Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Menu
Open sidebar
陈曦
sub2api
Commits
0170d19f
Commit
0170d19f
authored
Feb 02, 2026
by
song
Browse files
merge upstream main
parent
7ade9baa
Changes
319
Show whitespace changes
Inline
Side-by-side
backend/internal/repository/gateway_routing_integration_test.go
View file @
0170d19f
...
...
@@ -24,7 +24,7 @@ func (s *GatewayRoutingSuite) SetupTest() {
s
.
ctx
=
context
.
Background
()
tx
:=
testEntTx
(
s
.
T
())
s
.
client
=
tx
.
Client
()
s
.
accountRepo
=
newAccountRepositoryWithSQL
(
s
.
client
,
tx
)
s
.
accountRepo
=
newAccountRepositoryWithSQL
(
s
.
client
,
tx
,
nil
)
}
func
TestGatewayRoutingSuite
(
t
*
testing
.
T
)
{
...
...
backend/internal/repository/http_upstream.go
View file @
0170d19f
...
...
@@ -4,6 +4,7 @@ import (
"errors"
"fmt"
"io"
"log/slog"
"net"
"net/http"
"net/url"
...
...
@@ -14,6 +15,7 @@ import (
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/proxyutil"
"github.com/Wei-Shaw/sub2api/internal/pkg/tlsfingerprint"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/Wei-Shaw/sub2api/internal/util/urlvalidator"
)
...
...
@@ -150,6 +152,172 @@ func (s *httpUpstreamService) Do(req *http.Request, proxyURL string, accountID i
return
resp
,
nil
}
// DoWithTLS 执行带 TLS 指纹伪装的 HTTP 请求
// 根据 enableTLSFingerprint 参数决定是否使用 TLS 指纹
//
// 参数:
// - req: HTTP 请求对象
// - proxyURL: 代理地址,空字符串表示直连
// - accountID: 账户 ID,用于账户级隔离和 TLS 指纹模板选择
// - accountConcurrency: 账户并发限制,用于动态调整连接池大小
// - enableTLSFingerprint: 是否启用 TLS 指纹伪装
//
// TLS 指纹说明:
// - 当 enableTLSFingerprint=true 时,使用 utls 库模拟 Claude CLI 的 TLS 指纹
// - 指纹模板根据 accountID % len(profiles) 自动选择
// - 支持直连、HTTP/HTTPS 代理、SOCKS5 代理三种场景
func
(
s
*
httpUpstreamService
)
DoWithTLS
(
req
*
http
.
Request
,
proxyURL
string
,
accountID
int64
,
accountConcurrency
int
,
enableTLSFingerprint
bool
)
(
*
http
.
Response
,
error
)
{
// 如果未启用 TLS 指纹,直接使用标准请求路径
if
!
enableTLSFingerprint
{
return
s
.
Do
(
req
,
proxyURL
,
accountID
,
accountConcurrency
)
}
// TLS 指纹已启用,记录调试日志
targetHost
:=
""
if
req
!=
nil
&&
req
.
URL
!=
nil
{
targetHost
=
req
.
URL
.
Host
}
proxyInfo
:=
"direct"
if
proxyURL
!=
""
{
proxyInfo
=
proxyURL
}
slog
.
Debug
(
"tls_fingerprint_enabled"
,
"account_id"
,
accountID
,
"target"
,
targetHost
,
"proxy"
,
proxyInfo
)
if
err
:=
s
.
validateRequestHost
(
req
);
err
!=
nil
{
return
nil
,
err
}
// 获取 TLS 指纹 Profile
registry
:=
tlsfingerprint
.
GlobalRegistry
()
profile
:=
registry
.
GetProfileByAccountID
(
accountID
)
if
profile
==
nil
{
// 如果获取不到 profile,回退到普通请求
slog
.
Debug
(
"tls_fingerprint_no_profile"
,
"account_id"
,
accountID
,
"fallback"
,
"standard_request"
)
return
s
.
Do
(
req
,
proxyURL
,
accountID
,
accountConcurrency
)
}
slog
.
Debug
(
"tls_fingerprint_using_profile"
,
"account_id"
,
accountID
,
"profile"
,
profile
.
Name
,
"grease"
,
profile
.
EnableGREASE
)
// 获取或创建带 TLS 指纹的客户端
entry
,
err
:=
s
.
acquireClientWithTLS
(
proxyURL
,
accountID
,
accountConcurrency
,
profile
)
if
err
!=
nil
{
slog
.
Debug
(
"tls_fingerprint_acquire_client_failed"
,
"account_id"
,
accountID
,
"error"
,
err
)
return
nil
,
err
}
// 执行请求
resp
,
err
:=
entry
.
client
.
Do
(
req
)
if
err
!=
nil
{
// 请求失败,立即减少计数
atomic
.
AddInt64
(
&
entry
.
inFlight
,
-
1
)
atomic
.
StoreInt64
(
&
entry
.
lastUsed
,
time
.
Now
()
.
UnixNano
())
slog
.
Debug
(
"tls_fingerprint_request_failed"
,
"account_id"
,
accountID
,
"error"
,
err
)
return
nil
,
err
}
slog
.
Debug
(
"tls_fingerprint_request_success"
,
"account_id"
,
accountID
,
"status"
,
resp
.
StatusCode
)
// 包装响应体,在关闭时自动减少计数并更新时间戳
resp
.
Body
=
wrapTrackedBody
(
resp
.
Body
,
func
()
{
atomic
.
AddInt64
(
&
entry
.
inFlight
,
-
1
)
atomic
.
StoreInt64
(
&
entry
.
lastUsed
,
time
.
Now
()
.
UnixNano
())
})
return
resp
,
nil
}
// acquireClientWithTLS 获取或创建带 TLS 指纹的客户端
func
(
s
*
httpUpstreamService
)
acquireClientWithTLS
(
proxyURL
string
,
accountID
int64
,
accountConcurrency
int
,
profile
*
tlsfingerprint
.
Profile
)
(
*
upstreamClientEntry
,
error
)
{
return
s
.
getClientEntryWithTLS
(
proxyURL
,
accountID
,
accountConcurrency
,
profile
,
true
,
true
)
}
// getClientEntryWithTLS 获取或创建带 TLS 指纹的客户端条目
// TLS 指纹客户端使用独立的缓存键,与普通客户端隔离
func
(
s
*
httpUpstreamService
)
getClientEntryWithTLS
(
proxyURL
string
,
accountID
int64
,
accountConcurrency
int
,
profile
*
tlsfingerprint
.
Profile
,
markInFlight
bool
,
enforceLimit
bool
)
(
*
upstreamClientEntry
,
error
)
{
isolation
:=
s
.
getIsolationMode
()
proxyKey
,
parsedProxy
:=
normalizeProxyURL
(
proxyURL
)
// TLS 指纹客户端使用独立的缓存键,加 "tls:" 前缀
cacheKey
:=
"tls:"
+
buildCacheKey
(
isolation
,
proxyKey
,
accountID
)
poolKey
:=
s
.
buildPoolKey
(
isolation
,
accountConcurrency
)
+
":tls"
now
:=
time
.
Now
()
nowUnix
:=
now
.
UnixNano
()
// 读锁快速路径
s
.
mu
.
RLock
()
if
entry
,
ok
:=
s
.
clients
[
cacheKey
];
ok
&&
s
.
shouldReuseEntry
(
entry
,
isolation
,
proxyKey
,
poolKey
)
{
atomic
.
StoreInt64
(
&
entry
.
lastUsed
,
nowUnix
)
if
markInFlight
{
atomic
.
AddInt64
(
&
entry
.
inFlight
,
1
)
}
s
.
mu
.
RUnlock
()
slog
.
Debug
(
"tls_fingerprint_reusing_client"
,
"account_id"
,
accountID
,
"cache_key"
,
cacheKey
)
return
entry
,
nil
}
s
.
mu
.
RUnlock
()
// 写锁慢路径
s
.
mu
.
Lock
()
if
entry
,
ok
:=
s
.
clients
[
cacheKey
];
ok
{
if
s
.
shouldReuseEntry
(
entry
,
isolation
,
proxyKey
,
poolKey
)
{
atomic
.
StoreInt64
(
&
entry
.
lastUsed
,
nowUnix
)
if
markInFlight
{
atomic
.
AddInt64
(
&
entry
.
inFlight
,
1
)
}
s
.
mu
.
Unlock
()
slog
.
Debug
(
"tls_fingerprint_reusing_client"
,
"account_id"
,
accountID
,
"cache_key"
,
cacheKey
)
return
entry
,
nil
}
slog
.
Debug
(
"tls_fingerprint_evicting_stale_client"
,
"account_id"
,
accountID
,
"cache_key"
,
cacheKey
,
"proxy_changed"
,
entry
.
proxyKey
!=
proxyKey
,
"pool_changed"
,
entry
.
poolKey
!=
poolKey
)
s
.
removeClientLocked
(
cacheKey
,
entry
)
}
// 超出缓存上限时尝试淘汰
if
enforceLimit
&&
s
.
maxUpstreamClients
()
>
0
{
s
.
evictIdleLocked
(
now
)
if
len
(
s
.
clients
)
>=
s
.
maxUpstreamClients
()
{
if
!
s
.
evictOldestIdleLocked
()
{
s
.
mu
.
Unlock
()
return
nil
,
errUpstreamClientLimitReached
}
}
}
// 创建带 TLS 指纹的 Transport
slog
.
Debug
(
"tls_fingerprint_creating_new_client"
,
"account_id"
,
accountID
,
"cache_key"
,
cacheKey
,
"proxy"
,
proxyKey
)
settings
:=
s
.
resolvePoolSettings
(
isolation
,
accountConcurrency
)
transport
,
err
:=
buildUpstreamTransportWithTLSFingerprint
(
settings
,
parsedProxy
,
profile
)
if
err
!=
nil
{
s
.
mu
.
Unlock
()
return
nil
,
fmt
.
Errorf
(
"build TLS fingerprint transport: %w"
,
err
)
}
client
:=
&
http
.
Client
{
Transport
:
transport
}
if
s
.
shouldValidateResolvedIP
()
{
client
.
CheckRedirect
=
s
.
redirectChecker
}
entry
:=
&
upstreamClientEntry
{
client
:
client
,
proxyKey
:
proxyKey
,
poolKey
:
poolKey
,
}
atomic
.
StoreInt64
(
&
entry
.
lastUsed
,
nowUnix
)
if
markInFlight
{
atomic
.
StoreInt64
(
&
entry
.
inFlight
,
1
)
}
s
.
clients
[
cacheKey
]
=
entry
s
.
evictIdleLocked
(
now
)
s
.
evictOverLimitLocked
()
s
.
mu
.
Unlock
()
return
entry
,
nil
}
func
(
s
*
httpUpstreamService
)
shouldValidateResolvedIP
()
bool
{
if
s
.
cfg
==
nil
{
return
false
...
...
@@ -618,6 +786,64 @@ func buildUpstreamTransport(settings poolSettings, proxyURL *url.URL) (*http.Tra
return
transport
,
nil
}
// buildUpstreamTransportWithTLSFingerprint 构建带 TLS 指纹伪装的 Transport
// 使用 utls 库模拟 Claude CLI 的 TLS 指纹
//
// 参数:
// - settings: 连接池配置
// - proxyURL: 代理 URL(nil 表示直连)
// - profile: TLS 指纹配置
//
// 返回:
// - *http.Transport: 配置好的 Transport 实例
// - error: 配置错误
//
// 代理类型处理:
// - nil/空: 直连,使用 TLSFingerprintDialer
// - http/https: HTTP 代理,使用 HTTPProxyDialer(CONNECT 隧道 + utls 握手)
// - socks5: SOCKS5 代理,使用 SOCKS5ProxyDialer(SOCKS5 隧道 + utls 握手)
func
buildUpstreamTransportWithTLSFingerprint
(
settings
poolSettings
,
proxyURL
*
url
.
URL
,
profile
*
tlsfingerprint
.
Profile
)
(
*
http
.
Transport
,
error
)
{
transport
:=
&
http
.
Transport
{
MaxIdleConns
:
settings
.
maxIdleConns
,
MaxIdleConnsPerHost
:
settings
.
maxIdleConnsPerHost
,
MaxConnsPerHost
:
settings
.
maxConnsPerHost
,
IdleConnTimeout
:
settings
.
idleConnTimeout
,
ResponseHeaderTimeout
:
settings
.
responseHeaderTimeout
,
// 禁用默认的 TLS,我们使用自定义的 DialTLSContext
ForceAttemptHTTP2
:
false
,
}
// 根据代理类型选择合适的 TLS 指纹 Dialer
if
proxyURL
==
nil
{
// 直连:使用 TLSFingerprintDialer
slog
.
Debug
(
"tls_fingerprint_transport_direct"
)
dialer
:=
tlsfingerprint
.
NewDialer
(
profile
,
nil
)
transport
.
DialTLSContext
=
dialer
.
DialTLSContext
}
else
{
scheme
:=
strings
.
ToLower
(
proxyURL
.
Scheme
)
switch
scheme
{
case
"socks5"
,
"socks5h"
:
// SOCKS5 代理:使用 SOCKS5ProxyDialer
slog
.
Debug
(
"tls_fingerprint_transport_socks5"
,
"proxy"
,
proxyURL
.
Host
)
socks5Dialer
:=
tlsfingerprint
.
NewSOCKS5ProxyDialer
(
profile
,
proxyURL
)
transport
.
DialTLSContext
=
socks5Dialer
.
DialTLSContext
case
"http"
,
"https"
:
// HTTP/HTTPS 代理:使用 HTTPProxyDialer(CONNECT 隧道)
slog
.
Debug
(
"tls_fingerprint_transport_http_connect"
,
"proxy"
,
proxyURL
.
Host
)
httpDialer
:=
tlsfingerprint
.
NewHTTPProxyDialer
(
profile
,
proxyURL
)
transport
.
DialTLSContext
=
httpDialer
.
DialTLSContext
default
:
// 未知代理类型,回退到普通代理配置(无 TLS 指纹)
slog
.
Debug
(
"tls_fingerprint_transport_unknown_scheme_fallback"
,
"scheme"
,
scheme
)
if
err
:=
proxyutil
.
ConfigureTransportProxy
(
transport
,
proxyURL
);
err
!=
nil
{
return
nil
,
err
}
}
}
return
transport
,
nil
}
// trackedBody 带跟踪功能的响应体包装器
// 在 Close 时执行回调,用于更新请求计数
type
trackedBody
struct
{
...
...
backend/internal/repository/identity_cache.go
View file @
0170d19f
...
...
@@ -13,6 +13,8 @@ import (
const
(
fingerprintKeyPrefix
=
"fingerprint:"
fingerprintTTL
=
24
*
time
.
Hour
maskedSessionKeyPrefix
=
"masked_session:"
maskedSessionTTL
=
15
*
time
.
Minute
)
// fingerprintKey generates the Redis key for account fingerprint cache.
...
...
@@ -20,6 +22,11 @@ func fingerprintKey(accountID int64) string {
return
fmt
.
Sprintf
(
"%s%d"
,
fingerprintKeyPrefix
,
accountID
)
}
// maskedSessionKey generates the Redis key for masked session ID cache.
func
maskedSessionKey
(
accountID
int64
)
string
{
return
fmt
.
Sprintf
(
"%s%d"
,
maskedSessionKeyPrefix
,
accountID
)
}
type
identityCache
struct
{
rdb
*
redis
.
Client
}
...
...
@@ -49,3 +56,20 @@ func (c *identityCache) SetFingerprint(ctx context.Context, accountID int64, fp
}
return
c
.
rdb
.
Set
(
ctx
,
key
,
val
,
fingerprintTTL
)
.
Err
()
}
func
(
c
*
identityCache
)
GetMaskedSessionID
(
ctx
context
.
Context
,
accountID
int64
)
(
string
,
error
)
{
key
:=
maskedSessionKey
(
accountID
)
val
,
err
:=
c
.
rdb
.
Get
(
ctx
,
key
)
.
Result
()
if
err
!=
nil
{
if
err
==
redis
.
Nil
{
return
""
,
nil
}
return
""
,
err
}
return
val
,
nil
}
func
(
c
*
identityCache
)
SetMaskedSessionID
(
ctx
context
.
Context
,
accountID
int64
,
sessionID
string
)
error
{
key
:=
maskedSessionKey
(
accountID
)
return
c
.
rdb
.
Set
(
ctx
,
key
,
sessionID
,
maskedSessionTTL
)
.
Err
()
}
backend/internal/repository/openai_oauth_service.go
View file @
0170d19f
...
...
@@ -2,10 +2,11 @@ package repository
import
(
"context"
"
fmt
"
"
net/http
"
"net/url"
"time"
infraerrors
"github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/imroc/req/v3"
...
...
@@ -38,16 +39,17 @@ func (s *openaiOAuthService) ExchangeCode(ctx context.Context, code, codeVerifie
resp
,
err
:=
client
.
R
()
.
SetContext
(
ctx
)
.
SetHeader
(
"User-Agent"
,
"codex-cli/0.91.0"
)
.
SetFormDataFromValues
(
formData
)
.
SetSuccessResult
(
&
tokenResp
)
.
Post
(
s
.
tokenURL
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"request failed: %
w
"
,
err
)
return
nil
,
infraerrors
.
Newf
(
http
.
StatusBadGateway
,
"OPENAI_OAUTH_REQUEST_FAILED"
,
"request failed: %
v
"
,
err
)
}
if
!
resp
.
IsSuccessState
()
{
return
nil
,
fmt
.
Errorf
(
"token exchange failed: status %d, body: %s"
,
resp
.
StatusCode
,
resp
.
String
())
return
nil
,
infraerrors
.
Newf
(
http
.
StatusBadGateway
,
"OPENAI_OAUTH_TOKEN_EXCHANGE_FAILED"
,
"token exchange failed: status %d, body: %s"
,
resp
.
StatusCode
,
resp
.
String
())
}
return
&
tokenResp
,
nil
...
...
@@ -66,16 +68,17 @@ func (s *openaiOAuthService) RefreshToken(ctx context.Context, refreshToken, pro
resp
,
err
:=
client
.
R
()
.
SetContext
(
ctx
)
.
SetHeader
(
"User-Agent"
,
"codex-cli/0.91.0"
)
.
SetFormDataFromValues
(
formData
)
.
SetSuccessResult
(
&
tokenResp
)
.
Post
(
s
.
tokenURL
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"request failed: %
w
"
,
err
)
return
nil
,
infraerrors
.
Newf
(
http
.
StatusBadGateway
,
"OPENAI_OAUTH_REQUEST_FAILED"
,
"request failed: %
v
"
,
err
)
}
if
!
resp
.
IsSuccessState
()
{
return
nil
,
fmt
.
Errorf
(
"token refresh failed: status %d, body: %s"
,
resp
.
StatusCode
,
resp
.
String
())
return
nil
,
infraerrors
.
Newf
(
http
.
StatusBadGateway
,
"OPENAI_OAUTH_TOKEN_REFRESH_FAILED"
,
"token refresh failed: status %d, body: %s"
,
resp
.
StatusCode
,
resp
.
String
())
}
return
&
tokenResp
,
nil
...
...
@@ -84,6 +87,6 @@ func (s *openaiOAuthService) RefreshToken(ctx context.Context, refreshToken, pro
func
createOpenAIReqClient
(
proxyURL
string
)
*
req
.
Client
{
return
getSharedReqClient
(
reqClientOptions
{
ProxyURL
:
proxyURL
,
Timeout
:
6
0
*
time
.
Second
,
Timeout
:
12
0
*
time
.
Second
,
})
}
backend/internal/repository/openai_oauth_service_test.go
View file @
0170d19f
...
...
@@ -244,6 +244,13 @@ func (s *OpenAIOAuthServiceSuite) TestRefreshToken_NonSuccessStatus() {
require
.
ErrorContains
(
s
.
T
(),
err
,
"status 401"
)
}
func
TestNewOpenAIOAuthClient_DefaultTokenURL
(
t
*
testing
.
T
)
{
client
:=
NewOpenAIOAuthClient
()
svc
,
ok
:=
client
.
(
*
openaiOAuthService
)
require
.
True
(
t
,
ok
)
require
.
Equal
(
t
,
openai
.
TokenURL
,
svc
.
tokenURL
)
}
func
TestOpenAIOAuthServiceSuite
(
t
*
testing
.
T
)
{
suite
.
Run
(
t
,
new
(
OpenAIOAuthServiceSuite
))
}
backend/internal/repository/ops_repo.go
View file @
0170d19f
...
...
@@ -992,7 +992,8 @@ func buildOpsErrorLogsWhere(filter *service.OpsErrorLogFilter) (string, []any) {
}
// View filter: errors vs excluded vs all.
// Excluded = upstream 429/529 and business-limited (quota/concurrency/billing) errors.
// Excluded = business-limited errors (quota/concurrency/billing).
// Upstream 429/529 are included in errors view to match SLA calculation.
view
:=
""
if
filter
!=
nil
{
view
=
strings
.
ToLower
(
strings
.
TrimSpace
(
filter
.
View
))
...
...
@@ -1000,15 +1001,13 @@ func buildOpsErrorLogsWhere(filter *service.OpsErrorLogFilter) (string, []any) {
switch
view
{
case
""
,
"errors"
:
clauses
=
append
(
clauses
,
"COALESCE(is_business_limited,false) = false"
)
clauses
=
append
(
clauses
,
"COALESCE(upstream_status_code, status_code, 0) NOT IN (429, 529)"
)
case
"excluded"
:
clauses
=
append
(
clauses
,
"
(
COALESCE(is_business_limited,false) = true
OR COALESCE(upstream_status_code, status_code, 0) IN (429, 529))
"
)
clauses
=
append
(
clauses
,
"COALESCE(is_business_limited,false) = true"
)
case
"all"
:
// no-op
default
:
// treat unknown as default 'errors'
clauses
=
append
(
clauses
,
"COALESCE(is_business_limited,false) = false"
)
clauses
=
append
(
clauses
,
"COALESCE(upstream_status_code, status_code, 0) NOT IN (429, 529)"
)
}
if
len
(
filter
.
StatusCodes
)
>
0
{
args
=
append
(
args
,
pq
.
Array
(
filter
.
StatusCodes
))
...
...
backend/internal/repository/redis.go
View file @
0170d19f
package
repository
import
(
"crypto/tls"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
...
...
@@ -26,7 +27,7 @@ func InitRedis(cfg *config.Config) *redis.Client {
// buildRedisOptions 构建 Redis 连接选项
// 从配置文件读取连接池和超时参数,支持生产环境调优
func
buildRedisOptions
(
cfg
*
config
.
Config
)
*
redis
.
Options
{
return
&
redis
.
Options
{
opts
:=
&
redis
.
Options
{
Addr
:
cfg
.
Redis
.
Address
(),
Password
:
cfg
.
Redis
.
Password
,
DB
:
cfg
.
Redis
.
DB
,
...
...
@@ -36,4 +37,13 @@ func buildRedisOptions(cfg *config.Config) *redis.Options {
PoolSize
:
cfg
.
Redis
.
PoolSize
,
// 连接池大小
MinIdleConns
:
cfg
.
Redis
.
MinIdleConns
,
// 最小空闲连接
}
if
cfg
.
Redis
.
EnableTLS
{
opts
.
TLSConfig
=
&
tls
.
Config
{
MinVersion
:
tls
.
VersionTLS12
,
ServerName
:
cfg
.
Redis
.
Host
,
}
}
return
opts
}
backend/internal/repository/redis_test.go
View file @
0170d19f
...
...
@@ -32,4 +32,16 @@ func TestBuildRedisOptions(t *testing.T) {
require
.
Equal
(
t
,
4
*
time
.
Second
,
opts
.
WriteTimeout
)
require
.
Equal
(
t
,
100
,
opts
.
PoolSize
)
require
.
Equal
(
t
,
10
,
opts
.
MinIdleConns
)
require
.
Nil
(
t
,
opts
.
TLSConfig
)
// Test case with TLS enabled
cfgTLS
:=
&
config
.
Config
{
Redis
:
config
.
RedisConfig
{
Host
:
"localhost"
,
EnableTLS
:
true
,
},
}
optsTLS
:=
buildRedisOptions
(
cfgTLS
)
require
.
NotNil
(
t
,
optsTLS
.
TLSConfig
)
require
.
Equal
(
t
,
"localhost"
,
optsTLS
.
TLSConfig
.
ServerName
)
}
backend/internal/repository/req_client_pool.go
View file @
0170d19f
...
...
@@ -14,6 +14,7 @@ type reqClientOptions struct {
ProxyURL
string
// 代理 URL(支持 http/https/socks5)
Timeout
time
.
Duration
// 请求超时时间
Impersonate
bool
// 是否模拟 Chrome 浏览器指纹
ForceHTTP2
bool
// 是否强制使用 HTTP/2
}
// sharedReqClients 存储按配置参数缓存的 req 客户端实例
...
...
@@ -41,6 +42,9 @@ func getSharedReqClient(opts reqClientOptions) *req.Client {
}
client
:=
req
.
C
()
.
SetTimeout
(
opts
.
Timeout
)
if
opts
.
ForceHTTP2
{
client
=
client
.
EnableForceHTTP2
()
}
if
opts
.
Impersonate
{
client
=
client
.
ImpersonateChrome
()
}
...
...
@@ -56,9 +60,10 @@ func getSharedReqClient(opts reqClientOptions) *req.Client {
}
func
buildReqClientKey
(
opts
reqClientOptions
)
string
{
return
fmt
.
Sprintf
(
"%s|%s|%t"
,
return
fmt
.
Sprintf
(
"%s|%s|%t
|%t
"
,
strings
.
TrimSpace
(
opts
.
ProxyURL
),
opts
.
Timeout
.
String
(),
opts
.
Impersonate
,
opts
.
ForceHTTP2
,
)
}
backend/internal/repository/req_client_pool_test.go
0 → 100644
View file @
0170d19f
package
repository
import
(
"reflect"
"sync"
"testing"
"time"
"unsafe"
"github.com/imroc/req/v3"
"github.com/stretchr/testify/require"
)
func
forceHTTPVersion
(
t
*
testing
.
T
,
client
*
req
.
Client
)
string
{
t
.
Helper
()
transport
:=
client
.
GetTransport
()
field
:=
reflect
.
ValueOf
(
transport
)
.
Elem
()
.
FieldByName
(
"forceHttpVersion"
)
require
.
True
(
t
,
field
.
IsValid
(),
"forceHttpVersion field not found"
)
require
.
True
(
t
,
field
.
CanAddr
(),
"forceHttpVersion field not addressable"
)
return
reflect
.
NewAt
(
field
.
Type
(),
unsafe
.
Pointer
(
field
.
UnsafeAddr
()))
.
Elem
()
.
String
()
}
func
TestGetSharedReqClient_ForceHTTP2SeparatesCache
(
t
*
testing
.
T
)
{
sharedReqClients
=
sync
.
Map
{}
base
:=
reqClientOptions
{
ProxyURL
:
"http://proxy.local:8080"
,
Timeout
:
time
.
Second
,
}
clientDefault
:=
getSharedReqClient
(
base
)
force
:=
base
force
.
ForceHTTP2
=
true
clientForce
:=
getSharedReqClient
(
force
)
require
.
NotSame
(
t
,
clientDefault
,
clientForce
)
require
.
NotEqual
(
t
,
buildReqClientKey
(
base
),
buildReqClientKey
(
force
))
}
func
TestGetSharedReqClient_ReuseCachedClient
(
t
*
testing
.
T
)
{
sharedReqClients
=
sync
.
Map
{}
opts
:=
reqClientOptions
{
ProxyURL
:
"http://proxy.local:8080"
,
Timeout
:
2
*
time
.
Second
,
}
first
:=
getSharedReqClient
(
opts
)
second
:=
getSharedReqClient
(
opts
)
require
.
Same
(
t
,
first
,
second
)
}
func
TestGetSharedReqClient_IgnoresNonClientCache
(
t
*
testing
.
T
)
{
sharedReqClients
=
sync
.
Map
{}
opts
:=
reqClientOptions
{
ProxyURL
:
" http://proxy.local:8080 "
,
Timeout
:
3
*
time
.
Second
,
}
key
:=
buildReqClientKey
(
opts
)
sharedReqClients
.
Store
(
key
,
"invalid"
)
client
:=
getSharedReqClient
(
opts
)
require
.
NotNil
(
t
,
client
)
loaded
,
ok
:=
sharedReqClients
.
Load
(
key
)
require
.
True
(
t
,
ok
)
require
.
IsType
(
t
,
"invalid"
,
loaded
)
}
func
TestGetSharedReqClient_ImpersonateAndProxy
(
t
*
testing
.
T
)
{
sharedReqClients
=
sync
.
Map
{}
opts
:=
reqClientOptions
{
ProxyURL
:
" http://proxy.local:8080 "
,
Timeout
:
4
*
time
.
Second
,
Impersonate
:
true
,
}
client
:=
getSharedReqClient
(
opts
)
require
.
NotNil
(
t
,
client
)
require
.
Equal
(
t
,
"http://proxy.local:8080|4s|true|false"
,
buildReqClientKey
(
opts
))
}
func
TestCreateOpenAIReqClient_Timeout120Seconds
(
t
*
testing
.
T
)
{
sharedReqClients
=
sync
.
Map
{}
client
:=
createOpenAIReqClient
(
"http://proxy.local:8080"
)
require
.
Equal
(
t
,
120
*
time
.
Second
,
client
.
GetClient
()
.
Timeout
)
}
func
TestCreateGeminiReqClient_ForceHTTP2Disabled
(
t
*
testing
.
T
)
{
sharedReqClients
=
sync
.
Map
{}
client
:=
createGeminiReqClient
(
"http://proxy.local:8080"
)
require
.
Equal
(
t
,
""
,
forceHTTPVersion
(
t
,
client
))
}
backend/internal/repository/scheduler_cache.go
View file @
0170d19f
...
...
@@ -58,7 +58,9 @@ func (c *schedulerCache) GetSnapshot(ctx context.Context, bucket service.Schedul
return
nil
,
false
,
err
}
if
len
(
ids
)
==
0
{
return
[]
*
service
.
Account
{},
true
,
nil
// 空快照视为缓存未命中,触发数据库回退查询
// 这解决了新分组创建后立即绑定账号时的竞态条件问题
return
nil
,
false
,
nil
}
keys
:=
make
([]
string
,
0
,
len
(
ids
))
...
...
backend/internal/repository/scheduler_snapshot_outbox_integration_test.go
View file @
0170d19f
...
...
@@ -19,7 +19,7 @@ func TestSchedulerSnapshotOutboxReplay(t *testing.T) {
_
,
_
=
integrationDB
.
ExecContext
(
ctx
,
"TRUNCATE scheduler_outbox"
)
accountRepo
:=
newAccountRepositoryWithSQL
(
client
,
integrationDB
)
accountRepo
:=
newAccountRepositoryWithSQL
(
client
,
integrationDB
,
nil
)
outboxRepo
:=
NewSchedulerOutboxRepository
(
integrationDB
)
cache
:=
NewSchedulerCache
(
rdb
)
...
...
backend/internal/repository/session_limit_cache.go
View file @
0170d19f
...
...
@@ -217,7 +217,7 @@ func (c *sessionLimitCache) GetActiveSessionCount(ctx context.Context, accountID
}
// GetActiveSessionCountBatch 批量获取多个账号的活跃会话数
func
(
c
*
sessionLimitCache
)
GetActiveSessionCountBatch
(
ctx
context
.
Context
,
accountIDs
[]
int64
)
(
map
[
int64
]
int
,
error
)
{
func
(
c
*
sessionLimitCache
)
GetActiveSessionCountBatch
(
ctx
context
.
Context
,
accountIDs
[]
int64
,
idleTimeouts
map
[
int64
]
time
.
Duration
)
(
map
[
int64
]
int
,
error
)
{
if
len
(
accountIDs
)
==
0
{
return
make
(
map
[
int64
]
int
),
nil
}
...
...
@@ -226,11 +226,18 @@ func (c *sessionLimitCache) GetActiveSessionCountBatch(ctx context.Context, acco
// 使用 pipeline 批量执行
pipe
:=
c
.
rdb
.
Pipeline
()
idleTimeoutSeconds
:=
int
(
c
.
defaultIdleTimeout
.
Seconds
())
cmds
:=
make
(
map
[
int64
]
*
redis
.
Cmd
,
len
(
accountIDs
))
for
_
,
accountID
:=
range
accountIDs
{
key
:=
sessionLimitKey
(
accountID
)
// 使用各账号自己的 idleTimeout,如果没有则用默认值
idleTimeout
:=
c
.
defaultIdleTimeout
if
idleTimeouts
!=
nil
{
if
t
,
ok
:=
idleTimeouts
[
accountID
];
ok
&&
t
>
0
{
idleTimeout
=
t
}
}
idleTimeoutSeconds
:=
int
(
idleTimeout
.
Seconds
())
cmds
[
accountID
]
=
getActiveSessionCountScript
.
Run
(
ctx
,
pipe
,
[]
string
{
key
},
idleTimeoutSeconds
)
}
...
...
backend/internal/repository/simple_mode_default_groups.go
0 → 100644
View file @
0170d19f
package
repository
import
(
"context"
"fmt"
dbent
"github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/group"
"github.com/Wei-Shaw/sub2api/internal/service"
)
func
ensureSimpleModeDefaultGroups
(
ctx
context
.
Context
,
client
*
dbent
.
Client
)
error
{
if
client
==
nil
{
return
fmt
.
Errorf
(
"nil ent client"
)
}
requiredByPlatform
:=
map
[
string
]
int
{
service
.
PlatformAnthropic
:
1
,
service
.
PlatformOpenAI
:
1
,
service
.
PlatformGemini
:
1
,
service
.
PlatformAntigravity
:
2
,
}
for
platform
,
minCount
:=
range
requiredByPlatform
{
count
,
err
:=
client
.
Group
.
Query
()
.
Where
(
group
.
PlatformEQ
(
platform
),
group
.
DeletedAtIsNil
())
.
Count
(
ctx
)
if
err
!=
nil
{
return
fmt
.
Errorf
(
"count groups for platform %s: %w"
,
platform
,
err
)
}
if
platform
==
service
.
PlatformAntigravity
{
if
count
<
minCount
{
for
i
:=
count
;
i
<
minCount
;
i
++
{
name
:=
fmt
.
Sprintf
(
"%s-default-%d"
,
platform
,
i
+
1
)
if
err
:=
createGroupIfNotExists
(
ctx
,
client
,
name
,
platform
);
err
!=
nil
{
return
err
}
}
}
continue
}
// Non-antigravity platforms: ensure <platform>-default exists.
name
:=
platform
+
"-default"
if
err
:=
createGroupIfNotExists
(
ctx
,
client
,
name
,
platform
);
err
!=
nil
{
return
err
}
}
return
nil
}
func
createGroupIfNotExists
(
ctx
context
.
Context
,
client
*
dbent
.
Client
,
name
,
platform
string
)
error
{
exists
,
err
:=
client
.
Group
.
Query
()
.
Where
(
group
.
NameEQ
(
name
),
group
.
DeletedAtIsNil
())
.
Exist
(
ctx
)
if
err
!=
nil
{
return
fmt
.
Errorf
(
"check group exists %s: %w"
,
name
,
err
)
}
if
exists
{
return
nil
}
_
,
err
=
client
.
Group
.
Create
()
.
SetName
(
name
)
.
SetDescription
(
"Auto-created default group"
)
.
SetPlatform
(
platform
)
.
SetStatus
(
service
.
StatusActive
)
.
SetSubscriptionType
(
service
.
SubscriptionTypeStandard
)
.
SetRateMultiplier
(
1.0
)
.
SetIsExclusive
(
false
)
.
Save
(
ctx
)
if
err
!=
nil
{
if
dbent
.
IsConstraintError
(
err
)
{
// Concurrent server startups may race on creation; treat as success.
return
nil
}
return
fmt
.
Errorf
(
"create default group %s: %w"
,
name
,
err
)
}
return
nil
}
backend/internal/repository/simple_mode_default_groups_integration_test.go
0 → 100644
View file @
0170d19f
//go:build integration
package
repository
import
(
"context"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/ent/group"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/stretchr/testify/require"
)
func
TestEnsureSimpleModeDefaultGroups_CreatesMissingDefaults
(
t
*
testing
.
T
)
{
ctx
:=
context
.
Background
()
tx
:=
testEntTx
(
t
)
client
:=
tx
.
Client
()
seedCtx
,
cancel
:=
context
.
WithTimeout
(
ctx
,
10
*
time
.
Second
)
defer
cancel
()
require
.
NoError
(
t
,
ensureSimpleModeDefaultGroups
(
seedCtx
,
client
))
assertGroupExists
:=
func
(
name
string
)
{
exists
,
err
:=
client
.
Group
.
Query
()
.
Where
(
group
.
NameEQ
(
name
),
group
.
DeletedAtIsNil
())
.
Exist
(
seedCtx
)
require
.
NoError
(
t
,
err
)
require
.
True
(
t
,
exists
,
"expected group %s to exist"
,
name
)
}
assertGroupExists
(
service
.
PlatformAnthropic
+
"-default"
)
assertGroupExists
(
service
.
PlatformOpenAI
+
"-default"
)
assertGroupExists
(
service
.
PlatformGemini
+
"-default"
)
assertGroupExists
(
service
.
PlatformAntigravity
+
"-default-1"
)
assertGroupExists
(
service
.
PlatformAntigravity
+
"-default-2"
)
}
func
TestEnsureSimpleModeDefaultGroups_IgnoresSoftDeletedGroups
(
t
*
testing
.
T
)
{
ctx
:=
context
.
Background
()
tx
:=
testEntTx
(
t
)
client
:=
tx
.
Client
()
seedCtx
,
cancel
:=
context
.
WithTimeout
(
ctx
,
10
*
time
.
Second
)
defer
cancel
()
// Create and then soft-delete an anthropic default group.
g
,
err
:=
client
.
Group
.
Create
()
.
SetName
(
service
.
PlatformAnthropic
+
"-default"
)
.
SetPlatform
(
service
.
PlatformAnthropic
)
.
SetStatus
(
service
.
StatusActive
)
.
SetSubscriptionType
(
service
.
SubscriptionTypeStandard
)
.
SetRateMultiplier
(
1.0
)
.
SetIsExclusive
(
false
)
.
Save
(
seedCtx
)
require
.
NoError
(
t
,
err
)
_
,
err
=
client
.
Group
.
Delete
()
.
Where
(
group
.
IDEQ
(
g
.
ID
))
.
Exec
(
seedCtx
)
require
.
NoError
(
t
,
err
)
require
.
NoError
(
t
,
ensureSimpleModeDefaultGroups
(
seedCtx
,
client
))
// New active one should exist.
count
,
err
:=
client
.
Group
.
Query
()
.
Where
(
group
.
NameEQ
(
service
.
PlatformAnthropic
+
"-default"
),
group
.
DeletedAtIsNil
())
.
Count
(
seedCtx
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
1
,
count
)
}
func
TestEnsureSimpleModeDefaultGroups_AntigravityNeedsTwoGroupsOnlyByCount
(
t
*
testing
.
T
)
{
ctx
:=
context
.
Background
()
tx
:=
testEntTx
(
t
)
client
:=
tx
.
Client
()
seedCtx
,
cancel
:=
context
.
WithTimeout
(
ctx
,
10
*
time
.
Second
)
defer
cancel
()
mustCreateGroup
(
t
,
client
,
&
service
.
Group
{
Name
:
"ag-custom-1-"
+
time
.
Now
()
.
Format
(
time
.
RFC3339Nano
),
Platform
:
service
.
PlatformAntigravity
})
mustCreateGroup
(
t
,
client
,
&
service
.
Group
{
Name
:
"ag-custom-2-"
+
time
.
Now
()
.
Format
(
time
.
RFC3339Nano
),
Platform
:
service
.
PlatformAntigravity
})
require
.
NoError
(
t
,
ensureSimpleModeDefaultGroups
(
seedCtx
,
client
))
count
,
err
:=
client
.
Group
.
Query
()
.
Where
(
group
.
PlatformEQ
(
service
.
PlatformAntigravity
),
group
.
DeletedAtIsNil
())
.
Count
(
seedCtx
)
require
.
NoError
(
t
,
err
)
require
.
GreaterOrEqual
(
t
,
count
,
2
)
}
backend/internal/repository/totp_cache.go
0 → 100644
View file @
0170d19f
package
repository
import
(
"context"
"encoding/json"
"fmt"
"time"
"github.com/redis/go-redis/v9"
"github.com/Wei-Shaw/sub2api/internal/service"
)
const
(
totpSetupKeyPrefix
=
"totp:setup:"
totpLoginKeyPrefix
=
"totp:login:"
totpAttemptsKeyPrefix
=
"totp:attempts:"
totpAttemptsTTL
=
15
*
time
.
Minute
)
// TotpCache implements service.TotpCache using Redis
type
TotpCache
struct
{
rdb
*
redis
.
Client
}
// NewTotpCache creates a new TOTP cache
func
NewTotpCache
(
rdb
*
redis
.
Client
)
service
.
TotpCache
{
return
&
TotpCache
{
rdb
:
rdb
}
}
// GetSetupSession retrieves a TOTP setup session
func
(
c
*
TotpCache
)
GetSetupSession
(
ctx
context
.
Context
,
userID
int64
)
(
*
service
.
TotpSetupSession
,
error
)
{
key
:=
fmt
.
Sprintf
(
"%s%d"
,
totpSetupKeyPrefix
,
userID
)
data
,
err
:=
c
.
rdb
.
Get
(
ctx
,
key
)
.
Bytes
()
if
err
!=
nil
{
if
err
==
redis
.
Nil
{
return
nil
,
nil
}
return
nil
,
fmt
.
Errorf
(
"get setup session: %w"
,
err
)
}
var
session
service
.
TotpSetupSession
if
err
:=
json
.
Unmarshal
(
data
,
&
session
);
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"unmarshal setup session: %w"
,
err
)
}
return
&
session
,
nil
}
// SetSetupSession stores a TOTP setup session
func
(
c
*
TotpCache
)
SetSetupSession
(
ctx
context
.
Context
,
userID
int64
,
session
*
service
.
TotpSetupSession
,
ttl
time
.
Duration
)
error
{
key
:=
fmt
.
Sprintf
(
"%s%d"
,
totpSetupKeyPrefix
,
userID
)
data
,
err
:=
json
.
Marshal
(
session
)
if
err
!=
nil
{
return
fmt
.
Errorf
(
"marshal setup session: %w"
,
err
)
}
if
err
:=
c
.
rdb
.
Set
(
ctx
,
key
,
data
,
ttl
)
.
Err
();
err
!=
nil
{
return
fmt
.
Errorf
(
"set setup session: %w"
,
err
)
}
return
nil
}
// DeleteSetupSession deletes a TOTP setup session
func
(
c
*
TotpCache
)
DeleteSetupSession
(
ctx
context
.
Context
,
userID
int64
)
error
{
key
:=
fmt
.
Sprintf
(
"%s%d"
,
totpSetupKeyPrefix
,
userID
)
return
c
.
rdb
.
Del
(
ctx
,
key
)
.
Err
()
}
// GetLoginSession retrieves a TOTP login session
func
(
c
*
TotpCache
)
GetLoginSession
(
ctx
context
.
Context
,
tempToken
string
)
(
*
service
.
TotpLoginSession
,
error
)
{
key
:=
totpLoginKeyPrefix
+
tempToken
data
,
err
:=
c
.
rdb
.
Get
(
ctx
,
key
)
.
Bytes
()
if
err
!=
nil
{
if
err
==
redis
.
Nil
{
return
nil
,
nil
}
return
nil
,
fmt
.
Errorf
(
"get login session: %w"
,
err
)
}
var
session
service
.
TotpLoginSession
if
err
:=
json
.
Unmarshal
(
data
,
&
session
);
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"unmarshal login session: %w"
,
err
)
}
return
&
session
,
nil
}
// SetLoginSession stores a TOTP login session
func
(
c
*
TotpCache
)
SetLoginSession
(
ctx
context
.
Context
,
tempToken
string
,
session
*
service
.
TotpLoginSession
,
ttl
time
.
Duration
)
error
{
key
:=
totpLoginKeyPrefix
+
tempToken
data
,
err
:=
json
.
Marshal
(
session
)
if
err
!=
nil
{
return
fmt
.
Errorf
(
"marshal login session: %w"
,
err
)
}
if
err
:=
c
.
rdb
.
Set
(
ctx
,
key
,
data
,
ttl
)
.
Err
();
err
!=
nil
{
return
fmt
.
Errorf
(
"set login session: %w"
,
err
)
}
return
nil
}
// DeleteLoginSession deletes a TOTP login session
func
(
c
*
TotpCache
)
DeleteLoginSession
(
ctx
context
.
Context
,
tempToken
string
)
error
{
key
:=
totpLoginKeyPrefix
+
tempToken
return
c
.
rdb
.
Del
(
ctx
,
key
)
.
Err
()
}
// IncrementVerifyAttempts increments the verify attempt counter
func
(
c
*
TotpCache
)
IncrementVerifyAttempts
(
ctx
context
.
Context
,
userID
int64
)
(
int
,
error
)
{
key
:=
fmt
.
Sprintf
(
"%s%d"
,
totpAttemptsKeyPrefix
,
userID
)
// Use pipeline for atomic increment and set TTL
pipe
:=
c
.
rdb
.
Pipeline
()
incrCmd
:=
pipe
.
Incr
(
ctx
,
key
)
pipe
.
Expire
(
ctx
,
key
,
totpAttemptsTTL
)
if
_
,
err
:=
pipe
.
Exec
(
ctx
);
err
!=
nil
{
return
0
,
fmt
.
Errorf
(
"increment verify attempts: %w"
,
err
)
}
count
,
err
:=
incrCmd
.
Result
()
if
err
!=
nil
{
return
0
,
fmt
.
Errorf
(
"get increment result: %w"
,
err
)
}
return
int
(
count
),
nil
}
// GetVerifyAttempts gets the current verify attempt count
func
(
c
*
TotpCache
)
GetVerifyAttempts
(
ctx
context
.
Context
,
userID
int64
)
(
int
,
error
)
{
key
:=
fmt
.
Sprintf
(
"%s%d"
,
totpAttemptsKeyPrefix
,
userID
)
count
,
err
:=
c
.
rdb
.
Get
(
ctx
,
key
)
.
Int
()
if
err
!=
nil
{
if
err
==
redis
.
Nil
{
return
0
,
nil
}
return
0
,
fmt
.
Errorf
(
"get verify attempts: %w"
,
err
)
}
return
count
,
nil
}
// ClearVerifyAttempts clears the verify attempt counter
func
(
c
*
TotpCache
)
ClearVerifyAttempts
(
ctx
context
.
Context
,
userID
int64
)
error
{
key
:=
fmt
.
Sprintf
(
"%s%d"
,
totpAttemptsKeyPrefix
,
userID
)
return
c
.
rdb
.
Del
(
ctx
,
key
)
.
Err
()
}
backend/internal/repository/usage_cleanup_repo.go
0 → 100644
View file @
0170d19f
package
repository
import
(
"context"
"database/sql"
"encoding/json"
"errors"
"fmt"
"strings"
"time"
dbent
"github.com/Wei-Shaw/sub2api/ent"
dbusagecleanuptask
"github.com/Wei-Shaw/sub2api/ent/usagecleanuptask"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service"
)
type
usageCleanupRepository
struct
{
client
*
dbent
.
Client
sql
sqlExecutor
}
func
NewUsageCleanupRepository
(
client
*
dbent
.
Client
,
sqlDB
*
sql
.
DB
)
service
.
UsageCleanupRepository
{
return
newUsageCleanupRepositoryWithSQL
(
client
,
sqlDB
)
}
func
newUsageCleanupRepositoryWithSQL
(
client
*
dbent
.
Client
,
sqlq
sqlExecutor
)
*
usageCleanupRepository
{
return
&
usageCleanupRepository
{
client
:
client
,
sql
:
sqlq
}
}
func
(
r
*
usageCleanupRepository
)
CreateTask
(
ctx
context
.
Context
,
task
*
service
.
UsageCleanupTask
)
error
{
if
task
==
nil
{
return
nil
}
if
r
.
client
!=
nil
{
return
r
.
createTaskWithEnt
(
ctx
,
task
)
}
return
r
.
createTaskWithSQL
(
ctx
,
task
)
}
func
(
r
*
usageCleanupRepository
)
ListTasks
(
ctx
context
.
Context
,
params
pagination
.
PaginationParams
)
([]
service
.
UsageCleanupTask
,
*
pagination
.
PaginationResult
,
error
)
{
if
r
.
client
!=
nil
{
return
r
.
listTasksWithEnt
(
ctx
,
params
)
}
var
total
int64
if
err
:=
scanSingleRow
(
ctx
,
r
.
sql
,
"SELECT COUNT(*) FROM usage_cleanup_tasks"
,
nil
,
&
total
);
err
!=
nil
{
return
nil
,
nil
,
err
}
if
total
==
0
{
return
[]
service
.
UsageCleanupTask
{},
paginationResultFromTotal
(
0
,
params
),
nil
}
query
:=
`
SELECT id, status, filters, created_by, deleted_rows, error_message,
canceled_by, canceled_at,
started_at, finished_at, created_at, updated_at
FROM usage_cleanup_tasks
ORDER BY created_at DESC, id DESC
LIMIT $1 OFFSET $2
`
rows
,
err
:=
r
.
sql
.
QueryContext
(
ctx
,
query
,
params
.
Limit
(),
params
.
Offset
())
if
err
!=
nil
{
return
nil
,
nil
,
err
}
defer
func
()
{
_
=
rows
.
Close
()
}()
tasks
:=
make
([]
service
.
UsageCleanupTask
,
0
)
for
rows
.
Next
()
{
var
task
service
.
UsageCleanupTask
var
filtersJSON
[]
byte
var
errMsg
sql
.
NullString
var
canceledBy
sql
.
NullInt64
var
canceledAt
sql
.
NullTime
var
startedAt
sql
.
NullTime
var
finishedAt
sql
.
NullTime
if
err
:=
rows
.
Scan
(
&
task
.
ID
,
&
task
.
Status
,
&
filtersJSON
,
&
task
.
CreatedBy
,
&
task
.
DeletedRows
,
&
errMsg
,
&
canceledBy
,
&
canceledAt
,
&
startedAt
,
&
finishedAt
,
&
task
.
CreatedAt
,
&
task
.
UpdatedAt
,
);
err
!=
nil
{
return
nil
,
nil
,
err
}
if
err
:=
json
.
Unmarshal
(
filtersJSON
,
&
task
.
Filters
);
err
!=
nil
{
return
nil
,
nil
,
fmt
.
Errorf
(
"parse cleanup filters: %w"
,
err
)
}
if
errMsg
.
Valid
{
task
.
ErrorMsg
=
&
errMsg
.
String
}
if
canceledBy
.
Valid
{
v
:=
canceledBy
.
Int64
task
.
CanceledBy
=
&
v
}
if
canceledAt
.
Valid
{
task
.
CanceledAt
=
&
canceledAt
.
Time
}
if
startedAt
.
Valid
{
task
.
StartedAt
=
&
startedAt
.
Time
}
if
finishedAt
.
Valid
{
task
.
FinishedAt
=
&
finishedAt
.
Time
}
tasks
=
append
(
tasks
,
task
)
}
if
err
:=
rows
.
Err
();
err
!=
nil
{
return
nil
,
nil
,
err
}
return
tasks
,
paginationResultFromTotal
(
total
,
params
),
nil
}
func
(
r
*
usageCleanupRepository
)
ClaimNextPendingTask
(
ctx
context
.
Context
,
staleRunningAfterSeconds
int64
)
(
*
service
.
UsageCleanupTask
,
error
)
{
if
staleRunningAfterSeconds
<=
0
{
staleRunningAfterSeconds
=
1800
}
query
:=
`
WITH next AS (
SELECT id
FROM usage_cleanup_tasks
WHERE status = $1
OR (
status = $2
AND started_at IS NOT NULL
AND started_at < NOW() - ($3 * interval '1 second')
)
ORDER BY created_at ASC
LIMIT 1
FOR UPDATE SKIP LOCKED
)
UPDATE usage_cleanup_tasks AS tasks
SET status = $4,
started_at = NOW(),
finished_at = NULL,
error_message = NULL,
updated_at = NOW()
FROM next
WHERE tasks.id = next.id
RETURNING tasks.id, tasks.status, tasks.filters, tasks.created_by, tasks.deleted_rows, tasks.error_message,
tasks.started_at, tasks.finished_at, tasks.created_at, tasks.updated_at
`
var
task
service
.
UsageCleanupTask
var
filtersJSON
[]
byte
var
errMsg
sql
.
NullString
var
startedAt
sql
.
NullTime
var
finishedAt
sql
.
NullTime
if
err
:=
scanSingleRow
(
ctx
,
r
.
sql
,
query
,
[]
any
{
service
.
UsageCleanupStatusPending
,
service
.
UsageCleanupStatusRunning
,
staleRunningAfterSeconds
,
service
.
UsageCleanupStatusRunning
,
},
&
task
.
ID
,
&
task
.
Status
,
&
filtersJSON
,
&
task
.
CreatedBy
,
&
task
.
DeletedRows
,
&
errMsg
,
&
startedAt
,
&
finishedAt
,
&
task
.
CreatedAt
,
&
task
.
UpdatedAt
,
);
err
!=
nil
{
if
errors
.
Is
(
err
,
sql
.
ErrNoRows
)
{
return
nil
,
nil
}
return
nil
,
err
}
if
err
:=
json
.
Unmarshal
(
filtersJSON
,
&
task
.
Filters
);
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"parse cleanup filters: %w"
,
err
)
}
if
errMsg
.
Valid
{
task
.
ErrorMsg
=
&
errMsg
.
String
}
if
startedAt
.
Valid
{
task
.
StartedAt
=
&
startedAt
.
Time
}
if
finishedAt
.
Valid
{
task
.
FinishedAt
=
&
finishedAt
.
Time
}
return
&
task
,
nil
}
func
(
r
*
usageCleanupRepository
)
GetTaskStatus
(
ctx
context
.
Context
,
taskID
int64
)
(
string
,
error
)
{
if
r
.
client
!=
nil
{
return
r
.
getTaskStatusWithEnt
(
ctx
,
taskID
)
}
var
status
string
if
err
:=
scanSingleRow
(
ctx
,
r
.
sql
,
"SELECT status FROM usage_cleanup_tasks WHERE id = $1"
,
[]
any
{
taskID
},
&
status
);
err
!=
nil
{
return
""
,
err
}
return
status
,
nil
}
func
(
r
*
usageCleanupRepository
)
UpdateTaskProgress
(
ctx
context
.
Context
,
taskID
int64
,
deletedRows
int64
)
error
{
if
r
.
client
!=
nil
{
return
r
.
updateTaskProgressWithEnt
(
ctx
,
taskID
,
deletedRows
)
}
query
:=
`
UPDATE usage_cleanup_tasks
SET deleted_rows = $1,
updated_at = NOW()
WHERE id = $2
`
_
,
err
:=
r
.
sql
.
ExecContext
(
ctx
,
query
,
deletedRows
,
taskID
)
return
err
}
func
(
r
*
usageCleanupRepository
)
CancelTask
(
ctx
context
.
Context
,
taskID
int64
,
canceledBy
int64
)
(
bool
,
error
)
{
if
r
.
client
!=
nil
{
return
r
.
cancelTaskWithEnt
(
ctx
,
taskID
,
canceledBy
)
}
query
:=
`
UPDATE usage_cleanup_tasks
SET status = $1,
canceled_by = $3,
canceled_at = NOW(),
finished_at = NOW(),
error_message = NULL,
updated_at = NOW()
WHERE id = $2
AND status IN ($4, $5)
RETURNING id
`
var
id
int64
err
:=
scanSingleRow
(
ctx
,
r
.
sql
,
query
,
[]
any
{
service
.
UsageCleanupStatusCanceled
,
taskID
,
canceledBy
,
service
.
UsageCleanupStatusPending
,
service
.
UsageCleanupStatusRunning
,
},
&
id
)
if
errors
.
Is
(
err
,
sql
.
ErrNoRows
)
{
return
false
,
nil
}
if
err
!=
nil
{
return
false
,
err
}
return
true
,
nil
}
func
(
r
*
usageCleanupRepository
)
MarkTaskSucceeded
(
ctx
context
.
Context
,
taskID
int64
,
deletedRows
int64
)
error
{
if
r
.
client
!=
nil
{
return
r
.
markTaskSucceededWithEnt
(
ctx
,
taskID
,
deletedRows
)
}
query
:=
`
UPDATE usage_cleanup_tasks
SET status = $1,
deleted_rows = $2,
finished_at = NOW(),
updated_at = NOW()
WHERE id = $3
`
_
,
err
:=
r
.
sql
.
ExecContext
(
ctx
,
query
,
service
.
UsageCleanupStatusSucceeded
,
deletedRows
,
taskID
)
return
err
}
func
(
r
*
usageCleanupRepository
)
MarkTaskFailed
(
ctx
context
.
Context
,
taskID
int64
,
deletedRows
int64
,
errorMsg
string
)
error
{
if
r
.
client
!=
nil
{
return
r
.
markTaskFailedWithEnt
(
ctx
,
taskID
,
deletedRows
,
errorMsg
)
}
query
:=
`
UPDATE usage_cleanup_tasks
SET status = $1,
deleted_rows = $2,
error_message = $3,
finished_at = NOW(),
updated_at = NOW()
WHERE id = $4
`
_
,
err
:=
r
.
sql
.
ExecContext
(
ctx
,
query
,
service
.
UsageCleanupStatusFailed
,
deletedRows
,
errorMsg
,
taskID
)
return
err
}
func
(
r
*
usageCleanupRepository
)
DeleteUsageLogsBatch
(
ctx
context
.
Context
,
filters
service
.
UsageCleanupFilters
,
limit
int
)
(
int64
,
error
)
{
if
filters
.
StartTime
.
IsZero
()
||
filters
.
EndTime
.
IsZero
()
{
return
0
,
fmt
.
Errorf
(
"cleanup filters missing time range"
)
}
whereClause
,
args
:=
buildUsageCleanupWhere
(
filters
)
if
whereClause
==
""
{
return
0
,
fmt
.
Errorf
(
"cleanup filters missing time range"
)
}
args
=
append
(
args
,
limit
)
query
:=
fmt
.
Sprintf
(
`
WITH target AS (
SELECT id
FROM usage_logs
WHERE %s
ORDER BY created_at ASC, id ASC
LIMIT $%d
)
DELETE FROM usage_logs
WHERE id IN (SELECT id FROM target)
RETURNING id
`
,
whereClause
,
len
(
args
))
rows
,
err
:=
r
.
sql
.
QueryContext
(
ctx
,
query
,
args
...
)
if
err
!=
nil
{
return
0
,
err
}
defer
func
()
{
_
=
rows
.
Close
()
}()
var
deleted
int64
for
rows
.
Next
()
{
deleted
++
}
if
err
:=
rows
.
Err
();
err
!=
nil
{
return
0
,
err
}
return
deleted
,
nil
}
func
buildUsageCleanupWhere
(
filters
service
.
UsageCleanupFilters
)
(
string
,
[]
any
)
{
conditions
:=
make
([]
string
,
0
,
8
)
args
:=
make
([]
any
,
0
,
8
)
idx
:=
1
if
!
filters
.
StartTime
.
IsZero
()
{
conditions
=
append
(
conditions
,
fmt
.
Sprintf
(
"created_at >= $%d"
,
idx
))
args
=
append
(
args
,
filters
.
StartTime
)
idx
++
}
if
!
filters
.
EndTime
.
IsZero
()
{
conditions
=
append
(
conditions
,
fmt
.
Sprintf
(
"created_at <= $%d"
,
idx
))
args
=
append
(
args
,
filters
.
EndTime
)
idx
++
}
if
filters
.
UserID
!=
nil
{
conditions
=
append
(
conditions
,
fmt
.
Sprintf
(
"user_id = $%d"
,
idx
))
args
=
append
(
args
,
*
filters
.
UserID
)
idx
++
}
if
filters
.
APIKeyID
!=
nil
{
conditions
=
append
(
conditions
,
fmt
.
Sprintf
(
"api_key_id = $%d"
,
idx
))
args
=
append
(
args
,
*
filters
.
APIKeyID
)
idx
++
}
if
filters
.
AccountID
!=
nil
{
conditions
=
append
(
conditions
,
fmt
.
Sprintf
(
"account_id = $%d"
,
idx
))
args
=
append
(
args
,
*
filters
.
AccountID
)
idx
++
}
if
filters
.
GroupID
!=
nil
{
conditions
=
append
(
conditions
,
fmt
.
Sprintf
(
"group_id = $%d"
,
idx
))
args
=
append
(
args
,
*
filters
.
GroupID
)
idx
++
}
if
filters
.
Model
!=
nil
{
model
:=
strings
.
TrimSpace
(
*
filters
.
Model
)
if
model
!=
""
{
conditions
=
append
(
conditions
,
fmt
.
Sprintf
(
"model = $%d"
,
idx
))
args
=
append
(
args
,
model
)
idx
++
}
}
if
filters
.
Stream
!=
nil
{
conditions
=
append
(
conditions
,
fmt
.
Sprintf
(
"stream = $%d"
,
idx
))
args
=
append
(
args
,
*
filters
.
Stream
)
idx
++
}
if
filters
.
BillingType
!=
nil
{
conditions
=
append
(
conditions
,
fmt
.
Sprintf
(
"billing_type = $%d"
,
idx
))
args
=
append
(
args
,
*
filters
.
BillingType
)
}
return
strings
.
Join
(
conditions
,
" AND "
),
args
}
func
(
r
*
usageCleanupRepository
)
createTaskWithEnt
(
ctx
context
.
Context
,
task
*
service
.
UsageCleanupTask
)
error
{
client
:=
clientFromContext
(
ctx
,
r
.
client
)
filtersJSON
,
err
:=
json
.
Marshal
(
task
.
Filters
)
if
err
!=
nil
{
return
fmt
.
Errorf
(
"marshal cleanup filters: %w"
,
err
)
}
created
,
err
:=
client
.
UsageCleanupTask
.
Create
()
.
SetStatus
(
task
.
Status
)
.
SetFilters
(
json
.
RawMessage
(
filtersJSON
))
.
SetCreatedBy
(
task
.
CreatedBy
)
.
SetDeletedRows
(
task
.
DeletedRows
)
.
Save
(
ctx
)
if
err
!=
nil
{
return
err
}
task
.
ID
=
created
.
ID
task
.
CreatedAt
=
created
.
CreatedAt
task
.
UpdatedAt
=
created
.
UpdatedAt
return
nil
}
func
(
r
*
usageCleanupRepository
)
createTaskWithSQL
(
ctx
context
.
Context
,
task
*
service
.
UsageCleanupTask
)
error
{
filtersJSON
,
err
:=
json
.
Marshal
(
task
.
Filters
)
if
err
!=
nil
{
return
fmt
.
Errorf
(
"marshal cleanup filters: %w"
,
err
)
}
query
:=
`
INSERT INTO usage_cleanup_tasks (
status,
filters,
created_by,
deleted_rows
) VALUES ($1, $2, $3, $4)
RETURNING id, created_at, updated_at
`
if
err
:=
scanSingleRow
(
ctx
,
r
.
sql
,
query
,
[]
any
{
task
.
Status
,
filtersJSON
,
task
.
CreatedBy
,
task
.
DeletedRows
},
&
task
.
ID
,
&
task
.
CreatedAt
,
&
task
.
UpdatedAt
);
err
!=
nil
{
return
err
}
return
nil
}
func
(
r
*
usageCleanupRepository
)
listTasksWithEnt
(
ctx
context
.
Context
,
params
pagination
.
PaginationParams
)
([]
service
.
UsageCleanupTask
,
*
pagination
.
PaginationResult
,
error
)
{
client
:=
clientFromContext
(
ctx
,
r
.
client
)
query
:=
client
.
UsageCleanupTask
.
Query
()
total
,
err
:=
query
.
Clone
()
.
Count
(
ctx
)
if
err
!=
nil
{
return
nil
,
nil
,
err
}
if
total
==
0
{
return
[]
service
.
UsageCleanupTask
{},
paginationResultFromTotal
(
0
,
params
),
nil
}
rows
,
err
:=
query
.
Order
(
dbent
.
Desc
(
dbusagecleanuptask
.
FieldCreatedAt
),
dbent
.
Desc
(
dbusagecleanuptask
.
FieldID
))
.
Offset
(
params
.
Offset
())
.
Limit
(
params
.
Limit
())
.
All
(
ctx
)
if
err
!=
nil
{
return
nil
,
nil
,
err
}
tasks
:=
make
([]
service
.
UsageCleanupTask
,
0
,
len
(
rows
))
for
_
,
row
:=
range
rows
{
task
,
err
:=
usageCleanupTaskFromEnt
(
row
)
if
err
!=
nil
{
return
nil
,
nil
,
err
}
tasks
=
append
(
tasks
,
task
)
}
return
tasks
,
paginationResultFromTotal
(
int64
(
total
),
params
),
nil
}
func
(
r
*
usageCleanupRepository
)
getTaskStatusWithEnt
(
ctx
context
.
Context
,
taskID
int64
)
(
string
,
error
)
{
client
:=
clientFromContext
(
ctx
,
r
.
client
)
task
,
err
:=
client
.
UsageCleanupTask
.
Query
()
.
Where
(
dbusagecleanuptask
.
IDEQ
(
taskID
))
.
Only
(
ctx
)
if
err
!=
nil
{
if
dbent
.
IsNotFound
(
err
)
{
return
""
,
sql
.
ErrNoRows
}
return
""
,
err
}
return
task
.
Status
,
nil
}
func
(
r
*
usageCleanupRepository
)
updateTaskProgressWithEnt
(
ctx
context
.
Context
,
taskID
int64
,
deletedRows
int64
)
error
{
client
:=
clientFromContext
(
ctx
,
r
.
client
)
now
:=
time
.
Now
()
_
,
err
:=
client
.
UsageCleanupTask
.
Update
()
.
Where
(
dbusagecleanuptask
.
IDEQ
(
taskID
))
.
SetDeletedRows
(
deletedRows
)
.
SetUpdatedAt
(
now
)
.
Save
(
ctx
)
return
err
}
func
(
r
*
usageCleanupRepository
)
cancelTaskWithEnt
(
ctx
context
.
Context
,
taskID
int64
,
canceledBy
int64
)
(
bool
,
error
)
{
client
:=
clientFromContext
(
ctx
,
r
.
client
)
now
:=
time
.
Now
()
affected
,
err
:=
client
.
UsageCleanupTask
.
Update
()
.
Where
(
dbusagecleanuptask
.
IDEQ
(
taskID
),
dbusagecleanuptask
.
StatusIn
(
service
.
UsageCleanupStatusPending
,
service
.
UsageCleanupStatusRunning
),
)
.
SetStatus
(
service
.
UsageCleanupStatusCanceled
)
.
SetCanceledBy
(
canceledBy
)
.
SetCanceledAt
(
now
)
.
SetFinishedAt
(
now
)
.
ClearErrorMessage
()
.
SetUpdatedAt
(
now
)
.
Save
(
ctx
)
if
err
!=
nil
{
return
false
,
err
}
return
affected
>
0
,
nil
}
func
(
r
*
usageCleanupRepository
)
markTaskSucceededWithEnt
(
ctx
context
.
Context
,
taskID
int64
,
deletedRows
int64
)
error
{
client
:=
clientFromContext
(
ctx
,
r
.
client
)
now
:=
time
.
Now
()
_
,
err
:=
client
.
UsageCleanupTask
.
Update
()
.
Where
(
dbusagecleanuptask
.
IDEQ
(
taskID
))
.
SetStatus
(
service
.
UsageCleanupStatusSucceeded
)
.
SetDeletedRows
(
deletedRows
)
.
SetFinishedAt
(
now
)
.
SetUpdatedAt
(
now
)
.
Save
(
ctx
)
return
err
}
func
(
r
*
usageCleanupRepository
)
markTaskFailedWithEnt
(
ctx
context
.
Context
,
taskID
int64
,
deletedRows
int64
,
errorMsg
string
)
error
{
client
:=
clientFromContext
(
ctx
,
r
.
client
)
now
:=
time
.
Now
()
_
,
err
:=
client
.
UsageCleanupTask
.
Update
()
.
Where
(
dbusagecleanuptask
.
IDEQ
(
taskID
))
.
SetStatus
(
service
.
UsageCleanupStatusFailed
)
.
SetDeletedRows
(
deletedRows
)
.
SetErrorMessage
(
errorMsg
)
.
SetFinishedAt
(
now
)
.
SetUpdatedAt
(
now
)
.
Save
(
ctx
)
return
err
}
func
usageCleanupTaskFromEnt
(
row
*
dbent
.
UsageCleanupTask
)
(
service
.
UsageCleanupTask
,
error
)
{
task
:=
service
.
UsageCleanupTask
{
ID
:
row
.
ID
,
Status
:
row
.
Status
,
CreatedBy
:
row
.
CreatedBy
,
DeletedRows
:
row
.
DeletedRows
,
CreatedAt
:
row
.
CreatedAt
,
UpdatedAt
:
row
.
UpdatedAt
,
}
if
len
(
row
.
Filters
)
>
0
{
if
err
:=
json
.
Unmarshal
(
row
.
Filters
,
&
task
.
Filters
);
err
!=
nil
{
return
service
.
UsageCleanupTask
{},
fmt
.
Errorf
(
"parse cleanup filters: %w"
,
err
)
}
}
if
row
.
ErrorMessage
!=
nil
{
task
.
ErrorMsg
=
row
.
ErrorMessage
}
if
row
.
CanceledBy
!=
nil
{
task
.
CanceledBy
=
row
.
CanceledBy
}
if
row
.
CanceledAt
!=
nil
{
task
.
CanceledAt
=
row
.
CanceledAt
}
if
row
.
StartedAt
!=
nil
{
task
.
StartedAt
=
row
.
StartedAt
}
if
row
.
FinishedAt
!=
nil
{
task
.
FinishedAt
=
row
.
FinishedAt
}
return
task
,
nil
}
backend/internal/repository/usage_cleanup_repo_ent_test.go
0 → 100644
View file @
0170d19f
package
repository
import
(
"context"
"database/sql"
"encoding/json"
"testing"
"time"
dbent
"github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/enttest"
dbusagecleanuptask
"github.com/Wei-Shaw/sub2api/ent/usagecleanuptask"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/stretchr/testify/require"
"entgo.io/ent/dialect"
entsql
"entgo.io/ent/dialect/sql"
_
"modernc.org/sqlite"
)
func
newUsageCleanupEntRepo
(
t
*
testing
.
T
)
(
*
usageCleanupRepository
,
*
dbent
.
Client
)
{
t
.
Helper
()
db
,
err
:=
sql
.
Open
(
"sqlite"
,
"file:usage_cleanup?mode=memory&cache=shared"
)
require
.
NoError
(
t
,
err
)
t
.
Cleanup
(
func
()
{
_
=
db
.
Close
()
})
_
,
err
=
db
.
Exec
(
"PRAGMA foreign_keys = ON"
)
require
.
NoError
(
t
,
err
)
drv
:=
entsql
.
OpenDB
(
dialect
.
SQLite
,
db
)
client
:=
enttest
.
NewClient
(
t
,
enttest
.
WithOptions
(
dbent
.
Driver
(
drv
)))
t
.
Cleanup
(
func
()
{
_
=
client
.
Close
()
})
repo
:=
&
usageCleanupRepository
{
client
:
client
,
sql
:
db
}
return
repo
,
client
}
func
TestUsageCleanupRepositoryEntCreateAndList
(
t
*
testing
.
T
)
{
repo
,
_
:=
newUsageCleanupEntRepo
(
t
)
start
:=
time
.
Date
(
2024
,
1
,
2
,
0
,
0
,
0
,
0
,
time
.
UTC
)
end
:=
start
.
Add
(
24
*
time
.
Hour
)
task
:=
&
service
.
UsageCleanupTask
{
Status
:
service
.
UsageCleanupStatusPending
,
Filters
:
service
.
UsageCleanupFilters
{
StartTime
:
start
,
EndTime
:
end
},
CreatedBy
:
9
,
}
require
.
NoError
(
t
,
repo
.
CreateTask
(
context
.
Background
(),
task
))
require
.
NotZero
(
t
,
task
.
ID
)
task2
:=
&
service
.
UsageCleanupTask
{
Status
:
service
.
UsageCleanupStatusRunning
,
Filters
:
service
.
UsageCleanupFilters
{
StartTime
:
start
.
Add
(
-
24
*
time
.
Hour
),
EndTime
:
end
.
Add
(
-
24
*
time
.
Hour
)},
CreatedBy
:
10
,
}
require
.
NoError
(
t
,
repo
.
CreateTask
(
context
.
Background
(),
task2
))
tasks
,
result
,
err
:=
repo
.
ListTasks
(
context
.
Background
(),
pagination
.
PaginationParams
{
Page
:
1
,
PageSize
:
10
})
require
.
NoError
(
t
,
err
)
require
.
Len
(
t
,
tasks
,
2
)
require
.
Equal
(
t
,
int64
(
2
),
result
.
Total
)
require
.
Greater
(
t
,
tasks
[
0
]
.
ID
,
tasks
[
1
]
.
ID
)
require
.
Equal
(
t
,
start
,
tasks
[
1
]
.
Filters
.
StartTime
)
require
.
Equal
(
t
,
end
,
tasks
[
1
]
.
Filters
.
EndTime
)
}
func
TestUsageCleanupRepositoryEntListEmpty
(
t
*
testing
.
T
)
{
repo
,
_
:=
newUsageCleanupEntRepo
(
t
)
tasks
,
result
,
err
:=
repo
.
ListTasks
(
context
.
Background
(),
pagination
.
PaginationParams
{
Page
:
1
,
PageSize
:
10
})
require
.
NoError
(
t
,
err
)
require
.
Empty
(
t
,
tasks
)
require
.
Equal
(
t
,
int64
(
0
),
result
.
Total
)
}
func
TestUsageCleanupRepositoryEntGetStatusAndProgress
(
t
*
testing
.
T
)
{
repo
,
client
:=
newUsageCleanupEntRepo
(
t
)
task
:=
&
service
.
UsageCleanupTask
{
Status
:
service
.
UsageCleanupStatusPending
,
Filters
:
service
.
UsageCleanupFilters
{
StartTime
:
time
.
Now
()
.
UTC
(),
EndTime
:
time
.
Now
()
.
UTC
()
.
Add
(
time
.
Hour
)},
CreatedBy
:
3
,
}
require
.
NoError
(
t
,
repo
.
CreateTask
(
context
.
Background
(),
task
))
status
,
err
:=
repo
.
GetTaskStatus
(
context
.
Background
(),
task
.
ID
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
service
.
UsageCleanupStatusPending
,
status
)
_
,
err
=
repo
.
GetTaskStatus
(
context
.
Background
(),
task
.
ID
+
99
)
require
.
ErrorIs
(
t
,
err
,
sql
.
ErrNoRows
)
require
.
NoError
(
t
,
repo
.
UpdateTaskProgress
(
context
.
Background
(),
task
.
ID
,
42
))
loaded
,
err
:=
client
.
UsageCleanupTask
.
Get
(
context
.
Background
(),
task
.
ID
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
int64
(
42
),
loaded
.
DeletedRows
)
}
func
TestUsageCleanupRepositoryEntCancelAndFinish
(
t
*
testing
.
T
)
{
repo
,
client
:=
newUsageCleanupEntRepo
(
t
)
task
:=
&
service
.
UsageCleanupTask
{
Status
:
service
.
UsageCleanupStatusPending
,
Filters
:
service
.
UsageCleanupFilters
{
StartTime
:
time
.
Now
()
.
UTC
(),
EndTime
:
time
.
Now
()
.
UTC
()
.
Add
(
time
.
Hour
)},
CreatedBy
:
5
,
}
require
.
NoError
(
t
,
repo
.
CreateTask
(
context
.
Background
(),
task
))
ok
,
err
:=
repo
.
CancelTask
(
context
.
Background
(),
task
.
ID
,
7
)
require
.
NoError
(
t
,
err
)
require
.
True
(
t
,
ok
)
loaded
,
err
:=
client
.
UsageCleanupTask
.
Get
(
context
.
Background
(),
task
.
ID
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
service
.
UsageCleanupStatusCanceled
,
loaded
.
Status
)
require
.
NotNil
(
t
,
loaded
.
CanceledBy
)
require
.
NotNil
(
t
,
loaded
.
CanceledAt
)
require
.
NotNil
(
t
,
loaded
.
FinishedAt
)
loaded
.
Status
=
service
.
UsageCleanupStatusSucceeded
_
,
err
=
client
.
UsageCleanupTask
.
Update
()
.
Where
(
dbusagecleanuptask
.
IDEQ
(
task
.
ID
))
.
SetStatus
(
loaded
.
Status
)
.
Save
(
context
.
Background
())
require
.
NoError
(
t
,
err
)
ok
,
err
=
repo
.
CancelTask
(
context
.
Background
(),
task
.
ID
,
7
)
require
.
NoError
(
t
,
err
)
require
.
False
(
t
,
ok
)
}
func
TestUsageCleanupRepositoryEntCancelError
(
t
*
testing
.
T
)
{
repo
,
client
:=
newUsageCleanupEntRepo
(
t
)
task
:=
&
service
.
UsageCleanupTask
{
Status
:
service
.
UsageCleanupStatusPending
,
Filters
:
service
.
UsageCleanupFilters
{
StartTime
:
time
.
Now
()
.
UTC
(),
EndTime
:
time
.
Now
()
.
UTC
()
.
Add
(
time
.
Hour
)},
CreatedBy
:
5
,
}
require
.
NoError
(
t
,
repo
.
CreateTask
(
context
.
Background
(),
task
))
require
.
NoError
(
t
,
client
.
Close
())
_
,
err
:=
repo
.
CancelTask
(
context
.
Background
(),
task
.
ID
,
7
)
require
.
Error
(
t
,
err
)
}
func
TestUsageCleanupRepositoryEntMarkResults
(
t
*
testing
.
T
)
{
repo
,
client
:=
newUsageCleanupEntRepo
(
t
)
task
:=
&
service
.
UsageCleanupTask
{
Status
:
service
.
UsageCleanupStatusRunning
,
Filters
:
service
.
UsageCleanupFilters
{
StartTime
:
time
.
Now
()
.
UTC
(),
EndTime
:
time
.
Now
()
.
UTC
()
.
Add
(
time
.
Hour
)},
CreatedBy
:
12
,
}
require
.
NoError
(
t
,
repo
.
CreateTask
(
context
.
Background
(),
task
))
require
.
NoError
(
t
,
repo
.
MarkTaskSucceeded
(
context
.
Background
(),
task
.
ID
,
6
))
loaded
,
err
:=
client
.
UsageCleanupTask
.
Get
(
context
.
Background
(),
task
.
ID
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
service
.
UsageCleanupStatusSucceeded
,
loaded
.
Status
)
require
.
Equal
(
t
,
int64
(
6
),
loaded
.
DeletedRows
)
require
.
NotNil
(
t
,
loaded
.
FinishedAt
)
task2
:=
&
service
.
UsageCleanupTask
{
Status
:
service
.
UsageCleanupStatusRunning
,
Filters
:
service
.
UsageCleanupFilters
{
StartTime
:
time
.
Now
()
.
UTC
(),
EndTime
:
time
.
Now
()
.
UTC
()
.
Add
(
time
.
Hour
)},
CreatedBy
:
12
,
}
require
.
NoError
(
t
,
repo
.
CreateTask
(
context
.
Background
(),
task2
))
require
.
NoError
(
t
,
repo
.
MarkTaskFailed
(
context
.
Background
(),
task2
.
ID
,
4
,
"boom"
))
loaded2
,
err
:=
client
.
UsageCleanupTask
.
Get
(
context
.
Background
(),
task2
.
ID
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
service
.
UsageCleanupStatusFailed
,
loaded2
.
Status
)
require
.
Equal
(
t
,
"boom"
,
*
loaded2
.
ErrorMessage
)
}
func
TestUsageCleanupRepositoryEntInvalidStatus
(
t
*
testing
.
T
)
{
repo
,
_
:=
newUsageCleanupEntRepo
(
t
)
task
:=
&
service
.
UsageCleanupTask
{
Status
:
"invalid"
,
Filters
:
service
.
UsageCleanupFilters
{
StartTime
:
time
.
Now
()
.
UTC
(),
EndTime
:
time
.
Now
()
.
UTC
()
.
Add
(
time
.
Hour
)},
CreatedBy
:
1
,
}
require
.
Error
(
t
,
repo
.
CreateTask
(
context
.
Background
(),
task
))
}
func
TestUsageCleanupRepositoryEntListInvalidFilters
(
t
*
testing
.
T
)
{
repo
,
client
:=
newUsageCleanupEntRepo
(
t
)
now
:=
time
.
Now
()
.
UTC
()
driver
,
ok
:=
client
.
Driver
()
.
(
*
entsql
.
Driver
)
require
.
True
(
t
,
ok
)
_
,
err
:=
driver
.
DB
()
.
ExecContext
(
context
.
Background
(),
`INSERT INTO usage_cleanup_tasks (status, filters, created_by, deleted_rows, created_at, updated_at)
VALUES (?, ?, ?, ?, ?, ?)`
,
service
.
UsageCleanupStatusPending
,
[]
byte
(
"invalid-json"
),
int64
(
1
),
int64
(
0
),
now
,
now
,
)
require
.
NoError
(
t
,
err
)
_
,
_
,
err
=
repo
.
ListTasks
(
context
.
Background
(),
pagination
.
PaginationParams
{
Page
:
1
,
PageSize
:
10
})
require
.
Error
(
t
,
err
)
}
func
TestUsageCleanupTaskFromEntFull
(
t
*
testing
.
T
)
{
start
:=
time
.
Date
(
2024
,
1
,
2
,
0
,
0
,
0
,
0
,
time
.
UTC
)
end
:=
start
.
Add
(
24
*
time
.
Hour
)
errMsg
:=
"failed"
canceledBy
:=
int64
(
2
)
canceledAt
:=
start
.
Add
(
time
.
Minute
)
startedAt
:=
start
.
Add
(
2
*
time
.
Minute
)
finishedAt
:=
start
.
Add
(
3
*
time
.
Minute
)
filters
:=
service
.
UsageCleanupFilters
{
StartTime
:
start
,
EndTime
:
end
}
filtersJSON
,
err
:=
json
.
Marshal
(
filters
)
require
.
NoError
(
t
,
err
)
task
,
err
:=
usageCleanupTaskFromEnt
(
&
dbent
.
UsageCleanupTask
{
ID
:
10
,
Status
:
service
.
UsageCleanupStatusFailed
,
Filters
:
filtersJSON
,
CreatedBy
:
11
,
DeletedRows
:
7
,
ErrorMessage
:
&
errMsg
,
CanceledBy
:
&
canceledBy
,
CanceledAt
:
&
canceledAt
,
StartedAt
:
&
startedAt
,
FinishedAt
:
&
finishedAt
,
CreatedAt
:
start
,
UpdatedAt
:
end
,
})
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
int64
(
10
),
task
.
ID
)
require
.
Equal
(
t
,
service
.
UsageCleanupStatusFailed
,
task
.
Status
)
require
.
NotNil
(
t
,
task
.
ErrorMsg
)
require
.
NotNil
(
t
,
task
.
CanceledBy
)
require
.
NotNil
(
t
,
task
.
CanceledAt
)
require
.
NotNil
(
t
,
task
.
StartedAt
)
require
.
NotNil
(
t
,
task
.
FinishedAt
)
}
func
TestUsageCleanupTaskFromEntInvalidFilters
(
t
*
testing
.
T
)
{
task
,
err
:=
usageCleanupTaskFromEnt
(
&
dbent
.
UsageCleanupTask
{
Filters
:
json
.
RawMessage
(
"invalid-json"
),
})
require
.
Error
(
t
,
err
)
require
.
Empty
(
t
,
task
)
}
backend/internal/repository/usage_cleanup_repo_test.go
0 → 100644
View file @
0170d19f
package
repository
import
(
"context"
"database/sql"
"encoding/json"
"testing"
"time"
"github.com/DATA-DOG/go-sqlmock"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/stretchr/testify/require"
)
func
newSQLMock
(
t
*
testing
.
T
)
(
*
sql
.
DB
,
sqlmock
.
Sqlmock
)
{
t
.
Helper
()
db
,
mock
,
err
:=
sqlmock
.
New
(
sqlmock
.
QueryMatcherOption
(
sqlmock
.
QueryMatcherRegexp
))
require
.
NoError
(
t
,
err
)
t
.
Cleanup
(
func
()
{
_
=
db
.
Close
()
})
return
db
,
mock
}
func
TestNewUsageCleanupRepository
(
t
*
testing
.
T
)
{
db
,
_
:=
newSQLMock
(
t
)
repo
:=
NewUsageCleanupRepository
(
nil
,
db
)
require
.
NotNil
(
t
,
repo
)
}
func
TestUsageCleanupRepositoryCreateTask
(
t
*
testing
.
T
)
{
db
,
mock
:=
newSQLMock
(
t
)
repo
:=
&
usageCleanupRepository
{
sql
:
db
}
start
:=
time
.
Date
(
2024
,
1
,
1
,
0
,
0
,
0
,
0
,
time
.
UTC
)
end
:=
start
.
Add
(
24
*
time
.
Hour
)
task
:=
&
service
.
UsageCleanupTask
{
Status
:
service
.
UsageCleanupStatusPending
,
Filters
:
service
.
UsageCleanupFilters
{
StartTime
:
start
,
EndTime
:
end
},
CreatedBy
:
12
,
}
now
:=
time
.
Date
(
2024
,
1
,
2
,
0
,
0
,
0
,
0
,
time
.
UTC
)
mock
.
ExpectQuery
(
"INSERT INTO usage_cleanup_tasks"
)
.
WithArgs
(
task
.
Status
,
sqlmock
.
AnyArg
(),
task
.
CreatedBy
,
task
.
DeletedRows
)
.
WillReturnRows
(
sqlmock
.
NewRows
([]
string
{
"id"
,
"created_at"
,
"updated_at"
})
.
AddRow
(
int64
(
1
),
now
,
now
))
err
:=
repo
.
CreateTask
(
context
.
Background
(),
task
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
int64
(
1
),
task
.
ID
)
require
.
Equal
(
t
,
now
,
task
.
CreatedAt
)
require
.
Equal
(
t
,
now
,
task
.
UpdatedAt
)
require
.
NoError
(
t
,
mock
.
ExpectationsWereMet
())
}
func
TestUsageCleanupRepositoryCreateTaskNil
(
t
*
testing
.
T
)
{
db
,
mock
:=
newSQLMock
(
t
)
repo
:=
&
usageCleanupRepository
{
sql
:
db
}
err
:=
repo
.
CreateTask
(
context
.
Background
(),
nil
)
require
.
NoError
(
t
,
err
)
require
.
NoError
(
t
,
mock
.
ExpectationsWereMet
())
}
func
TestUsageCleanupRepositoryCreateTaskQueryError
(
t
*
testing
.
T
)
{
db
,
mock
:=
newSQLMock
(
t
)
repo
:=
&
usageCleanupRepository
{
sql
:
db
}
task
:=
&
service
.
UsageCleanupTask
{
Status
:
service
.
UsageCleanupStatusPending
,
Filters
:
service
.
UsageCleanupFilters
{
StartTime
:
time
.
Now
(),
EndTime
:
time
.
Now
()
.
Add
(
time
.
Hour
)},
CreatedBy
:
1
,
}
mock
.
ExpectQuery
(
"INSERT INTO usage_cleanup_tasks"
)
.
WithArgs
(
task
.
Status
,
sqlmock
.
AnyArg
(),
task
.
CreatedBy
,
task
.
DeletedRows
)
.
WillReturnError
(
sql
.
ErrConnDone
)
err
:=
repo
.
CreateTask
(
context
.
Background
(),
task
)
require
.
Error
(
t
,
err
)
require
.
NoError
(
t
,
mock
.
ExpectationsWereMet
())
}
func
TestUsageCleanupRepositoryListTasksEmpty
(
t
*
testing
.
T
)
{
db
,
mock
:=
newSQLMock
(
t
)
repo
:=
&
usageCleanupRepository
{
sql
:
db
}
mock
.
ExpectQuery
(
"SELECT COUNT
\\
(
\\
*
\\
) FROM usage_cleanup_tasks"
)
.
WillReturnRows
(
sqlmock
.
NewRows
([]
string
{
"count"
})
.
AddRow
(
int64
(
0
)))
tasks
,
result
,
err
:=
repo
.
ListTasks
(
context
.
Background
(),
pagination
.
PaginationParams
{
Page
:
1
,
PageSize
:
20
})
require
.
NoError
(
t
,
err
)
require
.
Empty
(
t
,
tasks
)
require
.
Equal
(
t
,
int64
(
0
),
result
.
Total
)
require
.
NoError
(
t
,
mock
.
ExpectationsWereMet
())
}
func
TestUsageCleanupRepositoryListTasks
(
t
*
testing
.
T
)
{
db
,
mock
:=
newSQLMock
(
t
)
repo
:=
&
usageCleanupRepository
{
sql
:
db
}
start
:=
time
.
Date
(
2024
,
1
,
1
,
0
,
0
,
0
,
0
,
time
.
UTC
)
end
:=
start
.
Add
(
2
*
time
.
Hour
)
filters
:=
service
.
UsageCleanupFilters
{
StartTime
:
start
,
EndTime
:
end
}
filtersJSON
,
err
:=
json
.
Marshal
(
filters
)
require
.
NoError
(
t
,
err
)
createdAt
:=
time
.
Date
(
2024
,
1
,
2
,
12
,
0
,
0
,
0
,
time
.
UTC
)
updatedAt
:=
createdAt
.
Add
(
time
.
Minute
)
rows
:=
sqlmock
.
NewRows
([]
string
{
"id"
,
"status"
,
"filters"
,
"created_by"
,
"deleted_rows"
,
"error_message"
,
"canceled_by"
,
"canceled_at"
,
"started_at"
,
"finished_at"
,
"created_at"
,
"updated_at"
,
})
.
AddRow
(
int64
(
1
),
service
.
UsageCleanupStatusSucceeded
,
filtersJSON
,
int64
(
2
),
int64
(
9
),
"error"
,
nil
,
nil
,
start
,
end
,
createdAt
,
updatedAt
,
)
mock
.
ExpectQuery
(
"SELECT COUNT
\\
(
\\
*
\\
) FROM usage_cleanup_tasks"
)
.
WillReturnRows
(
sqlmock
.
NewRows
([]
string
{
"count"
})
.
AddRow
(
int64
(
1
)))
mock
.
ExpectQuery
(
"SELECT id, status, filters, created_by, deleted_rows, error_message"
)
.
WithArgs
(
20
,
0
)
.
WillReturnRows
(
rows
)
tasks
,
result
,
err
:=
repo
.
ListTasks
(
context
.
Background
(),
pagination
.
PaginationParams
{
Page
:
1
,
PageSize
:
20
})
require
.
NoError
(
t
,
err
)
require
.
Len
(
t
,
tasks
,
1
)
require
.
Equal
(
t
,
int64
(
1
),
tasks
[
0
]
.
ID
)
require
.
Equal
(
t
,
service
.
UsageCleanupStatusSucceeded
,
tasks
[
0
]
.
Status
)
require
.
Equal
(
t
,
int64
(
2
),
tasks
[
0
]
.
CreatedBy
)
require
.
Equal
(
t
,
int64
(
9
),
tasks
[
0
]
.
DeletedRows
)
require
.
NotNil
(
t
,
tasks
[
0
]
.
ErrorMsg
)
require
.
Equal
(
t
,
"error"
,
*
tasks
[
0
]
.
ErrorMsg
)
require
.
NotNil
(
t
,
tasks
[
0
]
.
StartedAt
)
require
.
NotNil
(
t
,
tasks
[
0
]
.
FinishedAt
)
require
.
Equal
(
t
,
int64
(
1
),
result
.
Total
)
require
.
NoError
(
t
,
mock
.
ExpectationsWereMet
())
}
func
TestUsageCleanupRepositoryListTasksQueryError
(
t
*
testing
.
T
)
{
db
,
mock
:=
newSQLMock
(
t
)
repo
:=
&
usageCleanupRepository
{
sql
:
db
}
mock
.
ExpectQuery
(
"SELECT COUNT
\\
(
\\
*
\\
) FROM usage_cleanup_tasks"
)
.
WillReturnRows
(
sqlmock
.
NewRows
([]
string
{
"count"
})
.
AddRow
(
int64
(
2
)))
mock
.
ExpectQuery
(
"SELECT id, status, filters, created_by, deleted_rows, error_message"
)
.
WithArgs
(
20
,
0
)
.
WillReturnError
(
sql
.
ErrConnDone
)
_
,
_
,
err
:=
repo
.
ListTasks
(
context
.
Background
(),
pagination
.
PaginationParams
{
Page
:
1
,
PageSize
:
20
})
require
.
Error
(
t
,
err
)
require
.
NoError
(
t
,
mock
.
ExpectationsWereMet
())
}
func
TestUsageCleanupRepositoryListTasksInvalidFilters
(
t
*
testing
.
T
)
{
db
,
mock
:=
newSQLMock
(
t
)
repo
:=
&
usageCleanupRepository
{
sql
:
db
}
rows
:=
sqlmock
.
NewRows
([]
string
{
"id"
,
"status"
,
"filters"
,
"created_by"
,
"deleted_rows"
,
"error_message"
,
"canceled_by"
,
"canceled_at"
,
"started_at"
,
"finished_at"
,
"created_at"
,
"updated_at"
,
})
.
AddRow
(
int64
(
1
),
service
.
UsageCleanupStatusSucceeded
,
[]
byte
(
"not-json"
),
int64
(
2
),
int64
(
9
),
nil
,
nil
,
nil
,
nil
,
nil
,
time
.
Now
()
.
UTC
(),
time
.
Now
()
.
UTC
(),
)
mock
.
ExpectQuery
(
"SELECT COUNT
\\
(
\\
*
\\
) FROM usage_cleanup_tasks"
)
.
WillReturnRows
(
sqlmock
.
NewRows
([]
string
{
"count"
})
.
AddRow
(
int64
(
1
)))
mock
.
ExpectQuery
(
"SELECT id, status, filters, created_by, deleted_rows, error_message"
)
.
WithArgs
(
20
,
0
)
.
WillReturnRows
(
rows
)
_
,
_
,
err
:=
repo
.
ListTasks
(
context
.
Background
(),
pagination
.
PaginationParams
{
Page
:
1
,
PageSize
:
20
})
require
.
Error
(
t
,
err
)
require
.
NoError
(
t
,
mock
.
ExpectationsWereMet
())
}
func
TestUsageCleanupRepositoryClaimNextPendingTaskNone
(
t
*
testing
.
T
)
{
db
,
mock
:=
newSQLMock
(
t
)
repo
:=
&
usageCleanupRepository
{
sql
:
db
}
mock
.
ExpectQuery
(
"UPDATE usage_cleanup_tasks"
)
.
WithArgs
(
service
.
UsageCleanupStatusPending
,
service
.
UsageCleanupStatusRunning
,
int64
(
1800
),
service
.
UsageCleanupStatusRunning
)
.
WillReturnRows
(
sqlmock
.
NewRows
([]
string
{
"id"
,
"status"
,
"filters"
,
"created_by"
,
"deleted_rows"
,
"error_message"
,
"started_at"
,
"finished_at"
,
"created_at"
,
"updated_at"
,
}))
task
,
err
:=
repo
.
ClaimNextPendingTask
(
context
.
Background
(),
1800
)
require
.
NoError
(
t
,
err
)
require
.
Nil
(
t
,
task
)
require
.
NoError
(
t
,
mock
.
ExpectationsWereMet
())
}
func
TestUsageCleanupRepositoryClaimNextPendingTask
(
t
*
testing
.
T
)
{
db
,
mock
:=
newSQLMock
(
t
)
repo
:=
&
usageCleanupRepository
{
sql
:
db
}
start
:=
time
.
Date
(
2024
,
1
,
1
,
0
,
0
,
0
,
0
,
time
.
UTC
)
end
:=
start
.
Add
(
24
*
time
.
Hour
)
filters
:=
service
.
UsageCleanupFilters
{
StartTime
:
start
,
EndTime
:
end
}
filtersJSON
,
err
:=
json
.
Marshal
(
filters
)
require
.
NoError
(
t
,
err
)
rows
:=
sqlmock
.
NewRows
([]
string
{
"id"
,
"status"
,
"filters"
,
"created_by"
,
"deleted_rows"
,
"error_message"
,
"started_at"
,
"finished_at"
,
"created_at"
,
"updated_at"
,
})
.
AddRow
(
int64
(
4
),
service
.
UsageCleanupStatusRunning
,
filtersJSON
,
int64
(
7
),
int64
(
0
),
nil
,
start
,
nil
,
start
,
start
,
)
mock
.
ExpectQuery
(
"UPDATE usage_cleanup_tasks"
)
.
WithArgs
(
service
.
UsageCleanupStatusPending
,
service
.
UsageCleanupStatusRunning
,
int64
(
1800
),
service
.
UsageCleanupStatusRunning
)
.
WillReturnRows
(
rows
)
task
,
err
:=
repo
.
ClaimNextPendingTask
(
context
.
Background
(),
1800
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
task
)
require
.
Equal
(
t
,
int64
(
4
),
task
.
ID
)
require
.
Equal
(
t
,
service
.
UsageCleanupStatusRunning
,
task
.
Status
)
require
.
Equal
(
t
,
int64
(
7
),
task
.
CreatedBy
)
require
.
NotNil
(
t
,
task
.
StartedAt
)
require
.
Nil
(
t
,
task
.
ErrorMsg
)
require
.
NoError
(
t
,
mock
.
ExpectationsWereMet
())
}
func
TestUsageCleanupRepositoryClaimNextPendingTaskError
(
t
*
testing
.
T
)
{
db
,
mock
:=
newSQLMock
(
t
)
repo
:=
&
usageCleanupRepository
{
sql
:
db
}
mock
.
ExpectQuery
(
"UPDATE usage_cleanup_tasks"
)
.
WithArgs
(
service
.
UsageCleanupStatusPending
,
service
.
UsageCleanupStatusRunning
,
int64
(
1800
),
service
.
UsageCleanupStatusRunning
)
.
WillReturnError
(
sql
.
ErrConnDone
)
_
,
err
:=
repo
.
ClaimNextPendingTask
(
context
.
Background
(),
1800
)
require
.
Error
(
t
,
err
)
require
.
NoError
(
t
,
mock
.
ExpectationsWereMet
())
}
func
TestUsageCleanupRepositoryClaimNextPendingTaskInvalidFilters
(
t
*
testing
.
T
)
{
db
,
mock
:=
newSQLMock
(
t
)
repo
:=
&
usageCleanupRepository
{
sql
:
db
}
rows
:=
sqlmock
.
NewRows
([]
string
{
"id"
,
"status"
,
"filters"
,
"created_by"
,
"deleted_rows"
,
"error_message"
,
"started_at"
,
"finished_at"
,
"created_at"
,
"updated_at"
,
})
.
AddRow
(
int64
(
4
),
service
.
UsageCleanupStatusRunning
,
[]
byte
(
"invalid"
),
int64
(
7
),
int64
(
0
),
nil
,
nil
,
nil
,
time
.
Now
()
.
UTC
(),
time
.
Now
()
.
UTC
(),
)
mock
.
ExpectQuery
(
"UPDATE usage_cleanup_tasks"
)
.
WithArgs
(
service
.
UsageCleanupStatusPending
,
service
.
UsageCleanupStatusRunning
,
int64
(
1800
),
service
.
UsageCleanupStatusRunning
)
.
WillReturnRows
(
rows
)
_
,
err
:=
repo
.
ClaimNextPendingTask
(
context
.
Background
(),
1800
)
require
.
Error
(
t
,
err
)
require
.
NoError
(
t
,
mock
.
ExpectationsWereMet
())
}
func
TestUsageCleanupRepositoryMarkTaskSucceeded
(
t
*
testing
.
T
)
{
db
,
mock
:=
newSQLMock
(
t
)
repo
:=
&
usageCleanupRepository
{
sql
:
db
}
mock
.
ExpectExec
(
"UPDATE usage_cleanup_tasks"
)
.
WithArgs
(
service
.
UsageCleanupStatusSucceeded
,
int64
(
12
),
int64
(
9
))
.
WillReturnResult
(
sqlmock
.
NewResult
(
0
,
1
))
err
:=
repo
.
MarkTaskSucceeded
(
context
.
Background
(),
9
,
12
)
require
.
NoError
(
t
,
err
)
require
.
NoError
(
t
,
mock
.
ExpectationsWereMet
())
}
func
TestUsageCleanupRepositoryMarkTaskFailed
(
t
*
testing
.
T
)
{
db
,
mock
:=
newSQLMock
(
t
)
repo
:=
&
usageCleanupRepository
{
sql
:
db
}
mock
.
ExpectExec
(
"UPDATE usage_cleanup_tasks"
)
.
WithArgs
(
service
.
UsageCleanupStatusFailed
,
int64
(
4
),
"boom"
,
int64
(
2
))
.
WillReturnResult
(
sqlmock
.
NewResult
(
0
,
1
))
err
:=
repo
.
MarkTaskFailed
(
context
.
Background
(),
2
,
4
,
"boom"
)
require
.
NoError
(
t
,
err
)
require
.
NoError
(
t
,
mock
.
ExpectationsWereMet
())
}
func
TestUsageCleanupRepositoryGetTaskStatus
(
t
*
testing
.
T
)
{
db
,
mock
:=
newSQLMock
(
t
)
repo
:=
&
usageCleanupRepository
{
sql
:
db
}
mock
.
ExpectQuery
(
"SELECT status FROM usage_cleanup_tasks"
)
.
WithArgs
(
int64
(
9
))
.
WillReturnRows
(
sqlmock
.
NewRows
([]
string
{
"status"
})
.
AddRow
(
service
.
UsageCleanupStatusPending
))
status
,
err
:=
repo
.
GetTaskStatus
(
context
.
Background
(),
9
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
service
.
UsageCleanupStatusPending
,
status
)
require
.
NoError
(
t
,
mock
.
ExpectationsWereMet
())
}
func
TestUsageCleanupRepositoryGetTaskStatusQueryError
(
t
*
testing
.
T
)
{
db
,
mock
:=
newSQLMock
(
t
)
repo
:=
&
usageCleanupRepository
{
sql
:
db
}
mock
.
ExpectQuery
(
"SELECT status FROM usage_cleanup_tasks"
)
.
WithArgs
(
int64
(
9
))
.
WillReturnError
(
sql
.
ErrConnDone
)
_
,
err
:=
repo
.
GetTaskStatus
(
context
.
Background
(),
9
)
require
.
Error
(
t
,
err
)
require
.
NoError
(
t
,
mock
.
ExpectationsWereMet
())
}
func
TestUsageCleanupRepositoryUpdateTaskProgress
(
t
*
testing
.
T
)
{
db
,
mock
:=
newSQLMock
(
t
)
repo
:=
&
usageCleanupRepository
{
sql
:
db
}
mock
.
ExpectExec
(
"UPDATE usage_cleanup_tasks"
)
.
WithArgs
(
int64
(
123
),
int64
(
8
))
.
WillReturnResult
(
sqlmock
.
NewResult
(
0
,
1
))
err
:=
repo
.
UpdateTaskProgress
(
context
.
Background
(),
8
,
123
)
require
.
NoError
(
t
,
err
)
require
.
NoError
(
t
,
mock
.
ExpectationsWereMet
())
}
func
TestUsageCleanupRepositoryCancelTask
(
t
*
testing
.
T
)
{
db
,
mock
:=
newSQLMock
(
t
)
repo
:=
&
usageCleanupRepository
{
sql
:
db
}
mock
.
ExpectQuery
(
"UPDATE usage_cleanup_tasks"
)
.
WithArgs
(
service
.
UsageCleanupStatusCanceled
,
int64
(
6
),
int64
(
9
),
service
.
UsageCleanupStatusPending
,
service
.
UsageCleanupStatusRunning
)
.
WillReturnRows
(
sqlmock
.
NewRows
([]
string
{
"id"
})
.
AddRow
(
int64
(
6
)))
ok
,
err
:=
repo
.
CancelTask
(
context
.
Background
(),
6
,
9
)
require
.
NoError
(
t
,
err
)
require
.
True
(
t
,
ok
)
require
.
NoError
(
t
,
mock
.
ExpectationsWereMet
())
}
func
TestUsageCleanupRepositoryCancelTaskNoRows
(
t
*
testing
.
T
)
{
db
,
mock
:=
newSQLMock
(
t
)
repo
:=
&
usageCleanupRepository
{
sql
:
db
}
mock
.
ExpectQuery
(
"UPDATE usage_cleanup_tasks"
)
.
WithArgs
(
service
.
UsageCleanupStatusCanceled
,
int64
(
6
),
int64
(
9
),
service
.
UsageCleanupStatusPending
,
service
.
UsageCleanupStatusRunning
)
.
WillReturnRows
(
sqlmock
.
NewRows
([]
string
{
"id"
}))
ok
,
err
:=
repo
.
CancelTask
(
context
.
Background
(),
6
,
9
)
require
.
NoError
(
t
,
err
)
require
.
False
(
t
,
ok
)
require
.
NoError
(
t
,
mock
.
ExpectationsWereMet
())
}
func
TestUsageCleanupRepositoryDeleteUsageLogsBatchMissingRange
(
t
*
testing
.
T
)
{
db
,
_
:=
newSQLMock
(
t
)
repo
:=
&
usageCleanupRepository
{
sql
:
db
}
_
,
err
:=
repo
.
DeleteUsageLogsBatch
(
context
.
Background
(),
service
.
UsageCleanupFilters
{},
10
)
require
.
Error
(
t
,
err
)
}
func
TestUsageCleanupRepositoryDeleteUsageLogsBatch
(
t
*
testing
.
T
)
{
db
,
mock
:=
newSQLMock
(
t
)
repo
:=
&
usageCleanupRepository
{
sql
:
db
}
start
:=
time
.
Date
(
2024
,
1
,
1
,
0
,
0
,
0
,
0
,
time
.
UTC
)
end
:=
start
.
Add
(
24
*
time
.
Hour
)
userID
:=
int64
(
3
)
model
:=
" gpt-4 "
filters
:=
service
.
UsageCleanupFilters
{
StartTime
:
start
,
EndTime
:
end
,
UserID
:
&
userID
,
Model
:
&
model
,
}
mock
.
ExpectQuery
(
"DELETE FROM usage_logs"
)
.
WithArgs
(
start
,
end
,
userID
,
"gpt-4"
,
2
)
.
WillReturnRows
(
sqlmock
.
NewRows
([]
string
{
"id"
})
.
AddRow
(
int64
(
1
))
.
AddRow
(
int64
(
2
)))
deleted
,
err
:=
repo
.
DeleteUsageLogsBatch
(
context
.
Background
(),
filters
,
2
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
int64
(
2
),
deleted
)
require
.
NoError
(
t
,
mock
.
ExpectationsWereMet
())
}
func
TestUsageCleanupRepositoryDeleteUsageLogsBatchQueryError
(
t
*
testing
.
T
)
{
db
,
mock
:=
newSQLMock
(
t
)
repo
:=
&
usageCleanupRepository
{
sql
:
db
}
start
:=
time
.
Date
(
2024
,
1
,
1
,
0
,
0
,
0
,
0
,
time
.
UTC
)
end
:=
start
.
Add
(
24
*
time
.
Hour
)
filters
:=
service
.
UsageCleanupFilters
{
StartTime
:
start
,
EndTime
:
end
}
mock
.
ExpectQuery
(
"DELETE FROM usage_logs"
)
.
WithArgs
(
start
,
end
,
5
)
.
WillReturnError
(
sql
.
ErrConnDone
)
_
,
err
:=
repo
.
DeleteUsageLogsBatch
(
context
.
Background
(),
filters
,
5
)
require
.
Error
(
t
,
err
)
require
.
NoError
(
t
,
mock
.
ExpectationsWereMet
())
}
func
TestBuildUsageCleanupWhere
(
t
*
testing
.
T
)
{
start
:=
time
.
Date
(
2024
,
1
,
1
,
0
,
0
,
0
,
0
,
time
.
UTC
)
end
:=
start
.
Add
(
24
*
time
.
Hour
)
userID
:=
int64
(
1
)
apiKeyID
:=
int64
(
2
)
accountID
:=
int64
(
3
)
groupID
:=
int64
(
4
)
model
:=
" gpt-4 "
stream
:=
true
billingType
:=
int8
(
2
)
where
,
args
:=
buildUsageCleanupWhere
(
service
.
UsageCleanupFilters
{
StartTime
:
start
,
EndTime
:
end
,
UserID
:
&
userID
,
APIKeyID
:
&
apiKeyID
,
AccountID
:
&
accountID
,
GroupID
:
&
groupID
,
Model
:
&
model
,
Stream
:
&
stream
,
BillingType
:
&
billingType
,
})
require
.
Equal
(
t
,
"created_at >= $1 AND created_at <= $2 AND user_id = $3 AND api_key_id = $4 AND account_id = $5 AND group_id = $6 AND model = $7 AND stream = $8 AND billing_type = $9"
,
where
)
require
.
Equal
(
t
,
[]
any
{
start
,
end
,
userID
,
apiKeyID
,
accountID
,
groupID
,
"gpt-4"
,
stream
,
billingType
},
args
)
}
func
TestBuildUsageCleanupWhereModelEmpty
(
t
*
testing
.
T
)
{
start
:=
time
.
Date
(
2024
,
1
,
1
,
0
,
0
,
0
,
0
,
time
.
UTC
)
end
:=
start
.
Add
(
24
*
time
.
Hour
)
model
:=
" "
where
,
args
:=
buildUsageCleanupWhere
(
service
.
UsageCleanupFilters
{
StartTime
:
start
,
EndTime
:
end
,
Model
:
&
model
,
})
require
.
Equal
(
t
,
"created_at >= $1 AND created_at <= $2"
,
where
)
require
.
Equal
(
t
,
[]
any
{
start
,
end
},
args
)
}
backend/internal/repository/usage_log_repo.go
View file @
0170d19f
...
...
@@ -1411,7 +1411,7 @@ func (r *usageLogRepository) GetBatchAPIKeyUsageStats(ctx context.Context, apiKe
}
// GetUsageTrendWithFilters returns usage trend data with optional filters
func
(
r
*
usageLogRepository
)
GetUsageTrendWithFilters
(
ctx
context
.
Context
,
startTime
,
endTime
time
.
Time
,
granularity
string
,
userID
,
apiKeyID
,
accountID
,
groupID
int64
,
model
string
,
stream
*
bool
)
(
results
[]
TrendDataPoint
,
err
error
)
{
func
(
r
*
usageLogRepository
)
GetUsageTrendWithFilters
(
ctx
context
.
Context
,
startTime
,
endTime
time
.
Time
,
granularity
string
,
userID
,
apiKeyID
,
accountID
,
groupID
int64
,
model
string
,
stream
*
bool
,
billingType
*
int8
)
(
results
[]
TrendDataPoint
,
err
error
)
{
dateFormat
:=
"YYYY-MM-DD"
if
granularity
==
"hour"
{
dateFormat
=
"YYYY-MM-DD HH24:00"
...
...
@@ -1456,6 +1456,10 @@ func (r *usageLogRepository) GetUsageTrendWithFilters(ctx context.Context, start
query
+=
fmt
.
Sprintf
(
" AND stream = $%d"
,
len
(
args
)
+
1
)
args
=
append
(
args
,
*
stream
)
}
if
billingType
!=
nil
{
query
+=
fmt
.
Sprintf
(
" AND billing_type = $%d"
,
len
(
args
)
+
1
)
args
=
append
(
args
,
int16
(
*
billingType
))
}
query
+=
" GROUP BY date ORDER BY date ASC"
rows
,
err
:=
r
.
sql
.
QueryContext
(
ctx
,
query
,
args
...
)
...
...
@@ -1479,7 +1483,7 @@ func (r *usageLogRepository) GetUsageTrendWithFilters(ctx context.Context, start
}
// GetModelStatsWithFilters returns model statistics with optional filters
func
(
r
*
usageLogRepository
)
GetModelStatsWithFilters
(
ctx
context
.
Context
,
startTime
,
endTime
time
.
Time
,
userID
,
apiKeyID
,
accountID
,
groupID
int64
,
stream
*
bool
)
(
results
[]
ModelStat
,
err
error
)
{
func
(
r
*
usageLogRepository
)
GetModelStatsWithFilters
(
ctx
context
.
Context
,
startTime
,
endTime
time
.
Time
,
userID
,
apiKeyID
,
accountID
,
groupID
int64
,
stream
*
bool
,
billingType
*
int8
)
(
results
[]
ModelStat
,
err
error
)
{
actualCostExpr
:=
"COALESCE(SUM(actual_cost), 0) as actual_cost"
// 当仅按 account_id 聚合时,实际费用使用账号倍率(total_cost * account_rate_multiplier)。
if
accountID
>
0
&&
userID
==
0
&&
apiKeyID
==
0
{
...
...
@@ -1520,6 +1524,10 @@ func (r *usageLogRepository) GetModelStatsWithFilters(ctx context.Context, start
query
+=
fmt
.
Sprintf
(
" AND stream = $%d"
,
len
(
args
)
+
1
)
args
=
append
(
args
,
*
stream
)
}
if
billingType
!=
nil
{
query
+=
fmt
.
Sprintf
(
" AND billing_type = $%d"
,
len
(
args
)
+
1
)
args
=
append
(
args
,
int16
(
*
billingType
))
}
query
+=
" GROUP BY model ORDER BY total_tokens DESC"
rows
,
err
:=
r
.
sql
.
QueryContext
(
ctx
,
query
,
args
...
)
...
...
@@ -1825,7 +1833,7 @@ func (r *usageLogRepository) GetAccountUsageStats(ctx context.Context, accountID
}
}
models
,
err
:=
r
.
GetModelStatsWithFilters
(
ctx
,
startTime
,
endTime
,
0
,
0
,
accountID
,
0
,
nil
)
models
,
err
:=
r
.
GetModelStatsWithFilters
(
ctx
,
startTime
,
endTime
,
0
,
0
,
accountID
,
0
,
nil
,
nil
)
if
err
!=
nil
{
models
=
[]
ModelStat
{}
}
...
...
Prev
1
…
3
4
5
6
7
8
9
10
11
…
16
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment