Unverified Commit d402e722 authored by Wesley Liddick's avatar Wesley Liddick Committed by GitHub
Browse files

Merge pull request #1637 from touwaeriol/feat/websearch-notify-pricing

feat: web search emulation, balance/quota notify, account stats pricing, per-provider refund control, Stripe fix / Web 搜索模拟、余额配额通知、渠道统计计费、按服务商退款控制、Stripe 修复
parents e534e9ba 8548a130
package service
import (
"context"
"fmt"
"html"
"log/slog"
"strconv"
"strings"
"time"
)
const (
emailSendTimeout = 30 * time.Second
// Threshold type values
thresholdTypeFixed = "fixed"
thresholdTypePercentage = "percentage"
// Quota dimension labels
quotaDimDaily = "daily"
quotaDimWeekly = "weekly"
quotaDimTotal = "total"
defaultSiteName = "Sub2API"
)
// quotaDimLabels maps dimension names to display labels.
var quotaDimLabels = map[string]string{
quotaDimDaily: "日限额 / Daily",
quotaDimWeekly: "周限额 / Weekly",
quotaDimTotal: "总限额 / Total",
}
// AccountQuotaReader provides read access to account quota data.
type AccountQuotaReader interface {
GetByID(ctx context.Context, id int64) (*Account, error)
}
// BalanceNotifyService handles balance and quota threshold notifications.
type BalanceNotifyService struct {
emailService *EmailService
settingRepo SettingRepository
accountRepo AccountQuotaReader
}
// NewBalanceNotifyService creates a new BalanceNotifyService.
func NewBalanceNotifyService(emailService *EmailService, settingRepo SettingRepository, accountRepo AccountQuotaReader) *BalanceNotifyService {
return &BalanceNotifyService{
emailService: emailService,
settingRepo: settingRepo,
accountRepo: accountRepo,
}
}
// resolveBalanceThreshold returns the effective balance threshold.
// For percentage type, it computes threshold = totalRecharged * percentage / 100.
func resolveBalanceThreshold(threshold float64, thresholdType string, totalRecharged float64) float64 {
if thresholdType == thresholdTypePercentage && totalRecharged > 0 {
return totalRecharged * threshold / 100
}
return threshold
}
// CheckBalanceAfterDeduction checks if balance crossed below threshold after deduction.
// Notification is sent only on first crossing: oldBalance >= threshold && newBalance < threshold.
func (s *BalanceNotifyService) CheckBalanceAfterDeduction(ctx context.Context, user *User, oldBalance, cost float64) {
if !s.canNotifyBalance(user) {
return
}
effectiveThreshold, rechargeURL, ok := s.resolveUserEffectiveThreshold(ctx, user)
if !ok {
return
}
newBalance := oldBalance - cost
if !crossedDownward(oldBalance, newBalance, effectiveThreshold) {
return
}
s.dispatchBalanceLowEmail(ctx, user, newBalance, effectiveThreshold, rechargeURL)
}
// canNotifyBalance checks nil guards and user-level toggle.
func (s *BalanceNotifyService) canNotifyBalance(user *User) bool {
if user == nil || s.emailService == nil || s.settingRepo == nil {
return false
}
return user.BalanceNotifyEnabled
}
// resolveUserEffectiveThreshold reads global + user config, returns the effective threshold.
// Returns ok=false when notifications should be skipped.
func (s *BalanceNotifyService) resolveUserEffectiveThreshold(ctx context.Context, user *User) (effectiveThreshold float64, rechargeURL string, ok bool) {
globalEnabled, globalThreshold, rechargeURL := s.getBalanceNotifyConfig(ctx)
if !globalEnabled {
return 0, "", false
}
threshold := globalThreshold
if user.BalanceNotifyThreshold != nil {
threshold = *user.BalanceNotifyThreshold
}
if threshold <= 0 {
return 0, "", false
}
effectiveThreshold = resolveBalanceThreshold(threshold, user.BalanceNotifyThresholdType, user.TotalRecharged)
if effectiveThreshold <= 0 {
return 0, "", false
}
return effectiveThreshold, rechargeURL, true
}
// crossedDownward returns true when oldV was at-or-above threshold but newV dropped below it.
func crossedDownward(oldV, newV, threshold float64) bool {
return oldV >= threshold && newV < threshold
}
// dispatchBalanceLowEmail collects recipients and sends the alert in a goroutine.
func (s *BalanceNotifyService) dispatchBalanceLowEmail(ctx context.Context, user *User, newBalance, threshold float64, rechargeURL string) {
siteName := s.getSiteName(ctx)
recipients := s.collectBalanceNotifyRecipients(user)
slog.Info("CheckBalanceAfterDeduction: sending notification",
"user_id", user.ID, "recipients", recipients, "new_balance", newBalance, "threshold", threshold)
go func() {
defer func() {
if r := recover(); r != nil {
slog.Error("panic in balance notification", "recover", r)
}
}()
s.sendBalanceLowEmails(recipients, user.Username, user.Email, newBalance, threshold, siteName, rechargeURL)
}()
}
// quotaDim describes one quota dimension for notification checking.
type quotaDim struct {
name string
enabled bool
threshold float64
thresholdType string // "fixed" (default) or "percentage"
currentUsed float64
limit float64
}
// resolvedThreshold converts the user-facing "remaining" threshold into a usage-based trigger point.
// The threshold represents how much quota REMAINS when the alert fires:
// - Fixed ($): threshold=400, limit=1000 → fires when usage reaches 600 (remaining drops to 400)
// - Percentage (%): threshold=30, limit=1000 → fires when usage reaches 700 (remaining drops to 30%)
func (d quotaDim) resolvedThreshold() float64 {
if d.limit <= 0 {
return 0
}
if d.thresholdType == thresholdTypePercentage {
return d.limit * (1 - d.threshold/100)
}
return d.limit - d.threshold
}
// buildQuotaDims returns the three quota dimensions for notification checking.
func buildQuotaDims(account *Account) []quotaDim {
return []quotaDim{
{quotaDimDaily, account.GetQuotaNotifyDailyEnabled(), account.GetQuotaNotifyDailyThreshold(), account.GetQuotaNotifyDailyThresholdType(), account.GetQuotaDailyUsed(), account.GetQuotaDailyLimit()},
{quotaDimWeekly, account.GetQuotaNotifyWeeklyEnabled(), account.GetQuotaNotifyWeeklyThreshold(), account.GetQuotaNotifyWeeklyThresholdType(), account.GetQuotaWeeklyUsed(), account.GetQuotaWeeklyLimit()},
{quotaDimTotal, account.GetQuotaNotifyTotalEnabled(), account.GetQuotaNotifyTotalThreshold(), account.GetQuotaNotifyTotalThresholdType(), account.GetQuotaUsed(), account.GetQuotaLimit()},
}
}
// buildQuotaDimsFromState builds quota dimensions using DB transaction state instead of account snapshot.
// Notification settings (enabled, threshold, thresholdType) come from the account; usage values from quotaState.
func buildQuotaDimsFromState(account *Account, state *AccountQuotaState) []quotaDim {
return []quotaDim{
{quotaDimDaily, account.GetQuotaNotifyDailyEnabled(), account.GetQuotaNotifyDailyThreshold(), account.GetQuotaNotifyDailyThresholdType(), state.DailyUsed, state.DailyLimit},
{quotaDimWeekly, account.GetQuotaNotifyWeeklyEnabled(), account.GetQuotaNotifyWeeklyThreshold(), account.GetQuotaNotifyWeeklyThresholdType(), state.WeeklyUsed, state.WeeklyLimit},
{quotaDimTotal, account.GetQuotaNotifyTotalEnabled(), account.GetQuotaNotifyTotalThreshold(), account.GetQuotaNotifyTotalThresholdType(), state.TotalUsed, state.TotalLimit},
}
}
// CheckAccountQuotaAfterIncrement checks if any quota dimension crossed above its notify threshold.
// When quotaState is non-nil (from DB transaction RETURNING), it is used directly for threshold
// checking, avoiding a separate DB read. Otherwise it falls back to fetching fresh account data.
func (s *BalanceNotifyService) CheckAccountQuotaAfterIncrement(ctx context.Context, account *Account, cost float64, quotaState *AccountQuotaState) {
if account == nil || s.emailService == nil || s.settingRepo == nil || cost <= 0 {
return
}
if !s.isAccountQuotaNotifyEnabled(ctx) {
return
}
adminEmails := s.getAccountQuotaNotifyEmails(ctx)
if len(adminEmails) == 0 {
return
}
siteName := s.getSiteName(ctx)
var dims []quotaDim
if quotaState != nil {
dims = buildQuotaDimsFromState(account, quotaState)
} else {
freshAccount := s.fetchFreshAccount(ctx, account)
dims = buildQuotaDims(freshAccount)
account = freshAccount // use fresh data for alert metadata
}
s.checkQuotaDimCrossings(account, dims, cost, adminEmails, siteName)
}
// fetchFreshAccount loads the latest account from DB; falls back to the snapshot on error.
func (s *BalanceNotifyService) fetchFreshAccount(ctx context.Context, snapshot *Account) *Account {
if s.accountRepo == nil {
return snapshot
}
fresh, err := s.accountRepo.GetByID(ctx, snapshot.ID)
if err != nil {
slog.Warn("failed to fetch fresh account for quota notify, using snapshot",
"account_id", snapshot.ID, "error", err)
return snapshot
}
return fresh
}
// checkQuotaDimCrossings iterates pre-built quota dimensions and sends alerts for threshold crossings.
// Pre-increment value is reconstructed as currentUsed - cost to detect the crossing moment.
func (s *BalanceNotifyService) checkQuotaDimCrossings(account *Account, dims []quotaDim, cost float64, adminEmails []string, siteName string) {
for _, dim := range dims {
if !dim.enabled || dim.threshold <= 0 {
continue
}
effectiveThreshold := dim.resolvedThreshold()
if effectiveThreshold <= 0 {
continue
}
newUsed := dim.currentUsed
oldUsed := dim.currentUsed - cost
if oldUsed < effectiveThreshold && newUsed >= effectiveThreshold {
s.asyncSendQuotaAlert(adminEmails, account.ID, account.Name, account.Platform, dim, newUsed, effectiveThreshold, siteName)
}
}
}
// asyncSendQuotaAlert sends quota alert email in a goroutine with panic recovery.
func (s *BalanceNotifyService) asyncSendQuotaAlert(adminEmails []string, accountID int64, accountName, platform string, dim quotaDim, newUsed, effectiveThreshold float64, siteName string) {
go func() {
defer func() {
if r := recover(); r != nil {
slog.Error("panic in quota notification", "recover", r)
}
}()
s.sendQuotaAlertEmails(adminEmails, accountID, accountName, platform, dim, newUsed, siteName)
}()
}
// getBalanceNotifyConfig reads global balance notification settings.
func (s *BalanceNotifyService) getBalanceNotifyConfig(ctx context.Context) (enabled bool, threshold float64, rechargeURL string) {
keys := []string{SettingKeyBalanceLowNotifyEnabled, SettingKeyBalanceLowNotifyThreshold, SettingKeyBalanceLowNotifyRechargeURL}
settings, err := s.settingRepo.GetMultiple(ctx, keys)
if err != nil {
return false, 0, ""
}
enabled = settings[SettingKeyBalanceLowNotifyEnabled] == "true"
if v := settings[SettingKeyBalanceLowNotifyThreshold]; v != "" {
if f, err := strconv.ParseFloat(v, 64); err == nil {
threshold = f
}
}
rechargeURL = settings[SettingKeyBalanceLowNotifyRechargeURL]
return
}
// isAccountQuotaNotifyEnabled checks the global account quota notification toggle.
func (s *BalanceNotifyService) isAccountQuotaNotifyEnabled(ctx context.Context) bool {
val, err := s.settingRepo.GetValue(ctx, SettingKeyAccountQuotaNotifyEnabled)
if err != nil {
return false
}
return val == "true"
}
// getAccountQuotaNotifyEmails reads admin notification emails from settings,
// filtering out disabled and unverified entries.
func (s *BalanceNotifyService) getAccountQuotaNotifyEmails(ctx context.Context) []string {
raw, err := s.settingRepo.GetValue(ctx, SettingKeyAccountQuotaNotifyEmails)
if err != nil || strings.TrimSpace(raw) == "" || raw == "[]" {
return nil
}
entries := ParseNotifyEmails(raw)
if len(entries) == 0 {
return nil
}
return filterVerifiedEmails(entries)
}
// getSiteName reads site name from settings with fallback.
func (s *BalanceNotifyService) getSiteName(ctx context.Context) string {
name, err := s.settingRepo.GetValue(ctx, SettingKeySiteName)
if err != nil || name == "" {
return defaultSiteName
}
return name
}
// filterVerifiedEmails returns deduplicated, non-disabled, verified emails.
func filterVerifiedEmails(entries []NotifyEmailEntry) []string {
var recipients []string
seen := make(map[string]bool)
for _, entry := range entries {
if entry.Disabled || !entry.Verified {
continue
}
email := strings.TrimSpace(entry.Email)
if email == "" {
continue
}
lower := strings.ToLower(email)
if seen[lower] {
continue
}
seen[lower] = true
recipients = append(recipients, email)
}
return recipients
}
// collectBalanceNotifyRecipients returns verified, non-disabled email recipients.
// Only emails with verified=true and disabled=false are included.
func (s *BalanceNotifyService) collectBalanceNotifyRecipients(user *User) []string {
return filterVerifiedEmails(user.BalanceNotifyExtraEmails)
}
// sendEmails sends an email to all recipients with shared timeout and error logging.
func (s *BalanceNotifyService) sendEmails(recipients []string, subject, body string, logAttrs ...any) {
if len(recipients) == 0 {
slog.Warn("sendEmails: no recipients", "subject", subject)
return
}
for _, to := range recipients {
ctx, cancel := context.WithTimeout(context.Background(), emailSendTimeout)
if err := s.emailService.SendEmail(ctx, to, subject, body); err != nil {
attrs := append([]any{"to", to, "error", err}, logAttrs...)
slog.Error("failed to send notification", attrs...)
} else {
slog.Info("notification email sent successfully", "to", to, "subject", subject)
}
cancel()
}
}
// sendBalanceLowEmails sends balance low notification to all recipients.
func (s *BalanceNotifyService) sendBalanceLowEmails(recipients []string, userName, userEmail string, balance, threshold float64, siteName, rechargeURL string) {
displayName := userName
if displayName == "" {
displayName = userEmail
}
subject := fmt.Sprintf("[%s] 余额不足提醒 / Balance Low Alert", sanitizeEmailHeader(siteName))
body := s.buildBalanceLowEmailBody(html.EscapeString(displayName), balance, threshold, html.EscapeString(siteName), rechargeURL)
s.sendEmails(recipients, subject, body, "user_email", userEmail, "balance", balance)
}
// sendQuotaAlertEmails sends quota alert notification to admin emails.
func (s *BalanceNotifyService) sendQuotaAlertEmails(adminEmails []string, accountID int64, accountName, platform string, dim quotaDim, used float64, siteName string) {
dimLabel := quotaDimLabels[dim.name]
if dimLabel == "" {
dimLabel = dim.name
}
// Format the remaining-based threshold for display
thresholdDisplay := fmt.Sprintf("$%.2f", dim.threshold)
if dim.thresholdType == thresholdTypePercentage {
thresholdDisplay = fmt.Sprintf("%.0f%%", dim.threshold)
}
remaining := dim.limit - used
if remaining < 0 {
remaining = 0
}
subject := fmt.Sprintf("[%s] 账号限额告警 / Account Quota Alert - %s", sanitizeEmailHeader(siteName), sanitizeEmailHeader(accountName))
body := s.buildQuotaAlertEmailBody(accountID, html.EscapeString(accountName), html.EscapeString(platform), html.EscapeString(dimLabel), used, dim.limit, remaining, thresholdDisplay, html.EscapeString(siteName))
s.sendEmails(adminEmails, subject, body, "account", accountName, "dimension", dim.name)
}
// sanitizeEmailHeader removes CR/LF characters to prevent SMTP header injection.
func sanitizeEmailHeader(s string) string {
return strings.NewReplacer("\r", "", "\n", "").Replace(s)
}
// balanceLowEmailTemplate is the HTML template for balance low notifications.
// Format args: siteName, userName, userName, balance, threshold, threshold.
// The recharge button is appended dynamically when rechargeURL is set.
const balanceLowEmailTemplate = `<!DOCTYPE html>
<html>
<head>
<meta charset="UTF-8">
<style>
body { font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif; background-color: #f5f5f5; margin: 0; padding: 20px; }
.container { max-width: 600px; margin: 0 auto; background-color: #fff; border-radius: 8px; overflow: hidden; box-shadow: 0 2px 8px rgba(0,0,0,0.1); }
.header { background: linear-gradient(135deg, #f59e0b 0%%, #d97706 100%%); color: white; padding: 30px; text-align: center; }
.header h1 { margin: 0; font-size: 24px; }
.content { padding: 40px 30px; text-align: center; }
.balance { font-size: 36px; font-weight: bold; color: #dc2626; margin: 20px 0; }
.info { color: #666; font-size: 14px; line-height: 1.6; margin-top: 20px; }
.recharge-btn { display: inline-block; margin-top: 24px; padding: 12px 32px; background: linear-gradient(135deg, #f59e0b 0%%, #d97706 100%%); color: #fff; text-decoration: none; border-radius: 6px; font-size: 16px; font-weight: bold; }
.footer { background-color: #f8f9fa; padding: 20px; text-align: center; color: #999; font-size: 12px; }
</style>
</head>
<body>
<div class="container">
<div class="header"><h1>%s</h1></div>
<div class="content">
<p style="font-size: 18px; color: #333;">%s,您的余额不足</p>
<p style="color: #666;">Dear %s, your balance is running low</p>
<div class="balance">$%.2f</div>
<div class="info">
<p>您的账户余额已低于提醒阈值 <strong>$%.2f</strong>。</p>
<p>Your account balance has fallen below the alert threshold of <strong>$%.2f</strong>.</p>
<p>请及时充值以免服务中断。</p>
<p>Please top up to avoid service interruption.</p>
</div>
%s
</div>
<div class="footer"><p>此邮件由系统自动发送,请勿回复。</p></div>
</div>
</body>
</html>`
// quotaAlertEmailTemplate is the HTML template for account quota alert notifications.
// Format args: siteName, accountID, accountName, platform, dimLabel, used, limitStr, remaining, thresholdDisplay.
const quotaAlertEmailTemplate = `<!DOCTYPE html>
<html>
<head>
<meta charset="UTF-8">
<style>
body { font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif; background-color: #f5f5f5; margin: 0; padding: 20px; }
.container { max-width: 600px; margin: 0 auto; background-color: #fff; border-radius: 8px; overflow: hidden; box-shadow: 0 2px 8px rgba(0,0,0,0.1); }
.header { background: linear-gradient(135deg, #ef4444 0%%, #dc2626 100%%); color: white; padding: 30px; text-align: center; }
.header h1 { margin: 0; font-size: 24px; }
.content { padding: 40px 30px; }
.metric { display: flex; justify-content: space-between; padding: 12px 0; border-bottom: 1px solid #eee; }
.metric-label { color: #666; }
.metric-value { font-weight: bold; color: #333; }
.info { color: #666; font-size: 14px; line-height: 1.6; margin-top: 20px; text-align: center; }
.footer { background-color: #f8f9fa; padding: 20px; text-align: center; color: #999; font-size: 12px; }
</style>
</head>
<body>
<div class="container">
<div class="header"><h1>%s</h1></div>
<div class="content">
<p style="font-size: 18px; color: #333; text-align: center;">账号限额告警 / Account Quota Alert</p>
<div class="metric"><span class="metric-label">账号 ID / Account ID</span><span class="metric-value">#%d</span></div>
<div class="metric"><span class="metric-label">账号 / Account</span><span class="metric-value">%s</span></div>
<div class="metric"><span class="metric-label">平台 / Platform</span><span class="metric-value">%s</span></div>
<div class="metric"><span class="metric-label">维度 / Dimension</span><span class="metric-value">%s</span></div>
<div class="metric"><span class="metric-label">已使用 / Used</span><span class="metric-value">$%.2f</span></div>
<div class="metric"><span class="metric-label">限额 / Limit</span><span class="metric-value">%s</span></div>
<div class="metric"><span class="metric-label">剩余额度 / Remaining</span><span class="metric-value">$%.2f</span></div>
<div class="metric"><span class="metric-label">提醒阈值 / Alert Threshold</span><span class="metric-value">%s</span></div>
<div class="info">
<p>账号剩余额度已低于提醒阈值,请及时关注。</p>
<p>Account remaining quota has fallen below the alert threshold.</p>
</div>
</div>
<div class="footer"><p>此邮件由系统自动发送,请勿回复。</p></div>
</div>
</body>
</html>`
// buildBalanceLowEmailBody builds HTML email for balance low notification.
func (s *BalanceNotifyService) buildBalanceLowEmailBody(userName string, balance, threshold float64, siteName, rechargeURL string) string {
rechargeBlock := ""
if rechargeURL != "" {
rechargeBlock = fmt.Sprintf(`<a href="%s" class="recharge-btn">立即充值 / Top Up Now</a>`, html.EscapeString(rechargeURL))
}
return fmt.Sprintf(balanceLowEmailTemplate, siteName, userName, userName, balance, threshold, threshold, rechargeBlock)
}
// buildQuotaAlertEmailBody builds HTML email for account quota alert.
func (s *BalanceNotifyService) buildQuotaAlertEmailBody(accountID int64, accountName, platform, dimLabel string, used, limit, remaining float64, thresholdDisplay, siteName string) string {
limitStr := fmt.Sprintf("$%.2f", limit)
if limit <= 0 {
limitStr = "无限制 / Unlimited"
}
return fmt.Sprintf(quotaAlertEmailTemplate, siteName, accountID, accountName, platform, dimLabel, used, limitStr, remaining, thresholdDisplay)
}
//go:build unit
package service
import (
"testing"
"github.com/stretchr/testify/require"
)
// ---------- resolveBalanceThreshold ----------
func TestResolveBalanceThreshold_Fixed(t *testing.T) {
// Fixed type always returns the raw threshold regardless of totalRecharged.
require.Equal(t, 10.0, resolveBalanceThreshold(10, thresholdTypeFixed, 1000))
require.Equal(t, 10.0, resolveBalanceThreshold(10, thresholdTypeFixed, 0))
require.Equal(t, 0.0, resolveBalanceThreshold(0, thresholdTypeFixed, 1000))
}
func TestResolveBalanceThreshold_Percentage(t *testing.T) {
// 10% of 1000 = 100
require.Equal(t, 100.0, resolveBalanceThreshold(10, thresholdTypePercentage, 1000))
// 50% of 200 = 100
require.Equal(t, 100.0, resolveBalanceThreshold(50, thresholdTypePercentage, 200))
}
func TestResolveBalanceThreshold_PercentageZeroRecharged(t *testing.T) {
// When totalRecharged is 0, percentage falls through to raw threshold
// (treated as fixed). This is the defensive behavior.
require.Equal(t, 10.0, resolveBalanceThreshold(10, thresholdTypePercentage, 0))
}
func TestResolveBalanceThreshold_EmptyType(t *testing.T) {
// Empty type is treated as fixed (not percentage).
require.Equal(t, 10.0, resolveBalanceThreshold(10, "", 1000))
}
// ---------- quotaDim.resolvedThreshold ----------
func TestResolvedThreshold_FixedNormal(t *testing.T) {
// threshold=400 remaining, limit=1000 → usage trigger at 600
d := quotaDim{threshold: 400, thresholdType: thresholdTypeFixed, limit: 1000}
require.Equal(t, 600.0, d.resolvedThreshold())
}
func TestResolvedThreshold_FixedThresholdExceedsLimit(t *testing.T) {
// threshold=1200, limit=1000 → returns negative, callers must skip
d := quotaDim{threshold: 1200, thresholdType: thresholdTypeFixed, limit: 1000}
require.Equal(t, -200.0, d.resolvedThreshold())
}
func TestResolvedThreshold_FixedThresholdEqualsLimit(t *testing.T) {
// threshold=1000, limit=1000 → returns 0 (alert fires at 0 usage)
d := quotaDim{threshold: 1000, thresholdType: thresholdTypeFixed, limit: 1000}
require.Equal(t, 0.0, d.resolvedThreshold())
}
func TestResolvedThreshold_PercentageNormal(t *testing.T) {
// threshold=30%, limit=1000 → usage trigger at 700 (remaining drops to 30%)
d := quotaDim{threshold: 30, thresholdType: thresholdTypePercentage, limit: 1000}
require.InDelta(t, 700.0, d.resolvedThreshold(), 0.001)
}
func TestResolvedThreshold_PercentageZeroPercent(t *testing.T) {
// threshold=0%, limit=1000 → fires when remaining drops to 0 (usage=1000)
d := quotaDim{threshold: 0, thresholdType: thresholdTypePercentage, limit: 1000}
require.InDelta(t, 1000.0, d.resolvedThreshold(), 0.001)
}
func TestResolvedThreshold_PercentageHundredPercent(t *testing.T) {
// threshold=100%, limit=1000 → fires immediately (remaining drops to 100% i.e. nothing used yet)
d := quotaDim{threshold: 100, thresholdType: thresholdTypePercentage, limit: 1000}
require.InDelta(t, 0.0, d.resolvedThreshold(), 0.001)
}
func TestResolvedThreshold_PercentageOverHundred(t *testing.T) {
// threshold=150%, limit=1000 → returns negative (never triggers; callers skip)
d := quotaDim{threshold: 150, thresholdType: thresholdTypePercentage, limit: 1000}
require.Less(t, d.resolvedThreshold(), 0.0)
}
func TestResolvedThreshold_ZeroLimit(t *testing.T) {
// limit=0 → returns 0 to avoid division and false alerts on unlimited quotas
d := quotaDim{threshold: 100, thresholdType: thresholdTypeFixed, limit: 0}
require.Equal(t, 0.0, d.resolvedThreshold())
}
func TestResolvedThreshold_NegativeLimit(t *testing.T) {
// Negative limit treated as 0
d := quotaDim{threshold: 100, thresholdType: thresholdTypeFixed, limit: -10}
require.Equal(t, 0.0, d.resolvedThreshold())
}
// ---------- sanitizeEmailHeader ----------
func TestSanitizeEmailHeader_CRLF(t *testing.T) {
require.Equal(t, "Subject injected", sanitizeEmailHeader("Subject\r\n injected"))
}
func TestSanitizeEmailHeader_OnlyCR(t *testing.T) {
require.Equal(t, "foobar", sanitizeEmailHeader("foo\rbar"))
}
func TestSanitizeEmailHeader_OnlyLF(t *testing.T) {
require.Equal(t, "foobar", sanitizeEmailHeader("foo\nbar"))
}
func TestSanitizeEmailHeader_Clean(t *testing.T) {
require.Equal(t, "Sub2API", sanitizeEmailHeader("Sub2API"))
}
func TestSanitizeEmailHeader_Empty(t *testing.T) {
require.Equal(t, "", sanitizeEmailHeader(""))
}
func TestSanitizeEmailHeader_MultipleNewlines(t *testing.T) {
require.Equal(t, "abc", sanitizeEmailHeader("a\r\nb\r\nc"))
}
// ---------- buildQuotaDims ----------
func TestBuildQuotaDims_AllDimensionsReturned(t *testing.T) {
// Use an account with quota notify config across all 3 dimensions.
a := &Account{
Platform: PlatformAnthropic,
Type: AccountTypeAPIKey,
Extra: map[string]any{
"quota_notify_daily_enabled": true,
"quota_notify_daily_threshold": 100.0,
"quota_notify_daily_threshold_type": thresholdTypeFixed,
"quota_notify_weekly_enabled": true,
"quota_notify_weekly_threshold": 20.0,
"quota_notify_weekly_threshold_type": thresholdTypePercentage,
"quota_notify_total_enabled": false,
"quota_daily_limit": 500.0,
"quota_weekly_limit": 2000.0,
"quota_limit": 10000.0,
"quota_daily_used": 50.0,
"quota_weekly_used": 300.0,
"quota_used": 1000.0,
},
}
dims := buildQuotaDims(a)
require.Len(t, dims, 3)
// Daily
require.Equal(t, quotaDimDaily, dims[0].name)
require.True(t, dims[0].enabled)
require.Equal(t, 100.0, dims[0].threshold)
require.Equal(t, thresholdTypeFixed, dims[0].thresholdType)
require.Equal(t, 500.0, dims[0].limit)
require.Equal(t, 50.0, dims[0].currentUsed)
// Weekly
require.Equal(t, quotaDimWeekly, dims[1].name)
require.True(t, dims[1].enabled)
require.Equal(t, 20.0, dims[1].threshold)
require.Equal(t, thresholdTypePercentage, dims[1].thresholdType)
require.Equal(t, 2000.0, dims[1].limit)
// Total
require.Equal(t, quotaDimTotal, dims[2].name)
require.False(t, dims[2].enabled)
require.Equal(t, 10000.0, dims[2].limit)
require.Equal(t, 1000.0, dims[2].currentUsed)
}
func TestBuildQuotaDims_EmptyExtra(t *testing.T) {
// Missing fields default to zero/disabled.
a := &Account{
Platform: PlatformAnthropic,
Type: AccountTypeAPIKey,
Extra: map[string]any{},
}
dims := buildQuotaDims(a)
require.Len(t, dims, 3)
for _, d := range dims {
require.False(t, d.enabled)
require.Equal(t, 0.0, d.threshold)
require.Equal(t, 0.0, d.limit)
}
}
// ---------- buildQuotaDimsFromState ----------
func TestBuildQuotaDimsFromState_UsesStateValues(t *testing.T) {
// Usage values should come from the state, not the account.
a := &Account{
Platform: PlatformAnthropic,
Type: AccountTypeAPIKey,
Extra: map[string]any{
"quota_notify_daily_enabled": true,
"quota_notify_daily_threshold": 100.0,
"quota_daily_used": 999.0, // should be ignored
"quota_daily_limit": 999.0, // should be ignored
},
}
state := &AccountQuotaState{
DailyUsed: 77.0,
DailyLimit: 500.0,
WeeklyUsed: 88.0,
WeeklyLimit: 2000.0,
TotalUsed: 99.0,
TotalLimit: 10000.0,
}
dims := buildQuotaDimsFromState(a, state)
require.Len(t, dims, 3)
// Settings from account (enabled, threshold, thresholdType)
require.True(t, dims[0].enabled)
require.Equal(t, 100.0, dims[0].threshold)
// Usage from state
require.Equal(t, 77.0, dims[0].currentUsed)
require.Equal(t, 500.0, dims[0].limit)
require.Equal(t, 88.0, dims[1].currentUsed)
require.Equal(t, 2000.0, dims[1].limit)
require.Equal(t, 99.0, dims[2].currentUsed)
require.Equal(t, 10000.0, dims[2].limit)
}
// ---------- collectBalanceNotifyRecipients ----------
func TestCollectBalanceNotifyRecipients_Empty(t *testing.T) {
s := &BalanceNotifyService{}
u := &User{BalanceNotifyExtraEmails: nil}
require.Empty(t, s.collectBalanceNotifyRecipients(u))
}
func TestCollectBalanceNotifyRecipients_FiltersDisabledAndUnverified(t *testing.T) {
s := &BalanceNotifyService{}
u := &User{
BalanceNotifyExtraEmails: []NotifyEmailEntry{
{Email: "a@example.com", Verified: true, Disabled: false},
{Email: "b@example.com", Verified: true, Disabled: true}, // disabled
{Email: "c@example.com", Verified: false, Disabled: false}, // unverified
{Email: "d@example.com", Verified: true, Disabled: false},
},
}
got := s.collectBalanceNotifyRecipients(u)
require.Equal(t, []string{"a@example.com", "d@example.com"}, got)
}
func TestCollectBalanceNotifyRecipients_DeduplicatesCaseInsensitive(t *testing.T) {
s := &BalanceNotifyService{}
u := &User{
BalanceNotifyExtraEmails: []NotifyEmailEntry{
{Email: "User@Example.com", Verified: true},
{Email: "user@example.com", Verified: true},
{Email: "USER@EXAMPLE.COM", Verified: true},
},
}
got := s.collectBalanceNotifyRecipients(u)
require.Len(t, got, 1)
// The original casing of the first entry is preserved.
require.Equal(t, "User@Example.com", got[0])
}
func TestCollectBalanceNotifyRecipients_SkipsEmpty(t *testing.T) {
s := &BalanceNotifyService{}
u := &User{
BalanceNotifyExtraEmails: []NotifyEmailEntry{
{Email: " ", Verified: true},
{Email: "", Verified: true},
{Email: "valid@example.com", Verified: true},
},
}
got := s.collectBalanceNotifyRecipients(u)
require.Equal(t, []string{"valid@example.com"}, got)
}
func TestCollectBalanceNotifyRecipients_TrimsWhitespace(t *testing.T) {
s := &BalanceNotifyService{}
u := &User{
BalanceNotifyExtraEmails: []NotifyEmailEntry{
{Email: " trimmed@example.com ", Verified: true},
},
}
got := s.collectBalanceNotifyRecipients(u)
require.Equal(t, []string{"trimmed@example.com"}, got)
}
......@@ -363,7 +363,6 @@ func TestCalculateImageCost(t *testing.T) {
require.InDelta(t, 0.134*3, cost.ActualCost, 1e-10)
}
func TestIsModelSupported(t *testing.T) {
svc := newTestBillingService()
......@@ -719,3 +718,123 @@ func TestGetModelPricing_MapsDynamicPriorityFieldsIntoBillingPricing(t *testing.
require.InDelta(t, 1.5, pricing.LongContextInputMultiplier, 1e-12)
require.InDelta(t, 1.25, pricing.LongContextOutputMultiplier, 1e-12)
}
// ---------------------------------------------------------------------------
// GetModelPricingWithChannel
// ---------------------------------------------------------------------------
func TestGetModelPricingWithChannel_NilChannelPricing_ReturnsOriginal(t *testing.T) {
svc := newTestBillingService()
pricing, err := svc.GetModelPricingWithChannel("claude-sonnet-4", nil)
require.NoError(t, err)
require.NotNil(t, pricing)
// Should be identical to GetModelPricing
original, err := svc.GetModelPricing("claude-sonnet-4")
require.NoError(t, err)
require.InDelta(t, original.InputPricePerToken, pricing.InputPricePerToken, 1e-12)
require.InDelta(t, original.OutputPricePerToken, pricing.OutputPricePerToken, 1e-12)
require.InDelta(t, original.CacheCreationPricePerToken, pricing.CacheCreationPricePerToken, 1e-12)
require.InDelta(t, original.CacheReadPricePerToken, pricing.CacheReadPricePerToken, 1e-12)
}
func TestGetModelPricingWithChannel_OverrideInputPriceOnly(t *testing.T) {
svc := newTestBillingService()
chPricing := &ChannelModelPricing{
InputPrice: testPtrFloat64(99e-6),
}
pricing, err := svc.GetModelPricingWithChannel("claude-sonnet-4", chPricing)
require.NoError(t, err)
// InputPrice overridden (both normal and priority)
require.InDelta(t, 99e-6, pricing.InputPricePerToken, 1e-12)
require.InDelta(t, 99e-6, pricing.InputPricePerTokenPriority, 1e-12)
// OutputPrice unchanged (claude-sonnet-4 fallback = 15e-6)
require.InDelta(t, 15e-6, pricing.OutputPricePerToken, 1e-12)
}
func TestGetModelPricingWithChannel_OverrideOutputPriceOnly(t *testing.T) {
svc := newTestBillingService()
chPricing := &ChannelModelPricing{
OutputPrice: testPtrFloat64(88e-6),
}
pricing, err := svc.GetModelPricingWithChannel("claude-sonnet-4", chPricing)
require.NoError(t, err)
// OutputPrice overridden
require.InDelta(t, 88e-6, pricing.OutputPricePerToken, 1e-12)
require.InDelta(t, 88e-6, pricing.OutputPricePerTokenPriority, 1e-12)
// InputPrice unchanged (claude-sonnet-4 fallback = 3e-6)
require.InDelta(t, 3e-6, pricing.InputPricePerToken, 1e-12)
}
func TestGetModelPricingWithChannel_OverrideAllFields(t *testing.T) {
svc := newTestBillingService()
chPricing := &ChannelModelPricing{
InputPrice: testPtrFloat64(10e-6),
OutputPrice: testPtrFloat64(20e-6),
CacheWritePrice: testPtrFloat64(5e-6),
CacheReadPrice: testPtrFloat64(1e-6),
ImageOutputPrice: testPtrFloat64(50e-6),
}
pricing, err := svc.GetModelPricingWithChannel("claude-sonnet-4", chPricing)
require.NoError(t, err)
require.InDelta(t, 10e-6, pricing.InputPricePerToken, 1e-12)
require.InDelta(t, 10e-6, pricing.InputPricePerTokenPriority, 1e-12)
require.InDelta(t, 20e-6, pricing.OutputPricePerToken, 1e-12)
require.InDelta(t, 20e-6, pricing.OutputPricePerTokenPriority, 1e-12)
require.InDelta(t, 5e-6, pricing.CacheCreationPricePerToken, 1e-12)
require.InDelta(t, 5e-6, pricing.CacheCreation5mPrice, 1e-12)
require.InDelta(t, 5e-6, pricing.CacheCreation1hPrice, 1e-12)
require.InDelta(t, 1e-6, pricing.CacheReadPricePerToken, 1e-12)
require.InDelta(t, 1e-6, pricing.CacheReadPricePerTokenPriority, 1e-12)
require.InDelta(t, 50e-6, pricing.ImageOutputPricePerToken, 1e-12)
}
func TestGetModelPricingWithChannel_CacheWritePriceAffects5mAnd1h(t *testing.T) {
svc := newTestBillingService()
chPricing := &ChannelModelPricing{
CacheWritePrice: testPtrFloat64(7e-6),
}
pricing, err := svc.GetModelPricingWithChannel("claude-sonnet-4", chPricing)
require.NoError(t, err)
// CacheWritePrice should set all three: CacheCreationPricePerToken, 5m, and 1h
require.InDelta(t, 7e-6, pricing.CacheCreationPricePerToken, 1e-12)
require.InDelta(t, 7e-6, pricing.CacheCreation5mPrice, 1e-12)
require.InDelta(t, 7e-6, pricing.CacheCreation1hPrice, 1e-12)
}
func TestGetModelPricingWithChannel_CacheReadPriceAffectsPriority(t *testing.T) {
svc := newTestBillingService()
chPricing := &ChannelModelPricing{
CacheReadPrice: testPtrFloat64(2e-6),
}
pricing, err := svc.GetModelPricingWithChannel("claude-sonnet-4", chPricing)
require.NoError(t, err)
// CacheReadPrice should set both normal and priority
require.InDelta(t, 2e-6, pricing.CacheReadPricePerToken, 1e-12)
require.InDelta(t, 2e-6, pricing.CacheReadPricePerTokenPriority, 1e-12)
}
func TestGetModelPricingWithChannel_UnknownModelReturnsError(t *testing.T) {
svc := newTestBillingService()
chPricing := &ChannelModelPricing{
InputPrice: testPtrFloat64(1e-6),
}
pricing, err := svc.GetModelPricingWithChannel("totally-unknown-model", chPricing)
require.Error(t, err)
require.Nil(t, pricing)
require.Contains(t, err.Error(), "pricing not found")
}
//go:build unit
package service
import (
"context"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/stretchr/testify/require"
)
// ---------------------------------------------------------------------------
// CalculateCostUnified
// ---------------------------------------------------------------------------
func TestCalculateCostUnified_NilResolver_FallsBackToOldPath(t *testing.T) {
svc := newTestBillingService()
tokens := UsageTokens{InputTokens: 1000, OutputTokens: 500}
input := CostInput{
Model: "claude-sonnet-4",
Tokens: tokens,
RateMultiplier: 1.0,
Resolver: nil, // no resolver
}
cost, err := svc.CalculateCostUnified(input)
require.NoError(t, err)
// Should match the old-path result exactly
expected, err := svc.calculateCostInternal("claude-sonnet-4", tokens, 1.0, "", nil)
require.NoError(t, err)
require.InDelta(t, expected.TotalCost, cost.TotalCost, 1e-10)
require.InDelta(t, expected.ActualCost, cost.ActualCost, 1e-10)
// BillingMode is NOT set by old path through CalculateCostUnified (resolver == nil)
require.Empty(t, cost.BillingMode)
}
func TestCalculateCostUnified_TokenMode(t *testing.T) {
bs := newTestBillingService()
resolver := NewModelPricingResolver(nil, bs)
tokens := UsageTokens{InputTokens: 1000, OutputTokens: 500}
input := CostInput{
Ctx: context.Background(),
Model: "claude-sonnet-4",
Tokens: tokens,
RateMultiplier: 1.5,
Resolver: resolver,
}
cost, err := bs.CalculateCostUnified(input)
require.NoError(t, err)
require.NotNil(t, cost)
// Verify token billing: Input: 1000*3e-6=0.003, Output: 500*15e-6=0.0075
expectedTotal := 1000*3e-6 + 500*15e-6
require.InDelta(t, expectedTotal, cost.TotalCost, 1e-10)
require.InDelta(t, expectedTotal*1.5, cost.ActualCost, 1e-10)
require.Equal(t, string(BillingModeToken), cost.BillingMode)
}
func TestCalculateCostUnified_PerRequestMode(t *testing.T) {
// Set up a ChannelService with a per-request pricing channel
cs := newTestChannelServiceWithCache(t, &channelCache{
pricingByGroupModel: map[channelModelKey]*ChannelModelPricing{
{groupID: 1, model: "claude-sonnet-4"}: {
BillingMode: BillingModePerRequest,
PerRequestPrice: testPtrFloat64(0.05),
},
},
channelByGroupID: map[int64]*Channel{
1: {ID: 1, Status: StatusActive},
},
groupPlatform: map[int64]string{1: ""},
wildcardByGroupPlatform: map[channelGroupPlatformKey][]*wildcardPricingEntry{},
mappingByGroupModel: map[channelModelKey]string{},
wildcardMappingByGP: map[channelGroupPlatformKey][]*wildcardMappingEntry{},
byID: map[int64]*Channel{},
})
bs := newTestBillingService()
resolver := NewModelPricingResolver(cs, bs)
groupID := int64(1)
input := CostInput{
Ctx: context.Background(),
Model: "claude-sonnet-4",
GroupID: &groupID,
Tokens: UsageTokens{InputTokens: 100, OutputTokens: 50},
RequestCount: 3,
RateMultiplier: 2.0,
Resolver: resolver,
}
cost, err := bs.CalculateCostUnified(input)
require.NoError(t, err)
require.NotNil(t, cost)
// 3 requests * $0.05 = $0.15
require.InDelta(t, 0.15, cost.TotalCost, 1e-10)
// ActualCost = 0.15 * 2.0 = 0.30
require.InDelta(t, 0.30, cost.ActualCost, 1e-10)
require.Equal(t, string(BillingModePerRequest), cost.BillingMode)
}
func TestCalculateCostUnified_ImageMode(t *testing.T) {
cs := newTestChannelServiceWithCache(t, &channelCache{
pricingByGroupModel: map[channelModelKey]*ChannelModelPricing{
{groupID: 2, model: "gemini-image"}: {
BillingMode: BillingModeImage,
PerRequestPrice: testPtrFloat64(0.10),
},
},
channelByGroupID: map[int64]*Channel{
2: {ID: 2, Status: StatusActive},
},
groupPlatform: map[int64]string{2: ""},
wildcardByGroupPlatform: map[channelGroupPlatformKey][]*wildcardPricingEntry{},
mappingByGroupModel: map[channelModelKey]string{},
wildcardMappingByGP: map[channelGroupPlatformKey][]*wildcardMappingEntry{},
byID: map[int64]*Channel{},
})
bs := &BillingService{
cfg: &config.Config{},
fallbackPrices: map[string]*ModelPricing{},
}
resolver := NewModelPricingResolver(cs, bs)
groupID := int64(2)
input := CostInput{
Ctx: context.Background(),
Model: "gemini-image",
GroupID: &groupID,
Tokens: UsageTokens{},
RequestCount: 2,
RateMultiplier: 1.0,
Resolver: resolver,
}
cost, err := bs.CalculateCostUnified(input)
require.NoError(t, err)
require.NotNil(t, cost)
// 2 * $0.10 = $0.20
require.InDelta(t, 0.20, cost.TotalCost, 1e-10)
require.InDelta(t, 0.20, cost.ActualCost, 1e-10)
require.Equal(t, string(BillingModeImage), cost.BillingMode)
}
func TestCalculateCostUnified_RateMultiplierZeroDefaultsToOne(t *testing.T) {
bs := newTestBillingService()
resolver := NewModelPricingResolver(nil, bs)
tokens := UsageTokens{InputTokens: 1000, OutputTokens: 500}
costZero, err := bs.CalculateCostUnified(CostInput{
Ctx: context.Background(),
Model: "claude-sonnet-4",
Tokens: tokens,
RateMultiplier: 0, // should default to 1.0
Resolver: resolver,
})
require.NoError(t, err)
costOne, err := bs.CalculateCostUnified(CostInput{
Ctx: context.Background(),
Model: "claude-sonnet-4",
Tokens: tokens,
RateMultiplier: 1.0,
Resolver: resolver,
})
require.NoError(t, err)
require.InDelta(t, costOne.ActualCost, costZero.ActualCost, 1e-10)
}
func TestCalculateCostUnified_NegativeRateMultiplierDefaultsToOne(t *testing.T) {
bs := newTestBillingService()
resolver := NewModelPricingResolver(nil, bs)
tokens := UsageTokens{InputTokens: 1000}
costNeg, err := bs.CalculateCostUnified(CostInput{
Ctx: context.Background(),
Model: "claude-sonnet-4",
Tokens: tokens,
RateMultiplier: -5.0,
Resolver: resolver,
})
require.NoError(t, err)
costOne, err := bs.CalculateCostUnified(CostInput{
Ctx: context.Background(),
Model: "claude-sonnet-4",
Tokens: tokens,
RateMultiplier: 1.0,
Resolver: resolver,
})
require.NoError(t, err)
require.InDelta(t, costOne.ActualCost, costNeg.ActualCost, 1e-10)
}
func TestCalculateCostUnified_BillingModeFieldFilled(t *testing.T) {
bs := newTestBillingService()
resolver := NewModelPricingResolver(nil, bs)
cost, err := bs.CalculateCostUnified(CostInput{
Ctx: context.Background(),
Model: "claude-sonnet-4",
Tokens: UsageTokens{InputTokens: 100},
RateMultiplier: 1.0,
Resolver: resolver,
})
require.NoError(t, err)
require.Equal(t, "token", cost.BillingMode)
}
func TestCalculateCostUnified_UsesPreResolvedPricing(t *testing.T) {
bs := newTestBillingService()
resolver := NewModelPricingResolver(nil, bs)
// Pre-resolve with per_request mode to verify it's used instead of re-resolving
preResolved := &ResolvedPricing{
Mode: BillingModePerRequest,
DefaultPerRequestPrice: 0.07,
}
cost, err := bs.CalculateCostUnified(CostInput{
Ctx: context.Background(),
Model: "claude-sonnet-4",
Tokens: UsageTokens{InputTokens: 100},
RequestCount: 2,
RateMultiplier: 1.0,
Resolver: resolver,
Resolved: preResolved,
})
require.NoError(t, err)
require.NotNil(t, cost)
// 2 * $0.07 = $0.14
require.InDelta(t, 0.14, cost.TotalCost, 1e-10)
require.Equal(t, string(BillingModePerRequest), cost.BillingMode)
}
// ---------------------------------------------------------------------------
// helpers
// ---------------------------------------------------------------------------
// newTestChannelServiceWithCache creates a ChannelService with a pre-populated
// cache snapshot, bypassing the repository layer entirely.
func newTestChannelServiceWithCache(t *testing.T, cache *channelCache) *ChannelService {
t.Helper()
cs := &ChannelService{}
cache.loadedAt = time.Now()
cs.cache.Store(cache)
return cs
}
......@@ -37,8 +37,10 @@ type Channel struct {
Name string
Description string
Status string
BillingModelSource string // "requested", "upstream", or "channel_mapped"
RestrictModels bool // 是否限制模型(仅允许定价列表中的模型)
BillingModelSource string // "requested", "upstream", or "channel_mapped"
RestrictModels bool // 是否限制模型(仅允许定价列表中的模型)
Features string // 渠道特性描述(JSON 数组),用于支付页面展示
FeaturesConfig map[string]any // 渠道功能配置(如 web search emulation)
CreatedAt time.Time
UpdatedAt time.Time
......@@ -48,6 +50,25 @@ type Channel struct {
ModelPricing []ChannelModelPricing
// 渠道级模型映射(按平台分组:platform → {src→dst})
ModelMapping map[string]map[string]string
// 账号统计定价
ApplyPricingToAccountStats bool // 是否应用渠道模型定价到账号统计
AccountStatsPricingRules []AccountStatsPricingRule // 自定义账号统计定价规则(按 SortOrder 排序,先命中为准)
}
// AccountStatsPricingRule 账号统计定价规则
// 每条规则包含匹配条件(分组/账号)和独立的模型定价。
// 多条规则按 SortOrder 排序,先命中为准。
type AccountStatsPricingRule struct {
ID int64
ChannelID int64
Name string
GroupIDs []int64
AccountIDs []int64
SortOrder int
Pricing []ChannelModelPricing // 规则内的模型定价(复用现有定价结构)
CreatedAt time.Time
UpdatedAt time.Time
}
// ChannelModelPricing 渠道模型定价条目
......@@ -176,9 +197,58 @@ func (c *Channel) Clone() *Channel {
cp.ModelMapping[platform] = inner
}
}
if c.FeaturesConfig != nil {
cp.FeaturesConfig = deepCopyFeaturesConfig(c.FeaturesConfig)
}
if c.AccountStatsPricingRules != nil {
cp.AccountStatsPricingRules = make([]AccountStatsPricingRule, len(c.AccountStatsPricingRules))
for i, rule := range c.AccountStatsPricingRules {
cp.AccountStatsPricingRules[i] = rule
if rule.GroupIDs != nil {
cp.AccountStatsPricingRules[i].GroupIDs = make([]int64, len(rule.GroupIDs))
copy(cp.AccountStatsPricingRules[i].GroupIDs, rule.GroupIDs)
}
if rule.AccountIDs != nil {
cp.AccountStatsPricingRules[i].AccountIDs = make([]int64, len(rule.AccountIDs))
copy(cp.AccountStatsPricingRules[i].AccountIDs, rule.AccountIDs)
}
if rule.Pricing != nil {
cp.AccountStatsPricingRules[i].Pricing = make([]ChannelModelPricing, len(rule.Pricing))
for j := range rule.Pricing {
cp.AccountStatsPricingRules[i].Pricing[j] = rule.Pricing[j].Clone()
}
}
}
}
return &cp
}
// IsWebSearchEmulationEnabled 返回该渠道是否为指定平台启用了 web search 模拟。
func (c *Channel) IsWebSearchEmulationEnabled(platform string) bool {
if c == nil || c.FeaturesConfig == nil {
return false
}
wse, ok := c.FeaturesConfig[featureKeyWebSearchEmulation].(map[string]any)
if !ok {
return false
}
enabled, ok := wse[platform].(bool)
return ok && enabled
}
// deepCopyFeaturesConfig creates a deep copy of FeaturesConfig to prevent cache pollution.
func deepCopyFeaturesConfig(src map[string]any) map[string]any {
dst := make(map[string]any, len(src))
for k, v := range src {
if inner, ok := v.(map[string]any); ok {
dst[k] = deepCopyFeaturesConfig(inner)
} else {
dst[k] = v
}
}
return dst
}
// ValidateIntervals 校验区间列表的合法性。
// 规则:MinTokens >= 0;MaxTokens 若非 nil 则 > 0 且 > MinTokens;
// 所有价格字段 >= 0;区间按 MinTokens 排序后无重叠((min, max] 语义);
......
......@@ -81,9 +81,9 @@ type wildcardMappingEntry struct {
type channelCache struct {
// 热路径查找
pricingByGroupModel map[channelModelKey]*ChannelModelPricing // (groupID, platform, model) → 定价
wildcardByGroupPlatform map[channelGroupPlatformKey][]*wildcardPricingEntry // (groupID, platform) → 通配符定价(前缀长度降序
wildcardByGroupPlatform map[channelGroupPlatformKey][]*wildcardPricingEntry // (groupID, platform) → 通配符定价(按配置顺序,先匹配先使用
mappingByGroupModel map[channelModelKey]string // (groupID, platform, model) → 映射目标
wildcardMappingByGP map[channelGroupPlatformKey][]*wildcardMappingEntry // (groupID, platform) → 通配符映射(前缀长度降序
wildcardMappingByGP map[channelGroupPlatformKey][]*wildcardMappingEntry // (groupID, platform) → 通配符映射(按配置顺序,先匹配先使用
channelByGroupID map[int64]*Channel // groupID → 渠道
groupPlatform map[int64]string // groupID → platform
......@@ -315,6 +315,7 @@ func populateChannelCache(channels []Channel, groupPlatforms map[int64]string) *
expandMappingToCache(cache, ch, gid, platform)
}
}
return cache
}
......@@ -415,6 +416,15 @@ func (s *ChannelService) GetChannelForGroup(ctx context.Context, groupID int64)
return ch.Clone(), nil
}
// GetGroupPlatform 获取分组的平台标识(从缓存)
func (s *ChannelService) GetGroupPlatform(ctx context.Context, groupID int64) string {
cache, err := s.loadCache(ctx)
if err != nil {
return ""
}
return cache.groupPlatform[groupID]
}
// channelLookup 热路径公共查找结果
type channelLookup struct {
cache *channelCache
......@@ -556,13 +566,19 @@ func ReplaceModelInBody(body []byte, newModel string) []byte {
// validateChannelConfig 校验渠道的定价和映射配置(冲突检测 + 区间校验 + 计费模式校验)。
// Create 和 Update 共用此函数,避免重复。
func validateChannelConfig(pricing []ChannelModelPricing, mapping map[string]map[string]string) error {
if err := validateNoConflictingModels(pricing); err != nil {
if err := validatePricingEntries(pricing); err != nil {
return err
}
if err := validatePricingIntervals(pricing); err != nil {
return validateNoConflictingMappings(mapping)
}
// validatePricingEntries 校验定价条目(冲突检测 + 区间校验 + 计费模式校验),
// 同时用于主渠道定价和 account_stats_pricing_rules 的内部定价。
func validatePricingEntries(pricing []ChannelModelPricing) error {
if err := validateNoConflictingModels(pricing); err != nil {
return err
}
if err := validateNoConflictingMappings(mapping); err != nil {
if err := validatePricingIntervals(pricing); err != nil {
return err
}
return validatePricingBillingMode(pricing)
......@@ -655,14 +671,18 @@ func (s *ChannelService) Create(ctx context.Context, input *CreateChannelInput)
}
channel := &Channel{
Name: input.Name,
Description: input.Description,
Status: StatusActive,
BillingModelSource: input.BillingModelSource,
RestrictModels: input.RestrictModels,
GroupIDs: input.GroupIDs,
ModelPricing: input.ModelPricing,
ModelMapping: input.ModelMapping,
Name: input.Name,
Description: input.Description,
Status: StatusActive,
BillingModelSource: input.BillingModelSource,
RestrictModels: input.RestrictModels,
GroupIDs: input.GroupIDs,
ModelPricing: input.ModelPricing,
ModelMapping: input.ModelMapping,
Features: input.Features,
FeaturesConfig: input.FeaturesConfig,
ApplyPricingToAccountStats: input.ApplyPricingToAccountStats,
AccountStatsPricingRules: input.AccountStatsPricingRules,
}
if channel.BillingModelSource == "" {
channel.BillingModelSource = BillingModelSourceChannelMapped
......@@ -671,6 +691,11 @@ func (s *ChannelService) Create(ctx context.Context, input *CreateChannelInput)
if err := validateChannelConfig(channel.ModelPricing, channel.ModelMapping); err != nil {
return nil, err
}
for i, rule := range channel.AccountStatsPricingRules {
if err := validatePricingEntries(rule.Pricing); err != nil {
return nil, fmt.Errorf("account stats pricing rule #%d: %w", i+1, err)
}
}
if err := s.repo.Create(ctx, channel); err != nil {
return nil, fmt.Errorf("create channel: %w", err)
......@@ -699,6 +724,11 @@ func (s *ChannelService) Update(ctx context.Context, id int64, input *UpdateChan
if err := validateChannelConfig(channel.ModelPricing, channel.ModelMapping); err != nil {
return nil, err
}
for i, rule := range channel.AccountStatsPricingRules {
if err := validatePricingEntries(rule.Pricing); err != nil {
return nil, fmt.Errorf("account stats pricing rule #%d: %w", i+1, err)
}
}
oldGroupIDs := s.getOldGroupIDs(ctx, id)
......@@ -733,6 +763,9 @@ func (s *ChannelService) applyUpdateInput(ctx context.Context, channel *Channel,
if input.RestrictModels != nil {
channel.RestrictModels = *input.RestrictModels
}
if input.Features != nil {
channel.Features = *input.Features
}
if input.GroupIDs != nil {
if err := s.checkGroupConflicts(ctx, channel.ID, *input.GroupIDs); err != nil {
return err
......@@ -748,6 +781,15 @@ func (s *ChannelService) applyUpdateInput(ctx context.Context, channel *Channel,
if input.BillingModelSource != "" {
channel.BillingModelSource = input.BillingModelSource
}
if input.FeaturesConfig != nil {
channel.FeaturesConfig = input.FeaturesConfig
}
if input.ApplyPricingToAccountStats != nil {
channel.ApplyPricingToAccountStats = *input.ApplyPricingToAccountStats
}
if input.AccountStatsPricingRules != nil {
channel.AccountStatsPricingRules = *input.AccountStatsPricingRules
}
return nil
}
......@@ -913,23 +955,31 @@ func detectConflicts(entries []modelEntry, platform, errCode, label string) erro
// CreateChannelInput 创建渠道输入
type CreateChannelInput struct {
Name string
Description string
GroupIDs []int64
ModelPricing []ChannelModelPricing
ModelMapping map[string]map[string]string // platform → {src→dst}
BillingModelSource string
RestrictModels bool
Name string
Description string
GroupIDs []int64
ModelPricing []ChannelModelPricing
ModelMapping map[string]map[string]string // platform → {src→dst}
BillingModelSource string
RestrictModels bool
Features string
FeaturesConfig map[string]any
ApplyPricingToAccountStats bool
AccountStatsPricingRules []AccountStatsPricingRule
}
// UpdateChannelInput 更新渠道输入
type UpdateChannelInput struct {
Name string
Description *string
Status string
GroupIDs *[]int64
ModelPricing *[]ChannelModelPricing
ModelMapping map[string]map[string]string // platform → {src→dst}
BillingModelSource string
RestrictModels *bool
Name string
Description *string
Status string
GroupIDs *[]int64
ModelPricing *[]ChannelModelPricing
ModelMapping map[string]map[string]string // platform → {src→dst}
BillingModelSource string
RestrictModels *bool
Features *string
FeaturesConfig map[string]any
ApplyPricingToAccountStats *bool
AccountStatsPricingRules *[]AccountStatsPricingRule
}
package service
import (
"testing"
"github.com/stretchr/testify/require"
)
func TestChannel_IsWebSearchEmulationEnabled_Enabled(t *testing.T) {
c := &Channel{
FeaturesConfig: map[string]any{
featureKeyWebSearchEmulation: map[string]any{"anthropic": true},
},
}
require.True(t, c.IsWebSearchEmulationEnabled("anthropic"))
}
func TestChannel_IsWebSearchEmulationEnabled_DifferentPlatform(t *testing.T) {
c := &Channel{
FeaturesConfig: map[string]any{
featureKeyWebSearchEmulation: map[string]any{"anthropic": true},
},
}
require.False(t, c.IsWebSearchEmulationEnabled("openai"))
}
func TestChannel_IsWebSearchEmulationEnabled_Disabled(t *testing.T) {
c := &Channel{
FeaturesConfig: map[string]any{
featureKeyWebSearchEmulation: map[string]any{"anthropic": false},
},
}
require.False(t, c.IsWebSearchEmulationEnabled("anthropic"))
}
func TestChannel_IsWebSearchEmulationEnabled_NilFeaturesConfig(t *testing.T) {
c := &Channel{FeaturesConfig: nil}
require.False(t, c.IsWebSearchEmulationEnabled("anthropic"))
}
func TestChannel_IsWebSearchEmulationEnabled_NilChannel(t *testing.T) {
var c *Channel
require.False(t, c.IsWebSearchEmulationEnabled("anthropic"))
}
func TestChannel_IsWebSearchEmulationEnabled_WrongStructure(t *testing.T) {
c := &Channel{
FeaturesConfig: map[string]any{
featureKeyWebSearchEmulation: true, // not a map
},
}
require.False(t, c.IsWebSearchEmulationEnabled("anthropic"))
}
func TestChannel_IsWebSearchEmulationEnabled_PlatformValueNotBool(t *testing.T) {
c := &Channel{
FeaturesConfig: map[string]any{
featureKeyWebSearchEmulation: map[string]any{"anthropic": "yes"},
},
}
require.False(t, c.IsWebSearchEmulationEnabled("anthropic"))
}
......@@ -343,8 +343,9 @@ func (s *ConcurrencyService) StartSlotCleanupWorker(accountRepo AccountRepositor
}()
}
// GetAccountConcurrencyBatch gets current concurrency counts for multiple accounts
// Returns a map of accountID -> current concurrency count
// GetAccountConcurrencyBatch gets current concurrency counts for multiple accounts.
// Uses a detached context with timeout to prevent HTTP request cancellation from
// causing the entire batch to fail (which would show all concurrency as 0).
func (s *ConcurrencyService) GetAccountConcurrencyBatch(ctx context.Context, accountIDs []int64) (map[int64]int, error) {
if len(accountIDs) == 0 {
return map[int64]int{}, nil
......@@ -356,5 +357,11 @@ func (s *ConcurrencyService) GetAccountConcurrencyBatch(ctx context.Context, acc
}
return result, nil
}
return s.cache.GetAccountConcurrencyBatch(ctx, accountIDs)
// Use a detached context so that a cancelled HTTP request doesn't cause
// the Redis pipeline to fail and return all-zero concurrency counts.
redisCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
return s.cache.GetAccountConcurrencyBatch(redisCtx, accountIDs)
}
......@@ -249,6 +249,18 @@ const (
SettingKeyEnableMetadataPassthrough = "enable_metadata_passthrough"
// SettingKeyEnableCCHSigning 是否对 billing header 中的 cch 进行 xxHash64 签名(默认 false)
SettingKeyEnableCCHSigning = "enable_cch_signing"
// Balance Low Notification
SettingKeyBalanceLowNotifyEnabled = "balance_low_notify_enabled" // 全局开关
SettingKeyBalanceLowNotifyThreshold = "balance_low_notify_threshold" // 默认阈值(USD)
SettingKeyBalanceLowNotifyRechargeURL = "balance_low_notify_recharge_url" // 充值页面 URL
// Account Quota Notification
SettingKeyAccountQuotaNotifyEnabled = "account_quota_notify_enabled" // 全局开关
SettingKeyAccountQuotaNotifyEmails = "account_quota_notify_emails" // 管理员通知邮箱列表(JSON 数组)
// Web Search Emulation
SettingKeyWebSearchEmulationConfig = "web_search_emulation_config" // JSON 配置
)
// AdminAPIKeyPrefix is the prefix for admin API keys (distinct from user "sk-" keys).
......
......@@ -7,8 +7,9 @@ import (
"crypto/tls"
"encoding/hex"
"fmt"
"log"
"log/slog"
"math/big"
"net"
"net/smtp"
"net/url"
"strconv"
......@@ -34,6 +35,11 @@ type EmailCache interface {
SetVerificationCode(ctx context.Context, email string, data *VerificationCodeData, ttl time.Duration) error
DeleteVerificationCode(ctx context.Context, email string) error
// Notify email verification code methods
GetNotifyVerifyCode(ctx context.Context, email string) (*VerificationCodeData, error)
SetNotifyVerifyCode(ctx context.Context, email string, data *VerificationCodeData, ttl time.Duration) error
DeleteNotifyVerifyCode(ctx context.Context, email string) error
// Password reset token methods
GetPasswordResetToken(ctx context.Context, email string) (*PasswordResetTokenData, error)
SetPasswordResetToken(ctx context.Context, email string, data *PasswordResetTokenData, ttl time.Duration) error
......@@ -43,6 +49,10 @@ type EmailCache interface {
// Returns true if in cooldown period (email was sent recently)
IsPasswordResetEmailInCooldown(ctx context.Context, email string) bool
SetPasswordResetEmailCooldown(ctx context.Context, email string, ttl time.Duration) error
// Notify code rate limiting per user
IncrNotifyCodeUserRate(ctx context.Context, userID int64, window time.Duration) (int64, error)
GetNotifyCodeUserRate(ctx context.Context, userID int64) (int64, error)
}
// VerificationCodeData represents verification code data
......@@ -50,6 +60,7 @@ type VerificationCodeData struct {
Code string
Attempts int
CreatedAt time.Time
ExpiresAt time.Time // absolute expiry; used to preserve remaining TTL when updating attempts
}
// PasswordResetTokenData represents password reset token data
......@@ -146,11 +157,18 @@ func (s *EmailService) SendEmail(ctx context.Context, to, subject, body string)
return s.SendEmailWithConfig(config, to, subject, body)
}
const smtpDialTimeout = 10 * time.Second
const smtpIOTimeout = 20 * time.Second
// SendEmailWithConfig 使用指定配置发送邮件
func (s *EmailService) SendEmailWithConfig(config *SMTPConfig, to, subject, body string) error {
from := config.From
// Sanitize all SMTP header fields to prevent header injection (CR/LF removal).
to = sanitizeEmailHeader(to)
subject = sanitizeEmailHeader(subject)
from := sanitizeEmailHeader(config.From)
if config.FromName != "" {
from = fmt.Sprintf("%s <%s>", config.FromName, config.From)
from = fmt.Sprintf("%s <%s>", sanitizeEmailHeader(config.FromName), sanitizeEmailHeader(config.From))
}
msg := fmt.Sprintf("From: %s\r\nTo: %s\r\nSubject: %s\r\nMIME-Version: 1.0\r\nContent-Type: text/html; charset=UTF-8\r\n\r\n%s",
......@@ -163,7 +181,54 @@ func (s *EmailService) SendEmailWithConfig(config *SMTPConfig, to, subject, body
return s.sendMailTLS(addr, auth, config.From, to, []byte(msg), config.Host)
}
return smtp.SendMail(addr, auth, config.From, []string{to}, []byte(msg))
return s.sendMailPlain(addr, auth, config.From, to, []byte(msg), config.Host)
}
// sendMailPlain sends mail without TLS using a dialer with timeout.
func (s *EmailService) sendMailPlain(addr string, auth smtp.Auth, from, to string, msg []byte, host string) error {
dialer := &net.Dialer{Timeout: smtpDialTimeout}
conn, err := dialer.Dial("tcp", addr)
if err != nil {
return fmt.Errorf("smtp dial: %w", err)
}
_ = conn.SetDeadline(time.Now().Add(smtpIOTimeout))
defer func() { _ = conn.Close() }()
client, err := smtp.NewClient(conn, host)
if err != nil {
return fmt.Errorf("new smtp client: %w", err)
}
defer func() { _ = client.Close() }()
// Opportunistic STARTTLS: upgrade to encrypted connection if the server supports it.
// This mirrors the behavior of smtp.SendMail which we replaced for timeout support.
if ok, _ := client.Extension("STARTTLS"); ok {
if err = client.StartTLS(&tls.Config{ServerName: host, MinVersion: tls.VersionTLS12}); err != nil {
return fmt.Errorf("starttls: %w", err)
}
}
if err = client.Auth(auth); err != nil {
return fmt.Errorf("smtp auth: %w", err)
}
if err = client.Mail(from); err != nil {
return fmt.Errorf("smtp mail: %w", err)
}
if err = client.Rcpt(to); err != nil {
return fmt.Errorf("smtp rcpt: %w", err)
}
w, err := client.Data()
if err != nil {
return fmt.Errorf("smtp data: %w", err)
}
if _, err = w.Write(msg); err != nil {
return fmt.Errorf("write msg: %w", err)
}
if err = w.Close(); err != nil {
return fmt.Errorf("close writer: %w", err)
}
_ = client.Quit()
return nil
}
// sendMailTLS 使用TLS发送邮件
......@@ -174,10 +239,12 @@ func (s *EmailService) sendMailTLS(addr string, auth smtp.Auth, from, to string,
MinVersion: tls.VersionTLS12,
}
conn, err := tls.Dial("tcp", addr, tlsConfig)
dialer := &net.Dialer{Timeout: smtpDialTimeout}
conn, err := tls.DialWithDialer(dialer, "tcp", addr, tlsConfig)
if err != nil {
return fmt.Errorf("tls dial: %w", err)
}
_ = conn.SetDeadline(time.Now().Add(smtpIOTimeout))
defer func() { _ = conn.Close() }()
client, err := smtp.NewClient(conn, host)
......@@ -254,6 +321,7 @@ func (s *EmailService) SendVerifyCode(ctx context.Context, email, siteName strin
Code: code,
Attempts: 0,
CreatedAt: time.Now(),
ExpiresAt: time.Now().Add(verifyCodeTTL),
}
if err := s.cache.SetVerificationCode(ctx, email, data, verifyCodeTTL); err != nil {
return fmt.Errorf("save verify code: %w", err)
......@@ -286,8 +354,12 @@ func (s *EmailService) VerifyCode(ctx context.Context, email, code string) error
// 验证码不匹配 (constant-time comparison to prevent timing attacks)
if subtle.ConstantTimeCompare([]byte(data.Code), []byte(code)) != 1 {
data.Attempts++
if err := s.cache.SetVerificationCode(ctx, email, data, verifyCodeTTL); err != nil {
log.Printf("[Email] Failed to update verification attempt count: %v", err)
remaining := time.Until(data.ExpiresAt)
if remaining <= 0 {
return ErrInvalidVerifyCode
}
if err := s.cache.SetVerificationCode(ctx, email, data, remaining); err != nil {
slog.Error("failed to update verification attempt count", "email", email, "error", err)
}
if data.Attempts >= maxVerifyCodeAttempts {
return ErrVerifyCodeMaxAttempts
......@@ -297,7 +369,7 @@ func (s *EmailService) VerifyCode(ctx context.Context, email, code string) error
// 验证成功,删除验证码
if err := s.cache.DeleteVerificationCode(ctx, email); err != nil {
log.Printf("[Email] Failed to delete verification code after success: %v", err)
slog.Error("failed to delete verification code after success", "email", email, "error", err)
}
return nil
}
......@@ -447,7 +519,7 @@ func (s *EmailService) SendPasswordResetEmail(ctx context.Context, email, siteNa
func (s *EmailService) SendPasswordResetEmailWithCooldown(ctx context.Context, email, siteName, resetURL string) error {
// Check email cooldown to prevent email bombing
if s.cache.IsPasswordResetEmailInCooldown(ctx, email) {
log.Printf("[Email] Password reset email skipped (cooldown): %s", email)
slog.Info("password reset email skipped due to cooldown", "email", email)
return nil // Silent success to prevent revealing cooldown to attackers
}
......@@ -458,7 +530,7 @@ func (s *EmailService) SendPasswordResetEmailWithCooldown(ctx context.Context, e
// Set cooldown marker (Redis TTL handles expiration)
if err := s.cache.SetPasswordResetEmailCooldown(ctx, email, passwordResetEmailCooldown); err != nil {
log.Printf("[Email] Failed to set password reset cooldown for %s: %v", email, err)
slog.Error("failed to set password reset cooldown", "email", email, "error", err)
}
return nil
......@@ -488,7 +560,7 @@ func (s *EmailService) ConsumePasswordResetToken(ctx context.Context, email, tok
// Delete after verification (one-time use)
if err := s.cache.DeletePasswordResetToken(ctx, email); err != nil {
log.Printf("[Email] Failed to delete password reset token after consumption: %v", err)
slog.Error("failed to delete password reset token after consumption", "email", email, "error", err)
}
return nil
}
......
......@@ -43,6 +43,7 @@ func newGatewayRecordUsageServiceForTest(usageRepo UsageLogRepository, userRepo
nil,
nil,
nil,
nil,
)
}
......
......@@ -75,6 +75,9 @@ type ParsedRequest struct {
MaxTokens int // max_tokens 值(用于探测请求拦截)
SessionContext *SessionContext // 可选:请求上下文区分因子(nil 时行为不变)
// GroupID 请求所属分组 ID(来自 API Key)
GroupID *int64
// OnUpstreamAccepted 上游接受请求后立即调用(用于提前释放串行锁)
// 流式请求在收到 2xx 响应头后调用,避免持锁等流完成
OnUpstreamAccepted func()
......
......@@ -503,7 +503,6 @@ type ForwardResult struct {
// 图片生成计费字段(图片生成模型使用)
ImageCount int // 生成的图片数量
ImageSize string // 图片尺寸 "1K", "2K", "4K"
}
// UpstreamFailoverError indicates an upstream error that should trigger account failover.
......@@ -570,6 +569,7 @@ type GatewayService struct {
resolver *ModelPricingResolver
debugGatewayBodyFile atomic.Pointer[os.File] // non-nil when SUB2API_DEBUG_GATEWAY_BODY is set
tlsFPProfileService *TLSFingerprintProfileService
balanceNotifyService *BalanceNotifyService
}
// NewGatewayService creates a new GatewayService
......@@ -599,6 +599,7 @@ func NewGatewayService(
tlsFPProfileService *TLSFingerprintProfileService,
channelService *ChannelService,
resolver *ModelPricingResolver,
balanceNotifyService *BalanceNotifyService,
) *GatewayService {
userGroupRateTTL := resolveUserGroupRateCacheTTL(cfg)
modelsListTTL := resolveModelsListCacheTTL(cfg)
......@@ -633,6 +634,7 @@ func NewGatewayService(
tlsFPProfileService: tlsFPProfileService,
channelService: channelService,
resolver: resolver,
balanceNotifyService: balanceNotifyService,
}
svc.userGroupRateResolver = newUserGroupRateResolver(
userGroupRateRepo,
......@@ -1329,6 +1331,11 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
ctx = s.withWindowCostPrefetch(ctx, accounts)
ctx = s.withRPMPrefetch(ctx, accounts)
// 提前构建 accountByID(供 Layer 1 和 Layer 1.5 使用)
accountByID := make(map[int64]*Account, len(accounts))
for i := range accounts {
accountByID[accounts[i].ID] = &accounts[i]
}
isExcluded := func(accountID int64) bool {
if excludedIDs == nil {
return false
......@@ -1337,12 +1344,6 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
return excluded
}
// 提前构建 accountByID(供 Layer 1 和 Layer 1.5 使用)
accountByID := make(map[int64]*Account, len(accounts))
for i := range accounts {
accountByID[accounts[i].ID] = &accounts[i]
}
// 获取模型路由配置(仅 anthropic 平台)
var routingAccountIDs []int64
if group != nil && requestedModel != "" && group.Platform == PlatformAnthropic {
......@@ -1598,7 +1599,6 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
account, ok := accountByID[accountID]
if ok {
// 检查账户是否需要清理粘性会话绑定
// Check if the account needs sticky session cleanup
clearSticky := shouldClearStickySession(account, requestedModel)
if clearSticky {
_ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
......@@ -1614,7 +1614,6 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
result, err := s.tryAcquireAccountSlot(ctx, accountID, account.Concurrency)
if err == nil && result.Acquired {
// 会话数量限制检查
// Session count limit check
if !s.checkAndRegisterSession(ctx, account, sessionHash) {
result.ReleaseFunc() // 释放槽位,继续到 Layer 2
} else {
......@@ -1628,10 +1627,8 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, accountID)
if waitingCount < cfg.StickySessionMaxWaiting {
// 会话数量限制检查(等待计划也需要占用会话配额)
// Session count limit check (wait plan also requires session quota)
if !s.checkAndRegisterSession(ctx, account, sessionHash) {
// 会话限制已满,继续到 Layer 2
// Session limit full, continue to Layer 2
} else {
return s.newSelectionResult(ctx, account, false, nil, &AccountWaitPlan{
AccountID: accountID,
......@@ -2740,7 +2737,7 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
if clearSticky {
_ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
}
if !clearSticky && s.isAccountInGroup(account, groupID) && account.Platform == platform && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && s.isAccountSchedulableForModelSelection(ctx, account, requestedModel) && s.isAccountSchedulableForQuota(account) && s.isAccountSchedulableForWindowCost(ctx, account, true) && s.isAccountSchedulableForRPM(ctx, account, true) {
if !clearSticky && s.isAccountInGroup(account, groupID) && account.Platform == platform && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && s.isAccountSchedulableForModelSelection(ctx, account, requestedModel) && s.isAccountSchedulableForQuota(account) && s.isAccountSchedulableForWindowCost(ctx, account, true) && s.isAccountSchedulableForRPM(ctx, account, true) && !s.isStickyAccountUpstreamRestricted(ctx, groupID, account, requestedModel) {
if s.debugModelRoutingEnabled() {
logger.LegacyPrintf("service.gateway", "[ModelRoutingDebug] legacy routed sticky hit: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), accountID)
}
......@@ -3119,7 +3116,7 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
if clearSticky {
_ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
}
if !clearSticky && s.isAccountInGroup(account, groupID) && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && s.isAccountSchedulableForModelSelection(ctx, account, requestedModel) && s.isAccountSchedulableForQuota(account) && s.isAccountSchedulableForWindowCost(ctx, account, true) && s.isAccountSchedulableForRPM(ctx, account, true) {
if !clearSticky && s.isAccountInGroup(account, groupID) && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && s.isAccountSchedulableForModelSelection(ctx, account, requestedModel) && s.isAccountSchedulableForQuota(account) && s.isAccountSchedulableForWindowCost(ctx, account, true) && s.isAccountSchedulableForRPM(ctx, account, true) && !s.isStickyAccountUpstreamRestricted(ctx, groupID, account, requestedModel) {
if account.Platform == nativePlatform || (account.Platform == PlatformAntigravity && account.IsMixedSchedulingEnabled()) {
return account, nil
}
......@@ -3435,6 +3432,10 @@ func (s *GatewayService) isModelSupportedByAccount(account *Account, requestedMo
_, ok := ResolveBedrockModelID(account, requestedModel)
return ok
}
// OpenAI 透传模式:仅替换认证,允许所有模型
if account.Platform == PlatformOpenAI && account.IsOpenAIPassthroughEnabled() {
return true
}
// OAuth/SetupToken 账号使用 Anthropic 标准映射(短ID → 长ID)
if account.Platform == PlatformAnthropic && account.Type != AccountTypeAPIKey {
requestedModel = claude.NormalizeModelID(requestedModel)
......@@ -3934,6 +3935,11 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
return nil, fmt.Errorf("parse request: empty request")
}
// Web Search 模拟:纯 web_search 请求时,直接调用搜索 API 构造响应
if account != nil && s.shouldEmulateWebSearch(ctx, account, parsed.GroupID, parsed.Body) {
return s.handleWebSearchEmulation(ctx, c, account, parsed)
}
if account != nil && account.IsAnthropicAPIKeyPassthroughEnabled() {
passthroughBody := parsed.Body
passthroughModel := parsed.Model
......@@ -7279,6 +7285,7 @@ func (s *GatewayService) getUserGroupRateMultiplier(ctx context.Context, userID,
// RecordUsageInput 记录使用量的输入参数
type RecordUsageInput struct {
Result *ForwardResult
ParsedRequest *ParsedRequest
APIKey *APIKey
User *User
Account *Account
......@@ -7333,49 +7340,41 @@ func (p *postUsageBillingParams) shouldUpdateAccountQuota() bool {
return p.Cost.TotalCost > 0 && p.Account.IsAPIKeyOrBedrock() && p.Account.HasAnyQuotaLimit()
}
// postUsageBilling 统一处理使用量记录后的扣费逻辑:
// - 订阅/余额扣费
// - API Key 配额更新
// - API Key 限速用量更新
// - 账号配额用量更新(账号口径:TotalCost × 账号计费倍率)
// postUsageBilling is the legacy fallback billing path used when the unified
// billing repo is unavailable (nil). Production uses applyUsageBilling → repo.Apply
// for atomic billing. This path only runs in tests or degraded mode.
func postUsageBilling(ctx context.Context, p *postUsageBillingParams, deps *billingDeps) {
billingCtx, cancel := detachedBillingContext(ctx)
defer cancel()
cost := p.Cost
// 1. 订阅 / 余额扣费
if p.IsSubscriptionBill {
if cost.TotalCost > 0 {
if err := deps.userSubRepo.IncrementUsage(billingCtx, p.Subscription.ID, cost.TotalCost); err != nil {
slog.Error("increment subscription usage failed", "subscription_id", p.Subscription.ID, "error", err)
}
deps.billingCacheService.QueueUpdateSubscriptionUsage(p.User.ID, *p.APIKey.GroupID, cost.TotalCost)
}
} else {
if cost.ActualCost > 0 {
if err := deps.userRepo.DeductBalance(billingCtx, p.User.ID, cost.ActualCost); err != nil {
slog.Error("deduct balance failed", "user_id", p.User.ID, "error", err)
}
deps.billingCacheService.QueueDeductBalance(p.User.ID, cost.ActualCost)
}
}
// 2. API Key 配额
if p.shouldDeductAPIKeyQuota() {
if err := p.APIKeyService.UpdateQuotaUsed(billingCtx, p.APIKey.ID, cost.ActualCost); err != nil {
slog.Error("update api key quota failed", "api_key_id", p.APIKey.ID, "error", err)
}
}
// 3. API Key 限速用量
if p.shouldUpdateRateLimits() {
if err := p.APIKeyService.UpdateRateLimitUsage(billingCtx, p.APIKey.ID, cost.ActualCost); err != nil {
slog.Error("update api key rate limit usage failed", "api_key_id", p.APIKey.ID, "error", err)
}
}
// 4. 账号配额用量(账号口径:TotalCost × 账号计费倍率)
if p.shouldUpdateAccountQuota() {
accountCost := cost.TotalCost * p.AccountRateMultiplier
if err := deps.accountRepo.IncrementQuotaUsed(billingCtx, p.Account.ID, accountCost); err != nil {
......@@ -7383,7 +7382,10 @@ func postUsageBilling(ctx context.Context, p *postUsageBillingParams, deps *bill
}
}
finalizePostUsageBilling(p, deps)
// NOTE: finalizePostUsageBilling is NOT called here to avoid double-queuing
// cache updates. The legacy path does DB writes directly; the finalize path
// does cache queue + notifications. Notifications are dispatched separately
// by the caller after recording the usage log.
}
func resolveUsageBillingRequestID(ctx context.Context, upstreamRequestID string) string {
......@@ -7499,11 +7501,11 @@ func applyUsageBilling(ctx context.Context, requestID string, usageLog *UsageLog
}
}
finalizePostUsageBilling(p, deps)
finalizePostUsageBilling(p, deps, result)
return true, nil
}
func finalizePostUsageBilling(p *postUsageBillingParams, deps *billingDeps) {
func finalizePostUsageBilling(p *postUsageBillingParams, deps *billingDeps, result *UsageBillingApplyResult) {
if p == nil || p.Cost == nil || deps == nil {
return
}
......@@ -7521,6 +7523,83 @@ func finalizePostUsageBilling(p *postUsageBillingParams, deps *billingDeps) {
}
deps.deferredService.ScheduleLastUsedUpdate(p.Account.ID)
// Notification checks run async — all parameters are already captured,
// no dependency on the request context or upstream connection.
go notifyBalanceLow(p, deps, result)
go notifyAccountQuota(p, deps, result)
}
// notifyBalanceLow sends balance low notification after deduction.
// When result.NewBalance is available (from DB transaction RETURNING), it is used directly
// to reconstruct oldBalance, avoiding stale Redis reads and concurrent-deduction races.
func notifyBalanceLow(p *postUsageBillingParams, deps *billingDeps, result *UsageBillingApplyResult) {
defer func() {
if r := recover(); r != nil {
slog.Error("panic in notifyBalanceLow", "recover", r)
}
}()
if p.IsSubscriptionBill || p.Cost.ActualCost <= 0 || p.User == nil || deps.balanceNotifyService == nil {
slog.Debug("notifyBalanceLow: skipped",
"is_subscription", p.IsSubscriptionBill,
"actual_cost", p.Cost.ActualCost,
"user_nil", p.User == nil,
"service_nil", deps.balanceNotifyService == nil,
)
return
}
oldBalance := resolveOldBalance(p, result)
slog.Debug("notifyBalanceLow: calling CheckBalanceAfterDeduction",
"user_id", p.User.ID,
"old_balance", oldBalance,
"cost", p.Cost.ActualCost,
"notify_enabled", p.User.BalanceNotifyEnabled,
"threshold", p.User.BalanceNotifyThreshold,
"result_has_new_balance", result != nil && result.NewBalance != nil,
)
deps.balanceNotifyService.CheckBalanceAfterDeduction(context.Background(), p.User, oldBalance, p.Cost.ActualCost)
}
// resolveOldBalance returns the pre-deduction balance.
// Prefers the DB transaction result (newBalance + cost) over snapshot.
func resolveOldBalance(p *postUsageBillingParams, result *UsageBillingApplyResult) float64 {
if result != nil && result.NewBalance != nil {
return *result.NewBalance + p.Cost.ActualCost
}
// Legacy fallback: snapshot balance from request context
return p.User.Balance
}
// notifyAccountQuota sends account quota threshold notification after increment.
// When result.QuotaState is available (from DB transaction RETURNING), it is passed directly
// to avoid a separate DB read that may see stale or concurrently-modified data.
func notifyAccountQuota(p *postUsageBillingParams, deps *billingDeps, result *UsageBillingApplyResult) {
defer func() {
if r := recover(); r != nil {
slog.Error("panic in notifyAccountQuota", "recover", r)
}
}()
if p.Cost.TotalCost <= 0 || p.Account == nil || !p.Account.IsAPIKeyOrBedrock() || deps.balanceNotifyService == nil {
slog.Debug("notifyAccountQuota: skipped",
"total_cost", p.Cost.TotalCost,
"account_nil", p.Account == nil,
"is_apikey_or_bedrock", p.Account != nil && p.Account.IsAPIKeyOrBedrock(),
"service_nil", deps.balanceNotifyService == nil,
)
return
}
accountCost := p.Cost.TotalCost * p.AccountRateMultiplier
var quotaState *AccountQuotaState
if result != nil {
quotaState = result.QuotaState
}
slog.Debug("notifyAccountQuota: calling CheckAccountQuotaAfterIncrement",
"account_id", p.Account.ID,
"account_cost", accountCost,
"has_quota_state", quotaState != nil,
)
deps.balanceNotifyService.CheckAccountQuotaAfterIncrement(context.Background(), p.Account, accountCost, quotaState)
}
func detachedBillingContext(ctx context.Context) (context.Context, context.CancelFunc) {
......@@ -7543,20 +7622,22 @@ func detachStreamUpstreamContext(ctx context.Context, stream bool) (context.Cont
// billingDeps 扣费逻辑依赖的服务(由各 gateway service 提供)
type billingDeps struct {
accountRepo AccountRepository
userRepo UserRepository
userSubRepo UserSubscriptionRepository
billingCacheService *BillingCacheService
deferredService *DeferredService
accountRepo AccountRepository
userRepo UserRepository
userSubRepo UserSubscriptionRepository
billingCacheService *BillingCacheService
deferredService *DeferredService
balanceNotifyService *BalanceNotifyService
}
func (s *GatewayService) billingDeps() *billingDeps {
return &billingDeps{
accountRepo: s.accountRepo,
userRepo: s.userRepo,
userSubRepo: s.userSubRepo,
billingCacheService: s.billingCacheService,
deferredService: s.deferredService,
accountRepo: s.accountRepo,
userRepo: s.userRepo,
userSubRepo: s.userSubRepo,
billingCacheService: s.billingCacheService,
deferredService: s.deferredService,
balanceNotifyService: s.balanceNotifyService,
}
}
......@@ -7746,6 +7827,23 @@ func (s *GatewayService) recordUsageCore(ctx context.Context, input *recordUsage
usageLog := s.buildRecordUsageLog(ctx, input, result, apiKey, user, account, subscription,
requestedModel, multiplier, accountRateMultiplier, billingType, cacheTTLOverridden, cost, opts)
// 计算账号统计定价费用(使用最终上游模型匹配自定义规则)
if apiKey.GroupID != nil {
applyAccountStatsCost(ctx, usageLog, s.channelService, s.billingService,
account.ID, *apiKey.GroupID, result.UpstreamModel, result.Model,
// Anthropic's input_tokens excludes cache_read and cache_creation (billed separately);
// OpenAI gateway uses actualInputTokens which also excludes cache_read for the same reason.
UsageTokens{
InputTokens: result.Usage.InputTokens,
OutputTokens: result.Usage.OutputTokens,
CacheCreationTokens: result.Usage.CacheCreationInputTokens,
CacheReadTokens: result.Usage.CacheReadInputTokens,
ImageOutputTokens: result.Usage.ImageOutputTokens,
},
cost.TotalCost,
)
}
if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple {
writeUsageLogBestEffort(ctx, s.usageLogRepo, usageLog, "service.gateway")
logger.LegacyPrintf("service.gateway", "[SIMPLE MODE] Usage recorded (not billed): user=%d, tokens=%d", usageLog.UserID, usageLog.TotalTokens())
......@@ -8086,6 +8184,19 @@ func (s *GatewayService) needsUpstreamChannelRestrictionCheck(ctx context.Contex
return ch.BillingModelSource == BillingModelSourceUpstream
}
// isStickyAccountUpstreamRestricted 检查粘性会话命中的账号是否受 upstream 渠道限制。
// 合并 needsUpstreamChannelRestrictionCheck + isUpstreamModelRestrictedByChannel 两步调用,
// 供 sticky session 条件链使用,避免内联多个函数调用导致行过长。
func (s *GatewayService) isStickyAccountUpstreamRestricted(ctx context.Context, groupID *int64, account *Account, requestedModel string) bool {
if groupID == nil {
return false
}
if !s.needsUpstreamChannelRestrictionCheck(ctx, groupID) {
return false
}
return s.isUpstreamModelRestrictedByChannel(ctx, *groupID, account, requestedModel)
}
// ForwardCountTokens 转发 count_tokens 请求到上游 API
// 特点:不记录使用量、仅支持非流式响应
func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context, account *Account, parsed *ParsedRequest) error {
......
package service
import (
"context"
"encoding/json"
"errors"
"fmt"
"log/slog"
"net/http"
"strings"
"sync/atomic"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/websearch"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
"github.com/tidwall/gjson"
)
// Web search emulation constants
const (
toolTypeWebSearchPrefix = "web_search"
toolTypeGoogleSearch = "google_search"
toolNameWebSearch = "web_search"
toolNameGoogleSearch = "google_search"
toolNameWebSearch2025 = "web_search_20250305"
webSearchDefaultMaxResults = 5
defaultWebSearchModel = "claude-sonnet-4-6"
webSearchMsgIDPrefix = "msg_ws_"
webSearchToolUseIDPrefix = "srvtoolu_ws_"
tokenEstimateDivisor = 4
// featureKeyWebSearchEmulation is the key used in Account.Extra and Channel.FeaturesConfig.
featureKeyWebSearchEmulation = "web_search_emulation"
)
// webSearchManagerPtr stores *websearch.Manager atomically for concurrent safety.
var webSearchManagerPtr atomic.Pointer[websearch.Manager]
// SetWebSearchManager wires the websearch.Manager into the gateway (goroutine-safe).
func SetWebSearchManager(m *websearch.Manager) {
webSearchManagerPtr.Store(m)
}
func getWebSearchManager() *websearch.Manager {
return webSearchManagerPtr.Load()
}
// shouldEmulateWebSearch checks whether a request should be intercepted.
//
// Judgment chain: manager exists → only web_search tool → global enabled → account/channel enabled.
// Account-level mode: "enabled" (force on), "disabled" (force off), "default" (follow channel).
func (s *GatewayService) shouldEmulateWebSearch(ctx context.Context, account *Account, groupID *int64, body []byte) bool {
if getWebSearchManager() == nil {
return false
}
if !isOnlyWebSearchToolInBody(body) {
return false
}
if !s.settingService.IsWebSearchEmulationEnabled(ctx) {
return false
}
mode := account.GetWebSearchEmulationMode()
switch mode {
case WebSearchModeEnabled:
return true
case WebSearchModeDisabled:
return false
default: // "default" → follow channel config
if groupID == nil || s.channelService == nil {
return false
}
ch, err := s.channelService.GetChannelForGroup(ctx, *groupID)
if err != nil || ch == nil {
return false
}
return ch.IsWebSearchEmulationEnabled(account.Platform)
}
}
// isOnlyWebSearchToolInBody checks if the body contains exactly one web_search tool.
func isOnlyWebSearchToolInBody(body []byte) bool {
tools := gjson.GetBytes(body, "tools")
if !tools.IsArray() {
return false
}
arr := tools.Array()
if len(arr) != 1 {
return false
}
return isWebSearchToolJSON(arr[0])
}
func isWebSearchToolJSON(tool gjson.Result) bool {
toolType := tool.Get("type").String()
if strings.HasPrefix(toolType, toolTypeWebSearchPrefix) || toolType == toolTypeGoogleSearch {
return true
}
switch tool.Get("name").String() {
case toolNameWebSearch, toolNameGoogleSearch, toolNameWebSearch2025:
return true
}
return false
}
// extractSearchQueryFromBody extracts the last user message text as the search query.
func extractSearchQueryFromBody(body []byte) string {
messages := gjson.GetBytes(body, "messages")
if !messages.IsArray() {
return ""
}
arr := messages.Array()
if len(arr) == 0 {
return ""
}
lastMsg := arr[len(arr)-1]
if lastMsg.Get("role").String() != "user" {
return ""
}
return extractWebSearchTextFromContent(lastMsg.Get("content"))
}
func extractWebSearchTextFromContent(content gjson.Result) string {
if content.Type == gjson.String {
return content.String()
}
if content.IsArray() {
for _, block := range content.Array() {
if block.Get("type").String() == "text" {
if text := block.Get("text").String(); text != "" {
return text
}
}
}
}
return ""
}
// handleWebSearchEmulation intercepts a web-search-only request,
// calls a third-party search API, and constructs an Anthropic-format response.
func (s *GatewayService) handleWebSearchEmulation(
ctx context.Context, c *gin.Context, account *Account, parsed *ParsedRequest,
) (*ForwardResult, error) {
startTime := time.Now()
// Release the serial queue lock immediately — we don't need upstream.
if parsed.OnUpstreamAccepted != nil {
parsed.OnUpstreamAccepted()
}
query := extractSearchQueryFromBody(parsed.Body)
if query == "" {
return nil, fmt.Errorf("web search emulation: no query found in messages")
}
slog.Info("web search emulation: executing search",
"account_id", account.ID, "account_name", account.Name, "query", query)
resp, providerName, err := doWebSearch(ctx, account, query)
if err != nil {
// Proxy unavailable → trigger account switch via UpstreamFailoverError
if errors.Is(err, websearch.ErrProxyUnavailable) {
return nil, &UpstreamFailoverError{
StatusCode: http.StatusBadGateway,
ResponseBody: []byte(err.Error()),
}
}
return nil, err
}
slog.Info("web search emulation: search completed",
"provider", providerName, "results_count", len(resp.Results))
model := parsed.Model
if model == "" {
model = defaultWebSearchModel
}
if parsed.Stream {
return writeWebSearchStreamResponse(c, query, resp, model, startTime)
}
return writeWebSearchNonStreamResponse(c, query, resp, model, startTime)
}
func doWebSearch(ctx context.Context, account *Account, query string) (*websearch.SearchResponse, string, error) {
proxyURL := resolveAccountProxyURL(account)
mgr := getWebSearchManager()
if mgr == nil {
return nil, "", fmt.Errorf("web search emulation: manager not initialized")
}
resp, providerName, err := mgr.SearchWithBestProvider(ctx, websearch.SearchRequest{
Query: query, MaxResults: webSearchDefaultMaxResults, ProxyURL: proxyURL,
})
if err != nil {
slog.Error("web search emulation: search failed", "error", err)
return nil, "", fmt.Errorf("web search emulation: %w", err)
}
return resp, providerName, nil
}
func resolveAccountProxyURL(account *Account) string {
if account.ProxyID != nil && account.Proxy != nil {
return account.Proxy.URL()
}
return ""
}
// --- SSE streaming response ---
func writeWebSearchStreamResponse(
c *gin.Context, query string, resp *websearch.SearchResponse, model string, startTime time.Time,
) (*ForwardResult, error) {
msgID := webSearchMsgIDPrefix + uuid.New().String()
toolUseID := webSearchToolUseIDPrefix + uuid.New().String()[:16]
textSummary := buildTextSummary(query, resp.Results)
setSSEHeaders(c)
w := c.Writer
for _, fn := range []func() error{
func() error { return writeSSEMessageStart(w, msgID, model) },
func() error { return writeSSEServerToolUse(w, toolUseID, query, 0) },
func() error { return writeSSEToolResult(w, toolUseID, resp.Results, 1) },
func() error { return writeSSETextBlock(w, textSummary, 2) },
func() error { return writeSSEMessageEnd(w, len(textSummary)/tokenEstimateDivisor) },
} {
if err := fn(); err != nil {
slog.Warn("web search emulation: SSE write failed, stopping", "error", err)
break
}
}
w.Flush()
return &ForwardResult{Model: model, Duration: time.Since(startTime), Usage: ClaudeUsage{}}, nil
}
func setSSEHeaders(c *gin.Context) {
c.Writer.Header().Set("Content-Type", "text/event-stream")
c.Writer.Header().Set("Cache-Control", "no-cache")
c.Writer.Header().Set("Connection", "keep-alive")
c.Writer.Header().Set("X-Accel-Buffering", "no")
c.Writer.WriteHeader(http.StatusOK)
}
func writeSSEMessageStart(w http.ResponseWriter, msgID, model string) error {
evt := map[string]any{
"type": "message_start",
"message": map[string]any{
"id": msgID, "type": "message", "role": "assistant", "model": model,
"content": []any{}, "stop_reason": nil, "stop_sequence": nil,
"usage": map[string]int{"input_tokens": 0, "output_tokens": 0},
},
}
return flushSSEJSON(w, "message_start", evt)
}
func writeSSEServerToolUse(w http.ResponseWriter, toolUseID, query string, index int) error {
start := map[string]any{
"type": "content_block_start", "index": index,
"content_block": map[string]any{
"type": "server_tool_use", "id": toolUseID,
"name": toolNameWebSearch, "input": map[string]string{"query": query},
},
}
if err := flushSSEJSON(w, "content_block_start", start); err != nil {
return err
}
return flushSSEJSON(w, "content_block_stop", map[string]any{"type": "content_block_stop", "index": index})
}
func writeSSEToolResult(w http.ResponseWriter, toolUseID string, results []websearch.SearchResult, index int) error {
start := map[string]any{
"type": "content_block_start", "index": index,
"content_block": map[string]any{
"type": "web_search_tool_result", "tool_use_id": toolUseID,
"content": buildSearchResultBlocks(results),
},
}
if err := flushSSEJSON(w, "content_block_start", start); err != nil {
return err
}
return flushSSEJSON(w, "content_block_stop", map[string]any{"type": "content_block_stop", "index": index})
}
func writeSSETextBlock(w http.ResponseWriter, text string, index int) error {
if err := flushSSEJSON(w, "content_block_start", map[string]any{
"type": "content_block_start", "index": index,
"content_block": map[string]any{"type": "text", "text": ""},
}); err != nil {
return err
}
if err := flushSSEJSON(w, "content_block_delta", map[string]any{
"type": "content_block_delta", "index": index,
"delta": map[string]string{"type": "text_delta", "text": text},
}); err != nil {
return err
}
return flushSSEJSON(w, "content_block_stop", map[string]any{"type": "content_block_stop", "index": index})
}
func writeSSEMessageEnd(w http.ResponseWriter, outputTokens int) error {
if err := flushSSEJSON(w, "message_delta", map[string]any{
"type": "message_delta",
"delta": map[string]any{"stop_reason": "end_turn", "stop_sequence": nil},
"usage": map[string]int{"output_tokens": outputTokens},
}); err != nil {
return err
}
return flushSSEJSON(w, "message_stop", map[string]string{"type": "message_stop"})
}
// flushSSEJSON marshals data to JSON and writes an SSE event.
func flushSSEJSON(w http.ResponseWriter, event string, data any) error {
b, err := json.Marshal(data)
if err != nil {
return fmt.Errorf("marshal: %w", err)
}
if _, err := fmt.Fprintf(w, "event: %s\ndata: %s\n\n", event, b); err != nil {
return fmt.Errorf("write: %w", err)
}
if f, ok := w.(http.Flusher); ok {
f.Flush()
}
return nil
}
// --- Non-streaming JSON response ---
func writeWebSearchNonStreamResponse(
c *gin.Context, query string, resp *websearch.SearchResponse, model string, startTime time.Time,
) (*ForwardResult, error) {
msgID := webSearchMsgIDPrefix + uuid.New().String()
toolUseID := webSearchToolUseIDPrefix + uuid.New().String()[:16]
textSummary := buildTextSummary(query, resp.Results)
msg := map[string]any{
"id": msgID, "type": "message", "role": "assistant", "model": model,
"content": []any{
map[string]any{
"type": "server_tool_use", "id": toolUseID,
"name": toolNameWebSearch, "input": map[string]string{"query": query},
},
map[string]any{
"type": "web_search_tool_result", "tool_use_id": toolUseID,
"content": buildSearchResultBlocks(resp.Results),
},
map[string]any{"type": "text", "text": textSummary},
},
"stop_reason": "end_turn", "stop_sequence": nil,
"usage": map[string]int{"input_tokens": 0, "output_tokens": len(textSummary) / tokenEstimateDivisor},
}
body, err := json.Marshal(msg)
if err != nil {
return nil, fmt.Errorf("web search emulation: marshal response: %w", err)
}
c.Data(http.StatusOK, "application/json", body)
return &ForwardResult{Model: model, Duration: time.Since(startTime), Usage: ClaudeUsage{}}, nil
}
// --- Helpers ---
func buildSearchResultBlocks(results []websearch.SearchResult) []map[string]string {
blocks := make([]map[string]string, 0, len(results))
for _, r := range results {
block := map[string]string{
"type": "web_search_result",
"url": r.URL,
"title": r.Title,
}
if r.Snippet != "" {
block["page_content"] = r.Snippet
}
if r.PageAge != "" {
block["page_age"] = r.PageAge
}
blocks = append(blocks, block)
}
return blocks
}
func buildTextSummary(query string, results []websearch.SearchResult) string {
if len(results) == 0 {
return "No search results found for: " + query
}
var sb strings.Builder
fmt.Fprintf(&sb, "Here are the search results for \"%s\":\n\n", query)
for i, r := range results {
fmt.Fprintf(&sb, "%d. **%s**\n %s\n %s\n\n", i+1, r.Title, r.URL, r.Snippet)
}
return sb.String()
}
//go:build unit
package service
import (
"context"
"encoding/json"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/websearch"
"github.com/stretchr/testify/require"
)
// --- isOnlyWebSearchToolInBody ---
func TestIsOnlyWebSearchToolInBody_WebSearchType(t *testing.T) {
require.True(t, isOnlyWebSearchToolInBody([]byte(`{"tools":[{"type":"web_search"}]}`)))
}
func TestIsOnlyWebSearchToolInBody_WebSearch2025Type(t *testing.T) {
require.True(t, isOnlyWebSearchToolInBody([]byte(`{"tools":[{"type":"web_search_20250305"}]}`)))
}
func TestIsOnlyWebSearchToolInBody_GoogleSearchType(t *testing.T) {
require.True(t, isOnlyWebSearchToolInBody([]byte(`{"tools":[{"type":"google_search"}]}`)))
}
func TestIsOnlyWebSearchToolInBody_NameWebSearch(t *testing.T) {
require.True(t, isOnlyWebSearchToolInBody([]byte(`{"tools":[{"name":"web_search"}]}`)))
}
func TestIsOnlyWebSearchToolInBody_NameWebSearch2025(t *testing.T) {
require.True(t, isOnlyWebSearchToolInBody([]byte(`{"tools":[{"name":"web_search_20250305"}]}`)))
}
func TestIsOnlyWebSearchToolInBody_NameGoogleSearch(t *testing.T) {
require.True(t, isOnlyWebSearchToolInBody([]byte(`{"tools":[{"name":"google_search"}]}`)))
}
func TestIsOnlyWebSearchToolInBody_MultipleTools(t *testing.T) {
require.False(t, isOnlyWebSearchToolInBody(
[]byte(`{"tools":[{"type":"web_search"},{"type":"text_editor"}]}`)))
}
func TestIsOnlyWebSearchToolInBody_NoTools(t *testing.T) {
require.False(t, isOnlyWebSearchToolInBody([]byte(`{"model":"claude-3"}`)))
}
func TestIsOnlyWebSearchToolInBody_EmptyToolsArray(t *testing.T) {
require.False(t, isOnlyWebSearchToolInBody([]byte(`{"tools":[]}`)))
}
func TestIsOnlyWebSearchToolInBody_NonWebSearchTool(t *testing.T) {
require.False(t, isOnlyWebSearchToolInBody([]byte(`{"tools":[{"type":"text_editor"}]}`)))
}
func TestIsOnlyWebSearchToolInBody_ToolsNotArray(t *testing.T) {
require.False(t, isOnlyWebSearchToolInBody([]byte(`{"tools":"web_search"}`)))
}
// --- extractSearchQueryFromBody ---
func TestExtractSearchQueryFromBody_StringContent(t *testing.T) {
body := `{"messages":[{"role":"user","content":"what is golang"}]}`
require.Equal(t, "what is golang", extractSearchQueryFromBody([]byte(body)))
}
func TestExtractSearchQueryFromBody_ArrayContent(t *testing.T) {
body := `{"messages":[{"role":"user","content":[{"type":"text","text":"search this"}]}]}`
require.Equal(t, "search this", extractSearchQueryFromBody([]byte(body)))
}
func TestExtractSearchQueryFromBody_MultipleMessages(t *testing.T) {
body := `{"messages":[{"role":"user","content":"first"},{"role":"assistant","content":"ok"},{"role":"user","content":"second"}]}`
require.Equal(t, "second", extractSearchQueryFromBody([]byte(body)))
}
func TestExtractSearchQueryFromBody_LastMessageNotUser(t *testing.T) {
body := `{"messages":[{"role":"user","content":"q"},{"role":"assistant","content":"a"}]}`
require.Equal(t, "", extractSearchQueryFromBody([]byte(body)))
}
func TestExtractSearchQueryFromBody_EmptyMessages(t *testing.T) {
require.Equal(t, "", extractSearchQueryFromBody([]byte(`{"messages":[]}`)))
}
func TestExtractSearchQueryFromBody_NoMessages(t *testing.T) {
require.Equal(t, "", extractSearchQueryFromBody([]byte(`{"model":"claude-3"}`)))
}
func TestExtractSearchQueryFromBody_ArrayContentSkipsEmptyText(t *testing.T) {
body := `{"messages":[{"role":"user","content":[{"type":"image"},{"type":"text","text":""},{"type":"text","text":"real query"}]}]}`
require.Equal(t, "real query", extractSearchQueryFromBody([]byte(body)))
}
func TestExtractSearchQueryFromBody_ArrayContentNoTextBlock(t *testing.T) {
body := `{"messages":[{"role":"user","content":[{"type":"image","source":{}}]}]}`
require.Equal(t, "", extractSearchQueryFromBody([]byte(body)))
}
// --- buildSearchResultBlocks ---
func TestBuildSearchResultBlocks_WithResults(t *testing.T) {
results := []websearch.SearchResult{
{URL: "https://a.com", Title: "A", Snippet: "snippet a", PageAge: "2 days"},
{URL: "https://b.com", Title: "B", Snippet: "snippet b"},
}
blocks := buildSearchResultBlocks(results)
require.Len(t, blocks, 2)
require.Equal(t, "web_search_result", blocks[0]["type"])
require.Equal(t, "https://a.com", blocks[0]["url"])
require.Equal(t, "snippet a", blocks[0]["page_content"])
require.Equal(t, "2 days", blocks[0]["page_age"])
// Second result has no PageAge
require.Equal(t, "https://b.com", blocks[1]["url"])
_, hasPageAge := blocks[1]["page_age"]
require.False(t, hasPageAge)
}
func TestBuildSearchResultBlocks_Empty(t *testing.T) {
blocks := buildSearchResultBlocks(nil)
require.Empty(t, blocks)
}
func TestBuildSearchResultBlocks_SnippetEmpty(t *testing.T) {
blocks := buildSearchResultBlocks([]websearch.SearchResult{{URL: "https://x.com", Title: "X", Snippet: ""}})
_, hasContent := blocks[0]["page_content"]
require.False(t, hasContent)
}
// --- buildTextSummary ---
func TestBuildTextSummary_WithResults(t *testing.T) {
results := []websearch.SearchResult{
{URL: "https://a.com", Title: "A", Snippet: "desc a"},
}
summary := buildTextSummary("test query", results)
require.Contains(t, summary, "test query")
require.Contains(t, summary, "1. **A**")
require.Contains(t, summary, "https://a.com")
}
func TestBuildTextSummary_NoResults(t *testing.T) {
summary := buildTextSummary("test", nil)
require.Contains(t, summary, "No search results found for: test")
}
// --- shouldEmulateWebSearch ---
// webSearchToolBody is a valid request body with exactly one web_search tool.
var webSearchToolBody = []byte(`{"tools":[{"type":"web_search"}],"messages":[{"role":"user","content":"test"}]}`)
// nonWebSearchToolBody is a request body without web_search tool.
var nonWebSearchToolBody = []byte(`{"tools":[{"type":"text_editor"}],"messages":[{"role":"user","content":"test"}]}`)
// newAnthropicAPIKeyAccount creates a test Account with the given web search emulation mode.
func newAnthropicAPIKeyAccount(mode string) *Account {
return &Account{
ID: 1,
Platform: PlatformAnthropic,
Type: AccountTypeAPIKey,
Extra: map[string]any{featureKeyWebSearchEmulation: mode},
}
}
// setGlobalWebSearchConfig stores a config in the global cache used by SettingService.IsWebSearchEmulationEnabled.
func setGlobalWebSearchConfig(cfg *WebSearchEmulationConfig) {
webSearchEmulationCache.Store(&cachedWebSearchEmulationConfig{
config: cfg,
expiresAt: time.Now().Add(10 * time.Minute).UnixNano(),
})
}
// clearGlobalWebSearchConfig resets the global cache to force re-read.
func clearGlobalWebSearchConfig() {
webSearchEmulationCache.Store((*cachedWebSearchEmulationConfig)(nil))
}
// newSettingServiceForWebSearchTest creates a SettingService with a mock repo pre-loaded with config.
func newSettingServiceForWebSearchTest(enabled bool) *SettingService {
repo := newMockSettingRepo()
cfg := &WebSearchEmulationConfig{
Enabled: enabled,
Providers: []WebSearchProviderConfig{{Type: "brave", APIKey: "sk-test"}},
}
data, _ := json.Marshal(cfg)
repo.data[SettingKeyWebSearchEmulationConfig] = string(data)
return NewSettingService(repo, &config.Config{})
}
// newChannelServiceWithCache creates a ChannelService with a pre-built cache containing the channel.
func newChannelServiceWithCache(groupID int64, ch *Channel) *ChannelService {
svc := &ChannelService{}
cache := &channelCache{
channelByGroupID: map[int64]*Channel{groupID: ch},
byID: map[int64]*Channel{ch.ID: ch},
groupPlatform: map[int64]string{},
loadedAt: time.Now(),
}
svc.cache.Store(cache)
return svc
}
func TestShouldEmulateWebSearch_NilManager(t *testing.T) {
SetWebSearchManager(nil)
defer SetWebSearchManager(nil)
settingSvc := newSettingServiceForWebSearchTest(true)
setGlobalWebSearchConfig(&WebSearchEmulationConfig{
Enabled: true,
Providers: []WebSearchProviderConfig{{Type: "brave", APIKey: "k"}},
})
defer clearGlobalWebSearchConfig()
svc := &GatewayService{settingService: settingSvc}
account := newAnthropicAPIKeyAccount(WebSearchModeEnabled)
require.False(t, svc.shouldEmulateWebSearch(context.Background(), account, nil, webSearchToolBody))
}
func TestShouldEmulateWebSearch_NotOnlyWebSearchTool(t *testing.T) {
mgr := websearch.NewManager([]websearch.ProviderConfig{{Type: "brave", APIKey: "k"}}, nil)
SetWebSearchManager(mgr)
defer SetWebSearchManager(nil)
settingSvc := newSettingServiceForWebSearchTest(true)
setGlobalWebSearchConfig(&WebSearchEmulationConfig{
Enabled: true,
Providers: []WebSearchProviderConfig{{Type: "brave", APIKey: "k"}},
})
defer clearGlobalWebSearchConfig()
svc := &GatewayService{settingService: settingSvc}
account := newAnthropicAPIKeyAccount(WebSearchModeEnabled)
require.False(t, svc.shouldEmulateWebSearch(context.Background(), account, nil, nonWebSearchToolBody))
}
func TestShouldEmulateWebSearch_GlobalDisabled(t *testing.T) {
mgr := websearch.NewManager([]websearch.ProviderConfig{{Type: "brave", APIKey: "k"}}, nil)
SetWebSearchManager(mgr)
defer SetWebSearchManager(nil)
// Global config disabled
setGlobalWebSearchConfig(&WebSearchEmulationConfig{
Enabled: false,
Providers: []WebSearchProviderConfig{{Type: "brave", APIKey: "k"}},
})
defer clearGlobalWebSearchConfig()
settingSvc := newSettingServiceForWebSearchTest(false)
svc := &GatewayService{settingService: settingSvc}
account := newAnthropicAPIKeyAccount(WebSearchModeEnabled)
require.False(t, svc.shouldEmulateWebSearch(context.Background(), account, nil, webSearchToolBody))
}
func TestShouldEmulateWebSearch_AccountDisabled(t *testing.T) {
mgr := websearch.NewManager([]websearch.ProviderConfig{{Type: "brave", APIKey: "k"}}, nil)
SetWebSearchManager(mgr)
defer SetWebSearchManager(nil)
setGlobalWebSearchConfig(&WebSearchEmulationConfig{
Enabled: true,
Providers: []WebSearchProviderConfig{{Type: "brave", APIKey: "k"}},
})
defer clearGlobalWebSearchConfig()
settingSvc := newSettingServiceForWebSearchTest(true)
svc := &GatewayService{settingService: settingSvc}
account := newAnthropicAPIKeyAccount(WebSearchModeDisabled)
require.False(t, svc.shouldEmulateWebSearch(context.Background(), account, nil, webSearchToolBody))
}
func TestShouldEmulateWebSearch_AccountEnabled(t *testing.T) {
mgr := websearch.NewManager([]websearch.ProviderConfig{{Type: "brave", APIKey: "k"}}, nil)
SetWebSearchManager(mgr)
defer SetWebSearchManager(nil)
setGlobalWebSearchConfig(&WebSearchEmulationConfig{
Enabled: true,
Providers: []WebSearchProviderConfig{{Type: "brave", APIKey: "k"}},
})
defer clearGlobalWebSearchConfig()
settingSvc := newSettingServiceForWebSearchTest(true)
svc := &GatewayService{settingService: settingSvc}
account := newAnthropicAPIKeyAccount(WebSearchModeEnabled)
require.True(t, svc.shouldEmulateWebSearch(context.Background(), account, nil, webSearchToolBody))
}
func TestShouldEmulateWebSearch_DefaultMode_ChannelEnabled(t *testing.T) {
mgr := websearch.NewManager([]websearch.ProviderConfig{{Type: "brave", APIKey: "k"}}, nil)
SetWebSearchManager(mgr)
defer SetWebSearchManager(nil)
setGlobalWebSearchConfig(&WebSearchEmulationConfig{
Enabled: true,
Providers: []WebSearchProviderConfig{{Type: "brave", APIKey: "k"}},
})
defer clearGlobalWebSearchConfig()
settingSvc := newSettingServiceForWebSearchTest(true)
ch := &Channel{
ID: 10,
Status: StatusActive,
FeaturesConfig: map[string]any{
featureKeyWebSearchEmulation: map[string]any{PlatformAnthropic: true},
},
}
channelSvc := newChannelServiceWithCache(42, ch)
svc := &GatewayService{settingService: settingSvc, channelService: channelSvc}
account := newAnthropicAPIKeyAccount(WebSearchModeDefault)
groupID := int64(42)
require.True(t, svc.shouldEmulateWebSearch(context.Background(), account, &groupID, webSearchToolBody))
}
func TestShouldEmulateWebSearch_DefaultMode_ChannelDisabled(t *testing.T) {
mgr := websearch.NewManager([]websearch.ProviderConfig{{Type: "brave", APIKey: "k"}}, nil)
SetWebSearchManager(mgr)
defer SetWebSearchManager(nil)
setGlobalWebSearchConfig(&WebSearchEmulationConfig{
Enabled: true,
Providers: []WebSearchProviderConfig{{Type: "brave", APIKey: "k"}},
})
defer clearGlobalWebSearchConfig()
settingSvc := newSettingServiceForWebSearchTest(true)
ch := &Channel{
ID: 10,
Status: StatusActive,
FeaturesConfig: map[string]any{
featureKeyWebSearchEmulation: map[string]any{PlatformAnthropic: false},
},
}
channelSvc := newChannelServiceWithCache(42, ch)
svc := &GatewayService{settingService: settingSvc, channelService: channelSvc}
account := newAnthropicAPIKeyAccount(WebSearchModeDefault)
groupID := int64(42)
require.False(t, svc.shouldEmulateWebSearch(context.Background(), account, &groupID, webSearchToolBody))
}
func TestShouldEmulateWebSearch_DefaultMode_NilGroupID(t *testing.T) {
mgr := websearch.NewManager([]websearch.ProviderConfig{{Type: "brave", APIKey: "k"}}, nil)
SetWebSearchManager(mgr)
defer SetWebSearchManager(nil)
setGlobalWebSearchConfig(&WebSearchEmulationConfig{
Enabled: true,
Providers: []WebSearchProviderConfig{{Type: "brave", APIKey: "k"}},
})
defer clearGlobalWebSearchConfig()
settingSvc := newSettingServiceForWebSearchTest(true)
svc := &GatewayService{settingService: settingSvc}
account := newAnthropicAPIKeyAccount(WebSearchModeDefault)
// nil groupID + default mode → falls through to channel check → returns false
require.False(t, svc.shouldEmulateWebSearch(context.Background(), account, nil, webSearchToolBody))
}
func TestShouldEmulateWebSearch_DefaultMode_NilChannelService(t *testing.T) {
mgr := websearch.NewManager([]websearch.ProviderConfig{{Type: "brave", APIKey: "k"}}, nil)
SetWebSearchManager(mgr)
defer SetWebSearchManager(nil)
setGlobalWebSearchConfig(&WebSearchEmulationConfig{
Enabled: true,
Providers: []WebSearchProviderConfig{{Type: "brave", APIKey: "k"}},
})
defer clearGlobalWebSearchConfig()
settingSvc := newSettingServiceForWebSearchTest(true)
svc := &GatewayService{settingService: settingSvc, channelService: nil}
account := newAnthropicAPIKeyAccount(WebSearchModeDefault)
groupID := int64(42)
// nil channelService + default mode → returns false
require.False(t, svc.shouldEmulateWebSearch(context.Background(), account, &groupID, webSearchToolBody))
}
package service
import (
"encoding/json"
"strings"
)
// NotifyEmailEntry represents a notification email with enable/disable and verification state.
// All emails are user-managed; maximum 3 entries per user.
type NotifyEmailEntry struct {
Email string `json:"email"`
Disabled bool `json:"disabled"`
Verified bool `json:"verified"`
}
// parseNotifyEmails parses a JSON string into []NotifyEmailEntry.
// It auto-detects the format:
// - Old format ["email1","email2"] → converted to [{email, disabled:false, verified:true}, ...]
// - New format [{email,disabled,verified}, ...] → parsed directly
//
// Returns nil on empty/invalid input.
func ParseNotifyEmails(raw string) []NotifyEmailEntry {
raw = strings.TrimSpace(raw)
if raw == "" || raw == "[]" {
return nil
}
// Try parsing as new format first (array of objects)
var entries []NotifyEmailEntry
if err := json.Unmarshal([]byte(raw), &entries); err == nil && len(entries) > 0 {
// Verify it's actually the new format by checking the first element
// json.Unmarshal into []NotifyEmailEntry succeeds even for ["string"]
// because it tries to fit "string" into NotifyEmailEntry and gets zero values.
// We need to detect old format explicitly.
if !isOldStringArrayFormat(raw) {
return entries
}
}
// Try parsing as old format (array of strings)
var emails []string
if err := json.Unmarshal([]byte(raw), &emails); err == nil {
result := make([]NotifyEmailEntry, 0, len(emails))
for _, e := range emails {
e = strings.TrimSpace(e)
if e != "" {
result = append(result, NotifyEmailEntry{
Email: e,
Disabled: false,
Verified: false, // Old format emails default to unverified
})
}
}
return result
}
return nil
}
// isOldStringArrayFormat checks if the JSON is a string array like ["email1","email2"].
func isOldStringArrayFormat(raw string) bool {
var arr []json.RawMessage
if err := json.Unmarshal([]byte(raw), &arr); err != nil || len(arr) == 0 {
return false
}
// Check if first element starts with a quote (string) vs { (object)
first := strings.TrimSpace(string(arr[0]))
return len(first) > 0 && first[0] == '"'
}
// marshalNotifyEmails serializes []NotifyEmailEntry to JSON string.
func MarshalNotifyEmails(entries []NotifyEmailEntry) string {
if len(entries) == 0 {
return "[]"
}
data, err := json.Marshal(entries)
if err != nil {
return "[]"
}
return string(data)
}
//go:build unit
package service
import (
"testing"
"github.com/stretchr/testify/require"
)
// ---------- ParseNotifyEmails ----------
func TestParseNotifyEmails_EmptyString(t *testing.T) {
result := ParseNotifyEmails("")
require.Nil(t, result)
}
func TestParseNotifyEmails_EmptyArray(t *testing.T) {
result := ParseNotifyEmails("[]")
require.Nil(t, result)
}
func TestParseNotifyEmails_Null(t *testing.T) {
// "null" is valid JSON that unmarshals into a nil string slice.
// The old-format branch then returns an empty (non-nil) slice.
result := ParseNotifyEmails("null")
require.Empty(t, result)
}
func TestParseNotifyEmails_WhitespaceOnly(t *testing.T) {
result := ParseNotifyEmails(" ")
require.Nil(t, result)
}
func TestParseNotifyEmails_OldFormat(t *testing.T) {
raw := `["alice@example.com", "bob@example.com"]`
result := ParseNotifyEmails(raw)
require.Len(t, result, 2)
require.Equal(t, "alice@example.com", result[0].Email)
require.False(t, result[0].Verified, "old format emails should default to unverified")
require.False(t, result[0].Disabled)
require.Equal(t, "bob@example.com", result[1].Email)
require.False(t, result[1].Verified)
require.False(t, result[1].Disabled)
}
func TestParseNotifyEmails_OldFormat_SkipsEmptyEntries(t *testing.T) {
raw := `["alice@example.com", "", " ", "bob@example.com"]`
result := ParseNotifyEmails(raw)
require.Len(t, result, 2)
require.Equal(t, "alice@example.com", result[0].Email)
require.Equal(t, "bob@example.com", result[1].Email)
}
func TestParseNotifyEmails_NewFormat(t *testing.T) {
raw := `[{"email":"alice@example.com","verified":true,"disabled":false},{"email":"bob@example.com","verified":false,"disabled":true}]`
result := ParseNotifyEmails(raw)
require.Len(t, result, 2)
require.Equal(t, "alice@example.com", result[0].Email)
require.True(t, result[0].Verified)
require.False(t, result[0].Disabled)
require.Equal(t, "bob@example.com", result[1].Email)
require.False(t, result[1].Verified)
require.True(t, result[1].Disabled)
}
func TestParseNotifyEmails_NewFormat_SingleEntry(t *testing.T) {
raw := `[{"email":"solo@example.com","verified":true,"disabled":false}]`
result := ParseNotifyEmails(raw)
require.Len(t, result, 1)
require.Equal(t, "solo@example.com", result[0].Email)
require.True(t, result[0].Verified)
}
func TestParseNotifyEmails_InvalidJSON(t *testing.T) {
result := ParseNotifyEmails(`{not valid json`)
require.Nil(t, result)
}
func TestParseNotifyEmails_InvalidJSONObject(t *testing.T) {
// A plain JSON object (not array) should return nil.
result := ParseNotifyEmails(`{"email":"a@b.com"}`)
require.Nil(t, result)
}
func TestParseNotifyEmails_WhitespacePadding(t *testing.T) {
raw := ` ["padded@example.com"] `
result := ParseNotifyEmails(raw)
require.Len(t, result, 1)
require.Equal(t, "padded@example.com", result[0].Email)
}
// ---------- MarshalNotifyEmails ----------
func TestMarshalNotifyEmails_EmptySlice(t *testing.T) {
result := MarshalNotifyEmails([]NotifyEmailEntry{})
require.Equal(t, "[]", result)
}
func TestMarshalNotifyEmails_NilSlice(t *testing.T) {
result := MarshalNotifyEmails(nil)
require.Equal(t, "[]", result)
}
func TestMarshalNotifyEmails_SingleEntry(t *testing.T) {
entries := []NotifyEmailEntry{
{Email: "test@example.com", Verified: true, Disabled: false},
}
result := MarshalNotifyEmails(entries)
require.Contains(t, result, `"email":"test@example.com"`)
require.Contains(t, result, `"verified":true`)
require.Contains(t, result, `"disabled":false`)
// Round-trip: parsing the marshalled result should produce the original entries.
parsed := ParseNotifyEmails(result)
require.Len(t, parsed, 1)
require.Equal(t, entries[0], parsed[0])
}
func TestMarshalNotifyEmails_MultipleEntries(t *testing.T) {
entries := []NotifyEmailEntry{
{Email: "a@example.com", Verified: true, Disabled: false},
{Email: "b@example.com", Verified: false, Disabled: true},
}
result := MarshalNotifyEmails(entries)
// Round-trip verification.
parsed := ParseNotifyEmails(result)
require.Len(t, parsed, 2)
require.Equal(t, entries[0], parsed[0])
require.Equal(t, entries[1], parsed[1])
}
func TestMarshalNotifyEmails_RoundTrip_NewFormat(t *testing.T) {
original := []NotifyEmailEntry{
{Email: "x@example.com", Verified: true, Disabled: true},
{Email: "y@example.com", Verified: false, Disabled: false},
}
marshalled := MarshalNotifyEmails(original)
parsed := ParseNotifyEmails(marshalled)
require.Equal(t, original, parsed)
}
// ---------- isOldStringArrayFormat (indirectly via ParseNotifyEmails) ----------
func TestParseNotifyEmails_MixedOldFormatWithWhitespace(t *testing.T) {
// Emails with leading/trailing whitespace in old format should be trimmed.
raw := `[" alice@example.com "]`
result := ParseNotifyEmails(raw)
require.Len(t, result, 1)
require.Equal(t, "alice@example.com", result[0].Email)
}
......@@ -147,6 +147,7 @@ func newOpenAIRecordUsageServiceForTest(usageRepo UsageLogRepository, userRepo U
nil,
nil,
nil,
nil,
)
svc.userGroupRateResolver = newUserGroupRateResolver(
rateRepo,
......
......@@ -327,6 +327,7 @@ type OpenAIGatewayService struct {
openaiWSResolver OpenAIWSProtocolResolver
resolver *ModelPricingResolver
channelService *ChannelService
balanceNotifyService *BalanceNotifyService
openaiWSPoolOnce sync.Once
openaiWSStateStoreOnce sync.Once
......@@ -364,6 +365,7 @@ func NewOpenAIGatewayService(
openAITokenProvider *OpenAITokenProvider,
resolver *ModelPricingResolver,
channelService *ChannelService,
balanceNotifyService *BalanceNotifyService,
) *OpenAIGatewayService {
svc := &OpenAIGatewayService{
accountRepo: accountRepo,
......@@ -393,6 +395,7 @@ func NewOpenAIGatewayService(
openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
resolver: resolver,
channelService: channelService,
balanceNotifyService: balanceNotifyService,
responseHeaderFilter: compileResponseHeaderFilter(cfg),
codexSnapshotThrottle: newAccountWriteThrottle(openAICodexSnapshotPersistMinInterval),
}
......@@ -477,11 +480,12 @@ func (s *OpenAIGatewayService) getCodexSnapshotThrottle() *accountWriteThrottle
func (s *OpenAIGatewayService) billingDeps() *billingDeps {
return &billingDeps{
accountRepo: s.accountRepo,
userRepo: s.userRepo,
userSubRepo: s.userSubRepo,
billingCacheService: s.billingCacheService,
deferredService: s.deferredService,
accountRepo: s.accountRepo,
userRepo: s.userRepo,
userSubRepo: s.userSubRepo,
billingCacheService: s.billingCacheService,
deferredService: s.deferredService,
balanceNotifyService: s.balanceNotifyService,
}
}
......@@ -4569,6 +4573,14 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
usageLog.SubscriptionID = &subscription.ID
}
// 计算账号统计定价费用(使用最终上游模型匹配自定义规则)
if apiKey.GroupID != nil {
applyAccountStatsCost(ctx, usageLog, s.channelService, s.billingService,
account.ID, *apiKey.GroupID, result.UpstreamModel, result.Model,
tokens, cost.TotalCost,
)
}
if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple {
writeUsageLogBestEffort(ctx, s.usageLogRepo, usageLog, "service.openai_gateway")
logger.LegacyPrintf("service.openai_gateway", "[SIMPLE MODE] Usage recorded (not billed): user=%d, tokens=%d", usageLog.UserID, usageLog.TotalTokens())
......
......@@ -413,7 +413,12 @@ func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_PassthroughModeR
select {
case serverErr := <-serverErrCh:
require.NoError(t, serverErr)
// After normal client close, the server goroutine may receive the close frame
// as an error — this is expected behavior, not a test failure.
if serverErr != nil {
require.Contains(t, serverErr.Error(), "StatusNormalClosure",
"server error should only be a normal close frame, got: %v", serverErr)
}
case <-time.After(5 * time.Second):
t.Fatal("等待 passthrough websocket 结束超时")
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment