Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Menu
Open sidebar
陈曦
sub2api
Commits
3d79773b
Commit
3d79773b
authored
Mar 04, 2026
by
kyx236
Browse files
Merge branch 'main' of
https://github.com/james-6-23/sub2api
parents
6aa8cbbf
742e73c9
Changes
253
Show whitespace changes
Inline
Side-by-side
Too many changes to show.
To preserve performance only
253 of 253+
files are displayed.
Plain diff
Email patch
backend/internal/pkg/openai/oauth_test.go
0 → 100644
View file @
3d79773b
package
openai
import
(
"net/url"
"sync"
"testing"
"time"
)
func
TestSessionStore_Stop_Idempotent
(
t
*
testing
.
T
)
{
store
:=
NewSessionStore
()
store
.
Stop
()
store
.
Stop
()
select
{
case
<-
store
.
stopCh
:
// ok
case
<-
time
.
After
(
time
.
Second
)
:
t
.
Fatal
(
"stopCh 未关闭"
)
}
}
func
TestSessionStore_Stop_Concurrent
(
t
*
testing
.
T
)
{
store
:=
NewSessionStore
()
var
wg
sync
.
WaitGroup
for
range
50
{
wg
.
Add
(
1
)
go
func
()
{
defer
wg
.
Done
()
store
.
Stop
()
}()
}
wg
.
Wait
()
select
{
case
<-
store
.
stopCh
:
// ok
case
<-
time
.
After
(
time
.
Second
)
:
t
.
Fatal
(
"stopCh 未关闭"
)
}
}
func
TestBuildAuthorizationURLForPlatform_OpenAI
(
t
*
testing
.
T
)
{
authURL
:=
BuildAuthorizationURLForPlatform
(
"state-1"
,
"challenge-1"
,
DefaultRedirectURI
,
OAuthPlatformOpenAI
)
parsed
,
err
:=
url
.
Parse
(
authURL
)
if
err
!=
nil
{
t
.
Fatalf
(
"Parse URL failed: %v"
,
err
)
}
q
:=
parsed
.
Query
()
if
got
:=
q
.
Get
(
"client_id"
);
got
!=
ClientID
{
t
.
Fatalf
(
"client_id mismatch: got=%q want=%q"
,
got
,
ClientID
)
}
if
got
:=
q
.
Get
(
"codex_cli_simplified_flow"
);
got
!=
"true"
{
t
.
Fatalf
(
"codex flow mismatch: got=%q want=true"
,
got
)
}
if
got
:=
q
.
Get
(
"id_token_add_organizations"
);
got
!=
"true"
{
t
.
Fatalf
(
"id_token_add_organizations mismatch: got=%q want=true"
,
got
)
}
}
// TestBuildAuthorizationURLForPlatform_Sora 验证 Sora 平台复用 Codex CLI 的 client_id,
// 但不启用 codex_cli_simplified_flow。
func
TestBuildAuthorizationURLForPlatform_Sora
(
t
*
testing
.
T
)
{
authURL
:=
BuildAuthorizationURLForPlatform
(
"state-2"
,
"challenge-2"
,
DefaultRedirectURI
,
OAuthPlatformSora
)
parsed
,
err
:=
url
.
Parse
(
authURL
)
if
err
!=
nil
{
t
.
Fatalf
(
"Parse URL failed: %v"
,
err
)
}
q
:=
parsed
.
Query
()
if
got
:=
q
.
Get
(
"client_id"
);
got
!=
ClientID
{
t
.
Fatalf
(
"client_id mismatch: got=%q want=%q (Sora should reuse Codex CLI client_id)"
,
got
,
ClientID
)
}
if
got
:=
q
.
Get
(
"codex_cli_simplified_flow"
);
got
!=
""
{
t
.
Fatalf
(
"codex flow should be empty for sora, got=%q"
,
got
)
}
if
got
:=
q
.
Get
(
"id_token_add_organizations"
);
got
!=
"true"
{
t
.
Fatalf
(
"id_token_add_organizations mismatch: got=%q want=true"
,
got
)
}
}
backend/internal/pkg/openai/request.go
View file @
3d79773b
package
openai
import
"strings"
// CodexCLIUserAgentPrefixes matches Codex CLI User-Agent patterns
// Examples: "codex_vscode/1.0.0", "codex_cli_rs/0.1.2"
var
CodexCLIUserAgentPrefixes
=
[]
string
{
...
...
@@ -7,10 +9,67 @@ var CodexCLIUserAgentPrefixes = []string{
"codex_cli_rs/"
,
}
// CodexOfficialClientUserAgentPrefixes matches Codex 官方客户端家族 User-Agent 前缀。
// 该列表仅用于 OpenAI OAuth `codex_cli_only` 访问限制判定。
var
CodexOfficialClientUserAgentPrefixes
=
[]
string
{
"codex_cli_rs/"
,
"codex_vscode/"
,
"codex_app/"
,
"codex_chatgpt_desktop/"
,
"codex_atlas/"
,
"codex_exec/"
,
"codex_sdk_ts/"
,
"codex "
,
}
// CodexOfficialClientOriginatorPrefixes matches Codex 官方客户端家族 originator 前缀。
// 说明:OpenAI 官方 Codex 客户端并不只使用固定的 codex_app 标识。
// 例如 codex_cli_rs、codex_vscode、codex_chatgpt_desktop、codex_atlas、codex_exec、codex_sdk_ts 等。
var
CodexOfficialClientOriginatorPrefixes
=
[]
string
{
"codex_"
,
"codex "
,
}
// IsCodexCLIRequest checks if the User-Agent indicates a Codex CLI request
func
IsCodexCLIRequest
(
userAgent
string
)
bool
{
for
_
,
prefix
:=
range
CodexCLIUserAgentPrefixes
{
if
len
(
userAgent
)
>=
len
(
prefix
)
&&
userAgent
[
:
len
(
prefix
)]
==
prefix
{
ua
:=
normalizeCodexClientHeader
(
userAgent
)
if
ua
==
""
{
return
false
}
return
matchCodexClientHeaderPrefixes
(
ua
,
CodexCLIUserAgentPrefixes
)
}
// IsCodexOfficialClientRequest checks if the User-Agent indicates a Codex 官方客户端请求。
// 与 IsCodexCLIRequest 解耦,避免影响历史兼容逻辑。
func
IsCodexOfficialClientRequest
(
userAgent
string
)
bool
{
ua
:=
normalizeCodexClientHeader
(
userAgent
)
if
ua
==
""
{
return
false
}
return
matchCodexClientHeaderPrefixes
(
ua
,
CodexOfficialClientUserAgentPrefixes
)
}
// IsCodexOfficialClientOriginator checks if originator indicates a Codex 官方客户端请求。
func
IsCodexOfficialClientOriginator
(
originator
string
)
bool
{
v
:=
normalizeCodexClientHeader
(
originator
)
if
v
==
""
{
return
false
}
return
matchCodexClientHeaderPrefixes
(
v
,
CodexOfficialClientOriginatorPrefixes
)
}
func
normalizeCodexClientHeader
(
value
string
)
string
{
return
strings
.
ToLower
(
strings
.
TrimSpace
(
value
))
}
func
matchCodexClientHeaderPrefixes
(
value
string
,
prefixes
[]
string
)
bool
{
for
_
,
prefix
:=
range
prefixes
{
normalizedPrefix
:=
normalizeCodexClientHeader
(
prefix
)
if
normalizedPrefix
==
""
{
continue
}
// 优先前缀匹配;若 UA/Originator 被网关拼接为复合字符串时,退化为包含匹配。
if
strings
.
HasPrefix
(
value
,
normalizedPrefix
)
||
strings
.
Contains
(
value
,
normalizedPrefix
)
{
return
true
}
}
...
...
backend/internal/pkg/openai/request_test.go
0 → 100644
View file @
3d79773b
package
openai
import
"testing"
func
TestIsCodexCLIRequest
(
t
*
testing
.
T
)
{
tests
:=
[]
struct
{
name
string
ua
string
want
bool
}{
{
name
:
"codex_cli_rs 前缀"
,
ua
:
"codex_cli_rs/0.1.0"
,
want
:
true
},
{
name
:
"codex_vscode 前缀"
,
ua
:
"codex_vscode/1.2.3"
,
want
:
true
},
{
name
:
"大小写混合"
,
ua
:
"Codex_CLI_Rs/0.1.0"
,
want
:
true
},
{
name
:
"复合 UA 包含 codex"
,
ua
:
"Mozilla/5.0 codex_cli_rs/0.1.0"
,
want
:
true
},
{
name
:
"空白包裹"
,
ua
:
" codex_vscode/1.2.3 "
,
want
:
true
},
{
name
:
"非 codex"
,
ua
:
"curl/8.0.1"
,
want
:
false
},
{
name
:
"空字符串"
,
ua
:
""
,
want
:
false
},
}
for
_
,
tt
:=
range
tests
{
t
.
Run
(
tt
.
name
,
func
(
t
*
testing
.
T
)
{
got
:=
IsCodexCLIRequest
(
tt
.
ua
)
if
got
!=
tt
.
want
{
t
.
Fatalf
(
"IsCodexCLIRequest(%q) = %v, want %v"
,
tt
.
ua
,
got
,
tt
.
want
)
}
})
}
}
func
TestIsCodexOfficialClientRequest
(
t
*
testing
.
T
)
{
tests
:=
[]
struct
{
name
string
ua
string
want
bool
}{
{
name
:
"codex_cli_rs 前缀"
,
ua
:
"codex_cli_rs/0.98.0"
,
want
:
true
},
{
name
:
"codex_vscode 前缀"
,
ua
:
"codex_vscode/1.0.0"
,
want
:
true
},
{
name
:
"codex_app 前缀"
,
ua
:
"codex_app/0.1.0"
,
want
:
true
},
{
name
:
"codex_chatgpt_desktop 前缀"
,
ua
:
"codex_chatgpt_desktop/1.0.0"
,
want
:
true
},
{
name
:
"codex_atlas 前缀"
,
ua
:
"codex_atlas/1.0.0"
,
want
:
true
},
{
name
:
"codex_exec 前缀"
,
ua
:
"codex_exec/0.1.0"
,
want
:
true
},
{
name
:
"codex_sdk_ts 前缀"
,
ua
:
"codex_sdk_ts/0.1.0"
,
want
:
true
},
{
name
:
"Codex 桌面 UA"
,
ua
:
"Codex Desktop/1.2.3"
,
want
:
true
},
{
name
:
"复合 UA 包含 codex_app"
,
ua
:
"Mozilla/5.0 codex_app/0.1.0"
,
want
:
true
},
{
name
:
"大小写混合"
,
ua
:
"Codex_VSCode/1.2.3"
,
want
:
true
},
{
name
:
"非 codex"
,
ua
:
"curl/8.0.1"
,
want
:
false
},
{
name
:
"空字符串"
,
ua
:
""
,
want
:
false
},
}
for
_
,
tt
:=
range
tests
{
t
.
Run
(
tt
.
name
,
func
(
t
*
testing
.
T
)
{
got
:=
IsCodexOfficialClientRequest
(
tt
.
ua
)
if
got
!=
tt
.
want
{
t
.
Fatalf
(
"IsCodexOfficialClientRequest(%q) = %v, want %v"
,
tt
.
ua
,
got
,
tt
.
want
)
}
})
}
}
func
TestIsCodexOfficialClientOriginator
(
t
*
testing
.
T
)
{
tests
:=
[]
struct
{
name
string
originator
string
want
bool
}{
{
name
:
"codex_cli_rs"
,
originator
:
"codex_cli_rs"
,
want
:
true
},
{
name
:
"codex_vscode"
,
originator
:
"codex_vscode"
,
want
:
true
},
{
name
:
"codex_app"
,
originator
:
"codex_app"
,
want
:
true
},
{
name
:
"codex_chatgpt_desktop"
,
originator
:
"codex_chatgpt_desktop"
,
want
:
true
},
{
name
:
"codex_atlas"
,
originator
:
"codex_atlas"
,
want
:
true
},
{
name
:
"codex_exec"
,
originator
:
"codex_exec"
,
want
:
true
},
{
name
:
"codex_sdk_ts"
,
originator
:
"codex_sdk_ts"
,
want
:
true
},
{
name
:
"Codex 前缀"
,
originator
:
"Codex Desktop"
,
want
:
true
},
{
name
:
"空白包裹"
,
originator
:
" codex_vscode "
,
want
:
true
},
{
name
:
"非 codex"
,
originator
:
"my_client"
,
want
:
false
},
{
name
:
"空字符串"
,
originator
:
""
,
want
:
false
},
}
for
_
,
tt
:=
range
tests
{
t
.
Run
(
tt
.
name
,
func
(
t
*
testing
.
T
)
{
got
:=
IsCodexOfficialClientOriginator
(
tt
.
originator
)
if
got
!=
tt
.
want
{
t
.
Fatalf
(
"IsCodexOfficialClientOriginator(%q) = %v, want %v"
,
tt
.
originator
,
got
,
tt
.
want
)
}
})
}
}
backend/internal/pkg/proxyurl/parse.go
0 → 100644
View file @
3d79773b
// Package proxyurl 提供代理 URL 的统一验证(fail-fast,无效代理不回退直连)
//
// 所有需要解析代理 URL 的地方必须通过此包的 Parse 函数。
// 直接使用 url.Parse 处理代理 URL 是被禁止的。
// 这确保了 fail-fast 行为:无效代理配置在创建时立即失败,
// 而不是在运行时静默回退到直连(产生 IP 关联风险)。
package
proxyurl
import
(
"fmt"
"net/url"
"strings"
)
// allowedSchemes 代理协议白名单
var
allowedSchemes
=
map
[
string
]
bool
{
"http"
:
true
,
"https"
:
true
,
"socks5"
:
true
,
"socks5h"
:
true
,
}
// Parse 解析并验证代理 URL。
//
// 语义:
// - 空字符串 → ("", nil, nil),表示直连
// - 非空且有效 → (trimmed, *url.URL, nil)
// - 非空但无效 → ("", nil, error),fail-fast 不回退
//
// 验证规则:
// - TrimSpace 后为空视为直连
// - url.Parse 失败返回 error(不含原始 URL,防凭据泄露)
// - Host 为空返回 error(用 Redacted() 脱敏)
// - Scheme 必须为 http/https/socks5/socks5h
// - socks5:// 自动升级为 socks5h://(确保 DNS 由代理端解析,防止 DNS 泄漏)
func
Parse
(
raw
string
)
(
trimmed
string
,
parsed
*
url
.
URL
,
err
error
)
{
trimmed
=
strings
.
TrimSpace
(
raw
)
if
trimmed
==
""
{
return
""
,
nil
,
nil
}
parsed
,
err
=
url
.
Parse
(
trimmed
)
if
err
!=
nil
{
// 不使用 %w 包装,避免 url.Parse 的底层错误消息泄漏原始 URL(可能含凭据)
return
""
,
nil
,
fmt
.
Errorf
(
"invalid proxy URL: %v"
,
err
)
}
if
parsed
.
Host
==
""
||
parsed
.
Hostname
()
==
""
{
return
""
,
nil
,
fmt
.
Errorf
(
"proxy URL missing host: %s"
,
parsed
.
Redacted
())
}
scheme
:=
strings
.
ToLower
(
parsed
.
Scheme
)
if
!
allowedSchemes
[
scheme
]
{
return
""
,
nil
,
fmt
.
Errorf
(
"unsupported proxy scheme %q (allowed: http, https, socks5, socks5h)"
,
scheme
)
}
// 自动升级 socks5 → socks5h,确保 DNS 由代理端解析,防止 DNS 泄漏。
// Go 的 golang.org/x/net/proxy 对 socks5:// 默认在客户端本地解析 DNS,
// 仅 socks5h:// 才将域名发送给代理端做远程 DNS 解析。
if
scheme
==
"socks5"
{
parsed
.
Scheme
=
"socks5h"
trimmed
=
parsed
.
String
()
}
return
trimmed
,
parsed
,
nil
}
backend/internal/pkg/proxyurl/parse_test.go
0 → 100644
View file @
3d79773b
package
proxyurl
import
(
"strings"
"testing"
)
func
TestParse_空字符串直连
(
t
*
testing
.
T
)
{
trimmed
,
parsed
,
err
:=
Parse
(
""
)
if
err
!=
nil
{
t
.
Fatalf
(
"空字符串应直连: %v"
,
err
)
}
if
trimmed
!=
""
{
t
.
Errorf
(
"trimmed 应为空: got %q"
,
trimmed
)
}
if
parsed
!=
nil
{
t
.
Errorf
(
"parsed 应为 nil: got %v"
,
parsed
)
}
}
func
TestParse_空白字符串直连
(
t
*
testing
.
T
)
{
trimmed
,
parsed
,
err
:=
Parse
(
" "
)
if
err
!=
nil
{
t
.
Fatalf
(
"空白字符串应直连: %v"
,
err
)
}
if
trimmed
!=
""
{
t
.
Errorf
(
"trimmed 应为空: got %q"
,
trimmed
)
}
if
parsed
!=
nil
{
t
.
Errorf
(
"parsed 应为 nil: got %v"
,
parsed
)
}
}
func
TestParse_有效HTTP代理
(
t
*
testing
.
T
)
{
trimmed
,
parsed
,
err
:=
Parse
(
"http://proxy.example.com:8080"
)
if
err
!=
nil
{
t
.
Fatalf
(
"有效 HTTP 代理应成功: %v"
,
err
)
}
if
trimmed
!=
"http://proxy.example.com:8080"
{
t
.
Errorf
(
"trimmed 不匹配: got %q"
,
trimmed
)
}
if
parsed
==
nil
{
t
.
Fatal
(
"parsed 不应为 nil"
)
}
if
parsed
.
Host
!=
"proxy.example.com:8080"
{
t
.
Errorf
(
"Host 不匹配: got %q"
,
parsed
.
Host
)
}
}
func
TestParse_有效HTTPS代理
(
t
*
testing
.
T
)
{
_
,
parsed
,
err
:=
Parse
(
"https://proxy.example.com:443"
)
if
err
!=
nil
{
t
.
Fatalf
(
"有效 HTTPS 代理应成功: %v"
,
err
)
}
if
parsed
.
Scheme
!=
"https"
{
t
.
Errorf
(
"Scheme 不匹配: got %q"
,
parsed
.
Scheme
)
}
}
func
TestParse_有效SOCKS5代理_自动升级为SOCKS5H
(
t
*
testing
.
T
)
{
trimmed
,
parsed
,
err
:=
Parse
(
"socks5://127.0.0.1:1080"
)
if
err
!=
nil
{
t
.
Fatalf
(
"有效 SOCKS5 代理应成功: %v"
,
err
)
}
// socks5 自动升级为 socks5h,确保 DNS 由代理端解析
if
trimmed
!=
"socks5h://127.0.0.1:1080"
{
t
.
Errorf
(
"trimmed 应升级为 socks5h: got %q"
,
trimmed
)
}
if
parsed
.
Scheme
!=
"socks5h"
{
t
.
Errorf
(
"Scheme 应升级为 socks5h: got %q"
,
parsed
.
Scheme
)
}
}
func
TestParse_无效URL
(
t
*
testing
.
T
)
{
_
,
_
,
err
:=
Parse
(
"://invalid"
)
if
err
==
nil
{
t
.
Fatal
(
"无效 URL 应返回错误"
)
}
if
!
strings
.
Contains
(
err
.
Error
(),
"invalid proxy URL"
)
{
t
.
Errorf
(
"错误信息应包含 'invalid proxy URL': got %s"
,
err
.
Error
())
}
}
func
TestParse_缺少Host
(
t
*
testing
.
T
)
{
_
,
_
,
err
:=
Parse
(
"http://"
)
if
err
==
nil
{
t
.
Fatal
(
"缺少 host 应返回错误"
)
}
if
!
strings
.
Contains
(
err
.
Error
(),
"missing host"
)
{
t
.
Errorf
(
"错误信息应包含 'missing host': got %s"
,
err
.
Error
())
}
}
func
TestParse_不支持的Scheme
(
t
*
testing
.
T
)
{
_
,
_
,
err
:=
Parse
(
"ftp://proxy.example.com:21"
)
if
err
==
nil
{
t
.
Fatal
(
"不支持的 scheme 应返回错误"
)
}
if
!
strings
.
Contains
(
err
.
Error
(),
"unsupported proxy scheme"
)
{
t
.
Errorf
(
"错误信息应包含 'unsupported proxy scheme': got %s"
,
err
.
Error
())
}
}
func
TestParse_含密码URL脱敏
(
t
*
testing
.
T
)
{
// 场景 1: 带密码的 socks5 URL 应成功解析并升级为 socks5h
trimmed
,
parsed
,
err
:=
Parse
(
"socks5://user:secret_password@proxy.local:1080"
)
if
err
!=
nil
{
t
.
Fatalf
(
"含密码的有效 URL 应成功: %v"
,
err
)
}
if
trimmed
==
""
||
parsed
==
nil
{
t
.
Fatal
(
"应返回非空结果"
)
}
if
parsed
.
Scheme
!=
"socks5h"
{
t
.
Errorf
(
"Scheme 应升级为 socks5h: got %q"
,
parsed
.
Scheme
)
}
if
!
strings
.
HasPrefix
(
trimmed
,
"socks5h://"
)
{
t
.
Errorf
(
"trimmed 应以 socks5h:// 开头: got %q"
,
trimmed
)
}
if
parsed
.
User
==
nil
{
t
.
Error
(
"升级后应保留 UserInfo"
)
}
// 场景 2: 带密码但缺少 host(触发 Redacted 脱敏路径)
_
,
_
,
err
=
Parse
(
"http://user:secret_password@:0/"
)
if
err
==
nil
{
t
.
Fatal
(
"缺少 host 应返回错误"
)
}
if
strings
.
Contains
(
err
.
Error
(),
"secret_password"
)
{
t
.
Error
(
"错误信息不应包含明文密码"
)
}
if
!
strings
.
Contains
(
err
.
Error
(),
"missing host"
)
{
t
.
Errorf
(
"错误信息应包含 'missing host': got %s"
,
err
.
Error
())
}
}
func
TestParse_带空白的有效URL
(
t
*
testing
.
T
)
{
trimmed
,
parsed
,
err
:=
Parse
(
" http://proxy.example.com:8080 "
)
if
err
!=
nil
{
t
.
Fatalf
(
"带空白的有效 URL 应成功: %v"
,
err
)
}
if
trimmed
!=
"http://proxy.example.com:8080"
{
t
.
Errorf
(
"trimmed 应去除空白: got %q"
,
trimmed
)
}
if
parsed
==
nil
{
t
.
Fatal
(
"parsed 不应为 nil"
)
}
}
func
TestParse_Scheme大小写不敏感
(
t
*
testing
.
T
)
{
// 大写 SOCKS5 应被接受并升级为 socks5h
trimmed
,
parsed
,
err
:=
Parse
(
"SOCKS5://proxy.example.com:1080"
)
if
err
!=
nil
{
t
.
Fatalf
(
"大写 SOCKS5 应被接受: %v"
,
err
)
}
if
parsed
.
Scheme
!=
"socks5h"
{
t
.
Errorf
(
"大写 SOCKS5 Scheme 应升级为 socks5h: got %q"
,
parsed
.
Scheme
)
}
if
!
strings
.
HasPrefix
(
trimmed
,
"socks5h://"
)
{
t
.
Errorf
(
"大写 SOCKS5 trimmed 应升级为 socks5h://: got %q"
,
trimmed
)
}
// 大写 HTTP 应被接受(不变)
_
,
_
,
err
=
Parse
(
"HTTP://proxy.example.com:8080"
)
if
err
!=
nil
{
t
.
Fatalf
(
"大写 HTTP 应被接受: %v"
,
err
)
}
}
func
TestParse_带认证的有效代理
(
t
*
testing
.
T
)
{
trimmed
,
parsed
,
err
:=
Parse
(
"http://user:pass@proxy.example.com:8080"
)
if
err
!=
nil
{
t
.
Fatalf
(
"带认证的代理 URL 应成功: %v"
,
err
)
}
if
parsed
.
User
==
nil
{
t
.
Error
(
"应保留 UserInfo"
)
}
if
trimmed
!=
"http://user:pass@proxy.example.com:8080"
{
t
.
Errorf
(
"trimmed 不匹配: got %q"
,
trimmed
)
}
}
func
TestParse_IPv6地址
(
t
*
testing
.
T
)
{
trimmed
,
parsed
,
err
:=
Parse
(
"http://[::1]:8080"
)
if
err
!=
nil
{
t
.
Fatalf
(
"IPv6 代理 URL 应成功: %v"
,
err
)
}
if
parsed
.
Hostname
()
!=
"::1"
{
t
.
Errorf
(
"Hostname 不匹配: got %q"
,
parsed
.
Hostname
())
}
if
trimmed
!=
"http://[::1]:8080"
{
t
.
Errorf
(
"trimmed 不匹配: got %q"
,
trimmed
)
}
}
func
TestParse_SOCKS5H保持不变
(
t
*
testing
.
T
)
{
trimmed
,
parsed
,
err
:=
Parse
(
"socks5h://proxy.local:1080"
)
if
err
!=
nil
{
t
.
Fatalf
(
"有效 SOCKS5H 代理应成功: %v"
,
err
)
}
// socks5h 不需要升级,应保持原样
if
trimmed
!=
"socks5h://proxy.local:1080"
{
t
.
Errorf
(
"trimmed 不应变化: got %q"
,
trimmed
)
}
if
parsed
.
Scheme
!=
"socks5h"
{
t
.
Errorf
(
"Scheme 应保持 socks5h: got %q"
,
parsed
.
Scheme
)
}
}
func
TestParse_无Scheme裸地址
(
t
*
testing
.
T
)
{
// 无 scheme 的裸地址,Go url.Parse 将其视为 path,Host 为空
_
,
_
,
err
:=
Parse
(
"proxy.example.com:8080"
)
if
err
==
nil
{
t
.
Fatal
(
"无 scheme 的裸地址应返回错误"
)
}
}
backend/internal/pkg/proxyutil/dialer.go
View file @
3d79773b
...
...
@@ -2,7 +2,11 @@
//
// 支持的代理协议:
// - HTTP/HTTPS: 通过 Transport.Proxy 设置
// - SOCKS5/SOCKS5H: 通过 Transport.DialContext 设置(服务端解析 DNS)
// - SOCKS5: 通过 Transport.DialContext 设置(客户端本地解析 DNS)
// - SOCKS5H: 通过 Transport.DialContext 设置(代理端远程解析 DNS,推荐)
//
// 注意:proxyurl.Parse() 会自动将 socks5:// 升级为 socks5h://,
// 确保 DNS 也由代理端解析,防止 DNS 泄漏。
package
proxyutil
import
(
...
...
@@ -20,7 +24,8 @@ import (
//
// 支持的协议:
// - http/https: 设置 transport.Proxy
// - socks5/socks5h: 设置 transport.DialContext(由代理服务端解析 DNS)
// - socks5: 设置 transport.DialContext(客户端本地解析 DNS)
// - socks5h: 设置 transport.DialContext(代理端远程解析 DNS,推荐)
//
// 参数:
// - transport: 需要配置的 http.Transport
...
...
backend/internal/pkg/response/response.go
View file @
3d79773b
...
...
@@ -7,6 +7,7 @@ import (
"net/http"
infraerrors
"github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/util/logredact"
"github.com/gin-gonic/gin"
)
...
...
@@ -78,7 +79,7 @@ func ErrorFrom(c *gin.Context, err error) bool {
// Log internal errors with full details for debugging
if
statusCode
>=
500
&&
c
.
Request
!=
nil
{
log
.
Printf
(
"[ERROR] %s %s
\n
Error: %s"
,
c
.
Request
.
Method
,
c
.
Request
.
URL
.
Path
,
err
.
Error
())
log
.
Printf
(
"[ERROR] %s %s
\n
Error: %s"
,
c
.
Request
.
Method
,
c
.
Request
.
URL
.
Path
,
logredact
.
RedactText
(
err
.
Error
())
)
}
ErrorWithDetails
(
c
,
statusCode
,
status
.
Message
,
status
.
Reason
,
status
.
Metadata
)
...
...
backend/internal/pkg/response/response_test.go
View file @
3d79773b
...
...
@@ -14,6 +14,44 @@ import (
"github.com/stretchr/testify/require"
)
// ---------- 辅助函数 ----------
// parseResponseBody 从 httptest.ResponseRecorder 中解析 JSON 响应体
func
parseResponseBody
(
t
*
testing
.
T
,
w
*
httptest
.
ResponseRecorder
)
Response
{
t
.
Helper
()
var
got
Response
require
.
NoError
(
t
,
json
.
Unmarshal
(
w
.
Body
.
Bytes
(),
&
got
))
return
got
}
// parsePaginatedBody 从响应体中解析分页数据(Data 字段是 PaginatedData)
func
parsePaginatedBody
(
t
*
testing
.
T
,
w
*
httptest
.
ResponseRecorder
)
(
Response
,
PaginatedData
)
{
t
.
Helper
()
// 先用 raw json 解析,因为 Data 是 any 类型
var
raw
struct
{
Code
int
`json:"code"`
Message
string
`json:"message"`
Reason
string
`json:"reason,omitempty"`
Data
json
.
RawMessage
`json:"data,omitempty"`
}
require
.
NoError
(
t
,
json
.
Unmarshal
(
w
.
Body
.
Bytes
(),
&
raw
))
var
pd
PaginatedData
require
.
NoError
(
t
,
json
.
Unmarshal
(
raw
.
Data
,
&
pd
))
return
Response
{
Code
:
raw
.
Code
,
Message
:
raw
.
Message
,
Reason
:
raw
.
Reason
},
pd
}
// newContextWithQuery 创建一个带有 URL query 参数的 gin.Context 用于测试 ParsePagination
func
newContextWithQuery
(
query
string
)
(
*
httptest
.
ResponseRecorder
,
*
gin
.
Context
)
{
w
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
w
)
c
.
Request
=
httptest
.
NewRequest
(
http
.
MethodGet
,
"/?"
+
query
,
nil
)
return
w
,
c
}
// ---------- 现有测试 ----------
func
TestErrorWithDetails
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
...
...
@@ -169,3 +207,582 @@ func TestErrorFrom(t *testing.T) {
})
}
}
// ---------- 新增测试 ----------
func
TestSuccess
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
tests
:=
[]
struct
{
name
string
data
any
wantCode
int
wantBody
Response
}{
{
name
:
"返回字符串数据"
,
data
:
"hello"
,
wantCode
:
http
.
StatusOK
,
wantBody
:
Response
{
Code
:
0
,
Message
:
"success"
,
Data
:
"hello"
},
},
{
name
:
"返回nil数据"
,
data
:
nil
,
wantCode
:
http
.
StatusOK
,
wantBody
:
Response
{
Code
:
0
,
Message
:
"success"
},
},
{
name
:
"返回map数据"
,
data
:
map
[
string
]
string
{
"key"
:
"value"
},
wantCode
:
http
.
StatusOK
,
wantBody
:
Response
{
Code
:
0
,
Message
:
"success"
},
},
}
for
_
,
tt
:=
range
tests
{
t
.
Run
(
tt
.
name
,
func
(
t
*
testing
.
T
)
{
w
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
w
)
Success
(
c
,
tt
.
data
)
require
.
Equal
(
t
,
tt
.
wantCode
,
w
.
Code
)
// 只验证 code 和 message,data 字段类型在 JSON 反序列化时会变成 map/slice
got
:=
parseResponseBody
(
t
,
w
)
require
.
Equal
(
t
,
0
,
got
.
Code
)
require
.
Equal
(
t
,
"success"
,
got
.
Message
)
if
tt
.
data
==
nil
{
require
.
Nil
(
t
,
got
.
Data
)
}
else
{
require
.
NotNil
(
t
,
got
.
Data
)
}
})
}
}
func
TestCreated
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
tests
:=
[]
struct
{
name
string
data
any
wantCode
int
}{
{
name
:
"创建成功_返回数据"
,
data
:
map
[
string
]
int
{
"id"
:
42
},
wantCode
:
http
.
StatusCreated
,
},
{
name
:
"创建成功_nil数据"
,
data
:
nil
,
wantCode
:
http
.
StatusCreated
,
},
}
for
_
,
tt
:=
range
tests
{
t
.
Run
(
tt
.
name
,
func
(
t
*
testing
.
T
)
{
w
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
w
)
Created
(
c
,
tt
.
data
)
require
.
Equal
(
t
,
tt
.
wantCode
,
w
.
Code
)
got
:=
parseResponseBody
(
t
,
w
)
require
.
Equal
(
t
,
0
,
got
.
Code
)
require
.
Equal
(
t
,
"success"
,
got
.
Message
)
})
}
}
func
TestError
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
tests
:=
[]
struct
{
name
string
statusCode
int
message
string
}{
{
name
:
"400错误"
,
statusCode
:
http
.
StatusBadRequest
,
message
:
"bad request"
,
},
{
name
:
"500错误"
,
statusCode
:
http
.
StatusInternalServerError
,
message
:
"internal error"
,
},
{
name
:
"自定义状态码"
,
statusCode
:
418
,
message
:
"I'm a teapot"
,
},
}
for
_
,
tt
:=
range
tests
{
t
.
Run
(
tt
.
name
,
func
(
t
*
testing
.
T
)
{
w
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
w
)
Error
(
c
,
tt
.
statusCode
,
tt
.
message
)
require
.
Equal
(
t
,
tt
.
statusCode
,
w
.
Code
)
got
:=
parseResponseBody
(
t
,
w
)
require
.
Equal
(
t
,
tt
.
statusCode
,
got
.
Code
)
require
.
Equal
(
t
,
tt
.
message
,
got
.
Message
)
require
.
Empty
(
t
,
got
.
Reason
)
require
.
Nil
(
t
,
got
.
Metadata
)
require
.
Nil
(
t
,
got
.
Data
)
})
}
}
func
TestBadRequest
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
w
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
w
)
BadRequest
(
c
,
"参数无效"
)
require
.
Equal
(
t
,
http
.
StatusBadRequest
,
w
.
Code
)
got
:=
parseResponseBody
(
t
,
w
)
require
.
Equal
(
t
,
http
.
StatusBadRequest
,
got
.
Code
)
require
.
Equal
(
t
,
"参数无效"
,
got
.
Message
)
}
func
TestUnauthorized
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
w
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
w
)
Unauthorized
(
c
,
"未登录"
)
require
.
Equal
(
t
,
http
.
StatusUnauthorized
,
w
.
Code
)
got
:=
parseResponseBody
(
t
,
w
)
require
.
Equal
(
t
,
http
.
StatusUnauthorized
,
got
.
Code
)
require
.
Equal
(
t
,
"未登录"
,
got
.
Message
)
}
func
TestForbidden
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
w
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
w
)
Forbidden
(
c
,
"无权限"
)
require
.
Equal
(
t
,
http
.
StatusForbidden
,
w
.
Code
)
got
:=
parseResponseBody
(
t
,
w
)
require
.
Equal
(
t
,
http
.
StatusForbidden
,
got
.
Code
)
require
.
Equal
(
t
,
"无权限"
,
got
.
Message
)
}
func
TestNotFound
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
w
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
w
)
NotFound
(
c
,
"资源不存在"
)
require
.
Equal
(
t
,
http
.
StatusNotFound
,
w
.
Code
)
got
:=
parseResponseBody
(
t
,
w
)
require
.
Equal
(
t
,
http
.
StatusNotFound
,
got
.
Code
)
require
.
Equal
(
t
,
"资源不存在"
,
got
.
Message
)
}
func
TestInternalError
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
w
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
w
)
InternalError
(
c
,
"服务器内部错误"
)
require
.
Equal
(
t
,
http
.
StatusInternalServerError
,
w
.
Code
)
got
:=
parseResponseBody
(
t
,
w
)
require
.
Equal
(
t
,
http
.
StatusInternalServerError
,
got
.
Code
)
require
.
Equal
(
t
,
"服务器内部错误"
,
got
.
Message
)
}
func
TestPaginated
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
tests
:=
[]
struct
{
name
string
items
any
total
int64
page
int
pageSize
int
wantPages
int
wantTotal
int64
wantPage
int
wantPageSize
int
}{
{
name
:
"标准分页_多页"
,
items
:
[]
string
{
"a"
,
"b"
},
total
:
25
,
page
:
1
,
pageSize
:
10
,
wantPages
:
3
,
wantTotal
:
25
,
wantPage
:
1
,
wantPageSize
:
10
,
},
{
name
:
"总数刚好整除"
,
items
:
[]
string
{
"a"
},
total
:
20
,
page
:
2
,
pageSize
:
10
,
wantPages
:
2
,
wantTotal
:
20
,
wantPage
:
2
,
wantPageSize
:
10
,
},
{
name
:
"总数为0_pages至少为1"
,
items
:
[]
string
{},
total
:
0
,
page
:
1
,
pageSize
:
10
,
wantPages
:
1
,
wantTotal
:
0
,
wantPage
:
1
,
wantPageSize
:
10
,
},
{
name
:
"单页数据"
,
items
:
[]
int
{
1
,
2
,
3
},
total
:
3
,
page
:
1
,
pageSize
:
20
,
wantPages
:
1
,
wantTotal
:
3
,
wantPage
:
1
,
wantPageSize
:
20
,
},
{
name
:
"总数为1"
,
items
:
[]
string
{
"only"
},
total
:
1
,
page
:
1
,
pageSize
:
10
,
wantPages
:
1
,
wantTotal
:
1
,
wantPage
:
1
,
wantPageSize
:
10
,
},
}
for
_
,
tt
:=
range
tests
{
t
.
Run
(
tt
.
name
,
func
(
t
*
testing
.
T
)
{
w
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
w
)
Paginated
(
c
,
tt
.
items
,
tt
.
total
,
tt
.
page
,
tt
.
pageSize
)
require
.
Equal
(
t
,
http
.
StatusOK
,
w
.
Code
)
resp
,
pd
:=
parsePaginatedBody
(
t
,
w
)
require
.
Equal
(
t
,
0
,
resp
.
Code
)
require
.
Equal
(
t
,
"success"
,
resp
.
Message
)
require
.
Equal
(
t
,
tt
.
wantTotal
,
pd
.
Total
)
require
.
Equal
(
t
,
tt
.
wantPage
,
pd
.
Page
)
require
.
Equal
(
t
,
tt
.
wantPageSize
,
pd
.
PageSize
)
require
.
Equal
(
t
,
tt
.
wantPages
,
pd
.
Pages
)
})
}
}
func
TestPaginatedWithResult
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
tests
:=
[]
struct
{
name
string
items
any
pagination
*
PaginationResult
wantTotal
int64
wantPage
int
wantPageSize
int
wantPages
int
}{
{
name
:
"正常分页结果"
,
items
:
[]
string
{
"a"
,
"b"
},
pagination
:
&
PaginationResult
{
Total
:
50
,
Page
:
3
,
PageSize
:
10
,
Pages
:
5
,
},
wantTotal
:
50
,
wantPage
:
3
,
wantPageSize
:
10
,
wantPages
:
5
,
},
{
name
:
"pagination为nil_使用默认值"
,
items
:
[]
string
{},
pagination
:
nil
,
wantTotal
:
0
,
wantPage
:
1
,
wantPageSize
:
20
,
wantPages
:
1
,
},
{
name
:
"单页结果"
,
items
:
[]
int
{
1
},
pagination
:
&
PaginationResult
{
Total
:
1
,
Page
:
1
,
PageSize
:
20
,
Pages
:
1
,
},
wantTotal
:
1
,
wantPage
:
1
,
wantPageSize
:
20
,
wantPages
:
1
,
},
}
for
_
,
tt
:=
range
tests
{
t
.
Run
(
tt
.
name
,
func
(
t
*
testing
.
T
)
{
w
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
w
)
PaginatedWithResult
(
c
,
tt
.
items
,
tt
.
pagination
)
require
.
Equal
(
t
,
http
.
StatusOK
,
w
.
Code
)
resp
,
pd
:=
parsePaginatedBody
(
t
,
w
)
require
.
Equal
(
t
,
0
,
resp
.
Code
)
require
.
Equal
(
t
,
"success"
,
resp
.
Message
)
require
.
Equal
(
t
,
tt
.
wantTotal
,
pd
.
Total
)
require
.
Equal
(
t
,
tt
.
wantPage
,
pd
.
Page
)
require
.
Equal
(
t
,
tt
.
wantPageSize
,
pd
.
PageSize
)
require
.
Equal
(
t
,
tt
.
wantPages
,
pd
.
Pages
)
})
}
}
func
TestParsePagination
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
tests
:=
[]
struct
{
name
string
query
string
wantPage
int
wantPageSize
int
}{
{
name
:
"无参数_使用默认值"
,
query
:
""
,
wantPage
:
1
,
wantPageSize
:
20
,
},
{
name
:
"仅指定page"
,
query
:
"page=3"
,
wantPage
:
3
,
wantPageSize
:
20
,
},
{
name
:
"仅指定page_size"
,
query
:
"page_size=50"
,
wantPage
:
1
,
wantPageSize
:
50
,
},
{
name
:
"同时指定page和page_size"
,
query
:
"page=2&page_size=30"
,
wantPage
:
2
,
wantPageSize
:
30
,
},
{
name
:
"使用limit代替page_size"
,
query
:
"limit=15"
,
wantPage
:
1
,
wantPageSize
:
15
,
},
{
name
:
"page_size优先于limit"
,
query
:
"page_size=25&limit=50"
,
wantPage
:
1
,
wantPageSize
:
25
,
},
{
name
:
"page为0_使用默认值"
,
query
:
"page=0"
,
wantPage
:
1
,
wantPageSize
:
20
,
},
{
name
:
"page_size超过1000_使用默认值"
,
query
:
"page_size=1001"
,
wantPage
:
1
,
wantPageSize
:
20
,
},
{
name
:
"page_size恰好1000_有效"
,
query
:
"page_size=1000"
,
wantPage
:
1
,
wantPageSize
:
1000
,
},
{
name
:
"page为非数字_使用默认值"
,
query
:
"page=abc"
,
wantPage
:
1
,
wantPageSize
:
20
,
},
{
name
:
"page_size为非数字_使用默认值"
,
query
:
"page_size=xyz"
,
wantPage
:
1
,
wantPageSize
:
20
,
},
{
name
:
"limit为非数字_使用默认值"
,
query
:
"limit=abc"
,
wantPage
:
1
,
wantPageSize
:
20
,
},
{
name
:
"page_size为0_使用默认值"
,
query
:
"page_size=0"
,
wantPage
:
1
,
wantPageSize
:
20
,
},
{
name
:
"limit为0_使用默认值"
,
query
:
"limit=0"
,
wantPage
:
1
,
wantPageSize
:
20
,
},
{
name
:
"大页码"
,
query
:
"page=999&page_size=100"
,
wantPage
:
999
,
wantPageSize
:
100
,
},
{
name
:
"page_size为1_最小有效值"
,
query
:
"page_size=1"
,
wantPage
:
1
,
wantPageSize
:
1
,
},
{
name
:
"混合数字和字母的page"
,
query
:
"page=12a"
,
wantPage
:
1
,
wantPageSize
:
20
,
},
{
name
:
"limit超过1000_使用默认值"
,
query
:
"limit=2000"
,
wantPage
:
1
,
wantPageSize
:
20
,
},
}
for
_
,
tt
:=
range
tests
{
t
.
Run
(
tt
.
name
,
func
(
t
*
testing
.
T
)
{
_
,
c
:=
newContextWithQuery
(
tt
.
query
)
page
,
pageSize
:=
ParsePagination
(
c
)
require
.
Equal
(
t
,
tt
.
wantPage
,
page
,
"page 不符合预期"
)
require
.
Equal
(
t
,
tt
.
wantPageSize
,
pageSize
,
"pageSize 不符合预期"
)
})
}
}
func
Test_parseInt
(
t
*
testing
.
T
)
{
tests
:=
[]
struct
{
name
string
input
string
wantVal
int
wantErr
bool
}{
{
name
:
"正常数字"
,
input
:
"123"
,
wantVal
:
123
,
wantErr
:
false
,
},
{
name
:
"零"
,
input
:
"0"
,
wantVal
:
0
,
wantErr
:
false
,
},
{
name
:
"单个数字"
,
input
:
"5"
,
wantVal
:
5
,
wantErr
:
false
,
},
{
name
:
"大数字"
,
input
:
"99999"
,
wantVal
:
99999
,
wantErr
:
false
,
},
{
name
:
"包含字母_返回0"
,
input
:
"abc"
,
wantVal
:
0
,
wantErr
:
false
,
},
{
name
:
"数字开头接字母_返回0"
,
input
:
"12a"
,
wantVal
:
0
,
wantErr
:
false
,
},
{
name
:
"包含负号_返回0"
,
input
:
"-1"
,
wantVal
:
0
,
wantErr
:
false
,
},
{
name
:
"包含小数点_返回0"
,
input
:
"1.5"
,
wantVal
:
0
,
wantErr
:
false
,
},
{
name
:
"包含空格_返回0"
,
input
:
"1 2"
,
wantVal
:
0
,
wantErr
:
false
,
},
{
name
:
"空字符串"
,
input
:
""
,
wantVal
:
0
,
wantErr
:
false
,
},
}
for
_
,
tt
:=
range
tests
{
t
.
Run
(
tt
.
name
,
func
(
t
*
testing
.
T
)
{
val
,
err
:=
parseInt
(
tt
.
input
)
if
tt
.
wantErr
{
require
.
Error
(
t
,
err
)
}
else
{
require
.
NoError
(
t
,
err
)
}
require
.
Equal
(
t
,
tt
.
wantVal
,
val
)
})
}
}
backend/internal/pkg/tlsfingerprint/dialer.go
View file @
3d79773b
...
...
@@ -268,8 +268,8 @@ func (d *SOCKS5ProxyDialer) DialTLSContext(ctx context.Context, network, addr st
"cipher_suites"
,
len
(
spec
.
CipherSuites
),
"extensions"
,
len
(
spec
.
Extensions
),
"compression_methods"
,
spec
.
CompressionMethods
,
"tls_vers_max"
,
fmt
.
Sprintf
(
"0x%04x"
,
spec
.
TLSVersMax
)
,
"tls_vers_min"
,
fmt
.
Sprintf
(
"0x%04x"
,
spec
.
TLSVersMin
)
)
"tls_vers_max"
,
spec
.
TLSVersMax
,
"tls_vers_min"
,
spec
.
TLSVersMin
)
if
d
.
profile
!=
nil
{
slog
.
Debug
(
"tls_fingerprint_socks5_using_profile"
,
"name"
,
d
.
profile
.
Name
,
"grease"
,
d
.
profile
.
EnableGREASE
)
...
...
@@ -286,7 +286,7 @@ func (d *SOCKS5ProxyDialer) DialTLSContext(ctx context.Context, network, addr st
return
nil
,
fmt
.
Errorf
(
"apply TLS preset: %w"
,
err
)
}
if
err
:=
tlsConn
.
Handshake
(
);
err
!=
nil
{
if
err
:=
tlsConn
.
Handshake
Context
(
ctx
);
err
!=
nil
{
slog
.
Debug
(
"tls_fingerprint_socks5_handshake_failed"
,
"error"
,
err
)
_
=
conn
.
Close
()
return
nil
,
fmt
.
Errorf
(
"TLS handshake failed: %w"
,
err
)
...
...
@@ -294,8 +294,8 @@ func (d *SOCKS5ProxyDialer) DialTLSContext(ctx context.Context, network, addr st
state
:=
tlsConn
.
ConnectionState
()
slog
.
Debug
(
"tls_fingerprint_socks5_handshake_success"
,
"version"
,
fmt
.
Sprintf
(
"0x%04x"
,
state
.
Version
)
,
"cipher_suite"
,
fmt
.
Sprintf
(
"0x%04x"
,
state
.
CipherSuite
)
,
"version"
,
state
.
Version
,
"cipher_suite"
,
state
.
CipherSuite
,
"alpn"
,
state
.
NegotiatedProtocol
)
return
tlsConn
,
nil
...
...
@@ -404,8 +404,8 @@ func (d *HTTPProxyDialer) DialTLSContext(ctx context.Context, network, addr stri
state
:=
tlsConn
.
ConnectionState
()
slog
.
Debug
(
"tls_fingerprint_http_proxy_handshake_success"
,
"version"
,
fmt
.
Sprintf
(
"0x%04x"
,
state
.
Version
)
,
"cipher_suite"
,
fmt
.
Sprintf
(
"0x%04x"
,
state
.
CipherSuite
)
,
"version"
,
state
.
Version
,
"cipher_suite"
,
state
.
CipherSuite
,
"alpn"
,
state
.
NegotiatedProtocol
)
return
tlsConn
,
nil
...
...
@@ -470,8 +470,8 @@ func (d *Dialer) DialTLSContext(ctx context.Context, network, addr string) (net.
// Log successful handshake details
state
:=
tlsConn
.
ConnectionState
()
slog
.
Debug
(
"tls_fingerprint_handshake_success"
,
"version"
,
fmt
.
Sprintf
(
"0x%04x"
,
state
.
Version
)
,
"cipher_suite"
,
fmt
.
Sprintf
(
"0x%04x"
,
state
.
CipherSuite
)
,
"version"
,
state
.
Version
,
"cipher_suite"
,
state
.
CipherSuite
,
"alpn"
,
state
.
NegotiatedProtocol
)
return
tlsConn
,
nil
...
...
backend/internal/pkg/tlsfingerprint/dialer_integration_test.go
View file @
3d79773b
...
...
@@ -30,7 +30,8 @@ func skipIfExternalServiceUnavailable(t *testing.T, err error) {
strings
.
Contains
(
errStr
,
"connection refused"
)
||
strings
.
Contains
(
errStr
,
"no such host"
)
||
strings
.
Contains
(
errStr
,
"network is unreachable"
)
||
strings
.
Contains
(
errStr
,
"timeout"
)
{
strings
.
Contains
(
errStr
,
"timeout"
)
||
strings
.
Contains
(
errStr
,
"deadline exceeded"
)
{
t
.
Skipf
(
"skipping test: external service unavailable: %v"
,
err
)
}
t
.
Fatalf
(
"failed to get fingerprint: %v"
,
err
)
...
...
backend/internal/pkg/tlsfingerprint/dialer_test.go
View file @
3d79773b
//go:build unit
// Package tlsfingerprint provides TLS fingerprint simulation for HTTP clients.
//
// Unit tests for TLS fingerprint dialer.
...
...
@@ -9,26 +11,161 @@
package
tlsfingerprint
import
(
"context"
"encoding/json"
"io"
"net/http"
"net/url"
"os"
"strings"
"testing"
"time"
)
// FingerprintResponse represents the response from tls.peet.ws/api/all.
type
FingerprintResponse
struct
{
IP
string
`json:"ip"`
TLS
TLSInfo
`json:"tls"`
HTTP2
any
`json:"http2"`
// TestDialerBasicConnection tests that the dialer can establish TLS connections.
func
TestDialerBasicConnection
(
t
*
testing
.
T
)
{
skipNetworkTest
(
t
)
// Create a dialer with default profile
profile
:=
&
Profile
{
Name
:
"Test Profile"
,
EnableGREASE
:
false
,
}
dialer
:=
NewDialer
(
profile
,
nil
)
// Create HTTP client with custom TLS dialer
client
:=
&
http
.
Client
{
Transport
:
&
http
.
Transport
{
DialTLSContext
:
dialer
.
DialTLSContext
,
},
Timeout
:
30
*
time
.
Second
,
}
// Make a request to a known HTTPS endpoint
resp
,
err
:=
client
.
Get
(
"https://www.google.com"
)
if
err
!=
nil
{
t
.
Fatalf
(
"failed to connect: %v"
,
err
)
}
defer
func
()
{
_
=
resp
.
Body
.
Close
()
}()
if
resp
.
StatusCode
!=
http
.
StatusOK
{
t
.
Errorf
(
"expected status 200, got %d"
,
resp
.
StatusCode
)
}
}
// TLSInfo contains TLS fingerprint details.
type
TLSInfo
struct
{
JA3
string
`json:"ja3"`
JA3Hash
string
`json:"ja3_hash"`
JA4
string
`json:"ja4"`
PeetPrint
string
`json:"peetprint"`
PeetPrintHash
string
`json:"peetprint_hash"`
ClientRandom
string
`json:"client_random"`
SessionID
string
`json:"session_id"`
// TestJA3Fingerprint verifies the JA3/JA4 fingerprint matches expected value.
// This test uses tls.peet.ws to verify the fingerprint.
// Expected JA3 hash: 1a28e69016765d92e3b381168d68922c (Claude CLI / Node.js 20.x)
// Expected JA4: t13d5911h1_a33745022dd6_1f22a2ca17c4 (d=domain) or t13i5911h1_... (i=IP)
func
TestJA3Fingerprint
(
t
*
testing
.
T
)
{
skipNetworkTest
(
t
)
profile
:=
&
Profile
{
Name
:
"Claude CLI Test"
,
EnableGREASE
:
false
,
}
dialer
:=
NewDialer
(
profile
,
nil
)
client
:=
&
http
.
Client
{
Transport
:
&
http
.
Transport
{
DialTLSContext
:
dialer
.
DialTLSContext
,
},
Timeout
:
30
*
time
.
Second
,
}
// Use tls.peet.ws fingerprint detection API
ctx
,
cancel
:=
context
.
WithTimeout
(
context
.
Background
(),
30
*
time
.
Second
)
defer
cancel
()
req
,
err
:=
http
.
NewRequestWithContext
(
ctx
,
"GET"
,
"https://tls.peet.ws/api/all"
,
nil
)
if
err
!=
nil
{
t
.
Fatalf
(
"failed to create request: %v"
,
err
)
}
req
.
Header
.
Set
(
"User-Agent"
,
"Claude Code/2.0.0 Node.js/20.0.0"
)
resp
,
err
:=
client
.
Do
(
req
)
if
err
!=
nil
{
t
.
Fatalf
(
"failed to get fingerprint: %v"
,
err
)
}
defer
func
()
{
_
=
resp
.
Body
.
Close
()
}()
body
,
err
:=
io
.
ReadAll
(
resp
.
Body
)
if
err
!=
nil
{
t
.
Fatalf
(
"failed to read response: %v"
,
err
)
}
var
fpResp
FingerprintResponse
if
err
:=
json
.
Unmarshal
(
body
,
&
fpResp
);
err
!=
nil
{
t
.
Logf
(
"Response body: %s"
,
string
(
body
))
t
.
Fatalf
(
"failed to parse fingerprint response: %v"
,
err
)
}
// Log all fingerprint information
t
.
Logf
(
"JA3: %s"
,
fpResp
.
TLS
.
JA3
)
t
.
Logf
(
"JA3 Hash: %s"
,
fpResp
.
TLS
.
JA3Hash
)
t
.
Logf
(
"JA4: %s"
,
fpResp
.
TLS
.
JA4
)
t
.
Logf
(
"PeetPrint: %s"
,
fpResp
.
TLS
.
PeetPrint
)
t
.
Logf
(
"PeetPrint Hash: %s"
,
fpResp
.
TLS
.
PeetPrintHash
)
// Verify JA3 hash matches expected value
expectedJA3Hash
:=
"1a28e69016765d92e3b381168d68922c"
if
fpResp
.
TLS
.
JA3Hash
==
expectedJA3Hash
{
t
.
Logf
(
"✓ JA3 hash matches expected value: %s"
,
expectedJA3Hash
)
}
else
{
t
.
Errorf
(
"✗ JA3 hash mismatch: got %s, expected %s"
,
fpResp
.
TLS
.
JA3Hash
,
expectedJA3Hash
)
}
// Verify JA4 fingerprint
// JA4 format: t[version][sni][cipher_count][ext_count][alpn]_[cipher_hash]_[ext_hash]
// Expected: t13d5910h1 (d=domain) or t13i5910h1 (i=IP)
// The suffix _a33745022dd6_1f22a2ca17c4 should match
expectedJA4Suffix
:=
"_a33745022dd6_1f22a2ca17c4"
if
strings
.
HasSuffix
(
fpResp
.
TLS
.
JA4
,
expectedJA4Suffix
)
{
t
.
Logf
(
"✓ JA4 suffix matches expected value: %s"
,
expectedJA4Suffix
)
}
else
{
t
.
Errorf
(
"✗ JA4 suffix mismatch: got %s, expected suffix %s"
,
fpResp
.
TLS
.
JA4
,
expectedJA4Suffix
)
}
// Verify JA4 prefix (t13d5911h1 or t13i5911h1)
// d = domain (SNI present), i = IP (no SNI)
// Since we connect to tls.peet.ws (domain), we expect 'd'
expectedJA4Prefix
:=
"t13d5911h1"
if
strings
.
HasPrefix
(
fpResp
.
TLS
.
JA4
,
expectedJA4Prefix
)
{
t
.
Logf
(
"✓ JA4 prefix matches: %s (t13=TLS1.3, d=domain, 59=ciphers, 11=extensions, h1=HTTP/1.1)"
,
expectedJA4Prefix
)
}
else
{
// Also accept 'i' variant for IP connections
altPrefix
:=
"t13i5911h1"
if
strings
.
HasPrefix
(
fpResp
.
TLS
.
JA4
,
altPrefix
)
{
t
.
Logf
(
"✓ JA4 prefix matches (IP variant): %s"
,
altPrefix
)
}
else
{
t
.
Errorf
(
"✗ JA4 prefix mismatch: got %s, expected %s or %s"
,
fpResp
.
TLS
.
JA4
,
expectedJA4Prefix
,
altPrefix
)
}
}
// Verify JA3 contains expected cipher suites (TLS 1.3 ciphers at the beginning)
if
strings
.
Contains
(
fpResp
.
TLS
.
JA3
,
"4866-4867-4865"
)
{
t
.
Logf
(
"✓ JA3 contains expected TLS 1.3 cipher suites"
)
}
else
{
t
.
Logf
(
"Warning: JA3 does not contain expected TLS 1.3 cipher suites"
)
}
// Verify extension list (should be 11 extensions including SNI)
// Expected: 0-11-10-35-16-22-23-13-43-45-51
expectedExtensions
:=
"0-11-10-35-16-22-23-13-43-45-51"
if
strings
.
Contains
(
fpResp
.
TLS
.
JA3
,
expectedExtensions
)
{
t
.
Logf
(
"✓ JA3 contains expected extension list: %s"
,
expectedExtensions
)
}
else
{
t
.
Logf
(
"Warning: JA3 extension list may differ"
)
}
}
func
skipNetworkTest
(
t
*
testing
.
T
)
{
if
testing
.
Short
()
{
t
.
Skip
(
"跳过网络测试(short 模式)"
)
}
if
os
.
Getenv
(
"TLSFINGERPRINT_NETWORK_TESTS"
)
!=
"1"
{
t
.
Skip
(
"跳过网络测试(需要设置 TLSFINGERPRINT_NETWORK_TESTS=1)"
)
}
}
// TestDialerWithProfile tests that different profiles produce different fingerprints.
...
...
@@ -158,3 +295,137 @@ func mustParseURL(rawURL string) *url.URL {
}
return
u
}
// TestProfileExpectation defines expected fingerprint values for a profile.
type
TestProfileExpectation
struct
{
Profile
*
Profile
ExpectedJA3
string
// Expected JA3 hash (empty = don't check)
ExpectedJA4
string
// Expected full JA4 (empty = don't check)
JA4CipherHash
string
// Expected JA4 cipher hash - the stable middle part (empty = don't check)
}
// TestAllProfiles tests multiple TLS fingerprint profiles against tls.peet.ws.
// Run with: go test -v -run TestAllProfiles ./internal/pkg/tlsfingerprint/...
func
TestAllProfiles
(
t
*
testing
.
T
)
{
skipNetworkTest
(
t
)
// Define all profiles to test with their expected fingerprints
// These profiles are from config.yaml gateway.tls_fingerprint.profiles
profiles
:=
[]
TestProfileExpectation
{
{
// Linux x64 Node.js v22.17.1
// Expected JA3 Hash: 1a28e69016765d92e3b381168d68922c
// Expected JA4: t13d5911h1_a33745022dd6_1f22a2ca17c4
Profile
:
&
Profile
{
Name
:
"linux_x64_node_v22171"
,
EnableGREASE
:
false
,
CipherSuites
:
[]
uint16
{
4866
,
4867
,
4865
,
49199
,
49195
,
49200
,
49196
,
158
,
49191
,
103
,
49192
,
107
,
163
,
159
,
52393
,
52392
,
52394
,
49327
,
49325
,
49315
,
49311
,
49245
,
49249
,
49239
,
49235
,
162
,
49326
,
49324
,
49314
,
49310
,
49244
,
49248
,
49238
,
49234
,
49188
,
106
,
49187
,
64
,
49162
,
49172
,
57
,
56
,
49161
,
49171
,
51
,
50
,
157
,
49313
,
49309
,
49233
,
156
,
49312
,
49308
,
49232
,
61
,
60
,
53
,
47
,
255
},
Curves
:
[]
uint16
{
29
,
23
,
30
,
25
,
24
,
256
,
257
,
258
,
259
,
260
},
PointFormats
:
[]
uint8
{
0
,
1
,
2
},
},
JA4CipherHash
:
"a33745022dd6"
,
// stable part
},
{
// MacOS arm64 Node.js v22.18.0
// Expected JA3 Hash: 70cb5ca646080902703ffda87036a5ea
// Expected JA4: t13d5912h1_a33745022dd6_dbd39dd1d406
Profile
:
&
Profile
{
Name
:
"macos_arm64_node_v22180"
,
EnableGREASE
:
false
,
CipherSuites
:
[]
uint16
{
4866
,
4867
,
4865
,
49199
,
49195
,
49200
,
49196
,
158
,
49191
,
103
,
49192
,
107
,
163
,
159
,
52393
,
52392
,
52394
,
49327
,
49325
,
49315
,
49311
,
49245
,
49249
,
49239
,
49235
,
162
,
49326
,
49324
,
49314
,
49310
,
49244
,
49248
,
49238
,
49234
,
49188
,
106
,
49187
,
64
,
49162
,
49172
,
57
,
56
,
49161
,
49171
,
51
,
50
,
157
,
49313
,
49309
,
49233
,
156
,
49312
,
49308
,
49232
,
61
,
60
,
53
,
47
,
255
},
Curves
:
[]
uint16
{
29
,
23
,
30
,
25
,
24
,
256
,
257
,
258
,
259
,
260
},
PointFormats
:
[]
uint8
{
0
,
1
,
2
},
},
JA4CipherHash
:
"a33745022dd6"
,
// stable part (same cipher suites)
},
}
for
_
,
tc
:=
range
profiles
{
tc
:=
tc
// capture range variable
t
.
Run
(
tc
.
Profile
.
Name
,
func
(
t
*
testing
.
T
)
{
fp
:=
fetchFingerprint
(
t
,
tc
.
Profile
)
if
fp
==
nil
{
return
// fetchFingerprint already called t.Fatal
}
t
.
Logf
(
"Profile: %s"
,
tc
.
Profile
.
Name
)
t
.
Logf
(
" JA3: %s"
,
fp
.
JA3
)
t
.
Logf
(
" JA3 Hash: %s"
,
fp
.
JA3Hash
)
t
.
Logf
(
" JA4: %s"
,
fp
.
JA4
)
t
.
Logf
(
" PeetPrint: %s"
,
fp
.
PeetPrint
)
t
.
Logf
(
" PeetPrintHash: %s"
,
fp
.
PeetPrintHash
)
// Verify expectations
if
tc
.
ExpectedJA3
!=
""
{
if
fp
.
JA3Hash
==
tc
.
ExpectedJA3
{
t
.
Logf
(
" ✓ JA3 hash matches: %s"
,
tc
.
ExpectedJA3
)
}
else
{
t
.
Errorf
(
" ✗ JA3 hash mismatch: got %s, expected %s"
,
fp
.
JA3Hash
,
tc
.
ExpectedJA3
)
}
}
if
tc
.
ExpectedJA4
!=
""
{
if
fp
.
JA4
==
tc
.
ExpectedJA4
{
t
.
Logf
(
" ✓ JA4 matches: %s"
,
tc
.
ExpectedJA4
)
}
else
{
t
.
Errorf
(
" ✗ JA4 mismatch: got %s, expected %s"
,
fp
.
JA4
,
tc
.
ExpectedJA4
)
}
}
// Check JA4 cipher hash (stable middle part)
// JA4 format: prefix_cipherHash_extHash
if
tc
.
JA4CipherHash
!=
""
{
if
strings
.
Contains
(
fp
.
JA4
,
"_"
+
tc
.
JA4CipherHash
+
"_"
)
{
t
.
Logf
(
" ✓ JA4 cipher hash matches: %s"
,
tc
.
JA4CipherHash
)
}
else
{
t
.
Errorf
(
" ✗ JA4 cipher hash mismatch: got %s, expected cipher hash %s"
,
fp
.
JA4
,
tc
.
JA4CipherHash
)
}
}
})
}
}
// fetchFingerprint makes a request to tls.peet.ws and returns the TLS fingerprint info.
func
fetchFingerprint
(
t
*
testing
.
T
,
profile
*
Profile
)
*
TLSInfo
{
t
.
Helper
()
dialer
:=
NewDialer
(
profile
,
nil
)
client
:=
&
http
.
Client
{
Transport
:
&
http
.
Transport
{
DialTLSContext
:
dialer
.
DialTLSContext
,
},
Timeout
:
30
*
time
.
Second
,
}
ctx
,
cancel
:=
context
.
WithTimeout
(
context
.
Background
(),
30
*
time
.
Second
)
defer
cancel
()
req
,
err
:=
http
.
NewRequestWithContext
(
ctx
,
"GET"
,
"https://tls.peet.ws/api/all"
,
nil
)
if
err
!=
nil
{
t
.
Fatalf
(
"failed to create request: %v"
,
err
)
return
nil
}
req
.
Header
.
Set
(
"User-Agent"
,
"Claude Code/2.0.0 Node.js/20.0.0"
)
resp
,
err
:=
client
.
Do
(
req
)
if
err
!=
nil
{
t
.
Fatalf
(
"failed to get fingerprint: %v"
,
err
)
return
nil
}
defer
func
()
{
_
=
resp
.
Body
.
Close
()
}()
body
,
err
:=
io
.
ReadAll
(
resp
.
Body
)
if
err
!=
nil
{
t
.
Fatalf
(
"failed to read response: %v"
,
err
)
return
nil
}
var
fpResp
FingerprintResponse
if
err
:=
json
.
Unmarshal
(
body
,
&
fpResp
);
err
!=
nil
{
t
.
Logf
(
"Response body: %s"
,
string
(
body
))
t
.
Fatalf
(
"failed to parse fingerprint response: %v"
,
err
)
return
nil
}
return
&
fpResp
.
TLS
}
backend/internal/pkg/tlsfingerprint/test_types_test.go
0 → 100644
View file @
3d79773b
package
tlsfingerprint
// FingerprintResponse represents the response from tls.peet.ws/api/all.
// 共享测试类型,供 unit 和 integration 测试文件使用。
type
FingerprintResponse
struct
{
IP
string
`json:"ip"`
TLS
TLSInfo
`json:"tls"`
HTTP2
any
`json:"http2"`
}
// TLSInfo contains TLS fingerprint details.
type
TLSInfo
struct
{
JA3
string
`json:"ja3"`
JA3Hash
string
`json:"ja3_hash"`
JA4
string
`json:"ja4"`
PeetPrint
string
`json:"peetprint"`
PeetPrintHash
string
`json:"peetprint_hash"`
ClientRandom
string
`json:"client_random"`
SessionID
string
`json:"session_id"`
}
backend/internal/pkg/usagestats/usage_log_types.go
View file @
3d79773b
...
...
@@ -78,6 +78,16 @@ type ModelStat struct {
ActualCost
float64
`json:"actual_cost"`
// 实际扣除
}
// GroupStat represents usage statistics for a single group
type
GroupStat
struct
{
GroupID
int64
`json:"group_id"`
GroupName
string
`json:"group_name"`
Requests
int64
`json:"requests"`
TotalTokens
int64
`json:"total_tokens"`
Cost
float64
`json:"cost"`
// 标准计费
ActualCost
float64
`json:"actual_cost"`
// 实际扣除
}
// UserUsageTrendPoint represents user usage trend data point
type
UserUsageTrendPoint
struct
{
Date
string
`json:"date"`
...
...
@@ -139,10 +149,13 @@ type UsageLogFilters struct {
AccountID
int64
GroupID
int64
Model
string
RequestType
*
int16
Stream
*
bool
BillingType
*
int8
StartTime
*
time
.
Time
EndTime
*
time
.
Time
// ExactTotal requests exact COUNT(*) for pagination. Default false for fast large-table paging.
ExactTotal
bool
}
// UsageStats represents usage statistics
...
...
backend/internal/repository/account_repo.go
View file @
3d79773b
...
...
@@ -15,7 +15,6 @@ import (
"database/sql"
"encoding/json"
"errors"
"log"
"strconv"
"time"
...
...
@@ -25,6 +24,7 @@ import (
dbgroup
"github.com/Wei-Shaw/sub2api/ent/group"
dbpredicate
"github.com/Wei-Shaw/sub2api/ent/predicate"
dbproxy
"github.com/Wei-Shaw/sub2api/ent/proxy"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/lib/pq"
...
...
@@ -50,11 +50,6 @@ type accountRepository struct {
schedulerCache
service
.
SchedulerCache
}
type
tempUnschedSnapshot
struct
{
until
*
time
.
Time
reason
string
}
// NewAccountRepository 创建账户仓储实例。
// 这是对外暴露的构造函数,返回接口类型以便于依赖注入。
func
NewAccountRepository
(
client
*
dbent
.
Client
,
sqlDB
*
sql
.
DB
,
schedulerCache
service
.
SchedulerCache
)
service
.
AccountRepository
{
...
...
@@ -127,7 +122,7 @@ func (r *accountRepository) Create(ctx context.Context, account *service.Account
account
.
CreatedAt
=
created
.
CreatedAt
account
.
UpdatedAt
=
created
.
UpdatedAt
if
err
:=
enqueueSchedulerOutbox
(
ctx
,
r
.
sql
,
service
.
SchedulerOutboxEventAccountChanged
,
&
account
.
ID
,
nil
,
buildSchedulerGroupPayload
(
account
.
GroupIDs
));
err
!=
nil
{
log
.
Printf
(
"[SchedulerOutbox] enqueue account create failed: account=%d err=%v"
,
account
.
ID
,
err
)
log
ger
.
LegacyPrintf
(
"repository.account"
,
"[SchedulerOutbox] enqueue account create failed: account=%d err=%v"
,
account
.
ID
,
err
)
}
return
nil
}
...
...
@@ -189,11 +184,6 @@ func (r *accountRepository) GetByIDs(ctx context.Context, ids []int64) ([]*servi
accountIDs
=
append
(
accountIDs
,
acc
.
ID
)
}
tempUnschedMap
,
err
:=
r
.
loadTempUnschedStates
(
ctx
,
accountIDs
)
if
err
!=
nil
{
return
nil
,
err
}
groupsByAccount
,
groupIDsByAccount
,
accountGroupsByAccount
,
err
:=
r
.
loadAccountGroups
(
ctx
,
accountIDs
)
if
err
!=
nil
{
return
nil
,
err
...
...
@@ -220,10 +210,6 @@ func (r *accountRepository) GetByIDs(ctx context.Context, ids []int64) ([]*servi
if
ags
,
ok
:=
accountGroupsByAccount
[
entAcc
.
ID
];
ok
{
out
.
AccountGroups
=
ags
}
if
snap
,
ok
:=
tempUnschedMap
[
entAcc
.
ID
];
ok
{
out
.
TempUnschedulableUntil
=
snap
.
until
out
.
TempUnschedulableReason
=
snap
.
reason
}
outByID
[
entAcc
.
ID
]
=
out
}
...
...
@@ -388,7 +374,7 @@ func (r *accountRepository) Update(ctx context.Context, account *service.Account
}
account
.
UpdatedAt
=
updated
.
UpdatedAt
if
err
:=
enqueueSchedulerOutbox
(
ctx
,
r
.
sql
,
service
.
SchedulerOutboxEventAccountChanged
,
&
account
.
ID
,
nil
,
buildSchedulerGroupPayload
(
account
.
GroupIDs
));
err
!=
nil
{
log
.
Printf
(
"[SchedulerOutbox] enqueue account update failed: account=%d err=%v"
,
account
.
ID
,
err
)
log
ger
.
LegacyPrintf
(
"repository.account"
,
"[SchedulerOutbox] enqueue account update failed: account=%d err=%v"
,
account
.
ID
,
err
)
}
if
account
.
Status
==
service
.
StatusError
||
account
.
Status
==
service
.
StatusDisabled
||
!
account
.
Schedulable
{
r
.
syncSchedulerAccountSnapshot
(
ctx
,
account
.
ID
)
...
...
@@ -429,7 +415,7 @@ func (r *accountRepository) Delete(ctx context.Context, id int64) error {
}
}
if
err
:=
enqueueSchedulerOutbox
(
ctx
,
r
.
sql
,
service
.
SchedulerOutboxEventAccountChanged
,
&
id
,
nil
,
buildSchedulerGroupPayload
(
groupIDs
));
err
!=
nil
{
log
.
Printf
(
"[SchedulerOutbox] enqueue account delete failed: account=%d err=%v"
,
id
,
err
)
log
ger
.
LegacyPrintf
(
"repository.account"
,
"[SchedulerOutbox] enqueue account delete failed: account=%d err=%v"
,
id
,
err
)
}
return
nil
}
...
...
@@ -541,7 +527,7 @@ func (r *accountRepository) UpdateLastUsed(ctx context.Context, id int64) error
},
}
if
err
:=
enqueueSchedulerOutbox
(
ctx
,
r
.
sql
,
service
.
SchedulerOutboxEventAccountLastUsed
,
&
id
,
nil
,
payload
);
err
!=
nil
{
log
.
Printf
(
"[SchedulerOutbox] enqueue last used failed: account=%d err=%v"
,
id
,
err
)
log
ger
.
LegacyPrintf
(
"repository.account"
,
"[SchedulerOutbox] enqueue last used failed: account=%d err=%v"
,
id
,
err
)
}
return
nil
}
...
...
@@ -576,7 +562,7 @@ func (r *accountRepository) BatchUpdateLastUsed(ctx context.Context, updates map
}
payload
:=
map
[
string
]
any
{
"last_used"
:
lastUsedPayload
}
if
err
:=
enqueueSchedulerOutbox
(
ctx
,
r
.
sql
,
service
.
SchedulerOutboxEventAccountLastUsed
,
nil
,
nil
,
payload
);
err
!=
nil
{
log
.
Printf
(
"[SchedulerOutbox] enqueue batch last used failed: err=%v"
,
err
)
log
ger
.
LegacyPrintf
(
"repository.account"
,
"[SchedulerOutbox] enqueue batch last used failed: err=%v"
,
err
)
}
return
nil
}
...
...
@@ -591,7 +577,7 @@ func (r *accountRepository) SetError(ctx context.Context, id int64, errorMsg str
return
err
}
if
err
:=
enqueueSchedulerOutbox
(
ctx
,
r
.
sql
,
service
.
SchedulerOutboxEventAccountChanged
,
&
id
,
nil
,
nil
);
err
!=
nil
{
log
.
Printf
(
"[SchedulerOutbox] enqueue set error failed: account=%d err=%v"
,
id
,
err
)
log
ger
.
LegacyPrintf
(
"repository.account"
,
"[SchedulerOutbox] enqueue set error failed: account=%d err=%v"
,
id
,
err
)
}
r
.
syncSchedulerAccountSnapshot
(
ctx
,
id
)
return
nil
...
...
@@ -611,11 +597,48 @@ func (r *accountRepository) syncSchedulerAccountSnapshot(ctx context.Context, ac
}
account
,
err
:=
r
.
GetByID
(
ctx
,
accountID
)
if
err
!=
nil
{
log
.
Printf
(
"[Scheduler] sync account snapshot read failed: id=%d err=%v"
,
accountID
,
err
)
logger
.
LegacyPrintf
(
"repository.account"
,
"[Scheduler] sync account snapshot read failed: id=%d err=%v"
,
accountID
,
err
)
return
}
if
err
:=
r
.
schedulerCache
.
SetAccount
(
ctx
,
account
);
err
!=
nil
{
logger
.
LegacyPrintf
(
"repository.account"
,
"[Scheduler] sync account snapshot write failed: id=%d err=%v"
,
accountID
,
err
)
}
}
func
(
r
*
accountRepository
)
syncSchedulerAccountSnapshots
(
ctx
context
.
Context
,
accountIDs
[]
int64
)
{
if
r
==
nil
||
r
.
schedulerCache
==
nil
||
len
(
accountIDs
)
==
0
{
return
}
uniqueIDs
:=
make
([]
int64
,
0
,
len
(
accountIDs
))
seen
:=
make
(
map
[
int64
]
struct
{},
len
(
accountIDs
))
for
_
,
id
:=
range
accountIDs
{
if
id
<=
0
{
continue
}
if
_
,
exists
:=
seen
[
id
];
exists
{
continue
}
seen
[
id
]
=
struct
{}{}
uniqueIDs
=
append
(
uniqueIDs
,
id
)
}
if
len
(
uniqueIDs
)
==
0
{
return
}
accounts
,
err
:=
r
.
GetByIDs
(
ctx
,
uniqueIDs
)
if
err
!=
nil
{
logger
.
LegacyPrintf
(
"repository.account"
,
"[Scheduler] batch sync account snapshot read failed: count=%d err=%v"
,
len
(
uniqueIDs
),
err
)
return
}
for
_
,
account
:=
range
accounts
{
if
account
==
nil
{
continue
}
if
err
:=
r
.
schedulerCache
.
SetAccount
(
ctx
,
account
);
err
!=
nil
{
log
.
Printf
(
"[Scheduler] sync account snapshot write failed: id=%d err=%v"
,
accountID
,
err
)
logger
.
LegacyPrintf
(
"repository.account"
,
"[Scheduler] batch sync account snapshot write failed: id=%d err=%v"
,
account
.
ID
,
err
)
}
}
}
...
...
@@ -649,7 +672,7 @@ func (r *accountRepository) AddToGroup(ctx context.Context, accountID, groupID i
}
payload
:=
buildSchedulerGroupPayload
([]
int64
{
groupID
})
if
err
:=
enqueueSchedulerOutbox
(
ctx
,
r
.
sql
,
service
.
SchedulerOutboxEventAccountGroupsChanged
,
&
accountID
,
nil
,
payload
);
err
!=
nil
{
log
.
Printf
(
"[SchedulerOutbox] enqueue add to group failed: account=%d group=%d err=%v"
,
accountID
,
groupID
,
err
)
log
ger
.
LegacyPrintf
(
"repository.account"
,
"[SchedulerOutbox] enqueue add to group failed: account=%d group=%d err=%v"
,
accountID
,
groupID
,
err
)
}
return
nil
}
...
...
@@ -666,7 +689,7 @@ func (r *accountRepository) RemoveFromGroup(ctx context.Context, accountID, grou
}
payload
:=
buildSchedulerGroupPayload
([]
int64
{
groupID
})
if
err
:=
enqueueSchedulerOutbox
(
ctx
,
r
.
sql
,
service
.
SchedulerOutboxEventAccountGroupsChanged
,
&
accountID
,
nil
,
payload
);
err
!=
nil
{
log
.
Printf
(
"[SchedulerOutbox] enqueue remove from group failed: account=%d group=%d err=%v"
,
accountID
,
groupID
,
err
)
log
ger
.
LegacyPrintf
(
"repository.account"
,
"[SchedulerOutbox] enqueue remove from group failed: account=%d group=%d err=%v"
,
accountID
,
groupID
,
err
)
}
return
nil
}
...
...
@@ -739,7 +762,7 @@ func (r *accountRepository) BindGroups(ctx context.Context, accountID int64, gro
}
payload
:=
buildSchedulerGroupPayload
(
mergeGroupIDs
(
existingGroupIDs
,
groupIDs
))
if
err
:=
enqueueSchedulerOutbox
(
ctx
,
r
.
sql
,
service
.
SchedulerOutboxEventAccountGroupsChanged
,
&
accountID
,
nil
,
payload
);
err
!=
nil
{
log
.
Printf
(
"[SchedulerOutbox] enqueue bind groups failed: account=%d err=%v"
,
accountID
,
err
)
log
ger
.
LegacyPrintf
(
"repository.account"
,
"[SchedulerOutbox] enqueue bind groups failed: account=%d err=%v"
,
accountID
,
err
)
}
return
nil
}
...
...
@@ -824,6 +847,51 @@ func (r *accountRepository) ListSchedulableByPlatforms(ctx context.Context, plat
return
r
.
accountsToService
(
ctx
,
accounts
)
}
func
(
r
*
accountRepository
)
ListSchedulableUngroupedByPlatform
(
ctx
context
.
Context
,
platform
string
)
([]
service
.
Account
,
error
)
{
now
:=
time
.
Now
()
accounts
,
err
:=
r
.
client
.
Account
.
Query
()
.
Where
(
dbaccount
.
PlatformEQ
(
platform
),
dbaccount
.
StatusEQ
(
service
.
StatusActive
),
dbaccount
.
SchedulableEQ
(
true
),
dbaccount
.
Not
(
dbaccount
.
HasAccountGroups
()),
tempUnschedulablePredicate
(),
notExpiredPredicate
(
now
),
dbaccount
.
Or
(
dbaccount
.
OverloadUntilIsNil
(),
dbaccount
.
OverloadUntilLTE
(
now
)),
dbaccount
.
Or
(
dbaccount
.
RateLimitResetAtIsNil
(),
dbaccount
.
RateLimitResetAtLTE
(
now
)),
)
.
Order
(
dbent
.
Asc
(
dbaccount
.
FieldPriority
))
.
All
(
ctx
)
if
err
!=
nil
{
return
nil
,
err
}
return
r
.
accountsToService
(
ctx
,
accounts
)
}
func
(
r
*
accountRepository
)
ListSchedulableUngroupedByPlatforms
(
ctx
context
.
Context
,
platforms
[]
string
)
([]
service
.
Account
,
error
)
{
if
len
(
platforms
)
==
0
{
return
nil
,
nil
}
now
:=
time
.
Now
()
accounts
,
err
:=
r
.
client
.
Account
.
Query
()
.
Where
(
dbaccount
.
PlatformIn
(
platforms
...
),
dbaccount
.
StatusEQ
(
service
.
StatusActive
),
dbaccount
.
SchedulableEQ
(
true
),
dbaccount
.
Not
(
dbaccount
.
HasAccountGroups
()),
tempUnschedulablePredicate
(),
notExpiredPredicate
(
now
),
dbaccount
.
Or
(
dbaccount
.
OverloadUntilIsNil
(),
dbaccount
.
OverloadUntilLTE
(
now
)),
dbaccount
.
Or
(
dbaccount
.
RateLimitResetAtIsNil
(),
dbaccount
.
RateLimitResetAtLTE
(
now
)),
)
.
Order
(
dbent
.
Asc
(
dbaccount
.
FieldPriority
))
.
All
(
ctx
)
if
err
!=
nil
{
return
nil
,
err
}
return
r
.
accountsToService
(
ctx
,
accounts
)
}
func
(
r
*
accountRepository
)
ListSchedulableByGroupIDAndPlatforms
(
ctx
context
.
Context
,
groupID
int64
,
platforms
[]
string
)
([]
service
.
Account
,
error
)
{
if
len
(
platforms
)
==
0
{
return
nil
,
nil
...
...
@@ -847,7 +915,7 @@ func (r *accountRepository) SetRateLimited(ctx context.Context, id int64, resetA
return
err
}
if
err
:=
enqueueSchedulerOutbox
(
ctx
,
r
.
sql
,
service
.
SchedulerOutboxEventAccountChanged
,
&
id
,
nil
,
nil
);
err
!=
nil
{
log
.
Printf
(
"[SchedulerOutbox] enqueue rate limit failed: account=%d err=%v"
,
id
,
err
)
log
ger
.
LegacyPrintf
(
"repository.account"
,
"[SchedulerOutbox] enqueue rate limit failed: account=%d err=%v"
,
id
,
err
)
}
return
nil
}
...
...
@@ -894,7 +962,7 @@ func (r *accountRepository) SetModelRateLimit(ctx context.Context, id int64, sco
return
service
.
ErrAccountNotFound
}
if
err
:=
enqueueSchedulerOutbox
(
ctx
,
r
.
sql
,
service
.
SchedulerOutboxEventAccountChanged
,
&
id
,
nil
,
nil
);
err
!=
nil
{
log
.
Printf
(
"[SchedulerOutbox] enqueue model rate limit failed: account=%d err=%v"
,
id
,
err
)
log
ger
.
LegacyPrintf
(
"repository.account"
,
"[SchedulerOutbox] enqueue model rate limit failed: account=%d err=%v"
,
id
,
err
)
}
return
nil
}
...
...
@@ -908,7 +976,7 @@ func (r *accountRepository) SetOverloaded(ctx context.Context, id int64, until t
return
err
}
if
err
:=
enqueueSchedulerOutbox
(
ctx
,
r
.
sql
,
service
.
SchedulerOutboxEventAccountChanged
,
&
id
,
nil
,
nil
);
err
!=
nil
{
log
.
Printf
(
"[SchedulerOutbox] enqueue overload failed: account=%d err=%v"
,
id
,
err
)
log
ger
.
LegacyPrintf
(
"repository.account"
,
"[SchedulerOutbox] enqueue overload failed: account=%d err=%v"
,
id
,
err
)
}
return
nil
}
...
...
@@ -927,7 +995,7 @@ func (r *accountRepository) SetTempUnschedulable(ctx context.Context, id int64,
return
err
}
if
err
:=
enqueueSchedulerOutbox
(
ctx
,
r
.
sql
,
service
.
SchedulerOutboxEventAccountChanged
,
&
id
,
nil
,
nil
);
err
!=
nil
{
log
.
Printf
(
"[SchedulerOutbox] enqueue temp unschedulable failed: account=%d err=%v"
,
id
,
err
)
log
ger
.
LegacyPrintf
(
"repository.account"
,
"[SchedulerOutbox] enqueue temp unschedulable failed: account=%d err=%v"
,
id
,
err
)
}
r
.
syncSchedulerAccountSnapshot
(
ctx
,
id
)
return
nil
...
...
@@ -946,7 +1014,7 @@ func (r *accountRepository) ClearTempUnschedulable(ctx context.Context, id int64
return
err
}
if
err
:=
enqueueSchedulerOutbox
(
ctx
,
r
.
sql
,
service
.
SchedulerOutboxEventAccountChanged
,
&
id
,
nil
,
nil
);
err
!=
nil
{
log
.
Printf
(
"[SchedulerOutbox] enqueue clear temp unschedulable failed: account=%d err=%v"
,
id
,
err
)
log
ger
.
LegacyPrintf
(
"repository.account"
,
"[SchedulerOutbox] enqueue clear temp unschedulable failed: account=%d err=%v"
,
id
,
err
)
}
return
nil
}
...
...
@@ -962,7 +1030,7 @@ func (r *accountRepository) ClearRateLimit(ctx context.Context, id int64) error
return
err
}
if
err
:=
enqueueSchedulerOutbox
(
ctx
,
r
.
sql
,
service
.
SchedulerOutboxEventAccountChanged
,
&
id
,
nil
,
nil
);
err
!=
nil
{
log
.
Printf
(
"[SchedulerOutbox] enqueue clear rate limit failed: account=%d err=%v"
,
id
,
err
)
log
ger
.
LegacyPrintf
(
"repository.account"
,
"[SchedulerOutbox] enqueue clear rate limit failed: account=%d err=%v"
,
id
,
err
)
}
return
nil
}
...
...
@@ -986,7 +1054,7 @@ func (r *accountRepository) ClearAntigravityQuotaScopes(ctx context.Context, id
return
service
.
ErrAccountNotFound
}
if
err
:=
enqueueSchedulerOutbox
(
ctx
,
r
.
sql
,
service
.
SchedulerOutboxEventAccountChanged
,
&
id
,
nil
,
nil
);
err
!=
nil
{
log
.
Printf
(
"[SchedulerOutbox] enqueue clear quota scopes failed: account=%d err=%v"
,
id
,
err
)
log
ger
.
LegacyPrintf
(
"repository.account"
,
"[SchedulerOutbox] enqueue clear quota scopes failed: account=%d err=%v"
,
id
,
err
)
}
return
nil
}
...
...
@@ -1010,7 +1078,7 @@ func (r *accountRepository) ClearModelRateLimits(ctx context.Context, id int64)
return
service
.
ErrAccountNotFound
}
if
err
:=
enqueueSchedulerOutbox
(
ctx
,
r
.
sql
,
service
.
SchedulerOutboxEventAccountChanged
,
&
id
,
nil
,
nil
);
err
!=
nil
{
log
.
Printf
(
"[SchedulerOutbox] enqueue clear model rate limit failed: account=%d err=%v"
,
id
,
err
)
log
ger
.
LegacyPrintf
(
"repository.account"
,
"[SchedulerOutbox] enqueue clear model rate limit failed: account=%d err=%v"
,
id
,
err
)
}
return
nil
}
...
...
@@ -1032,7 +1100,7 @@ func (r *accountRepository) UpdateSessionWindow(ctx context.Context, id int64, s
// 触发调度器缓存更新(仅当窗口时间有变化时)
if
start
!=
nil
||
end
!=
nil
{
if
err
:=
enqueueSchedulerOutbox
(
ctx
,
r
.
sql
,
service
.
SchedulerOutboxEventAccountChanged
,
&
id
,
nil
,
nil
);
err
!=
nil
{
log
.
Printf
(
"[SchedulerOutbox] enqueue session window update failed: account=%d err=%v"
,
id
,
err
)
log
ger
.
LegacyPrintf
(
"repository.account"
,
"[SchedulerOutbox] enqueue session window update failed: account=%d err=%v"
,
id
,
err
)
}
}
return
nil
...
...
@@ -1047,7 +1115,7 @@ func (r *accountRepository) SetSchedulable(ctx context.Context, id int64, schedu
return
err
}
if
err
:=
enqueueSchedulerOutbox
(
ctx
,
r
.
sql
,
service
.
SchedulerOutboxEventAccountChanged
,
&
id
,
nil
,
nil
);
err
!=
nil
{
log
.
Printf
(
"[SchedulerOutbox] enqueue schedulable change failed: account=%d err=%v"
,
id
,
err
)
log
ger
.
LegacyPrintf
(
"repository.account"
,
"[SchedulerOutbox] enqueue schedulable change failed: account=%d err=%v"
,
id
,
err
)
}
if
!
schedulable
{
r
.
syncSchedulerAccountSnapshot
(
ctx
,
id
)
...
...
@@ -1075,7 +1143,7 @@ func (r *accountRepository) AutoPauseExpiredAccounts(ctx context.Context, now ti
}
if
rows
>
0
{
if
err
:=
enqueueSchedulerOutbox
(
ctx
,
r
.
sql
,
service
.
SchedulerOutboxEventFullRebuild
,
nil
,
nil
,
nil
);
err
!=
nil
{
log
.
Printf
(
"[SchedulerOutbox] enqueue auto pause rebuild failed: err=%v"
,
err
)
log
ger
.
LegacyPrintf
(
"repository.account"
,
"[SchedulerOutbox] enqueue auto pause rebuild failed: err=%v"
,
err
)
}
}
return
rows
,
nil
...
...
@@ -1111,7 +1179,7 @@ func (r *accountRepository) UpdateExtra(ctx context.Context, id int64, updates m
return
service
.
ErrAccountNotFound
}
if
err
:=
enqueueSchedulerOutbox
(
ctx
,
r
.
sql
,
service
.
SchedulerOutboxEventAccountChanged
,
&
id
,
nil
,
nil
);
err
!=
nil
{
log
.
Printf
(
"[SchedulerOutbox] enqueue extra update failed: account=%d err=%v"
,
id
,
err
)
log
ger
.
LegacyPrintf
(
"repository.account"
,
"[SchedulerOutbox] enqueue extra update failed: account=%d err=%v"
,
id
,
err
)
}
return
nil
}
...
...
@@ -1205,7 +1273,7 @@ func (r *accountRepository) BulkUpdate(ctx context.Context, ids []int64, updates
if
rows
>
0
{
payload
:=
map
[
string
]
any
{
"account_ids"
:
ids
}
if
err
:=
enqueueSchedulerOutbox
(
ctx
,
r
.
sql
,
service
.
SchedulerOutboxEventAccountBulkChanged
,
nil
,
nil
,
payload
);
err
!=
nil
{
log
.
Printf
(
"[SchedulerOutbox] enqueue bulk update failed: err=%v"
,
err
)
log
ger
.
LegacyPrintf
(
"repository.account"
,
"[SchedulerOutbox] enqueue bulk update failed: err=%v"
,
err
)
}
shouldSync
:=
false
if
updates
.
Status
!=
nil
&&
(
*
updates
.
Status
==
service
.
StatusError
||
*
updates
.
Status
==
service
.
StatusDisabled
)
{
...
...
@@ -1215,9 +1283,7 @@ func (r *accountRepository) BulkUpdate(ctx context.Context, ids []int64, updates
shouldSync
=
true
}
if
shouldSync
{
for
_
,
id
:=
range
ids
{
r
.
syncSchedulerAccountSnapshot
(
ctx
,
id
)
}
r
.
syncSchedulerAccountSnapshots
(
ctx
,
ids
)
}
}
return
rows
,
nil
...
...
@@ -1309,10 +1375,6 @@ func (r *accountRepository) accountsToService(ctx context.Context, accounts []*d
if
err
!=
nil
{
return
nil
,
err
}
tempUnschedMap
,
err
:=
r
.
loadTempUnschedStates
(
ctx
,
accountIDs
)
if
err
!=
nil
{
return
nil
,
err
}
groupsByAccount
,
groupIDsByAccount
,
accountGroupsByAccount
,
err
:=
r
.
loadAccountGroups
(
ctx
,
accountIDs
)
if
err
!=
nil
{
return
nil
,
err
...
...
@@ -1338,10 +1400,6 @@ func (r *accountRepository) accountsToService(ctx context.Context, accounts []*d
if
ags
,
ok
:=
accountGroupsByAccount
[
acc
.
ID
];
ok
{
out
.
AccountGroups
=
ags
}
if
snap
,
ok
:=
tempUnschedMap
[
acc
.
ID
];
ok
{
out
.
TempUnschedulableUntil
=
snap
.
until
out
.
TempUnschedulableReason
=
snap
.
reason
}
outAccounts
=
append
(
outAccounts
,
*
out
)
}
...
...
@@ -1366,48 +1424,6 @@ func notExpiredPredicate(now time.Time) dbpredicate.Account {
)
}
func
(
r
*
accountRepository
)
loadTempUnschedStates
(
ctx
context
.
Context
,
accountIDs
[]
int64
)
(
map
[
int64
]
tempUnschedSnapshot
,
error
)
{
out
:=
make
(
map
[
int64
]
tempUnschedSnapshot
)
if
len
(
accountIDs
)
==
0
{
return
out
,
nil
}
rows
,
err
:=
r
.
sql
.
QueryContext
(
ctx
,
`
SELECT id, temp_unschedulable_until, temp_unschedulable_reason
FROM accounts
WHERE id = ANY($1)
`
,
pq
.
Array
(
accountIDs
))
if
err
!=
nil
{
return
nil
,
err
}
defer
func
()
{
_
=
rows
.
Close
()
}()
for
rows
.
Next
()
{
var
id
int64
var
until
sql
.
NullTime
var
reason
sql
.
NullString
if
err
:=
rows
.
Scan
(
&
id
,
&
until
,
&
reason
);
err
!=
nil
{
return
nil
,
err
}
var
untilPtr
*
time
.
Time
if
until
.
Valid
{
tmp
:=
until
.
Time
untilPtr
=
&
tmp
}
if
reason
.
Valid
{
out
[
id
]
=
tempUnschedSnapshot
{
until
:
untilPtr
,
reason
:
reason
.
String
}
}
else
{
out
[
id
]
=
tempUnschedSnapshot
{
until
:
untilPtr
,
reason
:
""
}
}
}
if
err
:=
rows
.
Err
();
err
!=
nil
{
return
nil
,
err
}
return
out
,
nil
}
func
(
r
*
accountRepository
)
loadProxies
(
ctx
context
.
Context
,
proxyIDs
[]
int64
)
(
map
[
int64
]
*
service
.
Proxy
,
error
)
{
proxyMap
:=
make
(
map
[
int64
]
*
service
.
Proxy
)
if
len
(
proxyIDs
)
==
0
{
...
...
@@ -1540,6 +1556,8 @@ func accountEntityToService(m *dbent.Account) *service.Account {
RateLimitedAt
:
m
.
RateLimitedAt
,
RateLimitResetAt
:
m
.
RateLimitResetAt
,
OverloadUntil
:
m
.
OverloadUntil
,
TempUnschedulableUntil
:
m
.
TempUnschedulableUntil
,
TempUnschedulableReason
:
derefString
(
m
.
TempUnschedulableReason
),
SessionWindowStart
:
m
.
SessionWindowStart
,
SessionWindowEnd
:
m
.
SessionWindowEnd
,
SessionWindowStatus
:
derefString
(
m
.
SessionWindowStatus
),
...
...
@@ -1578,3 +1596,64 @@ func joinClauses(clauses []string, sep string) string {
func
itoa
(
v
int
)
string
{
return
strconv
.
Itoa
(
v
)
}
// FindByExtraField 根据 extra 字段中的键值对查找账号。
// 该方法限定 platform='sora',避免误查询其他平台的账号。
// 使用 PostgreSQL JSONB @> 操作符进行高效查询(需要 GIN 索引支持)。
//
// 应用场景:查找通过 linked_openai_account_id 关联的 Sora 账号。
//
// FindByExtraField finds accounts by key-value pairs in the extra field.
// Limited to platform='sora' to avoid querying accounts from other platforms.
// Uses PostgreSQL JSONB @> operator for efficient queries (requires GIN index).
//
// Use case: Finding Sora accounts linked via linked_openai_account_id.
func
(
r
*
accountRepository
)
FindByExtraField
(
ctx
context
.
Context
,
key
string
,
value
any
)
([]
service
.
Account
,
error
)
{
accounts
,
err
:=
r
.
client
.
Account
.
Query
()
.
Where
(
dbaccount
.
PlatformEQ
(
"sora"
),
// 限定平台为 sora
dbaccount
.
DeletedAtIsNil
(),
func
(
s
*
entsql
.
Selector
)
{
path
:=
sqljson
.
Path
(
key
)
switch
v
:=
value
.
(
type
)
{
case
string
:
preds
:=
[]
*
entsql
.
Predicate
{
sqljson
.
ValueEQ
(
dbaccount
.
FieldExtra
,
v
,
path
)}
if
parsed
,
err
:=
strconv
.
ParseInt
(
v
,
10
,
64
);
err
==
nil
{
preds
=
append
(
preds
,
sqljson
.
ValueEQ
(
dbaccount
.
FieldExtra
,
parsed
,
path
))
}
if
len
(
preds
)
==
1
{
s
.
Where
(
preds
[
0
])
}
else
{
s
.
Where
(
entsql
.
Or
(
preds
...
))
}
case
int
:
s
.
Where
(
entsql
.
Or
(
sqljson
.
ValueEQ
(
dbaccount
.
FieldExtra
,
v
,
path
),
sqljson
.
ValueEQ
(
dbaccount
.
FieldExtra
,
strconv
.
Itoa
(
v
),
path
),
))
case
int64
:
s
.
Where
(
entsql
.
Or
(
sqljson
.
ValueEQ
(
dbaccount
.
FieldExtra
,
v
,
path
),
sqljson
.
ValueEQ
(
dbaccount
.
FieldExtra
,
strconv
.
FormatInt
(
v
,
10
),
path
),
))
case
json
.
Number
:
if
parsed
,
err
:=
v
.
Int64
();
err
==
nil
{
s
.
Where
(
entsql
.
Or
(
sqljson
.
ValueEQ
(
dbaccount
.
FieldExtra
,
parsed
,
path
),
sqljson
.
ValueEQ
(
dbaccount
.
FieldExtra
,
v
.
String
(),
path
),
))
}
else
{
s
.
Where
(
sqljson
.
ValueEQ
(
dbaccount
.
FieldExtra
,
v
.
String
(),
path
))
}
default
:
s
.
Where
(
sqljson
.
ValueEQ
(
dbaccount
.
FieldExtra
,
value
,
path
))
}
},
)
.
All
(
ctx
)
if
err
!=
nil
{
return
nil
,
translatePersistenceError
(
err
,
service
.
ErrAccountNotFound
,
nil
)
}
return
r
.
accountsToService
(
ctx
,
accounts
)
}
backend/internal/repository/account_repo_integration_test.go
View file @
3d79773b
...
...
@@ -500,6 +500,38 @@ func (s *AccountRepoSuite) TestClearRateLimit() {
s
.
Require
()
.
Nil
(
got
.
OverloadUntil
)
}
func
(
s
*
AccountRepoSuite
)
TestTempUnschedulableFieldsLoadedByGetByIDAndGetByIDs
()
{
acc1
:=
mustCreateAccount
(
s
.
T
(),
s
.
client
,
&
service
.
Account
{
Name
:
"acc-temp-1"
})
acc2
:=
mustCreateAccount
(
s
.
T
(),
s
.
client
,
&
service
.
Account
{
Name
:
"acc-temp-2"
})
until
:=
time
.
Now
()
.
Add
(
15
*
time
.
Minute
)
.
UTC
()
.
Truncate
(
time
.
Second
)
reason
:=
`{"rule":"429","matched_keyword":"too many requests"}`
s
.
Require
()
.
NoError
(
s
.
repo
.
SetTempUnschedulable
(
s
.
ctx
,
acc1
.
ID
,
until
,
reason
))
gotByID
,
err
:=
s
.
repo
.
GetByID
(
s
.
ctx
,
acc1
.
ID
)
s
.
Require
()
.
NoError
(
err
)
s
.
Require
()
.
NotNil
(
gotByID
.
TempUnschedulableUntil
)
s
.
Require
()
.
WithinDuration
(
until
,
*
gotByID
.
TempUnschedulableUntil
,
time
.
Second
)
s
.
Require
()
.
Equal
(
reason
,
gotByID
.
TempUnschedulableReason
)
gotByIDs
,
err
:=
s
.
repo
.
GetByIDs
(
s
.
ctx
,
[]
int64
{
acc2
.
ID
,
acc1
.
ID
})
s
.
Require
()
.
NoError
(
err
)
s
.
Require
()
.
Len
(
gotByIDs
,
2
)
s
.
Require
()
.
Equal
(
acc2
.
ID
,
gotByIDs
[
0
]
.
ID
)
s
.
Require
()
.
Nil
(
gotByIDs
[
0
]
.
TempUnschedulableUntil
)
s
.
Require
()
.
Equal
(
""
,
gotByIDs
[
0
]
.
TempUnschedulableReason
)
s
.
Require
()
.
Equal
(
acc1
.
ID
,
gotByIDs
[
1
]
.
ID
)
s
.
Require
()
.
NotNil
(
gotByIDs
[
1
]
.
TempUnschedulableUntil
)
s
.
Require
()
.
WithinDuration
(
until
,
*
gotByIDs
[
1
]
.
TempUnschedulableUntil
,
time
.
Second
)
s
.
Require
()
.
Equal
(
reason
,
gotByIDs
[
1
]
.
TempUnschedulableReason
)
s
.
Require
()
.
NoError
(
s
.
repo
.
ClearTempUnschedulable
(
s
.
ctx
,
acc1
.
ID
))
cleared
,
err
:=
s
.
repo
.
GetByID
(
s
.
ctx
,
acc1
.
ID
)
s
.
Require
()
.
NoError
(
err
)
s
.
Require
()
.
Nil
(
cleared
.
TempUnschedulableUntil
)
s
.
Require
()
.
Equal
(
""
,
cleared
.
TempUnschedulableReason
)
}
// --- UpdateLastUsed ---
func
(
s
*
AccountRepoSuite
)
TestUpdateLastUsed
()
{
...
...
backend/internal/repository/allowed_groups_contract_integration_test.go
View file @
3d79773b
...
...
@@ -98,7 +98,7 @@ func TestGroupRepository_DeleteCascade_RemovesAllowedGroupsAndClearsApiKeys(t *t
userRepo
:=
newUserRepositoryWithSQL
(
entClient
,
tx
)
groupRepo
:=
newGroupRepositoryWithSQL
(
entClient
,
tx
)
apiKeyRepo
:=
N
ewAPIKeyRepository
(
entClient
)
apiKeyRepo
:=
n
ewAPIKeyRepository
WithSQL
(
entClient
,
tx
)
u
:=
&
service
.
User
{
Email
:
uniqueTestValue
(
t
,
"cascade-user"
)
+
"@example.com"
,
...
...
backend/internal/repository/api_key_repo.go
View file @
3d79773b
...
...
@@ -2,6 +2,7 @@ package repository
import
(
"context"
"database/sql"
"time"
dbent
"github.com/Wei-Shaw/sub2api/ent"
...
...
@@ -16,10 +17,15 @@ import (
type
apiKeyRepository
struct
{
client
*
dbent
.
Client
sql
sqlExecutor
}
func
NewAPIKeyRepository
(
client
*
dbent
.
Client
)
service
.
APIKeyRepository
{
return
&
apiKeyRepository
{
client
:
client
}
func
NewAPIKeyRepository
(
client
*
dbent
.
Client
,
sqlDB
*
sql
.
DB
)
service
.
APIKeyRepository
{
return
newAPIKeyRepositoryWithSQL
(
client
,
sqlDB
)
}
func
newAPIKeyRepositoryWithSQL
(
client
*
dbent
.
Client
,
sqlq
sqlExecutor
)
*
apiKeyRepository
{
return
&
apiKeyRepository
{
client
:
client
,
sql
:
sqlq
}
}
func
(
r
*
apiKeyRepository
)
activeQuery
()
*
dbent
.
APIKeyQuery
{
...
...
@@ -34,9 +40,13 @@ func (r *apiKeyRepository) Create(ctx context.Context, key *service.APIKey) erro
SetName
(
key
.
Name
)
.
SetStatus
(
key
.
Status
)
.
SetNillableGroupID
(
key
.
GroupID
)
.
SetNillableLastUsedAt
(
key
.
LastUsedAt
)
.
SetQuota
(
key
.
Quota
)
.
SetQuotaUsed
(
key
.
QuotaUsed
)
.
SetNillableExpiresAt
(
key
.
ExpiresAt
)
SetNillableExpiresAt
(
key
.
ExpiresAt
)
.
SetRateLimit5h
(
key
.
RateLimit5h
)
.
SetRateLimit1d
(
key
.
RateLimit1d
)
.
SetRateLimit7d
(
key
.
RateLimit7d
)
if
len
(
key
.
IPWhitelist
)
>
0
{
builder
.
SetIPWhitelist
(
key
.
IPWhitelist
)
...
...
@@ -48,6 +58,7 @@ func (r *apiKeyRepository) Create(ctx context.Context, key *service.APIKey) erro
created
,
err
:=
builder
.
Save
(
ctx
)
if
err
==
nil
{
key
.
ID
=
created
.
ID
key
.
LastUsedAt
=
created
.
LastUsedAt
key
.
CreatedAt
=
created
.
CreatedAt
key
.
UpdatedAt
=
created
.
UpdatedAt
}
...
...
@@ -116,6 +127,9 @@ func (r *apiKeyRepository) GetByKeyForAuth(ctx context.Context, key string) (*se
apikey
.
FieldQuota
,
apikey
.
FieldQuotaUsed
,
apikey
.
FieldExpiresAt
,
apikey
.
FieldRateLimit5h
,
apikey
.
FieldRateLimit1d
,
apikey
.
FieldRateLimit7d
,
)
.
WithUser
(
func
(
q
*
dbent
.
UserQuery
)
{
q
.
Select
(
...
...
@@ -140,6 +154,10 @@ func (r *apiKeyRepository) GetByKeyForAuth(ctx context.Context, key string) (*se
group
.
FieldImagePrice1k
,
group
.
FieldImagePrice2k
,
group
.
FieldImagePrice4k
,
group
.
FieldSoraImagePrice360
,
group
.
FieldSoraImagePrice540
,
group
.
FieldSoraVideoPricePerRequest
,
group
.
FieldSoraVideoPricePerRequestHd
,
group
.
FieldClaudeCodeOnly
,
group
.
FieldFallbackGroupID
,
group
.
FieldFallbackGroupIDOnInvalidRequest
,
...
...
@@ -165,13 +183,20 @@ func (r *apiKeyRepository) Update(ctx context.Context, key *service.APIKey) erro
// 则会更新已删除的记录。
// 这里选择 Update().Where(),确保只有未软删除记录能被更新。
// 同时显式设置 updated_at,避免二次查询带来的并发可见性问题。
client
:=
clientFromContext
(
ctx
,
r
.
client
)
now
:=
time
.
Now
()
builder
:=
r
.
client
.
APIKey
.
Update
()
.
builder
:=
client
.
APIKey
.
Update
()
.
Where
(
apikey
.
IDEQ
(
key
.
ID
),
apikey
.
DeletedAtIsNil
())
.
SetName
(
key
.
Name
)
.
SetStatus
(
key
.
Status
)
.
SetQuota
(
key
.
Quota
)
.
SetQuotaUsed
(
key
.
QuotaUsed
)
.
SetRateLimit5h
(
key
.
RateLimit5h
)
.
SetRateLimit1d
(
key
.
RateLimit1d
)
.
SetRateLimit7d
(
key
.
RateLimit7d
)
.
SetUsage5h
(
key
.
Usage5h
)
.
SetUsage1d
(
key
.
Usage1d
)
.
SetUsage7d
(
key
.
Usage7d
)
.
SetUpdatedAt
(
now
)
if
key
.
GroupID
!=
nil
{
builder
.
SetGroupID
(
*
key
.
GroupID
)
...
...
@@ -186,6 +211,23 @@ func (r *apiKeyRepository) Update(ctx context.Context, key *service.APIKey) erro
builder
.
ClearExpiresAt
()
}
// Rate limit window start times
if
key
.
Window5hStart
!=
nil
{
builder
.
SetWindow5hStart
(
*
key
.
Window5hStart
)
}
else
{
builder
.
ClearWindow5hStart
()
}
if
key
.
Window1dStart
!=
nil
{
builder
.
SetWindow1dStart
(
*
key
.
Window1dStart
)
}
else
{
builder
.
ClearWindow1dStart
()
}
if
key
.
Window7dStart
!=
nil
{
builder
.
SetWindow7dStart
(
*
key
.
Window7dStart
)
}
else
{
builder
.
ClearWindow7dStart
()
}
// IP 限制字段
if
len
(
key
.
IPWhitelist
)
>
0
{
builder
.
SetIPWhitelist
(
key
.
IPWhitelist
)
...
...
@@ -239,9 +281,27 @@ func (r *apiKeyRepository) Delete(ctx context.Context, id int64) error {
return
nil
}
func
(
r
*
apiKeyRepository
)
ListByUserID
(
ctx
context
.
Context
,
userID
int64
,
params
pagination
.
PaginationParams
)
([]
service
.
APIKey
,
*
pagination
.
PaginationResult
,
error
)
{
func
(
r
*
apiKeyRepository
)
ListByUserID
(
ctx
context
.
Context
,
userID
int64
,
params
pagination
.
PaginationParams
,
filters
service
.
APIKeyListFilters
)
([]
service
.
APIKey
,
*
pagination
.
PaginationResult
,
error
)
{
q
:=
r
.
activeQuery
()
.
Where
(
apikey
.
UserIDEQ
(
userID
))
// Apply filters
if
filters
.
Search
!=
""
{
q
=
q
.
Where
(
apikey
.
Or
(
apikey
.
NameContainsFold
(
filters
.
Search
),
apikey
.
KeyContainsFold
(
filters
.
Search
),
))
}
if
filters
.
Status
!=
""
{
q
=
q
.
Where
(
apikey
.
StatusEQ
(
filters
.
Status
))
}
if
filters
.
GroupID
!=
nil
{
if
*
filters
.
GroupID
==
0
{
q
=
q
.
Where
(
apikey
.
GroupIDIsNil
())
}
else
{
q
=
q
.
Where
(
apikey
.
GroupIDEQ
(
*
filters
.
GroupID
))
}
}
total
,
err
:=
q
.
Count
(
ctx
)
if
err
!=
nil
{
return
nil
,
nil
,
err
...
...
@@ -375,36 +435,92 @@ func (r *apiKeyRepository) ListKeysByGroupID(ctx context.Context, groupID int64)
return
keys
,
nil
}
// IncrementQuotaUsed
atomically increments the quota_used field and returns the new value
// IncrementQuotaUsed
使用 Ent 原子递增 quota_used 字段并返回新值
func
(
r
*
apiKeyRepository
)
IncrementQuotaUsed
(
ctx
context
.
Context
,
id
int64
,
amount
float64
)
(
float64
,
error
)
{
// Use raw SQL for atomic increment to avoid race conditions
// First get current value
m
,
err
:=
r
.
activeQuery
()
.
Where
(
apikey
.
IDEQ
(
id
))
.
Select
(
apikey
.
FieldQuotaUsed
)
.
Only
(
ctx
)
updated
,
err
:=
r
.
client
.
APIKey
.
UpdateOneID
(
id
)
.
Where
(
apikey
.
DeletedAtIsNil
())
.
AddQuotaUsed
(
amount
)
.
Save
(
ctx
)
if
err
!=
nil
{
if
dbent
.
IsNotFound
(
err
)
{
return
0
,
service
.
ErrAPIKeyNotFound
}
return
0
,
err
}
return
updated
.
QuotaUsed
,
nil
}
newValue
:=
m
.
QuotaUsed
+
amount
// Update with new value
func
(
r
*
apiKeyRepository
)
UpdateLastUsed
(
ctx
context
.
Context
,
id
int64
,
usedAt
time
.
Time
)
error
{
affected
,
err
:=
r
.
client
.
APIKey
.
Update
()
.
Where
(
apikey
.
IDEQ
(
id
),
apikey
.
DeletedAtIsNil
())
.
SetQuotaUsed
(
newValue
)
.
SetLastUsedAt
(
usedAt
)
.
SetUpdatedAt
(
usedAt
)
.
Save
(
ctx
)
if
err
!=
nil
{
return
0
,
err
return
err
}
if
affected
==
0
{
return
0
,
service
.
ErrAPIKeyNotFound
return
service
.
ErrAPIKeyNotFound
}
return
nil
}
// IncrementRateLimitUsage atomically increments all rate limit usage counters and initializes
// window start times via COALESCE if not already set.
func
(
r
*
apiKeyRepository
)
IncrementRateLimitUsage
(
ctx
context
.
Context
,
id
int64
,
cost
float64
)
error
{
_
,
err
:=
r
.
sql
.
ExecContext
(
ctx
,
`
UPDATE api_keys SET
usage_5h = usage_5h + $1,
usage_1d = usage_1d + $1,
usage_7d = usage_7d + $1,
window_5h_start = COALESCE(window_5h_start, NOW()),
window_1d_start = COALESCE(window_1d_start, NOW()),
window_7d_start = COALESCE(window_7d_start, NOW()),
updated_at = NOW()
WHERE id = $2 AND deleted_at IS NULL`
,
cost
,
id
)
return
err
}
// ResetRateLimitWindows resets expired rate limit windows atomically.
func
(
r
*
apiKeyRepository
)
ResetRateLimitWindows
(
ctx
context
.
Context
,
id
int64
)
error
{
_
,
err
:=
r
.
sql
.
ExecContext
(
ctx
,
`
UPDATE api_keys SET
usage_5h = CASE WHEN window_5h_start IS NOT NULL AND window_5h_start + INTERVAL '5 hours' <= NOW() THEN 0 ELSE usage_5h END,
window_5h_start = CASE WHEN window_5h_start IS NOT NULL AND window_5h_start + INTERVAL '5 hours' <= NOW() THEN NOW() ELSE window_5h_start END,
usage_1d = CASE WHEN window_1d_start IS NOT NULL AND window_1d_start + INTERVAL '24 hours' <= NOW() THEN 0 ELSE usage_1d END,
window_1d_start = CASE WHEN window_1d_start IS NOT NULL AND window_1d_start + INTERVAL '24 hours' <= NOW() THEN NOW() ELSE window_1d_start END,
usage_7d = CASE WHEN window_7d_start IS NOT NULL AND window_7d_start + INTERVAL '7 days' <= NOW() THEN 0 ELSE usage_7d END,
window_7d_start = CASE WHEN window_7d_start IS NOT NULL AND window_7d_start + INTERVAL '7 days' <= NOW() THEN NOW() ELSE window_7d_start END,
updated_at = NOW()
WHERE id = $1 AND deleted_at IS NULL`
,
id
)
return
err
}
return
newValue
,
nil
// GetRateLimitData returns the current rate limit usage and window start times for an API key.
func
(
r
*
apiKeyRepository
)
GetRateLimitData
(
ctx
context
.
Context
,
id
int64
)
(
result
*
service
.
APIKeyRateLimitData
,
err
error
)
{
rows
,
err
:=
r
.
sql
.
QueryContext
(
ctx
,
`
SELECT usage_5h, usage_1d, usage_7d, window_5h_start, window_1d_start, window_7d_start
FROM api_keys
WHERE id = $1 AND deleted_at IS NULL`
,
id
)
if
err
!=
nil
{
return
nil
,
err
}
defer
func
()
{
if
closeErr
:=
rows
.
Close
();
closeErr
!=
nil
&&
err
==
nil
{
err
=
closeErr
}
}()
if
!
rows
.
Next
()
{
return
nil
,
service
.
ErrAPIKeyNotFound
}
data
:=
&
service
.
APIKeyRateLimitData
{}
if
err
:=
rows
.
Scan
(
&
data
.
Usage5h
,
&
data
.
Usage1d
,
&
data
.
Usage7d
,
&
data
.
Window5hStart
,
&
data
.
Window1dStart
,
&
data
.
Window7dStart
);
err
!=
nil
{
return
nil
,
err
}
return
data
,
rows
.
Err
()
}
func
apiKeyEntityToService
(
m
*
dbent
.
APIKey
)
*
service
.
APIKey
{
...
...
@@ -419,12 +535,22 @@ func apiKeyEntityToService(m *dbent.APIKey) *service.APIKey {
Status
:
m
.
Status
,
IPWhitelist
:
m
.
IPWhitelist
,
IPBlacklist
:
m
.
IPBlacklist
,
LastUsedAt
:
m
.
LastUsedAt
,
CreatedAt
:
m
.
CreatedAt
,
UpdatedAt
:
m
.
UpdatedAt
,
GroupID
:
m
.
GroupID
,
Quota
:
m
.
Quota
,
QuotaUsed
:
m
.
QuotaUsed
,
ExpiresAt
:
m
.
ExpiresAt
,
RateLimit5h
:
m
.
RateLimit5h
,
RateLimit1d
:
m
.
RateLimit1d
,
RateLimit7d
:
m
.
RateLimit7d
,
Usage5h
:
m
.
Usage5h
,
Usage1d
:
m
.
Usage1d
,
Usage7d
:
m
.
Usage7d
,
Window5hStart
:
m
.
Window5hStart
,
Window1dStart
:
m
.
Window1dStart
,
Window7dStart
:
m
.
Window7dStart
,
}
if
m
.
Edges
.
User
!=
nil
{
out
.
User
=
userEntityToService
(
m
.
Edges
.
User
)
...
...
@@ -449,6 +575,8 @@ func userEntityToService(u *dbent.User) *service.User {
Balance
:
u
.
Balance
,
Concurrency
:
u
.
Concurrency
,
Status
:
u
.
Status
,
SoraStorageQuotaBytes
:
u
.
SoraStorageQuotaBytes
,
SoraStorageUsedBytes
:
u
.
SoraStorageUsedBytes
,
TotpSecretEncrypted
:
u
.
TotpSecretEncrypted
,
TotpEnabled
:
u
.
TotpEnabled
,
TotpEnabledAt
:
u
.
TotpEnabledAt
,
...
...
@@ -477,6 +605,11 @@ func groupEntityToService(g *dbent.Group) *service.Group {
ImagePrice1K
:
g
.
ImagePrice1k
,
ImagePrice2K
:
g
.
ImagePrice2k
,
ImagePrice4K
:
g
.
ImagePrice4k
,
SoraImagePrice360
:
g
.
SoraImagePrice360
,
SoraImagePrice540
:
g
.
SoraImagePrice540
,
SoraVideoPricePerRequest
:
g
.
SoraVideoPricePerRequest
,
SoraVideoPricePerRequestHD
:
g
.
SoraVideoPricePerRequestHd
,
SoraStorageQuotaBytes
:
g
.
SoraStorageQuotaBytes
,
DefaultValidityDays
:
g
.
DefaultValidityDays
,
ClaudeCodeOnly
:
g
.
ClaudeCodeOnly
,
FallbackGroupID
:
g
.
FallbackGroupID
,
...
...
backend/internal/repository/api_key_repo_integration_test.go
View file @
3d79773b
...
...
@@ -4,11 +4,14 @@ package repository
import
(
"context"
"sync"
"testing"
"time"
dbent
"github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
)
...
...
@@ -23,7 +26,7 @@ func (s *APIKeyRepoSuite) SetupTest() {
s
.
ctx
=
context
.
Background
()
tx
:=
testEntTx
(
s
.
T
())
s
.
client
=
tx
.
Client
()
s
.
repo
=
N
ewAPIKeyRepository
(
s
.
client
)
.
(
*
apiKeyRepository
)
s
.
repo
=
n
ewAPIKeyRepository
WithSQL
(
s
.
client
,
tx
)
}
func
TestAPIKeyRepoSuite
(
t
*
testing
.
T
)
{
...
...
@@ -155,7 +158,7 @@ func (s *APIKeyRepoSuite) TestListByUserID() {
s
.
mustCreateApiKey
(
user
.
ID
,
"sk-list-1"
,
"Key 1"
,
nil
)
s
.
mustCreateApiKey
(
user
.
ID
,
"sk-list-2"
,
"Key 2"
,
nil
)
keys
,
page
,
err
:=
s
.
repo
.
ListByUserID
(
s
.
ctx
,
user
.
ID
,
pagination
.
PaginationParams
{
Page
:
1
,
PageSize
:
10
})
keys
,
page
,
err
:=
s
.
repo
.
ListByUserID
(
s
.
ctx
,
user
.
ID
,
pagination
.
PaginationParams
{
Page
:
1
,
PageSize
:
10
}
,
service
.
APIKeyListFilters
{}
)
s
.
Require
()
.
NoError
(
err
,
"ListByUserID"
)
s
.
Require
()
.
Len
(
keys
,
2
)
s
.
Require
()
.
Equal
(
int64
(
2
),
page
.
Total
)
...
...
@@ -167,7 +170,7 @@ func (s *APIKeyRepoSuite) TestListByUserID_Pagination() {
s
.
mustCreateApiKey
(
user
.
ID
,
"sk-page-"
+
string
(
rune
(
'a'
+
i
)),
"Key"
,
nil
)
}
keys
,
page
,
err
:=
s
.
repo
.
ListByUserID
(
s
.
ctx
,
user
.
ID
,
pagination
.
PaginationParams
{
Page
:
1
,
PageSize
:
2
})
keys
,
page
,
err
:=
s
.
repo
.
ListByUserID
(
s
.
ctx
,
user
.
ID
,
pagination
.
PaginationParams
{
Page
:
1
,
PageSize
:
2
}
,
service
.
APIKeyListFilters
{}
)
s
.
Require
()
.
NoError
(
err
)
s
.
Require
()
.
Len
(
keys
,
2
)
s
.
Require
()
.
Equal
(
int64
(
5
),
page
.
Total
)
...
...
@@ -311,7 +314,7 @@ func (s *APIKeyRepoSuite) TestCRUD_Search_ClearGroupID() {
s
.
Require
()
.
Equal
(
service
.
StatusDisabled
,
got2
.
Status
)
s
.
Require
()
.
Nil
(
got2
.
GroupID
)
keys
,
page
,
err
:=
s
.
repo
.
ListByUserID
(
s
.
ctx
,
user
.
ID
,
pagination
.
PaginationParams
{
Page
:
1
,
PageSize
:
10
})
keys
,
page
,
err
:=
s
.
repo
.
ListByUserID
(
s
.
ctx
,
user
.
ID
,
pagination
.
PaginationParams
{
Page
:
1
,
PageSize
:
10
}
,
service
.
APIKeyListFilters
{}
)
s
.
Require
()
.
NoError
(
err
,
"ListByUserID"
)
s
.
Require
()
.
Equal
(
int64
(
1
),
page
.
Total
)
s
.
Require
()
.
Len
(
keys
,
1
)
...
...
@@ -383,3 +386,87 @@ func (s *APIKeyRepoSuite) mustCreateApiKey(userID int64, key, name string, group
s
.
Require
()
.
NoError
(
s
.
repo
.
Create
(
s
.
ctx
,
k
),
"create api key"
)
return
k
}
// --- IncrementQuotaUsed ---
func
(
s
*
APIKeyRepoSuite
)
TestIncrementQuotaUsed_Basic
()
{
user
:=
s
.
mustCreateUser
(
"incr-basic@test.com"
)
key
:=
s
.
mustCreateApiKey
(
user
.
ID
,
"sk-incr-basic"
,
"Incr"
,
nil
)
newQuota
,
err
:=
s
.
repo
.
IncrementQuotaUsed
(
s
.
ctx
,
key
.
ID
,
1.5
)
s
.
Require
()
.
NoError
(
err
,
"IncrementQuotaUsed"
)
s
.
Require
()
.
Equal
(
1.5
,
newQuota
,
"第一次递增后应为 1.5"
)
newQuota
,
err
=
s
.
repo
.
IncrementQuotaUsed
(
s
.
ctx
,
key
.
ID
,
2.5
)
s
.
Require
()
.
NoError
(
err
,
"IncrementQuotaUsed second"
)
s
.
Require
()
.
Equal
(
4.0
,
newQuota
,
"第二次递增后应为 4.0"
)
}
func
(
s
*
APIKeyRepoSuite
)
TestIncrementQuotaUsed_NotFound
()
{
_
,
err
:=
s
.
repo
.
IncrementQuotaUsed
(
s
.
ctx
,
999999
,
1.0
)
s
.
Require
()
.
ErrorIs
(
err
,
service
.
ErrAPIKeyNotFound
,
"不存在的 key 应返回 ErrAPIKeyNotFound"
)
}
func
(
s
*
APIKeyRepoSuite
)
TestIncrementQuotaUsed_DeletedKey
()
{
user
:=
s
.
mustCreateUser
(
"incr-deleted@test.com"
)
key
:=
s
.
mustCreateApiKey
(
user
.
ID
,
"sk-incr-del"
,
"Deleted"
,
nil
)
s
.
Require
()
.
NoError
(
s
.
repo
.
Delete
(
s
.
ctx
,
key
.
ID
),
"Delete"
)
_
,
err
:=
s
.
repo
.
IncrementQuotaUsed
(
s
.
ctx
,
key
.
ID
,
1.0
)
s
.
Require
()
.
ErrorIs
(
err
,
service
.
ErrAPIKeyNotFound
,
"已删除的 key 应返回 ErrAPIKeyNotFound"
)
}
// TestIncrementQuotaUsed_Concurrent 使用真实数据库验证并发原子性。
// 注意:此测试使用 testEntClient(非事务隔离),数据会真正写入数据库。
func
TestIncrementQuotaUsed_Concurrent
(
t
*
testing
.
T
)
{
client
:=
testEntClient
(
t
)
repo
:=
NewAPIKeyRepository
(
client
,
integrationDB
)
.
(
*
apiKeyRepository
)
ctx
:=
context
.
Background
()
// 创建测试用户和 API Key
u
,
err
:=
client
.
User
.
Create
()
.
SetEmail
(
"concurrent-incr-"
+
time
.
Now
()
.
Format
(
time
.
RFC3339Nano
)
+
"@test.com"
)
.
SetPasswordHash
(
"hash"
)
.
SetStatus
(
service
.
StatusActive
)
.
SetRole
(
service
.
RoleUser
)
.
Save
(
ctx
)
require
.
NoError
(
t
,
err
,
"create user"
)
k
:=
&
service
.
APIKey
{
UserID
:
u
.
ID
,
Key
:
"sk-concurrent-"
+
time
.
Now
()
.
Format
(
time
.
RFC3339Nano
),
Name
:
"Concurrent"
,
Status
:
service
.
StatusActive
,
}
require
.
NoError
(
t
,
repo
.
Create
(
ctx
,
k
),
"create api key"
)
t
.
Cleanup
(
func
()
{
_
=
client
.
APIKey
.
DeleteOneID
(
k
.
ID
)
.
Exec
(
ctx
)
_
=
client
.
User
.
DeleteOneID
(
u
.
ID
)
.
Exec
(
ctx
)
})
// 10 个 goroutine 各递增 1.0,总计应为 10.0
const
goroutines
=
10
const
increment
=
1.0
var
wg
sync
.
WaitGroup
errs
:=
make
([]
error
,
goroutines
)
for
i
:=
0
;
i
<
goroutines
;
i
++
{
wg
.
Add
(
1
)
go
func
(
idx
int
)
{
defer
wg
.
Done
()
_
,
errs
[
idx
]
=
repo
.
IncrementQuotaUsed
(
ctx
,
k
.
ID
,
increment
)
}(
i
)
}
wg
.
Wait
()
for
i
,
e
:=
range
errs
{
require
.
NoError
(
t
,
e
,
"goroutine %d failed"
,
i
)
}
// 验证最终结果
got
,
err
:=
repo
.
GetByID
(
ctx
,
k
.
ID
)
require
.
NoError
(
t
,
err
,
"GetByID"
)
require
.
Equal
(
t
,
float64
(
goroutines
)
*
increment
,
got
.
QuotaUsed
,
"并发递增后总和应为 %v,实际为 %v"
,
float64
(
goroutines
)
*
increment
,
got
.
QuotaUsed
)
}
backend/internal/repository/api_key_repo_last_used_unit_test.go
0 → 100644
View file @
3d79773b
package
repository
import
(
"context"
"database/sql"
"testing"
"time"
dbent
"github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/enttest"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/stretchr/testify/require"
"entgo.io/ent/dialect"
entsql
"entgo.io/ent/dialect/sql"
_
"modernc.org/sqlite"
)
func
newAPIKeyRepoSQLite
(
t
*
testing
.
T
)
(
*
apiKeyRepository
,
*
dbent
.
Client
)
{
t
.
Helper
()
db
,
err
:=
sql
.
Open
(
"sqlite"
,
"file:api_key_repo_last_used?mode=memory&cache=shared"
)
require
.
NoError
(
t
,
err
)
t
.
Cleanup
(
func
()
{
_
=
db
.
Close
()
})
_
,
err
=
db
.
Exec
(
"PRAGMA foreign_keys = ON"
)
require
.
NoError
(
t
,
err
)
drv
:=
entsql
.
OpenDB
(
dialect
.
SQLite
,
db
)
client
:=
enttest
.
NewClient
(
t
,
enttest
.
WithOptions
(
dbent
.
Driver
(
drv
)))
t
.
Cleanup
(
func
()
{
_
=
client
.
Close
()
})
return
&
apiKeyRepository
{
client
:
client
},
client
}
func
mustCreateAPIKeyRepoUser
(
t
*
testing
.
T
,
ctx
context
.
Context
,
client
*
dbent
.
Client
,
email
string
)
*
service
.
User
{
t
.
Helper
()
u
,
err
:=
client
.
User
.
Create
()
.
SetEmail
(
email
)
.
SetPasswordHash
(
"test-password-hash"
)
.
SetRole
(
service
.
RoleUser
)
.
SetStatus
(
service
.
StatusActive
)
.
Save
(
ctx
)
require
.
NoError
(
t
,
err
)
return
userEntityToService
(
u
)
}
func
TestAPIKeyRepository_CreateWithLastUsedAt
(
t
*
testing
.
T
)
{
repo
,
client
:=
newAPIKeyRepoSQLite
(
t
)
ctx
:=
context
.
Background
()
user
:=
mustCreateAPIKeyRepoUser
(
t
,
ctx
,
client
,
"create-last-used@test.com"
)
lastUsed
:=
time
.
Now
()
.
UTC
()
.
Add
(
-
time
.
Hour
)
.
Truncate
(
time
.
Second
)
key
:=
&
service
.
APIKey
{
UserID
:
user
.
ID
,
Key
:
"sk-create-last-used"
,
Name
:
"CreateWithLastUsed"
,
Status
:
service
.
StatusActive
,
LastUsedAt
:
&
lastUsed
,
}
require
.
NoError
(
t
,
repo
.
Create
(
ctx
,
key
))
require
.
NotNil
(
t
,
key
.
LastUsedAt
)
require
.
WithinDuration
(
t
,
lastUsed
,
*
key
.
LastUsedAt
,
time
.
Second
)
got
,
err
:=
repo
.
GetByID
(
ctx
,
key
.
ID
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
got
.
LastUsedAt
)
require
.
WithinDuration
(
t
,
lastUsed
,
*
got
.
LastUsedAt
,
time
.
Second
)
}
func
TestAPIKeyRepository_UpdateLastUsed
(
t
*
testing
.
T
)
{
repo
,
client
:=
newAPIKeyRepoSQLite
(
t
)
ctx
:=
context
.
Background
()
user
:=
mustCreateAPIKeyRepoUser
(
t
,
ctx
,
client
,
"update-last-used@test.com"
)
key
:=
&
service
.
APIKey
{
UserID
:
user
.
ID
,
Key
:
"sk-update-last-used"
,
Name
:
"UpdateLastUsed"
,
Status
:
service
.
StatusActive
,
}
require
.
NoError
(
t
,
repo
.
Create
(
ctx
,
key
))
before
,
err
:=
repo
.
GetByID
(
ctx
,
key
.
ID
)
require
.
NoError
(
t
,
err
)
require
.
Nil
(
t
,
before
.
LastUsedAt
)
target
:=
time
.
Now
()
.
UTC
()
.
Add
(
2
*
time
.
Minute
)
.
Truncate
(
time
.
Second
)
require
.
NoError
(
t
,
repo
.
UpdateLastUsed
(
ctx
,
key
.
ID
,
target
))
after
,
err
:=
repo
.
GetByID
(
ctx
,
key
.
ID
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
after
.
LastUsedAt
)
require
.
WithinDuration
(
t
,
target
,
*
after
.
LastUsedAt
,
time
.
Second
)
require
.
WithinDuration
(
t
,
target
,
after
.
UpdatedAt
,
time
.
Second
)
}
func
TestAPIKeyRepository_UpdateLastUsedDeletedKey
(
t
*
testing
.
T
)
{
repo
,
client
:=
newAPIKeyRepoSQLite
(
t
)
ctx
:=
context
.
Background
()
user
:=
mustCreateAPIKeyRepoUser
(
t
,
ctx
,
client
,
"deleted-last-used@test.com"
)
key
:=
&
service
.
APIKey
{
UserID
:
user
.
ID
,
Key
:
"sk-update-last-used-deleted"
,
Name
:
"UpdateLastUsedDeleted"
,
Status
:
service
.
StatusActive
,
}
require
.
NoError
(
t
,
repo
.
Create
(
ctx
,
key
))
require
.
NoError
(
t
,
repo
.
Delete
(
ctx
,
key
.
ID
))
err
:=
repo
.
UpdateLastUsed
(
ctx
,
key
.
ID
,
time
.
Now
()
.
UTC
())
require
.
ErrorIs
(
t
,
err
,
service
.
ErrAPIKeyNotFound
)
}
func
TestAPIKeyRepository_UpdateLastUsedDBError
(
t
*
testing
.
T
)
{
repo
,
client
:=
newAPIKeyRepoSQLite
(
t
)
ctx
:=
context
.
Background
()
user
:=
mustCreateAPIKeyRepoUser
(
t
,
ctx
,
client
,
"db-error-last-used@test.com"
)
key
:=
&
service
.
APIKey
{
UserID
:
user
.
ID
,
Key
:
"sk-update-last-used-db-error"
,
Name
:
"UpdateLastUsedDBError"
,
Status
:
service
.
StatusActive
,
}
require
.
NoError
(
t
,
repo
.
Create
(
ctx
,
key
))
require
.
NoError
(
t
,
client
.
Close
())
err
:=
repo
.
UpdateLastUsed
(
ctx
,
key
.
ID
,
time
.
Now
()
.
UTC
())
require
.
Error
(
t
,
err
)
}
func
TestAPIKeyRepository_CreateDuplicateKey
(
t
*
testing
.
T
)
{
repo
,
client
:=
newAPIKeyRepoSQLite
(
t
)
ctx
:=
context
.
Background
()
user
:=
mustCreateAPIKeyRepoUser
(
t
,
ctx
,
client
,
"duplicate-key@test.com"
)
first
:=
&
service
.
APIKey
{
UserID
:
user
.
ID
,
Key
:
"sk-duplicate"
,
Name
:
"first"
,
Status
:
service
.
StatusActive
,
}
second
:=
&
service
.
APIKey
{
UserID
:
user
.
ID
,
Key
:
"sk-duplicate"
,
Name
:
"second"
,
Status
:
service
.
StatusActive
,
}
require
.
NoError
(
t
,
repo
.
Create
(
ctx
,
first
))
err
:=
repo
.
Create
(
ctx
,
second
)
require
.
ErrorIs
(
t
,
err
,
service
.
ErrAPIKeyExists
)
}
backend/internal/repository/billing_cache.go
View file @
3d79773b
...
...
@@ -5,6 +5,7 @@ import (
"errors"
"fmt"
"log"
"math/rand/v2"
"strconv"
"time"
...
...
@@ -15,9 +16,22 @@ import (
const
(
billingBalanceKeyPrefix
=
"billing:balance:"
billingSubKeyPrefix
=
"billing:sub:"
billingRateLimitKeyPrefix
=
"apikey:rate:"
billingCacheTTL
=
5
*
time
.
Minute
billingCacheJitter
=
30
*
time
.
Second
rateLimitCacheTTL
=
7
*
24
*
time
.
Hour
// 7 days matches the longest window
)
// jitteredTTL 返回带随机抖动的 TTL,防止缓存雪崩
func
jitteredTTL
()
time
.
Duration
{
// 只做“减法抖动”,确保实际 TTL 不会超过 billingCacheTTL(避免上界预期被打破)。
if
billingCacheJitter
<=
0
{
return
billingCacheTTL
}
jitter
:=
time
.
Duration
(
rand
.
IntN
(
int
(
billingCacheJitter
)))
return
billingCacheTTL
-
jitter
}
// billingBalanceKey generates the Redis key for user balance cache.
func
billingBalanceKey
(
userID
int64
)
string
{
return
fmt
.
Sprintf
(
"%s%d"
,
billingBalanceKeyPrefix
,
userID
)
...
...
@@ -37,6 +51,20 @@ const (
subFieldVersion
=
"version"
)
// billingRateLimitKey generates the Redis key for API key rate limit cache.
func
billingRateLimitKey
(
keyID
int64
)
string
{
return
fmt
.
Sprintf
(
"%s%d"
,
billingRateLimitKeyPrefix
,
keyID
)
}
const
(
rateLimitFieldUsage5h
=
"usage_5h"
rateLimitFieldUsage1d
=
"usage_1d"
rateLimitFieldUsage7d
=
"usage_7d"
rateLimitFieldWindow5h
=
"window_5h"
rateLimitFieldWindow1d
=
"window_1d"
rateLimitFieldWindow7d
=
"window_7d"
)
var
(
deductBalanceScript
=
redis
.
NewScript
(
`
local current = redis.call('GET', KEYS[1])
...
...
@@ -61,6 +89,21 @@ var (
redis.call('EXPIRE', KEYS[1], ARGV[2])
return 1
`
)
// updateRateLimitUsageScript atomically increments all three rate limit usage counters.
// Returns 0 if the key doesn't exist (cache miss), 1 on success.
updateRateLimitUsageScript
=
redis
.
NewScript
(
`
local exists = redis.call('EXISTS', KEYS[1])
if exists == 0 then
return 0
end
local cost = tonumber(ARGV[1])
redis.call('HINCRBYFLOAT', KEYS[1], 'usage_5h', cost)
redis.call('HINCRBYFLOAT', KEYS[1], 'usage_1d', cost)
redis.call('HINCRBYFLOAT', KEYS[1], 'usage_7d', cost)
redis.call('EXPIRE', KEYS[1], ARGV[2])
return 1
`
)
)
type
billingCache
struct
{
...
...
@@ -82,14 +125,15 @@ func (c *billingCache) GetUserBalance(ctx context.Context, userID int64) (float6
func
(
c
*
billingCache
)
SetUserBalance
(
ctx
context
.
Context
,
userID
int64
,
balance
float64
)
error
{
key
:=
billingBalanceKey
(
userID
)
return
c
.
rdb
.
Set
(
ctx
,
key
,
balance
,
billingCache
TTL
)
.
Err
()
return
c
.
rdb
.
Set
(
ctx
,
key
,
balance
,
jittered
TTL
()
)
.
Err
()
}
func
(
c
*
billingCache
)
DeductUserBalance
(
ctx
context
.
Context
,
userID
int64
,
amount
float64
)
error
{
key
:=
billingBalanceKey
(
userID
)
_
,
err
:=
deductBalanceScript
.
Run
(
ctx
,
c
.
rdb
,
[]
string
{
key
},
amount
,
int
(
billingCache
TTL
.
Seconds
()))
.
Result
()
_
,
err
:=
deductBalanceScript
.
Run
(
ctx
,
c
.
rdb
,
[]
string
{
key
},
amount
,
int
(
jittered
TTL
()
.
Seconds
()))
.
Result
()
if
err
!=
nil
&&
!
errors
.
Is
(
err
,
redis
.
Nil
)
{
log
.
Printf
(
"Warning: deduct balance cache failed for user %d: %v"
,
userID
,
err
)
return
err
}
return
nil
}
...
...
@@ -163,16 +207,17 @@ func (c *billingCache) SetSubscriptionCache(ctx context.Context, userID, groupID
pipe
:=
c
.
rdb
.
Pipeline
()
pipe
.
HSet
(
ctx
,
key
,
fields
)
pipe
.
Expire
(
ctx
,
key
,
billingCache
TTL
)
pipe
.
Expire
(
ctx
,
key
,
jittered
TTL
()
)
_
,
err
:=
pipe
.
Exec
(
ctx
)
return
err
}
func
(
c
*
billingCache
)
UpdateSubscriptionUsage
(
ctx
context
.
Context
,
userID
,
groupID
int64
,
cost
float64
)
error
{
key
:=
billingSubKey
(
userID
,
groupID
)
_
,
err
:=
updateSubUsageScript
.
Run
(
ctx
,
c
.
rdb
,
[]
string
{
key
},
cost
,
int
(
billingCache
TTL
.
Seconds
()))
.
Result
()
_
,
err
:=
updateSubUsageScript
.
Run
(
ctx
,
c
.
rdb
,
[]
string
{
key
},
cost
,
int
(
jittered
TTL
()
.
Seconds
()))
.
Result
()
if
err
!=
nil
&&
!
errors
.
Is
(
err
,
redis
.
Nil
)
{
log
.
Printf
(
"Warning: update subscription usage cache failed for user %d group %d: %v"
,
userID
,
groupID
,
err
)
return
err
}
return
nil
}
...
...
@@ -181,3 +226,69 @@ func (c *billingCache) InvalidateSubscriptionCache(ctx context.Context, userID,
key
:=
billingSubKey
(
userID
,
groupID
)
return
c
.
rdb
.
Del
(
ctx
,
key
)
.
Err
()
}
func
(
c
*
billingCache
)
GetAPIKeyRateLimit
(
ctx
context
.
Context
,
keyID
int64
)
(
*
service
.
APIKeyRateLimitCacheData
,
error
)
{
key
:=
billingRateLimitKey
(
keyID
)
result
,
err
:=
c
.
rdb
.
HGetAll
(
ctx
,
key
)
.
Result
()
if
err
!=
nil
{
return
nil
,
err
}
if
len
(
result
)
==
0
{
return
nil
,
redis
.
Nil
}
data
:=
&
service
.
APIKeyRateLimitCacheData
{}
if
v
,
ok
:=
result
[
rateLimitFieldUsage5h
];
ok
{
data
.
Usage5h
,
_
=
strconv
.
ParseFloat
(
v
,
64
)
}
if
v
,
ok
:=
result
[
rateLimitFieldUsage1d
];
ok
{
data
.
Usage1d
,
_
=
strconv
.
ParseFloat
(
v
,
64
)
}
if
v
,
ok
:=
result
[
rateLimitFieldUsage7d
];
ok
{
data
.
Usage7d
,
_
=
strconv
.
ParseFloat
(
v
,
64
)
}
if
v
,
ok
:=
result
[
rateLimitFieldWindow5h
];
ok
{
data
.
Window5h
,
_
=
strconv
.
ParseInt
(
v
,
10
,
64
)
}
if
v
,
ok
:=
result
[
rateLimitFieldWindow1d
];
ok
{
data
.
Window1d
,
_
=
strconv
.
ParseInt
(
v
,
10
,
64
)
}
if
v
,
ok
:=
result
[
rateLimitFieldWindow7d
];
ok
{
data
.
Window7d
,
_
=
strconv
.
ParseInt
(
v
,
10
,
64
)
}
return
data
,
nil
}
func
(
c
*
billingCache
)
SetAPIKeyRateLimit
(
ctx
context
.
Context
,
keyID
int64
,
data
*
service
.
APIKeyRateLimitCacheData
)
error
{
if
data
==
nil
{
return
nil
}
key
:=
billingRateLimitKey
(
keyID
)
fields
:=
map
[
string
]
any
{
rateLimitFieldUsage5h
:
data
.
Usage5h
,
rateLimitFieldUsage1d
:
data
.
Usage1d
,
rateLimitFieldUsage7d
:
data
.
Usage7d
,
rateLimitFieldWindow5h
:
data
.
Window5h
,
rateLimitFieldWindow1d
:
data
.
Window1d
,
rateLimitFieldWindow7d
:
data
.
Window7d
,
}
pipe
:=
c
.
rdb
.
Pipeline
()
pipe
.
HSet
(
ctx
,
key
,
fields
)
pipe
.
Expire
(
ctx
,
key
,
rateLimitCacheTTL
)
_
,
err
:=
pipe
.
Exec
(
ctx
)
return
err
}
func
(
c
*
billingCache
)
UpdateAPIKeyRateLimitUsage
(
ctx
context
.
Context
,
keyID
int64
,
cost
float64
)
error
{
key
:=
billingRateLimitKey
(
keyID
)
_
,
err
:=
updateRateLimitUsageScript
.
Run
(
ctx
,
c
.
rdb
,
[]
string
{
key
},
cost
,
int
(
rateLimitCacheTTL
.
Seconds
()))
.
Result
()
if
err
!=
nil
&&
!
errors
.
Is
(
err
,
redis
.
Nil
)
{
log
.
Printf
(
"Warning: update rate limit usage cache failed for api key %d: %v"
,
keyID
,
err
)
return
err
}
return
nil
}
func
(
c
*
billingCache
)
InvalidateAPIKeyRateLimit
(
ctx
context
.
Context
,
keyID
int64
)
error
{
key
:=
billingRateLimitKey
(
keyID
)
return
c
.
rdb
.
Del
(
ctx
,
key
)
.
Err
()
}
Prev
1
…
7
8
9
10
11
12
13
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment