Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Menu
Open sidebar
陈曦
sub2api
Commits
0b746501
Commit
0b746501
authored
Apr 16, 2026
by
陈曦
Browse files
1. merge upstream v0.1.113 2.提交migration相关文件
parents
45061102
be7551b9
Changes
225
Hide whitespace changes
Inline
Side-by-side
backend/internal/service/gateway_record_usage_test.go
View file @
0b746501
...
@@ -43,6 +43,7 @@ func newGatewayRecordUsageServiceForTest(usageRepo UsageLogRepository, userRepo
...
@@ -43,6 +43,7 @@ func newGatewayRecordUsageServiceForTest(usageRepo UsageLogRepository, userRepo
nil
,
nil
,
nil
,
nil
,
nil
,
nil
,
nil
,
)
)
}
}
...
...
backend/internal/service/gateway_request.go
View file @
0b746501
...
@@ -75,6 +75,9 @@ type ParsedRequest struct {
...
@@ -75,6 +75,9 @@ type ParsedRequest struct {
MaxTokens
int
// max_tokens 值(用于探测请求拦截)
MaxTokens
int
// max_tokens 值(用于探测请求拦截)
SessionContext
*
SessionContext
// 可选:请求上下文区分因子(nil 时行为不变)
SessionContext
*
SessionContext
// 可选:请求上下文区分因子(nil 时行为不变)
// GroupID 请求所属分组 ID(来自 API Key)
GroupID
*
int64
// OnUpstreamAccepted 上游接受请求后立即调用(用于提前释放串行锁)
// OnUpstreamAccepted 上游接受请求后立即调用(用于提前释放串行锁)
// 流式请求在收到 2xx 响应头后调用,避免持锁等流完成
// 流式请求在收到 2xx 响应头后调用,避免持锁等流完成
OnUpstreamAccepted
func
()
OnUpstreamAccepted
func
()
...
...
backend/internal/service/gateway_service.go
View file @
0b746501
...
@@ -503,7 +503,6 @@ type ForwardResult struct {
...
@@ -503,7 +503,6 @@ type ForwardResult struct {
// 图片生成计费字段(图片生成模型使用)
// 图片生成计费字段(图片生成模型使用)
ImageCount
int
// 生成的图片数量
ImageCount
int
// 生成的图片数量
ImageSize
string
// 图片尺寸 "1K", "2K", "4K"
ImageSize
string
// 图片尺寸 "1K", "2K", "4K"
}
}
// UpstreamFailoverError indicates an upstream error that should trigger account failover.
// UpstreamFailoverError indicates an upstream error that should trigger account failover.
...
@@ -570,6 +569,7 @@ type GatewayService struct {
...
@@ -570,6 +569,7 @@ type GatewayService struct {
resolver
*
ModelPricingResolver
resolver
*
ModelPricingResolver
debugGatewayBodyFile
atomic
.
Pointer
[
os
.
File
]
// non-nil when SUB2API_DEBUG_GATEWAY_BODY is set
debugGatewayBodyFile
atomic
.
Pointer
[
os
.
File
]
// non-nil when SUB2API_DEBUG_GATEWAY_BODY is set
tlsFPProfileService
*
TLSFingerprintProfileService
tlsFPProfileService
*
TLSFingerprintProfileService
balanceNotifyService
*
BalanceNotifyService
}
}
// NewGatewayService creates a new GatewayService
// NewGatewayService creates a new GatewayService
...
@@ -599,6 +599,7 @@ func NewGatewayService(
...
@@ -599,6 +599,7 @@ func NewGatewayService(
tlsFPProfileService
*
TLSFingerprintProfileService
,
tlsFPProfileService
*
TLSFingerprintProfileService
,
channelService
*
ChannelService
,
channelService
*
ChannelService
,
resolver
*
ModelPricingResolver
,
resolver
*
ModelPricingResolver
,
balanceNotifyService
*
BalanceNotifyService
,
)
*
GatewayService
{
)
*
GatewayService
{
userGroupRateTTL
:=
resolveUserGroupRateCacheTTL
(
cfg
)
userGroupRateTTL
:=
resolveUserGroupRateCacheTTL
(
cfg
)
modelsListTTL
:=
resolveModelsListCacheTTL
(
cfg
)
modelsListTTL
:=
resolveModelsListCacheTTL
(
cfg
)
...
@@ -633,6 +634,7 @@ func NewGatewayService(
...
@@ -633,6 +634,7 @@ func NewGatewayService(
tlsFPProfileService
:
tlsFPProfileService
,
tlsFPProfileService
:
tlsFPProfileService
,
channelService
:
channelService
,
channelService
:
channelService
,
resolver
:
resolver
,
resolver
:
resolver
,
balanceNotifyService
:
balanceNotifyService
,
}
}
svc
.
userGroupRateResolver
=
newUserGroupRateResolver
(
svc
.
userGroupRateResolver
=
newUserGroupRateResolver
(
userGroupRateRepo
,
userGroupRateRepo
,
...
@@ -1329,6 +1331,11 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
...
@@ -1329,6 +1331,11 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
ctx
=
s
.
withWindowCostPrefetch
(
ctx
,
accounts
)
ctx
=
s
.
withWindowCostPrefetch
(
ctx
,
accounts
)
ctx
=
s
.
withRPMPrefetch
(
ctx
,
accounts
)
ctx
=
s
.
withRPMPrefetch
(
ctx
,
accounts
)
// 提前构建 accountByID(供 Layer 1 和 Layer 1.5 使用)
accountByID
:=
make
(
map
[
int64
]
*
Account
,
len
(
accounts
))
for
i
:=
range
accounts
{
accountByID
[
accounts
[
i
]
.
ID
]
=
&
accounts
[
i
]
}
isExcluded
:=
func
(
accountID
int64
)
bool
{
isExcluded
:=
func
(
accountID
int64
)
bool
{
if
excludedIDs
==
nil
{
if
excludedIDs
==
nil
{
return
false
return
false
...
@@ -1337,12 +1344,6 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
...
@@ -1337,12 +1344,6 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
return
excluded
return
excluded
}
}
// 提前构建 accountByID(供 Layer 1 和 Layer 1.5 使用)
accountByID
:=
make
(
map
[
int64
]
*
Account
,
len
(
accounts
))
for
i
:=
range
accounts
{
accountByID
[
accounts
[
i
]
.
ID
]
=
&
accounts
[
i
]
}
// 获取模型路由配置(仅 anthropic 平台)
// 获取模型路由配置(仅 anthropic 平台)
var
routingAccountIDs
[]
int64
var
routingAccountIDs
[]
int64
if
group
!=
nil
&&
requestedModel
!=
""
&&
group
.
Platform
==
PlatformAnthropic
{
if
group
!=
nil
&&
requestedModel
!=
""
&&
group
.
Platform
==
PlatformAnthropic
{
...
@@ -1598,7 +1599,6 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
...
@@ -1598,7 +1599,6 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
account
,
ok
:=
accountByID
[
accountID
]
account
,
ok
:=
accountByID
[
accountID
]
if
ok
{
if
ok
{
// 检查账户是否需要清理粘性会话绑定
// 检查账户是否需要清理粘性会话绑定
// Check if the account needs sticky session cleanup
clearSticky
:=
shouldClearStickySession
(
account
,
requestedModel
)
clearSticky
:=
shouldClearStickySession
(
account
,
requestedModel
)
if
clearSticky
{
if
clearSticky
{
_
=
s
.
cache
.
DeleteSessionAccountID
(
ctx
,
derefGroupID
(
groupID
),
sessionHash
)
_
=
s
.
cache
.
DeleteSessionAccountID
(
ctx
,
derefGroupID
(
groupID
),
sessionHash
)
...
@@ -1614,7 +1614,6 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
...
@@ -1614,7 +1614,6 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
result
,
err
:=
s
.
tryAcquireAccountSlot
(
ctx
,
accountID
,
account
.
Concurrency
)
result
,
err
:=
s
.
tryAcquireAccountSlot
(
ctx
,
accountID
,
account
.
Concurrency
)
if
err
==
nil
&&
result
.
Acquired
{
if
err
==
nil
&&
result
.
Acquired
{
// 会话数量限制检查
// 会话数量限制检查
// Session count limit check
if
!
s
.
checkAndRegisterSession
(
ctx
,
account
,
sessionHash
)
{
if
!
s
.
checkAndRegisterSession
(
ctx
,
account
,
sessionHash
)
{
result
.
ReleaseFunc
()
// 释放槽位,继续到 Layer 2
result
.
ReleaseFunc
()
// 释放槽位,继续到 Layer 2
}
else
{
}
else
{
...
@@ -1628,10 +1627,8 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
...
@@ -1628,10 +1627,8 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
waitingCount
,
_
:=
s
.
concurrencyService
.
GetAccountWaitingCount
(
ctx
,
accountID
)
waitingCount
,
_
:=
s
.
concurrencyService
.
GetAccountWaitingCount
(
ctx
,
accountID
)
if
waitingCount
<
cfg
.
StickySessionMaxWaiting
{
if
waitingCount
<
cfg
.
StickySessionMaxWaiting
{
// 会话数量限制检查(等待计划也需要占用会话配额)
// 会话数量限制检查(等待计划也需要占用会话配额)
// Session count limit check (wait plan also requires session quota)
if
!
s
.
checkAndRegisterSession
(
ctx
,
account
,
sessionHash
)
{
if
!
s
.
checkAndRegisterSession
(
ctx
,
account
,
sessionHash
)
{
// 会话限制已满,继续到 Layer 2
// 会话限制已满,继续到 Layer 2
// Session limit full, continue to Layer 2
}
else
{
}
else
{
return
s
.
newSelectionResult
(
ctx
,
account
,
false
,
nil
,
&
AccountWaitPlan
{
return
s
.
newSelectionResult
(
ctx
,
account
,
false
,
nil
,
&
AccountWaitPlan
{
AccountID
:
accountID
,
AccountID
:
accountID
,
...
@@ -2740,7 +2737,7 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
...
@@ -2740,7 +2737,7 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
if
clearSticky
{
if
clearSticky
{
_
=
s
.
cache
.
DeleteSessionAccountID
(
ctx
,
derefGroupID
(
groupID
),
sessionHash
)
_
=
s
.
cache
.
DeleteSessionAccountID
(
ctx
,
derefGroupID
(
groupID
),
sessionHash
)
}
}
if
!
clearSticky
&&
s
.
isAccountInGroup
(
account
,
groupID
)
&&
account
.
Platform
==
platform
&&
(
requestedModel
==
""
||
s
.
isModelSupportedByAccountWithContext
(
ctx
,
account
,
requestedModel
))
&&
s
.
isAccountSchedulableForModelSelection
(
ctx
,
account
,
requestedModel
)
&&
s
.
isAccountSchedulableForQuota
(
account
)
&&
s
.
isAccountSchedulableForWindowCost
(
ctx
,
account
,
true
)
&&
s
.
isAccountSchedulableForRPM
(
ctx
,
account
,
true
)
{
if
!
clearSticky
&&
s
.
isAccountInGroup
(
account
,
groupID
)
&&
account
.
Platform
==
platform
&&
(
requestedModel
==
""
||
s
.
isModelSupportedByAccountWithContext
(
ctx
,
account
,
requestedModel
))
&&
s
.
isAccountSchedulableForModelSelection
(
ctx
,
account
,
requestedModel
)
&&
s
.
isAccountSchedulableForQuota
(
account
)
&&
s
.
isAccountSchedulableForWindowCost
(
ctx
,
account
,
true
)
&&
s
.
isAccountSchedulableForRPM
(
ctx
,
account
,
true
)
&&
!
s
.
isStickyAccountUpstreamRestricted
(
ctx
,
groupID
,
account
,
requestedModel
)
{
if
s
.
debugModelRoutingEnabled
()
{
if
s
.
debugModelRoutingEnabled
()
{
logger
.
LegacyPrintf
(
"service.gateway"
,
"[ModelRoutingDebug] legacy routed sticky hit: group_id=%v model=%s session=%s account=%d"
,
derefGroupID
(
groupID
),
requestedModel
,
shortSessionHash
(
sessionHash
),
accountID
)
logger
.
LegacyPrintf
(
"service.gateway"
,
"[ModelRoutingDebug] legacy routed sticky hit: group_id=%v model=%s session=%s account=%d"
,
derefGroupID
(
groupID
),
requestedModel
,
shortSessionHash
(
sessionHash
),
accountID
)
}
}
...
@@ -3119,7 +3116,7 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
...
@@ -3119,7 +3116,7 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
if
clearSticky
{
if
clearSticky
{
_
=
s
.
cache
.
DeleteSessionAccountID
(
ctx
,
derefGroupID
(
groupID
),
sessionHash
)
_
=
s
.
cache
.
DeleteSessionAccountID
(
ctx
,
derefGroupID
(
groupID
),
sessionHash
)
}
}
if
!
clearSticky
&&
s
.
isAccountInGroup
(
account
,
groupID
)
&&
(
requestedModel
==
""
||
s
.
isModelSupportedByAccountWithContext
(
ctx
,
account
,
requestedModel
))
&&
s
.
isAccountSchedulableForModelSelection
(
ctx
,
account
,
requestedModel
)
&&
s
.
isAccountSchedulableForQuota
(
account
)
&&
s
.
isAccountSchedulableForWindowCost
(
ctx
,
account
,
true
)
&&
s
.
isAccountSchedulableForRPM
(
ctx
,
account
,
true
)
{
if
!
clearSticky
&&
s
.
isAccountInGroup
(
account
,
groupID
)
&&
(
requestedModel
==
""
||
s
.
isModelSupportedByAccountWithContext
(
ctx
,
account
,
requestedModel
))
&&
s
.
isAccountSchedulableForModelSelection
(
ctx
,
account
,
requestedModel
)
&&
s
.
isAccountSchedulableForQuota
(
account
)
&&
s
.
isAccountSchedulableForWindowCost
(
ctx
,
account
,
true
)
&&
s
.
isAccountSchedulableForRPM
(
ctx
,
account
,
true
)
&&
!
s
.
isStickyAccountUpstreamRestricted
(
ctx
,
groupID
,
account
,
requestedModel
)
{
if
account
.
Platform
==
nativePlatform
||
(
account
.
Platform
==
PlatformAntigravity
&&
account
.
IsMixedSchedulingEnabled
())
{
if
account
.
Platform
==
nativePlatform
||
(
account
.
Platform
==
PlatformAntigravity
&&
account
.
IsMixedSchedulingEnabled
())
{
return
account
,
nil
return
account
,
nil
}
}
...
@@ -3435,6 +3432,10 @@ func (s *GatewayService) isModelSupportedByAccount(account *Account, requestedMo
...
@@ -3435,6 +3432,10 @@ func (s *GatewayService) isModelSupportedByAccount(account *Account, requestedMo
_
,
ok
:=
ResolveBedrockModelID
(
account
,
requestedModel
)
_
,
ok
:=
ResolveBedrockModelID
(
account
,
requestedModel
)
return
ok
return
ok
}
}
// OpenAI 透传模式:仅替换认证,允许所有模型
if
account
.
Platform
==
PlatformOpenAI
&&
account
.
IsOpenAIPassthroughEnabled
()
{
return
true
}
// OAuth/SetupToken 账号使用 Anthropic 标准映射(短ID → 长ID)
// OAuth/SetupToken 账号使用 Anthropic 标准映射(短ID → 长ID)
if
account
.
Platform
==
PlatformAnthropic
&&
account
.
Type
!=
AccountTypeAPIKey
{
if
account
.
Platform
==
PlatformAnthropic
&&
account
.
Type
!=
AccountTypeAPIKey
{
requestedModel
=
claude
.
NormalizeModelID
(
requestedModel
)
requestedModel
=
claude
.
NormalizeModelID
(
requestedModel
)
...
@@ -3934,6 +3935,11 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
...
@@ -3934,6 +3935,11 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
return
nil
,
fmt
.
Errorf
(
"parse request: empty request"
)
return
nil
,
fmt
.
Errorf
(
"parse request: empty request"
)
}
}
// Web Search 模拟:纯 web_search 请求时,直接调用搜索 API 构造响应
if
account
!=
nil
&&
s
.
shouldEmulateWebSearch
(
ctx
,
account
,
parsed
.
GroupID
,
parsed
.
Body
)
{
return
s
.
handleWebSearchEmulation
(
ctx
,
c
,
account
,
parsed
)
}
if
account
!=
nil
&&
account
.
IsAnthropicAPIKeyPassthroughEnabled
()
{
if
account
!=
nil
&&
account
.
IsAnthropicAPIKeyPassthroughEnabled
()
{
passthroughBody
:=
parsed
.
Body
passthroughBody
:=
parsed
.
Body
passthroughModel
:=
parsed
.
Model
passthroughModel
:=
parsed
.
Model
...
@@ -7279,6 +7285,7 @@ func (s *GatewayService) getUserGroupRateMultiplier(ctx context.Context, userID,
...
@@ -7279,6 +7285,7 @@ func (s *GatewayService) getUserGroupRateMultiplier(ctx context.Context, userID,
// RecordUsageInput 记录使用量的输入参数
// RecordUsageInput 记录使用量的输入参数
type
RecordUsageInput
struct
{
type
RecordUsageInput
struct
{
Result
*
ForwardResult
Result
*
ForwardResult
ParsedRequest
*
ParsedRequest
APIKey
*
APIKey
APIKey
*
APIKey
User
*
User
User
*
User
Account
*
Account
Account
*
Account
...
@@ -7333,49 +7340,41 @@ func (p *postUsageBillingParams) shouldUpdateAccountQuota() bool {
...
@@ -7333,49 +7340,41 @@ func (p *postUsageBillingParams) shouldUpdateAccountQuota() bool {
return
p
.
Cost
.
TotalCost
>
0
&&
p
.
Account
.
IsAPIKeyOrBedrock
()
&&
p
.
Account
.
HasAnyQuotaLimit
()
return
p
.
Cost
.
TotalCost
>
0
&&
p
.
Account
.
IsAPIKeyOrBedrock
()
&&
p
.
Account
.
HasAnyQuotaLimit
()
}
}
// postUsageBilling 统一处理使用量记录后的扣费逻辑:
// postUsageBilling is the legacy fallback billing path used when the unified
// - 订阅/余额扣费
// billing repo is unavailable (nil). Production uses applyUsageBilling → repo.Apply
// - API Key 配额更新
// for atomic billing. This path only runs in tests or degraded mode.
// - API Key 限速用量更新
// - 账号配额用量更新(账号口径:TotalCost × 账号计费倍率)
func
postUsageBilling
(
ctx
context
.
Context
,
p
*
postUsageBillingParams
,
deps
*
billingDeps
)
{
func
postUsageBilling
(
ctx
context
.
Context
,
p
*
postUsageBillingParams
,
deps
*
billingDeps
)
{
billingCtx
,
cancel
:=
detachedBillingContext
(
ctx
)
billingCtx
,
cancel
:=
detachedBillingContext
(
ctx
)
defer
cancel
()
defer
cancel
()
cost
:=
p
.
Cost
cost
:=
p
.
Cost
// 1. 订阅 / 余额扣费
if
p
.
IsSubscriptionBill
{
if
p
.
IsSubscriptionBill
{
if
cost
.
TotalCost
>
0
{
if
cost
.
TotalCost
>
0
{
if
err
:=
deps
.
userSubRepo
.
IncrementUsage
(
billingCtx
,
p
.
Subscription
.
ID
,
cost
.
TotalCost
);
err
!=
nil
{
if
err
:=
deps
.
userSubRepo
.
IncrementUsage
(
billingCtx
,
p
.
Subscription
.
ID
,
cost
.
TotalCost
);
err
!=
nil
{
slog
.
Error
(
"increment subscription usage failed"
,
"subscription_id"
,
p
.
Subscription
.
ID
,
"error"
,
err
)
slog
.
Error
(
"increment subscription usage failed"
,
"subscription_id"
,
p
.
Subscription
.
ID
,
"error"
,
err
)
}
}
deps
.
billingCacheService
.
QueueUpdateSubscriptionUsage
(
p
.
User
.
ID
,
*
p
.
APIKey
.
GroupID
,
cost
.
TotalCost
)
}
}
}
else
{
}
else
{
if
cost
.
ActualCost
>
0
{
if
cost
.
ActualCost
>
0
{
if
err
:=
deps
.
userRepo
.
DeductBalance
(
billingCtx
,
p
.
User
.
ID
,
cost
.
ActualCost
);
err
!=
nil
{
if
err
:=
deps
.
userRepo
.
DeductBalance
(
billingCtx
,
p
.
User
.
ID
,
cost
.
ActualCost
);
err
!=
nil
{
slog
.
Error
(
"deduct balance failed"
,
"user_id"
,
p
.
User
.
ID
,
"error"
,
err
)
slog
.
Error
(
"deduct balance failed"
,
"user_id"
,
p
.
User
.
ID
,
"error"
,
err
)
}
}
deps
.
billingCacheService
.
QueueDeductBalance
(
p
.
User
.
ID
,
cost
.
ActualCost
)
}
}
}
}
// 2. API Key 配额
if
p
.
shouldDeductAPIKeyQuota
()
{
if
p
.
shouldDeductAPIKeyQuota
()
{
if
err
:=
p
.
APIKeyService
.
UpdateQuotaUsed
(
billingCtx
,
p
.
APIKey
.
ID
,
cost
.
ActualCost
);
err
!=
nil
{
if
err
:=
p
.
APIKeyService
.
UpdateQuotaUsed
(
billingCtx
,
p
.
APIKey
.
ID
,
cost
.
ActualCost
);
err
!=
nil
{
slog
.
Error
(
"update api key quota failed"
,
"api_key_id"
,
p
.
APIKey
.
ID
,
"error"
,
err
)
slog
.
Error
(
"update api key quota failed"
,
"api_key_id"
,
p
.
APIKey
.
ID
,
"error"
,
err
)
}
}
}
}
// 3. API Key 限速用量
if
p
.
shouldUpdateRateLimits
()
{
if
p
.
shouldUpdateRateLimits
()
{
if
err
:=
p
.
APIKeyService
.
UpdateRateLimitUsage
(
billingCtx
,
p
.
APIKey
.
ID
,
cost
.
ActualCost
);
err
!=
nil
{
if
err
:=
p
.
APIKeyService
.
UpdateRateLimitUsage
(
billingCtx
,
p
.
APIKey
.
ID
,
cost
.
ActualCost
);
err
!=
nil
{
slog
.
Error
(
"update api key rate limit usage failed"
,
"api_key_id"
,
p
.
APIKey
.
ID
,
"error"
,
err
)
slog
.
Error
(
"update api key rate limit usage failed"
,
"api_key_id"
,
p
.
APIKey
.
ID
,
"error"
,
err
)
}
}
}
}
// 4. 账号配额用量(账号口径:TotalCost × 账号计费倍率)
if
p
.
shouldUpdateAccountQuota
()
{
if
p
.
shouldUpdateAccountQuota
()
{
accountCost
:=
cost
.
TotalCost
*
p
.
AccountRateMultiplier
accountCost
:=
cost
.
TotalCost
*
p
.
AccountRateMultiplier
if
err
:=
deps
.
accountRepo
.
IncrementQuotaUsed
(
billingCtx
,
p
.
Account
.
ID
,
accountCost
);
err
!=
nil
{
if
err
:=
deps
.
accountRepo
.
IncrementQuotaUsed
(
billingCtx
,
p
.
Account
.
ID
,
accountCost
);
err
!=
nil
{
...
@@ -7383,7 +7382,10 @@ func postUsageBilling(ctx context.Context, p *postUsageBillingParams, deps *bill
...
@@ -7383,7 +7382,10 @@ func postUsageBilling(ctx context.Context, p *postUsageBillingParams, deps *bill
}
}
}
}
finalizePostUsageBilling
(
p
,
deps
)
// NOTE: finalizePostUsageBilling is NOT called here to avoid double-queuing
// cache updates. The legacy path does DB writes directly; the finalize path
// does cache queue + notifications. Notifications are dispatched separately
// by the caller after recording the usage log.
}
}
func
resolveUsageBillingRequestID
(
ctx
context
.
Context
,
upstreamRequestID
string
)
string
{
func
resolveUsageBillingRequestID
(
ctx
context
.
Context
,
upstreamRequestID
string
)
string
{
...
@@ -7499,11 +7501,11 @@ func applyUsageBilling(ctx context.Context, requestID string, usageLog *UsageLog
...
@@ -7499,11 +7501,11 @@ func applyUsageBilling(ctx context.Context, requestID string, usageLog *UsageLog
}
}
}
}
finalizePostUsageBilling
(
p
,
deps
)
finalizePostUsageBilling
(
p
,
deps
,
result
)
return
true
,
nil
return
true
,
nil
}
}
func
finalizePostUsageBilling
(
p
*
postUsageBillingParams
,
deps
*
billingDeps
)
{
func
finalizePostUsageBilling
(
p
*
postUsageBillingParams
,
deps
*
billingDeps
,
result
*
UsageBillingApplyResult
)
{
if
p
==
nil
||
p
.
Cost
==
nil
||
deps
==
nil
{
if
p
==
nil
||
p
.
Cost
==
nil
||
deps
==
nil
{
return
return
}
}
...
@@ -7521,6 +7523,83 @@ func finalizePostUsageBilling(p *postUsageBillingParams, deps *billingDeps) {
...
@@ -7521,6 +7523,83 @@ func finalizePostUsageBilling(p *postUsageBillingParams, deps *billingDeps) {
}
}
deps
.
deferredService
.
ScheduleLastUsedUpdate
(
p
.
Account
.
ID
)
deps
.
deferredService
.
ScheduleLastUsedUpdate
(
p
.
Account
.
ID
)
// Notification checks run async — all parameters are already captured,
// no dependency on the request context or upstream connection.
go
notifyBalanceLow
(
p
,
deps
,
result
)
go
notifyAccountQuota
(
p
,
deps
,
result
)
}
// notifyBalanceLow sends balance low notification after deduction.
// When result.NewBalance is available (from DB transaction RETURNING), it is used directly
// to reconstruct oldBalance, avoiding stale Redis reads and concurrent-deduction races.
func
notifyBalanceLow
(
p
*
postUsageBillingParams
,
deps
*
billingDeps
,
result
*
UsageBillingApplyResult
)
{
defer
func
()
{
if
r
:=
recover
();
r
!=
nil
{
slog
.
Error
(
"panic in notifyBalanceLow"
,
"recover"
,
r
)
}
}()
if
p
.
IsSubscriptionBill
||
p
.
Cost
.
ActualCost
<=
0
||
p
.
User
==
nil
||
deps
.
balanceNotifyService
==
nil
{
slog
.
Debug
(
"notifyBalanceLow: skipped"
,
"is_subscription"
,
p
.
IsSubscriptionBill
,
"actual_cost"
,
p
.
Cost
.
ActualCost
,
"user_nil"
,
p
.
User
==
nil
,
"service_nil"
,
deps
.
balanceNotifyService
==
nil
,
)
return
}
oldBalance
:=
resolveOldBalance
(
p
,
result
)
slog
.
Debug
(
"notifyBalanceLow: calling CheckBalanceAfterDeduction"
,
"user_id"
,
p
.
User
.
ID
,
"old_balance"
,
oldBalance
,
"cost"
,
p
.
Cost
.
ActualCost
,
"notify_enabled"
,
p
.
User
.
BalanceNotifyEnabled
,
"threshold"
,
p
.
User
.
BalanceNotifyThreshold
,
"result_has_new_balance"
,
result
!=
nil
&&
result
.
NewBalance
!=
nil
,
)
deps
.
balanceNotifyService
.
CheckBalanceAfterDeduction
(
context
.
Background
(),
p
.
User
,
oldBalance
,
p
.
Cost
.
ActualCost
)
}
// resolveOldBalance returns the pre-deduction balance.
// Prefers the DB transaction result (newBalance + cost) over snapshot.
func
resolveOldBalance
(
p
*
postUsageBillingParams
,
result
*
UsageBillingApplyResult
)
float64
{
if
result
!=
nil
&&
result
.
NewBalance
!=
nil
{
return
*
result
.
NewBalance
+
p
.
Cost
.
ActualCost
}
// Legacy fallback: snapshot balance from request context
return
p
.
User
.
Balance
}
// notifyAccountQuota sends account quota threshold notification after increment.
// When result.QuotaState is available (from DB transaction RETURNING), it is passed directly
// to avoid a separate DB read that may see stale or concurrently-modified data.
func
notifyAccountQuota
(
p
*
postUsageBillingParams
,
deps
*
billingDeps
,
result
*
UsageBillingApplyResult
)
{
defer
func
()
{
if
r
:=
recover
();
r
!=
nil
{
slog
.
Error
(
"panic in notifyAccountQuota"
,
"recover"
,
r
)
}
}()
if
p
.
Cost
.
TotalCost
<=
0
||
p
.
Account
==
nil
||
!
p
.
Account
.
IsAPIKeyOrBedrock
()
||
deps
.
balanceNotifyService
==
nil
{
slog
.
Debug
(
"notifyAccountQuota: skipped"
,
"total_cost"
,
p
.
Cost
.
TotalCost
,
"account_nil"
,
p
.
Account
==
nil
,
"is_apikey_or_bedrock"
,
p
.
Account
!=
nil
&&
p
.
Account
.
IsAPIKeyOrBedrock
(),
"service_nil"
,
deps
.
balanceNotifyService
==
nil
,
)
return
}
accountCost
:=
p
.
Cost
.
TotalCost
*
p
.
AccountRateMultiplier
var
quotaState
*
AccountQuotaState
if
result
!=
nil
{
quotaState
=
result
.
QuotaState
}
slog
.
Debug
(
"notifyAccountQuota: calling CheckAccountQuotaAfterIncrement"
,
"account_id"
,
p
.
Account
.
ID
,
"account_cost"
,
accountCost
,
"has_quota_state"
,
quotaState
!=
nil
,
)
deps
.
balanceNotifyService
.
CheckAccountQuotaAfterIncrement
(
context
.
Background
(),
p
.
Account
,
accountCost
,
quotaState
)
}
}
func
detachedBillingContext
(
ctx
context
.
Context
)
(
context
.
Context
,
context
.
CancelFunc
)
{
func
detachedBillingContext
(
ctx
context
.
Context
)
(
context
.
Context
,
context
.
CancelFunc
)
{
...
@@ -7543,20 +7622,22 @@ func detachStreamUpstreamContext(ctx context.Context, stream bool) (context.Cont
...
@@ -7543,20 +7622,22 @@ func detachStreamUpstreamContext(ctx context.Context, stream bool) (context.Cont
// billingDeps 扣费逻辑依赖的服务(由各 gateway service 提供)
// billingDeps 扣费逻辑依赖的服务(由各 gateway service 提供)
type
billingDeps
struct
{
type
billingDeps
struct
{
accountRepo
AccountRepository
accountRepo
AccountRepository
userRepo
UserRepository
userRepo
UserRepository
userSubRepo
UserSubscriptionRepository
userSubRepo
UserSubscriptionRepository
billingCacheService
*
BillingCacheService
billingCacheService
*
BillingCacheService
deferredService
*
DeferredService
deferredService
*
DeferredService
balanceNotifyService
*
BalanceNotifyService
}
}
func
(
s
*
GatewayService
)
billingDeps
()
*
billingDeps
{
func
(
s
*
GatewayService
)
billingDeps
()
*
billingDeps
{
return
&
billingDeps
{
return
&
billingDeps
{
accountRepo
:
s
.
accountRepo
,
accountRepo
:
s
.
accountRepo
,
userRepo
:
s
.
userRepo
,
userRepo
:
s
.
userRepo
,
userSubRepo
:
s
.
userSubRepo
,
userSubRepo
:
s
.
userSubRepo
,
billingCacheService
:
s
.
billingCacheService
,
billingCacheService
:
s
.
billingCacheService
,
deferredService
:
s
.
deferredService
,
deferredService
:
s
.
deferredService
,
balanceNotifyService
:
s
.
balanceNotifyService
,
}
}
}
}
...
@@ -7746,6 +7827,23 @@ func (s *GatewayService) recordUsageCore(ctx context.Context, input *recordUsage
...
@@ -7746,6 +7827,23 @@ func (s *GatewayService) recordUsageCore(ctx context.Context, input *recordUsage
usageLog
:=
s
.
buildRecordUsageLog
(
ctx
,
input
,
result
,
apiKey
,
user
,
account
,
subscription
,
usageLog
:=
s
.
buildRecordUsageLog
(
ctx
,
input
,
result
,
apiKey
,
user
,
account
,
subscription
,
requestedModel
,
multiplier
,
accountRateMultiplier
,
billingType
,
cacheTTLOverridden
,
cost
,
opts
)
requestedModel
,
multiplier
,
accountRateMultiplier
,
billingType
,
cacheTTLOverridden
,
cost
,
opts
)
// 计算账号统计定价费用(使用最终上游模型匹配自定义规则)
if
apiKey
.
GroupID
!=
nil
{
applyAccountStatsCost
(
ctx
,
usageLog
,
s
.
channelService
,
s
.
billingService
,
account
.
ID
,
*
apiKey
.
GroupID
,
result
.
UpstreamModel
,
result
.
Model
,
// Anthropic's input_tokens excludes cache_read and cache_creation (billed separately);
// OpenAI gateway uses actualInputTokens which also excludes cache_read for the same reason.
UsageTokens
{
InputTokens
:
result
.
Usage
.
InputTokens
,
OutputTokens
:
result
.
Usage
.
OutputTokens
,
CacheCreationTokens
:
result
.
Usage
.
CacheCreationInputTokens
,
CacheReadTokens
:
result
.
Usage
.
CacheReadInputTokens
,
ImageOutputTokens
:
result
.
Usage
.
ImageOutputTokens
,
},
cost
.
TotalCost
,
)
}
if
s
.
cfg
!=
nil
&&
s
.
cfg
.
RunMode
==
config
.
RunModeSimple
{
if
s
.
cfg
!=
nil
&&
s
.
cfg
.
RunMode
==
config
.
RunModeSimple
{
writeUsageLogBestEffort
(
ctx
,
s
.
usageLogRepo
,
usageLog
,
"service.gateway"
)
writeUsageLogBestEffort
(
ctx
,
s
.
usageLogRepo
,
usageLog
,
"service.gateway"
)
logger
.
LegacyPrintf
(
"service.gateway"
,
"[SIMPLE MODE] Usage recorded (not billed): user=%d, tokens=%d"
,
usageLog
.
UserID
,
usageLog
.
TotalTokens
())
logger
.
LegacyPrintf
(
"service.gateway"
,
"[SIMPLE MODE] Usage recorded (not billed): user=%d, tokens=%d"
,
usageLog
.
UserID
,
usageLog
.
TotalTokens
())
...
@@ -8086,6 +8184,19 @@ func (s *GatewayService) needsUpstreamChannelRestrictionCheck(ctx context.Contex
...
@@ -8086,6 +8184,19 @@ func (s *GatewayService) needsUpstreamChannelRestrictionCheck(ctx context.Contex
return
ch
.
BillingModelSource
==
BillingModelSourceUpstream
return
ch
.
BillingModelSource
==
BillingModelSourceUpstream
}
}
// isStickyAccountUpstreamRestricted 检查粘性会话命中的账号是否受 upstream 渠道限制。
// 合并 needsUpstreamChannelRestrictionCheck + isUpstreamModelRestrictedByChannel 两步调用,
// 供 sticky session 条件链使用,避免内联多个函数调用导致行过长。
func
(
s
*
GatewayService
)
isStickyAccountUpstreamRestricted
(
ctx
context
.
Context
,
groupID
*
int64
,
account
*
Account
,
requestedModel
string
)
bool
{
if
groupID
==
nil
{
return
false
}
if
!
s
.
needsUpstreamChannelRestrictionCheck
(
ctx
,
groupID
)
{
return
false
}
return
s
.
isUpstreamModelRestrictedByChannel
(
ctx
,
*
groupID
,
account
,
requestedModel
)
}
// ForwardCountTokens 转发 count_tokens 请求到上游 API
// ForwardCountTokens 转发 count_tokens 请求到上游 API
// 特点:不记录使用量、仅支持非流式响应
// 特点:不记录使用量、仅支持非流式响应
func
(
s
*
GatewayService
)
ForwardCountTokens
(
ctx
context
.
Context
,
c
*
gin
.
Context
,
account
*
Account
,
parsed
*
ParsedRequest
)
error
{
func
(
s
*
GatewayService
)
ForwardCountTokens
(
ctx
context
.
Context
,
c
*
gin
.
Context
,
account
*
Account
,
parsed
*
ParsedRequest
)
error
{
...
...
backend/internal/service/gateway_websearch_emulation.go
0 → 100644
View file @
0b746501
package
service
import
(
"context"
"encoding/json"
"errors"
"fmt"
"log/slog"
"net/http"
"strings"
"sync/atomic"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/websearch"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
"github.com/tidwall/gjson"
)
// Web search emulation constants
const
(
toolTypeWebSearchPrefix
=
"web_search"
toolTypeGoogleSearch
=
"google_search"
toolNameWebSearch
=
"web_search"
toolNameGoogleSearch
=
"google_search"
toolNameWebSearch2025
=
"web_search_20250305"
webSearchDefaultMaxResults
=
5
defaultWebSearchModel
=
"claude-sonnet-4-6"
webSearchMsgIDPrefix
=
"msg_ws_"
webSearchToolUseIDPrefix
=
"srvtoolu_ws_"
tokenEstimateDivisor
=
4
// featureKeyWebSearchEmulation is the key used in Account.Extra and Channel.FeaturesConfig.
featureKeyWebSearchEmulation
=
"web_search_emulation"
)
// webSearchManagerPtr stores *websearch.Manager atomically for concurrent safety.
var
webSearchManagerPtr
atomic
.
Pointer
[
websearch
.
Manager
]
// SetWebSearchManager wires the websearch.Manager into the gateway (goroutine-safe).
func
SetWebSearchManager
(
m
*
websearch
.
Manager
)
{
webSearchManagerPtr
.
Store
(
m
)
}
func
getWebSearchManager
()
*
websearch
.
Manager
{
return
webSearchManagerPtr
.
Load
()
}
// shouldEmulateWebSearch checks whether a request should be intercepted.
//
// Judgment chain: manager exists → only web_search tool → global enabled → account/channel enabled.
// Account-level mode: "enabled" (force on), "disabled" (force off), "default" (follow channel).
func
(
s
*
GatewayService
)
shouldEmulateWebSearch
(
ctx
context
.
Context
,
account
*
Account
,
groupID
*
int64
,
body
[]
byte
)
bool
{
if
getWebSearchManager
()
==
nil
{
return
false
}
if
!
isOnlyWebSearchToolInBody
(
body
)
{
return
false
}
if
!
s
.
settingService
.
IsWebSearchEmulationEnabled
(
ctx
)
{
return
false
}
mode
:=
account
.
GetWebSearchEmulationMode
()
switch
mode
{
case
WebSearchModeEnabled
:
return
true
case
WebSearchModeDisabled
:
return
false
default
:
// "default" → follow channel config
if
groupID
==
nil
||
s
.
channelService
==
nil
{
return
false
}
ch
,
err
:=
s
.
channelService
.
GetChannelForGroup
(
ctx
,
*
groupID
)
if
err
!=
nil
||
ch
==
nil
{
return
false
}
return
ch
.
IsWebSearchEmulationEnabled
(
account
.
Platform
)
}
}
// isOnlyWebSearchToolInBody checks if the body contains exactly one web_search tool.
func
isOnlyWebSearchToolInBody
(
body
[]
byte
)
bool
{
tools
:=
gjson
.
GetBytes
(
body
,
"tools"
)
if
!
tools
.
IsArray
()
{
return
false
}
arr
:=
tools
.
Array
()
if
len
(
arr
)
!=
1
{
return
false
}
return
isWebSearchToolJSON
(
arr
[
0
])
}
func
isWebSearchToolJSON
(
tool
gjson
.
Result
)
bool
{
toolType
:=
tool
.
Get
(
"type"
)
.
String
()
if
strings
.
HasPrefix
(
toolType
,
toolTypeWebSearchPrefix
)
||
toolType
==
toolTypeGoogleSearch
{
return
true
}
switch
tool
.
Get
(
"name"
)
.
String
()
{
case
toolNameWebSearch
,
toolNameGoogleSearch
,
toolNameWebSearch2025
:
return
true
}
return
false
}
// extractSearchQueryFromBody extracts the last user message text as the search query.
func
extractSearchQueryFromBody
(
body
[]
byte
)
string
{
messages
:=
gjson
.
GetBytes
(
body
,
"messages"
)
if
!
messages
.
IsArray
()
{
return
""
}
arr
:=
messages
.
Array
()
if
len
(
arr
)
==
0
{
return
""
}
lastMsg
:=
arr
[
len
(
arr
)
-
1
]
if
lastMsg
.
Get
(
"role"
)
.
String
()
!=
"user"
{
return
""
}
return
extractWebSearchTextFromContent
(
lastMsg
.
Get
(
"content"
))
}
func
extractWebSearchTextFromContent
(
content
gjson
.
Result
)
string
{
if
content
.
Type
==
gjson
.
String
{
return
content
.
String
()
}
if
content
.
IsArray
()
{
for
_
,
block
:=
range
content
.
Array
()
{
if
block
.
Get
(
"type"
)
.
String
()
==
"text"
{
if
text
:=
block
.
Get
(
"text"
)
.
String
();
text
!=
""
{
return
text
}
}
}
}
return
""
}
// handleWebSearchEmulation intercepts a web-search-only request,
// calls a third-party search API, and constructs an Anthropic-format response.
func
(
s
*
GatewayService
)
handleWebSearchEmulation
(
ctx
context
.
Context
,
c
*
gin
.
Context
,
account
*
Account
,
parsed
*
ParsedRequest
,
)
(
*
ForwardResult
,
error
)
{
startTime
:=
time
.
Now
()
// Release the serial queue lock immediately — we don't need upstream.
if
parsed
.
OnUpstreamAccepted
!=
nil
{
parsed
.
OnUpstreamAccepted
()
}
query
:=
extractSearchQueryFromBody
(
parsed
.
Body
)
if
query
==
""
{
return
nil
,
fmt
.
Errorf
(
"web search emulation: no query found in messages"
)
}
slog
.
Info
(
"web search emulation: executing search"
,
"account_id"
,
account
.
ID
,
"account_name"
,
account
.
Name
,
"query"
,
query
)
resp
,
providerName
,
err
:=
doWebSearch
(
ctx
,
account
,
query
)
if
err
!=
nil
{
// Proxy unavailable → trigger account switch via UpstreamFailoverError
if
errors
.
Is
(
err
,
websearch
.
ErrProxyUnavailable
)
{
return
nil
,
&
UpstreamFailoverError
{
StatusCode
:
http
.
StatusBadGateway
,
ResponseBody
:
[]
byte
(
err
.
Error
()),
}
}
return
nil
,
err
}
slog
.
Info
(
"web search emulation: search completed"
,
"provider"
,
providerName
,
"results_count"
,
len
(
resp
.
Results
))
model
:=
parsed
.
Model
if
model
==
""
{
model
=
defaultWebSearchModel
}
if
parsed
.
Stream
{
return
writeWebSearchStreamResponse
(
c
,
query
,
resp
,
model
,
startTime
)
}
return
writeWebSearchNonStreamResponse
(
c
,
query
,
resp
,
model
,
startTime
)
}
func
doWebSearch
(
ctx
context
.
Context
,
account
*
Account
,
query
string
)
(
*
websearch
.
SearchResponse
,
string
,
error
)
{
proxyURL
:=
resolveAccountProxyURL
(
account
)
mgr
:=
getWebSearchManager
()
if
mgr
==
nil
{
return
nil
,
""
,
fmt
.
Errorf
(
"web search emulation: manager not initialized"
)
}
resp
,
providerName
,
err
:=
mgr
.
SearchWithBestProvider
(
ctx
,
websearch
.
SearchRequest
{
Query
:
query
,
MaxResults
:
webSearchDefaultMaxResults
,
ProxyURL
:
proxyURL
,
})
if
err
!=
nil
{
slog
.
Error
(
"web search emulation: search failed"
,
"error"
,
err
)
return
nil
,
""
,
fmt
.
Errorf
(
"web search emulation: %w"
,
err
)
}
return
resp
,
providerName
,
nil
}
func
resolveAccountProxyURL
(
account
*
Account
)
string
{
if
account
.
ProxyID
!=
nil
&&
account
.
Proxy
!=
nil
{
return
account
.
Proxy
.
URL
()
}
return
""
}
// --- SSE streaming response ---
func
writeWebSearchStreamResponse
(
c
*
gin
.
Context
,
query
string
,
resp
*
websearch
.
SearchResponse
,
model
string
,
startTime
time
.
Time
,
)
(
*
ForwardResult
,
error
)
{
msgID
:=
webSearchMsgIDPrefix
+
uuid
.
New
()
.
String
()
toolUseID
:=
webSearchToolUseIDPrefix
+
uuid
.
New
()
.
String
()[
:
16
]
textSummary
:=
buildTextSummary
(
query
,
resp
.
Results
)
setSSEHeaders
(
c
)
w
:=
c
.
Writer
for
_
,
fn
:=
range
[]
func
()
error
{
func
()
error
{
return
writeSSEMessageStart
(
w
,
msgID
,
model
)
},
func
()
error
{
return
writeSSEServerToolUse
(
w
,
toolUseID
,
query
,
0
)
},
func
()
error
{
return
writeSSEToolResult
(
w
,
toolUseID
,
resp
.
Results
,
1
)
},
func
()
error
{
return
writeSSETextBlock
(
w
,
textSummary
,
2
)
},
func
()
error
{
return
writeSSEMessageEnd
(
w
,
len
(
textSummary
)
/
tokenEstimateDivisor
)
},
}
{
if
err
:=
fn
();
err
!=
nil
{
slog
.
Warn
(
"web search emulation: SSE write failed, stopping"
,
"error"
,
err
)
break
}
}
w
.
Flush
()
return
&
ForwardResult
{
Model
:
model
,
Duration
:
time
.
Since
(
startTime
),
Usage
:
ClaudeUsage
{}},
nil
}
func
setSSEHeaders
(
c
*
gin
.
Context
)
{
c
.
Writer
.
Header
()
.
Set
(
"Content-Type"
,
"text/event-stream"
)
c
.
Writer
.
Header
()
.
Set
(
"Cache-Control"
,
"no-cache"
)
c
.
Writer
.
Header
()
.
Set
(
"Connection"
,
"keep-alive"
)
c
.
Writer
.
Header
()
.
Set
(
"X-Accel-Buffering"
,
"no"
)
c
.
Writer
.
WriteHeader
(
http
.
StatusOK
)
}
func
writeSSEMessageStart
(
w
http
.
ResponseWriter
,
msgID
,
model
string
)
error
{
evt
:=
map
[
string
]
any
{
"type"
:
"message_start"
,
"message"
:
map
[
string
]
any
{
"id"
:
msgID
,
"type"
:
"message"
,
"role"
:
"assistant"
,
"model"
:
model
,
"content"
:
[]
any
{},
"stop_reason"
:
nil
,
"stop_sequence"
:
nil
,
"usage"
:
map
[
string
]
int
{
"input_tokens"
:
0
,
"output_tokens"
:
0
},
},
}
return
flushSSEJSON
(
w
,
"message_start"
,
evt
)
}
func
writeSSEServerToolUse
(
w
http
.
ResponseWriter
,
toolUseID
,
query
string
,
index
int
)
error
{
start
:=
map
[
string
]
any
{
"type"
:
"content_block_start"
,
"index"
:
index
,
"content_block"
:
map
[
string
]
any
{
"type"
:
"server_tool_use"
,
"id"
:
toolUseID
,
"name"
:
toolNameWebSearch
,
"input"
:
map
[
string
]
string
{
"query"
:
query
},
},
}
if
err
:=
flushSSEJSON
(
w
,
"content_block_start"
,
start
);
err
!=
nil
{
return
err
}
return
flushSSEJSON
(
w
,
"content_block_stop"
,
map
[
string
]
any
{
"type"
:
"content_block_stop"
,
"index"
:
index
})
}
func
writeSSEToolResult
(
w
http
.
ResponseWriter
,
toolUseID
string
,
results
[]
websearch
.
SearchResult
,
index
int
)
error
{
start
:=
map
[
string
]
any
{
"type"
:
"content_block_start"
,
"index"
:
index
,
"content_block"
:
map
[
string
]
any
{
"type"
:
"web_search_tool_result"
,
"tool_use_id"
:
toolUseID
,
"content"
:
buildSearchResultBlocks
(
results
),
},
}
if
err
:=
flushSSEJSON
(
w
,
"content_block_start"
,
start
);
err
!=
nil
{
return
err
}
return
flushSSEJSON
(
w
,
"content_block_stop"
,
map
[
string
]
any
{
"type"
:
"content_block_stop"
,
"index"
:
index
})
}
func
writeSSETextBlock
(
w
http
.
ResponseWriter
,
text
string
,
index
int
)
error
{
if
err
:=
flushSSEJSON
(
w
,
"content_block_start"
,
map
[
string
]
any
{
"type"
:
"content_block_start"
,
"index"
:
index
,
"content_block"
:
map
[
string
]
any
{
"type"
:
"text"
,
"text"
:
""
},
});
err
!=
nil
{
return
err
}
if
err
:=
flushSSEJSON
(
w
,
"content_block_delta"
,
map
[
string
]
any
{
"type"
:
"content_block_delta"
,
"index"
:
index
,
"delta"
:
map
[
string
]
string
{
"type"
:
"text_delta"
,
"text"
:
text
},
});
err
!=
nil
{
return
err
}
return
flushSSEJSON
(
w
,
"content_block_stop"
,
map
[
string
]
any
{
"type"
:
"content_block_stop"
,
"index"
:
index
})
}
func
writeSSEMessageEnd
(
w
http
.
ResponseWriter
,
outputTokens
int
)
error
{
if
err
:=
flushSSEJSON
(
w
,
"message_delta"
,
map
[
string
]
any
{
"type"
:
"message_delta"
,
"delta"
:
map
[
string
]
any
{
"stop_reason"
:
"end_turn"
,
"stop_sequence"
:
nil
},
"usage"
:
map
[
string
]
int
{
"output_tokens"
:
outputTokens
},
});
err
!=
nil
{
return
err
}
return
flushSSEJSON
(
w
,
"message_stop"
,
map
[
string
]
string
{
"type"
:
"message_stop"
})
}
// flushSSEJSON marshals data to JSON and writes an SSE event.
func
flushSSEJSON
(
w
http
.
ResponseWriter
,
event
string
,
data
any
)
error
{
b
,
err
:=
json
.
Marshal
(
data
)
if
err
!=
nil
{
return
fmt
.
Errorf
(
"marshal: %w"
,
err
)
}
if
_
,
err
:=
fmt
.
Fprintf
(
w
,
"event: %s
\n
data: %s
\n\n
"
,
event
,
b
);
err
!=
nil
{
return
fmt
.
Errorf
(
"write: %w"
,
err
)
}
if
f
,
ok
:=
w
.
(
http
.
Flusher
);
ok
{
f
.
Flush
()
}
return
nil
}
// --- Non-streaming JSON response ---
func
writeWebSearchNonStreamResponse
(
c
*
gin
.
Context
,
query
string
,
resp
*
websearch
.
SearchResponse
,
model
string
,
startTime
time
.
Time
,
)
(
*
ForwardResult
,
error
)
{
msgID
:=
webSearchMsgIDPrefix
+
uuid
.
New
()
.
String
()
toolUseID
:=
webSearchToolUseIDPrefix
+
uuid
.
New
()
.
String
()[
:
16
]
textSummary
:=
buildTextSummary
(
query
,
resp
.
Results
)
msg
:=
map
[
string
]
any
{
"id"
:
msgID
,
"type"
:
"message"
,
"role"
:
"assistant"
,
"model"
:
model
,
"content"
:
[]
any
{
map
[
string
]
any
{
"type"
:
"server_tool_use"
,
"id"
:
toolUseID
,
"name"
:
toolNameWebSearch
,
"input"
:
map
[
string
]
string
{
"query"
:
query
},
},
map
[
string
]
any
{
"type"
:
"web_search_tool_result"
,
"tool_use_id"
:
toolUseID
,
"content"
:
buildSearchResultBlocks
(
resp
.
Results
),
},
map
[
string
]
any
{
"type"
:
"text"
,
"text"
:
textSummary
},
},
"stop_reason"
:
"end_turn"
,
"stop_sequence"
:
nil
,
"usage"
:
map
[
string
]
int
{
"input_tokens"
:
0
,
"output_tokens"
:
len
(
textSummary
)
/
tokenEstimateDivisor
},
}
body
,
err
:=
json
.
Marshal
(
msg
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"web search emulation: marshal response: %w"
,
err
)
}
c
.
Data
(
http
.
StatusOK
,
"application/json"
,
body
)
return
&
ForwardResult
{
Model
:
model
,
Duration
:
time
.
Since
(
startTime
),
Usage
:
ClaudeUsage
{}},
nil
}
// --- Helpers ---
func
buildSearchResultBlocks
(
results
[]
websearch
.
SearchResult
)
[]
map
[
string
]
string
{
blocks
:=
make
([]
map
[
string
]
string
,
0
,
len
(
results
))
for
_
,
r
:=
range
results
{
block
:=
map
[
string
]
string
{
"type"
:
"web_search_result"
,
"url"
:
r
.
URL
,
"title"
:
r
.
Title
,
}
if
r
.
Snippet
!=
""
{
block
[
"page_content"
]
=
r
.
Snippet
}
if
r
.
PageAge
!=
""
{
block
[
"page_age"
]
=
r
.
PageAge
}
blocks
=
append
(
blocks
,
block
)
}
return
blocks
}
func
buildTextSummary
(
query
string
,
results
[]
websearch
.
SearchResult
)
string
{
if
len
(
results
)
==
0
{
return
"No search results found for: "
+
query
}
var
sb
strings
.
Builder
fmt
.
Fprintf
(
&
sb
,
"Here are the search results for
\"
%s
\"
:
\n\n
"
,
query
)
for
i
,
r
:=
range
results
{
fmt
.
Fprintf
(
&
sb
,
"%d. **%s**
\n
%s
\n
%s
\n\n
"
,
i
+
1
,
r
.
Title
,
r
.
URL
,
r
.
Snippet
)
}
return
sb
.
String
()
}
backend/internal/service/gateway_websearch_emulation_test.go
0 → 100644
View file @
0b746501
//go:build unit
package
service
import
(
"context"
"encoding/json"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/websearch"
"github.com/stretchr/testify/require"
)
// --- isOnlyWebSearchToolInBody ---
func
TestIsOnlyWebSearchToolInBody_WebSearchType
(
t
*
testing
.
T
)
{
require
.
True
(
t
,
isOnlyWebSearchToolInBody
([]
byte
(
`{"tools":[{"type":"web_search"}]}`
)))
}
func
TestIsOnlyWebSearchToolInBody_WebSearch2025Type
(
t
*
testing
.
T
)
{
require
.
True
(
t
,
isOnlyWebSearchToolInBody
([]
byte
(
`{"tools":[{"type":"web_search_20250305"}]}`
)))
}
func
TestIsOnlyWebSearchToolInBody_GoogleSearchType
(
t
*
testing
.
T
)
{
require
.
True
(
t
,
isOnlyWebSearchToolInBody
([]
byte
(
`{"tools":[{"type":"google_search"}]}`
)))
}
func
TestIsOnlyWebSearchToolInBody_NameWebSearch
(
t
*
testing
.
T
)
{
require
.
True
(
t
,
isOnlyWebSearchToolInBody
([]
byte
(
`{"tools":[{"name":"web_search"}]}`
)))
}
func
TestIsOnlyWebSearchToolInBody_NameWebSearch2025
(
t
*
testing
.
T
)
{
require
.
True
(
t
,
isOnlyWebSearchToolInBody
([]
byte
(
`{"tools":[{"name":"web_search_20250305"}]}`
)))
}
func
TestIsOnlyWebSearchToolInBody_NameGoogleSearch
(
t
*
testing
.
T
)
{
require
.
True
(
t
,
isOnlyWebSearchToolInBody
([]
byte
(
`{"tools":[{"name":"google_search"}]}`
)))
}
func
TestIsOnlyWebSearchToolInBody_MultipleTools
(
t
*
testing
.
T
)
{
require
.
False
(
t
,
isOnlyWebSearchToolInBody
(
[]
byte
(
`{"tools":[{"type":"web_search"},{"type":"text_editor"}]}`
)))
}
func
TestIsOnlyWebSearchToolInBody_NoTools
(
t
*
testing
.
T
)
{
require
.
False
(
t
,
isOnlyWebSearchToolInBody
([]
byte
(
`{"model":"claude-3"}`
)))
}
func
TestIsOnlyWebSearchToolInBody_EmptyToolsArray
(
t
*
testing
.
T
)
{
require
.
False
(
t
,
isOnlyWebSearchToolInBody
([]
byte
(
`{"tools":[]}`
)))
}
func
TestIsOnlyWebSearchToolInBody_NonWebSearchTool
(
t
*
testing
.
T
)
{
require
.
False
(
t
,
isOnlyWebSearchToolInBody
([]
byte
(
`{"tools":[{"type":"text_editor"}]}`
)))
}
func
TestIsOnlyWebSearchToolInBody_ToolsNotArray
(
t
*
testing
.
T
)
{
require
.
False
(
t
,
isOnlyWebSearchToolInBody
([]
byte
(
`{"tools":"web_search"}`
)))
}
// --- extractSearchQueryFromBody ---
func
TestExtractSearchQueryFromBody_StringContent
(
t
*
testing
.
T
)
{
body
:=
`{"messages":[{"role":"user","content":"what is golang"}]}`
require
.
Equal
(
t
,
"what is golang"
,
extractSearchQueryFromBody
([]
byte
(
body
)))
}
func
TestExtractSearchQueryFromBody_ArrayContent
(
t
*
testing
.
T
)
{
body
:=
`{"messages":[{"role":"user","content":[{"type":"text","text":"search this"}]}]}`
require
.
Equal
(
t
,
"search this"
,
extractSearchQueryFromBody
([]
byte
(
body
)))
}
func
TestExtractSearchQueryFromBody_MultipleMessages
(
t
*
testing
.
T
)
{
body
:=
`{"messages":[{"role":"user","content":"first"},{"role":"assistant","content":"ok"},{"role":"user","content":"second"}]}`
require
.
Equal
(
t
,
"second"
,
extractSearchQueryFromBody
([]
byte
(
body
)))
}
func
TestExtractSearchQueryFromBody_LastMessageNotUser
(
t
*
testing
.
T
)
{
body
:=
`{"messages":[{"role":"user","content":"q"},{"role":"assistant","content":"a"}]}`
require
.
Equal
(
t
,
""
,
extractSearchQueryFromBody
([]
byte
(
body
)))
}
func
TestExtractSearchQueryFromBody_EmptyMessages
(
t
*
testing
.
T
)
{
require
.
Equal
(
t
,
""
,
extractSearchQueryFromBody
([]
byte
(
`{"messages":[]}`
)))
}
func
TestExtractSearchQueryFromBody_NoMessages
(
t
*
testing
.
T
)
{
require
.
Equal
(
t
,
""
,
extractSearchQueryFromBody
([]
byte
(
`{"model":"claude-3"}`
)))
}
func
TestExtractSearchQueryFromBody_ArrayContentSkipsEmptyText
(
t
*
testing
.
T
)
{
body
:=
`{"messages":[{"role":"user","content":[{"type":"image"},{"type":"text","text":""},{"type":"text","text":"real query"}]}]}`
require
.
Equal
(
t
,
"real query"
,
extractSearchQueryFromBody
([]
byte
(
body
)))
}
func
TestExtractSearchQueryFromBody_ArrayContentNoTextBlock
(
t
*
testing
.
T
)
{
body
:=
`{"messages":[{"role":"user","content":[{"type":"image","source":{}}]}]}`
require
.
Equal
(
t
,
""
,
extractSearchQueryFromBody
([]
byte
(
body
)))
}
// --- buildSearchResultBlocks ---
func
TestBuildSearchResultBlocks_WithResults
(
t
*
testing
.
T
)
{
results
:=
[]
websearch
.
SearchResult
{
{
URL
:
"https://a.com"
,
Title
:
"A"
,
Snippet
:
"snippet a"
,
PageAge
:
"2 days"
},
{
URL
:
"https://b.com"
,
Title
:
"B"
,
Snippet
:
"snippet b"
},
}
blocks
:=
buildSearchResultBlocks
(
results
)
require
.
Len
(
t
,
blocks
,
2
)
require
.
Equal
(
t
,
"web_search_result"
,
blocks
[
0
][
"type"
])
require
.
Equal
(
t
,
"https://a.com"
,
blocks
[
0
][
"url"
])
require
.
Equal
(
t
,
"snippet a"
,
blocks
[
0
][
"page_content"
])
require
.
Equal
(
t
,
"2 days"
,
blocks
[
0
][
"page_age"
])
// Second result has no PageAge
require
.
Equal
(
t
,
"https://b.com"
,
blocks
[
1
][
"url"
])
_
,
hasPageAge
:=
blocks
[
1
][
"page_age"
]
require
.
False
(
t
,
hasPageAge
)
}
func
TestBuildSearchResultBlocks_Empty
(
t
*
testing
.
T
)
{
blocks
:=
buildSearchResultBlocks
(
nil
)
require
.
Empty
(
t
,
blocks
)
}
func
TestBuildSearchResultBlocks_SnippetEmpty
(
t
*
testing
.
T
)
{
blocks
:=
buildSearchResultBlocks
([]
websearch
.
SearchResult
{{
URL
:
"https://x.com"
,
Title
:
"X"
,
Snippet
:
""
}})
_
,
hasContent
:=
blocks
[
0
][
"page_content"
]
require
.
False
(
t
,
hasContent
)
}
// --- buildTextSummary ---
func
TestBuildTextSummary_WithResults
(
t
*
testing
.
T
)
{
results
:=
[]
websearch
.
SearchResult
{
{
URL
:
"https://a.com"
,
Title
:
"A"
,
Snippet
:
"desc a"
},
}
summary
:=
buildTextSummary
(
"test query"
,
results
)
require
.
Contains
(
t
,
summary
,
"test query"
)
require
.
Contains
(
t
,
summary
,
"1. **A**"
)
require
.
Contains
(
t
,
summary
,
"https://a.com"
)
}
func
TestBuildTextSummary_NoResults
(
t
*
testing
.
T
)
{
summary
:=
buildTextSummary
(
"test"
,
nil
)
require
.
Contains
(
t
,
summary
,
"No search results found for: test"
)
}
// --- shouldEmulateWebSearch ---
// webSearchToolBody is a valid request body with exactly one web_search tool.
var
webSearchToolBody
=
[]
byte
(
`{"tools":[{"type":"web_search"}],"messages":[{"role":"user","content":"test"}]}`
)
// nonWebSearchToolBody is a request body without web_search tool.
var
nonWebSearchToolBody
=
[]
byte
(
`{"tools":[{"type":"text_editor"}],"messages":[{"role":"user","content":"test"}]}`
)
// newAnthropicAPIKeyAccount creates a test Account with the given web search emulation mode.
func
newAnthropicAPIKeyAccount
(
mode
string
)
*
Account
{
return
&
Account
{
ID
:
1
,
Platform
:
PlatformAnthropic
,
Type
:
AccountTypeAPIKey
,
Extra
:
map
[
string
]
any
{
featureKeyWebSearchEmulation
:
mode
},
}
}
// setGlobalWebSearchConfig stores a config in the global cache used by SettingService.IsWebSearchEmulationEnabled.
func
setGlobalWebSearchConfig
(
cfg
*
WebSearchEmulationConfig
)
{
webSearchEmulationCache
.
Store
(
&
cachedWebSearchEmulationConfig
{
config
:
cfg
,
expiresAt
:
time
.
Now
()
.
Add
(
10
*
time
.
Minute
)
.
UnixNano
(),
})
}
// clearGlobalWebSearchConfig resets the global cache to force re-read.
func
clearGlobalWebSearchConfig
()
{
webSearchEmulationCache
.
Store
((
*
cachedWebSearchEmulationConfig
)(
nil
))
}
// newSettingServiceForWebSearchTest creates a SettingService with a mock repo pre-loaded with config.
func
newSettingServiceForWebSearchTest
(
enabled
bool
)
*
SettingService
{
repo
:=
newMockSettingRepo
()
cfg
:=
&
WebSearchEmulationConfig
{
Enabled
:
enabled
,
Providers
:
[]
WebSearchProviderConfig
{{
Type
:
"brave"
,
APIKey
:
"sk-test"
}},
}
data
,
_
:=
json
.
Marshal
(
cfg
)
repo
.
data
[
SettingKeyWebSearchEmulationConfig
]
=
string
(
data
)
return
NewSettingService
(
repo
,
&
config
.
Config
{})
}
// newChannelServiceWithCache creates a ChannelService with a pre-built cache containing the channel.
func
newChannelServiceWithCache
(
groupID
int64
,
ch
*
Channel
)
*
ChannelService
{
svc
:=
&
ChannelService
{}
cache
:=
&
channelCache
{
channelByGroupID
:
map
[
int64
]
*
Channel
{
groupID
:
ch
},
byID
:
map
[
int64
]
*
Channel
{
ch
.
ID
:
ch
},
groupPlatform
:
map
[
int64
]
string
{},
loadedAt
:
time
.
Now
(),
}
svc
.
cache
.
Store
(
cache
)
return
svc
}
func
TestShouldEmulateWebSearch_NilManager
(
t
*
testing
.
T
)
{
SetWebSearchManager
(
nil
)
defer
SetWebSearchManager
(
nil
)
settingSvc
:=
newSettingServiceForWebSearchTest
(
true
)
setGlobalWebSearchConfig
(
&
WebSearchEmulationConfig
{
Enabled
:
true
,
Providers
:
[]
WebSearchProviderConfig
{{
Type
:
"brave"
,
APIKey
:
"k"
}},
})
defer
clearGlobalWebSearchConfig
()
svc
:=
&
GatewayService
{
settingService
:
settingSvc
}
account
:=
newAnthropicAPIKeyAccount
(
WebSearchModeEnabled
)
require
.
False
(
t
,
svc
.
shouldEmulateWebSearch
(
context
.
Background
(),
account
,
nil
,
webSearchToolBody
))
}
func
TestShouldEmulateWebSearch_NotOnlyWebSearchTool
(
t
*
testing
.
T
)
{
mgr
:=
websearch
.
NewManager
([]
websearch
.
ProviderConfig
{{
Type
:
"brave"
,
APIKey
:
"k"
}},
nil
)
SetWebSearchManager
(
mgr
)
defer
SetWebSearchManager
(
nil
)
settingSvc
:=
newSettingServiceForWebSearchTest
(
true
)
setGlobalWebSearchConfig
(
&
WebSearchEmulationConfig
{
Enabled
:
true
,
Providers
:
[]
WebSearchProviderConfig
{{
Type
:
"brave"
,
APIKey
:
"k"
}},
})
defer
clearGlobalWebSearchConfig
()
svc
:=
&
GatewayService
{
settingService
:
settingSvc
}
account
:=
newAnthropicAPIKeyAccount
(
WebSearchModeEnabled
)
require
.
False
(
t
,
svc
.
shouldEmulateWebSearch
(
context
.
Background
(),
account
,
nil
,
nonWebSearchToolBody
))
}
func
TestShouldEmulateWebSearch_GlobalDisabled
(
t
*
testing
.
T
)
{
mgr
:=
websearch
.
NewManager
([]
websearch
.
ProviderConfig
{{
Type
:
"brave"
,
APIKey
:
"k"
}},
nil
)
SetWebSearchManager
(
mgr
)
defer
SetWebSearchManager
(
nil
)
// Global config disabled
setGlobalWebSearchConfig
(
&
WebSearchEmulationConfig
{
Enabled
:
false
,
Providers
:
[]
WebSearchProviderConfig
{{
Type
:
"brave"
,
APIKey
:
"k"
}},
})
defer
clearGlobalWebSearchConfig
()
settingSvc
:=
newSettingServiceForWebSearchTest
(
false
)
svc
:=
&
GatewayService
{
settingService
:
settingSvc
}
account
:=
newAnthropicAPIKeyAccount
(
WebSearchModeEnabled
)
require
.
False
(
t
,
svc
.
shouldEmulateWebSearch
(
context
.
Background
(),
account
,
nil
,
webSearchToolBody
))
}
func
TestShouldEmulateWebSearch_AccountDisabled
(
t
*
testing
.
T
)
{
mgr
:=
websearch
.
NewManager
([]
websearch
.
ProviderConfig
{{
Type
:
"brave"
,
APIKey
:
"k"
}},
nil
)
SetWebSearchManager
(
mgr
)
defer
SetWebSearchManager
(
nil
)
setGlobalWebSearchConfig
(
&
WebSearchEmulationConfig
{
Enabled
:
true
,
Providers
:
[]
WebSearchProviderConfig
{{
Type
:
"brave"
,
APIKey
:
"k"
}},
})
defer
clearGlobalWebSearchConfig
()
settingSvc
:=
newSettingServiceForWebSearchTest
(
true
)
svc
:=
&
GatewayService
{
settingService
:
settingSvc
}
account
:=
newAnthropicAPIKeyAccount
(
WebSearchModeDisabled
)
require
.
False
(
t
,
svc
.
shouldEmulateWebSearch
(
context
.
Background
(),
account
,
nil
,
webSearchToolBody
))
}
func
TestShouldEmulateWebSearch_AccountEnabled
(
t
*
testing
.
T
)
{
mgr
:=
websearch
.
NewManager
([]
websearch
.
ProviderConfig
{{
Type
:
"brave"
,
APIKey
:
"k"
}},
nil
)
SetWebSearchManager
(
mgr
)
defer
SetWebSearchManager
(
nil
)
setGlobalWebSearchConfig
(
&
WebSearchEmulationConfig
{
Enabled
:
true
,
Providers
:
[]
WebSearchProviderConfig
{{
Type
:
"brave"
,
APIKey
:
"k"
}},
})
defer
clearGlobalWebSearchConfig
()
settingSvc
:=
newSettingServiceForWebSearchTest
(
true
)
svc
:=
&
GatewayService
{
settingService
:
settingSvc
}
account
:=
newAnthropicAPIKeyAccount
(
WebSearchModeEnabled
)
require
.
True
(
t
,
svc
.
shouldEmulateWebSearch
(
context
.
Background
(),
account
,
nil
,
webSearchToolBody
))
}
func
TestShouldEmulateWebSearch_DefaultMode_ChannelEnabled
(
t
*
testing
.
T
)
{
mgr
:=
websearch
.
NewManager
([]
websearch
.
ProviderConfig
{{
Type
:
"brave"
,
APIKey
:
"k"
}},
nil
)
SetWebSearchManager
(
mgr
)
defer
SetWebSearchManager
(
nil
)
setGlobalWebSearchConfig
(
&
WebSearchEmulationConfig
{
Enabled
:
true
,
Providers
:
[]
WebSearchProviderConfig
{{
Type
:
"brave"
,
APIKey
:
"k"
}},
})
defer
clearGlobalWebSearchConfig
()
settingSvc
:=
newSettingServiceForWebSearchTest
(
true
)
ch
:=
&
Channel
{
ID
:
10
,
Status
:
StatusActive
,
FeaturesConfig
:
map
[
string
]
any
{
featureKeyWebSearchEmulation
:
map
[
string
]
any
{
PlatformAnthropic
:
true
},
},
}
channelSvc
:=
newChannelServiceWithCache
(
42
,
ch
)
svc
:=
&
GatewayService
{
settingService
:
settingSvc
,
channelService
:
channelSvc
}
account
:=
newAnthropicAPIKeyAccount
(
WebSearchModeDefault
)
groupID
:=
int64
(
42
)
require
.
True
(
t
,
svc
.
shouldEmulateWebSearch
(
context
.
Background
(),
account
,
&
groupID
,
webSearchToolBody
))
}
func
TestShouldEmulateWebSearch_DefaultMode_ChannelDisabled
(
t
*
testing
.
T
)
{
mgr
:=
websearch
.
NewManager
([]
websearch
.
ProviderConfig
{{
Type
:
"brave"
,
APIKey
:
"k"
}},
nil
)
SetWebSearchManager
(
mgr
)
defer
SetWebSearchManager
(
nil
)
setGlobalWebSearchConfig
(
&
WebSearchEmulationConfig
{
Enabled
:
true
,
Providers
:
[]
WebSearchProviderConfig
{{
Type
:
"brave"
,
APIKey
:
"k"
}},
})
defer
clearGlobalWebSearchConfig
()
settingSvc
:=
newSettingServiceForWebSearchTest
(
true
)
ch
:=
&
Channel
{
ID
:
10
,
Status
:
StatusActive
,
FeaturesConfig
:
map
[
string
]
any
{
featureKeyWebSearchEmulation
:
map
[
string
]
any
{
PlatformAnthropic
:
false
},
},
}
channelSvc
:=
newChannelServiceWithCache
(
42
,
ch
)
svc
:=
&
GatewayService
{
settingService
:
settingSvc
,
channelService
:
channelSvc
}
account
:=
newAnthropicAPIKeyAccount
(
WebSearchModeDefault
)
groupID
:=
int64
(
42
)
require
.
False
(
t
,
svc
.
shouldEmulateWebSearch
(
context
.
Background
(),
account
,
&
groupID
,
webSearchToolBody
))
}
func
TestShouldEmulateWebSearch_DefaultMode_NilGroupID
(
t
*
testing
.
T
)
{
mgr
:=
websearch
.
NewManager
([]
websearch
.
ProviderConfig
{{
Type
:
"brave"
,
APIKey
:
"k"
}},
nil
)
SetWebSearchManager
(
mgr
)
defer
SetWebSearchManager
(
nil
)
setGlobalWebSearchConfig
(
&
WebSearchEmulationConfig
{
Enabled
:
true
,
Providers
:
[]
WebSearchProviderConfig
{{
Type
:
"brave"
,
APIKey
:
"k"
}},
})
defer
clearGlobalWebSearchConfig
()
settingSvc
:=
newSettingServiceForWebSearchTest
(
true
)
svc
:=
&
GatewayService
{
settingService
:
settingSvc
}
account
:=
newAnthropicAPIKeyAccount
(
WebSearchModeDefault
)
// nil groupID + default mode → falls through to channel check → returns false
require
.
False
(
t
,
svc
.
shouldEmulateWebSearch
(
context
.
Background
(),
account
,
nil
,
webSearchToolBody
))
}
func
TestShouldEmulateWebSearch_DefaultMode_NilChannelService
(
t
*
testing
.
T
)
{
mgr
:=
websearch
.
NewManager
([]
websearch
.
ProviderConfig
{{
Type
:
"brave"
,
APIKey
:
"k"
}},
nil
)
SetWebSearchManager
(
mgr
)
defer
SetWebSearchManager
(
nil
)
setGlobalWebSearchConfig
(
&
WebSearchEmulationConfig
{
Enabled
:
true
,
Providers
:
[]
WebSearchProviderConfig
{{
Type
:
"brave"
,
APIKey
:
"k"
}},
})
defer
clearGlobalWebSearchConfig
()
settingSvc
:=
newSettingServiceForWebSearchTest
(
true
)
svc
:=
&
GatewayService
{
settingService
:
settingSvc
,
channelService
:
nil
}
account
:=
newAnthropicAPIKeyAccount
(
WebSearchModeDefault
)
groupID
:=
int64
(
42
)
// nil channelService + default mode → returns false
require
.
False
(
t
,
svc
.
shouldEmulateWebSearch
(
context
.
Background
(),
account
,
&
groupID
,
webSearchToolBody
))
}
backend/internal/service/notify_email_entry.go
0 → 100644
View file @
0b746501
package
service
import
(
"encoding/json"
"strings"
)
// NotifyEmailEntry represents a notification email with enable/disable and verification state.
// All emails are user-managed; maximum 3 entries per user.
type
NotifyEmailEntry
struct
{
Email
string
`json:"email"`
Disabled
bool
`json:"disabled"`
Verified
bool
`json:"verified"`
}
// parseNotifyEmails parses a JSON string into []NotifyEmailEntry.
// It auto-detects the format:
// - Old format ["email1","email2"] → converted to [{email, disabled:false, verified:true}, ...]
// - New format [{email,disabled,verified}, ...] → parsed directly
//
// Returns nil on empty/invalid input.
func
ParseNotifyEmails
(
raw
string
)
[]
NotifyEmailEntry
{
raw
=
strings
.
TrimSpace
(
raw
)
if
raw
==
""
||
raw
==
"[]"
{
return
nil
}
// Try parsing as new format first (array of objects)
var
entries
[]
NotifyEmailEntry
if
err
:=
json
.
Unmarshal
([]
byte
(
raw
),
&
entries
);
err
==
nil
&&
len
(
entries
)
>
0
{
// Verify it's actually the new format by checking the first element
// json.Unmarshal into []NotifyEmailEntry succeeds even for ["string"]
// because it tries to fit "string" into NotifyEmailEntry and gets zero values.
// We need to detect old format explicitly.
if
!
isOldStringArrayFormat
(
raw
)
{
return
entries
}
}
// Try parsing as old format (array of strings)
var
emails
[]
string
if
err
:=
json
.
Unmarshal
([]
byte
(
raw
),
&
emails
);
err
==
nil
{
result
:=
make
([]
NotifyEmailEntry
,
0
,
len
(
emails
))
for
_
,
e
:=
range
emails
{
e
=
strings
.
TrimSpace
(
e
)
if
e
!=
""
{
result
=
append
(
result
,
NotifyEmailEntry
{
Email
:
e
,
Disabled
:
false
,
Verified
:
false
,
// Old format emails default to unverified
})
}
}
return
result
}
return
nil
}
// isOldStringArrayFormat checks if the JSON is a string array like ["email1","email2"].
func
isOldStringArrayFormat
(
raw
string
)
bool
{
var
arr
[]
json
.
RawMessage
if
err
:=
json
.
Unmarshal
([]
byte
(
raw
),
&
arr
);
err
!=
nil
||
len
(
arr
)
==
0
{
return
false
}
// Check if first element starts with a quote (string) vs { (object)
first
:=
strings
.
TrimSpace
(
string
(
arr
[
0
]))
return
len
(
first
)
>
0
&&
first
[
0
]
==
'"'
}
// marshalNotifyEmails serializes []NotifyEmailEntry to JSON string.
func
MarshalNotifyEmails
(
entries
[]
NotifyEmailEntry
)
string
{
if
len
(
entries
)
==
0
{
return
"[]"
}
data
,
err
:=
json
.
Marshal
(
entries
)
if
err
!=
nil
{
return
"[]"
}
return
string
(
data
)
}
backend/internal/service/notify_email_entry_test.go
0 → 100644
View file @
0b746501
//go:build unit
package
service
import
(
"testing"
"github.com/stretchr/testify/require"
)
// ---------- ParseNotifyEmails ----------
func
TestParseNotifyEmails_EmptyString
(
t
*
testing
.
T
)
{
result
:=
ParseNotifyEmails
(
""
)
require
.
Nil
(
t
,
result
)
}
func
TestParseNotifyEmails_EmptyArray
(
t
*
testing
.
T
)
{
result
:=
ParseNotifyEmails
(
"[]"
)
require
.
Nil
(
t
,
result
)
}
func
TestParseNotifyEmails_Null
(
t
*
testing
.
T
)
{
// "null" is valid JSON that unmarshals into a nil string slice.
// The old-format branch then returns an empty (non-nil) slice.
result
:=
ParseNotifyEmails
(
"null"
)
require
.
Empty
(
t
,
result
)
}
func
TestParseNotifyEmails_WhitespaceOnly
(
t
*
testing
.
T
)
{
result
:=
ParseNotifyEmails
(
" "
)
require
.
Nil
(
t
,
result
)
}
func
TestParseNotifyEmails_OldFormat
(
t
*
testing
.
T
)
{
raw
:=
`["alice@example.com", "bob@example.com"]`
result
:=
ParseNotifyEmails
(
raw
)
require
.
Len
(
t
,
result
,
2
)
require
.
Equal
(
t
,
"alice@example.com"
,
result
[
0
]
.
Email
)
require
.
False
(
t
,
result
[
0
]
.
Verified
,
"old format emails should default to unverified"
)
require
.
False
(
t
,
result
[
0
]
.
Disabled
)
require
.
Equal
(
t
,
"bob@example.com"
,
result
[
1
]
.
Email
)
require
.
False
(
t
,
result
[
1
]
.
Verified
)
require
.
False
(
t
,
result
[
1
]
.
Disabled
)
}
func
TestParseNotifyEmails_OldFormat_SkipsEmptyEntries
(
t
*
testing
.
T
)
{
raw
:=
`["alice@example.com", "", " ", "bob@example.com"]`
result
:=
ParseNotifyEmails
(
raw
)
require
.
Len
(
t
,
result
,
2
)
require
.
Equal
(
t
,
"alice@example.com"
,
result
[
0
]
.
Email
)
require
.
Equal
(
t
,
"bob@example.com"
,
result
[
1
]
.
Email
)
}
func
TestParseNotifyEmails_NewFormat
(
t
*
testing
.
T
)
{
raw
:=
`[{"email":"alice@example.com","verified":true,"disabled":false},{"email":"bob@example.com","verified":false,"disabled":true}]`
result
:=
ParseNotifyEmails
(
raw
)
require
.
Len
(
t
,
result
,
2
)
require
.
Equal
(
t
,
"alice@example.com"
,
result
[
0
]
.
Email
)
require
.
True
(
t
,
result
[
0
]
.
Verified
)
require
.
False
(
t
,
result
[
0
]
.
Disabled
)
require
.
Equal
(
t
,
"bob@example.com"
,
result
[
1
]
.
Email
)
require
.
False
(
t
,
result
[
1
]
.
Verified
)
require
.
True
(
t
,
result
[
1
]
.
Disabled
)
}
func
TestParseNotifyEmails_NewFormat_SingleEntry
(
t
*
testing
.
T
)
{
raw
:=
`[{"email":"solo@example.com","verified":true,"disabled":false}]`
result
:=
ParseNotifyEmails
(
raw
)
require
.
Len
(
t
,
result
,
1
)
require
.
Equal
(
t
,
"solo@example.com"
,
result
[
0
]
.
Email
)
require
.
True
(
t
,
result
[
0
]
.
Verified
)
}
func
TestParseNotifyEmails_InvalidJSON
(
t
*
testing
.
T
)
{
result
:=
ParseNotifyEmails
(
`{not valid json`
)
require
.
Nil
(
t
,
result
)
}
func
TestParseNotifyEmails_InvalidJSONObject
(
t
*
testing
.
T
)
{
// A plain JSON object (not array) should return nil.
result
:=
ParseNotifyEmails
(
`{"email":"a@b.com"}`
)
require
.
Nil
(
t
,
result
)
}
func
TestParseNotifyEmails_WhitespacePadding
(
t
*
testing
.
T
)
{
raw
:=
` ["padded@example.com"] `
result
:=
ParseNotifyEmails
(
raw
)
require
.
Len
(
t
,
result
,
1
)
require
.
Equal
(
t
,
"padded@example.com"
,
result
[
0
]
.
Email
)
}
// ---------- MarshalNotifyEmails ----------
func
TestMarshalNotifyEmails_EmptySlice
(
t
*
testing
.
T
)
{
result
:=
MarshalNotifyEmails
([]
NotifyEmailEntry
{})
require
.
Equal
(
t
,
"[]"
,
result
)
}
func
TestMarshalNotifyEmails_NilSlice
(
t
*
testing
.
T
)
{
result
:=
MarshalNotifyEmails
(
nil
)
require
.
Equal
(
t
,
"[]"
,
result
)
}
func
TestMarshalNotifyEmails_SingleEntry
(
t
*
testing
.
T
)
{
entries
:=
[]
NotifyEmailEntry
{
{
Email
:
"test@example.com"
,
Verified
:
true
,
Disabled
:
false
},
}
result
:=
MarshalNotifyEmails
(
entries
)
require
.
Contains
(
t
,
result
,
`"email":"test@example.com"`
)
require
.
Contains
(
t
,
result
,
`"verified":true`
)
require
.
Contains
(
t
,
result
,
`"disabled":false`
)
// Round-trip: parsing the marshalled result should produce the original entries.
parsed
:=
ParseNotifyEmails
(
result
)
require
.
Len
(
t
,
parsed
,
1
)
require
.
Equal
(
t
,
entries
[
0
],
parsed
[
0
])
}
func
TestMarshalNotifyEmails_MultipleEntries
(
t
*
testing
.
T
)
{
entries
:=
[]
NotifyEmailEntry
{
{
Email
:
"a@example.com"
,
Verified
:
true
,
Disabled
:
false
},
{
Email
:
"b@example.com"
,
Verified
:
false
,
Disabled
:
true
},
}
result
:=
MarshalNotifyEmails
(
entries
)
// Round-trip verification.
parsed
:=
ParseNotifyEmails
(
result
)
require
.
Len
(
t
,
parsed
,
2
)
require
.
Equal
(
t
,
entries
[
0
],
parsed
[
0
])
require
.
Equal
(
t
,
entries
[
1
],
parsed
[
1
])
}
func
TestMarshalNotifyEmails_RoundTrip_NewFormat
(
t
*
testing
.
T
)
{
original
:=
[]
NotifyEmailEntry
{
{
Email
:
"x@example.com"
,
Verified
:
true
,
Disabled
:
true
},
{
Email
:
"y@example.com"
,
Verified
:
false
,
Disabled
:
false
},
}
marshalled
:=
MarshalNotifyEmails
(
original
)
parsed
:=
ParseNotifyEmails
(
marshalled
)
require
.
Equal
(
t
,
original
,
parsed
)
}
// ---------- isOldStringArrayFormat (indirectly via ParseNotifyEmails) ----------
func
TestParseNotifyEmails_MixedOldFormatWithWhitespace
(
t
*
testing
.
T
)
{
// Emails with leading/trailing whitespace in old format should be trimmed.
raw
:=
`[" alice@example.com "]`
result
:=
ParseNotifyEmails
(
raw
)
require
.
Len
(
t
,
result
,
1
)
require
.
Equal
(
t
,
"alice@example.com"
,
result
[
0
]
.
Email
)
}
backend/internal/service/openai_gateway_record_usage_test.go
View file @
0b746501
...
@@ -147,6 +147,7 @@ func newOpenAIRecordUsageServiceForTest(usageRepo UsageLogRepository, userRepo U
...
@@ -147,6 +147,7 @@ func newOpenAIRecordUsageServiceForTest(usageRepo UsageLogRepository, userRepo U
nil
,
nil
,
nil
,
nil
,
nil
,
nil
,
nil
,
)
)
svc
.
userGroupRateResolver
=
newUserGroupRateResolver
(
svc
.
userGroupRateResolver
=
newUserGroupRateResolver
(
rateRepo
,
rateRepo
,
...
...
backend/internal/service/openai_gateway_service.go
View file @
0b746501
...
@@ -327,6 +327,7 @@ type OpenAIGatewayService struct {
...
@@ -327,6 +327,7 @@ type OpenAIGatewayService struct {
openaiWSResolver
OpenAIWSProtocolResolver
openaiWSResolver
OpenAIWSProtocolResolver
resolver
*
ModelPricingResolver
resolver
*
ModelPricingResolver
channelService
*
ChannelService
channelService
*
ChannelService
balanceNotifyService
*
BalanceNotifyService
openaiWSPoolOnce
sync
.
Once
openaiWSPoolOnce
sync
.
Once
openaiWSStateStoreOnce
sync
.
Once
openaiWSStateStoreOnce
sync
.
Once
...
@@ -364,6 +365,7 @@ func NewOpenAIGatewayService(
...
@@ -364,6 +365,7 @@ func NewOpenAIGatewayService(
openAITokenProvider
*
OpenAITokenProvider
,
openAITokenProvider
*
OpenAITokenProvider
,
resolver
*
ModelPricingResolver
,
resolver
*
ModelPricingResolver
,
channelService
*
ChannelService
,
channelService
*
ChannelService
,
balanceNotifyService
*
BalanceNotifyService
,
)
*
OpenAIGatewayService
{
)
*
OpenAIGatewayService
{
svc
:=
&
OpenAIGatewayService
{
svc
:=
&
OpenAIGatewayService
{
accountRepo
:
accountRepo
,
accountRepo
:
accountRepo
,
...
@@ -393,6 +395,7 @@ func NewOpenAIGatewayService(
...
@@ -393,6 +395,7 @@ func NewOpenAIGatewayService(
openaiWSResolver
:
NewOpenAIWSProtocolResolver
(
cfg
),
openaiWSResolver
:
NewOpenAIWSProtocolResolver
(
cfg
),
resolver
:
resolver
,
resolver
:
resolver
,
channelService
:
channelService
,
channelService
:
channelService
,
balanceNotifyService
:
balanceNotifyService
,
responseHeaderFilter
:
compileResponseHeaderFilter
(
cfg
),
responseHeaderFilter
:
compileResponseHeaderFilter
(
cfg
),
codexSnapshotThrottle
:
newAccountWriteThrottle
(
openAICodexSnapshotPersistMinInterval
),
codexSnapshotThrottle
:
newAccountWriteThrottle
(
openAICodexSnapshotPersistMinInterval
),
}
}
...
@@ -477,11 +480,12 @@ func (s *OpenAIGatewayService) getCodexSnapshotThrottle() *accountWriteThrottle
...
@@ -477,11 +480,12 @@ func (s *OpenAIGatewayService) getCodexSnapshotThrottle() *accountWriteThrottle
func
(
s
*
OpenAIGatewayService
)
billingDeps
()
*
billingDeps
{
func
(
s
*
OpenAIGatewayService
)
billingDeps
()
*
billingDeps
{
return
&
billingDeps
{
return
&
billingDeps
{
accountRepo
:
s
.
accountRepo
,
accountRepo
:
s
.
accountRepo
,
userRepo
:
s
.
userRepo
,
userRepo
:
s
.
userRepo
,
userSubRepo
:
s
.
userSubRepo
,
userSubRepo
:
s
.
userSubRepo
,
billingCacheService
:
s
.
billingCacheService
,
billingCacheService
:
s
.
billingCacheService
,
deferredService
:
s
.
deferredService
,
deferredService
:
s
.
deferredService
,
balanceNotifyService
:
s
.
balanceNotifyService
,
}
}
}
}
...
@@ -1677,7 +1681,6 @@ func (s *OpenAIGatewayService) recheckSelectedOpenAIAccountFromDB(ctx context.Co
...
@@ -1677,7 +1681,6 @@ func (s *OpenAIGatewayService) recheckSelectedOpenAIAccountFromDB(ctx context.Co
if
err
!=
nil
||
latest
==
nil
{
if
err
!=
nil
||
latest
==
nil
{
return
nil
return
nil
}
}
syncOpenAICodexRateLimitFromExtra
(
ctx
,
s
.
accountRepo
,
latest
,
time
.
Now
())
if
!
latest
.
IsSchedulable
()
||
!
latest
.
IsOpenAI
()
{
if
!
latest
.
IsSchedulable
()
||
!
latest
.
IsOpenAI
()
{
return
nil
return
nil
}
}
...
@@ -1700,7 +1703,6 @@ func (s *OpenAIGatewayService) getSchedulableAccount(ctx context.Context, accoun
...
@@ -1700,7 +1703,6 @@ func (s *OpenAIGatewayService) getSchedulableAccount(ctx context.Context, accoun
if
err
!=
nil
||
account
==
nil
{
if
err
!=
nil
||
account
==
nil
{
return
account
,
err
return
account
,
err
}
}
syncOpenAICodexRateLimitFromExtra
(
ctx
,
s
.
accountRepo
,
account
,
time
.
Now
())
return
account
,
nil
return
account
,
nil
}
}
...
@@ -4569,6 +4571,14 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
...
@@ -4569,6 +4571,14 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
usageLog
.
SubscriptionID
=
&
subscription
.
ID
usageLog
.
SubscriptionID
=
&
subscription
.
ID
}
}
// 计算账号统计定价费用(使用最终上游模型匹配自定义规则)
if
apiKey
.
GroupID
!=
nil
{
applyAccountStatsCost
(
ctx
,
usageLog
,
s
.
channelService
,
s
.
billingService
,
account
.
ID
,
*
apiKey
.
GroupID
,
result
.
UpstreamModel
,
result
.
Model
,
tokens
,
cost
.
TotalCost
,
)
}
if
s
.
cfg
!=
nil
&&
s
.
cfg
.
RunMode
==
config
.
RunModeSimple
{
if
s
.
cfg
!=
nil
&&
s
.
cfg
.
RunMode
==
config
.
RunModeSimple
{
writeUsageLogBestEffort
(
ctx
,
s
.
usageLogRepo
,
usageLog
,
"service.openai_gateway"
)
writeUsageLogBestEffort
(
ctx
,
s
.
usageLogRepo
,
usageLog
,
"service.openai_gateway"
)
logger
.
LegacyPrintf
(
"service.openai_gateway"
,
"[SIMPLE MODE] Usage recorded (not billed): user=%d, tokens=%d"
,
usageLog
.
UserID
,
usageLog
.
TotalTokens
())
logger
.
LegacyPrintf
(
"service.openai_gateway"
,
"[SIMPLE MODE] Usage recorded (not billed): user=%d, tokens=%d"
,
usageLog
.
UserID
,
usageLog
.
TotalTokens
())
...
@@ -4756,69 +4766,6 @@ func buildCodexUsageExtraUpdates(snapshot *OpenAICodexUsageSnapshot, fallbackNow
...
@@ -4756,69 +4766,6 @@ func buildCodexUsageExtraUpdates(snapshot *OpenAICodexUsageSnapshot, fallbackNow
return
updates
return
updates
}
}
func
codexUsagePercentExhausted
(
value
*
float64
)
bool
{
return
value
!=
nil
&&
*
value
>=
100
-
1e-9
}
func
codexRateLimitResetAtFromSnapshot
(
snapshot
*
OpenAICodexUsageSnapshot
,
fallbackNow
time
.
Time
)
*
time
.
Time
{
if
snapshot
==
nil
{
return
nil
}
normalized
:=
snapshot
.
Normalize
()
if
normalized
==
nil
{
return
nil
}
baseTime
:=
codexSnapshotBaseTime
(
snapshot
,
fallbackNow
)
if
codexUsagePercentExhausted
(
normalized
.
Used7dPercent
)
&&
normalized
.
Reset7dSeconds
!=
nil
{
resetAt
:=
baseTime
.
Add
(
time
.
Duration
(
*
normalized
.
Reset7dSeconds
)
*
time
.
Second
)
return
&
resetAt
}
if
codexUsagePercentExhausted
(
normalized
.
Used5hPercent
)
&&
normalized
.
Reset5hSeconds
!=
nil
{
resetAt
:=
baseTime
.
Add
(
time
.
Duration
(
*
normalized
.
Reset5hSeconds
)
*
time
.
Second
)
return
&
resetAt
}
return
nil
}
func
codexRateLimitResetAtFromExtra
(
extra
map
[
string
]
any
,
now
time
.
Time
)
*
time
.
Time
{
if
len
(
extra
)
==
0
{
return
nil
}
if
progress
:=
buildCodexUsageProgressFromExtra
(
extra
,
"7d"
,
now
);
progress
!=
nil
&&
codexUsagePercentExhausted
(
&
progress
.
Utilization
)
&&
progress
.
ResetsAt
!=
nil
&&
now
.
Before
(
*
progress
.
ResetsAt
)
{
resetAt
:=
progress
.
ResetsAt
.
UTC
()
return
&
resetAt
}
if
progress
:=
buildCodexUsageProgressFromExtra
(
extra
,
"5h"
,
now
);
progress
!=
nil
&&
codexUsagePercentExhausted
(
&
progress
.
Utilization
)
&&
progress
.
ResetsAt
!=
nil
&&
now
.
Before
(
*
progress
.
ResetsAt
)
{
resetAt
:=
progress
.
ResetsAt
.
UTC
()
return
&
resetAt
}
return
nil
}
func
applyOpenAICodexRateLimitFromExtra
(
account
*
Account
,
now
time
.
Time
)
(
*
time
.
Time
,
bool
)
{
if
account
==
nil
||
!
account
.
IsOpenAI
()
{
return
nil
,
false
}
resetAt
:=
codexRateLimitResetAtFromExtra
(
account
.
Extra
,
now
)
if
resetAt
==
nil
{
return
nil
,
false
}
if
account
.
RateLimitResetAt
!=
nil
&&
now
.
Before
(
*
account
.
RateLimitResetAt
)
&&
!
account
.
RateLimitResetAt
.
Before
(
*
resetAt
)
{
return
account
.
RateLimitResetAt
,
false
}
account
.
RateLimitResetAt
=
resetAt
return
resetAt
,
true
}
func
syncOpenAICodexRateLimitFromExtra
(
ctx
context
.
Context
,
repo
AccountRepository
,
account
*
Account
,
now
time
.
Time
)
*
time
.
Time
{
resetAt
,
changed
:=
applyOpenAICodexRateLimitFromExtra
(
account
,
now
)
if
!
changed
||
resetAt
==
nil
||
repo
==
nil
||
account
==
nil
||
account
.
ID
<=
0
{
return
resetAt
}
_
=
repo
.
SetRateLimited
(
ctx
,
account
.
ID
,
*
resetAt
)
return
resetAt
}
// updateCodexUsageSnapshot saves the Codex usage snapshot to account's Extra field
// updateCodexUsageSnapshot saves the Codex usage snapshot to account's Extra field
func
(
s
*
OpenAIGatewayService
)
updateCodexUsageSnapshot
(
ctx
context
.
Context
,
accountID
int64
,
snapshot
*
OpenAICodexUsageSnapshot
)
{
func
(
s
*
OpenAIGatewayService
)
updateCodexUsageSnapshot
(
ctx
context
.
Context
,
accountID
int64
,
snapshot
*
OpenAICodexUsageSnapshot
)
{
if
snapshot
==
nil
{
if
snapshot
==
nil
{
...
@@ -4830,24 +4777,17 @@ func (s *OpenAIGatewayService) updateCodexUsageSnapshot(ctx context.Context, acc
...
@@ -4830,24 +4777,17 @@ func (s *OpenAIGatewayService) updateCodexUsageSnapshot(ctx context.Context, acc
now
:=
time
.
Now
()
now
:=
time
.
Now
()
updates
:=
buildCodexUsageExtraUpdates
(
snapshot
,
now
)
updates
:=
buildCodexUsageExtraUpdates
(
snapshot
,
now
)
resetAt
:=
codexRateLimitResetAtFromSnapshot
(
snapshot
,
now
)
if
len
(
updates
)
==
0
{
if
len
(
updates
)
==
0
&&
resetAt
==
nil
{
return
return
}
}
shouldPersistUpdates
:=
len
(
updates
)
>
0
&&
s
.
getCodexSnapshotThrottle
()
.
Allow
(
accountID
,
now
)
if
!
s
.
getCodexSnapshotThrottle
()
.
Allow
(
accountID
,
now
)
{
if
!
shouldPersistUpdates
&&
resetAt
==
nil
{
return
return
}
}
go
func
()
{
go
func
()
{
updateCtx
,
cancel
:=
context
.
WithTimeout
(
context
.
Background
(),
5
*
time
.
Second
)
updateCtx
,
cancel
:=
context
.
WithTimeout
(
context
.
Background
(),
5
*
time
.
Second
)
defer
cancel
()
defer
cancel
()
if
shouldPersistUpdates
{
_
=
s
.
accountRepo
.
UpdateExtra
(
updateCtx
,
accountID
,
updates
)
_
=
s
.
accountRepo
.
UpdateExtra
(
updateCtx
,
accountID
,
updates
)
}
if
resetAt
!=
nil
{
_
=
s
.
accountRepo
.
SetRateLimited
(
updateCtx
,
accountID
,
*
resetAt
)
}
}()
}()
}
}
...
...
backend/internal/service/openai_ws_forwarder_ingress_session_test.go
View file @
0b746501
...
@@ -413,7 +413,12 @@ func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_PassthroughModeR
...
@@ -413,7 +413,12 @@ func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_PassthroughModeR
select
{
select
{
case
serverErr
:=
<-
serverErrCh
:
case
serverErr
:=
<-
serverErrCh
:
require
.
NoError
(
t
,
serverErr
)
// After normal client close, the server goroutine may receive the close frame
// as an error — this is expected behavior, not a test failure.
if
serverErr
!=
nil
{
require
.
Contains
(
t
,
serverErr
.
Error
(),
"StatusNormalClosure"
,
"server error should only be a normal close frame, got: %v"
,
serverErr
)
}
case
<-
time
.
After
(
5
*
time
.
Second
)
:
case
<-
time
.
After
(
5
*
time
.
Second
)
:
t
.
Fatal
(
"等待 passthrough websocket 结束超时"
)
t
.
Fatal
(
"等待 passthrough websocket 结束超时"
)
}
}
...
...
backend/internal/service/openai_ws_protocol_forward_test.go
View file @
0b746501
...
@@ -617,6 +617,7 @@ func TestNewOpenAIGatewayService_InitializesOpenAIWSResolver(t *testing.T) {
...
@@ -617,6 +617,7 @@ func TestNewOpenAIGatewayService_InitializesOpenAIWSResolver(t *testing.T) {
nil
,
nil
,
nil
,
nil
,
nil
,
nil
,
nil
,
)
)
decision
:=
svc
.
getOpenAIWSProtocolResolver
()
.
Resolve
(
nil
)
decision
:=
svc
.
getOpenAIWSProtocolResolver
()
.
Resolve
(
nil
)
...
...
backend/internal/service/openai_ws_ratelimit_signal_test.go
View file @
0b746501
...
@@ -345,7 +345,7 @@ func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_ErrorEventUsageL
...
@@ -345,7 +345,7 @@ func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_ErrorEventUsageL
}
}
}
}
func
TestOpenAIGatewayService_UpdateCodexUsageSnapshot_ExhaustedSnapshotSet
s
RateLimit
(
t
*
testing
.
T
)
{
func
TestOpenAIGatewayService_UpdateCodexUsageSnapshot_ExhaustedSnapshot
DoesNot
SetRateLimit
(
t
*
testing
.
T
)
{
repo
:=
&
openAICodexSnapshotAsyncRepo
{
repo
:=
&
openAICodexSnapshotAsyncRepo
{
updateExtraCh
:
make
(
chan
map
[
string
]
any
,
1
),
updateExtraCh
:
make
(
chan
map
[
string
]
any
,
1
),
rateLimitCh
:
make
(
chan
time
.
Time
,
1
),
rateLimitCh
:
make
(
chan
time
.
Time
,
1
),
...
@@ -359,7 +359,6 @@ func TestOpenAIGatewayService_UpdateCodexUsageSnapshot_ExhaustedSnapshotSetsRate
...
@@ -359,7 +359,6 @@ func TestOpenAIGatewayService_UpdateCodexUsageSnapshot_ExhaustedSnapshotSetsRate
SecondaryResetAfterSeconds
:
ptrIntWS
(
1200
),
SecondaryResetAfterSeconds
:
ptrIntWS
(
1200
),
SecondaryWindowMinutes
:
ptrIntWS
(
300
),
SecondaryWindowMinutes
:
ptrIntWS
(
300
),
}
}
before
:=
time
.
Now
()
svc
.
updateCodexUsageSnapshot
(
context
.
Background
(),
601
,
snapshot
)
svc
.
updateCodexUsageSnapshot
(
context
.
Background
(),
601
,
snapshot
)
select
{
select
{
...
@@ -371,9 +370,8 @@ func TestOpenAIGatewayService_UpdateCodexUsageSnapshot_ExhaustedSnapshotSetsRate
...
@@ -371,9 +370,8 @@ func TestOpenAIGatewayService_UpdateCodexUsageSnapshot_ExhaustedSnapshotSetsRate
select
{
select
{
case
resetAt
:=
<-
repo
.
rateLimitCh
:
case
resetAt
:=
<-
repo
.
rateLimitCh
:
require
.
WithinDuration
(
t
,
before
.
Add
(
time
.
Hour
),
resetAt
,
2
*
time
.
Second
)
t
.
Fatalf
(
"不应因仅写入快照而生成运行时限流时间: %v"
,
resetAt
)
case
<-
time
.
After
(
2
*
time
.
Second
)
:
case
<-
time
.
After
(
2
*
time
.
Second
)
:
t
.
Fatal
(
"等待 codex 100% 自动切换限流超时"
)
}
}
}
}
...
@@ -401,7 +399,7 @@ func TestOpenAIGatewayService_UpdateCodexUsageSnapshot_NonExhaustedSnapshotDoesN
...
@@ -401,7 +399,7 @@ func TestOpenAIGatewayService_UpdateCodexUsageSnapshot_NonExhaustedSnapshotDoesN
select
{
select
{
case
resetAt
:=
<-
repo
.
rateLimitCh
:
case
resetAt
:=
<-
repo
.
rateLimitCh
:
t
.
Fatalf
(
"
unexpected rate limit reset at
: %v"
,
resetAt
)
t
.
Fatalf
(
"
不应写入运行时限流时间
: %v"
,
resetAt
)
case
<-
time
.
After
(
200
*
time
.
Millisecond
)
:
case
<-
time
.
After
(
200
*
time
.
Millisecond
)
:
}
}
}
}
...
@@ -409,7 +407,6 @@ func TestOpenAIGatewayService_UpdateCodexUsageSnapshot_NonExhaustedSnapshotDoesN
...
@@ -409,7 +407,6 @@ func TestOpenAIGatewayService_UpdateCodexUsageSnapshot_NonExhaustedSnapshotDoesN
func
TestOpenAIGatewayService_UpdateCodexUsageSnapshot_ThrottlesExtraWrites
(
t
*
testing
.
T
)
{
func
TestOpenAIGatewayService_UpdateCodexUsageSnapshot_ThrottlesExtraWrites
(
t
*
testing
.
T
)
{
repo
:=
&
openAICodexSnapshotAsyncRepo
{
repo
:=
&
openAICodexSnapshotAsyncRepo
{
updateExtraCh
:
make
(
chan
map
[
string
]
any
,
2
),
updateExtraCh
:
make
(
chan
map
[
string
]
any
,
2
),
rateLimitCh
:
make
(
chan
time
.
Time
,
2
),
}
}
svc
:=
&
OpenAIGatewayService
{
svc
:=
&
OpenAIGatewayService
{
accountRepo
:
repo
,
accountRepo
:
repo
,
...
@@ -443,7 +440,7 @@ func TestOpenAIGatewayService_UpdateCodexUsageSnapshot_ThrottlesExtraWrites(t *t
...
@@ -443,7 +440,7 @@ func TestOpenAIGatewayService_UpdateCodexUsageSnapshot_ThrottlesExtraWrites(t *t
func
ptrFloat64WS
(
v
float64
)
*
float64
{
return
&
v
}
func
ptrFloat64WS
(
v
float64
)
*
float64
{
return
&
v
}
func
ptrIntWS
(
v
int
)
*
int
{
return
&
v
}
func
ptrIntWS
(
v
int
)
*
int
{
return
&
v
}
func
TestOpenAIGatewayService_GetSchedulableAccount_ExhaustedCodexExtraSet
s
RateLimit
(
t
*
testing
.
T
)
{
func
TestOpenAIGatewayService_GetSchedulableAccount_ExhaustedCodexExtra
DoesNot
SetRateLimit
(
t
*
testing
.
T
)
{
resetAt
:=
time
.
Now
()
.
Add
(
6
*
24
*
time
.
Hour
)
resetAt
:=
time
.
Now
()
.
Add
(
6
*
24
*
time
.
Hour
)
account
:=
Account
{
account
:=
Account
{
ID
:
701
,
ID
:
701
,
...
@@ -463,17 +460,15 @@ func TestOpenAIGatewayService_GetSchedulableAccount_ExhaustedCodexExtraSetsRateL
...
@@ -463,17 +460,15 @@ func TestOpenAIGatewayService_GetSchedulableAccount_ExhaustedCodexExtraSetsRateL
fresh
,
err
:=
svc
.
getSchedulableAccount
(
context
.
Background
(),
account
.
ID
)
fresh
,
err
:=
svc
.
getSchedulableAccount
(
context
.
Background
(),
account
.
ID
)
require
.
NoError
(
t
,
err
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
fresh
)
require
.
NotNil
(
t
,
fresh
)
require
.
NotNil
(
t
,
fresh
.
RateLimitResetAt
)
require
.
Nil
(
t
,
fresh
.
RateLimitResetAt
)
require
.
WithinDuration
(
t
,
resetAt
.
UTC
(),
*
fresh
.
RateLimitResetAt
,
time
.
Second
)
select
{
select
{
case
persisted
:=
<-
repo
.
rateLimitCh
:
case
persisted
:=
<-
repo
.
rateLimitCh
:
require
.
WithinDuration
(
t
,
resetAt
.
UTC
(),
persisted
,
time
.
Secon
d
)
t
.
Fatalf
(
"不应将已耗尽的 codex extra 提升为运行时限流状态: %v"
,
persiste
d
)
case
<-
time
.
After
(
2
*
time
.
Second
)
:
case
<-
time
.
After
(
2
*
time
.
Second
)
:
t
.
Fatal
(
"等待旧快照补写限流状态超时"
)
}
}
}
}
func
TestAdminService_ListAccounts_ExhaustedCodexExtra
ReturnsRateLimitedAccoun
t
(
t
*
testing
.
T
)
{
func
TestAdminService_ListAccounts_ExhaustedCodexExtra
DoesNotSetRateLimi
t
(
t
*
testing
.
T
)
{
resetAt
:=
time
.
Now
()
.
Add
(
4
*
24
*
time
.
Hour
)
resetAt
:=
time
.
Now
()
.
Add
(
4
*
24
*
time
.
Hour
)
repo
:=
&
openAICodexExtraListRepo
{
repo
:=
&
openAICodexExtraListRepo
{
stubOpenAIAccountRepo
:
stubOpenAIAccountRepo
{
accounts
:
[]
Account
{{
stubOpenAIAccountRepo
:
stubOpenAIAccountRepo
{
accounts
:
[]
Account
{{
...
@@ -496,13 +491,11 @@ func TestAdminService_ListAccounts_ExhaustedCodexExtraReturnsRateLimitedAccount(
...
@@ -496,13 +491,11 @@ func TestAdminService_ListAccounts_ExhaustedCodexExtraReturnsRateLimitedAccount(
require
.
NoError
(
t
,
err
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
int64
(
1
),
total
)
require
.
Equal
(
t
,
int64
(
1
),
total
)
require
.
Len
(
t
,
accounts
,
1
)
require
.
Len
(
t
,
accounts
,
1
)
require
.
NotNil
(
t
,
accounts
[
0
]
.
RateLimitResetAt
)
require
.
Nil
(
t
,
accounts
[
0
]
.
RateLimitResetAt
)
require
.
WithinDuration
(
t
,
resetAt
.
UTC
(),
*
accounts
[
0
]
.
RateLimitResetAt
,
time
.
Second
)
select
{
select
{
case
persisted
:=
<-
repo
.
rateLimitCh
:
case
persisted
:=
<-
repo
.
rateLimitCh
:
require
.
WithinDuration
(
t
,
resetAt
.
UTC
(),
persisted
,
time
.
Secon
d
)
t
.
Fatalf
(
"不应在账号列表查询时将 codex extra 持久化为运行时限流状态: %v"
,
persiste
d
)
case
<-
time
.
After
(
2
*
time
.
Second
)
:
case
<-
time
.
After
(
2
*
time
.
Second
)
:
t
.
Fatal
(
"等待列表补写限流状态超时"
)
}
}
}
}
...
...
backend/internal/service/ops_concurrency.go
View file @
0b746501
...
@@ -64,12 +64,9 @@ func (s *OpsService) getAccountsLoadMapBestEffort(ctx context.Context, accounts
...
@@ -64,12 +64,9 @@ func (s *OpsService) getAccountsLoadMapBestEffort(ctx context.Context, accounts
if
acc
.
ID
<=
0
{
if
acc
.
ID
<=
0
{
continue
continue
}
}
c
:=
acc
.
Concurrency
lf
:=
acc
.
EffectiveLoadFactor
()
if
c
<=
0
{
if
prev
,
ok
:=
unique
[
acc
.
ID
];
!
ok
||
lf
>
prev
{
c
=
1
unique
[
acc
.
ID
]
=
lf
}
if
prev
,
ok
:=
unique
[
acc
.
ID
];
!
ok
||
c
>
prev
{
unique
[
acc
.
ID
]
=
c
}
}
}
}
...
...
backend/internal/service/ops_metrics_collector.go
View file @
0b746501
...
@@ -391,7 +391,7 @@ func (c *OpsMetricsCollector) collectConcurrencyQueueDepth(parentCtx context.Con
...
@@ -391,7 +391,7 @@ func (c *OpsMetricsCollector) collectConcurrencyQueueDepth(parentCtx context.Con
}
}
batch
=
append
(
batch
,
AccountWithConcurrency
{
batch
=
append
(
batch
,
AccountWithConcurrency
{
ID
:
acc
.
ID
,
ID
:
acc
.
ID
,
MaxConcurrency
:
acc
.
Concurrency
,
MaxConcurrency
:
acc
.
EffectiveLoadFactor
()
,
})
})
}
}
if
len
(
batch
)
==
0
{
if
len
(
batch
)
==
0
{
...
...
backend/internal/service/ops_system_log_sink_test.go
View file @
0b746501
...
@@ -183,6 +183,15 @@ func TestOpsSystemLogSink_StartStopAndFlushSuccess(t *testing.T) {
...
@@ -183,6 +183,15 @@ func TestOpsSystemLogSink_StartStopAndFlushSuccess(t *testing.T) {
if
strings
.
TrimSpace
(
item
.
Message
)
==
""
{
if
strings
.
TrimSpace
(
item
.
Message
)
==
""
{
t
.
Fatalf
(
"message should not be empty"
)
t
.
Fatalf
(
"message should not be empty"
)
}
}
// writtenCount is incremented after BatchInsertSystemLogsFn returns,
// so poll briefly to avoid a race between the done signal and the atomic add.
deadline
:=
time
.
Now
()
.
Add
(
time
.
Second
)
for
time
.
Now
()
.
Before
(
deadline
)
{
if
sink
.
Health
()
.
WrittenCount
>
0
{
break
}
time
.
Sleep
(
time
.
Millisecond
)
}
health
:=
sink
.
Health
()
health
:=
sink
.
Health
()
if
health
.
WrittenCount
==
0
{
if
health
.
WrittenCount
==
0
{
t
.
Fatalf
(
"written_count should be >0"
)
t
.
Fatalf
(
"written_count should be >0"
)
...
...
backend/internal/service/payment_amounts.go
0 → 100644
View file @
0b746501
package
service
import
(
"math"
"github.com/shopspring/decimal"
)
const
defaultBalanceRechargeMultiplier
=
1.0
func
normalizeBalanceRechargeMultiplier
(
multiplier
float64
)
float64
{
if
math
.
IsNaN
(
multiplier
)
||
math
.
IsInf
(
multiplier
,
0
)
||
multiplier
<=
0
{
return
defaultBalanceRechargeMultiplier
}
return
multiplier
}
func
calculateCreditedBalance
(
paymentAmount
,
multiplier
float64
)
float64
{
return
decimal
.
NewFromFloat
(
paymentAmount
)
.
Mul
(
decimal
.
NewFromFloat
(
normalizeBalanceRechargeMultiplier
(
multiplier
)))
.
Round
(
2
)
.
InexactFloat64
()
}
func
calculateGatewayRefundAmount
(
orderAmount
,
payAmount
,
refundAmount
float64
)
float64
{
if
orderAmount
<=
0
||
payAmount
<=
0
||
refundAmount
<=
0
{
return
0
}
if
math
.
Abs
(
refundAmount
-
orderAmount
)
<=
amountToleranceCNY
{
return
decimal
.
NewFromFloat
(
payAmount
)
.
Round
(
2
)
.
InexactFloat64
()
}
return
decimal
.
NewFromFloat
(
payAmount
)
.
Mul
(
decimal
.
NewFromFloat
(
refundAmount
))
.
Div
(
decimal
.
NewFromFloat
(
orderAmount
))
.
Round
(
2
)
.
InexactFloat64
()
}
backend/internal/service/payment_config_plans.go
View file @
0b746501
...
@@ -3,6 +3,7 @@ package service
...
@@ -3,6 +3,7 @@ package service
import
(
import
(
"context"
"context"
"fmt"
"fmt"
"strings"
dbent
"github.com/Wei-Shaw/sub2api/ent"
dbent
"github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/group"
"github.com/Wei-Shaw/sub2api/ent/group"
...
@@ -10,6 +11,52 @@ import (
...
@@ -10,6 +11,52 @@ import (
infraerrors
"github.com/Wei-Shaw/sub2api/internal/pkg/errors"
infraerrors
"github.com/Wei-Shaw/sub2api/internal/pkg/errors"
)
)
// validatePlanRequired checks that all required fields for a plan are provided.
func
validatePlanRequired
(
name
string
,
groupID
int64
,
price
float64
,
validityDays
int
,
validityUnit
string
,
originalPrice
*
float64
)
error
{
if
strings
.
TrimSpace
(
name
)
==
""
{
return
infraerrors
.
BadRequest
(
"PLAN_NAME_REQUIRED"
,
"plan name is required"
)
}
if
groupID
<=
0
{
return
infraerrors
.
BadRequest
(
"PLAN_GROUP_REQUIRED"
,
"group is required"
)
}
if
price
<=
0
{
return
infraerrors
.
BadRequest
(
"PLAN_PRICE_INVALID"
,
"price must be > 0"
)
}
if
validityDays
<=
0
{
return
infraerrors
.
BadRequest
(
"PLAN_VALIDITY_REQUIRED"
,
"validity days must be > 0"
)
}
if
strings
.
TrimSpace
(
validityUnit
)
==
""
{
return
infraerrors
.
BadRequest
(
"PLAN_VALIDITY_UNIT_REQUIRED"
,
"validity unit is required"
)
}
if
originalPrice
!=
nil
&&
*
originalPrice
<
0
{
return
infraerrors
.
BadRequest
(
"PLAN_ORIGINAL_PRICE_INVALID"
,
"original price must be >= 0"
)
}
return
nil
}
// validatePlanPatch validates only the non-nil fields in a patch update.
func
validatePlanPatch
(
req
UpdatePlanRequest
)
error
{
if
req
.
Name
!=
nil
&&
strings
.
TrimSpace
(
*
req
.
Name
)
==
""
{
return
infraerrors
.
BadRequest
(
"PLAN_NAME_REQUIRED"
,
"plan name is required"
)
}
if
req
.
GroupID
!=
nil
&&
*
req
.
GroupID
<=
0
{
return
infraerrors
.
BadRequest
(
"PLAN_GROUP_REQUIRED"
,
"group is required"
)
}
if
req
.
Price
!=
nil
&&
*
req
.
Price
<=
0
{
return
infraerrors
.
BadRequest
(
"PLAN_PRICE_INVALID"
,
"price must be > 0"
)
}
if
req
.
ValidityDays
!=
nil
&&
*
req
.
ValidityDays
<=
0
{
return
infraerrors
.
BadRequest
(
"PLAN_VALIDITY_REQUIRED"
,
"validity days must be > 0"
)
}
if
req
.
ValidityUnit
!=
nil
&&
strings
.
TrimSpace
(
*
req
.
ValidityUnit
)
==
""
{
return
infraerrors
.
BadRequest
(
"PLAN_VALIDITY_UNIT_REQUIRED"
,
"validity unit is required"
)
}
if
req
.
OriginalPrice
!=
nil
&&
*
req
.
OriginalPrice
<
0
{
return
infraerrors
.
BadRequest
(
"PLAN_ORIGINAL_PRICE_INVALID"
,
"original price must be >= 0"
)
}
return
nil
}
// --- Plan CRUD ---
// --- Plan CRUD ---
// PlanGroupInfo holds the group details needed for subscription plan display.
// PlanGroupInfo holds the group details needed for subscription plan display.
...
@@ -74,6 +121,9 @@ func (s *PaymentConfigService) ListPlansForSale(ctx context.Context) ([]*dbent.S
...
@@ -74,6 +121,9 @@ func (s *PaymentConfigService) ListPlansForSale(ctx context.Context) ([]*dbent.S
}
}
func
(
s
*
PaymentConfigService
)
CreatePlan
(
ctx
context
.
Context
,
req
CreatePlanRequest
)
(
*
dbent
.
SubscriptionPlan
,
error
)
{
func
(
s
*
PaymentConfigService
)
CreatePlan
(
ctx
context
.
Context
,
req
CreatePlanRequest
)
(
*
dbent
.
SubscriptionPlan
,
error
)
{
if
err
:=
validatePlanRequired
(
req
.
Name
,
req
.
GroupID
,
req
.
Price
,
req
.
ValidityDays
,
req
.
ValidityUnit
,
req
.
OriginalPrice
);
err
!=
nil
{
return
nil
,
err
}
b
:=
s
.
entClient
.
SubscriptionPlan
.
Create
()
.
b
:=
s
.
entClient
.
SubscriptionPlan
.
Create
()
.
SetGroupID
(
req
.
GroupID
)
.
SetName
(
req
.
Name
)
.
SetDescription
(
req
.
Description
)
.
SetGroupID
(
req
.
GroupID
)
.
SetName
(
req
.
Name
)
.
SetDescription
(
req
.
Description
)
.
SetPrice
(
req
.
Price
)
.
SetValidityDays
(
req
.
ValidityDays
)
.
SetValidityUnit
(
req
.
ValidityUnit
)
.
SetPrice
(
req
.
Price
)
.
SetValidityDays
(
req
.
ValidityDays
)
.
SetValidityUnit
(
req
.
ValidityUnit
)
.
...
@@ -86,8 +136,12 @@ func (s *PaymentConfigService) CreatePlan(ctx context.Context, req CreatePlanReq
...
@@ -86,8 +136,12 @@ func (s *PaymentConfigService) CreatePlan(ctx context.Context, req CreatePlanReq
}
}
// UpdatePlan updates a subscription plan by ID (patch semantics).
// UpdatePlan updates a subscription plan by ID (patch semantics).
// NOTE: This function exceeds 30 lines due to per-field nil-check patch update boilerplate.
// NOTE: This function exceeds 30 lines due to per-field nil-check patch update boilerplate
// plus a validation guard for non-nil fields.
func
(
s
*
PaymentConfigService
)
UpdatePlan
(
ctx
context
.
Context
,
id
int64
,
req
UpdatePlanRequest
)
(
*
dbent
.
SubscriptionPlan
,
error
)
{
func
(
s
*
PaymentConfigService
)
UpdatePlan
(
ctx
context
.
Context
,
id
int64
,
req
UpdatePlanRequest
)
(
*
dbent
.
SubscriptionPlan
,
error
)
{
if
err
:=
validatePlanPatch
(
req
);
err
!=
nil
{
return
nil
,
err
}
u
:=
s
.
entClient
.
SubscriptionPlan
.
UpdateOneID
(
id
)
u
:=
s
.
entClient
.
SubscriptionPlan
.
UpdateOneID
(
id
)
if
req
.
GroupID
!=
nil
{
if
req
.
GroupID
!=
nil
{
u
.
SetGroupID
(
*
req
.
GroupID
)
u
.
SetGroupID
(
*
req
.
GroupID
)
...
...
backend/internal/service/payment_config_plans_validation_test.go
0 → 100644
View file @
0b746501
//go:build unit
package
service
import
(
"testing"
"github.com/stretchr/testify/require"
)
func
TestValidatePlanRequired_AllValid
(
t
*
testing
.
T
)
{
err
:=
validatePlanRequired
(
"Pro"
,
1
,
9.99
,
30
,
"days"
,
nil
)
require
.
NoError
(
t
,
err
)
}
func
TestValidatePlanRequired_EmptyName
(
t
*
testing
.
T
)
{
err
:=
validatePlanRequired
(
""
,
1
,
9.99
,
30
,
"days"
,
nil
)
require
.
Error
(
t
,
err
)
require
.
Contains
(
t
,
err
.
Error
(),
"plan name"
)
}
func
TestValidatePlanRequired_WhitespaceName
(
t
*
testing
.
T
)
{
err
:=
validatePlanRequired
(
" "
,
1
,
9.99
,
30
,
"days"
,
nil
)
require
.
Error
(
t
,
err
)
require
.
Contains
(
t
,
err
.
Error
(),
"plan name"
)
}
func
TestValidatePlanRequired_ZeroGroupID
(
t
*
testing
.
T
)
{
err
:=
validatePlanRequired
(
"Pro"
,
0
,
9.99
,
30
,
"days"
,
nil
)
require
.
Error
(
t
,
err
)
require
.
Contains
(
t
,
err
.
Error
(),
"group"
)
}
func
TestValidatePlanRequired_NegativeGroupID
(
t
*
testing
.
T
)
{
err
:=
validatePlanRequired
(
"Pro"
,
-
1
,
9.99
,
30
,
"days"
,
nil
)
require
.
Error
(
t
,
err
)
require
.
Contains
(
t
,
err
.
Error
(),
"group"
)
}
func
TestValidatePlanRequired_ZeroPrice
(
t
*
testing
.
T
)
{
err
:=
validatePlanRequired
(
"Pro"
,
1
,
0
,
30
,
"days"
,
nil
)
require
.
Error
(
t
,
err
)
require
.
Contains
(
t
,
err
.
Error
(),
"price"
)
}
func
TestValidatePlanRequired_NegativePrice
(
t
*
testing
.
T
)
{
err
:=
validatePlanRequired
(
"Pro"
,
1
,
-
5
,
30
,
"days"
,
nil
)
require
.
Error
(
t
,
err
)
require
.
Contains
(
t
,
err
.
Error
(),
"price"
)
}
func
TestValidatePlanRequired_ZeroValidityDays
(
t
*
testing
.
T
)
{
err
:=
validatePlanRequired
(
"Pro"
,
1
,
9.99
,
0
,
"days"
,
nil
)
require
.
Error
(
t
,
err
)
require
.
Contains
(
t
,
err
.
Error
(),
"validity days"
)
}
func
TestValidatePlanRequired_NegativeValidityDays
(
t
*
testing
.
T
)
{
err
:=
validatePlanRequired
(
"Pro"
,
1
,
9.99
,
-
7
,
"days"
,
nil
)
require
.
Error
(
t
,
err
)
require
.
Contains
(
t
,
err
.
Error
(),
"validity days"
)
}
func
TestValidatePlanRequired_EmptyValidityUnit
(
t
*
testing
.
T
)
{
err
:=
validatePlanRequired
(
"Pro"
,
1
,
9.99
,
30
,
""
,
nil
)
require
.
Error
(
t
,
err
)
require
.
Contains
(
t
,
err
.
Error
(),
"validity unit"
)
}
func
TestValidatePlanRequired_WhitespaceValidityUnit
(
t
*
testing
.
T
)
{
err
:=
validatePlanRequired
(
"Pro"
,
1
,
9.99
,
30
,
" "
,
nil
)
require
.
Error
(
t
,
err
)
require
.
Contains
(
t
,
err
.
Error
(),
"validity unit"
)
}
func
TestValidatePlanRequired_NameValidatedFirst
(
t
*
testing
.
T
)
{
err
:=
validatePlanRequired
(
""
,
0
,
0
,
0
,
""
,
nil
)
require
.
Error
(
t
,
err
)
require
.
Contains
(
t
,
err
.
Error
(),
"plan name"
)
}
func
TestValidatePlanRequired_TrimmedValidName
(
t
*
testing
.
T
)
{
err
:=
validatePlanRequired
(
" Pro "
,
1
,
9.99
,
30
,
"days"
,
nil
)
require
.
NoError
(
t
,
err
)
}
func
TestValidatePlanRequired_NegativeOriginalPrice
(
t
*
testing
.
T
)
{
neg
:=
-
10.0
err
:=
validatePlanRequired
(
"Pro"
,
1
,
9.99
,
30
,
"days"
,
&
neg
)
require
.
Error
(
t
,
err
)
require
.
Contains
(
t
,
err
.
Error
(),
"original price"
)
}
func
TestValidatePlanRequired_ZeroOriginalPrice
(
t
*
testing
.
T
)
{
zero
:=
0.0
err
:=
validatePlanRequired
(
"Pro"
,
1
,
9.99
,
30
,
"days"
,
&
zero
)
require
.
NoError
(
t
,
err
)
}
func
TestValidatePlanRequired_ValidOriginalPrice
(
t
*
testing
.
T
)
{
op
:=
19.99
err
:=
validatePlanRequired
(
"Pro"
,
1
,
9.99
,
30
,
"days"
,
&
op
)
require
.
NoError
(
t
,
err
)
}
// --- validatePlanPatch tests ---
func
TestValidatePlanPatch_NegativeOriginalPrice
(
t
*
testing
.
T
)
{
neg
:=
-
5.0
err
:=
validatePlanPatch
(
UpdatePlanRequest
{
OriginalPrice
:
&
neg
})
require
.
Error
(
t
,
err
)
require
.
Contains
(
t
,
err
.
Error
(),
"original price"
)
}
func
TestValidatePlanPatch_ZeroOriginalPrice
(
t
*
testing
.
T
)
{
zero
:=
0.0
err
:=
validatePlanPatch
(
UpdatePlanRequest
{
OriginalPrice
:
&
zero
})
require
.
NoError
(
t
,
err
)
}
func
TestValidatePlanPatch_ValidOriginalPrice
(
t
*
testing
.
T
)
{
op
:=
29.99
err
:=
validatePlanPatch
(
UpdatePlanRequest
{
OriginalPrice
:
&
op
})
require
.
NoError
(
t
,
err
)
}
func
TestValidatePlanPatch_NilOriginalPrice
(
t
*
testing
.
T
)
{
err
:=
validatePlanPatch
(
UpdatePlanRequest
{
OriginalPrice
:
nil
})
require
.
NoError
(
t
,
err
)
}
// --- validatePlanPatch: other fields ---
func
ptrStr
(
s
string
)
*
string
{
return
&
s
}
func
ptrInt
(
i
int
)
*
int
{
return
&
i
}
func
ptrInt64
(
i
int64
)
*
int64
{
return
&
i
}
func
ptrFloat
(
f
float64
)
*
float64
{
return
&
f
}
func
TestValidatePlanPatch_EmptyName
(
t
*
testing
.
T
)
{
err
:=
validatePlanPatch
(
UpdatePlanRequest
{
Name
:
ptrStr
(
""
)})
require
.
Error
(
t
,
err
)
require
.
Contains
(
t
,
err
.
Error
(),
"plan name"
)
}
func
TestValidatePlanPatch_ValidName
(
t
*
testing
.
T
)
{
err
:=
validatePlanPatch
(
UpdatePlanRequest
{
Name
:
ptrStr
(
"Basic"
)})
require
.
NoError
(
t
,
err
)
}
func
TestValidatePlanPatch_ZeroGroupID
(
t
*
testing
.
T
)
{
err
:=
validatePlanPatch
(
UpdatePlanRequest
{
GroupID
:
ptrInt64
(
0
)})
require
.
Error
(
t
,
err
)
require
.
Contains
(
t
,
err
.
Error
(),
"group"
)
}
func
TestValidatePlanPatch_NegativePrice
(
t
*
testing
.
T
)
{
err
:=
validatePlanPatch
(
UpdatePlanRequest
{
Price
:
ptrFloat
(
-
1
)})
require
.
Error
(
t
,
err
)
require
.
Contains
(
t
,
err
.
Error
(),
"price"
)
}
func
TestValidatePlanPatch_ZeroPrice
(
t
*
testing
.
T
)
{
err
:=
validatePlanPatch
(
UpdatePlanRequest
{
Price
:
ptrFloat
(
0
)})
require
.
Error
(
t
,
err
)
require
.
Contains
(
t
,
err
.
Error
(),
"price"
)
}
func
TestValidatePlanPatch_ValidPrice
(
t
*
testing
.
T
)
{
err
:=
validatePlanPatch
(
UpdatePlanRequest
{
Price
:
ptrFloat
(
9.99
)})
require
.
NoError
(
t
,
err
)
}
func
TestValidatePlanPatch_ZeroValidityDays
(
t
*
testing
.
T
)
{
err
:=
validatePlanPatch
(
UpdatePlanRequest
{
ValidityDays
:
ptrInt
(
0
)})
require
.
Error
(
t
,
err
)
require
.
Contains
(
t
,
err
.
Error
(),
"validity days"
)
}
func
TestValidatePlanPatch_EmptyValidityUnit
(
t
*
testing
.
T
)
{
err
:=
validatePlanPatch
(
UpdatePlanRequest
{
ValidityUnit
:
ptrStr
(
""
)})
require
.
Error
(
t
,
err
)
require
.
Contains
(
t
,
err
.
Error
(),
"validity unit"
)
}
func
TestValidatePlanPatch_ValidValidityUnit
(
t
*
testing
.
T
)
{
err
:=
validatePlanPatch
(
UpdatePlanRequest
{
ValidityUnit
:
ptrStr
(
"days"
)})
require
.
NoError
(
t
,
err
)
}
func
TestValidatePlanPatch_AllNil
(
t
*
testing
.
T
)
{
err
:=
validatePlanPatch
(
UpdatePlanRequest
{})
require
.
NoError
(
t
,
err
)
}
backend/internal/service/payment_config_providers.go
View file @
0b746501
...
@@ -22,16 +22,17 @@ func (s *PaymentConfigService) ListProviderInstances(ctx context.Context) ([]*db
...
@@ -22,16 +22,17 @@ func (s *PaymentConfigService) ListProviderInstances(ctx context.Context) ([]*db
// ProviderInstanceResponse is the API response for a provider instance.
// ProviderInstanceResponse is the API response for a provider instance.
type
ProviderInstanceResponse
struct
{
type
ProviderInstanceResponse
struct
{
ID
int64
`json:"id"`
ID
int64
`json:"id"`
ProviderKey
string
`json:"provider_key"`
ProviderKey
string
`json:"provider_key"`
Name
string
`json:"name"`
Name
string
`json:"name"`
Config
map
[
string
]
string
`json:"config"`
Config
map
[
string
]
string
`json:"config"`
SupportedTypes
[]
string
`json:"supported_types"`
SupportedTypes
[]
string
`json:"supported_types"`
Limits
string
`json:"limits"`
Limits
string
`json:"limits"`
Enabled
bool
`json:"enabled"`
Enabled
bool
`json:"enabled"`
RefundEnabled
bool
`json:"refund_enabled"`
RefundEnabled
bool
`json:"refund_enabled"`
SortOrder
int
`json:"sort_order"`
AllowUserRefund
bool
`json:"allow_user_refund"`
PaymentMode
string
`json:"payment_mode"`
SortOrder
int
`json:"sort_order"`
PaymentMode
string
`json:"payment_mode"`
}
}
// ListProviderInstancesWithConfig returns provider instances with decrypted config.
// ListProviderInstancesWithConfig returns provider instances with decrypted config.
...
@@ -46,8 +47,9 @@ func (s *PaymentConfigService) ListProviderInstancesWithConfig(ctx context.Conte
...
@@ -46,8 +47,9 @@ func (s *PaymentConfigService) ListProviderInstancesWithConfig(ctx context.Conte
resp
:=
ProviderInstanceResponse
{
resp
:=
ProviderInstanceResponse
{
ID
:
int64
(
inst
.
ID
),
ProviderKey
:
inst
.
ProviderKey
,
Name
:
inst
.
Name
,
ID
:
int64
(
inst
.
ID
),
ProviderKey
:
inst
.
ProviderKey
,
Name
:
inst
.
Name
,
SupportedTypes
:
splitTypes
(
inst
.
SupportedTypes
),
Limits
:
inst
.
Limits
,
SupportedTypes
:
splitTypes
(
inst
.
SupportedTypes
),
Limits
:
inst
.
Limits
,
Enabled
:
inst
.
Enabled
,
RefundEnabled
:
inst
.
RefundEnabled
,
SortOrder
:
inst
.
SortOrder
,
Enabled
:
inst
.
Enabled
,
RefundEnabled
:
inst
.
RefundEnabled
,
PaymentMode
:
inst
.
PaymentMode
,
AllowUserRefund
:
inst
.
AllowUserRefund
,
SortOrder
:
inst
.
SortOrder
,
PaymentMode
:
inst
.
PaymentMode
,
}
}
resp
.
Config
,
err
=
s
.
decryptAndMaskConfig
(
inst
.
Config
)
resp
.
Config
,
err
=
s
.
decryptAndMaskConfig
(
inst
.
Config
)
if
err
!=
nil
{
if
err
!=
nil
{
...
@@ -110,10 +112,12 @@ func (s *PaymentConfigService) CreateProviderInstance(ctx context.Context, req C
...
@@ -110,10 +112,12 @@ func (s *PaymentConfigService) CreateProviderInstance(ctx context.Context, req C
if
err
!=
nil
{
if
err
!=
nil
{
return
nil
,
err
return
nil
,
err
}
}
allowUserRefund
:=
req
.
AllowUserRefund
&&
req
.
RefundEnabled
return
s
.
entClient
.
PaymentProviderInstance
.
Create
()
.
return
s
.
entClient
.
PaymentProviderInstance
.
Create
()
.
SetProviderKey
(
req
.
ProviderKey
)
.
SetName
(
req
.
Name
)
.
SetConfig
(
enc
)
.
SetProviderKey
(
req
.
ProviderKey
)
.
SetName
(
req
.
Name
)
.
SetConfig
(
enc
)
.
SetSupportedTypes
(
typesStr
)
.
SetEnabled
(
req
.
Enabled
)
.
SetPaymentMode
(
req
.
PaymentMode
)
.
SetSupportedTypes
(
typesStr
)
.
SetEnabled
(
req
.
Enabled
)
.
SetPaymentMode
(
req
.
PaymentMode
)
.
SetSortOrder
(
req
.
SortOrder
)
.
SetLimits
(
req
.
Limits
)
.
SetRefundEnabled
(
req
.
RefundEnabled
)
.
SetSortOrder
(
req
.
SortOrder
)
.
SetLimits
(
req
.
Limits
)
.
SetRefundEnabled
(
req
.
RefundEnabled
)
.
SetAllowUserRefund
(
allowUserRefund
)
.
Save
(
ctx
)
Save
(
ctx
)
}
}
...
@@ -221,6 +225,29 @@ func (s *PaymentConfigService) UpdateProviderInstance(ctx context.Context, id in
...
@@ -221,6 +225,29 @@ func (s *PaymentConfigService) UpdateProviderInstance(ctx context.Context, id in
}
}
if
req
.
RefundEnabled
!=
nil
{
if
req
.
RefundEnabled
!=
nil
{
u
.
SetRefundEnabled
(
*
req
.
RefundEnabled
)
u
.
SetRefundEnabled
(
*
req
.
RefundEnabled
)
// Cascade: turning off refund_enabled also disables allow_user_refund
if
!*
req
.
RefundEnabled
{
u
.
SetAllowUserRefund
(
false
)
}
}
if
req
.
AllowUserRefund
!=
nil
{
// Only allow enabling when refund_enabled is (or will be) true
if
*
req
.
AllowUserRefund
{
refundEnabled
:=
false
if
req
.
RefundEnabled
!=
nil
{
refundEnabled
=
*
req
.
RefundEnabled
}
else
{
inst
,
err
:=
s
.
entClient
.
PaymentProviderInstance
.
Get
(
ctx
,
id
)
if
err
==
nil
{
refundEnabled
=
inst
.
RefundEnabled
}
}
if
refundEnabled
{
u
.
SetAllowUserRefund
(
true
)
}
}
else
{
u
.
SetAllowUserRefund
(
false
)
}
}
}
if
req
.
PaymentMode
!=
nil
{
if
req
.
PaymentMode
!=
nil
{
u
.
SetPaymentMode
(
*
req
.
PaymentMode
)
u
.
SetPaymentMode
(
*
req
.
PaymentMode
)
...
@@ -228,6 +255,23 @@ func (s *PaymentConfigService) UpdateProviderInstance(ctx context.Context, id in
...
@@ -228,6 +255,23 @@ func (s *PaymentConfigService) UpdateProviderInstance(ctx context.Context, id in
return
u
.
Save
(
ctx
)
return
u
.
Save
(
ctx
)
}
}
// GetUserRefundEligibleInstanceIDs returns provider instance IDs that allow user refund.
func
(
s
*
PaymentConfigService
)
GetUserRefundEligibleInstanceIDs
(
ctx
context
.
Context
)
([]
string
,
error
)
{
instances
,
err
:=
s
.
entClient
.
PaymentProviderInstance
.
Query
()
.
Where
(
paymentproviderinstance
.
RefundEnabledEQ
(
true
),
paymentproviderinstance
.
AllowUserRefundEQ
(
true
),
)
.
Select
(
paymentproviderinstance
.
FieldID
)
.
All
(
ctx
)
if
err
!=
nil
{
return
nil
,
err
}
ids
:=
make
([]
string
,
0
,
len
(
instances
))
for
_
,
inst
:=
range
instances
{
ids
=
append
(
ids
,
strconv
.
FormatInt
(
int64
(
inst
.
ID
),
10
))
}
return
ids
,
nil
}
func
(
s
*
PaymentConfigService
)
mergeConfig
(
ctx
context
.
Context
,
id
int64
,
newConfig
map
[
string
]
string
)
(
map
[
string
]
string
,
error
)
{
func
(
s
*
PaymentConfigService
)
mergeConfig
(
ctx
context
.
Context
,
id
int64
,
newConfig
map
[
string
]
string
)
(
map
[
string
]
string
,
error
)
{
inst
,
err
:=
s
.
entClient
.
PaymentProviderInstance
.
Get
(
ctx
,
id
)
inst
,
err
:=
s
.
entClient
.
PaymentProviderInstance
.
Get
(
ctx
,
id
)
if
err
!=
nil
{
if
err
!=
nil
{
...
...
backend/internal/service/payment_config_providers_test.go
View file @
0b746501
...
@@ -101,7 +101,7 @@ func TestIsSensitiveConfigField(t *testing.T) {
...
@@ -101,7 +101,7 @@ func TestIsSensitiveConfigField(t *testing.T) {
t
.
Parallel
()
t
.
Parallel
()
tests
:=
[]
struct
{
tests
:=
[]
struct
{
field
string
field
string
wantSen
bool
wantSen
bool
}{
}{
// Sensitive fields (contain key/secret/private/password/pkey patterns)
// Sensitive fields (contain key/secret/private/password/pkey patterns)
...
...
Prev
1
2
3
4
5
6
7
8
9
10
…
12
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment