Commit a161fcc8 authored by cyhhao's avatar cyhhao
Browse files

Merge branch 'main' of github.com:Wei-Shaw/sub2api

parents 65e69738 e32c5f53
...@@ -209,17 +209,20 @@ func (h *GatewayHandler) Messages(c *gin.Context) { ...@@ -209,17 +209,20 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
account := selection.Account account := selection.Account
setOpsSelectedAccount(c, account.ID) setOpsSelectedAccount(c, account.ID)
// 检查预热请求拦截(在账号选择后、转发前检查) // 检查请求拦截(预热请求、SUGGESTION MODE等)
if account.IsInterceptWarmupEnabled() && isWarmupRequest(body) { if account.IsInterceptWarmupEnabled() {
if selection.Acquired && selection.ReleaseFunc != nil { interceptType := detectInterceptType(body)
selection.ReleaseFunc() if interceptType != InterceptTypeNone {
} if selection.Acquired && selection.ReleaseFunc != nil {
if reqStream { selection.ReleaseFunc()
sendMockWarmupStream(c, reqModel) }
} else { if reqStream {
sendMockWarmupResponse(c, reqModel) sendMockInterceptStream(c, reqModel, interceptType)
} else {
sendMockInterceptResponse(c, reqModel, interceptType)
}
return
} }
return
} }
// 3. 获取账号并发槽位 // 3. 获取账号并发槽位
...@@ -344,17 +347,20 @@ func (h *GatewayHandler) Messages(c *gin.Context) { ...@@ -344,17 +347,20 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
account := selection.Account account := selection.Account
setOpsSelectedAccount(c, account.ID) setOpsSelectedAccount(c, account.ID)
// 检查预热请求拦截(在账号选择后、转发前检查) // 检查请求拦截(预热请求、SUGGESTION MODE等)
if account.IsInterceptWarmupEnabled() && isWarmupRequest(body) { if account.IsInterceptWarmupEnabled() {
if selection.Acquired && selection.ReleaseFunc != nil { interceptType := detectInterceptType(body)
selection.ReleaseFunc() if interceptType != InterceptTypeNone {
} if selection.Acquired && selection.ReleaseFunc != nil {
if reqStream { selection.ReleaseFunc()
sendMockWarmupStream(c, reqModel) }
} else { if reqStream {
sendMockWarmupResponse(c, reqModel) sendMockInterceptStream(c, reqModel, interceptType)
} else {
sendMockInterceptResponse(c, reqModel, interceptType)
}
return
} }
return
} }
// 3. 获取账号并发槽位 // 3. 获取账号并发槽位
...@@ -768,17 +774,30 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) { ...@@ -768,17 +774,30 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) {
} }
} }
// isWarmupRequest 检测是否为预热请求(标题生成、Warmup等) // InterceptType 表示请求拦截类型
func isWarmupRequest(body []byte) bool { type InterceptType int
// 快速检查:如果body不包含关键字,直接返回false
const (
InterceptTypeNone InterceptType = iota
InterceptTypeWarmup // 预热请求(返回 "New Conversation")
InterceptTypeSuggestionMode // SUGGESTION MODE(返回空字符串)
)
// detectInterceptType 检测请求是否需要拦截,返回拦截类型
func detectInterceptType(body []byte) InterceptType {
// 快速检查:如果不包含任何关键字,直接返回
bodyStr := string(body) bodyStr := string(body)
if !strings.Contains(bodyStr, "title") && !strings.Contains(bodyStr, "Warmup") { hasSuggestionMode := strings.Contains(bodyStr, "[SUGGESTION MODE:")
return false hasWarmupKeyword := strings.Contains(bodyStr, "title") || strings.Contains(bodyStr, "Warmup")
if !hasSuggestionMode && !hasWarmupKeyword {
return InterceptTypeNone
} }
// 解析完整请求 // 解析请求(只解析一次)
var req struct { var req struct {
Messages []struct { Messages []struct {
Role string `json:"role"`
Content []struct { Content []struct {
Type string `json:"type"` Type string `json:"type"`
Text string `json:"text"` Text string `json:"text"`
...@@ -789,43 +808,71 @@ func isWarmupRequest(body []byte) bool { ...@@ -789,43 +808,71 @@ func isWarmupRequest(body []byte) bool {
} `json:"system"` } `json:"system"`
} }
if err := json.Unmarshal(body, &req); err != nil { if err := json.Unmarshal(body, &req); err != nil {
return false return InterceptTypeNone
} }
// 检查 messages 中的标题提示模式 // 检查 SUGGESTION MODE(最后一条 user 消息)
for _, msg := range req.Messages { if hasSuggestionMode && len(req.Messages) > 0 {
for _, content := range msg.Content { lastMsg := req.Messages[len(req.Messages)-1]
if content.Type == "text" { if lastMsg.Role == "user" && len(lastMsg.Content) > 0 &&
if strings.Contains(content.Text, "Please write a 5-10 word title for the following conversation:") || lastMsg.Content[0].Type == "text" &&
content.Text == "Warmup" { strings.HasPrefix(lastMsg.Content[0].Text, "[SUGGESTION MODE:") {
return true return InterceptTypeSuggestionMode
}
}
} }
} }
// 检查 system 中的标题提取模式 // 检查 Warmup 请求
for _, system := range req.System { if hasWarmupKeyword {
if strings.Contains(system.Text, "nalyze if this message indicates a new conversation topic. If it does, extract a 2-3 word title") { // 检查 messages 中的标题提示模式
return true for _, msg := range req.Messages {
for _, content := range msg.Content {
if content.Type == "text" {
if strings.Contains(content.Text, "Please write a 5-10 word title for the following conversation:") ||
content.Text == "Warmup" {
return InterceptTypeWarmup
}
}
}
}
// 检查 system 中的标题提取模式
for _, sys := range req.System {
if strings.Contains(sys.Text, "nalyze if this message indicates a new conversation topic. If it does, extract a 2-3 word title") {
return InterceptTypeWarmup
}
} }
} }
return false return InterceptTypeNone
} }
// sendMockWarmupStream 发送流式 mock 响应(用于预热请求拦截) // sendMockInterceptStream 发送流式 mock 响应(用于请求拦截)
func sendMockWarmupStream(c *gin.Context, model string) { func sendMockInterceptStream(c *gin.Context, model string, interceptType InterceptType) {
c.Header("Content-Type", "text/event-stream") c.Header("Content-Type", "text/event-stream")
c.Header("Cache-Control", "no-cache") c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive") c.Header("Connection", "keep-alive")
c.Header("X-Accel-Buffering", "no") c.Header("X-Accel-Buffering", "no")
// 根据拦截类型决定响应内容
var msgID string
var outputTokens int
var textDeltas []string
switch interceptType {
case InterceptTypeSuggestionMode:
msgID = "msg_mock_suggestion"
outputTokens = 1
textDeltas = []string{""} // 空内容
default: // InterceptTypeWarmup
msgID = "msg_mock_warmup"
outputTokens = 2
textDeltas = []string{"New", " Conversation"}
}
// Build message_start event with proper JSON marshaling // Build message_start event with proper JSON marshaling
messageStart := map[string]any{ messageStart := map[string]any{
"type": "message_start", "type": "message_start",
"message": map[string]any{ "message": map[string]any{
"id": "msg_mock_warmup", "id": msgID,
"type": "message", "type": "message",
"role": "assistant", "role": "assistant",
"model": model, "model": model,
...@@ -840,16 +887,46 @@ func sendMockWarmupStream(c *gin.Context, model string) { ...@@ -840,16 +887,46 @@ func sendMockWarmupStream(c *gin.Context, model string) {
} }
messageStartJSON, _ := json.Marshal(messageStart) messageStartJSON, _ := json.Marshal(messageStart)
// Build events
events := []string{ events := []string{
`event: message_start` + "\n" + `data: ` + string(messageStartJSON), `event: message_start` + "\n" + `data: ` + string(messageStartJSON),
`event: content_block_start` + "\n" + `data: {"content_block":{"text":"","type":"text"},"index":0,"type":"content_block_start"}`, `event: content_block_start` + "\n" + `data: {"content_block":{"text":"","type":"text"},"index":0,"type":"content_block_start"}`,
`event: content_block_delta` + "\n" + `data: {"delta":{"text":"New","type":"text_delta"},"index":0,"type":"content_block_delta"}`,
`event: content_block_delta` + "\n" + `data: {"delta":{"text":" Conversation","type":"text_delta"},"index":0,"type":"content_block_delta"}`,
`event: content_block_stop` + "\n" + `data: {"index":0,"type":"content_block_stop"}`,
`event: message_delta` + "\n" + `data: {"delta":{"stop_reason":"end_turn","stop_sequence":null},"type":"message_delta","usage":{"input_tokens":10,"output_tokens":2}}`,
`event: message_stop` + "\n" + `data: {"type":"message_stop"}`,
} }
// Add text deltas
for _, text := range textDeltas {
delta := map[string]any{
"type": "content_block_delta",
"index": 0,
"delta": map[string]string{
"type": "text_delta",
"text": text,
},
}
deltaJSON, _ := json.Marshal(delta)
events = append(events, `event: content_block_delta`+"\n"+`data: `+string(deltaJSON))
}
// Add final events
messageDelta := map[string]any{
"type": "message_delta",
"delta": map[string]any{
"stop_reason": "end_turn",
"stop_sequence": nil,
},
"usage": map[string]int{
"input_tokens": 10,
"output_tokens": outputTokens,
},
}
messageDeltaJSON, _ := json.Marshal(messageDelta)
events = append(events,
`event: content_block_stop`+"\n"+`data: {"index":0,"type":"content_block_stop"}`,
`event: message_delta`+"\n"+`data: `+string(messageDeltaJSON),
`event: message_stop`+"\n"+`data: {"type":"message_stop"}`,
)
for _, event := range events { for _, event := range events {
_, _ = c.Writer.WriteString(event + "\n\n") _, _ = c.Writer.WriteString(event + "\n\n")
c.Writer.Flush() c.Writer.Flush()
...@@ -857,18 +934,32 @@ func sendMockWarmupStream(c *gin.Context, model string) { ...@@ -857,18 +934,32 @@ func sendMockWarmupStream(c *gin.Context, model string) {
} }
} }
// sendMockWarmupResponse 发送非流式 mock 响应(用于预热请求拦截) // sendMockInterceptResponse 发送非流式 mock 响应(用于请求拦截)
func sendMockWarmupResponse(c *gin.Context, model string) { func sendMockInterceptResponse(c *gin.Context, model string, interceptType InterceptType) {
var msgID, text string
var outputTokens int
switch interceptType {
case InterceptTypeSuggestionMode:
msgID = "msg_mock_suggestion"
text = ""
outputTokens = 1
default: // InterceptTypeWarmup
msgID = "msg_mock_warmup"
text = "New Conversation"
outputTokens = 2
}
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"id": "msg_mock_warmup", "id": msgID,
"type": "message", "type": "message",
"role": "assistant", "role": "assistant",
"model": model, "model": model,
"content": []gin.H{{"type": "text", "text": "New Conversation"}}, "content": []gin.H{{"type": "text", "text": text}},
"stop_reason": "end_turn", "stop_reason": "end_turn",
"usage": gin.H{ "usage": gin.H{
"input_tokens": 10, "input_tokens": 10,
"output_tokens": 2, "output_tokens": outputTokens,
}, },
}) })
} }
......
//go:build unit
package handler
import (
"crypto/sha256"
"encoding/hex"
"net/http/httptest"
"testing"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
func TestExtractGeminiCLISessionHash(t *testing.T) {
tests := []struct {
name string
body string
privilegedUserID string
wantEmpty bool
wantHash string
}{
{
name: "with privileged-user-id and tmp dir",
body: `{"contents":[{"parts":[{"text":"The project's temporary directory is: /Users/ianshaw/.gemini/tmp/f7851b009ed314d1baee62e83115f486160283f4a55a582d89fdac8b9fe3b740"}]}]}`,
privilegedUserID: "90785f52-8bbe-4b17-b111-a1ddea1636c3",
wantEmpty: false,
wantHash: func() string {
combined := "90785f52-8bbe-4b17-b111-a1ddea1636c3:f7851b009ed314d1baee62e83115f486160283f4a55a582d89fdac8b9fe3b740"
hash := sha256.Sum256([]byte(combined))
return hex.EncodeToString(hash[:])
}(),
},
{
name: "without privileged-user-id but with tmp dir",
body: `{"contents":[{"parts":[{"text":"The project's temporary directory is: /Users/ianshaw/.gemini/tmp/f7851b009ed314d1baee62e83115f486160283f4a55a582d89fdac8b9fe3b740"}]}]}`,
privilegedUserID: "",
wantEmpty: false,
wantHash: "f7851b009ed314d1baee62e83115f486160283f4a55a582d89fdac8b9fe3b740",
},
{
name: "without tmp dir",
body: `{"contents":[{"parts":[{"text":"Hello world"}]}]}`,
privilegedUserID: "90785f52-8bbe-4b17-b111-a1ddea1636c3",
wantEmpty: true,
},
{
name: "empty body",
body: "",
privilegedUserID: "90785f52-8bbe-4b17-b111-a1ddea1636c3",
wantEmpty: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// 创建测试上下文
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest("POST", "/test", nil)
if tt.privilegedUserID != "" {
c.Request.Header.Set("x-gemini-api-privileged-user-id", tt.privilegedUserID)
}
// 调用函数
result := extractGeminiCLISessionHash(c, []byte(tt.body))
// 验证结果
if tt.wantEmpty {
require.Empty(t, result, "expected empty session hash")
} else {
require.NotEmpty(t, result, "expected non-empty session hash")
require.Equal(t, tt.wantHash, result, "session hash mismatch")
}
})
}
}
func TestGeminiCLITmpDirRegex(t *testing.T) {
tests := []struct {
name string
input string
wantMatch bool
wantHash string
}{
{
name: "valid tmp dir path",
input: "/Users/ianshaw/.gemini/tmp/f7851b009ed314d1baee62e83115f486160283f4a55a582d89fdac8b9fe3b740",
wantMatch: true,
wantHash: "f7851b009ed314d1baee62e83115f486160283f4a55a582d89fdac8b9fe3b740",
},
{
name: "valid tmp dir path in text",
input: "The project's temporary directory is: /Users/ianshaw/.gemini/tmp/f7851b009ed314d1baee62e83115f486160283f4a55a582d89fdac8b9fe3b740\nOther text",
wantMatch: true,
wantHash: "f7851b009ed314d1baee62e83115f486160283f4a55a582d89fdac8b9fe3b740",
},
{
name: "invalid hash length",
input: "/Users/ianshaw/.gemini/tmp/abc123",
wantMatch: false,
},
{
name: "no tmp dir",
input: "Hello world",
wantMatch: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
match := geminiCLITmpDirRegex.FindStringSubmatch(tt.input)
if tt.wantMatch {
require.NotNil(t, match, "expected regex to match")
require.Len(t, match, 2, "expected 2 capture groups")
require.Equal(t, tt.wantHash, match[1], "hash mismatch")
} else {
require.Nil(t, match, "expected regex not to match")
}
})
}
}
package handler package handler
import ( import (
"bytes"
"context" "context"
"crypto/sha256"
"encoding/hex"
"errors" "errors"
"io" "io"
"log" "log"
"net/http" "net/http"
"regexp"
"strings" "strings"
"time" "time"
...@@ -19,6 +23,17 @@ import ( ...@@ -19,6 +23,17 @@ import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )
// geminiCLITmpDirRegex 用于从 Gemini CLI 请求体中提取 tmp 目录的哈希值
// 匹配格式: /Users/xxx/.gemini/tmp/[64位十六进制哈希]
var geminiCLITmpDirRegex = regexp.MustCompile(`/\.gemini/tmp/([A-Fa-f0-9]{64})`)
func isGeminiCLIRequest(c *gin.Context, body []byte) bool {
if strings.TrimSpace(c.GetHeader("x-gemini-api-privileged-user-id")) != "" {
return true
}
return geminiCLITmpDirRegex.Match(body)
}
// GeminiV1BetaListModels proxies: // GeminiV1BetaListModels proxies:
// GET /v1beta/models // GET /v1beta/models
func (h *GatewayHandler) GeminiV1BetaListModels(c *gin.Context) { func (h *GatewayHandler) GeminiV1BetaListModels(c *gin.Context) {
...@@ -214,12 +229,26 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { ...@@ -214,12 +229,26 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
} }
// 3) select account (sticky session based on request body) // 3) select account (sticky session based on request body)
parsedReq, _ := service.ParseGatewayRequest(body) // 优先使用 Gemini CLI 的会话标识(privileged-user-id + tmp 目录哈希)
sessionHash := h.gatewayService.GenerateSessionHash(parsedReq) sessionHash := extractGeminiCLISessionHash(c, body)
if sessionHash == "" {
// Fallback: 使用通用的会话哈希生成逻辑(适用于其他客户端)
parsedReq, _ := service.ParseGatewayRequest(body)
sessionHash = h.gatewayService.GenerateSessionHash(parsedReq)
}
sessionKey := sessionHash sessionKey := sessionHash
if sessionHash != "" { if sessionHash != "" {
sessionKey = "gemini:" + sessionHash sessionKey = "gemini:" + sessionHash
} }
// 查询粘性会话绑定的账号 ID(用于检测账号切换)
var sessionBoundAccountID int64
if sessionKey != "" {
sessionBoundAccountID, _ = h.gatewayService.GetCachedSessionAccountID(c.Request.Context(), apiKey.GroupID, sessionKey)
}
isCLI := isGeminiCLIRequest(c, body)
cleanedForUnknownBinding := false
maxAccountSwitches := h.maxAccountSwitchesGemini maxAccountSwitches := h.maxAccountSwitchesGemini
switchCount := 0 switchCount := 0
failedAccountIDs := make(map[int64]struct{}) failedAccountIDs := make(map[int64]struct{})
...@@ -238,6 +267,24 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { ...@@ -238,6 +267,24 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
account := selection.Account account := selection.Account
setOpsSelectedAccount(c, account.ID) setOpsSelectedAccount(c, account.ID)
// 检测账号切换:如果粘性会话绑定的账号与当前选择的账号不同,清除 thoughtSignature
// 注意:Gemini 原生 API 的 thoughtSignature 与具体上游账号强相关;跨账号透传会导致 400。
if sessionBoundAccountID > 0 && sessionBoundAccountID != account.ID {
log.Printf("[Gemini] Sticky session account switched: %d -> %d, cleaning thoughtSignature", sessionBoundAccountID, account.ID)
body = service.CleanGeminiNativeThoughtSignatures(body)
sessionBoundAccountID = account.ID
} else if sessionKey != "" && sessionBoundAccountID == 0 && isCLI && !cleanedForUnknownBinding && bytes.Contains(body, []byte(`"thoughtSignature"`)) {
// 无缓存绑定但请求里已有 thoughtSignature:常见于缓存丢失/TTL 过期后,CLI 继续携带旧签名。
// 为避免第一次转发就 400,这里做一次确定性清理,让新账号重新生成签名链路。
log.Printf("[Gemini] Sticky session binding missing for CLI request, cleaning thoughtSignature proactively")
body = service.CleanGeminiNativeThoughtSignatures(body)
cleanedForUnknownBinding = true
sessionBoundAccountID = account.ID
} else if sessionBoundAccountID == 0 {
// 记录本次请求中首次选择到的账号,便于同一请求内 failover 时检测切换。
sessionBoundAccountID = account.ID
}
// 4) account concurrency slot // 4) account concurrency slot
accountReleaseFunc := selection.ReleaseFunc accountReleaseFunc := selection.ReleaseFunc
if !selection.Acquired { if !selection.Acquired {
...@@ -433,3 +480,38 @@ func shouldFallbackGeminiModels(res *service.UpstreamHTTPResult) bool { ...@@ -433,3 +480,38 @@ func shouldFallbackGeminiModels(res *service.UpstreamHTTPResult) bool {
} }
return false return false
} }
// extractGeminiCLISessionHash 从 Gemini CLI 请求中提取会话标识。
// 组合 x-gemini-api-privileged-user-id header 和请求体中的 tmp 目录哈希。
//
// 会话标识生成策略:
// 1. 从请求体中提取 tmp 目录哈希(64位十六进制)
// 2. 从 header 中提取 privileged-user-id(UUID)
// 3. 组合两者生成 SHA256 哈希作为最终的会话标识
//
// 如果找不到 tmp 目录哈希,返回空字符串(不使用粘性会话)。
//
// extractGeminiCLISessionHash extracts session identifier from Gemini CLI requests.
// Combines x-gemini-api-privileged-user-id header with tmp directory hash from request body.
func extractGeminiCLISessionHash(c *gin.Context, body []byte) string {
// 1. 从请求体中提取 tmp 目录哈希
match := geminiCLITmpDirRegex.FindSubmatch(body)
if len(match) < 2 {
return "" // 没有找到 tmp 目录,不使用粘性会话
}
tmpDirHash := string(match[1])
// 2. 提取 privileged-user-id
privilegedUserID := strings.TrimSpace(c.GetHeader("x-gemini-api-privileged-user-id"))
// 3. 组合生成最终的 session hash
if privilegedUserID != "" {
// 组合两个标识符:privileged-user-id + tmp 目录哈希
combined := privilegedUserID + ":" + tmpDirHash
hash := sha256.Sum256([]byte(combined))
return hex.EncodeToString(hash[:])
}
// 如果没有 privileged-user-id,直接使用 tmp 目录哈希
return tmpDirHash
}
...@@ -37,6 +37,7 @@ type Handlers struct { ...@@ -37,6 +37,7 @@ type Handlers struct {
Gateway *GatewayHandler Gateway *GatewayHandler
OpenAIGateway *OpenAIGatewayHandler OpenAIGateway *OpenAIGatewayHandler
Setting *SettingHandler Setting *SettingHandler
Totp *TotpHandler
} }
// BuildInfo contains build-time information // BuildInfo contains build-time information
......
...@@ -32,20 +32,21 @@ func (h *SettingHandler) GetPublicSettings(c *gin.Context) { ...@@ -32,20 +32,21 @@ func (h *SettingHandler) GetPublicSettings(c *gin.Context) {
} }
response.Success(c, dto.PublicSettings{ response.Success(c, dto.PublicSettings{
RegistrationEnabled: settings.RegistrationEnabled, RegistrationEnabled: settings.RegistrationEnabled,
EmailVerifyEnabled: settings.EmailVerifyEnabled, EmailVerifyEnabled: settings.EmailVerifyEnabled,
PromoCodeEnabled: settings.PromoCodeEnabled, PromoCodeEnabled: settings.PromoCodeEnabled,
TurnstileEnabled: settings.TurnstileEnabled, PasswordResetEnabled: settings.PasswordResetEnabled,
TurnstileSiteKey: settings.TurnstileSiteKey, TurnstileEnabled: settings.TurnstileEnabled,
SiteName: settings.SiteName, TurnstileSiteKey: settings.TurnstileSiteKey,
SiteLogo: settings.SiteLogo, SiteName: settings.SiteName,
SiteSubtitle: settings.SiteSubtitle, SiteLogo: settings.SiteLogo,
APIBaseURL: settings.APIBaseURL, SiteSubtitle: settings.SiteSubtitle,
ContactInfo: settings.ContactInfo, APIBaseURL: settings.APIBaseURL,
DocURL: settings.DocURL, ContactInfo: settings.ContactInfo,
HomeContent: settings.HomeContent, DocURL: settings.DocURL,
HideCcsImportButton: settings.HideCcsImportButton, HomeContent: settings.HomeContent,
LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled, HideCcsImportButton: settings.HideCcsImportButton,
Version: h.version, LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled,
Version: h.version,
}) })
} }
package handler
import (
"github.com/gin-gonic/gin"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
)
// TotpHandler handles TOTP-related requests
type TotpHandler struct {
totpService *service.TotpService
}
// NewTotpHandler creates a new TotpHandler
func NewTotpHandler(totpService *service.TotpService) *TotpHandler {
return &TotpHandler{
totpService: totpService,
}
}
// TotpStatusResponse represents the TOTP status response
type TotpStatusResponse struct {
Enabled bool `json:"enabled"`
EnabledAt *int64 `json:"enabled_at,omitempty"` // Unix timestamp
FeatureEnabled bool `json:"feature_enabled"`
}
// GetStatus returns the TOTP status for the current user
// GET /api/v1/user/totp/status
func (h *TotpHandler) GetStatus(c *gin.Context) {
subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !ok {
response.Unauthorized(c, "User not authenticated")
return
}
status, err := h.totpService.GetStatus(c.Request.Context(), subject.UserID)
if err != nil {
response.ErrorFrom(c, err)
return
}
resp := TotpStatusResponse{
Enabled: status.Enabled,
FeatureEnabled: status.FeatureEnabled,
}
if status.EnabledAt != nil {
ts := status.EnabledAt.Unix()
resp.EnabledAt = &ts
}
response.Success(c, resp)
}
// TotpSetupRequest represents the request to initiate TOTP setup
type TotpSetupRequest struct {
EmailCode string `json:"email_code"`
Password string `json:"password"`
}
// TotpSetupResponse represents the TOTP setup response
type TotpSetupResponse struct {
Secret string `json:"secret"`
QRCodeURL string `json:"qr_code_url"`
SetupToken string `json:"setup_token"`
Countdown int `json:"countdown"`
}
// InitiateSetup starts the TOTP setup process
// POST /api/v1/user/totp/setup
func (h *TotpHandler) InitiateSetup(c *gin.Context) {
subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !ok {
response.Unauthorized(c, "User not authenticated")
return
}
var req TotpSetupRequest
if err := c.ShouldBindJSON(&req); err != nil {
// Allow empty body (optional params)
req = TotpSetupRequest{}
}
result, err := h.totpService.InitiateSetup(c.Request.Context(), subject.UserID, req.EmailCode, req.Password)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, TotpSetupResponse{
Secret: result.Secret,
QRCodeURL: result.QRCodeURL,
SetupToken: result.SetupToken,
Countdown: result.Countdown,
})
}
// TotpEnableRequest represents the request to enable TOTP
type TotpEnableRequest struct {
TotpCode string `json:"totp_code" binding:"required,len=6"`
SetupToken string `json:"setup_token" binding:"required"`
}
// Enable completes the TOTP setup
// POST /api/v1/user/totp/enable
func (h *TotpHandler) Enable(c *gin.Context) {
subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !ok {
response.Unauthorized(c, "User not authenticated")
return
}
var req TotpEnableRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
if err := h.totpService.CompleteSetup(c.Request.Context(), subject.UserID, req.TotpCode, req.SetupToken); err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, gin.H{"success": true})
}
// TotpDisableRequest represents the request to disable TOTP
type TotpDisableRequest struct {
EmailCode string `json:"email_code"`
Password string `json:"password"`
}
// Disable disables TOTP for the current user
// POST /api/v1/user/totp/disable
func (h *TotpHandler) Disable(c *gin.Context) {
subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !ok {
response.Unauthorized(c, "User not authenticated")
return
}
var req TotpDisableRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
if err := h.totpService.Disable(c.Request.Context(), subject.UserID, req.EmailCode, req.Password); err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, gin.H{"success": true})
}
// GetVerificationMethod returns the verification method for TOTP operations
// GET /api/v1/user/totp/verification-method
func (h *TotpHandler) GetVerificationMethod(c *gin.Context) {
method := h.totpService.GetVerificationMethod(c.Request.Context())
response.Success(c, method)
}
// SendVerifyCode sends an email verification code for TOTP operations
// POST /api/v1/user/totp/send-code
func (h *TotpHandler) SendVerifyCode(c *gin.Context) {
subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !ok {
response.Unauthorized(c, "User not authenticated")
return
}
if err := h.totpService.SendVerifyCode(c.Request.Context(), subject.UserID); err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, gin.H{"success": true})
}
...@@ -70,6 +70,7 @@ func ProvideHandlers( ...@@ -70,6 +70,7 @@ func ProvideHandlers(
gatewayHandler *GatewayHandler, gatewayHandler *GatewayHandler,
openaiGatewayHandler *OpenAIGatewayHandler, openaiGatewayHandler *OpenAIGatewayHandler,
settingHandler *SettingHandler, settingHandler *SettingHandler,
totpHandler *TotpHandler,
) *Handlers { ) *Handlers {
return &Handlers{ return &Handlers{
Auth: authHandler, Auth: authHandler,
...@@ -82,6 +83,7 @@ func ProvideHandlers( ...@@ -82,6 +83,7 @@ func ProvideHandlers(
Gateway: gatewayHandler, Gateway: gatewayHandler,
OpenAIGateway: openaiGatewayHandler, OpenAIGateway: openaiGatewayHandler,
Setting: settingHandler, Setting: settingHandler,
Totp: totpHandler,
} }
} }
...@@ -96,6 +98,7 @@ var ProviderSet = wire.NewSet( ...@@ -96,6 +98,7 @@ var ProviderSet = wire.NewSet(
NewSubscriptionHandler, NewSubscriptionHandler,
NewGatewayHandler, NewGatewayHandler,
NewOpenAIGatewayHandler, NewOpenAIGatewayHandler,
NewTotpHandler,
ProvideSettingHandler, ProvideSettingHandler,
// Admin handlers // Admin handlers
......
...@@ -7,13 +7,11 @@ import ( ...@@ -7,13 +7,11 @@ import (
"fmt" "fmt"
"log" "log"
"math/rand" "math/rand"
"os"
"strconv" "strconv"
"strings" "strings"
"sync" "sync"
"time" "time"
"github.com/gin-gonic/gin"
"github.com/google/uuid" "github.com/google/uuid"
) )
...@@ -369,8 +367,10 @@ func buildParts(content json.RawMessage, toolIDToName map[string]string, allowDu ...@@ -369,8 +367,10 @@ func buildParts(content json.RawMessage, toolIDToName map[string]string, allowDu
Text: block.Thinking, Text: block.Thinking,
Thought: true, Thought: true,
} }
// 保留原有 signature(Claude 模型需要有效的 signature) // signature 处理:
if block.Signature != "" { // - Claude 模型(allowDummyThought=false):必须是上游返回的真实 signature(dummy 视为缺失)
// - Gemini 模型(allowDummyThought=true):优先透传真实 signature,缺失时使用 dummy signature
if block.Signature != "" && (allowDummyThought || block.Signature != dummyThoughtSignature) {
part.ThoughtSignature = block.Signature part.ThoughtSignature = block.Signature
} else if !allowDummyThought { } else if !allowDummyThought {
// Claude 模型需要有效 signature;在缺失时降级为普通文本,并在上层禁用 thinking mode。 // Claude 模型需要有效 signature;在缺失时降级为普通文本,并在上层禁用 thinking mode。
...@@ -409,12 +409,12 @@ func buildParts(content json.RawMessage, toolIDToName map[string]string, allowDu ...@@ -409,12 +409,12 @@ func buildParts(content json.RawMessage, toolIDToName map[string]string, allowDu
}, },
} }
// tool_use 的 signature 处理: // tool_use 的 signature 处理:
// - Gemini 模型:使用 dummy signature(跳过 thought_signature 校验) // - Claude 模型(allowDummyThought=false):必须是上游返回的真实 signature(dummy 视为缺失)
// - Claude 模型:透传上游返回的真实 signature(Vertex/Google 需要完整签名链路) // - Gemini 模型(allowDummyThought=true):优先透传真实 signature,缺失时使用 dummy signature
if allowDummyThought { if block.Signature != "" && (allowDummyThought || block.Signature != dummyThoughtSignature) {
part.ThoughtSignature = dummyThoughtSignature
} else if block.Signature != "" && block.Signature != dummyThoughtSignature {
part.ThoughtSignature = block.Signature part.ThoughtSignature = block.Signature
} else if allowDummyThought {
part.ThoughtSignature = dummyThoughtSignature
} }
parts = append(parts, part) parts = append(parts, part)
...@@ -594,11 +594,14 @@ func buildTools(tools []ClaudeTool) []GeminiToolDeclaration { ...@@ -594,11 +594,14 @@ func buildTools(tools []ClaudeTool) []GeminiToolDeclaration {
} }
// 清理 JSON Schema // 清理 JSON Schema
params := cleanJSONSchema(inputSchema) // 1. 深度清理 [undefined] 值
DeepCleanUndefined(inputSchema)
// 2. 转换为符合 Gemini v1internal 的 schema
params := CleanJSONSchema(inputSchema)
// 为 nil schema 提供默认值 // 为 nil schema 提供默认值
if params == nil { if params == nil {
params = map[string]any{ params = map[string]any{
"type": "OBJECT", "type": "object", // lowercase type
"properties": map[string]any{}, "properties": map[string]any{},
} }
} }
...@@ -631,236 +634,3 @@ func buildTools(tools []ClaudeTool) []GeminiToolDeclaration { ...@@ -631,236 +634,3 @@ func buildTools(tools []ClaudeTool) []GeminiToolDeclaration {
FunctionDeclarations: funcDecls, FunctionDeclarations: funcDecls,
}} }}
} }
// cleanJSONSchema 清理 JSON Schema,移除 Antigravity/Gemini 不支持的字段
// 参考 proxycast 的实现,确保 schema 符合 JSON Schema draft 2020-12
func cleanJSONSchema(schema map[string]any) map[string]any {
if schema == nil {
return nil
}
cleaned := cleanSchemaValue(schema, "$")
result, ok := cleaned.(map[string]any)
if !ok {
return nil
}
// 确保有 type 字段(默认 OBJECT)
if _, hasType := result["type"]; !hasType {
result["type"] = "OBJECT"
}
// 确保有 properties 字段(默认空对象)
if _, hasProps := result["properties"]; !hasProps {
result["properties"] = make(map[string]any)
}
// 验证 required 中的字段都存在于 properties 中
if required, ok := result["required"].([]any); ok {
if props, ok := result["properties"].(map[string]any); ok {
validRequired := make([]any, 0, len(required))
for _, r := range required {
if reqName, ok := r.(string); ok {
if _, exists := props[reqName]; exists {
validRequired = append(validRequired, r)
}
}
}
if len(validRequired) > 0 {
result["required"] = validRequired
} else {
delete(result, "required")
}
}
}
return result
}
var schemaValidationKeys = map[string]bool{
"minLength": true,
"maxLength": true,
"pattern": true,
"minimum": true,
"maximum": true,
"exclusiveMinimum": true,
"exclusiveMaximum": true,
"multipleOf": true,
"uniqueItems": true,
"minItems": true,
"maxItems": true,
"minProperties": true,
"maxProperties": true,
"patternProperties": true,
"propertyNames": true,
"dependencies": true,
"dependentSchemas": true,
"dependentRequired": true,
}
var warnedSchemaKeys sync.Map
func schemaCleaningWarningsEnabled() bool {
// 可通过环境变量强制开关,方便排查:SUB2API_SCHEMA_CLEAN_WARN=true/false
if v := strings.TrimSpace(os.Getenv("SUB2API_SCHEMA_CLEAN_WARN")); v != "" {
switch strings.ToLower(v) {
case "1", "true", "yes", "on":
return true
case "0", "false", "no", "off":
return false
}
}
// 默认:非 release 模式下输出(debug/test)
return gin.Mode() != gin.ReleaseMode
}
func warnSchemaKeyRemovedOnce(key, path string) {
if !schemaCleaningWarningsEnabled() {
return
}
if !schemaValidationKeys[key] {
return
}
if _, loaded := warnedSchemaKeys.LoadOrStore(key, struct{}{}); loaded {
return
}
log.Printf("[SchemaClean] removed unsupported JSON Schema validation field key=%q path=%q", key, path)
}
// excludedSchemaKeys 不支持的 schema 字段
// 基于 Claude API (Vertex AI) 的实际支持情况
// 支持: type, description, enum, properties, required, additionalProperties, items
// 不支持: minItems, maxItems, minLength, maxLength, pattern, minimum, maximum 等验证字段
var excludedSchemaKeys = map[string]bool{
// 元 schema 字段
"$schema": true,
"$id": true,
"$ref": true,
// 字符串验证(Gemini 不支持)
"minLength": true,
"maxLength": true,
"pattern": true,
// 数字验证(Claude API 通过 Vertex AI 不支持这些字段)
"minimum": true,
"maximum": true,
"exclusiveMinimum": true,
"exclusiveMaximum": true,
"multipleOf": true,
// 数组验证(Claude API 通过 Vertex AI 不支持这些字段)
"uniqueItems": true,
"minItems": true,
"maxItems": true,
// 组合 schema(Gemini 不支持)
"oneOf": true,
"anyOf": true,
"allOf": true,
"not": true,
"if": true,
"then": true,
"else": true,
"$defs": true,
"definitions": true,
// 对象验证(仅保留 properties/required/additionalProperties)
"minProperties": true,
"maxProperties": true,
"patternProperties": true,
"propertyNames": true,
"dependencies": true,
"dependentSchemas": true,
"dependentRequired": true,
// 其他不支持的字段
"default": true,
"const": true,
"examples": true,
"deprecated": true,
"readOnly": true,
"writeOnly": true,
"contentMediaType": true,
"contentEncoding": true,
// Claude 特有字段
"strict": true,
}
// cleanSchemaValue 递归清理 schema 值
func cleanSchemaValue(value any, path string) any {
switch v := value.(type) {
case map[string]any:
result := make(map[string]any)
for k, val := range v {
// 跳过不支持的字段
if excludedSchemaKeys[k] {
warnSchemaKeyRemovedOnce(k, path)
continue
}
// 特殊处理 type 字段
if k == "type" {
result[k] = cleanTypeValue(val)
continue
}
// 特殊处理 format 字段:只保留 Gemini 支持的 format 值
if k == "format" {
if formatStr, ok := val.(string); ok {
// Gemini 只支持 date-time, date, time
if formatStr == "date-time" || formatStr == "date" || formatStr == "time" {
result[k] = val
}
// 其他 format 值直接跳过
}
continue
}
// 特殊处理 additionalProperties:Claude API 只支持布尔值,不支持 schema 对象
if k == "additionalProperties" {
if boolVal, ok := val.(bool); ok {
result[k] = boolVal
} else {
// 如果是 schema 对象,转换为 false(更安全的默认值)
result[k] = false
}
continue
}
// 递归清理所有值
result[k] = cleanSchemaValue(val, path+"."+k)
}
return result
case []any:
// 递归处理数组中的每个元素
cleaned := make([]any, 0, len(v))
for i, item := range v {
cleaned = append(cleaned, cleanSchemaValue(item, fmt.Sprintf("%s[%d]", path, i)))
}
return cleaned
default:
return value
}
}
// cleanTypeValue 处理 type 字段,转换为大写
func cleanTypeValue(value any) any {
switch v := value.(type) {
case string:
return strings.ToUpper(v)
case []any:
// 联合类型 ["string", "null"] -> 取第一个非 null 类型
for _, t := range v {
if ts, ok := t.(string); ok && ts != "null" {
return strings.ToUpper(ts)
}
}
// 如果只有 null,返回 STRING
return "STRING"
default:
return value
}
}
...@@ -100,7 +100,7 @@ func TestBuildParts_ToolUseSignatureHandling(t *testing.T) { ...@@ -100,7 +100,7 @@ func TestBuildParts_ToolUseSignatureHandling(t *testing.T) {
{"type": "tool_use", "id": "t1", "name": "Bash", "input": {"command": "ls"}, "signature": "sig_tool_abc"} {"type": "tool_use", "id": "t1", "name": "Bash", "input": {"command": "ls"}, "signature": "sig_tool_abc"}
]` ]`
t.Run("Gemini uses dummy tool_use signature", func(t *testing.T) { t.Run("Gemini preserves provided tool_use signature", func(t *testing.T) {
toolIDToName := make(map[string]string) toolIDToName := make(map[string]string)
parts, _, err := buildParts(json.RawMessage(content), toolIDToName, true) parts, _, err := buildParts(json.RawMessage(content), toolIDToName, true)
if err != nil { if err != nil {
...@@ -109,6 +109,23 @@ func TestBuildParts_ToolUseSignatureHandling(t *testing.T) { ...@@ -109,6 +109,23 @@ func TestBuildParts_ToolUseSignatureHandling(t *testing.T) {
if len(parts) != 1 || parts[0].FunctionCall == nil { if len(parts) != 1 || parts[0].FunctionCall == nil {
t.Fatalf("expected 1 functionCall part, got %+v", parts) t.Fatalf("expected 1 functionCall part, got %+v", parts)
} }
if parts[0].ThoughtSignature != "sig_tool_abc" {
t.Fatalf("expected preserved tool signature %q, got %q", "sig_tool_abc", parts[0].ThoughtSignature)
}
})
t.Run("Gemini falls back to dummy tool_use signature when missing", func(t *testing.T) {
contentNoSig := `[
{"type": "tool_use", "id": "t1", "name": "Bash", "input": {"command": "ls"}}
]`
toolIDToName := make(map[string]string)
parts, _, err := buildParts(json.RawMessage(contentNoSig), toolIDToName, true)
if err != nil {
t.Fatalf("buildParts() error = %v", err)
}
if len(parts) != 1 || parts[0].FunctionCall == nil {
t.Fatalf("expected 1 functionCall part, got %+v", parts)
}
if parts[0].ThoughtSignature != dummyThoughtSignature { if parts[0].ThoughtSignature != dummyThoughtSignature {
t.Fatalf("expected dummy tool signature %q, got %q", dummyThoughtSignature, parts[0].ThoughtSignature) t.Fatalf("expected dummy tool signature %q, got %q", dummyThoughtSignature, parts[0].ThoughtSignature)
} }
......
...@@ -3,6 +3,7 @@ package antigravity ...@@ -3,6 +3,7 @@ package antigravity
import ( import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"log"
"strings" "strings"
) )
...@@ -19,6 +20,15 @@ func TransformGeminiToClaude(geminiResp []byte, originalModel string) ([]byte, * ...@@ -19,6 +20,15 @@ func TransformGeminiToClaude(geminiResp []byte, originalModel string) ([]byte, *
v1Resp.Response = directResp v1Resp.Response = directResp
v1Resp.ResponseID = directResp.ResponseID v1Resp.ResponseID = directResp.ResponseID
v1Resp.ModelVersion = directResp.ModelVersion v1Resp.ModelVersion = directResp.ModelVersion
} else if len(v1Resp.Response.Candidates) == 0 {
// 第一次解析成功但 candidates 为空,说明是直接的 GeminiResponse 格式
var directResp GeminiResponse
if err2 := json.Unmarshal(geminiResp, &directResp); err2 != nil {
return nil, nil, fmt.Errorf("parse gemini response as direct: %w", err2)
}
v1Resp.Response = directResp
v1Resp.ResponseID = directResp.ResponseID
v1Resp.ModelVersion = directResp.ModelVersion
} }
// 使用处理器转换 // 使用处理器转换
...@@ -173,16 +183,20 @@ func (p *NonStreamingProcessor) processPart(part *GeminiPart) { ...@@ -173,16 +183,20 @@ func (p *NonStreamingProcessor) processPart(part *GeminiPart) {
p.trailingSignature = "" p.trailingSignature = ""
} }
p.textBuilder += part.Text // 非空 text 带签名 - 特殊处理:先输出 text,再输出空 thinking 块
// 非空 text 带签名 - 立即刷新并输出空 thinking 块
if signature != "" { if signature != "" {
p.flushText() p.contentBlocks = append(p.contentBlocks, ClaudeContentItem{
Type: "text",
Text: part.Text,
})
p.contentBlocks = append(p.contentBlocks, ClaudeContentItem{ p.contentBlocks = append(p.contentBlocks, ClaudeContentItem{
Type: "thinking", Type: "thinking",
Thinking: "", Thinking: "",
Signature: signature, Signature: signature,
}) })
} else {
// 普通 text (无签名) - 累积到 builder
p.textBuilder += part.Text
} }
} }
} }
...@@ -242,6 +256,14 @@ func (p *NonStreamingProcessor) buildResponse(geminiResp *GeminiResponse, respon ...@@ -242,6 +256,14 @@ func (p *NonStreamingProcessor) buildResponse(geminiResp *GeminiResponse, respon
var finishReason string var finishReason string
if len(geminiResp.Candidates) > 0 { if len(geminiResp.Candidates) > 0 {
finishReason = geminiResp.Candidates[0].FinishReason finishReason = geminiResp.Candidates[0].FinishReason
if finishReason == "MALFORMED_FUNCTION_CALL" {
log.Printf("[Antigravity] MALFORMED_FUNCTION_CALL detected in response for model %s", originalModel)
if geminiResp.Candidates[0].Content != nil {
if b, err := json.Marshal(geminiResp.Candidates[0].Content); err == nil {
log.Printf("[Antigravity] Malformed content: %s", string(b))
}
}
}
} }
stopReason := "end_turn" stopReason := "end_turn"
......
package antigravity
import (
"fmt"
"strings"
)
// CleanJSONSchema 清理 JSON Schema,移除 Antigravity/Gemini 不支持的字段
// 参考 Antigravity-Manager/src-tauri/src/proxy/common/json_schema.rs 实现
// 确保 schema 符合 JSON Schema draft 2020-12 且适配 Gemini v1internal
func CleanJSONSchema(schema map[string]any) map[string]any {
if schema == nil {
return nil
}
// 0. 预处理:展开 $ref (Schema Flattening)
// (Go map 是引用的,直接修改 schema)
flattenRefs(schema, extractDefs(schema))
// 递归清理
cleaned := cleanJSONSchemaRecursive(schema)
result, ok := cleaned.(map[string]any)
if !ok {
return nil
}
return result
}
// extractDefs 提取并移除定义的 helper
func extractDefs(schema map[string]any) map[string]any {
defs := make(map[string]any)
if d, ok := schema["$defs"].(map[string]any); ok {
for k, v := range d {
defs[k] = v
}
delete(schema, "$defs")
}
if d, ok := schema["definitions"].(map[string]any); ok {
for k, v := range d {
defs[k] = v
}
delete(schema, "definitions")
}
return defs
}
// flattenRefs 递归展开 $ref
func flattenRefs(schema map[string]any, defs map[string]any) {
if len(defs) == 0 {
return // 无需展开
}
// 检查并替换 $ref
if ref, ok := schema["$ref"].(string); ok {
delete(schema, "$ref")
// 解析引用名 (例如 #/$defs/MyType -> MyType)
parts := strings.Split(ref, "/")
refName := parts[len(parts)-1]
if defSchema, exists := defs[refName]; exists {
if defMap, ok := defSchema.(map[string]any); ok {
// 合并定义内容 (不覆盖现有 key)
for k, v := range defMap {
if _, has := schema[k]; !has {
schema[k] = deepCopy(v) // 需深拷贝避免共享引用
}
}
// 递归处理刚刚合并进来的内容
flattenRefs(schema, defs)
}
}
}
// 遍历子节点
for _, v := range schema {
if subMap, ok := v.(map[string]any); ok {
flattenRefs(subMap, defs)
} else if subArr, ok := v.([]any); ok {
for _, item := range subArr {
if itemMap, ok := item.(map[string]any); ok {
flattenRefs(itemMap, defs)
}
}
}
}
}
// deepCopy 深拷贝 (简单实现,仅针对 JSON 类型)
func deepCopy(src any) any {
if src == nil {
return nil
}
switch v := src.(type) {
case map[string]any:
dst := make(map[string]any)
for k, val := range v {
dst[k] = deepCopy(val)
}
return dst
case []any:
dst := make([]any, len(v))
for i, val := range v {
dst[i] = deepCopy(val)
}
return dst
default:
return src
}
}
// cleanJSONSchemaRecursive 递归核心清理逻辑
// 返回处理后的值 (通常是 input map,但可能修改内部结构)
func cleanJSONSchemaRecursive(value any) any {
schemaMap, ok := value.(map[string]any)
if !ok {
return value
}
// 0. [NEW] 合并 allOf
mergeAllOf(schemaMap)
// 1. [CRITICAL] 深度递归处理子项
if props, ok := schemaMap["properties"].(map[string]any); ok {
for _, v := range props {
cleanJSONSchemaRecursive(v)
}
// Go 中不需要像 Rust 那样显式处理 nullable_keys remove required,
// 因为我们在子项处理中会正确设置 type 和 description
} else if items, ok := schemaMap["items"]; ok {
// [FIX] Gemini 期望 "items" 是单个 Schema 对象(列表验证),而不是数组(元组验证)。
if itemsArr, ok := items.([]any); ok {
// 策略:将元组 [A, B] 视为 A、B 中的最佳匹配项。
best := extractBestSchemaFromUnion(itemsArr)
if best == nil {
// 回退到通用字符串
best = map[string]any{"type": "string"}
}
// 用处理后的对象替换原有数组
cleanedBest := cleanJSONSchemaRecursive(best)
schemaMap["items"] = cleanedBest
} else {
cleanJSONSchemaRecursive(items)
}
} else {
// 遍历所有值递归
for _, v := range schemaMap {
if _, isMap := v.(map[string]any); isMap {
cleanJSONSchemaRecursive(v)
} else if arr, isArr := v.([]any); isArr {
for _, item := range arr {
cleanJSONSchemaRecursive(item)
}
}
}
}
// 2. [FIX] 处理 anyOf/oneOf 联合类型: 合并属性而非直接删除
var unionArray []any
typeStr, _ := schemaMap["type"].(string)
if typeStr == "" || typeStr == "object" {
if anyOf, ok := schemaMap["anyOf"].([]any); ok {
unionArray = anyOf
} else if oneOf, ok := schemaMap["oneOf"].([]any); ok {
unionArray = oneOf
}
}
if len(unionArray) > 0 {
if bestBranch := extractBestSchemaFromUnion(unionArray); bestBranch != nil {
if bestMap, ok := bestBranch.(map[string]any); ok {
// 合并分支内容
for k, v := range bestMap {
if k == "properties" {
targetProps, _ := schemaMap["properties"].(map[string]any)
if targetProps == nil {
targetProps = make(map[string]any)
schemaMap["properties"] = targetProps
}
if sourceProps, ok := v.(map[string]any); ok {
for pk, pv := range sourceProps {
if _, exists := targetProps[pk]; !exists {
targetProps[pk] = deepCopy(pv)
}
}
}
} else if k == "required" {
targetReq, _ := schemaMap["required"].([]any)
if sourceReq, ok := v.([]any); ok {
for _, rv := range sourceReq {
// 简单的去重添加
exists := false
for _, tr := range targetReq {
if tr == rv {
exists = true
break
}
}
if !exists {
targetReq = append(targetReq, rv)
}
}
schemaMap["required"] = targetReq
}
} else if _, exists := schemaMap[k]; !exists {
schemaMap[k] = deepCopy(v)
}
}
}
}
}
// 3. [SAFETY] 检查当前对象是否为 JSON Schema 节点
looksLikeSchema := hasKey(schemaMap, "type") ||
hasKey(schemaMap, "properties") ||
hasKey(schemaMap, "items") ||
hasKey(schemaMap, "enum") ||
hasKey(schemaMap, "anyOf") ||
hasKey(schemaMap, "oneOf") ||
hasKey(schemaMap, "allOf")
if looksLikeSchema {
// 4. [ROBUST] 约束迁移
migrateConstraints(schemaMap)
// 5. [CRITICAL] 白名单过滤
allowedFields := map[string]bool{
"type": true,
"description": true,
"properties": true,
"required": true,
"items": true,
"enum": true,
"title": true,
}
for k := range schemaMap {
if !allowedFields[k] {
delete(schemaMap, k)
}
}
// 6. [SAFETY] 处理空 Object
if t, _ := schemaMap["type"].(string); t == "object" {
hasProps := false
if props, ok := schemaMap["properties"].(map[string]any); ok && len(props) > 0 {
hasProps = true
}
if !hasProps {
schemaMap["properties"] = map[string]any{
"reason": map[string]any{
"type": "string",
"description": "Reason for calling this tool",
},
}
schemaMap["required"] = []any{"reason"}
}
}
// 7. [SAFETY] Required 字段对齐
if props, ok := schemaMap["properties"].(map[string]any); ok {
if req, ok := schemaMap["required"].([]any); ok {
var validReq []any
for _, r := range req {
if rStr, ok := r.(string); ok {
if _, exists := props[rStr]; exists {
validReq = append(validReq, r)
}
}
}
if len(validReq) > 0 {
schemaMap["required"] = validReq
} else {
delete(schemaMap, "required")
}
}
}
// 8. 处理 type 字段 (Lowercase + Nullable 提取)
isEffectivelyNullable := false
if typeVal, exists := schemaMap["type"]; exists {
var selectedType string
switch v := typeVal.(type) {
case string:
lower := strings.ToLower(v)
if lower == "null" {
isEffectivelyNullable = true
selectedType = "string" // fallback
} else {
selectedType = lower
}
case []any:
// ["string", "null"]
for _, t := range v {
if ts, ok := t.(string); ok {
lower := strings.ToLower(ts)
if lower == "null" {
isEffectivelyNullable = true
} else if selectedType == "" {
selectedType = lower
}
}
}
if selectedType == "" {
selectedType = "string"
}
}
schemaMap["type"] = selectedType
} else {
// 默认 object 如果有 properties (虽然上面白名单过滤可能删了 type 如果它不在... 但 type 必在 allowlist)
// 如果没有 type,但有 properties,补一个
if hasKey(schemaMap, "properties") {
schemaMap["type"] = "object"
} else {
// 默认为 string ? or object? Gemini 通常需要明确 type
schemaMap["type"] = "object"
}
}
if isEffectivelyNullable {
desc, _ := schemaMap["description"].(string)
if !strings.Contains(desc, "nullable") {
if desc != "" {
desc += " "
}
desc += "(nullable)"
schemaMap["description"] = desc
}
}
// 9. Enum 值强制转字符串
if enumVals, ok := schemaMap["enum"].([]any); ok {
hasNonString := false
for i, val := range enumVals {
if _, isStr := val.(string); !isStr {
hasNonString = true
if val == nil {
enumVals[i] = "null"
} else {
enumVals[i] = fmt.Sprintf("%v", val)
}
}
}
// If we mandated string values, we must ensure type is string
if hasNonString {
schemaMap["type"] = "string"
}
}
}
return schemaMap
}
func hasKey(m map[string]any, k string) bool {
_, ok := m[k]
return ok
}
func migrateConstraints(m map[string]any) {
constraints := []struct {
key string
label string
}{
{"minLength", "minLen"},
{"maxLength", "maxLen"},
{"pattern", "pattern"},
{"minimum", "min"},
{"maximum", "max"},
{"multipleOf", "multipleOf"},
{"exclusiveMinimum", "exclMin"},
{"exclusiveMaximum", "exclMax"},
{"minItems", "minItems"},
{"maxItems", "maxItems"},
{"propertyNames", "propertyNames"},
{"format", "format"},
}
var hints []string
for _, c := range constraints {
if val, ok := m[c.key]; ok && val != nil {
hints = append(hints, fmt.Sprintf("%s: %v", c.label, val))
}
}
if len(hints) > 0 {
suffix := fmt.Sprintf(" [Constraint: %s]", strings.Join(hints, ", "))
desc, _ := m["description"].(string)
if !strings.Contains(desc, suffix) {
m["description"] = desc + suffix
}
}
}
// mergeAllOf 合并 allOf
func mergeAllOf(m map[string]any) {
allOf, ok := m["allOf"].([]any)
if !ok {
return
}
delete(m, "allOf")
mergedProps := make(map[string]any)
mergedReq := make(map[string]bool)
otherFields := make(map[string]any)
for _, sub := range allOf {
if subMap, ok := sub.(map[string]any); ok {
// Props
if props, ok := subMap["properties"].(map[string]any); ok {
for k, v := range props {
mergedProps[k] = v
}
}
// Required
if reqs, ok := subMap["required"].([]any); ok {
for _, r := range reqs {
if s, ok := r.(string); ok {
mergedReq[s] = true
}
}
}
// Others
for k, v := range subMap {
if k != "properties" && k != "required" && k != "allOf" {
if _, exists := otherFields[k]; !exists {
otherFields[k] = v
}
}
}
}
}
// Apply
for k, v := range otherFields {
if _, exists := m[k]; !exists {
m[k] = v
}
}
if len(mergedProps) > 0 {
existProps, _ := m["properties"].(map[string]any)
if existProps == nil {
existProps = make(map[string]any)
m["properties"] = existProps
}
for k, v := range mergedProps {
if _, exists := existProps[k]; !exists {
existProps[k] = v
}
}
}
if len(mergedReq) > 0 {
existReq, _ := m["required"].([]any)
var validReqs []any
for _, r := range existReq {
if s, ok := r.(string); ok {
validReqs = append(validReqs, s)
delete(mergedReq, s) // already exists
}
}
// append new
for r := range mergedReq {
validReqs = append(validReqs, r)
}
m["required"] = validReqs
}
}
// extractBestSchemaFromUnion 从 anyOf/oneOf 中选取最佳分支
func extractBestSchemaFromUnion(unionArray []any) any {
var bestOption any
bestScore := -1
for _, item := range unionArray {
score := scoreSchemaOption(item)
if score > bestScore {
bestScore = score
bestOption = item
}
}
return bestOption
}
func scoreSchemaOption(val any) int {
m, ok := val.(map[string]any)
if !ok {
return 0
}
typeStr, _ := m["type"].(string)
if hasKey(m, "properties") || typeStr == "object" {
return 3
}
if hasKey(m, "items") || typeStr == "array" {
return 2
}
if typeStr != "" && typeStr != "null" {
return 1
}
return 0
}
// DeepCleanUndefined 深度清理值为 "[undefined]" 的字段
func DeepCleanUndefined(value any) {
if value == nil {
return
}
switch v := value.(type) {
case map[string]any:
for k, val := range v {
if s, ok := val.(string); ok && s == "[undefined]" {
delete(v, k)
continue
}
DeepCleanUndefined(val)
}
case []any:
for _, val := range v {
DeepCleanUndefined(val)
}
}
}
...@@ -4,6 +4,7 @@ import ( ...@@ -4,6 +4,7 @@ import (
"bytes" "bytes"
"encoding/json" "encoding/json"
"fmt" "fmt"
"log"
"strings" "strings"
) )
...@@ -102,6 +103,14 @@ func (p *StreamingProcessor) ProcessLine(line string) []byte { ...@@ -102,6 +103,14 @@ func (p *StreamingProcessor) ProcessLine(line string) []byte {
// 检查是否结束 // 检查是否结束
if len(geminiResp.Candidates) > 0 { if len(geminiResp.Candidates) > 0 {
finishReason := geminiResp.Candidates[0].FinishReason finishReason := geminiResp.Candidates[0].FinishReason
if finishReason == "MALFORMED_FUNCTION_CALL" {
log.Printf("[Antigravity] MALFORMED_FUNCTION_CALL detected in stream for model %s", p.originalModel)
if geminiResp.Candidates[0].Content != nil {
if b, err := json.Marshal(geminiResp.Candidates[0].Content); err == nil {
log.Printf("[Antigravity] Malformed content: %s", string(b))
}
}
}
if finishReason != "" { if finishReason != "" {
_, _ = result.Write(p.emitFinish(finishReason)) _, _ = result.Write(p.emitFinish(finishReason))
} }
......
...@@ -24,9 +24,9 @@ const ( ...@@ -24,9 +24,9 @@ const (
RedirectURI = "https://platform.claude.com/oauth/code/callback" RedirectURI = "https://platform.claude.com/oauth/code/callback"
// Scopes - Browser URL (includes org:create_api_key for user authorization) // Scopes - Browser URL (includes org:create_api_key for user authorization)
ScopeOAuth = "org:create_api_key user:profile user:inference user:sessions:claude_code" ScopeOAuth = "org:create_api_key user:profile user:inference user:sessions:claude_code user:mcp_servers"
// Scopes - Internal API call (org:create_api_key not supported in API) // Scopes - Internal API call (org:create_api_key not supported in API)
ScopeAPI = "user:profile user:inference user:sessions:claude_code" ScopeAPI = "user:profile user:inference user:sessions:claude_code user:mcp_servers"
// Scopes - Setup token (inference only) // Scopes - Setup token (inference only)
ScopeInference = "user:inference" ScopeInference = "user:inference"
...@@ -215,5 +215,6 @@ type OrgInfo struct { ...@@ -215,5 +215,6 @@ type OrgInfo struct {
// AccountInfo represents account info from OAuth response // AccountInfo represents account info from OAuth response
type AccountInfo struct { type AccountInfo struct {
UUID string `json:"uuid"` UUID string `json:"uuid"`
EmailAddress string `json:"email_address"`
} }
//go:build integration
// Package tlsfingerprint provides TLS fingerprint simulation for HTTP clients.
//
// Integration tests for verifying TLS fingerprint correctness.
// These tests make actual network requests to external services and should be run manually.
//
// Run with: go test -v -tags=integration ./internal/pkg/tlsfingerprint/...
package tlsfingerprint
import (
"context"
"encoding/json"
"io"
"net/http"
"strings"
"testing"
"time"
)
// skipIfExternalServiceUnavailable checks if the external service is available.
// If not, it skips the test instead of failing.
func skipIfExternalServiceUnavailable(t *testing.T, err error) {
t.Helper()
if err != nil {
// Check for common network/TLS errors that indicate external service issues
errStr := err.Error()
if strings.Contains(errStr, "certificate has expired") ||
strings.Contains(errStr, "certificate is not yet valid") ||
strings.Contains(errStr, "connection refused") ||
strings.Contains(errStr, "no such host") ||
strings.Contains(errStr, "network is unreachable") ||
strings.Contains(errStr, "timeout") {
t.Skipf("skipping test: external service unavailable: %v", err)
}
t.Fatalf("failed to get fingerprint: %v", err)
}
}
// TestJA3Fingerprint verifies the JA3/JA4 fingerprint matches expected value.
// This test uses tls.peet.ws to verify the fingerprint.
// Expected JA3 hash: 1a28e69016765d92e3b381168d68922c (Claude CLI / Node.js 20.x)
// Expected JA4: t13d5911h1_a33745022dd6_1f22a2ca17c4 (d=domain) or t13i5911h1_... (i=IP)
func TestJA3Fingerprint(t *testing.T) {
// Skip if network is unavailable or if running in short mode
if testing.Short() {
t.Skip("skipping integration test in short mode")
}
profile := &Profile{
Name: "Claude CLI Test",
EnableGREASE: false,
}
dialer := NewDialer(profile, nil)
client := &http.Client{
Transport: &http.Transport{
DialTLSContext: dialer.DialTLSContext,
},
Timeout: 30 * time.Second,
}
// Use tls.peet.ws fingerprint detection API
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
req, err := http.NewRequestWithContext(ctx, "GET", "https://tls.peet.ws/api/all", nil)
if err != nil {
t.Fatalf("failed to create request: %v", err)
}
req.Header.Set("User-Agent", "Claude Code/2.0.0 Node.js/20.0.0")
resp, err := client.Do(req)
skipIfExternalServiceUnavailable(t, err)
defer func() { _ = resp.Body.Close() }()
body, err := io.ReadAll(resp.Body)
if err != nil {
t.Fatalf("failed to read response: %v", err)
}
var fpResp FingerprintResponse
if err := json.Unmarshal(body, &fpResp); err != nil {
t.Logf("Response body: %s", string(body))
t.Fatalf("failed to parse fingerprint response: %v", err)
}
// Log all fingerprint information
t.Logf("JA3: %s", fpResp.TLS.JA3)
t.Logf("JA3 Hash: %s", fpResp.TLS.JA3Hash)
t.Logf("JA4: %s", fpResp.TLS.JA4)
t.Logf("PeetPrint: %s", fpResp.TLS.PeetPrint)
t.Logf("PeetPrint Hash: %s", fpResp.TLS.PeetPrintHash)
// Verify JA3 hash matches expected value
expectedJA3Hash := "1a28e69016765d92e3b381168d68922c"
if fpResp.TLS.JA3Hash == expectedJA3Hash {
t.Logf("✓ JA3 hash matches expected value: %s", expectedJA3Hash)
} else {
t.Errorf("✗ JA3 hash mismatch: got %s, expected %s", fpResp.TLS.JA3Hash, expectedJA3Hash)
}
// Verify JA4 fingerprint
// JA4 format: t[version][sni][cipher_count][ext_count][alpn]_[cipher_hash]_[ext_hash]
// Expected: t13d5910h1 (d=domain) or t13i5910h1 (i=IP)
// The suffix _a33745022dd6_1f22a2ca17c4 should match
expectedJA4Suffix := "_a33745022dd6_1f22a2ca17c4"
if strings.HasSuffix(fpResp.TLS.JA4, expectedJA4Suffix) {
t.Logf("✓ JA4 suffix matches expected value: %s", expectedJA4Suffix)
} else {
t.Errorf("✗ JA4 suffix mismatch: got %s, expected suffix %s", fpResp.TLS.JA4, expectedJA4Suffix)
}
// Verify JA4 prefix (t13d5911h1 or t13i5911h1)
// d = domain (SNI present), i = IP (no SNI)
// Since we connect to tls.peet.ws (domain), we expect 'd'
expectedJA4Prefix := "t13d5911h1"
if strings.HasPrefix(fpResp.TLS.JA4, expectedJA4Prefix) {
t.Logf("✓ JA4 prefix matches: %s (t13=TLS1.3, d=domain, 59=ciphers, 11=extensions, h1=HTTP/1.1)", expectedJA4Prefix)
} else {
// Also accept 'i' variant for IP connections
altPrefix := "t13i5911h1"
if strings.HasPrefix(fpResp.TLS.JA4, altPrefix) {
t.Logf("✓ JA4 prefix matches (IP variant): %s", altPrefix)
} else {
t.Errorf("✗ JA4 prefix mismatch: got %s, expected %s or %s", fpResp.TLS.JA4, expectedJA4Prefix, altPrefix)
}
}
// Verify JA3 contains expected cipher suites (TLS 1.3 ciphers at the beginning)
if strings.Contains(fpResp.TLS.JA3, "4866-4867-4865") {
t.Logf("✓ JA3 contains expected TLS 1.3 cipher suites")
} else {
t.Logf("Warning: JA3 does not contain expected TLS 1.3 cipher suites")
}
// Verify extension list (should be 11 extensions including SNI)
// Expected: 0-11-10-35-16-22-23-13-43-45-51
expectedExtensions := "0-11-10-35-16-22-23-13-43-45-51"
if strings.Contains(fpResp.TLS.JA3, expectedExtensions) {
t.Logf("✓ JA3 contains expected extension list: %s", expectedExtensions)
} else {
t.Logf("Warning: JA3 extension list may differ")
}
}
// TestProfileExpectation defines expected fingerprint values for a profile.
type TestProfileExpectation struct {
Profile *Profile
ExpectedJA3 string // Expected JA3 hash (empty = don't check)
ExpectedJA4 string // Expected full JA4 (empty = don't check)
JA4CipherHash string // Expected JA4 cipher hash - the stable middle part (empty = don't check)
}
// TestAllProfiles tests multiple TLS fingerprint profiles against tls.peet.ws.
// Run with: go test -v -tags=integration -run TestAllProfiles ./internal/pkg/tlsfingerprint/...
func TestAllProfiles(t *testing.T) {
if testing.Short() {
t.Skip("skipping integration test in short mode")
}
// Define all profiles to test with their expected fingerprints
// These profiles are from config.yaml gateway.tls_fingerprint.profiles
profiles := []TestProfileExpectation{
{
// Linux x64 Node.js v22.17.1
// Expected JA3 Hash: 1a28e69016765d92e3b381168d68922c
// Expected JA4: t13d5911h1_a33745022dd6_1f22a2ca17c4
Profile: &Profile{
Name: "linux_x64_node_v22171",
EnableGREASE: false,
CipherSuites: []uint16{4866, 4867, 4865, 49199, 49195, 49200, 49196, 158, 49191, 103, 49192, 107, 163, 159, 52393, 52392, 52394, 49327, 49325, 49315, 49311, 49245, 49249, 49239, 49235, 162, 49326, 49324, 49314, 49310, 49244, 49248, 49238, 49234, 49188, 106, 49187, 64, 49162, 49172, 57, 56, 49161, 49171, 51, 50, 157, 49313, 49309, 49233, 156, 49312, 49308, 49232, 61, 60, 53, 47, 255},
Curves: []uint16{29, 23, 30, 25, 24, 256, 257, 258, 259, 260},
PointFormats: []uint8{0, 1, 2},
},
JA4CipherHash: "a33745022dd6", // stable part
},
{
// MacOS arm64 Node.js v22.18.0
// Expected JA3 Hash: 70cb5ca646080902703ffda87036a5ea
// Expected JA4: t13d5912h1_a33745022dd6_dbd39dd1d406
Profile: &Profile{
Name: "macos_arm64_node_v22180",
EnableGREASE: false,
CipherSuites: []uint16{4866, 4867, 4865, 49199, 49195, 49200, 49196, 158, 49191, 103, 49192, 107, 163, 159, 52393, 52392, 52394, 49327, 49325, 49315, 49311, 49245, 49249, 49239, 49235, 162, 49326, 49324, 49314, 49310, 49244, 49248, 49238, 49234, 49188, 106, 49187, 64, 49162, 49172, 57, 56, 49161, 49171, 51, 50, 157, 49313, 49309, 49233, 156, 49312, 49308, 49232, 61, 60, 53, 47, 255},
Curves: []uint16{29, 23, 30, 25, 24, 256, 257, 258, 259, 260},
PointFormats: []uint8{0, 1, 2},
},
JA4CipherHash: "a33745022dd6", // stable part (same cipher suites)
},
}
for _, tc := range profiles {
tc := tc // capture range variable
t.Run(tc.Profile.Name, func(t *testing.T) {
fp := fetchFingerprint(t, tc.Profile)
if fp == nil {
return // fetchFingerprint already called t.Fatal
}
t.Logf("Profile: %s", tc.Profile.Name)
t.Logf(" JA3: %s", fp.JA3)
t.Logf(" JA3 Hash: %s", fp.JA3Hash)
t.Logf(" JA4: %s", fp.JA4)
t.Logf(" PeetPrint: %s", fp.PeetPrint)
t.Logf(" PeetPrintHash: %s", fp.PeetPrintHash)
// Verify expectations
if tc.ExpectedJA3 != "" {
if fp.JA3Hash == tc.ExpectedJA3 {
t.Logf(" ✓ JA3 hash matches: %s", tc.ExpectedJA3)
} else {
t.Errorf(" ✗ JA3 hash mismatch: got %s, expected %s", fp.JA3Hash, tc.ExpectedJA3)
}
}
if tc.ExpectedJA4 != "" {
if fp.JA4 == tc.ExpectedJA4 {
t.Logf(" ✓ JA4 matches: %s", tc.ExpectedJA4)
} else {
t.Errorf(" ✗ JA4 mismatch: got %s, expected %s", fp.JA4, tc.ExpectedJA4)
}
}
// Check JA4 cipher hash (stable middle part)
// JA4 format: prefix_cipherHash_extHash
if tc.JA4CipherHash != "" {
if strings.Contains(fp.JA4, "_"+tc.JA4CipherHash+"_") {
t.Logf(" ✓ JA4 cipher hash matches: %s", tc.JA4CipherHash)
} else {
t.Errorf(" ✗ JA4 cipher hash mismatch: got %s, expected cipher hash %s", fp.JA4, tc.JA4CipherHash)
}
}
})
}
}
// fetchFingerprint makes a request to tls.peet.ws and returns the TLS fingerprint info.
func fetchFingerprint(t *testing.T, profile *Profile) *TLSInfo {
t.Helper()
dialer := NewDialer(profile, nil)
client := &http.Client{
Transport: &http.Transport{
DialTLSContext: dialer.DialTLSContext,
},
Timeout: 30 * time.Second,
}
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
req, err := http.NewRequestWithContext(ctx, "GET", "https://tls.peet.ws/api/all", nil)
if err != nil {
t.Fatalf("failed to create request: %v", err)
return nil
}
req.Header.Set("User-Agent", "Claude Code/2.0.0 Node.js/20.0.0")
resp, err := client.Do(req)
skipIfExternalServiceUnavailable(t, err)
defer func() { _ = resp.Body.Close() }()
body, err := io.ReadAll(resp.Body)
if err != nil {
t.Fatalf("failed to read response: %v", err)
return nil
}
var fpResp FingerprintResponse
if err := json.Unmarshal(body, &fpResp); err != nil {
t.Logf("Response body: %s", string(body))
t.Fatalf("failed to parse fingerprint response: %v", err)
return nil
}
return &fpResp.TLS
}
// Package tlsfingerprint provides TLS fingerprint simulation for HTTP clients. // Package tlsfingerprint provides TLS fingerprint simulation for HTTP clients.
// //
// Integration tests for verifying TLS fingerprint correctness. // Unit tests for TLS fingerprint dialer.
// These tests make actual network requests and should be run manually. // Integration tests that require external network are in dialer_integration_test.go
// and require the 'integration' build tag.
// //
// Run with: go test -v ./internal/pkg/tlsfingerprint/... // Run unit tests: go test -v ./internal/pkg/tlsfingerprint/...
// Run integration tests: go test -v -run TestJA3 ./internal/pkg/tlsfingerprint/... // Run integration tests: go test -v -tags=integration ./internal/pkg/tlsfingerprint/...
package tlsfingerprint package tlsfingerprint
import ( import (
"context"
"encoding/json"
"io"
"net/http"
"net/url" "net/url"
"strings"
"testing" "testing"
"time"
) )
// FingerprintResponse represents the response from tls.peet.ws/api/all. // FingerprintResponse represents the response from tls.peet.ws/api/all.
...@@ -36,148 +31,6 @@ type TLSInfo struct { ...@@ -36,148 +31,6 @@ type TLSInfo struct {
SessionID string `json:"session_id"` SessionID string `json:"session_id"`
} }
// TestDialerBasicConnection tests that the dialer can establish TLS connections.
func TestDialerBasicConnection(t *testing.T) {
if testing.Short() {
t.Skip("skipping network test in short mode")
}
// Create a dialer with default profile
profile := &Profile{
Name: "Test Profile",
EnableGREASE: false,
}
dialer := NewDialer(profile, nil)
// Create HTTP client with custom TLS dialer
client := &http.Client{
Transport: &http.Transport{
DialTLSContext: dialer.DialTLSContext,
},
Timeout: 30 * time.Second,
}
// Make a request to a known HTTPS endpoint
resp, err := client.Get("https://www.google.com")
if err != nil {
t.Fatalf("failed to connect: %v", err)
}
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode != http.StatusOK {
t.Errorf("expected status 200, got %d", resp.StatusCode)
}
}
// TestJA3Fingerprint verifies the JA3/JA4 fingerprint matches expected value.
// This test uses tls.peet.ws to verify the fingerprint.
// Expected JA3 hash: 1a28e69016765d92e3b381168d68922c (Claude CLI / Node.js 20.x)
// Expected JA4: t13d5911h1_a33745022dd6_1f22a2ca17c4 (d=domain) or t13i5911h1_... (i=IP)
func TestJA3Fingerprint(t *testing.T) {
// Skip if network is unavailable or if running in short mode
if testing.Short() {
t.Skip("skipping integration test in short mode")
}
profile := &Profile{
Name: "Claude CLI Test",
EnableGREASE: false,
}
dialer := NewDialer(profile, nil)
client := &http.Client{
Transport: &http.Transport{
DialTLSContext: dialer.DialTLSContext,
},
Timeout: 30 * time.Second,
}
// Use tls.peet.ws fingerprint detection API
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
req, err := http.NewRequestWithContext(ctx, "GET", "https://tls.peet.ws/api/all", nil)
if err != nil {
t.Fatalf("failed to create request: %v", err)
}
req.Header.Set("User-Agent", "Claude Code/2.0.0 Node.js/20.0.0")
resp, err := client.Do(req)
if err != nil {
t.Fatalf("failed to get fingerprint: %v", err)
}
defer func() { _ = resp.Body.Close() }()
body, err := io.ReadAll(resp.Body)
if err != nil {
t.Fatalf("failed to read response: %v", err)
}
var fpResp FingerprintResponse
if err := json.Unmarshal(body, &fpResp); err != nil {
t.Logf("Response body: %s", string(body))
t.Fatalf("failed to parse fingerprint response: %v", err)
}
// Log all fingerprint information
t.Logf("JA3: %s", fpResp.TLS.JA3)
t.Logf("JA3 Hash: %s", fpResp.TLS.JA3Hash)
t.Logf("JA4: %s", fpResp.TLS.JA4)
t.Logf("PeetPrint: %s", fpResp.TLS.PeetPrint)
t.Logf("PeetPrint Hash: %s", fpResp.TLS.PeetPrintHash)
// Verify JA3 hash matches expected value
expectedJA3Hash := "1a28e69016765d92e3b381168d68922c"
if fpResp.TLS.JA3Hash == expectedJA3Hash {
t.Logf("✓ JA3 hash matches expected value: %s", expectedJA3Hash)
} else {
t.Errorf("✗ JA3 hash mismatch: got %s, expected %s", fpResp.TLS.JA3Hash, expectedJA3Hash)
}
// Verify JA4 fingerprint
// JA4 format: t[version][sni][cipher_count][ext_count][alpn]_[cipher_hash]_[ext_hash]
// Expected: t13d5910h1 (d=domain) or t13i5910h1 (i=IP)
// The suffix _a33745022dd6_1f22a2ca17c4 should match
expectedJA4Suffix := "_a33745022dd6_1f22a2ca17c4"
if strings.HasSuffix(fpResp.TLS.JA4, expectedJA4Suffix) {
t.Logf("✓ JA4 suffix matches expected value: %s", expectedJA4Suffix)
} else {
t.Errorf("✗ JA4 suffix mismatch: got %s, expected suffix %s", fpResp.TLS.JA4, expectedJA4Suffix)
}
// Verify JA4 prefix (t13d5911h1 or t13i5911h1)
// d = domain (SNI present), i = IP (no SNI)
// Since we connect to tls.peet.ws (domain), we expect 'd'
expectedJA4Prefix := "t13d5911h1"
if strings.HasPrefix(fpResp.TLS.JA4, expectedJA4Prefix) {
t.Logf("✓ JA4 prefix matches: %s (t13=TLS1.3, d=domain, 59=ciphers, 11=extensions, h1=HTTP/1.1)", expectedJA4Prefix)
} else {
// Also accept 'i' variant for IP connections
altPrefix := "t13i5911h1"
if strings.HasPrefix(fpResp.TLS.JA4, altPrefix) {
t.Logf("✓ JA4 prefix matches (IP variant): %s", altPrefix)
} else {
t.Errorf("✗ JA4 prefix mismatch: got %s, expected %s or %s", fpResp.TLS.JA4, expectedJA4Prefix, altPrefix)
}
}
// Verify JA3 contains expected cipher suites (TLS 1.3 ciphers at the beginning)
if strings.Contains(fpResp.TLS.JA3, "4866-4867-4865") {
t.Logf("✓ JA3 contains expected TLS 1.3 cipher suites")
} else {
t.Logf("Warning: JA3 does not contain expected TLS 1.3 cipher suites")
}
// Verify extension list (should be 11 extensions including SNI)
// Expected: 0-11-10-35-16-22-23-13-43-45-51
expectedExtensions := "0-11-10-35-16-22-23-13-43-45-51"
if strings.Contains(fpResp.TLS.JA3, expectedExtensions) {
t.Logf("✓ JA3 contains expected extension list: %s", expectedExtensions)
} else {
t.Logf("Warning: JA3 extension list may differ")
}
}
// TestDialerWithProfile tests that different profiles produce different fingerprints. // TestDialerWithProfile tests that different profiles produce different fingerprints.
func TestDialerWithProfile(t *testing.T) { func TestDialerWithProfile(t *testing.T) {
// Create two dialers with different profiles // Create two dialers with different profiles
...@@ -305,139 +158,3 @@ func mustParseURL(rawURL string) *url.URL { ...@@ -305,139 +158,3 @@ func mustParseURL(rawURL string) *url.URL {
} }
return u return u
} }
// TestProfileExpectation defines expected fingerprint values for a profile.
type TestProfileExpectation struct {
Profile *Profile
ExpectedJA3 string // Expected JA3 hash (empty = don't check)
ExpectedJA4 string // Expected full JA4 (empty = don't check)
JA4CipherHash string // Expected JA4 cipher hash - the stable middle part (empty = don't check)
}
// TestAllProfiles tests multiple TLS fingerprint profiles against tls.peet.ws.
// Run with: go test -v -run TestAllProfiles ./internal/pkg/tlsfingerprint/...
func TestAllProfiles(t *testing.T) {
if testing.Short() {
t.Skip("skipping integration test in short mode")
}
// Define all profiles to test with their expected fingerprints
// These profiles are from config.yaml gateway.tls_fingerprint.profiles
profiles := []TestProfileExpectation{
{
// Linux x64 Node.js v22.17.1
// Expected JA3 Hash: 1a28e69016765d92e3b381168d68922c
// Expected JA4: t13d5911h1_a33745022dd6_1f22a2ca17c4
Profile: &Profile{
Name: "linux_x64_node_v22171",
EnableGREASE: false,
CipherSuites: []uint16{4866, 4867, 4865, 49199, 49195, 49200, 49196, 158, 49191, 103, 49192, 107, 163, 159, 52393, 52392, 52394, 49327, 49325, 49315, 49311, 49245, 49249, 49239, 49235, 162, 49326, 49324, 49314, 49310, 49244, 49248, 49238, 49234, 49188, 106, 49187, 64, 49162, 49172, 57, 56, 49161, 49171, 51, 50, 157, 49313, 49309, 49233, 156, 49312, 49308, 49232, 61, 60, 53, 47, 255},
Curves: []uint16{29, 23, 30, 25, 24, 256, 257, 258, 259, 260},
PointFormats: []uint8{0, 1, 2},
},
JA4CipherHash: "a33745022dd6", // stable part
},
{
// MacOS arm64 Node.js v22.18.0
// Expected JA3 Hash: 70cb5ca646080902703ffda87036a5ea
// Expected JA4: t13d5912h1_a33745022dd6_dbd39dd1d406
Profile: &Profile{
Name: "macos_arm64_node_v22180",
EnableGREASE: false,
CipherSuites: []uint16{4866, 4867, 4865, 49199, 49195, 49200, 49196, 158, 49191, 103, 49192, 107, 163, 159, 52393, 52392, 52394, 49327, 49325, 49315, 49311, 49245, 49249, 49239, 49235, 162, 49326, 49324, 49314, 49310, 49244, 49248, 49238, 49234, 49188, 106, 49187, 64, 49162, 49172, 57, 56, 49161, 49171, 51, 50, 157, 49313, 49309, 49233, 156, 49312, 49308, 49232, 61, 60, 53, 47, 255},
Curves: []uint16{29, 23, 30, 25, 24, 256, 257, 258, 259, 260},
PointFormats: []uint8{0, 1, 2},
},
JA4CipherHash: "a33745022dd6", // stable part (same cipher suites)
},
}
for _, tc := range profiles {
tc := tc // capture range variable
t.Run(tc.Profile.Name, func(t *testing.T) {
fp := fetchFingerprint(t, tc.Profile)
if fp == nil {
return // fetchFingerprint already called t.Fatal
}
t.Logf("Profile: %s", tc.Profile.Name)
t.Logf(" JA3: %s", fp.JA3)
t.Logf(" JA3 Hash: %s", fp.JA3Hash)
t.Logf(" JA4: %s", fp.JA4)
t.Logf(" PeetPrint: %s", fp.PeetPrint)
t.Logf(" PeetPrintHash: %s", fp.PeetPrintHash)
// Verify expectations
if tc.ExpectedJA3 != "" {
if fp.JA3Hash == tc.ExpectedJA3 {
t.Logf(" ✓ JA3 hash matches: %s", tc.ExpectedJA3)
} else {
t.Errorf(" ✗ JA3 hash mismatch: got %s, expected %s", fp.JA3Hash, tc.ExpectedJA3)
}
}
if tc.ExpectedJA4 != "" {
if fp.JA4 == tc.ExpectedJA4 {
t.Logf(" ✓ JA4 matches: %s", tc.ExpectedJA4)
} else {
t.Errorf(" ✗ JA4 mismatch: got %s, expected %s", fp.JA4, tc.ExpectedJA4)
}
}
// Check JA4 cipher hash (stable middle part)
// JA4 format: prefix_cipherHash_extHash
if tc.JA4CipherHash != "" {
if strings.Contains(fp.JA4, "_"+tc.JA4CipherHash+"_") {
t.Logf(" ✓ JA4 cipher hash matches: %s", tc.JA4CipherHash)
} else {
t.Errorf(" ✗ JA4 cipher hash mismatch: got %s, expected cipher hash %s", fp.JA4, tc.JA4CipherHash)
}
}
})
}
}
// fetchFingerprint makes a request to tls.peet.ws and returns the TLS fingerprint info.
func fetchFingerprint(t *testing.T, profile *Profile) *TLSInfo {
t.Helper()
dialer := NewDialer(profile, nil)
client := &http.Client{
Transport: &http.Transport{
DialTLSContext: dialer.DialTLSContext,
},
Timeout: 30 * time.Second,
}
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
req, err := http.NewRequestWithContext(ctx, "GET", "https://tls.peet.ws/api/all", nil)
if err != nil {
t.Fatalf("failed to create request: %v", err)
return nil
}
req.Header.Set("User-Agent", "Claude Code/2.0.0 Node.js/20.0.0")
resp, err := client.Do(req)
if err != nil {
t.Fatalf("failed to get fingerprint: %v", err)
return nil
}
defer func() { _ = resp.Body.Close() }()
body, err := io.ReadAll(resp.Body)
if err != nil {
t.Fatalf("failed to read response: %v", err)
return nil
}
var fpResp FingerprintResponse
if err := json.Unmarshal(body, &fpResp); err != nil {
t.Logf("Response body: %s", string(body))
t.Fatalf("failed to parse fingerprint response: %v", err)
return nil
}
return &fpResp.TLS
}
package repository
import (
"crypto/aes"
"crypto/cipher"
"crypto/rand"
"encoding/base64"
"encoding/hex"
"fmt"
"io"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/service"
)
// AESEncryptor implements SecretEncryptor using AES-256-GCM
type AESEncryptor struct {
key []byte
}
// NewAESEncryptor creates a new AES encryptor
func NewAESEncryptor(cfg *config.Config) (service.SecretEncryptor, error) {
key, err := hex.DecodeString(cfg.Totp.EncryptionKey)
if err != nil {
return nil, fmt.Errorf("invalid totp encryption key: %w", err)
}
if len(key) != 32 {
return nil, fmt.Errorf("totp encryption key must be 32 bytes (64 hex chars), got %d bytes", len(key))
}
return &AESEncryptor{key: key}, nil
}
// Encrypt encrypts plaintext using AES-256-GCM
// Output format: base64(nonce + ciphertext + tag)
func (e *AESEncryptor) Encrypt(plaintext string) (string, error) {
block, err := aes.NewCipher(e.key)
if err != nil {
return "", fmt.Errorf("create cipher: %w", err)
}
gcm, err := cipher.NewGCM(block)
if err != nil {
return "", fmt.Errorf("create gcm: %w", err)
}
// Generate a random nonce
nonce := make([]byte, gcm.NonceSize())
if _, err := io.ReadFull(rand.Reader, nonce); err != nil {
return "", fmt.Errorf("generate nonce: %w", err)
}
// Encrypt the plaintext
// Seal appends the ciphertext and tag to the nonce
ciphertext := gcm.Seal(nonce, nonce, []byte(plaintext), nil)
// Encode as base64
return base64.StdEncoding.EncodeToString(ciphertext), nil
}
// Decrypt decrypts ciphertext using AES-256-GCM
func (e *AESEncryptor) Decrypt(ciphertext string) (string, error) {
// Decode from base64
data, err := base64.StdEncoding.DecodeString(ciphertext)
if err != nil {
return "", fmt.Errorf("decode base64: %w", err)
}
block, err := aes.NewCipher(e.key)
if err != nil {
return "", fmt.Errorf("create cipher: %w", err)
}
gcm, err := cipher.NewGCM(block)
if err != nil {
return "", fmt.Errorf("create gcm: %w", err)
}
nonceSize := gcm.NonceSize()
if len(data) < nonceSize {
return "", fmt.Errorf("ciphertext too short")
}
// Extract nonce and ciphertext
nonce, ciphertextData := data[:nonceSize], data[nonceSize:]
// Decrypt
plaintext, err := gcm.Open(nil, nonce, ciphertextData, nil)
if err != nil {
return "", fmt.Errorf("decrypt: %w", err)
}
return string(plaintext), nil
}
...@@ -387,17 +387,20 @@ func userEntityToService(u *dbent.User) *service.User { ...@@ -387,17 +387,20 @@ func userEntityToService(u *dbent.User) *service.User {
return nil return nil
} }
return &service.User{ return &service.User{
ID: u.ID, ID: u.ID,
Email: u.Email, Email: u.Email,
Username: u.Username, Username: u.Username,
Notes: u.Notes, Notes: u.Notes,
PasswordHash: u.PasswordHash, PasswordHash: u.PasswordHash,
Role: u.Role, Role: u.Role,
Balance: u.Balance, Balance: u.Balance,
Concurrency: u.Concurrency, Concurrency: u.Concurrency,
Status: u.Status, Status: u.Status,
CreatedAt: u.CreatedAt, TotpSecretEncrypted: u.TotpSecretEncrypted,
UpdatedAt: u.UpdatedAt, TotpEnabled: u.TotpEnabled,
TotpEnabledAt: u.TotpEnabledAt,
CreatedAt: u.CreatedAt,
UpdatedAt: u.UpdatedAt,
} }
} }
......
...@@ -35,7 +35,9 @@ func (s *claudeOAuthService) GetOrganizationUUID(ctx context.Context, sessionKey ...@@ -35,7 +35,9 @@ func (s *claudeOAuthService) GetOrganizationUUID(ctx context.Context, sessionKey
client := s.clientFactory(proxyURL) client := s.clientFactory(proxyURL)
var orgs []struct { var orgs []struct {
UUID string `json:"uuid"` UUID string `json:"uuid"`
Name string `json:"name"`
RavenType *string `json:"raven_type"` // nil for personal, "team" for team organization
} }
targetURL := s.baseURL + "/api/organizations" targetURL := s.baseURL + "/api/organizations"
...@@ -65,7 +67,23 @@ func (s *claudeOAuthService) GetOrganizationUUID(ctx context.Context, sessionKey ...@@ -65,7 +67,23 @@ func (s *claudeOAuthService) GetOrganizationUUID(ctx context.Context, sessionKey
return "", fmt.Errorf("no organizations found") return "", fmt.Errorf("no organizations found")
} }
log.Printf("[OAuth] Step 1 SUCCESS - Got org UUID: %s", orgs[0].UUID) // 如果只有一个组织,直接使用
if len(orgs) == 1 {
log.Printf("[OAuth] Step 1 SUCCESS - Single org found, UUID: %s, Name: %s", orgs[0].UUID, orgs[0].Name)
return orgs[0].UUID, nil
}
// 如果有多个组织,优先选择 raven_type 为 "team" 的组织
for _, org := range orgs {
if org.RavenType != nil && *org.RavenType == "team" {
log.Printf("[OAuth] Step 1 SUCCESS - Selected team org, UUID: %s, Name: %s, RavenType: %s",
org.UUID, org.Name, *org.RavenType)
return org.UUID, nil
}
}
// 如果没有 team 类型的组织,使用第一个
log.Printf("[OAuth] Step 1 SUCCESS - No team org found, using first org, UUID: %s, Name: %s", orgs[0].UUID, orgs[0].Name)
return orgs[0].UUID, nil return orgs[0].UUID, nil
} }
......
...@@ -9,13 +9,27 @@ import ( ...@@ -9,13 +9,27 @@ import (
"github.com/redis/go-redis/v9" "github.com/redis/go-redis/v9"
) )
const verifyCodeKeyPrefix = "verify_code:" const (
verifyCodeKeyPrefix = "verify_code:"
passwordResetKeyPrefix = "password_reset:"
passwordResetSentAtKeyPrefix = "password_reset_sent:"
)
// verifyCodeKey generates the Redis key for email verification code. // verifyCodeKey generates the Redis key for email verification code.
func verifyCodeKey(email string) string { func verifyCodeKey(email string) string {
return verifyCodeKeyPrefix + email return verifyCodeKeyPrefix + email
} }
// passwordResetKey generates the Redis key for password reset token.
func passwordResetKey(email string) string {
return passwordResetKeyPrefix + email
}
// passwordResetSentAtKey generates the Redis key for password reset email sent timestamp.
func passwordResetSentAtKey(email string) string {
return passwordResetSentAtKeyPrefix + email
}
type emailCache struct { type emailCache struct {
rdb *redis.Client rdb *redis.Client
} }
...@@ -50,3 +64,45 @@ func (c *emailCache) DeleteVerificationCode(ctx context.Context, email string) e ...@@ -50,3 +64,45 @@ func (c *emailCache) DeleteVerificationCode(ctx context.Context, email string) e
key := verifyCodeKey(email) key := verifyCodeKey(email)
return c.rdb.Del(ctx, key).Err() return c.rdb.Del(ctx, key).Err()
} }
// Password reset token methods
func (c *emailCache) GetPasswordResetToken(ctx context.Context, email string) (*service.PasswordResetTokenData, error) {
key := passwordResetKey(email)
val, err := c.rdb.Get(ctx, key).Result()
if err != nil {
return nil, err
}
var data service.PasswordResetTokenData
if err := json.Unmarshal([]byte(val), &data); err != nil {
return nil, err
}
return &data, nil
}
func (c *emailCache) SetPasswordResetToken(ctx context.Context, email string, data *service.PasswordResetTokenData, ttl time.Duration) error {
key := passwordResetKey(email)
val, err := json.Marshal(data)
if err != nil {
return err
}
return c.rdb.Set(ctx, key, val, ttl).Err()
}
func (c *emailCache) DeletePasswordResetToken(ctx context.Context, email string) error {
key := passwordResetKey(email)
return c.rdb.Del(ctx, key).Err()
}
// Password reset email cooldown methods
func (c *emailCache) IsPasswordResetEmailInCooldown(ctx context.Context, email string) bool {
key := passwordResetSentAtKey(email)
exists, err := c.rdb.Exists(ctx, key).Result()
return err == nil && exists > 0
}
func (c *emailCache) SetPasswordResetEmailCooldown(ctx context.Context, email string, ttl time.Duration) error {
key := passwordResetSentAtKey(email)
return c.rdb.Set(ctx, key, "1", ttl).Err()
}
package repository
import (
"context"
"encoding/json"
"fmt"
"time"
"github.com/redis/go-redis/v9"
"github.com/Wei-Shaw/sub2api/internal/service"
)
const (
totpSetupKeyPrefix = "totp:setup:"
totpLoginKeyPrefix = "totp:login:"
totpAttemptsKeyPrefix = "totp:attempts:"
totpAttemptsTTL = 15 * time.Minute
)
// TotpCache implements service.TotpCache using Redis
type TotpCache struct {
rdb *redis.Client
}
// NewTotpCache creates a new TOTP cache
func NewTotpCache(rdb *redis.Client) service.TotpCache {
return &TotpCache{rdb: rdb}
}
// GetSetupSession retrieves a TOTP setup session
func (c *TotpCache) GetSetupSession(ctx context.Context, userID int64) (*service.TotpSetupSession, error) {
key := fmt.Sprintf("%s%d", totpSetupKeyPrefix, userID)
data, err := c.rdb.Get(ctx, key).Bytes()
if err != nil {
if err == redis.Nil {
return nil, nil
}
return nil, fmt.Errorf("get setup session: %w", err)
}
var session service.TotpSetupSession
if err := json.Unmarshal(data, &session); err != nil {
return nil, fmt.Errorf("unmarshal setup session: %w", err)
}
return &session, nil
}
// SetSetupSession stores a TOTP setup session
func (c *TotpCache) SetSetupSession(ctx context.Context, userID int64, session *service.TotpSetupSession, ttl time.Duration) error {
key := fmt.Sprintf("%s%d", totpSetupKeyPrefix, userID)
data, err := json.Marshal(session)
if err != nil {
return fmt.Errorf("marshal setup session: %w", err)
}
if err := c.rdb.Set(ctx, key, data, ttl).Err(); err != nil {
return fmt.Errorf("set setup session: %w", err)
}
return nil
}
// DeleteSetupSession deletes a TOTP setup session
func (c *TotpCache) DeleteSetupSession(ctx context.Context, userID int64) error {
key := fmt.Sprintf("%s%d", totpSetupKeyPrefix, userID)
return c.rdb.Del(ctx, key).Err()
}
// GetLoginSession retrieves a TOTP login session
func (c *TotpCache) GetLoginSession(ctx context.Context, tempToken string) (*service.TotpLoginSession, error) {
key := totpLoginKeyPrefix + tempToken
data, err := c.rdb.Get(ctx, key).Bytes()
if err != nil {
if err == redis.Nil {
return nil, nil
}
return nil, fmt.Errorf("get login session: %w", err)
}
var session service.TotpLoginSession
if err := json.Unmarshal(data, &session); err != nil {
return nil, fmt.Errorf("unmarshal login session: %w", err)
}
return &session, nil
}
// SetLoginSession stores a TOTP login session
func (c *TotpCache) SetLoginSession(ctx context.Context, tempToken string, session *service.TotpLoginSession, ttl time.Duration) error {
key := totpLoginKeyPrefix + tempToken
data, err := json.Marshal(session)
if err != nil {
return fmt.Errorf("marshal login session: %w", err)
}
if err := c.rdb.Set(ctx, key, data, ttl).Err(); err != nil {
return fmt.Errorf("set login session: %w", err)
}
return nil
}
// DeleteLoginSession deletes a TOTP login session
func (c *TotpCache) DeleteLoginSession(ctx context.Context, tempToken string) error {
key := totpLoginKeyPrefix + tempToken
return c.rdb.Del(ctx, key).Err()
}
// IncrementVerifyAttempts increments the verify attempt counter
func (c *TotpCache) IncrementVerifyAttempts(ctx context.Context, userID int64) (int, error) {
key := fmt.Sprintf("%s%d", totpAttemptsKeyPrefix, userID)
// Use pipeline for atomic increment and set TTL
pipe := c.rdb.Pipeline()
incrCmd := pipe.Incr(ctx, key)
pipe.Expire(ctx, key, totpAttemptsTTL)
if _, err := pipe.Exec(ctx); err != nil {
return 0, fmt.Errorf("increment verify attempts: %w", err)
}
count, err := incrCmd.Result()
if err != nil {
return 0, fmt.Errorf("get increment result: %w", err)
}
return int(count), nil
}
// GetVerifyAttempts gets the current verify attempt count
func (c *TotpCache) GetVerifyAttempts(ctx context.Context, userID int64) (int, error) {
key := fmt.Sprintf("%s%d", totpAttemptsKeyPrefix, userID)
count, err := c.rdb.Get(ctx, key).Int()
if err != nil {
if err == redis.Nil {
return 0, nil
}
return 0, fmt.Errorf("get verify attempts: %w", err)
}
return count, nil
}
// ClearVerifyAttempts clears the verify attempt counter
func (c *TotpCache) ClearVerifyAttempts(ctx context.Context, userID int64) error {
key := fmt.Sprintf("%s%d", totpAttemptsKeyPrefix, userID)
return c.rdb.Del(ctx, key).Err()
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment