Commit 106b20cd authored by Elysia's avatar Elysia
Browse files

fix claudecode review bug

parent c069b3b1
...@@ -33,7 +33,7 @@ func main() { ...@@ -33,7 +33,7 @@ func main() {
}() }()
userRepo := repository.NewUserRepository(client, sqlDB) userRepo := repository.NewUserRepository(client, sqlDB)
authService := service.NewAuthService(userRepo, nil, nil, cfg, nil, nil, nil, nil, nil, nil) authService := service.NewAuthService(client, userRepo, nil, nil, cfg, nil, nil, nil, nil, nil, nil)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel() defer cancel()
......
...@@ -67,7 +67,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { ...@@ -67,7 +67,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
apiKeyAuthCacheInvalidator := service.ProvideAPIKeyAuthCacheInvalidator(apiKeyService) apiKeyAuthCacheInvalidator := service.ProvideAPIKeyAuthCacheInvalidator(apiKeyService)
promoService := service.NewPromoService(promoCodeRepository, userRepository, billingCacheService, client, apiKeyAuthCacheInvalidator) promoService := service.NewPromoService(promoCodeRepository, userRepository, billingCacheService, client, apiKeyAuthCacheInvalidator)
subscriptionService := service.NewSubscriptionService(groupRepository, userSubscriptionRepository, billingCacheService, client, configConfig) subscriptionService := service.NewSubscriptionService(groupRepository, userSubscriptionRepository, billingCacheService, client, configConfig)
authService := service.NewAuthService(userRepository, redeemCodeRepository, refreshTokenCache, configConfig, settingService, emailService, turnstileService, emailQueueService, promoService, subscriptionService) authService := service.NewAuthService(client, userRepository, redeemCodeRepository, refreshTokenCache, configConfig, settingService, emailService, turnstileService, emailQueueService, promoService, subscriptionService)
userService := service.NewUserService(userRepository, apiKeyAuthCacheInvalidator, billingCache) userService := service.NewUserService(userRepository, apiKeyAuthCacheInvalidator, billingCache)
redeemCache := repository.NewRedeemCache(redisClient) redeemCache := repository.NewRedeemCache(redisClient)
redeemService := service.NewRedeemService(redeemCodeRepository, userRepository, subscriptionService, redeemCache, billingCacheService, client, apiKeyAuthCacheInvalidator) redeemService := service.NewRedeemService(redeemCodeRepository, userRepository, subscriptionService, redeemCache, billingCacheService, client, apiKeyAuthCacheInvalidator)
......
...@@ -264,11 +264,7 @@ func (h *AuthHandler) CompleteLinuxDoOAuthRegistration(c *gin.Context) { ...@@ -264,11 +264,7 @@ func (h *AuthHandler) CompleteLinuxDoOAuthRegistration(c *gin.Context) {
tokenPair, _, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, req.InvitationCode) tokenPair, _, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, req.InvitationCode)
if err != nil { if err != nil {
statusCode := http.StatusBadRequest response.ErrorFrom(c, err)
c.JSON(statusCode, gin.H{
"error": infraerrors.Reason(err),
"message": infraerrors.Message(err),
})
return return
} }
......
...@@ -12,6 +12,7 @@ import ( ...@@ -12,6 +12,7 @@ import (
"strings" "strings"
"time" "time"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/config"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger" "github.com/Wei-Shaw/sub2api/internal/pkg/logger"
...@@ -59,6 +60,7 @@ type JWTClaims struct { ...@@ -59,6 +60,7 @@ type JWTClaims struct {
// AuthService 认证服务 // AuthService 认证服务
type AuthService struct { type AuthService struct {
entClient *dbent.Client
userRepo UserRepository userRepo UserRepository
redeemRepo RedeemCodeRepository redeemRepo RedeemCodeRepository
refreshTokenCache RefreshTokenCache refreshTokenCache RefreshTokenCache
...@@ -77,6 +79,7 @@ type DefaultSubscriptionAssigner interface { ...@@ -77,6 +79,7 @@ type DefaultSubscriptionAssigner interface {
// NewAuthService 创建认证服务实例 // NewAuthService 创建认证服务实例
func NewAuthService( func NewAuthService(
entClient *dbent.Client,
userRepo UserRepository, userRepo UserRepository,
redeemRepo RedeemCodeRepository, redeemRepo RedeemCodeRepository,
refreshTokenCache RefreshTokenCache, refreshTokenCache RefreshTokenCache,
...@@ -89,6 +92,7 @@ func NewAuthService( ...@@ -89,6 +92,7 @@ func NewAuthService(
defaultSubAssigner DefaultSubscriptionAssigner, defaultSubAssigner DefaultSubscriptionAssigner,
) *AuthService { ) *AuthService {
return &AuthService{ return &AuthService{
entClient: entClient,
userRepo: userRepo, userRepo: userRepo,
redeemRepo: redeemRepo, redeemRepo: redeemRepo,
refreshTokenCache: refreshTokenCache, refreshTokenCache: refreshTokenCache,
...@@ -597,7 +601,16 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema ...@@ -597,7 +601,16 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema
Status: StatusActive, Status: StatusActive,
} }
if err := s.userRepo.Create(ctx, newUser); err != nil { if s.entClient != nil && invitationRedeemCode != nil {
tx, err := s.entClient.Tx(ctx)
if err != nil {
logger.LegacyPrintf("service.auth", "[Auth] Failed to begin transaction for oauth registration: %v", err)
return nil, nil, ErrServiceUnavailable
}
defer func() { _ = tx.Rollback() }()
txCtx := dbent.NewTxContext(ctx, tx)
if err := s.userRepo.Create(txCtx, newUser); err != nil {
if errors.Is(err, ErrEmailExists) { if errors.Is(err, ErrEmailExists) {
user, err = s.userRepo.GetByEmail(ctx, email) user, err = s.userRepo.GetByEmail(ctx, email)
if err != nil { if err != nil {
...@@ -609,12 +622,31 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema ...@@ -609,12 +622,31 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema
return nil, nil, ErrServiceUnavailable return nil, nil, ErrServiceUnavailable
} }
} else { } else {
if err := s.redeemRepo.Use(txCtx, invitationRedeemCode.ID, newUser.ID); err != nil {
return nil, nil, ErrInvitationCodeInvalid
}
if err := tx.Commit(); err != nil {
logger.LegacyPrintf("service.auth", "[Auth] Failed to commit oauth registration transaction: %v", err)
return nil, nil, ErrServiceUnavailable
}
user = newUser user = newUser
s.assignDefaultSubscriptions(ctx, user.ID) s.assignDefaultSubscriptions(ctx, user.ID)
if invitationRedeemCode != nil {
if err := s.redeemRepo.Use(ctx, invitationRedeemCode.ID, user.ID); err != nil {
logger.LegacyPrintf("service.auth", "[Auth] Failed to mark invitation code as used for oauth user %d: %v", user.ID, err)
} }
} else {
if err := s.userRepo.Create(ctx, newUser); err != nil {
if errors.Is(err, ErrEmailExists) {
user, err = s.userRepo.GetByEmail(ctx, email)
if err != nil {
logger.LegacyPrintf("service.auth", "[Auth] Database error getting user after conflict: %v", err)
return nil, nil, ErrServiceUnavailable
}
} else {
logger.LegacyPrintf("service.auth", "[Auth] Database error creating oauth user: %v", err)
return nil, nil, ErrServiceUnavailable
}
} else {
user = newUser
s.assignDefaultSubscriptions(ctx, user.ID)
} }
} }
} else { } else {
...@@ -644,9 +676,13 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema ...@@ -644,9 +676,13 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema
// pendingOAuthTokenTTL is the validity period for pending OAuth tokens. // pendingOAuthTokenTTL is the validity period for pending OAuth tokens.
const pendingOAuthTokenTTL = 10 * time.Minute const pendingOAuthTokenTTL = 10 * time.Minute
// pendingOAuthPurpose is the purpose claim value for pending OAuth registration tokens.
const pendingOAuthPurpose = "pending_oauth_registration"
type pendingOAuthClaims struct { type pendingOAuthClaims struct {
Email string `json:"email"` Email string `json:"email"`
Username string `json:"username"` Username string `json:"username"`
Purpose string `json:"purpose"`
jwt.RegisteredClaims jwt.RegisteredClaims
} }
...@@ -657,6 +693,7 @@ func (s *AuthService) CreatePendingOAuthToken(email, username string) (string, e ...@@ -657,6 +693,7 @@ func (s *AuthService) CreatePendingOAuthToken(email, username string) (string, e
claims := &pendingOAuthClaims{ claims := &pendingOAuthClaims{
Email: email, Email: email,
Username: username, Username: username,
Purpose: pendingOAuthPurpose,
RegisteredClaims: jwt.RegisteredClaims{ RegisteredClaims: jwt.RegisteredClaims{
ExpiresAt: jwt.NewNumericDate(now.Add(pendingOAuthTokenTTL)), ExpiresAt: jwt.NewNumericDate(now.Add(pendingOAuthTokenTTL)),
IssuedAt: jwt.NewNumericDate(now), IssuedAt: jwt.NewNumericDate(now),
...@@ -687,6 +724,9 @@ func (s *AuthService) VerifyPendingOAuthToken(tokenStr string) (email, username ...@@ -687,6 +724,9 @@ func (s *AuthService) VerifyPendingOAuthToken(tokenStr string) (email, username
if !ok || !token.Valid { if !ok || !token.Valid {
return "", "", ErrInvalidToken return "", "", ErrInvalidToken
} }
if claims.Purpose != pendingOAuthPurpose {
return "", "", ErrInvalidToken
}
return claims.Email, claims.Username, nil return claims.Email, claims.Username, nil
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment