Commit 645609d4 authored by IanShaw027's avatar IanShaw027
Browse files

merge: 正确合并 main 分支改动

合并 origin/main 最新改动,正确保留所有配置:
- Ops 运维监控配置和功能
- LinuxDo Connect OAuth 配置
- Update 在线更新配置
- 优惠码功能
- 其他 main 分支新功能

修复之前合并时错误删除 LinuxDo 和 Update 配置的问题。
parents fc4ea659 56fc2764
...@@ -339,6 +339,7 @@ func groupEntityToService(g *dbent.Group) *service.Group { ...@@ -339,6 +339,7 @@ func groupEntityToService(g *dbent.Group) *service.Group {
RateMultiplier: g.RateMultiplier, RateMultiplier: g.RateMultiplier,
IsExclusive: g.IsExclusive, IsExclusive: g.IsExclusive,
Status: g.Status, Status: g.Status,
Hydrated: true,
SubscriptionType: g.SubscriptionType, SubscriptionType: g.SubscriptionType,
DailyLimitUSD: g.DailyLimitUsd, DailyLimitUSD: g.DailyLimitUsd,
WeeklyLimitUSD: g.WeeklyLimitUsd, WeeklyLimitUSD: g.WeeklyLimitUsd,
......
...@@ -60,6 +60,17 @@ func (r *groupRepository) Create(ctx context.Context, groupIn *service.Group) er ...@@ -60,6 +60,17 @@ func (r *groupRepository) Create(ctx context.Context, groupIn *service.Group) er
} }
func (r *groupRepository) GetByID(ctx context.Context, id int64) (*service.Group, error) { func (r *groupRepository) GetByID(ctx context.Context, id int64) (*service.Group, error) {
out, err := r.GetByIDLite(ctx, id)
if err != nil {
return nil, err
}
count, _ := r.GetAccountCount(ctx, out.ID)
out.AccountCount = count
return out, nil
}
func (r *groupRepository) GetByIDLite(ctx context.Context, id int64) (*service.Group, error) {
// AccountCount is intentionally not loaded here; use GetByID when needed.
m, err := r.client.Group.Query(). m, err := r.client.Group.Query().
Where(group.IDEQ(id)). Where(group.IDEQ(id)).
Only(ctx) Only(ctx)
...@@ -67,10 +78,7 @@ func (r *groupRepository) GetByID(ctx context.Context, id int64) (*service.Group ...@@ -67,10 +78,7 @@ func (r *groupRepository) GetByID(ctx context.Context, id int64) (*service.Group
return nil, translatePersistenceError(err, service.ErrGroupNotFound, nil) return nil, translatePersistenceError(err, service.ErrGroupNotFound, nil)
} }
out := groupEntityToService(m) return groupEntityToService(m), nil
count, _ := r.GetAccountCount(ctx, out.ID)
out.AccountCount = count
return out, nil
} }
func (r *groupRepository) Update(ctx context.Context, groupIn *service.Group) error { func (r *groupRepository) Update(ctx context.Context, groupIn *service.Group) error {
......
...@@ -4,6 +4,8 @@ package repository ...@@ -4,6 +4,8 @@ package repository
import ( import (
"context" "context"
"database/sql"
"errors"
"testing" "testing"
dbent "github.com/Wei-Shaw/sub2api/ent" dbent "github.com/Wei-Shaw/sub2api/ent"
...@@ -19,6 +21,20 @@ type GroupRepoSuite struct { ...@@ -19,6 +21,20 @@ type GroupRepoSuite struct {
repo *groupRepository repo *groupRepository
} }
type forbidSQLExecutor struct {
called bool
}
func (s *forbidSQLExecutor) ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error) {
s.called = true
return nil, errors.New("unexpected sql exec")
}
func (s *forbidSQLExecutor) QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error) {
s.called = true
return nil, errors.New("unexpected sql query")
}
func (s *GroupRepoSuite) SetupTest() { func (s *GroupRepoSuite) SetupTest() {
s.ctx = context.Background() s.ctx = context.Background()
tx := testEntTx(s.T()) tx := testEntTx(s.T())
...@@ -57,6 +73,26 @@ func (s *GroupRepoSuite) TestGetByID_NotFound() { ...@@ -57,6 +73,26 @@ func (s *GroupRepoSuite) TestGetByID_NotFound() {
s.Require().ErrorIs(err, service.ErrGroupNotFound) s.Require().ErrorIs(err, service.ErrGroupNotFound)
} }
func (s *GroupRepoSuite) TestGetByIDLite_DoesNotUseAccountCount() {
group := &service.Group{
Name: "lite-group",
Platform: service.PlatformAnthropic,
RateMultiplier: 1.0,
IsExclusive: false,
Status: service.StatusActive,
SubscriptionType: service.SubscriptionTypeStandard,
}
s.Require().NoError(s.repo.Create(s.ctx, group))
spy := &forbidSQLExecutor{}
repo := newGroupRepositoryWithSQL(s.tx.Client(), spy)
got, err := repo.GetByIDLite(s.ctx, group.ID)
s.Require().NoError(err)
s.Require().Equal(group.ID, got.ID)
s.Require().False(spy.called, "expected no direct sql executor usage")
}
func (s *GroupRepoSuite) TestUpdate() { func (s *GroupRepoSuite) TestUpdate() {
group := &service.Group{ group := &service.Group{
Name: "original", Name: "original",
......
package repository
import (
"context"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/promocode"
"github.com/Wei-Shaw/sub2api/ent/promocodeusage"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service"
)
type promoCodeRepository struct {
client *dbent.Client
}
func NewPromoCodeRepository(client *dbent.Client) service.PromoCodeRepository {
return &promoCodeRepository{client: client}
}
func (r *promoCodeRepository) Create(ctx context.Context, code *service.PromoCode) error {
client := clientFromContext(ctx, r.client)
builder := client.PromoCode.Create().
SetCode(code.Code).
SetBonusAmount(code.BonusAmount).
SetMaxUses(code.MaxUses).
SetUsedCount(code.UsedCount).
SetStatus(code.Status).
SetNotes(code.Notes)
if code.ExpiresAt != nil {
builder.SetExpiresAt(*code.ExpiresAt)
}
created, err := builder.Save(ctx)
if err != nil {
return err
}
code.ID = created.ID
code.CreatedAt = created.CreatedAt
code.UpdatedAt = created.UpdatedAt
return nil
}
func (r *promoCodeRepository) GetByID(ctx context.Context, id int64) (*service.PromoCode, error) {
m, err := r.client.PromoCode.Query().
Where(promocode.IDEQ(id)).
Only(ctx)
if err != nil {
if dbent.IsNotFound(err) {
return nil, service.ErrPromoCodeNotFound
}
return nil, err
}
return promoCodeEntityToService(m), nil
}
func (r *promoCodeRepository) GetByCode(ctx context.Context, code string) (*service.PromoCode, error) {
m, err := r.client.PromoCode.Query().
Where(promocode.CodeEqualFold(code)).
Only(ctx)
if err != nil {
if dbent.IsNotFound(err) {
return nil, service.ErrPromoCodeNotFound
}
return nil, err
}
return promoCodeEntityToService(m), nil
}
func (r *promoCodeRepository) GetByCodeForUpdate(ctx context.Context, code string) (*service.PromoCode, error) {
client := clientFromContext(ctx, r.client)
m, err := client.PromoCode.Query().
Where(promocode.CodeEqualFold(code)).
ForUpdate().
Only(ctx)
if err != nil {
if dbent.IsNotFound(err) {
return nil, service.ErrPromoCodeNotFound
}
return nil, err
}
return promoCodeEntityToService(m), nil
}
func (r *promoCodeRepository) Update(ctx context.Context, code *service.PromoCode) error {
client := clientFromContext(ctx, r.client)
builder := client.PromoCode.UpdateOneID(code.ID).
SetCode(code.Code).
SetBonusAmount(code.BonusAmount).
SetMaxUses(code.MaxUses).
SetUsedCount(code.UsedCount).
SetStatus(code.Status).
SetNotes(code.Notes)
if code.ExpiresAt != nil {
builder.SetExpiresAt(*code.ExpiresAt)
} else {
builder.ClearExpiresAt()
}
updated, err := builder.Save(ctx)
if err != nil {
if dbent.IsNotFound(err) {
return service.ErrPromoCodeNotFound
}
return err
}
code.UpdatedAt = updated.UpdatedAt
return nil
}
func (r *promoCodeRepository) Delete(ctx context.Context, id int64) error {
client := clientFromContext(ctx, r.client)
_, err := client.PromoCode.Delete().Where(promocode.IDEQ(id)).Exec(ctx)
return err
}
func (r *promoCodeRepository) List(ctx context.Context, params pagination.PaginationParams) ([]service.PromoCode, *pagination.PaginationResult, error) {
return r.ListWithFilters(ctx, params, "", "")
}
func (r *promoCodeRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, status, search string) ([]service.PromoCode, *pagination.PaginationResult, error) {
q := r.client.PromoCode.Query()
if status != "" {
q = q.Where(promocode.StatusEQ(status))
}
if search != "" {
q = q.Where(promocode.CodeContainsFold(search))
}
total, err := q.Count(ctx)
if err != nil {
return nil, nil, err
}
codes, err := q.
Offset(params.Offset()).
Limit(params.Limit()).
Order(dbent.Desc(promocode.FieldID)).
All(ctx)
if err != nil {
return nil, nil, err
}
outCodes := promoCodeEntitiesToService(codes)
return outCodes, paginationResultFromTotal(int64(total), params), nil
}
func (r *promoCodeRepository) CreateUsage(ctx context.Context, usage *service.PromoCodeUsage) error {
client := clientFromContext(ctx, r.client)
created, err := client.PromoCodeUsage.Create().
SetPromoCodeID(usage.PromoCodeID).
SetUserID(usage.UserID).
SetBonusAmount(usage.BonusAmount).
SetUsedAt(usage.UsedAt).
Save(ctx)
if err != nil {
return err
}
usage.ID = created.ID
return nil
}
func (r *promoCodeRepository) GetUsageByPromoCodeAndUser(ctx context.Context, promoCodeID, userID int64) (*service.PromoCodeUsage, error) {
m, err := r.client.PromoCodeUsage.Query().
Where(
promocodeusage.PromoCodeIDEQ(promoCodeID),
promocodeusage.UserIDEQ(userID),
).
Only(ctx)
if err != nil {
if dbent.IsNotFound(err) {
return nil, nil
}
return nil, err
}
return promoCodeUsageEntityToService(m), nil
}
func (r *promoCodeRepository) ListUsagesByPromoCode(ctx context.Context, promoCodeID int64, params pagination.PaginationParams) ([]service.PromoCodeUsage, *pagination.PaginationResult, error) {
q := r.client.PromoCodeUsage.Query().
Where(promocodeusage.PromoCodeIDEQ(promoCodeID))
total, err := q.Count(ctx)
if err != nil {
return nil, nil, err
}
usages, err := q.
WithUser().
Offset(params.Offset()).
Limit(params.Limit()).
Order(dbent.Desc(promocodeusage.FieldID)).
All(ctx)
if err != nil {
return nil, nil, err
}
outUsages := promoCodeUsageEntitiesToService(usages)
return outUsages, paginationResultFromTotal(int64(total), params), nil
}
func (r *promoCodeRepository) IncrementUsedCount(ctx context.Context, id int64) error {
client := clientFromContext(ctx, r.client)
_, err := client.PromoCode.UpdateOneID(id).
AddUsedCount(1).
Save(ctx)
return err
}
// Entity to Service conversions
func promoCodeEntityToService(m *dbent.PromoCode) *service.PromoCode {
if m == nil {
return nil
}
return &service.PromoCode{
ID: m.ID,
Code: m.Code,
BonusAmount: m.BonusAmount,
MaxUses: m.MaxUses,
UsedCount: m.UsedCount,
Status: m.Status,
ExpiresAt: m.ExpiresAt,
Notes: derefString(m.Notes),
CreatedAt: m.CreatedAt,
UpdatedAt: m.UpdatedAt,
}
}
func promoCodeEntitiesToService(models []*dbent.PromoCode) []service.PromoCode {
out := make([]service.PromoCode, 0, len(models))
for i := range models {
if s := promoCodeEntityToService(models[i]); s != nil {
out = append(out, *s)
}
}
return out
}
func promoCodeUsageEntityToService(m *dbent.PromoCodeUsage) *service.PromoCodeUsage {
if m == nil {
return nil
}
out := &service.PromoCodeUsage{
ID: m.ID,
PromoCodeID: m.PromoCodeID,
UserID: m.UserID,
BonusAmount: m.BonusAmount,
UsedAt: m.UsedAt,
}
if m.Edges.User != nil {
out.User = userEntityToService(m.Edges.User)
}
return out
}
func promoCodeUsageEntitiesToService(models []*dbent.PromoCodeUsage) []service.PromoCodeUsage {
out := make([]service.PromoCodeUsage, 0, len(models))
for i := range models {
if s := promoCodeUsageEntityToService(models[i]); s != nil {
out = append(out, *s)
}
}
return out
}
...@@ -45,6 +45,7 @@ var ProviderSet = wire.NewSet( ...@@ -45,6 +45,7 @@ var ProviderSet = wire.NewSet(
NewAccountRepository, NewAccountRepository,
NewProxyRepository, NewProxyRepository,
NewRedeemCodeRepository, NewRedeemCodeRepository,
NewPromoCodeRepository,
NewUsageLogRepository, NewUsageLogRepository,
NewSettingRepository, NewSettingRepository,
NewOpsRepository, NewOpsRepository,
......
...@@ -327,10 +327,7 @@ func TestAPIContracts(t *testing.T) { ...@@ -327,10 +327,7 @@ func TestAPIContracts(t *testing.T) {
"fallback_model_openai": "gpt-4o", "fallback_model_openai": "gpt-4o",
"enable_identity_patch": true, "enable_identity_patch": true,
"identity_patch_prompt": "", "identity_patch_prompt": "",
"ops_monitoring_enabled": true, "home_content": ""
"ops_realtime_monitoring_enabled": true,
"ops_query_mode_default": "auto",
"ops_metrics_interval_seconds": 60
} }
}`, }`,
}, },
...@@ -402,7 +399,7 @@ func newContractDeps(t *testing.T) *contractDeps { ...@@ -402,7 +399,7 @@ func newContractDeps(t *testing.T) *contractDeps {
settingRepo := newStubSettingRepo() settingRepo := newStubSettingRepo()
settingService := service.NewSettingService(settingRepo, cfg) settingService := service.NewSettingService(settingRepo, cfg)
authHandler := handler.NewAuthHandler(cfg, nil, userService, settingService) authHandler := handler.NewAuthHandler(cfg, nil, userService, settingService, nil)
apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService) apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService)
usageHandler := handler.NewUsageHandler(usageService, apiKeyService) usageHandler := handler.NewUsageHandler(usageService, apiKeyService)
adminSettingHandler := adminhandler.NewSettingHandler(settingService, nil, nil) adminSettingHandler := adminhandler.NewSettingHandler(settingService, nil, nil)
...@@ -579,6 +576,10 @@ func (stubGroupRepo) GetByID(ctx context.Context, id int64) (*service.Group, err ...@@ -579,6 +576,10 @@ func (stubGroupRepo) GetByID(ctx context.Context, id int64) (*service.Group, err
return nil, service.ErrGroupNotFound return nil, service.ErrGroupNotFound
} }
func (stubGroupRepo) GetByIDLite(ctx context.Context, id int64) (*service.Group, error) {
return nil, service.ErrGroupNotFound
}
func (stubGroupRepo) Update(ctx context.Context, group *service.Group) error { func (stubGroupRepo) Update(ctx context.Context, group *service.Group) error {
return errors.New("not implemented") return errors.New("not implemented")
} }
......
...@@ -13,6 +13,7 @@ import ( ...@@ -13,6 +13,7 @@ import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/google/wire" "github.com/google/wire"
"github.com/redis/go-redis/v9"
) )
// ProviderSet 提供服务器层的依赖 // ProviderSet 提供服务器层的依赖
...@@ -31,6 +32,8 @@ func ProvideRouter( ...@@ -31,6 +32,8 @@ func ProvideRouter(
apiKeyService *service.APIKeyService, apiKeyService *service.APIKeyService,
subscriptionService *service.SubscriptionService, subscriptionService *service.SubscriptionService,
opsService *service.OpsService, opsService *service.OpsService,
settingService *service.SettingService,
redisClient *redis.Client,
) *gin.Engine { ) *gin.Engine {
if cfg.Server.Mode == "release" { if cfg.Server.Mode == "release" {
gin.SetMode(gin.ReleaseMode) gin.SetMode(gin.ReleaseMode)
...@@ -48,7 +51,7 @@ func ProvideRouter( ...@@ -48,7 +51,7 @@ func ProvideRouter(
} }
} }
return SetupRouter(r, handlers, jwtAuth, adminAuth, apiKeyAuth, apiKeyService, subscriptionService, opsService, cfg) return SetupRouter(r, handlers, jwtAuth, adminAuth, apiKeyAuth, apiKeyService, subscriptionService, opsService, settingService, cfg, redisClient)
} }
// ProvideHTTPServer 提供 HTTP 服务器 // ProvideHTTPServer 提供 HTTP 服务器
......
package middleware package middleware
import ( import (
"context"
"errors" "errors"
"log" "log"
"strings" "strings"
"github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
"github.com/Wei-Shaw/sub2api/internal/pkg/ip" "github.com/Wei-Shaw/sub2api/internal/pkg/ip"
"github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
...@@ -103,6 +105,7 @@ func apiKeyAuthWithSubscription(apiKeyService *service.APIKeyService, subscripti ...@@ -103,6 +105,7 @@ func apiKeyAuthWithSubscription(apiKeyService *service.APIKeyService, subscripti
Concurrency: apiKey.User.Concurrency, Concurrency: apiKey.User.Concurrency,
}) })
c.Set(string(ContextKeyUserRole), apiKey.User.Role) c.Set(string(ContextKeyUserRole), apiKey.User.Role)
setGroupContext(c, apiKey.Group)
c.Next() c.Next()
return return
} }
...@@ -161,6 +164,7 @@ func apiKeyAuthWithSubscription(apiKeyService *service.APIKeyService, subscripti ...@@ -161,6 +164,7 @@ func apiKeyAuthWithSubscription(apiKeyService *service.APIKeyService, subscripti
Concurrency: apiKey.User.Concurrency, Concurrency: apiKey.User.Concurrency,
}) })
c.Set(string(ContextKeyUserRole), apiKey.User.Role) c.Set(string(ContextKeyUserRole), apiKey.User.Role)
setGroupContext(c, apiKey.Group)
c.Next() c.Next()
} }
...@@ -185,3 +189,14 @@ func GetSubscriptionFromContext(c *gin.Context) (*service.UserSubscription, bool ...@@ -185,3 +189,14 @@ func GetSubscriptionFromContext(c *gin.Context) (*service.UserSubscription, bool
subscription, ok := value.(*service.UserSubscription) subscription, ok := value.(*service.UserSubscription)
return subscription, ok return subscription, ok
} }
func setGroupContext(c *gin.Context, group *service.Group) {
if !service.IsGroupContextValid(group) {
return
}
if existing, ok := c.Request.Context().Value(ctxkey.Group).(*service.Group); ok && existing != nil && existing.ID == group.ID && service.IsGroupContextValid(existing) {
return
}
ctx := context.WithValue(c.Request.Context(), ctxkey.Group, group)
c.Request = c.Request.WithContext(ctx)
}
...@@ -63,6 +63,7 @@ func APIKeyAuthWithSubscriptionGoogle(apiKeyService *service.APIKeyService, subs ...@@ -63,6 +63,7 @@ func APIKeyAuthWithSubscriptionGoogle(apiKeyService *service.APIKeyService, subs
Concurrency: apiKey.User.Concurrency, Concurrency: apiKey.User.Concurrency,
}) })
c.Set(string(ContextKeyUserRole), apiKey.User.Role) c.Set(string(ContextKeyUserRole), apiKey.User.Role)
setGroupContext(c, apiKey.Group)
c.Next() c.Next()
return return
} }
...@@ -102,6 +103,7 @@ func APIKeyAuthWithSubscriptionGoogle(apiKeyService *service.APIKeyService, subs ...@@ -102,6 +103,7 @@ func APIKeyAuthWithSubscriptionGoogle(apiKeyService *service.APIKeyService, subs
Concurrency: apiKey.User.Concurrency, Concurrency: apiKey.User.Concurrency,
}) })
c.Set(string(ContextKeyUserRole), apiKey.User.Role) c.Set(string(ContextKeyUserRole), apiKey.User.Role)
setGroupContext(c, apiKey.Group)
c.Next() c.Next()
} }
} }
......
...@@ -9,6 +9,7 @@ import ( ...@@ -9,6 +9,7 @@ import (
"testing" "testing"
"github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
...@@ -133,6 +134,70 @@ func TestApiKeyAuthWithSubscriptionGoogle_QueryApiKeyRejected(t *testing.T) { ...@@ -133,6 +134,70 @@ func TestApiKeyAuthWithSubscriptionGoogle_QueryApiKeyRejected(t *testing.T) {
require.Equal(t, "INVALID_ARGUMENT", resp.Error.Status) require.Equal(t, "INVALID_ARGUMENT", resp.Error.Status)
} }
func TestApiKeyAuthWithSubscriptionGoogleSetsGroupContext(t *testing.T) {
gin.SetMode(gin.TestMode)
group := &service.Group{
ID: 99,
Name: "g1",
Status: service.StatusActive,
Platform: service.PlatformGemini,
Hydrated: true,
}
user := &service.User{
ID: 7,
Role: service.RoleUser,
Status: service.StatusActive,
Balance: 10,
Concurrency: 3,
}
apiKey := &service.APIKey{
ID: 100,
UserID: user.ID,
Key: "test-key",
Status: service.StatusActive,
User: user,
Group: group,
}
apiKey.GroupID = &group.ID
apiKeyService := service.NewAPIKeyService(
fakeAPIKeyRepo{
getByKey: func(ctx context.Context, key string) (*service.APIKey, error) {
if key != apiKey.Key {
return nil, service.ErrAPIKeyNotFound
}
clone := *apiKey
return &clone, nil
},
},
nil,
nil,
nil,
nil,
&config.Config{RunMode: config.RunModeSimple},
)
cfg := &config.Config{RunMode: config.RunModeSimple}
r := gin.New()
r.Use(APIKeyAuthWithSubscriptionGoogle(apiKeyService, nil, cfg))
r.GET("/v1beta/test", func(c *gin.Context) {
groupFromCtx, ok := c.Request.Context().Value(ctxkey.Group).(*service.Group)
if !ok || groupFromCtx == nil || groupFromCtx.ID != group.ID {
c.JSON(http.StatusInternalServerError, gin.H{"ok": false})
return
}
c.JSON(http.StatusOK, gin.H{"ok": true})
})
req := httptest.NewRequest(http.MethodGet, "/v1beta/test", nil)
req.Header.Set("x-api-key", apiKey.Key)
rec := httptest.NewRecorder()
r.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
}
func TestApiKeyAuthWithSubscriptionGoogle_QueryKeyAllowedOnV1Beta(t *testing.T) { func TestApiKeyAuthWithSubscriptionGoogle_QueryKeyAllowedOnV1Beta(t *testing.T) {
gin.SetMode(gin.TestMode) gin.SetMode(gin.TestMode)
......
...@@ -11,6 +11,7 @@ import ( ...@@ -11,6 +11,7 @@ import (
"time" "time"
"github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
...@@ -25,6 +26,7 @@ func TestSimpleModeBypassesQuotaCheck(t *testing.T) { ...@@ -25,6 +26,7 @@ func TestSimpleModeBypassesQuotaCheck(t *testing.T) {
ID: 42, ID: 42,
Name: "sub", Name: "sub",
Status: service.StatusActive, Status: service.StatusActive,
Hydrated: true,
SubscriptionType: service.SubscriptionTypeSubscription, SubscriptionType: service.SubscriptionTypeSubscription,
DailyLimitUSD: &limit, DailyLimitUSD: &limit,
} }
...@@ -110,6 +112,129 @@ func TestSimpleModeBypassesQuotaCheck(t *testing.T) { ...@@ -110,6 +112,129 @@ func TestSimpleModeBypassesQuotaCheck(t *testing.T) {
}) })
} }
func TestAPIKeyAuthSetsGroupContext(t *testing.T) {
gin.SetMode(gin.TestMode)
group := &service.Group{
ID: 101,
Name: "g1",
Status: service.StatusActive,
Platform: service.PlatformAnthropic,
Hydrated: true,
}
user := &service.User{
ID: 7,
Role: service.RoleUser,
Status: service.StatusActive,
Balance: 10,
Concurrency: 3,
}
apiKey := &service.APIKey{
ID: 100,
UserID: user.ID,
Key: "test-key",
Status: service.StatusActive,
User: user,
Group: group,
}
apiKey.GroupID = &group.ID
apiKeyRepo := &stubApiKeyRepo{
getByKey: func(ctx context.Context, key string) (*service.APIKey, error) {
if key != apiKey.Key {
return nil, service.ErrAPIKeyNotFound
}
clone := *apiKey
return &clone, nil
},
}
cfg := &config.Config{RunMode: config.RunModeSimple}
apiKeyService := service.NewAPIKeyService(apiKeyRepo, nil, nil, nil, nil, cfg)
router := gin.New()
router.Use(gin.HandlerFunc(NewAPIKeyAuthMiddleware(apiKeyService, nil, cfg)))
router.GET("/t", func(c *gin.Context) {
groupFromCtx, ok := c.Request.Context().Value(ctxkey.Group).(*service.Group)
if !ok || groupFromCtx == nil || groupFromCtx.ID != group.ID {
c.JSON(http.StatusInternalServerError, gin.H{"ok": false})
return
}
c.JSON(http.StatusOK, gin.H{"ok": true})
})
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/t", nil)
req.Header.Set("x-api-key", apiKey.Key)
router.ServeHTTP(w, req)
require.Equal(t, http.StatusOK, w.Code)
}
func TestAPIKeyAuthOverwritesInvalidContextGroup(t *testing.T) {
gin.SetMode(gin.TestMode)
group := &service.Group{
ID: 101,
Name: "g1",
Status: service.StatusActive,
Platform: service.PlatformAnthropic,
Hydrated: true,
}
user := &service.User{
ID: 7,
Role: service.RoleUser,
Status: service.StatusActive,
Balance: 10,
Concurrency: 3,
}
apiKey := &service.APIKey{
ID: 100,
UserID: user.ID,
Key: "test-key",
Status: service.StatusActive,
User: user,
Group: group,
}
apiKey.GroupID = &group.ID
apiKeyRepo := &stubApiKeyRepo{
getByKey: func(ctx context.Context, key string) (*service.APIKey, error) {
if key != apiKey.Key {
return nil, service.ErrAPIKeyNotFound
}
clone := *apiKey
return &clone, nil
},
}
cfg := &config.Config{RunMode: config.RunModeSimple}
apiKeyService := service.NewAPIKeyService(apiKeyRepo, nil, nil, nil, nil, cfg)
router := gin.New()
router.Use(gin.HandlerFunc(NewAPIKeyAuthMiddleware(apiKeyService, nil, cfg)))
invalidGroup := &service.Group{
ID: group.ID,
Platform: group.Platform,
Status: group.Status,
}
router.GET("/t", func(c *gin.Context) {
groupFromCtx, ok := c.Request.Context().Value(ctxkey.Group).(*service.Group)
if !ok || groupFromCtx == nil || groupFromCtx.ID != group.ID || !groupFromCtx.Hydrated || groupFromCtx == invalidGroup {
c.JSON(http.StatusInternalServerError, gin.H{"ok": false})
return
}
c.JSON(http.StatusOK, gin.H{"ok": true})
})
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/t", nil)
req.Header.Set("x-api-key", apiKey.Key)
req = req.WithContext(context.WithValue(req.Context(), ctxkey.Group, invalidGroup))
router.ServeHTTP(w, req)
require.Equal(t, http.StatusOK, w.Code)
}
func newAuthTestRouter(apiKeyService *service.APIKeyService, subscriptionService *service.SubscriptionService, cfg *config.Config) *gin.Engine { func newAuthTestRouter(apiKeyService *service.APIKeyService, subscriptionService *service.SubscriptionService, cfg *config.Config) *gin.Engine {
router := gin.New() router := gin.New()
router.Use(gin.HandlerFunc(NewAPIKeyAuthMiddleware(apiKeyService, subscriptionService, cfg))) router.Use(gin.HandlerFunc(NewAPIKeyAuthMiddleware(apiKeyService, subscriptionService, cfg)))
......
package server package server
import ( import (
"log"
"github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/handler" "github.com/Wei-Shaw/sub2api/internal/handler"
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware" middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
...@@ -9,6 +11,7 @@ import ( ...@@ -9,6 +11,7 @@ import (
"github.com/Wei-Shaw/sub2api/internal/web" "github.com/Wei-Shaw/sub2api/internal/web"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/redis/go-redis/v9"
) )
// SetupRouter 配置路由器中间件和路由 // SetupRouter 配置路由器中间件和路由
...@@ -21,20 +24,30 @@ func SetupRouter( ...@@ -21,20 +24,30 @@ func SetupRouter(
apiKeyService *service.APIKeyService, apiKeyService *service.APIKeyService,
subscriptionService *service.SubscriptionService, subscriptionService *service.SubscriptionService,
opsService *service.OpsService, opsService *service.OpsService,
settingService *service.SettingService,
cfg *config.Config, cfg *config.Config,
redisClient *redis.Client,
) *gin.Engine { ) *gin.Engine {
// 应用中间件 // 应用中间件
r.Use(middleware2.Logger()) r.Use(middleware2.Logger())
r.Use(middleware2.CORS(cfg.CORS)) r.Use(middleware2.CORS(cfg.CORS))
r.Use(middleware2.SecurityHeaders(cfg.Security.CSP)) r.Use(middleware2.SecurityHeaders(cfg.Security.CSP))
// Serve embedded frontend if available // Serve embedded frontend with settings injection if available
if web.HasEmbeddedFrontend() { if web.HasEmbeddedFrontend() {
r.Use(web.ServeEmbeddedFrontend()) frontendServer, err := web.NewFrontendServer(settingService)
if err != nil {
log.Printf("Warning: Failed to create frontend server with settings injection: %v, using legacy mode", err)
r.Use(web.ServeEmbeddedFrontend())
} else {
// Register cache invalidation callback
settingService.SetOnUpdateCallback(frontendServer.InvalidateCache)
r.Use(frontendServer.Middleware())
}
} }
// 注册路由 // 注册路由
registerRoutes(r, handlers, jwtAuth, adminAuth, apiKeyAuth, apiKeyService, subscriptionService, opsService, cfg) registerRoutes(r, handlers, jwtAuth, adminAuth, apiKeyAuth, apiKeyService, subscriptionService, opsService, cfg, redisClient)
return r return r
} }
...@@ -50,6 +63,7 @@ func registerRoutes( ...@@ -50,6 +63,7 @@ func registerRoutes(
subscriptionService *service.SubscriptionService, subscriptionService *service.SubscriptionService,
opsService *service.OpsService, opsService *service.OpsService,
cfg *config.Config, cfg *config.Config,
redisClient *redis.Client,
) { ) {
// 通用路由(健康检查、状态等) // 通用路由(健康检查、状态等)
routes.RegisterCommonRoutes(r) routes.RegisterCommonRoutes(r)
...@@ -58,7 +72,7 @@ func registerRoutes( ...@@ -58,7 +72,7 @@ func registerRoutes(
v1 := r.Group("/api/v1") v1 := r.Group("/api/v1")
// 注册各模块路由 // 注册各模块路由
routes.RegisterAuthRoutes(v1, h, jwtAuth) routes.RegisterAuthRoutes(v1, h, jwtAuth, redisClient)
routes.RegisterUserRoutes(v1, h, jwtAuth) routes.RegisterUserRoutes(v1, h, jwtAuth)
routes.RegisterAdminRoutes(v1, h, adminAuth) routes.RegisterAdminRoutes(v1, h, adminAuth)
routes.RegisterGatewayRoutes(r, h, apiKeyAuth, apiKeyService, subscriptionService, opsService, cfg) routes.RegisterGatewayRoutes(r, h, apiKeyAuth, apiKeyService, subscriptionService, opsService, cfg)
......
...@@ -44,6 +44,9 @@ func RegisterAdminRoutes( ...@@ -44,6 +44,9 @@ func RegisterAdminRoutes(
// 卡密管理 // 卡密管理
registerRedeemCodeRoutes(admin, h) registerRedeemCodeRoutes(admin, h)
// 优惠码管理
registerPromoCodeRoutes(admin, h)
// 系统设置 // 系统设置
registerSettingsRoutes(admin, h) registerSettingsRoutes(admin, h)
...@@ -252,6 +255,18 @@ func registerRedeemCodeRoutes(admin *gin.RouterGroup, h *handler.Handlers) { ...@@ -252,6 +255,18 @@ func registerRedeemCodeRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
} }
} }
func registerPromoCodeRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
promoCodes := admin.Group("/promo-codes")
{
promoCodes.GET("", h.Admin.Promo.List)
promoCodes.GET("/:id", h.Admin.Promo.GetByID)
promoCodes.POST("", h.Admin.Promo.Create)
promoCodes.PUT("/:id", h.Admin.Promo.Update)
promoCodes.DELETE("/:id", h.Admin.Promo.Delete)
promoCodes.GET("/:id/usages", h.Admin.Promo.GetUsages)
}
}
func registerSettingsRoutes(admin *gin.RouterGroup, h *handler.Handlers) { func registerSettingsRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
adminSettings := admin.Group("/settings") adminSettings := admin.Group("/settings")
{ {
......
package routes package routes
import ( import (
"time"
"github.com/Wei-Shaw/sub2api/internal/handler" "github.com/Wei-Shaw/sub2api/internal/handler"
"github.com/Wei-Shaw/sub2api/internal/server/middleware" "github.com/Wei-Shaw/sub2api/internal/middleware"
servermiddleware "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/redis/go-redis/v9"
) )
// RegisterAuthRoutes 注册认证相关路由 // RegisterAuthRoutes 注册认证相关路由
func RegisterAuthRoutes( func RegisterAuthRoutes(
v1 *gin.RouterGroup, v1 *gin.RouterGroup,
h *handler.Handlers, h *handler.Handlers,
jwtAuth middleware.JWTAuthMiddleware, jwtAuth servermiddleware.JWTAuthMiddleware,
redisClient *redis.Client,
) { ) {
// 创建速率限制器
rateLimiter := middleware.NewRateLimiter(redisClient)
// 公开接口 // 公开接口
auth := v1.Group("/auth") auth := v1.Group("/auth")
{ {
auth.POST("/register", h.Auth.Register) auth.POST("/register", h.Auth.Register)
auth.POST("/login", h.Auth.Login) auth.POST("/login", h.Auth.Login)
auth.POST("/send-verify-code", h.Auth.SendVerifyCode) auth.POST("/send-verify-code", h.Auth.SendVerifyCode)
// 优惠码验证接口添加速率限制:每分钟最多 10 次
auth.POST("/validate-promo-code", rateLimiter.Limit("validate-promo", 10, time.Minute), h.Auth.ValidatePromoCode)
auth.GET("/oauth/linuxdo/start", h.Auth.LinuxDoOAuthStart) auth.GET("/oauth/linuxdo/start", h.Auth.LinuxDoOAuthStart)
auth.GET("/oauth/linuxdo/callback", h.Auth.LinuxDoOAuthCallback) auth.GET("/oauth/linuxdo/callback", h.Auth.LinuxDoOAuthCallback)
} }
......
...@@ -576,18 +576,33 @@ func (s *adminServiceImpl) validateFallbackGroup(ctx context.Context, currentGro ...@@ -576,18 +576,33 @@ func (s *adminServiceImpl) validateFallbackGroup(ctx context.Context, currentGro
return fmt.Errorf("cannot set self as fallback group") return fmt.Errorf("cannot set self as fallback group")
} }
// 检查降级分组是否存在 visited := map[int64]struct{}{}
fallbackGroup, err := s.groupRepo.GetByID(ctx, fallbackGroupID) nextID := fallbackGroupID
if err != nil { for {
return fmt.Errorf("fallback group not found: %w", err) if _, seen := visited[nextID]; seen {
} return fmt.Errorf("fallback group cycle detected")
}
visited[nextID] = struct{}{}
if currentGroupID > 0 && nextID == currentGroupID {
return fmt.Errorf("fallback group cycle detected")
}
// 降级分组不能启用 claude_code_only,否则会造成死循环 // 检查降级分组是否存在
if fallbackGroup.ClaudeCodeOnly { fallbackGroup, err := s.groupRepo.GetByIDLite(ctx, nextID)
return fmt.Errorf("fallback group cannot have claude_code_only enabled") if err != nil {
} return fmt.Errorf("fallback group not found: %w", err)
}
return nil // 降级分组不能启用 claude_code_only,否则会造成死循环
if nextID == fallbackGroupID && fallbackGroup.ClaudeCodeOnly {
return fmt.Errorf("fallback group cannot have claude_code_only enabled")
}
if fallbackGroup.FallbackGroupID == nil {
return nil
}
nextID = *fallbackGroup.FallbackGroupID
}
} }
func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *UpdateGroupInput) (*Group, error) { func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *UpdateGroupInput) (*Group, error) {
......
...@@ -107,6 +107,10 @@ func (s *groupRepoStub) GetByID(ctx context.Context, id int64) (*Group, error) { ...@@ -107,6 +107,10 @@ func (s *groupRepoStub) GetByID(ctx context.Context, id int64) (*Group, error) {
panic("unexpected GetByID call") panic("unexpected GetByID call")
} }
func (s *groupRepoStub) GetByIDLite(ctx context.Context, id int64) (*Group, error) {
panic("unexpected GetByIDLite call")
}
func (s *groupRepoStub) Update(ctx context.Context, group *Group) error { func (s *groupRepoStub) Update(ctx context.Context, group *Group) error {
panic("unexpected Update call") panic("unexpected Update call")
} }
......
...@@ -45,6 +45,13 @@ func (s *groupRepoStubForAdmin) GetByID(_ context.Context, _ int64) (*Group, err ...@@ -45,6 +45,13 @@ func (s *groupRepoStubForAdmin) GetByID(_ context.Context, _ int64) (*Group, err
return s.getByID, nil return s.getByID, nil
} }
func (s *groupRepoStubForAdmin) GetByIDLite(_ context.Context, _ int64) (*Group, error) {
if s.getErr != nil {
return nil, s.getErr
}
return s.getByID, nil
}
func (s *groupRepoStubForAdmin) Delete(_ context.Context, _ int64) error { func (s *groupRepoStubForAdmin) Delete(_ context.Context, _ int64) error {
panic("unexpected Delete call") panic("unexpected Delete call")
} }
...@@ -290,3 +297,84 @@ func TestAdminService_ListGroups_WithSearch(t *testing.T) { ...@@ -290,3 +297,84 @@ func TestAdminService_ListGroups_WithSearch(t *testing.T) {
require.True(t, *repo.listWithFiltersIsExclusive) require.True(t, *repo.listWithFiltersIsExclusive)
}) })
} }
func TestAdminService_ValidateFallbackGroup_DetectsCycle(t *testing.T) {
groupID := int64(1)
fallbackID := int64(2)
repo := &groupRepoStubForFallbackCycle{
groups: map[int64]*Group{
groupID: {
ID: groupID,
FallbackGroupID: &fallbackID,
},
fallbackID: {
ID: fallbackID,
FallbackGroupID: &groupID,
},
},
}
svc := &adminServiceImpl{groupRepo: repo}
err := svc.validateFallbackGroup(context.Background(), groupID, fallbackID)
require.Error(t, err)
require.Contains(t, err.Error(), "fallback group cycle")
}
type groupRepoStubForFallbackCycle struct {
groups map[int64]*Group
}
func (s *groupRepoStubForFallbackCycle) Create(_ context.Context, _ *Group) error {
panic("unexpected Create call")
}
func (s *groupRepoStubForFallbackCycle) Update(_ context.Context, _ *Group) error {
panic("unexpected Update call")
}
func (s *groupRepoStubForFallbackCycle) GetByID(ctx context.Context, id int64) (*Group, error) {
return s.GetByIDLite(ctx, id)
}
func (s *groupRepoStubForFallbackCycle) GetByIDLite(_ context.Context, id int64) (*Group, error) {
if g, ok := s.groups[id]; ok {
return g, nil
}
return nil, ErrGroupNotFound
}
func (s *groupRepoStubForFallbackCycle) Delete(_ context.Context, _ int64) error {
panic("unexpected Delete call")
}
func (s *groupRepoStubForFallbackCycle) DeleteCascade(_ context.Context, _ int64) ([]int64, error) {
panic("unexpected DeleteCascade call")
}
func (s *groupRepoStubForFallbackCycle) List(_ context.Context, _ pagination.PaginationParams) ([]Group, *pagination.PaginationResult, error) {
panic("unexpected List call")
}
func (s *groupRepoStubForFallbackCycle) ListWithFilters(_ context.Context, _ pagination.PaginationParams, _, _, _ string, _ *bool) ([]Group, *pagination.PaginationResult, error) {
panic("unexpected ListWithFilters call")
}
func (s *groupRepoStubForFallbackCycle) ListActive(_ context.Context) ([]Group, error) {
panic("unexpected ListActive call")
}
func (s *groupRepoStubForFallbackCycle) ListActiveByPlatform(_ context.Context, _ string) ([]Group, error) {
panic("unexpected ListActiveByPlatform call")
}
func (s *groupRepoStubForFallbackCycle) ExistsByName(_ context.Context, _ string) (bool, error) {
panic("unexpected ExistsByName call")
}
func (s *groupRepoStubForFallbackCycle) GetAccountCount(_ context.Context, _ int64) (int64, error) {
panic("unexpected GetAccountCount call")
}
func (s *groupRepoStubForFallbackCycle) DeleteAccountGroupsByGroupID(_ context.Context, _ int64) (int64, error) {
panic("unexpected DeleteAccountGroupsByGroupID call")
}
...@@ -35,9 +35,6 @@ var ( ...@@ -35,9 +35,6 @@ var (
// maxTokenLength 限制 token 大小,避免超长 header 触发解析时的异常内存分配。 // maxTokenLength 限制 token 大小,避免超长 header 触发解析时的异常内存分配。
const maxTokenLength = 8192 const maxTokenLength = 8192
// LinuxDoConnectSyntheticEmailDomain LinuxDo Connect 生成的合成邮箱域名后缀
const LinuxDoConnectSyntheticEmailDomain = "@linuxdo.synthetic"
// JWTClaims JWT载荷数据 // JWTClaims JWT载荷数据
type JWTClaims struct { type JWTClaims struct {
UserID int64 `json:"user_id"` UserID int64 `json:"user_id"`
...@@ -55,6 +52,7 @@ type AuthService struct { ...@@ -55,6 +52,7 @@ type AuthService struct {
emailService *EmailService emailService *EmailService
turnstileService *TurnstileService turnstileService *TurnstileService
emailQueueService *EmailQueueService emailQueueService *EmailQueueService
promoService *PromoService
} }
// NewAuthService 创建认证服务实例 // NewAuthService 创建认证服务实例
...@@ -65,6 +63,7 @@ func NewAuthService( ...@@ -65,6 +63,7 @@ func NewAuthService(
emailService *EmailService, emailService *EmailService,
turnstileService *TurnstileService, turnstileService *TurnstileService,
emailQueueService *EmailQueueService, emailQueueService *EmailQueueService,
promoService *PromoService,
) *AuthService { ) *AuthService {
return &AuthService{ return &AuthService{
userRepo: userRepo, userRepo: userRepo,
...@@ -73,16 +72,17 @@ func NewAuthService( ...@@ -73,16 +72,17 @@ func NewAuthService(
emailService: emailService, emailService: emailService,
turnstileService: turnstileService, turnstileService: turnstileService,
emailQueueService: emailQueueService, emailQueueService: emailQueueService,
promoService: promoService,
} }
} }
// Register 用户注册,返回token和用户 // Register 用户注册,返回token和用户
func (s *AuthService) Register(ctx context.Context, email, password string) (string, *User, error) { func (s *AuthService) Register(ctx context.Context, email, password string) (string, *User, error) {
return s.RegisterWithVerification(ctx, email, password, "") return s.RegisterWithVerification(ctx, email, password, "", "")
} }
// RegisterWithVerification 用户注册(支持邮件验证),返回token和用户 // RegisterWithVerification 用户注册(支持邮件验证和优惠码),返回token和用户
func (s *AuthService) RegisterWithVerification(ctx context.Context, email, password, verifyCode string) (string, *User, error) { func (s *AuthService) RegisterWithVerification(ctx context.Context, email, password, verifyCode, promoCode string) (string, *User, error) {
// 检查是否开放注册(默认关闭:settingService 未配置时不允许注册) // 检查是否开放注册(默认关闭:settingService 未配置时不允许注册)
if s.settingService == nil || !s.settingService.IsRegistrationEnabled(ctx) { if s.settingService == nil || !s.settingService.IsRegistrationEnabled(ctx) {
return "", nil, ErrRegDisabled return "", nil, ErrRegDisabled
...@@ -153,6 +153,19 @@ func (s *AuthService) RegisterWithVerification(ctx context.Context, email, passw ...@@ -153,6 +153,19 @@ func (s *AuthService) RegisterWithVerification(ctx context.Context, email, passw
return "", nil, ErrServiceUnavailable return "", nil, ErrServiceUnavailable
} }
// 应用优惠码(如果提供)
if promoCode != "" && s.promoService != nil {
if err := s.promoService.ApplyPromoCode(ctx, user.ID, promoCode); err != nil {
// 优惠码应用失败不影响注册,只记录日志
log.Printf("[Auth] Failed to apply promo code for user %d: %v", user.ID, err)
} else {
// 重新获取用户信息以获取更新后的余额
if updatedUser, err := s.userRepo.GetByID(ctx, user.ID); err == nil {
user = updatedUser
}
}
}
// 生成token // 生成token
token, err := s.GenerateToken(user) token, err := s.GenerateToken(user)
if err != nil { if err != nil {
......
...@@ -100,6 +100,7 @@ func newAuthService(repo *userRepoStub, settings map[string]string, emailCache E ...@@ -100,6 +100,7 @@ func newAuthService(repo *userRepoStub, settings map[string]string, emailCache E
emailService, emailService,
nil, nil,
nil, nil,
nil, // promoService
) )
} }
...@@ -131,7 +132,7 @@ func TestAuthService_Register_EmailVerifyEnabledButServiceNotConfigured(t *testi ...@@ -131,7 +132,7 @@ func TestAuthService_Register_EmailVerifyEnabledButServiceNotConfigured(t *testi
}, nil) }, nil)
// 应返回服务不可用错误,而不是允许绕过验证 // 应返回服务不可用错误,而不是允许绕过验证
_, _, err := service.RegisterWithVerification(context.Background(), "user@test.com", "password", "any-code") _, _, err := service.RegisterWithVerification(context.Background(), "user@test.com", "password", "any-code", "")
require.ErrorIs(t, err, ErrServiceUnavailable) require.ErrorIs(t, err, ErrServiceUnavailable)
} }
...@@ -143,7 +144,7 @@ func TestAuthService_Register_EmailVerifyRequired(t *testing.T) { ...@@ -143,7 +144,7 @@ func TestAuthService_Register_EmailVerifyRequired(t *testing.T) {
SettingKeyEmailVerifyEnabled: "true", SettingKeyEmailVerifyEnabled: "true",
}, cache) }, cache)
_, _, err := service.RegisterWithVerification(context.Background(), "user@test.com", "password", "") _, _, err := service.RegisterWithVerification(context.Background(), "user@test.com", "password", "", "")
require.ErrorIs(t, err, ErrEmailVerifyRequired) require.ErrorIs(t, err, ErrEmailVerifyRequired)
} }
...@@ -157,7 +158,7 @@ func TestAuthService_Register_EmailVerifyInvalid(t *testing.T) { ...@@ -157,7 +158,7 @@ func TestAuthService_Register_EmailVerifyInvalid(t *testing.T) {
SettingKeyEmailVerifyEnabled: "true", SettingKeyEmailVerifyEnabled: "true",
}, cache) }, cache)
_, _, err := service.RegisterWithVerification(context.Background(), "user@test.com", "password", "wrong") _, _, err := service.RegisterWithVerification(context.Background(), "user@test.com", "password", "wrong", "")
require.ErrorIs(t, err, ErrInvalidVerifyCode) require.ErrorIs(t, err, ErrInvalidVerifyCode)
require.ErrorContains(t, err, "verify code") require.ErrorContains(t, err, "verify code")
} }
......
...@@ -38,6 +38,12 @@ const ( ...@@ -38,6 +38,12 @@ const (
RedeemTypeSubscription = "subscription" RedeemTypeSubscription = "subscription"
) )
// PromoCode status constants
const (
PromoCodeStatusActive = "active"
PromoCodeStatusDisabled = "disabled"
)
// Admin adjustment type constants // Admin adjustment type constants
const ( const (
AdjustmentTypeAdminBalance = "admin_balance" // 管理员调整余额 AdjustmentTypeAdminBalance = "admin_balance" // 管理员调整余额
...@@ -57,6 +63,9 @@ const ( ...@@ -57,6 +63,9 @@ const (
SubscriptionStatusSuspended = "suspended" SubscriptionStatusSuspended = "suspended"
) )
// LinuxDoConnectSyntheticEmailDomain 是 LinuxDo Connect 用户的合成邮箱后缀(RFC 保留域名)。
const LinuxDoConnectSyntheticEmailDomain = "@linuxdo-connect.invalid"
// Setting keys // Setting keys
const ( const (
// 注册设置 // 注册设置
...@@ -90,6 +99,7 @@ const ( ...@@ -90,6 +99,7 @@ const (
SettingKeyAPIBaseURL = "api_base_url" // API端点地址(用于客户端配置和导入) SettingKeyAPIBaseURL = "api_base_url" // API端点地址(用于客户端配置和导入)
SettingKeyContactInfo = "contact_info" // 客服联系方式 SettingKeyContactInfo = "contact_info" // 客服联系方式
SettingKeyDocURL = "doc_url" // 文档链接 SettingKeyDocURL = "doc_url" // 文档链接
SettingKeyHomeContent = "home_content" // 首页内容(支持 Markdown/HTML,或 URL 作为 iframe src)
// 默认配置 // 默认配置
SettingKeyDefaultConcurrency = "default_concurrency" // 新用户默认并发量 SettingKeyDefaultConcurrency = "default_concurrency" // 新用户默认并发量
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment