Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Menu
Open sidebar
陈曦
sub2api
Commits
a161fcc8
Commit
a161fcc8
authored
Jan 26, 2026
by
cyhhao
Browse files
Merge branch 'main' of github.com:Wei-Shaw/sub2api
parents
65e69738
e32c5f53
Changes
119
Expand all
Hide whitespace changes
Inline
Side-by-side
backend/internal/repository/user_repo.go
View file @
a161fcc8
...
...
@@ -7,6 +7,7 @@ import (
"fmt"
"sort"
"strings"
"time"
dbent
"github.com/Wei-Shaw/sub2api/ent"
dbuser
"github.com/Wei-Shaw/sub2api/ent/user"
...
...
@@ -466,3 +467,46 @@ func applyUserEntityToService(dst *service.User, src *dbent.User) {
dst
.
CreatedAt
=
src
.
CreatedAt
dst
.
UpdatedAt
=
src
.
UpdatedAt
}
// UpdateTotpSecret 更新用户的 TOTP 加密密钥
func
(
r
*
userRepository
)
UpdateTotpSecret
(
ctx
context
.
Context
,
userID
int64
,
encryptedSecret
*
string
)
error
{
client
:=
clientFromContext
(
ctx
,
r
.
client
)
update
:=
client
.
User
.
UpdateOneID
(
userID
)
if
encryptedSecret
==
nil
{
update
=
update
.
ClearTotpSecretEncrypted
()
}
else
{
update
=
update
.
SetTotpSecretEncrypted
(
*
encryptedSecret
)
}
_
,
err
:=
update
.
Save
(
ctx
)
if
err
!=
nil
{
return
translatePersistenceError
(
err
,
service
.
ErrUserNotFound
,
nil
)
}
return
nil
}
// EnableTotp 启用用户的 TOTP 双因素认证
func
(
r
*
userRepository
)
EnableTotp
(
ctx
context
.
Context
,
userID
int64
)
error
{
client
:=
clientFromContext
(
ctx
,
r
.
client
)
_
,
err
:=
client
.
User
.
UpdateOneID
(
userID
)
.
SetTotpEnabled
(
true
)
.
SetTotpEnabledAt
(
time
.
Now
())
.
Save
(
ctx
)
if
err
!=
nil
{
return
translatePersistenceError
(
err
,
service
.
ErrUserNotFound
,
nil
)
}
return
nil
}
// DisableTotp 禁用用户的 TOTP 双因素认证
func
(
r
*
userRepository
)
DisableTotp
(
ctx
context
.
Context
,
userID
int64
)
error
{
client
:=
clientFromContext
(
ctx
,
r
.
client
)
_
,
err
:=
client
.
User
.
UpdateOneID
(
userID
)
.
SetTotpEnabled
(
false
)
.
ClearTotpEnabledAt
()
.
ClearTotpSecretEncrypted
()
.
Save
(
ctx
)
if
err
!=
nil
{
return
translatePersistenceError
(
err
,
service
.
ErrUserNotFound
,
nil
)
}
return
nil
}
backend/internal/repository/user_subscription_repo.go
View file @
a161fcc8
...
...
@@ -190,7 +190,7 @@ func (r *userSubscriptionRepository) ListByGroupID(ctx context.Context, groupID
return
userSubscriptionEntitiesToService
(
subs
),
paginationResultFromTotal
(
int64
(
total
),
params
),
nil
}
func
(
r
*
userSubscriptionRepository
)
List
(
ctx
context
.
Context
,
params
pagination
.
PaginationParams
,
userID
,
groupID
*
int64
,
status
string
)
([]
service
.
UserSubscription
,
*
pagination
.
PaginationResult
,
error
)
{
func
(
r
*
userSubscriptionRepository
)
List
(
ctx
context
.
Context
,
params
pagination
.
PaginationParams
,
userID
,
groupID
*
int64
,
status
,
sortBy
,
sortOrder
string
)
([]
service
.
UserSubscription
,
*
pagination
.
PaginationResult
,
error
)
{
client
:=
clientFromContext
(
ctx
,
r
.
client
)
q
:=
client
.
UserSubscription
.
Query
()
if
userID
!=
nil
{
...
...
@@ -199,7 +199,31 @@ func (r *userSubscriptionRepository) List(ctx context.Context, params pagination
if
groupID
!=
nil
{
q
=
q
.
Where
(
usersubscription
.
GroupIDEQ
(
*
groupID
))
}
if
status
!=
""
{
// Status filtering with real-time expiration check
now
:=
time
.
Now
()
switch
status
{
case
service
.
SubscriptionStatusActive
:
// Active: status is active AND not yet expired
q
=
q
.
Where
(
usersubscription
.
StatusEQ
(
service
.
SubscriptionStatusActive
),
usersubscription
.
ExpiresAtGT
(
now
),
)
case
service
.
SubscriptionStatusExpired
:
// Expired: status is expired OR (status is active but already expired)
q
=
q
.
Where
(
usersubscription
.
Or
(
usersubscription
.
StatusEQ
(
service
.
SubscriptionStatusExpired
),
usersubscription
.
And
(
usersubscription
.
StatusEQ
(
service
.
SubscriptionStatusActive
),
usersubscription
.
ExpiresAtLTE
(
now
),
),
),
)
case
""
:
// No filter
default
:
// Other status (e.g., revoked)
q
=
q
.
Where
(
usersubscription
.
StatusEQ
(
status
))
}
...
...
@@ -208,11 +232,28 @@ func (r *userSubscriptionRepository) List(ctx context.Context, params pagination
return
nil
,
nil
,
err
}
// Apply sorting
q
=
q
.
WithUser
()
.
WithGroup
()
.
WithAssignedByUser
()
// Determine sort field
var
field
string
switch
sortBy
{
case
"expires_at"
:
field
=
usersubscription
.
FieldExpiresAt
case
"status"
:
field
=
usersubscription
.
FieldStatus
default
:
field
=
usersubscription
.
FieldCreatedAt
}
// Determine sort order (default: desc)
if
sortOrder
==
"asc"
&&
sortBy
!=
""
{
q
=
q
.
Order
(
dbent
.
Asc
(
field
))
}
else
{
q
=
q
.
Order
(
dbent
.
Desc
(
field
))
}
subs
,
err
:=
q
.
WithUser
()
.
WithGroup
()
.
WithAssignedByUser
()
.
Order
(
dbent
.
Desc
(
usersubscription
.
FieldCreatedAt
))
.
Offset
(
params
.
Offset
())
.
Limit
(
params
.
Limit
())
.
All
(
ctx
)
...
...
backend/internal/repository/user_subscription_repo_integration_test.go
View file @
a161fcc8
...
...
@@ -271,7 +271,7 @@ func (s *UserSubscriptionRepoSuite) TestList_NoFilters() {
group
:=
s
.
mustCreateGroup
(
"g-list"
)
s
.
mustCreateSubscription
(
user
.
ID
,
group
.
ID
,
nil
)
subs
,
page
,
err
:=
s
.
repo
.
List
(
s
.
ctx
,
pagination
.
PaginationParams
{
Page
:
1
,
PageSize
:
10
},
nil
,
nil
,
""
)
subs
,
page
,
err
:=
s
.
repo
.
List
(
s
.
ctx
,
pagination
.
PaginationParams
{
Page
:
1
,
PageSize
:
10
},
nil
,
nil
,
""
,
""
,
""
)
s
.
Require
()
.
NoError
(
err
,
"List"
)
s
.
Require
()
.
Len
(
subs
,
1
)
s
.
Require
()
.
Equal
(
int64
(
1
),
page
.
Total
)
...
...
@@ -285,7 +285,7 @@ func (s *UserSubscriptionRepoSuite) TestList_FilterByUserID() {
s
.
mustCreateSubscription
(
user1
.
ID
,
group
.
ID
,
nil
)
s
.
mustCreateSubscription
(
user2
.
ID
,
group
.
ID
,
nil
)
subs
,
_
,
err
:=
s
.
repo
.
List
(
s
.
ctx
,
pagination
.
PaginationParams
{
Page
:
1
,
PageSize
:
10
},
&
user1
.
ID
,
nil
,
""
)
subs
,
_
,
err
:=
s
.
repo
.
List
(
s
.
ctx
,
pagination
.
PaginationParams
{
Page
:
1
,
PageSize
:
10
},
&
user1
.
ID
,
nil
,
""
,
""
,
""
)
s
.
Require
()
.
NoError
(
err
)
s
.
Require
()
.
Len
(
subs
,
1
)
s
.
Require
()
.
Equal
(
user1
.
ID
,
subs
[
0
]
.
UserID
)
...
...
@@ -299,7 +299,7 @@ func (s *UserSubscriptionRepoSuite) TestList_FilterByGroupID() {
s
.
mustCreateSubscription
(
user
.
ID
,
g1
.
ID
,
nil
)
s
.
mustCreateSubscription
(
user
.
ID
,
g2
.
ID
,
nil
)
subs
,
_
,
err
:=
s
.
repo
.
List
(
s
.
ctx
,
pagination
.
PaginationParams
{
Page
:
1
,
PageSize
:
10
},
nil
,
&
g1
.
ID
,
""
)
subs
,
_
,
err
:=
s
.
repo
.
List
(
s
.
ctx
,
pagination
.
PaginationParams
{
Page
:
1
,
PageSize
:
10
},
nil
,
&
g1
.
ID
,
""
,
""
,
""
)
s
.
Require
()
.
NoError
(
err
)
s
.
Require
()
.
Len
(
subs
,
1
)
s
.
Require
()
.
Equal
(
g1
.
ID
,
subs
[
0
]
.
GroupID
)
...
...
@@ -320,7 +320,7 @@ func (s *UserSubscriptionRepoSuite) TestList_FilterByStatus() {
c
.
SetExpiresAt
(
time
.
Now
()
.
Add
(
-
24
*
time
.
Hour
))
})
subs
,
_
,
err
:=
s
.
repo
.
List
(
s
.
ctx
,
pagination
.
PaginationParams
{
Page
:
1
,
PageSize
:
10
},
nil
,
nil
,
service
.
SubscriptionStatusExpired
)
subs
,
_
,
err
:=
s
.
repo
.
List
(
s
.
ctx
,
pagination
.
PaginationParams
{
Page
:
1
,
PageSize
:
10
},
nil
,
nil
,
service
.
SubscriptionStatusExpired
,
""
,
""
)
s
.
Require
()
.
NoError
(
err
)
s
.
Require
()
.
Len
(
subs
,
1
)
s
.
Require
()
.
Equal
(
service
.
SubscriptionStatusExpired
,
subs
[
0
]
.
Status
)
...
...
backend/internal/repository/wire.go
View file @
a161fcc8
...
...
@@ -82,6 +82,10 @@ var ProviderSet = wire.NewSet(
NewSchedulerCache
,
NewSchedulerOutboxRepository
,
NewProxyLatencyCache
,
NewTotpCache
,
// Encryptors
NewAESEncryptor
,
// HTTP service ports (DI Strategy A: return interface directly)
NewTurnstileVerifier
,
...
...
backend/internal/server/api_contract_test.go
View file @
a161fcc8
...
...
@@ -193,20 +193,20 @@ func TestAPIContracts(t *testing.T) {
// 普通用户订阅接口不应包含 assigned_* / notes 等管理员字段。
deps
.
userSubRepo
.
SetByUserID
(
1
,
[]
service
.
UserSubscription
{
{
ID
:
501
,
UserID
:
1
,
GroupID
:
10
,
StartsAt
:
deps
.
now
,
ExpiresAt
:
deps
.
now
.
Add
(
24
*
time
.
Hour
),
Status
:
service
.
SubscriptionStatusActive
,
ID
:
501
,
UserID
:
1
,
GroupID
:
10
,
StartsAt
:
deps
.
now
,
ExpiresAt
:
time
.
Date
(
2099
,
1
,
2
,
3
,
4
,
5
,
0
,
time
.
UTC
),
// 使用未来日期避免 normalizeSubscriptionStatus 标记为过期
Status
:
service
.
SubscriptionStatusActive
,
DailyUsageUSD
:
1.23
,
WeeklyUsageUSD
:
2.34
,
MonthlyUsageUSD
:
3.45
,
AssignedBy
:
ptr
(
int64
(
999
)),
AssignedAt
:
deps
.
now
,
Notes
:
"admin-note"
,
CreatedAt
:
deps
.
now
,
UpdatedAt
:
deps
.
now
,
AssignedBy
:
ptr
(
int64
(
999
)),
AssignedAt
:
deps
.
now
,
Notes
:
"admin-note"
,
CreatedAt
:
deps
.
now
,
UpdatedAt
:
deps
.
now
,
},
})
},
...
...
@@ -222,7 +222,7 @@ func TestAPIContracts(t *testing.T) {
"user_id": 1,
"group_id": 10,
"starts_at": "2025-01-02T03:04:05Z",
"expires_at": "20
25
-01-0
3
T03:04:05Z",
"expires_at": "20
99
-01-0
2
T03:04:05Z",
"status": "active",
"daily_window_start": null,
"weekly_window_start": null,
...
...
@@ -452,6 +452,9 @@ func TestAPIContracts(t *testing.T) {
"registration_enabled": true,
"email_verify_enabled": false,
"promo_code_enabled": true,
"password_reset_enabled": false,
"totp_enabled": false,
"totp_encryption_key_configured": false,
"smtp_host": "smtp.example.com",
"smtp_port": 587,
"smtp_username": "user",
...
...
@@ -595,7 +598,7 @@ func newContractDeps(t *testing.T) *contractDeps {
settingService
:=
service
.
NewSettingService
(
settingRepo
,
cfg
)
adminService
:=
service
.
NewAdminService
(
userRepo
,
groupRepo
,
&
accountRepo
,
proxyRepo
,
apiKeyRepo
,
redeemRepo
,
nil
,
nil
,
nil
,
nil
)
authHandler
:=
handler
.
NewAuthHandler
(
cfg
,
nil
,
userService
,
settingService
,
nil
)
authHandler
:=
handler
.
NewAuthHandler
(
cfg
,
nil
,
userService
,
settingService
,
nil
,
nil
)
apiKeyHandler
:=
handler
.
NewAPIKeyHandler
(
apiKeyService
)
usageHandler
:=
handler
.
NewUsageHandler
(
usageService
,
apiKeyService
)
adminSettingHandler
:=
adminhandler
.
NewSettingHandler
(
settingService
,
nil
,
nil
,
nil
)
...
...
@@ -754,6 +757,18 @@ func (r *stubUserRepo) RemoveGroupFromAllowedGroups(ctx context.Context, groupID
return
0
,
errors
.
New
(
"not implemented"
)
}
func
(
r
*
stubUserRepo
)
UpdateTotpSecret
(
ctx
context
.
Context
,
userID
int64
,
encryptedSecret
*
string
)
error
{
return
errors
.
New
(
"not implemented"
)
}
func
(
r
*
stubUserRepo
)
EnableTotp
(
ctx
context
.
Context
,
userID
int64
)
error
{
return
errors
.
New
(
"not implemented"
)
}
func
(
r
*
stubUserRepo
)
DisableTotp
(
ctx
context
.
Context
,
userID
int64
)
error
{
return
errors
.
New
(
"not implemented"
)
}
type
stubApiKeyCache
struct
{}
func
(
stubApiKeyCache
)
GetCreateAttemptCount
(
ctx
context
.
Context
,
userID
int64
)
(
int
,
error
)
{
...
...
@@ -1176,7 +1191,7 @@ func (r *stubUserSubscriptionRepo) ListActiveByUserID(ctx context.Context, userI
func
(
stubUserSubscriptionRepo
)
ListByGroupID
(
ctx
context
.
Context
,
groupID
int64
,
params
pagination
.
PaginationParams
)
([]
service
.
UserSubscription
,
*
pagination
.
PaginationResult
,
error
)
{
return
nil
,
nil
,
errors
.
New
(
"not implemented"
)
}
func
(
stubUserSubscriptionRepo
)
List
(
ctx
context
.
Context
,
params
pagination
.
PaginationParams
,
userID
,
groupID
*
int64
,
status
string
)
([]
service
.
UserSubscription
,
*
pagination
.
PaginationResult
,
error
)
{
func
(
stubUserSubscriptionRepo
)
List
(
ctx
context
.
Context
,
params
pagination
.
PaginationParams
,
userID
,
groupID
*
int64
,
status
,
sortBy
,
sortOrder
string
)
([]
service
.
UserSubscription
,
*
pagination
.
PaginationResult
,
error
)
{
return
nil
,
nil
,
errors
.
New
(
"not implemented"
)
}
func
(
stubUserSubscriptionRepo
)
ExistsByUserIDAndGroupID
(
ctx
context
.
Context
,
userID
,
groupID
int64
)
(
bool
,
error
)
{
...
...
backend/internal/server/middleware/api_key_auth_test.go
View file @
a161fcc8
...
...
@@ -367,7 +367,7 @@ func (r *stubUserSubscriptionRepo) ListByGroupID(ctx context.Context, groupID in
return
nil
,
nil
,
errors
.
New
(
"not implemented"
)
}
func
(
r
*
stubUserSubscriptionRepo
)
List
(
ctx
context
.
Context
,
params
pagination
.
PaginationParams
,
userID
,
groupID
*
int64
,
status
string
)
([]
service
.
UserSubscription
,
*
pagination
.
PaginationResult
,
error
)
{
func
(
r
*
stubUserSubscriptionRepo
)
List
(
ctx
context
.
Context
,
params
pagination
.
PaginationParams
,
userID
,
groupID
*
int64
,
status
,
sortBy
,
sortOrder
string
)
([]
service
.
UserSubscription
,
*
pagination
.
PaginationResult
,
error
)
{
return
nil
,
nil
,
errors
.
New
(
"not implemented"
)
}
...
...
backend/internal/server/routes/auth.go
View file @
a161fcc8
...
...
@@ -26,11 +26,20 @@ func RegisterAuthRoutes(
{
auth
.
POST
(
"/register"
,
h
.
Auth
.
Register
)
auth
.
POST
(
"/login"
,
h
.
Auth
.
Login
)
auth
.
POST
(
"/login/2fa"
,
h
.
Auth
.
Login2FA
)
auth
.
POST
(
"/send-verify-code"
,
h
.
Auth
.
SendVerifyCode
)
// 优惠码验证接口添加速率限制:每分钟最多 10 次(Redis 故障时 fail-close)
auth
.
POST
(
"/validate-promo-code"
,
rateLimiter
.
LimitWithOptions
(
"validate-promo"
,
10
,
time
.
Minute
,
middleware
.
RateLimitOptions
{
FailureMode
:
middleware
.
RateLimitFailClose
,
}),
h
.
Auth
.
ValidatePromoCode
)
// 忘记密码接口添加速率限制:每分钟最多 5 次(Redis 故障时 fail-close)
auth
.
POST
(
"/forgot-password"
,
rateLimiter
.
LimitWithOptions
(
"forgot-password"
,
5
,
time
.
Minute
,
middleware
.
RateLimitOptions
{
FailureMode
:
middleware
.
RateLimitFailClose
,
}),
h
.
Auth
.
ForgotPassword
)
// 重置密码接口添加速率限制:每分钟最多 10 次(Redis 故障时 fail-close)
auth
.
POST
(
"/reset-password"
,
rateLimiter
.
LimitWithOptions
(
"reset-password"
,
10
,
time
.
Minute
,
middleware
.
RateLimitOptions
{
FailureMode
:
middleware
.
RateLimitFailClose
,
}),
h
.
Auth
.
ResetPassword
)
auth
.
GET
(
"/oauth/linuxdo/start"
,
h
.
Auth
.
LinuxDoOAuthStart
)
auth
.
GET
(
"/oauth/linuxdo/callback"
,
h
.
Auth
.
LinuxDoOAuthCallback
)
}
...
...
backend/internal/server/routes/user.go
View file @
a161fcc8
...
...
@@ -22,6 +22,17 @@ func RegisterUserRoutes(
user
.
GET
(
"/profile"
,
h
.
User
.
GetProfile
)
user
.
PUT
(
"/password"
,
h
.
User
.
ChangePassword
)
user
.
PUT
(
""
,
h
.
User
.
UpdateProfile
)
// TOTP 双因素认证
totp
:=
user
.
Group
(
"/totp"
)
{
totp
.
GET
(
"/status"
,
h
.
Totp
.
GetStatus
)
totp
.
GET
(
"/verification-method"
,
h
.
Totp
.
GetVerificationMethod
)
totp
.
POST
(
"/send-code"
,
h
.
Totp
.
SendVerifyCode
)
totp
.
POST
(
"/setup"
,
h
.
Totp
.
InitiateSetup
)
totp
.
POST
(
"/enable"
,
h
.
Totp
.
Enable
)
totp
.
POST
(
"/disable"
,
h
.
Totp
.
Disable
)
}
}
// API Key管理
...
...
backend/internal/service/account.go
View file @
a161fcc8
...
...
@@ -197,6 +197,35 @@ func (a *Account) GetCredentialAsTime(key string) *time.Time {
return
nil
}
// GetCredentialAsInt64 解析凭证中的 int64 字段
// 用于读取 _token_version 等内部字段
func
(
a
*
Account
)
GetCredentialAsInt64
(
key
string
)
int64
{
if
a
==
nil
||
a
.
Credentials
==
nil
{
return
0
}
val
,
ok
:=
a
.
Credentials
[
key
]
if
!
ok
||
val
==
nil
{
return
0
}
switch
v
:=
val
.
(
type
)
{
case
int64
:
return
v
case
float64
:
return
int64
(
v
)
case
int
:
return
int64
(
v
)
case
json
.
Number
:
if
i
,
err
:=
v
.
Int64
();
err
==
nil
{
return
i
}
case
string
:
if
i
,
err
:=
strconv
.
ParseInt
(
strings
.
TrimSpace
(
v
),
10
,
64
);
err
==
nil
{
return
i
}
}
return
0
}
func
(
a
*
Account
)
IsTempUnschedulableEnabled
()
bool
{
if
a
.
Credentials
==
nil
{
return
false
...
...
backend/internal/service/admin_service_delete_test.go
View file @
a161fcc8
...
...
@@ -93,6 +93,18 @@ func (s *userRepoStub) RemoveGroupFromAllowedGroups(ctx context.Context, groupID
panic
(
"unexpected RemoveGroupFromAllowedGroups call"
)
}
func
(
s
*
userRepoStub
)
UpdateTotpSecret
(
ctx
context
.
Context
,
userID
int64
,
encryptedSecret
*
string
)
error
{
panic
(
"unexpected UpdateTotpSecret call"
)
}
func
(
s
*
userRepoStub
)
EnableTotp
(
ctx
context
.
Context
,
userID
int64
)
error
{
panic
(
"unexpected EnableTotp call"
)
}
func
(
s
*
userRepoStub
)
DisableTotp
(
ctx
context
.
Context
,
userID
int64
)
error
{
panic
(
"unexpected DisableTotp call"
)
}
type
groupRepoStub
struct
{
affectedUserIDs
[]
int64
deleteErr
error
...
...
backend/internal/service/antigravity_gateway_service.go
View file @
a161fcc8
...
...
@@ -1305,6 +1305,14 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co
return
nil
,
err
}
// 清理 Schema
if
cleanedBody
,
err
:=
cleanGeminiRequest
(
injectedBody
);
err
==
nil
{
injectedBody
=
cleanedBody
log
.
Printf
(
"[Antigravity] Cleaned request schema in forwarded request for account %s"
,
account
.
Name
)
}
else
{
log
.
Printf
(
"[Antigravity] Failed to clean schema: %v"
,
err
)
}
// 包装请求
wrappedBody
,
err
:=
s
.
wrapV1InternalRequest
(
projectID
,
mappedModel
,
injectedBody
)
if
err
!=
nil
{
...
...
@@ -1705,6 +1713,19 @@ func (s *AntigravityGatewayService) handleGeminiStreamingResponse(c *gin.Context
if
u
:=
extractGeminiUsage
(
parsed
);
u
!=
nil
{
usage
=
u
}
// Check for MALFORMED_FUNCTION_CALL
if
candidates
,
ok
:=
parsed
[
"candidates"
]
.
([]
any
);
ok
&&
len
(
candidates
)
>
0
{
if
cand
,
ok
:=
candidates
[
0
]
.
(
map
[
string
]
any
);
ok
{
if
fr
,
ok
:=
cand
[
"finishReason"
]
.
(
string
);
ok
&&
fr
==
"MALFORMED_FUNCTION_CALL"
{
log
.
Printf
(
"[Antigravity] MALFORMED_FUNCTION_CALL detected in forward stream"
)
if
content
,
ok
:=
cand
[
"content"
];
ok
{
if
b
,
err
:=
json
.
Marshal
(
content
);
err
==
nil
{
log
.
Printf
(
"[Antigravity] Malformed content: %s"
,
string
(
b
))
}
}
}
}
}
}
if
firstTokenMs
==
nil
{
...
...
@@ -1854,6 +1875,20 @@ func (s *AntigravityGatewayService) handleGeminiStreamToNonStreaming(c *gin.Cont
usage
=
u
}
// Check for MALFORMED_FUNCTION_CALL
if
candidates
,
ok
:=
parsed
[
"candidates"
]
.
([]
any
);
ok
&&
len
(
candidates
)
>
0
{
if
cand
,
ok
:=
candidates
[
0
]
.
(
map
[
string
]
any
);
ok
{
if
fr
,
ok
:=
cand
[
"finishReason"
]
.
(
string
);
ok
&&
fr
==
"MALFORMED_FUNCTION_CALL"
{
log
.
Printf
(
"[Antigravity] MALFORMED_FUNCTION_CALL detected in forward non-stream collect"
)
if
content
,
ok
:=
cand
[
"content"
];
ok
{
if
b
,
err
:=
json
.
Marshal
(
content
);
err
==
nil
{
log
.
Printf
(
"[Antigravity] Malformed content: %s"
,
string
(
b
))
}
}
}
}
}
// 保留最后一个有 parts 的响应
if
parts
:=
extractGeminiParts
(
parsed
);
len
(
parts
)
>
0
{
lastWithParts
=
parsed
...
...
@@ -1950,6 +1985,58 @@ func getOrCreateGeminiParts(response map[string]any) (result map[string]any, exi
return
result
,
existingParts
,
setParts
}
// mergeCollectedPartsToResponse 将收集的所有 parts 合并到 Gemini 响应中
// 这个函数会合并所有类型的 parts:text、thinking、functionCall、inlineData 等
// 保持原始顺序,只合并连续的普通 text parts
func
mergeCollectedPartsToResponse
(
response
map
[
string
]
any
,
collectedParts
[]
map
[
string
]
any
)
map
[
string
]
any
{
if
len
(
collectedParts
)
==
0
{
return
response
}
result
,
_
,
setParts
:=
getOrCreateGeminiParts
(
response
)
// 合并策略:
// 1. 保持原始顺序
// 2. 连续的普通 text parts 合并为一个
// 3. thinking、functionCall、inlineData 等保持原样
var
mergedParts
[]
any
var
textBuffer
strings
.
Builder
flushTextBuffer
:=
func
()
{
if
textBuffer
.
Len
()
>
0
{
mergedParts
=
append
(
mergedParts
,
map
[
string
]
any
{
"text"
:
textBuffer
.
String
(),
})
textBuffer
.
Reset
()
}
}
for
_
,
part
:=
range
collectedParts
{
// 检查是否是普通 text part
if
text
,
ok
:=
part
[
"text"
]
.
(
string
);
ok
{
// 检查是否有 thought 标记
if
thought
,
_
:=
part
[
"thought"
]
.
(
bool
);
thought
{
// thinking part,先刷新 text buffer,然后保留原样
flushTextBuffer
()
mergedParts
=
append
(
mergedParts
,
part
)
}
else
{
// 普通 text,累积到 buffer
_
,
_
=
textBuffer
.
WriteString
(
text
)
}
}
else
{
// 非 text part(functionCall、inlineData 等),先刷新 text buffer,然后保留原样
flushTextBuffer
()
mergedParts
=
append
(
mergedParts
,
part
)
}
}
// 刷新剩余的 text
flushTextBuffer
()
setParts
(
mergedParts
)
return
result
}
// mergeImagePartsToResponse 将收集到的图片 parts 合并到 Gemini 响应中
func
mergeImagePartsToResponse
(
response
map
[
string
]
any
,
imageParts
[]
map
[
string
]
any
)
map
[
string
]
any
{
if
len
(
imageParts
)
==
0
{
...
...
@@ -2133,6 +2220,7 @@ func (s *AntigravityGatewayService) handleClaudeStreamToNonStreaming(c *gin.Cont
var
firstTokenMs
*
int
var
last
map
[
string
]
any
var
lastWithParts
map
[
string
]
any
var
collectedParts
[]
map
[
string
]
any
// 收集所有 parts(包括 text、thinking、functionCall、inlineData 等)
type
scanEvent
struct
{
line
string
...
...
@@ -2227,9 +2315,12 @@ func (s *AntigravityGatewayService) handleClaudeStreamToNonStreaming(c *gin.Cont
last
=
parsed
// 保留最后一个有 parts 的响应
// 保留最后一个有 parts 的响应
,并收集所有 parts
if
parts
:=
extractGeminiParts
(
parsed
);
len
(
parts
)
>
0
{
lastWithParts
=
parsed
// 收集所有 parts(text、thinking、functionCall、inlineData 等)
collectedParts
=
append
(
collectedParts
,
parts
...
)
}
case
<-
intervalCh
:
...
...
@@ -2252,6 +2343,11 @@ returnResponse:
return
nil
,
s
.
writeClaudeError
(
c
,
http
.
StatusBadGateway
,
"upstream_error"
,
"Empty response from upstream"
)
}
// 将收集的所有 parts 合并到最终响应中
if
len
(
collectedParts
)
>
0
{
finalResponse
=
mergeCollectedPartsToResponse
(
finalResponse
,
collectedParts
)
}
// 序列化为 JSON(Gemini 格式)
geminiBody
,
err
:=
json
.
Marshal
(
finalResponse
)
if
err
!=
nil
{
...
...
@@ -2459,3 +2555,55 @@ func isImageGenerationModel(model string) bool {
modelLower
==
"gemini-2.5-flash-image-preview"
||
strings
.
HasPrefix
(
modelLower
,
"gemini-2.5-flash-image-"
)
}
// cleanGeminiRequest 清理 Gemini 请求体中的 Schema
func
cleanGeminiRequest
(
body
[]
byte
)
([]
byte
,
error
)
{
var
payload
map
[
string
]
any
if
err
:=
json
.
Unmarshal
(
body
,
&
payload
);
err
!=
nil
{
return
nil
,
err
}
modified
:=
false
// 1. 清理 Tools
if
tools
,
ok
:=
payload
[
"tools"
]
.
([]
any
);
ok
&&
len
(
tools
)
>
0
{
for
_
,
t
:=
range
tools
{
toolMap
,
ok
:=
t
.
(
map
[
string
]
any
)
if
!
ok
{
continue
}
// function_declarations (snake_case) or functionDeclarations (camelCase)
var
funcs
[]
any
if
f
,
ok
:=
toolMap
[
"functionDeclarations"
]
.
([]
any
);
ok
{
funcs
=
f
}
else
if
f
,
ok
:=
toolMap
[
"function_declarations"
]
.
([]
any
);
ok
{
funcs
=
f
}
if
len
(
funcs
)
==
0
{
continue
}
for
_
,
f
:=
range
funcs
{
funcMap
,
ok
:=
f
.
(
map
[
string
]
any
)
if
!
ok
{
continue
}
if
params
,
ok
:=
funcMap
[
"parameters"
]
.
(
map
[
string
]
any
);
ok
{
antigravity
.
DeepCleanUndefined
(
params
)
cleaned
:=
antigravity
.
CleanJSONSchema
(
params
)
funcMap
[
"parameters"
]
=
cleaned
modified
=
true
}
}
}
}
if
!
modified
{
return
body
,
nil
}
return
json
.
Marshal
(
payload
)
}
backend/internal/service/antigravity_oauth_service.go
View file @
a161fcc8
...
...
@@ -142,12 +142,13 @@ func (s *AntigravityOAuthService) ExchangeCode(ctx context.Context, input *Antig
result
.
Email
=
userInfo
.
Email
}
// 获取 project_id(部分账户类型可能没有)
loadResp
,
_
,
err
:=
client
.
LoadCodeAssist
(
ctx
,
tokenResp
.
AccessToken
)
if
err
!=
nil
{
fmt
.
Printf
(
"[AntigravityOAuth] 警告: 获取 project_id 失败: %v
\n
"
,
err
)
}
else
if
loadResp
!=
nil
&&
loadResp
.
CloudAICompanionProject
!=
""
{
result
.
ProjectID
=
loadResp
.
CloudAICompanionProject
// 获取 project_id(部分账户类型可能没有),失败时重试
projectID
,
loadErr
:=
s
.
loadProjectIDWithRetry
(
ctx
,
tokenResp
.
AccessToken
,
proxyURL
,
3
)
if
loadErr
!=
nil
{
fmt
.
Printf
(
"[AntigravityOAuth] 警告: 获取 project_id 失败(重试后): %v
\n
"
,
loadErr
)
result
.
ProjectIDMissing
=
true
}
else
{
result
.
ProjectID
=
projectID
}
return
result
,
nil
...
...
@@ -237,21 +238,60 @@ func (s *AntigravityOAuthService) RefreshAccountToken(ctx context.Context, accou
tokenInfo
.
Email
=
existingEmail
}
// 每次刷新都调用 LoadCodeAssist 获取 project_id
client
:=
antigravity
.
NewClient
(
proxyURL
)
loadResp
,
_
,
err
:=
client
.
LoadCodeAssist
(
ctx
,
tokenInfo
.
AccessToken
)
if
err
!=
nil
||
loadResp
==
nil
||
loadResp
.
CloudAICompanionProject
==
""
{
// LoadCodeAssist 失败或返回空,保留原有 project_id,标记缺失
existingProjectID
:=
strings
.
TrimSpace
(
account
.
GetCredential
(
"
project_id
"
))
// 每次刷新都调用 LoadCodeAssist 获取 project_id
,失败时重试
existingProjectID
:=
strings
.
TrimSpace
(
account
.
GetCredential
(
"project_id"
)
)
projectID
,
loadErr
:=
s
.
loadProjectIDWithRetry
(
ctx
,
tokenInfo
.
AccessToken
,
proxyURL
,
3
)
if
loadErr
!=
nil
{
// LoadCodeAssist 失败,保留原有
project_id
tokenInfo
.
ProjectID
=
existingProjectID
tokenInfo
.
ProjectIDMissing
=
true
// 只有从未获取过 project_id 且本次也获取失败时,才标记为真正缺失
// 如果之前有 project_id,本次只是临时故障,不应标记为错误
if
existingProjectID
==
""
{
tokenInfo
.
ProjectIDMissing
=
true
}
}
else
{
tokenInfo
.
ProjectID
=
loadResp
.
CloudAICompanionP
roject
tokenInfo
.
ProjectID
=
p
roject
ID
}
return
tokenInfo
,
nil
}
// loadProjectIDWithRetry 带重试机制获取 project_id
// 返回 project_id 和错误,失败时会重试指定次数
func
(
s
*
AntigravityOAuthService
)
loadProjectIDWithRetry
(
ctx
context
.
Context
,
accessToken
,
proxyURL
string
,
maxRetries
int
)
(
string
,
error
)
{
var
lastErr
error
for
attempt
:=
0
;
attempt
<=
maxRetries
;
attempt
++
{
if
attempt
>
0
{
// 指数退避:1s, 2s, 4s
backoff
:=
time
.
Duration
(
1
<<
uint
(
attempt
-
1
))
*
time
.
Second
if
backoff
>
8
*
time
.
Second
{
backoff
=
8
*
time
.
Second
}
time
.
Sleep
(
backoff
)
}
client
:=
antigravity
.
NewClient
(
proxyURL
)
loadResp
,
_
,
err
:=
client
.
LoadCodeAssist
(
ctx
,
accessToken
)
if
err
==
nil
&&
loadResp
!=
nil
&&
loadResp
.
CloudAICompanionProject
!=
""
{
return
loadResp
.
CloudAICompanionProject
,
nil
}
// 记录错误
if
err
!=
nil
{
lastErr
=
err
}
else
if
loadResp
==
nil
{
lastErr
=
fmt
.
Errorf
(
"LoadCodeAssist 返回空响应"
)
}
else
{
lastErr
=
fmt
.
Errorf
(
"LoadCodeAssist 返回空 project_id"
)
}
}
return
""
,
fmt
.
Errorf
(
"获取 project_id 失败 (重试 %d 次后): %w"
,
maxRetries
,
lastErr
)
}
// BuildAccountCredentials 构建账户凭证
func
(
s
*
AntigravityOAuthService
)
BuildAccountCredentials
(
tokenInfo
*
AntigravityTokenInfo
)
map
[
string
]
any
{
creds
:=
map
[
string
]
any
{
...
...
backend/internal/service/antigravity_rate_limit_test.go
View file @
a161fcc8
...
...
@@ -94,14 +94,14 @@ func TestAntigravityRetryLoop_URLFallback_UsesLatestSuccess(t *testing.T) {
var
handleErrorCalled
bool
result
,
err
:=
antigravityRetryLoop
(
antigravityRetryLoopParams
{
prefix
:
"[test]"
,
ctx
:
context
.
Background
(),
account
:
account
,
proxyURL
:
""
,
accessToken
:
"token"
,
action
:
"generateContent"
,
body
:
[]
byte
(
`{"input":"test"}`
),
quotaScope
:
AntigravityQuotaScopeClaude
,
prefix
:
"[test]"
,
ctx
:
context
.
Background
(),
account
:
account
,
proxyURL
:
""
,
accessToken
:
"token"
,
action
:
"generateContent"
,
body
:
[]
byte
(
`{"input":"test"}`
),
quotaScope
:
AntigravityQuotaScopeClaude
,
httpUpstream
:
upstream
,
handleError
:
func
(
ctx
context
.
Context
,
prefix
string
,
account
*
Account
,
statusCode
int
,
headers
http
.
Header
,
body
[]
byte
,
quotaScope
AntigravityQuotaScope
)
{
handleErrorCalled
=
true
...
...
backend/internal/service/antigravity_token_provider.go
View file @
a161fcc8
...
...
@@ -4,6 +4,7 @@ import (
"context"
"errors"
"log"
"log/slog"
"strconv"
"strings"
"time"
...
...
@@ -101,21 +102,32 @@ func (p *AntigravityTokenProvider) GetAccessToken(ctx context.Context, account *
return
""
,
errors
.
New
(
"access_token not found in credentials"
)
}
// 3. 存入缓存
// 3. 存入缓存
(验证版本后再写入,避免异步刷新任务与请求线程的竞态条件)
if
p
.
tokenCache
!=
nil
{
ttl
:=
30
*
time
.
Minute
if
expiresAt
!=
nil
{
until
:=
time
.
Until
(
*
expiresAt
)
switch
{
case
until
>
antigravityTokenCacheSkew
:
ttl
=
until
-
antigravityTokenCacheSkew
case
until
>
0
:
ttl
=
until
default
:
ttl
=
time
.
Minute
latestAccount
,
isStale
:=
CheckTokenVersion
(
ctx
,
account
,
p
.
accountRepo
)
if
isStale
&&
latestAccount
!=
nil
{
// 版本过时,使用 DB 中的最新 token
slog
.
Debug
(
"antigravity_token_version_stale_use_latest"
,
"account_id"
,
account
.
ID
)
accessToken
=
latestAccount
.
GetCredential
(
"access_token"
)
if
strings
.
TrimSpace
(
accessToken
)
==
""
{
return
""
,
errors
.
New
(
"access_token not found after version check"
)
}
// 不写入缓存,让下次请求重新处理
}
else
{
ttl
:=
30
*
time
.
Minute
if
expiresAt
!=
nil
{
until
:=
time
.
Until
(
*
expiresAt
)
switch
{
case
until
>
antigravityTokenCacheSkew
:
ttl
=
until
-
antigravityTokenCacheSkew
case
until
>
0
:
ttl
=
until
default
:
ttl
=
time
.
Minute
}
}
_
=
p
.
tokenCache
.
SetAccessToken
(
ctx
,
cacheKey
,
accessToken
,
ttl
)
}
_
=
p
.
tokenCache
.
SetAccessToken
(
ctx
,
cacheKey
,
accessToken
,
ttl
)
}
return
accessToken
,
nil
...
...
backend/internal/service/antigravity_token_refresher.go
View file @
a161fcc8
...
...
@@ -3,6 +3,8 @@ package service
import
(
"context"
"fmt"
"log"
"strings"
"time"
)
...
...
@@ -55,15 +57,32 @@ func (r *AntigravityTokenRefresher) Refresh(ctx context.Context, account *Accoun
}
newCredentials
:=
r
.
antigravityOAuthService
.
BuildAccountCredentials
(
tokenInfo
)
// 合并旧的 credentials,保留新 credentials 中不存在的字段
for
k
,
v
:=
range
account
.
Credentials
{
if
_
,
exists
:=
newCredentials
[
k
];
!
exists
{
newCredentials
[
k
]
=
v
}
}
// 如果 project_id 获取失败,返回 credentials 但同时返回错误让账户被标记
// 特殊处理 project_id:如果新值为空但旧值非空,保留旧值
// 这确保了即使 LoadCodeAssist 失败,project_id 也不会丢失
if
newProjectID
,
_
:=
newCredentials
[
"project_id"
]
.
(
string
);
newProjectID
==
""
{
if
oldProjectID
:=
strings
.
TrimSpace
(
account
.
GetCredential
(
"project_id"
));
oldProjectID
!=
""
{
newCredentials
[
"project_id"
]
=
oldProjectID
}
}
// 如果 project_id 获取失败,只记录警告,不返回错误
// LoadCodeAssist 失败可能是临时网络问题,应该允许重试而不是立即标记为不可重试错误
// Token 刷新本身是成功的(access_token 和 refresh_token 已更新)
if
tokenInfo
.
ProjectIDMissing
{
return
newCredentials
,
fmt
.
Errorf
(
"missing_project_id: 账户缺少project id,可能无法使用Antigravity"
)
if
tokenInfo
.
ProjectID
!=
""
{
// 有旧的 project_id,本次获取失败,保留旧值
log
.
Printf
(
"[AntigravityTokenRefresher] Account %d: LoadCodeAssist 临时失败,保留旧 project_id"
,
account
.
ID
)
}
else
{
// 从未获取过 project_id,本次也失败,但不返回错误以允许下次重试
log
.
Printf
(
"[AntigravityTokenRefresher] Account %d: LoadCodeAssist 失败,project_id 缺失,但 token 已更新,将在下次刷新时重试"
,
account
.
ID
)
}
}
return
newCredentials
,
nil
...
...
backend/internal/service/auth_service.go
View file @
a161fcc8
This diff is collapsed.
Click to expand it.
backend/internal/service/auth_service_register_test.go
View file @
a161fcc8
This diff is collapsed.
Click to expand it.
backend/internal/service/claude_token_provider.go
View file @
a161fcc8
This diff is collapsed.
Click to expand it.
backend/internal/service/domain_constants.go
View file @
a161fcc8
This diff is collapsed.
Click to expand it.
backend/internal/service/email_queue_service.go
View file @
a161fcc8
This diff is collapsed.
Click to expand it.
Prev
1
2
3
4
5
6
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment