Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Menu
Open sidebar
陈曦
sub2api
Commits
7331220e
Commit
7331220e
authored
Jan 01, 2026
by
Edric Li
Browse files
Merge remote-tracking branch 'upstream/main'
# Conflicts: # frontend/src/components/account/CreateAccountModal.vue
parents
fb86002e
4f13c8de
Changes
215
Hide whitespace changes
Inline
Side-by-side
backend/internal/service/account.go
View file @
7331220e
...
...
@@ -3,6 +3,7 @@ package service
import
(
"encoding/json"
"strconv"
"strings"
"time"
)
...
...
@@ -78,6 +79,36 @@ func (a *Account) IsGemini() bool {
return
a
.
Platform
==
PlatformGemini
}
func
(
a
*
Account
)
GeminiOAuthType
()
string
{
if
a
.
Platform
!=
PlatformGemini
||
a
.
Type
!=
AccountTypeOAuth
{
return
""
}
oauthType
:=
strings
.
TrimSpace
(
a
.
GetCredential
(
"oauth_type"
))
if
oauthType
==
""
&&
strings
.
TrimSpace
(
a
.
GetCredential
(
"project_id"
))
!=
""
{
return
"code_assist"
}
return
oauthType
}
func
(
a
*
Account
)
GeminiTierID
()
string
{
tierID
:=
strings
.
TrimSpace
(
a
.
GetCredential
(
"tier_id"
))
if
tierID
==
""
{
return
""
}
return
strings
.
ToUpper
(
tierID
)
}
func
(
a
*
Account
)
IsGeminiCodeAssist
()
bool
{
if
a
.
Platform
!=
PlatformGemini
||
a
.
Type
!=
AccountTypeOAuth
{
return
false
}
oauthType
:=
a
.
GeminiOAuthType
()
if
oauthType
==
""
{
return
strings
.
TrimSpace
(
a
.
GetCredential
(
"project_id"
))
!=
""
}
return
oauthType
==
"code_assist"
}
func
(
a
*
Account
)
CanGetUsage
()
bool
{
return
a
.
Type
==
AccountTypeOAuth
}
...
...
@@ -110,6 +141,28 @@ func (a *Account) GetCredential(key string) string {
}
}
// GetCredentialAsTime 解析凭证中的时间戳字段,支持多种格式
// 兼容以下格式:
// - RFC3339 字符串: "2025-01-01T00:00:00Z"
// - Unix 时间戳字符串: "1735689600"
// - Unix 时间戳数字: 1735689600 (float64/int64/json.Number)
func
(
a
*
Account
)
GetCredentialAsTime
(
key
string
)
*
time
.
Time
{
s
:=
a
.
GetCredential
(
key
)
if
s
==
""
{
return
nil
}
// 尝试 RFC3339 格式
if
t
,
err
:=
time
.
Parse
(
time
.
RFC3339
,
s
);
err
==
nil
{
return
&
t
}
// 尝试 Unix 时间戳(纯数字字符串)
if
ts
,
err
:=
strconv
.
ParseInt
(
s
,
10
,
64
);
err
==
nil
{
t
:=
time
.
Unix
(
ts
,
0
)
return
&
t
}
return
nil
}
func
(
a
*
Account
)
GetModelMapping
()
map
[
string
]
string
{
if
a
.
Credentials
==
nil
{
return
nil
...
...
@@ -324,19 +377,7 @@ func (a *Account) GetOpenAITokenExpiresAt() *time.Time {
if
!
a
.
IsOpenAIOAuth
()
{
return
nil
}
expiresAtStr
:=
a
.
GetCredential
(
"expires_at"
)
if
expiresAtStr
==
""
{
return
nil
}
t
,
err
:=
time
.
Parse
(
time
.
RFC3339
,
expiresAtStr
)
if
err
!=
nil
{
if
v
,
ok
:=
a
.
Credentials
[
"expires_at"
]
.
(
float64
);
ok
{
tt
:=
time
.
Unix
(
int64
(
v
),
0
)
return
&
tt
}
return
nil
}
return
&
t
return
a
.
GetCredentialAsTime
(
"expires_at"
)
}
func
(
a
*
Account
)
IsOpenAITokenExpired
()
bool
{
...
...
backend/internal/service/account_service.go
View file @
7331220e
...
...
@@ -5,12 +5,13 @@ import (
"fmt"
"time"
infraerrors
"github.com/Wei-Shaw/sub2api/internal/
infrastructure
/errors"
infraerrors
"github.com/Wei-Shaw/sub2api/internal/
pkg
/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
)
var
(
ErrAccountNotFound
=
infraerrors
.
NotFound
(
"ACCOUNT_NOT_FOUND"
,
"account not found"
)
ErrAccountNilInput
=
infraerrors
.
BadRequest
(
"ACCOUNT_NIL_INPUT"
,
"account input cannot be nil"
)
)
type
AccountRepository
interface
{
...
...
backend/internal/service/account_test_service.go
View file @
7331220e
...
...
@@ -12,7 +12,6 @@ import (
"log"
"net/http"
"regexp"
"strconv"
"strings"
"time"
...
...
@@ -187,9 +186,8 @@ func (s *AccountTestService) testClaudeAccountConnection(c *gin.Context, account
// Check if token needs refresh
needRefresh
:=
false
if
expiresAtStr
:=
account
.
GetCredential
(
"expires_at"
);
expiresAtStr
!=
""
{
expiresAt
,
err
:=
strconv
.
ParseInt
(
expiresAtStr
,
10
,
64
)
if
err
==
nil
&&
time
.
Now
()
.
Unix
()
+
300
>
expiresAt
{
if
expiresAt
:=
account
.
GetCredentialAsTime
(
"expires_at"
);
expiresAt
!=
nil
{
if
time
.
Now
()
.
Add
(
5
*
time
.
Minute
)
.
After
(
*
expiresAt
)
{
needRefresh
=
true
}
}
...
...
@@ -263,7 +261,7 @@ func (s *AccountTestService) testClaudeAccountConnection(c *gin.Context, account
proxyURL
=
account
.
Proxy
.
URL
()
}
resp
,
err
:=
s
.
httpUpstream
.
Do
(
req
,
proxyURL
)
resp
,
err
:=
s
.
httpUpstream
.
Do
(
req
,
proxyURL
,
account
.
ID
,
account
.
Concurrency
)
if
err
!=
nil
{
return
s
.
sendErrorAndEnd
(
c
,
fmt
.
Sprintf
(
"Request failed: %s"
,
err
.
Error
()))
}
...
...
@@ -378,7 +376,7 @@ func (s *AccountTestService) testOpenAIAccountConnection(c *gin.Context, account
proxyURL
=
account
.
Proxy
.
URL
()
}
resp
,
err
:=
s
.
httpUpstream
.
Do
(
req
,
proxyURL
)
resp
,
err
:=
s
.
httpUpstream
.
Do
(
req
,
proxyURL
,
account
.
ID
,
account
.
Concurrency
)
if
err
!=
nil
{
return
s
.
sendErrorAndEnd
(
c
,
fmt
.
Sprintf
(
"Request failed: %s"
,
err
.
Error
()))
}
...
...
@@ -449,7 +447,7 @@ func (s *AccountTestService) testGeminiAccountConnection(c *gin.Context, account
proxyURL
=
account
.
Proxy
.
URL
()
}
resp
,
err
:=
s
.
httpUpstream
.
Do
(
req
,
proxyURL
)
resp
,
err
:=
s
.
httpUpstream
.
Do
(
req
,
proxyURL
,
account
.
ID
,
account
.
Concurrency
)
if
err
!=
nil
{
return
s
.
sendErrorAndEnd
(
c
,
fmt
.
Sprintf
(
"Request failed: %s"
,
err
.
Error
()))
}
...
...
backend/internal/service/account_usage_service.go
View file @
7331220e
...
...
@@ -52,6 +52,9 @@ type UsageLogRepository interface {
// Aggregated stats (optimized)
GetUserStatsAggregated
(
ctx
context
.
Context
,
userID
int64
,
startTime
,
endTime
time
.
Time
)
(
*
usagestats
.
UsageStats
,
error
)
GetApiKeyStatsAggregated
(
ctx
context
.
Context
,
apiKeyID
int64
,
startTime
,
endTime
time
.
Time
)
(
*
usagestats
.
UsageStats
,
error
)
GetAccountStatsAggregated
(
ctx
context
.
Context
,
accountID
int64
,
startTime
,
endTime
time
.
Time
)
(
*
usagestats
.
UsageStats
,
error
)
GetModelStatsAggregated
(
ctx
context
.
Context
,
modelName
string
,
startTime
,
endTime
time
.
Time
)
(
*
usagestats
.
UsageStats
,
error
)
GetDailyStatsAggregated
(
ctx
context
.
Context
,
userID
int64
,
startTime
,
endTime
time
.
Time
)
([]
map
[
string
]
any
,
error
)
}
// apiUsageCache 缓存从 Anthropic API 获取的使用率数据(utilization, resets_at)
...
...
@@ -90,10 +93,12 @@ type UsageProgress struct {
// UsageInfo 账号使用量信息
type
UsageInfo
struct
{
UpdatedAt
*
time
.
Time
`json:"updated_at,omitempty"`
// 更新时间
FiveHour
*
UsageProgress
`json:"five_hour"`
// 5小时窗口
SevenDay
*
UsageProgress
`json:"seven_day,omitempty"`
// 7天窗口
SevenDaySonnet
*
UsageProgress
`json:"seven_day_sonnet,omitempty"`
// 7天Sonnet窗口
UpdatedAt
*
time
.
Time
`json:"updated_at,omitempty"`
// 更新时间
FiveHour
*
UsageProgress
`json:"five_hour"`
// 5小时窗口
SevenDay
*
UsageProgress
`json:"seven_day,omitempty"`
// 7天窗口
SevenDaySonnet
*
UsageProgress
`json:"seven_day_sonnet,omitempty"`
// 7天Sonnet窗口
GeminiProDaily
*
UsageProgress
`json:"gemini_pro_daily,omitempty"`
// Gemini Pro 日配额
GeminiFlashDaily
*
UsageProgress
`json:"gemini_flash_daily,omitempty"`
// Gemini Flash 日配额
}
// ClaudeUsageResponse Anthropic API返回的usage结构
...
...
@@ -119,17 +124,19 @@ type ClaudeUsageFetcher interface {
// AccountUsageService 账号使用量查询服务
type
AccountUsageService
struct
{
accountRepo
AccountRepository
usageLogRepo
UsageLogRepository
usageFetcher
ClaudeUsageFetcher
accountRepo
AccountRepository
usageLogRepo
UsageLogRepository
usageFetcher
ClaudeUsageFetcher
geminiQuotaService
*
GeminiQuotaService
}
// NewAccountUsageService 创建AccountUsageService实例
func
NewAccountUsageService
(
accountRepo
AccountRepository
,
usageLogRepo
UsageLogRepository
,
usageFetcher
ClaudeUsageFetcher
)
*
AccountUsageService
{
func
NewAccountUsageService
(
accountRepo
AccountRepository
,
usageLogRepo
UsageLogRepository
,
usageFetcher
ClaudeUsageFetcher
,
geminiQuotaService
*
GeminiQuotaService
)
*
AccountUsageService
{
return
&
AccountUsageService
{
accountRepo
:
accountRepo
,
usageLogRepo
:
usageLogRepo
,
usageFetcher
:
usageFetcher
,
accountRepo
:
accountRepo
,
usageLogRepo
:
usageLogRepo
,
usageFetcher
:
usageFetcher
,
geminiQuotaService
:
geminiQuotaService
,
}
}
...
...
@@ -143,6 +150,10 @@ func (s *AccountUsageService) GetUsage(ctx context.Context, accountID int64) (*U
return
nil
,
fmt
.
Errorf
(
"get account failed: %w"
,
err
)
}
if
account
.
Platform
==
PlatformGemini
{
return
s
.
getGeminiUsage
(
ctx
,
account
)
}
// 只有oauth类型账号可以通过API获取usage(有profile scope)
if
account
.
CanGetUsage
()
{
var
apiResp
*
ClaudeUsageResponse
...
...
@@ -189,6 +200,36 @@ func (s *AccountUsageService) GetUsage(ctx context.Context, accountID int64) (*U
return
nil
,
fmt
.
Errorf
(
"account type %s does not support usage query"
,
account
.
Type
)
}
func
(
s
*
AccountUsageService
)
getGeminiUsage
(
ctx
context
.
Context
,
account
*
Account
)
(
*
UsageInfo
,
error
)
{
now
:=
time
.
Now
()
usage
:=
&
UsageInfo
{
UpdatedAt
:
&
now
,
}
if
s
.
geminiQuotaService
==
nil
||
s
.
usageLogRepo
==
nil
{
return
usage
,
nil
}
quota
,
ok
:=
s
.
geminiQuotaService
.
QuotaForAccount
(
ctx
,
account
)
if
!
ok
{
return
usage
,
nil
}
start
:=
geminiDailyWindowStart
(
now
)
stats
,
err
:=
s
.
usageLogRepo
.
GetModelStatsWithFilters
(
ctx
,
start
,
now
,
0
,
0
,
account
.
ID
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"get gemini usage stats failed: %w"
,
err
)
}
totals
:=
geminiAggregateUsage
(
stats
)
resetAt
:=
geminiDailyResetTime
(
now
)
usage
.
GeminiProDaily
=
buildGeminiUsageProgress
(
totals
.
ProRequests
,
quota
.
ProRPD
,
resetAt
,
totals
.
ProTokens
,
totals
.
ProCost
,
now
)
usage
.
GeminiFlashDaily
=
buildGeminiUsageProgress
(
totals
.
FlashRequests
,
quota
.
FlashRPD
,
resetAt
,
totals
.
FlashTokens
,
totals
.
FlashCost
,
now
)
return
usage
,
nil
}
// addWindowStats 为 usage 数据添加窗口期统计
// 使用独立缓存(1 分钟),与 API 缓存分离
func
(
s
*
AccountUsageService
)
addWindowStats
(
ctx
context
.
Context
,
account
*
Account
,
usage
*
UsageInfo
)
{
...
...
@@ -385,3 +426,25 @@ func (s *AccountUsageService) estimateSetupTokenUsage(account *Account) *UsageIn
// Setup Token无法获取7d数据
return
info
}
func
buildGeminiUsageProgress
(
used
,
limit
int64
,
resetAt
time
.
Time
,
tokens
int64
,
cost
float64
,
now
time
.
Time
)
*
UsageProgress
{
if
limit
<=
0
{
return
nil
}
utilization
:=
(
float64
(
used
)
/
float64
(
limit
))
*
100
remainingSeconds
:=
int
(
resetAt
.
Sub
(
now
)
.
Seconds
())
if
remainingSeconds
<
0
{
remainingSeconds
=
0
}
resetCopy
:=
resetAt
return
&
UsageProgress
{
Utilization
:
utilization
,
ResetsAt
:
&
resetCopy
,
RemainingSeconds
:
remainingSeconds
,
WindowStats
:
&
WindowStats
{
Requests
:
used
,
Tokens
:
tokens
,
Cost
:
cost
,
},
}
}
backend/internal/service/admin_service.go
View file @
7331220e
...
...
@@ -488,6 +488,11 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn
subscriptionType
=
SubscriptionTypeStandard
}
// 限额字段:0 和 nil 都表示"无限制"
dailyLimit
:=
normalizeLimit
(
input
.
DailyLimitUSD
)
weeklyLimit
:=
normalizeLimit
(
input
.
WeeklyLimitUSD
)
monthlyLimit
:=
normalizeLimit
(
input
.
MonthlyLimitUSD
)
group
:=
&
Group
{
Name
:
input
.
Name
,
Description
:
input
.
Description
,
...
...
@@ -496,9 +501,9 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn
IsExclusive
:
input
.
IsExclusive
,
Status
:
StatusActive
,
SubscriptionType
:
subscriptionType
,
DailyLimitUSD
:
input
.
D
ailyLimit
USD
,
WeeklyLimitUSD
:
input
.
W
eeklyLimit
USD
,
MonthlyLimitUSD
:
input
.
M
onthlyLimit
USD
,
DailyLimitUSD
:
d
ailyLimit
,
WeeklyLimitUSD
:
w
eeklyLimit
,
MonthlyLimitUSD
:
m
onthlyLimit
,
}
if
err
:=
s
.
groupRepo
.
Create
(
ctx
,
group
);
err
!=
nil
{
return
nil
,
err
...
...
@@ -506,6 +511,14 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn
return
group
,
nil
}
// normalizeLimit 将 0 或负数转换为 nil(表示无限制)
func
normalizeLimit
(
limit
*
float64
)
*
float64
{
if
limit
==
nil
||
*
limit
<=
0
{
return
nil
}
return
limit
}
func
(
s
*
adminServiceImpl
)
UpdateGroup
(
ctx
context
.
Context
,
id
int64
,
input
*
UpdateGroupInput
)
(
*
Group
,
error
)
{
group
,
err
:=
s
.
groupRepo
.
GetByID
(
ctx
,
id
)
if
err
!=
nil
{
...
...
@@ -535,15 +548,15 @@ func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *Upd
if
input
.
SubscriptionType
!=
""
{
group
.
SubscriptionType
=
input
.
SubscriptionType
}
// 限额字段
支持设置为nil(清除限额)或具体值
// 限额字段
:0 和 nil 都表示"无限制",正数表示具体限额
if
input
.
DailyLimitUSD
!=
nil
{
group
.
DailyLimitUSD
=
input
.
DailyLimitUSD
group
.
DailyLimitUSD
=
normalizeLimit
(
input
.
DailyLimitUSD
)
}
if
input
.
WeeklyLimitUSD
!=
nil
{
group
.
WeeklyLimitUSD
=
input
.
WeeklyLimitUSD
group
.
WeeklyLimitUSD
=
normalizeLimit
(
input
.
WeeklyLimitUSD
)
}
if
input
.
MonthlyLimitUSD
!=
nil
{
group
.
MonthlyLimitUSD
=
input
.
MonthlyLimitUSD
group
.
MonthlyLimitUSD
=
normalizeLimit
(
input
.
MonthlyLimitUSD
)
}
if
err
:=
s
.
groupRepo
.
Update
(
ctx
,
group
);
err
!=
nil
{
...
...
backend/internal/service/antigravity_gateway_service.go
View file @
7331220e
...
...
@@ -25,7 +25,7 @@ const (
antigravityRetryMaxDelay
=
16
*
time
.
Second
)
// Antigravity 直接支持的模型
// Antigravity 直接支持的模型
(精确匹配透传)
var
antigravitySupportedModels
=
map
[
string
]
bool
{
"claude-opus-4-5-thinking"
:
true
,
"claude-sonnet-4-5"
:
true
,
...
...
@@ -36,23 +36,26 @@ var antigravitySupportedModels = map[string]bool{
"gemini-3-flash"
:
true
,
"gemini-3-pro-low"
:
true
,
"gemini-3-pro-high"
:
true
,
"gemini-3-pro-preview"
:
true
,
"gemini-3-pro-image"
:
true
,
}
// Antigravity 系统默认模型映射表(不支持 → 支持)
var
antigravityModelMapping
=
map
[
string
]
string
{
"claude-3-5-sonnet-20241022"
:
"claude-sonnet-4-5"
,
"claude-3-5-sonnet-20240620"
:
"claude-sonnet-4-5"
,
"claude-sonnet-4-5-20250929"
:
"claude-sonnet-4-5-thinking"
,
"claude-opus-4"
:
"claude-opus-4-5-thinking"
,
"claude-opus-4-5-20251101"
:
"claude-opus-4-5-thinking"
,
"claude-haiku-4"
:
"gemini-3-flash"
,
"claude-haiku-4-5"
:
"gemini-3-flash"
,
"claude-3-haiku-20240307"
:
"gemini-3-flash"
,
"claude-haiku-4-5-20251001"
:
"gemini-3-flash"
,
// 生图模型:官方名 → Antigravity 内部名
"gemini-3-pro-image-preview"
:
"gemini-3-pro-image"
,
// Antigravity 前缀映射表(按前缀长度降序排列,确保最长匹配优先)
// 用于处理模型版本号变化(如 -20251111, -thinking, -preview 等后缀)
var
antigravityPrefixMapping
=
[]
struct
{
prefix
string
target
string
}{
// 长前缀优先
{
"gemini-3-pro-image"
,
"gemini-3-pro-image"
},
// gemini-3-pro-image-preview 等
{
"claude-3-5-sonnet"
,
"claude-sonnet-4-5"
},
// 旧版 claude-3-5-sonnet-xxx
{
"claude-sonnet-4-5"
,
"claude-sonnet-4-5"
},
// claude-sonnet-4-5-xxx
{
"claude-haiku-4-5"
,
"gemini-3-flash"
},
// claude-haiku-4-5-xxx
{
"claude-opus-4-5"
,
"claude-opus-4-5-thinking"
},
{
"claude-3-haiku"
,
"gemini-3-flash"
},
// 旧版 claude-3-haiku-xxx
{
"claude-sonnet-4"
,
"claude-sonnet-4-5"
},
{
"claude-haiku-4"
,
"gemini-3-flash"
},
{
"claude-opus-4"
,
"claude-opus-4-5-thinking"
},
{
"gemini-3-pro"
,
"gemini-3-pro-high"
},
// gemini-3-pro, gemini-3-pro-preview 等
}
// AntigravityGatewayService 处理 Antigravity 平台的 API 转发
...
...
@@ -84,24 +87,27 @@ func (s *AntigravityGatewayService) GetTokenProvider() *AntigravityTokenProvider
}
// getMappedModel 获取映射后的模型名
// 逻辑:账户映射 → 直接支持透传 → 前缀映射 → gemini透传 → 默认值
func
(
s
*
AntigravityGatewayService
)
getMappedModel
(
account
*
Account
,
requestedModel
string
)
string
{
// 1.
优先使用
账户级映射(
复用现有方法
)
// 1. 账户级映射(
用户自定义优先
)
if
mapped
:=
account
.
GetMappedModel
(
requestedModel
);
mapped
!=
requestedModel
{
return
mapped
}
// 2.
系统默认映射
if
mapped
,
ok
:=
antigravityModelMapping
[
requestedModel
]
;
ok
{
return
mapped
// 2.
直接支持的模型透传
if
antigravitySupportedModels
[
requestedModel
]
{
return
requestedModel
}
// 3. Gemini 模型透传
if
strings
.
HasPrefix
(
requestedModel
,
"gemini-"
)
{
return
requestedModel
// 3. 前缀映射(处理版本号变化,如 -20251111, -thinking, -preview)
for
_
,
pm
:=
range
antigravityPrefixMapping
{
if
strings
.
HasPrefix
(
requestedModel
,
pm
.
prefix
)
{
return
pm
.
target
}
}
// 4.
Claude 前缀透传直接支持的
模型
if
antigravitySupportedModels
[
requestedModel
]
{
// 4.
Gemini 模型透传(未匹配到前缀的 gemini
模型
)
if
strings
.
HasPrefix
(
requestedModel
,
"gemini-"
)
{
return
requestedModel
}
...
...
@@ -110,24 +116,10 @@ func (s *AntigravityGatewayService) getMappedModel(account *Account, requestedMo
}
// IsModelSupported 检查模型是否被支持
// 所有 claude- 和 gemini- 前缀的模型都能通过映射或透传支持
func
(
s
*
AntigravityGatewayService
)
IsModelSupported
(
requestedModel
string
)
bool
{
// 直接支持的模型
if
antigravitySupportedModels
[
requestedModel
]
{
return
true
}
// 可映射的模型
if
_
,
ok
:=
antigravityModelMapping
[
requestedModel
];
ok
{
return
true
}
// Gemini 前缀透传
if
strings
.
HasPrefix
(
requestedModel
,
"gemini-"
)
{
return
true
}
// Claude 模型支持(通过默认映射)
if
strings
.
HasPrefix
(
requestedModel
,
"claude-"
)
{
return
true
}
return
false
return
strings
.
HasPrefix
(
requestedModel
,
"claude-"
)
||
strings
.
HasPrefix
(
requestedModel
,
"gemini-"
)
}
// TestConnectionResult 测试连接结果
...
...
@@ -180,7 +172,7 @@ func (s *AntigravityGatewayService) TestConnection(ctx context.Context, account
}
// 发送请求
resp
,
err
:=
s
.
httpUpstream
.
Do
(
req
,
proxyURL
)
resp
,
err
:=
s
.
httpUpstream
.
Do
(
req
,
proxyURL
,
account
.
ID
,
account
.
Concurrency
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"请求失败: %w"
,
err
)
}
...
...
@@ -358,6 +350,15 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
return
nil
,
fmt
.
Errorf
(
"transform request: %w"
,
err
)
}
// 调试:记录转换后的请求体(仅记录前 2000 字符)
if
bodyJSON
,
err
:=
json
.
Marshal
(
geminiBody
);
err
==
nil
{
truncated
:=
string
(
bodyJSON
)
if
len
(
truncated
)
>
2000
{
truncated
=
truncated
[
:
2000
]
+
"..."
}
log
.
Printf
(
"[Debug] Transformed Gemini request: %s"
,
truncated
)
}
// 构建上游 action
action
:=
"generateContent"
if
claudeReq
.
Stream
{
...
...
@@ -372,7 +373,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
return
nil
,
err
}
resp
,
err
=
s
.
httpUpstream
.
Do
(
upstreamReq
,
proxyURL
)
resp
,
err
=
s
.
httpUpstream
.
Do
(
upstreamReq
,
proxyURL
,
account
.
ID
,
account
.
Concurrency
)
if
err
!=
nil
{
if
attempt
<
antigravityMaxRetries
{
log
.
Printf
(
"Antigravity account %d: upstream request failed, retry %d/%d: %v"
,
account
.
ID
,
attempt
,
antigravityMaxRetries
,
err
)
...
...
@@ -515,7 +516,7 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co
return
nil
,
err
}
resp
,
err
=
s
.
httpUpstream
.
Do
(
upstreamReq
,
proxyURL
)
resp
,
err
=
s
.
httpUpstream
.
Do
(
upstreamReq
,
proxyURL
,
account
.
ID
,
account
.
Concurrency
)
if
err
!=
nil
{
if
attempt
<
antigravityMaxRetries
{
log
.
Printf
(
"Antigravity account %d: upstream request failed, retry %d/%d: %v"
,
account
.
ID
,
attempt
,
antigravityMaxRetries
,
err
)
...
...
backend/internal/service/antigravity_model_mapping_test.go
View file @
7331220e
...
...
@@ -131,7 +131,7 @@ func TestAntigravityGatewayService_GetMappedModel(t *testing.T) {
name
:
"系统映射 - claude-sonnet-4-5-20250929"
,
requestedModel
:
"claude-sonnet-4-5-20250929"
,
accountMapping
:
nil
,
expected
:
"claude-sonnet-4-5
-thinking
"
,
expected
:
"claude-sonnet-4-5"
,
},
// 3. Gemini 透传
...
...
backend/internal/service/antigravity_quota_refresher.go
View file @
7331220e
...
...
@@ -191,7 +191,7 @@ func (r *AntigravityQuotaRefresher) refreshAccountQuota(ctx context.Context, acc
// isTokenExpired 检查 token 是否过期
func
(
r
*
AntigravityQuotaRefresher
)
isTokenExpired
(
account
*
Account
)
bool
{
expiresAt
:=
parseAntigravityExpiresAt
(
account
)
expiresAt
:=
account
.
GetCredentialAsTime
(
"expires_at"
)
if
expiresAt
==
nil
{
return
false
}
...
...
backend/internal/service/antigravity_token_provider.go
View file @
7331220e
...
...
@@ -55,7 +55,7 @@ func (p *AntigravityTokenProvider) GetAccessToken(ctx context.Context, account *
}
// 2. 如果即将过期则刷新
expiresAt
:=
parseAntigravityExpiresAt
(
account
)
expiresAt
:=
account
.
GetCredentialAsTime
(
"expires_at"
)
needsRefresh
:=
expiresAt
==
nil
||
time
.
Until
(
*
expiresAt
)
<=
antigravityTokenRefreshSkew
if
needsRefresh
&&
p
.
tokenCache
!=
nil
{
locked
,
err
:=
p
.
tokenCache
.
AcquireRefreshLock
(
ctx
,
cacheKey
,
30
*
time
.
Second
)
...
...
@@ -72,7 +72,7 @@ func (p *AntigravityTokenProvider) GetAccessToken(ctx context.Context, account *
if
err
==
nil
&&
fresh
!=
nil
{
account
=
fresh
}
expiresAt
=
parseAntigravityExpiresAt
(
account
)
expiresAt
=
account
.
GetCredentialAsTime
(
"expires_at"
)
if
expiresAt
==
nil
||
time
.
Until
(
*
expiresAt
)
<=
antigravityTokenRefreshSkew
{
if
p
.
antigravityOAuthService
==
nil
{
return
""
,
errors
.
New
(
"antigravity oauth service not configured"
)
...
...
@@ -91,7 +91,7 @@ func (p *AntigravityTokenProvider) GetAccessToken(ctx context.Context, account *
if
updateErr
:=
p
.
accountRepo
.
Update
(
ctx
,
account
);
updateErr
!=
nil
{
log
.
Printf
(
"[AntigravityTokenProvider] Failed to update account credentials: %v"
,
updateErr
)
}
expiresAt
=
parseAntigravityExpiresAt
(
account
)
expiresAt
=
account
.
GetCredentialAsTime
(
"expires_at"
)
}
}
}
...
...
@@ -128,18 +128,3 @@ func antigravityTokenCacheKey(account *Account) string {
}
return
"ag:account:"
+
strconv
.
FormatInt
(
account
.
ID
,
10
)
}
func
parseAntigravityExpiresAt
(
account
*
Account
)
*
time
.
Time
{
raw
:=
strings
.
TrimSpace
(
account
.
GetCredential
(
"expires_at"
))
if
raw
==
""
{
return
nil
}
if
unixSec
,
err
:=
strconv
.
ParseInt
(
raw
,
10
,
64
);
err
==
nil
&&
unixSec
>
0
{
t
:=
time
.
Unix
(
unixSec
,
0
)
return
&
t
}
if
t
,
err
:=
time
.
Parse
(
time
.
RFC3339
,
raw
);
err
==
nil
{
return
&
t
}
return
nil
}
backend/internal/service/antigravity_token_refresher.go
View file @
7331220e
...
...
@@ -2,7 +2,7 @@ package service
import
(
"context"
"
strconv
"
"
fmt
"
"time"
)
...
...
@@ -29,21 +29,22 @@ func (r *AntigravityTokenRefresher) CanRefresh(account *Account) bool {
}
// NeedsRefresh 检查账户是否需要刷新
// Antigravity 使用固定的1
0
分钟刷新窗口,忽略全局配置
// Antigravity 使用固定的1
5
分钟刷新窗口,忽略全局配置
func
(
r
*
AntigravityTokenRefresher
)
NeedsRefresh
(
account
*
Account
,
_
time
.
Duration
)
bool
{
if
!
r
.
CanRefresh
(
account
)
{
return
false
}
expiresAt
Str
:=
account
.
GetCredential
(
"expires_at"
)
if
expiresAt
Str
==
""
{
expiresAt
:=
account
.
GetCredential
AsTime
(
"expires_at"
)
if
expiresAt
==
nil
{
return
false
}
expiresAt
,
err
:=
strconv
.
ParseInt
(
expiresAtStr
,
10
,
64
)
if
err
!=
nil
{
return
false
timeUntilExpiry
:=
time
.
Until
(
*
expiresAt
)
needsRefresh
:=
timeUntilExpiry
<
antigravityRefreshWindow
if
needsRefresh
{
fmt
.
Printf
(
"[AntigravityTokenRefresher] Account %d needs refresh: expires_at=%s, time_until_expiry=%v, window=%v
\n
"
,
account
.
ID
,
expiresAt
.
Format
(
"2006-01-02 15:04:05"
),
timeUntilExpiry
,
antigravityRefreshWindow
)
}
expiryTime
:=
time
.
Unix
(
expiresAt
,
0
)
return
time
.
Until
(
expiryTime
)
<
antigravityRefreshWindow
return
needsRefresh
}
// Refresh 执行 token 刷新
...
...
backend/internal/service/api_key_service.go
View file @
7331220e
...
...
@@ -8,7 +8,7 @@ import (
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
infraerrors
"github.com/Wei-Shaw/sub2api/internal/
infrastructure
/errors"
infraerrors
"github.com/Wei-Shaw/sub2api/internal/
pkg
/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
)
...
...
backend/internal/service/auth_service.go
View file @
7331220e
...
...
@@ -8,7 +8,7 @@ import (
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
infraerrors
"github.com/Wei-Shaw/sub2api/internal/
infrastructure
/errors"
infraerrors
"github.com/Wei-Shaw/sub2api/internal/
pkg
/errors"
"github.com/golang-jwt/jwt/v5"
"golang.org/x/crypto/bcrypt"
...
...
backend/internal/service/billing_cache_service.go
View file @
7331220e
...
...
@@ -4,10 +4,12 @@ import (
"context"
"fmt"
"log"
"sync"
"sync/atomic"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
infraerrors
"github.com/Wei-Shaw/sub2api/internal/
infrastructure
/errors"
infraerrors
"github.com/Wei-Shaw/sub2api/internal/
pkg
/errors"
)
// 错误定义
...
...
@@ -27,6 +29,46 @@ type subscriptionCacheData struct {
Version
int64
}
// 缓存写入任务类型
type
cacheWriteKind
int
const
(
cacheWriteSetBalance
cacheWriteKind
=
iota
cacheWriteSetSubscription
cacheWriteUpdateSubscriptionUsage
cacheWriteDeductBalance
)
// 异步缓存写入工作池配置
//
// 性能优化说明:
// 原实现在请求热路径中使用 goroutine 异步更新缓存,存在以下问题:
// 1. 每次请求创建新 goroutine,高并发下产生大量短生命周期 goroutine
// 2. 无法控制并发数量,可能导致 Redis 连接耗尽
// 3. goroutine 创建/销毁带来额外开销
//
// 新实现使用固定大小的工作池:
// 1. 预创建 10 个 worker goroutine,避免频繁创建销毁
// 2. 使用带缓冲的 channel(1000)作为任务队列,平滑写入峰值
// 3. 非阻塞写入,队列满时关键任务同步回退,非关键任务丢弃并告警
// 4. 统一超时控制,避免慢操作阻塞工作池
const
(
cacheWriteWorkerCount
=
10
// 工作协程数量
cacheWriteBufferSize
=
1000
// 任务队列缓冲大小
cacheWriteTimeout
=
2
*
time
.
Second
// 单个写入操作超时
cacheWriteDropLogInterval
=
5
*
time
.
Second
// 丢弃日志节流间隔
)
// cacheWriteTask 缓存写入任务
type
cacheWriteTask
struct
{
kind
cacheWriteKind
userID
int64
groupID
int64
balance
float64
amount
float64
subscriptionData
*
subscriptionCacheData
}
// BillingCacheService 计费缓存服务
// 负责余额和订阅数据的缓存管理,提供高性能的计费资格检查
type
BillingCacheService
struct
{
...
...
@@ -34,16 +76,151 @@ type BillingCacheService struct {
userRepo
UserRepository
subRepo
UserSubscriptionRepository
cfg
*
config
.
Config
cacheWriteChan
chan
cacheWriteTask
cacheWriteWg
sync
.
WaitGroup
cacheWriteStopOnce
sync
.
Once
// 丢弃日志节流计数器(减少高负载下日志噪音)
cacheWriteDropFullCount
uint64
cacheWriteDropFullLastLog
int64
cacheWriteDropClosedCount
uint64
cacheWriteDropClosedLastLog
int64
}
// NewBillingCacheService 创建计费缓存服务
func
NewBillingCacheService
(
cache
BillingCache
,
userRepo
UserRepository
,
subRepo
UserSubscriptionRepository
,
cfg
*
config
.
Config
)
*
BillingCacheService
{
return
&
BillingCacheService
{
svc
:=
&
BillingCacheService
{
cache
:
cache
,
userRepo
:
userRepo
,
subRepo
:
subRepo
,
cfg
:
cfg
,
}
svc
.
startCacheWriteWorkers
()
return
svc
}
// Stop 关闭缓存写入工作池
func
(
s
*
BillingCacheService
)
Stop
()
{
s
.
cacheWriteStopOnce
.
Do
(
func
()
{
if
s
.
cacheWriteChan
==
nil
{
return
}
close
(
s
.
cacheWriteChan
)
s
.
cacheWriteWg
.
Wait
()
s
.
cacheWriteChan
=
nil
})
}
func
(
s
*
BillingCacheService
)
startCacheWriteWorkers
()
{
s
.
cacheWriteChan
=
make
(
chan
cacheWriteTask
,
cacheWriteBufferSize
)
for
i
:=
0
;
i
<
cacheWriteWorkerCount
;
i
++
{
s
.
cacheWriteWg
.
Add
(
1
)
go
s
.
cacheWriteWorker
()
}
}
// enqueueCacheWrite 尝试将任务入队,队列满时返回 false(并记录告警)。
func
(
s
*
BillingCacheService
)
enqueueCacheWrite
(
task
cacheWriteTask
)
(
enqueued
bool
)
{
if
s
.
cacheWriteChan
==
nil
{
return
false
}
defer
func
()
{
if
recovered
:=
recover
();
recovered
!=
nil
{
// 队列已关闭时可能触发 panic,记录后静默失败。
s
.
logCacheWriteDrop
(
task
,
"closed"
)
enqueued
=
false
}
}()
select
{
case
s
.
cacheWriteChan
<-
task
:
return
true
default
:
// 队列满时不阻塞主流程,交由调用方决定是否同步回退。
s
.
logCacheWriteDrop
(
task
,
"full"
)
return
false
}
}
func
(
s
*
BillingCacheService
)
cacheWriteWorker
()
{
defer
s
.
cacheWriteWg
.
Done
()
for
task
:=
range
s
.
cacheWriteChan
{
ctx
,
cancel
:=
context
.
WithTimeout
(
context
.
Background
(),
cacheWriteTimeout
)
switch
task
.
kind
{
case
cacheWriteSetBalance
:
s
.
setBalanceCache
(
ctx
,
task
.
userID
,
task
.
balance
)
case
cacheWriteSetSubscription
:
s
.
setSubscriptionCache
(
ctx
,
task
.
userID
,
task
.
groupID
,
task
.
subscriptionData
)
case
cacheWriteUpdateSubscriptionUsage
:
if
s
.
cache
!=
nil
{
if
err
:=
s
.
cache
.
UpdateSubscriptionUsage
(
ctx
,
task
.
userID
,
task
.
groupID
,
task
.
amount
);
err
!=
nil
{
log
.
Printf
(
"Warning: update subscription cache failed for user %d group %d: %v"
,
task
.
userID
,
task
.
groupID
,
err
)
}
}
case
cacheWriteDeductBalance
:
if
s
.
cache
!=
nil
{
if
err
:=
s
.
cache
.
DeductUserBalance
(
ctx
,
task
.
userID
,
task
.
amount
);
err
!=
nil
{
log
.
Printf
(
"Warning: deduct balance cache failed for user %d: %v"
,
task
.
userID
,
err
)
}
}
}
cancel
()
}
}
// cacheWriteKindName 用于日志中的任务类型标识,便于排查丢弃原因。
func
cacheWriteKindName
(
kind
cacheWriteKind
)
string
{
switch
kind
{
case
cacheWriteSetBalance
:
return
"set_balance"
case
cacheWriteSetSubscription
:
return
"set_subscription"
case
cacheWriteUpdateSubscriptionUsage
:
return
"update_subscription_usage"
case
cacheWriteDeductBalance
:
return
"deduct_balance"
default
:
return
"unknown"
}
}
// logCacheWriteDrop 使用节流方式记录丢弃情况,并汇总丢弃数量。
func
(
s
*
BillingCacheService
)
logCacheWriteDrop
(
task
cacheWriteTask
,
reason
string
)
{
var
(
countPtr
*
uint64
lastPtr
*
int64
)
switch
reason
{
case
"full"
:
countPtr
=
&
s
.
cacheWriteDropFullCount
lastPtr
=
&
s
.
cacheWriteDropFullLastLog
case
"closed"
:
countPtr
=
&
s
.
cacheWriteDropClosedCount
lastPtr
=
&
s
.
cacheWriteDropClosedLastLog
default
:
return
}
atomic
.
AddUint64
(
countPtr
,
1
)
now
:=
time
.
Now
()
.
UnixNano
()
last
:=
atomic
.
LoadInt64
(
lastPtr
)
if
now
-
last
<
int64
(
cacheWriteDropLogInterval
)
{
return
}
if
!
atomic
.
CompareAndSwapInt64
(
lastPtr
,
last
,
now
)
{
return
}
dropped
:=
atomic
.
SwapUint64
(
countPtr
,
0
)
if
dropped
==
0
{
return
}
log
.
Printf
(
"Warning: cache write queue %s, dropped %d tasks in last %s (latest kind=%s user %d group %d)"
,
reason
,
dropped
,
cacheWriteDropLogInterval
,
cacheWriteKindName
(
task
.
kind
),
task
.
userID
,
task
.
groupID
,
)
}
// ============================================
...
...
@@ -70,11 +247,11 @@ func (s *BillingCacheService) GetUserBalance(ctx context.Context, userID int64)
}
// 异步建立缓存
go
func
()
{
cacheCtx
,
cancel
:=
context
.
WithTimeout
(
context
.
Background
(),
2
*
time
.
Second
)
defer
cancel
()
s
.
setBalanceCache
(
cacheCtx
,
userID
,
balance
)
}
(
)
_
=
s
.
enqueueCacheWrite
(
cacheWriteTask
{
kind
:
cacheWriteSetBalance
,
userID
:
userID
,
balance
:
balance
,
})
return
balance
,
nil
}
...
...
@@ -98,7 +275,7 @@ func (s *BillingCacheService) setBalanceCache(ctx context.Context, userID int64,
}
}
// DeductBalanceCache 扣减余额缓存(
异
步调用
,用于扣费后更新缓存
)
// DeductBalanceCache 扣减余额缓存(
同
步调用)
func
(
s
*
BillingCacheService
)
DeductBalanceCache
(
ctx
context
.
Context
,
userID
int64
,
amount
float64
)
error
{
if
s
.
cache
==
nil
{
return
nil
...
...
@@ -106,6 +283,26 @@ func (s *BillingCacheService) DeductBalanceCache(ctx context.Context, userID int
return
s
.
cache
.
DeductUserBalance
(
ctx
,
userID
,
amount
)
}
// QueueDeductBalance 异步扣减余额缓存
func
(
s
*
BillingCacheService
)
QueueDeductBalance
(
userID
int64
,
amount
float64
)
{
if
s
.
cache
==
nil
{
return
}
// 队列满时同步回退,避免关键扣减被静默丢弃。
if
s
.
enqueueCacheWrite
(
cacheWriteTask
{
kind
:
cacheWriteDeductBalance
,
userID
:
userID
,
amount
:
amount
,
})
{
return
}
ctx
,
cancel
:=
context
.
WithTimeout
(
context
.
Background
(),
cacheWriteTimeout
)
defer
cancel
()
if
err
:=
s
.
DeductBalanceCache
(
ctx
,
userID
,
amount
);
err
!=
nil
{
log
.
Printf
(
"Warning: deduct balance cache fallback failed for user %d: %v"
,
userID
,
err
)
}
}
// InvalidateUserBalance 失效用户余额缓存
func
(
s
*
BillingCacheService
)
InvalidateUserBalance
(
ctx
context
.
Context
,
userID
int64
)
error
{
if
s
.
cache
==
nil
{
...
...
@@ -141,11 +338,12 @@ func (s *BillingCacheService) GetSubscriptionStatus(ctx context.Context, userID,
}
// 异步建立缓存
go
func
()
{
cacheCtx
,
cancel
:=
context
.
WithTimeout
(
context
.
Background
(),
2
*
time
.
Second
)
defer
cancel
()
s
.
setSubscriptionCache
(
cacheCtx
,
userID
,
groupID
,
data
)
}()
_
=
s
.
enqueueCacheWrite
(
cacheWriteTask
{
kind
:
cacheWriteSetSubscription
,
userID
:
userID
,
groupID
:
groupID
,
subscriptionData
:
data
,
})
return
data
,
nil
}
...
...
@@ -199,7 +397,7 @@ func (s *BillingCacheService) setSubscriptionCache(ctx context.Context, userID,
}
}
// UpdateSubscriptionUsage 更新订阅用量缓存(
异
步调用
,用于扣费后更新缓存
)
// UpdateSubscriptionUsage 更新订阅用量缓存(
同
步调用)
func
(
s
*
BillingCacheService
)
UpdateSubscriptionUsage
(
ctx
context
.
Context
,
userID
,
groupID
int64
,
costUSD
float64
)
error
{
if
s
.
cache
==
nil
{
return
nil
...
...
@@ -207,6 +405,27 @@ func (s *BillingCacheService) UpdateSubscriptionUsage(ctx context.Context, userI
return
s
.
cache
.
UpdateSubscriptionUsage
(
ctx
,
userID
,
groupID
,
costUSD
)
}
// QueueUpdateSubscriptionUsage 异步更新订阅用量缓存
func
(
s
*
BillingCacheService
)
QueueUpdateSubscriptionUsage
(
userID
,
groupID
int64
,
costUSD
float64
)
{
if
s
.
cache
==
nil
{
return
}
// 队列满时同步回退,确保订阅用量及时更新。
if
s
.
enqueueCacheWrite
(
cacheWriteTask
{
kind
:
cacheWriteUpdateSubscriptionUsage
,
userID
:
userID
,
groupID
:
groupID
,
amount
:
costUSD
,
})
{
return
}
ctx
,
cancel
:=
context
.
WithTimeout
(
context
.
Background
(),
cacheWriteTimeout
)
defer
cancel
()
if
err
:=
s
.
UpdateSubscriptionUsage
(
ctx
,
userID
,
groupID
,
costUSD
);
err
!=
nil
{
log
.
Printf
(
"Warning: update subscription cache fallback failed for user %d group %d: %v"
,
userID
,
groupID
,
err
)
}
}
// InvalidateSubscription 失效指定订阅缓存
func
(
s
*
BillingCacheService
)
InvalidateSubscription
(
ctx
context
.
Context
,
userID
,
groupID
int64
)
error
{
if
s
.
cache
==
nil
{
...
...
backend/internal/service/billing_cache_service_test.go
0 → 100644
View file @
7331220e
package
service
import
(
"context"
"errors"
"sync/atomic"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/stretchr/testify/require"
)
type
billingCacheWorkerStub
struct
{
balanceUpdates
int64
subscriptionUpdates
int64
}
func
(
b
*
billingCacheWorkerStub
)
GetUserBalance
(
ctx
context
.
Context
,
userID
int64
)
(
float64
,
error
)
{
return
0
,
errors
.
New
(
"not implemented"
)
}
func
(
b
*
billingCacheWorkerStub
)
SetUserBalance
(
ctx
context
.
Context
,
userID
int64
,
balance
float64
)
error
{
atomic
.
AddInt64
(
&
b
.
balanceUpdates
,
1
)
return
nil
}
func
(
b
*
billingCacheWorkerStub
)
DeductUserBalance
(
ctx
context
.
Context
,
userID
int64
,
amount
float64
)
error
{
atomic
.
AddInt64
(
&
b
.
balanceUpdates
,
1
)
return
nil
}
func
(
b
*
billingCacheWorkerStub
)
InvalidateUserBalance
(
ctx
context
.
Context
,
userID
int64
)
error
{
return
nil
}
func
(
b
*
billingCacheWorkerStub
)
GetSubscriptionCache
(
ctx
context
.
Context
,
userID
,
groupID
int64
)
(
*
SubscriptionCacheData
,
error
)
{
return
nil
,
errors
.
New
(
"not implemented"
)
}
func
(
b
*
billingCacheWorkerStub
)
SetSubscriptionCache
(
ctx
context
.
Context
,
userID
,
groupID
int64
,
data
*
SubscriptionCacheData
)
error
{
atomic
.
AddInt64
(
&
b
.
subscriptionUpdates
,
1
)
return
nil
}
func
(
b
*
billingCacheWorkerStub
)
UpdateSubscriptionUsage
(
ctx
context
.
Context
,
userID
,
groupID
int64
,
cost
float64
)
error
{
atomic
.
AddInt64
(
&
b
.
subscriptionUpdates
,
1
)
return
nil
}
func
(
b
*
billingCacheWorkerStub
)
InvalidateSubscriptionCache
(
ctx
context
.
Context
,
userID
,
groupID
int64
)
error
{
return
nil
}
func
TestBillingCacheServiceQueueHighLoad
(
t
*
testing
.
T
)
{
cache
:=
&
billingCacheWorkerStub
{}
svc
:=
NewBillingCacheService
(
cache
,
nil
,
nil
,
&
config
.
Config
{})
t
.
Cleanup
(
svc
.
Stop
)
start
:=
time
.
Now
()
for
i
:=
0
;
i
<
cacheWriteBufferSize
*
2
;
i
++
{
svc
.
QueueDeductBalance
(
1
,
1
)
}
require
.
Less
(
t
,
time
.
Since
(
start
),
2
*
time
.
Second
)
svc
.
QueueUpdateSubscriptionUsage
(
1
,
2
,
1.5
)
require
.
Eventually
(
t
,
func
()
bool
{
return
atomic
.
LoadInt64
(
&
cache
.
balanceUpdates
)
>
0
},
2
*
time
.
Second
,
10
*
time
.
Millisecond
)
require
.
Eventually
(
t
,
func
()
bool
{
return
atomic
.
LoadInt64
(
&
cache
.
subscriptionUpdates
)
>
0
},
2
*
time
.
Second
,
10
*
time
.
Millisecond
)
}
backend/internal/service/concurrency_service.go
View file @
7331220e
...
...
@@ -9,24 +9,35 @@ import (
"time"
)
// ConcurrencyCache
defines cache operations for concurrency service
//
Uses independent keys per request slot with native Redis TTL for automatic cleanup
// ConcurrencyCache
定义并发控制的缓存接口
//
使用有序集合存储槽位,按时间戳清理过期条目
type
ConcurrencyCache
interface
{
//
Account slot management - each slot is a separate key with independent TTL
//
Key format
: concurrency:account:{accountID}
:{
requestID
}
//
账号槽位管理
//
键格式
: concurrency:account:{accountID}
(有序集合,成员为
requestID
)
AcquireAccountSlot
(
ctx
context
.
Context
,
accountID
int64
,
maxConcurrency
int
,
requestID
string
)
(
bool
,
error
)
ReleaseAccountSlot
(
ctx
context
.
Context
,
accountID
int64
,
requestID
string
)
error
GetAccountConcurrency
(
ctx
context
.
Context
,
accountID
int64
)
(
int
,
error
)
// User slot management - each slot is a separate key with independent TTL
// Key format: concurrency:user:{userID}:{requestID}
// 账号等待队列(账号级)
IncrementAccountWaitCount
(
ctx
context
.
Context
,
accountID
int64
,
maxWait
int
)
(
bool
,
error
)
DecrementAccountWaitCount
(
ctx
context
.
Context
,
accountID
int64
)
error
GetAccountWaitingCount
(
ctx
context
.
Context
,
accountID
int64
)
(
int
,
error
)
// 用户槽位管理
// 键格式: concurrency:user:{userID}(有序集合,成员为 requestID)
AcquireUserSlot
(
ctx
context
.
Context
,
userID
int64
,
maxConcurrency
int
,
requestID
string
)
(
bool
,
error
)
ReleaseUserSlot
(
ctx
context
.
Context
,
userID
int64
,
requestID
string
)
error
GetUserConcurrency
(
ctx
context
.
Context
,
userID
int64
)
(
int
,
error
)
//
Wait queue - uses counter with TTL set only on creation
//
等待队列计数(只在首次创建时设置 TTL)
IncrementWaitCount
(
ctx
context
.
Context
,
userID
int64
,
maxWait
int
)
(
bool
,
error
)
DecrementWaitCount
(
ctx
context
.
Context
,
userID
int64
)
error
// 批量负载查询(只读)
GetAccountsLoadBatch
(
ctx
context
.
Context
,
accounts
[]
AccountWithConcurrency
)
(
map
[
int64
]
*
AccountLoadInfo
,
error
)
// 清理过期槽位(后台任务)
CleanupExpiredAccountSlots
(
ctx
context
.
Context
,
accountID
int64
)
error
}
// generateRequestID generates a unique request ID for concurrency slot tracking
...
...
@@ -61,6 +72,18 @@ type AcquireResult struct {
ReleaseFunc
func
()
// Must be called when done (typically via defer)
}
type
AccountWithConcurrency
struct
{
ID
int64
MaxConcurrency
int
}
type
AccountLoadInfo
struct
{
AccountID
int64
CurrentConcurrency
int
WaitingCount
int
LoadRate
int
// 0-100+ (percent)
}
// AcquireAccountSlot attempts to acquire a concurrency slot for an account.
// If the account is at max concurrency, it waits until a slot is available or timeout.
// Returns a release function that MUST be called when the request completes.
...
...
@@ -177,6 +200,42 @@ func (s *ConcurrencyService) DecrementWaitCount(ctx context.Context, userID int6
}
}
// IncrementAccountWaitCount increments the wait queue counter for an account.
func
(
s
*
ConcurrencyService
)
IncrementAccountWaitCount
(
ctx
context
.
Context
,
accountID
int64
,
maxWait
int
)
(
bool
,
error
)
{
if
s
.
cache
==
nil
{
return
true
,
nil
}
result
,
err
:=
s
.
cache
.
IncrementAccountWaitCount
(
ctx
,
accountID
,
maxWait
)
if
err
!=
nil
{
log
.
Printf
(
"Warning: increment wait count failed for account %d: %v"
,
accountID
,
err
)
return
true
,
nil
}
return
result
,
nil
}
// DecrementAccountWaitCount decrements the wait queue counter for an account.
func
(
s
*
ConcurrencyService
)
DecrementAccountWaitCount
(
ctx
context
.
Context
,
accountID
int64
)
{
if
s
.
cache
==
nil
{
return
}
bgCtx
,
cancel
:=
context
.
WithTimeout
(
context
.
Background
(),
5
*
time
.
Second
)
defer
cancel
()
if
err
:=
s
.
cache
.
DecrementAccountWaitCount
(
bgCtx
,
accountID
);
err
!=
nil
{
log
.
Printf
(
"Warning: decrement wait count failed for account %d: %v"
,
accountID
,
err
)
}
}
// GetAccountWaitingCount gets current wait queue count for an account.
func
(
s
*
ConcurrencyService
)
GetAccountWaitingCount
(
ctx
context
.
Context
,
accountID
int64
)
(
int
,
error
)
{
if
s
.
cache
==
nil
{
return
0
,
nil
}
return
s
.
cache
.
GetAccountWaitingCount
(
ctx
,
accountID
)
}
// CalculateMaxWait calculates the maximum wait queue size for a user
// maxWait = userConcurrency + defaultExtraWaitSlots
func
CalculateMaxWait
(
userConcurrency
int
)
int
{
...
...
@@ -186,6 +245,57 @@ func CalculateMaxWait(userConcurrency int) int {
return
userConcurrency
+
defaultExtraWaitSlots
}
// GetAccountsLoadBatch returns load info for multiple accounts.
func
(
s
*
ConcurrencyService
)
GetAccountsLoadBatch
(
ctx
context
.
Context
,
accounts
[]
AccountWithConcurrency
)
(
map
[
int64
]
*
AccountLoadInfo
,
error
)
{
if
s
.
cache
==
nil
{
return
map
[
int64
]
*
AccountLoadInfo
{},
nil
}
return
s
.
cache
.
GetAccountsLoadBatch
(
ctx
,
accounts
)
}
// CleanupExpiredAccountSlots removes expired slots for one account (background task).
func
(
s
*
ConcurrencyService
)
CleanupExpiredAccountSlots
(
ctx
context
.
Context
,
accountID
int64
)
error
{
if
s
.
cache
==
nil
{
return
nil
}
return
s
.
cache
.
CleanupExpiredAccountSlots
(
ctx
,
accountID
)
}
// StartSlotCleanupWorker starts a background cleanup worker for expired account slots.
func
(
s
*
ConcurrencyService
)
StartSlotCleanupWorker
(
accountRepo
AccountRepository
,
interval
time
.
Duration
)
{
if
s
==
nil
||
s
.
cache
==
nil
||
accountRepo
==
nil
||
interval
<=
0
{
return
}
runCleanup
:=
func
()
{
listCtx
,
cancel
:=
context
.
WithTimeout
(
context
.
Background
(),
5
*
time
.
Second
)
accounts
,
err
:=
accountRepo
.
ListSchedulable
(
listCtx
)
cancel
()
if
err
!=
nil
{
log
.
Printf
(
"Warning: list schedulable accounts failed: %v"
,
err
)
return
}
for
_
,
account
:=
range
accounts
{
accountCtx
,
accountCancel
:=
context
.
WithTimeout
(
context
.
Background
(),
2
*
time
.
Second
)
err
:=
s
.
cache
.
CleanupExpiredAccountSlots
(
accountCtx
,
account
.
ID
)
accountCancel
()
if
err
!=
nil
{
log
.
Printf
(
"Warning: cleanup expired slots failed for account %d: %v"
,
account
.
ID
,
err
)
}
}
}
go
func
()
{
ticker
:=
time
.
NewTicker
(
interval
)
defer
ticker
.
Stop
()
runCleanup
()
for
range
ticker
.
C
{
runCleanup
()
}
}()
}
// GetAccountConcurrencyBatch gets current concurrency counts for multiple accounts
// Returns a map of accountID -> current concurrency count
func
(
s
*
ConcurrencyService
)
GetAccountConcurrencyBatch
(
ctx
context
.
Context
,
accountIDs
[]
int64
)
(
map
[
int64
]
int
,
error
)
{
...
...
backend/internal/service/crs_sync_service.go
View file @
7331220e
...
...
@@ -12,6 +12,8 @@ import (
"strconv"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/httpclient"
)
type
CRSSyncService
struct
{
...
...
@@ -193,7 +195,12 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
return
nil
,
errors
.
New
(
"username and password are required"
)
}
client
:=
&
http
.
Client
{
Timeout
:
20
*
time
.
Second
}
client
,
err
:=
httpclient
.
GetClient
(
httpclient
.
Options
{
Timeout
:
20
*
time
.
Second
,
})
if
err
!=
nil
{
client
=
&
http
.
Client
{
Timeout
:
20
*
time
.
Second
}
}
adminToken
,
err
:=
crsLogin
(
ctx
,
client
,
baseURL
,
input
.
Username
,
input
.
Password
)
if
err
!=
nil
{
...
...
backend/internal/service/domain_constants.go
View file @
7331220e
...
...
@@ -91,6 +91,9 @@ const (
// 管理员 API Key
SettingKeyAdminApiKey
=
"admin_api_key"
// 全局管理员 API Key(用于外部系统集成)
// Gemini 配额策略(JSON)
SettingKeyGeminiQuotaPolicy
=
"gemini_quota_policy"
)
// Admin API Key prefix (distinct from user "sk-" keys)
...
...
backend/internal/service/email_service.go
View file @
7331220e
...
...
@@ -10,7 +10,7 @@ import (
"strconv"
"time"
infraerrors
"github.com/Wei-Shaw/sub2api/internal/
infrastructure
/errors"
infraerrors
"github.com/Wei-Shaw/sub2api/internal/
pkg
/errors"
)
var
(
...
...
backend/internal/service/gateway_multiplatform_test.go
View file @
7331220e
...
...
@@ -261,6 +261,34 @@ func TestGatewayService_SelectAccountForModelWithPlatform_PriorityAndLastUsed(t
require
.
Equal
(
t
,
int64
(
2
),
acc
.
ID
,
"同优先级应选择最久未用的账户"
)
}
func
TestGatewayService_SelectAccountForModelWithPlatform_GeminiOAuthPreference
(
t
*
testing
.
T
)
{
ctx
:=
context
.
Background
()
repo
:=
&
mockAccountRepoForPlatform
{
accounts
:
[]
Account
{
{
ID
:
1
,
Platform
:
PlatformGemini
,
Priority
:
1
,
Status
:
StatusActive
,
Schedulable
:
true
,
Type
:
AccountTypeApiKey
},
{
ID
:
2
,
Platform
:
PlatformGemini
,
Priority
:
1
,
Status
:
StatusActive
,
Schedulable
:
true
,
Type
:
AccountTypeOAuth
},
},
accountsByID
:
map
[
int64
]
*
Account
{},
}
for
i
:=
range
repo
.
accounts
{
repo
.
accountsByID
[
repo
.
accounts
[
i
]
.
ID
]
=
&
repo
.
accounts
[
i
]
}
cache
:=
&
mockGatewayCacheForPlatform
{}
svc
:=
&
GatewayService
{
accountRepo
:
repo
,
cache
:
cache
,
cfg
:
testConfig
(),
}
acc
,
err
:=
svc
.
selectAccountForModelWithPlatform
(
ctx
,
nil
,
""
,
"gemini-2.5-pro"
,
nil
,
PlatformGemini
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
acc
)
require
.
Equal
(
t
,
int64
(
2
),
acc
.
ID
,
"同优先级且未使用时应优先选择OAuth账户"
)
}
// TestGatewayService_SelectAccountForModelWithPlatform_NoAvailableAccounts 测试无可用账户
func
TestGatewayService_SelectAccountForModelWithPlatform_NoAvailableAccounts
(
t
*
testing
.
T
)
{
ctx
:=
context
.
Background
()
...
...
@@ -576,6 +604,32 @@ func TestGatewayService_isModelSupportedByAccount(t *testing.T) {
func
TestGatewayService_selectAccountWithMixedScheduling
(
t
*
testing
.
T
)
{
ctx
:=
context
.
Background
()
t
.
Run
(
"混合调度-Gemini优先选择OAuth账户"
,
func
(
t
*
testing
.
T
)
{
repo
:=
&
mockAccountRepoForPlatform
{
accounts
:
[]
Account
{
{
ID
:
1
,
Platform
:
PlatformGemini
,
Priority
:
1
,
Status
:
StatusActive
,
Schedulable
:
true
,
Type
:
AccountTypeApiKey
},
{
ID
:
2
,
Platform
:
PlatformGemini
,
Priority
:
1
,
Status
:
StatusActive
,
Schedulable
:
true
,
Type
:
AccountTypeOAuth
},
},
accountsByID
:
map
[
int64
]
*
Account
{},
}
for
i
:=
range
repo
.
accounts
{
repo
.
accountsByID
[
repo
.
accounts
[
i
]
.
ID
]
=
&
repo
.
accounts
[
i
]
}
cache
:=
&
mockGatewayCacheForPlatform
{}
svc
:=
&
GatewayService
{
accountRepo
:
repo
,
cache
:
cache
,
cfg
:
testConfig
(),
}
acc
,
err
:=
svc
.
selectAccountWithMixedScheduling
(
ctx
,
nil
,
""
,
"gemini-2.5-pro"
,
nil
,
PlatformGemini
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
acc
)
require
.
Equal
(
t
,
int64
(
2
),
acc
.
ID
,
"同优先级且未使用时应优先选择OAuth账户"
)
})
t
.
Run
(
"混合调度-包含启用mixed_scheduling的antigravity账户"
,
func
(
t
*
testing
.
T
)
{
repo
:=
&
mockAccountRepoForPlatform
{
accounts
:
[]
Account
{
...
...
@@ -783,3 +837,160 @@ func TestAccount_IsMixedSchedulingEnabled(t *testing.T) {
})
}
}
// mockConcurrencyService for testing
type
mockConcurrencyService
struct
{
accountLoads
map
[
int64
]
*
AccountLoadInfo
accountWaitCounts
map
[
int64
]
int
acquireResults
map
[
int64
]
bool
}
func
(
m
*
mockConcurrencyService
)
GetAccountsLoadBatch
(
ctx
context
.
Context
,
accounts
[]
AccountWithConcurrency
)
(
map
[
int64
]
*
AccountLoadInfo
,
error
)
{
if
m
.
accountLoads
==
nil
{
return
map
[
int64
]
*
AccountLoadInfo
{},
nil
}
result
:=
make
(
map
[
int64
]
*
AccountLoadInfo
)
for
_
,
acc
:=
range
accounts
{
if
load
,
ok
:=
m
.
accountLoads
[
acc
.
ID
];
ok
{
result
[
acc
.
ID
]
=
load
}
else
{
result
[
acc
.
ID
]
=
&
AccountLoadInfo
{
AccountID
:
acc
.
ID
,
CurrentConcurrency
:
0
,
WaitingCount
:
0
,
LoadRate
:
0
,
}
}
}
return
result
,
nil
}
func
(
m
*
mockConcurrencyService
)
GetAccountWaitingCount
(
ctx
context
.
Context
,
accountID
int64
)
(
int
,
error
)
{
if
m
.
accountWaitCounts
==
nil
{
return
0
,
nil
}
return
m
.
accountWaitCounts
[
accountID
],
nil
}
// TestGatewayService_SelectAccountWithLoadAwareness tests load-aware account selection
func
TestGatewayService_SelectAccountWithLoadAwareness
(
t
*
testing
.
T
)
{
ctx
:=
context
.
Background
()
t
.
Run
(
"禁用负载批量查询-降级到传统选择"
,
func
(
t
*
testing
.
T
)
{
repo
:=
&
mockAccountRepoForPlatform
{
accounts
:
[]
Account
{
{
ID
:
1
,
Platform
:
PlatformAnthropic
,
Priority
:
1
,
Status
:
StatusActive
,
Schedulable
:
true
,
Concurrency
:
5
},
{
ID
:
2
,
Platform
:
PlatformAnthropic
,
Priority
:
2
,
Status
:
StatusActive
,
Schedulable
:
true
,
Concurrency
:
5
},
},
accountsByID
:
map
[
int64
]
*
Account
{},
}
for
i
:=
range
repo
.
accounts
{
repo
.
accountsByID
[
repo
.
accounts
[
i
]
.
ID
]
=
&
repo
.
accounts
[
i
]
}
cache
:=
&
mockGatewayCacheForPlatform
{}
cfg
:=
testConfig
()
cfg
.
Gateway
.
Scheduling
.
LoadBatchEnabled
=
false
svc
:=
&
GatewayService
{
accountRepo
:
repo
,
cache
:
cache
,
cfg
:
cfg
,
concurrencyService
:
nil
,
// No concurrency service
}
result
,
err
:=
svc
.
SelectAccountWithLoadAwareness
(
ctx
,
nil
,
""
,
"claude-3-5-sonnet-20241022"
,
nil
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
result
)
require
.
NotNil
(
t
,
result
.
Account
)
require
.
Equal
(
t
,
int64
(
1
),
result
.
Account
.
ID
,
"应选择优先级最高的账号"
)
})
t
.
Run
(
"无ConcurrencyService-降级到传统选择"
,
func
(
t
*
testing
.
T
)
{
repo
:=
&
mockAccountRepoForPlatform
{
accounts
:
[]
Account
{
{
ID
:
1
,
Platform
:
PlatformAnthropic
,
Priority
:
2
,
Status
:
StatusActive
,
Schedulable
:
true
,
Concurrency
:
5
},
{
ID
:
2
,
Platform
:
PlatformAnthropic
,
Priority
:
1
,
Status
:
StatusActive
,
Schedulable
:
true
,
Concurrency
:
5
},
},
accountsByID
:
map
[
int64
]
*
Account
{},
}
for
i
:=
range
repo
.
accounts
{
repo
.
accountsByID
[
repo
.
accounts
[
i
]
.
ID
]
=
&
repo
.
accounts
[
i
]
}
cache
:=
&
mockGatewayCacheForPlatform
{}
cfg
:=
testConfig
()
cfg
.
Gateway
.
Scheduling
.
LoadBatchEnabled
=
true
svc
:=
&
GatewayService
{
accountRepo
:
repo
,
cache
:
cache
,
cfg
:
cfg
,
concurrencyService
:
nil
,
}
result
,
err
:=
svc
.
SelectAccountWithLoadAwareness
(
ctx
,
nil
,
""
,
"claude-3-5-sonnet-20241022"
,
nil
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
result
)
require
.
NotNil
(
t
,
result
.
Account
)
require
.
Equal
(
t
,
int64
(
2
),
result
.
Account
.
ID
,
"应选择优先级最高的账号"
)
})
t
.
Run
(
"排除账号-不选择被排除的账号"
,
func
(
t
*
testing
.
T
)
{
repo
:=
&
mockAccountRepoForPlatform
{
accounts
:
[]
Account
{
{
ID
:
1
,
Platform
:
PlatformAnthropic
,
Priority
:
1
,
Status
:
StatusActive
,
Schedulable
:
true
,
Concurrency
:
5
},
{
ID
:
2
,
Platform
:
PlatformAnthropic
,
Priority
:
2
,
Status
:
StatusActive
,
Schedulable
:
true
,
Concurrency
:
5
},
},
accountsByID
:
map
[
int64
]
*
Account
{},
}
for
i
:=
range
repo
.
accounts
{
repo
.
accountsByID
[
repo
.
accounts
[
i
]
.
ID
]
=
&
repo
.
accounts
[
i
]
}
cache
:=
&
mockGatewayCacheForPlatform
{}
cfg
:=
testConfig
()
cfg
.
Gateway
.
Scheduling
.
LoadBatchEnabled
=
false
svc
:=
&
GatewayService
{
accountRepo
:
repo
,
cache
:
cache
,
cfg
:
cfg
,
concurrencyService
:
nil
,
}
excludedIDs
:=
map
[
int64
]
struct
{}{
1
:
{}}
result
,
err
:=
svc
.
SelectAccountWithLoadAwareness
(
ctx
,
nil
,
""
,
"claude-3-5-sonnet-20241022"
,
excludedIDs
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
result
)
require
.
NotNil
(
t
,
result
.
Account
)
require
.
Equal
(
t
,
int64
(
2
),
result
.
Account
.
ID
,
"不应选择被排除的账号"
)
})
t
.
Run
(
"无可用账号-返回错误"
,
func
(
t
*
testing
.
T
)
{
repo
:=
&
mockAccountRepoForPlatform
{
accounts
:
[]
Account
{},
accountsByID
:
map
[
int64
]
*
Account
{},
}
cache
:=
&
mockGatewayCacheForPlatform
{}
cfg
:=
testConfig
()
cfg
.
Gateway
.
Scheduling
.
LoadBatchEnabled
=
false
svc
:=
&
GatewayService
{
accountRepo
:
repo
,
cache
:
cache
,
cfg
:
cfg
,
concurrencyService
:
nil
,
}
result
,
err
:=
svc
.
SelectAccountWithLoadAwareness
(
ctx
,
nil
,
""
,
"claude-3-5-sonnet-20241022"
,
nil
)
require
.
Error
(
t
,
err
)
require
.
Nil
(
t
,
result
)
require
.
Contains
(
t
,
err
.
Error
(),
"no available accounts"
)
})
}
backend/internal/service/gateway_request.go
0 → 100644
View file @
7331220e
package
service
import
(
"encoding/json"
"fmt"
)
// ParsedRequest 保存网关请求的预解析结果
//
// 性能优化说明:
// 原实现在多个位置重复解析请求体(Handler、Service 各解析一次):
// 1. gateway_handler.go 解析获取 model 和 stream
// 2. gateway_service.go 再次解析获取 system、messages、metadata
// 3. GenerateSessionHash 又一次解析获取会话哈希所需字段
//
// 新实现一次解析,多处复用:
// 1. 在 Handler 层统一调用 ParseGatewayRequest 一次性解析
// 2. 将解析结果 ParsedRequest 传递给 Service 层
// 3. 避免重复 json.Unmarshal,减少 CPU 和内存开销
type
ParsedRequest
struct
{
Body
[]
byte
// 原始请求体(保留用于转发)
Model
string
// 请求的模型名称
Stream
bool
// 是否为流式请求
MetadataUserID
string
// metadata.user_id(用于会话亲和)
System
any
// system 字段内容
Messages
[]
any
// messages 数组
HasSystem
bool
// 是否包含 system 字段(包含 null 也视为显式传入)
}
// ParseGatewayRequest 解析网关请求体并返回结构化结果
// 性能优化:一次解析提取所有需要的字段,避免重复 Unmarshal
func
ParseGatewayRequest
(
body
[]
byte
)
(
*
ParsedRequest
,
error
)
{
var
req
map
[
string
]
any
if
err
:=
json
.
Unmarshal
(
body
,
&
req
);
err
!=
nil
{
return
nil
,
err
}
parsed
:=
&
ParsedRequest
{
Body
:
body
,
}
if
rawModel
,
exists
:=
req
[
"model"
];
exists
{
model
,
ok
:=
rawModel
.
(
string
)
if
!
ok
{
return
nil
,
fmt
.
Errorf
(
"invalid model field type"
)
}
parsed
.
Model
=
model
}
if
rawStream
,
exists
:=
req
[
"stream"
];
exists
{
stream
,
ok
:=
rawStream
.
(
bool
)
if
!
ok
{
return
nil
,
fmt
.
Errorf
(
"invalid stream field type"
)
}
parsed
.
Stream
=
stream
}
if
metadata
,
ok
:=
req
[
"metadata"
]
.
(
map
[
string
]
any
);
ok
{
if
userID
,
ok
:=
metadata
[
"user_id"
]
.
(
string
);
ok
{
parsed
.
MetadataUserID
=
userID
}
}
// system 字段只要存在就视为显式提供(即使为 null),
// 以避免客户端传 null 时被默认 system 误注入。
if
system
,
ok
:=
req
[
"system"
];
ok
{
parsed
.
HasSystem
=
true
parsed
.
System
=
system
}
if
messages
,
ok
:=
req
[
"messages"
]
.
([]
any
);
ok
{
parsed
.
Messages
=
messages
}
return
parsed
,
nil
}
Prev
1
…
4
5
6
7
8
9
10
11
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment