Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Menu
Open sidebar
陈曦
sub2api
Commits
0170d19f
"backend/internal/vscode:/vscode.git/clone" did not exist on "f22bc59fe37c9708c5751e6aace6913995d6f251"
Commit
0170d19f
authored
Feb 02, 2026
by
song
Browse files
merge upstream main
parent
7ade9baa
Changes
319
Hide whitespace changes
Inline
Side-by-side
backend/internal/repository/usage_log_repo_integration_test.go
View file @
0170d19f
...
...
@@ -944,17 +944,17 @@ func (s *UsageLogRepoSuite) TestGetUsageTrendWithFilters() {
endTime
:=
base
.
Add
(
48
*
time
.
Hour
)
// Test with user filter
trend
,
err
:=
s
.
repo
.
GetUsageTrendWithFilters
(
s
.
ctx
,
startTime
,
endTime
,
"day"
,
user
.
ID
,
0
,
0
,
0
,
""
,
nil
)
trend
,
err
:=
s
.
repo
.
GetUsageTrendWithFilters
(
s
.
ctx
,
startTime
,
endTime
,
"day"
,
user
.
ID
,
0
,
0
,
0
,
""
,
nil
,
nil
)
s
.
Require
()
.
NoError
(
err
,
"GetUsageTrendWithFilters user filter"
)
s
.
Require
()
.
Len
(
trend
,
2
)
// Test with apiKey filter
trend
,
err
=
s
.
repo
.
GetUsageTrendWithFilters
(
s
.
ctx
,
startTime
,
endTime
,
"day"
,
0
,
apiKey
.
ID
,
0
,
0
,
""
,
nil
)
trend
,
err
=
s
.
repo
.
GetUsageTrendWithFilters
(
s
.
ctx
,
startTime
,
endTime
,
"day"
,
0
,
apiKey
.
ID
,
0
,
0
,
""
,
nil
,
nil
)
s
.
Require
()
.
NoError
(
err
,
"GetUsageTrendWithFilters apiKey filter"
)
s
.
Require
()
.
Len
(
trend
,
2
)
// Test with both filters
trend
,
err
=
s
.
repo
.
GetUsageTrendWithFilters
(
s
.
ctx
,
startTime
,
endTime
,
"day"
,
user
.
ID
,
apiKey
.
ID
,
0
,
0
,
""
,
nil
)
trend
,
err
=
s
.
repo
.
GetUsageTrendWithFilters
(
s
.
ctx
,
startTime
,
endTime
,
"day"
,
user
.
ID
,
apiKey
.
ID
,
0
,
0
,
""
,
nil
,
nil
)
s
.
Require
()
.
NoError
(
err
,
"GetUsageTrendWithFilters both filters"
)
s
.
Require
()
.
Len
(
trend
,
2
)
}
...
...
@@ -971,7 +971,7 @@ func (s *UsageLogRepoSuite) TestGetUsageTrendWithFilters_HourlyGranularity() {
startTime
:=
base
.
Add
(
-
1
*
time
.
Hour
)
endTime
:=
base
.
Add
(
3
*
time
.
Hour
)
trend
,
err
:=
s
.
repo
.
GetUsageTrendWithFilters
(
s
.
ctx
,
startTime
,
endTime
,
"hour"
,
user
.
ID
,
0
,
0
,
0
,
""
,
nil
)
trend
,
err
:=
s
.
repo
.
GetUsageTrendWithFilters
(
s
.
ctx
,
startTime
,
endTime
,
"hour"
,
user
.
ID
,
0
,
0
,
0
,
""
,
nil
,
nil
)
s
.
Require
()
.
NoError
(
err
,
"GetUsageTrendWithFilters hourly"
)
s
.
Require
()
.
Len
(
trend
,
2
)
}
...
...
@@ -1017,17 +1017,17 @@ func (s *UsageLogRepoSuite) TestGetModelStatsWithFilters() {
endTime
:=
base
.
Add
(
2
*
time
.
Hour
)
// Test with user filter
stats
,
err
:=
s
.
repo
.
GetModelStatsWithFilters
(
s
.
ctx
,
startTime
,
endTime
,
user
.
ID
,
0
,
0
,
0
,
nil
)
stats
,
err
:=
s
.
repo
.
GetModelStatsWithFilters
(
s
.
ctx
,
startTime
,
endTime
,
user
.
ID
,
0
,
0
,
0
,
nil
,
nil
)
s
.
Require
()
.
NoError
(
err
,
"GetModelStatsWithFilters user filter"
)
s
.
Require
()
.
Len
(
stats
,
2
)
// Test with apiKey filter
stats
,
err
=
s
.
repo
.
GetModelStatsWithFilters
(
s
.
ctx
,
startTime
,
endTime
,
0
,
apiKey
.
ID
,
0
,
0
,
nil
)
stats
,
err
=
s
.
repo
.
GetModelStatsWithFilters
(
s
.
ctx
,
startTime
,
endTime
,
0
,
apiKey
.
ID
,
0
,
0
,
nil
,
nil
)
s
.
Require
()
.
NoError
(
err
,
"GetModelStatsWithFilters apiKey filter"
)
s
.
Require
()
.
Len
(
stats
,
2
)
// Test with account filter
stats
,
err
=
s
.
repo
.
GetModelStatsWithFilters
(
s
.
ctx
,
startTime
,
endTime
,
0
,
0
,
account
.
ID
,
0
,
nil
)
stats
,
err
=
s
.
repo
.
GetModelStatsWithFilters
(
s
.
ctx
,
startTime
,
endTime
,
0
,
0
,
account
.
ID
,
0
,
nil
,
nil
)
s
.
Require
()
.
NoError
(
err
,
"GetModelStatsWithFilters account filter"
)
s
.
Require
()
.
Len
(
stats
,
2
)
}
...
...
backend/internal/repository/user_repo.go
View file @
0170d19f
...
...
@@ -7,6 +7,7 @@ import (
"fmt"
"sort"
"strings"
"time"
dbent
"github.com/Wei-Shaw/sub2api/ent"
dbuser
"github.com/Wei-Shaw/sub2api/ent/user"
...
...
@@ -189,6 +190,7 @@ func (r *userRepository) ListWithFilters(ctx context.Context, params pagination.
dbuser
.
Or
(
dbuser
.
EmailContainsFold
(
filters
.
Search
),
dbuser
.
UsernameContainsFold
(
filters
.
Search
),
dbuser
.
NotesContainsFold
(
filters
.
Search
),
),
)
}
...
...
@@ -466,3 +468,46 @@ func applyUserEntityToService(dst *service.User, src *dbent.User) {
dst
.
CreatedAt
=
src
.
CreatedAt
dst
.
UpdatedAt
=
src
.
UpdatedAt
}
// UpdateTotpSecret 更新用户的 TOTP 加密密钥
func
(
r
*
userRepository
)
UpdateTotpSecret
(
ctx
context
.
Context
,
userID
int64
,
encryptedSecret
*
string
)
error
{
client
:=
clientFromContext
(
ctx
,
r
.
client
)
update
:=
client
.
User
.
UpdateOneID
(
userID
)
if
encryptedSecret
==
nil
{
update
=
update
.
ClearTotpSecretEncrypted
()
}
else
{
update
=
update
.
SetTotpSecretEncrypted
(
*
encryptedSecret
)
}
_
,
err
:=
update
.
Save
(
ctx
)
if
err
!=
nil
{
return
translatePersistenceError
(
err
,
service
.
ErrUserNotFound
,
nil
)
}
return
nil
}
// EnableTotp 启用用户的 TOTP 双因素认证
func
(
r
*
userRepository
)
EnableTotp
(
ctx
context
.
Context
,
userID
int64
)
error
{
client
:=
clientFromContext
(
ctx
,
r
.
client
)
_
,
err
:=
client
.
User
.
UpdateOneID
(
userID
)
.
SetTotpEnabled
(
true
)
.
SetTotpEnabledAt
(
time
.
Now
())
.
Save
(
ctx
)
if
err
!=
nil
{
return
translatePersistenceError
(
err
,
service
.
ErrUserNotFound
,
nil
)
}
return
nil
}
// DisableTotp 禁用用户的 TOTP 双因素认证
func
(
r
*
userRepository
)
DisableTotp
(
ctx
context
.
Context
,
userID
int64
)
error
{
client
:=
clientFromContext
(
ctx
,
r
.
client
)
_
,
err
:=
client
.
User
.
UpdateOneID
(
userID
)
.
SetTotpEnabled
(
false
)
.
ClearTotpEnabledAt
()
.
ClearTotpSecretEncrypted
()
.
Save
(
ctx
)
if
err
!=
nil
{
return
translatePersistenceError
(
err
,
service
.
ErrUserNotFound
,
nil
)
}
return
nil
}
backend/internal/repository/user_subscription_repo.go
View file @
0170d19f
...
...
@@ -190,7 +190,7 @@ func (r *userSubscriptionRepository) ListByGroupID(ctx context.Context, groupID
return
userSubscriptionEntitiesToService
(
subs
),
paginationResultFromTotal
(
int64
(
total
),
params
),
nil
}
func
(
r
*
userSubscriptionRepository
)
List
(
ctx
context
.
Context
,
params
pagination
.
PaginationParams
,
userID
,
groupID
*
int64
,
status
string
)
([]
service
.
UserSubscription
,
*
pagination
.
PaginationResult
,
error
)
{
func
(
r
*
userSubscriptionRepository
)
List
(
ctx
context
.
Context
,
params
pagination
.
PaginationParams
,
userID
,
groupID
*
int64
,
status
,
sortBy
,
sortOrder
string
)
([]
service
.
UserSubscription
,
*
pagination
.
PaginationResult
,
error
)
{
client
:=
clientFromContext
(
ctx
,
r
.
client
)
q
:=
client
.
UserSubscription
.
Query
()
if
userID
!=
nil
{
...
...
@@ -199,7 +199,31 @@ func (r *userSubscriptionRepository) List(ctx context.Context, params pagination
if
groupID
!=
nil
{
q
=
q
.
Where
(
usersubscription
.
GroupIDEQ
(
*
groupID
))
}
if
status
!=
""
{
// Status filtering with real-time expiration check
now
:=
time
.
Now
()
switch
status
{
case
service
.
SubscriptionStatusActive
:
// Active: status is active AND not yet expired
q
=
q
.
Where
(
usersubscription
.
StatusEQ
(
service
.
SubscriptionStatusActive
),
usersubscription
.
ExpiresAtGT
(
now
),
)
case
service
.
SubscriptionStatusExpired
:
// Expired: status is expired OR (status is active but already expired)
q
=
q
.
Where
(
usersubscription
.
Or
(
usersubscription
.
StatusEQ
(
service
.
SubscriptionStatusExpired
),
usersubscription
.
And
(
usersubscription
.
StatusEQ
(
service
.
SubscriptionStatusActive
),
usersubscription
.
ExpiresAtLTE
(
now
),
),
),
)
case
""
:
// No filter
default
:
// Other status (e.g., revoked)
q
=
q
.
Where
(
usersubscription
.
StatusEQ
(
status
))
}
...
...
@@ -208,11 +232,28 @@ func (r *userSubscriptionRepository) List(ctx context.Context, params pagination
return
nil
,
nil
,
err
}
// Apply sorting
q
=
q
.
WithUser
()
.
WithGroup
()
.
WithAssignedByUser
()
// Determine sort field
var
field
string
switch
sortBy
{
case
"expires_at"
:
field
=
usersubscription
.
FieldExpiresAt
case
"status"
:
field
=
usersubscription
.
FieldStatus
default
:
field
=
usersubscription
.
FieldCreatedAt
}
// Determine sort order (default: desc)
if
sortOrder
==
"asc"
&&
sortBy
!=
""
{
q
=
q
.
Order
(
dbent
.
Asc
(
field
))
}
else
{
q
=
q
.
Order
(
dbent
.
Desc
(
field
))
}
subs
,
err
:=
q
.
WithUser
()
.
WithGroup
()
.
WithAssignedByUser
()
.
Order
(
dbent
.
Desc
(
usersubscription
.
FieldCreatedAt
))
.
Offset
(
params
.
Offset
())
.
Limit
(
params
.
Limit
())
.
All
(
ctx
)
...
...
backend/internal/repository/user_subscription_repo_integration_test.go
View file @
0170d19f
...
...
@@ -271,7 +271,7 @@ func (s *UserSubscriptionRepoSuite) TestList_NoFilters() {
group
:=
s
.
mustCreateGroup
(
"g-list"
)
s
.
mustCreateSubscription
(
user
.
ID
,
group
.
ID
,
nil
)
subs
,
page
,
err
:=
s
.
repo
.
List
(
s
.
ctx
,
pagination
.
PaginationParams
{
Page
:
1
,
PageSize
:
10
},
nil
,
nil
,
""
)
subs
,
page
,
err
:=
s
.
repo
.
List
(
s
.
ctx
,
pagination
.
PaginationParams
{
Page
:
1
,
PageSize
:
10
},
nil
,
nil
,
""
,
""
,
""
)
s
.
Require
()
.
NoError
(
err
,
"List"
)
s
.
Require
()
.
Len
(
subs
,
1
)
s
.
Require
()
.
Equal
(
int64
(
1
),
page
.
Total
)
...
...
@@ -285,7 +285,7 @@ func (s *UserSubscriptionRepoSuite) TestList_FilterByUserID() {
s
.
mustCreateSubscription
(
user1
.
ID
,
group
.
ID
,
nil
)
s
.
mustCreateSubscription
(
user2
.
ID
,
group
.
ID
,
nil
)
subs
,
_
,
err
:=
s
.
repo
.
List
(
s
.
ctx
,
pagination
.
PaginationParams
{
Page
:
1
,
PageSize
:
10
},
&
user1
.
ID
,
nil
,
""
)
subs
,
_
,
err
:=
s
.
repo
.
List
(
s
.
ctx
,
pagination
.
PaginationParams
{
Page
:
1
,
PageSize
:
10
},
&
user1
.
ID
,
nil
,
""
,
""
,
""
)
s
.
Require
()
.
NoError
(
err
)
s
.
Require
()
.
Len
(
subs
,
1
)
s
.
Require
()
.
Equal
(
user1
.
ID
,
subs
[
0
]
.
UserID
)
...
...
@@ -299,7 +299,7 @@ func (s *UserSubscriptionRepoSuite) TestList_FilterByGroupID() {
s
.
mustCreateSubscription
(
user
.
ID
,
g1
.
ID
,
nil
)
s
.
mustCreateSubscription
(
user
.
ID
,
g2
.
ID
,
nil
)
subs
,
_
,
err
:=
s
.
repo
.
List
(
s
.
ctx
,
pagination
.
PaginationParams
{
Page
:
1
,
PageSize
:
10
},
nil
,
&
g1
.
ID
,
""
)
subs
,
_
,
err
:=
s
.
repo
.
List
(
s
.
ctx
,
pagination
.
PaginationParams
{
Page
:
1
,
PageSize
:
10
},
nil
,
&
g1
.
ID
,
""
,
""
,
""
)
s
.
Require
()
.
NoError
(
err
)
s
.
Require
()
.
Len
(
subs
,
1
)
s
.
Require
()
.
Equal
(
g1
.
ID
,
subs
[
0
]
.
GroupID
)
...
...
@@ -320,7 +320,7 @@ func (s *UserSubscriptionRepoSuite) TestList_FilterByStatus() {
c
.
SetExpiresAt
(
time
.
Now
()
.
Add
(
-
24
*
time
.
Hour
))
})
subs
,
_
,
err
:=
s
.
repo
.
List
(
s
.
ctx
,
pagination
.
PaginationParams
{
Page
:
1
,
PageSize
:
10
},
nil
,
nil
,
service
.
SubscriptionStatusExpired
)
subs
,
_
,
err
:=
s
.
repo
.
List
(
s
.
ctx
,
pagination
.
PaginationParams
{
Page
:
1
,
PageSize
:
10
},
nil
,
nil
,
service
.
SubscriptionStatusExpired
,
""
,
""
)
s
.
Require
()
.
NoError
(
err
)
s
.
Require
()
.
Len
(
subs
,
1
)
s
.
Require
()
.
Equal
(
service
.
SubscriptionStatusExpired
,
subs
[
0
]
.
Status
)
...
...
backend/internal/repository/wire.go
View file @
0170d19f
...
...
@@ -56,7 +56,10 @@ var ProviderSet = wire.NewSet(
NewProxyRepository
,
NewRedeemCodeRepository
,
NewPromoCodeRepository
,
NewAnnouncementRepository
,
NewAnnouncementReadRepository
,
NewUsageLogRepository
,
NewUsageCleanupRepository
,
NewDashboardAggregationRepository
,
NewSettingRepository
,
NewOpsRepository
,
...
...
@@ -81,6 +84,10 @@ var ProviderSet = wire.NewSet(
NewSchedulerCache
,
NewSchedulerOutboxRepository
,
NewProxyLatencyCache
,
NewTotpCache
,
// Encryptors
NewAESEncryptor
,
// HTTP service ports (DI Strategy A: return interface directly)
NewTurnstileVerifier
,
...
...
backend/internal/server/api_contract_test.go
View file @
0170d19f
...
...
@@ -11,7 +11,6 @@ import (
"net/http"
"net/http/httptest"
"sort"
"strings"
"testing"
"time"
...
...
@@ -52,7 +51,6 @@ func TestAPIContracts(t *testing.T) {
"id": 1,
"email": "alice@example.com",
"username": "alice",
"notes": "hello",
"role": "user",
"balance": 12.5,
"concurrency": 5,
...
...
@@ -132,6 +130,153 @@ func TestAPIContracts(t *testing.T) {
}
}`
,
},
{
name
:
"GET /api/v1/groups/available"
,
setup
:
func
(
t
*
testing
.
T
,
deps
*
contractDeps
)
{
t
.
Helper
()
// 普通用户可见的分组列表不应包含内部字段(如 model_routing/account_count)。
deps
.
groupRepo
.
SetActive
([]
service
.
Group
{
{
ID
:
10
,
Name
:
"Group One"
,
Description
:
"desc"
,
Platform
:
service
.
PlatformAnthropic
,
RateMultiplier
:
1.5
,
IsExclusive
:
false
,
Status
:
service
.
StatusActive
,
SubscriptionType
:
service
.
SubscriptionTypeStandard
,
ModelRoutingEnabled
:
true
,
ModelRouting
:
map
[
string
][]
int64
{
"claude-3-*"
:
[]
int64
{
101
,
102
},
},
AccountCount
:
2
,
CreatedAt
:
deps
.
now
,
UpdatedAt
:
deps
.
now
,
},
})
deps
.
userSubRepo
.
SetActiveByUserID
(
1
,
nil
)
},
method
:
http
.
MethodGet
,
path
:
"/api/v1/groups/available"
,
wantStatus
:
http
.
StatusOK
,
wantJSON
:
`{
"code": 0,
"message": "success",
"data": [
{
"id": 10,
"name": "Group One",
"description": "desc",
"platform": "anthropic",
"rate_multiplier": 1.5,
"is_exclusive": false,
"status": "active",
"subscription_type": "standard",
"daily_limit_usd": null,
"weekly_limit_usd": null,
"monthly_limit_usd": null,
"image_price_1k": null,
"image_price_2k": null,
"image_price_4k": null,
"claude_code_only": false,
"fallback_group_id": null,
"created_at": "2025-01-02T03:04:05Z",
"updated_at": "2025-01-02T03:04:05Z"
}
]
}`
,
},
{
name
:
"GET /api/v1/subscriptions"
,
setup
:
func
(
t
*
testing
.
T
,
deps
*
contractDeps
)
{
t
.
Helper
()
// 普通用户订阅接口不应包含 assigned_* / notes 等管理员字段。
deps
.
userSubRepo
.
SetByUserID
(
1
,
[]
service
.
UserSubscription
{
{
ID
:
501
,
UserID
:
1
,
GroupID
:
10
,
StartsAt
:
deps
.
now
,
ExpiresAt
:
time
.
Date
(
2099
,
1
,
2
,
3
,
4
,
5
,
0
,
time
.
UTC
),
// 使用未来日期避免 normalizeSubscriptionStatus 标记为过期
Status
:
service
.
SubscriptionStatusActive
,
DailyUsageUSD
:
1.23
,
WeeklyUsageUSD
:
2.34
,
MonthlyUsageUSD
:
3.45
,
AssignedBy
:
ptr
(
int64
(
999
)),
AssignedAt
:
deps
.
now
,
Notes
:
"admin-note"
,
CreatedAt
:
deps
.
now
,
UpdatedAt
:
deps
.
now
,
},
})
},
method
:
http
.
MethodGet
,
path
:
"/api/v1/subscriptions"
,
wantStatus
:
http
.
StatusOK
,
wantJSON
:
`{
"code": 0,
"message": "success",
"data": [
{
"id": 501,
"user_id": 1,
"group_id": 10,
"starts_at": "2025-01-02T03:04:05Z",
"expires_at": "2099-01-02T03:04:05Z",
"status": "active",
"daily_window_start": null,
"weekly_window_start": null,
"monthly_window_start": null,
"daily_usage_usd": 1.23,
"weekly_usage_usd": 2.34,
"monthly_usage_usd": 3.45,
"created_at": "2025-01-02T03:04:05Z",
"updated_at": "2025-01-02T03:04:05Z"
}
]
}`
,
},
{
name
:
"GET /api/v1/redeem/history"
,
setup
:
func
(
t
*
testing
.
T
,
deps
*
contractDeps
)
{
t
.
Helper
()
// 普通用户兑换历史不应包含 notes 等内部字段。
deps
.
redeemRepo
.
SetByUser
(
1
,
[]
service
.
RedeemCode
{
{
ID
:
900
,
Code
:
"CODE-123"
,
Type
:
service
.
RedeemTypeBalance
,
Value
:
1.25
,
Status
:
service
.
StatusUsed
,
UsedBy
:
ptr
(
int64
(
1
)),
UsedAt
:
ptr
(
deps
.
now
),
Notes
:
"internal-note"
,
CreatedAt
:
deps
.
now
,
},
})
},
method
:
http
.
MethodGet
,
path
:
"/api/v1/redeem/history"
,
wantStatus
:
http
.
StatusOK
,
wantJSON
:
`{
"code": 0,
"message": "success",
"data": [
{
"id": 900,
"code": "CODE-123",
"type": "balance",
"value": 1.25,
"status": "used",
"used_by": 1,
"used_at": "2025-01-02T03:04:05Z",
"created_at": "2025-01-02T03:04:05Z",
"group_id": null,
"validity_days": 0
}
]
}`
,
},
{
name
:
"GET /api/v1/usage/stats"
,
setup
:
func
(
t
*
testing
.
T
,
deps
*
contractDeps
)
{
...
...
@@ -191,24 +336,25 @@ func TestAPIContracts(t *testing.T) {
t
.
Helper
()
deps
.
usageRepo
.
SetUserLogs
(
1
,
[]
service
.
UsageLog
{
{
ID
:
1
,
UserID
:
1
,
APIKeyID
:
100
,
AccountID
:
200
,
RequestID
:
"req_123"
,
Model
:
"claude-3"
,
InputTokens
:
10
,
OutputTokens
:
20
,
CacheCreationTokens
:
1
,
CacheReadTokens
:
2
,
TotalCost
:
0.5
,
ActualCost
:
0.5
,
RateMultiplier
:
1
,
BillingType
:
service
.
BillingTypeBalance
,
Stream
:
true
,
DurationMs
:
ptr
(
100
),
FirstTokenMs
:
ptr
(
50
),
CreatedAt
:
deps
.
now
,
ID
:
1
,
UserID
:
1
,
APIKeyID
:
100
,
AccountID
:
200
,
AccountRateMultiplier
:
ptr
(
0.5
),
RequestID
:
"req_123"
,
Model
:
"claude-3"
,
InputTokens
:
10
,
OutputTokens
:
20
,
CacheCreationTokens
:
1
,
CacheReadTokens
:
2
,
TotalCost
:
0.5
,
ActualCost
:
0.5
,
RateMultiplier
:
1
,
BillingType
:
service
.
BillingTypeBalance
,
Stream
:
true
,
DurationMs
:
ptr
(
100
),
FirstTokenMs
:
ptr
(
50
),
CreatedAt
:
deps
.
now
,
},
})
},
...
...
@@ -239,10 +385,9 @@ func TestAPIContracts(t *testing.T) {
"output_cost": 0,
"cache_creation_cost": 0,
"cache_read_cost": 0,
"total_cost": 0.5,
"total_cost": 0.5,
"actual_cost": 0.5,
"rate_multiplier": 1,
"account_rate_multiplier": null,
"billing_type": 0,
"stream": true,
"duration_ms": 100,
...
...
@@ -267,6 +412,7 @@ func TestAPIContracts(t *testing.T) {
deps
.
settingRepo
.
SetAll
(
map
[
string
]
string
{
service
.
SettingKeyRegistrationEnabled
:
"true"
,
service
.
SettingKeyEmailVerifyEnabled
:
"false"
,
service
.
SettingKeyPromoCodeEnabled
:
"true"
,
service
.
SettingKeySMTPHost
:
"smtp.example.com"
,
service
.
SettingKeySMTPPort
:
"587"
,
...
...
@@ -305,6 +451,10 @@ func TestAPIContracts(t *testing.T) {
"data": {
"registration_enabled": true,
"email_verify_enabled": false,
"promo_code_enabled": true,
"password_reset_enabled": false,
"totp_enabled": false,
"totp_encryption_key_configured": false,
"smtp_host": "smtp.example.com",
"smtp_port": 587,
"smtp_username": "user",
...
...
@@ -338,45 +488,10 @@ func TestAPIContracts(t *testing.T) {
"fallback_model_openai": "gpt-4o",
"enable_identity_patch": true,
"identity_patch_prompt": "",
"home_content": ""
}
}`
,
},
{
name
:
"POST /api/v1/admin/accounts/lookup"
,
setup
:
func
(
t
*
testing
.
T
,
deps
*
contractDeps
)
{
t
.
Helper
()
deps
.
accountRepo
.
lookupAccounts
=
[]
service
.
Account
{
{
ID
:
101
,
Name
:
"Alice Account"
,
Platform
:
"antigravity"
,
Credentials
:
map
[
string
]
any
{
"email"
:
"alice@example.com"
,
},
},
}
},
method
:
http
.
MethodPost
,
path
:
"/api/v1/admin/accounts/lookup"
,
body
:
`{"platform":"antigravity","emails":["Alice@Example.com","bob@example.com"]}`
,
headers
:
map
[
string
]
string
{
"Content-Type"
:
"application/json"
,
},
wantStatus
:
http
.
StatusOK
,
wantJSON
:
`{
"code": 0,
"message": "success",
"data": {
"matched": [
{
"email": "alice@example.com",
"account_id": 101,
"platform": "antigravity",
"name": "Alice Account"
}
],
"missing": ["bob@example.com"]
"home_content": "",
"hide_ccs_import_button": false,
"purchase_subscription_enabled": false,
"purchase_subscription_url": ""
}
}`
,
},
...
...
@@ -424,9 +539,11 @@ type contractDeps struct {
now
time
.
Time
router
http
.
Handler
apiKeyRepo
*
stubApiKeyRepo
groupRepo
*
stubGroupRepo
userSubRepo
*
stubUserSubscriptionRepo
usageRepo
*
stubUsageLogRepo
settingRepo
*
stubSettingRepo
account
Repo
*
stub
Account
Repo
redeem
Repo
*
stub
RedeemCode
Repo
}
func
newContractDeps
(
t
*
testing
.
T
)
*
contractDeps
{
...
...
@@ -454,11 +571,11 @@ func newContractDeps(t *testing.T) *contractDeps {
apiKeyRepo
:=
newStubApiKeyRepo
(
now
)
apiKeyCache
:=
stubApiKeyCache
{}
groupRepo
:=
stubGroupRepo
{}
userSubRepo
:=
stubUserSubscriptionRepo
{}
groupRepo
:=
&
stubGroupRepo
{}
userSubRepo
:=
&
stubUserSubscriptionRepo
{}
accountRepo
:=
stubAccountRepo
{}
proxyRepo
:=
stubProxyRepo
{}
redeemRepo
:=
stubRedeemCodeRepo
{}
redeemRepo
:=
&
stubRedeemCodeRepo
{}
cfg
:=
&
config
.
Config
{
Default
:
config
.
DefaultConfig
{
...
...
@@ -473,15 +590,21 @@ func newContractDeps(t *testing.T) *contractDeps {
usageRepo
:=
newStubUsageLogRepo
()
usageService
:=
service
.
NewUsageService
(
usageRepo
,
userRepo
,
nil
,
nil
)
subscriptionService
:=
service
.
NewSubscriptionService
(
groupRepo
,
userSubRepo
,
nil
)
subscriptionHandler
:=
handler
.
NewSubscriptionHandler
(
subscriptionService
)
redeemService
:=
service
.
NewRedeemService
(
redeemRepo
,
userRepo
,
subscriptionService
,
nil
,
nil
,
nil
,
nil
)
redeemHandler
:=
handler
.
NewRedeemHandler
(
redeemService
)
settingRepo
:=
newStubSettingRepo
()
settingService
:=
service
.
NewSettingService
(
settingRepo
,
cfg
)
adminService
:=
service
.
NewAdminService
(
userRepo
,
groupRepo
,
&
accountRepo
,
proxyRepo
,
apiKeyRepo
,
redeemRepo
,
nil
,
nil
,
nil
,
nil
)
authHandler
:=
handler
.
NewAuthHandler
(
cfg
,
nil
,
userService
,
settingService
,
nil
)
authHandler
:=
handler
.
NewAuthHandler
(
cfg
,
nil
,
userService
,
settingService
,
nil
,
nil
)
apiKeyHandler
:=
handler
.
NewAPIKeyHandler
(
apiKeyService
)
usageHandler
:=
handler
.
NewUsageHandler
(
usageService
,
apiKeyService
)
adminSettingHandler
:=
adminhandler
.
NewSettingHandler
(
settingService
,
nil
,
nil
,
nil
)
adminAccountHandler
:=
adminhandler
.
NewAccountHandler
(
adminService
,
nil
,
nil
,
nil
,
nil
,
nil
,
nil
,
nil
,
nil
,
nil
,
nil
)
adminAccountHandler
:=
adminhandler
.
NewAccountHandler
(
adminService
,
nil
,
nil
,
nil
,
nil
,
nil
,
nil
,
nil
,
nil
,
nil
,
nil
,
nil
)
jwtAuth
:=
func
(
c
*
gin
.
Context
)
{
c
.
Set
(
string
(
middleware
.
ContextKeyUser
),
middleware
.
AuthSubject
{
...
...
@@ -512,25 +635,35 @@ func newContractDeps(t *testing.T) *contractDeps {
v1Keys
.
Use
(
jwtAuth
)
v1Keys
.
GET
(
"/keys"
,
apiKeyHandler
.
List
)
v1Keys
.
POST
(
"/keys"
,
apiKeyHandler
.
Create
)
v1Keys
.
GET
(
"/groups/available"
,
apiKeyHandler
.
GetAvailableGroups
)
v1Usage
:=
v1
.
Group
(
""
)
v1Usage
.
Use
(
jwtAuth
)
v1Usage
.
GET
(
"/usage"
,
usageHandler
.
List
)
v1Usage
.
GET
(
"/usage/stats"
,
usageHandler
.
Stats
)
v1Subs
:=
v1
.
Group
(
""
)
v1Subs
.
Use
(
jwtAuth
)
v1Subs
.
GET
(
"/subscriptions"
,
subscriptionHandler
.
List
)
v1Redeem
:=
v1
.
Group
(
""
)
v1Redeem
.
Use
(
jwtAuth
)
v1Redeem
.
GET
(
"/redeem/history"
,
redeemHandler
.
GetHistory
)
v1Admin
:=
v1
.
Group
(
"/admin"
)
v1Admin
.
Use
(
adminAuth
)
v1Admin
.
GET
(
"/settings"
,
adminSettingHandler
.
GetSettings
)
v1Admin
.
POST
(
"/accounts/bulk-update"
,
adminAccountHandler
.
BulkUpdate
)
v1Admin
.
POST
(
"/accounts/lookup"
,
adminAccountHandler
.
Lookup
)
return
&
contractDeps
{
now
:
now
,
router
:
r
,
apiKeyRepo
:
apiKeyRepo
,
groupRepo
:
groupRepo
,
userSubRepo
:
userSubRepo
,
usageRepo
:
usageRepo
,
settingRepo
:
settingRepo
,
accountRepo
:
&
account
Repo
,
redeemRepo
:
redeem
Repo
,
}
}
...
...
@@ -626,6 +759,18 @@ func (r *stubUserRepo) RemoveGroupFromAllowedGroups(ctx context.Context, groupID
return
0
,
errors
.
New
(
"not implemented"
)
}
func
(
r
*
stubUserRepo
)
UpdateTotpSecret
(
ctx
context
.
Context
,
userID
int64
,
encryptedSecret
*
string
)
error
{
return
errors
.
New
(
"not implemented"
)
}
func
(
r
*
stubUserRepo
)
EnableTotp
(
ctx
context
.
Context
,
userID
int64
)
error
{
return
errors
.
New
(
"not implemented"
)
}
func
(
r
*
stubUserRepo
)
DisableTotp
(
ctx
context
.
Context
,
userID
int64
)
error
{
return
errors
.
New
(
"not implemented"
)
}
type
stubApiKeyCache
struct
{}
func
(
stubApiKeyCache
)
GetCreateAttemptCount
(
ctx
context
.
Context
,
userID
int64
)
(
int
,
error
)
{
...
...
@@ -660,7 +805,21 @@ func (stubApiKeyCache) DeleteAuthCache(ctx context.Context, key string) error {
return
nil
}
type
stubGroupRepo
struct
{}
func
(
stubApiKeyCache
)
PublishAuthCacheInvalidation
(
ctx
context
.
Context
,
cacheKey
string
)
error
{
return
nil
}
func
(
stubApiKeyCache
)
SubscribeAuthCacheInvalidation
(
ctx
context
.
Context
,
handler
func
(
cacheKey
string
))
error
{
return
nil
}
type
stubGroupRepo
struct
{
active
[]
service
.
Group
}
func
(
r
*
stubGroupRepo
)
SetActive
(
groups
[]
service
.
Group
)
{
r
.
active
=
append
([]
service
.
Group
(
nil
),
groups
...
)
}
func
(
stubGroupRepo
)
Create
(
ctx
context
.
Context
,
group
*
service
.
Group
)
error
{
return
errors
.
New
(
"not implemented"
)
...
...
@@ -694,12 +853,19 @@ func (stubGroupRepo) ListWithFilters(ctx context.Context, params pagination.Pagi
return
nil
,
nil
,
errors
.
New
(
"not implemented"
)
}
func
(
stubGroupRepo
)
ListActive
(
ctx
context
.
Context
)
([]
service
.
Group
,
error
)
{
return
nil
,
errors
.
New
(
"not implemented"
)
func
(
r
*
stubGroupRepo
)
ListActive
(
ctx
context
.
Context
)
([]
service
.
Group
,
error
)
{
return
append
([]
service
.
Group
(
nil
),
r
.
active
...
),
nil
}
func
(
stubGroupRepo
)
ListActiveByPlatform
(
ctx
context
.
Context
,
platform
string
)
([]
service
.
Group
,
error
)
{
return
nil
,
errors
.
New
(
"not implemented"
)
func
(
r
*
stubGroupRepo
)
ListActiveByPlatform
(
ctx
context
.
Context
,
platform
string
)
([]
service
.
Group
,
error
)
{
out
:=
make
([]
service
.
Group
,
0
,
len
(
r
.
active
))
for
i
:=
range
r
.
active
{
g
:=
r
.
active
[
i
]
if
g
.
Platform
==
platform
{
out
=
append
(
out
,
g
)
}
}
return
out
,
nil
}
func
(
stubGroupRepo
)
ExistsByName
(
ctx
context
.
Context
,
name
string
)
(
bool
,
error
)
{
...
...
@@ -715,8 +881,7 @@ func (stubGroupRepo) DeleteAccountGroupsByGroupID(ctx context.Context, groupID i
}
type
stubAccountRepo
struct
{
bulkUpdateIDs
[]
int64
lookupAccounts
[]
service
.
Account
bulkUpdateIDs
[]
int64
}
func
(
s
*
stubAccountRepo
)
Create
(
ctx
context
.
Context
,
account
*
service
.
Account
)
error
{
...
...
@@ -767,36 +932,6 @@ func (s *stubAccountRepo) ListByPlatform(ctx context.Context, platform string) (
return
nil
,
errors
.
New
(
"not implemented"
)
}
func
(
s
*
stubAccountRepo
)
ListByPlatformAndCredentialEmails
(
ctx
context
.
Context
,
platform
string
,
emails
[]
string
)
([]
service
.
Account
,
error
)
{
if
len
(
s
.
lookupAccounts
)
==
0
{
return
nil
,
nil
}
emailSet
:=
make
(
map
[
string
]
struct
{},
len
(
emails
))
for
_
,
email
:=
range
emails
{
normalized
:=
strings
.
ToLower
(
strings
.
TrimSpace
(
email
))
if
normalized
==
""
{
continue
}
emailSet
[
normalized
]
=
struct
{}{}
}
var
matches
[]
service
.
Account
for
i
:=
range
s
.
lookupAccounts
{
account
:=
&
s
.
lookupAccounts
[
i
]
if
account
.
Platform
!=
platform
{
continue
}
accountEmail
:=
strings
.
ToLower
(
strings
.
TrimSpace
(
account
.
GetCredential
(
"email"
)))
if
accountEmail
==
""
{
continue
}
if
_
,
ok
:=
emailSet
[
accountEmail
];
!
ok
{
continue
}
matches
=
append
(
matches
,
*
account
)
}
return
matches
,
nil
}
func
(
s
*
stubAccountRepo
)
UpdateLastUsed
(
ctx
context
.
Context
,
id
int64
)
error
{
return
errors
.
New
(
"not implemented"
)
}
...
...
@@ -948,7 +1083,16 @@ func (stubProxyRepo) ListAccountSummariesByProxyID(ctx context.Context, proxyID
return
nil
,
errors
.
New
(
"not implemented"
)
}
type
stubRedeemCodeRepo
struct
{}
type
stubRedeemCodeRepo
struct
{
byUser
map
[
int64
][]
service
.
RedeemCode
}
func
(
r
*
stubRedeemCodeRepo
)
SetByUser
(
userID
int64
,
codes
[]
service
.
RedeemCode
)
{
if
r
.
byUser
==
nil
{
r
.
byUser
=
make
(
map
[
int64
][]
service
.
RedeemCode
)
}
r
.
byUser
[
userID
]
=
append
([]
service
.
RedeemCode
(
nil
),
codes
...
)
}
func
(
stubRedeemCodeRepo
)
Create
(
ctx
context
.
Context
,
code
*
service
.
RedeemCode
)
error
{
return
errors
.
New
(
"not implemented"
)
...
...
@@ -986,11 +1130,35 @@ func (stubRedeemCodeRepo) ListWithFilters(ctx context.Context, params pagination
return
nil
,
nil
,
errors
.
New
(
"not implemented"
)
}
func
(
stubRedeemCodeRepo
)
ListByUser
(
ctx
context
.
Context
,
userID
int64
,
limit
int
)
([]
service
.
RedeemCode
,
error
)
{
return
nil
,
errors
.
New
(
"not implemented"
)
func
(
r
*
stubRedeemCodeRepo
)
ListByUser
(
ctx
context
.
Context
,
userID
int64
,
limit
int
)
([]
service
.
RedeemCode
,
error
)
{
if
r
.
byUser
==
nil
{
return
nil
,
nil
}
codes
:=
r
.
byUser
[
userID
]
if
limit
>
0
&&
len
(
codes
)
>
limit
{
codes
=
codes
[
:
limit
]
}
return
append
([]
service
.
RedeemCode
(
nil
),
codes
...
),
nil
}
type
stubUserSubscriptionRepo
struct
{
byUser
map
[
int64
][]
service
.
UserSubscription
activeByUser
map
[
int64
][]
service
.
UserSubscription
}
func
(
r
*
stubUserSubscriptionRepo
)
SetByUserID
(
userID
int64
,
subs
[]
service
.
UserSubscription
)
{
if
r
.
byUser
==
nil
{
r
.
byUser
=
make
(
map
[
int64
][]
service
.
UserSubscription
)
}
r
.
byUser
[
userID
]
=
append
([]
service
.
UserSubscription
(
nil
),
subs
...
)
}
type
stubUserSubscriptionRepo
struct
{}
func
(
r
*
stubUserSubscriptionRepo
)
SetActiveByUserID
(
userID
int64
,
subs
[]
service
.
UserSubscription
)
{
if
r
.
activeByUser
==
nil
{
r
.
activeByUser
=
make
(
map
[
int64
][]
service
.
UserSubscription
)
}
r
.
activeByUser
[
userID
]
=
append
([]
service
.
UserSubscription
(
nil
),
subs
...
)
}
func
(
stubUserSubscriptionRepo
)
Create
(
ctx
context
.
Context
,
sub
*
service
.
UserSubscription
)
error
{
return
errors
.
New
(
"not implemented"
)
...
...
@@ -1010,16 +1178,22 @@ func (stubUserSubscriptionRepo) Update(ctx context.Context, sub *service.UserSub
func
(
stubUserSubscriptionRepo
)
Delete
(
ctx
context
.
Context
,
id
int64
)
error
{
return
errors
.
New
(
"not implemented"
)
}
func
(
stubUserSubscriptionRepo
)
ListByUserID
(
ctx
context
.
Context
,
userID
int64
)
([]
service
.
UserSubscription
,
error
)
{
return
nil
,
errors
.
New
(
"not implemented"
)
func
(
r
*
stubUserSubscriptionRepo
)
ListByUserID
(
ctx
context
.
Context
,
userID
int64
)
([]
service
.
UserSubscription
,
error
)
{
if
r
.
byUser
==
nil
{
return
nil
,
nil
}
return
append
([]
service
.
UserSubscription
(
nil
),
r
.
byUser
[
userID
]
...
),
nil
}
func
(
stubUserSubscriptionRepo
)
ListActiveByUserID
(
ctx
context
.
Context
,
userID
int64
)
([]
service
.
UserSubscription
,
error
)
{
return
nil
,
errors
.
New
(
"not implemented"
)
func
(
r
*
stubUserSubscriptionRepo
)
ListActiveByUserID
(
ctx
context
.
Context
,
userID
int64
)
([]
service
.
UserSubscription
,
error
)
{
if
r
.
activeByUser
==
nil
{
return
nil
,
nil
}
return
append
([]
service
.
UserSubscription
(
nil
),
r
.
activeByUser
[
userID
]
...
),
nil
}
func
(
stubUserSubscriptionRepo
)
ListByGroupID
(
ctx
context
.
Context
,
groupID
int64
,
params
pagination
.
PaginationParams
)
([]
service
.
UserSubscription
,
*
pagination
.
PaginationResult
,
error
)
{
return
nil
,
nil
,
errors
.
New
(
"not implemented"
)
}
func
(
stubUserSubscriptionRepo
)
List
(
ctx
context
.
Context
,
params
pagination
.
PaginationParams
,
userID
,
groupID
*
int64
,
status
string
)
([]
service
.
UserSubscription
,
*
pagination
.
PaginationResult
,
error
)
{
func
(
stubUserSubscriptionRepo
)
List
(
ctx
context
.
Context
,
params
pagination
.
PaginationParams
,
userID
,
groupID
*
int64
,
status
,
sortBy
,
sortOrder
string
)
([]
service
.
UserSubscription
,
*
pagination
.
PaginationResult
,
error
)
{
return
nil
,
nil
,
errors
.
New
(
"not implemented"
)
}
func
(
stubUserSubscriptionRepo
)
ExistsByUserIDAndGroupID
(
ctx
context
.
Context
,
userID
,
groupID
int64
)
(
bool
,
error
)
{
...
...
@@ -1319,11 +1493,11 @@ func (r *stubUsageLogRepo) GetDashboardStats(ctx context.Context) (*usagestats.D
return
nil
,
errors
.
New
(
"not implemented"
)
}
func
(
r
*
stubUsageLogRepo
)
GetUsageTrendWithFilters
(
ctx
context
.
Context
,
startTime
,
endTime
time
.
Time
,
granularity
string
,
userID
,
apiKeyID
,
accountID
,
groupID
int64
,
model
string
,
stream
*
bool
)
([]
usagestats
.
TrendDataPoint
,
error
)
{
func
(
r
*
stubUsageLogRepo
)
GetUsageTrendWithFilters
(
ctx
context
.
Context
,
startTime
,
endTime
time
.
Time
,
granularity
string
,
userID
,
apiKeyID
,
accountID
,
groupID
int64
,
model
string
,
stream
*
bool
,
billingType
*
int8
)
([]
usagestats
.
TrendDataPoint
,
error
)
{
return
nil
,
errors
.
New
(
"not implemented"
)
}
func
(
r
*
stubUsageLogRepo
)
GetModelStatsWithFilters
(
ctx
context
.
Context
,
startTime
,
endTime
time
.
Time
,
userID
,
apiKeyID
,
accountID
,
groupID
int64
,
stream
*
bool
)
([]
usagestats
.
ModelStat
,
error
)
{
func
(
r
*
stubUsageLogRepo
)
GetModelStatsWithFilters
(
ctx
context
.
Context
,
startTime
,
endTime
time
.
Time
,
userID
,
apiKeyID
,
accountID
,
groupID
int64
,
stream
*
bool
,
billingType
*
int8
)
([]
usagestats
.
ModelStat
,
error
)
{
return
nil
,
errors
.
New
(
"not implemented"
)
}
...
...
backend/internal/server/middleware/api_key_auth_test.go
View file @
0170d19f
...
...
@@ -367,7 +367,7 @@ func (r *stubUserSubscriptionRepo) ListByGroupID(ctx context.Context, groupID in
return
nil
,
nil
,
errors
.
New
(
"not implemented"
)
}
func
(
r
*
stubUserSubscriptionRepo
)
List
(
ctx
context
.
Context
,
params
pagination
.
PaginationParams
,
userID
,
groupID
*
int64
,
status
string
)
([]
service
.
UserSubscription
,
*
pagination
.
PaginationResult
,
error
)
{
func
(
r
*
stubUserSubscriptionRepo
)
List
(
ctx
context
.
Context
,
params
pagination
.
PaginationParams
,
userID
,
groupID
*
int64
,
status
,
sortBy
,
sortOrder
string
)
([]
service
.
UserSubscription
,
*
pagination
.
PaginationResult
,
error
)
{
return
nil
,
nil
,
errors
.
New
(
"not implemented"
)
}
...
...
backend/internal/server/routes/admin.go
View file @
0170d19f
...
...
@@ -29,6 +29,9 @@ func RegisterAdminRoutes(
// 账号管理
registerAccountRoutes
(
admin
,
h
)
// 公告管理
registerAnnouncementRoutes
(
admin
,
h
)
// OpenAI OAuth
registerOpenAIOAuthRoutes
(
admin
,
h
)
...
...
@@ -197,7 +200,6 @@ func registerAccountRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
accounts
:=
admin
.
Group
(
"/accounts"
)
{
accounts
.
GET
(
""
,
h
.
Admin
.
Account
.
List
)
accounts
.
POST
(
"/lookup"
,
h
.
Admin
.
Account
.
Lookup
)
accounts
.
GET
(
"/:id"
,
h
.
Admin
.
Account
.
GetByID
)
accounts
.
POST
(
""
,
h
.
Admin
.
Account
.
Create
)
accounts
.
POST
(
"/sync/crs"
,
h
.
Admin
.
Account
.
SyncFromCRS
)
...
...
@@ -230,6 +232,18 @@ func registerAccountRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
}
}
func
registerAnnouncementRoutes
(
admin
*
gin
.
RouterGroup
,
h
*
handler
.
Handlers
)
{
announcements
:=
admin
.
Group
(
"/announcements"
)
{
announcements
.
GET
(
""
,
h
.
Admin
.
Announcement
.
List
)
announcements
.
POST
(
""
,
h
.
Admin
.
Announcement
.
Create
)
announcements
.
GET
(
"/:id"
,
h
.
Admin
.
Announcement
.
GetByID
)
announcements
.
PUT
(
"/:id"
,
h
.
Admin
.
Announcement
.
Update
)
announcements
.
DELETE
(
"/:id"
,
h
.
Admin
.
Announcement
.
Delete
)
announcements
.
GET
(
"/:id/read-status"
,
h
.
Admin
.
Announcement
.
ListReadStatus
)
}
}
func
registerOpenAIOAuthRoutes
(
admin
*
gin
.
RouterGroup
,
h
*
handler
.
Handlers
)
{
openai
:=
admin
.
Group
(
"/openai"
)
{
...
...
@@ -355,6 +369,9 @@ func registerUsageRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
usage
.
GET
(
"/stats"
,
h
.
Admin
.
Usage
.
Stats
)
usage
.
GET
(
"/search-users"
,
h
.
Admin
.
Usage
.
SearchUsers
)
usage
.
GET
(
"/search-api-keys"
,
h
.
Admin
.
Usage
.
SearchAPIKeys
)
usage
.
GET
(
"/cleanup-tasks"
,
h
.
Admin
.
Usage
.
ListCleanupTasks
)
usage
.
POST
(
"/cleanup-tasks"
,
h
.
Admin
.
Usage
.
CreateCleanupTask
)
usage
.
POST
(
"/cleanup-tasks/:id/cancel"
,
h
.
Admin
.
Usage
.
CancelCleanupTask
)
}
}
...
...
backend/internal/server/routes/auth.go
View file @
0170d19f
...
...
@@ -26,11 +26,20 @@ func RegisterAuthRoutes(
{
auth
.
POST
(
"/register"
,
h
.
Auth
.
Register
)
auth
.
POST
(
"/login"
,
h
.
Auth
.
Login
)
auth
.
POST
(
"/login/2fa"
,
h
.
Auth
.
Login2FA
)
auth
.
POST
(
"/send-verify-code"
,
h
.
Auth
.
SendVerifyCode
)
// 优惠码验证接口添加速率限制:每分钟最多 10 次(Redis 故障时 fail-close)
auth
.
POST
(
"/validate-promo-code"
,
rateLimiter
.
LimitWithOptions
(
"validate-promo"
,
10
,
time
.
Minute
,
middleware
.
RateLimitOptions
{
FailureMode
:
middleware
.
RateLimitFailClose
,
}),
h
.
Auth
.
ValidatePromoCode
)
// 忘记密码接口添加速率限制:每分钟最多 5 次(Redis 故障时 fail-close)
auth
.
POST
(
"/forgot-password"
,
rateLimiter
.
LimitWithOptions
(
"forgot-password"
,
5
,
time
.
Minute
,
middleware
.
RateLimitOptions
{
FailureMode
:
middleware
.
RateLimitFailClose
,
}),
h
.
Auth
.
ForgotPassword
)
// 重置密码接口添加速率限制:每分钟最多 10 次(Redis 故障时 fail-close)
auth
.
POST
(
"/reset-password"
,
rateLimiter
.
LimitWithOptions
(
"reset-password"
,
10
,
time
.
Minute
,
middleware
.
RateLimitOptions
{
FailureMode
:
middleware
.
RateLimitFailClose
,
}),
h
.
Auth
.
ResetPassword
)
auth
.
GET
(
"/oauth/linuxdo/start"
,
h
.
Auth
.
LinuxDoOAuthStart
)
auth
.
GET
(
"/oauth/linuxdo/callback"
,
h
.
Auth
.
LinuxDoOAuthCallback
)
}
...
...
backend/internal/server/routes/user.go
View file @
0170d19f
...
...
@@ -22,6 +22,17 @@ func RegisterUserRoutes(
user
.
GET
(
"/profile"
,
h
.
User
.
GetProfile
)
user
.
PUT
(
"/password"
,
h
.
User
.
ChangePassword
)
user
.
PUT
(
""
,
h
.
User
.
UpdateProfile
)
// TOTP 双因素认证
totp
:=
user
.
Group
(
"/totp"
)
{
totp
.
GET
(
"/status"
,
h
.
Totp
.
GetStatus
)
totp
.
GET
(
"/verification-method"
,
h
.
Totp
.
GetVerificationMethod
)
totp
.
POST
(
"/send-code"
,
h
.
Totp
.
SendVerifyCode
)
totp
.
POST
(
"/setup"
,
h
.
Totp
.
InitiateSetup
)
totp
.
POST
(
"/enable"
,
h
.
Totp
.
Enable
)
totp
.
POST
(
"/disable"
,
h
.
Totp
.
Disable
)
}
}
// API Key管理
...
...
@@ -53,6 +64,13 @@ func RegisterUserRoutes(
usage
.
POST
(
"/dashboard/api-keys-usage"
,
h
.
Usage
.
DashboardAPIKeysUsage
)
}
// 公告(用户可见)
announcements
:=
authenticated
.
Group
(
"/announcements"
)
{
announcements
.
GET
(
""
,
h
.
Announcement
.
List
)
announcements
.
POST
(
"/:id/read"
,
h
.
Announcement
.
MarkRead
)
}
// 卡密兑换
redeem
:=
authenticated
.
Group
(
"/redeem"
)
{
...
...
backend/internal/service/account.go
View file @
0170d19f
...
...
@@ -197,6 +197,35 @@ func (a *Account) GetCredentialAsTime(key string) *time.Time {
return
nil
}
// GetCredentialAsInt64 解析凭证中的 int64 字段
// 用于读取 _token_version 等内部字段
func
(
a
*
Account
)
GetCredentialAsInt64
(
key
string
)
int64
{
if
a
==
nil
||
a
.
Credentials
==
nil
{
return
0
}
val
,
ok
:=
a
.
Credentials
[
key
]
if
!
ok
||
val
==
nil
{
return
0
}
switch
v
:=
val
.
(
type
)
{
case
int64
:
return
v
case
float64
:
return
int64
(
v
)
case
int
:
return
int64
(
v
)
case
json
.
Number
:
if
i
,
err
:=
v
.
Int64
();
err
==
nil
{
return
i
}
case
string
:
if
i
,
err
:=
strconv
.
ParseInt
(
strings
.
TrimSpace
(
v
),
10
,
64
);
err
==
nil
{
return
i
}
}
return
0
}
func
(
a
*
Account
)
IsTempUnschedulableEnabled
()
bool
{
if
a
.
Credentials
==
nil
{
return
false
...
...
@@ -576,6 +605,44 @@ func (a *Account) IsAnthropicOAuthOrSetupToken() bool {
return
a
.
Platform
==
PlatformAnthropic
&&
(
a
.
Type
==
AccountTypeOAuth
||
a
.
Type
==
AccountTypeSetupToken
)
}
// IsTLSFingerprintEnabled 检查是否启用 TLS 指纹伪装
// 仅适用于 Anthropic OAuth/SetupToken 类型账号
// 启用后将模拟 Claude Code (Node.js) 客户端的 TLS 握手特征
func
(
a
*
Account
)
IsTLSFingerprintEnabled
()
bool
{
// 仅支持 Anthropic OAuth/SetupToken 账号
if
!
a
.
IsAnthropicOAuthOrSetupToken
()
{
return
false
}
if
a
.
Extra
==
nil
{
return
false
}
if
v
,
ok
:=
a
.
Extra
[
"enable_tls_fingerprint"
];
ok
{
if
enabled
,
ok
:=
v
.
(
bool
);
ok
{
return
enabled
}
}
return
false
}
// IsSessionIDMaskingEnabled 检查是否启用会话ID伪装
// 仅适用于 Anthropic OAuth/SetupToken 类型账号
// 启用后将在一段时间内(15分钟)固定 metadata.user_id 中的 session ID,
// 使上游认为请求来自同一个会话
func
(
a
*
Account
)
IsSessionIDMaskingEnabled
()
bool
{
if
!
a
.
IsAnthropicOAuthOrSetupToken
()
{
return
false
}
if
a
.
Extra
==
nil
{
return
false
}
if
v
,
ok
:=
a
.
Extra
[
"session_id_masking_enabled"
];
ok
{
if
enabled
,
ok
:=
v
.
(
bool
);
ok
{
return
enabled
}
}
return
false
}
// GetWindowCostLimit 获取 5h 窗口费用阈值(美元)
// 返回 0 表示未启用
func
(
a
*
Account
)
GetWindowCostLimit
()
float64
{
...
...
@@ -652,6 +719,23 @@ func (a *Account) CheckWindowCostSchedulability(currentWindowCost float64) Windo
return
WindowCostNotSchedulable
}
// GetCurrentWindowStartTime 获取当前有效的窗口开始时间
// 逻辑:
// 1. 如果窗口未过期(SessionWindowEnd 存在且在当前时间之后),使用记录的 SessionWindowStart
// 2. 否则(窗口过期或未设置),使用新的预测窗口开始时间(从当前整点开始)
func
(
a
*
Account
)
GetCurrentWindowStartTime
()
time
.
Time
{
now
:=
time
.
Now
()
// 窗口未过期,使用记录的窗口开始时间
if
a
.
SessionWindowStart
!=
nil
&&
a
.
SessionWindowEnd
!=
nil
&&
now
.
Before
(
*
a
.
SessionWindowEnd
)
{
return
*
a
.
SessionWindowStart
}
// 窗口已过期或未设置,预测新的窗口开始时间(从当前整点开始)
// 与 ratelimit_service.go 中 UpdateSessionWindow 的预测逻辑保持一致
return
time
.
Date
(
now
.
Year
(),
now
.
Month
(),
now
.
Day
(),
now
.
Hour
(),
0
,
0
,
0
,
now
.
Location
())
}
// parseExtraFloat64 从 extra 字段解析 float64 值
func
parseExtraFloat64
(
value
any
)
float64
{
switch
v
:=
value
.
(
type
)
{
...
...
backend/internal/service/account_service.go
View file @
0170d19f
...
...
@@ -33,7 +33,6 @@ type AccountRepository interface {
ListByGroup
(
ctx
context
.
Context
,
groupID
int64
)
([]
Account
,
error
)
ListActive
(
ctx
context
.
Context
)
([]
Account
,
error
)
ListByPlatform
(
ctx
context
.
Context
,
platform
string
)
([]
Account
,
error
)
ListByPlatformAndCredentialEmails
(
ctx
context
.
Context
,
platform
string
,
emails
[]
string
)
([]
Account
,
error
)
UpdateLastUsed
(
ctx
context
.
Context
,
id
int64
)
error
BatchUpdateLastUsed
(
ctx
context
.
Context
,
updates
map
[
int64
]
time
.
Time
)
error
...
...
backend/internal/service/account_service_delete_test.go
View file @
0170d19f
...
...
@@ -87,10 +87,6 @@ func (s *accountRepoStub) ListByPlatform(ctx context.Context, platform string) (
panic
(
"unexpected ListByPlatform call"
)
}
func
(
s
*
accountRepoStub
)
ListByPlatformAndCredentialEmails
(
ctx
context
.
Context
,
platform
string
,
emails
[]
string
)
([]
Account
,
error
)
{
panic
(
"unexpected ListByPlatformAndCredentialEmails call"
)
}
func
(
s
*
accountRepoStub
)
UpdateLastUsed
(
ctx
context
.
Context
,
id
int64
)
error
{
panic
(
"unexpected UpdateLastUsed call"
)
}
...
...
backend/internal/service/account_test_service.go
View file @
0170d19f
...
...
@@ -265,7 +265,7 @@ func (s *AccountTestService) testClaudeAccountConnection(c *gin.Context, account
proxyURL
=
account
.
Proxy
.
URL
()
}
resp
,
err
:=
s
.
httpUpstream
.
Do
(
req
,
proxyURL
,
account
.
ID
,
account
.
Concurrency
)
resp
,
err
:=
s
.
httpUpstream
.
Do
WithTLS
(
req
,
proxyURL
,
account
.
ID
,
account
.
Concurrency
,
account
.
IsTLSFingerprintEnabled
()
)
if
err
!=
nil
{
return
s
.
sendErrorAndEnd
(
c
,
fmt
.
Sprintf
(
"Request failed: %s"
,
err
.
Error
()))
}
...
...
@@ -375,7 +375,7 @@ func (s *AccountTestService) testOpenAIAccountConnection(c *gin.Context, account
proxyURL
=
account
.
Proxy
.
URL
()
}
resp
,
err
:=
s
.
httpUpstream
.
Do
(
req
,
proxyURL
,
account
.
ID
,
account
.
Concurrency
)
resp
,
err
:=
s
.
httpUpstream
.
Do
WithTLS
(
req
,
proxyURL
,
account
.
ID
,
account
.
Concurrency
,
account
.
IsTLSFingerprintEnabled
()
)
if
err
!=
nil
{
return
s
.
sendErrorAndEnd
(
c
,
fmt
.
Sprintf
(
"Request failed: %s"
,
err
.
Error
()))
}
...
...
@@ -446,7 +446,7 @@ func (s *AccountTestService) testGeminiAccountConnection(c *gin.Context, account
proxyURL
=
account
.
Proxy
.
URL
()
}
resp
,
err
:=
s
.
httpUpstream
.
Do
(
req
,
proxyURL
,
account
.
ID
,
account
.
Concurrency
)
resp
,
err
:=
s
.
httpUpstream
.
Do
WithTLS
(
req
,
proxyURL
,
account
.
ID
,
account
.
Concurrency
,
account
.
IsTLSFingerprintEnabled
()
)
if
err
!=
nil
{
return
s
.
sendErrorAndEnd
(
c
,
fmt
.
Sprintf
(
"Request failed: %s"
,
err
.
Error
()))
}
...
...
backend/internal/service/account_usage_service.go
View file @
0170d19f
...
...
@@ -32,8 +32,8 @@ type UsageLogRepository interface {
// Admin dashboard stats
GetDashboardStats
(
ctx
context
.
Context
)
(
*
usagestats
.
DashboardStats
,
error
)
GetUsageTrendWithFilters
(
ctx
context
.
Context
,
startTime
,
endTime
time
.
Time
,
granularity
string
,
userID
,
apiKeyID
,
accountID
,
groupID
int64
,
model
string
,
stream
*
bool
)
([]
usagestats
.
TrendDataPoint
,
error
)
GetModelStatsWithFilters
(
ctx
context
.
Context
,
startTime
,
endTime
time
.
Time
,
userID
,
apiKeyID
,
accountID
,
groupID
int64
,
stream
*
bool
)
([]
usagestats
.
ModelStat
,
error
)
GetUsageTrendWithFilters
(
ctx
context
.
Context
,
startTime
,
endTime
time
.
Time
,
granularity
string
,
userID
,
apiKeyID
,
accountID
,
groupID
int64
,
model
string
,
stream
*
bool
,
billingType
*
int8
)
([]
usagestats
.
TrendDataPoint
,
error
)
GetModelStatsWithFilters
(
ctx
context
.
Context
,
startTime
,
endTime
time
.
Time
,
userID
,
apiKeyID
,
accountID
,
groupID
int64
,
stream
*
bool
,
billingType
*
int8
)
([]
usagestats
.
ModelStat
,
error
)
GetAPIKeyUsageTrend
(
ctx
context
.
Context
,
startTime
,
endTime
time
.
Time
,
granularity
string
,
limit
int
)
([]
usagestats
.
APIKeyUsageTrendPoint
,
error
)
GetUserUsageTrend
(
ctx
context
.
Context
,
startTime
,
endTime
time
.
Time
,
granularity
string
,
limit
int
)
([]
usagestats
.
UserUsageTrendPoint
,
error
)
GetBatchUserUsageStats
(
ctx
context
.
Context
,
userIDs
[]
int64
)
(
map
[
int64
]
*
usagestats
.
BatchUserUsageStats
,
error
)
...
...
@@ -157,9 +157,20 @@ type ClaudeUsageResponse struct {
}
`json:"seven_day_sonnet"`
}
// ClaudeUsageFetchOptions 包含获取 Claude 用量数据所需的所有选项
type
ClaudeUsageFetchOptions
struct
{
AccessToken
string
// OAuth access token
ProxyURL
string
// 代理 URL(可选)
AccountID
int64
// 账号 ID(用于 TLS 指纹选择)
EnableTLSFingerprint
bool
// 是否启用 TLS 指纹伪装
Fingerprint
*
Fingerprint
// 缓存的指纹信息(User-Agent 等)
}
// ClaudeUsageFetcher fetches usage data from Anthropic OAuth API
type
ClaudeUsageFetcher
interface
{
FetchUsage
(
ctx
context
.
Context
,
accessToken
,
proxyURL
string
)
(
*
ClaudeUsageResponse
,
error
)
// FetchUsageWithOptions 使用完整选项获取用量数据,支持 TLS 指纹和自定义 User-Agent
FetchUsageWithOptions
(
ctx
context
.
Context
,
opts
*
ClaudeUsageFetchOptions
)
(
*
ClaudeUsageResponse
,
error
)
}
// AccountUsageService 账号使用量查询服务
...
...
@@ -170,6 +181,7 @@ type AccountUsageService struct {
geminiQuotaService
*
GeminiQuotaService
antigravityQuotaFetcher
*
AntigravityQuotaFetcher
cache
*
UsageCache
identityCache
IdentityCache
}
// NewAccountUsageService 创建AccountUsageService实例
...
...
@@ -180,6 +192,7 @@ func NewAccountUsageService(
geminiQuotaService
*
GeminiQuotaService
,
antigravityQuotaFetcher
*
AntigravityQuotaFetcher
,
cache
*
UsageCache
,
identityCache
IdentityCache
,
)
*
AccountUsageService
{
return
&
AccountUsageService
{
accountRepo
:
accountRepo
,
...
...
@@ -188,6 +201,7 @@ func NewAccountUsageService(
geminiQuotaService
:
geminiQuotaService
,
antigravityQuotaFetcher
:
antigravityQuotaFetcher
,
cache
:
cache
,
identityCache
:
identityCache
,
}
}
...
...
@@ -272,7 +286,7 @@ func (s *AccountUsageService) getGeminiUsage(ctx context.Context, account *Accou
}
dayStart
:=
geminiDailyWindowStart
(
now
)
stats
,
err
:=
s
.
usageLogRepo
.
GetModelStatsWithFilters
(
ctx
,
dayStart
,
now
,
0
,
0
,
account
.
ID
,
0
,
nil
)
stats
,
err
:=
s
.
usageLogRepo
.
GetModelStatsWithFilters
(
ctx
,
dayStart
,
now
,
0
,
0
,
account
.
ID
,
0
,
nil
,
nil
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"get gemini usage stats failed: %w"
,
err
)
}
...
...
@@ -294,7 +308,7 @@ func (s *AccountUsageService) getGeminiUsage(ctx context.Context, account *Accou
// Minute window (RPM) - fixed-window approximation: current minute [truncate(now), truncate(now)+1m)
minuteStart
:=
now
.
Truncate
(
time
.
Minute
)
minuteResetAt
:=
minuteStart
.
Add
(
time
.
Minute
)
minuteStats
,
err
:=
s
.
usageLogRepo
.
GetModelStatsWithFilters
(
ctx
,
minuteStart
,
now
,
0
,
0
,
account
.
ID
,
0
,
nil
)
minuteStats
,
err
:=
s
.
usageLogRepo
.
GetModelStatsWithFilters
(
ctx
,
minuteStart
,
now
,
0
,
0
,
account
.
ID
,
0
,
nil
,
nil
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"get gemini minute usage stats failed: %w"
,
err
)
}
...
...
@@ -369,12 +383,8 @@ func (s *AccountUsageService) addWindowStats(ctx context.Context, account *Accou
// 如果没有缓存,从数据库查询
if
windowStats
==
nil
{
var
startTime
time
.
Time
if
account
.
SessionWindowStart
!=
nil
{
startTime
=
*
account
.
SessionWindowStart
}
else
{
startTime
=
time
.
Now
()
.
Add
(
-
5
*
time
.
Hour
)
}
// 使用统一的窗口开始时间计算逻辑(考虑窗口过期情况)
startTime
:=
account
.
GetCurrentWindowStartTime
()
stats
,
err
:=
s
.
usageLogRepo
.
GetAccountWindowStats
(
ctx
,
account
.
ID
,
startTime
)
if
err
!=
nil
{
...
...
@@ -428,6 +438,8 @@ func (s *AccountUsageService) GetAccountUsageStats(ctx context.Context, accountI
}
// fetchOAuthUsageRaw 从 Anthropic API 获取原始响应(不构建 UsageInfo)
// 如果账号开启了 TLS 指纹,则使用 TLS 指纹伪装
// 如果有缓存的 Fingerprint,则使用缓存的 User-Agent 等信息
func
(
s
*
AccountUsageService
)
fetchOAuthUsageRaw
(
ctx
context
.
Context
,
account
*
Account
)
(
*
ClaudeUsageResponse
,
error
)
{
accessToken
:=
account
.
GetCredential
(
"access_token"
)
if
accessToken
==
""
{
...
...
@@ -439,7 +451,22 @@ func (s *AccountUsageService) fetchOAuthUsageRaw(ctx context.Context, account *A
proxyURL
=
account
.
Proxy
.
URL
()
}
return
s
.
usageFetcher
.
FetchUsage
(
ctx
,
accessToken
,
proxyURL
)
// 构建完整的选项
opts
:=
&
ClaudeUsageFetchOptions
{
AccessToken
:
accessToken
,
ProxyURL
:
proxyURL
,
AccountID
:
account
.
ID
,
EnableTLSFingerprint
:
account
.
IsTLSFingerprintEnabled
(),
}
// 尝试获取缓存的 Fingerprint(包含 User-Agent 等信息)
if
s
.
identityCache
!=
nil
{
if
fp
,
err
:=
s
.
identityCache
.
GetFingerprint
(
ctx
,
account
.
ID
);
err
==
nil
&&
fp
!=
nil
{
opts
.
Fingerprint
=
fp
}
}
return
s
.
usageFetcher
.
FetchUsageWithOptions
(
ctx
,
opts
)
}
// parseTime 尝试多种格式解析时间
...
...
backend/internal/service/admin_service.go
View file @
0170d19f
...
...
@@ -40,7 +40,6 @@ type AdminService interface {
CreateAccount
(
ctx
context
.
Context
,
input
*
CreateAccountInput
)
(
*
Account
,
error
)
UpdateAccount
(
ctx
context
.
Context
,
id
int64
,
input
*
UpdateAccountInput
)
(
*
Account
,
error
)
DeleteAccount
(
ctx
context
.
Context
,
id
int64
)
error
LookupAccountsByCredentialEmail
(
ctx
context
.
Context
,
platform
string
,
emails
[]
string
)
([]
Account
,
error
)
RefreshAccountCredentials
(
ctx
context
.
Context
,
id
int64
)
(
*
Account
,
error
)
ClearAccountError
(
ctx
context
.
Context
,
id
int64
)
(
*
Account
,
error
)
SetAccountError
(
ctx
context
.
Context
,
id
int64
,
errorMsg
string
)
error
...
...
@@ -866,13 +865,6 @@ func (s *adminServiceImpl) GetAccount(ctx context.Context, id int64) (*Account,
return
s
.
accountRepo
.
GetByID
(
ctx
,
id
)
}
func
(
s
*
adminServiceImpl
)
LookupAccountsByCredentialEmail
(
ctx
context
.
Context
,
platform
string
,
emails
[]
string
)
([]
Account
,
error
)
{
if
platform
==
""
||
len
(
emails
)
==
0
{
return
[]
Account
{},
nil
}
return
s
.
accountRepo
.
ListByPlatformAndCredentialEmails
(
ctx
,
platform
,
emails
)
}
func
(
s
*
adminServiceImpl
)
GetAccountsByIDs
(
ctx
context
.
Context
,
ids
[]
int64
)
([]
*
Account
,
error
)
{
if
len
(
ids
)
==
0
{
return
[]
*
Account
{},
nil
...
...
backend/internal/service/admin_service_delete_test.go
View file @
0170d19f
...
...
@@ -93,6 +93,18 @@ func (s *userRepoStub) RemoveGroupFromAllowedGroups(ctx context.Context, groupID
panic
(
"unexpected RemoveGroupFromAllowedGroups call"
)
}
func
(
s
*
userRepoStub
)
UpdateTotpSecret
(
ctx
context
.
Context
,
userID
int64
,
encryptedSecret
*
string
)
error
{
panic
(
"unexpected UpdateTotpSecret call"
)
}
func
(
s
*
userRepoStub
)
EnableTotp
(
ctx
context
.
Context
,
userID
int64
)
error
{
panic
(
"unexpected EnableTotp call"
)
}
func
(
s
*
userRepoStub
)
DisableTotp
(
ctx
context
.
Context
,
userID
int64
)
error
{
panic
(
"unexpected DisableTotp call"
)
}
type
groupRepoStub
struct
{
affectedUserIDs
[]
int64
deleteErr
error
...
...
backend/internal/service/announcement.go
0 → 100644
View file @
0170d19f
package
service
import
(
"context"
"time"
"github.com/Wei-Shaw/sub2api/internal/domain"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
)
const
(
AnnouncementStatusDraft
=
domain
.
AnnouncementStatusDraft
AnnouncementStatusActive
=
domain
.
AnnouncementStatusActive
AnnouncementStatusArchived
=
domain
.
AnnouncementStatusArchived
)
const
(
AnnouncementConditionTypeSubscription
=
domain
.
AnnouncementConditionTypeSubscription
AnnouncementConditionTypeBalance
=
domain
.
AnnouncementConditionTypeBalance
)
const
(
AnnouncementOperatorIn
=
domain
.
AnnouncementOperatorIn
AnnouncementOperatorGT
=
domain
.
AnnouncementOperatorGT
AnnouncementOperatorGTE
=
domain
.
AnnouncementOperatorGTE
AnnouncementOperatorLT
=
domain
.
AnnouncementOperatorLT
AnnouncementOperatorLTE
=
domain
.
AnnouncementOperatorLTE
AnnouncementOperatorEQ
=
domain
.
AnnouncementOperatorEQ
)
var
(
ErrAnnouncementNotFound
=
domain
.
ErrAnnouncementNotFound
ErrAnnouncementInvalidTarget
=
domain
.
ErrAnnouncementInvalidTarget
)
type
AnnouncementTargeting
=
domain
.
AnnouncementTargeting
type
AnnouncementConditionGroup
=
domain
.
AnnouncementConditionGroup
type
AnnouncementCondition
=
domain
.
AnnouncementCondition
type
Announcement
=
domain
.
Announcement
type
AnnouncementListFilters
struct
{
Status
string
Search
string
}
type
AnnouncementRepository
interface
{
Create
(
ctx
context
.
Context
,
a
*
Announcement
)
error
GetByID
(
ctx
context
.
Context
,
id
int64
)
(
*
Announcement
,
error
)
Update
(
ctx
context
.
Context
,
a
*
Announcement
)
error
Delete
(
ctx
context
.
Context
,
id
int64
)
error
List
(
ctx
context
.
Context
,
params
pagination
.
PaginationParams
,
filters
AnnouncementListFilters
)
([]
Announcement
,
*
pagination
.
PaginationResult
,
error
)
ListActive
(
ctx
context
.
Context
,
now
time
.
Time
)
([]
Announcement
,
error
)
}
type
AnnouncementReadRepository
interface
{
MarkRead
(
ctx
context
.
Context
,
announcementID
,
userID
int64
,
readAt
time
.
Time
)
error
GetReadMapByUser
(
ctx
context
.
Context
,
userID
int64
,
announcementIDs
[]
int64
)
(
map
[
int64
]
time
.
Time
,
error
)
GetReadMapByUsers
(
ctx
context
.
Context
,
announcementID
int64
,
userIDs
[]
int64
)
(
map
[
int64
]
time
.
Time
,
error
)
CountByAnnouncementID
(
ctx
context
.
Context
,
announcementID
int64
)
(
int64
,
error
)
}
backend/internal/service/announcement_service.go
0 → 100644
View file @
0170d19f
package
service
import
(
"context"
"fmt"
"sort"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/domain"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
)
type
AnnouncementService
struct
{
announcementRepo
AnnouncementRepository
readRepo
AnnouncementReadRepository
userRepo
UserRepository
userSubRepo
UserSubscriptionRepository
}
func
NewAnnouncementService
(
announcementRepo
AnnouncementRepository
,
readRepo
AnnouncementReadRepository
,
userRepo
UserRepository
,
userSubRepo
UserSubscriptionRepository
,
)
*
AnnouncementService
{
return
&
AnnouncementService
{
announcementRepo
:
announcementRepo
,
readRepo
:
readRepo
,
userRepo
:
userRepo
,
userSubRepo
:
userSubRepo
,
}
}
type
CreateAnnouncementInput
struct
{
Title
string
Content
string
Status
string
Targeting
AnnouncementTargeting
StartsAt
*
time
.
Time
EndsAt
*
time
.
Time
ActorID
*
int64
// 管理员用户ID
}
type
UpdateAnnouncementInput
struct
{
Title
*
string
Content
*
string
Status
*
string
Targeting
*
AnnouncementTargeting
StartsAt
**
time
.
Time
EndsAt
**
time
.
Time
ActorID
*
int64
// 管理员用户ID
}
type
UserAnnouncement
struct
{
Announcement
Announcement
ReadAt
*
time
.
Time
}
type
AnnouncementUserReadStatus
struct
{
UserID
int64
`json:"user_id"`
Email
string
`json:"email"`
Username
string
`json:"username"`
Balance
float64
`json:"balance"`
Eligible
bool
`json:"eligible"`
ReadAt
*
time
.
Time
`json:"read_at,omitempty"`
}
func
(
s
*
AnnouncementService
)
Create
(
ctx
context
.
Context
,
input
*
CreateAnnouncementInput
)
(
*
Announcement
,
error
)
{
if
input
==
nil
{
return
nil
,
fmt
.
Errorf
(
"create announcement: nil input"
)
}
title
:=
strings
.
TrimSpace
(
input
.
Title
)
content
:=
strings
.
TrimSpace
(
input
.
Content
)
if
title
==
""
||
len
(
title
)
>
200
{
return
nil
,
fmt
.
Errorf
(
"create announcement: invalid title"
)
}
if
content
==
""
{
return
nil
,
fmt
.
Errorf
(
"create announcement: content is required"
)
}
status
:=
strings
.
TrimSpace
(
input
.
Status
)
if
status
==
""
{
status
=
AnnouncementStatusDraft
}
if
!
isValidAnnouncementStatus
(
status
)
{
return
nil
,
fmt
.
Errorf
(
"create announcement: invalid status"
)
}
targeting
,
err
:=
domain
.
AnnouncementTargeting
(
input
.
Targeting
)
.
NormalizeAndValidate
()
if
err
!=
nil
{
return
nil
,
err
}
if
input
.
StartsAt
!=
nil
&&
input
.
EndsAt
!=
nil
{
if
!
input
.
StartsAt
.
Before
(
*
input
.
EndsAt
)
{
return
nil
,
fmt
.
Errorf
(
"create announcement: starts_at must be before ends_at"
)
}
}
a
:=
&
Announcement
{
Title
:
title
,
Content
:
content
,
Status
:
status
,
Targeting
:
targeting
,
StartsAt
:
input
.
StartsAt
,
EndsAt
:
input
.
EndsAt
,
}
if
input
.
ActorID
!=
nil
&&
*
input
.
ActorID
>
0
{
a
.
CreatedBy
=
input
.
ActorID
a
.
UpdatedBy
=
input
.
ActorID
}
if
err
:=
s
.
announcementRepo
.
Create
(
ctx
,
a
);
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"create announcement: %w"
,
err
)
}
return
a
,
nil
}
func
(
s
*
AnnouncementService
)
Update
(
ctx
context
.
Context
,
id
int64
,
input
*
UpdateAnnouncementInput
)
(
*
Announcement
,
error
)
{
if
input
==
nil
{
return
nil
,
fmt
.
Errorf
(
"update announcement: nil input"
)
}
a
,
err
:=
s
.
announcementRepo
.
GetByID
(
ctx
,
id
)
if
err
!=
nil
{
return
nil
,
err
}
if
input
.
Title
!=
nil
{
title
:=
strings
.
TrimSpace
(
*
input
.
Title
)
if
title
==
""
||
len
(
title
)
>
200
{
return
nil
,
fmt
.
Errorf
(
"update announcement: invalid title"
)
}
a
.
Title
=
title
}
if
input
.
Content
!=
nil
{
content
:=
strings
.
TrimSpace
(
*
input
.
Content
)
if
content
==
""
{
return
nil
,
fmt
.
Errorf
(
"update announcement: content is required"
)
}
a
.
Content
=
content
}
if
input
.
Status
!=
nil
{
status
:=
strings
.
TrimSpace
(
*
input
.
Status
)
if
!
isValidAnnouncementStatus
(
status
)
{
return
nil
,
fmt
.
Errorf
(
"update announcement: invalid status"
)
}
a
.
Status
=
status
}
if
input
.
Targeting
!=
nil
{
targeting
,
err
:=
domain
.
AnnouncementTargeting
(
*
input
.
Targeting
)
.
NormalizeAndValidate
()
if
err
!=
nil
{
return
nil
,
err
}
a
.
Targeting
=
targeting
}
if
input
.
StartsAt
!=
nil
{
a
.
StartsAt
=
*
input
.
StartsAt
}
if
input
.
EndsAt
!=
nil
{
a
.
EndsAt
=
*
input
.
EndsAt
}
if
a
.
StartsAt
!=
nil
&&
a
.
EndsAt
!=
nil
{
if
!
a
.
StartsAt
.
Before
(
*
a
.
EndsAt
)
{
return
nil
,
fmt
.
Errorf
(
"update announcement: starts_at must be before ends_at"
)
}
}
if
input
.
ActorID
!=
nil
&&
*
input
.
ActorID
>
0
{
a
.
UpdatedBy
=
input
.
ActorID
}
if
err
:=
s
.
announcementRepo
.
Update
(
ctx
,
a
);
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"update announcement: %w"
,
err
)
}
return
a
,
nil
}
func
(
s
*
AnnouncementService
)
Delete
(
ctx
context
.
Context
,
id
int64
)
error
{
if
err
:=
s
.
announcementRepo
.
Delete
(
ctx
,
id
);
err
!=
nil
{
return
fmt
.
Errorf
(
"delete announcement: %w"
,
err
)
}
return
nil
}
func
(
s
*
AnnouncementService
)
GetByID
(
ctx
context
.
Context
,
id
int64
)
(
*
Announcement
,
error
)
{
return
s
.
announcementRepo
.
GetByID
(
ctx
,
id
)
}
func
(
s
*
AnnouncementService
)
List
(
ctx
context
.
Context
,
params
pagination
.
PaginationParams
,
filters
AnnouncementListFilters
)
([]
Announcement
,
*
pagination
.
PaginationResult
,
error
)
{
return
s
.
announcementRepo
.
List
(
ctx
,
params
,
filters
)
}
func
(
s
*
AnnouncementService
)
ListForUser
(
ctx
context
.
Context
,
userID
int64
,
unreadOnly
bool
)
([]
UserAnnouncement
,
error
)
{
user
,
err
:=
s
.
userRepo
.
GetByID
(
ctx
,
userID
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"get user: %w"
,
err
)
}
activeSubs
,
err
:=
s
.
userSubRepo
.
ListActiveByUserID
(
ctx
,
userID
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"list active subscriptions: %w"
,
err
)
}
activeGroupIDs
:=
make
(
map
[
int64
]
struct
{},
len
(
activeSubs
))
for
i
:=
range
activeSubs
{
activeGroupIDs
[
activeSubs
[
i
]
.
GroupID
]
=
struct
{}{}
}
now
:=
time
.
Now
()
anns
,
err
:=
s
.
announcementRepo
.
ListActive
(
ctx
,
now
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"list active announcements: %w"
,
err
)
}
visible
:=
make
([]
Announcement
,
0
,
len
(
anns
))
ids
:=
make
([]
int64
,
0
,
len
(
anns
))
for
i
:=
range
anns
{
a
:=
anns
[
i
]
if
!
a
.
IsActiveAt
(
now
)
{
continue
}
if
!
a
.
Targeting
.
Matches
(
user
.
Balance
,
activeGroupIDs
)
{
continue
}
visible
=
append
(
visible
,
a
)
ids
=
append
(
ids
,
a
.
ID
)
}
if
len
(
visible
)
==
0
{
return
[]
UserAnnouncement
{},
nil
}
readMap
,
err
:=
s
.
readRepo
.
GetReadMapByUser
(
ctx
,
userID
,
ids
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"get read map: %w"
,
err
)
}
out
:=
make
([]
UserAnnouncement
,
0
,
len
(
visible
))
for
i
:=
range
visible
{
a
:=
visible
[
i
]
readAt
,
ok
:=
readMap
[
a
.
ID
]
if
unreadOnly
&&
ok
{
continue
}
var
ptr
*
time
.
Time
if
ok
{
t
:=
readAt
ptr
=
&
t
}
out
=
append
(
out
,
UserAnnouncement
{
Announcement
:
a
,
ReadAt
:
ptr
,
})
}
// 未读优先、同状态按创建时间倒序
sort
.
Slice
(
out
,
func
(
i
,
j
int
)
bool
{
ai
,
aj
:=
out
[
i
],
out
[
j
]
if
(
ai
.
ReadAt
==
nil
)
!=
(
aj
.
ReadAt
==
nil
)
{
return
ai
.
ReadAt
==
nil
}
return
ai
.
Announcement
.
ID
>
aj
.
Announcement
.
ID
})
return
out
,
nil
}
func
(
s
*
AnnouncementService
)
MarkRead
(
ctx
context
.
Context
,
userID
,
announcementID
int64
)
error
{
// 安全:仅允许标记当前用户“可见”的公告
user
,
err
:=
s
.
userRepo
.
GetByID
(
ctx
,
userID
)
if
err
!=
nil
{
return
fmt
.
Errorf
(
"get user: %w"
,
err
)
}
a
,
err
:=
s
.
announcementRepo
.
GetByID
(
ctx
,
announcementID
)
if
err
!=
nil
{
return
err
}
now
:=
time
.
Now
()
if
!
a
.
IsActiveAt
(
now
)
{
return
ErrAnnouncementNotFound
}
activeSubs
,
err
:=
s
.
userSubRepo
.
ListActiveByUserID
(
ctx
,
userID
)
if
err
!=
nil
{
return
fmt
.
Errorf
(
"list active subscriptions: %w"
,
err
)
}
activeGroupIDs
:=
make
(
map
[
int64
]
struct
{},
len
(
activeSubs
))
for
i
:=
range
activeSubs
{
activeGroupIDs
[
activeSubs
[
i
]
.
GroupID
]
=
struct
{}{}
}
if
!
a
.
Targeting
.
Matches
(
user
.
Balance
,
activeGroupIDs
)
{
return
ErrAnnouncementNotFound
}
if
err
:=
s
.
readRepo
.
MarkRead
(
ctx
,
announcementID
,
userID
,
now
);
err
!=
nil
{
return
fmt
.
Errorf
(
"mark read: %w"
,
err
)
}
return
nil
}
func
(
s
*
AnnouncementService
)
ListUserReadStatus
(
ctx
context
.
Context
,
announcementID
int64
,
params
pagination
.
PaginationParams
,
search
string
,
)
([]
AnnouncementUserReadStatus
,
*
pagination
.
PaginationResult
,
error
)
{
ann
,
err
:=
s
.
announcementRepo
.
GetByID
(
ctx
,
announcementID
)
if
err
!=
nil
{
return
nil
,
nil
,
err
}
filters
:=
UserListFilters
{
Search
:
strings
.
TrimSpace
(
search
),
}
users
,
page
,
err
:=
s
.
userRepo
.
ListWithFilters
(
ctx
,
params
,
filters
)
if
err
!=
nil
{
return
nil
,
nil
,
fmt
.
Errorf
(
"list users: %w"
,
err
)
}
userIDs
:=
make
([]
int64
,
0
,
len
(
users
))
for
i
:=
range
users
{
userIDs
=
append
(
userIDs
,
users
[
i
]
.
ID
)
}
readMap
,
err
:=
s
.
readRepo
.
GetReadMapByUsers
(
ctx
,
announcementID
,
userIDs
)
if
err
!=
nil
{
return
nil
,
nil
,
fmt
.
Errorf
(
"get read map: %w"
,
err
)
}
out
:=
make
([]
AnnouncementUserReadStatus
,
0
,
len
(
users
))
for
i
:=
range
users
{
u
:=
users
[
i
]
subs
,
err
:=
s
.
userSubRepo
.
ListActiveByUserID
(
ctx
,
u
.
ID
)
if
err
!=
nil
{
return
nil
,
nil
,
fmt
.
Errorf
(
"list active subscriptions: %w"
,
err
)
}
activeGroupIDs
:=
make
(
map
[
int64
]
struct
{},
len
(
subs
))
for
j
:=
range
subs
{
activeGroupIDs
[
subs
[
j
]
.
GroupID
]
=
struct
{}{}
}
readAt
,
ok
:=
readMap
[
u
.
ID
]
var
ptr
*
time
.
Time
if
ok
{
t
:=
readAt
ptr
=
&
t
}
out
=
append
(
out
,
AnnouncementUserReadStatus
{
UserID
:
u
.
ID
,
Email
:
u
.
Email
,
Username
:
u
.
Username
,
Balance
:
u
.
Balance
,
Eligible
:
domain
.
AnnouncementTargeting
(
ann
.
Targeting
)
.
Matches
(
u
.
Balance
,
activeGroupIDs
),
ReadAt
:
ptr
,
})
}
return
out
,
page
,
nil
}
func
isValidAnnouncementStatus
(
status
string
)
bool
{
switch
status
{
case
AnnouncementStatusDraft
,
AnnouncementStatusActive
,
AnnouncementStatusArchived
:
return
true
default
:
return
false
}
}
backend/internal/service/announcement_targeting_test.go
0 → 100644
View file @
0170d19f
package
service
import
(
"testing"
"github.com/stretchr/testify/require"
)
func
TestAnnouncementTargeting_Matches_EmptyMatchesAll
(
t
*
testing
.
T
)
{
var
targeting
AnnouncementTargeting
require
.
True
(
t
,
targeting
.
Matches
(
0
,
nil
))
require
.
True
(
t
,
targeting
.
Matches
(
123.45
,
map
[
int64
]
struct
{}{
1
:
{}}))
}
func
TestAnnouncementTargeting_NormalizeAndValidate_RejectsEmptyGroup
(
t
*
testing
.
T
)
{
targeting
:=
AnnouncementTargeting
{
AnyOf
:
[]
AnnouncementConditionGroup
{
{
AllOf
:
nil
},
},
}
_
,
err
:=
targeting
.
NormalizeAndValidate
()
require
.
Error
(
t
,
err
)
require
.
ErrorIs
(
t
,
err
,
ErrAnnouncementInvalidTarget
)
}
func
TestAnnouncementTargeting_NormalizeAndValidate_RejectsInvalidCondition
(
t
*
testing
.
T
)
{
targeting
:=
AnnouncementTargeting
{
AnyOf
:
[]
AnnouncementConditionGroup
{
{
AllOf
:
[]
AnnouncementCondition
{
{
Type
:
"balance"
,
Operator
:
"between"
,
Value
:
10
},
},
},
},
}
_
,
err
:=
targeting
.
NormalizeAndValidate
()
require
.
Error
(
t
,
err
)
require
.
ErrorIs
(
t
,
err
,
ErrAnnouncementInvalidTarget
)
}
func
TestAnnouncementTargeting_Matches_AndOrSemantics
(
t
*
testing
.
T
)
{
targeting
:=
AnnouncementTargeting
{
AnyOf
:
[]
AnnouncementConditionGroup
{
{
AllOf
:
[]
AnnouncementCondition
{
{
Type
:
AnnouncementConditionTypeBalance
,
Operator
:
AnnouncementOperatorGTE
,
Value
:
100
},
{
Type
:
AnnouncementConditionTypeSubscription
,
Operator
:
AnnouncementOperatorIn
,
GroupIDs
:
[]
int64
{
10
}},
},
},
{
AllOf
:
[]
AnnouncementCondition
{
{
Type
:
AnnouncementConditionTypeBalance
,
Operator
:
AnnouncementOperatorLT
,
Value
:
5
},
},
},
},
}
// 命中第 2 组(balance < 5)
require
.
True
(
t
,
targeting
.
Matches
(
4.99
,
nil
))
require
.
False
(
t
,
targeting
.
Matches
(
5
,
nil
))
// 命中第 1 组(balance >= 100 AND 订阅 in [10])
require
.
False
(
t
,
targeting
.
Matches
(
100
,
map
[
int64
]
struct
{}{}))
require
.
False
(
t
,
targeting
.
Matches
(
99.9
,
map
[
int64
]
struct
{}{
10
:
{}}))
require
.
True
(
t
,
targeting
.
Matches
(
100
,
map
[
int64
]
struct
{}{
10
:
{}}))
}
Prev
1
…
4
5
6
7
8
9
10
11
12
…
16
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment