Unverified Commit 8eb3f9e7 authored by Wesley Liddick's avatar Wesley Liddick Committed by GitHub
Browse files

Merge pull request #1785 from IanShaw027/rebuild/auth-identity-foundation

feat(auth,payment): 重构认证身份和支付系统及其他部分优化
parents 78f691d2 7fbd5177
//go:build unit
package service_test
import (
"context"
"database/sql"
"errors"
"testing"
"time"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/authidentity"
"github.com/Wei-Shaw/sub2api/ent/enttest"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/repository"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/stretchr/testify/require"
"entgo.io/ent/dialect"
entsql "entgo.io/ent/dialect/sql"
_ "modernc.org/sqlite"
)
type emailBindDefaultSubAssignerStub struct {
calls []*service.AssignSubscriptionInput
}
func (s *emailBindDefaultSubAssignerStub) AssignOrExtendSubscription(
_ context.Context,
input *service.AssignSubscriptionInput,
) (*service.UserSubscription, bool, error) {
cloned := *input
s.calls = append(s.calls, &cloned)
return &service.UserSubscription{UserID: input.UserID, GroupID: input.GroupID}, false, nil
}
type flakyEmailBindDefaultSubAssignerStub struct {
err error
calls []*service.AssignSubscriptionInput
}
func (s *flakyEmailBindDefaultSubAssignerStub) AssignOrExtendSubscription(
_ context.Context,
input *service.AssignSubscriptionInput,
) (*service.UserSubscription, bool, error) {
cloned := *input
s.calls = append(s.calls, &cloned)
return nil, false, s.err
}
func newAuthServiceForEmailBind(
t *testing.T,
settings map[string]string,
emailCache service.EmailCache,
defaultSubAssigner service.DefaultSubscriptionAssigner,
) (*service.AuthService, service.UserRepository, *dbent.Client) {
t.Helper()
db, err := sql.Open("sqlite", "file:auth_service_email_bind?mode=memory&cache=shared")
require.NoError(t, err)
t.Cleanup(func() { _ = db.Close() })
_, err = db.Exec("PRAGMA foreign_keys = ON")
require.NoError(t, err)
_, err = db.Exec(`
CREATE TABLE IF NOT EXISTS user_provider_default_grants (
id INTEGER PRIMARY KEY AUTOINCREMENT,
user_id INTEGER NOT NULL,
provider_type TEXT NOT NULL,
grant_reason TEXT NOT NULL DEFAULT 'first_bind',
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
UNIQUE(user_id, provider_type, grant_reason)
)`)
require.NoError(t, err)
drv := entsql.OpenDB(dialect.SQLite, db)
client := enttest.NewClient(t, enttest.WithOptions(dbent.Driver(drv)))
t.Cleanup(func() { _ = client.Close() })
repo := repository.NewUserRepository(client, db)
cfg := &config.Config{
JWT: config.JWTConfig{
Secret: "test-bind-email-secret",
ExpireHour: 1,
},
Default: config.DefaultConfig{
UserBalance: 3.5,
UserConcurrency: 2,
},
}
settingRepo := &emailBindSettingRepoStub{values: settings}
settingSvc := service.NewSettingService(settingRepo, cfg)
var emailSvc *service.EmailService
if emailCache != nil {
emailSvc = service.NewEmailService(settingRepo, emailCache)
}
svc := service.NewAuthService(client, repo, nil, nil, cfg, settingSvc, emailSvc, nil, nil, nil, defaultSubAssigner)
return svc, repo, client
}
func TestAuthServiceBindEmailIdentity_UpdatesEmailAndAppliesFirstBindDefaults(t *testing.T) {
assigner := &emailBindDefaultSubAssignerStub{}
cache := &emailBindCacheStub{
data: &service.VerificationCodeData{
Code: "123456",
CreatedAt: time.Now().UTC(),
ExpiresAt: time.Now().UTC().Add(10 * time.Minute),
},
}
svc, _, client := newAuthServiceForEmailBind(t, map[string]string{
service.SettingKeyAuthSourceDefaultEmailBalance: "8.5",
service.SettingKeyAuthSourceDefaultEmailConcurrency: "4",
service.SettingKeyAuthSourceDefaultEmailSubscriptions: `[{"group_id":11,"validity_days":30}]`,
service.SettingKeyAuthSourceDefaultEmailGrantOnFirstBind: "true",
}, cache, assigner)
ctx := context.Background()
user, err := client.User.Create().
SetEmail("legacy-user" + service.LinuxDoConnectSyntheticEmailDomain).
SetUsername("legacy-user").
SetPasswordHash("old-hash").
SetBalance(2.5).
SetConcurrency(1).
SetRole(service.RoleUser).
SetStatus(service.StatusActive).
Save(ctx)
require.NoError(t, err)
updatedUser, err := svc.BindEmailIdentity(ctx, user.ID, " NewEmail@Example.com ", "123456", "new-password")
require.NoError(t, err)
require.NotNil(t, updatedUser)
require.Equal(t, "newemail@example.com", updatedUser.Email)
storedUser, err := client.User.Get(ctx, user.ID)
require.NoError(t, err)
require.Equal(t, "newemail@example.com", storedUser.Email)
require.Equal(t, 11.0, storedUser.Balance)
require.Equal(t, 5, storedUser.Concurrency)
require.True(t, svc.CheckPassword("new-password", storedUser.PasswordHash))
identityCount, err := client.AuthIdentity.Query().
Where(
authidentity.UserIDEQ(user.ID),
authidentity.ProviderTypeEQ("email"),
authidentity.ProviderKeyEQ("email"),
authidentity.ProviderSubjectEQ("newemail@example.com"),
).
Count(ctx)
require.NoError(t, err)
require.Equal(t, 1, identityCount)
require.Len(t, assigner.calls, 1)
require.Equal(t, user.ID, assigner.calls[0].UserID)
require.Equal(t, int64(11), assigner.calls[0].GroupID)
require.Equal(t, 30, assigner.calls[0].ValidityDays)
require.Equal(t, 1, countProviderGrantRecords(t, client, user.ID, "email", "first_bind"))
}
func TestAuthServiceBindEmailIdentity_RejectsExistingEmailOnAnotherUser(t *testing.T) {
cache := &emailBindCacheStub{
data: &service.VerificationCodeData{
Code: "123456",
CreatedAt: time.Now().UTC(),
ExpiresAt: time.Now().UTC().Add(10 * time.Minute),
},
}
svc, _, client := newAuthServiceForEmailBind(t, nil, cache, nil)
ctx := context.Background()
sourceUser, err := client.User.Create().
SetEmail("source-user" + service.OIDCConnectSyntheticEmailDomain).
SetUsername("source-user").
SetPasswordHash("old-hash").
SetBalance(1).
SetConcurrency(1).
SetRole(service.RoleUser).
SetStatus(service.StatusActive).
Save(ctx)
require.NoError(t, err)
_, err = client.User.Create().
SetEmail("taken@example.com").
SetUsername("taken-user").
SetPasswordHash("hash").
SetBalance(1).
SetConcurrency(1).
SetRole(service.RoleUser).
SetStatus(service.StatusActive).
Save(ctx)
require.NoError(t, err)
updatedUser, err := svc.BindEmailIdentity(ctx, sourceUser.ID, "taken@example.com", "123456", "new-password")
require.ErrorIs(t, err, service.ErrEmailExists)
require.Nil(t, updatedUser)
storedUser, err := client.User.Get(ctx, sourceUser.ID)
require.NoError(t, err)
require.Equal(t, "source-user"+service.OIDCConnectSyntheticEmailDomain, storedUser.Email)
require.Equal(t, 0, countProviderGrantRecords(t, client, sourceUser.ID, "email", "first_bind"))
}
func TestAuthServiceBindEmailIdentity_RollsBackWhenFirstBindDefaultsFail(t *testing.T) {
assigner := &flakyEmailBindDefaultSubAssignerStub{err: errors.New("temporary assign failure")}
cache := &emailBindCacheStub{
data: &service.VerificationCodeData{
Code: "123456",
CreatedAt: time.Now().UTC(),
ExpiresAt: time.Now().UTC().Add(10 * time.Minute),
},
}
svc, _, client := newAuthServiceForEmailBind(t, map[string]string{
service.SettingKeyAuthSourceDefaultEmailBalance: "8.5",
service.SettingKeyAuthSourceDefaultEmailConcurrency: "4",
service.SettingKeyAuthSourceDefaultEmailSubscriptions: `[{"group_id":11,"validity_days":30}]`,
service.SettingKeyAuthSourceDefaultEmailGrantOnFirstBind: "true",
}, cache, assigner)
ctx := context.Background()
originalEmail := "legacy-rollback" + service.LinuxDoConnectSyntheticEmailDomain
user, err := client.User.Create().
SetEmail(originalEmail).
SetUsername("legacy-rollback").
SetPasswordHash("old-hash").
SetBalance(2.5).
SetConcurrency(1).
SetRole(service.RoleUser).
SetStatus(service.StatusActive).
Save(ctx)
require.NoError(t, err)
updatedUser, err := svc.BindEmailIdentity(ctx, user.ID, "rollback@example.com", "123456", "new-password")
require.ErrorContains(t, err, "apply email first bind defaults")
require.ErrorContains(t, err, "temporary assign failure")
require.Nil(t, updatedUser)
storedUser, err := client.User.Get(ctx, user.ID)
require.NoError(t, err)
require.Equal(t, originalEmail, storedUser.Email)
require.Equal(t, "old-hash", storedUser.PasswordHash)
require.Equal(t, 2.5, storedUser.Balance)
require.Equal(t, 1, storedUser.Concurrency)
identityCount, err := client.AuthIdentity.Query().
Where(
authidentity.UserIDEQ(user.ID),
authidentity.ProviderTypeEQ("email"),
authidentity.ProviderKeyEQ("email"),
authidentity.ProviderSubjectEQ("rollback@example.com"),
).
Count(ctx)
require.NoError(t, err)
require.Equal(t, 0, identityCount)
require.Len(t, assigner.calls, 1)
require.Equal(t, 0, countProviderGrantRecords(t, client, user.ID, "email", "first_bind"))
}
func TestAuthServiceBindEmailIdentity_RejectsReservedEmail(t *testing.T) {
cache := &emailBindCacheStub{
data: &service.VerificationCodeData{
Code: "123456",
CreatedAt: time.Now().UTC(),
ExpiresAt: time.Now().UTC().Add(10 * time.Minute),
},
}
svc, _, client := newAuthServiceForEmailBind(t, nil, cache, nil)
ctx := context.Background()
user, err := client.User.Create().
SetEmail("source-user@example.com").
SetUsername("source-user").
SetPasswordHash("old-hash").
SetBalance(1).
SetConcurrency(1).
SetRole(service.RoleUser).
SetStatus(service.StatusActive).
Save(ctx)
require.NoError(t, err)
updatedUser, err := svc.BindEmailIdentity(ctx, user.ID, "reserved"+service.LinuxDoConnectSyntheticEmailDomain, "123456", "new-password")
require.ErrorIs(t, err, service.ErrEmailReserved)
require.Nil(t, updatedUser)
}
func TestAuthServiceBindEmailIdentity_ReplacesBoundEmailAndSkipsFirstBindDefaults(t *testing.T) {
assigner := &emailBindDefaultSubAssignerStub{}
cache := &emailBindCacheStub{
data: &service.VerificationCodeData{
Code: "123456",
CreatedAt: time.Now().UTC(),
ExpiresAt: time.Now().UTC().Add(10 * time.Minute),
},
}
svc, _, client := newAuthServiceForEmailBind(t, map[string]string{
service.SettingKeyAuthSourceDefaultEmailBalance: "8.5",
service.SettingKeyAuthSourceDefaultEmailConcurrency: "4",
service.SettingKeyAuthSourceDefaultEmailSubscriptions: `[{"group_id":11,"validity_days":30}]`,
service.SettingKeyAuthSourceDefaultEmailGrantOnFirstBind: "true",
}, cache, assigner)
ctx := context.Background()
hashedPassword, err := svc.HashPassword("current-password")
require.NoError(t, err)
user, err := client.User.Create().
SetEmail("current@example.com").
SetUsername("bound-user").
SetPasswordHash(hashedPassword).
SetBalance(7.5).
SetConcurrency(3).
SetRole(service.RoleUser).
SetStatus(service.StatusActive).
Save(ctx)
require.NoError(t, err)
require.NoError(t, client.AuthIdentity.Create().
SetUserID(user.ID).
SetProviderType("email").
SetProviderKey("email").
SetProviderSubject("current@example.com").
SetVerifiedAt(time.Now().UTC()).
SetMetadata(map[string]any{"source": "test"}).
Exec(ctx))
updatedUser, err := svc.BindEmailIdentity(ctx, user.ID, "new@example.com", "123456", "current-password")
require.NoError(t, err)
require.NotNil(t, updatedUser)
require.Equal(t, "new@example.com", updatedUser.Email)
storedUser, err := client.User.Get(ctx, user.ID)
require.NoError(t, err)
require.Equal(t, "new@example.com", storedUser.Email)
require.Equal(t, 7.5, storedUser.Balance)
require.Equal(t, 3, storedUser.Concurrency)
require.True(t, svc.CheckPassword("current-password", storedUser.PasswordHash))
newIdentityCount, err := client.AuthIdentity.Query().
Where(
authidentity.UserIDEQ(user.ID),
authidentity.ProviderTypeEQ("email"),
authidentity.ProviderKeyEQ("email"),
authidentity.ProviderSubjectEQ("new@example.com"),
).
Count(ctx)
require.NoError(t, err)
require.Equal(t, 1, newIdentityCount)
oldIdentityCount, err := client.AuthIdentity.Query().
Where(
authidentity.UserIDEQ(user.ID),
authidentity.ProviderTypeEQ("email"),
authidentity.ProviderKeyEQ("email"),
authidentity.ProviderSubjectEQ("current@example.com"),
).
Count(ctx)
require.NoError(t, err)
require.Equal(t, 0, oldIdentityCount)
require.Empty(t, assigner.calls)
require.Equal(t, 0, countProviderGrantRecords(t, client, user.ID, "email", "first_bind"))
}
func TestAuthServiceBindEmailIdentity_RejectsWrongCurrentPasswordForBoundEmail(t *testing.T) {
cache := &emailBindCacheStub{
data: &service.VerificationCodeData{
Code: "123456",
CreatedAt: time.Now().UTC(),
ExpiresAt: time.Now().UTC().Add(10 * time.Minute),
},
}
svc, _, client := newAuthServiceForEmailBind(t, nil, cache, nil)
ctx := context.Background()
hashedPassword, err := svc.HashPassword("current-password")
require.NoError(t, err)
user, err := client.User.Create().
SetEmail("current@example.com").
SetUsername("bound-user").
SetPasswordHash(hashedPassword).
SetBalance(1).
SetConcurrency(1).
SetRole(service.RoleUser).
SetStatus(service.StatusActive).
Save(ctx)
require.NoError(t, err)
require.NoError(t, client.AuthIdentity.Create().
SetUserID(user.ID).
SetProviderType("email").
SetProviderKey("email").
SetProviderSubject("current@example.com").
SetVerifiedAt(time.Now().UTC()).
SetMetadata(map[string]any{"source": "test"}).
Exec(ctx))
updatedUser, err := svc.BindEmailIdentity(ctx, user.ID, "new@example.com", "123456", "wrong-password")
require.ErrorIs(t, err, service.ErrPasswordIncorrect)
require.Nil(t, updatedUser)
storedUser, err := client.User.Get(ctx, user.ID)
require.NoError(t, err)
require.Equal(t, "current@example.com", storedUser.Email)
require.True(t, svc.CheckPassword("current-password", storedUser.PasswordHash))
oldIdentityCount, err := client.AuthIdentity.Query().
Where(
authidentity.UserIDEQ(user.ID),
authidentity.ProviderTypeEQ("email"),
authidentity.ProviderKeyEQ("email"),
authidentity.ProviderSubjectEQ("current@example.com"),
).
Count(ctx)
require.NoError(t, err)
require.Equal(t, 1, oldIdentityCount)
newIdentityCount, err := client.AuthIdentity.Query().
Where(
authidentity.UserIDEQ(user.ID),
authidentity.ProviderTypeEQ("email"),
authidentity.ProviderKeyEQ("email"),
authidentity.ProviderSubjectEQ("new@example.com"),
).
Count(ctx)
require.NoError(t, err)
require.Equal(t, 0, newIdentityCount)
}
type emailBindSettingRepoStub struct {
values map[string]string
}
func (s *emailBindSettingRepoStub) Get(context.Context, string) (*service.Setting, error) {
panic("unexpected Get call")
}
func (s *emailBindSettingRepoStub) GetValue(_ context.Context, key string) (string, error) {
if v, ok := s.values[key]; ok {
return v, nil
}
return "", service.ErrSettingNotFound
}
func (s *emailBindSettingRepoStub) Set(context.Context, string, string) error {
panic("unexpected Set call")
}
func (s *emailBindSettingRepoStub) GetMultiple(_ context.Context, keys []string) (map[string]string, error) {
out := make(map[string]string, len(keys))
for _, key := range keys {
if v, ok := s.values[key]; ok {
out[key] = v
}
}
return out, nil
}
func (s *emailBindSettingRepoStub) SetMultiple(context.Context, map[string]string) error {
panic("unexpected SetMultiple call")
}
func (s *emailBindSettingRepoStub) GetAll(context.Context) (map[string]string, error) {
panic("unexpected GetAll call")
}
func (s *emailBindSettingRepoStub) Delete(context.Context, string) error {
panic("unexpected Delete call")
}
type emailBindCacheStub struct {
data *service.VerificationCodeData
err error
}
func (s *emailBindCacheStub) GetVerificationCode(context.Context, string) (*service.VerificationCodeData, error) {
if s.err != nil {
return nil, s.err
}
return s.data, nil
}
func (s *emailBindCacheStub) SetVerificationCode(context.Context, string, *service.VerificationCodeData, time.Duration) error {
return nil
}
func (s *emailBindCacheStub) DeleteVerificationCode(context.Context, string) error {
return nil
}
func (s *emailBindCacheStub) GetNotifyVerifyCode(context.Context, string) (*service.VerificationCodeData, error) {
return nil, nil
}
func (s *emailBindCacheStub) SetNotifyVerifyCode(context.Context, string, *service.VerificationCodeData, time.Duration) error {
return nil
}
func (s *emailBindCacheStub) DeleteNotifyVerifyCode(context.Context, string) error {
return nil
}
func (s *emailBindCacheStub) GetPasswordResetToken(context.Context, string) (*service.PasswordResetTokenData, error) {
return nil, nil
}
func (s *emailBindCacheStub) SetPasswordResetToken(context.Context, string, *service.PasswordResetTokenData, time.Duration) error {
return nil
}
func (s *emailBindCacheStub) DeletePasswordResetToken(context.Context, string) error {
return nil
}
func (s *emailBindCacheStub) IsPasswordResetEmailInCooldown(context.Context, string) bool {
return false
}
func (s *emailBindCacheStub) SetPasswordResetEmailCooldown(context.Context, string, time.Duration) error {
return nil
}
func (s *emailBindCacheStub) GetNotifyCodeUserRate(context.Context, int64) (int64, error) {
return 0, nil
}
func (s *emailBindCacheStub) IncrNotifyCodeUserRate(context.Context, int64, time.Duration) (int64, error) {
return 0, nil
}
//go:build unit
package service_test
import (
"context"
"database/sql"
"errors"
"testing"
"time"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/authidentity"
"github.com/Wei-Shaw/sub2api/ent/enttest"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/repository"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/stretchr/testify/require"
"entgo.io/ent/dialect"
entsql "entgo.io/ent/dialect/sql"
_ "modernc.org/sqlite"
)
type authIdentityDefaultSubAssignerStub struct {
calls []*service.AssignSubscriptionInput
}
func (s *authIdentityDefaultSubAssignerStub) AssignOrExtendSubscription(
_ context.Context,
input *service.AssignSubscriptionInput,
) (*service.UserSubscription, bool, error) {
cloned := *input
s.calls = append(s.calls, &cloned)
return &service.UserSubscription{UserID: input.UserID, GroupID: input.GroupID}, true, nil
}
type flakyAuthIdentityDefaultSubAssignerStub struct {
failuresRemaining int
calls []*service.AssignSubscriptionInput
}
func (s *flakyAuthIdentityDefaultSubAssignerStub) AssignOrExtendSubscription(
_ context.Context,
input *service.AssignSubscriptionInput,
) (*service.UserSubscription, bool, error) {
cloned := *input
s.calls = append(s.calls, &cloned)
if s.failuresRemaining > 0 {
s.failuresRemaining--
return nil, false, errors.New("temporary assign failure")
}
return &service.UserSubscription{UserID: input.UserID, GroupID: input.GroupID}, true, nil
}
type authIdentitySettingRepoStub struct {
values map[string]string
}
func (s *authIdentitySettingRepoStub) Get(context.Context, string) (*service.Setting, error) {
panic("unexpected Get call")
}
func (s *authIdentitySettingRepoStub) GetValue(_ context.Context, key string) (string, error) {
if v, ok := s.values[key]; ok {
return v, nil
}
return "", service.ErrSettingNotFound
}
func (s *authIdentitySettingRepoStub) Set(context.Context, string, string) error {
panic("unexpected Set call")
}
func (s *authIdentitySettingRepoStub) GetMultiple(_ context.Context, keys []string) (map[string]string, error) {
out := make(map[string]string, len(keys))
for _, key := range keys {
if v, ok := s.values[key]; ok {
out[key] = v
}
}
return out, nil
}
func (s *authIdentitySettingRepoStub) SetMultiple(context.Context, map[string]string) error {
panic("unexpected SetMultiple call")
}
func (s *authIdentitySettingRepoStub) GetAll(context.Context) (map[string]string, error) {
panic("unexpected GetAll call")
}
func (s *authIdentitySettingRepoStub) Delete(context.Context, string) error {
panic("unexpected Delete call")
}
func newAuthServiceWithEnt(
t *testing.T,
settings map[string]string,
defaultSubAssigner service.DefaultSubscriptionAssigner,
) (*service.AuthService, service.UserRepository, *dbent.Client) {
t.Helper()
db, err := sql.Open("sqlite", "file:auth_service_identity_sync?mode=memory&cache=shared")
require.NoError(t, err)
t.Cleanup(func() { _ = db.Close() })
_, err = db.Exec("PRAGMA foreign_keys = ON")
require.NoError(t, err)
_, err = db.Exec(`
CREATE TABLE IF NOT EXISTS user_provider_default_grants (
id INTEGER PRIMARY KEY AUTOINCREMENT,
user_id INTEGER NOT NULL,
provider_type TEXT NOT NULL,
grant_reason TEXT NOT NULL DEFAULT 'first_bind',
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
UNIQUE(user_id, provider_type, grant_reason)
)`)
require.NoError(t, err)
drv := entsql.OpenDB(dialect.SQLite, db)
client := enttest.NewClient(t, enttest.WithOptions(dbent.Driver(drv)))
t.Cleanup(func() { _ = client.Close() })
repo := repository.NewUserRepository(client, db)
cfg := &config.Config{
JWT: config.JWTConfig{
Secret: "test-auth-identity-secret",
ExpireHour: 1,
},
Default: config.DefaultConfig{
UserBalance: 3.5,
UserConcurrency: 2,
},
}
settingSvc := service.NewSettingService(&authIdentitySettingRepoStub{
values: settings,
}, cfg)
svc := service.NewAuthService(client, repo, nil, nil, cfg, settingSvc, nil, nil, nil, nil, defaultSubAssigner)
return svc, repo, client
}
func TestAuthServiceRegisterDualWritesEmailIdentity(t *testing.T) {
svc, _, client := newAuthServiceWithEnt(t, map[string]string{
service.SettingKeyRegistrationEnabled: "true",
}, nil)
ctx := context.Background()
token, user, err := svc.Register(ctx, "user@example.com", "password")
require.NoError(t, err)
require.NotEmpty(t, token)
require.NotNil(t, user)
storedUser, err := client.User.Get(ctx, user.ID)
require.NoError(t, err)
require.Equal(t, "email", storedUser.SignupSource)
require.NotNil(t, storedUser.LastLoginAt)
require.NotNil(t, storedUser.LastActiveAt)
identity, err := client.AuthIdentity.Query().
Where(
authidentity.ProviderTypeEQ("email"),
authidentity.ProviderKeyEQ("email"),
authidentity.ProviderSubjectEQ("user@example.com"),
).
Only(ctx)
require.NoError(t, err)
require.Equal(t, user.ID, identity.UserID)
require.NotNil(t, identity.VerifiedAt)
}
func TestAuthServiceLoginDefersLastLoginTouchUntilRecordSuccessfulLogin(t *testing.T) {
svc, _, client := newAuthServiceWithEnt(t, map[string]string{
service.SettingKeyRegistrationEnabled: "true",
}, nil)
ctx := context.Background()
passwordHash, err := svc.HashPassword("password")
require.NoError(t, err)
user, err := client.User.Create().
SetEmail("login@example.com").
SetPasswordHash(passwordHash).
SetRole(service.RoleUser).
SetStatus(service.StatusActive).
SetBalance(1).
SetConcurrency(1).
Save(ctx)
require.NoError(t, err)
old := time.Now().Add(-2 * time.Hour).UTC().Round(time.Second)
_, err = client.User.UpdateOneID(user.ID).
SetLastLoginAt(old).
SetLastActiveAt(old).
Save(ctx)
require.NoError(t, err)
token, gotUser, err := svc.Login(ctx, user.Email, "password")
require.NoError(t, err)
require.NotEmpty(t, token)
require.NotNil(t, gotUser)
storedUser, err := client.User.Get(ctx, user.ID)
require.NoError(t, err)
require.NotNil(t, storedUser.LastLoginAt)
require.NotNil(t, storedUser.LastActiveAt)
require.True(t, storedUser.LastLoginAt.Equal(old))
require.True(t, storedUser.LastActiveAt.Equal(old))
identityCount, err := client.AuthIdentity.Query().
Where(
authidentity.ProviderTypeEQ("email"),
authidentity.ProviderKeyEQ("email"),
authidentity.ProviderSubjectEQ("login@example.com"),
).
Count(ctx)
require.NoError(t, err)
require.Zero(t, identityCount)
svc.RecordSuccessfulLogin(ctx, user.ID)
identity, err := client.AuthIdentity.Query().
Where(
authidentity.ProviderTypeEQ("email"),
authidentity.ProviderKeyEQ("email"),
authidentity.ProviderSubjectEQ("login@example.com"),
).
Only(ctx)
require.NoError(t, err)
require.Equal(t, user.ID, identity.UserID)
}
func TestAuthServiceRecordSuccessfulLoginBackfillsEmailIdentity(t *testing.T) {
svc, repo, client := newAuthServiceWithEnt(t, map[string]string{
service.SettingKeyRegistrationEnabled: "true",
}, nil)
ctx := context.Background()
user := &service.User{
Email: "record@example.com",
Role: service.RoleUser,
Status: service.StatusActive,
Balance: 1,
Concurrency: 1,
}
require.NoError(t, user.SetPassword("password"))
require.NoError(t, repo.Create(ctx, user))
svc.RecordSuccessfulLogin(ctx, user.ID)
identity, err := client.AuthIdentity.Query().
Where(
authidentity.ProviderTypeEQ("email"),
authidentity.ProviderKeyEQ("email"),
authidentity.ProviderSubjectEQ("record@example.com"),
).
Only(ctx)
require.NoError(t, err)
require.Equal(t, user.ID, identity.UserID)
}
func TestAuthServiceLogin_DoesNotApplyEmailFirstBindDefaultsWhenBackfillingLegacyEmailIdentity(t *testing.T) {
assigner := &authIdentityDefaultSubAssignerStub{}
svc, _, client := newAuthServiceWithEnt(t, map[string]string{
service.SettingKeyRegistrationEnabled: "true",
service.SettingKeyAuthSourceDefaultEmailBalance: "8.5",
service.SettingKeyAuthSourceDefaultEmailConcurrency: "4",
service.SettingKeyAuthSourceDefaultEmailSubscriptions: `[{"group_id":11,"validity_days":30}]`,
service.SettingKeyAuthSourceDefaultEmailGrantOnFirstBind: "true",
}, assigner)
ctx := context.Background()
passwordHash, err := svc.HashPassword("password")
require.NoError(t, err)
user, err := client.User.Create().
SetEmail("legacy@example.com").
SetUsername("legacy-user").
SetPasswordHash(passwordHash).
SetBalance(1.5).
SetConcurrency(2).
SetRole(service.RoleUser).
SetStatus(service.StatusActive).
Save(ctx)
require.NoError(t, err)
token, gotUser, err := svc.Login(ctx, user.Email, "password")
require.NoError(t, err)
require.NotEmpty(t, token)
require.NotNil(t, gotUser)
svc.RecordSuccessfulLogin(ctx, user.ID)
storedUser, err := client.User.Get(ctx, user.ID)
require.NoError(t, err)
require.Equal(t, 1.5, storedUser.Balance)
require.Equal(t, 2, storedUser.Concurrency)
require.Empty(t, assigner.calls)
identityCount, err := client.AuthIdentity.Query().
Where(
authidentity.ProviderTypeEQ("email"),
authidentity.ProviderKeyEQ("email"),
authidentity.ProviderSubjectEQ("legacy@example.com"),
).
Count(ctx)
require.NoError(t, err)
require.Equal(t, 1, identityCount)
require.Equal(t, 0, countProviderGrantRecords(t, client, user.ID, "email", "first_bind"))
token, gotUser, err = svc.Login(ctx, user.Email, "password")
require.NoError(t, err)
require.NotEmpty(t, token)
require.NotNil(t, gotUser)
storedUser, err = client.User.Get(ctx, user.ID)
require.NoError(t, err)
require.Equal(t, 1.5, storedUser.Balance)
require.Equal(t, 2, storedUser.Concurrency)
require.Empty(t, assigner.calls)
require.Equal(t, 0, countProviderGrantRecords(t, client, user.ID, "email", "first_bind"))
}
func TestAuthServiceLogin_DoesNotApplyMergedEmailFirstBindDefaultsWhenBackfillingLegacyEmailIdentity(t *testing.T) {
assigner := &authIdentityDefaultSubAssignerStub{}
svc, _, client := newAuthServiceWithEnt(t, map[string]string{
service.SettingKeyRegistrationEnabled: "true",
service.SettingKeyDefaultSubscriptions: `[{"group_id":21,"validity_days":14}]`,
service.SettingKeyAuthSourceDefaultEmailBalance: "8.5",
service.SettingKeyAuthSourceDefaultEmailConcurrency: "5",
service.SettingKeyAuthSourceDefaultEmailSubscriptions: `[]`,
service.SettingKeyAuthSourceDefaultEmailGrantOnFirstBind: "true",
}, assigner)
ctx := context.Background()
passwordHash, err := svc.HashPassword("password")
require.NoError(t, err)
user, err := client.User.Create().
SetEmail("merged-first-bind@example.com").
SetUsername("merged-user").
SetPasswordHash(passwordHash).
SetBalance(1.5).
SetConcurrency(2).
SetRole(service.RoleUser).
SetStatus(service.StatusActive).
Save(ctx)
require.NoError(t, err)
token, gotUser, err := svc.Login(ctx, user.Email, "password")
require.NoError(t, err)
require.NotEmpty(t, token)
require.NotNil(t, gotUser)
svc.RecordSuccessfulLogin(ctx, user.ID)
storedUser, err := client.User.Get(ctx, user.ID)
require.NoError(t, err)
require.Equal(t, 1.5, storedUser.Balance)
require.Equal(t, 2, storedUser.Concurrency)
require.Empty(t, assigner.calls)
require.Equal(t, 0, countProviderGrantRecords(t, client, user.ID, "email", "first_bind"))
}
func TestAuthServiceLogin_DoesNotApplyEmailFirstBindDefaultsWhenIdentityAlreadyExists(t *testing.T) {
assigner := &authIdentityDefaultSubAssignerStub{}
svc, _, client := newAuthServiceWithEnt(t, map[string]string{
service.SettingKeyRegistrationEnabled: "true",
service.SettingKeyAuthSourceDefaultEmailBalance: "8.5",
service.SettingKeyAuthSourceDefaultEmailConcurrency: "4",
service.SettingKeyAuthSourceDefaultEmailSubscriptions: `[{"group_id":11,"validity_days":30}]`,
service.SettingKeyAuthSourceDefaultEmailGrantOnFirstBind: "true",
}, assigner)
ctx := context.Background()
passwordHash, err := svc.HashPassword("password")
require.NoError(t, err)
user, err := client.User.Create().
SetEmail("bound@example.com").
SetUsername("bound-user").
SetPasswordHash(passwordHash).
SetBalance(2).
SetConcurrency(3).
SetRole(service.RoleUser).
SetStatus(service.StatusActive).
Save(ctx)
require.NoError(t, err)
_, err = client.AuthIdentity.Create().
SetUserID(user.ID).
SetProviderType("email").
SetProviderKey("email").
SetProviderSubject("bound@example.com").
SetVerifiedAt(time.Now().UTC()).
SetMetadata(map[string]any{"source": "preexisting"}).
Save(ctx)
require.NoError(t, err)
token, gotUser, err := svc.Login(ctx, user.Email, "password")
require.NoError(t, err)
require.NotEmpty(t, token)
require.NotNil(t, gotUser)
svc.RecordSuccessfulLogin(ctx, user.ID)
storedUser, err := client.User.Get(ctx, user.ID)
require.NoError(t, err)
require.Equal(t, 2.0, storedUser.Balance)
require.Equal(t, 3, storedUser.Concurrency)
require.Empty(t, assigner.calls)
require.Equal(t, 0, countProviderGrantRecords(t, client, user.ID, "email", "first_bind"))
}
func TestAuthServiceLogin_DoesNotRetryEmailFirstBindDefaultsForBackfilledEmailIdentity(t *testing.T) {
assigner := &flakyAuthIdentityDefaultSubAssignerStub{failuresRemaining: 1}
svc, _, client := newAuthServiceWithEnt(t, map[string]string{
service.SettingKeyRegistrationEnabled: "true",
service.SettingKeyAuthSourceDefaultEmailBalance: "8.5",
service.SettingKeyAuthSourceDefaultEmailConcurrency: "4",
service.SettingKeyAuthSourceDefaultEmailSubscriptions: `[{"group_id":11,"validity_days":30}]`,
service.SettingKeyAuthSourceDefaultEmailGrantOnFirstBind: "true",
}, assigner)
ctx := context.Background()
passwordHash, err := svc.HashPassword("password")
require.NoError(t, err)
user, err := client.User.Create().
SetEmail("retry-first-bind@example.com").
SetUsername("retry-user").
SetPasswordHash(passwordHash).
SetBalance(1.5).
SetConcurrency(2).
SetRole(service.RoleUser).
SetStatus(service.StatusActive).
Save(ctx)
require.NoError(t, err)
token, gotUser, err := svc.Login(ctx, user.Email, "password")
require.NoError(t, err)
require.NotEmpty(t, token)
require.NotNil(t, gotUser)
svc.RecordSuccessfulLogin(ctx, user.ID)
storedUser, err := client.User.Get(ctx, user.ID)
require.NoError(t, err)
require.Equal(t, 1.5, storedUser.Balance)
require.Equal(t, 2, storedUser.Concurrency)
require.Empty(t, assigner.calls)
require.Equal(t, 0, countProviderGrantRecords(t, client, user.ID, "email", "first_bind"))
token, gotUser, err = svc.Login(ctx, user.Email, "password")
require.NoError(t, err)
require.NotEmpty(t, token)
require.NotNil(t, gotUser)
svc.RecordSuccessfulLogin(ctx, user.ID)
storedUser, err = client.User.Get(ctx, user.ID)
require.NoError(t, err)
require.Equal(t, 1.5, storedUser.Balance)
require.Equal(t, 2, storedUser.Concurrency)
require.Empty(t, assigner.calls)
require.Equal(t, 0, countProviderGrantRecords(t, client, user.ID, "email", "first_bind"))
}
func countProviderGrantRecords(
t *testing.T,
client *dbent.Client,
userID int64,
providerType string,
grantReason string,
) int {
t.Helper()
var count int
rows, err := client.QueryContext(
context.Background(),
`SELECT COUNT(*) FROM user_provider_default_grants WHERE user_id = ? AND provider_type = ? AND grant_reason = ?`,
userID,
providerType,
grantReason,
)
require.NoError(t, err)
defer rows.Close()
require.True(t, rows.Next())
require.NoError(t, rows.Scan(&count))
require.NoError(t, rows.Err())
return count
}
//go:build unit
package service
import (
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/golang-jwt/jwt/v5"
"github.com/stretchr/testify/require"
)
func newAuthServiceForPendingOAuthTest() *AuthService {
cfg := &config.Config{
JWT: config.JWTConfig{
Secret: "test-secret-pending-oauth",
ExpireHour: 1,
},
}
return NewAuthService(nil, nil, nil, nil, cfg, nil, nil, nil, nil, nil, nil)
}
// TestVerifyPendingOAuthToken_ValidToken 验证正常签发的 pending token 可以被成功解析。
func TestVerifyPendingOAuthToken_ValidToken(t *testing.T) {
svc := newAuthServiceForPendingOAuthTest()
token, err := svc.CreatePendingOAuthToken("user@example.com", "alice")
require.NoError(t, err)
require.NotEmpty(t, token)
email, username, err := svc.VerifyPendingOAuthToken(token)
require.NoError(t, err)
require.Equal(t, "user@example.com", email)
require.Equal(t, "alice", username)
}
// TestVerifyPendingOAuthToken_RegularJWTRejected 用普通 access token 尝试验证,应返回 ErrInvalidToken。
func TestVerifyPendingOAuthToken_RegularJWTRejected(t *testing.T) {
svc := newAuthServiceForPendingOAuthTest()
// 签发一个普通 access token(JWTClaims,无 Purpose 字段)
accessToken, err := svc.GenerateToken(&User{
ID: 1,
Email: "user@example.com",
Role: RoleUser,
})
require.NoError(t, err)
_, _, err = svc.VerifyPendingOAuthToken(accessToken)
require.ErrorIs(t, err, ErrInvalidToken)
}
// TestVerifyPendingOAuthToken_WrongPurpose 手动构造 purpose 字段不匹配的 JWT,应返回 ErrInvalidToken。
func TestVerifyPendingOAuthToken_WrongPurpose(t *testing.T) {
svc := newAuthServiceForPendingOAuthTest()
now := time.Now()
claims := &pendingOAuthClaims{
Email: "user@example.com",
Username: "alice",
Purpose: "some_other_purpose",
RegisteredClaims: jwt.RegisteredClaims{
ExpiresAt: jwt.NewNumericDate(now.Add(10 * time.Minute)),
IssuedAt: jwt.NewNumericDate(now),
NotBefore: jwt.NewNumericDate(now),
},
}
tok := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
tokenStr, err := tok.SignedString([]byte(svc.cfg.JWT.Secret))
require.NoError(t, err)
_, _, err = svc.VerifyPendingOAuthToken(tokenStr)
require.ErrorIs(t, err, ErrInvalidToken)
}
// TestVerifyPendingOAuthToken_MissingPurpose 手动构造无 purpose 字段的 JWT(模拟旧 token),应返回 ErrInvalidToken。
func TestVerifyPendingOAuthToken_MissingPurpose(t *testing.T) {
svc := newAuthServiceForPendingOAuthTest()
now := time.Now()
claims := &pendingOAuthClaims{
Email: "user@example.com",
Username: "alice",
Purpose: "", // 旧 token 无此字段,反序列化后为零值
RegisteredClaims: jwt.RegisteredClaims{
ExpiresAt: jwt.NewNumericDate(now.Add(10 * time.Minute)),
IssuedAt: jwt.NewNumericDate(now),
NotBefore: jwt.NewNumericDate(now),
},
}
tok := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
tokenStr, err := tok.SignedString([]byte(svc.cfg.JWT.Secret))
require.NoError(t, err)
_, _, err = svc.VerifyPendingOAuthToken(tokenStr)
require.ErrorIs(t, err, ErrInvalidToken)
}
// TestVerifyPendingOAuthToken_ExpiredToken 过期 token 应返回 ErrInvalidToken。
func TestVerifyPendingOAuthToken_ExpiredToken(t *testing.T) {
svc := newAuthServiceForPendingOAuthTest()
past := time.Now().Add(-1 * time.Hour)
claims := &pendingOAuthClaims{
Email: "user@example.com",
Username: "alice",
Purpose: pendingOAuthPurpose,
RegisteredClaims: jwt.RegisteredClaims{
ExpiresAt: jwt.NewNumericDate(past),
IssuedAt: jwt.NewNumericDate(past.Add(-10 * time.Minute)),
NotBefore: jwt.NewNumericDate(past.Add(-10 * time.Minute)),
},
}
tok := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
tokenStr, err := tok.SignedString([]byte(svc.cfg.JWT.Secret))
require.NoError(t, err)
_, _, err = svc.VerifyPendingOAuthToken(tokenStr)
require.ErrorIs(t, err, ErrInvalidToken)
}
// TestVerifyPendingOAuthToken_WrongSecret 不同密钥签发的 token 应返回 ErrInvalidToken。
func TestVerifyPendingOAuthToken_WrongSecret(t *testing.T) {
other := NewAuthService(nil, nil, nil, nil, &config.Config{
JWT: config.JWTConfig{Secret: "other-secret"},
}, nil, nil, nil, nil, nil, nil)
token, err := other.CreatePendingOAuthToken("user@example.com", "alice")
require.NoError(t, err)
svc := newAuthServiceForPendingOAuthTest()
_, _, err = svc.VerifyPendingOAuthToken(token)
require.ErrorIs(t, err, ErrInvalidToken)
}
// TestVerifyPendingOAuthToken_TooLong 超长 token 应返回 ErrInvalidToken。
func TestVerifyPendingOAuthToken_TooLong(t *testing.T) {
svc := newAuthServiceForPendingOAuthTest()
giant := make([]byte, maxTokenLength+1)
for i := range giant {
giant[i] = 'a'
}
_, _, err := svc.VerifyPendingOAuthToken(string(giant))
require.ErrorIs(t, err, ErrInvalidToken)
}
......@@ -37,7 +37,16 @@ func (s *settingRepoStub) Set(ctx context.Context, key, value string) error {
}
func (s *settingRepoStub) GetMultiple(ctx context.Context, keys []string) (map[string]string, error) {
panic("unexpected GetMultiple call")
if s.err != nil {
return nil, s.err
}
result := make(map[string]string, len(keys))
for _, key := range keys {
if v, ok := s.values[key]; ok {
result[key] = v
}
}
return result, nil
}
func (s *settingRepoStub) SetMultiple(ctx context.Context, settings map[string]string) error {
......@@ -62,6 +71,8 @@ type defaultSubscriptionAssignerStub struct {
err error
}
type refreshTokenCacheStub struct{}
func (s *defaultSubscriptionAssignerStub) AssignOrExtendSubscription(_ context.Context, input *AssignSubscriptionInput) (*UserSubscription, bool, error) {
if input != nil {
s.calls = append(s.calls, *input)
......@@ -72,6 +83,46 @@ func (s *defaultSubscriptionAssignerStub) AssignOrExtendSubscription(_ context.C
return &UserSubscription{UserID: input.UserID, GroupID: input.GroupID}, false, nil
}
func (s *refreshTokenCacheStub) StoreRefreshToken(context.Context, string, *RefreshTokenData, time.Duration) error {
return nil
}
func (s *refreshTokenCacheStub) GetRefreshToken(context.Context, string) (*RefreshTokenData, error) {
return nil, ErrRefreshTokenNotFound
}
func (s *refreshTokenCacheStub) DeleteRefreshToken(context.Context, string) error {
return nil
}
func (s *refreshTokenCacheStub) DeleteUserRefreshTokens(context.Context, int64) error {
return nil
}
func (s *refreshTokenCacheStub) DeleteTokenFamily(context.Context, string) error {
return nil
}
func (s *refreshTokenCacheStub) AddToUserTokenSet(context.Context, int64, string, time.Duration) error {
return nil
}
func (s *refreshTokenCacheStub) AddToFamilyTokenSet(context.Context, string, string, time.Duration) error {
return nil
}
func (s *refreshTokenCacheStub) GetUserTokenHashes(context.Context, int64) ([]string, error) {
return nil, nil
}
func (s *refreshTokenCacheStub) GetFamilyTokenHashes(context.Context, string) ([]string, error) {
return nil, nil
}
func (s *refreshTokenCacheStub) IsTokenInFamily(context.Context, string, string) (bool, error) {
return false, nil
}
func (s *emailCacheStub) GetVerificationCode(ctx context.Context, email string) (*VerificationCodeData, error) {
if s.err != nil {
return nil, s.err
......@@ -322,7 +373,8 @@ func TestAuthService_Register_CreateEmailExistsRace(t *testing.T) {
func TestAuthService_Register_Success(t *testing.T) {
repo := &userRepoStub{nextID: 5}
service := newAuthService(repo, map[string]string{
SettingKeyRegistrationEnabled: "true",
SettingKeyRegistrationEnabled: "true",
SettingKeyAuthSourceDefaultEmailGrantOnSignup: "false",
}, nil)
token, user, err := service.Register(context.Background(), "user@test.com", "password")
......@@ -469,8 +521,9 @@ func TestAuthService_Register_AssignsDefaultSubscriptions(t *testing.T) {
repo := &userRepoStub{nextID: 42}
assigner := &defaultSubscriptionAssignerStub{}
service := newAuthService(repo, map[string]string{
SettingKeyRegistrationEnabled: "true",
SettingKeyDefaultSubscriptions: `[{"group_id":11,"validity_days":30},{"group_id":12,"validity_days":7}]`,
SettingKeyRegistrationEnabled: "true",
SettingKeyDefaultSubscriptions: `[{"group_id":11,"validity_days":30},{"group_id":12,"validity_days":7}]`,
SettingKeyAuthSourceDefaultEmailGrantOnSignup: "false",
}, nil)
service.defaultSubAssigner = assigner
......@@ -484,3 +537,132 @@ func TestAuthService_Register_AssignsDefaultSubscriptions(t *testing.T) {
require.Equal(t, int64(12), assigner.calls[1].GroupID)
require.Equal(t, 7, assigner.calls[1].ValidityDays)
}
func TestAuthService_Register_UsesEmailAuthSourceDefaultsWhenGrantEnabled(t *testing.T) {
repo := &userRepoStub{nextID: 52}
assigner := &defaultSubscriptionAssignerStub{}
service := newAuthService(repo, map[string]string{
SettingKeyRegistrationEnabled: "true",
SettingKeyDefaultSubscriptions: `[{"group_id":91,"validity_days":3}]`,
SettingKeyAuthSourceDefaultEmailBalance: "12.5",
SettingKeyAuthSourceDefaultEmailConcurrency: "7",
SettingKeyAuthSourceDefaultEmailSubscriptions: `[{"group_id":11,"validity_days":30}]`,
SettingKeyAuthSourceDefaultEmailGrantOnSignup: "true",
}, nil)
service.defaultSubAssigner = assigner
_, user, err := service.Register(context.Background(), "email-defaults@test.com", "password")
require.NoError(t, err)
require.NotNil(t, user)
require.Equal(t, 12.5, user.Balance)
require.Equal(t, 7, user.Concurrency)
require.Len(t, assigner.calls, 1)
require.Equal(t, int64(11), assigner.calls[0].GroupID)
require.Equal(t, 30, assigner.calls[0].ValidityDays)
}
func TestAuthService_Register_GrantOnSignupFalseFallsBackToGlobalDefaults(t *testing.T) {
repo := &userRepoStub{nextID: 53}
assigner := &defaultSubscriptionAssignerStub{}
service := newAuthService(repo, map[string]string{
SettingKeyRegistrationEnabled: "true",
SettingKeyDefaultSubscriptions: `[{"group_id":31,"validity_days":5}]`,
SettingKeyAuthSourceDefaultEmailBalance: "99",
SettingKeyAuthSourceDefaultEmailConcurrency: "88",
SettingKeyAuthSourceDefaultEmailSubscriptions: `[{"group_id":32,"validity_days":9}]`,
SettingKeyAuthSourceDefaultEmailGrantOnSignup: "false",
}, nil)
service.defaultSubAssigner = assigner
_, user, err := service.Register(context.Background(), "email-global@test.com", "password")
require.NoError(t, err)
require.NotNil(t, user)
require.Equal(t, 3.5, user.Balance)
require.Equal(t, 2, user.Concurrency)
require.Len(t, assigner.calls, 1)
require.Equal(t, int64(31), assigner.calls[0].GroupID)
require.Equal(t, 5, assigner.calls[0].ValidityDays)
}
func TestAuthService_Register_GrantOnSignupMergesSourceOverridesWithGlobalDefaults(t *testing.T) {
repo := &userRepoStub{nextID: 54}
assigner := &defaultSubscriptionAssignerStub{}
service := newAuthService(repo, map[string]string{
SettingKeyRegistrationEnabled: "true",
SettingKeyDefaultSubscriptions: `[{"group_id":31,"validity_days":5}]`,
SettingKeyAuthSourceDefaultEmailBalance: "9.5",
SettingKeyAuthSourceDefaultEmailConcurrency: "5",
SettingKeyAuthSourceDefaultEmailSubscriptions: `[]`,
SettingKeyAuthSourceDefaultEmailGrantOnSignup: "true",
}, nil)
service.defaultSubAssigner = assigner
_, user, err := service.Register(context.Background(), "email-merged@test.com", "password")
require.NoError(t, err)
require.NotNil(t, user)
require.Equal(t, 9.5, user.Balance)
require.Equal(t, 2, user.Concurrency)
require.Len(t, assigner.calls, 1)
require.Equal(t, int64(31), assigner.calls[0].GroupID)
require.Equal(t, 5, assigner.calls[0].ValidityDays)
}
func TestAuthService_LoginOrRegisterOAuthWithTokenPair_UsesLinuxDoAuthSourceDefaultsOnSignup(t *testing.T) {
repo := &userRepoStub{nextID: 61}
assigner := &defaultSubscriptionAssignerStub{}
service := newAuthService(repo, map[string]string{
SettingKeyRegistrationEnabled: "true",
SettingKeyDefaultSubscriptions: `[{"group_id":81,"validity_days":1}]`,
SettingKeyAuthSourceDefaultLinuxDoBalance: "21.75",
SettingKeyAuthSourceDefaultLinuxDoConcurrency: "9",
SettingKeyAuthSourceDefaultLinuxDoSubscriptions: `[{"group_id":22,"validity_days":14}]`,
SettingKeyAuthSourceDefaultLinuxDoGrantOnSignup: "true",
}, nil)
service.defaultSubAssigner = assigner
service.refreshTokenCache = &refreshTokenCacheStub{}
tokenPair, user, err := service.LoginOrRegisterOAuthWithTokenPair(context.Background(), "linuxdo-123@linuxdo-connect.invalid", "linuxdo_user", "")
require.NoError(t, err)
require.NotNil(t, tokenPair)
require.NotNil(t, user)
require.Equal(t, int64(61), user.ID)
require.Equal(t, 21.75, user.Balance)
require.Equal(t, 9, user.Concurrency)
require.Len(t, repo.created, 1)
require.Len(t, assigner.calls, 1)
require.Equal(t, int64(22), assigner.calls[0].GroupID)
require.Equal(t, 14, assigner.calls[0].ValidityDays)
}
func TestAuthService_LoginOrRegisterOAuthWithTokenPair_ExistingUserDoesNotGrantAgain(t *testing.T) {
existing := &User{
ID: 88,
Email: "linuxdo-123@linuxdo-connect.invalid",
Username: "existing-linuxdo",
Role: RoleUser,
Status: StatusActive,
Balance: 4,
Concurrency: 1,
TokenVersion: 2,
}
repo := &userRepoStub{user: existing}
assigner := &defaultSubscriptionAssignerStub{}
service := newAuthService(repo, map[string]string{
SettingKeyRegistrationEnabled: "true",
SettingKeyAuthSourceDefaultLinuxDoBalance: "21.75",
SettingKeyAuthSourceDefaultLinuxDoConcurrency: "9",
SettingKeyAuthSourceDefaultLinuxDoSubscriptions: `[{"group_id":22,"validity_days":14}]`,
SettingKeyAuthSourceDefaultLinuxDoGrantOnSignup: "true",
}, nil)
service.defaultSubAssigner = assigner
service.refreshTokenCache = &refreshTokenCacheStub{}
tokenPair, user, err := service.LoginOrRegisterOAuthWithTokenPair(context.Background(), existing.Email, "linuxdo_user", "")
require.NoError(t, err)
require.NotNil(t, tokenPair)
require.Equal(t, existing.ID, user.ID)
require.Equal(t, 4.0, user.Balance)
require.Equal(t, 1, user.Concurrency)
require.Empty(t, repo.created)
require.Empty(t, assigner.calls)
}
......@@ -86,6 +86,14 @@ func (s *balanceLoadUserRepoStub) GetByID(ctx context.Context, id int64) (*User,
return &User{ID: id, Balance: s.balance}, nil
}
func (s *balanceLoadUserRepoStub) ListUserAuthIdentities(context.Context, int64) ([]UserAuthIdentityRecord, error) {
return nil, nil
}
func (s *balanceLoadUserRepoStub) UnbindUserAuthProvider(context.Context, int64, string) error {
return nil
}
func TestBillingCacheServiceGetUserBalance_Singleflight(t *testing.T) {
cache := &billingCacheMissStub{}
userRepo := &balanceLoadUserRepoStub{
......
......@@ -74,6 +74,9 @@ const LinuxDoConnectSyntheticEmailDomain = "@linuxdo-connect.invalid"
// OIDCConnectSyntheticEmailDomain 是 OIDC 用户的合成邮箱后缀(RFC 保留域名)。
const OIDCConnectSyntheticEmailDomain = "@oidc-connect.invalid"
// WeChatConnectSyntheticEmailDomain 是 WeChat Connect 用户的合成邮箱后缀(RFC 保留域名)。
const WeChatConnectSyntheticEmailDomain = "@wechat-connect.invalid"
// Setting keys
const (
// 注册设置
......@@ -108,6 +111,24 @@ const (
SettingKeyLinuxDoConnectClientSecret = "linuxdo_connect_client_secret"
SettingKeyLinuxDoConnectRedirectURL = "linuxdo_connect_redirect_url"
// WeChat Connect OAuth 登录设置
SettingKeyWeChatConnectEnabled = "wechat_connect_enabled"
SettingKeyWeChatConnectAppID = "wechat_connect_app_id"
SettingKeyWeChatConnectAppSecret = "wechat_connect_app_secret"
SettingKeyWeChatConnectOpenAppID = "wechat_connect_open_app_id"
SettingKeyWeChatConnectOpenAppSecret = "wechat_connect_open_app_secret"
SettingKeyWeChatConnectMPAppID = "wechat_connect_mp_app_id"
SettingKeyWeChatConnectMPAppSecret = "wechat_connect_mp_app_secret"
SettingKeyWeChatConnectMobileAppID = "wechat_connect_mobile_app_id"
SettingKeyWeChatConnectMobileAppSecret = "wechat_connect_mobile_app_secret"
SettingKeyWeChatConnectOpenEnabled = "wechat_connect_open_enabled"
SettingKeyWeChatConnectMPEnabled = "wechat_connect_mp_enabled"
SettingKeyWeChatConnectMobileEnabled = "wechat_connect_mobile_enabled"
SettingKeyWeChatConnectMode = "wechat_connect_mode"
SettingKeyWeChatConnectScopes = "wechat_connect_scopes"
SettingKeyWeChatConnectRedirectURL = "wechat_connect_redirect_url"
SettingKeyWeChatConnectFrontendRedirectURL = "wechat_connect_frontend_redirect_url"
// Generic OIDC OAuth 登录设置
SettingKeyOIDCConnectEnabled = "oidc_connect_enabled"
SettingKeyOIDCConnectProviderName = "oidc_connect_provider_name"
......@@ -153,6 +174,29 @@ const (
SettingKeyDefaultBalance = "default_balance" // 新用户默认余额
SettingKeyDefaultSubscriptions = "default_subscriptions" // 新用户默认订阅列表(JSON)
// 第三方认证来源默认授予配置
SettingKeyAuthSourceDefaultEmailBalance = "auth_source_default_email_balance"
SettingKeyAuthSourceDefaultEmailConcurrency = "auth_source_default_email_concurrency"
SettingKeyAuthSourceDefaultEmailSubscriptions = "auth_source_default_email_subscriptions"
SettingKeyAuthSourceDefaultEmailGrantOnSignup = "auth_source_default_email_grant_on_signup"
SettingKeyAuthSourceDefaultEmailGrantOnFirstBind = "auth_source_default_email_grant_on_first_bind"
SettingKeyAuthSourceDefaultLinuxDoBalance = "auth_source_default_linuxdo_balance"
SettingKeyAuthSourceDefaultLinuxDoConcurrency = "auth_source_default_linuxdo_concurrency"
SettingKeyAuthSourceDefaultLinuxDoSubscriptions = "auth_source_default_linuxdo_subscriptions"
SettingKeyAuthSourceDefaultLinuxDoGrantOnSignup = "auth_source_default_linuxdo_grant_on_signup"
SettingKeyAuthSourceDefaultLinuxDoGrantOnFirstBind = "auth_source_default_linuxdo_grant_on_first_bind"
SettingKeyAuthSourceDefaultOIDCBalance = "auth_source_default_oidc_balance"
SettingKeyAuthSourceDefaultOIDCConcurrency = "auth_source_default_oidc_concurrency"
SettingKeyAuthSourceDefaultOIDCSubscriptions = "auth_source_default_oidc_subscriptions"
SettingKeyAuthSourceDefaultOIDCGrantOnSignup = "auth_source_default_oidc_grant_on_signup"
SettingKeyAuthSourceDefaultOIDCGrantOnFirstBind = "auth_source_default_oidc_grant_on_first_bind"
SettingKeyAuthSourceDefaultWeChatBalance = "auth_source_default_wechat_balance"
SettingKeyAuthSourceDefaultWeChatConcurrency = "auth_source_default_wechat_concurrency"
SettingKeyAuthSourceDefaultWeChatSubscriptions = "auth_source_default_wechat_subscriptions"
SettingKeyAuthSourceDefaultWeChatGrantOnSignup = "auth_source_default_wechat_grant_on_signup"
SettingKeyAuthSourceDefaultWeChatGrantOnFirstBind = "auth_source_default_wechat_grant_on_first_bind"
SettingKeyForceEmailOnThirdPartySignup = "force_email_on_third_party_signup"
// 管理员 API Key
SettingKeyAdminAPIKey = "admin_api_key" // 全局管理员 API Key(用于外部系统集成)
......
......@@ -13,14 +13,30 @@ import (
"sync"
"sync/atomic"
"time"
"golang.org/x/sync/singleflight"
)
const (
openAIAccountScheduleLayerPreviousResponse = "previous_response_id"
openAIAccountScheduleLayerSessionSticky = "session_hash"
openAIAccountScheduleLayerLoadBalance = "load_balance"
openAIAdvancedSchedulerSettingKey = "openai_advanced_scheduler_enabled"
)
const (
openAIAdvancedSchedulerSettingCacheTTL = 5 * time.Second
openAIAdvancedSchedulerSettingDBTimeout = 2 * time.Second
)
type cachedOpenAIAdvancedSchedulerSetting struct {
enabled bool
expiresAt int64
}
var openAIAdvancedSchedulerSettingCache atomic.Value // *cachedOpenAIAdvancedSchedulerSetting
var openAIAdvancedSchedulerSettingSF singleflight.Group
type OpenAIAccountScheduleRequest struct {
GroupID *int64
SessionHash string
......@@ -751,14 +767,13 @@ func (s *defaultOpenAIAccountScheduler) selectByLoadBalance(
}
func (s *defaultOpenAIAccountScheduler) isAccountTransportCompatible(account *Account, requiredTransport OpenAIUpstreamTransport) bool {
// HTTP 入站可回退到 HTTP 线路,不需要在账号选择阶段做传输协议强过滤。
if requiredTransport == OpenAIUpstreamTransportAny || requiredTransport == OpenAIUpstreamTransportHTTPSSE {
return true
}
if s == nil || s.service == nil || account == nil {
if s == nil || s.service == nil {
return false
}
return s.service.getOpenAIWSProtocolResolver().Resolve(account).Transport == requiredTransport
return s.service.isOpenAIAccountTransportCompatible(account, requiredTransport)
}
func (s *defaultOpenAIAccountScheduler) ReportResult(accountID int64, success bool, firstTokenMs *int) {
......@@ -805,10 +820,56 @@ func (s *defaultOpenAIAccountScheduler) SnapshotMetrics() OpenAIAccountScheduler
return snapshot
}
func (s *OpenAIGatewayService) getOpenAIAccountScheduler() OpenAIAccountScheduler {
func (s *OpenAIGatewayService) openAIAdvancedSchedulerSettingRepo() SettingRepository {
if s == nil || s.rateLimitService == nil || s.rateLimitService.settingService == nil {
return nil
}
return s.rateLimitService.settingService.settingRepo
}
func (s *OpenAIGatewayService) isOpenAIAdvancedSchedulerEnabled(ctx context.Context) bool {
if cached, ok := openAIAdvancedSchedulerSettingCache.Load().(*cachedOpenAIAdvancedSchedulerSetting); ok && cached != nil {
if time.Now().UnixNano() < cached.expiresAt {
return cached.enabled
}
}
result, _, _ := openAIAdvancedSchedulerSettingSF.Do(openAIAdvancedSchedulerSettingKey, func() (any, error) {
if cached, ok := openAIAdvancedSchedulerSettingCache.Load().(*cachedOpenAIAdvancedSchedulerSetting); ok && cached != nil {
if time.Now().UnixNano() < cached.expiresAt {
return cached.enabled, nil
}
}
enabled := false
if repo := s.openAIAdvancedSchedulerSettingRepo(); repo != nil {
dbCtx, cancel := context.WithTimeout(context.WithoutCancel(ctx), openAIAdvancedSchedulerSettingDBTimeout)
defer cancel()
value, err := repo.GetValue(dbCtx, openAIAdvancedSchedulerSettingKey)
if err == nil {
enabled = strings.EqualFold(strings.TrimSpace(value), "true")
}
}
openAIAdvancedSchedulerSettingCache.Store(&cachedOpenAIAdvancedSchedulerSetting{
enabled: enabled,
expiresAt: time.Now().Add(openAIAdvancedSchedulerSettingCacheTTL).UnixNano(),
})
return enabled, nil
})
enabled, _ := result.(bool)
return enabled
}
func (s *OpenAIGatewayService) getOpenAIAccountScheduler(ctx context.Context) OpenAIAccountScheduler {
if s == nil {
return nil
}
if !s.isOpenAIAdvancedSchedulerEnabled(ctx) {
return nil
}
s.openaiSchedulerOnce.Do(func() {
if s.openaiAccountStats == nil {
s.openaiAccountStats = newOpenAIAccountRuntimeStats()
......@@ -820,6 +881,11 @@ func (s *OpenAIGatewayService) getOpenAIAccountScheduler() OpenAIAccountSchedule
return s.openaiScheduler
}
func resetOpenAIAdvancedSchedulerSettingCacheForTest() {
openAIAdvancedSchedulerSettingCache = atomic.Value{}
openAIAdvancedSchedulerSettingSF = singleflight.Group{}
}
func (s *OpenAIGatewayService) SelectAccountWithScheduler(
ctx context.Context,
groupID *int64,
......@@ -830,11 +896,37 @@ func (s *OpenAIGatewayService) SelectAccountWithScheduler(
requiredTransport OpenAIUpstreamTransport,
) (*AccountSelectionResult, OpenAIAccountScheduleDecision, error) {
decision := OpenAIAccountScheduleDecision{}
scheduler := s.getOpenAIAccountScheduler()
scheduler := s.getOpenAIAccountScheduler(ctx)
if scheduler == nil {
selection, err := s.SelectAccountWithLoadAwareness(ctx, groupID, sessionHash, requestedModel, excludedIDs)
decision.Layer = openAIAccountScheduleLayerLoadBalance
return selection, decision, err
if requiredTransport == OpenAIUpstreamTransportAny || requiredTransport == OpenAIUpstreamTransportHTTPSSE {
selection, err := s.SelectAccountWithLoadAwareness(ctx, groupID, sessionHash, requestedModel, excludedIDs)
return selection, decision, err
}
effectiveExcludedIDs := cloneExcludedAccountIDs(excludedIDs)
for {
selection, err := s.SelectAccountWithLoadAwareness(ctx, groupID, sessionHash, requestedModel, effectiveExcludedIDs)
if err != nil {
return nil, decision, err
}
if selection == nil || selection.Account == nil {
return selection, decision, nil
}
if s.isOpenAIAccountTransportCompatible(selection.Account, requiredTransport) {
return selection, decision, nil
}
if selection.ReleaseFunc != nil {
selection.ReleaseFunc()
}
if effectiveExcludedIDs == nil {
effectiveExcludedIDs = make(map[int64]struct{})
}
if _, exists := effectiveExcludedIDs[selection.Account.ID]; exists {
return nil, decision, ErrNoAvailableAccounts
}
effectiveExcludedIDs[selection.Account.ID] = struct{}{}
}
}
var stickyAccountID int64
......@@ -855,8 +947,29 @@ func (s *OpenAIGatewayService) SelectAccountWithScheduler(
})
}
func cloneExcludedAccountIDs(excludedIDs map[int64]struct{}) map[int64]struct{} {
if len(excludedIDs) == 0 {
return nil
}
cloned := make(map[int64]struct{}, len(excludedIDs))
for id := range excludedIDs {
cloned[id] = struct{}{}
}
return cloned
}
func (s *OpenAIGatewayService) isOpenAIAccountTransportCompatible(account *Account, requiredTransport OpenAIUpstreamTransport) bool {
if requiredTransport == OpenAIUpstreamTransportAny || requiredTransport == OpenAIUpstreamTransportHTTPSSE {
return true
}
if s == nil || account == nil {
return false
}
return s.getOpenAIWSProtocolResolver().Resolve(account).Transport == requiredTransport
}
func (s *OpenAIGatewayService) ReportOpenAIAccountScheduleResult(accountID int64, success bool, firstTokenMs *int) {
scheduler := s.getOpenAIAccountScheduler()
scheduler := s.getOpenAIAccountScheduler(context.Background())
if scheduler == nil {
return
}
......@@ -864,7 +977,7 @@ func (s *OpenAIGatewayService) ReportOpenAIAccountScheduleResult(accountID int64
}
func (s *OpenAIGatewayService) RecordOpenAIAccountSwitch() {
scheduler := s.getOpenAIAccountScheduler()
scheduler := s.getOpenAIAccountScheduler(context.Background())
if scheduler == nil {
return
}
......@@ -872,7 +985,7 @@ func (s *OpenAIGatewayService) RecordOpenAIAccountSwitch() {
}
func (s *OpenAIGatewayService) SnapshotOpenAIAccountSchedulerMetrics() OpenAIAccountSchedulerMetricsSnapshot {
scheduler := s.getOpenAIAccountScheduler()
scheduler := s.getOpenAIAccountScheduler(context.Background())
if scheduler == nil {
return OpenAIAccountSchedulerMetricsSnapshot{}
}
......
......@@ -2,6 +2,7 @@ package service
import (
"context"
"errors"
"fmt"
"math"
"sync"
......@@ -18,6 +19,202 @@ type openAISnapshotCacheStub struct {
accountsByID map[int64]*Account
}
type schedulerTestOpenAIAccountRepo struct {
AccountRepository
accounts []Account
}
func (r schedulerTestOpenAIAccountRepo) GetByID(ctx context.Context, id int64) (*Account, error) {
for i := range r.accounts {
if r.accounts[i].ID == id {
return &r.accounts[i], nil
}
}
return nil, errors.New("account not found")
}
func (r schedulerTestOpenAIAccountRepo) ListSchedulableByGroupIDAndPlatform(ctx context.Context, groupID int64, platform string) ([]Account, error) {
var result []Account
for _, acc := range r.accounts {
if acc.Platform == platform {
result = append(result, acc)
}
}
return result, nil
}
func (r schedulerTestOpenAIAccountRepo) ListSchedulableByPlatform(ctx context.Context, platform string) ([]Account, error) {
var result []Account
for _, acc := range r.accounts {
if acc.Platform == platform {
result = append(result, acc)
}
}
return result, nil
}
func (r schedulerTestOpenAIAccountRepo) ListSchedulableUngroupedByPlatform(ctx context.Context, platform string) ([]Account, error) {
return r.ListSchedulableByPlatform(ctx, platform)
}
type schedulerTestConcurrencyCache struct {
ConcurrencyCache
loadBatchErr error
loadMap map[int64]*AccountLoadInfo
acquireResults map[int64]bool
waitCounts map[int64]int
skipDefaultLoad bool
}
func (c schedulerTestConcurrencyCache) AcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error) {
if c.acquireResults != nil {
if result, ok := c.acquireResults[accountID]; ok {
return result, nil
}
}
return true, nil
}
func (c schedulerTestConcurrencyCache) ReleaseAccountSlot(ctx context.Context, accountID int64, requestID string) error {
return nil
}
func (c schedulerTestConcurrencyCache) GetAccountsLoadBatch(ctx context.Context, accounts []AccountWithConcurrency) (map[int64]*AccountLoadInfo, error) {
if c.loadBatchErr != nil {
return nil, c.loadBatchErr
}
out := make(map[int64]*AccountLoadInfo, len(accounts))
if c.skipDefaultLoad && c.loadMap != nil {
for _, acc := range accounts {
if load, ok := c.loadMap[acc.ID]; ok {
out[acc.ID] = load
}
}
return out, nil
}
for _, acc := range accounts {
if c.loadMap != nil {
if load, ok := c.loadMap[acc.ID]; ok {
out[acc.ID] = load
continue
}
}
out[acc.ID] = &AccountLoadInfo{AccountID: acc.ID, LoadRate: 0}
}
return out, nil
}
func (c schedulerTestConcurrencyCache) GetAccountWaitingCount(ctx context.Context, accountID int64) (int, error) {
if c.waitCounts != nil {
if count, ok := c.waitCounts[accountID]; ok {
return count, nil
}
}
return 0, nil
}
type schedulerTestGatewayCache struct {
sessionBindings map[string]int64
deletedSessions map[string]int
}
func (c *schedulerTestGatewayCache) GetSessionAccountID(ctx context.Context, groupID int64, sessionHash string) (int64, error) {
if id, ok := c.sessionBindings[sessionHash]; ok {
return id, nil
}
return 0, errors.New("not found")
}
func (c *schedulerTestGatewayCache) SetSessionAccountID(ctx context.Context, groupID int64, sessionHash string, accountID int64, ttl time.Duration) error {
if c.sessionBindings == nil {
c.sessionBindings = make(map[string]int64)
}
c.sessionBindings[sessionHash] = accountID
return nil
}
func (c *schedulerTestGatewayCache) RefreshSessionTTL(ctx context.Context, groupID int64, sessionHash string, ttl time.Duration) error {
return nil
}
func (c *schedulerTestGatewayCache) DeleteSessionAccountID(ctx context.Context, groupID int64, sessionHash string) error {
if c.sessionBindings == nil {
return nil
}
if c.deletedSessions == nil {
c.deletedSessions = make(map[string]int)
}
c.deletedSessions[sessionHash]++
delete(c.sessionBindings, sessionHash)
return nil
}
func newSchedulerTestOpenAIWSV2Config() *config.Config {
cfg := &config.Config{}
cfg.Gateway.OpenAIWS.Enabled = true
cfg.Gateway.OpenAIWS.OAuthEnabled = true
cfg.Gateway.OpenAIWS.APIKeyEnabled = true
cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
cfg.Gateway.OpenAIWS.StickyResponseIDTTLSeconds = 3600
return cfg
}
type openAIAdvancedSchedulerSettingRepoStub struct {
values map[string]string
}
func (s *openAIAdvancedSchedulerSettingRepoStub) Get(ctx context.Context, key string) (*Setting, error) {
value, err := s.GetValue(ctx, key)
if err != nil {
return nil, err
}
return &Setting{Key: key, Value: value}, nil
}
func (s *openAIAdvancedSchedulerSettingRepoStub) GetValue(_ context.Context, key string) (string, error) {
if s == nil || s.values == nil {
return "", ErrSettingNotFound
}
value, ok := s.values[key]
if !ok {
return "", ErrSettingNotFound
}
return value, nil
}
func (s *openAIAdvancedSchedulerSettingRepoStub) Set(context.Context, string, string) error {
panic("unexpected call to Set")
}
func (s *openAIAdvancedSchedulerSettingRepoStub) GetMultiple(context.Context, []string) (map[string]string, error) {
panic("unexpected call to GetMultiple")
}
func (s *openAIAdvancedSchedulerSettingRepoStub) SetMultiple(context.Context, map[string]string) error {
panic("unexpected call to SetMultiple")
}
func (s *openAIAdvancedSchedulerSettingRepoStub) GetAll(context.Context) (map[string]string, error) {
panic("unexpected call to GetAll")
}
func (s *openAIAdvancedSchedulerSettingRepoStub) Delete(context.Context, string) error {
panic("unexpected call to Delete")
}
func newOpenAIAdvancedSchedulerRateLimitService(enabled string) *RateLimitService {
resetOpenAIAdvancedSchedulerSettingCacheForTest()
repo := &openAIAdvancedSchedulerSettingRepoStub{
values: map[string]string{},
}
if enabled != "" {
repo.values[openAIAdvancedSchedulerSettingKey] = enabled
}
return &RateLimitService{
settingService: NewSettingService(repo, &config.Config{}),
}
}
func (s *openAISnapshotCacheStub) GetSnapshot(ctx context.Context, bucket SchedulerBucket) ([]*Account, bool, error) {
if len(s.snapshotAccounts) == 0 {
return nil, false, nil
......@@ -45,6 +242,230 @@ func (s *openAISnapshotCacheStub) GetAccount(ctx context.Context, accountID int6
return &cloned, nil
}
func TestOpenAIGatewayService_SelectAccountWithScheduler_DefaultDisabledUsesLegacyLoadAwareness(t *testing.T) {
resetOpenAIAdvancedSchedulerSettingCacheForTest()
ctx := context.Background()
groupID := int64(10106)
accounts := []Account{
{
ID: 36001,
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Status: StatusActive,
Schedulable: true,
Concurrency: 1,
Priority: 5,
},
{
ID: 36002,
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Status: StatusActive,
Schedulable: true,
Concurrency: 1,
Priority: 0,
},
}
cfg := &config.Config{}
cfg.Gateway.Scheduling.LoadBatchEnabled = false
cache := &schedulerTestGatewayCache{}
svc := &OpenAIGatewayService{
accountRepo: schedulerTestOpenAIAccountRepo{accounts: accounts},
cache: cache,
cfg: cfg,
concurrencyService: NewConcurrencyService(schedulerTestConcurrencyCache{}),
}
store := svc.getOpenAIWSStateStore()
require.NoError(t, store.BindResponseAccount(ctx, groupID, "resp_disabled_001", 36001, time.Hour))
require.False(t, svc.isOpenAIAdvancedSchedulerEnabled(ctx))
selection, decision, err := svc.SelectAccountWithScheduler(
ctx,
&groupID,
"resp_disabled_001",
"",
"gpt-5.1",
nil,
OpenAIUpstreamTransportAny,
)
require.NoError(t, err)
require.NotNil(t, selection)
require.NotNil(t, selection.Account)
require.Equal(t, int64(36002), selection.Account.ID)
require.Equal(t, openAIAccountScheduleLayerLoadBalance, decision.Layer)
require.False(t, decision.StickyPreviousHit)
}
func TestOpenAIGatewayService_SelectAccountWithScheduler_DefaultDisabled_RequiredWSV2_SkipsHTTPOnlyAccount(t *testing.T) {
resetOpenAIAdvancedSchedulerSettingCacheForTest()
ctx := context.Background()
groupID := int64(10108)
accounts := []Account{
{
ID: 36011,
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Status: StatusActive,
Schedulable: true,
Concurrency: 1,
Priority: 0,
},
{
ID: 36012,
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Status: StatusActive,
Schedulable: true,
Concurrency: 1,
Priority: 5,
Extra: map[string]any{
"openai_apikey_responses_websockets_v2_enabled": true,
},
},
}
cfg := newSchedulerTestOpenAIWSV2Config()
cfg.Gateway.Scheduling.LoadBatchEnabled = false
svc := &OpenAIGatewayService{
accountRepo: schedulerTestOpenAIAccountRepo{accounts: accounts},
cache: &schedulerTestGatewayCache{},
cfg: cfg,
concurrencyService: NewConcurrencyService(schedulerTestConcurrencyCache{}),
}
selection, decision, err := svc.SelectAccountWithScheduler(
ctx,
&groupID,
"",
"",
"gpt-5.1",
nil,
OpenAIUpstreamTransportResponsesWebsocketV2,
)
require.NoError(t, err)
require.NotNil(t, selection)
require.NotNil(t, selection.Account)
require.Equal(t, int64(36012), selection.Account.ID)
require.Equal(t, openAIAccountScheduleLayerLoadBalance, decision.Layer)
}
func TestOpenAIGatewayService_SelectAccountWithScheduler_DefaultDisabled_RequiredWSV2_NoAvailableAccount(t *testing.T) {
resetOpenAIAdvancedSchedulerSettingCacheForTest()
ctx := context.Background()
groupID := int64(10109)
accounts := []Account{
{
ID: 36021,
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Status: StatusActive,
Schedulable: true,
Concurrency: 1,
Priority: 0,
},
}
cfg := newSchedulerTestOpenAIWSV2Config()
cfg.Gateway.Scheduling.LoadBatchEnabled = false
svc := &OpenAIGatewayService{
accountRepo: schedulerTestOpenAIAccountRepo{accounts: accounts},
cache: &schedulerTestGatewayCache{},
cfg: cfg,
concurrencyService: NewConcurrencyService(schedulerTestConcurrencyCache{}),
}
selection, decision, err := svc.SelectAccountWithScheduler(
ctx,
&groupID,
"",
"",
"gpt-5.1",
nil,
OpenAIUpstreamTransportResponsesWebsocketV2,
)
require.ErrorContains(t, err, "no available OpenAI accounts")
require.Nil(t, selection)
require.Equal(t, openAIAccountScheduleLayerLoadBalance, decision.Layer)
}
func TestOpenAIGatewayService_SelectAccountWithScheduler_EnabledUsesAdvancedPreviousResponseRouting(t *testing.T) {
resetOpenAIAdvancedSchedulerSettingCacheForTest()
ctx := context.Background()
groupID := int64(10107)
accounts := []Account{
{
ID: 37001,
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Status: StatusActive,
Schedulable: true,
Concurrency: 1,
Priority: 5,
Extra: map[string]any{
"openai_apikey_responses_websockets_v2_enabled": true,
},
},
{
ID: 37002,
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Status: StatusActive,
Schedulable: true,
Concurrency: 1,
Priority: 0,
},
}
cfg := &config.Config{}
cfg.Gateway.Scheduling.LoadBatchEnabled = false
cfg.Gateway.OpenAIWS.Enabled = true
cfg.Gateway.OpenAIWS.OAuthEnabled = true
cfg.Gateway.OpenAIWS.APIKeyEnabled = true
cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
cfg.Gateway.OpenAIWS.StickyResponseIDTTLSeconds = 3600
svc := &OpenAIGatewayService{
accountRepo: schedulerTestOpenAIAccountRepo{accounts: accounts},
cache: &schedulerTestGatewayCache{},
cfg: cfg,
rateLimitService: newOpenAIAdvancedSchedulerRateLimitService("true"),
concurrencyService: NewConcurrencyService(schedulerTestConcurrencyCache{}),
}
store := svc.getOpenAIWSStateStore()
require.NoError(t, store.BindResponseAccount(ctx, groupID, "resp_enabled_001", 37001, time.Hour))
require.True(t, svc.isOpenAIAdvancedSchedulerEnabled(ctx))
selection, decision, err := svc.SelectAccountWithScheduler(
ctx,
&groupID,
"resp_enabled_001",
"",
"gpt-5.1",
nil,
OpenAIUpstreamTransportAny,
)
require.NoError(t, err)
require.NotNil(t, selection)
require.NotNil(t, selection.Account)
require.Equal(t, int64(37001), selection.Account.ID)
require.Equal(t, openAIAccountScheduleLayerPreviousResponse, decision.Layer)
require.True(t, decision.StickyPreviousHit)
}
func TestOpenAIGatewayService_OpenAIAccountSchedulerMetrics_DisabledNoOp(t *testing.T) {
resetOpenAIAdvancedSchedulerSettingCacheForTest()
svc := &OpenAIGatewayService{}
ttft := 120
svc.ReportOpenAIAccountScheduleResult(10, true, &ttft)
svc.RecordOpenAIAccountSwitch()
snapshot := svc.SnapshotOpenAIAccountSchedulerMetrics()
require.Equal(t, OpenAIAccountSchedulerMetricsSnapshot{}, snapshot)
}
func TestOpenAIGatewayService_SelectAccountWithScheduler_SessionStickyRateLimitedAccountFallsBackToFreshCandidate(t *testing.T) {
ctx := context.Background()
groupID := int64(10101)
......@@ -53,10 +474,17 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_SessionStickyRateLimite
staleBackup := &Account{ID: 31002, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 5}
freshSticky := &Account{ID: 31001, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 0, RateLimitResetAt: &rateLimitedUntil}
freshBackup := &Account{ID: 31002, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 5}
cache := &stubGatewayCache{sessionBindings: map[string]int64{"openai:session_hash_rate_limited": 31001}}
cache := &schedulerTestGatewayCache{sessionBindings: map[string]int64{"openai:session_hash_rate_limited": 31001}}
snapshotCache := &openAISnapshotCacheStub{snapshotAccounts: []*Account{staleSticky, staleBackup}, accountsByID: map[int64]*Account{31001: freshSticky, 31002: freshBackup}}
snapshotService := &SchedulerSnapshotService{cache: snapshotCache}
svc := &OpenAIGatewayService{accountRepo: stubOpenAIAccountRepo{accounts: []Account{*freshSticky, *freshBackup}}, cache: cache, cfg: &config.Config{}, schedulerSnapshot: snapshotService, concurrencyService: NewConcurrencyService(stubConcurrencyCache{})}
svc := &OpenAIGatewayService{
accountRepo: schedulerTestOpenAIAccountRepo{accounts: []Account{*freshSticky, *freshBackup}},
cache: cache,
cfg: &config.Config{},
rateLimitService: newOpenAIAdvancedSchedulerRateLimitService("true"),
schedulerSnapshot: snapshotService,
concurrencyService: NewConcurrencyService(schedulerTestConcurrencyCache{}),
}
selection, decision, err := svc.SelectAccountWithScheduler(ctx, &groupID, "", "session_hash_rate_limited", "gpt-5.1", nil, OpenAIUpstreamTransportAny)
require.NoError(t, err)
......@@ -76,7 +504,12 @@ func TestOpenAIGatewayService_SelectAccountForModelWithExclusions_SkipsFreshlyRa
freshSecondary := &Account{ID: 32002, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 5}
snapshotCache := &openAISnapshotCacheStub{snapshotAccounts: []*Account{stalePrimary, staleSecondary}, accountsByID: map[int64]*Account{32001: freshPrimary, 32002: freshSecondary}}
snapshotService := &SchedulerSnapshotService{cache: snapshotCache}
svc := &OpenAIGatewayService{accountRepo: stubOpenAIAccountRepo{accounts: []Account{*freshPrimary, *freshSecondary}}, cfg: &config.Config{}, schedulerSnapshot: snapshotService}
svc := &OpenAIGatewayService{
accountRepo: schedulerTestOpenAIAccountRepo{accounts: []Account{*freshPrimary, *freshSecondary}},
cfg: &config.Config{},
rateLimitService: newOpenAIAdvancedSchedulerRateLimitService("true"),
schedulerSnapshot: snapshotService,
}
account, err := svc.SelectAccountForModelWithExclusions(ctx, &groupID, "", "gpt-5.1", nil)
require.NoError(t, err)
......@@ -92,18 +525,19 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_SessionStickyDBRuntimeR
staleBackup := &Account{ID: 33002, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 5}
dbSticky := Account{ID: 33001, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 0, RateLimitResetAt: &rateLimitedUntil}
dbBackup := Account{ID: 33002, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 5}
cache := &stubGatewayCache{sessionBindings: map[string]int64{"openai:session_hash_db_runtime_recheck": 33001}}
cache := &schedulerTestGatewayCache{sessionBindings: map[string]int64{"openai:session_hash_db_runtime_recheck": 33001}}
snapshotCache := &openAISnapshotCacheStub{
snapshotAccounts: []*Account{staleSticky, staleBackup},
accountsByID: map[int64]*Account{33001: staleSticky, 33002: staleBackup},
}
snapshotService := &SchedulerSnapshotService{cache: snapshotCache}
svc := &OpenAIGatewayService{
accountRepo: stubOpenAIAccountRepo{accounts: []Account{dbSticky, dbBackup}},
accountRepo: schedulerTestOpenAIAccountRepo{accounts: []Account{dbSticky, dbBackup}},
cache: cache,
cfg: &config.Config{},
rateLimitService: newOpenAIAdvancedSchedulerRateLimitService("true"),
schedulerSnapshot: snapshotService,
concurrencyService: NewConcurrencyService(stubConcurrencyCache{}),
concurrencyService: NewConcurrencyService(schedulerTestConcurrencyCache{}),
}
selection, decision, err := svc.SelectAccountWithScheduler(ctx, &groupID, "", "session_hash_db_runtime_recheck", "gpt-5.1", nil, OpenAIUpstreamTransportAny)
......@@ -128,8 +562,9 @@ func TestOpenAIGatewayService_SelectAccountForModelWithExclusions_DBRuntimeReche
}
snapshotService := &SchedulerSnapshotService{cache: snapshotCache}
svc := &OpenAIGatewayService{
accountRepo: stubOpenAIAccountRepo{accounts: []Account{dbPrimary, dbSecondary}},
accountRepo: schedulerTestOpenAIAccountRepo{accounts: []Account{dbPrimary, dbSecondary}},
cfg: &config.Config{},
rateLimitService: newOpenAIAdvancedSchedulerRateLimitService("true"),
schedulerSnapshot: snapshotService,
}
......@@ -153,7 +588,7 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_PreviousResponseSticky(
"openai_apikey_responses_websockets_v2_enabled": true,
},
}
cache := &stubGatewayCache{}
cache := &schedulerTestGatewayCache{}
cfg := &config.Config{}
cfg.Gateway.OpenAIWS.Enabled = true
cfg.Gateway.OpenAIWS.OAuthEnabled = true
......@@ -163,10 +598,11 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_PreviousResponseSticky(
cfg.Gateway.OpenAIWS.StickyResponseIDTTLSeconds = 3600
svc := &OpenAIGatewayService{
accountRepo: stubOpenAIAccountRepo{accounts: []Account{account}},
accountRepo: schedulerTestOpenAIAccountRepo{accounts: []Account{account}},
cache: cache,
cfg: cfg,
concurrencyService: NewConcurrencyService(stubConcurrencyCache{}),
rateLimitService: newOpenAIAdvancedSchedulerRateLimitService("true"),
concurrencyService: NewConcurrencyService(schedulerTestConcurrencyCache{}),
}
store := svc.getOpenAIWSStateStore()
......@@ -204,17 +640,18 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_SessionSticky(t *testin
Schedulable: true,
Concurrency: 1,
}
cache := &stubGatewayCache{
cache := &schedulerTestGatewayCache{
sessionBindings: map[string]int64{
"openai:session_hash_abc": account.ID,
},
}
svc := &OpenAIGatewayService{
accountRepo: stubOpenAIAccountRepo{accounts: []Account{account}},
accountRepo: schedulerTestOpenAIAccountRepo{accounts: []Account{account}},
cache: cache,
cfg: &config.Config{},
concurrencyService: NewConcurrencyService(stubConcurrencyCache{}),
rateLimitService: newOpenAIAdvancedSchedulerRateLimitService("true"),
concurrencyService: NewConcurrencyService(schedulerTestConcurrencyCache{}),
}
selection, decision, err := svc.SelectAccountWithScheduler(
......@@ -260,7 +697,7 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_SessionStickyBusyKeepsS
Priority: 9,
},
}
cache := &stubGatewayCache{
cache := &schedulerTestGatewayCache{
sessionBindings: map[string]int64{
"openai:session_hash_sticky_busy": 21001,
},
......@@ -273,7 +710,7 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_SessionStickyBusyKeepsS
cfg.Gateway.OpenAIWS.OAuthEnabled = true
cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
concurrencyCache := stubConcurrencyCache{
concurrencyCache := schedulerTestConcurrencyCache{
acquireResults: map[int64]bool{
21001: false, // sticky 账号已满
21002: true, // 若回退负载均衡会命中该账号(本测试要求不能切换)
......@@ -288,9 +725,10 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_SessionStickyBusyKeepsS
}
svc := &OpenAIGatewayService{
accountRepo: stubOpenAIAccountRepo{accounts: accounts},
accountRepo: schedulerTestOpenAIAccountRepo{accounts: accounts},
cache: cache,
cfg: cfg,
rateLimitService: newOpenAIAdvancedSchedulerRateLimitService("true"),
concurrencyService: NewConcurrencyService(concurrencyCache),
}
......@@ -328,17 +766,18 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_SessionSticky_ForceHTTP
"openai_ws_force_http": true,
},
}
cache := &stubGatewayCache{
cache := &schedulerTestGatewayCache{
sessionBindings: map[string]int64{
"openai:session_hash_force_http": account.ID,
},
}
svc := &OpenAIGatewayService{
accountRepo: stubOpenAIAccountRepo{accounts: []Account{account}},
accountRepo: schedulerTestOpenAIAccountRepo{accounts: []Account{account}},
cache: cache,
cfg: &config.Config{},
concurrencyService: NewConcurrencyService(stubConcurrencyCache{}),
rateLimitService: newOpenAIAdvancedSchedulerRateLimitService("true"),
concurrencyService: NewConcurrencyService(schedulerTestConcurrencyCache{}),
}
selection, decision, err := svc.SelectAccountWithScheduler(
......@@ -387,15 +826,15 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_RequiredWSV2_SkipsStick
},
},
}
cache := &stubGatewayCache{
cache := &schedulerTestGatewayCache{
sessionBindings: map[string]int64{
"openai:session_hash_ws_only": 2201,
},
}
cfg := newOpenAIWSV2TestConfig()
cfg := newSchedulerTestOpenAIWSV2Config()
// 构造“HTTP-only 账号负载更低”的场景,验证 required transport 会强制过滤。
concurrencyCache := stubConcurrencyCache{
concurrencyCache := schedulerTestConcurrencyCache{
loadMap: map[int64]*AccountLoadInfo{
2201: {AccountID: 2201, LoadRate: 0, WaitingCount: 0},
2202: {AccountID: 2202, LoadRate: 90, WaitingCount: 5},
......@@ -403,9 +842,10 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_RequiredWSV2_SkipsStick
}
svc := &OpenAIGatewayService{
accountRepo: stubOpenAIAccountRepo{accounts: accounts},
accountRepo: schedulerTestOpenAIAccountRepo{accounts: accounts},
cache: cache,
cfg: cfg,
rateLimitService: newOpenAIAdvancedSchedulerRateLimitService("true"),
concurrencyService: NewConcurrencyService(concurrencyCache),
}
......@@ -445,10 +885,11 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_RequiredWSV2_NoAvailabl
}
svc := &OpenAIGatewayService{
accountRepo: stubOpenAIAccountRepo{accounts: accounts},
cache: &stubGatewayCache{},
cfg: newOpenAIWSV2TestConfig(),
concurrencyService: NewConcurrencyService(stubConcurrencyCache{}),
accountRepo: schedulerTestOpenAIAccountRepo{accounts: accounts},
cache: &schedulerTestGatewayCache{},
cfg: newSchedulerTestOpenAIWSV2Config(),
rateLimitService: newOpenAIAdvancedSchedulerRateLimitService("true"),
concurrencyService: NewConcurrencyService(schedulerTestConcurrencyCache{}),
}
selection, decision, err := svc.SelectAccountWithScheduler(
......@@ -507,7 +948,7 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_LoadBalanceTopKFallback
cfg.Gateway.OpenAIWS.SchedulerScoreWeights.ErrorRate = 0.2
cfg.Gateway.OpenAIWS.SchedulerScoreWeights.TTFT = 0.1
concurrencyCache := stubConcurrencyCache{
concurrencyCache := schedulerTestConcurrencyCache{
loadMap: map[int64]*AccountLoadInfo{
3001: {AccountID: 3001, LoadRate: 95, WaitingCount: 8},
3002: {AccountID: 3002, LoadRate: 20, WaitingCount: 1},
......@@ -520,9 +961,10 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_LoadBalanceTopKFallback
}
svc := &OpenAIGatewayService{
accountRepo: stubOpenAIAccountRepo{accounts: accounts},
cache: &stubGatewayCache{},
accountRepo: schedulerTestOpenAIAccountRepo{accounts: accounts},
cache: &schedulerTestGatewayCache{},
cfg: cfg,
rateLimitService: newOpenAIAdvancedSchedulerRateLimitService("true"),
concurrencyService: NewConcurrencyService(concurrencyCache),
}
......@@ -559,16 +1001,17 @@ func TestOpenAIGatewayService_OpenAIAccountSchedulerMetrics(t *testing.T) {
Schedulable: true,
Concurrency: 1,
}
cache := &stubGatewayCache{
cache := &schedulerTestGatewayCache{
sessionBindings: map[string]int64{
"openai:session_hash_metrics": account.ID,
},
}
svc := &OpenAIGatewayService{
accountRepo: stubOpenAIAccountRepo{accounts: []Account{account}},
accountRepo: schedulerTestOpenAIAccountRepo{accounts: []Account{account}},
cache: cache,
cfg: &config.Config{},
concurrencyService: NewConcurrencyService(stubConcurrencyCache{}),
rateLimitService: newOpenAIAdvancedSchedulerRateLimitService("true"),
concurrencyService: NewConcurrencyService(schedulerTestConcurrencyCache{}),
}
selection, _, err := svc.SelectAccountWithScheduler(ctx, &groupID, "", "session_hash_metrics", "gpt-5.1", nil, OpenAIUpstreamTransportAny)
......@@ -749,7 +1192,7 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_LoadBalanceDistributesA
cfg.Gateway.OpenAIWS.SchedulerScoreWeights.ErrorRate = 1
cfg.Gateway.OpenAIWS.SchedulerScoreWeights.TTFT = 1
concurrencyCache := stubConcurrencyCache{
concurrencyCache := schedulerTestConcurrencyCache{
loadMap: map[int64]*AccountLoadInfo{
5101: {AccountID: 5101, LoadRate: 20, WaitingCount: 1},
5102: {AccountID: 5102, LoadRate: 20, WaitingCount: 1},
......@@ -757,9 +1200,10 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_LoadBalanceDistributesA
},
}
svc := &OpenAIGatewayService{
accountRepo: stubOpenAIAccountRepo{accounts: accounts},
cache: &stubGatewayCache{sessionBindings: map[string]int64{}},
accountRepo: schedulerTestOpenAIAccountRepo{accounts: accounts},
cache: &schedulerTestGatewayCache{sessionBindings: map[string]int64{}},
cfg: cfg,
rateLimitService: newOpenAIAdvancedSchedulerRateLimitService("true"),
concurrencyService: NewConcurrencyService(concurrencyCache),
}
......@@ -905,12 +1349,14 @@ func TestDefaultOpenAIAccountScheduler_ReportSwitchAndSnapshot(t *testing.T) {
}
func TestOpenAIGatewayService_SchedulerWrappersAndDefaults(t *testing.T) {
resetOpenAIAdvancedSchedulerSettingCacheForTest()
svc := &OpenAIGatewayService{}
ttft := 120
svc.ReportOpenAIAccountScheduleResult(10, true, &ttft)
svc.RecordOpenAIAccountSwitch()
snapshot := svc.SnapshotOpenAIAccountSchedulerMetrics()
require.GreaterOrEqual(t, snapshot.AccountSwitchTotal, int64(1))
require.Equal(t, OpenAIAccountSchedulerMetricsSnapshot{}, snapshot)
require.Equal(t, 7, svc.openAIWSLBTopK())
require.Equal(t, openaiStickySessionTTL, svc.openAIWSSessionStickyTTL())
......@@ -947,7 +1393,7 @@ func TestDefaultOpenAIAccountScheduler_IsAccountTransportCompatible_Branches(t *
require.True(t, scheduler.isAccountTransportCompatible(nil, OpenAIUpstreamTransportHTTPSSE))
require.False(t, scheduler.isAccountTransportCompatible(nil, OpenAIUpstreamTransportResponsesWebsocketV2))
cfg := newOpenAIWSV2TestConfig()
cfg := newSchedulerTestOpenAIWSV2Config()
scheduler.service = &OpenAIGatewayService{cfg: cfg}
account := &Account{
ID: 8801,
......
......@@ -38,11 +38,12 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_UsesWSPassthroughSnapsh
cfg.Gateway.OpenAIWS.IngressModeDefault = OpenAIWSIngressModeCtxPool
svc := &OpenAIGatewayService{
accountRepo: stubOpenAIAccountRepo{accounts: []Account{*account}},
cache: &stubGatewayCache{},
accountRepo: schedulerTestOpenAIAccountRepo{accounts: []Account{*account}},
cache: &schedulerTestGatewayCache{},
cfg: cfg,
rateLimitService: newOpenAIAdvancedSchedulerRateLimitService("true"),
schedulerSnapshot: &SchedulerSnapshotService{cache: snapshotCache},
concurrencyService: NewConcurrencyService(stubConcurrencyCache{}),
concurrencyService: NewConcurrencyService(schedulerTestConcurrencyCache{}),
}
selection, decision, err := svc.SelectAccountWithScheduler(
......
......@@ -107,11 +107,15 @@ func (s *OpenAIGatewayService) ForwardAsChatCompletions(
responsesBody = stripped
}
}
responsesBody, normalizedServiceTier, err := normalizeResponsesBodyServiceTier(responsesBody)
if err != nil {
return nil, fmt.Errorf("normalize service_tier in responses-shape body: %w", err)
}
// Minimal stub populated from the raw body so downstream billing
// propagation (ServiceTier, ReasoningEffort) keeps working.
responsesReq = &apicompat.ResponsesRequest{
Model: upstreamModel,
ServiceTier: gjson.GetBytes(responsesBody, "service_tier").String(),
ServiceTier: normalizedServiceTier,
}
if effort := gjson.GetBytes(responsesBody, "reasoning.effort").String(); effort != "" {
responsesReq.Reasoning = &apicompat.ResponsesReasoning{Effort: effort}
......@@ -124,6 +128,7 @@ func (s *OpenAIGatewayService) ForwardAsChatCompletions(
return nil, fmt.Errorf("convert chat completions to responses: %w", err)
}
responsesReq.Model = upstreamModel
normalizeResponsesRequestServiceTier(responsesReq)
responsesBody, err = json.Marshal(responsesReq)
if err != nil {
return nil, fmt.Errorf("marshal responses request: %w", err)
......@@ -274,6 +279,41 @@ func (s *OpenAIGatewayService) ForwardAsChatCompletions(
return result, handleErr
}
func normalizeResponsesRequestServiceTier(req *apicompat.ResponsesRequest) {
if req == nil {
return
}
req.ServiceTier = normalizedOpenAIServiceTierValue(req.ServiceTier)
}
func normalizeResponsesBodyServiceTier(body []byte) ([]byte, string, error) {
if len(body) == 0 {
return body, "", nil
}
rawServiceTier := gjson.GetBytes(body, "service_tier").String()
if rawServiceTier == "" {
return body, "", nil
}
normalizedServiceTier := normalizedOpenAIServiceTierValue(rawServiceTier)
if normalizedServiceTier == "" {
trimmed, err := sjson.DeleteBytes(body, "service_tier")
return trimmed, "", err
}
if normalizedServiceTier == rawServiceTier {
return body, normalizedServiceTier, nil
}
trimmed, err := sjson.SetBytes(body, "service_tier", normalizedServiceTier)
return trimmed, normalizedServiceTier, err
}
func normalizedOpenAIServiceTierValue(raw string) string {
normalized := normalizeOpenAIServiceTier(raw)
if normalized == nil {
return ""
}
return *normalized
}
// handleChatCompletionsErrorResponse reads an upstream error and returns it in
// OpenAI Chat Completions error format.
func (s *OpenAIGatewayService) handleChatCompletionsErrorResponse(
......
package service
import (
"testing"
"github.com/Wei-Shaw/sub2api/internal/pkg/apicompat"
"github.com/stretchr/testify/require"
"github.com/tidwall/gjson"
)
func TestNormalizeResponsesRequestServiceTier(t *testing.T) {
t.Parallel()
req := &apicompat.ResponsesRequest{ServiceTier: " fast "}
normalizeResponsesRequestServiceTier(req)
require.Equal(t, "priority", req.ServiceTier)
req.ServiceTier = "flex"
normalizeResponsesRequestServiceTier(req)
require.Equal(t, "flex", req.ServiceTier)
req.ServiceTier = "default"
normalizeResponsesRequestServiceTier(req)
require.Empty(t, req.ServiceTier)
}
func TestNormalizeResponsesBodyServiceTier(t *testing.T) {
t.Parallel()
body, tier, err := normalizeResponsesBodyServiceTier([]byte(`{"model":"gpt-5.1","service_tier":"fast"}`))
require.NoError(t, err)
require.Equal(t, "priority", tier)
require.Equal(t, "priority", gjson.GetBytes(body, "service_tier").String())
body, tier, err = normalizeResponsesBodyServiceTier([]byte(`{"model":"gpt-5.1","service_tier":"flex"}`))
require.NoError(t, err)
require.Equal(t, "flex", tier)
require.Equal(t, "flex", gjson.GetBytes(body, "service_tier").String())
body, tier, err = normalizeResponsesBodyServiceTier([]byte(`{"model":"gpt-5.1","service_tier":"default"}`))
require.NoError(t, err)
require.Empty(t, tier)
require.False(t, gjson.GetBytes(body, "service_tier").Exists())
}
......@@ -20,6 +20,7 @@ func (s *PaymentConfigService) GetAvailableMethodLimits(ctx context.Context) (*M
return nil, fmt.Errorf("query provider instances: %w", err)
}
typeInstances := pcGroupByPaymentType(instances)
typeInstances = pcApplyEnabledVisibleMethodInstances(typeInstances, instances)
resp := &MethodLimitsResponse{
Methods: make(map[string]MethodLimits, len(typeInstances)),
}
......@@ -31,6 +32,27 @@ func (s *PaymentConfigService) GetAvailableMethodLimits(ctx context.Context) (*M
return resp, nil
}
func pcApplyEnabledVisibleMethodInstances(typeInstances map[string][]*dbent.PaymentProviderInstance, instances []*dbent.PaymentProviderInstance) map[string][]*dbent.PaymentProviderInstance {
if len(typeInstances) == 0 {
return typeInstances
}
filtered := make(map[string][]*dbent.PaymentProviderInstance, len(typeInstances))
for paymentType, groupedInstances := range typeInstances {
filtered[paymentType] = groupedInstances
}
for _, method := range []string{payment.TypeAlipay, payment.TypeWxpay} {
matching := filterEnabledVisibleMethodInstances(instances, method)
if len(matching) != 1 {
delete(filtered, method)
continue
}
filtered[method] = []*dbent.PaymentProviderInstance{matching[0]}
}
return filtered
}
// GetMethodLimits returns per-payment-type limits from enabled provider instances.
func (s *PaymentConfigService) GetMethodLimits(ctx context.Context, types []string) ([]MethodLimits, error) {
instances, err := s.entClient.PaymentProviderInstance.Query().
......
package service
import (
"context"
"testing"
dbent "github.com/Wei-Shaw/sub2api/ent"
......@@ -299,3 +300,66 @@ func TestPcInstanceTypeLimits(t *testing.T) {
}
})
}
func TestGetAvailableMethodLimitsHidesConflictingVisibleMethodProviders(t *testing.T) {
ctx := context.Background()
client := newPaymentConfigServiceTestClient(t)
_, err := client.PaymentProviderInstance.Create().
SetProviderKey(payment.TypeAlipay).
SetName("Official Alipay").
SetConfig("{}").
SetSupportedTypes("alipay").
SetLimits(`{"alipay":{"singleMin":10,"singleMax":100}}`).
SetEnabled(true).
Save(ctx)
if err != nil {
t.Fatalf("create official alipay instance: %v", err)
}
_, err = client.PaymentProviderInstance.Create().
SetProviderKey(payment.TypeEasyPay).
SetName("EasyPay Alipay").
SetConfig("{}").
SetSupportedTypes("alipay").
SetLimits(`{"alipay":{"singleMin":20,"singleMax":200}}`).
SetEnabled(true).
Save(ctx)
if err != nil {
t.Fatalf("create easypay alipay instance: %v", err)
}
_, err = client.PaymentProviderInstance.Create().
SetProviderKey(payment.TypeWxpay).
SetName("Official WeChat").
SetConfig("{}").
SetSupportedTypes("wxpay").
SetLimits(`{"wxpay":{"singleMin":30,"singleMax":300}}`).
SetEnabled(true).
Save(ctx)
if err != nil {
t.Fatalf("create official wxpay instance: %v", err)
}
svc := &PaymentConfigService{
entClient: client,
}
resp, err := svc.GetAvailableMethodLimits(ctx)
if err != nil {
t.Fatalf("GetAvailableMethodLimits returned error: %v", err)
}
if _, ok := resp.Methods[payment.TypeAlipay]; ok {
t.Fatalf("alipay should be hidden when multiple enabled providers claim it, got %v", resp.Methods[payment.TypeAlipay])
}
wxpayLimits, ok := resp.Methods[payment.TypeWxpay]
if !ok {
t.Fatalf("expected wxpay limits to remain visible, got %v", resp.Methods)
}
if wxpayLimits.SingleMin != 30 || wxpayLimits.SingleMax != 300 {
t.Fatalf("wxpay limits = %+v, want official-only min=30 max=300", wxpayLimits)
}
if resp.GlobalMin != 30 || resp.GlobalMax != 300 {
t.Fatalf("global range = (%v, %v), want (30, 300)", resp.GlobalMin, resp.GlobalMax)
}
}
......@@ -150,6 +150,9 @@ func (s *PaymentConfigService) CreateProviderInstance(ctx context.Context, req C
if err := validateProviderRequest(req.ProviderKey, req.Name, typesStr); err != nil {
return nil, err
}
if err := s.validateVisibleMethodEnablementConflicts(ctx, 0, req.ProviderKey, typesStr, req.Enabled); err != nil {
return nil, err
}
if req.Enabled {
if err := s.validateProviderConfig(req.ProviderKey, req.Config); err != nil {
return nil, err
......@@ -183,26 +186,25 @@ func validateProviderRequest(providerKey, name, supportedTypes string) error {
// NOTE: This function exceeds 30 lines due to per-field nil-check patch update
// boilerplate and pending-order safety checks.
func (s *PaymentConfigService) UpdateProviderInstance(ctx context.Context, id int64, req UpdateProviderInstanceRequest) (*dbent.PaymentProviderInstance, error) {
var cachedInst *dbent.PaymentProviderInstance
loadInst := func() (*dbent.PaymentProviderInstance, error) {
if cachedInst != nil {
return cachedInst, nil
}
inst, err := s.entClient.PaymentProviderInstance.Get(ctx, id)
if err != nil {
return nil, fmt.Errorf("load provider instance: %w", err)
}
cachedInst = inst
return inst, nil
current, err := s.entClient.PaymentProviderInstance.Get(ctx, id)
if err != nil {
return nil, fmt.Errorf("load provider instance: %w", err)
}
nextEnabled := current.Enabled
if req.Enabled != nil {
nextEnabled = *req.Enabled
}
nextSupportedTypes := current.SupportedTypes
if req.SupportedTypes != nil {
nextSupportedTypes = joinTypes(req.SupportedTypes)
}
if err := s.validateVisibleMethodEnablementConflicts(ctx, id, current.ProviderKey, nextSupportedTypes, nextEnabled); err != nil {
return nil, err
}
if req.Config != nil {
inst, err := loadInst()
if err != nil {
return nil, err
}
hasSensitive := false
for k, v := range req.Config {
if v != "" && isSensitiveProviderConfigField(inst.ProviderKey, k) {
if v != "" && isSensitiveProviderConfigField(current.ProviderKey, k) {
hasSensitive = true
break
}
......@@ -231,11 +233,7 @@ func (s *PaymentConfigService) UpdateProviderInstance(ctx context.Context, id in
// Validate merged config when the instance will end up enabled.
// This surfaces provider-level errors (e.g. wxpay missing certSerial) at save time,
// so admins see them in the dialog instead of only when an order is created.
inst, err := s.entClient.PaymentProviderInstance.Get(ctx, id)
if err != nil {
return nil, fmt.Errorf("load provider instance: %w", err)
}
finalEnabled := inst.Enabled
finalEnabled := current.Enabled
if req.Enabled != nil {
finalEnabled = *req.Enabled
}
......@@ -249,12 +247,12 @@ func (s *PaymentConfigService) UpdateProviderInstance(ctx context.Context, id in
if finalEnabled {
configToValidate := mergedConfig
if configToValidate == nil {
configToValidate, err = s.decryptConfig(inst.Config)
configToValidate, err = s.decryptConfig(current.Config)
if err != nil {
return nil, fmt.Errorf("decrypt existing config: %w", err)
}
}
if err := s.validateProviderConfig(inst.ProviderKey, configToValidate); err != nil {
if err := s.validateProviderConfig(current.ProviderKey, configToValidate); err != nil {
return nil, err
}
}
......@@ -277,11 +275,7 @@ func (s *PaymentConfigService) UpdateProviderInstance(ctx context.Context, id in
}
if count > 0 {
// Load current instance to compare types
inst, err := s.entClient.PaymentProviderInstance.Get(ctx, id)
if err != nil {
return nil, fmt.Errorf("load provider instance: %w", err)
}
oldTypes := strings.Split(inst.SupportedTypes, ",")
oldTypes := strings.Split(current.SupportedTypes, ",")
newTypes := req.SupportedTypes
for _, ot := range oldTypes {
ot = strings.TrimSpace(ot)
......@@ -326,10 +320,7 @@ func (s *PaymentConfigService) UpdateProviderInstance(ctx context.Context, id in
if req.RefundEnabled != nil {
refundEnabled = *req.RefundEnabled
} else {
inst, err := s.entClient.PaymentProviderInstance.Get(ctx, id)
if err == nil {
refundEnabled = inst.RefundEnabled
}
refundEnabled = current.RefundEnabled
}
if refundEnabled {
u.SetAllowUserRefund(true)
......
......@@ -3,8 +3,10 @@
package service
import (
"context"
"testing"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
......@@ -196,3 +198,122 @@ func TestJoinTypes(t *testing.T) {
})
}
}
func TestCreateProviderInstanceRejectsConflictingVisibleMethodEnablement(t *testing.T) {
t.Parallel()
ctx := context.Background()
client := newPaymentConfigServiceTestClient(t)
svc := &PaymentConfigService{
entClient: client,
encryptionKey: []byte("0123456789abcdef0123456789abcdef"),
}
_, err := svc.CreateProviderInstance(ctx, CreateProviderInstanceRequest{
ProviderKey: "easypay",
Name: "EasyPay Alipay",
Config: map[string]string{
"pid": "1001",
"pkey": "pkey-1001",
"apiBase": "https://pay.example.com",
"notifyUrl": "https://merchant.example.com/notify",
"returnUrl": "https://merchant.example.com/return",
},
SupportedTypes: []string{"alipay"},
Enabled: true,
})
require.NoError(t, err)
_, err = svc.CreateProviderInstance(ctx, CreateProviderInstanceRequest{
ProviderKey: "alipay",
Name: "Official Alipay",
Config: map[string]string{"appId": "app-1"},
SupportedTypes: []string{"alipay"},
Enabled: true,
})
require.Error(t, err)
require.Equal(t, "PAYMENT_PROVIDER_CONFLICT", infraerrors.Reason(err))
}
func TestUpdateProviderInstanceRejectsEnablingConflictingVisibleMethodProvider(t *testing.T) {
t.Parallel()
ctx := context.Background()
client := newPaymentConfigServiceTestClient(t)
svc := &PaymentConfigService{
entClient: client,
encryptionKey: []byte("0123456789abcdef0123456789abcdef"),
}
existing, err := svc.CreateProviderInstance(ctx, CreateProviderInstanceRequest{
ProviderKey: "easypay",
Name: "EasyPay WeChat",
Config: map[string]string{
"pid": "2001",
"pkey": "pkey-2001",
"apiBase": "https://pay.example.com",
"notifyUrl": "https://merchant.example.com/notify",
"returnUrl": "https://merchant.example.com/return",
},
SupportedTypes: []string{"wxpay"},
Enabled: true,
})
require.NoError(t, err)
require.NotNil(t, existing)
candidate, err := svc.CreateProviderInstance(ctx, CreateProviderInstanceRequest{
ProviderKey: "wxpay",
Name: "Official WeChat",
Config: map[string]string{"appId": "wx-app"},
SupportedTypes: []string{"wxpay"},
Enabled: false,
})
require.NoError(t, err)
_, err = svc.UpdateProviderInstance(ctx, candidate.ID, UpdateProviderInstanceRequest{
Enabled: boolPtrValue(true),
})
require.Error(t, err)
require.Equal(t, "PAYMENT_PROVIDER_CONFLICT", infraerrors.Reason(err))
}
func TestUpdateProviderInstancePersistsEnabledAndSupportedTypes(t *testing.T) {
t.Parallel()
ctx := context.Background()
client := newPaymentConfigServiceTestClient(t)
svc := &PaymentConfigService{
entClient: client,
encryptionKey: []byte("0123456789abcdef0123456789abcdef"),
}
instance, err := svc.CreateProviderInstance(ctx, CreateProviderInstanceRequest{
ProviderKey: "easypay",
Name: "EasyPay",
Config: map[string]string{
"pid": "3001",
"pkey": "pkey-3001",
"apiBase": "https://pay.example.com",
"notifyUrl": "https://merchant.example.com/notify",
"returnUrl": "https://merchant.example.com/return",
},
SupportedTypes: []string{"alipay"},
Enabled: false,
})
require.NoError(t, err)
_, err = svc.UpdateProviderInstance(ctx, instance.ID, UpdateProviderInstanceRequest{
Enabled: boolPtrValue(true),
SupportedTypes: []string{"alipay", "wxpay"},
})
require.NoError(t, err)
saved, err := client.PaymentProviderInstance.Get(ctx, instance.ID)
require.NoError(t, err)
require.True(t, saved.Enabled)
require.Equal(t, "alipay,wxpay", saved.SupportedTypes)
}
func boolPtrValue(v bool) *bool {
return &v
}
......@@ -93,6 +93,11 @@ type UpdatePaymentConfigRequest struct {
CancelRateLimitWindow *int `json:"cancel_rate_limit_window"`
CancelRateLimitUnit *string `json:"cancel_rate_limit_unit"`
CancelRateLimitMode *string `json:"cancel_rate_limit_window_mode"`
VisibleMethodAlipaySource *string `json:"payment_visible_method_alipay_source"`
VisibleMethodWxpaySource *string `json:"payment_visible_method_wxpay_source"`
VisibleMethodAlipayEnabled *bool `json:"payment_visible_method_alipay_enabled"`
VisibleMethodWxpayEnabled *bool `json:"payment_visible_method_wxpay_enabled"`
}
// MethodLimits holds per-payment-type limits.
......@@ -196,6 +201,8 @@ func (s *PaymentConfigService) GetPaymentConfig(ctx context.Context) (*PaymentCo
SettingHelpImageURL, SettingHelpText,
SettingCancelRateLimitOn, SettingCancelRateLimitMax,
SettingCancelWindowSize, SettingCancelWindowUnit, SettingCancelWindowMode,
SettingPaymentVisibleMethodAlipayEnabled, SettingPaymentVisibleMethodAlipaySource,
SettingPaymentVisibleMethodWxpayEnabled, SettingPaymentVisibleMethodWxpaySource,
}
vals, err := s.settingRepo.GetMultiple(ctx, keys)
if err != nil {
......@@ -234,18 +241,23 @@ func (s *PaymentConfigService) parsePaymentConfig(vals map[string]string) *Payme
cfg.LoadBalanceStrategy = payment.DefaultLoadBalanceStrategy
}
if raw := vals[SettingEnabledPaymentTypes]; raw != "" {
types := make([]string, 0, len(strings.Split(raw, ",")))
for _, t := range strings.Split(raw, ",") {
t = strings.TrimSpace(t)
if t != "" {
cfg.EnabledTypes = append(cfg.EnabledTypes, t)
types = append(types, t)
}
}
cfg.EnabledTypes = NormalizeVisibleMethods(types)
}
return cfg
}
// getStripePublishableKey finds the publishable key from the first enabled Stripe provider instance.
func (s *PaymentConfigService) getStripePublishableKey(ctx context.Context) string {
if s.entClient == nil {
return ""
}
instances, err := s.entClient.PaymentProviderInstance.Query().
Where(
paymentproviderinstance.EnabledEQ(true),
......@@ -282,25 +294,29 @@ func (s *PaymentConfigService) UpdatePaymentConfig(ctx context.Context, req Upda
}
}
m := map[string]string{
SettingPaymentEnabled: formatBoolOrEmpty(req.Enabled),
SettingMinRechargeAmount: formatPositiveFloat(req.MinAmount),
SettingMaxRechargeAmount: formatPositiveFloat(req.MaxAmount),
SettingDailyRechargeLimit: formatPositiveFloat(req.DailyLimit),
SettingOrderTimeoutMinutes: formatPositiveInt(req.OrderTimeoutMin),
SettingMaxPendingOrders: formatPositiveInt(req.MaxPendingOrders),
SettingBalancePayDisabled: formatBoolOrEmpty(req.BalanceDisabled),
SettingBalanceRechargeMult: formatPositiveFloat(req.BalanceRechargeMultiplier),
SettingRechargeFeeRate: formatNonNegativeFloat(req.RechargeFeeRate),
SettingLoadBalanceStrategy: derefStr(req.LoadBalanceStrategy),
SettingProductNamePrefix: derefStr(req.ProductNamePrefix),
SettingProductNameSuffix: derefStr(req.ProductNameSuffix),
SettingHelpImageURL: derefStr(req.HelpImageURL),
SettingHelpText: derefStr(req.HelpText),
SettingCancelRateLimitOn: formatBoolOrEmpty(req.CancelRateLimitEnabled),
SettingCancelRateLimitMax: formatPositiveInt(req.CancelRateLimitMax),
SettingCancelWindowSize: formatPositiveInt(req.CancelRateLimitWindow),
SettingCancelWindowUnit: derefStr(req.CancelRateLimitUnit),
SettingCancelWindowMode: derefStr(req.CancelRateLimitMode),
SettingPaymentEnabled: formatBoolOrEmpty(req.Enabled),
SettingMinRechargeAmount: formatPositiveFloat(req.MinAmount),
SettingMaxRechargeAmount: formatPositiveFloat(req.MaxAmount),
SettingDailyRechargeLimit: formatPositiveFloat(req.DailyLimit),
SettingOrderTimeoutMinutes: formatPositiveInt(req.OrderTimeoutMin),
SettingMaxPendingOrders: formatPositiveInt(req.MaxPendingOrders),
SettingBalancePayDisabled: formatBoolOrEmpty(req.BalanceDisabled),
SettingBalanceRechargeMult: formatPositiveFloat(req.BalanceRechargeMultiplier),
SettingRechargeFeeRate: formatNonNegativeFloat(req.RechargeFeeRate),
SettingLoadBalanceStrategy: derefStr(req.LoadBalanceStrategy),
SettingProductNamePrefix: derefStr(req.ProductNamePrefix),
SettingProductNameSuffix: derefStr(req.ProductNameSuffix),
SettingHelpImageURL: derefStr(req.HelpImageURL),
SettingHelpText: derefStr(req.HelpText),
SettingCancelRateLimitOn: formatBoolOrEmpty(req.CancelRateLimitEnabled),
SettingCancelRateLimitMax: formatPositiveInt(req.CancelRateLimitMax),
SettingCancelWindowSize: formatPositiveInt(req.CancelRateLimitWindow),
SettingCancelWindowUnit: derefStr(req.CancelRateLimitUnit),
SettingCancelWindowMode: derefStr(req.CancelRateLimitMode),
SettingPaymentVisibleMethodAlipaySource: derefStr(req.VisibleMethodAlipaySource),
SettingPaymentVisibleMethodWxpaySource: derefStr(req.VisibleMethodWxpaySource),
SettingPaymentVisibleMethodAlipayEnabled: formatBoolOrEmpty(req.VisibleMethodAlipayEnabled),
SettingPaymentVisibleMethodWxpayEnabled: formatBoolOrEmpty(req.VisibleMethodWxpayEnabled),
}
if req.EnabledTypes != nil {
m[SettingEnabledPaymentTypes] = strings.Join(req.EnabledTypes, ",")
......@@ -385,3 +401,79 @@ func pcParseInt(s string, defaultVal int) int {
}
return v
}
func buildVisibleMethodSourceAvailability(instances []*dbent.PaymentProviderInstance) map[string]bool {
available := make(map[string]bool, 4)
for _, inst := range instances {
switch inst.ProviderKey {
case payment.TypeAlipay:
if inst.SupportedTypes == "" || payment.InstanceSupportsType(inst.SupportedTypes, payment.TypeAlipay) || payment.InstanceSupportsType(inst.SupportedTypes, payment.TypeAlipayDirect) {
available[VisibleMethodSourceOfficialAlipay] = true
}
case payment.TypeWxpay:
if inst.SupportedTypes == "" || payment.InstanceSupportsType(inst.SupportedTypes, payment.TypeWxpay) || payment.InstanceSupportsType(inst.SupportedTypes, payment.TypeWxpayDirect) {
available[VisibleMethodSourceOfficialWechat] = true
}
case payment.TypeEasyPay:
for _, supportedType := range splitTypes(inst.SupportedTypes) {
switch NormalizeVisibleMethod(supportedType) {
case payment.TypeAlipay:
available[VisibleMethodSourceEasyPayAlipay] = true
case payment.TypeWxpay:
available[VisibleMethodSourceEasyPayWechat] = true
}
}
}
}
return available
}
func applyVisibleMethodRoutingToEnabledTypes(base []string, vals map[string]string, available map[string]bool) []string {
shouldExpose := map[string]bool{
payment.TypeAlipay: visibleMethodShouldBeExposed(payment.TypeAlipay, vals, available),
payment.TypeWxpay: visibleMethodShouldBeExposed(payment.TypeWxpay, vals, available),
}
seen := make(map[string]struct{}, len(base)+2)
out := make([]string, 0, len(base)+2)
appendType := func(paymentType string) {
paymentType = NormalizeVisibleMethod(paymentType)
if paymentType == "" {
return
}
if _, ok := seen[paymentType]; ok {
return
}
seen[paymentType] = struct{}{}
out = append(out, paymentType)
}
for _, paymentType := range base {
visibleMethod := NormalizeVisibleMethod(paymentType)
switch visibleMethod {
case payment.TypeAlipay, payment.TypeWxpay:
if shouldExpose[visibleMethod] {
appendType(visibleMethod)
}
default:
appendType(visibleMethod)
}
}
for _, visibleMethod := range []string{payment.TypeAlipay, payment.TypeWxpay} {
if shouldExpose[visibleMethod] {
appendType(visibleMethod)
}
}
return out
}
func visibleMethodShouldBeExposed(method string, vals map[string]string, available map[string]bool) bool {
enabledKey := visibleMethodEnabledSettingKey(method)
sourceKey := visibleMethodSourceSettingKey(method)
if enabledKey == "" || sourceKey == "" || vals[enabledKey] != "true" {
return false
}
source := NormalizeVisibleMethodSource(method, vals[sourceKey])
return source != "" && available[source]
}
package service
import (
"context"
"database/sql"
"fmt"
"strings"
"testing"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/enttest"
"github.com/Wei-Shaw/sub2api/internal/payment"
"entgo.io/ent/dialect"
entsql "entgo.io/ent/dialect/sql"
_ "modernc.org/sqlite"
)
func TestPcParseFloat(t *testing.T) {
......@@ -163,6 +173,20 @@ func TestParsePaymentConfig(t *testing.T) {
}
})
t.Run("enabled types are normalized to visible methods and deduplicated", func(t *testing.T) {
t.Parallel()
vals := map[string]string{
SettingEnabledPaymentTypes: "alipay_direct, alipay, wxpay_direct, wxpay",
}
cfg := svc.parsePaymentConfig(vals)
if len(cfg.EnabledTypes) != 2 {
t.Fatalf("EnabledTypes len = %d, want 2", len(cfg.EnabledTypes))
}
if cfg.EnabledTypes[0] != "alipay" || cfg.EnabledTypes[1] != "wxpay" {
t.Fatalf("EnabledTypes = %v, want [alipay wxpay]", cfg.EnabledTypes)
}
})
t.Run("empty enabled types string", func(t *testing.T) {
t.Parallel()
vals := map[string]string{
......@@ -204,3 +228,210 @@ func TestGetBasePaymentType(t *testing.T) {
})
}
}
func TestApplyVisibleMethodRoutingToEnabledTypes(t *testing.T) {
t.Parallel()
base := []string{"alipay", "wxpay", "stripe"}
vals := map[string]string{
SettingPaymentVisibleMethodAlipayEnabled: "true",
SettingPaymentVisibleMethodAlipaySource: VisibleMethodSourceOfficialAlipay,
SettingPaymentVisibleMethodWxpayEnabled: "true",
SettingPaymentVisibleMethodWxpaySource: VisibleMethodSourceOfficialWechat,
}
available := map[string]bool{
VisibleMethodSourceOfficialAlipay: true,
VisibleMethodSourceOfficialWechat: false,
}
got := applyVisibleMethodRoutingToEnabledTypes(base, vals, available)
want := []string{"alipay", "stripe"}
if len(got) != len(want) {
t.Fatalf("applyVisibleMethodRoutingToEnabledTypes len = %d, want %d (%v)", len(got), len(want), got)
}
for i := range want {
if got[i] != want[i] {
t.Fatalf("applyVisibleMethodRoutingToEnabledTypes[%d] = %q, want %q (full=%v)", i, got[i], want[i], got)
}
}
}
func TestApplyVisibleMethodRoutingAddsConfiguredVisibleMethod(t *testing.T) {
t.Parallel()
base := []string{"stripe"}
vals := map[string]string{
SettingPaymentVisibleMethodAlipayEnabled: "true",
SettingPaymentVisibleMethodAlipaySource: VisibleMethodSourceEasyPayAlipay,
}
available := map[string]bool{
VisibleMethodSourceEasyPayAlipay: true,
}
got := applyVisibleMethodRoutingToEnabledTypes(base, vals, available)
want := []string{"stripe", "alipay"}
if len(got) != len(want) {
t.Fatalf("applyVisibleMethodRoutingToEnabledTypes len = %d, want %d (%v)", len(got), len(want), got)
}
for i := range want {
if got[i] != want[i] {
t.Fatalf("applyVisibleMethodRoutingToEnabledTypes[%d] = %q, want %q (full=%v)", i, got[i], want[i], got)
}
}
}
func TestBuildVisibleMethodSourceAvailability(t *testing.T) {
t.Parallel()
instances := []*dbent.PaymentProviderInstance{
{ProviderKey: payment.TypeAlipay, SupportedTypes: "alipay"},
{ProviderKey: payment.TypeEasyPay, SupportedTypes: "wxpay_direct, alipay"},
{ProviderKey: payment.TypeWxpay, SupportedTypes: "wxpay_direct"},
}
got := buildVisibleMethodSourceAvailability(instances)
if !got[VisibleMethodSourceOfficialAlipay] {
t.Fatalf("expected %q to be available", VisibleMethodSourceOfficialAlipay)
}
if !got[VisibleMethodSourceEasyPayAlipay] {
t.Fatalf("expected %q to be available", VisibleMethodSourceEasyPayAlipay)
}
if !got[VisibleMethodSourceOfficialWechat] {
t.Fatalf("expected %q to be available", VisibleMethodSourceOfficialWechat)
}
if !got[VisibleMethodSourceEasyPayWechat] {
t.Fatalf("expected %q to be available", VisibleMethodSourceEasyPayWechat)
}
}
func TestGetPaymentConfigKeepsStoredEnabledTypes(t *testing.T) {
ctx := context.Background()
client := newPaymentConfigServiceTestClient(t)
_, err := client.PaymentProviderInstance.Create().
SetProviderKey(payment.TypeEasyPay).
SetName("EasyPay Alipay").
SetConfig("{}").
SetSupportedTypes("alipay").
SetEnabled(true).
Save(ctx)
if err != nil {
t.Fatalf("create easypay instance: %v", err)
}
svc := &PaymentConfigService{
entClient: client,
settingRepo: &paymentConfigSettingRepoStub{
values: map[string]string{
SettingEnabledPaymentTypes: "alipay,wxpay,stripe",
},
},
}
cfg, err := svc.GetPaymentConfig(ctx)
if err != nil {
t.Fatalf("GetPaymentConfig returned error: %v", err)
}
want := []string{payment.TypeAlipay, payment.TypeWxpay, payment.TypeStripe}
if len(cfg.EnabledTypes) != len(want) {
t.Fatalf("EnabledTypes len = %d, want %d (%v)", len(cfg.EnabledTypes), len(want), cfg.EnabledTypes)
}
for i := range want {
if cfg.EnabledTypes[i] != want[i] {
t.Fatalf("EnabledTypes[%d] = %q, want %q (full=%v)", i, cfg.EnabledTypes[i], want[i], cfg.EnabledTypes)
}
}
}
func newPaymentConfigServiceTestClient(t *testing.T) *dbent.Client {
t.Helper()
dbName := fmt.Sprintf(
"file:%s?mode=memory&cache=shared",
strings.NewReplacer("/", "_", " ", "_").Replace(t.Name()),
)
db, err := sql.Open("sqlite", dbName)
if err != nil {
t.Fatalf("open sqlite: %v", err)
}
t.Cleanup(func() { _ = db.Close() })
if _, err := db.Exec("PRAGMA foreign_keys = ON"); err != nil {
t.Fatalf("enable foreign keys: %v", err)
}
drv := entsql.OpenDB(dialect.SQLite, db)
client := enttest.NewClient(t, enttest.WithOptions(dbent.Driver(drv)))
t.Cleanup(func() { _ = client.Close() })
return client
}
type paymentConfigSettingRepoStub struct {
values map[string]string
updates map[string]string
}
func (s *paymentConfigSettingRepoStub) Get(context.Context, string) (*Setting, error) {
return nil, nil
}
func (s *paymentConfigSettingRepoStub) GetValue(_ context.Context, key string) (string, error) {
return s.values[key], nil
}
func (s *paymentConfigSettingRepoStub) Set(context.Context, string, string) error { return nil }
func (s *paymentConfigSettingRepoStub) GetMultiple(_ context.Context, keys []string) (map[string]string, error) {
out := make(map[string]string, len(keys))
for _, key := range keys {
out[key] = s.values[key]
}
return out, nil
}
func (s *paymentConfigSettingRepoStub) SetMultiple(_ context.Context, values map[string]string) error {
s.updates = make(map[string]string, len(values))
for key, value := range values {
s.updates[key] = value
if s.values == nil {
s.values = map[string]string{}
}
s.values[key] = value
}
return nil
}
func (s *paymentConfigSettingRepoStub) GetAll(context.Context) (map[string]string, error) {
return s.values, nil
}
func (s *paymentConfigSettingRepoStub) Delete(context.Context, string) error { return nil }
func TestUpdatePaymentConfig_PersistsVisibleMethodRouting(t *testing.T) {
repo := &paymentConfigSettingRepoStub{values: map[string]string{}}
svc := &PaymentConfigService{settingRepo: repo}
alipayEnabled := true
wxpayEnabled := false
err := svc.UpdatePaymentConfig(context.Background(), UpdatePaymentConfigRequest{
VisibleMethodAlipayEnabled: &alipayEnabled,
VisibleMethodAlipaySource: paymentConfigStrPtr(VisibleMethodSourceEasyPayAlipay),
VisibleMethodWxpayEnabled: &wxpayEnabled,
VisibleMethodWxpaySource: paymentConfigStrPtr(VisibleMethodSourceOfficialWechat),
})
if err != nil {
t.Fatalf("UpdatePaymentConfig returned error: %v", err)
}
if repo.values[SettingPaymentVisibleMethodAlipayEnabled] != "true" {
t.Fatalf("alipay enabled = %q, want true", repo.values[SettingPaymentVisibleMethodAlipayEnabled])
}
if repo.values[SettingPaymentVisibleMethodAlipaySource] != VisibleMethodSourceEasyPayAlipay {
t.Fatalf("alipay source = %q, want %q", repo.values[SettingPaymentVisibleMethodAlipaySource], VisibleMethodSourceEasyPayAlipay)
}
if repo.values[SettingPaymentVisibleMethodWxpayEnabled] != "false" {
t.Fatalf("wxpay enabled = %q, want false", repo.values[SettingPaymentVisibleMethodWxpayEnabled])
}
if repo.values[SettingPaymentVisibleMethodWxpaySource] != VisibleMethodSourceOfficialWechat {
t.Fatalf("wxpay source = %q, want %q", repo.values[SettingPaymentVisibleMethodWxpaySource], VisibleMethodSourceOfficialWechat)
}
}
func paymentConfigStrPtr(value string) *string {
return &value
}
......@@ -25,22 +25,61 @@ func (s *PaymentService) HandlePaymentNotification(ctx context.Context, n *payme
// Look up order by out_trade_no (the external order ID we sent to the provider)
order, err := s.entClient.PaymentOrder.Query().Where(paymentorder.OutTradeNo(n.OrderID)).Only(ctx)
if err != nil {
// Fallback: try legacy format (sub2_N where N is DB ID)
trimmed := strings.TrimPrefix(n.OrderID, orderIDPrefix)
if oid, parseErr := strconv.ParseInt(trimmed, 10, 64); parseErr == nil {
return s.confirmPayment(ctx, oid, n.TradeNo, n.Amount, pk)
// Fallback only for true legacy "sub2_N" DB-ID payloads when the
// current out_trade_no lookup genuinely did not find an order.
if oid, ok := parseLegacyPaymentOrderID(n.OrderID, err); ok {
return s.confirmPayment(ctx, oid, n.TradeNo, n.Amount, pk, n.Metadata)
}
return fmt.Errorf("order not found for out_trade_no: %s", n.OrderID)
}
return s.confirmPayment(ctx, order.ID, n.TradeNo, n.Amount, pk)
return s.confirmPayment(ctx, order.ID, n.TradeNo, n.Amount, pk, n.Metadata)
}
func (s *PaymentService) confirmPayment(ctx context.Context, oid int64, tradeNo string, paid float64, pk string) error {
func parseLegacyPaymentOrderID(orderID string, lookupErr error) (int64, bool) {
if !dbent.IsNotFound(lookupErr) {
return 0, false
}
orderID = strings.TrimSpace(orderID)
if !strings.HasPrefix(orderID, orderIDPrefix) {
return 0, false
}
trimmed := strings.TrimPrefix(orderID, orderIDPrefix)
if trimmed == "" || trimmed == orderID {
return 0, false
}
oid, err := strconv.ParseInt(trimmed, 10, 64)
if err != nil || oid <= 0 {
return 0, false
}
return oid, true
}
func (s *PaymentService) confirmPayment(ctx context.Context, oid int64, tradeNo string, paid float64, pk string, metadata map[string]string) error {
o, err := s.entClient.PaymentOrder.Get(ctx, oid)
if err != nil {
slog.Error("order not found", "orderID", oid)
return nil
}
instanceProviderKey := ""
if inst, instErr := s.getOrderProviderInstance(ctx, o); instErr == nil && inst != nil {
instanceProviderKey = inst.ProviderKey
}
expectedProviderKey := expectedNotificationProviderKeyForOrder(s.registry, o, instanceProviderKey)
if expectedProviderKey != "" && strings.TrimSpace(pk) != "" && !strings.EqualFold(expectedProviderKey, strings.TrimSpace(pk)) {
s.writeAuditLog(ctx, o.ID, "PAYMENT_PROVIDER_MISMATCH", pk, map[string]any{
"expectedProvider": expectedProviderKey,
"actualProvider": pk,
"tradeNo": tradeNo,
})
return fmt.Errorf("provider mismatch: expected %s, got %s", expectedProviderKey, pk)
}
if err := validateProviderNotificationMetadata(o, pk, metadata); err != nil {
s.writeAuditLog(ctx, o.ID, "PAYMENT_PROVIDER_METADATA_MISMATCH", pk, map[string]any{
"detail": err.Error(),
"tradeNo": tradeNo,
})
return err
}
// Skip amount check when paid=0 (e.g. QueryOrder doesn't return amount).
// Also skip if paid is NaN/Inf (malformed provider data).
if paid > 0 && !math.IsNaN(paid) && !math.IsInf(paid, 0) {
......@@ -56,6 +95,25 @@ func (s *PaymentService) confirmPayment(ctx context.Context, oid int64, tradeNo
return s.toPaid(ctx, o, tradeNo, paid, pk)
}
func validateProviderNotificationMetadata(order *dbent.PaymentOrder, providerKey string, metadata map[string]string) error {
return validateProviderSnapshotMetadata(order, providerKey, metadata)
}
func expectedNotificationProviderKey(registry *payment.Registry, orderPaymentType string, orderProviderKey string, instanceProviderKey string) string {
if key := strings.TrimSpace(instanceProviderKey); key != "" {
return key
}
if key := strings.TrimSpace(orderProviderKey); key != "" {
return key
}
if registry != nil {
if key := strings.TrimSpace(registry.GetProviderKey(payment.PaymentType(orderPaymentType))); key != "" {
return key
}
}
return strings.TrimSpace(orderPaymentType)
}
func (s *PaymentService) toPaid(ctx context.Context, o *dbent.PaymentOrder, tradeNo string, paid float64, pk string) error {
previousStatus := o.Status
now := time.Now()
......
......@@ -3,12 +3,38 @@
package service
import (
"context"
"errors"
"testing"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/internal/payment"
"github.com/stretchr/testify/assert"
)
type paymentFulfillmentTestProvider struct {
key string
supportedTypes []payment.PaymentType
}
func (p paymentFulfillmentTestProvider) Name() string { return p.key }
func (p paymentFulfillmentTestProvider) ProviderKey() string { return p.key }
func (p paymentFulfillmentTestProvider) SupportedTypes() []payment.PaymentType {
return p.supportedTypes
}
func (p paymentFulfillmentTestProvider) CreatePayment(ctx context.Context, req payment.CreatePaymentRequest) (*payment.CreatePaymentResponse, error) {
panic("unexpected call")
}
func (p paymentFulfillmentTestProvider) QueryOrder(ctx context.Context, tradeNo string) (*payment.QueryOrderResponse, error) {
panic("unexpected call")
}
func (p paymentFulfillmentTestProvider) VerifyNotification(ctx context.Context, rawBody string, headers map[string]string) (*payment.PaymentNotification, error) {
panic("unexpected call")
}
func (p paymentFulfillmentTestProvider) Refund(ctx context.Context, req payment.RefundRequest) (*payment.RefundResponse, error) {
panic("unexpected call")
}
// ---------------------------------------------------------------------------
// resolveRedeemAction — pure idempotency decision logic
// ---------------------------------------------------------------------------
......@@ -161,3 +187,171 @@ func TestResolveRedeemAction_IsUsedCanUseConsistency(t *testing.T) {
assert.True(t, unusedCode.CanUse())
assert.Equal(t, redeemActionRedeem, resolveRedeemAction(unusedCode, nil))
}
func TestExpectedNotificationProviderKeyPrefersOrderInstanceProvider(t *testing.T) {
t.Parallel()
registry := payment.NewRegistry()
registry.Register(paymentFulfillmentTestProvider{
key: payment.TypeAlipay,
supportedTypes: []payment.PaymentType{payment.TypeAlipay},
})
assert.Equal(t,
payment.TypeEasyPay,
expectedNotificationProviderKey(registry, payment.TypeAlipay, "", payment.TypeEasyPay),
)
}
func TestExpectedNotificationProviderKeyUsesRegistryMappingForLegacyOrders(t *testing.T) {
t.Parallel()
registry := payment.NewRegistry()
registry.Register(paymentFulfillmentTestProvider{
key: payment.TypeEasyPay,
supportedTypes: []payment.PaymentType{payment.TypeAlipay},
})
assert.Equal(t,
payment.TypeEasyPay,
expectedNotificationProviderKey(registry, payment.TypeAlipay, "", ""),
)
}
func TestExpectedNotificationProviderKeyFallsBackToPaymentType(t *testing.T) {
t.Parallel()
assert.Equal(t,
payment.TypeWxpay,
expectedNotificationProviderKey(nil, payment.TypeWxpay, "", ""),
)
}
func TestExpectedNotificationProviderKeyPrefersOrderSnapshotProviderKey(t *testing.T) {
t.Parallel()
registry := payment.NewRegistry()
registry.Register(paymentFulfillmentTestProvider{
key: payment.TypeAlipay,
supportedTypes: []payment.PaymentType{payment.TypeAlipay},
})
assert.Equal(t,
payment.TypeEasyPay,
expectedNotificationProviderKey(registry, payment.TypeAlipay, payment.TypeEasyPay, ""),
)
}
func TestExpectedNotificationProviderKeyForOrderUsesSnapshotProviderKey(t *testing.T) {
t.Parallel()
registry := payment.NewRegistry()
registry.Register(paymentFulfillmentTestProvider{
key: payment.TypeAlipay,
supportedTypes: []payment.PaymentType{payment.TypeAlipay},
})
order := &dbent.PaymentOrder{
PaymentType: payment.TypeAlipay,
ProviderSnapshot: map[string]any{
"schema_version": 1,
"provider_key": payment.TypeEasyPay,
},
}
assert.Equal(t,
payment.TypeEasyPay,
expectedNotificationProviderKeyForOrder(registry, order, ""),
)
}
func TestValidateProviderNotificationMetadataRejectsWxpaySnapshotMismatch(t *testing.T) {
t.Parallel()
order := &dbent.PaymentOrder{
PaymentType: payment.TypeWxpay,
ProviderSnapshot: map[string]any{
"schema_version": 1,
"merchant_app_id": "wx-app-expected",
"merchant_id": "mch-expected",
"currency": "CNY",
},
}
err := validateProviderNotificationMetadata(order, payment.TypeWxpay, map[string]string{
"appid": "wx-app-other",
"mchid": "mch-expected",
"currency": "CNY",
"trade_state": "SUCCESS",
})
assert.ErrorContains(t, err, "wxpay appid mismatch")
}
func TestValidateProviderNotificationMetadataAllowsLegacyOrdersWithoutSnapshotFields(t *testing.T) {
t.Parallel()
order := &dbent.PaymentOrder{
PaymentType: payment.TypeWxpay,
ProviderSnapshot: map[string]any{
"schema_version": 1,
"provider_instance_id": "9",
"provider_key": payment.TypeWxpay,
},
}
err := validateProviderNotificationMetadata(order, payment.TypeWxpay, map[string]string{
"appid": "wx-app-runtime",
"mchid": "mch-runtime",
"currency": "CNY",
"trade_state": "SUCCESS",
})
assert.NoError(t, err)
}
func TestParseLegacyPaymentOrderID(t *testing.T) {
t.Parallel()
oid, ok := parseLegacyPaymentOrderID("sub2_42", &dbent.NotFoundError{})
assert.True(t, ok)
assert.EqualValues(t, 42, oid)
_, ok = parseLegacyPaymentOrderID("42", &dbent.NotFoundError{})
assert.False(t, ok)
_, ok = parseLegacyPaymentOrderID("sub2_42", errors.New("db down"))
assert.False(t, ok)
}
func TestValidateProviderNotificationMetadataRejectsAlipaySnapshotMismatch(t *testing.T) {
t.Parallel()
order := &dbent.PaymentOrder{
PaymentType: payment.TypeAlipay,
ProviderSnapshot: map[string]any{
"schema_version": 2,
"merchant_app_id": "alipay-app-expected",
},
}
err := validateProviderNotificationMetadata(order, payment.TypeAlipay, map[string]string{
"app_id": "alipay-app-other",
})
assert.ErrorContains(t, err, "alipay app_id mismatch")
}
func TestValidateProviderNotificationMetadataRejectsEasyPaySnapshotMismatch(t *testing.T) {
t.Parallel()
order := &dbent.PaymentOrder{
PaymentType: payment.TypeAlipay,
ProviderSnapshot: map[string]any{
"schema_version": 2,
"merchant_id": "pid-expected",
},
}
err := validateProviderNotificationMetadata(order, payment.TypeEasyPay, map[string]string{
"pid": "pid-other",
})
assert.ErrorContains(t, err, "easypay pid mismatch")
}
......@@ -6,6 +6,7 @@ import (
"fmt"
"log/slog"
"math"
"net/url"
"strconv"
"strings"
"time"
......@@ -23,6 +24,9 @@ func (s *PaymentService) CreateOrder(ctx context.Context, req CreateOrderRequest
if req.OrderType == "" {
req.OrderType = payment.OrderTypeBalance
}
if normalized := NormalizeVisibleMethod(req.PaymentType); normalized != "" {
req.PaymentType = normalized
}
cfg, err := s.configService.GetPaymentConfig(ctx)
if err != nil {
return nil, fmt.Errorf("get payment config: %w", err)
......@@ -55,11 +59,25 @@ func (s *PaymentService) CreateOrder(ctx context.Context, req CreateOrderRequest
feeRate := cfg.RechargeFeeRate
payAmountStr := payment.CalculatePayAmount(limitAmount, feeRate)
payAmount, _ := strconv.ParseFloat(payAmountStr, 64)
order, err := s.createOrderInTx(ctx, req, user, plan, cfg, orderAmount, limitAmount, feeRate, payAmount)
sel, err := s.selectCreateOrderInstance(ctx, req, cfg, payAmount)
if err != nil {
return nil, err
}
if err := s.validateSelectedCreateOrderInstance(ctx, req, sel); err != nil {
return nil, err
}
oauthResp, err := s.maybeBuildWeChatOAuthRequiredResponseForSelection(ctx, req, limitAmount, payAmount, feeRate, sel)
if err != nil {
return nil, err
}
if oauthResp != nil {
return oauthResp, nil
}
order, err := s.createOrderInTx(ctx, req, user, plan, cfg, orderAmount, limitAmount, feeRate, payAmount, sel)
if err != nil {
return nil, err
}
resp, err := s.invokeProvider(ctx, order, req, cfg, limitAmount, payAmountStr, payAmount, plan)
resp, err := s.invokeProvider(ctx, order, req, cfg, limitAmount, payAmountStr, payAmount, plan, sel)
if err != nil {
_, _ = s.entClient.PaymentOrder.UpdateOneID(order.ID).
SetStatus(OrderStatusFailed).
......@@ -104,7 +122,7 @@ func (s *PaymentService) validateSubOrder(ctx context.Context, req CreateOrderRe
return plan, nil
}
func (s *PaymentService) createOrderInTx(ctx context.Context, req CreateOrderRequest, user *User, plan *dbent.SubscriptionPlan, cfg *PaymentConfig, orderAmount, limitAmount, feeRate, payAmount float64) (*dbent.PaymentOrder, error) {
func (s *PaymentService) createOrderInTx(ctx context.Context, req CreateOrderRequest, user *User, plan *dbent.SubscriptionPlan, cfg *PaymentConfig, orderAmount, limitAmount, feeRate, payAmount float64, sel *payment.InstanceSelection) (*dbent.PaymentOrder, error) {
tx, err := s.entClient.Tx(ctx)
if err != nil {
return nil, fmt.Errorf("begin transaction: %w", err)
......@@ -121,6 +139,13 @@ func (s *PaymentService) createOrderInTx(ctx context.Context, req CreateOrderReq
tm = defaultOrderTimeoutMin
}
exp := time.Now().Add(time.Duration(tm) * time.Minute)
providerSnapshot := buildPaymentOrderProviderSnapshot(sel, req)
selectedInstanceID := ""
selectedProviderKey := ""
if sel != nil {
selectedInstanceID = strings.TrimSpace(sel.InstanceID)
selectedProviderKey = strings.TrimSpace(sel.ProviderKey)
}
b := tx.PaymentOrder.Create().
SetUserID(req.UserID).
SetUserEmail(user.Email).
......@@ -141,6 +166,15 @@ func (s *PaymentService) createOrderInTx(ctx context.Context, req CreateOrderReq
if req.SrcURL != "" {
b.SetSrcURL(req.SrcURL)
}
if selectedInstanceID != "" {
b.SetProviderInstanceID(selectedInstanceID)
}
if selectedProviderKey != "" {
b.SetProviderKey(selectedProviderKey)
}
if providerSnapshot != nil {
b.SetProviderSnapshot(providerSnapshot)
}
if plan != nil {
b.SetPlanID(plan.ID).SetSubscriptionGroupID(plan.GroupID).SetSubscriptionDays(psComputeValidityDays(plan.ValidityDays, plan.ValidityUnit))
}
......@@ -174,6 +208,65 @@ func (s *PaymentService) checkPendingLimit(ctx context.Context, tx *dbent.Tx, us
return nil
}
func buildPaymentOrderProviderSnapshot(sel *payment.InstanceSelection, req CreateOrderRequest) map[string]any {
if sel == nil {
return nil
}
snapshot := map[string]any{}
snapshot["schema_version"] = 2
instanceID := strings.TrimSpace(sel.InstanceID)
if instanceID != "" {
snapshot["provider_instance_id"] = instanceID
}
providerKey := strings.TrimSpace(sel.ProviderKey)
if providerKey != "" {
snapshot["provider_key"] = providerKey
}
paymentMode := strings.TrimSpace(sel.PaymentMode)
if paymentMode != "" {
snapshot["payment_mode"] = paymentMode
}
if providerKey == payment.TypeWxpay {
if merchantAppID := paymentOrderSnapshotWxpayAppID(sel, req); merchantAppID != "" {
snapshot["merchant_app_id"] = merchantAppID
}
if merchantID := strings.TrimSpace(sel.Config["mchId"]); merchantID != "" {
snapshot["merchant_id"] = merchantID
}
snapshot["currency"] = "CNY"
}
if providerKey == payment.TypeAlipay {
if merchantAppID := strings.TrimSpace(sel.Config["appId"]); merchantAppID != "" {
snapshot["merchant_app_id"] = merchantAppID
}
}
if providerKey == payment.TypeEasyPay {
if merchantID := strings.TrimSpace(sel.Config["pid"]); merchantID != "" {
snapshot["merchant_id"] = merchantID
}
}
if len(snapshot) == 1 {
return nil
}
return snapshot
}
func paymentOrderSnapshotWxpayAppID(sel *payment.InstanceSelection, req CreateOrderRequest) string {
if sel == nil || strings.TrimSpace(sel.ProviderKey) != payment.TypeWxpay {
return ""
}
if strings.TrimSpace(req.OpenID) != "" {
return strings.TrimSpace(provider.ResolveWxpayJSAPIAppID(sel.Config))
}
return strings.TrimSpace(sel.Config["appId"])
}
func (s *PaymentService) checkDailyLimit(ctx context.Context, tx *dbent.Tx, userID int64, amount, limit float64) error {
if limit <= 0 {
return nil
......@@ -198,10 +291,12 @@ func (s *PaymentService) checkDailyLimit(ctx context.Context, tx *dbent.Tx, user
return nil
}
func (s *PaymentService) invokeProvider(ctx context.Context, order *dbent.PaymentOrder, req CreateOrderRequest, cfg *PaymentConfig, limitAmount float64, payAmountStr string, payAmount float64, plan *dbent.SubscriptionPlan) (*CreateOrderResponse, error) {
// Select an instance across all providers that support the requested payment type.
// This enables cross-provider load balancing (e.g. EasyPay + Alipay direct for "alipay").
sel, err := s.loadBalancer.SelectInstance(ctx, "", req.PaymentType, payment.Strategy(cfg.LoadBalanceStrategy), payAmount)
func (s *PaymentService) selectCreateOrderInstance(ctx context.Context, req CreateOrderRequest, cfg *PaymentConfig, payAmount float64) (*payment.InstanceSelection, error) {
selectCtx, err := s.prepareCreateOrderSelectionContext(ctx, req)
if err != nil {
return nil, err
}
sel, err := s.loadBalancer.SelectInstance(selectCtx, "", req.PaymentType, payment.Strategy(cfg.LoadBalanceStrategy), payAmount)
if err != nil {
return nil, infraerrors.ServiceUnavailable("PAYMENT_GATEWAY_ERROR", "method_not_configured").
WithMetadata(map[string]string{"payment_type": req.PaymentType})
......@@ -209,6 +304,45 @@ func (s *PaymentService) invokeProvider(ctx context.Context, order *dbent.Paymen
if sel == nil {
return nil, infraerrors.TooManyRequests("NO_AVAILABLE_INSTANCE", "no_available_instance")
}
return sel, nil
}
func (s *PaymentService) prepareCreateOrderSelectionContext(ctx context.Context, req CreateOrderRequest) (context.Context, error) {
if !requestNeedsWeChatJSAPICompatibility(req) {
return ctx, nil
}
if !s.usesOfficialWxpayVisibleMethod(ctx) {
return ctx, nil
}
expectedAppID, _, err := s.getWeChatPaymentOAuthCredential(ctx)
if err != nil {
return nil, err
}
return payment.WithWxpayJSAPIAppID(ctx, expectedAppID), nil
}
func requestNeedsWeChatJSAPICompatibility(req CreateOrderRequest) bool {
if payment.GetBasePaymentType(req.PaymentType) != payment.TypeWxpay {
return false
}
return req.IsWeChatBrowser || strings.TrimSpace(req.OpenID) != ""
}
func (s *PaymentService) usesOfficialWxpayVisibleMethod(ctx context.Context) bool {
if s == nil || s.configService == nil {
return false
}
inst, err := s.configService.resolveEnabledVisibleMethodInstance(ctx, payment.TypeWxpay)
if err != nil {
return false
}
if inst == nil {
return false
}
return inst.ProviderKey == payment.TypeWxpay
}
func (s *PaymentService) invokeProvider(ctx context.Context, order *dbent.PaymentOrder, req CreateOrderRequest, cfg *PaymentConfig, limitAmount float64, payAmountStr string, payAmount float64, plan *dbent.SubscriptionPlan, sel *payment.InstanceSelection) (*CreateOrderResponse, error) {
prov, err := provider.CreateProvider(sel.ProviderKey, sel.InstanceID, sel.Config)
if err != nil {
slog.Error("[PaymentService] CreateProvider failed", "provider", sel.ProviderKey, "instance", sel.InstanceID, "error", err)
......@@ -226,16 +360,52 @@ func (s *PaymentService) invokeProvider(ctx context.Context, order *dbent.Paymen
}
subject := s.buildPaymentSubject(plan, limitAmount, cfg)
outTradeNo := order.OutTradeNo
pr, err := prov.CreatePayment(ctx, payment.CreatePaymentRequest{OrderID: outTradeNo, Amount: payAmountStr, PaymentType: req.PaymentType, Subject: subject, ClientIP: req.ClientIP, IsMobile: req.IsMobile, InstanceSubMethods: sel.SupportedTypes})
canonicalReturnURL, err := CanonicalizeReturnURL(req.ReturnURL, req.SrcHost)
if err != nil {
return nil, err
}
resumeToken := ""
if resume := s.paymentResume(); resume != nil {
if resume.isSigningConfigured() {
resumeToken, err = resume.CreateToken(ResumeTokenClaims{
OrderID: order.ID,
UserID: order.UserID,
ProviderInstanceID: sel.InstanceID,
ProviderKey: sel.ProviderKey,
PaymentType: req.PaymentType,
CanonicalReturnURL: canonicalReturnURL,
})
if err != nil {
return nil, fmt.Errorf("create payment resume token: %w", err)
}
}
}
providerReturnURL, err := buildPaymentReturnURL(canonicalReturnURL, order.ID, resumeToken)
if err != nil {
return nil, err
}
providerReq := buildProviderCreatePaymentRequest(CreateOrderRequest{
PaymentType: req.PaymentType,
OpenID: req.OpenID,
ClientIP: req.ClientIP,
IsMobile: req.IsMobile,
ReturnURL: providerReturnURL,
}, sel, outTradeNo, payAmountStr, subject)
pr, err := prov.CreatePayment(ctx, providerReq)
if err != nil {
slog.Error("[PaymentService] CreatePayment failed", "provider", sel.ProviderKey, "instance", sel.InstanceID, "error", err)
if appErr := new(infraerrors.ApplicationError); errors.As(err, &appErr) {
return nil, appErr
}
return nil, infraerrors.ServiceUnavailable("PAYMENT_GATEWAY_ERROR", "payment_gateway_error").
WithMetadata(map[string]string{"provider": sel.ProviderKey, "instance_id": sel.InstanceID})
}
_, err = s.entClient.PaymentOrder.UpdateOneID(order.ID).SetNillablePaymentTradeNo(psNilIfEmpty(pr.TradeNo)).SetNillablePayURL(psNilIfEmpty(pr.PayURL)).SetNillableQrCode(psNilIfEmpty(pr.QRCode)).SetNillableProviderInstanceID(psNilIfEmpty(sel.InstanceID)).Save(ctx)
return nil, classifyCreatePaymentError(req, sel.ProviderKey, err)
}
_, err = s.entClient.PaymentOrder.UpdateOneID(order.ID).
SetNillablePaymentTradeNo(psNilIfEmpty(pr.TradeNo)).
SetNillablePayURL(psNilIfEmpty(pr.PayURL)).
SetNillableQrCode(psNilIfEmpty(pr.QRCode)).
SetNillableProviderInstanceID(psNilIfEmpty(sel.InstanceID)).
SetNillableProviderKey(psNilIfEmpty(sel.ProviderKey)).
Save(ctx)
if err != nil {
return nil, fmt.Errorf("update order with payment details: %w", err)
}
......@@ -245,8 +415,36 @@ func (s *PaymentService) invokeProvider(ctx context.Context, order *dbent.Paymen
"payAmount": order.PayAmount,
"paymentType": req.PaymentType,
"orderType": req.OrderType,
"paymentSource": NormalizePaymentSource(req.PaymentSource),
})
return &CreateOrderResponse{OrderID: order.ID, Amount: order.Amount, PayAmount: payAmount, FeeRate: order.FeeRate, Status: OrderStatusPending, PaymentType: req.PaymentType, PayURL: pr.PayURL, QRCode: pr.QRCode, ClientSecret: pr.ClientSecret, ExpiresAt: order.ExpiresAt, PaymentMode: sel.PaymentMode}, nil
resultType := pr.ResultType
if resultType == "" {
resultType = payment.CreatePaymentResultOrderCreated
}
resp := buildCreateOrderResponse(order, req, payAmount, sel, pr, resultType)
resp.ResumeToken = resumeToken
return resp, nil
}
func buildProviderCreatePaymentRequest(req CreateOrderRequest, sel *payment.InstanceSelection, orderID, amount, subject string) payment.CreatePaymentRequest {
return payment.CreatePaymentRequest{
OrderID: orderID,
Amount: amount,
PaymentType: req.PaymentType,
Subject: subject,
ReturnURL: req.ReturnURL,
OpenID: strings.TrimSpace(req.OpenID),
ClientIP: req.ClientIP,
IsMobile: req.IsMobile,
InstanceSubMethods: selectedInstanceSupportedTypes(sel),
}
}
func selectedInstanceSupportedTypes(sel *payment.InstanceSelection) string {
if sel == nil {
return ""
}
return sel.SupportedTypes
}
func (s *PaymentService) buildPaymentSubject(plan *dbent.SubscriptionPlan, limitAmount float64, cfg *PaymentConfig) string {
......@@ -265,6 +463,190 @@ func (s *PaymentService) buildPaymentSubject(plan *dbent.SubscriptionPlan, limit
return "Sub2API " + amountStr + " CNY"
}
func (s *PaymentService) maybeBuildWeChatOAuthRequiredResponse(ctx context.Context, req CreateOrderRequest, amount, payAmount, feeRate float64) (*CreateOrderResponse, error) {
return s.maybeBuildWeChatOAuthRequiredResponseForSelection(ctx, req, amount, payAmount, feeRate, nil)
}
func (s *PaymentService) maybeBuildWeChatOAuthRequiredResponseForSelection(ctx context.Context, req CreateOrderRequest, amount, payAmount, feeRate float64, sel *payment.InstanceSelection) (*CreateOrderResponse, error) {
if sel != nil && sel.ProviderKey != "" && sel.ProviderKey != payment.TypeWxpay {
return nil, nil
}
if strings.TrimSpace(req.OpenID) != "" || !req.IsWeChatBrowser || payment.GetBasePaymentType(req.PaymentType) != payment.TypeWxpay {
return nil, nil
}
return s.buildWeChatOAuthRequiredResponse(ctx, req, amount, payAmount, feeRate)
}
func (s *PaymentService) buildWeChatOAuthRequiredResponse(ctx context.Context, req CreateOrderRequest, amount, payAmount, feeRate float64) (*CreateOrderResponse, error) {
appID, _, err := s.getWeChatPaymentOAuthCredential(ctx)
if err != nil {
return nil, err
}
authorizeURL, err := buildWeChatPaymentOAuthStartURL(req, "snsapi_base")
if err != nil {
return nil, err
}
return &CreateOrderResponse{
Amount: amount,
PayAmount: payAmount,
FeeRate: feeRate,
ResultType: payment.CreatePaymentResultOAuthRequired,
PaymentType: req.PaymentType,
OAuth: &payment.WechatOAuthInfo{
AuthorizeURL: authorizeURL,
AppID: appID,
Scope: "snsapi_base",
RedirectURL: "/auth/wechat/payment/callback",
},
}, nil
}
func (s *PaymentService) validateSelectedCreateOrderInstance(ctx context.Context, req CreateOrderRequest, sel *payment.InstanceSelection) error {
if !requiresWeChatJSAPICompatibleSelection(req, sel) {
return nil
}
expectedAppID, _, err := s.getWeChatPaymentOAuthCredential(ctx)
if err != nil {
return err
}
selectedAppID := provider.ResolveWxpayJSAPIAppID(sel.Config)
if selectedAppID == "" || selectedAppID != expectedAppID {
return infraerrors.TooManyRequests("NO_AVAILABLE_INSTANCE", "selected payment instance is not compatible with the current WeChat OAuth app")
}
return nil
}
func requiresWeChatJSAPICompatibleSelection(req CreateOrderRequest, sel *payment.InstanceSelection) bool {
if sel == nil || sel.ProviderKey != payment.TypeWxpay || payment.GetBasePaymentType(req.PaymentType) != payment.TypeWxpay {
return false
}
return req.IsWeChatBrowser || strings.TrimSpace(req.OpenID) != ""
}
func (s *PaymentService) getWeChatPaymentOAuthCredential(ctx context.Context) (string, string, error) {
if s == nil || s.configService == nil || s.configService.settingRepo == nil {
return "", "", infraerrors.ServiceUnavailable(
"WECHAT_PAYMENT_MP_NOT_CONFIGURED",
"wechat in-app payment requires a complete WeChat MP OAuth credential",
)
}
cfg, err := (&SettingService{settingRepo: s.configService.settingRepo}).GetWeChatConnectOAuthConfig(ctx)
appID := strings.TrimSpace(cfg.AppIDForMode("mp"))
appSecret := strings.TrimSpace(cfg.AppSecretForMode("mp"))
if err != nil || !cfg.SupportsMode("mp") || appID == "" || appSecret == "" {
return "", "", infraerrors.ServiceUnavailable(
"WECHAT_PAYMENT_MP_NOT_CONFIGURED",
"wechat in-app payment requires a complete WeChat MP OAuth credential",
)
}
return appID, appSecret, nil
}
func classifyCreatePaymentError(req CreateOrderRequest, providerKey string, err error) error {
if err == nil {
return nil
}
if providerKey == payment.TypeWxpay &&
payment.GetBasePaymentType(req.PaymentType) == payment.TypeWxpay &&
strings.Contains(err.Error(), "wxpay h5 payments are not authorized for this merchant") {
return infraerrors.ServiceUnavailable(
"WECHAT_H5_NOT_AUTHORIZED",
"wechat h5 payment is not available for this merchant",
).WithMetadata(map[string]string{
"action": "open_in_wechat_or_scan_qr",
})
}
return infraerrors.ServiceUnavailable("PAYMENT_GATEWAY_ERROR", fmt.Sprintf("payment gateway error: %s", err.Error()))
}
func buildCreateOrderResponse(order *dbent.PaymentOrder, req CreateOrderRequest, payAmount float64, sel *payment.InstanceSelection, pr *payment.CreatePaymentResponse, resultType payment.CreatePaymentResultType) *CreateOrderResponse {
return &CreateOrderResponse{
OrderID: order.ID,
Amount: order.Amount,
PayAmount: payAmount,
FeeRate: order.FeeRate,
Status: OrderStatusPending,
ResultType: resultType,
PaymentType: req.PaymentType,
OutTradeNo: order.OutTradeNo,
PayURL: pr.PayURL,
QRCode: pr.QRCode,
ClientSecret: pr.ClientSecret,
OAuth: pr.OAuth,
JSAPI: pr.JSAPI,
JSAPIPayload: pr.JSAPI,
ExpiresAt: order.ExpiresAt,
PaymentMode: sel.PaymentMode,
}
}
func buildWeChatPaymentOAuthStartURL(req CreateOrderRequest, scope string) (string, error) {
u, err := url.Parse("/api/v1/auth/oauth/wechat/payment/start")
if err != nil {
return "", fmt.Errorf("build wechat payment oauth start url: %w", err)
}
q := u.Query()
q.Set("payment_type", strings.TrimSpace(req.PaymentType))
if req.Amount > 0 {
q.Set("amount", strconv.FormatFloat(req.Amount, 'f', -1, 64))
}
if orderType := strings.TrimSpace(req.OrderType); orderType != "" {
q.Set("order_type", orderType)
}
if req.PlanID > 0 {
q.Set("plan_id", strconv.FormatInt(req.PlanID, 10))
}
if scope = strings.TrimSpace(scope); scope != "" {
q.Set("scope", scope)
}
if redirectTo := paymentRedirectPathFromURL(req.SrcURL); redirectTo != "" {
q.Set("redirect", redirectTo)
}
u.RawQuery = q.Encode()
return u.String(), nil
}
func paymentRedirectPathFromURL(rawURL string) string {
rawURL = strings.TrimSpace(rawURL)
if rawURL == "" {
return "/purchase"
}
if strings.HasPrefix(rawURL, "/") && !strings.HasPrefix(rawURL, "//") {
return normalizePaymentRedirectPath(rawURL)
}
u, err := url.Parse(rawURL)
if err != nil {
return "/purchase"
}
path := strings.TrimSpace(u.EscapedPath())
if path == "" {
path = strings.TrimSpace(u.Path)
}
if path == "" || !strings.HasPrefix(path, "/") || strings.HasPrefix(path, "//") {
return "/purchase"
}
if strings.TrimSpace(u.RawQuery) != "" {
path += "?" + u.RawQuery
}
return normalizePaymentRedirectPath(path)
}
func normalizePaymentRedirectPath(path string) string {
path = strings.TrimSpace(path)
if path == "" {
return "/purchase"
}
if path == "/payment" {
return "/purchase"
}
if strings.HasPrefix(path, "/payment?") {
return "/purchase" + strings.TrimPrefix(path, "/payment")
}
return path
}
// --- Order Queries ---
func (s *PaymentService) GetOrder(ctx context.Context, orderID, userID int64) (*dbent.PaymentOrder, error) {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment