Commit 195e227c authored by song's avatar song
Browse files

merge: 合并 upstream/main 并保留本地图片计费功能

parents 6fa704d6 752882a0
...@@ -123,6 +123,7 @@ type UpdateGroupInput struct { ...@@ -123,6 +123,7 @@ type UpdateGroupInput struct {
type CreateAccountInput struct { type CreateAccountInput struct {
Name string Name string
Notes *string
Platform string Platform string
Type string Type string
Credentials map[string]any Credentials map[string]any
...@@ -138,6 +139,7 @@ type CreateAccountInput struct { ...@@ -138,6 +139,7 @@ type CreateAccountInput struct {
type UpdateAccountInput struct { type UpdateAccountInput struct {
Name string Name string
Notes *string
Type string // Account type: oauth, setup-token, apikey Type string // Account type: oauth, setup-token, apikey
Credentials map[string]any Credentials map[string]any
Extra map[string]any Extra map[string]any
...@@ -687,6 +689,7 @@ func (s *adminServiceImpl) CreateAccount(ctx context.Context, input *CreateAccou ...@@ -687,6 +689,7 @@ func (s *adminServiceImpl) CreateAccount(ctx context.Context, input *CreateAccou
account := &Account{ account := &Account{
Name: input.Name, Name: input.Name,
Notes: normalizeAccountNotes(input.Notes),
Platform: input.Platform, Platform: input.Platform,
Type: input.Type, Type: input.Type,
Credentials: input.Credentials, Credentials: input.Credentials,
...@@ -723,6 +726,9 @@ func (s *adminServiceImpl) UpdateAccount(ctx context.Context, id int64, input *U ...@@ -723,6 +726,9 @@ func (s *adminServiceImpl) UpdateAccount(ctx context.Context, id int64, input *U
if input.Type != "" { if input.Type != "" {
account.Type = input.Type account.Type = input.Type
} }
if input.Notes != nil {
account.Notes = normalizeAccountNotes(input.Notes)
}
if len(input.Credentials) > 0 { if len(input.Credentials) > 0 {
account.Credentials = input.Credentials account.Credentials = input.Credentials
} }
...@@ -730,7 +736,12 @@ func (s *adminServiceImpl) UpdateAccount(ctx context.Context, id int64, input *U ...@@ -730,7 +736,12 @@ func (s *adminServiceImpl) UpdateAccount(ctx context.Context, id int64, input *U
account.Extra = input.Extra account.Extra = input.Extra
} }
if input.ProxyID != nil { if input.ProxyID != nil {
account.ProxyID = input.ProxyID // 0 表示清除代理(前端发送 0 而不是 null 来表达清除意图)
if *input.ProxyID == 0 {
account.ProxyID = nil
} else {
account.ProxyID = input.ProxyID
}
account.Proxy = nil // 清除关联对象,防止 GORM Save 时根据 Proxy.ID 覆盖 ProxyID account.Proxy = nil // 清除关联对象,防止 GORM Save 时根据 Proxy.ID 覆盖 ProxyID
} }
// 只在指针非 nil 时更新 Concurrency(支持设置为 0) // 只在指针非 nil 时更新 Concurrency(支持设置为 0)
......
package service
import (
"encoding/json"
"testing"
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
"github.com/stretchr/testify/require"
)
func TestStripSignatureSensitiveBlocksFromClaudeRequest(t *testing.T) {
req := &antigravity.ClaudeRequest{
Model: "claude-sonnet-4-5",
Thinking: &antigravity.ThinkingConfig{
Type: "enabled",
BudgetTokens: 1024,
},
Messages: []antigravity.ClaudeMessage{
{
Role: "assistant",
Content: json.RawMessage(`[
{"type":"thinking","thinking":"secret plan","signature":""},
{"type":"tool_use","id":"t1","name":"Bash","input":{"command":"ls"}}
]`),
},
{
Role: "user",
Content: json.RawMessage(`[
{"type":"tool_result","tool_use_id":"t1","content":"ok","is_error":false},
{"type":"redacted_thinking","data":"..."}
]`),
},
},
}
changed, err := stripSignatureSensitiveBlocksFromClaudeRequest(req)
require.NoError(t, err)
require.True(t, changed)
require.Nil(t, req.Thinking)
require.Len(t, req.Messages, 2)
var blocks0 []map[string]any
require.NoError(t, json.Unmarshal(req.Messages[0].Content, &blocks0))
require.Len(t, blocks0, 2)
require.Equal(t, "text", blocks0[0]["type"])
require.Equal(t, "secret plan", blocks0[0]["text"])
require.Equal(t, "text", blocks0[1]["type"])
var blocks1 []map[string]any
require.NoError(t, json.Unmarshal(req.Messages[1].Content, &blocks1))
require.Len(t, blocks1, 1)
require.Equal(t, "text", blocks1[0]["type"])
require.NotEmpty(t, blocks1[0]["text"])
}
func TestStripThinkingFromClaudeRequest_DoesNotDowngradeTools(t *testing.T) {
req := &antigravity.ClaudeRequest{
Model: "claude-sonnet-4-5",
Thinking: &antigravity.ThinkingConfig{
Type: "enabled",
BudgetTokens: 1024,
},
Messages: []antigravity.ClaudeMessage{
{
Role: "assistant",
Content: json.RawMessage(`[{"type":"thinking","thinking":"secret plan"},{"type":"tool_use","id":"t1","name":"Bash","input":{"command":"ls"}}]`),
},
},
}
changed, err := stripThinkingFromClaudeRequest(req)
require.NoError(t, err)
require.True(t, changed)
require.Nil(t, req.Thinking)
var blocks []map[string]any
require.NoError(t, json.Unmarshal(req.Messages[0].Content, &blocks))
require.Len(t, blocks, 2)
require.Equal(t, "text", blocks[0]["type"])
require.Equal(t, "secret plan", blocks[0]["text"])
require.Equal(t, "tool_use", blocks[1]["type"])
}
...@@ -221,9 +221,33 @@ func (s *AuthService) SendVerifyCodeAsync(ctx context.Context, email string) (*S ...@@ -221,9 +221,33 @@ func (s *AuthService) SendVerifyCodeAsync(ctx context.Context, email string) (*S
// VerifyTurnstile 验证Turnstile token // VerifyTurnstile 验证Turnstile token
func (s *AuthService) VerifyTurnstile(ctx context.Context, token string, remoteIP string) error { func (s *AuthService) VerifyTurnstile(ctx context.Context, token string, remoteIP string) error {
required := s.cfg != nil && s.cfg.Server.Mode == "release" && s.cfg.Turnstile.Required
if required {
if s.settingService == nil {
log.Println("[Auth] Turnstile required but settings service is not configured")
return ErrTurnstileNotConfigured
}
enabled := s.settingService.IsTurnstileEnabled(ctx)
secretConfigured := s.settingService.GetTurnstileSecretKey(ctx) != ""
if !enabled || !secretConfigured {
log.Printf("[Auth] Turnstile required but not configured (enabled=%v, secret_configured=%v)", enabled, secretConfigured)
return ErrTurnstileNotConfigured
}
}
if s.turnstileService == nil { if s.turnstileService == nil {
if required {
log.Println("[Auth] Turnstile required but service not configured")
return ErrTurnstileNotConfigured
}
return nil // 服务未配置则跳过验证 return nil // 服务未配置则跳过验证
} }
if !required && s.settingService != nil && s.settingService.IsTurnstileEnabled(ctx) && s.settingService.GetTurnstileSecretKey(ctx) == "" {
log.Println("[Auth] Turnstile enabled but secret key not configured")
}
return s.turnstileService.VerifyToken(ctx, token, remoteIP) return s.turnstileService.VerifyToken(ctx, token, remoteIP)
} }
......
...@@ -16,7 +16,8 @@ import ( ...@@ -16,7 +16,8 @@ import (
// 注:ErrInsufficientBalance在redeem_service.go中定义 // 注:ErrInsufficientBalance在redeem_service.go中定义
// 注:ErrDailyLimitExceeded/ErrWeeklyLimitExceeded/ErrMonthlyLimitExceeded在subscription_service.go中定义 // 注:ErrDailyLimitExceeded/ErrWeeklyLimitExceeded/ErrMonthlyLimitExceeded在subscription_service.go中定义
var ( var (
ErrSubscriptionInvalid = infraerrors.Forbidden("SUBSCRIPTION_INVALID", "subscription is invalid or expired") ErrSubscriptionInvalid = infraerrors.Forbidden("SUBSCRIPTION_INVALID", "subscription is invalid or expired")
ErrBillingServiceUnavailable = infraerrors.ServiceUnavailable("BILLING_SERVICE_ERROR", "Billing service temporarily unavailable. Please retry later.")
) )
// subscriptionCacheData 订阅缓存数据结构(内部使用) // subscriptionCacheData 订阅缓存数据结构(内部使用)
...@@ -72,10 +73,11 @@ type cacheWriteTask struct { ...@@ -72,10 +73,11 @@ type cacheWriteTask struct {
// BillingCacheService 计费缓存服务 // BillingCacheService 计费缓存服务
// 负责余额和订阅数据的缓存管理,提供高性能的计费资格检查 // 负责余额和订阅数据的缓存管理,提供高性能的计费资格检查
type BillingCacheService struct { type BillingCacheService struct {
cache BillingCache cache BillingCache
userRepo UserRepository userRepo UserRepository
subRepo UserSubscriptionRepository subRepo UserSubscriptionRepository
cfg *config.Config cfg *config.Config
circuitBreaker *billingCircuitBreaker
cacheWriteChan chan cacheWriteTask cacheWriteChan chan cacheWriteTask
cacheWriteWg sync.WaitGroup cacheWriteWg sync.WaitGroup
...@@ -95,6 +97,7 @@ func NewBillingCacheService(cache BillingCache, userRepo UserRepository, subRepo ...@@ -95,6 +97,7 @@ func NewBillingCacheService(cache BillingCache, userRepo UserRepository, subRepo
subRepo: subRepo, subRepo: subRepo,
cfg: cfg, cfg: cfg,
} }
svc.circuitBreaker = newBillingCircuitBreaker(cfg.Billing.CircuitBreaker)
svc.startCacheWriteWorkers() svc.startCacheWriteWorkers()
return svc return svc
} }
...@@ -450,6 +453,9 @@ func (s *BillingCacheService) CheckBillingEligibility(ctx context.Context, user ...@@ -450,6 +453,9 @@ func (s *BillingCacheService) CheckBillingEligibility(ctx context.Context, user
if s.cfg.RunMode == config.RunModeSimple { if s.cfg.RunMode == config.RunModeSimple {
return nil return nil
} }
if s.circuitBreaker != nil && !s.circuitBreaker.Allow() {
return ErrBillingServiceUnavailable
}
// 判断计费模式 // 判断计费模式
isSubscriptionMode := group != nil && group.IsSubscriptionType() && subscription != nil isSubscriptionMode := group != nil && group.IsSubscriptionType() && subscription != nil
...@@ -465,9 +471,14 @@ func (s *BillingCacheService) CheckBillingEligibility(ctx context.Context, user ...@@ -465,9 +471,14 @@ func (s *BillingCacheService) CheckBillingEligibility(ctx context.Context, user
func (s *BillingCacheService) checkBalanceEligibility(ctx context.Context, userID int64) error { func (s *BillingCacheService) checkBalanceEligibility(ctx context.Context, userID int64) error {
balance, err := s.GetUserBalance(ctx, userID) balance, err := s.GetUserBalance(ctx, userID)
if err != nil { if err != nil {
// 缓存/数据库错误,允许通过(降级处理) if s.circuitBreaker != nil {
log.Printf("Warning: get user balance failed, allowing request: %v", err) s.circuitBreaker.OnFailure(err)
return nil }
log.Printf("ALERT: billing balance check failed for user %d: %v", userID, err)
return ErrBillingServiceUnavailable.WithCause(err)
}
if s.circuitBreaker != nil {
s.circuitBreaker.OnSuccess()
} }
if balance <= 0 { if balance <= 0 {
...@@ -482,9 +493,14 @@ func (s *BillingCacheService) checkSubscriptionEligibility(ctx context.Context, ...@@ -482,9 +493,14 @@ func (s *BillingCacheService) checkSubscriptionEligibility(ctx context.Context,
// 获取订阅缓存数据 // 获取订阅缓存数据
subData, err := s.GetSubscriptionStatus(ctx, userID, group.ID) subData, err := s.GetSubscriptionStatus(ctx, userID, group.ID)
if err != nil { if err != nil {
// 缓存/数据库错误,降级使用传入的subscription进行检查 if s.circuitBreaker != nil {
log.Printf("Warning: get subscription cache failed, using fallback: %v", err) s.circuitBreaker.OnFailure(err)
return s.checkSubscriptionLimitsFallback(subscription, group) }
log.Printf("ALERT: billing subscription check failed for user %d group %d: %v", userID, group.ID, err)
return ErrBillingServiceUnavailable.WithCause(err)
}
if s.circuitBreaker != nil {
s.circuitBreaker.OnSuccess()
} }
// 检查订阅状态 // 检查订阅状态
...@@ -513,27 +529,133 @@ func (s *BillingCacheService) checkSubscriptionEligibility(ctx context.Context, ...@@ -513,27 +529,133 @@ func (s *BillingCacheService) checkSubscriptionEligibility(ctx context.Context,
return nil return nil
} }
// checkSubscriptionLimitsFallback 降级检查订阅限额 type billingCircuitBreakerState int
func (s *BillingCacheService) checkSubscriptionLimitsFallback(subscription *UserSubscription, group *Group) error {
if subscription == nil { const (
return ErrSubscriptionInvalid billingCircuitClosed billingCircuitBreakerState = iota
billingCircuitOpen
billingCircuitHalfOpen
)
type billingCircuitBreaker struct {
mu sync.Mutex
state billingCircuitBreakerState
failures int
openedAt time.Time
failureThreshold int
resetTimeout time.Duration
halfOpenRequests int
halfOpenRemaining int
}
func newBillingCircuitBreaker(cfg config.CircuitBreakerConfig) *billingCircuitBreaker {
if !cfg.Enabled {
return nil
}
resetTimeout := time.Duration(cfg.ResetTimeoutSeconds) * time.Second
if resetTimeout <= 0 {
resetTimeout = 30 * time.Second
}
halfOpen := cfg.HalfOpenRequests
if halfOpen <= 0 {
halfOpen = 1
}
threshold := cfg.FailureThreshold
if threshold <= 0 {
threshold = 5
}
return &billingCircuitBreaker{
state: billingCircuitClosed,
failureThreshold: threshold,
resetTimeout: resetTimeout,
halfOpenRequests: halfOpen,
} }
}
if !subscription.IsActive() { func (b *billingCircuitBreaker) Allow() bool {
return ErrSubscriptionInvalid b.mu.Lock()
defer b.mu.Unlock()
switch b.state {
case billingCircuitClosed:
return true
case billingCircuitOpen:
if time.Since(b.openedAt) < b.resetTimeout {
return false
}
b.state = billingCircuitHalfOpen
b.halfOpenRemaining = b.halfOpenRequests
log.Printf("ALERT: billing circuit breaker entering half-open state")
fallthrough
case billingCircuitHalfOpen:
if b.halfOpenRemaining <= 0 {
return false
}
b.halfOpenRemaining--
return true
default:
return false
} }
}
if !subscription.CheckDailyLimit(group, 0) { func (b *billingCircuitBreaker) OnFailure(err error) {
return ErrDailyLimitExceeded if b == nil {
return
} }
b.mu.Lock()
defer b.mu.Unlock()
if !subscription.CheckWeeklyLimit(group, 0) { switch b.state {
return ErrWeeklyLimitExceeded case billingCircuitOpen:
return
case billingCircuitHalfOpen:
b.state = billingCircuitOpen
b.openedAt = time.Now()
b.halfOpenRemaining = 0
log.Printf("ALERT: billing circuit breaker opened after half-open failure: %v", err)
return
default:
b.failures++
if b.failures >= b.failureThreshold {
b.state = billingCircuitOpen
b.openedAt = time.Now()
b.halfOpenRemaining = 0
log.Printf("ALERT: billing circuit breaker opened after %d failures: %v", b.failures, err)
}
} }
}
if !subscription.CheckMonthlyLimit(group, 0) { func (b *billingCircuitBreaker) OnSuccess() {
return ErrMonthlyLimitExceeded if b == nil {
return
} }
b.mu.Lock()
defer b.mu.Unlock()
return nil previousState := b.state
previousFailures := b.failures
b.state = billingCircuitClosed
b.failures = 0
b.halfOpenRemaining = 0
// 只有状态真正发生变化时才记录日志
if previousState != billingCircuitClosed {
log.Printf("ALERT: billing circuit breaker closed (was %s)", circuitStateString(previousState))
} else if previousFailures > 0 {
log.Printf("INFO: billing circuit breaker failures reset from %d", previousFailures)
}
}
func circuitStateString(state billingCircuitBreakerState) string {
switch state {
case billingCircuitClosed:
return "closed"
case billingCircuitOpen:
return "open"
case billingCircuitHalfOpen:
return "half-open"
default:
return "unknown"
}
} }
...@@ -8,12 +8,13 @@ import ( ...@@ -8,12 +8,13 @@ import (
"fmt" "fmt"
"io" "io"
"net/http" "net/http"
"net/url"
"strconv" "strconv"
"strings" "strings"
"time" "time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/httpclient" "github.com/Wei-Shaw/sub2api/internal/pkg/httpclient"
"github.com/Wei-Shaw/sub2api/internal/util/urlvalidator"
) )
type CRSSyncService struct { type CRSSyncService struct {
...@@ -22,6 +23,7 @@ type CRSSyncService struct { ...@@ -22,6 +23,7 @@ type CRSSyncService struct {
oauthService *OAuthService oauthService *OAuthService
openaiOAuthService *OpenAIOAuthService openaiOAuthService *OpenAIOAuthService
geminiOAuthService *GeminiOAuthService geminiOAuthService *GeminiOAuthService
cfg *config.Config
} }
func NewCRSSyncService( func NewCRSSyncService(
...@@ -30,6 +32,7 @@ func NewCRSSyncService( ...@@ -30,6 +32,7 @@ func NewCRSSyncService(
oauthService *OAuthService, oauthService *OAuthService,
openaiOAuthService *OpenAIOAuthService, openaiOAuthService *OpenAIOAuthService,
geminiOAuthService *GeminiOAuthService, geminiOAuthService *GeminiOAuthService,
cfg *config.Config,
) *CRSSyncService { ) *CRSSyncService {
return &CRSSyncService{ return &CRSSyncService{
accountRepo: accountRepo, accountRepo: accountRepo,
...@@ -37,6 +40,7 @@ func NewCRSSyncService( ...@@ -37,6 +40,7 @@ func NewCRSSyncService(
oauthService: oauthService, oauthService: oauthService,
openaiOAuthService: openaiOAuthService, openaiOAuthService: openaiOAuthService,
geminiOAuthService: geminiOAuthService, geminiOAuthService: geminiOAuthService,
cfg: cfg,
} }
} }
...@@ -187,16 +191,31 @@ type crsGeminiAPIKeyAccount struct { ...@@ -187,16 +191,31 @@ type crsGeminiAPIKeyAccount struct {
} }
func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput) (*SyncFromCRSResult, error) { func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput) (*SyncFromCRSResult, error) {
baseURL, err := normalizeBaseURL(input.BaseURL) if s.cfg == nil {
if err != nil { return nil, errors.New("config is not available")
return nil, err }
baseURL := strings.TrimSpace(input.BaseURL)
if s.cfg.Security.URLAllowlist.Enabled {
normalized, err := normalizeBaseURL(baseURL, s.cfg.Security.URLAllowlist.CRSHosts, s.cfg.Security.URLAllowlist.AllowPrivateHosts)
if err != nil {
return nil, err
}
baseURL = normalized
} else {
normalized, err := urlvalidator.ValidateURLFormat(baseURL, s.cfg.Security.URLAllowlist.AllowInsecureHTTP)
if err != nil {
return nil, fmt.Errorf("invalid base_url: %w", err)
}
baseURL = normalized
} }
if strings.TrimSpace(input.Username) == "" || strings.TrimSpace(input.Password) == "" { if strings.TrimSpace(input.Username) == "" || strings.TrimSpace(input.Password) == "" {
return nil, errors.New("username and password are required") return nil, errors.New("username and password are required")
} }
client, err := httpclient.GetClient(httpclient.Options{ client, err := httpclient.GetClient(httpclient.Options{
Timeout: 20 * time.Second, Timeout: 20 * time.Second,
ValidateResolvedIP: s.cfg.Security.URLAllowlist.Enabled,
AllowPrivateHosts: s.cfg.Security.URLAllowlist.AllowPrivateHosts,
}) })
if err != nil { if err != nil {
client = &http.Client{Timeout: 20 * time.Second} client = &http.Client{Timeout: 20 * time.Second}
...@@ -1055,17 +1074,18 @@ func mapCRSStatus(isActive bool, status string) string { ...@@ -1055,17 +1074,18 @@ func mapCRSStatus(isActive bool, status string) string {
return "active" return "active"
} }
func normalizeBaseURL(raw string) (string, error) { func normalizeBaseURL(raw string, allowlist []string, allowPrivate bool) (string, error) {
trimmed := strings.TrimSpace(raw) // 当 allowlist 为空时,不强制要求白名单(只进行基本的 URL 和 SSRF 验证)
if trimmed == "" { requireAllowlist := len(allowlist) > 0
return "", errors.New("base_url is required") normalized, err := urlvalidator.ValidateHTTPSURL(raw, urlvalidator.ValidationOptions{
} AllowedHosts: allowlist,
u, err := url.Parse(trimmed) RequireAllowlist: requireAllowlist,
if err != nil || u.Scheme == "" || u.Host == "" { AllowPrivate: allowPrivate,
return "", fmt.Errorf("invalid base_url: %s", trimmed) })
if err != nil {
return "", fmt.Errorf("invalid base_url: %w", err)
} }
u.Path = strings.TrimRight(u.Path, "/") return normalized, nil
return strings.TrimRight(u.String(), "/"), nil
} }
// cleanBaseURL removes trailing suffix from base_url in credentials // cleanBaseURL removes trailing suffix from base_url in credentials
......
...@@ -101,6 +101,10 @@ const ( ...@@ -101,6 +101,10 @@ const (
SettingKeyFallbackModelOpenAI = "fallback_model_openai" SettingKeyFallbackModelOpenAI = "fallback_model_openai"
SettingKeyFallbackModelGemini = "fallback_model_gemini" SettingKeyFallbackModelGemini = "fallback_model_gemini"
SettingKeyFallbackModelAntigravity = "fallback_model_antigravity" SettingKeyFallbackModelAntigravity = "fallback_model_antigravity"
// Request identity patch (Claude -> Gemini systemInstruction injection)
SettingKeyEnableIdentityPatch = "enable_identity_patch"
SettingKeyIdentityPatchPrompt = "identity_patch_prompt"
) )
// AdminAPIKeyPrefix is the prefix for admin API keys (distinct from user "sk-" keys). // AdminAPIKeyPrefix is the prefix for admin API keys (distinct from user "sk-" keys).
......
This diff is collapsed.
...@@ -151,3 +151,148 @@ func TestFilterThinkingBlocks(t *testing.T) { ...@@ -151,3 +151,148 @@ func TestFilterThinkingBlocks(t *testing.T) {
}) })
} }
} }
func TestFilterThinkingBlocksForRetry_DisablesThinkingAndPreservesAsText(t *testing.T) {
input := []byte(`{
"model":"claude-3-5-sonnet-20241022",
"thinking":{"type":"enabled","budget_tokens":1024},
"messages":[
{"role":"user","content":[{"type":"text","text":"Hi"}]},
{"role":"assistant","content":[
{"type":"thinking","thinking":"Let me think...","signature":"bad_sig"},
{"type":"text","text":"Answer"}
]}
]
}`)
out := FilterThinkingBlocksForRetry(input)
var req map[string]any
require.NoError(t, json.Unmarshal(out, &req))
_, hasThinking := req["thinking"]
require.False(t, hasThinking)
msgs, ok := req["messages"].([]any)
require.True(t, ok)
require.Len(t, msgs, 2)
assistant, ok := msgs[1].(map[string]any)
require.True(t, ok)
content, ok := assistant["content"].([]any)
require.True(t, ok)
require.Len(t, content, 2)
first, ok := content[0].(map[string]any)
require.True(t, ok)
require.Equal(t, "text", first["type"])
require.Equal(t, "Let me think...", first["text"])
}
func TestFilterThinkingBlocksForRetry_DisablesThinkingEvenWithoutThinkingBlocks(t *testing.T) {
input := []byte(`{
"model":"claude-3-5-sonnet-20241022",
"thinking":{"type":"enabled","budget_tokens":1024},
"messages":[
{"role":"user","content":[{"type":"text","text":"Hi"}]},
{"role":"assistant","content":[{"type":"text","text":"Prefill"}]}
]
}`)
out := FilterThinkingBlocksForRetry(input)
var req map[string]any
require.NoError(t, json.Unmarshal(out, &req))
_, hasThinking := req["thinking"]
require.False(t, hasThinking)
}
func TestFilterThinkingBlocksForRetry_RemovesRedactedThinkingAndKeepsValidContent(t *testing.T) {
input := []byte(`{
"thinking":{"type":"enabled","budget_tokens":1024},
"messages":[
{"role":"assistant","content":[
{"type":"redacted_thinking","data":"..."},
{"type":"text","text":"Visible"}
]}
]
}`)
out := FilterThinkingBlocksForRetry(input)
var req map[string]any
require.NoError(t, json.Unmarshal(out, &req))
_, hasThinking := req["thinking"]
require.False(t, hasThinking)
msgs, ok := req["messages"].([]any)
require.True(t, ok)
msg0, ok := msgs[0].(map[string]any)
require.True(t, ok)
content, ok := msg0["content"].([]any)
require.True(t, ok)
require.Len(t, content, 1)
content0, ok := content[0].(map[string]any)
require.True(t, ok)
require.Equal(t, "text", content0["type"])
require.Equal(t, "Visible", content0["text"])
}
func TestFilterThinkingBlocksForRetry_EmptyContentGetsPlaceholder(t *testing.T) {
input := []byte(`{
"thinking":{"type":"enabled"},
"messages":[
{"role":"assistant","content":[{"type":"redacted_thinking","data":"..."}]}
]
}`)
out := FilterThinkingBlocksForRetry(input)
var req map[string]any
require.NoError(t, json.Unmarshal(out, &req))
msgs, ok := req["messages"].([]any)
require.True(t, ok)
msg0, ok := msgs[0].(map[string]any)
require.True(t, ok)
content, ok := msg0["content"].([]any)
require.True(t, ok)
require.Len(t, content, 1)
content0, ok := content[0].(map[string]any)
require.True(t, ok)
require.Equal(t, "text", content0["type"])
require.NotEmpty(t, content0["text"])
}
func TestFilterSignatureSensitiveBlocksForRetry_DowngradesTools(t *testing.T) {
input := []byte(`{
"thinking":{"type":"enabled","budget_tokens":1024},
"messages":[
{"role":"assistant","content":[
{"type":"tool_use","id":"t1","name":"Bash","input":{"command":"ls"}},
{"type":"tool_result","tool_use_id":"t1","content":"ok","is_error":false}
]}
]
}`)
out := FilterSignatureSensitiveBlocksForRetry(input)
var req map[string]any
require.NoError(t, json.Unmarshal(out, &req))
_, hasThinking := req["thinking"]
require.False(t, hasThinking)
msgs, ok := req["messages"].([]any)
require.True(t, ok)
msg0, ok := msgs[0].(map[string]any)
require.True(t, ok)
content, ok := msg0["content"].([]any)
require.True(t, ok)
require.Len(t, content, 2)
content0, ok := content[0].(map[string]any)
require.True(t, ok)
content1, ok := content[1].(map[string]any)
require.True(t, ok)
require.Equal(t, "text", content0["type"])
require.Equal(t, "text", content1["type"])
require.Contains(t, content0["text"], "tool_use")
require.Contains(t, content1["text"], "tool_result")
}
This diff is collapsed.
...@@ -1000,8 +1000,9 @@ func fetchProjectIDFromResourceManager(ctx context.Context, accessToken, proxyUR ...@@ -1000,8 +1000,9 @@ func fetchProjectIDFromResourceManager(ctx context.Context, accessToken, proxyUR
req.Header.Set("User-Agent", geminicli.GeminiCLIUserAgent) req.Header.Set("User-Agent", geminicli.GeminiCLIUserAgent)
client, err := httpclient.GetClient(httpclient.Options{ client, err := httpclient.GetClient(httpclient.Options{
ProxyURL: strings.TrimSpace(proxyURL), ProxyURL: strings.TrimSpace(proxyURL),
Timeout: 30 * time.Second, Timeout: 30 * time.Second,
ValidateResolvedIP: true,
}) })
if err != nil { if err != nil {
client = &http.Client{Timeout: 30 * time.Second} client = &http.Client{Timeout: 30 * time.Second}
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
...@@ -4,17 +4,19 @@ type SystemSettings struct { ...@@ -4,17 +4,19 @@ type SystemSettings struct {
RegistrationEnabled bool RegistrationEnabled bool
EmailVerifyEnabled bool EmailVerifyEnabled bool
SMTPHost string SMTPHost string
SMTPPort int SMTPPort int
SMTPUsername string SMTPUsername string
SMTPPassword string SMTPPassword string
SMTPFrom string SMTPPasswordConfigured bool
SMTPFromName string SMTPFrom string
SMTPUseTLS bool SMTPFromName string
SMTPUseTLS bool
TurnstileEnabled bool
TurnstileSiteKey string TurnstileEnabled bool
TurnstileSecretKey string TurnstileSiteKey string
TurnstileSecretKey string
TurnstileSecretKeyConfigured bool
SiteName string SiteName string
SiteLogo string SiteLogo string
...@@ -32,6 +34,10 @@ type SystemSettings struct { ...@@ -32,6 +34,10 @@ type SystemSettings struct {
FallbackModelOpenAI string `json:"fallback_model_openai"` FallbackModelOpenAI string `json:"fallback_model_openai"`
FallbackModelGemini string `json:"fallback_model_gemini"` FallbackModelGemini string `json:"fallback_model_gemini"`
FallbackModelAntigravity string `json:"fallback_model_antigravity"` FallbackModelAntigravity string `json:"fallback_model_antigravity"`
// Identity patch configuration (Claude -> Gemini)
EnableIdentityPatch bool `json:"enable_identity_patch"`
IdentityPatchPrompt string `json:"identity_patch_prompt"`
} }
type PublicSettings struct { type PublicSettings struct {
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment