"backend/vscode:/vscode.git/clone" did not exist on "a61d58716fb7a15a2bbedc5e7d8260f5b7bf0974"
Commit a161fcc8 authored by cyhhao's avatar cyhhao
Browse files

Merge branch 'main' of github.com:Wei-Shaw/sub2api

parents 65e69738 e32c5f53
...@@ -7,6 +7,7 @@ import ( ...@@ -7,6 +7,7 @@ import (
"fmt" "fmt"
"sort" "sort"
"strings" "strings"
"time"
dbent "github.com/Wei-Shaw/sub2api/ent" dbent "github.com/Wei-Shaw/sub2api/ent"
dbuser "github.com/Wei-Shaw/sub2api/ent/user" dbuser "github.com/Wei-Shaw/sub2api/ent/user"
...@@ -466,3 +467,46 @@ func applyUserEntityToService(dst *service.User, src *dbent.User) { ...@@ -466,3 +467,46 @@ func applyUserEntityToService(dst *service.User, src *dbent.User) {
dst.CreatedAt = src.CreatedAt dst.CreatedAt = src.CreatedAt
dst.UpdatedAt = src.UpdatedAt dst.UpdatedAt = src.UpdatedAt
} }
// UpdateTotpSecret 更新用户的 TOTP 加密密钥
func (r *userRepository) UpdateTotpSecret(ctx context.Context, userID int64, encryptedSecret *string) error {
client := clientFromContext(ctx, r.client)
update := client.User.UpdateOneID(userID)
if encryptedSecret == nil {
update = update.ClearTotpSecretEncrypted()
} else {
update = update.SetTotpSecretEncrypted(*encryptedSecret)
}
_, err := update.Save(ctx)
if err != nil {
return translatePersistenceError(err, service.ErrUserNotFound, nil)
}
return nil
}
// EnableTotp 启用用户的 TOTP 双因素认证
func (r *userRepository) EnableTotp(ctx context.Context, userID int64) error {
client := clientFromContext(ctx, r.client)
_, err := client.User.UpdateOneID(userID).
SetTotpEnabled(true).
SetTotpEnabledAt(time.Now()).
Save(ctx)
if err != nil {
return translatePersistenceError(err, service.ErrUserNotFound, nil)
}
return nil
}
// DisableTotp 禁用用户的 TOTP 双因素认证
func (r *userRepository) DisableTotp(ctx context.Context, userID int64) error {
client := clientFromContext(ctx, r.client)
_, err := client.User.UpdateOneID(userID).
SetTotpEnabled(false).
ClearTotpEnabledAt().
ClearTotpSecretEncrypted().
Save(ctx)
if err != nil {
return translatePersistenceError(err, service.ErrUserNotFound, nil)
}
return nil
}
...@@ -190,7 +190,7 @@ func (r *userSubscriptionRepository) ListByGroupID(ctx context.Context, groupID ...@@ -190,7 +190,7 @@ func (r *userSubscriptionRepository) ListByGroupID(ctx context.Context, groupID
return userSubscriptionEntitiesToService(subs), paginationResultFromTotal(int64(total), params), nil return userSubscriptionEntitiesToService(subs), paginationResultFromTotal(int64(total), params), nil
} }
func (r *userSubscriptionRepository) List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status string) ([]service.UserSubscription, *pagination.PaginationResult, error) { func (r *userSubscriptionRepository) List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status, sortBy, sortOrder string) ([]service.UserSubscription, *pagination.PaginationResult, error) {
client := clientFromContext(ctx, r.client) client := clientFromContext(ctx, r.client)
q := client.UserSubscription.Query() q := client.UserSubscription.Query()
if userID != nil { if userID != nil {
...@@ -199,7 +199,31 @@ func (r *userSubscriptionRepository) List(ctx context.Context, params pagination ...@@ -199,7 +199,31 @@ func (r *userSubscriptionRepository) List(ctx context.Context, params pagination
if groupID != nil { if groupID != nil {
q = q.Where(usersubscription.GroupIDEQ(*groupID)) q = q.Where(usersubscription.GroupIDEQ(*groupID))
} }
if status != "" {
// Status filtering with real-time expiration check
now := time.Now()
switch status {
case service.SubscriptionStatusActive:
// Active: status is active AND not yet expired
q = q.Where(
usersubscription.StatusEQ(service.SubscriptionStatusActive),
usersubscription.ExpiresAtGT(now),
)
case service.SubscriptionStatusExpired:
// Expired: status is expired OR (status is active but already expired)
q = q.Where(
usersubscription.Or(
usersubscription.StatusEQ(service.SubscriptionStatusExpired),
usersubscription.And(
usersubscription.StatusEQ(service.SubscriptionStatusActive),
usersubscription.ExpiresAtLTE(now),
),
),
)
case "":
// No filter
default:
// Other status (e.g., revoked)
q = q.Where(usersubscription.StatusEQ(status)) q = q.Where(usersubscription.StatusEQ(status))
} }
...@@ -208,11 +232,28 @@ func (r *userSubscriptionRepository) List(ctx context.Context, params pagination ...@@ -208,11 +232,28 @@ func (r *userSubscriptionRepository) List(ctx context.Context, params pagination
return nil, nil, err return nil, nil, err
} }
// Apply sorting
q = q.WithUser().WithGroup().WithAssignedByUser()
// Determine sort field
var field string
switch sortBy {
case "expires_at":
field = usersubscription.FieldExpiresAt
case "status":
field = usersubscription.FieldStatus
default:
field = usersubscription.FieldCreatedAt
}
// Determine sort order (default: desc)
if sortOrder == "asc" && sortBy != "" {
q = q.Order(dbent.Asc(field))
} else {
q = q.Order(dbent.Desc(field))
}
subs, err := q. subs, err := q.
WithUser().
WithGroup().
WithAssignedByUser().
Order(dbent.Desc(usersubscription.FieldCreatedAt)).
Offset(params.Offset()). Offset(params.Offset()).
Limit(params.Limit()). Limit(params.Limit()).
All(ctx) All(ctx)
......
...@@ -271,7 +271,7 @@ func (s *UserSubscriptionRepoSuite) TestList_NoFilters() { ...@@ -271,7 +271,7 @@ func (s *UserSubscriptionRepoSuite) TestList_NoFilters() {
group := s.mustCreateGroup("g-list") group := s.mustCreateGroup("g-list")
s.mustCreateSubscription(user.ID, group.ID, nil) s.mustCreateSubscription(user.ID, group.ID, nil)
subs, page, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, nil, nil, "") subs, page, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, nil, nil, "", "", "")
s.Require().NoError(err, "List") s.Require().NoError(err, "List")
s.Require().Len(subs, 1) s.Require().Len(subs, 1)
s.Require().Equal(int64(1), page.Total) s.Require().Equal(int64(1), page.Total)
...@@ -285,7 +285,7 @@ func (s *UserSubscriptionRepoSuite) TestList_FilterByUserID() { ...@@ -285,7 +285,7 @@ func (s *UserSubscriptionRepoSuite) TestList_FilterByUserID() {
s.mustCreateSubscription(user1.ID, group.ID, nil) s.mustCreateSubscription(user1.ID, group.ID, nil)
s.mustCreateSubscription(user2.ID, group.ID, nil) s.mustCreateSubscription(user2.ID, group.ID, nil)
subs, _, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, &user1.ID, nil, "") subs, _, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, &user1.ID, nil, "", "", "")
s.Require().NoError(err) s.Require().NoError(err)
s.Require().Len(subs, 1) s.Require().Len(subs, 1)
s.Require().Equal(user1.ID, subs[0].UserID) s.Require().Equal(user1.ID, subs[0].UserID)
...@@ -299,7 +299,7 @@ func (s *UserSubscriptionRepoSuite) TestList_FilterByGroupID() { ...@@ -299,7 +299,7 @@ func (s *UserSubscriptionRepoSuite) TestList_FilterByGroupID() {
s.mustCreateSubscription(user.ID, g1.ID, nil) s.mustCreateSubscription(user.ID, g1.ID, nil)
s.mustCreateSubscription(user.ID, g2.ID, nil) s.mustCreateSubscription(user.ID, g2.ID, nil)
subs, _, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, nil, &g1.ID, "") subs, _, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, nil, &g1.ID, "", "", "")
s.Require().NoError(err) s.Require().NoError(err)
s.Require().Len(subs, 1) s.Require().Len(subs, 1)
s.Require().Equal(g1.ID, subs[0].GroupID) s.Require().Equal(g1.ID, subs[0].GroupID)
...@@ -320,7 +320,7 @@ func (s *UserSubscriptionRepoSuite) TestList_FilterByStatus() { ...@@ -320,7 +320,7 @@ func (s *UserSubscriptionRepoSuite) TestList_FilterByStatus() {
c.SetExpiresAt(time.Now().Add(-24 * time.Hour)) c.SetExpiresAt(time.Now().Add(-24 * time.Hour))
}) })
subs, _, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, nil, nil, service.SubscriptionStatusExpired) subs, _, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, nil, nil, service.SubscriptionStatusExpired, "", "")
s.Require().NoError(err) s.Require().NoError(err)
s.Require().Len(subs, 1) s.Require().Len(subs, 1)
s.Require().Equal(service.SubscriptionStatusExpired, subs[0].Status) s.Require().Equal(service.SubscriptionStatusExpired, subs[0].Status)
......
...@@ -82,6 +82,10 @@ var ProviderSet = wire.NewSet( ...@@ -82,6 +82,10 @@ var ProviderSet = wire.NewSet(
NewSchedulerCache, NewSchedulerCache,
NewSchedulerOutboxRepository, NewSchedulerOutboxRepository,
NewProxyLatencyCache, NewProxyLatencyCache,
NewTotpCache,
// Encryptors
NewAESEncryptor,
// HTTP service ports (DI Strategy A: return interface directly) // HTTP service ports (DI Strategy A: return interface directly)
NewTurnstileVerifier, NewTurnstileVerifier,
......
...@@ -193,20 +193,20 @@ func TestAPIContracts(t *testing.T) { ...@@ -193,20 +193,20 @@ func TestAPIContracts(t *testing.T) {
// 普通用户订阅接口不应包含 assigned_* / notes 等管理员字段。 // 普通用户订阅接口不应包含 assigned_* / notes 等管理员字段。
deps.userSubRepo.SetByUserID(1, []service.UserSubscription{ deps.userSubRepo.SetByUserID(1, []service.UserSubscription{
{ {
ID: 501, ID: 501,
UserID: 1, UserID: 1,
GroupID: 10, GroupID: 10,
StartsAt: deps.now, StartsAt: deps.now,
ExpiresAt: deps.now.Add(24 * time.Hour), ExpiresAt: time.Date(2099, 1, 2, 3, 4, 5, 0, time.UTC), // 使用未来日期避免 normalizeSubscriptionStatus 标记为过期
Status: service.SubscriptionStatusActive, Status: service.SubscriptionStatusActive,
DailyUsageUSD: 1.23, DailyUsageUSD: 1.23,
WeeklyUsageUSD: 2.34, WeeklyUsageUSD: 2.34,
MonthlyUsageUSD: 3.45, MonthlyUsageUSD: 3.45,
AssignedBy: ptr(int64(999)), AssignedBy: ptr(int64(999)),
AssignedAt: deps.now, AssignedAt: deps.now,
Notes: "admin-note", Notes: "admin-note",
CreatedAt: deps.now, CreatedAt: deps.now,
UpdatedAt: deps.now, UpdatedAt: deps.now,
}, },
}) })
}, },
...@@ -222,7 +222,7 @@ func TestAPIContracts(t *testing.T) { ...@@ -222,7 +222,7 @@ func TestAPIContracts(t *testing.T) {
"user_id": 1, "user_id": 1,
"group_id": 10, "group_id": 10,
"starts_at": "2025-01-02T03:04:05Z", "starts_at": "2025-01-02T03:04:05Z",
"expires_at": "2025-01-03T03:04:05Z", "expires_at": "2099-01-02T03:04:05Z",
"status": "active", "status": "active",
"daily_window_start": null, "daily_window_start": null,
"weekly_window_start": null, "weekly_window_start": null,
...@@ -452,6 +452,9 @@ func TestAPIContracts(t *testing.T) { ...@@ -452,6 +452,9 @@ func TestAPIContracts(t *testing.T) {
"registration_enabled": true, "registration_enabled": true,
"email_verify_enabled": false, "email_verify_enabled": false,
"promo_code_enabled": true, "promo_code_enabled": true,
"password_reset_enabled": false,
"totp_enabled": false,
"totp_encryption_key_configured": false,
"smtp_host": "smtp.example.com", "smtp_host": "smtp.example.com",
"smtp_port": 587, "smtp_port": 587,
"smtp_username": "user", "smtp_username": "user",
...@@ -595,7 +598,7 @@ func newContractDeps(t *testing.T) *contractDeps { ...@@ -595,7 +598,7 @@ func newContractDeps(t *testing.T) *contractDeps {
settingService := service.NewSettingService(settingRepo, cfg) settingService := service.NewSettingService(settingRepo, cfg)
adminService := service.NewAdminService(userRepo, groupRepo, &accountRepo, proxyRepo, apiKeyRepo, redeemRepo, nil, nil, nil, nil) adminService := service.NewAdminService(userRepo, groupRepo, &accountRepo, proxyRepo, apiKeyRepo, redeemRepo, nil, nil, nil, nil)
authHandler := handler.NewAuthHandler(cfg, nil, userService, settingService, nil) authHandler := handler.NewAuthHandler(cfg, nil, userService, settingService, nil, nil)
apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService) apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService)
usageHandler := handler.NewUsageHandler(usageService, apiKeyService) usageHandler := handler.NewUsageHandler(usageService, apiKeyService)
adminSettingHandler := adminhandler.NewSettingHandler(settingService, nil, nil, nil) adminSettingHandler := adminhandler.NewSettingHandler(settingService, nil, nil, nil)
...@@ -754,6 +757,18 @@ func (r *stubUserRepo) RemoveGroupFromAllowedGroups(ctx context.Context, groupID ...@@ -754,6 +757,18 @@ func (r *stubUserRepo) RemoveGroupFromAllowedGroups(ctx context.Context, groupID
return 0, errors.New("not implemented") return 0, errors.New("not implemented")
} }
func (r *stubUserRepo) UpdateTotpSecret(ctx context.Context, userID int64, encryptedSecret *string) error {
return errors.New("not implemented")
}
func (r *stubUserRepo) EnableTotp(ctx context.Context, userID int64) error {
return errors.New("not implemented")
}
func (r *stubUserRepo) DisableTotp(ctx context.Context, userID int64) error {
return errors.New("not implemented")
}
type stubApiKeyCache struct{} type stubApiKeyCache struct{}
func (stubApiKeyCache) GetCreateAttemptCount(ctx context.Context, userID int64) (int, error) { func (stubApiKeyCache) GetCreateAttemptCount(ctx context.Context, userID int64) (int, error) {
...@@ -1176,7 +1191,7 @@ func (r *stubUserSubscriptionRepo) ListActiveByUserID(ctx context.Context, userI ...@@ -1176,7 +1191,7 @@ func (r *stubUserSubscriptionRepo) ListActiveByUserID(ctx context.Context, userI
func (stubUserSubscriptionRepo) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]service.UserSubscription, *pagination.PaginationResult, error) { func (stubUserSubscriptionRepo) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]service.UserSubscription, *pagination.PaginationResult, error) {
return nil, nil, errors.New("not implemented") return nil, nil, errors.New("not implemented")
} }
func (stubUserSubscriptionRepo) List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status string) ([]service.UserSubscription, *pagination.PaginationResult, error) { func (stubUserSubscriptionRepo) List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status, sortBy, sortOrder string) ([]service.UserSubscription, *pagination.PaginationResult, error) {
return nil, nil, errors.New("not implemented") return nil, nil, errors.New("not implemented")
} }
func (stubUserSubscriptionRepo) ExistsByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (bool, error) { func (stubUserSubscriptionRepo) ExistsByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (bool, error) {
......
...@@ -367,7 +367,7 @@ func (r *stubUserSubscriptionRepo) ListByGroupID(ctx context.Context, groupID in ...@@ -367,7 +367,7 @@ func (r *stubUserSubscriptionRepo) ListByGroupID(ctx context.Context, groupID in
return nil, nil, errors.New("not implemented") return nil, nil, errors.New("not implemented")
} }
func (r *stubUserSubscriptionRepo) List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status string) ([]service.UserSubscription, *pagination.PaginationResult, error) { func (r *stubUserSubscriptionRepo) List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status, sortBy, sortOrder string) ([]service.UserSubscription, *pagination.PaginationResult, error) {
return nil, nil, errors.New("not implemented") return nil, nil, errors.New("not implemented")
} }
......
...@@ -26,11 +26,20 @@ func RegisterAuthRoutes( ...@@ -26,11 +26,20 @@ func RegisterAuthRoutes(
{ {
auth.POST("/register", h.Auth.Register) auth.POST("/register", h.Auth.Register)
auth.POST("/login", h.Auth.Login) auth.POST("/login", h.Auth.Login)
auth.POST("/login/2fa", h.Auth.Login2FA)
auth.POST("/send-verify-code", h.Auth.SendVerifyCode) auth.POST("/send-verify-code", h.Auth.SendVerifyCode)
// 优惠码验证接口添加速率限制:每分钟最多 10 次(Redis 故障时 fail-close) // 优惠码验证接口添加速率限制:每分钟最多 10 次(Redis 故障时 fail-close)
auth.POST("/validate-promo-code", rateLimiter.LimitWithOptions("validate-promo", 10, time.Minute, middleware.RateLimitOptions{ auth.POST("/validate-promo-code", rateLimiter.LimitWithOptions("validate-promo", 10, time.Minute, middleware.RateLimitOptions{
FailureMode: middleware.RateLimitFailClose, FailureMode: middleware.RateLimitFailClose,
}), h.Auth.ValidatePromoCode) }), h.Auth.ValidatePromoCode)
// 忘记密码接口添加速率限制:每分钟最多 5 次(Redis 故障时 fail-close)
auth.POST("/forgot-password", rateLimiter.LimitWithOptions("forgot-password", 5, time.Minute, middleware.RateLimitOptions{
FailureMode: middleware.RateLimitFailClose,
}), h.Auth.ForgotPassword)
// 重置密码接口添加速率限制:每分钟最多 10 次(Redis 故障时 fail-close)
auth.POST("/reset-password", rateLimiter.LimitWithOptions("reset-password", 10, time.Minute, middleware.RateLimitOptions{
FailureMode: middleware.RateLimitFailClose,
}), h.Auth.ResetPassword)
auth.GET("/oauth/linuxdo/start", h.Auth.LinuxDoOAuthStart) auth.GET("/oauth/linuxdo/start", h.Auth.LinuxDoOAuthStart)
auth.GET("/oauth/linuxdo/callback", h.Auth.LinuxDoOAuthCallback) auth.GET("/oauth/linuxdo/callback", h.Auth.LinuxDoOAuthCallback)
} }
......
...@@ -22,6 +22,17 @@ func RegisterUserRoutes( ...@@ -22,6 +22,17 @@ func RegisterUserRoutes(
user.GET("/profile", h.User.GetProfile) user.GET("/profile", h.User.GetProfile)
user.PUT("/password", h.User.ChangePassword) user.PUT("/password", h.User.ChangePassword)
user.PUT("", h.User.UpdateProfile) user.PUT("", h.User.UpdateProfile)
// TOTP 双因素认证
totp := user.Group("/totp")
{
totp.GET("/status", h.Totp.GetStatus)
totp.GET("/verification-method", h.Totp.GetVerificationMethod)
totp.POST("/send-code", h.Totp.SendVerifyCode)
totp.POST("/setup", h.Totp.InitiateSetup)
totp.POST("/enable", h.Totp.Enable)
totp.POST("/disable", h.Totp.Disable)
}
} }
// API Key管理 // API Key管理
......
...@@ -197,6 +197,35 @@ func (a *Account) GetCredentialAsTime(key string) *time.Time { ...@@ -197,6 +197,35 @@ func (a *Account) GetCredentialAsTime(key string) *time.Time {
return nil return nil
} }
// GetCredentialAsInt64 解析凭证中的 int64 字段
// 用于读取 _token_version 等内部字段
func (a *Account) GetCredentialAsInt64(key string) int64 {
if a == nil || a.Credentials == nil {
return 0
}
val, ok := a.Credentials[key]
if !ok || val == nil {
return 0
}
switch v := val.(type) {
case int64:
return v
case float64:
return int64(v)
case int:
return int64(v)
case json.Number:
if i, err := v.Int64(); err == nil {
return i
}
case string:
if i, err := strconv.ParseInt(strings.TrimSpace(v), 10, 64); err == nil {
return i
}
}
return 0
}
func (a *Account) IsTempUnschedulableEnabled() bool { func (a *Account) IsTempUnschedulableEnabled() bool {
if a.Credentials == nil { if a.Credentials == nil {
return false return false
......
...@@ -93,6 +93,18 @@ func (s *userRepoStub) RemoveGroupFromAllowedGroups(ctx context.Context, groupID ...@@ -93,6 +93,18 @@ func (s *userRepoStub) RemoveGroupFromAllowedGroups(ctx context.Context, groupID
panic("unexpected RemoveGroupFromAllowedGroups call") panic("unexpected RemoveGroupFromAllowedGroups call")
} }
func (s *userRepoStub) UpdateTotpSecret(ctx context.Context, userID int64, encryptedSecret *string) error {
panic("unexpected UpdateTotpSecret call")
}
func (s *userRepoStub) EnableTotp(ctx context.Context, userID int64) error {
panic("unexpected EnableTotp call")
}
func (s *userRepoStub) DisableTotp(ctx context.Context, userID int64) error {
panic("unexpected DisableTotp call")
}
type groupRepoStub struct { type groupRepoStub struct {
affectedUserIDs []int64 affectedUserIDs []int64
deleteErr error deleteErr error
......
...@@ -1305,6 +1305,14 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co ...@@ -1305,6 +1305,14 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co
return nil, err return nil, err
} }
// 清理 Schema
if cleanedBody, err := cleanGeminiRequest(injectedBody); err == nil {
injectedBody = cleanedBody
log.Printf("[Antigravity] Cleaned request schema in forwarded request for account %s", account.Name)
} else {
log.Printf("[Antigravity] Failed to clean schema: %v", err)
}
// 包装请求 // 包装请求
wrappedBody, err := s.wrapV1InternalRequest(projectID, mappedModel, injectedBody) wrappedBody, err := s.wrapV1InternalRequest(projectID, mappedModel, injectedBody)
if err != nil { if err != nil {
...@@ -1705,6 +1713,19 @@ func (s *AntigravityGatewayService) handleGeminiStreamingResponse(c *gin.Context ...@@ -1705,6 +1713,19 @@ func (s *AntigravityGatewayService) handleGeminiStreamingResponse(c *gin.Context
if u := extractGeminiUsage(parsed); u != nil { if u := extractGeminiUsage(parsed); u != nil {
usage = u usage = u
} }
// Check for MALFORMED_FUNCTION_CALL
if candidates, ok := parsed["candidates"].([]any); ok && len(candidates) > 0 {
if cand, ok := candidates[0].(map[string]any); ok {
if fr, ok := cand["finishReason"].(string); ok && fr == "MALFORMED_FUNCTION_CALL" {
log.Printf("[Antigravity] MALFORMED_FUNCTION_CALL detected in forward stream")
if content, ok := cand["content"]; ok {
if b, err := json.Marshal(content); err == nil {
log.Printf("[Antigravity] Malformed content: %s", string(b))
}
}
}
}
}
} }
if firstTokenMs == nil { if firstTokenMs == nil {
...@@ -1854,6 +1875,20 @@ func (s *AntigravityGatewayService) handleGeminiStreamToNonStreaming(c *gin.Cont ...@@ -1854,6 +1875,20 @@ func (s *AntigravityGatewayService) handleGeminiStreamToNonStreaming(c *gin.Cont
usage = u usage = u
} }
// Check for MALFORMED_FUNCTION_CALL
if candidates, ok := parsed["candidates"].([]any); ok && len(candidates) > 0 {
if cand, ok := candidates[0].(map[string]any); ok {
if fr, ok := cand["finishReason"].(string); ok && fr == "MALFORMED_FUNCTION_CALL" {
log.Printf("[Antigravity] MALFORMED_FUNCTION_CALL detected in forward non-stream collect")
if content, ok := cand["content"]; ok {
if b, err := json.Marshal(content); err == nil {
log.Printf("[Antigravity] Malformed content: %s", string(b))
}
}
}
}
}
// 保留最后一个有 parts 的响应 // 保留最后一个有 parts 的响应
if parts := extractGeminiParts(parsed); len(parts) > 0 { if parts := extractGeminiParts(parsed); len(parts) > 0 {
lastWithParts = parsed lastWithParts = parsed
...@@ -1950,6 +1985,58 @@ func getOrCreateGeminiParts(response map[string]any) (result map[string]any, exi ...@@ -1950,6 +1985,58 @@ func getOrCreateGeminiParts(response map[string]any) (result map[string]any, exi
return result, existingParts, setParts return result, existingParts, setParts
} }
// mergeCollectedPartsToResponse 将收集的所有 parts 合并到 Gemini 响应中
// 这个函数会合并所有类型的 parts:text、thinking、functionCall、inlineData 等
// 保持原始顺序,只合并连续的普通 text parts
func mergeCollectedPartsToResponse(response map[string]any, collectedParts []map[string]any) map[string]any {
if len(collectedParts) == 0 {
return response
}
result, _, setParts := getOrCreateGeminiParts(response)
// 合并策略:
// 1. 保持原始顺序
// 2. 连续的普通 text parts 合并为一个
// 3. thinking、functionCall、inlineData 等保持原样
var mergedParts []any
var textBuffer strings.Builder
flushTextBuffer := func() {
if textBuffer.Len() > 0 {
mergedParts = append(mergedParts, map[string]any{
"text": textBuffer.String(),
})
textBuffer.Reset()
}
}
for _, part := range collectedParts {
// 检查是否是普通 text part
if text, ok := part["text"].(string); ok {
// 检查是否有 thought 标记
if thought, _ := part["thought"].(bool); thought {
// thinking part,先刷新 text buffer,然后保留原样
flushTextBuffer()
mergedParts = append(mergedParts, part)
} else {
// 普通 text,累积到 buffer
_, _ = textBuffer.WriteString(text)
}
} else {
// 非 text part(functionCall、inlineData 等),先刷新 text buffer,然后保留原样
flushTextBuffer()
mergedParts = append(mergedParts, part)
}
}
// 刷新剩余的 text
flushTextBuffer()
setParts(mergedParts)
return result
}
// mergeImagePartsToResponse 将收集到的图片 parts 合并到 Gemini 响应中 // mergeImagePartsToResponse 将收集到的图片 parts 合并到 Gemini 响应中
func mergeImagePartsToResponse(response map[string]any, imageParts []map[string]any) map[string]any { func mergeImagePartsToResponse(response map[string]any, imageParts []map[string]any) map[string]any {
if len(imageParts) == 0 { if len(imageParts) == 0 {
...@@ -2133,6 +2220,7 @@ func (s *AntigravityGatewayService) handleClaudeStreamToNonStreaming(c *gin.Cont ...@@ -2133,6 +2220,7 @@ func (s *AntigravityGatewayService) handleClaudeStreamToNonStreaming(c *gin.Cont
var firstTokenMs *int var firstTokenMs *int
var last map[string]any var last map[string]any
var lastWithParts map[string]any var lastWithParts map[string]any
var collectedParts []map[string]any // 收集所有 parts(包括 text、thinking、functionCall、inlineData 等)
type scanEvent struct { type scanEvent struct {
line string line string
...@@ -2227,9 +2315,12 @@ func (s *AntigravityGatewayService) handleClaudeStreamToNonStreaming(c *gin.Cont ...@@ -2227,9 +2315,12 @@ func (s *AntigravityGatewayService) handleClaudeStreamToNonStreaming(c *gin.Cont
last = parsed last = parsed
// 保留最后一个有 parts 的响应 // 保留最后一个有 parts 的响应,并收集所有 parts
if parts := extractGeminiParts(parsed); len(parts) > 0 { if parts := extractGeminiParts(parsed); len(parts) > 0 {
lastWithParts = parsed lastWithParts = parsed
// 收集所有 parts(text、thinking、functionCall、inlineData 等)
collectedParts = append(collectedParts, parts...)
} }
case <-intervalCh: case <-intervalCh:
...@@ -2252,6 +2343,11 @@ returnResponse: ...@@ -2252,6 +2343,11 @@ returnResponse:
return nil, s.writeClaudeError(c, http.StatusBadGateway, "upstream_error", "Empty response from upstream") return nil, s.writeClaudeError(c, http.StatusBadGateway, "upstream_error", "Empty response from upstream")
} }
// 将收集的所有 parts 合并到最终响应中
if len(collectedParts) > 0 {
finalResponse = mergeCollectedPartsToResponse(finalResponse, collectedParts)
}
// 序列化为 JSON(Gemini 格式) // 序列化为 JSON(Gemini 格式)
geminiBody, err := json.Marshal(finalResponse) geminiBody, err := json.Marshal(finalResponse)
if err != nil { if err != nil {
...@@ -2459,3 +2555,55 @@ func isImageGenerationModel(model string) bool { ...@@ -2459,3 +2555,55 @@ func isImageGenerationModel(model string) bool {
modelLower == "gemini-2.5-flash-image-preview" || modelLower == "gemini-2.5-flash-image-preview" ||
strings.HasPrefix(modelLower, "gemini-2.5-flash-image-") strings.HasPrefix(modelLower, "gemini-2.5-flash-image-")
} }
// cleanGeminiRequest 清理 Gemini 请求体中的 Schema
func cleanGeminiRequest(body []byte) ([]byte, error) {
var payload map[string]any
if err := json.Unmarshal(body, &payload); err != nil {
return nil, err
}
modified := false
// 1. 清理 Tools
if tools, ok := payload["tools"].([]any); ok && len(tools) > 0 {
for _, t := range tools {
toolMap, ok := t.(map[string]any)
if !ok {
continue
}
// function_declarations (snake_case) or functionDeclarations (camelCase)
var funcs []any
if f, ok := toolMap["functionDeclarations"].([]any); ok {
funcs = f
} else if f, ok := toolMap["function_declarations"].([]any); ok {
funcs = f
}
if len(funcs) == 0 {
continue
}
for _, f := range funcs {
funcMap, ok := f.(map[string]any)
if !ok {
continue
}
if params, ok := funcMap["parameters"].(map[string]any); ok {
antigravity.DeepCleanUndefined(params)
cleaned := antigravity.CleanJSONSchema(params)
funcMap["parameters"] = cleaned
modified = true
}
}
}
}
if !modified {
return body, nil
}
return json.Marshal(payload)
}
...@@ -142,12 +142,13 @@ func (s *AntigravityOAuthService) ExchangeCode(ctx context.Context, input *Antig ...@@ -142,12 +142,13 @@ func (s *AntigravityOAuthService) ExchangeCode(ctx context.Context, input *Antig
result.Email = userInfo.Email result.Email = userInfo.Email
} }
// 获取 project_id(部分账户类型可能没有) // 获取 project_id(部分账户类型可能没有),失败时重试
loadResp, _, err := client.LoadCodeAssist(ctx, tokenResp.AccessToken) projectID, loadErr := s.loadProjectIDWithRetry(ctx, tokenResp.AccessToken, proxyURL, 3)
if err != nil { if loadErr != nil {
fmt.Printf("[AntigravityOAuth] 警告: 获取 project_id 失败: %v\n", err) fmt.Printf("[AntigravityOAuth] 警告: 获取 project_id 失败(重试后): %v\n", loadErr)
} else if loadResp != nil && loadResp.CloudAICompanionProject != "" { result.ProjectIDMissing = true
result.ProjectID = loadResp.CloudAICompanionProject } else {
result.ProjectID = projectID
} }
return result, nil return result, nil
...@@ -237,21 +238,60 @@ func (s *AntigravityOAuthService) RefreshAccountToken(ctx context.Context, accou ...@@ -237,21 +238,60 @@ func (s *AntigravityOAuthService) RefreshAccountToken(ctx context.Context, accou
tokenInfo.Email = existingEmail tokenInfo.Email = existingEmail
} }
// 每次刷新都调用 LoadCodeAssist 获取 project_id // 每次刷新都调用 LoadCodeAssist 获取 project_id,失败时重试
client := antigravity.NewClient(proxyURL) existingProjectID := strings.TrimSpace(account.GetCredential("project_id"))
loadResp, _, err := client.LoadCodeAssist(ctx, tokenInfo.AccessToken) projectID, loadErr := s.loadProjectIDWithRetry(ctx, tokenInfo.AccessToken, proxyURL, 3)
if err != nil || loadResp == nil || loadResp.CloudAICompanionProject == "" {
// LoadCodeAssist 失败或返回空,保留原有 project_id,标记缺失 if loadErr != nil {
existingProjectID := strings.TrimSpace(account.GetCredential("project_id")) // LoadCodeAssist 失败,保留原有 project_id
tokenInfo.ProjectID = existingProjectID tokenInfo.ProjectID = existingProjectID
tokenInfo.ProjectIDMissing = true // 只有从未获取过 project_id 且本次也获取失败时,才标记为真正缺失
// 如果之前有 project_id,本次只是临时故障,不应标记为错误
if existingProjectID == "" {
tokenInfo.ProjectIDMissing = true
}
} else { } else {
tokenInfo.ProjectID = loadResp.CloudAICompanionProject tokenInfo.ProjectID = projectID
} }
return tokenInfo, nil return tokenInfo, nil
} }
// loadProjectIDWithRetry 带重试机制获取 project_id
// 返回 project_id 和错误,失败时会重试指定次数
func (s *AntigravityOAuthService) loadProjectIDWithRetry(ctx context.Context, accessToken, proxyURL string, maxRetries int) (string, error) {
var lastErr error
for attempt := 0; attempt <= maxRetries; attempt++ {
if attempt > 0 {
// 指数退避:1s, 2s, 4s
backoff := time.Duration(1<<uint(attempt-1)) * time.Second
if backoff > 8*time.Second {
backoff = 8 * time.Second
}
time.Sleep(backoff)
}
client := antigravity.NewClient(proxyURL)
loadResp, _, err := client.LoadCodeAssist(ctx, accessToken)
if err == nil && loadResp != nil && loadResp.CloudAICompanionProject != "" {
return loadResp.CloudAICompanionProject, nil
}
// 记录错误
if err != nil {
lastErr = err
} else if loadResp == nil {
lastErr = fmt.Errorf("LoadCodeAssist 返回空响应")
} else {
lastErr = fmt.Errorf("LoadCodeAssist 返回空 project_id")
}
}
return "", fmt.Errorf("获取 project_id 失败 (重试 %d 次后): %w", maxRetries, lastErr)
}
// BuildAccountCredentials 构建账户凭证 // BuildAccountCredentials 构建账户凭证
func (s *AntigravityOAuthService) BuildAccountCredentials(tokenInfo *AntigravityTokenInfo) map[string]any { func (s *AntigravityOAuthService) BuildAccountCredentials(tokenInfo *AntigravityTokenInfo) map[string]any {
creds := map[string]any{ creds := map[string]any{
......
...@@ -94,14 +94,14 @@ func TestAntigravityRetryLoop_URLFallback_UsesLatestSuccess(t *testing.T) { ...@@ -94,14 +94,14 @@ func TestAntigravityRetryLoop_URLFallback_UsesLatestSuccess(t *testing.T) {
var handleErrorCalled bool var handleErrorCalled bool
result, err := antigravityRetryLoop(antigravityRetryLoopParams{ result, err := antigravityRetryLoop(antigravityRetryLoopParams{
prefix: "[test]", prefix: "[test]",
ctx: context.Background(), ctx: context.Background(),
account: account, account: account,
proxyURL: "", proxyURL: "",
accessToken: "token", accessToken: "token",
action: "generateContent", action: "generateContent",
body: []byte(`{"input":"test"}`), body: []byte(`{"input":"test"}`),
quotaScope: AntigravityQuotaScopeClaude, quotaScope: AntigravityQuotaScopeClaude,
httpUpstream: upstream, httpUpstream: upstream,
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope) { handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope) {
handleErrorCalled = true handleErrorCalled = true
......
...@@ -4,6 +4,7 @@ import ( ...@@ -4,6 +4,7 @@ import (
"context" "context"
"errors" "errors"
"log" "log"
"log/slog"
"strconv" "strconv"
"strings" "strings"
"time" "time"
...@@ -101,21 +102,32 @@ func (p *AntigravityTokenProvider) GetAccessToken(ctx context.Context, account * ...@@ -101,21 +102,32 @@ func (p *AntigravityTokenProvider) GetAccessToken(ctx context.Context, account *
return "", errors.New("access_token not found in credentials") return "", errors.New("access_token not found in credentials")
} }
// 3. 存入缓存 // 3. 存入缓存(验证版本后再写入,避免异步刷新任务与请求线程的竞态条件)
if p.tokenCache != nil { if p.tokenCache != nil {
ttl := 30 * time.Minute latestAccount, isStale := CheckTokenVersion(ctx, account, p.accountRepo)
if expiresAt != nil { if isStale && latestAccount != nil {
until := time.Until(*expiresAt) // 版本过时,使用 DB 中的最新 token
switch { slog.Debug("antigravity_token_version_stale_use_latest", "account_id", account.ID)
case until > antigravityTokenCacheSkew: accessToken = latestAccount.GetCredential("access_token")
ttl = until - antigravityTokenCacheSkew if strings.TrimSpace(accessToken) == "" {
case until > 0: return "", errors.New("access_token not found after version check")
ttl = until
default:
ttl = time.Minute
} }
// 不写入缓存,让下次请求重新处理
} else {
ttl := 30 * time.Minute
if expiresAt != nil {
until := time.Until(*expiresAt)
switch {
case until > antigravityTokenCacheSkew:
ttl = until - antigravityTokenCacheSkew
case until > 0:
ttl = until
default:
ttl = time.Minute
}
}
_ = p.tokenCache.SetAccessToken(ctx, cacheKey, accessToken, ttl)
} }
_ = p.tokenCache.SetAccessToken(ctx, cacheKey, accessToken, ttl)
} }
return accessToken, nil return accessToken, nil
......
...@@ -3,6 +3,8 @@ package service ...@@ -3,6 +3,8 @@ package service
import ( import (
"context" "context"
"fmt" "fmt"
"log"
"strings"
"time" "time"
) )
...@@ -55,15 +57,32 @@ func (r *AntigravityTokenRefresher) Refresh(ctx context.Context, account *Accoun ...@@ -55,15 +57,32 @@ func (r *AntigravityTokenRefresher) Refresh(ctx context.Context, account *Accoun
} }
newCredentials := r.antigravityOAuthService.BuildAccountCredentials(tokenInfo) newCredentials := r.antigravityOAuthService.BuildAccountCredentials(tokenInfo)
// 合并旧的 credentials,保留新 credentials 中不存在的字段
for k, v := range account.Credentials { for k, v := range account.Credentials {
if _, exists := newCredentials[k]; !exists { if _, exists := newCredentials[k]; !exists {
newCredentials[k] = v newCredentials[k] = v
} }
} }
// 如果 project_id 获取失败,返回 credentials 但同时返回错误让账户被标记 // 特殊处理 project_id:如果新值为空但旧值非空,保留旧值
// 这确保了即使 LoadCodeAssist 失败,project_id 也不会丢失
if newProjectID, _ := newCredentials["project_id"].(string); newProjectID == "" {
if oldProjectID := strings.TrimSpace(account.GetCredential("project_id")); oldProjectID != "" {
newCredentials["project_id"] = oldProjectID
}
}
// 如果 project_id 获取失败,只记录警告,不返回错误
// LoadCodeAssist 失败可能是临时网络问题,应该允许重试而不是立即标记为不可重试错误
// Token 刷新本身是成功的(access_token 和 refresh_token 已更新)
if tokenInfo.ProjectIDMissing { if tokenInfo.ProjectIDMissing {
return newCredentials, fmt.Errorf("missing_project_id: 账户缺少project id,可能无法使用Antigravity") if tokenInfo.ProjectID != "" {
// 有旧的 project_id,本次获取失败,保留旧值
log.Printf("[AntigravityTokenRefresher] Account %d: LoadCodeAssist 临时失败,保留旧 project_id", account.ID)
} else {
// 从未获取过 project_id,本次也失败,但不返回错误以允许下次重试
log.Printf("[AntigravityTokenRefresher] Account %d: LoadCodeAssist 失败,project_id 缺失,但 token 已更新,将在下次刷新时重试", account.ID)
}
} }
return newCredentials, nil return newCredentials, nil
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment