Commit bb664d9b authored by yangjianbo's avatar yangjianbo
Browse files

feat(sync): full code sync from release

parent bfc7b339
......@@ -72,6 +72,12 @@ func (User) Fields() []ent.Field {
field.Time("totp_enabled_at").
Optional().
Nillable(),
// Sora 存储配额
field.Int64("sora_storage_quota_bytes").
Default(0),
field.Int64("sora_storage_used_bytes").
Default(0),
}
}
......
......@@ -108,6 +108,8 @@ func (UserSubscription) Indexes() []ent.Index {
index.Fields("group_id"),
index.Fields("status"),
index.Fields("expires_at"),
// 活跃订阅查询复合索引(线上由 SQL 迁移创建部分索引,schema 仅用于模型可读性对齐)
index.Fields("user_id", "status", "expires_at"),
index.Fields("assigned_by"),
// 唯一约束通过部分索引实现(WHERE deleted_at IS NULL),支持软删除后重新订阅
// 见迁移文件 016_soft_delete_partial_unique_indexes.sql
......
......@@ -28,6 +28,8 @@ type Tx struct {
ErrorPassthroughRule *ErrorPassthroughRuleClient
// Group is the client for interacting with the Group builders.
Group *GroupClient
// IdempotencyRecord is the client for interacting with the IdempotencyRecord builders.
IdempotencyRecord *IdempotencyRecordClient
// PromoCode is the client for interacting with the PromoCode builders.
PromoCode *PromoCodeClient
// PromoCodeUsage is the client for interacting with the PromoCodeUsage builders.
......@@ -192,6 +194,7 @@ func (tx *Tx) init() {
tx.AnnouncementRead = NewAnnouncementReadClient(tx.config)
tx.ErrorPassthroughRule = NewErrorPassthroughRuleClient(tx.config)
tx.Group = NewGroupClient(tx.config)
tx.IdempotencyRecord = NewIdempotencyRecordClient(tx.config)
tx.PromoCode = NewPromoCodeClient(tx.config)
tx.PromoCodeUsage = NewPromoCodeUsageClient(tx.config)
tx.Proxy = NewProxyClient(tx.config)
......
......@@ -45,6 +45,10 @@ type User struct {
TotpEnabled bool `json:"totp_enabled,omitempty"`
// TotpEnabledAt holds the value of the "totp_enabled_at" field.
TotpEnabledAt *time.Time `json:"totp_enabled_at,omitempty"`
// SoraStorageQuotaBytes holds the value of the "sora_storage_quota_bytes" field.
SoraStorageQuotaBytes int64 `json:"sora_storage_quota_bytes,omitempty"`
// SoraStorageUsedBytes holds the value of the "sora_storage_used_bytes" field.
SoraStorageUsedBytes int64 `json:"sora_storage_used_bytes,omitempty"`
// Edges holds the relations/edges for other nodes in the graph.
// The values are being populated by the UserQuery when eager-loading is set.
Edges UserEdges `json:"edges"`
......@@ -177,7 +181,7 @@ func (*User) scanValues(columns []string) ([]any, error) {
values[i] = new(sql.NullBool)
case user.FieldBalance:
values[i] = new(sql.NullFloat64)
case user.FieldID, user.FieldConcurrency:
case user.FieldID, user.FieldConcurrency, user.FieldSoraStorageQuotaBytes, user.FieldSoraStorageUsedBytes:
values[i] = new(sql.NullInt64)
case user.FieldEmail, user.FieldPasswordHash, user.FieldRole, user.FieldStatus, user.FieldUsername, user.FieldNotes, user.FieldTotpSecretEncrypted:
values[i] = new(sql.NullString)
......@@ -291,6 +295,18 @@ func (_m *User) assignValues(columns []string, values []any) error {
_m.TotpEnabledAt = new(time.Time)
*_m.TotpEnabledAt = value.Time
}
case user.FieldSoraStorageQuotaBytes:
if value, ok := values[i].(*sql.NullInt64); !ok {
return fmt.Errorf("unexpected type %T for field sora_storage_quota_bytes", values[i])
} else if value.Valid {
_m.SoraStorageQuotaBytes = value.Int64
}
case user.FieldSoraStorageUsedBytes:
if value, ok := values[i].(*sql.NullInt64); !ok {
return fmt.Errorf("unexpected type %T for field sora_storage_used_bytes", values[i])
} else if value.Valid {
_m.SoraStorageUsedBytes = value.Int64
}
default:
_m.selectValues.Set(columns[i], values[i])
}
......@@ -424,6 +440,12 @@ func (_m *User) String() string {
builder.WriteString("totp_enabled_at=")
builder.WriteString(v.Format(time.ANSIC))
}
builder.WriteString(", ")
builder.WriteString("sora_storage_quota_bytes=")
builder.WriteString(fmt.Sprintf("%v", _m.SoraStorageQuotaBytes))
builder.WriteString(", ")
builder.WriteString("sora_storage_used_bytes=")
builder.WriteString(fmt.Sprintf("%v", _m.SoraStorageUsedBytes))
builder.WriteByte(')')
return builder.String()
}
......
......@@ -43,6 +43,10 @@ const (
FieldTotpEnabled = "totp_enabled"
// FieldTotpEnabledAt holds the string denoting the totp_enabled_at field in the database.
FieldTotpEnabledAt = "totp_enabled_at"
// FieldSoraStorageQuotaBytes holds the string denoting the sora_storage_quota_bytes field in the database.
FieldSoraStorageQuotaBytes = "sora_storage_quota_bytes"
// FieldSoraStorageUsedBytes holds the string denoting the sora_storage_used_bytes field in the database.
FieldSoraStorageUsedBytes = "sora_storage_used_bytes"
// EdgeAPIKeys holds the string denoting the api_keys edge name in mutations.
EdgeAPIKeys = "api_keys"
// EdgeRedeemCodes holds the string denoting the redeem_codes edge name in mutations.
......@@ -152,6 +156,8 @@ var Columns = []string{
FieldTotpSecretEncrypted,
FieldTotpEnabled,
FieldTotpEnabledAt,
FieldSoraStorageQuotaBytes,
FieldSoraStorageUsedBytes,
}
var (
......@@ -208,6 +214,10 @@ var (
DefaultNotes string
// DefaultTotpEnabled holds the default value on creation for the "totp_enabled" field.
DefaultTotpEnabled bool
// DefaultSoraStorageQuotaBytes holds the default value on creation for the "sora_storage_quota_bytes" field.
DefaultSoraStorageQuotaBytes int64
// DefaultSoraStorageUsedBytes holds the default value on creation for the "sora_storage_used_bytes" field.
DefaultSoraStorageUsedBytes int64
)
// OrderOption defines the ordering options for the User queries.
......@@ -288,6 +298,16 @@ func ByTotpEnabledAt(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldTotpEnabledAt, opts...).ToFunc()
}
// BySoraStorageQuotaBytes orders the results by the sora_storage_quota_bytes field.
func BySoraStorageQuotaBytes(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldSoraStorageQuotaBytes, opts...).ToFunc()
}
// BySoraStorageUsedBytes orders the results by the sora_storage_used_bytes field.
func BySoraStorageUsedBytes(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldSoraStorageUsedBytes, opts...).ToFunc()
}
// ByAPIKeysCount orders the results by api_keys count.
func ByAPIKeysCount(opts ...sql.OrderTermOption) OrderOption {
return func(s *sql.Selector) {
......
......@@ -125,6 +125,16 @@ func TotpEnabledAt(v time.Time) predicate.User {
return predicate.User(sql.FieldEQ(FieldTotpEnabledAt, v))
}
// SoraStorageQuotaBytes applies equality check predicate on the "sora_storage_quota_bytes" field. It's identical to SoraStorageQuotaBytesEQ.
func SoraStorageQuotaBytes(v int64) predicate.User {
return predicate.User(sql.FieldEQ(FieldSoraStorageQuotaBytes, v))
}
// SoraStorageUsedBytes applies equality check predicate on the "sora_storage_used_bytes" field. It's identical to SoraStorageUsedBytesEQ.
func SoraStorageUsedBytes(v int64) predicate.User {
return predicate.User(sql.FieldEQ(FieldSoraStorageUsedBytes, v))
}
// CreatedAtEQ applies the EQ predicate on the "created_at" field.
func CreatedAtEQ(v time.Time) predicate.User {
return predicate.User(sql.FieldEQ(FieldCreatedAt, v))
......@@ -860,6 +870,86 @@ func TotpEnabledAtNotNil() predicate.User {
return predicate.User(sql.FieldNotNull(FieldTotpEnabledAt))
}
// SoraStorageQuotaBytesEQ applies the EQ predicate on the "sora_storage_quota_bytes" field.
func SoraStorageQuotaBytesEQ(v int64) predicate.User {
return predicate.User(sql.FieldEQ(FieldSoraStorageQuotaBytes, v))
}
// SoraStorageQuotaBytesNEQ applies the NEQ predicate on the "sora_storage_quota_bytes" field.
func SoraStorageQuotaBytesNEQ(v int64) predicate.User {
return predicate.User(sql.FieldNEQ(FieldSoraStorageQuotaBytes, v))
}
// SoraStorageQuotaBytesIn applies the In predicate on the "sora_storage_quota_bytes" field.
func SoraStorageQuotaBytesIn(vs ...int64) predicate.User {
return predicate.User(sql.FieldIn(FieldSoraStorageQuotaBytes, vs...))
}
// SoraStorageQuotaBytesNotIn applies the NotIn predicate on the "sora_storage_quota_bytes" field.
func SoraStorageQuotaBytesNotIn(vs ...int64) predicate.User {
return predicate.User(sql.FieldNotIn(FieldSoraStorageQuotaBytes, vs...))
}
// SoraStorageQuotaBytesGT applies the GT predicate on the "sora_storage_quota_bytes" field.
func SoraStorageQuotaBytesGT(v int64) predicate.User {
return predicate.User(sql.FieldGT(FieldSoraStorageQuotaBytes, v))
}
// SoraStorageQuotaBytesGTE applies the GTE predicate on the "sora_storage_quota_bytes" field.
func SoraStorageQuotaBytesGTE(v int64) predicate.User {
return predicate.User(sql.FieldGTE(FieldSoraStorageQuotaBytes, v))
}
// SoraStorageQuotaBytesLT applies the LT predicate on the "sora_storage_quota_bytes" field.
func SoraStorageQuotaBytesLT(v int64) predicate.User {
return predicate.User(sql.FieldLT(FieldSoraStorageQuotaBytes, v))
}
// SoraStorageQuotaBytesLTE applies the LTE predicate on the "sora_storage_quota_bytes" field.
func SoraStorageQuotaBytesLTE(v int64) predicate.User {
return predicate.User(sql.FieldLTE(FieldSoraStorageQuotaBytes, v))
}
// SoraStorageUsedBytesEQ applies the EQ predicate on the "sora_storage_used_bytes" field.
func SoraStorageUsedBytesEQ(v int64) predicate.User {
return predicate.User(sql.FieldEQ(FieldSoraStorageUsedBytes, v))
}
// SoraStorageUsedBytesNEQ applies the NEQ predicate on the "sora_storage_used_bytes" field.
func SoraStorageUsedBytesNEQ(v int64) predicate.User {
return predicate.User(sql.FieldNEQ(FieldSoraStorageUsedBytes, v))
}
// SoraStorageUsedBytesIn applies the In predicate on the "sora_storage_used_bytes" field.
func SoraStorageUsedBytesIn(vs ...int64) predicate.User {
return predicate.User(sql.FieldIn(FieldSoraStorageUsedBytes, vs...))
}
// SoraStorageUsedBytesNotIn applies the NotIn predicate on the "sora_storage_used_bytes" field.
func SoraStorageUsedBytesNotIn(vs ...int64) predicate.User {
return predicate.User(sql.FieldNotIn(FieldSoraStorageUsedBytes, vs...))
}
// SoraStorageUsedBytesGT applies the GT predicate on the "sora_storage_used_bytes" field.
func SoraStorageUsedBytesGT(v int64) predicate.User {
return predicate.User(sql.FieldGT(FieldSoraStorageUsedBytes, v))
}
// SoraStorageUsedBytesGTE applies the GTE predicate on the "sora_storage_used_bytes" field.
func SoraStorageUsedBytesGTE(v int64) predicate.User {
return predicate.User(sql.FieldGTE(FieldSoraStorageUsedBytes, v))
}
// SoraStorageUsedBytesLT applies the LT predicate on the "sora_storage_used_bytes" field.
func SoraStorageUsedBytesLT(v int64) predicate.User {
return predicate.User(sql.FieldLT(FieldSoraStorageUsedBytes, v))
}
// SoraStorageUsedBytesLTE applies the LTE predicate on the "sora_storage_used_bytes" field.
func SoraStorageUsedBytesLTE(v int64) predicate.User {
return predicate.User(sql.FieldLTE(FieldSoraStorageUsedBytes, v))
}
// HasAPIKeys applies the HasEdge predicate on the "api_keys" edge.
func HasAPIKeys() predicate.User {
return predicate.User(func(s *sql.Selector) {
......
......@@ -210,6 +210,34 @@ func (_c *UserCreate) SetNillableTotpEnabledAt(v *time.Time) *UserCreate {
return _c
}
// SetSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field.
func (_c *UserCreate) SetSoraStorageQuotaBytes(v int64) *UserCreate {
_c.mutation.SetSoraStorageQuotaBytes(v)
return _c
}
// SetNillableSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field if the given value is not nil.
func (_c *UserCreate) SetNillableSoraStorageQuotaBytes(v *int64) *UserCreate {
if v != nil {
_c.SetSoraStorageQuotaBytes(*v)
}
return _c
}
// SetSoraStorageUsedBytes sets the "sora_storage_used_bytes" field.
func (_c *UserCreate) SetSoraStorageUsedBytes(v int64) *UserCreate {
_c.mutation.SetSoraStorageUsedBytes(v)
return _c
}
// SetNillableSoraStorageUsedBytes sets the "sora_storage_used_bytes" field if the given value is not nil.
func (_c *UserCreate) SetNillableSoraStorageUsedBytes(v *int64) *UserCreate {
if v != nil {
_c.SetSoraStorageUsedBytes(*v)
}
return _c
}
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs.
func (_c *UserCreate) AddAPIKeyIDs(ids ...int64) *UserCreate {
_c.mutation.AddAPIKeyIDs(ids...)
......@@ -424,6 +452,14 @@ func (_c *UserCreate) defaults() error {
v := user.DefaultTotpEnabled
_c.mutation.SetTotpEnabled(v)
}
if _, ok := _c.mutation.SoraStorageQuotaBytes(); !ok {
v := user.DefaultSoraStorageQuotaBytes
_c.mutation.SetSoraStorageQuotaBytes(v)
}
if _, ok := _c.mutation.SoraStorageUsedBytes(); !ok {
v := user.DefaultSoraStorageUsedBytes
_c.mutation.SetSoraStorageUsedBytes(v)
}
return nil
}
......@@ -487,6 +523,12 @@ func (_c *UserCreate) check() error {
if _, ok := _c.mutation.TotpEnabled(); !ok {
return &ValidationError{Name: "totp_enabled", err: errors.New(`ent: missing required field "User.totp_enabled"`)}
}
if _, ok := _c.mutation.SoraStorageQuotaBytes(); !ok {
return &ValidationError{Name: "sora_storage_quota_bytes", err: errors.New(`ent: missing required field "User.sora_storage_quota_bytes"`)}
}
if _, ok := _c.mutation.SoraStorageUsedBytes(); !ok {
return &ValidationError{Name: "sora_storage_used_bytes", err: errors.New(`ent: missing required field "User.sora_storage_used_bytes"`)}
}
return nil
}
......@@ -570,6 +612,14 @@ func (_c *UserCreate) createSpec() (*User, *sqlgraph.CreateSpec) {
_spec.SetField(user.FieldTotpEnabledAt, field.TypeTime, value)
_node.TotpEnabledAt = &value
}
if value, ok := _c.mutation.SoraStorageQuotaBytes(); ok {
_spec.SetField(user.FieldSoraStorageQuotaBytes, field.TypeInt64, value)
_node.SoraStorageQuotaBytes = value
}
if value, ok := _c.mutation.SoraStorageUsedBytes(); ok {
_spec.SetField(user.FieldSoraStorageUsedBytes, field.TypeInt64, value)
_node.SoraStorageUsedBytes = value
}
if nodes := _c.mutation.APIKeysIDs(); len(nodes) > 0 {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
......@@ -956,6 +1006,42 @@ func (u *UserUpsert) ClearTotpEnabledAt() *UserUpsert {
return u
}
// SetSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field.
func (u *UserUpsert) SetSoraStorageQuotaBytes(v int64) *UserUpsert {
u.Set(user.FieldSoraStorageQuotaBytes, v)
return u
}
// UpdateSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field to the value that was provided on create.
func (u *UserUpsert) UpdateSoraStorageQuotaBytes() *UserUpsert {
u.SetExcluded(user.FieldSoraStorageQuotaBytes)
return u
}
// AddSoraStorageQuotaBytes adds v to the "sora_storage_quota_bytes" field.
func (u *UserUpsert) AddSoraStorageQuotaBytes(v int64) *UserUpsert {
u.Add(user.FieldSoraStorageQuotaBytes, v)
return u
}
// SetSoraStorageUsedBytes sets the "sora_storage_used_bytes" field.
func (u *UserUpsert) SetSoraStorageUsedBytes(v int64) *UserUpsert {
u.Set(user.FieldSoraStorageUsedBytes, v)
return u
}
// UpdateSoraStorageUsedBytes sets the "sora_storage_used_bytes" field to the value that was provided on create.
func (u *UserUpsert) UpdateSoraStorageUsedBytes() *UserUpsert {
u.SetExcluded(user.FieldSoraStorageUsedBytes)
return u
}
// AddSoraStorageUsedBytes adds v to the "sora_storage_used_bytes" field.
func (u *UserUpsert) AddSoraStorageUsedBytes(v int64) *UserUpsert {
u.Add(user.FieldSoraStorageUsedBytes, v)
return u
}
// UpdateNewValues updates the mutable fields using the new values that were set on create.
// Using this option is equivalent to using:
//
......@@ -1218,6 +1304,48 @@ func (u *UserUpsertOne) ClearTotpEnabledAt() *UserUpsertOne {
})
}
// SetSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field.
func (u *UserUpsertOne) SetSoraStorageQuotaBytes(v int64) *UserUpsertOne {
return u.Update(func(s *UserUpsert) {
s.SetSoraStorageQuotaBytes(v)
})
}
// AddSoraStorageQuotaBytes adds v to the "sora_storage_quota_bytes" field.
func (u *UserUpsertOne) AddSoraStorageQuotaBytes(v int64) *UserUpsertOne {
return u.Update(func(s *UserUpsert) {
s.AddSoraStorageQuotaBytes(v)
})
}
// UpdateSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field to the value that was provided on create.
func (u *UserUpsertOne) UpdateSoraStorageQuotaBytes() *UserUpsertOne {
return u.Update(func(s *UserUpsert) {
s.UpdateSoraStorageQuotaBytes()
})
}
// SetSoraStorageUsedBytes sets the "sora_storage_used_bytes" field.
func (u *UserUpsertOne) SetSoraStorageUsedBytes(v int64) *UserUpsertOne {
return u.Update(func(s *UserUpsert) {
s.SetSoraStorageUsedBytes(v)
})
}
// AddSoraStorageUsedBytes adds v to the "sora_storage_used_bytes" field.
func (u *UserUpsertOne) AddSoraStorageUsedBytes(v int64) *UserUpsertOne {
return u.Update(func(s *UserUpsert) {
s.AddSoraStorageUsedBytes(v)
})
}
// UpdateSoraStorageUsedBytes sets the "sora_storage_used_bytes" field to the value that was provided on create.
func (u *UserUpsertOne) UpdateSoraStorageUsedBytes() *UserUpsertOne {
return u.Update(func(s *UserUpsert) {
s.UpdateSoraStorageUsedBytes()
})
}
// Exec executes the query.
func (u *UserUpsertOne) Exec(ctx context.Context) error {
if len(u.create.conflict) == 0 {
......@@ -1646,6 +1774,48 @@ func (u *UserUpsertBulk) ClearTotpEnabledAt() *UserUpsertBulk {
})
}
// SetSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field.
func (u *UserUpsertBulk) SetSoraStorageQuotaBytes(v int64) *UserUpsertBulk {
return u.Update(func(s *UserUpsert) {
s.SetSoraStorageQuotaBytes(v)
})
}
// AddSoraStorageQuotaBytes adds v to the "sora_storage_quota_bytes" field.
func (u *UserUpsertBulk) AddSoraStorageQuotaBytes(v int64) *UserUpsertBulk {
return u.Update(func(s *UserUpsert) {
s.AddSoraStorageQuotaBytes(v)
})
}
// UpdateSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field to the value that was provided on create.
func (u *UserUpsertBulk) UpdateSoraStorageQuotaBytes() *UserUpsertBulk {
return u.Update(func(s *UserUpsert) {
s.UpdateSoraStorageQuotaBytes()
})
}
// SetSoraStorageUsedBytes sets the "sora_storage_used_bytes" field.
func (u *UserUpsertBulk) SetSoraStorageUsedBytes(v int64) *UserUpsertBulk {
return u.Update(func(s *UserUpsert) {
s.SetSoraStorageUsedBytes(v)
})
}
// AddSoraStorageUsedBytes adds v to the "sora_storage_used_bytes" field.
func (u *UserUpsertBulk) AddSoraStorageUsedBytes(v int64) *UserUpsertBulk {
return u.Update(func(s *UserUpsert) {
s.AddSoraStorageUsedBytes(v)
})
}
// UpdateSoraStorageUsedBytes sets the "sora_storage_used_bytes" field to the value that was provided on create.
func (u *UserUpsertBulk) UpdateSoraStorageUsedBytes() *UserUpsertBulk {
return u.Update(func(s *UserUpsert) {
s.UpdateSoraStorageUsedBytes()
})
}
// Exec executes the query.
func (u *UserUpsertBulk) Exec(ctx context.Context) error {
if u.create.err != nil {
......
......@@ -242,6 +242,48 @@ func (_u *UserUpdate) ClearTotpEnabledAt() *UserUpdate {
return _u
}
// SetSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field.
func (_u *UserUpdate) SetSoraStorageQuotaBytes(v int64) *UserUpdate {
_u.mutation.ResetSoraStorageQuotaBytes()
_u.mutation.SetSoraStorageQuotaBytes(v)
return _u
}
// SetNillableSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field if the given value is not nil.
func (_u *UserUpdate) SetNillableSoraStorageQuotaBytes(v *int64) *UserUpdate {
if v != nil {
_u.SetSoraStorageQuotaBytes(*v)
}
return _u
}
// AddSoraStorageQuotaBytes adds value to the "sora_storage_quota_bytes" field.
func (_u *UserUpdate) AddSoraStorageQuotaBytes(v int64) *UserUpdate {
_u.mutation.AddSoraStorageQuotaBytes(v)
return _u
}
// SetSoraStorageUsedBytes sets the "sora_storage_used_bytes" field.
func (_u *UserUpdate) SetSoraStorageUsedBytes(v int64) *UserUpdate {
_u.mutation.ResetSoraStorageUsedBytes()
_u.mutation.SetSoraStorageUsedBytes(v)
return _u
}
// SetNillableSoraStorageUsedBytes sets the "sora_storage_used_bytes" field if the given value is not nil.
func (_u *UserUpdate) SetNillableSoraStorageUsedBytes(v *int64) *UserUpdate {
if v != nil {
_u.SetSoraStorageUsedBytes(*v)
}
return _u
}
// AddSoraStorageUsedBytes adds value to the "sora_storage_used_bytes" field.
func (_u *UserUpdate) AddSoraStorageUsedBytes(v int64) *UserUpdate {
_u.mutation.AddSoraStorageUsedBytes(v)
return _u
}
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs.
func (_u *UserUpdate) AddAPIKeyIDs(ids ...int64) *UserUpdate {
_u.mutation.AddAPIKeyIDs(ids...)
......@@ -709,6 +751,18 @@ func (_u *UserUpdate) sqlSave(ctx context.Context) (_node int, err error) {
if _u.mutation.TotpEnabledAtCleared() {
_spec.ClearField(user.FieldTotpEnabledAt, field.TypeTime)
}
if value, ok := _u.mutation.SoraStorageQuotaBytes(); ok {
_spec.SetField(user.FieldSoraStorageQuotaBytes, field.TypeInt64, value)
}
if value, ok := _u.mutation.AddedSoraStorageQuotaBytes(); ok {
_spec.AddField(user.FieldSoraStorageQuotaBytes, field.TypeInt64, value)
}
if value, ok := _u.mutation.SoraStorageUsedBytes(); ok {
_spec.SetField(user.FieldSoraStorageUsedBytes, field.TypeInt64, value)
}
if value, ok := _u.mutation.AddedSoraStorageUsedBytes(); ok {
_spec.AddField(user.FieldSoraStorageUsedBytes, field.TypeInt64, value)
}
if _u.mutation.APIKeysCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
......@@ -1352,6 +1406,48 @@ func (_u *UserUpdateOne) ClearTotpEnabledAt() *UserUpdateOne {
return _u
}
// SetSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field.
func (_u *UserUpdateOne) SetSoraStorageQuotaBytes(v int64) *UserUpdateOne {
_u.mutation.ResetSoraStorageQuotaBytes()
_u.mutation.SetSoraStorageQuotaBytes(v)
return _u
}
// SetNillableSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field if the given value is not nil.
func (_u *UserUpdateOne) SetNillableSoraStorageQuotaBytes(v *int64) *UserUpdateOne {
if v != nil {
_u.SetSoraStorageQuotaBytes(*v)
}
return _u
}
// AddSoraStorageQuotaBytes adds value to the "sora_storage_quota_bytes" field.
func (_u *UserUpdateOne) AddSoraStorageQuotaBytes(v int64) *UserUpdateOne {
_u.mutation.AddSoraStorageQuotaBytes(v)
return _u
}
// SetSoraStorageUsedBytes sets the "sora_storage_used_bytes" field.
func (_u *UserUpdateOne) SetSoraStorageUsedBytes(v int64) *UserUpdateOne {
_u.mutation.ResetSoraStorageUsedBytes()
_u.mutation.SetSoraStorageUsedBytes(v)
return _u
}
// SetNillableSoraStorageUsedBytes sets the "sora_storage_used_bytes" field if the given value is not nil.
func (_u *UserUpdateOne) SetNillableSoraStorageUsedBytes(v *int64) *UserUpdateOne {
if v != nil {
_u.SetSoraStorageUsedBytes(*v)
}
return _u
}
// AddSoraStorageUsedBytes adds value to the "sora_storage_used_bytes" field.
func (_u *UserUpdateOne) AddSoraStorageUsedBytes(v int64) *UserUpdateOne {
_u.mutation.AddSoraStorageUsedBytes(v)
return _u
}
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs.
func (_u *UserUpdateOne) AddAPIKeyIDs(ids ...int64) *UserUpdateOne {
_u.mutation.AddAPIKeyIDs(ids...)
......@@ -1849,6 +1945,18 @@ func (_u *UserUpdateOne) sqlSave(ctx context.Context) (_node *User, err error) {
if _u.mutation.TotpEnabledAtCleared() {
_spec.ClearField(user.FieldTotpEnabledAt, field.TypeTime)
}
if value, ok := _u.mutation.SoraStorageQuotaBytes(); ok {
_spec.SetField(user.FieldSoraStorageQuotaBytes, field.TypeInt64, value)
}
if value, ok := _u.mutation.AddedSoraStorageQuotaBytes(); ok {
_spec.AddField(user.FieldSoraStorageQuotaBytes, field.TypeInt64, value)
}
if value, ok := _u.mutation.SoraStorageUsedBytes(); ok {
_spec.SetField(user.FieldSoraStorageUsedBytes, field.TypeInt64, value)
}
if value, ok := _u.mutation.AddedSoraStorageUsedBytes(); ok {
_spec.AddField(user.FieldSoraStorageUsedBytes, field.TypeInt64, value)
}
if _u.mutation.APIKeysCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
......
......@@ -7,7 +7,11 @@ require (
github.com/DATA-DOG/go-sqlmock v1.5.2
github.com/DouDOU-start/go-sora2api v1.1.0
github.com/alitto/pond/v2 v2.6.2
github.com/aws/aws-sdk-go-v2/config v1.32.10
github.com/aws/aws-sdk-go-v2/credentials v1.19.10
github.com/aws/aws-sdk-go-v2/service/s3 v1.96.2
github.com/cespare/xxhash/v2 v2.3.0
github.com/coder/websocket v1.8.14
github.com/dgraph-io/ristretto v0.2.0
github.com/gin-gonic/gin v1.9.1
github.com/golang-jwt/jwt/v5 v5.2.2
......@@ -34,6 +38,8 @@ require (
golang.org/x/net v0.49.0
golang.org/x/sync v0.19.0
golang.org/x/term v0.40.0
google.golang.org/grpc v1.75.1
google.golang.org/protobuf v1.36.10
gopkg.in/natefinch/lumberjack.v2 v2.2.1
gopkg.in/yaml.v3 v3.0.1
modernc.org/sqlite v1.44.3
......@@ -47,6 +53,22 @@ require (
github.com/agext/levenshtein v1.2.3 // indirect
github.com/andybalholm/brotli v1.2.0 // indirect
github.com/apparentlymart/go-textseg/v15 v15.0.0 // indirect
github.com/aws/aws-sdk-go-v2 v1.41.2 // indirect
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.5 // indirect
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.18 // indirect
github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.18 // indirect
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.18 // indirect
github.com/aws/aws-sdk-go-v2/internal/ini v1.8.4 // indirect
github.com/aws/aws-sdk-go-v2/internal/v4a v1.4.18 // indirect
github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.5 // indirect
github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.9.10 // indirect
github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.18 // indirect
github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.19.18 // indirect
github.com/aws/aws-sdk-go-v2/service/signin v1.0.6 // indirect
github.com/aws/aws-sdk-go-v2/service/sso v1.30.11 // indirect
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.15 // indirect
github.com/aws/aws-sdk-go-v2/service/sts v1.41.7 // indirect
github.com/aws/smithy-go v1.24.1 // indirect
github.com/bdandy/go-errors v1.2.2 // indirect
github.com/bdandy/go-socks4 v1.2.3 // indirect
github.com/bmatcuk/doublestar v1.3.4 // indirect
......@@ -146,7 +168,6 @@ require (
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.49.0 // indirect
go.opentelemetry.io/otel v1.37.0 // indirect
go.opentelemetry.io/otel/metric v1.37.0 // indirect
go.opentelemetry.io/otel/sdk v1.37.0 // indirect
go.opentelemetry.io/otel/trace v1.37.0 // indirect
go.uber.org/atomic v1.10.0 // indirect
go.uber.org/automaxprocs v1.6.0 // indirect
......@@ -156,8 +177,7 @@ require (
golang.org/x/mod v0.32.0 // indirect
golang.org/x/sys v0.41.0 // indirect
golang.org/x/text v0.34.0 // indirect
google.golang.org/grpc v1.75.1 // indirect
google.golang.org/protobuf v1.36.10 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20250929231259-57b25ae835d4 // indirect
gopkg.in/ini.v1 v1.67.0 // indirect
modernc.org/libc v1.67.6 // indirect
modernc.org/mathutil v1.7.1 // indirect
......
......@@ -22,6 +22,44 @@ github.com/andybalholm/brotli v1.2.0 h1:ukwgCxwYrmACq68yiUqwIWnGY0cTPox/M94sVwTo
github.com/andybalholm/brotli v1.2.0/go.mod h1:rzTDkvFWvIrjDXZHkuS16NPggd91W3kUSvPlQ1pLaKY=
github.com/apparentlymart/go-textseg/v15 v15.0.0 h1:uYvfpb3DyLSCGWnctWKGj857c6ew1u1fNQOlOtuGxQY=
github.com/apparentlymart/go-textseg/v15 v15.0.0/go.mod h1:K8XmNZdhEBkdlyDdvbmmsvpAG721bKi0joRfFdHIWJ4=
github.com/aws/aws-sdk-go-v2 v1.41.2 h1:LuT2rzqNQsauaGkPK/7813XxcZ3o3yePY0Iy891T2ls=
github.com/aws/aws-sdk-go-v2 v1.41.2/go.mod h1:IvvlAZQXvTXznUPfRVfryiG1fbzE2NGK6m9u39YQ+S4=
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.5 h1:zWFmPmgw4sveAYi1mRqG+E/g0461cJ5M4bJ8/nc6d3Q=
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.5/go.mod h1:nVUlMLVV8ycXSb7mSkcNu9e3v/1TJq2RTlrPwhYWr5c=
github.com/aws/aws-sdk-go-v2/config v1.32.10 h1:9DMthfO6XWZYLfzZglAgW5Fyou2nRI5CuV44sTedKBI=
github.com/aws/aws-sdk-go-v2/config v1.32.10/go.mod h1:2rUIOnA2JaiqYmSKYmRJlcMWy6qTj1vuRFscppSBMcw=
github.com/aws/aws-sdk-go-v2/credentials v1.19.10 h1:EEhmEUFCE1Yhl7vDhNOI5OCL/iKMdkkYFTRpZXNw7m8=
github.com/aws/aws-sdk-go-v2/credentials v1.19.10/go.mod h1:RnnlFCAlxQCkN2Q379B67USkBMu1PipEEiibzYN5UTE=
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.18 h1:Ii4s+Sq3yDfaMLpjrJsqD6SmG/Wq/P5L/hw2qa78UAY=
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.18/go.mod h1:6x81qnY++ovptLE6nWQeWrpXxbnlIex+4H4eYYGcqfc=
github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.18 h1:F43zk1vemYIqPAwhjTjYIz0irU2EY7sOb/F5eJ3HuyM=
github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.18/go.mod h1:w1jdlZXrGKaJcNoL+Nnrj+k5wlpGXqnNrKoP22HvAug=
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.18 h1:xCeWVjj0ki0l3nruoyP2slHsGArMxeiiaoPN5QZH6YQ=
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.18/go.mod h1:r/eLGuGCBw6l36ZRWiw6PaZwPXb6YOj+i/7MizNl5/k=
github.com/aws/aws-sdk-go-v2/internal/ini v1.8.4 h1:WKuaxf++XKWlHWu9ECbMlha8WOEGm0OUEZqm4K/Gcfk=
github.com/aws/aws-sdk-go-v2/internal/ini v1.8.4/go.mod h1:ZWy7j6v1vWGmPReu0iSGvRiise4YI5SkR3OHKTZ6Wuc=
github.com/aws/aws-sdk-go-v2/internal/v4a v1.4.18 h1:eZioDaZGJ0tMM4gzmkNIO2aAoQd+je7Ug7TkvAzlmkU=
github.com/aws/aws-sdk-go-v2/internal/v4a v1.4.18/go.mod h1:CCXwUKAJdoWr6/NcxZ+zsiPr6oH/Q5aTooRGYieAyj4=
github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.5 h1:CeY9LUdur+Dxoeldqoun6y4WtJ3RQtzk0JMP2gfUay0=
github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.5/go.mod h1:AZLZf2fMaahW5s/wMRciu1sYbdsikT/UHwbUjOdEVTc=
github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.9.10 h1:fJvQ5mIBVfKtiyx0AHY6HeWcRX5LGANLpq8SVR+Uazs=
github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.9.10/go.mod h1:Kzm5e6OmNH8VMkgK9t+ry5jEih4Y8whqs+1hrkxim1I=
github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.18 h1:LTRCYFlnnKFlKsyIQxKhJuDuA3ZkrDQMRYm6rXiHlLY=
github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.18/go.mod h1:XhwkgGG6bHSd00nO/mexWTcTjgd6PjuvWQMqSn2UaEk=
github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.19.18 h1:/A/xDuZAVD2BpsS2fftFRo/NoEKQJ8YTnJDEHBy2Gtg=
github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.19.18/go.mod h1:hWe9b4f+djUQGmyiGEeOnZv69dtMSgpDRIvNMvuvzvY=
github.com/aws/aws-sdk-go-v2/service/s3 v1.96.2 h1:M1A9AjcFwlxTLuf0Faj88L8Iqw0n/AJHjpZTQzMMsSc=
github.com/aws/aws-sdk-go-v2/service/s3 v1.96.2/go.mod h1:KsdTV6Q9WKUZm2mNJnUFmIoXfZux91M3sr/a4REX8e0=
github.com/aws/aws-sdk-go-v2/service/signin v1.0.6 h1:MzORe+J94I+hYu2a6XmV5yC9huoTv8NRcCrUNedDypQ=
github.com/aws/aws-sdk-go-v2/service/signin v1.0.6/go.mod h1:hXzcHLARD7GeWnifd8j9RWqtfIgxj4/cAtIVIK7hg8g=
github.com/aws/aws-sdk-go-v2/service/sso v1.30.11 h1:7oGD8KPfBOJGXiCoRKrrrQkbvCp8N++u36hrLMPey6o=
github.com/aws/aws-sdk-go-v2/service/sso v1.30.11/go.mod h1:0DO9B5EUJQlIDif+XJRWCljZRKsAFKh3gpFz7UnDtOo=
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.15 h1:edCcNp9eGIUDUCrzoCu1jWAXLGFIizeqkdkKgRlJwWc=
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.15/go.mod h1:lyRQKED9xWfgkYC/wmmYfv7iVIM68Z5OQ88ZdcV1QbU=
github.com/aws/aws-sdk-go-v2/service/sts v1.41.7 h1:NITQpgo9A5NrDZ57uOWj+abvXSb83BbyggcUBVksN7c=
github.com/aws/aws-sdk-go-v2/service/sts v1.41.7/go.mod h1:sks5UWBhEuWYDPdwlnRFn1w7xWdH29Jcpe+/PJQefEs=
github.com/aws/smithy-go v1.24.1 h1:VbyeNfmYkWoxMVpGUAbQumkODcYmfMRfZ8yQiH30SK0=
github.com/aws/smithy-go v1.24.1/go.mod h1:LEj2LM3rBRQJxPZTB4KuzZkaZYnZPnvgIhb4pu07mx0=
github.com/bdandy/go-errors v1.2.2 h1:WdFv/oukjTJCLa79UfkGmwX7ZxONAihKu4V0mLIs11Q=
github.com/bdandy/go-errors v1.2.2/go.mod h1:NkYHl4Fey9oRRdbB1CoC6e84tuqQHiqrOcZpqFEkBxM=
github.com/bdandy/go-socks4 v1.2.3 h1:Q6Y2heY1GRjCtHbmlKfnwrKVU/k81LS8mRGLRlmDlic=
......@@ -56,6 +94,12 @@ github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XL
github.com/chenzhuoyu/base64x v0.0.0-20211019084208-fb5309c8db06/go.mod h1:DH46F32mSOjUmXrMHnKwZdA8wcEefY7UVqBKYGjpdQY=
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 h1:qSGYFH7+jGhDF8vLC+iwCD4WpbV1EBDSzWkJODFLams=
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311/go.mod h1:b583jCggY9gE99b6G5LEC39OIiVsWj+R97kbl5odCEk=
github.com/clipperhouse/stringish v0.1.1 h1:+NSqMOr3GR6k1FdRhhnXrLfztGzuG+VuFDfatpWHKCs=
github.com/clipperhouse/stringish v0.1.1/go.mod h1:v/WhFtE1q0ovMta2+m+UbpZ+2/HEXNWYXQgCt4hdOzA=
github.com/clipperhouse/uax29/v2 v2.5.0 h1:x7T0T4eTHDONxFJsL94uKNKPHrclyFI0lm7+w94cO8U=
github.com/clipperhouse/uax29/v2 v2.5.0/go.mod h1:Wn1g7MK6OoeDT0vL+Q0SQLDz/KpfsVRgg6W7ihQeh4g=
github.com/coder/websocket v1.8.14 h1:9L0p0iKiNOibykf283eHkKUHHrpG7f65OE3BhhO7v9g=
github.com/coder/websocket v1.8.14/go.mod h1:NX3SzP+inril6yawo5CQXx8+fk145lPDC6pumgx0mVg=
github.com/containerd/errdefs v1.0.0 h1:tg5yIfIlQIrxYtu9ajqY42W3lpS19XqdxRQeEwYG8PI=
github.com/containerd/errdefs v1.0.0/go.mod h1:+YBYIdtsnF4Iw6nWZhJcqGSg/dwvV7tyJ/kCkyJ2k+M=
github.com/containerd/errdefs/pkg v0.3.0 h1:9IKJ06FvyNlexW690DXuQNx2KA2cUJXx151Xdx3ZPPE=
......@@ -127,6 +171,8 @@ github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU=
github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I=
github.com/golang-jwt/jwt/v5 v5.2.2 h1:Rl4B7itRWVtYIHFrSNd7vhTiz9UpLdi6gZhZ3wEeDy8=
github.com/golang-jwt/jwt/v5 v5.2.2/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk=
github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek=
github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps=
github.com/google/go-cmp v0.5.2/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
......@@ -190,6 +236,8 @@ github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovk
github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM=
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/mattn/go-runewidth v0.0.19 h1:v++JhqYnZuu5jSKrk9RbgF5v4CGUjqRfBm05byFGLdw=
github.com/mattn/go-runewidth v0.0.19/go.mod h1:XBkDxAl56ILZc9knddidhrOlY5R/pDhgLpndooCuJAs=
github.com/mattn/go-sqlite3 v1.14.17 h1:mCRHCLDUBXgpKAqIKsaAaAsrAlbkeomtRFKXh2L6YIM=
github.com/mattn/go-sqlite3 v1.14.17/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg=
github.com/mdelapenya/tlscert v0.2.0 h1:7H81W6Z/4weDvZBNOfQte5GpIMo0lGYEeWbkGp5LJHI=
......@@ -223,6 +271,8 @@ github.com/morikuni/aec v1.0.0 h1:nP9CBfwrvYnBRgY6qfDQkygYDmYwOilePFkwzv4dU8A=
github.com/morikuni/aec v1.0.0/go.mod h1:BbKIizmSmc5MMPqRYbxO4ZU0S0+P200+tUnFx7PXmsc=
github.com/ncruces/go-strftime v1.0.0 h1:HMFp8mLCTPp341M/ZnA4qaf7ZlsbTc+miZjCLOFAw7w=
github.com/ncruces/go-strftime v1.0.0/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls=
github.com/olekukonko/tablewriter v0.0.5 h1:P2Ga83D34wi1o9J6Wh1mRuqd4mF/x/lgBS7N7AbDhec=
github.com/olekukonko/tablewriter v0.0.5/go.mod h1:hPp6KlRPjbx+hW8ykQs1w3UBbZlj6HuIJcUGPhkA7kY=
github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8Oi/yOhh5U=
github.com/opencontainers/go-digest v1.0.0/go.mod h1:0JzlMkj0TRzQZfJkVvzbP0HBR3IKzErnv2BNG4W4MAM=
github.com/opencontainers/image-spec v1.1.1 h1:y0fUlFfIZhPF1W537XOLg0/fcx6zcHCJwooC2xJA040=
......@@ -274,6 +324,8 @@ github.com/spf13/afero v1.11.0 h1:WJQKhtpdm3v2IzqG8VMqrr6Rf3UYpEF239Jy9wNepM8=
github.com/spf13/afero v1.11.0/go.mod h1:GH9Y3pIexgf1MTIWtNGyogA5MwRIDXGUr+hbWNoBjkY=
github.com/spf13/cast v1.6.0 h1:GEiTHELF+vaR5dhz3VqZfFSzZjYbgeKDpBxQVS4GYJ0=
github.com/spf13/cast v1.6.0/go.mod h1:ancEpBxwJDODSW/UG4rDrAqiKolqNNh2DX3mk86cAdo=
github.com/spf13/cobra v1.7.0 h1:hyqWnYt1ZQShIddO5kBpj3vu05/++x6tJ6dg8EC572I=
github.com/spf13/cobra v1.7.0/go.mod h1:uLxZILRyS/50WlhOIKD7W6V5bgeIt+4sICxh6uRMrb0=
github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA=
github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
github.com/spf13/viper v1.18.2 h1:LUXCnvUvSM6FXAsj6nnfc8Q2tp1dIgUfY9Kc8GsSOiQ=
......@@ -344,6 +396,8 @@ go.opentelemetry.io/otel/metric v1.37.0 h1:mvwbQS5m0tbmqML4NqK+e3aDiO02vsf/Wgbsd
go.opentelemetry.io/otel/metric v1.37.0/go.mod h1:04wGrZurHYKOc+RKeye86GwKiTb9FKm1WHtO+4EVr2E=
go.opentelemetry.io/otel/sdk v1.37.0 h1:ItB0QUqnjesGRvNcmAcU0LyvkVyGJ2xftD29bWdDvKI=
go.opentelemetry.io/otel/sdk v1.37.0/go.mod h1:VredYzxUvuo2q3WRcDnKDjbdvmO0sCzOvVAiY+yUkAg=
go.opentelemetry.io/otel/sdk/metric v1.37.0 h1:90lI228XrB9jCMuSdA0673aubgRobVZFhbjxHHspCPc=
go.opentelemetry.io/otel/sdk/metric v1.37.0/go.mod h1:cNen4ZWfiD37l5NhS+Keb5RXVWZWpRE+9WyVCpbo5ps=
go.opentelemetry.io/otel/trace v1.37.0 h1:HLdcFNbRQBE2imdSEgm/kwqmQj1Or1l/7bW6mxVK7z4=
go.opentelemetry.io/otel/trace v1.37.0/go.mod h1:TlgrlQ+PtQO5XFerSPUYG0JSgGyryXewPGyayAWSBS0=
go.opentelemetry.io/proto/otlp v1.3.1 h1:TrMUixzpM0yuc/znrFTP9MMRh8trP93mkCiDVeXrui0=
......@@ -399,6 +453,8 @@ golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGm
golang.org/x/tools v0.41.0 h1:a9b8iMweWG+S0OBnlU36rzLp20z1Rp10w+IY2czHTQc=
golang.org/x/tools v0.41.0/go.mod h1:XSY6eDqxVNiYgezAVqqCeihT4j1U2CCsqvH3WhQpnlg=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
gonum.org/v1/gonum v0.16.0 h1:5+ul4Swaf3ESvrOnidPp4GZbzf0mxVQpDCYUQE7OJfk=
gonum.org/v1/gonum v0.16.0/go.mod h1:fef3am4MQ93R2HHpKnLk4/Tbh/s0+wqD5nfa6Pnwy4E=
google.golang.org/genproto v0.0.0-20231106174013-bbf56f31fb17 h1:wpZ8pe2x1Q3f2KyT5f8oP/fa9rHAKgFPr/HZdNuS+PQ=
google.golang.org/genproto/googleapis/api v0.0.0-20250929231259-57b25ae835d4 h1:8XJ4pajGwOlasW+L13MnEGA8W4115jJySQtVfS2/IBU=
google.golang.org/genproto/googleapis/api v0.0.0-20250929231259-57b25ae835d4/go.mod h1:NnuHhy+bxcg30o7FnVAZbXsPHUDQ9qKWAQKCD7VxFtk=
......
......@@ -364,6 +364,8 @@ type GatewayConfig struct {
// OpenAIPassthroughAllowTimeoutHeaders: OpenAI 透传模式是否放行客户端超时头
// 关闭(默认)可避免 x-stainless-timeout 等头导致上游提前断流。
OpenAIPassthroughAllowTimeoutHeaders bool `mapstructure:"openai_passthrough_allow_timeout_headers"`
// OpenAIWS: OpenAI Responses WebSocket 配置(默认开启,可按需回滚到 HTTP)
OpenAIWS GatewayOpenAIWSConfig `mapstructure:"openai_ws"`
// HTTP 上游连接池配置(性能优化:支持高并发场景调优)
// MaxIdleConns: 所有主机的最大空闲连接总数
......@@ -450,6 +452,101 @@ type GatewayConfig struct {
ModelsListCacheTTLSeconds int `mapstructure:"models_list_cache_ttl_seconds"`
}
// GatewayOpenAIWSConfig OpenAI Responses WebSocket 配置。
// 注意:默认全局开启;如需回滚可使用 force_http 或关闭 enabled。
type GatewayOpenAIWSConfig struct {
// ModeRouterV2Enabled: 新版 WS mode 路由开关(默认 false;关闭时保持 legacy 行为)
ModeRouterV2Enabled bool `mapstructure:"mode_router_v2_enabled"`
// IngressModeDefault: ingress 默认模式(off/shared/dedicated)
IngressModeDefault string `mapstructure:"ingress_mode_default"`
// Enabled: 全局总开关(默认 true)
Enabled bool `mapstructure:"enabled"`
// OAuthEnabled: 是否允许 OpenAI OAuth 账号使用 WS
OAuthEnabled bool `mapstructure:"oauth_enabled"`
// APIKeyEnabled: 是否允许 OpenAI API Key 账号使用 WS
APIKeyEnabled bool `mapstructure:"apikey_enabled"`
// ForceHTTP: 全局强制 HTTP(用于紧急回滚)
ForceHTTP bool `mapstructure:"force_http"`
// AllowStoreRecovery: 允许在 WSv2 下按策略恢复 store=true(默认 false)
AllowStoreRecovery bool `mapstructure:"allow_store_recovery"`
// IngressPreviousResponseRecoveryEnabled: ingress 模式收到 previous_response_not_found 时,是否允许自动去掉 previous_response_id 重试一次(默认 true)
IngressPreviousResponseRecoveryEnabled bool `mapstructure:"ingress_previous_response_recovery_enabled"`
// StoreDisabledConnMode: store=false 且无可复用会话连接时的建连策略(strict/adaptive/off)
// - strict: 强制新建连接(隔离优先)
// - adaptive: 仅在高风险失败后强制新建连接(性能与隔离折中)
// - off: 不强制新建连接(复用优先)
StoreDisabledConnMode string `mapstructure:"store_disabled_conn_mode"`
// StoreDisabledForceNewConn: store=false 且无可复用粘连连接时是否强制新建连接(默认 true,保障会话隔离)
// 兼容旧配置;当 StoreDisabledConnMode 为空时才生效。
StoreDisabledForceNewConn bool `mapstructure:"store_disabled_force_new_conn"`
// PrewarmGenerateEnabled: 是否启用 WSv2 generate=false 预热(默认 false)
PrewarmGenerateEnabled bool `mapstructure:"prewarm_generate_enabled"`
// Feature 开关:v2 优先于 v1
ResponsesWebsockets bool `mapstructure:"responses_websockets"`
ResponsesWebsocketsV2 bool `mapstructure:"responses_websockets_v2"`
// 连接池参数
MaxConnsPerAccount int `mapstructure:"max_conns_per_account"`
MinIdlePerAccount int `mapstructure:"min_idle_per_account"`
MaxIdlePerAccount int `mapstructure:"max_idle_per_account"`
// DynamicMaxConnsByAccountConcurrencyEnabled: 是否按账号并发动态计算连接池上限
DynamicMaxConnsByAccountConcurrencyEnabled bool `mapstructure:"dynamic_max_conns_by_account_concurrency_enabled"`
// OAuthMaxConnsFactor: OAuth 账号连接池系数(effective=ceil(concurrency*factor))
OAuthMaxConnsFactor float64 `mapstructure:"oauth_max_conns_factor"`
// APIKeyMaxConnsFactor: API Key 账号连接池系数(effective=ceil(concurrency*factor))
APIKeyMaxConnsFactor float64 `mapstructure:"apikey_max_conns_factor"`
DialTimeoutSeconds int `mapstructure:"dial_timeout_seconds"`
ReadTimeoutSeconds int `mapstructure:"read_timeout_seconds"`
WriteTimeoutSeconds int `mapstructure:"write_timeout_seconds"`
PoolTargetUtilization float64 `mapstructure:"pool_target_utilization"`
QueueLimitPerConn int `mapstructure:"queue_limit_per_conn"`
// EventFlushBatchSize: WS 流式写出批量 flush 阈值(事件条数)
EventFlushBatchSize int `mapstructure:"event_flush_batch_size"`
// EventFlushIntervalMS: WS 流式写出最大等待时间(毫秒);0 表示仅按 batch 触发
EventFlushIntervalMS int `mapstructure:"event_flush_interval_ms"`
// PrewarmCooldownMS: 连接池预热触发冷却时间(毫秒)
PrewarmCooldownMS int `mapstructure:"prewarm_cooldown_ms"`
// FallbackCooldownSeconds: WS 回退冷却窗口,避免 WS/HTTP 抖动;0 表示关闭冷却
FallbackCooldownSeconds int `mapstructure:"fallback_cooldown_seconds"`
// RetryBackoffInitialMS: WS 重试初始退避(毫秒);<=0 表示关闭退避
RetryBackoffInitialMS int `mapstructure:"retry_backoff_initial_ms"`
// RetryBackoffMaxMS: WS 重试最大退避(毫秒)
RetryBackoffMaxMS int `mapstructure:"retry_backoff_max_ms"`
// RetryJitterRatio: WS 重试退避抖动比例(0-1)
RetryJitterRatio float64 `mapstructure:"retry_jitter_ratio"`
// RetryTotalBudgetMS: WS 单次请求重试总预算(毫秒);0 表示关闭预算限制
RetryTotalBudgetMS int `mapstructure:"retry_total_budget_ms"`
// PayloadLogSampleRate: payload_schema 日志采样率(0-1)
PayloadLogSampleRate float64 `mapstructure:"payload_log_sample_rate"`
// 账号调度与粘连参数
LBTopK int `mapstructure:"lb_top_k"`
// StickySessionTTLSeconds: session_hash -> account_id 粘连 TTL
StickySessionTTLSeconds int `mapstructure:"sticky_session_ttl_seconds"`
// SessionHashReadOldFallback: 会话哈希迁移期是否允许“新 key 未命中时回退读旧 SHA-256 key”
SessionHashReadOldFallback bool `mapstructure:"session_hash_read_old_fallback"`
// SessionHashDualWriteOld: 会话哈希迁移期是否双写旧 SHA-256 key(短 TTL)
SessionHashDualWriteOld bool `mapstructure:"session_hash_dual_write_old"`
// MetadataBridgeEnabled: RequestMetadata 迁移期是否保留旧 ctxkey.* 兼容桥接
MetadataBridgeEnabled bool `mapstructure:"metadata_bridge_enabled"`
// StickyResponseIDTTLSeconds: response_id -> account_id 粘连 TTL
StickyResponseIDTTLSeconds int `mapstructure:"sticky_response_id_ttl_seconds"`
// StickyPreviousResponseTTLSeconds: 兼容旧键(当新键未设置时回退)
StickyPreviousResponseTTLSeconds int `mapstructure:"sticky_previous_response_ttl_seconds"`
SchedulerScoreWeights GatewayOpenAIWSSchedulerScoreWeights `mapstructure:"scheduler_score_weights"`
}
// GatewayOpenAIWSSchedulerScoreWeights 账号调度打分权重。
type GatewayOpenAIWSSchedulerScoreWeights struct {
Priority float64 `mapstructure:"priority"`
Load float64 `mapstructure:"load"`
Queue float64 `mapstructure:"queue"`
ErrorRate float64 `mapstructure:"error_rate"`
TTFT float64 `mapstructure:"ttft"`
}
// GatewayUsageRecordConfig 使用量记录异步队列配置
type GatewayUsageRecordConfig struct {
// WorkerCount: worker 初始数量(自动扩缩容开启时作为初始并发上限)
......@@ -886,6 +983,12 @@ func load(allowMissingJWTSecret bool) (*Config, error) {
cfg.Log.StacktraceLevel = strings.ToLower(strings.TrimSpace(cfg.Log.StacktraceLevel))
cfg.Log.Output.FilePath = strings.TrimSpace(cfg.Log.Output.FilePath)
// 兼容旧键 gateway.openai_ws.sticky_previous_response_ttl_seconds。
// 新键未配置(<=0)时回退旧键;新键优先。
if cfg.Gateway.OpenAIWS.StickyResponseIDTTLSeconds <= 0 && cfg.Gateway.OpenAIWS.StickyPreviousResponseTTLSeconds > 0 {
cfg.Gateway.OpenAIWS.StickyResponseIDTTLSeconds = cfg.Gateway.OpenAIWS.StickyPreviousResponseTTLSeconds
}
// Auto-generate TOTP encryption key if not set (32 bytes = 64 hex chars for AES-256)
cfg.Totp.EncryptionKey = strings.TrimSpace(cfg.Totp.EncryptionKey)
if cfg.Totp.EncryptionKey == "" {
......@@ -945,7 +1048,7 @@ func setDefaults() {
viper.SetDefault("server.read_header_timeout", 30) // 30秒读取请求头
viper.SetDefault("server.idle_timeout", 120) // 120秒空闲超时
viper.SetDefault("server.trusted_proxies", []string{})
viper.SetDefault("server.max_request_body_size", int64(100*1024*1024))
viper.SetDefault("server.max_request_body_size", int64(256*1024*1024))
// H2C 默认配置
viper.SetDefault("server.h2c.enabled", false)
viper.SetDefault("server.h2c.max_concurrent_streams", uint32(50)) // 50 个并发流
......@@ -1088,9 +1191,9 @@ func setDefaults() {
// RateLimit
viper.SetDefault("rate_limit.overload_cooldown_minutes", 10)
// Pricing - 从 model-price-repo 同步模型定价和上下文窗口数据的配置
viper.SetDefault("pricing.remote_url", "https://github.com/Wei-Shaw/model-price-repo/raw/refs/heads/main/model_prices_and_context_window.json")
viper.SetDefault("pricing.hash_url", "https://github.com/Wei-Shaw/model-price-repo/raw/refs/heads/main/model_prices_and_context_window.sha256")
// Pricing - 从 model-price-repo 同步模型定价和上下文窗口数据(固定到 commit,避免分支漂移)
viper.SetDefault("pricing.remote_url", "https://raw.githubusercontent.com/Wei-Shaw/model-price-repo/c7947e9871687e664180bc971d4837f1fc2784a9/model_prices_and_context_window.json")
viper.SetDefault("pricing.hash_url", "https://raw.githubusercontent.com/Wei-Shaw/model-price-repo/c7947e9871687e664180bc971d4837f1fc2784a9/model_prices_and_context_window.sha256")
viper.SetDefault("pricing.data_dir", "./data")
viper.SetDefault("pricing.fallback_file", "./resources/model-pricing/model_prices_and_context_window.json")
viper.SetDefault("pricing.update_interval_hours", 24)
......@@ -1157,9 +1260,55 @@ func setDefaults() {
viper.SetDefault("gateway.max_account_switches_gemini", 3)
viper.SetDefault("gateway.force_codex_cli", false)
viper.SetDefault("gateway.openai_passthrough_allow_timeout_headers", false)
// OpenAI Responses WebSocket(默认开启;可通过 force_http 紧急回滚)
viper.SetDefault("gateway.openai_ws.enabled", true)
viper.SetDefault("gateway.openai_ws.mode_router_v2_enabled", false)
viper.SetDefault("gateway.openai_ws.ingress_mode_default", "shared")
viper.SetDefault("gateway.openai_ws.oauth_enabled", true)
viper.SetDefault("gateway.openai_ws.apikey_enabled", true)
viper.SetDefault("gateway.openai_ws.force_http", false)
viper.SetDefault("gateway.openai_ws.allow_store_recovery", false)
viper.SetDefault("gateway.openai_ws.ingress_previous_response_recovery_enabled", true)
viper.SetDefault("gateway.openai_ws.store_disabled_conn_mode", "strict")
viper.SetDefault("gateway.openai_ws.store_disabled_force_new_conn", true)
viper.SetDefault("gateway.openai_ws.prewarm_generate_enabled", false)
viper.SetDefault("gateway.openai_ws.responses_websockets", false)
viper.SetDefault("gateway.openai_ws.responses_websockets_v2", true)
viper.SetDefault("gateway.openai_ws.max_conns_per_account", 128)
viper.SetDefault("gateway.openai_ws.min_idle_per_account", 4)
viper.SetDefault("gateway.openai_ws.max_idle_per_account", 12)
viper.SetDefault("gateway.openai_ws.dynamic_max_conns_by_account_concurrency_enabled", true)
viper.SetDefault("gateway.openai_ws.oauth_max_conns_factor", 1.0)
viper.SetDefault("gateway.openai_ws.apikey_max_conns_factor", 1.0)
viper.SetDefault("gateway.openai_ws.dial_timeout_seconds", 10)
viper.SetDefault("gateway.openai_ws.read_timeout_seconds", 900)
viper.SetDefault("gateway.openai_ws.write_timeout_seconds", 120)
viper.SetDefault("gateway.openai_ws.pool_target_utilization", 0.7)
viper.SetDefault("gateway.openai_ws.queue_limit_per_conn", 64)
viper.SetDefault("gateway.openai_ws.event_flush_batch_size", 1)
viper.SetDefault("gateway.openai_ws.event_flush_interval_ms", 10)
viper.SetDefault("gateway.openai_ws.prewarm_cooldown_ms", 300)
viper.SetDefault("gateway.openai_ws.fallback_cooldown_seconds", 30)
viper.SetDefault("gateway.openai_ws.retry_backoff_initial_ms", 120)
viper.SetDefault("gateway.openai_ws.retry_backoff_max_ms", 2000)
viper.SetDefault("gateway.openai_ws.retry_jitter_ratio", 0.2)
viper.SetDefault("gateway.openai_ws.retry_total_budget_ms", 5000)
viper.SetDefault("gateway.openai_ws.payload_log_sample_rate", 0.2)
viper.SetDefault("gateway.openai_ws.lb_top_k", 7)
viper.SetDefault("gateway.openai_ws.sticky_session_ttl_seconds", 3600)
viper.SetDefault("gateway.openai_ws.session_hash_read_old_fallback", true)
viper.SetDefault("gateway.openai_ws.session_hash_dual_write_old", true)
viper.SetDefault("gateway.openai_ws.metadata_bridge_enabled", true)
viper.SetDefault("gateway.openai_ws.sticky_response_id_ttl_seconds", 3600)
viper.SetDefault("gateway.openai_ws.sticky_previous_response_ttl_seconds", 3600)
viper.SetDefault("gateway.openai_ws.scheduler_score_weights.priority", 1.0)
viper.SetDefault("gateway.openai_ws.scheduler_score_weights.load", 1.0)
viper.SetDefault("gateway.openai_ws.scheduler_score_weights.queue", 0.7)
viper.SetDefault("gateway.openai_ws.scheduler_score_weights.error_rate", 0.8)
viper.SetDefault("gateway.openai_ws.scheduler_score_weights.ttft", 0.5)
viper.SetDefault("gateway.antigravity_fallback_cooldown_minutes", 1)
viper.SetDefault("gateway.antigravity_extra_retries", 10)
viper.SetDefault("gateway.max_body_size", int64(100*1024*1024))
viper.SetDefault("gateway.max_body_size", int64(256*1024*1024))
viper.SetDefault("gateway.upstream_response_read_max_bytes", int64(8*1024*1024))
viper.SetDefault("gateway.proxy_probe_response_read_max_bytes", int64(1024*1024))
viper.SetDefault("gateway.gemini_debug_response_headers", false)
......@@ -1747,6 +1896,118 @@ func (c *Config) Validate() error {
(c.Gateway.StreamKeepaliveInterval < 5 || c.Gateway.StreamKeepaliveInterval > 30) {
return fmt.Errorf("gateway.stream_keepalive_interval must be 0 or between 5-30 seconds")
}
// 兼容旧键 sticky_previous_response_ttl_seconds
if c.Gateway.OpenAIWS.StickyResponseIDTTLSeconds <= 0 && c.Gateway.OpenAIWS.StickyPreviousResponseTTLSeconds > 0 {
c.Gateway.OpenAIWS.StickyResponseIDTTLSeconds = c.Gateway.OpenAIWS.StickyPreviousResponseTTLSeconds
}
if c.Gateway.OpenAIWS.MaxConnsPerAccount <= 0 {
return fmt.Errorf("gateway.openai_ws.max_conns_per_account must be positive")
}
if c.Gateway.OpenAIWS.MinIdlePerAccount < 0 {
return fmt.Errorf("gateway.openai_ws.min_idle_per_account must be non-negative")
}
if c.Gateway.OpenAIWS.MaxIdlePerAccount < 0 {
return fmt.Errorf("gateway.openai_ws.max_idle_per_account must be non-negative")
}
if c.Gateway.OpenAIWS.MinIdlePerAccount > c.Gateway.OpenAIWS.MaxIdlePerAccount {
return fmt.Errorf("gateway.openai_ws.min_idle_per_account must be <= max_idle_per_account")
}
if c.Gateway.OpenAIWS.MaxIdlePerAccount > c.Gateway.OpenAIWS.MaxConnsPerAccount {
return fmt.Errorf("gateway.openai_ws.max_idle_per_account must be <= max_conns_per_account")
}
if c.Gateway.OpenAIWS.OAuthMaxConnsFactor <= 0 {
return fmt.Errorf("gateway.openai_ws.oauth_max_conns_factor must be positive")
}
if c.Gateway.OpenAIWS.APIKeyMaxConnsFactor <= 0 {
return fmt.Errorf("gateway.openai_ws.apikey_max_conns_factor must be positive")
}
if c.Gateway.OpenAIWS.DialTimeoutSeconds <= 0 {
return fmt.Errorf("gateway.openai_ws.dial_timeout_seconds must be positive")
}
if c.Gateway.OpenAIWS.ReadTimeoutSeconds <= 0 {
return fmt.Errorf("gateway.openai_ws.read_timeout_seconds must be positive")
}
if c.Gateway.OpenAIWS.WriteTimeoutSeconds <= 0 {
return fmt.Errorf("gateway.openai_ws.write_timeout_seconds must be positive")
}
if c.Gateway.OpenAIWS.PoolTargetUtilization <= 0 || c.Gateway.OpenAIWS.PoolTargetUtilization > 1 {
return fmt.Errorf("gateway.openai_ws.pool_target_utilization must be within (0,1]")
}
if c.Gateway.OpenAIWS.QueueLimitPerConn <= 0 {
return fmt.Errorf("gateway.openai_ws.queue_limit_per_conn must be positive")
}
if c.Gateway.OpenAIWS.EventFlushBatchSize <= 0 {
return fmt.Errorf("gateway.openai_ws.event_flush_batch_size must be positive")
}
if c.Gateway.OpenAIWS.EventFlushIntervalMS < 0 {
return fmt.Errorf("gateway.openai_ws.event_flush_interval_ms must be non-negative")
}
if c.Gateway.OpenAIWS.PrewarmCooldownMS < 0 {
return fmt.Errorf("gateway.openai_ws.prewarm_cooldown_ms must be non-negative")
}
if c.Gateway.OpenAIWS.FallbackCooldownSeconds < 0 {
return fmt.Errorf("gateway.openai_ws.fallback_cooldown_seconds must be non-negative")
}
if c.Gateway.OpenAIWS.RetryBackoffInitialMS < 0 {
return fmt.Errorf("gateway.openai_ws.retry_backoff_initial_ms must be non-negative")
}
if c.Gateway.OpenAIWS.RetryBackoffMaxMS < 0 {
return fmt.Errorf("gateway.openai_ws.retry_backoff_max_ms must be non-negative")
}
if c.Gateway.OpenAIWS.RetryBackoffInitialMS > 0 && c.Gateway.OpenAIWS.RetryBackoffMaxMS > 0 &&
c.Gateway.OpenAIWS.RetryBackoffMaxMS < c.Gateway.OpenAIWS.RetryBackoffInitialMS {
return fmt.Errorf("gateway.openai_ws.retry_backoff_max_ms must be >= retry_backoff_initial_ms")
}
if c.Gateway.OpenAIWS.RetryJitterRatio < 0 || c.Gateway.OpenAIWS.RetryJitterRatio > 1 {
return fmt.Errorf("gateway.openai_ws.retry_jitter_ratio must be within [0,1]")
}
if c.Gateway.OpenAIWS.RetryTotalBudgetMS < 0 {
return fmt.Errorf("gateway.openai_ws.retry_total_budget_ms must be non-negative")
}
if mode := strings.ToLower(strings.TrimSpace(c.Gateway.OpenAIWS.IngressModeDefault)); mode != "" {
switch mode {
case "off", "shared", "dedicated":
default:
return fmt.Errorf("gateway.openai_ws.ingress_mode_default must be one of off|shared|dedicated")
}
}
if mode := strings.ToLower(strings.TrimSpace(c.Gateway.OpenAIWS.StoreDisabledConnMode)); mode != "" {
switch mode {
case "strict", "adaptive", "off":
default:
return fmt.Errorf("gateway.openai_ws.store_disabled_conn_mode must be one of strict|adaptive|off")
}
}
if c.Gateway.OpenAIWS.PayloadLogSampleRate < 0 || c.Gateway.OpenAIWS.PayloadLogSampleRate > 1 {
return fmt.Errorf("gateway.openai_ws.payload_log_sample_rate must be within [0,1]")
}
if c.Gateway.OpenAIWS.LBTopK <= 0 {
return fmt.Errorf("gateway.openai_ws.lb_top_k must be positive")
}
if c.Gateway.OpenAIWS.StickySessionTTLSeconds <= 0 {
return fmt.Errorf("gateway.openai_ws.sticky_session_ttl_seconds must be positive")
}
if c.Gateway.OpenAIWS.StickyResponseIDTTLSeconds <= 0 {
return fmt.Errorf("gateway.openai_ws.sticky_response_id_ttl_seconds must be positive")
}
if c.Gateway.OpenAIWS.StickyPreviousResponseTTLSeconds < 0 {
return fmt.Errorf("gateway.openai_ws.sticky_previous_response_ttl_seconds must be non-negative")
}
if c.Gateway.OpenAIWS.SchedulerScoreWeights.Priority < 0 ||
c.Gateway.OpenAIWS.SchedulerScoreWeights.Load < 0 ||
c.Gateway.OpenAIWS.SchedulerScoreWeights.Queue < 0 ||
c.Gateway.OpenAIWS.SchedulerScoreWeights.ErrorRate < 0 ||
c.Gateway.OpenAIWS.SchedulerScoreWeights.TTFT < 0 {
return fmt.Errorf("gateway.openai_ws.scheduler_score_weights.* must be non-negative")
}
weightSum := c.Gateway.OpenAIWS.SchedulerScoreWeights.Priority +
c.Gateway.OpenAIWS.SchedulerScoreWeights.Load +
c.Gateway.OpenAIWS.SchedulerScoreWeights.Queue +
c.Gateway.OpenAIWS.SchedulerScoreWeights.ErrorRate +
c.Gateway.OpenAIWS.SchedulerScoreWeights.TTFT
if weightSum <= 0 {
return fmt.Errorf("gateway.openai_ws.scheduler_score_weights must not all be zero")
}
if c.Gateway.MaxLineSize < 0 {
return fmt.Errorf("gateway.max_line_size must be non-negative")
}
......
......@@ -6,6 +6,7 @@ import (
"time"
"github.com/spf13/viper"
"github.com/stretchr/testify/require"
)
func resetViperWithJWTSecret(t *testing.T) {
......@@ -75,6 +76,103 @@ func TestLoadDefaultSchedulingConfig(t *testing.T) {
}
}
func TestLoadDefaultOpenAIWSConfig(t *testing.T) {
resetViperWithJWTSecret(t)
cfg, err := Load()
if err != nil {
t.Fatalf("Load() error: %v", err)
}
if !cfg.Gateway.OpenAIWS.Enabled {
t.Fatalf("Gateway.OpenAIWS.Enabled = false, want true")
}
if !cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 {
t.Fatalf("Gateway.OpenAIWS.ResponsesWebsocketsV2 = false, want true")
}
if cfg.Gateway.OpenAIWS.ResponsesWebsockets {
t.Fatalf("Gateway.OpenAIWS.ResponsesWebsockets = true, want false")
}
if !cfg.Gateway.OpenAIWS.DynamicMaxConnsByAccountConcurrencyEnabled {
t.Fatalf("Gateway.OpenAIWS.DynamicMaxConnsByAccountConcurrencyEnabled = false, want true")
}
if cfg.Gateway.OpenAIWS.OAuthMaxConnsFactor != 1.0 {
t.Fatalf("Gateway.OpenAIWS.OAuthMaxConnsFactor = %v, want 1.0", cfg.Gateway.OpenAIWS.OAuthMaxConnsFactor)
}
if cfg.Gateway.OpenAIWS.APIKeyMaxConnsFactor != 1.0 {
t.Fatalf("Gateway.OpenAIWS.APIKeyMaxConnsFactor = %v, want 1.0", cfg.Gateway.OpenAIWS.APIKeyMaxConnsFactor)
}
if cfg.Gateway.OpenAIWS.StickySessionTTLSeconds != 3600 {
t.Fatalf("Gateway.OpenAIWS.StickySessionTTLSeconds = %d, want 3600", cfg.Gateway.OpenAIWS.StickySessionTTLSeconds)
}
if !cfg.Gateway.OpenAIWS.SessionHashReadOldFallback {
t.Fatalf("Gateway.OpenAIWS.SessionHashReadOldFallback = false, want true")
}
if !cfg.Gateway.OpenAIWS.SessionHashDualWriteOld {
t.Fatalf("Gateway.OpenAIWS.SessionHashDualWriteOld = false, want true")
}
if !cfg.Gateway.OpenAIWS.MetadataBridgeEnabled {
t.Fatalf("Gateway.OpenAIWS.MetadataBridgeEnabled = false, want true")
}
if cfg.Gateway.OpenAIWS.StickyResponseIDTTLSeconds != 3600 {
t.Fatalf("Gateway.OpenAIWS.StickyResponseIDTTLSeconds = %d, want 3600", cfg.Gateway.OpenAIWS.StickyResponseIDTTLSeconds)
}
if cfg.Gateway.OpenAIWS.FallbackCooldownSeconds != 30 {
t.Fatalf("Gateway.OpenAIWS.FallbackCooldownSeconds = %d, want 30", cfg.Gateway.OpenAIWS.FallbackCooldownSeconds)
}
if cfg.Gateway.OpenAIWS.EventFlushBatchSize != 1 {
t.Fatalf("Gateway.OpenAIWS.EventFlushBatchSize = %d, want 1", cfg.Gateway.OpenAIWS.EventFlushBatchSize)
}
if cfg.Gateway.OpenAIWS.EventFlushIntervalMS != 10 {
t.Fatalf("Gateway.OpenAIWS.EventFlushIntervalMS = %d, want 10", cfg.Gateway.OpenAIWS.EventFlushIntervalMS)
}
if cfg.Gateway.OpenAIWS.PrewarmCooldownMS != 300 {
t.Fatalf("Gateway.OpenAIWS.PrewarmCooldownMS = %d, want 300", cfg.Gateway.OpenAIWS.PrewarmCooldownMS)
}
if cfg.Gateway.OpenAIWS.RetryBackoffInitialMS != 120 {
t.Fatalf("Gateway.OpenAIWS.RetryBackoffInitialMS = %d, want 120", cfg.Gateway.OpenAIWS.RetryBackoffInitialMS)
}
if cfg.Gateway.OpenAIWS.RetryBackoffMaxMS != 2000 {
t.Fatalf("Gateway.OpenAIWS.RetryBackoffMaxMS = %d, want 2000", cfg.Gateway.OpenAIWS.RetryBackoffMaxMS)
}
if cfg.Gateway.OpenAIWS.RetryJitterRatio != 0.2 {
t.Fatalf("Gateway.OpenAIWS.RetryJitterRatio = %v, want 0.2", cfg.Gateway.OpenAIWS.RetryJitterRatio)
}
if cfg.Gateway.OpenAIWS.RetryTotalBudgetMS != 5000 {
t.Fatalf("Gateway.OpenAIWS.RetryTotalBudgetMS = %d, want 5000", cfg.Gateway.OpenAIWS.RetryTotalBudgetMS)
}
if cfg.Gateway.OpenAIWS.PayloadLogSampleRate != 0.2 {
t.Fatalf("Gateway.OpenAIWS.PayloadLogSampleRate = %v, want 0.2", cfg.Gateway.OpenAIWS.PayloadLogSampleRate)
}
if !cfg.Gateway.OpenAIWS.StoreDisabledForceNewConn {
t.Fatalf("Gateway.OpenAIWS.StoreDisabledForceNewConn = false, want true")
}
if cfg.Gateway.OpenAIWS.StoreDisabledConnMode != "strict" {
t.Fatalf("Gateway.OpenAIWS.StoreDisabledConnMode = %q, want %q", cfg.Gateway.OpenAIWS.StoreDisabledConnMode, "strict")
}
if cfg.Gateway.OpenAIWS.ModeRouterV2Enabled {
t.Fatalf("Gateway.OpenAIWS.ModeRouterV2Enabled = true, want false")
}
if cfg.Gateway.OpenAIWS.IngressModeDefault != "shared" {
t.Fatalf("Gateway.OpenAIWS.IngressModeDefault = %q, want %q", cfg.Gateway.OpenAIWS.IngressModeDefault, "shared")
}
}
func TestLoadOpenAIWSStickyTTLCompatibility(t *testing.T) {
resetViperWithJWTSecret(t)
t.Setenv("GATEWAY_OPENAI_WS_STICKY_RESPONSE_ID_TTL_SECONDS", "0")
t.Setenv("GATEWAY_OPENAI_WS_STICKY_PREVIOUS_RESPONSE_TTL_SECONDS", "7200")
cfg, err := Load()
if err != nil {
t.Fatalf("Load() error: %v", err)
}
if cfg.Gateway.OpenAIWS.StickyResponseIDTTLSeconds != 7200 {
t.Fatalf("StickyResponseIDTTLSeconds = %d, want 7200", cfg.Gateway.OpenAIWS.StickyResponseIDTTLSeconds)
}
}
func TestLoadDefaultIdempotencyConfig(t *testing.T) {
resetViperWithJWTSecret(t)
......@@ -993,6 +1091,16 @@ func TestValidateConfigErrors(t *testing.T) {
mutate: func(c *Config) { c.Gateway.StreamKeepaliveInterval = 4 },
wantErr: "gateway.stream_keepalive_interval",
},
{
name: "gateway openai ws oauth max conns factor",
mutate: func(c *Config) { c.Gateway.OpenAIWS.OAuthMaxConnsFactor = 0 },
wantErr: "gateway.openai_ws.oauth_max_conns_factor",
},
{
name: "gateway openai ws apikey max conns factor",
mutate: func(c *Config) { c.Gateway.OpenAIWS.APIKeyMaxConnsFactor = 0 },
wantErr: "gateway.openai_ws.apikey_max_conns_factor",
},
{
name: "gateway stream data interval range",
mutate: func(c *Config) { c.Gateway.StreamDataIntervalTimeout = 5 },
......@@ -1174,6 +1282,165 @@ func TestValidateConfigErrors(t *testing.T) {
}
}
func TestValidateConfig_OpenAIWSRules(t *testing.T) {
buildValid := func(t *testing.T) *Config {
t.Helper()
resetViperWithJWTSecret(t)
cfg, err := Load()
require.NoError(t, err)
return cfg
}
t.Run("sticky response id ttl 兼容旧键回填", func(t *testing.T) {
cfg := buildValid(t)
cfg.Gateway.OpenAIWS.StickyResponseIDTTLSeconds = 0
cfg.Gateway.OpenAIWS.StickyPreviousResponseTTLSeconds = 7200
require.NoError(t, cfg.Validate())
require.Equal(t, 7200, cfg.Gateway.OpenAIWS.StickyResponseIDTTLSeconds)
})
cases := []struct {
name string
mutate func(*Config)
wantErr string
}{
{
name: "max_conns_per_account 必须为正数",
mutate: func(c *Config) { c.Gateway.OpenAIWS.MaxConnsPerAccount = 0 },
wantErr: "gateway.openai_ws.max_conns_per_account",
},
{
name: "min_idle_per_account 不能为负数",
mutate: func(c *Config) { c.Gateway.OpenAIWS.MinIdlePerAccount = -1 },
wantErr: "gateway.openai_ws.min_idle_per_account",
},
{
name: "max_idle_per_account 不能为负数",
mutate: func(c *Config) { c.Gateway.OpenAIWS.MaxIdlePerAccount = -1 },
wantErr: "gateway.openai_ws.max_idle_per_account",
},
{
name: "min_idle_per_account 不能大于 max_idle_per_account",
mutate: func(c *Config) {
c.Gateway.OpenAIWS.MinIdlePerAccount = 3
c.Gateway.OpenAIWS.MaxIdlePerAccount = 2
},
wantErr: "gateway.openai_ws.min_idle_per_account must be <= max_idle_per_account",
},
{
name: "max_idle_per_account 不能大于 max_conns_per_account",
mutate: func(c *Config) {
c.Gateway.OpenAIWS.MaxConnsPerAccount = 2
c.Gateway.OpenAIWS.MinIdlePerAccount = 1
c.Gateway.OpenAIWS.MaxIdlePerAccount = 3
},
wantErr: "gateway.openai_ws.max_idle_per_account must be <= max_conns_per_account",
},
{
name: "dial_timeout_seconds 必须为正数",
mutate: func(c *Config) { c.Gateway.OpenAIWS.DialTimeoutSeconds = 0 },
wantErr: "gateway.openai_ws.dial_timeout_seconds",
},
{
name: "read_timeout_seconds 必须为正数",
mutate: func(c *Config) { c.Gateway.OpenAIWS.ReadTimeoutSeconds = 0 },
wantErr: "gateway.openai_ws.read_timeout_seconds",
},
{
name: "write_timeout_seconds 必须为正数",
mutate: func(c *Config) { c.Gateway.OpenAIWS.WriteTimeoutSeconds = 0 },
wantErr: "gateway.openai_ws.write_timeout_seconds",
},
{
name: "pool_target_utilization 必须在 (0,1]",
mutate: func(c *Config) { c.Gateway.OpenAIWS.PoolTargetUtilization = 0 },
wantErr: "gateway.openai_ws.pool_target_utilization",
},
{
name: "queue_limit_per_conn 必须为正数",
mutate: func(c *Config) { c.Gateway.OpenAIWS.QueueLimitPerConn = 0 },
wantErr: "gateway.openai_ws.queue_limit_per_conn",
},
{
name: "fallback_cooldown_seconds 不能为负数",
mutate: func(c *Config) { c.Gateway.OpenAIWS.FallbackCooldownSeconds = -1 },
wantErr: "gateway.openai_ws.fallback_cooldown_seconds",
},
{
name: "store_disabled_conn_mode 必须为 strict|adaptive|off",
mutate: func(c *Config) { c.Gateway.OpenAIWS.StoreDisabledConnMode = "invalid" },
wantErr: "gateway.openai_ws.store_disabled_conn_mode",
},
{
name: "ingress_mode_default 必须为 off|shared|dedicated",
mutate: func(c *Config) { c.Gateway.OpenAIWS.IngressModeDefault = "invalid" },
wantErr: "gateway.openai_ws.ingress_mode_default",
},
{
name: "payload_log_sample_rate 必须在 [0,1] 范围内",
mutate: func(c *Config) { c.Gateway.OpenAIWS.PayloadLogSampleRate = 1.2 },
wantErr: "gateway.openai_ws.payload_log_sample_rate",
},
{
name: "retry_total_budget_ms 不能为负数",
mutate: func(c *Config) { c.Gateway.OpenAIWS.RetryTotalBudgetMS = -1 },
wantErr: "gateway.openai_ws.retry_total_budget_ms",
},
{
name: "lb_top_k 必须为正数",
mutate: func(c *Config) { c.Gateway.OpenAIWS.LBTopK = 0 },
wantErr: "gateway.openai_ws.lb_top_k",
},
{
name: "sticky_session_ttl_seconds 必须为正数",
mutate: func(c *Config) { c.Gateway.OpenAIWS.StickySessionTTLSeconds = 0 },
wantErr: "gateway.openai_ws.sticky_session_ttl_seconds",
},
{
name: "sticky_response_id_ttl_seconds 必须为正数",
mutate: func(c *Config) {
c.Gateway.OpenAIWS.StickyResponseIDTTLSeconds = 0
c.Gateway.OpenAIWS.StickyPreviousResponseTTLSeconds = 0
},
wantErr: "gateway.openai_ws.sticky_response_id_ttl_seconds",
},
{
name: "sticky_previous_response_ttl_seconds 不能为负数",
mutate: func(c *Config) { c.Gateway.OpenAIWS.StickyPreviousResponseTTLSeconds = -1 },
wantErr: "gateway.openai_ws.sticky_previous_response_ttl_seconds",
},
{
name: "scheduler_score_weights 不能为负数",
mutate: func(c *Config) { c.Gateway.OpenAIWS.SchedulerScoreWeights.Queue = -0.1 },
wantErr: "gateway.openai_ws.scheduler_score_weights.* must be non-negative",
},
{
name: "scheduler_score_weights 不能全为 0",
mutate: func(c *Config) {
c.Gateway.OpenAIWS.SchedulerScoreWeights.Priority = 0
c.Gateway.OpenAIWS.SchedulerScoreWeights.Load = 0
c.Gateway.OpenAIWS.SchedulerScoreWeights.Queue = 0
c.Gateway.OpenAIWS.SchedulerScoreWeights.ErrorRate = 0
c.Gateway.OpenAIWS.SchedulerScoreWeights.TTFT = 0
},
wantErr: "gateway.openai_ws.scheduler_score_weights must not all be zero",
},
}
for _, tc := range cases {
tc := tc
t.Run(tc.name, func(t *testing.T) {
cfg := buildValid(t)
tc.mutate(cfg)
err := cfg.Validate()
require.Error(t, err)
require.Contains(t, err.Error(), tc.wantErr)
})
}
}
func TestValidateConfig_AutoScaleDisabledIgnoreAutoScaleFields(t *testing.T) {
resetViperWithJWTSecret(t)
cfg, err := Load()
......
......@@ -104,6 +104,9 @@ var DefaultAntigravityModelMapping = map[string]string{
"gemini-3.1-flash-image": "gemini-3.1-flash-image",
// Gemini 3.1 image preview 映射
"gemini-3.1-flash-image-preview": "gemini-3.1-flash-image",
// Gemini 3 image 兼容映射(向 3.1 image 迁移)
"gemini-3-pro-image": "gemini-3.1-flash-image",
"gemini-3-pro-image-preview": "gemini-3.1-flash-image",
// 其他官方模型
"gpt-oss-120b-medium": "gpt-oss-120b-medium",
"tab_flash_lite_preview": "tab_flash_lite_preview",
......
package domain
import "testing"
func TestDefaultAntigravityModelMapping_ImageCompatibilityAliases(t *testing.T) {
t.Parallel()
cases := map[string]string{
"gemini-3.1-flash-image": "gemini-3.1-flash-image",
"gemini-3.1-flash-image-preview": "gemini-3.1-flash-image",
"gemini-3-pro-image": "gemini-3.1-flash-image",
"gemini-3-pro-image-preview": "gemini-3.1-flash-image",
}
for from, want := range cases {
got, ok := DefaultAntigravityModelMapping[from]
if !ok {
t.Fatalf("expected mapping for %q to exist", from)
}
if got != want {
t.Fatalf("unexpected mapping for %q: got %q want %q", from, got, want)
}
}
}
......@@ -1337,6 +1337,34 @@ func (h *AccountHandler) GetTodayStats(c *gin.Context) {
response.Success(c, stats)
}
// BatchTodayStatsRequest 批量今日统计请求体。
type BatchTodayStatsRequest struct {
AccountIDs []int64 `json:"account_ids" binding:"required"`
}
// GetBatchTodayStats 批量获取多个账号的今日统计。
// POST /api/v1/admin/accounts/today-stats/batch
func (h *AccountHandler) GetBatchTodayStats(c *gin.Context) {
var req BatchTodayStatsRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
if len(req.AccountIDs) == 0 {
response.Success(c, gin.H{"stats": map[string]any{}})
return
}
stats, err := h.accountUsageService.GetTodayStatsBatch(c.Request.Context(), req.AccountIDs)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, gin.H{"stats": stats})
}
// SetSchedulableRequest represents the request body for setting schedulable status
type SetSchedulableRequest struct {
Schedulable bool `json:"schedulable"`
......
......@@ -3,6 +3,7 @@ package admin
import (
"errors"
"strconv"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
......@@ -186,7 +187,7 @@ func (h *DashboardHandler) GetRealtimeMetrics(c *gin.Context) {
// GetUsageTrend handles getting usage trend data
// GET /api/v1/admin/dashboard/trend
// Query params: start_date, end_date (YYYY-MM-DD), granularity (day/hour), user_id, api_key_id, model, account_id, group_id, stream, billing_type
// Query params: start_date, end_date (YYYY-MM-DD), granularity (day/hour), user_id, api_key_id, model, account_id, group_id, request_type, stream, billing_type
func (h *DashboardHandler) GetUsageTrend(c *gin.Context) {
startTime, endTime := parseTimeRange(c)
granularity := c.DefaultQuery("granularity", "day")
......@@ -194,6 +195,7 @@ func (h *DashboardHandler) GetUsageTrend(c *gin.Context) {
// Parse optional filter params
var userID, apiKeyID, accountID, groupID int64
var model string
var requestType *int16
var stream *bool
var billingType *int8
......@@ -220,9 +222,20 @@ func (h *DashboardHandler) GetUsageTrend(c *gin.Context) {
if modelStr := c.Query("model"); modelStr != "" {
model = modelStr
}
if streamStr := c.Query("stream"); streamStr != "" {
if requestTypeStr := strings.TrimSpace(c.Query("request_type")); requestTypeStr != "" {
parsed, err := service.ParseUsageRequestType(requestTypeStr)
if err != nil {
response.BadRequest(c, err.Error())
return
}
value := int16(parsed)
requestType = &value
} else if streamStr := c.Query("stream"); streamStr != "" {
if streamVal, err := strconv.ParseBool(streamStr); err == nil {
stream = &streamVal
} else {
response.BadRequest(c, "Invalid stream value, use true or false")
return
}
}
if billingTypeStr := c.Query("billing_type"); billingTypeStr != "" {
......@@ -235,7 +248,7 @@ func (h *DashboardHandler) GetUsageTrend(c *gin.Context) {
}
}
trend, err := h.dashboardService.GetUsageTrendWithFilters(c.Request.Context(), startTime, endTime, granularity, userID, apiKeyID, accountID, groupID, model, stream, billingType)
trend, err := h.dashboardService.GetUsageTrendWithFilters(c.Request.Context(), startTime, endTime, granularity, userID, apiKeyID, accountID, groupID, model, requestType, stream, billingType)
if err != nil {
response.Error(c, 500, "Failed to get usage trend")
return
......@@ -251,12 +264,13 @@ func (h *DashboardHandler) GetUsageTrend(c *gin.Context) {
// GetModelStats handles getting model usage statistics
// GET /api/v1/admin/dashboard/models
// Query params: start_date, end_date (YYYY-MM-DD), user_id, api_key_id, account_id, group_id, stream, billing_type
// Query params: start_date, end_date (YYYY-MM-DD), user_id, api_key_id, account_id, group_id, request_type, stream, billing_type
func (h *DashboardHandler) GetModelStats(c *gin.Context) {
startTime, endTime := parseTimeRange(c)
// Parse optional filter params
var userID, apiKeyID, accountID, groupID int64
var requestType *int16
var stream *bool
var billingType *int8
......@@ -280,9 +294,20 @@ func (h *DashboardHandler) GetModelStats(c *gin.Context) {
groupID = id
}
}
if streamStr := c.Query("stream"); streamStr != "" {
if requestTypeStr := strings.TrimSpace(c.Query("request_type")); requestTypeStr != "" {
parsed, err := service.ParseUsageRequestType(requestTypeStr)
if err != nil {
response.BadRequest(c, err.Error())
return
}
value := int16(parsed)
requestType = &value
} else if streamStr := c.Query("stream"); streamStr != "" {
if streamVal, err := strconv.ParseBool(streamStr); err == nil {
stream = &streamVal
} else {
response.BadRequest(c, "Invalid stream value, use true or false")
return
}
}
if billingTypeStr := c.Query("billing_type"); billingTypeStr != "" {
......@@ -295,7 +320,7 @@ func (h *DashboardHandler) GetModelStats(c *gin.Context) {
}
}
stats, err := h.dashboardService.GetModelStatsWithFilters(c.Request.Context(), startTime, endTime, userID, apiKeyID, accountID, groupID, stream, billingType)
stats, err := h.dashboardService.GetModelStatsWithFilters(c.Request.Context(), startTime, endTime, userID, apiKeyID, accountID, groupID, requestType, stream, billingType)
if err != nil {
response.Error(c, 500, "Failed to get model statistics")
return
......
package admin
import (
"context"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
type dashboardUsageRepoCapture struct {
service.UsageLogRepository
trendRequestType *int16
trendStream *bool
modelRequestType *int16
modelStream *bool
}
func (s *dashboardUsageRepoCapture) GetUsageTrendWithFilters(
ctx context.Context,
startTime, endTime time.Time,
granularity string,
userID, apiKeyID, accountID, groupID int64,
model string,
requestType *int16,
stream *bool,
billingType *int8,
) ([]usagestats.TrendDataPoint, error) {
s.trendRequestType = requestType
s.trendStream = stream
return []usagestats.TrendDataPoint{}, nil
}
func (s *dashboardUsageRepoCapture) GetModelStatsWithFilters(
ctx context.Context,
startTime, endTime time.Time,
userID, apiKeyID, accountID, groupID int64,
requestType *int16,
stream *bool,
billingType *int8,
) ([]usagestats.ModelStat, error) {
s.modelRequestType = requestType
s.modelStream = stream
return []usagestats.ModelStat{}, nil
}
func newDashboardRequestTypeTestRouter(repo *dashboardUsageRepoCapture) *gin.Engine {
gin.SetMode(gin.TestMode)
dashboardSvc := service.NewDashboardService(repo, nil, nil, nil)
handler := NewDashboardHandler(dashboardSvc, nil)
router := gin.New()
router.GET("/admin/dashboard/trend", handler.GetUsageTrend)
router.GET("/admin/dashboard/models", handler.GetModelStats)
return router
}
func TestDashboardTrendRequestTypePriority(t *testing.T) {
repo := &dashboardUsageRepoCapture{}
router := newDashboardRequestTypeTestRouter(repo)
req := httptest.NewRequest(http.MethodGet, "/admin/dashboard/trend?request_type=ws_v2&stream=bad", nil)
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
require.NotNil(t, repo.trendRequestType)
require.Equal(t, int16(service.RequestTypeWSV2), *repo.trendRequestType)
require.Nil(t, repo.trendStream)
}
func TestDashboardTrendInvalidRequestType(t *testing.T) {
repo := &dashboardUsageRepoCapture{}
router := newDashboardRequestTypeTestRouter(repo)
req := httptest.NewRequest(http.MethodGet, "/admin/dashboard/trend?request_type=bad", nil)
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusBadRequest, rec.Code)
}
func TestDashboardTrendInvalidStream(t *testing.T) {
repo := &dashboardUsageRepoCapture{}
router := newDashboardRequestTypeTestRouter(repo)
req := httptest.NewRequest(http.MethodGet, "/admin/dashboard/trend?stream=bad", nil)
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusBadRequest, rec.Code)
}
func TestDashboardModelStatsRequestTypePriority(t *testing.T) {
repo := &dashboardUsageRepoCapture{}
router := newDashboardRequestTypeTestRouter(repo)
req := httptest.NewRequest(http.MethodGet, "/admin/dashboard/models?request_type=sync&stream=bad", nil)
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
require.NotNil(t, repo.modelRequestType)
require.Equal(t, int16(service.RequestTypeSync), *repo.modelRequestType)
require.Nil(t, repo.modelStream)
}
func TestDashboardModelStatsInvalidRequestType(t *testing.T) {
repo := &dashboardUsageRepoCapture{}
router := newDashboardRequestTypeTestRouter(repo)
req := httptest.NewRequest(http.MethodGet, "/admin/dashboard/models?request_type=bad", nil)
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusBadRequest, rec.Code)
}
func TestDashboardModelStatsInvalidStream(t *testing.T) {
repo := &dashboardUsageRepoCapture{}
router := newDashboardRequestTypeTestRouter(repo)
req := httptest.NewRequest(http.MethodGet, "/admin/dashboard/models?stream=bad", nil)
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusBadRequest, rec.Code)
}
package admin
import (
"strconv"
"strings"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
type DataManagementHandler struct {
dataManagementService *service.DataManagementService
}
func NewDataManagementHandler(dataManagementService *service.DataManagementService) *DataManagementHandler {
return &DataManagementHandler{dataManagementService: dataManagementService}
}
type TestS3ConnectionRequest struct {
Endpoint string `json:"endpoint"`
Region string `json:"region" binding:"required"`
Bucket string `json:"bucket" binding:"required"`
AccessKeyID string `json:"access_key_id"`
SecretAccessKey string `json:"secret_access_key"`
Prefix string `json:"prefix"`
ForcePathStyle bool `json:"force_path_style"`
UseSSL bool `json:"use_ssl"`
}
type CreateBackupJobRequest struct {
BackupType string `json:"backup_type" binding:"required,oneof=postgres redis full"`
UploadToS3 bool `json:"upload_to_s3"`
S3ProfileID string `json:"s3_profile_id"`
PostgresID string `json:"postgres_profile_id"`
RedisID string `json:"redis_profile_id"`
IdempotencyKey string `json:"idempotency_key"`
}
type CreateSourceProfileRequest struct {
ProfileID string `json:"profile_id" binding:"required"`
Name string `json:"name" binding:"required"`
Config service.DataManagementSourceConfig `json:"config" binding:"required"`
SetActive bool `json:"set_active"`
}
type UpdateSourceProfileRequest struct {
Name string `json:"name" binding:"required"`
Config service.DataManagementSourceConfig `json:"config" binding:"required"`
}
type CreateS3ProfileRequest struct {
ProfileID string `json:"profile_id" binding:"required"`
Name string `json:"name" binding:"required"`
Enabled bool `json:"enabled"`
Endpoint string `json:"endpoint"`
Region string `json:"region"`
Bucket string `json:"bucket"`
AccessKeyID string `json:"access_key_id"`
SecretAccessKey string `json:"secret_access_key"`
Prefix string `json:"prefix"`
ForcePathStyle bool `json:"force_path_style"`
UseSSL bool `json:"use_ssl"`
SetActive bool `json:"set_active"`
}
type UpdateS3ProfileRequest struct {
Name string `json:"name" binding:"required"`
Enabled bool `json:"enabled"`
Endpoint string `json:"endpoint"`
Region string `json:"region"`
Bucket string `json:"bucket"`
AccessKeyID string `json:"access_key_id"`
SecretAccessKey string `json:"secret_access_key"`
Prefix string `json:"prefix"`
ForcePathStyle bool `json:"force_path_style"`
UseSSL bool `json:"use_ssl"`
}
func (h *DataManagementHandler) GetAgentHealth(c *gin.Context) {
health := h.getAgentHealth(c)
payload := gin.H{
"enabled": health.Enabled,
"reason": health.Reason,
"socket_path": health.SocketPath,
}
if health.Agent != nil {
payload["agent"] = gin.H{
"status": health.Agent.Status,
"version": health.Agent.Version,
"uptime_seconds": health.Agent.UptimeSeconds,
}
}
response.Success(c, payload)
}
func (h *DataManagementHandler) GetConfig(c *gin.Context) {
if !h.requireAgentEnabled(c) {
return
}
cfg, err := h.dataManagementService.GetConfig(c.Request.Context())
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, cfg)
}
func (h *DataManagementHandler) UpdateConfig(c *gin.Context) {
var req service.DataManagementConfig
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
if !h.requireAgentEnabled(c) {
return
}
cfg, err := h.dataManagementService.UpdateConfig(c.Request.Context(), req)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, cfg)
}
func (h *DataManagementHandler) TestS3(c *gin.Context) {
var req TestS3ConnectionRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
if !h.requireAgentEnabled(c) {
return
}
result, err := h.dataManagementService.ValidateS3(c.Request.Context(), service.DataManagementS3Config{
Enabled: true,
Endpoint: req.Endpoint,
Region: req.Region,
Bucket: req.Bucket,
AccessKeyID: req.AccessKeyID,
SecretAccessKey: req.SecretAccessKey,
Prefix: req.Prefix,
ForcePathStyle: req.ForcePathStyle,
UseSSL: req.UseSSL,
})
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, gin.H{"ok": result.OK, "message": result.Message})
}
func (h *DataManagementHandler) CreateBackupJob(c *gin.Context) {
var req CreateBackupJobRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
req.IdempotencyKey = normalizeBackupIdempotencyKey(c.GetHeader("X-Idempotency-Key"), req.IdempotencyKey)
if !h.requireAgentEnabled(c) {
return
}
triggeredBy := "admin:unknown"
if subject, ok := middleware2.GetAuthSubjectFromContext(c); ok {
triggeredBy = "admin:" + strconv.FormatInt(subject.UserID, 10)
}
job, err := h.dataManagementService.CreateBackupJob(c.Request.Context(), service.DataManagementCreateBackupJobInput{
BackupType: req.BackupType,
UploadToS3: req.UploadToS3,
S3ProfileID: req.S3ProfileID,
PostgresID: req.PostgresID,
RedisID: req.RedisID,
TriggeredBy: triggeredBy,
IdempotencyKey: req.IdempotencyKey,
})
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, gin.H{"job_id": job.JobID, "status": job.Status})
}
func (h *DataManagementHandler) ListSourceProfiles(c *gin.Context) {
sourceType := strings.TrimSpace(c.Param("source_type"))
if sourceType == "" {
response.BadRequest(c, "Invalid source_type")
return
}
if sourceType != "postgres" && sourceType != "redis" {
response.BadRequest(c, "source_type must be postgres or redis")
return
}
if !h.requireAgentEnabled(c) {
return
}
items, err := h.dataManagementService.ListSourceProfiles(c.Request.Context(), sourceType)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, gin.H{"items": items})
}
func (h *DataManagementHandler) CreateSourceProfile(c *gin.Context) {
sourceType := strings.TrimSpace(c.Param("source_type"))
if sourceType != "postgres" && sourceType != "redis" {
response.BadRequest(c, "source_type must be postgres or redis")
return
}
var req CreateSourceProfileRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
if !h.requireAgentEnabled(c) {
return
}
profile, err := h.dataManagementService.CreateSourceProfile(c.Request.Context(), service.DataManagementCreateSourceProfileInput{
SourceType: sourceType,
ProfileID: req.ProfileID,
Name: req.Name,
Config: req.Config,
SetActive: req.SetActive,
})
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, profile)
}
func (h *DataManagementHandler) UpdateSourceProfile(c *gin.Context) {
sourceType := strings.TrimSpace(c.Param("source_type"))
if sourceType != "postgres" && sourceType != "redis" {
response.BadRequest(c, "source_type must be postgres or redis")
return
}
profileID := strings.TrimSpace(c.Param("profile_id"))
if profileID == "" {
response.BadRequest(c, "Invalid profile_id")
return
}
var req UpdateSourceProfileRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
if !h.requireAgentEnabled(c) {
return
}
profile, err := h.dataManagementService.UpdateSourceProfile(c.Request.Context(), service.DataManagementUpdateSourceProfileInput{
SourceType: sourceType,
ProfileID: profileID,
Name: req.Name,
Config: req.Config,
})
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, profile)
}
func (h *DataManagementHandler) DeleteSourceProfile(c *gin.Context) {
sourceType := strings.TrimSpace(c.Param("source_type"))
if sourceType != "postgres" && sourceType != "redis" {
response.BadRequest(c, "source_type must be postgres or redis")
return
}
profileID := strings.TrimSpace(c.Param("profile_id"))
if profileID == "" {
response.BadRequest(c, "Invalid profile_id")
return
}
if !h.requireAgentEnabled(c) {
return
}
if err := h.dataManagementService.DeleteSourceProfile(c.Request.Context(), sourceType, profileID); err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, gin.H{"deleted": true})
}
func (h *DataManagementHandler) SetActiveSourceProfile(c *gin.Context) {
sourceType := strings.TrimSpace(c.Param("source_type"))
if sourceType != "postgres" && sourceType != "redis" {
response.BadRequest(c, "source_type must be postgres or redis")
return
}
profileID := strings.TrimSpace(c.Param("profile_id"))
if profileID == "" {
response.BadRequest(c, "Invalid profile_id")
return
}
if !h.requireAgentEnabled(c) {
return
}
profile, err := h.dataManagementService.SetActiveSourceProfile(c.Request.Context(), sourceType, profileID)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, profile)
}
func (h *DataManagementHandler) ListS3Profiles(c *gin.Context) {
if !h.requireAgentEnabled(c) {
return
}
items, err := h.dataManagementService.ListS3Profiles(c.Request.Context())
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, gin.H{"items": items})
}
func (h *DataManagementHandler) CreateS3Profile(c *gin.Context) {
var req CreateS3ProfileRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
if !h.requireAgentEnabled(c) {
return
}
profile, err := h.dataManagementService.CreateS3Profile(c.Request.Context(), service.DataManagementCreateS3ProfileInput{
ProfileID: req.ProfileID,
Name: req.Name,
SetActive: req.SetActive,
S3: service.DataManagementS3Config{
Enabled: req.Enabled,
Endpoint: req.Endpoint,
Region: req.Region,
Bucket: req.Bucket,
AccessKeyID: req.AccessKeyID,
SecretAccessKey: req.SecretAccessKey,
Prefix: req.Prefix,
ForcePathStyle: req.ForcePathStyle,
UseSSL: req.UseSSL,
},
})
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, profile)
}
func (h *DataManagementHandler) UpdateS3Profile(c *gin.Context) {
var req UpdateS3ProfileRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
profileID := strings.TrimSpace(c.Param("profile_id"))
if profileID == "" {
response.BadRequest(c, "Invalid profile_id")
return
}
if !h.requireAgentEnabled(c) {
return
}
profile, err := h.dataManagementService.UpdateS3Profile(c.Request.Context(), service.DataManagementUpdateS3ProfileInput{
ProfileID: profileID,
Name: req.Name,
S3: service.DataManagementS3Config{
Enabled: req.Enabled,
Endpoint: req.Endpoint,
Region: req.Region,
Bucket: req.Bucket,
AccessKeyID: req.AccessKeyID,
SecretAccessKey: req.SecretAccessKey,
Prefix: req.Prefix,
ForcePathStyle: req.ForcePathStyle,
UseSSL: req.UseSSL,
},
})
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, profile)
}
func (h *DataManagementHandler) DeleteS3Profile(c *gin.Context) {
profileID := strings.TrimSpace(c.Param("profile_id"))
if profileID == "" {
response.BadRequest(c, "Invalid profile_id")
return
}
if !h.requireAgentEnabled(c) {
return
}
if err := h.dataManagementService.DeleteS3Profile(c.Request.Context(), profileID); err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, gin.H{"deleted": true})
}
func (h *DataManagementHandler) SetActiveS3Profile(c *gin.Context) {
profileID := strings.TrimSpace(c.Param("profile_id"))
if profileID == "" {
response.BadRequest(c, "Invalid profile_id")
return
}
if !h.requireAgentEnabled(c) {
return
}
profile, err := h.dataManagementService.SetActiveS3Profile(c.Request.Context(), profileID)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, profile)
}
func (h *DataManagementHandler) ListBackupJobs(c *gin.Context) {
if !h.requireAgentEnabled(c) {
return
}
pageSize := int32(20)
if raw := strings.TrimSpace(c.Query("page_size")); raw != "" {
v, err := strconv.Atoi(raw)
if err != nil || v <= 0 {
response.BadRequest(c, "Invalid page_size")
return
}
pageSize = int32(v)
}
result, err := h.dataManagementService.ListBackupJobs(c.Request.Context(), service.DataManagementListBackupJobsInput{
PageSize: pageSize,
PageToken: c.Query("page_token"),
Status: c.Query("status"),
BackupType: c.Query("backup_type"),
})
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, result)
}
func (h *DataManagementHandler) GetBackupJob(c *gin.Context) {
jobID := strings.TrimSpace(c.Param("job_id"))
if jobID == "" {
response.BadRequest(c, "Invalid backup job ID")
return
}
if !h.requireAgentEnabled(c) {
return
}
job, err := h.dataManagementService.GetBackupJob(c.Request.Context(), jobID)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, job)
}
func (h *DataManagementHandler) requireAgentEnabled(c *gin.Context) bool {
if h.dataManagementService == nil {
err := infraerrors.ServiceUnavailable(
service.DataManagementAgentUnavailableReason,
"data management agent service is not configured",
).WithMetadata(map[string]string{"socket_path": service.DefaultDataManagementAgentSocketPath})
response.ErrorFrom(c, err)
return false
}
if err := h.dataManagementService.EnsureAgentEnabled(c.Request.Context()); err != nil {
response.ErrorFrom(c, err)
return false
}
return true
}
func (h *DataManagementHandler) getAgentHealth(c *gin.Context) service.DataManagementAgentHealth {
if h.dataManagementService == nil {
return service.DataManagementAgentHealth{
Enabled: false,
Reason: service.DataManagementAgentUnavailableReason,
SocketPath: service.DefaultDataManagementAgentSocketPath,
}
}
return h.dataManagementService.GetAgentHealth(c.Request.Context())
}
func normalizeBackupIdempotencyKey(headerValue, bodyValue string) string {
headerKey := strings.TrimSpace(headerValue)
if headerKey != "" {
return headerKey
}
return strings.TrimSpace(bodyValue)
}
package admin
import (
"encoding/json"
"net/http"
"net/http/httptest"
"path/filepath"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
type apiEnvelope struct {
Code int `json:"code"`
Message string `json:"message"`
Reason string `json:"reason"`
Data json.RawMessage `json:"data"`
}
func TestDataManagementHandler_AgentHealthAlways200(t *testing.T) {
gin.SetMode(gin.TestMode)
svc := service.NewDataManagementServiceWithOptions(filepath.Join(t.TempDir(), "missing.sock"), 50*time.Millisecond)
h := NewDataManagementHandler(svc)
r := gin.New()
r.GET("/api/v1/admin/data-management/agent/health", h.GetAgentHealth)
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/data-management/agent/health", nil)
r.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
var envelope apiEnvelope
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &envelope))
require.Equal(t, 0, envelope.Code)
var data struct {
Enabled bool `json:"enabled"`
Reason string `json:"reason"`
SocketPath string `json:"socket_path"`
}
require.NoError(t, json.Unmarshal(envelope.Data, &data))
require.False(t, data.Enabled)
require.Equal(t, service.DataManagementDeprecatedReason, data.Reason)
require.Equal(t, svc.SocketPath(), data.SocketPath)
}
func TestDataManagementHandler_NonHealthRouteReturns503WhenDisabled(t *testing.T) {
gin.SetMode(gin.TestMode)
svc := service.NewDataManagementServiceWithOptions(filepath.Join(t.TempDir(), "missing.sock"), 50*time.Millisecond)
h := NewDataManagementHandler(svc)
r := gin.New()
r.GET("/api/v1/admin/data-management/config", h.GetConfig)
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/data-management/config", nil)
r.ServeHTTP(rec, req)
require.Equal(t, http.StatusServiceUnavailable, rec.Code)
var envelope apiEnvelope
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &envelope))
require.Equal(t, http.StatusServiceUnavailable, envelope.Code)
require.Equal(t, service.DataManagementDeprecatedReason, envelope.Reason)
}
func TestNormalizeBackupIdempotencyKey(t *testing.T) {
require.Equal(t, "from-header", normalizeBackupIdempotencyKey("from-header", "from-body"))
require.Equal(t, "from-body", normalizeBackupIdempotencyKey(" ", " from-body "))
require.Equal(t, "", normalizeBackupIdempotencyKey("", ""))
}
......@@ -51,6 +51,8 @@ type CreateGroupRequest struct {
MCPXMLInject *bool `json:"mcp_xml_inject"`
// 支持的模型系列(仅 antigravity 平台使用)
SupportedModelScopes []string `json:"supported_model_scopes"`
// Sora 存储配额
SoraStorageQuotaBytes int64 `json:"sora_storage_quota_bytes"`
// 从指定分组复制账号(创建后自动绑定)
CopyAccountsFromGroupIDs []int64 `json:"copy_accounts_from_group_ids"`
}
......@@ -84,6 +86,8 @@ type UpdateGroupRequest struct {
MCPXMLInject *bool `json:"mcp_xml_inject"`
// 支持的模型系列(仅 antigravity 平台使用)
SupportedModelScopes *[]string `json:"supported_model_scopes"`
// Sora 存储配额
SoraStorageQuotaBytes *int64 `json:"sora_storage_quota_bytes"`
// 从指定分组复制账号(同步操作:先清空当前分组的账号绑定,再绑定源分组的账号)
CopyAccountsFromGroupIDs []int64 `json:"copy_accounts_from_group_ids"`
}
......@@ -198,6 +202,7 @@ func (h *GroupHandler) Create(c *gin.Context) {
ModelRoutingEnabled: req.ModelRoutingEnabled,
MCPXMLInject: req.MCPXMLInject,
SupportedModelScopes: req.SupportedModelScopes,
SoraStorageQuotaBytes: req.SoraStorageQuotaBytes,
CopyAccountsFromGroupIDs: req.CopyAccountsFromGroupIDs,
})
if err != nil {
......@@ -248,6 +253,7 @@ func (h *GroupHandler) Update(c *gin.Context) {
ModelRoutingEnabled: req.ModelRoutingEnabled,
MCPXMLInject: req.MCPXMLInject,
SupportedModelScopes: req.SupportedModelScopes,
SoraStorageQuotaBytes: req.SoraStorageQuotaBytes,
CopyAccountsFromGroupIDs: req.CopyAccountsFromGroupIDs,
})
if err != nil {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment