"git@web.lueluesay.top:chenxi/sub2api.git" did not exist on "a14babdc732101397db1cd3bb36d91fa93c305be"
Commit 65e69738 authored by cyhhao's avatar cyhhao
Browse files

Merge branch 'main' of github.com:Wei-Shaw/sub2api

parents c8e2f614 39fad63c
...@@ -136,11 +136,24 @@ var allowedHeaders = map[string]bool{ ...@@ -136,11 +136,24 @@ var allowedHeaders = map[string]bool{
"content-type": true, "content-type": true,
} }
// GatewayCache defines cache operations for gateway service // GatewayCache 定义网关服务的缓存操作接口。
// 提供粘性会话(Sticky Session)的存储、查询、刷新和删除功能。
//
// GatewayCache defines cache operations for gateway service.
// Provides sticky session storage, retrieval, refresh and deletion capabilities.
type GatewayCache interface { type GatewayCache interface {
// GetSessionAccountID 获取粘性会话绑定的账号 ID
// Get the account ID bound to a sticky session
GetSessionAccountID(ctx context.Context, groupID int64, sessionHash string) (int64, error) GetSessionAccountID(ctx context.Context, groupID int64, sessionHash string) (int64, error)
// SetSessionAccountID 设置粘性会话与账号的绑定关系
// Set the binding between sticky session and account
SetSessionAccountID(ctx context.Context, groupID int64, sessionHash string, accountID int64, ttl time.Duration) error SetSessionAccountID(ctx context.Context, groupID int64, sessionHash string, accountID int64, ttl time.Duration) error
// RefreshSessionTTL 刷新粘性会话的过期时间
// Refresh the expiration time of a sticky session
RefreshSessionTTL(ctx context.Context, groupID int64, sessionHash string, ttl time.Duration) error RefreshSessionTTL(ctx context.Context, groupID int64, sessionHash string, ttl time.Duration) error
// DeleteSessionAccountID 删除粘性会话绑定,用于账号不可用时主动清理
// Delete sticky session binding, used to proactively clean up when account becomes unavailable
DeleteSessionAccountID(ctx context.Context, groupID int64, sessionHash string) error
} }
// derefGroupID safely dereferences *int64 to int64, returning 0 if nil // derefGroupID safely dereferences *int64 to int64, returning 0 if nil
...@@ -151,6 +164,28 @@ func derefGroupID(groupID *int64) int64 { ...@@ -151,6 +164,28 @@ func derefGroupID(groupID *int64) int64 {
return *groupID return *groupID
} }
// shouldClearStickySession 检查账号是否处于不可调度状态,需要清理粘性会话绑定。
// 当账号状态为错误、禁用、不可调度,或处于临时不可调度期间时,返回 true。
// 这确保后续请求不会继续使用不可用的账号。
//
// shouldClearStickySession checks if an account is in an unschedulable state
// and the sticky session binding should be cleared.
// Returns true when account status is error/disabled, schedulable is false,
// or within temporary unschedulable period.
// This ensures subsequent requests won't continue using unavailable accounts.
func shouldClearStickySession(account *Account) bool {
if account == nil {
return false
}
if account.Status == StatusError || account.Status == StatusDisabled || !account.Schedulable {
return true
}
if account.TempUnschedulableUntil != nil && time.Now().Before(*account.TempUnschedulableUntil) {
return true
}
return false
}
type AccountWaitPlan struct { type AccountWaitPlan struct {
AccountID int64 AccountID int64
MaxConcurrency int MaxConcurrency int
...@@ -1067,6 +1102,8 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro ...@@ -1067,6 +1102,8 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
} }
// 粘性账号槽位满且等待队列已满,继续使用负载感知选择 // 粘性账号槽位满且等待队列已满,继续使用负载感知选择
} }
} else {
_ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
} }
} }
} }
...@@ -1173,41 +1210,52 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro ...@@ -1173,41 +1210,52 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), sessionHash) accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
if err == nil && accountID > 0 && !isExcluded(accountID) { if err == nil && accountID > 0 && !isExcluded(accountID) {
account, ok := accountByID[accountID] account, ok := accountByID[accountID]
if ok && s.isAccountInGroup(account, groupID) && if ok {
s.isAccountAllowedForPlatform(account, platform, useMixed) && // 检查账户是否需要清理粘性会话绑定
account.IsSchedulableForModel(requestedModel) && // Check if the account needs sticky session cleanup
(requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) && clearSticky := shouldClearStickySession(account)
s.isAccountSchedulableForWindowCost(ctx, account, true) { // 粘性会话窗口费用检查 if clearSticky {
result, err := s.tryAcquireAccountSlot(ctx, accountID, account.Concurrency) _ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
if err == nil && result.Acquired {
// 会话数量限制检查
if !s.checkAndRegisterSession(ctx, account, sessionHash) {
result.ReleaseFunc() // 释放槽位,继续到 Layer 2
} else {
_ = s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL)
return &AccountSelectionResult{
Account: account,
Acquired: true,
ReleaseFunc: result.ReleaseFunc,
}, nil
}
} }
if !clearSticky && s.isAccountInGroup(account, groupID) &&
s.isAccountAllowedForPlatform(account, platform, useMixed) &&
account.IsSchedulableForModel(requestedModel) &&
(requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) &&
s.isAccountSchedulableForWindowCost(ctx, account, true) { // 粘性会话窗口费用检查
result, err := s.tryAcquireAccountSlot(ctx, accountID, account.Concurrency)
if err == nil && result.Acquired {
// 会话数量限制检查
// Session count limit check
if !s.checkAndRegisterSession(ctx, account, sessionHash) {
result.ReleaseFunc() // 释放槽位,继续到 Layer 2
} else {
_ = s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL)
return &AccountSelectionResult{
Account: account,
Acquired: true,
ReleaseFunc: result.ReleaseFunc,
}, nil
}
}
waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, accountID) waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, accountID)
if waitingCount < cfg.StickySessionMaxWaiting { if waitingCount < cfg.StickySessionMaxWaiting {
// 会话数量限制检查(等待计划也需要占用会话配额) // 会话数量限制检查(等待计划也需要占用会话配额)
if !s.checkAndRegisterSession(ctx, account, sessionHash) { // Session count limit check (wait plan also requires session quota)
// 会话限制已满,继续到 Layer 2 if !s.checkAndRegisterSession(ctx, account, sessionHash) {
} else { // 会话限制已满,继续到 Layer 2
return &AccountSelectionResult{ // Session limit full, continue to Layer 2
Account: account, } else {
WaitPlan: &AccountWaitPlan{ return &AccountSelectionResult{
AccountID: accountID, Account: account,
MaxConcurrency: account.Concurrency, WaitPlan: &AccountWaitPlan{
Timeout: cfg.StickySessionWaitTimeout, AccountID: accountID,
MaxWaiting: cfg.StickySessionMaxWaiting, MaxConcurrency: account.Concurrency,
}, Timeout: cfg.StickySessionWaitTimeout,
}, nil MaxWaiting: cfg.StickySessionMaxWaiting,
},
}, nil
}
} }
} }
} }
...@@ -1827,14 +1875,20 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context, ...@@ -1827,14 +1875,20 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
if _, excluded := excludedIDs[accountID]; !excluded { if _, excluded := excludedIDs[accountID]; !excluded {
account, err := s.getSchedulableAccount(ctx, accountID) account, err := s.getSchedulableAccount(ctx, accountID)
// 检查账号分组归属和平台匹配(确保粘性会话不会跨分组或跨平台) // 检查账号分组归属和平台匹配(确保粘性会话不会跨分组或跨平台)
if err == nil && s.isAccountInGroup(account, groupID) && account.Platform == platform && account.IsSchedulableForModel(requestedModel) && (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) { if err == nil {
if err := s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL); err != nil { clearSticky := shouldClearStickySession(account)
log.Printf("refresh session ttl failed: session=%s err=%v", sessionHash, err) if clearSticky {
_ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
} }
if s.debugModelRoutingEnabled() { if !clearSticky && s.isAccountInGroup(account, groupID) && account.Platform == platform && account.IsSchedulableForModel(requestedModel) && (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) {
log.Printf("[ModelRoutingDebug] legacy routed sticky hit: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), accountID) if err := s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL); err != nil {
log.Printf("refresh session ttl failed: session=%s err=%v", sessionHash, err)
}
if s.debugModelRoutingEnabled() {
log.Printf("[ModelRoutingDebug] legacy routed sticky hit: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), accountID)
}
return account, nil
} }
return account, nil
} }
} }
} }
...@@ -1924,11 +1978,17 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context, ...@@ -1924,11 +1978,17 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
if _, excluded := excludedIDs[accountID]; !excluded { if _, excluded := excludedIDs[accountID]; !excluded {
account, err := s.getSchedulableAccount(ctx, accountID) account, err := s.getSchedulableAccount(ctx, accountID)
// 检查账号分组归属和平台匹配(确保粘性会话不会跨分组或跨平台) // 检查账号分组归属和平台匹配(确保粘性会话不会跨分组或跨平台)
if err == nil && s.isAccountInGroup(account, groupID) && account.Platform == platform && account.IsSchedulableForModel(requestedModel) && (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) { if err == nil {
if err := s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL); err != nil { clearSticky := shouldClearStickySession(account)
log.Printf("refresh session ttl failed: session=%s err=%v", sessionHash, err) if clearSticky {
_ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
}
if !clearSticky && s.isAccountInGroup(account, groupID) && account.Platform == platform && account.IsSchedulableForModel(requestedModel) && (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) {
if err := s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL); err != nil {
log.Printf("refresh session ttl failed: session=%s err=%v", sessionHash, err)
}
return account, nil
} }
return account, nil
} }
} }
} }
...@@ -2028,15 +2088,21 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g ...@@ -2028,15 +2088,21 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
if _, excluded := excludedIDs[accountID]; !excluded { if _, excluded := excludedIDs[accountID]; !excluded {
account, err := s.getSchedulableAccount(ctx, accountID) account, err := s.getSchedulableAccount(ctx, accountID)
// 检查账号分组归属和有效性:原生平台直接匹配,antigravity 需要启用混合调度 // 检查账号分组归属和有效性:原生平台直接匹配,antigravity 需要启用混合调度
if err == nil && s.isAccountInGroup(account, groupID) && account.IsSchedulableForModel(requestedModel) && (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) { if err == nil {
if account.Platform == nativePlatform || (account.Platform == PlatformAntigravity && account.IsMixedSchedulingEnabled()) { clearSticky := shouldClearStickySession(account)
if err := s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL); err != nil { if clearSticky {
log.Printf("refresh session ttl failed: session=%s err=%v", sessionHash, err) _ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
} }
if s.debugModelRoutingEnabled() { if !clearSticky && s.isAccountInGroup(account, groupID) && account.IsSchedulableForModel(requestedModel) && (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) {
log.Printf("[ModelRoutingDebug] legacy mixed routed sticky hit: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), accountID) if account.Platform == nativePlatform || (account.Platform == PlatformAntigravity && account.IsMixedSchedulingEnabled()) {
if err := s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL); err != nil {
log.Printf("refresh session ttl failed: session=%s err=%v", sessionHash, err)
}
if s.debugModelRoutingEnabled() {
log.Printf("[ModelRoutingDebug] legacy mixed routed sticky hit: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), accountID)
}
return account, nil
} }
return account, nil
} }
} }
} }
...@@ -2127,12 +2193,18 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g ...@@ -2127,12 +2193,18 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
if _, excluded := excludedIDs[accountID]; !excluded { if _, excluded := excludedIDs[accountID]; !excluded {
account, err := s.getSchedulableAccount(ctx, accountID) account, err := s.getSchedulableAccount(ctx, accountID)
// 检查账号分组归属和有效性:原生平台直接匹配,antigravity 需要启用混合调度 // 检查账号分组归属和有效性:原生平台直接匹配,antigravity 需要启用混合调度
if err == nil && s.isAccountInGroup(account, groupID) && account.IsSchedulableForModel(requestedModel) && (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) { if err == nil {
if account.Platform == nativePlatform || (account.Platform == PlatformAntigravity && account.IsMixedSchedulingEnabled()) { clearSticky := shouldClearStickySession(account)
if err := s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL); err != nil { if clearSticky {
log.Printf("refresh session ttl failed: session=%s err=%v", sessionHash, err) _ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
}
if !clearSticky && s.isAccountInGroup(account, groupID) && account.IsSchedulableForModel(requestedModel) && (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) {
if account.Platform == nativePlatform || (account.Platform == PlatformAntigravity && account.IsMixedSchedulingEnabled()) {
if err := s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL); err != nil {
log.Printf("refresh session ttl failed: session=%s err=%v", sessionHash, err)
}
return account, nil
} }
return account, nil
} }
} }
} }
......
...@@ -82,145 +82,276 @@ func (s *GeminiMessagesCompatService) SelectAccountForModel(ctx context.Context, ...@@ -82,145 +82,276 @@ func (s *GeminiMessagesCompatService) SelectAccountForModel(ctx context.Context,
} }
func (s *GeminiMessagesCompatService) SelectAccountForModelWithExclusions(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}) (*Account, error) { func (s *GeminiMessagesCompatService) SelectAccountForModelWithExclusions(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}) (*Account, error) {
// 1. 确定目标平台和调度模式
// Determine target platform and scheduling mode
platform, useMixedScheduling, hasForcePlatform, err := s.resolvePlatformAndSchedulingMode(ctx, groupID)
if err != nil {
return nil, err
}
cacheKey := "gemini:" + sessionHash
// 2. 尝试粘性会话命中
// Try sticky session hit
if account := s.tryStickySessionHit(ctx, groupID, sessionHash, cacheKey, requestedModel, excludedIDs, platform, useMixedScheduling); account != nil {
return account, nil
}
// 3. 查询可调度账户(强制平台模式:优先按分组查找,找不到再查全部)
// Query schedulable accounts (force platform mode: try group first, fallback to all)
accounts, err := s.listSchedulableAccountsOnce(ctx, groupID, platform, hasForcePlatform)
if err != nil {
return nil, fmt.Errorf("query accounts failed: %w", err)
}
// 强制平台模式下,分组中找不到账户时回退查询全部
if len(accounts) == 0 && groupID != nil && hasForcePlatform {
accounts, err = s.listSchedulableAccountsOnce(ctx, nil, platform, hasForcePlatform)
if err != nil {
return nil, fmt.Errorf("query accounts failed: %w", err)
}
}
// 4. 按优先级 + LRU 选择最佳账号
// Select best account by priority + LRU
selected := s.selectBestGeminiAccount(ctx, accounts, requestedModel, excludedIDs, platform, useMixedScheduling)
if selected == nil {
if requestedModel != "" {
return nil, fmt.Errorf("no available Gemini accounts supporting model: %s", requestedModel)
}
return nil, errors.New("no available Gemini accounts")
}
// 5. 设置粘性会话绑定
// Set sticky session binding
if sessionHash != "" {
_ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), cacheKey, selected.ID, geminiStickySessionTTL)
}
return selected, nil
}
// resolvePlatformAndSchedulingMode 解析目标平台和调度模式。
// 返回:平台名称、是否使用混合调度、是否强制平台、错误。
//
// resolvePlatformAndSchedulingMode resolves target platform and scheduling mode.
// Returns: platform name, whether to use mixed scheduling, whether force platform, error.
func (s *GeminiMessagesCompatService) resolvePlatformAndSchedulingMode(ctx context.Context, groupID *int64) (platform string, useMixedScheduling bool, hasForcePlatform bool, err error) {
// 优先检查 context 中的强制平台(/antigravity 路由) // 优先检查 context 中的强制平台(/antigravity 路由)
var platform string
forcePlatform, hasForcePlatform := ctx.Value(ctxkey.ForcePlatform).(string) forcePlatform, hasForcePlatform := ctx.Value(ctxkey.ForcePlatform).(string)
if hasForcePlatform && forcePlatform != "" { if hasForcePlatform && forcePlatform != "" {
platform = forcePlatform return forcePlatform, false, true, nil
} else if groupID != nil { }
if groupID != nil {
// 根据分组 platform 决定查询哪种账号 // 根据分组 platform 决定查询哪种账号
var group *Group var group *Group
if ctxGroup, ok := ctx.Value(ctxkey.Group).(*Group); ok && IsGroupContextValid(ctxGroup) && ctxGroup.ID == *groupID { if ctxGroup, ok := ctx.Value(ctxkey.Group).(*Group); ok && IsGroupContextValid(ctxGroup) && ctxGroup.ID == *groupID {
group = ctxGroup group = ctxGroup
} else { } else {
var err error
group, err = s.groupRepo.GetByIDLite(ctx, *groupID) group, err = s.groupRepo.GetByIDLite(ctx, *groupID)
if err != nil { if err != nil {
return nil, fmt.Errorf("get group failed: %w", err) return "", false, false, fmt.Errorf("get group failed: %w", err)
} }
} }
platform = group.Platform // gemini 分组支持混合调度(包含启用了 mixed_scheduling 的 antigravity 账户)
} else { return group.Platform, group.Platform == PlatformGemini, false, nil
// 无分组时只使用原生 gemini 平台
platform = PlatformGemini
} }
// gemini 分组支持混合调度(包含启用了 mixed_scheduling 的 antigravity 账户) // 无分组时只使用原生 gemini 平台
// 注意:强制平台模式不走混合调度 return PlatformGemini, true, false, nil
useMixedScheduling := platform == PlatformGemini && !hasForcePlatform }
cacheKey := "gemini:" + sessionHash // tryStickySessionHit 尝试从粘性会话获取账号。
// 如果命中且账号可用则返回账号;如果账号不可用则清理会话并返回 nil。
//
// tryStickySessionHit attempts to get account from sticky session.
// Returns account if hit and usable; clears session and returns nil if account unavailable.
func (s *GeminiMessagesCompatService) tryStickySessionHit(
ctx context.Context,
groupID *int64,
sessionHash, cacheKey, requestedModel string,
excludedIDs map[int64]struct{},
platform string,
useMixedScheduling bool,
) *Account {
if sessionHash == "" {
return nil
}
if sessionHash != "" { accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), cacheKey)
accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), cacheKey) if err != nil || accountID <= 0 {
if err == nil && accountID > 0 { return nil
if _, excluded := excludedIDs[accountID]; !excluded {
account, err := s.getSchedulableAccount(ctx, accountID)
// 检查账号是否有效:原生平台直接匹配,antigravity 需要启用混合调度
if err == nil && account.IsSchedulableForModel(requestedModel) && (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) {
valid := false
if account.Platform == platform {
valid = true
} else if useMixedScheduling && account.Platform == PlatformAntigravity && account.IsMixedSchedulingEnabled() {
valid = true
}
if valid {
usable := true
if s.rateLimitService != nil && requestedModel != "" {
ok, err := s.rateLimitService.PreCheckUsage(ctx, account, requestedModel)
if err != nil {
log.Printf("[Gemini PreCheck] Account %d precheck error: %v", account.ID, err)
}
if !ok {
usable = false
}
}
if usable {
_ = s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), cacheKey, geminiStickySessionTTL)
return account, nil
}
}
}
}
}
} }
// 查询可调度账户(强制平台模式:优先按分组查找,找不到再查全部) if _, excluded := excludedIDs[accountID]; excluded {
accounts, err := s.listSchedulableAccountsOnce(ctx, groupID, platform, hasForcePlatform) return nil
}
account, err := s.getSchedulableAccount(ctx, accountID)
if err != nil { if err != nil {
return nil, fmt.Errorf("query accounts failed: %w", err) return nil
} }
// 强制平台模式下,分组中找不到账户时回退查询全部
if len(accounts) == 0 && groupID != nil && hasForcePlatform { // 检查账号是否需要清理粘性会话
accounts, err = s.listSchedulableAccountsOnce(ctx, nil, platform, hasForcePlatform) // Check if sticky session should be cleared
if err != nil { if shouldClearStickySession(account) {
return nil, fmt.Errorf("query accounts failed: %w", err) _ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), cacheKey)
} return nil
}
// 验证账号是否可用于当前请求
// Verify account is usable for current request
if !s.isAccountUsableForRequest(ctx, account, requestedModel, platform, useMixedScheduling) {
return nil
} }
// 刷新会话 TTL 并返回账号
// Refresh session TTL and return account
_ = s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), cacheKey, geminiStickySessionTTL)
return account
}
// isAccountUsableForRequest 检查账号是否可用于当前请求。
// 验证:模型调度、模型支持、平台匹配、速率限制预检。
//
// isAccountUsableForRequest checks if account is usable for current request.
// Validates: model scheduling, model support, platform matching, rate limit precheck.
func (s *GeminiMessagesCompatService) isAccountUsableForRequest(
ctx context.Context,
account *Account,
requestedModel, platform string,
useMixedScheduling bool,
) bool {
// 检查模型调度能力
// Check model scheduling capability
if !account.IsSchedulableForModel(requestedModel) {
return false
}
// 检查模型支持
// Check model support
if requestedModel != "" && !s.isModelSupportedByAccount(account, requestedModel) {
return false
}
// 检查平台匹配
// Check platform matching
if !s.isAccountValidForPlatform(account, platform, useMixedScheduling) {
return false
}
// 速率限制预检
// Rate limit precheck
if !s.passesRateLimitPreCheck(ctx, account, requestedModel) {
return false
}
return true
}
// isAccountValidForPlatform 检查账号是否匹配目标平台。
// 原生平台直接匹配;混合调度模式下 antigravity 需要启用 mixed_scheduling。
//
// isAccountValidForPlatform checks if account matches target platform.
// Native platform matches directly; mixed scheduling mode requires antigravity to enable mixed_scheduling.
func (s *GeminiMessagesCompatService) isAccountValidForPlatform(account *Account, platform string, useMixedScheduling bool) bool {
if account.Platform == platform {
return true
}
if useMixedScheduling && account.Platform == PlatformAntigravity && account.IsMixedSchedulingEnabled() {
return true
}
return false
}
// passesRateLimitPreCheck 执行速率限制预检。
// 返回 true 表示通过预检或无需预检。
//
// passesRateLimitPreCheck performs rate limit precheck.
// Returns true if passed or precheck not required.
func (s *GeminiMessagesCompatService) passesRateLimitPreCheck(ctx context.Context, account *Account, requestedModel string) bool {
if s.rateLimitService == nil || requestedModel == "" {
return true
}
ok, err := s.rateLimitService.PreCheckUsage(ctx, account, requestedModel)
if err != nil {
log.Printf("[Gemini PreCheck] Account %d precheck error: %v", account.ID, err)
}
return ok
}
// selectBestGeminiAccount 从候选账号中选择最佳账号(优先级 + LRU + OAuth 优先)。
// 返回 nil 表示无可用账号。
//
// selectBestGeminiAccount selects best account from candidates (priority + LRU + OAuth preferred).
// Returns nil if no available account.
func (s *GeminiMessagesCompatService) selectBestGeminiAccount(
ctx context.Context,
accounts []Account,
requestedModel string,
excludedIDs map[int64]struct{},
platform string,
useMixedScheduling bool,
) *Account {
var selected *Account var selected *Account
for i := range accounts { for i := range accounts {
acc := &accounts[i] acc := &accounts[i]
// 跳过被排除的账号
if _, excluded := excludedIDs[acc.ID]; excluded { if _, excluded := excludedIDs[acc.ID]; excluded {
continue continue
} }
// 混合调度模式下:原生平台直接通过,antigravity 需要启用 mixed_scheduling
// 非混合调度模式(antigravity 分组):不需要过滤 // 检查账号是否可用于当前请求
if useMixedScheduling && acc.Platform == PlatformAntigravity && !acc.IsMixedSchedulingEnabled() { if !s.isAccountUsableForRequest(ctx, acc, requestedModel, platform, useMixedScheduling) {
continue
}
if !acc.IsSchedulableForModel(requestedModel) {
continue
}
if requestedModel != "" && !s.isModelSupportedByAccount(acc, requestedModel) {
continue continue
} }
if s.rateLimitService != nil && requestedModel != "" {
ok, err := s.rateLimitService.PreCheckUsage(ctx, acc, requestedModel) // 选择最佳账号
if err != nil {
log.Printf("[Gemini PreCheck] Account %d precheck error: %v", acc.ID, err)
}
if !ok {
continue
}
}
if selected == nil { if selected == nil {
selected = acc selected = acc
continue continue
} }
if acc.Priority < selected.Priority {
if s.isBetterGeminiAccount(acc, selected) {
selected = acc selected = acc
} else if acc.Priority == selected.Priority {
switch {
case acc.LastUsedAt == nil && selected.LastUsedAt != nil:
selected = acc
case acc.LastUsedAt != nil && selected.LastUsedAt == nil:
// keep selected (never used is preferred)
case acc.LastUsedAt == nil && selected.LastUsedAt == nil:
// Prefer OAuth accounts when both are unused (more compatible for Code Assist flows).
if acc.Type == AccountTypeOAuth && selected.Type != AccountTypeOAuth {
selected = acc
}
default:
if acc.LastUsedAt.Before(*selected.LastUsedAt) {
selected = acc
}
}
} }
} }
if selected == nil { return selected
if requestedModel != "" { }
return nil, fmt.Errorf("no available Gemini accounts supporting model: %s", requestedModel)
}
return nil, errors.New("no available Gemini accounts")
}
if sessionHash != "" { // isBetterGeminiAccount 判断 candidate 是否比 current 更优。
_ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), cacheKey, selected.ID, geminiStickySessionTTL) // 规则:优先级更高(数值更小)优先;同优先级时,未使用过的优先(OAuth > 非 OAuth),其次是最久未使用的。
//
// isBetterGeminiAccount checks if candidate is better than current.
// Rules: higher priority (lower value) wins; same priority: never used (OAuth > non-OAuth) > least recently used.
func (s *GeminiMessagesCompatService) isBetterGeminiAccount(candidate, current *Account) bool {
// 优先级更高(数值更小)
if candidate.Priority < current.Priority {
return true
}
if candidate.Priority > current.Priority {
return false
} }
return selected, nil // 同优先级,比较最后使用时间
switch {
case candidate.LastUsedAt == nil && current.LastUsedAt != nil:
// candidate 从未使用,优先
return true
case candidate.LastUsedAt != nil && current.LastUsedAt == nil:
// current 从未使用,保持
return false
case candidate.LastUsedAt == nil && current.LastUsedAt == nil:
// 都未使用,优先选择 OAuth 账号(更兼容 Code Assist 流程)
return candidate.Type == AccountTypeOAuth && current.Type != AccountTypeOAuth
default:
// 都使用过,选择最久未使用的
return candidate.LastUsedAt.Before(*current.LastUsedAt)
}
} }
// isModelSupportedByAccount 根据账户平台检查模型支持 // isModelSupportedByAccount 根据账户平台检查模型支持
......
...@@ -15,8 +15,10 @@ import ( ...@@ -15,8 +15,10 @@ import (
// mockAccountRepoForGemini Gemini 测试用的 mock // mockAccountRepoForGemini Gemini 测试用的 mock
type mockAccountRepoForGemini struct { type mockAccountRepoForGemini struct {
accounts []Account accounts []Account
accountsByID map[int64]*Account accountsByID map[int64]*Account
listByGroupFunc func(ctx context.Context, groupID int64, platforms []string) ([]Account, error)
listByPlatformFunc func(ctx context.Context, platforms []string) ([]Account, error)
} }
func (m *mockAccountRepoForGemini) GetByID(ctx context.Context, id int64) (*Account, error) { func (m *mockAccountRepoForGemini) GetByID(ctx context.Context, id int64) (*Account, error) {
...@@ -107,6 +109,9 @@ func (m *mockAccountRepoForGemini) ListSchedulableByGroupID(ctx context.Context, ...@@ -107,6 +109,9 @@ func (m *mockAccountRepoForGemini) ListSchedulableByGroupID(ctx context.Context,
return nil, nil return nil, nil
} }
func (m *mockAccountRepoForGemini) ListSchedulableByPlatforms(ctx context.Context, platforms []string) ([]Account, error) { func (m *mockAccountRepoForGemini) ListSchedulableByPlatforms(ctx context.Context, platforms []string) ([]Account, error) {
if m.listByPlatformFunc != nil {
return m.listByPlatformFunc(ctx, platforms)
}
var result []Account var result []Account
platformSet := make(map[string]bool) platformSet := make(map[string]bool)
for _, p := range platforms { for _, p := range platforms {
...@@ -120,6 +125,9 @@ func (m *mockAccountRepoForGemini) ListSchedulableByPlatforms(ctx context.Contex ...@@ -120,6 +125,9 @@ func (m *mockAccountRepoForGemini) ListSchedulableByPlatforms(ctx context.Contex
return result, nil return result, nil
} }
func (m *mockAccountRepoForGemini) ListSchedulableByGroupIDAndPlatforms(ctx context.Context, groupID int64, platforms []string) ([]Account, error) { func (m *mockAccountRepoForGemini) ListSchedulableByGroupIDAndPlatforms(ctx context.Context, groupID int64, platforms []string) ([]Account, error) {
if m.listByGroupFunc != nil {
return m.listByGroupFunc(ctx, groupID, platforms)
}
return m.ListSchedulableByPlatforms(ctx, platforms) return m.ListSchedulableByPlatforms(ctx, platforms)
} }
func (m *mockAccountRepoForGemini) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error { func (m *mockAccountRepoForGemini) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error {
...@@ -215,6 +223,7 @@ var _ GroupRepository = (*mockGroupRepoForGemini)(nil) ...@@ -215,6 +223,7 @@ var _ GroupRepository = (*mockGroupRepoForGemini)(nil)
// mockGatewayCacheForGemini Gemini 测试用的 cache mock // mockGatewayCacheForGemini Gemini 测试用的 cache mock
type mockGatewayCacheForGemini struct { type mockGatewayCacheForGemini struct {
sessionBindings map[string]int64 sessionBindings map[string]int64
deletedSessions map[string]int
} }
func (m *mockGatewayCacheForGemini) GetSessionAccountID(ctx context.Context, groupID int64, sessionHash string) (int64, error) { func (m *mockGatewayCacheForGemini) GetSessionAccountID(ctx context.Context, groupID int64, sessionHash string) (int64, error) {
...@@ -236,6 +245,18 @@ func (m *mockGatewayCacheForGemini) RefreshSessionTTL(ctx context.Context, group ...@@ -236,6 +245,18 @@ func (m *mockGatewayCacheForGemini) RefreshSessionTTL(ctx context.Context, group
return nil return nil
} }
func (m *mockGatewayCacheForGemini) DeleteSessionAccountID(ctx context.Context, groupID int64, sessionHash string) error {
if m.sessionBindings == nil {
return nil
}
if m.deletedSessions == nil {
m.deletedSessions = make(map[string]int)
}
m.deletedSessions[sessionHash]++
delete(m.sessionBindings, sessionHash)
return nil
}
// TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_GeminiPlatform 测试 Gemini 单平台选择 // TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_GeminiPlatform 测试 Gemini 单平台选择
func TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_GeminiPlatform(t *testing.T) { func TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_GeminiPlatform(t *testing.T) {
ctx := context.Background() ctx := context.Background()
...@@ -526,6 +547,274 @@ func TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_StickyS ...@@ -526,6 +547,274 @@ func TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_StickyS
// 粘性会话未命中,按优先级选择 // 粘性会话未命中,按优先级选择
require.Equal(t, int64(2), acc.ID, "粘性会话未命中,应按优先级选择") require.Equal(t, int64(2), acc.ID, "粘性会话未命中,应按优先级选择")
}) })
t.Run("粘性会话不可调度-清理并回退选择", func(t *testing.T) {
repo := &mockAccountRepoForGemini{
accounts: []Account{
{ID: 1, Platform: PlatformGemini, Priority: 2, Status: StatusDisabled, Schedulable: true},
{ID: 2, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true},
},
accountsByID: map[int64]*Account{},
}
for i := range repo.accounts {
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
}
cache := &mockGatewayCacheForGemini{
sessionBindings: map[string]int64{"gemini:session-123": 1},
}
groupRepo := &mockGroupRepoForGemini{groups: map[int64]*Group{}}
svc := &GeminiMessagesCompatService{
accountRepo: repo,
groupRepo: groupRepo,
cache: cache,
}
acc, err := svc.SelectAccountForModelWithExclusions(ctx, nil, "session-123", "gemini-2.5-flash", nil)
require.NoError(t, err)
require.NotNil(t, acc)
require.Equal(t, int64(2), acc.ID)
require.Equal(t, 1, cache.deletedSessions["gemini:session-123"])
require.Equal(t, int64(2), cache.sessionBindings["gemini:session-123"])
})
}
func TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_ForcePlatformFallback(t *testing.T) {
ctx := context.Background()
groupID := int64(9)
ctx = context.WithValue(ctx, ctxkey.ForcePlatform, PlatformAntigravity)
repo := &mockAccountRepoForGemini{
listByGroupFunc: func(ctx context.Context, groupID int64, platforms []string) ([]Account, error) {
return nil, nil
},
listByPlatformFunc: func(ctx context.Context, platforms []string) ([]Account, error) {
return []Account{
{ID: 1, Platform: PlatformAntigravity, Priority: 1, Status: StatusActive, Schedulable: true},
}, nil
},
accountsByID: map[int64]*Account{
1: {ID: 1, Platform: PlatformAntigravity, Priority: 1, Status: StatusActive, Schedulable: true},
},
}
cache := &mockGatewayCacheForGemini{}
groupRepo := &mockGroupRepoForGemini{groups: map[int64]*Group{}}
svc := &GeminiMessagesCompatService{
accountRepo: repo,
groupRepo: groupRepo,
cache: cache,
}
acc, err := svc.SelectAccountForModelWithExclusions(ctx, &groupID, "", "gemini-2.5-flash", nil)
require.NoError(t, err)
require.NotNil(t, acc)
require.Equal(t, int64(1), acc.ID)
}
func TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_NoModelSupport(t *testing.T) {
ctx := context.Background()
repo := &mockAccountRepoForGemini{
accounts: []Account{
{
ID: 1,
Platform: PlatformGemini,
Priority: 1,
Status: StatusActive,
Schedulable: true,
Credentials: map[string]any{"model_mapping": map[string]any{"gemini-1.0-pro": "gemini-1.0-pro"}},
},
},
accountsByID: map[int64]*Account{},
}
for i := range repo.accounts {
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
}
cache := &mockGatewayCacheForGemini{}
groupRepo := &mockGroupRepoForGemini{groups: map[int64]*Group{}}
svc := &GeminiMessagesCompatService{
accountRepo: repo,
groupRepo: groupRepo,
cache: cache,
}
acc, err := svc.SelectAccountForModelWithExclusions(ctx, nil, "", "gemini-2.5-flash", nil)
require.Error(t, err)
require.Nil(t, acc)
require.Contains(t, err.Error(), "supporting model")
}
func TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_StickyMixedScheduling(t *testing.T) {
ctx := context.Background()
repo := &mockAccountRepoForGemini{
accounts: []Account{
{ID: 1, Platform: PlatformAntigravity, Priority: 1, Status: StatusActive, Schedulable: true, Extra: map[string]any{"mixed_scheduling": true}},
{ID: 2, Platform: PlatformGemini, Priority: 2, Status: StatusActive, Schedulable: true},
},
accountsByID: map[int64]*Account{},
}
for i := range repo.accounts {
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
}
cache := &mockGatewayCacheForGemini{
sessionBindings: map[string]int64{"gemini:session-999": 1},
}
groupRepo := &mockGroupRepoForGemini{groups: map[int64]*Group{}}
svc := &GeminiMessagesCompatService{
accountRepo: repo,
groupRepo: groupRepo,
cache: cache,
}
acc, err := svc.SelectAccountForModelWithExclusions(ctx, nil, "session-999", "gemini-2.5-flash", nil)
require.NoError(t, err)
require.NotNil(t, acc)
require.Equal(t, int64(1), acc.ID)
}
func TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_SkipDisabledMixedScheduling(t *testing.T) {
ctx := context.Background()
repo := &mockAccountRepoForGemini{
accounts: []Account{
{ID: 1, Platform: PlatformAntigravity, Priority: 1, Status: StatusActive, Schedulable: true},
{ID: 2, Platform: PlatformGemini, Priority: 2, Status: StatusActive, Schedulable: true},
},
accountsByID: map[int64]*Account{},
}
for i := range repo.accounts {
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
}
cache := &mockGatewayCacheForGemini{}
groupRepo := &mockGroupRepoForGemini{groups: map[int64]*Group{}}
svc := &GeminiMessagesCompatService{
accountRepo: repo,
groupRepo: groupRepo,
cache: cache,
}
acc, err := svc.SelectAccountForModelWithExclusions(ctx, nil, "", "gemini-2.5-flash", nil)
require.NoError(t, err)
require.NotNil(t, acc)
require.Equal(t, int64(2), acc.ID)
}
func TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_ExcludedAccount(t *testing.T) {
ctx := context.Background()
repo := &mockAccountRepoForGemini{
accounts: []Account{
{ID: 1, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true},
{ID: 2, Platform: PlatformGemini, Priority: 2, Status: StatusActive, Schedulable: true},
},
accountsByID: map[int64]*Account{},
}
for i := range repo.accounts {
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
}
cache := &mockGatewayCacheForGemini{}
groupRepo := &mockGroupRepoForGemini{groups: map[int64]*Group{}}
svc := &GeminiMessagesCompatService{
accountRepo: repo,
groupRepo: groupRepo,
cache: cache,
}
excluded := map[int64]struct{}{1: {}}
acc, err := svc.SelectAccountForModelWithExclusions(ctx, nil, "", "gemini-2.5-flash", excluded)
require.NoError(t, err)
require.NotNil(t, acc)
require.Equal(t, int64(2), acc.ID)
}
func TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_ListError(t *testing.T) {
ctx := context.Background()
repo := &mockAccountRepoForGemini{
listByPlatformFunc: func(ctx context.Context, platforms []string) ([]Account, error) {
return nil, errors.New("query failed")
},
}
cache := &mockGatewayCacheForGemini{}
groupRepo := &mockGroupRepoForGemini{groups: map[int64]*Group{}}
svc := &GeminiMessagesCompatService{
accountRepo: repo,
groupRepo: groupRepo,
cache: cache,
}
acc, err := svc.SelectAccountForModelWithExclusions(ctx, nil, "", "gemini-2.5-flash", nil)
require.Error(t, err)
require.Nil(t, acc)
require.Contains(t, err.Error(), "query accounts failed")
}
func TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_PreferOAuth(t *testing.T) {
ctx := context.Background()
repo := &mockAccountRepoForGemini{
accounts: []Account{
{ID: 1, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true, Type: AccountTypeAPIKey},
{ID: 2, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true, Type: AccountTypeOAuth},
},
accountsByID: map[int64]*Account{},
}
for i := range repo.accounts {
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
}
cache := &mockGatewayCacheForGemini{}
groupRepo := &mockGroupRepoForGemini{groups: map[int64]*Group{}}
svc := &GeminiMessagesCompatService{
accountRepo: repo,
groupRepo: groupRepo,
cache: cache,
}
acc, err := svc.SelectAccountForModelWithExclusions(ctx, nil, "", "gemini-2.5-pro", nil)
require.NoError(t, err)
require.NotNil(t, acc)
require.Equal(t, int64(2), acc.ID)
}
func TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_PreferLeastRecentlyUsed(t *testing.T) {
ctx := context.Background()
oldTime := time.Now().Add(-2 * time.Hour)
newTime := time.Now().Add(-1 * time.Hour)
repo := &mockAccountRepoForGemini{
accounts: []Account{
{ID: 1, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true, LastUsedAt: &newTime},
{ID: 2, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true, LastUsedAt: &oldTime},
},
accountsByID: map[int64]*Account{},
}
for i := range repo.accounts {
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
}
cache := &mockGatewayCacheForGemini{}
groupRepo := &mockGroupRepoForGemini{groups: map[int64]*Group{}}
svc := &GeminiMessagesCompatService{
accountRepo: repo,
groupRepo: groupRepo,
cache: cache,
}
acc, err := svc.SelectAccountForModelWithExclusions(ctx, nil, "", "gemini-2.5-pro", nil)
require.NoError(t, err)
require.NotNil(t, acc)
require.Equal(t, int64(2), acc.ID)
} }
// TestGeminiPlatformRouting_DocumentRouteDecision 测试平台路由决策逻辑 // TestGeminiPlatformRouting_DocumentRouteDecision 测试平台路由决策逻辑
......
...@@ -180,81 +180,164 @@ func (s *OpenAIGatewayService) SelectAccountForModel(ctx context.Context, groupI ...@@ -180,81 +180,164 @@ func (s *OpenAIGatewayService) SelectAccountForModel(ctx context.Context, groupI
} }
// SelectAccountForModelWithExclusions selects an account supporting the requested model while excluding specified accounts. // SelectAccountForModelWithExclusions selects an account supporting the requested model while excluding specified accounts.
// SelectAccountForModelWithExclusions 选择支持指定模型的账号,同时排除指定的账号。
func (s *OpenAIGatewayService) SelectAccountForModelWithExclusions(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}) (*Account, error) { func (s *OpenAIGatewayService) SelectAccountForModelWithExclusions(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}) (*Account, error) {
// 1. Check sticky session cacheKey := "openai:" + sessionHash
if sessionHash != "" {
accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), "openai:"+sessionHash) // 1. 尝试粘性会话命中
if err == nil && accountID > 0 { // Try sticky session hit
if _, excluded := excludedIDs[accountID]; !excluded { if account := s.tryStickySessionHit(ctx, groupID, sessionHash, cacheKey, requestedModel, excludedIDs); account != nil {
account, err := s.getSchedulableAccount(ctx, accountID) return account, nil
if err == nil && account.IsSchedulable() && account.IsOpenAI() && (requestedModel == "" || account.IsModelSupported(requestedModel)) {
// Refresh sticky session TTL
_ = s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), "openai:"+sessionHash, openaiStickySessionTTL)
return account, nil
}
}
}
} }
// 2. Get schedulable OpenAI accounts // 2. 获取可调度的 OpenAI 账号
// Get schedulable OpenAI accounts
accounts, err := s.listSchedulableAccounts(ctx, groupID) accounts, err := s.listSchedulableAccounts(ctx, groupID)
if err != nil { if err != nil {
return nil, fmt.Errorf("query accounts failed: %w", err) return nil, fmt.Errorf("query accounts failed: %w", err)
} }
// 3. Select by priority + LRU // 3. 按优先级 + LRU 选择最佳账号
// Select by priority + LRU
selected := s.selectBestAccount(accounts, requestedModel, excludedIDs)
if selected == nil {
if requestedModel != "" {
return nil, fmt.Errorf("no available OpenAI accounts supporting model: %s", requestedModel)
}
return nil, errors.New("no available OpenAI accounts")
}
// 4. 设置粘性会话绑定
// Set sticky session binding
if sessionHash != "" {
_ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), cacheKey, selected.ID, openaiStickySessionTTL)
}
return selected, nil
}
// tryStickySessionHit 尝试从粘性会话获取账号。
// 如果命中且账号可用则返回账号;如果账号不可用则清理会话并返回 nil。
//
// tryStickySessionHit attempts to get account from sticky session.
// Returns account if hit and usable; clears session and returns nil if account is unavailable.
func (s *OpenAIGatewayService) tryStickySessionHit(ctx context.Context, groupID *int64, sessionHash, cacheKey, requestedModel string, excludedIDs map[int64]struct{}) *Account {
if sessionHash == "" {
return nil
}
accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), cacheKey)
if err != nil || accountID <= 0 {
return nil
}
if _, excluded := excludedIDs[accountID]; excluded {
return nil
}
account, err := s.getSchedulableAccount(ctx, accountID)
if err != nil {
return nil
}
// 检查账号是否需要清理粘性会话
// Check if sticky session should be cleared
if shouldClearStickySession(account) {
_ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), cacheKey)
return nil
}
// 验证账号是否可用于当前请求
// Verify account is usable for current request
if !account.IsSchedulable() || !account.IsOpenAI() {
return nil
}
if requestedModel != "" && !account.IsModelSupported(requestedModel) {
return nil
}
// 刷新会话 TTL 并返回账号
// Refresh session TTL and return account
_ = s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), cacheKey, openaiStickySessionTTL)
return account
}
// selectBestAccount 从候选账号中选择最佳账号(优先级 + LRU)。
// 返回 nil 表示无可用账号。
//
// selectBestAccount selects the best account from candidates (priority + LRU).
// Returns nil if no available account.
func (s *OpenAIGatewayService) selectBestAccount(accounts []Account, requestedModel string, excludedIDs map[int64]struct{}) *Account {
var selected *Account var selected *Account
for i := range accounts { for i := range accounts {
acc := &accounts[i] acc := &accounts[i]
// 跳过被排除的账号
// Skip excluded accounts
if _, excluded := excludedIDs[acc.ID]; excluded { if _, excluded := excludedIDs[acc.ID]; excluded {
continue continue
} }
// Scheduler snapshots can be temporarily stale; re-check schedulability here to
// avoid selecting accounts that were recently rate-limited/overloaded. // 调度器快照可能暂时过时,这里重新检查可调度性和平台
if !acc.IsSchedulable() { // Scheduler snapshots can be temporarily stale; re-check schedulability and platform
if !acc.IsSchedulable() || !acc.IsOpenAI() {
continue continue
} }
// 检查模型支持
// Check model support // Check model support
if requestedModel != "" && !acc.IsModelSupported(requestedModel) { if requestedModel != "" && !acc.IsModelSupported(requestedModel) {
continue continue
} }
// 选择优先级最高且最久未使用的账号
// Select highest priority and least recently used
if selected == nil { if selected == nil {
selected = acc selected = acc
continue continue
} }
// Lower priority value means higher priority
if acc.Priority < selected.Priority { if s.isBetterAccount(acc, selected) {
selected = acc selected = acc
} else if acc.Priority == selected.Priority {
switch {
case acc.LastUsedAt == nil && selected.LastUsedAt != nil:
selected = acc
case acc.LastUsedAt != nil && selected.LastUsedAt == nil:
// keep selected (never used is preferred)
case acc.LastUsedAt == nil && selected.LastUsedAt == nil:
// keep selected (both never used)
default:
// Same priority, select least recently used
if acc.LastUsedAt.Before(*selected.LastUsedAt) {
selected = acc
}
}
} }
} }
if selected == nil { return selected
if requestedModel != "" { }
return nil, fmt.Errorf("no available OpenAI accounts supporting model: %s", requestedModel)
}
return nil, errors.New("no available OpenAI accounts")
}
// 4. Set sticky session // isBetterAccount 判断 candidate 是否比 current 更优。
if sessionHash != "" { // 规则:优先级更高(数值更小)优先;同优先级时,未使用过的优先,其次是最久未使用的。
_ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), "openai:"+sessionHash, selected.ID, openaiStickySessionTTL) //
// isBetterAccount checks if candidate is better than current.
// Rules: higher priority (lower value) wins; same priority: never used > least recently used.
func (s *OpenAIGatewayService) isBetterAccount(candidate, current *Account) bool {
// 优先级更高(数值更小)
// Higher priority (lower value)
if candidate.Priority < current.Priority {
return true
}
if candidate.Priority > current.Priority {
return false
} }
return selected, nil // 同优先级,比较最后使用时间
// Same priority, compare last used time
switch {
case candidate.LastUsedAt == nil && current.LastUsedAt != nil:
// candidate 从未使用,优先
return true
case candidate.LastUsedAt != nil && current.LastUsedAt == nil:
// current 从未使用,保持
return false
case candidate.LastUsedAt == nil && current.LastUsedAt == nil:
// 都未使用,保持
return false
default:
// 都使用过,选择最久未使用的
return candidate.LastUsedAt.Before(*current.LastUsedAt)
}
} }
// SelectAccountWithLoadAwareness selects an account with load-awareness and wait plan. // SelectAccountWithLoadAwareness selects an account with load-awareness and wait plan.
...@@ -325,29 +408,35 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex ...@@ -325,29 +408,35 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex
accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), "openai:"+sessionHash) accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), "openai:"+sessionHash)
if err == nil && accountID > 0 && !isExcluded(accountID) { if err == nil && accountID > 0 && !isExcluded(accountID) {
account, err := s.getSchedulableAccount(ctx, accountID) account, err := s.getSchedulableAccount(ctx, accountID)
if err == nil && account.IsSchedulable() && account.IsOpenAI() && if err == nil {
(requestedModel == "" || account.IsModelSupported(requestedModel)) { clearSticky := shouldClearStickySession(account)
result, err := s.tryAcquireAccountSlot(ctx, accountID, account.Concurrency) if clearSticky {
if err == nil && result.Acquired { _ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), "openai:"+sessionHash)
_ = s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), "openai:"+sessionHash, openaiStickySessionTTL)
return &AccountSelectionResult{
Account: account,
Acquired: true,
ReleaseFunc: result.ReleaseFunc,
}, nil
} }
if !clearSticky && account.IsSchedulable() && account.IsOpenAI() &&
(requestedModel == "" || account.IsModelSupported(requestedModel)) {
result, err := s.tryAcquireAccountSlot(ctx, accountID, account.Concurrency)
if err == nil && result.Acquired {
_ = s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), "openai:"+sessionHash, openaiStickySessionTTL)
return &AccountSelectionResult{
Account: account,
Acquired: true,
ReleaseFunc: result.ReleaseFunc,
}, nil
}
waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, accountID) waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, accountID)
if waitingCount < cfg.StickySessionMaxWaiting { if waitingCount < cfg.StickySessionMaxWaiting {
return &AccountSelectionResult{ return &AccountSelectionResult{
Account: account, Account: account,
WaitPlan: &AccountWaitPlan{ WaitPlan: &AccountWaitPlan{
AccountID: accountID, AccountID: accountID,
MaxConcurrency: account.Concurrency, MaxConcurrency: account.Concurrency,
Timeout: cfg.StickySessionWaitTimeout, Timeout: cfg.StickySessionWaitTimeout,
MaxWaiting: cfg.StickySessionMaxWaiting, MaxWaiting: cfg.StickySessionMaxWaiting,
}, },
}, nil }, nil
}
} }
} }
} }
......
...@@ -60,6 +60,7 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings ...@@ -60,6 +60,7 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
keys := []string{ keys := []string{
SettingKeyRegistrationEnabled, SettingKeyRegistrationEnabled,
SettingKeyEmailVerifyEnabled, SettingKeyEmailVerifyEnabled,
SettingKeyPromoCodeEnabled,
SettingKeyTurnstileEnabled, SettingKeyTurnstileEnabled,
SettingKeyTurnstileSiteKey, SettingKeyTurnstileSiteKey,
SettingKeySiteName, SettingKeySiteName,
...@@ -88,6 +89,7 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings ...@@ -88,6 +89,7 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
return &PublicSettings{ return &PublicSettings{
RegistrationEnabled: settings[SettingKeyRegistrationEnabled] == "true", RegistrationEnabled: settings[SettingKeyRegistrationEnabled] == "true",
EmailVerifyEnabled: settings[SettingKeyEmailVerifyEnabled] == "true", EmailVerifyEnabled: settings[SettingKeyEmailVerifyEnabled] == "true",
PromoCodeEnabled: settings[SettingKeyPromoCodeEnabled] != "false", // 默认启用
TurnstileEnabled: settings[SettingKeyTurnstileEnabled] == "true", TurnstileEnabled: settings[SettingKeyTurnstileEnabled] == "true",
TurnstileSiteKey: settings[SettingKeyTurnstileSiteKey], TurnstileSiteKey: settings[SettingKeyTurnstileSiteKey],
SiteName: s.getStringOrDefault(settings, SettingKeySiteName, "Sub2API"), SiteName: s.getStringOrDefault(settings, SettingKeySiteName, "Sub2API"),
...@@ -125,6 +127,7 @@ func (s *SettingService) GetPublicSettingsForInjection(ctx context.Context) (any ...@@ -125,6 +127,7 @@ func (s *SettingService) GetPublicSettingsForInjection(ctx context.Context) (any
return &struct { return &struct {
RegistrationEnabled bool `json:"registration_enabled"` RegistrationEnabled bool `json:"registration_enabled"`
EmailVerifyEnabled bool `json:"email_verify_enabled"` EmailVerifyEnabled bool `json:"email_verify_enabled"`
PromoCodeEnabled bool `json:"promo_code_enabled"`
TurnstileEnabled bool `json:"turnstile_enabled"` TurnstileEnabled bool `json:"turnstile_enabled"`
TurnstileSiteKey string `json:"turnstile_site_key,omitempty"` TurnstileSiteKey string `json:"turnstile_site_key,omitempty"`
SiteName string `json:"site_name"` SiteName string `json:"site_name"`
...@@ -140,6 +143,7 @@ func (s *SettingService) GetPublicSettingsForInjection(ctx context.Context) (any ...@@ -140,6 +143,7 @@ func (s *SettingService) GetPublicSettingsForInjection(ctx context.Context) (any
}{ }{
RegistrationEnabled: settings.RegistrationEnabled, RegistrationEnabled: settings.RegistrationEnabled,
EmailVerifyEnabled: settings.EmailVerifyEnabled, EmailVerifyEnabled: settings.EmailVerifyEnabled,
PromoCodeEnabled: settings.PromoCodeEnabled,
TurnstileEnabled: settings.TurnstileEnabled, TurnstileEnabled: settings.TurnstileEnabled,
TurnstileSiteKey: settings.TurnstileSiteKey, TurnstileSiteKey: settings.TurnstileSiteKey,
SiteName: settings.SiteName, SiteName: settings.SiteName,
...@@ -162,6 +166,7 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet ...@@ -162,6 +166,7 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet
// 注册设置 // 注册设置
updates[SettingKeyRegistrationEnabled] = strconv.FormatBool(settings.RegistrationEnabled) updates[SettingKeyRegistrationEnabled] = strconv.FormatBool(settings.RegistrationEnabled)
updates[SettingKeyEmailVerifyEnabled] = strconv.FormatBool(settings.EmailVerifyEnabled) updates[SettingKeyEmailVerifyEnabled] = strconv.FormatBool(settings.EmailVerifyEnabled)
updates[SettingKeyPromoCodeEnabled] = strconv.FormatBool(settings.PromoCodeEnabled)
// 邮件服务设置(只有非空才更新密码) // 邮件服务设置(只有非空才更新密码)
updates[SettingKeySMTPHost] = settings.SMTPHost updates[SettingKeySMTPHost] = settings.SMTPHost
...@@ -248,6 +253,15 @@ func (s *SettingService) IsEmailVerifyEnabled(ctx context.Context) bool { ...@@ -248,6 +253,15 @@ func (s *SettingService) IsEmailVerifyEnabled(ctx context.Context) bool {
return value == "true" return value == "true"
} }
// IsPromoCodeEnabled 检查是否启用优惠码功能
func (s *SettingService) IsPromoCodeEnabled(ctx context.Context) bool {
value, err := s.settingRepo.GetValue(ctx, SettingKeyPromoCodeEnabled)
if err != nil {
return true // 默认启用
}
return value != "false"
}
// GetSiteName 获取网站名称 // GetSiteName 获取网站名称
func (s *SettingService) GetSiteName(ctx context.Context) string { func (s *SettingService) GetSiteName(ctx context.Context) string {
value, err := s.settingRepo.GetValue(ctx, SettingKeySiteName) value, err := s.settingRepo.GetValue(ctx, SettingKeySiteName)
...@@ -297,6 +311,7 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error { ...@@ -297,6 +311,7 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error {
defaults := map[string]string{ defaults := map[string]string{
SettingKeyRegistrationEnabled: "true", SettingKeyRegistrationEnabled: "true",
SettingKeyEmailVerifyEnabled: "false", SettingKeyEmailVerifyEnabled: "false",
SettingKeyPromoCodeEnabled: "true", // 默认启用优惠码功能
SettingKeySiteName: "Sub2API", SettingKeySiteName: "Sub2API",
SettingKeySiteLogo: "", SettingKeySiteLogo: "",
SettingKeyDefaultConcurrency: strconv.Itoa(s.cfg.Default.UserConcurrency), SettingKeyDefaultConcurrency: strconv.Itoa(s.cfg.Default.UserConcurrency),
...@@ -328,6 +343,7 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin ...@@ -328,6 +343,7 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin
result := &SystemSettings{ result := &SystemSettings{
RegistrationEnabled: settings[SettingKeyRegistrationEnabled] == "true", RegistrationEnabled: settings[SettingKeyRegistrationEnabled] == "true",
EmailVerifyEnabled: settings[SettingKeyEmailVerifyEnabled] == "true", EmailVerifyEnabled: settings[SettingKeyEmailVerifyEnabled] == "true",
PromoCodeEnabled: settings[SettingKeyPromoCodeEnabled] != "false", // 默认启用
SMTPHost: settings[SettingKeySMTPHost], SMTPHost: settings[SettingKeySMTPHost],
SMTPUsername: settings[SettingKeySMTPUsername], SMTPUsername: settings[SettingKeySMTPUsername],
SMTPFrom: settings[SettingKeySMTPFrom], SMTPFrom: settings[SettingKeySMTPFrom],
......
...@@ -3,6 +3,7 @@ package service ...@@ -3,6 +3,7 @@ package service
type SystemSettings struct { type SystemSettings struct {
RegistrationEnabled bool RegistrationEnabled bool
EmailVerifyEnabled bool EmailVerifyEnabled bool
PromoCodeEnabled bool
SMTPHost string SMTPHost string
SMTPPort int SMTPPort int
...@@ -58,6 +59,7 @@ type SystemSettings struct { ...@@ -58,6 +59,7 @@ type SystemSettings struct {
type PublicSettings struct { type PublicSettings struct {
RegistrationEnabled bool RegistrationEnabled bool
EmailVerifyEnabled bool EmailVerifyEnabled bool
PromoCodeEnabled bool
TurnstileEnabled bool TurnstileEnabled bool
TurnstileSiteKey string TurnstileSiteKey string
SiteName string SiteName string
......
//go:build unit
// Package service 提供 API 网关核心服务。
// 本文件包含 shouldClearStickySession 函数的单元测试,
// 验证粘性会话清理逻辑在各种账号状态下的正确行为。
//
// This file contains unit tests for the shouldClearStickySession function,
// verifying correct sticky session clearing behavior under various account states.
package service
import (
"testing"
"time"
"github.com/stretchr/testify/require"
)
// TestShouldClearStickySession 测试粘性会话清理判断逻辑。
// 验证在以下情况下是否正确判断需要清理粘性会话:
// - nil 账号:不清理(返回 false)
// - 状态为错误或禁用:清理
// - 不可调度:清理
// - 临时不可调度且未过期:清理
// - 临时不可调度已过期:不清理
// - 正常可调度状态:不清理
//
// TestShouldClearStickySession tests the sticky session clearing logic.
// Verifies correct behavior for various account states including:
// nil account, error/disabled status, unschedulable, temporary unschedulable.
func TestShouldClearStickySession(t *testing.T) {
now := time.Now()
future := now.Add(1 * time.Hour)
past := now.Add(-1 * time.Hour)
tests := []struct {
name string
account *Account
want bool
}{
{name: "nil account", account: nil, want: false},
{name: "status error", account: &Account{Status: StatusError, Schedulable: true}, want: true},
{name: "status disabled", account: &Account{Status: StatusDisabled, Schedulable: true}, want: true},
{name: "schedulable false", account: &Account{Status: StatusActive, Schedulable: false}, want: true},
{name: "temp unschedulable", account: &Account{Status: StatusActive, Schedulable: true, TempUnschedulableUntil: &future}, want: true},
{name: "temp unschedulable expired", account: &Account{Status: StatusActive, Schedulable: true, TempUnschedulableUntil: &past}, want: false},
{name: "active schedulable", account: &Account{Status: StatusActive, Schedulable: true}, want: false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
require.Equal(t, tt.want, shouldClearStickySession(tt.account))
})
}
}
...@@ -345,6 +345,9 @@ func TestUsageCleanupServiceRunOnceSuccess(t *testing.T) { ...@@ -345,6 +345,9 @@ func TestUsageCleanupServiceRunOnceSuccess(t *testing.T) {
repo.mu.Lock() repo.mu.Lock()
defer repo.mu.Unlock() defer repo.mu.Unlock()
require.Len(t, repo.deleteCalls, 3) require.Len(t, repo.deleteCalls, 3)
require.Equal(t, 2, repo.deleteCalls[0].limit)
require.True(t, repo.deleteCalls[0].filters.StartTime.Equal(start))
require.True(t, repo.deleteCalls[0].filters.EndTime.Equal(end))
require.Len(t, repo.markSucceeded, 1) require.Len(t, repo.markSucceeded, 1)
require.Empty(t, repo.markFailed) require.Empty(t, repo.markFailed)
require.Equal(t, int64(5), repo.markSucceeded[0].taskID) require.Equal(t, int64(5), repo.markSucceeded[0].taskID)
......
...@@ -12,6 +12,7 @@ export interface SystemSettings { ...@@ -12,6 +12,7 @@ export interface SystemSettings {
// Registration settings // Registration settings
registration_enabled: boolean registration_enabled: boolean
email_verify_enabled: boolean email_verify_enabled: boolean
promo_code_enabled: boolean
// Default settings // Default settings
default_balance: number default_balance: number
default_concurrency: number default_concurrency: number
...@@ -64,6 +65,7 @@ export interface SystemSettings { ...@@ -64,6 +65,7 @@ export interface SystemSettings {
export interface UpdateSettingsRequest { export interface UpdateSettingsRequest {
registration_enabled?: boolean registration_enabled?: boolean
email_verify_enabled?: boolean email_verify_enabled?: boolean
promo_code_enabled?: boolean
default_balance?: number default_balance?: number
default_concurrency?: number default_concurrency?: number
site_name?: string site_name?: string
......
...@@ -2726,7 +2726,9 @@ export default { ...@@ -2726,7 +2726,9 @@ export default {
enableRegistration: 'Enable Registration', enableRegistration: 'Enable Registration',
enableRegistrationHint: 'Allow new users to register', enableRegistrationHint: 'Allow new users to register',
emailVerification: 'Email Verification', emailVerification: 'Email Verification',
emailVerificationHint: 'Require email verification for new registrations' emailVerificationHint: 'Require email verification for new registrations',
promoCode: 'Promo Code',
promoCodeHint: 'Allow users to use promo codes during registration'
}, },
turnstile: { turnstile: {
title: 'Cloudflare Turnstile', title: 'Cloudflare Turnstile',
......
...@@ -2879,7 +2879,9 @@ export default { ...@@ -2879,7 +2879,9 @@ export default {
enableRegistration: '开放注册', enableRegistration: '开放注册',
enableRegistrationHint: '允许新用户注册', enableRegistrationHint: '允许新用户注册',
emailVerification: '邮箱验证', emailVerification: '邮箱验证',
emailVerificationHint: '新用户注册时需要验证邮箱' emailVerificationHint: '新用户注册时需要验证邮箱',
promoCode: '优惠码',
promoCodeHint: '允许用户在注册时使用优惠码'
}, },
turnstile: { turnstile: {
title: 'Cloudflare Turnstile', title: 'Cloudflare Turnstile',
......
...@@ -312,6 +312,7 @@ export const useAppStore = defineStore('app', () => { ...@@ -312,6 +312,7 @@ export const useAppStore = defineStore('app', () => {
return { return {
registration_enabled: false, registration_enabled: false,
email_verify_enabled: false, email_verify_enabled: false,
promo_code_enabled: true,
turnstile_enabled: false, turnstile_enabled: false,
turnstile_site_key: '', turnstile_site_key: '',
site_name: siteName.value, site_name: siteName.value,
......
...@@ -70,6 +70,7 @@ export interface SendVerifyCodeResponse { ...@@ -70,6 +70,7 @@ export interface SendVerifyCodeResponse {
export interface PublicSettings { export interface PublicSettings {
registration_enabled: boolean registration_enabled: boolean
email_verify_enabled: boolean email_verify_enabled: boolean
promo_code_enabled: boolean
turnstile_enabled: boolean turnstile_enabled: boolean
turnstile_site_key: string turnstile_site_key: string
site_name: string site_name: string
......
...@@ -238,7 +238,30 @@ ...@@ -238,7 +238,30 @@
v-model="generateForm.group_id" v-model="generateForm.group_id"
:options="subscriptionGroupOptions" :options="subscriptionGroupOptions"
:placeholder="t('admin.redeem.selectGroupPlaceholder')" :placeholder="t('admin.redeem.selectGroupPlaceholder')"
/> >
<template #selected="{ option }">
<GroupBadge
v-if="option"
:name="(option as unknown as GroupOption).label"
:platform="(option as unknown as GroupOption).platform"
:subscription-type="(option as unknown as GroupOption).subscriptionType"
:rate-multiplier="(option as unknown as GroupOption).rate"
/>
<span v-else class="text-gray-400">{{
t('admin.redeem.selectGroupPlaceholder')
}}</span>
</template>
<template #option="{ option, selected }">
<GroupOptionItem
:name="(option as unknown as GroupOption).label"
:platform="(option as unknown as GroupOption).platform"
:subscription-type="(option as unknown as GroupOption).subscriptionType"
:rate-multiplier="(option as unknown as GroupOption).rate"
:description="(option as unknown as GroupOption).description"
:selected="selected"
/>
</template>
</Select>
</div> </div>
<div> <div>
<label class="input-label">{{ t('admin.redeem.validityDays') }}</label> <label class="input-label">{{ t('admin.redeem.validityDays') }}</label>
...@@ -370,7 +393,7 @@ import { useAppStore } from '@/stores/app' ...@@ -370,7 +393,7 @@ import { useAppStore } from '@/stores/app'
import { useClipboard } from '@/composables/useClipboard' import { useClipboard } from '@/composables/useClipboard'
import { adminAPI } from '@/api/admin' import { adminAPI } from '@/api/admin'
import { formatDateTime } from '@/utils/format' import { formatDateTime } from '@/utils/format'
import type { RedeemCode, RedeemCodeType, Group } from '@/types' import type { RedeemCode, RedeemCodeType, Group, GroupPlatform, SubscriptionType } from '@/types'
import type { Column } from '@/components/common/types' import type { Column } from '@/components/common/types'
import AppLayout from '@/components/layout/AppLayout.vue' import AppLayout from '@/components/layout/AppLayout.vue'
import TablePageLayout from '@/components/layout/TablePageLayout.vue' import TablePageLayout from '@/components/layout/TablePageLayout.vue'
...@@ -378,12 +401,23 @@ import DataTable from '@/components/common/DataTable.vue' ...@@ -378,12 +401,23 @@ import DataTable from '@/components/common/DataTable.vue'
import Pagination from '@/components/common/Pagination.vue' import Pagination from '@/components/common/Pagination.vue'
import ConfirmDialog from '@/components/common/ConfirmDialog.vue' import ConfirmDialog from '@/components/common/ConfirmDialog.vue'
import Select from '@/components/common/Select.vue' import Select from '@/components/common/Select.vue'
import GroupBadge from '@/components/common/GroupBadge.vue'
import GroupOptionItem from '@/components/common/GroupOptionItem.vue'
import Icon from '@/components/icons/Icon.vue' import Icon from '@/components/icons/Icon.vue'
const { t } = useI18n() const { t } = useI18n()
const appStore = useAppStore() const appStore = useAppStore()
const { copyToClipboard: clipboardCopy } = useClipboard() const { copyToClipboard: clipboardCopy } = useClipboard()
interface GroupOption {
value: number
label: string
description: string | null
platform: GroupPlatform
subscriptionType: SubscriptionType
rate: number
}
const showGenerateDialog = ref(false) const showGenerateDialog = ref(false)
const showResultDialog = ref(false) const showResultDialog = ref(false)
const generatedCodes = ref<RedeemCode[]>([]) const generatedCodes = ref<RedeemCode[]>([])
...@@ -395,7 +429,11 @@ const subscriptionGroupOptions = computed(() => { ...@@ -395,7 +429,11 @@ const subscriptionGroupOptions = computed(() => {
.filter((g) => g.subscription_type === 'subscription') .filter((g) => g.subscription_type === 'subscription')
.map((g) => ({ .map((g) => ({
value: g.id, value: g.id,
label: g.name label: g.name,
description: g.description,
platform: g.platform,
subscriptionType: g.subscription_type,
rate: g.rate_multiplier
})) }))
}) })
......
...@@ -323,6 +323,21 @@ ...@@ -323,6 +323,21 @@
</div> </div>
<Toggle v-model="form.email_verify_enabled" /> <Toggle v-model="form.email_verify_enabled" />
</div> </div>
<!-- Promo Code -->
<div
class="flex items-center justify-between border-t border-gray-100 pt-4 dark:border-dark-700"
>
<div>
<label class="font-medium text-gray-900 dark:text-white">{{
t('admin.settings.registration.promoCode')
}}</label>
<p class="text-sm text-gray-500 dark:text-gray-400">
{{ t('admin.settings.registration.promoCodeHint') }}
</p>
</div>
<Toggle v-model="form.promo_code_enabled" />
</div>
</div> </div>
</div> </div>
...@@ -1013,6 +1028,7 @@ type SettingsForm = SystemSettings & { ...@@ -1013,6 +1028,7 @@ type SettingsForm = SystemSettings & {
const form = reactive<SettingsForm>({ const form = reactive<SettingsForm>({
registration_enabled: true, registration_enabled: true,
email_verify_enabled: false, email_verify_enabled: false,
promo_code_enabled: true,
default_balance: 0, default_balance: 0,
default_concurrency: 1, default_concurrency: 1,
site_name: 'Sub2API', site_name: 'Sub2API',
...@@ -1135,6 +1151,7 @@ async function saveSettings() { ...@@ -1135,6 +1151,7 @@ async function saveSettings() {
const payload: UpdateSettingsRequest = { const payload: UpdateSettingsRequest = {
registration_enabled: form.registration_enabled, registration_enabled: form.registration_enabled,
email_verify_enabled: form.email_verify_enabled, email_verify_enabled: form.email_verify_enabled,
promo_code_enabled: form.promo_code_enabled,
default_balance: form.default_balance, default_balance: form.default_balance,
default_concurrency: form.default_concurrency, default_concurrency: form.default_concurrency,
site_name: form.site_name, site_name: form.site_name,
......
...@@ -466,7 +466,28 @@ ...@@ -466,7 +466,28 @@
v-model="assignForm.group_id" v-model="assignForm.group_id"
:options="subscriptionGroupOptions" :options="subscriptionGroupOptions"
:placeholder="t('admin.subscriptions.selectGroup')" :placeholder="t('admin.subscriptions.selectGroup')"
/> >
<template #selected="{ option }">
<GroupBadge
v-if="option"
:name="(option as unknown as GroupOption).label"
:platform="(option as unknown as GroupOption).platform"
:subscription-type="(option as unknown as GroupOption).subscriptionType"
:rate-multiplier="(option as unknown as GroupOption).rate"
/>
<span v-else class="text-gray-400">{{ t('admin.subscriptions.selectGroup') }}</span>
</template>
<template #option="{ option, selected }">
<GroupOptionItem
:name="(option as unknown as GroupOption).label"
:platform="(option as unknown as GroupOption).platform"
:subscription-type="(option as unknown as GroupOption).subscriptionType"
:rate-multiplier="(option as unknown as GroupOption).rate"
:description="(option as unknown as GroupOption).description"
:selected="selected"
/>
</template>
</Select>
<p class="input-hint">{{ t('admin.subscriptions.groupHint') }}</p> <p class="input-hint">{{ t('admin.subscriptions.groupHint') }}</p>
</div> </div>
<div> <div>
...@@ -599,7 +620,7 @@ import { ref, reactive, computed, onMounted, onUnmounted } from 'vue' ...@@ -599,7 +620,7 @@ import { ref, reactive, computed, onMounted, onUnmounted } from 'vue'
import { useI18n } from 'vue-i18n' import { useI18n } from 'vue-i18n'
import { useAppStore } from '@/stores/app' import { useAppStore } from '@/stores/app'
import { adminAPI } from '@/api/admin' import { adminAPI } from '@/api/admin'
import type { UserSubscription, Group } from '@/types' import type { UserSubscription, Group, GroupPlatform, SubscriptionType } from '@/types'
import type { SimpleUser } from '@/api/admin/usage' import type { SimpleUser } from '@/api/admin/usage'
import type { Column } from '@/components/common/types' import type { Column } from '@/components/common/types'
import { formatDateOnly } from '@/utils/format' import { formatDateOnly } from '@/utils/format'
...@@ -612,11 +633,21 @@ import ConfirmDialog from '@/components/common/ConfirmDialog.vue' ...@@ -612,11 +633,21 @@ import ConfirmDialog from '@/components/common/ConfirmDialog.vue'
import EmptyState from '@/components/common/EmptyState.vue' import EmptyState from '@/components/common/EmptyState.vue'
import Select from '@/components/common/Select.vue' import Select from '@/components/common/Select.vue'
import GroupBadge from '@/components/common/GroupBadge.vue' import GroupBadge from '@/components/common/GroupBadge.vue'
import GroupOptionItem from '@/components/common/GroupOptionItem.vue'
import Icon from '@/components/icons/Icon.vue' import Icon from '@/components/icons/Icon.vue'
const { t } = useI18n() const { t } = useI18n()
const appStore = useAppStore() const appStore = useAppStore()
interface GroupOption {
value: number
label: string
description: string | null
platform: GroupPlatform
subscriptionType: SubscriptionType
rate: number
}
// User column display mode: 'email' or 'username' // User column display mode: 'email' or 'username'
const userColumnMode = ref<'email' | 'username'>('email') const userColumnMode = ref<'email' | 'username'>('email')
const USER_COLUMN_MODE_KEY = 'subscription-user-column-mode' const USER_COLUMN_MODE_KEY = 'subscription-user-column-mode'
...@@ -792,7 +823,14 @@ const groupOptions = computed(() => [ ...@@ -792,7 +823,14 @@ const groupOptions = computed(() => [
const subscriptionGroupOptions = computed(() => const subscriptionGroupOptions = computed(() =>
groups.value groups.value
.filter((g) => g.subscription_type === 'subscription' && g.status === 'active') .filter((g) => g.subscription_type === 'subscription' && g.status === 'active')
.map((g) => ({ value: g.id, label: g.name })) .map((g) => ({
value: g.id,
label: g.name,
description: g.description,
platform: g.platform,
subscriptionType: g.subscription_type,
rate: g.rate_multiplier
}))
) )
const applyFilters = () => { const applyFilters = () => {
......
...@@ -96,7 +96,7 @@ ...@@ -96,7 +96,7 @@
</div> </div>
<!-- Promo Code Input (Optional) --> <!-- Promo Code Input (Optional) -->
<div> <div v-if="promoCodeEnabled">
<label for="promo_code" class="input-label"> <label for="promo_code" class="input-label">
{{ t('auth.promoCodeLabel') }} {{ t('auth.promoCodeLabel') }}
<span class="ml-1 text-xs font-normal text-gray-400 dark:text-dark-500">({{ t('common.optional') }})</span> <span class="ml-1 text-xs font-normal text-gray-400 dark:text-dark-500">({{ t('common.optional') }})</span>
...@@ -260,6 +260,7 @@ const showPassword = ref<boolean>(false) ...@@ -260,6 +260,7 @@ const showPassword = ref<boolean>(false)
// Public settings // Public settings
const registrationEnabled = ref<boolean>(true) const registrationEnabled = ref<boolean>(true)
const emailVerifyEnabled = ref<boolean>(false) const emailVerifyEnabled = ref<boolean>(false)
const promoCodeEnabled = ref<boolean>(true)
const turnstileEnabled = ref<boolean>(false) const turnstileEnabled = ref<boolean>(false)
const turnstileSiteKey = ref<string>('') const turnstileSiteKey = ref<string>('')
const siteName = ref<string>('Sub2API') const siteName = ref<string>('Sub2API')
...@@ -294,22 +295,25 @@ const errors = reactive({ ...@@ -294,22 +295,25 @@ const errors = reactive({
// ==================== Lifecycle ==================== // ==================== Lifecycle ====================
onMounted(async () => { onMounted(async () => {
// Read promo code from URL parameter
const promoParam = route.query.promo as string
if (promoParam) {
formData.promo_code = promoParam
// Validate the promo code from URL
await validatePromoCodeDebounced(promoParam)
}
try { try {
const settings = await getPublicSettings() const settings = await getPublicSettings()
registrationEnabled.value = settings.registration_enabled registrationEnabled.value = settings.registration_enabled
emailVerifyEnabled.value = settings.email_verify_enabled emailVerifyEnabled.value = settings.email_verify_enabled
promoCodeEnabled.value = settings.promo_code_enabled
turnstileEnabled.value = settings.turnstile_enabled turnstileEnabled.value = settings.turnstile_enabled
turnstileSiteKey.value = settings.turnstile_site_key || '' turnstileSiteKey.value = settings.turnstile_site_key || ''
siteName.value = settings.site_name || 'Sub2API' siteName.value = settings.site_name || 'Sub2API'
linuxdoOAuthEnabled.value = settings.linuxdo_oauth_enabled linuxdoOAuthEnabled.value = settings.linuxdo_oauth_enabled
// Read promo code from URL parameter only if promo code is enabled
if (promoCodeEnabled.value) {
const promoParam = route.query.promo as string
if (promoParam) {
formData.promo_code = promoParam
// Validate the promo code from URL
await validatePromoCodeDebounced(promoParam)
}
}
} catch (error) { } catch (error) {
console.error('Failed to load public settings:', error) console.error('Failed to load public settings:', error)
} finally { } finally {
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment