"backend/git@web.lueluesay.top:chenxi/sub2api.git" did not exist on "421728a985af55133d2a2d07d10df1f05728bcef"
Commit 0170d19f authored by song's avatar song
Browse files

merge upstream main

parent 7ade9baa
...@@ -944,17 +944,17 @@ func (s *UsageLogRepoSuite) TestGetUsageTrendWithFilters() { ...@@ -944,17 +944,17 @@ func (s *UsageLogRepoSuite) TestGetUsageTrendWithFilters() {
endTime := base.Add(48 * time.Hour) endTime := base.Add(48 * time.Hour)
// Test with user filter // Test with user filter
trend, err := s.repo.GetUsageTrendWithFilters(s.ctx, startTime, endTime, "day", user.ID, 0, 0, 0, "", nil) trend, err := s.repo.GetUsageTrendWithFilters(s.ctx, startTime, endTime, "day", user.ID, 0, 0, 0, "", nil, nil)
s.Require().NoError(err, "GetUsageTrendWithFilters user filter") s.Require().NoError(err, "GetUsageTrendWithFilters user filter")
s.Require().Len(trend, 2) s.Require().Len(trend, 2)
// Test with apiKey filter // Test with apiKey filter
trend, err = s.repo.GetUsageTrendWithFilters(s.ctx, startTime, endTime, "day", 0, apiKey.ID, 0, 0, "", nil) trend, err = s.repo.GetUsageTrendWithFilters(s.ctx, startTime, endTime, "day", 0, apiKey.ID, 0, 0, "", nil, nil)
s.Require().NoError(err, "GetUsageTrendWithFilters apiKey filter") s.Require().NoError(err, "GetUsageTrendWithFilters apiKey filter")
s.Require().Len(trend, 2) s.Require().Len(trend, 2)
// Test with both filters // Test with both filters
trend, err = s.repo.GetUsageTrendWithFilters(s.ctx, startTime, endTime, "day", user.ID, apiKey.ID, 0, 0, "", nil) trend, err = s.repo.GetUsageTrendWithFilters(s.ctx, startTime, endTime, "day", user.ID, apiKey.ID, 0, 0, "", nil, nil)
s.Require().NoError(err, "GetUsageTrendWithFilters both filters") s.Require().NoError(err, "GetUsageTrendWithFilters both filters")
s.Require().Len(trend, 2) s.Require().Len(trend, 2)
} }
...@@ -971,7 +971,7 @@ func (s *UsageLogRepoSuite) TestGetUsageTrendWithFilters_HourlyGranularity() { ...@@ -971,7 +971,7 @@ func (s *UsageLogRepoSuite) TestGetUsageTrendWithFilters_HourlyGranularity() {
startTime := base.Add(-1 * time.Hour) startTime := base.Add(-1 * time.Hour)
endTime := base.Add(3 * time.Hour) endTime := base.Add(3 * time.Hour)
trend, err := s.repo.GetUsageTrendWithFilters(s.ctx, startTime, endTime, "hour", user.ID, 0, 0, 0, "", nil) trend, err := s.repo.GetUsageTrendWithFilters(s.ctx, startTime, endTime, "hour", user.ID, 0, 0, 0, "", nil, nil)
s.Require().NoError(err, "GetUsageTrendWithFilters hourly") s.Require().NoError(err, "GetUsageTrendWithFilters hourly")
s.Require().Len(trend, 2) s.Require().Len(trend, 2)
} }
...@@ -1017,17 +1017,17 @@ func (s *UsageLogRepoSuite) TestGetModelStatsWithFilters() { ...@@ -1017,17 +1017,17 @@ func (s *UsageLogRepoSuite) TestGetModelStatsWithFilters() {
endTime := base.Add(2 * time.Hour) endTime := base.Add(2 * time.Hour)
// Test with user filter // Test with user filter
stats, err := s.repo.GetModelStatsWithFilters(s.ctx, startTime, endTime, user.ID, 0, 0, 0, nil) stats, err := s.repo.GetModelStatsWithFilters(s.ctx, startTime, endTime, user.ID, 0, 0, 0, nil, nil)
s.Require().NoError(err, "GetModelStatsWithFilters user filter") s.Require().NoError(err, "GetModelStatsWithFilters user filter")
s.Require().Len(stats, 2) s.Require().Len(stats, 2)
// Test with apiKey filter // Test with apiKey filter
stats, err = s.repo.GetModelStatsWithFilters(s.ctx, startTime, endTime, 0, apiKey.ID, 0, 0, nil) stats, err = s.repo.GetModelStatsWithFilters(s.ctx, startTime, endTime, 0, apiKey.ID, 0, 0, nil, nil)
s.Require().NoError(err, "GetModelStatsWithFilters apiKey filter") s.Require().NoError(err, "GetModelStatsWithFilters apiKey filter")
s.Require().Len(stats, 2) s.Require().Len(stats, 2)
// Test with account filter // Test with account filter
stats, err = s.repo.GetModelStatsWithFilters(s.ctx, startTime, endTime, 0, 0, account.ID, 0, nil) stats, err = s.repo.GetModelStatsWithFilters(s.ctx, startTime, endTime, 0, 0, account.ID, 0, nil, nil)
s.Require().NoError(err, "GetModelStatsWithFilters account filter") s.Require().NoError(err, "GetModelStatsWithFilters account filter")
s.Require().Len(stats, 2) s.Require().Len(stats, 2)
} }
......
...@@ -7,6 +7,7 @@ import ( ...@@ -7,6 +7,7 @@ import (
"fmt" "fmt"
"sort" "sort"
"strings" "strings"
"time"
dbent "github.com/Wei-Shaw/sub2api/ent" dbent "github.com/Wei-Shaw/sub2api/ent"
dbuser "github.com/Wei-Shaw/sub2api/ent/user" dbuser "github.com/Wei-Shaw/sub2api/ent/user"
...@@ -189,6 +190,7 @@ func (r *userRepository) ListWithFilters(ctx context.Context, params pagination. ...@@ -189,6 +190,7 @@ func (r *userRepository) ListWithFilters(ctx context.Context, params pagination.
dbuser.Or( dbuser.Or(
dbuser.EmailContainsFold(filters.Search), dbuser.EmailContainsFold(filters.Search),
dbuser.UsernameContainsFold(filters.Search), dbuser.UsernameContainsFold(filters.Search),
dbuser.NotesContainsFold(filters.Search),
), ),
) )
} }
...@@ -466,3 +468,46 @@ func applyUserEntityToService(dst *service.User, src *dbent.User) { ...@@ -466,3 +468,46 @@ func applyUserEntityToService(dst *service.User, src *dbent.User) {
dst.CreatedAt = src.CreatedAt dst.CreatedAt = src.CreatedAt
dst.UpdatedAt = src.UpdatedAt dst.UpdatedAt = src.UpdatedAt
} }
// UpdateTotpSecret 更新用户的 TOTP 加密密钥
func (r *userRepository) UpdateTotpSecret(ctx context.Context, userID int64, encryptedSecret *string) error {
client := clientFromContext(ctx, r.client)
update := client.User.UpdateOneID(userID)
if encryptedSecret == nil {
update = update.ClearTotpSecretEncrypted()
} else {
update = update.SetTotpSecretEncrypted(*encryptedSecret)
}
_, err := update.Save(ctx)
if err != nil {
return translatePersistenceError(err, service.ErrUserNotFound, nil)
}
return nil
}
// EnableTotp 启用用户的 TOTP 双因素认证
func (r *userRepository) EnableTotp(ctx context.Context, userID int64) error {
client := clientFromContext(ctx, r.client)
_, err := client.User.UpdateOneID(userID).
SetTotpEnabled(true).
SetTotpEnabledAt(time.Now()).
Save(ctx)
if err != nil {
return translatePersistenceError(err, service.ErrUserNotFound, nil)
}
return nil
}
// DisableTotp 禁用用户的 TOTP 双因素认证
func (r *userRepository) DisableTotp(ctx context.Context, userID int64) error {
client := clientFromContext(ctx, r.client)
_, err := client.User.UpdateOneID(userID).
SetTotpEnabled(false).
ClearTotpEnabledAt().
ClearTotpSecretEncrypted().
Save(ctx)
if err != nil {
return translatePersistenceError(err, service.ErrUserNotFound, nil)
}
return nil
}
...@@ -190,7 +190,7 @@ func (r *userSubscriptionRepository) ListByGroupID(ctx context.Context, groupID ...@@ -190,7 +190,7 @@ func (r *userSubscriptionRepository) ListByGroupID(ctx context.Context, groupID
return userSubscriptionEntitiesToService(subs), paginationResultFromTotal(int64(total), params), nil return userSubscriptionEntitiesToService(subs), paginationResultFromTotal(int64(total), params), nil
} }
func (r *userSubscriptionRepository) List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status string) ([]service.UserSubscription, *pagination.PaginationResult, error) { func (r *userSubscriptionRepository) List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status, sortBy, sortOrder string) ([]service.UserSubscription, *pagination.PaginationResult, error) {
client := clientFromContext(ctx, r.client) client := clientFromContext(ctx, r.client)
q := client.UserSubscription.Query() q := client.UserSubscription.Query()
if userID != nil { if userID != nil {
...@@ -199,7 +199,31 @@ func (r *userSubscriptionRepository) List(ctx context.Context, params pagination ...@@ -199,7 +199,31 @@ func (r *userSubscriptionRepository) List(ctx context.Context, params pagination
if groupID != nil { if groupID != nil {
q = q.Where(usersubscription.GroupIDEQ(*groupID)) q = q.Where(usersubscription.GroupIDEQ(*groupID))
} }
if status != "" {
// Status filtering with real-time expiration check
now := time.Now()
switch status {
case service.SubscriptionStatusActive:
// Active: status is active AND not yet expired
q = q.Where(
usersubscription.StatusEQ(service.SubscriptionStatusActive),
usersubscription.ExpiresAtGT(now),
)
case service.SubscriptionStatusExpired:
// Expired: status is expired OR (status is active but already expired)
q = q.Where(
usersubscription.Or(
usersubscription.StatusEQ(service.SubscriptionStatusExpired),
usersubscription.And(
usersubscription.StatusEQ(service.SubscriptionStatusActive),
usersubscription.ExpiresAtLTE(now),
),
),
)
case "":
// No filter
default:
// Other status (e.g., revoked)
q = q.Where(usersubscription.StatusEQ(status)) q = q.Where(usersubscription.StatusEQ(status))
} }
...@@ -208,11 +232,28 @@ func (r *userSubscriptionRepository) List(ctx context.Context, params pagination ...@@ -208,11 +232,28 @@ func (r *userSubscriptionRepository) List(ctx context.Context, params pagination
return nil, nil, err return nil, nil, err
} }
// Apply sorting
q = q.WithUser().WithGroup().WithAssignedByUser()
// Determine sort field
var field string
switch sortBy {
case "expires_at":
field = usersubscription.FieldExpiresAt
case "status":
field = usersubscription.FieldStatus
default:
field = usersubscription.FieldCreatedAt
}
// Determine sort order (default: desc)
if sortOrder == "asc" && sortBy != "" {
q = q.Order(dbent.Asc(field))
} else {
q = q.Order(dbent.Desc(field))
}
subs, err := q. subs, err := q.
WithUser().
WithGroup().
WithAssignedByUser().
Order(dbent.Desc(usersubscription.FieldCreatedAt)).
Offset(params.Offset()). Offset(params.Offset()).
Limit(params.Limit()). Limit(params.Limit()).
All(ctx) All(ctx)
......
...@@ -271,7 +271,7 @@ func (s *UserSubscriptionRepoSuite) TestList_NoFilters() { ...@@ -271,7 +271,7 @@ func (s *UserSubscriptionRepoSuite) TestList_NoFilters() {
group := s.mustCreateGroup("g-list") group := s.mustCreateGroup("g-list")
s.mustCreateSubscription(user.ID, group.ID, nil) s.mustCreateSubscription(user.ID, group.ID, nil)
subs, page, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, nil, nil, "") subs, page, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, nil, nil, "", "", "")
s.Require().NoError(err, "List") s.Require().NoError(err, "List")
s.Require().Len(subs, 1) s.Require().Len(subs, 1)
s.Require().Equal(int64(1), page.Total) s.Require().Equal(int64(1), page.Total)
...@@ -285,7 +285,7 @@ func (s *UserSubscriptionRepoSuite) TestList_FilterByUserID() { ...@@ -285,7 +285,7 @@ func (s *UserSubscriptionRepoSuite) TestList_FilterByUserID() {
s.mustCreateSubscription(user1.ID, group.ID, nil) s.mustCreateSubscription(user1.ID, group.ID, nil)
s.mustCreateSubscription(user2.ID, group.ID, nil) s.mustCreateSubscription(user2.ID, group.ID, nil)
subs, _, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, &user1.ID, nil, "") subs, _, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, &user1.ID, nil, "", "", "")
s.Require().NoError(err) s.Require().NoError(err)
s.Require().Len(subs, 1) s.Require().Len(subs, 1)
s.Require().Equal(user1.ID, subs[0].UserID) s.Require().Equal(user1.ID, subs[0].UserID)
...@@ -299,7 +299,7 @@ func (s *UserSubscriptionRepoSuite) TestList_FilterByGroupID() { ...@@ -299,7 +299,7 @@ func (s *UserSubscriptionRepoSuite) TestList_FilterByGroupID() {
s.mustCreateSubscription(user.ID, g1.ID, nil) s.mustCreateSubscription(user.ID, g1.ID, nil)
s.mustCreateSubscription(user.ID, g2.ID, nil) s.mustCreateSubscription(user.ID, g2.ID, nil)
subs, _, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, nil, &g1.ID, "") subs, _, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, nil, &g1.ID, "", "", "")
s.Require().NoError(err) s.Require().NoError(err)
s.Require().Len(subs, 1) s.Require().Len(subs, 1)
s.Require().Equal(g1.ID, subs[0].GroupID) s.Require().Equal(g1.ID, subs[0].GroupID)
...@@ -320,7 +320,7 @@ func (s *UserSubscriptionRepoSuite) TestList_FilterByStatus() { ...@@ -320,7 +320,7 @@ func (s *UserSubscriptionRepoSuite) TestList_FilterByStatus() {
c.SetExpiresAt(time.Now().Add(-24 * time.Hour)) c.SetExpiresAt(time.Now().Add(-24 * time.Hour))
}) })
subs, _, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, nil, nil, service.SubscriptionStatusExpired) subs, _, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, nil, nil, service.SubscriptionStatusExpired, "", "")
s.Require().NoError(err) s.Require().NoError(err)
s.Require().Len(subs, 1) s.Require().Len(subs, 1)
s.Require().Equal(service.SubscriptionStatusExpired, subs[0].Status) s.Require().Equal(service.SubscriptionStatusExpired, subs[0].Status)
......
...@@ -56,7 +56,10 @@ var ProviderSet = wire.NewSet( ...@@ -56,7 +56,10 @@ var ProviderSet = wire.NewSet(
NewProxyRepository, NewProxyRepository,
NewRedeemCodeRepository, NewRedeemCodeRepository,
NewPromoCodeRepository, NewPromoCodeRepository,
NewAnnouncementRepository,
NewAnnouncementReadRepository,
NewUsageLogRepository, NewUsageLogRepository,
NewUsageCleanupRepository,
NewDashboardAggregationRepository, NewDashboardAggregationRepository,
NewSettingRepository, NewSettingRepository,
NewOpsRepository, NewOpsRepository,
...@@ -81,6 +84,10 @@ var ProviderSet = wire.NewSet( ...@@ -81,6 +84,10 @@ var ProviderSet = wire.NewSet(
NewSchedulerCache, NewSchedulerCache,
NewSchedulerOutboxRepository, NewSchedulerOutboxRepository,
NewProxyLatencyCache, NewProxyLatencyCache,
NewTotpCache,
// Encryptors
NewAESEncryptor,
// HTTP service ports (DI Strategy A: return interface directly) // HTTP service ports (DI Strategy A: return interface directly)
NewTurnstileVerifier, NewTurnstileVerifier,
......
...@@ -11,7 +11,6 @@ import ( ...@@ -11,7 +11,6 @@ import (
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"sort" "sort"
"strings"
"testing" "testing"
"time" "time"
...@@ -52,7 +51,6 @@ func TestAPIContracts(t *testing.T) { ...@@ -52,7 +51,6 @@ func TestAPIContracts(t *testing.T) {
"id": 1, "id": 1,
"email": "alice@example.com", "email": "alice@example.com",
"username": "alice", "username": "alice",
"notes": "hello",
"role": "user", "role": "user",
"balance": 12.5, "balance": 12.5,
"concurrency": 5, "concurrency": 5,
...@@ -132,6 +130,153 @@ func TestAPIContracts(t *testing.T) { ...@@ -132,6 +130,153 @@ func TestAPIContracts(t *testing.T) {
} }
}`, }`,
}, },
{
name: "GET /api/v1/groups/available",
setup: func(t *testing.T, deps *contractDeps) {
t.Helper()
// 普通用户可见的分组列表不应包含内部字段(如 model_routing/account_count)。
deps.groupRepo.SetActive([]service.Group{
{
ID: 10,
Name: "Group One",
Description: "desc",
Platform: service.PlatformAnthropic,
RateMultiplier: 1.5,
IsExclusive: false,
Status: service.StatusActive,
SubscriptionType: service.SubscriptionTypeStandard,
ModelRoutingEnabled: true,
ModelRouting: map[string][]int64{
"claude-3-*": []int64{101, 102},
},
AccountCount: 2,
CreatedAt: deps.now,
UpdatedAt: deps.now,
},
})
deps.userSubRepo.SetActiveByUserID(1, nil)
},
method: http.MethodGet,
path: "/api/v1/groups/available",
wantStatus: http.StatusOK,
wantJSON: `{
"code": 0,
"message": "success",
"data": [
{
"id": 10,
"name": "Group One",
"description": "desc",
"platform": "anthropic",
"rate_multiplier": 1.5,
"is_exclusive": false,
"status": "active",
"subscription_type": "standard",
"daily_limit_usd": null,
"weekly_limit_usd": null,
"monthly_limit_usd": null,
"image_price_1k": null,
"image_price_2k": null,
"image_price_4k": null,
"claude_code_only": false,
"fallback_group_id": null,
"created_at": "2025-01-02T03:04:05Z",
"updated_at": "2025-01-02T03:04:05Z"
}
]
}`,
},
{
name: "GET /api/v1/subscriptions",
setup: func(t *testing.T, deps *contractDeps) {
t.Helper()
// 普通用户订阅接口不应包含 assigned_* / notes 等管理员字段。
deps.userSubRepo.SetByUserID(1, []service.UserSubscription{
{
ID: 501,
UserID: 1,
GroupID: 10,
StartsAt: deps.now,
ExpiresAt: time.Date(2099, 1, 2, 3, 4, 5, 0, time.UTC), // 使用未来日期避免 normalizeSubscriptionStatus 标记为过期
Status: service.SubscriptionStatusActive,
DailyUsageUSD: 1.23,
WeeklyUsageUSD: 2.34,
MonthlyUsageUSD: 3.45,
AssignedBy: ptr(int64(999)),
AssignedAt: deps.now,
Notes: "admin-note",
CreatedAt: deps.now,
UpdatedAt: deps.now,
},
})
},
method: http.MethodGet,
path: "/api/v1/subscriptions",
wantStatus: http.StatusOK,
wantJSON: `{
"code": 0,
"message": "success",
"data": [
{
"id": 501,
"user_id": 1,
"group_id": 10,
"starts_at": "2025-01-02T03:04:05Z",
"expires_at": "2099-01-02T03:04:05Z",
"status": "active",
"daily_window_start": null,
"weekly_window_start": null,
"monthly_window_start": null,
"daily_usage_usd": 1.23,
"weekly_usage_usd": 2.34,
"monthly_usage_usd": 3.45,
"created_at": "2025-01-02T03:04:05Z",
"updated_at": "2025-01-02T03:04:05Z"
}
]
}`,
},
{
name: "GET /api/v1/redeem/history",
setup: func(t *testing.T, deps *contractDeps) {
t.Helper()
// 普通用户兑换历史不应包含 notes 等内部字段。
deps.redeemRepo.SetByUser(1, []service.RedeemCode{
{
ID: 900,
Code: "CODE-123",
Type: service.RedeemTypeBalance,
Value: 1.25,
Status: service.StatusUsed,
UsedBy: ptr(int64(1)),
UsedAt: ptr(deps.now),
Notes: "internal-note",
CreatedAt: deps.now,
},
})
},
method: http.MethodGet,
path: "/api/v1/redeem/history",
wantStatus: http.StatusOK,
wantJSON: `{
"code": 0,
"message": "success",
"data": [
{
"id": 900,
"code": "CODE-123",
"type": "balance",
"value": 1.25,
"status": "used",
"used_by": 1,
"used_at": "2025-01-02T03:04:05Z",
"created_at": "2025-01-02T03:04:05Z",
"group_id": null,
"validity_days": 0
}
]
}`,
},
{ {
name: "GET /api/v1/usage/stats", name: "GET /api/v1/usage/stats",
setup: func(t *testing.T, deps *contractDeps) { setup: func(t *testing.T, deps *contractDeps) {
...@@ -191,24 +336,25 @@ func TestAPIContracts(t *testing.T) { ...@@ -191,24 +336,25 @@ func TestAPIContracts(t *testing.T) {
t.Helper() t.Helper()
deps.usageRepo.SetUserLogs(1, []service.UsageLog{ deps.usageRepo.SetUserLogs(1, []service.UsageLog{
{ {
ID: 1, ID: 1,
UserID: 1, UserID: 1,
APIKeyID: 100, APIKeyID: 100,
AccountID: 200, AccountID: 200,
RequestID: "req_123", AccountRateMultiplier: ptr(0.5),
Model: "claude-3", RequestID: "req_123",
InputTokens: 10, Model: "claude-3",
OutputTokens: 20, InputTokens: 10,
CacheCreationTokens: 1, OutputTokens: 20,
CacheReadTokens: 2, CacheCreationTokens: 1,
TotalCost: 0.5, CacheReadTokens: 2,
ActualCost: 0.5, TotalCost: 0.5,
RateMultiplier: 1, ActualCost: 0.5,
BillingType: service.BillingTypeBalance, RateMultiplier: 1,
Stream: true, BillingType: service.BillingTypeBalance,
DurationMs: ptr(100), Stream: true,
FirstTokenMs: ptr(50), DurationMs: ptr(100),
CreatedAt: deps.now, FirstTokenMs: ptr(50),
CreatedAt: deps.now,
}, },
}) })
}, },
...@@ -239,10 +385,9 @@ func TestAPIContracts(t *testing.T) { ...@@ -239,10 +385,9 @@ func TestAPIContracts(t *testing.T) {
"output_cost": 0, "output_cost": 0,
"cache_creation_cost": 0, "cache_creation_cost": 0,
"cache_read_cost": 0, "cache_read_cost": 0,
"total_cost": 0.5, "total_cost": 0.5,
"actual_cost": 0.5, "actual_cost": 0.5,
"rate_multiplier": 1, "rate_multiplier": 1,
"account_rate_multiplier": null,
"billing_type": 0, "billing_type": 0,
"stream": true, "stream": true,
"duration_ms": 100, "duration_ms": 100,
...@@ -267,6 +412,7 @@ func TestAPIContracts(t *testing.T) { ...@@ -267,6 +412,7 @@ func TestAPIContracts(t *testing.T) {
deps.settingRepo.SetAll(map[string]string{ deps.settingRepo.SetAll(map[string]string{
service.SettingKeyRegistrationEnabled: "true", service.SettingKeyRegistrationEnabled: "true",
service.SettingKeyEmailVerifyEnabled: "false", service.SettingKeyEmailVerifyEnabled: "false",
service.SettingKeyPromoCodeEnabled: "true",
service.SettingKeySMTPHost: "smtp.example.com", service.SettingKeySMTPHost: "smtp.example.com",
service.SettingKeySMTPPort: "587", service.SettingKeySMTPPort: "587",
...@@ -305,6 +451,10 @@ func TestAPIContracts(t *testing.T) { ...@@ -305,6 +451,10 @@ func TestAPIContracts(t *testing.T) {
"data": { "data": {
"registration_enabled": true, "registration_enabled": true,
"email_verify_enabled": false, "email_verify_enabled": false,
"promo_code_enabled": true,
"password_reset_enabled": false,
"totp_enabled": false,
"totp_encryption_key_configured": false,
"smtp_host": "smtp.example.com", "smtp_host": "smtp.example.com",
"smtp_port": 587, "smtp_port": 587,
"smtp_username": "user", "smtp_username": "user",
...@@ -338,45 +488,10 @@ func TestAPIContracts(t *testing.T) { ...@@ -338,45 +488,10 @@ func TestAPIContracts(t *testing.T) {
"fallback_model_openai": "gpt-4o", "fallback_model_openai": "gpt-4o",
"enable_identity_patch": true, "enable_identity_patch": true,
"identity_patch_prompt": "", "identity_patch_prompt": "",
"home_content": "" "home_content": "",
} "hide_ccs_import_button": false,
}`, "purchase_subscription_enabled": false,
}, "purchase_subscription_url": ""
{
name: "POST /api/v1/admin/accounts/lookup",
setup: func(t *testing.T, deps *contractDeps) {
t.Helper()
deps.accountRepo.lookupAccounts = []service.Account{
{
ID: 101,
Name: "Alice Account",
Platform: "antigravity",
Credentials: map[string]any{
"email": "alice@example.com",
},
},
}
},
method: http.MethodPost,
path: "/api/v1/admin/accounts/lookup",
body: `{"platform":"antigravity","emails":["Alice@Example.com","bob@example.com"]}`,
headers: map[string]string{
"Content-Type": "application/json",
},
wantStatus: http.StatusOK,
wantJSON: `{
"code": 0,
"message": "success",
"data": {
"matched": [
{
"email": "alice@example.com",
"account_id": 101,
"platform": "antigravity",
"name": "Alice Account"
}
],
"missing": ["bob@example.com"]
} }
}`, }`,
}, },
...@@ -424,9 +539,11 @@ type contractDeps struct { ...@@ -424,9 +539,11 @@ type contractDeps struct {
now time.Time now time.Time
router http.Handler router http.Handler
apiKeyRepo *stubApiKeyRepo apiKeyRepo *stubApiKeyRepo
groupRepo *stubGroupRepo
userSubRepo *stubUserSubscriptionRepo
usageRepo *stubUsageLogRepo usageRepo *stubUsageLogRepo
settingRepo *stubSettingRepo settingRepo *stubSettingRepo
accountRepo *stubAccountRepo redeemRepo *stubRedeemCodeRepo
} }
func newContractDeps(t *testing.T) *contractDeps { func newContractDeps(t *testing.T) *contractDeps {
...@@ -454,11 +571,11 @@ func newContractDeps(t *testing.T) *contractDeps { ...@@ -454,11 +571,11 @@ func newContractDeps(t *testing.T) *contractDeps {
apiKeyRepo := newStubApiKeyRepo(now) apiKeyRepo := newStubApiKeyRepo(now)
apiKeyCache := stubApiKeyCache{} apiKeyCache := stubApiKeyCache{}
groupRepo := stubGroupRepo{} groupRepo := &stubGroupRepo{}
userSubRepo := stubUserSubscriptionRepo{} userSubRepo := &stubUserSubscriptionRepo{}
accountRepo := stubAccountRepo{} accountRepo := stubAccountRepo{}
proxyRepo := stubProxyRepo{} proxyRepo := stubProxyRepo{}
redeemRepo := stubRedeemCodeRepo{} redeemRepo := &stubRedeemCodeRepo{}
cfg := &config.Config{ cfg := &config.Config{
Default: config.DefaultConfig{ Default: config.DefaultConfig{
...@@ -473,15 +590,21 @@ func newContractDeps(t *testing.T) *contractDeps { ...@@ -473,15 +590,21 @@ func newContractDeps(t *testing.T) *contractDeps {
usageRepo := newStubUsageLogRepo() usageRepo := newStubUsageLogRepo()
usageService := service.NewUsageService(usageRepo, userRepo, nil, nil) usageService := service.NewUsageService(usageRepo, userRepo, nil, nil)
subscriptionService := service.NewSubscriptionService(groupRepo, userSubRepo, nil)
subscriptionHandler := handler.NewSubscriptionHandler(subscriptionService)
redeemService := service.NewRedeemService(redeemRepo, userRepo, subscriptionService, nil, nil, nil, nil)
redeemHandler := handler.NewRedeemHandler(redeemService)
settingRepo := newStubSettingRepo() settingRepo := newStubSettingRepo()
settingService := service.NewSettingService(settingRepo, cfg) settingService := service.NewSettingService(settingRepo, cfg)
adminService := service.NewAdminService(userRepo, groupRepo, &accountRepo, proxyRepo, apiKeyRepo, redeemRepo, nil, nil, nil, nil) adminService := service.NewAdminService(userRepo, groupRepo, &accountRepo, proxyRepo, apiKeyRepo, redeemRepo, nil, nil, nil, nil)
authHandler := handler.NewAuthHandler(cfg, nil, userService, settingService, nil) authHandler := handler.NewAuthHandler(cfg, nil, userService, settingService, nil, nil)
apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService) apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService)
usageHandler := handler.NewUsageHandler(usageService, apiKeyService) usageHandler := handler.NewUsageHandler(usageService, apiKeyService)
adminSettingHandler := adminhandler.NewSettingHandler(settingService, nil, nil, nil) adminSettingHandler := adminhandler.NewSettingHandler(settingService, nil, nil, nil)
adminAccountHandler := adminhandler.NewAccountHandler(adminService, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil) adminAccountHandler := adminhandler.NewAccountHandler(adminService, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
jwtAuth := func(c *gin.Context) { jwtAuth := func(c *gin.Context) {
c.Set(string(middleware.ContextKeyUser), middleware.AuthSubject{ c.Set(string(middleware.ContextKeyUser), middleware.AuthSubject{
...@@ -512,25 +635,35 @@ func newContractDeps(t *testing.T) *contractDeps { ...@@ -512,25 +635,35 @@ func newContractDeps(t *testing.T) *contractDeps {
v1Keys.Use(jwtAuth) v1Keys.Use(jwtAuth)
v1Keys.GET("/keys", apiKeyHandler.List) v1Keys.GET("/keys", apiKeyHandler.List)
v1Keys.POST("/keys", apiKeyHandler.Create) v1Keys.POST("/keys", apiKeyHandler.Create)
v1Keys.GET("/groups/available", apiKeyHandler.GetAvailableGroups)
v1Usage := v1.Group("") v1Usage := v1.Group("")
v1Usage.Use(jwtAuth) v1Usage.Use(jwtAuth)
v1Usage.GET("/usage", usageHandler.List) v1Usage.GET("/usage", usageHandler.List)
v1Usage.GET("/usage/stats", usageHandler.Stats) v1Usage.GET("/usage/stats", usageHandler.Stats)
v1Subs := v1.Group("")
v1Subs.Use(jwtAuth)
v1Subs.GET("/subscriptions", subscriptionHandler.List)
v1Redeem := v1.Group("")
v1Redeem.Use(jwtAuth)
v1Redeem.GET("/redeem/history", redeemHandler.GetHistory)
v1Admin := v1.Group("/admin") v1Admin := v1.Group("/admin")
v1Admin.Use(adminAuth) v1Admin.Use(adminAuth)
v1Admin.GET("/settings", adminSettingHandler.GetSettings) v1Admin.GET("/settings", adminSettingHandler.GetSettings)
v1Admin.POST("/accounts/bulk-update", adminAccountHandler.BulkUpdate) v1Admin.POST("/accounts/bulk-update", adminAccountHandler.BulkUpdate)
v1Admin.POST("/accounts/lookup", adminAccountHandler.Lookup)
return &contractDeps{ return &contractDeps{
now: now, now: now,
router: r, router: r,
apiKeyRepo: apiKeyRepo, apiKeyRepo: apiKeyRepo,
groupRepo: groupRepo,
userSubRepo: userSubRepo,
usageRepo: usageRepo, usageRepo: usageRepo,
settingRepo: settingRepo, settingRepo: settingRepo,
accountRepo: &accountRepo, redeemRepo: redeemRepo,
} }
} }
...@@ -626,6 +759,18 @@ func (r *stubUserRepo) RemoveGroupFromAllowedGroups(ctx context.Context, groupID ...@@ -626,6 +759,18 @@ func (r *stubUserRepo) RemoveGroupFromAllowedGroups(ctx context.Context, groupID
return 0, errors.New("not implemented") return 0, errors.New("not implemented")
} }
func (r *stubUserRepo) UpdateTotpSecret(ctx context.Context, userID int64, encryptedSecret *string) error {
return errors.New("not implemented")
}
func (r *stubUserRepo) EnableTotp(ctx context.Context, userID int64) error {
return errors.New("not implemented")
}
func (r *stubUserRepo) DisableTotp(ctx context.Context, userID int64) error {
return errors.New("not implemented")
}
type stubApiKeyCache struct{} type stubApiKeyCache struct{}
func (stubApiKeyCache) GetCreateAttemptCount(ctx context.Context, userID int64) (int, error) { func (stubApiKeyCache) GetCreateAttemptCount(ctx context.Context, userID int64) (int, error) {
...@@ -660,7 +805,21 @@ func (stubApiKeyCache) DeleteAuthCache(ctx context.Context, key string) error { ...@@ -660,7 +805,21 @@ func (stubApiKeyCache) DeleteAuthCache(ctx context.Context, key string) error {
return nil return nil
} }
type stubGroupRepo struct{} func (stubApiKeyCache) PublishAuthCacheInvalidation(ctx context.Context, cacheKey string) error {
return nil
}
func (stubApiKeyCache) SubscribeAuthCacheInvalidation(ctx context.Context, handler func(cacheKey string)) error {
return nil
}
type stubGroupRepo struct {
active []service.Group
}
func (r *stubGroupRepo) SetActive(groups []service.Group) {
r.active = append([]service.Group(nil), groups...)
}
func (stubGroupRepo) Create(ctx context.Context, group *service.Group) error { func (stubGroupRepo) Create(ctx context.Context, group *service.Group) error {
return errors.New("not implemented") return errors.New("not implemented")
...@@ -694,12 +853,19 @@ func (stubGroupRepo) ListWithFilters(ctx context.Context, params pagination.Pagi ...@@ -694,12 +853,19 @@ func (stubGroupRepo) ListWithFilters(ctx context.Context, params pagination.Pagi
return nil, nil, errors.New("not implemented") return nil, nil, errors.New("not implemented")
} }
func (stubGroupRepo) ListActive(ctx context.Context) ([]service.Group, error) { func (r *stubGroupRepo) ListActive(ctx context.Context) ([]service.Group, error) {
return nil, errors.New("not implemented") return append([]service.Group(nil), r.active...), nil
} }
func (stubGroupRepo) ListActiveByPlatform(ctx context.Context, platform string) ([]service.Group, error) { func (r *stubGroupRepo) ListActiveByPlatform(ctx context.Context, platform string) ([]service.Group, error) {
return nil, errors.New("not implemented") out := make([]service.Group, 0, len(r.active))
for i := range r.active {
g := r.active[i]
if g.Platform == platform {
out = append(out, g)
}
}
return out, nil
} }
func (stubGroupRepo) ExistsByName(ctx context.Context, name string) (bool, error) { func (stubGroupRepo) ExistsByName(ctx context.Context, name string) (bool, error) {
...@@ -715,8 +881,7 @@ func (stubGroupRepo) DeleteAccountGroupsByGroupID(ctx context.Context, groupID i ...@@ -715,8 +881,7 @@ func (stubGroupRepo) DeleteAccountGroupsByGroupID(ctx context.Context, groupID i
} }
type stubAccountRepo struct { type stubAccountRepo struct {
bulkUpdateIDs []int64 bulkUpdateIDs []int64
lookupAccounts []service.Account
} }
func (s *stubAccountRepo) Create(ctx context.Context, account *service.Account) error { func (s *stubAccountRepo) Create(ctx context.Context, account *service.Account) error {
...@@ -767,36 +932,6 @@ func (s *stubAccountRepo) ListByPlatform(ctx context.Context, platform string) ( ...@@ -767,36 +932,6 @@ func (s *stubAccountRepo) ListByPlatform(ctx context.Context, platform string) (
return nil, errors.New("not implemented") return nil, errors.New("not implemented")
} }
func (s *stubAccountRepo) ListByPlatformAndCredentialEmails(ctx context.Context, platform string, emails []string) ([]service.Account, error) {
if len(s.lookupAccounts) == 0 {
return nil, nil
}
emailSet := make(map[string]struct{}, len(emails))
for _, email := range emails {
normalized := strings.ToLower(strings.TrimSpace(email))
if normalized == "" {
continue
}
emailSet[normalized] = struct{}{}
}
var matches []service.Account
for i := range s.lookupAccounts {
account := &s.lookupAccounts[i]
if account.Platform != platform {
continue
}
accountEmail := strings.ToLower(strings.TrimSpace(account.GetCredential("email")))
if accountEmail == "" {
continue
}
if _, ok := emailSet[accountEmail]; !ok {
continue
}
matches = append(matches, *account)
}
return matches, nil
}
func (s *stubAccountRepo) UpdateLastUsed(ctx context.Context, id int64) error { func (s *stubAccountRepo) UpdateLastUsed(ctx context.Context, id int64) error {
return errors.New("not implemented") return errors.New("not implemented")
} }
...@@ -948,7 +1083,16 @@ func (stubProxyRepo) ListAccountSummariesByProxyID(ctx context.Context, proxyID ...@@ -948,7 +1083,16 @@ func (stubProxyRepo) ListAccountSummariesByProxyID(ctx context.Context, proxyID
return nil, errors.New("not implemented") return nil, errors.New("not implemented")
} }
type stubRedeemCodeRepo struct{} type stubRedeemCodeRepo struct {
byUser map[int64][]service.RedeemCode
}
func (r *stubRedeemCodeRepo) SetByUser(userID int64, codes []service.RedeemCode) {
if r.byUser == nil {
r.byUser = make(map[int64][]service.RedeemCode)
}
r.byUser[userID] = append([]service.RedeemCode(nil), codes...)
}
func (stubRedeemCodeRepo) Create(ctx context.Context, code *service.RedeemCode) error { func (stubRedeemCodeRepo) Create(ctx context.Context, code *service.RedeemCode) error {
return errors.New("not implemented") return errors.New("not implemented")
...@@ -986,11 +1130,35 @@ func (stubRedeemCodeRepo) ListWithFilters(ctx context.Context, params pagination ...@@ -986,11 +1130,35 @@ func (stubRedeemCodeRepo) ListWithFilters(ctx context.Context, params pagination
return nil, nil, errors.New("not implemented") return nil, nil, errors.New("not implemented")
} }
func (stubRedeemCodeRepo) ListByUser(ctx context.Context, userID int64, limit int) ([]service.RedeemCode, error) { func (r *stubRedeemCodeRepo) ListByUser(ctx context.Context, userID int64, limit int) ([]service.RedeemCode, error) {
return nil, errors.New("not implemented") if r.byUser == nil {
return nil, nil
}
codes := r.byUser[userID]
if limit > 0 && len(codes) > limit {
codes = codes[:limit]
}
return append([]service.RedeemCode(nil), codes...), nil
}
type stubUserSubscriptionRepo struct {
byUser map[int64][]service.UserSubscription
activeByUser map[int64][]service.UserSubscription
}
func (r *stubUserSubscriptionRepo) SetByUserID(userID int64, subs []service.UserSubscription) {
if r.byUser == nil {
r.byUser = make(map[int64][]service.UserSubscription)
}
r.byUser[userID] = append([]service.UserSubscription(nil), subs...)
} }
type stubUserSubscriptionRepo struct{} func (r *stubUserSubscriptionRepo) SetActiveByUserID(userID int64, subs []service.UserSubscription) {
if r.activeByUser == nil {
r.activeByUser = make(map[int64][]service.UserSubscription)
}
r.activeByUser[userID] = append([]service.UserSubscription(nil), subs...)
}
func (stubUserSubscriptionRepo) Create(ctx context.Context, sub *service.UserSubscription) error { func (stubUserSubscriptionRepo) Create(ctx context.Context, sub *service.UserSubscription) error {
return errors.New("not implemented") return errors.New("not implemented")
...@@ -1010,16 +1178,22 @@ func (stubUserSubscriptionRepo) Update(ctx context.Context, sub *service.UserSub ...@@ -1010,16 +1178,22 @@ func (stubUserSubscriptionRepo) Update(ctx context.Context, sub *service.UserSub
func (stubUserSubscriptionRepo) Delete(ctx context.Context, id int64) error { func (stubUserSubscriptionRepo) Delete(ctx context.Context, id int64) error {
return errors.New("not implemented") return errors.New("not implemented")
} }
func (stubUserSubscriptionRepo) ListByUserID(ctx context.Context, userID int64) ([]service.UserSubscription, error) { func (r *stubUserSubscriptionRepo) ListByUserID(ctx context.Context, userID int64) ([]service.UserSubscription, error) {
return nil, errors.New("not implemented") if r.byUser == nil {
return nil, nil
}
return append([]service.UserSubscription(nil), r.byUser[userID]...), nil
} }
func (stubUserSubscriptionRepo) ListActiveByUserID(ctx context.Context, userID int64) ([]service.UserSubscription, error) { func (r *stubUserSubscriptionRepo) ListActiveByUserID(ctx context.Context, userID int64) ([]service.UserSubscription, error) {
return nil, errors.New("not implemented") if r.activeByUser == nil {
return nil, nil
}
return append([]service.UserSubscription(nil), r.activeByUser[userID]...), nil
} }
func (stubUserSubscriptionRepo) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]service.UserSubscription, *pagination.PaginationResult, error) { func (stubUserSubscriptionRepo) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]service.UserSubscription, *pagination.PaginationResult, error) {
return nil, nil, errors.New("not implemented") return nil, nil, errors.New("not implemented")
} }
func (stubUserSubscriptionRepo) List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status string) ([]service.UserSubscription, *pagination.PaginationResult, error) { func (stubUserSubscriptionRepo) List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status, sortBy, sortOrder string) ([]service.UserSubscription, *pagination.PaginationResult, error) {
return nil, nil, errors.New("not implemented") return nil, nil, errors.New("not implemented")
} }
func (stubUserSubscriptionRepo) ExistsByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (bool, error) { func (stubUserSubscriptionRepo) ExistsByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (bool, error) {
...@@ -1319,11 +1493,11 @@ func (r *stubUsageLogRepo) GetDashboardStats(ctx context.Context) (*usagestats.D ...@@ -1319,11 +1493,11 @@ func (r *stubUsageLogRepo) GetDashboardStats(ctx context.Context) (*usagestats.D
return nil, errors.New("not implemented") return nil, errors.New("not implemented")
} }
func (r *stubUsageLogRepo) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID, accountID, groupID int64, model string, stream *bool) ([]usagestats.TrendDataPoint, error) { func (r *stubUsageLogRepo) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID, accountID, groupID int64, model string, stream *bool, billingType *int8) ([]usagestats.TrendDataPoint, error) {
return nil, errors.New("not implemented") return nil, errors.New("not implemented")
} }
func (r *stubUsageLogRepo) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, stream *bool) ([]usagestats.ModelStat, error) { func (r *stubUsageLogRepo) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, stream *bool, billingType *int8) ([]usagestats.ModelStat, error) {
return nil, errors.New("not implemented") return nil, errors.New("not implemented")
} }
......
...@@ -367,7 +367,7 @@ func (r *stubUserSubscriptionRepo) ListByGroupID(ctx context.Context, groupID in ...@@ -367,7 +367,7 @@ func (r *stubUserSubscriptionRepo) ListByGroupID(ctx context.Context, groupID in
return nil, nil, errors.New("not implemented") return nil, nil, errors.New("not implemented")
} }
func (r *stubUserSubscriptionRepo) List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status string) ([]service.UserSubscription, *pagination.PaginationResult, error) { func (r *stubUserSubscriptionRepo) List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status, sortBy, sortOrder string) ([]service.UserSubscription, *pagination.PaginationResult, error) {
return nil, nil, errors.New("not implemented") return nil, nil, errors.New("not implemented")
} }
......
...@@ -29,6 +29,9 @@ func RegisterAdminRoutes( ...@@ -29,6 +29,9 @@ func RegisterAdminRoutes(
// 账号管理 // 账号管理
registerAccountRoutes(admin, h) registerAccountRoutes(admin, h)
// 公告管理
registerAnnouncementRoutes(admin, h)
// OpenAI OAuth // OpenAI OAuth
registerOpenAIOAuthRoutes(admin, h) registerOpenAIOAuthRoutes(admin, h)
...@@ -197,7 +200,6 @@ func registerAccountRoutes(admin *gin.RouterGroup, h *handler.Handlers) { ...@@ -197,7 +200,6 @@ func registerAccountRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
accounts := admin.Group("/accounts") accounts := admin.Group("/accounts")
{ {
accounts.GET("", h.Admin.Account.List) accounts.GET("", h.Admin.Account.List)
accounts.POST("/lookup", h.Admin.Account.Lookup)
accounts.GET("/:id", h.Admin.Account.GetByID) accounts.GET("/:id", h.Admin.Account.GetByID)
accounts.POST("", h.Admin.Account.Create) accounts.POST("", h.Admin.Account.Create)
accounts.POST("/sync/crs", h.Admin.Account.SyncFromCRS) accounts.POST("/sync/crs", h.Admin.Account.SyncFromCRS)
...@@ -230,6 +232,18 @@ func registerAccountRoutes(admin *gin.RouterGroup, h *handler.Handlers) { ...@@ -230,6 +232,18 @@ func registerAccountRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
} }
} }
func registerAnnouncementRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
announcements := admin.Group("/announcements")
{
announcements.GET("", h.Admin.Announcement.List)
announcements.POST("", h.Admin.Announcement.Create)
announcements.GET("/:id", h.Admin.Announcement.GetByID)
announcements.PUT("/:id", h.Admin.Announcement.Update)
announcements.DELETE("/:id", h.Admin.Announcement.Delete)
announcements.GET("/:id/read-status", h.Admin.Announcement.ListReadStatus)
}
}
func registerOpenAIOAuthRoutes(admin *gin.RouterGroup, h *handler.Handlers) { func registerOpenAIOAuthRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
openai := admin.Group("/openai") openai := admin.Group("/openai")
{ {
...@@ -355,6 +369,9 @@ func registerUsageRoutes(admin *gin.RouterGroup, h *handler.Handlers) { ...@@ -355,6 +369,9 @@ func registerUsageRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
usage.GET("/stats", h.Admin.Usage.Stats) usage.GET("/stats", h.Admin.Usage.Stats)
usage.GET("/search-users", h.Admin.Usage.SearchUsers) usage.GET("/search-users", h.Admin.Usage.SearchUsers)
usage.GET("/search-api-keys", h.Admin.Usage.SearchAPIKeys) usage.GET("/search-api-keys", h.Admin.Usage.SearchAPIKeys)
usage.GET("/cleanup-tasks", h.Admin.Usage.ListCleanupTasks)
usage.POST("/cleanup-tasks", h.Admin.Usage.CreateCleanupTask)
usage.POST("/cleanup-tasks/:id/cancel", h.Admin.Usage.CancelCleanupTask)
} }
} }
......
...@@ -26,11 +26,20 @@ func RegisterAuthRoutes( ...@@ -26,11 +26,20 @@ func RegisterAuthRoutes(
{ {
auth.POST("/register", h.Auth.Register) auth.POST("/register", h.Auth.Register)
auth.POST("/login", h.Auth.Login) auth.POST("/login", h.Auth.Login)
auth.POST("/login/2fa", h.Auth.Login2FA)
auth.POST("/send-verify-code", h.Auth.SendVerifyCode) auth.POST("/send-verify-code", h.Auth.SendVerifyCode)
// 优惠码验证接口添加速率限制:每分钟最多 10 次(Redis 故障时 fail-close) // 优惠码验证接口添加速率限制:每分钟最多 10 次(Redis 故障时 fail-close)
auth.POST("/validate-promo-code", rateLimiter.LimitWithOptions("validate-promo", 10, time.Minute, middleware.RateLimitOptions{ auth.POST("/validate-promo-code", rateLimiter.LimitWithOptions("validate-promo", 10, time.Minute, middleware.RateLimitOptions{
FailureMode: middleware.RateLimitFailClose, FailureMode: middleware.RateLimitFailClose,
}), h.Auth.ValidatePromoCode) }), h.Auth.ValidatePromoCode)
// 忘记密码接口添加速率限制:每分钟最多 5 次(Redis 故障时 fail-close)
auth.POST("/forgot-password", rateLimiter.LimitWithOptions("forgot-password", 5, time.Minute, middleware.RateLimitOptions{
FailureMode: middleware.RateLimitFailClose,
}), h.Auth.ForgotPassword)
// 重置密码接口添加速率限制:每分钟最多 10 次(Redis 故障时 fail-close)
auth.POST("/reset-password", rateLimiter.LimitWithOptions("reset-password", 10, time.Minute, middleware.RateLimitOptions{
FailureMode: middleware.RateLimitFailClose,
}), h.Auth.ResetPassword)
auth.GET("/oauth/linuxdo/start", h.Auth.LinuxDoOAuthStart) auth.GET("/oauth/linuxdo/start", h.Auth.LinuxDoOAuthStart)
auth.GET("/oauth/linuxdo/callback", h.Auth.LinuxDoOAuthCallback) auth.GET("/oauth/linuxdo/callback", h.Auth.LinuxDoOAuthCallback)
} }
......
...@@ -22,6 +22,17 @@ func RegisterUserRoutes( ...@@ -22,6 +22,17 @@ func RegisterUserRoutes(
user.GET("/profile", h.User.GetProfile) user.GET("/profile", h.User.GetProfile)
user.PUT("/password", h.User.ChangePassword) user.PUT("/password", h.User.ChangePassword)
user.PUT("", h.User.UpdateProfile) user.PUT("", h.User.UpdateProfile)
// TOTP 双因素认证
totp := user.Group("/totp")
{
totp.GET("/status", h.Totp.GetStatus)
totp.GET("/verification-method", h.Totp.GetVerificationMethod)
totp.POST("/send-code", h.Totp.SendVerifyCode)
totp.POST("/setup", h.Totp.InitiateSetup)
totp.POST("/enable", h.Totp.Enable)
totp.POST("/disable", h.Totp.Disable)
}
} }
// API Key管理 // API Key管理
...@@ -53,6 +64,13 @@ func RegisterUserRoutes( ...@@ -53,6 +64,13 @@ func RegisterUserRoutes(
usage.POST("/dashboard/api-keys-usage", h.Usage.DashboardAPIKeysUsage) usage.POST("/dashboard/api-keys-usage", h.Usage.DashboardAPIKeysUsage)
} }
// 公告(用户可见)
announcements := authenticated.Group("/announcements")
{
announcements.GET("", h.Announcement.List)
announcements.POST("/:id/read", h.Announcement.MarkRead)
}
// 卡密兑换 // 卡密兑换
redeem := authenticated.Group("/redeem") redeem := authenticated.Group("/redeem")
{ {
......
...@@ -197,6 +197,35 @@ func (a *Account) GetCredentialAsTime(key string) *time.Time { ...@@ -197,6 +197,35 @@ func (a *Account) GetCredentialAsTime(key string) *time.Time {
return nil return nil
} }
// GetCredentialAsInt64 解析凭证中的 int64 字段
// 用于读取 _token_version 等内部字段
func (a *Account) GetCredentialAsInt64(key string) int64 {
if a == nil || a.Credentials == nil {
return 0
}
val, ok := a.Credentials[key]
if !ok || val == nil {
return 0
}
switch v := val.(type) {
case int64:
return v
case float64:
return int64(v)
case int:
return int64(v)
case json.Number:
if i, err := v.Int64(); err == nil {
return i
}
case string:
if i, err := strconv.ParseInt(strings.TrimSpace(v), 10, 64); err == nil {
return i
}
}
return 0
}
func (a *Account) IsTempUnschedulableEnabled() bool { func (a *Account) IsTempUnschedulableEnabled() bool {
if a.Credentials == nil { if a.Credentials == nil {
return false return false
...@@ -576,6 +605,44 @@ func (a *Account) IsAnthropicOAuthOrSetupToken() bool { ...@@ -576,6 +605,44 @@ func (a *Account) IsAnthropicOAuthOrSetupToken() bool {
return a.Platform == PlatformAnthropic && (a.Type == AccountTypeOAuth || a.Type == AccountTypeSetupToken) return a.Platform == PlatformAnthropic && (a.Type == AccountTypeOAuth || a.Type == AccountTypeSetupToken)
} }
// IsTLSFingerprintEnabled 检查是否启用 TLS 指纹伪装
// 仅适用于 Anthropic OAuth/SetupToken 类型账号
// 启用后将模拟 Claude Code (Node.js) 客户端的 TLS 握手特征
func (a *Account) IsTLSFingerprintEnabled() bool {
// 仅支持 Anthropic OAuth/SetupToken 账号
if !a.IsAnthropicOAuthOrSetupToken() {
return false
}
if a.Extra == nil {
return false
}
if v, ok := a.Extra["enable_tls_fingerprint"]; ok {
if enabled, ok := v.(bool); ok {
return enabled
}
}
return false
}
// IsSessionIDMaskingEnabled 检查是否启用会话ID伪装
// 仅适用于 Anthropic OAuth/SetupToken 类型账号
// 启用后将在一段时间内(15分钟)固定 metadata.user_id 中的 session ID,
// 使上游认为请求来自同一个会话
func (a *Account) IsSessionIDMaskingEnabled() bool {
if !a.IsAnthropicOAuthOrSetupToken() {
return false
}
if a.Extra == nil {
return false
}
if v, ok := a.Extra["session_id_masking_enabled"]; ok {
if enabled, ok := v.(bool); ok {
return enabled
}
}
return false
}
// GetWindowCostLimit 获取 5h 窗口费用阈值(美元) // GetWindowCostLimit 获取 5h 窗口费用阈值(美元)
// 返回 0 表示未启用 // 返回 0 表示未启用
func (a *Account) GetWindowCostLimit() float64 { func (a *Account) GetWindowCostLimit() float64 {
...@@ -652,6 +719,23 @@ func (a *Account) CheckWindowCostSchedulability(currentWindowCost float64) Windo ...@@ -652,6 +719,23 @@ func (a *Account) CheckWindowCostSchedulability(currentWindowCost float64) Windo
return WindowCostNotSchedulable return WindowCostNotSchedulable
} }
// GetCurrentWindowStartTime 获取当前有效的窗口开始时间
// 逻辑:
// 1. 如果窗口未过期(SessionWindowEnd 存在且在当前时间之后),使用记录的 SessionWindowStart
// 2. 否则(窗口过期或未设置),使用新的预测窗口开始时间(从当前整点开始)
func (a *Account) GetCurrentWindowStartTime() time.Time {
now := time.Now()
// 窗口未过期,使用记录的窗口开始时间
if a.SessionWindowStart != nil && a.SessionWindowEnd != nil && now.Before(*a.SessionWindowEnd) {
return *a.SessionWindowStart
}
// 窗口已过期或未设置,预测新的窗口开始时间(从当前整点开始)
// 与 ratelimit_service.go 中 UpdateSessionWindow 的预测逻辑保持一致
return time.Date(now.Year(), now.Month(), now.Day(), now.Hour(), 0, 0, 0, now.Location())
}
// parseExtraFloat64 从 extra 字段解析 float64 值 // parseExtraFloat64 从 extra 字段解析 float64 值
func parseExtraFloat64(value any) float64 { func parseExtraFloat64(value any) float64 {
switch v := value.(type) { switch v := value.(type) {
......
...@@ -33,7 +33,6 @@ type AccountRepository interface { ...@@ -33,7 +33,6 @@ type AccountRepository interface {
ListByGroup(ctx context.Context, groupID int64) ([]Account, error) ListByGroup(ctx context.Context, groupID int64) ([]Account, error)
ListActive(ctx context.Context) ([]Account, error) ListActive(ctx context.Context) ([]Account, error)
ListByPlatform(ctx context.Context, platform string) ([]Account, error) ListByPlatform(ctx context.Context, platform string) ([]Account, error)
ListByPlatformAndCredentialEmails(ctx context.Context, platform string, emails []string) ([]Account, error)
UpdateLastUsed(ctx context.Context, id int64) error UpdateLastUsed(ctx context.Context, id int64) error
BatchUpdateLastUsed(ctx context.Context, updates map[int64]time.Time) error BatchUpdateLastUsed(ctx context.Context, updates map[int64]time.Time) error
......
...@@ -87,10 +87,6 @@ func (s *accountRepoStub) ListByPlatform(ctx context.Context, platform string) ( ...@@ -87,10 +87,6 @@ func (s *accountRepoStub) ListByPlatform(ctx context.Context, platform string) (
panic("unexpected ListByPlatform call") panic("unexpected ListByPlatform call")
} }
func (s *accountRepoStub) ListByPlatformAndCredentialEmails(ctx context.Context, platform string, emails []string) ([]Account, error) {
panic("unexpected ListByPlatformAndCredentialEmails call")
}
func (s *accountRepoStub) UpdateLastUsed(ctx context.Context, id int64) error { func (s *accountRepoStub) UpdateLastUsed(ctx context.Context, id int64) error {
panic("unexpected UpdateLastUsed call") panic("unexpected UpdateLastUsed call")
} }
......
...@@ -265,7 +265,7 @@ func (s *AccountTestService) testClaudeAccountConnection(c *gin.Context, account ...@@ -265,7 +265,7 @@ func (s *AccountTestService) testClaudeAccountConnection(c *gin.Context, account
proxyURL = account.Proxy.URL() proxyURL = account.Proxy.URL()
} }
resp, err := s.httpUpstream.Do(req, proxyURL, account.ID, account.Concurrency) resp, err := s.httpUpstream.DoWithTLS(req, proxyURL, account.ID, account.Concurrency, account.IsTLSFingerprintEnabled())
if err != nil { if err != nil {
return s.sendErrorAndEnd(c, fmt.Sprintf("Request failed: %s", err.Error())) return s.sendErrorAndEnd(c, fmt.Sprintf("Request failed: %s", err.Error()))
} }
...@@ -375,7 +375,7 @@ func (s *AccountTestService) testOpenAIAccountConnection(c *gin.Context, account ...@@ -375,7 +375,7 @@ func (s *AccountTestService) testOpenAIAccountConnection(c *gin.Context, account
proxyURL = account.Proxy.URL() proxyURL = account.Proxy.URL()
} }
resp, err := s.httpUpstream.Do(req, proxyURL, account.ID, account.Concurrency) resp, err := s.httpUpstream.DoWithTLS(req, proxyURL, account.ID, account.Concurrency, account.IsTLSFingerprintEnabled())
if err != nil { if err != nil {
return s.sendErrorAndEnd(c, fmt.Sprintf("Request failed: %s", err.Error())) return s.sendErrorAndEnd(c, fmt.Sprintf("Request failed: %s", err.Error()))
} }
...@@ -446,7 +446,7 @@ func (s *AccountTestService) testGeminiAccountConnection(c *gin.Context, account ...@@ -446,7 +446,7 @@ func (s *AccountTestService) testGeminiAccountConnection(c *gin.Context, account
proxyURL = account.Proxy.URL() proxyURL = account.Proxy.URL()
} }
resp, err := s.httpUpstream.Do(req, proxyURL, account.ID, account.Concurrency) resp, err := s.httpUpstream.DoWithTLS(req, proxyURL, account.ID, account.Concurrency, account.IsTLSFingerprintEnabled())
if err != nil { if err != nil {
return s.sendErrorAndEnd(c, fmt.Sprintf("Request failed: %s", err.Error())) return s.sendErrorAndEnd(c, fmt.Sprintf("Request failed: %s", err.Error()))
} }
......
...@@ -32,8 +32,8 @@ type UsageLogRepository interface { ...@@ -32,8 +32,8 @@ type UsageLogRepository interface {
// Admin dashboard stats // Admin dashboard stats
GetDashboardStats(ctx context.Context) (*usagestats.DashboardStats, error) GetDashboardStats(ctx context.Context) (*usagestats.DashboardStats, error)
GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID, accountID, groupID int64, model string, stream *bool) ([]usagestats.TrendDataPoint, error) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID, accountID, groupID int64, model string, stream *bool, billingType *int8) ([]usagestats.TrendDataPoint, error)
GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, stream *bool) ([]usagestats.ModelStat, error) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, stream *bool, billingType *int8) ([]usagestats.ModelStat, error)
GetAPIKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.APIKeyUsageTrendPoint, error) GetAPIKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.APIKeyUsageTrendPoint, error)
GetUserUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.UserUsageTrendPoint, error) GetUserUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.UserUsageTrendPoint, error)
GetBatchUserUsageStats(ctx context.Context, userIDs []int64) (map[int64]*usagestats.BatchUserUsageStats, error) GetBatchUserUsageStats(ctx context.Context, userIDs []int64) (map[int64]*usagestats.BatchUserUsageStats, error)
...@@ -157,9 +157,20 @@ type ClaudeUsageResponse struct { ...@@ -157,9 +157,20 @@ type ClaudeUsageResponse struct {
} `json:"seven_day_sonnet"` } `json:"seven_day_sonnet"`
} }
// ClaudeUsageFetchOptions 包含获取 Claude 用量数据所需的所有选项
type ClaudeUsageFetchOptions struct {
AccessToken string // OAuth access token
ProxyURL string // 代理 URL(可选)
AccountID int64 // 账号 ID(用于 TLS 指纹选择)
EnableTLSFingerprint bool // 是否启用 TLS 指纹伪装
Fingerprint *Fingerprint // 缓存的指纹信息(User-Agent 等)
}
// ClaudeUsageFetcher fetches usage data from Anthropic OAuth API // ClaudeUsageFetcher fetches usage data from Anthropic OAuth API
type ClaudeUsageFetcher interface { type ClaudeUsageFetcher interface {
FetchUsage(ctx context.Context, accessToken, proxyURL string) (*ClaudeUsageResponse, error) FetchUsage(ctx context.Context, accessToken, proxyURL string) (*ClaudeUsageResponse, error)
// FetchUsageWithOptions 使用完整选项获取用量数据,支持 TLS 指纹和自定义 User-Agent
FetchUsageWithOptions(ctx context.Context, opts *ClaudeUsageFetchOptions) (*ClaudeUsageResponse, error)
} }
// AccountUsageService 账号使用量查询服务 // AccountUsageService 账号使用量查询服务
...@@ -170,6 +181,7 @@ type AccountUsageService struct { ...@@ -170,6 +181,7 @@ type AccountUsageService struct {
geminiQuotaService *GeminiQuotaService geminiQuotaService *GeminiQuotaService
antigravityQuotaFetcher *AntigravityQuotaFetcher antigravityQuotaFetcher *AntigravityQuotaFetcher
cache *UsageCache cache *UsageCache
identityCache IdentityCache
} }
// NewAccountUsageService 创建AccountUsageService实例 // NewAccountUsageService 创建AccountUsageService实例
...@@ -180,6 +192,7 @@ func NewAccountUsageService( ...@@ -180,6 +192,7 @@ func NewAccountUsageService(
geminiQuotaService *GeminiQuotaService, geminiQuotaService *GeminiQuotaService,
antigravityQuotaFetcher *AntigravityQuotaFetcher, antigravityQuotaFetcher *AntigravityQuotaFetcher,
cache *UsageCache, cache *UsageCache,
identityCache IdentityCache,
) *AccountUsageService { ) *AccountUsageService {
return &AccountUsageService{ return &AccountUsageService{
accountRepo: accountRepo, accountRepo: accountRepo,
...@@ -188,6 +201,7 @@ func NewAccountUsageService( ...@@ -188,6 +201,7 @@ func NewAccountUsageService(
geminiQuotaService: geminiQuotaService, geminiQuotaService: geminiQuotaService,
antigravityQuotaFetcher: antigravityQuotaFetcher, antigravityQuotaFetcher: antigravityQuotaFetcher,
cache: cache, cache: cache,
identityCache: identityCache,
} }
} }
...@@ -272,7 +286,7 @@ func (s *AccountUsageService) getGeminiUsage(ctx context.Context, account *Accou ...@@ -272,7 +286,7 @@ func (s *AccountUsageService) getGeminiUsage(ctx context.Context, account *Accou
} }
dayStart := geminiDailyWindowStart(now) dayStart := geminiDailyWindowStart(now)
stats, err := s.usageLogRepo.GetModelStatsWithFilters(ctx, dayStart, now, 0, 0, account.ID, 0, nil) stats, err := s.usageLogRepo.GetModelStatsWithFilters(ctx, dayStart, now, 0, 0, account.ID, 0, nil, nil)
if err != nil { if err != nil {
return nil, fmt.Errorf("get gemini usage stats failed: %w", err) return nil, fmt.Errorf("get gemini usage stats failed: %w", err)
} }
...@@ -294,7 +308,7 @@ func (s *AccountUsageService) getGeminiUsage(ctx context.Context, account *Accou ...@@ -294,7 +308,7 @@ func (s *AccountUsageService) getGeminiUsage(ctx context.Context, account *Accou
// Minute window (RPM) - fixed-window approximation: current minute [truncate(now), truncate(now)+1m) // Minute window (RPM) - fixed-window approximation: current minute [truncate(now), truncate(now)+1m)
minuteStart := now.Truncate(time.Minute) minuteStart := now.Truncate(time.Minute)
minuteResetAt := minuteStart.Add(time.Minute) minuteResetAt := minuteStart.Add(time.Minute)
minuteStats, err := s.usageLogRepo.GetModelStatsWithFilters(ctx, minuteStart, now, 0, 0, account.ID, 0, nil) minuteStats, err := s.usageLogRepo.GetModelStatsWithFilters(ctx, minuteStart, now, 0, 0, account.ID, 0, nil, nil)
if err != nil { if err != nil {
return nil, fmt.Errorf("get gemini minute usage stats failed: %w", err) return nil, fmt.Errorf("get gemini minute usage stats failed: %w", err)
} }
...@@ -369,12 +383,8 @@ func (s *AccountUsageService) addWindowStats(ctx context.Context, account *Accou ...@@ -369,12 +383,8 @@ func (s *AccountUsageService) addWindowStats(ctx context.Context, account *Accou
// 如果没有缓存,从数据库查询 // 如果没有缓存,从数据库查询
if windowStats == nil { if windowStats == nil {
var startTime time.Time // 使用统一的窗口开始时间计算逻辑(考虑窗口过期情况)
if account.SessionWindowStart != nil { startTime := account.GetCurrentWindowStartTime()
startTime = *account.SessionWindowStart
} else {
startTime = time.Now().Add(-5 * time.Hour)
}
stats, err := s.usageLogRepo.GetAccountWindowStats(ctx, account.ID, startTime) stats, err := s.usageLogRepo.GetAccountWindowStats(ctx, account.ID, startTime)
if err != nil { if err != nil {
...@@ -428,6 +438,8 @@ func (s *AccountUsageService) GetAccountUsageStats(ctx context.Context, accountI ...@@ -428,6 +438,8 @@ func (s *AccountUsageService) GetAccountUsageStats(ctx context.Context, accountI
} }
// fetchOAuthUsageRaw 从 Anthropic API 获取原始响应(不构建 UsageInfo) // fetchOAuthUsageRaw 从 Anthropic API 获取原始响应(不构建 UsageInfo)
// 如果账号开启了 TLS 指纹,则使用 TLS 指纹伪装
// 如果有缓存的 Fingerprint,则使用缓存的 User-Agent 等信息
func (s *AccountUsageService) fetchOAuthUsageRaw(ctx context.Context, account *Account) (*ClaudeUsageResponse, error) { func (s *AccountUsageService) fetchOAuthUsageRaw(ctx context.Context, account *Account) (*ClaudeUsageResponse, error) {
accessToken := account.GetCredential("access_token") accessToken := account.GetCredential("access_token")
if accessToken == "" { if accessToken == "" {
...@@ -439,7 +451,22 @@ func (s *AccountUsageService) fetchOAuthUsageRaw(ctx context.Context, account *A ...@@ -439,7 +451,22 @@ func (s *AccountUsageService) fetchOAuthUsageRaw(ctx context.Context, account *A
proxyURL = account.Proxy.URL() proxyURL = account.Proxy.URL()
} }
return s.usageFetcher.FetchUsage(ctx, accessToken, proxyURL) // 构建完整的选项
opts := &ClaudeUsageFetchOptions{
AccessToken: accessToken,
ProxyURL: proxyURL,
AccountID: account.ID,
EnableTLSFingerprint: account.IsTLSFingerprintEnabled(),
}
// 尝试获取缓存的 Fingerprint(包含 User-Agent 等信息)
if s.identityCache != nil {
if fp, err := s.identityCache.GetFingerprint(ctx, account.ID); err == nil && fp != nil {
opts.Fingerprint = fp
}
}
return s.usageFetcher.FetchUsageWithOptions(ctx, opts)
} }
// parseTime 尝试多种格式解析时间 // parseTime 尝试多种格式解析时间
......
...@@ -40,7 +40,6 @@ type AdminService interface { ...@@ -40,7 +40,6 @@ type AdminService interface {
CreateAccount(ctx context.Context, input *CreateAccountInput) (*Account, error) CreateAccount(ctx context.Context, input *CreateAccountInput) (*Account, error)
UpdateAccount(ctx context.Context, id int64, input *UpdateAccountInput) (*Account, error) UpdateAccount(ctx context.Context, id int64, input *UpdateAccountInput) (*Account, error)
DeleteAccount(ctx context.Context, id int64) error DeleteAccount(ctx context.Context, id int64) error
LookupAccountsByCredentialEmail(ctx context.Context, platform string, emails []string) ([]Account, error)
RefreshAccountCredentials(ctx context.Context, id int64) (*Account, error) RefreshAccountCredentials(ctx context.Context, id int64) (*Account, error)
ClearAccountError(ctx context.Context, id int64) (*Account, error) ClearAccountError(ctx context.Context, id int64) (*Account, error)
SetAccountError(ctx context.Context, id int64, errorMsg string) error SetAccountError(ctx context.Context, id int64, errorMsg string) error
...@@ -866,13 +865,6 @@ func (s *adminServiceImpl) GetAccount(ctx context.Context, id int64) (*Account, ...@@ -866,13 +865,6 @@ func (s *adminServiceImpl) GetAccount(ctx context.Context, id int64) (*Account,
return s.accountRepo.GetByID(ctx, id) return s.accountRepo.GetByID(ctx, id)
} }
func (s *adminServiceImpl) LookupAccountsByCredentialEmail(ctx context.Context, platform string, emails []string) ([]Account, error) {
if platform == "" || len(emails) == 0 {
return []Account{}, nil
}
return s.accountRepo.ListByPlatformAndCredentialEmails(ctx, platform, emails)
}
func (s *adminServiceImpl) GetAccountsByIDs(ctx context.Context, ids []int64) ([]*Account, error) { func (s *adminServiceImpl) GetAccountsByIDs(ctx context.Context, ids []int64) ([]*Account, error) {
if len(ids) == 0 { if len(ids) == 0 {
return []*Account{}, nil return []*Account{}, nil
......
...@@ -93,6 +93,18 @@ func (s *userRepoStub) RemoveGroupFromAllowedGroups(ctx context.Context, groupID ...@@ -93,6 +93,18 @@ func (s *userRepoStub) RemoveGroupFromAllowedGroups(ctx context.Context, groupID
panic("unexpected RemoveGroupFromAllowedGroups call") panic("unexpected RemoveGroupFromAllowedGroups call")
} }
func (s *userRepoStub) UpdateTotpSecret(ctx context.Context, userID int64, encryptedSecret *string) error {
panic("unexpected UpdateTotpSecret call")
}
func (s *userRepoStub) EnableTotp(ctx context.Context, userID int64) error {
panic("unexpected EnableTotp call")
}
func (s *userRepoStub) DisableTotp(ctx context.Context, userID int64) error {
panic("unexpected DisableTotp call")
}
type groupRepoStub struct { type groupRepoStub struct {
affectedUserIDs []int64 affectedUserIDs []int64
deleteErr error deleteErr error
......
package service
import (
"context"
"time"
"github.com/Wei-Shaw/sub2api/internal/domain"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
)
const (
AnnouncementStatusDraft = domain.AnnouncementStatusDraft
AnnouncementStatusActive = domain.AnnouncementStatusActive
AnnouncementStatusArchived = domain.AnnouncementStatusArchived
)
const (
AnnouncementConditionTypeSubscription = domain.AnnouncementConditionTypeSubscription
AnnouncementConditionTypeBalance = domain.AnnouncementConditionTypeBalance
)
const (
AnnouncementOperatorIn = domain.AnnouncementOperatorIn
AnnouncementOperatorGT = domain.AnnouncementOperatorGT
AnnouncementOperatorGTE = domain.AnnouncementOperatorGTE
AnnouncementOperatorLT = domain.AnnouncementOperatorLT
AnnouncementOperatorLTE = domain.AnnouncementOperatorLTE
AnnouncementOperatorEQ = domain.AnnouncementOperatorEQ
)
var (
ErrAnnouncementNotFound = domain.ErrAnnouncementNotFound
ErrAnnouncementInvalidTarget = domain.ErrAnnouncementInvalidTarget
)
type AnnouncementTargeting = domain.AnnouncementTargeting
type AnnouncementConditionGroup = domain.AnnouncementConditionGroup
type AnnouncementCondition = domain.AnnouncementCondition
type Announcement = domain.Announcement
type AnnouncementListFilters struct {
Status string
Search string
}
type AnnouncementRepository interface {
Create(ctx context.Context, a *Announcement) error
GetByID(ctx context.Context, id int64) (*Announcement, error)
Update(ctx context.Context, a *Announcement) error
Delete(ctx context.Context, id int64) error
List(ctx context.Context, params pagination.PaginationParams, filters AnnouncementListFilters) ([]Announcement, *pagination.PaginationResult, error)
ListActive(ctx context.Context, now time.Time) ([]Announcement, error)
}
type AnnouncementReadRepository interface {
MarkRead(ctx context.Context, announcementID, userID int64, readAt time.Time) error
GetReadMapByUser(ctx context.Context, userID int64, announcementIDs []int64) (map[int64]time.Time, error)
GetReadMapByUsers(ctx context.Context, announcementID int64, userIDs []int64) (map[int64]time.Time, error)
CountByAnnouncementID(ctx context.Context, announcementID int64) (int64, error)
}
package service
import (
"context"
"fmt"
"sort"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/domain"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
)
type AnnouncementService struct {
announcementRepo AnnouncementRepository
readRepo AnnouncementReadRepository
userRepo UserRepository
userSubRepo UserSubscriptionRepository
}
func NewAnnouncementService(
announcementRepo AnnouncementRepository,
readRepo AnnouncementReadRepository,
userRepo UserRepository,
userSubRepo UserSubscriptionRepository,
) *AnnouncementService {
return &AnnouncementService{
announcementRepo: announcementRepo,
readRepo: readRepo,
userRepo: userRepo,
userSubRepo: userSubRepo,
}
}
type CreateAnnouncementInput struct {
Title string
Content string
Status string
Targeting AnnouncementTargeting
StartsAt *time.Time
EndsAt *time.Time
ActorID *int64 // 管理员用户ID
}
type UpdateAnnouncementInput struct {
Title *string
Content *string
Status *string
Targeting *AnnouncementTargeting
StartsAt **time.Time
EndsAt **time.Time
ActorID *int64 // 管理员用户ID
}
type UserAnnouncement struct {
Announcement Announcement
ReadAt *time.Time
}
type AnnouncementUserReadStatus struct {
UserID int64 `json:"user_id"`
Email string `json:"email"`
Username string `json:"username"`
Balance float64 `json:"balance"`
Eligible bool `json:"eligible"`
ReadAt *time.Time `json:"read_at,omitempty"`
}
func (s *AnnouncementService) Create(ctx context.Context, input *CreateAnnouncementInput) (*Announcement, error) {
if input == nil {
return nil, fmt.Errorf("create announcement: nil input")
}
title := strings.TrimSpace(input.Title)
content := strings.TrimSpace(input.Content)
if title == "" || len(title) > 200 {
return nil, fmt.Errorf("create announcement: invalid title")
}
if content == "" {
return nil, fmt.Errorf("create announcement: content is required")
}
status := strings.TrimSpace(input.Status)
if status == "" {
status = AnnouncementStatusDraft
}
if !isValidAnnouncementStatus(status) {
return nil, fmt.Errorf("create announcement: invalid status")
}
targeting, err := domain.AnnouncementTargeting(input.Targeting).NormalizeAndValidate()
if err != nil {
return nil, err
}
if input.StartsAt != nil && input.EndsAt != nil {
if !input.StartsAt.Before(*input.EndsAt) {
return nil, fmt.Errorf("create announcement: starts_at must be before ends_at")
}
}
a := &Announcement{
Title: title,
Content: content,
Status: status,
Targeting: targeting,
StartsAt: input.StartsAt,
EndsAt: input.EndsAt,
}
if input.ActorID != nil && *input.ActorID > 0 {
a.CreatedBy = input.ActorID
a.UpdatedBy = input.ActorID
}
if err := s.announcementRepo.Create(ctx, a); err != nil {
return nil, fmt.Errorf("create announcement: %w", err)
}
return a, nil
}
func (s *AnnouncementService) Update(ctx context.Context, id int64, input *UpdateAnnouncementInput) (*Announcement, error) {
if input == nil {
return nil, fmt.Errorf("update announcement: nil input")
}
a, err := s.announcementRepo.GetByID(ctx, id)
if err != nil {
return nil, err
}
if input.Title != nil {
title := strings.TrimSpace(*input.Title)
if title == "" || len(title) > 200 {
return nil, fmt.Errorf("update announcement: invalid title")
}
a.Title = title
}
if input.Content != nil {
content := strings.TrimSpace(*input.Content)
if content == "" {
return nil, fmt.Errorf("update announcement: content is required")
}
a.Content = content
}
if input.Status != nil {
status := strings.TrimSpace(*input.Status)
if !isValidAnnouncementStatus(status) {
return nil, fmt.Errorf("update announcement: invalid status")
}
a.Status = status
}
if input.Targeting != nil {
targeting, err := domain.AnnouncementTargeting(*input.Targeting).NormalizeAndValidate()
if err != nil {
return nil, err
}
a.Targeting = targeting
}
if input.StartsAt != nil {
a.StartsAt = *input.StartsAt
}
if input.EndsAt != nil {
a.EndsAt = *input.EndsAt
}
if a.StartsAt != nil && a.EndsAt != nil {
if !a.StartsAt.Before(*a.EndsAt) {
return nil, fmt.Errorf("update announcement: starts_at must be before ends_at")
}
}
if input.ActorID != nil && *input.ActorID > 0 {
a.UpdatedBy = input.ActorID
}
if err := s.announcementRepo.Update(ctx, a); err != nil {
return nil, fmt.Errorf("update announcement: %w", err)
}
return a, nil
}
func (s *AnnouncementService) Delete(ctx context.Context, id int64) error {
if err := s.announcementRepo.Delete(ctx, id); err != nil {
return fmt.Errorf("delete announcement: %w", err)
}
return nil
}
func (s *AnnouncementService) GetByID(ctx context.Context, id int64) (*Announcement, error) {
return s.announcementRepo.GetByID(ctx, id)
}
func (s *AnnouncementService) List(ctx context.Context, params pagination.PaginationParams, filters AnnouncementListFilters) ([]Announcement, *pagination.PaginationResult, error) {
return s.announcementRepo.List(ctx, params, filters)
}
func (s *AnnouncementService) ListForUser(ctx context.Context, userID int64, unreadOnly bool) ([]UserAnnouncement, error) {
user, err := s.userRepo.GetByID(ctx, userID)
if err != nil {
return nil, fmt.Errorf("get user: %w", err)
}
activeSubs, err := s.userSubRepo.ListActiveByUserID(ctx, userID)
if err != nil {
return nil, fmt.Errorf("list active subscriptions: %w", err)
}
activeGroupIDs := make(map[int64]struct{}, len(activeSubs))
for i := range activeSubs {
activeGroupIDs[activeSubs[i].GroupID] = struct{}{}
}
now := time.Now()
anns, err := s.announcementRepo.ListActive(ctx, now)
if err != nil {
return nil, fmt.Errorf("list active announcements: %w", err)
}
visible := make([]Announcement, 0, len(anns))
ids := make([]int64, 0, len(anns))
for i := range anns {
a := anns[i]
if !a.IsActiveAt(now) {
continue
}
if !a.Targeting.Matches(user.Balance, activeGroupIDs) {
continue
}
visible = append(visible, a)
ids = append(ids, a.ID)
}
if len(visible) == 0 {
return []UserAnnouncement{}, nil
}
readMap, err := s.readRepo.GetReadMapByUser(ctx, userID, ids)
if err != nil {
return nil, fmt.Errorf("get read map: %w", err)
}
out := make([]UserAnnouncement, 0, len(visible))
for i := range visible {
a := visible[i]
readAt, ok := readMap[a.ID]
if unreadOnly && ok {
continue
}
var ptr *time.Time
if ok {
t := readAt
ptr = &t
}
out = append(out, UserAnnouncement{
Announcement: a,
ReadAt: ptr,
})
}
// 未读优先、同状态按创建时间倒序
sort.Slice(out, func(i, j int) bool {
ai, aj := out[i], out[j]
if (ai.ReadAt == nil) != (aj.ReadAt == nil) {
return ai.ReadAt == nil
}
return ai.Announcement.ID > aj.Announcement.ID
})
return out, nil
}
func (s *AnnouncementService) MarkRead(ctx context.Context, userID, announcementID int64) error {
// 安全:仅允许标记当前用户“可见”的公告
user, err := s.userRepo.GetByID(ctx, userID)
if err != nil {
return fmt.Errorf("get user: %w", err)
}
a, err := s.announcementRepo.GetByID(ctx, announcementID)
if err != nil {
return err
}
now := time.Now()
if !a.IsActiveAt(now) {
return ErrAnnouncementNotFound
}
activeSubs, err := s.userSubRepo.ListActiveByUserID(ctx, userID)
if err != nil {
return fmt.Errorf("list active subscriptions: %w", err)
}
activeGroupIDs := make(map[int64]struct{}, len(activeSubs))
for i := range activeSubs {
activeGroupIDs[activeSubs[i].GroupID] = struct{}{}
}
if !a.Targeting.Matches(user.Balance, activeGroupIDs) {
return ErrAnnouncementNotFound
}
if err := s.readRepo.MarkRead(ctx, announcementID, userID, now); err != nil {
return fmt.Errorf("mark read: %w", err)
}
return nil
}
func (s *AnnouncementService) ListUserReadStatus(
ctx context.Context,
announcementID int64,
params pagination.PaginationParams,
search string,
) ([]AnnouncementUserReadStatus, *pagination.PaginationResult, error) {
ann, err := s.announcementRepo.GetByID(ctx, announcementID)
if err != nil {
return nil, nil, err
}
filters := UserListFilters{
Search: strings.TrimSpace(search),
}
users, page, err := s.userRepo.ListWithFilters(ctx, params, filters)
if err != nil {
return nil, nil, fmt.Errorf("list users: %w", err)
}
userIDs := make([]int64, 0, len(users))
for i := range users {
userIDs = append(userIDs, users[i].ID)
}
readMap, err := s.readRepo.GetReadMapByUsers(ctx, announcementID, userIDs)
if err != nil {
return nil, nil, fmt.Errorf("get read map: %w", err)
}
out := make([]AnnouncementUserReadStatus, 0, len(users))
for i := range users {
u := users[i]
subs, err := s.userSubRepo.ListActiveByUserID(ctx, u.ID)
if err != nil {
return nil, nil, fmt.Errorf("list active subscriptions: %w", err)
}
activeGroupIDs := make(map[int64]struct{}, len(subs))
for j := range subs {
activeGroupIDs[subs[j].GroupID] = struct{}{}
}
readAt, ok := readMap[u.ID]
var ptr *time.Time
if ok {
t := readAt
ptr = &t
}
out = append(out, AnnouncementUserReadStatus{
UserID: u.ID,
Email: u.Email,
Username: u.Username,
Balance: u.Balance,
Eligible: domain.AnnouncementTargeting(ann.Targeting).Matches(u.Balance, activeGroupIDs),
ReadAt: ptr,
})
}
return out, page, nil
}
func isValidAnnouncementStatus(status string) bool {
switch status {
case AnnouncementStatusDraft, AnnouncementStatusActive, AnnouncementStatusArchived:
return true
default:
return false
}
}
package service
import (
"testing"
"github.com/stretchr/testify/require"
)
func TestAnnouncementTargeting_Matches_EmptyMatchesAll(t *testing.T) {
var targeting AnnouncementTargeting
require.True(t, targeting.Matches(0, nil))
require.True(t, targeting.Matches(123.45, map[int64]struct{}{1: {}}))
}
func TestAnnouncementTargeting_NormalizeAndValidate_RejectsEmptyGroup(t *testing.T) {
targeting := AnnouncementTargeting{
AnyOf: []AnnouncementConditionGroup{
{AllOf: nil},
},
}
_, err := targeting.NormalizeAndValidate()
require.Error(t, err)
require.ErrorIs(t, err, ErrAnnouncementInvalidTarget)
}
func TestAnnouncementTargeting_NormalizeAndValidate_RejectsInvalidCondition(t *testing.T) {
targeting := AnnouncementTargeting{
AnyOf: []AnnouncementConditionGroup{
{
AllOf: []AnnouncementCondition{
{Type: "balance", Operator: "between", Value: 10},
},
},
},
}
_, err := targeting.NormalizeAndValidate()
require.Error(t, err)
require.ErrorIs(t, err, ErrAnnouncementInvalidTarget)
}
func TestAnnouncementTargeting_Matches_AndOrSemantics(t *testing.T) {
targeting := AnnouncementTargeting{
AnyOf: []AnnouncementConditionGroup{
{
AllOf: []AnnouncementCondition{
{Type: AnnouncementConditionTypeBalance, Operator: AnnouncementOperatorGTE, Value: 100},
{Type: AnnouncementConditionTypeSubscription, Operator: AnnouncementOperatorIn, GroupIDs: []int64{10}},
},
},
{
AllOf: []AnnouncementCondition{
{Type: AnnouncementConditionTypeBalance, Operator: AnnouncementOperatorLT, Value: 5},
},
},
},
}
// 命中第 2 组(balance < 5)
require.True(t, targeting.Matches(4.99, nil))
require.False(t, targeting.Matches(5, nil))
// 命中第 1 组(balance >= 100 AND 订阅 in [10])
require.False(t, targeting.Matches(100, map[int64]struct{}{}))
require.False(t, targeting.Matches(99.9, map[int64]struct{}{10: {}}))
require.True(t, targeting.Matches(100, map[int64]struct{}{10: {}}))
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment