Commit 0b746501 authored by 陈曦's avatar 陈曦
Browse files

1. merge upstream v0.1.113 2.提交migration相关文件

parents 45061102 be7551b9
......@@ -43,6 +43,7 @@ func newGatewayRecordUsageServiceForTest(usageRepo UsageLogRepository, userRepo
nil,
nil,
nil,
nil,
)
}
......
......@@ -75,6 +75,9 @@ type ParsedRequest struct {
MaxTokens int // max_tokens 值(用于探测请求拦截)
SessionContext *SessionContext // 可选:请求上下文区分因子(nil 时行为不变)
// GroupID 请求所属分组 ID(来自 API Key)
GroupID *int64
// OnUpstreamAccepted 上游接受请求后立即调用(用于提前释放串行锁)
// 流式请求在收到 2xx 响应头后调用,避免持锁等流完成
OnUpstreamAccepted func()
......
......@@ -503,7 +503,6 @@ type ForwardResult struct {
// 图片生成计费字段(图片生成模型使用)
ImageCount int // 生成的图片数量
ImageSize string // 图片尺寸 "1K", "2K", "4K"
}
// UpstreamFailoverError indicates an upstream error that should trigger account failover.
......@@ -570,6 +569,7 @@ type GatewayService struct {
resolver *ModelPricingResolver
debugGatewayBodyFile atomic.Pointer[os.File] // non-nil when SUB2API_DEBUG_GATEWAY_BODY is set
tlsFPProfileService *TLSFingerprintProfileService
balanceNotifyService *BalanceNotifyService
}
// NewGatewayService creates a new GatewayService
......@@ -599,6 +599,7 @@ func NewGatewayService(
tlsFPProfileService *TLSFingerprintProfileService,
channelService *ChannelService,
resolver *ModelPricingResolver,
balanceNotifyService *BalanceNotifyService,
) *GatewayService {
userGroupRateTTL := resolveUserGroupRateCacheTTL(cfg)
modelsListTTL := resolveModelsListCacheTTL(cfg)
......@@ -633,6 +634,7 @@ func NewGatewayService(
tlsFPProfileService: tlsFPProfileService,
channelService: channelService,
resolver: resolver,
balanceNotifyService: balanceNotifyService,
}
svc.userGroupRateResolver = newUserGroupRateResolver(
userGroupRateRepo,
......@@ -1329,6 +1331,11 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
ctx = s.withWindowCostPrefetch(ctx, accounts)
ctx = s.withRPMPrefetch(ctx, accounts)
// 提前构建 accountByID(供 Layer 1 和 Layer 1.5 使用)
accountByID := make(map[int64]*Account, len(accounts))
for i := range accounts {
accountByID[accounts[i].ID] = &accounts[i]
}
isExcluded := func(accountID int64) bool {
if excludedIDs == nil {
return false
......@@ -1337,12 +1344,6 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
return excluded
}
// 提前构建 accountByID(供 Layer 1 和 Layer 1.5 使用)
accountByID := make(map[int64]*Account, len(accounts))
for i := range accounts {
accountByID[accounts[i].ID] = &accounts[i]
}
// 获取模型路由配置(仅 anthropic 平台)
var routingAccountIDs []int64
if group != nil && requestedModel != "" && group.Platform == PlatformAnthropic {
......@@ -1598,7 +1599,6 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
account, ok := accountByID[accountID]
if ok {
// 检查账户是否需要清理粘性会话绑定
// Check if the account needs sticky session cleanup
clearSticky := shouldClearStickySession(account, requestedModel)
if clearSticky {
_ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
......@@ -1614,7 +1614,6 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
result, err := s.tryAcquireAccountSlot(ctx, accountID, account.Concurrency)
if err == nil && result.Acquired {
// 会话数量限制检查
// Session count limit check
if !s.checkAndRegisterSession(ctx, account, sessionHash) {
result.ReleaseFunc() // 释放槽位,继续到 Layer 2
} else {
......@@ -1628,10 +1627,8 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, accountID)
if waitingCount < cfg.StickySessionMaxWaiting {
// 会话数量限制检查(等待计划也需要占用会话配额)
// Session count limit check (wait plan also requires session quota)
if !s.checkAndRegisterSession(ctx, account, sessionHash) {
// 会话限制已满,继续到 Layer 2
// Session limit full, continue to Layer 2
} else {
return s.newSelectionResult(ctx, account, false, nil, &AccountWaitPlan{
AccountID: accountID,
......@@ -2740,7 +2737,7 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
if clearSticky {
_ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
}
if !clearSticky && s.isAccountInGroup(account, groupID) && account.Platform == platform && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && s.isAccountSchedulableForModelSelection(ctx, account, requestedModel) && s.isAccountSchedulableForQuota(account) && s.isAccountSchedulableForWindowCost(ctx, account, true) && s.isAccountSchedulableForRPM(ctx, account, true) {
if !clearSticky && s.isAccountInGroup(account, groupID) && account.Platform == platform && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && s.isAccountSchedulableForModelSelection(ctx, account, requestedModel) && s.isAccountSchedulableForQuota(account) && s.isAccountSchedulableForWindowCost(ctx, account, true) && s.isAccountSchedulableForRPM(ctx, account, true) && !s.isStickyAccountUpstreamRestricted(ctx, groupID, account, requestedModel) {
if s.debugModelRoutingEnabled() {
logger.LegacyPrintf("service.gateway", "[ModelRoutingDebug] legacy routed sticky hit: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), accountID)
}
......@@ -3119,7 +3116,7 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
if clearSticky {
_ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
}
if !clearSticky && s.isAccountInGroup(account, groupID) && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && s.isAccountSchedulableForModelSelection(ctx, account, requestedModel) && s.isAccountSchedulableForQuota(account) && s.isAccountSchedulableForWindowCost(ctx, account, true) && s.isAccountSchedulableForRPM(ctx, account, true) {
if !clearSticky && s.isAccountInGroup(account, groupID) && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && s.isAccountSchedulableForModelSelection(ctx, account, requestedModel) && s.isAccountSchedulableForQuota(account) && s.isAccountSchedulableForWindowCost(ctx, account, true) && s.isAccountSchedulableForRPM(ctx, account, true) && !s.isStickyAccountUpstreamRestricted(ctx, groupID, account, requestedModel) {
if account.Platform == nativePlatform || (account.Platform == PlatformAntigravity && account.IsMixedSchedulingEnabled()) {
return account, nil
}
......@@ -3435,6 +3432,10 @@ func (s *GatewayService) isModelSupportedByAccount(account *Account, requestedMo
_, ok := ResolveBedrockModelID(account, requestedModel)
return ok
}
// OpenAI 透传模式:仅替换认证,允许所有模型
if account.Platform == PlatformOpenAI && account.IsOpenAIPassthroughEnabled() {
return true
}
// OAuth/SetupToken 账号使用 Anthropic 标准映射(短ID → 长ID)
if account.Platform == PlatformAnthropic && account.Type != AccountTypeAPIKey {
requestedModel = claude.NormalizeModelID(requestedModel)
......@@ -3934,6 +3935,11 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
return nil, fmt.Errorf("parse request: empty request")
}
// Web Search 模拟:纯 web_search 请求时,直接调用搜索 API 构造响应
if account != nil && s.shouldEmulateWebSearch(ctx, account, parsed.GroupID, parsed.Body) {
return s.handleWebSearchEmulation(ctx, c, account, parsed)
}
if account != nil && account.IsAnthropicAPIKeyPassthroughEnabled() {
passthroughBody := parsed.Body
passthroughModel := parsed.Model
......@@ -7279,6 +7285,7 @@ func (s *GatewayService) getUserGroupRateMultiplier(ctx context.Context, userID,
// RecordUsageInput 记录使用量的输入参数
type RecordUsageInput struct {
Result *ForwardResult
ParsedRequest *ParsedRequest
APIKey *APIKey
User *User
Account *Account
......@@ -7333,49 +7340,41 @@ func (p *postUsageBillingParams) shouldUpdateAccountQuota() bool {
return p.Cost.TotalCost > 0 && p.Account.IsAPIKeyOrBedrock() && p.Account.HasAnyQuotaLimit()
}
// postUsageBilling 统一处理使用量记录后的扣费逻辑:
// - 订阅/余额扣费
// - API Key 配额更新
// - API Key 限速用量更新
// - 账号配额用量更新(账号口径:TotalCost × 账号计费倍率)
// postUsageBilling is the legacy fallback billing path used when the unified
// billing repo is unavailable (nil). Production uses applyUsageBilling → repo.Apply
// for atomic billing. This path only runs in tests or degraded mode.
func postUsageBilling(ctx context.Context, p *postUsageBillingParams, deps *billingDeps) {
billingCtx, cancel := detachedBillingContext(ctx)
defer cancel()
cost := p.Cost
// 1. 订阅 / 余额扣费
if p.IsSubscriptionBill {
if cost.TotalCost > 0 {
if err := deps.userSubRepo.IncrementUsage(billingCtx, p.Subscription.ID, cost.TotalCost); err != nil {
slog.Error("increment subscription usage failed", "subscription_id", p.Subscription.ID, "error", err)
}
deps.billingCacheService.QueueUpdateSubscriptionUsage(p.User.ID, *p.APIKey.GroupID, cost.TotalCost)
}
} else {
if cost.ActualCost > 0 {
if err := deps.userRepo.DeductBalance(billingCtx, p.User.ID, cost.ActualCost); err != nil {
slog.Error("deduct balance failed", "user_id", p.User.ID, "error", err)
}
deps.billingCacheService.QueueDeductBalance(p.User.ID, cost.ActualCost)
}
}
// 2. API Key 配额
if p.shouldDeductAPIKeyQuota() {
if err := p.APIKeyService.UpdateQuotaUsed(billingCtx, p.APIKey.ID, cost.ActualCost); err != nil {
slog.Error("update api key quota failed", "api_key_id", p.APIKey.ID, "error", err)
}
}
// 3. API Key 限速用量
if p.shouldUpdateRateLimits() {
if err := p.APIKeyService.UpdateRateLimitUsage(billingCtx, p.APIKey.ID, cost.ActualCost); err != nil {
slog.Error("update api key rate limit usage failed", "api_key_id", p.APIKey.ID, "error", err)
}
}
// 4. 账号配额用量(账号口径:TotalCost × 账号计费倍率)
if p.shouldUpdateAccountQuota() {
accountCost := cost.TotalCost * p.AccountRateMultiplier
if err := deps.accountRepo.IncrementQuotaUsed(billingCtx, p.Account.ID, accountCost); err != nil {
......@@ -7383,7 +7382,10 @@ func postUsageBilling(ctx context.Context, p *postUsageBillingParams, deps *bill
}
}
finalizePostUsageBilling(p, deps)
// NOTE: finalizePostUsageBilling is NOT called here to avoid double-queuing
// cache updates. The legacy path does DB writes directly; the finalize path
// does cache queue + notifications. Notifications are dispatched separately
// by the caller after recording the usage log.
}
func resolveUsageBillingRequestID(ctx context.Context, upstreamRequestID string) string {
......@@ -7499,11 +7501,11 @@ func applyUsageBilling(ctx context.Context, requestID string, usageLog *UsageLog
}
}
finalizePostUsageBilling(p, deps)
finalizePostUsageBilling(p, deps, result)
return true, nil
}
func finalizePostUsageBilling(p *postUsageBillingParams, deps *billingDeps) {
func finalizePostUsageBilling(p *postUsageBillingParams, deps *billingDeps, result *UsageBillingApplyResult) {
if p == nil || p.Cost == nil || deps == nil {
return
}
......@@ -7521,6 +7523,83 @@ func finalizePostUsageBilling(p *postUsageBillingParams, deps *billingDeps) {
}
deps.deferredService.ScheduleLastUsedUpdate(p.Account.ID)
// Notification checks run async — all parameters are already captured,
// no dependency on the request context or upstream connection.
go notifyBalanceLow(p, deps, result)
go notifyAccountQuota(p, deps, result)
}
// notifyBalanceLow sends balance low notification after deduction.
// When result.NewBalance is available (from DB transaction RETURNING), it is used directly
// to reconstruct oldBalance, avoiding stale Redis reads and concurrent-deduction races.
func notifyBalanceLow(p *postUsageBillingParams, deps *billingDeps, result *UsageBillingApplyResult) {
defer func() {
if r := recover(); r != nil {
slog.Error("panic in notifyBalanceLow", "recover", r)
}
}()
if p.IsSubscriptionBill || p.Cost.ActualCost <= 0 || p.User == nil || deps.balanceNotifyService == nil {
slog.Debug("notifyBalanceLow: skipped",
"is_subscription", p.IsSubscriptionBill,
"actual_cost", p.Cost.ActualCost,
"user_nil", p.User == nil,
"service_nil", deps.balanceNotifyService == nil,
)
return
}
oldBalance := resolveOldBalance(p, result)
slog.Debug("notifyBalanceLow: calling CheckBalanceAfterDeduction",
"user_id", p.User.ID,
"old_balance", oldBalance,
"cost", p.Cost.ActualCost,
"notify_enabled", p.User.BalanceNotifyEnabled,
"threshold", p.User.BalanceNotifyThreshold,
"result_has_new_balance", result != nil && result.NewBalance != nil,
)
deps.balanceNotifyService.CheckBalanceAfterDeduction(context.Background(), p.User, oldBalance, p.Cost.ActualCost)
}
// resolveOldBalance returns the pre-deduction balance.
// Prefers the DB transaction result (newBalance + cost) over snapshot.
func resolveOldBalance(p *postUsageBillingParams, result *UsageBillingApplyResult) float64 {
if result != nil && result.NewBalance != nil {
return *result.NewBalance + p.Cost.ActualCost
}
// Legacy fallback: snapshot balance from request context
return p.User.Balance
}
// notifyAccountQuota sends account quota threshold notification after increment.
// When result.QuotaState is available (from DB transaction RETURNING), it is passed directly
// to avoid a separate DB read that may see stale or concurrently-modified data.
func notifyAccountQuota(p *postUsageBillingParams, deps *billingDeps, result *UsageBillingApplyResult) {
defer func() {
if r := recover(); r != nil {
slog.Error("panic in notifyAccountQuota", "recover", r)
}
}()
if p.Cost.TotalCost <= 0 || p.Account == nil || !p.Account.IsAPIKeyOrBedrock() || deps.balanceNotifyService == nil {
slog.Debug("notifyAccountQuota: skipped",
"total_cost", p.Cost.TotalCost,
"account_nil", p.Account == nil,
"is_apikey_or_bedrock", p.Account != nil && p.Account.IsAPIKeyOrBedrock(),
"service_nil", deps.balanceNotifyService == nil,
)
return
}
accountCost := p.Cost.TotalCost * p.AccountRateMultiplier
var quotaState *AccountQuotaState
if result != nil {
quotaState = result.QuotaState
}
slog.Debug("notifyAccountQuota: calling CheckAccountQuotaAfterIncrement",
"account_id", p.Account.ID,
"account_cost", accountCost,
"has_quota_state", quotaState != nil,
)
deps.balanceNotifyService.CheckAccountQuotaAfterIncrement(context.Background(), p.Account, accountCost, quotaState)
}
func detachedBillingContext(ctx context.Context) (context.Context, context.CancelFunc) {
......@@ -7543,20 +7622,22 @@ func detachStreamUpstreamContext(ctx context.Context, stream bool) (context.Cont
// billingDeps 扣费逻辑依赖的服务(由各 gateway service 提供)
type billingDeps struct {
accountRepo AccountRepository
userRepo UserRepository
userSubRepo UserSubscriptionRepository
billingCacheService *BillingCacheService
deferredService *DeferredService
accountRepo AccountRepository
userRepo UserRepository
userSubRepo UserSubscriptionRepository
billingCacheService *BillingCacheService
deferredService *DeferredService
balanceNotifyService *BalanceNotifyService
}
func (s *GatewayService) billingDeps() *billingDeps {
return &billingDeps{
accountRepo: s.accountRepo,
userRepo: s.userRepo,
userSubRepo: s.userSubRepo,
billingCacheService: s.billingCacheService,
deferredService: s.deferredService,
accountRepo: s.accountRepo,
userRepo: s.userRepo,
userSubRepo: s.userSubRepo,
billingCacheService: s.billingCacheService,
deferredService: s.deferredService,
balanceNotifyService: s.balanceNotifyService,
}
}
......@@ -7746,6 +7827,23 @@ func (s *GatewayService) recordUsageCore(ctx context.Context, input *recordUsage
usageLog := s.buildRecordUsageLog(ctx, input, result, apiKey, user, account, subscription,
requestedModel, multiplier, accountRateMultiplier, billingType, cacheTTLOverridden, cost, opts)
// 计算账号统计定价费用(使用最终上游模型匹配自定义规则)
if apiKey.GroupID != nil {
applyAccountStatsCost(ctx, usageLog, s.channelService, s.billingService,
account.ID, *apiKey.GroupID, result.UpstreamModel, result.Model,
// Anthropic's input_tokens excludes cache_read and cache_creation (billed separately);
// OpenAI gateway uses actualInputTokens which also excludes cache_read for the same reason.
UsageTokens{
InputTokens: result.Usage.InputTokens,
OutputTokens: result.Usage.OutputTokens,
CacheCreationTokens: result.Usage.CacheCreationInputTokens,
CacheReadTokens: result.Usage.CacheReadInputTokens,
ImageOutputTokens: result.Usage.ImageOutputTokens,
},
cost.TotalCost,
)
}
if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple {
writeUsageLogBestEffort(ctx, s.usageLogRepo, usageLog, "service.gateway")
logger.LegacyPrintf("service.gateway", "[SIMPLE MODE] Usage recorded (not billed): user=%d, tokens=%d", usageLog.UserID, usageLog.TotalTokens())
......@@ -8086,6 +8184,19 @@ func (s *GatewayService) needsUpstreamChannelRestrictionCheck(ctx context.Contex
return ch.BillingModelSource == BillingModelSourceUpstream
}
// isStickyAccountUpstreamRestricted 检查粘性会话命中的账号是否受 upstream 渠道限制。
// 合并 needsUpstreamChannelRestrictionCheck + isUpstreamModelRestrictedByChannel 两步调用,
// 供 sticky session 条件链使用,避免内联多个函数调用导致行过长。
func (s *GatewayService) isStickyAccountUpstreamRestricted(ctx context.Context, groupID *int64, account *Account, requestedModel string) bool {
if groupID == nil {
return false
}
if !s.needsUpstreamChannelRestrictionCheck(ctx, groupID) {
return false
}
return s.isUpstreamModelRestrictedByChannel(ctx, *groupID, account, requestedModel)
}
// ForwardCountTokens 转发 count_tokens 请求到上游 API
// 特点:不记录使用量、仅支持非流式响应
func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context, account *Account, parsed *ParsedRequest) error {
......
package service
import (
"context"
"encoding/json"
"errors"
"fmt"
"log/slog"
"net/http"
"strings"
"sync/atomic"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/websearch"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
"github.com/tidwall/gjson"
)
// Web search emulation constants
const (
toolTypeWebSearchPrefix = "web_search"
toolTypeGoogleSearch = "google_search"
toolNameWebSearch = "web_search"
toolNameGoogleSearch = "google_search"
toolNameWebSearch2025 = "web_search_20250305"
webSearchDefaultMaxResults = 5
defaultWebSearchModel = "claude-sonnet-4-6"
webSearchMsgIDPrefix = "msg_ws_"
webSearchToolUseIDPrefix = "srvtoolu_ws_"
tokenEstimateDivisor = 4
// featureKeyWebSearchEmulation is the key used in Account.Extra and Channel.FeaturesConfig.
featureKeyWebSearchEmulation = "web_search_emulation"
)
// webSearchManagerPtr stores *websearch.Manager atomically for concurrent safety.
var webSearchManagerPtr atomic.Pointer[websearch.Manager]
// SetWebSearchManager wires the websearch.Manager into the gateway (goroutine-safe).
func SetWebSearchManager(m *websearch.Manager) {
webSearchManagerPtr.Store(m)
}
func getWebSearchManager() *websearch.Manager {
return webSearchManagerPtr.Load()
}
// shouldEmulateWebSearch checks whether a request should be intercepted.
//
// Judgment chain: manager exists → only web_search tool → global enabled → account/channel enabled.
// Account-level mode: "enabled" (force on), "disabled" (force off), "default" (follow channel).
func (s *GatewayService) shouldEmulateWebSearch(ctx context.Context, account *Account, groupID *int64, body []byte) bool {
if getWebSearchManager() == nil {
return false
}
if !isOnlyWebSearchToolInBody(body) {
return false
}
if !s.settingService.IsWebSearchEmulationEnabled(ctx) {
return false
}
mode := account.GetWebSearchEmulationMode()
switch mode {
case WebSearchModeEnabled:
return true
case WebSearchModeDisabled:
return false
default: // "default" → follow channel config
if groupID == nil || s.channelService == nil {
return false
}
ch, err := s.channelService.GetChannelForGroup(ctx, *groupID)
if err != nil || ch == nil {
return false
}
return ch.IsWebSearchEmulationEnabled(account.Platform)
}
}
// isOnlyWebSearchToolInBody checks if the body contains exactly one web_search tool.
func isOnlyWebSearchToolInBody(body []byte) bool {
tools := gjson.GetBytes(body, "tools")
if !tools.IsArray() {
return false
}
arr := tools.Array()
if len(arr) != 1 {
return false
}
return isWebSearchToolJSON(arr[0])
}
func isWebSearchToolJSON(tool gjson.Result) bool {
toolType := tool.Get("type").String()
if strings.HasPrefix(toolType, toolTypeWebSearchPrefix) || toolType == toolTypeGoogleSearch {
return true
}
switch tool.Get("name").String() {
case toolNameWebSearch, toolNameGoogleSearch, toolNameWebSearch2025:
return true
}
return false
}
// extractSearchQueryFromBody extracts the last user message text as the search query.
func extractSearchQueryFromBody(body []byte) string {
messages := gjson.GetBytes(body, "messages")
if !messages.IsArray() {
return ""
}
arr := messages.Array()
if len(arr) == 0 {
return ""
}
lastMsg := arr[len(arr)-1]
if lastMsg.Get("role").String() != "user" {
return ""
}
return extractWebSearchTextFromContent(lastMsg.Get("content"))
}
func extractWebSearchTextFromContent(content gjson.Result) string {
if content.Type == gjson.String {
return content.String()
}
if content.IsArray() {
for _, block := range content.Array() {
if block.Get("type").String() == "text" {
if text := block.Get("text").String(); text != "" {
return text
}
}
}
}
return ""
}
// handleWebSearchEmulation intercepts a web-search-only request,
// calls a third-party search API, and constructs an Anthropic-format response.
func (s *GatewayService) handleWebSearchEmulation(
ctx context.Context, c *gin.Context, account *Account, parsed *ParsedRequest,
) (*ForwardResult, error) {
startTime := time.Now()
// Release the serial queue lock immediately — we don't need upstream.
if parsed.OnUpstreamAccepted != nil {
parsed.OnUpstreamAccepted()
}
query := extractSearchQueryFromBody(parsed.Body)
if query == "" {
return nil, fmt.Errorf("web search emulation: no query found in messages")
}
slog.Info("web search emulation: executing search",
"account_id", account.ID, "account_name", account.Name, "query", query)
resp, providerName, err := doWebSearch(ctx, account, query)
if err != nil {
// Proxy unavailable → trigger account switch via UpstreamFailoverError
if errors.Is(err, websearch.ErrProxyUnavailable) {
return nil, &UpstreamFailoverError{
StatusCode: http.StatusBadGateway,
ResponseBody: []byte(err.Error()),
}
}
return nil, err
}
slog.Info("web search emulation: search completed",
"provider", providerName, "results_count", len(resp.Results))
model := parsed.Model
if model == "" {
model = defaultWebSearchModel
}
if parsed.Stream {
return writeWebSearchStreamResponse(c, query, resp, model, startTime)
}
return writeWebSearchNonStreamResponse(c, query, resp, model, startTime)
}
func doWebSearch(ctx context.Context, account *Account, query string) (*websearch.SearchResponse, string, error) {
proxyURL := resolveAccountProxyURL(account)
mgr := getWebSearchManager()
if mgr == nil {
return nil, "", fmt.Errorf("web search emulation: manager not initialized")
}
resp, providerName, err := mgr.SearchWithBestProvider(ctx, websearch.SearchRequest{
Query: query, MaxResults: webSearchDefaultMaxResults, ProxyURL: proxyURL,
})
if err != nil {
slog.Error("web search emulation: search failed", "error", err)
return nil, "", fmt.Errorf("web search emulation: %w", err)
}
return resp, providerName, nil
}
func resolveAccountProxyURL(account *Account) string {
if account.ProxyID != nil && account.Proxy != nil {
return account.Proxy.URL()
}
return ""
}
// --- SSE streaming response ---
func writeWebSearchStreamResponse(
c *gin.Context, query string, resp *websearch.SearchResponse, model string, startTime time.Time,
) (*ForwardResult, error) {
msgID := webSearchMsgIDPrefix + uuid.New().String()
toolUseID := webSearchToolUseIDPrefix + uuid.New().String()[:16]
textSummary := buildTextSummary(query, resp.Results)
setSSEHeaders(c)
w := c.Writer
for _, fn := range []func() error{
func() error { return writeSSEMessageStart(w, msgID, model) },
func() error { return writeSSEServerToolUse(w, toolUseID, query, 0) },
func() error { return writeSSEToolResult(w, toolUseID, resp.Results, 1) },
func() error { return writeSSETextBlock(w, textSummary, 2) },
func() error { return writeSSEMessageEnd(w, len(textSummary)/tokenEstimateDivisor) },
} {
if err := fn(); err != nil {
slog.Warn("web search emulation: SSE write failed, stopping", "error", err)
break
}
}
w.Flush()
return &ForwardResult{Model: model, Duration: time.Since(startTime), Usage: ClaudeUsage{}}, nil
}
func setSSEHeaders(c *gin.Context) {
c.Writer.Header().Set("Content-Type", "text/event-stream")
c.Writer.Header().Set("Cache-Control", "no-cache")
c.Writer.Header().Set("Connection", "keep-alive")
c.Writer.Header().Set("X-Accel-Buffering", "no")
c.Writer.WriteHeader(http.StatusOK)
}
func writeSSEMessageStart(w http.ResponseWriter, msgID, model string) error {
evt := map[string]any{
"type": "message_start",
"message": map[string]any{
"id": msgID, "type": "message", "role": "assistant", "model": model,
"content": []any{}, "stop_reason": nil, "stop_sequence": nil,
"usage": map[string]int{"input_tokens": 0, "output_tokens": 0},
},
}
return flushSSEJSON(w, "message_start", evt)
}
func writeSSEServerToolUse(w http.ResponseWriter, toolUseID, query string, index int) error {
start := map[string]any{
"type": "content_block_start", "index": index,
"content_block": map[string]any{
"type": "server_tool_use", "id": toolUseID,
"name": toolNameWebSearch, "input": map[string]string{"query": query},
},
}
if err := flushSSEJSON(w, "content_block_start", start); err != nil {
return err
}
return flushSSEJSON(w, "content_block_stop", map[string]any{"type": "content_block_stop", "index": index})
}
func writeSSEToolResult(w http.ResponseWriter, toolUseID string, results []websearch.SearchResult, index int) error {
start := map[string]any{
"type": "content_block_start", "index": index,
"content_block": map[string]any{
"type": "web_search_tool_result", "tool_use_id": toolUseID,
"content": buildSearchResultBlocks(results),
},
}
if err := flushSSEJSON(w, "content_block_start", start); err != nil {
return err
}
return flushSSEJSON(w, "content_block_stop", map[string]any{"type": "content_block_stop", "index": index})
}
func writeSSETextBlock(w http.ResponseWriter, text string, index int) error {
if err := flushSSEJSON(w, "content_block_start", map[string]any{
"type": "content_block_start", "index": index,
"content_block": map[string]any{"type": "text", "text": ""},
}); err != nil {
return err
}
if err := flushSSEJSON(w, "content_block_delta", map[string]any{
"type": "content_block_delta", "index": index,
"delta": map[string]string{"type": "text_delta", "text": text},
}); err != nil {
return err
}
return flushSSEJSON(w, "content_block_stop", map[string]any{"type": "content_block_stop", "index": index})
}
func writeSSEMessageEnd(w http.ResponseWriter, outputTokens int) error {
if err := flushSSEJSON(w, "message_delta", map[string]any{
"type": "message_delta",
"delta": map[string]any{"stop_reason": "end_turn", "stop_sequence": nil},
"usage": map[string]int{"output_tokens": outputTokens},
}); err != nil {
return err
}
return flushSSEJSON(w, "message_stop", map[string]string{"type": "message_stop"})
}
// flushSSEJSON marshals data to JSON and writes an SSE event.
func flushSSEJSON(w http.ResponseWriter, event string, data any) error {
b, err := json.Marshal(data)
if err != nil {
return fmt.Errorf("marshal: %w", err)
}
if _, err := fmt.Fprintf(w, "event: %s\ndata: %s\n\n", event, b); err != nil {
return fmt.Errorf("write: %w", err)
}
if f, ok := w.(http.Flusher); ok {
f.Flush()
}
return nil
}
// --- Non-streaming JSON response ---
func writeWebSearchNonStreamResponse(
c *gin.Context, query string, resp *websearch.SearchResponse, model string, startTime time.Time,
) (*ForwardResult, error) {
msgID := webSearchMsgIDPrefix + uuid.New().String()
toolUseID := webSearchToolUseIDPrefix + uuid.New().String()[:16]
textSummary := buildTextSummary(query, resp.Results)
msg := map[string]any{
"id": msgID, "type": "message", "role": "assistant", "model": model,
"content": []any{
map[string]any{
"type": "server_tool_use", "id": toolUseID,
"name": toolNameWebSearch, "input": map[string]string{"query": query},
},
map[string]any{
"type": "web_search_tool_result", "tool_use_id": toolUseID,
"content": buildSearchResultBlocks(resp.Results),
},
map[string]any{"type": "text", "text": textSummary},
},
"stop_reason": "end_turn", "stop_sequence": nil,
"usage": map[string]int{"input_tokens": 0, "output_tokens": len(textSummary) / tokenEstimateDivisor},
}
body, err := json.Marshal(msg)
if err != nil {
return nil, fmt.Errorf("web search emulation: marshal response: %w", err)
}
c.Data(http.StatusOK, "application/json", body)
return &ForwardResult{Model: model, Duration: time.Since(startTime), Usage: ClaudeUsage{}}, nil
}
// --- Helpers ---
func buildSearchResultBlocks(results []websearch.SearchResult) []map[string]string {
blocks := make([]map[string]string, 0, len(results))
for _, r := range results {
block := map[string]string{
"type": "web_search_result",
"url": r.URL,
"title": r.Title,
}
if r.Snippet != "" {
block["page_content"] = r.Snippet
}
if r.PageAge != "" {
block["page_age"] = r.PageAge
}
blocks = append(blocks, block)
}
return blocks
}
func buildTextSummary(query string, results []websearch.SearchResult) string {
if len(results) == 0 {
return "No search results found for: " + query
}
var sb strings.Builder
fmt.Fprintf(&sb, "Here are the search results for \"%s\":\n\n", query)
for i, r := range results {
fmt.Fprintf(&sb, "%d. **%s**\n %s\n %s\n\n", i+1, r.Title, r.URL, r.Snippet)
}
return sb.String()
}
//go:build unit
package service
import (
"context"
"encoding/json"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/websearch"
"github.com/stretchr/testify/require"
)
// --- isOnlyWebSearchToolInBody ---
func TestIsOnlyWebSearchToolInBody_WebSearchType(t *testing.T) {
require.True(t, isOnlyWebSearchToolInBody([]byte(`{"tools":[{"type":"web_search"}]}`)))
}
func TestIsOnlyWebSearchToolInBody_WebSearch2025Type(t *testing.T) {
require.True(t, isOnlyWebSearchToolInBody([]byte(`{"tools":[{"type":"web_search_20250305"}]}`)))
}
func TestIsOnlyWebSearchToolInBody_GoogleSearchType(t *testing.T) {
require.True(t, isOnlyWebSearchToolInBody([]byte(`{"tools":[{"type":"google_search"}]}`)))
}
func TestIsOnlyWebSearchToolInBody_NameWebSearch(t *testing.T) {
require.True(t, isOnlyWebSearchToolInBody([]byte(`{"tools":[{"name":"web_search"}]}`)))
}
func TestIsOnlyWebSearchToolInBody_NameWebSearch2025(t *testing.T) {
require.True(t, isOnlyWebSearchToolInBody([]byte(`{"tools":[{"name":"web_search_20250305"}]}`)))
}
func TestIsOnlyWebSearchToolInBody_NameGoogleSearch(t *testing.T) {
require.True(t, isOnlyWebSearchToolInBody([]byte(`{"tools":[{"name":"google_search"}]}`)))
}
func TestIsOnlyWebSearchToolInBody_MultipleTools(t *testing.T) {
require.False(t, isOnlyWebSearchToolInBody(
[]byte(`{"tools":[{"type":"web_search"},{"type":"text_editor"}]}`)))
}
func TestIsOnlyWebSearchToolInBody_NoTools(t *testing.T) {
require.False(t, isOnlyWebSearchToolInBody([]byte(`{"model":"claude-3"}`)))
}
func TestIsOnlyWebSearchToolInBody_EmptyToolsArray(t *testing.T) {
require.False(t, isOnlyWebSearchToolInBody([]byte(`{"tools":[]}`)))
}
func TestIsOnlyWebSearchToolInBody_NonWebSearchTool(t *testing.T) {
require.False(t, isOnlyWebSearchToolInBody([]byte(`{"tools":[{"type":"text_editor"}]}`)))
}
func TestIsOnlyWebSearchToolInBody_ToolsNotArray(t *testing.T) {
require.False(t, isOnlyWebSearchToolInBody([]byte(`{"tools":"web_search"}`)))
}
// --- extractSearchQueryFromBody ---
func TestExtractSearchQueryFromBody_StringContent(t *testing.T) {
body := `{"messages":[{"role":"user","content":"what is golang"}]}`
require.Equal(t, "what is golang", extractSearchQueryFromBody([]byte(body)))
}
func TestExtractSearchQueryFromBody_ArrayContent(t *testing.T) {
body := `{"messages":[{"role":"user","content":[{"type":"text","text":"search this"}]}]}`
require.Equal(t, "search this", extractSearchQueryFromBody([]byte(body)))
}
func TestExtractSearchQueryFromBody_MultipleMessages(t *testing.T) {
body := `{"messages":[{"role":"user","content":"first"},{"role":"assistant","content":"ok"},{"role":"user","content":"second"}]}`
require.Equal(t, "second", extractSearchQueryFromBody([]byte(body)))
}
func TestExtractSearchQueryFromBody_LastMessageNotUser(t *testing.T) {
body := `{"messages":[{"role":"user","content":"q"},{"role":"assistant","content":"a"}]}`
require.Equal(t, "", extractSearchQueryFromBody([]byte(body)))
}
func TestExtractSearchQueryFromBody_EmptyMessages(t *testing.T) {
require.Equal(t, "", extractSearchQueryFromBody([]byte(`{"messages":[]}`)))
}
func TestExtractSearchQueryFromBody_NoMessages(t *testing.T) {
require.Equal(t, "", extractSearchQueryFromBody([]byte(`{"model":"claude-3"}`)))
}
func TestExtractSearchQueryFromBody_ArrayContentSkipsEmptyText(t *testing.T) {
body := `{"messages":[{"role":"user","content":[{"type":"image"},{"type":"text","text":""},{"type":"text","text":"real query"}]}]}`
require.Equal(t, "real query", extractSearchQueryFromBody([]byte(body)))
}
func TestExtractSearchQueryFromBody_ArrayContentNoTextBlock(t *testing.T) {
body := `{"messages":[{"role":"user","content":[{"type":"image","source":{}}]}]}`
require.Equal(t, "", extractSearchQueryFromBody([]byte(body)))
}
// --- buildSearchResultBlocks ---
func TestBuildSearchResultBlocks_WithResults(t *testing.T) {
results := []websearch.SearchResult{
{URL: "https://a.com", Title: "A", Snippet: "snippet a", PageAge: "2 days"},
{URL: "https://b.com", Title: "B", Snippet: "snippet b"},
}
blocks := buildSearchResultBlocks(results)
require.Len(t, blocks, 2)
require.Equal(t, "web_search_result", blocks[0]["type"])
require.Equal(t, "https://a.com", blocks[0]["url"])
require.Equal(t, "snippet a", blocks[0]["page_content"])
require.Equal(t, "2 days", blocks[0]["page_age"])
// Second result has no PageAge
require.Equal(t, "https://b.com", blocks[1]["url"])
_, hasPageAge := blocks[1]["page_age"]
require.False(t, hasPageAge)
}
func TestBuildSearchResultBlocks_Empty(t *testing.T) {
blocks := buildSearchResultBlocks(nil)
require.Empty(t, blocks)
}
func TestBuildSearchResultBlocks_SnippetEmpty(t *testing.T) {
blocks := buildSearchResultBlocks([]websearch.SearchResult{{URL: "https://x.com", Title: "X", Snippet: ""}})
_, hasContent := blocks[0]["page_content"]
require.False(t, hasContent)
}
// --- buildTextSummary ---
func TestBuildTextSummary_WithResults(t *testing.T) {
results := []websearch.SearchResult{
{URL: "https://a.com", Title: "A", Snippet: "desc a"},
}
summary := buildTextSummary("test query", results)
require.Contains(t, summary, "test query")
require.Contains(t, summary, "1. **A**")
require.Contains(t, summary, "https://a.com")
}
func TestBuildTextSummary_NoResults(t *testing.T) {
summary := buildTextSummary("test", nil)
require.Contains(t, summary, "No search results found for: test")
}
// --- shouldEmulateWebSearch ---
// webSearchToolBody is a valid request body with exactly one web_search tool.
var webSearchToolBody = []byte(`{"tools":[{"type":"web_search"}],"messages":[{"role":"user","content":"test"}]}`)
// nonWebSearchToolBody is a request body without web_search tool.
var nonWebSearchToolBody = []byte(`{"tools":[{"type":"text_editor"}],"messages":[{"role":"user","content":"test"}]}`)
// newAnthropicAPIKeyAccount creates a test Account with the given web search emulation mode.
func newAnthropicAPIKeyAccount(mode string) *Account {
return &Account{
ID: 1,
Platform: PlatformAnthropic,
Type: AccountTypeAPIKey,
Extra: map[string]any{featureKeyWebSearchEmulation: mode},
}
}
// setGlobalWebSearchConfig stores a config in the global cache used by SettingService.IsWebSearchEmulationEnabled.
func setGlobalWebSearchConfig(cfg *WebSearchEmulationConfig) {
webSearchEmulationCache.Store(&cachedWebSearchEmulationConfig{
config: cfg,
expiresAt: time.Now().Add(10 * time.Minute).UnixNano(),
})
}
// clearGlobalWebSearchConfig resets the global cache to force re-read.
func clearGlobalWebSearchConfig() {
webSearchEmulationCache.Store((*cachedWebSearchEmulationConfig)(nil))
}
// newSettingServiceForWebSearchTest creates a SettingService with a mock repo pre-loaded with config.
func newSettingServiceForWebSearchTest(enabled bool) *SettingService {
repo := newMockSettingRepo()
cfg := &WebSearchEmulationConfig{
Enabled: enabled,
Providers: []WebSearchProviderConfig{{Type: "brave", APIKey: "sk-test"}},
}
data, _ := json.Marshal(cfg)
repo.data[SettingKeyWebSearchEmulationConfig] = string(data)
return NewSettingService(repo, &config.Config{})
}
// newChannelServiceWithCache creates a ChannelService with a pre-built cache containing the channel.
func newChannelServiceWithCache(groupID int64, ch *Channel) *ChannelService {
svc := &ChannelService{}
cache := &channelCache{
channelByGroupID: map[int64]*Channel{groupID: ch},
byID: map[int64]*Channel{ch.ID: ch},
groupPlatform: map[int64]string{},
loadedAt: time.Now(),
}
svc.cache.Store(cache)
return svc
}
func TestShouldEmulateWebSearch_NilManager(t *testing.T) {
SetWebSearchManager(nil)
defer SetWebSearchManager(nil)
settingSvc := newSettingServiceForWebSearchTest(true)
setGlobalWebSearchConfig(&WebSearchEmulationConfig{
Enabled: true,
Providers: []WebSearchProviderConfig{{Type: "brave", APIKey: "k"}},
})
defer clearGlobalWebSearchConfig()
svc := &GatewayService{settingService: settingSvc}
account := newAnthropicAPIKeyAccount(WebSearchModeEnabled)
require.False(t, svc.shouldEmulateWebSearch(context.Background(), account, nil, webSearchToolBody))
}
func TestShouldEmulateWebSearch_NotOnlyWebSearchTool(t *testing.T) {
mgr := websearch.NewManager([]websearch.ProviderConfig{{Type: "brave", APIKey: "k"}}, nil)
SetWebSearchManager(mgr)
defer SetWebSearchManager(nil)
settingSvc := newSettingServiceForWebSearchTest(true)
setGlobalWebSearchConfig(&WebSearchEmulationConfig{
Enabled: true,
Providers: []WebSearchProviderConfig{{Type: "brave", APIKey: "k"}},
})
defer clearGlobalWebSearchConfig()
svc := &GatewayService{settingService: settingSvc}
account := newAnthropicAPIKeyAccount(WebSearchModeEnabled)
require.False(t, svc.shouldEmulateWebSearch(context.Background(), account, nil, nonWebSearchToolBody))
}
func TestShouldEmulateWebSearch_GlobalDisabled(t *testing.T) {
mgr := websearch.NewManager([]websearch.ProviderConfig{{Type: "brave", APIKey: "k"}}, nil)
SetWebSearchManager(mgr)
defer SetWebSearchManager(nil)
// Global config disabled
setGlobalWebSearchConfig(&WebSearchEmulationConfig{
Enabled: false,
Providers: []WebSearchProviderConfig{{Type: "brave", APIKey: "k"}},
})
defer clearGlobalWebSearchConfig()
settingSvc := newSettingServiceForWebSearchTest(false)
svc := &GatewayService{settingService: settingSvc}
account := newAnthropicAPIKeyAccount(WebSearchModeEnabled)
require.False(t, svc.shouldEmulateWebSearch(context.Background(), account, nil, webSearchToolBody))
}
func TestShouldEmulateWebSearch_AccountDisabled(t *testing.T) {
mgr := websearch.NewManager([]websearch.ProviderConfig{{Type: "brave", APIKey: "k"}}, nil)
SetWebSearchManager(mgr)
defer SetWebSearchManager(nil)
setGlobalWebSearchConfig(&WebSearchEmulationConfig{
Enabled: true,
Providers: []WebSearchProviderConfig{{Type: "brave", APIKey: "k"}},
})
defer clearGlobalWebSearchConfig()
settingSvc := newSettingServiceForWebSearchTest(true)
svc := &GatewayService{settingService: settingSvc}
account := newAnthropicAPIKeyAccount(WebSearchModeDisabled)
require.False(t, svc.shouldEmulateWebSearch(context.Background(), account, nil, webSearchToolBody))
}
func TestShouldEmulateWebSearch_AccountEnabled(t *testing.T) {
mgr := websearch.NewManager([]websearch.ProviderConfig{{Type: "brave", APIKey: "k"}}, nil)
SetWebSearchManager(mgr)
defer SetWebSearchManager(nil)
setGlobalWebSearchConfig(&WebSearchEmulationConfig{
Enabled: true,
Providers: []WebSearchProviderConfig{{Type: "brave", APIKey: "k"}},
})
defer clearGlobalWebSearchConfig()
settingSvc := newSettingServiceForWebSearchTest(true)
svc := &GatewayService{settingService: settingSvc}
account := newAnthropicAPIKeyAccount(WebSearchModeEnabled)
require.True(t, svc.shouldEmulateWebSearch(context.Background(), account, nil, webSearchToolBody))
}
func TestShouldEmulateWebSearch_DefaultMode_ChannelEnabled(t *testing.T) {
mgr := websearch.NewManager([]websearch.ProviderConfig{{Type: "brave", APIKey: "k"}}, nil)
SetWebSearchManager(mgr)
defer SetWebSearchManager(nil)
setGlobalWebSearchConfig(&WebSearchEmulationConfig{
Enabled: true,
Providers: []WebSearchProviderConfig{{Type: "brave", APIKey: "k"}},
})
defer clearGlobalWebSearchConfig()
settingSvc := newSettingServiceForWebSearchTest(true)
ch := &Channel{
ID: 10,
Status: StatusActive,
FeaturesConfig: map[string]any{
featureKeyWebSearchEmulation: map[string]any{PlatformAnthropic: true},
},
}
channelSvc := newChannelServiceWithCache(42, ch)
svc := &GatewayService{settingService: settingSvc, channelService: channelSvc}
account := newAnthropicAPIKeyAccount(WebSearchModeDefault)
groupID := int64(42)
require.True(t, svc.shouldEmulateWebSearch(context.Background(), account, &groupID, webSearchToolBody))
}
func TestShouldEmulateWebSearch_DefaultMode_ChannelDisabled(t *testing.T) {
mgr := websearch.NewManager([]websearch.ProviderConfig{{Type: "brave", APIKey: "k"}}, nil)
SetWebSearchManager(mgr)
defer SetWebSearchManager(nil)
setGlobalWebSearchConfig(&WebSearchEmulationConfig{
Enabled: true,
Providers: []WebSearchProviderConfig{{Type: "brave", APIKey: "k"}},
})
defer clearGlobalWebSearchConfig()
settingSvc := newSettingServiceForWebSearchTest(true)
ch := &Channel{
ID: 10,
Status: StatusActive,
FeaturesConfig: map[string]any{
featureKeyWebSearchEmulation: map[string]any{PlatformAnthropic: false},
},
}
channelSvc := newChannelServiceWithCache(42, ch)
svc := &GatewayService{settingService: settingSvc, channelService: channelSvc}
account := newAnthropicAPIKeyAccount(WebSearchModeDefault)
groupID := int64(42)
require.False(t, svc.shouldEmulateWebSearch(context.Background(), account, &groupID, webSearchToolBody))
}
func TestShouldEmulateWebSearch_DefaultMode_NilGroupID(t *testing.T) {
mgr := websearch.NewManager([]websearch.ProviderConfig{{Type: "brave", APIKey: "k"}}, nil)
SetWebSearchManager(mgr)
defer SetWebSearchManager(nil)
setGlobalWebSearchConfig(&WebSearchEmulationConfig{
Enabled: true,
Providers: []WebSearchProviderConfig{{Type: "brave", APIKey: "k"}},
})
defer clearGlobalWebSearchConfig()
settingSvc := newSettingServiceForWebSearchTest(true)
svc := &GatewayService{settingService: settingSvc}
account := newAnthropicAPIKeyAccount(WebSearchModeDefault)
// nil groupID + default mode → falls through to channel check → returns false
require.False(t, svc.shouldEmulateWebSearch(context.Background(), account, nil, webSearchToolBody))
}
func TestShouldEmulateWebSearch_DefaultMode_NilChannelService(t *testing.T) {
mgr := websearch.NewManager([]websearch.ProviderConfig{{Type: "brave", APIKey: "k"}}, nil)
SetWebSearchManager(mgr)
defer SetWebSearchManager(nil)
setGlobalWebSearchConfig(&WebSearchEmulationConfig{
Enabled: true,
Providers: []WebSearchProviderConfig{{Type: "brave", APIKey: "k"}},
})
defer clearGlobalWebSearchConfig()
settingSvc := newSettingServiceForWebSearchTest(true)
svc := &GatewayService{settingService: settingSvc, channelService: nil}
account := newAnthropicAPIKeyAccount(WebSearchModeDefault)
groupID := int64(42)
// nil channelService + default mode → returns false
require.False(t, svc.shouldEmulateWebSearch(context.Background(), account, &groupID, webSearchToolBody))
}
package service
import (
"encoding/json"
"strings"
)
// NotifyEmailEntry represents a notification email with enable/disable and verification state.
// All emails are user-managed; maximum 3 entries per user.
type NotifyEmailEntry struct {
Email string `json:"email"`
Disabled bool `json:"disabled"`
Verified bool `json:"verified"`
}
// parseNotifyEmails parses a JSON string into []NotifyEmailEntry.
// It auto-detects the format:
// - Old format ["email1","email2"] → converted to [{email, disabled:false, verified:true}, ...]
// - New format [{email,disabled,verified}, ...] → parsed directly
//
// Returns nil on empty/invalid input.
func ParseNotifyEmails(raw string) []NotifyEmailEntry {
raw = strings.TrimSpace(raw)
if raw == "" || raw == "[]" {
return nil
}
// Try parsing as new format first (array of objects)
var entries []NotifyEmailEntry
if err := json.Unmarshal([]byte(raw), &entries); err == nil && len(entries) > 0 {
// Verify it's actually the new format by checking the first element
// json.Unmarshal into []NotifyEmailEntry succeeds even for ["string"]
// because it tries to fit "string" into NotifyEmailEntry and gets zero values.
// We need to detect old format explicitly.
if !isOldStringArrayFormat(raw) {
return entries
}
}
// Try parsing as old format (array of strings)
var emails []string
if err := json.Unmarshal([]byte(raw), &emails); err == nil {
result := make([]NotifyEmailEntry, 0, len(emails))
for _, e := range emails {
e = strings.TrimSpace(e)
if e != "" {
result = append(result, NotifyEmailEntry{
Email: e,
Disabled: false,
Verified: false, // Old format emails default to unverified
})
}
}
return result
}
return nil
}
// isOldStringArrayFormat checks if the JSON is a string array like ["email1","email2"].
func isOldStringArrayFormat(raw string) bool {
var arr []json.RawMessage
if err := json.Unmarshal([]byte(raw), &arr); err != nil || len(arr) == 0 {
return false
}
// Check if first element starts with a quote (string) vs { (object)
first := strings.TrimSpace(string(arr[0]))
return len(first) > 0 && first[0] == '"'
}
// marshalNotifyEmails serializes []NotifyEmailEntry to JSON string.
func MarshalNotifyEmails(entries []NotifyEmailEntry) string {
if len(entries) == 0 {
return "[]"
}
data, err := json.Marshal(entries)
if err != nil {
return "[]"
}
return string(data)
}
//go:build unit
package service
import (
"testing"
"github.com/stretchr/testify/require"
)
// ---------- ParseNotifyEmails ----------
func TestParseNotifyEmails_EmptyString(t *testing.T) {
result := ParseNotifyEmails("")
require.Nil(t, result)
}
func TestParseNotifyEmails_EmptyArray(t *testing.T) {
result := ParseNotifyEmails("[]")
require.Nil(t, result)
}
func TestParseNotifyEmails_Null(t *testing.T) {
// "null" is valid JSON that unmarshals into a nil string slice.
// The old-format branch then returns an empty (non-nil) slice.
result := ParseNotifyEmails("null")
require.Empty(t, result)
}
func TestParseNotifyEmails_WhitespaceOnly(t *testing.T) {
result := ParseNotifyEmails(" ")
require.Nil(t, result)
}
func TestParseNotifyEmails_OldFormat(t *testing.T) {
raw := `["alice@example.com", "bob@example.com"]`
result := ParseNotifyEmails(raw)
require.Len(t, result, 2)
require.Equal(t, "alice@example.com", result[0].Email)
require.False(t, result[0].Verified, "old format emails should default to unverified")
require.False(t, result[0].Disabled)
require.Equal(t, "bob@example.com", result[1].Email)
require.False(t, result[1].Verified)
require.False(t, result[1].Disabled)
}
func TestParseNotifyEmails_OldFormat_SkipsEmptyEntries(t *testing.T) {
raw := `["alice@example.com", "", " ", "bob@example.com"]`
result := ParseNotifyEmails(raw)
require.Len(t, result, 2)
require.Equal(t, "alice@example.com", result[0].Email)
require.Equal(t, "bob@example.com", result[1].Email)
}
func TestParseNotifyEmails_NewFormat(t *testing.T) {
raw := `[{"email":"alice@example.com","verified":true,"disabled":false},{"email":"bob@example.com","verified":false,"disabled":true}]`
result := ParseNotifyEmails(raw)
require.Len(t, result, 2)
require.Equal(t, "alice@example.com", result[0].Email)
require.True(t, result[0].Verified)
require.False(t, result[0].Disabled)
require.Equal(t, "bob@example.com", result[1].Email)
require.False(t, result[1].Verified)
require.True(t, result[1].Disabled)
}
func TestParseNotifyEmails_NewFormat_SingleEntry(t *testing.T) {
raw := `[{"email":"solo@example.com","verified":true,"disabled":false}]`
result := ParseNotifyEmails(raw)
require.Len(t, result, 1)
require.Equal(t, "solo@example.com", result[0].Email)
require.True(t, result[0].Verified)
}
func TestParseNotifyEmails_InvalidJSON(t *testing.T) {
result := ParseNotifyEmails(`{not valid json`)
require.Nil(t, result)
}
func TestParseNotifyEmails_InvalidJSONObject(t *testing.T) {
// A plain JSON object (not array) should return nil.
result := ParseNotifyEmails(`{"email":"a@b.com"}`)
require.Nil(t, result)
}
func TestParseNotifyEmails_WhitespacePadding(t *testing.T) {
raw := ` ["padded@example.com"] `
result := ParseNotifyEmails(raw)
require.Len(t, result, 1)
require.Equal(t, "padded@example.com", result[0].Email)
}
// ---------- MarshalNotifyEmails ----------
func TestMarshalNotifyEmails_EmptySlice(t *testing.T) {
result := MarshalNotifyEmails([]NotifyEmailEntry{})
require.Equal(t, "[]", result)
}
func TestMarshalNotifyEmails_NilSlice(t *testing.T) {
result := MarshalNotifyEmails(nil)
require.Equal(t, "[]", result)
}
func TestMarshalNotifyEmails_SingleEntry(t *testing.T) {
entries := []NotifyEmailEntry{
{Email: "test@example.com", Verified: true, Disabled: false},
}
result := MarshalNotifyEmails(entries)
require.Contains(t, result, `"email":"test@example.com"`)
require.Contains(t, result, `"verified":true`)
require.Contains(t, result, `"disabled":false`)
// Round-trip: parsing the marshalled result should produce the original entries.
parsed := ParseNotifyEmails(result)
require.Len(t, parsed, 1)
require.Equal(t, entries[0], parsed[0])
}
func TestMarshalNotifyEmails_MultipleEntries(t *testing.T) {
entries := []NotifyEmailEntry{
{Email: "a@example.com", Verified: true, Disabled: false},
{Email: "b@example.com", Verified: false, Disabled: true},
}
result := MarshalNotifyEmails(entries)
// Round-trip verification.
parsed := ParseNotifyEmails(result)
require.Len(t, parsed, 2)
require.Equal(t, entries[0], parsed[0])
require.Equal(t, entries[1], parsed[1])
}
func TestMarshalNotifyEmails_RoundTrip_NewFormat(t *testing.T) {
original := []NotifyEmailEntry{
{Email: "x@example.com", Verified: true, Disabled: true},
{Email: "y@example.com", Verified: false, Disabled: false},
}
marshalled := MarshalNotifyEmails(original)
parsed := ParseNotifyEmails(marshalled)
require.Equal(t, original, parsed)
}
// ---------- isOldStringArrayFormat (indirectly via ParseNotifyEmails) ----------
func TestParseNotifyEmails_MixedOldFormatWithWhitespace(t *testing.T) {
// Emails with leading/trailing whitespace in old format should be trimmed.
raw := `[" alice@example.com "]`
result := ParseNotifyEmails(raw)
require.Len(t, result, 1)
require.Equal(t, "alice@example.com", result[0].Email)
}
......@@ -147,6 +147,7 @@ func newOpenAIRecordUsageServiceForTest(usageRepo UsageLogRepository, userRepo U
nil,
nil,
nil,
nil,
)
svc.userGroupRateResolver = newUserGroupRateResolver(
rateRepo,
......
......@@ -327,6 +327,7 @@ type OpenAIGatewayService struct {
openaiWSResolver OpenAIWSProtocolResolver
resolver *ModelPricingResolver
channelService *ChannelService
balanceNotifyService *BalanceNotifyService
openaiWSPoolOnce sync.Once
openaiWSStateStoreOnce sync.Once
......@@ -364,6 +365,7 @@ func NewOpenAIGatewayService(
openAITokenProvider *OpenAITokenProvider,
resolver *ModelPricingResolver,
channelService *ChannelService,
balanceNotifyService *BalanceNotifyService,
) *OpenAIGatewayService {
svc := &OpenAIGatewayService{
accountRepo: accountRepo,
......@@ -393,6 +395,7 @@ func NewOpenAIGatewayService(
openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
resolver: resolver,
channelService: channelService,
balanceNotifyService: balanceNotifyService,
responseHeaderFilter: compileResponseHeaderFilter(cfg),
codexSnapshotThrottle: newAccountWriteThrottle(openAICodexSnapshotPersistMinInterval),
}
......@@ -477,11 +480,12 @@ func (s *OpenAIGatewayService) getCodexSnapshotThrottle() *accountWriteThrottle
func (s *OpenAIGatewayService) billingDeps() *billingDeps {
return &billingDeps{
accountRepo: s.accountRepo,
userRepo: s.userRepo,
userSubRepo: s.userSubRepo,
billingCacheService: s.billingCacheService,
deferredService: s.deferredService,
accountRepo: s.accountRepo,
userRepo: s.userRepo,
userSubRepo: s.userSubRepo,
billingCacheService: s.billingCacheService,
deferredService: s.deferredService,
balanceNotifyService: s.balanceNotifyService,
}
}
......@@ -1677,7 +1681,6 @@ func (s *OpenAIGatewayService) recheckSelectedOpenAIAccountFromDB(ctx context.Co
if err != nil || latest == nil {
return nil
}
syncOpenAICodexRateLimitFromExtra(ctx, s.accountRepo, latest, time.Now())
if !latest.IsSchedulable() || !latest.IsOpenAI() {
return nil
}
......@@ -1700,7 +1703,6 @@ func (s *OpenAIGatewayService) getSchedulableAccount(ctx context.Context, accoun
if err != nil || account == nil {
return account, err
}
syncOpenAICodexRateLimitFromExtra(ctx, s.accountRepo, account, time.Now())
return account, nil
}
......@@ -4569,6 +4571,14 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
usageLog.SubscriptionID = &subscription.ID
}
// 计算账号统计定价费用(使用最终上游模型匹配自定义规则)
if apiKey.GroupID != nil {
applyAccountStatsCost(ctx, usageLog, s.channelService, s.billingService,
account.ID, *apiKey.GroupID, result.UpstreamModel, result.Model,
tokens, cost.TotalCost,
)
}
if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple {
writeUsageLogBestEffort(ctx, s.usageLogRepo, usageLog, "service.openai_gateway")
logger.LegacyPrintf("service.openai_gateway", "[SIMPLE MODE] Usage recorded (not billed): user=%d, tokens=%d", usageLog.UserID, usageLog.TotalTokens())
......@@ -4756,69 +4766,6 @@ func buildCodexUsageExtraUpdates(snapshot *OpenAICodexUsageSnapshot, fallbackNow
return updates
}
func codexUsagePercentExhausted(value *float64) bool {
return value != nil && *value >= 100-1e-9
}
func codexRateLimitResetAtFromSnapshot(snapshot *OpenAICodexUsageSnapshot, fallbackNow time.Time) *time.Time {
if snapshot == nil {
return nil
}
normalized := snapshot.Normalize()
if normalized == nil {
return nil
}
baseTime := codexSnapshotBaseTime(snapshot, fallbackNow)
if codexUsagePercentExhausted(normalized.Used7dPercent) && normalized.Reset7dSeconds != nil {
resetAt := baseTime.Add(time.Duration(*normalized.Reset7dSeconds) * time.Second)
return &resetAt
}
if codexUsagePercentExhausted(normalized.Used5hPercent) && normalized.Reset5hSeconds != nil {
resetAt := baseTime.Add(time.Duration(*normalized.Reset5hSeconds) * time.Second)
return &resetAt
}
return nil
}
func codexRateLimitResetAtFromExtra(extra map[string]any, now time.Time) *time.Time {
if len(extra) == 0 {
return nil
}
if progress := buildCodexUsageProgressFromExtra(extra, "7d", now); progress != nil && codexUsagePercentExhausted(&progress.Utilization) && progress.ResetsAt != nil && now.Before(*progress.ResetsAt) {
resetAt := progress.ResetsAt.UTC()
return &resetAt
}
if progress := buildCodexUsageProgressFromExtra(extra, "5h", now); progress != nil && codexUsagePercentExhausted(&progress.Utilization) && progress.ResetsAt != nil && now.Before(*progress.ResetsAt) {
resetAt := progress.ResetsAt.UTC()
return &resetAt
}
return nil
}
func applyOpenAICodexRateLimitFromExtra(account *Account, now time.Time) (*time.Time, bool) {
if account == nil || !account.IsOpenAI() {
return nil, false
}
resetAt := codexRateLimitResetAtFromExtra(account.Extra, now)
if resetAt == nil {
return nil, false
}
if account.RateLimitResetAt != nil && now.Before(*account.RateLimitResetAt) && !account.RateLimitResetAt.Before(*resetAt) {
return account.RateLimitResetAt, false
}
account.RateLimitResetAt = resetAt
return resetAt, true
}
func syncOpenAICodexRateLimitFromExtra(ctx context.Context, repo AccountRepository, account *Account, now time.Time) *time.Time {
resetAt, changed := applyOpenAICodexRateLimitFromExtra(account, now)
if !changed || resetAt == nil || repo == nil || account == nil || account.ID <= 0 {
return resetAt
}
_ = repo.SetRateLimited(ctx, account.ID, *resetAt)
return resetAt
}
// updateCodexUsageSnapshot saves the Codex usage snapshot to account's Extra field
func (s *OpenAIGatewayService) updateCodexUsageSnapshot(ctx context.Context, accountID int64, snapshot *OpenAICodexUsageSnapshot) {
if snapshot == nil {
......@@ -4830,24 +4777,17 @@ func (s *OpenAIGatewayService) updateCodexUsageSnapshot(ctx context.Context, acc
now := time.Now()
updates := buildCodexUsageExtraUpdates(snapshot, now)
resetAt := codexRateLimitResetAtFromSnapshot(snapshot, now)
if len(updates) == 0 && resetAt == nil {
if len(updates) == 0 {
return
}
shouldPersistUpdates := len(updates) > 0 && s.getCodexSnapshotThrottle().Allow(accountID, now)
if !shouldPersistUpdates && resetAt == nil {
if !s.getCodexSnapshotThrottle().Allow(accountID, now) {
return
}
go func() {
updateCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if shouldPersistUpdates {
_ = s.accountRepo.UpdateExtra(updateCtx, accountID, updates)
}
if resetAt != nil {
_ = s.accountRepo.SetRateLimited(updateCtx, accountID, *resetAt)
}
_ = s.accountRepo.UpdateExtra(updateCtx, accountID, updates)
}()
}
......
......@@ -413,7 +413,12 @@ func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_PassthroughModeR
select {
case serverErr := <-serverErrCh:
require.NoError(t, serverErr)
// After normal client close, the server goroutine may receive the close frame
// as an error — this is expected behavior, not a test failure.
if serverErr != nil {
require.Contains(t, serverErr.Error(), "StatusNormalClosure",
"server error should only be a normal close frame, got: %v", serverErr)
}
case <-time.After(5 * time.Second):
t.Fatal("等待 passthrough websocket 结束超时")
}
......
......@@ -617,6 +617,7 @@ func TestNewOpenAIGatewayService_InitializesOpenAIWSResolver(t *testing.T) {
nil,
nil,
nil,
nil,
)
decision := svc.getOpenAIWSProtocolResolver().Resolve(nil)
......
......@@ -345,7 +345,7 @@ func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_ErrorEventUsageL
}
}
func TestOpenAIGatewayService_UpdateCodexUsageSnapshot_ExhaustedSnapshotSetsRateLimit(t *testing.T) {
func TestOpenAIGatewayService_UpdateCodexUsageSnapshot_ExhaustedSnapshotDoesNotSetRateLimit(t *testing.T) {
repo := &openAICodexSnapshotAsyncRepo{
updateExtraCh: make(chan map[string]any, 1),
rateLimitCh: make(chan time.Time, 1),
......@@ -359,7 +359,6 @@ func TestOpenAIGatewayService_UpdateCodexUsageSnapshot_ExhaustedSnapshotSetsRate
SecondaryResetAfterSeconds: ptrIntWS(1200),
SecondaryWindowMinutes: ptrIntWS(300),
}
before := time.Now()
svc.updateCodexUsageSnapshot(context.Background(), 601, snapshot)
select {
......@@ -371,9 +370,8 @@ func TestOpenAIGatewayService_UpdateCodexUsageSnapshot_ExhaustedSnapshotSetsRate
select {
case resetAt := <-repo.rateLimitCh:
require.WithinDuration(t, before.Add(time.Hour), resetAt, 2*time.Second)
t.Fatalf("不应因仅写入快照而生成运行时限流时间: %v", resetAt)
case <-time.After(2 * time.Second):
t.Fatal("等待 codex 100% 自动切换限流超时")
}
}
......@@ -401,7 +399,7 @@ func TestOpenAIGatewayService_UpdateCodexUsageSnapshot_NonExhaustedSnapshotDoesN
select {
case resetAt := <-repo.rateLimitCh:
t.Fatalf("unexpected rate limit reset at: %v", resetAt)
t.Fatalf("不应写入运行时限流时间: %v", resetAt)
case <-time.After(200 * time.Millisecond):
}
}
......@@ -409,7 +407,6 @@ func TestOpenAIGatewayService_UpdateCodexUsageSnapshot_NonExhaustedSnapshotDoesN
func TestOpenAIGatewayService_UpdateCodexUsageSnapshot_ThrottlesExtraWrites(t *testing.T) {
repo := &openAICodexSnapshotAsyncRepo{
updateExtraCh: make(chan map[string]any, 2),
rateLimitCh: make(chan time.Time, 2),
}
svc := &OpenAIGatewayService{
accountRepo: repo,
......@@ -443,7 +440,7 @@ func TestOpenAIGatewayService_UpdateCodexUsageSnapshot_ThrottlesExtraWrites(t *t
func ptrFloat64WS(v float64) *float64 { return &v }
func ptrIntWS(v int) *int { return &v }
func TestOpenAIGatewayService_GetSchedulableAccount_ExhaustedCodexExtraSetsRateLimit(t *testing.T) {
func TestOpenAIGatewayService_GetSchedulableAccount_ExhaustedCodexExtraDoesNotSetRateLimit(t *testing.T) {
resetAt := time.Now().Add(6 * 24 * time.Hour)
account := Account{
ID: 701,
......@@ -463,17 +460,15 @@ func TestOpenAIGatewayService_GetSchedulableAccount_ExhaustedCodexExtraSetsRateL
fresh, err := svc.getSchedulableAccount(context.Background(), account.ID)
require.NoError(t, err)
require.NotNil(t, fresh)
require.NotNil(t, fresh.RateLimitResetAt)
require.WithinDuration(t, resetAt.UTC(), *fresh.RateLimitResetAt, time.Second)
require.Nil(t, fresh.RateLimitResetAt)
select {
case persisted := <-repo.rateLimitCh:
require.WithinDuration(t, resetAt.UTC(), persisted, time.Second)
t.Fatalf("不应将已耗尽的 codex extra 提升为运行时限流状态: %v", persisted)
case <-time.After(2 * time.Second):
t.Fatal("等待旧快照补写限流状态超时")
}
}
func TestAdminService_ListAccounts_ExhaustedCodexExtraReturnsRateLimitedAccount(t *testing.T) {
func TestAdminService_ListAccounts_ExhaustedCodexExtraDoesNotSetRateLimit(t *testing.T) {
resetAt := time.Now().Add(4 * 24 * time.Hour)
repo := &openAICodexExtraListRepo{
stubOpenAIAccountRepo: stubOpenAIAccountRepo{accounts: []Account{{
......@@ -496,13 +491,11 @@ func TestAdminService_ListAccounts_ExhaustedCodexExtraReturnsRateLimitedAccount(
require.NoError(t, err)
require.Equal(t, int64(1), total)
require.Len(t, accounts, 1)
require.NotNil(t, accounts[0].RateLimitResetAt)
require.WithinDuration(t, resetAt.UTC(), *accounts[0].RateLimitResetAt, time.Second)
require.Nil(t, accounts[0].RateLimitResetAt)
select {
case persisted := <-repo.rateLimitCh:
require.WithinDuration(t, resetAt.UTC(), persisted, time.Second)
t.Fatalf("不应在账号列表查询时将 codex extra 持久化为运行时限流状态: %v", persisted)
case <-time.After(2 * time.Second):
t.Fatal("等待列表补写限流状态超时")
}
}
......
......@@ -64,12 +64,9 @@ func (s *OpsService) getAccountsLoadMapBestEffort(ctx context.Context, accounts
if acc.ID <= 0 {
continue
}
c := acc.Concurrency
if c <= 0 {
c = 1
}
if prev, ok := unique[acc.ID]; !ok || c > prev {
unique[acc.ID] = c
lf := acc.EffectiveLoadFactor()
if prev, ok := unique[acc.ID]; !ok || lf > prev {
unique[acc.ID] = lf
}
}
......
......@@ -391,7 +391,7 @@ func (c *OpsMetricsCollector) collectConcurrencyQueueDepth(parentCtx context.Con
}
batch = append(batch, AccountWithConcurrency{
ID: acc.ID,
MaxConcurrency: acc.Concurrency,
MaxConcurrency: acc.EffectiveLoadFactor(),
})
}
if len(batch) == 0 {
......
......@@ -183,6 +183,15 @@ func TestOpsSystemLogSink_StartStopAndFlushSuccess(t *testing.T) {
if strings.TrimSpace(item.Message) == "" {
t.Fatalf("message should not be empty")
}
// writtenCount is incremented after BatchInsertSystemLogsFn returns,
// so poll briefly to avoid a race between the done signal and the atomic add.
deadline := time.Now().Add(time.Second)
for time.Now().Before(deadline) {
if sink.Health().WrittenCount > 0 {
break
}
time.Sleep(time.Millisecond)
}
health := sink.Health()
if health.WrittenCount == 0 {
t.Fatalf("written_count should be >0")
......
package service
import (
"math"
"github.com/shopspring/decimal"
)
const defaultBalanceRechargeMultiplier = 1.0
func normalizeBalanceRechargeMultiplier(multiplier float64) float64 {
if math.IsNaN(multiplier) || math.IsInf(multiplier, 0) || multiplier <= 0 {
return defaultBalanceRechargeMultiplier
}
return multiplier
}
func calculateCreditedBalance(paymentAmount, multiplier float64) float64 {
return decimal.NewFromFloat(paymentAmount).
Mul(decimal.NewFromFloat(normalizeBalanceRechargeMultiplier(multiplier))).
Round(2).
InexactFloat64()
}
func calculateGatewayRefundAmount(orderAmount, payAmount, refundAmount float64) float64 {
if orderAmount <= 0 || payAmount <= 0 || refundAmount <= 0 {
return 0
}
if math.Abs(refundAmount-orderAmount) <= amountToleranceCNY {
return decimal.NewFromFloat(payAmount).Round(2).InexactFloat64()
}
return decimal.NewFromFloat(payAmount).
Mul(decimal.NewFromFloat(refundAmount)).
Div(decimal.NewFromFloat(orderAmount)).
Round(2).
InexactFloat64()
}
......@@ -3,6 +3,7 @@ package service
import (
"context"
"fmt"
"strings"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/group"
......@@ -10,6 +11,52 @@ import (
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
)
// validatePlanRequired checks that all required fields for a plan are provided.
func validatePlanRequired(name string, groupID int64, price float64, validityDays int, validityUnit string, originalPrice *float64) error {
if strings.TrimSpace(name) == "" {
return infraerrors.BadRequest("PLAN_NAME_REQUIRED", "plan name is required")
}
if groupID <= 0 {
return infraerrors.BadRequest("PLAN_GROUP_REQUIRED", "group is required")
}
if price <= 0 {
return infraerrors.BadRequest("PLAN_PRICE_INVALID", "price must be > 0")
}
if validityDays <= 0 {
return infraerrors.BadRequest("PLAN_VALIDITY_REQUIRED", "validity days must be > 0")
}
if strings.TrimSpace(validityUnit) == "" {
return infraerrors.BadRequest("PLAN_VALIDITY_UNIT_REQUIRED", "validity unit is required")
}
if originalPrice != nil && *originalPrice < 0 {
return infraerrors.BadRequest("PLAN_ORIGINAL_PRICE_INVALID", "original price must be >= 0")
}
return nil
}
// validatePlanPatch validates only the non-nil fields in a patch update.
func validatePlanPatch(req UpdatePlanRequest) error {
if req.Name != nil && strings.TrimSpace(*req.Name) == "" {
return infraerrors.BadRequest("PLAN_NAME_REQUIRED", "plan name is required")
}
if req.GroupID != nil && *req.GroupID <= 0 {
return infraerrors.BadRequest("PLAN_GROUP_REQUIRED", "group is required")
}
if req.Price != nil && *req.Price <= 0 {
return infraerrors.BadRequest("PLAN_PRICE_INVALID", "price must be > 0")
}
if req.ValidityDays != nil && *req.ValidityDays <= 0 {
return infraerrors.BadRequest("PLAN_VALIDITY_REQUIRED", "validity days must be > 0")
}
if req.ValidityUnit != nil && strings.TrimSpace(*req.ValidityUnit) == "" {
return infraerrors.BadRequest("PLAN_VALIDITY_UNIT_REQUIRED", "validity unit is required")
}
if req.OriginalPrice != nil && *req.OriginalPrice < 0 {
return infraerrors.BadRequest("PLAN_ORIGINAL_PRICE_INVALID", "original price must be >= 0")
}
return nil
}
// --- Plan CRUD ---
// PlanGroupInfo holds the group details needed for subscription plan display.
......@@ -74,6 +121,9 @@ func (s *PaymentConfigService) ListPlansForSale(ctx context.Context) ([]*dbent.S
}
func (s *PaymentConfigService) CreatePlan(ctx context.Context, req CreatePlanRequest) (*dbent.SubscriptionPlan, error) {
if err := validatePlanRequired(req.Name, req.GroupID, req.Price, req.ValidityDays, req.ValidityUnit, req.OriginalPrice); err != nil {
return nil, err
}
b := s.entClient.SubscriptionPlan.Create().
SetGroupID(req.GroupID).SetName(req.Name).SetDescription(req.Description).
SetPrice(req.Price).SetValidityDays(req.ValidityDays).SetValidityUnit(req.ValidityUnit).
......@@ -86,8 +136,12 @@ func (s *PaymentConfigService) CreatePlan(ctx context.Context, req CreatePlanReq
}
// UpdatePlan updates a subscription plan by ID (patch semantics).
// NOTE: This function exceeds 30 lines due to per-field nil-check patch update boilerplate.
// NOTE: This function exceeds 30 lines due to per-field nil-check patch update boilerplate
// plus a validation guard for non-nil fields.
func (s *PaymentConfigService) UpdatePlan(ctx context.Context, id int64, req UpdatePlanRequest) (*dbent.SubscriptionPlan, error) {
if err := validatePlanPatch(req); err != nil {
return nil, err
}
u := s.entClient.SubscriptionPlan.UpdateOneID(id)
if req.GroupID != nil {
u.SetGroupID(*req.GroupID)
......
//go:build unit
package service
import (
"testing"
"github.com/stretchr/testify/require"
)
func TestValidatePlanRequired_AllValid(t *testing.T) {
err := validatePlanRequired("Pro", 1, 9.99, 30, "days", nil)
require.NoError(t, err)
}
func TestValidatePlanRequired_EmptyName(t *testing.T) {
err := validatePlanRequired("", 1, 9.99, 30, "days", nil)
require.Error(t, err)
require.Contains(t, err.Error(), "plan name")
}
func TestValidatePlanRequired_WhitespaceName(t *testing.T) {
err := validatePlanRequired(" ", 1, 9.99, 30, "days", nil)
require.Error(t, err)
require.Contains(t, err.Error(), "plan name")
}
func TestValidatePlanRequired_ZeroGroupID(t *testing.T) {
err := validatePlanRequired("Pro", 0, 9.99, 30, "days", nil)
require.Error(t, err)
require.Contains(t, err.Error(), "group")
}
func TestValidatePlanRequired_NegativeGroupID(t *testing.T) {
err := validatePlanRequired("Pro", -1, 9.99, 30, "days", nil)
require.Error(t, err)
require.Contains(t, err.Error(), "group")
}
func TestValidatePlanRequired_ZeroPrice(t *testing.T) {
err := validatePlanRequired("Pro", 1, 0, 30, "days", nil)
require.Error(t, err)
require.Contains(t, err.Error(), "price")
}
func TestValidatePlanRequired_NegativePrice(t *testing.T) {
err := validatePlanRequired("Pro", 1, -5, 30, "days", nil)
require.Error(t, err)
require.Contains(t, err.Error(), "price")
}
func TestValidatePlanRequired_ZeroValidityDays(t *testing.T) {
err := validatePlanRequired("Pro", 1, 9.99, 0, "days", nil)
require.Error(t, err)
require.Contains(t, err.Error(), "validity days")
}
func TestValidatePlanRequired_NegativeValidityDays(t *testing.T) {
err := validatePlanRequired("Pro", 1, 9.99, -7, "days", nil)
require.Error(t, err)
require.Contains(t, err.Error(), "validity days")
}
func TestValidatePlanRequired_EmptyValidityUnit(t *testing.T) {
err := validatePlanRequired("Pro", 1, 9.99, 30, "", nil)
require.Error(t, err)
require.Contains(t, err.Error(), "validity unit")
}
func TestValidatePlanRequired_WhitespaceValidityUnit(t *testing.T) {
err := validatePlanRequired("Pro", 1, 9.99, 30, " ", nil)
require.Error(t, err)
require.Contains(t, err.Error(), "validity unit")
}
func TestValidatePlanRequired_NameValidatedFirst(t *testing.T) {
err := validatePlanRequired("", 0, 0, 0, "", nil)
require.Error(t, err)
require.Contains(t, err.Error(), "plan name")
}
func TestValidatePlanRequired_TrimmedValidName(t *testing.T) {
err := validatePlanRequired(" Pro ", 1, 9.99, 30, "days", nil)
require.NoError(t, err)
}
func TestValidatePlanRequired_NegativeOriginalPrice(t *testing.T) {
neg := -10.0
err := validatePlanRequired("Pro", 1, 9.99, 30, "days", &neg)
require.Error(t, err)
require.Contains(t, err.Error(), "original price")
}
func TestValidatePlanRequired_ZeroOriginalPrice(t *testing.T) {
zero := 0.0
err := validatePlanRequired("Pro", 1, 9.99, 30, "days", &zero)
require.NoError(t, err)
}
func TestValidatePlanRequired_ValidOriginalPrice(t *testing.T) {
op := 19.99
err := validatePlanRequired("Pro", 1, 9.99, 30, "days", &op)
require.NoError(t, err)
}
// --- validatePlanPatch tests ---
func TestValidatePlanPatch_NegativeOriginalPrice(t *testing.T) {
neg := -5.0
err := validatePlanPatch(UpdatePlanRequest{OriginalPrice: &neg})
require.Error(t, err)
require.Contains(t, err.Error(), "original price")
}
func TestValidatePlanPatch_ZeroOriginalPrice(t *testing.T) {
zero := 0.0
err := validatePlanPatch(UpdatePlanRequest{OriginalPrice: &zero})
require.NoError(t, err)
}
func TestValidatePlanPatch_ValidOriginalPrice(t *testing.T) {
op := 29.99
err := validatePlanPatch(UpdatePlanRequest{OriginalPrice: &op})
require.NoError(t, err)
}
func TestValidatePlanPatch_NilOriginalPrice(t *testing.T) {
err := validatePlanPatch(UpdatePlanRequest{OriginalPrice: nil})
require.NoError(t, err)
}
// --- validatePlanPatch: other fields ---
func ptrStr(s string) *string { return &s }
func ptrInt(i int) *int { return &i }
func ptrInt64(i int64) *int64 { return &i }
func ptrFloat(f float64) *float64 { return &f }
func TestValidatePlanPatch_EmptyName(t *testing.T) {
err := validatePlanPatch(UpdatePlanRequest{Name: ptrStr("")})
require.Error(t, err)
require.Contains(t, err.Error(), "plan name")
}
func TestValidatePlanPatch_ValidName(t *testing.T) {
err := validatePlanPatch(UpdatePlanRequest{Name: ptrStr("Basic")})
require.NoError(t, err)
}
func TestValidatePlanPatch_ZeroGroupID(t *testing.T) {
err := validatePlanPatch(UpdatePlanRequest{GroupID: ptrInt64(0)})
require.Error(t, err)
require.Contains(t, err.Error(), "group")
}
func TestValidatePlanPatch_NegativePrice(t *testing.T) {
err := validatePlanPatch(UpdatePlanRequest{Price: ptrFloat(-1)})
require.Error(t, err)
require.Contains(t, err.Error(), "price")
}
func TestValidatePlanPatch_ZeroPrice(t *testing.T) {
err := validatePlanPatch(UpdatePlanRequest{Price: ptrFloat(0)})
require.Error(t, err)
require.Contains(t, err.Error(), "price")
}
func TestValidatePlanPatch_ValidPrice(t *testing.T) {
err := validatePlanPatch(UpdatePlanRequest{Price: ptrFloat(9.99)})
require.NoError(t, err)
}
func TestValidatePlanPatch_ZeroValidityDays(t *testing.T) {
err := validatePlanPatch(UpdatePlanRequest{ValidityDays: ptrInt(0)})
require.Error(t, err)
require.Contains(t, err.Error(), "validity days")
}
func TestValidatePlanPatch_EmptyValidityUnit(t *testing.T) {
err := validatePlanPatch(UpdatePlanRequest{ValidityUnit: ptrStr("")})
require.Error(t, err)
require.Contains(t, err.Error(), "validity unit")
}
func TestValidatePlanPatch_ValidValidityUnit(t *testing.T) {
err := validatePlanPatch(UpdatePlanRequest{ValidityUnit: ptrStr("days")})
require.NoError(t, err)
}
func TestValidatePlanPatch_AllNil(t *testing.T) {
err := validatePlanPatch(UpdatePlanRequest{})
require.NoError(t, err)
}
......@@ -22,16 +22,17 @@ func (s *PaymentConfigService) ListProviderInstances(ctx context.Context) ([]*db
// ProviderInstanceResponse is the API response for a provider instance.
type ProviderInstanceResponse struct {
ID int64 `json:"id"`
ProviderKey string `json:"provider_key"`
Name string `json:"name"`
Config map[string]string `json:"config"`
SupportedTypes []string `json:"supported_types"`
Limits string `json:"limits"`
Enabled bool `json:"enabled"`
RefundEnabled bool `json:"refund_enabled"`
SortOrder int `json:"sort_order"`
PaymentMode string `json:"payment_mode"`
ID int64 `json:"id"`
ProviderKey string `json:"provider_key"`
Name string `json:"name"`
Config map[string]string `json:"config"`
SupportedTypes []string `json:"supported_types"`
Limits string `json:"limits"`
Enabled bool `json:"enabled"`
RefundEnabled bool `json:"refund_enabled"`
AllowUserRefund bool `json:"allow_user_refund"`
SortOrder int `json:"sort_order"`
PaymentMode string `json:"payment_mode"`
}
// ListProviderInstancesWithConfig returns provider instances with decrypted config.
......@@ -46,8 +47,9 @@ func (s *PaymentConfigService) ListProviderInstancesWithConfig(ctx context.Conte
resp := ProviderInstanceResponse{
ID: int64(inst.ID), ProviderKey: inst.ProviderKey, Name: inst.Name,
SupportedTypes: splitTypes(inst.SupportedTypes), Limits: inst.Limits,
Enabled: inst.Enabled, RefundEnabled: inst.RefundEnabled, SortOrder: inst.SortOrder,
PaymentMode: inst.PaymentMode,
Enabled: inst.Enabled, RefundEnabled: inst.RefundEnabled,
AllowUserRefund: inst.AllowUserRefund,
SortOrder: inst.SortOrder, PaymentMode: inst.PaymentMode,
}
resp.Config, err = s.decryptAndMaskConfig(inst.Config)
if err != nil {
......@@ -110,10 +112,12 @@ func (s *PaymentConfigService) CreateProviderInstance(ctx context.Context, req C
if err != nil {
return nil, err
}
allowUserRefund := req.AllowUserRefund && req.RefundEnabled
return s.entClient.PaymentProviderInstance.Create().
SetProviderKey(req.ProviderKey).SetName(req.Name).SetConfig(enc).
SetSupportedTypes(typesStr).SetEnabled(req.Enabled).SetPaymentMode(req.PaymentMode).
SetSortOrder(req.SortOrder).SetLimits(req.Limits).SetRefundEnabled(req.RefundEnabled).
SetAllowUserRefund(allowUserRefund).
Save(ctx)
}
......@@ -221,6 +225,29 @@ func (s *PaymentConfigService) UpdateProviderInstance(ctx context.Context, id in
}
if req.RefundEnabled != nil {
u.SetRefundEnabled(*req.RefundEnabled)
// Cascade: turning off refund_enabled also disables allow_user_refund
if !*req.RefundEnabled {
u.SetAllowUserRefund(false)
}
}
if req.AllowUserRefund != nil {
// Only allow enabling when refund_enabled is (or will be) true
if *req.AllowUserRefund {
refundEnabled := false
if req.RefundEnabled != nil {
refundEnabled = *req.RefundEnabled
} else {
inst, err := s.entClient.PaymentProviderInstance.Get(ctx, id)
if err == nil {
refundEnabled = inst.RefundEnabled
}
}
if refundEnabled {
u.SetAllowUserRefund(true)
}
} else {
u.SetAllowUserRefund(false)
}
}
if req.PaymentMode != nil {
u.SetPaymentMode(*req.PaymentMode)
......@@ -228,6 +255,23 @@ func (s *PaymentConfigService) UpdateProviderInstance(ctx context.Context, id in
return u.Save(ctx)
}
// GetUserRefundEligibleInstanceIDs returns provider instance IDs that allow user refund.
func (s *PaymentConfigService) GetUserRefundEligibleInstanceIDs(ctx context.Context) ([]string, error) {
instances, err := s.entClient.PaymentProviderInstance.Query().
Where(
paymentproviderinstance.RefundEnabledEQ(true),
paymentproviderinstance.AllowUserRefundEQ(true),
).Select(paymentproviderinstance.FieldID).All(ctx)
if err != nil {
return nil, err
}
ids := make([]string, 0, len(instances))
for _, inst := range instances {
ids = append(ids, strconv.FormatInt(int64(inst.ID), 10))
}
return ids, nil
}
func (s *PaymentConfigService) mergeConfig(ctx context.Context, id int64, newConfig map[string]string) (map[string]string, error) {
inst, err := s.entClient.PaymentProviderInstance.Get(ctx, id)
if err != nil {
......
......@@ -101,7 +101,7 @@ func TestIsSensitiveConfigField(t *testing.T) {
t.Parallel()
tests := []struct {
field string
field string
wantSen bool
}{
// Sensitive fields (contain key/secret/private/password/pkey patterns)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment