Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Menu
Open sidebar
陈曦
sub2api
Commits
bdc426a7
Commit
bdc426a7
authored
Jan 18, 2026
by
yangjianbo
Browse files
Merge branch 'main' into dev
parents
771baa66
32fff379
Changes
44
Hide whitespace changes
Inline
Side-by-side
backend/cmd/server/wire_gen.go
View file @
bdc426a7
...
...
@@ -118,7 +118,8 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
concurrencyCache
:=
repository
.
ProvideConcurrencyCache
(
redisClient
,
configConfig
)
concurrencyService
:=
service
.
ProvideConcurrencyService
(
concurrencyCache
,
accountRepository
,
configConfig
)
crsSyncService
:=
service
.
NewCRSSyncService
(
accountRepository
,
proxyRepository
,
oAuthService
,
openAIOAuthService
,
geminiOAuthService
,
configConfig
)
accountHandler
:=
admin
.
NewAccountHandler
(
adminService
,
oAuthService
,
openAIOAuthService
,
geminiOAuthService
,
antigravityOAuthService
,
rateLimitService
,
accountUsageService
,
accountTestService
,
concurrencyService
,
crsSyncService
)
sessionLimitCache
:=
repository
.
ProvideSessionLimitCache
(
redisClient
,
configConfig
)
accountHandler
:=
admin
.
NewAccountHandler
(
adminService
,
oAuthService
,
openAIOAuthService
,
geminiOAuthService
,
antigravityOAuthService
,
rateLimitService
,
accountUsageService
,
accountTestService
,
concurrencyService
,
crsSyncService
,
sessionLimitCache
)
oAuthHandler
:=
admin
.
NewOAuthHandler
(
oAuthService
)
openAIOAuthHandler
:=
admin
.
NewOpenAIOAuthHandler
(
openAIOAuthService
,
adminService
)
geminiOAuthHandler
:=
admin
.
NewGeminiOAuthHandler
(
geminiOAuthService
)
...
...
@@ -140,7 +141,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
identityService
:=
service
.
NewIdentityService
(
identityCache
)
deferredService
:=
service
.
ProvideDeferredService
(
accountRepository
,
timingWheelService
)
claudeTokenProvider
:=
service
.
NewClaudeTokenProvider
(
accountRepository
,
geminiTokenCache
,
oAuthService
)
gatewayService
:=
service
.
NewGatewayService
(
accountRepository
,
groupRepository
,
usageLogRepository
,
userRepository
,
userSubscriptionRepository
,
gatewayCache
,
configConfig
,
schedulerSnapshotService
,
concurrencyService
,
billingService
,
rateLimitService
,
billingCacheService
,
identityService
,
httpUpstream
,
deferredService
,
claudeTokenProvider
)
gatewayService
:=
service
.
NewGatewayService
(
accountRepository
,
groupRepository
,
usageLogRepository
,
userRepository
,
userSubscriptionRepository
,
gatewayCache
,
configConfig
,
schedulerSnapshotService
,
concurrencyService
,
billingService
,
rateLimitService
,
billingCacheService
,
identityService
,
httpUpstream
,
deferredService
,
claudeTokenProvider
,
sessionLimitCache
)
openAITokenProvider
:=
service
.
NewOpenAITokenProvider
(
accountRepository
,
geminiTokenCache
,
openAIOAuthService
)
openAIGatewayService
:=
service
.
NewOpenAIGatewayService
(
accountRepository
,
usageLogRepository
,
userRepository
,
userSubscriptionRepository
,
gatewayCache
,
configConfig
,
schedulerSnapshotService
,
concurrencyService
,
billingService
,
rateLimitService
,
billingCacheService
,
httpUpstream
,
deferredService
,
openAITokenProvider
)
geminiMessagesCompatService
:=
service
.
NewGeminiMessagesCompatService
(
accountRepository
,
groupRepository
,
gatewayCache
,
schedulerSnapshotService
,
geminiTokenProvider
,
rateLimitService
,
httpUpstream
,
antigravityGatewayService
,
configConfig
)
...
...
backend/internal/config/config.go
View file @
bdc426a7
...
...
@@ -235,6 +235,10 @@ type GatewayConfig struct {
// ConcurrencySlotTTLMinutes: 并发槽位过期时间(分钟)
// 应大于最长 LLM 请求时间,防止请求完成前槽位过期
ConcurrencySlotTTLMinutes
int
`mapstructure:"concurrency_slot_ttl_minutes"`
// SessionIdleTimeoutMinutes: 会话空闲超时时间(分钟),默认 5 分钟
// 用于 Anthropic OAuth/SetupToken 账号的会话数量限制功能
// 空闲超过此时间的会话将被自动释放
SessionIdleTimeoutMinutes
int
`mapstructure:"session_idle_timeout_minutes"`
// StreamDataIntervalTimeout: 流数据间隔超时(秒),0表示禁用
StreamDataIntervalTimeout
int
`mapstructure:"stream_data_interval_timeout"`
...
...
backend/internal/handler/admin/account_handler.go
View file @
bdc426a7
...
...
@@ -44,6 +44,7 @@ type AccountHandler struct {
accountTestService
*
service
.
AccountTestService
concurrencyService
*
service
.
ConcurrencyService
crsSyncService
*
service
.
CRSSyncService
sessionLimitCache
service
.
SessionLimitCache
}
// NewAccountHandler creates a new admin account handler
...
...
@@ -58,6 +59,7 @@ func NewAccountHandler(
accountTestService
*
service
.
AccountTestService
,
concurrencyService
*
service
.
ConcurrencyService
,
crsSyncService
*
service
.
CRSSyncService
,
sessionLimitCache
service
.
SessionLimitCache
,
)
*
AccountHandler
{
return
&
AccountHandler
{
adminService
:
adminService
,
...
...
@@ -70,6 +72,7 @@ func NewAccountHandler(
accountTestService
:
accountTestService
,
concurrencyService
:
concurrencyService
,
crsSyncService
:
crsSyncService
,
sessionLimitCache
:
sessionLimitCache
,
}
}
...
...
@@ -130,6 +133,9 @@ type BulkUpdateAccountsRequest struct {
type
AccountWithConcurrency
struct
{
*
dto
.
Account
CurrentConcurrency
int
`json:"current_concurrency"`
// 以下字段仅对 Anthropic OAuth/SetupToken 账号有效,且仅在启用相应功能时返回
CurrentWindowCost
*
float64
`json:"current_window_cost,omitempty"`
// 当前窗口费用
ActiveSessions
*
int
`json:"active_sessions,omitempty"`
// 当前活跃会话数
}
// List handles listing all accounts with pagination
...
...
@@ -164,13 +170,89 @@ func (h *AccountHandler) List(c *gin.Context) {
concurrencyCounts
=
make
(
map
[
int64
]
int
)
}
// 识别需要查询窗口费用和会话数的账号(Anthropic OAuth/SetupToken 且启用了相应功能)
windowCostAccountIDs
:=
make
([]
int64
,
0
)
sessionLimitAccountIDs
:=
make
([]
int64
,
0
)
for
i
:=
range
accounts
{
acc
:=
&
accounts
[
i
]
if
acc
.
IsAnthropicOAuthOrSetupToken
()
{
if
acc
.
GetWindowCostLimit
()
>
0
{
windowCostAccountIDs
=
append
(
windowCostAccountIDs
,
acc
.
ID
)
}
if
acc
.
GetMaxSessions
()
>
0
{
sessionLimitAccountIDs
=
append
(
sessionLimitAccountIDs
,
acc
.
ID
)
}
}
}
// 并行获取窗口费用和活跃会话数
var
windowCosts
map
[
int64
]
float64
var
activeSessions
map
[
int64
]
int
// 获取活跃会话数(批量查询)
if
len
(
sessionLimitAccountIDs
)
>
0
&&
h
.
sessionLimitCache
!=
nil
{
activeSessions
,
_
=
h
.
sessionLimitCache
.
GetActiveSessionCountBatch
(
c
.
Request
.
Context
(),
sessionLimitAccountIDs
)
if
activeSessions
==
nil
{
activeSessions
=
make
(
map
[
int64
]
int
)
}
}
// 获取窗口费用(并行查询)
if
len
(
windowCostAccountIDs
)
>
0
{
windowCosts
=
make
(
map
[
int64
]
float64
)
var
mu
sync
.
Mutex
g
,
gctx
:=
errgroup
.
WithContext
(
c
.
Request
.
Context
())
g
.
SetLimit
(
10
)
// 限制并发数
for
i
:=
range
accounts
{
acc
:=
&
accounts
[
i
]
if
!
acc
.
IsAnthropicOAuthOrSetupToken
()
||
acc
.
GetWindowCostLimit
()
<=
0
{
continue
}
accCopy
:=
acc
// 闭包捕获
g
.
Go
(
func
()
error
{
var
startTime
time
.
Time
if
accCopy
.
SessionWindowStart
!=
nil
{
startTime
=
*
accCopy
.
SessionWindowStart
}
else
{
startTime
=
time
.
Now
()
.
Add
(
-
5
*
time
.
Hour
)
}
stats
,
err
:=
h
.
accountUsageService
.
GetAccountWindowStats
(
gctx
,
accCopy
.
ID
,
startTime
)
if
err
==
nil
&&
stats
!=
nil
{
mu
.
Lock
()
windowCosts
[
accCopy
.
ID
]
=
stats
.
StandardCost
// 使用标准费用
mu
.
Unlock
()
}
return
nil
// 不返回错误,允许部分失败
})
}
_
=
g
.
Wait
()
}
// Build response with concurrency info
result
:=
make
([]
AccountWithConcurrency
,
len
(
accounts
))
for
i
:=
range
accounts
{
result
[
i
]
=
AccountWithConcurrency
{
Account
:
dto
.
AccountFromService
(
&
accounts
[
i
]),
CurrentConcurrency
:
concurrencyCounts
[
accounts
[
i
]
.
ID
],
acc
:=
&
accounts
[
i
]
item
:=
AccountWithConcurrency
{
Account
:
dto
.
AccountFromService
(
acc
),
CurrentConcurrency
:
concurrencyCounts
[
acc
.
ID
],
}
// 添加窗口费用(仅当启用时)
if
windowCosts
!=
nil
{
if
cost
,
ok
:=
windowCosts
[
acc
.
ID
];
ok
{
item
.
CurrentWindowCost
=
&
cost
}
}
// 添加活跃会话数(仅当启用时)
if
activeSessions
!=
nil
{
if
count
,
ok
:=
activeSessions
[
acc
.
ID
];
ok
{
item
.
ActiveSessions
=
&
count
}
}
result
[
i
]
=
item
}
response
.
Paginated
(
c
,
result
,
total
,
page
,
pageSize
)
...
...
backend/internal/handler/dto/mappers.go
View file @
bdc426a7
...
...
@@ -116,7 +116,7 @@ func AccountFromServiceShallow(a *service.Account) *Account {
if
a
==
nil
{
return
nil
}
return
&
Account
{
out
:=
&
Account
{
ID
:
a
.
ID
,
Name
:
a
.
Name
,
Notes
:
a
.
Notes
,
...
...
@@ -146,6 +146,24 @@ func AccountFromServiceShallow(a *service.Account) *Account {
SessionWindowStatus
:
a
.
SessionWindowStatus
,
GroupIDs
:
a
.
GroupIDs
,
}
// 提取 5h 窗口费用控制和会话数量控制配置(仅 Anthropic OAuth/SetupToken 账号有效)
if
a
.
IsAnthropicOAuthOrSetupToken
()
{
if
limit
:=
a
.
GetWindowCostLimit
();
limit
>
0
{
out
.
WindowCostLimit
=
&
limit
}
if
reserve
:=
a
.
GetWindowCostStickyReserve
();
reserve
>
0
{
out
.
WindowCostStickyReserve
=
&
reserve
}
if
maxSessions
:=
a
.
GetMaxSessions
();
maxSessions
>
0
{
out
.
MaxSessions
=
&
maxSessions
}
if
idleTimeout
:=
a
.
GetSessionIdleTimeoutMinutes
();
idleTimeout
>
0
{
out
.
SessionIdleTimeoutMin
=
&
idleTimeout
}
}
return
out
}
func
AccountFromService
(
a
*
service
.
Account
)
*
Account
{
...
...
backend/internal/handler/dto/types.go
View file @
bdc426a7
...
...
@@ -102,6 +102,16 @@ type Account struct {
SessionWindowEnd
*
time
.
Time
`json:"session_window_end"`
SessionWindowStatus
string
`json:"session_window_status"`
// 5h窗口费用控制(仅 Anthropic OAuth/SetupToken 账号有效)
// 从 extra 字段提取,方便前端显示和编辑
WindowCostLimit
*
float64
`json:"window_cost_limit,omitempty"`
WindowCostStickyReserve
*
float64
`json:"window_cost_sticky_reserve,omitempty"`
// 会话数量控制(仅 Anthropic OAuth/SetupToken 账号有效)
// 从 extra 字段提取,方便前端显示和编辑
MaxSessions
*
int
`json:"max_sessions,omitempty"`
SessionIdleTimeoutMin
*
int
`json:"session_idle_timeout_minutes,omitempty"`
Proxy
*
Proxy
`json:"proxy,omitempty"`
AccountGroups
[]
AccountGroup
`json:"account_groups,omitempty"`
...
...
backend/internal/handler/gateway_handler.go
View file @
bdc426a7
...
...
@@ -185,7 +185,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
lastFailoverStatus
:=
0
for
{
selection
,
err
:=
h
.
gatewayService
.
SelectAccountWithLoadAwareness
(
c
.
Request
.
Context
(),
apiKey
.
GroupID
,
sessionKey
,
reqModel
,
failedAccountIDs
)
selection
,
err
:=
h
.
gatewayService
.
SelectAccountWithLoadAwareness
(
c
.
Request
.
Context
(),
apiKey
.
GroupID
,
sessionKey
,
reqModel
,
failedAccountIDs
,
""
)
// Gemini 不使用会话限制
if
err
!=
nil
{
if
len
(
failedAccountIDs
)
==
0
{
h
.
handleStreamingAwareError
(
c
,
http
.
StatusServiceUnavailable
,
"api_error"
,
"No available accounts: "
+
err
.
Error
(),
streamStarted
)
...
...
@@ -320,7 +320,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
for
{
// 选择支持该模型的账号
selection
,
err
:=
h
.
gatewayService
.
SelectAccountWithLoadAwareness
(
c
.
Request
.
Context
(),
apiKey
.
GroupID
,
sessionKey
,
reqModel
,
failedAccountIDs
)
selection
,
err
:=
h
.
gatewayService
.
SelectAccountWithLoadAwareness
(
c
.
Request
.
Context
(),
apiKey
.
GroupID
,
sessionKey
,
reqModel
,
failedAccountIDs
,
parsedReq
.
MetadataUserID
)
if
err
!=
nil
{
if
len
(
failedAccountIDs
)
==
0
{
h
.
handleStreamingAwareError
(
c
,
http
.
StatusServiceUnavailable
,
"api_error"
,
"No available accounts: "
+
err
.
Error
(),
streamStarted
)
...
...
backend/internal/handler/gemini_v1beta_handler.go
View file @
bdc426a7
...
...
@@ -226,7 +226,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
lastFailoverStatus
:=
0
for
{
selection
,
err
:=
h
.
gatewayService
.
SelectAccountWithLoadAwareness
(
c
.
Request
.
Context
(),
apiKey
.
GroupID
,
sessionKey
,
modelName
,
failedAccountIDs
)
selection
,
err
:=
h
.
gatewayService
.
SelectAccountWithLoadAwareness
(
c
.
Request
.
Context
(),
apiKey
.
GroupID
,
sessionKey
,
modelName
,
failedAccountIDs
,
""
)
// Gemini 不使用会话限制
if
err
!=
nil
{
if
len
(
failedAccountIDs
)
==
0
{
googleError
(
c
,
http
.
StatusServiceUnavailable
,
"No available Gemini accounts: "
+
err
.
Error
())
...
...
backend/internal/handler/openai_gateway_handler.go
View file @
bdc426a7
...
...
@@ -186,8 +186,8 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
return
}
// Generate session hash (
from
header f
or OpenAI
)
sessionHash
:=
h
.
gatewayService
.
GenerateSessionHash
(
c
)
// Generate session hash (header f
irst; fallback to prompt_cache_key
)
sessionHash
:=
h
.
gatewayService
.
GenerateSessionHash
(
c
,
reqBody
)
const
maxAccountSwitches
=
3
switchCount
:=
0
...
...
backend/internal/pkg/gemini/models.go
View file @
bdc426a7
...
...
@@ -16,14 +16,11 @@ type ModelsListResponse struct {
func
DefaultModels
()
[]
Model
{
methods
:=
[]
string
{
"generateContent"
,
"streamGenerateContent"
}
return
[]
Model
{
{
Name
:
"models/gemini-3-pro-preview"
,
SupportedGenerationMethods
:
methods
},
{
Name
:
"models/gemini-3-flash-preview"
,
SupportedGenerationMethods
:
methods
},
{
Name
:
"models/gemini-2.5-pro"
,
SupportedGenerationMethods
:
methods
},
{
Name
:
"models/gemini-2.5-flash"
,
SupportedGenerationMethods
:
methods
},
{
Name
:
"models/gemini-2.0-flash"
,
SupportedGenerationMethods
:
methods
},
{
Name
:
"models/gemini-1.5-pro"
,
SupportedGenerationMethods
:
methods
},
{
Name
:
"models/gemini-1.5-flash"
,
SupportedGenerationMethods
:
methods
},
{
Name
:
"models/gemini-1.5-flash-8b"
,
SupportedGenerationMethods
:
methods
},
{
Name
:
"models/gemini-2.5-flash"
,
SupportedGenerationMethods
:
methods
},
{
Name
:
"models/gemini-2.5-pro"
,
SupportedGenerationMethods
:
methods
},
{
Name
:
"models/gemini-3-flash-preview"
,
SupportedGenerationMethods
:
methods
},
{
Name
:
"models/gemini-3-pro-preview"
,
SupportedGenerationMethods
:
methods
},
}
}
...
...
backend/internal/pkg/geminicli/models.go
View file @
bdc426a7
...
...
@@ -12,10 +12,10 @@ type Model struct {
// DefaultModels is the curated Gemini model list used by the admin UI "test account" flow.
var
DefaultModels
=
[]
Model
{
{
ID
:
"gemini-2.0-flash"
,
Type
:
"model"
,
DisplayName
:
"Gemini 2.0 Flash"
,
CreatedAt
:
""
},
{
ID
:
"gemini-2.5-pro"
,
Type
:
"model"
,
DisplayName
:
"Gemini 2.5 Pro"
,
CreatedAt
:
""
},
{
ID
:
"gemini-2.5-flash"
,
Type
:
"model"
,
DisplayName
:
"Gemini 2.5 Flash"
,
CreatedAt
:
""
},
{
ID
:
"gemini-
3
-pro
-preview
"
,
Type
:
"model"
,
DisplayName
:
"Gemini
3
Pro
Preview
"
,
CreatedAt
:
""
},
{
ID
:
"gemini-
2.5
-pro"
,
Type
:
"model"
,
DisplayName
:
"Gemini
2.5
Pro"
,
CreatedAt
:
""
},
{
ID
:
"gemini-3-flash-preview"
,
Type
:
"model"
,
DisplayName
:
"Gemini 3 Flash Preview"
,
CreatedAt
:
""
},
{
ID
:
"gemini-3-pro-preview"
,
Type
:
"model"
,
DisplayName
:
"Gemini 3 Pro Preview"
,
CreatedAt
:
""
},
}
// DefaultTestModel is the default model to preselect in test flows.
...
...
backend/internal/repository/ent.go
View file @
bdc426a7
...
...
@@ -65,5 +65,18 @@ func InitEnt(cfg *config.Config) (*ent.Client, *sql.DB, error) {
// 创建 Ent 客户端,绑定到已配置的数据库驱动。
client
:=
ent
.
NewClient
(
ent
.
Driver
(
drv
))
// SIMPLE 模式:启动时补齐各平台默认分组。
// - anthropic/openai/gemini: 确保存在 <platform>-default
// - antigravity: 仅要求存在 >=2 个未软删除分组(用于 claude/gemini 混合调度场景)
if
cfg
.
RunMode
==
config
.
RunModeSimple
{
seedCtx
,
seedCancel
:=
context
.
WithTimeout
(
context
.
Background
(),
30
*
time
.
Second
)
defer
seedCancel
()
if
err
:=
ensureSimpleModeDefaultGroups
(
seedCtx
,
client
);
err
!=
nil
{
_
=
client
.
Close
()
return
nil
,
nil
,
err
}
}
return
client
,
drv
.
DB
(),
nil
}
backend/internal/repository/ops_repo.go
View file @
bdc426a7
...
...
@@ -992,7 +992,8 @@ func buildOpsErrorLogsWhere(filter *service.OpsErrorLogFilter) (string, []any) {
}
// View filter: errors vs excluded vs all.
// Excluded = upstream 429/529 and business-limited (quota/concurrency/billing) errors.
// Excluded = business-limited errors (quota/concurrency/billing).
// Upstream 429/529 are included in errors view to match SLA calculation.
view
:=
""
if
filter
!=
nil
{
view
=
strings
.
ToLower
(
strings
.
TrimSpace
(
filter
.
View
))
...
...
@@ -1000,15 +1001,13 @@ func buildOpsErrorLogsWhere(filter *service.OpsErrorLogFilter) (string, []any) {
switch
view
{
case
""
,
"errors"
:
clauses
=
append
(
clauses
,
"COALESCE(is_business_limited,false) = false"
)
clauses
=
append
(
clauses
,
"COALESCE(upstream_status_code, status_code, 0) NOT IN (429, 529)"
)
case
"excluded"
:
clauses
=
append
(
clauses
,
"
(
COALESCE(is_business_limited,false) = true
OR COALESCE(upstream_status_code, status_code, 0) IN (429, 529))
"
)
clauses
=
append
(
clauses
,
"COALESCE(is_business_limited,false) = true"
)
case
"all"
:
// no-op
default
:
// treat unknown as default 'errors'
clauses
=
append
(
clauses
,
"COALESCE(is_business_limited,false) = false"
)
clauses
=
append
(
clauses
,
"COALESCE(upstream_status_code, status_code, 0) NOT IN (429, 529)"
)
}
if
len
(
filter
.
StatusCodes
)
>
0
{
args
=
append
(
args
,
pq
.
Array
(
filter
.
StatusCodes
))
...
...
backend/internal/repository/session_limit_cache.go
0 → 100644
View file @
bdc426a7
package
repository
import
(
"context"
"fmt"
"strconv"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/redis/go-redis/v9"
)
// 会话限制缓存常量定义
//
// 设计说明:
// 使用 Redis 有序集合(Sorted Set)跟踪每个账号的活跃会话:
// - Key: session_limit:account:{accountID}
// - Member: sessionUUID(从 metadata.user_id 中提取)
// - Score: Unix 时间戳(会话最后活跃时间)
//
// 通过 ZREMRANGEBYSCORE 自动清理过期会话,无需手动管理 TTL
const
(
// 会话限制键前缀
// 格式: session_limit:account:{accountID}
sessionLimitKeyPrefix
=
"session_limit:account:"
// 窗口费用缓存键前缀
// 格式: window_cost:account:{accountID}
windowCostKeyPrefix
=
"window_cost:account:"
// 窗口费用缓存 TTL(30秒)
windowCostCacheTTL
=
30
*
time
.
Second
)
var
(
// registerSessionScript 注册会话活动
// 使用 Redis TIME 命令获取服务器时间,避免多实例时钟不同步
// KEYS[1] = session_limit:account:{accountID}
// ARGV[1] = maxSessions
// ARGV[2] = idleTimeout(秒)
// ARGV[3] = sessionUUID
// 返回: 1 = 允许, 0 = 拒绝
registerSessionScript
=
redis
.
NewScript
(
`
local key = KEYS[1]
local maxSessions = tonumber(ARGV[1])
local idleTimeout = tonumber(ARGV[2])
local sessionUUID = ARGV[3]
-- 使用 Redis 服务器时间,确保多实例时钟一致
local timeResult = redis.call('TIME')
local now = tonumber(timeResult[1])
local expireBefore = now - idleTimeout
-- 清理过期会话
redis.call('ZREMRANGEBYSCORE', key, '-inf', expireBefore)
-- 检查会话是否已存在(支持刷新时间戳)
local exists = redis.call('ZSCORE', key, sessionUUID)
if exists ~= false then
-- 会话已存在,刷新时间戳
redis.call('ZADD', key, now, sessionUUID)
redis.call('EXPIRE', key, idleTimeout + 60)
return 1
end
-- 检查是否达到会话数量上限
local count = redis.call('ZCARD', key)
if count < maxSessions then
-- 未达上限,添加新会话
redis.call('ZADD', key, now, sessionUUID)
redis.call('EXPIRE', key, idleTimeout + 60)
return 1
end
-- 达到上限,拒绝新会话
return 0
`
)
// refreshSessionScript 刷新会话时间戳
// KEYS[1] = session_limit:account:{accountID}
// ARGV[1] = idleTimeout(秒)
// ARGV[2] = sessionUUID
refreshSessionScript
=
redis
.
NewScript
(
`
local key = KEYS[1]
local idleTimeout = tonumber(ARGV[1])
local sessionUUID = ARGV[2]
local timeResult = redis.call('TIME')
local now = tonumber(timeResult[1])
-- 检查会话是否存在
local exists = redis.call('ZSCORE', key, sessionUUID)
if exists ~= false then
redis.call('ZADD', key, now, sessionUUID)
redis.call('EXPIRE', key, idleTimeout + 60)
end
return 1
`
)
// getActiveSessionCountScript 获取活跃会话数
// KEYS[1] = session_limit:account:{accountID}
// ARGV[1] = idleTimeout(秒)
getActiveSessionCountScript
=
redis
.
NewScript
(
`
local key = KEYS[1]
local idleTimeout = tonumber(ARGV[1])
local timeResult = redis.call('TIME')
local now = tonumber(timeResult[1])
local expireBefore = now - idleTimeout
-- 清理过期会话
redis.call('ZREMRANGEBYSCORE', key, '-inf', expireBefore)
return redis.call('ZCARD', key)
`
)
// isSessionActiveScript 检查会话是否活跃
// KEYS[1] = session_limit:account:{accountID}
// ARGV[1] = idleTimeout(秒)
// ARGV[2] = sessionUUID
isSessionActiveScript
=
redis
.
NewScript
(
`
local key = KEYS[1]
local idleTimeout = tonumber(ARGV[1])
local sessionUUID = ARGV[2]
local timeResult = redis.call('TIME')
local now = tonumber(timeResult[1])
local expireBefore = now - idleTimeout
-- 获取会话的时间戳
local score = redis.call('ZSCORE', key, sessionUUID)
if score == false then
return 0
end
-- 检查是否过期
if tonumber(score) <= expireBefore then
return 0
end
return 1
`
)
)
type
sessionLimitCache
struct
{
rdb
*
redis
.
Client
defaultIdleTimeout
time
.
Duration
// 默认空闲超时(用于 GetActiveSessionCount)
}
// NewSessionLimitCache 创建会话限制缓存
// defaultIdleTimeoutMinutes: 默认空闲超时时间(分钟),用于无参数查询
func
NewSessionLimitCache
(
rdb
*
redis
.
Client
,
defaultIdleTimeoutMinutes
int
)
service
.
SessionLimitCache
{
if
defaultIdleTimeoutMinutes
<=
0
{
defaultIdleTimeoutMinutes
=
5
// 默认 5 分钟
}
return
&
sessionLimitCache
{
rdb
:
rdb
,
defaultIdleTimeout
:
time
.
Duration
(
defaultIdleTimeoutMinutes
)
*
time
.
Minute
,
}
}
// sessionLimitKey 生成会话限制的 Redis 键
func
sessionLimitKey
(
accountID
int64
)
string
{
return
fmt
.
Sprintf
(
"%s%d"
,
sessionLimitKeyPrefix
,
accountID
)
}
// windowCostKey 生成窗口费用缓存的 Redis 键
func
windowCostKey
(
accountID
int64
)
string
{
return
fmt
.
Sprintf
(
"%s%d"
,
windowCostKeyPrefix
,
accountID
)
}
// RegisterSession 注册会话活动
func
(
c
*
sessionLimitCache
)
RegisterSession
(
ctx
context
.
Context
,
accountID
int64
,
sessionUUID
string
,
maxSessions
int
,
idleTimeout
time
.
Duration
)
(
bool
,
error
)
{
if
sessionUUID
==
""
||
maxSessions
<=
0
{
return
true
,
nil
// 无效参数,默认允许
}
key
:=
sessionLimitKey
(
accountID
)
idleTimeoutSeconds
:=
int
(
idleTimeout
.
Seconds
())
if
idleTimeoutSeconds
<=
0
{
idleTimeoutSeconds
=
int
(
c
.
defaultIdleTimeout
.
Seconds
())
}
result
,
err
:=
registerSessionScript
.
Run
(
ctx
,
c
.
rdb
,
[]
string
{
key
},
maxSessions
,
idleTimeoutSeconds
,
sessionUUID
)
.
Int
()
if
err
!=
nil
{
return
true
,
err
// 失败开放:缓存错误时允许请求通过
}
return
result
==
1
,
nil
}
// RefreshSession 刷新会话时间戳
func
(
c
*
sessionLimitCache
)
RefreshSession
(
ctx
context
.
Context
,
accountID
int64
,
sessionUUID
string
,
idleTimeout
time
.
Duration
)
error
{
if
sessionUUID
==
""
{
return
nil
}
key
:=
sessionLimitKey
(
accountID
)
idleTimeoutSeconds
:=
int
(
idleTimeout
.
Seconds
())
if
idleTimeoutSeconds
<=
0
{
idleTimeoutSeconds
=
int
(
c
.
defaultIdleTimeout
.
Seconds
())
}
_
,
err
:=
refreshSessionScript
.
Run
(
ctx
,
c
.
rdb
,
[]
string
{
key
},
idleTimeoutSeconds
,
sessionUUID
)
.
Result
()
return
err
}
// GetActiveSessionCount 获取活跃会话数
func
(
c
*
sessionLimitCache
)
GetActiveSessionCount
(
ctx
context
.
Context
,
accountID
int64
)
(
int
,
error
)
{
key
:=
sessionLimitKey
(
accountID
)
idleTimeoutSeconds
:=
int
(
c
.
defaultIdleTimeout
.
Seconds
())
result
,
err
:=
getActiveSessionCountScript
.
Run
(
ctx
,
c
.
rdb
,
[]
string
{
key
},
idleTimeoutSeconds
)
.
Int
()
if
err
!=
nil
{
return
0
,
err
}
return
result
,
nil
}
// GetActiveSessionCountBatch 批量获取多个账号的活跃会话数
func
(
c
*
sessionLimitCache
)
GetActiveSessionCountBatch
(
ctx
context
.
Context
,
accountIDs
[]
int64
)
(
map
[
int64
]
int
,
error
)
{
if
len
(
accountIDs
)
==
0
{
return
make
(
map
[
int64
]
int
),
nil
}
results
:=
make
(
map
[
int64
]
int
,
len
(
accountIDs
))
// 使用 pipeline 批量执行
pipe
:=
c
.
rdb
.
Pipeline
()
idleTimeoutSeconds
:=
int
(
c
.
defaultIdleTimeout
.
Seconds
())
cmds
:=
make
(
map
[
int64
]
*
redis
.
Cmd
,
len
(
accountIDs
))
for
_
,
accountID
:=
range
accountIDs
{
key
:=
sessionLimitKey
(
accountID
)
cmds
[
accountID
]
=
getActiveSessionCountScript
.
Run
(
ctx
,
pipe
,
[]
string
{
key
},
idleTimeoutSeconds
)
}
// 执行 pipeline,即使部分失败也尝试获取成功的结果
_
,
_
=
pipe
.
Exec
(
ctx
)
for
accountID
,
cmd
:=
range
cmds
{
if
result
,
err
:=
cmd
.
Int
();
err
==
nil
{
results
[
accountID
]
=
result
}
}
return
results
,
nil
}
// IsSessionActive 检查会话是否活跃
func
(
c
*
sessionLimitCache
)
IsSessionActive
(
ctx
context
.
Context
,
accountID
int64
,
sessionUUID
string
)
(
bool
,
error
)
{
if
sessionUUID
==
""
{
return
false
,
nil
}
key
:=
sessionLimitKey
(
accountID
)
idleTimeoutSeconds
:=
int
(
c
.
defaultIdleTimeout
.
Seconds
())
result
,
err
:=
isSessionActiveScript
.
Run
(
ctx
,
c
.
rdb
,
[]
string
{
key
},
idleTimeoutSeconds
,
sessionUUID
)
.
Int
()
if
err
!=
nil
{
return
false
,
err
}
return
result
==
1
,
nil
}
// ========== 5h窗口费用缓存实现 ==========
// GetWindowCost 获取缓存的窗口费用
func
(
c
*
sessionLimitCache
)
GetWindowCost
(
ctx
context
.
Context
,
accountID
int64
)
(
float64
,
bool
,
error
)
{
key
:=
windowCostKey
(
accountID
)
val
,
err
:=
c
.
rdb
.
Get
(
ctx
,
key
)
.
Float64
()
if
err
==
redis
.
Nil
{
return
0
,
false
,
nil
// 缓存未命中
}
if
err
!=
nil
{
return
0
,
false
,
err
}
return
val
,
true
,
nil
}
// SetWindowCost 设置窗口费用缓存
func
(
c
*
sessionLimitCache
)
SetWindowCost
(
ctx
context
.
Context
,
accountID
int64
,
cost
float64
)
error
{
key
:=
windowCostKey
(
accountID
)
return
c
.
rdb
.
Set
(
ctx
,
key
,
cost
,
windowCostCacheTTL
)
.
Err
()
}
// GetWindowCostBatch 批量获取窗口费用缓存
func
(
c
*
sessionLimitCache
)
GetWindowCostBatch
(
ctx
context
.
Context
,
accountIDs
[]
int64
)
(
map
[
int64
]
float64
,
error
)
{
if
len
(
accountIDs
)
==
0
{
return
make
(
map
[
int64
]
float64
),
nil
}
// 构建批量查询的 keys
keys
:=
make
([]
string
,
len
(
accountIDs
))
for
i
,
accountID
:=
range
accountIDs
{
keys
[
i
]
=
windowCostKey
(
accountID
)
}
// 使用 MGET 批量获取
vals
,
err
:=
c
.
rdb
.
MGet
(
ctx
,
keys
...
)
.
Result
()
if
err
!=
nil
{
return
nil
,
err
}
results
:=
make
(
map
[
int64
]
float64
,
len
(
accountIDs
))
for
i
,
val
:=
range
vals
{
if
val
==
nil
{
continue
// 缓存未命中
}
// 尝试解析为 float64
switch
v
:=
val
.
(
type
)
{
case
string
:
if
cost
,
err
:=
strconv
.
ParseFloat
(
v
,
64
);
err
==
nil
{
results
[
accountIDs
[
i
]]
=
cost
}
case
float64
:
results
[
accountIDs
[
i
]]
=
v
}
}
return
results
,
nil
}
backend/internal/repository/simple_mode_default_groups.go
0 → 100644
View file @
bdc426a7
package
repository
import
(
"context"
"fmt"
dbent
"github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/group"
"github.com/Wei-Shaw/sub2api/internal/service"
)
func
ensureSimpleModeDefaultGroups
(
ctx
context
.
Context
,
client
*
dbent
.
Client
)
error
{
if
client
==
nil
{
return
fmt
.
Errorf
(
"nil ent client"
)
}
requiredByPlatform
:=
map
[
string
]
int
{
service
.
PlatformAnthropic
:
1
,
service
.
PlatformOpenAI
:
1
,
service
.
PlatformGemini
:
1
,
service
.
PlatformAntigravity
:
2
,
}
for
platform
,
minCount
:=
range
requiredByPlatform
{
count
,
err
:=
client
.
Group
.
Query
()
.
Where
(
group
.
PlatformEQ
(
platform
),
group
.
DeletedAtIsNil
())
.
Count
(
ctx
)
if
err
!=
nil
{
return
fmt
.
Errorf
(
"count groups for platform %s: %w"
,
platform
,
err
)
}
if
platform
==
service
.
PlatformAntigravity
{
if
count
<
minCount
{
for
i
:=
count
;
i
<
minCount
;
i
++
{
name
:=
fmt
.
Sprintf
(
"%s-default-%d"
,
platform
,
i
+
1
)
if
err
:=
createGroupIfNotExists
(
ctx
,
client
,
name
,
platform
);
err
!=
nil
{
return
err
}
}
}
continue
}
// Non-antigravity platforms: ensure <platform>-default exists.
name
:=
platform
+
"-default"
if
err
:=
createGroupIfNotExists
(
ctx
,
client
,
name
,
platform
);
err
!=
nil
{
return
err
}
}
return
nil
}
func
createGroupIfNotExists
(
ctx
context
.
Context
,
client
*
dbent
.
Client
,
name
,
platform
string
)
error
{
exists
,
err
:=
client
.
Group
.
Query
()
.
Where
(
group
.
NameEQ
(
name
),
group
.
DeletedAtIsNil
())
.
Exist
(
ctx
)
if
err
!=
nil
{
return
fmt
.
Errorf
(
"check group exists %s: %w"
,
name
,
err
)
}
if
exists
{
return
nil
}
_
,
err
=
client
.
Group
.
Create
()
.
SetName
(
name
)
.
SetDescription
(
"Auto-created default group"
)
.
SetPlatform
(
platform
)
.
SetStatus
(
service
.
StatusActive
)
.
SetSubscriptionType
(
service
.
SubscriptionTypeStandard
)
.
SetRateMultiplier
(
1.0
)
.
SetIsExclusive
(
false
)
.
Save
(
ctx
)
if
err
!=
nil
{
if
dbent
.
IsConstraintError
(
err
)
{
// Concurrent server startups may race on creation; treat as success.
return
nil
}
return
fmt
.
Errorf
(
"create default group %s: %w"
,
name
,
err
)
}
return
nil
}
backend/internal/repository/simple_mode_default_groups_integration_test.go
0 → 100644
View file @
bdc426a7
//go:build integration
package
repository
import
(
"context"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/ent/group"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/stretchr/testify/require"
)
func
TestEnsureSimpleModeDefaultGroups_CreatesMissingDefaults
(
t
*
testing
.
T
)
{
ctx
:=
context
.
Background
()
tx
:=
testEntTx
(
t
)
client
:=
tx
.
Client
()
seedCtx
,
cancel
:=
context
.
WithTimeout
(
ctx
,
10
*
time
.
Second
)
defer
cancel
()
require
.
NoError
(
t
,
ensureSimpleModeDefaultGroups
(
seedCtx
,
client
))
assertGroupExists
:=
func
(
name
string
)
{
exists
,
err
:=
client
.
Group
.
Query
()
.
Where
(
group
.
NameEQ
(
name
),
group
.
DeletedAtIsNil
())
.
Exist
(
seedCtx
)
require
.
NoError
(
t
,
err
)
require
.
True
(
t
,
exists
,
"expected group %s to exist"
,
name
)
}
assertGroupExists
(
service
.
PlatformAnthropic
+
"-default"
)
assertGroupExists
(
service
.
PlatformOpenAI
+
"-default"
)
assertGroupExists
(
service
.
PlatformGemini
+
"-default"
)
assertGroupExists
(
service
.
PlatformAntigravity
+
"-default-1"
)
assertGroupExists
(
service
.
PlatformAntigravity
+
"-default-2"
)
}
func
TestEnsureSimpleModeDefaultGroups_IgnoresSoftDeletedGroups
(
t
*
testing
.
T
)
{
ctx
:=
context
.
Background
()
tx
:=
testEntTx
(
t
)
client
:=
tx
.
Client
()
seedCtx
,
cancel
:=
context
.
WithTimeout
(
ctx
,
10
*
time
.
Second
)
defer
cancel
()
// Create and then soft-delete an anthropic default group.
g
,
err
:=
client
.
Group
.
Create
()
.
SetName
(
service
.
PlatformAnthropic
+
"-default"
)
.
SetPlatform
(
service
.
PlatformAnthropic
)
.
SetStatus
(
service
.
StatusActive
)
.
SetSubscriptionType
(
service
.
SubscriptionTypeStandard
)
.
SetRateMultiplier
(
1.0
)
.
SetIsExclusive
(
false
)
.
Save
(
seedCtx
)
require
.
NoError
(
t
,
err
)
_
,
err
=
client
.
Group
.
Delete
()
.
Where
(
group
.
IDEQ
(
g
.
ID
))
.
Exec
(
seedCtx
)
require
.
NoError
(
t
,
err
)
require
.
NoError
(
t
,
ensureSimpleModeDefaultGroups
(
seedCtx
,
client
))
// New active one should exist.
count
,
err
:=
client
.
Group
.
Query
()
.
Where
(
group
.
NameEQ
(
service
.
PlatformAnthropic
+
"-default"
),
group
.
DeletedAtIsNil
())
.
Count
(
seedCtx
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
1
,
count
)
}
func
TestEnsureSimpleModeDefaultGroups_AntigravityNeedsTwoGroupsOnlyByCount
(
t
*
testing
.
T
)
{
ctx
:=
context
.
Background
()
tx
:=
testEntTx
(
t
)
client
:=
tx
.
Client
()
seedCtx
,
cancel
:=
context
.
WithTimeout
(
ctx
,
10
*
time
.
Second
)
defer
cancel
()
mustCreateGroup
(
t
,
client
,
&
service
.
Group
{
Name
:
"ag-custom-1-"
+
time
.
Now
()
.
Format
(
time
.
RFC3339Nano
),
Platform
:
service
.
PlatformAntigravity
})
mustCreateGroup
(
t
,
client
,
&
service
.
Group
{
Name
:
"ag-custom-2-"
+
time
.
Now
()
.
Format
(
time
.
RFC3339Nano
),
Platform
:
service
.
PlatformAntigravity
})
require
.
NoError
(
t
,
ensureSimpleModeDefaultGroups
(
seedCtx
,
client
))
count
,
err
:=
client
.
Group
.
Query
()
.
Where
(
group
.
PlatformEQ
(
service
.
PlatformAntigravity
),
group
.
DeletedAtIsNil
())
.
Count
(
seedCtx
)
require
.
NoError
(
t
,
err
)
require
.
GreaterOrEqual
(
t
,
count
,
2
)
}
backend/internal/repository/wire.go
View file @
bdc426a7
...
...
@@ -37,6 +37,16 @@ func ProvidePricingRemoteClient(cfg *config.Config) service.PricingRemoteClient
return
NewPricingRemoteClient
(
cfg
.
Update
.
ProxyURL
)
}
// ProvideSessionLimitCache 创建会话限制缓存
// 用于 Anthropic OAuth/SetupToken 账号的并发会话数量控制
func
ProvideSessionLimitCache
(
rdb
*
redis
.
Client
,
cfg
*
config
.
Config
)
service
.
SessionLimitCache
{
defaultIdleTimeoutMinutes
:=
5
// 默认 5 分钟空闲超时
if
cfg
!=
nil
&&
cfg
.
Gateway
.
SessionIdleTimeoutMinutes
>
0
{
defaultIdleTimeoutMinutes
=
cfg
.
Gateway
.
SessionIdleTimeoutMinutes
}
return
NewSessionLimitCache
(
rdb
,
defaultIdleTimeoutMinutes
)
}
// ProviderSet is the Wire provider set for all repositories
var
ProviderSet
=
wire
.
NewSet
(
NewUserRepository
,
...
...
@@ -62,6 +72,7 @@ var ProviderSet = wire.NewSet(
NewTempUnschedCache
,
NewTimeoutCounterCache
,
ProvideConcurrencyCache
,
ProvideSessionLimitCache
,
NewDashboardCache
,
NewEmailCache
,
NewIdentityCache
,
...
...
backend/internal/server/api_contract_test.go
View file @
bdc426a7
...
...
@@ -441,7 +441,7 @@ func newContractDeps(t *testing.T) *contractDeps {
apiKeyHandler
:=
handler
.
NewAPIKeyHandler
(
apiKeyService
)
usageHandler
:=
handler
.
NewUsageHandler
(
usageService
,
apiKeyService
)
adminSettingHandler
:=
adminhandler
.
NewSettingHandler
(
settingService
,
nil
,
nil
,
nil
)
adminAccountHandler
:=
adminhandler
.
NewAccountHandler
(
adminService
,
nil
,
nil
,
nil
,
nil
,
nil
,
nil
,
nil
,
nil
,
nil
)
adminAccountHandler
:=
adminhandler
.
NewAccountHandler
(
adminService
,
nil
,
nil
,
nil
,
nil
,
nil
,
nil
,
nil
,
nil
,
nil
,
nil
)
jwtAuth
:=
func
(
c
*
gin
.
Context
)
{
c
.
Set
(
string
(
middleware
.
ContextKeyUser
),
middleware
.
AuthSubject
{
...
...
backend/internal/service/account.go
View file @
bdc426a7
...
...
@@ -557,3 +557,141 @@ func (a *Account) IsMixedSchedulingEnabled() bool {
}
return
false
}
// WindowCostSchedulability 窗口费用调度状态
type
WindowCostSchedulability
int
const
(
// WindowCostSchedulable 可正常调度
WindowCostSchedulable
WindowCostSchedulability
=
iota
// WindowCostStickyOnly 仅允许粘性会话
WindowCostStickyOnly
// WindowCostNotSchedulable 完全不可调度
WindowCostNotSchedulable
)
// IsAnthropicOAuthOrSetupToken 判断是否为 Anthropic OAuth 或 SetupToken 类型账号
// 仅这两类账号支持 5h 窗口额度控制和会话数量控制
func
(
a
*
Account
)
IsAnthropicOAuthOrSetupToken
()
bool
{
return
a
.
Platform
==
PlatformAnthropic
&&
(
a
.
Type
==
AccountTypeOAuth
||
a
.
Type
==
AccountTypeSetupToken
)
}
// GetWindowCostLimit 获取 5h 窗口费用阈值(美元)
// 返回 0 表示未启用
func
(
a
*
Account
)
GetWindowCostLimit
()
float64
{
if
a
.
Extra
==
nil
{
return
0
}
if
v
,
ok
:=
a
.
Extra
[
"window_cost_limit"
];
ok
{
return
parseExtraFloat64
(
v
)
}
return
0
}
// GetWindowCostStickyReserve 获取粘性会话预留额度(美元)
// 默认值为 10
func
(
a
*
Account
)
GetWindowCostStickyReserve
()
float64
{
if
a
.
Extra
==
nil
{
return
10.0
}
if
v
,
ok
:=
a
.
Extra
[
"window_cost_sticky_reserve"
];
ok
{
val
:=
parseExtraFloat64
(
v
)
if
val
>
0
{
return
val
}
}
return
10.0
}
// GetMaxSessions 获取最大并发会话数
// 返回 0 表示未启用
func
(
a
*
Account
)
GetMaxSessions
()
int
{
if
a
.
Extra
==
nil
{
return
0
}
if
v
,
ok
:=
a
.
Extra
[
"max_sessions"
];
ok
{
return
parseExtraInt
(
v
)
}
return
0
}
// GetSessionIdleTimeoutMinutes 获取会话空闲超时分钟数
// 默认值为 5 分钟
func
(
a
*
Account
)
GetSessionIdleTimeoutMinutes
()
int
{
if
a
.
Extra
==
nil
{
return
5
}
if
v
,
ok
:=
a
.
Extra
[
"session_idle_timeout_minutes"
];
ok
{
val
:=
parseExtraInt
(
v
)
if
val
>
0
{
return
val
}
}
return
5
}
// CheckWindowCostSchedulability 根据当前窗口费用检查调度状态
// - 费用 < 阈值: WindowCostSchedulable(可正常调度)
// - 费用 >= 阈值 且 < 阈值+预留: WindowCostStickyOnly(仅粘性会话)
// - 费用 >= 阈值+预留: WindowCostNotSchedulable(不可调度)
func
(
a
*
Account
)
CheckWindowCostSchedulability
(
currentWindowCost
float64
)
WindowCostSchedulability
{
limit
:=
a
.
GetWindowCostLimit
()
if
limit
<=
0
{
return
WindowCostSchedulable
}
if
currentWindowCost
<
limit
{
return
WindowCostSchedulable
}
stickyReserve
:=
a
.
GetWindowCostStickyReserve
()
if
currentWindowCost
<
limit
+
stickyReserve
{
return
WindowCostStickyOnly
}
return
WindowCostNotSchedulable
}
// parseExtraFloat64 从 extra 字段解析 float64 值
func
parseExtraFloat64
(
value
any
)
float64
{
switch
v
:=
value
.
(
type
)
{
case
float64
:
return
v
case
float32
:
return
float64
(
v
)
case
int
:
return
float64
(
v
)
case
int64
:
return
float64
(
v
)
case
json
.
Number
:
if
f
,
err
:=
v
.
Float64
();
err
==
nil
{
return
f
}
case
string
:
if
f
,
err
:=
strconv
.
ParseFloat
(
strings
.
TrimSpace
(
v
),
64
);
err
==
nil
{
return
f
}
}
return
0
}
// parseExtraInt 从 extra 字段解析 int 值
func
parseExtraInt
(
value
any
)
int
{
switch
v
:=
value
.
(
type
)
{
case
int
:
return
v
case
int64
:
return
int
(
v
)
case
float64
:
return
int
(
v
)
case
json
.
Number
:
if
i
,
err
:=
v
.
Int64
();
err
==
nil
{
return
int
(
i
)
}
case
string
:
if
i
,
err
:=
strconv
.
Atoi
(
strings
.
TrimSpace
(
v
));
err
==
nil
{
return
i
}
}
return
0
}
backend/internal/service/account_usage_service.go
View file @
bdc426a7
...
...
@@ -575,3 +575,9 @@ func buildGeminiUsageProgress(used, limit int64, resetAt time.Time, tokens int64
},
}
}
// GetAccountWindowStats 获取账号在指定时间窗口内的使用统计
// 用于账号列表页面显示当前窗口费用
func
(
s
*
AccountUsageService
)
GetAccountWindowStats
(
ctx
context
.
Context
,
accountID
int64
,
startTime
time
.
Time
)
(
*
usagestats
.
AccountStats
,
error
)
{
return
s
.
usageLogRepo
.
GetAccountWindowStats
(
ctx
,
accountID
,
startTime
)
}
backend/internal/service/antigravity_model_mapping_test.go
View file @
bdc426a7
...
...
@@ -30,7 +30,7 @@ func TestIsAntigravityModelSupported(t *testing.T) {
{
"可映射 - claude-3-haiku-20240307"
,
"claude-3-haiku-20240307"
,
true
},
// Gemini 前缀透传
{
"Gemini前缀 - gemini-
1
.5-pro"
,
"gemini-
1
.5-pro"
,
true
},
{
"Gemini前缀 - gemini-
2
.5-pro"
,
"gemini-
2
.5-pro"
,
true
},
{
"Gemini前缀 - gemini-unknown-model"
,
"gemini-unknown-model"
,
true
},
{
"Gemini前缀 - gemini-future-version"
,
"gemini-future-version"
,
true
},
...
...
@@ -142,10 +142,10 @@ func TestAntigravityGatewayService_GetMappedModel(t *testing.T) {
expected
:
"gemini-2.5-flash"
,
},
{
name
:
"Gemini透传 - gemini-
1
.5-pro"
,
requestedModel
:
"gemini-
1
.5-pro"
,
name
:
"Gemini透传 - gemini-
2
.5-pro"
,
requestedModel
:
"gemini-
2
.5-pro"
,
accountMapping
:
nil
,
expected
:
"gemini-
1
.5-pro"
,
expected
:
"gemini-
2
.5-pro"
,
},
{
name
:
"Gemini透传 - gemini-future-model"
,
...
...
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment