Commit b9b4db3d authored by song's avatar song
Browse files

Merge upstream/main

parents 5a6f60a9 dae0d532
......@@ -22,6 +22,10 @@ type Tx struct {
AccountGroup *AccountGroupClient
// Group is the client for interacting with the Group builders.
Group *GroupClient
// PromoCode is the client for interacting with the PromoCode builders.
PromoCode *PromoCodeClient
// PromoCodeUsage is the client for interacting with the PromoCodeUsage builders.
PromoCodeUsage *PromoCodeUsageClient
// Proxy is the client for interacting with the Proxy builders.
Proxy *ProxyClient
// RedeemCode is the client for interacting with the RedeemCode builders.
......@@ -175,6 +179,8 @@ func (tx *Tx) init() {
tx.Account = NewAccountClient(tx.config)
tx.AccountGroup = NewAccountGroupClient(tx.config)
tx.Group = NewGroupClient(tx.config)
tx.PromoCode = NewPromoCodeClient(tx.config)
tx.PromoCodeUsage = NewPromoCodeUsageClient(tx.config)
tx.Proxy = NewProxyClient(tx.config)
tx.RedeemCode = NewRedeemCodeClient(tx.config)
tx.Setting = NewSettingClient(tx.config)
......
......@@ -62,6 +62,8 @@ type UsageLog struct {
ActualCost float64 `json:"actual_cost,omitempty"`
// RateMultiplier holds the value of the "rate_multiplier" field.
RateMultiplier float64 `json:"rate_multiplier,omitempty"`
// AccountRateMultiplier holds the value of the "account_rate_multiplier" field.
AccountRateMultiplier *float64 `json:"account_rate_multiplier,omitempty"`
// BillingType holds the value of the "billing_type" field.
BillingType int8 `json:"billing_type,omitempty"`
// Stream holds the value of the "stream" field.
......@@ -72,6 +74,8 @@ type UsageLog struct {
FirstTokenMs *int `json:"first_token_ms,omitempty"`
// UserAgent holds the value of the "user_agent" field.
UserAgent *string `json:"user_agent,omitempty"`
// IPAddress holds the value of the "ip_address" field.
IPAddress *string `json:"ip_address,omitempty"`
// ImageCount holds the value of the "image_count" field.
ImageCount int `json:"image_count,omitempty"`
// ImageSize holds the value of the "image_size" field.
......@@ -163,11 +167,11 @@ func (*UsageLog) scanValues(columns []string) ([]any, error) {
switch columns[i] {
case usagelog.FieldStream:
values[i] = new(sql.NullBool)
case usagelog.FieldInputCost, usagelog.FieldOutputCost, usagelog.FieldCacheCreationCost, usagelog.FieldCacheReadCost, usagelog.FieldTotalCost, usagelog.FieldActualCost, usagelog.FieldRateMultiplier:
case usagelog.FieldInputCost, usagelog.FieldOutputCost, usagelog.FieldCacheCreationCost, usagelog.FieldCacheReadCost, usagelog.FieldTotalCost, usagelog.FieldActualCost, usagelog.FieldRateMultiplier, usagelog.FieldAccountRateMultiplier:
values[i] = new(sql.NullFloat64)
case usagelog.FieldID, usagelog.FieldUserID, usagelog.FieldAPIKeyID, usagelog.FieldAccountID, usagelog.FieldGroupID, usagelog.FieldSubscriptionID, usagelog.FieldInputTokens, usagelog.FieldOutputTokens, usagelog.FieldCacheCreationTokens, usagelog.FieldCacheReadTokens, usagelog.FieldCacheCreation5mTokens, usagelog.FieldCacheCreation1hTokens, usagelog.FieldBillingType, usagelog.FieldDurationMs, usagelog.FieldFirstTokenMs, usagelog.FieldImageCount:
values[i] = new(sql.NullInt64)
case usagelog.FieldRequestID, usagelog.FieldModel, usagelog.FieldUserAgent, usagelog.FieldImageSize:
case usagelog.FieldRequestID, usagelog.FieldModel, usagelog.FieldUserAgent, usagelog.FieldIPAddress, usagelog.FieldImageSize:
values[i] = new(sql.NullString)
case usagelog.FieldCreatedAt:
values[i] = new(sql.NullTime)
......@@ -314,6 +318,13 @@ func (_m *UsageLog) assignValues(columns []string, values []any) error {
} else if value.Valid {
_m.RateMultiplier = value.Float64
}
case usagelog.FieldAccountRateMultiplier:
if value, ok := values[i].(*sql.NullFloat64); !ok {
return fmt.Errorf("unexpected type %T for field account_rate_multiplier", values[i])
} else if value.Valid {
_m.AccountRateMultiplier = new(float64)
*_m.AccountRateMultiplier = value.Float64
}
case usagelog.FieldBillingType:
if value, ok := values[i].(*sql.NullInt64); !ok {
return fmt.Errorf("unexpected type %T for field billing_type", values[i])
......@@ -347,6 +358,13 @@ func (_m *UsageLog) assignValues(columns []string, values []any) error {
_m.UserAgent = new(string)
*_m.UserAgent = value.String
}
case usagelog.FieldIPAddress:
if value, ok := values[i].(*sql.NullString); !ok {
return fmt.Errorf("unexpected type %T for field ip_address", values[i])
} else if value.Valid {
_m.IPAddress = new(string)
*_m.IPAddress = value.String
}
case usagelog.FieldImageCount:
if value, ok := values[i].(*sql.NullInt64); !ok {
return fmt.Errorf("unexpected type %T for field image_count", values[i])
......@@ -491,6 +509,11 @@ func (_m *UsageLog) String() string {
builder.WriteString("rate_multiplier=")
builder.WriteString(fmt.Sprintf("%v", _m.RateMultiplier))
builder.WriteString(", ")
if v := _m.AccountRateMultiplier; v != nil {
builder.WriteString("account_rate_multiplier=")
builder.WriteString(fmt.Sprintf("%v", *v))
}
builder.WriteString(", ")
builder.WriteString("billing_type=")
builder.WriteString(fmt.Sprintf("%v", _m.BillingType))
builder.WriteString(", ")
......@@ -512,6 +535,11 @@ func (_m *UsageLog) String() string {
builder.WriteString(*v)
}
builder.WriteString(", ")
if v := _m.IPAddress; v != nil {
builder.WriteString("ip_address=")
builder.WriteString(*v)
}
builder.WriteString(", ")
builder.WriteString("image_count=")
builder.WriteString(fmt.Sprintf("%v", _m.ImageCount))
builder.WriteString(", ")
......
......@@ -54,6 +54,8 @@ const (
FieldActualCost = "actual_cost"
// FieldRateMultiplier holds the string denoting the rate_multiplier field in the database.
FieldRateMultiplier = "rate_multiplier"
// FieldAccountRateMultiplier holds the string denoting the account_rate_multiplier field in the database.
FieldAccountRateMultiplier = "account_rate_multiplier"
// FieldBillingType holds the string denoting the billing_type field in the database.
FieldBillingType = "billing_type"
// FieldStream holds the string denoting the stream field in the database.
......@@ -64,6 +66,8 @@ const (
FieldFirstTokenMs = "first_token_ms"
// FieldUserAgent holds the string denoting the user_agent field in the database.
FieldUserAgent = "user_agent"
// FieldIPAddress holds the string denoting the ip_address field in the database.
FieldIPAddress = "ip_address"
// FieldImageCount holds the string denoting the image_count field in the database.
FieldImageCount = "image_count"
// FieldImageSize holds the string denoting the image_size field in the database.
......@@ -142,11 +146,13 @@ var Columns = []string{
FieldTotalCost,
FieldActualCost,
FieldRateMultiplier,
FieldAccountRateMultiplier,
FieldBillingType,
FieldStream,
FieldDurationMs,
FieldFirstTokenMs,
FieldUserAgent,
FieldIPAddress,
FieldImageCount,
FieldImageSize,
FieldCreatedAt,
......@@ -199,6 +205,8 @@ var (
DefaultStream bool
// UserAgentValidator is a validator for the "user_agent" field. It is called by the builders before save.
UserAgentValidator func(string) error
// IPAddressValidator is a validator for the "ip_address" field. It is called by the builders before save.
IPAddressValidator func(string) error
// DefaultImageCount holds the default value on creation for the "image_count" field.
DefaultImageCount int
// ImageSizeValidator is a validator for the "image_size" field. It is called by the builders before save.
......@@ -315,6 +323,11 @@ func ByRateMultiplier(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldRateMultiplier, opts...).ToFunc()
}
// ByAccountRateMultiplier orders the results by the account_rate_multiplier field.
func ByAccountRateMultiplier(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldAccountRateMultiplier, opts...).ToFunc()
}
// ByBillingType orders the results by the billing_type field.
func ByBillingType(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldBillingType, opts...).ToFunc()
......@@ -340,6 +353,11 @@ func ByUserAgent(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldUserAgent, opts...).ToFunc()
}
// ByIPAddress orders the results by the ip_address field.
func ByIPAddress(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldIPAddress, opts...).ToFunc()
}
// ByImageCount orders the results by the image_count field.
func ByImageCount(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldImageCount, opts...).ToFunc()
......
......@@ -155,6 +155,11 @@ func RateMultiplier(v float64) predicate.UsageLog {
return predicate.UsageLog(sql.FieldEQ(FieldRateMultiplier, v))
}
// AccountRateMultiplier applies equality check predicate on the "account_rate_multiplier" field. It's identical to AccountRateMultiplierEQ.
func AccountRateMultiplier(v float64) predicate.UsageLog {
return predicate.UsageLog(sql.FieldEQ(FieldAccountRateMultiplier, v))
}
// BillingType applies equality check predicate on the "billing_type" field. It's identical to BillingTypeEQ.
func BillingType(v int8) predicate.UsageLog {
return predicate.UsageLog(sql.FieldEQ(FieldBillingType, v))
......@@ -180,6 +185,11 @@ func UserAgent(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldEQ(FieldUserAgent, v))
}
// IPAddress applies equality check predicate on the "ip_address" field. It's identical to IPAddressEQ.
func IPAddress(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldEQ(FieldIPAddress, v))
}
// ImageCount applies equality check predicate on the "image_count" field. It's identical to ImageCountEQ.
func ImageCount(v int) predicate.UsageLog {
return predicate.UsageLog(sql.FieldEQ(FieldImageCount, v))
......@@ -965,6 +975,56 @@ func RateMultiplierLTE(v float64) predicate.UsageLog {
return predicate.UsageLog(sql.FieldLTE(FieldRateMultiplier, v))
}
// AccountRateMultiplierEQ applies the EQ predicate on the "account_rate_multiplier" field.
func AccountRateMultiplierEQ(v float64) predicate.UsageLog {
return predicate.UsageLog(sql.FieldEQ(FieldAccountRateMultiplier, v))
}
// AccountRateMultiplierNEQ applies the NEQ predicate on the "account_rate_multiplier" field.
func AccountRateMultiplierNEQ(v float64) predicate.UsageLog {
return predicate.UsageLog(sql.FieldNEQ(FieldAccountRateMultiplier, v))
}
// AccountRateMultiplierIn applies the In predicate on the "account_rate_multiplier" field.
func AccountRateMultiplierIn(vs ...float64) predicate.UsageLog {
return predicate.UsageLog(sql.FieldIn(FieldAccountRateMultiplier, vs...))
}
// AccountRateMultiplierNotIn applies the NotIn predicate on the "account_rate_multiplier" field.
func AccountRateMultiplierNotIn(vs ...float64) predicate.UsageLog {
return predicate.UsageLog(sql.FieldNotIn(FieldAccountRateMultiplier, vs...))
}
// AccountRateMultiplierGT applies the GT predicate on the "account_rate_multiplier" field.
func AccountRateMultiplierGT(v float64) predicate.UsageLog {
return predicate.UsageLog(sql.FieldGT(FieldAccountRateMultiplier, v))
}
// AccountRateMultiplierGTE applies the GTE predicate on the "account_rate_multiplier" field.
func AccountRateMultiplierGTE(v float64) predicate.UsageLog {
return predicate.UsageLog(sql.FieldGTE(FieldAccountRateMultiplier, v))
}
// AccountRateMultiplierLT applies the LT predicate on the "account_rate_multiplier" field.
func AccountRateMultiplierLT(v float64) predicate.UsageLog {
return predicate.UsageLog(sql.FieldLT(FieldAccountRateMultiplier, v))
}
// AccountRateMultiplierLTE applies the LTE predicate on the "account_rate_multiplier" field.
func AccountRateMultiplierLTE(v float64) predicate.UsageLog {
return predicate.UsageLog(sql.FieldLTE(FieldAccountRateMultiplier, v))
}
// AccountRateMultiplierIsNil applies the IsNil predicate on the "account_rate_multiplier" field.
func AccountRateMultiplierIsNil() predicate.UsageLog {
return predicate.UsageLog(sql.FieldIsNull(FieldAccountRateMultiplier))
}
// AccountRateMultiplierNotNil applies the NotNil predicate on the "account_rate_multiplier" field.
func AccountRateMultiplierNotNil() predicate.UsageLog {
return predicate.UsageLog(sql.FieldNotNull(FieldAccountRateMultiplier))
}
// BillingTypeEQ applies the EQ predicate on the "billing_type" field.
func BillingTypeEQ(v int8) predicate.UsageLog {
return predicate.UsageLog(sql.FieldEQ(FieldBillingType, v))
......@@ -1190,6 +1250,81 @@ func UserAgentContainsFold(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldContainsFold(FieldUserAgent, v))
}
// IPAddressEQ applies the EQ predicate on the "ip_address" field.
func IPAddressEQ(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldEQ(FieldIPAddress, v))
}
// IPAddressNEQ applies the NEQ predicate on the "ip_address" field.
func IPAddressNEQ(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldNEQ(FieldIPAddress, v))
}
// IPAddressIn applies the In predicate on the "ip_address" field.
func IPAddressIn(vs ...string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldIn(FieldIPAddress, vs...))
}
// IPAddressNotIn applies the NotIn predicate on the "ip_address" field.
func IPAddressNotIn(vs ...string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldNotIn(FieldIPAddress, vs...))
}
// IPAddressGT applies the GT predicate on the "ip_address" field.
func IPAddressGT(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldGT(FieldIPAddress, v))
}
// IPAddressGTE applies the GTE predicate on the "ip_address" field.
func IPAddressGTE(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldGTE(FieldIPAddress, v))
}
// IPAddressLT applies the LT predicate on the "ip_address" field.
func IPAddressLT(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldLT(FieldIPAddress, v))
}
// IPAddressLTE applies the LTE predicate on the "ip_address" field.
func IPAddressLTE(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldLTE(FieldIPAddress, v))
}
// IPAddressContains applies the Contains predicate on the "ip_address" field.
func IPAddressContains(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldContains(FieldIPAddress, v))
}
// IPAddressHasPrefix applies the HasPrefix predicate on the "ip_address" field.
func IPAddressHasPrefix(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldHasPrefix(FieldIPAddress, v))
}
// IPAddressHasSuffix applies the HasSuffix predicate on the "ip_address" field.
func IPAddressHasSuffix(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldHasSuffix(FieldIPAddress, v))
}
// IPAddressIsNil applies the IsNil predicate on the "ip_address" field.
func IPAddressIsNil() predicate.UsageLog {
return predicate.UsageLog(sql.FieldIsNull(FieldIPAddress))
}
// IPAddressNotNil applies the NotNil predicate on the "ip_address" field.
func IPAddressNotNil() predicate.UsageLog {
return predicate.UsageLog(sql.FieldNotNull(FieldIPAddress))
}
// IPAddressEqualFold applies the EqualFold predicate on the "ip_address" field.
func IPAddressEqualFold(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldEqualFold(FieldIPAddress, v))
}
// IPAddressContainsFold applies the ContainsFold predicate on the "ip_address" field.
func IPAddressContainsFold(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldContainsFold(FieldIPAddress, v))
}
// ImageCountEQ applies the EQ predicate on the "image_count" field.
func ImageCountEQ(v int) predicate.UsageLog {
return predicate.UsageLog(sql.FieldEQ(FieldImageCount, v))
......
......@@ -267,6 +267,20 @@ func (_c *UsageLogCreate) SetNillableRateMultiplier(v *float64) *UsageLogCreate
return _c
}
// SetAccountRateMultiplier sets the "account_rate_multiplier" field.
func (_c *UsageLogCreate) SetAccountRateMultiplier(v float64) *UsageLogCreate {
_c.mutation.SetAccountRateMultiplier(v)
return _c
}
// SetNillableAccountRateMultiplier sets the "account_rate_multiplier" field if the given value is not nil.
func (_c *UsageLogCreate) SetNillableAccountRateMultiplier(v *float64) *UsageLogCreate {
if v != nil {
_c.SetAccountRateMultiplier(*v)
}
return _c
}
// SetBillingType sets the "billing_type" field.
func (_c *UsageLogCreate) SetBillingType(v int8) *UsageLogCreate {
_c.mutation.SetBillingType(v)
......@@ -337,6 +351,20 @@ func (_c *UsageLogCreate) SetNillableUserAgent(v *string) *UsageLogCreate {
return _c
}
// SetIPAddress sets the "ip_address" field.
func (_c *UsageLogCreate) SetIPAddress(v string) *UsageLogCreate {
_c.mutation.SetIPAddress(v)
return _c
}
// SetNillableIPAddress sets the "ip_address" field if the given value is not nil.
func (_c *UsageLogCreate) SetNillableIPAddress(v *string) *UsageLogCreate {
if v != nil {
_c.SetIPAddress(*v)
}
return _c
}
// SetImageCount sets the "image_count" field.
func (_c *UsageLogCreate) SetImageCount(v int) *UsageLogCreate {
_c.mutation.SetImageCount(v)
......@@ -586,6 +614,11 @@ func (_c *UsageLogCreate) check() error {
return &ValidationError{Name: "user_agent", err: fmt.Errorf(`ent: validator failed for field "UsageLog.user_agent": %w`, err)}
}
}
if v, ok := _c.mutation.IPAddress(); ok {
if err := usagelog.IPAddressValidator(v); err != nil {
return &ValidationError{Name: "ip_address", err: fmt.Errorf(`ent: validator failed for field "UsageLog.ip_address": %w`, err)}
}
}
if _, ok := _c.mutation.ImageCount(); !ok {
return &ValidationError{Name: "image_count", err: errors.New(`ent: missing required field "UsageLog.image_count"`)}
}
......@@ -693,6 +726,10 @@ func (_c *UsageLogCreate) createSpec() (*UsageLog, *sqlgraph.CreateSpec) {
_spec.SetField(usagelog.FieldRateMultiplier, field.TypeFloat64, value)
_node.RateMultiplier = value
}
if value, ok := _c.mutation.AccountRateMultiplier(); ok {
_spec.SetField(usagelog.FieldAccountRateMultiplier, field.TypeFloat64, value)
_node.AccountRateMultiplier = &value
}
if value, ok := _c.mutation.BillingType(); ok {
_spec.SetField(usagelog.FieldBillingType, field.TypeInt8, value)
_node.BillingType = value
......@@ -713,6 +750,10 @@ func (_c *UsageLogCreate) createSpec() (*UsageLog, *sqlgraph.CreateSpec) {
_spec.SetField(usagelog.FieldUserAgent, field.TypeString, value)
_node.UserAgent = &value
}
if value, ok := _c.mutation.IPAddress(); ok {
_spec.SetField(usagelog.FieldIPAddress, field.TypeString, value)
_node.IPAddress = &value
}
if value, ok := _c.mutation.ImageCount(); ok {
_spec.SetField(usagelog.FieldImageCount, field.TypeInt, value)
_node.ImageCount = value
......@@ -1192,6 +1233,30 @@ func (u *UsageLogUpsert) AddRateMultiplier(v float64) *UsageLogUpsert {
return u
}
// SetAccountRateMultiplier sets the "account_rate_multiplier" field.
func (u *UsageLogUpsert) SetAccountRateMultiplier(v float64) *UsageLogUpsert {
u.Set(usagelog.FieldAccountRateMultiplier, v)
return u
}
// UpdateAccountRateMultiplier sets the "account_rate_multiplier" field to the value that was provided on create.
func (u *UsageLogUpsert) UpdateAccountRateMultiplier() *UsageLogUpsert {
u.SetExcluded(usagelog.FieldAccountRateMultiplier)
return u
}
// AddAccountRateMultiplier adds v to the "account_rate_multiplier" field.
func (u *UsageLogUpsert) AddAccountRateMultiplier(v float64) *UsageLogUpsert {
u.Add(usagelog.FieldAccountRateMultiplier, v)
return u
}
// ClearAccountRateMultiplier clears the value of the "account_rate_multiplier" field.
func (u *UsageLogUpsert) ClearAccountRateMultiplier() *UsageLogUpsert {
u.SetNull(usagelog.FieldAccountRateMultiplier)
return u
}
// SetBillingType sets the "billing_type" field.
func (u *UsageLogUpsert) SetBillingType(v int8) *UsageLogUpsert {
u.Set(usagelog.FieldBillingType, v)
......@@ -1288,6 +1353,24 @@ func (u *UsageLogUpsert) ClearUserAgent() *UsageLogUpsert {
return u
}
// SetIPAddress sets the "ip_address" field.
func (u *UsageLogUpsert) SetIPAddress(v string) *UsageLogUpsert {
u.Set(usagelog.FieldIPAddress, v)
return u
}
// UpdateIPAddress sets the "ip_address" field to the value that was provided on create.
func (u *UsageLogUpsert) UpdateIPAddress() *UsageLogUpsert {
u.SetExcluded(usagelog.FieldIPAddress)
return u
}
// ClearIPAddress clears the value of the "ip_address" field.
func (u *UsageLogUpsert) ClearIPAddress() *UsageLogUpsert {
u.SetNull(usagelog.FieldIPAddress)
return u
}
// SetImageCount sets the "image_count" field.
func (u *UsageLogUpsert) SetImageCount(v int) *UsageLogUpsert {
u.Set(usagelog.FieldImageCount, v)
......@@ -1754,6 +1837,34 @@ func (u *UsageLogUpsertOne) UpdateRateMultiplier() *UsageLogUpsertOne {
})
}
// SetAccountRateMultiplier sets the "account_rate_multiplier" field.
func (u *UsageLogUpsertOne) SetAccountRateMultiplier(v float64) *UsageLogUpsertOne {
return u.Update(func(s *UsageLogUpsert) {
s.SetAccountRateMultiplier(v)
})
}
// AddAccountRateMultiplier adds v to the "account_rate_multiplier" field.
func (u *UsageLogUpsertOne) AddAccountRateMultiplier(v float64) *UsageLogUpsertOne {
return u.Update(func(s *UsageLogUpsert) {
s.AddAccountRateMultiplier(v)
})
}
// UpdateAccountRateMultiplier sets the "account_rate_multiplier" field to the value that was provided on create.
func (u *UsageLogUpsertOne) UpdateAccountRateMultiplier() *UsageLogUpsertOne {
return u.Update(func(s *UsageLogUpsert) {
s.UpdateAccountRateMultiplier()
})
}
// ClearAccountRateMultiplier clears the value of the "account_rate_multiplier" field.
func (u *UsageLogUpsertOne) ClearAccountRateMultiplier() *UsageLogUpsertOne {
return u.Update(func(s *UsageLogUpsert) {
s.ClearAccountRateMultiplier()
})
}
// SetBillingType sets the "billing_type" field.
func (u *UsageLogUpsertOne) SetBillingType(v int8) *UsageLogUpsertOne {
return u.Update(func(s *UsageLogUpsert) {
......@@ -1866,6 +1977,27 @@ func (u *UsageLogUpsertOne) ClearUserAgent() *UsageLogUpsertOne {
})
}
// SetIPAddress sets the "ip_address" field.
func (u *UsageLogUpsertOne) SetIPAddress(v string) *UsageLogUpsertOne {
return u.Update(func(s *UsageLogUpsert) {
s.SetIPAddress(v)
})
}
// UpdateIPAddress sets the "ip_address" field to the value that was provided on create.
func (u *UsageLogUpsertOne) UpdateIPAddress() *UsageLogUpsertOne {
return u.Update(func(s *UsageLogUpsert) {
s.UpdateIPAddress()
})
}
// ClearIPAddress clears the value of the "ip_address" field.
func (u *UsageLogUpsertOne) ClearIPAddress() *UsageLogUpsertOne {
return u.Update(func(s *UsageLogUpsert) {
s.ClearIPAddress()
})
}
// SetImageCount sets the "image_count" field.
func (u *UsageLogUpsertOne) SetImageCount(v int) *UsageLogUpsertOne {
return u.Update(func(s *UsageLogUpsert) {
......@@ -2504,6 +2636,34 @@ func (u *UsageLogUpsertBulk) UpdateRateMultiplier() *UsageLogUpsertBulk {
})
}
// SetAccountRateMultiplier sets the "account_rate_multiplier" field.
func (u *UsageLogUpsertBulk) SetAccountRateMultiplier(v float64) *UsageLogUpsertBulk {
return u.Update(func(s *UsageLogUpsert) {
s.SetAccountRateMultiplier(v)
})
}
// AddAccountRateMultiplier adds v to the "account_rate_multiplier" field.
func (u *UsageLogUpsertBulk) AddAccountRateMultiplier(v float64) *UsageLogUpsertBulk {
return u.Update(func(s *UsageLogUpsert) {
s.AddAccountRateMultiplier(v)
})
}
// UpdateAccountRateMultiplier sets the "account_rate_multiplier" field to the value that was provided on create.
func (u *UsageLogUpsertBulk) UpdateAccountRateMultiplier() *UsageLogUpsertBulk {
return u.Update(func(s *UsageLogUpsert) {
s.UpdateAccountRateMultiplier()
})
}
// ClearAccountRateMultiplier clears the value of the "account_rate_multiplier" field.
func (u *UsageLogUpsertBulk) ClearAccountRateMultiplier() *UsageLogUpsertBulk {
return u.Update(func(s *UsageLogUpsert) {
s.ClearAccountRateMultiplier()
})
}
// SetBillingType sets the "billing_type" field.
func (u *UsageLogUpsertBulk) SetBillingType(v int8) *UsageLogUpsertBulk {
return u.Update(func(s *UsageLogUpsert) {
......@@ -2616,6 +2776,27 @@ func (u *UsageLogUpsertBulk) ClearUserAgent() *UsageLogUpsertBulk {
})
}
// SetIPAddress sets the "ip_address" field.
func (u *UsageLogUpsertBulk) SetIPAddress(v string) *UsageLogUpsertBulk {
return u.Update(func(s *UsageLogUpsert) {
s.SetIPAddress(v)
})
}
// UpdateIPAddress sets the "ip_address" field to the value that was provided on create.
func (u *UsageLogUpsertBulk) UpdateIPAddress() *UsageLogUpsertBulk {
return u.Update(func(s *UsageLogUpsert) {
s.UpdateIPAddress()
})
}
// ClearIPAddress clears the value of the "ip_address" field.
func (u *UsageLogUpsertBulk) ClearIPAddress() *UsageLogUpsertBulk {
return u.Update(func(s *UsageLogUpsert) {
s.ClearIPAddress()
})
}
// SetImageCount sets the "image_count" field.
func (u *UsageLogUpsertBulk) SetImageCount(v int) *UsageLogUpsertBulk {
return u.Update(func(s *UsageLogUpsert) {
......
......@@ -8,6 +8,7 @@ import (
"math"
"entgo.io/ent"
"entgo.io/ent/dialect"
"entgo.io/ent/dialect/sql"
"entgo.io/ent/dialect/sql/sqlgraph"
"entgo.io/ent/schema/field"
......@@ -32,6 +33,7 @@ type UsageLogQuery struct {
withAccount *AccountQuery
withGroup *GroupQuery
withSubscription *UserSubscriptionQuery
modifiers []func(*sql.Selector)
// intermediate query (i.e. traversal path).
sql *sql.Selector
path func(context.Context) (*sql.Selector, error)
......@@ -531,6 +533,9 @@ func (_q *UsageLogQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*Usa
node.Edges.loadedTypes = loadedTypes
return node.assignValues(columns, values)
}
if len(_q.modifiers) > 0 {
_spec.Modifiers = _q.modifiers
}
for i := range hooks {
hooks[i](ctx, _spec)
}
......@@ -727,6 +732,9 @@ func (_q *UsageLogQuery) loadSubscription(ctx context.Context, query *UserSubscr
func (_q *UsageLogQuery) sqlCount(ctx context.Context) (int, error) {
_spec := _q.querySpec()
if len(_q.modifiers) > 0 {
_spec.Modifiers = _q.modifiers
}
_spec.Node.Columns = _q.ctx.Fields
if len(_q.ctx.Fields) > 0 {
_spec.Unique = _q.ctx.Unique != nil && *_q.ctx.Unique
......@@ -804,6 +812,9 @@ func (_q *UsageLogQuery) sqlQuery(ctx context.Context) *sql.Selector {
if _q.ctx.Unique != nil && *_q.ctx.Unique {
selector.Distinct()
}
for _, m := range _q.modifiers {
m(selector)
}
for _, p := range _q.predicates {
p(selector)
}
......@@ -821,6 +832,32 @@ func (_q *UsageLogQuery) sqlQuery(ctx context.Context) *sql.Selector {
return selector
}
// ForUpdate locks the selected rows against concurrent updates, and prevent them from being
// updated, deleted or "selected ... for update" by other sessions, until the transaction is
// either committed or rolled-back.
func (_q *UsageLogQuery) ForUpdate(opts ...sql.LockOption) *UsageLogQuery {
if _q.driver.Dialect() == dialect.Postgres {
_q.Unique(false)
}
_q.modifiers = append(_q.modifiers, func(s *sql.Selector) {
s.ForUpdate(opts...)
})
return _q
}
// ForShare behaves similarly to ForUpdate, except that it acquires a shared mode lock
// on any rows that are read. Other sessions can read the rows, but cannot modify them
// until your transaction commits.
func (_q *UsageLogQuery) ForShare(opts ...sql.LockOption) *UsageLogQuery {
if _q.driver.Dialect() == dialect.Postgres {
_q.Unique(false)
}
_q.modifiers = append(_q.modifiers, func(s *sql.Selector) {
s.ForShare(opts...)
})
return _q
}
// UsageLogGroupBy is the group-by builder for UsageLog entities.
type UsageLogGroupBy struct {
selector
......
......@@ -415,6 +415,33 @@ func (_u *UsageLogUpdate) AddRateMultiplier(v float64) *UsageLogUpdate {
return _u
}
// SetAccountRateMultiplier sets the "account_rate_multiplier" field.
func (_u *UsageLogUpdate) SetAccountRateMultiplier(v float64) *UsageLogUpdate {
_u.mutation.ResetAccountRateMultiplier()
_u.mutation.SetAccountRateMultiplier(v)
return _u
}
// SetNillableAccountRateMultiplier sets the "account_rate_multiplier" field if the given value is not nil.
func (_u *UsageLogUpdate) SetNillableAccountRateMultiplier(v *float64) *UsageLogUpdate {
if v != nil {
_u.SetAccountRateMultiplier(*v)
}
return _u
}
// AddAccountRateMultiplier adds value to the "account_rate_multiplier" field.
func (_u *UsageLogUpdate) AddAccountRateMultiplier(v float64) *UsageLogUpdate {
_u.mutation.AddAccountRateMultiplier(v)
return _u
}
// ClearAccountRateMultiplier clears the value of the "account_rate_multiplier" field.
func (_u *UsageLogUpdate) ClearAccountRateMultiplier() *UsageLogUpdate {
_u.mutation.ClearAccountRateMultiplier()
return _u
}
// SetBillingType sets the "billing_type" field.
func (_u *UsageLogUpdate) SetBillingType(v int8) *UsageLogUpdate {
_u.mutation.ResetBillingType()
......@@ -524,6 +551,26 @@ func (_u *UsageLogUpdate) ClearUserAgent() *UsageLogUpdate {
return _u
}
// SetIPAddress sets the "ip_address" field.
func (_u *UsageLogUpdate) SetIPAddress(v string) *UsageLogUpdate {
_u.mutation.SetIPAddress(v)
return _u
}
// SetNillableIPAddress sets the "ip_address" field if the given value is not nil.
func (_u *UsageLogUpdate) SetNillableIPAddress(v *string) *UsageLogUpdate {
if v != nil {
_u.SetIPAddress(*v)
}
return _u
}
// ClearIPAddress clears the value of the "ip_address" field.
func (_u *UsageLogUpdate) ClearIPAddress() *UsageLogUpdate {
_u.mutation.ClearIPAddress()
return _u
}
// SetImageCount sets the "image_count" field.
func (_u *UsageLogUpdate) SetImageCount(v int) *UsageLogUpdate {
_u.mutation.ResetImageCount()
......@@ -669,6 +716,11 @@ func (_u *UsageLogUpdate) check() error {
return &ValidationError{Name: "user_agent", err: fmt.Errorf(`ent: validator failed for field "UsageLog.user_agent": %w`, err)}
}
}
if v, ok := _u.mutation.IPAddress(); ok {
if err := usagelog.IPAddressValidator(v); err != nil {
return &ValidationError{Name: "ip_address", err: fmt.Errorf(`ent: validator failed for field "UsageLog.ip_address": %w`, err)}
}
}
if v, ok := _u.mutation.ImageSize(); ok {
if err := usagelog.ImageSizeValidator(v); err != nil {
return &ValidationError{Name: "image_size", err: fmt.Errorf(`ent: validator failed for field "UsageLog.image_size": %w`, err)}
......@@ -782,6 +834,15 @@ func (_u *UsageLogUpdate) sqlSave(ctx context.Context) (_node int, err error) {
if value, ok := _u.mutation.AddedRateMultiplier(); ok {
_spec.AddField(usagelog.FieldRateMultiplier, field.TypeFloat64, value)
}
if value, ok := _u.mutation.AccountRateMultiplier(); ok {
_spec.SetField(usagelog.FieldAccountRateMultiplier, field.TypeFloat64, value)
}
if value, ok := _u.mutation.AddedAccountRateMultiplier(); ok {
_spec.AddField(usagelog.FieldAccountRateMultiplier, field.TypeFloat64, value)
}
if _u.mutation.AccountRateMultiplierCleared() {
_spec.ClearField(usagelog.FieldAccountRateMultiplier, field.TypeFloat64)
}
if value, ok := _u.mutation.BillingType(); ok {
_spec.SetField(usagelog.FieldBillingType, field.TypeInt8, value)
}
......@@ -815,6 +876,12 @@ func (_u *UsageLogUpdate) sqlSave(ctx context.Context) (_node int, err error) {
if _u.mutation.UserAgentCleared() {
_spec.ClearField(usagelog.FieldUserAgent, field.TypeString)
}
if value, ok := _u.mutation.IPAddress(); ok {
_spec.SetField(usagelog.FieldIPAddress, field.TypeString, value)
}
if _u.mutation.IPAddressCleared() {
_spec.ClearField(usagelog.FieldIPAddress, field.TypeString)
}
if value, ok := _u.mutation.ImageCount(); ok {
_spec.SetField(usagelog.FieldImageCount, field.TypeInt, value)
}
......@@ -1375,6 +1442,33 @@ func (_u *UsageLogUpdateOne) AddRateMultiplier(v float64) *UsageLogUpdateOne {
return _u
}
// SetAccountRateMultiplier sets the "account_rate_multiplier" field.
func (_u *UsageLogUpdateOne) SetAccountRateMultiplier(v float64) *UsageLogUpdateOne {
_u.mutation.ResetAccountRateMultiplier()
_u.mutation.SetAccountRateMultiplier(v)
return _u
}
// SetNillableAccountRateMultiplier sets the "account_rate_multiplier" field if the given value is not nil.
func (_u *UsageLogUpdateOne) SetNillableAccountRateMultiplier(v *float64) *UsageLogUpdateOne {
if v != nil {
_u.SetAccountRateMultiplier(*v)
}
return _u
}
// AddAccountRateMultiplier adds value to the "account_rate_multiplier" field.
func (_u *UsageLogUpdateOne) AddAccountRateMultiplier(v float64) *UsageLogUpdateOne {
_u.mutation.AddAccountRateMultiplier(v)
return _u
}
// ClearAccountRateMultiplier clears the value of the "account_rate_multiplier" field.
func (_u *UsageLogUpdateOne) ClearAccountRateMultiplier() *UsageLogUpdateOne {
_u.mutation.ClearAccountRateMultiplier()
return _u
}
// SetBillingType sets the "billing_type" field.
func (_u *UsageLogUpdateOne) SetBillingType(v int8) *UsageLogUpdateOne {
_u.mutation.ResetBillingType()
......@@ -1484,6 +1578,26 @@ func (_u *UsageLogUpdateOne) ClearUserAgent() *UsageLogUpdateOne {
return _u
}
// SetIPAddress sets the "ip_address" field.
func (_u *UsageLogUpdateOne) SetIPAddress(v string) *UsageLogUpdateOne {
_u.mutation.SetIPAddress(v)
return _u
}
// SetNillableIPAddress sets the "ip_address" field if the given value is not nil.
func (_u *UsageLogUpdateOne) SetNillableIPAddress(v *string) *UsageLogUpdateOne {
if v != nil {
_u.SetIPAddress(*v)
}
return _u
}
// ClearIPAddress clears the value of the "ip_address" field.
func (_u *UsageLogUpdateOne) ClearIPAddress() *UsageLogUpdateOne {
_u.mutation.ClearIPAddress()
return _u
}
// SetImageCount sets the "image_count" field.
func (_u *UsageLogUpdateOne) SetImageCount(v int) *UsageLogUpdateOne {
_u.mutation.ResetImageCount()
......@@ -1642,6 +1756,11 @@ func (_u *UsageLogUpdateOne) check() error {
return &ValidationError{Name: "user_agent", err: fmt.Errorf(`ent: validator failed for field "UsageLog.user_agent": %w`, err)}
}
}
if v, ok := _u.mutation.IPAddress(); ok {
if err := usagelog.IPAddressValidator(v); err != nil {
return &ValidationError{Name: "ip_address", err: fmt.Errorf(`ent: validator failed for field "UsageLog.ip_address": %w`, err)}
}
}
if v, ok := _u.mutation.ImageSize(); ok {
if err := usagelog.ImageSizeValidator(v); err != nil {
return &ValidationError{Name: "image_size", err: fmt.Errorf(`ent: validator failed for field "UsageLog.image_size": %w`, err)}
......@@ -1772,6 +1891,15 @@ func (_u *UsageLogUpdateOne) sqlSave(ctx context.Context) (_node *UsageLog, err
if value, ok := _u.mutation.AddedRateMultiplier(); ok {
_spec.AddField(usagelog.FieldRateMultiplier, field.TypeFloat64, value)
}
if value, ok := _u.mutation.AccountRateMultiplier(); ok {
_spec.SetField(usagelog.FieldAccountRateMultiplier, field.TypeFloat64, value)
}
if value, ok := _u.mutation.AddedAccountRateMultiplier(); ok {
_spec.AddField(usagelog.FieldAccountRateMultiplier, field.TypeFloat64, value)
}
if _u.mutation.AccountRateMultiplierCleared() {
_spec.ClearField(usagelog.FieldAccountRateMultiplier, field.TypeFloat64)
}
if value, ok := _u.mutation.BillingType(); ok {
_spec.SetField(usagelog.FieldBillingType, field.TypeInt8, value)
}
......@@ -1805,6 +1933,12 @@ func (_u *UsageLogUpdateOne) sqlSave(ctx context.Context) (_node *UsageLog, err
if _u.mutation.UserAgentCleared() {
_spec.ClearField(usagelog.FieldUserAgent, field.TypeString)
}
if value, ok := _u.mutation.IPAddress(); ok {
_spec.SetField(usagelog.FieldIPAddress, field.TypeString, value)
}
if _u.mutation.IPAddressCleared() {
_spec.ClearField(usagelog.FieldIPAddress, field.TypeString)
}
if value, ok := _u.mutation.ImageCount(); ok {
_spec.SetField(usagelog.FieldImageCount, field.TypeInt, value)
}
......
......@@ -61,11 +61,13 @@ type UserEdges struct {
UsageLogs []*UsageLog `json:"usage_logs,omitempty"`
// AttributeValues holds the value of the attribute_values edge.
AttributeValues []*UserAttributeValue `json:"attribute_values,omitempty"`
// PromoCodeUsages holds the value of the promo_code_usages edge.
PromoCodeUsages []*PromoCodeUsage `json:"promo_code_usages,omitempty"`
// UserAllowedGroups holds the value of the user_allowed_groups edge.
UserAllowedGroups []*UserAllowedGroup `json:"user_allowed_groups,omitempty"`
// loadedTypes holds the information for reporting if a
// type was loaded (or requested) in eager-loading or not.
loadedTypes [8]bool
loadedTypes [9]bool
}
// APIKeysOrErr returns the APIKeys value or an error if the edge
......@@ -131,10 +133,19 @@ func (e UserEdges) AttributeValuesOrErr() ([]*UserAttributeValue, error) {
return nil, &NotLoadedError{edge: "attribute_values"}
}
// PromoCodeUsagesOrErr returns the PromoCodeUsages value or an error if the edge
// was not loaded in eager-loading.
func (e UserEdges) PromoCodeUsagesOrErr() ([]*PromoCodeUsage, error) {
if e.loadedTypes[7] {
return e.PromoCodeUsages, nil
}
return nil, &NotLoadedError{edge: "promo_code_usages"}
}
// UserAllowedGroupsOrErr returns the UserAllowedGroups value or an error if the edge
// was not loaded in eager-loading.
func (e UserEdges) UserAllowedGroupsOrErr() ([]*UserAllowedGroup, error) {
if e.loadedTypes[7] {
if e.loadedTypes[8] {
return e.UserAllowedGroups, nil
}
return nil, &NotLoadedError{edge: "user_allowed_groups"}
......@@ -289,6 +300,11 @@ func (_m *User) QueryAttributeValues() *UserAttributeValueQuery {
return NewUserClient(_m.config).QueryAttributeValues(_m)
}
// QueryPromoCodeUsages queries the "promo_code_usages" edge of the User entity.
func (_m *User) QueryPromoCodeUsages() *PromoCodeUsageQuery {
return NewUserClient(_m.config).QueryPromoCodeUsages(_m)
}
// QueryUserAllowedGroups queries the "user_allowed_groups" edge of the User entity.
func (_m *User) QueryUserAllowedGroups() *UserAllowedGroupQuery {
return NewUserClient(_m.config).QueryUserAllowedGroups(_m)
......
......@@ -51,6 +51,8 @@ const (
EdgeUsageLogs = "usage_logs"
// EdgeAttributeValues holds the string denoting the attribute_values edge name in mutations.
EdgeAttributeValues = "attribute_values"
// EdgePromoCodeUsages holds the string denoting the promo_code_usages edge name in mutations.
EdgePromoCodeUsages = "promo_code_usages"
// EdgeUserAllowedGroups holds the string denoting the user_allowed_groups edge name in mutations.
EdgeUserAllowedGroups = "user_allowed_groups"
// Table holds the table name of the user in the database.
......@@ -102,6 +104,13 @@ const (
AttributeValuesInverseTable = "user_attribute_values"
// AttributeValuesColumn is the table column denoting the attribute_values relation/edge.
AttributeValuesColumn = "user_id"
// PromoCodeUsagesTable is the table that holds the promo_code_usages relation/edge.
PromoCodeUsagesTable = "promo_code_usages"
// PromoCodeUsagesInverseTable is the table name for the PromoCodeUsage entity.
// It exists in this package in order to avoid circular dependency with the "promocodeusage" package.
PromoCodeUsagesInverseTable = "promo_code_usages"
// PromoCodeUsagesColumn is the table column denoting the promo_code_usages relation/edge.
PromoCodeUsagesColumn = "user_id"
// UserAllowedGroupsTable is the table that holds the user_allowed_groups relation/edge.
UserAllowedGroupsTable = "user_allowed_groups"
// UserAllowedGroupsInverseTable is the table name for the UserAllowedGroup entity.
......@@ -342,6 +351,20 @@ func ByAttributeValues(term sql.OrderTerm, terms ...sql.OrderTerm) OrderOption {
}
}
// ByPromoCodeUsagesCount orders the results by promo_code_usages count.
func ByPromoCodeUsagesCount(opts ...sql.OrderTermOption) OrderOption {
return func(s *sql.Selector) {
sqlgraph.OrderByNeighborsCount(s, newPromoCodeUsagesStep(), opts...)
}
}
// ByPromoCodeUsages orders the results by promo_code_usages terms.
func ByPromoCodeUsages(term sql.OrderTerm, terms ...sql.OrderTerm) OrderOption {
return func(s *sql.Selector) {
sqlgraph.OrderByNeighborTerms(s, newPromoCodeUsagesStep(), append([]sql.OrderTerm{term}, terms...)...)
}
}
// ByUserAllowedGroupsCount orders the results by user_allowed_groups count.
func ByUserAllowedGroupsCount(opts ...sql.OrderTermOption) OrderOption {
return func(s *sql.Selector) {
......@@ -404,6 +427,13 @@ func newAttributeValuesStep() *sqlgraph.Step {
sqlgraph.Edge(sqlgraph.O2M, false, AttributeValuesTable, AttributeValuesColumn),
)
}
func newPromoCodeUsagesStep() *sqlgraph.Step {
return sqlgraph.NewStep(
sqlgraph.From(Table, FieldID),
sqlgraph.To(PromoCodeUsagesInverseTable, FieldID),
sqlgraph.Edge(sqlgraph.O2M, false, PromoCodeUsagesTable, PromoCodeUsagesColumn),
)
}
func newUserAllowedGroupsStep() *sqlgraph.Step {
return sqlgraph.NewStep(
sqlgraph.From(Table, FieldID),
......
......@@ -871,6 +871,29 @@ func HasAttributeValuesWith(preds ...predicate.UserAttributeValue) predicate.Use
})
}
// HasPromoCodeUsages applies the HasEdge predicate on the "promo_code_usages" edge.
func HasPromoCodeUsages() predicate.User {
return predicate.User(func(s *sql.Selector) {
step := sqlgraph.NewStep(
sqlgraph.From(Table, FieldID),
sqlgraph.Edge(sqlgraph.O2M, false, PromoCodeUsagesTable, PromoCodeUsagesColumn),
)
sqlgraph.HasNeighbors(s, step)
})
}
// HasPromoCodeUsagesWith applies the HasEdge predicate on the "promo_code_usages" edge with a given conditions (other predicates).
func HasPromoCodeUsagesWith(preds ...predicate.PromoCodeUsage) predicate.User {
return predicate.User(func(s *sql.Selector) {
step := newPromoCodeUsagesStep()
sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) {
for _, p := range preds {
p(s)
}
})
})
}
// HasUserAllowedGroups applies the HasEdge predicate on the "user_allowed_groups" edge.
func HasUserAllowedGroups() predicate.User {
return predicate.User(func(s *sql.Selector) {
......
......@@ -13,6 +13,7 @@ import (
"entgo.io/ent/schema/field"
"github.com/Wei-Shaw/sub2api/ent/apikey"
"github.com/Wei-Shaw/sub2api/ent/group"
"github.com/Wei-Shaw/sub2api/ent/promocodeusage"
"github.com/Wei-Shaw/sub2api/ent/redeemcode"
"github.com/Wei-Shaw/sub2api/ent/usagelog"
"github.com/Wei-Shaw/sub2api/ent/user"
......@@ -271,6 +272,21 @@ func (_c *UserCreate) AddAttributeValues(v ...*UserAttributeValue) *UserCreate {
return _c.AddAttributeValueIDs(ids...)
}
// AddPromoCodeUsageIDs adds the "promo_code_usages" edge to the PromoCodeUsage entity by IDs.
func (_c *UserCreate) AddPromoCodeUsageIDs(ids ...int64) *UserCreate {
_c.mutation.AddPromoCodeUsageIDs(ids...)
return _c
}
// AddPromoCodeUsages adds the "promo_code_usages" edges to the PromoCodeUsage entity.
func (_c *UserCreate) AddPromoCodeUsages(v ...*PromoCodeUsage) *UserCreate {
ids := make([]int64, len(v))
for i := range v {
ids[i] = v[i].ID
}
return _c.AddPromoCodeUsageIDs(ids...)
}
// Mutation returns the UserMutation object of the builder.
func (_c *UserCreate) Mutation() *UserMutation {
return _c.mutation
......@@ -593,6 +609,22 @@ func (_c *UserCreate) createSpec() (*User, *sqlgraph.CreateSpec) {
}
_spec.Edges = append(_spec.Edges, edge)
}
if nodes := _c.mutation.PromoCodeUsagesIDs(); len(nodes) > 0 {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
Inverse: false,
Table: user.PromoCodeUsagesTable,
Columns: []string{user.PromoCodeUsagesColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(promocodeusage.FieldID, field.TypeInt64),
},
}
for _, k := range nodes {
edge.Target.Nodes = append(edge.Target.Nodes, k)
}
_spec.Edges = append(_spec.Edges, edge)
}
return _node, _spec
}
......
......@@ -9,12 +9,14 @@ import (
"math"
"entgo.io/ent"
"entgo.io/ent/dialect"
"entgo.io/ent/dialect/sql"
"entgo.io/ent/dialect/sql/sqlgraph"
"entgo.io/ent/schema/field"
"github.com/Wei-Shaw/sub2api/ent/apikey"
"github.com/Wei-Shaw/sub2api/ent/group"
"github.com/Wei-Shaw/sub2api/ent/predicate"
"github.com/Wei-Shaw/sub2api/ent/promocodeusage"
"github.com/Wei-Shaw/sub2api/ent/redeemcode"
"github.com/Wei-Shaw/sub2api/ent/usagelog"
"github.com/Wei-Shaw/sub2api/ent/user"
......@@ -37,7 +39,9 @@ type UserQuery struct {
withAllowedGroups *GroupQuery
withUsageLogs *UsageLogQuery
withAttributeValues *UserAttributeValueQuery
withPromoCodeUsages *PromoCodeUsageQuery
withUserAllowedGroups *UserAllowedGroupQuery
modifiers []func(*sql.Selector)
// intermediate query (i.e. traversal path).
sql *sql.Selector
path func(context.Context) (*sql.Selector, error)
......@@ -228,6 +232,28 @@ func (_q *UserQuery) QueryAttributeValues() *UserAttributeValueQuery {
return query
}
// QueryPromoCodeUsages chains the current query on the "promo_code_usages" edge.
func (_q *UserQuery) QueryPromoCodeUsages() *PromoCodeUsageQuery {
query := (&PromoCodeUsageClient{config: _q.config}).Query()
query.path = func(ctx context.Context) (fromU *sql.Selector, err error) {
if err := _q.prepareQuery(ctx); err != nil {
return nil, err
}
selector := _q.sqlQuery(ctx)
if err := selector.Err(); err != nil {
return nil, err
}
step := sqlgraph.NewStep(
sqlgraph.From(user.Table, user.FieldID, selector),
sqlgraph.To(promocodeusage.Table, promocodeusage.FieldID),
sqlgraph.Edge(sqlgraph.O2M, false, user.PromoCodeUsagesTable, user.PromoCodeUsagesColumn),
)
fromU = sqlgraph.SetNeighbors(_q.driver.Dialect(), step)
return fromU, nil
}
return query
}
// QueryUserAllowedGroups chains the current query on the "user_allowed_groups" edge.
func (_q *UserQuery) QueryUserAllowedGroups() *UserAllowedGroupQuery {
query := (&UserAllowedGroupClient{config: _q.config}).Query()
......@@ -449,6 +475,7 @@ func (_q *UserQuery) Clone() *UserQuery {
withAllowedGroups: _q.withAllowedGroups.Clone(),
withUsageLogs: _q.withUsageLogs.Clone(),
withAttributeValues: _q.withAttributeValues.Clone(),
withPromoCodeUsages: _q.withPromoCodeUsages.Clone(),
withUserAllowedGroups: _q.withUserAllowedGroups.Clone(),
// clone intermediate query.
sql: _q.sql.Clone(),
......@@ -533,6 +560,17 @@ func (_q *UserQuery) WithAttributeValues(opts ...func(*UserAttributeValueQuery))
return _q
}
// WithPromoCodeUsages tells the query-builder to eager-load the nodes that are connected to
// the "promo_code_usages" edge. The optional arguments are used to configure the query builder of the edge.
func (_q *UserQuery) WithPromoCodeUsages(opts ...func(*PromoCodeUsageQuery)) *UserQuery {
query := (&PromoCodeUsageClient{config: _q.config}).Query()
for _, opt := range opts {
opt(query)
}
_q.withPromoCodeUsages = query
return _q
}
// WithUserAllowedGroups tells the query-builder to eager-load the nodes that are connected to
// the "user_allowed_groups" edge. The optional arguments are used to configure the query builder of the edge.
func (_q *UserQuery) WithUserAllowedGroups(opts ...func(*UserAllowedGroupQuery)) *UserQuery {
......@@ -622,7 +660,7 @@ func (_q *UserQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*User, e
var (
nodes = []*User{}
_spec = _q.querySpec()
loadedTypes = [8]bool{
loadedTypes = [9]bool{
_q.withAPIKeys != nil,
_q.withRedeemCodes != nil,
_q.withSubscriptions != nil,
......@@ -630,6 +668,7 @@ func (_q *UserQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*User, e
_q.withAllowedGroups != nil,
_q.withUsageLogs != nil,
_q.withAttributeValues != nil,
_q.withPromoCodeUsages != nil,
_q.withUserAllowedGroups != nil,
}
)
......@@ -642,6 +681,9 @@ func (_q *UserQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*User, e
node.Edges.loadedTypes = loadedTypes
return node.assignValues(columns, values)
}
if len(_q.modifiers) > 0 {
_spec.Modifiers = _q.modifiers
}
for i := range hooks {
hooks[i](ctx, _spec)
}
......@@ -702,6 +744,13 @@ func (_q *UserQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*User, e
return nil, err
}
}
if query := _q.withPromoCodeUsages; query != nil {
if err := _q.loadPromoCodeUsages(ctx, query, nodes,
func(n *User) { n.Edges.PromoCodeUsages = []*PromoCodeUsage{} },
func(n *User, e *PromoCodeUsage) { n.Edges.PromoCodeUsages = append(n.Edges.PromoCodeUsages, e) }); err != nil {
return nil, err
}
}
if query := _q.withUserAllowedGroups; query != nil {
if err := _q.loadUserAllowedGroups(ctx, query, nodes,
func(n *User) { n.Edges.UserAllowedGroups = []*UserAllowedGroup{} },
......@@ -959,6 +1008,36 @@ func (_q *UserQuery) loadAttributeValues(ctx context.Context, query *UserAttribu
}
return nil
}
func (_q *UserQuery) loadPromoCodeUsages(ctx context.Context, query *PromoCodeUsageQuery, nodes []*User, init func(*User), assign func(*User, *PromoCodeUsage)) error {
fks := make([]driver.Value, 0, len(nodes))
nodeids := make(map[int64]*User)
for i := range nodes {
fks = append(fks, nodes[i].ID)
nodeids[nodes[i].ID] = nodes[i]
if init != nil {
init(nodes[i])
}
}
if len(query.ctx.Fields) > 0 {
query.ctx.AppendFieldOnce(promocodeusage.FieldUserID)
}
query.Where(predicate.PromoCodeUsage(func(s *sql.Selector) {
s.Where(sql.InValues(s.C(user.PromoCodeUsagesColumn), fks...))
}))
neighbors, err := query.All(ctx)
if err != nil {
return err
}
for _, n := range neighbors {
fk := n.UserID
node, ok := nodeids[fk]
if !ok {
return fmt.Errorf(`unexpected referenced foreign-key "user_id" returned %v for node %v`, fk, n.ID)
}
assign(node, n)
}
return nil
}
func (_q *UserQuery) loadUserAllowedGroups(ctx context.Context, query *UserAllowedGroupQuery, nodes []*User, init func(*User), assign func(*User, *UserAllowedGroup)) error {
fks := make([]driver.Value, 0, len(nodes))
nodeids := make(map[int64]*User)
......@@ -992,6 +1071,9 @@ func (_q *UserQuery) loadUserAllowedGroups(ctx context.Context, query *UserAllow
func (_q *UserQuery) sqlCount(ctx context.Context) (int, error) {
_spec := _q.querySpec()
if len(_q.modifiers) > 0 {
_spec.Modifiers = _q.modifiers
}
_spec.Node.Columns = _q.ctx.Fields
if len(_q.ctx.Fields) > 0 {
_spec.Unique = _q.ctx.Unique != nil && *_q.ctx.Unique
......@@ -1054,6 +1136,9 @@ func (_q *UserQuery) sqlQuery(ctx context.Context) *sql.Selector {
if _q.ctx.Unique != nil && *_q.ctx.Unique {
selector.Distinct()
}
for _, m := range _q.modifiers {
m(selector)
}
for _, p := range _q.predicates {
p(selector)
}
......@@ -1071,6 +1156,32 @@ func (_q *UserQuery) sqlQuery(ctx context.Context) *sql.Selector {
return selector
}
// ForUpdate locks the selected rows against concurrent updates, and prevent them from being
// updated, deleted or "selected ... for update" by other sessions, until the transaction is
// either committed or rolled-back.
func (_q *UserQuery) ForUpdate(opts ...sql.LockOption) *UserQuery {
if _q.driver.Dialect() == dialect.Postgres {
_q.Unique(false)
}
_q.modifiers = append(_q.modifiers, func(s *sql.Selector) {
s.ForUpdate(opts...)
})
return _q
}
// ForShare behaves similarly to ForUpdate, except that it acquires a shared mode lock
// on any rows that are read. Other sessions can read the rows, but cannot modify them
// until your transaction commits.
func (_q *UserQuery) ForShare(opts ...sql.LockOption) *UserQuery {
if _q.driver.Dialect() == dialect.Postgres {
_q.Unique(false)
}
_q.modifiers = append(_q.modifiers, func(s *sql.Selector) {
s.ForShare(opts...)
})
return _q
}
// UserGroupBy is the group-by builder for User entities.
type UserGroupBy struct {
selector
......
......@@ -14,6 +14,7 @@ import (
"github.com/Wei-Shaw/sub2api/ent/apikey"
"github.com/Wei-Shaw/sub2api/ent/group"
"github.com/Wei-Shaw/sub2api/ent/predicate"
"github.com/Wei-Shaw/sub2api/ent/promocodeusage"
"github.com/Wei-Shaw/sub2api/ent/redeemcode"
"github.com/Wei-Shaw/sub2api/ent/usagelog"
"github.com/Wei-Shaw/sub2api/ent/user"
......@@ -291,6 +292,21 @@ func (_u *UserUpdate) AddAttributeValues(v ...*UserAttributeValue) *UserUpdate {
return _u.AddAttributeValueIDs(ids...)
}
// AddPromoCodeUsageIDs adds the "promo_code_usages" edge to the PromoCodeUsage entity by IDs.
func (_u *UserUpdate) AddPromoCodeUsageIDs(ids ...int64) *UserUpdate {
_u.mutation.AddPromoCodeUsageIDs(ids...)
return _u
}
// AddPromoCodeUsages adds the "promo_code_usages" edges to the PromoCodeUsage entity.
func (_u *UserUpdate) AddPromoCodeUsages(v ...*PromoCodeUsage) *UserUpdate {
ids := make([]int64, len(v))
for i := range v {
ids[i] = v[i].ID
}
return _u.AddPromoCodeUsageIDs(ids...)
}
// Mutation returns the UserMutation object of the builder.
func (_u *UserUpdate) Mutation() *UserMutation {
return _u.mutation
......@@ -443,6 +459,27 @@ func (_u *UserUpdate) RemoveAttributeValues(v ...*UserAttributeValue) *UserUpdat
return _u.RemoveAttributeValueIDs(ids...)
}
// ClearPromoCodeUsages clears all "promo_code_usages" edges to the PromoCodeUsage entity.
func (_u *UserUpdate) ClearPromoCodeUsages() *UserUpdate {
_u.mutation.ClearPromoCodeUsages()
return _u
}
// RemovePromoCodeUsageIDs removes the "promo_code_usages" edge to PromoCodeUsage entities by IDs.
func (_u *UserUpdate) RemovePromoCodeUsageIDs(ids ...int64) *UserUpdate {
_u.mutation.RemovePromoCodeUsageIDs(ids...)
return _u
}
// RemovePromoCodeUsages removes "promo_code_usages" edges to PromoCodeUsage entities.
func (_u *UserUpdate) RemovePromoCodeUsages(v ...*PromoCodeUsage) *UserUpdate {
ids := make([]int64, len(v))
for i := range v {
ids[i] = v[i].ID
}
return _u.RemovePromoCodeUsageIDs(ids...)
}
// Save executes the query and returns the number of nodes affected by the update operation.
func (_u *UserUpdate) Save(ctx context.Context) (int, error) {
if err := _u.defaults(); err != nil {
......@@ -893,6 +930,51 @@ func (_u *UserUpdate) sqlSave(ctx context.Context) (_node int, err error) {
}
_spec.Edges.Add = append(_spec.Edges.Add, edge)
}
if _u.mutation.PromoCodeUsagesCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
Inverse: false,
Table: user.PromoCodeUsagesTable,
Columns: []string{user.PromoCodeUsagesColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(promocodeusage.FieldID, field.TypeInt64),
},
}
_spec.Edges.Clear = append(_spec.Edges.Clear, edge)
}
if nodes := _u.mutation.RemovedPromoCodeUsagesIDs(); len(nodes) > 0 && !_u.mutation.PromoCodeUsagesCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
Inverse: false,
Table: user.PromoCodeUsagesTable,
Columns: []string{user.PromoCodeUsagesColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(promocodeusage.FieldID, field.TypeInt64),
},
}
for _, k := range nodes {
edge.Target.Nodes = append(edge.Target.Nodes, k)
}
_spec.Edges.Clear = append(_spec.Edges.Clear, edge)
}
if nodes := _u.mutation.PromoCodeUsagesIDs(); len(nodes) > 0 {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
Inverse: false,
Table: user.PromoCodeUsagesTable,
Columns: []string{user.PromoCodeUsagesColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(promocodeusage.FieldID, field.TypeInt64),
},
}
for _, k := range nodes {
edge.Target.Nodes = append(edge.Target.Nodes, k)
}
_spec.Edges.Add = append(_spec.Edges.Add, edge)
}
if _node, err = sqlgraph.UpdateNodes(ctx, _u.driver, _spec); err != nil {
if _, ok := err.(*sqlgraph.NotFoundError); ok {
err = &NotFoundError{user.Label}
......@@ -1170,6 +1252,21 @@ func (_u *UserUpdateOne) AddAttributeValues(v ...*UserAttributeValue) *UserUpdat
return _u.AddAttributeValueIDs(ids...)
}
// AddPromoCodeUsageIDs adds the "promo_code_usages" edge to the PromoCodeUsage entity by IDs.
func (_u *UserUpdateOne) AddPromoCodeUsageIDs(ids ...int64) *UserUpdateOne {
_u.mutation.AddPromoCodeUsageIDs(ids...)
return _u
}
// AddPromoCodeUsages adds the "promo_code_usages" edges to the PromoCodeUsage entity.
func (_u *UserUpdateOne) AddPromoCodeUsages(v ...*PromoCodeUsage) *UserUpdateOne {
ids := make([]int64, len(v))
for i := range v {
ids[i] = v[i].ID
}
return _u.AddPromoCodeUsageIDs(ids...)
}
// Mutation returns the UserMutation object of the builder.
func (_u *UserUpdateOne) Mutation() *UserMutation {
return _u.mutation
......@@ -1322,6 +1419,27 @@ func (_u *UserUpdateOne) RemoveAttributeValues(v ...*UserAttributeValue) *UserUp
return _u.RemoveAttributeValueIDs(ids...)
}
// ClearPromoCodeUsages clears all "promo_code_usages" edges to the PromoCodeUsage entity.
func (_u *UserUpdateOne) ClearPromoCodeUsages() *UserUpdateOne {
_u.mutation.ClearPromoCodeUsages()
return _u
}
// RemovePromoCodeUsageIDs removes the "promo_code_usages" edge to PromoCodeUsage entities by IDs.
func (_u *UserUpdateOne) RemovePromoCodeUsageIDs(ids ...int64) *UserUpdateOne {
_u.mutation.RemovePromoCodeUsageIDs(ids...)
return _u
}
// RemovePromoCodeUsages removes "promo_code_usages" edges to PromoCodeUsage entities.
func (_u *UserUpdateOne) RemovePromoCodeUsages(v ...*PromoCodeUsage) *UserUpdateOne {
ids := make([]int64, len(v))
for i := range v {
ids[i] = v[i].ID
}
return _u.RemovePromoCodeUsageIDs(ids...)
}
// Where appends a list predicates to the UserUpdate builder.
func (_u *UserUpdateOne) Where(ps ...predicate.User) *UserUpdateOne {
_u.mutation.Where(ps...)
......@@ -1802,6 +1920,51 @@ func (_u *UserUpdateOne) sqlSave(ctx context.Context) (_node *User, err error) {
}
_spec.Edges.Add = append(_spec.Edges.Add, edge)
}
if _u.mutation.PromoCodeUsagesCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
Inverse: false,
Table: user.PromoCodeUsagesTable,
Columns: []string{user.PromoCodeUsagesColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(promocodeusage.FieldID, field.TypeInt64),
},
}
_spec.Edges.Clear = append(_spec.Edges.Clear, edge)
}
if nodes := _u.mutation.RemovedPromoCodeUsagesIDs(); len(nodes) > 0 && !_u.mutation.PromoCodeUsagesCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
Inverse: false,
Table: user.PromoCodeUsagesTable,
Columns: []string{user.PromoCodeUsagesColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(promocodeusage.FieldID, field.TypeInt64),
},
}
for _, k := range nodes {
edge.Target.Nodes = append(edge.Target.Nodes, k)
}
_spec.Edges.Clear = append(_spec.Edges.Clear, edge)
}
if nodes := _u.mutation.PromoCodeUsagesIDs(); len(nodes) > 0 {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
Inverse: false,
Table: user.PromoCodeUsagesTable,
Columns: []string{user.PromoCodeUsagesColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(promocodeusage.FieldID, field.TypeInt64),
},
}
for _, k := range nodes {
edge.Target.Nodes = append(edge.Target.Nodes, k)
}
_spec.Edges.Add = append(_spec.Edges.Add, edge)
}
_node = &User{config: _u.config}
_spec.Assign = _node.assignValues
_spec.ScanValues = _node.scanValues
......
......@@ -8,6 +8,7 @@ import (
"math"
"entgo.io/ent"
"entgo.io/ent/dialect"
"entgo.io/ent/dialect/sql"
"entgo.io/ent/dialect/sql/sqlgraph"
"github.com/Wei-Shaw/sub2api/ent/group"
......@@ -25,6 +26,7 @@ type UserAllowedGroupQuery struct {
predicates []predicate.UserAllowedGroup
withUser *UserQuery
withGroup *GroupQuery
modifiers []func(*sql.Selector)
// intermediate query (i.e. traversal path).
sql *sql.Selector
path func(context.Context) (*sql.Selector, error)
......@@ -347,6 +349,9 @@ func (_q *UserAllowedGroupQuery) sqlAll(ctx context.Context, hooks ...queryHook)
node.Edges.loadedTypes = loadedTypes
return node.assignValues(columns, values)
}
if len(_q.modifiers) > 0 {
_spec.Modifiers = _q.modifiers
}
for i := range hooks {
hooks[i](ctx, _spec)
}
......@@ -432,6 +437,9 @@ func (_q *UserAllowedGroupQuery) loadGroup(ctx context.Context, query *GroupQuer
func (_q *UserAllowedGroupQuery) sqlCount(ctx context.Context) (int, error) {
_spec := _q.querySpec()
if len(_q.modifiers) > 0 {
_spec.Modifiers = _q.modifiers
}
_spec.Unique = false
_spec.Node.Columns = nil
return sqlgraph.CountNodes(ctx, _q.driver, _spec)
......@@ -495,6 +503,9 @@ func (_q *UserAllowedGroupQuery) sqlQuery(ctx context.Context) *sql.Selector {
if _q.ctx.Unique != nil && *_q.ctx.Unique {
selector.Distinct()
}
for _, m := range _q.modifiers {
m(selector)
}
for _, p := range _q.predicates {
p(selector)
}
......@@ -512,6 +523,32 @@ func (_q *UserAllowedGroupQuery) sqlQuery(ctx context.Context) *sql.Selector {
return selector
}
// ForUpdate locks the selected rows against concurrent updates, and prevent them from being
// updated, deleted or "selected ... for update" by other sessions, until the transaction is
// either committed or rolled-back.
func (_q *UserAllowedGroupQuery) ForUpdate(opts ...sql.LockOption) *UserAllowedGroupQuery {
if _q.driver.Dialect() == dialect.Postgres {
_q.Unique(false)
}
_q.modifiers = append(_q.modifiers, func(s *sql.Selector) {
s.ForUpdate(opts...)
})
return _q
}
// ForShare behaves similarly to ForUpdate, except that it acquires a shared mode lock
// on any rows that are read. Other sessions can read the rows, but cannot modify them
// until your transaction commits.
func (_q *UserAllowedGroupQuery) ForShare(opts ...sql.LockOption) *UserAllowedGroupQuery {
if _q.driver.Dialect() == dialect.Postgres {
_q.Unique(false)
}
_q.modifiers = append(_q.modifiers, func(s *sql.Selector) {
s.ForShare(opts...)
})
return _q
}
// UserAllowedGroupGroupBy is the group-by builder for UserAllowedGroup entities.
type UserAllowedGroupGroupBy struct {
selector
......
......@@ -9,6 +9,7 @@ import (
"math"
"entgo.io/ent"
"entgo.io/ent/dialect"
"entgo.io/ent/dialect/sql"
"entgo.io/ent/dialect/sql/sqlgraph"
"entgo.io/ent/schema/field"
......@@ -25,6 +26,7 @@ type UserAttributeDefinitionQuery struct {
inters []Interceptor
predicates []predicate.UserAttributeDefinition
withValues *UserAttributeValueQuery
modifiers []func(*sql.Selector)
// intermediate query (i.e. traversal path).
sql *sql.Selector
path func(context.Context) (*sql.Selector, error)
......@@ -384,6 +386,9 @@ func (_q *UserAttributeDefinitionQuery) sqlAll(ctx context.Context, hooks ...que
node.Edges.loadedTypes = loadedTypes
return node.assignValues(columns, values)
}
if len(_q.modifiers) > 0 {
_spec.Modifiers = _q.modifiers
}
for i := range hooks {
hooks[i](ctx, _spec)
}
......@@ -436,6 +441,9 @@ func (_q *UserAttributeDefinitionQuery) loadValues(ctx context.Context, query *U
func (_q *UserAttributeDefinitionQuery) sqlCount(ctx context.Context) (int, error) {
_spec := _q.querySpec()
if len(_q.modifiers) > 0 {
_spec.Modifiers = _q.modifiers
}
_spec.Node.Columns = _q.ctx.Fields
if len(_q.ctx.Fields) > 0 {
_spec.Unique = _q.ctx.Unique != nil && *_q.ctx.Unique
......@@ -498,6 +506,9 @@ func (_q *UserAttributeDefinitionQuery) sqlQuery(ctx context.Context) *sql.Selec
if _q.ctx.Unique != nil && *_q.ctx.Unique {
selector.Distinct()
}
for _, m := range _q.modifiers {
m(selector)
}
for _, p := range _q.predicates {
p(selector)
}
......@@ -515,6 +526,32 @@ func (_q *UserAttributeDefinitionQuery) sqlQuery(ctx context.Context) *sql.Selec
return selector
}
// ForUpdate locks the selected rows against concurrent updates, and prevent them from being
// updated, deleted or "selected ... for update" by other sessions, until the transaction is
// either committed or rolled-back.
func (_q *UserAttributeDefinitionQuery) ForUpdate(opts ...sql.LockOption) *UserAttributeDefinitionQuery {
if _q.driver.Dialect() == dialect.Postgres {
_q.Unique(false)
}
_q.modifiers = append(_q.modifiers, func(s *sql.Selector) {
s.ForUpdate(opts...)
})
return _q
}
// ForShare behaves similarly to ForUpdate, except that it acquires a shared mode lock
// on any rows that are read. Other sessions can read the rows, but cannot modify them
// until your transaction commits.
func (_q *UserAttributeDefinitionQuery) ForShare(opts ...sql.LockOption) *UserAttributeDefinitionQuery {
if _q.driver.Dialect() == dialect.Postgres {
_q.Unique(false)
}
_q.modifiers = append(_q.modifiers, func(s *sql.Selector) {
s.ForShare(opts...)
})
return _q
}
// UserAttributeDefinitionGroupBy is the group-by builder for UserAttributeDefinition entities.
type UserAttributeDefinitionGroupBy struct {
selector
......
......@@ -8,6 +8,7 @@ import (
"math"
"entgo.io/ent"
"entgo.io/ent/dialect"
"entgo.io/ent/dialect/sql"
"entgo.io/ent/dialect/sql/sqlgraph"
"entgo.io/ent/schema/field"
......@@ -26,6 +27,7 @@ type UserAttributeValueQuery struct {
predicates []predicate.UserAttributeValue
withUser *UserQuery
withDefinition *UserAttributeDefinitionQuery
modifiers []func(*sql.Selector)
// intermediate query (i.e. traversal path).
sql *sql.Selector
path func(context.Context) (*sql.Selector, error)
......@@ -420,6 +422,9 @@ func (_q *UserAttributeValueQuery) sqlAll(ctx context.Context, hooks ...queryHoo
node.Edges.loadedTypes = loadedTypes
return node.assignValues(columns, values)
}
if len(_q.modifiers) > 0 {
_spec.Modifiers = _q.modifiers
}
for i := range hooks {
hooks[i](ctx, _spec)
}
......@@ -505,6 +510,9 @@ func (_q *UserAttributeValueQuery) loadDefinition(ctx context.Context, query *Us
func (_q *UserAttributeValueQuery) sqlCount(ctx context.Context) (int, error) {
_spec := _q.querySpec()
if len(_q.modifiers) > 0 {
_spec.Modifiers = _q.modifiers
}
_spec.Node.Columns = _q.ctx.Fields
if len(_q.ctx.Fields) > 0 {
_spec.Unique = _q.ctx.Unique != nil && *_q.ctx.Unique
......@@ -573,6 +581,9 @@ func (_q *UserAttributeValueQuery) sqlQuery(ctx context.Context) *sql.Selector {
if _q.ctx.Unique != nil && *_q.ctx.Unique {
selector.Distinct()
}
for _, m := range _q.modifiers {
m(selector)
}
for _, p := range _q.predicates {
p(selector)
}
......@@ -590,6 +601,32 @@ func (_q *UserAttributeValueQuery) sqlQuery(ctx context.Context) *sql.Selector {
return selector
}
// ForUpdate locks the selected rows against concurrent updates, and prevent them from being
// updated, deleted or "selected ... for update" by other sessions, until the transaction is
// either committed or rolled-back.
func (_q *UserAttributeValueQuery) ForUpdate(opts ...sql.LockOption) *UserAttributeValueQuery {
if _q.driver.Dialect() == dialect.Postgres {
_q.Unique(false)
}
_q.modifiers = append(_q.modifiers, func(s *sql.Selector) {
s.ForUpdate(opts...)
})
return _q
}
// ForShare behaves similarly to ForUpdate, except that it acquires a shared mode lock
// on any rows that are read. Other sessions can read the rows, but cannot modify them
// until your transaction commits.
func (_q *UserAttributeValueQuery) ForShare(opts ...sql.LockOption) *UserAttributeValueQuery {
if _q.driver.Dialect() == dialect.Postgres {
_q.Unique(false)
}
_q.modifiers = append(_q.modifiers, func(s *sql.Selector) {
s.ForShare(opts...)
})
return _q
}
// UserAttributeValueGroupBy is the group-by builder for UserAttributeValue entities.
type UserAttributeValueGroupBy struct {
selector
......
......@@ -9,6 +9,7 @@ import (
"math"
"entgo.io/ent"
"entgo.io/ent/dialect"
"entgo.io/ent/dialect/sql"
"entgo.io/ent/dialect/sql/sqlgraph"
"entgo.io/ent/schema/field"
......@@ -30,6 +31,7 @@ type UserSubscriptionQuery struct {
withGroup *GroupQuery
withAssignedByUser *UserQuery
withUsageLogs *UsageLogQuery
modifiers []func(*sql.Selector)
// intermediate query (i.e. traversal path).
sql *sql.Selector
path func(context.Context) (*sql.Selector, error)
......@@ -494,6 +496,9 @@ func (_q *UserSubscriptionQuery) sqlAll(ctx context.Context, hooks ...queryHook)
node.Edges.loadedTypes = loadedTypes
return node.assignValues(columns, values)
}
if len(_q.modifiers) > 0 {
_spec.Modifiers = _q.modifiers
}
for i := range hooks {
hooks[i](ctx, _spec)
}
......@@ -657,6 +662,9 @@ func (_q *UserSubscriptionQuery) loadUsageLogs(ctx context.Context, query *Usage
func (_q *UserSubscriptionQuery) sqlCount(ctx context.Context) (int, error) {
_spec := _q.querySpec()
if len(_q.modifiers) > 0 {
_spec.Modifiers = _q.modifiers
}
_spec.Node.Columns = _q.ctx.Fields
if len(_q.ctx.Fields) > 0 {
_spec.Unique = _q.ctx.Unique != nil && *_q.ctx.Unique
......@@ -728,6 +736,9 @@ func (_q *UserSubscriptionQuery) sqlQuery(ctx context.Context) *sql.Selector {
if _q.ctx.Unique != nil && *_q.ctx.Unique {
selector.Distinct()
}
for _, m := range _q.modifiers {
m(selector)
}
for _, p := range _q.predicates {
p(selector)
}
......@@ -745,6 +756,32 @@ func (_q *UserSubscriptionQuery) sqlQuery(ctx context.Context) *sql.Selector {
return selector
}
// ForUpdate locks the selected rows against concurrent updates, and prevent them from being
// updated, deleted or "selected ... for update" by other sessions, until the transaction is
// either committed or rolled-back.
func (_q *UserSubscriptionQuery) ForUpdate(opts ...sql.LockOption) *UserSubscriptionQuery {
if _q.driver.Dialect() == dialect.Postgres {
_q.Unique(false)
}
_q.modifiers = append(_q.modifiers, func(s *sql.Selector) {
s.ForUpdate(opts...)
})
return _q
}
// ForShare behaves similarly to ForUpdate, except that it acquires a shared mode lock
// on any rows that are read. Other sessions can read the rows, but cannot modify them
// until your transaction commits.
func (_q *UserSubscriptionQuery) ForShare(opts ...sql.LockOption) *UserSubscriptionQuery {
if _q.driver.Dialect() == dialect.Postgres {
_q.Unique(false)
}
_q.modifiers = append(_q.modifiers, func(s *sql.Selector) {
s.ForShare(opts...)
})
return _q
}
// UserSubscriptionGroupBy is the group-by builder for UserSubscription entities.
type UserSubscriptionGroupBy struct {
selector
......
......@@ -8,9 +8,11 @@ require (
github.com/golang-jwt/jwt/v5 v5.2.2
github.com/google/uuid v1.6.0
github.com/google/wire v0.7.0
github.com/gorilla/websocket v1.5.3
github.com/imroc/req/v3 v3.57.0
github.com/lib/pq v1.10.9
github.com/redis/go-redis/v9 v9.17.2
github.com/shirou/gopsutil/v4 v4.25.6
github.com/spf13/viper v1.18.2
github.com/stretchr/testify v1.11.1
github.com/testcontainers/testcontainers-go/modules/postgres v0.40.0
......@@ -44,11 +46,13 @@ require (
github.com/containerd/platforms v0.2.1 // indirect
github.com/cpuguy83/dockercfg v0.3.2 // indirect
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect
github.com/dgraph-io/ristretto v0.2.0 // indirect
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
github.com/distribution/reference v0.6.0 // indirect
github.com/docker/docker v28.5.1+incompatible // indirect
github.com/docker/go-connections v0.6.0 // indirect
github.com/docker/go-units v0.5.0 // indirect
github.com/dustin/go-humanize v1.0.1 // indirect
github.com/ebitengine/purego v0.8.4 // indirect
github.com/fatih/color v1.18.0 // indirect
github.com/felixge/httpsnoop v1.0.4 // indirect
......@@ -104,9 +108,9 @@ require (
github.com/quic-go/quic-go v0.57.1 // indirect
github.com/refraction-networking/utls v1.8.1 // indirect
github.com/rivo/uniseg v0.2.0 // indirect
github.com/robfig/cron/v3 v3.0.1 // indirect
github.com/sagikazarmark/locafero v0.4.0 // indirect
github.com/sagikazarmark/slog-shim v0.1.0 // indirect
github.com/shirou/gopsutil/v4 v4.25.6 // indirect
github.com/sirupsen/logrus v1.9.3 // indirect
github.com/sourcegraph/conc v0.3.0 // indirect
github.com/spaolacci/murmur3 v1.1.0 // indirect
......
......@@ -51,6 +51,8 @@ github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSs
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM=
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/dgraph-io/ristretto v0.2.0 h1:XAfl+7cmoUDWW/2Lx8TGZQjjxIQ2Ley9DSf52dru4WE=
github.com/dgraph-io/ristretto v0.2.0/go.mod h1:8uBHCU/PBV4Ag0CJrP47b9Ofby5dqWNh4FicAdoqFNU=
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78=
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc=
github.com/distribution/reference v0.6.0 h1:0IXCQ5g4/QMHHkarYzh5l+u8T3t73zM5QvfrDyIgxBk=
......@@ -61,6 +63,8 @@ github.com/docker/go-connections v0.6.0 h1:LlMG9azAe1TqfR7sO+NJttz1gy6KO7VJBh+pM
github.com/docker/go-connections v0.6.0/go.mod h1:AahvXYshr6JgfUJGdDCs2b5EZG/vmaMAntpSFH5BFKE=
github.com/docker/go-units v0.5.0 h1:69rxXcBk27SvSaaxTtLh/8llcHD8vYHT7WSdRZ/jvr4=
github.com/docker/go-units v0.5.0/go.mod h1:fgPhTUdO+D/Jk86RDLlptpiXQzgHJF7gydDDbaIK4Dk=
github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY=
github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto=
github.com/ebitengine/purego v0.8.4 h1:CF7LEKg5FFOsASUj0+QwaXf8Ht6TlFxg09+S9wz0omw=
github.com/ebitengine/purego v0.8.4/go.mod h1:iIjxzd6CiRiOG0UyXP+V1+jWqUXVjPKLAI0mRfJZTmQ=
github.com/fatih/color v1.18.0 h1:S8gINlzdQ840/4pfAwic/ZE0djQEH3wM94VfqLTZcOM=
......@@ -113,6 +117,8 @@ github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/google/wire v0.7.0 h1:JxUKI6+CVBgCO2WToKy/nQk0sS+amI9z9EjVmdaocj4=
github.com/google/wire v0.7.0/go.mod h1:n6YbUQD9cPKTnHXEBN2DXlOp/mVADhVErcMFb0v3J18=
github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg=
github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.3 h1:NmZ1PKzSTQbuGHw9DGPFomqkkLWMC+vZCkfs+FHv1Vg=
github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.3/go.mod h1:zQrxl1YP88HQlA6i9c63DSVPFklWpGX4OWAc9bFuaH4=
github.com/hashicorp/hcl v1.0.0 h1:0Anlzjpi4vEasTeNFn2mLJgTSwt0+6sfsiTG8qcWGx4=
......@@ -220,6 +226,8 @@ github.com/refraction-networking/utls v1.8.1 h1:yNY1kapmQU8JeM1sSw2H2asfTIwWxIkr
github.com/refraction-networking/utls v1.8.1/go.mod h1:jkSOEkLqn+S/jtpEHPOsVv/4V4EVnelwbMQl4vCWXAM=
github.com/rivo/uniseg v0.2.0 h1:S1pD9weZBuJdFmowNwbpi7BJ8TNftyUImj/0WQi72jY=
github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc=
github.com/robfig/cron/v3 v3.0.1 h1:WdRxkvbJztn8LMz/QEvLN5sBU+xKpSqwwUO1Pjr4qDs=
github.com/robfig/cron/v3 v3.0.1/go.mod h1:eQICP3HwyT7UooqI/z+Ov+PtYAWygg1TEWWzGIFLtro=
github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR38lUII=
github.com/rogpeppe/go-internal v1.13.1/go.mod h1:uMEvuHeurkdAXX61udpOXGD/AzZDWNMNyH2VO9fmH0o=
github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
......
......@@ -19,7 +19,9 @@ const (
RunModeSimple = "simple"
)
const DefaultCSPPolicy = "default-src 'self'; script-src 'self' https://challenges.cloudflare.com; style-src 'self' 'unsafe-inline' https://fonts.googleapis.com; img-src 'self' data: https:; font-src 'self' data: https://fonts.gstatic.com; connect-src 'self' https:; frame-src https://challenges.cloudflare.com; frame-ancestors 'none'; base-uri 'self'; form-action 'self'"
// DefaultCSPPolicy is the default Content-Security-Policy with nonce support
// __CSP_NONCE__ will be replaced with actual nonce at request time by the SecurityHeaders middleware
const DefaultCSPPolicy = "default-src 'self'; script-src 'self' __CSP_NONCE__ https://challenges.cloudflare.com https://static.cloudflareinsights.com; style-src 'self' 'unsafe-inline' https://fonts.googleapis.com; img-src 'self' data: https:; font-src 'self' data: https://fonts.gstatic.com; connect-src 'self' https:; frame-src https://challenges.cloudflare.com; frame-ancestors 'none'; base-uri 'self'; form-action 'self'"
// 连接池隔离策略常量
// 用于控制上游 HTTP 连接池的隔离粒度,影响连接复用和资源消耗
......@@ -43,12 +45,16 @@ type Config struct {
Turnstile TurnstileConfig `mapstructure:"turnstile"`
Database DatabaseConfig `mapstructure:"database"`
Redis RedisConfig `mapstructure:"redis"`
Ops OpsConfig `mapstructure:"ops"`
JWT JWTConfig `mapstructure:"jwt"`
LinuxDo LinuxDoConnectConfig `mapstructure:"linuxdo_connect"`
Default DefaultConfig `mapstructure:"default"`
RateLimit RateLimitConfig `mapstructure:"rate_limit"`
Pricing PricingConfig `mapstructure:"pricing"`
Gateway GatewayConfig `mapstructure:"gateway"`
APIKeyAuth APIKeyAuthCacheConfig `mapstructure:"api_key_auth_cache"`
Dashboard DashboardCacheConfig `mapstructure:"dashboard_cache"`
DashboardAgg DashboardAggregationConfig `mapstructure:"dashboard_aggregation"`
Concurrency ConcurrencyConfig `mapstructure:"concurrency"`
TokenRefresh TokenRefreshConfig `mapstructure:"token_refresh"`
RunMode string `mapstructure:"run_mode" yaml:"run_mode"`
......@@ -57,14 +63,6 @@ type Config struct {
Update UpdateConfig `mapstructure:"update"`
}
// UpdateConfig 在线更新相关配置
type UpdateConfig struct {
// ProxyURL 用于访问 GitHub 的代理地址
// 支持 http/https/socks5/socks5h 协议
// 例如: "http://127.0.0.1:7890", "socks5://127.0.0.1:1080"
ProxyURL string `mapstructure:"proxy_url"`
}
type GeminiConfig struct {
OAuth GeminiOAuthConfig `mapstructure:"oauth"`
Quota GeminiQuotaConfig `mapstructure:"quota"`
......@@ -87,6 +85,33 @@ type GeminiTierQuotaConfig struct {
CooldownMinutes *int `mapstructure:"cooldown_minutes" json:"cooldown_minutes"`
}
type UpdateConfig struct {
// ProxyURL 用于访问 GitHub 的代理地址
// 支持 http/https/socks5/socks5h 协议
// 例如: "http://127.0.0.1:7890", "socks5://127.0.0.1:1080"
ProxyURL string `mapstructure:"proxy_url"`
}
type LinuxDoConnectConfig struct {
Enabled bool `mapstructure:"enabled"`
ClientID string `mapstructure:"client_id"`
ClientSecret string `mapstructure:"client_secret"`
AuthorizeURL string `mapstructure:"authorize_url"`
TokenURL string `mapstructure:"token_url"`
UserInfoURL string `mapstructure:"userinfo_url"`
Scopes string `mapstructure:"scopes"`
RedirectURL string `mapstructure:"redirect_url"` // 后端回调地址(需在提供方后台登记)
FrontendRedirectURL string `mapstructure:"frontend_redirect_url"` // 前端接收 token 的路由(默认:/auth/linuxdo/callback)
TokenAuthMethod string `mapstructure:"token_auth_method"` // client_secret_post / client_secret_basic / none
UsePKCE bool `mapstructure:"use_pkce"`
// 可选:用于从 userinfo JSON 中提取字段的 gjson 路径。
// 为空时,服务端会尝试一组常见字段名。
UserInfoEmailPath string `mapstructure:"userinfo_email_path"`
UserInfoIDPath string `mapstructure:"userinfo_id_path"`
UserInfoUsernamePath string `mapstructure:"userinfo_username_path"`
}
// TokenRefreshConfig OAuth token自动刷新配置
type TokenRefreshConfig struct {
// 是否启用自动刷新
......@@ -209,6 +234,10 @@ type GatewayConfig struct {
// ConcurrencySlotTTLMinutes: 并发槽位过期时间(分钟)
// 应大于最长 LLM 请求时间,防止请求完成前槽位过期
ConcurrencySlotTTLMinutes int `mapstructure:"concurrency_slot_ttl_minutes"`
// SessionIdleTimeoutMinutes: 会话空闲超时时间(分钟),默认 5 分钟
// 用于 Anthropic OAuth/SetupToken 账号的会话数量限制功能
// 空闲超过此时间的会话将被自动释放
SessionIdleTimeoutMinutes int `mapstructure:"session_idle_timeout_minutes"`
// StreamDataIntervalTimeout: 流数据间隔超时(秒),0表示禁用
StreamDataIntervalTimeout int `mapstructure:"stream_data_interval_timeout"`
......@@ -258,6 +287,29 @@ type GatewaySchedulingConfig struct {
// 过期槽位清理周期(0 表示禁用)
SlotCleanupInterval time.Duration `mapstructure:"slot_cleanup_interval"`
// 受控回源配置
DbFallbackEnabled bool `mapstructure:"db_fallback_enabled"`
// 受控回源超时(秒),0 表示不额外收紧超时
DbFallbackTimeoutSeconds int `mapstructure:"db_fallback_timeout_seconds"`
// 受控回源限流(实例级 QPS),0 表示不限制
DbFallbackMaxQPS int `mapstructure:"db_fallback_max_qps"`
// Outbox 轮询与滞后阈值配置
// Outbox 轮询周期(秒)
OutboxPollIntervalSeconds int `mapstructure:"outbox_poll_interval_seconds"`
// Outbox 滞后告警阈值(秒)
OutboxLagWarnSeconds int `mapstructure:"outbox_lag_warn_seconds"`
// Outbox 触发强制重建阈值(秒)
OutboxLagRebuildSeconds int `mapstructure:"outbox_lag_rebuild_seconds"`
// Outbox 连续滞后触发次数
OutboxLagRebuildFailures int `mapstructure:"outbox_lag_rebuild_failures"`
// Outbox 积压触发重建阈值(行数)
OutboxBacklogRebuildRows int `mapstructure:"outbox_backlog_rebuild_rows"`
// 全量重建周期配置
// 全量重建周期(秒),0 表示禁用
FullRebuildIntervalSeconds int `mapstructure:"full_rebuild_interval_seconds"`
}
func (s *ServerConfig) Address() string {
......@@ -285,6 +337,13 @@ type DatabaseConfig struct {
}
func (d *DatabaseConfig) DSN() string {
// 当密码为空时不包含 password 参数,避免 libpq 解析错误
if d.Password == "" {
return fmt.Sprintf(
"host=%s port=%d user=%s dbname=%s sslmode=%s",
d.Host, d.Port, d.User, d.DBName, d.SSLMode,
)
}
return fmt.Sprintf(
"host=%s port=%d user=%s password=%s dbname=%s sslmode=%s",
d.Host, d.Port, d.User, d.Password, d.DBName, d.SSLMode,
......@@ -296,6 +355,13 @@ func (d *DatabaseConfig) DSNWithTimezone(tz string) string {
if tz == "" {
tz = "Asia/Shanghai"
}
// 当密码为空时不包含 password 参数,避免 libpq 解析错误
if d.Password == "" {
return fmt.Sprintf(
"host=%s port=%d user=%s dbname=%s sslmode=%s TimeZone=%s",
d.Host, d.Port, d.User, d.DBName, d.SSLMode, tz,
)
}
return fmt.Sprintf(
"host=%s port=%d user=%s password=%s dbname=%s sslmode=%s TimeZone=%s",
d.Host, d.Port, d.User, d.Password, d.DBName, d.SSLMode, tz,
......@@ -326,6 +392,47 @@ func (r *RedisConfig) Address() string {
return fmt.Sprintf("%s:%d", r.Host, r.Port)
}
type OpsConfig struct {
// Enabled controls whether ops features should run.
//
// NOTE: vNext still has a DB-backed feature flag (ops_monitoring_enabled) for runtime on/off.
// This config flag is the "hard switch" for deployments that want to disable ops completely.
Enabled bool `mapstructure:"enabled"`
// UsePreaggregatedTables prefers ops_metrics_hourly/daily for long-window dashboard queries.
UsePreaggregatedTables bool `mapstructure:"use_preaggregated_tables"`
// Cleanup controls periodic deletion of old ops data to prevent unbounded growth.
Cleanup OpsCleanupConfig `mapstructure:"cleanup"`
// MetricsCollectorCache controls Redis caching for expensive per-window collector queries.
MetricsCollectorCache OpsMetricsCollectorCacheConfig `mapstructure:"metrics_collector_cache"`
// Pre-aggregation configuration.
Aggregation OpsAggregationConfig `mapstructure:"aggregation"`
}
type OpsCleanupConfig struct {
Enabled bool `mapstructure:"enabled"`
Schedule string `mapstructure:"schedule"`
// Retention days (0 disables that cleanup target).
//
// vNext requirement: default 30 days across ops datasets.
ErrorLogRetentionDays int `mapstructure:"error_log_retention_days"`
MinuteMetricsRetentionDays int `mapstructure:"minute_metrics_retention_days"`
HourlyMetricsRetentionDays int `mapstructure:"hourly_metrics_retention_days"`
}
type OpsAggregationConfig struct {
Enabled bool `mapstructure:"enabled"`
}
type OpsMetricsCollectorCacheConfig struct {
Enabled bool `mapstructure:"enabled"`
TTL time.Duration `mapstructure:"ttl"`
}
type JWTConfig struct {
Secret string `mapstructure:"secret"`
ExpireHour int `mapstructure:"expire_hour"`
......@@ -335,30 +442,6 @@ type TurnstileConfig struct {
Required bool `mapstructure:"required"`
}
// LinuxDoConnectConfig 用于 LinuxDo Connect OAuth 登录(终端用户 SSO)。
//
// 注意:这与上游账号的 OAuth(例如 OpenAI/Gemini 账号接入)不是一回事。
// 这里是用于登录 Sub2API 本身的用户体系。
type LinuxDoConnectConfig struct {
Enabled bool `mapstructure:"enabled"`
ClientID string `mapstructure:"client_id"`
ClientSecret string `mapstructure:"client_secret"`
AuthorizeURL string `mapstructure:"authorize_url"`
TokenURL string `mapstructure:"token_url"`
UserInfoURL string `mapstructure:"userinfo_url"`
Scopes string `mapstructure:"scopes"`
RedirectURL string `mapstructure:"redirect_url"` // 后端回调地址(需在提供方后台登记)
FrontendRedirectURL string `mapstructure:"frontend_redirect_url"` // 前端接收 token 的路由(默认:/auth/linuxdo/callback)
TokenAuthMethod string `mapstructure:"token_auth_method"` // client_secret_post / client_secret_basic / none
UsePKCE bool `mapstructure:"use_pkce"`
// 可选:用于从 userinfo JSON 中提取字段的 gjson 路径。
// 为空时,服务端会尝试一组常见字段名。
UserInfoEmailPath string `mapstructure:"userinfo_email_path"`
UserInfoIDPath string `mapstructure:"userinfo_id_path"`
UserInfoUsernamePath string `mapstructure:"userinfo_username_path"`
}
type DefaultConfig struct {
AdminEmail string `mapstructure:"admin_email"`
AdminPassword string `mapstructure:"admin_password"`
......@@ -372,6 +455,55 @@ type RateLimitConfig struct {
OverloadCooldownMinutes int `mapstructure:"overload_cooldown_minutes"` // 529过载冷却时间(分钟)
}
// APIKeyAuthCacheConfig API Key 认证缓存配置
type APIKeyAuthCacheConfig struct {
L1Size int `mapstructure:"l1_size"`
L1TTLSeconds int `mapstructure:"l1_ttl_seconds"`
L2TTLSeconds int `mapstructure:"l2_ttl_seconds"`
NegativeTTLSeconds int `mapstructure:"negative_ttl_seconds"`
JitterPercent int `mapstructure:"jitter_percent"`
Singleflight bool `mapstructure:"singleflight"`
}
// DashboardCacheConfig 仪表盘统计缓存配置
type DashboardCacheConfig struct {
// Enabled: 是否启用仪表盘缓存
Enabled bool `mapstructure:"enabled"`
// KeyPrefix: Redis key 前缀,用于多环境隔离
KeyPrefix string `mapstructure:"key_prefix"`
// StatsFreshTTLSeconds: 缓存命中认为“新鲜”的时间窗口(秒)
StatsFreshTTLSeconds int `mapstructure:"stats_fresh_ttl_seconds"`
// StatsTTLSeconds: Redis 缓存总 TTL(秒)
StatsTTLSeconds int `mapstructure:"stats_ttl_seconds"`
// StatsRefreshTimeoutSeconds: 异步刷新超时(秒)
StatsRefreshTimeoutSeconds int `mapstructure:"stats_refresh_timeout_seconds"`
}
// DashboardAggregationConfig 仪表盘预聚合配置
type DashboardAggregationConfig struct {
// Enabled: 是否启用预聚合作业
Enabled bool `mapstructure:"enabled"`
// IntervalSeconds: 聚合刷新间隔(秒)
IntervalSeconds int `mapstructure:"interval_seconds"`
// LookbackSeconds: 回看窗口(秒)
LookbackSeconds int `mapstructure:"lookback_seconds"`
// BackfillEnabled: 是否允许全量回填
BackfillEnabled bool `mapstructure:"backfill_enabled"`
// BackfillMaxDays: 回填最大跨度(天)
BackfillMaxDays int `mapstructure:"backfill_max_days"`
// Retention: 各表保留窗口(天)
Retention DashboardAggregationRetentionConfig `mapstructure:"retention"`
// RecomputeDays: 启动时重算最近 N 天
RecomputeDays int `mapstructure:"recompute_days"`
}
// DashboardAggregationRetentionConfig 预聚合保留窗口
type DashboardAggregationRetentionConfig struct {
UsageLogsDays int `mapstructure:"usage_logs_days"`
HourlyDays int `mapstructure:"hourly_days"`
DailyDays int `mapstructure:"daily_days"`
}
func NormalizeRunMode(value string) string {
normalized := strings.ToLower(strings.TrimSpace(value))
switch normalized {
......@@ -437,6 +569,7 @@ func Load() (*Config, error) {
cfg.LinuxDo.UserInfoEmailPath = strings.TrimSpace(cfg.LinuxDo.UserInfoEmailPath)
cfg.LinuxDo.UserInfoIDPath = strings.TrimSpace(cfg.LinuxDo.UserInfoIDPath)
cfg.LinuxDo.UserInfoUsernamePath = strings.TrimSpace(cfg.LinuxDo.UserInfoUsernamePath)
cfg.Dashboard.KeyPrefix = strings.TrimSpace(cfg.Dashboard.KeyPrefix)
cfg.CORS.AllowedOrigins = normalizeStringSlice(cfg.CORS.AllowedOrigins)
cfg.Security.ResponseHeaders.AdditionalAllowed = normalizeStringSlice(cfg.Security.ResponseHeaders.AdditionalAllowed)
cfg.Security.ResponseHeaders.ForceRemove = normalizeStringSlice(cfg.Security.ResponseHeaders.ForceRemove)
......@@ -475,81 +608,6 @@ func Load() (*Config, error) {
return &cfg, nil
}
// ValidateAbsoluteHTTPURL 校验一个绝对 http(s) URL(禁止 fragment)。
func ValidateAbsoluteHTTPURL(raw string) error {
raw = strings.TrimSpace(raw)
if raw == "" {
return fmt.Errorf("empty url")
}
u, err := url.Parse(raw)
if err != nil {
return err
}
if !u.IsAbs() {
return fmt.Errorf("must be absolute")
}
if !isHTTPScheme(u.Scheme) {
return fmt.Errorf("unsupported scheme: %s", u.Scheme)
}
if strings.TrimSpace(u.Host) == "" {
return fmt.Errorf("missing host")
}
if u.Fragment != "" {
return fmt.Errorf("must not include fragment")
}
return nil
}
// ValidateFrontendRedirectURL 校验前端回调地址:
// - 允许同源相对路径(以 / 开头)
// - 或绝对 http(s) URL(禁止 fragment)
func ValidateFrontendRedirectURL(raw string) error {
raw = strings.TrimSpace(raw)
if raw == "" {
return fmt.Errorf("empty url")
}
if strings.ContainsAny(raw, "\r\n") {
return fmt.Errorf("contains invalid characters")
}
if strings.HasPrefix(raw, "/") {
if strings.HasPrefix(raw, "//") {
return fmt.Errorf("must not start with //")
}
return nil
}
u, err := url.Parse(raw)
if err != nil {
return err
}
if !u.IsAbs() {
return fmt.Errorf("must be absolute http(s) url or relative path")
}
if !isHTTPScheme(u.Scheme) {
return fmt.Errorf("unsupported scheme: %s", u.Scheme)
}
if strings.TrimSpace(u.Host) == "" {
return fmt.Errorf("missing host")
}
if u.Fragment != "" {
return fmt.Errorf("must not include fragment")
}
return nil
}
func isHTTPScheme(scheme string) bool {
return strings.EqualFold(scheme, "http") || strings.EqualFold(scheme, "https")
}
func warnIfInsecureURL(field, raw string) {
u, err := url.Parse(strings.TrimSpace(raw))
if err != nil {
return
}
if strings.EqualFold(u.Scheme, "http") {
log.Printf("Warning: %s uses http scheme; use https in production to avoid token leakage.", field)
}
}
func setDefaults() {
viper.SetDefault("run_mode", RunModeStandard)
......@@ -599,7 +657,7 @@ func setDefaults() {
// Turnstile
viper.SetDefault("turnstile.required", false)
// LinuxDo Connect OAuth 登录(终端用户 SSO)
// LinuxDo Connect OAuth 登录
viper.SetDefault("linuxdo_connect.enabled", false)
viper.SetDefault("linuxdo_connect.client_id", "")
viper.SetDefault("linuxdo_connect.client_secret", "")
......@@ -638,6 +696,20 @@ func setDefaults() {
viper.SetDefault("redis.pool_size", 128)
viper.SetDefault("redis.min_idle_conns", 10)
// Ops (vNext)
viper.SetDefault("ops.enabled", true)
viper.SetDefault("ops.use_preaggregated_tables", false)
viper.SetDefault("ops.cleanup.enabled", true)
viper.SetDefault("ops.cleanup.schedule", "0 2 * * *")
// Retention days: vNext defaults to 30 days across ops datasets.
viper.SetDefault("ops.cleanup.error_log_retention_days", 30)
viper.SetDefault("ops.cleanup.minute_metrics_retention_days", 30)
viper.SetDefault("ops.cleanup.hourly_metrics_retention_days", 30)
viper.SetDefault("ops.aggregation.enabled", true)
viper.SetDefault("ops.metrics_collector_cache.enabled", true)
// TTL should be slightly larger than collection interval (1m) to maximize cross-replica cache hits.
viper.SetDefault("ops.metrics_collector_cache.ttl", 65*time.Second)
// JWT
viper.SetDefault("jwt.secret", "")
viper.SetDefault("jwt.expire_hour", 24)
......@@ -666,9 +738,35 @@ func setDefaults() {
// Timezone (default to Asia/Shanghai for Chinese users)
viper.SetDefault("timezone", "Asia/Shanghai")
// API Key auth cache
viper.SetDefault("api_key_auth_cache.l1_size", 65535)
viper.SetDefault("api_key_auth_cache.l1_ttl_seconds", 15)
viper.SetDefault("api_key_auth_cache.l2_ttl_seconds", 300)
viper.SetDefault("api_key_auth_cache.negative_ttl_seconds", 30)
viper.SetDefault("api_key_auth_cache.jitter_percent", 10)
viper.SetDefault("api_key_auth_cache.singleflight", true)
// Dashboard cache
viper.SetDefault("dashboard_cache.enabled", true)
viper.SetDefault("dashboard_cache.key_prefix", "sub2api:")
viper.SetDefault("dashboard_cache.stats_fresh_ttl_seconds", 15)
viper.SetDefault("dashboard_cache.stats_ttl_seconds", 30)
viper.SetDefault("dashboard_cache.stats_refresh_timeout_seconds", 30)
// Dashboard aggregation
viper.SetDefault("dashboard_aggregation.enabled", true)
viper.SetDefault("dashboard_aggregation.interval_seconds", 60)
viper.SetDefault("dashboard_aggregation.lookback_seconds", 120)
viper.SetDefault("dashboard_aggregation.backfill_enabled", false)
viper.SetDefault("dashboard_aggregation.backfill_max_days", 31)
viper.SetDefault("dashboard_aggregation.retention.usage_logs_days", 90)
viper.SetDefault("dashboard_aggregation.retention.hourly_days", 180)
viper.SetDefault("dashboard_aggregation.retention.daily_days", 730)
viper.SetDefault("dashboard_aggregation.recompute_days", 2)
// Gateway
viper.SetDefault("gateway.response_header_timeout", 600) // 600秒(10分钟)等待上游响应头,LLM高负载时可能排队较久
viper.SetDefault("gateway.log_upstream_error_body", false)
viper.SetDefault("gateway.log_upstream_error_body", true)
viper.SetDefault("gateway.log_upstream_error_body_max_bytes", 2048)
viper.SetDefault("gateway.inject_beta_for_apikey", false)
viper.SetDefault("gateway.failover_on_400", false)
......@@ -689,12 +787,21 @@ func setDefaults() {
viper.SetDefault("gateway.stream_keepalive_interval", 10)
viper.SetDefault("gateway.max_line_size", 40*1024*1024)
viper.SetDefault("gateway.scheduling.sticky_session_max_waiting", 3)
viper.SetDefault("gateway.scheduling.sticky_session_wait_timeout", 45*time.Second)
viper.SetDefault("gateway.scheduling.sticky_session_wait_timeout", 120*time.Second)
viper.SetDefault("gateway.scheduling.fallback_wait_timeout", 30*time.Second)
viper.SetDefault("gateway.scheduling.fallback_max_waiting", 100)
viper.SetDefault("gateway.scheduling.fallback_selection_mode", "last_used")
viper.SetDefault("gateway.scheduling.load_batch_enabled", true)
viper.SetDefault("gateway.scheduling.slot_cleanup_interval", 30*time.Second)
viper.SetDefault("gateway.scheduling.db_fallback_enabled", true)
viper.SetDefault("gateway.scheduling.db_fallback_timeout_seconds", 0)
viper.SetDefault("gateway.scheduling.db_fallback_max_qps", 0)
viper.SetDefault("gateway.scheduling.outbox_poll_interval_seconds", 1)
viper.SetDefault("gateway.scheduling.outbox_lag_warn_seconds", 5)
viper.SetDefault("gateway.scheduling.outbox_lag_rebuild_seconds", 10)
viper.SetDefault("gateway.scheduling.outbox_lag_rebuild_failures", 3)
viper.SetDefault("gateway.scheduling.outbox_backlog_rebuild_rows", 10000)
viper.SetDefault("gateway.scheduling.full_rebuild_interval_seconds", 300)
viper.SetDefault("concurrency.ping_interval", 10)
// TokenRefresh
......@@ -711,10 +818,6 @@ func setDefaults() {
viper.SetDefault("gemini.oauth.client_secret", "")
viper.SetDefault("gemini.oauth.scopes", "")
viper.SetDefault("gemini.quota.policy", "")
// Update - 在线更新配置
// 代理地址为空表示直连 GitHub(适用于海外服务器)
viper.SetDefault("update.proxy_url", "")
}
func (c *Config) Validate() error {
......@@ -755,7 +858,8 @@ func (c *Config) Validate() error {
if method == "none" && !c.LinuxDo.UsePKCE {
return fmt.Errorf("linuxdo_connect.use_pkce must be true when linuxdo_connect.token_auth_method=none")
}
if (method == "" || method == "client_secret_post" || method == "client_secret_basic") && strings.TrimSpace(c.LinuxDo.ClientSecret) == "" {
if (method == "" || method == "client_secret_post" || method == "client_secret_basic") &&
strings.TrimSpace(c.LinuxDo.ClientSecret) == "" {
return fmt.Errorf("linuxdo_connect.client_secret is required when linuxdo_connect.enabled=true and token_auth_method is client_secret_post/client_secret_basic")
}
if strings.TrimSpace(c.LinuxDo.FrontendRedirectURL) == "" {
......@@ -828,6 +932,78 @@ func (c *Config) Validate() error {
if c.Redis.MinIdleConns > c.Redis.PoolSize {
return fmt.Errorf("redis.min_idle_conns cannot exceed redis.pool_size")
}
if c.Dashboard.Enabled {
if c.Dashboard.StatsFreshTTLSeconds <= 0 {
return fmt.Errorf("dashboard_cache.stats_fresh_ttl_seconds must be positive")
}
if c.Dashboard.StatsTTLSeconds <= 0 {
return fmt.Errorf("dashboard_cache.stats_ttl_seconds must be positive")
}
if c.Dashboard.StatsRefreshTimeoutSeconds <= 0 {
return fmt.Errorf("dashboard_cache.stats_refresh_timeout_seconds must be positive")
}
if c.Dashboard.StatsFreshTTLSeconds > c.Dashboard.StatsTTLSeconds {
return fmt.Errorf("dashboard_cache.stats_fresh_ttl_seconds must be <= dashboard_cache.stats_ttl_seconds")
}
} else {
if c.Dashboard.StatsFreshTTLSeconds < 0 {
return fmt.Errorf("dashboard_cache.stats_fresh_ttl_seconds must be non-negative")
}
if c.Dashboard.StatsTTLSeconds < 0 {
return fmt.Errorf("dashboard_cache.stats_ttl_seconds must be non-negative")
}
if c.Dashboard.StatsRefreshTimeoutSeconds < 0 {
return fmt.Errorf("dashboard_cache.stats_refresh_timeout_seconds must be non-negative")
}
}
if c.DashboardAgg.Enabled {
if c.DashboardAgg.IntervalSeconds <= 0 {
return fmt.Errorf("dashboard_aggregation.interval_seconds must be positive")
}
if c.DashboardAgg.LookbackSeconds < 0 {
return fmt.Errorf("dashboard_aggregation.lookback_seconds must be non-negative")
}
if c.DashboardAgg.BackfillMaxDays < 0 {
return fmt.Errorf("dashboard_aggregation.backfill_max_days must be non-negative")
}
if c.DashboardAgg.BackfillEnabled && c.DashboardAgg.BackfillMaxDays == 0 {
return fmt.Errorf("dashboard_aggregation.backfill_max_days must be positive")
}
if c.DashboardAgg.Retention.UsageLogsDays <= 0 {
return fmt.Errorf("dashboard_aggregation.retention.usage_logs_days must be positive")
}
if c.DashboardAgg.Retention.HourlyDays <= 0 {
return fmt.Errorf("dashboard_aggregation.retention.hourly_days must be positive")
}
if c.DashboardAgg.Retention.DailyDays <= 0 {
return fmt.Errorf("dashboard_aggregation.retention.daily_days must be positive")
}
if c.DashboardAgg.RecomputeDays < 0 {
return fmt.Errorf("dashboard_aggregation.recompute_days must be non-negative")
}
} else {
if c.DashboardAgg.IntervalSeconds < 0 {
return fmt.Errorf("dashboard_aggregation.interval_seconds must be non-negative")
}
if c.DashboardAgg.LookbackSeconds < 0 {
return fmt.Errorf("dashboard_aggregation.lookback_seconds must be non-negative")
}
if c.DashboardAgg.BackfillMaxDays < 0 {
return fmt.Errorf("dashboard_aggregation.backfill_max_days must be non-negative")
}
if c.DashboardAgg.Retention.UsageLogsDays < 0 {
return fmt.Errorf("dashboard_aggregation.retention.usage_logs_days must be non-negative")
}
if c.DashboardAgg.Retention.HourlyDays < 0 {
return fmt.Errorf("dashboard_aggregation.retention.hourly_days must be non-negative")
}
if c.DashboardAgg.Retention.DailyDays < 0 {
return fmt.Errorf("dashboard_aggregation.retention.daily_days must be non-negative")
}
if c.DashboardAgg.RecomputeDays < 0 {
return fmt.Errorf("dashboard_aggregation.recompute_days must be non-negative")
}
}
if c.Gateway.MaxBodySize <= 0 {
return fmt.Errorf("gateway.max_body_size must be positive")
}
......@@ -898,6 +1074,50 @@ func (c *Config) Validate() error {
if c.Gateway.Scheduling.SlotCleanupInterval < 0 {
return fmt.Errorf("gateway.scheduling.slot_cleanup_interval must be non-negative")
}
if c.Gateway.Scheduling.DbFallbackTimeoutSeconds < 0 {
return fmt.Errorf("gateway.scheduling.db_fallback_timeout_seconds must be non-negative")
}
if c.Gateway.Scheduling.DbFallbackMaxQPS < 0 {
return fmt.Errorf("gateway.scheduling.db_fallback_max_qps must be non-negative")
}
if c.Gateway.Scheduling.OutboxPollIntervalSeconds <= 0 {
return fmt.Errorf("gateway.scheduling.outbox_poll_interval_seconds must be positive")
}
if c.Gateway.Scheduling.OutboxLagWarnSeconds < 0 {
return fmt.Errorf("gateway.scheduling.outbox_lag_warn_seconds must be non-negative")
}
if c.Gateway.Scheduling.OutboxLagRebuildSeconds < 0 {
return fmt.Errorf("gateway.scheduling.outbox_lag_rebuild_seconds must be non-negative")
}
if c.Gateway.Scheduling.OutboxLagRebuildFailures <= 0 {
return fmt.Errorf("gateway.scheduling.outbox_lag_rebuild_failures must be positive")
}
if c.Gateway.Scheduling.OutboxBacklogRebuildRows < 0 {
return fmt.Errorf("gateway.scheduling.outbox_backlog_rebuild_rows must be non-negative")
}
if c.Gateway.Scheduling.FullRebuildIntervalSeconds < 0 {
return fmt.Errorf("gateway.scheduling.full_rebuild_interval_seconds must be non-negative")
}
if c.Gateway.Scheduling.OutboxLagWarnSeconds > 0 &&
c.Gateway.Scheduling.OutboxLagRebuildSeconds > 0 &&
c.Gateway.Scheduling.OutboxLagRebuildSeconds < c.Gateway.Scheduling.OutboxLagWarnSeconds {
return fmt.Errorf("gateway.scheduling.outbox_lag_rebuild_seconds must be >= outbox_lag_warn_seconds")
}
if c.Ops.MetricsCollectorCache.TTL < 0 {
return fmt.Errorf("ops.metrics_collector_cache.ttl must be non-negative")
}
if c.Ops.Cleanup.ErrorLogRetentionDays < 0 {
return fmt.Errorf("ops.cleanup.error_log_retention_days must be non-negative")
}
if c.Ops.Cleanup.MinuteMetricsRetentionDays < 0 {
return fmt.Errorf("ops.cleanup.minute_metrics_retention_days must be non-negative")
}
if c.Ops.Cleanup.HourlyMetricsRetentionDays < 0 {
return fmt.Errorf("ops.cleanup.hourly_metrics_retention_days must be non-negative")
}
if c.Ops.Cleanup.Enabled && strings.TrimSpace(c.Ops.Cleanup.Schedule) == "" {
return fmt.Errorf("ops.cleanup.schedule is required when ops.cleanup.enabled=true")
}
if c.Concurrency.PingInterval < 5 || c.Concurrency.PingInterval > 30 {
return fmt.Errorf("concurrency.ping_interval must be between 5-30 seconds")
}
......@@ -974,3 +1194,77 @@ func GetServerAddress() string {
port := v.GetInt("server.port")
return fmt.Sprintf("%s:%d", host, port)
}
// ValidateAbsoluteHTTPURL 验证是否为有效的绝对 HTTP(S) URL
func ValidateAbsoluteHTTPURL(raw string) error {
raw = strings.TrimSpace(raw)
if raw == "" {
return fmt.Errorf("empty url")
}
u, err := url.Parse(raw)
if err != nil {
return err
}
if !u.IsAbs() {
return fmt.Errorf("must be absolute")
}
if !isHTTPScheme(u.Scheme) {
return fmt.Errorf("unsupported scheme: %s", u.Scheme)
}
if strings.TrimSpace(u.Host) == "" {
return fmt.Errorf("missing host")
}
if u.Fragment != "" {
return fmt.Errorf("must not include fragment")
}
return nil
}
// ValidateFrontendRedirectURL 验证前端重定向 URL(可以是绝对 URL 或相对路径)
func ValidateFrontendRedirectURL(raw string) error {
raw = strings.TrimSpace(raw)
if raw == "" {
return fmt.Errorf("empty url")
}
if strings.ContainsAny(raw, "\r\n") {
return fmt.Errorf("contains invalid characters")
}
if strings.HasPrefix(raw, "/") {
if strings.HasPrefix(raw, "//") {
return fmt.Errorf("must not start with //")
}
return nil
}
u, err := url.Parse(raw)
if err != nil {
return err
}
if !u.IsAbs() {
return fmt.Errorf("must be absolute http(s) url or relative path")
}
if !isHTTPScheme(u.Scheme) {
return fmt.Errorf("unsupported scheme: %s", u.Scheme)
}
if strings.TrimSpace(u.Host) == "" {
return fmt.Errorf("missing host")
}
if u.Fragment != "" {
return fmt.Errorf("must not include fragment")
}
return nil
}
// isHTTPScheme 检查是否为 HTTP 或 HTTPS 协议
func isHTTPScheme(scheme string) bool {
return strings.EqualFold(scheme, "http") || strings.EqualFold(scheme, "https")
}
func warnIfInsecureURL(field, raw string) {
u, err := url.Parse(strings.TrimSpace(raw))
if err != nil {
return
}
if strings.EqualFold(u.Scheme, "http") {
log.Printf("Warning: %s uses http scheme; use https in production to avoid token leakage.", field)
}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment