Commit 3d79773b authored by kyx236's avatar kyx236
Browse files

Merge branch 'main' of https://github.com/james-6-23/sub2api

parents 6aa8cbbf 742e73c9
package openai
import (
"net/url"
"sync"
"testing"
"time"
)
func TestSessionStore_Stop_Idempotent(t *testing.T) {
store := NewSessionStore()
store.Stop()
store.Stop()
select {
case <-store.stopCh:
// ok
case <-time.After(time.Second):
t.Fatal("stopCh 未关闭")
}
}
func TestSessionStore_Stop_Concurrent(t *testing.T) {
store := NewSessionStore()
var wg sync.WaitGroup
for range 50 {
wg.Add(1)
go func() {
defer wg.Done()
store.Stop()
}()
}
wg.Wait()
select {
case <-store.stopCh:
// ok
case <-time.After(time.Second):
t.Fatal("stopCh 未关闭")
}
}
func TestBuildAuthorizationURLForPlatform_OpenAI(t *testing.T) {
authURL := BuildAuthorizationURLForPlatform("state-1", "challenge-1", DefaultRedirectURI, OAuthPlatformOpenAI)
parsed, err := url.Parse(authURL)
if err != nil {
t.Fatalf("Parse URL failed: %v", err)
}
q := parsed.Query()
if got := q.Get("client_id"); got != ClientID {
t.Fatalf("client_id mismatch: got=%q want=%q", got, ClientID)
}
if got := q.Get("codex_cli_simplified_flow"); got != "true" {
t.Fatalf("codex flow mismatch: got=%q want=true", got)
}
if got := q.Get("id_token_add_organizations"); got != "true" {
t.Fatalf("id_token_add_organizations mismatch: got=%q want=true", got)
}
}
// TestBuildAuthorizationURLForPlatform_Sora 验证 Sora 平台复用 Codex CLI 的 client_id,
// 但不启用 codex_cli_simplified_flow。
func TestBuildAuthorizationURLForPlatform_Sora(t *testing.T) {
authURL := BuildAuthorizationURLForPlatform("state-2", "challenge-2", DefaultRedirectURI, OAuthPlatformSora)
parsed, err := url.Parse(authURL)
if err != nil {
t.Fatalf("Parse URL failed: %v", err)
}
q := parsed.Query()
if got := q.Get("client_id"); got != ClientID {
t.Fatalf("client_id mismatch: got=%q want=%q (Sora should reuse Codex CLI client_id)", got, ClientID)
}
if got := q.Get("codex_cli_simplified_flow"); got != "" {
t.Fatalf("codex flow should be empty for sora, got=%q", got)
}
if got := q.Get("id_token_add_organizations"); got != "true" {
t.Fatalf("id_token_add_organizations mismatch: got=%q want=true", got)
}
}
package openai
import "strings"
// CodexCLIUserAgentPrefixes matches Codex CLI User-Agent patterns
// Examples: "codex_vscode/1.0.0", "codex_cli_rs/0.1.2"
var CodexCLIUserAgentPrefixes = []string{
......@@ -7,10 +9,67 @@ var CodexCLIUserAgentPrefixes = []string{
"codex_cli_rs/",
}
// CodexOfficialClientUserAgentPrefixes matches Codex 官方客户端家族 User-Agent 前缀。
// 该列表仅用于 OpenAI OAuth `codex_cli_only` 访问限制判定。
var CodexOfficialClientUserAgentPrefixes = []string{
"codex_cli_rs/",
"codex_vscode/",
"codex_app/",
"codex_chatgpt_desktop/",
"codex_atlas/",
"codex_exec/",
"codex_sdk_ts/",
"codex ",
}
// CodexOfficialClientOriginatorPrefixes matches Codex 官方客户端家族 originator 前缀。
// 说明:OpenAI 官方 Codex 客户端并不只使用固定的 codex_app 标识。
// 例如 codex_cli_rs、codex_vscode、codex_chatgpt_desktop、codex_atlas、codex_exec、codex_sdk_ts 等。
var CodexOfficialClientOriginatorPrefixes = []string{
"codex_",
"codex ",
}
// IsCodexCLIRequest checks if the User-Agent indicates a Codex CLI request
func IsCodexCLIRequest(userAgent string) bool {
for _, prefix := range CodexCLIUserAgentPrefixes {
if len(userAgent) >= len(prefix) && userAgent[:len(prefix)] == prefix {
ua := normalizeCodexClientHeader(userAgent)
if ua == "" {
return false
}
return matchCodexClientHeaderPrefixes(ua, CodexCLIUserAgentPrefixes)
}
// IsCodexOfficialClientRequest checks if the User-Agent indicates a Codex 官方客户端请求。
// 与 IsCodexCLIRequest 解耦,避免影响历史兼容逻辑。
func IsCodexOfficialClientRequest(userAgent string) bool {
ua := normalizeCodexClientHeader(userAgent)
if ua == "" {
return false
}
return matchCodexClientHeaderPrefixes(ua, CodexOfficialClientUserAgentPrefixes)
}
// IsCodexOfficialClientOriginator checks if originator indicates a Codex 官方客户端请求。
func IsCodexOfficialClientOriginator(originator string) bool {
v := normalizeCodexClientHeader(originator)
if v == "" {
return false
}
return matchCodexClientHeaderPrefixes(v, CodexOfficialClientOriginatorPrefixes)
}
func normalizeCodexClientHeader(value string) string {
return strings.ToLower(strings.TrimSpace(value))
}
func matchCodexClientHeaderPrefixes(value string, prefixes []string) bool {
for _, prefix := range prefixes {
normalizedPrefix := normalizeCodexClientHeader(prefix)
if normalizedPrefix == "" {
continue
}
// 优先前缀匹配;若 UA/Originator 被网关拼接为复合字符串时,退化为包含匹配。
if strings.HasPrefix(value, normalizedPrefix) || strings.Contains(value, normalizedPrefix) {
return true
}
}
......
package openai
import "testing"
func TestIsCodexCLIRequest(t *testing.T) {
tests := []struct {
name string
ua string
want bool
}{
{name: "codex_cli_rs 前缀", ua: "codex_cli_rs/0.1.0", want: true},
{name: "codex_vscode 前缀", ua: "codex_vscode/1.2.3", want: true},
{name: "大小写混合", ua: "Codex_CLI_Rs/0.1.0", want: true},
{name: "复合 UA 包含 codex", ua: "Mozilla/5.0 codex_cli_rs/0.1.0", want: true},
{name: "空白包裹", ua: " codex_vscode/1.2.3 ", want: true},
{name: "非 codex", ua: "curl/8.0.1", want: false},
{name: "空字符串", ua: "", want: false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := IsCodexCLIRequest(tt.ua)
if got != tt.want {
t.Fatalf("IsCodexCLIRequest(%q) = %v, want %v", tt.ua, got, tt.want)
}
})
}
}
func TestIsCodexOfficialClientRequest(t *testing.T) {
tests := []struct {
name string
ua string
want bool
}{
{name: "codex_cli_rs 前缀", ua: "codex_cli_rs/0.98.0", want: true},
{name: "codex_vscode 前缀", ua: "codex_vscode/1.0.0", want: true},
{name: "codex_app 前缀", ua: "codex_app/0.1.0", want: true},
{name: "codex_chatgpt_desktop 前缀", ua: "codex_chatgpt_desktop/1.0.0", want: true},
{name: "codex_atlas 前缀", ua: "codex_atlas/1.0.0", want: true},
{name: "codex_exec 前缀", ua: "codex_exec/0.1.0", want: true},
{name: "codex_sdk_ts 前缀", ua: "codex_sdk_ts/0.1.0", want: true},
{name: "Codex 桌面 UA", ua: "Codex Desktop/1.2.3", want: true},
{name: "复合 UA 包含 codex_app", ua: "Mozilla/5.0 codex_app/0.1.0", want: true},
{name: "大小写混合", ua: "Codex_VSCode/1.2.3", want: true},
{name: "非 codex", ua: "curl/8.0.1", want: false},
{name: "空字符串", ua: "", want: false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := IsCodexOfficialClientRequest(tt.ua)
if got != tt.want {
t.Fatalf("IsCodexOfficialClientRequest(%q) = %v, want %v", tt.ua, got, tt.want)
}
})
}
}
func TestIsCodexOfficialClientOriginator(t *testing.T) {
tests := []struct {
name string
originator string
want bool
}{
{name: "codex_cli_rs", originator: "codex_cli_rs", want: true},
{name: "codex_vscode", originator: "codex_vscode", want: true},
{name: "codex_app", originator: "codex_app", want: true},
{name: "codex_chatgpt_desktop", originator: "codex_chatgpt_desktop", want: true},
{name: "codex_atlas", originator: "codex_atlas", want: true},
{name: "codex_exec", originator: "codex_exec", want: true},
{name: "codex_sdk_ts", originator: "codex_sdk_ts", want: true},
{name: "Codex 前缀", originator: "Codex Desktop", want: true},
{name: "空白包裹", originator: " codex_vscode ", want: true},
{name: "非 codex", originator: "my_client", want: false},
{name: "空字符串", originator: "", want: false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := IsCodexOfficialClientOriginator(tt.originator)
if got != tt.want {
t.Fatalf("IsCodexOfficialClientOriginator(%q) = %v, want %v", tt.originator, got, tt.want)
}
})
}
}
// Package proxyurl 提供代理 URL 的统一验证(fail-fast,无效代理不回退直连)
//
// 所有需要解析代理 URL 的地方必须通过此包的 Parse 函数。
// 直接使用 url.Parse 处理代理 URL 是被禁止的。
// 这确保了 fail-fast 行为:无效代理配置在创建时立即失败,
// 而不是在运行时静默回退到直连(产生 IP 关联风险)。
package proxyurl
import (
"fmt"
"net/url"
"strings"
)
// allowedSchemes 代理协议白名单
var allowedSchemes = map[string]bool{
"http": true,
"https": true,
"socks5": true,
"socks5h": true,
}
// Parse 解析并验证代理 URL。
//
// 语义:
// - 空字符串 → ("", nil, nil),表示直连
// - 非空且有效 → (trimmed, *url.URL, nil)
// - 非空但无效 → ("", nil, error),fail-fast 不回退
//
// 验证规则:
// - TrimSpace 后为空视为直连
// - url.Parse 失败返回 error(不含原始 URL,防凭据泄露)
// - Host 为空返回 error(用 Redacted() 脱敏)
// - Scheme 必须为 http/https/socks5/socks5h
// - socks5:// 自动升级为 socks5h://(确保 DNS 由代理端解析,防止 DNS 泄漏)
func Parse(raw string) (trimmed string, parsed *url.URL, err error) {
trimmed = strings.TrimSpace(raw)
if trimmed == "" {
return "", nil, nil
}
parsed, err = url.Parse(trimmed)
if err != nil {
// 不使用 %w 包装,避免 url.Parse 的底层错误消息泄漏原始 URL(可能含凭据)
return "", nil, fmt.Errorf("invalid proxy URL: %v", err)
}
if parsed.Host == "" || parsed.Hostname() == "" {
return "", nil, fmt.Errorf("proxy URL missing host: %s", parsed.Redacted())
}
scheme := strings.ToLower(parsed.Scheme)
if !allowedSchemes[scheme] {
return "", nil, fmt.Errorf("unsupported proxy scheme %q (allowed: http, https, socks5, socks5h)", scheme)
}
// 自动升级 socks5 → socks5h,确保 DNS 由代理端解析,防止 DNS 泄漏。
// Go 的 golang.org/x/net/proxy 对 socks5:// 默认在客户端本地解析 DNS,
// 仅 socks5h:// 才将域名发送给代理端做远程 DNS 解析。
if scheme == "socks5" {
parsed.Scheme = "socks5h"
trimmed = parsed.String()
}
return trimmed, parsed, nil
}
package proxyurl
import (
"strings"
"testing"
)
func TestParse_空字符串直连(t *testing.T) {
trimmed, parsed, err := Parse("")
if err != nil {
t.Fatalf("空字符串应直连: %v", err)
}
if trimmed != "" {
t.Errorf("trimmed 应为空: got %q", trimmed)
}
if parsed != nil {
t.Errorf("parsed 应为 nil: got %v", parsed)
}
}
func TestParse_空白字符串直连(t *testing.T) {
trimmed, parsed, err := Parse(" ")
if err != nil {
t.Fatalf("空白字符串应直连: %v", err)
}
if trimmed != "" {
t.Errorf("trimmed 应为空: got %q", trimmed)
}
if parsed != nil {
t.Errorf("parsed 应为 nil: got %v", parsed)
}
}
func TestParse_有效HTTP代理(t *testing.T) {
trimmed, parsed, err := Parse("http://proxy.example.com:8080")
if err != nil {
t.Fatalf("有效 HTTP 代理应成功: %v", err)
}
if trimmed != "http://proxy.example.com:8080" {
t.Errorf("trimmed 不匹配: got %q", trimmed)
}
if parsed == nil {
t.Fatal("parsed 不应为 nil")
}
if parsed.Host != "proxy.example.com:8080" {
t.Errorf("Host 不匹配: got %q", parsed.Host)
}
}
func TestParse_有效HTTPS代理(t *testing.T) {
_, parsed, err := Parse("https://proxy.example.com:443")
if err != nil {
t.Fatalf("有效 HTTPS 代理应成功: %v", err)
}
if parsed.Scheme != "https" {
t.Errorf("Scheme 不匹配: got %q", parsed.Scheme)
}
}
func TestParse_有效SOCKS5代理_自动升级为SOCKS5H(t *testing.T) {
trimmed, parsed, err := Parse("socks5://127.0.0.1:1080")
if err != nil {
t.Fatalf("有效 SOCKS5 代理应成功: %v", err)
}
// socks5 自动升级为 socks5h,确保 DNS 由代理端解析
if trimmed != "socks5h://127.0.0.1:1080" {
t.Errorf("trimmed 应升级为 socks5h: got %q", trimmed)
}
if parsed.Scheme != "socks5h" {
t.Errorf("Scheme 应升级为 socks5h: got %q", parsed.Scheme)
}
}
func TestParse_无效URL(t *testing.T) {
_, _, err := Parse("://invalid")
if err == nil {
t.Fatal("无效 URL 应返回错误")
}
if !strings.Contains(err.Error(), "invalid proxy URL") {
t.Errorf("错误信息应包含 'invalid proxy URL': got %s", err.Error())
}
}
func TestParse_缺少Host(t *testing.T) {
_, _, err := Parse("http://")
if err == nil {
t.Fatal("缺少 host 应返回错误")
}
if !strings.Contains(err.Error(), "missing host") {
t.Errorf("错误信息应包含 'missing host': got %s", err.Error())
}
}
func TestParse_不支持的Scheme(t *testing.T) {
_, _, err := Parse("ftp://proxy.example.com:21")
if err == nil {
t.Fatal("不支持的 scheme 应返回错误")
}
if !strings.Contains(err.Error(), "unsupported proxy scheme") {
t.Errorf("错误信息应包含 'unsupported proxy scheme': got %s", err.Error())
}
}
func TestParse_含密码URL脱敏(t *testing.T) {
// 场景 1: 带密码的 socks5 URL 应成功解析并升级为 socks5h
trimmed, parsed, err := Parse("socks5://user:secret_password@proxy.local:1080")
if err != nil {
t.Fatalf("含密码的有效 URL 应成功: %v", err)
}
if trimmed == "" || parsed == nil {
t.Fatal("应返回非空结果")
}
if parsed.Scheme != "socks5h" {
t.Errorf("Scheme 应升级为 socks5h: got %q", parsed.Scheme)
}
if !strings.HasPrefix(trimmed, "socks5h://") {
t.Errorf("trimmed 应以 socks5h:// 开头: got %q", trimmed)
}
if parsed.User == nil {
t.Error("升级后应保留 UserInfo")
}
// 场景 2: 带密码但缺少 host(触发 Redacted 脱敏路径)
_, _, err = Parse("http://user:secret_password@:0/")
if err == nil {
t.Fatal("缺少 host 应返回错误")
}
if strings.Contains(err.Error(), "secret_password") {
t.Error("错误信息不应包含明文密码")
}
if !strings.Contains(err.Error(), "missing host") {
t.Errorf("错误信息应包含 'missing host': got %s", err.Error())
}
}
func TestParse_带空白的有效URL(t *testing.T) {
trimmed, parsed, err := Parse(" http://proxy.example.com:8080 ")
if err != nil {
t.Fatalf("带空白的有效 URL 应成功: %v", err)
}
if trimmed != "http://proxy.example.com:8080" {
t.Errorf("trimmed 应去除空白: got %q", trimmed)
}
if parsed == nil {
t.Fatal("parsed 不应为 nil")
}
}
func TestParse_Scheme大小写不敏感(t *testing.T) {
// 大写 SOCKS5 应被接受并升级为 socks5h
trimmed, parsed, err := Parse("SOCKS5://proxy.example.com:1080")
if err != nil {
t.Fatalf("大写 SOCKS5 应被接受: %v", err)
}
if parsed.Scheme != "socks5h" {
t.Errorf("大写 SOCKS5 Scheme 应升级为 socks5h: got %q", parsed.Scheme)
}
if !strings.HasPrefix(trimmed, "socks5h://") {
t.Errorf("大写 SOCKS5 trimmed 应升级为 socks5h://: got %q", trimmed)
}
// 大写 HTTP 应被接受(不变)
_, _, err = Parse("HTTP://proxy.example.com:8080")
if err != nil {
t.Fatalf("大写 HTTP 应被接受: %v", err)
}
}
func TestParse_带认证的有效代理(t *testing.T) {
trimmed, parsed, err := Parse("http://user:pass@proxy.example.com:8080")
if err != nil {
t.Fatalf("带认证的代理 URL 应成功: %v", err)
}
if parsed.User == nil {
t.Error("应保留 UserInfo")
}
if trimmed != "http://user:pass@proxy.example.com:8080" {
t.Errorf("trimmed 不匹配: got %q", trimmed)
}
}
func TestParse_IPv6地址(t *testing.T) {
trimmed, parsed, err := Parse("http://[::1]:8080")
if err != nil {
t.Fatalf("IPv6 代理 URL 应成功: %v", err)
}
if parsed.Hostname() != "::1" {
t.Errorf("Hostname 不匹配: got %q", parsed.Hostname())
}
if trimmed != "http://[::1]:8080" {
t.Errorf("trimmed 不匹配: got %q", trimmed)
}
}
func TestParse_SOCKS5H保持不变(t *testing.T) {
trimmed, parsed, err := Parse("socks5h://proxy.local:1080")
if err != nil {
t.Fatalf("有效 SOCKS5H 代理应成功: %v", err)
}
// socks5h 不需要升级,应保持原样
if trimmed != "socks5h://proxy.local:1080" {
t.Errorf("trimmed 不应变化: got %q", trimmed)
}
if parsed.Scheme != "socks5h" {
t.Errorf("Scheme 应保持 socks5h: got %q", parsed.Scheme)
}
}
func TestParse_无Scheme裸地址(t *testing.T) {
// 无 scheme 的裸地址,Go url.Parse 将其视为 path,Host 为空
_, _, err := Parse("proxy.example.com:8080")
if err == nil {
t.Fatal("无 scheme 的裸地址应返回错误")
}
}
......@@ -2,7 +2,11 @@
//
// 支持的代理协议:
// - HTTP/HTTPS: 通过 Transport.Proxy 设置
// - SOCKS5/SOCKS5H: 通过 Transport.DialContext 设置(服务端解析 DNS)
// - SOCKS5: 通过 Transport.DialContext 设置(客户端本地解析 DNS)
// - SOCKS5H: 通过 Transport.DialContext 设置(代理端远程解析 DNS,推荐)
//
// 注意:proxyurl.Parse() 会自动将 socks5:// 升级为 socks5h://,
// 确保 DNS 也由代理端解析,防止 DNS 泄漏。
package proxyutil
import (
......@@ -20,7 +24,8 @@ import (
//
// 支持的协议:
// - http/https: 设置 transport.Proxy
// - socks5/socks5h: 设置 transport.DialContext(由代理服务端解析 DNS)
// - socks5: 设置 transport.DialContext(客户端本地解析 DNS)
// - socks5h: 设置 transport.DialContext(代理端远程解析 DNS,推荐)
//
// 参数:
// - transport: 需要配置的 http.Transport
......
......@@ -7,6 +7,7 @@ import (
"net/http"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/util/logredact"
"github.com/gin-gonic/gin"
)
......@@ -78,7 +79,7 @@ func ErrorFrom(c *gin.Context, err error) bool {
// Log internal errors with full details for debugging
if statusCode >= 500 && c.Request != nil {
log.Printf("[ERROR] %s %s\n Error: %s", c.Request.Method, c.Request.URL.Path, err.Error())
log.Printf("[ERROR] %s %s\n Error: %s", c.Request.Method, c.Request.URL.Path, logredact.RedactText(err.Error()))
}
ErrorWithDetails(c, statusCode, status.Message, status.Reason, status.Metadata)
......
......@@ -14,6 +14,44 @@ import (
"github.com/stretchr/testify/require"
)
// ---------- 辅助函数 ----------
// parseResponseBody 从 httptest.ResponseRecorder 中解析 JSON 响应体
func parseResponseBody(t *testing.T, w *httptest.ResponseRecorder) Response {
t.Helper()
var got Response
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &got))
return got
}
// parsePaginatedBody 从响应体中解析分页数据(Data 字段是 PaginatedData)
func parsePaginatedBody(t *testing.T, w *httptest.ResponseRecorder) (Response, PaginatedData) {
t.Helper()
// 先用 raw json 解析,因为 Data 是 any 类型
var raw struct {
Code int `json:"code"`
Message string `json:"message"`
Reason string `json:"reason,omitempty"`
Data json.RawMessage `json:"data,omitempty"`
}
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &raw))
var pd PaginatedData
require.NoError(t, json.Unmarshal(raw.Data, &pd))
return Response{Code: raw.Code, Message: raw.Message, Reason: raw.Reason}, pd
}
// newContextWithQuery 创建一个带有 URL query 参数的 gin.Context 用于测试 ParsePagination
func newContextWithQuery(query string) (*httptest.ResponseRecorder, *gin.Context) {
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodGet, "/?"+query, nil)
return w, c
}
// ---------- 现有测试 ----------
func TestErrorWithDetails(t *testing.T) {
gin.SetMode(gin.TestMode)
......@@ -169,3 +207,582 @@ func TestErrorFrom(t *testing.T) {
})
}
}
// ---------- 新增测试 ----------
func TestSuccess(t *testing.T) {
gin.SetMode(gin.TestMode)
tests := []struct {
name string
data any
wantCode int
wantBody Response
}{
{
name: "返回字符串数据",
data: "hello",
wantCode: http.StatusOK,
wantBody: Response{Code: 0, Message: "success", Data: "hello"},
},
{
name: "返回nil数据",
data: nil,
wantCode: http.StatusOK,
wantBody: Response{Code: 0, Message: "success"},
},
{
name: "返回map数据",
data: map[string]string{"key": "value"},
wantCode: http.StatusOK,
wantBody: Response{Code: 0, Message: "success"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
Success(c, tt.data)
require.Equal(t, tt.wantCode, w.Code)
// 只验证 code 和 message,data 字段类型在 JSON 反序列化时会变成 map/slice
got := parseResponseBody(t, w)
require.Equal(t, 0, got.Code)
require.Equal(t, "success", got.Message)
if tt.data == nil {
require.Nil(t, got.Data)
} else {
require.NotNil(t, got.Data)
}
})
}
}
func TestCreated(t *testing.T) {
gin.SetMode(gin.TestMode)
tests := []struct {
name string
data any
wantCode int
}{
{
name: "创建成功_返回数据",
data: map[string]int{"id": 42},
wantCode: http.StatusCreated,
},
{
name: "创建成功_nil数据",
data: nil,
wantCode: http.StatusCreated,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
Created(c, tt.data)
require.Equal(t, tt.wantCode, w.Code)
got := parseResponseBody(t, w)
require.Equal(t, 0, got.Code)
require.Equal(t, "success", got.Message)
})
}
}
func TestError(t *testing.T) {
gin.SetMode(gin.TestMode)
tests := []struct {
name string
statusCode int
message string
}{
{
name: "400错误",
statusCode: http.StatusBadRequest,
message: "bad request",
},
{
name: "500错误",
statusCode: http.StatusInternalServerError,
message: "internal error",
},
{
name: "自定义状态码",
statusCode: 418,
message: "I'm a teapot",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
Error(c, tt.statusCode, tt.message)
require.Equal(t, tt.statusCode, w.Code)
got := parseResponseBody(t, w)
require.Equal(t, tt.statusCode, got.Code)
require.Equal(t, tt.message, got.Message)
require.Empty(t, got.Reason)
require.Nil(t, got.Metadata)
require.Nil(t, got.Data)
})
}
}
func TestBadRequest(t *testing.T) {
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
BadRequest(c, "参数无效")
require.Equal(t, http.StatusBadRequest, w.Code)
got := parseResponseBody(t, w)
require.Equal(t, http.StatusBadRequest, got.Code)
require.Equal(t, "参数无效", got.Message)
}
func TestUnauthorized(t *testing.T) {
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
Unauthorized(c, "未登录")
require.Equal(t, http.StatusUnauthorized, w.Code)
got := parseResponseBody(t, w)
require.Equal(t, http.StatusUnauthorized, got.Code)
require.Equal(t, "未登录", got.Message)
}
func TestForbidden(t *testing.T) {
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
Forbidden(c, "无权限")
require.Equal(t, http.StatusForbidden, w.Code)
got := parseResponseBody(t, w)
require.Equal(t, http.StatusForbidden, got.Code)
require.Equal(t, "无权限", got.Message)
}
func TestNotFound(t *testing.T) {
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
NotFound(c, "资源不存在")
require.Equal(t, http.StatusNotFound, w.Code)
got := parseResponseBody(t, w)
require.Equal(t, http.StatusNotFound, got.Code)
require.Equal(t, "资源不存在", got.Message)
}
func TestInternalError(t *testing.T) {
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
InternalError(c, "服务器内部错误")
require.Equal(t, http.StatusInternalServerError, w.Code)
got := parseResponseBody(t, w)
require.Equal(t, http.StatusInternalServerError, got.Code)
require.Equal(t, "服务器内部错误", got.Message)
}
func TestPaginated(t *testing.T) {
gin.SetMode(gin.TestMode)
tests := []struct {
name string
items any
total int64
page int
pageSize int
wantPages int
wantTotal int64
wantPage int
wantPageSize int
}{
{
name: "标准分页_多页",
items: []string{"a", "b"},
total: 25,
page: 1,
pageSize: 10,
wantPages: 3,
wantTotal: 25,
wantPage: 1,
wantPageSize: 10,
},
{
name: "总数刚好整除",
items: []string{"a"},
total: 20,
page: 2,
pageSize: 10,
wantPages: 2,
wantTotal: 20,
wantPage: 2,
wantPageSize: 10,
},
{
name: "总数为0_pages至少为1",
items: []string{},
total: 0,
page: 1,
pageSize: 10,
wantPages: 1,
wantTotal: 0,
wantPage: 1,
wantPageSize: 10,
},
{
name: "单页数据",
items: []int{1, 2, 3},
total: 3,
page: 1,
pageSize: 20,
wantPages: 1,
wantTotal: 3,
wantPage: 1,
wantPageSize: 20,
},
{
name: "总数为1",
items: []string{"only"},
total: 1,
page: 1,
pageSize: 10,
wantPages: 1,
wantTotal: 1,
wantPage: 1,
wantPageSize: 10,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
Paginated(c, tt.items, tt.total, tt.page, tt.pageSize)
require.Equal(t, http.StatusOK, w.Code)
resp, pd := parsePaginatedBody(t, w)
require.Equal(t, 0, resp.Code)
require.Equal(t, "success", resp.Message)
require.Equal(t, tt.wantTotal, pd.Total)
require.Equal(t, tt.wantPage, pd.Page)
require.Equal(t, tt.wantPageSize, pd.PageSize)
require.Equal(t, tt.wantPages, pd.Pages)
})
}
}
func TestPaginatedWithResult(t *testing.T) {
gin.SetMode(gin.TestMode)
tests := []struct {
name string
items any
pagination *PaginationResult
wantTotal int64
wantPage int
wantPageSize int
wantPages int
}{
{
name: "正常分页结果",
items: []string{"a", "b"},
pagination: &PaginationResult{
Total: 50,
Page: 3,
PageSize: 10,
Pages: 5,
},
wantTotal: 50,
wantPage: 3,
wantPageSize: 10,
wantPages: 5,
},
{
name: "pagination为nil_使用默认值",
items: []string{},
pagination: nil,
wantTotal: 0,
wantPage: 1,
wantPageSize: 20,
wantPages: 1,
},
{
name: "单页结果",
items: []int{1},
pagination: &PaginationResult{
Total: 1,
Page: 1,
PageSize: 20,
Pages: 1,
},
wantTotal: 1,
wantPage: 1,
wantPageSize: 20,
wantPages: 1,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
PaginatedWithResult(c, tt.items, tt.pagination)
require.Equal(t, http.StatusOK, w.Code)
resp, pd := parsePaginatedBody(t, w)
require.Equal(t, 0, resp.Code)
require.Equal(t, "success", resp.Message)
require.Equal(t, tt.wantTotal, pd.Total)
require.Equal(t, tt.wantPage, pd.Page)
require.Equal(t, tt.wantPageSize, pd.PageSize)
require.Equal(t, tt.wantPages, pd.Pages)
})
}
}
func TestParsePagination(t *testing.T) {
gin.SetMode(gin.TestMode)
tests := []struct {
name string
query string
wantPage int
wantPageSize int
}{
{
name: "无参数_使用默认值",
query: "",
wantPage: 1,
wantPageSize: 20,
},
{
name: "仅指定page",
query: "page=3",
wantPage: 3,
wantPageSize: 20,
},
{
name: "仅指定page_size",
query: "page_size=50",
wantPage: 1,
wantPageSize: 50,
},
{
name: "同时指定page和page_size",
query: "page=2&page_size=30",
wantPage: 2,
wantPageSize: 30,
},
{
name: "使用limit代替page_size",
query: "limit=15",
wantPage: 1,
wantPageSize: 15,
},
{
name: "page_size优先于limit",
query: "page_size=25&limit=50",
wantPage: 1,
wantPageSize: 25,
},
{
name: "page为0_使用默认值",
query: "page=0",
wantPage: 1,
wantPageSize: 20,
},
{
name: "page_size超过1000_使用默认值",
query: "page_size=1001",
wantPage: 1,
wantPageSize: 20,
},
{
name: "page_size恰好1000_有效",
query: "page_size=1000",
wantPage: 1,
wantPageSize: 1000,
},
{
name: "page为非数字_使用默认值",
query: "page=abc",
wantPage: 1,
wantPageSize: 20,
},
{
name: "page_size为非数字_使用默认值",
query: "page_size=xyz",
wantPage: 1,
wantPageSize: 20,
},
{
name: "limit为非数字_使用默认值",
query: "limit=abc",
wantPage: 1,
wantPageSize: 20,
},
{
name: "page_size为0_使用默认值",
query: "page_size=0",
wantPage: 1,
wantPageSize: 20,
},
{
name: "limit为0_使用默认值",
query: "limit=0",
wantPage: 1,
wantPageSize: 20,
},
{
name: "大页码",
query: "page=999&page_size=100",
wantPage: 999,
wantPageSize: 100,
},
{
name: "page_size为1_最小有效值",
query: "page_size=1",
wantPage: 1,
wantPageSize: 1,
},
{
name: "混合数字和字母的page",
query: "page=12a",
wantPage: 1,
wantPageSize: 20,
},
{
name: "limit超过1000_使用默认值",
query: "limit=2000",
wantPage: 1,
wantPageSize: 20,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
_, c := newContextWithQuery(tt.query)
page, pageSize := ParsePagination(c)
require.Equal(t, tt.wantPage, page, "page 不符合预期")
require.Equal(t, tt.wantPageSize, pageSize, "pageSize 不符合预期")
})
}
}
func Test_parseInt(t *testing.T) {
tests := []struct {
name string
input string
wantVal int
wantErr bool
}{
{
name: "正常数字",
input: "123",
wantVal: 123,
wantErr: false,
},
{
name: "零",
input: "0",
wantVal: 0,
wantErr: false,
},
{
name: "单个数字",
input: "5",
wantVal: 5,
wantErr: false,
},
{
name: "大数字",
input: "99999",
wantVal: 99999,
wantErr: false,
},
{
name: "包含字母_返回0",
input: "abc",
wantVal: 0,
wantErr: false,
},
{
name: "数字开头接字母_返回0",
input: "12a",
wantVal: 0,
wantErr: false,
},
{
name: "包含负号_返回0",
input: "-1",
wantVal: 0,
wantErr: false,
},
{
name: "包含小数点_返回0",
input: "1.5",
wantVal: 0,
wantErr: false,
},
{
name: "包含空格_返回0",
input: "1 2",
wantVal: 0,
wantErr: false,
},
{
name: "空字符串",
input: "",
wantVal: 0,
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
val, err := parseInt(tt.input)
if tt.wantErr {
require.Error(t, err)
} else {
require.NoError(t, err)
}
require.Equal(t, tt.wantVal, val)
})
}
}
......@@ -268,8 +268,8 @@ func (d *SOCKS5ProxyDialer) DialTLSContext(ctx context.Context, network, addr st
"cipher_suites", len(spec.CipherSuites),
"extensions", len(spec.Extensions),
"compression_methods", spec.CompressionMethods,
"tls_vers_max", fmt.Sprintf("0x%04x", spec.TLSVersMax),
"tls_vers_min", fmt.Sprintf("0x%04x", spec.TLSVersMin))
"tls_vers_max", spec.TLSVersMax,
"tls_vers_min", spec.TLSVersMin)
if d.profile != nil {
slog.Debug("tls_fingerprint_socks5_using_profile", "name", d.profile.Name, "grease", d.profile.EnableGREASE)
......@@ -286,7 +286,7 @@ func (d *SOCKS5ProxyDialer) DialTLSContext(ctx context.Context, network, addr st
return nil, fmt.Errorf("apply TLS preset: %w", err)
}
if err := tlsConn.Handshake(); err != nil {
if err := tlsConn.HandshakeContext(ctx); err != nil {
slog.Debug("tls_fingerprint_socks5_handshake_failed", "error", err)
_ = conn.Close()
return nil, fmt.Errorf("TLS handshake failed: %w", err)
......@@ -294,8 +294,8 @@ func (d *SOCKS5ProxyDialer) DialTLSContext(ctx context.Context, network, addr st
state := tlsConn.ConnectionState()
slog.Debug("tls_fingerprint_socks5_handshake_success",
"version", fmt.Sprintf("0x%04x", state.Version),
"cipher_suite", fmt.Sprintf("0x%04x", state.CipherSuite),
"version", state.Version,
"cipher_suite", state.CipherSuite,
"alpn", state.NegotiatedProtocol)
return tlsConn, nil
......@@ -404,8 +404,8 @@ func (d *HTTPProxyDialer) DialTLSContext(ctx context.Context, network, addr stri
state := tlsConn.ConnectionState()
slog.Debug("tls_fingerprint_http_proxy_handshake_success",
"version", fmt.Sprintf("0x%04x", state.Version),
"cipher_suite", fmt.Sprintf("0x%04x", state.CipherSuite),
"version", state.Version,
"cipher_suite", state.CipherSuite,
"alpn", state.NegotiatedProtocol)
return tlsConn, nil
......@@ -470,8 +470,8 @@ func (d *Dialer) DialTLSContext(ctx context.Context, network, addr string) (net.
// Log successful handshake details
state := tlsConn.ConnectionState()
slog.Debug("tls_fingerprint_handshake_success",
"version", fmt.Sprintf("0x%04x", state.Version),
"cipher_suite", fmt.Sprintf("0x%04x", state.CipherSuite),
"version", state.Version,
"cipher_suite", state.CipherSuite,
"alpn", state.NegotiatedProtocol)
return tlsConn, nil
......
......@@ -30,7 +30,8 @@ func skipIfExternalServiceUnavailable(t *testing.T, err error) {
strings.Contains(errStr, "connection refused") ||
strings.Contains(errStr, "no such host") ||
strings.Contains(errStr, "network is unreachable") ||
strings.Contains(errStr, "timeout") {
strings.Contains(errStr, "timeout") ||
strings.Contains(errStr, "deadline exceeded") {
t.Skipf("skipping test: external service unavailable: %v", err)
}
t.Fatalf("failed to get fingerprint: %v", err)
......
//go:build unit
// Package tlsfingerprint provides TLS fingerprint simulation for HTTP clients.
//
// Unit tests for TLS fingerprint dialer.
......@@ -9,26 +11,161 @@
package tlsfingerprint
import (
"context"
"encoding/json"
"io"
"net/http"
"net/url"
"os"
"strings"
"testing"
"time"
)
// FingerprintResponse represents the response from tls.peet.ws/api/all.
type FingerprintResponse struct {
IP string `json:"ip"`
TLS TLSInfo `json:"tls"`
HTTP2 any `json:"http2"`
// TestDialerBasicConnection tests that the dialer can establish TLS connections.
func TestDialerBasicConnection(t *testing.T) {
skipNetworkTest(t)
// Create a dialer with default profile
profile := &Profile{
Name: "Test Profile",
EnableGREASE: false,
}
dialer := NewDialer(profile, nil)
// Create HTTP client with custom TLS dialer
client := &http.Client{
Transport: &http.Transport{
DialTLSContext: dialer.DialTLSContext,
},
Timeout: 30 * time.Second,
}
// Make a request to a known HTTPS endpoint
resp, err := client.Get("https://www.google.com")
if err != nil {
t.Fatalf("failed to connect: %v", err)
}
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode != http.StatusOK {
t.Errorf("expected status 200, got %d", resp.StatusCode)
}
}
// TLSInfo contains TLS fingerprint details.
type TLSInfo struct {
JA3 string `json:"ja3"`
JA3Hash string `json:"ja3_hash"`
JA4 string `json:"ja4"`
PeetPrint string `json:"peetprint"`
PeetPrintHash string `json:"peetprint_hash"`
ClientRandom string `json:"client_random"`
SessionID string `json:"session_id"`
// TestJA3Fingerprint verifies the JA3/JA4 fingerprint matches expected value.
// This test uses tls.peet.ws to verify the fingerprint.
// Expected JA3 hash: 1a28e69016765d92e3b381168d68922c (Claude CLI / Node.js 20.x)
// Expected JA4: t13d5911h1_a33745022dd6_1f22a2ca17c4 (d=domain) or t13i5911h1_... (i=IP)
func TestJA3Fingerprint(t *testing.T) {
skipNetworkTest(t)
profile := &Profile{
Name: "Claude CLI Test",
EnableGREASE: false,
}
dialer := NewDialer(profile, nil)
client := &http.Client{
Transport: &http.Transport{
DialTLSContext: dialer.DialTLSContext,
},
Timeout: 30 * time.Second,
}
// Use tls.peet.ws fingerprint detection API
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
req, err := http.NewRequestWithContext(ctx, "GET", "https://tls.peet.ws/api/all", nil)
if err != nil {
t.Fatalf("failed to create request: %v", err)
}
req.Header.Set("User-Agent", "Claude Code/2.0.0 Node.js/20.0.0")
resp, err := client.Do(req)
if err != nil {
t.Fatalf("failed to get fingerprint: %v", err)
}
defer func() { _ = resp.Body.Close() }()
body, err := io.ReadAll(resp.Body)
if err != nil {
t.Fatalf("failed to read response: %v", err)
}
var fpResp FingerprintResponse
if err := json.Unmarshal(body, &fpResp); err != nil {
t.Logf("Response body: %s", string(body))
t.Fatalf("failed to parse fingerprint response: %v", err)
}
// Log all fingerprint information
t.Logf("JA3: %s", fpResp.TLS.JA3)
t.Logf("JA3 Hash: %s", fpResp.TLS.JA3Hash)
t.Logf("JA4: %s", fpResp.TLS.JA4)
t.Logf("PeetPrint: %s", fpResp.TLS.PeetPrint)
t.Logf("PeetPrint Hash: %s", fpResp.TLS.PeetPrintHash)
// Verify JA3 hash matches expected value
expectedJA3Hash := "1a28e69016765d92e3b381168d68922c"
if fpResp.TLS.JA3Hash == expectedJA3Hash {
t.Logf("✓ JA3 hash matches expected value: %s", expectedJA3Hash)
} else {
t.Errorf("✗ JA3 hash mismatch: got %s, expected %s", fpResp.TLS.JA3Hash, expectedJA3Hash)
}
// Verify JA4 fingerprint
// JA4 format: t[version][sni][cipher_count][ext_count][alpn]_[cipher_hash]_[ext_hash]
// Expected: t13d5910h1 (d=domain) or t13i5910h1 (i=IP)
// The suffix _a33745022dd6_1f22a2ca17c4 should match
expectedJA4Suffix := "_a33745022dd6_1f22a2ca17c4"
if strings.HasSuffix(fpResp.TLS.JA4, expectedJA4Suffix) {
t.Logf("✓ JA4 suffix matches expected value: %s", expectedJA4Suffix)
} else {
t.Errorf("✗ JA4 suffix mismatch: got %s, expected suffix %s", fpResp.TLS.JA4, expectedJA4Suffix)
}
// Verify JA4 prefix (t13d5911h1 or t13i5911h1)
// d = domain (SNI present), i = IP (no SNI)
// Since we connect to tls.peet.ws (domain), we expect 'd'
expectedJA4Prefix := "t13d5911h1"
if strings.HasPrefix(fpResp.TLS.JA4, expectedJA4Prefix) {
t.Logf("✓ JA4 prefix matches: %s (t13=TLS1.3, d=domain, 59=ciphers, 11=extensions, h1=HTTP/1.1)", expectedJA4Prefix)
} else {
// Also accept 'i' variant for IP connections
altPrefix := "t13i5911h1"
if strings.HasPrefix(fpResp.TLS.JA4, altPrefix) {
t.Logf("✓ JA4 prefix matches (IP variant): %s", altPrefix)
} else {
t.Errorf("✗ JA4 prefix mismatch: got %s, expected %s or %s", fpResp.TLS.JA4, expectedJA4Prefix, altPrefix)
}
}
// Verify JA3 contains expected cipher suites (TLS 1.3 ciphers at the beginning)
if strings.Contains(fpResp.TLS.JA3, "4866-4867-4865") {
t.Logf("✓ JA3 contains expected TLS 1.3 cipher suites")
} else {
t.Logf("Warning: JA3 does not contain expected TLS 1.3 cipher suites")
}
// Verify extension list (should be 11 extensions including SNI)
// Expected: 0-11-10-35-16-22-23-13-43-45-51
expectedExtensions := "0-11-10-35-16-22-23-13-43-45-51"
if strings.Contains(fpResp.TLS.JA3, expectedExtensions) {
t.Logf("✓ JA3 contains expected extension list: %s", expectedExtensions)
} else {
t.Logf("Warning: JA3 extension list may differ")
}
}
func skipNetworkTest(t *testing.T) {
if testing.Short() {
t.Skip("跳过网络测试(short 模式)")
}
if os.Getenv("TLSFINGERPRINT_NETWORK_TESTS") != "1" {
t.Skip("跳过网络测试(需要设置 TLSFINGERPRINT_NETWORK_TESTS=1)")
}
}
// TestDialerWithProfile tests that different profiles produce different fingerprints.
......@@ -158,3 +295,137 @@ func mustParseURL(rawURL string) *url.URL {
}
return u
}
// TestProfileExpectation defines expected fingerprint values for a profile.
type TestProfileExpectation struct {
Profile *Profile
ExpectedJA3 string // Expected JA3 hash (empty = don't check)
ExpectedJA4 string // Expected full JA4 (empty = don't check)
JA4CipherHash string // Expected JA4 cipher hash - the stable middle part (empty = don't check)
}
// TestAllProfiles tests multiple TLS fingerprint profiles against tls.peet.ws.
// Run with: go test -v -run TestAllProfiles ./internal/pkg/tlsfingerprint/...
func TestAllProfiles(t *testing.T) {
skipNetworkTest(t)
// Define all profiles to test with their expected fingerprints
// These profiles are from config.yaml gateway.tls_fingerprint.profiles
profiles := []TestProfileExpectation{
{
// Linux x64 Node.js v22.17.1
// Expected JA3 Hash: 1a28e69016765d92e3b381168d68922c
// Expected JA4: t13d5911h1_a33745022dd6_1f22a2ca17c4
Profile: &Profile{
Name: "linux_x64_node_v22171",
EnableGREASE: false,
CipherSuites: []uint16{4866, 4867, 4865, 49199, 49195, 49200, 49196, 158, 49191, 103, 49192, 107, 163, 159, 52393, 52392, 52394, 49327, 49325, 49315, 49311, 49245, 49249, 49239, 49235, 162, 49326, 49324, 49314, 49310, 49244, 49248, 49238, 49234, 49188, 106, 49187, 64, 49162, 49172, 57, 56, 49161, 49171, 51, 50, 157, 49313, 49309, 49233, 156, 49312, 49308, 49232, 61, 60, 53, 47, 255},
Curves: []uint16{29, 23, 30, 25, 24, 256, 257, 258, 259, 260},
PointFormats: []uint8{0, 1, 2},
},
JA4CipherHash: "a33745022dd6", // stable part
},
{
// MacOS arm64 Node.js v22.18.0
// Expected JA3 Hash: 70cb5ca646080902703ffda87036a5ea
// Expected JA4: t13d5912h1_a33745022dd6_dbd39dd1d406
Profile: &Profile{
Name: "macos_arm64_node_v22180",
EnableGREASE: false,
CipherSuites: []uint16{4866, 4867, 4865, 49199, 49195, 49200, 49196, 158, 49191, 103, 49192, 107, 163, 159, 52393, 52392, 52394, 49327, 49325, 49315, 49311, 49245, 49249, 49239, 49235, 162, 49326, 49324, 49314, 49310, 49244, 49248, 49238, 49234, 49188, 106, 49187, 64, 49162, 49172, 57, 56, 49161, 49171, 51, 50, 157, 49313, 49309, 49233, 156, 49312, 49308, 49232, 61, 60, 53, 47, 255},
Curves: []uint16{29, 23, 30, 25, 24, 256, 257, 258, 259, 260},
PointFormats: []uint8{0, 1, 2},
},
JA4CipherHash: "a33745022dd6", // stable part (same cipher suites)
},
}
for _, tc := range profiles {
tc := tc // capture range variable
t.Run(tc.Profile.Name, func(t *testing.T) {
fp := fetchFingerprint(t, tc.Profile)
if fp == nil {
return // fetchFingerprint already called t.Fatal
}
t.Logf("Profile: %s", tc.Profile.Name)
t.Logf(" JA3: %s", fp.JA3)
t.Logf(" JA3 Hash: %s", fp.JA3Hash)
t.Logf(" JA4: %s", fp.JA4)
t.Logf(" PeetPrint: %s", fp.PeetPrint)
t.Logf(" PeetPrintHash: %s", fp.PeetPrintHash)
// Verify expectations
if tc.ExpectedJA3 != "" {
if fp.JA3Hash == tc.ExpectedJA3 {
t.Logf(" ✓ JA3 hash matches: %s", tc.ExpectedJA3)
} else {
t.Errorf(" ✗ JA3 hash mismatch: got %s, expected %s", fp.JA3Hash, tc.ExpectedJA3)
}
}
if tc.ExpectedJA4 != "" {
if fp.JA4 == tc.ExpectedJA4 {
t.Logf(" ✓ JA4 matches: %s", tc.ExpectedJA4)
} else {
t.Errorf(" ✗ JA4 mismatch: got %s, expected %s", fp.JA4, tc.ExpectedJA4)
}
}
// Check JA4 cipher hash (stable middle part)
// JA4 format: prefix_cipherHash_extHash
if tc.JA4CipherHash != "" {
if strings.Contains(fp.JA4, "_"+tc.JA4CipherHash+"_") {
t.Logf(" ✓ JA4 cipher hash matches: %s", tc.JA4CipherHash)
} else {
t.Errorf(" ✗ JA4 cipher hash mismatch: got %s, expected cipher hash %s", fp.JA4, tc.JA4CipherHash)
}
}
})
}
}
// fetchFingerprint makes a request to tls.peet.ws and returns the TLS fingerprint info.
func fetchFingerprint(t *testing.T, profile *Profile) *TLSInfo {
t.Helper()
dialer := NewDialer(profile, nil)
client := &http.Client{
Transport: &http.Transport{
DialTLSContext: dialer.DialTLSContext,
},
Timeout: 30 * time.Second,
}
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
req, err := http.NewRequestWithContext(ctx, "GET", "https://tls.peet.ws/api/all", nil)
if err != nil {
t.Fatalf("failed to create request: %v", err)
return nil
}
req.Header.Set("User-Agent", "Claude Code/2.0.0 Node.js/20.0.0")
resp, err := client.Do(req)
if err != nil {
t.Fatalf("failed to get fingerprint: %v", err)
return nil
}
defer func() { _ = resp.Body.Close() }()
body, err := io.ReadAll(resp.Body)
if err != nil {
t.Fatalf("failed to read response: %v", err)
return nil
}
var fpResp FingerprintResponse
if err := json.Unmarshal(body, &fpResp); err != nil {
t.Logf("Response body: %s", string(body))
t.Fatalf("failed to parse fingerprint response: %v", err)
return nil
}
return &fpResp.TLS
}
package tlsfingerprint
// FingerprintResponse represents the response from tls.peet.ws/api/all.
// 共享测试类型,供 unit 和 integration 测试文件使用。
type FingerprintResponse struct {
IP string `json:"ip"`
TLS TLSInfo `json:"tls"`
HTTP2 any `json:"http2"`
}
// TLSInfo contains TLS fingerprint details.
type TLSInfo struct {
JA3 string `json:"ja3"`
JA3Hash string `json:"ja3_hash"`
JA4 string `json:"ja4"`
PeetPrint string `json:"peetprint"`
PeetPrintHash string `json:"peetprint_hash"`
ClientRandom string `json:"client_random"`
SessionID string `json:"session_id"`
}
......@@ -78,6 +78,16 @@ type ModelStat struct {
ActualCost float64 `json:"actual_cost"` // 实际扣除
}
// GroupStat represents usage statistics for a single group
type GroupStat struct {
GroupID int64 `json:"group_id"`
GroupName string `json:"group_name"`
Requests int64 `json:"requests"`
TotalTokens int64 `json:"total_tokens"`
Cost float64 `json:"cost"` // 标准计费
ActualCost float64 `json:"actual_cost"` // 实际扣除
}
// UserUsageTrendPoint represents user usage trend data point
type UserUsageTrendPoint struct {
Date string `json:"date"`
......@@ -139,10 +149,13 @@ type UsageLogFilters struct {
AccountID int64
GroupID int64
Model string
RequestType *int16
Stream *bool
BillingType *int8
StartTime *time.Time
EndTime *time.Time
// ExactTotal requests exact COUNT(*) for pagination. Default false for fast large-table paging.
ExactTotal bool
}
// UsageStats represents usage statistics
......
......@@ -15,7 +15,6 @@ import (
"database/sql"
"encoding/json"
"errors"
"log"
"strconv"
"time"
......@@ -25,6 +24,7 @@ import (
dbgroup "github.com/Wei-Shaw/sub2api/ent/group"
dbpredicate "github.com/Wei-Shaw/sub2api/ent/predicate"
dbproxy "github.com/Wei-Shaw/sub2api/ent/proxy"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/lib/pq"
......@@ -50,11 +50,6 @@ type accountRepository struct {
schedulerCache service.SchedulerCache
}
type tempUnschedSnapshot struct {
until *time.Time
reason string
}
// NewAccountRepository 创建账户仓储实例。
// 这是对外暴露的构造函数,返回接口类型以便于依赖注入。
func NewAccountRepository(client *dbent.Client, sqlDB *sql.DB, schedulerCache service.SchedulerCache) service.AccountRepository {
......@@ -127,7 +122,7 @@ func (r *accountRepository) Create(ctx context.Context, account *service.Account
account.CreatedAt = created.CreatedAt
account.UpdatedAt = created.UpdatedAt
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &account.ID, nil, buildSchedulerGroupPayload(account.GroupIDs)); err != nil {
log.Printf("[SchedulerOutbox] enqueue account create failed: account=%d err=%v", account.ID, err)
logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue account create failed: account=%d err=%v", account.ID, err)
}
return nil
}
......@@ -189,11 +184,6 @@ func (r *accountRepository) GetByIDs(ctx context.Context, ids []int64) ([]*servi
accountIDs = append(accountIDs, acc.ID)
}
tempUnschedMap, err := r.loadTempUnschedStates(ctx, accountIDs)
if err != nil {
return nil, err
}
groupsByAccount, groupIDsByAccount, accountGroupsByAccount, err := r.loadAccountGroups(ctx, accountIDs)
if err != nil {
return nil, err
......@@ -220,10 +210,6 @@ func (r *accountRepository) GetByIDs(ctx context.Context, ids []int64) ([]*servi
if ags, ok := accountGroupsByAccount[entAcc.ID]; ok {
out.AccountGroups = ags
}
if snap, ok := tempUnschedMap[entAcc.ID]; ok {
out.TempUnschedulableUntil = snap.until
out.TempUnschedulableReason = snap.reason
}
outByID[entAcc.ID] = out
}
......@@ -388,7 +374,7 @@ func (r *accountRepository) Update(ctx context.Context, account *service.Account
}
account.UpdatedAt = updated.UpdatedAt
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &account.ID, nil, buildSchedulerGroupPayload(account.GroupIDs)); err != nil {
log.Printf("[SchedulerOutbox] enqueue account update failed: account=%d err=%v", account.ID, err)
logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue account update failed: account=%d err=%v", account.ID, err)
}
if account.Status == service.StatusError || account.Status == service.StatusDisabled || !account.Schedulable {
r.syncSchedulerAccountSnapshot(ctx, account.ID)
......@@ -429,7 +415,7 @@ func (r *accountRepository) Delete(ctx context.Context, id int64) error {
}
}
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, buildSchedulerGroupPayload(groupIDs)); err != nil {
log.Printf("[SchedulerOutbox] enqueue account delete failed: account=%d err=%v", id, err)
logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue account delete failed: account=%d err=%v", id, err)
}
return nil
}
......@@ -541,7 +527,7 @@ func (r *accountRepository) UpdateLastUsed(ctx context.Context, id int64) error
},
}
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountLastUsed, &id, nil, payload); err != nil {
log.Printf("[SchedulerOutbox] enqueue last used failed: account=%d err=%v", id, err)
logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue last used failed: account=%d err=%v", id, err)
}
return nil
}
......@@ -576,7 +562,7 @@ func (r *accountRepository) BatchUpdateLastUsed(ctx context.Context, updates map
}
payload := map[string]any{"last_used": lastUsedPayload}
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountLastUsed, nil, nil, payload); err != nil {
log.Printf("[SchedulerOutbox] enqueue batch last used failed: err=%v", err)
logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue batch last used failed: err=%v", err)
}
return nil
}
......@@ -591,7 +577,7 @@ func (r *accountRepository) SetError(ctx context.Context, id int64, errorMsg str
return err
}
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
log.Printf("[SchedulerOutbox] enqueue set error failed: account=%d err=%v", id, err)
logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue set error failed: account=%d err=%v", id, err)
}
r.syncSchedulerAccountSnapshot(ctx, id)
return nil
......@@ -611,11 +597,48 @@ func (r *accountRepository) syncSchedulerAccountSnapshot(ctx context.Context, ac
}
account, err := r.GetByID(ctx, accountID)
if err != nil {
log.Printf("[Scheduler] sync account snapshot read failed: id=%d err=%v", accountID, err)
logger.LegacyPrintf("repository.account", "[Scheduler] sync account snapshot read failed: id=%d err=%v", accountID, err)
return
}
if err := r.schedulerCache.SetAccount(ctx, account); err != nil {
logger.LegacyPrintf("repository.account", "[Scheduler] sync account snapshot write failed: id=%d err=%v", accountID, err)
}
}
func (r *accountRepository) syncSchedulerAccountSnapshots(ctx context.Context, accountIDs []int64) {
if r == nil || r.schedulerCache == nil || len(accountIDs) == 0 {
return
}
uniqueIDs := make([]int64, 0, len(accountIDs))
seen := make(map[int64]struct{}, len(accountIDs))
for _, id := range accountIDs {
if id <= 0 {
continue
}
if _, exists := seen[id]; exists {
continue
}
seen[id] = struct{}{}
uniqueIDs = append(uniqueIDs, id)
}
if len(uniqueIDs) == 0 {
return
}
accounts, err := r.GetByIDs(ctx, uniqueIDs)
if err != nil {
logger.LegacyPrintf("repository.account", "[Scheduler] batch sync account snapshot read failed: count=%d err=%v", len(uniqueIDs), err)
return
}
for _, account := range accounts {
if account == nil {
continue
}
if err := r.schedulerCache.SetAccount(ctx, account); err != nil {
log.Printf("[Scheduler] sync account snapshot write failed: id=%d err=%v", accountID, err)
logger.LegacyPrintf("repository.account", "[Scheduler] batch sync account snapshot write failed: id=%d err=%v", account.ID, err)
}
}
}
......@@ -649,7 +672,7 @@ func (r *accountRepository) AddToGroup(ctx context.Context, accountID, groupID i
}
payload := buildSchedulerGroupPayload([]int64{groupID})
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountGroupsChanged, &accountID, nil, payload); err != nil {
log.Printf("[SchedulerOutbox] enqueue add to group failed: account=%d group=%d err=%v", accountID, groupID, err)
logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue add to group failed: account=%d group=%d err=%v", accountID, groupID, err)
}
return nil
}
......@@ -666,7 +689,7 @@ func (r *accountRepository) RemoveFromGroup(ctx context.Context, accountID, grou
}
payload := buildSchedulerGroupPayload([]int64{groupID})
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountGroupsChanged, &accountID, nil, payload); err != nil {
log.Printf("[SchedulerOutbox] enqueue remove from group failed: account=%d group=%d err=%v", accountID, groupID, err)
logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue remove from group failed: account=%d group=%d err=%v", accountID, groupID, err)
}
return nil
}
......@@ -739,7 +762,7 @@ func (r *accountRepository) BindGroups(ctx context.Context, accountID int64, gro
}
payload := buildSchedulerGroupPayload(mergeGroupIDs(existingGroupIDs, groupIDs))
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountGroupsChanged, &accountID, nil, payload); err != nil {
log.Printf("[SchedulerOutbox] enqueue bind groups failed: account=%d err=%v", accountID, err)
logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue bind groups failed: account=%d err=%v", accountID, err)
}
return nil
}
......@@ -824,6 +847,51 @@ func (r *accountRepository) ListSchedulableByPlatforms(ctx context.Context, plat
return r.accountsToService(ctx, accounts)
}
func (r *accountRepository) ListSchedulableUngroupedByPlatform(ctx context.Context, platform string) ([]service.Account, error) {
now := time.Now()
accounts, err := r.client.Account.Query().
Where(
dbaccount.PlatformEQ(platform),
dbaccount.StatusEQ(service.StatusActive),
dbaccount.SchedulableEQ(true),
dbaccount.Not(dbaccount.HasAccountGroups()),
tempUnschedulablePredicate(),
notExpiredPredicate(now),
dbaccount.Or(dbaccount.OverloadUntilIsNil(), dbaccount.OverloadUntilLTE(now)),
dbaccount.Or(dbaccount.RateLimitResetAtIsNil(), dbaccount.RateLimitResetAtLTE(now)),
).
Order(dbent.Asc(dbaccount.FieldPriority)).
All(ctx)
if err != nil {
return nil, err
}
return r.accountsToService(ctx, accounts)
}
func (r *accountRepository) ListSchedulableUngroupedByPlatforms(ctx context.Context, platforms []string) ([]service.Account, error) {
if len(platforms) == 0 {
return nil, nil
}
now := time.Now()
accounts, err := r.client.Account.Query().
Where(
dbaccount.PlatformIn(platforms...),
dbaccount.StatusEQ(service.StatusActive),
dbaccount.SchedulableEQ(true),
dbaccount.Not(dbaccount.HasAccountGroups()),
tempUnschedulablePredicate(),
notExpiredPredicate(now),
dbaccount.Or(dbaccount.OverloadUntilIsNil(), dbaccount.OverloadUntilLTE(now)),
dbaccount.Or(dbaccount.RateLimitResetAtIsNil(), dbaccount.RateLimitResetAtLTE(now)),
).
Order(dbent.Asc(dbaccount.FieldPriority)).
All(ctx)
if err != nil {
return nil, err
}
return r.accountsToService(ctx, accounts)
}
func (r *accountRepository) ListSchedulableByGroupIDAndPlatforms(ctx context.Context, groupID int64, platforms []string) ([]service.Account, error) {
if len(platforms) == 0 {
return nil, nil
......@@ -847,7 +915,7 @@ func (r *accountRepository) SetRateLimited(ctx context.Context, id int64, resetA
return err
}
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
log.Printf("[SchedulerOutbox] enqueue rate limit failed: account=%d err=%v", id, err)
logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue rate limit failed: account=%d err=%v", id, err)
}
return nil
}
......@@ -894,7 +962,7 @@ func (r *accountRepository) SetModelRateLimit(ctx context.Context, id int64, sco
return service.ErrAccountNotFound
}
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
log.Printf("[SchedulerOutbox] enqueue model rate limit failed: account=%d err=%v", id, err)
logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue model rate limit failed: account=%d err=%v", id, err)
}
return nil
}
......@@ -908,7 +976,7 @@ func (r *accountRepository) SetOverloaded(ctx context.Context, id int64, until t
return err
}
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
log.Printf("[SchedulerOutbox] enqueue overload failed: account=%d err=%v", id, err)
logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue overload failed: account=%d err=%v", id, err)
}
return nil
}
......@@ -927,7 +995,7 @@ func (r *accountRepository) SetTempUnschedulable(ctx context.Context, id int64,
return err
}
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
log.Printf("[SchedulerOutbox] enqueue temp unschedulable failed: account=%d err=%v", id, err)
logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue temp unschedulable failed: account=%d err=%v", id, err)
}
r.syncSchedulerAccountSnapshot(ctx, id)
return nil
......@@ -946,7 +1014,7 @@ func (r *accountRepository) ClearTempUnschedulable(ctx context.Context, id int64
return err
}
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
log.Printf("[SchedulerOutbox] enqueue clear temp unschedulable failed: account=%d err=%v", id, err)
logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue clear temp unschedulable failed: account=%d err=%v", id, err)
}
return nil
}
......@@ -962,7 +1030,7 @@ func (r *accountRepository) ClearRateLimit(ctx context.Context, id int64) error
return err
}
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
log.Printf("[SchedulerOutbox] enqueue clear rate limit failed: account=%d err=%v", id, err)
logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue clear rate limit failed: account=%d err=%v", id, err)
}
return nil
}
......@@ -986,7 +1054,7 @@ func (r *accountRepository) ClearAntigravityQuotaScopes(ctx context.Context, id
return service.ErrAccountNotFound
}
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
log.Printf("[SchedulerOutbox] enqueue clear quota scopes failed: account=%d err=%v", id, err)
logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue clear quota scopes failed: account=%d err=%v", id, err)
}
return nil
}
......@@ -1010,7 +1078,7 @@ func (r *accountRepository) ClearModelRateLimits(ctx context.Context, id int64)
return service.ErrAccountNotFound
}
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
log.Printf("[SchedulerOutbox] enqueue clear model rate limit failed: account=%d err=%v", id, err)
logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue clear model rate limit failed: account=%d err=%v", id, err)
}
return nil
}
......@@ -1032,7 +1100,7 @@ func (r *accountRepository) UpdateSessionWindow(ctx context.Context, id int64, s
// 触发调度器缓存更新(仅当窗口时间有变化时)
if start != nil || end != nil {
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
log.Printf("[SchedulerOutbox] enqueue session window update failed: account=%d err=%v", id, err)
logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue session window update failed: account=%d err=%v", id, err)
}
}
return nil
......@@ -1047,7 +1115,7 @@ func (r *accountRepository) SetSchedulable(ctx context.Context, id int64, schedu
return err
}
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
log.Printf("[SchedulerOutbox] enqueue schedulable change failed: account=%d err=%v", id, err)
logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue schedulable change failed: account=%d err=%v", id, err)
}
if !schedulable {
r.syncSchedulerAccountSnapshot(ctx, id)
......@@ -1075,7 +1143,7 @@ func (r *accountRepository) AutoPauseExpiredAccounts(ctx context.Context, now ti
}
if rows > 0 {
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventFullRebuild, nil, nil, nil); err != nil {
log.Printf("[SchedulerOutbox] enqueue auto pause rebuild failed: err=%v", err)
logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue auto pause rebuild failed: err=%v", err)
}
}
return rows, nil
......@@ -1111,7 +1179,7 @@ func (r *accountRepository) UpdateExtra(ctx context.Context, id int64, updates m
return service.ErrAccountNotFound
}
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
log.Printf("[SchedulerOutbox] enqueue extra update failed: account=%d err=%v", id, err)
logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue extra update failed: account=%d err=%v", id, err)
}
return nil
}
......@@ -1205,7 +1273,7 @@ func (r *accountRepository) BulkUpdate(ctx context.Context, ids []int64, updates
if rows > 0 {
payload := map[string]any{"account_ids": ids}
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountBulkChanged, nil, nil, payload); err != nil {
log.Printf("[SchedulerOutbox] enqueue bulk update failed: err=%v", err)
logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue bulk update failed: err=%v", err)
}
shouldSync := false
if updates.Status != nil && (*updates.Status == service.StatusError || *updates.Status == service.StatusDisabled) {
......@@ -1215,9 +1283,7 @@ func (r *accountRepository) BulkUpdate(ctx context.Context, ids []int64, updates
shouldSync = true
}
if shouldSync {
for _, id := range ids {
r.syncSchedulerAccountSnapshot(ctx, id)
}
r.syncSchedulerAccountSnapshots(ctx, ids)
}
}
return rows, nil
......@@ -1309,10 +1375,6 @@ func (r *accountRepository) accountsToService(ctx context.Context, accounts []*d
if err != nil {
return nil, err
}
tempUnschedMap, err := r.loadTempUnschedStates(ctx, accountIDs)
if err != nil {
return nil, err
}
groupsByAccount, groupIDsByAccount, accountGroupsByAccount, err := r.loadAccountGroups(ctx, accountIDs)
if err != nil {
return nil, err
......@@ -1338,10 +1400,6 @@ func (r *accountRepository) accountsToService(ctx context.Context, accounts []*d
if ags, ok := accountGroupsByAccount[acc.ID]; ok {
out.AccountGroups = ags
}
if snap, ok := tempUnschedMap[acc.ID]; ok {
out.TempUnschedulableUntil = snap.until
out.TempUnschedulableReason = snap.reason
}
outAccounts = append(outAccounts, *out)
}
......@@ -1366,48 +1424,6 @@ func notExpiredPredicate(now time.Time) dbpredicate.Account {
)
}
func (r *accountRepository) loadTempUnschedStates(ctx context.Context, accountIDs []int64) (map[int64]tempUnschedSnapshot, error) {
out := make(map[int64]tempUnschedSnapshot)
if len(accountIDs) == 0 {
return out, nil
}
rows, err := r.sql.QueryContext(ctx, `
SELECT id, temp_unschedulable_until, temp_unschedulable_reason
FROM accounts
WHERE id = ANY($1)
`, pq.Array(accountIDs))
if err != nil {
return nil, err
}
defer func() { _ = rows.Close() }()
for rows.Next() {
var id int64
var until sql.NullTime
var reason sql.NullString
if err := rows.Scan(&id, &until, &reason); err != nil {
return nil, err
}
var untilPtr *time.Time
if until.Valid {
tmp := until.Time
untilPtr = &tmp
}
if reason.Valid {
out[id] = tempUnschedSnapshot{until: untilPtr, reason: reason.String}
} else {
out[id] = tempUnschedSnapshot{until: untilPtr, reason: ""}
}
}
if err := rows.Err(); err != nil {
return nil, err
}
return out, nil
}
func (r *accountRepository) loadProxies(ctx context.Context, proxyIDs []int64) (map[int64]*service.Proxy, error) {
proxyMap := make(map[int64]*service.Proxy)
if len(proxyIDs) == 0 {
......@@ -1540,6 +1556,8 @@ func accountEntityToService(m *dbent.Account) *service.Account {
RateLimitedAt: m.RateLimitedAt,
RateLimitResetAt: m.RateLimitResetAt,
OverloadUntil: m.OverloadUntil,
TempUnschedulableUntil: m.TempUnschedulableUntil,
TempUnschedulableReason: derefString(m.TempUnschedulableReason),
SessionWindowStart: m.SessionWindowStart,
SessionWindowEnd: m.SessionWindowEnd,
SessionWindowStatus: derefString(m.SessionWindowStatus),
......@@ -1578,3 +1596,64 @@ func joinClauses(clauses []string, sep string) string {
func itoa(v int) string {
return strconv.Itoa(v)
}
// FindByExtraField 根据 extra 字段中的键值对查找账号。
// 该方法限定 platform='sora',避免误查询其他平台的账号。
// 使用 PostgreSQL JSONB @> 操作符进行高效查询(需要 GIN 索引支持)。
//
// 应用场景:查找通过 linked_openai_account_id 关联的 Sora 账号。
//
// FindByExtraField finds accounts by key-value pairs in the extra field.
// Limited to platform='sora' to avoid querying accounts from other platforms.
// Uses PostgreSQL JSONB @> operator for efficient queries (requires GIN index).
//
// Use case: Finding Sora accounts linked via linked_openai_account_id.
func (r *accountRepository) FindByExtraField(ctx context.Context, key string, value any) ([]service.Account, error) {
accounts, err := r.client.Account.Query().
Where(
dbaccount.PlatformEQ("sora"), // 限定平台为 sora
dbaccount.DeletedAtIsNil(),
func(s *entsql.Selector) {
path := sqljson.Path(key)
switch v := value.(type) {
case string:
preds := []*entsql.Predicate{sqljson.ValueEQ(dbaccount.FieldExtra, v, path)}
if parsed, err := strconv.ParseInt(v, 10, 64); err == nil {
preds = append(preds, sqljson.ValueEQ(dbaccount.FieldExtra, parsed, path))
}
if len(preds) == 1 {
s.Where(preds[0])
} else {
s.Where(entsql.Or(preds...))
}
case int:
s.Where(entsql.Or(
sqljson.ValueEQ(dbaccount.FieldExtra, v, path),
sqljson.ValueEQ(dbaccount.FieldExtra, strconv.Itoa(v), path),
))
case int64:
s.Where(entsql.Or(
sqljson.ValueEQ(dbaccount.FieldExtra, v, path),
sqljson.ValueEQ(dbaccount.FieldExtra, strconv.FormatInt(v, 10), path),
))
case json.Number:
if parsed, err := v.Int64(); err == nil {
s.Where(entsql.Or(
sqljson.ValueEQ(dbaccount.FieldExtra, parsed, path),
sqljson.ValueEQ(dbaccount.FieldExtra, v.String(), path),
))
} else {
s.Where(sqljson.ValueEQ(dbaccount.FieldExtra, v.String(), path))
}
default:
s.Where(sqljson.ValueEQ(dbaccount.FieldExtra, value, path))
}
},
).
All(ctx)
if err != nil {
return nil, translatePersistenceError(err, service.ErrAccountNotFound, nil)
}
return r.accountsToService(ctx, accounts)
}
......@@ -500,6 +500,38 @@ func (s *AccountRepoSuite) TestClearRateLimit() {
s.Require().Nil(got.OverloadUntil)
}
func (s *AccountRepoSuite) TestTempUnschedulableFieldsLoadedByGetByIDAndGetByIDs() {
acc1 := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-temp-1"})
acc2 := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-temp-2"})
until := time.Now().Add(15 * time.Minute).UTC().Truncate(time.Second)
reason := `{"rule":"429","matched_keyword":"too many requests"}`
s.Require().NoError(s.repo.SetTempUnschedulable(s.ctx, acc1.ID, until, reason))
gotByID, err := s.repo.GetByID(s.ctx, acc1.ID)
s.Require().NoError(err)
s.Require().NotNil(gotByID.TempUnschedulableUntil)
s.Require().WithinDuration(until, *gotByID.TempUnschedulableUntil, time.Second)
s.Require().Equal(reason, gotByID.TempUnschedulableReason)
gotByIDs, err := s.repo.GetByIDs(s.ctx, []int64{acc2.ID, acc1.ID})
s.Require().NoError(err)
s.Require().Len(gotByIDs, 2)
s.Require().Equal(acc2.ID, gotByIDs[0].ID)
s.Require().Nil(gotByIDs[0].TempUnschedulableUntil)
s.Require().Equal("", gotByIDs[0].TempUnschedulableReason)
s.Require().Equal(acc1.ID, gotByIDs[1].ID)
s.Require().NotNil(gotByIDs[1].TempUnschedulableUntil)
s.Require().WithinDuration(until, *gotByIDs[1].TempUnschedulableUntil, time.Second)
s.Require().Equal(reason, gotByIDs[1].TempUnschedulableReason)
s.Require().NoError(s.repo.ClearTempUnschedulable(s.ctx, acc1.ID))
cleared, err := s.repo.GetByID(s.ctx, acc1.ID)
s.Require().NoError(err)
s.Require().Nil(cleared.TempUnschedulableUntil)
s.Require().Equal("", cleared.TempUnschedulableReason)
}
// --- UpdateLastUsed ---
func (s *AccountRepoSuite) TestUpdateLastUsed() {
......
......@@ -98,7 +98,7 @@ func TestGroupRepository_DeleteCascade_RemovesAllowedGroupsAndClearsApiKeys(t *t
userRepo := newUserRepositoryWithSQL(entClient, tx)
groupRepo := newGroupRepositoryWithSQL(entClient, tx)
apiKeyRepo := NewAPIKeyRepository(entClient)
apiKeyRepo := newAPIKeyRepositoryWithSQL(entClient, tx)
u := &service.User{
Email: uniqueTestValue(t, "cascade-user") + "@example.com",
......
......@@ -2,6 +2,7 @@ package repository
import (
"context"
"database/sql"
"time"
dbent "github.com/Wei-Shaw/sub2api/ent"
......@@ -16,10 +17,15 @@ import (
type apiKeyRepository struct {
client *dbent.Client
sql sqlExecutor
}
func NewAPIKeyRepository(client *dbent.Client) service.APIKeyRepository {
return &apiKeyRepository{client: client}
func NewAPIKeyRepository(client *dbent.Client, sqlDB *sql.DB) service.APIKeyRepository {
return newAPIKeyRepositoryWithSQL(client, sqlDB)
}
func newAPIKeyRepositoryWithSQL(client *dbent.Client, sqlq sqlExecutor) *apiKeyRepository {
return &apiKeyRepository{client: client, sql: sqlq}
}
func (r *apiKeyRepository) activeQuery() *dbent.APIKeyQuery {
......@@ -34,9 +40,13 @@ func (r *apiKeyRepository) Create(ctx context.Context, key *service.APIKey) erro
SetName(key.Name).
SetStatus(key.Status).
SetNillableGroupID(key.GroupID).
SetNillableLastUsedAt(key.LastUsedAt).
SetQuota(key.Quota).
SetQuotaUsed(key.QuotaUsed).
SetNillableExpiresAt(key.ExpiresAt)
SetNillableExpiresAt(key.ExpiresAt).
SetRateLimit5h(key.RateLimit5h).
SetRateLimit1d(key.RateLimit1d).
SetRateLimit7d(key.RateLimit7d)
if len(key.IPWhitelist) > 0 {
builder.SetIPWhitelist(key.IPWhitelist)
......@@ -48,6 +58,7 @@ func (r *apiKeyRepository) Create(ctx context.Context, key *service.APIKey) erro
created, err := builder.Save(ctx)
if err == nil {
key.ID = created.ID
key.LastUsedAt = created.LastUsedAt
key.CreatedAt = created.CreatedAt
key.UpdatedAt = created.UpdatedAt
}
......@@ -116,6 +127,9 @@ func (r *apiKeyRepository) GetByKeyForAuth(ctx context.Context, key string) (*se
apikey.FieldQuota,
apikey.FieldQuotaUsed,
apikey.FieldExpiresAt,
apikey.FieldRateLimit5h,
apikey.FieldRateLimit1d,
apikey.FieldRateLimit7d,
).
WithUser(func(q *dbent.UserQuery) {
q.Select(
......@@ -140,6 +154,10 @@ func (r *apiKeyRepository) GetByKeyForAuth(ctx context.Context, key string) (*se
group.FieldImagePrice1k,
group.FieldImagePrice2k,
group.FieldImagePrice4k,
group.FieldSoraImagePrice360,
group.FieldSoraImagePrice540,
group.FieldSoraVideoPricePerRequest,
group.FieldSoraVideoPricePerRequestHd,
group.FieldClaudeCodeOnly,
group.FieldFallbackGroupID,
group.FieldFallbackGroupIDOnInvalidRequest,
......@@ -165,13 +183,20 @@ func (r *apiKeyRepository) Update(ctx context.Context, key *service.APIKey) erro
// 则会更新已删除的记录。
// 这里选择 Update().Where(),确保只有未软删除记录能被更新。
// 同时显式设置 updated_at,避免二次查询带来的并发可见性问题。
client := clientFromContext(ctx, r.client)
now := time.Now()
builder := r.client.APIKey.Update().
builder := client.APIKey.Update().
Where(apikey.IDEQ(key.ID), apikey.DeletedAtIsNil()).
SetName(key.Name).
SetStatus(key.Status).
SetQuota(key.Quota).
SetQuotaUsed(key.QuotaUsed).
SetRateLimit5h(key.RateLimit5h).
SetRateLimit1d(key.RateLimit1d).
SetRateLimit7d(key.RateLimit7d).
SetUsage5h(key.Usage5h).
SetUsage1d(key.Usage1d).
SetUsage7d(key.Usage7d).
SetUpdatedAt(now)
if key.GroupID != nil {
builder.SetGroupID(*key.GroupID)
......@@ -186,6 +211,23 @@ func (r *apiKeyRepository) Update(ctx context.Context, key *service.APIKey) erro
builder.ClearExpiresAt()
}
// Rate limit window start times
if key.Window5hStart != nil {
builder.SetWindow5hStart(*key.Window5hStart)
} else {
builder.ClearWindow5hStart()
}
if key.Window1dStart != nil {
builder.SetWindow1dStart(*key.Window1dStart)
} else {
builder.ClearWindow1dStart()
}
if key.Window7dStart != nil {
builder.SetWindow7dStart(*key.Window7dStart)
} else {
builder.ClearWindow7dStart()
}
// IP 限制字段
if len(key.IPWhitelist) > 0 {
builder.SetIPWhitelist(key.IPWhitelist)
......@@ -239,9 +281,27 @@ func (r *apiKeyRepository) Delete(ctx context.Context, id int64) error {
return nil
}
func (r *apiKeyRepository) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]service.APIKey, *pagination.PaginationResult, error) {
func (r *apiKeyRepository) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams, filters service.APIKeyListFilters) ([]service.APIKey, *pagination.PaginationResult, error) {
q := r.activeQuery().Where(apikey.UserIDEQ(userID))
// Apply filters
if filters.Search != "" {
q = q.Where(apikey.Or(
apikey.NameContainsFold(filters.Search),
apikey.KeyContainsFold(filters.Search),
))
}
if filters.Status != "" {
q = q.Where(apikey.StatusEQ(filters.Status))
}
if filters.GroupID != nil {
if *filters.GroupID == 0 {
q = q.Where(apikey.GroupIDIsNil())
} else {
q = q.Where(apikey.GroupIDEQ(*filters.GroupID))
}
}
total, err := q.Count(ctx)
if err != nil {
return nil, nil, err
......@@ -375,36 +435,92 @@ func (r *apiKeyRepository) ListKeysByGroupID(ctx context.Context, groupID int64)
return keys, nil
}
// IncrementQuotaUsed atomically increments the quota_used field and returns the new value
// IncrementQuotaUsed 使用 Ent 原子递增 quota_used 字段并返回新值
func (r *apiKeyRepository) IncrementQuotaUsed(ctx context.Context, id int64, amount float64) (float64, error) {
// Use raw SQL for atomic increment to avoid race conditions
// First get current value
m, err := r.activeQuery().
Where(apikey.IDEQ(id)).
Select(apikey.FieldQuotaUsed).
Only(ctx)
updated, err := r.client.APIKey.UpdateOneID(id).
Where(apikey.DeletedAtIsNil()).
AddQuotaUsed(amount).
Save(ctx)
if err != nil {
if dbent.IsNotFound(err) {
return 0, service.ErrAPIKeyNotFound
}
return 0, err
}
return updated.QuotaUsed, nil
}
newValue := m.QuotaUsed + amount
// Update with new value
func (r *apiKeyRepository) UpdateLastUsed(ctx context.Context, id int64, usedAt time.Time) error {
affected, err := r.client.APIKey.Update().
Where(apikey.IDEQ(id), apikey.DeletedAtIsNil()).
SetQuotaUsed(newValue).
SetLastUsedAt(usedAt).
SetUpdatedAt(usedAt).
Save(ctx)
if err != nil {
return 0, err
return err
}
if affected == 0 {
return 0, service.ErrAPIKeyNotFound
return service.ErrAPIKeyNotFound
}
return nil
}
// IncrementRateLimitUsage atomically increments all rate limit usage counters and initializes
// window start times via COALESCE if not already set.
func (r *apiKeyRepository) IncrementRateLimitUsage(ctx context.Context, id int64, cost float64) error {
_, err := r.sql.ExecContext(ctx, `
UPDATE api_keys SET
usage_5h = usage_5h + $1,
usage_1d = usage_1d + $1,
usage_7d = usage_7d + $1,
window_5h_start = COALESCE(window_5h_start, NOW()),
window_1d_start = COALESCE(window_1d_start, NOW()),
window_7d_start = COALESCE(window_7d_start, NOW()),
updated_at = NOW()
WHERE id = $2 AND deleted_at IS NULL`,
cost, id)
return err
}
// ResetRateLimitWindows resets expired rate limit windows atomically.
func (r *apiKeyRepository) ResetRateLimitWindows(ctx context.Context, id int64) error {
_, err := r.sql.ExecContext(ctx, `
UPDATE api_keys SET
usage_5h = CASE WHEN window_5h_start IS NOT NULL AND window_5h_start + INTERVAL '5 hours' <= NOW() THEN 0 ELSE usage_5h END,
window_5h_start = CASE WHEN window_5h_start IS NOT NULL AND window_5h_start + INTERVAL '5 hours' <= NOW() THEN NOW() ELSE window_5h_start END,
usage_1d = CASE WHEN window_1d_start IS NOT NULL AND window_1d_start + INTERVAL '24 hours' <= NOW() THEN 0 ELSE usage_1d END,
window_1d_start = CASE WHEN window_1d_start IS NOT NULL AND window_1d_start + INTERVAL '24 hours' <= NOW() THEN NOW() ELSE window_1d_start END,
usage_7d = CASE WHEN window_7d_start IS NOT NULL AND window_7d_start + INTERVAL '7 days' <= NOW() THEN 0 ELSE usage_7d END,
window_7d_start = CASE WHEN window_7d_start IS NOT NULL AND window_7d_start + INTERVAL '7 days' <= NOW() THEN NOW() ELSE window_7d_start END,
updated_at = NOW()
WHERE id = $1 AND deleted_at IS NULL`,
id)
return err
}
return newValue, nil
// GetRateLimitData returns the current rate limit usage and window start times for an API key.
func (r *apiKeyRepository) GetRateLimitData(ctx context.Context, id int64) (result *service.APIKeyRateLimitData, err error) {
rows, err := r.sql.QueryContext(ctx, `
SELECT usage_5h, usage_1d, usage_7d, window_5h_start, window_1d_start, window_7d_start
FROM api_keys
WHERE id = $1 AND deleted_at IS NULL`,
id)
if err != nil {
return nil, err
}
defer func() {
if closeErr := rows.Close(); closeErr != nil && err == nil {
err = closeErr
}
}()
if !rows.Next() {
return nil, service.ErrAPIKeyNotFound
}
data := &service.APIKeyRateLimitData{}
if err := rows.Scan(&data.Usage5h, &data.Usage1d, &data.Usage7d, &data.Window5hStart, &data.Window1dStart, &data.Window7dStart); err != nil {
return nil, err
}
return data, rows.Err()
}
func apiKeyEntityToService(m *dbent.APIKey) *service.APIKey {
......@@ -419,12 +535,22 @@ func apiKeyEntityToService(m *dbent.APIKey) *service.APIKey {
Status: m.Status,
IPWhitelist: m.IPWhitelist,
IPBlacklist: m.IPBlacklist,
LastUsedAt: m.LastUsedAt,
CreatedAt: m.CreatedAt,
UpdatedAt: m.UpdatedAt,
GroupID: m.GroupID,
Quota: m.Quota,
QuotaUsed: m.QuotaUsed,
ExpiresAt: m.ExpiresAt,
RateLimit5h: m.RateLimit5h,
RateLimit1d: m.RateLimit1d,
RateLimit7d: m.RateLimit7d,
Usage5h: m.Usage5h,
Usage1d: m.Usage1d,
Usage7d: m.Usage7d,
Window5hStart: m.Window5hStart,
Window1dStart: m.Window1dStart,
Window7dStart: m.Window7dStart,
}
if m.Edges.User != nil {
out.User = userEntityToService(m.Edges.User)
......@@ -449,6 +575,8 @@ func userEntityToService(u *dbent.User) *service.User {
Balance: u.Balance,
Concurrency: u.Concurrency,
Status: u.Status,
SoraStorageQuotaBytes: u.SoraStorageQuotaBytes,
SoraStorageUsedBytes: u.SoraStorageUsedBytes,
TotpSecretEncrypted: u.TotpSecretEncrypted,
TotpEnabled: u.TotpEnabled,
TotpEnabledAt: u.TotpEnabledAt,
......@@ -477,6 +605,11 @@ func groupEntityToService(g *dbent.Group) *service.Group {
ImagePrice1K: g.ImagePrice1k,
ImagePrice2K: g.ImagePrice2k,
ImagePrice4K: g.ImagePrice4k,
SoraImagePrice360: g.SoraImagePrice360,
SoraImagePrice540: g.SoraImagePrice540,
SoraVideoPricePerRequest: g.SoraVideoPricePerRequest,
SoraVideoPricePerRequestHD: g.SoraVideoPricePerRequestHd,
SoraStorageQuotaBytes: g.SoraStorageQuotaBytes,
DefaultValidityDays: g.DefaultValidityDays,
ClaudeCodeOnly: g.ClaudeCodeOnly,
FallbackGroupID: g.FallbackGroupID,
......
......@@ -4,11 +4,14 @@ package repository
import (
"context"
"sync"
"testing"
"time"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
)
......@@ -23,7 +26,7 @@ func (s *APIKeyRepoSuite) SetupTest() {
s.ctx = context.Background()
tx := testEntTx(s.T())
s.client = tx.Client()
s.repo = NewAPIKeyRepository(s.client).(*apiKeyRepository)
s.repo = newAPIKeyRepositoryWithSQL(s.client, tx)
}
func TestAPIKeyRepoSuite(t *testing.T) {
......@@ -155,7 +158,7 @@ func (s *APIKeyRepoSuite) TestListByUserID() {
s.mustCreateApiKey(user.ID, "sk-list-1", "Key 1", nil)
s.mustCreateApiKey(user.ID, "sk-list-2", "Key 2", nil)
keys, page, err := s.repo.ListByUserID(s.ctx, user.ID, pagination.PaginationParams{Page: 1, PageSize: 10})
keys, page, err := s.repo.ListByUserID(s.ctx, user.ID, pagination.PaginationParams{Page: 1, PageSize: 10}, service.APIKeyListFilters{})
s.Require().NoError(err, "ListByUserID")
s.Require().Len(keys, 2)
s.Require().Equal(int64(2), page.Total)
......@@ -167,7 +170,7 @@ func (s *APIKeyRepoSuite) TestListByUserID_Pagination() {
s.mustCreateApiKey(user.ID, "sk-page-"+string(rune('a'+i)), "Key", nil)
}
keys, page, err := s.repo.ListByUserID(s.ctx, user.ID, pagination.PaginationParams{Page: 1, PageSize: 2})
keys, page, err := s.repo.ListByUserID(s.ctx, user.ID, pagination.PaginationParams{Page: 1, PageSize: 2}, service.APIKeyListFilters{})
s.Require().NoError(err)
s.Require().Len(keys, 2)
s.Require().Equal(int64(5), page.Total)
......@@ -311,7 +314,7 @@ func (s *APIKeyRepoSuite) TestCRUD_Search_ClearGroupID() {
s.Require().Equal(service.StatusDisabled, got2.Status)
s.Require().Nil(got2.GroupID)
keys, page, err := s.repo.ListByUserID(s.ctx, user.ID, pagination.PaginationParams{Page: 1, PageSize: 10})
keys, page, err := s.repo.ListByUserID(s.ctx, user.ID, pagination.PaginationParams{Page: 1, PageSize: 10}, service.APIKeyListFilters{})
s.Require().NoError(err, "ListByUserID")
s.Require().Equal(int64(1), page.Total)
s.Require().Len(keys, 1)
......@@ -383,3 +386,87 @@ func (s *APIKeyRepoSuite) mustCreateApiKey(userID int64, key, name string, group
s.Require().NoError(s.repo.Create(s.ctx, k), "create api key")
return k
}
// --- IncrementQuotaUsed ---
func (s *APIKeyRepoSuite) TestIncrementQuotaUsed_Basic() {
user := s.mustCreateUser("incr-basic@test.com")
key := s.mustCreateApiKey(user.ID, "sk-incr-basic", "Incr", nil)
newQuota, err := s.repo.IncrementQuotaUsed(s.ctx, key.ID, 1.5)
s.Require().NoError(err, "IncrementQuotaUsed")
s.Require().Equal(1.5, newQuota, "第一次递增后应为 1.5")
newQuota, err = s.repo.IncrementQuotaUsed(s.ctx, key.ID, 2.5)
s.Require().NoError(err, "IncrementQuotaUsed second")
s.Require().Equal(4.0, newQuota, "第二次递增后应为 4.0")
}
func (s *APIKeyRepoSuite) TestIncrementQuotaUsed_NotFound() {
_, err := s.repo.IncrementQuotaUsed(s.ctx, 999999, 1.0)
s.Require().ErrorIs(err, service.ErrAPIKeyNotFound, "不存在的 key 应返回 ErrAPIKeyNotFound")
}
func (s *APIKeyRepoSuite) TestIncrementQuotaUsed_DeletedKey() {
user := s.mustCreateUser("incr-deleted@test.com")
key := s.mustCreateApiKey(user.ID, "sk-incr-del", "Deleted", nil)
s.Require().NoError(s.repo.Delete(s.ctx, key.ID), "Delete")
_, err := s.repo.IncrementQuotaUsed(s.ctx, key.ID, 1.0)
s.Require().ErrorIs(err, service.ErrAPIKeyNotFound, "已删除的 key 应返回 ErrAPIKeyNotFound")
}
// TestIncrementQuotaUsed_Concurrent 使用真实数据库验证并发原子性。
// 注意:此测试使用 testEntClient(非事务隔离),数据会真正写入数据库。
func TestIncrementQuotaUsed_Concurrent(t *testing.T) {
client := testEntClient(t)
repo := NewAPIKeyRepository(client, integrationDB).(*apiKeyRepository)
ctx := context.Background()
// 创建测试用户和 API Key
u, err := client.User.Create().
SetEmail("concurrent-incr-" + time.Now().Format(time.RFC3339Nano) + "@test.com").
SetPasswordHash("hash").
SetStatus(service.StatusActive).
SetRole(service.RoleUser).
Save(ctx)
require.NoError(t, err, "create user")
k := &service.APIKey{
UserID: u.ID,
Key: "sk-concurrent-" + time.Now().Format(time.RFC3339Nano),
Name: "Concurrent",
Status: service.StatusActive,
}
require.NoError(t, repo.Create(ctx, k), "create api key")
t.Cleanup(func() {
_ = client.APIKey.DeleteOneID(k.ID).Exec(ctx)
_ = client.User.DeleteOneID(u.ID).Exec(ctx)
})
// 10 个 goroutine 各递增 1.0,总计应为 10.0
const goroutines = 10
const increment = 1.0
var wg sync.WaitGroup
errs := make([]error, goroutines)
for i := 0; i < goroutines; i++ {
wg.Add(1)
go func(idx int) {
defer wg.Done()
_, errs[idx] = repo.IncrementQuotaUsed(ctx, k.ID, increment)
}(i)
}
wg.Wait()
for i, e := range errs {
require.NoError(t, e, "goroutine %d failed", i)
}
// 验证最终结果
got, err := repo.GetByID(ctx, k.ID)
require.NoError(t, err, "GetByID")
require.Equal(t, float64(goroutines)*increment, got.QuotaUsed,
"并发递增后总和应为 %v,实际为 %v", float64(goroutines)*increment, got.QuotaUsed)
}
package repository
import (
"context"
"database/sql"
"testing"
"time"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/enttest"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/stretchr/testify/require"
"entgo.io/ent/dialect"
entsql "entgo.io/ent/dialect/sql"
_ "modernc.org/sqlite"
)
func newAPIKeyRepoSQLite(t *testing.T) (*apiKeyRepository, *dbent.Client) {
t.Helper()
db, err := sql.Open("sqlite", "file:api_key_repo_last_used?mode=memory&cache=shared")
require.NoError(t, err)
t.Cleanup(func() { _ = db.Close() })
_, err = db.Exec("PRAGMA foreign_keys = ON")
require.NoError(t, err)
drv := entsql.OpenDB(dialect.SQLite, db)
client := enttest.NewClient(t, enttest.WithOptions(dbent.Driver(drv)))
t.Cleanup(func() { _ = client.Close() })
return &apiKeyRepository{client: client}, client
}
func mustCreateAPIKeyRepoUser(t *testing.T, ctx context.Context, client *dbent.Client, email string) *service.User {
t.Helper()
u, err := client.User.Create().
SetEmail(email).
SetPasswordHash("test-password-hash").
SetRole(service.RoleUser).
SetStatus(service.StatusActive).
Save(ctx)
require.NoError(t, err)
return userEntityToService(u)
}
func TestAPIKeyRepository_CreateWithLastUsedAt(t *testing.T) {
repo, client := newAPIKeyRepoSQLite(t)
ctx := context.Background()
user := mustCreateAPIKeyRepoUser(t, ctx, client, "create-last-used@test.com")
lastUsed := time.Now().UTC().Add(-time.Hour).Truncate(time.Second)
key := &service.APIKey{
UserID: user.ID,
Key: "sk-create-last-used",
Name: "CreateWithLastUsed",
Status: service.StatusActive,
LastUsedAt: &lastUsed,
}
require.NoError(t, repo.Create(ctx, key))
require.NotNil(t, key.LastUsedAt)
require.WithinDuration(t, lastUsed, *key.LastUsedAt, time.Second)
got, err := repo.GetByID(ctx, key.ID)
require.NoError(t, err)
require.NotNil(t, got.LastUsedAt)
require.WithinDuration(t, lastUsed, *got.LastUsedAt, time.Second)
}
func TestAPIKeyRepository_UpdateLastUsed(t *testing.T) {
repo, client := newAPIKeyRepoSQLite(t)
ctx := context.Background()
user := mustCreateAPIKeyRepoUser(t, ctx, client, "update-last-used@test.com")
key := &service.APIKey{
UserID: user.ID,
Key: "sk-update-last-used",
Name: "UpdateLastUsed",
Status: service.StatusActive,
}
require.NoError(t, repo.Create(ctx, key))
before, err := repo.GetByID(ctx, key.ID)
require.NoError(t, err)
require.Nil(t, before.LastUsedAt)
target := time.Now().UTC().Add(2 * time.Minute).Truncate(time.Second)
require.NoError(t, repo.UpdateLastUsed(ctx, key.ID, target))
after, err := repo.GetByID(ctx, key.ID)
require.NoError(t, err)
require.NotNil(t, after.LastUsedAt)
require.WithinDuration(t, target, *after.LastUsedAt, time.Second)
require.WithinDuration(t, target, after.UpdatedAt, time.Second)
}
func TestAPIKeyRepository_UpdateLastUsedDeletedKey(t *testing.T) {
repo, client := newAPIKeyRepoSQLite(t)
ctx := context.Background()
user := mustCreateAPIKeyRepoUser(t, ctx, client, "deleted-last-used@test.com")
key := &service.APIKey{
UserID: user.ID,
Key: "sk-update-last-used-deleted",
Name: "UpdateLastUsedDeleted",
Status: service.StatusActive,
}
require.NoError(t, repo.Create(ctx, key))
require.NoError(t, repo.Delete(ctx, key.ID))
err := repo.UpdateLastUsed(ctx, key.ID, time.Now().UTC())
require.ErrorIs(t, err, service.ErrAPIKeyNotFound)
}
func TestAPIKeyRepository_UpdateLastUsedDBError(t *testing.T) {
repo, client := newAPIKeyRepoSQLite(t)
ctx := context.Background()
user := mustCreateAPIKeyRepoUser(t, ctx, client, "db-error-last-used@test.com")
key := &service.APIKey{
UserID: user.ID,
Key: "sk-update-last-used-db-error",
Name: "UpdateLastUsedDBError",
Status: service.StatusActive,
}
require.NoError(t, repo.Create(ctx, key))
require.NoError(t, client.Close())
err := repo.UpdateLastUsed(ctx, key.ID, time.Now().UTC())
require.Error(t, err)
}
func TestAPIKeyRepository_CreateDuplicateKey(t *testing.T) {
repo, client := newAPIKeyRepoSQLite(t)
ctx := context.Background()
user := mustCreateAPIKeyRepoUser(t, ctx, client, "duplicate-key@test.com")
first := &service.APIKey{
UserID: user.ID,
Key: "sk-duplicate",
Name: "first",
Status: service.StatusActive,
}
second := &service.APIKey{
UserID: user.ID,
Key: "sk-duplicate",
Name: "second",
Status: service.StatusActive,
}
require.NoError(t, repo.Create(ctx, first))
err := repo.Create(ctx, second)
require.ErrorIs(t, err, service.ErrAPIKeyExists)
}
......@@ -5,6 +5,7 @@ import (
"errors"
"fmt"
"log"
"math/rand/v2"
"strconv"
"time"
......@@ -15,9 +16,22 @@ import (
const (
billingBalanceKeyPrefix = "billing:balance:"
billingSubKeyPrefix = "billing:sub:"
billingRateLimitKeyPrefix = "apikey:rate:"
billingCacheTTL = 5 * time.Minute
billingCacheJitter = 30 * time.Second
rateLimitCacheTTL = 7 * 24 * time.Hour // 7 days matches the longest window
)
// jitteredTTL 返回带随机抖动的 TTL,防止缓存雪崩
func jitteredTTL() time.Duration {
// 只做“减法抖动”,确保实际 TTL 不会超过 billingCacheTTL(避免上界预期被打破)。
if billingCacheJitter <= 0 {
return billingCacheTTL
}
jitter := time.Duration(rand.IntN(int(billingCacheJitter)))
return billingCacheTTL - jitter
}
// billingBalanceKey generates the Redis key for user balance cache.
func billingBalanceKey(userID int64) string {
return fmt.Sprintf("%s%d", billingBalanceKeyPrefix, userID)
......@@ -37,6 +51,20 @@ const (
subFieldVersion = "version"
)
// billingRateLimitKey generates the Redis key for API key rate limit cache.
func billingRateLimitKey(keyID int64) string {
return fmt.Sprintf("%s%d", billingRateLimitKeyPrefix, keyID)
}
const (
rateLimitFieldUsage5h = "usage_5h"
rateLimitFieldUsage1d = "usage_1d"
rateLimitFieldUsage7d = "usage_7d"
rateLimitFieldWindow5h = "window_5h"
rateLimitFieldWindow1d = "window_1d"
rateLimitFieldWindow7d = "window_7d"
)
var (
deductBalanceScript = redis.NewScript(`
local current = redis.call('GET', KEYS[1])
......@@ -61,6 +89,21 @@ var (
redis.call('EXPIRE', KEYS[1], ARGV[2])
return 1
`)
// updateRateLimitUsageScript atomically increments all three rate limit usage counters.
// Returns 0 if the key doesn't exist (cache miss), 1 on success.
updateRateLimitUsageScript = redis.NewScript(`
local exists = redis.call('EXISTS', KEYS[1])
if exists == 0 then
return 0
end
local cost = tonumber(ARGV[1])
redis.call('HINCRBYFLOAT', KEYS[1], 'usage_5h', cost)
redis.call('HINCRBYFLOAT', KEYS[1], 'usage_1d', cost)
redis.call('HINCRBYFLOAT', KEYS[1], 'usage_7d', cost)
redis.call('EXPIRE', KEYS[1], ARGV[2])
return 1
`)
)
type billingCache struct {
......@@ -82,14 +125,15 @@ func (c *billingCache) GetUserBalance(ctx context.Context, userID int64) (float6
func (c *billingCache) SetUserBalance(ctx context.Context, userID int64, balance float64) error {
key := billingBalanceKey(userID)
return c.rdb.Set(ctx, key, balance, billingCacheTTL).Err()
return c.rdb.Set(ctx, key, balance, jitteredTTL()).Err()
}
func (c *billingCache) DeductUserBalance(ctx context.Context, userID int64, amount float64) error {
key := billingBalanceKey(userID)
_, err := deductBalanceScript.Run(ctx, c.rdb, []string{key}, amount, int(billingCacheTTL.Seconds())).Result()
_, err := deductBalanceScript.Run(ctx, c.rdb, []string{key}, amount, int(jitteredTTL().Seconds())).Result()
if err != nil && !errors.Is(err, redis.Nil) {
log.Printf("Warning: deduct balance cache failed for user %d: %v", userID, err)
return err
}
return nil
}
......@@ -163,16 +207,17 @@ func (c *billingCache) SetSubscriptionCache(ctx context.Context, userID, groupID
pipe := c.rdb.Pipeline()
pipe.HSet(ctx, key, fields)
pipe.Expire(ctx, key, billingCacheTTL)
pipe.Expire(ctx, key, jitteredTTL())
_, err := pipe.Exec(ctx)
return err
}
func (c *billingCache) UpdateSubscriptionUsage(ctx context.Context, userID, groupID int64, cost float64) error {
key := billingSubKey(userID, groupID)
_, err := updateSubUsageScript.Run(ctx, c.rdb, []string{key}, cost, int(billingCacheTTL.Seconds())).Result()
_, err := updateSubUsageScript.Run(ctx, c.rdb, []string{key}, cost, int(jitteredTTL().Seconds())).Result()
if err != nil && !errors.Is(err, redis.Nil) {
log.Printf("Warning: update subscription usage cache failed for user %d group %d: %v", userID, groupID, err)
return err
}
return nil
}
......@@ -181,3 +226,69 @@ func (c *billingCache) InvalidateSubscriptionCache(ctx context.Context, userID,
key := billingSubKey(userID, groupID)
return c.rdb.Del(ctx, key).Err()
}
func (c *billingCache) GetAPIKeyRateLimit(ctx context.Context, keyID int64) (*service.APIKeyRateLimitCacheData, error) {
key := billingRateLimitKey(keyID)
result, err := c.rdb.HGetAll(ctx, key).Result()
if err != nil {
return nil, err
}
if len(result) == 0 {
return nil, redis.Nil
}
data := &service.APIKeyRateLimitCacheData{}
if v, ok := result[rateLimitFieldUsage5h]; ok {
data.Usage5h, _ = strconv.ParseFloat(v, 64)
}
if v, ok := result[rateLimitFieldUsage1d]; ok {
data.Usage1d, _ = strconv.ParseFloat(v, 64)
}
if v, ok := result[rateLimitFieldUsage7d]; ok {
data.Usage7d, _ = strconv.ParseFloat(v, 64)
}
if v, ok := result[rateLimitFieldWindow5h]; ok {
data.Window5h, _ = strconv.ParseInt(v, 10, 64)
}
if v, ok := result[rateLimitFieldWindow1d]; ok {
data.Window1d, _ = strconv.ParseInt(v, 10, 64)
}
if v, ok := result[rateLimitFieldWindow7d]; ok {
data.Window7d, _ = strconv.ParseInt(v, 10, 64)
}
return data, nil
}
func (c *billingCache) SetAPIKeyRateLimit(ctx context.Context, keyID int64, data *service.APIKeyRateLimitCacheData) error {
if data == nil {
return nil
}
key := billingRateLimitKey(keyID)
fields := map[string]any{
rateLimitFieldUsage5h: data.Usage5h,
rateLimitFieldUsage1d: data.Usage1d,
rateLimitFieldUsage7d: data.Usage7d,
rateLimitFieldWindow5h: data.Window5h,
rateLimitFieldWindow1d: data.Window1d,
rateLimitFieldWindow7d: data.Window7d,
}
pipe := c.rdb.Pipeline()
pipe.HSet(ctx, key, fields)
pipe.Expire(ctx, key, rateLimitCacheTTL)
_, err := pipe.Exec(ctx)
return err
}
func (c *billingCache) UpdateAPIKeyRateLimitUsage(ctx context.Context, keyID int64, cost float64) error {
key := billingRateLimitKey(keyID)
_, err := updateRateLimitUsageScript.Run(ctx, c.rdb, []string{key}, cost, int(rateLimitCacheTTL.Seconds())).Result()
if err != nil && !errors.Is(err, redis.Nil) {
log.Printf("Warning: update rate limit usage cache failed for api key %d: %v", keyID, err)
return err
}
return nil
}
func (c *billingCache) InvalidateAPIKeyRateLimit(ctx context.Context, keyID int64) error {
key := billingRateLimitKey(keyID)
return c.rdb.Del(ctx, key).Err()
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment