Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Menu
Open sidebar
陈曦
sub2api
Commits
bdc426a7
"...views/admin/ops/git@web.lueluesay.top:chenxi/sub2api.git" did not exist on "a05711a37a071dabb34edaff420501ec66d90273"
Commit
bdc426a7
authored
Jan 18, 2026
by
yangjianbo
Browse files
Merge branch 'main' into dev
parents
771baa66
32fff379
Changes
44
Hide whitespace changes
Inline
Side-by-side
backend/cmd/server/wire_gen.go
View file @
bdc426a7
...
...
@@ -118,7 +118,8 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
concurrencyCache
:=
repository
.
ProvideConcurrencyCache
(
redisClient
,
configConfig
)
concurrencyService
:=
service
.
ProvideConcurrencyService
(
concurrencyCache
,
accountRepository
,
configConfig
)
crsSyncService
:=
service
.
NewCRSSyncService
(
accountRepository
,
proxyRepository
,
oAuthService
,
openAIOAuthService
,
geminiOAuthService
,
configConfig
)
accountHandler
:=
admin
.
NewAccountHandler
(
adminService
,
oAuthService
,
openAIOAuthService
,
geminiOAuthService
,
antigravityOAuthService
,
rateLimitService
,
accountUsageService
,
accountTestService
,
concurrencyService
,
crsSyncService
)
sessionLimitCache
:=
repository
.
ProvideSessionLimitCache
(
redisClient
,
configConfig
)
accountHandler
:=
admin
.
NewAccountHandler
(
adminService
,
oAuthService
,
openAIOAuthService
,
geminiOAuthService
,
antigravityOAuthService
,
rateLimitService
,
accountUsageService
,
accountTestService
,
concurrencyService
,
crsSyncService
,
sessionLimitCache
)
oAuthHandler
:=
admin
.
NewOAuthHandler
(
oAuthService
)
openAIOAuthHandler
:=
admin
.
NewOpenAIOAuthHandler
(
openAIOAuthService
,
adminService
)
geminiOAuthHandler
:=
admin
.
NewGeminiOAuthHandler
(
geminiOAuthService
)
...
...
@@ -140,7 +141,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
identityService
:=
service
.
NewIdentityService
(
identityCache
)
deferredService
:=
service
.
ProvideDeferredService
(
accountRepository
,
timingWheelService
)
claudeTokenProvider
:=
service
.
NewClaudeTokenProvider
(
accountRepository
,
geminiTokenCache
,
oAuthService
)
gatewayService
:=
service
.
NewGatewayService
(
accountRepository
,
groupRepository
,
usageLogRepository
,
userRepository
,
userSubscriptionRepository
,
gatewayCache
,
configConfig
,
schedulerSnapshotService
,
concurrencyService
,
billingService
,
rateLimitService
,
billingCacheService
,
identityService
,
httpUpstream
,
deferredService
,
claudeTokenProvider
)
gatewayService
:=
service
.
NewGatewayService
(
accountRepository
,
groupRepository
,
usageLogRepository
,
userRepository
,
userSubscriptionRepository
,
gatewayCache
,
configConfig
,
schedulerSnapshotService
,
concurrencyService
,
billingService
,
rateLimitService
,
billingCacheService
,
identityService
,
httpUpstream
,
deferredService
,
claudeTokenProvider
,
sessionLimitCache
)
openAITokenProvider
:=
service
.
NewOpenAITokenProvider
(
accountRepository
,
geminiTokenCache
,
openAIOAuthService
)
openAIGatewayService
:=
service
.
NewOpenAIGatewayService
(
accountRepository
,
usageLogRepository
,
userRepository
,
userSubscriptionRepository
,
gatewayCache
,
configConfig
,
schedulerSnapshotService
,
concurrencyService
,
billingService
,
rateLimitService
,
billingCacheService
,
httpUpstream
,
deferredService
,
openAITokenProvider
)
geminiMessagesCompatService
:=
service
.
NewGeminiMessagesCompatService
(
accountRepository
,
groupRepository
,
gatewayCache
,
schedulerSnapshotService
,
geminiTokenProvider
,
rateLimitService
,
httpUpstream
,
antigravityGatewayService
,
configConfig
)
...
...
backend/internal/config/config.go
View file @
bdc426a7
...
...
@@ -235,6 +235,10 @@ type GatewayConfig struct {
// ConcurrencySlotTTLMinutes: 并发槽位过期时间(分钟)
// 应大于最长 LLM 请求时间,防止请求完成前槽位过期
ConcurrencySlotTTLMinutes
int
`mapstructure:"concurrency_slot_ttl_minutes"`
// SessionIdleTimeoutMinutes: 会话空闲超时时间(分钟),默认 5 分钟
// 用于 Anthropic OAuth/SetupToken 账号的会话数量限制功能
// 空闲超过此时间的会话将被自动释放
SessionIdleTimeoutMinutes
int
`mapstructure:"session_idle_timeout_minutes"`
// StreamDataIntervalTimeout: 流数据间隔超时(秒),0表示禁用
StreamDataIntervalTimeout
int
`mapstructure:"stream_data_interval_timeout"`
...
...
backend/internal/handler/admin/account_handler.go
View file @
bdc426a7
...
...
@@ -44,6 +44,7 @@ type AccountHandler struct {
accountTestService
*
service
.
AccountTestService
concurrencyService
*
service
.
ConcurrencyService
crsSyncService
*
service
.
CRSSyncService
sessionLimitCache
service
.
SessionLimitCache
}
// NewAccountHandler creates a new admin account handler
...
...
@@ -58,6 +59,7 @@ func NewAccountHandler(
accountTestService
*
service
.
AccountTestService
,
concurrencyService
*
service
.
ConcurrencyService
,
crsSyncService
*
service
.
CRSSyncService
,
sessionLimitCache
service
.
SessionLimitCache
,
)
*
AccountHandler
{
return
&
AccountHandler
{
adminService
:
adminService
,
...
...
@@ -70,6 +72,7 @@ func NewAccountHandler(
accountTestService
:
accountTestService
,
concurrencyService
:
concurrencyService
,
crsSyncService
:
crsSyncService
,
sessionLimitCache
:
sessionLimitCache
,
}
}
...
...
@@ -130,6 +133,9 @@ type BulkUpdateAccountsRequest struct {
type
AccountWithConcurrency
struct
{
*
dto
.
Account
CurrentConcurrency
int
`json:"current_concurrency"`
// 以下字段仅对 Anthropic OAuth/SetupToken 账号有效,且仅在启用相应功能时返回
CurrentWindowCost
*
float64
`json:"current_window_cost,omitempty"`
// 当前窗口费用
ActiveSessions
*
int
`json:"active_sessions,omitempty"`
// 当前活跃会话数
}
// List handles listing all accounts with pagination
...
...
@@ -164,13 +170,89 @@ func (h *AccountHandler) List(c *gin.Context) {
concurrencyCounts
=
make
(
map
[
int64
]
int
)
}
// 识别需要查询窗口费用和会话数的账号(Anthropic OAuth/SetupToken 且启用了相应功能)
windowCostAccountIDs
:=
make
([]
int64
,
0
)
sessionLimitAccountIDs
:=
make
([]
int64
,
0
)
for
i
:=
range
accounts
{
acc
:=
&
accounts
[
i
]
if
acc
.
IsAnthropicOAuthOrSetupToken
()
{
if
acc
.
GetWindowCostLimit
()
>
0
{
windowCostAccountIDs
=
append
(
windowCostAccountIDs
,
acc
.
ID
)
}
if
acc
.
GetMaxSessions
()
>
0
{
sessionLimitAccountIDs
=
append
(
sessionLimitAccountIDs
,
acc
.
ID
)
}
}
}
// 并行获取窗口费用和活跃会话数
var
windowCosts
map
[
int64
]
float64
var
activeSessions
map
[
int64
]
int
// 获取活跃会话数(批量查询)
if
len
(
sessionLimitAccountIDs
)
>
0
&&
h
.
sessionLimitCache
!=
nil
{
activeSessions
,
_
=
h
.
sessionLimitCache
.
GetActiveSessionCountBatch
(
c
.
Request
.
Context
(),
sessionLimitAccountIDs
)
if
activeSessions
==
nil
{
activeSessions
=
make
(
map
[
int64
]
int
)
}
}
// 获取窗口费用(并行查询)
if
len
(
windowCostAccountIDs
)
>
0
{
windowCosts
=
make
(
map
[
int64
]
float64
)
var
mu
sync
.
Mutex
g
,
gctx
:=
errgroup
.
WithContext
(
c
.
Request
.
Context
())
g
.
SetLimit
(
10
)
// 限制并发数
for
i
:=
range
accounts
{
acc
:=
&
accounts
[
i
]
if
!
acc
.
IsAnthropicOAuthOrSetupToken
()
||
acc
.
GetWindowCostLimit
()
<=
0
{
continue
}
accCopy
:=
acc
// 闭包捕获
g
.
Go
(
func
()
error
{
var
startTime
time
.
Time
if
accCopy
.
SessionWindowStart
!=
nil
{
startTime
=
*
accCopy
.
SessionWindowStart
}
else
{
startTime
=
time
.
Now
()
.
Add
(
-
5
*
time
.
Hour
)
}
stats
,
err
:=
h
.
accountUsageService
.
GetAccountWindowStats
(
gctx
,
accCopy
.
ID
,
startTime
)
if
err
==
nil
&&
stats
!=
nil
{
mu
.
Lock
()
windowCosts
[
accCopy
.
ID
]
=
stats
.
StandardCost
// 使用标准费用
mu
.
Unlock
()
}
return
nil
// 不返回错误,允许部分失败
})
}
_
=
g
.
Wait
()
}
// Build response with concurrency info
result
:=
make
([]
AccountWithConcurrency
,
len
(
accounts
))
for
i
:=
range
accounts
{
result
[
i
]
=
AccountWithConcurrency
{
Account
:
dto
.
AccountFromService
(
&
accounts
[
i
]),
CurrentConcurrency
:
concurrencyCounts
[
accounts
[
i
]
.
ID
],
acc
:=
&
accounts
[
i
]
item
:=
AccountWithConcurrency
{
Account
:
dto
.
AccountFromService
(
acc
),
CurrentConcurrency
:
concurrencyCounts
[
acc
.
ID
],
}
// 添加窗口费用(仅当启用时)
if
windowCosts
!=
nil
{
if
cost
,
ok
:=
windowCosts
[
acc
.
ID
];
ok
{
item
.
CurrentWindowCost
=
&
cost
}
}
// 添加活跃会话数(仅当启用时)
if
activeSessions
!=
nil
{
if
count
,
ok
:=
activeSessions
[
acc
.
ID
];
ok
{
item
.
ActiveSessions
=
&
count
}
}
result
[
i
]
=
item
}
response
.
Paginated
(
c
,
result
,
total
,
page
,
pageSize
)
...
...
backend/internal/handler/dto/mappers.go
View file @
bdc426a7
...
...
@@ -116,7 +116,7 @@ func AccountFromServiceShallow(a *service.Account) *Account {
if
a
==
nil
{
return
nil
}
return
&
Account
{
out
:=
&
Account
{
ID
:
a
.
ID
,
Name
:
a
.
Name
,
Notes
:
a
.
Notes
,
...
...
@@ -146,6 +146,24 @@ func AccountFromServiceShallow(a *service.Account) *Account {
SessionWindowStatus
:
a
.
SessionWindowStatus
,
GroupIDs
:
a
.
GroupIDs
,
}
// 提取 5h 窗口费用控制和会话数量控制配置(仅 Anthropic OAuth/SetupToken 账号有效)
if
a
.
IsAnthropicOAuthOrSetupToken
()
{
if
limit
:=
a
.
GetWindowCostLimit
();
limit
>
0
{
out
.
WindowCostLimit
=
&
limit
}
if
reserve
:=
a
.
GetWindowCostStickyReserve
();
reserve
>
0
{
out
.
WindowCostStickyReserve
=
&
reserve
}
if
maxSessions
:=
a
.
GetMaxSessions
();
maxSessions
>
0
{
out
.
MaxSessions
=
&
maxSessions
}
if
idleTimeout
:=
a
.
GetSessionIdleTimeoutMinutes
();
idleTimeout
>
0
{
out
.
SessionIdleTimeoutMin
=
&
idleTimeout
}
}
return
out
}
func
AccountFromService
(
a
*
service
.
Account
)
*
Account
{
...
...
backend/internal/handler/dto/types.go
View file @
bdc426a7
...
...
@@ -102,6 +102,16 @@ type Account struct {
SessionWindowEnd
*
time
.
Time
`json:"session_window_end"`
SessionWindowStatus
string
`json:"session_window_status"`
// 5h窗口费用控制(仅 Anthropic OAuth/SetupToken 账号有效)
// 从 extra 字段提取,方便前端显示和编辑
WindowCostLimit
*
float64
`json:"window_cost_limit,omitempty"`
WindowCostStickyReserve
*
float64
`json:"window_cost_sticky_reserve,omitempty"`
// 会话数量控制(仅 Anthropic OAuth/SetupToken 账号有效)
// 从 extra 字段提取,方便前端显示和编辑
MaxSessions
*
int
`json:"max_sessions,omitempty"`
SessionIdleTimeoutMin
*
int
`json:"session_idle_timeout_minutes,omitempty"`
Proxy
*
Proxy
`json:"proxy,omitempty"`
AccountGroups
[]
AccountGroup
`json:"account_groups,omitempty"`
...
...
backend/internal/handler/gateway_handler.go
View file @
bdc426a7
...
...
@@ -185,7 +185,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
lastFailoverStatus
:=
0
for
{
selection
,
err
:=
h
.
gatewayService
.
SelectAccountWithLoadAwareness
(
c
.
Request
.
Context
(),
apiKey
.
GroupID
,
sessionKey
,
reqModel
,
failedAccountIDs
)
selection
,
err
:=
h
.
gatewayService
.
SelectAccountWithLoadAwareness
(
c
.
Request
.
Context
(),
apiKey
.
GroupID
,
sessionKey
,
reqModel
,
failedAccountIDs
,
""
)
// Gemini 不使用会话限制
if
err
!=
nil
{
if
len
(
failedAccountIDs
)
==
0
{
h
.
handleStreamingAwareError
(
c
,
http
.
StatusServiceUnavailable
,
"api_error"
,
"No available accounts: "
+
err
.
Error
(),
streamStarted
)
...
...
@@ -320,7 +320,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
for
{
// 选择支持该模型的账号
selection
,
err
:=
h
.
gatewayService
.
SelectAccountWithLoadAwareness
(
c
.
Request
.
Context
(),
apiKey
.
GroupID
,
sessionKey
,
reqModel
,
failedAccountIDs
)
selection
,
err
:=
h
.
gatewayService
.
SelectAccountWithLoadAwareness
(
c
.
Request
.
Context
(),
apiKey
.
GroupID
,
sessionKey
,
reqModel
,
failedAccountIDs
,
parsedReq
.
MetadataUserID
)
if
err
!=
nil
{
if
len
(
failedAccountIDs
)
==
0
{
h
.
handleStreamingAwareError
(
c
,
http
.
StatusServiceUnavailable
,
"api_error"
,
"No available accounts: "
+
err
.
Error
(),
streamStarted
)
...
...
backend/internal/handler/gemini_v1beta_handler.go
View file @
bdc426a7
...
...
@@ -226,7 +226,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
lastFailoverStatus
:=
0
for
{
selection
,
err
:=
h
.
gatewayService
.
SelectAccountWithLoadAwareness
(
c
.
Request
.
Context
(),
apiKey
.
GroupID
,
sessionKey
,
modelName
,
failedAccountIDs
)
selection
,
err
:=
h
.
gatewayService
.
SelectAccountWithLoadAwareness
(
c
.
Request
.
Context
(),
apiKey
.
GroupID
,
sessionKey
,
modelName
,
failedAccountIDs
,
""
)
// Gemini 不使用会话限制
if
err
!=
nil
{
if
len
(
failedAccountIDs
)
==
0
{
googleError
(
c
,
http
.
StatusServiceUnavailable
,
"No available Gemini accounts: "
+
err
.
Error
())
...
...
backend/internal/handler/openai_gateway_handler.go
View file @
bdc426a7
...
...
@@ -186,8 +186,8 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
return
}
// Generate session hash (
from
header f
or OpenAI
)
sessionHash
:=
h
.
gatewayService
.
GenerateSessionHash
(
c
)
// Generate session hash (header f
irst; fallback to prompt_cache_key
)
sessionHash
:=
h
.
gatewayService
.
GenerateSessionHash
(
c
,
reqBody
)
const
maxAccountSwitches
=
3
switchCount
:=
0
...
...
backend/internal/pkg/gemini/models.go
View file @
bdc426a7
...
...
@@ -16,14 +16,11 @@ type ModelsListResponse struct {
func
DefaultModels
()
[]
Model
{
methods
:=
[]
string
{
"generateContent"
,
"streamGenerateContent"
}
return
[]
Model
{
{
Name
:
"models/gemini-3-pro-preview"
,
SupportedGenerationMethods
:
methods
},
{
Name
:
"models/gemini-3-flash-preview"
,
SupportedGenerationMethods
:
methods
},
{
Name
:
"models/gemini-2.5-pro"
,
SupportedGenerationMethods
:
methods
},
{
Name
:
"models/gemini-2.5-flash"
,
SupportedGenerationMethods
:
methods
},
{
Name
:
"models/gemini-2.0-flash"
,
SupportedGenerationMethods
:
methods
},
{
Name
:
"models/gemini-1.5-pro"
,
SupportedGenerationMethods
:
methods
},
{
Name
:
"models/gemini-1.5-flash"
,
SupportedGenerationMethods
:
methods
},
{
Name
:
"models/gemini-1.5-flash-8b"
,
SupportedGenerationMethods
:
methods
},
{
Name
:
"models/gemini-2.5-flash"
,
SupportedGenerationMethods
:
methods
},
{
Name
:
"models/gemini-2.5-pro"
,
SupportedGenerationMethods
:
methods
},
{
Name
:
"models/gemini-3-flash-preview"
,
SupportedGenerationMethods
:
methods
},
{
Name
:
"models/gemini-3-pro-preview"
,
SupportedGenerationMethods
:
methods
},
}
}
...
...
backend/internal/pkg/geminicli/models.go
View file @
bdc426a7
...
...
@@ -12,10 +12,10 @@ type Model struct {
// DefaultModels is the curated Gemini model list used by the admin UI "test account" flow.
var
DefaultModels
=
[]
Model
{
{
ID
:
"gemini-2.0-flash"
,
Type
:
"model"
,
DisplayName
:
"Gemini 2.0 Flash"
,
CreatedAt
:
""
},
{
ID
:
"gemini-2.5-pro"
,
Type
:
"model"
,
DisplayName
:
"Gemini 2.5 Pro"
,
CreatedAt
:
""
},
{
ID
:
"gemini-2.5-flash"
,
Type
:
"model"
,
DisplayName
:
"Gemini 2.5 Flash"
,
CreatedAt
:
""
},
{
ID
:
"gemini-
3
-pro
-preview
"
,
Type
:
"model"
,
DisplayName
:
"Gemini
3
Pro
Preview
"
,
CreatedAt
:
""
},
{
ID
:
"gemini-
2.5
-pro"
,
Type
:
"model"
,
DisplayName
:
"Gemini
2.5
Pro"
,
CreatedAt
:
""
},
{
ID
:
"gemini-3-flash-preview"
,
Type
:
"model"
,
DisplayName
:
"Gemini 3 Flash Preview"
,
CreatedAt
:
""
},
{
ID
:
"gemini-3-pro-preview"
,
Type
:
"model"
,
DisplayName
:
"Gemini 3 Pro Preview"
,
CreatedAt
:
""
},
}
// DefaultTestModel is the default model to preselect in test flows.
...
...
backend/internal/repository/ent.go
View file @
bdc426a7
...
...
@@ -65,5 +65,18 @@ func InitEnt(cfg *config.Config) (*ent.Client, *sql.DB, error) {
// 创建 Ent 客户端,绑定到已配置的数据库驱动。
client
:=
ent
.
NewClient
(
ent
.
Driver
(
drv
))
// SIMPLE 模式:启动时补齐各平台默认分组。
// - anthropic/openai/gemini: 确保存在 <platform>-default
// - antigravity: 仅要求存在 >=2 个未软删除分组(用于 claude/gemini 混合调度场景)
if
cfg
.
RunMode
==
config
.
RunModeSimple
{
seedCtx
,
seedCancel
:=
context
.
WithTimeout
(
context
.
Background
(),
30
*
time
.
Second
)
defer
seedCancel
()
if
err
:=
ensureSimpleModeDefaultGroups
(
seedCtx
,
client
);
err
!=
nil
{
_
=
client
.
Close
()
return
nil
,
nil
,
err
}
}
return
client
,
drv
.
DB
(),
nil
}
backend/internal/repository/ops_repo.go
View file @
bdc426a7
...
...
@@ -992,7 +992,8 @@ func buildOpsErrorLogsWhere(filter *service.OpsErrorLogFilter) (string, []any) {
}
// View filter: errors vs excluded vs all.
// Excluded = upstream 429/529 and business-limited (quota/concurrency/billing) errors.
// Excluded = business-limited errors (quota/concurrency/billing).
// Upstream 429/529 are included in errors view to match SLA calculation.
view
:=
""
if
filter
!=
nil
{
view
=
strings
.
ToLower
(
strings
.
TrimSpace
(
filter
.
View
))
...
...
@@ -1000,15 +1001,13 @@ func buildOpsErrorLogsWhere(filter *service.OpsErrorLogFilter) (string, []any) {
switch
view
{
case
""
,
"errors"
:
clauses
=
append
(
clauses
,
"COALESCE(is_business_limited,false) = false"
)
clauses
=
append
(
clauses
,
"COALESCE(upstream_status_code, status_code, 0) NOT IN (429, 529)"
)
case
"excluded"
:
clauses
=
append
(
clauses
,
"
(
COALESCE(is_business_limited,false) = true
OR COALESCE(upstream_status_code, status_code, 0) IN (429, 529))
"
)
clauses
=
append
(
clauses
,
"COALESCE(is_business_limited,false) = true"
)
case
"all"
:
// no-op
default
:
// treat unknown as default 'errors'
clauses
=
append
(
clauses
,
"COALESCE(is_business_limited,false) = false"
)
clauses
=
append
(
clauses
,
"COALESCE(upstream_status_code, status_code, 0) NOT IN (429, 529)"
)
}
if
len
(
filter
.
StatusCodes
)
>
0
{
args
=
append
(
args
,
pq
.
Array
(
filter
.
StatusCodes
))
...
...
backend/internal/repository/session_limit_cache.go
0 → 100644
View file @
bdc426a7
package
repository
import
(
"context"
"fmt"
"strconv"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/redis/go-redis/v9"
)
// 会话限制缓存常量定义
//
// 设计说明:
// 使用 Redis 有序集合(Sorted Set)跟踪每个账号的活跃会话:
// - Key: session_limit:account:{accountID}
// - Member: sessionUUID(从 metadata.user_id 中提取)
// - Score: Unix 时间戳(会话最后活跃时间)
//
// 通过 ZREMRANGEBYSCORE 自动清理过期会话,无需手动管理 TTL
const
(
// 会话限制键前缀
// 格式: session_limit:account:{accountID}
sessionLimitKeyPrefix
=
"session_limit:account:"
// 窗口费用缓存键前缀
// 格式: window_cost:account:{accountID}
windowCostKeyPrefix
=
"window_cost:account:"
// 窗口费用缓存 TTL(30秒)
windowCostCacheTTL
=
30
*
time
.
Second
)
var
(
// registerSessionScript 注册会话活动
// 使用 Redis TIME 命令获取服务器时间,避免多实例时钟不同步
// KEYS[1] = session_limit:account:{accountID}
// ARGV[1] = maxSessions
// ARGV[2] = idleTimeout(秒)
// ARGV[3] = sessionUUID
// 返回: 1 = 允许, 0 = 拒绝
registerSessionScript
=
redis
.
NewScript
(
`
local key = KEYS[1]
local maxSessions = tonumber(ARGV[1])
local idleTimeout = tonumber(ARGV[2])
local sessionUUID = ARGV[3]
-- 使用 Redis 服务器时间,确保多实例时钟一致
local timeResult = redis.call('TIME')
local now = tonumber(timeResult[1])
local expireBefore = now - idleTimeout
-- 清理过期会话
redis.call('ZREMRANGEBYSCORE', key, '-inf', expireBefore)
-- 检查会话是否已存在(支持刷新时间戳)
local exists = redis.call('ZSCORE', key, sessionUUID)
if exists ~= false then
-- 会话已存在,刷新时间戳
redis.call('ZADD', key, now, sessionUUID)
redis.call('EXPIRE', key, idleTimeout + 60)
return 1
end
-- 检查是否达到会话数量上限
local count = redis.call('ZCARD', key)
if count < maxSessions then
-- 未达上限,添加新会话
redis.call('ZADD', key, now, sessionUUID)
redis.call('EXPIRE', key, idleTimeout + 60)
return 1
end
-- 达到上限,拒绝新会话
return 0
`
)
// refreshSessionScript 刷新会话时间戳
// KEYS[1] = session_limit:account:{accountID}
// ARGV[1] = idleTimeout(秒)
// ARGV[2] = sessionUUID
refreshSessionScript
=
redis
.
NewScript
(
`
local key = KEYS[1]
local idleTimeout = tonumber(ARGV[1])
local sessionUUID = ARGV[2]
local timeResult = redis.call('TIME')
local now = tonumber(timeResult[1])
-- 检查会话是否存在
local exists = redis.call('ZSCORE', key, sessionUUID)
if exists ~= false then
redis.call('ZADD', key, now, sessionUUID)
redis.call('EXPIRE', key, idleTimeout + 60)
end
return 1
`
)
// getActiveSessionCountScript 获取活跃会话数
// KEYS[1] = session_limit:account:{accountID}
// ARGV[1] = idleTimeout(秒)
getActiveSessionCountScript
=
redis
.
NewScript
(
`
local key = KEYS[1]
local idleTimeout = tonumber(ARGV[1])
local timeResult = redis.call('TIME')
local now = tonumber(timeResult[1])
local expireBefore = now - idleTimeout
-- 清理过期会话
redis.call('ZREMRANGEBYSCORE', key, '-inf', expireBefore)
return redis.call('ZCARD', key)
`
)
// isSessionActiveScript 检查会话是否活跃
// KEYS[1] = session_limit:account:{accountID}
// ARGV[1] = idleTimeout(秒)
// ARGV[2] = sessionUUID
isSessionActiveScript
=
redis
.
NewScript
(
`
local key = KEYS[1]
local idleTimeout = tonumber(ARGV[1])
local sessionUUID = ARGV[2]
local timeResult = redis.call('TIME')
local now = tonumber(timeResult[1])
local expireBefore = now - idleTimeout
-- 获取会话的时间戳
local score = redis.call('ZSCORE', key, sessionUUID)
if score == false then
return 0
end
-- 检查是否过期
if tonumber(score) <= expireBefore then
return 0
end
return 1
`
)
)
type
sessionLimitCache
struct
{
rdb
*
redis
.
Client
defaultIdleTimeout
time
.
Duration
// 默认空闲超时(用于 GetActiveSessionCount)
}
// NewSessionLimitCache 创建会话限制缓存
// defaultIdleTimeoutMinutes: 默认空闲超时时间(分钟),用于无参数查询
func
NewSessionLimitCache
(
rdb
*
redis
.
Client
,
defaultIdleTimeoutMinutes
int
)
service
.
SessionLimitCache
{
if
defaultIdleTimeoutMinutes
<=
0
{
defaultIdleTimeoutMinutes
=
5
// 默认 5 分钟
}
return
&
sessionLimitCache
{
rdb
:
rdb
,
defaultIdleTimeout
:
time
.
Duration
(
defaultIdleTimeoutMinutes
)
*
time
.
Minute
,
}
}
// sessionLimitKey 生成会话限制的 Redis 键
func
sessionLimitKey
(
accountID
int64
)
string
{
return
fmt
.
Sprintf
(
"%s%d"
,
sessionLimitKeyPrefix
,
accountID
)
}
// windowCostKey 生成窗口费用缓存的 Redis 键
func
windowCostKey
(
accountID
int64
)
string
{
return
fmt
.
Sprintf
(
"%s%d"
,
windowCostKeyPrefix
,
accountID
)
}
// RegisterSession 注册会话活动
func
(
c
*
sessionLimitCache
)
RegisterSession
(
ctx
context
.
Context
,
accountID
int64
,
sessionUUID
string
,
maxSessions
int
,
idleTimeout
time
.
Duration
)
(
bool
,
error
)
{
if
sessionUUID
==
""
||
maxSessions
<=
0
{
return
true
,
nil
// 无效参数,默认允许
}
key
:=
sessionLimitKey
(
accountID
)
idleTimeoutSeconds
:=
int
(
idleTimeout
.
Seconds
())
if
idleTimeoutSeconds
<=
0
{
idleTimeoutSeconds
=
int
(
c
.
defaultIdleTimeout
.
Seconds
())
}
result
,
err
:=
registerSessionScript
.
Run
(
ctx
,
c
.
rdb
,
[]
string
{
key
},
maxSessions
,
idleTimeoutSeconds
,
sessionUUID
)
.
Int
()
if
err
!=
nil
{
return
true
,
err
// 失败开放:缓存错误时允许请求通过
}
return
result
==
1
,
nil
}
// RefreshSession 刷新会话时间戳
func
(
c
*
sessionLimitCache
)
RefreshSession
(
ctx
context
.
Context
,
accountID
int64
,
sessionUUID
string
,
idleTimeout
time
.
Duration
)
error
{
if
sessionUUID
==
""
{
return
nil
}
key
:=
sessionLimitKey
(
accountID
)
idleTimeoutSeconds
:=
int
(
idleTimeout
.
Seconds
())
if
idleTimeoutSeconds
<=
0
{
idleTimeoutSeconds
=
int
(
c
.
defaultIdleTimeout
.
Seconds
())
}
_
,
err
:=
refreshSessionScript
.
Run
(
ctx
,
c
.
rdb
,
[]
string
{
key
},
idleTimeoutSeconds
,
sessionUUID
)
.
Result
()
return
err
}
// GetActiveSessionCount 获取活跃会话数
func
(
c
*
sessionLimitCache
)
GetActiveSessionCount
(
ctx
context
.
Context
,
accountID
int64
)
(
int
,
error
)
{
key
:=
sessionLimitKey
(
accountID
)
idleTimeoutSeconds
:=
int
(
c
.
defaultIdleTimeout
.
Seconds
())
result
,
err
:=
getActiveSessionCountScript
.
Run
(
ctx
,
c
.
rdb
,
[]
string
{
key
},
idleTimeoutSeconds
)
.
Int
()
if
err
!=
nil
{
return
0
,
err
}
return
result
,
nil
}
// GetActiveSessionCountBatch 批量获取多个账号的活跃会话数
func
(
c
*
sessionLimitCache
)
GetActiveSessionCountBatch
(
ctx
context
.
Context
,
accountIDs
[]
int64
)
(
map
[
int64
]
int
,
error
)
{
if
len
(
accountIDs
)
==
0
{
return
make
(
map
[
int64
]
int
),
nil
}
results
:=
make
(
map
[
int64
]
int
,
len
(
accountIDs
))
// 使用 pipeline 批量执行
pipe
:=
c
.
rdb
.
Pipeline
()
idleTimeoutSeconds
:=
int
(
c
.
defaultIdleTimeout
.
Seconds
())
cmds
:=
make
(
map
[
int64
]
*
redis
.
Cmd
,
len
(
accountIDs
))
for
_
,
accountID
:=
range
accountIDs
{
key
:=
sessionLimitKey
(
accountID
)
cmds
[
accountID
]
=
getActiveSessionCountScript
.
Run
(
ctx
,
pipe
,
[]
string
{
key
},
idleTimeoutSeconds
)
}
// 执行 pipeline,即使部分失败也尝试获取成功的结果
_
,
_
=
pipe
.
Exec
(
ctx
)
for
accountID
,
cmd
:=
range
cmds
{
if
result
,
err
:=
cmd
.
Int
();
err
==
nil
{
results
[
accountID
]
=
result
}
}
return
results
,
nil
}
// IsSessionActive 检查会话是否活跃
func
(
c
*
sessionLimitCache
)
IsSessionActive
(
ctx
context
.
Context
,
accountID
int64
,
sessionUUID
string
)
(
bool
,
error
)
{
if
sessionUUID
==
""
{
return
false
,
nil
}
key
:=
sessionLimitKey
(
accountID
)
idleTimeoutSeconds
:=
int
(
c
.
defaultIdleTimeout
.
Seconds
())
result
,
err
:=
isSessionActiveScript
.
Run
(
ctx
,
c
.
rdb
,
[]
string
{
key
},
idleTimeoutSeconds
,
sessionUUID
)
.
Int
()
if
err
!=
nil
{
return
false
,
err
}
return
result
==
1
,
nil
}
// ========== 5h窗口费用缓存实现 ==========
// GetWindowCost 获取缓存的窗口费用
func
(
c
*
sessionLimitCache
)
GetWindowCost
(
ctx
context
.
Context
,
accountID
int64
)
(
float64
,
bool
,
error
)
{
key
:=
windowCostKey
(
accountID
)
val
,
err
:=
c
.
rdb
.
Get
(
ctx
,
key
)
.
Float64
()
if
err
==
redis
.
Nil
{
return
0
,
false
,
nil
// 缓存未命中
}
if
err
!=
nil
{
return
0
,
false
,
err
}
return
val
,
true
,
nil
}
// SetWindowCost 设置窗口费用缓存
func
(
c
*
sessionLimitCache
)
SetWindowCost
(
ctx
context
.
Context
,
accountID
int64
,
cost
float64
)
error
{
key
:=
windowCostKey
(
accountID
)
return
c
.
rdb
.
Set
(
ctx
,
key
,
cost
,
windowCostCacheTTL
)
.
Err
()
}
// GetWindowCostBatch 批量获取窗口费用缓存
func
(
c
*
sessionLimitCache
)
GetWindowCostBatch
(
ctx
context
.
Context
,
accountIDs
[]
int64
)
(
map
[
int64
]
float64
,
error
)
{
if
len
(
accountIDs
)
==
0
{
return
make
(
map
[
int64
]
float64
),
nil
}
// 构建批量查询的 keys
keys
:=
make
([]
string
,
len
(
accountIDs
))
for
i
,
accountID
:=
range
accountIDs
{
keys
[
i
]
=
windowCostKey
(
accountID
)
}
// 使用 MGET 批量获取
vals
,
err
:=
c
.
rdb
.
MGet
(
ctx
,
keys
...
)
.
Result
()
if
err
!=
nil
{
return
nil
,
err
}
results
:=
make
(
map
[
int64
]
float64
,
len
(
accountIDs
))
for
i
,
val
:=
range
vals
{
if
val
==
nil
{
continue
// 缓存未命中
}
// 尝试解析为 float64
switch
v
:=
val
.
(
type
)
{
case
string
:
if
cost
,
err
:=
strconv
.
ParseFloat
(
v
,
64
);
err
==
nil
{
results
[
accountIDs
[
i
]]
=
cost
}
case
float64
:
results
[
accountIDs
[
i
]]
=
v
}
}
return
results
,
nil
}
backend/internal/repository/simple_mode_default_groups.go
0 → 100644
View file @
bdc426a7
package
repository
import
(
"context"
"fmt"
dbent
"github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/group"
"github.com/Wei-Shaw/sub2api/internal/service"
)
func
ensureSimpleModeDefaultGroups
(
ctx
context
.
Context
,
client
*
dbent
.
Client
)
error
{
if
client
==
nil
{
return
fmt
.
Errorf
(
"nil ent client"
)
}
requiredByPlatform
:=
map
[
string
]
int
{
service
.
PlatformAnthropic
:
1
,
service
.
PlatformOpenAI
:
1
,
service
.
PlatformGemini
:
1
,
service
.
PlatformAntigravity
:
2
,
}
for
platform
,
minCount
:=
range
requiredByPlatform
{
count
,
err
:=
client
.
Group
.
Query
()
.
Where
(
group
.
PlatformEQ
(
platform
),
group
.
DeletedAtIsNil
())
.
Count
(
ctx
)
if
err
!=
nil
{
return
fmt
.
Errorf
(
"count groups for platform %s: %w"
,
platform
,
err
)
}
if
platform
==
service
.
PlatformAntigravity
{
if
count
<
minCount
{
for
i
:=
count
;
i
<
minCount
;
i
++
{
name
:=
fmt
.
Sprintf
(
"%s-default-%d"
,
platform
,
i
+
1
)
if
err
:=
createGroupIfNotExists
(
ctx
,
client
,
name
,
platform
);
err
!=
nil
{
return
err
}
}
}
continue
}
// Non-antigravity platforms: ensure <platform>-default exists.
name
:=
platform
+
"-default"
if
err
:=
createGroupIfNotExists
(
ctx
,
client
,
name
,
platform
);
err
!=
nil
{
return
err
}
}
return
nil
}
func
createGroupIfNotExists
(
ctx
context
.
Context
,
client
*
dbent
.
Client
,
name
,
platform
string
)
error
{
exists
,
err
:=
client
.
Group
.
Query
()
.
Where
(
group
.
NameEQ
(
name
),
group
.
DeletedAtIsNil
())
.
Exist
(
ctx
)
if
err
!=
nil
{
return
fmt
.
Errorf
(
"check group exists %s: %w"
,
name
,
err
)
}
if
exists
{
return
nil
}
_
,
err
=
client
.
Group
.
Create
()
.
SetName
(
name
)
.
SetDescription
(
"Auto-created default group"
)
.
SetPlatform
(
platform
)
.
SetStatus
(
service
.
StatusActive
)
.
SetSubscriptionType
(
service
.
SubscriptionTypeStandard
)
.
SetRateMultiplier
(
1.0
)
.
SetIsExclusive
(
false
)
.
Save
(
ctx
)
if
err
!=
nil
{
if
dbent
.
IsConstraintError
(
err
)
{
// Concurrent server startups may race on creation; treat as success.
return
nil
}
return
fmt
.
Errorf
(
"create default group %s: %w"
,
name
,
err
)
}
return
nil
}
backend/internal/repository/simple_mode_default_groups_integration_test.go
0 → 100644
View file @
bdc426a7
//go:build integration
package
repository
import
(
"context"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/ent/group"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/stretchr/testify/require"
)
func
TestEnsureSimpleModeDefaultGroups_CreatesMissingDefaults
(
t
*
testing
.
T
)
{
ctx
:=
context
.
Background
()
tx
:=
testEntTx
(
t
)
client
:=
tx
.
Client
()
seedCtx
,
cancel
:=
context
.
WithTimeout
(
ctx
,
10
*
time
.
Second
)
defer
cancel
()
require
.
NoError
(
t
,
ensureSimpleModeDefaultGroups
(
seedCtx
,
client
))
assertGroupExists
:=
func
(
name
string
)
{
exists
,
err
:=
client
.
Group
.
Query
()
.
Where
(
group
.
NameEQ
(
name
),
group
.
DeletedAtIsNil
())
.
Exist
(
seedCtx
)
require
.
NoError
(
t
,
err
)
require
.
True
(
t
,
exists
,
"expected group %s to exist"
,
name
)
}
assertGroupExists
(
service
.
PlatformAnthropic
+
"-default"
)
assertGroupExists
(
service
.
PlatformOpenAI
+
"-default"
)
assertGroupExists
(
service
.
PlatformGemini
+
"-default"
)
assertGroupExists
(
service
.
PlatformAntigravity
+
"-default-1"
)
assertGroupExists
(
service
.
PlatformAntigravity
+
"-default-2"
)
}
func
TestEnsureSimpleModeDefaultGroups_IgnoresSoftDeletedGroups
(
t
*
testing
.
T
)
{
ctx
:=
context
.
Background
()
tx
:=
testEntTx
(
t
)
client
:=
tx
.
Client
()
seedCtx
,
cancel
:=
context
.
WithTimeout
(
ctx
,
10
*
time
.
Second
)
defer
cancel
()
// Create and then soft-delete an anthropic default group.
g
,
err
:=
client
.
Group
.
Create
()
.
SetName
(
service
.
PlatformAnthropic
+
"-default"
)
.
SetPlatform
(
service
.
PlatformAnthropic
)
.
SetStatus
(
service
.
StatusActive
)
.
SetSubscriptionType
(
service
.
SubscriptionTypeStandard
)
.
SetRateMultiplier
(
1.0
)
.
SetIsExclusive
(
false
)
.
Save
(
seedCtx
)
require
.
NoError
(
t
,
err
)
_
,
err
=
client
.
Group
.
Delete
()
.
Where
(
group
.
IDEQ
(
g
.
ID
))
.
Exec
(
seedCtx
)
require
.
NoError
(
t
,
err
)
require
.
NoError
(
t
,
ensureSimpleModeDefaultGroups
(
seedCtx
,
client
))
// New active one should exist.
count
,
err
:=
client
.
Group
.
Query
()
.
Where
(
group
.
NameEQ
(
service
.
PlatformAnthropic
+
"-default"
),
group
.
DeletedAtIsNil
())
.
Count
(
seedCtx
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
1
,
count
)
}
func
TestEnsureSimpleModeDefaultGroups_AntigravityNeedsTwoGroupsOnlyByCount
(
t
*
testing
.
T
)
{
ctx
:=
context
.
Background
()
tx
:=
testEntTx
(
t
)
client
:=
tx
.
Client
()
seedCtx
,
cancel
:=
context
.
WithTimeout
(
ctx
,
10
*
time
.
Second
)
defer
cancel
()
mustCreateGroup
(
t
,
client
,
&
service
.
Group
{
Name
:
"ag-custom-1-"
+
time
.
Now
()
.
Format
(
time
.
RFC3339Nano
),
Platform
:
service
.
PlatformAntigravity
})
mustCreateGroup
(
t
,
client
,
&
service
.
Group
{
Name
:
"ag-custom-2-"
+
time
.
Now
()
.
Format
(
time
.
RFC3339Nano
),
Platform
:
service
.
PlatformAntigravity
})
require
.
NoError
(
t
,
ensureSimpleModeDefaultGroups
(
seedCtx
,
client
))
count
,
err
:=
client
.
Group
.
Query
()
.
Where
(
group
.
PlatformEQ
(
service
.
PlatformAntigravity
),
group
.
DeletedAtIsNil
())
.
Count
(
seedCtx
)
require
.
NoError
(
t
,
err
)
require
.
GreaterOrEqual
(
t
,
count
,
2
)
}
backend/internal/repository/wire.go
View file @
bdc426a7
...
...
@@ -37,6 +37,16 @@ func ProvidePricingRemoteClient(cfg *config.Config) service.PricingRemoteClient
return
NewPricingRemoteClient
(
cfg
.
Update
.
ProxyURL
)
}
// ProvideSessionLimitCache 创建会话限制缓存
// 用于 Anthropic OAuth/SetupToken 账号的并发会话数量控制
func
ProvideSessionLimitCache
(
rdb
*
redis
.
Client
,
cfg
*
config
.
Config
)
service
.
SessionLimitCache
{
defaultIdleTimeoutMinutes
:=
5
// 默认 5 分钟空闲超时
if
cfg
!=
nil
&&
cfg
.
Gateway
.
SessionIdleTimeoutMinutes
>
0
{
defaultIdleTimeoutMinutes
=
cfg
.
Gateway
.
SessionIdleTimeoutMinutes
}
return
NewSessionLimitCache
(
rdb
,
defaultIdleTimeoutMinutes
)
}
// ProviderSet is the Wire provider set for all repositories
var
ProviderSet
=
wire
.
NewSet
(
NewUserRepository
,
...
...
@@ -62,6 +72,7 @@ var ProviderSet = wire.NewSet(
NewTempUnschedCache
,
NewTimeoutCounterCache
,
ProvideConcurrencyCache
,
ProvideSessionLimitCache
,
NewDashboardCache
,
NewEmailCache
,
NewIdentityCache
,
...
...
backend/internal/server/api_contract_test.go
View file @
bdc426a7
...
...
@@ -441,7 +441,7 @@ func newContractDeps(t *testing.T) *contractDeps {
apiKeyHandler
:=
handler
.
NewAPIKeyHandler
(
apiKeyService
)
usageHandler
:=
handler
.
NewUsageHandler
(
usageService
,
apiKeyService
)
adminSettingHandler
:=
adminhandler
.
NewSettingHandler
(
settingService
,
nil
,
nil
,
nil
)
adminAccountHandler
:=
adminhandler
.
NewAccountHandler
(
adminService
,
nil
,
nil
,
nil
,
nil
,
nil
,
nil
,
nil
,
nil
,
nil
)
adminAccountHandler
:=
adminhandler
.
NewAccountHandler
(
adminService
,
nil
,
nil
,
nil
,
nil
,
nil
,
nil
,
nil
,
nil
,
nil
,
nil
)
jwtAuth
:=
func
(
c
*
gin
.
Context
)
{
c
.
Set
(
string
(
middleware
.
ContextKeyUser
),
middleware
.
AuthSubject
{
...
...
backend/internal/service/account.go
View file @
bdc426a7
...
...
@@ -557,3 +557,141 @@ func (a *Account) IsMixedSchedulingEnabled() bool {
}
return
false
}
// WindowCostSchedulability 窗口费用调度状态
type
WindowCostSchedulability
int
const
(
// WindowCostSchedulable 可正常调度
WindowCostSchedulable
WindowCostSchedulability
=
iota
// WindowCostStickyOnly 仅允许粘性会话
WindowCostStickyOnly
// WindowCostNotSchedulable 完全不可调度
WindowCostNotSchedulable
)
// IsAnthropicOAuthOrSetupToken 判断是否为 Anthropic OAuth 或 SetupToken 类型账号
// 仅这两类账号支持 5h 窗口额度控制和会话数量控制
func
(
a
*
Account
)
IsAnthropicOAuthOrSetupToken
()
bool
{
return
a
.
Platform
==
PlatformAnthropic
&&
(
a
.
Type
==
AccountTypeOAuth
||
a
.
Type
==
AccountTypeSetupToken
)
}
// GetWindowCostLimit 获取 5h 窗口费用阈值(美元)
// 返回 0 表示未启用
func
(
a
*
Account
)
GetWindowCostLimit
()
float64
{
if
a
.
Extra
==
nil
{
return
0
}
if
v
,
ok
:=
a
.
Extra
[
"window_cost_limit"
];
ok
{
return
parseExtraFloat64
(
v
)
}
return
0
}
// GetWindowCostStickyReserve 获取粘性会话预留额度(美元)
// 默认值为 10
func
(
a
*
Account
)
GetWindowCostStickyReserve
()
float64
{
if
a
.
Extra
==
nil
{
return
10.0
}
if
v
,
ok
:=
a
.
Extra
[
"window_cost_sticky_reserve"
];
ok
{
val
:=
parseExtraFloat64
(
v
)
if
val
>
0
{
return
val
}
}
return
10.0
}
// GetMaxSessions 获取最大并发会话数
// 返回 0 表示未启用
func
(
a
*
Account
)
GetMaxSessions
()
int
{
if
a
.
Extra
==
nil
{
return
0
}
if
v
,
ok
:=
a
.
Extra
[
"max_sessions"
];
ok
{
return
parseExtraInt
(
v
)
}
return
0
}
// GetSessionIdleTimeoutMinutes 获取会话空闲超时分钟数
// 默认值为 5 分钟
func
(
a
*
Account
)
GetSessionIdleTimeoutMinutes
()
int
{
if
a
.
Extra
==
nil
{
return
5
}
if
v
,
ok
:=
a
.
Extra
[
"session_idle_timeout_minutes"
];
ok
{
val
:=
parseExtraInt
(
v
)
if
val
>
0
{
return
val
}
}
return
5
}
// CheckWindowCostSchedulability 根据当前窗口费用检查调度状态
// - 费用 < 阈值: WindowCostSchedulable(可正常调度)
// - 费用 >= 阈值 且 < 阈值+预留: WindowCostStickyOnly(仅粘性会话)
// - 费用 >= 阈值+预留: WindowCostNotSchedulable(不可调度)
func
(
a
*
Account
)
CheckWindowCostSchedulability
(
currentWindowCost
float64
)
WindowCostSchedulability
{
limit
:=
a
.
GetWindowCostLimit
()
if
limit
<=
0
{
return
WindowCostSchedulable
}
if
currentWindowCost
<
limit
{
return
WindowCostSchedulable
}
stickyReserve
:=
a
.
GetWindowCostStickyReserve
()
if
currentWindowCost
<
limit
+
stickyReserve
{
return
WindowCostStickyOnly
}
return
WindowCostNotSchedulable
}
// parseExtraFloat64 从 extra 字段解析 float64 值
func
parseExtraFloat64
(
value
any
)
float64
{
switch
v
:=
value
.
(
type
)
{
case
float64
:
return
v
case
float32
:
return
float64
(
v
)
case
int
:
return
float64
(
v
)
case
int64
:
return
float64
(
v
)
case
json
.
Number
:
if
f
,
err
:=
v
.
Float64
();
err
==
nil
{
return
f
}
case
string
:
if
f
,
err
:=
strconv
.
ParseFloat
(
strings
.
TrimSpace
(
v
),
64
);
err
==
nil
{
return
f
}
}
return
0
}
// parseExtraInt 从 extra 字段解析 int 值
func
parseExtraInt
(
value
any
)
int
{
switch
v
:=
value
.
(
type
)
{
case
int
:
return
v
case
int64
:
return
int
(
v
)
case
float64
:
return
int
(
v
)
case
json
.
Number
:
if
i
,
err
:=
v
.
Int64
();
err
==
nil
{
return
int
(
i
)
}
case
string
:
if
i
,
err
:=
strconv
.
Atoi
(
strings
.
TrimSpace
(
v
));
err
==
nil
{
return
i
}
}
return
0
}
backend/internal/service/account_usage_service.go
View file @
bdc426a7
...
...
@@ -575,3 +575,9 @@ func buildGeminiUsageProgress(used, limit int64, resetAt time.Time, tokens int64
},
}
}
// GetAccountWindowStats 获取账号在指定时间窗口内的使用统计
// 用于账号列表页面显示当前窗口费用
func
(
s
*
AccountUsageService
)
GetAccountWindowStats
(
ctx
context
.
Context
,
accountID
int64
,
startTime
time
.
Time
)
(
*
usagestats
.
AccountStats
,
error
)
{
return
s
.
usageLogRepo
.
GetAccountWindowStats
(
ctx
,
accountID
,
startTime
)
}
backend/internal/service/antigravity_model_mapping_test.go
View file @
bdc426a7
...
...
@@ -30,7 +30,7 @@ func TestIsAntigravityModelSupported(t *testing.T) {
{
"可映射 - claude-3-haiku-20240307"
,
"claude-3-haiku-20240307"
,
true
},
// Gemini 前缀透传
{
"Gemini前缀 - gemini-
1
.5-pro"
,
"gemini-
1
.5-pro"
,
true
},
{
"Gemini前缀 - gemini-
2
.5-pro"
,
"gemini-
2
.5-pro"
,
true
},
{
"Gemini前缀 - gemini-unknown-model"
,
"gemini-unknown-model"
,
true
},
{
"Gemini前缀 - gemini-future-version"
,
"gemini-future-version"
,
true
},
...
...
@@ -142,10 +142,10 @@ func TestAntigravityGatewayService_GetMappedModel(t *testing.T) {
expected
:
"gemini-2.5-flash"
,
},
{
name
:
"Gemini透传 - gemini-
1
.5-pro"
,
requestedModel
:
"gemini-
1
.5-pro"
,
name
:
"Gemini透传 - gemini-
2
.5-pro"
,
requestedModel
:
"gemini-
2
.5-pro"
,
accountMapping
:
nil
,
expected
:
"gemini-
1
.5-pro"
,
expected
:
"gemini-
2
.5-pro"
,
},
{
name
:
"Gemini透传 - gemini-future-model"
,
...
...
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment