Commit 0170d19f authored by song's avatar song
Browse files

merge upstream main

parent 7ade9baa
......@@ -24,7 +24,7 @@ func (s *GatewayRoutingSuite) SetupTest() {
s.ctx = context.Background()
tx := testEntTx(s.T())
s.client = tx.Client()
s.accountRepo = newAccountRepositoryWithSQL(s.client, tx)
s.accountRepo = newAccountRepositoryWithSQL(s.client, tx, nil)
}
func TestGatewayRoutingSuite(t *testing.T) {
......
......@@ -4,6 +4,7 @@ import (
"errors"
"fmt"
"io"
"log/slog"
"net"
"net/http"
"net/url"
......@@ -14,6 +15,7 @@ import (
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/proxyutil"
"github.com/Wei-Shaw/sub2api/internal/pkg/tlsfingerprint"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/Wei-Shaw/sub2api/internal/util/urlvalidator"
)
......@@ -150,6 +152,172 @@ func (s *httpUpstreamService) Do(req *http.Request, proxyURL string, accountID i
return resp, nil
}
// DoWithTLS 执行带 TLS 指纹伪装的 HTTP 请求
// 根据 enableTLSFingerprint 参数决定是否使用 TLS 指纹
//
// 参数:
// - req: HTTP 请求对象
// - proxyURL: 代理地址,空字符串表示直连
// - accountID: 账户 ID,用于账户级隔离和 TLS 指纹模板选择
// - accountConcurrency: 账户并发限制,用于动态调整连接池大小
// - enableTLSFingerprint: 是否启用 TLS 指纹伪装
//
// TLS 指纹说明:
// - 当 enableTLSFingerprint=true 时,使用 utls 库模拟 Claude CLI 的 TLS 指纹
// - 指纹模板根据 accountID % len(profiles) 自动选择
// - 支持直连、HTTP/HTTPS 代理、SOCKS5 代理三种场景
func (s *httpUpstreamService) DoWithTLS(req *http.Request, proxyURL string, accountID int64, accountConcurrency int, enableTLSFingerprint bool) (*http.Response, error) {
// 如果未启用 TLS 指纹,直接使用标准请求路径
if !enableTLSFingerprint {
return s.Do(req, proxyURL, accountID, accountConcurrency)
}
// TLS 指纹已启用,记录调试日志
targetHost := ""
if req != nil && req.URL != nil {
targetHost = req.URL.Host
}
proxyInfo := "direct"
if proxyURL != "" {
proxyInfo = proxyURL
}
slog.Debug("tls_fingerprint_enabled", "account_id", accountID, "target", targetHost, "proxy", proxyInfo)
if err := s.validateRequestHost(req); err != nil {
return nil, err
}
// 获取 TLS 指纹 Profile
registry := tlsfingerprint.GlobalRegistry()
profile := registry.GetProfileByAccountID(accountID)
if profile == nil {
// 如果获取不到 profile,回退到普通请求
slog.Debug("tls_fingerprint_no_profile", "account_id", accountID, "fallback", "standard_request")
return s.Do(req, proxyURL, accountID, accountConcurrency)
}
slog.Debug("tls_fingerprint_using_profile", "account_id", accountID, "profile", profile.Name, "grease", profile.EnableGREASE)
// 获取或创建带 TLS 指纹的客户端
entry, err := s.acquireClientWithTLS(proxyURL, accountID, accountConcurrency, profile)
if err != nil {
slog.Debug("tls_fingerprint_acquire_client_failed", "account_id", accountID, "error", err)
return nil, err
}
// 执行请求
resp, err := entry.client.Do(req)
if err != nil {
// 请求失败,立即减少计数
atomic.AddInt64(&entry.inFlight, -1)
atomic.StoreInt64(&entry.lastUsed, time.Now().UnixNano())
slog.Debug("tls_fingerprint_request_failed", "account_id", accountID, "error", err)
return nil, err
}
slog.Debug("tls_fingerprint_request_success", "account_id", accountID, "status", resp.StatusCode)
// 包装响应体,在关闭时自动减少计数并更新时间戳
resp.Body = wrapTrackedBody(resp.Body, func() {
atomic.AddInt64(&entry.inFlight, -1)
atomic.StoreInt64(&entry.lastUsed, time.Now().UnixNano())
})
return resp, nil
}
// acquireClientWithTLS 获取或创建带 TLS 指纹的客户端
func (s *httpUpstreamService) acquireClientWithTLS(proxyURL string, accountID int64, accountConcurrency int, profile *tlsfingerprint.Profile) (*upstreamClientEntry, error) {
return s.getClientEntryWithTLS(proxyURL, accountID, accountConcurrency, profile, true, true)
}
// getClientEntryWithTLS 获取或创建带 TLS 指纹的客户端条目
// TLS 指纹客户端使用独立的缓存键,与普通客户端隔离
func (s *httpUpstreamService) getClientEntryWithTLS(proxyURL string, accountID int64, accountConcurrency int, profile *tlsfingerprint.Profile, markInFlight bool, enforceLimit bool) (*upstreamClientEntry, error) {
isolation := s.getIsolationMode()
proxyKey, parsedProxy := normalizeProxyURL(proxyURL)
// TLS 指纹客户端使用独立的缓存键,加 "tls:" 前缀
cacheKey := "tls:" + buildCacheKey(isolation, proxyKey, accountID)
poolKey := s.buildPoolKey(isolation, accountConcurrency) + ":tls"
now := time.Now()
nowUnix := now.UnixNano()
// 读锁快速路径
s.mu.RLock()
if entry, ok := s.clients[cacheKey]; ok && s.shouldReuseEntry(entry, isolation, proxyKey, poolKey) {
atomic.StoreInt64(&entry.lastUsed, nowUnix)
if markInFlight {
atomic.AddInt64(&entry.inFlight, 1)
}
s.mu.RUnlock()
slog.Debug("tls_fingerprint_reusing_client", "account_id", accountID, "cache_key", cacheKey)
return entry, nil
}
s.mu.RUnlock()
// 写锁慢路径
s.mu.Lock()
if entry, ok := s.clients[cacheKey]; ok {
if s.shouldReuseEntry(entry, isolation, proxyKey, poolKey) {
atomic.StoreInt64(&entry.lastUsed, nowUnix)
if markInFlight {
atomic.AddInt64(&entry.inFlight, 1)
}
s.mu.Unlock()
slog.Debug("tls_fingerprint_reusing_client", "account_id", accountID, "cache_key", cacheKey)
return entry, nil
}
slog.Debug("tls_fingerprint_evicting_stale_client",
"account_id", accountID,
"cache_key", cacheKey,
"proxy_changed", entry.proxyKey != proxyKey,
"pool_changed", entry.poolKey != poolKey)
s.removeClientLocked(cacheKey, entry)
}
// 超出缓存上限时尝试淘汰
if enforceLimit && s.maxUpstreamClients() > 0 {
s.evictIdleLocked(now)
if len(s.clients) >= s.maxUpstreamClients() {
if !s.evictOldestIdleLocked() {
s.mu.Unlock()
return nil, errUpstreamClientLimitReached
}
}
}
// 创建带 TLS 指纹的 Transport
slog.Debug("tls_fingerprint_creating_new_client", "account_id", accountID, "cache_key", cacheKey, "proxy", proxyKey)
settings := s.resolvePoolSettings(isolation, accountConcurrency)
transport, err := buildUpstreamTransportWithTLSFingerprint(settings, parsedProxy, profile)
if err != nil {
s.mu.Unlock()
return nil, fmt.Errorf("build TLS fingerprint transport: %w", err)
}
client := &http.Client{Transport: transport}
if s.shouldValidateResolvedIP() {
client.CheckRedirect = s.redirectChecker
}
entry := &upstreamClientEntry{
client: client,
proxyKey: proxyKey,
poolKey: poolKey,
}
atomic.StoreInt64(&entry.lastUsed, nowUnix)
if markInFlight {
atomic.StoreInt64(&entry.inFlight, 1)
}
s.clients[cacheKey] = entry
s.evictIdleLocked(now)
s.evictOverLimitLocked()
s.mu.Unlock()
return entry, nil
}
func (s *httpUpstreamService) shouldValidateResolvedIP() bool {
if s.cfg == nil {
return false
......@@ -618,6 +786,64 @@ func buildUpstreamTransport(settings poolSettings, proxyURL *url.URL) (*http.Tra
return transport, nil
}
// buildUpstreamTransportWithTLSFingerprint 构建带 TLS 指纹伪装的 Transport
// 使用 utls 库模拟 Claude CLI 的 TLS 指纹
//
// 参数:
// - settings: 连接池配置
// - proxyURL: 代理 URL(nil 表示直连)
// - profile: TLS 指纹配置
//
// 返回:
// - *http.Transport: 配置好的 Transport 实例
// - error: 配置错误
//
// 代理类型处理:
// - nil/空: 直连,使用 TLSFingerprintDialer
// - http/https: HTTP 代理,使用 HTTPProxyDialer(CONNECT 隧道 + utls 握手)
// - socks5: SOCKS5 代理,使用 SOCKS5ProxyDialer(SOCKS5 隧道 + utls 握手)
func buildUpstreamTransportWithTLSFingerprint(settings poolSettings, proxyURL *url.URL, profile *tlsfingerprint.Profile) (*http.Transport, error) {
transport := &http.Transport{
MaxIdleConns: settings.maxIdleConns,
MaxIdleConnsPerHost: settings.maxIdleConnsPerHost,
MaxConnsPerHost: settings.maxConnsPerHost,
IdleConnTimeout: settings.idleConnTimeout,
ResponseHeaderTimeout: settings.responseHeaderTimeout,
// 禁用默认的 TLS,我们使用自定义的 DialTLSContext
ForceAttemptHTTP2: false,
}
// 根据代理类型选择合适的 TLS 指纹 Dialer
if proxyURL == nil {
// 直连:使用 TLSFingerprintDialer
slog.Debug("tls_fingerprint_transport_direct")
dialer := tlsfingerprint.NewDialer(profile, nil)
transport.DialTLSContext = dialer.DialTLSContext
} else {
scheme := strings.ToLower(proxyURL.Scheme)
switch scheme {
case "socks5", "socks5h":
// SOCKS5 代理:使用 SOCKS5ProxyDialer
slog.Debug("tls_fingerprint_transport_socks5", "proxy", proxyURL.Host)
socks5Dialer := tlsfingerprint.NewSOCKS5ProxyDialer(profile, proxyURL)
transport.DialTLSContext = socks5Dialer.DialTLSContext
case "http", "https":
// HTTP/HTTPS 代理:使用 HTTPProxyDialer(CONNECT 隧道)
slog.Debug("tls_fingerprint_transport_http_connect", "proxy", proxyURL.Host)
httpDialer := tlsfingerprint.NewHTTPProxyDialer(profile, proxyURL)
transport.DialTLSContext = httpDialer.DialTLSContext
default:
// 未知代理类型,回退到普通代理配置(无 TLS 指纹)
slog.Debug("tls_fingerprint_transport_unknown_scheme_fallback", "scheme", scheme)
if err := proxyutil.ConfigureTransportProxy(transport, proxyURL); err != nil {
return nil, err
}
}
}
return transport, nil
}
// trackedBody 带跟踪功能的响应体包装器
// 在 Close 时执行回调,用于更新请求计数
type trackedBody struct {
......
......@@ -13,6 +13,8 @@ import (
const (
fingerprintKeyPrefix = "fingerprint:"
fingerprintTTL = 24 * time.Hour
maskedSessionKeyPrefix = "masked_session:"
maskedSessionTTL = 15 * time.Minute
)
// fingerprintKey generates the Redis key for account fingerprint cache.
......@@ -20,6 +22,11 @@ func fingerprintKey(accountID int64) string {
return fmt.Sprintf("%s%d", fingerprintKeyPrefix, accountID)
}
// maskedSessionKey generates the Redis key for masked session ID cache.
func maskedSessionKey(accountID int64) string {
return fmt.Sprintf("%s%d", maskedSessionKeyPrefix, accountID)
}
type identityCache struct {
rdb *redis.Client
}
......@@ -49,3 +56,20 @@ func (c *identityCache) SetFingerprint(ctx context.Context, accountID int64, fp
}
return c.rdb.Set(ctx, key, val, fingerprintTTL).Err()
}
func (c *identityCache) GetMaskedSessionID(ctx context.Context, accountID int64) (string, error) {
key := maskedSessionKey(accountID)
val, err := c.rdb.Get(ctx, key).Result()
if err != nil {
if err == redis.Nil {
return "", nil
}
return "", err
}
return val, nil
}
func (c *identityCache) SetMaskedSessionID(ctx context.Context, accountID int64, sessionID string) error {
key := maskedSessionKey(accountID)
return c.rdb.Set(ctx, key, sessionID, maskedSessionTTL).Err()
}
......@@ -2,10 +2,11 @@ package repository
import (
"context"
"fmt"
"net/http"
"net/url"
"time"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/imroc/req/v3"
......@@ -38,16 +39,17 @@ func (s *openaiOAuthService) ExchangeCode(ctx context.Context, code, codeVerifie
resp, err := client.R().
SetContext(ctx).
SetHeader("User-Agent", "codex-cli/0.91.0").
SetFormDataFromValues(formData).
SetSuccessResult(&tokenResp).
Post(s.tokenURL)
if err != nil {
return nil, fmt.Errorf("request failed: %w", err)
return nil, infraerrors.Newf(http.StatusBadGateway, "OPENAI_OAUTH_REQUEST_FAILED", "request failed: %v", err)
}
if !resp.IsSuccessState() {
return nil, fmt.Errorf("token exchange failed: status %d, body: %s", resp.StatusCode, resp.String())
return nil, infraerrors.Newf(http.StatusBadGateway, "OPENAI_OAUTH_TOKEN_EXCHANGE_FAILED", "token exchange failed: status %d, body: %s", resp.StatusCode, resp.String())
}
return &tokenResp, nil
......@@ -66,16 +68,17 @@ func (s *openaiOAuthService) RefreshToken(ctx context.Context, refreshToken, pro
resp, err := client.R().
SetContext(ctx).
SetHeader("User-Agent", "codex-cli/0.91.0").
SetFormDataFromValues(formData).
SetSuccessResult(&tokenResp).
Post(s.tokenURL)
if err != nil {
return nil, fmt.Errorf("request failed: %w", err)
return nil, infraerrors.Newf(http.StatusBadGateway, "OPENAI_OAUTH_REQUEST_FAILED", "request failed: %v", err)
}
if !resp.IsSuccessState() {
return nil, fmt.Errorf("token refresh failed: status %d, body: %s", resp.StatusCode, resp.String())
return nil, infraerrors.Newf(http.StatusBadGateway, "OPENAI_OAUTH_TOKEN_REFRESH_FAILED", "token refresh failed: status %d, body: %s", resp.StatusCode, resp.String())
}
return &tokenResp, nil
......@@ -84,6 +87,6 @@ func (s *openaiOAuthService) RefreshToken(ctx context.Context, refreshToken, pro
func createOpenAIReqClient(proxyURL string) *req.Client {
return getSharedReqClient(reqClientOptions{
ProxyURL: proxyURL,
Timeout: 60 * time.Second,
Timeout: 120 * time.Second,
})
}
......@@ -244,6 +244,13 @@ func (s *OpenAIOAuthServiceSuite) TestRefreshToken_NonSuccessStatus() {
require.ErrorContains(s.T(), err, "status 401")
}
func TestNewOpenAIOAuthClient_DefaultTokenURL(t *testing.T) {
client := NewOpenAIOAuthClient()
svc, ok := client.(*openaiOAuthService)
require.True(t, ok)
require.Equal(t, openai.TokenURL, svc.tokenURL)
}
func TestOpenAIOAuthServiceSuite(t *testing.T) {
suite.Run(t, new(OpenAIOAuthServiceSuite))
}
......@@ -992,7 +992,8 @@ func buildOpsErrorLogsWhere(filter *service.OpsErrorLogFilter) (string, []any) {
}
// View filter: errors vs excluded vs all.
// Excluded = upstream 429/529 and business-limited (quota/concurrency/billing) errors.
// Excluded = business-limited errors (quota/concurrency/billing).
// Upstream 429/529 are included in errors view to match SLA calculation.
view := ""
if filter != nil {
view = strings.ToLower(strings.TrimSpace(filter.View))
......@@ -1000,15 +1001,13 @@ func buildOpsErrorLogsWhere(filter *service.OpsErrorLogFilter) (string, []any) {
switch view {
case "", "errors":
clauses = append(clauses, "COALESCE(is_business_limited,false) = false")
clauses = append(clauses, "COALESCE(upstream_status_code, status_code, 0) NOT IN (429, 529)")
case "excluded":
clauses = append(clauses, "(COALESCE(is_business_limited,false) = true OR COALESCE(upstream_status_code, status_code, 0) IN (429, 529))")
clauses = append(clauses, "COALESCE(is_business_limited,false) = true")
case "all":
// no-op
default:
// treat unknown as default 'errors'
clauses = append(clauses, "COALESCE(is_business_limited,false) = false")
clauses = append(clauses, "COALESCE(upstream_status_code, status_code, 0) NOT IN (429, 529)")
}
if len(filter.StatusCodes) > 0 {
args = append(args, pq.Array(filter.StatusCodes))
......
package repository
import (
"crypto/tls"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
......@@ -26,7 +27,7 @@ func InitRedis(cfg *config.Config) *redis.Client {
// buildRedisOptions 构建 Redis 连接选项
// 从配置文件读取连接池和超时参数,支持生产环境调优
func buildRedisOptions(cfg *config.Config) *redis.Options {
return &redis.Options{
opts := &redis.Options{
Addr: cfg.Redis.Address(),
Password: cfg.Redis.Password,
DB: cfg.Redis.DB,
......@@ -36,4 +37,13 @@ func buildRedisOptions(cfg *config.Config) *redis.Options {
PoolSize: cfg.Redis.PoolSize, // 连接池大小
MinIdleConns: cfg.Redis.MinIdleConns, // 最小空闲连接
}
if cfg.Redis.EnableTLS {
opts.TLSConfig = &tls.Config{
MinVersion: tls.VersionTLS12,
ServerName: cfg.Redis.Host,
}
}
return opts
}
......@@ -32,4 +32,16 @@ func TestBuildRedisOptions(t *testing.T) {
require.Equal(t, 4*time.Second, opts.WriteTimeout)
require.Equal(t, 100, opts.PoolSize)
require.Equal(t, 10, opts.MinIdleConns)
require.Nil(t, opts.TLSConfig)
// Test case with TLS enabled
cfgTLS := &config.Config{
Redis: config.RedisConfig{
Host: "localhost",
EnableTLS: true,
},
}
optsTLS := buildRedisOptions(cfgTLS)
require.NotNil(t, optsTLS.TLSConfig)
require.Equal(t, "localhost", optsTLS.TLSConfig.ServerName)
}
......@@ -14,6 +14,7 @@ type reqClientOptions struct {
ProxyURL string // 代理 URL(支持 http/https/socks5)
Timeout time.Duration // 请求超时时间
Impersonate bool // 是否模拟 Chrome 浏览器指纹
ForceHTTP2 bool // 是否强制使用 HTTP/2
}
// sharedReqClients 存储按配置参数缓存的 req 客户端实例
......@@ -41,6 +42,9 @@ func getSharedReqClient(opts reqClientOptions) *req.Client {
}
client := req.C().SetTimeout(opts.Timeout)
if opts.ForceHTTP2 {
client = client.EnableForceHTTP2()
}
if opts.Impersonate {
client = client.ImpersonateChrome()
}
......@@ -56,9 +60,10 @@ func getSharedReqClient(opts reqClientOptions) *req.Client {
}
func buildReqClientKey(opts reqClientOptions) string {
return fmt.Sprintf("%s|%s|%t",
return fmt.Sprintf("%s|%s|%t|%t",
strings.TrimSpace(opts.ProxyURL),
opts.Timeout.String(),
opts.Impersonate,
opts.ForceHTTP2,
)
}
package repository
import (
"reflect"
"sync"
"testing"
"time"
"unsafe"
"github.com/imroc/req/v3"
"github.com/stretchr/testify/require"
)
func forceHTTPVersion(t *testing.T, client *req.Client) string {
t.Helper()
transport := client.GetTransport()
field := reflect.ValueOf(transport).Elem().FieldByName("forceHttpVersion")
require.True(t, field.IsValid(), "forceHttpVersion field not found")
require.True(t, field.CanAddr(), "forceHttpVersion field not addressable")
return reflect.NewAt(field.Type(), unsafe.Pointer(field.UnsafeAddr())).Elem().String()
}
func TestGetSharedReqClient_ForceHTTP2SeparatesCache(t *testing.T) {
sharedReqClients = sync.Map{}
base := reqClientOptions{
ProxyURL: "http://proxy.local:8080",
Timeout: time.Second,
}
clientDefault := getSharedReqClient(base)
force := base
force.ForceHTTP2 = true
clientForce := getSharedReqClient(force)
require.NotSame(t, clientDefault, clientForce)
require.NotEqual(t, buildReqClientKey(base), buildReqClientKey(force))
}
func TestGetSharedReqClient_ReuseCachedClient(t *testing.T) {
sharedReqClients = sync.Map{}
opts := reqClientOptions{
ProxyURL: "http://proxy.local:8080",
Timeout: 2 * time.Second,
}
first := getSharedReqClient(opts)
second := getSharedReqClient(opts)
require.Same(t, first, second)
}
func TestGetSharedReqClient_IgnoresNonClientCache(t *testing.T) {
sharedReqClients = sync.Map{}
opts := reqClientOptions{
ProxyURL: " http://proxy.local:8080 ",
Timeout: 3 * time.Second,
}
key := buildReqClientKey(opts)
sharedReqClients.Store(key, "invalid")
client := getSharedReqClient(opts)
require.NotNil(t, client)
loaded, ok := sharedReqClients.Load(key)
require.True(t, ok)
require.IsType(t, "invalid", loaded)
}
func TestGetSharedReqClient_ImpersonateAndProxy(t *testing.T) {
sharedReqClients = sync.Map{}
opts := reqClientOptions{
ProxyURL: " http://proxy.local:8080 ",
Timeout: 4 * time.Second,
Impersonate: true,
}
client := getSharedReqClient(opts)
require.NotNil(t, client)
require.Equal(t, "http://proxy.local:8080|4s|true|false", buildReqClientKey(opts))
}
func TestCreateOpenAIReqClient_Timeout120Seconds(t *testing.T) {
sharedReqClients = sync.Map{}
client := createOpenAIReqClient("http://proxy.local:8080")
require.Equal(t, 120*time.Second, client.GetClient().Timeout)
}
func TestCreateGeminiReqClient_ForceHTTP2Disabled(t *testing.T) {
sharedReqClients = sync.Map{}
client := createGeminiReqClient("http://proxy.local:8080")
require.Equal(t, "", forceHTTPVersion(t, client))
}
......@@ -58,7 +58,9 @@ func (c *schedulerCache) GetSnapshot(ctx context.Context, bucket service.Schedul
return nil, false, err
}
if len(ids) == 0 {
return []*service.Account{}, true, nil
// 空快照视为缓存未命中,触发数据库回退查询
// 这解决了新分组创建后立即绑定账号时的竞态条件问题
return nil, false, nil
}
keys := make([]string, 0, len(ids))
......
......@@ -19,7 +19,7 @@ func TestSchedulerSnapshotOutboxReplay(t *testing.T) {
_, _ = integrationDB.ExecContext(ctx, "TRUNCATE scheduler_outbox")
accountRepo := newAccountRepositoryWithSQL(client, integrationDB)
accountRepo := newAccountRepositoryWithSQL(client, integrationDB, nil)
outboxRepo := NewSchedulerOutboxRepository(integrationDB)
cache := NewSchedulerCache(rdb)
......
......@@ -217,7 +217,7 @@ func (c *sessionLimitCache) GetActiveSessionCount(ctx context.Context, accountID
}
// GetActiveSessionCountBatch 批量获取多个账号的活跃会话数
func (c *sessionLimitCache) GetActiveSessionCountBatch(ctx context.Context, accountIDs []int64) (map[int64]int, error) {
func (c *sessionLimitCache) GetActiveSessionCountBatch(ctx context.Context, accountIDs []int64, idleTimeouts map[int64]time.Duration) (map[int64]int, error) {
if len(accountIDs) == 0 {
return make(map[int64]int), nil
}
......@@ -226,11 +226,18 @@ func (c *sessionLimitCache) GetActiveSessionCountBatch(ctx context.Context, acco
// 使用 pipeline 批量执行
pipe := c.rdb.Pipeline()
idleTimeoutSeconds := int(c.defaultIdleTimeout.Seconds())
cmds := make(map[int64]*redis.Cmd, len(accountIDs))
for _, accountID := range accountIDs {
key := sessionLimitKey(accountID)
// 使用各账号自己的 idleTimeout,如果没有则用默认值
idleTimeout := c.defaultIdleTimeout
if idleTimeouts != nil {
if t, ok := idleTimeouts[accountID]; ok && t > 0 {
idleTimeout = t
}
}
idleTimeoutSeconds := int(idleTimeout.Seconds())
cmds[accountID] = getActiveSessionCountScript.Run(ctx, pipe, []string{key}, idleTimeoutSeconds)
}
......
package repository
import (
"context"
"fmt"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/group"
"github.com/Wei-Shaw/sub2api/internal/service"
)
func ensureSimpleModeDefaultGroups(ctx context.Context, client *dbent.Client) error {
if client == nil {
return fmt.Errorf("nil ent client")
}
requiredByPlatform := map[string]int{
service.PlatformAnthropic: 1,
service.PlatformOpenAI: 1,
service.PlatformGemini: 1,
service.PlatformAntigravity: 2,
}
for platform, minCount := range requiredByPlatform {
count, err := client.Group.Query().
Where(group.PlatformEQ(platform), group.DeletedAtIsNil()).
Count(ctx)
if err != nil {
return fmt.Errorf("count groups for platform %s: %w", platform, err)
}
if platform == service.PlatformAntigravity {
if count < minCount {
for i := count; i < minCount; i++ {
name := fmt.Sprintf("%s-default-%d", platform, i+1)
if err := createGroupIfNotExists(ctx, client, name, platform); err != nil {
return err
}
}
}
continue
}
// Non-antigravity platforms: ensure <platform>-default exists.
name := platform + "-default"
if err := createGroupIfNotExists(ctx, client, name, platform); err != nil {
return err
}
}
return nil
}
func createGroupIfNotExists(ctx context.Context, client *dbent.Client, name, platform string) error {
exists, err := client.Group.Query().
Where(group.NameEQ(name), group.DeletedAtIsNil()).
Exist(ctx)
if err != nil {
return fmt.Errorf("check group exists %s: %w", name, err)
}
if exists {
return nil
}
_, err = client.Group.Create().
SetName(name).
SetDescription("Auto-created default group").
SetPlatform(platform).
SetStatus(service.StatusActive).
SetSubscriptionType(service.SubscriptionTypeStandard).
SetRateMultiplier(1.0).
SetIsExclusive(false).
Save(ctx)
if err != nil {
if dbent.IsConstraintError(err) {
// Concurrent server startups may race on creation; treat as success.
return nil
}
return fmt.Errorf("create default group %s: %w", name, err)
}
return nil
}
//go:build integration
package repository
import (
"context"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/ent/group"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/stretchr/testify/require"
)
func TestEnsureSimpleModeDefaultGroups_CreatesMissingDefaults(t *testing.T) {
ctx := context.Background()
tx := testEntTx(t)
client := tx.Client()
seedCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
defer cancel()
require.NoError(t, ensureSimpleModeDefaultGroups(seedCtx, client))
assertGroupExists := func(name string) {
exists, err := client.Group.Query().Where(group.NameEQ(name), group.DeletedAtIsNil()).Exist(seedCtx)
require.NoError(t, err)
require.True(t, exists, "expected group %s to exist", name)
}
assertGroupExists(service.PlatformAnthropic + "-default")
assertGroupExists(service.PlatformOpenAI + "-default")
assertGroupExists(service.PlatformGemini + "-default")
assertGroupExists(service.PlatformAntigravity + "-default-1")
assertGroupExists(service.PlatformAntigravity + "-default-2")
}
func TestEnsureSimpleModeDefaultGroups_IgnoresSoftDeletedGroups(t *testing.T) {
ctx := context.Background()
tx := testEntTx(t)
client := tx.Client()
seedCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
defer cancel()
// Create and then soft-delete an anthropic default group.
g, err := client.Group.Create().
SetName(service.PlatformAnthropic + "-default").
SetPlatform(service.PlatformAnthropic).
SetStatus(service.StatusActive).
SetSubscriptionType(service.SubscriptionTypeStandard).
SetRateMultiplier(1.0).
SetIsExclusive(false).
Save(seedCtx)
require.NoError(t, err)
_, err = client.Group.Delete().Where(group.IDEQ(g.ID)).Exec(seedCtx)
require.NoError(t, err)
require.NoError(t, ensureSimpleModeDefaultGroups(seedCtx, client))
// New active one should exist.
count, err := client.Group.Query().Where(group.NameEQ(service.PlatformAnthropic+"-default"), group.DeletedAtIsNil()).Count(seedCtx)
require.NoError(t, err)
require.Equal(t, 1, count)
}
func TestEnsureSimpleModeDefaultGroups_AntigravityNeedsTwoGroupsOnlyByCount(t *testing.T) {
ctx := context.Background()
tx := testEntTx(t)
client := tx.Client()
seedCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
defer cancel()
mustCreateGroup(t, client, &service.Group{Name: "ag-custom-1-" + time.Now().Format(time.RFC3339Nano), Platform: service.PlatformAntigravity})
mustCreateGroup(t, client, &service.Group{Name: "ag-custom-2-" + time.Now().Format(time.RFC3339Nano), Platform: service.PlatformAntigravity})
require.NoError(t, ensureSimpleModeDefaultGroups(seedCtx, client))
count, err := client.Group.Query().Where(group.PlatformEQ(service.PlatformAntigravity), group.DeletedAtIsNil()).Count(seedCtx)
require.NoError(t, err)
require.GreaterOrEqual(t, count, 2)
}
package repository
import (
"context"
"encoding/json"
"fmt"
"time"
"github.com/redis/go-redis/v9"
"github.com/Wei-Shaw/sub2api/internal/service"
)
const (
totpSetupKeyPrefix = "totp:setup:"
totpLoginKeyPrefix = "totp:login:"
totpAttemptsKeyPrefix = "totp:attempts:"
totpAttemptsTTL = 15 * time.Minute
)
// TotpCache implements service.TotpCache using Redis
type TotpCache struct {
rdb *redis.Client
}
// NewTotpCache creates a new TOTP cache
func NewTotpCache(rdb *redis.Client) service.TotpCache {
return &TotpCache{rdb: rdb}
}
// GetSetupSession retrieves a TOTP setup session
func (c *TotpCache) GetSetupSession(ctx context.Context, userID int64) (*service.TotpSetupSession, error) {
key := fmt.Sprintf("%s%d", totpSetupKeyPrefix, userID)
data, err := c.rdb.Get(ctx, key).Bytes()
if err != nil {
if err == redis.Nil {
return nil, nil
}
return nil, fmt.Errorf("get setup session: %w", err)
}
var session service.TotpSetupSession
if err := json.Unmarshal(data, &session); err != nil {
return nil, fmt.Errorf("unmarshal setup session: %w", err)
}
return &session, nil
}
// SetSetupSession stores a TOTP setup session
func (c *TotpCache) SetSetupSession(ctx context.Context, userID int64, session *service.TotpSetupSession, ttl time.Duration) error {
key := fmt.Sprintf("%s%d", totpSetupKeyPrefix, userID)
data, err := json.Marshal(session)
if err != nil {
return fmt.Errorf("marshal setup session: %w", err)
}
if err := c.rdb.Set(ctx, key, data, ttl).Err(); err != nil {
return fmt.Errorf("set setup session: %w", err)
}
return nil
}
// DeleteSetupSession deletes a TOTP setup session
func (c *TotpCache) DeleteSetupSession(ctx context.Context, userID int64) error {
key := fmt.Sprintf("%s%d", totpSetupKeyPrefix, userID)
return c.rdb.Del(ctx, key).Err()
}
// GetLoginSession retrieves a TOTP login session
func (c *TotpCache) GetLoginSession(ctx context.Context, tempToken string) (*service.TotpLoginSession, error) {
key := totpLoginKeyPrefix + tempToken
data, err := c.rdb.Get(ctx, key).Bytes()
if err != nil {
if err == redis.Nil {
return nil, nil
}
return nil, fmt.Errorf("get login session: %w", err)
}
var session service.TotpLoginSession
if err := json.Unmarshal(data, &session); err != nil {
return nil, fmt.Errorf("unmarshal login session: %w", err)
}
return &session, nil
}
// SetLoginSession stores a TOTP login session
func (c *TotpCache) SetLoginSession(ctx context.Context, tempToken string, session *service.TotpLoginSession, ttl time.Duration) error {
key := totpLoginKeyPrefix + tempToken
data, err := json.Marshal(session)
if err != nil {
return fmt.Errorf("marshal login session: %w", err)
}
if err := c.rdb.Set(ctx, key, data, ttl).Err(); err != nil {
return fmt.Errorf("set login session: %w", err)
}
return nil
}
// DeleteLoginSession deletes a TOTP login session
func (c *TotpCache) DeleteLoginSession(ctx context.Context, tempToken string) error {
key := totpLoginKeyPrefix + tempToken
return c.rdb.Del(ctx, key).Err()
}
// IncrementVerifyAttempts increments the verify attempt counter
func (c *TotpCache) IncrementVerifyAttempts(ctx context.Context, userID int64) (int, error) {
key := fmt.Sprintf("%s%d", totpAttemptsKeyPrefix, userID)
// Use pipeline for atomic increment and set TTL
pipe := c.rdb.Pipeline()
incrCmd := pipe.Incr(ctx, key)
pipe.Expire(ctx, key, totpAttemptsTTL)
if _, err := pipe.Exec(ctx); err != nil {
return 0, fmt.Errorf("increment verify attempts: %w", err)
}
count, err := incrCmd.Result()
if err != nil {
return 0, fmt.Errorf("get increment result: %w", err)
}
return int(count), nil
}
// GetVerifyAttempts gets the current verify attempt count
func (c *TotpCache) GetVerifyAttempts(ctx context.Context, userID int64) (int, error) {
key := fmt.Sprintf("%s%d", totpAttemptsKeyPrefix, userID)
count, err := c.rdb.Get(ctx, key).Int()
if err != nil {
if err == redis.Nil {
return 0, nil
}
return 0, fmt.Errorf("get verify attempts: %w", err)
}
return count, nil
}
// ClearVerifyAttempts clears the verify attempt counter
func (c *TotpCache) ClearVerifyAttempts(ctx context.Context, userID int64) error {
key := fmt.Sprintf("%s%d", totpAttemptsKeyPrefix, userID)
return c.rdb.Del(ctx, key).Err()
}
package repository
import (
"context"
"database/sql"
"encoding/json"
"errors"
"fmt"
"strings"
"time"
dbent "github.com/Wei-Shaw/sub2api/ent"
dbusagecleanuptask "github.com/Wei-Shaw/sub2api/ent/usagecleanuptask"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service"
)
type usageCleanupRepository struct {
client *dbent.Client
sql sqlExecutor
}
func NewUsageCleanupRepository(client *dbent.Client, sqlDB *sql.DB) service.UsageCleanupRepository {
return newUsageCleanupRepositoryWithSQL(client, sqlDB)
}
func newUsageCleanupRepositoryWithSQL(client *dbent.Client, sqlq sqlExecutor) *usageCleanupRepository {
return &usageCleanupRepository{client: client, sql: sqlq}
}
func (r *usageCleanupRepository) CreateTask(ctx context.Context, task *service.UsageCleanupTask) error {
if task == nil {
return nil
}
if r.client != nil {
return r.createTaskWithEnt(ctx, task)
}
return r.createTaskWithSQL(ctx, task)
}
func (r *usageCleanupRepository) ListTasks(ctx context.Context, params pagination.PaginationParams) ([]service.UsageCleanupTask, *pagination.PaginationResult, error) {
if r.client != nil {
return r.listTasksWithEnt(ctx, params)
}
var total int64
if err := scanSingleRow(ctx, r.sql, "SELECT COUNT(*) FROM usage_cleanup_tasks", nil, &total); err != nil {
return nil, nil, err
}
if total == 0 {
return []service.UsageCleanupTask{}, paginationResultFromTotal(0, params), nil
}
query := `
SELECT id, status, filters, created_by, deleted_rows, error_message,
canceled_by, canceled_at,
started_at, finished_at, created_at, updated_at
FROM usage_cleanup_tasks
ORDER BY created_at DESC, id DESC
LIMIT $1 OFFSET $2
`
rows, err := r.sql.QueryContext(ctx, query, params.Limit(), params.Offset())
if err != nil {
return nil, nil, err
}
defer func() { _ = rows.Close() }()
tasks := make([]service.UsageCleanupTask, 0)
for rows.Next() {
var task service.UsageCleanupTask
var filtersJSON []byte
var errMsg sql.NullString
var canceledBy sql.NullInt64
var canceledAt sql.NullTime
var startedAt sql.NullTime
var finishedAt sql.NullTime
if err := rows.Scan(
&task.ID,
&task.Status,
&filtersJSON,
&task.CreatedBy,
&task.DeletedRows,
&errMsg,
&canceledBy,
&canceledAt,
&startedAt,
&finishedAt,
&task.CreatedAt,
&task.UpdatedAt,
); err != nil {
return nil, nil, err
}
if err := json.Unmarshal(filtersJSON, &task.Filters); err != nil {
return nil, nil, fmt.Errorf("parse cleanup filters: %w", err)
}
if errMsg.Valid {
task.ErrorMsg = &errMsg.String
}
if canceledBy.Valid {
v := canceledBy.Int64
task.CanceledBy = &v
}
if canceledAt.Valid {
task.CanceledAt = &canceledAt.Time
}
if startedAt.Valid {
task.StartedAt = &startedAt.Time
}
if finishedAt.Valid {
task.FinishedAt = &finishedAt.Time
}
tasks = append(tasks, task)
}
if err := rows.Err(); err != nil {
return nil, nil, err
}
return tasks, paginationResultFromTotal(total, params), nil
}
func (r *usageCleanupRepository) ClaimNextPendingTask(ctx context.Context, staleRunningAfterSeconds int64) (*service.UsageCleanupTask, error) {
if staleRunningAfterSeconds <= 0 {
staleRunningAfterSeconds = 1800
}
query := `
WITH next AS (
SELECT id
FROM usage_cleanup_tasks
WHERE status = $1
OR (
status = $2
AND started_at IS NOT NULL
AND started_at < NOW() - ($3 * interval '1 second')
)
ORDER BY created_at ASC
LIMIT 1
FOR UPDATE SKIP LOCKED
)
UPDATE usage_cleanup_tasks AS tasks
SET status = $4,
started_at = NOW(),
finished_at = NULL,
error_message = NULL,
updated_at = NOW()
FROM next
WHERE tasks.id = next.id
RETURNING tasks.id, tasks.status, tasks.filters, tasks.created_by, tasks.deleted_rows, tasks.error_message,
tasks.started_at, tasks.finished_at, tasks.created_at, tasks.updated_at
`
var task service.UsageCleanupTask
var filtersJSON []byte
var errMsg sql.NullString
var startedAt sql.NullTime
var finishedAt sql.NullTime
if err := scanSingleRow(
ctx,
r.sql,
query,
[]any{
service.UsageCleanupStatusPending,
service.UsageCleanupStatusRunning,
staleRunningAfterSeconds,
service.UsageCleanupStatusRunning,
},
&task.ID,
&task.Status,
&filtersJSON,
&task.CreatedBy,
&task.DeletedRows,
&errMsg,
&startedAt,
&finishedAt,
&task.CreatedAt,
&task.UpdatedAt,
); err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil, nil
}
return nil, err
}
if err := json.Unmarshal(filtersJSON, &task.Filters); err != nil {
return nil, fmt.Errorf("parse cleanup filters: %w", err)
}
if errMsg.Valid {
task.ErrorMsg = &errMsg.String
}
if startedAt.Valid {
task.StartedAt = &startedAt.Time
}
if finishedAt.Valid {
task.FinishedAt = &finishedAt.Time
}
return &task, nil
}
func (r *usageCleanupRepository) GetTaskStatus(ctx context.Context, taskID int64) (string, error) {
if r.client != nil {
return r.getTaskStatusWithEnt(ctx, taskID)
}
var status string
if err := scanSingleRow(ctx, r.sql, "SELECT status FROM usage_cleanup_tasks WHERE id = $1", []any{taskID}, &status); err != nil {
return "", err
}
return status, nil
}
func (r *usageCleanupRepository) UpdateTaskProgress(ctx context.Context, taskID int64, deletedRows int64) error {
if r.client != nil {
return r.updateTaskProgressWithEnt(ctx, taskID, deletedRows)
}
query := `
UPDATE usage_cleanup_tasks
SET deleted_rows = $1,
updated_at = NOW()
WHERE id = $2
`
_, err := r.sql.ExecContext(ctx, query, deletedRows, taskID)
return err
}
func (r *usageCleanupRepository) CancelTask(ctx context.Context, taskID int64, canceledBy int64) (bool, error) {
if r.client != nil {
return r.cancelTaskWithEnt(ctx, taskID, canceledBy)
}
query := `
UPDATE usage_cleanup_tasks
SET status = $1,
canceled_by = $3,
canceled_at = NOW(),
finished_at = NOW(),
error_message = NULL,
updated_at = NOW()
WHERE id = $2
AND status IN ($4, $5)
RETURNING id
`
var id int64
err := scanSingleRow(ctx, r.sql, query, []any{
service.UsageCleanupStatusCanceled,
taskID,
canceledBy,
service.UsageCleanupStatusPending,
service.UsageCleanupStatusRunning,
}, &id)
if errors.Is(err, sql.ErrNoRows) {
return false, nil
}
if err != nil {
return false, err
}
return true, nil
}
func (r *usageCleanupRepository) MarkTaskSucceeded(ctx context.Context, taskID int64, deletedRows int64) error {
if r.client != nil {
return r.markTaskSucceededWithEnt(ctx, taskID, deletedRows)
}
query := `
UPDATE usage_cleanup_tasks
SET status = $1,
deleted_rows = $2,
finished_at = NOW(),
updated_at = NOW()
WHERE id = $3
`
_, err := r.sql.ExecContext(ctx, query, service.UsageCleanupStatusSucceeded, deletedRows, taskID)
return err
}
func (r *usageCleanupRepository) MarkTaskFailed(ctx context.Context, taskID int64, deletedRows int64, errorMsg string) error {
if r.client != nil {
return r.markTaskFailedWithEnt(ctx, taskID, deletedRows, errorMsg)
}
query := `
UPDATE usage_cleanup_tasks
SET status = $1,
deleted_rows = $2,
error_message = $3,
finished_at = NOW(),
updated_at = NOW()
WHERE id = $4
`
_, err := r.sql.ExecContext(ctx, query, service.UsageCleanupStatusFailed, deletedRows, errorMsg, taskID)
return err
}
func (r *usageCleanupRepository) DeleteUsageLogsBatch(ctx context.Context, filters service.UsageCleanupFilters, limit int) (int64, error) {
if filters.StartTime.IsZero() || filters.EndTime.IsZero() {
return 0, fmt.Errorf("cleanup filters missing time range")
}
whereClause, args := buildUsageCleanupWhere(filters)
if whereClause == "" {
return 0, fmt.Errorf("cleanup filters missing time range")
}
args = append(args, limit)
query := fmt.Sprintf(`
WITH target AS (
SELECT id
FROM usage_logs
WHERE %s
ORDER BY created_at ASC, id ASC
LIMIT $%d
)
DELETE FROM usage_logs
WHERE id IN (SELECT id FROM target)
RETURNING id
`, whereClause, len(args))
rows, err := r.sql.QueryContext(ctx, query, args...)
if err != nil {
return 0, err
}
defer func() { _ = rows.Close() }()
var deleted int64
for rows.Next() {
deleted++
}
if err := rows.Err(); err != nil {
return 0, err
}
return deleted, nil
}
func buildUsageCleanupWhere(filters service.UsageCleanupFilters) (string, []any) {
conditions := make([]string, 0, 8)
args := make([]any, 0, 8)
idx := 1
if !filters.StartTime.IsZero() {
conditions = append(conditions, fmt.Sprintf("created_at >= $%d", idx))
args = append(args, filters.StartTime)
idx++
}
if !filters.EndTime.IsZero() {
conditions = append(conditions, fmt.Sprintf("created_at <= $%d", idx))
args = append(args, filters.EndTime)
idx++
}
if filters.UserID != nil {
conditions = append(conditions, fmt.Sprintf("user_id = $%d", idx))
args = append(args, *filters.UserID)
idx++
}
if filters.APIKeyID != nil {
conditions = append(conditions, fmt.Sprintf("api_key_id = $%d", idx))
args = append(args, *filters.APIKeyID)
idx++
}
if filters.AccountID != nil {
conditions = append(conditions, fmt.Sprintf("account_id = $%d", idx))
args = append(args, *filters.AccountID)
idx++
}
if filters.GroupID != nil {
conditions = append(conditions, fmt.Sprintf("group_id = $%d", idx))
args = append(args, *filters.GroupID)
idx++
}
if filters.Model != nil {
model := strings.TrimSpace(*filters.Model)
if model != "" {
conditions = append(conditions, fmt.Sprintf("model = $%d", idx))
args = append(args, model)
idx++
}
}
if filters.Stream != nil {
conditions = append(conditions, fmt.Sprintf("stream = $%d", idx))
args = append(args, *filters.Stream)
idx++
}
if filters.BillingType != nil {
conditions = append(conditions, fmt.Sprintf("billing_type = $%d", idx))
args = append(args, *filters.BillingType)
}
return strings.Join(conditions, " AND "), args
}
func (r *usageCleanupRepository) createTaskWithEnt(ctx context.Context, task *service.UsageCleanupTask) error {
client := clientFromContext(ctx, r.client)
filtersJSON, err := json.Marshal(task.Filters)
if err != nil {
return fmt.Errorf("marshal cleanup filters: %w", err)
}
created, err := client.UsageCleanupTask.
Create().
SetStatus(task.Status).
SetFilters(json.RawMessage(filtersJSON)).
SetCreatedBy(task.CreatedBy).
SetDeletedRows(task.DeletedRows).
Save(ctx)
if err != nil {
return err
}
task.ID = created.ID
task.CreatedAt = created.CreatedAt
task.UpdatedAt = created.UpdatedAt
return nil
}
func (r *usageCleanupRepository) createTaskWithSQL(ctx context.Context, task *service.UsageCleanupTask) error {
filtersJSON, err := json.Marshal(task.Filters)
if err != nil {
return fmt.Errorf("marshal cleanup filters: %w", err)
}
query := `
INSERT INTO usage_cleanup_tasks (
status,
filters,
created_by,
deleted_rows
) VALUES ($1, $2, $3, $4)
RETURNING id, created_at, updated_at
`
if err := scanSingleRow(ctx, r.sql, query, []any{task.Status, filtersJSON, task.CreatedBy, task.DeletedRows}, &task.ID, &task.CreatedAt, &task.UpdatedAt); err != nil {
return err
}
return nil
}
func (r *usageCleanupRepository) listTasksWithEnt(ctx context.Context, params pagination.PaginationParams) ([]service.UsageCleanupTask, *pagination.PaginationResult, error) {
client := clientFromContext(ctx, r.client)
query := client.UsageCleanupTask.Query()
total, err := query.Clone().Count(ctx)
if err != nil {
return nil, nil, err
}
if total == 0 {
return []service.UsageCleanupTask{}, paginationResultFromTotal(0, params), nil
}
rows, err := query.
Order(dbent.Desc(dbusagecleanuptask.FieldCreatedAt), dbent.Desc(dbusagecleanuptask.FieldID)).
Offset(params.Offset()).
Limit(params.Limit()).
All(ctx)
if err != nil {
return nil, nil, err
}
tasks := make([]service.UsageCleanupTask, 0, len(rows))
for _, row := range rows {
task, err := usageCleanupTaskFromEnt(row)
if err != nil {
return nil, nil, err
}
tasks = append(tasks, task)
}
return tasks, paginationResultFromTotal(int64(total), params), nil
}
func (r *usageCleanupRepository) getTaskStatusWithEnt(ctx context.Context, taskID int64) (string, error) {
client := clientFromContext(ctx, r.client)
task, err := client.UsageCleanupTask.Query().
Where(dbusagecleanuptask.IDEQ(taskID)).
Only(ctx)
if err != nil {
if dbent.IsNotFound(err) {
return "", sql.ErrNoRows
}
return "", err
}
return task.Status, nil
}
func (r *usageCleanupRepository) updateTaskProgressWithEnt(ctx context.Context, taskID int64, deletedRows int64) error {
client := clientFromContext(ctx, r.client)
now := time.Now()
_, err := client.UsageCleanupTask.Update().
Where(dbusagecleanuptask.IDEQ(taskID)).
SetDeletedRows(deletedRows).
SetUpdatedAt(now).
Save(ctx)
return err
}
func (r *usageCleanupRepository) cancelTaskWithEnt(ctx context.Context, taskID int64, canceledBy int64) (bool, error) {
client := clientFromContext(ctx, r.client)
now := time.Now()
affected, err := client.UsageCleanupTask.Update().
Where(
dbusagecleanuptask.IDEQ(taskID),
dbusagecleanuptask.StatusIn(service.UsageCleanupStatusPending, service.UsageCleanupStatusRunning),
).
SetStatus(service.UsageCleanupStatusCanceled).
SetCanceledBy(canceledBy).
SetCanceledAt(now).
SetFinishedAt(now).
ClearErrorMessage().
SetUpdatedAt(now).
Save(ctx)
if err != nil {
return false, err
}
return affected > 0, nil
}
func (r *usageCleanupRepository) markTaskSucceededWithEnt(ctx context.Context, taskID int64, deletedRows int64) error {
client := clientFromContext(ctx, r.client)
now := time.Now()
_, err := client.UsageCleanupTask.Update().
Where(dbusagecleanuptask.IDEQ(taskID)).
SetStatus(service.UsageCleanupStatusSucceeded).
SetDeletedRows(deletedRows).
SetFinishedAt(now).
SetUpdatedAt(now).
Save(ctx)
return err
}
func (r *usageCleanupRepository) markTaskFailedWithEnt(ctx context.Context, taskID int64, deletedRows int64, errorMsg string) error {
client := clientFromContext(ctx, r.client)
now := time.Now()
_, err := client.UsageCleanupTask.Update().
Where(dbusagecleanuptask.IDEQ(taskID)).
SetStatus(service.UsageCleanupStatusFailed).
SetDeletedRows(deletedRows).
SetErrorMessage(errorMsg).
SetFinishedAt(now).
SetUpdatedAt(now).
Save(ctx)
return err
}
func usageCleanupTaskFromEnt(row *dbent.UsageCleanupTask) (service.UsageCleanupTask, error) {
task := service.UsageCleanupTask{
ID: row.ID,
Status: row.Status,
CreatedBy: row.CreatedBy,
DeletedRows: row.DeletedRows,
CreatedAt: row.CreatedAt,
UpdatedAt: row.UpdatedAt,
}
if len(row.Filters) > 0 {
if err := json.Unmarshal(row.Filters, &task.Filters); err != nil {
return service.UsageCleanupTask{}, fmt.Errorf("parse cleanup filters: %w", err)
}
}
if row.ErrorMessage != nil {
task.ErrorMsg = row.ErrorMessage
}
if row.CanceledBy != nil {
task.CanceledBy = row.CanceledBy
}
if row.CanceledAt != nil {
task.CanceledAt = row.CanceledAt
}
if row.StartedAt != nil {
task.StartedAt = row.StartedAt
}
if row.FinishedAt != nil {
task.FinishedAt = row.FinishedAt
}
return task, nil
}
package repository
import (
"context"
"database/sql"
"encoding/json"
"testing"
"time"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/enttest"
dbusagecleanuptask "github.com/Wei-Shaw/sub2api/ent/usagecleanuptask"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/stretchr/testify/require"
"entgo.io/ent/dialect"
entsql "entgo.io/ent/dialect/sql"
_ "modernc.org/sqlite"
)
func newUsageCleanupEntRepo(t *testing.T) (*usageCleanupRepository, *dbent.Client) {
t.Helper()
db, err := sql.Open("sqlite", "file:usage_cleanup?mode=memory&cache=shared")
require.NoError(t, err)
t.Cleanup(func() { _ = db.Close() })
_, err = db.Exec("PRAGMA foreign_keys = ON")
require.NoError(t, err)
drv := entsql.OpenDB(dialect.SQLite, db)
client := enttest.NewClient(t, enttest.WithOptions(dbent.Driver(drv)))
t.Cleanup(func() { _ = client.Close() })
repo := &usageCleanupRepository{client: client, sql: db}
return repo, client
}
func TestUsageCleanupRepositoryEntCreateAndList(t *testing.T) {
repo, _ := newUsageCleanupEntRepo(t)
start := time.Date(2024, 1, 2, 0, 0, 0, 0, time.UTC)
end := start.Add(24 * time.Hour)
task := &service.UsageCleanupTask{
Status: service.UsageCleanupStatusPending,
Filters: service.UsageCleanupFilters{StartTime: start, EndTime: end},
CreatedBy: 9,
}
require.NoError(t, repo.CreateTask(context.Background(), task))
require.NotZero(t, task.ID)
task2 := &service.UsageCleanupTask{
Status: service.UsageCleanupStatusRunning,
Filters: service.UsageCleanupFilters{StartTime: start.Add(-24 * time.Hour), EndTime: end.Add(-24 * time.Hour)},
CreatedBy: 10,
}
require.NoError(t, repo.CreateTask(context.Background(), task2))
tasks, result, err := repo.ListTasks(context.Background(), pagination.PaginationParams{Page: 1, PageSize: 10})
require.NoError(t, err)
require.Len(t, tasks, 2)
require.Equal(t, int64(2), result.Total)
require.Greater(t, tasks[0].ID, tasks[1].ID)
require.Equal(t, start, tasks[1].Filters.StartTime)
require.Equal(t, end, tasks[1].Filters.EndTime)
}
func TestUsageCleanupRepositoryEntListEmpty(t *testing.T) {
repo, _ := newUsageCleanupEntRepo(t)
tasks, result, err := repo.ListTasks(context.Background(), pagination.PaginationParams{Page: 1, PageSize: 10})
require.NoError(t, err)
require.Empty(t, tasks)
require.Equal(t, int64(0), result.Total)
}
func TestUsageCleanupRepositoryEntGetStatusAndProgress(t *testing.T) {
repo, client := newUsageCleanupEntRepo(t)
task := &service.UsageCleanupTask{
Status: service.UsageCleanupStatusPending,
Filters: service.UsageCleanupFilters{StartTime: time.Now().UTC(), EndTime: time.Now().UTC().Add(time.Hour)},
CreatedBy: 3,
}
require.NoError(t, repo.CreateTask(context.Background(), task))
status, err := repo.GetTaskStatus(context.Background(), task.ID)
require.NoError(t, err)
require.Equal(t, service.UsageCleanupStatusPending, status)
_, err = repo.GetTaskStatus(context.Background(), task.ID+99)
require.ErrorIs(t, err, sql.ErrNoRows)
require.NoError(t, repo.UpdateTaskProgress(context.Background(), task.ID, 42))
loaded, err := client.UsageCleanupTask.Get(context.Background(), task.ID)
require.NoError(t, err)
require.Equal(t, int64(42), loaded.DeletedRows)
}
func TestUsageCleanupRepositoryEntCancelAndFinish(t *testing.T) {
repo, client := newUsageCleanupEntRepo(t)
task := &service.UsageCleanupTask{
Status: service.UsageCleanupStatusPending,
Filters: service.UsageCleanupFilters{StartTime: time.Now().UTC(), EndTime: time.Now().UTC().Add(time.Hour)},
CreatedBy: 5,
}
require.NoError(t, repo.CreateTask(context.Background(), task))
ok, err := repo.CancelTask(context.Background(), task.ID, 7)
require.NoError(t, err)
require.True(t, ok)
loaded, err := client.UsageCleanupTask.Get(context.Background(), task.ID)
require.NoError(t, err)
require.Equal(t, service.UsageCleanupStatusCanceled, loaded.Status)
require.NotNil(t, loaded.CanceledBy)
require.NotNil(t, loaded.CanceledAt)
require.NotNil(t, loaded.FinishedAt)
loaded.Status = service.UsageCleanupStatusSucceeded
_, err = client.UsageCleanupTask.Update().Where(dbusagecleanuptask.IDEQ(task.ID)).SetStatus(loaded.Status).Save(context.Background())
require.NoError(t, err)
ok, err = repo.CancelTask(context.Background(), task.ID, 7)
require.NoError(t, err)
require.False(t, ok)
}
func TestUsageCleanupRepositoryEntCancelError(t *testing.T) {
repo, client := newUsageCleanupEntRepo(t)
task := &service.UsageCleanupTask{
Status: service.UsageCleanupStatusPending,
Filters: service.UsageCleanupFilters{StartTime: time.Now().UTC(), EndTime: time.Now().UTC().Add(time.Hour)},
CreatedBy: 5,
}
require.NoError(t, repo.CreateTask(context.Background(), task))
require.NoError(t, client.Close())
_, err := repo.CancelTask(context.Background(), task.ID, 7)
require.Error(t, err)
}
func TestUsageCleanupRepositoryEntMarkResults(t *testing.T) {
repo, client := newUsageCleanupEntRepo(t)
task := &service.UsageCleanupTask{
Status: service.UsageCleanupStatusRunning,
Filters: service.UsageCleanupFilters{StartTime: time.Now().UTC(), EndTime: time.Now().UTC().Add(time.Hour)},
CreatedBy: 12,
}
require.NoError(t, repo.CreateTask(context.Background(), task))
require.NoError(t, repo.MarkTaskSucceeded(context.Background(), task.ID, 6))
loaded, err := client.UsageCleanupTask.Get(context.Background(), task.ID)
require.NoError(t, err)
require.Equal(t, service.UsageCleanupStatusSucceeded, loaded.Status)
require.Equal(t, int64(6), loaded.DeletedRows)
require.NotNil(t, loaded.FinishedAt)
task2 := &service.UsageCleanupTask{
Status: service.UsageCleanupStatusRunning,
Filters: service.UsageCleanupFilters{StartTime: time.Now().UTC(), EndTime: time.Now().UTC().Add(time.Hour)},
CreatedBy: 12,
}
require.NoError(t, repo.CreateTask(context.Background(), task2))
require.NoError(t, repo.MarkTaskFailed(context.Background(), task2.ID, 4, "boom"))
loaded2, err := client.UsageCleanupTask.Get(context.Background(), task2.ID)
require.NoError(t, err)
require.Equal(t, service.UsageCleanupStatusFailed, loaded2.Status)
require.Equal(t, "boom", *loaded2.ErrorMessage)
}
func TestUsageCleanupRepositoryEntInvalidStatus(t *testing.T) {
repo, _ := newUsageCleanupEntRepo(t)
task := &service.UsageCleanupTask{
Status: "invalid",
Filters: service.UsageCleanupFilters{StartTime: time.Now().UTC(), EndTime: time.Now().UTC().Add(time.Hour)},
CreatedBy: 1,
}
require.Error(t, repo.CreateTask(context.Background(), task))
}
func TestUsageCleanupRepositoryEntListInvalidFilters(t *testing.T) {
repo, client := newUsageCleanupEntRepo(t)
now := time.Now().UTC()
driver, ok := client.Driver().(*entsql.Driver)
require.True(t, ok)
_, err := driver.DB().ExecContext(
context.Background(),
`INSERT INTO usage_cleanup_tasks (status, filters, created_by, deleted_rows, created_at, updated_at)
VALUES (?, ?, ?, ?, ?, ?)`,
service.UsageCleanupStatusPending,
[]byte("invalid-json"),
int64(1),
int64(0),
now,
now,
)
require.NoError(t, err)
_, _, err = repo.ListTasks(context.Background(), pagination.PaginationParams{Page: 1, PageSize: 10})
require.Error(t, err)
}
func TestUsageCleanupTaskFromEntFull(t *testing.T) {
start := time.Date(2024, 1, 2, 0, 0, 0, 0, time.UTC)
end := start.Add(24 * time.Hour)
errMsg := "failed"
canceledBy := int64(2)
canceledAt := start.Add(time.Minute)
startedAt := start.Add(2 * time.Minute)
finishedAt := start.Add(3 * time.Minute)
filters := service.UsageCleanupFilters{StartTime: start, EndTime: end}
filtersJSON, err := json.Marshal(filters)
require.NoError(t, err)
task, err := usageCleanupTaskFromEnt(&dbent.UsageCleanupTask{
ID: 10,
Status: service.UsageCleanupStatusFailed,
Filters: filtersJSON,
CreatedBy: 11,
DeletedRows: 7,
ErrorMessage: &errMsg,
CanceledBy: &canceledBy,
CanceledAt: &canceledAt,
StartedAt: &startedAt,
FinishedAt: &finishedAt,
CreatedAt: start,
UpdatedAt: end,
})
require.NoError(t, err)
require.Equal(t, int64(10), task.ID)
require.Equal(t, service.UsageCleanupStatusFailed, task.Status)
require.NotNil(t, task.ErrorMsg)
require.NotNil(t, task.CanceledBy)
require.NotNil(t, task.CanceledAt)
require.NotNil(t, task.StartedAt)
require.NotNil(t, task.FinishedAt)
}
func TestUsageCleanupTaskFromEntInvalidFilters(t *testing.T) {
task, err := usageCleanupTaskFromEnt(&dbent.UsageCleanupTask{
Filters: json.RawMessage("invalid-json"),
})
require.Error(t, err)
require.Empty(t, task)
}
package repository
import (
"context"
"database/sql"
"encoding/json"
"testing"
"time"
"github.com/DATA-DOG/go-sqlmock"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/stretchr/testify/require"
)
func newSQLMock(t *testing.T) (*sql.DB, sqlmock.Sqlmock) {
t.Helper()
db, mock, err := sqlmock.New(sqlmock.QueryMatcherOption(sqlmock.QueryMatcherRegexp))
require.NoError(t, err)
t.Cleanup(func() { _ = db.Close() })
return db, mock
}
func TestNewUsageCleanupRepository(t *testing.T) {
db, _ := newSQLMock(t)
repo := NewUsageCleanupRepository(nil, db)
require.NotNil(t, repo)
}
func TestUsageCleanupRepositoryCreateTask(t *testing.T) {
db, mock := newSQLMock(t)
repo := &usageCleanupRepository{sql: db}
start := time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC)
end := start.Add(24 * time.Hour)
task := &service.UsageCleanupTask{
Status: service.UsageCleanupStatusPending,
Filters: service.UsageCleanupFilters{StartTime: start, EndTime: end},
CreatedBy: 12,
}
now := time.Date(2024, 1, 2, 0, 0, 0, 0, time.UTC)
mock.ExpectQuery("INSERT INTO usage_cleanup_tasks").
WithArgs(task.Status, sqlmock.AnyArg(), task.CreatedBy, task.DeletedRows).
WillReturnRows(sqlmock.NewRows([]string{"id", "created_at", "updated_at"}).AddRow(int64(1), now, now))
err := repo.CreateTask(context.Background(), task)
require.NoError(t, err)
require.Equal(t, int64(1), task.ID)
require.Equal(t, now, task.CreatedAt)
require.Equal(t, now, task.UpdatedAt)
require.NoError(t, mock.ExpectationsWereMet())
}
func TestUsageCleanupRepositoryCreateTaskNil(t *testing.T) {
db, mock := newSQLMock(t)
repo := &usageCleanupRepository{sql: db}
err := repo.CreateTask(context.Background(), nil)
require.NoError(t, err)
require.NoError(t, mock.ExpectationsWereMet())
}
func TestUsageCleanupRepositoryCreateTaskQueryError(t *testing.T) {
db, mock := newSQLMock(t)
repo := &usageCleanupRepository{sql: db}
task := &service.UsageCleanupTask{
Status: service.UsageCleanupStatusPending,
Filters: service.UsageCleanupFilters{StartTime: time.Now(), EndTime: time.Now().Add(time.Hour)},
CreatedBy: 1,
}
mock.ExpectQuery("INSERT INTO usage_cleanup_tasks").
WithArgs(task.Status, sqlmock.AnyArg(), task.CreatedBy, task.DeletedRows).
WillReturnError(sql.ErrConnDone)
err := repo.CreateTask(context.Background(), task)
require.Error(t, err)
require.NoError(t, mock.ExpectationsWereMet())
}
func TestUsageCleanupRepositoryListTasksEmpty(t *testing.T) {
db, mock := newSQLMock(t)
repo := &usageCleanupRepository{sql: db}
mock.ExpectQuery("SELECT COUNT\\(\\*\\) FROM usage_cleanup_tasks").
WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(int64(0)))
tasks, result, err := repo.ListTasks(context.Background(), pagination.PaginationParams{Page: 1, PageSize: 20})
require.NoError(t, err)
require.Empty(t, tasks)
require.Equal(t, int64(0), result.Total)
require.NoError(t, mock.ExpectationsWereMet())
}
func TestUsageCleanupRepositoryListTasks(t *testing.T) {
db, mock := newSQLMock(t)
repo := &usageCleanupRepository{sql: db}
start := time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC)
end := start.Add(2 * time.Hour)
filters := service.UsageCleanupFilters{StartTime: start, EndTime: end}
filtersJSON, err := json.Marshal(filters)
require.NoError(t, err)
createdAt := time.Date(2024, 1, 2, 12, 0, 0, 0, time.UTC)
updatedAt := createdAt.Add(time.Minute)
rows := sqlmock.NewRows([]string{
"id", "status", "filters", "created_by", "deleted_rows", "error_message",
"canceled_by", "canceled_at",
"started_at", "finished_at", "created_at", "updated_at",
}).AddRow(
int64(1),
service.UsageCleanupStatusSucceeded,
filtersJSON,
int64(2),
int64(9),
"error",
nil,
nil,
start,
end,
createdAt,
updatedAt,
)
mock.ExpectQuery("SELECT COUNT\\(\\*\\) FROM usage_cleanup_tasks").
WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(int64(1)))
mock.ExpectQuery("SELECT id, status, filters, created_by, deleted_rows, error_message").
WithArgs(20, 0).
WillReturnRows(rows)
tasks, result, err := repo.ListTasks(context.Background(), pagination.PaginationParams{Page: 1, PageSize: 20})
require.NoError(t, err)
require.Len(t, tasks, 1)
require.Equal(t, int64(1), tasks[0].ID)
require.Equal(t, service.UsageCleanupStatusSucceeded, tasks[0].Status)
require.Equal(t, int64(2), tasks[0].CreatedBy)
require.Equal(t, int64(9), tasks[0].DeletedRows)
require.NotNil(t, tasks[0].ErrorMsg)
require.Equal(t, "error", *tasks[0].ErrorMsg)
require.NotNil(t, tasks[0].StartedAt)
require.NotNil(t, tasks[0].FinishedAt)
require.Equal(t, int64(1), result.Total)
require.NoError(t, mock.ExpectationsWereMet())
}
func TestUsageCleanupRepositoryListTasksQueryError(t *testing.T) {
db, mock := newSQLMock(t)
repo := &usageCleanupRepository{sql: db}
mock.ExpectQuery("SELECT COUNT\\(\\*\\) FROM usage_cleanup_tasks").
WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(int64(2)))
mock.ExpectQuery("SELECT id, status, filters, created_by, deleted_rows, error_message").
WithArgs(20, 0).
WillReturnError(sql.ErrConnDone)
_, _, err := repo.ListTasks(context.Background(), pagination.PaginationParams{Page: 1, PageSize: 20})
require.Error(t, err)
require.NoError(t, mock.ExpectationsWereMet())
}
func TestUsageCleanupRepositoryListTasksInvalidFilters(t *testing.T) {
db, mock := newSQLMock(t)
repo := &usageCleanupRepository{sql: db}
rows := sqlmock.NewRows([]string{
"id", "status", "filters", "created_by", "deleted_rows", "error_message",
"canceled_by", "canceled_at",
"started_at", "finished_at", "created_at", "updated_at",
}).AddRow(
int64(1),
service.UsageCleanupStatusSucceeded,
[]byte("not-json"),
int64(2),
int64(9),
nil,
nil,
nil,
nil,
nil,
time.Now().UTC(),
time.Now().UTC(),
)
mock.ExpectQuery("SELECT COUNT\\(\\*\\) FROM usage_cleanup_tasks").
WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(int64(1)))
mock.ExpectQuery("SELECT id, status, filters, created_by, deleted_rows, error_message").
WithArgs(20, 0).
WillReturnRows(rows)
_, _, err := repo.ListTasks(context.Background(), pagination.PaginationParams{Page: 1, PageSize: 20})
require.Error(t, err)
require.NoError(t, mock.ExpectationsWereMet())
}
func TestUsageCleanupRepositoryClaimNextPendingTaskNone(t *testing.T) {
db, mock := newSQLMock(t)
repo := &usageCleanupRepository{sql: db}
mock.ExpectQuery("UPDATE usage_cleanup_tasks").
WithArgs(service.UsageCleanupStatusPending, service.UsageCleanupStatusRunning, int64(1800), service.UsageCleanupStatusRunning).
WillReturnRows(sqlmock.NewRows([]string{
"id", "status", "filters", "created_by", "deleted_rows", "error_message",
"started_at", "finished_at", "created_at", "updated_at",
}))
task, err := repo.ClaimNextPendingTask(context.Background(), 1800)
require.NoError(t, err)
require.Nil(t, task)
require.NoError(t, mock.ExpectationsWereMet())
}
func TestUsageCleanupRepositoryClaimNextPendingTask(t *testing.T) {
db, mock := newSQLMock(t)
repo := &usageCleanupRepository{sql: db}
start := time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC)
end := start.Add(24 * time.Hour)
filters := service.UsageCleanupFilters{StartTime: start, EndTime: end}
filtersJSON, err := json.Marshal(filters)
require.NoError(t, err)
rows := sqlmock.NewRows([]string{
"id", "status", "filters", "created_by", "deleted_rows", "error_message",
"started_at", "finished_at", "created_at", "updated_at",
}).AddRow(
int64(4),
service.UsageCleanupStatusRunning,
filtersJSON,
int64(7),
int64(0),
nil,
start,
nil,
start,
start,
)
mock.ExpectQuery("UPDATE usage_cleanup_tasks").
WithArgs(service.UsageCleanupStatusPending, service.UsageCleanupStatusRunning, int64(1800), service.UsageCleanupStatusRunning).
WillReturnRows(rows)
task, err := repo.ClaimNextPendingTask(context.Background(), 1800)
require.NoError(t, err)
require.NotNil(t, task)
require.Equal(t, int64(4), task.ID)
require.Equal(t, service.UsageCleanupStatusRunning, task.Status)
require.Equal(t, int64(7), task.CreatedBy)
require.NotNil(t, task.StartedAt)
require.Nil(t, task.ErrorMsg)
require.NoError(t, mock.ExpectationsWereMet())
}
func TestUsageCleanupRepositoryClaimNextPendingTaskError(t *testing.T) {
db, mock := newSQLMock(t)
repo := &usageCleanupRepository{sql: db}
mock.ExpectQuery("UPDATE usage_cleanup_tasks").
WithArgs(service.UsageCleanupStatusPending, service.UsageCleanupStatusRunning, int64(1800), service.UsageCleanupStatusRunning).
WillReturnError(sql.ErrConnDone)
_, err := repo.ClaimNextPendingTask(context.Background(), 1800)
require.Error(t, err)
require.NoError(t, mock.ExpectationsWereMet())
}
func TestUsageCleanupRepositoryClaimNextPendingTaskInvalidFilters(t *testing.T) {
db, mock := newSQLMock(t)
repo := &usageCleanupRepository{sql: db}
rows := sqlmock.NewRows([]string{
"id", "status", "filters", "created_by", "deleted_rows", "error_message",
"started_at", "finished_at", "created_at", "updated_at",
}).AddRow(
int64(4),
service.UsageCleanupStatusRunning,
[]byte("invalid"),
int64(7),
int64(0),
nil,
nil,
nil,
time.Now().UTC(),
time.Now().UTC(),
)
mock.ExpectQuery("UPDATE usage_cleanup_tasks").
WithArgs(service.UsageCleanupStatusPending, service.UsageCleanupStatusRunning, int64(1800), service.UsageCleanupStatusRunning).
WillReturnRows(rows)
_, err := repo.ClaimNextPendingTask(context.Background(), 1800)
require.Error(t, err)
require.NoError(t, mock.ExpectationsWereMet())
}
func TestUsageCleanupRepositoryMarkTaskSucceeded(t *testing.T) {
db, mock := newSQLMock(t)
repo := &usageCleanupRepository{sql: db}
mock.ExpectExec("UPDATE usage_cleanup_tasks").
WithArgs(service.UsageCleanupStatusSucceeded, int64(12), int64(9)).
WillReturnResult(sqlmock.NewResult(0, 1))
err := repo.MarkTaskSucceeded(context.Background(), 9, 12)
require.NoError(t, err)
require.NoError(t, mock.ExpectationsWereMet())
}
func TestUsageCleanupRepositoryMarkTaskFailed(t *testing.T) {
db, mock := newSQLMock(t)
repo := &usageCleanupRepository{sql: db}
mock.ExpectExec("UPDATE usage_cleanup_tasks").
WithArgs(service.UsageCleanupStatusFailed, int64(4), "boom", int64(2)).
WillReturnResult(sqlmock.NewResult(0, 1))
err := repo.MarkTaskFailed(context.Background(), 2, 4, "boom")
require.NoError(t, err)
require.NoError(t, mock.ExpectationsWereMet())
}
func TestUsageCleanupRepositoryGetTaskStatus(t *testing.T) {
db, mock := newSQLMock(t)
repo := &usageCleanupRepository{sql: db}
mock.ExpectQuery("SELECT status FROM usage_cleanup_tasks").
WithArgs(int64(9)).
WillReturnRows(sqlmock.NewRows([]string{"status"}).AddRow(service.UsageCleanupStatusPending))
status, err := repo.GetTaskStatus(context.Background(), 9)
require.NoError(t, err)
require.Equal(t, service.UsageCleanupStatusPending, status)
require.NoError(t, mock.ExpectationsWereMet())
}
func TestUsageCleanupRepositoryGetTaskStatusQueryError(t *testing.T) {
db, mock := newSQLMock(t)
repo := &usageCleanupRepository{sql: db}
mock.ExpectQuery("SELECT status FROM usage_cleanup_tasks").
WithArgs(int64(9)).
WillReturnError(sql.ErrConnDone)
_, err := repo.GetTaskStatus(context.Background(), 9)
require.Error(t, err)
require.NoError(t, mock.ExpectationsWereMet())
}
func TestUsageCleanupRepositoryUpdateTaskProgress(t *testing.T) {
db, mock := newSQLMock(t)
repo := &usageCleanupRepository{sql: db}
mock.ExpectExec("UPDATE usage_cleanup_tasks").
WithArgs(int64(123), int64(8)).
WillReturnResult(sqlmock.NewResult(0, 1))
err := repo.UpdateTaskProgress(context.Background(), 8, 123)
require.NoError(t, err)
require.NoError(t, mock.ExpectationsWereMet())
}
func TestUsageCleanupRepositoryCancelTask(t *testing.T) {
db, mock := newSQLMock(t)
repo := &usageCleanupRepository{sql: db}
mock.ExpectQuery("UPDATE usage_cleanup_tasks").
WithArgs(service.UsageCleanupStatusCanceled, int64(6), int64(9), service.UsageCleanupStatusPending, service.UsageCleanupStatusRunning).
WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(int64(6)))
ok, err := repo.CancelTask(context.Background(), 6, 9)
require.NoError(t, err)
require.True(t, ok)
require.NoError(t, mock.ExpectationsWereMet())
}
func TestUsageCleanupRepositoryCancelTaskNoRows(t *testing.T) {
db, mock := newSQLMock(t)
repo := &usageCleanupRepository{sql: db}
mock.ExpectQuery("UPDATE usage_cleanup_tasks").
WithArgs(service.UsageCleanupStatusCanceled, int64(6), int64(9), service.UsageCleanupStatusPending, service.UsageCleanupStatusRunning).
WillReturnRows(sqlmock.NewRows([]string{"id"}))
ok, err := repo.CancelTask(context.Background(), 6, 9)
require.NoError(t, err)
require.False(t, ok)
require.NoError(t, mock.ExpectationsWereMet())
}
func TestUsageCleanupRepositoryDeleteUsageLogsBatchMissingRange(t *testing.T) {
db, _ := newSQLMock(t)
repo := &usageCleanupRepository{sql: db}
_, err := repo.DeleteUsageLogsBatch(context.Background(), service.UsageCleanupFilters{}, 10)
require.Error(t, err)
}
func TestUsageCleanupRepositoryDeleteUsageLogsBatch(t *testing.T) {
db, mock := newSQLMock(t)
repo := &usageCleanupRepository{sql: db}
start := time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC)
end := start.Add(24 * time.Hour)
userID := int64(3)
model := " gpt-4 "
filters := service.UsageCleanupFilters{
StartTime: start,
EndTime: end,
UserID: &userID,
Model: &model,
}
mock.ExpectQuery("DELETE FROM usage_logs").
WithArgs(start, end, userID, "gpt-4", 2).
WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(int64(1)).AddRow(int64(2)))
deleted, err := repo.DeleteUsageLogsBatch(context.Background(), filters, 2)
require.NoError(t, err)
require.Equal(t, int64(2), deleted)
require.NoError(t, mock.ExpectationsWereMet())
}
func TestUsageCleanupRepositoryDeleteUsageLogsBatchQueryError(t *testing.T) {
db, mock := newSQLMock(t)
repo := &usageCleanupRepository{sql: db}
start := time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC)
end := start.Add(24 * time.Hour)
filters := service.UsageCleanupFilters{StartTime: start, EndTime: end}
mock.ExpectQuery("DELETE FROM usage_logs").
WithArgs(start, end, 5).
WillReturnError(sql.ErrConnDone)
_, err := repo.DeleteUsageLogsBatch(context.Background(), filters, 5)
require.Error(t, err)
require.NoError(t, mock.ExpectationsWereMet())
}
func TestBuildUsageCleanupWhere(t *testing.T) {
start := time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC)
end := start.Add(24 * time.Hour)
userID := int64(1)
apiKeyID := int64(2)
accountID := int64(3)
groupID := int64(4)
model := " gpt-4 "
stream := true
billingType := int8(2)
where, args := buildUsageCleanupWhere(service.UsageCleanupFilters{
StartTime: start,
EndTime: end,
UserID: &userID,
APIKeyID: &apiKeyID,
AccountID: &accountID,
GroupID: &groupID,
Model: &model,
Stream: &stream,
BillingType: &billingType,
})
require.Equal(t, "created_at >= $1 AND created_at <= $2 AND user_id = $3 AND api_key_id = $4 AND account_id = $5 AND group_id = $6 AND model = $7 AND stream = $8 AND billing_type = $9", where)
require.Equal(t, []any{start, end, userID, apiKeyID, accountID, groupID, "gpt-4", stream, billingType}, args)
}
func TestBuildUsageCleanupWhereModelEmpty(t *testing.T) {
start := time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC)
end := start.Add(24 * time.Hour)
model := " "
where, args := buildUsageCleanupWhere(service.UsageCleanupFilters{
StartTime: start,
EndTime: end,
Model: &model,
})
require.Equal(t, "created_at >= $1 AND created_at <= $2", where)
require.Equal(t, []any{start, end}, args)
}
......@@ -1411,7 +1411,7 @@ func (r *usageLogRepository) GetBatchAPIKeyUsageStats(ctx context.Context, apiKe
}
// GetUsageTrendWithFilters returns usage trend data with optional filters
func (r *usageLogRepository) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID, accountID, groupID int64, model string, stream *bool) (results []TrendDataPoint, err error) {
func (r *usageLogRepository) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID, accountID, groupID int64, model string, stream *bool, billingType *int8) (results []TrendDataPoint, err error) {
dateFormat := "YYYY-MM-DD"
if granularity == "hour" {
dateFormat = "YYYY-MM-DD HH24:00"
......@@ -1456,6 +1456,10 @@ func (r *usageLogRepository) GetUsageTrendWithFilters(ctx context.Context, start
query += fmt.Sprintf(" AND stream = $%d", len(args)+1)
args = append(args, *stream)
}
if billingType != nil {
query += fmt.Sprintf(" AND billing_type = $%d", len(args)+1)
args = append(args, int16(*billingType))
}
query += " GROUP BY date ORDER BY date ASC"
rows, err := r.sql.QueryContext(ctx, query, args...)
......@@ -1479,7 +1483,7 @@ func (r *usageLogRepository) GetUsageTrendWithFilters(ctx context.Context, start
}
// GetModelStatsWithFilters returns model statistics with optional filters
func (r *usageLogRepository) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, stream *bool) (results []ModelStat, err error) {
func (r *usageLogRepository) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, stream *bool, billingType *int8) (results []ModelStat, err error) {
actualCostExpr := "COALESCE(SUM(actual_cost), 0) as actual_cost"
// 当仅按 account_id 聚合时,实际费用使用账号倍率(total_cost * account_rate_multiplier)。
if accountID > 0 && userID == 0 && apiKeyID == 0 {
......@@ -1520,6 +1524,10 @@ func (r *usageLogRepository) GetModelStatsWithFilters(ctx context.Context, start
query += fmt.Sprintf(" AND stream = $%d", len(args)+1)
args = append(args, *stream)
}
if billingType != nil {
query += fmt.Sprintf(" AND billing_type = $%d", len(args)+1)
args = append(args, int16(*billingType))
}
query += " GROUP BY model ORDER BY total_tokens DESC"
rows, err := r.sql.QueryContext(ctx, query, args...)
......@@ -1825,7 +1833,7 @@ func (r *usageLogRepository) GetAccountUsageStats(ctx context.Context, accountID
}
}
models, err := r.GetModelStatsWithFilters(ctx, startTime, endTime, 0, 0, accountID, 0, nil)
models, err := r.GetModelStatsWithFilters(ctx, startTime, endTime, 0, 0, accountID, 0, nil, nil)
if err != nil {
models = []ModelStat{}
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment