Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Menu
Open sidebar
陈曦
sub2api
Commits
0170d19f
Commit
0170d19f
authored
Feb 02, 2026
by
song
Browse files
merge upstream main
parent
7ade9baa
Changes
319
Hide whitespace changes
Inline
Side-by-side
backend/internal/repository/usage_log_repo_integration_test.go
View file @
0170d19f
...
...
@@ -944,17 +944,17 @@ func (s *UsageLogRepoSuite) TestGetUsageTrendWithFilters() {
endTime
:=
base
.
Add
(
48
*
time
.
Hour
)
// Test with user filter
trend
,
err
:=
s
.
repo
.
GetUsageTrendWithFilters
(
s
.
ctx
,
startTime
,
endTime
,
"day"
,
user
.
ID
,
0
,
0
,
0
,
""
,
nil
)
trend
,
err
:=
s
.
repo
.
GetUsageTrendWithFilters
(
s
.
ctx
,
startTime
,
endTime
,
"day"
,
user
.
ID
,
0
,
0
,
0
,
""
,
nil
,
nil
)
s
.
Require
()
.
NoError
(
err
,
"GetUsageTrendWithFilters user filter"
)
s
.
Require
()
.
Len
(
trend
,
2
)
// Test with apiKey filter
trend
,
err
=
s
.
repo
.
GetUsageTrendWithFilters
(
s
.
ctx
,
startTime
,
endTime
,
"day"
,
0
,
apiKey
.
ID
,
0
,
0
,
""
,
nil
)
trend
,
err
=
s
.
repo
.
GetUsageTrendWithFilters
(
s
.
ctx
,
startTime
,
endTime
,
"day"
,
0
,
apiKey
.
ID
,
0
,
0
,
""
,
nil
,
nil
)
s
.
Require
()
.
NoError
(
err
,
"GetUsageTrendWithFilters apiKey filter"
)
s
.
Require
()
.
Len
(
trend
,
2
)
// Test with both filters
trend
,
err
=
s
.
repo
.
GetUsageTrendWithFilters
(
s
.
ctx
,
startTime
,
endTime
,
"day"
,
user
.
ID
,
apiKey
.
ID
,
0
,
0
,
""
,
nil
)
trend
,
err
=
s
.
repo
.
GetUsageTrendWithFilters
(
s
.
ctx
,
startTime
,
endTime
,
"day"
,
user
.
ID
,
apiKey
.
ID
,
0
,
0
,
""
,
nil
,
nil
)
s
.
Require
()
.
NoError
(
err
,
"GetUsageTrendWithFilters both filters"
)
s
.
Require
()
.
Len
(
trend
,
2
)
}
...
...
@@ -971,7 +971,7 @@ func (s *UsageLogRepoSuite) TestGetUsageTrendWithFilters_HourlyGranularity() {
startTime
:=
base
.
Add
(
-
1
*
time
.
Hour
)
endTime
:=
base
.
Add
(
3
*
time
.
Hour
)
trend
,
err
:=
s
.
repo
.
GetUsageTrendWithFilters
(
s
.
ctx
,
startTime
,
endTime
,
"hour"
,
user
.
ID
,
0
,
0
,
0
,
""
,
nil
)
trend
,
err
:=
s
.
repo
.
GetUsageTrendWithFilters
(
s
.
ctx
,
startTime
,
endTime
,
"hour"
,
user
.
ID
,
0
,
0
,
0
,
""
,
nil
,
nil
)
s
.
Require
()
.
NoError
(
err
,
"GetUsageTrendWithFilters hourly"
)
s
.
Require
()
.
Len
(
trend
,
2
)
}
...
...
@@ -1017,17 +1017,17 @@ func (s *UsageLogRepoSuite) TestGetModelStatsWithFilters() {
endTime
:=
base
.
Add
(
2
*
time
.
Hour
)
// Test with user filter
stats
,
err
:=
s
.
repo
.
GetModelStatsWithFilters
(
s
.
ctx
,
startTime
,
endTime
,
user
.
ID
,
0
,
0
,
0
,
nil
)
stats
,
err
:=
s
.
repo
.
GetModelStatsWithFilters
(
s
.
ctx
,
startTime
,
endTime
,
user
.
ID
,
0
,
0
,
0
,
nil
,
nil
)
s
.
Require
()
.
NoError
(
err
,
"GetModelStatsWithFilters user filter"
)
s
.
Require
()
.
Len
(
stats
,
2
)
// Test with apiKey filter
stats
,
err
=
s
.
repo
.
GetModelStatsWithFilters
(
s
.
ctx
,
startTime
,
endTime
,
0
,
apiKey
.
ID
,
0
,
0
,
nil
)
stats
,
err
=
s
.
repo
.
GetModelStatsWithFilters
(
s
.
ctx
,
startTime
,
endTime
,
0
,
apiKey
.
ID
,
0
,
0
,
nil
,
nil
)
s
.
Require
()
.
NoError
(
err
,
"GetModelStatsWithFilters apiKey filter"
)
s
.
Require
()
.
Len
(
stats
,
2
)
// Test with account filter
stats
,
err
=
s
.
repo
.
GetModelStatsWithFilters
(
s
.
ctx
,
startTime
,
endTime
,
0
,
0
,
account
.
ID
,
0
,
nil
)
stats
,
err
=
s
.
repo
.
GetModelStatsWithFilters
(
s
.
ctx
,
startTime
,
endTime
,
0
,
0
,
account
.
ID
,
0
,
nil
,
nil
)
s
.
Require
()
.
NoError
(
err
,
"GetModelStatsWithFilters account filter"
)
s
.
Require
()
.
Len
(
stats
,
2
)
}
...
...
backend/internal/repository/user_repo.go
View file @
0170d19f
...
...
@@ -7,6 +7,7 @@ import (
"fmt"
"sort"
"strings"
"time"
dbent
"github.com/Wei-Shaw/sub2api/ent"
dbuser
"github.com/Wei-Shaw/sub2api/ent/user"
...
...
@@ -189,6 +190,7 @@ func (r *userRepository) ListWithFilters(ctx context.Context, params pagination.
dbuser
.
Or
(
dbuser
.
EmailContainsFold
(
filters
.
Search
),
dbuser
.
UsernameContainsFold
(
filters
.
Search
),
dbuser
.
NotesContainsFold
(
filters
.
Search
),
),
)
}
...
...
@@ -466,3 +468,46 @@ func applyUserEntityToService(dst *service.User, src *dbent.User) {
dst
.
CreatedAt
=
src
.
CreatedAt
dst
.
UpdatedAt
=
src
.
UpdatedAt
}
// UpdateTotpSecret 更新用户的 TOTP 加密密钥
func
(
r
*
userRepository
)
UpdateTotpSecret
(
ctx
context
.
Context
,
userID
int64
,
encryptedSecret
*
string
)
error
{
client
:=
clientFromContext
(
ctx
,
r
.
client
)
update
:=
client
.
User
.
UpdateOneID
(
userID
)
if
encryptedSecret
==
nil
{
update
=
update
.
ClearTotpSecretEncrypted
()
}
else
{
update
=
update
.
SetTotpSecretEncrypted
(
*
encryptedSecret
)
}
_
,
err
:=
update
.
Save
(
ctx
)
if
err
!=
nil
{
return
translatePersistenceError
(
err
,
service
.
ErrUserNotFound
,
nil
)
}
return
nil
}
// EnableTotp 启用用户的 TOTP 双因素认证
func
(
r
*
userRepository
)
EnableTotp
(
ctx
context
.
Context
,
userID
int64
)
error
{
client
:=
clientFromContext
(
ctx
,
r
.
client
)
_
,
err
:=
client
.
User
.
UpdateOneID
(
userID
)
.
SetTotpEnabled
(
true
)
.
SetTotpEnabledAt
(
time
.
Now
())
.
Save
(
ctx
)
if
err
!=
nil
{
return
translatePersistenceError
(
err
,
service
.
ErrUserNotFound
,
nil
)
}
return
nil
}
// DisableTotp 禁用用户的 TOTP 双因素认证
func
(
r
*
userRepository
)
DisableTotp
(
ctx
context
.
Context
,
userID
int64
)
error
{
client
:=
clientFromContext
(
ctx
,
r
.
client
)
_
,
err
:=
client
.
User
.
UpdateOneID
(
userID
)
.
SetTotpEnabled
(
false
)
.
ClearTotpEnabledAt
()
.
ClearTotpSecretEncrypted
()
.
Save
(
ctx
)
if
err
!=
nil
{
return
translatePersistenceError
(
err
,
service
.
ErrUserNotFound
,
nil
)
}
return
nil
}
backend/internal/repository/user_subscription_repo.go
View file @
0170d19f
...
...
@@ -190,7 +190,7 @@ func (r *userSubscriptionRepository) ListByGroupID(ctx context.Context, groupID
return
userSubscriptionEntitiesToService
(
subs
),
paginationResultFromTotal
(
int64
(
total
),
params
),
nil
}
func
(
r
*
userSubscriptionRepository
)
List
(
ctx
context
.
Context
,
params
pagination
.
PaginationParams
,
userID
,
groupID
*
int64
,
status
string
)
([]
service
.
UserSubscription
,
*
pagination
.
PaginationResult
,
error
)
{
func
(
r
*
userSubscriptionRepository
)
List
(
ctx
context
.
Context
,
params
pagination
.
PaginationParams
,
userID
,
groupID
*
int64
,
status
,
sortBy
,
sortOrder
string
)
([]
service
.
UserSubscription
,
*
pagination
.
PaginationResult
,
error
)
{
client
:=
clientFromContext
(
ctx
,
r
.
client
)
q
:=
client
.
UserSubscription
.
Query
()
if
userID
!=
nil
{
...
...
@@ -199,7 +199,31 @@ func (r *userSubscriptionRepository) List(ctx context.Context, params pagination
if
groupID
!=
nil
{
q
=
q
.
Where
(
usersubscription
.
GroupIDEQ
(
*
groupID
))
}
if
status
!=
""
{
// Status filtering with real-time expiration check
now
:=
time
.
Now
()
switch
status
{
case
service
.
SubscriptionStatusActive
:
// Active: status is active AND not yet expired
q
=
q
.
Where
(
usersubscription
.
StatusEQ
(
service
.
SubscriptionStatusActive
),
usersubscription
.
ExpiresAtGT
(
now
),
)
case
service
.
SubscriptionStatusExpired
:
// Expired: status is expired OR (status is active but already expired)
q
=
q
.
Where
(
usersubscription
.
Or
(
usersubscription
.
StatusEQ
(
service
.
SubscriptionStatusExpired
),
usersubscription
.
And
(
usersubscription
.
StatusEQ
(
service
.
SubscriptionStatusActive
),
usersubscription
.
ExpiresAtLTE
(
now
),
),
),
)
case
""
:
// No filter
default
:
// Other status (e.g., revoked)
q
=
q
.
Where
(
usersubscription
.
StatusEQ
(
status
))
}
...
...
@@ -208,11 +232,28 @@ func (r *userSubscriptionRepository) List(ctx context.Context, params pagination
return
nil
,
nil
,
err
}
// Apply sorting
q
=
q
.
WithUser
()
.
WithGroup
()
.
WithAssignedByUser
()
// Determine sort field
var
field
string
switch
sortBy
{
case
"expires_at"
:
field
=
usersubscription
.
FieldExpiresAt
case
"status"
:
field
=
usersubscription
.
FieldStatus
default
:
field
=
usersubscription
.
FieldCreatedAt
}
// Determine sort order (default: desc)
if
sortOrder
==
"asc"
&&
sortBy
!=
""
{
q
=
q
.
Order
(
dbent
.
Asc
(
field
))
}
else
{
q
=
q
.
Order
(
dbent
.
Desc
(
field
))
}
subs
,
err
:=
q
.
WithUser
()
.
WithGroup
()
.
WithAssignedByUser
()
.
Order
(
dbent
.
Desc
(
usersubscription
.
FieldCreatedAt
))
.
Offset
(
params
.
Offset
())
.
Limit
(
params
.
Limit
())
.
All
(
ctx
)
...
...
backend/internal/repository/user_subscription_repo_integration_test.go
View file @
0170d19f
...
...
@@ -271,7 +271,7 @@ func (s *UserSubscriptionRepoSuite) TestList_NoFilters() {
group
:=
s
.
mustCreateGroup
(
"g-list"
)
s
.
mustCreateSubscription
(
user
.
ID
,
group
.
ID
,
nil
)
subs
,
page
,
err
:=
s
.
repo
.
List
(
s
.
ctx
,
pagination
.
PaginationParams
{
Page
:
1
,
PageSize
:
10
},
nil
,
nil
,
""
)
subs
,
page
,
err
:=
s
.
repo
.
List
(
s
.
ctx
,
pagination
.
PaginationParams
{
Page
:
1
,
PageSize
:
10
},
nil
,
nil
,
""
,
""
,
""
)
s
.
Require
()
.
NoError
(
err
,
"List"
)
s
.
Require
()
.
Len
(
subs
,
1
)
s
.
Require
()
.
Equal
(
int64
(
1
),
page
.
Total
)
...
...
@@ -285,7 +285,7 @@ func (s *UserSubscriptionRepoSuite) TestList_FilterByUserID() {
s
.
mustCreateSubscription
(
user1
.
ID
,
group
.
ID
,
nil
)
s
.
mustCreateSubscription
(
user2
.
ID
,
group
.
ID
,
nil
)
subs
,
_
,
err
:=
s
.
repo
.
List
(
s
.
ctx
,
pagination
.
PaginationParams
{
Page
:
1
,
PageSize
:
10
},
&
user1
.
ID
,
nil
,
""
)
subs
,
_
,
err
:=
s
.
repo
.
List
(
s
.
ctx
,
pagination
.
PaginationParams
{
Page
:
1
,
PageSize
:
10
},
&
user1
.
ID
,
nil
,
""
,
""
,
""
)
s
.
Require
()
.
NoError
(
err
)
s
.
Require
()
.
Len
(
subs
,
1
)
s
.
Require
()
.
Equal
(
user1
.
ID
,
subs
[
0
]
.
UserID
)
...
...
@@ -299,7 +299,7 @@ func (s *UserSubscriptionRepoSuite) TestList_FilterByGroupID() {
s
.
mustCreateSubscription
(
user
.
ID
,
g1
.
ID
,
nil
)
s
.
mustCreateSubscription
(
user
.
ID
,
g2
.
ID
,
nil
)
subs
,
_
,
err
:=
s
.
repo
.
List
(
s
.
ctx
,
pagination
.
PaginationParams
{
Page
:
1
,
PageSize
:
10
},
nil
,
&
g1
.
ID
,
""
)
subs
,
_
,
err
:=
s
.
repo
.
List
(
s
.
ctx
,
pagination
.
PaginationParams
{
Page
:
1
,
PageSize
:
10
},
nil
,
&
g1
.
ID
,
""
,
""
,
""
)
s
.
Require
()
.
NoError
(
err
)
s
.
Require
()
.
Len
(
subs
,
1
)
s
.
Require
()
.
Equal
(
g1
.
ID
,
subs
[
0
]
.
GroupID
)
...
...
@@ -320,7 +320,7 @@ func (s *UserSubscriptionRepoSuite) TestList_FilterByStatus() {
c
.
SetExpiresAt
(
time
.
Now
()
.
Add
(
-
24
*
time
.
Hour
))
})
subs
,
_
,
err
:=
s
.
repo
.
List
(
s
.
ctx
,
pagination
.
PaginationParams
{
Page
:
1
,
PageSize
:
10
},
nil
,
nil
,
service
.
SubscriptionStatusExpired
)
subs
,
_
,
err
:=
s
.
repo
.
List
(
s
.
ctx
,
pagination
.
PaginationParams
{
Page
:
1
,
PageSize
:
10
},
nil
,
nil
,
service
.
SubscriptionStatusExpired
,
""
,
""
)
s
.
Require
()
.
NoError
(
err
)
s
.
Require
()
.
Len
(
subs
,
1
)
s
.
Require
()
.
Equal
(
service
.
SubscriptionStatusExpired
,
subs
[
0
]
.
Status
)
...
...
backend/internal/repository/wire.go
View file @
0170d19f
...
...
@@ -56,7 +56,10 @@ var ProviderSet = wire.NewSet(
NewProxyRepository
,
NewRedeemCodeRepository
,
NewPromoCodeRepository
,
NewAnnouncementRepository
,
NewAnnouncementReadRepository
,
NewUsageLogRepository
,
NewUsageCleanupRepository
,
NewDashboardAggregationRepository
,
NewSettingRepository
,
NewOpsRepository
,
...
...
@@ -81,6 +84,10 @@ var ProviderSet = wire.NewSet(
NewSchedulerCache
,
NewSchedulerOutboxRepository
,
NewProxyLatencyCache
,
NewTotpCache
,
// Encryptors
NewAESEncryptor
,
// HTTP service ports (DI Strategy A: return interface directly)
NewTurnstileVerifier
,
...
...
backend/internal/server/api_contract_test.go
View file @
0170d19f
...
...
@@ -11,7 +11,6 @@ import (
"net/http"
"net/http/httptest"
"sort"
"strings"
"testing"
"time"
...
...
@@ -52,7 +51,6 @@ func TestAPIContracts(t *testing.T) {
"id": 1,
"email": "alice@example.com",
"username": "alice",
"notes": "hello",
"role": "user",
"balance": 12.5,
"concurrency": 5,
...
...
@@ -132,6 +130,153 @@ func TestAPIContracts(t *testing.T) {
}
}`
,
},
{
name
:
"GET /api/v1/groups/available"
,
setup
:
func
(
t
*
testing
.
T
,
deps
*
contractDeps
)
{
t
.
Helper
()
// 普通用户可见的分组列表不应包含内部字段(如 model_routing/account_count)。
deps
.
groupRepo
.
SetActive
([]
service
.
Group
{
{
ID
:
10
,
Name
:
"Group One"
,
Description
:
"desc"
,
Platform
:
service
.
PlatformAnthropic
,
RateMultiplier
:
1.5
,
IsExclusive
:
false
,
Status
:
service
.
StatusActive
,
SubscriptionType
:
service
.
SubscriptionTypeStandard
,
ModelRoutingEnabled
:
true
,
ModelRouting
:
map
[
string
][]
int64
{
"claude-3-*"
:
[]
int64
{
101
,
102
},
},
AccountCount
:
2
,
CreatedAt
:
deps
.
now
,
UpdatedAt
:
deps
.
now
,
},
})
deps
.
userSubRepo
.
SetActiveByUserID
(
1
,
nil
)
},
method
:
http
.
MethodGet
,
path
:
"/api/v1/groups/available"
,
wantStatus
:
http
.
StatusOK
,
wantJSON
:
`{
"code": 0,
"message": "success",
"data": [
{
"id": 10,
"name": "Group One",
"description": "desc",
"platform": "anthropic",
"rate_multiplier": 1.5,
"is_exclusive": false,
"status": "active",
"subscription_type": "standard",
"daily_limit_usd": null,
"weekly_limit_usd": null,
"monthly_limit_usd": null,
"image_price_1k": null,
"image_price_2k": null,
"image_price_4k": null,
"claude_code_only": false,
"fallback_group_id": null,
"created_at": "2025-01-02T03:04:05Z",
"updated_at": "2025-01-02T03:04:05Z"
}
]
}`
,
},
{
name
:
"GET /api/v1/subscriptions"
,
setup
:
func
(
t
*
testing
.
T
,
deps
*
contractDeps
)
{
t
.
Helper
()
// 普通用户订阅接口不应包含 assigned_* / notes 等管理员字段。
deps
.
userSubRepo
.
SetByUserID
(
1
,
[]
service
.
UserSubscription
{
{
ID
:
501
,
UserID
:
1
,
GroupID
:
10
,
StartsAt
:
deps
.
now
,
ExpiresAt
:
time
.
Date
(
2099
,
1
,
2
,
3
,
4
,
5
,
0
,
time
.
UTC
),
// 使用未来日期避免 normalizeSubscriptionStatus 标记为过期
Status
:
service
.
SubscriptionStatusActive
,
DailyUsageUSD
:
1.23
,
WeeklyUsageUSD
:
2.34
,
MonthlyUsageUSD
:
3.45
,
AssignedBy
:
ptr
(
int64
(
999
)),
AssignedAt
:
deps
.
now
,
Notes
:
"admin-note"
,
CreatedAt
:
deps
.
now
,
UpdatedAt
:
deps
.
now
,
},
})
},
method
:
http
.
MethodGet
,
path
:
"/api/v1/subscriptions"
,
wantStatus
:
http
.
StatusOK
,
wantJSON
:
`{
"code": 0,
"message": "success",
"data": [
{
"id": 501,
"user_id": 1,
"group_id": 10,
"starts_at": "2025-01-02T03:04:05Z",
"expires_at": "2099-01-02T03:04:05Z",
"status": "active",
"daily_window_start": null,
"weekly_window_start": null,
"monthly_window_start": null,
"daily_usage_usd": 1.23,
"weekly_usage_usd": 2.34,
"monthly_usage_usd": 3.45,
"created_at": "2025-01-02T03:04:05Z",
"updated_at": "2025-01-02T03:04:05Z"
}
]
}`
,
},
{
name
:
"GET /api/v1/redeem/history"
,
setup
:
func
(
t
*
testing
.
T
,
deps
*
contractDeps
)
{
t
.
Helper
()
// 普通用户兑换历史不应包含 notes 等内部字段。
deps
.
redeemRepo
.
SetByUser
(
1
,
[]
service
.
RedeemCode
{
{
ID
:
900
,
Code
:
"CODE-123"
,
Type
:
service
.
RedeemTypeBalance
,
Value
:
1.25
,
Status
:
service
.
StatusUsed
,
UsedBy
:
ptr
(
int64
(
1
)),
UsedAt
:
ptr
(
deps
.
now
),
Notes
:
"internal-note"
,
CreatedAt
:
deps
.
now
,
},
})
},
method
:
http
.
MethodGet
,
path
:
"/api/v1/redeem/history"
,
wantStatus
:
http
.
StatusOK
,
wantJSON
:
`{
"code": 0,
"message": "success",
"data": [
{
"id": 900,
"code": "CODE-123",
"type": "balance",
"value": 1.25,
"status": "used",
"used_by": 1,
"used_at": "2025-01-02T03:04:05Z",
"created_at": "2025-01-02T03:04:05Z",
"group_id": null,
"validity_days": 0
}
]
}`
,
},
{
name
:
"GET /api/v1/usage/stats"
,
setup
:
func
(
t
*
testing
.
T
,
deps
*
contractDeps
)
{
...
...
@@ -191,24 +336,25 @@ func TestAPIContracts(t *testing.T) {
t
.
Helper
()
deps
.
usageRepo
.
SetUserLogs
(
1
,
[]
service
.
UsageLog
{
{
ID
:
1
,
UserID
:
1
,
APIKeyID
:
100
,
AccountID
:
200
,
RequestID
:
"req_123"
,
Model
:
"claude-3"
,
InputTokens
:
10
,
OutputTokens
:
20
,
CacheCreationTokens
:
1
,
CacheReadTokens
:
2
,
TotalCost
:
0.5
,
ActualCost
:
0.5
,
RateMultiplier
:
1
,
BillingType
:
service
.
BillingTypeBalance
,
Stream
:
true
,
DurationMs
:
ptr
(
100
),
FirstTokenMs
:
ptr
(
50
),
CreatedAt
:
deps
.
now
,
ID
:
1
,
UserID
:
1
,
APIKeyID
:
100
,
AccountID
:
200
,
AccountRateMultiplier
:
ptr
(
0.5
),
RequestID
:
"req_123"
,
Model
:
"claude-3"
,
InputTokens
:
10
,
OutputTokens
:
20
,
CacheCreationTokens
:
1
,
CacheReadTokens
:
2
,
TotalCost
:
0.5
,
ActualCost
:
0.5
,
RateMultiplier
:
1
,
BillingType
:
service
.
BillingTypeBalance
,
Stream
:
true
,
DurationMs
:
ptr
(
100
),
FirstTokenMs
:
ptr
(
50
),
CreatedAt
:
deps
.
now
,
},
})
},
...
...
@@ -239,10 +385,9 @@ func TestAPIContracts(t *testing.T) {
"output_cost": 0,
"cache_creation_cost": 0,
"cache_read_cost": 0,
"total_cost": 0.5,
"total_cost": 0.5,
"actual_cost": 0.5,
"rate_multiplier": 1,
"account_rate_multiplier": null,
"billing_type": 0,
"stream": true,
"duration_ms": 100,
...
...
@@ -267,6 +412,7 @@ func TestAPIContracts(t *testing.T) {
deps
.
settingRepo
.
SetAll
(
map
[
string
]
string
{
service
.
SettingKeyRegistrationEnabled
:
"true"
,
service
.
SettingKeyEmailVerifyEnabled
:
"false"
,
service
.
SettingKeyPromoCodeEnabled
:
"true"
,
service
.
SettingKeySMTPHost
:
"smtp.example.com"
,
service
.
SettingKeySMTPPort
:
"587"
,
...
...
@@ -305,6 +451,10 @@ func TestAPIContracts(t *testing.T) {
"data": {
"registration_enabled": true,
"email_verify_enabled": false,
"promo_code_enabled": true,
"password_reset_enabled": false,
"totp_enabled": false,
"totp_encryption_key_configured": false,
"smtp_host": "smtp.example.com",
"smtp_port": 587,
"smtp_username": "user",
...
...
@@ -338,45 +488,10 @@ func TestAPIContracts(t *testing.T) {
"fallback_model_openai": "gpt-4o",
"enable_identity_patch": true,
"identity_patch_prompt": "",
"home_content": ""
}
}`
,
},
{
name
:
"POST /api/v1/admin/accounts/lookup"
,
setup
:
func
(
t
*
testing
.
T
,
deps
*
contractDeps
)
{
t
.
Helper
()
deps
.
accountRepo
.
lookupAccounts
=
[]
service
.
Account
{
{
ID
:
101
,
Name
:
"Alice Account"
,
Platform
:
"antigravity"
,
Credentials
:
map
[
string
]
any
{
"email"
:
"alice@example.com"
,
},
},
}
},
method
:
http
.
MethodPost
,
path
:
"/api/v1/admin/accounts/lookup"
,
body
:
`{"platform":"antigravity","emails":["Alice@Example.com","bob@example.com"]}`
,
headers
:
map
[
string
]
string
{
"Content-Type"
:
"application/json"
,
},
wantStatus
:
http
.
StatusOK
,
wantJSON
:
`{
"code": 0,
"message": "success",
"data": {
"matched": [
{
"email": "alice@example.com",
"account_id": 101,
"platform": "antigravity",
"name": "Alice Account"
}
],
"missing": ["bob@example.com"]
"home_content": "",
"hide_ccs_import_button": false,
"purchase_subscription_enabled": false,
"purchase_subscription_url": ""
}
}`
,
},
...
...
@@ -424,9 +539,11 @@ type contractDeps struct {
now
time
.
Time
router
http
.
Handler
apiKeyRepo
*
stubApiKeyRepo
groupRepo
*
stubGroupRepo
userSubRepo
*
stubUserSubscriptionRepo
usageRepo
*
stubUsageLogRepo
settingRepo
*
stubSettingRepo
account
Repo
*
stub
Account
Repo
redeem
Repo
*
stub
RedeemCode
Repo
}
func
newContractDeps
(
t
*
testing
.
T
)
*
contractDeps
{
...
...
@@ -454,11 +571,11 @@ func newContractDeps(t *testing.T) *contractDeps {
apiKeyRepo
:=
newStubApiKeyRepo
(
now
)
apiKeyCache
:=
stubApiKeyCache
{}
groupRepo
:=
stubGroupRepo
{}
userSubRepo
:=
stubUserSubscriptionRepo
{}
groupRepo
:=
&
stubGroupRepo
{}
userSubRepo
:=
&
stubUserSubscriptionRepo
{}
accountRepo
:=
stubAccountRepo
{}
proxyRepo
:=
stubProxyRepo
{}
redeemRepo
:=
stubRedeemCodeRepo
{}
redeemRepo
:=
&
stubRedeemCodeRepo
{}
cfg
:=
&
config
.
Config
{
Default
:
config
.
DefaultConfig
{
...
...
@@ -473,15 +590,21 @@ func newContractDeps(t *testing.T) *contractDeps {
usageRepo
:=
newStubUsageLogRepo
()
usageService
:=
service
.
NewUsageService
(
usageRepo
,
userRepo
,
nil
,
nil
)
subscriptionService
:=
service
.
NewSubscriptionService
(
groupRepo
,
userSubRepo
,
nil
)
subscriptionHandler
:=
handler
.
NewSubscriptionHandler
(
subscriptionService
)
redeemService
:=
service
.
NewRedeemService
(
redeemRepo
,
userRepo
,
subscriptionService
,
nil
,
nil
,
nil
,
nil
)
redeemHandler
:=
handler
.
NewRedeemHandler
(
redeemService
)
settingRepo
:=
newStubSettingRepo
()
settingService
:=
service
.
NewSettingService
(
settingRepo
,
cfg
)
adminService
:=
service
.
NewAdminService
(
userRepo
,
groupRepo
,
&
accountRepo
,
proxyRepo
,
apiKeyRepo
,
redeemRepo
,
nil
,
nil
,
nil
,
nil
)
authHandler
:=
handler
.
NewAuthHandler
(
cfg
,
nil
,
userService
,
settingService
,
nil
)
authHandler
:=
handler
.
NewAuthHandler
(
cfg
,
nil
,
userService
,
settingService
,
nil
,
nil
)
apiKeyHandler
:=
handler
.
NewAPIKeyHandler
(
apiKeyService
)
usageHandler
:=
handler
.
NewUsageHandler
(
usageService
,
apiKeyService
)
adminSettingHandler
:=
adminhandler
.
NewSettingHandler
(
settingService
,
nil
,
nil
,
nil
)
adminAccountHandler
:=
adminhandler
.
NewAccountHandler
(
adminService
,
nil
,
nil
,
nil
,
nil
,
nil
,
nil
,
nil
,
nil
,
nil
,
nil
)
adminAccountHandler
:=
adminhandler
.
NewAccountHandler
(
adminService
,
nil
,
nil
,
nil
,
nil
,
nil
,
nil
,
nil
,
nil
,
nil
,
nil
,
nil
)
jwtAuth
:=
func
(
c
*
gin
.
Context
)
{
c
.
Set
(
string
(
middleware
.
ContextKeyUser
),
middleware
.
AuthSubject
{
...
...
@@ -512,25 +635,35 @@ func newContractDeps(t *testing.T) *contractDeps {
v1Keys
.
Use
(
jwtAuth
)
v1Keys
.
GET
(
"/keys"
,
apiKeyHandler
.
List
)
v1Keys
.
POST
(
"/keys"
,
apiKeyHandler
.
Create
)
v1Keys
.
GET
(
"/groups/available"
,
apiKeyHandler
.
GetAvailableGroups
)
v1Usage
:=
v1
.
Group
(
""
)
v1Usage
.
Use
(
jwtAuth
)
v1Usage
.
GET
(
"/usage"
,
usageHandler
.
List
)
v1Usage
.
GET
(
"/usage/stats"
,
usageHandler
.
Stats
)
v1Subs
:=
v1
.
Group
(
""
)
v1Subs
.
Use
(
jwtAuth
)
v1Subs
.
GET
(
"/subscriptions"
,
subscriptionHandler
.
List
)
v1Redeem
:=
v1
.
Group
(
""
)
v1Redeem
.
Use
(
jwtAuth
)
v1Redeem
.
GET
(
"/redeem/history"
,
redeemHandler
.
GetHistory
)
v1Admin
:=
v1
.
Group
(
"/admin"
)
v1Admin
.
Use
(
adminAuth
)
v1Admin
.
GET
(
"/settings"
,
adminSettingHandler
.
GetSettings
)
v1Admin
.
POST
(
"/accounts/bulk-update"
,
adminAccountHandler
.
BulkUpdate
)
v1Admin
.
POST
(
"/accounts/lookup"
,
adminAccountHandler
.
Lookup
)
return
&
contractDeps
{
now
:
now
,
router
:
r
,
apiKeyRepo
:
apiKeyRepo
,
groupRepo
:
groupRepo
,
userSubRepo
:
userSubRepo
,
usageRepo
:
usageRepo
,
settingRepo
:
settingRepo
,
accountRepo
:
&
account
Repo
,
redeemRepo
:
redeem
Repo
,
}
}
...
...
@@ -626,6 +759,18 @@ func (r *stubUserRepo) RemoveGroupFromAllowedGroups(ctx context.Context, groupID
return
0
,
errors
.
New
(
"not implemented"
)
}
func
(
r
*
stubUserRepo
)
UpdateTotpSecret
(
ctx
context
.
Context
,
userID
int64
,
encryptedSecret
*
string
)
error
{
return
errors
.
New
(
"not implemented"
)
}
func
(
r
*
stubUserRepo
)
EnableTotp
(
ctx
context
.
Context
,
userID
int64
)
error
{
return
errors
.
New
(
"not implemented"
)
}
func
(
r
*
stubUserRepo
)
DisableTotp
(
ctx
context
.
Context
,
userID
int64
)
error
{
return
errors
.
New
(
"not implemented"
)
}
type
stubApiKeyCache
struct
{}
func
(
stubApiKeyCache
)
GetCreateAttemptCount
(
ctx
context
.
Context
,
userID
int64
)
(
int
,
error
)
{
...
...
@@ -660,7 +805,21 @@ func (stubApiKeyCache) DeleteAuthCache(ctx context.Context, key string) error {
return
nil
}
type
stubGroupRepo
struct
{}
func
(
stubApiKeyCache
)
PublishAuthCacheInvalidation
(
ctx
context
.
Context
,
cacheKey
string
)
error
{
return
nil
}
func
(
stubApiKeyCache
)
SubscribeAuthCacheInvalidation
(
ctx
context
.
Context
,
handler
func
(
cacheKey
string
))
error
{
return
nil
}
type
stubGroupRepo
struct
{
active
[]
service
.
Group
}
func
(
r
*
stubGroupRepo
)
SetActive
(
groups
[]
service
.
Group
)
{
r
.
active
=
append
([]
service
.
Group
(
nil
),
groups
...
)
}
func
(
stubGroupRepo
)
Create
(
ctx
context
.
Context
,
group
*
service
.
Group
)
error
{
return
errors
.
New
(
"not implemented"
)
...
...
@@ -694,12 +853,19 @@ func (stubGroupRepo) ListWithFilters(ctx context.Context, params pagination.Pagi
return
nil
,
nil
,
errors
.
New
(
"not implemented"
)
}
func
(
stubGroupRepo
)
ListActive
(
ctx
context
.
Context
)
([]
service
.
Group
,
error
)
{
return
nil
,
errors
.
New
(
"not implemented"
)
func
(
r
*
stubGroupRepo
)
ListActive
(
ctx
context
.
Context
)
([]
service
.
Group
,
error
)
{
return
append
([]
service
.
Group
(
nil
),
r
.
active
...
),
nil
}
func
(
stubGroupRepo
)
ListActiveByPlatform
(
ctx
context
.
Context
,
platform
string
)
([]
service
.
Group
,
error
)
{
return
nil
,
errors
.
New
(
"not implemented"
)
func
(
r
*
stubGroupRepo
)
ListActiveByPlatform
(
ctx
context
.
Context
,
platform
string
)
([]
service
.
Group
,
error
)
{
out
:=
make
([]
service
.
Group
,
0
,
len
(
r
.
active
))
for
i
:=
range
r
.
active
{
g
:=
r
.
active
[
i
]
if
g
.
Platform
==
platform
{
out
=
append
(
out
,
g
)
}
}
return
out
,
nil
}
func
(
stubGroupRepo
)
ExistsByName
(
ctx
context
.
Context
,
name
string
)
(
bool
,
error
)
{
...
...
@@ -715,8 +881,7 @@ func (stubGroupRepo) DeleteAccountGroupsByGroupID(ctx context.Context, groupID i
}
type
stubAccountRepo
struct
{
bulkUpdateIDs
[]
int64
lookupAccounts
[]
service
.
Account
bulkUpdateIDs
[]
int64
}
func
(
s
*
stubAccountRepo
)
Create
(
ctx
context
.
Context
,
account
*
service
.
Account
)
error
{
...
...
@@ -767,36 +932,6 @@ func (s *stubAccountRepo) ListByPlatform(ctx context.Context, platform string) (
return
nil
,
errors
.
New
(
"not implemented"
)
}
func
(
s
*
stubAccountRepo
)
ListByPlatformAndCredentialEmails
(
ctx
context
.
Context
,
platform
string
,
emails
[]
string
)
([]
service
.
Account
,
error
)
{
if
len
(
s
.
lookupAccounts
)
==
0
{
return
nil
,
nil
}
emailSet
:=
make
(
map
[
string
]
struct
{},
len
(
emails
))
for
_
,
email
:=
range
emails
{
normalized
:=
strings
.
ToLower
(
strings
.
TrimSpace
(
email
))
if
normalized
==
""
{
continue
}
emailSet
[
normalized
]
=
struct
{}{}
}
var
matches
[]
service
.
Account
for
i
:=
range
s
.
lookupAccounts
{
account
:=
&
s
.
lookupAccounts
[
i
]
if
account
.
Platform
!=
platform
{
continue
}
accountEmail
:=
strings
.
ToLower
(
strings
.
TrimSpace
(
account
.
GetCredential
(
"email"
)))
if
accountEmail
==
""
{
continue
}
if
_
,
ok
:=
emailSet
[
accountEmail
];
!
ok
{
continue
}
matches
=
append
(
matches
,
*
account
)
}
return
matches
,
nil
}
func
(
s
*
stubAccountRepo
)
UpdateLastUsed
(
ctx
context
.
Context
,
id
int64
)
error
{
return
errors
.
New
(
"not implemented"
)
}
...
...
@@ -948,7 +1083,16 @@ func (stubProxyRepo) ListAccountSummariesByProxyID(ctx context.Context, proxyID
return
nil
,
errors
.
New
(
"not implemented"
)
}
type
stubRedeemCodeRepo
struct
{}
type
stubRedeemCodeRepo
struct
{
byUser
map
[
int64
][]
service
.
RedeemCode
}
func
(
r
*
stubRedeemCodeRepo
)
SetByUser
(
userID
int64
,
codes
[]
service
.
RedeemCode
)
{
if
r
.
byUser
==
nil
{
r
.
byUser
=
make
(
map
[
int64
][]
service
.
RedeemCode
)
}
r
.
byUser
[
userID
]
=
append
([]
service
.
RedeemCode
(
nil
),
codes
...
)
}
func
(
stubRedeemCodeRepo
)
Create
(
ctx
context
.
Context
,
code
*
service
.
RedeemCode
)
error
{
return
errors
.
New
(
"not implemented"
)
...
...
@@ -986,11 +1130,35 @@ func (stubRedeemCodeRepo) ListWithFilters(ctx context.Context, params pagination
return
nil
,
nil
,
errors
.
New
(
"not implemented"
)
}
func
(
stubRedeemCodeRepo
)
ListByUser
(
ctx
context
.
Context
,
userID
int64
,
limit
int
)
([]
service
.
RedeemCode
,
error
)
{
return
nil
,
errors
.
New
(
"not implemented"
)
func
(
r
*
stubRedeemCodeRepo
)
ListByUser
(
ctx
context
.
Context
,
userID
int64
,
limit
int
)
([]
service
.
RedeemCode
,
error
)
{
if
r
.
byUser
==
nil
{
return
nil
,
nil
}
codes
:=
r
.
byUser
[
userID
]
if
limit
>
0
&&
len
(
codes
)
>
limit
{
codes
=
codes
[
:
limit
]
}
return
append
([]
service
.
RedeemCode
(
nil
),
codes
...
),
nil
}
type
stubUserSubscriptionRepo
struct
{
byUser
map
[
int64
][]
service
.
UserSubscription
activeByUser
map
[
int64
][]
service
.
UserSubscription
}
func
(
r
*
stubUserSubscriptionRepo
)
SetByUserID
(
userID
int64
,
subs
[]
service
.
UserSubscription
)
{
if
r
.
byUser
==
nil
{
r
.
byUser
=
make
(
map
[
int64
][]
service
.
UserSubscription
)
}
r
.
byUser
[
userID
]
=
append
([]
service
.
UserSubscription
(
nil
),
subs
...
)
}
type
stubUserSubscriptionRepo
struct
{}
func
(
r
*
stubUserSubscriptionRepo
)
SetActiveByUserID
(
userID
int64
,
subs
[]
service
.
UserSubscription
)
{
if
r
.
activeByUser
==
nil
{
r
.
activeByUser
=
make
(
map
[
int64
][]
service
.
UserSubscription
)
}
r
.
activeByUser
[
userID
]
=
append
([]
service
.
UserSubscription
(
nil
),
subs
...
)
}
func
(
stubUserSubscriptionRepo
)
Create
(
ctx
context
.
Context
,
sub
*
service
.
UserSubscription
)
error
{
return
errors
.
New
(
"not implemented"
)
...
...
@@ -1010,16 +1178,22 @@ func (stubUserSubscriptionRepo) Update(ctx context.Context, sub *service.UserSub
func
(
stubUserSubscriptionRepo
)
Delete
(
ctx
context
.
Context
,
id
int64
)
error
{
return
errors
.
New
(
"not implemented"
)
}
func
(
stubUserSubscriptionRepo
)
ListByUserID
(
ctx
context
.
Context
,
userID
int64
)
([]
service
.
UserSubscription
,
error
)
{
return
nil
,
errors
.
New
(
"not implemented"
)
func
(
r
*
stubUserSubscriptionRepo
)
ListByUserID
(
ctx
context
.
Context
,
userID
int64
)
([]
service
.
UserSubscription
,
error
)
{
if
r
.
byUser
==
nil
{
return
nil
,
nil
}
return
append
([]
service
.
UserSubscription
(
nil
),
r
.
byUser
[
userID
]
...
),
nil
}
func
(
stubUserSubscriptionRepo
)
ListActiveByUserID
(
ctx
context
.
Context
,
userID
int64
)
([]
service
.
UserSubscription
,
error
)
{
return
nil
,
errors
.
New
(
"not implemented"
)
func
(
r
*
stubUserSubscriptionRepo
)
ListActiveByUserID
(
ctx
context
.
Context
,
userID
int64
)
([]
service
.
UserSubscription
,
error
)
{
if
r
.
activeByUser
==
nil
{
return
nil
,
nil
}
return
append
([]
service
.
UserSubscription
(
nil
),
r
.
activeByUser
[
userID
]
...
),
nil
}
func
(
stubUserSubscriptionRepo
)
ListByGroupID
(
ctx
context
.
Context
,
groupID
int64
,
params
pagination
.
PaginationParams
)
([]
service
.
UserSubscription
,
*
pagination
.
PaginationResult
,
error
)
{
return
nil
,
nil
,
errors
.
New
(
"not implemented"
)
}
func
(
stubUserSubscriptionRepo
)
List
(
ctx
context
.
Context
,
params
pagination
.
PaginationParams
,
userID
,
groupID
*
int64
,
status
string
)
([]
service
.
UserSubscription
,
*
pagination
.
PaginationResult
,
error
)
{
func
(
stubUserSubscriptionRepo
)
List
(
ctx
context
.
Context
,
params
pagination
.
PaginationParams
,
userID
,
groupID
*
int64
,
status
,
sortBy
,
sortOrder
string
)
([]
service
.
UserSubscription
,
*
pagination
.
PaginationResult
,
error
)
{
return
nil
,
nil
,
errors
.
New
(
"not implemented"
)
}
func
(
stubUserSubscriptionRepo
)
ExistsByUserIDAndGroupID
(
ctx
context
.
Context
,
userID
,
groupID
int64
)
(
bool
,
error
)
{
...
...
@@ -1319,11 +1493,11 @@ func (r *stubUsageLogRepo) GetDashboardStats(ctx context.Context) (*usagestats.D
return
nil
,
errors
.
New
(
"not implemented"
)
}
func
(
r
*
stubUsageLogRepo
)
GetUsageTrendWithFilters
(
ctx
context
.
Context
,
startTime
,
endTime
time
.
Time
,
granularity
string
,
userID
,
apiKeyID
,
accountID
,
groupID
int64
,
model
string
,
stream
*
bool
)
([]
usagestats
.
TrendDataPoint
,
error
)
{
func
(
r
*
stubUsageLogRepo
)
GetUsageTrendWithFilters
(
ctx
context
.
Context
,
startTime
,
endTime
time
.
Time
,
granularity
string
,
userID
,
apiKeyID
,
accountID
,
groupID
int64
,
model
string
,
stream
*
bool
,
billingType
*
int8
)
([]
usagestats
.
TrendDataPoint
,
error
)
{
return
nil
,
errors
.
New
(
"not implemented"
)
}
func
(
r
*
stubUsageLogRepo
)
GetModelStatsWithFilters
(
ctx
context
.
Context
,
startTime
,
endTime
time
.
Time
,
userID
,
apiKeyID
,
accountID
,
groupID
int64
,
stream
*
bool
)
([]
usagestats
.
ModelStat
,
error
)
{
func
(
r
*
stubUsageLogRepo
)
GetModelStatsWithFilters
(
ctx
context
.
Context
,
startTime
,
endTime
time
.
Time
,
userID
,
apiKeyID
,
accountID
,
groupID
int64
,
stream
*
bool
,
billingType
*
int8
)
([]
usagestats
.
ModelStat
,
error
)
{
return
nil
,
errors
.
New
(
"not implemented"
)
}
...
...
backend/internal/server/middleware/api_key_auth_test.go
View file @
0170d19f
...
...
@@ -367,7 +367,7 @@ func (r *stubUserSubscriptionRepo) ListByGroupID(ctx context.Context, groupID in
return
nil
,
nil
,
errors
.
New
(
"not implemented"
)
}
func
(
r
*
stubUserSubscriptionRepo
)
List
(
ctx
context
.
Context
,
params
pagination
.
PaginationParams
,
userID
,
groupID
*
int64
,
status
string
)
([]
service
.
UserSubscription
,
*
pagination
.
PaginationResult
,
error
)
{
func
(
r
*
stubUserSubscriptionRepo
)
List
(
ctx
context
.
Context
,
params
pagination
.
PaginationParams
,
userID
,
groupID
*
int64
,
status
,
sortBy
,
sortOrder
string
)
([]
service
.
UserSubscription
,
*
pagination
.
PaginationResult
,
error
)
{
return
nil
,
nil
,
errors
.
New
(
"not implemented"
)
}
...
...
backend/internal/server/routes/admin.go
View file @
0170d19f
...
...
@@ -29,6 +29,9 @@ func RegisterAdminRoutes(
// 账号管理
registerAccountRoutes
(
admin
,
h
)
// 公告管理
registerAnnouncementRoutes
(
admin
,
h
)
// OpenAI OAuth
registerOpenAIOAuthRoutes
(
admin
,
h
)
...
...
@@ -197,7 +200,6 @@ func registerAccountRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
accounts
:=
admin
.
Group
(
"/accounts"
)
{
accounts
.
GET
(
""
,
h
.
Admin
.
Account
.
List
)
accounts
.
POST
(
"/lookup"
,
h
.
Admin
.
Account
.
Lookup
)
accounts
.
GET
(
"/:id"
,
h
.
Admin
.
Account
.
GetByID
)
accounts
.
POST
(
""
,
h
.
Admin
.
Account
.
Create
)
accounts
.
POST
(
"/sync/crs"
,
h
.
Admin
.
Account
.
SyncFromCRS
)
...
...
@@ -230,6 +232,18 @@ func registerAccountRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
}
}
func
registerAnnouncementRoutes
(
admin
*
gin
.
RouterGroup
,
h
*
handler
.
Handlers
)
{
announcements
:=
admin
.
Group
(
"/announcements"
)
{
announcements
.
GET
(
""
,
h
.
Admin
.
Announcement
.
List
)
announcements
.
POST
(
""
,
h
.
Admin
.
Announcement
.
Create
)
announcements
.
GET
(
"/:id"
,
h
.
Admin
.
Announcement
.
GetByID
)
announcements
.
PUT
(
"/:id"
,
h
.
Admin
.
Announcement
.
Update
)
announcements
.
DELETE
(
"/:id"
,
h
.
Admin
.
Announcement
.
Delete
)
announcements
.
GET
(
"/:id/read-status"
,
h
.
Admin
.
Announcement
.
ListReadStatus
)
}
}
func
registerOpenAIOAuthRoutes
(
admin
*
gin
.
RouterGroup
,
h
*
handler
.
Handlers
)
{
openai
:=
admin
.
Group
(
"/openai"
)
{
...
...
@@ -355,6 +369,9 @@ func registerUsageRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
usage
.
GET
(
"/stats"
,
h
.
Admin
.
Usage
.
Stats
)
usage
.
GET
(
"/search-users"
,
h
.
Admin
.
Usage
.
SearchUsers
)
usage
.
GET
(
"/search-api-keys"
,
h
.
Admin
.
Usage
.
SearchAPIKeys
)
usage
.
GET
(
"/cleanup-tasks"
,
h
.
Admin
.
Usage
.
ListCleanupTasks
)
usage
.
POST
(
"/cleanup-tasks"
,
h
.
Admin
.
Usage
.
CreateCleanupTask
)
usage
.
POST
(
"/cleanup-tasks/:id/cancel"
,
h
.
Admin
.
Usage
.
CancelCleanupTask
)
}
}
...
...
backend/internal/server/routes/auth.go
View file @
0170d19f
...
...
@@ -26,11 +26,20 @@ func RegisterAuthRoutes(
{
auth
.
POST
(
"/register"
,
h
.
Auth
.
Register
)
auth
.
POST
(
"/login"
,
h
.
Auth
.
Login
)
auth
.
POST
(
"/login/2fa"
,
h
.
Auth
.
Login2FA
)
auth
.
POST
(
"/send-verify-code"
,
h
.
Auth
.
SendVerifyCode
)
// 优惠码验证接口添加速率限制:每分钟最多 10 次(Redis 故障时 fail-close)
auth
.
POST
(
"/validate-promo-code"
,
rateLimiter
.
LimitWithOptions
(
"validate-promo"
,
10
,
time
.
Minute
,
middleware
.
RateLimitOptions
{
FailureMode
:
middleware
.
RateLimitFailClose
,
}),
h
.
Auth
.
ValidatePromoCode
)
// 忘记密码接口添加速率限制:每分钟最多 5 次(Redis 故障时 fail-close)
auth
.
POST
(
"/forgot-password"
,
rateLimiter
.
LimitWithOptions
(
"forgot-password"
,
5
,
time
.
Minute
,
middleware
.
RateLimitOptions
{
FailureMode
:
middleware
.
RateLimitFailClose
,
}),
h
.
Auth
.
ForgotPassword
)
// 重置密码接口添加速率限制:每分钟最多 10 次(Redis 故障时 fail-close)
auth
.
POST
(
"/reset-password"
,
rateLimiter
.
LimitWithOptions
(
"reset-password"
,
10
,
time
.
Minute
,
middleware
.
RateLimitOptions
{
FailureMode
:
middleware
.
RateLimitFailClose
,
}),
h
.
Auth
.
ResetPassword
)
auth
.
GET
(
"/oauth/linuxdo/start"
,
h
.
Auth
.
LinuxDoOAuthStart
)
auth
.
GET
(
"/oauth/linuxdo/callback"
,
h
.
Auth
.
LinuxDoOAuthCallback
)
}
...
...
backend/internal/server/routes/user.go
View file @
0170d19f
...
...
@@ -22,6 +22,17 @@ func RegisterUserRoutes(
user
.
GET
(
"/profile"
,
h
.
User
.
GetProfile
)
user
.
PUT
(
"/password"
,
h
.
User
.
ChangePassword
)
user
.
PUT
(
""
,
h
.
User
.
UpdateProfile
)
// TOTP 双因素认证
totp
:=
user
.
Group
(
"/totp"
)
{
totp
.
GET
(
"/status"
,
h
.
Totp
.
GetStatus
)
totp
.
GET
(
"/verification-method"
,
h
.
Totp
.
GetVerificationMethod
)
totp
.
POST
(
"/send-code"
,
h
.
Totp
.
SendVerifyCode
)
totp
.
POST
(
"/setup"
,
h
.
Totp
.
InitiateSetup
)
totp
.
POST
(
"/enable"
,
h
.
Totp
.
Enable
)
totp
.
POST
(
"/disable"
,
h
.
Totp
.
Disable
)
}
}
// API Key管理
...
...
@@ -53,6 +64,13 @@ func RegisterUserRoutes(
usage
.
POST
(
"/dashboard/api-keys-usage"
,
h
.
Usage
.
DashboardAPIKeysUsage
)
}
// 公告(用户可见)
announcements
:=
authenticated
.
Group
(
"/announcements"
)
{
announcements
.
GET
(
""
,
h
.
Announcement
.
List
)
announcements
.
POST
(
"/:id/read"
,
h
.
Announcement
.
MarkRead
)
}
// 卡密兑换
redeem
:=
authenticated
.
Group
(
"/redeem"
)
{
...
...
backend/internal/service/account.go
View file @
0170d19f
...
...
@@ -197,6 +197,35 @@ func (a *Account) GetCredentialAsTime(key string) *time.Time {
return
nil
}
// GetCredentialAsInt64 解析凭证中的 int64 字段
// 用于读取 _token_version 等内部字段
func
(
a
*
Account
)
GetCredentialAsInt64
(
key
string
)
int64
{
if
a
==
nil
||
a
.
Credentials
==
nil
{
return
0
}
val
,
ok
:=
a
.
Credentials
[
key
]
if
!
ok
||
val
==
nil
{
return
0
}
switch
v
:=
val
.
(
type
)
{
case
int64
:
return
v
case
float64
:
return
int64
(
v
)
case
int
:
return
int64
(
v
)
case
json
.
Number
:
if
i
,
err
:=
v
.
Int64
();
err
==
nil
{
return
i
}
case
string
:
if
i
,
err
:=
strconv
.
ParseInt
(
strings
.
TrimSpace
(
v
),
10
,
64
);
err
==
nil
{
return
i
}
}
return
0
}
func
(
a
*
Account
)
IsTempUnschedulableEnabled
()
bool
{
if
a
.
Credentials
==
nil
{
return
false
...
...
@@ -576,6 +605,44 @@ func (a *Account) IsAnthropicOAuthOrSetupToken() bool {
return
a
.
Platform
==
PlatformAnthropic
&&
(
a
.
Type
==
AccountTypeOAuth
||
a
.
Type
==
AccountTypeSetupToken
)
}
// IsTLSFingerprintEnabled 检查是否启用 TLS 指纹伪装
// 仅适用于 Anthropic OAuth/SetupToken 类型账号
// 启用后将模拟 Claude Code (Node.js) 客户端的 TLS 握手特征
func
(
a
*
Account
)
IsTLSFingerprintEnabled
()
bool
{
// 仅支持 Anthropic OAuth/SetupToken 账号
if
!
a
.
IsAnthropicOAuthOrSetupToken
()
{
return
false
}
if
a
.
Extra
==
nil
{
return
false
}
if
v
,
ok
:=
a
.
Extra
[
"enable_tls_fingerprint"
];
ok
{
if
enabled
,
ok
:=
v
.
(
bool
);
ok
{
return
enabled
}
}
return
false
}
// IsSessionIDMaskingEnabled 检查是否启用会话ID伪装
// 仅适用于 Anthropic OAuth/SetupToken 类型账号
// 启用后将在一段时间内(15分钟)固定 metadata.user_id 中的 session ID,
// 使上游认为请求来自同一个会话
func
(
a
*
Account
)
IsSessionIDMaskingEnabled
()
bool
{
if
!
a
.
IsAnthropicOAuthOrSetupToken
()
{
return
false
}
if
a
.
Extra
==
nil
{
return
false
}
if
v
,
ok
:=
a
.
Extra
[
"session_id_masking_enabled"
];
ok
{
if
enabled
,
ok
:=
v
.
(
bool
);
ok
{
return
enabled
}
}
return
false
}
// GetWindowCostLimit 获取 5h 窗口费用阈值(美元)
// 返回 0 表示未启用
func
(
a
*
Account
)
GetWindowCostLimit
()
float64
{
...
...
@@ -652,6 +719,23 @@ func (a *Account) CheckWindowCostSchedulability(currentWindowCost float64) Windo
return
WindowCostNotSchedulable
}
// GetCurrentWindowStartTime 获取当前有效的窗口开始时间
// 逻辑:
// 1. 如果窗口未过期(SessionWindowEnd 存在且在当前时间之后),使用记录的 SessionWindowStart
// 2. 否则(窗口过期或未设置),使用新的预测窗口开始时间(从当前整点开始)
func
(
a
*
Account
)
GetCurrentWindowStartTime
()
time
.
Time
{
now
:=
time
.
Now
()
// 窗口未过期,使用记录的窗口开始时间
if
a
.
SessionWindowStart
!=
nil
&&
a
.
SessionWindowEnd
!=
nil
&&
now
.
Before
(
*
a
.
SessionWindowEnd
)
{
return
*
a
.
SessionWindowStart
}
// 窗口已过期或未设置,预测新的窗口开始时间(从当前整点开始)
// 与 ratelimit_service.go 中 UpdateSessionWindow 的预测逻辑保持一致
return
time
.
Date
(
now
.
Year
(),
now
.
Month
(),
now
.
Day
(),
now
.
Hour
(),
0
,
0
,
0
,
now
.
Location
())
}
// parseExtraFloat64 从 extra 字段解析 float64 值
func
parseExtraFloat64
(
value
any
)
float64
{
switch
v
:=
value
.
(
type
)
{
...
...
backend/internal/service/account_service.go
View file @
0170d19f
...
...
@@ -33,7 +33,6 @@ type AccountRepository interface {
ListByGroup
(
ctx
context
.
Context
,
groupID
int64
)
([]
Account
,
error
)
ListActive
(
ctx
context
.
Context
)
([]
Account
,
error
)
ListByPlatform
(
ctx
context
.
Context
,
platform
string
)
([]
Account
,
error
)
ListByPlatformAndCredentialEmails
(
ctx
context
.
Context
,
platform
string
,
emails
[]
string
)
([]
Account
,
error
)
UpdateLastUsed
(
ctx
context
.
Context
,
id
int64
)
error
BatchUpdateLastUsed
(
ctx
context
.
Context
,
updates
map
[
int64
]
time
.
Time
)
error
...
...
backend/internal/service/account_service_delete_test.go
View file @
0170d19f
...
...
@@ -87,10 +87,6 @@ func (s *accountRepoStub) ListByPlatform(ctx context.Context, platform string) (
panic
(
"unexpected ListByPlatform call"
)
}
func
(
s
*
accountRepoStub
)
ListByPlatformAndCredentialEmails
(
ctx
context
.
Context
,
platform
string
,
emails
[]
string
)
([]
Account
,
error
)
{
panic
(
"unexpected ListByPlatformAndCredentialEmails call"
)
}
func
(
s
*
accountRepoStub
)
UpdateLastUsed
(
ctx
context
.
Context
,
id
int64
)
error
{
panic
(
"unexpected UpdateLastUsed call"
)
}
...
...
backend/internal/service/account_test_service.go
View file @
0170d19f
...
...
@@ -265,7 +265,7 @@ func (s *AccountTestService) testClaudeAccountConnection(c *gin.Context, account
proxyURL
=
account
.
Proxy
.
URL
()
}
resp
,
err
:=
s
.
httpUpstream
.
Do
(
req
,
proxyURL
,
account
.
ID
,
account
.
Concurrency
)
resp
,
err
:=
s
.
httpUpstream
.
Do
WithTLS
(
req
,
proxyURL
,
account
.
ID
,
account
.
Concurrency
,
account
.
IsTLSFingerprintEnabled
()
)
if
err
!=
nil
{
return
s
.
sendErrorAndEnd
(
c
,
fmt
.
Sprintf
(
"Request failed: %s"
,
err
.
Error
()))
}
...
...
@@ -375,7 +375,7 @@ func (s *AccountTestService) testOpenAIAccountConnection(c *gin.Context, account
proxyURL
=
account
.
Proxy
.
URL
()
}
resp
,
err
:=
s
.
httpUpstream
.
Do
(
req
,
proxyURL
,
account
.
ID
,
account
.
Concurrency
)
resp
,
err
:=
s
.
httpUpstream
.
Do
WithTLS
(
req
,
proxyURL
,
account
.
ID
,
account
.
Concurrency
,
account
.
IsTLSFingerprintEnabled
()
)
if
err
!=
nil
{
return
s
.
sendErrorAndEnd
(
c
,
fmt
.
Sprintf
(
"Request failed: %s"
,
err
.
Error
()))
}
...
...
@@ -446,7 +446,7 @@ func (s *AccountTestService) testGeminiAccountConnection(c *gin.Context, account
proxyURL
=
account
.
Proxy
.
URL
()
}
resp
,
err
:=
s
.
httpUpstream
.
Do
(
req
,
proxyURL
,
account
.
ID
,
account
.
Concurrency
)
resp
,
err
:=
s
.
httpUpstream
.
Do
WithTLS
(
req
,
proxyURL
,
account
.
ID
,
account
.
Concurrency
,
account
.
IsTLSFingerprintEnabled
()
)
if
err
!=
nil
{
return
s
.
sendErrorAndEnd
(
c
,
fmt
.
Sprintf
(
"Request failed: %s"
,
err
.
Error
()))
}
...
...
backend/internal/service/account_usage_service.go
View file @
0170d19f
...
...
@@ -32,8 +32,8 @@ type UsageLogRepository interface {
// Admin dashboard stats
GetDashboardStats
(
ctx
context
.
Context
)
(
*
usagestats
.
DashboardStats
,
error
)
GetUsageTrendWithFilters
(
ctx
context
.
Context
,
startTime
,
endTime
time
.
Time
,
granularity
string
,
userID
,
apiKeyID
,
accountID
,
groupID
int64
,
model
string
,
stream
*
bool
)
([]
usagestats
.
TrendDataPoint
,
error
)
GetModelStatsWithFilters
(
ctx
context
.
Context
,
startTime
,
endTime
time
.
Time
,
userID
,
apiKeyID
,
accountID
,
groupID
int64
,
stream
*
bool
)
([]
usagestats
.
ModelStat
,
error
)
GetUsageTrendWithFilters
(
ctx
context
.
Context
,
startTime
,
endTime
time
.
Time
,
granularity
string
,
userID
,
apiKeyID
,
accountID
,
groupID
int64
,
model
string
,
stream
*
bool
,
billingType
*
int8
)
([]
usagestats
.
TrendDataPoint
,
error
)
GetModelStatsWithFilters
(
ctx
context
.
Context
,
startTime
,
endTime
time
.
Time
,
userID
,
apiKeyID
,
accountID
,
groupID
int64
,
stream
*
bool
,
billingType
*
int8
)
([]
usagestats
.
ModelStat
,
error
)
GetAPIKeyUsageTrend
(
ctx
context
.
Context
,
startTime
,
endTime
time
.
Time
,
granularity
string
,
limit
int
)
([]
usagestats
.
APIKeyUsageTrendPoint
,
error
)
GetUserUsageTrend
(
ctx
context
.
Context
,
startTime
,
endTime
time
.
Time
,
granularity
string
,
limit
int
)
([]
usagestats
.
UserUsageTrendPoint
,
error
)
GetBatchUserUsageStats
(
ctx
context
.
Context
,
userIDs
[]
int64
)
(
map
[
int64
]
*
usagestats
.
BatchUserUsageStats
,
error
)
...
...
@@ -157,9 +157,20 @@ type ClaudeUsageResponse struct {
}
`json:"seven_day_sonnet"`
}
// ClaudeUsageFetchOptions 包含获取 Claude 用量数据所需的所有选项
type
ClaudeUsageFetchOptions
struct
{
AccessToken
string
// OAuth access token
ProxyURL
string
// 代理 URL(可选)
AccountID
int64
// 账号 ID(用于 TLS 指纹选择)
EnableTLSFingerprint
bool
// 是否启用 TLS 指纹伪装
Fingerprint
*
Fingerprint
// 缓存的指纹信息(User-Agent 等)
}
// ClaudeUsageFetcher fetches usage data from Anthropic OAuth API
type
ClaudeUsageFetcher
interface
{
FetchUsage
(
ctx
context
.
Context
,
accessToken
,
proxyURL
string
)
(
*
ClaudeUsageResponse
,
error
)
// FetchUsageWithOptions 使用完整选项获取用量数据,支持 TLS 指纹和自定义 User-Agent
FetchUsageWithOptions
(
ctx
context
.
Context
,
opts
*
ClaudeUsageFetchOptions
)
(
*
ClaudeUsageResponse
,
error
)
}
// AccountUsageService 账号使用量查询服务
...
...
@@ -170,6 +181,7 @@ type AccountUsageService struct {
geminiQuotaService
*
GeminiQuotaService
antigravityQuotaFetcher
*
AntigravityQuotaFetcher
cache
*
UsageCache
identityCache
IdentityCache
}
// NewAccountUsageService 创建AccountUsageService实例
...
...
@@ -180,6 +192,7 @@ func NewAccountUsageService(
geminiQuotaService
*
GeminiQuotaService
,
antigravityQuotaFetcher
*
AntigravityQuotaFetcher
,
cache
*
UsageCache
,
identityCache
IdentityCache
,
)
*
AccountUsageService
{
return
&
AccountUsageService
{
accountRepo
:
accountRepo
,
...
...
@@ -188,6 +201,7 @@ func NewAccountUsageService(
geminiQuotaService
:
geminiQuotaService
,
antigravityQuotaFetcher
:
antigravityQuotaFetcher
,
cache
:
cache
,
identityCache
:
identityCache
,
}
}
...
...
@@ -272,7 +286,7 @@ func (s *AccountUsageService) getGeminiUsage(ctx context.Context, account *Accou
}
dayStart
:=
geminiDailyWindowStart
(
now
)
stats
,
err
:=
s
.
usageLogRepo
.
GetModelStatsWithFilters
(
ctx
,
dayStart
,
now
,
0
,
0
,
account
.
ID
,
0
,
nil
)
stats
,
err
:=
s
.
usageLogRepo
.
GetModelStatsWithFilters
(
ctx
,
dayStart
,
now
,
0
,
0
,
account
.
ID
,
0
,
nil
,
nil
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"get gemini usage stats failed: %w"
,
err
)
}
...
...
@@ -294,7 +308,7 @@ func (s *AccountUsageService) getGeminiUsage(ctx context.Context, account *Accou
// Minute window (RPM) - fixed-window approximation: current minute [truncate(now), truncate(now)+1m)
minuteStart
:=
now
.
Truncate
(
time
.
Minute
)
minuteResetAt
:=
minuteStart
.
Add
(
time
.
Minute
)
minuteStats
,
err
:=
s
.
usageLogRepo
.
GetModelStatsWithFilters
(
ctx
,
minuteStart
,
now
,
0
,
0
,
account
.
ID
,
0
,
nil
)
minuteStats
,
err
:=
s
.
usageLogRepo
.
GetModelStatsWithFilters
(
ctx
,
minuteStart
,
now
,
0
,
0
,
account
.
ID
,
0
,
nil
,
nil
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"get gemini minute usage stats failed: %w"
,
err
)
}
...
...
@@ -369,12 +383,8 @@ func (s *AccountUsageService) addWindowStats(ctx context.Context, account *Accou
// 如果没有缓存,从数据库查询
if
windowStats
==
nil
{
var
startTime
time
.
Time
if
account
.
SessionWindowStart
!=
nil
{
startTime
=
*
account
.
SessionWindowStart
}
else
{
startTime
=
time
.
Now
()
.
Add
(
-
5
*
time
.
Hour
)
}
// 使用统一的窗口开始时间计算逻辑(考虑窗口过期情况)
startTime
:=
account
.
GetCurrentWindowStartTime
()
stats
,
err
:=
s
.
usageLogRepo
.
GetAccountWindowStats
(
ctx
,
account
.
ID
,
startTime
)
if
err
!=
nil
{
...
...
@@ -428,6 +438,8 @@ func (s *AccountUsageService) GetAccountUsageStats(ctx context.Context, accountI
}
// fetchOAuthUsageRaw 从 Anthropic API 获取原始响应(不构建 UsageInfo)
// 如果账号开启了 TLS 指纹,则使用 TLS 指纹伪装
// 如果有缓存的 Fingerprint,则使用缓存的 User-Agent 等信息
func
(
s
*
AccountUsageService
)
fetchOAuthUsageRaw
(
ctx
context
.
Context
,
account
*
Account
)
(
*
ClaudeUsageResponse
,
error
)
{
accessToken
:=
account
.
GetCredential
(
"access_token"
)
if
accessToken
==
""
{
...
...
@@ -439,7 +451,22 @@ func (s *AccountUsageService) fetchOAuthUsageRaw(ctx context.Context, account *A
proxyURL
=
account
.
Proxy
.
URL
()
}
return
s
.
usageFetcher
.
FetchUsage
(
ctx
,
accessToken
,
proxyURL
)
// 构建完整的选项
opts
:=
&
ClaudeUsageFetchOptions
{
AccessToken
:
accessToken
,
ProxyURL
:
proxyURL
,
AccountID
:
account
.
ID
,
EnableTLSFingerprint
:
account
.
IsTLSFingerprintEnabled
(),
}
// 尝试获取缓存的 Fingerprint(包含 User-Agent 等信息)
if
s
.
identityCache
!=
nil
{
if
fp
,
err
:=
s
.
identityCache
.
GetFingerprint
(
ctx
,
account
.
ID
);
err
==
nil
&&
fp
!=
nil
{
opts
.
Fingerprint
=
fp
}
}
return
s
.
usageFetcher
.
FetchUsageWithOptions
(
ctx
,
opts
)
}
// parseTime 尝试多种格式解析时间
...
...
backend/internal/service/admin_service.go
View file @
0170d19f
...
...
@@ -40,7 +40,6 @@ type AdminService interface {
CreateAccount
(
ctx
context
.
Context
,
input
*
CreateAccountInput
)
(
*
Account
,
error
)
UpdateAccount
(
ctx
context
.
Context
,
id
int64
,
input
*
UpdateAccountInput
)
(
*
Account
,
error
)
DeleteAccount
(
ctx
context
.
Context
,
id
int64
)
error
LookupAccountsByCredentialEmail
(
ctx
context
.
Context
,
platform
string
,
emails
[]
string
)
([]
Account
,
error
)
RefreshAccountCredentials
(
ctx
context
.
Context
,
id
int64
)
(
*
Account
,
error
)
ClearAccountError
(
ctx
context
.
Context
,
id
int64
)
(
*
Account
,
error
)
SetAccountError
(
ctx
context
.
Context
,
id
int64
,
errorMsg
string
)
error
...
...
@@ -866,13 +865,6 @@ func (s *adminServiceImpl) GetAccount(ctx context.Context, id int64) (*Account,
return
s
.
accountRepo
.
GetByID
(
ctx
,
id
)
}
func
(
s
*
adminServiceImpl
)
LookupAccountsByCredentialEmail
(
ctx
context
.
Context
,
platform
string
,
emails
[]
string
)
([]
Account
,
error
)
{
if
platform
==
""
||
len
(
emails
)
==
0
{
return
[]
Account
{},
nil
}
return
s
.
accountRepo
.
ListByPlatformAndCredentialEmails
(
ctx
,
platform
,
emails
)
}
func
(
s
*
adminServiceImpl
)
GetAccountsByIDs
(
ctx
context
.
Context
,
ids
[]
int64
)
([]
*
Account
,
error
)
{
if
len
(
ids
)
==
0
{
return
[]
*
Account
{},
nil
...
...
backend/internal/service/admin_service_delete_test.go
View file @
0170d19f
...
...
@@ -93,6 +93,18 @@ func (s *userRepoStub) RemoveGroupFromAllowedGroups(ctx context.Context, groupID
panic
(
"unexpected RemoveGroupFromAllowedGroups call"
)
}
func
(
s
*
userRepoStub
)
UpdateTotpSecret
(
ctx
context
.
Context
,
userID
int64
,
encryptedSecret
*
string
)
error
{
panic
(
"unexpected UpdateTotpSecret call"
)
}
func
(
s
*
userRepoStub
)
EnableTotp
(
ctx
context
.
Context
,
userID
int64
)
error
{
panic
(
"unexpected EnableTotp call"
)
}
func
(
s
*
userRepoStub
)
DisableTotp
(
ctx
context
.
Context
,
userID
int64
)
error
{
panic
(
"unexpected DisableTotp call"
)
}
type
groupRepoStub
struct
{
affectedUserIDs
[]
int64
deleteErr
error
...
...
backend/internal/service/announcement.go
0 → 100644
View file @
0170d19f
package
service
import
(
"context"
"time"
"github.com/Wei-Shaw/sub2api/internal/domain"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
)
const
(
AnnouncementStatusDraft
=
domain
.
AnnouncementStatusDraft
AnnouncementStatusActive
=
domain
.
AnnouncementStatusActive
AnnouncementStatusArchived
=
domain
.
AnnouncementStatusArchived
)
const
(
AnnouncementConditionTypeSubscription
=
domain
.
AnnouncementConditionTypeSubscription
AnnouncementConditionTypeBalance
=
domain
.
AnnouncementConditionTypeBalance
)
const
(
AnnouncementOperatorIn
=
domain
.
AnnouncementOperatorIn
AnnouncementOperatorGT
=
domain
.
AnnouncementOperatorGT
AnnouncementOperatorGTE
=
domain
.
AnnouncementOperatorGTE
AnnouncementOperatorLT
=
domain
.
AnnouncementOperatorLT
AnnouncementOperatorLTE
=
domain
.
AnnouncementOperatorLTE
AnnouncementOperatorEQ
=
domain
.
AnnouncementOperatorEQ
)
var
(
ErrAnnouncementNotFound
=
domain
.
ErrAnnouncementNotFound
ErrAnnouncementInvalidTarget
=
domain
.
ErrAnnouncementInvalidTarget
)
type
AnnouncementTargeting
=
domain
.
AnnouncementTargeting
type
AnnouncementConditionGroup
=
domain
.
AnnouncementConditionGroup
type
AnnouncementCondition
=
domain
.
AnnouncementCondition
type
Announcement
=
domain
.
Announcement
type
AnnouncementListFilters
struct
{
Status
string
Search
string
}
type
AnnouncementRepository
interface
{
Create
(
ctx
context
.
Context
,
a
*
Announcement
)
error
GetByID
(
ctx
context
.
Context
,
id
int64
)
(
*
Announcement
,
error
)
Update
(
ctx
context
.
Context
,
a
*
Announcement
)
error
Delete
(
ctx
context
.
Context
,
id
int64
)
error
List
(
ctx
context
.
Context
,
params
pagination
.
PaginationParams
,
filters
AnnouncementListFilters
)
([]
Announcement
,
*
pagination
.
PaginationResult
,
error
)
ListActive
(
ctx
context
.
Context
,
now
time
.
Time
)
([]
Announcement
,
error
)
}
type
AnnouncementReadRepository
interface
{
MarkRead
(
ctx
context
.
Context
,
announcementID
,
userID
int64
,
readAt
time
.
Time
)
error
GetReadMapByUser
(
ctx
context
.
Context
,
userID
int64
,
announcementIDs
[]
int64
)
(
map
[
int64
]
time
.
Time
,
error
)
GetReadMapByUsers
(
ctx
context
.
Context
,
announcementID
int64
,
userIDs
[]
int64
)
(
map
[
int64
]
time
.
Time
,
error
)
CountByAnnouncementID
(
ctx
context
.
Context
,
announcementID
int64
)
(
int64
,
error
)
}
backend/internal/service/announcement_service.go
0 → 100644
View file @
0170d19f
package
service
import
(
"context"
"fmt"
"sort"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/domain"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
)
type
AnnouncementService
struct
{
announcementRepo
AnnouncementRepository
readRepo
AnnouncementReadRepository
userRepo
UserRepository
userSubRepo
UserSubscriptionRepository
}
func
NewAnnouncementService
(
announcementRepo
AnnouncementRepository
,
readRepo
AnnouncementReadRepository
,
userRepo
UserRepository
,
userSubRepo
UserSubscriptionRepository
,
)
*
AnnouncementService
{
return
&
AnnouncementService
{
announcementRepo
:
announcementRepo
,
readRepo
:
readRepo
,
userRepo
:
userRepo
,
userSubRepo
:
userSubRepo
,
}
}
type
CreateAnnouncementInput
struct
{
Title
string
Content
string
Status
string
Targeting
AnnouncementTargeting
StartsAt
*
time
.
Time
EndsAt
*
time
.
Time
ActorID
*
int64
// 管理员用户ID
}
type
UpdateAnnouncementInput
struct
{
Title
*
string
Content
*
string
Status
*
string
Targeting
*
AnnouncementTargeting
StartsAt
**
time
.
Time
EndsAt
**
time
.
Time
ActorID
*
int64
// 管理员用户ID
}
type
UserAnnouncement
struct
{
Announcement
Announcement
ReadAt
*
time
.
Time
}
type
AnnouncementUserReadStatus
struct
{
UserID
int64
`json:"user_id"`
Email
string
`json:"email"`
Username
string
`json:"username"`
Balance
float64
`json:"balance"`
Eligible
bool
`json:"eligible"`
ReadAt
*
time
.
Time
`json:"read_at,omitempty"`
}
func
(
s
*
AnnouncementService
)
Create
(
ctx
context
.
Context
,
input
*
CreateAnnouncementInput
)
(
*
Announcement
,
error
)
{
if
input
==
nil
{
return
nil
,
fmt
.
Errorf
(
"create announcement: nil input"
)
}
title
:=
strings
.
TrimSpace
(
input
.
Title
)
content
:=
strings
.
TrimSpace
(
input
.
Content
)
if
title
==
""
||
len
(
title
)
>
200
{
return
nil
,
fmt
.
Errorf
(
"create announcement: invalid title"
)
}
if
content
==
""
{
return
nil
,
fmt
.
Errorf
(
"create announcement: content is required"
)
}
status
:=
strings
.
TrimSpace
(
input
.
Status
)
if
status
==
""
{
status
=
AnnouncementStatusDraft
}
if
!
isValidAnnouncementStatus
(
status
)
{
return
nil
,
fmt
.
Errorf
(
"create announcement: invalid status"
)
}
targeting
,
err
:=
domain
.
AnnouncementTargeting
(
input
.
Targeting
)
.
NormalizeAndValidate
()
if
err
!=
nil
{
return
nil
,
err
}
if
input
.
StartsAt
!=
nil
&&
input
.
EndsAt
!=
nil
{
if
!
input
.
StartsAt
.
Before
(
*
input
.
EndsAt
)
{
return
nil
,
fmt
.
Errorf
(
"create announcement: starts_at must be before ends_at"
)
}
}
a
:=
&
Announcement
{
Title
:
title
,
Content
:
content
,
Status
:
status
,
Targeting
:
targeting
,
StartsAt
:
input
.
StartsAt
,
EndsAt
:
input
.
EndsAt
,
}
if
input
.
ActorID
!=
nil
&&
*
input
.
ActorID
>
0
{
a
.
CreatedBy
=
input
.
ActorID
a
.
UpdatedBy
=
input
.
ActorID
}
if
err
:=
s
.
announcementRepo
.
Create
(
ctx
,
a
);
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"create announcement: %w"
,
err
)
}
return
a
,
nil
}
func
(
s
*
AnnouncementService
)
Update
(
ctx
context
.
Context
,
id
int64
,
input
*
UpdateAnnouncementInput
)
(
*
Announcement
,
error
)
{
if
input
==
nil
{
return
nil
,
fmt
.
Errorf
(
"update announcement: nil input"
)
}
a
,
err
:=
s
.
announcementRepo
.
GetByID
(
ctx
,
id
)
if
err
!=
nil
{
return
nil
,
err
}
if
input
.
Title
!=
nil
{
title
:=
strings
.
TrimSpace
(
*
input
.
Title
)
if
title
==
""
||
len
(
title
)
>
200
{
return
nil
,
fmt
.
Errorf
(
"update announcement: invalid title"
)
}
a
.
Title
=
title
}
if
input
.
Content
!=
nil
{
content
:=
strings
.
TrimSpace
(
*
input
.
Content
)
if
content
==
""
{
return
nil
,
fmt
.
Errorf
(
"update announcement: content is required"
)
}
a
.
Content
=
content
}
if
input
.
Status
!=
nil
{
status
:=
strings
.
TrimSpace
(
*
input
.
Status
)
if
!
isValidAnnouncementStatus
(
status
)
{
return
nil
,
fmt
.
Errorf
(
"update announcement: invalid status"
)
}
a
.
Status
=
status
}
if
input
.
Targeting
!=
nil
{
targeting
,
err
:=
domain
.
AnnouncementTargeting
(
*
input
.
Targeting
)
.
NormalizeAndValidate
()
if
err
!=
nil
{
return
nil
,
err
}
a
.
Targeting
=
targeting
}
if
input
.
StartsAt
!=
nil
{
a
.
StartsAt
=
*
input
.
StartsAt
}
if
input
.
EndsAt
!=
nil
{
a
.
EndsAt
=
*
input
.
EndsAt
}
if
a
.
StartsAt
!=
nil
&&
a
.
EndsAt
!=
nil
{
if
!
a
.
StartsAt
.
Before
(
*
a
.
EndsAt
)
{
return
nil
,
fmt
.
Errorf
(
"update announcement: starts_at must be before ends_at"
)
}
}
if
input
.
ActorID
!=
nil
&&
*
input
.
ActorID
>
0
{
a
.
UpdatedBy
=
input
.
ActorID
}
if
err
:=
s
.
announcementRepo
.
Update
(
ctx
,
a
);
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"update announcement: %w"
,
err
)
}
return
a
,
nil
}
func
(
s
*
AnnouncementService
)
Delete
(
ctx
context
.
Context
,
id
int64
)
error
{
if
err
:=
s
.
announcementRepo
.
Delete
(
ctx
,
id
);
err
!=
nil
{
return
fmt
.
Errorf
(
"delete announcement: %w"
,
err
)
}
return
nil
}
func
(
s
*
AnnouncementService
)
GetByID
(
ctx
context
.
Context
,
id
int64
)
(
*
Announcement
,
error
)
{
return
s
.
announcementRepo
.
GetByID
(
ctx
,
id
)
}
func
(
s
*
AnnouncementService
)
List
(
ctx
context
.
Context
,
params
pagination
.
PaginationParams
,
filters
AnnouncementListFilters
)
([]
Announcement
,
*
pagination
.
PaginationResult
,
error
)
{
return
s
.
announcementRepo
.
List
(
ctx
,
params
,
filters
)
}
func
(
s
*
AnnouncementService
)
ListForUser
(
ctx
context
.
Context
,
userID
int64
,
unreadOnly
bool
)
([]
UserAnnouncement
,
error
)
{
user
,
err
:=
s
.
userRepo
.
GetByID
(
ctx
,
userID
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"get user: %w"
,
err
)
}
activeSubs
,
err
:=
s
.
userSubRepo
.
ListActiveByUserID
(
ctx
,
userID
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"list active subscriptions: %w"
,
err
)
}
activeGroupIDs
:=
make
(
map
[
int64
]
struct
{},
len
(
activeSubs
))
for
i
:=
range
activeSubs
{
activeGroupIDs
[
activeSubs
[
i
]
.
GroupID
]
=
struct
{}{}
}
now
:=
time
.
Now
()
anns
,
err
:=
s
.
announcementRepo
.
ListActive
(
ctx
,
now
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"list active announcements: %w"
,
err
)
}
visible
:=
make
([]
Announcement
,
0
,
len
(
anns
))
ids
:=
make
([]
int64
,
0
,
len
(
anns
))
for
i
:=
range
anns
{
a
:=
anns
[
i
]
if
!
a
.
IsActiveAt
(
now
)
{
continue
}
if
!
a
.
Targeting
.
Matches
(
user
.
Balance
,
activeGroupIDs
)
{
continue
}
visible
=
append
(
visible
,
a
)
ids
=
append
(
ids
,
a
.
ID
)
}
if
len
(
visible
)
==
0
{
return
[]
UserAnnouncement
{},
nil
}
readMap
,
err
:=
s
.
readRepo
.
GetReadMapByUser
(
ctx
,
userID
,
ids
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"get read map: %w"
,
err
)
}
out
:=
make
([]
UserAnnouncement
,
0
,
len
(
visible
))
for
i
:=
range
visible
{
a
:=
visible
[
i
]
readAt
,
ok
:=
readMap
[
a
.
ID
]
if
unreadOnly
&&
ok
{
continue
}
var
ptr
*
time
.
Time
if
ok
{
t
:=
readAt
ptr
=
&
t
}
out
=
append
(
out
,
UserAnnouncement
{
Announcement
:
a
,
ReadAt
:
ptr
,
})
}
// 未读优先、同状态按创建时间倒序
sort
.
Slice
(
out
,
func
(
i
,
j
int
)
bool
{
ai
,
aj
:=
out
[
i
],
out
[
j
]
if
(
ai
.
ReadAt
==
nil
)
!=
(
aj
.
ReadAt
==
nil
)
{
return
ai
.
ReadAt
==
nil
}
return
ai
.
Announcement
.
ID
>
aj
.
Announcement
.
ID
})
return
out
,
nil
}
func
(
s
*
AnnouncementService
)
MarkRead
(
ctx
context
.
Context
,
userID
,
announcementID
int64
)
error
{
// 安全:仅允许标记当前用户“可见”的公告
user
,
err
:=
s
.
userRepo
.
GetByID
(
ctx
,
userID
)
if
err
!=
nil
{
return
fmt
.
Errorf
(
"get user: %w"
,
err
)
}
a
,
err
:=
s
.
announcementRepo
.
GetByID
(
ctx
,
announcementID
)
if
err
!=
nil
{
return
err
}
now
:=
time
.
Now
()
if
!
a
.
IsActiveAt
(
now
)
{
return
ErrAnnouncementNotFound
}
activeSubs
,
err
:=
s
.
userSubRepo
.
ListActiveByUserID
(
ctx
,
userID
)
if
err
!=
nil
{
return
fmt
.
Errorf
(
"list active subscriptions: %w"
,
err
)
}
activeGroupIDs
:=
make
(
map
[
int64
]
struct
{},
len
(
activeSubs
))
for
i
:=
range
activeSubs
{
activeGroupIDs
[
activeSubs
[
i
]
.
GroupID
]
=
struct
{}{}
}
if
!
a
.
Targeting
.
Matches
(
user
.
Balance
,
activeGroupIDs
)
{
return
ErrAnnouncementNotFound
}
if
err
:=
s
.
readRepo
.
MarkRead
(
ctx
,
announcementID
,
userID
,
now
);
err
!=
nil
{
return
fmt
.
Errorf
(
"mark read: %w"
,
err
)
}
return
nil
}
func
(
s
*
AnnouncementService
)
ListUserReadStatus
(
ctx
context
.
Context
,
announcementID
int64
,
params
pagination
.
PaginationParams
,
search
string
,
)
([]
AnnouncementUserReadStatus
,
*
pagination
.
PaginationResult
,
error
)
{
ann
,
err
:=
s
.
announcementRepo
.
GetByID
(
ctx
,
announcementID
)
if
err
!=
nil
{
return
nil
,
nil
,
err
}
filters
:=
UserListFilters
{
Search
:
strings
.
TrimSpace
(
search
),
}
users
,
page
,
err
:=
s
.
userRepo
.
ListWithFilters
(
ctx
,
params
,
filters
)
if
err
!=
nil
{
return
nil
,
nil
,
fmt
.
Errorf
(
"list users: %w"
,
err
)
}
userIDs
:=
make
([]
int64
,
0
,
len
(
users
))
for
i
:=
range
users
{
userIDs
=
append
(
userIDs
,
users
[
i
]
.
ID
)
}
readMap
,
err
:=
s
.
readRepo
.
GetReadMapByUsers
(
ctx
,
announcementID
,
userIDs
)
if
err
!=
nil
{
return
nil
,
nil
,
fmt
.
Errorf
(
"get read map: %w"
,
err
)
}
out
:=
make
([]
AnnouncementUserReadStatus
,
0
,
len
(
users
))
for
i
:=
range
users
{
u
:=
users
[
i
]
subs
,
err
:=
s
.
userSubRepo
.
ListActiveByUserID
(
ctx
,
u
.
ID
)
if
err
!=
nil
{
return
nil
,
nil
,
fmt
.
Errorf
(
"list active subscriptions: %w"
,
err
)
}
activeGroupIDs
:=
make
(
map
[
int64
]
struct
{},
len
(
subs
))
for
j
:=
range
subs
{
activeGroupIDs
[
subs
[
j
]
.
GroupID
]
=
struct
{}{}
}
readAt
,
ok
:=
readMap
[
u
.
ID
]
var
ptr
*
time
.
Time
if
ok
{
t
:=
readAt
ptr
=
&
t
}
out
=
append
(
out
,
AnnouncementUserReadStatus
{
UserID
:
u
.
ID
,
Email
:
u
.
Email
,
Username
:
u
.
Username
,
Balance
:
u
.
Balance
,
Eligible
:
domain
.
AnnouncementTargeting
(
ann
.
Targeting
)
.
Matches
(
u
.
Balance
,
activeGroupIDs
),
ReadAt
:
ptr
,
})
}
return
out
,
page
,
nil
}
func
isValidAnnouncementStatus
(
status
string
)
bool
{
switch
status
{
case
AnnouncementStatusDraft
,
AnnouncementStatusActive
,
AnnouncementStatusArchived
:
return
true
default
:
return
false
}
}
backend/internal/service/announcement_targeting_test.go
0 → 100644
View file @
0170d19f
package
service
import
(
"testing"
"github.com/stretchr/testify/require"
)
func
TestAnnouncementTargeting_Matches_EmptyMatchesAll
(
t
*
testing
.
T
)
{
var
targeting
AnnouncementTargeting
require
.
True
(
t
,
targeting
.
Matches
(
0
,
nil
))
require
.
True
(
t
,
targeting
.
Matches
(
123.45
,
map
[
int64
]
struct
{}{
1
:
{}}))
}
func
TestAnnouncementTargeting_NormalizeAndValidate_RejectsEmptyGroup
(
t
*
testing
.
T
)
{
targeting
:=
AnnouncementTargeting
{
AnyOf
:
[]
AnnouncementConditionGroup
{
{
AllOf
:
nil
},
},
}
_
,
err
:=
targeting
.
NormalizeAndValidate
()
require
.
Error
(
t
,
err
)
require
.
ErrorIs
(
t
,
err
,
ErrAnnouncementInvalidTarget
)
}
func
TestAnnouncementTargeting_NormalizeAndValidate_RejectsInvalidCondition
(
t
*
testing
.
T
)
{
targeting
:=
AnnouncementTargeting
{
AnyOf
:
[]
AnnouncementConditionGroup
{
{
AllOf
:
[]
AnnouncementCondition
{
{
Type
:
"balance"
,
Operator
:
"between"
,
Value
:
10
},
},
},
},
}
_
,
err
:=
targeting
.
NormalizeAndValidate
()
require
.
Error
(
t
,
err
)
require
.
ErrorIs
(
t
,
err
,
ErrAnnouncementInvalidTarget
)
}
func
TestAnnouncementTargeting_Matches_AndOrSemantics
(
t
*
testing
.
T
)
{
targeting
:=
AnnouncementTargeting
{
AnyOf
:
[]
AnnouncementConditionGroup
{
{
AllOf
:
[]
AnnouncementCondition
{
{
Type
:
AnnouncementConditionTypeBalance
,
Operator
:
AnnouncementOperatorGTE
,
Value
:
100
},
{
Type
:
AnnouncementConditionTypeSubscription
,
Operator
:
AnnouncementOperatorIn
,
GroupIDs
:
[]
int64
{
10
}},
},
},
{
AllOf
:
[]
AnnouncementCondition
{
{
Type
:
AnnouncementConditionTypeBalance
,
Operator
:
AnnouncementOperatorLT
,
Value
:
5
},
},
},
},
}
// 命中第 2 组(balance < 5)
require
.
True
(
t
,
targeting
.
Matches
(
4.99
,
nil
))
require
.
False
(
t
,
targeting
.
Matches
(
5
,
nil
))
// 命中第 1 组(balance >= 100 AND 订阅 in [10])
require
.
False
(
t
,
targeting
.
Matches
(
100
,
map
[
int64
]
struct
{}{}))
require
.
False
(
t
,
targeting
.
Matches
(
99.9
,
map
[
int64
]
struct
{}{
10
:
{}}))
require
.
True
(
t
,
targeting
.
Matches
(
100
,
map
[
int64
]
struct
{}{
10
:
{}}))
}
Prev
1
…
4
5
6
7
8
9
10
11
12
…
16
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment