Commit f6f072cb authored by Edric Li's avatar Edric Li
Browse files

Merge branch 'main' into feat/api-key-ip-restriction

parents 5265b12c ff087586
...@@ -4,6 +4,8 @@ package repository ...@@ -4,6 +4,8 @@ package repository
import ( import (
"context" "context"
"database/sql"
"errors"
"testing" "testing"
dbent "github.com/Wei-Shaw/sub2api/ent" dbent "github.com/Wei-Shaw/sub2api/ent"
...@@ -19,6 +21,20 @@ type GroupRepoSuite struct { ...@@ -19,6 +21,20 @@ type GroupRepoSuite struct {
repo *groupRepository repo *groupRepository
} }
type forbidSQLExecutor struct {
called bool
}
func (s *forbidSQLExecutor) ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error) {
s.called = true
return nil, errors.New("unexpected sql exec")
}
func (s *forbidSQLExecutor) QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error) {
s.called = true
return nil, errors.New("unexpected sql query")
}
func (s *GroupRepoSuite) SetupTest() { func (s *GroupRepoSuite) SetupTest() {
s.ctx = context.Background() s.ctx = context.Background()
tx := testEntTx(s.T()) tx := testEntTx(s.T())
...@@ -57,6 +73,26 @@ func (s *GroupRepoSuite) TestGetByID_NotFound() { ...@@ -57,6 +73,26 @@ func (s *GroupRepoSuite) TestGetByID_NotFound() {
s.Require().ErrorIs(err, service.ErrGroupNotFound) s.Require().ErrorIs(err, service.ErrGroupNotFound)
} }
func (s *GroupRepoSuite) TestGetByIDLite_DoesNotUseAccountCount() {
group := &service.Group{
Name: "lite-group",
Platform: service.PlatformAnthropic,
RateMultiplier: 1.0,
IsExclusive: false,
Status: service.StatusActive,
SubscriptionType: service.SubscriptionTypeStandard,
}
s.Require().NoError(s.repo.Create(s.ctx, group))
spy := &forbidSQLExecutor{}
repo := newGroupRepositoryWithSQL(s.tx.Client(), spy)
got, err := repo.GetByIDLite(s.ctx, group.ID)
s.Require().NoError(err)
s.Require().Equal(group.ID, got.ID)
s.Require().False(spy.called, "expected no direct sql executor usage")
}
func (s *GroupRepoSuite) TestUpdate() { func (s *GroupRepoSuite) TestUpdate() {
group := &service.Group{ group := &service.Group{
Name: "original", Name: "original",
......
package repository
import (
"context"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/promocode"
"github.com/Wei-Shaw/sub2api/ent/promocodeusage"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service"
)
type promoCodeRepository struct {
client *dbent.Client
}
func NewPromoCodeRepository(client *dbent.Client) service.PromoCodeRepository {
return &promoCodeRepository{client: client}
}
func (r *promoCodeRepository) Create(ctx context.Context, code *service.PromoCode) error {
client := clientFromContext(ctx, r.client)
builder := client.PromoCode.Create().
SetCode(code.Code).
SetBonusAmount(code.BonusAmount).
SetMaxUses(code.MaxUses).
SetUsedCount(code.UsedCount).
SetStatus(code.Status).
SetNotes(code.Notes)
if code.ExpiresAt != nil {
builder.SetExpiresAt(*code.ExpiresAt)
}
created, err := builder.Save(ctx)
if err != nil {
return err
}
code.ID = created.ID
code.CreatedAt = created.CreatedAt
code.UpdatedAt = created.UpdatedAt
return nil
}
func (r *promoCodeRepository) GetByID(ctx context.Context, id int64) (*service.PromoCode, error) {
m, err := r.client.PromoCode.Query().
Where(promocode.IDEQ(id)).
Only(ctx)
if err != nil {
if dbent.IsNotFound(err) {
return nil, service.ErrPromoCodeNotFound
}
return nil, err
}
return promoCodeEntityToService(m), nil
}
func (r *promoCodeRepository) GetByCode(ctx context.Context, code string) (*service.PromoCode, error) {
m, err := r.client.PromoCode.Query().
Where(promocode.CodeEqualFold(code)).
Only(ctx)
if err != nil {
if dbent.IsNotFound(err) {
return nil, service.ErrPromoCodeNotFound
}
return nil, err
}
return promoCodeEntityToService(m), nil
}
func (r *promoCodeRepository) GetByCodeForUpdate(ctx context.Context, code string) (*service.PromoCode, error) {
client := clientFromContext(ctx, r.client)
m, err := client.PromoCode.Query().
Where(promocode.CodeEqualFold(code)).
ForUpdate().
Only(ctx)
if err != nil {
if dbent.IsNotFound(err) {
return nil, service.ErrPromoCodeNotFound
}
return nil, err
}
return promoCodeEntityToService(m), nil
}
func (r *promoCodeRepository) Update(ctx context.Context, code *service.PromoCode) error {
client := clientFromContext(ctx, r.client)
builder := client.PromoCode.UpdateOneID(code.ID).
SetCode(code.Code).
SetBonusAmount(code.BonusAmount).
SetMaxUses(code.MaxUses).
SetUsedCount(code.UsedCount).
SetStatus(code.Status).
SetNotes(code.Notes)
if code.ExpiresAt != nil {
builder.SetExpiresAt(*code.ExpiresAt)
} else {
builder.ClearExpiresAt()
}
updated, err := builder.Save(ctx)
if err != nil {
if dbent.IsNotFound(err) {
return service.ErrPromoCodeNotFound
}
return err
}
code.UpdatedAt = updated.UpdatedAt
return nil
}
func (r *promoCodeRepository) Delete(ctx context.Context, id int64) error {
client := clientFromContext(ctx, r.client)
_, err := client.PromoCode.Delete().Where(promocode.IDEQ(id)).Exec(ctx)
return err
}
func (r *promoCodeRepository) List(ctx context.Context, params pagination.PaginationParams) ([]service.PromoCode, *pagination.PaginationResult, error) {
return r.ListWithFilters(ctx, params, "", "")
}
func (r *promoCodeRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, status, search string) ([]service.PromoCode, *pagination.PaginationResult, error) {
q := r.client.PromoCode.Query()
if status != "" {
q = q.Where(promocode.StatusEQ(status))
}
if search != "" {
q = q.Where(promocode.CodeContainsFold(search))
}
total, err := q.Count(ctx)
if err != nil {
return nil, nil, err
}
codes, err := q.
Offset(params.Offset()).
Limit(params.Limit()).
Order(dbent.Desc(promocode.FieldID)).
All(ctx)
if err != nil {
return nil, nil, err
}
outCodes := promoCodeEntitiesToService(codes)
return outCodes, paginationResultFromTotal(int64(total), params), nil
}
func (r *promoCodeRepository) CreateUsage(ctx context.Context, usage *service.PromoCodeUsage) error {
client := clientFromContext(ctx, r.client)
created, err := client.PromoCodeUsage.Create().
SetPromoCodeID(usage.PromoCodeID).
SetUserID(usage.UserID).
SetBonusAmount(usage.BonusAmount).
SetUsedAt(usage.UsedAt).
Save(ctx)
if err != nil {
return err
}
usage.ID = created.ID
return nil
}
func (r *promoCodeRepository) GetUsageByPromoCodeAndUser(ctx context.Context, promoCodeID, userID int64) (*service.PromoCodeUsage, error) {
m, err := r.client.PromoCodeUsage.Query().
Where(
promocodeusage.PromoCodeIDEQ(promoCodeID),
promocodeusage.UserIDEQ(userID),
).
Only(ctx)
if err != nil {
if dbent.IsNotFound(err) {
return nil, nil
}
return nil, err
}
return promoCodeUsageEntityToService(m), nil
}
func (r *promoCodeRepository) ListUsagesByPromoCode(ctx context.Context, promoCodeID int64, params pagination.PaginationParams) ([]service.PromoCodeUsage, *pagination.PaginationResult, error) {
q := r.client.PromoCodeUsage.Query().
Where(promocodeusage.PromoCodeIDEQ(promoCodeID))
total, err := q.Count(ctx)
if err != nil {
return nil, nil, err
}
usages, err := q.
WithUser().
Offset(params.Offset()).
Limit(params.Limit()).
Order(dbent.Desc(promocodeusage.FieldID)).
All(ctx)
if err != nil {
return nil, nil, err
}
outUsages := promoCodeUsageEntitiesToService(usages)
return outUsages, paginationResultFromTotal(int64(total), params), nil
}
func (r *promoCodeRepository) IncrementUsedCount(ctx context.Context, id int64) error {
client := clientFromContext(ctx, r.client)
_, err := client.PromoCode.UpdateOneID(id).
AddUsedCount(1).
Save(ctx)
return err
}
// Entity to Service conversions
func promoCodeEntityToService(m *dbent.PromoCode) *service.PromoCode {
if m == nil {
return nil
}
return &service.PromoCode{
ID: m.ID,
Code: m.Code,
BonusAmount: m.BonusAmount,
MaxUses: m.MaxUses,
UsedCount: m.UsedCount,
Status: m.Status,
ExpiresAt: m.ExpiresAt,
Notes: derefString(m.Notes),
CreatedAt: m.CreatedAt,
UpdatedAt: m.UpdatedAt,
}
}
func promoCodeEntitiesToService(models []*dbent.PromoCode) []service.PromoCode {
out := make([]service.PromoCode, 0, len(models))
for i := range models {
if s := promoCodeEntityToService(models[i]); s != nil {
out = append(out, *s)
}
}
return out
}
func promoCodeUsageEntityToService(m *dbent.PromoCodeUsage) *service.PromoCodeUsage {
if m == nil {
return nil
}
out := &service.PromoCodeUsage{
ID: m.ID,
PromoCodeID: m.PromoCodeID,
UserID: m.UserID,
BonusAmount: m.BonusAmount,
UsedAt: m.UsedAt,
}
if m.Edges.User != nil {
out.User = userEntityToService(m.Edges.User)
}
return out
}
func promoCodeUsageEntitiesToService(models []*dbent.PromoCodeUsage) []service.PromoCodeUsage {
out := make([]service.PromoCodeUsage, 0, len(models))
for i := range models {
if s := promoCodeUsageEntityToService(models[i]); s != nil {
out = append(out, *s)
}
}
return out
}
...@@ -45,6 +45,7 @@ var ProviderSet = wire.NewSet( ...@@ -45,6 +45,7 @@ var ProviderSet = wire.NewSet(
NewAccountRepository, NewAccountRepository,
NewProxyRepository, NewProxyRepository,
NewRedeemCodeRepository, NewRedeemCodeRepository,
NewPromoCodeRepository,
NewUsageLogRepository, NewUsageLogRepository,
NewSettingRepository, NewSettingRepository,
NewUserSubscriptionRepository, NewUserSubscriptionRepository,
......
...@@ -398,7 +398,7 @@ func newContractDeps(t *testing.T) *contractDeps { ...@@ -398,7 +398,7 @@ func newContractDeps(t *testing.T) *contractDeps {
settingRepo := newStubSettingRepo() settingRepo := newStubSettingRepo()
settingService := service.NewSettingService(settingRepo, cfg) settingService := service.NewSettingService(settingRepo, cfg)
authHandler := handler.NewAuthHandler(cfg, nil, userService, settingService) authHandler := handler.NewAuthHandler(cfg, nil, userService, settingService, nil)
apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService) apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService)
usageHandler := handler.NewUsageHandler(usageService, apiKeyService) usageHandler := handler.NewUsageHandler(usageService, apiKeyService)
adminSettingHandler := adminhandler.NewSettingHandler(settingService, nil, nil) adminSettingHandler := adminhandler.NewSettingHandler(settingService, nil, nil)
...@@ -575,6 +575,10 @@ func (stubGroupRepo) GetByID(ctx context.Context, id int64) (*service.Group, err ...@@ -575,6 +575,10 @@ func (stubGroupRepo) GetByID(ctx context.Context, id int64) (*service.Group, err
return nil, service.ErrGroupNotFound return nil, service.ErrGroupNotFound
} }
func (stubGroupRepo) GetByIDLite(ctx context.Context, id int64) (*service.Group, error) {
return nil, service.ErrGroupNotFound
}
func (stubGroupRepo) Update(ctx context.Context, group *service.Group) error { func (stubGroupRepo) Update(ctx context.Context, group *service.Group) error {
return errors.New("not implemented") return errors.New("not implemented")
} }
......
...@@ -13,6 +13,7 @@ import ( ...@@ -13,6 +13,7 @@ import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/google/wire" "github.com/google/wire"
"github.com/redis/go-redis/v9"
) )
// ProviderSet 提供服务器层的依赖 // ProviderSet 提供服务器层的依赖
...@@ -31,6 +32,7 @@ func ProvideRouter( ...@@ -31,6 +32,7 @@ func ProvideRouter(
apiKeyService *service.APIKeyService, apiKeyService *service.APIKeyService,
subscriptionService *service.SubscriptionService, subscriptionService *service.SubscriptionService,
settingService *service.SettingService, settingService *service.SettingService,
redisClient *redis.Client,
) *gin.Engine { ) *gin.Engine {
if cfg.Server.Mode == "release" { if cfg.Server.Mode == "release" {
gin.SetMode(gin.ReleaseMode) gin.SetMode(gin.ReleaseMode)
...@@ -48,7 +50,7 @@ func ProvideRouter( ...@@ -48,7 +50,7 @@ func ProvideRouter(
} }
} }
return SetupRouter(r, handlers, jwtAuth, adminAuth, apiKeyAuth, apiKeyService, subscriptionService, settingService, cfg) return SetupRouter(r, handlers, jwtAuth, adminAuth, apiKeyAuth, apiKeyService, subscriptionService, settingService, cfg, redisClient)
} }
// ProvideHTTPServer 提供 HTTP 服务器 // ProvideHTTPServer 提供 HTTP 服务器
......
package middleware package middleware
import ( import (
"context"
"errors" "errors"
"log" "log"
"strings" "strings"
"github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
"github.com/Wei-Shaw/sub2api/internal/pkg/ip" "github.com/Wei-Shaw/sub2api/internal/pkg/ip"
"github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
...@@ -103,6 +105,7 @@ func apiKeyAuthWithSubscription(apiKeyService *service.APIKeyService, subscripti ...@@ -103,6 +105,7 @@ func apiKeyAuthWithSubscription(apiKeyService *service.APIKeyService, subscripti
Concurrency: apiKey.User.Concurrency, Concurrency: apiKey.User.Concurrency,
}) })
c.Set(string(ContextKeyUserRole), apiKey.User.Role) c.Set(string(ContextKeyUserRole), apiKey.User.Role)
setGroupContext(c, apiKey.Group)
c.Next() c.Next()
return return
} }
...@@ -161,6 +164,7 @@ func apiKeyAuthWithSubscription(apiKeyService *service.APIKeyService, subscripti ...@@ -161,6 +164,7 @@ func apiKeyAuthWithSubscription(apiKeyService *service.APIKeyService, subscripti
Concurrency: apiKey.User.Concurrency, Concurrency: apiKey.User.Concurrency,
}) })
c.Set(string(ContextKeyUserRole), apiKey.User.Role) c.Set(string(ContextKeyUserRole), apiKey.User.Role)
setGroupContext(c, apiKey.Group)
c.Next() c.Next()
} }
...@@ -185,3 +189,14 @@ func GetSubscriptionFromContext(c *gin.Context) (*service.UserSubscription, bool ...@@ -185,3 +189,14 @@ func GetSubscriptionFromContext(c *gin.Context) (*service.UserSubscription, bool
subscription, ok := value.(*service.UserSubscription) subscription, ok := value.(*service.UserSubscription)
return subscription, ok return subscription, ok
} }
func setGroupContext(c *gin.Context, group *service.Group) {
if !service.IsGroupContextValid(group) {
return
}
if existing, ok := c.Request.Context().Value(ctxkey.Group).(*service.Group); ok && existing != nil && existing.ID == group.ID && service.IsGroupContextValid(existing) {
return
}
ctx := context.WithValue(c.Request.Context(), ctxkey.Group, group)
c.Request = c.Request.WithContext(ctx)
}
...@@ -63,6 +63,7 @@ func APIKeyAuthWithSubscriptionGoogle(apiKeyService *service.APIKeyService, subs ...@@ -63,6 +63,7 @@ func APIKeyAuthWithSubscriptionGoogle(apiKeyService *service.APIKeyService, subs
Concurrency: apiKey.User.Concurrency, Concurrency: apiKey.User.Concurrency,
}) })
c.Set(string(ContextKeyUserRole), apiKey.User.Role) c.Set(string(ContextKeyUserRole), apiKey.User.Role)
setGroupContext(c, apiKey.Group)
c.Next() c.Next()
return return
} }
...@@ -102,6 +103,7 @@ func APIKeyAuthWithSubscriptionGoogle(apiKeyService *service.APIKeyService, subs ...@@ -102,6 +103,7 @@ func APIKeyAuthWithSubscriptionGoogle(apiKeyService *service.APIKeyService, subs
Concurrency: apiKey.User.Concurrency, Concurrency: apiKey.User.Concurrency,
}) })
c.Set(string(ContextKeyUserRole), apiKey.User.Role) c.Set(string(ContextKeyUserRole), apiKey.User.Role)
setGroupContext(c, apiKey.Group)
c.Next() c.Next()
} }
} }
......
...@@ -9,6 +9,7 @@ import ( ...@@ -9,6 +9,7 @@ import (
"testing" "testing"
"github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
...@@ -133,6 +134,70 @@ func TestApiKeyAuthWithSubscriptionGoogle_QueryApiKeyRejected(t *testing.T) { ...@@ -133,6 +134,70 @@ func TestApiKeyAuthWithSubscriptionGoogle_QueryApiKeyRejected(t *testing.T) {
require.Equal(t, "INVALID_ARGUMENT", resp.Error.Status) require.Equal(t, "INVALID_ARGUMENT", resp.Error.Status)
} }
func TestApiKeyAuthWithSubscriptionGoogleSetsGroupContext(t *testing.T) {
gin.SetMode(gin.TestMode)
group := &service.Group{
ID: 99,
Name: "g1",
Status: service.StatusActive,
Platform: service.PlatformGemini,
Hydrated: true,
}
user := &service.User{
ID: 7,
Role: service.RoleUser,
Status: service.StatusActive,
Balance: 10,
Concurrency: 3,
}
apiKey := &service.APIKey{
ID: 100,
UserID: user.ID,
Key: "test-key",
Status: service.StatusActive,
User: user,
Group: group,
}
apiKey.GroupID = &group.ID
apiKeyService := service.NewAPIKeyService(
fakeAPIKeyRepo{
getByKey: func(ctx context.Context, key string) (*service.APIKey, error) {
if key != apiKey.Key {
return nil, service.ErrAPIKeyNotFound
}
clone := *apiKey
return &clone, nil
},
},
nil,
nil,
nil,
nil,
&config.Config{RunMode: config.RunModeSimple},
)
cfg := &config.Config{RunMode: config.RunModeSimple}
r := gin.New()
r.Use(APIKeyAuthWithSubscriptionGoogle(apiKeyService, nil, cfg))
r.GET("/v1beta/test", func(c *gin.Context) {
groupFromCtx, ok := c.Request.Context().Value(ctxkey.Group).(*service.Group)
if !ok || groupFromCtx == nil || groupFromCtx.ID != group.ID {
c.JSON(http.StatusInternalServerError, gin.H{"ok": false})
return
}
c.JSON(http.StatusOK, gin.H{"ok": true})
})
req := httptest.NewRequest(http.MethodGet, "/v1beta/test", nil)
req.Header.Set("x-api-key", apiKey.Key)
rec := httptest.NewRecorder()
r.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
}
func TestApiKeyAuthWithSubscriptionGoogle_QueryKeyAllowedOnV1Beta(t *testing.T) { func TestApiKeyAuthWithSubscriptionGoogle_QueryKeyAllowedOnV1Beta(t *testing.T) {
gin.SetMode(gin.TestMode) gin.SetMode(gin.TestMode)
......
...@@ -11,6 +11,7 @@ import ( ...@@ -11,6 +11,7 @@ import (
"time" "time"
"github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
...@@ -25,6 +26,7 @@ func TestSimpleModeBypassesQuotaCheck(t *testing.T) { ...@@ -25,6 +26,7 @@ func TestSimpleModeBypassesQuotaCheck(t *testing.T) {
ID: 42, ID: 42,
Name: "sub", Name: "sub",
Status: service.StatusActive, Status: service.StatusActive,
Hydrated: true,
SubscriptionType: service.SubscriptionTypeSubscription, SubscriptionType: service.SubscriptionTypeSubscription,
DailyLimitUSD: &limit, DailyLimitUSD: &limit,
} }
...@@ -110,6 +112,129 @@ func TestSimpleModeBypassesQuotaCheck(t *testing.T) { ...@@ -110,6 +112,129 @@ func TestSimpleModeBypassesQuotaCheck(t *testing.T) {
}) })
} }
func TestAPIKeyAuthSetsGroupContext(t *testing.T) {
gin.SetMode(gin.TestMode)
group := &service.Group{
ID: 101,
Name: "g1",
Status: service.StatusActive,
Platform: service.PlatformAnthropic,
Hydrated: true,
}
user := &service.User{
ID: 7,
Role: service.RoleUser,
Status: service.StatusActive,
Balance: 10,
Concurrency: 3,
}
apiKey := &service.APIKey{
ID: 100,
UserID: user.ID,
Key: "test-key",
Status: service.StatusActive,
User: user,
Group: group,
}
apiKey.GroupID = &group.ID
apiKeyRepo := &stubApiKeyRepo{
getByKey: func(ctx context.Context, key string) (*service.APIKey, error) {
if key != apiKey.Key {
return nil, service.ErrAPIKeyNotFound
}
clone := *apiKey
return &clone, nil
},
}
cfg := &config.Config{RunMode: config.RunModeSimple}
apiKeyService := service.NewAPIKeyService(apiKeyRepo, nil, nil, nil, nil, cfg)
router := gin.New()
router.Use(gin.HandlerFunc(NewAPIKeyAuthMiddleware(apiKeyService, nil, cfg)))
router.GET("/t", func(c *gin.Context) {
groupFromCtx, ok := c.Request.Context().Value(ctxkey.Group).(*service.Group)
if !ok || groupFromCtx == nil || groupFromCtx.ID != group.ID {
c.JSON(http.StatusInternalServerError, gin.H{"ok": false})
return
}
c.JSON(http.StatusOK, gin.H{"ok": true})
})
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/t", nil)
req.Header.Set("x-api-key", apiKey.Key)
router.ServeHTTP(w, req)
require.Equal(t, http.StatusOK, w.Code)
}
func TestAPIKeyAuthOverwritesInvalidContextGroup(t *testing.T) {
gin.SetMode(gin.TestMode)
group := &service.Group{
ID: 101,
Name: "g1",
Status: service.StatusActive,
Platform: service.PlatformAnthropic,
Hydrated: true,
}
user := &service.User{
ID: 7,
Role: service.RoleUser,
Status: service.StatusActive,
Balance: 10,
Concurrency: 3,
}
apiKey := &service.APIKey{
ID: 100,
UserID: user.ID,
Key: "test-key",
Status: service.StatusActive,
User: user,
Group: group,
}
apiKey.GroupID = &group.ID
apiKeyRepo := &stubApiKeyRepo{
getByKey: func(ctx context.Context, key string) (*service.APIKey, error) {
if key != apiKey.Key {
return nil, service.ErrAPIKeyNotFound
}
clone := *apiKey
return &clone, nil
},
}
cfg := &config.Config{RunMode: config.RunModeSimple}
apiKeyService := service.NewAPIKeyService(apiKeyRepo, nil, nil, nil, nil, cfg)
router := gin.New()
router.Use(gin.HandlerFunc(NewAPIKeyAuthMiddleware(apiKeyService, nil, cfg)))
invalidGroup := &service.Group{
ID: group.ID,
Platform: group.Platform,
Status: group.Status,
}
router.GET("/t", func(c *gin.Context) {
groupFromCtx, ok := c.Request.Context().Value(ctxkey.Group).(*service.Group)
if !ok || groupFromCtx == nil || groupFromCtx.ID != group.ID || !groupFromCtx.Hydrated || groupFromCtx == invalidGroup {
c.JSON(http.StatusInternalServerError, gin.H{"ok": false})
return
}
c.JSON(http.StatusOK, gin.H{"ok": true})
})
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/t", nil)
req.Header.Set("x-api-key", apiKey.Key)
req = req.WithContext(context.WithValue(req.Context(), ctxkey.Group, invalidGroup))
router.ServeHTTP(w, req)
require.Equal(t, http.StatusOK, w.Code)
}
func newAuthTestRouter(apiKeyService *service.APIKeyService, subscriptionService *service.SubscriptionService, cfg *config.Config) *gin.Engine { func newAuthTestRouter(apiKeyService *service.APIKeyService, subscriptionService *service.SubscriptionService, cfg *config.Config) *gin.Engine {
router := gin.New() router := gin.New()
router.Use(gin.HandlerFunc(NewAPIKeyAuthMiddleware(apiKeyService, subscriptionService, cfg))) router.Use(gin.HandlerFunc(NewAPIKeyAuthMiddleware(apiKeyService, subscriptionService, cfg)))
......
...@@ -11,6 +11,7 @@ import ( ...@@ -11,6 +11,7 @@ import (
"github.com/Wei-Shaw/sub2api/internal/web" "github.com/Wei-Shaw/sub2api/internal/web"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/redis/go-redis/v9"
) )
// SetupRouter 配置路由器中间件和路由 // SetupRouter 配置路由器中间件和路由
...@@ -24,6 +25,7 @@ func SetupRouter( ...@@ -24,6 +25,7 @@ func SetupRouter(
subscriptionService *service.SubscriptionService, subscriptionService *service.SubscriptionService,
settingService *service.SettingService, settingService *service.SettingService,
cfg *config.Config, cfg *config.Config,
redisClient *redis.Client,
) *gin.Engine { ) *gin.Engine {
// 应用中间件 // 应用中间件
r.Use(middleware2.Logger()) r.Use(middleware2.Logger())
...@@ -44,7 +46,7 @@ func SetupRouter( ...@@ -44,7 +46,7 @@ func SetupRouter(
} }
// 注册路由 // 注册路由
registerRoutes(r, handlers, jwtAuth, adminAuth, apiKeyAuth, apiKeyService, subscriptionService, cfg) registerRoutes(r, handlers, jwtAuth, adminAuth, apiKeyAuth, apiKeyService, subscriptionService, cfg, redisClient)
return r return r
} }
...@@ -59,6 +61,7 @@ func registerRoutes( ...@@ -59,6 +61,7 @@ func registerRoutes(
apiKeyService *service.APIKeyService, apiKeyService *service.APIKeyService,
subscriptionService *service.SubscriptionService, subscriptionService *service.SubscriptionService,
cfg *config.Config, cfg *config.Config,
redisClient *redis.Client,
) { ) {
// 通用路由(健康检查、状态等) // 通用路由(健康检查、状态等)
routes.RegisterCommonRoutes(r) routes.RegisterCommonRoutes(r)
...@@ -67,7 +70,7 @@ func registerRoutes( ...@@ -67,7 +70,7 @@ func registerRoutes(
v1 := r.Group("/api/v1") v1 := r.Group("/api/v1")
// 注册各模块路由 // 注册各模块路由
routes.RegisterAuthRoutes(v1, h, jwtAuth) routes.RegisterAuthRoutes(v1, h, jwtAuth, redisClient)
routes.RegisterUserRoutes(v1, h, jwtAuth) routes.RegisterUserRoutes(v1, h, jwtAuth)
routes.RegisterAdminRoutes(v1, h, adminAuth) routes.RegisterAdminRoutes(v1, h, adminAuth)
routes.RegisterGatewayRoutes(r, h, apiKeyAuth, apiKeyService, subscriptionService, cfg) routes.RegisterGatewayRoutes(r, h, apiKeyAuth, apiKeyService, subscriptionService, cfg)
......
...@@ -44,6 +44,9 @@ func RegisterAdminRoutes( ...@@ -44,6 +44,9 @@ func RegisterAdminRoutes(
// 卡密管理 // 卡密管理
registerRedeemCodeRoutes(admin, h) registerRedeemCodeRoutes(admin, h)
// 优惠码管理
registerPromoCodeRoutes(admin, h)
// 系统设置 // 系统设置
registerSettingsRoutes(admin, h) registerSettingsRoutes(admin, h)
...@@ -201,6 +204,18 @@ func registerRedeemCodeRoutes(admin *gin.RouterGroup, h *handler.Handlers) { ...@@ -201,6 +204,18 @@ func registerRedeemCodeRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
} }
} }
func registerPromoCodeRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
promoCodes := admin.Group("/promo-codes")
{
promoCodes.GET("", h.Admin.Promo.List)
promoCodes.GET("/:id", h.Admin.Promo.GetByID)
promoCodes.POST("", h.Admin.Promo.Create)
promoCodes.PUT("/:id", h.Admin.Promo.Update)
promoCodes.DELETE("/:id", h.Admin.Promo.Delete)
promoCodes.GET("/:id/usages", h.Admin.Promo.GetUsages)
}
}
func registerSettingsRoutes(admin *gin.RouterGroup, h *handler.Handlers) { func registerSettingsRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
adminSettings := admin.Group("/settings") adminSettings := admin.Group("/settings")
{ {
......
package routes package routes
import ( import (
"time"
"github.com/Wei-Shaw/sub2api/internal/handler" "github.com/Wei-Shaw/sub2api/internal/handler"
"github.com/Wei-Shaw/sub2api/internal/server/middleware" "github.com/Wei-Shaw/sub2api/internal/middleware"
servermiddleware "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/redis/go-redis/v9"
) )
// RegisterAuthRoutes 注册认证相关路由 // RegisterAuthRoutes 注册认证相关路由
func RegisterAuthRoutes( func RegisterAuthRoutes(
v1 *gin.RouterGroup, v1 *gin.RouterGroup,
h *handler.Handlers, h *handler.Handlers,
jwtAuth middleware.JWTAuthMiddleware, jwtAuth servermiddleware.JWTAuthMiddleware,
redisClient *redis.Client,
) { ) {
// 创建速率限制器
rateLimiter := middleware.NewRateLimiter(redisClient)
// 公开接口 // 公开接口
auth := v1.Group("/auth") auth := v1.Group("/auth")
{ {
auth.POST("/register", h.Auth.Register) auth.POST("/register", h.Auth.Register)
auth.POST("/login", h.Auth.Login) auth.POST("/login", h.Auth.Login)
auth.POST("/send-verify-code", h.Auth.SendVerifyCode) auth.POST("/send-verify-code", h.Auth.SendVerifyCode)
// 优惠码验证接口添加速率限制:每分钟最多 10 次
auth.POST("/validate-promo-code", rateLimiter.Limit("validate-promo", 10, time.Minute), h.Auth.ValidatePromoCode)
auth.GET("/oauth/linuxdo/start", h.Auth.LinuxDoOAuthStart) auth.GET("/oauth/linuxdo/start", h.Auth.LinuxDoOAuthStart)
auth.GET("/oauth/linuxdo/callback", h.Auth.LinuxDoOAuthCallback) auth.GET("/oauth/linuxdo/callback", h.Auth.LinuxDoOAuthCallback)
} }
......
...@@ -49,10 +49,12 @@ type AccountRepository interface { ...@@ -49,10 +49,12 @@ type AccountRepository interface {
ListSchedulableByGroupIDAndPlatforms(ctx context.Context, groupID int64, platforms []string) ([]Account, error) ListSchedulableByGroupIDAndPlatforms(ctx context.Context, groupID int64, platforms []string) ([]Account, error)
SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error
SetAntigravityQuotaScopeLimit(ctx context.Context, id int64, scope AntigravityQuotaScope, resetAt time.Time) error
SetOverloaded(ctx context.Context, id int64, until time.Time) error SetOverloaded(ctx context.Context, id int64, until time.Time) error
SetTempUnschedulable(ctx context.Context, id int64, until time.Time, reason string) error SetTempUnschedulable(ctx context.Context, id int64, until time.Time, reason string) error
ClearTempUnschedulable(ctx context.Context, id int64) error ClearTempUnschedulable(ctx context.Context, id int64) error
ClearRateLimit(ctx context.Context, id int64) error ClearRateLimit(ctx context.Context, id int64) error
ClearAntigravityQuotaScopes(ctx context.Context, id int64) error
UpdateSessionWindow(ctx context.Context, id int64, start, end *time.Time, status string) error UpdateSessionWindow(ctx context.Context, id int64, start, end *time.Time, status string) error
UpdateExtra(ctx context.Context, id int64, updates map[string]any) error UpdateExtra(ctx context.Context, id int64, updates map[string]any) error
BulkUpdate(ctx context.Context, ids []int64, updates AccountBulkUpdate) (int64, error) BulkUpdate(ctx context.Context, ids []int64, updates AccountBulkUpdate) (int64, error)
......
...@@ -139,6 +139,10 @@ func (s *accountRepoStub) SetRateLimited(ctx context.Context, id int64, resetAt ...@@ -139,6 +139,10 @@ func (s *accountRepoStub) SetRateLimited(ctx context.Context, id int64, resetAt
panic("unexpected SetRateLimited call") panic("unexpected SetRateLimited call")
} }
func (s *accountRepoStub) SetAntigravityQuotaScopeLimit(ctx context.Context, id int64, scope AntigravityQuotaScope, resetAt time.Time) error {
panic("unexpected SetAntigravityQuotaScopeLimit call")
}
func (s *accountRepoStub) SetOverloaded(ctx context.Context, id int64, until time.Time) error { func (s *accountRepoStub) SetOverloaded(ctx context.Context, id int64, until time.Time) error {
panic("unexpected SetOverloaded call") panic("unexpected SetOverloaded call")
} }
...@@ -155,6 +159,10 @@ func (s *accountRepoStub) ClearRateLimit(ctx context.Context, id int64) error { ...@@ -155,6 +159,10 @@ func (s *accountRepoStub) ClearRateLimit(ctx context.Context, id int64) error {
panic("unexpected ClearRateLimit call") panic("unexpected ClearRateLimit call")
} }
func (s *accountRepoStub) ClearAntigravityQuotaScopes(ctx context.Context, id int64) error {
panic("unexpected ClearAntigravityQuotaScopes call")
}
func (s *accountRepoStub) UpdateSessionWindow(ctx context.Context, id int64, start, end *time.Time, status string) error { func (s *accountRepoStub) UpdateSessionWindow(ctx context.Context, id int64, start, end *time.Time, status string) error {
panic("unexpected UpdateSessionWindow call") panic("unexpected UpdateSessionWindow call")
} }
......
...@@ -576,18 +576,33 @@ func (s *adminServiceImpl) validateFallbackGroup(ctx context.Context, currentGro ...@@ -576,18 +576,33 @@ func (s *adminServiceImpl) validateFallbackGroup(ctx context.Context, currentGro
return fmt.Errorf("cannot set self as fallback group") return fmt.Errorf("cannot set self as fallback group")
} }
// 检查降级分组是否存在 visited := map[int64]struct{}{}
fallbackGroup, err := s.groupRepo.GetByID(ctx, fallbackGroupID) nextID := fallbackGroupID
if err != nil { for {
return fmt.Errorf("fallback group not found: %w", err) if _, seen := visited[nextID]; seen {
} return fmt.Errorf("fallback group cycle detected")
}
visited[nextID] = struct{}{}
if currentGroupID > 0 && nextID == currentGroupID {
return fmt.Errorf("fallback group cycle detected")
}
// 降级分组不能启用 claude_code_only,否则会造成死循环 // 检查降级分组是否存在
if fallbackGroup.ClaudeCodeOnly { fallbackGroup, err := s.groupRepo.GetByIDLite(ctx, nextID)
return fmt.Errorf("fallback group cannot have claude_code_only enabled") if err != nil {
} return fmt.Errorf("fallback group not found: %w", err)
}
return nil // 降级分组不能启用 claude_code_only,否则会造成死循环
if nextID == fallbackGroupID && fallbackGroup.ClaudeCodeOnly {
return fmt.Errorf("fallback group cannot have claude_code_only enabled")
}
if fallbackGroup.FallbackGroupID == nil {
return nil
}
nextID = *fallbackGroup.FallbackGroupID
}
} }
func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *UpdateGroupInput) (*Group, error) { func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *UpdateGroupInput) (*Group, error) {
......
...@@ -107,6 +107,10 @@ func (s *groupRepoStub) GetByID(ctx context.Context, id int64) (*Group, error) { ...@@ -107,6 +107,10 @@ func (s *groupRepoStub) GetByID(ctx context.Context, id int64) (*Group, error) {
panic("unexpected GetByID call") panic("unexpected GetByID call")
} }
func (s *groupRepoStub) GetByIDLite(ctx context.Context, id int64) (*Group, error) {
panic("unexpected GetByIDLite call")
}
func (s *groupRepoStub) Update(ctx context.Context, group *Group) error { func (s *groupRepoStub) Update(ctx context.Context, group *Group) error {
panic("unexpected Update call") panic("unexpected Update call")
} }
......
...@@ -45,6 +45,13 @@ func (s *groupRepoStubForAdmin) GetByID(_ context.Context, _ int64) (*Group, err ...@@ -45,6 +45,13 @@ func (s *groupRepoStubForAdmin) GetByID(_ context.Context, _ int64) (*Group, err
return s.getByID, nil return s.getByID, nil
} }
func (s *groupRepoStubForAdmin) GetByIDLite(_ context.Context, _ int64) (*Group, error) {
if s.getErr != nil {
return nil, s.getErr
}
return s.getByID, nil
}
func (s *groupRepoStubForAdmin) Delete(_ context.Context, _ int64) error { func (s *groupRepoStubForAdmin) Delete(_ context.Context, _ int64) error {
panic("unexpected Delete call") panic("unexpected Delete call")
} }
...@@ -290,3 +297,84 @@ func TestAdminService_ListGroups_WithSearch(t *testing.T) { ...@@ -290,3 +297,84 @@ func TestAdminService_ListGroups_WithSearch(t *testing.T) {
require.True(t, *repo.listWithFiltersIsExclusive) require.True(t, *repo.listWithFiltersIsExclusive)
}) })
} }
func TestAdminService_ValidateFallbackGroup_DetectsCycle(t *testing.T) {
groupID := int64(1)
fallbackID := int64(2)
repo := &groupRepoStubForFallbackCycle{
groups: map[int64]*Group{
groupID: {
ID: groupID,
FallbackGroupID: &fallbackID,
},
fallbackID: {
ID: fallbackID,
FallbackGroupID: &groupID,
},
},
}
svc := &adminServiceImpl{groupRepo: repo}
err := svc.validateFallbackGroup(context.Background(), groupID, fallbackID)
require.Error(t, err)
require.Contains(t, err.Error(), "fallback group cycle")
}
type groupRepoStubForFallbackCycle struct {
groups map[int64]*Group
}
func (s *groupRepoStubForFallbackCycle) Create(_ context.Context, _ *Group) error {
panic("unexpected Create call")
}
func (s *groupRepoStubForFallbackCycle) Update(_ context.Context, _ *Group) error {
panic("unexpected Update call")
}
func (s *groupRepoStubForFallbackCycle) GetByID(ctx context.Context, id int64) (*Group, error) {
return s.GetByIDLite(ctx, id)
}
func (s *groupRepoStubForFallbackCycle) GetByIDLite(_ context.Context, id int64) (*Group, error) {
if g, ok := s.groups[id]; ok {
return g, nil
}
return nil, ErrGroupNotFound
}
func (s *groupRepoStubForFallbackCycle) Delete(_ context.Context, _ int64) error {
panic("unexpected Delete call")
}
func (s *groupRepoStubForFallbackCycle) DeleteCascade(_ context.Context, _ int64) ([]int64, error) {
panic("unexpected DeleteCascade call")
}
func (s *groupRepoStubForFallbackCycle) List(_ context.Context, _ pagination.PaginationParams) ([]Group, *pagination.PaginationResult, error) {
panic("unexpected List call")
}
func (s *groupRepoStubForFallbackCycle) ListWithFilters(_ context.Context, _ pagination.PaginationParams, _, _, _ string, _ *bool) ([]Group, *pagination.PaginationResult, error) {
panic("unexpected ListWithFilters call")
}
func (s *groupRepoStubForFallbackCycle) ListActive(_ context.Context) ([]Group, error) {
panic("unexpected ListActive call")
}
func (s *groupRepoStubForFallbackCycle) ListActiveByPlatform(_ context.Context, _ string) ([]Group, error) {
panic("unexpected ListActiveByPlatform call")
}
func (s *groupRepoStubForFallbackCycle) ExistsByName(_ context.Context, _ string) (bool, error) {
panic("unexpected ExistsByName call")
}
func (s *groupRepoStubForFallbackCycle) GetAccountCount(_ context.Context, _ int64) (int64, error) {
panic("unexpected GetAccountCount call")
}
func (s *groupRepoStubForFallbackCycle) DeleteAccountGroupsByGroupID(_ context.Context, _ int64) (int64, error) {
panic("unexpected DeleteAccountGroupsByGroupID call")
}
...@@ -93,6 +93,7 @@ var antigravityPrefixMapping = []struct { ...@@ -93,6 +93,7 @@ var antigravityPrefixMapping = []struct {
// 长前缀优先 // 长前缀优先
{"gemini-2.5-flash-image", "gemini-3-pro-image"}, // gemini-2.5-flash-image → 3-pro-image {"gemini-2.5-flash-image", "gemini-3-pro-image"}, // gemini-2.5-flash-image → 3-pro-image
{"gemini-3-pro-image", "gemini-3-pro-image"}, // gemini-3-pro-image-preview 等 {"gemini-3-pro-image", "gemini-3-pro-image"}, // gemini-3-pro-image-preview 等
{"gemini-3-flash", "gemini-3-flash"}, // gemini-3-flash-preview 等 → gemini-3-flash
{"claude-3-5-sonnet", "claude-sonnet-4-5"}, // 旧版 claude-3-5-sonnet-xxx {"claude-3-5-sonnet", "claude-sonnet-4-5"}, // 旧版 claude-3-5-sonnet-xxx
{"claude-sonnet-4-5", "claude-sonnet-4-5"}, // claude-sonnet-4-5-xxx {"claude-sonnet-4-5", "claude-sonnet-4-5"}, // claude-sonnet-4-5-xxx
{"claude-haiku-4-5", "claude-sonnet-4-5"}, // claude-haiku-4-5-xxx → sonnet {"claude-haiku-4-5", "claude-sonnet-4-5"}, // claude-haiku-4-5-xxx → sonnet
...@@ -502,6 +503,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context, ...@@ -502,6 +503,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
originalModel := claudeReq.Model originalModel := claudeReq.Model
mappedModel := s.getMappedModel(account, claudeReq.Model) mappedModel := s.getMappedModel(account, claudeReq.Model)
quotaScope, _ := resolveAntigravityQuotaScope(originalModel)
// 获取 access_token // 获取 access_token
if s.tokenProvider == nil { if s.tokenProvider == nil {
...@@ -603,7 +605,7 @@ urlFallbackLoop: ...@@ -603,7 +605,7 @@ urlFallbackLoop:
} }
// 所有重试都失败,标记限流状态 // 所有重试都失败,标记限流状态
if resp.StatusCode == 429 { if resp.StatusCode == 429 {
s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody) s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody, quotaScope)
} }
// 最后一次尝试也失败 // 最后一次尝试也失败
resp = &http.Response{ resp = &http.Response{
...@@ -696,7 +698,7 @@ urlFallbackLoop: ...@@ -696,7 +698,7 @@ urlFallbackLoop:
// 处理错误响应(重试后仍失败或不触发重试) // 处理错误响应(重试后仍失败或不触发重试)
if resp.StatusCode >= 400 { if resp.StatusCode >= 400 {
s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody) s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody, quotaScope)
if s.shouldFailoverUpstreamError(resp.StatusCode) { if s.shouldFailoverUpstreamError(resp.StatusCode) {
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode} return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
...@@ -1021,6 +1023,7 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co ...@@ -1021,6 +1023,7 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co
if len(body) == 0 { if len(body) == 0 {
return nil, s.writeGoogleError(c, http.StatusBadRequest, "Request body is empty") return nil, s.writeGoogleError(c, http.StatusBadRequest, "Request body is empty")
} }
quotaScope, _ := resolveAntigravityQuotaScope(originalModel)
// 解析请求以获取 image_size(用于图片计费) // 解析请求以获取 image_size(用于图片计费)
imageSize := s.extractImageSize(body) imageSize := s.extractImageSize(body)
...@@ -1146,7 +1149,7 @@ urlFallbackLoop: ...@@ -1146,7 +1149,7 @@ urlFallbackLoop:
} }
// 所有重试都失败,标记限流状态 // 所有重试都失败,标记限流状态
if resp.StatusCode == 429 { if resp.StatusCode == 429 {
s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody) s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody, quotaScope)
} }
resp = &http.Response{ resp = &http.Response{
StatusCode: resp.StatusCode, StatusCode: resp.StatusCode,
...@@ -1200,7 +1203,7 @@ urlFallbackLoop: ...@@ -1200,7 +1203,7 @@ urlFallbackLoop:
goto handleSuccess goto handleSuccess
} }
s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody) s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody, quotaScope)
if s.shouldFailoverUpstreamError(resp.StatusCode) { if s.shouldFailoverUpstreamError(resp.StatusCode) {
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode} return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
...@@ -1314,7 +1317,7 @@ func sleepAntigravityBackoffWithContext(ctx context.Context, attempt int) bool { ...@@ -1314,7 +1317,7 @@ func sleepAntigravityBackoffWithContext(ctx context.Context, attempt int) bool {
} }
} }
func (s *AntigravityGatewayService) handleUpstreamError(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte) { func (s *AntigravityGatewayService) handleUpstreamError(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope) {
// 429 使用 Gemini 格式解析(从 body 解析重置时间) // 429 使用 Gemini 格式解析(从 body 解析重置时间)
if statusCode == 429 { if statusCode == 429 {
resetAt := ParseGeminiRateLimitResetTime(body) resetAt := ParseGeminiRateLimitResetTime(body)
...@@ -1325,13 +1328,23 @@ func (s *AntigravityGatewayService) handleUpstreamError(ctx context.Context, pre ...@@ -1325,13 +1328,23 @@ func (s *AntigravityGatewayService) handleUpstreamError(ctx context.Context, pre
defaultDur = 5 * time.Minute defaultDur = 5 * time.Minute
} }
ra := time.Now().Add(defaultDur) ra := time.Now().Add(defaultDur)
log.Printf("%s status=429 rate_limited reset_in=%v (fallback)", prefix, defaultDur) log.Printf("%s status=429 rate_limited scope=%s reset_in=%v (fallback)", prefix, quotaScope, defaultDur)
_ = s.accountRepo.SetRateLimited(ctx, account.ID, ra) if quotaScope == "" {
return
}
if err := s.accountRepo.SetAntigravityQuotaScopeLimit(ctx, account.ID, quotaScope, ra); err != nil {
log.Printf("%s status=429 rate_limit_set_failed scope=%s error=%v", prefix, quotaScope, err)
}
return return
} }
resetTime := time.Unix(*resetAt, 0) resetTime := time.Unix(*resetAt, 0)
log.Printf("%s status=429 rate_limited reset_at=%v reset_in=%v", prefix, resetTime.Format("15:04:05"), time.Until(resetTime).Truncate(time.Second)) log.Printf("%s status=429 rate_limited scope=%s reset_at=%v reset_in=%v", prefix, quotaScope, resetTime.Format("15:04:05"), time.Until(resetTime).Truncate(time.Second))
_ = s.accountRepo.SetRateLimited(ctx, account.ID, resetTime) if quotaScope == "" {
return
}
if err := s.accountRepo.SetAntigravityQuotaScopeLimit(ctx, account.ID, quotaScope, resetTime); err != nil {
log.Printf("%s status=429 rate_limit_set_failed scope=%s error=%v", prefix, quotaScope, err)
}
return return
} }
// 其他错误码继续使用 rateLimitService // 其他错误码继续使用 rateLimitService
......
package service
import (
"strings"
"time"
)
const antigravityQuotaScopesKey = "antigravity_quota_scopes"
// AntigravityQuotaScope 表示 Antigravity 的配额域
type AntigravityQuotaScope string
const (
AntigravityQuotaScopeClaude AntigravityQuotaScope = "claude"
AntigravityQuotaScopeGeminiText AntigravityQuotaScope = "gemini_text"
AntigravityQuotaScopeGeminiImage AntigravityQuotaScope = "gemini_image"
)
// resolveAntigravityQuotaScope 根据模型名称解析配额域
func resolveAntigravityQuotaScope(requestedModel string) (AntigravityQuotaScope, bool) {
model := normalizeAntigravityModelName(requestedModel)
if model == "" {
return "", false
}
switch {
case strings.HasPrefix(model, "claude-"):
return AntigravityQuotaScopeClaude, true
case strings.HasPrefix(model, "gemini-"):
if isImageGenerationModel(model) {
return AntigravityQuotaScopeGeminiImage, true
}
return AntigravityQuotaScopeGeminiText, true
default:
return "", false
}
}
func normalizeAntigravityModelName(model string) string {
normalized := strings.ToLower(strings.TrimSpace(model))
normalized = strings.TrimPrefix(normalized, "models/")
return normalized
}
// IsSchedulableForModel 结合 Antigravity 配额域限流判断是否可调度
func (a *Account) IsSchedulableForModel(requestedModel string) bool {
if a == nil {
return false
}
if !a.IsSchedulable() {
return false
}
if a.Platform != PlatformAntigravity {
return true
}
scope, ok := resolveAntigravityQuotaScope(requestedModel)
if !ok {
return true
}
resetAt := a.antigravityQuotaScopeResetAt(scope)
if resetAt == nil {
return true
}
now := time.Now()
return !now.Before(*resetAt)
}
func (a *Account) antigravityQuotaScopeResetAt(scope AntigravityQuotaScope) *time.Time {
if a == nil || a.Extra == nil || scope == "" {
return nil
}
rawScopes, ok := a.Extra[antigravityQuotaScopesKey].(map[string]any)
if !ok {
return nil
}
rawScope, ok := rawScopes[string(scope)].(map[string]any)
if !ok {
return nil
}
resetAtRaw, ok := rawScope["rate_limit_reset_at"].(string)
if !ok || strings.TrimSpace(resetAtRaw) == "" {
return nil
}
resetAt, err := time.Parse(time.RFC3339, resetAtRaw)
if err != nil {
return nil
}
return &resetAt
}
...@@ -52,6 +52,7 @@ type AuthService struct { ...@@ -52,6 +52,7 @@ type AuthService struct {
emailService *EmailService emailService *EmailService
turnstileService *TurnstileService turnstileService *TurnstileService
emailQueueService *EmailQueueService emailQueueService *EmailQueueService
promoService *PromoService
} }
// NewAuthService 创建认证服务实例 // NewAuthService 创建认证服务实例
...@@ -62,6 +63,7 @@ func NewAuthService( ...@@ -62,6 +63,7 @@ func NewAuthService(
emailService *EmailService, emailService *EmailService,
turnstileService *TurnstileService, turnstileService *TurnstileService,
emailQueueService *EmailQueueService, emailQueueService *EmailQueueService,
promoService *PromoService,
) *AuthService { ) *AuthService {
return &AuthService{ return &AuthService{
userRepo: userRepo, userRepo: userRepo,
...@@ -70,16 +72,17 @@ func NewAuthService( ...@@ -70,16 +72,17 @@ func NewAuthService(
emailService: emailService, emailService: emailService,
turnstileService: turnstileService, turnstileService: turnstileService,
emailQueueService: emailQueueService, emailQueueService: emailQueueService,
promoService: promoService,
} }
} }
// Register 用户注册,返回token和用户 // Register 用户注册,返回token和用户
func (s *AuthService) Register(ctx context.Context, email, password string) (string, *User, error) { func (s *AuthService) Register(ctx context.Context, email, password string) (string, *User, error) {
return s.RegisterWithVerification(ctx, email, password, "") return s.RegisterWithVerification(ctx, email, password, "", "")
} }
// RegisterWithVerification 用户注册(支持邮件验证),返回token和用户 // RegisterWithVerification 用户注册(支持邮件验证和优惠码),返回token和用户
func (s *AuthService) RegisterWithVerification(ctx context.Context, email, password, verifyCode string) (string, *User, error) { func (s *AuthService) RegisterWithVerification(ctx context.Context, email, password, verifyCode, promoCode string) (string, *User, error) {
// 检查是否开放注册(默认关闭:settingService 未配置时不允许注册) // 检查是否开放注册(默认关闭:settingService 未配置时不允许注册)
if s.settingService == nil || !s.settingService.IsRegistrationEnabled(ctx) { if s.settingService == nil || !s.settingService.IsRegistrationEnabled(ctx) {
return "", nil, ErrRegDisabled return "", nil, ErrRegDisabled
...@@ -150,6 +153,19 @@ func (s *AuthService) RegisterWithVerification(ctx context.Context, email, passw ...@@ -150,6 +153,19 @@ func (s *AuthService) RegisterWithVerification(ctx context.Context, email, passw
return "", nil, ErrServiceUnavailable return "", nil, ErrServiceUnavailable
} }
// 应用优惠码(如果提供)
if promoCode != "" && s.promoService != nil {
if err := s.promoService.ApplyPromoCode(ctx, user.ID, promoCode); err != nil {
// 优惠码应用失败不影响注册,只记录日志
log.Printf("[Auth] Failed to apply promo code for user %d: %v", user.ID, err)
} else {
// 重新获取用户信息以获取更新后的余额
if updatedUser, err := s.userRepo.GetByID(ctx, user.ID); err == nil {
user = updatedUser
}
}
}
// 生成token // 生成token
token, err := s.GenerateToken(user) token, err := s.GenerateToken(user)
if err != nil { if err != nil {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment