Commit 5c76b9e4 authored by erio's avatar erio
Browse files

fix: prevent sessionHash collision for different users with same messages

Mix SessionContext (ClientIP, UserAgent, APIKeyID) into
GenerateSessionHash 3rd-level fallback to differentiate requests
from different users sending identical content.

Also switch hashContent from SHA256-truncated to XXHash64 for
better performance, and optimize Trie Lua script to match from
longest prefix first.
parent 0b8fea4c
......@@ -203,6 +203,11 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
}
// 计算粘性会话hash
parsedReq.SessionContext = &service.SessionContext{
ClientIP: ip.GetClientIP(c),
UserAgent: c.GetHeader("User-Agent"),
APIKeyID: apiKey.ID,
}
sessionHash := h.gatewayService.GenerateSessionHash(parsedReq)
// 获取平台:优先使用强制平台(/antigravity 路由,中间件已设置 request.Context),否则使用分组平台
......@@ -962,6 +967,11 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) {
}
// 计算粘性会话 hash
parsedReq.SessionContext = &service.SessionContext{
ClientIP: ip.GetClientIP(c),
UserAgent: c.GetHeader("User-Agent"),
APIKeyID: apiKey.ID,
}
sessionHash := h.gatewayService.GenerateSessionHash(parsedReq)
// 选择支持该模型的账号
......
......@@ -233,6 +233,13 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
if sessionHash == "" {
// Fallback: 使用通用的会话哈希生成逻辑(适用于其他客户端)
parsedReq, _ := service.ParseGatewayRequest(body)
if parsedReq != nil {
parsedReq.SessionContext = &service.SessionContext{
ClientIP: ip.GetClientIP(c),
UserAgent: c.GetHeader("User-Agent"),
APIKeyID: apiKey.ID,
}
}
sessionHash = h.gatewayService.GenerateSessionHash(parsedReq)
}
sessionKey := sessionHash
......
......@@ -19,25 +19,34 @@ const (
// ARGV[2] = TTL seconds (用于刷新)
// 返回: 最长匹配的 value (uuid:accountID) 或 nil
// 查找成功时自动刷新 TTL,防止活跃会话意外过期
// 从最长前缀(完整 chain)开始逐步缩短,第一次命中即返回
geminiTrieFindScript = `
local chain = ARGV[1]
local ttl = tonumber(ARGV[2])
local lastMatch = nil
local path = ""
for part in string.gmatch(chain, "[^-]+") do
path = path == "" and part or path .. "-" .. part
local val = redis.call('HGET', KEYS[1], path)
if val and val ~= "" then
lastMatch = val
end
-- 先尝试完整 chain(最常见场景:同一对话的下一轮请求)
local val = redis.call('HGET', KEYS[1], chain)
if val and val ~= "" then
redis.call('EXPIRE', KEYS[1], ttl)
return val
end
if lastMatch then
redis.call('EXPIRE', KEYS[1], ttl)
-- 从最长前缀开始逐步缩短(去掉最后一个 "-xxx" 段)
local path = chain
while true do
local i = string.find(path, "-[^-]*$")
if not i or i <= 1 then
break
end
path = string.sub(path, 1, i - 1)
val = redis.call('HGET', KEYS[1], path)
if val and val ~= "" then
redis.call('EXPIRE', KEYS[1], ttl)
return val
end
end
return lastMatch
return nil
`
// geminiTrieSaveScript 保存会话到 Trie 的 Lua 脚本
......
......@@ -9,6 +9,15 @@ import (
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
)
// SessionContext 粘性会话上下文,用于区分不同来源的请求。
// 仅在 GenerateSessionHash 第 3 级 fallback(消息内容 hash)时混入,
// 避免不同用户发送相同消息产生相同 hash 导致账号集中。
type SessionContext struct {
ClientIP string
UserAgent string
APIKeyID int64
}
// ParsedRequest 保存网关请求的预解析结果
//
// 性能优化说明:
......@@ -22,15 +31,16 @@ import (
// 2. 将解析结果 ParsedRequest 传递给 Service 层
// 3. 避免重复 json.Unmarshal,减少 CPU 和内存开销
type ParsedRequest struct {
Body []byte // 原始请求体(保留用于转发)
Model string // 请求的模型名称
Stream bool // 是否为流式请求
MetadataUserID string // metadata.user_id(用于会话亲和)
System any // system 字段内容
Messages []any // messages 数组
HasSystem bool // 是否包含 system 字段(包含 null 也视为显式传入)
ThinkingEnabled bool // 是否开启 thinking(部分平台会影响最终模型名)
MaxTokens int // max_tokens 值(用于探测请求拦截)
Body []byte // 原始请求体(保留用于转发)
Model string // 请求的模型名称
Stream bool // 是否为流式请求
MetadataUserID string // metadata.user_id(用于会话亲和)
System any // system 字段内容
Messages []any // messages 数组
HasSystem bool // 是否包含 system 字段(包含 null 也视为显式传入)
ThinkingEnabled bool // 是否开启 thinking(部分平台会影响最终模型名)
MaxTokens int // max_tokens 值(用于探测请求拦截)
SessionContext *SessionContext // 可选:请求上下文区分因子(nil 时行为不变)
}
// ParseGatewayRequest 解析网关请求体并返回结构化结果
......
......@@ -5,7 +5,6 @@ import (
"bytes"
"context"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
......@@ -17,6 +16,7 @@ import (
"os"
"regexp"
"sort"
"strconv"
"strings"
"sync/atomic"
"time"
......@@ -26,6 +26,7 @@ import (
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
"github.com/Wei-Shaw/sub2api/internal/util/responseheaders"
"github.com/Wei-Shaw/sub2api/internal/util/urlvalidator"
"github.com/cespare/xxhash/v2"
"github.com/google/uuid"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
......@@ -490,8 +491,17 @@ func (s *GatewayService) GenerateSessionHash(parsed *ParsedRequest) string {
return s.hashContent(cacheableContent)
}
// 3. 最后 fallback: 使用 system + 所有消息的完整摘要串
// 3. 最后 fallback: 使用 session上下文 + system + 所有消息的完整摘要串
var combined strings.Builder
// 混入请求上下文区分因子,避免不同用户相同消息产生相同 hash
if parsed.SessionContext != nil {
_, _ = combined.WriteString(parsed.SessionContext.ClientIP)
_, _ = combined.WriteString(":")
_, _ = combined.WriteString(parsed.SessionContext.UserAgent)
_, _ = combined.WriteString(":")
_, _ = combined.WriteString(strconv.FormatInt(parsed.SessionContext.APIKeyID, 10))
_, _ = combined.WriteString("|")
}
if parsed.System != nil {
systemText := s.extractTextFromSystem(parsed.System)
if systemText != "" {
......@@ -649,8 +659,8 @@ func (s *GatewayService) extractTextFromContent(content any) string {
}
func (s *GatewayService) hashContent(content string) string {
hash := sha256.Sum256([]byte(content))
return hex.EncodeToString(hash[:16]) // 32字符
h := xxhash.Sum64String(content)
return strconv.FormatUint(h, 36)
}
// replaceModelInBody 替换请求体中的model字段
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment