"backend/vscode:/vscode.git/clone" did not exist on "a68df457d80dd67b173df2a74fd69de1206c0f2e"
Unverified Commit d402e722 authored by Wesley Liddick's avatar Wesley Liddick Committed by GitHub
Browse files

Merge pull request #1637 from touwaeriol/feat/websearch-notify-pricing

feat: web search emulation, balance/quota notify, account stats pricing, per-provider refund control, Stripe fix / Web 搜索模拟、余额配额通知、渠道统计计费、按服务商退款控制、Stripe 修复
parents e534e9ba 8548a130
...@@ -233,12 +233,13 @@ func TestLoadForcedCodexInstructionsTemplate(t *testing.T) { ...@@ -233,12 +233,13 @@ func TestLoadForcedCodexInstructionsTemplate(t *testing.T) {
configPath := filepath.Join(tempDir, "config.yaml") configPath := filepath.Join(tempDir, "config.yaml")
require.NoError(t, os.WriteFile(templatePath, []byte("server-prefix\n\n{{ .ExistingInstructions }}"), 0o644)) require.NoError(t, os.WriteFile(templatePath, []byte("server-prefix\n\n{{ .ExistingInstructions }}"), 0o644))
require.NoError(t, os.WriteFile(configPath, []byte("gateway:\n forced_codex_instructions_template_file: \""+templatePath+"\"\n"), 0o644)) yamlSafePath := filepath.ToSlash(templatePath)
require.NoError(t, os.WriteFile(configPath, []byte("gateway:\n forced_codex_instructions_template_file: \""+yamlSafePath+"\"\n"), 0o644))
t.Setenv("DATA_DIR", tempDir) t.Setenv("DATA_DIR", tempDir)
cfg, err := Load() cfg, err := Load()
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, templatePath, cfg.Gateway.ForcedCodexInstructionsTemplateFile) require.Equal(t, yamlSafePath, cfg.Gateway.ForcedCodexInstructionsTemplateFile)
require.Equal(t, "server-prefix\n\n{{ .ExistingInstructions }}", cfg.Gateway.ForcedCodexInstructionsTemplate) require.Equal(t, "server-prefix\n\n{{ .ExistingInstructions }}", cfg.Gateway.ForcedCodexInstructionsTemplate)
} }
......
...@@ -1412,6 +1412,12 @@ func (h *AccountHandler) BulkUpdate(c *gin.Context) { ...@@ -1412,6 +1412,12 @@ func (h *AccountHandler) BulkUpdate(c *gin.Context) {
c.JSON(409, gin.H{ c.JSON(409, gin.H{
"error": "mixed_channel_warning", "error": "mixed_channel_warning",
"message": mixedErr.Error(), "message": mixedErr.Error(),
"details": gin.H{
"group_id": mixedErr.GroupID,
"group_name": mixedErr.GroupName,
"current_platform": mixedErr.CurrentPlatform,
"other_platform": mixedErr.OtherPlatform,
},
}) })
return return
} }
......
package admin package admin
import ( import (
"fmt"
"strconv" "strconv"
"strings" "strings"
...@@ -33,6 +34,10 @@ type createChannelRequest struct { ...@@ -33,6 +34,10 @@ type createChannelRequest struct {
ModelMapping map[string]map[string]string `json:"model_mapping"` ModelMapping map[string]map[string]string `json:"model_mapping"`
BillingModelSource string `json:"billing_model_source" binding:"omitempty,oneof=requested upstream channel_mapped"` BillingModelSource string `json:"billing_model_source" binding:"omitempty,oneof=requested upstream channel_mapped"`
RestrictModels bool `json:"restrict_models"` RestrictModels bool `json:"restrict_models"`
Features string `json:"features"`
FeaturesConfig map[string]any `json:"features_config"`
ApplyPricingToAccountStats bool `json:"apply_pricing_to_account_stats"`
AccountStatsPricingRules []accountStatsPricingRuleRequest `json:"account_stats_pricing_rules"`
} }
type updateChannelRequest struct { type updateChannelRequest struct {
...@@ -44,6 +49,10 @@ type updateChannelRequest struct { ...@@ -44,6 +49,10 @@ type updateChannelRequest struct {
ModelMapping map[string]map[string]string `json:"model_mapping"` ModelMapping map[string]map[string]string `json:"model_mapping"`
BillingModelSource string `json:"billing_model_source" binding:"omitempty,oneof=requested upstream channel_mapped"` BillingModelSource string `json:"billing_model_source" binding:"omitempty,oneof=requested upstream channel_mapped"`
RestrictModels *bool `json:"restrict_models"` RestrictModels *bool `json:"restrict_models"`
Features *string `json:"features"`
FeaturesConfig map[string]any `json:"features_config"`
ApplyPricingToAccountStats *bool `json:"apply_pricing_to_account_stats"`
AccountStatsPricingRules *[]accountStatsPricingRuleRequest `json:"account_stats_pricing_rules"`
} }
type channelModelPricingRequest struct { type channelModelPricingRequest struct {
...@@ -71,6 +80,13 @@ type pricingIntervalRequest struct { ...@@ -71,6 +80,13 @@ type pricingIntervalRequest struct {
SortOrder int `json:"sort_order"` SortOrder int `json:"sort_order"`
} }
type accountStatsPricingRuleRequest struct {
Name string `json:"name"`
GroupIDs []int64 `json:"group_ids"`
AccountIDs []int64 `json:"account_ids"`
Pricing []channelModelPricingRequest `json:"pricing"`
}
type channelResponse struct { type channelResponse struct {
ID int64 `json:"id"` ID int64 `json:"id"`
Name string `json:"name"` Name string `json:"name"`
...@@ -78,9 +94,13 @@ type channelResponse struct { ...@@ -78,9 +94,13 @@ type channelResponse struct {
Status string `json:"status"` Status string `json:"status"`
BillingModelSource string `json:"billing_model_source"` BillingModelSource string `json:"billing_model_source"`
RestrictModels bool `json:"restrict_models"` RestrictModels bool `json:"restrict_models"`
Features string `json:"features"`
FeaturesConfig map[string]any `json:"features_config"`
GroupIDs []int64 `json:"group_ids"` GroupIDs []int64 `json:"group_ids"`
ModelPricing []channelModelPricingResponse `json:"model_pricing"` ModelPricing []channelModelPricingResponse `json:"model_pricing"`
ModelMapping map[string]map[string]string `json:"model_mapping"` ModelMapping map[string]map[string]string `json:"model_mapping"`
ApplyPricingToAccountStats bool `json:"apply_pricing_to_account_stats"`
AccountStatsPricingRules []accountStatsPricingRuleResponse `json:"account_stats_pricing_rules"`
CreatedAt string `json:"created_at"` CreatedAt string `json:"created_at"`
UpdatedAt string `json:"updated_at"` UpdatedAt string `json:"updated_at"`
} }
...@@ -112,6 +132,14 @@ type pricingIntervalResponse struct { ...@@ -112,6 +132,14 @@ type pricingIntervalResponse struct {
SortOrder int `json:"sort_order"` SortOrder int `json:"sort_order"`
} }
type accountStatsPricingRuleResponse struct {
ID int64 `json:"id"`
Name string `json:"name"`
GroupIDs []int64 `json:"group_ids"`
AccountIDs []int64 `json:"account_ids"`
Pricing []channelModelPricingResponse `json:"pricing"`
}
func channelToResponse(ch *service.Channel) *channelResponse { func channelToResponse(ch *service.Channel) *channelResponse {
if ch == nil { if ch == nil {
return nil return nil
...@@ -122,6 +150,8 @@ func channelToResponse(ch *service.Channel) *channelResponse { ...@@ -122,6 +150,8 @@ func channelToResponse(ch *service.Channel) *channelResponse {
Description: ch.Description, Description: ch.Description,
Status: ch.Status, Status: ch.Status,
RestrictModels: ch.RestrictModels, RestrictModels: ch.RestrictModels,
Features: ch.Features,
FeaturesConfig: ch.FeaturesConfig,
GroupIDs: ch.GroupIDs, GroupIDs: ch.GroupIDs,
ModelMapping: ch.ModelMapping, ModelMapping: ch.ModelMapping,
CreatedAt: ch.CreatedAt.Format("2006-01-02T15:04:05Z"), CreatedAt: ch.CreatedAt.Format("2006-01-02T15:04:05Z"),
...@@ -142,6 +172,29 @@ func channelToResponse(ch *service.Channel) *channelResponse { ...@@ -142,6 +172,29 @@ func channelToResponse(ch *service.Channel) *channelResponse {
for _, p := range ch.ModelPricing { for _, p := range ch.ModelPricing {
resp.ModelPricing = append(resp.ModelPricing, pricingToResponse(&p)) resp.ModelPricing = append(resp.ModelPricing, pricingToResponse(&p))
} }
resp.ApplyPricingToAccountStats = ch.ApplyPricingToAccountStats
resp.AccountStatsPricingRules = make([]accountStatsPricingRuleResponse, 0, len(ch.AccountStatsPricingRules))
for _, rule := range ch.AccountStatsPricingRules {
ruleResp := accountStatsPricingRuleResponse{
ID: rule.ID,
Name: rule.Name,
GroupIDs: rule.GroupIDs,
AccountIDs: rule.AccountIDs,
Pricing: make([]channelModelPricingResponse, 0, len(rule.Pricing)),
}
if ruleResp.GroupIDs == nil {
ruleResp.GroupIDs = []int64{}
}
if ruleResp.AccountIDs == nil {
ruleResp.AccountIDs = []int64{}
}
for i := range rule.Pricing {
ruleResp.Pricing = append(ruleResp.Pricing, pricingToResponse(&rule.Pricing[i]))
}
resp.AccountStatsPricingRules = append(resp.AccountStatsPricingRules, ruleResp)
}
return resp return resp
} }
...@@ -200,9 +253,6 @@ func pricingRequestToService(reqs []channelModelPricingRequest) []service.Channe ...@@ -200,9 +253,6 @@ func pricingRequestToService(reqs []channelModelPricingRequest) []service.Channe
billingMode = service.BillingModeToken billingMode = service.BillingModeToken
} }
platform := r.Platform platform := r.Platform
if platform == "" {
platform = service.PlatformAnthropic
}
intervals := make([]service.PricingInterval, 0, len(r.Intervals)) intervals := make([]service.PricingInterval, 0, len(r.Intervals))
for _, iv := range r.Intervals { for _, iv := range r.Intervals {
intervals = append(intervals, service.PricingInterval{ intervals = append(intervals, service.PricingInterval{
...@@ -233,6 +283,15 @@ func pricingRequestToService(reqs []channelModelPricingRequest) []service.Channe ...@@ -233,6 +283,15 @@ func pricingRequestToService(reqs []channelModelPricingRequest) []service.Channe
return result return result
} }
func accountStatsPricingRuleRequestToService(r accountStatsPricingRuleRequest) service.AccountStatsPricingRule {
return service.AccountStatsPricingRule{
Name: r.Name,
GroupIDs: r.GroupIDs,
AccountIDs: r.AccountIDs,
Pricing: pricingRequestToService(r.Pricing),
}
}
// --- Handlers --- // --- Handlers ---
// List handles listing channels with pagination // List handles listing channels with pagination
...@@ -291,6 +350,29 @@ func (h *ChannelHandler) Create(c *gin.Context) { ...@@ -291,6 +350,29 @@ func (h *ChannelHandler) Create(c *gin.Context) {
} }
pricing := pricingRequestToService(req.ModelPricing) pricing := pricingRequestToService(req.ModelPricing)
// Main model_pricing requires a platform; default to anthropic for backward compatibility.
for i := range pricing {
if pricing[i].Platform == "" {
pricing[i].Platform = service.PlatformAnthropic
}
}
var statsRules []service.AccountStatsPricingRule
for i, r := range req.AccountStatsPricingRules {
if len(r.GroupIDs) == 0 && len(r.AccountIDs) == 0 {
response.ErrorFrom(c, infraerrors.BadRequest("PRICING_RULE_EMPTY_SCOPE",
fmt.Sprintf("pricing rule #%d must have at least one group or account", i+1)))
return
}
if len(r.Pricing) == 0 {
response.ErrorFrom(c, infraerrors.BadRequest("PRICING_RULE_EMPTY_PRICING",
fmt.Sprintf("pricing rule #%d must have at least one pricing entry", i+1)))
return
}
rule := accountStatsPricingRuleRequestToService(r)
rule.SortOrder = i
statsRules = append(statsRules, rule)
}
channel, err := h.channelService.Create(c.Request.Context(), &service.CreateChannelInput{ channel, err := h.channelService.Create(c.Request.Context(), &service.CreateChannelInput{
Name: req.Name, Name: req.Name,
...@@ -300,6 +382,10 @@ func (h *ChannelHandler) Create(c *gin.Context) { ...@@ -300,6 +382,10 @@ func (h *ChannelHandler) Create(c *gin.Context) {
ModelMapping: req.ModelMapping, ModelMapping: req.ModelMapping,
BillingModelSource: req.BillingModelSource, BillingModelSource: req.BillingModelSource,
RestrictModels: req.RestrictModels, RestrictModels: req.RestrictModels,
Features: req.Features,
FeaturesConfig: req.FeaturesConfig,
ApplyPricingToAccountStats: req.ApplyPricingToAccountStats,
AccountStatsPricingRules: statsRules,
}) })
if err != nil { if err != nil {
response.ErrorFrom(c, err) response.ErrorFrom(c, err)
...@@ -332,11 +418,38 @@ func (h *ChannelHandler) Update(c *gin.Context) { ...@@ -332,11 +418,38 @@ func (h *ChannelHandler) Update(c *gin.Context) {
ModelMapping: req.ModelMapping, ModelMapping: req.ModelMapping,
BillingModelSource: req.BillingModelSource, BillingModelSource: req.BillingModelSource,
RestrictModels: req.RestrictModels, RestrictModels: req.RestrictModels,
Features: req.Features,
FeaturesConfig: req.FeaturesConfig,
ApplyPricingToAccountStats: req.ApplyPricingToAccountStats,
} }
if req.ModelPricing != nil { if req.ModelPricing != nil {
pricing := pricingRequestToService(*req.ModelPricing) pricing := pricingRequestToService(*req.ModelPricing)
for i := range pricing {
if pricing[i].Platform == "" {
pricing[i].Platform = service.PlatformAnthropic
}
}
input.ModelPricing = &pricing input.ModelPricing = &pricing
} }
if req.AccountStatsPricingRules != nil {
statsRules := make([]service.AccountStatsPricingRule, 0, len(*req.AccountStatsPricingRules))
for i, r := range *req.AccountStatsPricingRules {
if len(r.GroupIDs) == 0 && len(r.AccountIDs) == 0 {
response.ErrorFrom(c, infraerrors.BadRequest("PRICING_RULE_EMPTY_SCOPE",
fmt.Sprintf("pricing rule #%d must have at least one group or account", i+1)))
return
}
if len(r.Pricing) == 0 {
response.ErrorFrom(c, infraerrors.BadRequest("PRICING_RULE_EMPTY_PRICING",
fmt.Sprintf("pricing rule #%d must have at least one pricing entry", i+1)))
return
}
rule := accountStatsPricingRuleRequestToService(r)
rule.SortOrder = i
statsRules = append(statsRules, rule)
}
input.AccountStatsPricingRules = &statsRules
}
channel, err := h.channelService.Update(c.Request.Context(), id, input) channel, err := h.channelService.Update(c.Request.Context(), id, input)
if err != nil { if err != nil {
......
...@@ -273,13 +273,13 @@ func TestPricingRequestToService_Defaults(t *testing.T) { ...@@ -273,13 +273,13 @@ func TestPricingRequestToService_Defaults(t *testing.T) {
wantValue: string(service.BillingModeToken), wantValue: string(service.BillingModeToken),
}, },
{ {
name: "empty platform defaults to anthropic", name: "empty platform stays empty",
req: channelModelPricingRequest{ req: channelModelPricingRequest{
Models: []string{"m1"}, Models: []string{"m1"},
Platform: "", Platform: "",
}, },
wantField: "Platform", wantField: "Platform",
wantValue: "anthropic", wantValue: "",
}, },
} }
......
...@@ -5,11 +5,10 @@ import ( ...@@ -5,11 +5,10 @@ import (
"encoding/hex" "encoding/hex"
"encoding/json" "encoding/json"
"fmt" "fmt"
"log" "log/slog"
"net/http" "net/http"
"regexp" "regexp"
"strings" "strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/handler/dto" "github.com/Wei-Shaw/sub2api/internal/handler/dto"
...@@ -175,6 +174,12 @@ func (h *SettingHandler) GetSettings(c *gin.Context) { ...@@ -175,6 +174,12 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
EnableFingerprintUnification: settings.EnableFingerprintUnification, EnableFingerprintUnification: settings.EnableFingerprintUnification,
EnableMetadataPassthrough: settings.EnableMetadataPassthrough, EnableMetadataPassthrough: settings.EnableMetadataPassthrough,
EnableCCHSigning: settings.EnableCCHSigning, EnableCCHSigning: settings.EnableCCHSigning,
WebSearchEmulationEnabled: settings.WebSearchEmulationEnabled,
BalanceLowNotifyEnabled: settings.BalanceLowNotifyEnabled,
BalanceLowNotifyThreshold: settings.BalanceLowNotifyThreshold,
BalanceLowNotifyRechargeURL: settings.BalanceLowNotifyRechargeURL,
AccountQuotaNotifyEnabled: settings.AccountQuotaNotifyEnabled,
AccountQuotaNotifyEmails: dto.NotifyEmailEntriesFromService(settings.AccountQuotaNotifyEmails),
PaymentEnabled: paymentCfg.Enabled, PaymentEnabled: paymentCfg.Enabled,
PaymentMinAmount: paymentCfg.MinAmount, PaymentMinAmount: paymentCfg.MinAmount,
PaymentMaxAmount: paymentCfg.MaxAmount, PaymentMaxAmount: paymentCfg.MaxAmount,
...@@ -304,6 +309,13 @@ type UpdateSettingsRequest struct { ...@@ -304,6 +309,13 @@ type UpdateSettingsRequest struct {
EnableMetadataPassthrough *bool `json:"enable_metadata_passthrough"` EnableMetadataPassthrough *bool `json:"enable_metadata_passthrough"`
EnableCCHSigning *bool `json:"enable_cch_signing"` EnableCCHSigning *bool `json:"enable_cch_signing"`
// Balance low notification
BalanceLowNotifyEnabled *bool `json:"balance_low_notify_enabled"`
BalanceLowNotifyThreshold *float64 `json:"balance_low_notify_threshold"`
BalanceLowNotifyRechargeURL *string `json:"balance_low_notify_recharge_url"`
AccountQuotaNotifyEnabled *bool `json:"account_quota_notify_enabled"`
AccountQuotaNotifyEmails *[]dto.NotifyEmailEntry `json:"account_quota_notify_emails"`
// Payment configuration (integrated into settings, full replace) // Payment configuration (integrated into settings, full replace)
PaymentEnabled *bool `json:"payment_enabled"` PaymentEnabled *bool `json:"payment_enabled"`
PaymentMinAmount *float64 `json:"payment_min_amount"` PaymentMinAmount *float64 `json:"payment_min_amount"`
...@@ -881,6 +893,36 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { ...@@ -881,6 +893,36 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
} }
return previousSettings.EnableCCHSigning return previousSettings.EnableCCHSigning
}(), }(),
BalanceLowNotifyEnabled: func() bool {
if req.BalanceLowNotifyEnabled != nil {
return *req.BalanceLowNotifyEnabled
}
return previousSettings.BalanceLowNotifyEnabled
}(),
BalanceLowNotifyThreshold: func() float64 {
if req.BalanceLowNotifyThreshold != nil {
return *req.BalanceLowNotifyThreshold
}
return previousSettings.BalanceLowNotifyThreshold
}(),
BalanceLowNotifyRechargeURL: func() string {
if req.BalanceLowNotifyRechargeURL != nil {
return *req.BalanceLowNotifyRechargeURL
}
return previousSettings.BalanceLowNotifyRechargeURL
}(),
AccountQuotaNotifyEnabled: func() bool {
if req.AccountQuotaNotifyEnabled != nil {
return *req.AccountQuotaNotifyEnabled
}
return previousSettings.AccountQuotaNotifyEnabled
}(),
AccountQuotaNotifyEmails: func() []service.NotifyEmailEntry {
if req.AccountQuotaNotifyEmails != nil {
return dto.NotifyEmailEntriesToService(*req.AccountQuotaNotifyEmails)
}
return previousSettings.AccountQuotaNotifyEmails
}(),
} }
if err := h.settingService.UpdateSettings(c.Request.Context(), settings); err != nil { if err := h.settingService.UpdateSettings(c.Request.Context(), settings); err != nil {
...@@ -1027,6 +1069,11 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { ...@@ -1027,6 +1069,11 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
EnableFingerprintUnification: updatedSettings.EnableFingerprintUnification, EnableFingerprintUnification: updatedSettings.EnableFingerprintUnification,
EnableMetadataPassthrough: updatedSettings.EnableMetadataPassthrough, EnableMetadataPassthrough: updatedSettings.EnableMetadataPassthrough,
EnableCCHSigning: updatedSettings.EnableCCHSigning, EnableCCHSigning: updatedSettings.EnableCCHSigning,
BalanceLowNotifyEnabled: updatedSettings.BalanceLowNotifyEnabled,
BalanceLowNotifyThreshold: updatedSettings.BalanceLowNotifyThreshold,
BalanceLowNotifyRechargeURL: updatedSettings.BalanceLowNotifyRechargeURL,
AccountQuotaNotifyEnabled: updatedSettings.AccountQuotaNotifyEnabled,
AccountQuotaNotifyEmails: dto.NotifyEmailEntriesFromService(updatedSettings.AccountQuotaNotifyEmails),
PaymentEnabled: updatedPaymentCfg.Enabled, PaymentEnabled: updatedPaymentCfg.Enabled,
PaymentMinAmount: updatedPaymentCfg.MinAmount, PaymentMinAmount: updatedPaymentCfg.MinAmount,
PaymentMaxAmount: updatedPaymentCfg.MaxAmount, PaymentMaxAmount: updatedPaymentCfg.MaxAmount,
...@@ -1073,11 +1120,11 @@ func (h *SettingHandler) auditSettingsUpdate(c *gin.Context, before *service.Sys ...@@ -1073,11 +1120,11 @@ func (h *SettingHandler) auditSettingsUpdate(c *gin.Context, before *service.Sys
subject, _ := middleware.GetAuthSubjectFromContext(c) subject, _ := middleware.GetAuthSubjectFromContext(c)
role, _ := middleware.GetUserRoleFromContext(c) role, _ := middleware.GetUserRoleFromContext(c)
log.Printf("AUDIT: settings updated at=%s user_id=%d role=%s changed=%v", slog.Info("settings updated",
time.Now().UTC().Format(time.RFC3339), "audit", true,
subject.UserID, "user_id", subject.UserID,
role, "role", role,
changed, "changed", changed,
) )
} }
...@@ -1092,6 +1139,12 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings, ...@@ -1092,6 +1139,12 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings,
if !equalStringSlice(before.RegistrationEmailSuffixWhitelist, after.RegistrationEmailSuffixWhitelist) { if !equalStringSlice(before.RegistrationEmailSuffixWhitelist, after.RegistrationEmailSuffixWhitelist) {
changed = append(changed, "registration_email_suffix_whitelist") changed = append(changed, "registration_email_suffix_whitelist")
} }
if before.PromoCodeEnabled != after.PromoCodeEnabled {
changed = append(changed, "promo_code_enabled")
}
if before.InvitationCodeEnabled != after.InvitationCodeEnabled {
changed = append(changed, "invitation_code_enabled")
}
if before.PasswordResetEnabled != after.PasswordResetEnabled { if before.PasswordResetEnabled != after.PasswordResetEnabled {
changed = append(changed, "password_reset_enabled") changed = append(changed, "password_reset_enabled")
} }
...@@ -1302,6 +1355,9 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings, ...@@ -1302,6 +1355,9 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings,
if before.CustomMenuItems != after.CustomMenuItems { if before.CustomMenuItems != after.CustomMenuItems {
changed = append(changed, "custom_menu_items") changed = append(changed, "custom_menu_items")
} }
if before.CustomEndpoints != after.CustomEndpoints {
changed = append(changed, "custom_endpoints")
}
if before.EnableFingerprintUnification != after.EnableFingerprintUnification { if before.EnableFingerprintUnification != after.EnableFingerprintUnification {
changed = append(changed, "enable_fingerprint_unification") changed = append(changed, "enable_fingerprint_unification")
} }
...@@ -1311,6 +1367,22 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings, ...@@ -1311,6 +1367,22 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings,
if before.EnableCCHSigning != after.EnableCCHSigning { if before.EnableCCHSigning != after.EnableCCHSigning {
changed = append(changed, "enable_cch_signing") changed = append(changed, "enable_cch_signing")
} }
// Balance & quota notification
if before.BalanceLowNotifyEnabled != after.BalanceLowNotifyEnabled {
changed = append(changed, "balance_low_notify_enabled")
}
if before.BalanceLowNotifyThreshold != after.BalanceLowNotifyThreshold {
changed = append(changed, "balance_low_notify_threshold")
}
if before.BalanceLowNotifyRechargeURL != after.BalanceLowNotifyRechargeURL {
changed = append(changed, "balance_low_notify_recharge_url")
}
if before.AccountQuotaNotifyEnabled != after.AccountQuotaNotifyEnabled {
changed = append(changed, "account_quota_notify_enabled")
}
if !equalNotifyEmailEntries(before.AccountQuotaNotifyEmails, after.AccountQuotaNotifyEmails) {
changed = append(changed, "account_quota_notify_emails")
}
return changed return changed
} }
...@@ -1367,6 +1439,18 @@ func equalIntSlice(a, b []int) bool { ...@@ -1367,6 +1439,18 @@ func equalIntSlice(a, b []int) bool {
return true return true
} }
func equalNotifyEmailEntries(a, b []service.NotifyEmailEntry) bool {
if len(a) != len(b) {
return false
}
for i := range a {
if a[i].Email != b[i].Email || a[i].Verified != b[i].Verified || a[i].Disabled != b[i].Disabled {
return false
}
}
return true
}
// TestSMTPRequest 测试SMTP连接请求 // TestSMTPRequest 测试SMTP连接请求
type TestSMTPRequest struct { type TestSMTPRequest struct {
SMTPHost string `json:"smtp_host"` SMTPHost string `json:"smtp_host"`
...@@ -1847,3 +1931,80 @@ func (h *SettingHandler) UpdateStreamTimeoutSettings(c *gin.Context) { ...@@ -1847,3 +1931,80 @@ func (h *SettingHandler) UpdateStreamTimeoutSettings(c *gin.Context) {
ThresholdWindowMinutes: updatedSettings.ThresholdWindowMinutes, ThresholdWindowMinutes: updatedSettings.ThresholdWindowMinutes,
}) })
} }
// GetWebSearchEmulationConfig 获取 Web Search 模拟配置
// GET /api/v1/admin/settings/web-search-emulation
func (h *SettingHandler) GetWebSearchEmulationConfig(c *gin.Context) {
cfg, err := h.settingService.GetWebSearchEmulationConfig(c.Request.Context())
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, service.PopulateWebSearchUsage(c.Request.Context(), cfg))
}
// UpdateWebSearchEmulationConfig 更新 Web Search 模拟配置
// PUT /api/v1/admin/settings/web-search-emulation
func (h *SettingHandler) UpdateWebSearchEmulationConfig(c *gin.Context) {
var cfg service.WebSearchEmulationConfig
if err := c.ShouldBindJSON(&cfg); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
if err := h.settingService.SaveWebSearchEmulationConfig(c.Request.Context(), &cfg); err != nil {
response.ErrorFrom(c, err)
return
}
// Re-read (with sanitized api keys) to return current state
updated, err := h.settingService.GetWebSearchEmulationConfig(c.Request.Context())
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, service.PopulateWebSearchUsage(c.Request.Context(), updated))
}
// ResetWebSearchUsage 重置指定 provider 的配额用量
// POST /api/v1/admin/settings/web-search-emulation/reset-usage
func (h *SettingHandler) ResetWebSearchUsage(c *gin.Context) {
var req struct {
ProviderType string `json:"provider_type"`
}
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
if req.ProviderType == "" {
response.BadRequest(c, "provider_type is required")
return
}
if err := service.ResetWebSearchUsage(c.Request.Context(), req.ProviderType); err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, nil)
}
// TestWebSearchEmulation 测试 Web Search 搜索
// POST /api/v1/admin/settings/web-search-emulation/test
func (h *SettingHandler) TestWebSearchEmulation(c *gin.Context) {
var req struct {
Query string `json:"query"`
}
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
if strings.TrimSpace(req.Query) == "" {
req.Query = "搜索今年世界大事件"
}
result, err := service.TestWebSearch(c.Request.Context(), req.Query)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, result)
}
...@@ -23,6 +23,11 @@ func UserFromServiceShallow(u *service.User) *User { ...@@ -23,6 +23,11 @@ func UserFromServiceShallow(u *service.User) *User {
AllowedGroups: u.AllowedGroups, AllowedGroups: u.AllowedGroups,
CreatedAt: u.CreatedAt, CreatedAt: u.CreatedAt,
UpdatedAt: u.UpdatedAt, UpdatedAt: u.UpdatedAt,
BalanceNotifyEnabled: u.BalanceNotifyEnabled,
BalanceNotifyThresholdType: u.BalanceNotifyThresholdType,
BalanceNotifyThreshold: u.BalanceNotifyThreshold,
BalanceNotifyExtraEmails: NotifyEmailEntriesFromService(u.BalanceNotifyExtraEmails),
TotalRecharged: u.TotalRecharged,
} }
} }
...@@ -322,6 +327,26 @@ func AccountFromServiceShallow(a *service.Account) *Account { ...@@ -322,6 +327,26 @@ func AccountFromServiceShallow(a *service.Account) *Account {
out.QuotaWeeklyResetAt = &v out.QuotaWeeklyResetAt = &v
} }
} }
// 配额通知配置
if enabled := a.GetQuotaNotifyDailyEnabled(); enabled {
out.QuotaNotifyDailyEnabled = &enabled
}
if threshold := a.GetQuotaNotifyDailyThreshold(); threshold > 0 {
out.QuotaNotifyDailyThreshold = &threshold
}
if enabled := a.GetQuotaNotifyWeeklyEnabled(); enabled {
out.QuotaNotifyWeeklyEnabled = &enabled
}
if threshold := a.GetQuotaNotifyWeeklyThreshold(); threshold > 0 {
out.QuotaNotifyWeeklyThreshold = &threshold
}
if enabled := a.GetQuotaNotifyTotalEnabled(); enabled {
out.QuotaNotifyTotalEnabled = &enabled
}
if threshold := a.GetQuotaNotifyTotalThreshold(); threshold > 0 {
out.QuotaNotifyTotalThreshold = &threshold
}
} }
return out return out
...@@ -603,6 +628,7 @@ func UsageLogFromServiceAdmin(l *service.UsageLog) *AdminUsageLog { ...@@ -603,6 +628,7 @@ func UsageLogFromServiceAdmin(l *service.UsageLog) *AdminUsageLog {
ModelMappingChain: l.ModelMappingChain, ModelMappingChain: l.ModelMappingChain,
BillingTier: l.BillingTier, BillingTier: l.BillingTier,
AccountRateMultiplier: l.AccountRateMultiplier, AccountRateMultiplier: l.AccountRateMultiplier,
AccountStatsCost: l.AccountStatsCost,
IPAddress: l.IPAddress, IPAddress: l.IPAddress,
Account: AccountSummaryFromService(l.Account), Account: AccountSummaryFromService(l.Account),
} }
......
package dto
import "github.com/Wei-Shaw/sub2api/internal/service"
// NotifyEmailEntry represents a notification email with enable/disable and verification state.
// All emails are user-managed; maximum 3 entries per user.
type NotifyEmailEntry struct {
Email string `json:"email"`
Disabled bool `json:"disabled"`
Verified bool `json:"verified"`
}
// NotifyEmailEntriesFromService converts service entries to DTO entries.
func NotifyEmailEntriesFromService(entries []service.NotifyEmailEntry) []NotifyEmailEntry {
if entries == nil {
return nil
}
result := make([]NotifyEmailEntry, len(entries))
for i, e := range entries {
result[i] = NotifyEmailEntry{
Email: e.Email,
Disabled: e.Disabled,
Verified: e.Verified,
}
}
return result
}
// NotifyEmailEntriesToService converts DTO entries to service entries.
func NotifyEmailEntriesToService(entries []NotifyEmailEntry) []service.NotifyEmailEntry {
if entries == nil {
return nil
}
result := make([]service.NotifyEmailEntry, len(entries))
for i, e := range entries {
result[i] = service.NotifyEmailEntry{
Email: e.Email,
Disabled: e.Disabled,
Verified: e.Verified,
}
}
return result
}
...@@ -124,6 +124,9 @@ type SystemSettings struct { ...@@ -124,6 +124,9 @@ type SystemSettings struct {
EnableMetadataPassthrough bool `json:"enable_metadata_passthrough"` EnableMetadataPassthrough bool `json:"enable_metadata_passthrough"`
EnableCCHSigning bool `json:"enable_cch_signing"` EnableCCHSigning bool `json:"enable_cch_signing"`
// Web Search Emulation
WebSearchEmulationEnabled bool `json:"web_search_emulation_enabled"`
// Payment configuration // Payment configuration
PaymentEnabled bool `json:"payment_enabled"` PaymentEnabled bool `json:"payment_enabled"`
PaymentMinAmount float64 `json:"payment_min_amount"` PaymentMinAmount float64 `json:"payment_min_amount"`
...@@ -145,6 +148,13 @@ type SystemSettings struct { ...@@ -145,6 +148,13 @@ type SystemSettings struct {
PaymentCancelRateLimitWindow int `json:"payment_cancel_rate_limit_window"` PaymentCancelRateLimitWindow int `json:"payment_cancel_rate_limit_window"`
PaymentCancelRateLimitUnit string `json:"payment_cancel_rate_limit_unit"` PaymentCancelRateLimitUnit string `json:"payment_cancel_rate_limit_unit"`
PaymentCancelRateLimitMode string `json:"payment_cancel_rate_limit_window_mode"` PaymentCancelRateLimitMode string `json:"payment_cancel_rate_limit_window_mode"`
// Balance low notification
BalanceLowNotifyEnabled bool `json:"balance_low_notify_enabled"`
BalanceLowNotifyThreshold float64 `json:"balance_low_notify_threshold"`
BalanceLowNotifyRechargeURL string `json:"balance_low_notify_recharge_url"`
AccountQuotaNotifyEnabled bool `json:"account_quota_notify_enabled"`
AccountQuotaNotifyEmails []NotifyEmailEntry `json:"account_quota_notify_emails"`
} }
type DefaultSubscriptionSetting struct { type DefaultSubscriptionSetting struct {
...@@ -183,6 +193,10 @@ type PublicSettings struct { ...@@ -183,6 +193,10 @@ type PublicSettings struct {
BackendModeEnabled bool `json:"backend_mode_enabled"` BackendModeEnabled bool `json:"backend_mode_enabled"`
PaymentEnabled bool `json:"payment_enabled"` PaymentEnabled bool `json:"payment_enabled"`
Version string `json:"version"` Version string `json:"version"`
BalanceLowNotifyEnabled bool `json:"balance_low_notify_enabled"`
AccountQuotaNotifyEnabled bool `json:"account_quota_notify_enabled"`
BalanceLowNotifyThreshold float64 `json:"balance_low_notify_threshold"`
BalanceLowNotifyRechargeURL string `json:"balance_low_notify_recharge_url"`
} }
// OverloadCooldownSettings 529过载冷却配置 DTO // OverloadCooldownSettings 529过载冷却配置 DTO
......
...@@ -18,6 +18,13 @@ type User struct { ...@@ -18,6 +18,13 @@ type User struct {
CreatedAt time.Time `json:"created_at"` CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"` UpdatedAt time.Time `json:"updated_at"`
// 余额不足通知
BalanceNotifyEnabled bool `json:"balance_notify_enabled"`
BalanceNotifyThresholdType string `json:"balance_notify_threshold_type"`
BalanceNotifyThreshold *float64 `json:"balance_notify_threshold"`
BalanceNotifyExtraEmails []NotifyEmailEntry `json:"balance_notify_extra_emails"`
TotalRecharged float64 `json:"total_recharged"`
APIKeys []APIKey `json:"api_keys,omitempty"` APIKeys []APIKey `json:"api_keys,omitempty"`
Subscriptions []UserSubscription `json:"subscriptions,omitempty"` Subscriptions []UserSubscription `json:"subscriptions,omitempty"`
} }
...@@ -218,6 +225,14 @@ type Account struct { ...@@ -218,6 +225,14 @@ type Account struct {
QuotaDailyResetAt *string `json:"quota_daily_reset_at,omitempty"` QuotaDailyResetAt *string `json:"quota_daily_reset_at,omitempty"`
QuotaWeeklyResetAt *string `json:"quota_weekly_reset_at,omitempty"` QuotaWeeklyResetAt *string `json:"quota_weekly_reset_at,omitempty"`
// 配额通知配置
QuotaNotifyDailyEnabled *bool `json:"quota_notify_daily_enabled,omitempty"`
QuotaNotifyDailyThreshold *float64 `json:"quota_notify_daily_threshold,omitempty"`
QuotaNotifyWeeklyEnabled *bool `json:"quota_notify_weekly_enabled,omitempty"`
QuotaNotifyWeeklyThreshold *float64 `json:"quota_notify_weekly_threshold,omitempty"`
QuotaNotifyTotalEnabled *bool `json:"quota_notify_total_enabled,omitempty"`
QuotaNotifyTotalThreshold *float64 `json:"quota_notify_total_threshold,omitempty"`
Proxy *Proxy `json:"proxy,omitempty"` Proxy *Proxy `json:"proxy,omitempty"`
AccountGroups []AccountGroup `json:"account_groups,omitempty"` AccountGroups []AccountGroup `json:"account_groups,omitempty"`
...@@ -412,6 +427,8 @@ type AdminUsageLog struct { ...@@ -412,6 +427,8 @@ type AdminUsageLog struct {
// AccountRateMultiplier 账号计费倍率快照(nil 表示按 1.0 处理) // AccountRateMultiplier 账号计费倍率快照(nil 表示按 1.0 处理)
AccountRateMultiplier *float64 `json:"account_rate_multiplier"` AccountRateMultiplier *float64 `json:"account_rate_multiplier"`
// AccountStatsCost 自定义定价规则计算的账号统计费用(nil 表示使用默认公式)
AccountStatsCost *float64 `json:"account_stats_cost,omitempty"`
// IPAddress 用户请求 IP(仅管理员可见) // IPAddress 用户请求 IP(仅管理员可见)
IPAddress *string `json:"ip_address,omitempty"` IPAddress *string `json:"ip_address,omitempty"`
......
...@@ -248,6 +248,9 @@ func (h *GatewayHandler) Messages(c *gin.Context) { ...@@ -248,6 +248,9 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
return return
} }
// 设置请求所属分组 ID(用于渠道级功能判断,如 WebSearch 模拟)
parsedReq.GroupID = apiKey.GroupID
// 计算粘性会话hash // 计算粘性会话hash
parsedReq.SessionContext = &service.SessionContext{ parsedReq.SessionContext = &service.SessionContext{
ClientIP: ip.GetClientIP(c), ClientIP: ip.GetClientIP(c),
...@@ -470,6 +473,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) { ...@@ -470,6 +473,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
h.submitUsageRecordTask(func(ctx context.Context) { h.submitUsageRecordTask(func(ctx context.Context) {
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{ if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
Result: result, Result: result,
ParsedRequest: parsedReq,
APIKey: apiKey, APIKey: apiKey,
User: apiKey.User, User: apiKey.User,
Account: account, Account: account,
...@@ -518,7 +522,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) { ...@@ -518,7 +522,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
for { for {
// 选择支持该模型的账号 // 选择支持该模型的账号
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), currentAPIKey.GroupID, sessionKey, reqModel, fs.FailedAccountIDs, parsedReq.MetadataUserID, int64(0)) selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), currentAPIKey.GroupID, sessionKey, reqModel, fs.FailedAccountIDs, parsedReq.MetadataUserID, subject.UserID)
if err != nil { if err != nil {
if len(fs.FailedAccountIDs) == 0 { if len(fs.FailedAccountIDs) == 0 {
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted) h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted)
...@@ -672,6 +676,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) { ...@@ -672,6 +676,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
} }
// 转发请求 - 根据账号平台分流 // 转发请求 - 根据账号平台分流
c.Set("parsed_request", parsedReq)
var result *service.ForwardResult var result *service.ForwardResult
requestCtx := c.Request.Context() requestCtx := c.Request.Context()
if fs.SwitchCount > 0 { if fs.SwitchCount > 0 {
...@@ -810,6 +815,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) { ...@@ -810,6 +815,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
h.submitUsageRecordTask(func(ctx context.Context) { h.submitUsageRecordTask(func(ctx context.Context) {
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{ if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
Result: result, Result: result,
ParsedRequest: parsedReq,
APIKey: currentAPIKey, APIKey: currentAPIKey,
User: currentAPIKey.User, User: currentAPIKey.User,
Account: account, Account: account,
......
...@@ -168,6 +168,7 @@ func newTestGatewayHandler(t *testing.T, group *service.Group, accounts []*servi ...@@ -168,6 +168,7 @@ func newTestGatewayHandler(t *testing.T, group *service.Group, accounts []*servi
nil, // tlsFPProfileService nil, // tlsFPProfileService
nil, // channelService nil, // channelService
nil, // resolver nil, // resolver
nil, // balanceNotifyService
) )
// RunModeSimple:跳过计费检查,避免引入 repo/cache 依赖。 // RunModeSimple:跳过计费检查,避免引入 repo/cache 依赖。
......
...@@ -335,6 +335,16 @@ func (h *PaymentHandler) RequestRefund(c *gin.Context) { ...@@ -335,6 +335,16 @@ func (h *PaymentHandler) RequestRefund(c *gin.Context) {
response.Success(c, gin.H{"message": "refund requested"}) response.Success(c, gin.H{"message": "refund requested"})
} }
// GetRefundEligibleProviders returns provider instance IDs that allow user refund.
func (h *PaymentHandler) GetRefundEligibleProviders(c *gin.Context) {
ids, err := h.configService.GetUserRefundEligibleInstanceIDs(c.Request.Context())
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, gin.H{"provider_instance_ids": ids})
}
// VerifyOrderRequest is the request body for verifying a payment order. // VerifyOrderRequest is the request body for verifying a payment order.
type VerifyOrderRequest struct { type VerifyOrderRequest struct {
OutTradeNo string `json:"out_trade_no" binding:"required"` OutTradeNo string `json:"out_trade_no" binding:"required"`
......
...@@ -61,5 +61,9 @@ func (h *SettingHandler) GetPublicSettings(c *gin.Context) { ...@@ -61,5 +61,9 @@ func (h *SettingHandler) GetPublicSettings(c *gin.Context) {
BackendModeEnabled: settings.BackendModeEnabled, BackendModeEnabled: settings.BackendModeEnabled,
PaymentEnabled: settings.PaymentEnabled, PaymentEnabled: settings.PaymentEnabled,
Version: h.version, Version: h.version,
BalanceLowNotifyEnabled: settings.BalanceLowNotifyEnabled,
AccountQuotaNotifyEnabled: settings.AccountQuotaNotifyEnabled,
BalanceLowNotifyThreshold: settings.BalanceLowNotifyThreshold,
BalanceLowNotifyRechargeURL: settings.BalanceLowNotifyRechargeURL,
}) })
} }
...@@ -12,12 +12,16 @@ import ( ...@@ -12,12 +12,16 @@ import (
// UserHandler handles user-related requests // UserHandler handles user-related requests
type UserHandler struct { type UserHandler struct {
userService *service.UserService userService *service.UserService
emailService *service.EmailService
emailCache service.EmailCache
} }
// NewUserHandler creates a new UserHandler // NewUserHandler creates a new UserHandler
func NewUserHandler(userService *service.UserService) *UserHandler { func NewUserHandler(userService *service.UserService, emailService *service.EmailService, emailCache service.EmailCache) *UserHandler {
return &UserHandler{ return &UserHandler{
userService: userService, userService: userService,
emailService: emailService,
emailCache: emailCache,
} }
} }
...@@ -30,6 +34,8 @@ type ChangePasswordRequest struct { ...@@ -30,6 +34,8 @@ type ChangePasswordRequest struct {
// UpdateProfileRequest represents the update profile request payload // UpdateProfileRequest represents the update profile request payload
type UpdateProfileRequest struct { type UpdateProfileRequest struct {
Username *string `json:"username"` Username *string `json:"username"`
BalanceNotifyEnabled *bool `json:"balance_notify_enabled"`
BalanceNotifyThreshold *float64 `json:"balance_notify_threshold"`
} }
// GetProfile handles getting user profile // GetProfile handles getting user profile
...@@ -95,6 +101,8 @@ func (h *UserHandler) UpdateProfile(c *gin.Context) { ...@@ -95,6 +101,8 @@ func (h *UserHandler) UpdateProfile(c *gin.Context) {
svcReq := service.UpdateProfileRequest{ svcReq := service.UpdateProfileRequest{
Username: req.Username, Username: req.Username,
BalanceNotifyEnabled: req.BalanceNotifyEnabled,
BalanceNotifyThreshold: req.BalanceNotifyThreshold,
} }
updatedUser, err := h.userService.UpdateProfile(c.Request.Context(), subject.UserID, svcReq) updatedUser, err := h.userService.UpdateProfile(c.Request.Context(), subject.UserID, svcReq)
if err != nil { if err != nil {
...@@ -104,3 +112,141 @@ func (h *UserHandler) UpdateProfile(c *gin.Context) { ...@@ -104,3 +112,141 @@ func (h *UserHandler) UpdateProfile(c *gin.Context) {
response.Success(c, dto.UserFromService(updatedUser)) response.Success(c, dto.UserFromService(updatedUser))
} }
// SendNotifyEmailCodeRequest represents the request to send notify email verification code
type SendNotifyEmailCodeRequest struct {
Email string `json:"email" binding:"required,email"`
}
// SendNotifyEmailCode sends verification code to extra notification email
// POST /api/v1/user/notify-email/send-code
func (h *UserHandler) SendNotifyEmailCode(c *gin.Context) {
subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !ok {
response.Unauthorized(c, "User not authenticated")
return
}
var req SendNotifyEmailCodeRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
err := h.userService.SendNotifyEmailCode(c.Request.Context(), subject.UserID, req.Email, h.emailService, h.emailCache)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, gin.H{"message": "Verification code sent successfully"})
}
// VerifyNotifyEmailRequest represents the request to verify and add notify email
type VerifyNotifyEmailRequest struct {
Email string `json:"email" binding:"required,email"`
Code string `json:"code" binding:"required,len=6"`
}
// VerifyNotifyEmail verifies code and adds email to notification list
// POST /api/v1/user/notify-email/verify
func (h *UserHandler) VerifyNotifyEmail(c *gin.Context) {
subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !ok {
response.Unauthorized(c, "User not authenticated")
return
}
var req VerifyNotifyEmailRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
err := h.userService.VerifyAndAddNotifyEmail(c.Request.Context(), subject.UserID, req.Email, req.Code, h.emailCache)
if err != nil {
response.ErrorFrom(c, err)
return
}
// Return updated user
updatedUser, err := h.userService.GetByID(c.Request.Context(), subject.UserID)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, dto.UserFromService(updatedUser))
}
// RemoveNotifyEmailRequest represents the request to remove a notify email
type RemoveNotifyEmailRequest struct {
Email string `json:"email" binding:"required,email"`
}
// RemoveNotifyEmail removes email from notification list
// DELETE /api/v1/user/notify-email
func (h *UserHandler) RemoveNotifyEmail(c *gin.Context) {
subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !ok {
response.Unauthorized(c, "User not authenticated")
return
}
var req RemoveNotifyEmailRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
err := h.userService.RemoveNotifyEmail(c.Request.Context(), subject.UserID, req.Email)
if err != nil {
response.ErrorFrom(c, err)
return
}
// Return updated user
updatedUser, err := h.userService.GetByID(c.Request.Context(), subject.UserID)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, dto.UserFromService(updatedUser))
}
// ToggleNotifyEmailRequest represents the request to toggle a notify email's disabled state
type ToggleNotifyEmailRequest struct {
Email string `json:"email" binding:"required,email"`
Disabled bool `json:"disabled"`
}
// ToggleNotifyEmail toggles the disabled state of a notification email
// PUT /api/v1/user/notify-email/toggle
func (h *UserHandler) ToggleNotifyEmail(c *gin.Context) {
subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !ok {
response.Unauthorized(c, "User not authenticated")
return
}
var req ToggleNotifyEmailRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
err := h.userService.ToggleNotifyEmail(c.Request.Context(), subject.UserID, req.Email, req.Disabled)
if err != nil {
response.ErrorFrom(c, err)
return
}
updatedUser, err := h.userService.GetByID(c.Request.Context(), subject.UserID)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, dto.UserFromService(updatedUser))
}
...@@ -117,7 +117,13 @@ func (lb *DefaultLoadBalancer) queryEnabledInstances( ...@@ -117,7 +117,13 @@ func (lb *DefaultLoadBalancer) queryEnabledInstances(
var matched []*dbent.PaymentProviderInstance var matched []*dbent.PaymentProviderInstance
for _, inst := range instances { for _, inst := range instances {
if InstanceSupportsType(inst.SupportedTypes, paymentType) { // Stripe: match by provider_key because supported_types lists sub-types (card,link,alipay,wxpay),
// not "stripe" itself. The checkout page aggregates all sub-types under "stripe".
if paymentType == TypeStripe {
if inst.ProviderKey == TypeStripe {
matched = append(matched, inst)
}
} else if InstanceSupportsType(inst.SupportedTypes, paymentType) {
matched = append(matched, inst) matched = append(matched, inst)
} }
} }
......
...@@ -18,6 +18,9 @@ const ( ...@@ -18,6 +18,9 @@ const (
BlockTypeFunction BlockTypeFunction
) )
// UsageMapHook is a callback that can modify usage data before it's emitted in SSE events.
type UsageMapHook func(usageMap map[string]any)
// StreamingProcessor 流式响应处理器 // StreamingProcessor 流式响应处理器
type StreamingProcessor struct { type StreamingProcessor struct {
blockType BlockType blockType BlockType
...@@ -30,6 +33,7 @@ type StreamingProcessor struct { ...@@ -30,6 +33,7 @@ type StreamingProcessor struct {
originalModel string originalModel string
webSearchQueries []string webSearchQueries []string
groundingChunks []GeminiGroundingChunk groundingChunks []GeminiGroundingChunk
usageMapHook UsageMapHook
// 累计 usage // 累计 usage
inputTokens int inputTokens int
...@@ -46,6 +50,28 @@ func NewStreamingProcessor(originalModel string) *StreamingProcessor { ...@@ -46,6 +50,28 @@ func NewStreamingProcessor(originalModel string) *StreamingProcessor {
} }
} }
// SetUsageMapHook sets an optional hook that modifies usage maps before they are emitted.
func (p *StreamingProcessor) SetUsageMapHook(fn UsageMapHook) {
p.usageMapHook = fn
}
func usageToMap(u ClaudeUsage) map[string]any {
m := map[string]any{
"input_tokens": u.InputTokens,
"output_tokens": u.OutputTokens,
}
if u.CacheCreationInputTokens > 0 {
m["cache_creation_input_tokens"] = u.CacheCreationInputTokens
}
if u.CacheReadInputTokens > 0 {
m["cache_read_input_tokens"] = u.CacheReadInputTokens
}
if u.ImageOutputTokens > 0 {
m["image_output_tokens"] = u.ImageOutputTokens
}
return m
}
// ProcessLine 处理 SSE 行,返回 Claude SSE 事件 // ProcessLine 处理 SSE 行,返回 Claude SSE 事件
func (p *StreamingProcessor) ProcessLine(line string) []byte { func (p *StreamingProcessor) ProcessLine(line string) []byte {
line = strings.TrimSpace(line) line = strings.TrimSpace(line)
...@@ -172,6 +198,13 @@ func (p *StreamingProcessor) emitMessageStart(v1Resp *V1InternalResponse) []byte ...@@ -172,6 +198,13 @@ func (p *StreamingProcessor) emitMessageStart(v1Resp *V1InternalResponse) []byte
responseID = "msg_" + generateRandomID() responseID = "msg_" + generateRandomID()
} }
var usageValue any = usage
if p.usageMapHook != nil {
usageMap := usageToMap(usage)
p.usageMapHook(usageMap)
usageValue = usageMap
}
message := map[string]any{ message := map[string]any{
"id": responseID, "id": responseID,
"type": "message", "type": "message",
...@@ -180,7 +213,7 @@ func (p *StreamingProcessor) emitMessageStart(v1Resp *V1InternalResponse) []byte ...@@ -180,7 +213,7 @@ func (p *StreamingProcessor) emitMessageStart(v1Resp *V1InternalResponse) []byte
"model": p.originalModel, "model": p.originalModel,
"stop_reason": nil, "stop_reason": nil,
"stop_sequence": nil, "stop_sequence": nil,
"usage": usage, "usage": usageValue,
} }
event := map[string]any{ event := map[string]any{
...@@ -492,13 +525,20 @@ func (p *StreamingProcessor) emitFinish(finishReason string) []byte { ...@@ -492,13 +525,20 @@ func (p *StreamingProcessor) emitFinish(finishReason string) []byte {
ImageOutputTokens: p.imageOutputTokens, ImageOutputTokens: p.imageOutputTokens,
} }
var usageValue any = usage
if p.usageMapHook != nil {
usageMap := usageToMap(usage)
p.usageMapHook(usageMap)
usageValue = usageMap
}
deltaEvent := map[string]any{ deltaEvent := map[string]any{
"type": "message_delta", "type": "message_delta",
"delta": map[string]any{ "delta": map[string]any{
"stop_reason": stopReason, "stop_reason": stopReason,
"stop_sequence": nil, "stop_sequence": nil,
}, },
"usage": usage, "usage": usageValue,
} }
_, _ = result.Write(p.formatSSE("message_delta", deltaEvent)) _, _ = result.Write(p.formatSSE("message_delta", deltaEvent))
......
...@@ -28,6 +28,7 @@ func ChatCompletionsToResponses(req *ChatCompletionsRequest) (*ResponsesRequest, ...@@ -28,6 +28,7 @@ func ChatCompletionsToResponses(req *ChatCompletionsRequest) (*ResponsesRequest,
out := &ResponsesRequest{ out := &ResponsesRequest{
Model: req.Model, Model: req.Model,
Instructions: req.Instructions,
Input: inputJSON, Input: inputJSON,
Temperature: req.Temperature, Temperature: req.Temperature,
TopP: req.TopP, TopP: req.TopP,
......
...@@ -152,6 +152,7 @@ type AnthropicDelta struct { ...@@ -152,6 +152,7 @@ type AnthropicDelta struct {
// ResponsesRequest is the request body for POST /v1/responses. // ResponsesRequest is the request body for POST /v1/responses.
type ResponsesRequest struct { type ResponsesRequest struct {
Model string `json:"model"` Model string `json:"model"`
Instructions string `json:"instructions,omitempty"`
Input json.RawMessage `json:"input"` // string or []ResponsesInputItem Input json.RawMessage `json:"input"` // string or []ResponsesInputItem
MaxOutputTokens *int `json:"max_output_tokens,omitempty"` MaxOutputTokens *int `json:"max_output_tokens,omitempty"`
Temperature *float64 `json:"temperature,omitempty"` Temperature *float64 `json:"temperature,omitempty"`
...@@ -337,6 +338,7 @@ type ResponsesStreamEvent struct { ...@@ -337,6 +338,7 @@ type ResponsesStreamEvent struct {
type ChatCompletionsRequest struct { type ChatCompletionsRequest struct {
Model string `json:"model"` Model string `json:"model"`
Messages []ChatMessage `json:"messages"` Messages []ChatMessage `json:"messages"`
Instructions string `json:"instructions,omitempty"` // OpenAI Responses API compat
MaxTokens *int `json:"max_tokens,omitempty"` MaxTokens *int `json:"max_tokens,omitempty"`
MaxCompletionTokens *int `json:"max_completion_tokens,omitempty"` MaxCompletionTokens *int `json:"max_completion_tokens,omitempty"`
Temperature *float64 `json:"temperature,omitempty"` Temperature *float64 `json:"temperature,omitempty"`
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment