Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Menu
Open sidebar
陈曦
sub2api
Commits
2fe8932c
Unverified
Commit
2fe8932c
authored
Feb 03, 2026
by
Call White
Committed by
GitHub
Feb 03, 2026
Browse files
Merge pull request #3 from cyhhao/main
merge to main
parents
2f2e76f9
adb77af1
Changes
267
Show whitespace changes
Inline
Side-by-side
backend/internal/repository/req_client_pool.go
View file @
2fe8932c
...
...
@@ -14,6 +14,7 @@ type reqClientOptions struct {
ProxyURL
string
// 代理 URL(支持 http/https/socks5)
Timeout
time
.
Duration
// 请求超时时间
Impersonate
bool
// 是否模拟 Chrome 浏览器指纹
ForceHTTP2
bool
// 是否强制使用 HTTP/2
}
// sharedReqClients 存储按配置参数缓存的 req 客户端实例
...
...
@@ -41,6 +42,9 @@ func getSharedReqClient(opts reqClientOptions) *req.Client {
}
client
:=
req
.
C
()
.
SetTimeout
(
opts
.
Timeout
)
if
opts
.
ForceHTTP2
{
client
=
client
.
EnableForceHTTP2
()
}
if
opts
.
Impersonate
{
client
=
client
.
ImpersonateChrome
()
}
...
...
@@ -56,9 +60,10 @@ func getSharedReqClient(opts reqClientOptions) *req.Client {
}
func
buildReqClientKey
(
opts
reqClientOptions
)
string
{
return
fmt
.
Sprintf
(
"%s|%s|%t"
,
return
fmt
.
Sprintf
(
"%s|%s|%t
|%t
"
,
strings
.
TrimSpace
(
opts
.
ProxyURL
),
opts
.
Timeout
.
String
(),
opts
.
Impersonate
,
opts
.
ForceHTTP2
,
)
}
backend/internal/repository/req_client_pool_test.go
0 → 100644
View file @
2fe8932c
package
repository
import
(
"reflect"
"sync"
"testing"
"time"
"unsafe"
"github.com/imroc/req/v3"
"github.com/stretchr/testify/require"
)
func
forceHTTPVersion
(
t
*
testing
.
T
,
client
*
req
.
Client
)
string
{
t
.
Helper
()
transport
:=
client
.
GetTransport
()
field
:=
reflect
.
ValueOf
(
transport
)
.
Elem
()
.
FieldByName
(
"forceHttpVersion"
)
require
.
True
(
t
,
field
.
IsValid
(),
"forceHttpVersion field not found"
)
require
.
True
(
t
,
field
.
CanAddr
(),
"forceHttpVersion field not addressable"
)
return
reflect
.
NewAt
(
field
.
Type
(),
unsafe
.
Pointer
(
field
.
UnsafeAddr
()))
.
Elem
()
.
String
()
}
func
TestGetSharedReqClient_ForceHTTP2SeparatesCache
(
t
*
testing
.
T
)
{
sharedReqClients
=
sync
.
Map
{}
base
:=
reqClientOptions
{
ProxyURL
:
"http://proxy.local:8080"
,
Timeout
:
time
.
Second
,
}
clientDefault
:=
getSharedReqClient
(
base
)
force
:=
base
force
.
ForceHTTP2
=
true
clientForce
:=
getSharedReqClient
(
force
)
require
.
NotSame
(
t
,
clientDefault
,
clientForce
)
require
.
NotEqual
(
t
,
buildReqClientKey
(
base
),
buildReqClientKey
(
force
))
}
func
TestGetSharedReqClient_ReuseCachedClient
(
t
*
testing
.
T
)
{
sharedReqClients
=
sync
.
Map
{}
opts
:=
reqClientOptions
{
ProxyURL
:
"http://proxy.local:8080"
,
Timeout
:
2
*
time
.
Second
,
}
first
:=
getSharedReqClient
(
opts
)
second
:=
getSharedReqClient
(
opts
)
require
.
Same
(
t
,
first
,
second
)
}
func
TestGetSharedReqClient_IgnoresNonClientCache
(
t
*
testing
.
T
)
{
sharedReqClients
=
sync
.
Map
{}
opts
:=
reqClientOptions
{
ProxyURL
:
" http://proxy.local:8080 "
,
Timeout
:
3
*
time
.
Second
,
}
key
:=
buildReqClientKey
(
opts
)
sharedReqClients
.
Store
(
key
,
"invalid"
)
client
:=
getSharedReqClient
(
opts
)
require
.
NotNil
(
t
,
client
)
loaded
,
ok
:=
sharedReqClients
.
Load
(
key
)
require
.
True
(
t
,
ok
)
require
.
IsType
(
t
,
"invalid"
,
loaded
)
}
func
TestGetSharedReqClient_ImpersonateAndProxy
(
t
*
testing
.
T
)
{
sharedReqClients
=
sync
.
Map
{}
opts
:=
reqClientOptions
{
ProxyURL
:
" http://proxy.local:8080 "
,
Timeout
:
4
*
time
.
Second
,
Impersonate
:
true
,
}
client
:=
getSharedReqClient
(
opts
)
require
.
NotNil
(
t
,
client
)
require
.
Equal
(
t
,
"http://proxy.local:8080|4s|true|false"
,
buildReqClientKey
(
opts
))
}
func
TestCreateOpenAIReqClient_Timeout120Seconds
(
t
*
testing
.
T
)
{
sharedReqClients
=
sync
.
Map
{}
client
:=
createOpenAIReqClient
(
"http://proxy.local:8080"
)
require
.
Equal
(
t
,
120
*
time
.
Second
,
client
.
GetClient
()
.
Timeout
)
}
func
TestCreateGeminiReqClient_ForceHTTP2Disabled
(
t
*
testing
.
T
)
{
sharedReqClients
=
sync
.
Map
{}
client
:=
createGeminiReqClient
(
"http://proxy.local:8080"
)
require
.
Equal
(
t
,
""
,
forceHTTPVersion
(
t
,
client
))
}
backend/internal/repository/scheduler_cache.go
View file @
2fe8932c
...
...
@@ -58,7 +58,9 @@ func (c *schedulerCache) GetSnapshot(ctx context.Context, bucket service.Schedul
return
nil
,
false
,
err
}
if
len
(
ids
)
==
0
{
return
[]
*
service
.
Account
{},
true
,
nil
// 空快照视为缓存未命中,触发数据库回退查询
// 这解决了新分组创建后立即绑定账号时的竞态条件问题
return
nil
,
false
,
nil
}
keys
:=
make
([]
string
,
0
,
len
(
ids
))
...
...
backend/internal/repository/scheduler_snapshot_outbox_integration_test.go
View file @
2fe8932c
...
...
@@ -19,7 +19,7 @@ func TestSchedulerSnapshotOutboxReplay(t *testing.T) {
_
,
_
=
integrationDB
.
ExecContext
(
ctx
,
"TRUNCATE scheduler_outbox"
)
accountRepo
:=
newAccountRepositoryWithSQL
(
client
,
integrationDB
)
accountRepo
:=
newAccountRepositoryWithSQL
(
client
,
integrationDB
,
nil
)
outboxRepo
:=
NewSchedulerOutboxRepository
(
integrationDB
)
cache
:=
NewSchedulerCache
(
rdb
)
...
...
backend/internal/repository/session_limit_cache.go
View file @
2fe8932c
...
...
@@ -217,7 +217,7 @@ func (c *sessionLimitCache) GetActiveSessionCount(ctx context.Context, accountID
}
// GetActiveSessionCountBatch 批量获取多个账号的活跃会话数
func
(
c
*
sessionLimitCache
)
GetActiveSessionCountBatch
(
ctx
context
.
Context
,
accountIDs
[]
int64
)
(
map
[
int64
]
int
,
error
)
{
func
(
c
*
sessionLimitCache
)
GetActiveSessionCountBatch
(
ctx
context
.
Context
,
accountIDs
[]
int64
,
idleTimeouts
map
[
int64
]
time
.
Duration
)
(
map
[
int64
]
int
,
error
)
{
if
len
(
accountIDs
)
==
0
{
return
make
(
map
[
int64
]
int
),
nil
}
...
...
@@ -226,11 +226,18 @@ func (c *sessionLimitCache) GetActiveSessionCountBatch(ctx context.Context, acco
// 使用 pipeline 批量执行
pipe
:=
c
.
rdb
.
Pipeline
()
idleTimeoutSeconds
:=
int
(
c
.
defaultIdleTimeout
.
Seconds
())
cmds
:=
make
(
map
[
int64
]
*
redis
.
Cmd
,
len
(
accountIDs
))
for
_
,
accountID
:=
range
accountIDs
{
key
:=
sessionLimitKey
(
accountID
)
// 使用各账号自己的 idleTimeout,如果没有则用默认值
idleTimeout
:=
c
.
defaultIdleTimeout
if
idleTimeouts
!=
nil
{
if
t
,
ok
:=
idleTimeouts
[
accountID
];
ok
&&
t
>
0
{
idleTimeout
=
t
}
}
idleTimeoutSeconds
:=
int
(
idleTimeout
.
Seconds
())
cmds
[
accountID
]
=
getActiveSessionCountScript
.
Run
(
ctx
,
pipe
,
[]
string
{
key
},
idleTimeoutSeconds
)
}
...
...
backend/internal/repository/simple_mode_default_groups.go
0 → 100644
View file @
2fe8932c
package
repository
import
(
"context"
"fmt"
dbent
"github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/group"
"github.com/Wei-Shaw/sub2api/internal/service"
)
func
ensureSimpleModeDefaultGroups
(
ctx
context
.
Context
,
client
*
dbent
.
Client
)
error
{
if
client
==
nil
{
return
fmt
.
Errorf
(
"nil ent client"
)
}
requiredByPlatform
:=
map
[
string
]
int
{
service
.
PlatformAnthropic
:
1
,
service
.
PlatformOpenAI
:
1
,
service
.
PlatformGemini
:
1
,
service
.
PlatformAntigravity
:
2
,
}
for
platform
,
minCount
:=
range
requiredByPlatform
{
count
,
err
:=
client
.
Group
.
Query
()
.
Where
(
group
.
PlatformEQ
(
platform
),
group
.
DeletedAtIsNil
())
.
Count
(
ctx
)
if
err
!=
nil
{
return
fmt
.
Errorf
(
"count groups for platform %s: %w"
,
platform
,
err
)
}
if
platform
==
service
.
PlatformAntigravity
{
if
count
<
minCount
{
for
i
:=
count
;
i
<
minCount
;
i
++
{
name
:=
fmt
.
Sprintf
(
"%s-default-%d"
,
platform
,
i
+
1
)
if
err
:=
createGroupIfNotExists
(
ctx
,
client
,
name
,
platform
);
err
!=
nil
{
return
err
}
}
}
continue
}
// Non-antigravity platforms: ensure <platform>-default exists.
name
:=
platform
+
"-default"
if
err
:=
createGroupIfNotExists
(
ctx
,
client
,
name
,
platform
);
err
!=
nil
{
return
err
}
}
return
nil
}
func
createGroupIfNotExists
(
ctx
context
.
Context
,
client
*
dbent
.
Client
,
name
,
platform
string
)
error
{
exists
,
err
:=
client
.
Group
.
Query
()
.
Where
(
group
.
NameEQ
(
name
),
group
.
DeletedAtIsNil
())
.
Exist
(
ctx
)
if
err
!=
nil
{
return
fmt
.
Errorf
(
"check group exists %s: %w"
,
name
,
err
)
}
if
exists
{
return
nil
}
_
,
err
=
client
.
Group
.
Create
()
.
SetName
(
name
)
.
SetDescription
(
"Auto-created default group"
)
.
SetPlatform
(
platform
)
.
SetStatus
(
service
.
StatusActive
)
.
SetSubscriptionType
(
service
.
SubscriptionTypeStandard
)
.
SetRateMultiplier
(
1.0
)
.
SetIsExclusive
(
false
)
.
Save
(
ctx
)
if
err
!=
nil
{
if
dbent
.
IsConstraintError
(
err
)
{
// Concurrent server startups may race on creation; treat as success.
return
nil
}
return
fmt
.
Errorf
(
"create default group %s: %w"
,
name
,
err
)
}
return
nil
}
backend/internal/repository/simple_mode_default_groups_integration_test.go
0 → 100644
View file @
2fe8932c
//go:build integration
package
repository
import
(
"context"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/ent/group"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/stretchr/testify/require"
)
func
TestEnsureSimpleModeDefaultGroups_CreatesMissingDefaults
(
t
*
testing
.
T
)
{
ctx
:=
context
.
Background
()
tx
:=
testEntTx
(
t
)
client
:=
tx
.
Client
()
seedCtx
,
cancel
:=
context
.
WithTimeout
(
ctx
,
10
*
time
.
Second
)
defer
cancel
()
require
.
NoError
(
t
,
ensureSimpleModeDefaultGroups
(
seedCtx
,
client
))
assertGroupExists
:=
func
(
name
string
)
{
exists
,
err
:=
client
.
Group
.
Query
()
.
Where
(
group
.
NameEQ
(
name
),
group
.
DeletedAtIsNil
())
.
Exist
(
seedCtx
)
require
.
NoError
(
t
,
err
)
require
.
True
(
t
,
exists
,
"expected group %s to exist"
,
name
)
}
assertGroupExists
(
service
.
PlatformAnthropic
+
"-default"
)
assertGroupExists
(
service
.
PlatformOpenAI
+
"-default"
)
assertGroupExists
(
service
.
PlatformGemini
+
"-default"
)
assertGroupExists
(
service
.
PlatformAntigravity
+
"-default-1"
)
assertGroupExists
(
service
.
PlatformAntigravity
+
"-default-2"
)
}
func
TestEnsureSimpleModeDefaultGroups_IgnoresSoftDeletedGroups
(
t
*
testing
.
T
)
{
ctx
:=
context
.
Background
()
tx
:=
testEntTx
(
t
)
client
:=
tx
.
Client
()
seedCtx
,
cancel
:=
context
.
WithTimeout
(
ctx
,
10
*
time
.
Second
)
defer
cancel
()
// Create and then soft-delete an anthropic default group.
g
,
err
:=
client
.
Group
.
Create
()
.
SetName
(
service
.
PlatformAnthropic
+
"-default"
)
.
SetPlatform
(
service
.
PlatformAnthropic
)
.
SetStatus
(
service
.
StatusActive
)
.
SetSubscriptionType
(
service
.
SubscriptionTypeStandard
)
.
SetRateMultiplier
(
1.0
)
.
SetIsExclusive
(
false
)
.
Save
(
seedCtx
)
require
.
NoError
(
t
,
err
)
_
,
err
=
client
.
Group
.
Delete
()
.
Where
(
group
.
IDEQ
(
g
.
ID
))
.
Exec
(
seedCtx
)
require
.
NoError
(
t
,
err
)
require
.
NoError
(
t
,
ensureSimpleModeDefaultGroups
(
seedCtx
,
client
))
// New active one should exist.
count
,
err
:=
client
.
Group
.
Query
()
.
Where
(
group
.
NameEQ
(
service
.
PlatformAnthropic
+
"-default"
),
group
.
DeletedAtIsNil
())
.
Count
(
seedCtx
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
1
,
count
)
}
func
TestEnsureSimpleModeDefaultGroups_AntigravityNeedsTwoGroupsOnlyByCount
(
t
*
testing
.
T
)
{
ctx
:=
context
.
Background
()
tx
:=
testEntTx
(
t
)
client
:=
tx
.
Client
()
seedCtx
,
cancel
:=
context
.
WithTimeout
(
ctx
,
10
*
time
.
Second
)
defer
cancel
()
mustCreateGroup
(
t
,
client
,
&
service
.
Group
{
Name
:
"ag-custom-1-"
+
time
.
Now
()
.
Format
(
time
.
RFC3339Nano
),
Platform
:
service
.
PlatformAntigravity
})
mustCreateGroup
(
t
,
client
,
&
service
.
Group
{
Name
:
"ag-custom-2-"
+
time
.
Now
()
.
Format
(
time
.
RFC3339Nano
),
Platform
:
service
.
PlatformAntigravity
})
require
.
NoError
(
t
,
ensureSimpleModeDefaultGroups
(
seedCtx
,
client
))
count
,
err
:=
client
.
Group
.
Query
()
.
Where
(
group
.
PlatformEQ
(
service
.
PlatformAntigravity
),
group
.
DeletedAtIsNil
())
.
Count
(
seedCtx
)
require
.
NoError
(
t
,
err
)
require
.
GreaterOrEqual
(
t
,
count
,
2
)
}
backend/internal/repository/totp_cache.go
0 → 100644
View file @
2fe8932c
package
repository
import
(
"context"
"encoding/json"
"fmt"
"time"
"github.com/redis/go-redis/v9"
"github.com/Wei-Shaw/sub2api/internal/service"
)
const
(
totpSetupKeyPrefix
=
"totp:setup:"
totpLoginKeyPrefix
=
"totp:login:"
totpAttemptsKeyPrefix
=
"totp:attempts:"
totpAttemptsTTL
=
15
*
time
.
Minute
)
// TotpCache implements service.TotpCache using Redis
type
TotpCache
struct
{
rdb
*
redis
.
Client
}
// NewTotpCache creates a new TOTP cache
func
NewTotpCache
(
rdb
*
redis
.
Client
)
service
.
TotpCache
{
return
&
TotpCache
{
rdb
:
rdb
}
}
// GetSetupSession retrieves a TOTP setup session
func
(
c
*
TotpCache
)
GetSetupSession
(
ctx
context
.
Context
,
userID
int64
)
(
*
service
.
TotpSetupSession
,
error
)
{
key
:=
fmt
.
Sprintf
(
"%s%d"
,
totpSetupKeyPrefix
,
userID
)
data
,
err
:=
c
.
rdb
.
Get
(
ctx
,
key
)
.
Bytes
()
if
err
!=
nil
{
if
err
==
redis
.
Nil
{
return
nil
,
nil
}
return
nil
,
fmt
.
Errorf
(
"get setup session: %w"
,
err
)
}
var
session
service
.
TotpSetupSession
if
err
:=
json
.
Unmarshal
(
data
,
&
session
);
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"unmarshal setup session: %w"
,
err
)
}
return
&
session
,
nil
}
// SetSetupSession stores a TOTP setup session
func
(
c
*
TotpCache
)
SetSetupSession
(
ctx
context
.
Context
,
userID
int64
,
session
*
service
.
TotpSetupSession
,
ttl
time
.
Duration
)
error
{
key
:=
fmt
.
Sprintf
(
"%s%d"
,
totpSetupKeyPrefix
,
userID
)
data
,
err
:=
json
.
Marshal
(
session
)
if
err
!=
nil
{
return
fmt
.
Errorf
(
"marshal setup session: %w"
,
err
)
}
if
err
:=
c
.
rdb
.
Set
(
ctx
,
key
,
data
,
ttl
)
.
Err
();
err
!=
nil
{
return
fmt
.
Errorf
(
"set setup session: %w"
,
err
)
}
return
nil
}
// DeleteSetupSession deletes a TOTP setup session
func
(
c
*
TotpCache
)
DeleteSetupSession
(
ctx
context
.
Context
,
userID
int64
)
error
{
key
:=
fmt
.
Sprintf
(
"%s%d"
,
totpSetupKeyPrefix
,
userID
)
return
c
.
rdb
.
Del
(
ctx
,
key
)
.
Err
()
}
// GetLoginSession retrieves a TOTP login session
func
(
c
*
TotpCache
)
GetLoginSession
(
ctx
context
.
Context
,
tempToken
string
)
(
*
service
.
TotpLoginSession
,
error
)
{
key
:=
totpLoginKeyPrefix
+
tempToken
data
,
err
:=
c
.
rdb
.
Get
(
ctx
,
key
)
.
Bytes
()
if
err
!=
nil
{
if
err
==
redis
.
Nil
{
return
nil
,
nil
}
return
nil
,
fmt
.
Errorf
(
"get login session: %w"
,
err
)
}
var
session
service
.
TotpLoginSession
if
err
:=
json
.
Unmarshal
(
data
,
&
session
);
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"unmarshal login session: %w"
,
err
)
}
return
&
session
,
nil
}
// SetLoginSession stores a TOTP login session
func
(
c
*
TotpCache
)
SetLoginSession
(
ctx
context
.
Context
,
tempToken
string
,
session
*
service
.
TotpLoginSession
,
ttl
time
.
Duration
)
error
{
key
:=
totpLoginKeyPrefix
+
tempToken
data
,
err
:=
json
.
Marshal
(
session
)
if
err
!=
nil
{
return
fmt
.
Errorf
(
"marshal login session: %w"
,
err
)
}
if
err
:=
c
.
rdb
.
Set
(
ctx
,
key
,
data
,
ttl
)
.
Err
();
err
!=
nil
{
return
fmt
.
Errorf
(
"set login session: %w"
,
err
)
}
return
nil
}
// DeleteLoginSession deletes a TOTP login session
func
(
c
*
TotpCache
)
DeleteLoginSession
(
ctx
context
.
Context
,
tempToken
string
)
error
{
key
:=
totpLoginKeyPrefix
+
tempToken
return
c
.
rdb
.
Del
(
ctx
,
key
)
.
Err
()
}
// IncrementVerifyAttempts increments the verify attempt counter
func
(
c
*
TotpCache
)
IncrementVerifyAttempts
(
ctx
context
.
Context
,
userID
int64
)
(
int
,
error
)
{
key
:=
fmt
.
Sprintf
(
"%s%d"
,
totpAttemptsKeyPrefix
,
userID
)
// Use pipeline for atomic increment and set TTL
pipe
:=
c
.
rdb
.
Pipeline
()
incrCmd
:=
pipe
.
Incr
(
ctx
,
key
)
pipe
.
Expire
(
ctx
,
key
,
totpAttemptsTTL
)
if
_
,
err
:=
pipe
.
Exec
(
ctx
);
err
!=
nil
{
return
0
,
fmt
.
Errorf
(
"increment verify attempts: %w"
,
err
)
}
count
,
err
:=
incrCmd
.
Result
()
if
err
!=
nil
{
return
0
,
fmt
.
Errorf
(
"get increment result: %w"
,
err
)
}
return
int
(
count
),
nil
}
// GetVerifyAttempts gets the current verify attempt count
func
(
c
*
TotpCache
)
GetVerifyAttempts
(
ctx
context
.
Context
,
userID
int64
)
(
int
,
error
)
{
key
:=
fmt
.
Sprintf
(
"%s%d"
,
totpAttemptsKeyPrefix
,
userID
)
count
,
err
:=
c
.
rdb
.
Get
(
ctx
,
key
)
.
Int
()
if
err
!=
nil
{
if
err
==
redis
.
Nil
{
return
0
,
nil
}
return
0
,
fmt
.
Errorf
(
"get verify attempts: %w"
,
err
)
}
return
count
,
nil
}
// ClearVerifyAttempts clears the verify attempt counter
func
(
c
*
TotpCache
)
ClearVerifyAttempts
(
ctx
context
.
Context
,
userID
int64
)
error
{
key
:=
fmt
.
Sprintf
(
"%s%d"
,
totpAttemptsKeyPrefix
,
userID
)
return
c
.
rdb
.
Del
(
ctx
,
key
)
.
Err
()
}
backend/internal/repository/usage_cleanup_repo.go
0 → 100644
View file @
2fe8932c
package
repository
import
(
"context"
"database/sql"
"encoding/json"
"errors"
"fmt"
"strings"
"time"
dbent
"github.com/Wei-Shaw/sub2api/ent"
dbusagecleanuptask
"github.com/Wei-Shaw/sub2api/ent/usagecleanuptask"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service"
)
type
usageCleanupRepository
struct
{
client
*
dbent
.
Client
sql
sqlExecutor
}
func
NewUsageCleanupRepository
(
client
*
dbent
.
Client
,
sqlDB
*
sql
.
DB
)
service
.
UsageCleanupRepository
{
return
newUsageCleanupRepositoryWithSQL
(
client
,
sqlDB
)
}
func
newUsageCleanupRepositoryWithSQL
(
client
*
dbent
.
Client
,
sqlq
sqlExecutor
)
*
usageCleanupRepository
{
return
&
usageCleanupRepository
{
client
:
client
,
sql
:
sqlq
}
}
func
(
r
*
usageCleanupRepository
)
CreateTask
(
ctx
context
.
Context
,
task
*
service
.
UsageCleanupTask
)
error
{
if
task
==
nil
{
return
nil
}
if
r
.
client
!=
nil
{
return
r
.
createTaskWithEnt
(
ctx
,
task
)
}
return
r
.
createTaskWithSQL
(
ctx
,
task
)
}
func
(
r
*
usageCleanupRepository
)
ListTasks
(
ctx
context
.
Context
,
params
pagination
.
PaginationParams
)
([]
service
.
UsageCleanupTask
,
*
pagination
.
PaginationResult
,
error
)
{
if
r
.
client
!=
nil
{
return
r
.
listTasksWithEnt
(
ctx
,
params
)
}
var
total
int64
if
err
:=
scanSingleRow
(
ctx
,
r
.
sql
,
"SELECT COUNT(*) FROM usage_cleanup_tasks"
,
nil
,
&
total
);
err
!=
nil
{
return
nil
,
nil
,
err
}
if
total
==
0
{
return
[]
service
.
UsageCleanupTask
{},
paginationResultFromTotal
(
0
,
params
),
nil
}
query
:=
`
SELECT id, status, filters, created_by, deleted_rows, error_message,
canceled_by, canceled_at,
started_at, finished_at, created_at, updated_at
FROM usage_cleanup_tasks
ORDER BY created_at DESC, id DESC
LIMIT $1 OFFSET $2
`
rows
,
err
:=
r
.
sql
.
QueryContext
(
ctx
,
query
,
params
.
Limit
(),
params
.
Offset
())
if
err
!=
nil
{
return
nil
,
nil
,
err
}
defer
func
()
{
_
=
rows
.
Close
()
}()
tasks
:=
make
([]
service
.
UsageCleanupTask
,
0
)
for
rows
.
Next
()
{
var
task
service
.
UsageCleanupTask
var
filtersJSON
[]
byte
var
errMsg
sql
.
NullString
var
canceledBy
sql
.
NullInt64
var
canceledAt
sql
.
NullTime
var
startedAt
sql
.
NullTime
var
finishedAt
sql
.
NullTime
if
err
:=
rows
.
Scan
(
&
task
.
ID
,
&
task
.
Status
,
&
filtersJSON
,
&
task
.
CreatedBy
,
&
task
.
DeletedRows
,
&
errMsg
,
&
canceledBy
,
&
canceledAt
,
&
startedAt
,
&
finishedAt
,
&
task
.
CreatedAt
,
&
task
.
UpdatedAt
,
);
err
!=
nil
{
return
nil
,
nil
,
err
}
if
err
:=
json
.
Unmarshal
(
filtersJSON
,
&
task
.
Filters
);
err
!=
nil
{
return
nil
,
nil
,
fmt
.
Errorf
(
"parse cleanup filters: %w"
,
err
)
}
if
errMsg
.
Valid
{
task
.
ErrorMsg
=
&
errMsg
.
String
}
if
canceledBy
.
Valid
{
v
:=
canceledBy
.
Int64
task
.
CanceledBy
=
&
v
}
if
canceledAt
.
Valid
{
task
.
CanceledAt
=
&
canceledAt
.
Time
}
if
startedAt
.
Valid
{
task
.
StartedAt
=
&
startedAt
.
Time
}
if
finishedAt
.
Valid
{
task
.
FinishedAt
=
&
finishedAt
.
Time
}
tasks
=
append
(
tasks
,
task
)
}
if
err
:=
rows
.
Err
();
err
!=
nil
{
return
nil
,
nil
,
err
}
return
tasks
,
paginationResultFromTotal
(
total
,
params
),
nil
}
func
(
r
*
usageCleanupRepository
)
ClaimNextPendingTask
(
ctx
context
.
Context
,
staleRunningAfterSeconds
int64
)
(
*
service
.
UsageCleanupTask
,
error
)
{
if
staleRunningAfterSeconds
<=
0
{
staleRunningAfterSeconds
=
1800
}
query
:=
`
WITH next AS (
SELECT id
FROM usage_cleanup_tasks
WHERE status = $1
OR (
status = $2
AND started_at IS NOT NULL
AND started_at < NOW() - ($3 * interval '1 second')
)
ORDER BY created_at ASC
LIMIT 1
FOR UPDATE SKIP LOCKED
)
UPDATE usage_cleanup_tasks AS tasks
SET status = $4,
started_at = NOW(),
finished_at = NULL,
error_message = NULL,
updated_at = NOW()
FROM next
WHERE tasks.id = next.id
RETURNING tasks.id, tasks.status, tasks.filters, tasks.created_by, tasks.deleted_rows, tasks.error_message,
tasks.started_at, tasks.finished_at, tasks.created_at, tasks.updated_at
`
var
task
service
.
UsageCleanupTask
var
filtersJSON
[]
byte
var
errMsg
sql
.
NullString
var
startedAt
sql
.
NullTime
var
finishedAt
sql
.
NullTime
if
err
:=
scanSingleRow
(
ctx
,
r
.
sql
,
query
,
[]
any
{
service
.
UsageCleanupStatusPending
,
service
.
UsageCleanupStatusRunning
,
staleRunningAfterSeconds
,
service
.
UsageCleanupStatusRunning
,
},
&
task
.
ID
,
&
task
.
Status
,
&
filtersJSON
,
&
task
.
CreatedBy
,
&
task
.
DeletedRows
,
&
errMsg
,
&
startedAt
,
&
finishedAt
,
&
task
.
CreatedAt
,
&
task
.
UpdatedAt
,
);
err
!=
nil
{
if
errors
.
Is
(
err
,
sql
.
ErrNoRows
)
{
return
nil
,
nil
}
return
nil
,
err
}
if
err
:=
json
.
Unmarshal
(
filtersJSON
,
&
task
.
Filters
);
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"parse cleanup filters: %w"
,
err
)
}
if
errMsg
.
Valid
{
task
.
ErrorMsg
=
&
errMsg
.
String
}
if
startedAt
.
Valid
{
task
.
StartedAt
=
&
startedAt
.
Time
}
if
finishedAt
.
Valid
{
task
.
FinishedAt
=
&
finishedAt
.
Time
}
return
&
task
,
nil
}
func
(
r
*
usageCleanupRepository
)
GetTaskStatus
(
ctx
context
.
Context
,
taskID
int64
)
(
string
,
error
)
{
if
r
.
client
!=
nil
{
return
r
.
getTaskStatusWithEnt
(
ctx
,
taskID
)
}
var
status
string
if
err
:=
scanSingleRow
(
ctx
,
r
.
sql
,
"SELECT status FROM usage_cleanup_tasks WHERE id = $1"
,
[]
any
{
taskID
},
&
status
);
err
!=
nil
{
return
""
,
err
}
return
status
,
nil
}
func
(
r
*
usageCleanupRepository
)
UpdateTaskProgress
(
ctx
context
.
Context
,
taskID
int64
,
deletedRows
int64
)
error
{
if
r
.
client
!=
nil
{
return
r
.
updateTaskProgressWithEnt
(
ctx
,
taskID
,
deletedRows
)
}
query
:=
`
UPDATE usage_cleanup_tasks
SET deleted_rows = $1,
updated_at = NOW()
WHERE id = $2
`
_
,
err
:=
r
.
sql
.
ExecContext
(
ctx
,
query
,
deletedRows
,
taskID
)
return
err
}
func
(
r
*
usageCleanupRepository
)
CancelTask
(
ctx
context
.
Context
,
taskID
int64
,
canceledBy
int64
)
(
bool
,
error
)
{
if
r
.
client
!=
nil
{
return
r
.
cancelTaskWithEnt
(
ctx
,
taskID
,
canceledBy
)
}
query
:=
`
UPDATE usage_cleanup_tasks
SET status = $1,
canceled_by = $3,
canceled_at = NOW(),
finished_at = NOW(),
error_message = NULL,
updated_at = NOW()
WHERE id = $2
AND status IN ($4, $5)
RETURNING id
`
var
id
int64
err
:=
scanSingleRow
(
ctx
,
r
.
sql
,
query
,
[]
any
{
service
.
UsageCleanupStatusCanceled
,
taskID
,
canceledBy
,
service
.
UsageCleanupStatusPending
,
service
.
UsageCleanupStatusRunning
,
},
&
id
)
if
errors
.
Is
(
err
,
sql
.
ErrNoRows
)
{
return
false
,
nil
}
if
err
!=
nil
{
return
false
,
err
}
return
true
,
nil
}
func
(
r
*
usageCleanupRepository
)
MarkTaskSucceeded
(
ctx
context
.
Context
,
taskID
int64
,
deletedRows
int64
)
error
{
if
r
.
client
!=
nil
{
return
r
.
markTaskSucceededWithEnt
(
ctx
,
taskID
,
deletedRows
)
}
query
:=
`
UPDATE usage_cleanup_tasks
SET status = $1,
deleted_rows = $2,
finished_at = NOW(),
updated_at = NOW()
WHERE id = $3
`
_
,
err
:=
r
.
sql
.
ExecContext
(
ctx
,
query
,
service
.
UsageCleanupStatusSucceeded
,
deletedRows
,
taskID
)
return
err
}
func
(
r
*
usageCleanupRepository
)
MarkTaskFailed
(
ctx
context
.
Context
,
taskID
int64
,
deletedRows
int64
,
errorMsg
string
)
error
{
if
r
.
client
!=
nil
{
return
r
.
markTaskFailedWithEnt
(
ctx
,
taskID
,
deletedRows
,
errorMsg
)
}
query
:=
`
UPDATE usage_cleanup_tasks
SET status = $1,
deleted_rows = $2,
error_message = $3,
finished_at = NOW(),
updated_at = NOW()
WHERE id = $4
`
_
,
err
:=
r
.
sql
.
ExecContext
(
ctx
,
query
,
service
.
UsageCleanupStatusFailed
,
deletedRows
,
errorMsg
,
taskID
)
return
err
}
func
(
r
*
usageCleanupRepository
)
DeleteUsageLogsBatch
(
ctx
context
.
Context
,
filters
service
.
UsageCleanupFilters
,
limit
int
)
(
int64
,
error
)
{
if
filters
.
StartTime
.
IsZero
()
||
filters
.
EndTime
.
IsZero
()
{
return
0
,
fmt
.
Errorf
(
"cleanup filters missing time range"
)
}
whereClause
,
args
:=
buildUsageCleanupWhere
(
filters
)
if
whereClause
==
""
{
return
0
,
fmt
.
Errorf
(
"cleanup filters missing time range"
)
}
args
=
append
(
args
,
limit
)
query
:=
fmt
.
Sprintf
(
`
WITH target AS (
SELECT id
FROM usage_logs
WHERE %s
ORDER BY created_at ASC, id ASC
LIMIT $%d
)
DELETE FROM usage_logs
WHERE id IN (SELECT id FROM target)
RETURNING id
`
,
whereClause
,
len
(
args
))
rows
,
err
:=
r
.
sql
.
QueryContext
(
ctx
,
query
,
args
...
)
if
err
!=
nil
{
return
0
,
err
}
defer
func
()
{
_
=
rows
.
Close
()
}()
var
deleted
int64
for
rows
.
Next
()
{
deleted
++
}
if
err
:=
rows
.
Err
();
err
!=
nil
{
return
0
,
err
}
return
deleted
,
nil
}
func
buildUsageCleanupWhere
(
filters
service
.
UsageCleanupFilters
)
(
string
,
[]
any
)
{
conditions
:=
make
([]
string
,
0
,
8
)
args
:=
make
([]
any
,
0
,
8
)
idx
:=
1
if
!
filters
.
StartTime
.
IsZero
()
{
conditions
=
append
(
conditions
,
fmt
.
Sprintf
(
"created_at >= $%d"
,
idx
))
args
=
append
(
args
,
filters
.
StartTime
)
idx
++
}
if
!
filters
.
EndTime
.
IsZero
()
{
conditions
=
append
(
conditions
,
fmt
.
Sprintf
(
"created_at <= $%d"
,
idx
))
args
=
append
(
args
,
filters
.
EndTime
)
idx
++
}
if
filters
.
UserID
!=
nil
{
conditions
=
append
(
conditions
,
fmt
.
Sprintf
(
"user_id = $%d"
,
idx
))
args
=
append
(
args
,
*
filters
.
UserID
)
idx
++
}
if
filters
.
APIKeyID
!=
nil
{
conditions
=
append
(
conditions
,
fmt
.
Sprintf
(
"api_key_id = $%d"
,
idx
))
args
=
append
(
args
,
*
filters
.
APIKeyID
)
idx
++
}
if
filters
.
AccountID
!=
nil
{
conditions
=
append
(
conditions
,
fmt
.
Sprintf
(
"account_id = $%d"
,
idx
))
args
=
append
(
args
,
*
filters
.
AccountID
)
idx
++
}
if
filters
.
GroupID
!=
nil
{
conditions
=
append
(
conditions
,
fmt
.
Sprintf
(
"group_id = $%d"
,
idx
))
args
=
append
(
args
,
*
filters
.
GroupID
)
idx
++
}
if
filters
.
Model
!=
nil
{
model
:=
strings
.
TrimSpace
(
*
filters
.
Model
)
if
model
!=
""
{
conditions
=
append
(
conditions
,
fmt
.
Sprintf
(
"model = $%d"
,
idx
))
args
=
append
(
args
,
model
)
idx
++
}
}
if
filters
.
Stream
!=
nil
{
conditions
=
append
(
conditions
,
fmt
.
Sprintf
(
"stream = $%d"
,
idx
))
args
=
append
(
args
,
*
filters
.
Stream
)
idx
++
}
if
filters
.
BillingType
!=
nil
{
conditions
=
append
(
conditions
,
fmt
.
Sprintf
(
"billing_type = $%d"
,
idx
))
args
=
append
(
args
,
*
filters
.
BillingType
)
}
return
strings
.
Join
(
conditions
,
" AND "
),
args
}
func
(
r
*
usageCleanupRepository
)
createTaskWithEnt
(
ctx
context
.
Context
,
task
*
service
.
UsageCleanupTask
)
error
{
client
:=
clientFromContext
(
ctx
,
r
.
client
)
filtersJSON
,
err
:=
json
.
Marshal
(
task
.
Filters
)
if
err
!=
nil
{
return
fmt
.
Errorf
(
"marshal cleanup filters: %w"
,
err
)
}
created
,
err
:=
client
.
UsageCleanupTask
.
Create
()
.
SetStatus
(
task
.
Status
)
.
SetFilters
(
json
.
RawMessage
(
filtersJSON
))
.
SetCreatedBy
(
task
.
CreatedBy
)
.
SetDeletedRows
(
task
.
DeletedRows
)
.
Save
(
ctx
)
if
err
!=
nil
{
return
err
}
task
.
ID
=
created
.
ID
task
.
CreatedAt
=
created
.
CreatedAt
task
.
UpdatedAt
=
created
.
UpdatedAt
return
nil
}
func
(
r
*
usageCleanupRepository
)
createTaskWithSQL
(
ctx
context
.
Context
,
task
*
service
.
UsageCleanupTask
)
error
{
filtersJSON
,
err
:=
json
.
Marshal
(
task
.
Filters
)
if
err
!=
nil
{
return
fmt
.
Errorf
(
"marshal cleanup filters: %w"
,
err
)
}
query
:=
`
INSERT INTO usage_cleanup_tasks (
status,
filters,
created_by,
deleted_rows
) VALUES ($1, $2, $3, $4)
RETURNING id, created_at, updated_at
`
if
err
:=
scanSingleRow
(
ctx
,
r
.
sql
,
query
,
[]
any
{
task
.
Status
,
filtersJSON
,
task
.
CreatedBy
,
task
.
DeletedRows
},
&
task
.
ID
,
&
task
.
CreatedAt
,
&
task
.
UpdatedAt
);
err
!=
nil
{
return
err
}
return
nil
}
func
(
r
*
usageCleanupRepository
)
listTasksWithEnt
(
ctx
context
.
Context
,
params
pagination
.
PaginationParams
)
([]
service
.
UsageCleanupTask
,
*
pagination
.
PaginationResult
,
error
)
{
client
:=
clientFromContext
(
ctx
,
r
.
client
)
query
:=
client
.
UsageCleanupTask
.
Query
()
total
,
err
:=
query
.
Clone
()
.
Count
(
ctx
)
if
err
!=
nil
{
return
nil
,
nil
,
err
}
if
total
==
0
{
return
[]
service
.
UsageCleanupTask
{},
paginationResultFromTotal
(
0
,
params
),
nil
}
rows
,
err
:=
query
.
Order
(
dbent
.
Desc
(
dbusagecleanuptask
.
FieldCreatedAt
),
dbent
.
Desc
(
dbusagecleanuptask
.
FieldID
))
.
Offset
(
params
.
Offset
())
.
Limit
(
params
.
Limit
())
.
All
(
ctx
)
if
err
!=
nil
{
return
nil
,
nil
,
err
}
tasks
:=
make
([]
service
.
UsageCleanupTask
,
0
,
len
(
rows
))
for
_
,
row
:=
range
rows
{
task
,
err
:=
usageCleanupTaskFromEnt
(
row
)
if
err
!=
nil
{
return
nil
,
nil
,
err
}
tasks
=
append
(
tasks
,
task
)
}
return
tasks
,
paginationResultFromTotal
(
int64
(
total
),
params
),
nil
}
func
(
r
*
usageCleanupRepository
)
getTaskStatusWithEnt
(
ctx
context
.
Context
,
taskID
int64
)
(
string
,
error
)
{
client
:=
clientFromContext
(
ctx
,
r
.
client
)
task
,
err
:=
client
.
UsageCleanupTask
.
Query
()
.
Where
(
dbusagecleanuptask
.
IDEQ
(
taskID
))
.
Only
(
ctx
)
if
err
!=
nil
{
if
dbent
.
IsNotFound
(
err
)
{
return
""
,
sql
.
ErrNoRows
}
return
""
,
err
}
return
task
.
Status
,
nil
}
func
(
r
*
usageCleanupRepository
)
updateTaskProgressWithEnt
(
ctx
context
.
Context
,
taskID
int64
,
deletedRows
int64
)
error
{
client
:=
clientFromContext
(
ctx
,
r
.
client
)
now
:=
time
.
Now
()
_
,
err
:=
client
.
UsageCleanupTask
.
Update
()
.
Where
(
dbusagecleanuptask
.
IDEQ
(
taskID
))
.
SetDeletedRows
(
deletedRows
)
.
SetUpdatedAt
(
now
)
.
Save
(
ctx
)
return
err
}
func
(
r
*
usageCleanupRepository
)
cancelTaskWithEnt
(
ctx
context
.
Context
,
taskID
int64
,
canceledBy
int64
)
(
bool
,
error
)
{
client
:=
clientFromContext
(
ctx
,
r
.
client
)
now
:=
time
.
Now
()
affected
,
err
:=
client
.
UsageCleanupTask
.
Update
()
.
Where
(
dbusagecleanuptask
.
IDEQ
(
taskID
),
dbusagecleanuptask
.
StatusIn
(
service
.
UsageCleanupStatusPending
,
service
.
UsageCleanupStatusRunning
),
)
.
SetStatus
(
service
.
UsageCleanupStatusCanceled
)
.
SetCanceledBy
(
canceledBy
)
.
SetCanceledAt
(
now
)
.
SetFinishedAt
(
now
)
.
ClearErrorMessage
()
.
SetUpdatedAt
(
now
)
.
Save
(
ctx
)
if
err
!=
nil
{
return
false
,
err
}
return
affected
>
0
,
nil
}
func
(
r
*
usageCleanupRepository
)
markTaskSucceededWithEnt
(
ctx
context
.
Context
,
taskID
int64
,
deletedRows
int64
)
error
{
client
:=
clientFromContext
(
ctx
,
r
.
client
)
now
:=
time
.
Now
()
_
,
err
:=
client
.
UsageCleanupTask
.
Update
()
.
Where
(
dbusagecleanuptask
.
IDEQ
(
taskID
))
.
SetStatus
(
service
.
UsageCleanupStatusSucceeded
)
.
SetDeletedRows
(
deletedRows
)
.
SetFinishedAt
(
now
)
.
SetUpdatedAt
(
now
)
.
Save
(
ctx
)
return
err
}
func
(
r
*
usageCleanupRepository
)
markTaskFailedWithEnt
(
ctx
context
.
Context
,
taskID
int64
,
deletedRows
int64
,
errorMsg
string
)
error
{
client
:=
clientFromContext
(
ctx
,
r
.
client
)
now
:=
time
.
Now
()
_
,
err
:=
client
.
UsageCleanupTask
.
Update
()
.
Where
(
dbusagecleanuptask
.
IDEQ
(
taskID
))
.
SetStatus
(
service
.
UsageCleanupStatusFailed
)
.
SetDeletedRows
(
deletedRows
)
.
SetErrorMessage
(
errorMsg
)
.
SetFinishedAt
(
now
)
.
SetUpdatedAt
(
now
)
.
Save
(
ctx
)
return
err
}
func
usageCleanupTaskFromEnt
(
row
*
dbent
.
UsageCleanupTask
)
(
service
.
UsageCleanupTask
,
error
)
{
task
:=
service
.
UsageCleanupTask
{
ID
:
row
.
ID
,
Status
:
row
.
Status
,
CreatedBy
:
row
.
CreatedBy
,
DeletedRows
:
row
.
DeletedRows
,
CreatedAt
:
row
.
CreatedAt
,
UpdatedAt
:
row
.
UpdatedAt
,
}
if
len
(
row
.
Filters
)
>
0
{
if
err
:=
json
.
Unmarshal
(
row
.
Filters
,
&
task
.
Filters
);
err
!=
nil
{
return
service
.
UsageCleanupTask
{},
fmt
.
Errorf
(
"parse cleanup filters: %w"
,
err
)
}
}
if
row
.
ErrorMessage
!=
nil
{
task
.
ErrorMsg
=
row
.
ErrorMessage
}
if
row
.
CanceledBy
!=
nil
{
task
.
CanceledBy
=
row
.
CanceledBy
}
if
row
.
CanceledAt
!=
nil
{
task
.
CanceledAt
=
row
.
CanceledAt
}
if
row
.
StartedAt
!=
nil
{
task
.
StartedAt
=
row
.
StartedAt
}
if
row
.
FinishedAt
!=
nil
{
task
.
FinishedAt
=
row
.
FinishedAt
}
return
task
,
nil
}
backend/internal/repository/usage_cleanup_repo_ent_test.go
0 → 100644
View file @
2fe8932c
package
repository
import
(
"context"
"database/sql"
"encoding/json"
"testing"
"time"
dbent
"github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/enttest"
dbusagecleanuptask
"github.com/Wei-Shaw/sub2api/ent/usagecleanuptask"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/stretchr/testify/require"
"entgo.io/ent/dialect"
entsql
"entgo.io/ent/dialect/sql"
_
"modernc.org/sqlite"
)
func
newUsageCleanupEntRepo
(
t
*
testing
.
T
)
(
*
usageCleanupRepository
,
*
dbent
.
Client
)
{
t
.
Helper
()
db
,
err
:=
sql
.
Open
(
"sqlite"
,
"file:usage_cleanup?mode=memory&cache=shared"
)
require
.
NoError
(
t
,
err
)
t
.
Cleanup
(
func
()
{
_
=
db
.
Close
()
})
_
,
err
=
db
.
Exec
(
"PRAGMA foreign_keys = ON"
)
require
.
NoError
(
t
,
err
)
drv
:=
entsql
.
OpenDB
(
dialect
.
SQLite
,
db
)
client
:=
enttest
.
NewClient
(
t
,
enttest
.
WithOptions
(
dbent
.
Driver
(
drv
)))
t
.
Cleanup
(
func
()
{
_
=
client
.
Close
()
})
repo
:=
&
usageCleanupRepository
{
client
:
client
,
sql
:
db
}
return
repo
,
client
}
func
TestUsageCleanupRepositoryEntCreateAndList
(
t
*
testing
.
T
)
{
repo
,
_
:=
newUsageCleanupEntRepo
(
t
)
start
:=
time
.
Date
(
2024
,
1
,
2
,
0
,
0
,
0
,
0
,
time
.
UTC
)
end
:=
start
.
Add
(
24
*
time
.
Hour
)
task
:=
&
service
.
UsageCleanupTask
{
Status
:
service
.
UsageCleanupStatusPending
,
Filters
:
service
.
UsageCleanupFilters
{
StartTime
:
start
,
EndTime
:
end
},
CreatedBy
:
9
,
}
require
.
NoError
(
t
,
repo
.
CreateTask
(
context
.
Background
(),
task
))
require
.
NotZero
(
t
,
task
.
ID
)
task2
:=
&
service
.
UsageCleanupTask
{
Status
:
service
.
UsageCleanupStatusRunning
,
Filters
:
service
.
UsageCleanupFilters
{
StartTime
:
start
.
Add
(
-
24
*
time
.
Hour
),
EndTime
:
end
.
Add
(
-
24
*
time
.
Hour
)},
CreatedBy
:
10
,
}
require
.
NoError
(
t
,
repo
.
CreateTask
(
context
.
Background
(),
task2
))
tasks
,
result
,
err
:=
repo
.
ListTasks
(
context
.
Background
(),
pagination
.
PaginationParams
{
Page
:
1
,
PageSize
:
10
})
require
.
NoError
(
t
,
err
)
require
.
Len
(
t
,
tasks
,
2
)
require
.
Equal
(
t
,
int64
(
2
),
result
.
Total
)
require
.
Greater
(
t
,
tasks
[
0
]
.
ID
,
tasks
[
1
]
.
ID
)
require
.
Equal
(
t
,
start
,
tasks
[
1
]
.
Filters
.
StartTime
)
require
.
Equal
(
t
,
end
,
tasks
[
1
]
.
Filters
.
EndTime
)
}
func
TestUsageCleanupRepositoryEntListEmpty
(
t
*
testing
.
T
)
{
repo
,
_
:=
newUsageCleanupEntRepo
(
t
)
tasks
,
result
,
err
:=
repo
.
ListTasks
(
context
.
Background
(),
pagination
.
PaginationParams
{
Page
:
1
,
PageSize
:
10
})
require
.
NoError
(
t
,
err
)
require
.
Empty
(
t
,
tasks
)
require
.
Equal
(
t
,
int64
(
0
),
result
.
Total
)
}
func
TestUsageCleanupRepositoryEntGetStatusAndProgress
(
t
*
testing
.
T
)
{
repo
,
client
:=
newUsageCleanupEntRepo
(
t
)
task
:=
&
service
.
UsageCleanupTask
{
Status
:
service
.
UsageCleanupStatusPending
,
Filters
:
service
.
UsageCleanupFilters
{
StartTime
:
time
.
Now
()
.
UTC
(),
EndTime
:
time
.
Now
()
.
UTC
()
.
Add
(
time
.
Hour
)},
CreatedBy
:
3
,
}
require
.
NoError
(
t
,
repo
.
CreateTask
(
context
.
Background
(),
task
))
status
,
err
:=
repo
.
GetTaskStatus
(
context
.
Background
(),
task
.
ID
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
service
.
UsageCleanupStatusPending
,
status
)
_
,
err
=
repo
.
GetTaskStatus
(
context
.
Background
(),
task
.
ID
+
99
)
require
.
ErrorIs
(
t
,
err
,
sql
.
ErrNoRows
)
require
.
NoError
(
t
,
repo
.
UpdateTaskProgress
(
context
.
Background
(),
task
.
ID
,
42
))
loaded
,
err
:=
client
.
UsageCleanupTask
.
Get
(
context
.
Background
(),
task
.
ID
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
int64
(
42
),
loaded
.
DeletedRows
)
}
func
TestUsageCleanupRepositoryEntCancelAndFinish
(
t
*
testing
.
T
)
{
repo
,
client
:=
newUsageCleanupEntRepo
(
t
)
task
:=
&
service
.
UsageCleanupTask
{
Status
:
service
.
UsageCleanupStatusPending
,
Filters
:
service
.
UsageCleanupFilters
{
StartTime
:
time
.
Now
()
.
UTC
(),
EndTime
:
time
.
Now
()
.
UTC
()
.
Add
(
time
.
Hour
)},
CreatedBy
:
5
,
}
require
.
NoError
(
t
,
repo
.
CreateTask
(
context
.
Background
(),
task
))
ok
,
err
:=
repo
.
CancelTask
(
context
.
Background
(),
task
.
ID
,
7
)
require
.
NoError
(
t
,
err
)
require
.
True
(
t
,
ok
)
loaded
,
err
:=
client
.
UsageCleanupTask
.
Get
(
context
.
Background
(),
task
.
ID
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
service
.
UsageCleanupStatusCanceled
,
loaded
.
Status
)
require
.
NotNil
(
t
,
loaded
.
CanceledBy
)
require
.
NotNil
(
t
,
loaded
.
CanceledAt
)
require
.
NotNil
(
t
,
loaded
.
FinishedAt
)
loaded
.
Status
=
service
.
UsageCleanupStatusSucceeded
_
,
err
=
client
.
UsageCleanupTask
.
Update
()
.
Where
(
dbusagecleanuptask
.
IDEQ
(
task
.
ID
))
.
SetStatus
(
loaded
.
Status
)
.
Save
(
context
.
Background
())
require
.
NoError
(
t
,
err
)
ok
,
err
=
repo
.
CancelTask
(
context
.
Background
(),
task
.
ID
,
7
)
require
.
NoError
(
t
,
err
)
require
.
False
(
t
,
ok
)
}
func
TestUsageCleanupRepositoryEntCancelError
(
t
*
testing
.
T
)
{
repo
,
client
:=
newUsageCleanupEntRepo
(
t
)
task
:=
&
service
.
UsageCleanupTask
{
Status
:
service
.
UsageCleanupStatusPending
,
Filters
:
service
.
UsageCleanupFilters
{
StartTime
:
time
.
Now
()
.
UTC
(),
EndTime
:
time
.
Now
()
.
UTC
()
.
Add
(
time
.
Hour
)},
CreatedBy
:
5
,
}
require
.
NoError
(
t
,
repo
.
CreateTask
(
context
.
Background
(),
task
))
require
.
NoError
(
t
,
client
.
Close
())
_
,
err
:=
repo
.
CancelTask
(
context
.
Background
(),
task
.
ID
,
7
)
require
.
Error
(
t
,
err
)
}
func
TestUsageCleanupRepositoryEntMarkResults
(
t
*
testing
.
T
)
{
repo
,
client
:=
newUsageCleanupEntRepo
(
t
)
task
:=
&
service
.
UsageCleanupTask
{
Status
:
service
.
UsageCleanupStatusRunning
,
Filters
:
service
.
UsageCleanupFilters
{
StartTime
:
time
.
Now
()
.
UTC
(),
EndTime
:
time
.
Now
()
.
UTC
()
.
Add
(
time
.
Hour
)},
CreatedBy
:
12
,
}
require
.
NoError
(
t
,
repo
.
CreateTask
(
context
.
Background
(),
task
))
require
.
NoError
(
t
,
repo
.
MarkTaskSucceeded
(
context
.
Background
(),
task
.
ID
,
6
))
loaded
,
err
:=
client
.
UsageCleanupTask
.
Get
(
context
.
Background
(),
task
.
ID
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
service
.
UsageCleanupStatusSucceeded
,
loaded
.
Status
)
require
.
Equal
(
t
,
int64
(
6
),
loaded
.
DeletedRows
)
require
.
NotNil
(
t
,
loaded
.
FinishedAt
)
task2
:=
&
service
.
UsageCleanupTask
{
Status
:
service
.
UsageCleanupStatusRunning
,
Filters
:
service
.
UsageCleanupFilters
{
StartTime
:
time
.
Now
()
.
UTC
(),
EndTime
:
time
.
Now
()
.
UTC
()
.
Add
(
time
.
Hour
)},
CreatedBy
:
12
,
}
require
.
NoError
(
t
,
repo
.
CreateTask
(
context
.
Background
(),
task2
))
require
.
NoError
(
t
,
repo
.
MarkTaskFailed
(
context
.
Background
(),
task2
.
ID
,
4
,
"boom"
))
loaded2
,
err
:=
client
.
UsageCleanupTask
.
Get
(
context
.
Background
(),
task2
.
ID
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
service
.
UsageCleanupStatusFailed
,
loaded2
.
Status
)
require
.
Equal
(
t
,
"boom"
,
*
loaded2
.
ErrorMessage
)
}
func
TestUsageCleanupRepositoryEntInvalidStatus
(
t
*
testing
.
T
)
{
repo
,
_
:=
newUsageCleanupEntRepo
(
t
)
task
:=
&
service
.
UsageCleanupTask
{
Status
:
"invalid"
,
Filters
:
service
.
UsageCleanupFilters
{
StartTime
:
time
.
Now
()
.
UTC
(),
EndTime
:
time
.
Now
()
.
UTC
()
.
Add
(
time
.
Hour
)},
CreatedBy
:
1
,
}
require
.
Error
(
t
,
repo
.
CreateTask
(
context
.
Background
(),
task
))
}
func
TestUsageCleanupRepositoryEntListInvalidFilters
(
t
*
testing
.
T
)
{
repo
,
client
:=
newUsageCleanupEntRepo
(
t
)
now
:=
time
.
Now
()
.
UTC
()
driver
,
ok
:=
client
.
Driver
()
.
(
*
entsql
.
Driver
)
require
.
True
(
t
,
ok
)
_
,
err
:=
driver
.
DB
()
.
ExecContext
(
context
.
Background
(),
`INSERT INTO usage_cleanup_tasks (status, filters, created_by, deleted_rows, created_at, updated_at)
VALUES (?, ?, ?, ?, ?, ?)`
,
service
.
UsageCleanupStatusPending
,
[]
byte
(
"invalid-json"
),
int64
(
1
),
int64
(
0
),
now
,
now
,
)
require
.
NoError
(
t
,
err
)
_
,
_
,
err
=
repo
.
ListTasks
(
context
.
Background
(),
pagination
.
PaginationParams
{
Page
:
1
,
PageSize
:
10
})
require
.
Error
(
t
,
err
)
}
func
TestUsageCleanupTaskFromEntFull
(
t
*
testing
.
T
)
{
start
:=
time
.
Date
(
2024
,
1
,
2
,
0
,
0
,
0
,
0
,
time
.
UTC
)
end
:=
start
.
Add
(
24
*
time
.
Hour
)
errMsg
:=
"failed"
canceledBy
:=
int64
(
2
)
canceledAt
:=
start
.
Add
(
time
.
Minute
)
startedAt
:=
start
.
Add
(
2
*
time
.
Minute
)
finishedAt
:=
start
.
Add
(
3
*
time
.
Minute
)
filters
:=
service
.
UsageCleanupFilters
{
StartTime
:
start
,
EndTime
:
end
}
filtersJSON
,
err
:=
json
.
Marshal
(
filters
)
require
.
NoError
(
t
,
err
)
task
,
err
:=
usageCleanupTaskFromEnt
(
&
dbent
.
UsageCleanupTask
{
ID
:
10
,
Status
:
service
.
UsageCleanupStatusFailed
,
Filters
:
filtersJSON
,
CreatedBy
:
11
,
DeletedRows
:
7
,
ErrorMessage
:
&
errMsg
,
CanceledBy
:
&
canceledBy
,
CanceledAt
:
&
canceledAt
,
StartedAt
:
&
startedAt
,
FinishedAt
:
&
finishedAt
,
CreatedAt
:
start
,
UpdatedAt
:
end
,
})
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
int64
(
10
),
task
.
ID
)
require
.
Equal
(
t
,
service
.
UsageCleanupStatusFailed
,
task
.
Status
)
require
.
NotNil
(
t
,
task
.
ErrorMsg
)
require
.
NotNil
(
t
,
task
.
CanceledBy
)
require
.
NotNil
(
t
,
task
.
CanceledAt
)
require
.
NotNil
(
t
,
task
.
StartedAt
)
require
.
NotNil
(
t
,
task
.
FinishedAt
)
}
func
TestUsageCleanupTaskFromEntInvalidFilters
(
t
*
testing
.
T
)
{
task
,
err
:=
usageCleanupTaskFromEnt
(
&
dbent
.
UsageCleanupTask
{
Filters
:
json
.
RawMessage
(
"invalid-json"
),
})
require
.
Error
(
t
,
err
)
require
.
Empty
(
t
,
task
)
}
backend/internal/repository/usage_cleanup_repo_test.go
0 → 100644
View file @
2fe8932c
package
repository
import
(
"context"
"database/sql"
"encoding/json"
"testing"
"time"
"github.com/DATA-DOG/go-sqlmock"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/stretchr/testify/require"
)
func
newSQLMock
(
t
*
testing
.
T
)
(
*
sql
.
DB
,
sqlmock
.
Sqlmock
)
{
t
.
Helper
()
db
,
mock
,
err
:=
sqlmock
.
New
(
sqlmock
.
QueryMatcherOption
(
sqlmock
.
QueryMatcherRegexp
))
require
.
NoError
(
t
,
err
)
t
.
Cleanup
(
func
()
{
_
=
db
.
Close
()
})
return
db
,
mock
}
func
TestNewUsageCleanupRepository
(
t
*
testing
.
T
)
{
db
,
_
:=
newSQLMock
(
t
)
repo
:=
NewUsageCleanupRepository
(
nil
,
db
)
require
.
NotNil
(
t
,
repo
)
}
func
TestUsageCleanupRepositoryCreateTask
(
t
*
testing
.
T
)
{
db
,
mock
:=
newSQLMock
(
t
)
repo
:=
&
usageCleanupRepository
{
sql
:
db
}
start
:=
time
.
Date
(
2024
,
1
,
1
,
0
,
0
,
0
,
0
,
time
.
UTC
)
end
:=
start
.
Add
(
24
*
time
.
Hour
)
task
:=
&
service
.
UsageCleanupTask
{
Status
:
service
.
UsageCleanupStatusPending
,
Filters
:
service
.
UsageCleanupFilters
{
StartTime
:
start
,
EndTime
:
end
},
CreatedBy
:
12
,
}
now
:=
time
.
Date
(
2024
,
1
,
2
,
0
,
0
,
0
,
0
,
time
.
UTC
)
mock
.
ExpectQuery
(
"INSERT INTO usage_cleanup_tasks"
)
.
WithArgs
(
task
.
Status
,
sqlmock
.
AnyArg
(),
task
.
CreatedBy
,
task
.
DeletedRows
)
.
WillReturnRows
(
sqlmock
.
NewRows
([]
string
{
"id"
,
"created_at"
,
"updated_at"
})
.
AddRow
(
int64
(
1
),
now
,
now
))
err
:=
repo
.
CreateTask
(
context
.
Background
(),
task
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
int64
(
1
),
task
.
ID
)
require
.
Equal
(
t
,
now
,
task
.
CreatedAt
)
require
.
Equal
(
t
,
now
,
task
.
UpdatedAt
)
require
.
NoError
(
t
,
mock
.
ExpectationsWereMet
())
}
func
TestUsageCleanupRepositoryCreateTaskNil
(
t
*
testing
.
T
)
{
db
,
mock
:=
newSQLMock
(
t
)
repo
:=
&
usageCleanupRepository
{
sql
:
db
}
err
:=
repo
.
CreateTask
(
context
.
Background
(),
nil
)
require
.
NoError
(
t
,
err
)
require
.
NoError
(
t
,
mock
.
ExpectationsWereMet
())
}
func
TestUsageCleanupRepositoryCreateTaskQueryError
(
t
*
testing
.
T
)
{
db
,
mock
:=
newSQLMock
(
t
)
repo
:=
&
usageCleanupRepository
{
sql
:
db
}
task
:=
&
service
.
UsageCleanupTask
{
Status
:
service
.
UsageCleanupStatusPending
,
Filters
:
service
.
UsageCleanupFilters
{
StartTime
:
time
.
Now
(),
EndTime
:
time
.
Now
()
.
Add
(
time
.
Hour
)},
CreatedBy
:
1
,
}
mock
.
ExpectQuery
(
"INSERT INTO usage_cleanup_tasks"
)
.
WithArgs
(
task
.
Status
,
sqlmock
.
AnyArg
(),
task
.
CreatedBy
,
task
.
DeletedRows
)
.
WillReturnError
(
sql
.
ErrConnDone
)
err
:=
repo
.
CreateTask
(
context
.
Background
(),
task
)
require
.
Error
(
t
,
err
)
require
.
NoError
(
t
,
mock
.
ExpectationsWereMet
())
}
func
TestUsageCleanupRepositoryListTasksEmpty
(
t
*
testing
.
T
)
{
db
,
mock
:=
newSQLMock
(
t
)
repo
:=
&
usageCleanupRepository
{
sql
:
db
}
mock
.
ExpectQuery
(
"SELECT COUNT
\\
(
\\
*
\\
) FROM usage_cleanup_tasks"
)
.
WillReturnRows
(
sqlmock
.
NewRows
([]
string
{
"count"
})
.
AddRow
(
int64
(
0
)))
tasks
,
result
,
err
:=
repo
.
ListTasks
(
context
.
Background
(),
pagination
.
PaginationParams
{
Page
:
1
,
PageSize
:
20
})
require
.
NoError
(
t
,
err
)
require
.
Empty
(
t
,
tasks
)
require
.
Equal
(
t
,
int64
(
0
),
result
.
Total
)
require
.
NoError
(
t
,
mock
.
ExpectationsWereMet
())
}
func
TestUsageCleanupRepositoryListTasks
(
t
*
testing
.
T
)
{
db
,
mock
:=
newSQLMock
(
t
)
repo
:=
&
usageCleanupRepository
{
sql
:
db
}
start
:=
time
.
Date
(
2024
,
1
,
1
,
0
,
0
,
0
,
0
,
time
.
UTC
)
end
:=
start
.
Add
(
2
*
time
.
Hour
)
filters
:=
service
.
UsageCleanupFilters
{
StartTime
:
start
,
EndTime
:
end
}
filtersJSON
,
err
:=
json
.
Marshal
(
filters
)
require
.
NoError
(
t
,
err
)
createdAt
:=
time
.
Date
(
2024
,
1
,
2
,
12
,
0
,
0
,
0
,
time
.
UTC
)
updatedAt
:=
createdAt
.
Add
(
time
.
Minute
)
rows
:=
sqlmock
.
NewRows
([]
string
{
"id"
,
"status"
,
"filters"
,
"created_by"
,
"deleted_rows"
,
"error_message"
,
"canceled_by"
,
"canceled_at"
,
"started_at"
,
"finished_at"
,
"created_at"
,
"updated_at"
,
})
.
AddRow
(
int64
(
1
),
service
.
UsageCleanupStatusSucceeded
,
filtersJSON
,
int64
(
2
),
int64
(
9
),
"error"
,
nil
,
nil
,
start
,
end
,
createdAt
,
updatedAt
,
)
mock
.
ExpectQuery
(
"SELECT COUNT
\\
(
\\
*
\\
) FROM usage_cleanup_tasks"
)
.
WillReturnRows
(
sqlmock
.
NewRows
([]
string
{
"count"
})
.
AddRow
(
int64
(
1
)))
mock
.
ExpectQuery
(
"SELECT id, status, filters, created_by, deleted_rows, error_message"
)
.
WithArgs
(
20
,
0
)
.
WillReturnRows
(
rows
)
tasks
,
result
,
err
:=
repo
.
ListTasks
(
context
.
Background
(),
pagination
.
PaginationParams
{
Page
:
1
,
PageSize
:
20
})
require
.
NoError
(
t
,
err
)
require
.
Len
(
t
,
tasks
,
1
)
require
.
Equal
(
t
,
int64
(
1
),
tasks
[
0
]
.
ID
)
require
.
Equal
(
t
,
service
.
UsageCleanupStatusSucceeded
,
tasks
[
0
]
.
Status
)
require
.
Equal
(
t
,
int64
(
2
),
tasks
[
0
]
.
CreatedBy
)
require
.
Equal
(
t
,
int64
(
9
),
tasks
[
0
]
.
DeletedRows
)
require
.
NotNil
(
t
,
tasks
[
0
]
.
ErrorMsg
)
require
.
Equal
(
t
,
"error"
,
*
tasks
[
0
]
.
ErrorMsg
)
require
.
NotNil
(
t
,
tasks
[
0
]
.
StartedAt
)
require
.
NotNil
(
t
,
tasks
[
0
]
.
FinishedAt
)
require
.
Equal
(
t
,
int64
(
1
),
result
.
Total
)
require
.
NoError
(
t
,
mock
.
ExpectationsWereMet
())
}
func
TestUsageCleanupRepositoryListTasksQueryError
(
t
*
testing
.
T
)
{
db
,
mock
:=
newSQLMock
(
t
)
repo
:=
&
usageCleanupRepository
{
sql
:
db
}
mock
.
ExpectQuery
(
"SELECT COUNT
\\
(
\\
*
\\
) FROM usage_cleanup_tasks"
)
.
WillReturnRows
(
sqlmock
.
NewRows
([]
string
{
"count"
})
.
AddRow
(
int64
(
2
)))
mock
.
ExpectQuery
(
"SELECT id, status, filters, created_by, deleted_rows, error_message"
)
.
WithArgs
(
20
,
0
)
.
WillReturnError
(
sql
.
ErrConnDone
)
_
,
_
,
err
:=
repo
.
ListTasks
(
context
.
Background
(),
pagination
.
PaginationParams
{
Page
:
1
,
PageSize
:
20
})
require
.
Error
(
t
,
err
)
require
.
NoError
(
t
,
mock
.
ExpectationsWereMet
())
}
func
TestUsageCleanupRepositoryListTasksInvalidFilters
(
t
*
testing
.
T
)
{
db
,
mock
:=
newSQLMock
(
t
)
repo
:=
&
usageCleanupRepository
{
sql
:
db
}
rows
:=
sqlmock
.
NewRows
([]
string
{
"id"
,
"status"
,
"filters"
,
"created_by"
,
"deleted_rows"
,
"error_message"
,
"canceled_by"
,
"canceled_at"
,
"started_at"
,
"finished_at"
,
"created_at"
,
"updated_at"
,
})
.
AddRow
(
int64
(
1
),
service
.
UsageCleanupStatusSucceeded
,
[]
byte
(
"not-json"
),
int64
(
2
),
int64
(
9
),
nil
,
nil
,
nil
,
nil
,
nil
,
time
.
Now
()
.
UTC
(),
time
.
Now
()
.
UTC
(),
)
mock
.
ExpectQuery
(
"SELECT COUNT
\\
(
\\
*
\\
) FROM usage_cleanup_tasks"
)
.
WillReturnRows
(
sqlmock
.
NewRows
([]
string
{
"count"
})
.
AddRow
(
int64
(
1
)))
mock
.
ExpectQuery
(
"SELECT id, status, filters, created_by, deleted_rows, error_message"
)
.
WithArgs
(
20
,
0
)
.
WillReturnRows
(
rows
)
_
,
_
,
err
:=
repo
.
ListTasks
(
context
.
Background
(),
pagination
.
PaginationParams
{
Page
:
1
,
PageSize
:
20
})
require
.
Error
(
t
,
err
)
require
.
NoError
(
t
,
mock
.
ExpectationsWereMet
())
}
func
TestUsageCleanupRepositoryClaimNextPendingTaskNone
(
t
*
testing
.
T
)
{
db
,
mock
:=
newSQLMock
(
t
)
repo
:=
&
usageCleanupRepository
{
sql
:
db
}
mock
.
ExpectQuery
(
"UPDATE usage_cleanup_tasks"
)
.
WithArgs
(
service
.
UsageCleanupStatusPending
,
service
.
UsageCleanupStatusRunning
,
int64
(
1800
),
service
.
UsageCleanupStatusRunning
)
.
WillReturnRows
(
sqlmock
.
NewRows
([]
string
{
"id"
,
"status"
,
"filters"
,
"created_by"
,
"deleted_rows"
,
"error_message"
,
"started_at"
,
"finished_at"
,
"created_at"
,
"updated_at"
,
}))
task
,
err
:=
repo
.
ClaimNextPendingTask
(
context
.
Background
(),
1800
)
require
.
NoError
(
t
,
err
)
require
.
Nil
(
t
,
task
)
require
.
NoError
(
t
,
mock
.
ExpectationsWereMet
())
}
func
TestUsageCleanupRepositoryClaimNextPendingTask
(
t
*
testing
.
T
)
{
db
,
mock
:=
newSQLMock
(
t
)
repo
:=
&
usageCleanupRepository
{
sql
:
db
}
start
:=
time
.
Date
(
2024
,
1
,
1
,
0
,
0
,
0
,
0
,
time
.
UTC
)
end
:=
start
.
Add
(
24
*
time
.
Hour
)
filters
:=
service
.
UsageCleanupFilters
{
StartTime
:
start
,
EndTime
:
end
}
filtersJSON
,
err
:=
json
.
Marshal
(
filters
)
require
.
NoError
(
t
,
err
)
rows
:=
sqlmock
.
NewRows
([]
string
{
"id"
,
"status"
,
"filters"
,
"created_by"
,
"deleted_rows"
,
"error_message"
,
"started_at"
,
"finished_at"
,
"created_at"
,
"updated_at"
,
})
.
AddRow
(
int64
(
4
),
service
.
UsageCleanupStatusRunning
,
filtersJSON
,
int64
(
7
),
int64
(
0
),
nil
,
start
,
nil
,
start
,
start
,
)
mock
.
ExpectQuery
(
"UPDATE usage_cleanup_tasks"
)
.
WithArgs
(
service
.
UsageCleanupStatusPending
,
service
.
UsageCleanupStatusRunning
,
int64
(
1800
),
service
.
UsageCleanupStatusRunning
)
.
WillReturnRows
(
rows
)
task
,
err
:=
repo
.
ClaimNextPendingTask
(
context
.
Background
(),
1800
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
task
)
require
.
Equal
(
t
,
int64
(
4
),
task
.
ID
)
require
.
Equal
(
t
,
service
.
UsageCleanupStatusRunning
,
task
.
Status
)
require
.
Equal
(
t
,
int64
(
7
),
task
.
CreatedBy
)
require
.
NotNil
(
t
,
task
.
StartedAt
)
require
.
Nil
(
t
,
task
.
ErrorMsg
)
require
.
NoError
(
t
,
mock
.
ExpectationsWereMet
())
}
func
TestUsageCleanupRepositoryClaimNextPendingTaskError
(
t
*
testing
.
T
)
{
db
,
mock
:=
newSQLMock
(
t
)
repo
:=
&
usageCleanupRepository
{
sql
:
db
}
mock
.
ExpectQuery
(
"UPDATE usage_cleanup_tasks"
)
.
WithArgs
(
service
.
UsageCleanupStatusPending
,
service
.
UsageCleanupStatusRunning
,
int64
(
1800
),
service
.
UsageCleanupStatusRunning
)
.
WillReturnError
(
sql
.
ErrConnDone
)
_
,
err
:=
repo
.
ClaimNextPendingTask
(
context
.
Background
(),
1800
)
require
.
Error
(
t
,
err
)
require
.
NoError
(
t
,
mock
.
ExpectationsWereMet
())
}
func
TestUsageCleanupRepositoryClaimNextPendingTaskInvalidFilters
(
t
*
testing
.
T
)
{
db
,
mock
:=
newSQLMock
(
t
)
repo
:=
&
usageCleanupRepository
{
sql
:
db
}
rows
:=
sqlmock
.
NewRows
([]
string
{
"id"
,
"status"
,
"filters"
,
"created_by"
,
"deleted_rows"
,
"error_message"
,
"started_at"
,
"finished_at"
,
"created_at"
,
"updated_at"
,
})
.
AddRow
(
int64
(
4
),
service
.
UsageCleanupStatusRunning
,
[]
byte
(
"invalid"
),
int64
(
7
),
int64
(
0
),
nil
,
nil
,
nil
,
time
.
Now
()
.
UTC
(),
time
.
Now
()
.
UTC
(),
)
mock
.
ExpectQuery
(
"UPDATE usage_cleanup_tasks"
)
.
WithArgs
(
service
.
UsageCleanupStatusPending
,
service
.
UsageCleanupStatusRunning
,
int64
(
1800
),
service
.
UsageCleanupStatusRunning
)
.
WillReturnRows
(
rows
)
_
,
err
:=
repo
.
ClaimNextPendingTask
(
context
.
Background
(),
1800
)
require
.
Error
(
t
,
err
)
require
.
NoError
(
t
,
mock
.
ExpectationsWereMet
())
}
func
TestUsageCleanupRepositoryMarkTaskSucceeded
(
t
*
testing
.
T
)
{
db
,
mock
:=
newSQLMock
(
t
)
repo
:=
&
usageCleanupRepository
{
sql
:
db
}
mock
.
ExpectExec
(
"UPDATE usage_cleanup_tasks"
)
.
WithArgs
(
service
.
UsageCleanupStatusSucceeded
,
int64
(
12
),
int64
(
9
))
.
WillReturnResult
(
sqlmock
.
NewResult
(
0
,
1
))
err
:=
repo
.
MarkTaskSucceeded
(
context
.
Background
(),
9
,
12
)
require
.
NoError
(
t
,
err
)
require
.
NoError
(
t
,
mock
.
ExpectationsWereMet
())
}
func
TestUsageCleanupRepositoryMarkTaskFailed
(
t
*
testing
.
T
)
{
db
,
mock
:=
newSQLMock
(
t
)
repo
:=
&
usageCleanupRepository
{
sql
:
db
}
mock
.
ExpectExec
(
"UPDATE usage_cleanup_tasks"
)
.
WithArgs
(
service
.
UsageCleanupStatusFailed
,
int64
(
4
),
"boom"
,
int64
(
2
))
.
WillReturnResult
(
sqlmock
.
NewResult
(
0
,
1
))
err
:=
repo
.
MarkTaskFailed
(
context
.
Background
(),
2
,
4
,
"boom"
)
require
.
NoError
(
t
,
err
)
require
.
NoError
(
t
,
mock
.
ExpectationsWereMet
())
}
func
TestUsageCleanupRepositoryGetTaskStatus
(
t
*
testing
.
T
)
{
db
,
mock
:=
newSQLMock
(
t
)
repo
:=
&
usageCleanupRepository
{
sql
:
db
}
mock
.
ExpectQuery
(
"SELECT status FROM usage_cleanup_tasks"
)
.
WithArgs
(
int64
(
9
))
.
WillReturnRows
(
sqlmock
.
NewRows
([]
string
{
"status"
})
.
AddRow
(
service
.
UsageCleanupStatusPending
))
status
,
err
:=
repo
.
GetTaskStatus
(
context
.
Background
(),
9
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
service
.
UsageCleanupStatusPending
,
status
)
require
.
NoError
(
t
,
mock
.
ExpectationsWereMet
())
}
func
TestUsageCleanupRepositoryGetTaskStatusQueryError
(
t
*
testing
.
T
)
{
db
,
mock
:=
newSQLMock
(
t
)
repo
:=
&
usageCleanupRepository
{
sql
:
db
}
mock
.
ExpectQuery
(
"SELECT status FROM usage_cleanup_tasks"
)
.
WithArgs
(
int64
(
9
))
.
WillReturnError
(
sql
.
ErrConnDone
)
_
,
err
:=
repo
.
GetTaskStatus
(
context
.
Background
(),
9
)
require
.
Error
(
t
,
err
)
require
.
NoError
(
t
,
mock
.
ExpectationsWereMet
())
}
func
TestUsageCleanupRepositoryUpdateTaskProgress
(
t
*
testing
.
T
)
{
db
,
mock
:=
newSQLMock
(
t
)
repo
:=
&
usageCleanupRepository
{
sql
:
db
}
mock
.
ExpectExec
(
"UPDATE usage_cleanup_tasks"
)
.
WithArgs
(
int64
(
123
),
int64
(
8
))
.
WillReturnResult
(
sqlmock
.
NewResult
(
0
,
1
))
err
:=
repo
.
UpdateTaskProgress
(
context
.
Background
(),
8
,
123
)
require
.
NoError
(
t
,
err
)
require
.
NoError
(
t
,
mock
.
ExpectationsWereMet
())
}
func
TestUsageCleanupRepositoryCancelTask
(
t
*
testing
.
T
)
{
db
,
mock
:=
newSQLMock
(
t
)
repo
:=
&
usageCleanupRepository
{
sql
:
db
}
mock
.
ExpectQuery
(
"UPDATE usage_cleanup_tasks"
)
.
WithArgs
(
service
.
UsageCleanupStatusCanceled
,
int64
(
6
),
int64
(
9
),
service
.
UsageCleanupStatusPending
,
service
.
UsageCleanupStatusRunning
)
.
WillReturnRows
(
sqlmock
.
NewRows
([]
string
{
"id"
})
.
AddRow
(
int64
(
6
)))
ok
,
err
:=
repo
.
CancelTask
(
context
.
Background
(),
6
,
9
)
require
.
NoError
(
t
,
err
)
require
.
True
(
t
,
ok
)
require
.
NoError
(
t
,
mock
.
ExpectationsWereMet
())
}
func
TestUsageCleanupRepositoryCancelTaskNoRows
(
t
*
testing
.
T
)
{
db
,
mock
:=
newSQLMock
(
t
)
repo
:=
&
usageCleanupRepository
{
sql
:
db
}
mock
.
ExpectQuery
(
"UPDATE usage_cleanup_tasks"
)
.
WithArgs
(
service
.
UsageCleanupStatusCanceled
,
int64
(
6
),
int64
(
9
),
service
.
UsageCleanupStatusPending
,
service
.
UsageCleanupStatusRunning
)
.
WillReturnRows
(
sqlmock
.
NewRows
([]
string
{
"id"
}))
ok
,
err
:=
repo
.
CancelTask
(
context
.
Background
(),
6
,
9
)
require
.
NoError
(
t
,
err
)
require
.
False
(
t
,
ok
)
require
.
NoError
(
t
,
mock
.
ExpectationsWereMet
())
}
func
TestUsageCleanupRepositoryDeleteUsageLogsBatchMissingRange
(
t
*
testing
.
T
)
{
db
,
_
:=
newSQLMock
(
t
)
repo
:=
&
usageCleanupRepository
{
sql
:
db
}
_
,
err
:=
repo
.
DeleteUsageLogsBatch
(
context
.
Background
(),
service
.
UsageCleanupFilters
{},
10
)
require
.
Error
(
t
,
err
)
}
func
TestUsageCleanupRepositoryDeleteUsageLogsBatch
(
t
*
testing
.
T
)
{
db
,
mock
:=
newSQLMock
(
t
)
repo
:=
&
usageCleanupRepository
{
sql
:
db
}
start
:=
time
.
Date
(
2024
,
1
,
1
,
0
,
0
,
0
,
0
,
time
.
UTC
)
end
:=
start
.
Add
(
24
*
time
.
Hour
)
userID
:=
int64
(
3
)
model
:=
" gpt-4 "
filters
:=
service
.
UsageCleanupFilters
{
StartTime
:
start
,
EndTime
:
end
,
UserID
:
&
userID
,
Model
:
&
model
,
}
mock
.
ExpectQuery
(
"DELETE FROM usage_logs"
)
.
WithArgs
(
start
,
end
,
userID
,
"gpt-4"
,
2
)
.
WillReturnRows
(
sqlmock
.
NewRows
([]
string
{
"id"
})
.
AddRow
(
int64
(
1
))
.
AddRow
(
int64
(
2
)))
deleted
,
err
:=
repo
.
DeleteUsageLogsBatch
(
context
.
Background
(),
filters
,
2
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
int64
(
2
),
deleted
)
require
.
NoError
(
t
,
mock
.
ExpectationsWereMet
())
}
func
TestUsageCleanupRepositoryDeleteUsageLogsBatchQueryError
(
t
*
testing
.
T
)
{
db
,
mock
:=
newSQLMock
(
t
)
repo
:=
&
usageCleanupRepository
{
sql
:
db
}
start
:=
time
.
Date
(
2024
,
1
,
1
,
0
,
0
,
0
,
0
,
time
.
UTC
)
end
:=
start
.
Add
(
24
*
time
.
Hour
)
filters
:=
service
.
UsageCleanupFilters
{
StartTime
:
start
,
EndTime
:
end
}
mock
.
ExpectQuery
(
"DELETE FROM usage_logs"
)
.
WithArgs
(
start
,
end
,
5
)
.
WillReturnError
(
sql
.
ErrConnDone
)
_
,
err
:=
repo
.
DeleteUsageLogsBatch
(
context
.
Background
(),
filters
,
5
)
require
.
Error
(
t
,
err
)
require
.
NoError
(
t
,
mock
.
ExpectationsWereMet
())
}
func
TestBuildUsageCleanupWhere
(
t
*
testing
.
T
)
{
start
:=
time
.
Date
(
2024
,
1
,
1
,
0
,
0
,
0
,
0
,
time
.
UTC
)
end
:=
start
.
Add
(
24
*
time
.
Hour
)
userID
:=
int64
(
1
)
apiKeyID
:=
int64
(
2
)
accountID
:=
int64
(
3
)
groupID
:=
int64
(
4
)
model
:=
" gpt-4 "
stream
:=
true
billingType
:=
int8
(
2
)
where
,
args
:=
buildUsageCleanupWhere
(
service
.
UsageCleanupFilters
{
StartTime
:
start
,
EndTime
:
end
,
UserID
:
&
userID
,
APIKeyID
:
&
apiKeyID
,
AccountID
:
&
accountID
,
GroupID
:
&
groupID
,
Model
:
&
model
,
Stream
:
&
stream
,
BillingType
:
&
billingType
,
})
require
.
Equal
(
t
,
"created_at >= $1 AND created_at <= $2 AND user_id = $3 AND api_key_id = $4 AND account_id = $5 AND group_id = $6 AND model = $7 AND stream = $8 AND billing_type = $9"
,
where
)
require
.
Equal
(
t
,
[]
any
{
start
,
end
,
userID
,
apiKeyID
,
accountID
,
groupID
,
"gpt-4"
,
stream
,
billingType
},
args
)
}
func
TestBuildUsageCleanupWhereModelEmpty
(
t
*
testing
.
T
)
{
start
:=
time
.
Date
(
2024
,
1
,
1
,
0
,
0
,
0
,
0
,
time
.
UTC
)
end
:=
start
.
Add
(
24
*
time
.
Hour
)
model
:=
" "
where
,
args
:=
buildUsageCleanupWhere
(
service
.
UsageCleanupFilters
{
StartTime
:
start
,
EndTime
:
end
,
Model
:
&
model
,
})
require
.
Equal
(
t
,
"created_at >= $1 AND created_at <= $2"
,
where
)
require
.
Equal
(
t
,
[]
any
{
start
,
end
},
args
)
}
backend/internal/repository/usage_log_repo.go
View file @
2fe8932c
...
...
@@ -1411,7 +1411,7 @@ func (r *usageLogRepository) GetBatchAPIKeyUsageStats(ctx context.Context, apiKe
}
// GetUsageTrendWithFilters returns usage trend data with optional filters
func
(
r
*
usageLogRepository
)
GetUsageTrendWithFilters
(
ctx
context
.
Context
,
startTime
,
endTime
time
.
Time
,
granularity
string
,
userID
,
apiKeyID
,
accountID
,
groupID
int64
,
model
string
,
stream
*
bool
)
(
results
[]
TrendDataPoint
,
err
error
)
{
func
(
r
*
usageLogRepository
)
GetUsageTrendWithFilters
(
ctx
context
.
Context
,
startTime
,
endTime
time
.
Time
,
granularity
string
,
userID
,
apiKeyID
,
accountID
,
groupID
int64
,
model
string
,
stream
*
bool
,
billingType
*
int8
)
(
results
[]
TrendDataPoint
,
err
error
)
{
dateFormat
:=
"YYYY-MM-DD"
if
granularity
==
"hour"
{
dateFormat
=
"YYYY-MM-DD HH24:00"
...
...
@@ -1456,6 +1456,10 @@ func (r *usageLogRepository) GetUsageTrendWithFilters(ctx context.Context, start
query
+=
fmt
.
Sprintf
(
" AND stream = $%d"
,
len
(
args
)
+
1
)
args
=
append
(
args
,
*
stream
)
}
if
billingType
!=
nil
{
query
+=
fmt
.
Sprintf
(
" AND billing_type = $%d"
,
len
(
args
)
+
1
)
args
=
append
(
args
,
int16
(
*
billingType
))
}
query
+=
" GROUP BY date ORDER BY date ASC"
rows
,
err
:=
r
.
sql
.
QueryContext
(
ctx
,
query
,
args
...
)
...
...
@@ -1479,7 +1483,7 @@ func (r *usageLogRepository) GetUsageTrendWithFilters(ctx context.Context, start
}
// GetModelStatsWithFilters returns model statistics with optional filters
func
(
r
*
usageLogRepository
)
GetModelStatsWithFilters
(
ctx
context
.
Context
,
startTime
,
endTime
time
.
Time
,
userID
,
apiKeyID
,
accountID
,
groupID
int64
,
stream
*
bool
)
(
results
[]
ModelStat
,
err
error
)
{
func
(
r
*
usageLogRepository
)
GetModelStatsWithFilters
(
ctx
context
.
Context
,
startTime
,
endTime
time
.
Time
,
userID
,
apiKeyID
,
accountID
,
groupID
int64
,
stream
*
bool
,
billingType
*
int8
)
(
results
[]
ModelStat
,
err
error
)
{
actualCostExpr
:=
"COALESCE(SUM(actual_cost), 0) as actual_cost"
// 当仅按 account_id 聚合时,实际费用使用账号倍率(total_cost * account_rate_multiplier)。
if
accountID
>
0
&&
userID
==
0
&&
apiKeyID
==
0
{
...
...
@@ -1520,6 +1524,10 @@ func (r *usageLogRepository) GetModelStatsWithFilters(ctx context.Context, start
query
+=
fmt
.
Sprintf
(
" AND stream = $%d"
,
len
(
args
)
+
1
)
args
=
append
(
args
,
*
stream
)
}
if
billingType
!=
nil
{
query
+=
fmt
.
Sprintf
(
" AND billing_type = $%d"
,
len
(
args
)
+
1
)
args
=
append
(
args
,
int16
(
*
billingType
))
}
query
+=
" GROUP BY model ORDER BY total_tokens DESC"
rows
,
err
:=
r
.
sql
.
QueryContext
(
ctx
,
query
,
args
...
)
...
...
@@ -1825,7 +1833,7 @@ func (r *usageLogRepository) GetAccountUsageStats(ctx context.Context, accountID
}
}
models
,
err
:=
r
.
GetModelStatsWithFilters
(
ctx
,
startTime
,
endTime
,
0
,
0
,
accountID
,
0
,
nil
)
models
,
err
:=
r
.
GetModelStatsWithFilters
(
ctx
,
startTime
,
endTime
,
0
,
0
,
accountID
,
0
,
nil
,
nil
)
if
err
!=
nil
{
models
=
[]
ModelStat
{}
}
...
...
backend/internal/repository/usage_log_repo_integration_test.go
View file @
2fe8932c
...
...
@@ -944,17 +944,17 @@ func (s *UsageLogRepoSuite) TestGetUsageTrendWithFilters() {
endTime
:=
base
.
Add
(
48
*
time
.
Hour
)
// Test with user filter
trend
,
err
:=
s
.
repo
.
GetUsageTrendWithFilters
(
s
.
ctx
,
startTime
,
endTime
,
"day"
,
user
.
ID
,
0
,
0
,
0
,
""
,
nil
)
trend
,
err
:=
s
.
repo
.
GetUsageTrendWithFilters
(
s
.
ctx
,
startTime
,
endTime
,
"day"
,
user
.
ID
,
0
,
0
,
0
,
""
,
nil
,
nil
)
s
.
Require
()
.
NoError
(
err
,
"GetUsageTrendWithFilters user filter"
)
s
.
Require
()
.
Len
(
trend
,
2
)
// Test with apiKey filter
trend
,
err
=
s
.
repo
.
GetUsageTrendWithFilters
(
s
.
ctx
,
startTime
,
endTime
,
"day"
,
0
,
apiKey
.
ID
,
0
,
0
,
""
,
nil
)
trend
,
err
=
s
.
repo
.
GetUsageTrendWithFilters
(
s
.
ctx
,
startTime
,
endTime
,
"day"
,
0
,
apiKey
.
ID
,
0
,
0
,
""
,
nil
,
nil
)
s
.
Require
()
.
NoError
(
err
,
"GetUsageTrendWithFilters apiKey filter"
)
s
.
Require
()
.
Len
(
trend
,
2
)
// Test with both filters
trend
,
err
=
s
.
repo
.
GetUsageTrendWithFilters
(
s
.
ctx
,
startTime
,
endTime
,
"day"
,
user
.
ID
,
apiKey
.
ID
,
0
,
0
,
""
,
nil
)
trend
,
err
=
s
.
repo
.
GetUsageTrendWithFilters
(
s
.
ctx
,
startTime
,
endTime
,
"day"
,
user
.
ID
,
apiKey
.
ID
,
0
,
0
,
""
,
nil
,
nil
)
s
.
Require
()
.
NoError
(
err
,
"GetUsageTrendWithFilters both filters"
)
s
.
Require
()
.
Len
(
trend
,
2
)
}
...
...
@@ -971,7 +971,7 @@ func (s *UsageLogRepoSuite) TestGetUsageTrendWithFilters_HourlyGranularity() {
startTime
:=
base
.
Add
(
-
1
*
time
.
Hour
)
endTime
:=
base
.
Add
(
3
*
time
.
Hour
)
trend
,
err
:=
s
.
repo
.
GetUsageTrendWithFilters
(
s
.
ctx
,
startTime
,
endTime
,
"hour"
,
user
.
ID
,
0
,
0
,
0
,
""
,
nil
)
trend
,
err
:=
s
.
repo
.
GetUsageTrendWithFilters
(
s
.
ctx
,
startTime
,
endTime
,
"hour"
,
user
.
ID
,
0
,
0
,
0
,
""
,
nil
,
nil
)
s
.
Require
()
.
NoError
(
err
,
"GetUsageTrendWithFilters hourly"
)
s
.
Require
()
.
Len
(
trend
,
2
)
}
...
...
@@ -1017,17 +1017,17 @@ func (s *UsageLogRepoSuite) TestGetModelStatsWithFilters() {
endTime
:=
base
.
Add
(
2
*
time
.
Hour
)
// Test with user filter
stats
,
err
:=
s
.
repo
.
GetModelStatsWithFilters
(
s
.
ctx
,
startTime
,
endTime
,
user
.
ID
,
0
,
0
,
0
,
nil
)
stats
,
err
:=
s
.
repo
.
GetModelStatsWithFilters
(
s
.
ctx
,
startTime
,
endTime
,
user
.
ID
,
0
,
0
,
0
,
nil
,
nil
)
s
.
Require
()
.
NoError
(
err
,
"GetModelStatsWithFilters user filter"
)
s
.
Require
()
.
Len
(
stats
,
2
)
// Test with apiKey filter
stats
,
err
=
s
.
repo
.
GetModelStatsWithFilters
(
s
.
ctx
,
startTime
,
endTime
,
0
,
apiKey
.
ID
,
0
,
0
,
nil
)
stats
,
err
=
s
.
repo
.
GetModelStatsWithFilters
(
s
.
ctx
,
startTime
,
endTime
,
0
,
apiKey
.
ID
,
0
,
0
,
nil
,
nil
)
s
.
Require
()
.
NoError
(
err
,
"GetModelStatsWithFilters apiKey filter"
)
s
.
Require
()
.
Len
(
stats
,
2
)
// Test with account filter
stats
,
err
=
s
.
repo
.
GetModelStatsWithFilters
(
s
.
ctx
,
startTime
,
endTime
,
0
,
0
,
account
.
ID
,
0
,
nil
)
stats
,
err
=
s
.
repo
.
GetModelStatsWithFilters
(
s
.
ctx
,
startTime
,
endTime
,
0
,
0
,
account
.
ID
,
0
,
nil
,
nil
)
s
.
Require
()
.
NoError
(
err
,
"GetModelStatsWithFilters account filter"
)
s
.
Require
()
.
Len
(
stats
,
2
)
}
...
...
backend/internal/repository/user_repo.go
View file @
2fe8932c
...
...
@@ -7,6 +7,7 @@ import (
"fmt"
"sort"
"strings"
"time"
dbent
"github.com/Wei-Shaw/sub2api/ent"
dbuser
"github.com/Wei-Shaw/sub2api/ent/user"
...
...
@@ -466,3 +467,46 @@ func applyUserEntityToService(dst *service.User, src *dbent.User) {
dst
.
CreatedAt
=
src
.
CreatedAt
dst
.
UpdatedAt
=
src
.
UpdatedAt
}
// UpdateTotpSecret 更新用户的 TOTP 加密密钥
func
(
r
*
userRepository
)
UpdateTotpSecret
(
ctx
context
.
Context
,
userID
int64
,
encryptedSecret
*
string
)
error
{
client
:=
clientFromContext
(
ctx
,
r
.
client
)
update
:=
client
.
User
.
UpdateOneID
(
userID
)
if
encryptedSecret
==
nil
{
update
=
update
.
ClearTotpSecretEncrypted
()
}
else
{
update
=
update
.
SetTotpSecretEncrypted
(
*
encryptedSecret
)
}
_
,
err
:=
update
.
Save
(
ctx
)
if
err
!=
nil
{
return
translatePersistenceError
(
err
,
service
.
ErrUserNotFound
,
nil
)
}
return
nil
}
// EnableTotp 启用用户的 TOTP 双因素认证
func
(
r
*
userRepository
)
EnableTotp
(
ctx
context
.
Context
,
userID
int64
)
error
{
client
:=
clientFromContext
(
ctx
,
r
.
client
)
_
,
err
:=
client
.
User
.
UpdateOneID
(
userID
)
.
SetTotpEnabled
(
true
)
.
SetTotpEnabledAt
(
time
.
Now
())
.
Save
(
ctx
)
if
err
!=
nil
{
return
translatePersistenceError
(
err
,
service
.
ErrUserNotFound
,
nil
)
}
return
nil
}
// DisableTotp 禁用用户的 TOTP 双因素认证
func
(
r
*
userRepository
)
DisableTotp
(
ctx
context
.
Context
,
userID
int64
)
error
{
client
:=
clientFromContext
(
ctx
,
r
.
client
)
_
,
err
:=
client
.
User
.
UpdateOneID
(
userID
)
.
SetTotpEnabled
(
false
)
.
ClearTotpEnabledAt
()
.
ClearTotpSecretEncrypted
()
.
Save
(
ctx
)
if
err
!=
nil
{
return
translatePersistenceError
(
err
,
service
.
ErrUserNotFound
,
nil
)
}
return
nil
}
backend/internal/repository/user_subscription_repo.go
View file @
2fe8932c
...
...
@@ -190,7 +190,7 @@ func (r *userSubscriptionRepository) ListByGroupID(ctx context.Context, groupID
return
userSubscriptionEntitiesToService
(
subs
),
paginationResultFromTotal
(
int64
(
total
),
params
),
nil
}
func
(
r
*
userSubscriptionRepository
)
List
(
ctx
context
.
Context
,
params
pagination
.
PaginationParams
,
userID
,
groupID
*
int64
,
status
string
)
([]
service
.
UserSubscription
,
*
pagination
.
PaginationResult
,
error
)
{
func
(
r
*
userSubscriptionRepository
)
List
(
ctx
context
.
Context
,
params
pagination
.
PaginationParams
,
userID
,
groupID
*
int64
,
status
,
sortBy
,
sortOrder
string
)
([]
service
.
UserSubscription
,
*
pagination
.
PaginationResult
,
error
)
{
client
:=
clientFromContext
(
ctx
,
r
.
client
)
q
:=
client
.
UserSubscription
.
Query
()
if
userID
!=
nil
{
...
...
@@ -199,7 +199,31 @@ func (r *userSubscriptionRepository) List(ctx context.Context, params pagination
if
groupID
!=
nil
{
q
=
q
.
Where
(
usersubscription
.
GroupIDEQ
(
*
groupID
))
}
if
status
!=
""
{
// Status filtering with real-time expiration check
now
:=
time
.
Now
()
switch
status
{
case
service
.
SubscriptionStatusActive
:
// Active: status is active AND not yet expired
q
=
q
.
Where
(
usersubscription
.
StatusEQ
(
service
.
SubscriptionStatusActive
),
usersubscription
.
ExpiresAtGT
(
now
),
)
case
service
.
SubscriptionStatusExpired
:
// Expired: status is expired OR (status is active but already expired)
q
=
q
.
Where
(
usersubscription
.
Or
(
usersubscription
.
StatusEQ
(
service
.
SubscriptionStatusExpired
),
usersubscription
.
And
(
usersubscription
.
StatusEQ
(
service
.
SubscriptionStatusActive
),
usersubscription
.
ExpiresAtLTE
(
now
),
),
),
)
case
""
:
// No filter
default
:
// Other status (e.g., revoked)
q
=
q
.
Where
(
usersubscription
.
StatusEQ
(
status
))
}
...
...
@@ -208,11 +232,28 @@ func (r *userSubscriptionRepository) List(ctx context.Context, params pagination
return
nil
,
nil
,
err
}
// Apply sorting
q
=
q
.
WithUser
()
.
WithGroup
()
.
WithAssignedByUser
()
// Determine sort field
var
field
string
switch
sortBy
{
case
"expires_at"
:
field
=
usersubscription
.
FieldExpiresAt
case
"status"
:
field
=
usersubscription
.
FieldStatus
default
:
field
=
usersubscription
.
FieldCreatedAt
}
// Determine sort order (default: desc)
if
sortOrder
==
"asc"
&&
sortBy
!=
""
{
q
=
q
.
Order
(
dbent
.
Asc
(
field
))
}
else
{
q
=
q
.
Order
(
dbent
.
Desc
(
field
))
}
subs
,
err
:=
q
.
WithUser
()
.
WithGroup
()
.
WithAssignedByUser
()
.
Order
(
dbent
.
Desc
(
usersubscription
.
FieldCreatedAt
))
.
Offset
(
params
.
Offset
())
.
Limit
(
params
.
Limit
())
.
All
(
ctx
)
...
...
backend/internal/repository/user_subscription_repo_integration_test.go
View file @
2fe8932c
...
...
@@ -271,7 +271,7 @@ func (s *UserSubscriptionRepoSuite) TestList_NoFilters() {
group
:=
s
.
mustCreateGroup
(
"g-list"
)
s
.
mustCreateSubscription
(
user
.
ID
,
group
.
ID
,
nil
)
subs
,
page
,
err
:=
s
.
repo
.
List
(
s
.
ctx
,
pagination
.
PaginationParams
{
Page
:
1
,
PageSize
:
10
},
nil
,
nil
,
""
)
subs
,
page
,
err
:=
s
.
repo
.
List
(
s
.
ctx
,
pagination
.
PaginationParams
{
Page
:
1
,
PageSize
:
10
},
nil
,
nil
,
""
,
""
,
""
)
s
.
Require
()
.
NoError
(
err
,
"List"
)
s
.
Require
()
.
Len
(
subs
,
1
)
s
.
Require
()
.
Equal
(
int64
(
1
),
page
.
Total
)
...
...
@@ -285,7 +285,7 @@ func (s *UserSubscriptionRepoSuite) TestList_FilterByUserID() {
s
.
mustCreateSubscription
(
user1
.
ID
,
group
.
ID
,
nil
)
s
.
mustCreateSubscription
(
user2
.
ID
,
group
.
ID
,
nil
)
subs
,
_
,
err
:=
s
.
repo
.
List
(
s
.
ctx
,
pagination
.
PaginationParams
{
Page
:
1
,
PageSize
:
10
},
&
user1
.
ID
,
nil
,
""
)
subs
,
_
,
err
:=
s
.
repo
.
List
(
s
.
ctx
,
pagination
.
PaginationParams
{
Page
:
1
,
PageSize
:
10
},
&
user1
.
ID
,
nil
,
""
,
""
,
""
)
s
.
Require
()
.
NoError
(
err
)
s
.
Require
()
.
Len
(
subs
,
1
)
s
.
Require
()
.
Equal
(
user1
.
ID
,
subs
[
0
]
.
UserID
)
...
...
@@ -299,7 +299,7 @@ func (s *UserSubscriptionRepoSuite) TestList_FilterByGroupID() {
s
.
mustCreateSubscription
(
user
.
ID
,
g1
.
ID
,
nil
)
s
.
mustCreateSubscription
(
user
.
ID
,
g2
.
ID
,
nil
)
subs
,
_
,
err
:=
s
.
repo
.
List
(
s
.
ctx
,
pagination
.
PaginationParams
{
Page
:
1
,
PageSize
:
10
},
nil
,
&
g1
.
ID
,
""
)
subs
,
_
,
err
:=
s
.
repo
.
List
(
s
.
ctx
,
pagination
.
PaginationParams
{
Page
:
1
,
PageSize
:
10
},
nil
,
&
g1
.
ID
,
""
,
""
,
""
)
s
.
Require
()
.
NoError
(
err
)
s
.
Require
()
.
Len
(
subs
,
1
)
s
.
Require
()
.
Equal
(
g1
.
ID
,
subs
[
0
]
.
GroupID
)
...
...
@@ -320,7 +320,7 @@ func (s *UserSubscriptionRepoSuite) TestList_FilterByStatus() {
c
.
SetExpiresAt
(
time
.
Now
()
.
Add
(
-
24
*
time
.
Hour
))
})
subs
,
_
,
err
:=
s
.
repo
.
List
(
s
.
ctx
,
pagination
.
PaginationParams
{
Page
:
1
,
PageSize
:
10
},
nil
,
nil
,
service
.
SubscriptionStatusExpired
)
subs
,
_
,
err
:=
s
.
repo
.
List
(
s
.
ctx
,
pagination
.
PaginationParams
{
Page
:
1
,
PageSize
:
10
},
nil
,
nil
,
service
.
SubscriptionStatusExpired
,
""
,
""
)
s
.
Require
()
.
NoError
(
err
)
s
.
Require
()
.
Len
(
subs
,
1
)
s
.
Require
()
.
Equal
(
service
.
SubscriptionStatusExpired
,
subs
[
0
]
.
Status
)
...
...
backend/internal/repository/wire.go
View file @
2fe8932c
...
...
@@ -57,6 +57,7 @@ var ProviderSet = wire.NewSet(
NewRedeemCodeRepository
,
NewPromoCodeRepository
,
NewUsageLogRepository
,
NewUsageCleanupRepository
,
NewDashboardAggregationRepository
,
NewSettingRepository
,
NewOpsRepository
,
...
...
@@ -81,6 +82,10 @@ var ProviderSet = wire.NewSet(
NewSchedulerCache
,
NewSchedulerOutboxRepository
,
NewProxyLatencyCache
,
NewTotpCache
,
// Encryptors
NewAESEncryptor
,
// HTTP service ports (DI Strategy A: return interface directly)
NewTurnstileVerifier
,
...
...
backend/internal/server/api_contract_test.go
View file @
2fe8932c
...
...
@@ -51,7 +51,6 @@ func TestAPIContracts(t *testing.T) {
"id": 1,
"email": "alice@example.com",
"username": "alice",
"notes": "hello",
"role": "user",
"balance": 12.5,
"concurrency": 5,
...
...
@@ -131,6 +130,153 @@ func TestAPIContracts(t *testing.T) {
}
}`
,
},
{
name
:
"GET /api/v1/groups/available"
,
setup
:
func
(
t
*
testing
.
T
,
deps
*
contractDeps
)
{
t
.
Helper
()
// 普通用户可见的分组列表不应包含内部字段(如 model_routing/account_count)。
deps
.
groupRepo
.
SetActive
([]
service
.
Group
{
{
ID
:
10
,
Name
:
"Group One"
,
Description
:
"desc"
,
Platform
:
service
.
PlatformAnthropic
,
RateMultiplier
:
1.5
,
IsExclusive
:
false
,
Status
:
service
.
StatusActive
,
SubscriptionType
:
service
.
SubscriptionTypeStandard
,
ModelRoutingEnabled
:
true
,
ModelRouting
:
map
[
string
][]
int64
{
"claude-3-*"
:
[]
int64
{
101
,
102
},
},
AccountCount
:
2
,
CreatedAt
:
deps
.
now
,
UpdatedAt
:
deps
.
now
,
},
})
deps
.
userSubRepo
.
SetActiveByUserID
(
1
,
nil
)
},
method
:
http
.
MethodGet
,
path
:
"/api/v1/groups/available"
,
wantStatus
:
http
.
StatusOK
,
wantJSON
:
`{
"code": 0,
"message": "success",
"data": [
{
"id": 10,
"name": "Group One",
"description": "desc",
"platform": "anthropic",
"rate_multiplier": 1.5,
"is_exclusive": false,
"status": "active",
"subscription_type": "standard",
"daily_limit_usd": null,
"weekly_limit_usd": null,
"monthly_limit_usd": null,
"image_price_1k": null,
"image_price_2k": null,
"image_price_4k": null,
"claude_code_only": false,
"fallback_group_id": null,
"created_at": "2025-01-02T03:04:05Z",
"updated_at": "2025-01-02T03:04:05Z"
}
]
}`
,
},
{
name
:
"GET /api/v1/subscriptions"
,
setup
:
func
(
t
*
testing
.
T
,
deps
*
contractDeps
)
{
t
.
Helper
()
// 普通用户订阅接口不应包含 assigned_* / notes 等管理员字段。
deps
.
userSubRepo
.
SetByUserID
(
1
,
[]
service
.
UserSubscription
{
{
ID
:
501
,
UserID
:
1
,
GroupID
:
10
,
StartsAt
:
deps
.
now
,
ExpiresAt
:
time
.
Date
(
2099
,
1
,
2
,
3
,
4
,
5
,
0
,
time
.
UTC
),
// 使用未来日期避免 normalizeSubscriptionStatus 标记为过期
Status
:
service
.
SubscriptionStatusActive
,
DailyUsageUSD
:
1.23
,
WeeklyUsageUSD
:
2.34
,
MonthlyUsageUSD
:
3.45
,
AssignedBy
:
ptr
(
int64
(
999
)),
AssignedAt
:
deps
.
now
,
Notes
:
"admin-note"
,
CreatedAt
:
deps
.
now
,
UpdatedAt
:
deps
.
now
,
},
})
},
method
:
http
.
MethodGet
,
path
:
"/api/v1/subscriptions"
,
wantStatus
:
http
.
StatusOK
,
wantJSON
:
`{
"code": 0,
"message": "success",
"data": [
{
"id": 501,
"user_id": 1,
"group_id": 10,
"starts_at": "2025-01-02T03:04:05Z",
"expires_at": "2099-01-02T03:04:05Z",
"status": "active",
"daily_window_start": null,
"weekly_window_start": null,
"monthly_window_start": null,
"daily_usage_usd": 1.23,
"weekly_usage_usd": 2.34,
"monthly_usage_usd": 3.45,
"created_at": "2025-01-02T03:04:05Z",
"updated_at": "2025-01-02T03:04:05Z"
}
]
}`
,
},
{
name
:
"GET /api/v1/redeem/history"
,
setup
:
func
(
t
*
testing
.
T
,
deps
*
contractDeps
)
{
t
.
Helper
()
// 普通用户兑换历史不应包含 notes 等内部字段。
deps
.
redeemRepo
.
SetByUser
(
1
,
[]
service
.
RedeemCode
{
{
ID
:
900
,
Code
:
"CODE-123"
,
Type
:
service
.
RedeemTypeBalance
,
Value
:
1.25
,
Status
:
service
.
StatusUsed
,
UsedBy
:
ptr
(
int64
(
1
)),
UsedAt
:
ptr
(
deps
.
now
),
Notes
:
"internal-note"
,
CreatedAt
:
deps
.
now
,
},
})
},
method
:
http
.
MethodGet
,
path
:
"/api/v1/redeem/history"
,
wantStatus
:
http
.
StatusOK
,
wantJSON
:
`{
"code": 0,
"message": "success",
"data": [
{
"id": 900,
"code": "CODE-123",
"type": "balance",
"value": 1.25,
"status": "used",
"used_by": 1,
"used_at": "2025-01-02T03:04:05Z",
"created_at": "2025-01-02T03:04:05Z",
"group_id": null,
"validity_days": 0
}
]
}`
,
},
{
name
:
"GET /api/v1/usage/stats"
,
setup
:
func
(
t
*
testing
.
T
,
deps
*
contractDeps
)
{
...
...
@@ -194,6 +340,7 @@ func TestAPIContracts(t *testing.T) {
UserID
:
1
,
APIKeyID
:
100
,
AccountID
:
200
,
AccountRateMultiplier
:
ptr
(
0.5
),
RequestID
:
"req_123"
,
Model
:
"claude-3"
,
InputTokens
:
10
,
...
...
@@ -241,7 +388,6 @@ func TestAPIContracts(t *testing.T) {
"total_cost": 0.5,
"actual_cost": 0.5,
"rate_multiplier": 1,
"account_rate_multiplier": null,
"billing_type": 0,
"stream": true,
"duration_ms": 100,
...
...
@@ -266,6 +412,7 @@ func TestAPIContracts(t *testing.T) {
deps
.
settingRepo
.
SetAll
(
map
[
string
]
string
{
service
.
SettingKeyRegistrationEnabled
:
"true"
,
service
.
SettingKeyEmailVerifyEnabled
:
"false"
,
service
.
SettingKeyPromoCodeEnabled
:
"true"
,
service
.
SettingKeySMTPHost
:
"smtp.example.com"
,
service
.
SettingKeySMTPPort
:
"587"
,
...
...
@@ -304,6 +451,10 @@ func TestAPIContracts(t *testing.T) {
"data": {
"registration_enabled": true,
"email_verify_enabled": false,
"promo_code_enabled": true,
"password_reset_enabled": false,
"totp_enabled": false,
"totp_encryption_key_configured": false,
"smtp_host": "smtp.example.com",
"smtp_port": 587,
"smtp_username": "user",
...
...
@@ -337,7 +488,10 @@ func TestAPIContracts(t *testing.T) {
"fallback_model_openai": "gpt-4o",
"enable_identity_patch": true,
"identity_patch_prompt": "",
"home_content": ""
"home_content": "",
"hide_ccs_import_button": false,
"purchase_subscription_enabled": false,
"purchase_subscription_url": ""
}
}`
,
},
...
...
@@ -385,8 +539,11 @@ type contractDeps struct {
now
time
.
Time
router
http
.
Handler
apiKeyRepo
*
stubApiKeyRepo
groupRepo
*
stubGroupRepo
userSubRepo
*
stubUserSubscriptionRepo
usageRepo
*
stubUsageLogRepo
settingRepo
*
stubSettingRepo
redeemRepo
*
stubRedeemCodeRepo
}
func
newContractDeps
(
t
*
testing
.
T
)
*
contractDeps
{
...
...
@@ -414,11 +571,11 @@ func newContractDeps(t *testing.T) *contractDeps {
apiKeyRepo
:=
newStubApiKeyRepo
(
now
)
apiKeyCache
:=
stubApiKeyCache
{}
groupRepo
:=
stubGroupRepo
{}
userSubRepo
:=
stubUserSubscriptionRepo
{}
groupRepo
:=
&
stubGroupRepo
{}
userSubRepo
:=
&
stubUserSubscriptionRepo
{}
accountRepo
:=
stubAccountRepo
{}
proxyRepo
:=
stubProxyRepo
{}
redeemRepo
:=
stubRedeemCodeRepo
{}
redeemRepo
:=
&
stubRedeemCodeRepo
{}
cfg
:=
&
config
.
Config
{
Default
:
config
.
DefaultConfig
{
...
...
@@ -433,15 +590,21 @@ func newContractDeps(t *testing.T) *contractDeps {
usageRepo
:=
newStubUsageLogRepo
()
usageService
:=
service
.
NewUsageService
(
usageRepo
,
userRepo
,
nil
,
nil
)
subscriptionService
:=
service
.
NewSubscriptionService
(
groupRepo
,
userSubRepo
,
nil
)
subscriptionHandler
:=
handler
.
NewSubscriptionHandler
(
subscriptionService
)
redeemService
:=
service
.
NewRedeemService
(
redeemRepo
,
userRepo
,
subscriptionService
,
nil
,
nil
,
nil
,
nil
)
redeemHandler
:=
handler
.
NewRedeemHandler
(
redeemService
)
settingRepo
:=
newStubSettingRepo
()
settingService
:=
service
.
NewSettingService
(
settingRepo
,
cfg
)
adminService
:=
service
.
NewAdminService
(
userRepo
,
groupRepo
,
&
accountRepo
,
proxyRepo
,
apiKeyRepo
,
redeemRepo
,
nil
,
nil
,
nil
,
nil
)
authHandler
:=
handler
.
NewAuthHandler
(
cfg
,
nil
,
userService
,
settingService
,
nil
)
authHandler
:=
handler
.
NewAuthHandler
(
cfg
,
nil
,
userService
,
settingService
,
nil
,
nil
)
apiKeyHandler
:=
handler
.
NewAPIKeyHandler
(
apiKeyService
)
usageHandler
:=
handler
.
NewUsageHandler
(
usageService
,
apiKeyService
)
adminSettingHandler
:=
adminhandler
.
NewSettingHandler
(
settingService
,
nil
,
nil
,
nil
)
adminAccountHandler
:=
adminhandler
.
NewAccountHandler
(
adminService
,
nil
,
nil
,
nil
,
nil
,
nil
,
nil
,
nil
,
nil
,
nil
,
nil
)
adminAccountHandler
:=
adminhandler
.
NewAccountHandler
(
adminService
,
nil
,
nil
,
nil
,
nil
,
nil
,
nil
,
nil
,
nil
,
nil
,
nil
,
nil
)
jwtAuth
:=
func
(
c
*
gin
.
Context
)
{
c
.
Set
(
string
(
middleware
.
ContextKeyUser
),
middleware
.
AuthSubject
{
...
...
@@ -472,12 +635,21 @@ func newContractDeps(t *testing.T) *contractDeps {
v1Keys
.
Use
(
jwtAuth
)
v1Keys
.
GET
(
"/keys"
,
apiKeyHandler
.
List
)
v1Keys
.
POST
(
"/keys"
,
apiKeyHandler
.
Create
)
v1Keys
.
GET
(
"/groups/available"
,
apiKeyHandler
.
GetAvailableGroups
)
v1Usage
:=
v1
.
Group
(
""
)
v1Usage
.
Use
(
jwtAuth
)
v1Usage
.
GET
(
"/usage"
,
usageHandler
.
List
)
v1Usage
.
GET
(
"/usage/stats"
,
usageHandler
.
Stats
)
v1Subs
:=
v1
.
Group
(
""
)
v1Subs
.
Use
(
jwtAuth
)
v1Subs
.
GET
(
"/subscriptions"
,
subscriptionHandler
.
List
)
v1Redeem
:=
v1
.
Group
(
""
)
v1Redeem
.
Use
(
jwtAuth
)
v1Redeem
.
GET
(
"/redeem/history"
,
redeemHandler
.
GetHistory
)
v1Admin
:=
v1
.
Group
(
"/admin"
)
v1Admin
.
Use
(
adminAuth
)
v1Admin
.
GET
(
"/settings"
,
adminSettingHandler
.
GetSettings
)
...
...
@@ -487,8 +659,11 @@ func newContractDeps(t *testing.T) *contractDeps {
now
:
now
,
router
:
r
,
apiKeyRepo
:
apiKeyRepo
,
groupRepo
:
groupRepo
,
userSubRepo
:
userSubRepo
,
usageRepo
:
usageRepo
,
settingRepo
:
settingRepo
,
redeemRepo
:
redeemRepo
,
}
}
...
...
@@ -584,6 +759,18 @@ func (r *stubUserRepo) RemoveGroupFromAllowedGroups(ctx context.Context, groupID
return
0
,
errors
.
New
(
"not implemented"
)
}
func
(
r
*
stubUserRepo
)
UpdateTotpSecret
(
ctx
context
.
Context
,
userID
int64
,
encryptedSecret
*
string
)
error
{
return
errors
.
New
(
"not implemented"
)
}
func
(
r
*
stubUserRepo
)
EnableTotp
(
ctx
context
.
Context
,
userID
int64
)
error
{
return
errors
.
New
(
"not implemented"
)
}
func
(
r
*
stubUserRepo
)
DisableTotp
(
ctx
context
.
Context
,
userID
int64
)
error
{
return
errors
.
New
(
"not implemented"
)
}
type
stubApiKeyCache
struct
{}
func
(
stubApiKeyCache
)
GetCreateAttemptCount
(
ctx
context
.
Context
,
userID
int64
)
(
int
,
error
)
{
...
...
@@ -618,7 +805,21 @@ func (stubApiKeyCache) DeleteAuthCache(ctx context.Context, key string) error {
return
nil
}
type
stubGroupRepo
struct
{}
func
(
stubApiKeyCache
)
PublishAuthCacheInvalidation
(
ctx
context
.
Context
,
cacheKey
string
)
error
{
return
nil
}
func
(
stubApiKeyCache
)
SubscribeAuthCacheInvalidation
(
ctx
context
.
Context
,
handler
func
(
cacheKey
string
))
error
{
return
nil
}
type
stubGroupRepo
struct
{
active
[]
service
.
Group
}
func
(
r
*
stubGroupRepo
)
SetActive
(
groups
[]
service
.
Group
)
{
r
.
active
=
append
([]
service
.
Group
(
nil
),
groups
...
)
}
func
(
stubGroupRepo
)
Create
(
ctx
context
.
Context
,
group
*
service
.
Group
)
error
{
return
errors
.
New
(
"not implemented"
)
...
...
@@ -652,12 +853,19 @@ func (stubGroupRepo) ListWithFilters(ctx context.Context, params pagination.Pagi
return
nil
,
nil
,
errors
.
New
(
"not implemented"
)
}
func
(
stubGroupRepo
)
ListActive
(
ctx
context
.
Context
)
([]
service
.
Group
,
error
)
{
return
nil
,
errors
.
New
(
"not implemented"
)
func
(
r
*
stubGroupRepo
)
ListActive
(
ctx
context
.
Context
)
([]
service
.
Group
,
error
)
{
return
append
([]
service
.
Group
(
nil
),
r
.
active
...
),
nil
}
func
(
stubGroupRepo
)
ListActiveByPlatform
(
ctx
context
.
Context
,
platform
string
)
([]
service
.
Group
,
error
)
{
return
nil
,
errors
.
New
(
"not implemented"
)
func
(
r
*
stubGroupRepo
)
ListActiveByPlatform
(
ctx
context
.
Context
,
platform
string
)
([]
service
.
Group
,
error
)
{
out
:=
make
([]
service
.
Group
,
0
,
len
(
r
.
active
))
for
i
:=
range
r
.
active
{
g
:=
r
.
active
[
i
]
if
g
.
Platform
==
platform
{
out
=
append
(
out
,
g
)
}
}
return
out
,
nil
}
func
(
stubGroupRepo
)
ExistsByName
(
ctx
context
.
Context
,
name
string
)
(
bool
,
error
)
{
...
...
@@ -736,6 +944,10 @@ func (s *stubAccountRepo) SetError(ctx context.Context, id int64, errorMsg strin
return
errors
.
New
(
"not implemented"
)
}
func
(
s
*
stubAccountRepo
)
ClearError
(
ctx
context
.
Context
,
id
int64
)
error
{
return
errors
.
New
(
"not implemented"
)
}
func
(
s
*
stubAccountRepo
)
SetSchedulable
(
ctx
context
.
Context
,
id
int64
,
schedulable
bool
)
error
{
return
errors
.
New
(
"not implemented"
)
}
...
...
@@ -871,7 +1083,16 @@ func (stubProxyRepo) ListAccountSummariesByProxyID(ctx context.Context, proxyID
return
nil
,
errors
.
New
(
"not implemented"
)
}
type
stubRedeemCodeRepo
struct
{}
type
stubRedeemCodeRepo
struct
{
byUser
map
[
int64
][]
service
.
RedeemCode
}
func
(
r
*
stubRedeemCodeRepo
)
SetByUser
(
userID
int64
,
codes
[]
service
.
RedeemCode
)
{
if
r
.
byUser
==
nil
{
r
.
byUser
=
make
(
map
[
int64
][]
service
.
RedeemCode
)
}
r
.
byUser
[
userID
]
=
append
([]
service
.
RedeemCode
(
nil
),
codes
...
)
}
func
(
stubRedeemCodeRepo
)
Create
(
ctx
context
.
Context
,
code
*
service
.
RedeemCode
)
error
{
return
errors
.
New
(
"not implemented"
)
...
...
@@ -909,11 +1130,35 @@ func (stubRedeemCodeRepo) ListWithFilters(ctx context.Context, params pagination
return
nil
,
nil
,
errors
.
New
(
"not implemented"
)
}
func
(
stubRedeemCodeRepo
)
ListByUser
(
ctx
context
.
Context
,
userID
int64
,
limit
int
)
([]
service
.
RedeemCode
,
error
)
{
return
nil
,
errors
.
New
(
"not implemented"
)
func
(
r
*
stubRedeemCodeRepo
)
ListByUser
(
ctx
context
.
Context
,
userID
int64
,
limit
int
)
([]
service
.
RedeemCode
,
error
)
{
if
r
.
byUser
==
nil
{
return
nil
,
nil
}
codes
:=
r
.
byUser
[
userID
]
if
limit
>
0
&&
len
(
codes
)
>
limit
{
codes
=
codes
[
:
limit
]
}
return
append
([]
service
.
RedeemCode
(
nil
),
codes
...
),
nil
}
type
stubUserSubscriptionRepo
struct
{
byUser
map
[
int64
][]
service
.
UserSubscription
activeByUser
map
[
int64
][]
service
.
UserSubscription
}
type
stubUserSubscriptionRepo
struct
{}
func
(
r
*
stubUserSubscriptionRepo
)
SetByUserID
(
userID
int64
,
subs
[]
service
.
UserSubscription
)
{
if
r
.
byUser
==
nil
{
r
.
byUser
=
make
(
map
[
int64
][]
service
.
UserSubscription
)
}
r
.
byUser
[
userID
]
=
append
([]
service
.
UserSubscription
(
nil
),
subs
...
)
}
func
(
r
*
stubUserSubscriptionRepo
)
SetActiveByUserID
(
userID
int64
,
subs
[]
service
.
UserSubscription
)
{
if
r
.
activeByUser
==
nil
{
r
.
activeByUser
=
make
(
map
[
int64
][]
service
.
UserSubscription
)
}
r
.
activeByUser
[
userID
]
=
append
([]
service
.
UserSubscription
(
nil
),
subs
...
)
}
func
(
stubUserSubscriptionRepo
)
Create
(
ctx
context
.
Context
,
sub
*
service
.
UserSubscription
)
error
{
return
errors
.
New
(
"not implemented"
)
...
...
@@ -933,16 +1178,22 @@ func (stubUserSubscriptionRepo) Update(ctx context.Context, sub *service.UserSub
func
(
stubUserSubscriptionRepo
)
Delete
(
ctx
context
.
Context
,
id
int64
)
error
{
return
errors
.
New
(
"not implemented"
)
}
func
(
stubUserSubscriptionRepo
)
ListByUserID
(
ctx
context
.
Context
,
userID
int64
)
([]
service
.
UserSubscription
,
error
)
{
return
nil
,
errors
.
New
(
"not implemented"
)
func
(
r
*
stubUserSubscriptionRepo
)
ListByUserID
(
ctx
context
.
Context
,
userID
int64
)
([]
service
.
UserSubscription
,
error
)
{
if
r
.
byUser
==
nil
{
return
nil
,
nil
}
return
append
([]
service
.
UserSubscription
(
nil
),
r
.
byUser
[
userID
]
...
),
nil
}
func
(
stubUserSubscriptionRepo
)
ListActiveByUserID
(
ctx
context
.
Context
,
userID
int64
)
([]
service
.
UserSubscription
,
error
)
{
return
nil
,
errors
.
New
(
"not implemented"
)
func
(
r
*
stubUserSubscriptionRepo
)
ListActiveByUserID
(
ctx
context
.
Context
,
userID
int64
)
([]
service
.
UserSubscription
,
error
)
{
if
r
.
activeByUser
==
nil
{
return
nil
,
nil
}
return
append
([]
service
.
UserSubscription
(
nil
),
r
.
activeByUser
[
userID
]
...
),
nil
}
func
(
stubUserSubscriptionRepo
)
ListByGroupID
(
ctx
context
.
Context
,
groupID
int64
,
params
pagination
.
PaginationParams
)
([]
service
.
UserSubscription
,
*
pagination
.
PaginationResult
,
error
)
{
return
nil
,
nil
,
errors
.
New
(
"not implemented"
)
}
func
(
stubUserSubscriptionRepo
)
List
(
ctx
context
.
Context
,
params
pagination
.
PaginationParams
,
userID
,
groupID
*
int64
,
status
string
)
([]
service
.
UserSubscription
,
*
pagination
.
PaginationResult
,
error
)
{
func
(
stubUserSubscriptionRepo
)
List
(
ctx
context
.
Context
,
params
pagination
.
PaginationParams
,
userID
,
groupID
*
int64
,
status
,
sortBy
,
sortOrder
string
)
([]
service
.
UserSubscription
,
*
pagination
.
PaginationResult
,
error
)
{
return
nil
,
nil
,
errors
.
New
(
"not implemented"
)
}
func
(
stubUserSubscriptionRepo
)
ExistsByUserIDAndGroupID
(
ctx
context
.
Context
,
userID
,
groupID
int64
)
(
bool
,
error
)
{
...
...
@@ -1242,11 +1493,11 @@ func (r *stubUsageLogRepo) GetDashboardStats(ctx context.Context) (*usagestats.D
return
nil
,
errors
.
New
(
"not implemented"
)
}
func
(
r
*
stubUsageLogRepo
)
GetUsageTrendWithFilters
(
ctx
context
.
Context
,
startTime
,
endTime
time
.
Time
,
granularity
string
,
userID
,
apiKeyID
,
accountID
,
groupID
int64
,
model
string
,
stream
*
bool
)
([]
usagestats
.
TrendDataPoint
,
error
)
{
func
(
r
*
stubUsageLogRepo
)
GetUsageTrendWithFilters
(
ctx
context
.
Context
,
startTime
,
endTime
time
.
Time
,
granularity
string
,
userID
,
apiKeyID
,
accountID
,
groupID
int64
,
model
string
,
stream
*
bool
,
billingType
*
int8
)
([]
usagestats
.
TrendDataPoint
,
error
)
{
return
nil
,
errors
.
New
(
"not implemented"
)
}
func
(
r
*
stubUsageLogRepo
)
GetModelStatsWithFilters
(
ctx
context
.
Context
,
startTime
,
endTime
time
.
Time
,
userID
,
apiKeyID
,
accountID
,
groupID
int64
,
stream
*
bool
)
([]
usagestats
.
ModelStat
,
error
)
{
func
(
r
*
stubUsageLogRepo
)
GetModelStatsWithFilters
(
ctx
context
.
Context
,
startTime
,
endTime
time
.
Time
,
userID
,
apiKeyID
,
accountID
,
groupID
int64
,
stream
*
bool
,
billingType
*
int8
)
([]
usagestats
.
ModelStat
,
error
)
{
return
nil
,
errors
.
New
(
"not implemented"
)
}
...
...
backend/internal/server/middleware/api_key_auth_test.go
View file @
2fe8932c
...
...
@@ -367,7 +367,7 @@ func (r *stubUserSubscriptionRepo) ListByGroupID(ctx context.Context, groupID in
return
nil
,
nil
,
errors
.
New
(
"not implemented"
)
}
func
(
r
*
stubUserSubscriptionRepo
)
List
(
ctx
context
.
Context
,
params
pagination
.
PaginationParams
,
userID
,
groupID
*
int64
,
status
string
)
([]
service
.
UserSubscription
,
*
pagination
.
PaginationResult
,
error
)
{
func
(
r
*
stubUserSubscriptionRepo
)
List
(
ctx
context
.
Context
,
params
pagination
.
PaginationParams
,
userID
,
groupID
*
int64
,
status
,
sortBy
,
sortOrder
string
)
([]
service
.
UserSubscription
,
*
pagination
.
PaginationResult
,
error
)
{
return
nil
,
nil
,
errors
.
New
(
"not implemented"
)
}
...
...
backend/internal/server/routes/admin.go
View file @
2fe8932c
...
...
@@ -354,6 +354,9 @@ func registerUsageRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
usage
.
GET
(
"/stats"
,
h
.
Admin
.
Usage
.
Stats
)
usage
.
GET
(
"/search-users"
,
h
.
Admin
.
Usage
.
SearchUsers
)
usage
.
GET
(
"/search-api-keys"
,
h
.
Admin
.
Usage
.
SearchAPIKeys
)
usage
.
GET
(
"/cleanup-tasks"
,
h
.
Admin
.
Usage
.
ListCleanupTasks
)
usage
.
POST
(
"/cleanup-tasks"
,
h
.
Admin
.
Usage
.
CreateCleanupTask
)
usage
.
POST
(
"/cleanup-tasks/:id/cancel"
,
h
.
Admin
.
Usage
.
CancelCleanupTask
)
}
}
...
...
Prev
1
2
3
4
5
6
7
8
9
10
…
14
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment