Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Menu
Open sidebar
陈曦
sub2api
Commits
a161fcc8
Commit
a161fcc8
authored
Jan 26, 2026
by
cyhhao
Browse files
Merge branch 'main' of github.com:Wei-Shaw/sub2api
parents
65e69738
e32c5f53
Changes
119
Show whitespace changes
Inline
Side-by-side
backend/internal/handler/gateway_handler.go
View file @
a161fcc8
...
@@ -209,18 +209,21 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
...
@@ -209,18 +209,21 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
account
:=
selection
.
Account
account
:=
selection
.
Account
setOpsSelectedAccount
(
c
,
account
.
ID
)
setOpsSelectedAccount
(
c
,
account
.
ID
)
// 检查预热请求拦截(在账号选择后、转发前检查)
// 检查请求拦截(预热请求、SUGGESTION MODE等)
if
account
.
IsInterceptWarmupEnabled
()
&&
isWarmupRequest
(
body
)
{
if
account
.
IsInterceptWarmupEnabled
()
{
interceptType
:=
detectInterceptType
(
body
)
if
interceptType
!=
InterceptTypeNone
{
if
selection
.
Acquired
&&
selection
.
ReleaseFunc
!=
nil
{
if
selection
.
Acquired
&&
selection
.
ReleaseFunc
!=
nil
{
selection
.
ReleaseFunc
()
selection
.
ReleaseFunc
()
}
}
if
reqStream
{
if
reqStream
{
sendMock
Warmup
Stream
(
c
,
reqModel
)
sendMock
Intercept
Stream
(
c
,
reqModel
,
interceptType
)
}
else
{
}
else
{
sendMock
Warmup
Response
(
c
,
reqModel
)
sendMock
Intercept
Response
(
c
,
reqModel
,
interceptType
)
}
}
return
return
}
}
}
// 3. 获取账号并发槽位
// 3. 获取账号并发槽位
accountReleaseFunc
:=
selection
.
ReleaseFunc
accountReleaseFunc
:=
selection
.
ReleaseFunc
...
@@ -344,18 +347,21 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
...
@@ -344,18 +347,21 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
account
:=
selection
.
Account
account
:=
selection
.
Account
setOpsSelectedAccount
(
c
,
account
.
ID
)
setOpsSelectedAccount
(
c
,
account
.
ID
)
// 检查预热请求拦截(在账号选择后、转发前检查)
// 检查请求拦截(预热请求、SUGGESTION MODE等)
if
account
.
IsInterceptWarmupEnabled
()
&&
isWarmupRequest
(
body
)
{
if
account
.
IsInterceptWarmupEnabled
()
{
interceptType
:=
detectInterceptType
(
body
)
if
interceptType
!=
InterceptTypeNone
{
if
selection
.
Acquired
&&
selection
.
ReleaseFunc
!=
nil
{
if
selection
.
Acquired
&&
selection
.
ReleaseFunc
!=
nil
{
selection
.
ReleaseFunc
()
selection
.
ReleaseFunc
()
}
}
if
reqStream
{
if
reqStream
{
sendMock
Warmup
Stream
(
c
,
reqModel
)
sendMock
Intercept
Stream
(
c
,
reqModel
,
interceptType
)
}
else
{
}
else
{
sendMock
Warmup
Response
(
c
,
reqModel
)
sendMock
Intercept
Response
(
c
,
reqModel
,
interceptType
)
}
}
return
return
}
}
}
// 3. 获取账号并发槽位
// 3. 获取账号并发槽位
accountReleaseFunc
:=
selection
.
ReleaseFunc
accountReleaseFunc
:=
selection
.
ReleaseFunc
...
@@ -768,17 +774,30 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) {
...
@@ -768,17 +774,30 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) {
}
}
}
}
// isWarmupRequest 检测是否为预热请求(标题生成、Warmup等)
// InterceptType 表示请求拦截类型
func
isWarmupRequest
(
body
[]
byte
)
bool
{
type
InterceptType
int
// 快速检查:如果body不包含关键字,直接返回false
const
(
InterceptTypeNone
InterceptType
=
iota
InterceptTypeWarmup
// 预热请求(返回 "New Conversation")
InterceptTypeSuggestionMode
// SUGGESTION MODE(返回空字符串)
)
// detectInterceptType 检测请求是否需要拦截,返回拦截类型
func
detectInterceptType
(
body
[]
byte
)
InterceptType
{
// 快速检查:如果不包含任何关键字,直接返回
bodyStr
:=
string
(
body
)
bodyStr
:=
string
(
body
)
if
!
strings
.
Contains
(
bodyStr
,
"title"
)
&&
!
strings
.
Contains
(
bodyStr
,
"Warmup"
)
{
hasSuggestionMode
:=
strings
.
Contains
(
bodyStr
,
"[SUGGESTION MODE:"
)
return
false
hasWarmupKeyword
:=
strings
.
Contains
(
bodyStr
,
"title"
)
||
strings
.
Contains
(
bodyStr
,
"Warmup"
)
if
!
hasSuggestionMode
&&
!
hasWarmupKeyword
{
return
InterceptTypeNone
}
}
// 解析
完整
请求
// 解析请求
(只解析一次)
var
req
struct
{
var
req
struct
{
Messages
[]
struct
{
Messages
[]
struct
{
Role
string
`json:"role"`
Content
[]
struct
{
Content
[]
struct
{
Type
string
`json:"type"`
Type
string
`json:"type"`
Text
string
`json:"text"`
Text
string
`json:"text"`
...
@@ -789,43 +808,71 @@ func isWarmupRequest(body []byte) bool {
...
@@ -789,43 +808,71 @@ func isWarmupRequest(body []byte) bool {
}
`json:"system"`
}
`json:"system"`
}
}
if
err
:=
json
.
Unmarshal
(
body
,
&
req
);
err
!=
nil
{
if
err
:=
json
.
Unmarshal
(
body
,
&
req
);
err
!=
nil
{
return
false
return
InterceptTypeNone
}
// 检查 SUGGESTION MODE(最后一条 user 消息)
if
hasSuggestionMode
&&
len
(
req
.
Messages
)
>
0
{
lastMsg
:=
req
.
Messages
[
len
(
req
.
Messages
)
-
1
]
if
lastMsg
.
Role
==
"user"
&&
len
(
lastMsg
.
Content
)
>
0
&&
lastMsg
.
Content
[
0
]
.
Type
==
"text"
&&
strings
.
HasPrefix
(
lastMsg
.
Content
[
0
]
.
Text
,
"[SUGGESTION MODE:"
)
{
return
InterceptTypeSuggestionMode
}
}
}
// 检查 Warmup 请求
if
hasWarmupKeyword
{
// 检查 messages 中的标题提示模式
// 检查 messages 中的标题提示模式
for
_
,
msg
:=
range
req
.
Messages
{
for
_
,
msg
:=
range
req
.
Messages
{
for
_
,
content
:=
range
msg
.
Content
{
for
_
,
content
:=
range
msg
.
Content
{
if
content
.
Type
==
"text"
{
if
content
.
Type
==
"text"
{
if
strings
.
Contains
(
content
.
Text
,
"Please write a 5-10 word title for the following conversation:"
)
||
if
strings
.
Contains
(
content
.
Text
,
"Please write a 5-10 word title for the following conversation:"
)
||
content
.
Text
==
"Warmup"
{
content
.
Text
==
"Warmup"
{
return
true
return
InterceptTypeWarmup
}
}
}
}
}
}
}
}
// 检查 system 中的标题提取模式
// 检查 system 中的标题提取模式
for
_
,
system
:=
range
req
.
System
{
for
_
,
sys
:=
range
req
.
System
{
if
strings
.
Contains
(
system
.
Text
,
"nalyze if this message indicates a new conversation topic. If it does, extract a 2-3 word title"
)
{
if
strings
.
Contains
(
sys
.
Text
,
"nalyze if this message indicates a new conversation topic. If it does, extract a 2-3 word title"
)
{
return
true
return
InterceptTypeWarmup
}
}
}
}
}
return
fals
e
return
InterceptTypeNon
e
}
}
// sendMock
Warmup
Stream 发送流式 mock 响应(用于
预热
请求拦截)
// sendMock
Intercept
Stream 发送流式 mock 响应(用于请求拦截)
func
sendMock
Warmup
Stream
(
c
*
gin
.
Context
,
model
string
)
{
func
sendMock
Intercept
Stream
(
c
*
gin
.
Context
,
model
string
,
interceptType
InterceptType
)
{
c
.
Header
(
"Content-Type"
,
"text/event-stream"
)
c
.
Header
(
"Content-Type"
,
"text/event-stream"
)
c
.
Header
(
"Cache-Control"
,
"no-cache"
)
c
.
Header
(
"Cache-Control"
,
"no-cache"
)
c
.
Header
(
"Connection"
,
"keep-alive"
)
c
.
Header
(
"Connection"
,
"keep-alive"
)
c
.
Header
(
"X-Accel-Buffering"
,
"no"
)
c
.
Header
(
"X-Accel-Buffering"
,
"no"
)
// 根据拦截类型决定响应内容
var
msgID
string
var
outputTokens
int
var
textDeltas
[]
string
switch
interceptType
{
case
InterceptTypeSuggestionMode
:
msgID
=
"msg_mock_suggestion"
outputTokens
=
1
textDeltas
=
[]
string
{
""
}
// 空内容
default
:
// InterceptTypeWarmup
msgID
=
"msg_mock_warmup"
outputTokens
=
2
textDeltas
=
[]
string
{
"New"
,
" Conversation"
}
}
// Build message_start event with proper JSON marshaling
// Build message_start event with proper JSON marshaling
messageStart
:=
map
[
string
]
any
{
messageStart
:=
map
[
string
]
any
{
"type"
:
"message_start"
,
"type"
:
"message_start"
,
"message"
:
map
[
string
]
any
{
"message"
:
map
[
string
]
any
{
"id"
:
"
msg
_mock_warmup"
,
"id"
:
msg
ID
,
"type"
:
"message"
,
"type"
:
"message"
,
"role"
:
"assistant"
,
"role"
:
"assistant"
,
"model"
:
model
,
"model"
:
model
,
...
@@ -840,16 +887,46 @@ func sendMockWarmupStream(c *gin.Context, model string) {
...
@@ -840,16 +887,46 @@ func sendMockWarmupStream(c *gin.Context, model string) {
}
}
messageStartJSON
,
_
:=
json
.
Marshal
(
messageStart
)
messageStartJSON
,
_
:=
json
.
Marshal
(
messageStart
)
// Build events
events
:=
[]
string
{
events
:=
[]
string
{
`event: message_start`
+
"
\n
"
+
`data: `
+
string
(
messageStartJSON
),
`event: message_start`
+
"
\n
"
+
`data: `
+
string
(
messageStartJSON
),
`event: content_block_start`
+
"
\n
"
+
`data: {"content_block":{"text":"","type":"text"},"index":0,"type":"content_block_start"}`
,
`event: content_block_start`
+
"
\n
"
+
`data: {"content_block":{"text":"","type":"text"},"index":0,"type":"content_block_start"}`
,
`event: content_block_delta`
+
"
\n
"
+
`data: {"delta":{"text":"New","type":"text_delta"},"index":0,"type":"content_block_delta"}`
,
`event: content_block_delta`
+
"
\n
"
+
`data: {"delta":{"text":" Conversation","type":"text_delta"},"index":0,"type":"content_block_delta"}`
,
`event: content_block_stop`
+
"
\n
"
+
`data: {"index":0,"type":"content_block_stop"}`
,
`event: message_delta`
+
"
\n
"
+
`data: {"delta":{"stop_reason":"end_turn","stop_sequence":null},"type":"message_delta","usage":{"input_tokens":10,"output_tokens":2}}`
,
`event: message_stop`
+
"
\n
"
+
`data: {"type":"message_stop"}`
,
}
}
// Add text deltas
for
_
,
text
:=
range
textDeltas
{
delta
:=
map
[
string
]
any
{
"type"
:
"content_block_delta"
,
"index"
:
0
,
"delta"
:
map
[
string
]
string
{
"type"
:
"text_delta"
,
"text"
:
text
,
},
}
deltaJSON
,
_
:=
json
.
Marshal
(
delta
)
events
=
append
(
events
,
`event: content_block_delta`
+
"
\n
"
+
`data: `
+
string
(
deltaJSON
))
}
// Add final events
messageDelta
:=
map
[
string
]
any
{
"type"
:
"message_delta"
,
"delta"
:
map
[
string
]
any
{
"stop_reason"
:
"end_turn"
,
"stop_sequence"
:
nil
,
},
"usage"
:
map
[
string
]
int
{
"input_tokens"
:
10
,
"output_tokens"
:
outputTokens
,
},
}
messageDeltaJSON
,
_
:=
json
.
Marshal
(
messageDelta
)
events
=
append
(
events
,
`event: content_block_stop`
+
"
\n
"
+
`data: {"index":0,"type":"content_block_stop"}`
,
`event: message_delta`
+
"
\n
"
+
`data: `
+
string
(
messageDeltaJSON
),
`event: message_stop`
+
"
\n
"
+
`data: {"type":"message_stop"}`
,
)
for
_
,
event
:=
range
events
{
for
_
,
event
:=
range
events
{
_
,
_
=
c
.
Writer
.
WriteString
(
event
+
"
\n\n
"
)
_
,
_
=
c
.
Writer
.
WriteString
(
event
+
"
\n\n
"
)
c
.
Writer
.
Flush
()
c
.
Writer
.
Flush
()
...
@@ -857,18 +934,32 @@ func sendMockWarmupStream(c *gin.Context, model string) {
...
@@ -857,18 +934,32 @@ func sendMockWarmupStream(c *gin.Context, model string) {
}
}
}
}
// sendMockWarmupResponse 发送非流式 mock 响应(用于预热请求拦截)
// sendMockInterceptResponse 发送非流式 mock 响应(用于请求拦截)
func
sendMockWarmupResponse
(
c
*
gin
.
Context
,
model
string
)
{
func
sendMockInterceptResponse
(
c
*
gin
.
Context
,
model
string
,
interceptType
InterceptType
)
{
var
msgID
,
text
string
var
outputTokens
int
switch
interceptType
{
case
InterceptTypeSuggestionMode
:
msgID
=
"msg_mock_suggestion"
text
=
""
outputTokens
=
1
default
:
// InterceptTypeWarmup
msgID
=
"msg_mock_warmup"
text
=
"New Conversation"
outputTokens
=
2
}
c
.
JSON
(
http
.
StatusOK
,
gin
.
H
{
c
.
JSON
(
http
.
StatusOK
,
gin
.
H
{
"id"
:
"
msg
_mock_warmup"
,
"id"
:
msg
ID
,
"type"
:
"message"
,
"type"
:
"message"
,
"role"
:
"assistant"
,
"role"
:
"assistant"
,
"model"
:
model
,
"model"
:
model
,
"content"
:
[]
gin
.
H
{{
"type"
:
"text"
,
"text"
:
"New Conversation"
}},
"content"
:
[]
gin
.
H
{{
"type"
:
"text"
,
"text"
:
text
}},
"stop_reason"
:
"end_turn"
,
"stop_reason"
:
"end_turn"
,
"usage"
:
gin
.
H
{
"usage"
:
gin
.
H
{
"input_tokens"
:
10
,
"input_tokens"
:
10
,
"output_tokens"
:
2
,
"output_tokens"
:
outputTokens
,
},
},
})
})
}
}
...
...
backend/internal/handler/gemini_cli_session_test.go
0 → 100644
View file @
a161fcc8
//go:build unit
package
handler
import
(
"crypto/sha256"
"encoding/hex"
"net/http/httptest"
"testing"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
func
TestExtractGeminiCLISessionHash
(
t
*
testing
.
T
)
{
tests
:=
[]
struct
{
name
string
body
string
privilegedUserID
string
wantEmpty
bool
wantHash
string
}{
{
name
:
"with privileged-user-id and tmp dir"
,
body
:
`{"contents":[{"parts":[{"text":"The project's temporary directory is: /Users/ianshaw/.gemini/tmp/f7851b009ed314d1baee62e83115f486160283f4a55a582d89fdac8b9fe3b740"}]}]}`
,
privilegedUserID
:
"90785f52-8bbe-4b17-b111-a1ddea1636c3"
,
wantEmpty
:
false
,
wantHash
:
func
()
string
{
combined
:=
"90785f52-8bbe-4b17-b111-a1ddea1636c3:f7851b009ed314d1baee62e83115f486160283f4a55a582d89fdac8b9fe3b740"
hash
:=
sha256
.
Sum256
([]
byte
(
combined
))
return
hex
.
EncodeToString
(
hash
[
:
])
}(),
},
{
name
:
"without privileged-user-id but with tmp dir"
,
body
:
`{"contents":[{"parts":[{"text":"The project's temporary directory is: /Users/ianshaw/.gemini/tmp/f7851b009ed314d1baee62e83115f486160283f4a55a582d89fdac8b9fe3b740"}]}]}`
,
privilegedUserID
:
""
,
wantEmpty
:
false
,
wantHash
:
"f7851b009ed314d1baee62e83115f486160283f4a55a582d89fdac8b9fe3b740"
,
},
{
name
:
"without tmp dir"
,
body
:
`{"contents":[{"parts":[{"text":"Hello world"}]}]}`
,
privilegedUserID
:
"90785f52-8bbe-4b17-b111-a1ddea1636c3"
,
wantEmpty
:
true
,
},
{
name
:
"empty body"
,
body
:
""
,
privilegedUserID
:
"90785f52-8bbe-4b17-b111-a1ddea1636c3"
,
wantEmpty
:
true
,
},
}
for
_
,
tt
:=
range
tests
{
t
.
Run
(
tt
.
name
,
func
(
t
*
testing
.
T
)
{
// 创建测试上下文
w
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
w
)
c
.
Request
=
httptest
.
NewRequest
(
"POST"
,
"/test"
,
nil
)
if
tt
.
privilegedUserID
!=
""
{
c
.
Request
.
Header
.
Set
(
"x-gemini-api-privileged-user-id"
,
tt
.
privilegedUserID
)
}
// 调用函数
result
:=
extractGeminiCLISessionHash
(
c
,
[]
byte
(
tt
.
body
))
// 验证结果
if
tt
.
wantEmpty
{
require
.
Empty
(
t
,
result
,
"expected empty session hash"
)
}
else
{
require
.
NotEmpty
(
t
,
result
,
"expected non-empty session hash"
)
require
.
Equal
(
t
,
tt
.
wantHash
,
result
,
"session hash mismatch"
)
}
})
}
}
func
TestGeminiCLITmpDirRegex
(
t
*
testing
.
T
)
{
tests
:=
[]
struct
{
name
string
input
string
wantMatch
bool
wantHash
string
}{
{
name
:
"valid tmp dir path"
,
input
:
"/Users/ianshaw/.gemini/tmp/f7851b009ed314d1baee62e83115f486160283f4a55a582d89fdac8b9fe3b740"
,
wantMatch
:
true
,
wantHash
:
"f7851b009ed314d1baee62e83115f486160283f4a55a582d89fdac8b9fe3b740"
,
},
{
name
:
"valid tmp dir path in text"
,
input
:
"The project's temporary directory is: /Users/ianshaw/.gemini/tmp/f7851b009ed314d1baee62e83115f486160283f4a55a582d89fdac8b9fe3b740
\n
Other text"
,
wantMatch
:
true
,
wantHash
:
"f7851b009ed314d1baee62e83115f486160283f4a55a582d89fdac8b9fe3b740"
,
},
{
name
:
"invalid hash length"
,
input
:
"/Users/ianshaw/.gemini/tmp/abc123"
,
wantMatch
:
false
,
},
{
name
:
"no tmp dir"
,
input
:
"Hello world"
,
wantMatch
:
false
,
},
}
for
_
,
tt
:=
range
tests
{
t
.
Run
(
tt
.
name
,
func
(
t
*
testing
.
T
)
{
match
:=
geminiCLITmpDirRegex
.
FindStringSubmatch
(
tt
.
input
)
if
tt
.
wantMatch
{
require
.
NotNil
(
t
,
match
,
"expected regex to match"
)
require
.
Len
(
t
,
match
,
2
,
"expected 2 capture groups"
)
require
.
Equal
(
t
,
tt
.
wantHash
,
match
[
1
],
"hash mismatch"
)
}
else
{
require
.
Nil
(
t
,
match
,
"expected regex not to match"
)
}
})
}
}
backend/internal/handler/gemini_v1beta_handler.go
View file @
a161fcc8
package
handler
package
handler
import
(
import
(
"bytes"
"context"
"context"
"crypto/sha256"
"encoding/hex"
"errors"
"errors"
"io"
"io"
"log"
"log"
"net/http"
"net/http"
"regexp"
"strings"
"strings"
"time"
"time"
...
@@ -19,6 +23,17 @@ import (
...
@@ -19,6 +23,17 @@ import (
"github.com/gin-gonic/gin"
"github.com/gin-gonic/gin"
)
)
// geminiCLITmpDirRegex 用于从 Gemini CLI 请求体中提取 tmp 目录的哈希值
// 匹配格式: /Users/xxx/.gemini/tmp/[64位十六进制哈希]
var
geminiCLITmpDirRegex
=
regexp
.
MustCompile
(
`/\.gemini/tmp/([A-Fa-f0-9]{64})`
)
func
isGeminiCLIRequest
(
c
*
gin
.
Context
,
body
[]
byte
)
bool
{
if
strings
.
TrimSpace
(
c
.
GetHeader
(
"x-gemini-api-privileged-user-id"
))
!=
""
{
return
true
}
return
geminiCLITmpDirRegex
.
Match
(
body
)
}
// GeminiV1BetaListModels proxies:
// GeminiV1BetaListModels proxies:
// GET /v1beta/models
// GET /v1beta/models
func
(
h
*
GatewayHandler
)
GeminiV1BetaListModels
(
c
*
gin
.
Context
)
{
func
(
h
*
GatewayHandler
)
GeminiV1BetaListModels
(
c
*
gin
.
Context
)
{
...
@@ -214,12 +229,26 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
...
@@ -214,12 +229,26 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
}
}
// 3) select account (sticky session based on request body)
// 3) select account (sticky session based on request body)
// 优先使用 Gemini CLI 的会话标识(privileged-user-id + tmp 目录哈希)
sessionHash
:=
extractGeminiCLISessionHash
(
c
,
body
)
if
sessionHash
==
""
{
// Fallback: 使用通用的会话哈希生成逻辑(适用于其他客户端)
parsedReq
,
_
:=
service
.
ParseGatewayRequest
(
body
)
parsedReq
,
_
:=
service
.
ParseGatewayRequest
(
body
)
sessionHash
:=
h
.
gatewayService
.
GenerateSessionHash
(
parsedReq
)
sessionHash
=
h
.
gatewayService
.
GenerateSessionHash
(
parsedReq
)
}
sessionKey
:=
sessionHash
sessionKey
:=
sessionHash
if
sessionHash
!=
""
{
if
sessionHash
!=
""
{
sessionKey
=
"gemini:"
+
sessionHash
sessionKey
=
"gemini:"
+
sessionHash
}
}
// 查询粘性会话绑定的账号 ID(用于检测账号切换)
var
sessionBoundAccountID
int64
if
sessionKey
!=
""
{
sessionBoundAccountID
,
_
=
h
.
gatewayService
.
GetCachedSessionAccountID
(
c
.
Request
.
Context
(),
apiKey
.
GroupID
,
sessionKey
)
}
isCLI
:=
isGeminiCLIRequest
(
c
,
body
)
cleanedForUnknownBinding
:=
false
maxAccountSwitches
:=
h
.
maxAccountSwitchesGemini
maxAccountSwitches
:=
h
.
maxAccountSwitchesGemini
switchCount
:=
0
switchCount
:=
0
failedAccountIDs
:=
make
(
map
[
int64
]
struct
{})
failedAccountIDs
:=
make
(
map
[
int64
]
struct
{})
...
@@ -238,6 +267,24 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
...
@@ -238,6 +267,24 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
account
:=
selection
.
Account
account
:=
selection
.
Account
setOpsSelectedAccount
(
c
,
account
.
ID
)
setOpsSelectedAccount
(
c
,
account
.
ID
)
// 检测账号切换:如果粘性会话绑定的账号与当前选择的账号不同,清除 thoughtSignature
// 注意:Gemini 原生 API 的 thoughtSignature 与具体上游账号强相关;跨账号透传会导致 400。
if
sessionBoundAccountID
>
0
&&
sessionBoundAccountID
!=
account
.
ID
{
log
.
Printf
(
"[Gemini] Sticky session account switched: %d -> %d, cleaning thoughtSignature"
,
sessionBoundAccountID
,
account
.
ID
)
body
=
service
.
CleanGeminiNativeThoughtSignatures
(
body
)
sessionBoundAccountID
=
account
.
ID
}
else
if
sessionKey
!=
""
&&
sessionBoundAccountID
==
0
&&
isCLI
&&
!
cleanedForUnknownBinding
&&
bytes
.
Contains
(
body
,
[]
byte
(
`"thoughtSignature"`
))
{
// 无缓存绑定但请求里已有 thoughtSignature:常见于缓存丢失/TTL 过期后,CLI 继续携带旧签名。
// 为避免第一次转发就 400,这里做一次确定性清理,让新账号重新生成签名链路。
log
.
Printf
(
"[Gemini] Sticky session binding missing for CLI request, cleaning thoughtSignature proactively"
)
body
=
service
.
CleanGeminiNativeThoughtSignatures
(
body
)
cleanedForUnknownBinding
=
true
sessionBoundAccountID
=
account
.
ID
}
else
if
sessionBoundAccountID
==
0
{
// 记录本次请求中首次选择到的账号,便于同一请求内 failover 时检测切换。
sessionBoundAccountID
=
account
.
ID
}
// 4) account concurrency slot
// 4) account concurrency slot
accountReleaseFunc
:=
selection
.
ReleaseFunc
accountReleaseFunc
:=
selection
.
ReleaseFunc
if
!
selection
.
Acquired
{
if
!
selection
.
Acquired
{
...
@@ -433,3 +480,38 @@ func shouldFallbackGeminiModels(res *service.UpstreamHTTPResult) bool {
...
@@ -433,3 +480,38 @@ func shouldFallbackGeminiModels(res *service.UpstreamHTTPResult) bool {
}
}
return
false
return
false
}
}
// extractGeminiCLISessionHash 从 Gemini CLI 请求中提取会话标识。
// 组合 x-gemini-api-privileged-user-id header 和请求体中的 tmp 目录哈希。
//
// 会话标识生成策略:
// 1. 从请求体中提取 tmp 目录哈希(64位十六进制)
// 2. 从 header 中提取 privileged-user-id(UUID)
// 3. 组合两者生成 SHA256 哈希作为最终的会话标识
//
// 如果找不到 tmp 目录哈希,返回空字符串(不使用粘性会话)。
//
// extractGeminiCLISessionHash extracts session identifier from Gemini CLI requests.
// Combines x-gemini-api-privileged-user-id header with tmp directory hash from request body.
func
extractGeminiCLISessionHash
(
c
*
gin
.
Context
,
body
[]
byte
)
string
{
// 1. 从请求体中提取 tmp 目录哈希
match
:=
geminiCLITmpDirRegex
.
FindSubmatch
(
body
)
if
len
(
match
)
<
2
{
return
""
// 没有找到 tmp 目录,不使用粘性会话
}
tmpDirHash
:=
string
(
match
[
1
])
// 2. 提取 privileged-user-id
privilegedUserID
:=
strings
.
TrimSpace
(
c
.
GetHeader
(
"x-gemini-api-privileged-user-id"
))
// 3. 组合生成最终的 session hash
if
privilegedUserID
!=
""
{
// 组合两个标识符:privileged-user-id + tmp 目录哈希
combined
:=
privilegedUserID
+
":"
+
tmpDirHash
hash
:=
sha256
.
Sum256
([]
byte
(
combined
))
return
hex
.
EncodeToString
(
hash
[
:
])
}
// 如果没有 privileged-user-id,直接使用 tmp 目录哈希
return
tmpDirHash
}
backend/internal/handler/handler.go
View file @
a161fcc8
...
@@ -37,6 +37,7 @@ type Handlers struct {
...
@@ -37,6 +37,7 @@ type Handlers struct {
Gateway
*
GatewayHandler
Gateway
*
GatewayHandler
OpenAIGateway
*
OpenAIGatewayHandler
OpenAIGateway
*
OpenAIGatewayHandler
Setting
*
SettingHandler
Setting
*
SettingHandler
Totp
*
TotpHandler
}
}
// BuildInfo contains build-time information
// BuildInfo contains build-time information
...
...
backend/internal/handler/setting_handler.go
View file @
a161fcc8
...
@@ -35,6 +35,7 @@ func (h *SettingHandler) GetPublicSettings(c *gin.Context) {
...
@@ -35,6 +35,7 @@ func (h *SettingHandler) GetPublicSettings(c *gin.Context) {
RegistrationEnabled
:
settings
.
RegistrationEnabled
,
RegistrationEnabled
:
settings
.
RegistrationEnabled
,
EmailVerifyEnabled
:
settings
.
EmailVerifyEnabled
,
EmailVerifyEnabled
:
settings
.
EmailVerifyEnabled
,
PromoCodeEnabled
:
settings
.
PromoCodeEnabled
,
PromoCodeEnabled
:
settings
.
PromoCodeEnabled
,
PasswordResetEnabled
:
settings
.
PasswordResetEnabled
,
TurnstileEnabled
:
settings
.
TurnstileEnabled
,
TurnstileEnabled
:
settings
.
TurnstileEnabled
,
TurnstileSiteKey
:
settings
.
TurnstileSiteKey
,
TurnstileSiteKey
:
settings
.
TurnstileSiteKey
,
SiteName
:
settings
.
SiteName
,
SiteName
:
settings
.
SiteName
,
...
...
backend/internal/handler/totp_handler.go
0 → 100644
View file @
a161fcc8
package
handler
import
(
"github.com/gin-gonic/gin"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
middleware2
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
)
// TotpHandler handles TOTP-related requests
type
TotpHandler
struct
{
totpService
*
service
.
TotpService
}
// NewTotpHandler creates a new TotpHandler
func
NewTotpHandler
(
totpService
*
service
.
TotpService
)
*
TotpHandler
{
return
&
TotpHandler
{
totpService
:
totpService
,
}
}
// TotpStatusResponse represents the TOTP status response
type
TotpStatusResponse
struct
{
Enabled
bool
`json:"enabled"`
EnabledAt
*
int64
`json:"enabled_at,omitempty"`
// Unix timestamp
FeatureEnabled
bool
`json:"feature_enabled"`
}
// GetStatus returns the TOTP status for the current user
// GET /api/v1/user/totp/status
func
(
h
*
TotpHandler
)
GetStatus
(
c
*
gin
.
Context
)
{
subject
,
ok
:=
middleware2
.
GetAuthSubjectFromContext
(
c
)
if
!
ok
{
response
.
Unauthorized
(
c
,
"User not authenticated"
)
return
}
status
,
err
:=
h
.
totpService
.
GetStatus
(
c
.
Request
.
Context
(),
subject
.
UserID
)
if
err
!=
nil
{
response
.
ErrorFrom
(
c
,
err
)
return
}
resp
:=
TotpStatusResponse
{
Enabled
:
status
.
Enabled
,
FeatureEnabled
:
status
.
FeatureEnabled
,
}
if
status
.
EnabledAt
!=
nil
{
ts
:=
status
.
EnabledAt
.
Unix
()
resp
.
EnabledAt
=
&
ts
}
response
.
Success
(
c
,
resp
)
}
// TotpSetupRequest represents the request to initiate TOTP setup
type
TotpSetupRequest
struct
{
EmailCode
string
`json:"email_code"`
Password
string
`json:"password"`
}
// TotpSetupResponse represents the TOTP setup response
type
TotpSetupResponse
struct
{
Secret
string
`json:"secret"`
QRCodeURL
string
`json:"qr_code_url"`
SetupToken
string
`json:"setup_token"`
Countdown
int
`json:"countdown"`
}
// InitiateSetup starts the TOTP setup process
// POST /api/v1/user/totp/setup
func
(
h
*
TotpHandler
)
InitiateSetup
(
c
*
gin
.
Context
)
{
subject
,
ok
:=
middleware2
.
GetAuthSubjectFromContext
(
c
)
if
!
ok
{
response
.
Unauthorized
(
c
,
"User not authenticated"
)
return
}
var
req
TotpSetupRequest
if
err
:=
c
.
ShouldBindJSON
(
&
req
);
err
!=
nil
{
// Allow empty body (optional params)
req
=
TotpSetupRequest
{}
}
result
,
err
:=
h
.
totpService
.
InitiateSetup
(
c
.
Request
.
Context
(),
subject
.
UserID
,
req
.
EmailCode
,
req
.
Password
)
if
err
!=
nil
{
response
.
ErrorFrom
(
c
,
err
)
return
}
response
.
Success
(
c
,
TotpSetupResponse
{
Secret
:
result
.
Secret
,
QRCodeURL
:
result
.
QRCodeURL
,
SetupToken
:
result
.
SetupToken
,
Countdown
:
result
.
Countdown
,
})
}
// TotpEnableRequest represents the request to enable TOTP
type
TotpEnableRequest
struct
{
TotpCode
string
`json:"totp_code" binding:"required,len=6"`
SetupToken
string
`json:"setup_token" binding:"required"`
}
// Enable completes the TOTP setup
// POST /api/v1/user/totp/enable
func
(
h
*
TotpHandler
)
Enable
(
c
*
gin
.
Context
)
{
subject
,
ok
:=
middleware2
.
GetAuthSubjectFromContext
(
c
)
if
!
ok
{
response
.
Unauthorized
(
c
,
"User not authenticated"
)
return
}
var
req
TotpEnableRequest
if
err
:=
c
.
ShouldBindJSON
(
&
req
);
err
!=
nil
{
response
.
BadRequest
(
c
,
"Invalid request: "
+
err
.
Error
())
return
}
if
err
:=
h
.
totpService
.
CompleteSetup
(
c
.
Request
.
Context
(),
subject
.
UserID
,
req
.
TotpCode
,
req
.
SetupToken
);
err
!=
nil
{
response
.
ErrorFrom
(
c
,
err
)
return
}
response
.
Success
(
c
,
gin
.
H
{
"success"
:
true
})
}
// TotpDisableRequest represents the request to disable TOTP
type
TotpDisableRequest
struct
{
EmailCode
string
`json:"email_code"`
Password
string
`json:"password"`
}
// Disable disables TOTP for the current user
// POST /api/v1/user/totp/disable
func
(
h
*
TotpHandler
)
Disable
(
c
*
gin
.
Context
)
{
subject
,
ok
:=
middleware2
.
GetAuthSubjectFromContext
(
c
)
if
!
ok
{
response
.
Unauthorized
(
c
,
"User not authenticated"
)
return
}
var
req
TotpDisableRequest
if
err
:=
c
.
ShouldBindJSON
(
&
req
);
err
!=
nil
{
response
.
BadRequest
(
c
,
"Invalid request: "
+
err
.
Error
())
return
}
if
err
:=
h
.
totpService
.
Disable
(
c
.
Request
.
Context
(),
subject
.
UserID
,
req
.
EmailCode
,
req
.
Password
);
err
!=
nil
{
response
.
ErrorFrom
(
c
,
err
)
return
}
response
.
Success
(
c
,
gin
.
H
{
"success"
:
true
})
}
// GetVerificationMethod returns the verification method for TOTP operations
// GET /api/v1/user/totp/verification-method
func
(
h
*
TotpHandler
)
GetVerificationMethod
(
c
*
gin
.
Context
)
{
method
:=
h
.
totpService
.
GetVerificationMethod
(
c
.
Request
.
Context
())
response
.
Success
(
c
,
method
)
}
// SendVerifyCode sends an email verification code for TOTP operations
// POST /api/v1/user/totp/send-code
func
(
h
*
TotpHandler
)
SendVerifyCode
(
c
*
gin
.
Context
)
{
subject
,
ok
:=
middleware2
.
GetAuthSubjectFromContext
(
c
)
if
!
ok
{
response
.
Unauthorized
(
c
,
"User not authenticated"
)
return
}
if
err
:=
h
.
totpService
.
SendVerifyCode
(
c
.
Request
.
Context
(),
subject
.
UserID
);
err
!=
nil
{
response
.
ErrorFrom
(
c
,
err
)
return
}
response
.
Success
(
c
,
gin
.
H
{
"success"
:
true
})
}
backend/internal/handler/wire.go
View file @
a161fcc8
...
@@ -70,6 +70,7 @@ func ProvideHandlers(
...
@@ -70,6 +70,7 @@ func ProvideHandlers(
gatewayHandler
*
GatewayHandler
,
gatewayHandler
*
GatewayHandler
,
openaiGatewayHandler
*
OpenAIGatewayHandler
,
openaiGatewayHandler
*
OpenAIGatewayHandler
,
settingHandler
*
SettingHandler
,
settingHandler
*
SettingHandler
,
totpHandler
*
TotpHandler
,
)
*
Handlers
{
)
*
Handlers
{
return
&
Handlers
{
return
&
Handlers
{
Auth
:
authHandler
,
Auth
:
authHandler
,
...
@@ -82,6 +83,7 @@ func ProvideHandlers(
...
@@ -82,6 +83,7 @@ func ProvideHandlers(
Gateway
:
gatewayHandler
,
Gateway
:
gatewayHandler
,
OpenAIGateway
:
openaiGatewayHandler
,
OpenAIGateway
:
openaiGatewayHandler
,
Setting
:
settingHandler
,
Setting
:
settingHandler
,
Totp
:
totpHandler
,
}
}
}
}
...
@@ -96,6 +98,7 @@ var ProviderSet = wire.NewSet(
...
@@ -96,6 +98,7 @@ var ProviderSet = wire.NewSet(
NewSubscriptionHandler
,
NewSubscriptionHandler
,
NewGatewayHandler
,
NewGatewayHandler
,
NewOpenAIGatewayHandler
,
NewOpenAIGatewayHandler
,
NewTotpHandler
,
ProvideSettingHandler
,
ProvideSettingHandler
,
// Admin handlers
// Admin handlers
...
...
backend/internal/pkg/antigravity/request_transformer.go
View file @
a161fcc8
...
@@ -7,13 +7,11 @@ import (
...
@@ -7,13 +7,11 @@ import (
"fmt"
"fmt"
"log"
"log"
"math/rand"
"math/rand"
"os"
"strconv"
"strconv"
"strings"
"strings"
"sync"
"sync"
"time"
"time"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
"github.com/google/uuid"
)
)
...
@@ -369,8 +367,10 @@ func buildParts(content json.RawMessage, toolIDToName map[string]string, allowDu
...
@@ -369,8 +367,10 @@ func buildParts(content json.RawMessage, toolIDToName map[string]string, allowDu
Text
:
block
.
Thinking
,
Text
:
block
.
Thinking
,
Thought
:
true
,
Thought
:
true
,
}
}
// 保留原有 signature(Claude 模型需要有效的 signature)
// signature 处理:
if
block
.
Signature
!=
""
{
// - Claude 模型(allowDummyThought=false):必须是上游返回的真实 signature(dummy 视为缺失)
// - Gemini 模型(allowDummyThought=true):优先透传真实 signature,缺失时使用 dummy signature
if
block
.
Signature
!=
""
&&
(
allowDummyThought
||
block
.
Signature
!=
dummyThoughtSignature
)
{
part
.
ThoughtSignature
=
block
.
Signature
part
.
ThoughtSignature
=
block
.
Signature
}
else
if
!
allowDummyThought
{
}
else
if
!
allowDummyThought
{
// Claude 模型需要有效 signature;在缺失时降级为普通文本,并在上层禁用 thinking mode。
// Claude 模型需要有效 signature;在缺失时降级为普通文本,并在上层禁用 thinking mode。
...
@@ -409,12 +409,12 @@ func buildParts(content json.RawMessage, toolIDToName map[string]string, allowDu
...
@@ -409,12 +409,12 @@ func buildParts(content json.RawMessage, toolIDToName map[string]string, allowDu
},
},
}
}
// tool_use 的 signature 处理:
// tool_use 的 signature 处理:
// - Gemini 模型:使用 dummy signature(跳过 thought_signature 校验)
// - Claude 模型(allowDummyThought=false):必须是上游返回的真实 signature(dummy 视为缺失)
// - Claude 模型:透传上游返回的真实 signature(Vertex/Google 需要完整签名链路)
// - Gemini 模型(allowDummyThought=true):优先透传真实 signature,缺失时使用 dummy signature
if
allowDummyThought
{
if
block
.
Signature
!=
""
&&
(
allowDummyThought
||
block
.
Signature
!=
dummyThoughtSignature
)
{
part
.
ThoughtSignature
=
dummyThoughtSignature
}
else
if
block
.
Signature
!=
""
&&
block
.
Signature
!=
dummyThoughtSignature
{
part
.
ThoughtSignature
=
block
.
Signature
part
.
ThoughtSignature
=
block
.
Signature
}
else
if
allowDummyThought
{
part
.
ThoughtSignature
=
dummyThoughtSignature
}
}
parts
=
append
(
parts
,
part
)
parts
=
append
(
parts
,
part
)
...
@@ -594,11 +594,14 @@ func buildTools(tools []ClaudeTool) []GeminiToolDeclaration {
...
@@ -594,11 +594,14 @@ func buildTools(tools []ClaudeTool) []GeminiToolDeclaration {
}
}
// 清理 JSON Schema
// 清理 JSON Schema
params
:=
cleanJSONSchema
(
inputSchema
)
// 1. 深度清理 [undefined] 值
DeepCleanUndefined
(
inputSchema
)
// 2. 转换为符合 Gemini v1internal 的 schema
params
:=
CleanJSONSchema
(
inputSchema
)
// 为 nil schema 提供默认值
// 为 nil schema 提供默认值
if
params
==
nil
{
if
params
==
nil
{
params
=
map
[
string
]
any
{
params
=
map
[
string
]
any
{
"type"
:
"
OBJECT"
,
"type"
:
"
object"
,
// lowercase type
"properties"
:
map
[
string
]
any
{},
"properties"
:
map
[
string
]
any
{},
}
}
}
}
...
@@ -631,236 +634,3 @@ func buildTools(tools []ClaudeTool) []GeminiToolDeclaration {
...
@@ -631,236 +634,3 @@ func buildTools(tools []ClaudeTool) []GeminiToolDeclaration {
FunctionDeclarations
:
funcDecls
,
FunctionDeclarations
:
funcDecls
,
}}
}}
}
}
// cleanJSONSchema 清理 JSON Schema,移除 Antigravity/Gemini 不支持的字段
// 参考 proxycast 的实现,确保 schema 符合 JSON Schema draft 2020-12
func
cleanJSONSchema
(
schema
map
[
string
]
any
)
map
[
string
]
any
{
if
schema
==
nil
{
return
nil
}
cleaned
:=
cleanSchemaValue
(
schema
,
"$"
)
result
,
ok
:=
cleaned
.
(
map
[
string
]
any
)
if
!
ok
{
return
nil
}
// 确保有 type 字段(默认 OBJECT)
if
_
,
hasType
:=
result
[
"type"
];
!
hasType
{
result
[
"type"
]
=
"OBJECT"
}
// 确保有 properties 字段(默认空对象)
if
_
,
hasProps
:=
result
[
"properties"
];
!
hasProps
{
result
[
"properties"
]
=
make
(
map
[
string
]
any
)
}
// 验证 required 中的字段都存在于 properties 中
if
required
,
ok
:=
result
[
"required"
]
.
([]
any
);
ok
{
if
props
,
ok
:=
result
[
"properties"
]
.
(
map
[
string
]
any
);
ok
{
validRequired
:=
make
([]
any
,
0
,
len
(
required
))
for
_
,
r
:=
range
required
{
if
reqName
,
ok
:=
r
.
(
string
);
ok
{
if
_
,
exists
:=
props
[
reqName
];
exists
{
validRequired
=
append
(
validRequired
,
r
)
}
}
}
if
len
(
validRequired
)
>
0
{
result
[
"required"
]
=
validRequired
}
else
{
delete
(
result
,
"required"
)
}
}
}
return
result
}
var
schemaValidationKeys
=
map
[
string
]
bool
{
"minLength"
:
true
,
"maxLength"
:
true
,
"pattern"
:
true
,
"minimum"
:
true
,
"maximum"
:
true
,
"exclusiveMinimum"
:
true
,
"exclusiveMaximum"
:
true
,
"multipleOf"
:
true
,
"uniqueItems"
:
true
,
"minItems"
:
true
,
"maxItems"
:
true
,
"minProperties"
:
true
,
"maxProperties"
:
true
,
"patternProperties"
:
true
,
"propertyNames"
:
true
,
"dependencies"
:
true
,
"dependentSchemas"
:
true
,
"dependentRequired"
:
true
,
}
var
warnedSchemaKeys
sync
.
Map
func
schemaCleaningWarningsEnabled
()
bool
{
// 可通过环境变量强制开关,方便排查:SUB2API_SCHEMA_CLEAN_WARN=true/false
if
v
:=
strings
.
TrimSpace
(
os
.
Getenv
(
"SUB2API_SCHEMA_CLEAN_WARN"
));
v
!=
""
{
switch
strings
.
ToLower
(
v
)
{
case
"1"
,
"true"
,
"yes"
,
"on"
:
return
true
case
"0"
,
"false"
,
"no"
,
"off"
:
return
false
}
}
// 默认:非 release 模式下输出(debug/test)
return
gin
.
Mode
()
!=
gin
.
ReleaseMode
}
func
warnSchemaKeyRemovedOnce
(
key
,
path
string
)
{
if
!
schemaCleaningWarningsEnabled
()
{
return
}
if
!
schemaValidationKeys
[
key
]
{
return
}
if
_
,
loaded
:=
warnedSchemaKeys
.
LoadOrStore
(
key
,
struct
{}{});
loaded
{
return
}
log
.
Printf
(
"[SchemaClean] removed unsupported JSON Schema validation field key=%q path=%q"
,
key
,
path
)
}
// excludedSchemaKeys 不支持的 schema 字段
// 基于 Claude API (Vertex AI) 的实际支持情况
// 支持: type, description, enum, properties, required, additionalProperties, items
// 不支持: minItems, maxItems, minLength, maxLength, pattern, minimum, maximum 等验证字段
var
excludedSchemaKeys
=
map
[
string
]
bool
{
// 元 schema 字段
"$schema"
:
true
,
"$id"
:
true
,
"$ref"
:
true
,
// 字符串验证(Gemini 不支持)
"minLength"
:
true
,
"maxLength"
:
true
,
"pattern"
:
true
,
// 数字验证(Claude API 通过 Vertex AI 不支持这些字段)
"minimum"
:
true
,
"maximum"
:
true
,
"exclusiveMinimum"
:
true
,
"exclusiveMaximum"
:
true
,
"multipleOf"
:
true
,
// 数组验证(Claude API 通过 Vertex AI 不支持这些字段)
"uniqueItems"
:
true
,
"minItems"
:
true
,
"maxItems"
:
true
,
// 组合 schema(Gemini 不支持)
"oneOf"
:
true
,
"anyOf"
:
true
,
"allOf"
:
true
,
"not"
:
true
,
"if"
:
true
,
"then"
:
true
,
"else"
:
true
,
"$defs"
:
true
,
"definitions"
:
true
,
// 对象验证(仅保留 properties/required/additionalProperties)
"minProperties"
:
true
,
"maxProperties"
:
true
,
"patternProperties"
:
true
,
"propertyNames"
:
true
,
"dependencies"
:
true
,
"dependentSchemas"
:
true
,
"dependentRequired"
:
true
,
// 其他不支持的字段
"default"
:
true
,
"const"
:
true
,
"examples"
:
true
,
"deprecated"
:
true
,
"readOnly"
:
true
,
"writeOnly"
:
true
,
"contentMediaType"
:
true
,
"contentEncoding"
:
true
,
// Claude 特有字段
"strict"
:
true
,
}
// cleanSchemaValue 递归清理 schema 值
func
cleanSchemaValue
(
value
any
,
path
string
)
any
{
switch
v
:=
value
.
(
type
)
{
case
map
[
string
]
any
:
result
:=
make
(
map
[
string
]
any
)
for
k
,
val
:=
range
v
{
// 跳过不支持的字段
if
excludedSchemaKeys
[
k
]
{
warnSchemaKeyRemovedOnce
(
k
,
path
)
continue
}
// 特殊处理 type 字段
if
k
==
"type"
{
result
[
k
]
=
cleanTypeValue
(
val
)
continue
}
// 特殊处理 format 字段:只保留 Gemini 支持的 format 值
if
k
==
"format"
{
if
formatStr
,
ok
:=
val
.
(
string
);
ok
{
// Gemini 只支持 date-time, date, time
if
formatStr
==
"date-time"
||
formatStr
==
"date"
||
formatStr
==
"time"
{
result
[
k
]
=
val
}
// 其他 format 值直接跳过
}
continue
}
// 特殊处理 additionalProperties:Claude API 只支持布尔值,不支持 schema 对象
if
k
==
"additionalProperties"
{
if
boolVal
,
ok
:=
val
.
(
bool
);
ok
{
result
[
k
]
=
boolVal
}
else
{
// 如果是 schema 对象,转换为 false(更安全的默认值)
result
[
k
]
=
false
}
continue
}
// 递归清理所有值
result
[
k
]
=
cleanSchemaValue
(
val
,
path
+
"."
+
k
)
}
return
result
case
[]
any
:
// 递归处理数组中的每个元素
cleaned
:=
make
([]
any
,
0
,
len
(
v
))
for
i
,
item
:=
range
v
{
cleaned
=
append
(
cleaned
,
cleanSchemaValue
(
item
,
fmt
.
Sprintf
(
"%s[%d]"
,
path
,
i
)))
}
return
cleaned
default
:
return
value
}
}
// cleanTypeValue 处理 type 字段,转换为大写
func
cleanTypeValue
(
value
any
)
any
{
switch
v
:=
value
.
(
type
)
{
case
string
:
return
strings
.
ToUpper
(
v
)
case
[]
any
:
// 联合类型 ["string", "null"] -> 取第一个非 null 类型
for
_
,
t
:=
range
v
{
if
ts
,
ok
:=
t
.
(
string
);
ok
&&
ts
!=
"null"
{
return
strings
.
ToUpper
(
ts
)
}
}
// 如果只有 null,返回 STRING
return
"STRING"
default
:
return
value
}
}
backend/internal/pkg/antigravity/request_transformer_test.go
View file @
a161fcc8
...
@@ -100,7 +100,7 @@ func TestBuildParts_ToolUseSignatureHandling(t *testing.T) {
...
@@ -100,7 +100,7 @@ func TestBuildParts_ToolUseSignatureHandling(t *testing.T) {
{"type": "tool_use", "id": "t1", "name": "Bash", "input": {"command": "ls"}, "signature": "sig_tool_abc"}
{"type": "tool_use", "id": "t1", "name": "Bash", "input": {"command": "ls"}, "signature": "sig_tool_abc"}
]`
]`
t
.
Run
(
"Gemini
uses dummy
tool_use signature"
,
func
(
t
*
testing
.
T
)
{
t
.
Run
(
"Gemini
preserves provided
tool_use signature"
,
func
(
t
*
testing
.
T
)
{
toolIDToName
:=
make
(
map
[
string
]
string
)
toolIDToName
:=
make
(
map
[
string
]
string
)
parts
,
_
,
err
:=
buildParts
(
json
.
RawMessage
(
content
),
toolIDToName
,
true
)
parts
,
_
,
err
:=
buildParts
(
json
.
RawMessage
(
content
),
toolIDToName
,
true
)
if
err
!=
nil
{
if
err
!=
nil
{
...
@@ -109,6 +109,23 @@ func TestBuildParts_ToolUseSignatureHandling(t *testing.T) {
...
@@ -109,6 +109,23 @@ func TestBuildParts_ToolUseSignatureHandling(t *testing.T) {
if
len
(
parts
)
!=
1
||
parts
[
0
]
.
FunctionCall
==
nil
{
if
len
(
parts
)
!=
1
||
parts
[
0
]
.
FunctionCall
==
nil
{
t
.
Fatalf
(
"expected 1 functionCall part, got %+v"
,
parts
)
t
.
Fatalf
(
"expected 1 functionCall part, got %+v"
,
parts
)
}
}
if
parts
[
0
]
.
ThoughtSignature
!=
"sig_tool_abc"
{
t
.
Fatalf
(
"expected preserved tool signature %q, got %q"
,
"sig_tool_abc"
,
parts
[
0
]
.
ThoughtSignature
)
}
})
t
.
Run
(
"Gemini falls back to dummy tool_use signature when missing"
,
func
(
t
*
testing
.
T
)
{
contentNoSig
:=
`[
{"type": "tool_use", "id": "t1", "name": "Bash", "input": {"command": "ls"}}
]`
toolIDToName
:=
make
(
map
[
string
]
string
)
parts
,
_
,
err
:=
buildParts
(
json
.
RawMessage
(
contentNoSig
),
toolIDToName
,
true
)
if
err
!=
nil
{
t
.
Fatalf
(
"buildParts() error = %v"
,
err
)
}
if
len
(
parts
)
!=
1
||
parts
[
0
]
.
FunctionCall
==
nil
{
t
.
Fatalf
(
"expected 1 functionCall part, got %+v"
,
parts
)
}
if
parts
[
0
]
.
ThoughtSignature
!=
dummyThoughtSignature
{
if
parts
[
0
]
.
ThoughtSignature
!=
dummyThoughtSignature
{
t
.
Fatalf
(
"expected dummy tool signature %q, got %q"
,
dummyThoughtSignature
,
parts
[
0
]
.
ThoughtSignature
)
t
.
Fatalf
(
"expected dummy tool signature %q, got %q"
,
dummyThoughtSignature
,
parts
[
0
]
.
ThoughtSignature
)
}
}
...
...
backend/internal/pkg/antigravity/response_transformer.go
View file @
a161fcc8
...
@@ -3,6 +3,7 @@ package antigravity
...
@@ -3,6 +3,7 @@ package antigravity
import
(
import
(
"encoding/json"
"encoding/json"
"fmt"
"fmt"
"log"
"strings"
"strings"
)
)
...
@@ -19,6 +20,15 @@ func TransformGeminiToClaude(geminiResp []byte, originalModel string) ([]byte, *
...
@@ -19,6 +20,15 @@ func TransformGeminiToClaude(geminiResp []byte, originalModel string) ([]byte, *
v1Resp
.
Response
=
directResp
v1Resp
.
Response
=
directResp
v1Resp
.
ResponseID
=
directResp
.
ResponseID
v1Resp
.
ResponseID
=
directResp
.
ResponseID
v1Resp
.
ModelVersion
=
directResp
.
ModelVersion
v1Resp
.
ModelVersion
=
directResp
.
ModelVersion
}
else
if
len
(
v1Resp
.
Response
.
Candidates
)
==
0
{
// 第一次解析成功但 candidates 为空,说明是直接的 GeminiResponse 格式
var
directResp
GeminiResponse
if
err2
:=
json
.
Unmarshal
(
geminiResp
,
&
directResp
);
err2
!=
nil
{
return
nil
,
nil
,
fmt
.
Errorf
(
"parse gemini response as direct: %w"
,
err2
)
}
v1Resp
.
Response
=
directResp
v1Resp
.
ResponseID
=
directResp
.
ResponseID
v1Resp
.
ModelVersion
=
directResp
.
ModelVersion
}
}
// 使用处理器转换
// 使用处理器转换
...
@@ -173,16 +183,20 @@ func (p *NonStreamingProcessor) processPart(part *GeminiPart) {
...
@@ -173,16 +183,20 @@ func (p *NonStreamingProcessor) processPart(part *GeminiPart) {
p
.
trailingSignature
=
""
p
.
trailingSignature
=
""
}
}
p
.
textBuilder
+=
part
.
Text
// 非空 text 带签名 - 特殊处理:先输出 text,再输出空 thinking 块
// 非空 text 带签名 - 立即刷新并输出空 thinking 块
if
signature
!=
""
{
if
signature
!=
""
{
p
.
flushText
()
p
.
contentBlocks
=
append
(
p
.
contentBlocks
,
ClaudeContentItem
{
Type
:
"text"
,
Text
:
part
.
Text
,
})
p
.
contentBlocks
=
append
(
p
.
contentBlocks
,
ClaudeContentItem
{
p
.
contentBlocks
=
append
(
p
.
contentBlocks
,
ClaudeContentItem
{
Type
:
"thinking"
,
Type
:
"thinking"
,
Thinking
:
""
,
Thinking
:
""
,
Signature
:
signature
,
Signature
:
signature
,
})
})
}
else
{
// 普通 text (无签名) - 累积到 builder
p
.
textBuilder
+=
part
.
Text
}
}
}
}
}
}
...
@@ -242,6 +256,14 @@ func (p *NonStreamingProcessor) buildResponse(geminiResp *GeminiResponse, respon
...
@@ -242,6 +256,14 @@ func (p *NonStreamingProcessor) buildResponse(geminiResp *GeminiResponse, respon
var
finishReason
string
var
finishReason
string
if
len
(
geminiResp
.
Candidates
)
>
0
{
if
len
(
geminiResp
.
Candidates
)
>
0
{
finishReason
=
geminiResp
.
Candidates
[
0
]
.
FinishReason
finishReason
=
geminiResp
.
Candidates
[
0
]
.
FinishReason
if
finishReason
==
"MALFORMED_FUNCTION_CALL"
{
log
.
Printf
(
"[Antigravity] MALFORMED_FUNCTION_CALL detected in response for model %s"
,
originalModel
)
if
geminiResp
.
Candidates
[
0
]
.
Content
!=
nil
{
if
b
,
err
:=
json
.
Marshal
(
geminiResp
.
Candidates
[
0
]
.
Content
);
err
==
nil
{
log
.
Printf
(
"[Antigravity] Malformed content: %s"
,
string
(
b
))
}
}
}
}
}
stopReason
:=
"end_turn"
stopReason
:=
"end_turn"
...
...
backend/internal/pkg/antigravity/schema_cleaner.go
0 → 100644
View file @
a161fcc8
package
antigravity
import
(
"fmt"
"strings"
)
// CleanJSONSchema 清理 JSON Schema,移除 Antigravity/Gemini 不支持的字段
// 参考 Antigravity-Manager/src-tauri/src/proxy/common/json_schema.rs 实现
// 确保 schema 符合 JSON Schema draft 2020-12 且适配 Gemini v1internal
func
CleanJSONSchema
(
schema
map
[
string
]
any
)
map
[
string
]
any
{
if
schema
==
nil
{
return
nil
}
// 0. 预处理:展开 $ref (Schema Flattening)
// (Go map 是引用的,直接修改 schema)
flattenRefs
(
schema
,
extractDefs
(
schema
))
// 递归清理
cleaned
:=
cleanJSONSchemaRecursive
(
schema
)
result
,
ok
:=
cleaned
.
(
map
[
string
]
any
)
if
!
ok
{
return
nil
}
return
result
}
// extractDefs 提取并移除定义的 helper
func
extractDefs
(
schema
map
[
string
]
any
)
map
[
string
]
any
{
defs
:=
make
(
map
[
string
]
any
)
if
d
,
ok
:=
schema
[
"$defs"
]
.
(
map
[
string
]
any
);
ok
{
for
k
,
v
:=
range
d
{
defs
[
k
]
=
v
}
delete
(
schema
,
"$defs"
)
}
if
d
,
ok
:=
schema
[
"definitions"
]
.
(
map
[
string
]
any
);
ok
{
for
k
,
v
:=
range
d
{
defs
[
k
]
=
v
}
delete
(
schema
,
"definitions"
)
}
return
defs
}
// flattenRefs 递归展开 $ref
func
flattenRefs
(
schema
map
[
string
]
any
,
defs
map
[
string
]
any
)
{
if
len
(
defs
)
==
0
{
return
// 无需展开
}
// 检查并替换 $ref
if
ref
,
ok
:=
schema
[
"$ref"
]
.
(
string
);
ok
{
delete
(
schema
,
"$ref"
)
// 解析引用名 (例如 #/$defs/MyType -> MyType)
parts
:=
strings
.
Split
(
ref
,
"/"
)
refName
:=
parts
[
len
(
parts
)
-
1
]
if
defSchema
,
exists
:=
defs
[
refName
];
exists
{
if
defMap
,
ok
:=
defSchema
.
(
map
[
string
]
any
);
ok
{
// 合并定义内容 (不覆盖现有 key)
for
k
,
v
:=
range
defMap
{
if
_
,
has
:=
schema
[
k
];
!
has
{
schema
[
k
]
=
deepCopy
(
v
)
// 需深拷贝避免共享引用
}
}
// 递归处理刚刚合并进来的内容
flattenRefs
(
schema
,
defs
)
}
}
}
// 遍历子节点
for
_
,
v
:=
range
schema
{
if
subMap
,
ok
:=
v
.
(
map
[
string
]
any
);
ok
{
flattenRefs
(
subMap
,
defs
)
}
else
if
subArr
,
ok
:=
v
.
([]
any
);
ok
{
for
_
,
item
:=
range
subArr
{
if
itemMap
,
ok
:=
item
.
(
map
[
string
]
any
);
ok
{
flattenRefs
(
itemMap
,
defs
)
}
}
}
}
}
// deepCopy 深拷贝 (简单实现,仅针对 JSON 类型)
func
deepCopy
(
src
any
)
any
{
if
src
==
nil
{
return
nil
}
switch
v
:=
src
.
(
type
)
{
case
map
[
string
]
any
:
dst
:=
make
(
map
[
string
]
any
)
for
k
,
val
:=
range
v
{
dst
[
k
]
=
deepCopy
(
val
)
}
return
dst
case
[]
any
:
dst
:=
make
([]
any
,
len
(
v
))
for
i
,
val
:=
range
v
{
dst
[
i
]
=
deepCopy
(
val
)
}
return
dst
default
:
return
src
}
}
// cleanJSONSchemaRecursive 递归核心清理逻辑
// 返回处理后的值 (通常是 input map,但可能修改内部结构)
func
cleanJSONSchemaRecursive
(
value
any
)
any
{
schemaMap
,
ok
:=
value
.
(
map
[
string
]
any
)
if
!
ok
{
return
value
}
// 0. [NEW] 合并 allOf
mergeAllOf
(
schemaMap
)
// 1. [CRITICAL] 深度递归处理子项
if
props
,
ok
:=
schemaMap
[
"properties"
]
.
(
map
[
string
]
any
);
ok
{
for
_
,
v
:=
range
props
{
cleanJSONSchemaRecursive
(
v
)
}
// Go 中不需要像 Rust 那样显式处理 nullable_keys remove required,
// 因为我们在子项处理中会正确设置 type 和 description
}
else
if
items
,
ok
:=
schemaMap
[
"items"
];
ok
{
// [FIX] Gemini 期望 "items" 是单个 Schema 对象(列表验证),而不是数组(元组验证)。
if
itemsArr
,
ok
:=
items
.
([]
any
);
ok
{
// 策略:将元组 [A, B] 视为 A、B 中的最佳匹配项。
best
:=
extractBestSchemaFromUnion
(
itemsArr
)
if
best
==
nil
{
// 回退到通用字符串
best
=
map
[
string
]
any
{
"type"
:
"string"
}
}
// 用处理后的对象替换原有数组
cleanedBest
:=
cleanJSONSchemaRecursive
(
best
)
schemaMap
[
"items"
]
=
cleanedBest
}
else
{
cleanJSONSchemaRecursive
(
items
)
}
}
else
{
// 遍历所有值递归
for
_
,
v
:=
range
schemaMap
{
if
_
,
isMap
:=
v
.
(
map
[
string
]
any
);
isMap
{
cleanJSONSchemaRecursive
(
v
)
}
else
if
arr
,
isArr
:=
v
.
([]
any
);
isArr
{
for
_
,
item
:=
range
arr
{
cleanJSONSchemaRecursive
(
item
)
}
}
}
}
// 2. [FIX] 处理 anyOf/oneOf 联合类型: 合并属性而非直接删除
var
unionArray
[]
any
typeStr
,
_
:=
schemaMap
[
"type"
]
.
(
string
)
if
typeStr
==
""
||
typeStr
==
"object"
{
if
anyOf
,
ok
:=
schemaMap
[
"anyOf"
]
.
([]
any
);
ok
{
unionArray
=
anyOf
}
else
if
oneOf
,
ok
:=
schemaMap
[
"oneOf"
]
.
([]
any
);
ok
{
unionArray
=
oneOf
}
}
if
len
(
unionArray
)
>
0
{
if
bestBranch
:=
extractBestSchemaFromUnion
(
unionArray
);
bestBranch
!=
nil
{
if
bestMap
,
ok
:=
bestBranch
.
(
map
[
string
]
any
);
ok
{
// 合并分支内容
for
k
,
v
:=
range
bestMap
{
if
k
==
"properties"
{
targetProps
,
_
:=
schemaMap
[
"properties"
]
.
(
map
[
string
]
any
)
if
targetProps
==
nil
{
targetProps
=
make
(
map
[
string
]
any
)
schemaMap
[
"properties"
]
=
targetProps
}
if
sourceProps
,
ok
:=
v
.
(
map
[
string
]
any
);
ok
{
for
pk
,
pv
:=
range
sourceProps
{
if
_
,
exists
:=
targetProps
[
pk
];
!
exists
{
targetProps
[
pk
]
=
deepCopy
(
pv
)
}
}
}
}
else
if
k
==
"required"
{
targetReq
,
_
:=
schemaMap
[
"required"
]
.
([]
any
)
if
sourceReq
,
ok
:=
v
.
([]
any
);
ok
{
for
_
,
rv
:=
range
sourceReq
{
// 简单的去重添加
exists
:=
false
for
_
,
tr
:=
range
targetReq
{
if
tr
==
rv
{
exists
=
true
break
}
}
if
!
exists
{
targetReq
=
append
(
targetReq
,
rv
)
}
}
schemaMap
[
"required"
]
=
targetReq
}
}
else
if
_
,
exists
:=
schemaMap
[
k
];
!
exists
{
schemaMap
[
k
]
=
deepCopy
(
v
)
}
}
}
}
}
// 3. [SAFETY] 检查当前对象是否为 JSON Schema 节点
looksLikeSchema
:=
hasKey
(
schemaMap
,
"type"
)
||
hasKey
(
schemaMap
,
"properties"
)
||
hasKey
(
schemaMap
,
"items"
)
||
hasKey
(
schemaMap
,
"enum"
)
||
hasKey
(
schemaMap
,
"anyOf"
)
||
hasKey
(
schemaMap
,
"oneOf"
)
||
hasKey
(
schemaMap
,
"allOf"
)
if
looksLikeSchema
{
// 4. [ROBUST] 约束迁移
migrateConstraints
(
schemaMap
)
// 5. [CRITICAL] 白名单过滤
allowedFields
:=
map
[
string
]
bool
{
"type"
:
true
,
"description"
:
true
,
"properties"
:
true
,
"required"
:
true
,
"items"
:
true
,
"enum"
:
true
,
"title"
:
true
,
}
for
k
:=
range
schemaMap
{
if
!
allowedFields
[
k
]
{
delete
(
schemaMap
,
k
)
}
}
// 6. [SAFETY] 处理空 Object
if
t
,
_
:=
schemaMap
[
"type"
]
.
(
string
);
t
==
"object"
{
hasProps
:=
false
if
props
,
ok
:=
schemaMap
[
"properties"
]
.
(
map
[
string
]
any
);
ok
&&
len
(
props
)
>
0
{
hasProps
=
true
}
if
!
hasProps
{
schemaMap
[
"properties"
]
=
map
[
string
]
any
{
"reason"
:
map
[
string
]
any
{
"type"
:
"string"
,
"description"
:
"Reason for calling this tool"
,
},
}
schemaMap
[
"required"
]
=
[]
any
{
"reason"
}
}
}
// 7. [SAFETY] Required 字段对齐
if
props
,
ok
:=
schemaMap
[
"properties"
]
.
(
map
[
string
]
any
);
ok
{
if
req
,
ok
:=
schemaMap
[
"required"
]
.
([]
any
);
ok
{
var
validReq
[]
any
for
_
,
r
:=
range
req
{
if
rStr
,
ok
:=
r
.
(
string
);
ok
{
if
_
,
exists
:=
props
[
rStr
];
exists
{
validReq
=
append
(
validReq
,
r
)
}
}
}
if
len
(
validReq
)
>
0
{
schemaMap
[
"required"
]
=
validReq
}
else
{
delete
(
schemaMap
,
"required"
)
}
}
}
// 8. 处理 type 字段 (Lowercase + Nullable 提取)
isEffectivelyNullable
:=
false
if
typeVal
,
exists
:=
schemaMap
[
"type"
];
exists
{
var
selectedType
string
switch
v
:=
typeVal
.
(
type
)
{
case
string
:
lower
:=
strings
.
ToLower
(
v
)
if
lower
==
"null"
{
isEffectivelyNullable
=
true
selectedType
=
"string"
// fallback
}
else
{
selectedType
=
lower
}
case
[]
any
:
// ["string", "null"]
for
_
,
t
:=
range
v
{
if
ts
,
ok
:=
t
.
(
string
);
ok
{
lower
:=
strings
.
ToLower
(
ts
)
if
lower
==
"null"
{
isEffectivelyNullable
=
true
}
else
if
selectedType
==
""
{
selectedType
=
lower
}
}
}
if
selectedType
==
""
{
selectedType
=
"string"
}
}
schemaMap
[
"type"
]
=
selectedType
}
else
{
// 默认 object 如果有 properties (虽然上面白名单过滤可能删了 type 如果它不在... 但 type 必在 allowlist)
// 如果没有 type,但有 properties,补一个
if
hasKey
(
schemaMap
,
"properties"
)
{
schemaMap
[
"type"
]
=
"object"
}
else
{
// 默认为 string ? or object? Gemini 通常需要明确 type
schemaMap
[
"type"
]
=
"object"
}
}
if
isEffectivelyNullable
{
desc
,
_
:=
schemaMap
[
"description"
]
.
(
string
)
if
!
strings
.
Contains
(
desc
,
"nullable"
)
{
if
desc
!=
""
{
desc
+=
" "
}
desc
+=
"(nullable)"
schemaMap
[
"description"
]
=
desc
}
}
// 9. Enum 值强制转字符串
if
enumVals
,
ok
:=
schemaMap
[
"enum"
]
.
([]
any
);
ok
{
hasNonString
:=
false
for
i
,
val
:=
range
enumVals
{
if
_
,
isStr
:=
val
.
(
string
);
!
isStr
{
hasNonString
=
true
if
val
==
nil
{
enumVals
[
i
]
=
"null"
}
else
{
enumVals
[
i
]
=
fmt
.
Sprintf
(
"%v"
,
val
)
}
}
}
// If we mandated string values, we must ensure type is string
if
hasNonString
{
schemaMap
[
"type"
]
=
"string"
}
}
}
return
schemaMap
}
func
hasKey
(
m
map
[
string
]
any
,
k
string
)
bool
{
_
,
ok
:=
m
[
k
]
return
ok
}
func
migrateConstraints
(
m
map
[
string
]
any
)
{
constraints
:=
[]
struct
{
key
string
label
string
}{
{
"minLength"
,
"minLen"
},
{
"maxLength"
,
"maxLen"
},
{
"pattern"
,
"pattern"
},
{
"minimum"
,
"min"
},
{
"maximum"
,
"max"
},
{
"multipleOf"
,
"multipleOf"
},
{
"exclusiveMinimum"
,
"exclMin"
},
{
"exclusiveMaximum"
,
"exclMax"
},
{
"minItems"
,
"minItems"
},
{
"maxItems"
,
"maxItems"
},
{
"propertyNames"
,
"propertyNames"
},
{
"format"
,
"format"
},
}
var
hints
[]
string
for
_
,
c
:=
range
constraints
{
if
val
,
ok
:=
m
[
c
.
key
];
ok
&&
val
!=
nil
{
hints
=
append
(
hints
,
fmt
.
Sprintf
(
"%s: %v"
,
c
.
label
,
val
))
}
}
if
len
(
hints
)
>
0
{
suffix
:=
fmt
.
Sprintf
(
" [Constraint: %s]"
,
strings
.
Join
(
hints
,
", "
))
desc
,
_
:=
m
[
"description"
]
.
(
string
)
if
!
strings
.
Contains
(
desc
,
suffix
)
{
m
[
"description"
]
=
desc
+
suffix
}
}
}
// mergeAllOf 合并 allOf
func
mergeAllOf
(
m
map
[
string
]
any
)
{
allOf
,
ok
:=
m
[
"allOf"
]
.
([]
any
)
if
!
ok
{
return
}
delete
(
m
,
"allOf"
)
mergedProps
:=
make
(
map
[
string
]
any
)
mergedReq
:=
make
(
map
[
string
]
bool
)
otherFields
:=
make
(
map
[
string
]
any
)
for
_
,
sub
:=
range
allOf
{
if
subMap
,
ok
:=
sub
.
(
map
[
string
]
any
);
ok
{
// Props
if
props
,
ok
:=
subMap
[
"properties"
]
.
(
map
[
string
]
any
);
ok
{
for
k
,
v
:=
range
props
{
mergedProps
[
k
]
=
v
}
}
// Required
if
reqs
,
ok
:=
subMap
[
"required"
]
.
([]
any
);
ok
{
for
_
,
r
:=
range
reqs
{
if
s
,
ok
:=
r
.
(
string
);
ok
{
mergedReq
[
s
]
=
true
}
}
}
// Others
for
k
,
v
:=
range
subMap
{
if
k
!=
"properties"
&&
k
!=
"required"
&&
k
!=
"allOf"
{
if
_
,
exists
:=
otherFields
[
k
];
!
exists
{
otherFields
[
k
]
=
v
}
}
}
}
}
// Apply
for
k
,
v
:=
range
otherFields
{
if
_
,
exists
:=
m
[
k
];
!
exists
{
m
[
k
]
=
v
}
}
if
len
(
mergedProps
)
>
0
{
existProps
,
_
:=
m
[
"properties"
]
.
(
map
[
string
]
any
)
if
existProps
==
nil
{
existProps
=
make
(
map
[
string
]
any
)
m
[
"properties"
]
=
existProps
}
for
k
,
v
:=
range
mergedProps
{
if
_
,
exists
:=
existProps
[
k
];
!
exists
{
existProps
[
k
]
=
v
}
}
}
if
len
(
mergedReq
)
>
0
{
existReq
,
_
:=
m
[
"required"
]
.
([]
any
)
var
validReqs
[]
any
for
_
,
r
:=
range
existReq
{
if
s
,
ok
:=
r
.
(
string
);
ok
{
validReqs
=
append
(
validReqs
,
s
)
delete
(
mergedReq
,
s
)
// already exists
}
}
// append new
for
r
:=
range
mergedReq
{
validReqs
=
append
(
validReqs
,
r
)
}
m
[
"required"
]
=
validReqs
}
}
// extractBestSchemaFromUnion 从 anyOf/oneOf 中选取最佳分支
func
extractBestSchemaFromUnion
(
unionArray
[]
any
)
any
{
var
bestOption
any
bestScore
:=
-
1
for
_
,
item
:=
range
unionArray
{
score
:=
scoreSchemaOption
(
item
)
if
score
>
bestScore
{
bestScore
=
score
bestOption
=
item
}
}
return
bestOption
}
func
scoreSchemaOption
(
val
any
)
int
{
m
,
ok
:=
val
.
(
map
[
string
]
any
)
if
!
ok
{
return
0
}
typeStr
,
_
:=
m
[
"type"
]
.
(
string
)
if
hasKey
(
m
,
"properties"
)
||
typeStr
==
"object"
{
return
3
}
if
hasKey
(
m
,
"items"
)
||
typeStr
==
"array"
{
return
2
}
if
typeStr
!=
""
&&
typeStr
!=
"null"
{
return
1
}
return
0
}
// DeepCleanUndefined 深度清理值为 "[undefined]" 的字段
func
DeepCleanUndefined
(
value
any
)
{
if
value
==
nil
{
return
}
switch
v
:=
value
.
(
type
)
{
case
map
[
string
]
any
:
for
k
,
val
:=
range
v
{
if
s
,
ok
:=
val
.
(
string
);
ok
&&
s
==
"[undefined]"
{
delete
(
v
,
k
)
continue
}
DeepCleanUndefined
(
val
)
}
case
[]
any
:
for
_
,
val
:=
range
v
{
DeepCleanUndefined
(
val
)
}
}
}
backend/internal/pkg/antigravity/stream_transformer.go
View file @
a161fcc8
...
@@ -4,6 +4,7 @@ import (
...
@@ -4,6 +4,7 @@ import (
"bytes"
"bytes"
"encoding/json"
"encoding/json"
"fmt"
"fmt"
"log"
"strings"
"strings"
)
)
...
@@ -102,6 +103,14 @@ func (p *StreamingProcessor) ProcessLine(line string) []byte {
...
@@ -102,6 +103,14 @@ func (p *StreamingProcessor) ProcessLine(line string) []byte {
// 检查是否结束
// 检查是否结束
if
len
(
geminiResp
.
Candidates
)
>
0
{
if
len
(
geminiResp
.
Candidates
)
>
0
{
finishReason
:=
geminiResp
.
Candidates
[
0
]
.
FinishReason
finishReason
:=
geminiResp
.
Candidates
[
0
]
.
FinishReason
if
finishReason
==
"MALFORMED_FUNCTION_CALL"
{
log
.
Printf
(
"[Antigravity] MALFORMED_FUNCTION_CALL detected in stream for model %s"
,
p
.
originalModel
)
if
geminiResp
.
Candidates
[
0
]
.
Content
!=
nil
{
if
b
,
err
:=
json
.
Marshal
(
geminiResp
.
Candidates
[
0
]
.
Content
);
err
==
nil
{
log
.
Printf
(
"[Antigravity] Malformed content: %s"
,
string
(
b
))
}
}
}
if
finishReason
!=
""
{
if
finishReason
!=
""
{
_
,
_
=
result
.
Write
(
p
.
emitFinish
(
finishReason
))
_
,
_
=
result
.
Write
(
p
.
emitFinish
(
finishReason
))
}
}
...
...
backend/internal/pkg/oauth/oauth.go
View file @
a161fcc8
...
@@ -24,9 +24,9 @@ const (
...
@@ -24,9 +24,9 @@ const (
RedirectURI
=
"https://platform.claude.com/oauth/code/callback"
RedirectURI
=
"https://platform.claude.com/oauth/code/callback"
// Scopes - Browser URL (includes org:create_api_key for user authorization)
// Scopes - Browser URL (includes org:create_api_key for user authorization)
ScopeOAuth
=
"org:create_api_key user:profile user:inference user:sessions:claude_code"
ScopeOAuth
=
"org:create_api_key user:profile user:inference user:sessions:claude_code
user:mcp_servers
"
// Scopes - Internal API call (org:create_api_key not supported in API)
// Scopes - Internal API call (org:create_api_key not supported in API)
ScopeAPI
=
"user:profile user:inference user:sessions:claude_code"
ScopeAPI
=
"user:profile user:inference user:sessions:claude_code
user:mcp_servers
"
// Scopes - Setup token (inference only)
// Scopes - Setup token (inference only)
ScopeInference
=
"user:inference"
ScopeInference
=
"user:inference"
...
@@ -216,4 +216,5 @@ type OrgInfo struct {
...
@@ -216,4 +216,5 @@ type OrgInfo struct {
// AccountInfo represents account info from OAuth response
// AccountInfo represents account info from OAuth response
type
AccountInfo
struct
{
type
AccountInfo
struct
{
UUID
string
`json:"uuid"`
UUID
string
`json:"uuid"`
EmailAddress
string
`json:"email_address"`
}
}
backend/internal/pkg/tlsfingerprint/dialer_integration_test.go
0 → 100644
View file @
a161fcc8
//go:build integration
// Package tlsfingerprint provides TLS fingerprint simulation for HTTP clients.
//
// Integration tests for verifying TLS fingerprint correctness.
// These tests make actual network requests to external services and should be run manually.
//
// Run with: go test -v -tags=integration ./internal/pkg/tlsfingerprint/...
package
tlsfingerprint
import
(
"context"
"encoding/json"
"io"
"net/http"
"strings"
"testing"
"time"
)
// skipIfExternalServiceUnavailable checks if the external service is available.
// If not, it skips the test instead of failing.
func
skipIfExternalServiceUnavailable
(
t
*
testing
.
T
,
err
error
)
{
t
.
Helper
()
if
err
!=
nil
{
// Check for common network/TLS errors that indicate external service issues
errStr
:=
err
.
Error
()
if
strings
.
Contains
(
errStr
,
"certificate has expired"
)
||
strings
.
Contains
(
errStr
,
"certificate is not yet valid"
)
||
strings
.
Contains
(
errStr
,
"connection refused"
)
||
strings
.
Contains
(
errStr
,
"no such host"
)
||
strings
.
Contains
(
errStr
,
"network is unreachable"
)
||
strings
.
Contains
(
errStr
,
"timeout"
)
{
t
.
Skipf
(
"skipping test: external service unavailable: %v"
,
err
)
}
t
.
Fatalf
(
"failed to get fingerprint: %v"
,
err
)
}
}
// TestJA3Fingerprint verifies the JA3/JA4 fingerprint matches expected value.
// This test uses tls.peet.ws to verify the fingerprint.
// Expected JA3 hash: 1a28e69016765d92e3b381168d68922c (Claude CLI / Node.js 20.x)
// Expected JA4: t13d5911h1_a33745022dd6_1f22a2ca17c4 (d=domain) or t13i5911h1_... (i=IP)
func
TestJA3Fingerprint
(
t
*
testing
.
T
)
{
// Skip if network is unavailable or if running in short mode
if
testing
.
Short
()
{
t
.
Skip
(
"skipping integration test in short mode"
)
}
profile
:=
&
Profile
{
Name
:
"Claude CLI Test"
,
EnableGREASE
:
false
,
}
dialer
:=
NewDialer
(
profile
,
nil
)
client
:=
&
http
.
Client
{
Transport
:
&
http
.
Transport
{
DialTLSContext
:
dialer
.
DialTLSContext
,
},
Timeout
:
30
*
time
.
Second
,
}
// Use tls.peet.ws fingerprint detection API
ctx
,
cancel
:=
context
.
WithTimeout
(
context
.
Background
(),
30
*
time
.
Second
)
defer
cancel
()
req
,
err
:=
http
.
NewRequestWithContext
(
ctx
,
"GET"
,
"https://tls.peet.ws/api/all"
,
nil
)
if
err
!=
nil
{
t
.
Fatalf
(
"failed to create request: %v"
,
err
)
}
req
.
Header
.
Set
(
"User-Agent"
,
"Claude Code/2.0.0 Node.js/20.0.0"
)
resp
,
err
:=
client
.
Do
(
req
)
skipIfExternalServiceUnavailable
(
t
,
err
)
defer
func
()
{
_
=
resp
.
Body
.
Close
()
}()
body
,
err
:=
io
.
ReadAll
(
resp
.
Body
)
if
err
!=
nil
{
t
.
Fatalf
(
"failed to read response: %v"
,
err
)
}
var
fpResp
FingerprintResponse
if
err
:=
json
.
Unmarshal
(
body
,
&
fpResp
);
err
!=
nil
{
t
.
Logf
(
"Response body: %s"
,
string
(
body
))
t
.
Fatalf
(
"failed to parse fingerprint response: %v"
,
err
)
}
// Log all fingerprint information
t
.
Logf
(
"JA3: %s"
,
fpResp
.
TLS
.
JA3
)
t
.
Logf
(
"JA3 Hash: %s"
,
fpResp
.
TLS
.
JA3Hash
)
t
.
Logf
(
"JA4: %s"
,
fpResp
.
TLS
.
JA4
)
t
.
Logf
(
"PeetPrint: %s"
,
fpResp
.
TLS
.
PeetPrint
)
t
.
Logf
(
"PeetPrint Hash: %s"
,
fpResp
.
TLS
.
PeetPrintHash
)
// Verify JA3 hash matches expected value
expectedJA3Hash
:=
"1a28e69016765d92e3b381168d68922c"
if
fpResp
.
TLS
.
JA3Hash
==
expectedJA3Hash
{
t
.
Logf
(
"✓ JA3 hash matches expected value: %s"
,
expectedJA3Hash
)
}
else
{
t
.
Errorf
(
"✗ JA3 hash mismatch: got %s, expected %s"
,
fpResp
.
TLS
.
JA3Hash
,
expectedJA3Hash
)
}
// Verify JA4 fingerprint
// JA4 format: t[version][sni][cipher_count][ext_count][alpn]_[cipher_hash]_[ext_hash]
// Expected: t13d5910h1 (d=domain) or t13i5910h1 (i=IP)
// The suffix _a33745022dd6_1f22a2ca17c4 should match
expectedJA4Suffix
:=
"_a33745022dd6_1f22a2ca17c4"
if
strings
.
HasSuffix
(
fpResp
.
TLS
.
JA4
,
expectedJA4Suffix
)
{
t
.
Logf
(
"✓ JA4 suffix matches expected value: %s"
,
expectedJA4Suffix
)
}
else
{
t
.
Errorf
(
"✗ JA4 suffix mismatch: got %s, expected suffix %s"
,
fpResp
.
TLS
.
JA4
,
expectedJA4Suffix
)
}
// Verify JA4 prefix (t13d5911h1 or t13i5911h1)
// d = domain (SNI present), i = IP (no SNI)
// Since we connect to tls.peet.ws (domain), we expect 'd'
expectedJA4Prefix
:=
"t13d5911h1"
if
strings
.
HasPrefix
(
fpResp
.
TLS
.
JA4
,
expectedJA4Prefix
)
{
t
.
Logf
(
"✓ JA4 prefix matches: %s (t13=TLS1.3, d=domain, 59=ciphers, 11=extensions, h1=HTTP/1.1)"
,
expectedJA4Prefix
)
}
else
{
// Also accept 'i' variant for IP connections
altPrefix
:=
"t13i5911h1"
if
strings
.
HasPrefix
(
fpResp
.
TLS
.
JA4
,
altPrefix
)
{
t
.
Logf
(
"✓ JA4 prefix matches (IP variant): %s"
,
altPrefix
)
}
else
{
t
.
Errorf
(
"✗ JA4 prefix mismatch: got %s, expected %s or %s"
,
fpResp
.
TLS
.
JA4
,
expectedJA4Prefix
,
altPrefix
)
}
}
// Verify JA3 contains expected cipher suites (TLS 1.3 ciphers at the beginning)
if
strings
.
Contains
(
fpResp
.
TLS
.
JA3
,
"4866-4867-4865"
)
{
t
.
Logf
(
"✓ JA3 contains expected TLS 1.3 cipher suites"
)
}
else
{
t
.
Logf
(
"Warning: JA3 does not contain expected TLS 1.3 cipher suites"
)
}
// Verify extension list (should be 11 extensions including SNI)
// Expected: 0-11-10-35-16-22-23-13-43-45-51
expectedExtensions
:=
"0-11-10-35-16-22-23-13-43-45-51"
if
strings
.
Contains
(
fpResp
.
TLS
.
JA3
,
expectedExtensions
)
{
t
.
Logf
(
"✓ JA3 contains expected extension list: %s"
,
expectedExtensions
)
}
else
{
t
.
Logf
(
"Warning: JA3 extension list may differ"
)
}
}
// TestProfileExpectation defines expected fingerprint values for a profile.
type
TestProfileExpectation
struct
{
Profile
*
Profile
ExpectedJA3
string
// Expected JA3 hash (empty = don't check)
ExpectedJA4
string
// Expected full JA4 (empty = don't check)
JA4CipherHash
string
// Expected JA4 cipher hash - the stable middle part (empty = don't check)
}
// TestAllProfiles tests multiple TLS fingerprint profiles against tls.peet.ws.
// Run with: go test -v -tags=integration -run TestAllProfiles ./internal/pkg/tlsfingerprint/...
func
TestAllProfiles
(
t
*
testing
.
T
)
{
if
testing
.
Short
()
{
t
.
Skip
(
"skipping integration test in short mode"
)
}
// Define all profiles to test with their expected fingerprints
// These profiles are from config.yaml gateway.tls_fingerprint.profiles
profiles
:=
[]
TestProfileExpectation
{
{
// Linux x64 Node.js v22.17.1
// Expected JA3 Hash: 1a28e69016765d92e3b381168d68922c
// Expected JA4: t13d5911h1_a33745022dd6_1f22a2ca17c4
Profile
:
&
Profile
{
Name
:
"linux_x64_node_v22171"
,
EnableGREASE
:
false
,
CipherSuites
:
[]
uint16
{
4866
,
4867
,
4865
,
49199
,
49195
,
49200
,
49196
,
158
,
49191
,
103
,
49192
,
107
,
163
,
159
,
52393
,
52392
,
52394
,
49327
,
49325
,
49315
,
49311
,
49245
,
49249
,
49239
,
49235
,
162
,
49326
,
49324
,
49314
,
49310
,
49244
,
49248
,
49238
,
49234
,
49188
,
106
,
49187
,
64
,
49162
,
49172
,
57
,
56
,
49161
,
49171
,
51
,
50
,
157
,
49313
,
49309
,
49233
,
156
,
49312
,
49308
,
49232
,
61
,
60
,
53
,
47
,
255
},
Curves
:
[]
uint16
{
29
,
23
,
30
,
25
,
24
,
256
,
257
,
258
,
259
,
260
},
PointFormats
:
[]
uint8
{
0
,
1
,
2
},
},
JA4CipherHash
:
"a33745022dd6"
,
// stable part
},
{
// MacOS arm64 Node.js v22.18.0
// Expected JA3 Hash: 70cb5ca646080902703ffda87036a5ea
// Expected JA4: t13d5912h1_a33745022dd6_dbd39dd1d406
Profile
:
&
Profile
{
Name
:
"macos_arm64_node_v22180"
,
EnableGREASE
:
false
,
CipherSuites
:
[]
uint16
{
4866
,
4867
,
4865
,
49199
,
49195
,
49200
,
49196
,
158
,
49191
,
103
,
49192
,
107
,
163
,
159
,
52393
,
52392
,
52394
,
49327
,
49325
,
49315
,
49311
,
49245
,
49249
,
49239
,
49235
,
162
,
49326
,
49324
,
49314
,
49310
,
49244
,
49248
,
49238
,
49234
,
49188
,
106
,
49187
,
64
,
49162
,
49172
,
57
,
56
,
49161
,
49171
,
51
,
50
,
157
,
49313
,
49309
,
49233
,
156
,
49312
,
49308
,
49232
,
61
,
60
,
53
,
47
,
255
},
Curves
:
[]
uint16
{
29
,
23
,
30
,
25
,
24
,
256
,
257
,
258
,
259
,
260
},
PointFormats
:
[]
uint8
{
0
,
1
,
2
},
},
JA4CipherHash
:
"a33745022dd6"
,
// stable part (same cipher suites)
},
}
for
_
,
tc
:=
range
profiles
{
tc
:=
tc
// capture range variable
t
.
Run
(
tc
.
Profile
.
Name
,
func
(
t
*
testing
.
T
)
{
fp
:=
fetchFingerprint
(
t
,
tc
.
Profile
)
if
fp
==
nil
{
return
// fetchFingerprint already called t.Fatal
}
t
.
Logf
(
"Profile: %s"
,
tc
.
Profile
.
Name
)
t
.
Logf
(
" JA3: %s"
,
fp
.
JA3
)
t
.
Logf
(
" JA3 Hash: %s"
,
fp
.
JA3Hash
)
t
.
Logf
(
" JA4: %s"
,
fp
.
JA4
)
t
.
Logf
(
" PeetPrint: %s"
,
fp
.
PeetPrint
)
t
.
Logf
(
" PeetPrintHash: %s"
,
fp
.
PeetPrintHash
)
// Verify expectations
if
tc
.
ExpectedJA3
!=
""
{
if
fp
.
JA3Hash
==
tc
.
ExpectedJA3
{
t
.
Logf
(
" ✓ JA3 hash matches: %s"
,
tc
.
ExpectedJA3
)
}
else
{
t
.
Errorf
(
" ✗ JA3 hash mismatch: got %s, expected %s"
,
fp
.
JA3Hash
,
tc
.
ExpectedJA3
)
}
}
if
tc
.
ExpectedJA4
!=
""
{
if
fp
.
JA4
==
tc
.
ExpectedJA4
{
t
.
Logf
(
" ✓ JA4 matches: %s"
,
tc
.
ExpectedJA4
)
}
else
{
t
.
Errorf
(
" ✗ JA4 mismatch: got %s, expected %s"
,
fp
.
JA4
,
tc
.
ExpectedJA4
)
}
}
// Check JA4 cipher hash (stable middle part)
// JA4 format: prefix_cipherHash_extHash
if
tc
.
JA4CipherHash
!=
""
{
if
strings
.
Contains
(
fp
.
JA4
,
"_"
+
tc
.
JA4CipherHash
+
"_"
)
{
t
.
Logf
(
" ✓ JA4 cipher hash matches: %s"
,
tc
.
JA4CipherHash
)
}
else
{
t
.
Errorf
(
" ✗ JA4 cipher hash mismatch: got %s, expected cipher hash %s"
,
fp
.
JA4
,
tc
.
JA4CipherHash
)
}
}
})
}
}
// fetchFingerprint makes a request to tls.peet.ws and returns the TLS fingerprint info.
func
fetchFingerprint
(
t
*
testing
.
T
,
profile
*
Profile
)
*
TLSInfo
{
t
.
Helper
()
dialer
:=
NewDialer
(
profile
,
nil
)
client
:=
&
http
.
Client
{
Transport
:
&
http
.
Transport
{
DialTLSContext
:
dialer
.
DialTLSContext
,
},
Timeout
:
30
*
time
.
Second
,
}
ctx
,
cancel
:=
context
.
WithTimeout
(
context
.
Background
(),
30
*
time
.
Second
)
defer
cancel
()
req
,
err
:=
http
.
NewRequestWithContext
(
ctx
,
"GET"
,
"https://tls.peet.ws/api/all"
,
nil
)
if
err
!=
nil
{
t
.
Fatalf
(
"failed to create request: %v"
,
err
)
return
nil
}
req
.
Header
.
Set
(
"User-Agent"
,
"Claude Code/2.0.0 Node.js/20.0.0"
)
resp
,
err
:=
client
.
Do
(
req
)
skipIfExternalServiceUnavailable
(
t
,
err
)
defer
func
()
{
_
=
resp
.
Body
.
Close
()
}()
body
,
err
:=
io
.
ReadAll
(
resp
.
Body
)
if
err
!=
nil
{
t
.
Fatalf
(
"failed to read response: %v"
,
err
)
return
nil
}
var
fpResp
FingerprintResponse
if
err
:=
json
.
Unmarshal
(
body
,
&
fpResp
);
err
!=
nil
{
t
.
Logf
(
"Response body: %s"
,
string
(
body
))
t
.
Fatalf
(
"failed to parse fingerprint response: %v"
,
err
)
return
nil
}
return
&
fpResp
.
TLS
}
backend/internal/pkg/tlsfingerprint/dialer_test.go
View file @
a161fcc8
// Package tlsfingerprint provides TLS fingerprint simulation for HTTP clients.
// Package tlsfingerprint provides TLS fingerprint simulation for HTTP clients.
//
//
// Integration tests for verifying TLS fingerprint correctness.
// Unit tests for TLS fingerprint dialer.
// These tests make actual network requests and should be run manually.
// Integration tests that require external network are in dialer_integration_test.go
// and require the 'integration' build tag.
//
//
// Run
with
: go test -v ./internal/pkg/tlsfingerprint/...
// Run
unit tests
: go test -v ./internal/pkg/tlsfingerprint/...
// Run integration tests: go test -v -
run TestJA3
./internal/pkg/tlsfingerprint/...
// Run integration tests: go test -v -
tags=integration
./internal/pkg/tlsfingerprint/...
package
tlsfingerprint
package
tlsfingerprint
import
(
import
(
"context"
"encoding/json"
"io"
"net/http"
"net/url"
"net/url"
"strings"
"testing"
"testing"
"time"
)
)
// FingerprintResponse represents the response from tls.peet.ws/api/all.
// FingerprintResponse represents the response from tls.peet.ws/api/all.
...
@@ -36,148 +31,6 @@ type TLSInfo struct {
...
@@ -36,148 +31,6 @@ type TLSInfo struct {
SessionID
string
`json:"session_id"`
SessionID
string
`json:"session_id"`
}
}
// TestDialerBasicConnection tests that the dialer can establish TLS connections.
func
TestDialerBasicConnection
(
t
*
testing
.
T
)
{
if
testing
.
Short
()
{
t
.
Skip
(
"skipping network test in short mode"
)
}
// Create a dialer with default profile
profile
:=
&
Profile
{
Name
:
"Test Profile"
,
EnableGREASE
:
false
,
}
dialer
:=
NewDialer
(
profile
,
nil
)
// Create HTTP client with custom TLS dialer
client
:=
&
http
.
Client
{
Transport
:
&
http
.
Transport
{
DialTLSContext
:
dialer
.
DialTLSContext
,
},
Timeout
:
30
*
time
.
Second
,
}
// Make a request to a known HTTPS endpoint
resp
,
err
:=
client
.
Get
(
"https://www.google.com"
)
if
err
!=
nil
{
t
.
Fatalf
(
"failed to connect: %v"
,
err
)
}
defer
func
()
{
_
=
resp
.
Body
.
Close
()
}()
if
resp
.
StatusCode
!=
http
.
StatusOK
{
t
.
Errorf
(
"expected status 200, got %d"
,
resp
.
StatusCode
)
}
}
// TestJA3Fingerprint verifies the JA3/JA4 fingerprint matches expected value.
// This test uses tls.peet.ws to verify the fingerprint.
// Expected JA3 hash: 1a28e69016765d92e3b381168d68922c (Claude CLI / Node.js 20.x)
// Expected JA4: t13d5911h1_a33745022dd6_1f22a2ca17c4 (d=domain) or t13i5911h1_... (i=IP)
func
TestJA3Fingerprint
(
t
*
testing
.
T
)
{
// Skip if network is unavailable or if running in short mode
if
testing
.
Short
()
{
t
.
Skip
(
"skipping integration test in short mode"
)
}
profile
:=
&
Profile
{
Name
:
"Claude CLI Test"
,
EnableGREASE
:
false
,
}
dialer
:=
NewDialer
(
profile
,
nil
)
client
:=
&
http
.
Client
{
Transport
:
&
http
.
Transport
{
DialTLSContext
:
dialer
.
DialTLSContext
,
},
Timeout
:
30
*
time
.
Second
,
}
// Use tls.peet.ws fingerprint detection API
ctx
,
cancel
:=
context
.
WithTimeout
(
context
.
Background
(),
30
*
time
.
Second
)
defer
cancel
()
req
,
err
:=
http
.
NewRequestWithContext
(
ctx
,
"GET"
,
"https://tls.peet.ws/api/all"
,
nil
)
if
err
!=
nil
{
t
.
Fatalf
(
"failed to create request: %v"
,
err
)
}
req
.
Header
.
Set
(
"User-Agent"
,
"Claude Code/2.0.0 Node.js/20.0.0"
)
resp
,
err
:=
client
.
Do
(
req
)
if
err
!=
nil
{
t
.
Fatalf
(
"failed to get fingerprint: %v"
,
err
)
}
defer
func
()
{
_
=
resp
.
Body
.
Close
()
}()
body
,
err
:=
io
.
ReadAll
(
resp
.
Body
)
if
err
!=
nil
{
t
.
Fatalf
(
"failed to read response: %v"
,
err
)
}
var
fpResp
FingerprintResponse
if
err
:=
json
.
Unmarshal
(
body
,
&
fpResp
);
err
!=
nil
{
t
.
Logf
(
"Response body: %s"
,
string
(
body
))
t
.
Fatalf
(
"failed to parse fingerprint response: %v"
,
err
)
}
// Log all fingerprint information
t
.
Logf
(
"JA3: %s"
,
fpResp
.
TLS
.
JA3
)
t
.
Logf
(
"JA3 Hash: %s"
,
fpResp
.
TLS
.
JA3Hash
)
t
.
Logf
(
"JA4: %s"
,
fpResp
.
TLS
.
JA4
)
t
.
Logf
(
"PeetPrint: %s"
,
fpResp
.
TLS
.
PeetPrint
)
t
.
Logf
(
"PeetPrint Hash: %s"
,
fpResp
.
TLS
.
PeetPrintHash
)
// Verify JA3 hash matches expected value
expectedJA3Hash
:=
"1a28e69016765d92e3b381168d68922c"
if
fpResp
.
TLS
.
JA3Hash
==
expectedJA3Hash
{
t
.
Logf
(
"✓ JA3 hash matches expected value: %s"
,
expectedJA3Hash
)
}
else
{
t
.
Errorf
(
"✗ JA3 hash mismatch: got %s, expected %s"
,
fpResp
.
TLS
.
JA3Hash
,
expectedJA3Hash
)
}
// Verify JA4 fingerprint
// JA4 format: t[version][sni][cipher_count][ext_count][alpn]_[cipher_hash]_[ext_hash]
// Expected: t13d5910h1 (d=domain) or t13i5910h1 (i=IP)
// The suffix _a33745022dd6_1f22a2ca17c4 should match
expectedJA4Suffix
:=
"_a33745022dd6_1f22a2ca17c4"
if
strings
.
HasSuffix
(
fpResp
.
TLS
.
JA4
,
expectedJA4Suffix
)
{
t
.
Logf
(
"✓ JA4 suffix matches expected value: %s"
,
expectedJA4Suffix
)
}
else
{
t
.
Errorf
(
"✗ JA4 suffix mismatch: got %s, expected suffix %s"
,
fpResp
.
TLS
.
JA4
,
expectedJA4Suffix
)
}
// Verify JA4 prefix (t13d5911h1 or t13i5911h1)
// d = domain (SNI present), i = IP (no SNI)
// Since we connect to tls.peet.ws (domain), we expect 'd'
expectedJA4Prefix
:=
"t13d5911h1"
if
strings
.
HasPrefix
(
fpResp
.
TLS
.
JA4
,
expectedJA4Prefix
)
{
t
.
Logf
(
"✓ JA4 prefix matches: %s (t13=TLS1.3, d=domain, 59=ciphers, 11=extensions, h1=HTTP/1.1)"
,
expectedJA4Prefix
)
}
else
{
// Also accept 'i' variant for IP connections
altPrefix
:=
"t13i5911h1"
if
strings
.
HasPrefix
(
fpResp
.
TLS
.
JA4
,
altPrefix
)
{
t
.
Logf
(
"✓ JA4 prefix matches (IP variant): %s"
,
altPrefix
)
}
else
{
t
.
Errorf
(
"✗ JA4 prefix mismatch: got %s, expected %s or %s"
,
fpResp
.
TLS
.
JA4
,
expectedJA4Prefix
,
altPrefix
)
}
}
// Verify JA3 contains expected cipher suites (TLS 1.3 ciphers at the beginning)
if
strings
.
Contains
(
fpResp
.
TLS
.
JA3
,
"4866-4867-4865"
)
{
t
.
Logf
(
"✓ JA3 contains expected TLS 1.3 cipher suites"
)
}
else
{
t
.
Logf
(
"Warning: JA3 does not contain expected TLS 1.3 cipher suites"
)
}
// Verify extension list (should be 11 extensions including SNI)
// Expected: 0-11-10-35-16-22-23-13-43-45-51
expectedExtensions
:=
"0-11-10-35-16-22-23-13-43-45-51"
if
strings
.
Contains
(
fpResp
.
TLS
.
JA3
,
expectedExtensions
)
{
t
.
Logf
(
"✓ JA3 contains expected extension list: %s"
,
expectedExtensions
)
}
else
{
t
.
Logf
(
"Warning: JA3 extension list may differ"
)
}
}
// TestDialerWithProfile tests that different profiles produce different fingerprints.
// TestDialerWithProfile tests that different profiles produce different fingerprints.
func
TestDialerWithProfile
(
t
*
testing
.
T
)
{
func
TestDialerWithProfile
(
t
*
testing
.
T
)
{
// Create two dialers with different profiles
// Create two dialers with different profiles
...
@@ -305,139 +158,3 @@ func mustParseURL(rawURL string) *url.URL {
...
@@ -305,139 +158,3 @@ func mustParseURL(rawURL string) *url.URL {
}
}
return
u
return
u
}
}
// TestProfileExpectation defines expected fingerprint values for a profile.
type
TestProfileExpectation
struct
{
Profile
*
Profile
ExpectedJA3
string
// Expected JA3 hash (empty = don't check)
ExpectedJA4
string
// Expected full JA4 (empty = don't check)
JA4CipherHash
string
// Expected JA4 cipher hash - the stable middle part (empty = don't check)
}
// TestAllProfiles tests multiple TLS fingerprint profiles against tls.peet.ws.
// Run with: go test -v -run TestAllProfiles ./internal/pkg/tlsfingerprint/...
func
TestAllProfiles
(
t
*
testing
.
T
)
{
if
testing
.
Short
()
{
t
.
Skip
(
"skipping integration test in short mode"
)
}
// Define all profiles to test with their expected fingerprints
// These profiles are from config.yaml gateway.tls_fingerprint.profiles
profiles
:=
[]
TestProfileExpectation
{
{
// Linux x64 Node.js v22.17.1
// Expected JA3 Hash: 1a28e69016765d92e3b381168d68922c
// Expected JA4: t13d5911h1_a33745022dd6_1f22a2ca17c4
Profile
:
&
Profile
{
Name
:
"linux_x64_node_v22171"
,
EnableGREASE
:
false
,
CipherSuites
:
[]
uint16
{
4866
,
4867
,
4865
,
49199
,
49195
,
49200
,
49196
,
158
,
49191
,
103
,
49192
,
107
,
163
,
159
,
52393
,
52392
,
52394
,
49327
,
49325
,
49315
,
49311
,
49245
,
49249
,
49239
,
49235
,
162
,
49326
,
49324
,
49314
,
49310
,
49244
,
49248
,
49238
,
49234
,
49188
,
106
,
49187
,
64
,
49162
,
49172
,
57
,
56
,
49161
,
49171
,
51
,
50
,
157
,
49313
,
49309
,
49233
,
156
,
49312
,
49308
,
49232
,
61
,
60
,
53
,
47
,
255
},
Curves
:
[]
uint16
{
29
,
23
,
30
,
25
,
24
,
256
,
257
,
258
,
259
,
260
},
PointFormats
:
[]
uint8
{
0
,
1
,
2
},
},
JA4CipherHash
:
"a33745022dd6"
,
// stable part
},
{
// MacOS arm64 Node.js v22.18.0
// Expected JA3 Hash: 70cb5ca646080902703ffda87036a5ea
// Expected JA4: t13d5912h1_a33745022dd6_dbd39dd1d406
Profile
:
&
Profile
{
Name
:
"macos_arm64_node_v22180"
,
EnableGREASE
:
false
,
CipherSuites
:
[]
uint16
{
4866
,
4867
,
4865
,
49199
,
49195
,
49200
,
49196
,
158
,
49191
,
103
,
49192
,
107
,
163
,
159
,
52393
,
52392
,
52394
,
49327
,
49325
,
49315
,
49311
,
49245
,
49249
,
49239
,
49235
,
162
,
49326
,
49324
,
49314
,
49310
,
49244
,
49248
,
49238
,
49234
,
49188
,
106
,
49187
,
64
,
49162
,
49172
,
57
,
56
,
49161
,
49171
,
51
,
50
,
157
,
49313
,
49309
,
49233
,
156
,
49312
,
49308
,
49232
,
61
,
60
,
53
,
47
,
255
},
Curves
:
[]
uint16
{
29
,
23
,
30
,
25
,
24
,
256
,
257
,
258
,
259
,
260
},
PointFormats
:
[]
uint8
{
0
,
1
,
2
},
},
JA4CipherHash
:
"a33745022dd6"
,
// stable part (same cipher suites)
},
}
for
_
,
tc
:=
range
profiles
{
tc
:=
tc
// capture range variable
t
.
Run
(
tc
.
Profile
.
Name
,
func
(
t
*
testing
.
T
)
{
fp
:=
fetchFingerprint
(
t
,
tc
.
Profile
)
if
fp
==
nil
{
return
// fetchFingerprint already called t.Fatal
}
t
.
Logf
(
"Profile: %s"
,
tc
.
Profile
.
Name
)
t
.
Logf
(
" JA3: %s"
,
fp
.
JA3
)
t
.
Logf
(
" JA3 Hash: %s"
,
fp
.
JA3Hash
)
t
.
Logf
(
" JA4: %s"
,
fp
.
JA4
)
t
.
Logf
(
" PeetPrint: %s"
,
fp
.
PeetPrint
)
t
.
Logf
(
" PeetPrintHash: %s"
,
fp
.
PeetPrintHash
)
// Verify expectations
if
tc
.
ExpectedJA3
!=
""
{
if
fp
.
JA3Hash
==
tc
.
ExpectedJA3
{
t
.
Logf
(
" ✓ JA3 hash matches: %s"
,
tc
.
ExpectedJA3
)
}
else
{
t
.
Errorf
(
" ✗ JA3 hash mismatch: got %s, expected %s"
,
fp
.
JA3Hash
,
tc
.
ExpectedJA3
)
}
}
if
tc
.
ExpectedJA4
!=
""
{
if
fp
.
JA4
==
tc
.
ExpectedJA4
{
t
.
Logf
(
" ✓ JA4 matches: %s"
,
tc
.
ExpectedJA4
)
}
else
{
t
.
Errorf
(
" ✗ JA4 mismatch: got %s, expected %s"
,
fp
.
JA4
,
tc
.
ExpectedJA4
)
}
}
// Check JA4 cipher hash (stable middle part)
// JA4 format: prefix_cipherHash_extHash
if
tc
.
JA4CipherHash
!=
""
{
if
strings
.
Contains
(
fp
.
JA4
,
"_"
+
tc
.
JA4CipherHash
+
"_"
)
{
t
.
Logf
(
" ✓ JA4 cipher hash matches: %s"
,
tc
.
JA4CipherHash
)
}
else
{
t
.
Errorf
(
" ✗ JA4 cipher hash mismatch: got %s, expected cipher hash %s"
,
fp
.
JA4
,
tc
.
JA4CipherHash
)
}
}
})
}
}
// fetchFingerprint makes a request to tls.peet.ws and returns the TLS fingerprint info.
func
fetchFingerprint
(
t
*
testing
.
T
,
profile
*
Profile
)
*
TLSInfo
{
t
.
Helper
()
dialer
:=
NewDialer
(
profile
,
nil
)
client
:=
&
http
.
Client
{
Transport
:
&
http
.
Transport
{
DialTLSContext
:
dialer
.
DialTLSContext
,
},
Timeout
:
30
*
time
.
Second
,
}
ctx
,
cancel
:=
context
.
WithTimeout
(
context
.
Background
(),
30
*
time
.
Second
)
defer
cancel
()
req
,
err
:=
http
.
NewRequestWithContext
(
ctx
,
"GET"
,
"https://tls.peet.ws/api/all"
,
nil
)
if
err
!=
nil
{
t
.
Fatalf
(
"failed to create request: %v"
,
err
)
return
nil
}
req
.
Header
.
Set
(
"User-Agent"
,
"Claude Code/2.0.0 Node.js/20.0.0"
)
resp
,
err
:=
client
.
Do
(
req
)
if
err
!=
nil
{
t
.
Fatalf
(
"failed to get fingerprint: %v"
,
err
)
return
nil
}
defer
func
()
{
_
=
resp
.
Body
.
Close
()
}()
body
,
err
:=
io
.
ReadAll
(
resp
.
Body
)
if
err
!=
nil
{
t
.
Fatalf
(
"failed to read response: %v"
,
err
)
return
nil
}
var
fpResp
FingerprintResponse
if
err
:=
json
.
Unmarshal
(
body
,
&
fpResp
);
err
!=
nil
{
t
.
Logf
(
"Response body: %s"
,
string
(
body
))
t
.
Fatalf
(
"failed to parse fingerprint response: %v"
,
err
)
return
nil
}
return
&
fpResp
.
TLS
}
backend/internal/repository/aes_encryptor.go
0 → 100644
View file @
a161fcc8
package
repository
import
(
"crypto/aes"
"crypto/cipher"
"crypto/rand"
"encoding/base64"
"encoding/hex"
"fmt"
"io"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/service"
)
// AESEncryptor implements SecretEncryptor using AES-256-GCM
type
AESEncryptor
struct
{
key
[]
byte
}
// NewAESEncryptor creates a new AES encryptor
func
NewAESEncryptor
(
cfg
*
config
.
Config
)
(
service
.
SecretEncryptor
,
error
)
{
key
,
err
:=
hex
.
DecodeString
(
cfg
.
Totp
.
EncryptionKey
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"invalid totp encryption key: %w"
,
err
)
}
if
len
(
key
)
!=
32
{
return
nil
,
fmt
.
Errorf
(
"totp encryption key must be 32 bytes (64 hex chars), got %d bytes"
,
len
(
key
))
}
return
&
AESEncryptor
{
key
:
key
},
nil
}
// Encrypt encrypts plaintext using AES-256-GCM
// Output format: base64(nonce + ciphertext + tag)
func
(
e
*
AESEncryptor
)
Encrypt
(
plaintext
string
)
(
string
,
error
)
{
block
,
err
:=
aes
.
NewCipher
(
e
.
key
)
if
err
!=
nil
{
return
""
,
fmt
.
Errorf
(
"create cipher: %w"
,
err
)
}
gcm
,
err
:=
cipher
.
NewGCM
(
block
)
if
err
!=
nil
{
return
""
,
fmt
.
Errorf
(
"create gcm: %w"
,
err
)
}
// Generate a random nonce
nonce
:=
make
([]
byte
,
gcm
.
NonceSize
())
if
_
,
err
:=
io
.
ReadFull
(
rand
.
Reader
,
nonce
);
err
!=
nil
{
return
""
,
fmt
.
Errorf
(
"generate nonce: %w"
,
err
)
}
// Encrypt the plaintext
// Seal appends the ciphertext and tag to the nonce
ciphertext
:=
gcm
.
Seal
(
nonce
,
nonce
,
[]
byte
(
plaintext
),
nil
)
// Encode as base64
return
base64
.
StdEncoding
.
EncodeToString
(
ciphertext
),
nil
}
// Decrypt decrypts ciphertext using AES-256-GCM
func
(
e
*
AESEncryptor
)
Decrypt
(
ciphertext
string
)
(
string
,
error
)
{
// Decode from base64
data
,
err
:=
base64
.
StdEncoding
.
DecodeString
(
ciphertext
)
if
err
!=
nil
{
return
""
,
fmt
.
Errorf
(
"decode base64: %w"
,
err
)
}
block
,
err
:=
aes
.
NewCipher
(
e
.
key
)
if
err
!=
nil
{
return
""
,
fmt
.
Errorf
(
"create cipher: %w"
,
err
)
}
gcm
,
err
:=
cipher
.
NewGCM
(
block
)
if
err
!=
nil
{
return
""
,
fmt
.
Errorf
(
"create gcm: %w"
,
err
)
}
nonceSize
:=
gcm
.
NonceSize
()
if
len
(
data
)
<
nonceSize
{
return
""
,
fmt
.
Errorf
(
"ciphertext too short"
)
}
// Extract nonce and ciphertext
nonce
,
ciphertextData
:=
data
[
:
nonceSize
],
data
[
nonceSize
:
]
// Decrypt
plaintext
,
err
:=
gcm
.
Open
(
nil
,
nonce
,
ciphertextData
,
nil
)
if
err
!=
nil
{
return
""
,
fmt
.
Errorf
(
"decrypt: %w"
,
err
)
}
return
string
(
plaintext
),
nil
}
backend/internal/repository/api_key_repo.go
View file @
a161fcc8
...
@@ -396,6 +396,9 @@ func userEntityToService(u *dbent.User) *service.User {
...
@@ -396,6 +396,9 @@ func userEntityToService(u *dbent.User) *service.User {
Balance
:
u
.
Balance
,
Balance
:
u
.
Balance
,
Concurrency
:
u
.
Concurrency
,
Concurrency
:
u
.
Concurrency
,
Status
:
u
.
Status
,
Status
:
u
.
Status
,
TotpSecretEncrypted
:
u
.
TotpSecretEncrypted
,
TotpEnabled
:
u
.
TotpEnabled
,
TotpEnabledAt
:
u
.
TotpEnabledAt
,
CreatedAt
:
u
.
CreatedAt
,
CreatedAt
:
u
.
CreatedAt
,
UpdatedAt
:
u
.
UpdatedAt
,
UpdatedAt
:
u
.
UpdatedAt
,
}
}
...
...
backend/internal/repository/claude_oauth_service.go
View file @
a161fcc8
...
@@ -36,6 +36,8 @@ func (s *claudeOAuthService) GetOrganizationUUID(ctx context.Context, sessionKey
...
@@ -36,6 +36,8 @@ func (s *claudeOAuthService) GetOrganizationUUID(ctx context.Context, sessionKey
var
orgs
[]
struct
{
var
orgs
[]
struct
{
UUID
string
`json:"uuid"`
UUID
string
`json:"uuid"`
Name
string
`json:"name"`
RavenType
*
string
`json:"raven_type"`
// nil for personal, "team" for team organization
}
}
targetURL
:=
s
.
baseURL
+
"/api/organizations"
targetURL
:=
s
.
baseURL
+
"/api/organizations"
...
@@ -65,7 +67,23 @@ func (s *claudeOAuthService) GetOrganizationUUID(ctx context.Context, sessionKey
...
@@ -65,7 +67,23 @@ func (s *claudeOAuthService) GetOrganizationUUID(ctx context.Context, sessionKey
return
""
,
fmt
.
Errorf
(
"no organizations found"
)
return
""
,
fmt
.
Errorf
(
"no organizations found"
)
}
}
log
.
Printf
(
"[OAuth] Step 1 SUCCESS - Got org UUID: %s"
,
orgs
[
0
]
.
UUID
)
// 如果只有一个组织,直接使用
if
len
(
orgs
)
==
1
{
log
.
Printf
(
"[OAuth] Step 1 SUCCESS - Single org found, UUID: %s, Name: %s"
,
orgs
[
0
]
.
UUID
,
orgs
[
0
]
.
Name
)
return
orgs
[
0
]
.
UUID
,
nil
}
// 如果有多个组织,优先选择 raven_type 为 "team" 的组织
for
_
,
org
:=
range
orgs
{
if
org
.
RavenType
!=
nil
&&
*
org
.
RavenType
==
"team"
{
log
.
Printf
(
"[OAuth] Step 1 SUCCESS - Selected team org, UUID: %s, Name: %s, RavenType: %s"
,
org
.
UUID
,
org
.
Name
,
*
org
.
RavenType
)
return
org
.
UUID
,
nil
}
}
// 如果没有 team 类型的组织,使用第一个
log
.
Printf
(
"[OAuth] Step 1 SUCCESS - No team org found, using first org, UUID: %s, Name: %s"
,
orgs
[
0
]
.
UUID
,
orgs
[
0
]
.
Name
)
return
orgs
[
0
]
.
UUID
,
nil
return
orgs
[
0
]
.
UUID
,
nil
}
}
...
...
backend/internal/repository/email_cache.go
View file @
a161fcc8
...
@@ -9,13 +9,27 @@ import (
...
@@ -9,13 +9,27 @@ import (
"github.com/redis/go-redis/v9"
"github.com/redis/go-redis/v9"
)
)
const
verifyCodeKeyPrefix
=
"verify_code:"
const
(
verifyCodeKeyPrefix
=
"verify_code:"
passwordResetKeyPrefix
=
"password_reset:"
passwordResetSentAtKeyPrefix
=
"password_reset_sent:"
)
// verifyCodeKey generates the Redis key for email verification code.
// verifyCodeKey generates the Redis key for email verification code.
func
verifyCodeKey
(
email
string
)
string
{
func
verifyCodeKey
(
email
string
)
string
{
return
verifyCodeKeyPrefix
+
email
return
verifyCodeKeyPrefix
+
email
}
}
// passwordResetKey generates the Redis key for password reset token.
func
passwordResetKey
(
email
string
)
string
{
return
passwordResetKeyPrefix
+
email
}
// passwordResetSentAtKey generates the Redis key for password reset email sent timestamp.
func
passwordResetSentAtKey
(
email
string
)
string
{
return
passwordResetSentAtKeyPrefix
+
email
}
type
emailCache
struct
{
type
emailCache
struct
{
rdb
*
redis
.
Client
rdb
*
redis
.
Client
}
}
...
@@ -50,3 +64,45 @@ func (c *emailCache) DeleteVerificationCode(ctx context.Context, email string) e
...
@@ -50,3 +64,45 @@ func (c *emailCache) DeleteVerificationCode(ctx context.Context, email string) e
key
:=
verifyCodeKey
(
email
)
key
:=
verifyCodeKey
(
email
)
return
c
.
rdb
.
Del
(
ctx
,
key
)
.
Err
()
return
c
.
rdb
.
Del
(
ctx
,
key
)
.
Err
()
}
}
// Password reset token methods
func
(
c
*
emailCache
)
GetPasswordResetToken
(
ctx
context
.
Context
,
email
string
)
(
*
service
.
PasswordResetTokenData
,
error
)
{
key
:=
passwordResetKey
(
email
)
val
,
err
:=
c
.
rdb
.
Get
(
ctx
,
key
)
.
Result
()
if
err
!=
nil
{
return
nil
,
err
}
var
data
service
.
PasswordResetTokenData
if
err
:=
json
.
Unmarshal
([]
byte
(
val
),
&
data
);
err
!=
nil
{
return
nil
,
err
}
return
&
data
,
nil
}
func
(
c
*
emailCache
)
SetPasswordResetToken
(
ctx
context
.
Context
,
email
string
,
data
*
service
.
PasswordResetTokenData
,
ttl
time
.
Duration
)
error
{
key
:=
passwordResetKey
(
email
)
val
,
err
:=
json
.
Marshal
(
data
)
if
err
!=
nil
{
return
err
}
return
c
.
rdb
.
Set
(
ctx
,
key
,
val
,
ttl
)
.
Err
()
}
func
(
c
*
emailCache
)
DeletePasswordResetToken
(
ctx
context
.
Context
,
email
string
)
error
{
key
:=
passwordResetKey
(
email
)
return
c
.
rdb
.
Del
(
ctx
,
key
)
.
Err
()
}
// Password reset email cooldown methods
func
(
c
*
emailCache
)
IsPasswordResetEmailInCooldown
(
ctx
context
.
Context
,
email
string
)
bool
{
key
:=
passwordResetSentAtKey
(
email
)
exists
,
err
:=
c
.
rdb
.
Exists
(
ctx
,
key
)
.
Result
()
return
err
==
nil
&&
exists
>
0
}
func
(
c
*
emailCache
)
SetPasswordResetEmailCooldown
(
ctx
context
.
Context
,
email
string
,
ttl
time
.
Duration
)
error
{
key
:=
passwordResetSentAtKey
(
email
)
return
c
.
rdb
.
Set
(
ctx
,
key
,
"1"
,
ttl
)
.
Err
()
}
backend/internal/repository/totp_cache.go
0 → 100644
View file @
a161fcc8
package
repository
import
(
"context"
"encoding/json"
"fmt"
"time"
"github.com/redis/go-redis/v9"
"github.com/Wei-Shaw/sub2api/internal/service"
)
const
(
totpSetupKeyPrefix
=
"totp:setup:"
totpLoginKeyPrefix
=
"totp:login:"
totpAttemptsKeyPrefix
=
"totp:attempts:"
totpAttemptsTTL
=
15
*
time
.
Minute
)
// TotpCache implements service.TotpCache using Redis
type
TotpCache
struct
{
rdb
*
redis
.
Client
}
// NewTotpCache creates a new TOTP cache
func
NewTotpCache
(
rdb
*
redis
.
Client
)
service
.
TotpCache
{
return
&
TotpCache
{
rdb
:
rdb
}
}
// GetSetupSession retrieves a TOTP setup session
func
(
c
*
TotpCache
)
GetSetupSession
(
ctx
context
.
Context
,
userID
int64
)
(
*
service
.
TotpSetupSession
,
error
)
{
key
:=
fmt
.
Sprintf
(
"%s%d"
,
totpSetupKeyPrefix
,
userID
)
data
,
err
:=
c
.
rdb
.
Get
(
ctx
,
key
)
.
Bytes
()
if
err
!=
nil
{
if
err
==
redis
.
Nil
{
return
nil
,
nil
}
return
nil
,
fmt
.
Errorf
(
"get setup session: %w"
,
err
)
}
var
session
service
.
TotpSetupSession
if
err
:=
json
.
Unmarshal
(
data
,
&
session
);
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"unmarshal setup session: %w"
,
err
)
}
return
&
session
,
nil
}
// SetSetupSession stores a TOTP setup session
func
(
c
*
TotpCache
)
SetSetupSession
(
ctx
context
.
Context
,
userID
int64
,
session
*
service
.
TotpSetupSession
,
ttl
time
.
Duration
)
error
{
key
:=
fmt
.
Sprintf
(
"%s%d"
,
totpSetupKeyPrefix
,
userID
)
data
,
err
:=
json
.
Marshal
(
session
)
if
err
!=
nil
{
return
fmt
.
Errorf
(
"marshal setup session: %w"
,
err
)
}
if
err
:=
c
.
rdb
.
Set
(
ctx
,
key
,
data
,
ttl
)
.
Err
();
err
!=
nil
{
return
fmt
.
Errorf
(
"set setup session: %w"
,
err
)
}
return
nil
}
// DeleteSetupSession deletes a TOTP setup session
func
(
c
*
TotpCache
)
DeleteSetupSession
(
ctx
context
.
Context
,
userID
int64
)
error
{
key
:=
fmt
.
Sprintf
(
"%s%d"
,
totpSetupKeyPrefix
,
userID
)
return
c
.
rdb
.
Del
(
ctx
,
key
)
.
Err
()
}
// GetLoginSession retrieves a TOTP login session
func
(
c
*
TotpCache
)
GetLoginSession
(
ctx
context
.
Context
,
tempToken
string
)
(
*
service
.
TotpLoginSession
,
error
)
{
key
:=
totpLoginKeyPrefix
+
tempToken
data
,
err
:=
c
.
rdb
.
Get
(
ctx
,
key
)
.
Bytes
()
if
err
!=
nil
{
if
err
==
redis
.
Nil
{
return
nil
,
nil
}
return
nil
,
fmt
.
Errorf
(
"get login session: %w"
,
err
)
}
var
session
service
.
TotpLoginSession
if
err
:=
json
.
Unmarshal
(
data
,
&
session
);
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"unmarshal login session: %w"
,
err
)
}
return
&
session
,
nil
}
// SetLoginSession stores a TOTP login session
func
(
c
*
TotpCache
)
SetLoginSession
(
ctx
context
.
Context
,
tempToken
string
,
session
*
service
.
TotpLoginSession
,
ttl
time
.
Duration
)
error
{
key
:=
totpLoginKeyPrefix
+
tempToken
data
,
err
:=
json
.
Marshal
(
session
)
if
err
!=
nil
{
return
fmt
.
Errorf
(
"marshal login session: %w"
,
err
)
}
if
err
:=
c
.
rdb
.
Set
(
ctx
,
key
,
data
,
ttl
)
.
Err
();
err
!=
nil
{
return
fmt
.
Errorf
(
"set login session: %w"
,
err
)
}
return
nil
}
// DeleteLoginSession deletes a TOTP login session
func
(
c
*
TotpCache
)
DeleteLoginSession
(
ctx
context
.
Context
,
tempToken
string
)
error
{
key
:=
totpLoginKeyPrefix
+
tempToken
return
c
.
rdb
.
Del
(
ctx
,
key
)
.
Err
()
}
// IncrementVerifyAttempts increments the verify attempt counter
func
(
c
*
TotpCache
)
IncrementVerifyAttempts
(
ctx
context
.
Context
,
userID
int64
)
(
int
,
error
)
{
key
:=
fmt
.
Sprintf
(
"%s%d"
,
totpAttemptsKeyPrefix
,
userID
)
// Use pipeline for atomic increment and set TTL
pipe
:=
c
.
rdb
.
Pipeline
()
incrCmd
:=
pipe
.
Incr
(
ctx
,
key
)
pipe
.
Expire
(
ctx
,
key
,
totpAttemptsTTL
)
if
_
,
err
:=
pipe
.
Exec
(
ctx
);
err
!=
nil
{
return
0
,
fmt
.
Errorf
(
"increment verify attempts: %w"
,
err
)
}
count
,
err
:=
incrCmd
.
Result
()
if
err
!=
nil
{
return
0
,
fmt
.
Errorf
(
"get increment result: %w"
,
err
)
}
return
int
(
count
),
nil
}
// GetVerifyAttempts gets the current verify attempt count
func
(
c
*
TotpCache
)
GetVerifyAttempts
(
ctx
context
.
Context
,
userID
int64
)
(
int
,
error
)
{
key
:=
fmt
.
Sprintf
(
"%s%d"
,
totpAttemptsKeyPrefix
,
userID
)
count
,
err
:=
c
.
rdb
.
Get
(
ctx
,
key
)
.
Int
()
if
err
!=
nil
{
if
err
==
redis
.
Nil
{
return
0
,
nil
}
return
0
,
fmt
.
Errorf
(
"get verify attempts: %w"
,
err
)
}
return
count
,
nil
}
// ClearVerifyAttempts clears the verify attempt counter
func
(
c
*
TotpCache
)
ClearVerifyAttempts
(
ctx
context
.
Context
,
userID
int64
)
error
{
key
:=
fmt
.
Sprintf
(
"%s%d"
,
totpAttemptsKeyPrefix
,
userID
)
return
c
.
rdb
.
Del
(
ctx
,
key
)
.
Err
()
}
Prev
1
2
3
4
5
6
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment