Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Menu
Open sidebar
陈曦
sub2api
Commits
31fe0178
Commit
31fe0178
authored
Feb 03, 2026
by
yangjianbo
Browse files
Merge branch 'main' of
https://github.com/mt21625457/aicodex2api
parents
d9e345f2
ba5a0d47
Changes
235
Hide whitespace changes
Inline
Side-by-side
backend/internal/repository/user_subscription_repo.go
View file @
31fe0178
...
...
@@ -190,7 +190,7 @@ func (r *userSubscriptionRepository) ListByGroupID(ctx context.Context, groupID
return
userSubscriptionEntitiesToService
(
subs
),
paginationResultFromTotal
(
int64
(
total
),
params
),
nil
}
func
(
r
*
userSubscriptionRepository
)
List
(
ctx
context
.
Context
,
params
pagination
.
PaginationParams
,
userID
,
groupID
*
int64
,
status
string
)
([]
service
.
UserSubscription
,
*
pagination
.
PaginationResult
,
error
)
{
func
(
r
*
userSubscriptionRepository
)
List
(
ctx
context
.
Context
,
params
pagination
.
PaginationParams
,
userID
,
groupID
*
int64
,
status
,
sortBy
,
sortOrder
string
)
([]
service
.
UserSubscription
,
*
pagination
.
PaginationResult
,
error
)
{
client
:=
clientFromContext
(
ctx
,
r
.
client
)
q
:=
client
.
UserSubscription
.
Query
()
if
userID
!=
nil
{
...
...
@@ -199,7 +199,31 @@ func (r *userSubscriptionRepository) List(ctx context.Context, params pagination
if
groupID
!=
nil
{
q
=
q
.
Where
(
usersubscription
.
GroupIDEQ
(
*
groupID
))
}
if
status
!=
""
{
// Status filtering with real-time expiration check
now
:=
time
.
Now
()
switch
status
{
case
service
.
SubscriptionStatusActive
:
// Active: status is active AND not yet expired
q
=
q
.
Where
(
usersubscription
.
StatusEQ
(
service
.
SubscriptionStatusActive
),
usersubscription
.
ExpiresAtGT
(
now
),
)
case
service
.
SubscriptionStatusExpired
:
// Expired: status is expired OR (status is active but already expired)
q
=
q
.
Where
(
usersubscription
.
Or
(
usersubscription
.
StatusEQ
(
service
.
SubscriptionStatusExpired
),
usersubscription
.
And
(
usersubscription
.
StatusEQ
(
service
.
SubscriptionStatusActive
),
usersubscription
.
ExpiresAtLTE
(
now
),
),
),
)
case
""
:
// No filter
default
:
// Other status (e.g., revoked)
q
=
q
.
Where
(
usersubscription
.
StatusEQ
(
status
))
}
...
...
@@ -208,11 +232,28 @@ func (r *userSubscriptionRepository) List(ctx context.Context, params pagination
return
nil
,
nil
,
err
}
// Apply sorting
q
=
q
.
WithUser
()
.
WithGroup
()
.
WithAssignedByUser
()
// Determine sort field
var
field
string
switch
sortBy
{
case
"expires_at"
:
field
=
usersubscription
.
FieldExpiresAt
case
"status"
:
field
=
usersubscription
.
FieldStatus
default
:
field
=
usersubscription
.
FieldCreatedAt
}
// Determine sort order (default: desc)
if
sortOrder
==
"asc"
&&
sortBy
!=
""
{
q
=
q
.
Order
(
dbent
.
Asc
(
field
))
}
else
{
q
=
q
.
Order
(
dbent
.
Desc
(
field
))
}
subs
,
err
:=
q
.
WithUser
()
.
WithGroup
()
.
WithAssignedByUser
()
.
Order
(
dbent
.
Desc
(
usersubscription
.
FieldCreatedAt
))
.
Offset
(
params
.
Offset
())
.
Limit
(
params
.
Limit
())
.
All
(
ctx
)
...
...
backend/internal/repository/user_subscription_repo_integration_test.go
View file @
31fe0178
...
...
@@ -271,7 +271,7 @@ func (s *UserSubscriptionRepoSuite) TestList_NoFilters() {
group
:=
s
.
mustCreateGroup
(
"g-list"
)
s
.
mustCreateSubscription
(
user
.
ID
,
group
.
ID
,
nil
)
subs
,
page
,
err
:=
s
.
repo
.
List
(
s
.
ctx
,
pagination
.
PaginationParams
{
Page
:
1
,
PageSize
:
10
},
nil
,
nil
,
""
)
subs
,
page
,
err
:=
s
.
repo
.
List
(
s
.
ctx
,
pagination
.
PaginationParams
{
Page
:
1
,
PageSize
:
10
},
nil
,
nil
,
""
,
""
,
""
)
s
.
Require
()
.
NoError
(
err
,
"List"
)
s
.
Require
()
.
Len
(
subs
,
1
)
s
.
Require
()
.
Equal
(
int64
(
1
),
page
.
Total
)
...
...
@@ -285,7 +285,7 @@ func (s *UserSubscriptionRepoSuite) TestList_FilterByUserID() {
s
.
mustCreateSubscription
(
user1
.
ID
,
group
.
ID
,
nil
)
s
.
mustCreateSubscription
(
user2
.
ID
,
group
.
ID
,
nil
)
subs
,
_
,
err
:=
s
.
repo
.
List
(
s
.
ctx
,
pagination
.
PaginationParams
{
Page
:
1
,
PageSize
:
10
},
&
user1
.
ID
,
nil
,
""
)
subs
,
_
,
err
:=
s
.
repo
.
List
(
s
.
ctx
,
pagination
.
PaginationParams
{
Page
:
1
,
PageSize
:
10
},
&
user1
.
ID
,
nil
,
""
,
""
,
""
)
s
.
Require
()
.
NoError
(
err
)
s
.
Require
()
.
Len
(
subs
,
1
)
s
.
Require
()
.
Equal
(
user1
.
ID
,
subs
[
0
]
.
UserID
)
...
...
@@ -299,7 +299,7 @@ func (s *UserSubscriptionRepoSuite) TestList_FilterByGroupID() {
s
.
mustCreateSubscription
(
user
.
ID
,
g1
.
ID
,
nil
)
s
.
mustCreateSubscription
(
user
.
ID
,
g2
.
ID
,
nil
)
subs
,
_
,
err
:=
s
.
repo
.
List
(
s
.
ctx
,
pagination
.
PaginationParams
{
Page
:
1
,
PageSize
:
10
},
nil
,
&
g1
.
ID
,
""
)
subs
,
_
,
err
:=
s
.
repo
.
List
(
s
.
ctx
,
pagination
.
PaginationParams
{
Page
:
1
,
PageSize
:
10
},
nil
,
&
g1
.
ID
,
""
,
""
,
""
)
s
.
Require
()
.
NoError
(
err
)
s
.
Require
()
.
Len
(
subs
,
1
)
s
.
Require
()
.
Equal
(
g1
.
ID
,
subs
[
0
]
.
GroupID
)
...
...
@@ -320,7 +320,7 @@ func (s *UserSubscriptionRepoSuite) TestList_FilterByStatus() {
c
.
SetExpiresAt
(
time
.
Now
()
.
Add
(
-
24
*
time
.
Hour
))
})
subs
,
_
,
err
:=
s
.
repo
.
List
(
s
.
ctx
,
pagination
.
PaginationParams
{
Page
:
1
,
PageSize
:
10
},
nil
,
nil
,
service
.
SubscriptionStatusExpired
)
subs
,
_
,
err
:=
s
.
repo
.
List
(
s
.
ctx
,
pagination
.
PaginationParams
{
Page
:
1
,
PageSize
:
10
},
nil
,
nil
,
service
.
SubscriptionStatusExpired
,
""
,
""
)
s
.
Require
()
.
NoError
(
err
)
s
.
Require
()
.
Len
(
subs
,
1
)
s
.
Require
()
.
Equal
(
service
.
SubscriptionStatusExpired
,
subs
[
0
]
.
Status
)
...
...
backend/internal/repository/wire.go
View file @
31fe0178
...
...
@@ -56,6 +56,8 @@ var ProviderSet = wire.NewSet(
NewProxyRepository
,
NewRedeemCodeRepository
,
NewPromoCodeRepository
,
NewAnnouncementRepository
,
NewAnnouncementReadRepository
,
NewUsageLogRepository
,
NewUsageCleanupRepository
,
NewDashboardAggregationRepository
,
...
...
@@ -82,6 +84,10 @@ var ProviderSet = wire.NewSet(
NewSchedulerCache
,
NewSchedulerOutboxRepository
,
NewProxyLatencyCache
,
NewTotpCache
,
// Encryptors
NewAESEncryptor
,
// HTTP service ports (DI Strategy A: return interface directly)
NewTurnstileVerifier
,
...
...
backend/internal/server/api_contract_test.go
View file @
31fe0178
...
...
@@ -197,7 +197,7 @@ func TestAPIContracts(t *testing.T) {
UserID
:
1
,
GroupID
:
10
,
StartsAt
:
deps
.
now
,
ExpiresAt
:
deps
.
now
.
Add
(
24
*
time
.
Hour
),
ExpiresAt
:
time
.
Date
(
2099
,
1
,
2
,
3
,
4
,
5
,
0
,
time
.
UTC
),
// 使用未来日期避免 normalizeSubscriptionStatus 标记为过期
Status
:
service
.
SubscriptionStatusActive
,
DailyUsageUSD
:
1.23
,
WeeklyUsageUSD
:
2.34
,
...
...
@@ -222,7 +222,7 @@ func TestAPIContracts(t *testing.T) {
"user_id": 1,
"group_id": 10,
"starts_at": "2025-01-02T03:04:05Z",
"expires_at": "20
25
-01-0
3
T03:04:05Z",
"expires_at": "20
99
-01-0
2
T03:04:05Z",
"status": "active",
"daily_window_start": null,
"weekly_window_start": null,
...
...
@@ -452,6 +452,9 @@ func TestAPIContracts(t *testing.T) {
"registration_enabled": true,
"email_verify_enabled": false,
"promo_code_enabled": true,
"password_reset_enabled": false,
"totp_enabled": false,
"totp_encryption_key_configured": false,
"smtp_host": "smtp.example.com",
"smtp_port": 587,
"smtp_username": "user",
...
...
@@ -485,8 +488,11 @@ func TestAPIContracts(t *testing.T) {
"fallback_model_openai": "gpt-4o",
"enable_identity_patch": true,
"identity_patch_prompt": "",
"invitation_code_enabled": false,
"home_content": "",
"hide_ccs_import_button": false
"hide_ccs_import_button": false,
"purchase_subscription_enabled": false,
"purchase_subscription_url": ""
}
}`
,
},
...
...
@@ -595,7 +601,7 @@ func newContractDeps(t *testing.T) *contractDeps {
settingService
:=
service
.
NewSettingService
(
settingRepo
,
cfg
)
adminService
:=
service
.
NewAdminService
(
userRepo
,
groupRepo
,
&
accountRepo
,
proxyRepo
,
apiKeyRepo
,
redeemRepo
,
nil
,
nil
,
nil
,
nil
)
authHandler
:=
handler
.
NewAuthHandler
(
cfg
,
nil
,
userService
,
settingService
,
nil
)
authHandler
:=
handler
.
NewAuthHandler
(
cfg
,
nil
,
userService
,
settingService
,
nil
,
nil
,
nil
)
apiKeyHandler
:=
handler
.
NewAPIKeyHandler
(
apiKeyService
)
usageHandler
:=
handler
.
NewUsageHandler
(
usageService
,
apiKeyService
)
adminSettingHandler
:=
adminhandler
.
NewSettingHandler
(
settingService
,
nil
,
nil
,
nil
)
...
...
@@ -754,6 +760,18 @@ func (r *stubUserRepo) RemoveGroupFromAllowedGroups(ctx context.Context, groupID
return
0
,
errors
.
New
(
"not implemented"
)
}
func
(
r
*
stubUserRepo
)
UpdateTotpSecret
(
ctx
context
.
Context
,
userID
int64
,
encryptedSecret
*
string
)
error
{
return
errors
.
New
(
"not implemented"
)
}
func
(
r
*
stubUserRepo
)
EnableTotp
(
ctx
context
.
Context
,
userID
int64
)
error
{
return
errors
.
New
(
"not implemented"
)
}
func
(
r
*
stubUserRepo
)
DisableTotp
(
ctx
context
.
Context
,
userID
int64
)
error
{
return
errors
.
New
(
"not implemented"
)
}
type
stubApiKeyCache
struct
{}
func
(
stubApiKeyCache
)
GetCreateAttemptCount
(
ctx
context
.
Context
,
userID
int64
)
(
int
,
error
)
{
...
...
@@ -863,6 +881,14 @@ func (stubGroupRepo) DeleteAccountGroupsByGroupID(ctx context.Context, groupID i
return
0
,
errors
.
New
(
"not implemented"
)
}
func
(
stubGroupRepo
)
BindAccountsToGroup
(
ctx
context
.
Context
,
groupID
int64
,
accountIDs
[]
int64
)
error
{
return
errors
.
New
(
"not implemented"
)
}
func
(
stubGroupRepo
)
GetAccountIDsByGroupIDs
(
ctx
context
.
Context
,
groupIDs
[]
int64
)
([]
int64
,
error
)
{
return
nil
,
errors
.
New
(
"not implemented"
)
}
type
stubAccountRepo
struct
{
bulkUpdateIDs
[]
int64
}
...
...
@@ -1124,6 +1150,14 @@ func (r *stubRedeemCodeRepo) ListByUser(ctx context.Context, userID int64, limit
return
append
([]
service
.
RedeemCode
(
nil
),
codes
...
),
nil
}
func
(
stubRedeemCodeRepo
)
ListByUserPaginated
(
ctx
context
.
Context
,
userID
int64
,
params
pagination
.
PaginationParams
,
codeType
string
)
([]
service
.
RedeemCode
,
*
pagination
.
PaginationResult
,
error
)
{
return
nil
,
nil
,
errors
.
New
(
"not implemented"
)
}
func
(
stubRedeemCodeRepo
)
SumPositiveBalanceByUser
(
ctx
context
.
Context
,
userID
int64
)
(
float64
,
error
)
{
return
0
,
errors
.
New
(
"not implemented"
)
}
type
stubUserSubscriptionRepo
struct
{
byUser
map
[
int64
][]
service
.
UserSubscription
activeByUser
map
[
int64
][]
service
.
UserSubscription
...
...
@@ -1176,7 +1210,7 @@ func (r *stubUserSubscriptionRepo) ListActiveByUserID(ctx context.Context, userI
func
(
stubUserSubscriptionRepo
)
ListByGroupID
(
ctx
context
.
Context
,
groupID
int64
,
params
pagination
.
PaginationParams
)
([]
service
.
UserSubscription
,
*
pagination
.
PaginationResult
,
error
)
{
return
nil
,
nil
,
errors
.
New
(
"not implemented"
)
}
func
(
stubUserSubscriptionRepo
)
List
(
ctx
context
.
Context
,
params
pagination
.
PaginationParams
,
userID
,
groupID
*
int64
,
status
string
)
([]
service
.
UserSubscription
,
*
pagination
.
PaginationResult
,
error
)
{
func
(
stubUserSubscriptionRepo
)
List
(
ctx
context
.
Context
,
params
pagination
.
PaginationParams
,
userID
,
groupID
*
int64
,
status
,
sortBy
,
sortOrder
string
)
([]
service
.
UserSubscription
,
*
pagination
.
PaginationResult
,
error
)
{
return
nil
,
nil
,
errors
.
New
(
"not implemented"
)
}
func
(
stubUserSubscriptionRepo
)
ExistsByUserIDAndGroupID
(
ctx
context
.
Context
,
userID
,
groupID
int64
)
(
bool
,
error
)
{
...
...
backend/internal/server/middleware/api_key_auth_test.go
View file @
31fe0178
...
...
@@ -367,7 +367,7 @@ func (r *stubUserSubscriptionRepo) ListByGroupID(ctx context.Context, groupID in
return
nil
,
nil
,
errors
.
New
(
"not implemented"
)
}
func
(
r
*
stubUserSubscriptionRepo
)
List
(
ctx
context
.
Context
,
params
pagination
.
PaginationParams
,
userID
,
groupID
*
int64
,
status
string
)
([]
service
.
UserSubscription
,
*
pagination
.
PaginationResult
,
error
)
{
func
(
r
*
stubUserSubscriptionRepo
)
List
(
ctx
context
.
Context
,
params
pagination
.
PaginationParams
,
userID
,
groupID
*
int64
,
status
,
sortBy
,
sortOrder
string
)
([]
service
.
UserSubscription
,
*
pagination
.
PaginationResult
,
error
)
{
return
nil
,
nil
,
errors
.
New
(
"not implemented"
)
}
...
...
backend/internal/server/routes/admin.go
View file @
31fe0178
...
...
@@ -29,6 +29,9 @@ func RegisterAdminRoutes(
// 账号管理
registerAccountRoutes
(
admin
,
h
)
// 公告管理
registerAnnouncementRoutes
(
admin
,
h
)
// OpenAI OAuth
registerOpenAIOAuthRoutes
(
admin
,
h
)
...
...
@@ -172,6 +175,7 @@ func registerUserManagementRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
users
.
POST
(
"/:id/balance"
,
h
.
Admin
.
User
.
UpdateBalance
)
users
.
GET
(
"/:id/api-keys"
,
h
.
Admin
.
User
.
GetUserAPIKeys
)
users
.
GET
(
"/:id/usage"
,
h
.
Admin
.
User
.
GetUserUsage
)
users
.
GET
(
"/:id/balance-history"
,
h
.
Admin
.
User
.
GetBalanceHistory
)
// User attribute values
users
.
GET
(
"/:id/attributes"
,
h
.
Admin
.
UserAttribute
.
GetUserAttributes
)
...
...
@@ -229,6 +233,18 @@ func registerAccountRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
}
}
func
registerAnnouncementRoutes
(
admin
*
gin
.
RouterGroup
,
h
*
handler
.
Handlers
)
{
announcements
:=
admin
.
Group
(
"/announcements"
)
{
announcements
.
GET
(
""
,
h
.
Admin
.
Announcement
.
List
)
announcements
.
POST
(
""
,
h
.
Admin
.
Announcement
.
Create
)
announcements
.
GET
(
"/:id"
,
h
.
Admin
.
Announcement
.
GetByID
)
announcements
.
PUT
(
"/:id"
,
h
.
Admin
.
Announcement
.
Update
)
announcements
.
DELETE
(
"/:id"
,
h
.
Admin
.
Announcement
.
Delete
)
announcements
.
GET
(
"/:id/read-status"
,
h
.
Admin
.
Announcement
.
ListReadStatus
)
}
}
func
registerOpenAIOAuthRoutes
(
admin
*
gin
.
RouterGroup
,
h
*
handler
.
Handlers
)
{
openai
:=
admin
.
Group
(
"/openai"
)
{
...
...
backend/internal/server/routes/auth.go
View file @
31fe0178
...
...
@@ -26,11 +26,24 @@ func RegisterAuthRoutes(
{
auth
.
POST
(
"/register"
,
h
.
Auth
.
Register
)
auth
.
POST
(
"/login"
,
h
.
Auth
.
Login
)
auth
.
POST
(
"/login/2fa"
,
h
.
Auth
.
Login2FA
)
auth
.
POST
(
"/send-verify-code"
,
h
.
Auth
.
SendVerifyCode
)
// 优惠码验证接口添加速率限制:每分钟最多 10 次(Redis 故障时 fail-close)
auth
.
POST
(
"/validate-promo-code"
,
rateLimiter
.
LimitWithOptions
(
"validate-promo"
,
10
,
time
.
Minute
,
middleware
.
RateLimitOptions
{
FailureMode
:
middleware
.
RateLimitFailClose
,
}),
h
.
Auth
.
ValidatePromoCode
)
// 邀请码验证接口添加速率限制:每分钟最多 10 次(Redis 故障时 fail-close)
auth
.
POST
(
"/validate-invitation-code"
,
rateLimiter
.
LimitWithOptions
(
"validate-invitation"
,
10
,
time
.
Minute
,
middleware
.
RateLimitOptions
{
FailureMode
:
middleware
.
RateLimitFailClose
,
}),
h
.
Auth
.
ValidateInvitationCode
)
// 忘记密码接口添加速率限制:每分钟最多 5 次(Redis 故障时 fail-close)
auth
.
POST
(
"/forgot-password"
,
rateLimiter
.
LimitWithOptions
(
"forgot-password"
,
5
,
time
.
Minute
,
middleware
.
RateLimitOptions
{
FailureMode
:
middleware
.
RateLimitFailClose
,
}),
h
.
Auth
.
ForgotPassword
)
// 重置密码接口添加速率限制:每分钟最多 10 次(Redis 故障时 fail-close)
auth
.
POST
(
"/reset-password"
,
rateLimiter
.
LimitWithOptions
(
"reset-password"
,
10
,
time
.
Minute
,
middleware
.
RateLimitOptions
{
FailureMode
:
middleware
.
RateLimitFailClose
,
}),
h
.
Auth
.
ResetPassword
)
auth
.
GET
(
"/oauth/linuxdo/start"
,
h
.
Auth
.
LinuxDoOAuthStart
)
auth
.
GET
(
"/oauth/linuxdo/callback"
,
h
.
Auth
.
LinuxDoOAuthCallback
)
}
...
...
backend/internal/server/routes/user.go
View file @
31fe0178
...
...
@@ -22,6 +22,17 @@ func RegisterUserRoutes(
user
.
GET
(
"/profile"
,
h
.
User
.
GetProfile
)
user
.
PUT
(
"/password"
,
h
.
User
.
ChangePassword
)
user
.
PUT
(
""
,
h
.
User
.
UpdateProfile
)
// TOTP 双因素认证
totp
:=
user
.
Group
(
"/totp"
)
{
totp
.
GET
(
"/status"
,
h
.
Totp
.
GetStatus
)
totp
.
GET
(
"/verification-method"
,
h
.
Totp
.
GetVerificationMethod
)
totp
.
POST
(
"/send-code"
,
h
.
Totp
.
SendVerifyCode
)
totp
.
POST
(
"/setup"
,
h
.
Totp
.
InitiateSetup
)
totp
.
POST
(
"/enable"
,
h
.
Totp
.
Enable
)
totp
.
POST
(
"/disable"
,
h
.
Totp
.
Disable
)
}
}
// API Key管理
...
...
@@ -53,6 +64,13 @@ func RegisterUserRoutes(
usage
.
POST
(
"/dashboard/api-keys-usage"
,
h
.
Usage
.
DashboardAPIKeysUsage
)
}
// 公告(用户可见)
announcements
:=
authenticated
.
Group
(
"/announcements"
)
{
announcements
.
GET
(
""
,
h
.
Announcement
.
List
)
announcements
.
POST
(
"/:id/read"
,
h
.
Announcement
.
MarkRead
)
}
// 卡密兑换
redeem
:=
authenticated
.
Group
(
"/redeem"
)
{
...
...
backend/internal/service/account.go
View file @
31fe0178
...
...
@@ -410,6 +410,22 @@ func (a *Account) GetExtraString(key string) string {
return
""
}
func
(
a
*
Account
)
GetClaudeUserID
()
string
{
if
v
:=
strings
.
TrimSpace
(
a
.
GetExtraString
(
"claude_user_id"
));
v
!=
""
{
return
v
}
if
v
:=
strings
.
TrimSpace
(
a
.
GetExtraString
(
"anthropic_user_id"
));
v
!=
""
{
return
v
}
if
v
:=
strings
.
TrimSpace
(
a
.
GetCredential
(
"claude_user_id"
));
v
!=
""
{
return
v
}
if
v
:=
strings
.
TrimSpace
(
a
.
GetCredential
(
"anthropic_user_id"
));
v
!=
""
{
return
v
}
return
""
}
func
(
a
*
Account
)
IsCustomErrorCodesEnabled
()
bool
{
if
a
.
Type
!=
AccountTypeAPIKey
||
a
.
Credentials
==
nil
{
return
false
...
...
backend/internal/service/account_test_service.go
View file @
31fe0178
...
...
@@ -123,7 +123,7 @@ func createTestPayload(modelID string) (map[string]any, error) {
"system"
:
[]
map
[
string
]
any
{
{
"type"
:
"text"
,
"text"
:
"You are C
laude
Code
, Anthropic's official CLI for Claude."
,
"text"
:
c
laudeCode
SystemPrompt
,
"cache_control"
:
map
[
string
]
string
{
"type"
:
"ephemeral"
,
},
...
...
backend/internal/service/admin_service.go
View file @
31fe0178
...
...
@@ -22,6 +22,10 @@ type AdminService interface {
UpdateUserBalance
(
ctx
context
.
Context
,
userID
int64
,
balance
float64
,
operation
string
,
notes
string
)
(
*
User
,
error
)
GetUserAPIKeys
(
ctx
context
.
Context
,
userID
int64
,
page
,
pageSize
int
)
([]
APIKey
,
int64
,
error
)
GetUserUsageStats
(
ctx
context
.
Context
,
userID
int64
,
period
string
)
(
any
,
error
)
// GetUserBalanceHistory returns paginated balance/concurrency change records for a user.
// codeType is optional - pass empty string to return all types.
// Also returns totalRecharged (sum of all positive balance top-ups).
GetUserBalanceHistory
(
ctx
context
.
Context
,
userID
int64
,
page
,
pageSize
int
,
codeType
string
)
([]
RedeemCode
,
int64
,
float64
,
error
)
// Group management
ListGroups
(
ctx
context
.
Context
,
page
,
pageSize
int
,
platform
,
status
,
search
string
,
isExclusive
*
bool
)
([]
Group
,
int64
,
error
)
...
...
@@ -110,6 +114,8 @@ type CreateGroupInput struct {
// 模型路由配置(仅 anthropic 平台使用)
ModelRouting
map
[
string
][]
int64
ModelRoutingEnabled
bool
// 是否启用模型路由
// 从指定分组复制账号(创建分组后在同一事务内绑定)
CopyAccountsFromGroupIDs
[]
int64
}
type
UpdateGroupInput
struct
{
...
...
@@ -132,6 +138,8 @@ type UpdateGroupInput struct {
// 模型路由配置(仅 anthropic 平台使用)
ModelRouting
map
[
string
][]
int64
ModelRoutingEnabled
*
bool
// 是否启用模型路由
// 从指定分组复制账号(同步操作:先清空当前分组的账号绑定,再绑定源分组的账号)
CopyAccountsFromGroupIDs
[]
int64
}
type
CreateAccountInput
struct
{
...
...
@@ -522,6 +530,21 @@ func (s *adminServiceImpl) GetUserUsageStats(ctx context.Context, userID int64,
},
nil
}
// GetUserBalanceHistory returns paginated balance/concurrency change records for a user.
func
(
s
*
adminServiceImpl
)
GetUserBalanceHistory
(
ctx
context
.
Context
,
userID
int64
,
page
,
pageSize
int
,
codeType
string
)
([]
RedeemCode
,
int64
,
float64
,
error
)
{
params
:=
pagination
.
PaginationParams
{
Page
:
page
,
PageSize
:
pageSize
}
codes
,
result
,
err
:=
s
.
redeemCodeRepo
.
ListByUserPaginated
(
ctx
,
userID
,
params
,
codeType
)
if
err
!=
nil
{
return
nil
,
0
,
0
,
err
}
// Aggregate total recharged amount (only once, regardless of type filter)
totalRecharged
,
err
:=
s
.
redeemCodeRepo
.
SumPositiveBalanceByUser
(
ctx
,
userID
)
if
err
!=
nil
{
return
nil
,
0
,
0
,
err
}
return
codes
,
result
.
Total
,
totalRecharged
,
nil
}
// Group management implementations
func
(
s
*
adminServiceImpl
)
ListGroups
(
ctx
context
.
Context
,
page
,
pageSize
int
,
platform
,
status
,
search
string
,
isExclusive
*
bool
)
([]
Group
,
int64
,
error
)
{
params
:=
pagination
.
PaginationParams
{
Page
:
page
,
PageSize
:
pageSize
}
...
...
@@ -572,6 +595,38 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn
}
}
// 如果指定了复制账号的源分组,先获取账号 ID 列表
var
accountIDsToCopy
[]
int64
if
len
(
input
.
CopyAccountsFromGroupIDs
)
>
0
{
// 去重源分组 IDs
seen
:=
make
(
map
[
int64
]
struct
{})
uniqueSourceGroupIDs
:=
make
([]
int64
,
0
,
len
(
input
.
CopyAccountsFromGroupIDs
))
for
_
,
srcGroupID
:=
range
input
.
CopyAccountsFromGroupIDs
{
if
_
,
exists
:=
seen
[
srcGroupID
];
!
exists
{
seen
[
srcGroupID
]
=
struct
{}{}
uniqueSourceGroupIDs
=
append
(
uniqueSourceGroupIDs
,
srcGroupID
)
}
}
// 校验源分组的平台是否与新分组一致
for
_
,
srcGroupID
:=
range
uniqueSourceGroupIDs
{
srcGroup
,
err
:=
s
.
groupRepo
.
GetByIDLite
(
ctx
,
srcGroupID
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"source group %d not found: %w"
,
srcGroupID
,
err
)
}
if
srcGroup
.
Platform
!=
platform
{
return
nil
,
fmt
.
Errorf
(
"source group %d platform mismatch: expected %s, got %s"
,
srcGroupID
,
platform
,
srcGroup
.
Platform
)
}
}
// 获取所有源分组的账号(去重)
var
err
error
accountIDsToCopy
,
err
=
s
.
groupRepo
.
GetAccountIDsByGroupIDs
(
ctx
,
uniqueSourceGroupIDs
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"failed to get accounts from source groups: %w"
,
err
)
}
}
group
:=
&
Group
{
Name
:
input
.
Name
,
Description
:
input
.
Description
,
...
...
@@ -593,6 +648,15 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn
if
err
:=
s
.
groupRepo
.
Create
(
ctx
,
group
);
err
!=
nil
{
return
nil
,
err
}
// 如果有需要复制的账号,绑定到新分组
if
len
(
accountIDsToCopy
)
>
0
{
if
err
:=
s
.
groupRepo
.
BindAccountsToGroup
(
ctx
,
group
.
ID
,
accountIDsToCopy
);
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"failed to bind accounts to new group: %w"
,
err
)
}
group
.
AccountCount
=
int64
(
len
(
accountIDsToCopy
))
}
return
group
,
nil
}
...
...
@@ -728,6 +792,54 @@ func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *Upd
if
err
:=
s
.
groupRepo
.
Update
(
ctx
,
group
);
err
!=
nil
{
return
nil
,
err
}
// 如果指定了复制账号的源分组,同步绑定(替换当前分组的账号)
if
len
(
input
.
CopyAccountsFromGroupIDs
)
>
0
{
// 去重源分组 IDs
seen
:=
make
(
map
[
int64
]
struct
{})
uniqueSourceGroupIDs
:=
make
([]
int64
,
0
,
len
(
input
.
CopyAccountsFromGroupIDs
))
for
_
,
srcGroupID
:=
range
input
.
CopyAccountsFromGroupIDs
{
// 校验:源分组不能是自身
if
srcGroupID
==
id
{
return
nil
,
fmt
.
Errorf
(
"cannot copy accounts from self"
)
}
// 去重
if
_
,
exists
:=
seen
[
srcGroupID
];
!
exists
{
seen
[
srcGroupID
]
=
struct
{}{}
uniqueSourceGroupIDs
=
append
(
uniqueSourceGroupIDs
,
srcGroupID
)
}
}
// 校验源分组的平台是否与当前分组一致
for
_
,
srcGroupID
:=
range
uniqueSourceGroupIDs
{
srcGroup
,
err
:=
s
.
groupRepo
.
GetByIDLite
(
ctx
,
srcGroupID
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"source group %d not found: %w"
,
srcGroupID
,
err
)
}
if
srcGroup
.
Platform
!=
group
.
Platform
{
return
nil
,
fmt
.
Errorf
(
"source group %d platform mismatch: expected %s, got %s"
,
srcGroupID
,
group
.
Platform
,
srcGroup
.
Platform
)
}
}
// 获取所有源分组的账号(去重)
accountIDsToCopy
,
err
:=
s
.
groupRepo
.
GetAccountIDsByGroupIDs
(
ctx
,
uniqueSourceGroupIDs
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"failed to get accounts from source groups: %w"
,
err
)
}
// 先清空当前分组的所有账号绑定
if
_
,
err
:=
s
.
groupRepo
.
DeleteAccountGroupsByGroupID
(
ctx
,
id
);
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"failed to clear existing account bindings: %w"
,
err
)
}
// 再绑定源分组的账号
if
len
(
accountIDsToCopy
)
>
0
{
if
err
:=
s
.
groupRepo
.
BindAccountsToGroup
(
ctx
,
id
,
accountIDsToCopy
);
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"failed to bind accounts to group: %w"
,
err
)
}
}
}
if
s
.
authCacheInvalidator
!=
nil
{
s
.
authCacheInvalidator
.
InvalidateAuthCacheByGroupID
(
ctx
,
id
)
}
...
...
backend/internal/service/admin_service_delete_test.go
View file @
31fe0178
...
...
@@ -93,6 +93,18 @@ func (s *userRepoStub) RemoveGroupFromAllowedGroups(ctx context.Context, groupID
panic
(
"unexpected RemoveGroupFromAllowedGroups call"
)
}
func
(
s
*
userRepoStub
)
UpdateTotpSecret
(
ctx
context
.
Context
,
userID
int64
,
encryptedSecret
*
string
)
error
{
panic
(
"unexpected UpdateTotpSecret call"
)
}
func
(
s
*
userRepoStub
)
EnableTotp
(
ctx
context
.
Context
,
userID
int64
)
error
{
panic
(
"unexpected EnableTotp call"
)
}
func
(
s
*
userRepoStub
)
DisableTotp
(
ctx
context
.
Context
,
userID
int64
)
error
{
panic
(
"unexpected DisableTotp call"
)
}
type
groupRepoStub
struct
{
affectedUserIDs
[]
int64
deleteErr
error
...
...
@@ -152,6 +164,14 @@ func (s *groupRepoStub) DeleteAccountGroupsByGroupID(ctx context.Context, groupI
panic
(
"unexpected DeleteAccountGroupsByGroupID call"
)
}
func
(
s
*
groupRepoStub
)
BindAccountsToGroup
(
ctx
context
.
Context
,
groupID
int64
,
accountIDs
[]
int64
)
error
{
panic
(
"unexpected BindAccountsToGroup call"
)
}
func
(
s
*
groupRepoStub
)
GetAccountIDsByGroupIDs
(
ctx
context
.
Context
,
groupIDs
[]
int64
)
([]
int64
,
error
)
{
panic
(
"unexpected GetAccountIDsByGroupIDs call"
)
}
type
proxyRepoStub
struct
{
deleteErr
error
countErr
error
...
...
@@ -262,6 +282,14 @@ func (s *redeemRepoStub) ListByUser(ctx context.Context, userID int64, limit int
panic
(
"unexpected ListByUser call"
)
}
func
(
s
*
redeemRepoStub
)
ListByUserPaginated
(
ctx
context
.
Context
,
userID
int64
,
params
pagination
.
PaginationParams
,
codeType
string
)
([]
RedeemCode
,
*
pagination
.
PaginationResult
,
error
)
{
panic
(
"unexpected ListByUserPaginated call"
)
}
func
(
s
*
redeemRepoStub
)
SumPositiveBalanceByUser
(
ctx
context
.
Context
,
userID
int64
)
(
float64
,
error
)
{
panic
(
"unexpected SumPositiveBalanceByUser call"
)
}
type
subscriptionInvalidateCall
struct
{
userID
int64
groupID
int64
...
...
backend/internal/service/admin_service_group_test.go
View file @
31fe0178
...
...
@@ -108,6 +108,14 @@ func (s *groupRepoStubForAdmin) DeleteAccountGroupsByGroupID(_ context.Context,
panic
(
"unexpected DeleteAccountGroupsByGroupID call"
)
}
func
(
s
*
groupRepoStubForAdmin
)
BindAccountsToGroup
(
_
context
.
Context
,
_
int64
,
_
[]
int64
)
error
{
panic
(
"unexpected BindAccountsToGroup call"
)
}
func
(
s
*
groupRepoStubForAdmin
)
GetAccountIDsByGroupIDs
(
_
context
.
Context
,
_
[]
int64
)
([]
int64
,
error
)
{
panic
(
"unexpected GetAccountIDsByGroupIDs call"
)
}
// TestAdminService_CreateGroup_WithImagePricing 测试创建分组时 ImagePrice 字段正确传递
func
TestAdminService_CreateGroup_WithImagePricing
(
t
*
testing
.
T
)
{
repo
:=
&
groupRepoStubForAdmin
{}
...
...
@@ -378,3 +386,11 @@ func (s *groupRepoStubForFallbackCycle) GetAccountCount(_ context.Context, _ int
func
(
s
*
groupRepoStubForFallbackCycle
)
DeleteAccountGroupsByGroupID
(
_
context
.
Context
,
_
int64
)
(
int64
,
error
)
{
panic
(
"unexpected DeleteAccountGroupsByGroupID call"
)
}
func
(
s
*
groupRepoStubForFallbackCycle
)
BindAccountsToGroup
(
_
context
.
Context
,
_
int64
,
_
[]
int64
)
error
{
panic
(
"unexpected BindAccountsToGroup call"
)
}
func
(
s
*
groupRepoStubForFallbackCycle
)
GetAccountIDsByGroupIDs
(
_
context
.
Context
,
_
[]
int64
)
([]
int64
,
error
)
{
panic
(
"unexpected GetAccountIDsByGroupIDs call"
)
}
backend/internal/service/admin_service_search_test.go
View file @
31fe0178
...
...
@@ -152,6 +152,14 @@ func (s *redeemRepoStubForAdminList) ListWithFilters(_ context.Context, params p
return
s
.
listWithFiltersCodes
,
result
,
nil
}
func
(
s
*
redeemRepoStubForAdminList
)
ListByUserPaginated
(
_
context
.
Context
,
userID
int64
,
params
pagination
.
PaginationParams
,
codeType
string
)
([]
RedeemCode
,
*
pagination
.
PaginationResult
,
error
)
{
panic
(
"unexpected ListByUserPaginated call"
)
}
func
(
s
*
redeemRepoStubForAdminList
)
SumPositiveBalanceByUser
(
_
context
.
Context
,
userID
int64
)
(
float64
,
error
)
{
panic
(
"unexpected SumPositiveBalanceByUser call"
)
}
func
TestAdminService_ListAccounts_WithSearch
(
t
*
testing
.
T
)
{
t
.
Run
(
"search 参数正常传递到 repository 层"
,
func
(
t
*
testing
.
T
)
{
repo
:=
&
accountRepoStubForAdminList
{
...
...
backend/internal/service/announcement.go
0 → 100644
View file @
31fe0178
package
service
import
(
"context"
"time"
"github.com/Wei-Shaw/sub2api/internal/domain"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
)
const
(
AnnouncementStatusDraft
=
domain
.
AnnouncementStatusDraft
AnnouncementStatusActive
=
domain
.
AnnouncementStatusActive
AnnouncementStatusArchived
=
domain
.
AnnouncementStatusArchived
)
const
(
AnnouncementConditionTypeSubscription
=
domain
.
AnnouncementConditionTypeSubscription
AnnouncementConditionTypeBalance
=
domain
.
AnnouncementConditionTypeBalance
)
const
(
AnnouncementOperatorIn
=
domain
.
AnnouncementOperatorIn
AnnouncementOperatorGT
=
domain
.
AnnouncementOperatorGT
AnnouncementOperatorGTE
=
domain
.
AnnouncementOperatorGTE
AnnouncementOperatorLT
=
domain
.
AnnouncementOperatorLT
AnnouncementOperatorLTE
=
domain
.
AnnouncementOperatorLTE
AnnouncementOperatorEQ
=
domain
.
AnnouncementOperatorEQ
)
var
(
ErrAnnouncementNotFound
=
domain
.
ErrAnnouncementNotFound
ErrAnnouncementInvalidTarget
=
domain
.
ErrAnnouncementInvalidTarget
)
type
AnnouncementTargeting
=
domain
.
AnnouncementTargeting
type
AnnouncementConditionGroup
=
domain
.
AnnouncementConditionGroup
type
AnnouncementCondition
=
domain
.
AnnouncementCondition
type
Announcement
=
domain
.
Announcement
type
AnnouncementListFilters
struct
{
Status
string
Search
string
}
type
AnnouncementRepository
interface
{
Create
(
ctx
context
.
Context
,
a
*
Announcement
)
error
GetByID
(
ctx
context
.
Context
,
id
int64
)
(
*
Announcement
,
error
)
Update
(
ctx
context
.
Context
,
a
*
Announcement
)
error
Delete
(
ctx
context
.
Context
,
id
int64
)
error
List
(
ctx
context
.
Context
,
params
pagination
.
PaginationParams
,
filters
AnnouncementListFilters
)
([]
Announcement
,
*
pagination
.
PaginationResult
,
error
)
ListActive
(
ctx
context
.
Context
,
now
time
.
Time
)
([]
Announcement
,
error
)
}
type
AnnouncementReadRepository
interface
{
MarkRead
(
ctx
context
.
Context
,
announcementID
,
userID
int64
,
readAt
time
.
Time
)
error
GetReadMapByUser
(
ctx
context
.
Context
,
userID
int64
,
announcementIDs
[]
int64
)
(
map
[
int64
]
time
.
Time
,
error
)
GetReadMapByUsers
(
ctx
context
.
Context
,
announcementID
int64
,
userIDs
[]
int64
)
(
map
[
int64
]
time
.
Time
,
error
)
CountByAnnouncementID
(
ctx
context
.
Context
,
announcementID
int64
)
(
int64
,
error
)
}
backend/internal/service/announcement_service.go
0 → 100644
View file @
31fe0178
package
service
import
(
"context"
"fmt"
"sort"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/domain"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
)
type
AnnouncementService
struct
{
announcementRepo
AnnouncementRepository
readRepo
AnnouncementReadRepository
userRepo
UserRepository
userSubRepo
UserSubscriptionRepository
}
func
NewAnnouncementService
(
announcementRepo
AnnouncementRepository
,
readRepo
AnnouncementReadRepository
,
userRepo
UserRepository
,
userSubRepo
UserSubscriptionRepository
,
)
*
AnnouncementService
{
return
&
AnnouncementService
{
announcementRepo
:
announcementRepo
,
readRepo
:
readRepo
,
userRepo
:
userRepo
,
userSubRepo
:
userSubRepo
,
}
}
type
CreateAnnouncementInput
struct
{
Title
string
Content
string
Status
string
Targeting
AnnouncementTargeting
StartsAt
*
time
.
Time
EndsAt
*
time
.
Time
ActorID
*
int64
// 管理员用户ID
}
type
UpdateAnnouncementInput
struct
{
Title
*
string
Content
*
string
Status
*
string
Targeting
*
AnnouncementTargeting
StartsAt
**
time
.
Time
EndsAt
**
time
.
Time
ActorID
*
int64
// 管理员用户ID
}
type
UserAnnouncement
struct
{
Announcement
Announcement
ReadAt
*
time
.
Time
}
type
AnnouncementUserReadStatus
struct
{
UserID
int64
`json:"user_id"`
Email
string
`json:"email"`
Username
string
`json:"username"`
Balance
float64
`json:"balance"`
Eligible
bool
`json:"eligible"`
ReadAt
*
time
.
Time
`json:"read_at,omitempty"`
}
func
(
s
*
AnnouncementService
)
Create
(
ctx
context
.
Context
,
input
*
CreateAnnouncementInput
)
(
*
Announcement
,
error
)
{
if
input
==
nil
{
return
nil
,
fmt
.
Errorf
(
"create announcement: nil input"
)
}
title
:=
strings
.
TrimSpace
(
input
.
Title
)
content
:=
strings
.
TrimSpace
(
input
.
Content
)
if
title
==
""
||
len
(
title
)
>
200
{
return
nil
,
fmt
.
Errorf
(
"create announcement: invalid title"
)
}
if
content
==
""
{
return
nil
,
fmt
.
Errorf
(
"create announcement: content is required"
)
}
status
:=
strings
.
TrimSpace
(
input
.
Status
)
if
status
==
""
{
status
=
AnnouncementStatusDraft
}
if
!
isValidAnnouncementStatus
(
status
)
{
return
nil
,
fmt
.
Errorf
(
"create announcement: invalid status"
)
}
targeting
,
err
:=
domain
.
AnnouncementTargeting
(
input
.
Targeting
)
.
NormalizeAndValidate
()
if
err
!=
nil
{
return
nil
,
err
}
if
input
.
StartsAt
!=
nil
&&
input
.
EndsAt
!=
nil
{
if
!
input
.
StartsAt
.
Before
(
*
input
.
EndsAt
)
{
return
nil
,
fmt
.
Errorf
(
"create announcement: starts_at must be before ends_at"
)
}
}
a
:=
&
Announcement
{
Title
:
title
,
Content
:
content
,
Status
:
status
,
Targeting
:
targeting
,
StartsAt
:
input
.
StartsAt
,
EndsAt
:
input
.
EndsAt
,
}
if
input
.
ActorID
!=
nil
&&
*
input
.
ActorID
>
0
{
a
.
CreatedBy
=
input
.
ActorID
a
.
UpdatedBy
=
input
.
ActorID
}
if
err
:=
s
.
announcementRepo
.
Create
(
ctx
,
a
);
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"create announcement: %w"
,
err
)
}
return
a
,
nil
}
func
(
s
*
AnnouncementService
)
Update
(
ctx
context
.
Context
,
id
int64
,
input
*
UpdateAnnouncementInput
)
(
*
Announcement
,
error
)
{
if
input
==
nil
{
return
nil
,
fmt
.
Errorf
(
"update announcement: nil input"
)
}
a
,
err
:=
s
.
announcementRepo
.
GetByID
(
ctx
,
id
)
if
err
!=
nil
{
return
nil
,
err
}
if
input
.
Title
!=
nil
{
title
:=
strings
.
TrimSpace
(
*
input
.
Title
)
if
title
==
""
||
len
(
title
)
>
200
{
return
nil
,
fmt
.
Errorf
(
"update announcement: invalid title"
)
}
a
.
Title
=
title
}
if
input
.
Content
!=
nil
{
content
:=
strings
.
TrimSpace
(
*
input
.
Content
)
if
content
==
""
{
return
nil
,
fmt
.
Errorf
(
"update announcement: content is required"
)
}
a
.
Content
=
content
}
if
input
.
Status
!=
nil
{
status
:=
strings
.
TrimSpace
(
*
input
.
Status
)
if
!
isValidAnnouncementStatus
(
status
)
{
return
nil
,
fmt
.
Errorf
(
"update announcement: invalid status"
)
}
a
.
Status
=
status
}
if
input
.
Targeting
!=
nil
{
targeting
,
err
:=
domain
.
AnnouncementTargeting
(
*
input
.
Targeting
)
.
NormalizeAndValidate
()
if
err
!=
nil
{
return
nil
,
err
}
a
.
Targeting
=
targeting
}
if
input
.
StartsAt
!=
nil
{
a
.
StartsAt
=
*
input
.
StartsAt
}
if
input
.
EndsAt
!=
nil
{
a
.
EndsAt
=
*
input
.
EndsAt
}
if
a
.
StartsAt
!=
nil
&&
a
.
EndsAt
!=
nil
{
if
!
a
.
StartsAt
.
Before
(
*
a
.
EndsAt
)
{
return
nil
,
fmt
.
Errorf
(
"update announcement: starts_at must be before ends_at"
)
}
}
if
input
.
ActorID
!=
nil
&&
*
input
.
ActorID
>
0
{
a
.
UpdatedBy
=
input
.
ActorID
}
if
err
:=
s
.
announcementRepo
.
Update
(
ctx
,
a
);
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"update announcement: %w"
,
err
)
}
return
a
,
nil
}
func
(
s
*
AnnouncementService
)
Delete
(
ctx
context
.
Context
,
id
int64
)
error
{
if
err
:=
s
.
announcementRepo
.
Delete
(
ctx
,
id
);
err
!=
nil
{
return
fmt
.
Errorf
(
"delete announcement: %w"
,
err
)
}
return
nil
}
func
(
s
*
AnnouncementService
)
GetByID
(
ctx
context
.
Context
,
id
int64
)
(
*
Announcement
,
error
)
{
return
s
.
announcementRepo
.
GetByID
(
ctx
,
id
)
}
func
(
s
*
AnnouncementService
)
List
(
ctx
context
.
Context
,
params
pagination
.
PaginationParams
,
filters
AnnouncementListFilters
)
([]
Announcement
,
*
pagination
.
PaginationResult
,
error
)
{
return
s
.
announcementRepo
.
List
(
ctx
,
params
,
filters
)
}
func
(
s
*
AnnouncementService
)
ListForUser
(
ctx
context
.
Context
,
userID
int64
,
unreadOnly
bool
)
([]
UserAnnouncement
,
error
)
{
user
,
err
:=
s
.
userRepo
.
GetByID
(
ctx
,
userID
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"get user: %w"
,
err
)
}
activeSubs
,
err
:=
s
.
userSubRepo
.
ListActiveByUserID
(
ctx
,
userID
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"list active subscriptions: %w"
,
err
)
}
activeGroupIDs
:=
make
(
map
[
int64
]
struct
{},
len
(
activeSubs
))
for
i
:=
range
activeSubs
{
activeGroupIDs
[
activeSubs
[
i
]
.
GroupID
]
=
struct
{}{}
}
now
:=
time
.
Now
()
anns
,
err
:=
s
.
announcementRepo
.
ListActive
(
ctx
,
now
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"list active announcements: %w"
,
err
)
}
visible
:=
make
([]
Announcement
,
0
,
len
(
anns
))
ids
:=
make
([]
int64
,
0
,
len
(
anns
))
for
i
:=
range
anns
{
a
:=
anns
[
i
]
if
!
a
.
IsActiveAt
(
now
)
{
continue
}
if
!
a
.
Targeting
.
Matches
(
user
.
Balance
,
activeGroupIDs
)
{
continue
}
visible
=
append
(
visible
,
a
)
ids
=
append
(
ids
,
a
.
ID
)
}
if
len
(
visible
)
==
0
{
return
[]
UserAnnouncement
{},
nil
}
readMap
,
err
:=
s
.
readRepo
.
GetReadMapByUser
(
ctx
,
userID
,
ids
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"get read map: %w"
,
err
)
}
out
:=
make
([]
UserAnnouncement
,
0
,
len
(
visible
))
for
i
:=
range
visible
{
a
:=
visible
[
i
]
readAt
,
ok
:=
readMap
[
a
.
ID
]
if
unreadOnly
&&
ok
{
continue
}
var
ptr
*
time
.
Time
if
ok
{
t
:=
readAt
ptr
=
&
t
}
out
=
append
(
out
,
UserAnnouncement
{
Announcement
:
a
,
ReadAt
:
ptr
,
})
}
// 未读优先、同状态按创建时间倒序
sort
.
Slice
(
out
,
func
(
i
,
j
int
)
bool
{
ai
,
aj
:=
out
[
i
],
out
[
j
]
if
(
ai
.
ReadAt
==
nil
)
!=
(
aj
.
ReadAt
==
nil
)
{
return
ai
.
ReadAt
==
nil
}
return
ai
.
Announcement
.
ID
>
aj
.
Announcement
.
ID
})
return
out
,
nil
}
func
(
s
*
AnnouncementService
)
MarkRead
(
ctx
context
.
Context
,
userID
,
announcementID
int64
)
error
{
// 安全:仅允许标记当前用户“可见”的公告
user
,
err
:=
s
.
userRepo
.
GetByID
(
ctx
,
userID
)
if
err
!=
nil
{
return
fmt
.
Errorf
(
"get user: %w"
,
err
)
}
a
,
err
:=
s
.
announcementRepo
.
GetByID
(
ctx
,
announcementID
)
if
err
!=
nil
{
return
err
}
now
:=
time
.
Now
()
if
!
a
.
IsActiveAt
(
now
)
{
return
ErrAnnouncementNotFound
}
activeSubs
,
err
:=
s
.
userSubRepo
.
ListActiveByUserID
(
ctx
,
userID
)
if
err
!=
nil
{
return
fmt
.
Errorf
(
"list active subscriptions: %w"
,
err
)
}
activeGroupIDs
:=
make
(
map
[
int64
]
struct
{},
len
(
activeSubs
))
for
i
:=
range
activeSubs
{
activeGroupIDs
[
activeSubs
[
i
]
.
GroupID
]
=
struct
{}{}
}
if
!
a
.
Targeting
.
Matches
(
user
.
Balance
,
activeGroupIDs
)
{
return
ErrAnnouncementNotFound
}
if
err
:=
s
.
readRepo
.
MarkRead
(
ctx
,
announcementID
,
userID
,
now
);
err
!=
nil
{
return
fmt
.
Errorf
(
"mark read: %w"
,
err
)
}
return
nil
}
func
(
s
*
AnnouncementService
)
ListUserReadStatus
(
ctx
context
.
Context
,
announcementID
int64
,
params
pagination
.
PaginationParams
,
search
string
,
)
([]
AnnouncementUserReadStatus
,
*
pagination
.
PaginationResult
,
error
)
{
ann
,
err
:=
s
.
announcementRepo
.
GetByID
(
ctx
,
announcementID
)
if
err
!=
nil
{
return
nil
,
nil
,
err
}
filters
:=
UserListFilters
{
Search
:
strings
.
TrimSpace
(
search
),
}
users
,
page
,
err
:=
s
.
userRepo
.
ListWithFilters
(
ctx
,
params
,
filters
)
if
err
!=
nil
{
return
nil
,
nil
,
fmt
.
Errorf
(
"list users: %w"
,
err
)
}
userIDs
:=
make
([]
int64
,
0
,
len
(
users
))
for
i
:=
range
users
{
userIDs
=
append
(
userIDs
,
users
[
i
]
.
ID
)
}
readMap
,
err
:=
s
.
readRepo
.
GetReadMapByUsers
(
ctx
,
announcementID
,
userIDs
)
if
err
!=
nil
{
return
nil
,
nil
,
fmt
.
Errorf
(
"get read map: %w"
,
err
)
}
out
:=
make
([]
AnnouncementUserReadStatus
,
0
,
len
(
users
))
for
i
:=
range
users
{
u
:=
users
[
i
]
subs
,
err
:=
s
.
userSubRepo
.
ListActiveByUserID
(
ctx
,
u
.
ID
)
if
err
!=
nil
{
return
nil
,
nil
,
fmt
.
Errorf
(
"list active subscriptions: %w"
,
err
)
}
activeGroupIDs
:=
make
(
map
[
int64
]
struct
{},
len
(
subs
))
for
j
:=
range
subs
{
activeGroupIDs
[
subs
[
j
]
.
GroupID
]
=
struct
{}{}
}
readAt
,
ok
:=
readMap
[
u
.
ID
]
var
ptr
*
time
.
Time
if
ok
{
t
:=
readAt
ptr
=
&
t
}
out
=
append
(
out
,
AnnouncementUserReadStatus
{
UserID
:
u
.
ID
,
Email
:
u
.
Email
,
Username
:
u
.
Username
,
Balance
:
u
.
Balance
,
Eligible
:
domain
.
AnnouncementTargeting
(
ann
.
Targeting
)
.
Matches
(
u
.
Balance
,
activeGroupIDs
),
ReadAt
:
ptr
,
})
}
return
out
,
page
,
nil
}
func
isValidAnnouncementStatus
(
status
string
)
bool
{
switch
status
{
case
AnnouncementStatusDraft
,
AnnouncementStatusActive
,
AnnouncementStatusArchived
:
return
true
default
:
return
false
}
}
backend/internal/service/announcement_targeting_test.go
0 → 100644
View file @
31fe0178
package
service
import
(
"testing"
"github.com/stretchr/testify/require"
)
func
TestAnnouncementTargeting_Matches_EmptyMatchesAll
(
t
*
testing
.
T
)
{
var
targeting
AnnouncementTargeting
require
.
True
(
t
,
targeting
.
Matches
(
0
,
nil
))
require
.
True
(
t
,
targeting
.
Matches
(
123.45
,
map
[
int64
]
struct
{}{
1
:
{}}))
}
func
TestAnnouncementTargeting_NormalizeAndValidate_RejectsEmptyGroup
(
t
*
testing
.
T
)
{
targeting
:=
AnnouncementTargeting
{
AnyOf
:
[]
AnnouncementConditionGroup
{
{
AllOf
:
nil
},
},
}
_
,
err
:=
targeting
.
NormalizeAndValidate
()
require
.
Error
(
t
,
err
)
require
.
ErrorIs
(
t
,
err
,
ErrAnnouncementInvalidTarget
)
}
func
TestAnnouncementTargeting_NormalizeAndValidate_RejectsInvalidCondition
(
t
*
testing
.
T
)
{
targeting
:=
AnnouncementTargeting
{
AnyOf
:
[]
AnnouncementConditionGroup
{
{
AllOf
:
[]
AnnouncementCondition
{
{
Type
:
"balance"
,
Operator
:
"between"
,
Value
:
10
},
},
},
},
}
_
,
err
:=
targeting
.
NormalizeAndValidate
()
require
.
Error
(
t
,
err
)
require
.
ErrorIs
(
t
,
err
,
ErrAnnouncementInvalidTarget
)
}
func
TestAnnouncementTargeting_Matches_AndOrSemantics
(
t
*
testing
.
T
)
{
targeting
:=
AnnouncementTargeting
{
AnyOf
:
[]
AnnouncementConditionGroup
{
{
AllOf
:
[]
AnnouncementCondition
{
{
Type
:
AnnouncementConditionTypeBalance
,
Operator
:
AnnouncementOperatorGTE
,
Value
:
100
},
{
Type
:
AnnouncementConditionTypeSubscription
,
Operator
:
AnnouncementOperatorIn
,
GroupIDs
:
[]
int64
{
10
}},
},
},
{
AllOf
:
[]
AnnouncementCondition
{
{
Type
:
AnnouncementConditionTypeBalance
,
Operator
:
AnnouncementOperatorLT
,
Value
:
5
},
},
},
},
}
// 命中第 2 组(balance < 5)
require
.
True
(
t
,
targeting
.
Matches
(
4.99
,
nil
))
require
.
False
(
t
,
targeting
.
Matches
(
5
,
nil
))
// 命中第 1 组(balance >= 100 AND 订阅 in [10])
require
.
False
(
t
,
targeting
.
Matches
(
100
,
map
[
int64
]
struct
{}{}))
require
.
False
(
t
,
targeting
.
Matches
(
99.9
,
map
[
int64
]
struct
{}{
10
:
{}}))
require
.
True
(
t
,
targeting
.
Matches
(
100
,
map
[
int64
]
struct
{}{
10
:
{}}))
}
backend/internal/service/antigravity_gateway_service.go
View file @
31fe0178
...
...
@@ -273,13 +273,11 @@ func logPrefix(sessionID, accountName string) string {
}
// Antigravity 直接支持的模型(精确匹配透传)
// 注意:gemini-2.5 系列已移除,统一映射到 gemini-3 系列
var
antigravitySupportedModels
=
map
[
string
]
bool
{
"claude-opus-4-5-thinking"
:
true
,
"claude-sonnet-4-5"
:
true
,
"claude-sonnet-4-5-thinking"
:
true
,
"gemini-2.5-flash"
:
true
,
"gemini-2.5-flash-lite"
:
true
,
"gemini-2.5-flash-thinking"
:
true
,
"gemini-3-flash"
:
true
,
"gemini-3-pro-low"
:
true
,
"gemini-3-pro-high"
:
true
,
...
...
@@ -288,23 +286,32 @@ var antigravitySupportedModels = map[string]bool{
// Antigravity 前缀映射表(按前缀长度降序排列,确保最长匹配优先)
// 用于处理模型版本号变化(如 -20251111, -thinking, -preview 等后缀)
// gemini-2.5 系列统一映射到 gemini-3 系列(Antigravity 上游不再支持 2.5)
var
antigravityPrefixMapping
=
[]
struct
{
prefix
string
target
string
}{
// 长前缀优先
{
"gemini-2.5-flash-image"
,
"gemini-3-pro-image"
},
// gemini-2.5-flash-image → 3-pro-image
{
"gemini-3-pro-image"
,
"gemini-3-pro-image"
},
// gemini-3-pro-image-preview 等
{
"gemini-3-flash"
,
"gemini-3-flash"
},
// gemini-3-flash-preview 等 → gemini-3-flash
{
"claude-3-5-sonnet"
,
"claude-sonnet-4-5"
},
// 旧版 claude-3-5-sonnet-xxx
{
"claude-sonnet-4-5"
,
"claude-sonnet-4-5"
},
// claude-sonnet-4-5-xxx
{
"claude-haiku-4-5"
,
"claude-sonnet-4-5"
},
// claude-haiku-4-5-xxx → sonnet
// gemini-2.5 → gemini-3 映射(长前缀优先)
{
"gemini-2.5-flash-thinking"
,
"gemini-3-flash"
},
// gemini-2.5-flash-thinking → gemini-3-flash
{
"gemini-2.5-flash-image"
,
"gemini-3-pro-image"
},
// gemini-2.5-flash-image → gemini-3-pro-image
{
"gemini-2.5-flash-lite"
,
"gemini-3-flash"
},
// gemini-2.5-flash-lite → gemini-3-flash
{
"gemini-2.5-flash"
,
"gemini-3-flash"
},
// gemini-2.5-flash → gemini-3-flash
{
"gemini-2.5-pro-preview"
,
"gemini-3-pro-high"
},
// gemini-2.5-pro-preview → gemini-3-pro-high
{
"gemini-2.5-pro-exp"
,
"gemini-3-pro-high"
},
// gemini-2.5-pro-exp → gemini-3-pro-high
{
"gemini-2.5-pro"
,
"gemini-3-pro-high"
},
// gemini-2.5-pro → gemini-3-pro-high
// gemini-3 前缀映射
{
"gemini-3-pro-image"
,
"gemini-3-pro-image"
},
// gemini-3-pro-image-preview 等
{
"gemini-3-flash"
,
"gemini-3-flash"
},
// gemini-3-flash-preview 等 → gemini-3-flash
{
"gemini-3-pro"
,
"gemini-3-pro-high"
},
// gemini-3-pro, gemini-3-pro-preview 等
// Claude 映射
{
"claude-3-5-sonnet"
,
"claude-sonnet-4-5"
},
// 旧版 claude-3-5-sonnet-xxx
{
"claude-sonnet-4-5"
,
"claude-sonnet-4-5"
},
// claude-sonnet-4-5-xxx
{
"claude-haiku-4-5"
,
"claude-sonnet-4-5"
},
// claude-haiku-4-5-xxx → sonnet
{
"claude-opus-4-5"
,
"claude-opus-4-5-thinking"
},
{
"claude-3-haiku"
,
"claude-sonnet-4-5"
},
// 旧版 claude-3-haiku-xxx → sonnet
{
"claude-sonnet-4"
,
"claude-sonnet-4-5"
},
{
"claude-haiku-4"
,
"claude-sonnet-4-5"
},
// → sonnet
{
"claude-opus-4"
,
"claude-opus-4-5-thinking"
},
{
"gemini-3-pro"
,
"gemini-3-pro-high"
},
// gemini-3-pro, gemini-3-pro-preview 等
}
// AntigravityGatewayService 处理 Antigravity 平台的 API 转发
...
...
@@ -1530,7 +1537,11 @@ func sleepAntigravityBackoffWithContext(ctx context.Context, attempt int) bool {
func
antigravityUseScopeRateLimit
()
bool
{
v
:=
strings
.
ToLower
(
strings
.
TrimSpace
(
os
.
Getenv
(
antigravityScopeRateLimitEnv
)))
return
v
==
"1"
||
v
==
"true"
||
v
==
"yes"
||
v
==
"on"
// 默认开启按配额域限流,只有明确设置为禁用值时才关闭
if
v
==
"0"
||
v
==
"false"
||
v
==
"no"
||
v
==
"off"
{
return
false
}
return
true
}
func
(
s
*
AntigravityGatewayService
)
handleUpstreamError
(
ctx
context
.
Context
,
prefix
string
,
account
*
Account
,
statusCode
int
,
headers
http
.
Header
,
body
[]
byte
,
quotaScope
AntigravityQuotaScope
)
{
...
...
backend/internal/service/antigravity_model_mapping_test.go
View file @
31fe0178
...
...
@@ -134,18 +134,18 @@ func TestAntigravityGatewayService_GetMappedModel(t *testing.T) {
expected
:
"claude-sonnet-4-5"
,
},
// 3. Gemini
透传
// 3. Gemini
2.5 → 3 映射
{
name
:
"Gemini
透传
- gemini-2.5-flash"
,
name
:
"Gemini
映射
- gemini-2.5-flash
→ gemini-3-flash
"
,
requestedModel
:
"gemini-2.5-flash"
,
accountMapping
:
nil
,
expected
:
"gemini-
2.5
-flash"
,
expected
:
"gemini-
3
-flash"
,
},
{
name
:
"Gemini
透传
- gemini-2.5-pro"
,
name
:
"Gemini
映射
- gemini-2.5-pro
→ gemini-3-pro-high
"
,
requestedModel
:
"gemini-2.5-pro"
,
accountMapping
:
nil
,
expected
:
"gemini-
2.5
-pro"
,
expected
:
"gemini-
3
-pro
-high
"
,
},
{
name
:
"Gemini透传 - gemini-future-model"
,
...
...
backend/internal/service/antigravity_oauth_service.go
View file @
31fe0178
...
...
@@ -142,12 +142,13 @@ func (s *AntigravityOAuthService) ExchangeCode(ctx context.Context, input *Antig
result
.
Email
=
userInfo
.
Email
}
// 获取 project_id(部分账户类型可能没有)
loadResp
,
_
,
err
:=
client
.
LoadCodeAssist
(
ctx
,
tokenResp
.
AccessToken
)
if
err
!=
nil
{
fmt
.
Printf
(
"[AntigravityOAuth] 警告: 获取 project_id 失败: %v
\n
"
,
err
)
}
else
if
loadResp
!=
nil
&&
loadResp
.
CloudAICompanionProject
!=
""
{
result
.
ProjectID
=
loadResp
.
CloudAICompanionProject
// 获取 project_id(部分账户类型可能没有),失败时重试
projectID
,
loadErr
:=
s
.
loadProjectIDWithRetry
(
ctx
,
tokenResp
.
AccessToken
,
proxyURL
,
3
)
if
loadErr
!=
nil
{
fmt
.
Printf
(
"[AntigravityOAuth] 警告: 获取 project_id 失败(重试后): %v
\n
"
,
loadErr
)
result
.
ProjectIDMissing
=
true
}
else
{
result
.
ProjectID
=
projectID
}
return
result
,
nil
...
...
@@ -237,21 +238,60 @@ func (s *AntigravityOAuthService) RefreshAccountToken(ctx context.Context, accou
tokenInfo
.
Email
=
existingEmail
}
// 每次刷新都调用 LoadCodeAssist 获取 project_id
client
:=
antigravity
.
NewClient
(
proxyURL
)
loadResp
,
_
,
err
:=
client
.
LoadCodeAssist
(
ctx
,
tokenInfo
.
AccessToken
)
if
err
!=
nil
||
loadResp
==
nil
||
loadResp
.
CloudAICompanionProject
==
""
{
// LoadCodeAssist 失败或返回空,保留原有 project_id,标记缺失
existingProjectID
:=
strings
.
TrimSpace
(
account
.
GetCredential
(
"
project_id
"
))
// 每次刷新都调用 LoadCodeAssist 获取 project_id
,失败时重试
existingProjectID
:=
strings
.
TrimSpace
(
account
.
GetCredential
(
"project_id"
)
)
projectID
,
loadErr
:=
s
.
loadProjectIDWithRetry
(
ctx
,
tokenInfo
.
AccessToken
,
proxyURL
,
3
)
if
loadErr
!=
nil
{
// LoadCodeAssist 失败,保留原有
project_id
tokenInfo
.
ProjectID
=
existingProjectID
tokenInfo
.
ProjectIDMissing
=
true
// 只有从未获取过 project_id 且本次也获取失败时,才标记为真正缺失
// 如果之前有 project_id,本次只是临时故障,不应标记为错误
if
existingProjectID
==
""
{
tokenInfo
.
ProjectIDMissing
=
true
}
}
else
{
tokenInfo
.
ProjectID
=
loadResp
.
CloudAICompanionP
roject
tokenInfo
.
ProjectID
=
p
roject
ID
}
return
tokenInfo
,
nil
}
// loadProjectIDWithRetry 带重试机制获取 project_id
// 返回 project_id 和错误,失败时会重试指定次数
func
(
s
*
AntigravityOAuthService
)
loadProjectIDWithRetry
(
ctx
context
.
Context
,
accessToken
,
proxyURL
string
,
maxRetries
int
)
(
string
,
error
)
{
var
lastErr
error
for
attempt
:=
0
;
attempt
<=
maxRetries
;
attempt
++
{
if
attempt
>
0
{
// 指数退避:1s, 2s, 4s
backoff
:=
time
.
Duration
(
1
<<
uint
(
attempt
-
1
))
*
time
.
Second
if
backoff
>
8
*
time
.
Second
{
backoff
=
8
*
time
.
Second
}
time
.
Sleep
(
backoff
)
}
client
:=
antigravity
.
NewClient
(
proxyURL
)
loadResp
,
_
,
err
:=
client
.
LoadCodeAssist
(
ctx
,
accessToken
)
if
err
==
nil
&&
loadResp
!=
nil
&&
loadResp
.
CloudAICompanionProject
!=
""
{
return
loadResp
.
CloudAICompanionProject
,
nil
}
// 记录错误
if
err
!=
nil
{
lastErr
=
err
}
else
if
loadResp
==
nil
{
lastErr
=
fmt
.
Errorf
(
"LoadCodeAssist 返回空响应"
)
}
else
{
lastErr
=
fmt
.
Errorf
(
"LoadCodeAssist 返回空 project_id"
)
}
}
return
""
,
fmt
.
Errorf
(
"获取 project_id 失败 (重试 %d 次后): %w"
,
maxRetries
,
lastErr
)
}
// BuildAccountCredentials 构建账户凭证
func
(
s
*
AntigravityOAuthService
)
BuildAccountCredentials
(
tokenInfo
*
AntigravityTokenInfo
)
map
[
string
]
any
{
creds
:=
map
[
string
]
any
{
...
...
Prev
1
2
3
4
5
6
7
8
9
10
…
12
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment