"frontend/src/api/vscode:/vscode.git/clone" did not exist on "ff3f514f6b75dfe3e033dfac3faab65080a2eea7"
Commit 65e69738 authored by cyhhao's avatar cyhhao
Browse files

Merge branch 'main' of github.com:Wei-Shaw/sub2api

parents c8e2f614 39fad63c
...@@ -84,7 +84,8 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { ...@@ -84,7 +84,8 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
} }
dashboardAggregationService := service.ProvideDashboardAggregationService(dashboardAggregationRepository, timingWheelService, configConfig) dashboardAggregationService := service.ProvideDashboardAggregationService(dashboardAggregationRepository, timingWheelService, configConfig)
dashboardHandler := admin.NewDashboardHandler(dashboardService, dashboardAggregationService) dashboardHandler := admin.NewDashboardHandler(dashboardService, dashboardAggregationService)
accountRepository := repository.NewAccountRepository(client, db) schedulerCache := repository.NewSchedulerCache(redisClient)
accountRepository := repository.NewAccountRepository(client, db, schedulerCache)
proxyRepository := repository.NewProxyRepository(client, db) proxyRepository := repository.NewProxyRepository(client, db)
proxyExitInfoProber := repository.NewProxyExitInfoProber(configConfig) proxyExitInfoProber := repository.NewProxyExitInfoProber(configConfig)
proxyLatencyCache := repository.NewProxyLatencyCache(redisClient) proxyLatencyCache := repository.NewProxyLatencyCache(redisClient)
...@@ -129,7 +130,6 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { ...@@ -129,7 +130,6 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
adminRedeemHandler := admin.NewRedeemHandler(adminService) adminRedeemHandler := admin.NewRedeemHandler(adminService)
promoHandler := admin.NewPromoHandler(promoService) promoHandler := admin.NewPromoHandler(promoService)
opsRepository := repository.NewOpsRepository(db) opsRepository := repository.NewOpsRepository(db)
schedulerCache := repository.NewSchedulerCache(redisClient)
schedulerOutboxRepository := repository.NewSchedulerOutboxRepository(db) schedulerOutboxRepository := repository.NewSchedulerOutboxRepository(db)
schedulerSnapshotService := service.ProvideSchedulerSnapshotService(schedulerCache, schedulerOutboxRepository, accountRepository, groupRepository, configConfig) schedulerSnapshotService := service.ProvideSchedulerSnapshotService(schedulerCache, schedulerOutboxRepository, accountRepository, groupRepository, configConfig)
pricingRemoteClient := repository.ProvidePricingRemoteClient(configConfig) pricingRemoteClient := repository.ProvidePricingRemoteClient(configConfig)
......
...@@ -47,6 +47,7 @@ func (h *SettingHandler) GetSettings(c *gin.Context) { ...@@ -47,6 +47,7 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
response.Success(c, dto.SystemSettings{ response.Success(c, dto.SystemSettings{
RegistrationEnabled: settings.RegistrationEnabled, RegistrationEnabled: settings.RegistrationEnabled,
EmailVerifyEnabled: settings.EmailVerifyEnabled, EmailVerifyEnabled: settings.EmailVerifyEnabled,
PromoCodeEnabled: settings.PromoCodeEnabled,
SMTPHost: settings.SMTPHost, SMTPHost: settings.SMTPHost,
SMTPPort: settings.SMTPPort, SMTPPort: settings.SMTPPort,
SMTPUsername: settings.SMTPUsername, SMTPUsername: settings.SMTPUsername,
...@@ -90,6 +91,7 @@ type UpdateSettingsRequest struct { ...@@ -90,6 +91,7 @@ type UpdateSettingsRequest struct {
// 注册设置 // 注册设置
RegistrationEnabled bool `json:"registration_enabled"` RegistrationEnabled bool `json:"registration_enabled"`
EmailVerifyEnabled bool `json:"email_verify_enabled"` EmailVerifyEnabled bool `json:"email_verify_enabled"`
PromoCodeEnabled bool `json:"promo_code_enabled"`
// 邮件服务设置 // 邮件服务设置
SMTPHost string `json:"smtp_host"` SMTPHost string `json:"smtp_host"`
...@@ -240,6 +242,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { ...@@ -240,6 +242,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
settings := &service.SystemSettings{ settings := &service.SystemSettings{
RegistrationEnabled: req.RegistrationEnabled, RegistrationEnabled: req.RegistrationEnabled,
EmailVerifyEnabled: req.EmailVerifyEnabled, EmailVerifyEnabled: req.EmailVerifyEnabled,
PromoCodeEnabled: req.PromoCodeEnabled,
SMTPHost: req.SMTPHost, SMTPHost: req.SMTPHost,
SMTPPort: req.SMTPPort, SMTPPort: req.SMTPPort,
SMTPUsername: req.SMTPUsername, SMTPUsername: req.SMTPUsername,
...@@ -314,6 +317,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { ...@@ -314,6 +317,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
response.Success(c, dto.SystemSettings{ response.Success(c, dto.SystemSettings{
RegistrationEnabled: updatedSettings.RegistrationEnabled, RegistrationEnabled: updatedSettings.RegistrationEnabled,
EmailVerifyEnabled: updatedSettings.EmailVerifyEnabled, EmailVerifyEnabled: updatedSettings.EmailVerifyEnabled,
PromoCodeEnabled: updatedSettings.PromoCodeEnabled,
SMTPHost: updatedSettings.SMTPHost, SMTPHost: updatedSettings.SMTPHost,
SMTPPort: updatedSettings.SMTPPort, SMTPPort: updatedSettings.SMTPPort,
SMTPUsername: updatedSettings.SMTPUsername, SMTPUsername: updatedSettings.SMTPUsername,
......
...@@ -195,6 +195,15 @@ type ValidatePromoCodeResponse struct { ...@@ -195,6 +195,15 @@ type ValidatePromoCodeResponse struct {
// ValidatePromoCode 验证优惠码(公开接口,注册前调用) // ValidatePromoCode 验证优惠码(公开接口,注册前调用)
// POST /api/v1/auth/validate-promo-code // POST /api/v1/auth/validate-promo-code
func (h *AuthHandler) ValidatePromoCode(c *gin.Context) { func (h *AuthHandler) ValidatePromoCode(c *gin.Context) {
// 检查优惠码功能是否启用
if h.settingSvc != nil && !h.settingSvc.IsPromoCodeEnabled(c.Request.Context()) {
response.Success(c, ValidatePromoCodeResponse{
Valid: false,
ErrorCode: "PROMO_CODE_DISABLED",
})
return
}
var req ValidatePromoCodeRequest var req ValidatePromoCodeRequest
if err := c.ShouldBindJSON(&req); err != nil { if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error()) response.BadRequest(c, "Invalid request: "+err.Error())
......
...@@ -4,6 +4,7 @@ package dto ...@@ -4,6 +4,7 @@ package dto
type SystemSettings struct { type SystemSettings struct {
RegistrationEnabled bool `json:"registration_enabled"` RegistrationEnabled bool `json:"registration_enabled"`
EmailVerifyEnabled bool `json:"email_verify_enabled"` EmailVerifyEnabled bool `json:"email_verify_enabled"`
PromoCodeEnabled bool `json:"promo_code_enabled"`
SMTPHost string `json:"smtp_host"` SMTPHost string `json:"smtp_host"`
SMTPPort int `json:"smtp_port"` SMTPPort int `json:"smtp_port"`
...@@ -55,6 +56,7 @@ type SystemSettings struct { ...@@ -55,6 +56,7 @@ type SystemSettings struct {
type PublicSettings struct { type PublicSettings struct {
RegistrationEnabled bool `json:"registration_enabled"` RegistrationEnabled bool `json:"registration_enabled"`
EmailVerifyEnabled bool `json:"email_verify_enabled"` EmailVerifyEnabled bool `json:"email_verify_enabled"`
PromoCodeEnabled bool `json:"promo_code_enabled"`
TurnstileEnabled bool `json:"turnstile_enabled"` TurnstileEnabled bool `json:"turnstile_enabled"`
TurnstileSiteKey string `json:"turnstile_site_key"` TurnstileSiteKey string `json:"turnstile_site_key"`
SiteName string `json:"site_name"` SiteName string `json:"site_name"`
......
...@@ -34,6 +34,7 @@ func (h *SettingHandler) GetPublicSettings(c *gin.Context) { ...@@ -34,6 +34,7 @@ func (h *SettingHandler) GetPublicSettings(c *gin.Context) {
response.Success(c, dto.PublicSettings{ response.Success(c, dto.PublicSettings{
RegistrationEnabled: settings.RegistrationEnabled, RegistrationEnabled: settings.RegistrationEnabled,
EmailVerifyEnabled: settings.EmailVerifyEnabled, EmailVerifyEnabled: settings.EmailVerifyEnabled,
PromoCodeEnabled: settings.PromoCodeEnabled,
TurnstileEnabled: settings.TurnstileEnabled, TurnstileEnabled: settings.TurnstileEnabled,
TurnstileSiteKey: settings.TurnstileSiteKey, TurnstileSiteKey: settings.TurnstileSiteKey,
SiteName: settings.SiteName, SiteName: settings.SiteName,
......
...@@ -305,3 +305,139 @@ func mustParseURL(rawURL string) *url.URL { ...@@ -305,3 +305,139 @@ func mustParseURL(rawURL string) *url.URL {
} }
return u return u
} }
// TestProfileExpectation defines expected fingerprint values for a profile.
type TestProfileExpectation struct {
Profile *Profile
ExpectedJA3 string // Expected JA3 hash (empty = don't check)
ExpectedJA4 string // Expected full JA4 (empty = don't check)
JA4CipherHash string // Expected JA4 cipher hash - the stable middle part (empty = don't check)
}
// TestAllProfiles tests multiple TLS fingerprint profiles against tls.peet.ws.
// Run with: go test -v -run TestAllProfiles ./internal/pkg/tlsfingerprint/...
func TestAllProfiles(t *testing.T) {
if testing.Short() {
t.Skip("skipping integration test in short mode")
}
// Define all profiles to test with their expected fingerprints
// These profiles are from config.yaml gateway.tls_fingerprint.profiles
profiles := []TestProfileExpectation{
{
// Linux x64 Node.js v22.17.1
// Expected JA3 Hash: 1a28e69016765d92e3b381168d68922c
// Expected JA4: t13d5911h1_a33745022dd6_1f22a2ca17c4
Profile: &Profile{
Name: "linux_x64_node_v22171",
EnableGREASE: false,
CipherSuites: []uint16{4866, 4867, 4865, 49199, 49195, 49200, 49196, 158, 49191, 103, 49192, 107, 163, 159, 52393, 52392, 52394, 49327, 49325, 49315, 49311, 49245, 49249, 49239, 49235, 162, 49326, 49324, 49314, 49310, 49244, 49248, 49238, 49234, 49188, 106, 49187, 64, 49162, 49172, 57, 56, 49161, 49171, 51, 50, 157, 49313, 49309, 49233, 156, 49312, 49308, 49232, 61, 60, 53, 47, 255},
Curves: []uint16{29, 23, 30, 25, 24, 256, 257, 258, 259, 260},
PointFormats: []uint8{0, 1, 2},
},
JA4CipherHash: "a33745022dd6", // stable part
},
{
// MacOS arm64 Node.js v22.18.0
// Expected JA3 Hash: 70cb5ca646080902703ffda87036a5ea
// Expected JA4: t13d5912h1_a33745022dd6_dbd39dd1d406
Profile: &Profile{
Name: "macos_arm64_node_v22180",
EnableGREASE: false,
CipherSuites: []uint16{4866, 4867, 4865, 49199, 49195, 49200, 49196, 158, 49191, 103, 49192, 107, 163, 159, 52393, 52392, 52394, 49327, 49325, 49315, 49311, 49245, 49249, 49239, 49235, 162, 49326, 49324, 49314, 49310, 49244, 49248, 49238, 49234, 49188, 106, 49187, 64, 49162, 49172, 57, 56, 49161, 49171, 51, 50, 157, 49313, 49309, 49233, 156, 49312, 49308, 49232, 61, 60, 53, 47, 255},
Curves: []uint16{29, 23, 30, 25, 24, 256, 257, 258, 259, 260},
PointFormats: []uint8{0, 1, 2},
},
JA4CipherHash: "a33745022dd6", // stable part (same cipher suites)
},
}
for _, tc := range profiles {
tc := tc // capture range variable
t.Run(tc.Profile.Name, func(t *testing.T) {
fp := fetchFingerprint(t, tc.Profile)
if fp == nil {
return // fetchFingerprint already called t.Fatal
}
t.Logf("Profile: %s", tc.Profile.Name)
t.Logf(" JA3: %s", fp.JA3)
t.Logf(" JA3 Hash: %s", fp.JA3Hash)
t.Logf(" JA4: %s", fp.JA4)
t.Logf(" PeetPrint: %s", fp.PeetPrint)
t.Logf(" PeetPrintHash: %s", fp.PeetPrintHash)
// Verify expectations
if tc.ExpectedJA3 != "" {
if fp.JA3Hash == tc.ExpectedJA3 {
t.Logf(" ✓ JA3 hash matches: %s", tc.ExpectedJA3)
} else {
t.Errorf(" ✗ JA3 hash mismatch: got %s, expected %s", fp.JA3Hash, tc.ExpectedJA3)
}
}
if tc.ExpectedJA4 != "" {
if fp.JA4 == tc.ExpectedJA4 {
t.Logf(" ✓ JA4 matches: %s", tc.ExpectedJA4)
} else {
t.Errorf(" ✗ JA4 mismatch: got %s, expected %s", fp.JA4, tc.ExpectedJA4)
}
}
// Check JA4 cipher hash (stable middle part)
// JA4 format: prefix_cipherHash_extHash
if tc.JA4CipherHash != "" {
if strings.Contains(fp.JA4, "_"+tc.JA4CipherHash+"_") {
t.Logf(" ✓ JA4 cipher hash matches: %s", tc.JA4CipherHash)
} else {
t.Errorf(" ✗ JA4 cipher hash mismatch: got %s, expected cipher hash %s", fp.JA4, tc.JA4CipherHash)
}
}
})
}
}
// fetchFingerprint makes a request to tls.peet.ws and returns the TLS fingerprint info.
func fetchFingerprint(t *testing.T, profile *Profile) *TLSInfo {
t.Helper()
dialer := NewDialer(profile, nil)
client := &http.Client{
Transport: &http.Transport{
DialTLSContext: dialer.DialTLSContext,
},
Timeout: 30 * time.Second,
}
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
req, err := http.NewRequestWithContext(ctx, "GET", "https://tls.peet.ws/api/all", nil)
if err != nil {
t.Fatalf("failed to create request: %v", err)
return nil
}
req.Header.Set("User-Agent", "Claude Code/2.0.0 Node.js/20.0.0")
resp, err := client.Do(req)
if err != nil {
t.Fatalf("failed to get fingerprint: %v", err)
return nil
}
defer func() { _ = resp.Body.Close() }()
body, err := io.ReadAll(resp.Body)
if err != nil {
t.Fatalf("failed to read response: %v", err)
return nil
}
var fpResp FingerprintResponse
if err := json.Unmarshal(body, &fpResp); err != nil {
t.Logf("Response body: %s", string(body))
t.Fatalf("failed to parse fingerprint response: %v", err)
return nil
}
return &fpResp.TLS
}
...@@ -39,9 +39,15 @@ import ( ...@@ -39,9 +39,15 @@ import (
// 设计说明: // 设计说明:
// - client: Ent 客户端,用于类型安全的 ORM 操作 // - client: Ent 客户端,用于类型安全的 ORM 操作
// - sql: 原生 SQL 执行器,用于复杂查询和批量操作 // - sql: 原生 SQL 执行器,用于复杂查询和批量操作
// - schedulerCache: 调度器缓存,用于在账号状态变更时同步快照
type accountRepository struct { type accountRepository struct {
client *dbent.Client // Ent ORM 客户端 client *dbent.Client // Ent ORM 客户端
sql sqlExecutor // 原生 SQL 执行接口 sql sqlExecutor // 原生 SQL 执行接口
// schedulerCache 用于在账号状态变更时主动同步快照到缓存,
// 确保粘性会话能及时感知账号不可用状态。
// Used to proactively sync account snapshot to cache when status changes,
// ensuring sticky sessions can promptly detect unavailable accounts.
schedulerCache service.SchedulerCache
} }
type tempUnschedSnapshot struct { type tempUnschedSnapshot struct {
...@@ -51,14 +57,14 @@ type tempUnschedSnapshot struct { ...@@ -51,14 +57,14 @@ type tempUnschedSnapshot struct {
// NewAccountRepository 创建账户仓储实例。 // NewAccountRepository 创建账户仓储实例。
// 这是对外暴露的构造函数,返回接口类型以便于依赖注入。 // 这是对外暴露的构造函数,返回接口类型以便于依赖注入。
func NewAccountRepository(client *dbent.Client, sqlDB *sql.DB) service.AccountRepository { func NewAccountRepository(client *dbent.Client, sqlDB *sql.DB, schedulerCache service.SchedulerCache) service.AccountRepository {
return newAccountRepositoryWithSQL(client, sqlDB) return newAccountRepositoryWithSQL(client, sqlDB, schedulerCache)
} }
// newAccountRepositoryWithSQL 是内部构造函数,支持依赖注入 SQL 执行器。 // newAccountRepositoryWithSQL 是内部构造函数,支持依赖注入 SQL 执行器。
// 这种设计便于单元测试时注入 mock 对象。 // 这种设计便于单元测试时注入 mock 对象。
func newAccountRepositoryWithSQL(client *dbent.Client, sqlq sqlExecutor) *accountRepository { func newAccountRepositoryWithSQL(client *dbent.Client, sqlq sqlExecutor, schedulerCache service.SchedulerCache) *accountRepository {
return &accountRepository{client: client, sql: sqlq} return &accountRepository{client: client, sql: sqlq, schedulerCache: schedulerCache}
} }
func (r *accountRepository) Create(ctx context.Context, account *service.Account) error { func (r *accountRepository) Create(ctx context.Context, account *service.Account) error {
...@@ -356,6 +362,9 @@ func (r *accountRepository) Update(ctx context.Context, account *service.Account ...@@ -356,6 +362,9 @@ func (r *accountRepository) Update(ctx context.Context, account *service.Account
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &account.ID, nil, buildSchedulerGroupPayload(account.GroupIDs)); err != nil { if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &account.ID, nil, buildSchedulerGroupPayload(account.GroupIDs)); err != nil {
log.Printf("[SchedulerOutbox] enqueue account update failed: account=%d err=%v", account.ID, err) log.Printf("[SchedulerOutbox] enqueue account update failed: account=%d err=%v", account.ID, err)
} }
if account.Status == service.StatusError || account.Status == service.StatusDisabled || !account.Schedulable {
r.syncSchedulerAccountSnapshot(ctx, account.ID)
}
return nil return nil
} }
...@@ -540,9 +549,32 @@ func (r *accountRepository) SetError(ctx context.Context, id int64, errorMsg str ...@@ -540,9 +549,32 @@ func (r *accountRepository) SetError(ctx context.Context, id int64, errorMsg str
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil { if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
log.Printf("[SchedulerOutbox] enqueue set error failed: account=%d err=%v", id, err) log.Printf("[SchedulerOutbox] enqueue set error failed: account=%d err=%v", id, err)
} }
r.syncSchedulerAccountSnapshot(ctx, id)
return nil return nil
} }
// syncSchedulerAccountSnapshot 在账号状态变更时主动同步快照到调度器缓存。
// 当账号被设置为错误、禁用、不可调度或临时不可调度时调用,
// 确保调度器和粘性会话逻辑能及时感知账号的最新状态,避免继续使用不可用账号。
//
// syncSchedulerAccountSnapshot proactively syncs account snapshot to scheduler cache
// when account status changes. Called when account is set to error, disabled,
// unschedulable, or temporarily unschedulable, ensuring scheduler and sticky session
// logic can promptly detect the latest account state and avoid using unavailable accounts.
func (r *accountRepository) syncSchedulerAccountSnapshot(ctx context.Context, accountID int64) {
if r == nil || r.schedulerCache == nil || accountID <= 0 {
return
}
account, err := r.GetByID(ctx, accountID)
if err != nil {
log.Printf("[Scheduler] sync account snapshot read failed: id=%d err=%v", accountID, err)
return
}
if err := r.schedulerCache.SetAccount(ctx, account); err != nil {
log.Printf("[Scheduler] sync account snapshot write failed: id=%d err=%v", accountID, err)
}
}
func (r *accountRepository) ClearError(ctx context.Context, id int64) error { func (r *accountRepository) ClearError(ctx context.Context, id int64) error {
_, err := r.client.Account.Update(). _, err := r.client.Account.Update().
Where(dbaccount.IDEQ(id)). Where(dbaccount.IDEQ(id)).
...@@ -873,6 +905,7 @@ func (r *accountRepository) SetTempUnschedulable(ctx context.Context, id int64, ...@@ -873,6 +905,7 @@ func (r *accountRepository) SetTempUnschedulable(ctx context.Context, id int64,
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil { if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
log.Printf("[SchedulerOutbox] enqueue temp unschedulable failed: account=%d err=%v", id, err) log.Printf("[SchedulerOutbox] enqueue temp unschedulable failed: account=%d err=%v", id, err)
} }
r.syncSchedulerAccountSnapshot(ctx, id)
return nil return nil
} }
...@@ -992,6 +1025,9 @@ func (r *accountRepository) SetSchedulable(ctx context.Context, id int64, schedu ...@@ -992,6 +1025,9 @@ func (r *accountRepository) SetSchedulable(ctx context.Context, id int64, schedu
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil { if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
log.Printf("[SchedulerOutbox] enqueue schedulable change failed: account=%d err=%v", id, err) log.Printf("[SchedulerOutbox] enqueue schedulable change failed: account=%d err=%v", id, err)
} }
if !schedulable {
r.syncSchedulerAccountSnapshot(ctx, id)
}
return nil return nil
} }
...@@ -1146,6 +1182,18 @@ func (r *accountRepository) BulkUpdate(ctx context.Context, ids []int64, updates ...@@ -1146,6 +1182,18 @@ func (r *accountRepository) BulkUpdate(ctx context.Context, ids []int64, updates
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountBulkChanged, nil, nil, payload); err != nil { if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountBulkChanged, nil, nil, payload); err != nil {
log.Printf("[SchedulerOutbox] enqueue bulk update failed: err=%v", err) log.Printf("[SchedulerOutbox] enqueue bulk update failed: err=%v", err)
} }
shouldSync := false
if updates.Status != nil && (*updates.Status == service.StatusError || *updates.Status == service.StatusDisabled) {
shouldSync = true
}
if updates.Schedulable != nil && !*updates.Schedulable {
shouldSync = true
}
if shouldSync {
for _, id := range ids {
r.syncSchedulerAccountSnapshot(ctx, id)
}
}
} }
return rows, nil return rows, nil
} }
......
...@@ -21,11 +21,56 @@ type AccountRepoSuite struct { ...@@ -21,11 +21,56 @@ type AccountRepoSuite struct {
repo *accountRepository repo *accountRepository
} }
type schedulerCacheRecorder struct {
setAccounts []*service.Account
}
func (s *schedulerCacheRecorder) GetSnapshot(ctx context.Context, bucket service.SchedulerBucket) ([]*service.Account, bool, error) {
return nil, false, nil
}
func (s *schedulerCacheRecorder) SetSnapshot(ctx context.Context, bucket service.SchedulerBucket, accounts []service.Account) error {
return nil
}
func (s *schedulerCacheRecorder) GetAccount(ctx context.Context, accountID int64) (*service.Account, error) {
return nil, nil
}
func (s *schedulerCacheRecorder) SetAccount(ctx context.Context, account *service.Account) error {
s.setAccounts = append(s.setAccounts, account)
return nil
}
func (s *schedulerCacheRecorder) DeleteAccount(ctx context.Context, accountID int64) error {
return nil
}
func (s *schedulerCacheRecorder) UpdateLastUsed(ctx context.Context, updates map[int64]time.Time) error {
return nil
}
func (s *schedulerCacheRecorder) TryLockBucket(ctx context.Context, bucket service.SchedulerBucket, ttl time.Duration) (bool, error) {
return true, nil
}
func (s *schedulerCacheRecorder) ListBuckets(ctx context.Context) ([]service.SchedulerBucket, error) {
return nil, nil
}
func (s *schedulerCacheRecorder) GetOutboxWatermark(ctx context.Context) (int64, error) {
return 0, nil
}
func (s *schedulerCacheRecorder) SetOutboxWatermark(ctx context.Context, id int64) error {
return nil
}
func (s *AccountRepoSuite) SetupTest() { func (s *AccountRepoSuite) SetupTest() {
s.ctx = context.Background() s.ctx = context.Background()
tx := testEntTx(s.T()) tx := testEntTx(s.T())
s.client = tx.Client() s.client = tx.Client()
s.repo = newAccountRepositoryWithSQL(s.client, tx) s.repo = newAccountRepositoryWithSQL(s.client, tx, nil)
} }
func TestAccountRepoSuite(t *testing.T) { func TestAccountRepoSuite(t *testing.T) {
...@@ -73,6 +118,20 @@ func (s *AccountRepoSuite) TestUpdate() { ...@@ -73,6 +118,20 @@ func (s *AccountRepoSuite) TestUpdate() {
s.Require().Equal("updated", got.Name) s.Require().Equal("updated", got.Name)
} }
func (s *AccountRepoSuite) TestUpdate_SyncSchedulerSnapshotOnDisabled() {
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "sync-update", Status: service.StatusActive, Schedulable: true})
cacheRecorder := &schedulerCacheRecorder{}
s.repo.schedulerCache = cacheRecorder
account.Status = service.StatusDisabled
err := s.repo.Update(s.ctx, account)
s.Require().NoError(err, "Update")
s.Require().Len(cacheRecorder.setAccounts, 1)
s.Require().Equal(account.ID, cacheRecorder.setAccounts[0].ID)
s.Require().Equal(service.StatusDisabled, cacheRecorder.setAccounts[0].Status)
}
func (s *AccountRepoSuite) TestDelete() { func (s *AccountRepoSuite) TestDelete() {
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "to-delete"}) account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "to-delete"})
...@@ -174,7 +233,7 @@ func (s *AccountRepoSuite) TestListWithFilters() { ...@@ -174,7 +233,7 @@ func (s *AccountRepoSuite) TestListWithFilters() {
// 每个 case 重新获取隔离资源 // 每个 case 重新获取隔离资源
tx := testEntTx(s.T()) tx := testEntTx(s.T())
client := tx.Client() client := tx.Client()
repo := newAccountRepositoryWithSQL(client, tx) repo := newAccountRepositoryWithSQL(client, tx, nil)
ctx := context.Background() ctx := context.Background()
tt.setup(client) tt.setup(client)
...@@ -365,12 +424,38 @@ func (s *AccountRepoSuite) TestListSchedulableByGroupIDAndPlatform() { ...@@ -365,12 +424,38 @@ func (s *AccountRepoSuite) TestListSchedulableByGroupIDAndPlatform() {
func (s *AccountRepoSuite) TestSetSchedulable() { func (s *AccountRepoSuite) TestSetSchedulable() {
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-sched", Schedulable: true}) account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-sched", Schedulable: true})
cacheRecorder := &schedulerCacheRecorder{}
s.repo.schedulerCache = cacheRecorder
s.Require().NoError(s.repo.SetSchedulable(s.ctx, account.ID, false)) s.Require().NoError(s.repo.SetSchedulable(s.ctx, account.ID, false))
got, err := s.repo.GetByID(s.ctx, account.ID) got, err := s.repo.GetByID(s.ctx, account.ID)
s.Require().NoError(err) s.Require().NoError(err)
s.Require().False(got.Schedulable) s.Require().False(got.Schedulable)
s.Require().Len(cacheRecorder.setAccounts, 1)
s.Require().Equal(account.ID, cacheRecorder.setAccounts[0].ID)
}
func (s *AccountRepoSuite) TestBulkUpdate_SyncSchedulerSnapshotOnDisabled() {
account1 := mustCreateAccount(s.T(), s.client, &service.Account{Name: "bulk-1", Status: service.StatusActive, Schedulable: true})
account2 := mustCreateAccount(s.T(), s.client, &service.Account{Name: "bulk-2", Status: service.StatusActive, Schedulable: true})
cacheRecorder := &schedulerCacheRecorder{}
s.repo.schedulerCache = cacheRecorder
disabled := service.StatusDisabled
rows, err := s.repo.BulkUpdate(s.ctx, []int64{account1.ID, account2.ID}, service.AccountBulkUpdate{
Status: &disabled,
})
s.Require().NoError(err)
s.Require().Equal(int64(2), rows)
s.Require().Len(cacheRecorder.setAccounts, 2)
ids := map[int64]struct{}{}
for _, acc := range cacheRecorder.setAccounts {
ids[acc.ID] = struct{}{}
}
s.Require().Contains(ids, account1.ID)
s.Require().Contains(ids, account2.ID)
} }
// --- SetOverloaded / SetRateLimited / ClearRateLimit --- // --- SetOverloaded / SetRateLimited / ClearRateLimit ---
......
...@@ -39,3 +39,15 @@ func (c *gatewayCache) RefreshSessionTTL(ctx context.Context, groupID int64, ses ...@@ -39,3 +39,15 @@ func (c *gatewayCache) RefreshSessionTTL(ctx context.Context, groupID int64, ses
key := buildSessionKey(groupID, sessionHash) key := buildSessionKey(groupID, sessionHash)
return c.rdb.Expire(ctx, key, ttl).Err() return c.rdb.Expire(ctx, key, ttl).Err()
} }
// DeleteSessionAccountID 删除粘性会话与账号的绑定关系。
// 当检测到绑定的账号不可用(如状态错误、禁用、不可调度等)时调用,
// 以便下次请求能够重新选择可用账号。
//
// DeleteSessionAccountID removes the sticky session binding for the given session.
// Called when the bound account becomes unavailable (e.g., error status, disabled,
// or unschedulable), allowing subsequent requests to select a new available account.
func (c *gatewayCache) DeleteSessionAccountID(ctx context.Context, groupID int64, sessionHash string) error {
key := buildSessionKey(groupID, sessionHash)
return c.rdb.Del(ctx, key).Err()
}
...@@ -78,6 +78,19 @@ func (s *GatewayCacheSuite) TestRefreshSessionTTL_MissingKey() { ...@@ -78,6 +78,19 @@ func (s *GatewayCacheSuite) TestRefreshSessionTTL_MissingKey() {
require.NoError(s.T(), err, "RefreshSessionTTL on missing key should not error") require.NoError(s.T(), err, "RefreshSessionTTL on missing key should not error")
} }
func (s *GatewayCacheSuite) TestDeleteSessionAccountID() {
sessionID := "openai:s4"
accountID := int64(102)
groupID := int64(1)
sessionTTL := 1 * time.Minute
require.NoError(s.T(), s.cache.SetSessionAccountID(s.ctx, groupID, sessionID, accountID, sessionTTL), "SetSessionAccountID")
require.NoError(s.T(), s.cache.DeleteSessionAccountID(s.ctx, groupID, sessionID), "DeleteSessionAccountID")
_, err := s.cache.GetSessionAccountID(s.ctx, groupID, sessionID)
require.True(s.T(), errors.Is(err, redis.Nil), "expected redis.Nil after delete")
}
func (s *GatewayCacheSuite) TestGetSessionAccountID_CorruptedValue() { func (s *GatewayCacheSuite) TestGetSessionAccountID_CorruptedValue() {
sessionID := "corrupted" sessionID := "corrupted"
groupID := int64(1) groupID := int64(1)
......
...@@ -24,7 +24,7 @@ func (s *GatewayRoutingSuite) SetupTest() { ...@@ -24,7 +24,7 @@ func (s *GatewayRoutingSuite) SetupTest() {
s.ctx = context.Background() s.ctx = context.Background()
tx := testEntTx(s.T()) tx := testEntTx(s.T())
s.client = tx.Client() s.client = tx.Client()
s.accountRepo = newAccountRepositoryWithSQL(s.client, tx) s.accountRepo = newAccountRepositoryWithSQL(s.client, tx, nil)
} }
func TestGatewayRoutingSuite(t *testing.T) { func TestGatewayRoutingSuite(t *testing.T) {
......
...@@ -4,6 +4,7 @@ import ( ...@@ -4,6 +4,7 @@ import (
"context" "context"
"fmt" "fmt"
"net/url" "net/url"
"strings"
"time" "time"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai" "github.com/Wei-Shaw/sub2api/internal/pkg/openai"
...@@ -21,7 +22,7 @@ type openaiOAuthService struct { ...@@ -21,7 +22,7 @@ type openaiOAuthService struct {
} }
func (s *openaiOAuthService) ExchangeCode(ctx context.Context, code, codeVerifier, redirectURI, proxyURL string) (*openai.TokenResponse, error) { func (s *openaiOAuthService) ExchangeCode(ctx context.Context, code, codeVerifier, redirectURI, proxyURL string) (*openai.TokenResponse, error) {
client := createOpenAIReqClient(proxyURL) client := createOpenAIReqClient(s.tokenURL, proxyURL)
if redirectURI == "" { if redirectURI == "" {
redirectURI = openai.DefaultRedirectURI redirectURI = openai.DefaultRedirectURI
...@@ -54,7 +55,7 @@ func (s *openaiOAuthService) ExchangeCode(ctx context.Context, code, codeVerifie ...@@ -54,7 +55,7 @@ func (s *openaiOAuthService) ExchangeCode(ctx context.Context, code, codeVerifie
} }
func (s *openaiOAuthService) RefreshToken(ctx context.Context, refreshToken, proxyURL string) (*openai.TokenResponse, error) { func (s *openaiOAuthService) RefreshToken(ctx context.Context, refreshToken, proxyURL string) (*openai.TokenResponse, error) {
client := createOpenAIReqClient(proxyURL) client := createOpenAIReqClient(s.tokenURL, proxyURL)
formData := url.Values{} formData := url.Values{}
formData.Set("grant_type", "refresh_token") formData.Set("grant_type", "refresh_token")
...@@ -81,9 +82,14 @@ func (s *openaiOAuthService) RefreshToken(ctx context.Context, refreshToken, pro ...@@ -81,9 +82,14 @@ func (s *openaiOAuthService) RefreshToken(ctx context.Context, refreshToken, pro
return &tokenResp, nil return &tokenResp, nil
} }
func createOpenAIReqClient(proxyURL string) *req.Client { func createOpenAIReqClient(tokenURL, proxyURL string) *req.Client {
forceHTTP2 := false
if parsedURL, err := url.Parse(tokenURL); err == nil {
forceHTTP2 = strings.EqualFold(parsedURL.Scheme, "https")
}
return getSharedReqClient(reqClientOptions{ return getSharedReqClient(reqClientOptions{
ProxyURL: proxyURL, ProxyURL: proxyURL,
Timeout: 60 * time.Second, Timeout: 120 * time.Second,
ForceHTTP2: forceHTTP2,
}) })
} }
...@@ -244,6 +244,13 @@ func (s *OpenAIOAuthServiceSuite) TestRefreshToken_NonSuccessStatus() { ...@@ -244,6 +244,13 @@ func (s *OpenAIOAuthServiceSuite) TestRefreshToken_NonSuccessStatus() {
require.ErrorContains(s.T(), err, "status 401") require.ErrorContains(s.T(), err, "status 401")
} }
func TestNewOpenAIOAuthClient_DefaultTokenURL(t *testing.T) {
client := NewOpenAIOAuthClient()
svc, ok := client.(*openaiOAuthService)
require.True(t, ok)
require.Equal(t, openai.TokenURL, svc.tokenURL)
}
func TestOpenAIOAuthServiceSuite(t *testing.T) { func TestOpenAIOAuthServiceSuite(t *testing.T) {
suite.Run(t, new(OpenAIOAuthServiceSuite)) suite.Run(t, new(OpenAIOAuthServiceSuite))
} }
...@@ -14,6 +14,7 @@ type reqClientOptions struct { ...@@ -14,6 +14,7 @@ type reqClientOptions struct {
ProxyURL string // 代理 URL(支持 http/https/socks5) ProxyURL string // 代理 URL(支持 http/https/socks5)
Timeout time.Duration // 请求超时时间 Timeout time.Duration // 请求超时时间
Impersonate bool // 是否模拟 Chrome 浏览器指纹 Impersonate bool // 是否模拟 Chrome 浏览器指纹
ForceHTTP2 bool // 是否强制使用 HTTP/2
} }
// sharedReqClients 存储按配置参数缓存的 req 客户端实例 // sharedReqClients 存储按配置参数缓存的 req 客户端实例
...@@ -41,6 +42,9 @@ func getSharedReqClient(opts reqClientOptions) *req.Client { ...@@ -41,6 +42,9 @@ func getSharedReqClient(opts reqClientOptions) *req.Client {
} }
client := req.C().SetTimeout(opts.Timeout) client := req.C().SetTimeout(opts.Timeout)
if opts.ForceHTTP2 {
client = client.EnableForceHTTP2()
}
if opts.Impersonate { if opts.Impersonate {
client = client.ImpersonateChrome() client = client.ImpersonateChrome()
} }
...@@ -56,9 +60,10 @@ func getSharedReqClient(opts reqClientOptions) *req.Client { ...@@ -56,9 +60,10 @@ func getSharedReqClient(opts reqClientOptions) *req.Client {
} }
func buildReqClientKey(opts reqClientOptions) string { func buildReqClientKey(opts reqClientOptions) string {
return fmt.Sprintf("%s|%s|%t", return fmt.Sprintf("%s|%s|%t|%t",
strings.TrimSpace(opts.ProxyURL), strings.TrimSpace(opts.ProxyURL),
opts.Timeout.String(), opts.Timeout.String(),
opts.Impersonate, opts.Impersonate,
opts.ForceHTTP2,
) )
} }
package repository
import (
"reflect"
"sync"
"testing"
"time"
"unsafe"
"github.com/imroc/req/v3"
"github.com/stretchr/testify/require"
)
func forceHTTPVersion(t *testing.T, client *req.Client) string {
t.Helper()
transport := client.GetTransport()
field := reflect.ValueOf(transport).Elem().FieldByName("forceHttpVersion")
require.True(t, field.IsValid(), "forceHttpVersion field not found")
require.True(t, field.CanAddr(), "forceHttpVersion field not addressable")
return reflect.NewAt(field.Type(), unsafe.Pointer(field.UnsafeAddr())).Elem().String()
}
func TestGetSharedReqClient_ForceHTTP2SeparatesCache(t *testing.T) {
sharedReqClients = sync.Map{}
base := reqClientOptions{
ProxyURL: "http://proxy.local:8080",
Timeout: time.Second,
}
clientDefault := getSharedReqClient(base)
force := base
force.ForceHTTP2 = true
clientForce := getSharedReqClient(force)
require.NotSame(t, clientDefault, clientForce)
require.NotEqual(t, buildReqClientKey(base), buildReqClientKey(force))
}
func TestGetSharedReqClient_ReuseCachedClient(t *testing.T) {
sharedReqClients = sync.Map{}
opts := reqClientOptions{
ProxyURL: "http://proxy.local:8080",
Timeout: 2 * time.Second,
}
first := getSharedReqClient(opts)
second := getSharedReqClient(opts)
require.Same(t, first, second)
}
func TestGetSharedReqClient_IgnoresNonClientCache(t *testing.T) {
sharedReqClients = sync.Map{}
opts := reqClientOptions{
ProxyURL: " http://proxy.local:8080 ",
Timeout: 3 * time.Second,
}
key := buildReqClientKey(opts)
sharedReqClients.Store(key, "invalid")
client := getSharedReqClient(opts)
require.NotNil(t, client)
loaded, ok := sharedReqClients.Load(key)
require.True(t, ok)
require.IsType(t, "invalid", loaded)
}
func TestGetSharedReqClient_ImpersonateAndProxy(t *testing.T) {
sharedReqClients = sync.Map{}
opts := reqClientOptions{
ProxyURL: " http://proxy.local:8080 ",
Timeout: 4 * time.Second,
Impersonate: true,
}
client := getSharedReqClient(opts)
require.NotNil(t, client)
require.Equal(t, "http://proxy.local:8080|4s|true|false", buildReqClientKey(opts))
}
func TestCreateOpenAIReqClient_ForceHTTP2Enabled(t *testing.T) {
sharedReqClients = sync.Map{}
client := createOpenAIReqClient("https://auth.openai.com/oauth/token", "http://proxy.local:8080")
require.Equal(t, "2", forceHTTPVersion(t, client))
}
func TestCreateOpenAIReqClient_ForceHTTP2DisabledForHTTP(t *testing.T) {
sharedReqClients = sync.Map{}
client := createOpenAIReqClient("http://localhost/oauth/token", "http://proxy.local:8080")
require.Equal(t, "", forceHTTPVersion(t, client))
}
func TestCreateOpenAIReqClient_Timeout120Seconds(t *testing.T) {
sharedReqClients = sync.Map{}
client := createOpenAIReqClient("https://auth.openai.com/oauth/token", "http://proxy.local:8080")
require.Equal(t, 120*time.Second, client.GetClient().Timeout)
}
func TestCreateGeminiReqClient_ForceHTTP2Disabled(t *testing.T) {
sharedReqClients = sync.Map{}
client := createGeminiReqClient("http://proxy.local:8080")
require.Equal(t, "", forceHTTPVersion(t, client))
}
...@@ -19,7 +19,7 @@ func TestSchedulerSnapshotOutboxReplay(t *testing.T) { ...@@ -19,7 +19,7 @@ func TestSchedulerSnapshotOutboxReplay(t *testing.T) {
_, _ = integrationDB.ExecContext(ctx, "TRUNCATE scheduler_outbox") _, _ = integrationDB.ExecContext(ctx, "TRUNCATE scheduler_outbox")
accountRepo := newAccountRepositoryWithSQL(client, integrationDB) accountRepo := newAccountRepositoryWithSQL(client, integrationDB, nil)
outboxRepo := NewSchedulerOutboxRepository(integrationDB) outboxRepo := NewSchedulerOutboxRepository(integrationDB)
cache := NewSchedulerCache(rdb) cache := NewSchedulerCache(rdb)
......
...@@ -412,6 +412,7 @@ func TestAPIContracts(t *testing.T) { ...@@ -412,6 +412,7 @@ func TestAPIContracts(t *testing.T) {
deps.settingRepo.SetAll(map[string]string{ deps.settingRepo.SetAll(map[string]string{
service.SettingKeyRegistrationEnabled: "true", service.SettingKeyRegistrationEnabled: "true",
service.SettingKeyEmailVerifyEnabled: "false", service.SettingKeyEmailVerifyEnabled: "false",
service.SettingKeyPromoCodeEnabled: "true",
service.SettingKeySMTPHost: "smtp.example.com", service.SettingKeySMTPHost: "smtp.example.com",
service.SettingKeySMTPPort: "587", service.SettingKeySMTPPort: "587",
...@@ -450,6 +451,7 @@ func TestAPIContracts(t *testing.T) { ...@@ -450,6 +451,7 @@ func TestAPIContracts(t *testing.T) {
"data": { "data": {
"registration_enabled": true, "registration_enabled": true,
"email_verify_enabled": false, "email_verify_enabled": false,
"promo_code_enabled": true,
"smtp_host": "smtp.example.com", "smtp_host": "smtp.example.com",
"smtp_port": 587, "smtp_port": 587,
"smtp_username": "user", "smtp_username": "user",
......
...@@ -153,8 +153,8 @@ func (s *AuthService) RegisterWithVerification(ctx context.Context, email, passw ...@@ -153,8 +153,8 @@ func (s *AuthService) RegisterWithVerification(ctx context.Context, email, passw
return "", nil, ErrServiceUnavailable return "", nil, ErrServiceUnavailable
} }
// 应用优惠码(如果提供) // 应用优惠码(如果提供且功能已启用
if promoCode != "" && s.promoService != nil { if promoCode != "" && s.promoService != nil && s.settingService != nil && s.settingService.IsPromoCodeEnabled(ctx) {
if err := s.promoService.ApplyPromoCode(ctx, user.ID, promoCode); err != nil { if err := s.promoService.ApplyPromoCode(ctx, user.ID, promoCode); err != nil {
// 优惠码应用失败不影响注册,只记录日志 // 优惠码应用失败不影响注册,只记录日志
log.Printf("[Auth] Failed to apply promo code for user %d: %v", user.ID, err) log.Printf("[Auth] Failed to apply promo code for user %d: %v", user.ID, err)
......
...@@ -71,6 +71,7 @@ const ( ...@@ -71,6 +71,7 @@ const (
// 注册设置 // 注册设置
SettingKeyRegistrationEnabled = "registration_enabled" // 是否开放注册 SettingKeyRegistrationEnabled = "registration_enabled" // 是否开放注册
SettingKeyEmailVerifyEnabled = "email_verify_enabled" // 是否开启邮件验证 SettingKeyEmailVerifyEnabled = "email_verify_enabled" // 是否开启邮件验证
SettingKeyPromoCodeEnabled = "promo_code_enabled" // 是否启用优惠码功能
// 邮件服务设置 // 邮件服务设置
SettingKeySMTPHost = "smtp_host" // SMTP服务器地址 SettingKeySMTPHost = "smtp_host" // SMTP服务器地址
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment