Commit 5c76b9e4 authored by erio's avatar erio
Browse files

fix: prevent sessionHash collision for different users with same messages

Mix SessionContext (ClientIP, UserAgent, APIKeyID) into
GenerateSessionHash 3rd-level fallback to differentiate requests
from different users sending identical content.

Also switch hashContent from SHA256-truncated to XXHash64 for
better performance, and optimize Trie Lua script to match from
longest prefix first.
parent 0b8fea4c
...@@ -203,6 +203,11 @@ func (h *GatewayHandler) Messages(c *gin.Context) { ...@@ -203,6 +203,11 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
} }
// 计算粘性会话hash // 计算粘性会话hash
parsedReq.SessionContext = &service.SessionContext{
ClientIP: ip.GetClientIP(c),
UserAgent: c.GetHeader("User-Agent"),
APIKeyID: apiKey.ID,
}
sessionHash := h.gatewayService.GenerateSessionHash(parsedReq) sessionHash := h.gatewayService.GenerateSessionHash(parsedReq)
// 获取平台:优先使用强制平台(/antigravity 路由,中间件已设置 request.Context),否则使用分组平台 // 获取平台:优先使用强制平台(/antigravity 路由,中间件已设置 request.Context),否则使用分组平台
...@@ -962,6 +967,11 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) { ...@@ -962,6 +967,11 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) {
} }
// 计算粘性会话 hash // 计算粘性会话 hash
parsedReq.SessionContext = &service.SessionContext{
ClientIP: ip.GetClientIP(c),
UserAgent: c.GetHeader("User-Agent"),
APIKeyID: apiKey.ID,
}
sessionHash := h.gatewayService.GenerateSessionHash(parsedReq) sessionHash := h.gatewayService.GenerateSessionHash(parsedReq)
// 选择支持该模型的账号 // 选择支持该模型的账号
......
...@@ -233,6 +233,13 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { ...@@ -233,6 +233,13 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
if sessionHash == "" { if sessionHash == "" {
// Fallback: 使用通用的会话哈希生成逻辑(适用于其他客户端) // Fallback: 使用通用的会话哈希生成逻辑(适用于其他客户端)
parsedReq, _ := service.ParseGatewayRequest(body) parsedReq, _ := service.ParseGatewayRequest(body)
if parsedReq != nil {
parsedReq.SessionContext = &service.SessionContext{
ClientIP: ip.GetClientIP(c),
UserAgent: c.GetHeader("User-Agent"),
APIKeyID: apiKey.ID,
}
}
sessionHash = h.gatewayService.GenerateSessionHash(parsedReq) sessionHash = h.gatewayService.GenerateSessionHash(parsedReq)
} }
sessionKey := sessionHash sessionKey := sessionHash
......
...@@ -19,25 +19,34 @@ const ( ...@@ -19,25 +19,34 @@ const (
// ARGV[2] = TTL seconds (用于刷新) // ARGV[2] = TTL seconds (用于刷新)
// 返回: 最长匹配的 value (uuid:accountID) 或 nil // 返回: 最长匹配的 value (uuid:accountID) 或 nil
// 查找成功时自动刷新 TTL,防止活跃会话意外过期 // 查找成功时自动刷新 TTL,防止活跃会话意外过期
// 从最长前缀(完整 chain)开始逐步缩短,第一次命中即返回
geminiTrieFindScript = ` geminiTrieFindScript = `
local chain = ARGV[1] local chain = ARGV[1]
local ttl = tonumber(ARGV[2]) local ttl = tonumber(ARGV[2])
local lastMatch = nil
local path = ""
for part in string.gmatch(chain, "[^-]+") do -- 先尝试完整 chain(最常见场景:同一对话的下一轮请求)
path = path == "" and part or path .. "-" .. part local val = redis.call('HGET', KEYS[1], chain)
local val = redis.call('HGET', KEYS[1], path) if val and val ~= "" then
if val and val ~= "" then redis.call('EXPIRE', KEYS[1], ttl)
lastMatch = val return val
end
end end
if lastMatch then -- 从最长前缀开始逐步缩短(去掉最后一个 "-xxx" 段)
local path = chain
while true do
local i = string.find(path, "-[^-]*$")
if not i or i <= 1 then
break
end
path = string.sub(path, 1, i - 1)
val = redis.call('HGET', KEYS[1], path)
if val and val ~= "" then
redis.call('EXPIRE', KEYS[1], ttl) redis.call('EXPIRE', KEYS[1], ttl)
return val
end
end end
return lastMatch return nil
` `
// geminiTrieSaveScript 保存会话到 Trie 的 Lua 脚本 // geminiTrieSaveScript 保存会话到 Trie 的 Lua 脚本
......
...@@ -9,6 +9,15 @@ import ( ...@@ -9,6 +9,15 @@ import (
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity" "github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
) )
// SessionContext 粘性会话上下文,用于区分不同来源的请求。
// 仅在 GenerateSessionHash 第 3 级 fallback(消息内容 hash)时混入,
// 避免不同用户发送相同消息产生相同 hash 导致账号集中。
type SessionContext struct {
ClientIP string
UserAgent string
APIKeyID int64
}
// ParsedRequest 保存网关请求的预解析结果 // ParsedRequest 保存网关请求的预解析结果
// //
// 性能优化说明: // 性能优化说明:
...@@ -31,6 +40,7 @@ type ParsedRequest struct { ...@@ -31,6 +40,7 @@ type ParsedRequest struct {
HasSystem bool // 是否包含 system 字段(包含 null 也视为显式传入) HasSystem bool // 是否包含 system 字段(包含 null 也视为显式传入)
ThinkingEnabled bool // 是否开启 thinking(部分平台会影响最终模型名) ThinkingEnabled bool // 是否开启 thinking(部分平台会影响最终模型名)
MaxTokens int // max_tokens 值(用于探测请求拦截) MaxTokens int // max_tokens 值(用于探测请求拦截)
SessionContext *SessionContext // 可选:请求上下文区分因子(nil 时行为不变)
} }
// ParseGatewayRequest 解析网关请求体并返回结构化结果 // ParseGatewayRequest 解析网关请求体并返回结构化结果
......
...@@ -5,7 +5,6 @@ import ( ...@@ -5,7 +5,6 @@ import (
"bytes" "bytes"
"context" "context"
"crypto/sha256" "crypto/sha256"
"encoding/hex"
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
...@@ -17,6 +16,7 @@ import ( ...@@ -17,6 +16,7 @@ import (
"os" "os"
"regexp" "regexp"
"sort" "sort"
"strconv"
"strings" "strings"
"sync/atomic" "sync/atomic"
"time" "time"
...@@ -26,6 +26,7 @@ import ( ...@@ -26,6 +26,7 @@ import (
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey" "github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
"github.com/Wei-Shaw/sub2api/internal/util/responseheaders" "github.com/Wei-Shaw/sub2api/internal/util/responseheaders"
"github.com/Wei-Shaw/sub2api/internal/util/urlvalidator" "github.com/Wei-Shaw/sub2api/internal/util/urlvalidator"
"github.com/cespare/xxhash/v2"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/tidwall/gjson" "github.com/tidwall/gjson"
"github.com/tidwall/sjson" "github.com/tidwall/sjson"
...@@ -490,8 +491,17 @@ func (s *GatewayService) GenerateSessionHash(parsed *ParsedRequest) string { ...@@ -490,8 +491,17 @@ func (s *GatewayService) GenerateSessionHash(parsed *ParsedRequest) string {
return s.hashContent(cacheableContent) return s.hashContent(cacheableContent)
} }
// 3. 最后 fallback: 使用 system + 所有消息的完整摘要串 // 3. 最后 fallback: 使用 session上下文 + system + 所有消息的完整摘要串
var combined strings.Builder var combined strings.Builder
// 混入请求上下文区分因子,避免不同用户相同消息产生相同 hash
if parsed.SessionContext != nil {
_, _ = combined.WriteString(parsed.SessionContext.ClientIP)
_, _ = combined.WriteString(":")
_, _ = combined.WriteString(parsed.SessionContext.UserAgent)
_, _ = combined.WriteString(":")
_, _ = combined.WriteString(strconv.FormatInt(parsed.SessionContext.APIKeyID, 10))
_, _ = combined.WriteString("|")
}
if parsed.System != nil { if parsed.System != nil {
systemText := s.extractTextFromSystem(parsed.System) systemText := s.extractTextFromSystem(parsed.System)
if systemText != "" { if systemText != "" {
...@@ -649,8 +659,8 @@ func (s *GatewayService) extractTextFromContent(content any) string { ...@@ -649,8 +659,8 @@ func (s *GatewayService) extractTextFromContent(content any) string {
} }
func (s *GatewayService) hashContent(content string) string { func (s *GatewayService) hashContent(content string) string {
hash := sha256.Sum256([]byte(content)) h := xxhash.Sum64String(content)
return hex.EncodeToString(hash[:16]) // 32字符 return strconv.FormatUint(h, 36)
} }
// replaceModelInBody 替换请求体中的model字段 // replaceModelInBody 替换请求体中的model字段
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment