Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Menu
Open sidebar
陈曦
sub2api
Commits
e9de839d
Commit
e9de839d
authored
Apr 20, 2026
by
IanShaw027
Browse files
feat: rebuild auth identity foundation flow
parent
fbd0a2e3
Changes
123
Hide whitespace changes
Inline
Side-by-side
backend/internal/service/auth_service_pending_oauth_test.go
deleted
100644 → 0
View file @
fbd0a2e3
//go:build unit
package
service
import
(
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/golang-jwt/jwt/v5"
"github.com/stretchr/testify/require"
)
func
newAuthServiceForPendingOAuthTest
()
*
AuthService
{
cfg
:=
&
config
.
Config
{
JWT
:
config
.
JWTConfig
{
Secret
:
"test-secret-pending-oauth"
,
ExpireHour
:
1
,
},
}
return
NewAuthService
(
nil
,
nil
,
nil
,
nil
,
cfg
,
nil
,
nil
,
nil
,
nil
,
nil
,
nil
)
}
// TestVerifyPendingOAuthToken_ValidToken 验证正常签发的 pending token 可以被成功解析。
func
TestVerifyPendingOAuthToken_ValidToken
(
t
*
testing
.
T
)
{
svc
:=
newAuthServiceForPendingOAuthTest
()
token
,
err
:=
svc
.
CreatePendingOAuthToken
(
"user@example.com"
,
"alice"
)
require
.
NoError
(
t
,
err
)
require
.
NotEmpty
(
t
,
token
)
email
,
username
,
err
:=
svc
.
VerifyPendingOAuthToken
(
token
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
"user@example.com"
,
email
)
require
.
Equal
(
t
,
"alice"
,
username
)
}
// TestVerifyPendingOAuthToken_RegularJWTRejected 用普通 access token 尝试验证,应返回 ErrInvalidToken。
func
TestVerifyPendingOAuthToken_RegularJWTRejected
(
t
*
testing
.
T
)
{
svc
:=
newAuthServiceForPendingOAuthTest
()
// 签发一个普通 access token(JWTClaims,无 Purpose 字段)
accessToken
,
err
:=
svc
.
GenerateToken
(
&
User
{
ID
:
1
,
Email
:
"user@example.com"
,
Role
:
RoleUser
,
})
require
.
NoError
(
t
,
err
)
_
,
_
,
err
=
svc
.
VerifyPendingOAuthToken
(
accessToken
)
require
.
ErrorIs
(
t
,
err
,
ErrInvalidToken
)
}
// TestVerifyPendingOAuthToken_WrongPurpose 手动构造 purpose 字段不匹配的 JWT,应返回 ErrInvalidToken。
func
TestVerifyPendingOAuthToken_WrongPurpose
(
t
*
testing
.
T
)
{
svc
:=
newAuthServiceForPendingOAuthTest
()
now
:=
time
.
Now
()
claims
:=
&
pendingOAuthClaims
{
Email
:
"user@example.com"
,
Username
:
"alice"
,
Purpose
:
"some_other_purpose"
,
RegisteredClaims
:
jwt
.
RegisteredClaims
{
ExpiresAt
:
jwt
.
NewNumericDate
(
now
.
Add
(
10
*
time
.
Minute
)),
IssuedAt
:
jwt
.
NewNumericDate
(
now
),
NotBefore
:
jwt
.
NewNumericDate
(
now
),
},
}
tok
:=
jwt
.
NewWithClaims
(
jwt
.
SigningMethodHS256
,
claims
)
tokenStr
,
err
:=
tok
.
SignedString
([]
byte
(
svc
.
cfg
.
JWT
.
Secret
))
require
.
NoError
(
t
,
err
)
_
,
_
,
err
=
svc
.
VerifyPendingOAuthToken
(
tokenStr
)
require
.
ErrorIs
(
t
,
err
,
ErrInvalidToken
)
}
// TestVerifyPendingOAuthToken_MissingPurpose 手动构造无 purpose 字段的 JWT(模拟旧 token),应返回 ErrInvalidToken。
func
TestVerifyPendingOAuthToken_MissingPurpose
(
t
*
testing
.
T
)
{
svc
:=
newAuthServiceForPendingOAuthTest
()
now
:=
time
.
Now
()
claims
:=
&
pendingOAuthClaims
{
Email
:
"user@example.com"
,
Username
:
"alice"
,
Purpose
:
""
,
// 旧 token 无此字段,反序列化后为零值
RegisteredClaims
:
jwt
.
RegisteredClaims
{
ExpiresAt
:
jwt
.
NewNumericDate
(
now
.
Add
(
10
*
time
.
Minute
)),
IssuedAt
:
jwt
.
NewNumericDate
(
now
),
NotBefore
:
jwt
.
NewNumericDate
(
now
),
},
}
tok
:=
jwt
.
NewWithClaims
(
jwt
.
SigningMethodHS256
,
claims
)
tokenStr
,
err
:=
tok
.
SignedString
([]
byte
(
svc
.
cfg
.
JWT
.
Secret
))
require
.
NoError
(
t
,
err
)
_
,
_
,
err
=
svc
.
VerifyPendingOAuthToken
(
tokenStr
)
require
.
ErrorIs
(
t
,
err
,
ErrInvalidToken
)
}
// TestVerifyPendingOAuthToken_ExpiredToken 过期 token 应返回 ErrInvalidToken。
func
TestVerifyPendingOAuthToken_ExpiredToken
(
t
*
testing
.
T
)
{
svc
:=
newAuthServiceForPendingOAuthTest
()
past
:=
time
.
Now
()
.
Add
(
-
1
*
time
.
Hour
)
claims
:=
&
pendingOAuthClaims
{
Email
:
"user@example.com"
,
Username
:
"alice"
,
Purpose
:
pendingOAuthPurpose
,
RegisteredClaims
:
jwt
.
RegisteredClaims
{
ExpiresAt
:
jwt
.
NewNumericDate
(
past
),
IssuedAt
:
jwt
.
NewNumericDate
(
past
.
Add
(
-
10
*
time
.
Minute
)),
NotBefore
:
jwt
.
NewNumericDate
(
past
.
Add
(
-
10
*
time
.
Minute
)),
},
}
tok
:=
jwt
.
NewWithClaims
(
jwt
.
SigningMethodHS256
,
claims
)
tokenStr
,
err
:=
tok
.
SignedString
([]
byte
(
svc
.
cfg
.
JWT
.
Secret
))
require
.
NoError
(
t
,
err
)
_
,
_
,
err
=
svc
.
VerifyPendingOAuthToken
(
tokenStr
)
require
.
ErrorIs
(
t
,
err
,
ErrInvalidToken
)
}
// TestVerifyPendingOAuthToken_WrongSecret 不同密钥签发的 token 应返回 ErrInvalidToken。
func
TestVerifyPendingOAuthToken_WrongSecret
(
t
*
testing
.
T
)
{
other
:=
NewAuthService
(
nil
,
nil
,
nil
,
nil
,
&
config
.
Config
{
JWT
:
config
.
JWTConfig
{
Secret
:
"other-secret"
},
},
nil
,
nil
,
nil
,
nil
,
nil
,
nil
)
token
,
err
:=
other
.
CreatePendingOAuthToken
(
"user@example.com"
,
"alice"
)
require
.
NoError
(
t
,
err
)
svc
:=
newAuthServiceForPendingOAuthTest
()
_
,
_
,
err
=
svc
.
VerifyPendingOAuthToken
(
token
)
require
.
ErrorIs
(
t
,
err
,
ErrInvalidToken
)
}
// TestVerifyPendingOAuthToken_TooLong 超长 token 应返回 ErrInvalidToken。
func
TestVerifyPendingOAuthToken_TooLong
(
t
*
testing
.
T
)
{
svc
:=
newAuthServiceForPendingOAuthTest
()
giant
:=
make
([]
byte
,
maxTokenLength
+
1
)
for
i
:=
range
giant
{
giant
[
i
]
=
'a'
}
_
,
_
,
err
:=
svc
.
VerifyPendingOAuthToken
(
string
(
giant
))
require
.
ErrorIs
(
t
,
err
,
ErrInvalidToken
)
}
backend/internal/service/domain_constants.go
View file @
e9de839d
...
...
@@ -74,6 +74,9 @@ const LinuxDoConnectSyntheticEmailDomain = "@linuxdo-connect.invalid"
// OIDCConnectSyntheticEmailDomain 是 OIDC 用户的合成邮箱后缀(RFC 保留域名)。
const
OIDCConnectSyntheticEmailDomain
=
"@oidc-connect.invalid"
// WeChatConnectSyntheticEmailDomain 是 WeChat Connect 用户的合成邮箱后缀(RFC 保留域名)。
const
WeChatConnectSyntheticEmailDomain
=
"@wechat-connect.invalid"
// Setting keys
const
(
// 注册设置
...
...
@@ -153,6 +156,29 @@ const (
SettingKeyDefaultBalance
=
"default_balance"
// 新用户默认余额
SettingKeyDefaultSubscriptions
=
"default_subscriptions"
// 新用户默认订阅列表(JSON)
// 第三方认证来源默认授予配置
SettingKeyAuthSourceDefaultEmailBalance
=
"auth_source_default_email_balance"
SettingKeyAuthSourceDefaultEmailConcurrency
=
"auth_source_default_email_concurrency"
SettingKeyAuthSourceDefaultEmailSubscriptions
=
"auth_source_default_email_subscriptions"
SettingKeyAuthSourceDefaultEmailGrantOnSignup
=
"auth_source_default_email_grant_on_signup"
SettingKeyAuthSourceDefaultEmailGrantOnFirstBind
=
"auth_source_default_email_grant_on_first_bind"
SettingKeyAuthSourceDefaultLinuxDoBalance
=
"auth_source_default_linuxdo_balance"
SettingKeyAuthSourceDefaultLinuxDoConcurrency
=
"auth_source_default_linuxdo_concurrency"
SettingKeyAuthSourceDefaultLinuxDoSubscriptions
=
"auth_source_default_linuxdo_subscriptions"
SettingKeyAuthSourceDefaultLinuxDoGrantOnSignup
=
"auth_source_default_linuxdo_grant_on_signup"
SettingKeyAuthSourceDefaultLinuxDoGrantOnFirstBind
=
"auth_source_default_linuxdo_grant_on_first_bind"
SettingKeyAuthSourceDefaultOIDCBalance
=
"auth_source_default_oidc_balance"
SettingKeyAuthSourceDefaultOIDCConcurrency
=
"auth_source_default_oidc_concurrency"
SettingKeyAuthSourceDefaultOIDCSubscriptions
=
"auth_source_default_oidc_subscriptions"
SettingKeyAuthSourceDefaultOIDCGrantOnSignup
=
"auth_source_default_oidc_grant_on_signup"
SettingKeyAuthSourceDefaultOIDCGrantOnFirstBind
=
"auth_source_default_oidc_grant_on_first_bind"
SettingKeyAuthSourceDefaultWeChatBalance
=
"auth_source_default_wechat_balance"
SettingKeyAuthSourceDefaultWeChatConcurrency
=
"auth_source_default_wechat_concurrency"
SettingKeyAuthSourceDefaultWeChatSubscriptions
=
"auth_source_default_wechat_subscriptions"
SettingKeyAuthSourceDefaultWeChatGrantOnSignup
=
"auth_source_default_wechat_grant_on_signup"
SettingKeyAuthSourceDefaultWeChatGrantOnFirstBind
=
"auth_source_default_wechat_grant_on_first_bind"
SettingKeyForceEmailOnThirdPartySignup
=
"force_email_on_third_party_signup"
// 管理员 API Key
SettingKeyAdminAPIKey
=
"admin_api_key"
// 全局管理员 API Key(用于外部系统集成)
...
...
backend/internal/service/openai_account_scheduler.go
View file @
e9de839d
...
...
@@ -13,14 +13,30 @@ import (
"sync"
"sync/atomic"
"time"
"golang.org/x/sync/singleflight"
)
const
(
openAIAccountScheduleLayerPreviousResponse
=
"previous_response_id"
openAIAccountScheduleLayerSessionSticky
=
"session_hash"
openAIAccountScheduleLayerLoadBalance
=
"load_balance"
openAIAdvancedSchedulerSettingKey
=
"openai_advanced_scheduler_enabled"
)
const
(
openAIAdvancedSchedulerSettingCacheTTL
=
5
*
time
.
Second
openAIAdvancedSchedulerSettingDBTimeout
=
2
*
time
.
Second
)
type
cachedOpenAIAdvancedSchedulerSetting
struct
{
enabled
bool
expiresAt
int64
}
var
openAIAdvancedSchedulerSettingCache
atomic
.
Value
// *cachedOpenAIAdvancedSchedulerSetting
var
openAIAdvancedSchedulerSettingSF
singleflight
.
Group
type
OpenAIAccountScheduleRequest
struct
{
GroupID
*
int64
SessionHash
string
...
...
@@ -805,10 +821,56 @@ func (s *defaultOpenAIAccountScheduler) SnapshotMetrics() OpenAIAccountScheduler
return
snapshot
}
func
(
s
*
OpenAIGatewayService
)
getOpenAIAccountScheduler
()
OpenAIAccountScheduler
{
func
(
s
*
OpenAIGatewayService
)
openAIAdvancedSchedulerSettingRepo
()
SettingRepository
{
if
s
==
nil
||
s
.
rateLimitService
==
nil
||
s
.
rateLimitService
.
settingService
==
nil
{
return
nil
}
return
s
.
rateLimitService
.
settingService
.
settingRepo
}
func
(
s
*
OpenAIGatewayService
)
isOpenAIAdvancedSchedulerEnabled
(
ctx
context
.
Context
)
bool
{
if
cached
,
ok
:=
openAIAdvancedSchedulerSettingCache
.
Load
()
.
(
*
cachedOpenAIAdvancedSchedulerSetting
);
ok
&&
cached
!=
nil
{
if
time
.
Now
()
.
UnixNano
()
<
cached
.
expiresAt
{
return
cached
.
enabled
}
}
result
,
_
,
_
:=
openAIAdvancedSchedulerSettingSF
.
Do
(
openAIAdvancedSchedulerSettingKey
,
func
()
(
any
,
error
)
{
if
cached
,
ok
:=
openAIAdvancedSchedulerSettingCache
.
Load
()
.
(
*
cachedOpenAIAdvancedSchedulerSetting
);
ok
&&
cached
!=
nil
{
if
time
.
Now
()
.
UnixNano
()
<
cached
.
expiresAt
{
return
cached
.
enabled
,
nil
}
}
enabled
:=
false
if
repo
:=
s
.
openAIAdvancedSchedulerSettingRepo
();
repo
!=
nil
{
dbCtx
,
cancel
:=
context
.
WithTimeout
(
context
.
WithoutCancel
(
ctx
),
openAIAdvancedSchedulerSettingDBTimeout
)
defer
cancel
()
value
,
err
:=
repo
.
GetValue
(
dbCtx
,
openAIAdvancedSchedulerSettingKey
)
if
err
==
nil
{
enabled
=
strings
.
EqualFold
(
strings
.
TrimSpace
(
value
),
"true"
)
}
}
openAIAdvancedSchedulerSettingCache
.
Store
(
&
cachedOpenAIAdvancedSchedulerSetting
{
enabled
:
enabled
,
expiresAt
:
time
.
Now
()
.
Add
(
openAIAdvancedSchedulerSettingCacheTTL
)
.
UnixNano
(),
})
return
enabled
,
nil
})
enabled
,
_
:=
result
.
(
bool
)
return
enabled
}
func
(
s
*
OpenAIGatewayService
)
getOpenAIAccountScheduler
(
ctx
context
.
Context
)
OpenAIAccountScheduler
{
if
s
==
nil
{
return
nil
}
if
!
s
.
isOpenAIAdvancedSchedulerEnabled
(
ctx
)
{
return
nil
}
s
.
openaiSchedulerOnce
.
Do
(
func
()
{
if
s
.
openaiAccountStats
==
nil
{
s
.
openaiAccountStats
=
newOpenAIAccountRuntimeStats
()
...
...
@@ -820,6 +882,11 @@ func (s *OpenAIGatewayService) getOpenAIAccountScheduler() OpenAIAccountSchedule
return
s
.
openaiScheduler
}
func
resetOpenAIAdvancedSchedulerSettingCacheForTest
()
{
openAIAdvancedSchedulerSettingCache
=
atomic
.
Value
{}
openAIAdvancedSchedulerSettingSF
=
singleflight
.
Group
{}
}
func
(
s
*
OpenAIGatewayService
)
SelectAccountWithScheduler
(
ctx
context
.
Context
,
groupID
*
int64
,
...
...
@@ -830,7 +897,7 @@ func (s *OpenAIGatewayService) SelectAccountWithScheduler(
requiredTransport
OpenAIUpstreamTransport
,
)
(
*
AccountSelectionResult
,
OpenAIAccountScheduleDecision
,
error
)
{
decision
:=
OpenAIAccountScheduleDecision
{}
scheduler
:=
s
.
getOpenAIAccountScheduler
()
scheduler
:=
s
.
getOpenAIAccountScheduler
(
ctx
)
if
scheduler
==
nil
{
selection
,
err
:=
s
.
SelectAccountWithLoadAwareness
(
ctx
,
groupID
,
sessionHash
,
requestedModel
,
excludedIDs
)
decision
.
Layer
=
openAIAccountScheduleLayerLoadBalance
...
...
@@ -856,7 +923,7 @@ func (s *OpenAIGatewayService) SelectAccountWithScheduler(
}
func
(
s
*
OpenAIGatewayService
)
ReportOpenAIAccountScheduleResult
(
accountID
int64
,
success
bool
,
firstTokenMs
*
int
)
{
scheduler
:=
s
.
getOpenAIAccountScheduler
()
scheduler
:=
s
.
getOpenAIAccountScheduler
(
context
.
Background
()
)
if
scheduler
==
nil
{
return
}
...
...
@@ -864,7 +931,7 @@ func (s *OpenAIGatewayService) ReportOpenAIAccountScheduleResult(accountID int64
}
func
(
s
*
OpenAIGatewayService
)
RecordOpenAIAccountSwitch
()
{
scheduler
:=
s
.
getOpenAIAccountScheduler
()
scheduler
:=
s
.
getOpenAIAccountScheduler
(
context
.
Background
()
)
if
scheduler
==
nil
{
return
}
...
...
@@ -872,7 +939,7 @@ func (s *OpenAIGatewayService) RecordOpenAIAccountSwitch() {
}
func
(
s
*
OpenAIGatewayService
)
SnapshotOpenAIAccountSchedulerMetrics
()
OpenAIAccountSchedulerMetricsSnapshot
{
scheduler
:=
s
.
getOpenAIAccountScheduler
()
scheduler
:=
s
.
getOpenAIAccountScheduler
(
context
.
Background
()
)
if
scheduler
==
nil
{
return
OpenAIAccountSchedulerMetricsSnapshot
{}
}
...
...
backend/internal/service/openai_account_scheduler_test.go
View file @
e9de839d
...
...
@@ -2,6 +2,7 @@ package service
import
(
"context"
"errors"
"fmt"
"math"
"sync"
...
...
@@ -18,6 +19,202 @@ type openAISnapshotCacheStub struct {
accountsByID
map
[
int64
]
*
Account
}
type
schedulerTestOpenAIAccountRepo
struct
{
AccountRepository
accounts
[]
Account
}
func
(
r
schedulerTestOpenAIAccountRepo
)
GetByID
(
ctx
context
.
Context
,
id
int64
)
(
*
Account
,
error
)
{
for
i
:=
range
r
.
accounts
{
if
r
.
accounts
[
i
]
.
ID
==
id
{
return
&
r
.
accounts
[
i
],
nil
}
}
return
nil
,
errors
.
New
(
"account not found"
)
}
func
(
r
schedulerTestOpenAIAccountRepo
)
ListSchedulableByGroupIDAndPlatform
(
ctx
context
.
Context
,
groupID
int64
,
platform
string
)
([]
Account
,
error
)
{
var
result
[]
Account
for
_
,
acc
:=
range
r
.
accounts
{
if
acc
.
Platform
==
platform
{
result
=
append
(
result
,
acc
)
}
}
return
result
,
nil
}
func
(
r
schedulerTestOpenAIAccountRepo
)
ListSchedulableByPlatform
(
ctx
context
.
Context
,
platform
string
)
([]
Account
,
error
)
{
var
result
[]
Account
for
_
,
acc
:=
range
r
.
accounts
{
if
acc
.
Platform
==
platform
{
result
=
append
(
result
,
acc
)
}
}
return
result
,
nil
}
func
(
r
schedulerTestOpenAIAccountRepo
)
ListSchedulableUngroupedByPlatform
(
ctx
context
.
Context
,
platform
string
)
([]
Account
,
error
)
{
return
r
.
ListSchedulableByPlatform
(
ctx
,
platform
)
}
type
schedulerTestConcurrencyCache
struct
{
ConcurrencyCache
loadBatchErr
error
loadMap
map
[
int64
]
*
AccountLoadInfo
acquireResults
map
[
int64
]
bool
waitCounts
map
[
int64
]
int
skipDefaultLoad
bool
}
func
(
c
schedulerTestConcurrencyCache
)
AcquireAccountSlot
(
ctx
context
.
Context
,
accountID
int64
,
maxConcurrency
int
,
requestID
string
)
(
bool
,
error
)
{
if
c
.
acquireResults
!=
nil
{
if
result
,
ok
:=
c
.
acquireResults
[
accountID
];
ok
{
return
result
,
nil
}
}
return
true
,
nil
}
func
(
c
schedulerTestConcurrencyCache
)
ReleaseAccountSlot
(
ctx
context
.
Context
,
accountID
int64
,
requestID
string
)
error
{
return
nil
}
func
(
c
schedulerTestConcurrencyCache
)
GetAccountsLoadBatch
(
ctx
context
.
Context
,
accounts
[]
AccountWithConcurrency
)
(
map
[
int64
]
*
AccountLoadInfo
,
error
)
{
if
c
.
loadBatchErr
!=
nil
{
return
nil
,
c
.
loadBatchErr
}
out
:=
make
(
map
[
int64
]
*
AccountLoadInfo
,
len
(
accounts
))
if
c
.
skipDefaultLoad
&&
c
.
loadMap
!=
nil
{
for
_
,
acc
:=
range
accounts
{
if
load
,
ok
:=
c
.
loadMap
[
acc
.
ID
];
ok
{
out
[
acc
.
ID
]
=
load
}
}
return
out
,
nil
}
for
_
,
acc
:=
range
accounts
{
if
c
.
loadMap
!=
nil
{
if
load
,
ok
:=
c
.
loadMap
[
acc
.
ID
];
ok
{
out
[
acc
.
ID
]
=
load
continue
}
}
out
[
acc
.
ID
]
=
&
AccountLoadInfo
{
AccountID
:
acc
.
ID
,
LoadRate
:
0
}
}
return
out
,
nil
}
func
(
c
schedulerTestConcurrencyCache
)
GetAccountWaitingCount
(
ctx
context
.
Context
,
accountID
int64
)
(
int
,
error
)
{
if
c
.
waitCounts
!=
nil
{
if
count
,
ok
:=
c
.
waitCounts
[
accountID
];
ok
{
return
count
,
nil
}
}
return
0
,
nil
}
type
schedulerTestGatewayCache
struct
{
sessionBindings
map
[
string
]
int64
deletedSessions
map
[
string
]
int
}
func
(
c
*
schedulerTestGatewayCache
)
GetSessionAccountID
(
ctx
context
.
Context
,
groupID
int64
,
sessionHash
string
)
(
int64
,
error
)
{
if
id
,
ok
:=
c
.
sessionBindings
[
sessionHash
];
ok
{
return
id
,
nil
}
return
0
,
errors
.
New
(
"not found"
)
}
func
(
c
*
schedulerTestGatewayCache
)
SetSessionAccountID
(
ctx
context
.
Context
,
groupID
int64
,
sessionHash
string
,
accountID
int64
,
ttl
time
.
Duration
)
error
{
if
c
.
sessionBindings
==
nil
{
c
.
sessionBindings
=
make
(
map
[
string
]
int64
)
}
c
.
sessionBindings
[
sessionHash
]
=
accountID
return
nil
}
func
(
c
*
schedulerTestGatewayCache
)
RefreshSessionTTL
(
ctx
context
.
Context
,
groupID
int64
,
sessionHash
string
,
ttl
time
.
Duration
)
error
{
return
nil
}
func
(
c
*
schedulerTestGatewayCache
)
DeleteSessionAccountID
(
ctx
context
.
Context
,
groupID
int64
,
sessionHash
string
)
error
{
if
c
.
sessionBindings
==
nil
{
return
nil
}
if
c
.
deletedSessions
==
nil
{
c
.
deletedSessions
=
make
(
map
[
string
]
int
)
}
c
.
deletedSessions
[
sessionHash
]
++
delete
(
c
.
sessionBindings
,
sessionHash
)
return
nil
}
func
newSchedulerTestOpenAIWSV2Config
()
*
config
.
Config
{
cfg
:=
&
config
.
Config
{}
cfg
.
Gateway
.
OpenAIWS
.
Enabled
=
true
cfg
.
Gateway
.
OpenAIWS
.
OAuthEnabled
=
true
cfg
.
Gateway
.
OpenAIWS
.
APIKeyEnabled
=
true
cfg
.
Gateway
.
OpenAIWS
.
ResponsesWebsocketsV2
=
true
cfg
.
Gateway
.
OpenAIWS
.
StickyResponseIDTTLSeconds
=
3600
return
cfg
}
type
openAIAdvancedSchedulerSettingRepoStub
struct
{
values
map
[
string
]
string
}
func
(
s
*
openAIAdvancedSchedulerSettingRepoStub
)
Get
(
ctx
context
.
Context
,
key
string
)
(
*
Setting
,
error
)
{
value
,
err
:=
s
.
GetValue
(
ctx
,
key
)
if
err
!=
nil
{
return
nil
,
err
}
return
&
Setting
{
Key
:
key
,
Value
:
value
},
nil
}
func
(
s
*
openAIAdvancedSchedulerSettingRepoStub
)
GetValue
(
_
context
.
Context
,
key
string
)
(
string
,
error
)
{
if
s
==
nil
||
s
.
values
==
nil
{
return
""
,
ErrSettingNotFound
}
value
,
ok
:=
s
.
values
[
key
]
if
!
ok
{
return
""
,
ErrSettingNotFound
}
return
value
,
nil
}
func
(
s
*
openAIAdvancedSchedulerSettingRepoStub
)
Set
(
context
.
Context
,
string
,
string
)
error
{
panic
(
"unexpected call to Set"
)
}
func
(
s
*
openAIAdvancedSchedulerSettingRepoStub
)
GetMultiple
(
context
.
Context
,
[]
string
)
(
map
[
string
]
string
,
error
)
{
panic
(
"unexpected call to GetMultiple"
)
}
func
(
s
*
openAIAdvancedSchedulerSettingRepoStub
)
SetMultiple
(
context
.
Context
,
map
[
string
]
string
)
error
{
panic
(
"unexpected call to SetMultiple"
)
}
func
(
s
*
openAIAdvancedSchedulerSettingRepoStub
)
GetAll
(
context
.
Context
)
(
map
[
string
]
string
,
error
)
{
panic
(
"unexpected call to GetAll"
)
}
func
(
s
*
openAIAdvancedSchedulerSettingRepoStub
)
Delete
(
context
.
Context
,
string
)
error
{
panic
(
"unexpected call to Delete"
)
}
func
newOpenAIAdvancedSchedulerRateLimitService
(
enabled
string
)
*
RateLimitService
{
resetOpenAIAdvancedSchedulerSettingCacheForTest
()
repo
:=
&
openAIAdvancedSchedulerSettingRepoStub
{
values
:
map
[
string
]
string
{},
}
if
enabled
!=
""
{
repo
.
values
[
openAIAdvancedSchedulerSettingKey
]
=
enabled
}
return
&
RateLimitService
{
settingService
:
NewSettingService
(
repo
,
&
config
.
Config
{}),
}
}
func
(
s
*
openAISnapshotCacheStub
)
GetSnapshot
(
ctx
context
.
Context
,
bucket
SchedulerBucket
)
([]
*
Account
,
bool
,
error
)
{
if
len
(
s
.
snapshotAccounts
)
==
0
{
return
nil
,
false
,
nil
...
...
@@ -45,6 +242,138 @@ func (s *openAISnapshotCacheStub) GetAccount(ctx context.Context, accountID int6
return
&
cloned
,
nil
}
func
TestOpenAIGatewayService_SelectAccountWithScheduler_DefaultDisabledUsesLegacyLoadAwareness
(
t
*
testing
.
T
)
{
resetOpenAIAdvancedSchedulerSettingCacheForTest
()
ctx
:=
context
.
Background
()
groupID
:=
int64
(
10106
)
accounts
:=
[]
Account
{
{
ID
:
36001
,
Platform
:
PlatformOpenAI
,
Type
:
AccountTypeAPIKey
,
Status
:
StatusActive
,
Schedulable
:
true
,
Concurrency
:
1
,
Priority
:
5
,
},
{
ID
:
36002
,
Platform
:
PlatformOpenAI
,
Type
:
AccountTypeAPIKey
,
Status
:
StatusActive
,
Schedulable
:
true
,
Concurrency
:
1
,
Priority
:
0
,
},
}
cfg
:=
&
config
.
Config
{}
cfg
.
Gateway
.
Scheduling
.
LoadBatchEnabled
=
false
cache
:=
&
schedulerTestGatewayCache
{}
svc
:=
&
OpenAIGatewayService
{
accountRepo
:
schedulerTestOpenAIAccountRepo
{
accounts
:
accounts
},
cache
:
cache
,
cfg
:
cfg
,
concurrencyService
:
NewConcurrencyService
(
schedulerTestConcurrencyCache
{}),
}
store
:=
svc
.
getOpenAIWSStateStore
()
require
.
NoError
(
t
,
store
.
BindResponseAccount
(
ctx
,
groupID
,
"resp_disabled_001"
,
36001
,
time
.
Hour
))
require
.
False
(
t
,
svc
.
isOpenAIAdvancedSchedulerEnabled
(
ctx
))
selection
,
decision
,
err
:=
svc
.
SelectAccountWithScheduler
(
ctx
,
&
groupID
,
"resp_disabled_001"
,
""
,
"gpt-5.1"
,
nil
,
OpenAIUpstreamTransportAny
,
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
selection
)
require
.
NotNil
(
t
,
selection
.
Account
)
require
.
Equal
(
t
,
int64
(
36002
),
selection
.
Account
.
ID
)
require
.
Equal
(
t
,
openAIAccountScheduleLayerLoadBalance
,
decision
.
Layer
)
require
.
False
(
t
,
decision
.
StickyPreviousHit
)
}
func
TestOpenAIGatewayService_SelectAccountWithScheduler_EnabledUsesAdvancedPreviousResponseRouting
(
t
*
testing
.
T
)
{
resetOpenAIAdvancedSchedulerSettingCacheForTest
()
ctx
:=
context
.
Background
()
groupID
:=
int64
(
10107
)
accounts
:=
[]
Account
{
{
ID
:
37001
,
Platform
:
PlatformOpenAI
,
Type
:
AccountTypeAPIKey
,
Status
:
StatusActive
,
Schedulable
:
true
,
Concurrency
:
1
,
Priority
:
5
,
Extra
:
map
[
string
]
any
{
"openai_apikey_responses_websockets_v2_enabled"
:
true
,
},
},
{
ID
:
37002
,
Platform
:
PlatformOpenAI
,
Type
:
AccountTypeAPIKey
,
Status
:
StatusActive
,
Schedulable
:
true
,
Concurrency
:
1
,
Priority
:
0
,
},
}
cfg
:=
&
config
.
Config
{}
cfg
.
Gateway
.
Scheduling
.
LoadBatchEnabled
=
false
cfg
.
Gateway
.
OpenAIWS
.
Enabled
=
true
cfg
.
Gateway
.
OpenAIWS
.
OAuthEnabled
=
true
cfg
.
Gateway
.
OpenAIWS
.
APIKeyEnabled
=
true
cfg
.
Gateway
.
OpenAIWS
.
ResponsesWebsocketsV2
=
true
cfg
.
Gateway
.
OpenAIWS
.
StickyResponseIDTTLSeconds
=
3600
svc
:=
&
OpenAIGatewayService
{
accountRepo
:
schedulerTestOpenAIAccountRepo
{
accounts
:
accounts
},
cache
:
&
schedulerTestGatewayCache
{},
cfg
:
cfg
,
rateLimitService
:
newOpenAIAdvancedSchedulerRateLimitService
(
"true"
),
concurrencyService
:
NewConcurrencyService
(
schedulerTestConcurrencyCache
{}),
}
store
:=
svc
.
getOpenAIWSStateStore
()
require
.
NoError
(
t
,
store
.
BindResponseAccount
(
ctx
,
groupID
,
"resp_enabled_001"
,
37001
,
time
.
Hour
))
require
.
True
(
t
,
svc
.
isOpenAIAdvancedSchedulerEnabled
(
ctx
))
selection
,
decision
,
err
:=
svc
.
SelectAccountWithScheduler
(
ctx
,
&
groupID
,
"resp_enabled_001"
,
""
,
"gpt-5.1"
,
nil
,
OpenAIUpstreamTransportAny
,
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
selection
)
require
.
NotNil
(
t
,
selection
.
Account
)
require
.
Equal
(
t
,
int64
(
37001
),
selection
.
Account
.
ID
)
require
.
Equal
(
t
,
openAIAccountScheduleLayerPreviousResponse
,
decision
.
Layer
)
require
.
True
(
t
,
decision
.
StickyPreviousHit
)
}
func
TestOpenAIGatewayService_OpenAIAccountSchedulerMetrics_DisabledNoOp
(
t
*
testing
.
T
)
{
resetOpenAIAdvancedSchedulerSettingCacheForTest
()
svc
:=
&
OpenAIGatewayService
{}
ttft
:=
120
svc
.
ReportOpenAIAccountScheduleResult
(
10
,
true
,
&
ttft
)
svc
.
RecordOpenAIAccountSwitch
()
snapshot
:=
svc
.
SnapshotOpenAIAccountSchedulerMetrics
()
require
.
Equal
(
t
,
OpenAIAccountSchedulerMetricsSnapshot
{},
snapshot
)
}
func
TestOpenAIGatewayService_SelectAccountWithScheduler_SessionStickyRateLimitedAccountFallsBackToFreshCandidate
(
t
*
testing
.
T
)
{
ctx
:=
context
.
Background
()
groupID
:=
int64
(
10101
)
...
...
@@ -53,10 +382,17 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_SessionStickyRateLimite
staleBackup
:=
&
Account
{
ID
:
31002
,
Platform
:
PlatformOpenAI
,
Type
:
AccountTypeOAuth
,
Status
:
StatusActive
,
Schedulable
:
true
,
Concurrency
:
1
,
Priority
:
5
}
freshSticky
:=
&
Account
{
ID
:
31001
,
Platform
:
PlatformOpenAI
,
Type
:
AccountTypeOAuth
,
Status
:
StatusActive
,
Schedulable
:
true
,
Concurrency
:
1
,
Priority
:
0
,
RateLimitResetAt
:
&
rateLimitedUntil
}
freshBackup
:=
&
Account
{
ID
:
31002
,
Platform
:
PlatformOpenAI
,
Type
:
AccountTypeOAuth
,
Status
:
StatusActive
,
Schedulable
:
true
,
Concurrency
:
1
,
Priority
:
5
}
cache
:=
&
s
tub
GatewayCache
{
sessionBindings
:
map
[
string
]
int64
{
"openai:session_hash_rate_limited"
:
31001
}}
cache
:=
&
s
chedulerTest
GatewayCache
{
sessionBindings
:
map
[
string
]
int64
{
"openai:session_hash_rate_limited"
:
31001
}}
snapshotCache
:=
&
openAISnapshotCacheStub
{
snapshotAccounts
:
[]
*
Account
{
staleSticky
,
staleBackup
},
accountsByID
:
map
[
int64
]
*
Account
{
31001
:
freshSticky
,
31002
:
freshBackup
}}
snapshotService
:=
&
SchedulerSnapshotService
{
cache
:
snapshotCache
}
svc
:=
&
OpenAIGatewayService
{
accountRepo
:
stubOpenAIAccountRepo
{
accounts
:
[]
Account
{
*
freshSticky
,
*
freshBackup
}},
cache
:
cache
,
cfg
:
&
config
.
Config
{},
schedulerSnapshot
:
snapshotService
,
concurrencyService
:
NewConcurrencyService
(
stubConcurrencyCache
{})}
svc
:=
&
OpenAIGatewayService
{
accountRepo
:
schedulerTestOpenAIAccountRepo
{
accounts
:
[]
Account
{
*
freshSticky
,
*
freshBackup
}},
cache
:
cache
,
cfg
:
&
config
.
Config
{},
rateLimitService
:
newOpenAIAdvancedSchedulerRateLimitService
(
"true"
),
schedulerSnapshot
:
snapshotService
,
concurrencyService
:
NewConcurrencyService
(
schedulerTestConcurrencyCache
{}),
}
selection
,
decision
,
err
:=
svc
.
SelectAccountWithScheduler
(
ctx
,
&
groupID
,
""
,
"session_hash_rate_limited"
,
"gpt-5.1"
,
nil
,
OpenAIUpstreamTransportAny
)
require
.
NoError
(
t
,
err
)
...
...
@@ -76,7 +412,12 @@ func TestOpenAIGatewayService_SelectAccountForModelWithExclusions_SkipsFreshlyRa
freshSecondary
:=
&
Account
{
ID
:
32002
,
Platform
:
PlatformOpenAI
,
Type
:
AccountTypeOAuth
,
Status
:
StatusActive
,
Schedulable
:
true
,
Concurrency
:
1
,
Priority
:
5
}
snapshotCache
:=
&
openAISnapshotCacheStub
{
snapshotAccounts
:
[]
*
Account
{
stalePrimary
,
staleSecondary
},
accountsByID
:
map
[
int64
]
*
Account
{
32001
:
freshPrimary
,
32002
:
freshSecondary
}}
snapshotService
:=
&
SchedulerSnapshotService
{
cache
:
snapshotCache
}
svc
:=
&
OpenAIGatewayService
{
accountRepo
:
stubOpenAIAccountRepo
{
accounts
:
[]
Account
{
*
freshPrimary
,
*
freshSecondary
}},
cfg
:
&
config
.
Config
{},
schedulerSnapshot
:
snapshotService
}
svc
:=
&
OpenAIGatewayService
{
accountRepo
:
schedulerTestOpenAIAccountRepo
{
accounts
:
[]
Account
{
*
freshPrimary
,
*
freshSecondary
}},
cfg
:
&
config
.
Config
{},
rateLimitService
:
newOpenAIAdvancedSchedulerRateLimitService
(
"true"
),
schedulerSnapshot
:
snapshotService
,
}
account
,
err
:=
svc
.
SelectAccountForModelWithExclusions
(
ctx
,
&
groupID
,
""
,
"gpt-5.1"
,
nil
)
require
.
NoError
(
t
,
err
)
...
...
@@ -92,18 +433,19 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_SessionStickyDBRuntimeR
staleBackup
:=
&
Account
{
ID
:
33002
,
Platform
:
PlatformOpenAI
,
Type
:
AccountTypeOAuth
,
Status
:
StatusActive
,
Schedulable
:
true
,
Concurrency
:
1
,
Priority
:
5
}
dbSticky
:=
Account
{
ID
:
33001
,
Platform
:
PlatformOpenAI
,
Type
:
AccountTypeOAuth
,
Status
:
StatusActive
,
Schedulable
:
true
,
Concurrency
:
1
,
Priority
:
0
,
RateLimitResetAt
:
&
rateLimitedUntil
}
dbBackup
:=
Account
{
ID
:
33002
,
Platform
:
PlatformOpenAI
,
Type
:
AccountTypeOAuth
,
Status
:
StatusActive
,
Schedulable
:
true
,
Concurrency
:
1
,
Priority
:
5
}
cache
:=
&
s
tub
GatewayCache
{
sessionBindings
:
map
[
string
]
int64
{
"openai:session_hash_db_runtime_recheck"
:
33001
}}
cache
:=
&
s
chedulerTest
GatewayCache
{
sessionBindings
:
map
[
string
]
int64
{
"openai:session_hash_db_runtime_recheck"
:
33001
}}
snapshotCache
:=
&
openAISnapshotCacheStub
{
snapshotAccounts
:
[]
*
Account
{
staleSticky
,
staleBackup
},
accountsByID
:
map
[
int64
]
*
Account
{
33001
:
staleSticky
,
33002
:
staleBackup
},
}
snapshotService
:=
&
SchedulerSnapshotService
{
cache
:
snapshotCache
}
svc
:=
&
OpenAIGatewayService
{
accountRepo
:
s
tub
OpenAIAccountRepo
{
accounts
:
[]
Account
{
dbSticky
,
dbBackup
}},
accountRepo
:
s
chedulerTest
OpenAIAccountRepo
{
accounts
:
[]
Account
{
dbSticky
,
dbBackup
}},
cache
:
cache
,
cfg
:
&
config
.
Config
{},
rateLimitService
:
newOpenAIAdvancedSchedulerRateLimitService
(
"true"
),
schedulerSnapshot
:
snapshotService
,
concurrencyService
:
NewConcurrencyService
(
s
tub
ConcurrencyCache
{}),
concurrencyService
:
NewConcurrencyService
(
s
chedulerTest
ConcurrencyCache
{}),
}
selection
,
decision
,
err
:=
svc
.
SelectAccountWithScheduler
(
ctx
,
&
groupID
,
""
,
"session_hash_db_runtime_recheck"
,
"gpt-5.1"
,
nil
,
OpenAIUpstreamTransportAny
)
...
...
@@ -128,8 +470,9 @@ func TestOpenAIGatewayService_SelectAccountForModelWithExclusions_DBRuntimeReche
}
snapshotService
:=
&
SchedulerSnapshotService
{
cache
:
snapshotCache
}
svc
:=
&
OpenAIGatewayService
{
accountRepo
:
s
tub
OpenAIAccountRepo
{
accounts
:
[]
Account
{
dbPrimary
,
dbSecondary
}},
accountRepo
:
s
chedulerTest
OpenAIAccountRepo
{
accounts
:
[]
Account
{
dbPrimary
,
dbSecondary
}},
cfg
:
&
config
.
Config
{},
rateLimitService
:
newOpenAIAdvancedSchedulerRateLimitService
(
"true"
),
schedulerSnapshot
:
snapshotService
,
}
...
...
@@ -153,7 +496,7 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_PreviousResponseSticky(
"openai_apikey_responses_websockets_v2_enabled"
:
true
,
},
}
cache
:=
&
s
tub
GatewayCache
{}
cache
:=
&
s
chedulerTest
GatewayCache
{}
cfg
:=
&
config
.
Config
{}
cfg
.
Gateway
.
OpenAIWS
.
Enabled
=
true
cfg
.
Gateway
.
OpenAIWS
.
OAuthEnabled
=
true
...
...
@@ -163,10 +506,11 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_PreviousResponseSticky(
cfg
.
Gateway
.
OpenAIWS
.
StickyResponseIDTTLSeconds
=
3600
svc
:=
&
OpenAIGatewayService
{
accountRepo
:
s
tub
OpenAIAccountRepo
{
accounts
:
[]
Account
{
account
}},
accountRepo
:
s
chedulerTest
OpenAIAccountRepo
{
accounts
:
[]
Account
{
account
}},
cache
:
cache
,
cfg
:
cfg
,
concurrencyService
:
NewConcurrencyService
(
stubConcurrencyCache
{}),
rateLimitService
:
newOpenAIAdvancedSchedulerRateLimitService
(
"true"
),
concurrencyService
:
NewConcurrencyService
(
schedulerTestConcurrencyCache
{}),
}
store
:=
svc
.
getOpenAIWSStateStore
()
...
...
@@ -204,17 +548,18 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_SessionSticky(t *testin
Schedulable
:
true
,
Concurrency
:
1
,
}
cache
:=
&
s
tub
GatewayCache
{
cache
:=
&
s
chedulerTest
GatewayCache
{
sessionBindings
:
map
[
string
]
int64
{
"openai:session_hash_abc"
:
account
.
ID
,
},
}
svc
:=
&
OpenAIGatewayService
{
accountRepo
:
s
tub
OpenAIAccountRepo
{
accounts
:
[]
Account
{
account
}},
accountRepo
:
s
chedulerTest
OpenAIAccountRepo
{
accounts
:
[]
Account
{
account
}},
cache
:
cache
,
cfg
:
&
config
.
Config
{},
concurrencyService
:
NewConcurrencyService
(
stubConcurrencyCache
{}),
rateLimitService
:
newOpenAIAdvancedSchedulerRateLimitService
(
"true"
),
concurrencyService
:
NewConcurrencyService
(
schedulerTestConcurrencyCache
{}),
}
selection
,
decision
,
err
:=
svc
.
SelectAccountWithScheduler
(
...
...
@@ -260,7 +605,7 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_SessionStickyBusyKeepsS
Priority
:
9
,
},
}
cache
:=
&
s
tub
GatewayCache
{
cache
:=
&
s
chedulerTest
GatewayCache
{
sessionBindings
:
map
[
string
]
int64
{
"openai:session_hash_sticky_busy"
:
21001
,
},
...
...
@@ -273,7 +618,7 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_SessionStickyBusyKeepsS
cfg
.
Gateway
.
OpenAIWS
.
OAuthEnabled
=
true
cfg
.
Gateway
.
OpenAIWS
.
ResponsesWebsocketsV2
=
true
concurrencyCache
:=
s
tub
ConcurrencyCache
{
concurrencyCache
:=
s
chedulerTest
ConcurrencyCache
{
acquireResults
:
map
[
int64
]
bool
{
21001
:
false
,
// sticky 账号已满
21002
:
true
,
// 若回退负载均衡会命中该账号(本测试要求不能切换)
...
...
@@ -288,9 +633,10 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_SessionStickyBusyKeepsS
}
svc
:=
&
OpenAIGatewayService
{
accountRepo
:
s
tub
OpenAIAccountRepo
{
accounts
:
accounts
},
accountRepo
:
s
chedulerTest
OpenAIAccountRepo
{
accounts
:
accounts
},
cache
:
cache
,
cfg
:
cfg
,
rateLimitService
:
newOpenAIAdvancedSchedulerRateLimitService
(
"true"
),
concurrencyService
:
NewConcurrencyService
(
concurrencyCache
),
}
...
...
@@ -328,17 +674,18 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_SessionSticky_ForceHTTP
"openai_ws_force_http"
:
true
,
},
}
cache
:=
&
s
tub
GatewayCache
{
cache
:=
&
s
chedulerTest
GatewayCache
{
sessionBindings
:
map
[
string
]
int64
{
"openai:session_hash_force_http"
:
account
.
ID
,
},
}
svc
:=
&
OpenAIGatewayService
{
accountRepo
:
s
tub
OpenAIAccountRepo
{
accounts
:
[]
Account
{
account
}},
accountRepo
:
s
chedulerTest
OpenAIAccountRepo
{
accounts
:
[]
Account
{
account
}},
cache
:
cache
,
cfg
:
&
config
.
Config
{},
concurrencyService
:
NewConcurrencyService
(
stubConcurrencyCache
{}),
rateLimitService
:
newOpenAIAdvancedSchedulerRateLimitService
(
"true"
),
concurrencyService
:
NewConcurrencyService
(
schedulerTestConcurrencyCache
{}),
}
selection
,
decision
,
err
:=
svc
.
SelectAccountWithScheduler
(
...
...
@@ -387,15 +734,15 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_RequiredWSV2_SkipsStick
},
},
}
cache
:=
&
s
tub
GatewayCache
{
cache
:=
&
s
chedulerTest
GatewayCache
{
sessionBindings
:
map
[
string
]
int64
{
"openai:session_hash_ws_only"
:
2201
,
},
}
cfg
:=
newOpenAIWSV2
Test
Config
()
cfg
:=
new
SchedulerTest
OpenAIWSV2Config
()
// 构造“HTTP-only 账号负载更低”的场景,验证 required transport 会强制过滤。
concurrencyCache
:=
s
tub
ConcurrencyCache
{
concurrencyCache
:=
s
chedulerTest
ConcurrencyCache
{
loadMap
:
map
[
int64
]
*
AccountLoadInfo
{
2201
:
{
AccountID
:
2201
,
LoadRate
:
0
,
WaitingCount
:
0
},
2202
:
{
AccountID
:
2202
,
LoadRate
:
90
,
WaitingCount
:
5
},
...
...
@@ -403,9 +750,10 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_RequiredWSV2_SkipsStick
}
svc
:=
&
OpenAIGatewayService
{
accountRepo
:
s
tub
OpenAIAccountRepo
{
accounts
:
accounts
},
accountRepo
:
s
chedulerTest
OpenAIAccountRepo
{
accounts
:
accounts
},
cache
:
cache
,
cfg
:
cfg
,
rateLimitService
:
newOpenAIAdvancedSchedulerRateLimitService
(
"true"
),
concurrencyService
:
NewConcurrencyService
(
concurrencyCache
),
}
...
...
@@ -445,10 +793,11 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_RequiredWSV2_NoAvailabl
}
svc
:=
&
OpenAIGatewayService
{
accountRepo
:
stubOpenAIAccountRepo
{
accounts
:
accounts
},
cache
:
&
stubGatewayCache
{},
cfg
:
newOpenAIWSV2TestConfig
(),
concurrencyService
:
NewConcurrencyService
(
stubConcurrencyCache
{}),
accountRepo
:
schedulerTestOpenAIAccountRepo
{
accounts
:
accounts
},
cache
:
&
schedulerTestGatewayCache
{},
cfg
:
newSchedulerTestOpenAIWSV2Config
(),
rateLimitService
:
newOpenAIAdvancedSchedulerRateLimitService
(
"true"
),
concurrencyService
:
NewConcurrencyService
(
schedulerTestConcurrencyCache
{}),
}
selection
,
decision
,
err
:=
svc
.
SelectAccountWithScheduler
(
...
...
@@ -507,7 +856,7 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_LoadBalanceTopKFallback
cfg
.
Gateway
.
OpenAIWS
.
SchedulerScoreWeights
.
ErrorRate
=
0.2
cfg
.
Gateway
.
OpenAIWS
.
SchedulerScoreWeights
.
TTFT
=
0.1
concurrencyCache
:=
s
tub
ConcurrencyCache
{
concurrencyCache
:=
s
chedulerTest
ConcurrencyCache
{
loadMap
:
map
[
int64
]
*
AccountLoadInfo
{
3001
:
{
AccountID
:
3001
,
LoadRate
:
95
,
WaitingCount
:
8
},
3002
:
{
AccountID
:
3002
,
LoadRate
:
20
,
WaitingCount
:
1
},
...
...
@@ -520,9 +869,10 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_LoadBalanceTopKFallback
}
svc
:=
&
OpenAIGatewayService
{
accountRepo
:
s
tub
OpenAIAccountRepo
{
accounts
:
accounts
},
cache
:
&
s
tub
GatewayCache
{},
accountRepo
:
s
chedulerTest
OpenAIAccountRepo
{
accounts
:
accounts
},
cache
:
&
s
chedulerTest
GatewayCache
{},
cfg
:
cfg
,
rateLimitService
:
newOpenAIAdvancedSchedulerRateLimitService
(
"true"
),
concurrencyService
:
NewConcurrencyService
(
concurrencyCache
),
}
...
...
@@ -559,16 +909,17 @@ func TestOpenAIGatewayService_OpenAIAccountSchedulerMetrics(t *testing.T) {
Schedulable
:
true
,
Concurrency
:
1
,
}
cache
:=
&
s
tub
GatewayCache
{
cache
:=
&
s
chedulerTest
GatewayCache
{
sessionBindings
:
map
[
string
]
int64
{
"openai:session_hash_metrics"
:
account
.
ID
,
},
}
svc
:=
&
OpenAIGatewayService
{
accountRepo
:
s
tub
OpenAIAccountRepo
{
accounts
:
[]
Account
{
account
}},
accountRepo
:
s
chedulerTest
OpenAIAccountRepo
{
accounts
:
[]
Account
{
account
}},
cache
:
cache
,
cfg
:
&
config
.
Config
{},
concurrencyService
:
NewConcurrencyService
(
stubConcurrencyCache
{}),
rateLimitService
:
newOpenAIAdvancedSchedulerRateLimitService
(
"true"
),
concurrencyService
:
NewConcurrencyService
(
schedulerTestConcurrencyCache
{}),
}
selection
,
_
,
err
:=
svc
.
SelectAccountWithScheduler
(
ctx
,
&
groupID
,
""
,
"session_hash_metrics"
,
"gpt-5.1"
,
nil
,
OpenAIUpstreamTransportAny
)
...
...
@@ -749,7 +1100,7 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_LoadBalanceDistributesA
cfg
.
Gateway
.
OpenAIWS
.
SchedulerScoreWeights
.
ErrorRate
=
1
cfg
.
Gateway
.
OpenAIWS
.
SchedulerScoreWeights
.
TTFT
=
1
concurrencyCache
:=
s
tub
ConcurrencyCache
{
concurrencyCache
:=
s
chedulerTest
ConcurrencyCache
{
loadMap
:
map
[
int64
]
*
AccountLoadInfo
{
5101
:
{
AccountID
:
5101
,
LoadRate
:
20
,
WaitingCount
:
1
},
5102
:
{
AccountID
:
5102
,
LoadRate
:
20
,
WaitingCount
:
1
},
...
...
@@ -757,9 +1108,10 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_LoadBalanceDistributesA
},
}
svc
:=
&
OpenAIGatewayService
{
accountRepo
:
s
tub
OpenAIAccountRepo
{
accounts
:
accounts
},
cache
:
&
s
tub
GatewayCache
{
sessionBindings
:
map
[
string
]
int64
{}},
accountRepo
:
s
chedulerTest
OpenAIAccountRepo
{
accounts
:
accounts
},
cache
:
&
s
chedulerTest
GatewayCache
{
sessionBindings
:
map
[
string
]
int64
{}},
cfg
:
cfg
,
rateLimitService
:
newOpenAIAdvancedSchedulerRateLimitService
(
"true"
),
concurrencyService
:
NewConcurrencyService
(
concurrencyCache
),
}
...
...
@@ -905,12 +1257,14 @@ func TestDefaultOpenAIAccountScheduler_ReportSwitchAndSnapshot(t *testing.T) {
}
func
TestOpenAIGatewayService_SchedulerWrappersAndDefaults
(
t
*
testing
.
T
)
{
resetOpenAIAdvancedSchedulerSettingCacheForTest
()
svc
:=
&
OpenAIGatewayService
{}
ttft
:=
120
svc
.
ReportOpenAIAccountScheduleResult
(
10
,
true
,
&
ttft
)
svc
.
RecordOpenAIAccountSwitch
()
snapshot
:=
svc
.
SnapshotOpenAIAccountSchedulerMetrics
()
require
.
GreaterOrEqual
(
t
,
snapshot
.
AccountSwitchTotal
,
int64
(
1
)
)
require
.
Equal
(
t
,
OpenAIAccountSchedulerMetricsSnapshot
{},
snapshot
)
require
.
Equal
(
t
,
7
,
svc
.
openAIWSLBTopK
())
require
.
Equal
(
t
,
openaiStickySessionTTL
,
svc
.
openAIWSSessionStickyTTL
())
...
...
@@ -947,7 +1301,7 @@ func TestDefaultOpenAIAccountScheduler_IsAccountTransportCompatible_Branches(t *
require
.
True
(
t
,
scheduler
.
isAccountTransportCompatible
(
nil
,
OpenAIUpstreamTransportHTTPSSE
))
require
.
False
(
t
,
scheduler
.
isAccountTransportCompatible
(
nil
,
OpenAIUpstreamTransportResponsesWebsocketV2
))
cfg
:=
newOpenAIWSV2
Test
Config
()
cfg
:=
new
SchedulerTest
OpenAIWSV2Config
()
scheduler
.
service
=
&
OpenAIGatewayService
{
cfg
:
cfg
}
account
:=
&
Account
{
ID
:
8801
,
...
...
backend/internal/service/openai_account_scheduler_ws_snapshot_test.go
View file @
e9de839d
...
...
@@ -38,11 +38,12 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_UsesWSPassthroughSnapsh
cfg
.
Gateway
.
OpenAIWS
.
IngressModeDefault
=
OpenAIWSIngressModeCtxPool
svc
:=
&
OpenAIGatewayService
{
accountRepo
:
s
tub
OpenAIAccountRepo
{
accounts
:
[]
Account
{
*
account
}},
cache
:
&
s
tub
GatewayCache
{},
accountRepo
:
s
chedulerTest
OpenAIAccountRepo
{
accounts
:
[]
Account
{
*
account
}},
cache
:
&
s
chedulerTest
GatewayCache
{},
cfg
:
cfg
,
rateLimitService
:
newOpenAIAdvancedSchedulerRateLimitService
(
"true"
),
schedulerSnapshot
:
&
SchedulerSnapshotService
{
cache
:
snapshotCache
},
concurrencyService
:
NewConcurrencyService
(
s
tub
ConcurrencyCache
{}),
concurrencyService
:
NewConcurrencyService
(
s
chedulerTest
ConcurrencyCache
{}),
}
selection
,
decision
,
err
:=
svc
.
SelectAccountWithScheduler
(
...
...
backend/internal/service/payment_config_service.go
View file @
e9de839d
...
...
@@ -196,12 +196,25 @@ func (s *PaymentConfigService) GetPaymentConfig(ctx context.Context) (*PaymentCo
SettingHelpImageURL
,
SettingHelpText
,
SettingCancelRateLimitOn
,
SettingCancelRateLimitMax
,
SettingCancelWindowSize
,
SettingCancelWindowUnit
,
SettingCancelWindowMode
,
SettingPaymentVisibleMethodAlipayEnabled
,
SettingPaymentVisibleMethodAlipaySource
,
SettingPaymentVisibleMethodWxpayEnabled
,
SettingPaymentVisibleMethodWxpaySource
,
}
vals
,
err
:=
s
.
settingRepo
.
GetMultiple
(
ctx
,
keys
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"get payment config settings: %w"
,
err
)
}
cfg
:=
s
.
parsePaymentConfig
(
vals
)
if
s
.
entClient
!=
nil
{
instances
,
err
:=
s
.
entClient
.
PaymentProviderInstance
.
Query
()
.
Where
(
paymentproviderinstance
.
EnabledEQ
(
true
))
.
All
(
ctx
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"list enabled provider instances: %w"
,
err
)
}
cfg
.
EnabledTypes
=
applyVisibleMethodRoutingToEnabledTypes
(
cfg
.
EnabledTypes
,
vals
,
buildVisibleMethodSourceAvailability
(
instances
))
}
else
{
cfg
.
EnabledTypes
=
applyVisibleMethodRoutingToEnabledTypes
(
cfg
.
EnabledTypes
,
vals
,
nil
)
}
// Load Stripe publishable key from the first enabled Stripe provider instance
cfg
.
StripePublishableKey
=
s
.
getStripePublishableKey
(
ctx
)
return
cfg
,
nil
...
...
@@ -234,18 +247,23 @@ func (s *PaymentConfigService) parsePaymentConfig(vals map[string]string) *Payme
cfg
.
LoadBalanceStrategy
=
payment
.
DefaultLoadBalanceStrategy
}
if
raw
:=
vals
[
SettingEnabledPaymentTypes
];
raw
!=
""
{
types
:=
make
([]
string
,
0
,
len
(
strings
.
Split
(
raw
,
","
)))
for
_
,
t
:=
range
strings
.
Split
(
raw
,
","
)
{
t
=
strings
.
TrimSpace
(
t
)
if
t
!=
""
{
cfg
.
EnabledT
ypes
=
append
(
cfg
.
EnabledT
ypes
,
t
)
t
ypes
=
append
(
t
ypes
,
t
)
}
}
cfg
.
EnabledTypes
=
NormalizeVisibleMethods
(
types
)
}
return
cfg
}
// getStripePublishableKey finds the publishable key from the first enabled Stripe provider instance.
func
(
s
*
PaymentConfigService
)
getStripePublishableKey
(
ctx
context
.
Context
)
string
{
if
s
.
entClient
==
nil
{
return
""
}
instances
,
err
:=
s
.
entClient
.
PaymentProviderInstance
.
Query
()
.
Where
(
paymentproviderinstance
.
EnabledEQ
(
true
),
...
...
@@ -385,3 +403,79 @@ func pcParseInt(s string, defaultVal int) int {
}
return
v
}
func
buildVisibleMethodSourceAvailability
(
instances
[]
*
dbent
.
PaymentProviderInstance
)
map
[
string
]
bool
{
available
:=
make
(
map
[
string
]
bool
,
4
)
for
_
,
inst
:=
range
instances
{
switch
inst
.
ProviderKey
{
case
payment
.
TypeAlipay
:
if
inst
.
SupportedTypes
==
""
||
payment
.
InstanceSupportsType
(
inst
.
SupportedTypes
,
payment
.
TypeAlipay
)
||
payment
.
InstanceSupportsType
(
inst
.
SupportedTypes
,
payment
.
TypeAlipayDirect
)
{
available
[
VisibleMethodSourceOfficialAlipay
]
=
true
}
case
payment
.
TypeWxpay
:
if
inst
.
SupportedTypes
==
""
||
payment
.
InstanceSupportsType
(
inst
.
SupportedTypes
,
payment
.
TypeWxpay
)
||
payment
.
InstanceSupportsType
(
inst
.
SupportedTypes
,
payment
.
TypeWxpayDirect
)
{
available
[
VisibleMethodSourceOfficialWechat
]
=
true
}
case
payment
.
TypeEasyPay
:
for
_
,
supportedType
:=
range
splitTypes
(
inst
.
SupportedTypes
)
{
switch
NormalizeVisibleMethod
(
supportedType
)
{
case
payment
.
TypeAlipay
:
available
[
VisibleMethodSourceEasyPayAlipay
]
=
true
case
payment
.
TypeWxpay
:
available
[
VisibleMethodSourceEasyPayWechat
]
=
true
}
}
}
}
return
available
}
func
applyVisibleMethodRoutingToEnabledTypes
(
base
[]
string
,
vals
map
[
string
]
string
,
available
map
[
string
]
bool
)
[]
string
{
shouldExpose
:=
map
[
string
]
bool
{
payment
.
TypeAlipay
:
visibleMethodShouldBeExposed
(
payment
.
TypeAlipay
,
vals
,
available
),
payment
.
TypeWxpay
:
visibleMethodShouldBeExposed
(
payment
.
TypeWxpay
,
vals
,
available
),
}
seen
:=
make
(
map
[
string
]
struct
{},
len
(
base
)
+
2
)
out
:=
make
([]
string
,
0
,
len
(
base
)
+
2
)
appendType
:=
func
(
paymentType
string
)
{
paymentType
=
NormalizeVisibleMethod
(
paymentType
)
if
paymentType
==
""
{
return
}
if
_
,
ok
:=
seen
[
paymentType
];
ok
{
return
}
seen
[
paymentType
]
=
struct
{}{}
out
=
append
(
out
,
paymentType
)
}
for
_
,
paymentType
:=
range
base
{
visibleMethod
:=
NormalizeVisibleMethod
(
paymentType
)
switch
visibleMethod
{
case
payment
.
TypeAlipay
,
payment
.
TypeWxpay
:
if
shouldExpose
[
visibleMethod
]
{
appendType
(
visibleMethod
)
}
default
:
appendType
(
visibleMethod
)
}
}
for
_
,
visibleMethod
:=
range
[]
string
{
payment
.
TypeAlipay
,
payment
.
TypeWxpay
}
{
if
shouldExpose
[
visibleMethod
]
{
appendType
(
visibleMethod
)
}
}
return
out
}
func
visibleMethodShouldBeExposed
(
method
string
,
vals
map
[
string
]
string
,
available
map
[
string
]
bool
)
bool
{
enabledKey
:=
visibleMethodEnabledSettingKey
(
method
)
sourceKey
:=
visibleMethodSourceSettingKey
(
method
)
if
enabledKey
==
""
||
sourceKey
==
""
||
vals
[
enabledKey
]
!=
"true"
{
return
false
}
source
:=
NormalizeVisibleMethodSource
(
method
,
vals
[
sourceKey
])
return
source
!=
""
&&
available
[
source
]
}
backend/internal/service/payment_config_service_test.go
View file @
e9de839d
package
service
import
(
"context"
"database/sql"
"testing"
dbent
"github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/enttest"
"github.com/Wei-Shaw/sub2api/internal/payment"
"entgo.io/ent/dialect"
entsql
"entgo.io/ent/dialect/sql"
_
"modernc.org/sqlite"
)
func
TestPcParseFloat
(
t
*
testing
.
T
)
{
...
...
@@ -163,6 +171,20 @@ func TestParsePaymentConfig(t *testing.T) {
}
})
t
.
Run
(
"enabled types are normalized to visible methods and deduplicated"
,
func
(
t
*
testing
.
T
)
{
t
.
Parallel
()
vals
:=
map
[
string
]
string
{
SettingEnabledPaymentTypes
:
"alipay_direct, alipay, wxpay_direct, wxpay"
,
}
cfg
:=
svc
.
parsePaymentConfig
(
vals
)
if
len
(
cfg
.
EnabledTypes
)
!=
2
{
t
.
Fatalf
(
"EnabledTypes len = %d, want 2"
,
len
(
cfg
.
EnabledTypes
))
}
if
cfg
.
EnabledTypes
[
0
]
!=
"alipay"
||
cfg
.
EnabledTypes
[
1
]
!=
"wxpay"
{
t
.
Fatalf
(
"EnabledTypes = %v, want [alipay wxpay]"
,
cfg
.
EnabledTypes
)
}
})
t
.
Run
(
"empty enabled types string"
,
func
(
t
*
testing
.
T
)
{
t
.
Parallel
()
vals
:=
map
[
string
]
string
{
...
...
@@ -204,3 +226,167 @@ func TestGetBasePaymentType(t *testing.T) {
})
}
}
func
TestApplyVisibleMethodRoutingToEnabledTypes
(
t
*
testing
.
T
)
{
t
.
Parallel
()
base
:=
[]
string
{
"alipay"
,
"wxpay"
,
"stripe"
}
vals
:=
map
[
string
]
string
{
SettingPaymentVisibleMethodAlipayEnabled
:
"true"
,
SettingPaymentVisibleMethodAlipaySource
:
VisibleMethodSourceOfficialAlipay
,
SettingPaymentVisibleMethodWxpayEnabled
:
"true"
,
SettingPaymentVisibleMethodWxpaySource
:
VisibleMethodSourceOfficialWechat
,
}
available
:=
map
[
string
]
bool
{
VisibleMethodSourceOfficialAlipay
:
true
,
VisibleMethodSourceOfficialWechat
:
false
,
}
got
:=
applyVisibleMethodRoutingToEnabledTypes
(
base
,
vals
,
available
)
want
:=
[]
string
{
"alipay"
,
"stripe"
}
if
len
(
got
)
!=
len
(
want
)
{
t
.
Fatalf
(
"applyVisibleMethodRoutingToEnabledTypes len = %d, want %d (%v)"
,
len
(
got
),
len
(
want
),
got
)
}
for
i
:=
range
want
{
if
got
[
i
]
!=
want
[
i
]
{
t
.
Fatalf
(
"applyVisibleMethodRoutingToEnabledTypes[%d] = %q, want %q (full=%v)"
,
i
,
got
[
i
],
want
[
i
],
got
)
}
}
}
func
TestApplyVisibleMethodRoutingAddsConfiguredVisibleMethod
(
t
*
testing
.
T
)
{
t
.
Parallel
()
base
:=
[]
string
{
"stripe"
}
vals
:=
map
[
string
]
string
{
SettingPaymentVisibleMethodAlipayEnabled
:
"true"
,
SettingPaymentVisibleMethodAlipaySource
:
VisibleMethodSourceEasyPayAlipay
,
}
available
:=
map
[
string
]
bool
{
VisibleMethodSourceEasyPayAlipay
:
true
,
}
got
:=
applyVisibleMethodRoutingToEnabledTypes
(
base
,
vals
,
available
)
want
:=
[]
string
{
"stripe"
,
"alipay"
}
if
len
(
got
)
!=
len
(
want
)
{
t
.
Fatalf
(
"applyVisibleMethodRoutingToEnabledTypes len = %d, want %d (%v)"
,
len
(
got
),
len
(
want
),
got
)
}
for
i
:=
range
want
{
if
got
[
i
]
!=
want
[
i
]
{
t
.
Fatalf
(
"applyVisibleMethodRoutingToEnabledTypes[%d] = %q, want %q (full=%v)"
,
i
,
got
[
i
],
want
[
i
],
got
)
}
}
}
func
TestBuildVisibleMethodSourceAvailability
(
t
*
testing
.
T
)
{
t
.
Parallel
()
instances
:=
[]
*
dbent
.
PaymentProviderInstance
{
{
ProviderKey
:
payment
.
TypeAlipay
,
SupportedTypes
:
"alipay"
},
{
ProviderKey
:
payment
.
TypeEasyPay
,
SupportedTypes
:
"wxpay_direct, alipay"
},
{
ProviderKey
:
payment
.
TypeWxpay
,
SupportedTypes
:
"wxpay_direct"
},
}
got
:=
buildVisibleMethodSourceAvailability
(
instances
)
if
!
got
[
VisibleMethodSourceOfficialAlipay
]
{
t
.
Fatalf
(
"expected %q to be available"
,
VisibleMethodSourceOfficialAlipay
)
}
if
!
got
[
VisibleMethodSourceEasyPayAlipay
]
{
t
.
Fatalf
(
"expected %q to be available"
,
VisibleMethodSourceEasyPayAlipay
)
}
if
!
got
[
VisibleMethodSourceOfficialWechat
]
{
t
.
Fatalf
(
"expected %q to be available"
,
VisibleMethodSourceOfficialWechat
)
}
if
!
got
[
VisibleMethodSourceEasyPayWechat
]
{
t
.
Fatalf
(
"expected %q to be available"
,
VisibleMethodSourceEasyPayWechat
)
}
}
func
TestGetPaymentConfigAppliesVisibleMethodRouting
(
t
*
testing
.
T
)
{
ctx
:=
context
.
Background
()
client
:=
newPaymentConfigServiceTestClient
(
t
)
_
,
err
:=
client
.
PaymentProviderInstance
.
Create
()
.
SetProviderKey
(
payment
.
TypeEasyPay
)
.
SetName
(
"EasyPay Alipay"
)
.
SetConfig
(
"{}"
)
.
SetSupportedTypes
(
"alipay"
)
.
SetEnabled
(
true
)
.
Save
(
ctx
)
if
err
!=
nil
{
t
.
Fatalf
(
"create easypay instance: %v"
,
err
)
}
svc
:=
&
PaymentConfigService
{
entClient
:
client
,
settingRepo
:
&
paymentConfigSettingRepoStub
{
values
:
map
[
string
]
string
{
SettingEnabledPaymentTypes
:
"alipay,wxpay,stripe"
,
SettingPaymentVisibleMethodAlipayEnabled
:
"true"
,
SettingPaymentVisibleMethodAlipaySource
:
"easypay"
,
SettingPaymentVisibleMethodWxpayEnabled
:
"true"
,
SettingPaymentVisibleMethodWxpaySource
:
"wxpay"
,
},
},
}
cfg
,
err
:=
svc
.
GetPaymentConfig
(
ctx
)
if
err
!=
nil
{
t
.
Fatalf
(
"GetPaymentConfig returned error: %v"
,
err
)
}
want
:=
[]
string
{
payment
.
TypeAlipay
,
payment
.
TypeStripe
}
if
len
(
cfg
.
EnabledTypes
)
!=
len
(
want
)
{
t
.
Fatalf
(
"EnabledTypes len = %d, want %d (%v)"
,
len
(
cfg
.
EnabledTypes
),
len
(
want
),
cfg
.
EnabledTypes
)
}
for
i
:=
range
want
{
if
cfg
.
EnabledTypes
[
i
]
!=
want
[
i
]
{
t
.
Fatalf
(
"EnabledTypes[%d] = %q, want %q (full=%v)"
,
i
,
cfg
.
EnabledTypes
[
i
],
want
[
i
],
cfg
.
EnabledTypes
)
}
}
}
func
newPaymentConfigServiceTestClient
(
t
*
testing
.
T
)
*
dbent
.
Client
{
t
.
Helper
()
db
,
err
:=
sql
.
Open
(
"sqlite"
,
"file:payment_config_service?mode=memory&cache=shared"
)
if
err
!=
nil
{
t
.
Fatalf
(
"open sqlite: %v"
,
err
)
}
t
.
Cleanup
(
func
()
{
_
=
db
.
Close
()
})
if
_
,
err
:=
db
.
Exec
(
"PRAGMA foreign_keys = ON"
);
err
!=
nil
{
t
.
Fatalf
(
"enable foreign keys: %v"
,
err
)
}
drv
:=
entsql
.
OpenDB
(
dialect
.
SQLite
,
db
)
client
:=
enttest
.
NewClient
(
t
,
enttest
.
WithOptions
(
dbent
.
Driver
(
drv
)))
t
.
Cleanup
(
func
()
{
_
=
client
.
Close
()
})
return
client
}
type
paymentConfigSettingRepoStub
struct
{
values
map
[
string
]
string
}
func
(
s
*
paymentConfigSettingRepoStub
)
Get
(
context
.
Context
,
string
)
(
*
Setting
,
error
)
{
return
nil
,
nil
}
func
(
s
*
paymentConfigSettingRepoStub
)
GetValue
(
_
context
.
Context
,
key
string
)
(
string
,
error
)
{
return
s
.
values
[
key
],
nil
}
func
(
s
*
paymentConfigSettingRepoStub
)
Set
(
context
.
Context
,
string
,
string
)
error
{
return
nil
}
func
(
s
*
paymentConfigSettingRepoStub
)
GetMultiple
(
_
context
.
Context
,
keys
[]
string
)
(
map
[
string
]
string
,
error
)
{
out
:=
make
(
map
[
string
]
string
,
len
(
keys
))
for
_
,
key
:=
range
keys
{
out
[
key
]
=
s
.
values
[
key
]
}
return
out
,
nil
}
func
(
s
*
paymentConfigSettingRepoStub
)
SetMultiple
(
context
.
Context
,
map
[
string
]
string
)
error
{
return
nil
}
func
(
s
*
paymentConfigSettingRepoStub
)
GetAll
(
context
.
Context
)
(
map
[
string
]
string
,
error
)
{
return
s
.
values
,
nil
}
func
(
s
*
paymentConfigSettingRepoStub
)
Delete
(
context
.
Context
,
string
)
error
{
return
nil
}
backend/internal/service/payment_resume_service.go
0 → 100644
View file @
e9de839d
package
service
import
(
"context"
"crypto/hmac"
"crypto/sha256"
"encoding/base64"
"encoding/json"
"fmt"
"net/url"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/payment"
infraerrors
"github.com/Wei-Shaw/sub2api/internal/pkg/errors"
)
const
(
PaymentSourceHostedRedirect
=
"hosted_redirect"
PaymentSourceWechatInAppResume
=
"wechat_in_app_resume"
paymentResumeFallbackSigningKey
=
"sub2api-payment-resume"
SettingPaymentVisibleMethodAlipaySource
=
"payment_visible_method_alipay_source"
SettingPaymentVisibleMethodWxpaySource
=
"payment_visible_method_wxpay_source"
SettingPaymentVisibleMethodAlipayEnabled
=
"payment_visible_method_alipay_enabled"
SettingPaymentVisibleMethodWxpayEnabled
=
"payment_visible_method_wxpay_enabled"
VisibleMethodSourceOfficialAlipay
=
"official_alipay"
VisibleMethodSourceEasyPayAlipay
=
"easypay_alipay"
VisibleMethodSourceOfficialWechat
=
"official_wxpay"
VisibleMethodSourceEasyPayWechat
=
"easypay_wxpay"
)
type
ResumeTokenClaims
struct
{
OrderID
int64
`json:"oid"`
UserID
int64
`json:"uid,omitempty"`
ProviderInstanceID
string
`json:"pi,omitempty"`
ProviderKey
string
`json:"pk,omitempty"`
PaymentType
string
`json:"pt,omitempty"`
CanonicalReturnURL
string
`json:"ru,omitempty"`
IssuedAt
int64
`json:"iat"`
}
type
PaymentResumeService
struct
{
signingKey
[]
byte
}
type
visibleMethodLoadBalancer
struct
{
inner
payment
.
LoadBalancer
configService
*
PaymentConfigService
}
func
NewPaymentResumeService
(
signingKey
[]
byte
)
*
PaymentResumeService
{
return
&
PaymentResumeService
{
signingKey
:
signingKey
}
}
func
NormalizeVisibleMethod
(
method
string
)
string
{
return
payment
.
GetBasePaymentType
(
strings
.
TrimSpace
(
method
))
}
func
NormalizeVisibleMethods
(
methods
[]
string
)
[]
string
{
if
len
(
methods
)
==
0
{
return
nil
}
seen
:=
make
(
map
[
string
]
struct
{},
len
(
methods
))
out
:=
make
([]
string
,
0
,
len
(
methods
))
for
_
,
method
:=
range
methods
{
normalized
:=
NormalizeVisibleMethod
(
method
)
if
normalized
==
""
{
continue
}
if
_
,
ok
:=
seen
[
normalized
];
ok
{
continue
}
seen
[
normalized
]
=
struct
{}{}
out
=
append
(
out
,
normalized
)
}
return
out
}
func
NormalizePaymentSource
(
source
string
)
string
{
switch
strings
.
TrimSpace
(
strings
.
ToLower
(
source
))
{
case
""
,
PaymentSourceHostedRedirect
:
return
PaymentSourceHostedRedirect
case
"wechat_in_app"
,
"wxpay_resume"
,
PaymentSourceWechatInAppResume
:
return
PaymentSourceWechatInAppResume
default
:
return
strings
.
TrimSpace
(
strings
.
ToLower
(
source
))
}
}
func
NormalizeVisibleMethodSource
(
method
,
source
string
)
string
{
switch
NormalizeVisibleMethod
(
method
)
{
case
payment
.
TypeAlipay
:
switch
strings
.
TrimSpace
(
strings
.
ToLower
(
source
))
{
case
VisibleMethodSourceOfficialAlipay
,
payment
.
TypeAlipay
,
payment
.
TypeAlipayDirect
,
"official"
:
return
VisibleMethodSourceOfficialAlipay
case
VisibleMethodSourceEasyPayAlipay
,
payment
.
TypeEasyPay
:
return
VisibleMethodSourceEasyPayAlipay
}
case
payment
.
TypeWxpay
:
switch
strings
.
TrimSpace
(
strings
.
ToLower
(
source
))
{
case
VisibleMethodSourceOfficialWechat
,
payment
.
TypeWxpay
,
payment
.
TypeWxpayDirect
,
"wechat"
,
"official"
:
return
VisibleMethodSourceOfficialWechat
case
VisibleMethodSourceEasyPayWechat
,
payment
.
TypeEasyPay
:
return
VisibleMethodSourceEasyPayWechat
}
}
return
""
}
func
VisibleMethodProviderKeyForSource
(
method
,
source
string
)
(
string
,
bool
)
{
switch
NormalizeVisibleMethodSource
(
method
,
source
)
{
case
VisibleMethodSourceOfficialAlipay
:
return
payment
.
TypeAlipay
,
NormalizeVisibleMethod
(
method
)
==
payment
.
TypeAlipay
case
VisibleMethodSourceEasyPayAlipay
:
return
payment
.
TypeEasyPay
,
NormalizeVisibleMethod
(
method
)
==
payment
.
TypeAlipay
case
VisibleMethodSourceOfficialWechat
:
return
payment
.
TypeWxpay
,
NormalizeVisibleMethod
(
method
)
==
payment
.
TypeWxpay
case
VisibleMethodSourceEasyPayWechat
:
return
payment
.
TypeEasyPay
,
NormalizeVisibleMethod
(
method
)
==
payment
.
TypeWxpay
default
:
return
""
,
false
}
}
func
newVisibleMethodLoadBalancer
(
inner
payment
.
LoadBalancer
,
configService
*
PaymentConfigService
)
payment
.
LoadBalancer
{
if
inner
==
nil
||
configService
==
nil
||
configService
.
settingRepo
==
nil
{
return
inner
}
return
&
visibleMethodLoadBalancer
{
inner
:
inner
,
configService
:
configService
}
}
func
(
lb
*
visibleMethodLoadBalancer
)
GetInstanceConfig
(
ctx
context
.
Context
,
instanceID
int64
)
(
map
[
string
]
string
,
error
)
{
return
lb
.
inner
.
GetInstanceConfig
(
ctx
,
instanceID
)
}
func
(
lb
*
visibleMethodLoadBalancer
)
SelectInstance
(
ctx
context
.
Context
,
providerKey
string
,
paymentType
payment
.
PaymentType
,
strategy
payment
.
Strategy
,
orderAmount
float64
)
(
*
payment
.
InstanceSelection
,
error
)
{
visibleMethod
:=
NormalizeVisibleMethod
(
paymentType
)
if
providerKey
!=
""
||
(
visibleMethod
!=
payment
.
TypeAlipay
&&
visibleMethod
!=
payment
.
TypeWxpay
)
{
return
lb
.
inner
.
SelectInstance
(
ctx
,
providerKey
,
paymentType
,
strategy
,
orderAmount
)
}
enabledKey
:=
visibleMethodEnabledSettingKey
(
visibleMethod
)
sourceKey
:=
visibleMethodSourceSettingKey
(
visibleMethod
)
vals
,
err
:=
lb
.
configService
.
settingRepo
.
GetMultiple
(
ctx
,
[]
string
{
enabledKey
,
sourceKey
})
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"load visible method routing for %s: %w"
,
visibleMethod
,
err
)
}
if
vals
[
enabledKey
]
!=
"true"
{
return
nil
,
fmt
.
Errorf
(
"visible payment method %s is disabled"
,
visibleMethod
)
}
targetProviderKey
,
ok
:=
VisibleMethodProviderKeyForSource
(
visibleMethod
,
vals
[
sourceKey
])
if
!
ok
{
return
nil
,
fmt
.
Errorf
(
"visible payment method %s has no valid source"
,
visibleMethod
)
}
return
lb
.
inner
.
SelectInstance
(
ctx
,
targetProviderKey
,
paymentType
,
strategy
,
orderAmount
)
}
func
visibleMethodEnabledSettingKey
(
method
string
)
string
{
switch
NormalizeVisibleMethod
(
method
)
{
case
payment
.
TypeAlipay
:
return
SettingPaymentVisibleMethodAlipayEnabled
case
payment
.
TypeWxpay
:
return
SettingPaymentVisibleMethodWxpayEnabled
default
:
return
""
}
}
func
visibleMethodSourceSettingKey
(
method
string
)
string
{
switch
NormalizeVisibleMethod
(
method
)
{
case
payment
.
TypeAlipay
:
return
SettingPaymentVisibleMethodAlipaySource
case
payment
.
TypeWxpay
:
return
SettingPaymentVisibleMethodWxpaySource
default
:
return
""
}
}
func
CanonicalizeReturnURL
(
raw
string
)
(
string
,
error
)
{
raw
=
strings
.
TrimSpace
(
raw
)
if
raw
==
""
{
return
""
,
nil
}
parsed
,
err
:=
url
.
Parse
(
raw
)
if
err
!=
nil
||
!
parsed
.
IsAbs
()
||
parsed
.
Host
==
""
{
return
""
,
infraerrors
.
BadRequest
(
"INVALID_RETURN_URL"
,
"return_url must be an absolute http/https URL"
)
}
if
parsed
.
Scheme
!=
"http"
&&
parsed
.
Scheme
!=
"https"
{
return
""
,
infraerrors
.
BadRequest
(
"INVALID_RETURN_URL"
,
"return_url must use http or https"
)
}
parsed
.
Fragment
=
""
if
parsed
.
Path
==
""
{
parsed
.
Path
=
"/"
}
return
parsed
.
String
(),
nil
}
func
(
s
*
PaymentResumeService
)
CreateToken
(
claims
ResumeTokenClaims
)
(
string
,
error
)
{
if
claims
.
OrderID
<=
0
{
return
""
,
fmt
.
Errorf
(
"resume token requires order id"
)
}
if
claims
.
IssuedAt
==
0
{
claims
.
IssuedAt
=
time
.
Now
()
.
Unix
()
}
payload
,
err
:=
json
.
Marshal
(
claims
)
if
err
!=
nil
{
return
""
,
fmt
.
Errorf
(
"marshal resume claims: %w"
,
err
)
}
encodedPayload
:=
base64
.
RawURLEncoding
.
EncodeToString
(
payload
)
return
encodedPayload
+
"."
+
s
.
sign
(
encodedPayload
),
nil
}
func
(
s
*
PaymentResumeService
)
ParseToken
(
token
string
)
(
*
ResumeTokenClaims
,
error
)
{
parts
:=
strings
.
Split
(
token
,
"."
)
if
len
(
parts
)
!=
2
||
parts
[
0
]
==
""
||
parts
[
1
]
==
""
{
return
nil
,
infraerrors
.
BadRequest
(
"INVALID_RESUME_TOKEN"
,
"resume token is malformed"
)
}
if
!
hmac
.
Equal
([]
byte
(
parts
[
1
]),
[]
byte
(
s
.
sign
(
parts
[
0
])))
{
return
nil
,
infraerrors
.
BadRequest
(
"INVALID_RESUME_TOKEN"
,
"resume token signature mismatch"
)
}
payload
,
err
:=
base64
.
RawURLEncoding
.
DecodeString
(
parts
[
0
])
if
err
!=
nil
{
return
nil
,
infraerrors
.
BadRequest
(
"INVALID_RESUME_TOKEN"
,
"resume token payload is malformed"
)
}
var
claims
ResumeTokenClaims
if
err
:=
json
.
Unmarshal
(
payload
,
&
claims
);
err
!=
nil
{
return
nil
,
infraerrors
.
BadRequest
(
"INVALID_RESUME_TOKEN"
,
"resume token payload is invalid"
)
}
if
claims
.
OrderID
<=
0
{
return
nil
,
infraerrors
.
BadRequest
(
"INVALID_RESUME_TOKEN"
,
"resume token missing order id"
)
}
return
&
claims
,
nil
}
func
(
s
*
PaymentResumeService
)
sign
(
payload
string
)
string
{
key
:=
s
.
signingKey
if
len
(
key
)
==
0
{
key
=
[]
byte
(
paymentResumeFallbackSigningKey
)
}
mac
:=
hmac
.
New
(
sha256
.
New
,
key
)
_
,
_
=
mac
.
Write
([]
byte
(
payload
))
return
base64
.
RawURLEncoding
.
EncodeToString
(
mac
.
Sum
(
nil
))
}
backend/internal/service/payment_resume_service_test.go
0 → 100644
View file @
e9de839d
//go:build unit
package
service
import
(
"context"
"testing"
"github.com/Wei-Shaw/sub2api/internal/payment"
)
func
TestNormalizeVisibleMethods
(
t
*
testing
.
T
)
{
t
.
Parallel
()
got
:=
NormalizeVisibleMethods
([]
string
{
"alipay_direct"
,
"alipay"
,
" wxpay_direct "
,
"wxpay"
,
"stripe"
,
})
want
:=
[]
string
{
"alipay"
,
"wxpay"
,
"stripe"
}
if
len
(
got
)
!=
len
(
want
)
{
t
.
Fatalf
(
"NormalizeVisibleMethods len = %d, want %d (%v)"
,
len
(
got
),
len
(
want
),
got
)
}
for
i
:=
range
want
{
if
got
[
i
]
!=
want
[
i
]
{
t
.
Fatalf
(
"NormalizeVisibleMethods[%d] = %q, want %q (full=%v)"
,
i
,
got
[
i
],
want
[
i
],
got
)
}
}
}
func
TestNormalizePaymentSource
(
t
*
testing
.
T
)
{
t
.
Parallel
()
tests
:=
[]
struct
{
name
string
input
string
expect
string
}{
{
name
:
"empty uses default"
,
input
:
""
,
expect
:
PaymentSourceHostedRedirect
},
{
name
:
"wechat alias normalized"
,
input
:
"wechat_in_app"
,
expect
:
PaymentSourceWechatInAppResume
},
{
name
:
"canonical value preserved"
,
input
:
PaymentSourceWechatInAppResume
,
expect
:
PaymentSourceWechatInAppResume
},
}
for
_
,
tt
:=
range
tests
{
t
.
Run
(
tt
.
name
,
func
(
t
*
testing
.
T
)
{
t
.
Parallel
()
if
got
:=
NormalizePaymentSource
(
tt
.
input
);
got
!=
tt
.
expect
{
t
.
Fatalf
(
"NormalizePaymentSource(%q) = %q, want %q"
,
tt
.
input
,
got
,
tt
.
expect
)
}
})
}
}
func
TestCanonicalizeReturnURL
(
t
*
testing
.
T
)
{
t
.
Parallel
()
got
,
err
:=
CanonicalizeReturnURL
(
"https://example.com/pay/result?b=2#a"
)
if
err
!=
nil
{
t
.
Fatalf
(
"CanonicalizeReturnURL returned error: %v"
,
err
)
}
if
got
!=
"https://example.com/pay/result?b=2"
{
t
.
Fatalf
(
"CanonicalizeReturnURL = %q, want %q"
,
got
,
"https://example.com/pay/result?b=2"
)
}
}
func
TestCanonicalizeReturnURLRejectsRelativeURL
(
t
*
testing
.
T
)
{
t
.
Parallel
()
if
_
,
err
:=
CanonicalizeReturnURL
(
"/payment/result"
);
err
==
nil
{
t
.
Fatal
(
"CanonicalizeReturnURL should reject relative URLs"
)
}
}
func
TestPaymentResumeTokenRoundTrip
(
t
*
testing
.
T
)
{
t
.
Parallel
()
svc
:=
NewPaymentResumeService
([]
byte
(
"0123456789abcdef0123456789abcdef"
))
token
,
err
:=
svc
.
CreateToken
(
ResumeTokenClaims
{
OrderID
:
42
,
UserID
:
7
,
ProviderInstanceID
:
"19"
,
ProviderKey
:
"easypay"
,
PaymentType
:
"wxpay"
,
CanonicalReturnURL
:
"https://example.com/payment/result"
,
IssuedAt
:
1234567890
,
})
if
err
!=
nil
{
t
.
Fatalf
(
"CreateToken returned error: %v"
,
err
)
}
claims
,
err
:=
svc
.
ParseToken
(
token
)
if
err
!=
nil
{
t
.
Fatalf
(
"ParseToken returned error: %v"
,
err
)
}
if
claims
.
OrderID
!=
42
||
claims
.
UserID
!=
7
{
t
.
Fatalf
(
"claims mismatch: %+v"
,
claims
)
}
if
claims
.
ProviderInstanceID
!=
"19"
||
claims
.
ProviderKey
!=
"easypay"
||
claims
.
PaymentType
!=
"wxpay"
{
t
.
Fatalf
(
"claims provider snapshot mismatch: %+v"
,
claims
)
}
if
claims
.
CanonicalReturnURL
!=
"https://example.com/payment/result"
{
t
.
Fatalf
(
"claims return URL = %q"
,
claims
.
CanonicalReturnURL
)
}
}
func
TestNormalizeVisibleMethodSource
(
t
*
testing
.
T
)
{
t
.
Parallel
()
tests
:=
[]
struct
{
name
string
method
string
input
string
want
string
}{
{
name
:
"alipay official alias"
,
method
:
payment
.
TypeAlipay
,
input
:
"alipay"
,
want
:
VisibleMethodSourceOfficialAlipay
},
{
name
:
"alipay easypay alias"
,
method
:
payment
.
TypeAlipay
,
input
:
"easypay"
,
want
:
VisibleMethodSourceEasyPayAlipay
},
{
name
:
"wxpay official alias"
,
method
:
payment
.
TypeWxpay
,
input
:
"wxpay"
,
want
:
VisibleMethodSourceOfficialWechat
},
{
name
:
"wxpay easypay alias"
,
method
:
payment
.
TypeWxpay
,
input
:
"easypay"
,
want
:
VisibleMethodSourceEasyPayWechat
},
{
name
:
"unsupported source"
,
method
:
payment
.
TypeWxpay
,
input
:
"stripe"
,
want
:
""
},
}
for
_
,
tt
:=
range
tests
{
t
.
Run
(
tt
.
name
,
func
(
t
*
testing
.
T
)
{
t
.
Parallel
()
if
got
:=
NormalizeVisibleMethodSource
(
tt
.
method
,
tt
.
input
);
got
!=
tt
.
want
{
t
.
Fatalf
(
"NormalizeVisibleMethodSource(%q, %q) = %q, want %q"
,
tt
.
method
,
tt
.
input
,
got
,
tt
.
want
)
}
})
}
}
func
TestVisibleMethodProviderKeyForSource
(
t
*
testing
.
T
)
{
t
.
Parallel
()
tests
:=
[]
struct
{
name
string
method
string
source
string
want
string
ok
bool
}{
{
name
:
"official alipay"
,
method
:
payment
.
TypeAlipay
,
source
:
VisibleMethodSourceOfficialAlipay
,
want
:
payment
.
TypeAlipay
,
ok
:
true
},
{
name
:
"easypay alipay"
,
method
:
payment
.
TypeAlipay
,
source
:
VisibleMethodSourceEasyPayAlipay
,
want
:
payment
.
TypeEasyPay
,
ok
:
true
},
{
name
:
"official wechat"
,
method
:
payment
.
TypeWxpay
,
source
:
VisibleMethodSourceOfficialWechat
,
want
:
payment
.
TypeWxpay
,
ok
:
true
},
{
name
:
"easypay wechat"
,
method
:
payment
.
TypeWxpay
,
source
:
VisibleMethodSourceEasyPayWechat
,
want
:
payment
.
TypeEasyPay
,
ok
:
true
},
{
name
:
"mismatched method and source"
,
method
:
payment
.
TypeAlipay
,
source
:
VisibleMethodSourceOfficialWechat
,
want
:
""
,
ok
:
false
},
}
for
_
,
tt
:=
range
tests
{
t
.
Run
(
tt
.
name
,
func
(
t
*
testing
.
T
)
{
t
.
Parallel
()
got
,
ok
:=
VisibleMethodProviderKeyForSource
(
tt
.
method
,
tt
.
source
)
if
got
!=
tt
.
want
||
ok
!=
tt
.
ok
{
t
.
Fatalf
(
"VisibleMethodProviderKeyForSource(%q, %q) = (%q, %v), want (%q, %v)"
,
tt
.
method
,
tt
.
source
,
got
,
ok
,
tt
.
want
,
tt
.
ok
)
}
})
}
}
func
TestVisibleMethodLoadBalancerUsesConfiguredSource
(
t
*
testing
.
T
)
{
t
.
Parallel
()
inner
:=
&
captureLoadBalancer
{}
configService
:=
&
PaymentConfigService
{
settingRepo
:
&
paymentSettingRepoStub
{
values
:
map
[
string
]
string
{
SettingPaymentVisibleMethodAlipayEnabled
:
"true"
,
SettingPaymentVisibleMethodAlipaySource
:
VisibleMethodSourceOfficialAlipay
,
},
},
}
lb
:=
newVisibleMethodLoadBalancer
(
inner
,
configService
)
_
,
err
:=
lb
.
SelectInstance
(
context
.
Background
(),
""
,
payment
.
TypeAlipay
,
payment
.
StrategyRoundRobin
,
12.5
)
if
err
!=
nil
{
t
.
Fatalf
(
"SelectInstance returned error: %v"
,
err
)
}
if
inner
.
lastProviderKey
!=
payment
.
TypeAlipay
{
t
.
Fatalf
(
"lastProviderKey = %q, want %q"
,
inner
.
lastProviderKey
,
payment
.
TypeAlipay
)
}
}
func
TestVisibleMethodLoadBalancerRejectsDisabledVisibleMethod
(
t
*
testing
.
T
)
{
t
.
Parallel
()
inner
:=
&
captureLoadBalancer
{}
configService
:=
&
PaymentConfigService
{
settingRepo
:
&
paymentSettingRepoStub
{
values
:
map
[
string
]
string
{
SettingPaymentVisibleMethodWxpayEnabled
:
"false"
,
SettingPaymentVisibleMethodWxpaySource
:
VisibleMethodSourceOfficialWechat
,
},
},
}
lb
:=
newVisibleMethodLoadBalancer
(
inner
,
configService
)
if
_
,
err
:=
lb
.
SelectInstance
(
context
.
Background
(),
""
,
payment
.
TypeWxpay
,
payment
.
StrategyRoundRobin
,
9.9
);
err
==
nil
{
t
.
Fatal
(
"SelectInstance should reject disabled visible method"
)
}
}
type
paymentSettingRepoStub
struct
{
values
map
[
string
]
string
}
func
(
s
*
paymentSettingRepoStub
)
Get
(
context
.
Context
,
string
)
(
*
Setting
,
error
)
{
return
nil
,
nil
}
func
(
s
*
paymentSettingRepoStub
)
GetValue
(
_
context
.
Context
,
key
string
)
(
string
,
error
)
{
return
s
.
values
[
key
],
nil
}
func
(
s
*
paymentSettingRepoStub
)
Set
(
context
.
Context
,
string
,
string
)
error
{
return
nil
}
func
(
s
*
paymentSettingRepoStub
)
GetMultiple
(
_
context
.
Context
,
keys
[]
string
)
(
map
[
string
]
string
,
error
)
{
out
:=
make
(
map
[
string
]
string
,
len
(
keys
))
for
_
,
key
:=
range
keys
{
out
[
key
]
=
s
.
values
[
key
]
}
return
out
,
nil
}
func
(
s
*
paymentSettingRepoStub
)
SetMultiple
(
context
.
Context
,
map
[
string
]
string
)
error
{
return
nil
}
func
(
s
*
paymentSettingRepoStub
)
GetAll
(
context
.
Context
)
(
map
[
string
]
string
,
error
)
{
return
s
.
values
,
nil
}
func
(
s
*
paymentSettingRepoStub
)
Delete
(
context
.
Context
,
string
)
error
{
return
nil
}
type
captureLoadBalancer
struct
{
lastProviderKey
string
lastPaymentType
string
}
func
(
c
*
captureLoadBalancer
)
GetInstanceConfig
(
context
.
Context
,
int64
)
(
map
[
string
]
string
,
error
)
{
return
map
[
string
]
string
{},
nil
}
func
(
c
*
captureLoadBalancer
)
SelectInstance
(
_
context
.
Context
,
providerKey
string
,
paymentType
payment
.
PaymentType
,
_
payment
.
Strategy
,
_
float64
)
(
*
payment
.
InstanceSelection
,
error
)
{
c
.
lastProviderKey
=
providerKey
c
.
lastPaymentType
=
paymentType
return
&
payment
.
InstanceSelection
{
ProviderKey
:
providerKey
,
SupportedTypes
:
paymentType
},
nil
}
backend/internal/service/payment_service.go
View file @
e9de839d
...
...
@@ -65,15 +65,17 @@ func generateRandomString(n int) string {
}
type
CreateOrderRequest
struct
{
UserID
int64
Amount
float64
PaymentType
string
ClientIP
string
IsMobile
bool
SrcHost
string
SrcURL
string
OrderType
string
PlanID
int64
UserID
int64
Amount
float64
PaymentType
string
ClientIP
string
IsMobile
bool
SrcHost
string
SrcURL
string
ReturnURL
string
PaymentSource
string
OrderType
string
PlanID
int64
}
type
CreateOrderResponse
struct
{
...
...
@@ -88,6 +90,7 @@ type CreateOrderResponse struct {
ClientSecret
string
`json:"client_secret,omitempty"`
ExpiresAt
time
.
Time
`json:"expires_at"`
PaymentMode
string
`json:"payment_mode,omitempty"`
ResumeToken
string
`json:"resume_token,omitempty"`
}
type
OrderListParams
struct
{
...
...
@@ -165,10 +168,13 @@ type PaymentService struct {
configService
*
PaymentConfigService
userRepo
UserRepository
groupRepo
GroupRepository
resumeService
*
PaymentResumeService
}
func
NewPaymentService
(
entClient
*
dbent
.
Client
,
registry
*
payment
.
Registry
,
loadBalancer
payment
.
LoadBalancer
,
redeemService
*
RedeemService
,
subscriptionSvc
*
SubscriptionService
,
configService
*
PaymentConfigService
,
userRepo
UserRepository
,
groupRepo
GroupRepository
)
*
PaymentService
{
return
&
PaymentService
{
entClient
:
entClient
,
registry
:
registry
,
loadBalancer
:
loadBalancer
,
redeemService
:
redeemService
,
subscriptionSvc
:
subscriptionSvc
,
configService
:
configService
,
userRepo
:
userRepo
,
groupRepo
:
groupRepo
}
svc
:=
&
PaymentService
{
entClient
:
entClient
,
registry
:
registry
,
loadBalancer
:
newVisibleMethodLoadBalancer
(
loadBalancer
,
configService
),
redeemService
:
redeemService
,
subscriptionSvc
:
subscriptionSvc
,
configService
:
configService
,
userRepo
:
userRepo
,
groupRepo
:
groupRepo
}
svc
.
resumeService
=
NewPaymentResumeService
(
psResumeSigningKey
(
configService
))
return
svc
}
// --- Provider Registry ---
...
...
@@ -262,6 +268,20 @@ func psNilIfEmpty(s string) *string {
return
&
s
}
func
(
s
*
PaymentService
)
paymentResume
()
*
PaymentResumeService
{
if
s
.
resumeService
!=
nil
{
return
s
.
resumeService
}
return
NewPaymentResumeService
(
psResumeSigningKey
(
s
.
configService
))
}
func
psResumeSigningKey
(
configService
*
PaymentConfigService
)
[]
byte
{
if
configService
==
nil
{
return
nil
}
return
configService
.
encryptionKey
}
func
psSliceContains
(
sl
[]
string
,
s
string
)
bool
{
for
_
,
v
:=
range
sl
{
if
v
==
s
{
...
...
backend/internal/service/setting_service.go
View file @
e9de839d
...
...
@@ -9,6 +9,7 @@ import (
"fmt"
"log/slog"
"net/url"
"os"
"sort"
"strconv"
"strings"
...
...
@@ -114,6 +115,66 @@ type SettingService struct {
webSearchManagerBuilder
WebSearchManagerBuilder
}
type
ProviderDefaultGrantSettings
struct
{
Balance
float64
Concurrency
int
Subscriptions
[]
DefaultSubscriptionSetting
GrantOnSignup
bool
GrantOnFirstBind
bool
}
type
AuthSourceDefaultSettings
struct
{
Email
ProviderDefaultGrantSettings
LinuxDo
ProviderDefaultGrantSettings
OIDC
ProviderDefaultGrantSettings
WeChat
ProviderDefaultGrantSettings
ForceEmailOnThirdPartySignup
bool
}
type
authSourceDefaultKeySet
struct
{
balance
string
concurrency
string
subscriptions
string
grantOnSignup
string
grantOnFirstBind
string
}
var
(
emailAuthSourceDefaultKeys
=
authSourceDefaultKeySet
{
balance
:
SettingKeyAuthSourceDefaultEmailBalance
,
concurrency
:
SettingKeyAuthSourceDefaultEmailConcurrency
,
subscriptions
:
SettingKeyAuthSourceDefaultEmailSubscriptions
,
grantOnSignup
:
SettingKeyAuthSourceDefaultEmailGrantOnSignup
,
grantOnFirstBind
:
SettingKeyAuthSourceDefaultEmailGrantOnFirstBind
,
}
linuxDoAuthSourceDefaultKeys
=
authSourceDefaultKeySet
{
balance
:
SettingKeyAuthSourceDefaultLinuxDoBalance
,
concurrency
:
SettingKeyAuthSourceDefaultLinuxDoConcurrency
,
subscriptions
:
SettingKeyAuthSourceDefaultLinuxDoSubscriptions
,
grantOnSignup
:
SettingKeyAuthSourceDefaultLinuxDoGrantOnSignup
,
grantOnFirstBind
:
SettingKeyAuthSourceDefaultLinuxDoGrantOnFirstBind
,
}
oidcAuthSourceDefaultKeys
=
authSourceDefaultKeySet
{
balance
:
SettingKeyAuthSourceDefaultOIDCBalance
,
concurrency
:
SettingKeyAuthSourceDefaultOIDCConcurrency
,
subscriptions
:
SettingKeyAuthSourceDefaultOIDCSubscriptions
,
grantOnSignup
:
SettingKeyAuthSourceDefaultOIDCGrantOnSignup
,
grantOnFirstBind
:
SettingKeyAuthSourceDefaultOIDCGrantOnFirstBind
,
}
weChatAuthSourceDefaultKeys
=
authSourceDefaultKeySet
{
balance
:
SettingKeyAuthSourceDefaultWeChatBalance
,
concurrency
:
SettingKeyAuthSourceDefaultWeChatConcurrency
,
subscriptions
:
SettingKeyAuthSourceDefaultWeChatSubscriptions
,
grantOnSignup
:
SettingKeyAuthSourceDefaultWeChatGrantOnSignup
,
grantOnFirstBind
:
SettingKeyAuthSourceDefaultWeChatGrantOnFirstBind
,
}
)
const
(
defaultAuthSourceBalance
=
0
defaultAuthSourceConcurrency
=
5
)
// NewSettingService 创建系统设置服务实例
func
NewSettingService
(
settingRepo
SettingRepository
,
cfg
*
config
.
Config
)
*
SettingService
{
return
&
SettingService
{
...
...
@@ -212,6 +273,7 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
if
oidcProviderName
==
""
{
oidcProviderName
=
"OIDC"
}
weChatEnabled
:=
isWeChatOAuthConfigured
()
// Password reset requires email verification to be enabled
emailVerifyEnabled
:=
settings
[
SettingKeyEmailVerifyEnabled
]
==
"true"
...
...
@@ -254,6 +316,7 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
CustomMenuItems
:
settings
[
SettingKeyCustomMenuItems
],
CustomEndpoints
:
settings
[
SettingKeyCustomEndpoints
],
LinuxDoOAuthEnabled
:
linuxDoEnabled
,
WeChatOAuthEnabled
:
weChatEnabled
,
BackendModeEnabled
:
settings
[
SettingKeyBackendModeEnabled
]
==
"true"
,
PaymentEnabled
:
settings
[
SettingPaymentEnabled
]
==
"true"
,
OIDCOAuthEnabled
:
oidcEnabled
,
...
...
@@ -310,6 +373,7 @@ func (s *SettingService) GetPublicSettingsForInjection(ctx context.Context) (any
CustomMenuItems
json
.
RawMessage
`json:"custom_menu_items"`
CustomEndpoints
json
.
RawMessage
`json:"custom_endpoints"`
LinuxDoOAuthEnabled
bool
`json:"linuxdo_oauth_enabled"`
WeChatOAuthEnabled
bool
`json:"wechat_oauth_enabled"`
BackendModeEnabled
bool
`json:"backend_mode_enabled"`
PaymentEnabled
bool
`json:"payment_enabled"`
OIDCOAuthEnabled
bool
`json:"oidc_oauth_enabled"`
...
...
@@ -344,6 +408,7 @@ func (s *SettingService) GetPublicSettingsForInjection(ctx context.Context) (any
CustomMenuItems
:
filterUserVisibleMenuItems
(
settings
.
CustomMenuItems
),
CustomEndpoints
:
safeRawJSONArray
(
settings
.
CustomEndpoints
),
LinuxDoOAuthEnabled
:
settings
.
LinuxDoOAuthEnabled
,
WeChatOAuthEnabled
:
settings
.
WeChatOAuthEnabled
,
BackendModeEnabled
:
settings
.
BackendModeEnabled
,
PaymentEnabled
:
settings
.
PaymentEnabled
,
OIDCOAuthEnabled
:
settings
.
OIDCOAuthEnabled
,
...
...
@@ -392,6 +457,14 @@ func filterUserVisibleMenuItems(raw string) json.RawMessage {
return
result
}
func
isWeChatOAuthConfigured
()
bool
{
openConfigured
:=
strings
.
TrimSpace
(
os
.
Getenv
(
"WECHAT_OAUTH_OPEN_APP_ID"
))
!=
""
&&
strings
.
TrimSpace
(
os
.
Getenv
(
"WECHAT_OAUTH_OPEN_APP_SECRET"
))
!=
""
mpConfigured
:=
strings
.
TrimSpace
(
os
.
Getenv
(
"WECHAT_OAUTH_MP_APP_ID"
))
!=
""
&&
strings
.
TrimSpace
(
os
.
Getenv
(
"WECHAT_OAUTH_MP_APP_SECRET"
))
!=
""
return
openConfigured
||
mpConfigured
}
// safeRawJSONArray returns raw as json.RawMessage if it's valid JSON, otherwise "[]".
func
safeRawJSONArray
(
raw
string
)
json
.
RawMessage
{
raw
=
strings
.
TrimSpace
(
raw
)
...
...
@@ -919,6 +992,74 @@ func (s *SettingService) GetDefaultSubscriptions(ctx context.Context) []DefaultS
return
parseDefaultSubscriptions
(
value
)
}
func
(
s
*
SettingService
)
GetAuthSourceDefaultSettings
(
ctx
context
.
Context
)
(
*
AuthSourceDefaultSettings
,
error
)
{
keys
:=
[]
string
{
SettingKeyAuthSourceDefaultEmailBalance
,
SettingKeyAuthSourceDefaultEmailConcurrency
,
SettingKeyAuthSourceDefaultEmailSubscriptions
,
SettingKeyAuthSourceDefaultEmailGrantOnSignup
,
SettingKeyAuthSourceDefaultEmailGrantOnFirstBind
,
SettingKeyAuthSourceDefaultLinuxDoBalance
,
SettingKeyAuthSourceDefaultLinuxDoConcurrency
,
SettingKeyAuthSourceDefaultLinuxDoSubscriptions
,
SettingKeyAuthSourceDefaultLinuxDoGrantOnSignup
,
SettingKeyAuthSourceDefaultLinuxDoGrantOnFirstBind
,
SettingKeyAuthSourceDefaultOIDCBalance
,
SettingKeyAuthSourceDefaultOIDCConcurrency
,
SettingKeyAuthSourceDefaultOIDCSubscriptions
,
SettingKeyAuthSourceDefaultOIDCGrantOnSignup
,
SettingKeyAuthSourceDefaultOIDCGrantOnFirstBind
,
SettingKeyAuthSourceDefaultWeChatBalance
,
SettingKeyAuthSourceDefaultWeChatConcurrency
,
SettingKeyAuthSourceDefaultWeChatSubscriptions
,
SettingKeyAuthSourceDefaultWeChatGrantOnSignup
,
SettingKeyAuthSourceDefaultWeChatGrantOnFirstBind
,
SettingKeyForceEmailOnThirdPartySignup
,
}
settings
,
err
:=
s
.
settingRepo
.
GetMultiple
(
ctx
,
keys
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"get auth source default settings: %w"
,
err
)
}
return
&
AuthSourceDefaultSettings
{
Email
:
parseProviderDefaultGrantSettings
(
settings
,
emailAuthSourceDefaultKeys
),
LinuxDo
:
parseProviderDefaultGrantSettings
(
settings
,
linuxDoAuthSourceDefaultKeys
),
OIDC
:
parseProviderDefaultGrantSettings
(
settings
,
oidcAuthSourceDefaultKeys
),
WeChat
:
parseProviderDefaultGrantSettings
(
settings
,
weChatAuthSourceDefaultKeys
),
ForceEmailOnThirdPartySignup
:
settings
[
SettingKeyForceEmailOnThirdPartySignup
]
==
"true"
,
},
nil
}
func
(
s
*
SettingService
)
UpdateAuthSourceDefaultSettings
(
ctx
context
.
Context
,
settings
*
AuthSourceDefaultSettings
)
error
{
if
settings
==
nil
{
return
nil
}
for
_
,
subscriptions
:=
range
[][]
DefaultSubscriptionSetting
{
settings
.
Email
.
Subscriptions
,
settings
.
LinuxDo
.
Subscriptions
,
settings
.
OIDC
.
Subscriptions
,
settings
.
WeChat
.
Subscriptions
,
}
{
if
err
:=
s
.
validateDefaultSubscriptionGroups
(
ctx
,
subscriptions
);
err
!=
nil
{
return
err
}
}
updates
:=
make
(
map
[
string
]
string
,
21
)
writeProviderDefaultGrantUpdates
(
updates
,
emailAuthSourceDefaultKeys
,
settings
.
Email
)
writeProviderDefaultGrantUpdates
(
updates
,
linuxDoAuthSourceDefaultKeys
,
settings
.
LinuxDo
)
writeProviderDefaultGrantUpdates
(
updates
,
oidcAuthSourceDefaultKeys
,
settings
.
OIDC
)
writeProviderDefaultGrantUpdates
(
updates
,
weChatAuthSourceDefaultKeys
,
settings
.
WeChat
)
updates
[
SettingKeyForceEmailOnThirdPartySignup
]
=
strconv
.
FormatBool
(
settings
.
ForceEmailOnThirdPartySignup
)
if
err
:=
s
.
settingRepo
.
SetMultiple
(
ctx
,
updates
);
err
!=
nil
{
return
fmt
.
Errorf
(
"update auth source default settings: %w"
,
err
)
}
return
nil
}
// InitializeDefaultSettings 初始化默认设置
func
(
s
*
SettingService
)
InitializeDefaultSettings
(
ctx
context
.
Context
)
error
{
// 检查是否已有设置
...
...
@@ -933,25 +1074,46 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error {
// 初始化默认设置
defaults
:=
map
[
string
]
string
{
SettingKeyRegistrationEnabled
:
"true"
,
SettingKeyEmailVerifyEnabled
:
"false"
,
SettingKeyRegistrationEmailSuffixWhitelist
:
"[]"
,
SettingKeyPromoCodeEnabled
:
"true"
,
// 默认启用优惠码功能
SettingKeySiteName
:
"Sub2API"
,
SettingKeySiteLogo
:
""
,
SettingKeyPurchaseSubscriptionEnabled
:
"false"
,
SettingKeyPurchaseSubscriptionURL
:
""
,
SettingKeyTableDefaultPageSize
:
"20"
,
SettingKeyTablePageSizeOptions
:
"[10,20,50,100]"
,
SettingKeyCustomMenuItems
:
"[]"
,
SettingKeyCustomEndpoints
:
"[]"
,
SettingKeyOIDCConnectEnabled
:
"false"
,
SettingKeyOIDCConnectProviderName
:
"OIDC"
,
SettingKeyDefaultConcurrency
:
strconv
.
Itoa
(
s
.
cfg
.
Default
.
UserConcurrency
),
SettingKeyDefaultBalance
:
strconv
.
FormatFloat
(
s
.
cfg
.
Default
.
UserBalance
,
'f'
,
8
,
64
),
SettingKeyDefaultSubscriptions
:
"[]"
,
SettingKeySMTPPort
:
"587"
,
SettingKeySMTPUseTLS
:
"false"
,
SettingKeyRegistrationEnabled
:
"true"
,
SettingKeyEmailVerifyEnabled
:
"false"
,
SettingKeyRegistrationEmailSuffixWhitelist
:
"[]"
,
SettingKeyPromoCodeEnabled
:
"true"
,
// 默认启用优惠码功能
SettingKeySiteName
:
"Sub2API"
,
SettingKeySiteLogo
:
""
,
SettingKeyPurchaseSubscriptionEnabled
:
"false"
,
SettingKeyPurchaseSubscriptionURL
:
""
,
SettingKeyTableDefaultPageSize
:
"20"
,
SettingKeyTablePageSizeOptions
:
"[10,20,50,100]"
,
SettingKeyCustomMenuItems
:
"[]"
,
SettingKeyCustomEndpoints
:
"[]"
,
SettingKeyOIDCConnectEnabled
:
"false"
,
SettingKeyOIDCConnectProviderName
:
"OIDC"
,
SettingKeyDefaultConcurrency
:
strconv
.
Itoa
(
s
.
cfg
.
Default
.
UserConcurrency
),
SettingKeyDefaultBalance
:
strconv
.
FormatFloat
(
s
.
cfg
.
Default
.
UserBalance
,
'f'
,
8
,
64
),
SettingKeyDefaultSubscriptions
:
"[]"
,
SettingKeyAuthSourceDefaultEmailBalance
:
"0"
,
SettingKeyAuthSourceDefaultEmailConcurrency
:
"5"
,
SettingKeyAuthSourceDefaultEmailSubscriptions
:
"[]"
,
SettingKeyAuthSourceDefaultEmailGrantOnSignup
:
"true"
,
SettingKeyAuthSourceDefaultEmailGrantOnFirstBind
:
"false"
,
SettingKeyAuthSourceDefaultLinuxDoBalance
:
"0"
,
SettingKeyAuthSourceDefaultLinuxDoConcurrency
:
"5"
,
SettingKeyAuthSourceDefaultLinuxDoSubscriptions
:
"[]"
,
SettingKeyAuthSourceDefaultLinuxDoGrantOnSignup
:
"true"
,
SettingKeyAuthSourceDefaultLinuxDoGrantOnFirstBind
:
"false"
,
SettingKeyAuthSourceDefaultOIDCBalance
:
"0"
,
SettingKeyAuthSourceDefaultOIDCConcurrency
:
"5"
,
SettingKeyAuthSourceDefaultOIDCSubscriptions
:
"[]"
,
SettingKeyAuthSourceDefaultOIDCGrantOnSignup
:
"true"
,
SettingKeyAuthSourceDefaultOIDCGrantOnFirstBind
:
"false"
,
SettingKeyAuthSourceDefaultWeChatBalance
:
"0"
,
SettingKeyAuthSourceDefaultWeChatConcurrency
:
"5"
,
SettingKeyAuthSourceDefaultWeChatSubscriptions
:
"[]"
,
SettingKeyAuthSourceDefaultWeChatGrantOnSignup
:
"true"
,
SettingKeyAuthSourceDefaultWeChatGrantOnFirstBind
:
"false"
,
SettingKeyForceEmailOnThirdPartySignup
:
"false"
,
SettingKeySMTPPort
:
"587"
,
SettingKeySMTPUseTLS
:
"false"
,
// Model fallback defaults
SettingKeyEnableModelFallback
:
"false"
,
SettingKeyFallbackModelAnthropic
:
"claude-3-5-sonnet-20241022"
,
...
...
@@ -1164,6 +1326,8 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin
}
else
{
result
.
OIDCConnectValidateIDToken
=
oidcBase
.
ValidateIDToken
}
result
.
OIDCConnectUsePKCE
=
true
result
.
OIDCConnectValidateIDToken
=
true
if
v
,
ok
:=
settings
[
SettingKeyOIDCConnectAllowedSigningAlgs
];
ok
&&
strings
.
TrimSpace
(
v
)
!=
""
{
result
.
OIDCConnectAllowedSigningAlgs
=
strings
.
TrimSpace
(
v
)
}
else
{
...
...
@@ -1317,6 +1481,51 @@ func parseDefaultSubscriptions(raw string) []DefaultSubscriptionSetting {
return
normalized
}
func
parseProviderDefaultGrantSettings
(
settings
map
[
string
]
string
,
keys
authSourceDefaultKeySet
)
ProviderDefaultGrantSettings
{
result
:=
ProviderDefaultGrantSettings
{
Balance
:
defaultAuthSourceBalance
,
Concurrency
:
defaultAuthSourceConcurrency
,
Subscriptions
:
[]
DefaultSubscriptionSetting
{},
GrantOnSignup
:
true
,
GrantOnFirstBind
:
false
,
}
if
v
,
err
:=
strconv
.
ParseFloat
(
strings
.
TrimSpace
(
settings
[
keys
.
balance
]),
64
);
err
==
nil
{
result
.
Balance
=
v
}
if
v
,
err
:=
strconv
.
Atoi
(
strings
.
TrimSpace
(
settings
[
keys
.
concurrency
]));
err
==
nil
{
result
.
Concurrency
=
v
}
if
items
:=
parseDefaultSubscriptions
(
settings
[
keys
.
subscriptions
]);
items
!=
nil
{
result
.
Subscriptions
=
items
}
if
raw
,
ok
:=
settings
[
keys
.
grantOnSignup
];
ok
{
result
.
GrantOnSignup
=
raw
==
"true"
}
if
raw
,
ok
:=
settings
[
keys
.
grantOnFirstBind
];
ok
{
result
.
GrantOnFirstBind
=
raw
==
"true"
}
return
result
}
func
writeProviderDefaultGrantUpdates
(
updates
map
[
string
]
string
,
keys
authSourceDefaultKeySet
,
settings
ProviderDefaultGrantSettings
)
{
updates
[
keys
.
balance
]
=
strconv
.
FormatFloat
(
settings
.
Balance
,
'f'
,
8
,
64
)
updates
[
keys
.
concurrency
]
=
strconv
.
Itoa
(
settings
.
Concurrency
)
subscriptions
:=
settings
.
Subscriptions
if
subscriptions
==
nil
{
subscriptions
=
[]
DefaultSubscriptionSetting
{}
}
raw
,
err
:=
json
.
Marshal
(
subscriptions
)
if
err
!=
nil
{
raw
=
[]
byte
(
"[]"
)
}
updates
[
keys
.
subscriptions
]
=
string
(
raw
)
updates
[
keys
.
grantOnSignup
]
=
strconv
.
FormatBool
(
settings
.
GrantOnSignup
)
updates
[
keys
.
grantOnFirstBind
]
=
strconv
.
FormatBool
(
settings
.
GrantOnFirstBind
)
}
func
parseTablePreferences
(
defaultPageSizeRaw
,
optionsRaw
string
)
(
int
,
[]
int
)
{
defaultPageSize
:=
20
if
v
,
err
:=
strconv
.
Atoi
(
strings
.
TrimSpace
(
defaultPageSizeRaw
));
err
==
nil
{
...
...
@@ -1539,6 +1748,7 @@ func (s *SettingService) GetLinuxDoConnectOAuthConfig(ctx context.Context) (conf
if
v
,
ok
:=
settings
[
SettingKeyLinuxDoConnectRedirectURL
];
ok
&&
strings
.
TrimSpace
(
v
)
!=
""
{
effective
.
RedirectURL
=
strings
.
TrimSpace
(
v
)
}
effective
.
UsePKCE
=
true
if
!
effective
.
Enabled
{
return
config
.
LinuxDoConnectConfig
{},
infraerrors
.
NotFound
(
"OAUTH_DISABLED"
,
"oauth login is disabled"
)
...
...
@@ -1587,9 +1797,6 @@ func (s *SettingService) GetLinuxDoConnectOAuthConfig(ctx context.Context) (conf
return
config
.
LinuxDoConnectConfig
{},
infraerrors
.
InternalServer
(
"OAUTH_CONFIG_INVALID"
,
"oauth client secret not configured"
)
}
case
"none"
:
if
!
effective
.
UsePKCE
{
return
config
.
LinuxDoConnectConfig
{},
infraerrors
.
InternalServer
(
"OAUTH_CONFIG_INVALID"
,
"oauth pkce must be enabled when token_auth_method=none"
)
}
default
:
return
config
.
LinuxDoConnectConfig
{},
infraerrors
.
InternalServer
(
"OAUTH_CONFIG_INVALID"
,
"oauth token_auth_method invalid"
)
}
...
...
@@ -1737,6 +1944,8 @@ func (s *SettingService) GetOIDCConnectOAuthConfig(ctx context.Context) (config.
if
raw
,
ok
:=
settings
[
SettingKeyOIDCConnectValidateIDToken
];
ok
{
effective
.
ValidateIDToken
=
raw
==
"true"
}
effective
.
UsePKCE
=
true
effective
.
ValidateIDToken
=
true
if
v
,
ok
:=
settings
[
SettingKeyOIDCConnectAllowedSigningAlgs
];
ok
&&
strings
.
TrimSpace
(
v
)
!=
""
{
effective
.
AllowedSigningAlgs
=
strings
.
TrimSpace
(
v
)
}
...
...
@@ -1864,9 +2073,6 @@ func (s *SettingService) GetOIDCConnectOAuthConfig(ctx context.Context) (config.
return
config
.
OIDCConnectConfig
{},
infraerrors
.
InternalServer
(
"OAUTH_CONFIG_INVALID"
,
"oauth client secret not configured"
)
}
case
"none"
:
if
!
effective
.
UsePKCE
{
return
config
.
OIDCConnectConfig
{},
infraerrors
.
InternalServer
(
"OAUTH_CONFIG_INVALID"
,
"oauth pkce must be enabled when token_auth_method=none"
)
}
default
:
return
config
.
OIDCConnectConfig
{},
infraerrors
.
InternalServer
(
"OAUTH_CONFIG_INVALID"
,
"oauth token_auth_method invalid"
)
}
...
...
backend/internal/service/setting_service_auth_source_defaults_test.go
0 → 100644
View file @
e9de839d
//go:build unit
package
service
import
(
"context"
"encoding/json"
"testing"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/stretchr/testify/require"
)
type
authSourceDefaultsRepoStub
struct
{
values
map
[
string
]
string
updates
map
[
string
]
string
}
func
(
s
*
authSourceDefaultsRepoStub
)
Get
(
ctx
context
.
Context
,
key
string
)
(
*
Setting
,
error
)
{
panic
(
"unexpected Get call"
)
}
func
(
s
*
authSourceDefaultsRepoStub
)
GetValue
(
ctx
context
.
Context
,
key
string
)
(
string
,
error
)
{
panic
(
"unexpected GetValue call"
)
}
func
(
s
*
authSourceDefaultsRepoStub
)
Set
(
ctx
context
.
Context
,
key
,
value
string
)
error
{
panic
(
"unexpected Set call"
)
}
func
(
s
*
authSourceDefaultsRepoStub
)
GetMultiple
(
ctx
context
.
Context
,
keys
[]
string
)
(
map
[
string
]
string
,
error
)
{
out
:=
make
(
map
[
string
]
string
,
len
(
keys
))
for
_
,
key
:=
range
keys
{
if
value
,
ok
:=
s
.
values
[
key
];
ok
{
out
[
key
]
=
value
}
}
return
out
,
nil
}
func
(
s
*
authSourceDefaultsRepoStub
)
SetMultiple
(
ctx
context
.
Context
,
settings
map
[
string
]
string
)
error
{
s
.
updates
=
make
(
map
[
string
]
string
,
len
(
settings
))
for
key
,
value
:=
range
settings
{
s
.
updates
[
key
]
=
value
if
s
.
values
==
nil
{
s
.
values
=
map
[
string
]
string
{}
}
s
.
values
[
key
]
=
value
}
return
nil
}
func
(
s
*
authSourceDefaultsRepoStub
)
GetAll
(
ctx
context
.
Context
)
(
map
[
string
]
string
,
error
)
{
panic
(
"unexpected GetAll call"
)
}
func
(
s
*
authSourceDefaultsRepoStub
)
Delete
(
ctx
context
.
Context
,
key
string
)
error
{
panic
(
"unexpected Delete call"
)
}
func
TestSettingService_GetAuthSourceDefaultSettings_ParsesValuesAndDefaults
(
t
*
testing
.
T
)
{
repo
:=
&
authSourceDefaultsRepoStub
{
values
:
map
[
string
]
string
{
SettingKeyAuthSourceDefaultEmailBalance
:
"12.5"
,
SettingKeyAuthSourceDefaultEmailConcurrency
:
"7"
,
SettingKeyAuthSourceDefaultEmailSubscriptions
:
`[{"group_id":11,"validity_days":30}]`
,
SettingKeyAuthSourceDefaultEmailGrantOnSignup
:
"false"
,
SettingKeyAuthSourceDefaultLinuxDoGrantOnFirstBind
:
"true"
,
SettingKeyForceEmailOnThirdPartySignup
:
"true"
,
},
}
svc
:=
NewSettingService
(
repo
,
&
config
.
Config
{})
got
,
err
:=
svc
.
GetAuthSourceDefaultSettings
(
context
.
Background
())
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
12.5
,
got
.
Email
.
Balance
)
require
.
Equal
(
t
,
7
,
got
.
Email
.
Concurrency
)
require
.
Equal
(
t
,
[]
DefaultSubscriptionSetting
{{
GroupID
:
11
,
ValidityDays
:
30
}},
got
.
Email
.
Subscriptions
)
require
.
False
(
t
,
got
.
Email
.
GrantOnSignup
)
require
.
False
(
t
,
got
.
Email
.
GrantOnFirstBind
)
require
.
Equal
(
t
,
0.0
,
got
.
LinuxDo
.
Balance
)
require
.
Equal
(
t
,
5
,
got
.
LinuxDo
.
Concurrency
)
require
.
Equal
(
t
,
[]
DefaultSubscriptionSetting
{},
got
.
LinuxDo
.
Subscriptions
)
require
.
True
(
t
,
got
.
LinuxDo
.
GrantOnSignup
)
require
.
True
(
t
,
got
.
LinuxDo
.
GrantOnFirstBind
)
require
.
Equal
(
t
,
5
,
got
.
OIDC
.
Concurrency
)
require
.
Equal
(
t
,
5
,
got
.
WeChat
.
Concurrency
)
require
.
True
(
t
,
got
.
ForceEmailOnThirdPartySignup
)
}
func
TestSettingService_UpdateAuthSourceDefaultSettings_PersistsAllKeys
(
t
*
testing
.
T
)
{
repo
:=
&
authSourceDefaultsRepoStub
{}
svc
:=
NewSettingService
(
repo
,
&
config
.
Config
{})
err
:=
svc
.
UpdateAuthSourceDefaultSettings
(
context
.
Background
(),
&
AuthSourceDefaultSettings
{
Email
:
ProviderDefaultGrantSettings
{
Balance
:
1.25
,
Concurrency
:
3
,
Subscriptions
:
[]
DefaultSubscriptionSetting
{{
GroupID
:
21
,
ValidityDays
:
14
}},
GrantOnSignup
:
false
,
GrantOnFirstBind
:
true
,
},
LinuxDo
:
ProviderDefaultGrantSettings
{
Balance
:
2
,
Concurrency
:
4
,
Subscriptions
:
[]
DefaultSubscriptionSetting
{{
GroupID
:
22
,
ValidityDays
:
30
}},
GrantOnSignup
:
true
,
GrantOnFirstBind
:
false
,
},
OIDC
:
ProviderDefaultGrantSettings
{
Balance
:
3
,
Concurrency
:
5
,
Subscriptions
:
[]
DefaultSubscriptionSetting
{{
GroupID
:
23
,
ValidityDays
:
60
}},
GrantOnSignup
:
true
,
GrantOnFirstBind
:
true
,
},
WeChat
:
ProviderDefaultGrantSettings
{
Balance
:
4
,
Concurrency
:
6
,
Subscriptions
:
[]
DefaultSubscriptionSetting
{{
GroupID
:
24
,
ValidityDays
:
90
}},
GrantOnSignup
:
false
,
GrantOnFirstBind
:
false
,
},
ForceEmailOnThirdPartySignup
:
true
,
})
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
"1.25000000"
,
repo
.
updates
[
SettingKeyAuthSourceDefaultEmailBalance
])
require
.
Equal
(
t
,
"3"
,
repo
.
updates
[
SettingKeyAuthSourceDefaultEmailConcurrency
])
require
.
Equal
(
t
,
"false"
,
repo
.
updates
[
SettingKeyAuthSourceDefaultEmailGrantOnSignup
])
require
.
Equal
(
t
,
"true"
,
repo
.
updates
[
SettingKeyAuthSourceDefaultEmailGrantOnFirstBind
])
require
.
Equal
(
t
,
"true"
,
repo
.
updates
[
SettingKeyForceEmailOnThirdPartySignup
])
var
got
[]
DefaultSubscriptionSetting
require
.
NoError
(
t
,
json
.
Unmarshal
([]
byte
(
repo
.
updates
[
SettingKeyAuthSourceDefaultWeChatSubscriptions
]),
&
got
))
require
.
Equal
(
t
,
[]
DefaultSubscriptionSetting
{{
GroupID
:
24
,
ValidityDays
:
90
}},
got
)
}
backend/internal/service/settings_view.go
View file @
e9de839d
...
...
@@ -152,6 +152,7 @@ type PublicSettings struct {
CustomEndpoints
string
// JSON array of custom endpoints
LinuxDoOAuthEnabled
bool
WeChatOAuthEnabled
bool
BackendModeEnabled
bool
PaymentEnabled
bool
OIDCOAuthEnabled
bool
...
...
backend/internal/service/user.go
View file @
e9de839d
...
...
@@ -7,19 +7,27 @@ import (
)
type
User
struct
{
ID
int64
Email
string
Username
string
Notes
string
PasswordHash
string
Role
string
Balance
float64
Concurrency
int
Status
string
AllowedGroups
[]
int64
TokenVersion
int64
// Incremented on password change to invalidate existing tokens
CreatedAt
time
.
Time
UpdatedAt
time
.
Time
ID
int64
Email
string
Username
string
Notes
string
AvatarURL
string
AvatarSource
string
AvatarMIME
string
AvatarByteSize
int
AvatarSHA256
string
PasswordHash
string
Role
string
Balance
float64
Concurrency
int
Status
string
AllowedGroups
[]
int64
TokenVersion
int64
// Incremented on password change to invalidate existing tokens
SignupSource
string
LastLoginAt
*
time
.
Time
LastActiveAt
*
time
.
Time
CreatedAt
time
.
Time
UpdatedAt
time
.
Time
// GroupRates 用户专属分组倍率配置
// map[groupID]rateMultiplier
...
...
backend/internal/service/user_service.go
View file @
e9de839d
...
...
@@ -2,9 +2,13 @@ package service
import
(
"context"
"crypto/sha256"
"crypto/subtle"
"encoding/base64"
"encoding/hex"
"fmt"
"log/slog"
"net/url"
"strings"
"time"
...
...
@@ -17,10 +21,14 @@ var (
ErrPasswordIncorrect
=
infraerrors
.
BadRequest
(
"PASSWORD_INCORRECT"
,
"current password is incorrect"
)
ErrInsufficientPerms
=
infraerrors
.
Forbidden
(
"INSUFFICIENT_PERMISSIONS"
,
"insufficient permissions"
)
ErrNotifyCodeUserRateLimit
=
infraerrors
.
TooManyRequests
(
"NOTIFY_CODE_USER_RATE_LIMIT"
,
"too many verification codes requested, please try again later"
)
ErrAvatarInvalid
=
infraerrors
.
BadRequest
(
"AVATAR_INVALID"
,
"avatar must be a valid image data URL or http(s) URL"
)
ErrAvatarTooLarge
=
infraerrors
.
BadRequest
(
"AVATAR_TOO_LARGE"
,
"avatar image must be 100KB or smaller"
)
ErrAvatarNotImage
=
infraerrors
.
BadRequest
(
"AVATAR_NOT_IMAGE"
,
"avatar content must be an image"
)
)
const
(
maxNotifyEmails
=
3
// Maximum number of notification emails per user
maxNotifyEmails
=
3
// Maximum number of notification emails per user
maxInlineAvatarBytes
=
100
*
1024
// User-level rate limiting for notify email verification codes
notifyCodeUserRateLimit
=
5
...
...
@@ -47,6 +55,9 @@ type UserRepository interface {
GetFirstAdmin
(
ctx
context
.
Context
)
(
*
User
,
error
)
Update
(
ctx
context
.
Context
,
user
*
User
)
error
Delete
(
ctx
context
.
Context
,
id
int64
)
error
GetUserAvatar
(
ctx
context
.
Context
,
userID
int64
)
(
*
UserAvatar
,
error
)
UpsertUserAvatar
(
ctx
context
.
Context
,
userID
int64
,
input
UpsertUserAvatarInput
)
(
*
UserAvatar
,
error
)
DeleteUserAvatar
(
ctx
context
.
Context
,
userID
int64
)
error
List
(
ctx
context
.
Context
,
params
pagination
.
PaginationParams
)
([]
User
,
*
pagination
.
PaginationResult
,
error
)
ListWithFilters
(
ctx
context
.
Context
,
params
pagination
.
PaginationParams
,
filters
UserListFilters
)
([]
User
,
*
pagination
.
PaginationResult
,
error
)
...
...
@@ -71,11 +82,30 @@ type UserRepository interface {
type
UpdateProfileRequest
struct
{
Email
*
string
`json:"email"`
Username
*
string
`json:"username"`
AvatarURL
*
string
`json:"avatar_url"`
Concurrency
*
int
`json:"concurrency"`
BalanceNotifyEnabled
*
bool
`json:"balance_notify_enabled"`
BalanceNotifyThreshold
*
float64
`json:"balance_notify_threshold"`
}
type
UserAvatar
struct
{
StorageProvider
string
StorageKey
string
URL
string
ContentType
string
ByteSize
int
SHA256
string
}
type
UpsertUserAvatarInput
struct
{
StorageProvider
string
StorageKey
string
URL
string
ContentType
string
ByteSize
int
SHA256
string
}
// ChangePasswordRequest 修改密码请求
type
ChangePasswordRequest
struct
{
CurrentPassword
string
`json:"current_password"`
...
...
@@ -115,6 +145,9 @@ func (s *UserService) GetProfile(ctx context.Context, userID int64) (*User, erro
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"get user: %w"
,
err
)
}
if
err
:=
s
.
hydrateUserAvatar
(
ctx
,
user
);
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"get user avatar: %w"
,
err
)
}
return
user
,
nil
}
...
...
@@ -143,6 +176,27 @@ func (s *UserService) UpdateProfile(ctx context.Context, userID int64, req Updat
user
.
Username
=
*
req
.
Username
}
if
req
.
AvatarURL
!=
nil
{
avatarValue
:=
strings
.
TrimSpace
(
*
req
.
AvatarURL
)
switch
{
case
avatarValue
==
""
:
if
err
:=
s
.
userRepo
.
DeleteUserAvatar
(
ctx
,
userID
);
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"delete avatar: %w"
,
err
)
}
applyUserAvatar
(
user
,
nil
)
default
:
avatarInput
,
err
:=
normalizeUserAvatarInput
(
avatarValue
)
if
err
!=
nil
{
return
nil
,
err
}
avatar
,
err
:=
s
.
userRepo
.
UpsertUserAvatar
(
ctx
,
userID
,
avatarInput
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"upsert avatar: %w"
,
err
)
}
applyUserAvatar
(
user
,
avatar
)
}
}
if
req
.
Concurrency
!=
nil
{
user
.
Concurrency
=
*
req
.
Concurrency
}
...
...
@@ -168,6 +222,87 @@ func (s *UserService) UpdateProfile(ctx context.Context, userID int64, req Updat
return
user
,
nil
}
func
applyUserAvatar
(
user
*
User
,
avatar
*
UserAvatar
)
{
if
user
==
nil
{
return
}
if
avatar
==
nil
{
user
.
AvatarURL
=
""
user
.
AvatarSource
=
""
user
.
AvatarMIME
=
""
user
.
AvatarByteSize
=
0
user
.
AvatarSHA256
=
""
return
}
user
.
AvatarURL
=
avatar
.
URL
user
.
AvatarSource
=
avatar
.
StorageProvider
user
.
AvatarMIME
=
avatar
.
ContentType
user
.
AvatarByteSize
=
avatar
.
ByteSize
user
.
AvatarSHA256
=
avatar
.
SHA256
}
func
normalizeUserAvatarInput
(
raw
string
)
(
UpsertUserAvatarInput
,
error
)
{
raw
=
strings
.
TrimSpace
(
raw
)
if
raw
==
""
{
return
UpsertUserAvatarInput
{},
ErrAvatarInvalid
}
if
strings
.
HasPrefix
(
raw
,
"data:"
)
{
return
normalizeInlineUserAvatarInput
(
raw
)
}
parsed
,
err
:=
url
.
Parse
(
raw
)
if
err
!=
nil
||
parsed
==
nil
{
return
UpsertUserAvatarInput
{},
ErrAvatarInvalid
}
if
!
strings
.
EqualFold
(
parsed
.
Scheme
,
"http"
)
&&
!
strings
.
EqualFold
(
parsed
.
Scheme
,
"https"
)
{
return
UpsertUserAvatarInput
{},
ErrAvatarInvalid
}
if
strings
.
TrimSpace
(
parsed
.
Host
)
==
""
{
return
UpsertUserAvatarInput
{},
ErrAvatarInvalid
}
return
UpsertUserAvatarInput
{
StorageProvider
:
"remote_url"
,
URL
:
raw
,
},
nil
}
func
normalizeInlineUserAvatarInput
(
raw
string
)
(
UpsertUserAvatarInput
,
error
)
{
body
:=
strings
.
TrimPrefix
(
raw
,
"data:"
)
meta
,
encoded
,
ok
:=
strings
.
Cut
(
body
,
","
)
if
!
ok
{
return
UpsertUserAvatarInput
{},
ErrAvatarInvalid
}
meta
=
strings
.
TrimSpace
(
meta
)
encoded
=
strings
.
TrimSpace
(
encoded
)
if
!
strings
.
HasSuffix
(
strings
.
ToLower
(
meta
),
";base64"
)
{
return
UpsertUserAvatarInput
{},
ErrAvatarInvalid
}
contentType
:=
strings
.
TrimSpace
(
meta
[
:
len
(
meta
)
-
len
(
";base64"
)])
if
contentType
==
""
||
!
strings
.
HasPrefix
(
strings
.
ToLower
(
contentType
),
"image/"
)
{
return
UpsertUserAvatarInput
{},
ErrAvatarNotImage
}
decoded
,
err
:=
base64
.
StdEncoding
.
DecodeString
(
encoded
)
if
err
!=
nil
{
return
UpsertUserAvatarInput
{},
ErrAvatarInvalid
}
if
len
(
decoded
)
>
maxInlineAvatarBytes
{
return
UpsertUserAvatarInput
{},
ErrAvatarTooLarge
}
sum
:=
sha256
.
Sum256
(
decoded
)
return
UpsertUserAvatarInput
{
StorageProvider
:
"inline"
,
URL
:
raw
,
ContentType
:
contentType
,
ByteSize
:
len
(
decoded
),
SHA256
:
hex
.
EncodeToString
(
sum
[
:
]),
},
nil
}
// ChangePassword 修改密码
// Security: Increments TokenVersion to invalidate all existing JWT tokens
func
(
s
*
UserService
)
ChangePassword
(
ctx
context
.
Context
,
userID
int64
,
req
ChangePasswordRequest
)
error
{
...
...
@@ -202,9 +337,25 @@ func (s *UserService) GetByID(ctx context.Context, id int64) (*User, error) {
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"get user: %w"
,
err
)
}
if
err
:=
s
.
hydrateUserAvatar
(
ctx
,
user
);
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"get user avatar: %w"
,
err
)
}
return
user
,
nil
}
func
(
s
*
UserService
)
hydrateUserAvatar
(
ctx
context
.
Context
,
user
*
User
)
error
{
if
s
==
nil
||
s
.
userRepo
==
nil
||
user
==
nil
||
user
.
ID
==
0
{
return
nil
}
avatar
,
err
:=
s
.
userRepo
.
GetUserAvatar
(
ctx
,
user
.
ID
)
if
err
!=
nil
{
return
err
}
applyUserAvatar
(
user
,
avatar
)
return
nil
}
// List 获取用户列表(管理员功能)
func
(
s
*
UserService
)
List
(
ctx
context
.
Context
,
params
pagination
.
PaginationParams
)
([]
User
,
*
pagination
.
PaginationResult
,
error
)
{
users
,
pagination
,
err
:=
s
.
userRepo
.
List
(
ctx
,
params
)
...
...
backend/internal/service/user_service_test.go
View file @
e9de839d
...
...
@@ -4,6 +4,9 @@ package service
import
(
"context"
"crypto/sha256"
"encoding/base64"
"encoding/hex"
"errors"
"sync"
"sync/atomic"
...
...
@@ -19,14 +22,65 @@ import (
type
mockUserRepo
struct
{
updateBalanceErr
error
updateBalanceFn
func
(
ctx
context
.
Context
,
id
int64
,
amount
float64
)
error
getByIDUser
*
User
getByIDErr
error
updateFn
func
(
ctx
context
.
Context
,
user
*
User
)
error
updateCalls
int
upsertAvatarFn
func
(
ctx
context
.
Context
,
userID
int64
,
input
UpsertUserAvatarInput
)
(
*
UserAvatar
,
error
)
upsertAvatarArgs
[]
UpsertUserAvatarInput
deleteAvatarFn
func
(
ctx
context
.
Context
,
userID
int64
)
error
deleteAvatarIDs
[]
int64
getAvatarFn
func
(
ctx
context
.
Context
,
userID
int64
)
(
*
UserAvatar
,
error
)
}
func
(
m
*
mockUserRepo
)
Create
(
context
.
Context
,
*
User
)
error
{
return
nil
}
func
(
m
*
mockUserRepo
)
GetByID
(
context
.
Context
,
int64
)
(
*
User
,
error
)
{
return
&
User
{},
nil
}
func
(
m
*
mockUserRepo
)
Create
(
context
.
Context
,
*
User
)
error
{
return
nil
}
func
(
m
*
mockUserRepo
)
GetByID
(
context
.
Context
,
int64
)
(
*
User
,
error
)
{
if
m
.
getByIDErr
!=
nil
{
return
nil
,
m
.
getByIDErr
}
if
m
.
getByIDUser
!=
nil
{
cloned
:=
*
m
.
getByIDUser
return
&
cloned
,
nil
}
return
&
User
{},
nil
}
func
(
m
*
mockUserRepo
)
GetByEmail
(
context
.
Context
,
string
)
(
*
User
,
error
)
{
return
&
User
{},
nil
}
func
(
m
*
mockUserRepo
)
GetFirstAdmin
(
context
.
Context
)
(
*
User
,
error
)
{
return
&
User
{},
nil
}
func
(
m
*
mockUserRepo
)
Update
(
context
.
Context
,
*
User
)
error
{
return
nil
}
func
(
m
*
mockUserRepo
)
Delete
(
context
.
Context
,
int64
)
error
{
return
nil
}
func
(
m
*
mockUserRepo
)
Update
(
ctx
context
.
Context
,
user
*
User
)
error
{
m
.
updateCalls
++
if
m
.
updateFn
!=
nil
{
return
m
.
updateFn
(
ctx
,
user
)
}
return
nil
}
func
(
m
*
mockUserRepo
)
Delete
(
context
.
Context
,
int64
)
error
{
return
nil
}
func
(
m
*
mockUserRepo
)
GetUserAvatar
(
ctx
context
.
Context
,
userID
int64
)
(
*
UserAvatar
,
error
)
{
if
m
.
getAvatarFn
!=
nil
{
return
m
.
getAvatarFn
(
ctx
,
userID
)
}
return
nil
,
nil
}
func
(
m
*
mockUserRepo
)
UpsertUserAvatar
(
ctx
context
.
Context
,
userID
int64
,
input
UpsertUserAvatarInput
)
(
*
UserAvatar
,
error
)
{
m
.
upsertAvatarArgs
=
append
(
m
.
upsertAvatarArgs
,
input
)
if
m
.
upsertAvatarFn
!=
nil
{
return
m
.
upsertAvatarFn
(
ctx
,
userID
,
input
)
}
return
&
UserAvatar
{
StorageProvider
:
input
.
StorageProvider
,
StorageKey
:
input
.
StorageKey
,
URL
:
input
.
URL
,
ContentType
:
input
.
ContentType
,
ByteSize
:
input
.
ByteSize
,
SHA256
:
input
.
SHA256
,
},
nil
}
func
(
m
*
mockUserRepo
)
DeleteUserAvatar
(
ctx
context
.
Context
,
userID
int64
)
error
{
m
.
deleteAvatarIDs
=
append
(
m
.
deleteAvatarIDs
,
userID
)
if
m
.
deleteAvatarFn
!=
nil
{
return
m
.
deleteAvatarFn
(
ctx
,
userID
)
}
return
nil
}
func
(
m
*
mockUserRepo
)
List
(
context
.
Context
,
pagination
.
PaginationParams
)
([]
User
,
*
pagination
.
PaginationResult
,
error
)
{
return
nil
,
nil
,
nil
}
...
...
@@ -200,3 +254,121 @@ func TestNewUserService_FieldsAssignment(t *testing.T) {
require
.
Equal
(
t
,
auth
,
svc
.
authCacheInvalidator
)
require
.
Equal
(
t
,
cache
,
svc
.
billingCache
)
}
func
TestUpdateProfile_StoresInlineAvatarWithinLimit
(
t
*
testing
.
T
)
{
raw
:=
[]
byte
(
"small-avatar"
)
dataURL
:=
"data:image/png;base64,"
+
base64
.
StdEncoding
.
EncodeToString
(
raw
)
expectedSum
:=
sha256
.
Sum256
(
raw
)
repo
:=
&
mockUserRepo
{
getByIDUser
:
&
User
{
ID
:
7
,
Email
:
"avatar@example.com"
,
Username
:
"avatar-user"
,
},
}
svc
:=
NewUserService
(
repo
,
nil
,
nil
,
nil
)
updated
,
err
:=
svc
.
UpdateProfile
(
context
.
Background
(),
7
,
UpdateProfileRequest
{
AvatarURL
:
&
dataURL
,
})
require
.
NoError
(
t
,
err
)
require
.
Len
(
t
,
repo
.
upsertAvatarArgs
,
1
)
require
.
Equal
(
t
,
"inline"
,
repo
.
upsertAvatarArgs
[
0
]
.
StorageProvider
)
require
.
Equal
(
t
,
"image/png"
,
repo
.
upsertAvatarArgs
[
0
]
.
ContentType
)
require
.
Equal
(
t
,
len
(
raw
),
repo
.
upsertAvatarArgs
[
0
]
.
ByteSize
)
require
.
Equal
(
t
,
hex
.
EncodeToString
(
expectedSum
[
:
]),
repo
.
upsertAvatarArgs
[
0
]
.
SHA256
)
require
.
Equal
(
t
,
dataURL
,
updated
.
AvatarURL
)
require
.
Equal
(
t
,
"inline"
,
updated
.
AvatarSource
)
require
.
Equal
(
t
,
"image/png"
,
updated
.
AvatarMIME
)
require
.
Equal
(
t
,
len
(
raw
),
updated
.
AvatarByteSize
)
require
.
Equal
(
t
,
hex
.
EncodeToString
(
expectedSum
[
:
]),
updated
.
AvatarSHA256
)
}
func
TestUpdateProfile_RejectsInlineAvatarOverLimit
(
t
*
testing
.
T
)
{
raw
:=
make
([]
byte
,
maxInlineAvatarBytes
+
1
)
dataURL
:=
"data:image/png;base64,"
+
base64
.
StdEncoding
.
EncodeToString
(
raw
)
repo
:=
&
mockUserRepo
{
getByIDUser
:
&
User
{
ID
:
8
,
Email
:
"large-avatar@example.com"
,
Username
:
"too-large"
,
},
}
svc
:=
NewUserService
(
repo
,
nil
,
nil
,
nil
)
_
,
err
:=
svc
.
UpdateProfile
(
context
.
Background
(),
8
,
UpdateProfileRequest
{
AvatarURL
:
&
dataURL
,
})
require
.
ErrorIs
(
t
,
err
,
ErrAvatarTooLarge
)
require
.
Empty
(
t
,
repo
.
upsertAvatarArgs
)
require
.
Empty
(
t
,
repo
.
deleteAvatarIDs
)
require
.
Zero
(
t
,
repo
.
updateCalls
)
}
func
TestUpdateProfile_StoresRemoteAvatarURL
(
t
*
testing
.
T
)
{
remoteURL
:=
"https://cdn.example.com/avatar.png"
repo
:=
&
mockUserRepo
{
getByIDUser
:
&
User
{
ID
:
9
,
Email
:
"remote-avatar@example.com"
,
Username
:
"remote-avatar"
,
},
}
svc
:=
NewUserService
(
repo
,
nil
,
nil
,
nil
)
updated
,
err
:=
svc
.
UpdateProfile
(
context
.
Background
(),
9
,
UpdateProfileRequest
{
AvatarURL
:
&
remoteURL
,
})
require
.
NoError
(
t
,
err
)
require
.
Len
(
t
,
repo
.
upsertAvatarArgs
,
1
)
require
.
Equal
(
t
,
"remote_url"
,
repo
.
upsertAvatarArgs
[
0
]
.
StorageProvider
)
require
.
Equal
(
t
,
remoteURL
,
repo
.
upsertAvatarArgs
[
0
]
.
URL
)
require
.
Equal
(
t
,
remoteURL
,
updated
.
AvatarURL
)
require
.
Equal
(
t
,
"remote_url"
,
updated
.
AvatarSource
)
require
.
Zero
(
t
,
updated
.
AvatarByteSize
)
}
func
TestUpdateProfile_DeletesAvatarOnEmptyString
(
t
*
testing
.
T
)
{
empty
:=
""
repo
:=
&
mockUserRepo
{
getByIDUser
:
&
User
{
ID
:
10
,
Email
:
"delete-avatar@example.com"
,
Username
:
"delete-avatar"
,
AvatarURL
:
"https://cdn.example.com/old.png"
,
AvatarSource
:
"remote_url"
,
},
}
svc
:=
NewUserService
(
repo
,
nil
,
nil
,
nil
)
updated
,
err
:=
svc
.
UpdateProfile
(
context
.
Background
(),
10
,
UpdateProfileRequest
{
AvatarURL
:
&
empty
,
})
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
[]
int64
{
10
},
repo
.
deleteAvatarIDs
)
require
.
Empty
(
t
,
repo
.
upsertAvatarArgs
)
require
.
Empty
(
t
,
updated
.
AvatarURL
)
require
.
Empty
(
t
,
updated
.
AvatarSource
)
}
func
TestGetProfile_HydratesAvatarFromRepository
(
t
*
testing
.
T
)
{
repo
:=
&
mockUserRepo
{
getByIDUser
:
&
User
{
ID
:
12
,
Email
:
"profile-avatar@example.com"
,
Username
:
"profile-avatar"
,
},
getAvatarFn
:
func
(
context
.
Context
,
int64
)
(
*
UserAvatar
,
error
)
{
return
&
UserAvatar
{
StorageProvider
:
"remote_url"
,
URL
:
"https://cdn.example.com/profile.png"
,
},
nil
},
}
svc
:=
NewUserService
(
repo
,
nil
,
nil
,
nil
)
user
,
err
:=
svc
.
GetProfile
(
context
.
Background
(),
12
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
"https://cdn.example.com/profile.png"
,
user
.
AvatarURL
)
require
.
Equal
(
t
,
"remote_url"
,
user
.
AvatarSource
)
}
backend/migrations/108_auth_identity_foundation_core.sql
0 → 100644
View file @
e9de839d
ALTER
TABLE
users
ADD
COLUMN
IF
NOT
EXISTS
signup_source
VARCHAR
(
20
)
NOT
NULL
DEFAULT
'email'
,
ADD
COLUMN
IF
NOT
EXISTS
last_login_at
TIMESTAMPTZ
NULL
,
ADD
COLUMN
IF
NOT
EXISTS
last_active_at
TIMESTAMPTZ
NULL
;
UPDATE
users
SET
signup_source
=
'email'
WHERE
signup_source
IS
NULL
OR
signup_source
=
''
;
DO
$$
BEGIN
IF
NOT
EXISTS
(
SELECT
1
FROM
pg_constraint
WHERE
conname
=
'users_signup_source_check'
)
THEN
ALTER
TABLE
users
ADD
CONSTRAINT
users_signup_source_check
CHECK
(
signup_source
IN
(
'email'
,
'linuxdo'
,
'wechat'
,
'oidc'
));
END
IF
;
END
$$
;
CREATE
TABLE
IF
NOT
EXISTS
auth_identities
(
id
BIGSERIAL
PRIMARY
KEY
,
user_id
BIGINT
NOT
NULL
REFERENCES
users
(
id
)
ON
DELETE
CASCADE
,
provider_type
VARCHAR
(
20
)
NOT
NULL
,
provider_key
TEXT
NOT
NULL
,
provider_subject
TEXT
NOT
NULL
,
verified_at
TIMESTAMPTZ
NULL
,
issuer
TEXT
NULL
,
metadata
JSONB
NOT
NULL
DEFAULT
'{}'
::
jsonb
,
created_at
TIMESTAMPTZ
NOT
NULL
DEFAULT
NOW
(),
updated_at
TIMESTAMPTZ
NOT
NULL
DEFAULT
NOW
(),
CONSTRAINT
auth_identities_provider_type_check
CHECK
(
provider_type
IN
(
'email'
,
'linuxdo'
,
'wechat'
,
'oidc'
))
);
CREATE
UNIQUE
INDEX
IF
NOT
EXISTS
auth_identities_provider_subject_key
ON
auth_identities
(
provider_type
,
provider_key
,
provider_subject
);
CREATE
INDEX
IF
NOT
EXISTS
auth_identities_user_id_idx
ON
auth_identities
(
user_id
);
CREATE
INDEX
IF
NOT
EXISTS
auth_identities_user_provider_idx
ON
auth_identities
(
user_id
,
provider_type
);
CREATE
TABLE
IF
NOT
EXISTS
auth_identity_channels
(
id
BIGSERIAL
PRIMARY
KEY
,
identity_id
BIGINT
NOT
NULL
REFERENCES
auth_identities
(
id
)
ON
DELETE
CASCADE
,
provider_type
VARCHAR
(
20
)
NOT
NULL
,
provider_key
TEXT
NOT
NULL
,
channel
VARCHAR
(
20
)
NOT
NULL
,
channel_app_id
TEXT
NOT
NULL
,
channel_subject
TEXT
NOT
NULL
,
metadata
JSONB
NOT
NULL
DEFAULT
'{}'
::
jsonb
,
created_at
TIMESTAMPTZ
NOT
NULL
DEFAULT
NOW
(),
updated_at
TIMESTAMPTZ
NOT
NULL
DEFAULT
NOW
(),
CONSTRAINT
auth_identity_channels_provider_type_check
CHECK
(
provider_type
IN
(
'email'
,
'linuxdo'
,
'wechat'
,
'oidc'
))
);
CREATE
UNIQUE
INDEX
IF
NOT
EXISTS
auth_identity_channels_channel_key
ON
auth_identity_channels
(
provider_type
,
provider_key
,
channel
,
channel_app_id
,
channel_subject
);
CREATE
INDEX
IF
NOT
EXISTS
auth_identity_channels_identity_id_idx
ON
auth_identity_channels
(
identity_id
);
CREATE
TABLE
IF
NOT
EXISTS
pending_auth_sessions
(
id
BIGSERIAL
PRIMARY
KEY
,
session_token
VARCHAR
(
255
)
NOT
NULL
,
intent
VARCHAR
(
40
)
NOT
NULL
,
provider_type
VARCHAR
(
20
)
NOT
NULL
,
provider_key
TEXT
NOT
NULL
,
provider_subject
TEXT
NOT
NULL
,
target_user_id
BIGINT
NULL
REFERENCES
users
(
id
)
ON
DELETE
SET
NULL
,
redirect_to
TEXT
NOT
NULL
DEFAULT
''
,
resolved_email
TEXT
NOT
NULL
DEFAULT
''
,
registration_password_hash
TEXT
NOT
NULL
DEFAULT
''
,
upstream_identity_claims
JSONB
NOT
NULL
DEFAULT
'{}'
::
jsonb
,
local_flow_state
JSONB
NOT
NULL
DEFAULT
'{}'
::
jsonb
,
browser_session_key
TEXT
NOT
NULL
DEFAULT
''
,
completion_code_hash
TEXT
NOT
NULL
DEFAULT
''
,
completion_code_expires_at
TIMESTAMPTZ
NULL
,
email_verified_at
TIMESTAMPTZ
NULL
,
password_verified_at
TIMESTAMPTZ
NULL
,
totp_verified_at
TIMESTAMPTZ
NULL
,
expires_at
TIMESTAMPTZ
NOT
NULL
,
consumed_at
TIMESTAMPTZ
NULL
,
created_at
TIMESTAMPTZ
NOT
NULL
DEFAULT
NOW
(),
updated_at
TIMESTAMPTZ
NOT
NULL
DEFAULT
NOW
(),
CONSTRAINT
pending_auth_sessions_intent_check
CHECK
(
intent
IN
(
'login'
,
'bind_current_user'
,
'adopt_existing_user_by_email'
)),
CONSTRAINT
pending_auth_sessions_provider_type_check
CHECK
(
provider_type
IN
(
'email'
,
'linuxdo'
,
'wechat'
,
'oidc'
))
);
CREATE
UNIQUE
INDEX
IF
NOT
EXISTS
pending_auth_sessions_session_token_key
ON
pending_auth_sessions
(
session_token
);
CREATE
INDEX
IF
NOT
EXISTS
pending_auth_sessions_target_user_id_idx
ON
pending_auth_sessions
(
target_user_id
);
CREATE
INDEX
IF
NOT
EXISTS
pending_auth_sessions_expires_at_idx
ON
pending_auth_sessions
(
expires_at
);
CREATE
INDEX
IF
NOT
EXISTS
pending_auth_sessions_provider_idx
ON
pending_auth_sessions
(
provider_type
,
provider_key
,
provider_subject
);
CREATE
INDEX
IF
NOT
EXISTS
pending_auth_sessions_completion_code_idx
ON
pending_auth_sessions
(
completion_code_hash
);
CREATE
TABLE
IF
NOT
EXISTS
identity_adoption_decisions
(
id
BIGSERIAL
PRIMARY
KEY
,
pending_auth_session_id
BIGINT
NOT
NULL
REFERENCES
pending_auth_sessions
(
id
)
ON
DELETE
CASCADE
,
identity_id
BIGINT
NULL
REFERENCES
auth_identities
(
id
)
ON
DELETE
SET
NULL
,
adopt_display_name
BOOLEAN
NOT
NULL
DEFAULT
FALSE
,
adopt_avatar
BOOLEAN
NOT
NULL
DEFAULT
FALSE
,
decided_at
TIMESTAMPTZ
NOT
NULL
DEFAULT
NOW
(),
created_at
TIMESTAMPTZ
NOT
NULL
DEFAULT
NOW
(),
updated_at
TIMESTAMPTZ
NOT
NULL
DEFAULT
NOW
()
);
CREATE
UNIQUE
INDEX
IF
NOT
EXISTS
identity_adoption_decisions_pending_auth_session_id_key
ON
identity_adoption_decisions
(
pending_auth_session_id
);
CREATE
INDEX
IF
NOT
EXISTS
identity_adoption_decisions_identity_id_idx
ON
identity_adoption_decisions
(
identity_id
);
CREATE
TABLE
IF
NOT
EXISTS
auth_identity_migration_reports
(
id
BIGSERIAL
PRIMARY
KEY
,
report_type
VARCHAR
(
40
)
NOT
NULL
,
report_key
TEXT
NOT
NULL
,
details
JSONB
NOT
NULL
DEFAULT
'{}'
::
jsonb
,
created_at
TIMESTAMPTZ
NOT
NULL
DEFAULT
NOW
()
);
CREATE
INDEX
IF
NOT
EXISTS
auth_identity_migration_reports_type_idx
ON
auth_identity_migration_reports
(
report_type
);
CREATE
UNIQUE
INDEX
IF
NOT
EXISTS
auth_identity_migration_reports_type_key
ON
auth_identity_migration_reports
(
report_type
,
report_key
);
backend/migrations/109_auth_identity_compat_backfill.sql
0 → 100644
View file @
e9de839d
INSERT
INTO
auth_identities
(
user_id
,
provider_type
,
provider_key
,
provider_subject
,
verified_at
,
metadata
)
SELECT
u
.
id
,
'email'
,
'email'
,
LOWER
(
BTRIM
(
u
.
email
)),
COALESCE
(
u
.
updated_at
,
u
.
created_at
,
NOW
()),
jsonb_build_object
(
'backfill_source'
,
'users.email'
,
'migration'
,
'109_auth_identity_compat_backfill'
)
FROM
users
AS
u
WHERE
u
.
deleted_at
IS
NULL
AND
BTRIM
(
COALESCE
(
u
.
email
,
''
))
<>
''
AND
RIGHT
(
LOWER
(
BTRIM
(
u
.
email
)),
LENGTH
(
'@linuxdo-connect.invalid'
))
<>
'@linuxdo-connect.invalid'
AND
RIGHT
(
LOWER
(
BTRIM
(
u
.
email
)),
LENGTH
(
'@oidc-connect.invalid'
))
<>
'@oidc-connect.invalid'
AND
RIGHT
(
LOWER
(
BTRIM
(
u
.
email
)),
LENGTH
(
'@wechat-connect.invalid'
))
<>
'@wechat-connect.invalid'
ON
CONFLICT
(
provider_type
,
provider_key
,
provider_subject
)
DO
NOTHING
;
INSERT
INTO
auth_identities
(
user_id
,
provider_type
,
provider_key
,
provider_subject
,
verified_at
,
metadata
)
SELECT
u
.
id
,
'linuxdo'
,
'linuxdo'
,
SUBSTRING
(
BTRIM
(
u
.
email
)
FROM
'(?i)^linuxdo-(.+)@linuxdo-connect
\.
invalid$'
),
COALESCE
(
u
.
updated_at
,
u
.
created_at
,
NOW
()),
jsonb_build_object
(
'backfill_source'
,
'synthetic_email'
,
'legacy_email'
,
BTRIM
(
u
.
email
),
'migration'
,
'109_auth_identity_compat_backfill'
)
FROM
users
AS
u
WHERE
u
.
deleted_at
IS
NULL
AND
LOWER
(
BTRIM
(
u
.
email
))
~
'^linuxdo-.+@linuxdo-connect
\.
invalid$'
ON
CONFLICT
(
provider_type
,
provider_key
,
provider_subject
)
DO
NOTHING
;
INSERT
INTO
auth_identities
(
user_id
,
provider_type
,
provider_key
,
provider_subject
,
verified_at
,
metadata
)
SELECT
u
.
id
,
'wechat'
,
'wechat'
,
SUBSTRING
(
BTRIM
(
u
.
email
)
FROM
'(?i)^wechat-(.+)@wechat-connect
\.
invalid$'
),
COALESCE
(
u
.
updated_at
,
u
.
created_at
,
NOW
()),
jsonb_build_object
(
'backfill_source'
,
'synthetic_email'
,
'legacy_email'
,
BTRIM
(
u
.
email
),
'migration'
,
'109_auth_identity_compat_backfill'
)
FROM
users
AS
u
WHERE
u
.
deleted_at
IS
NULL
AND
LOWER
(
BTRIM
(
u
.
email
))
~
'^wechat-.+@wechat-connect
\.
invalid$'
ON
CONFLICT
(
provider_type
,
provider_key
,
provider_subject
)
DO
NOTHING
;
UPDATE
users
SET
signup_source
=
'linuxdo'
WHERE
deleted_at
IS
NULL
AND
LOWER
(
BTRIM
(
COALESCE
(
email
,
''
)))
~
'^linuxdo-.+@linuxdo-connect
\.
invalid$'
;
UPDATE
users
SET
signup_source
=
'wechat'
WHERE
deleted_at
IS
NULL
AND
LOWER
(
BTRIM
(
COALESCE
(
email
,
''
)))
~
'^wechat-.+@wechat-connect
\.
invalid$'
;
UPDATE
users
SET
signup_source
=
'oidc'
WHERE
deleted_at
IS
NULL
AND
LOWER
(
BTRIM
(
COALESCE
(
email
,
''
)))
~
'^oidc-.+@oidc-connect
\.
invalid$'
;
INSERT
INTO
auth_identity_migration_reports
(
report_type
,
report_key
,
details
)
SELECT
'oidc_synthetic_email_requires_manual_recovery'
,
CAST
(
u
.
id
AS
TEXT
),
jsonb_build_object
(
'user_id'
,
u
.
id
,
'email'
,
LOWER
(
BTRIM
(
u
.
email
)),
'reason'
,
'cannot recover issuer_plus_sub deterministically from synthetic email alone'
,
'migration'
,
'109_auth_identity_compat_backfill'
)
FROM
users
AS
u
WHERE
u
.
deleted_at
IS
NULL
AND
LOWER
(
BTRIM
(
u
.
email
))
~
'^oidc-.+@oidc-connect
\.
invalid$'
ON
CONFLICT
(
report_type
,
report_key
)
DO
NOTHING
;
INSERT
INTO
auth_identity_migration_reports
(
report_type
,
report_key
,
details
)
SELECT
'wechat_openid_only_requires_remediation'
,
CAST
(
u
.
id
AS
TEXT
),
jsonb_build_object
(
'user_id'
,
u
.
id
,
'email'
,
LOWER
(
BTRIM
(
u
.
email
)),
'reason'
,
'legacy wechat synthetic identity requires explicit unionid remediation if channel-only data exists'
,
'migration'
,
'109_auth_identity_compat_backfill'
)
FROM
users
AS
u
WHERE
u
.
deleted_at
IS
NULL
AND
LOWER
(
BTRIM
(
u
.
email
))
~
'^wechat-.+@wechat-connect
\.
invalid$'
AND
NOT
EXISTS
(
SELECT
1
FROM
auth_identities
ai
WHERE
ai
.
user_id
=
u
.
id
AND
ai
.
provider_type
=
'wechat'
AND
ai
.
provider_key
=
'wechat'
)
ON
CONFLICT
(
report_type
,
report_key
)
DO
NOTHING
;
backend/migrations/110_pending_auth_and_provider_default_grants.sql
0 → 100644
View file @
e9de839d
CREATE
TABLE
IF
NOT
EXISTS
user_provider_default_grants
(
id
BIGSERIAL
PRIMARY
KEY
,
user_id
BIGINT
NOT
NULL
REFERENCES
users
(
id
)
ON
DELETE
CASCADE
,
provider_type
VARCHAR
(
20
)
NOT
NULL
,
grant_reason
VARCHAR
(
20
)
NOT
NULL
DEFAULT
'first_bind'
,
granted_at
TIMESTAMPTZ
NOT
NULL
DEFAULT
NOW
(),
created_at
TIMESTAMPTZ
NOT
NULL
DEFAULT
NOW
(),
CONSTRAINT
user_provider_default_grants_provider_type_check
CHECK
(
provider_type
IN
(
'email'
,
'linuxdo'
,
'wechat'
,
'oidc'
)),
CONSTRAINT
user_provider_default_grants_reason_check
CHECK
(
grant_reason
IN
(
'signup'
,
'first_bind'
))
);
CREATE
UNIQUE
INDEX
IF
NOT
EXISTS
user_provider_default_grants_user_provider_reason_key
ON
user_provider_default_grants
(
user_id
,
provider_type
,
grant_reason
);
CREATE
INDEX
IF
NOT
EXISTS
user_provider_default_grants_user_id_idx
ON
user_provider_default_grants
(
user_id
);
CREATE
TABLE
IF
NOT
EXISTS
user_avatars
(
id
BIGSERIAL
PRIMARY
KEY
,
user_id
BIGINT
NOT
NULL
REFERENCES
users
(
id
)
ON
DELETE
CASCADE
,
storage_provider
VARCHAR
(
20
)
NOT
NULL
DEFAULT
'database'
,
storage_key
TEXT
NOT
NULL
DEFAULT
''
,
url
TEXT
NOT
NULL
DEFAULT
''
,
content_type
VARCHAR
(
100
)
NOT
NULL
DEFAULT
''
,
byte_size
INT
NOT
NULL
DEFAULT
0
,
sha256
VARCHAR
(
64
)
NOT
NULL
DEFAULT
''
,
created_at
TIMESTAMPTZ
NOT
NULL
DEFAULT
NOW
(),
updated_at
TIMESTAMPTZ
NOT
NULL
DEFAULT
NOW
()
);
CREATE
UNIQUE
INDEX
IF
NOT
EXISTS
user_avatars_user_id_key
ON
user_avatars
(
user_id
);
INSERT
INTO
settings
(
key
,
value
)
VALUES
(
'auth_source_default_email_balance'
,
'0'
),
(
'auth_source_default_email_concurrency'
,
'5'
),
(
'auth_source_default_email_subscriptions'
,
'[]'
),
(
'auth_source_default_email_grant_on_signup'
,
'true'
),
(
'auth_source_default_email_grant_on_first_bind'
,
'false'
),
(
'auth_source_default_linuxdo_balance'
,
'0'
),
(
'auth_source_default_linuxdo_concurrency'
,
'5'
),
(
'auth_source_default_linuxdo_subscriptions'
,
'[]'
),
(
'auth_source_default_linuxdo_grant_on_signup'
,
'true'
),
(
'auth_source_default_linuxdo_grant_on_first_bind'
,
'false'
),
(
'auth_source_default_oidc_balance'
,
'0'
),
(
'auth_source_default_oidc_concurrency'
,
'5'
),
(
'auth_source_default_oidc_subscriptions'
,
'[]'
),
(
'auth_source_default_oidc_grant_on_signup'
,
'true'
),
(
'auth_source_default_oidc_grant_on_first_bind'
,
'false'
),
(
'auth_source_default_wechat_balance'
,
'0'
),
(
'auth_source_default_wechat_concurrency'
,
'5'
),
(
'auth_source_default_wechat_subscriptions'
,
'[]'
),
(
'auth_source_default_wechat_grant_on_signup'
,
'true'
),
(
'auth_source_default_wechat_grant_on_first_bind'
,
'false'
),
(
'force_email_on_third_party_signup'
,
'false'
)
ON
CONFLICT
(
key
)
DO
NOTHING
;
backend/migrations/111_payment_routing_and_scheduler_flags.sql
0 → 100644
View file @
e9de839d
INSERT
INTO
settings
(
key
,
value
)
VALUES
(
'payment_visible_method_alipay_source'
,
''
),
(
'payment_visible_method_wxpay_source'
,
''
),
(
'payment_visible_method_alipay_enabled'
,
'false'
),
(
'payment_visible_method_wxpay_enabled'
,
'false'
),
(
'openai_advanced_scheduler_enabled'
,
'false'
)
ON
CONFLICT
(
key
)
DO
NOTHING
;
Prev
1
2
3
4
5
6
7
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment