Commit e9de839d authored by IanShaw027's avatar IanShaw027
Browse files

feat: rebuild auth identity foundation flow

parent fbd0a2e3
//go:build unit
package service
import (
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/golang-jwt/jwt/v5"
"github.com/stretchr/testify/require"
)
func newAuthServiceForPendingOAuthTest() *AuthService {
cfg := &config.Config{
JWT: config.JWTConfig{
Secret: "test-secret-pending-oauth",
ExpireHour: 1,
},
}
return NewAuthService(nil, nil, nil, nil, cfg, nil, nil, nil, nil, nil, nil)
}
// TestVerifyPendingOAuthToken_ValidToken 验证正常签发的 pending token 可以被成功解析。
func TestVerifyPendingOAuthToken_ValidToken(t *testing.T) {
svc := newAuthServiceForPendingOAuthTest()
token, err := svc.CreatePendingOAuthToken("user@example.com", "alice")
require.NoError(t, err)
require.NotEmpty(t, token)
email, username, err := svc.VerifyPendingOAuthToken(token)
require.NoError(t, err)
require.Equal(t, "user@example.com", email)
require.Equal(t, "alice", username)
}
// TestVerifyPendingOAuthToken_RegularJWTRejected 用普通 access token 尝试验证,应返回 ErrInvalidToken。
func TestVerifyPendingOAuthToken_RegularJWTRejected(t *testing.T) {
svc := newAuthServiceForPendingOAuthTest()
// 签发一个普通 access token(JWTClaims,无 Purpose 字段)
accessToken, err := svc.GenerateToken(&User{
ID: 1,
Email: "user@example.com",
Role: RoleUser,
})
require.NoError(t, err)
_, _, err = svc.VerifyPendingOAuthToken(accessToken)
require.ErrorIs(t, err, ErrInvalidToken)
}
// TestVerifyPendingOAuthToken_WrongPurpose 手动构造 purpose 字段不匹配的 JWT,应返回 ErrInvalidToken。
func TestVerifyPendingOAuthToken_WrongPurpose(t *testing.T) {
svc := newAuthServiceForPendingOAuthTest()
now := time.Now()
claims := &pendingOAuthClaims{
Email: "user@example.com",
Username: "alice",
Purpose: "some_other_purpose",
RegisteredClaims: jwt.RegisteredClaims{
ExpiresAt: jwt.NewNumericDate(now.Add(10 * time.Minute)),
IssuedAt: jwt.NewNumericDate(now),
NotBefore: jwt.NewNumericDate(now),
},
}
tok := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
tokenStr, err := tok.SignedString([]byte(svc.cfg.JWT.Secret))
require.NoError(t, err)
_, _, err = svc.VerifyPendingOAuthToken(tokenStr)
require.ErrorIs(t, err, ErrInvalidToken)
}
// TestVerifyPendingOAuthToken_MissingPurpose 手动构造无 purpose 字段的 JWT(模拟旧 token),应返回 ErrInvalidToken。
func TestVerifyPendingOAuthToken_MissingPurpose(t *testing.T) {
svc := newAuthServiceForPendingOAuthTest()
now := time.Now()
claims := &pendingOAuthClaims{
Email: "user@example.com",
Username: "alice",
Purpose: "", // 旧 token 无此字段,反序列化后为零值
RegisteredClaims: jwt.RegisteredClaims{
ExpiresAt: jwt.NewNumericDate(now.Add(10 * time.Minute)),
IssuedAt: jwt.NewNumericDate(now),
NotBefore: jwt.NewNumericDate(now),
},
}
tok := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
tokenStr, err := tok.SignedString([]byte(svc.cfg.JWT.Secret))
require.NoError(t, err)
_, _, err = svc.VerifyPendingOAuthToken(tokenStr)
require.ErrorIs(t, err, ErrInvalidToken)
}
// TestVerifyPendingOAuthToken_ExpiredToken 过期 token 应返回 ErrInvalidToken。
func TestVerifyPendingOAuthToken_ExpiredToken(t *testing.T) {
svc := newAuthServiceForPendingOAuthTest()
past := time.Now().Add(-1 * time.Hour)
claims := &pendingOAuthClaims{
Email: "user@example.com",
Username: "alice",
Purpose: pendingOAuthPurpose,
RegisteredClaims: jwt.RegisteredClaims{
ExpiresAt: jwt.NewNumericDate(past),
IssuedAt: jwt.NewNumericDate(past.Add(-10 * time.Minute)),
NotBefore: jwt.NewNumericDate(past.Add(-10 * time.Minute)),
},
}
tok := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
tokenStr, err := tok.SignedString([]byte(svc.cfg.JWT.Secret))
require.NoError(t, err)
_, _, err = svc.VerifyPendingOAuthToken(tokenStr)
require.ErrorIs(t, err, ErrInvalidToken)
}
// TestVerifyPendingOAuthToken_WrongSecret 不同密钥签发的 token 应返回 ErrInvalidToken。
func TestVerifyPendingOAuthToken_WrongSecret(t *testing.T) {
other := NewAuthService(nil, nil, nil, nil, &config.Config{
JWT: config.JWTConfig{Secret: "other-secret"},
}, nil, nil, nil, nil, nil, nil)
token, err := other.CreatePendingOAuthToken("user@example.com", "alice")
require.NoError(t, err)
svc := newAuthServiceForPendingOAuthTest()
_, _, err = svc.VerifyPendingOAuthToken(token)
require.ErrorIs(t, err, ErrInvalidToken)
}
// TestVerifyPendingOAuthToken_TooLong 超长 token 应返回 ErrInvalidToken。
func TestVerifyPendingOAuthToken_TooLong(t *testing.T) {
svc := newAuthServiceForPendingOAuthTest()
giant := make([]byte, maxTokenLength+1)
for i := range giant {
giant[i] = 'a'
}
_, _, err := svc.VerifyPendingOAuthToken(string(giant))
require.ErrorIs(t, err, ErrInvalidToken)
}
...@@ -74,6 +74,9 @@ const LinuxDoConnectSyntheticEmailDomain = "@linuxdo-connect.invalid" ...@@ -74,6 +74,9 @@ const LinuxDoConnectSyntheticEmailDomain = "@linuxdo-connect.invalid"
// OIDCConnectSyntheticEmailDomain 是 OIDC 用户的合成邮箱后缀(RFC 保留域名)。 // OIDCConnectSyntheticEmailDomain 是 OIDC 用户的合成邮箱后缀(RFC 保留域名)。
const OIDCConnectSyntheticEmailDomain = "@oidc-connect.invalid" const OIDCConnectSyntheticEmailDomain = "@oidc-connect.invalid"
// WeChatConnectSyntheticEmailDomain 是 WeChat Connect 用户的合成邮箱后缀(RFC 保留域名)。
const WeChatConnectSyntheticEmailDomain = "@wechat-connect.invalid"
// Setting keys // Setting keys
const ( const (
// 注册设置 // 注册设置
...@@ -153,6 +156,29 @@ const ( ...@@ -153,6 +156,29 @@ const (
SettingKeyDefaultBalance = "default_balance" // 新用户默认余额 SettingKeyDefaultBalance = "default_balance" // 新用户默认余额
SettingKeyDefaultSubscriptions = "default_subscriptions" // 新用户默认订阅列表(JSON) SettingKeyDefaultSubscriptions = "default_subscriptions" // 新用户默认订阅列表(JSON)
// 第三方认证来源默认授予配置
SettingKeyAuthSourceDefaultEmailBalance = "auth_source_default_email_balance"
SettingKeyAuthSourceDefaultEmailConcurrency = "auth_source_default_email_concurrency"
SettingKeyAuthSourceDefaultEmailSubscriptions = "auth_source_default_email_subscriptions"
SettingKeyAuthSourceDefaultEmailGrantOnSignup = "auth_source_default_email_grant_on_signup"
SettingKeyAuthSourceDefaultEmailGrantOnFirstBind = "auth_source_default_email_grant_on_first_bind"
SettingKeyAuthSourceDefaultLinuxDoBalance = "auth_source_default_linuxdo_balance"
SettingKeyAuthSourceDefaultLinuxDoConcurrency = "auth_source_default_linuxdo_concurrency"
SettingKeyAuthSourceDefaultLinuxDoSubscriptions = "auth_source_default_linuxdo_subscriptions"
SettingKeyAuthSourceDefaultLinuxDoGrantOnSignup = "auth_source_default_linuxdo_grant_on_signup"
SettingKeyAuthSourceDefaultLinuxDoGrantOnFirstBind = "auth_source_default_linuxdo_grant_on_first_bind"
SettingKeyAuthSourceDefaultOIDCBalance = "auth_source_default_oidc_balance"
SettingKeyAuthSourceDefaultOIDCConcurrency = "auth_source_default_oidc_concurrency"
SettingKeyAuthSourceDefaultOIDCSubscriptions = "auth_source_default_oidc_subscriptions"
SettingKeyAuthSourceDefaultOIDCGrantOnSignup = "auth_source_default_oidc_grant_on_signup"
SettingKeyAuthSourceDefaultOIDCGrantOnFirstBind = "auth_source_default_oidc_grant_on_first_bind"
SettingKeyAuthSourceDefaultWeChatBalance = "auth_source_default_wechat_balance"
SettingKeyAuthSourceDefaultWeChatConcurrency = "auth_source_default_wechat_concurrency"
SettingKeyAuthSourceDefaultWeChatSubscriptions = "auth_source_default_wechat_subscriptions"
SettingKeyAuthSourceDefaultWeChatGrantOnSignup = "auth_source_default_wechat_grant_on_signup"
SettingKeyAuthSourceDefaultWeChatGrantOnFirstBind = "auth_source_default_wechat_grant_on_first_bind"
SettingKeyForceEmailOnThirdPartySignup = "force_email_on_third_party_signup"
// 管理员 API Key // 管理员 API Key
SettingKeyAdminAPIKey = "admin_api_key" // 全局管理员 API Key(用于外部系统集成) SettingKeyAdminAPIKey = "admin_api_key" // 全局管理员 API Key(用于外部系统集成)
......
...@@ -13,14 +13,30 @@ import ( ...@@ -13,14 +13,30 @@ import (
"sync" "sync"
"sync/atomic" "sync/atomic"
"time" "time"
"golang.org/x/sync/singleflight"
) )
const ( const (
openAIAccountScheduleLayerPreviousResponse = "previous_response_id" openAIAccountScheduleLayerPreviousResponse = "previous_response_id"
openAIAccountScheduleLayerSessionSticky = "session_hash" openAIAccountScheduleLayerSessionSticky = "session_hash"
openAIAccountScheduleLayerLoadBalance = "load_balance" openAIAccountScheduleLayerLoadBalance = "load_balance"
openAIAdvancedSchedulerSettingKey = "openai_advanced_scheduler_enabled"
)
const (
openAIAdvancedSchedulerSettingCacheTTL = 5 * time.Second
openAIAdvancedSchedulerSettingDBTimeout = 2 * time.Second
) )
type cachedOpenAIAdvancedSchedulerSetting struct {
enabled bool
expiresAt int64
}
var openAIAdvancedSchedulerSettingCache atomic.Value // *cachedOpenAIAdvancedSchedulerSetting
var openAIAdvancedSchedulerSettingSF singleflight.Group
type OpenAIAccountScheduleRequest struct { type OpenAIAccountScheduleRequest struct {
GroupID *int64 GroupID *int64
SessionHash string SessionHash string
...@@ -805,10 +821,56 @@ func (s *defaultOpenAIAccountScheduler) SnapshotMetrics() OpenAIAccountScheduler ...@@ -805,10 +821,56 @@ func (s *defaultOpenAIAccountScheduler) SnapshotMetrics() OpenAIAccountScheduler
return snapshot return snapshot
} }
func (s *OpenAIGatewayService) getOpenAIAccountScheduler() OpenAIAccountScheduler { func (s *OpenAIGatewayService) openAIAdvancedSchedulerSettingRepo() SettingRepository {
if s == nil || s.rateLimitService == nil || s.rateLimitService.settingService == nil {
return nil
}
return s.rateLimitService.settingService.settingRepo
}
func (s *OpenAIGatewayService) isOpenAIAdvancedSchedulerEnabled(ctx context.Context) bool {
if cached, ok := openAIAdvancedSchedulerSettingCache.Load().(*cachedOpenAIAdvancedSchedulerSetting); ok && cached != nil {
if time.Now().UnixNano() < cached.expiresAt {
return cached.enabled
}
}
result, _, _ := openAIAdvancedSchedulerSettingSF.Do(openAIAdvancedSchedulerSettingKey, func() (any, error) {
if cached, ok := openAIAdvancedSchedulerSettingCache.Load().(*cachedOpenAIAdvancedSchedulerSetting); ok && cached != nil {
if time.Now().UnixNano() < cached.expiresAt {
return cached.enabled, nil
}
}
enabled := false
if repo := s.openAIAdvancedSchedulerSettingRepo(); repo != nil {
dbCtx, cancel := context.WithTimeout(context.WithoutCancel(ctx), openAIAdvancedSchedulerSettingDBTimeout)
defer cancel()
value, err := repo.GetValue(dbCtx, openAIAdvancedSchedulerSettingKey)
if err == nil {
enabled = strings.EqualFold(strings.TrimSpace(value), "true")
}
}
openAIAdvancedSchedulerSettingCache.Store(&cachedOpenAIAdvancedSchedulerSetting{
enabled: enabled,
expiresAt: time.Now().Add(openAIAdvancedSchedulerSettingCacheTTL).UnixNano(),
})
return enabled, nil
})
enabled, _ := result.(bool)
return enabled
}
func (s *OpenAIGatewayService) getOpenAIAccountScheduler(ctx context.Context) OpenAIAccountScheduler {
if s == nil { if s == nil {
return nil return nil
} }
if !s.isOpenAIAdvancedSchedulerEnabled(ctx) {
return nil
}
s.openaiSchedulerOnce.Do(func() { s.openaiSchedulerOnce.Do(func() {
if s.openaiAccountStats == nil { if s.openaiAccountStats == nil {
s.openaiAccountStats = newOpenAIAccountRuntimeStats() s.openaiAccountStats = newOpenAIAccountRuntimeStats()
...@@ -820,6 +882,11 @@ func (s *OpenAIGatewayService) getOpenAIAccountScheduler() OpenAIAccountSchedule ...@@ -820,6 +882,11 @@ func (s *OpenAIGatewayService) getOpenAIAccountScheduler() OpenAIAccountSchedule
return s.openaiScheduler return s.openaiScheduler
} }
func resetOpenAIAdvancedSchedulerSettingCacheForTest() {
openAIAdvancedSchedulerSettingCache = atomic.Value{}
openAIAdvancedSchedulerSettingSF = singleflight.Group{}
}
func (s *OpenAIGatewayService) SelectAccountWithScheduler( func (s *OpenAIGatewayService) SelectAccountWithScheduler(
ctx context.Context, ctx context.Context,
groupID *int64, groupID *int64,
...@@ -830,7 +897,7 @@ func (s *OpenAIGatewayService) SelectAccountWithScheduler( ...@@ -830,7 +897,7 @@ func (s *OpenAIGatewayService) SelectAccountWithScheduler(
requiredTransport OpenAIUpstreamTransport, requiredTransport OpenAIUpstreamTransport,
) (*AccountSelectionResult, OpenAIAccountScheduleDecision, error) { ) (*AccountSelectionResult, OpenAIAccountScheduleDecision, error) {
decision := OpenAIAccountScheduleDecision{} decision := OpenAIAccountScheduleDecision{}
scheduler := s.getOpenAIAccountScheduler() scheduler := s.getOpenAIAccountScheduler(ctx)
if scheduler == nil { if scheduler == nil {
selection, err := s.SelectAccountWithLoadAwareness(ctx, groupID, sessionHash, requestedModel, excludedIDs) selection, err := s.SelectAccountWithLoadAwareness(ctx, groupID, sessionHash, requestedModel, excludedIDs)
decision.Layer = openAIAccountScheduleLayerLoadBalance decision.Layer = openAIAccountScheduleLayerLoadBalance
...@@ -856,7 +923,7 @@ func (s *OpenAIGatewayService) SelectAccountWithScheduler( ...@@ -856,7 +923,7 @@ func (s *OpenAIGatewayService) SelectAccountWithScheduler(
} }
func (s *OpenAIGatewayService) ReportOpenAIAccountScheduleResult(accountID int64, success bool, firstTokenMs *int) { func (s *OpenAIGatewayService) ReportOpenAIAccountScheduleResult(accountID int64, success bool, firstTokenMs *int) {
scheduler := s.getOpenAIAccountScheduler() scheduler := s.getOpenAIAccountScheduler(context.Background())
if scheduler == nil { if scheduler == nil {
return return
} }
...@@ -864,7 +931,7 @@ func (s *OpenAIGatewayService) ReportOpenAIAccountScheduleResult(accountID int64 ...@@ -864,7 +931,7 @@ func (s *OpenAIGatewayService) ReportOpenAIAccountScheduleResult(accountID int64
} }
func (s *OpenAIGatewayService) RecordOpenAIAccountSwitch() { func (s *OpenAIGatewayService) RecordOpenAIAccountSwitch() {
scheduler := s.getOpenAIAccountScheduler() scheduler := s.getOpenAIAccountScheduler(context.Background())
if scheduler == nil { if scheduler == nil {
return return
} }
...@@ -872,7 +939,7 @@ func (s *OpenAIGatewayService) RecordOpenAIAccountSwitch() { ...@@ -872,7 +939,7 @@ func (s *OpenAIGatewayService) RecordOpenAIAccountSwitch() {
} }
func (s *OpenAIGatewayService) SnapshotOpenAIAccountSchedulerMetrics() OpenAIAccountSchedulerMetricsSnapshot { func (s *OpenAIGatewayService) SnapshotOpenAIAccountSchedulerMetrics() OpenAIAccountSchedulerMetricsSnapshot {
scheduler := s.getOpenAIAccountScheduler() scheduler := s.getOpenAIAccountScheduler(context.Background())
if scheduler == nil { if scheduler == nil {
return OpenAIAccountSchedulerMetricsSnapshot{} return OpenAIAccountSchedulerMetricsSnapshot{}
} }
......
...@@ -2,6 +2,7 @@ package service ...@@ -2,6 +2,7 @@ package service
import ( import (
"context" "context"
"errors"
"fmt" "fmt"
"math" "math"
"sync" "sync"
...@@ -18,6 +19,202 @@ type openAISnapshotCacheStub struct { ...@@ -18,6 +19,202 @@ type openAISnapshotCacheStub struct {
accountsByID map[int64]*Account accountsByID map[int64]*Account
} }
type schedulerTestOpenAIAccountRepo struct {
AccountRepository
accounts []Account
}
func (r schedulerTestOpenAIAccountRepo) GetByID(ctx context.Context, id int64) (*Account, error) {
for i := range r.accounts {
if r.accounts[i].ID == id {
return &r.accounts[i], nil
}
}
return nil, errors.New("account not found")
}
func (r schedulerTestOpenAIAccountRepo) ListSchedulableByGroupIDAndPlatform(ctx context.Context, groupID int64, platform string) ([]Account, error) {
var result []Account
for _, acc := range r.accounts {
if acc.Platform == platform {
result = append(result, acc)
}
}
return result, nil
}
func (r schedulerTestOpenAIAccountRepo) ListSchedulableByPlatform(ctx context.Context, platform string) ([]Account, error) {
var result []Account
for _, acc := range r.accounts {
if acc.Platform == platform {
result = append(result, acc)
}
}
return result, nil
}
func (r schedulerTestOpenAIAccountRepo) ListSchedulableUngroupedByPlatform(ctx context.Context, platform string) ([]Account, error) {
return r.ListSchedulableByPlatform(ctx, platform)
}
type schedulerTestConcurrencyCache struct {
ConcurrencyCache
loadBatchErr error
loadMap map[int64]*AccountLoadInfo
acquireResults map[int64]bool
waitCounts map[int64]int
skipDefaultLoad bool
}
func (c schedulerTestConcurrencyCache) AcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error) {
if c.acquireResults != nil {
if result, ok := c.acquireResults[accountID]; ok {
return result, nil
}
}
return true, nil
}
func (c schedulerTestConcurrencyCache) ReleaseAccountSlot(ctx context.Context, accountID int64, requestID string) error {
return nil
}
func (c schedulerTestConcurrencyCache) GetAccountsLoadBatch(ctx context.Context, accounts []AccountWithConcurrency) (map[int64]*AccountLoadInfo, error) {
if c.loadBatchErr != nil {
return nil, c.loadBatchErr
}
out := make(map[int64]*AccountLoadInfo, len(accounts))
if c.skipDefaultLoad && c.loadMap != nil {
for _, acc := range accounts {
if load, ok := c.loadMap[acc.ID]; ok {
out[acc.ID] = load
}
}
return out, nil
}
for _, acc := range accounts {
if c.loadMap != nil {
if load, ok := c.loadMap[acc.ID]; ok {
out[acc.ID] = load
continue
}
}
out[acc.ID] = &AccountLoadInfo{AccountID: acc.ID, LoadRate: 0}
}
return out, nil
}
func (c schedulerTestConcurrencyCache) GetAccountWaitingCount(ctx context.Context, accountID int64) (int, error) {
if c.waitCounts != nil {
if count, ok := c.waitCounts[accountID]; ok {
return count, nil
}
}
return 0, nil
}
type schedulerTestGatewayCache struct {
sessionBindings map[string]int64
deletedSessions map[string]int
}
func (c *schedulerTestGatewayCache) GetSessionAccountID(ctx context.Context, groupID int64, sessionHash string) (int64, error) {
if id, ok := c.sessionBindings[sessionHash]; ok {
return id, nil
}
return 0, errors.New("not found")
}
func (c *schedulerTestGatewayCache) SetSessionAccountID(ctx context.Context, groupID int64, sessionHash string, accountID int64, ttl time.Duration) error {
if c.sessionBindings == nil {
c.sessionBindings = make(map[string]int64)
}
c.sessionBindings[sessionHash] = accountID
return nil
}
func (c *schedulerTestGatewayCache) RefreshSessionTTL(ctx context.Context, groupID int64, sessionHash string, ttl time.Duration) error {
return nil
}
func (c *schedulerTestGatewayCache) DeleteSessionAccountID(ctx context.Context, groupID int64, sessionHash string) error {
if c.sessionBindings == nil {
return nil
}
if c.deletedSessions == nil {
c.deletedSessions = make(map[string]int)
}
c.deletedSessions[sessionHash]++
delete(c.sessionBindings, sessionHash)
return nil
}
func newSchedulerTestOpenAIWSV2Config() *config.Config {
cfg := &config.Config{}
cfg.Gateway.OpenAIWS.Enabled = true
cfg.Gateway.OpenAIWS.OAuthEnabled = true
cfg.Gateway.OpenAIWS.APIKeyEnabled = true
cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
cfg.Gateway.OpenAIWS.StickyResponseIDTTLSeconds = 3600
return cfg
}
type openAIAdvancedSchedulerSettingRepoStub struct {
values map[string]string
}
func (s *openAIAdvancedSchedulerSettingRepoStub) Get(ctx context.Context, key string) (*Setting, error) {
value, err := s.GetValue(ctx, key)
if err != nil {
return nil, err
}
return &Setting{Key: key, Value: value}, nil
}
func (s *openAIAdvancedSchedulerSettingRepoStub) GetValue(_ context.Context, key string) (string, error) {
if s == nil || s.values == nil {
return "", ErrSettingNotFound
}
value, ok := s.values[key]
if !ok {
return "", ErrSettingNotFound
}
return value, nil
}
func (s *openAIAdvancedSchedulerSettingRepoStub) Set(context.Context, string, string) error {
panic("unexpected call to Set")
}
func (s *openAIAdvancedSchedulerSettingRepoStub) GetMultiple(context.Context, []string) (map[string]string, error) {
panic("unexpected call to GetMultiple")
}
func (s *openAIAdvancedSchedulerSettingRepoStub) SetMultiple(context.Context, map[string]string) error {
panic("unexpected call to SetMultiple")
}
func (s *openAIAdvancedSchedulerSettingRepoStub) GetAll(context.Context) (map[string]string, error) {
panic("unexpected call to GetAll")
}
func (s *openAIAdvancedSchedulerSettingRepoStub) Delete(context.Context, string) error {
panic("unexpected call to Delete")
}
func newOpenAIAdvancedSchedulerRateLimitService(enabled string) *RateLimitService {
resetOpenAIAdvancedSchedulerSettingCacheForTest()
repo := &openAIAdvancedSchedulerSettingRepoStub{
values: map[string]string{},
}
if enabled != "" {
repo.values[openAIAdvancedSchedulerSettingKey] = enabled
}
return &RateLimitService{
settingService: NewSettingService(repo, &config.Config{}),
}
}
func (s *openAISnapshotCacheStub) GetSnapshot(ctx context.Context, bucket SchedulerBucket) ([]*Account, bool, error) { func (s *openAISnapshotCacheStub) GetSnapshot(ctx context.Context, bucket SchedulerBucket) ([]*Account, bool, error) {
if len(s.snapshotAccounts) == 0 { if len(s.snapshotAccounts) == 0 {
return nil, false, nil return nil, false, nil
...@@ -45,6 +242,138 @@ func (s *openAISnapshotCacheStub) GetAccount(ctx context.Context, accountID int6 ...@@ -45,6 +242,138 @@ func (s *openAISnapshotCacheStub) GetAccount(ctx context.Context, accountID int6
return &cloned, nil return &cloned, nil
} }
func TestOpenAIGatewayService_SelectAccountWithScheduler_DefaultDisabledUsesLegacyLoadAwareness(t *testing.T) {
resetOpenAIAdvancedSchedulerSettingCacheForTest()
ctx := context.Background()
groupID := int64(10106)
accounts := []Account{
{
ID: 36001,
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Status: StatusActive,
Schedulable: true,
Concurrency: 1,
Priority: 5,
},
{
ID: 36002,
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Status: StatusActive,
Schedulable: true,
Concurrency: 1,
Priority: 0,
},
}
cfg := &config.Config{}
cfg.Gateway.Scheduling.LoadBatchEnabled = false
cache := &schedulerTestGatewayCache{}
svc := &OpenAIGatewayService{
accountRepo: schedulerTestOpenAIAccountRepo{accounts: accounts},
cache: cache,
cfg: cfg,
concurrencyService: NewConcurrencyService(schedulerTestConcurrencyCache{}),
}
store := svc.getOpenAIWSStateStore()
require.NoError(t, store.BindResponseAccount(ctx, groupID, "resp_disabled_001", 36001, time.Hour))
require.False(t, svc.isOpenAIAdvancedSchedulerEnabled(ctx))
selection, decision, err := svc.SelectAccountWithScheduler(
ctx,
&groupID,
"resp_disabled_001",
"",
"gpt-5.1",
nil,
OpenAIUpstreamTransportAny,
)
require.NoError(t, err)
require.NotNil(t, selection)
require.NotNil(t, selection.Account)
require.Equal(t, int64(36002), selection.Account.ID)
require.Equal(t, openAIAccountScheduleLayerLoadBalance, decision.Layer)
require.False(t, decision.StickyPreviousHit)
}
func TestOpenAIGatewayService_SelectAccountWithScheduler_EnabledUsesAdvancedPreviousResponseRouting(t *testing.T) {
resetOpenAIAdvancedSchedulerSettingCacheForTest()
ctx := context.Background()
groupID := int64(10107)
accounts := []Account{
{
ID: 37001,
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Status: StatusActive,
Schedulable: true,
Concurrency: 1,
Priority: 5,
Extra: map[string]any{
"openai_apikey_responses_websockets_v2_enabled": true,
},
},
{
ID: 37002,
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Status: StatusActive,
Schedulable: true,
Concurrency: 1,
Priority: 0,
},
}
cfg := &config.Config{}
cfg.Gateway.Scheduling.LoadBatchEnabled = false
cfg.Gateway.OpenAIWS.Enabled = true
cfg.Gateway.OpenAIWS.OAuthEnabled = true
cfg.Gateway.OpenAIWS.APIKeyEnabled = true
cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
cfg.Gateway.OpenAIWS.StickyResponseIDTTLSeconds = 3600
svc := &OpenAIGatewayService{
accountRepo: schedulerTestOpenAIAccountRepo{accounts: accounts},
cache: &schedulerTestGatewayCache{},
cfg: cfg,
rateLimitService: newOpenAIAdvancedSchedulerRateLimitService("true"),
concurrencyService: NewConcurrencyService(schedulerTestConcurrencyCache{}),
}
store := svc.getOpenAIWSStateStore()
require.NoError(t, store.BindResponseAccount(ctx, groupID, "resp_enabled_001", 37001, time.Hour))
require.True(t, svc.isOpenAIAdvancedSchedulerEnabled(ctx))
selection, decision, err := svc.SelectAccountWithScheduler(
ctx,
&groupID,
"resp_enabled_001",
"",
"gpt-5.1",
nil,
OpenAIUpstreamTransportAny,
)
require.NoError(t, err)
require.NotNil(t, selection)
require.NotNil(t, selection.Account)
require.Equal(t, int64(37001), selection.Account.ID)
require.Equal(t, openAIAccountScheduleLayerPreviousResponse, decision.Layer)
require.True(t, decision.StickyPreviousHit)
}
func TestOpenAIGatewayService_OpenAIAccountSchedulerMetrics_DisabledNoOp(t *testing.T) {
resetOpenAIAdvancedSchedulerSettingCacheForTest()
svc := &OpenAIGatewayService{}
ttft := 120
svc.ReportOpenAIAccountScheduleResult(10, true, &ttft)
svc.RecordOpenAIAccountSwitch()
snapshot := svc.SnapshotOpenAIAccountSchedulerMetrics()
require.Equal(t, OpenAIAccountSchedulerMetricsSnapshot{}, snapshot)
}
func TestOpenAIGatewayService_SelectAccountWithScheduler_SessionStickyRateLimitedAccountFallsBackToFreshCandidate(t *testing.T) { func TestOpenAIGatewayService_SelectAccountWithScheduler_SessionStickyRateLimitedAccountFallsBackToFreshCandidate(t *testing.T) {
ctx := context.Background() ctx := context.Background()
groupID := int64(10101) groupID := int64(10101)
...@@ -53,10 +382,17 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_SessionStickyRateLimite ...@@ -53,10 +382,17 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_SessionStickyRateLimite
staleBackup := &Account{ID: 31002, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 5} staleBackup := &Account{ID: 31002, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 5}
freshSticky := &Account{ID: 31001, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 0, RateLimitResetAt: &rateLimitedUntil} freshSticky := &Account{ID: 31001, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 0, RateLimitResetAt: &rateLimitedUntil}
freshBackup := &Account{ID: 31002, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 5} freshBackup := &Account{ID: 31002, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 5}
cache := &stubGatewayCache{sessionBindings: map[string]int64{"openai:session_hash_rate_limited": 31001}} cache := &schedulerTestGatewayCache{sessionBindings: map[string]int64{"openai:session_hash_rate_limited": 31001}}
snapshotCache := &openAISnapshotCacheStub{snapshotAccounts: []*Account{staleSticky, staleBackup}, accountsByID: map[int64]*Account{31001: freshSticky, 31002: freshBackup}} snapshotCache := &openAISnapshotCacheStub{snapshotAccounts: []*Account{staleSticky, staleBackup}, accountsByID: map[int64]*Account{31001: freshSticky, 31002: freshBackup}}
snapshotService := &SchedulerSnapshotService{cache: snapshotCache} snapshotService := &SchedulerSnapshotService{cache: snapshotCache}
svc := &OpenAIGatewayService{accountRepo: stubOpenAIAccountRepo{accounts: []Account{*freshSticky, *freshBackup}}, cache: cache, cfg: &config.Config{}, schedulerSnapshot: snapshotService, concurrencyService: NewConcurrencyService(stubConcurrencyCache{})} svc := &OpenAIGatewayService{
accountRepo: schedulerTestOpenAIAccountRepo{accounts: []Account{*freshSticky, *freshBackup}},
cache: cache,
cfg: &config.Config{},
rateLimitService: newOpenAIAdvancedSchedulerRateLimitService("true"),
schedulerSnapshot: snapshotService,
concurrencyService: NewConcurrencyService(schedulerTestConcurrencyCache{}),
}
selection, decision, err := svc.SelectAccountWithScheduler(ctx, &groupID, "", "session_hash_rate_limited", "gpt-5.1", nil, OpenAIUpstreamTransportAny) selection, decision, err := svc.SelectAccountWithScheduler(ctx, &groupID, "", "session_hash_rate_limited", "gpt-5.1", nil, OpenAIUpstreamTransportAny)
require.NoError(t, err) require.NoError(t, err)
...@@ -76,7 +412,12 @@ func TestOpenAIGatewayService_SelectAccountForModelWithExclusions_SkipsFreshlyRa ...@@ -76,7 +412,12 @@ func TestOpenAIGatewayService_SelectAccountForModelWithExclusions_SkipsFreshlyRa
freshSecondary := &Account{ID: 32002, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 5} freshSecondary := &Account{ID: 32002, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 5}
snapshotCache := &openAISnapshotCacheStub{snapshotAccounts: []*Account{stalePrimary, staleSecondary}, accountsByID: map[int64]*Account{32001: freshPrimary, 32002: freshSecondary}} snapshotCache := &openAISnapshotCacheStub{snapshotAccounts: []*Account{stalePrimary, staleSecondary}, accountsByID: map[int64]*Account{32001: freshPrimary, 32002: freshSecondary}}
snapshotService := &SchedulerSnapshotService{cache: snapshotCache} snapshotService := &SchedulerSnapshotService{cache: snapshotCache}
svc := &OpenAIGatewayService{accountRepo: stubOpenAIAccountRepo{accounts: []Account{*freshPrimary, *freshSecondary}}, cfg: &config.Config{}, schedulerSnapshot: snapshotService} svc := &OpenAIGatewayService{
accountRepo: schedulerTestOpenAIAccountRepo{accounts: []Account{*freshPrimary, *freshSecondary}},
cfg: &config.Config{},
rateLimitService: newOpenAIAdvancedSchedulerRateLimitService("true"),
schedulerSnapshot: snapshotService,
}
account, err := svc.SelectAccountForModelWithExclusions(ctx, &groupID, "", "gpt-5.1", nil) account, err := svc.SelectAccountForModelWithExclusions(ctx, &groupID, "", "gpt-5.1", nil)
require.NoError(t, err) require.NoError(t, err)
...@@ -92,18 +433,19 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_SessionStickyDBRuntimeR ...@@ -92,18 +433,19 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_SessionStickyDBRuntimeR
staleBackup := &Account{ID: 33002, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 5} staleBackup := &Account{ID: 33002, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 5}
dbSticky := Account{ID: 33001, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 0, RateLimitResetAt: &rateLimitedUntil} dbSticky := Account{ID: 33001, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 0, RateLimitResetAt: &rateLimitedUntil}
dbBackup := Account{ID: 33002, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 5} dbBackup := Account{ID: 33002, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 5}
cache := &stubGatewayCache{sessionBindings: map[string]int64{"openai:session_hash_db_runtime_recheck": 33001}} cache := &schedulerTestGatewayCache{sessionBindings: map[string]int64{"openai:session_hash_db_runtime_recheck": 33001}}
snapshotCache := &openAISnapshotCacheStub{ snapshotCache := &openAISnapshotCacheStub{
snapshotAccounts: []*Account{staleSticky, staleBackup}, snapshotAccounts: []*Account{staleSticky, staleBackup},
accountsByID: map[int64]*Account{33001: staleSticky, 33002: staleBackup}, accountsByID: map[int64]*Account{33001: staleSticky, 33002: staleBackup},
} }
snapshotService := &SchedulerSnapshotService{cache: snapshotCache} snapshotService := &SchedulerSnapshotService{cache: snapshotCache}
svc := &OpenAIGatewayService{ svc := &OpenAIGatewayService{
accountRepo: stubOpenAIAccountRepo{accounts: []Account{dbSticky, dbBackup}}, accountRepo: schedulerTestOpenAIAccountRepo{accounts: []Account{dbSticky, dbBackup}},
cache: cache, cache: cache,
cfg: &config.Config{}, cfg: &config.Config{},
rateLimitService: newOpenAIAdvancedSchedulerRateLimitService("true"),
schedulerSnapshot: snapshotService, schedulerSnapshot: snapshotService,
concurrencyService: NewConcurrencyService(stubConcurrencyCache{}), concurrencyService: NewConcurrencyService(schedulerTestConcurrencyCache{}),
} }
selection, decision, err := svc.SelectAccountWithScheduler(ctx, &groupID, "", "session_hash_db_runtime_recheck", "gpt-5.1", nil, OpenAIUpstreamTransportAny) selection, decision, err := svc.SelectAccountWithScheduler(ctx, &groupID, "", "session_hash_db_runtime_recheck", "gpt-5.1", nil, OpenAIUpstreamTransportAny)
...@@ -128,8 +470,9 @@ func TestOpenAIGatewayService_SelectAccountForModelWithExclusions_DBRuntimeReche ...@@ -128,8 +470,9 @@ func TestOpenAIGatewayService_SelectAccountForModelWithExclusions_DBRuntimeReche
} }
snapshotService := &SchedulerSnapshotService{cache: snapshotCache} snapshotService := &SchedulerSnapshotService{cache: snapshotCache}
svc := &OpenAIGatewayService{ svc := &OpenAIGatewayService{
accountRepo: stubOpenAIAccountRepo{accounts: []Account{dbPrimary, dbSecondary}}, accountRepo: schedulerTestOpenAIAccountRepo{accounts: []Account{dbPrimary, dbSecondary}},
cfg: &config.Config{}, cfg: &config.Config{},
rateLimitService: newOpenAIAdvancedSchedulerRateLimitService("true"),
schedulerSnapshot: snapshotService, schedulerSnapshot: snapshotService,
} }
...@@ -153,7 +496,7 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_PreviousResponseSticky( ...@@ -153,7 +496,7 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_PreviousResponseSticky(
"openai_apikey_responses_websockets_v2_enabled": true, "openai_apikey_responses_websockets_v2_enabled": true,
}, },
} }
cache := &stubGatewayCache{} cache := &schedulerTestGatewayCache{}
cfg := &config.Config{} cfg := &config.Config{}
cfg.Gateway.OpenAIWS.Enabled = true cfg.Gateway.OpenAIWS.Enabled = true
cfg.Gateway.OpenAIWS.OAuthEnabled = true cfg.Gateway.OpenAIWS.OAuthEnabled = true
...@@ -163,10 +506,11 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_PreviousResponseSticky( ...@@ -163,10 +506,11 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_PreviousResponseSticky(
cfg.Gateway.OpenAIWS.StickyResponseIDTTLSeconds = 3600 cfg.Gateway.OpenAIWS.StickyResponseIDTTLSeconds = 3600
svc := &OpenAIGatewayService{ svc := &OpenAIGatewayService{
accountRepo: stubOpenAIAccountRepo{accounts: []Account{account}}, accountRepo: schedulerTestOpenAIAccountRepo{accounts: []Account{account}},
cache: cache, cache: cache,
cfg: cfg, cfg: cfg,
concurrencyService: NewConcurrencyService(stubConcurrencyCache{}), rateLimitService: newOpenAIAdvancedSchedulerRateLimitService("true"),
concurrencyService: NewConcurrencyService(schedulerTestConcurrencyCache{}),
} }
store := svc.getOpenAIWSStateStore() store := svc.getOpenAIWSStateStore()
...@@ -204,17 +548,18 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_SessionSticky(t *testin ...@@ -204,17 +548,18 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_SessionSticky(t *testin
Schedulable: true, Schedulable: true,
Concurrency: 1, Concurrency: 1,
} }
cache := &stubGatewayCache{ cache := &schedulerTestGatewayCache{
sessionBindings: map[string]int64{ sessionBindings: map[string]int64{
"openai:session_hash_abc": account.ID, "openai:session_hash_abc": account.ID,
}, },
} }
svc := &OpenAIGatewayService{ svc := &OpenAIGatewayService{
accountRepo: stubOpenAIAccountRepo{accounts: []Account{account}}, accountRepo: schedulerTestOpenAIAccountRepo{accounts: []Account{account}},
cache: cache, cache: cache,
cfg: &config.Config{}, cfg: &config.Config{},
concurrencyService: NewConcurrencyService(stubConcurrencyCache{}), rateLimitService: newOpenAIAdvancedSchedulerRateLimitService("true"),
concurrencyService: NewConcurrencyService(schedulerTestConcurrencyCache{}),
} }
selection, decision, err := svc.SelectAccountWithScheduler( selection, decision, err := svc.SelectAccountWithScheduler(
...@@ -260,7 +605,7 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_SessionStickyBusyKeepsS ...@@ -260,7 +605,7 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_SessionStickyBusyKeepsS
Priority: 9, Priority: 9,
}, },
} }
cache := &stubGatewayCache{ cache := &schedulerTestGatewayCache{
sessionBindings: map[string]int64{ sessionBindings: map[string]int64{
"openai:session_hash_sticky_busy": 21001, "openai:session_hash_sticky_busy": 21001,
}, },
...@@ -273,7 +618,7 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_SessionStickyBusyKeepsS ...@@ -273,7 +618,7 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_SessionStickyBusyKeepsS
cfg.Gateway.OpenAIWS.OAuthEnabled = true cfg.Gateway.OpenAIWS.OAuthEnabled = true
cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
concurrencyCache := stubConcurrencyCache{ concurrencyCache := schedulerTestConcurrencyCache{
acquireResults: map[int64]bool{ acquireResults: map[int64]bool{
21001: false, // sticky 账号已满 21001: false, // sticky 账号已满
21002: true, // 若回退负载均衡会命中该账号(本测试要求不能切换) 21002: true, // 若回退负载均衡会命中该账号(本测试要求不能切换)
...@@ -288,9 +633,10 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_SessionStickyBusyKeepsS ...@@ -288,9 +633,10 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_SessionStickyBusyKeepsS
} }
svc := &OpenAIGatewayService{ svc := &OpenAIGatewayService{
accountRepo: stubOpenAIAccountRepo{accounts: accounts}, accountRepo: schedulerTestOpenAIAccountRepo{accounts: accounts},
cache: cache, cache: cache,
cfg: cfg, cfg: cfg,
rateLimitService: newOpenAIAdvancedSchedulerRateLimitService("true"),
concurrencyService: NewConcurrencyService(concurrencyCache), concurrencyService: NewConcurrencyService(concurrencyCache),
} }
...@@ -328,17 +674,18 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_SessionSticky_ForceHTTP ...@@ -328,17 +674,18 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_SessionSticky_ForceHTTP
"openai_ws_force_http": true, "openai_ws_force_http": true,
}, },
} }
cache := &stubGatewayCache{ cache := &schedulerTestGatewayCache{
sessionBindings: map[string]int64{ sessionBindings: map[string]int64{
"openai:session_hash_force_http": account.ID, "openai:session_hash_force_http": account.ID,
}, },
} }
svc := &OpenAIGatewayService{ svc := &OpenAIGatewayService{
accountRepo: stubOpenAIAccountRepo{accounts: []Account{account}}, accountRepo: schedulerTestOpenAIAccountRepo{accounts: []Account{account}},
cache: cache, cache: cache,
cfg: &config.Config{}, cfg: &config.Config{},
concurrencyService: NewConcurrencyService(stubConcurrencyCache{}), rateLimitService: newOpenAIAdvancedSchedulerRateLimitService("true"),
concurrencyService: NewConcurrencyService(schedulerTestConcurrencyCache{}),
} }
selection, decision, err := svc.SelectAccountWithScheduler( selection, decision, err := svc.SelectAccountWithScheduler(
...@@ -387,15 +734,15 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_RequiredWSV2_SkipsStick ...@@ -387,15 +734,15 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_RequiredWSV2_SkipsStick
}, },
}, },
} }
cache := &stubGatewayCache{ cache := &schedulerTestGatewayCache{
sessionBindings: map[string]int64{ sessionBindings: map[string]int64{
"openai:session_hash_ws_only": 2201, "openai:session_hash_ws_only": 2201,
}, },
} }
cfg := newOpenAIWSV2TestConfig() cfg := newSchedulerTestOpenAIWSV2Config()
// 构造“HTTP-only 账号负载更低”的场景,验证 required transport 会强制过滤。 // 构造“HTTP-only 账号负载更低”的场景,验证 required transport 会强制过滤。
concurrencyCache := stubConcurrencyCache{ concurrencyCache := schedulerTestConcurrencyCache{
loadMap: map[int64]*AccountLoadInfo{ loadMap: map[int64]*AccountLoadInfo{
2201: {AccountID: 2201, LoadRate: 0, WaitingCount: 0}, 2201: {AccountID: 2201, LoadRate: 0, WaitingCount: 0},
2202: {AccountID: 2202, LoadRate: 90, WaitingCount: 5}, 2202: {AccountID: 2202, LoadRate: 90, WaitingCount: 5},
...@@ -403,9 +750,10 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_RequiredWSV2_SkipsStick ...@@ -403,9 +750,10 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_RequiredWSV2_SkipsStick
} }
svc := &OpenAIGatewayService{ svc := &OpenAIGatewayService{
accountRepo: stubOpenAIAccountRepo{accounts: accounts}, accountRepo: schedulerTestOpenAIAccountRepo{accounts: accounts},
cache: cache, cache: cache,
cfg: cfg, cfg: cfg,
rateLimitService: newOpenAIAdvancedSchedulerRateLimitService("true"),
concurrencyService: NewConcurrencyService(concurrencyCache), concurrencyService: NewConcurrencyService(concurrencyCache),
} }
...@@ -445,10 +793,11 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_RequiredWSV2_NoAvailabl ...@@ -445,10 +793,11 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_RequiredWSV2_NoAvailabl
} }
svc := &OpenAIGatewayService{ svc := &OpenAIGatewayService{
accountRepo: stubOpenAIAccountRepo{accounts: accounts}, accountRepo: schedulerTestOpenAIAccountRepo{accounts: accounts},
cache: &stubGatewayCache{}, cache: &schedulerTestGatewayCache{},
cfg: newOpenAIWSV2TestConfig(), cfg: newSchedulerTestOpenAIWSV2Config(),
concurrencyService: NewConcurrencyService(stubConcurrencyCache{}), rateLimitService: newOpenAIAdvancedSchedulerRateLimitService("true"),
concurrencyService: NewConcurrencyService(schedulerTestConcurrencyCache{}),
} }
selection, decision, err := svc.SelectAccountWithScheduler( selection, decision, err := svc.SelectAccountWithScheduler(
...@@ -507,7 +856,7 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_LoadBalanceTopKFallback ...@@ -507,7 +856,7 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_LoadBalanceTopKFallback
cfg.Gateway.OpenAIWS.SchedulerScoreWeights.ErrorRate = 0.2 cfg.Gateway.OpenAIWS.SchedulerScoreWeights.ErrorRate = 0.2
cfg.Gateway.OpenAIWS.SchedulerScoreWeights.TTFT = 0.1 cfg.Gateway.OpenAIWS.SchedulerScoreWeights.TTFT = 0.1
concurrencyCache := stubConcurrencyCache{ concurrencyCache := schedulerTestConcurrencyCache{
loadMap: map[int64]*AccountLoadInfo{ loadMap: map[int64]*AccountLoadInfo{
3001: {AccountID: 3001, LoadRate: 95, WaitingCount: 8}, 3001: {AccountID: 3001, LoadRate: 95, WaitingCount: 8},
3002: {AccountID: 3002, LoadRate: 20, WaitingCount: 1}, 3002: {AccountID: 3002, LoadRate: 20, WaitingCount: 1},
...@@ -520,9 +869,10 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_LoadBalanceTopKFallback ...@@ -520,9 +869,10 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_LoadBalanceTopKFallback
} }
svc := &OpenAIGatewayService{ svc := &OpenAIGatewayService{
accountRepo: stubOpenAIAccountRepo{accounts: accounts}, accountRepo: schedulerTestOpenAIAccountRepo{accounts: accounts},
cache: &stubGatewayCache{}, cache: &schedulerTestGatewayCache{},
cfg: cfg, cfg: cfg,
rateLimitService: newOpenAIAdvancedSchedulerRateLimitService("true"),
concurrencyService: NewConcurrencyService(concurrencyCache), concurrencyService: NewConcurrencyService(concurrencyCache),
} }
...@@ -559,16 +909,17 @@ func TestOpenAIGatewayService_OpenAIAccountSchedulerMetrics(t *testing.T) { ...@@ -559,16 +909,17 @@ func TestOpenAIGatewayService_OpenAIAccountSchedulerMetrics(t *testing.T) {
Schedulable: true, Schedulable: true,
Concurrency: 1, Concurrency: 1,
} }
cache := &stubGatewayCache{ cache := &schedulerTestGatewayCache{
sessionBindings: map[string]int64{ sessionBindings: map[string]int64{
"openai:session_hash_metrics": account.ID, "openai:session_hash_metrics": account.ID,
}, },
} }
svc := &OpenAIGatewayService{ svc := &OpenAIGatewayService{
accountRepo: stubOpenAIAccountRepo{accounts: []Account{account}}, accountRepo: schedulerTestOpenAIAccountRepo{accounts: []Account{account}},
cache: cache, cache: cache,
cfg: &config.Config{}, cfg: &config.Config{},
concurrencyService: NewConcurrencyService(stubConcurrencyCache{}), rateLimitService: newOpenAIAdvancedSchedulerRateLimitService("true"),
concurrencyService: NewConcurrencyService(schedulerTestConcurrencyCache{}),
} }
selection, _, err := svc.SelectAccountWithScheduler(ctx, &groupID, "", "session_hash_metrics", "gpt-5.1", nil, OpenAIUpstreamTransportAny) selection, _, err := svc.SelectAccountWithScheduler(ctx, &groupID, "", "session_hash_metrics", "gpt-5.1", nil, OpenAIUpstreamTransportAny)
...@@ -749,7 +1100,7 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_LoadBalanceDistributesA ...@@ -749,7 +1100,7 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_LoadBalanceDistributesA
cfg.Gateway.OpenAIWS.SchedulerScoreWeights.ErrorRate = 1 cfg.Gateway.OpenAIWS.SchedulerScoreWeights.ErrorRate = 1
cfg.Gateway.OpenAIWS.SchedulerScoreWeights.TTFT = 1 cfg.Gateway.OpenAIWS.SchedulerScoreWeights.TTFT = 1
concurrencyCache := stubConcurrencyCache{ concurrencyCache := schedulerTestConcurrencyCache{
loadMap: map[int64]*AccountLoadInfo{ loadMap: map[int64]*AccountLoadInfo{
5101: {AccountID: 5101, LoadRate: 20, WaitingCount: 1}, 5101: {AccountID: 5101, LoadRate: 20, WaitingCount: 1},
5102: {AccountID: 5102, LoadRate: 20, WaitingCount: 1}, 5102: {AccountID: 5102, LoadRate: 20, WaitingCount: 1},
...@@ -757,9 +1108,10 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_LoadBalanceDistributesA ...@@ -757,9 +1108,10 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_LoadBalanceDistributesA
}, },
} }
svc := &OpenAIGatewayService{ svc := &OpenAIGatewayService{
accountRepo: stubOpenAIAccountRepo{accounts: accounts}, accountRepo: schedulerTestOpenAIAccountRepo{accounts: accounts},
cache: &stubGatewayCache{sessionBindings: map[string]int64{}}, cache: &schedulerTestGatewayCache{sessionBindings: map[string]int64{}},
cfg: cfg, cfg: cfg,
rateLimitService: newOpenAIAdvancedSchedulerRateLimitService("true"),
concurrencyService: NewConcurrencyService(concurrencyCache), concurrencyService: NewConcurrencyService(concurrencyCache),
} }
...@@ -905,12 +1257,14 @@ func TestDefaultOpenAIAccountScheduler_ReportSwitchAndSnapshot(t *testing.T) { ...@@ -905,12 +1257,14 @@ func TestDefaultOpenAIAccountScheduler_ReportSwitchAndSnapshot(t *testing.T) {
} }
func TestOpenAIGatewayService_SchedulerWrappersAndDefaults(t *testing.T) { func TestOpenAIGatewayService_SchedulerWrappersAndDefaults(t *testing.T) {
resetOpenAIAdvancedSchedulerSettingCacheForTest()
svc := &OpenAIGatewayService{} svc := &OpenAIGatewayService{}
ttft := 120 ttft := 120
svc.ReportOpenAIAccountScheduleResult(10, true, &ttft) svc.ReportOpenAIAccountScheduleResult(10, true, &ttft)
svc.RecordOpenAIAccountSwitch() svc.RecordOpenAIAccountSwitch()
snapshot := svc.SnapshotOpenAIAccountSchedulerMetrics() snapshot := svc.SnapshotOpenAIAccountSchedulerMetrics()
require.GreaterOrEqual(t, snapshot.AccountSwitchTotal, int64(1)) require.Equal(t, OpenAIAccountSchedulerMetricsSnapshot{}, snapshot)
require.Equal(t, 7, svc.openAIWSLBTopK()) require.Equal(t, 7, svc.openAIWSLBTopK())
require.Equal(t, openaiStickySessionTTL, svc.openAIWSSessionStickyTTL()) require.Equal(t, openaiStickySessionTTL, svc.openAIWSSessionStickyTTL())
...@@ -947,7 +1301,7 @@ func TestDefaultOpenAIAccountScheduler_IsAccountTransportCompatible_Branches(t * ...@@ -947,7 +1301,7 @@ func TestDefaultOpenAIAccountScheduler_IsAccountTransportCompatible_Branches(t *
require.True(t, scheduler.isAccountTransportCompatible(nil, OpenAIUpstreamTransportHTTPSSE)) require.True(t, scheduler.isAccountTransportCompatible(nil, OpenAIUpstreamTransportHTTPSSE))
require.False(t, scheduler.isAccountTransportCompatible(nil, OpenAIUpstreamTransportResponsesWebsocketV2)) require.False(t, scheduler.isAccountTransportCompatible(nil, OpenAIUpstreamTransportResponsesWebsocketV2))
cfg := newOpenAIWSV2TestConfig() cfg := newSchedulerTestOpenAIWSV2Config()
scheduler.service = &OpenAIGatewayService{cfg: cfg} scheduler.service = &OpenAIGatewayService{cfg: cfg}
account := &Account{ account := &Account{
ID: 8801, ID: 8801,
......
...@@ -38,11 +38,12 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_UsesWSPassthroughSnapsh ...@@ -38,11 +38,12 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_UsesWSPassthroughSnapsh
cfg.Gateway.OpenAIWS.IngressModeDefault = OpenAIWSIngressModeCtxPool cfg.Gateway.OpenAIWS.IngressModeDefault = OpenAIWSIngressModeCtxPool
svc := &OpenAIGatewayService{ svc := &OpenAIGatewayService{
accountRepo: stubOpenAIAccountRepo{accounts: []Account{*account}}, accountRepo: schedulerTestOpenAIAccountRepo{accounts: []Account{*account}},
cache: &stubGatewayCache{}, cache: &schedulerTestGatewayCache{},
cfg: cfg, cfg: cfg,
rateLimitService: newOpenAIAdvancedSchedulerRateLimitService("true"),
schedulerSnapshot: &SchedulerSnapshotService{cache: snapshotCache}, schedulerSnapshot: &SchedulerSnapshotService{cache: snapshotCache},
concurrencyService: NewConcurrencyService(stubConcurrencyCache{}), concurrencyService: NewConcurrencyService(schedulerTestConcurrencyCache{}),
} }
selection, decision, err := svc.SelectAccountWithScheduler( selection, decision, err := svc.SelectAccountWithScheduler(
......
...@@ -196,12 +196,25 @@ func (s *PaymentConfigService) GetPaymentConfig(ctx context.Context) (*PaymentCo ...@@ -196,12 +196,25 @@ func (s *PaymentConfigService) GetPaymentConfig(ctx context.Context) (*PaymentCo
SettingHelpImageURL, SettingHelpText, SettingHelpImageURL, SettingHelpText,
SettingCancelRateLimitOn, SettingCancelRateLimitMax, SettingCancelRateLimitOn, SettingCancelRateLimitMax,
SettingCancelWindowSize, SettingCancelWindowUnit, SettingCancelWindowMode, SettingCancelWindowSize, SettingCancelWindowUnit, SettingCancelWindowMode,
SettingPaymentVisibleMethodAlipayEnabled, SettingPaymentVisibleMethodAlipaySource,
SettingPaymentVisibleMethodWxpayEnabled, SettingPaymentVisibleMethodWxpaySource,
} }
vals, err := s.settingRepo.GetMultiple(ctx, keys) vals, err := s.settingRepo.GetMultiple(ctx, keys)
if err != nil { if err != nil {
return nil, fmt.Errorf("get payment config settings: %w", err) return nil, fmt.Errorf("get payment config settings: %w", err)
} }
cfg := s.parsePaymentConfig(vals) cfg := s.parsePaymentConfig(vals)
if s.entClient != nil {
instances, err := s.entClient.PaymentProviderInstance.Query().
Where(paymentproviderinstance.EnabledEQ(true)).
All(ctx)
if err != nil {
return nil, fmt.Errorf("list enabled provider instances: %w", err)
}
cfg.EnabledTypes = applyVisibleMethodRoutingToEnabledTypes(cfg.EnabledTypes, vals, buildVisibleMethodSourceAvailability(instances))
} else {
cfg.EnabledTypes = applyVisibleMethodRoutingToEnabledTypes(cfg.EnabledTypes, vals, nil)
}
// Load Stripe publishable key from the first enabled Stripe provider instance // Load Stripe publishable key from the first enabled Stripe provider instance
cfg.StripePublishableKey = s.getStripePublishableKey(ctx) cfg.StripePublishableKey = s.getStripePublishableKey(ctx)
return cfg, nil return cfg, nil
...@@ -234,18 +247,23 @@ func (s *PaymentConfigService) parsePaymentConfig(vals map[string]string) *Payme ...@@ -234,18 +247,23 @@ func (s *PaymentConfigService) parsePaymentConfig(vals map[string]string) *Payme
cfg.LoadBalanceStrategy = payment.DefaultLoadBalanceStrategy cfg.LoadBalanceStrategy = payment.DefaultLoadBalanceStrategy
} }
if raw := vals[SettingEnabledPaymentTypes]; raw != "" { if raw := vals[SettingEnabledPaymentTypes]; raw != "" {
types := make([]string, 0, len(strings.Split(raw, ",")))
for _, t := range strings.Split(raw, ",") { for _, t := range strings.Split(raw, ",") {
t = strings.TrimSpace(t) t = strings.TrimSpace(t)
if t != "" { if t != "" {
cfg.EnabledTypes = append(cfg.EnabledTypes, t) types = append(types, t)
} }
} }
cfg.EnabledTypes = NormalizeVisibleMethods(types)
} }
return cfg return cfg
} }
// getStripePublishableKey finds the publishable key from the first enabled Stripe provider instance. // getStripePublishableKey finds the publishable key from the first enabled Stripe provider instance.
func (s *PaymentConfigService) getStripePublishableKey(ctx context.Context) string { func (s *PaymentConfigService) getStripePublishableKey(ctx context.Context) string {
if s.entClient == nil {
return ""
}
instances, err := s.entClient.PaymentProviderInstance.Query(). instances, err := s.entClient.PaymentProviderInstance.Query().
Where( Where(
paymentproviderinstance.EnabledEQ(true), paymentproviderinstance.EnabledEQ(true),
...@@ -385,3 +403,79 @@ func pcParseInt(s string, defaultVal int) int { ...@@ -385,3 +403,79 @@ func pcParseInt(s string, defaultVal int) int {
} }
return v return v
} }
func buildVisibleMethodSourceAvailability(instances []*dbent.PaymentProviderInstance) map[string]bool {
available := make(map[string]bool, 4)
for _, inst := range instances {
switch inst.ProviderKey {
case payment.TypeAlipay:
if inst.SupportedTypes == "" || payment.InstanceSupportsType(inst.SupportedTypes, payment.TypeAlipay) || payment.InstanceSupportsType(inst.SupportedTypes, payment.TypeAlipayDirect) {
available[VisibleMethodSourceOfficialAlipay] = true
}
case payment.TypeWxpay:
if inst.SupportedTypes == "" || payment.InstanceSupportsType(inst.SupportedTypes, payment.TypeWxpay) || payment.InstanceSupportsType(inst.SupportedTypes, payment.TypeWxpayDirect) {
available[VisibleMethodSourceOfficialWechat] = true
}
case payment.TypeEasyPay:
for _, supportedType := range splitTypes(inst.SupportedTypes) {
switch NormalizeVisibleMethod(supportedType) {
case payment.TypeAlipay:
available[VisibleMethodSourceEasyPayAlipay] = true
case payment.TypeWxpay:
available[VisibleMethodSourceEasyPayWechat] = true
}
}
}
}
return available
}
func applyVisibleMethodRoutingToEnabledTypes(base []string, vals map[string]string, available map[string]bool) []string {
shouldExpose := map[string]bool{
payment.TypeAlipay: visibleMethodShouldBeExposed(payment.TypeAlipay, vals, available),
payment.TypeWxpay: visibleMethodShouldBeExposed(payment.TypeWxpay, vals, available),
}
seen := make(map[string]struct{}, len(base)+2)
out := make([]string, 0, len(base)+2)
appendType := func(paymentType string) {
paymentType = NormalizeVisibleMethod(paymentType)
if paymentType == "" {
return
}
if _, ok := seen[paymentType]; ok {
return
}
seen[paymentType] = struct{}{}
out = append(out, paymentType)
}
for _, paymentType := range base {
visibleMethod := NormalizeVisibleMethod(paymentType)
switch visibleMethod {
case payment.TypeAlipay, payment.TypeWxpay:
if shouldExpose[visibleMethod] {
appendType(visibleMethod)
}
default:
appendType(visibleMethod)
}
}
for _, visibleMethod := range []string{payment.TypeAlipay, payment.TypeWxpay} {
if shouldExpose[visibleMethod] {
appendType(visibleMethod)
}
}
return out
}
func visibleMethodShouldBeExposed(method string, vals map[string]string, available map[string]bool) bool {
enabledKey := visibleMethodEnabledSettingKey(method)
sourceKey := visibleMethodSourceSettingKey(method)
if enabledKey == "" || sourceKey == "" || vals[enabledKey] != "true" {
return false
}
source := NormalizeVisibleMethodSource(method, vals[sourceKey])
return source != "" && available[source]
}
package service package service
import ( import (
"context"
"database/sql"
"testing" "testing"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/enttest"
"github.com/Wei-Shaw/sub2api/internal/payment" "github.com/Wei-Shaw/sub2api/internal/payment"
"entgo.io/ent/dialect"
entsql "entgo.io/ent/dialect/sql"
_ "modernc.org/sqlite"
) )
func TestPcParseFloat(t *testing.T) { func TestPcParseFloat(t *testing.T) {
...@@ -163,6 +171,20 @@ func TestParsePaymentConfig(t *testing.T) { ...@@ -163,6 +171,20 @@ func TestParsePaymentConfig(t *testing.T) {
} }
}) })
t.Run("enabled types are normalized to visible methods and deduplicated", func(t *testing.T) {
t.Parallel()
vals := map[string]string{
SettingEnabledPaymentTypes: "alipay_direct, alipay, wxpay_direct, wxpay",
}
cfg := svc.parsePaymentConfig(vals)
if len(cfg.EnabledTypes) != 2 {
t.Fatalf("EnabledTypes len = %d, want 2", len(cfg.EnabledTypes))
}
if cfg.EnabledTypes[0] != "alipay" || cfg.EnabledTypes[1] != "wxpay" {
t.Fatalf("EnabledTypes = %v, want [alipay wxpay]", cfg.EnabledTypes)
}
})
t.Run("empty enabled types string", func(t *testing.T) { t.Run("empty enabled types string", func(t *testing.T) {
t.Parallel() t.Parallel()
vals := map[string]string{ vals := map[string]string{
...@@ -204,3 +226,167 @@ func TestGetBasePaymentType(t *testing.T) { ...@@ -204,3 +226,167 @@ func TestGetBasePaymentType(t *testing.T) {
}) })
} }
} }
func TestApplyVisibleMethodRoutingToEnabledTypes(t *testing.T) {
t.Parallel()
base := []string{"alipay", "wxpay", "stripe"}
vals := map[string]string{
SettingPaymentVisibleMethodAlipayEnabled: "true",
SettingPaymentVisibleMethodAlipaySource: VisibleMethodSourceOfficialAlipay,
SettingPaymentVisibleMethodWxpayEnabled: "true",
SettingPaymentVisibleMethodWxpaySource: VisibleMethodSourceOfficialWechat,
}
available := map[string]bool{
VisibleMethodSourceOfficialAlipay: true,
VisibleMethodSourceOfficialWechat: false,
}
got := applyVisibleMethodRoutingToEnabledTypes(base, vals, available)
want := []string{"alipay", "stripe"}
if len(got) != len(want) {
t.Fatalf("applyVisibleMethodRoutingToEnabledTypes len = %d, want %d (%v)", len(got), len(want), got)
}
for i := range want {
if got[i] != want[i] {
t.Fatalf("applyVisibleMethodRoutingToEnabledTypes[%d] = %q, want %q (full=%v)", i, got[i], want[i], got)
}
}
}
func TestApplyVisibleMethodRoutingAddsConfiguredVisibleMethod(t *testing.T) {
t.Parallel()
base := []string{"stripe"}
vals := map[string]string{
SettingPaymentVisibleMethodAlipayEnabled: "true",
SettingPaymentVisibleMethodAlipaySource: VisibleMethodSourceEasyPayAlipay,
}
available := map[string]bool{
VisibleMethodSourceEasyPayAlipay: true,
}
got := applyVisibleMethodRoutingToEnabledTypes(base, vals, available)
want := []string{"stripe", "alipay"}
if len(got) != len(want) {
t.Fatalf("applyVisibleMethodRoutingToEnabledTypes len = %d, want %d (%v)", len(got), len(want), got)
}
for i := range want {
if got[i] != want[i] {
t.Fatalf("applyVisibleMethodRoutingToEnabledTypes[%d] = %q, want %q (full=%v)", i, got[i], want[i], got)
}
}
}
func TestBuildVisibleMethodSourceAvailability(t *testing.T) {
t.Parallel()
instances := []*dbent.PaymentProviderInstance{
{ProviderKey: payment.TypeAlipay, SupportedTypes: "alipay"},
{ProviderKey: payment.TypeEasyPay, SupportedTypes: "wxpay_direct, alipay"},
{ProviderKey: payment.TypeWxpay, SupportedTypes: "wxpay_direct"},
}
got := buildVisibleMethodSourceAvailability(instances)
if !got[VisibleMethodSourceOfficialAlipay] {
t.Fatalf("expected %q to be available", VisibleMethodSourceOfficialAlipay)
}
if !got[VisibleMethodSourceEasyPayAlipay] {
t.Fatalf("expected %q to be available", VisibleMethodSourceEasyPayAlipay)
}
if !got[VisibleMethodSourceOfficialWechat] {
t.Fatalf("expected %q to be available", VisibleMethodSourceOfficialWechat)
}
if !got[VisibleMethodSourceEasyPayWechat] {
t.Fatalf("expected %q to be available", VisibleMethodSourceEasyPayWechat)
}
}
func TestGetPaymentConfigAppliesVisibleMethodRouting(t *testing.T) {
ctx := context.Background()
client := newPaymentConfigServiceTestClient(t)
_, err := client.PaymentProviderInstance.Create().
SetProviderKey(payment.TypeEasyPay).
SetName("EasyPay Alipay").
SetConfig("{}").
SetSupportedTypes("alipay").
SetEnabled(true).
Save(ctx)
if err != nil {
t.Fatalf("create easypay instance: %v", err)
}
svc := &PaymentConfigService{
entClient: client,
settingRepo: &paymentConfigSettingRepoStub{
values: map[string]string{
SettingEnabledPaymentTypes: "alipay,wxpay,stripe",
SettingPaymentVisibleMethodAlipayEnabled: "true",
SettingPaymentVisibleMethodAlipaySource: "easypay",
SettingPaymentVisibleMethodWxpayEnabled: "true",
SettingPaymentVisibleMethodWxpaySource: "wxpay",
},
},
}
cfg, err := svc.GetPaymentConfig(ctx)
if err != nil {
t.Fatalf("GetPaymentConfig returned error: %v", err)
}
want := []string{payment.TypeAlipay, payment.TypeStripe}
if len(cfg.EnabledTypes) != len(want) {
t.Fatalf("EnabledTypes len = %d, want %d (%v)", len(cfg.EnabledTypes), len(want), cfg.EnabledTypes)
}
for i := range want {
if cfg.EnabledTypes[i] != want[i] {
t.Fatalf("EnabledTypes[%d] = %q, want %q (full=%v)", i, cfg.EnabledTypes[i], want[i], cfg.EnabledTypes)
}
}
}
func newPaymentConfigServiceTestClient(t *testing.T) *dbent.Client {
t.Helper()
db, err := sql.Open("sqlite", "file:payment_config_service?mode=memory&cache=shared")
if err != nil {
t.Fatalf("open sqlite: %v", err)
}
t.Cleanup(func() { _ = db.Close() })
if _, err := db.Exec("PRAGMA foreign_keys = ON"); err != nil {
t.Fatalf("enable foreign keys: %v", err)
}
drv := entsql.OpenDB(dialect.SQLite, db)
client := enttest.NewClient(t, enttest.WithOptions(dbent.Driver(drv)))
t.Cleanup(func() { _ = client.Close() })
return client
}
type paymentConfigSettingRepoStub struct {
values map[string]string
}
func (s *paymentConfigSettingRepoStub) Get(context.Context, string) (*Setting, error) {
return nil, nil
}
func (s *paymentConfigSettingRepoStub) GetValue(_ context.Context, key string) (string, error) {
return s.values[key], nil
}
func (s *paymentConfigSettingRepoStub) Set(context.Context, string, string) error { return nil }
func (s *paymentConfigSettingRepoStub) GetMultiple(_ context.Context, keys []string) (map[string]string, error) {
out := make(map[string]string, len(keys))
for _, key := range keys {
out[key] = s.values[key]
}
return out, nil
}
func (s *paymentConfigSettingRepoStub) SetMultiple(context.Context, map[string]string) error {
return nil
}
func (s *paymentConfigSettingRepoStub) GetAll(context.Context) (map[string]string, error) {
return s.values, nil
}
func (s *paymentConfigSettingRepoStub) Delete(context.Context, string) error { return nil }
package service
import (
"context"
"crypto/hmac"
"crypto/sha256"
"encoding/base64"
"encoding/json"
"fmt"
"net/url"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/payment"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
)
const (
PaymentSourceHostedRedirect = "hosted_redirect"
PaymentSourceWechatInAppResume = "wechat_in_app_resume"
paymentResumeFallbackSigningKey = "sub2api-payment-resume"
SettingPaymentVisibleMethodAlipaySource = "payment_visible_method_alipay_source"
SettingPaymentVisibleMethodWxpaySource = "payment_visible_method_wxpay_source"
SettingPaymentVisibleMethodAlipayEnabled = "payment_visible_method_alipay_enabled"
SettingPaymentVisibleMethodWxpayEnabled = "payment_visible_method_wxpay_enabled"
VisibleMethodSourceOfficialAlipay = "official_alipay"
VisibleMethodSourceEasyPayAlipay = "easypay_alipay"
VisibleMethodSourceOfficialWechat = "official_wxpay"
VisibleMethodSourceEasyPayWechat = "easypay_wxpay"
)
type ResumeTokenClaims struct {
OrderID int64 `json:"oid"`
UserID int64 `json:"uid,omitempty"`
ProviderInstanceID string `json:"pi,omitempty"`
ProviderKey string `json:"pk,omitempty"`
PaymentType string `json:"pt,omitempty"`
CanonicalReturnURL string `json:"ru,omitempty"`
IssuedAt int64 `json:"iat"`
}
type PaymentResumeService struct {
signingKey []byte
}
type visibleMethodLoadBalancer struct {
inner payment.LoadBalancer
configService *PaymentConfigService
}
func NewPaymentResumeService(signingKey []byte) *PaymentResumeService {
return &PaymentResumeService{signingKey: signingKey}
}
func NormalizeVisibleMethod(method string) string {
return payment.GetBasePaymentType(strings.TrimSpace(method))
}
func NormalizeVisibleMethods(methods []string) []string {
if len(methods) == 0 {
return nil
}
seen := make(map[string]struct{}, len(methods))
out := make([]string, 0, len(methods))
for _, method := range methods {
normalized := NormalizeVisibleMethod(method)
if normalized == "" {
continue
}
if _, ok := seen[normalized]; ok {
continue
}
seen[normalized] = struct{}{}
out = append(out, normalized)
}
return out
}
func NormalizePaymentSource(source string) string {
switch strings.TrimSpace(strings.ToLower(source)) {
case "", PaymentSourceHostedRedirect:
return PaymentSourceHostedRedirect
case "wechat_in_app", "wxpay_resume", PaymentSourceWechatInAppResume:
return PaymentSourceWechatInAppResume
default:
return strings.TrimSpace(strings.ToLower(source))
}
}
func NormalizeVisibleMethodSource(method, source string) string {
switch NormalizeVisibleMethod(method) {
case payment.TypeAlipay:
switch strings.TrimSpace(strings.ToLower(source)) {
case VisibleMethodSourceOfficialAlipay, payment.TypeAlipay, payment.TypeAlipayDirect, "official":
return VisibleMethodSourceOfficialAlipay
case VisibleMethodSourceEasyPayAlipay, payment.TypeEasyPay:
return VisibleMethodSourceEasyPayAlipay
}
case payment.TypeWxpay:
switch strings.TrimSpace(strings.ToLower(source)) {
case VisibleMethodSourceOfficialWechat, payment.TypeWxpay, payment.TypeWxpayDirect, "wechat", "official":
return VisibleMethodSourceOfficialWechat
case VisibleMethodSourceEasyPayWechat, payment.TypeEasyPay:
return VisibleMethodSourceEasyPayWechat
}
}
return ""
}
func VisibleMethodProviderKeyForSource(method, source string) (string, bool) {
switch NormalizeVisibleMethodSource(method, source) {
case VisibleMethodSourceOfficialAlipay:
return payment.TypeAlipay, NormalizeVisibleMethod(method) == payment.TypeAlipay
case VisibleMethodSourceEasyPayAlipay:
return payment.TypeEasyPay, NormalizeVisibleMethod(method) == payment.TypeAlipay
case VisibleMethodSourceOfficialWechat:
return payment.TypeWxpay, NormalizeVisibleMethod(method) == payment.TypeWxpay
case VisibleMethodSourceEasyPayWechat:
return payment.TypeEasyPay, NormalizeVisibleMethod(method) == payment.TypeWxpay
default:
return "", false
}
}
func newVisibleMethodLoadBalancer(inner payment.LoadBalancer, configService *PaymentConfigService) payment.LoadBalancer {
if inner == nil || configService == nil || configService.settingRepo == nil {
return inner
}
return &visibleMethodLoadBalancer{inner: inner, configService: configService}
}
func (lb *visibleMethodLoadBalancer) GetInstanceConfig(ctx context.Context, instanceID int64) (map[string]string, error) {
return lb.inner.GetInstanceConfig(ctx, instanceID)
}
func (lb *visibleMethodLoadBalancer) SelectInstance(ctx context.Context, providerKey string, paymentType payment.PaymentType, strategy payment.Strategy, orderAmount float64) (*payment.InstanceSelection, error) {
visibleMethod := NormalizeVisibleMethod(paymentType)
if providerKey != "" || (visibleMethod != payment.TypeAlipay && visibleMethod != payment.TypeWxpay) {
return lb.inner.SelectInstance(ctx, providerKey, paymentType, strategy, orderAmount)
}
enabledKey := visibleMethodEnabledSettingKey(visibleMethod)
sourceKey := visibleMethodSourceSettingKey(visibleMethod)
vals, err := lb.configService.settingRepo.GetMultiple(ctx, []string{enabledKey, sourceKey})
if err != nil {
return nil, fmt.Errorf("load visible method routing for %s: %w", visibleMethod, err)
}
if vals[enabledKey] != "true" {
return nil, fmt.Errorf("visible payment method %s is disabled", visibleMethod)
}
targetProviderKey, ok := VisibleMethodProviderKeyForSource(visibleMethod, vals[sourceKey])
if !ok {
return nil, fmt.Errorf("visible payment method %s has no valid source", visibleMethod)
}
return lb.inner.SelectInstance(ctx, targetProviderKey, paymentType, strategy, orderAmount)
}
func visibleMethodEnabledSettingKey(method string) string {
switch NormalizeVisibleMethod(method) {
case payment.TypeAlipay:
return SettingPaymentVisibleMethodAlipayEnabled
case payment.TypeWxpay:
return SettingPaymentVisibleMethodWxpayEnabled
default:
return ""
}
}
func visibleMethodSourceSettingKey(method string) string {
switch NormalizeVisibleMethod(method) {
case payment.TypeAlipay:
return SettingPaymentVisibleMethodAlipaySource
case payment.TypeWxpay:
return SettingPaymentVisibleMethodWxpaySource
default:
return ""
}
}
func CanonicalizeReturnURL(raw string) (string, error) {
raw = strings.TrimSpace(raw)
if raw == "" {
return "", nil
}
parsed, err := url.Parse(raw)
if err != nil || !parsed.IsAbs() || parsed.Host == "" {
return "", infraerrors.BadRequest("INVALID_RETURN_URL", "return_url must be an absolute http/https URL")
}
if parsed.Scheme != "http" && parsed.Scheme != "https" {
return "", infraerrors.BadRequest("INVALID_RETURN_URL", "return_url must use http or https")
}
parsed.Fragment = ""
if parsed.Path == "" {
parsed.Path = "/"
}
return parsed.String(), nil
}
func (s *PaymentResumeService) CreateToken(claims ResumeTokenClaims) (string, error) {
if claims.OrderID <= 0 {
return "", fmt.Errorf("resume token requires order id")
}
if claims.IssuedAt == 0 {
claims.IssuedAt = time.Now().Unix()
}
payload, err := json.Marshal(claims)
if err != nil {
return "", fmt.Errorf("marshal resume claims: %w", err)
}
encodedPayload := base64.RawURLEncoding.EncodeToString(payload)
return encodedPayload + "." + s.sign(encodedPayload), nil
}
func (s *PaymentResumeService) ParseToken(token string) (*ResumeTokenClaims, error) {
parts := strings.Split(token, ".")
if len(parts) != 2 || parts[0] == "" || parts[1] == "" {
return nil, infraerrors.BadRequest("INVALID_RESUME_TOKEN", "resume token is malformed")
}
if !hmac.Equal([]byte(parts[1]), []byte(s.sign(parts[0]))) {
return nil, infraerrors.BadRequest("INVALID_RESUME_TOKEN", "resume token signature mismatch")
}
payload, err := base64.RawURLEncoding.DecodeString(parts[0])
if err != nil {
return nil, infraerrors.BadRequest("INVALID_RESUME_TOKEN", "resume token payload is malformed")
}
var claims ResumeTokenClaims
if err := json.Unmarshal(payload, &claims); err != nil {
return nil, infraerrors.BadRequest("INVALID_RESUME_TOKEN", "resume token payload is invalid")
}
if claims.OrderID <= 0 {
return nil, infraerrors.BadRequest("INVALID_RESUME_TOKEN", "resume token missing order id")
}
return &claims, nil
}
func (s *PaymentResumeService) sign(payload string) string {
key := s.signingKey
if len(key) == 0 {
key = []byte(paymentResumeFallbackSigningKey)
}
mac := hmac.New(sha256.New, key)
_, _ = mac.Write([]byte(payload))
return base64.RawURLEncoding.EncodeToString(mac.Sum(nil))
}
//go:build unit
package service
import (
"context"
"testing"
"github.com/Wei-Shaw/sub2api/internal/payment"
)
func TestNormalizeVisibleMethods(t *testing.T) {
t.Parallel()
got := NormalizeVisibleMethods([]string{
"alipay_direct",
"alipay",
" wxpay_direct ",
"wxpay",
"stripe",
})
want := []string{"alipay", "wxpay", "stripe"}
if len(got) != len(want) {
t.Fatalf("NormalizeVisibleMethods len = %d, want %d (%v)", len(got), len(want), got)
}
for i := range want {
if got[i] != want[i] {
t.Fatalf("NormalizeVisibleMethods[%d] = %q, want %q (full=%v)", i, got[i], want[i], got)
}
}
}
func TestNormalizePaymentSource(t *testing.T) {
t.Parallel()
tests := []struct {
name string
input string
expect string
}{
{name: "empty uses default", input: "", expect: PaymentSourceHostedRedirect},
{name: "wechat alias normalized", input: "wechat_in_app", expect: PaymentSourceWechatInAppResume},
{name: "canonical value preserved", input: PaymentSourceWechatInAppResume, expect: PaymentSourceWechatInAppResume},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
if got := NormalizePaymentSource(tt.input); got != tt.expect {
t.Fatalf("NormalizePaymentSource(%q) = %q, want %q", tt.input, got, tt.expect)
}
})
}
}
func TestCanonicalizeReturnURL(t *testing.T) {
t.Parallel()
got, err := CanonicalizeReturnURL("https://example.com/pay/result?b=2#a")
if err != nil {
t.Fatalf("CanonicalizeReturnURL returned error: %v", err)
}
if got != "https://example.com/pay/result?b=2" {
t.Fatalf("CanonicalizeReturnURL = %q, want %q", got, "https://example.com/pay/result?b=2")
}
}
func TestCanonicalizeReturnURLRejectsRelativeURL(t *testing.T) {
t.Parallel()
if _, err := CanonicalizeReturnURL("/payment/result"); err == nil {
t.Fatal("CanonicalizeReturnURL should reject relative URLs")
}
}
func TestPaymentResumeTokenRoundTrip(t *testing.T) {
t.Parallel()
svc := NewPaymentResumeService([]byte("0123456789abcdef0123456789abcdef"))
token, err := svc.CreateToken(ResumeTokenClaims{
OrderID: 42,
UserID: 7,
ProviderInstanceID: "19",
ProviderKey: "easypay",
PaymentType: "wxpay",
CanonicalReturnURL: "https://example.com/payment/result",
IssuedAt: 1234567890,
})
if err != nil {
t.Fatalf("CreateToken returned error: %v", err)
}
claims, err := svc.ParseToken(token)
if err != nil {
t.Fatalf("ParseToken returned error: %v", err)
}
if claims.OrderID != 42 || claims.UserID != 7 {
t.Fatalf("claims mismatch: %+v", claims)
}
if claims.ProviderInstanceID != "19" || claims.ProviderKey != "easypay" || claims.PaymentType != "wxpay" {
t.Fatalf("claims provider snapshot mismatch: %+v", claims)
}
if claims.CanonicalReturnURL != "https://example.com/payment/result" {
t.Fatalf("claims return URL = %q", claims.CanonicalReturnURL)
}
}
func TestNormalizeVisibleMethodSource(t *testing.T) {
t.Parallel()
tests := []struct {
name string
method string
input string
want string
}{
{name: "alipay official alias", method: payment.TypeAlipay, input: "alipay", want: VisibleMethodSourceOfficialAlipay},
{name: "alipay easypay alias", method: payment.TypeAlipay, input: "easypay", want: VisibleMethodSourceEasyPayAlipay},
{name: "wxpay official alias", method: payment.TypeWxpay, input: "wxpay", want: VisibleMethodSourceOfficialWechat},
{name: "wxpay easypay alias", method: payment.TypeWxpay, input: "easypay", want: VisibleMethodSourceEasyPayWechat},
{name: "unsupported source", method: payment.TypeWxpay, input: "stripe", want: ""},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
if got := NormalizeVisibleMethodSource(tt.method, tt.input); got != tt.want {
t.Fatalf("NormalizeVisibleMethodSource(%q, %q) = %q, want %q", tt.method, tt.input, got, tt.want)
}
})
}
}
func TestVisibleMethodProviderKeyForSource(t *testing.T) {
t.Parallel()
tests := []struct {
name string
method string
source string
want string
ok bool
}{
{name: "official alipay", method: payment.TypeAlipay, source: VisibleMethodSourceOfficialAlipay, want: payment.TypeAlipay, ok: true},
{name: "easypay alipay", method: payment.TypeAlipay, source: VisibleMethodSourceEasyPayAlipay, want: payment.TypeEasyPay, ok: true},
{name: "official wechat", method: payment.TypeWxpay, source: VisibleMethodSourceOfficialWechat, want: payment.TypeWxpay, ok: true},
{name: "easypay wechat", method: payment.TypeWxpay, source: VisibleMethodSourceEasyPayWechat, want: payment.TypeEasyPay, ok: true},
{name: "mismatched method and source", method: payment.TypeAlipay, source: VisibleMethodSourceOfficialWechat, want: "", ok: false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
got, ok := VisibleMethodProviderKeyForSource(tt.method, tt.source)
if got != tt.want || ok != tt.ok {
t.Fatalf("VisibleMethodProviderKeyForSource(%q, %q) = (%q, %v), want (%q, %v)", tt.method, tt.source, got, ok, tt.want, tt.ok)
}
})
}
}
func TestVisibleMethodLoadBalancerUsesConfiguredSource(t *testing.T) {
t.Parallel()
inner := &captureLoadBalancer{}
configService := &PaymentConfigService{
settingRepo: &paymentSettingRepoStub{
values: map[string]string{
SettingPaymentVisibleMethodAlipayEnabled: "true",
SettingPaymentVisibleMethodAlipaySource: VisibleMethodSourceOfficialAlipay,
},
},
}
lb := newVisibleMethodLoadBalancer(inner, configService)
_, err := lb.SelectInstance(context.Background(), "", payment.TypeAlipay, payment.StrategyRoundRobin, 12.5)
if err != nil {
t.Fatalf("SelectInstance returned error: %v", err)
}
if inner.lastProviderKey != payment.TypeAlipay {
t.Fatalf("lastProviderKey = %q, want %q", inner.lastProviderKey, payment.TypeAlipay)
}
}
func TestVisibleMethodLoadBalancerRejectsDisabledVisibleMethod(t *testing.T) {
t.Parallel()
inner := &captureLoadBalancer{}
configService := &PaymentConfigService{
settingRepo: &paymentSettingRepoStub{
values: map[string]string{
SettingPaymentVisibleMethodWxpayEnabled: "false",
SettingPaymentVisibleMethodWxpaySource: VisibleMethodSourceOfficialWechat,
},
},
}
lb := newVisibleMethodLoadBalancer(inner, configService)
if _, err := lb.SelectInstance(context.Background(), "", payment.TypeWxpay, payment.StrategyRoundRobin, 9.9); err == nil {
t.Fatal("SelectInstance should reject disabled visible method")
}
}
type paymentSettingRepoStub struct {
values map[string]string
}
func (s *paymentSettingRepoStub) Get(context.Context, string) (*Setting, error) { return nil, nil }
func (s *paymentSettingRepoStub) GetValue(_ context.Context, key string) (string, error) {
return s.values[key], nil
}
func (s *paymentSettingRepoStub) Set(context.Context, string, string) error { return nil }
func (s *paymentSettingRepoStub) GetMultiple(_ context.Context, keys []string) (map[string]string, error) {
out := make(map[string]string, len(keys))
for _, key := range keys {
out[key] = s.values[key]
}
return out, nil
}
func (s *paymentSettingRepoStub) SetMultiple(context.Context, map[string]string) error { return nil }
func (s *paymentSettingRepoStub) GetAll(context.Context) (map[string]string, error) {
return s.values, nil
}
func (s *paymentSettingRepoStub) Delete(context.Context, string) error { return nil }
type captureLoadBalancer struct {
lastProviderKey string
lastPaymentType string
}
func (c *captureLoadBalancer) GetInstanceConfig(context.Context, int64) (map[string]string, error) {
return map[string]string{}, nil
}
func (c *captureLoadBalancer) SelectInstance(_ context.Context, providerKey string, paymentType payment.PaymentType, _ payment.Strategy, _ float64) (*payment.InstanceSelection, error) {
c.lastProviderKey = providerKey
c.lastPaymentType = paymentType
return &payment.InstanceSelection{ProviderKey: providerKey, SupportedTypes: paymentType}, nil
}
...@@ -65,15 +65,17 @@ func generateRandomString(n int) string { ...@@ -65,15 +65,17 @@ func generateRandomString(n int) string {
} }
type CreateOrderRequest struct { type CreateOrderRequest struct {
UserID int64 UserID int64
Amount float64 Amount float64
PaymentType string PaymentType string
ClientIP string ClientIP string
IsMobile bool IsMobile bool
SrcHost string SrcHost string
SrcURL string SrcURL string
OrderType string ReturnURL string
PlanID int64 PaymentSource string
OrderType string
PlanID int64
} }
type CreateOrderResponse struct { type CreateOrderResponse struct {
...@@ -88,6 +90,7 @@ type CreateOrderResponse struct { ...@@ -88,6 +90,7 @@ type CreateOrderResponse struct {
ClientSecret string `json:"client_secret,omitempty"` ClientSecret string `json:"client_secret,omitempty"`
ExpiresAt time.Time `json:"expires_at"` ExpiresAt time.Time `json:"expires_at"`
PaymentMode string `json:"payment_mode,omitempty"` PaymentMode string `json:"payment_mode,omitempty"`
ResumeToken string `json:"resume_token,omitempty"`
} }
type OrderListParams struct { type OrderListParams struct {
...@@ -165,10 +168,13 @@ type PaymentService struct { ...@@ -165,10 +168,13 @@ type PaymentService struct {
configService *PaymentConfigService configService *PaymentConfigService
userRepo UserRepository userRepo UserRepository
groupRepo GroupRepository groupRepo GroupRepository
resumeService *PaymentResumeService
} }
func NewPaymentService(entClient *dbent.Client, registry *payment.Registry, loadBalancer payment.LoadBalancer, redeemService *RedeemService, subscriptionSvc *SubscriptionService, configService *PaymentConfigService, userRepo UserRepository, groupRepo GroupRepository) *PaymentService { func NewPaymentService(entClient *dbent.Client, registry *payment.Registry, loadBalancer payment.LoadBalancer, redeemService *RedeemService, subscriptionSvc *SubscriptionService, configService *PaymentConfigService, userRepo UserRepository, groupRepo GroupRepository) *PaymentService {
return &PaymentService{entClient: entClient, registry: registry, loadBalancer: loadBalancer, redeemService: redeemService, subscriptionSvc: subscriptionSvc, configService: configService, userRepo: userRepo, groupRepo: groupRepo} svc := &PaymentService{entClient: entClient, registry: registry, loadBalancer: newVisibleMethodLoadBalancer(loadBalancer, configService), redeemService: redeemService, subscriptionSvc: subscriptionSvc, configService: configService, userRepo: userRepo, groupRepo: groupRepo}
svc.resumeService = NewPaymentResumeService(psResumeSigningKey(configService))
return svc
} }
// --- Provider Registry --- // --- Provider Registry ---
...@@ -262,6 +268,20 @@ func psNilIfEmpty(s string) *string { ...@@ -262,6 +268,20 @@ func psNilIfEmpty(s string) *string {
return &s return &s
} }
func (s *PaymentService) paymentResume() *PaymentResumeService {
if s.resumeService != nil {
return s.resumeService
}
return NewPaymentResumeService(psResumeSigningKey(s.configService))
}
func psResumeSigningKey(configService *PaymentConfigService) []byte {
if configService == nil {
return nil
}
return configService.encryptionKey
}
func psSliceContains(sl []string, s string) bool { func psSliceContains(sl []string, s string) bool {
for _, v := range sl { for _, v := range sl {
if v == s { if v == s {
......
...@@ -9,6 +9,7 @@ import ( ...@@ -9,6 +9,7 @@ import (
"fmt" "fmt"
"log/slog" "log/slog"
"net/url" "net/url"
"os"
"sort" "sort"
"strconv" "strconv"
"strings" "strings"
...@@ -114,6 +115,66 @@ type SettingService struct { ...@@ -114,6 +115,66 @@ type SettingService struct {
webSearchManagerBuilder WebSearchManagerBuilder webSearchManagerBuilder WebSearchManagerBuilder
} }
type ProviderDefaultGrantSettings struct {
Balance float64
Concurrency int
Subscriptions []DefaultSubscriptionSetting
GrantOnSignup bool
GrantOnFirstBind bool
}
type AuthSourceDefaultSettings struct {
Email ProviderDefaultGrantSettings
LinuxDo ProviderDefaultGrantSettings
OIDC ProviderDefaultGrantSettings
WeChat ProviderDefaultGrantSettings
ForceEmailOnThirdPartySignup bool
}
type authSourceDefaultKeySet struct {
balance string
concurrency string
subscriptions string
grantOnSignup string
grantOnFirstBind string
}
var (
emailAuthSourceDefaultKeys = authSourceDefaultKeySet{
balance: SettingKeyAuthSourceDefaultEmailBalance,
concurrency: SettingKeyAuthSourceDefaultEmailConcurrency,
subscriptions: SettingKeyAuthSourceDefaultEmailSubscriptions,
grantOnSignup: SettingKeyAuthSourceDefaultEmailGrantOnSignup,
grantOnFirstBind: SettingKeyAuthSourceDefaultEmailGrantOnFirstBind,
}
linuxDoAuthSourceDefaultKeys = authSourceDefaultKeySet{
balance: SettingKeyAuthSourceDefaultLinuxDoBalance,
concurrency: SettingKeyAuthSourceDefaultLinuxDoConcurrency,
subscriptions: SettingKeyAuthSourceDefaultLinuxDoSubscriptions,
grantOnSignup: SettingKeyAuthSourceDefaultLinuxDoGrantOnSignup,
grantOnFirstBind: SettingKeyAuthSourceDefaultLinuxDoGrantOnFirstBind,
}
oidcAuthSourceDefaultKeys = authSourceDefaultKeySet{
balance: SettingKeyAuthSourceDefaultOIDCBalance,
concurrency: SettingKeyAuthSourceDefaultOIDCConcurrency,
subscriptions: SettingKeyAuthSourceDefaultOIDCSubscriptions,
grantOnSignup: SettingKeyAuthSourceDefaultOIDCGrantOnSignup,
grantOnFirstBind: SettingKeyAuthSourceDefaultOIDCGrantOnFirstBind,
}
weChatAuthSourceDefaultKeys = authSourceDefaultKeySet{
balance: SettingKeyAuthSourceDefaultWeChatBalance,
concurrency: SettingKeyAuthSourceDefaultWeChatConcurrency,
subscriptions: SettingKeyAuthSourceDefaultWeChatSubscriptions,
grantOnSignup: SettingKeyAuthSourceDefaultWeChatGrantOnSignup,
grantOnFirstBind: SettingKeyAuthSourceDefaultWeChatGrantOnFirstBind,
}
)
const (
defaultAuthSourceBalance = 0
defaultAuthSourceConcurrency = 5
)
// NewSettingService 创建系统设置服务实例 // NewSettingService 创建系统设置服务实例
func NewSettingService(settingRepo SettingRepository, cfg *config.Config) *SettingService { func NewSettingService(settingRepo SettingRepository, cfg *config.Config) *SettingService {
return &SettingService{ return &SettingService{
...@@ -212,6 +273,7 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings ...@@ -212,6 +273,7 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
if oidcProviderName == "" { if oidcProviderName == "" {
oidcProviderName = "OIDC" oidcProviderName = "OIDC"
} }
weChatEnabled := isWeChatOAuthConfigured()
// Password reset requires email verification to be enabled // Password reset requires email verification to be enabled
emailVerifyEnabled := settings[SettingKeyEmailVerifyEnabled] == "true" emailVerifyEnabled := settings[SettingKeyEmailVerifyEnabled] == "true"
...@@ -254,6 +316,7 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings ...@@ -254,6 +316,7 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
CustomMenuItems: settings[SettingKeyCustomMenuItems], CustomMenuItems: settings[SettingKeyCustomMenuItems],
CustomEndpoints: settings[SettingKeyCustomEndpoints], CustomEndpoints: settings[SettingKeyCustomEndpoints],
LinuxDoOAuthEnabled: linuxDoEnabled, LinuxDoOAuthEnabled: linuxDoEnabled,
WeChatOAuthEnabled: weChatEnabled,
BackendModeEnabled: settings[SettingKeyBackendModeEnabled] == "true", BackendModeEnabled: settings[SettingKeyBackendModeEnabled] == "true",
PaymentEnabled: settings[SettingPaymentEnabled] == "true", PaymentEnabled: settings[SettingPaymentEnabled] == "true",
OIDCOAuthEnabled: oidcEnabled, OIDCOAuthEnabled: oidcEnabled,
...@@ -310,6 +373,7 @@ func (s *SettingService) GetPublicSettingsForInjection(ctx context.Context) (any ...@@ -310,6 +373,7 @@ func (s *SettingService) GetPublicSettingsForInjection(ctx context.Context) (any
CustomMenuItems json.RawMessage `json:"custom_menu_items"` CustomMenuItems json.RawMessage `json:"custom_menu_items"`
CustomEndpoints json.RawMessage `json:"custom_endpoints"` CustomEndpoints json.RawMessage `json:"custom_endpoints"`
LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"` LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"`
WeChatOAuthEnabled bool `json:"wechat_oauth_enabled"`
BackendModeEnabled bool `json:"backend_mode_enabled"` BackendModeEnabled bool `json:"backend_mode_enabled"`
PaymentEnabled bool `json:"payment_enabled"` PaymentEnabled bool `json:"payment_enabled"`
OIDCOAuthEnabled bool `json:"oidc_oauth_enabled"` OIDCOAuthEnabled bool `json:"oidc_oauth_enabled"`
...@@ -344,6 +408,7 @@ func (s *SettingService) GetPublicSettingsForInjection(ctx context.Context) (any ...@@ -344,6 +408,7 @@ func (s *SettingService) GetPublicSettingsForInjection(ctx context.Context) (any
CustomMenuItems: filterUserVisibleMenuItems(settings.CustomMenuItems), CustomMenuItems: filterUserVisibleMenuItems(settings.CustomMenuItems),
CustomEndpoints: safeRawJSONArray(settings.CustomEndpoints), CustomEndpoints: safeRawJSONArray(settings.CustomEndpoints),
LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled, LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled,
WeChatOAuthEnabled: settings.WeChatOAuthEnabled,
BackendModeEnabled: settings.BackendModeEnabled, BackendModeEnabled: settings.BackendModeEnabled,
PaymentEnabled: settings.PaymentEnabled, PaymentEnabled: settings.PaymentEnabled,
OIDCOAuthEnabled: settings.OIDCOAuthEnabled, OIDCOAuthEnabled: settings.OIDCOAuthEnabled,
...@@ -392,6 +457,14 @@ func filterUserVisibleMenuItems(raw string) json.RawMessage { ...@@ -392,6 +457,14 @@ func filterUserVisibleMenuItems(raw string) json.RawMessage {
return result return result
} }
func isWeChatOAuthConfigured() bool {
openConfigured := strings.TrimSpace(os.Getenv("WECHAT_OAUTH_OPEN_APP_ID")) != "" &&
strings.TrimSpace(os.Getenv("WECHAT_OAUTH_OPEN_APP_SECRET")) != ""
mpConfigured := strings.TrimSpace(os.Getenv("WECHAT_OAUTH_MP_APP_ID")) != "" &&
strings.TrimSpace(os.Getenv("WECHAT_OAUTH_MP_APP_SECRET")) != ""
return openConfigured || mpConfigured
}
// safeRawJSONArray returns raw as json.RawMessage if it's valid JSON, otherwise "[]". // safeRawJSONArray returns raw as json.RawMessage if it's valid JSON, otherwise "[]".
func safeRawJSONArray(raw string) json.RawMessage { func safeRawJSONArray(raw string) json.RawMessage {
raw = strings.TrimSpace(raw) raw = strings.TrimSpace(raw)
...@@ -919,6 +992,74 @@ func (s *SettingService) GetDefaultSubscriptions(ctx context.Context) []DefaultS ...@@ -919,6 +992,74 @@ func (s *SettingService) GetDefaultSubscriptions(ctx context.Context) []DefaultS
return parseDefaultSubscriptions(value) return parseDefaultSubscriptions(value)
} }
func (s *SettingService) GetAuthSourceDefaultSettings(ctx context.Context) (*AuthSourceDefaultSettings, error) {
keys := []string{
SettingKeyAuthSourceDefaultEmailBalance,
SettingKeyAuthSourceDefaultEmailConcurrency,
SettingKeyAuthSourceDefaultEmailSubscriptions,
SettingKeyAuthSourceDefaultEmailGrantOnSignup,
SettingKeyAuthSourceDefaultEmailGrantOnFirstBind,
SettingKeyAuthSourceDefaultLinuxDoBalance,
SettingKeyAuthSourceDefaultLinuxDoConcurrency,
SettingKeyAuthSourceDefaultLinuxDoSubscriptions,
SettingKeyAuthSourceDefaultLinuxDoGrantOnSignup,
SettingKeyAuthSourceDefaultLinuxDoGrantOnFirstBind,
SettingKeyAuthSourceDefaultOIDCBalance,
SettingKeyAuthSourceDefaultOIDCConcurrency,
SettingKeyAuthSourceDefaultOIDCSubscriptions,
SettingKeyAuthSourceDefaultOIDCGrantOnSignup,
SettingKeyAuthSourceDefaultOIDCGrantOnFirstBind,
SettingKeyAuthSourceDefaultWeChatBalance,
SettingKeyAuthSourceDefaultWeChatConcurrency,
SettingKeyAuthSourceDefaultWeChatSubscriptions,
SettingKeyAuthSourceDefaultWeChatGrantOnSignup,
SettingKeyAuthSourceDefaultWeChatGrantOnFirstBind,
SettingKeyForceEmailOnThirdPartySignup,
}
settings, err := s.settingRepo.GetMultiple(ctx, keys)
if err != nil {
return nil, fmt.Errorf("get auth source default settings: %w", err)
}
return &AuthSourceDefaultSettings{
Email: parseProviderDefaultGrantSettings(settings, emailAuthSourceDefaultKeys),
LinuxDo: parseProviderDefaultGrantSettings(settings, linuxDoAuthSourceDefaultKeys),
OIDC: parseProviderDefaultGrantSettings(settings, oidcAuthSourceDefaultKeys),
WeChat: parseProviderDefaultGrantSettings(settings, weChatAuthSourceDefaultKeys),
ForceEmailOnThirdPartySignup: settings[SettingKeyForceEmailOnThirdPartySignup] == "true",
}, nil
}
func (s *SettingService) UpdateAuthSourceDefaultSettings(ctx context.Context, settings *AuthSourceDefaultSettings) error {
if settings == nil {
return nil
}
for _, subscriptions := range [][]DefaultSubscriptionSetting{
settings.Email.Subscriptions,
settings.LinuxDo.Subscriptions,
settings.OIDC.Subscriptions,
settings.WeChat.Subscriptions,
} {
if err := s.validateDefaultSubscriptionGroups(ctx, subscriptions); err != nil {
return err
}
}
updates := make(map[string]string, 21)
writeProviderDefaultGrantUpdates(updates, emailAuthSourceDefaultKeys, settings.Email)
writeProviderDefaultGrantUpdates(updates, linuxDoAuthSourceDefaultKeys, settings.LinuxDo)
writeProviderDefaultGrantUpdates(updates, oidcAuthSourceDefaultKeys, settings.OIDC)
writeProviderDefaultGrantUpdates(updates, weChatAuthSourceDefaultKeys, settings.WeChat)
updates[SettingKeyForceEmailOnThirdPartySignup] = strconv.FormatBool(settings.ForceEmailOnThirdPartySignup)
if err := s.settingRepo.SetMultiple(ctx, updates); err != nil {
return fmt.Errorf("update auth source default settings: %w", err)
}
return nil
}
// InitializeDefaultSettings 初始化默认设置 // InitializeDefaultSettings 初始化默认设置
func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error { func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error {
// 检查是否已有设置 // 检查是否已有设置
...@@ -933,25 +1074,46 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error { ...@@ -933,25 +1074,46 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error {
// 初始化默认设置 // 初始化默认设置
defaults := map[string]string{ defaults := map[string]string{
SettingKeyRegistrationEnabled: "true", SettingKeyRegistrationEnabled: "true",
SettingKeyEmailVerifyEnabled: "false", SettingKeyEmailVerifyEnabled: "false",
SettingKeyRegistrationEmailSuffixWhitelist: "[]", SettingKeyRegistrationEmailSuffixWhitelist: "[]",
SettingKeyPromoCodeEnabled: "true", // 默认启用优惠码功能 SettingKeyPromoCodeEnabled: "true", // 默认启用优惠码功能
SettingKeySiteName: "Sub2API", SettingKeySiteName: "Sub2API",
SettingKeySiteLogo: "", SettingKeySiteLogo: "",
SettingKeyPurchaseSubscriptionEnabled: "false", SettingKeyPurchaseSubscriptionEnabled: "false",
SettingKeyPurchaseSubscriptionURL: "", SettingKeyPurchaseSubscriptionURL: "",
SettingKeyTableDefaultPageSize: "20", SettingKeyTableDefaultPageSize: "20",
SettingKeyTablePageSizeOptions: "[10,20,50,100]", SettingKeyTablePageSizeOptions: "[10,20,50,100]",
SettingKeyCustomMenuItems: "[]", SettingKeyCustomMenuItems: "[]",
SettingKeyCustomEndpoints: "[]", SettingKeyCustomEndpoints: "[]",
SettingKeyOIDCConnectEnabled: "false", SettingKeyOIDCConnectEnabled: "false",
SettingKeyOIDCConnectProviderName: "OIDC", SettingKeyOIDCConnectProviderName: "OIDC",
SettingKeyDefaultConcurrency: strconv.Itoa(s.cfg.Default.UserConcurrency), SettingKeyDefaultConcurrency: strconv.Itoa(s.cfg.Default.UserConcurrency),
SettingKeyDefaultBalance: strconv.FormatFloat(s.cfg.Default.UserBalance, 'f', 8, 64), SettingKeyDefaultBalance: strconv.FormatFloat(s.cfg.Default.UserBalance, 'f', 8, 64),
SettingKeyDefaultSubscriptions: "[]", SettingKeyDefaultSubscriptions: "[]",
SettingKeySMTPPort: "587", SettingKeyAuthSourceDefaultEmailBalance: "0",
SettingKeySMTPUseTLS: "false", SettingKeyAuthSourceDefaultEmailConcurrency: "5",
SettingKeyAuthSourceDefaultEmailSubscriptions: "[]",
SettingKeyAuthSourceDefaultEmailGrantOnSignup: "true",
SettingKeyAuthSourceDefaultEmailGrantOnFirstBind: "false",
SettingKeyAuthSourceDefaultLinuxDoBalance: "0",
SettingKeyAuthSourceDefaultLinuxDoConcurrency: "5",
SettingKeyAuthSourceDefaultLinuxDoSubscriptions: "[]",
SettingKeyAuthSourceDefaultLinuxDoGrantOnSignup: "true",
SettingKeyAuthSourceDefaultLinuxDoGrantOnFirstBind: "false",
SettingKeyAuthSourceDefaultOIDCBalance: "0",
SettingKeyAuthSourceDefaultOIDCConcurrency: "5",
SettingKeyAuthSourceDefaultOIDCSubscriptions: "[]",
SettingKeyAuthSourceDefaultOIDCGrantOnSignup: "true",
SettingKeyAuthSourceDefaultOIDCGrantOnFirstBind: "false",
SettingKeyAuthSourceDefaultWeChatBalance: "0",
SettingKeyAuthSourceDefaultWeChatConcurrency: "5",
SettingKeyAuthSourceDefaultWeChatSubscriptions: "[]",
SettingKeyAuthSourceDefaultWeChatGrantOnSignup: "true",
SettingKeyAuthSourceDefaultWeChatGrantOnFirstBind: "false",
SettingKeyForceEmailOnThirdPartySignup: "false",
SettingKeySMTPPort: "587",
SettingKeySMTPUseTLS: "false",
// Model fallback defaults // Model fallback defaults
SettingKeyEnableModelFallback: "false", SettingKeyEnableModelFallback: "false",
SettingKeyFallbackModelAnthropic: "claude-3-5-sonnet-20241022", SettingKeyFallbackModelAnthropic: "claude-3-5-sonnet-20241022",
...@@ -1164,6 +1326,8 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin ...@@ -1164,6 +1326,8 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin
} else { } else {
result.OIDCConnectValidateIDToken = oidcBase.ValidateIDToken result.OIDCConnectValidateIDToken = oidcBase.ValidateIDToken
} }
result.OIDCConnectUsePKCE = true
result.OIDCConnectValidateIDToken = true
if v, ok := settings[SettingKeyOIDCConnectAllowedSigningAlgs]; ok && strings.TrimSpace(v) != "" { if v, ok := settings[SettingKeyOIDCConnectAllowedSigningAlgs]; ok && strings.TrimSpace(v) != "" {
result.OIDCConnectAllowedSigningAlgs = strings.TrimSpace(v) result.OIDCConnectAllowedSigningAlgs = strings.TrimSpace(v)
} else { } else {
...@@ -1317,6 +1481,51 @@ func parseDefaultSubscriptions(raw string) []DefaultSubscriptionSetting { ...@@ -1317,6 +1481,51 @@ func parseDefaultSubscriptions(raw string) []DefaultSubscriptionSetting {
return normalized return normalized
} }
func parseProviderDefaultGrantSettings(settings map[string]string, keys authSourceDefaultKeySet) ProviderDefaultGrantSettings {
result := ProviderDefaultGrantSettings{
Balance: defaultAuthSourceBalance,
Concurrency: defaultAuthSourceConcurrency,
Subscriptions: []DefaultSubscriptionSetting{},
GrantOnSignup: true,
GrantOnFirstBind: false,
}
if v, err := strconv.ParseFloat(strings.TrimSpace(settings[keys.balance]), 64); err == nil {
result.Balance = v
}
if v, err := strconv.Atoi(strings.TrimSpace(settings[keys.concurrency])); err == nil {
result.Concurrency = v
}
if items := parseDefaultSubscriptions(settings[keys.subscriptions]); items != nil {
result.Subscriptions = items
}
if raw, ok := settings[keys.grantOnSignup]; ok {
result.GrantOnSignup = raw == "true"
}
if raw, ok := settings[keys.grantOnFirstBind]; ok {
result.GrantOnFirstBind = raw == "true"
}
return result
}
func writeProviderDefaultGrantUpdates(updates map[string]string, keys authSourceDefaultKeySet, settings ProviderDefaultGrantSettings) {
updates[keys.balance] = strconv.FormatFloat(settings.Balance, 'f', 8, 64)
updates[keys.concurrency] = strconv.Itoa(settings.Concurrency)
subscriptions := settings.Subscriptions
if subscriptions == nil {
subscriptions = []DefaultSubscriptionSetting{}
}
raw, err := json.Marshal(subscriptions)
if err != nil {
raw = []byte("[]")
}
updates[keys.subscriptions] = string(raw)
updates[keys.grantOnSignup] = strconv.FormatBool(settings.GrantOnSignup)
updates[keys.grantOnFirstBind] = strconv.FormatBool(settings.GrantOnFirstBind)
}
func parseTablePreferences(defaultPageSizeRaw, optionsRaw string) (int, []int) { func parseTablePreferences(defaultPageSizeRaw, optionsRaw string) (int, []int) {
defaultPageSize := 20 defaultPageSize := 20
if v, err := strconv.Atoi(strings.TrimSpace(defaultPageSizeRaw)); err == nil { if v, err := strconv.Atoi(strings.TrimSpace(defaultPageSizeRaw)); err == nil {
...@@ -1539,6 +1748,7 @@ func (s *SettingService) GetLinuxDoConnectOAuthConfig(ctx context.Context) (conf ...@@ -1539,6 +1748,7 @@ func (s *SettingService) GetLinuxDoConnectOAuthConfig(ctx context.Context) (conf
if v, ok := settings[SettingKeyLinuxDoConnectRedirectURL]; ok && strings.TrimSpace(v) != "" { if v, ok := settings[SettingKeyLinuxDoConnectRedirectURL]; ok && strings.TrimSpace(v) != "" {
effective.RedirectURL = strings.TrimSpace(v) effective.RedirectURL = strings.TrimSpace(v)
} }
effective.UsePKCE = true
if !effective.Enabled { if !effective.Enabled {
return config.LinuxDoConnectConfig{}, infraerrors.NotFound("OAUTH_DISABLED", "oauth login is disabled") return config.LinuxDoConnectConfig{}, infraerrors.NotFound("OAUTH_DISABLED", "oauth login is disabled")
...@@ -1587,9 +1797,6 @@ func (s *SettingService) GetLinuxDoConnectOAuthConfig(ctx context.Context) (conf ...@@ -1587,9 +1797,6 @@ func (s *SettingService) GetLinuxDoConnectOAuthConfig(ctx context.Context) (conf
return config.LinuxDoConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth client secret not configured") return config.LinuxDoConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth client secret not configured")
} }
case "none": case "none":
if !effective.UsePKCE {
return config.LinuxDoConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth pkce must be enabled when token_auth_method=none")
}
default: default:
return config.LinuxDoConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth token_auth_method invalid") return config.LinuxDoConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth token_auth_method invalid")
} }
...@@ -1737,6 +1944,8 @@ func (s *SettingService) GetOIDCConnectOAuthConfig(ctx context.Context) (config. ...@@ -1737,6 +1944,8 @@ func (s *SettingService) GetOIDCConnectOAuthConfig(ctx context.Context) (config.
if raw, ok := settings[SettingKeyOIDCConnectValidateIDToken]; ok { if raw, ok := settings[SettingKeyOIDCConnectValidateIDToken]; ok {
effective.ValidateIDToken = raw == "true" effective.ValidateIDToken = raw == "true"
} }
effective.UsePKCE = true
effective.ValidateIDToken = true
if v, ok := settings[SettingKeyOIDCConnectAllowedSigningAlgs]; ok && strings.TrimSpace(v) != "" { if v, ok := settings[SettingKeyOIDCConnectAllowedSigningAlgs]; ok && strings.TrimSpace(v) != "" {
effective.AllowedSigningAlgs = strings.TrimSpace(v) effective.AllowedSigningAlgs = strings.TrimSpace(v)
} }
...@@ -1864,9 +2073,6 @@ func (s *SettingService) GetOIDCConnectOAuthConfig(ctx context.Context) (config. ...@@ -1864,9 +2073,6 @@ func (s *SettingService) GetOIDCConnectOAuthConfig(ctx context.Context) (config.
return config.OIDCConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth client secret not configured") return config.OIDCConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth client secret not configured")
} }
case "none": case "none":
if !effective.UsePKCE {
return config.OIDCConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth pkce must be enabled when token_auth_method=none")
}
default: default:
return config.OIDCConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth token_auth_method invalid") return config.OIDCConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth token_auth_method invalid")
} }
......
//go:build unit
package service
import (
"context"
"encoding/json"
"testing"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/stretchr/testify/require"
)
type authSourceDefaultsRepoStub struct {
values map[string]string
updates map[string]string
}
func (s *authSourceDefaultsRepoStub) Get(ctx context.Context, key string) (*Setting, error) {
panic("unexpected Get call")
}
func (s *authSourceDefaultsRepoStub) GetValue(ctx context.Context, key string) (string, error) {
panic("unexpected GetValue call")
}
func (s *authSourceDefaultsRepoStub) Set(ctx context.Context, key, value string) error {
panic("unexpected Set call")
}
func (s *authSourceDefaultsRepoStub) GetMultiple(ctx context.Context, keys []string) (map[string]string, error) {
out := make(map[string]string, len(keys))
for _, key := range keys {
if value, ok := s.values[key]; ok {
out[key] = value
}
}
return out, nil
}
func (s *authSourceDefaultsRepoStub) SetMultiple(ctx context.Context, settings map[string]string) error {
s.updates = make(map[string]string, len(settings))
for key, value := range settings {
s.updates[key] = value
if s.values == nil {
s.values = map[string]string{}
}
s.values[key] = value
}
return nil
}
func (s *authSourceDefaultsRepoStub) GetAll(ctx context.Context) (map[string]string, error) {
panic("unexpected GetAll call")
}
func (s *authSourceDefaultsRepoStub) Delete(ctx context.Context, key string) error {
panic("unexpected Delete call")
}
func TestSettingService_GetAuthSourceDefaultSettings_ParsesValuesAndDefaults(t *testing.T) {
repo := &authSourceDefaultsRepoStub{
values: map[string]string{
SettingKeyAuthSourceDefaultEmailBalance: "12.5",
SettingKeyAuthSourceDefaultEmailConcurrency: "7",
SettingKeyAuthSourceDefaultEmailSubscriptions: `[{"group_id":11,"validity_days":30}]`,
SettingKeyAuthSourceDefaultEmailGrantOnSignup: "false",
SettingKeyAuthSourceDefaultLinuxDoGrantOnFirstBind: "true",
SettingKeyForceEmailOnThirdPartySignup: "true",
},
}
svc := NewSettingService(repo, &config.Config{})
got, err := svc.GetAuthSourceDefaultSettings(context.Background())
require.NoError(t, err)
require.Equal(t, 12.5, got.Email.Balance)
require.Equal(t, 7, got.Email.Concurrency)
require.Equal(t, []DefaultSubscriptionSetting{{GroupID: 11, ValidityDays: 30}}, got.Email.Subscriptions)
require.False(t, got.Email.GrantOnSignup)
require.False(t, got.Email.GrantOnFirstBind)
require.Equal(t, 0.0, got.LinuxDo.Balance)
require.Equal(t, 5, got.LinuxDo.Concurrency)
require.Equal(t, []DefaultSubscriptionSetting{}, got.LinuxDo.Subscriptions)
require.True(t, got.LinuxDo.GrantOnSignup)
require.True(t, got.LinuxDo.GrantOnFirstBind)
require.Equal(t, 5, got.OIDC.Concurrency)
require.Equal(t, 5, got.WeChat.Concurrency)
require.True(t, got.ForceEmailOnThirdPartySignup)
}
func TestSettingService_UpdateAuthSourceDefaultSettings_PersistsAllKeys(t *testing.T) {
repo := &authSourceDefaultsRepoStub{}
svc := NewSettingService(repo, &config.Config{})
err := svc.UpdateAuthSourceDefaultSettings(context.Background(), &AuthSourceDefaultSettings{
Email: ProviderDefaultGrantSettings{
Balance: 1.25,
Concurrency: 3,
Subscriptions: []DefaultSubscriptionSetting{{GroupID: 21, ValidityDays: 14}},
GrantOnSignup: false,
GrantOnFirstBind: true,
},
LinuxDo: ProviderDefaultGrantSettings{
Balance: 2,
Concurrency: 4,
Subscriptions: []DefaultSubscriptionSetting{{GroupID: 22, ValidityDays: 30}},
GrantOnSignup: true,
GrantOnFirstBind: false,
},
OIDC: ProviderDefaultGrantSettings{
Balance: 3,
Concurrency: 5,
Subscriptions: []DefaultSubscriptionSetting{{GroupID: 23, ValidityDays: 60}},
GrantOnSignup: true,
GrantOnFirstBind: true,
},
WeChat: ProviderDefaultGrantSettings{
Balance: 4,
Concurrency: 6,
Subscriptions: []DefaultSubscriptionSetting{{GroupID: 24, ValidityDays: 90}},
GrantOnSignup: false,
GrantOnFirstBind: false,
},
ForceEmailOnThirdPartySignup: true,
})
require.NoError(t, err)
require.Equal(t, "1.25000000", repo.updates[SettingKeyAuthSourceDefaultEmailBalance])
require.Equal(t, "3", repo.updates[SettingKeyAuthSourceDefaultEmailConcurrency])
require.Equal(t, "false", repo.updates[SettingKeyAuthSourceDefaultEmailGrantOnSignup])
require.Equal(t, "true", repo.updates[SettingKeyAuthSourceDefaultEmailGrantOnFirstBind])
require.Equal(t, "true", repo.updates[SettingKeyForceEmailOnThirdPartySignup])
var got []DefaultSubscriptionSetting
require.NoError(t, json.Unmarshal([]byte(repo.updates[SettingKeyAuthSourceDefaultWeChatSubscriptions]), &got))
require.Equal(t, []DefaultSubscriptionSetting{{GroupID: 24, ValidityDays: 90}}, got)
}
...@@ -152,6 +152,7 @@ type PublicSettings struct { ...@@ -152,6 +152,7 @@ type PublicSettings struct {
CustomEndpoints string // JSON array of custom endpoints CustomEndpoints string // JSON array of custom endpoints
LinuxDoOAuthEnabled bool LinuxDoOAuthEnabled bool
WeChatOAuthEnabled bool
BackendModeEnabled bool BackendModeEnabled bool
PaymentEnabled bool PaymentEnabled bool
OIDCOAuthEnabled bool OIDCOAuthEnabled bool
......
...@@ -7,19 +7,27 @@ import ( ...@@ -7,19 +7,27 @@ import (
) )
type User struct { type User struct {
ID int64 ID int64
Email string Email string
Username string Username string
Notes string Notes string
PasswordHash string AvatarURL string
Role string AvatarSource string
Balance float64 AvatarMIME string
Concurrency int AvatarByteSize int
Status string AvatarSHA256 string
AllowedGroups []int64 PasswordHash string
TokenVersion int64 // Incremented on password change to invalidate existing tokens Role string
CreatedAt time.Time Balance float64
UpdatedAt time.Time Concurrency int
Status string
AllowedGroups []int64
TokenVersion int64 // Incremented on password change to invalidate existing tokens
SignupSource string
LastLoginAt *time.Time
LastActiveAt *time.Time
CreatedAt time.Time
UpdatedAt time.Time
// GroupRates 用户专属分组倍率配置 // GroupRates 用户专属分组倍率配置
// map[groupID]rateMultiplier // map[groupID]rateMultiplier
......
...@@ -2,9 +2,13 @@ package service ...@@ -2,9 +2,13 @@ package service
import ( import (
"context" "context"
"crypto/sha256"
"crypto/subtle" "crypto/subtle"
"encoding/base64"
"encoding/hex"
"fmt" "fmt"
"log/slog" "log/slog"
"net/url"
"strings" "strings"
"time" "time"
...@@ -17,10 +21,14 @@ var ( ...@@ -17,10 +21,14 @@ var (
ErrPasswordIncorrect = infraerrors.BadRequest("PASSWORD_INCORRECT", "current password is incorrect") ErrPasswordIncorrect = infraerrors.BadRequest("PASSWORD_INCORRECT", "current password is incorrect")
ErrInsufficientPerms = infraerrors.Forbidden("INSUFFICIENT_PERMISSIONS", "insufficient permissions") ErrInsufficientPerms = infraerrors.Forbidden("INSUFFICIENT_PERMISSIONS", "insufficient permissions")
ErrNotifyCodeUserRateLimit = infraerrors.TooManyRequests("NOTIFY_CODE_USER_RATE_LIMIT", "too many verification codes requested, please try again later") ErrNotifyCodeUserRateLimit = infraerrors.TooManyRequests("NOTIFY_CODE_USER_RATE_LIMIT", "too many verification codes requested, please try again later")
ErrAvatarInvalid = infraerrors.BadRequest("AVATAR_INVALID", "avatar must be a valid image data URL or http(s) URL")
ErrAvatarTooLarge = infraerrors.BadRequest("AVATAR_TOO_LARGE", "avatar image must be 100KB or smaller")
ErrAvatarNotImage = infraerrors.BadRequest("AVATAR_NOT_IMAGE", "avatar content must be an image")
) )
const ( const (
maxNotifyEmails = 3 // Maximum number of notification emails per user maxNotifyEmails = 3 // Maximum number of notification emails per user
maxInlineAvatarBytes = 100 * 1024
// User-level rate limiting for notify email verification codes // User-level rate limiting for notify email verification codes
notifyCodeUserRateLimit = 5 notifyCodeUserRateLimit = 5
...@@ -47,6 +55,9 @@ type UserRepository interface { ...@@ -47,6 +55,9 @@ type UserRepository interface {
GetFirstAdmin(ctx context.Context) (*User, error) GetFirstAdmin(ctx context.Context) (*User, error)
Update(ctx context.Context, user *User) error Update(ctx context.Context, user *User) error
Delete(ctx context.Context, id int64) error Delete(ctx context.Context, id int64) error
GetUserAvatar(ctx context.Context, userID int64) (*UserAvatar, error)
UpsertUserAvatar(ctx context.Context, userID int64, input UpsertUserAvatarInput) (*UserAvatar, error)
DeleteUserAvatar(ctx context.Context, userID int64) error
List(ctx context.Context, params pagination.PaginationParams) ([]User, *pagination.PaginationResult, error) List(ctx context.Context, params pagination.PaginationParams) ([]User, *pagination.PaginationResult, error)
ListWithFilters(ctx context.Context, params pagination.PaginationParams, filters UserListFilters) ([]User, *pagination.PaginationResult, error) ListWithFilters(ctx context.Context, params pagination.PaginationParams, filters UserListFilters) ([]User, *pagination.PaginationResult, error)
...@@ -71,11 +82,30 @@ type UserRepository interface { ...@@ -71,11 +82,30 @@ type UserRepository interface {
type UpdateProfileRequest struct { type UpdateProfileRequest struct {
Email *string `json:"email"` Email *string `json:"email"`
Username *string `json:"username"` Username *string `json:"username"`
AvatarURL *string `json:"avatar_url"`
Concurrency *int `json:"concurrency"` Concurrency *int `json:"concurrency"`
BalanceNotifyEnabled *bool `json:"balance_notify_enabled"` BalanceNotifyEnabled *bool `json:"balance_notify_enabled"`
BalanceNotifyThreshold *float64 `json:"balance_notify_threshold"` BalanceNotifyThreshold *float64 `json:"balance_notify_threshold"`
} }
type UserAvatar struct {
StorageProvider string
StorageKey string
URL string
ContentType string
ByteSize int
SHA256 string
}
type UpsertUserAvatarInput struct {
StorageProvider string
StorageKey string
URL string
ContentType string
ByteSize int
SHA256 string
}
// ChangePasswordRequest 修改密码请求 // ChangePasswordRequest 修改密码请求
type ChangePasswordRequest struct { type ChangePasswordRequest struct {
CurrentPassword string `json:"current_password"` CurrentPassword string `json:"current_password"`
...@@ -115,6 +145,9 @@ func (s *UserService) GetProfile(ctx context.Context, userID int64) (*User, erro ...@@ -115,6 +145,9 @@ func (s *UserService) GetProfile(ctx context.Context, userID int64) (*User, erro
if err != nil { if err != nil {
return nil, fmt.Errorf("get user: %w", err) return nil, fmt.Errorf("get user: %w", err)
} }
if err := s.hydrateUserAvatar(ctx, user); err != nil {
return nil, fmt.Errorf("get user avatar: %w", err)
}
return user, nil return user, nil
} }
...@@ -143,6 +176,27 @@ func (s *UserService) UpdateProfile(ctx context.Context, userID int64, req Updat ...@@ -143,6 +176,27 @@ func (s *UserService) UpdateProfile(ctx context.Context, userID int64, req Updat
user.Username = *req.Username user.Username = *req.Username
} }
if req.AvatarURL != nil {
avatarValue := strings.TrimSpace(*req.AvatarURL)
switch {
case avatarValue == "":
if err := s.userRepo.DeleteUserAvatar(ctx, userID); err != nil {
return nil, fmt.Errorf("delete avatar: %w", err)
}
applyUserAvatar(user, nil)
default:
avatarInput, err := normalizeUserAvatarInput(avatarValue)
if err != nil {
return nil, err
}
avatar, err := s.userRepo.UpsertUserAvatar(ctx, userID, avatarInput)
if err != nil {
return nil, fmt.Errorf("upsert avatar: %w", err)
}
applyUserAvatar(user, avatar)
}
}
if req.Concurrency != nil { if req.Concurrency != nil {
user.Concurrency = *req.Concurrency user.Concurrency = *req.Concurrency
} }
...@@ -168,6 +222,87 @@ func (s *UserService) UpdateProfile(ctx context.Context, userID int64, req Updat ...@@ -168,6 +222,87 @@ func (s *UserService) UpdateProfile(ctx context.Context, userID int64, req Updat
return user, nil return user, nil
} }
func applyUserAvatar(user *User, avatar *UserAvatar) {
if user == nil {
return
}
if avatar == nil {
user.AvatarURL = ""
user.AvatarSource = ""
user.AvatarMIME = ""
user.AvatarByteSize = 0
user.AvatarSHA256 = ""
return
}
user.AvatarURL = avatar.URL
user.AvatarSource = avatar.StorageProvider
user.AvatarMIME = avatar.ContentType
user.AvatarByteSize = avatar.ByteSize
user.AvatarSHA256 = avatar.SHA256
}
func normalizeUserAvatarInput(raw string) (UpsertUserAvatarInput, error) {
raw = strings.TrimSpace(raw)
if raw == "" {
return UpsertUserAvatarInput{}, ErrAvatarInvalid
}
if strings.HasPrefix(raw, "data:") {
return normalizeInlineUserAvatarInput(raw)
}
parsed, err := url.Parse(raw)
if err != nil || parsed == nil {
return UpsertUserAvatarInput{}, ErrAvatarInvalid
}
if !strings.EqualFold(parsed.Scheme, "http") && !strings.EqualFold(parsed.Scheme, "https") {
return UpsertUserAvatarInput{}, ErrAvatarInvalid
}
if strings.TrimSpace(parsed.Host) == "" {
return UpsertUserAvatarInput{}, ErrAvatarInvalid
}
return UpsertUserAvatarInput{
StorageProvider: "remote_url",
URL: raw,
}, nil
}
func normalizeInlineUserAvatarInput(raw string) (UpsertUserAvatarInput, error) {
body := strings.TrimPrefix(raw, "data:")
meta, encoded, ok := strings.Cut(body, ",")
if !ok {
return UpsertUserAvatarInput{}, ErrAvatarInvalid
}
meta = strings.TrimSpace(meta)
encoded = strings.TrimSpace(encoded)
if !strings.HasSuffix(strings.ToLower(meta), ";base64") {
return UpsertUserAvatarInput{}, ErrAvatarInvalid
}
contentType := strings.TrimSpace(meta[:len(meta)-len(";base64")])
if contentType == "" || !strings.HasPrefix(strings.ToLower(contentType), "image/") {
return UpsertUserAvatarInput{}, ErrAvatarNotImage
}
decoded, err := base64.StdEncoding.DecodeString(encoded)
if err != nil {
return UpsertUserAvatarInput{}, ErrAvatarInvalid
}
if len(decoded) > maxInlineAvatarBytes {
return UpsertUserAvatarInput{}, ErrAvatarTooLarge
}
sum := sha256.Sum256(decoded)
return UpsertUserAvatarInput{
StorageProvider: "inline",
URL: raw,
ContentType: contentType,
ByteSize: len(decoded),
SHA256: hex.EncodeToString(sum[:]),
}, nil
}
// ChangePassword 修改密码 // ChangePassword 修改密码
// Security: Increments TokenVersion to invalidate all existing JWT tokens // Security: Increments TokenVersion to invalidate all existing JWT tokens
func (s *UserService) ChangePassword(ctx context.Context, userID int64, req ChangePasswordRequest) error { func (s *UserService) ChangePassword(ctx context.Context, userID int64, req ChangePasswordRequest) error {
...@@ -202,9 +337,25 @@ func (s *UserService) GetByID(ctx context.Context, id int64) (*User, error) { ...@@ -202,9 +337,25 @@ func (s *UserService) GetByID(ctx context.Context, id int64) (*User, error) {
if err != nil { if err != nil {
return nil, fmt.Errorf("get user: %w", err) return nil, fmt.Errorf("get user: %w", err)
} }
if err := s.hydrateUserAvatar(ctx, user); err != nil {
return nil, fmt.Errorf("get user avatar: %w", err)
}
return user, nil return user, nil
} }
func (s *UserService) hydrateUserAvatar(ctx context.Context, user *User) error {
if s == nil || s.userRepo == nil || user == nil || user.ID == 0 {
return nil
}
avatar, err := s.userRepo.GetUserAvatar(ctx, user.ID)
if err != nil {
return err
}
applyUserAvatar(user, avatar)
return nil
}
// List 获取用户列表(管理员功能) // List 获取用户列表(管理员功能)
func (s *UserService) List(ctx context.Context, params pagination.PaginationParams) ([]User, *pagination.PaginationResult, error) { func (s *UserService) List(ctx context.Context, params pagination.PaginationParams) ([]User, *pagination.PaginationResult, error) {
users, pagination, err := s.userRepo.List(ctx, params) users, pagination, err := s.userRepo.List(ctx, params)
......
...@@ -4,6 +4,9 @@ package service ...@@ -4,6 +4,9 @@ package service
import ( import (
"context" "context"
"crypto/sha256"
"encoding/base64"
"encoding/hex"
"errors" "errors"
"sync" "sync"
"sync/atomic" "sync/atomic"
...@@ -19,14 +22,65 @@ import ( ...@@ -19,14 +22,65 @@ import (
type mockUserRepo struct { type mockUserRepo struct {
updateBalanceErr error updateBalanceErr error
updateBalanceFn func(ctx context.Context, id int64, amount float64) error updateBalanceFn func(ctx context.Context, id int64, amount float64) error
getByIDUser *User
getByIDErr error
updateFn func(ctx context.Context, user *User) error
updateCalls int
upsertAvatarFn func(ctx context.Context, userID int64, input UpsertUserAvatarInput) (*UserAvatar, error)
upsertAvatarArgs []UpsertUserAvatarInput
deleteAvatarFn func(ctx context.Context, userID int64) error
deleteAvatarIDs []int64
getAvatarFn func(ctx context.Context, userID int64) (*UserAvatar, error)
} }
func (m *mockUserRepo) Create(context.Context, *User) error { return nil } func (m *mockUserRepo) Create(context.Context, *User) error { return nil }
func (m *mockUserRepo) GetByID(context.Context, int64) (*User, error) { return &User{}, nil } func (m *mockUserRepo) GetByID(context.Context, int64) (*User, error) {
if m.getByIDErr != nil {
return nil, m.getByIDErr
}
if m.getByIDUser != nil {
cloned := *m.getByIDUser
return &cloned, nil
}
return &User{}, nil
}
func (m *mockUserRepo) GetByEmail(context.Context, string) (*User, error) { return &User{}, nil } func (m *mockUserRepo) GetByEmail(context.Context, string) (*User, error) { return &User{}, nil }
func (m *mockUserRepo) GetFirstAdmin(context.Context) (*User, error) { return &User{}, nil } func (m *mockUserRepo) GetFirstAdmin(context.Context) (*User, error) { return &User{}, nil }
func (m *mockUserRepo) Update(context.Context, *User) error { return nil } func (m *mockUserRepo) Update(ctx context.Context, user *User) error {
func (m *mockUserRepo) Delete(context.Context, int64) error { return nil } m.updateCalls++
if m.updateFn != nil {
return m.updateFn(ctx, user)
}
return nil
}
func (m *mockUserRepo) Delete(context.Context, int64) error { return nil }
func (m *mockUserRepo) GetUserAvatar(ctx context.Context, userID int64) (*UserAvatar, error) {
if m.getAvatarFn != nil {
return m.getAvatarFn(ctx, userID)
}
return nil, nil
}
func (m *mockUserRepo) UpsertUserAvatar(ctx context.Context, userID int64, input UpsertUserAvatarInput) (*UserAvatar, error) {
m.upsertAvatarArgs = append(m.upsertAvatarArgs, input)
if m.upsertAvatarFn != nil {
return m.upsertAvatarFn(ctx, userID, input)
}
return &UserAvatar{
StorageProvider: input.StorageProvider,
StorageKey: input.StorageKey,
URL: input.URL,
ContentType: input.ContentType,
ByteSize: input.ByteSize,
SHA256: input.SHA256,
}, nil
}
func (m *mockUserRepo) DeleteUserAvatar(ctx context.Context, userID int64) error {
m.deleteAvatarIDs = append(m.deleteAvatarIDs, userID)
if m.deleteAvatarFn != nil {
return m.deleteAvatarFn(ctx, userID)
}
return nil
}
func (m *mockUserRepo) List(context.Context, pagination.PaginationParams) ([]User, *pagination.PaginationResult, error) { func (m *mockUserRepo) List(context.Context, pagination.PaginationParams) ([]User, *pagination.PaginationResult, error) {
return nil, nil, nil return nil, nil, nil
} }
...@@ -200,3 +254,121 @@ func TestNewUserService_FieldsAssignment(t *testing.T) { ...@@ -200,3 +254,121 @@ func TestNewUserService_FieldsAssignment(t *testing.T) {
require.Equal(t, auth, svc.authCacheInvalidator) require.Equal(t, auth, svc.authCacheInvalidator)
require.Equal(t, cache, svc.billingCache) require.Equal(t, cache, svc.billingCache)
} }
func TestUpdateProfile_StoresInlineAvatarWithinLimit(t *testing.T) {
raw := []byte("small-avatar")
dataURL := "data:image/png;base64," + base64.StdEncoding.EncodeToString(raw)
expectedSum := sha256.Sum256(raw)
repo := &mockUserRepo{
getByIDUser: &User{
ID: 7,
Email: "avatar@example.com",
Username: "avatar-user",
},
}
svc := NewUserService(repo, nil, nil, nil)
updated, err := svc.UpdateProfile(context.Background(), 7, UpdateProfileRequest{
AvatarURL: &dataURL,
})
require.NoError(t, err)
require.Len(t, repo.upsertAvatarArgs, 1)
require.Equal(t, "inline", repo.upsertAvatarArgs[0].StorageProvider)
require.Equal(t, "image/png", repo.upsertAvatarArgs[0].ContentType)
require.Equal(t, len(raw), repo.upsertAvatarArgs[0].ByteSize)
require.Equal(t, hex.EncodeToString(expectedSum[:]), repo.upsertAvatarArgs[0].SHA256)
require.Equal(t, dataURL, updated.AvatarURL)
require.Equal(t, "inline", updated.AvatarSource)
require.Equal(t, "image/png", updated.AvatarMIME)
require.Equal(t, len(raw), updated.AvatarByteSize)
require.Equal(t, hex.EncodeToString(expectedSum[:]), updated.AvatarSHA256)
}
func TestUpdateProfile_RejectsInlineAvatarOverLimit(t *testing.T) {
raw := make([]byte, maxInlineAvatarBytes+1)
dataURL := "data:image/png;base64," + base64.StdEncoding.EncodeToString(raw)
repo := &mockUserRepo{
getByIDUser: &User{
ID: 8,
Email: "large-avatar@example.com",
Username: "too-large",
},
}
svc := NewUserService(repo, nil, nil, nil)
_, err := svc.UpdateProfile(context.Background(), 8, UpdateProfileRequest{
AvatarURL: &dataURL,
})
require.ErrorIs(t, err, ErrAvatarTooLarge)
require.Empty(t, repo.upsertAvatarArgs)
require.Empty(t, repo.deleteAvatarIDs)
require.Zero(t, repo.updateCalls)
}
func TestUpdateProfile_StoresRemoteAvatarURL(t *testing.T) {
remoteURL := "https://cdn.example.com/avatar.png"
repo := &mockUserRepo{
getByIDUser: &User{
ID: 9,
Email: "remote-avatar@example.com",
Username: "remote-avatar",
},
}
svc := NewUserService(repo, nil, nil, nil)
updated, err := svc.UpdateProfile(context.Background(), 9, UpdateProfileRequest{
AvatarURL: &remoteURL,
})
require.NoError(t, err)
require.Len(t, repo.upsertAvatarArgs, 1)
require.Equal(t, "remote_url", repo.upsertAvatarArgs[0].StorageProvider)
require.Equal(t, remoteURL, repo.upsertAvatarArgs[0].URL)
require.Equal(t, remoteURL, updated.AvatarURL)
require.Equal(t, "remote_url", updated.AvatarSource)
require.Zero(t, updated.AvatarByteSize)
}
func TestUpdateProfile_DeletesAvatarOnEmptyString(t *testing.T) {
empty := ""
repo := &mockUserRepo{
getByIDUser: &User{
ID: 10,
Email: "delete-avatar@example.com",
Username: "delete-avatar",
AvatarURL: "https://cdn.example.com/old.png",
AvatarSource: "remote_url",
},
}
svc := NewUserService(repo, nil, nil, nil)
updated, err := svc.UpdateProfile(context.Background(), 10, UpdateProfileRequest{
AvatarURL: &empty,
})
require.NoError(t, err)
require.Equal(t, []int64{10}, repo.deleteAvatarIDs)
require.Empty(t, repo.upsertAvatarArgs)
require.Empty(t, updated.AvatarURL)
require.Empty(t, updated.AvatarSource)
}
func TestGetProfile_HydratesAvatarFromRepository(t *testing.T) {
repo := &mockUserRepo{
getByIDUser: &User{
ID: 12,
Email: "profile-avatar@example.com",
Username: "profile-avatar",
},
getAvatarFn: func(context.Context, int64) (*UserAvatar, error) {
return &UserAvatar{
StorageProvider: "remote_url",
URL: "https://cdn.example.com/profile.png",
}, nil
},
}
svc := NewUserService(repo, nil, nil, nil)
user, err := svc.GetProfile(context.Background(), 12)
require.NoError(t, err)
require.Equal(t, "https://cdn.example.com/profile.png", user.AvatarURL)
require.Equal(t, "remote_url", user.AvatarSource)
}
ALTER TABLE users
ADD COLUMN IF NOT EXISTS signup_source VARCHAR(20) NOT NULL DEFAULT 'email',
ADD COLUMN IF NOT EXISTS last_login_at TIMESTAMPTZ NULL,
ADD COLUMN IF NOT EXISTS last_active_at TIMESTAMPTZ NULL;
UPDATE users
SET signup_source = 'email'
WHERE signup_source IS NULL OR signup_source = '';
DO $$
BEGIN
IF NOT EXISTS (
SELECT 1
FROM pg_constraint
WHERE conname = 'users_signup_source_check'
) THEN
ALTER TABLE users
ADD CONSTRAINT users_signup_source_check
CHECK (signup_source IN ('email', 'linuxdo', 'wechat', 'oidc'));
END IF;
END $$;
CREATE TABLE IF NOT EXISTS auth_identities (
id BIGSERIAL PRIMARY KEY,
user_id BIGINT NOT NULL REFERENCES users(id) ON DELETE CASCADE,
provider_type VARCHAR(20) NOT NULL,
provider_key TEXT NOT NULL,
provider_subject TEXT NOT NULL,
verified_at TIMESTAMPTZ NULL,
issuer TEXT NULL,
metadata JSONB NOT NULL DEFAULT '{}'::jsonb,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
CONSTRAINT auth_identities_provider_type_check
CHECK (provider_type IN ('email', 'linuxdo', 'wechat', 'oidc'))
);
CREATE UNIQUE INDEX IF NOT EXISTS auth_identities_provider_subject_key
ON auth_identities (provider_type, provider_key, provider_subject);
CREATE INDEX IF NOT EXISTS auth_identities_user_id_idx
ON auth_identities (user_id);
CREATE INDEX IF NOT EXISTS auth_identities_user_provider_idx
ON auth_identities (user_id, provider_type);
CREATE TABLE IF NOT EXISTS auth_identity_channels (
id BIGSERIAL PRIMARY KEY,
identity_id BIGINT NOT NULL REFERENCES auth_identities(id) ON DELETE CASCADE,
provider_type VARCHAR(20) NOT NULL,
provider_key TEXT NOT NULL,
channel VARCHAR(20) NOT NULL,
channel_app_id TEXT NOT NULL,
channel_subject TEXT NOT NULL,
metadata JSONB NOT NULL DEFAULT '{}'::jsonb,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
CONSTRAINT auth_identity_channels_provider_type_check
CHECK (provider_type IN ('email', 'linuxdo', 'wechat', 'oidc'))
);
CREATE UNIQUE INDEX IF NOT EXISTS auth_identity_channels_channel_key
ON auth_identity_channels (provider_type, provider_key, channel, channel_app_id, channel_subject);
CREATE INDEX IF NOT EXISTS auth_identity_channels_identity_id_idx
ON auth_identity_channels (identity_id);
CREATE TABLE IF NOT EXISTS pending_auth_sessions (
id BIGSERIAL PRIMARY KEY,
session_token VARCHAR(255) NOT NULL,
intent VARCHAR(40) NOT NULL,
provider_type VARCHAR(20) NOT NULL,
provider_key TEXT NOT NULL,
provider_subject TEXT NOT NULL,
target_user_id BIGINT NULL REFERENCES users(id) ON DELETE SET NULL,
redirect_to TEXT NOT NULL DEFAULT '',
resolved_email TEXT NOT NULL DEFAULT '',
registration_password_hash TEXT NOT NULL DEFAULT '',
upstream_identity_claims JSONB NOT NULL DEFAULT '{}'::jsonb,
local_flow_state JSONB NOT NULL DEFAULT '{}'::jsonb,
browser_session_key TEXT NOT NULL DEFAULT '',
completion_code_hash TEXT NOT NULL DEFAULT '',
completion_code_expires_at TIMESTAMPTZ NULL,
email_verified_at TIMESTAMPTZ NULL,
password_verified_at TIMESTAMPTZ NULL,
totp_verified_at TIMESTAMPTZ NULL,
expires_at TIMESTAMPTZ NOT NULL,
consumed_at TIMESTAMPTZ NULL,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
CONSTRAINT pending_auth_sessions_intent_check
CHECK (intent IN ('login', 'bind_current_user', 'adopt_existing_user_by_email')),
CONSTRAINT pending_auth_sessions_provider_type_check
CHECK (provider_type IN ('email', 'linuxdo', 'wechat', 'oidc'))
);
CREATE UNIQUE INDEX IF NOT EXISTS pending_auth_sessions_session_token_key
ON pending_auth_sessions (session_token);
CREATE INDEX IF NOT EXISTS pending_auth_sessions_target_user_id_idx
ON pending_auth_sessions (target_user_id);
CREATE INDEX IF NOT EXISTS pending_auth_sessions_expires_at_idx
ON pending_auth_sessions (expires_at);
CREATE INDEX IF NOT EXISTS pending_auth_sessions_provider_idx
ON pending_auth_sessions (provider_type, provider_key, provider_subject);
CREATE INDEX IF NOT EXISTS pending_auth_sessions_completion_code_idx
ON pending_auth_sessions (completion_code_hash);
CREATE TABLE IF NOT EXISTS identity_adoption_decisions (
id BIGSERIAL PRIMARY KEY,
pending_auth_session_id BIGINT NOT NULL REFERENCES pending_auth_sessions(id) ON DELETE CASCADE,
identity_id BIGINT NULL REFERENCES auth_identities(id) ON DELETE SET NULL,
adopt_display_name BOOLEAN NOT NULL DEFAULT FALSE,
adopt_avatar BOOLEAN NOT NULL DEFAULT FALSE,
decided_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
CREATE UNIQUE INDEX IF NOT EXISTS identity_adoption_decisions_pending_auth_session_id_key
ON identity_adoption_decisions (pending_auth_session_id);
CREATE INDEX IF NOT EXISTS identity_adoption_decisions_identity_id_idx
ON identity_adoption_decisions (identity_id);
CREATE TABLE IF NOT EXISTS auth_identity_migration_reports (
id BIGSERIAL PRIMARY KEY,
report_type VARCHAR(40) NOT NULL,
report_key TEXT NOT NULL,
details JSONB NOT NULL DEFAULT '{}'::jsonb,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
CREATE INDEX IF NOT EXISTS auth_identity_migration_reports_type_idx
ON auth_identity_migration_reports (report_type);
CREATE UNIQUE INDEX IF NOT EXISTS auth_identity_migration_reports_type_key
ON auth_identity_migration_reports (report_type, report_key);
INSERT INTO auth_identities (
user_id,
provider_type,
provider_key,
provider_subject,
verified_at,
metadata
)
SELECT
u.id,
'email',
'email',
LOWER(BTRIM(u.email)),
COALESCE(u.updated_at, u.created_at, NOW()),
jsonb_build_object(
'backfill_source', 'users.email',
'migration', '109_auth_identity_compat_backfill'
)
FROM users AS u
WHERE u.deleted_at IS NULL
AND BTRIM(COALESCE(u.email, '')) <> ''
AND RIGHT(LOWER(BTRIM(u.email)), LENGTH('@linuxdo-connect.invalid')) <> '@linuxdo-connect.invalid'
AND RIGHT(LOWER(BTRIM(u.email)), LENGTH('@oidc-connect.invalid')) <> '@oidc-connect.invalid'
AND RIGHT(LOWER(BTRIM(u.email)), LENGTH('@wechat-connect.invalid')) <> '@wechat-connect.invalid'
ON CONFLICT (provider_type, provider_key, provider_subject) DO NOTHING;
INSERT INTO auth_identities (
user_id,
provider_type,
provider_key,
provider_subject,
verified_at,
metadata
)
SELECT
u.id,
'linuxdo',
'linuxdo',
SUBSTRING(BTRIM(u.email) FROM '(?i)^linuxdo-(.+)@linuxdo-connect\.invalid$'),
COALESCE(u.updated_at, u.created_at, NOW()),
jsonb_build_object(
'backfill_source', 'synthetic_email',
'legacy_email', BTRIM(u.email),
'migration', '109_auth_identity_compat_backfill'
)
FROM users AS u
WHERE u.deleted_at IS NULL
AND LOWER(BTRIM(u.email)) ~ '^linuxdo-.+@linuxdo-connect\.invalid$'
ON CONFLICT (provider_type, provider_key, provider_subject) DO NOTHING;
INSERT INTO auth_identities (
user_id,
provider_type,
provider_key,
provider_subject,
verified_at,
metadata
)
SELECT
u.id,
'wechat',
'wechat',
SUBSTRING(BTRIM(u.email) FROM '(?i)^wechat-(.+)@wechat-connect\.invalid$'),
COALESCE(u.updated_at, u.created_at, NOW()),
jsonb_build_object(
'backfill_source', 'synthetic_email',
'legacy_email', BTRIM(u.email),
'migration', '109_auth_identity_compat_backfill'
)
FROM users AS u
WHERE u.deleted_at IS NULL
AND LOWER(BTRIM(u.email)) ~ '^wechat-.+@wechat-connect\.invalid$'
ON CONFLICT (provider_type, provider_key, provider_subject) DO NOTHING;
UPDATE users
SET signup_source = 'linuxdo'
WHERE deleted_at IS NULL
AND LOWER(BTRIM(COALESCE(email, ''))) ~ '^linuxdo-.+@linuxdo-connect\.invalid$';
UPDATE users
SET signup_source = 'wechat'
WHERE deleted_at IS NULL
AND LOWER(BTRIM(COALESCE(email, ''))) ~ '^wechat-.+@wechat-connect\.invalid$';
UPDATE users
SET signup_source = 'oidc'
WHERE deleted_at IS NULL
AND LOWER(BTRIM(COALESCE(email, ''))) ~ '^oidc-.+@oidc-connect\.invalid$';
INSERT INTO auth_identity_migration_reports (report_type, report_key, details)
SELECT
'oidc_synthetic_email_requires_manual_recovery',
CAST(u.id AS TEXT),
jsonb_build_object(
'user_id', u.id,
'email', LOWER(BTRIM(u.email)),
'reason', 'cannot recover issuer_plus_sub deterministically from synthetic email alone',
'migration', '109_auth_identity_compat_backfill'
)
FROM users AS u
WHERE u.deleted_at IS NULL
AND LOWER(BTRIM(u.email)) ~ '^oidc-.+@oidc-connect\.invalid$'
ON CONFLICT (report_type, report_key) DO NOTHING;
INSERT INTO auth_identity_migration_reports (report_type, report_key, details)
SELECT
'wechat_openid_only_requires_remediation',
CAST(u.id AS TEXT),
jsonb_build_object(
'user_id', u.id,
'email', LOWER(BTRIM(u.email)),
'reason', 'legacy wechat synthetic identity requires explicit unionid remediation if channel-only data exists',
'migration', '109_auth_identity_compat_backfill'
)
FROM users AS u
WHERE u.deleted_at IS NULL
AND LOWER(BTRIM(u.email)) ~ '^wechat-.+@wechat-connect\.invalid$'
AND NOT EXISTS (
SELECT 1
FROM auth_identities ai
WHERE ai.user_id = u.id
AND ai.provider_type = 'wechat'
AND ai.provider_key = 'wechat'
)
ON CONFLICT (report_type, report_key) DO NOTHING;
CREATE TABLE IF NOT EXISTS user_provider_default_grants (
id BIGSERIAL PRIMARY KEY,
user_id BIGINT NOT NULL REFERENCES users(id) ON DELETE CASCADE,
provider_type VARCHAR(20) NOT NULL,
grant_reason VARCHAR(20) NOT NULL DEFAULT 'first_bind',
granted_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
CONSTRAINT user_provider_default_grants_provider_type_check
CHECK (provider_type IN ('email', 'linuxdo', 'wechat', 'oidc')),
CONSTRAINT user_provider_default_grants_reason_check
CHECK (grant_reason IN ('signup', 'first_bind'))
);
CREATE UNIQUE INDEX IF NOT EXISTS user_provider_default_grants_user_provider_reason_key
ON user_provider_default_grants (user_id, provider_type, grant_reason);
CREATE INDEX IF NOT EXISTS user_provider_default_grants_user_id_idx
ON user_provider_default_grants (user_id);
CREATE TABLE IF NOT EXISTS user_avatars (
id BIGSERIAL PRIMARY KEY,
user_id BIGINT NOT NULL REFERENCES users(id) ON DELETE CASCADE,
storage_provider VARCHAR(20) NOT NULL DEFAULT 'database',
storage_key TEXT NOT NULL DEFAULT '',
url TEXT NOT NULL DEFAULT '',
content_type VARCHAR(100) NOT NULL DEFAULT '',
byte_size INT NOT NULL DEFAULT 0,
sha256 VARCHAR(64) NOT NULL DEFAULT '',
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
CREATE UNIQUE INDEX IF NOT EXISTS user_avatars_user_id_key
ON user_avatars (user_id);
INSERT INTO settings (key, value)
VALUES
('auth_source_default_email_balance', '0'),
('auth_source_default_email_concurrency', '5'),
('auth_source_default_email_subscriptions', '[]'),
('auth_source_default_email_grant_on_signup', 'true'),
('auth_source_default_email_grant_on_first_bind', 'false'),
('auth_source_default_linuxdo_balance', '0'),
('auth_source_default_linuxdo_concurrency', '5'),
('auth_source_default_linuxdo_subscriptions', '[]'),
('auth_source_default_linuxdo_grant_on_signup', 'true'),
('auth_source_default_linuxdo_grant_on_first_bind', 'false'),
('auth_source_default_oidc_balance', '0'),
('auth_source_default_oidc_concurrency', '5'),
('auth_source_default_oidc_subscriptions', '[]'),
('auth_source_default_oidc_grant_on_signup', 'true'),
('auth_source_default_oidc_grant_on_first_bind', 'false'),
('auth_source_default_wechat_balance', '0'),
('auth_source_default_wechat_concurrency', '5'),
('auth_source_default_wechat_subscriptions', '[]'),
('auth_source_default_wechat_grant_on_signup', 'true'),
('auth_source_default_wechat_grant_on_first_bind', 'false'),
('force_email_on_third_party_signup', 'false')
ON CONFLICT (key) DO NOTHING;
INSERT INTO settings (key, value)
VALUES
('payment_visible_method_alipay_source', ''),
('payment_visible_method_wxpay_source', ''),
('payment_visible_method_alipay_enabled', 'false'),
('payment_visible_method_wxpay_enabled', 'false'),
('openai_advanced_scheduler_enabled', 'false')
ON CONFLICT (key) DO NOTHING;
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment