Commit 0b746501 authored by 陈曦's avatar 陈曦
Browse files

1. merge upstream v0.1.113 2.提交migration相关文件

parents 45061102 be7551b9
...@@ -183,6 +183,8 @@ github.com/icholy/digest v1.1.0 h1:HfGg9Irj7i+IX1o1QAmPfIBNu/Q5A5Tu3n/MED9k9H4= ...@@ -183,6 +183,8 @@ github.com/icholy/digest v1.1.0 h1:HfGg9Irj7i+IX1o1QAmPfIBNu/Q5A5Tu3n/MED9k9H4=
github.com/icholy/digest v1.1.0/go.mod h1:QNrsSGQ5v7v9cReDI0+eyjsXGUoRSUZQHeQ5C4XLa0Y= github.com/icholy/digest v1.1.0/go.mod h1:QNrsSGQ5v7v9cReDI0+eyjsXGUoRSUZQHeQ5C4XLa0Y=
github.com/imroc/req/v3 v3.57.0 h1:LMTUjNRUybUkTPn8oJDq8Kg3JRBOBTcnDhKu7mzupKI= github.com/imroc/req/v3 v3.57.0 h1:LMTUjNRUybUkTPn8oJDq8Kg3JRBOBTcnDhKu7mzupKI=
github.com/imroc/req/v3 v3.57.0/go.mod h1:JL62ey1nvSLq81HORNcosvlf7SxZStONNqOprg0Pz00= github.com/imroc/req/v3 v3.57.0/go.mod h1:JL62ey1nvSLq81HORNcosvlf7SxZStONNqOprg0Pz00=
github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8=
github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw=
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo= github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo=
...@@ -218,6 +220,8 @@ github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovk ...@@ -218,6 +220,8 @@ github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovk
github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM=
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/mattn/go-runewidth v0.0.15 h1:UNAjwbU9l54TA3KzvqLGxwWjHmMgBUVhBiTjelZgg3U=
github.com/mattn/go-runewidth v0.0.15/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w=
github.com/mattn/go-sqlite3 v1.14.17 h1:mCRHCLDUBXgpKAqIKsaAaAsrAlbkeomtRFKXh2L6YIM= github.com/mattn/go-sqlite3 v1.14.17 h1:mCRHCLDUBXgpKAqIKsaAaAsrAlbkeomtRFKXh2L6YIM=
github.com/mattn/go-sqlite3 v1.14.17/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg= github.com/mattn/go-sqlite3 v1.14.17/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg=
github.com/mdelapenya/tlscert v0.2.0 h1:7H81W6Z/4weDvZBNOfQte5GpIMo0lGYEeWbkGp5LJHI= github.com/mdelapenya/tlscert v0.2.0 h1:7H81W6Z/4weDvZBNOfQte5GpIMo0lGYEeWbkGp5LJHI=
...@@ -251,6 +255,8 @@ github.com/morikuni/aec v1.0.0 h1:nP9CBfwrvYnBRgY6qfDQkygYDmYwOilePFkwzv4dU8A= ...@@ -251,6 +255,8 @@ github.com/morikuni/aec v1.0.0 h1:nP9CBfwrvYnBRgY6qfDQkygYDmYwOilePFkwzv4dU8A=
github.com/morikuni/aec v1.0.0/go.mod h1:BbKIizmSmc5MMPqRYbxO4ZU0S0+P200+tUnFx7PXmsc= github.com/morikuni/aec v1.0.0/go.mod h1:BbKIizmSmc5MMPqRYbxO4ZU0S0+P200+tUnFx7PXmsc=
github.com/ncruces/go-strftime v1.0.0 h1:HMFp8mLCTPp341M/ZnA4qaf7ZlsbTc+miZjCLOFAw7w= github.com/ncruces/go-strftime v1.0.0 h1:HMFp8mLCTPp341M/ZnA4qaf7ZlsbTc+miZjCLOFAw7w=
github.com/ncruces/go-strftime v1.0.0/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls= github.com/ncruces/go-strftime v1.0.0/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls=
github.com/olekukonko/tablewriter v0.0.5 h1:P2Ga83D34wi1o9J6Wh1mRuqd4mF/x/lgBS7N7AbDhec=
github.com/olekukonko/tablewriter v0.0.5/go.mod h1:hPp6KlRPjbx+hW8ykQs1w3UBbZlj6HuIJcUGPhkA7kY=
github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8Oi/yOhh5U= github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8Oi/yOhh5U=
github.com/opencontainers/go-digest v1.0.0/go.mod h1:0JzlMkj0TRzQZfJkVvzbP0HBR3IKzErnv2BNG4W4MAM= github.com/opencontainers/go-digest v1.0.0/go.mod h1:0JzlMkj0TRzQZfJkVvzbP0HBR3IKzErnv2BNG4W4MAM=
github.com/opencontainers/image-spec v1.1.1 h1:y0fUlFfIZhPF1W537XOLg0/fcx6zcHCJwooC2xJA040= github.com/opencontainers/image-spec v1.1.1 h1:y0fUlFfIZhPF1W537XOLg0/fcx6zcHCJwooC2xJA040=
...@@ -280,6 +286,8 @@ github.com/refraction-networking/utls v1.8.2 h1:j4Q1gJj0xngdeH+Ox/qND11aEfhpgoEv ...@@ -280,6 +286,8 @@ github.com/refraction-networking/utls v1.8.2 h1:j4Q1gJj0xngdeH+Ox/qND11aEfhpgoEv
github.com/refraction-networking/utls v1.8.2/go.mod h1:jkSOEkLqn+S/jtpEHPOsVv/4V4EVnelwbMQl4vCWXAM= github.com/refraction-networking/utls v1.8.2/go.mod h1:jkSOEkLqn+S/jtpEHPOsVv/4V4EVnelwbMQl4vCWXAM=
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE= github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE=
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo= github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
github.com/rivo/uniseg v0.2.0 h1:S1pD9weZBuJdFmowNwbpi7BJ8TNftyUImj/0WQi72jY=
github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc=
github.com/robfig/cron/v3 v3.0.1 h1:WdRxkvbJztn8LMz/QEvLN5sBU+xKpSqwwUO1Pjr4qDs= github.com/robfig/cron/v3 v3.0.1 h1:WdRxkvbJztn8LMz/QEvLN5sBU+xKpSqwwUO1Pjr4qDs=
github.com/robfig/cron/v3 v3.0.1/go.mod h1:eQICP3HwyT7UooqI/z+Ov+PtYAWygg1TEWWzGIFLtro= github.com/robfig/cron/v3 v3.0.1/go.mod h1:eQICP3HwyT7UooqI/z+Ov+PtYAWygg1TEWWzGIFLtro=
github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR38lUII= github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR38lUII=
...@@ -312,6 +320,8 @@ github.com/spf13/afero v1.11.0 h1:WJQKhtpdm3v2IzqG8VMqrr6Rf3UYpEF239Jy9wNepM8= ...@@ -312,6 +320,8 @@ github.com/spf13/afero v1.11.0 h1:WJQKhtpdm3v2IzqG8VMqrr6Rf3UYpEF239Jy9wNepM8=
github.com/spf13/afero v1.11.0/go.mod h1:GH9Y3pIexgf1MTIWtNGyogA5MwRIDXGUr+hbWNoBjkY= github.com/spf13/afero v1.11.0/go.mod h1:GH9Y3pIexgf1MTIWtNGyogA5MwRIDXGUr+hbWNoBjkY=
github.com/spf13/cast v1.6.0 h1:GEiTHELF+vaR5dhz3VqZfFSzZjYbgeKDpBxQVS4GYJ0= github.com/spf13/cast v1.6.0 h1:GEiTHELF+vaR5dhz3VqZfFSzZjYbgeKDpBxQVS4GYJ0=
github.com/spf13/cast v1.6.0/go.mod h1:ancEpBxwJDODSW/UG4rDrAqiKolqNNh2DX3mk86cAdo= github.com/spf13/cast v1.6.0/go.mod h1:ancEpBxwJDODSW/UG4rDrAqiKolqNNh2DX3mk86cAdo=
github.com/spf13/cobra v1.7.0 h1:hyqWnYt1ZQShIddO5kBpj3vu05/++x6tJ6dg8EC572I=
github.com/spf13/cobra v1.7.0/go.mod h1:uLxZILRyS/50WlhOIKD7W6V5bgeIt+4sICxh6uRMrb0=
github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA= github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA=
github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
github.com/spf13/viper v1.18.2 h1:LUXCnvUvSM6FXAsj6nnfc8Q2tp1dIgUfY9Kc8GsSOiQ= github.com/spf13/viper v1.18.2 h1:LUXCnvUvSM6FXAsj6nnfc8Q2tp1dIgUfY9Kc8GsSOiQ=
......
...@@ -28,7 +28,7 @@ const ( ...@@ -28,7 +28,7 @@ const (
// DefaultCSPPolicy is the default Content-Security-Policy with nonce support // DefaultCSPPolicy is the default Content-Security-Policy with nonce support
// __CSP_NONCE__ will be replaced with actual nonce at request time by the SecurityHeaders middleware // __CSP_NONCE__ will be replaced with actual nonce at request time by the SecurityHeaders middleware
const DefaultCSPPolicy = "default-src 'self'; script-src 'self' __CSP_NONCE__ https://challenges.cloudflare.com https://static.cloudflareinsights.com; style-src 'self' 'unsafe-inline' https://fonts.googleapis.com; img-src 'self' data: https:; font-src 'self' data: https://fonts.gstatic.com; connect-src 'self' https:; frame-src https://challenges.cloudflare.com; frame-ancestors 'none'; base-uri 'self'; form-action 'self'" const DefaultCSPPolicy = "default-src 'self'; script-src 'self' __CSP_NONCE__ https://challenges.cloudflare.com https://static.cloudflareinsights.com https://*.stripe.com; style-src 'self' 'unsafe-inline' https://fonts.googleapis.com; img-src 'self' data: https:; font-src 'self' data: https://fonts.gstatic.com; connect-src 'self' https:; frame-src https://challenges.cloudflare.com https://*.stripe.com; frame-ancestors 'none'; base-uri 'self'; form-action 'self'"
// UMQ(用户消息队列)模式常量 // UMQ(用户消息队列)模式常量
const ( const (
......
...@@ -233,12 +233,13 @@ func TestLoadForcedCodexInstructionsTemplate(t *testing.T) { ...@@ -233,12 +233,13 @@ func TestLoadForcedCodexInstructionsTemplate(t *testing.T) {
configPath := filepath.Join(tempDir, "config.yaml") configPath := filepath.Join(tempDir, "config.yaml")
require.NoError(t, os.WriteFile(templatePath, []byte("server-prefix\n\n{{ .ExistingInstructions }}"), 0o644)) require.NoError(t, os.WriteFile(templatePath, []byte("server-prefix\n\n{{ .ExistingInstructions }}"), 0o644))
require.NoError(t, os.WriteFile(configPath, []byte("gateway:\n forced_codex_instructions_template_file: \""+templatePath+"\"\n"), 0o644)) yamlSafePath := filepath.ToSlash(templatePath)
require.NoError(t, os.WriteFile(configPath, []byte("gateway:\n forced_codex_instructions_template_file: \""+yamlSafePath+"\"\n"), 0o644))
t.Setenv("DATA_DIR", tempDir) t.Setenv("DATA_DIR", tempDir)
cfg, err := Load() cfg, err := Load()
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, templatePath, cfg.Gateway.ForcedCodexInstructionsTemplateFile) require.Equal(t, yamlSafePath, cfg.Gateway.ForcedCodexInstructionsTemplateFile)
require.Equal(t, "server-prefix\n\n{{ .ExistingInstructions }}", cfg.Gateway.ForcedCodexInstructionsTemplate) require.Equal(t, "server-prefix\n\n{{ .ExistingInstructions }}", cfg.Gateway.ForcedCodexInstructionsTemplate)
} }
......
...@@ -1412,6 +1412,12 @@ func (h *AccountHandler) BulkUpdate(c *gin.Context) { ...@@ -1412,6 +1412,12 @@ func (h *AccountHandler) BulkUpdate(c *gin.Context) {
c.JSON(409, gin.H{ c.JSON(409, gin.H{
"error": "mixed_channel_warning", "error": "mixed_channel_warning",
"message": mixedErr.Error(), "message": mixedErr.Error(),
"details": gin.H{
"group_id": mixedErr.GroupID,
"group_name": mixedErr.GroupName,
"current_platform": mixedErr.CurrentPlatform,
"other_platform": mixedErr.OtherPlatform,
},
}) })
return return
} }
......
package admin package admin
import ( import (
"fmt"
"strconv" "strconv"
"strings" "strings"
...@@ -26,24 +27,32 @@ func NewChannelHandler(channelService *service.ChannelService, billingService *s ...@@ -26,24 +27,32 @@ func NewChannelHandler(channelService *service.ChannelService, billingService *s
// --- Request / Response types --- // --- Request / Response types ---
type createChannelRequest struct { type createChannelRequest struct {
Name string `json:"name" binding:"required,max=100"` Name string `json:"name" binding:"required,max=100"`
Description string `json:"description"` Description string `json:"description"`
GroupIDs []int64 `json:"group_ids"` GroupIDs []int64 `json:"group_ids"`
ModelPricing []channelModelPricingRequest `json:"model_pricing"` ModelPricing []channelModelPricingRequest `json:"model_pricing"`
ModelMapping map[string]map[string]string `json:"model_mapping"` ModelMapping map[string]map[string]string `json:"model_mapping"`
BillingModelSource string `json:"billing_model_source" binding:"omitempty,oneof=requested upstream channel_mapped"` BillingModelSource string `json:"billing_model_source" binding:"omitempty,oneof=requested upstream channel_mapped"`
RestrictModels bool `json:"restrict_models"` RestrictModels bool `json:"restrict_models"`
Features string `json:"features"`
FeaturesConfig map[string]any `json:"features_config"`
ApplyPricingToAccountStats bool `json:"apply_pricing_to_account_stats"`
AccountStatsPricingRules []accountStatsPricingRuleRequest `json:"account_stats_pricing_rules"`
} }
type updateChannelRequest struct { type updateChannelRequest struct {
Name string `json:"name" binding:"omitempty,max=100"` Name string `json:"name" binding:"omitempty,max=100"`
Description *string `json:"description"` Description *string `json:"description"`
Status string `json:"status" binding:"omitempty,oneof=active disabled"` Status string `json:"status" binding:"omitempty,oneof=active disabled"`
GroupIDs *[]int64 `json:"group_ids"` GroupIDs *[]int64 `json:"group_ids"`
ModelPricing *[]channelModelPricingRequest `json:"model_pricing"` ModelPricing *[]channelModelPricingRequest `json:"model_pricing"`
ModelMapping map[string]map[string]string `json:"model_mapping"` ModelMapping map[string]map[string]string `json:"model_mapping"`
BillingModelSource string `json:"billing_model_source" binding:"omitempty,oneof=requested upstream channel_mapped"` BillingModelSource string `json:"billing_model_source" binding:"omitempty,oneof=requested upstream channel_mapped"`
RestrictModels *bool `json:"restrict_models"` RestrictModels *bool `json:"restrict_models"`
Features *string `json:"features"`
FeaturesConfig map[string]any `json:"features_config"`
ApplyPricingToAccountStats *bool `json:"apply_pricing_to_account_stats"`
AccountStatsPricingRules *[]accountStatsPricingRuleRequest `json:"account_stats_pricing_rules"`
} }
type channelModelPricingRequest struct { type channelModelPricingRequest struct {
...@@ -71,18 +80,29 @@ type pricingIntervalRequest struct { ...@@ -71,18 +80,29 @@ type pricingIntervalRequest struct {
SortOrder int `json:"sort_order"` SortOrder int `json:"sort_order"`
} }
type accountStatsPricingRuleRequest struct {
Name string `json:"name"`
GroupIDs []int64 `json:"group_ids"`
AccountIDs []int64 `json:"account_ids"`
Pricing []channelModelPricingRequest `json:"pricing"`
}
type channelResponse struct { type channelResponse struct {
ID int64 `json:"id"` ID int64 `json:"id"`
Name string `json:"name"` Name string `json:"name"`
Description string `json:"description"` Description string `json:"description"`
Status string `json:"status"` Status string `json:"status"`
BillingModelSource string `json:"billing_model_source"` BillingModelSource string `json:"billing_model_source"`
RestrictModels bool `json:"restrict_models"` RestrictModels bool `json:"restrict_models"`
GroupIDs []int64 `json:"group_ids"` Features string `json:"features"`
ModelPricing []channelModelPricingResponse `json:"model_pricing"` FeaturesConfig map[string]any `json:"features_config"`
ModelMapping map[string]map[string]string `json:"model_mapping"` GroupIDs []int64 `json:"group_ids"`
CreatedAt string `json:"created_at"` ModelPricing []channelModelPricingResponse `json:"model_pricing"`
UpdatedAt string `json:"updated_at"` ModelMapping map[string]map[string]string `json:"model_mapping"`
ApplyPricingToAccountStats bool `json:"apply_pricing_to_account_stats"`
AccountStatsPricingRules []accountStatsPricingRuleResponse `json:"account_stats_pricing_rules"`
CreatedAt string `json:"created_at"`
UpdatedAt string `json:"updated_at"`
} }
type channelModelPricingResponse struct { type channelModelPricingResponse struct {
...@@ -112,6 +132,14 @@ type pricingIntervalResponse struct { ...@@ -112,6 +132,14 @@ type pricingIntervalResponse struct {
SortOrder int `json:"sort_order"` SortOrder int `json:"sort_order"`
} }
type accountStatsPricingRuleResponse struct {
ID int64 `json:"id"`
Name string `json:"name"`
GroupIDs []int64 `json:"group_ids"`
AccountIDs []int64 `json:"account_ids"`
Pricing []channelModelPricingResponse `json:"pricing"`
}
func channelToResponse(ch *service.Channel) *channelResponse { func channelToResponse(ch *service.Channel) *channelResponse {
if ch == nil { if ch == nil {
return nil return nil
...@@ -122,6 +150,8 @@ func channelToResponse(ch *service.Channel) *channelResponse { ...@@ -122,6 +150,8 @@ func channelToResponse(ch *service.Channel) *channelResponse {
Description: ch.Description, Description: ch.Description,
Status: ch.Status, Status: ch.Status,
RestrictModels: ch.RestrictModels, RestrictModels: ch.RestrictModels,
Features: ch.Features,
FeaturesConfig: ch.FeaturesConfig,
GroupIDs: ch.GroupIDs, GroupIDs: ch.GroupIDs,
ModelMapping: ch.ModelMapping, ModelMapping: ch.ModelMapping,
CreatedAt: ch.CreatedAt.Format("2006-01-02T15:04:05Z"), CreatedAt: ch.CreatedAt.Format("2006-01-02T15:04:05Z"),
...@@ -142,6 +172,29 @@ func channelToResponse(ch *service.Channel) *channelResponse { ...@@ -142,6 +172,29 @@ func channelToResponse(ch *service.Channel) *channelResponse {
for _, p := range ch.ModelPricing { for _, p := range ch.ModelPricing {
resp.ModelPricing = append(resp.ModelPricing, pricingToResponse(&p)) resp.ModelPricing = append(resp.ModelPricing, pricingToResponse(&p))
} }
resp.ApplyPricingToAccountStats = ch.ApplyPricingToAccountStats
resp.AccountStatsPricingRules = make([]accountStatsPricingRuleResponse, 0, len(ch.AccountStatsPricingRules))
for _, rule := range ch.AccountStatsPricingRules {
ruleResp := accountStatsPricingRuleResponse{
ID: rule.ID,
Name: rule.Name,
GroupIDs: rule.GroupIDs,
AccountIDs: rule.AccountIDs,
Pricing: make([]channelModelPricingResponse, 0, len(rule.Pricing)),
}
if ruleResp.GroupIDs == nil {
ruleResp.GroupIDs = []int64{}
}
if ruleResp.AccountIDs == nil {
ruleResp.AccountIDs = []int64{}
}
for i := range rule.Pricing {
ruleResp.Pricing = append(ruleResp.Pricing, pricingToResponse(&rule.Pricing[i]))
}
resp.AccountStatsPricingRules = append(resp.AccountStatsPricingRules, ruleResp)
}
return resp return resp
} }
...@@ -200,9 +253,6 @@ func pricingRequestToService(reqs []channelModelPricingRequest) []service.Channe ...@@ -200,9 +253,6 @@ func pricingRequestToService(reqs []channelModelPricingRequest) []service.Channe
billingMode = service.BillingModeToken billingMode = service.BillingModeToken
} }
platform := r.Platform platform := r.Platform
if platform == "" {
platform = service.PlatformAnthropic
}
intervals := make([]service.PricingInterval, 0, len(r.Intervals)) intervals := make([]service.PricingInterval, 0, len(r.Intervals))
for _, iv := range r.Intervals { for _, iv := range r.Intervals {
intervals = append(intervals, service.PricingInterval{ intervals = append(intervals, service.PricingInterval{
...@@ -233,6 +283,15 @@ func pricingRequestToService(reqs []channelModelPricingRequest) []service.Channe ...@@ -233,6 +283,15 @@ func pricingRequestToService(reqs []channelModelPricingRequest) []service.Channe
return result return result
} }
func accountStatsPricingRuleRequestToService(r accountStatsPricingRuleRequest) service.AccountStatsPricingRule {
return service.AccountStatsPricingRule{
Name: r.Name,
GroupIDs: r.GroupIDs,
AccountIDs: r.AccountIDs,
Pricing: pricingRequestToService(r.Pricing),
}
}
// --- Handlers --- // --- Handlers ---
// List handles listing channels with pagination // List handles listing channels with pagination
...@@ -291,15 +350,42 @@ func (h *ChannelHandler) Create(c *gin.Context) { ...@@ -291,15 +350,42 @@ func (h *ChannelHandler) Create(c *gin.Context) {
} }
pricing := pricingRequestToService(req.ModelPricing) pricing := pricingRequestToService(req.ModelPricing)
// Main model_pricing requires a platform; default to anthropic for backward compatibility.
for i := range pricing {
if pricing[i].Platform == "" {
pricing[i].Platform = service.PlatformAnthropic
}
}
var statsRules []service.AccountStatsPricingRule
for i, r := range req.AccountStatsPricingRules {
if len(r.GroupIDs) == 0 && len(r.AccountIDs) == 0 {
response.ErrorFrom(c, infraerrors.BadRequest("PRICING_RULE_EMPTY_SCOPE",
fmt.Sprintf("pricing rule #%d must have at least one group or account", i+1)))
return
}
if len(r.Pricing) == 0 {
response.ErrorFrom(c, infraerrors.BadRequest("PRICING_RULE_EMPTY_PRICING",
fmt.Sprintf("pricing rule #%d must have at least one pricing entry", i+1)))
return
}
rule := accountStatsPricingRuleRequestToService(r)
rule.SortOrder = i
statsRules = append(statsRules, rule)
}
channel, err := h.channelService.Create(c.Request.Context(), &service.CreateChannelInput{ channel, err := h.channelService.Create(c.Request.Context(), &service.CreateChannelInput{
Name: req.Name, Name: req.Name,
Description: req.Description, Description: req.Description,
GroupIDs: req.GroupIDs, GroupIDs: req.GroupIDs,
ModelPricing: pricing, ModelPricing: pricing,
ModelMapping: req.ModelMapping, ModelMapping: req.ModelMapping,
BillingModelSource: req.BillingModelSource, BillingModelSource: req.BillingModelSource,
RestrictModels: req.RestrictModels, RestrictModels: req.RestrictModels,
Features: req.Features,
FeaturesConfig: req.FeaturesConfig,
ApplyPricingToAccountStats: req.ApplyPricingToAccountStats,
AccountStatsPricingRules: statsRules,
}) })
if err != nil { if err != nil {
response.ErrorFrom(c, err) response.ErrorFrom(c, err)
...@@ -325,18 +411,45 @@ func (h *ChannelHandler) Update(c *gin.Context) { ...@@ -325,18 +411,45 @@ func (h *ChannelHandler) Update(c *gin.Context) {
} }
input := &service.UpdateChannelInput{ input := &service.UpdateChannelInput{
Name: req.Name, Name: req.Name,
Description: req.Description, Description: req.Description,
Status: req.Status, Status: req.Status,
GroupIDs: req.GroupIDs, GroupIDs: req.GroupIDs,
ModelMapping: req.ModelMapping, ModelMapping: req.ModelMapping,
BillingModelSource: req.BillingModelSource, BillingModelSource: req.BillingModelSource,
RestrictModels: req.RestrictModels, RestrictModels: req.RestrictModels,
Features: req.Features,
FeaturesConfig: req.FeaturesConfig,
ApplyPricingToAccountStats: req.ApplyPricingToAccountStats,
} }
if req.ModelPricing != nil { if req.ModelPricing != nil {
pricing := pricingRequestToService(*req.ModelPricing) pricing := pricingRequestToService(*req.ModelPricing)
for i := range pricing {
if pricing[i].Platform == "" {
pricing[i].Platform = service.PlatformAnthropic
}
}
input.ModelPricing = &pricing input.ModelPricing = &pricing
} }
if req.AccountStatsPricingRules != nil {
statsRules := make([]service.AccountStatsPricingRule, 0, len(*req.AccountStatsPricingRules))
for i, r := range *req.AccountStatsPricingRules {
if len(r.GroupIDs) == 0 && len(r.AccountIDs) == 0 {
response.ErrorFrom(c, infraerrors.BadRequest("PRICING_RULE_EMPTY_SCOPE",
fmt.Sprintf("pricing rule #%d must have at least one group or account", i+1)))
return
}
if len(r.Pricing) == 0 {
response.ErrorFrom(c, infraerrors.BadRequest("PRICING_RULE_EMPTY_PRICING",
fmt.Sprintf("pricing rule #%d must have at least one pricing entry", i+1)))
return
}
rule := accountStatsPricingRuleRequestToService(r)
rule.SortOrder = i
statsRules = append(statsRules, rule)
}
input.AccountStatsPricingRules = &statsRules
}
channel, err := h.channelService.Update(c.Request.Context(), id, input) channel, err := h.channelService.Update(c.Request.Context(), id, input)
if err != nil { if err != nil {
......
...@@ -273,13 +273,13 @@ func TestPricingRequestToService_Defaults(t *testing.T) { ...@@ -273,13 +273,13 @@ func TestPricingRequestToService_Defaults(t *testing.T) {
wantValue: string(service.BillingModeToken), wantValue: string(service.BillingModeToken),
}, },
{ {
name: "empty platform defaults to anthropic", name: "empty platform stays empty",
req: channelModelPricingRequest{ req: channelModelPricingRequest{
Models: []string{"m1"}, Models: []string{"m1"},
Platform: "", Platform: "",
}, },
wantField: "Platform", wantField: "Platform",
wantValue: "anthropic", wantValue: "",
}, },
} }
......
...@@ -5,11 +5,10 @@ import ( ...@@ -5,11 +5,10 @@ import (
"encoding/hex" "encoding/hex"
"encoding/json" "encoding/json"
"fmt" "fmt"
"log" "log/slog"
"net/http" "net/http"
"regexp" "regexp"
"strings" "strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/handler/dto" "github.com/Wei-Shaw/sub2api/internal/handler/dto"
...@@ -175,6 +174,12 @@ func (h *SettingHandler) GetSettings(c *gin.Context) { ...@@ -175,6 +174,12 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
EnableFingerprintUnification: settings.EnableFingerprintUnification, EnableFingerprintUnification: settings.EnableFingerprintUnification,
EnableMetadataPassthrough: settings.EnableMetadataPassthrough, EnableMetadataPassthrough: settings.EnableMetadataPassthrough,
EnableCCHSigning: settings.EnableCCHSigning, EnableCCHSigning: settings.EnableCCHSigning,
WebSearchEmulationEnabled: settings.WebSearchEmulationEnabled,
BalanceLowNotifyEnabled: settings.BalanceLowNotifyEnabled,
BalanceLowNotifyThreshold: settings.BalanceLowNotifyThreshold,
BalanceLowNotifyRechargeURL: settings.BalanceLowNotifyRechargeURL,
AccountQuotaNotifyEnabled: settings.AccountQuotaNotifyEnabled,
AccountQuotaNotifyEmails: dto.NotifyEmailEntriesFromService(settings.AccountQuotaNotifyEmails),
PaymentEnabled: paymentCfg.Enabled, PaymentEnabled: paymentCfg.Enabled,
PaymentMinAmount: paymentCfg.MinAmount, PaymentMinAmount: paymentCfg.MinAmount,
PaymentMaxAmount: paymentCfg.MaxAmount, PaymentMaxAmount: paymentCfg.MaxAmount,
...@@ -183,6 +188,8 @@ func (h *SettingHandler) GetSettings(c *gin.Context) { ...@@ -183,6 +188,8 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
PaymentMaxPendingOrders: paymentCfg.MaxPendingOrders, PaymentMaxPendingOrders: paymentCfg.MaxPendingOrders,
PaymentEnabledTypes: paymentCfg.EnabledTypes, PaymentEnabledTypes: paymentCfg.EnabledTypes,
PaymentBalanceDisabled: paymentCfg.BalanceDisabled, PaymentBalanceDisabled: paymentCfg.BalanceDisabled,
PaymentBalanceRechargeMultiplier: paymentCfg.BalanceRechargeMultiplier,
PaymentRechargeFeeRate: paymentCfg.RechargeFeeRate,
PaymentLoadBalanceStrat: paymentCfg.LoadBalanceStrategy, PaymentLoadBalanceStrat: paymentCfg.LoadBalanceStrategy,
PaymentProductNamePrefix: paymentCfg.ProductNamePrefix, PaymentProductNamePrefix: paymentCfg.ProductNamePrefix,
PaymentProductNameSuffix: paymentCfg.ProductNameSuffix, PaymentProductNameSuffix: paymentCfg.ProductNameSuffix,
...@@ -304,20 +311,29 @@ type UpdateSettingsRequest struct { ...@@ -304,20 +311,29 @@ type UpdateSettingsRequest struct {
EnableMetadataPassthrough *bool `json:"enable_metadata_passthrough"` EnableMetadataPassthrough *bool `json:"enable_metadata_passthrough"`
EnableCCHSigning *bool `json:"enable_cch_signing"` EnableCCHSigning *bool `json:"enable_cch_signing"`
// Balance low notification
BalanceLowNotifyEnabled *bool `json:"balance_low_notify_enabled"`
BalanceLowNotifyThreshold *float64 `json:"balance_low_notify_threshold"`
BalanceLowNotifyRechargeURL *string `json:"balance_low_notify_recharge_url"`
AccountQuotaNotifyEnabled *bool `json:"account_quota_notify_enabled"`
AccountQuotaNotifyEmails *[]dto.NotifyEmailEntry `json:"account_quota_notify_emails"`
// Payment configuration (integrated into settings, full replace) // Payment configuration (integrated into settings, full replace)
PaymentEnabled *bool `json:"payment_enabled"` PaymentEnabled *bool `json:"payment_enabled"`
PaymentMinAmount *float64 `json:"payment_min_amount"` PaymentMinAmount *float64 `json:"payment_min_amount"`
PaymentMaxAmount *float64 `json:"payment_max_amount"` PaymentMaxAmount *float64 `json:"payment_max_amount"`
PaymentDailyLimit *float64 `json:"payment_daily_limit"` PaymentDailyLimit *float64 `json:"payment_daily_limit"`
PaymentOrderTimeoutMin *int `json:"payment_order_timeout_minutes"` PaymentOrderTimeoutMin *int `json:"payment_order_timeout_minutes"`
PaymentMaxPendingOrders *int `json:"payment_max_pending_orders"` PaymentMaxPendingOrders *int `json:"payment_max_pending_orders"`
PaymentEnabledTypes []string `json:"payment_enabled_types"` PaymentEnabledTypes []string `json:"payment_enabled_types"`
PaymentBalanceDisabled *bool `json:"payment_balance_disabled"` PaymentBalanceDisabled *bool `json:"payment_balance_disabled"`
PaymentLoadBalanceStrat *string `json:"payment_load_balance_strategy"` PaymentBalanceRechargeMultiplier *float64 `json:"payment_balance_recharge_multiplier"`
PaymentProductNamePrefix *string `json:"payment_product_name_prefix"` PaymentRechargeFeeRate *float64 `json:"payment_recharge_fee_rate"`
PaymentProductNameSuffix *string `json:"payment_product_name_suffix"` PaymentLoadBalanceStrat *string `json:"payment_load_balance_strategy"`
PaymentHelpImageURL *string `json:"payment_help_image_url"` PaymentProductNamePrefix *string `json:"payment_product_name_prefix"`
PaymentHelpText *string `json:"payment_help_text"` PaymentProductNameSuffix *string `json:"payment_product_name_suffix"`
PaymentHelpImageURL *string `json:"payment_help_image_url"`
PaymentHelpText *string `json:"payment_help_text"`
// Cancel rate limit // Cancel rate limit
PaymentCancelRateLimitEnabled *bool `json:"payment_cancel_rate_limit_enabled"` PaymentCancelRateLimitEnabled *bool `json:"payment_cancel_rate_limit_enabled"`
...@@ -881,6 +897,36 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { ...@@ -881,6 +897,36 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
} }
return previousSettings.EnableCCHSigning return previousSettings.EnableCCHSigning
}(), }(),
BalanceLowNotifyEnabled: func() bool {
if req.BalanceLowNotifyEnabled != nil {
return *req.BalanceLowNotifyEnabled
}
return previousSettings.BalanceLowNotifyEnabled
}(),
BalanceLowNotifyThreshold: func() float64 {
if req.BalanceLowNotifyThreshold != nil {
return *req.BalanceLowNotifyThreshold
}
return previousSettings.BalanceLowNotifyThreshold
}(),
BalanceLowNotifyRechargeURL: func() string {
if req.BalanceLowNotifyRechargeURL != nil {
return *req.BalanceLowNotifyRechargeURL
}
return previousSettings.BalanceLowNotifyRechargeURL
}(),
AccountQuotaNotifyEnabled: func() bool {
if req.AccountQuotaNotifyEnabled != nil {
return *req.AccountQuotaNotifyEnabled
}
return previousSettings.AccountQuotaNotifyEnabled
}(),
AccountQuotaNotifyEmails: func() []service.NotifyEmailEntry {
if req.AccountQuotaNotifyEmails != nil {
return dto.NotifyEmailEntriesToService(*req.AccountQuotaNotifyEmails)
}
return previousSettings.AccountQuotaNotifyEmails
}(),
} }
if err := h.settingService.UpdateSettings(c.Request.Context(), settings); err != nil { if err := h.settingService.UpdateSettings(c.Request.Context(), settings); err != nil {
...@@ -892,24 +938,26 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { ...@@ -892,24 +938,26 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
// Skip if no payment fields were provided (prevents accidental wipe). // Skip if no payment fields were provided (prevents accidental wipe).
if h.paymentConfigService != nil && hasPaymentFields(req) { if h.paymentConfigService != nil && hasPaymentFields(req) {
paymentReq := service.UpdatePaymentConfigRequest{ paymentReq := service.UpdatePaymentConfigRequest{
Enabled: req.PaymentEnabled, Enabled: req.PaymentEnabled,
MinAmount: req.PaymentMinAmount, MinAmount: req.PaymentMinAmount,
MaxAmount: req.PaymentMaxAmount, MaxAmount: req.PaymentMaxAmount,
DailyLimit: req.PaymentDailyLimit, DailyLimit: req.PaymentDailyLimit,
OrderTimeoutMin: req.PaymentOrderTimeoutMin, OrderTimeoutMin: req.PaymentOrderTimeoutMin,
MaxPendingOrders: req.PaymentMaxPendingOrders, MaxPendingOrders: req.PaymentMaxPendingOrders,
EnabledTypes: req.PaymentEnabledTypes, EnabledTypes: req.PaymentEnabledTypes,
BalanceDisabled: req.PaymentBalanceDisabled, BalanceDisabled: req.PaymentBalanceDisabled,
LoadBalanceStrategy: req.PaymentLoadBalanceStrat, BalanceRechargeMultiplier: req.PaymentBalanceRechargeMultiplier,
ProductNamePrefix: req.PaymentProductNamePrefix, RechargeFeeRate: req.PaymentRechargeFeeRate,
ProductNameSuffix: req.PaymentProductNameSuffix, LoadBalanceStrategy: req.PaymentLoadBalanceStrat,
HelpImageURL: req.PaymentHelpImageURL, ProductNamePrefix: req.PaymentProductNamePrefix,
HelpText: req.PaymentHelpText, ProductNameSuffix: req.PaymentProductNameSuffix,
CancelRateLimitEnabled: req.PaymentCancelRateLimitEnabled, HelpImageURL: req.PaymentHelpImageURL,
CancelRateLimitMax: req.PaymentCancelRateLimitMax, HelpText: req.PaymentHelpText,
CancelRateLimitWindow: req.PaymentCancelRateLimitWindow, CancelRateLimitEnabled: req.PaymentCancelRateLimitEnabled,
CancelRateLimitUnit: req.PaymentCancelRateLimitUnit, CancelRateLimitMax: req.PaymentCancelRateLimitMax,
CancelRateLimitMode: req.PaymentCancelRateLimitMode, CancelRateLimitWindow: req.PaymentCancelRateLimitWindow,
CancelRateLimitUnit: req.PaymentCancelRateLimitUnit,
CancelRateLimitMode: req.PaymentCancelRateLimitMode,
} }
if err := h.paymentConfigService.UpdatePaymentConfig(c.Request.Context(), paymentReq); err != nil { if err := h.paymentConfigService.UpdatePaymentConfig(c.Request.Context(), paymentReq); err != nil {
response.ErrorFrom(c, err) response.ErrorFrom(c, err)
...@@ -1027,6 +1075,11 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { ...@@ -1027,6 +1075,11 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
EnableFingerprintUnification: updatedSettings.EnableFingerprintUnification, EnableFingerprintUnification: updatedSettings.EnableFingerprintUnification,
EnableMetadataPassthrough: updatedSettings.EnableMetadataPassthrough, EnableMetadataPassthrough: updatedSettings.EnableMetadataPassthrough,
EnableCCHSigning: updatedSettings.EnableCCHSigning, EnableCCHSigning: updatedSettings.EnableCCHSigning,
BalanceLowNotifyEnabled: updatedSettings.BalanceLowNotifyEnabled,
BalanceLowNotifyThreshold: updatedSettings.BalanceLowNotifyThreshold,
BalanceLowNotifyRechargeURL: updatedSettings.BalanceLowNotifyRechargeURL,
AccountQuotaNotifyEnabled: updatedSettings.AccountQuotaNotifyEnabled,
AccountQuotaNotifyEmails: dto.NotifyEmailEntriesFromService(updatedSettings.AccountQuotaNotifyEmails),
PaymentEnabled: updatedPaymentCfg.Enabled, PaymentEnabled: updatedPaymentCfg.Enabled,
PaymentMinAmount: updatedPaymentCfg.MinAmount, PaymentMinAmount: updatedPaymentCfg.MinAmount,
PaymentMaxAmount: updatedPaymentCfg.MaxAmount, PaymentMaxAmount: updatedPaymentCfg.MaxAmount,
...@@ -1035,6 +1088,8 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { ...@@ -1035,6 +1088,8 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
PaymentMaxPendingOrders: updatedPaymentCfg.MaxPendingOrders, PaymentMaxPendingOrders: updatedPaymentCfg.MaxPendingOrders,
PaymentEnabledTypes: updatedPaymentCfg.EnabledTypes, PaymentEnabledTypes: updatedPaymentCfg.EnabledTypes,
PaymentBalanceDisabled: updatedPaymentCfg.BalanceDisabled, PaymentBalanceDisabled: updatedPaymentCfg.BalanceDisabled,
PaymentBalanceRechargeMultiplier: updatedPaymentCfg.BalanceRechargeMultiplier,
PaymentRechargeFeeRate: updatedPaymentCfg.RechargeFeeRate,
PaymentLoadBalanceStrat: updatedPaymentCfg.LoadBalanceStrategy, PaymentLoadBalanceStrat: updatedPaymentCfg.LoadBalanceStrategy,
PaymentProductNamePrefix: updatedPaymentCfg.ProductNamePrefix, PaymentProductNamePrefix: updatedPaymentCfg.ProductNamePrefix,
PaymentProductNameSuffix: updatedPaymentCfg.ProductNameSuffix, PaymentProductNameSuffix: updatedPaymentCfg.ProductNameSuffix,
...@@ -1054,6 +1109,7 @@ func hasPaymentFields(req UpdateSettingsRequest) bool { ...@@ -1054,6 +1109,7 @@ func hasPaymentFields(req UpdateSettingsRequest) bool {
req.PaymentMaxAmount != nil || req.PaymentDailyLimit != nil || req.PaymentMaxAmount != nil || req.PaymentDailyLimit != nil ||
req.PaymentOrderTimeoutMin != nil || req.PaymentMaxPendingOrders != nil || req.PaymentOrderTimeoutMin != nil || req.PaymentMaxPendingOrders != nil ||
req.PaymentEnabledTypes != nil || req.PaymentBalanceDisabled != nil || req.PaymentEnabledTypes != nil || req.PaymentBalanceDisabled != nil ||
req.PaymentBalanceRechargeMultiplier != nil || req.PaymentRechargeFeeRate != nil ||
req.PaymentLoadBalanceStrat != nil || req.PaymentProductNamePrefix != nil || req.PaymentLoadBalanceStrat != nil || req.PaymentProductNamePrefix != nil ||
req.PaymentProductNameSuffix != nil || req.PaymentHelpImageURL != nil || req.PaymentProductNameSuffix != nil || req.PaymentHelpImageURL != nil ||
req.PaymentHelpText != nil || req.PaymentCancelRateLimitEnabled != nil || req.PaymentHelpText != nil || req.PaymentCancelRateLimitEnabled != nil ||
...@@ -1073,11 +1129,11 @@ func (h *SettingHandler) auditSettingsUpdate(c *gin.Context, before *service.Sys ...@@ -1073,11 +1129,11 @@ func (h *SettingHandler) auditSettingsUpdate(c *gin.Context, before *service.Sys
subject, _ := middleware.GetAuthSubjectFromContext(c) subject, _ := middleware.GetAuthSubjectFromContext(c)
role, _ := middleware.GetUserRoleFromContext(c) role, _ := middleware.GetUserRoleFromContext(c)
log.Printf("AUDIT: settings updated at=%s user_id=%d role=%s changed=%v", slog.Info("settings updated",
time.Now().UTC().Format(time.RFC3339), "audit", true,
subject.UserID, "user_id", subject.UserID,
role, "role", role,
changed, "changed", changed,
) )
} }
...@@ -1092,6 +1148,12 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings, ...@@ -1092,6 +1148,12 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings,
if !equalStringSlice(before.RegistrationEmailSuffixWhitelist, after.RegistrationEmailSuffixWhitelist) { if !equalStringSlice(before.RegistrationEmailSuffixWhitelist, after.RegistrationEmailSuffixWhitelist) {
changed = append(changed, "registration_email_suffix_whitelist") changed = append(changed, "registration_email_suffix_whitelist")
} }
if before.PromoCodeEnabled != after.PromoCodeEnabled {
changed = append(changed, "promo_code_enabled")
}
if before.InvitationCodeEnabled != after.InvitationCodeEnabled {
changed = append(changed, "invitation_code_enabled")
}
if before.PasswordResetEnabled != after.PasswordResetEnabled { if before.PasswordResetEnabled != after.PasswordResetEnabled {
changed = append(changed, "password_reset_enabled") changed = append(changed, "password_reset_enabled")
} }
...@@ -1302,6 +1364,9 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings, ...@@ -1302,6 +1364,9 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings,
if before.CustomMenuItems != after.CustomMenuItems { if before.CustomMenuItems != after.CustomMenuItems {
changed = append(changed, "custom_menu_items") changed = append(changed, "custom_menu_items")
} }
if before.CustomEndpoints != after.CustomEndpoints {
changed = append(changed, "custom_endpoints")
}
if before.EnableFingerprintUnification != after.EnableFingerprintUnification { if before.EnableFingerprintUnification != after.EnableFingerprintUnification {
changed = append(changed, "enable_fingerprint_unification") changed = append(changed, "enable_fingerprint_unification")
} }
...@@ -1311,6 +1376,22 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings, ...@@ -1311,6 +1376,22 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings,
if before.EnableCCHSigning != after.EnableCCHSigning { if before.EnableCCHSigning != after.EnableCCHSigning {
changed = append(changed, "enable_cch_signing") changed = append(changed, "enable_cch_signing")
} }
// Balance & quota notification
if before.BalanceLowNotifyEnabled != after.BalanceLowNotifyEnabled {
changed = append(changed, "balance_low_notify_enabled")
}
if before.BalanceLowNotifyThreshold != after.BalanceLowNotifyThreshold {
changed = append(changed, "balance_low_notify_threshold")
}
if before.BalanceLowNotifyRechargeURL != after.BalanceLowNotifyRechargeURL {
changed = append(changed, "balance_low_notify_recharge_url")
}
if before.AccountQuotaNotifyEnabled != after.AccountQuotaNotifyEnabled {
changed = append(changed, "account_quota_notify_enabled")
}
if !equalNotifyEmailEntries(before.AccountQuotaNotifyEmails, after.AccountQuotaNotifyEmails) {
changed = append(changed, "account_quota_notify_emails")
}
return changed return changed
} }
...@@ -1367,6 +1448,18 @@ func equalIntSlice(a, b []int) bool { ...@@ -1367,6 +1448,18 @@ func equalIntSlice(a, b []int) bool {
return true return true
} }
func equalNotifyEmailEntries(a, b []service.NotifyEmailEntry) bool {
if len(a) != len(b) {
return false
}
for i := range a {
if a[i].Email != b[i].Email || a[i].Verified != b[i].Verified || a[i].Disabled != b[i].Disabled {
return false
}
}
return true
}
// TestSMTPRequest 测试SMTP连接请求 // TestSMTPRequest 测试SMTP连接请求
type TestSMTPRequest struct { type TestSMTPRequest struct {
SMTPHost string `json:"smtp_host"` SMTPHost string `json:"smtp_host"`
...@@ -1847,3 +1940,80 @@ func (h *SettingHandler) UpdateStreamTimeoutSettings(c *gin.Context) { ...@@ -1847,3 +1940,80 @@ func (h *SettingHandler) UpdateStreamTimeoutSettings(c *gin.Context) {
ThresholdWindowMinutes: updatedSettings.ThresholdWindowMinutes, ThresholdWindowMinutes: updatedSettings.ThresholdWindowMinutes,
}) })
} }
// GetWebSearchEmulationConfig 获取 Web Search 模拟配置
// GET /api/v1/admin/settings/web-search-emulation
func (h *SettingHandler) GetWebSearchEmulationConfig(c *gin.Context) {
cfg, err := h.settingService.GetWebSearchEmulationConfig(c.Request.Context())
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, service.PopulateWebSearchUsage(c.Request.Context(), cfg))
}
// UpdateWebSearchEmulationConfig 更新 Web Search 模拟配置
// PUT /api/v1/admin/settings/web-search-emulation
func (h *SettingHandler) UpdateWebSearchEmulationConfig(c *gin.Context) {
var cfg service.WebSearchEmulationConfig
if err := c.ShouldBindJSON(&cfg); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
if err := h.settingService.SaveWebSearchEmulationConfig(c.Request.Context(), &cfg); err != nil {
response.ErrorFrom(c, err)
return
}
// Re-read (with sanitized api keys) to return current state
updated, err := h.settingService.GetWebSearchEmulationConfig(c.Request.Context())
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, service.PopulateWebSearchUsage(c.Request.Context(), updated))
}
// ResetWebSearchUsage 重置指定 provider 的配额用量
// POST /api/v1/admin/settings/web-search-emulation/reset-usage
func (h *SettingHandler) ResetWebSearchUsage(c *gin.Context) {
var req struct {
ProviderType string `json:"provider_type"`
}
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
if req.ProviderType == "" {
response.BadRequest(c, "provider_type is required")
return
}
if err := service.ResetWebSearchUsage(c.Request.Context(), req.ProviderType); err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, nil)
}
// TestWebSearchEmulation 测试 Web Search 搜索
// POST /api/v1/admin/settings/web-search-emulation/test
func (h *SettingHandler) TestWebSearchEmulation(c *gin.Context) {
var req struct {
Query string `json:"query"`
}
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
if strings.TrimSpace(req.Query) == "" {
req.Query = "搜索今年世界大事件"
}
result, err := service.TestWebSearch(c.Request.Context(), req.Query)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, result)
}
...@@ -13,16 +13,21 @@ func UserFromServiceShallow(u *service.User) *User { ...@@ -13,16 +13,21 @@ func UserFromServiceShallow(u *service.User) *User {
return nil return nil
} }
return &User{ return &User{
ID: u.ID, ID: u.ID,
Email: u.Email, Email: u.Email,
Username: u.Username, Username: u.Username,
Role: u.Role, Role: u.Role,
Balance: u.Balance, Balance: u.Balance,
Concurrency: u.Concurrency, Concurrency: u.Concurrency,
Status: u.Status, Status: u.Status,
AllowedGroups: u.AllowedGroups, AllowedGroups: u.AllowedGroups,
CreatedAt: u.CreatedAt, CreatedAt: u.CreatedAt,
UpdatedAt: u.UpdatedAt, UpdatedAt: u.UpdatedAt,
BalanceNotifyEnabled: u.BalanceNotifyEnabled,
BalanceNotifyThresholdType: u.BalanceNotifyThresholdType,
BalanceNotifyThreshold: u.BalanceNotifyThreshold,
BalanceNotifyExtraEmails: NotifyEmailEntriesFromService(u.BalanceNotifyExtraEmails),
TotalRecharged: u.TotalRecharged,
} }
} }
...@@ -322,6 +327,26 @@ func AccountFromServiceShallow(a *service.Account) *Account { ...@@ -322,6 +327,26 @@ func AccountFromServiceShallow(a *service.Account) *Account {
out.QuotaWeeklyResetAt = &v out.QuotaWeeklyResetAt = &v
} }
} }
// 配额通知配置
if enabled := a.GetQuotaNotifyDailyEnabled(); enabled {
out.QuotaNotifyDailyEnabled = &enabled
}
if threshold := a.GetQuotaNotifyDailyThreshold(); threshold > 0 {
out.QuotaNotifyDailyThreshold = &threshold
}
if enabled := a.GetQuotaNotifyWeeklyEnabled(); enabled {
out.QuotaNotifyWeeklyEnabled = &enabled
}
if threshold := a.GetQuotaNotifyWeeklyThreshold(); threshold > 0 {
out.QuotaNotifyWeeklyThreshold = &threshold
}
if enabled := a.GetQuotaNotifyTotalEnabled(); enabled {
out.QuotaNotifyTotalEnabled = &enabled
}
if threshold := a.GetQuotaNotifyTotalThreshold(); threshold > 0 {
out.QuotaNotifyTotalThreshold = &threshold
}
} }
return out return out
...@@ -603,6 +628,7 @@ func UsageLogFromServiceAdmin(l *service.UsageLog) *AdminUsageLog { ...@@ -603,6 +628,7 @@ func UsageLogFromServiceAdmin(l *service.UsageLog) *AdminUsageLog {
ModelMappingChain: l.ModelMappingChain, ModelMappingChain: l.ModelMappingChain,
BillingTier: l.BillingTier, BillingTier: l.BillingTier,
AccountRateMultiplier: l.AccountRateMultiplier, AccountRateMultiplier: l.AccountRateMultiplier,
AccountStatsCost: l.AccountStatsCost,
IPAddress: l.IPAddress, IPAddress: l.IPAddress,
Account: AccountSummaryFromService(l.Account), Account: AccountSummaryFromService(l.Account),
} }
......
package dto
import "github.com/Wei-Shaw/sub2api/internal/service"
// NotifyEmailEntry represents a notification email with enable/disable and verification state.
// All emails are user-managed; maximum 3 entries per user.
type NotifyEmailEntry struct {
Email string `json:"email"`
Disabled bool `json:"disabled"`
Verified bool `json:"verified"`
}
// NotifyEmailEntriesFromService converts service entries to DTO entries.
func NotifyEmailEntriesFromService(entries []service.NotifyEmailEntry) []NotifyEmailEntry {
if entries == nil {
return nil
}
result := make([]NotifyEmailEntry, len(entries))
for i, e := range entries {
result[i] = NotifyEmailEntry{
Email: e.Email,
Disabled: e.Disabled,
Verified: e.Verified,
}
}
return result
}
// NotifyEmailEntriesToService converts DTO entries to service entries.
func NotifyEmailEntriesToService(entries []NotifyEmailEntry) []service.NotifyEmailEntry {
if entries == nil {
return nil
}
result := make([]service.NotifyEmailEntry, len(entries))
for i, e := range entries {
result[i] = service.NotifyEmailEntry{
Email: e.Email,
Disabled: e.Disabled,
Verified: e.Verified,
}
}
return result
}
...@@ -124,20 +124,25 @@ type SystemSettings struct { ...@@ -124,20 +124,25 @@ type SystemSettings struct {
EnableMetadataPassthrough bool `json:"enable_metadata_passthrough"` EnableMetadataPassthrough bool `json:"enable_metadata_passthrough"`
EnableCCHSigning bool `json:"enable_cch_signing"` EnableCCHSigning bool `json:"enable_cch_signing"`
// Web Search Emulation
WebSearchEmulationEnabled bool `json:"web_search_emulation_enabled"`
// Payment configuration // Payment configuration
PaymentEnabled bool `json:"payment_enabled"` PaymentEnabled bool `json:"payment_enabled"`
PaymentMinAmount float64 `json:"payment_min_amount"` PaymentMinAmount float64 `json:"payment_min_amount"`
PaymentMaxAmount float64 `json:"payment_max_amount"` PaymentMaxAmount float64 `json:"payment_max_amount"`
PaymentDailyLimit float64 `json:"payment_daily_limit"` PaymentDailyLimit float64 `json:"payment_daily_limit"`
PaymentOrderTimeoutMin int `json:"payment_order_timeout_minutes"` PaymentOrderTimeoutMin int `json:"payment_order_timeout_minutes"`
PaymentMaxPendingOrders int `json:"payment_max_pending_orders"` PaymentMaxPendingOrders int `json:"payment_max_pending_orders"`
PaymentEnabledTypes []string `json:"payment_enabled_types"` PaymentEnabledTypes []string `json:"payment_enabled_types"`
PaymentBalanceDisabled bool `json:"payment_balance_disabled"` PaymentBalanceDisabled bool `json:"payment_balance_disabled"`
PaymentLoadBalanceStrat string `json:"payment_load_balance_strategy"` PaymentBalanceRechargeMultiplier float64 `json:"payment_balance_recharge_multiplier"`
PaymentProductNamePrefix string `json:"payment_product_name_prefix"` PaymentRechargeFeeRate float64 `json:"payment_recharge_fee_rate"`
PaymentProductNameSuffix string `json:"payment_product_name_suffix"` PaymentLoadBalanceStrat string `json:"payment_load_balance_strategy"`
PaymentHelpImageURL string `json:"payment_help_image_url"` PaymentProductNamePrefix string `json:"payment_product_name_prefix"`
PaymentHelpText string `json:"payment_help_text"` PaymentProductNameSuffix string `json:"payment_product_name_suffix"`
PaymentHelpImageURL string `json:"payment_help_image_url"`
PaymentHelpText string `json:"payment_help_text"`
// Cancel rate limit // Cancel rate limit
PaymentCancelRateLimitEnabled bool `json:"payment_cancel_rate_limit_enabled"` PaymentCancelRateLimitEnabled bool `json:"payment_cancel_rate_limit_enabled"`
...@@ -145,6 +150,13 @@ type SystemSettings struct { ...@@ -145,6 +150,13 @@ type SystemSettings struct {
PaymentCancelRateLimitWindow int `json:"payment_cancel_rate_limit_window"` PaymentCancelRateLimitWindow int `json:"payment_cancel_rate_limit_window"`
PaymentCancelRateLimitUnit string `json:"payment_cancel_rate_limit_unit"` PaymentCancelRateLimitUnit string `json:"payment_cancel_rate_limit_unit"`
PaymentCancelRateLimitMode string `json:"payment_cancel_rate_limit_window_mode"` PaymentCancelRateLimitMode string `json:"payment_cancel_rate_limit_window_mode"`
// Balance low notification
BalanceLowNotifyEnabled bool `json:"balance_low_notify_enabled"`
BalanceLowNotifyThreshold float64 `json:"balance_low_notify_threshold"`
BalanceLowNotifyRechargeURL string `json:"balance_low_notify_recharge_url"`
AccountQuotaNotifyEnabled bool `json:"account_quota_notify_enabled"`
AccountQuotaNotifyEmails []NotifyEmailEntry `json:"account_quota_notify_emails"`
} }
type DefaultSubscriptionSetting struct { type DefaultSubscriptionSetting struct {
...@@ -183,6 +195,10 @@ type PublicSettings struct { ...@@ -183,6 +195,10 @@ type PublicSettings struct {
BackendModeEnabled bool `json:"backend_mode_enabled"` BackendModeEnabled bool `json:"backend_mode_enabled"`
PaymentEnabled bool `json:"payment_enabled"` PaymentEnabled bool `json:"payment_enabled"`
Version string `json:"version"` Version string `json:"version"`
BalanceLowNotifyEnabled bool `json:"balance_low_notify_enabled"`
AccountQuotaNotifyEnabled bool `json:"account_quota_notify_enabled"`
BalanceLowNotifyThreshold float64 `json:"balance_low_notify_threshold"`
BalanceLowNotifyRechargeURL string `json:"balance_low_notify_recharge_url"`
} }
// OverloadCooldownSettings 529过载冷却配置 DTO // OverloadCooldownSettings 529过载冷却配置 DTO
......
...@@ -18,6 +18,13 @@ type User struct { ...@@ -18,6 +18,13 @@ type User struct {
CreatedAt time.Time `json:"created_at"` CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"` UpdatedAt time.Time `json:"updated_at"`
// 余额不足通知
BalanceNotifyEnabled bool `json:"balance_notify_enabled"`
BalanceNotifyThresholdType string `json:"balance_notify_threshold_type"`
BalanceNotifyThreshold *float64 `json:"balance_notify_threshold"`
BalanceNotifyExtraEmails []NotifyEmailEntry `json:"balance_notify_extra_emails"`
TotalRecharged float64 `json:"total_recharged"`
APIKeys []APIKey `json:"api_keys,omitempty"` APIKeys []APIKey `json:"api_keys,omitempty"`
Subscriptions []UserSubscription `json:"subscriptions,omitempty"` Subscriptions []UserSubscription `json:"subscriptions,omitempty"`
} }
...@@ -218,6 +225,14 @@ type Account struct { ...@@ -218,6 +225,14 @@ type Account struct {
QuotaDailyResetAt *string `json:"quota_daily_reset_at,omitempty"` QuotaDailyResetAt *string `json:"quota_daily_reset_at,omitempty"`
QuotaWeeklyResetAt *string `json:"quota_weekly_reset_at,omitempty"` QuotaWeeklyResetAt *string `json:"quota_weekly_reset_at,omitempty"`
// 配额通知配置
QuotaNotifyDailyEnabled *bool `json:"quota_notify_daily_enabled,omitempty"`
QuotaNotifyDailyThreshold *float64 `json:"quota_notify_daily_threshold,omitempty"`
QuotaNotifyWeeklyEnabled *bool `json:"quota_notify_weekly_enabled,omitempty"`
QuotaNotifyWeeklyThreshold *float64 `json:"quota_notify_weekly_threshold,omitempty"`
QuotaNotifyTotalEnabled *bool `json:"quota_notify_total_enabled,omitempty"`
QuotaNotifyTotalThreshold *float64 `json:"quota_notify_total_threshold,omitempty"`
Proxy *Proxy `json:"proxy,omitempty"` Proxy *Proxy `json:"proxy,omitempty"`
AccountGroups []AccountGroup `json:"account_groups,omitempty"` AccountGroups []AccountGroup `json:"account_groups,omitempty"`
...@@ -412,6 +427,8 @@ type AdminUsageLog struct { ...@@ -412,6 +427,8 @@ type AdminUsageLog struct {
// AccountRateMultiplier 账号计费倍率快照(nil 表示按 1.0 处理) // AccountRateMultiplier 账号计费倍率快照(nil 表示按 1.0 处理)
AccountRateMultiplier *float64 `json:"account_rate_multiplier"` AccountRateMultiplier *float64 `json:"account_rate_multiplier"`
// AccountStatsCost 自定义定价规则计算的账号统计费用(nil 表示使用默认公式)
AccountStatsCost *float64 `json:"account_stats_cost,omitempty"`
// IPAddress 用户请求 IP(仅管理员可见) // IPAddress 用户请求 IP(仅管理员可见)
IPAddress *string `json:"ip_address,omitempty"` IPAddress *string `json:"ip_address,omitempty"`
......
...@@ -248,6 +248,9 @@ func (h *GatewayHandler) Messages(c *gin.Context) { ...@@ -248,6 +248,9 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
return return
} }
// 设置请求所属分组 ID(用于渠道级功能判断,如 WebSearch 模拟)
parsedReq.GroupID = apiKey.GroupID
// 计算粘性会话hash // 计算粘性会话hash
parsedReq.SessionContext = &service.SessionContext{ parsedReq.SessionContext = &service.SessionContext{
ClientIP: ip.GetClientIP(c), ClientIP: ip.GetClientIP(c),
...@@ -470,6 +473,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) { ...@@ -470,6 +473,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
h.submitUsageRecordTask(func(ctx context.Context) { h.submitUsageRecordTask(func(ctx context.Context) {
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{ if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
Result: result, Result: result,
ParsedRequest: parsedReq,
APIKey: apiKey, APIKey: apiKey,
User: apiKey.User, User: apiKey.User,
Account: account, Account: account,
...@@ -518,7 +522,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) { ...@@ -518,7 +522,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
for { for {
// 选择支持该模型的账号 // 选择支持该模型的账号
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), currentAPIKey.GroupID, sessionKey, reqModel, fs.FailedAccountIDs, parsedReq.MetadataUserID, int64(0)) selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), currentAPIKey.GroupID, sessionKey, reqModel, fs.FailedAccountIDs, parsedReq.MetadataUserID, subject.UserID)
if err != nil { if err != nil {
if len(fs.FailedAccountIDs) == 0 { if len(fs.FailedAccountIDs) == 0 {
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted) h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted)
...@@ -672,6 +676,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) { ...@@ -672,6 +676,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
} }
// 转发请求 - 根据账号平台分流 // 转发请求 - 根据账号平台分流
c.Set("parsed_request", parsedReq)
var result *service.ForwardResult var result *service.ForwardResult
requestCtx := c.Request.Context() requestCtx := c.Request.Context()
if fs.SwitchCount > 0 { if fs.SwitchCount > 0 {
...@@ -810,6 +815,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) { ...@@ -810,6 +815,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
h.submitUsageRecordTask(func(ctx context.Context) { h.submitUsageRecordTask(func(ctx context.Context) {
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{ if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
Result: result, Result: result,
ParsedRequest: parsedReq,
APIKey: currentAPIKey, APIKey: currentAPIKey,
User: currentAPIKey.User, User: currentAPIKey.User,
Account: account, Account: account,
......
...@@ -168,6 +168,7 @@ func newTestGatewayHandler(t *testing.T, group *service.Group, accounts []*servi ...@@ -168,6 +168,7 @@ func newTestGatewayHandler(t *testing.T, group *service.Group, accounts []*servi
nil, // tlsFPProfileService nil, // tlsFPProfileService
nil, // channelService nil, // channelService
nil, // resolver nil, // resolver
nil, // balanceNotifyService
) )
// RunModeSimple:跳过计费检查,避免引入 repo/cache 依赖。 // RunModeSimple:跳过计费检查,避免引入 repo/cache 依赖。
......
...@@ -126,26 +126,30 @@ func (h *PaymentHandler) GetCheckoutInfo(c *gin.Context) { ...@@ -126,26 +126,30 @@ func (h *PaymentHandler) GetCheckoutInfo(c *gin.Context) {
} }
response.Success(c, checkoutInfoResponse{ response.Success(c, checkoutInfoResponse{
Methods: limitsResp.Methods, Methods: limitsResp.Methods,
GlobalMin: limitsResp.GlobalMin, GlobalMin: limitsResp.GlobalMin,
GlobalMax: limitsResp.GlobalMax, GlobalMax: limitsResp.GlobalMax,
Plans: planList, Plans: planList,
BalanceDisabled: cfg.BalanceDisabled, BalanceDisabled: cfg.BalanceDisabled,
HelpText: cfg.HelpText, BalanceRechargeMultiplier: cfg.BalanceRechargeMultiplier,
HelpImageURL: cfg.HelpImageURL, RechargeFeeRate: cfg.RechargeFeeRate,
StripePublishableKey: cfg.StripePublishableKey, HelpText: cfg.HelpText,
HelpImageURL: cfg.HelpImageURL,
StripePublishableKey: cfg.StripePublishableKey,
}) })
} }
type checkoutInfoResponse struct { type checkoutInfoResponse struct {
Methods map[string]service.MethodLimits `json:"methods"` Methods map[string]service.MethodLimits `json:"methods"`
GlobalMin float64 `json:"global_min"` GlobalMin float64 `json:"global_min"`
GlobalMax float64 `json:"global_max"` GlobalMax float64 `json:"global_max"`
Plans []checkoutPlan `json:"plans"` Plans []checkoutPlan `json:"plans"`
BalanceDisabled bool `json:"balance_disabled"` BalanceDisabled bool `json:"balance_disabled"`
HelpText string `json:"help_text"` BalanceRechargeMultiplier float64 `json:"balance_recharge_multiplier"`
HelpImageURL string `json:"help_image_url"` RechargeFeeRate float64 `json:"recharge_fee_rate"`
StripePublishableKey string `json:"stripe_publishable_key"` HelpText string `json:"help_text"`
HelpImageURL string `json:"help_image_url"`
StripePublishableKey string `json:"stripe_publishable_key"`
} }
type checkoutPlan struct { type checkoutPlan struct {
...@@ -335,6 +339,16 @@ func (h *PaymentHandler) RequestRefund(c *gin.Context) { ...@@ -335,6 +339,16 @@ func (h *PaymentHandler) RequestRefund(c *gin.Context) {
response.Success(c, gin.H{"message": "refund requested"}) response.Success(c, gin.H{"message": "refund requested"})
} }
// GetRefundEligibleProviders returns provider instance IDs that allow user refund.
func (h *PaymentHandler) GetRefundEligibleProviders(c *gin.Context) {
ids, err := h.configService.GetUserRefundEligibleInstanceIDs(c.Request.Context())
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, gin.H{"provider_instance_ids": ids})
}
// VerifyOrderRequest is the request body for verifying a payment order. // VerifyOrderRequest is the request body for verifying a payment order.
type VerifyOrderRequest struct { type VerifyOrderRequest struct {
OutTradeNo string `json:"out_trade_no" binding:"required"` OutTradeNo string `json:"out_trade_no" binding:"required"`
...@@ -371,6 +385,7 @@ type PublicOrderResult struct { ...@@ -371,6 +385,7 @@ type PublicOrderResult struct {
Amount float64 `json:"amount"` Amount float64 `json:"amount"`
PayAmount float64 `json:"pay_amount"` PayAmount float64 `json:"pay_amount"`
PaymentType string `json:"payment_type"` PaymentType string `json:"payment_type"`
OrderType string `json:"order_type"`
Status string `json:"status"` Status string `json:"status"`
} }
...@@ -394,6 +409,7 @@ func (h *PaymentHandler) VerifyOrderPublic(c *gin.Context) { ...@@ -394,6 +409,7 @@ func (h *PaymentHandler) VerifyOrderPublic(c *gin.Context) {
Amount: order.Amount, Amount: order.Amount,
PayAmount: order.PayAmount, PayAmount: order.PayAmount,
PaymentType: order.PaymentType, PaymentType: order.PaymentType,
OrderType: order.OrderType,
Status: order.Status, Status: order.Status,
}) })
} }
......
...@@ -61,5 +61,9 @@ func (h *SettingHandler) GetPublicSettings(c *gin.Context) { ...@@ -61,5 +61,9 @@ func (h *SettingHandler) GetPublicSettings(c *gin.Context) {
BackendModeEnabled: settings.BackendModeEnabled, BackendModeEnabled: settings.BackendModeEnabled,
PaymentEnabled: settings.PaymentEnabled, PaymentEnabled: settings.PaymentEnabled,
Version: h.version, Version: h.version,
BalanceLowNotifyEnabled: settings.BalanceLowNotifyEnabled,
AccountQuotaNotifyEnabled: settings.AccountQuotaNotifyEnabled,
BalanceLowNotifyThreshold: settings.BalanceLowNotifyThreshold,
BalanceLowNotifyRechargeURL: settings.BalanceLowNotifyRechargeURL,
}) })
} }
...@@ -11,13 +11,17 @@ import ( ...@@ -11,13 +11,17 @@ import (
// UserHandler handles user-related requests // UserHandler handles user-related requests
type UserHandler struct { type UserHandler struct {
userService *service.UserService userService *service.UserService
emailService *service.EmailService
emailCache service.EmailCache
} }
// NewUserHandler creates a new UserHandler // NewUserHandler creates a new UserHandler
func NewUserHandler(userService *service.UserService) *UserHandler { func NewUserHandler(userService *service.UserService, emailService *service.EmailService, emailCache service.EmailCache) *UserHandler {
return &UserHandler{ return &UserHandler{
userService: userService, userService: userService,
emailService: emailService,
emailCache: emailCache,
} }
} }
...@@ -29,7 +33,9 @@ type ChangePasswordRequest struct { ...@@ -29,7 +33,9 @@ type ChangePasswordRequest struct {
// UpdateProfileRequest represents the update profile request payload // UpdateProfileRequest represents the update profile request payload
type UpdateProfileRequest struct { type UpdateProfileRequest struct {
Username *string `json:"username"` Username *string `json:"username"`
BalanceNotifyEnabled *bool `json:"balance_notify_enabled"`
BalanceNotifyThreshold *float64 `json:"balance_notify_threshold"`
} }
// GetProfile handles getting user profile // GetProfile handles getting user profile
...@@ -94,7 +100,9 @@ func (h *UserHandler) UpdateProfile(c *gin.Context) { ...@@ -94,7 +100,9 @@ func (h *UserHandler) UpdateProfile(c *gin.Context) {
} }
svcReq := service.UpdateProfileRequest{ svcReq := service.UpdateProfileRequest{
Username: req.Username, Username: req.Username,
BalanceNotifyEnabled: req.BalanceNotifyEnabled,
BalanceNotifyThreshold: req.BalanceNotifyThreshold,
} }
updatedUser, err := h.userService.UpdateProfile(c.Request.Context(), subject.UserID, svcReq) updatedUser, err := h.userService.UpdateProfile(c.Request.Context(), subject.UserID, svcReq)
if err != nil { if err != nil {
...@@ -104,3 +112,141 @@ func (h *UserHandler) UpdateProfile(c *gin.Context) { ...@@ -104,3 +112,141 @@ func (h *UserHandler) UpdateProfile(c *gin.Context) {
response.Success(c, dto.UserFromService(updatedUser)) response.Success(c, dto.UserFromService(updatedUser))
} }
// SendNotifyEmailCodeRequest represents the request to send notify email verification code
type SendNotifyEmailCodeRequest struct {
Email string `json:"email" binding:"required,email"`
}
// SendNotifyEmailCode sends verification code to extra notification email
// POST /api/v1/user/notify-email/send-code
func (h *UserHandler) SendNotifyEmailCode(c *gin.Context) {
subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !ok {
response.Unauthorized(c, "User not authenticated")
return
}
var req SendNotifyEmailCodeRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
err := h.userService.SendNotifyEmailCode(c.Request.Context(), subject.UserID, req.Email, h.emailService, h.emailCache)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, gin.H{"message": "Verification code sent successfully"})
}
// VerifyNotifyEmailRequest represents the request to verify and add notify email
type VerifyNotifyEmailRequest struct {
Email string `json:"email" binding:"required,email"`
Code string `json:"code" binding:"required,len=6"`
}
// VerifyNotifyEmail verifies code and adds email to notification list
// POST /api/v1/user/notify-email/verify
func (h *UserHandler) VerifyNotifyEmail(c *gin.Context) {
subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !ok {
response.Unauthorized(c, "User not authenticated")
return
}
var req VerifyNotifyEmailRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
err := h.userService.VerifyAndAddNotifyEmail(c.Request.Context(), subject.UserID, req.Email, req.Code, h.emailCache)
if err != nil {
response.ErrorFrom(c, err)
return
}
// Return updated user
updatedUser, err := h.userService.GetByID(c.Request.Context(), subject.UserID)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, dto.UserFromService(updatedUser))
}
// RemoveNotifyEmailRequest represents the request to remove a notify email
type RemoveNotifyEmailRequest struct {
Email string `json:"email" binding:"required,email"`
}
// RemoveNotifyEmail removes email from notification list
// DELETE /api/v1/user/notify-email
func (h *UserHandler) RemoveNotifyEmail(c *gin.Context) {
subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !ok {
response.Unauthorized(c, "User not authenticated")
return
}
var req RemoveNotifyEmailRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
err := h.userService.RemoveNotifyEmail(c.Request.Context(), subject.UserID, req.Email)
if err != nil {
response.ErrorFrom(c, err)
return
}
// Return updated user
updatedUser, err := h.userService.GetByID(c.Request.Context(), subject.UserID)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, dto.UserFromService(updatedUser))
}
// ToggleNotifyEmailRequest represents the request to toggle a notify email's disabled state
type ToggleNotifyEmailRequest struct {
Email string `json:"email" binding:"required,email"`
Disabled bool `json:"disabled"`
}
// ToggleNotifyEmail toggles the disabled state of a notification email
// PUT /api/v1/user/notify-email/toggle
func (h *UserHandler) ToggleNotifyEmail(c *gin.Context) {
subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !ok {
response.Unauthorized(c, "User not authenticated")
return
}
var req ToggleNotifyEmailRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
err := h.userService.ToggleNotifyEmail(c.Request.Context(), subject.UserID, req.Email, req.Disabled)
if err != nil {
response.ErrorFrom(c, err)
return
}
updatedUser, err := h.userService.GetByID(c.Request.Context(), subject.UserID)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, dto.UserFromService(updatedUser))
}
...@@ -117,7 +117,13 @@ func (lb *DefaultLoadBalancer) queryEnabledInstances( ...@@ -117,7 +117,13 @@ func (lb *DefaultLoadBalancer) queryEnabledInstances(
var matched []*dbent.PaymentProviderInstance var matched []*dbent.PaymentProviderInstance
for _, inst := range instances { for _, inst := range instances {
if InstanceSupportsType(inst.SupportedTypes, paymentType) { // Stripe: match by provider_key because supported_types lists sub-types (card,link,alipay,wxpay),
// not "stripe" itself. The checkout page aggregates all sub-types under "stripe".
if paymentType == TypeStripe {
if inst.ProviderKey == TypeStripe {
matched = append(matched, inst)
}
} else if InstanceSupportsType(inst.SupportedTypes, paymentType) {
matched = append(matched, inst) matched = append(matched, inst)
} }
} }
......
...@@ -242,7 +242,7 @@ func TestFilterByLimits(t *testing.T) { ...@@ -242,7 +242,7 @@ func TestFilterByLimits(t *testing.T) {
wantIDs: nil, wantIDs: nil,
}, },
{ {
name: "empty candidates returns empty", name: "empty candidates returns empty",
candidates: nil, candidates: nil,
paymentType: "alipay", paymentType: "alipay",
orderAmount: 10, orderAmount: 10,
......
...@@ -98,9 +98,9 @@ func TestNewAlipay(t *testing.T) { ...@@ -98,9 +98,9 @@ func TestNewAlipay(t *testing.T) {
errSubstr: "privateKey", errSubstr: "privateKey",
}, },
{ {
name: "nil config map returns error for appId", name: "nil config map returns error for appId",
config: map[string]string{}, config: map[string]string{},
wantErr: true, wantErr: true,
errSubstr: "appId", errSubstr: "appId",
}, },
} }
......
...@@ -18,6 +18,9 @@ const ( ...@@ -18,6 +18,9 @@ const (
BlockTypeFunction BlockTypeFunction
) )
// UsageMapHook is a callback that can modify usage data before it's emitted in SSE events.
type UsageMapHook func(usageMap map[string]any)
// StreamingProcessor 流式响应处理器 // StreamingProcessor 流式响应处理器
type StreamingProcessor struct { type StreamingProcessor struct {
blockType BlockType blockType BlockType
...@@ -30,6 +33,7 @@ type StreamingProcessor struct { ...@@ -30,6 +33,7 @@ type StreamingProcessor struct {
originalModel string originalModel string
webSearchQueries []string webSearchQueries []string
groundingChunks []GeminiGroundingChunk groundingChunks []GeminiGroundingChunk
usageMapHook UsageMapHook
// 累计 usage // 累计 usage
inputTokens int inputTokens int
...@@ -46,6 +50,28 @@ func NewStreamingProcessor(originalModel string) *StreamingProcessor { ...@@ -46,6 +50,28 @@ func NewStreamingProcessor(originalModel string) *StreamingProcessor {
} }
} }
// SetUsageMapHook sets an optional hook that modifies usage maps before they are emitted.
func (p *StreamingProcessor) SetUsageMapHook(fn UsageMapHook) {
p.usageMapHook = fn
}
func usageToMap(u ClaudeUsage) map[string]any {
m := map[string]any{
"input_tokens": u.InputTokens,
"output_tokens": u.OutputTokens,
}
if u.CacheCreationInputTokens > 0 {
m["cache_creation_input_tokens"] = u.CacheCreationInputTokens
}
if u.CacheReadInputTokens > 0 {
m["cache_read_input_tokens"] = u.CacheReadInputTokens
}
if u.ImageOutputTokens > 0 {
m["image_output_tokens"] = u.ImageOutputTokens
}
return m
}
// ProcessLine 处理 SSE 行,返回 Claude SSE 事件 // ProcessLine 处理 SSE 行,返回 Claude SSE 事件
func (p *StreamingProcessor) ProcessLine(line string) []byte { func (p *StreamingProcessor) ProcessLine(line string) []byte {
line = strings.TrimSpace(line) line = strings.TrimSpace(line)
...@@ -172,6 +198,13 @@ func (p *StreamingProcessor) emitMessageStart(v1Resp *V1InternalResponse) []byte ...@@ -172,6 +198,13 @@ func (p *StreamingProcessor) emitMessageStart(v1Resp *V1InternalResponse) []byte
responseID = "msg_" + generateRandomID() responseID = "msg_" + generateRandomID()
} }
var usageValue any = usage
if p.usageMapHook != nil {
usageMap := usageToMap(usage)
p.usageMapHook(usageMap)
usageValue = usageMap
}
message := map[string]any{ message := map[string]any{
"id": responseID, "id": responseID,
"type": "message", "type": "message",
...@@ -180,7 +213,7 @@ func (p *StreamingProcessor) emitMessageStart(v1Resp *V1InternalResponse) []byte ...@@ -180,7 +213,7 @@ func (p *StreamingProcessor) emitMessageStart(v1Resp *V1InternalResponse) []byte
"model": p.originalModel, "model": p.originalModel,
"stop_reason": nil, "stop_reason": nil,
"stop_sequence": nil, "stop_sequence": nil,
"usage": usage, "usage": usageValue,
} }
event := map[string]any{ event := map[string]any{
...@@ -492,13 +525,20 @@ func (p *StreamingProcessor) emitFinish(finishReason string) []byte { ...@@ -492,13 +525,20 @@ func (p *StreamingProcessor) emitFinish(finishReason string) []byte {
ImageOutputTokens: p.imageOutputTokens, ImageOutputTokens: p.imageOutputTokens,
} }
var usageValue any = usage
if p.usageMapHook != nil {
usageMap := usageToMap(usage)
p.usageMapHook(usageMap)
usageValue = usageMap
}
deltaEvent := map[string]any{ deltaEvent := map[string]any{
"type": "message_delta", "type": "message_delta",
"delta": map[string]any{ "delta": map[string]any{
"stop_reason": stopReason, "stop_reason": stopReason,
"stop_sequence": nil, "stop_sequence": nil,
}, },
"usage": usage, "usage": usageValue,
} }
_, _ = result.Write(p.formatSSE("message_delta", deltaEvent)) _, _ = result.Write(p.formatSSE("message_delta", deltaEvent))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment