Unverified Commit 9d795061 authored by Wesley Liddick's avatar Wesley Liddick Committed by GitHub
Browse files

Merge pull request #682 from mt21625457/pr/all-code-sync-20260228

feat(openai-ws): support websocket mode v2, optimize relay performance, enhance sora
parents bfc7b339 1d1fc019
......@@ -72,6 +72,12 @@ func (User) Fields() []ent.Field {
field.Time("totp_enabled_at").
Optional().
Nillable(),
// Sora 存储配额
field.Int64("sora_storage_quota_bytes").
Default(0),
field.Int64("sora_storage_used_bytes").
Default(0),
}
}
......
......@@ -108,6 +108,8 @@ func (UserSubscription) Indexes() []ent.Index {
index.Fields("group_id"),
index.Fields("status"),
index.Fields("expires_at"),
// 活跃订阅查询复合索引(线上由 SQL 迁移创建部分索引,schema 仅用于模型可读性对齐)
index.Fields("user_id", "status", "expires_at"),
index.Fields("assigned_by"),
// 唯一约束通过部分索引实现(WHERE deleted_at IS NULL),支持软删除后重新订阅
// 见迁移文件 016_soft_delete_partial_unique_indexes.sql
......
......@@ -28,6 +28,8 @@ type Tx struct {
ErrorPassthroughRule *ErrorPassthroughRuleClient
// Group is the client for interacting with the Group builders.
Group *GroupClient
// IdempotencyRecord is the client for interacting with the IdempotencyRecord builders.
IdempotencyRecord *IdempotencyRecordClient
// PromoCode is the client for interacting with the PromoCode builders.
PromoCode *PromoCodeClient
// PromoCodeUsage is the client for interacting with the PromoCodeUsage builders.
......@@ -192,6 +194,7 @@ func (tx *Tx) init() {
tx.AnnouncementRead = NewAnnouncementReadClient(tx.config)
tx.ErrorPassthroughRule = NewErrorPassthroughRuleClient(tx.config)
tx.Group = NewGroupClient(tx.config)
tx.IdempotencyRecord = NewIdempotencyRecordClient(tx.config)
tx.PromoCode = NewPromoCodeClient(tx.config)
tx.PromoCodeUsage = NewPromoCodeUsageClient(tx.config)
tx.Proxy = NewProxyClient(tx.config)
......
......@@ -45,6 +45,10 @@ type User struct {
TotpEnabled bool `json:"totp_enabled,omitempty"`
// TotpEnabledAt holds the value of the "totp_enabled_at" field.
TotpEnabledAt *time.Time `json:"totp_enabled_at,omitempty"`
// SoraStorageQuotaBytes holds the value of the "sora_storage_quota_bytes" field.
SoraStorageQuotaBytes int64 `json:"sora_storage_quota_bytes,omitempty"`
// SoraStorageUsedBytes holds the value of the "sora_storage_used_bytes" field.
SoraStorageUsedBytes int64 `json:"sora_storage_used_bytes,omitempty"`
// Edges holds the relations/edges for other nodes in the graph.
// The values are being populated by the UserQuery when eager-loading is set.
Edges UserEdges `json:"edges"`
......@@ -177,7 +181,7 @@ func (*User) scanValues(columns []string) ([]any, error) {
values[i] = new(sql.NullBool)
case user.FieldBalance:
values[i] = new(sql.NullFloat64)
case user.FieldID, user.FieldConcurrency:
case user.FieldID, user.FieldConcurrency, user.FieldSoraStorageQuotaBytes, user.FieldSoraStorageUsedBytes:
values[i] = new(sql.NullInt64)
case user.FieldEmail, user.FieldPasswordHash, user.FieldRole, user.FieldStatus, user.FieldUsername, user.FieldNotes, user.FieldTotpSecretEncrypted:
values[i] = new(sql.NullString)
......@@ -291,6 +295,18 @@ func (_m *User) assignValues(columns []string, values []any) error {
_m.TotpEnabledAt = new(time.Time)
*_m.TotpEnabledAt = value.Time
}
case user.FieldSoraStorageQuotaBytes:
if value, ok := values[i].(*sql.NullInt64); !ok {
return fmt.Errorf("unexpected type %T for field sora_storage_quota_bytes", values[i])
} else if value.Valid {
_m.SoraStorageQuotaBytes = value.Int64
}
case user.FieldSoraStorageUsedBytes:
if value, ok := values[i].(*sql.NullInt64); !ok {
return fmt.Errorf("unexpected type %T for field sora_storage_used_bytes", values[i])
} else if value.Valid {
_m.SoraStorageUsedBytes = value.Int64
}
default:
_m.selectValues.Set(columns[i], values[i])
}
......@@ -424,6 +440,12 @@ func (_m *User) String() string {
builder.WriteString("totp_enabled_at=")
builder.WriteString(v.Format(time.ANSIC))
}
builder.WriteString(", ")
builder.WriteString("sora_storage_quota_bytes=")
builder.WriteString(fmt.Sprintf("%v", _m.SoraStorageQuotaBytes))
builder.WriteString(", ")
builder.WriteString("sora_storage_used_bytes=")
builder.WriteString(fmt.Sprintf("%v", _m.SoraStorageUsedBytes))
builder.WriteByte(')')
return builder.String()
}
......
......@@ -43,6 +43,10 @@ const (
FieldTotpEnabled = "totp_enabled"
// FieldTotpEnabledAt holds the string denoting the totp_enabled_at field in the database.
FieldTotpEnabledAt = "totp_enabled_at"
// FieldSoraStorageQuotaBytes holds the string denoting the sora_storage_quota_bytes field in the database.
FieldSoraStorageQuotaBytes = "sora_storage_quota_bytes"
// FieldSoraStorageUsedBytes holds the string denoting the sora_storage_used_bytes field in the database.
FieldSoraStorageUsedBytes = "sora_storage_used_bytes"
// EdgeAPIKeys holds the string denoting the api_keys edge name in mutations.
EdgeAPIKeys = "api_keys"
// EdgeRedeemCodes holds the string denoting the redeem_codes edge name in mutations.
......@@ -152,6 +156,8 @@ var Columns = []string{
FieldTotpSecretEncrypted,
FieldTotpEnabled,
FieldTotpEnabledAt,
FieldSoraStorageQuotaBytes,
FieldSoraStorageUsedBytes,
}
var (
......@@ -208,6 +214,10 @@ var (
DefaultNotes string
// DefaultTotpEnabled holds the default value on creation for the "totp_enabled" field.
DefaultTotpEnabled bool
// DefaultSoraStorageQuotaBytes holds the default value on creation for the "sora_storage_quota_bytes" field.
DefaultSoraStorageQuotaBytes int64
// DefaultSoraStorageUsedBytes holds the default value on creation for the "sora_storage_used_bytes" field.
DefaultSoraStorageUsedBytes int64
)
// OrderOption defines the ordering options for the User queries.
......@@ -288,6 +298,16 @@ func ByTotpEnabledAt(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldTotpEnabledAt, opts...).ToFunc()
}
// BySoraStorageQuotaBytes orders the results by the sora_storage_quota_bytes field.
func BySoraStorageQuotaBytes(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldSoraStorageQuotaBytes, opts...).ToFunc()
}
// BySoraStorageUsedBytes orders the results by the sora_storage_used_bytes field.
func BySoraStorageUsedBytes(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldSoraStorageUsedBytes, opts...).ToFunc()
}
// ByAPIKeysCount orders the results by api_keys count.
func ByAPIKeysCount(opts ...sql.OrderTermOption) OrderOption {
return func(s *sql.Selector) {
......
......@@ -125,6 +125,16 @@ func TotpEnabledAt(v time.Time) predicate.User {
return predicate.User(sql.FieldEQ(FieldTotpEnabledAt, v))
}
// SoraStorageQuotaBytes applies equality check predicate on the "sora_storage_quota_bytes" field. It's identical to SoraStorageQuotaBytesEQ.
func SoraStorageQuotaBytes(v int64) predicate.User {
return predicate.User(sql.FieldEQ(FieldSoraStorageQuotaBytes, v))
}
// SoraStorageUsedBytes applies equality check predicate on the "sora_storage_used_bytes" field. It's identical to SoraStorageUsedBytesEQ.
func SoraStorageUsedBytes(v int64) predicate.User {
return predicate.User(sql.FieldEQ(FieldSoraStorageUsedBytes, v))
}
// CreatedAtEQ applies the EQ predicate on the "created_at" field.
func CreatedAtEQ(v time.Time) predicate.User {
return predicate.User(sql.FieldEQ(FieldCreatedAt, v))
......@@ -860,6 +870,86 @@ func TotpEnabledAtNotNil() predicate.User {
return predicate.User(sql.FieldNotNull(FieldTotpEnabledAt))
}
// SoraStorageQuotaBytesEQ applies the EQ predicate on the "sora_storage_quota_bytes" field.
func SoraStorageQuotaBytesEQ(v int64) predicate.User {
return predicate.User(sql.FieldEQ(FieldSoraStorageQuotaBytes, v))
}
// SoraStorageQuotaBytesNEQ applies the NEQ predicate on the "sora_storage_quota_bytes" field.
func SoraStorageQuotaBytesNEQ(v int64) predicate.User {
return predicate.User(sql.FieldNEQ(FieldSoraStorageQuotaBytes, v))
}
// SoraStorageQuotaBytesIn applies the In predicate on the "sora_storage_quota_bytes" field.
func SoraStorageQuotaBytesIn(vs ...int64) predicate.User {
return predicate.User(sql.FieldIn(FieldSoraStorageQuotaBytes, vs...))
}
// SoraStorageQuotaBytesNotIn applies the NotIn predicate on the "sora_storage_quota_bytes" field.
func SoraStorageQuotaBytesNotIn(vs ...int64) predicate.User {
return predicate.User(sql.FieldNotIn(FieldSoraStorageQuotaBytes, vs...))
}
// SoraStorageQuotaBytesGT applies the GT predicate on the "sora_storage_quota_bytes" field.
func SoraStorageQuotaBytesGT(v int64) predicate.User {
return predicate.User(sql.FieldGT(FieldSoraStorageQuotaBytes, v))
}
// SoraStorageQuotaBytesGTE applies the GTE predicate on the "sora_storage_quota_bytes" field.
func SoraStorageQuotaBytesGTE(v int64) predicate.User {
return predicate.User(sql.FieldGTE(FieldSoraStorageQuotaBytes, v))
}
// SoraStorageQuotaBytesLT applies the LT predicate on the "sora_storage_quota_bytes" field.
func SoraStorageQuotaBytesLT(v int64) predicate.User {
return predicate.User(sql.FieldLT(FieldSoraStorageQuotaBytes, v))
}
// SoraStorageQuotaBytesLTE applies the LTE predicate on the "sora_storage_quota_bytes" field.
func SoraStorageQuotaBytesLTE(v int64) predicate.User {
return predicate.User(sql.FieldLTE(FieldSoraStorageQuotaBytes, v))
}
// SoraStorageUsedBytesEQ applies the EQ predicate on the "sora_storage_used_bytes" field.
func SoraStorageUsedBytesEQ(v int64) predicate.User {
return predicate.User(sql.FieldEQ(FieldSoraStorageUsedBytes, v))
}
// SoraStorageUsedBytesNEQ applies the NEQ predicate on the "sora_storage_used_bytes" field.
func SoraStorageUsedBytesNEQ(v int64) predicate.User {
return predicate.User(sql.FieldNEQ(FieldSoraStorageUsedBytes, v))
}
// SoraStorageUsedBytesIn applies the In predicate on the "sora_storage_used_bytes" field.
func SoraStorageUsedBytesIn(vs ...int64) predicate.User {
return predicate.User(sql.FieldIn(FieldSoraStorageUsedBytes, vs...))
}
// SoraStorageUsedBytesNotIn applies the NotIn predicate on the "sora_storage_used_bytes" field.
func SoraStorageUsedBytesNotIn(vs ...int64) predicate.User {
return predicate.User(sql.FieldNotIn(FieldSoraStorageUsedBytes, vs...))
}
// SoraStorageUsedBytesGT applies the GT predicate on the "sora_storage_used_bytes" field.
func SoraStorageUsedBytesGT(v int64) predicate.User {
return predicate.User(sql.FieldGT(FieldSoraStorageUsedBytes, v))
}
// SoraStorageUsedBytesGTE applies the GTE predicate on the "sora_storage_used_bytes" field.
func SoraStorageUsedBytesGTE(v int64) predicate.User {
return predicate.User(sql.FieldGTE(FieldSoraStorageUsedBytes, v))
}
// SoraStorageUsedBytesLT applies the LT predicate on the "sora_storage_used_bytes" field.
func SoraStorageUsedBytesLT(v int64) predicate.User {
return predicate.User(sql.FieldLT(FieldSoraStorageUsedBytes, v))
}
// SoraStorageUsedBytesLTE applies the LTE predicate on the "sora_storage_used_bytes" field.
func SoraStorageUsedBytesLTE(v int64) predicate.User {
return predicate.User(sql.FieldLTE(FieldSoraStorageUsedBytes, v))
}
// HasAPIKeys applies the HasEdge predicate on the "api_keys" edge.
func HasAPIKeys() predicate.User {
return predicate.User(func(s *sql.Selector) {
......
......@@ -210,6 +210,34 @@ func (_c *UserCreate) SetNillableTotpEnabledAt(v *time.Time) *UserCreate {
return _c
}
// SetSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field.
func (_c *UserCreate) SetSoraStorageQuotaBytes(v int64) *UserCreate {
_c.mutation.SetSoraStorageQuotaBytes(v)
return _c
}
// SetNillableSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field if the given value is not nil.
func (_c *UserCreate) SetNillableSoraStorageQuotaBytes(v *int64) *UserCreate {
if v != nil {
_c.SetSoraStorageQuotaBytes(*v)
}
return _c
}
// SetSoraStorageUsedBytes sets the "sora_storage_used_bytes" field.
func (_c *UserCreate) SetSoraStorageUsedBytes(v int64) *UserCreate {
_c.mutation.SetSoraStorageUsedBytes(v)
return _c
}
// SetNillableSoraStorageUsedBytes sets the "sora_storage_used_bytes" field if the given value is not nil.
func (_c *UserCreate) SetNillableSoraStorageUsedBytes(v *int64) *UserCreate {
if v != nil {
_c.SetSoraStorageUsedBytes(*v)
}
return _c
}
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs.
func (_c *UserCreate) AddAPIKeyIDs(ids ...int64) *UserCreate {
_c.mutation.AddAPIKeyIDs(ids...)
......@@ -424,6 +452,14 @@ func (_c *UserCreate) defaults() error {
v := user.DefaultTotpEnabled
_c.mutation.SetTotpEnabled(v)
}
if _, ok := _c.mutation.SoraStorageQuotaBytes(); !ok {
v := user.DefaultSoraStorageQuotaBytes
_c.mutation.SetSoraStorageQuotaBytes(v)
}
if _, ok := _c.mutation.SoraStorageUsedBytes(); !ok {
v := user.DefaultSoraStorageUsedBytes
_c.mutation.SetSoraStorageUsedBytes(v)
}
return nil
}
......@@ -487,6 +523,12 @@ func (_c *UserCreate) check() error {
if _, ok := _c.mutation.TotpEnabled(); !ok {
return &ValidationError{Name: "totp_enabled", err: errors.New(`ent: missing required field "User.totp_enabled"`)}
}
if _, ok := _c.mutation.SoraStorageQuotaBytes(); !ok {
return &ValidationError{Name: "sora_storage_quota_bytes", err: errors.New(`ent: missing required field "User.sora_storage_quota_bytes"`)}
}
if _, ok := _c.mutation.SoraStorageUsedBytes(); !ok {
return &ValidationError{Name: "sora_storage_used_bytes", err: errors.New(`ent: missing required field "User.sora_storage_used_bytes"`)}
}
return nil
}
......@@ -570,6 +612,14 @@ func (_c *UserCreate) createSpec() (*User, *sqlgraph.CreateSpec) {
_spec.SetField(user.FieldTotpEnabledAt, field.TypeTime, value)
_node.TotpEnabledAt = &value
}
if value, ok := _c.mutation.SoraStorageQuotaBytes(); ok {
_spec.SetField(user.FieldSoraStorageQuotaBytes, field.TypeInt64, value)
_node.SoraStorageQuotaBytes = value
}
if value, ok := _c.mutation.SoraStorageUsedBytes(); ok {
_spec.SetField(user.FieldSoraStorageUsedBytes, field.TypeInt64, value)
_node.SoraStorageUsedBytes = value
}
if nodes := _c.mutation.APIKeysIDs(); len(nodes) > 0 {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
......@@ -956,6 +1006,42 @@ func (u *UserUpsert) ClearTotpEnabledAt() *UserUpsert {
return u
}
// SetSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field.
func (u *UserUpsert) SetSoraStorageQuotaBytes(v int64) *UserUpsert {
u.Set(user.FieldSoraStorageQuotaBytes, v)
return u
}
// UpdateSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field to the value that was provided on create.
func (u *UserUpsert) UpdateSoraStorageQuotaBytes() *UserUpsert {
u.SetExcluded(user.FieldSoraStorageQuotaBytes)
return u
}
// AddSoraStorageQuotaBytes adds v to the "sora_storage_quota_bytes" field.
func (u *UserUpsert) AddSoraStorageQuotaBytes(v int64) *UserUpsert {
u.Add(user.FieldSoraStorageQuotaBytes, v)
return u
}
// SetSoraStorageUsedBytes sets the "sora_storage_used_bytes" field.
func (u *UserUpsert) SetSoraStorageUsedBytes(v int64) *UserUpsert {
u.Set(user.FieldSoraStorageUsedBytes, v)
return u
}
// UpdateSoraStorageUsedBytes sets the "sora_storage_used_bytes" field to the value that was provided on create.
func (u *UserUpsert) UpdateSoraStorageUsedBytes() *UserUpsert {
u.SetExcluded(user.FieldSoraStorageUsedBytes)
return u
}
// AddSoraStorageUsedBytes adds v to the "sora_storage_used_bytes" field.
func (u *UserUpsert) AddSoraStorageUsedBytes(v int64) *UserUpsert {
u.Add(user.FieldSoraStorageUsedBytes, v)
return u
}
// UpdateNewValues updates the mutable fields using the new values that were set on create.
// Using this option is equivalent to using:
//
......@@ -1218,6 +1304,48 @@ func (u *UserUpsertOne) ClearTotpEnabledAt() *UserUpsertOne {
})
}
// SetSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field.
func (u *UserUpsertOne) SetSoraStorageQuotaBytes(v int64) *UserUpsertOne {
return u.Update(func(s *UserUpsert) {
s.SetSoraStorageQuotaBytes(v)
})
}
// AddSoraStorageQuotaBytes adds v to the "sora_storage_quota_bytes" field.
func (u *UserUpsertOne) AddSoraStorageQuotaBytes(v int64) *UserUpsertOne {
return u.Update(func(s *UserUpsert) {
s.AddSoraStorageQuotaBytes(v)
})
}
// UpdateSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field to the value that was provided on create.
func (u *UserUpsertOne) UpdateSoraStorageQuotaBytes() *UserUpsertOne {
return u.Update(func(s *UserUpsert) {
s.UpdateSoraStorageQuotaBytes()
})
}
// SetSoraStorageUsedBytes sets the "sora_storage_used_bytes" field.
func (u *UserUpsertOne) SetSoraStorageUsedBytes(v int64) *UserUpsertOne {
return u.Update(func(s *UserUpsert) {
s.SetSoraStorageUsedBytes(v)
})
}
// AddSoraStorageUsedBytes adds v to the "sora_storage_used_bytes" field.
func (u *UserUpsertOne) AddSoraStorageUsedBytes(v int64) *UserUpsertOne {
return u.Update(func(s *UserUpsert) {
s.AddSoraStorageUsedBytes(v)
})
}
// UpdateSoraStorageUsedBytes sets the "sora_storage_used_bytes" field to the value that was provided on create.
func (u *UserUpsertOne) UpdateSoraStorageUsedBytes() *UserUpsertOne {
return u.Update(func(s *UserUpsert) {
s.UpdateSoraStorageUsedBytes()
})
}
// Exec executes the query.
func (u *UserUpsertOne) Exec(ctx context.Context) error {
if len(u.create.conflict) == 0 {
......@@ -1646,6 +1774,48 @@ func (u *UserUpsertBulk) ClearTotpEnabledAt() *UserUpsertBulk {
})
}
// SetSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field.
func (u *UserUpsertBulk) SetSoraStorageQuotaBytes(v int64) *UserUpsertBulk {
return u.Update(func(s *UserUpsert) {
s.SetSoraStorageQuotaBytes(v)
})
}
// AddSoraStorageQuotaBytes adds v to the "sora_storage_quota_bytes" field.
func (u *UserUpsertBulk) AddSoraStorageQuotaBytes(v int64) *UserUpsertBulk {
return u.Update(func(s *UserUpsert) {
s.AddSoraStorageQuotaBytes(v)
})
}
// UpdateSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field to the value that was provided on create.
func (u *UserUpsertBulk) UpdateSoraStorageQuotaBytes() *UserUpsertBulk {
return u.Update(func(s *UserUpsert) {
s.UpdateSoraStorageQuotaBytes()
})
}
// SetSoraStorageUsedBytes sets the "sora_storage_used_bytes" field.
func (u *UserUpsertBulk) SetSoraStorageUsedBytes(v int64) *UserUpsertBulk {
return u.Update(func(s *UserUpsert) {
s.SetSoraStorageUsedBytes(v)
})
}
// AddSoraStorageUsedBytes adds v to the "sora_storage_used_bytes" field.
func (u *UserUpsertBulk) AddSoraStorageUsedBytes(v int64) *UserUpsertBulk {
return u.Update(func(s *UserUpsert) {
s.AddSoraStorageUsedBytes(v)
})
}
// UpdateSoraStorageUsedBytes sets the "sora_storage_used_bytes" field to the value that was provided on create.
func (u *UserUpsertBulk) UpdateSoraStorageUsedBytes() *UserUpsertBulk {
return u.Update(func(s *UserUpsert) {
s.UpdateSoraStorageUsedBytes()
})
}
// Exec executes the query.
func (u *UserUpsertBulk) Exec(ctx context.Context) error {
if u.create.err != nil {
......
......@@ -242,6 +242,48 @@ func (_u *UserUpdate) ClearTotpEnabledAt() *UserUpdate {
return _u
}
// SetSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field.
func (_u *UserUpdate) SetSoraStorageQuotaBytes(v int64) *UserUpdate {
_u.mutation.ResetSoraStorageQuotaBytes()
_u.mutation.SetSoraStorageQuotaBytes(v)
return _u
}
// SetNillableSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field if the given value is not nil.
func (_u *UserUpdate) SetNillableSoraStorageQuotaBytes(v *int64) *UserUpdate {
if v != nil {
_u.SetSoraStorageQuotaBytes(*v)
}
return _u
}
// AddSoraStorageQuotaBytes adds value to the "sora_storage_quota_bytes" field.
func (_u *UserUpdate) AddSoraStorageQuotaBytes(v int64) *UserUpdate {
_u.mutation.AddSoraStorageQuotaBytes(v)
return _u
}
// SetSoraStorageUsedBytes sets the "sora_storage_used_bytes" field.
func (_u *UserUpdate) SetSoraStorageUsedBytes(v int64) *UserUpdate {
_u.mutation.ResetSoraStorageUsedBytes()
_u.mutation.SetSoraStorageUsedBytes(v)
return _u
}
// SetNillableSoraStorageUsedBytes sets the "sora_storage_used_bytes" field if the given value is not nil.
func (_u *UserUpdate) SetNillableSoraStorageUsedBytes(v *int64) *UserUpdate {
if v != nil {
_u.SetSoraStorageUsedBytes(*v)
}
return _u
}
// AddSoraStorageUsedBytes adds value to the "sora_storage_used_bytes" field.
func (_u *UserUpdate) AddSoraStorageUsedBytes(v int64) *UserUpdate {
_u.mutation.AddSoraStorageUsedBytes(v)
return _u
}
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs.
func (_u *UserUpdate) AddAPIKeyIDs(ids ...int64) *UserUpdate {
_u.mutation.AddAPIKeyIDs(ids...)
......@@ -709,6 +751,18 @@ func (_u *UserUpdate) sqlSave(ctx context.Context) (_node int, err error) {
if _u.mutation.TotpEnabledAtCleared() {
_spec.ClearField(user.FieldTotpEnabledAt, field.TypeTime)
}
if value, ok := _u.mutation.SoraStorageQuotaBytes(); ok {
_spec.SetField(user.FieldSoraStorageQuotaBytes, field.TypeInt64, value)
}
if value, ok := _u.mutation.AddedSoraStorageQuotaBytes(); ok {
_spec.AddField(user.FieldSoraStorageQuotaBytes, field.TypeInt64, value)
}
if value, ok := _u.mutation.SoraStorageUsedBytes(); ok {
_spec.SetField(user.FieldSoraStorageUsedBytes, field.TypeInt64, value)
}
if value, ok := _u.mutation.AddedSoraStorageUsedBytes(); ok {
_spec.AddField(user.FieldSoraStorageUsedBytes, field.TypeInt64, value)
}
if _u.mutation.APIKeysCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
......@@ -1352,6 +1406,48 @@ func (_u *UserUpdateOne) ClearTotpEnabledAt() *UserUpdateOne {
return _u
}
// SetSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field.
func (_u *UserUpdateOne) SetSoraStorageQuotaBytes(v int64) *UserUpdateOne {
_u.mutation.ResetSoraStorageQuotaBytes()
_u.mutation.SetSoraStorageQuotaBytes(v)
return _u
}
// SetNillableSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field if the given value is not nil.
func (_u *UserUpdateOne) SetNillableSoraStorageQuotaBytes(v *int64) *UserUpdateOne {
if v != nil {
_u.SetSoraStorageQuotaBytes(*v)
}
return _u
}
// AddSoraStorageQuotaBytes adds value to the "sora_storage_quota_bytes" field.
func (_u *UserUpdateOne) AddSoraStorageQuotaBytes(v int64) *UserUpdateOne {
_u.mutation.AddSoraStorageQuotaBytes(v)
return _u
}
// SetSoraStorageUsedBytes sets the "sora_storage_used_bytes" field.
func (_u *UserUpdateOne) SetSoraStorageUsedBytes(v int64) *UserUpdateOne {
_u.mutation.ResetSoraStorageUsedBytes()
_u.mutation.SetSoraStorageUsedBytes(v)
return _u
}
// SetNillableSoraStorageUsedBytes sets the "sora_storage_used_bytes" field if the given value is not nil.
func (_u *UserUpdateOne) SetNillableSoraStorageUsedBytes(v *int64) *UserUpdateOne {
if v != nil {
_u.SetSoraStorageUsedBytes(*v)
}
return _u
}
// AddSoraStorageUsedBytes adds value to the "sora_storage_used_bytes" field.
func (_u *UserUpdateOne) AddSoraStorageUsedBytes(v int64) *UserUpdateOne {
_u.mutation.AddSoraStorageUsedBytes(v)
return _u
}
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs.
func (_u *UserUpdateOne) AddAPIKeyIDs(ids ...int64) *UserUpdateOne {
_u.mutation.AddAPIKeyIDs(ids...)
......@@ -1849,6 +1945,18 @@ func (_u *UserUpdateOne) sqlSave(ctx context.Context) (_node *User, err error) {
if _u.mutation.TotpEnabledAtCleared() {
_spec.ClearField(user.FieldTotpEnabledAt, field.TypeTime)
}
if value, ok := _u.mutation.SoraStorageQuotaBytes(); ok {
_spec.SetField(user.FieldSoraStorageQuotaBytes, field.TypeInt64, value)
}
if value, ok := _u.mutation.AddedSoraStorageQuotaBytes(); ok {
_spec.AddField(user.FieldSoraStorageQuotaBytes, field.TypeInt64, value)
}
if value, ok := _u.mutation.SoraStorageUsedBytes(); ok {
_spec.SetField(user.FieldSoraStorageUsedBytes, field.TypeInt64, value)
}
if value, ok := _u.mutation.AddedSoraStorageUsedBytes(); ok {
_spec.AddField(user.FieldSoraStorageUsedBytes, field.TypeInt64, value)
}
if _u.mutation.APIKeysCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
......
......@@ -7,7 +7,11 @@ require (
github.com/DATA-DOG/go-sqlmock v1.5.2
github.com/DouDOU-start/go-sora2api v1.1.0
github.com/alitto/pond/v2 v2.6.2
github.com/aws/aws-sdk-go-v2/config v1.32.10
github.com/aws/aws-sdk-go-v2/credentials v1.19.10
github.com/aws/aws-sdk-go-v2/service/s3 v1.96.2
github.com/cespare/xxhash/v2 v2.3.0
github.com/coder/websocket v1.8.14
github.com/dgraph-io/ristretto v0.2.0
github.com/gin-gonic/gin v1.9.1
github.com/golang-jwt/jwt/v5 v5.2.2
......@@ -34,6 +38,8 @@ require (
golang.org/x/net v0.49.0
golang.org/x/sync v0.19.0
golang.org/x/term v0.40.0
google.golang.org/grpc v1.75.1
google.golang.org/protobuf v1.36.10
gopkg.in/natefinch/lumberjack.v2 v2.2.1
gopkg.in/yaml.v3 v3.0.1
modernc.org/sqlite v1.44.3
......@@ -47,6 +53,22 @@ require (
github.com/agext/levenshtein v1.2.3 // indirect
github.com/andybalholm/brotli v1.2.0 // indirect
github.com/apparentlymart/go-textseg/v15 v15.0.0 // indirect
github.com/aws/aws-sdk-go-v2 v1.41.2 // indirect
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.5 // indirect
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.18 // indirect
github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.18 // indirect
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.18 // indirect
github.com/aws/aws-sdk-go-v2/internal/ini v1.8.4 // indirect
github.com/aws/aws-sdk-go-v2/internal/v4a v1.4.18 // indirect
github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.5 // indirect
github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.9.10 // indirect
github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.18 // indirect
github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.19.18 // indirect
github.com/aws/aws-sdk-go-v2/service/signin v1.0.6 // indirect
github.com/aws/aws-sdk-go-v2/service/sso v1.30.11 // indirect
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.15 // indirect
github.com/aws/aws-sdk-go-v2/service/sts v1.41.7 // indirect
github.com/aws/smithy-go v1.24.1 // indirect
github.com/bdandy/go-errors v1.2.2 // indirect
github.com/bdandy/go-socks4 v1.2.3 // indirect
github.com/bmatcuk/doublestar v1.3.4 // indirect
......@@ -146,7 +168,6 @@ require (
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.49.0 // indirect
go.opentelemetry.io/otel v1.37.0 // indirect
go.opentelemetry.io/otel/metric v1.37.0 // indirect
go.opentelemetry.io/otel/sdk v1.37.0 // indirect
go.opentelemetry.io/otel/trace v1.37.0 // indirect
go.uber.org/atomic v1.10.0 // indirect
go.uber.org/automaxprocs v1.6.0 // indirect
......@@ -156,8 +177,7 @@ require (
golang.org/x/mod v0.32.0 // indirect
golang.org/x/sys v0.41.0 // indirect
golang.org/x/text v0.34.0 // indirect
google.golang.org/grpc v1.75.1 // indirect
google.golang.org/protobuf v1.36.10 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20250929231259-57b25ae835d4 // indirect
gopkg.in/ini.v1 v1.67.0 // indirect
modernc.org/libc v1.67.6 // indirect
modernc.org/mathutil v1.7.1 // indirect
......
......@@ -22,6 +22,44 @@ github.com/andybalholm/brotli v1.2.0 h1:ukwgCxwYrmACq68yiUqwIWnGY0cTPox/M94sVwTo
github.com/andybalholm/brotli v1.2.0/go.mod h1:rzTDkvFWvIrjDXZHkuS16NPggd91W3kUSvPlQ1pLaKY=
github.com/apparentlymart/go-textseg/v15 v15.0.0 h1:uYvfpb3DyLSCGWnctWKGj857c6ew1u1fNQOlOtuGxQY=
github.com/apparentlymart/go-textseg/v15 v15.0.0/go.mod h1:K8XmNZdhEBkdlyDdvbmmsvpAG721bKi0joRfFdHIWJ4=
github.com/aws/aws-sdk-go-v2 v1.41.2 h1:LuT2rzqNQsauaGkPK/7813XxcZ3o3yePY0Iy891T2ls=
github.com/aws/aws-sdk-go-v2 v1.41.2/go.mod h1:IvvlAZQXvTXznUPfRVfryiG1fbzE2NGK6m9u39YQ+S4=
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.5 h1:zWFmPmgw4sveAYi1mRqG+E/g0461cJ5M4bJ8/nc6d3Q=
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.5/go.mod h1:nVUlMLVV8ycXSb7mSkcNu9e3v/1TJq2RTlrPwhYWr5c=
github.com/aws/aws-sdk-go-v2/config v1.32.10 h1:9DMthfO6XWZYLfzZglAgW5Fyou2nRI5CuV44sTedKBI=
github.com/aws/aws-sdk-go-v2/config v1.32.10/go.mod h1:2rUIOnA2JaiqYmSKYmRJlcMWy6qTj1vuRFscppSBMcw=
github.com/aws/aws-sdk-go-v2/credentials v1.19.10 h1:EEhmEUFCE1Yhl7vDhNOI5OCL/iKMdkkYFTRpZXNw7m8=
github.com/aws/aws-sdk-go-v2/credentials v1.19.10/go.mod h1:RnnlFCAlxQCkN2Q379B67USkBMu1PipEEiibzYN5UTE=
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.18 h1:Ii4s+Sq3yDfaMLpjrJsqD6SmG/Wq/P5L/hw2qa78UAY=
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.18/go.mod h1:6x81qnY++ovptLE6nWQeWrpXxbnlIex+4H4eYYGcqfc=
github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.18 h1:F43zk1vemYIqPAwhjTjYIz0irU2EY7sOb/F5eJ3HuyM=
github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.18/go.mod h1:w1jdlZXrGKaJcNoL+Nnrj+k5wlpGXqnNrKoP22HvAug=
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.18 h1:xCeWVjj0ki0l3nruoyP2slHsGArMxeiiaoPN5QZH6YQ=
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.18/go.mod h1:r/eLGuGCBw6l36ZRWiw6PaZwPXb6YOj+i/7MizNl5/k=
github.com/aws/aws-sdk-go-v2/internal/ini v1.8.4 h1:WKuaxf++XKWlHWu9ECbMlha8WOEGm0OUEZqm4K/Gcfk=
github.com/aws/aws-sdk-go-v2/internal/ini v1.8.4/go.mod h1:ZWy7j6v1vWGmPReu0iSGvRiise4YI5SkR3OHKTZ6Wuc=
github.com/aws/aws-sdk-go-v2/internal/v4a v1.4.18 h1:eZioDaZGJ0tMM4gzmkNIO2aAoQd+je7Ug7TkvAzlmkU=
github.com/aws/aws-sdk-go-v2/internal/v4a v1.4.18/go.mod h1:CCXwUKAJdoWr6/NcxZ+zsiPr6oH/Q5aTooRGYieAyj4=
github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.5 h1:CeY9LUdur+Dxoeldqoun6y4WtJ3RQtzk0JMP2gfUay0=
github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.5/go.mod h1:AZLZf2fMaahW5s/wMRciu1sYbdsikT/UHwbUjOdEVTc=
github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.9.10 h1:fJvQ5mIBVfKtiyx0AHY6HeWcRX5LGANLpq8SVR+Uazs=
github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.9.10/go.mod h1:Kzm5e6OmNH8VMkgK9t+ry5jEih4Y8whqs+1hrkxim1I=
github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.18 h1:LTRCYFlnnKFlKsyIQxKhJuDuA3ZkrDQMRYm6rXiHlLY=
github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.18/go.mod h1:XhwkgGG6bHSd00nO/mexWTcTjgd6PjuvWQMqSn2UaEk=
github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.19.18 h1:/A/xDuZAVD2BpsS2fftFRo/NoEKQJ8YTnJDEHBy2Gtg=
github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.19.18/go.mod h1:hWe9b4f+djUQGmyiGEeOnZv69dtMSgpDRIvNMvuvzvY=
github.com/aws/aws-sdk-go-v2/service/s3 v1.96.2 h1:M1A9AjcFwlxTLuf0Faj88L8Iqw0n/AJHjpZTQzMMsSc=
github.com/aws/aws-sdk-go-v2/service/s3 v1.96.2/go.mod h1:KsdTV6Q9WKUZm2mNJnUFmIoXfZux91M3sr/a4REX8e0=
github.com/aws/aws-sdk-go-v2/service/signin v1.0.6 h1:MzORe+J94I+hYu2a6XmV5yC9huoTv8NRcCrUNedDypQ=
github.com/aws/aws-sdk-go-v2/service/signin v1.0.6/go.mod h1:hXzcHLARD7GeWnifd8j9RWqtfIgxj4/cAtIVIK7hg8g=
github.com/aws/aws-sdk-go-v2/service/sso v1.30.11 h1:7oGD8KPfBOJGXiCoRKrrrQkbvCp8N++u36hrLMPey6o=
github.com/aws/aws-sdk-go-v2/service/sso v1.30.11/go.mod h1:0DO9B5EUJQlIDif+XJRWCljZRKsAFKh3gpFz7UnDtOo=
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.15 h1:edCcNp9eGIUDUCrzoCu1jWAXLGFIizeqkdkKgRlJwWc=
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.15/go.mod h1:lyRQKED9xWfgkYC/wmmYfv7iVIM68Z5OQ88ZdcV1QbU=
github.com/aws/aws-sdk-go-v2/service/sts v1.41.7 h1:NITQpgo9A5NrDZ57uOWj+abvXSb83BbyggcUBVksN7c=
github.com/aws/aws-sdk-go-v2/service/sts v1.41.7/go.mod h1:sks5UWBhEuWYDPdwlnRFn1w7xWdH29Jcpe+/PJQefEs=
github.com/aws/smithy-go v1.24.1 h1:VbyeNfmYkWoxMVpGUAbQumkODcYmfMRfZ8yQiH30SK0=
github.com/aws/smithy-go v1.24.1/go.mod h1:LEj2LM3rBRQJxPZTB4KuzZkaZYnZPnvgIhb4pu07mx0=
github.com/bdandy/go-errors v1.2.2 h1:WdFv/oukjTJCLa79UfkGmwX7ZxONAihKu4V0mLIs11Q=
github.com/bdandy/go-errors v1.2.2/go.mod h1:NkYHl4Fey9oRRdbB1CoC6e84tuqQHiqrOcZpqFEkBxM=
github.com/bdandy/go-socks4 v1.2.3 h1:Q6Y2heY1GRjCtHbmlKfnwrKVU/k81LS8mRGLRlmDlic=
......@@ -56,6 +94,12 @@ github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XL
github.com/chenzhuoyu/base64x v0.0.0-20211019084208-fb5309c8db06/go.mod h1:DH46F32mSOjUmXrMHnKwZdA8wcEefY7UVqBKYGjpdQY=
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 h1:qSGYFH7+jGhDF8vLC+iwCD4WpbV1EBDSzWkJODFLams=
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311/go.mod h1:b583jCggY9gE99b6G5LEC39OIiVsWj+R97kbl5odCEk=
github.com/clipperhouse/stringish v0.1.1 h1:+NSqMOr3GR6k1FdRhhnXrLfztGzuG+VuFDfatpWHKCs=
github.com/clipperhouse/stringish v0.1.1/go.mod h1:v/WhFtE1q0ovMta2+m+UbpZ+2/HEXNWYXQgCt4hdOzA=
github.com/clipperhouse/uax29/v2 v2.5.0 h1:x7T0T4eTHDONxFJsL94uKNKPHrclyFI0lm7+w94cO8U=
github.com/clipperhouse/uax29/v2 v2.5.0/go.mod h1:Wn1g7MK6OoeDT0vL+Q0SQLDz/KpfsVRgg6W7ihQeh4g=
github.com/coder/websocket v1.8.14 h1:9L0p0iKiNOibykf283eHkKUHHrpG7f65OE3BhhO7v9g=
github.com/coder/websocket v1.8.14/go.mod h1:NX3SzP+inril6yawo5CQXx8+fk145lPDC6pumgx0mVg=
github.com/containerd/errdefs v1.0.0 h1:tg5yIfIlQIrxYtu9ajqY42W3lpS19XqdxRQeEwYG8PI=
github.com/containerd/errdefs v1.0.0/go.mod h1:+YBYIdtsnF4Iw6nWZhJcqGSg/dwvV7tyJ/kCkyJ2k+M=
github.com/containerd/errdefs/pkg v0.3.0 h1:9IKJ06FvyNlexW690DXuQNx2KA2cUJXx151Xdx3ZPPE=
......@@ -127,6 +171,8 @@ github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU=
github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I=
github.com/golang-jwt/jwt/v5 v5.2.2 h1:Rl4B7itRWVtYIHFrSNd7vhTiz9UpLdi6gZhZ3wEeDy8=
github.com/golang-jwt/jwt/v5 v5.2.2/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk=
github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek=
github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps=
github.com/google/go-cmp v0.5.2/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
......@@ -190,6 +236,8 @@ github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovk
github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM=
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/mattn/go-runewidth v0.0.19 h1:v++JhqYnZuu5jSKrk9RbgF5v4CGUjqRfBm05byFGLdw=
github.com/mattn/go-runewidth v0.0.19/go.mod h1:XBkDxAl56ILZc9knddidhrOlY5R/pDhgLpndooCuJAs=
github.com/mattn/go-sqlite3 v1.14.17 h1:mCRHCLDUBXgpKAqIKsaAaAsrAlbkeomtRFKXh2L6YIM=
github.com/mattn/go-sqlite3 v1.14.17/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg=
github.com/mdelapenya/tlscert v0.2.0 h1:7H81W6Z/4weDvZBNOfQte5GpIMo0lGYEeWbkGp5LJHI=
......@@ -223,6 +271,8 @@ github.com/morikuni/aec v1.0.0 h1:nP9CBfwrvYnBRgY6qfDQkygYDmYwOilePFkwzv4dU8A=
github.com/morikuni/aec v1.0.0/go.mod h1:BbKIizmSmc5MMPqRYbxO4ZU0S0+P200+tUnFx7PXmsc=
github.com/ncruces/go-strftime v1.0.0 h1:HMFp8mLCTPp341M/ZnA4qaf7ZlsbTc+miZjCLOFAw7w=
github.com/ncruces/go-strftime v1.0.0/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls=
github.com/olekukonko/tablewriter v0.0.5 h1:P2Ga83D34wi1o9J6Wh1mRuqd4mF/x/lgBS7N7AbDhec=
github.com/olekukonko/tablewriter v0.0.5/go.mod h1:hPp6KlRPjbx+hW8ykQs1w3UBbZlj6HuIJcUGPhkA7kY=
github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8Oi/yOhh5U=
github.com/opencontainers/go-digest v1.0.0/go.mod h1:0JzlMkj0TRzQZfJkVvzbP0HBR3IKzErnv2BNG4W4MAM=
github.com/opencontainers/image-spec v1.1.1 h1:y0fUlFfIZhPF1W537XOLg0/fcx6zcHCJwooC2xJA040=
......@@ -274,6 +324,8 @@ github.com/spf13/afero v1.11.0 h1:WJQKhtpdm3v2IzqG8VMqrr6Rf3UYpEF239Jy9wNepM8=
github.com/spf13/afero v1.11.0/go.mod h1:GH9Y3pIexgf1MTIWtNGyogA5MwRIDXGUr+hbWNoBjkY=
github.com/spf13/cast v1.6.0 h1:GEiTHELF+vaR5dhz3VqZfFSzZjYbgeKDpBxQVS4GYJ0=
github.com/spf13/cast v1.6.0/go.mod h1:ancEpBxwJDODSW/UG4rDrAqiKolqNNh2DX3mk86cAdo=
github.com/spf13/cobra v1.7.0 h1:hyqWnYt1ZQShIddO5kBpj3vu05/++x6tJ6dg8EC572I=
github.com/spf13/cobra v1.7.0/go.mod h1:uLxZILRyS/50WlhOIKD7W6V5bgeIt+4sICxh6uRMrb0=
github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA=
github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
github.com/spf13/viper v1.18.2 h1:LUXCnvUvSM6FXAsj6nnfc8Q2tp1dIgUfY9Kc8GsSOiQ=
......@@ -344,6 +396,8 @@ go.opentelemetry.io/otel/metric v1.37.0 h1:mvwbQS5m0tbmqML4NqK+e3aDiO02vsf/Wgbsd
go.opentelemetry.io/otel/metric v1.37.0/go.mod h1:04wGrZurHYKOc+RKeye86GwKiTb9FKm1WHtO+4EVr2E=
go.opentelemetry.io/otel/sdk v1.37.0 h1:ItB0QUqnjesGRvNcmAcU0LyvkVyGJ2xftD29bWdDvKI=
go.opentelemetry.io/otel/sdk v1.37.0/go.mod h1:VredYzxUvuo2q3WRcDnKDjbdvmO0sCzOvVAiY+yUkAg=
go.opentelemetry.io/otel/sdk/metric v1.37.0 h1:90lI228XrB9jCMuSdA0673aubgRobVZFhbjxHHspCPc=
go.opentelemetry.io/otel/sdk/metric v1.37.0/go.mod h1:cNen4ZWfiD37l5NhS+Keb5RXVWZWpRE+9WyVCpbo5ps=
go.opentelemetry.io/otel/trace v1.37.0 h1:HLdcFNbRQBE2imdSEgm/kwqmQj1Or1l/7bW6mxVK7z4=
go.opentelemetry.io/otel/trace v1.37.0/go.mod h1:TlgrlQ+PtQO5XFerSPUYG0JSgGyryXewPGyayAWSBS0=
go.opentelemetry.io/proto/otlp v1.3.1 h1:TrMUixzpM0yuc/znrFTP9MMRh8trP93mkCiDVeXrui0=
......@@ -399,6 +453,8 @@ golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGm
golang.org/x/tools v0.41.0 h1:a9b8iMweWG+S0OBnlU36rzLp20z1Rp10w+IY2czHTQc=
golang.org/x/tools v0.41.0/go.mod h1:XSY6eDqxVNiYgezAVqqCeihT4j1U2CCsqvH3WhQpnlg=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
gonum.org/v1/gonum v0.16.0 h1:5+ul4Swaf3ESvrOnidPp4GZbzf0mxVQpDCYUQE7OJfk=
gonum.org/v1/gonum v0.16.0/go.mod h1:fef3am4MQ93R2HHpKnLk4/Tbh/s0+wqD5nfa6Pnwy4E=
google.golang.org/genproto v0.0.0-20231106174013-bbf56f31fb17 h1:wpZ8pe2x1Q3f2KyT5f8oP/fa9rHAKgFPr/HZdNuS+PQ=
google.golang.org/genproto/googleapis/api v0.0.0-20250929231259-57b25ae835d4 h1:8XJ4pajGwOlasW+L13MnEGA8W4115jJySQtVfS2/IBU=
google.golang.org/genproto/googleapis/api v0.0.0-20250929231259-57b25ae835d4/go.mod h1:NnuHhy+bxcg30o7FnVAZbXsPHUDQ9qKWAQKCD7VxFtk=
......
......@@ -364,6 +364,8 @@ type GatewayConfig struct {
// OpenAIPassthroughAllowTimeoutHeaders: OpenAI 透传模式是否放行客户端超时头
// 关闭(默认)可避免 x-stainless-timeout 等头导致上游提前断流。
OpenAIPassthroughAllowTimeoutHeaders bool `mapstructure:"openai_passthrough_allow_timeout_headers"`
// OpenAIWS: OpenAI Responses WebSocket 配置(默认开启,可按需回滚到 HTTP)
OpenAIWS GatewayOpenAIWSConfig `mapstructure:"openai_ws"`
// HTTP 上游连接池配置(性能优化:支持高并发场景调优)
// MaxIdleConns: 所有主机的最大空闲连接总数
......@@ -450,6 +452,101 @@ type GatewayConfig struct {
ModelsListCacheTTLSeconds int `mapstructure:"models_list_cache_ttl_seconds"`
}
// GatewayOpenAIWSConfig OpenAI Responses WebSocket 配置。
// 注意:默认全局开启;如需回滚可使用 force_http 或关闭 enabled。
type GatewayOpenAIWSConfig struct {
// ModeRouterV2Enabled: 新版 WS mode 路由开关(默认 false;关闭时保持 legacy 行为)
ModeRouterV2Enabled bool `mapstructure:"mode_router_v2_enabled"`
// IngressModeDefault: ingress 默认模式(off/shared/dedicated)
IngressModeDefault string `mapstructure:"ingress_mode_default"`
// Enabled: 全局总开关(默认 true)
Enabled bool `mapstructure:"enabled"`
// OAuthEnabled: 是否允许 OpenAI OAuth 账号使用 WS
OAuthEnabled bool `mapstructure:"oauth_enabled"`
// APIKeyEnabled: 是否允许 OpenAI API Key 账号使用 WS
APIKeyEnabled bool `mapstructure:"apikey_enabled"`
// ForceHTTP: 全局强制 HTTP(用于紧急回滚)
ForceHTTP bool `mapstructure:"force_http"`
// AllowStoreRecovery: 允许在 WSv2 下按策略恢复 store=true(默认 false)
AllowStoreRecovery bool `mapstructure:"allow_store_recovery"`
// IngressPreviousResponseRecoveryEnabled: ingress 模式收到 previous_response_not_found 时,是否允许自动去掉 previous_response_id 重试一次(默认 true)
IngressPreviousResponseRecoveryEnabled bool `mapstructure:"ingress_previous_response_recovery_enabled"`
// StoreDisabledConnMode: store=false 且无可复用会话连接时的建连策略(strict/adaptive/off)
// - strict: 强制新建连接(隔离优先)
// - adaptive: 仅在高风险失败后强制新建连接(性能与隔离折中)
// - off: 不强制新建连接(复用优先)
StoreDisabledConnMode string `mapstructure:"store_disabled_conn_mode"`
// StoreDisabledForceNewConn: store=false 且无可复用粘连连接时是否强制新建连接(默认 true,保障会话隔离)
// 兼容旧配置;当 StoreDisabledConnMode 为空时才生效。
StoreDisabledForceNewConn bool `mapstructure:"store_disabled_force_new_conn"`
// PrewarmGenerateEnabled: 是否启用 WSv2 generate=false 预热(默认 false)
PrewarmGenerateEnabled bool `mapstructure:"prewarm_generate_enabled"`
// Feature 开关:v2 优先于 v1
ResponsesWebsockets bool `mapstructure:"responses_websockets"`
ResponsesWebsocketsV2 bool `mapstructure:"responses_websockets_v2"`
// 连接池参数
MaxConnsPerAccount int `mapstructure:"max_conns_per_account"`
MinIdlePerAccount int `mapstructure:"min_idle_per_account"`
MaxIdlePerAccount int `mapstructure:"max_idle_per_account"`
// DynamicMaxConnsByAccountConcurrencyEnabled: 是否按账号并发动态计算连接池上限
DynamicMaxConnsByAccountConcurrencyEnabled bool `mapstructure:"dynamic_max_conns_by_account_concurrency_enabled"`
// OAuthMaxConnsFactor: OAuth 账号连接池系数(effective=ceil(concurrency*factor))
OAuthMaxConnsFactor float64 `mapstructure:"oauth_max_conns_factor"`
// APIKeyMaxConnsFactor: API Key 账号连接池系数(effective=ceil(concurrency*factor))
APIKeyMaxConnsFactor float64 `mapstructure:"apikey_max_conns_factor"`
DialTimeoutSeconds int `mapstructure:"dial_timeout_seconds"`
ReadTimeoutSeconds int `mapstructure:"read_timeout_seconds"`
WriteTimeoutSeconds int `mapstructure:"write_timeout_seconds"`
PoolTargetUtilization float64 `mapstructure:"pool_target_utilization"`
QueueLimitPerConn int `mapstructure:"queue_limit_per_conn"`
// EventFlushBatchSize: WS 流式写出批量 flush 阈值(事件条数)
EventFlushBatchSize int `mapstructure:"event_flush_batch_size"`
// EventFlushIntervalMS: WS 流式写出最大等待时间(毫秒);0 表示仅按 batch 触发
EventFlushIntervalMS int `mapstructure:"event_flush_interval_ms"`
// PrewarmCooldownMS: 连接池预热触发冷却时间(毫秒)
PrewarmCooldownMS int `mapstructure:"prewarm_cooldown_ms"`
// FallbackCooldownSeconds: WS 回退冷却窗口,避免 WS/HTTP 抖动;0 表示关闭冷却
FallbackCooldownSeconds int `mapstructure:"fallback_cooldown_seconds"`
// RetryBackoffInitialMS: WS 重试初始退避(毫秒);<=0 表示关闭退避
RetryBackoffInitialMS int `mapstructure:"retry_backoff_initial_ms"`
// RetryBackoffMaxMS: WS 重试最大退避(毫秒)
RetryBackoffMaxMS int `mapstructure:"retry_backoff_max_ms"`
// RetryJitterRatio: WS 重试退避抖动比例(0-1)
RetryJitterRatio float64 `mapstructure:"retry_jitter_ratio"`
// RetryTotalBudgetMS: WS 单次请求重试总预算(毫秒);0 表示关闭预算限制
RetryTotalBudgetMS int `mapstructure:"retry_total_budget_ms"`
// PayloadLogSampleRate: payload_schema 日志采样率(0-1)
PayloadLogSampleRate float64 `mapstructure:"payload_log_sample_rate"`
// 账号调度与粘连参数
LBTopK int `mapstructure:"lb_top_k"`
// StickySessionTTLSeconds: session_hash -> account_id 粘连 TTL
StickySessionTTLSeconds int `mapstructure:"sticky_session_ttl_seconds"`
// SessionHashReadOldFallback: 会话哈希迁移期是否允许“新 key 未命中时回退读旧 SHA-256 key”
SessionHashReadOldFallback bool `mapstructure:"session_hash_read_old_fallback"`
// SessionHashDualWriteOld: 会话哈希迁移期是否双写旧 SHA-256 key(短 TTL)
SessionHashDualWriteOld bool `mapstructure:"session_hash_dual_write_old"`
// MetadataBridgeEnabled: RequestMetadata 迁移期是否保留旧 ctxkey.* 兼容桥接
MetadataBridgeEnabled bool `mapstructure:"metadata_bridge_enabled"`
// StickyResponseIDTTLSeconds: response_id -> account_id 粘连 TTL
StickyResponseIDTTLSeconds int `mapstructure:"sticky_response_id_ttl_seconds"`
// StickyPreviousResponseTTLSeconds: 兼容旧键(当新键未设置时回退)
StickyPreviousResponseTTLSeconds int `mapstructure:"sticky_previous_response_ttl_seconds"`
SchedulerScoreWeights GatewayOpenAIWSSchedulerScoreWeights `mapstructure:"scheduler_score_weights"`
}
// GatewayOpenAIWSSchedulerScoreWeights 账号调度打分权重。
type GatewayOpenAIWSSchedulerScoreWeights struct {
Priority float64 `mapstructure:"priority"`
Load float64 `mapstructure:"load"`
Queue float64 `mapstructure:"queue"`
ErrorRate float64 `mapstructure:"error_rate"`
TTFT float64 `mapstructure:"ttft"`
}
// GatewayUsageRecordConfig 使用量记录异步队列配置
type GatewayUsageRecordConfig struct {
// WorkerCount: worker 初始数量(自动扩缩容开启时作为初始并发上限)
......@@ -886,6 +983,12 @@ func load(allowMissingJWTSecret bool) (*Config, error) {
cfg.Log.StacktraceLevel = strings.ToLower(strings.TrimSpace(cfg.Log.StacktraceLevel))
cfg.Log.Output.FilePath = strings.TrimSpace(cfg.Log.Output.FilePath)
// 兼容旧键 gateway.openai_ws.sticky_previous_response_ttl_seconds。
// 新键未配置(<=0)时回退旧键;新键优先。
if cfg.Gateway.OpenAIWS.StickyResponseIDTTLSeconds <= 0 && cfg.Gateway.OpenAIWS.StickyPreviousResponseTTLSeconds > 0 {
cfg.Gateway.OpenAIWS.StickyResponseIDTTLSeconds = cfg.Gateway.OpenAIWS.StickyPreviousResponseTTLSeconds
}
// Auto-generate TOTP encryption key if not set (32 bytes = 64 hex chars for AES-256)
cfg.Totp.EncryptionKey = strings.TrimSpace(cfg.Totp.EncryptionKey)
if cfg.Totp.EncryptionKey == "" {
......@@ -945,7 +1048,7 @@ func setDefaults() {
viper.SetDefault("server.read_header_timeout", 30) // 30秒读取请求头
viper.SetDefault("server.idle_timeout", 120) // 120秒空闲超时
viper.SetDefault("server.trusted_proxies", []string{})
viper.SetDefault("server.max_request_body_size", int64(100*1024*1024))
viper.SetDefault("server.max_request_body_size", int64(256*1024*1024))
// H2C 默认配置
viper.SetDefault("server.h2c.enabled", false)
viper.SetDefault("server.h2c.max_concurrent_streams", uint32(50)) // 50 个并发流
......@@ -1088,9 +1191,9 @@ func setDefaults() {
// RateLimit
viper.SetDefault("rate_limit.overload_cooldown_minutes", 10)
// Pricing - 从 model-price-repo 同步模型定价和上下文窗口数据的配置
viper.SetDefault("pricing.remote_url", "https://github.com/Wei-Shaw/model-price-repo/raw/refs/heads/main/model_prices_and_context_window.json")
viper.SetDefault("pricing.hash_url", "https://github.com/Wei-Shaw/model-price-repo/raw/refs/heads/main/model_prices_and_context_window.sha256")
// Pricing - 从 model-price-repo 同步模型定价和上下文窗口数据(固定到 commit,避免分支漂移)
viper.SetDefault("pricing.remote_url", "https://raw.githubusercontent.com/Wei-Shaw/model-price-repo/c7947e9871687e664180bc971d4837f1fc2784a9/model_prices_and_context_window.json")
viper.SetDefault("pricing.hash_url", "https://raw.githubusercontent.com/Wei-Shaw/model-price-repo/c7947e9871687e664180bc971d4837f1fc2784a9/model_prices_and_context_window.sha256")
viper.SetDefault("pricing.data_dir", "./data")
viper.SetDefault("pricing.fallback_file", "./resources/model-pricing/model_prices_and_context_window.json")
viper.SetDefault("pricing.update_interval_hours", 24)
......@@ -1157,9 +1260,55 @@ func setDefaults() {
viper.SetDefault("gateway.max_account_switches_gemini", 3)
viper.SetDefault("gateway.force_codex_cli", false)
viper.SetDefault("gateway.openai_passthrough_allow_timeout_headers", false)
// OpenAI Responses WebSocket(默认开启;可通过 force_http 紧急回滚)
viper.SetDefault("gateway.openai_ws.enabled", true)
viper.SetDefault("gateway.openai_ws.mode_router_v2_enabled", false)
viper.SetDefault("gateway.openai_ws.ingress_mode_default", "shared")
viper.SetDefault("gateway.openai_ws.oauth_enabled", true)
viper.SetDefault("gateway.openai_ws.apikey_enabled", true)
viper.SetDefault("gateway.openai_ws.force_http", false)
viper.SetDefault("gateway.openai_ws.allow_store_recovery", false)
viper.SetDefault("gateway.openai_ws.ingress_previous_response_recovery_enabled", true)
viper.SetDefault("gateway.openai_ws.store_disabled_conn_mode", "strict")
viper.SetDefault("gateway.openai_ws.store_disabled_force_new_conn", true)
viper.SetDefault("gateway.openai_ws.prewarm_generate_enabled", false)
viper.SetDefault("gateway.openai_ws.responses_websockets", false)
viper.SetDefault("gateway.openai_ws.responses_websockets_v2", true)
viper.SetDefault("gateway.openai_ws.max_conns_per_account", 128)
viper.SetDefault("gateway.openai_ws.min_idle_per_account", 4)
viper.SetDefault("gateway.openai_ws.max_idle_per_account", 12)
viper.SetDefault("gateway.openai_ws.dynamic_max_conns_by_account_concurrency_enabled", true)
viper.SetDefault("gateway.openai_ws.oauth_max_conns_factor", 1.0)
viper.SetDefault("gateway.openai_ws.apikey_max_conns_factor", 1.0)
viper.SetDefault("gateway.openai_ws.dial_timeout_seconds", 10)
viper.SetDefault("gateway.openai_ws.read_timeout_seconds", 900)
viper.SetDefault("gateway.openai_ws.write_timeout_seconds", 120)
viper.SetDefault("gateway.openai_ws.pool_target_utilization", 0.7)
viper.SetDefault("gateway.openai_ws.queue_limit_per_conn", 64)
viper.SetDefault("gateway.openai_ws.event_flush_batch_size", 1)
viper.SetDefault("gateway.openai_ws.event_flush_interval_ms", 10)
viper.SetDefault("gateway.openai_ws.prewarm_cooldown_ms", 300)
viper.SetDefault("gateway.openai_ws.fallback_cooldown_seconds", 30)
viper.SetDefault("gateway.openai_ws.retry_backoff_initial_ms", 120)
viper.SetDefault("gateway.openai_ws.retry_backoff_max_ms", 2000)
viper.SetDefault("gateway.openai_ws.retry_jitter_ratio", 0.2)
viper.SetDefault("gateway.openai_ws.retry_total_budget_ms", 5000)
viper.SetDefault("gateway.openai_ws.payload_log_sample_rate", 0.2)
viper.SetDefault("gateway.openai_ws.lb_top_k", 7)
viper.SetDefault("gateway.openai_ws.sticky_session_ttl_seconds", 3600)
viper.SetDefault("gateway.openai_ws.session_hash_read_old_fallback", true)
viper.SetDefault("gateway.openai_ws.session_hash_dual_write_old", true)
viper.SetDefault("gateway.openai_ws.metadata_bridge_enabled", true)
viper.SetDefault("gateway.openai_ws.sticky_response_id_ttl_seconds", 3600)
viper.SetDefault("gateway.openai_ws.sticky_previous_response_ttl_seconds", 3600)
viper.SetDefault("gateway.openai_ws.scheduler_score_weights.priority", 1.0)
viper.SetDefault("gateway.openai_ws.scheduler_score_weights.load", 1.0)
viper.SetDefault("gateway.openai_ws.scheduler_score_weights.queue", 0.7)
viper.SetDefault("gateway.openai_ws.scheduler_score_weights.error_rate", 0.8)
viper.SetDefault("gateway.openai_ws.scheduler_score_weights.ttft", 0.5)
viper.SetDefault("gateway.antigravity_fallback_cooldown_minutes", 1)
viper.SetDefault("gateway.antigravity_extra_retries", 10)
viper.SetDefault("gateway.max_body_size", int64(100*1024*1024))
viper.SetDefault("gateway.max_body_size", int64(256*1024*1024))
viper.SetDefault("gateway.upstream_response_read_max_bytes", int64(8*1024*1024))
viper.SetDefault("gateway.proxy_probe_response_read_max_bytes", int64(1024*1024))
viper.SetDefault("gateway.gemini_debug_response_headers", false)
......@@ -1747,6 +1896,118 @@ func (c *Config) Validate() error {
(c.Gateway.StreamKeepaliveInterval < 5 || c.Gateway.StreamKeepaliveInterval > 30) {
return fmt.Errorf("gateway.stream_keepalive_interval must be 0 or between 5-30 seconds")
}
// 兼容旧键 sticky_previous_response_ttl_seconds
if c.Gateway.OpenAIWS.StickyResponseIDTTLSeconds <= 0 && c.Gateway.OpenAIWS.StickyPreviousResponseTTLSeconds > 0 {
c.Gateway.OpenAIWS.StickyResponseIDTTLSeconds = c.Gateway.OpenAIWS.StickyPreviousResponseTTLSeconds
}
if c.Gateway.OpenAIWS.MaxConnsPerAccount <= 0 {
return fmt.Errorf("gateway.openai_ws.max_conns_per_account must be positive")
}
if c.Gateway.OpenAIWS.MinIdlePerAccount < 0 {
return fmt.Errorf("gateway.openai_ws.min_idle_per_account must be non-negative")
}
if c.Gateway.OpenAIWS.MaxIdlePerAccount < 0 {
return fmt.Errorf("gateway.openai_ws.max_idle_per_account must be non-negative")
}
if c.Gateway.OpenAIWS.MinIdlePerAccount > c.Gateway.OpenAIWS.MaxIdlePerAccount {
return fmt.Errorf("gateway.openai_ws.min_idle_per_account must be <= max_idle_per_account")
}
if c.Gateway.OpenAIWS.MaxIdlePerAccount > c.Gateway.OpenAIWS.MaxConnsPerAccount {
return fmt.Errorf("gateway.openai_ws.max_idle_per_account must be <= max_conns_per_account")
}
if c.Gateway.OpenAIWS.OAuthMaxConnsFactor <= 0 {
return fmt.Errorf("gateway.openai_ws.oauth_max_conns_factor must be positive")
}
if c.Gateway.OpenAIWS.APIKeyMaxConnsFactor <= 0 {
return fmt.Errorf("gateway.openai_ws.apikey_max_conns_factor must be positive")
}
if c.Gateway.OpenAIWS.DialTimeoutSeconds <= 0 {
return fmt.Errorf("gateway.openai_ws.dial_timeout_seconds must be positive")
}
if c.Gateway.OpenAIWS.ReadTimeoutSeconds <= 0 {
return fmt.Errorf("gateway.openai_ws.read_timeout_seconds must be positive")
}
if c.Gateway.OpenAIWS.WriteTimeoutSeconds <= 0 {
return fmt.Errorf("gateway.openai_ws.write_timeout_seconds must be positive")
}
if c.Gateway.OpenAIWS.PoolTargetUtilization <= 0 || c.Gateway.OpenAIWS.PoolTargetUtilization > 1 {
return fmt.Errorf("gateway.openai_ws.pool_target_utilization must be within (0,1]")
}
if c.Gateway.OpenAIWS.QueueLimitPerConn <= 0 {
return fmt.Errorf("gateway.openai_ws.queue_limit_per_conn must be positive")
}
if c.Gateway.OpenAIWS.EventFlushBatchSize <= 0 {
return fmt.Errorf("gateway.openai_ws.event_flush_batch_size must be positive")
}
if c.Gateway.OpenAIWS.EventFlushIntervalMS < 0 {
return fmt.Errorf("gateway.openai_ws.event_flush_interval_ms must be non-negative")
}
if c.Gateway.OpenAIWS.PrewarmCooldownMS < 0 {
return fmt.Errorf("gateway.openai_ws.prewarm_cooldown_ms must be non-negative")
}
if c.Gateway.OpenAIWS.FallbackCooldownSeconds < 0 {
return fmt.Errorf("gateway.openai_ws.fallback_cooldown_seconds must be non-negative")
}
if c.Gateway.OpenAIWS.RetryBackoffInitialMS < 0 {
return fmt.Errorf("gateway.openai_ws.retry_backoff_initial_ms must be non-negative")
}
if c.Gateway.OpenAIWS.RetryBackoffMaxMS < 0 {
return fmt.Errorf("gateway.openai_ws.retry_backoff_max_ms must be non-negative")
}
if c.Gateway.OpenAIWS.RetryBackoffInitialMS > 0 && c.Gateway.OpenAIWS.RetryBackoffMaxMS > 0 &&
c.Gateway.OpenAIWS.RetryBackoffMaxMS < c.Gateway.OpenAIWS.RetryBackoffInitialMS {
return fmt.Errorf("gateway.openai_ws.retry_backoff_max_ms must be >= retry_backoff_initial_ms")
}
if c.Gateway.OpenAIWS.RetryJitterRatio < 0 || c.Gateway.OpenAIWS.RetryJitterRatio > 1 {
return fmt.Errorf("gateway.openai_ws.retry_jitter_ratio must be within [0,1]")
}
if c.Gateway.OpenAIWS.RetryTotalBudgetMS < 0 {
return fmt.Errorf("gateway.openai_ws.retry_total_budget_ms must be non-negative")
}
if mode := strings.ToLower(strings.TrimSpace(c.Gateway.OpenAIWS.IngressModeDefault)); mode != "" {
switch mode {
case "off", "shared", "dedicated":
default:
return fmt.Errorf("gateway.openai_ws.ingress_mode_default must be one of off|shared|dedicated")
}
}
if mode := strings.ToLower(strings.TrimSpace(c.Gateway.OpenAIWS.StoreDisabledConnMode)); mode != "" {
switch mode {
case "strict", "adaptive", "off":
default:
return fmt.Errorf("gateway.openai_ws.store_disabled_conn_mode must be one of strict|adaptive|off")
}
}
if c.Gateway.OpenAIWS.PayloadLogSampleRate < 0 || c.Gateway.OpenAIWS.PayloadLogSampleRate > 1 {
return fmt.Errorf("gateway.openai_ws.payload_log_sample_rate must be within [0,1]")
}
if c.Gateway.OpenAIWS.LBTopK <= 0 {
return fmt.Errorf("gateway.openai_ws.lb_top_k must be positive")
}
if c.Gateway.OpenAIWS.StickySessionTTLSeconds <= 0 {
return fmt.Errorf("gateway.openai_ws.sticky_session_ttl_seconds must be positive")
}
if c.Gateway.OpenAIWS.StickyResponseIDTTLSeconds <= 0 {
return fmt.Errorf("gateway.openai_ws.sticky_response_id_ttl_seconds must be positive")
}
if c.Gateway.OpenAIWS.StickyPreviousResponseTTLSeconds < 0 {
return fmt.Errorf("gateway.openai_ws.sticky_previous_response_ttl_seconds must be non-negative")
}
if c.Gateway.OpenAIWS.SchedulerScoreWeights.Priority < 0 ||
c.Gateway.OpenAIWS.SchedulerScoreWeights.Load < 0 ||
c.Gateway.OpenAIWS.SchedulerScoreWeights.Queue < 0 ||
c.Gateway.OpenAIWS.SchedulerScoreWeights.ErrorRate < 0 ||
c.Gateway.OpenAIWS.SchedulerScoreWeights.TTFT < 0 {
return fmt.Errorf("gateway.openai_ws.scheduler_score_weights.* must be non-negative")
}
weightSum := c.Gateway.OpenAIWS.SchedulerScoreWeights.Priority +
c.Gateway.OpenAIWS.SchedulerScoreWeights.Load +
c.Gateway.OpenAIWS.SchedulerScoreWeights.Queue +
c.Gateway.OpenAIWS.SchedulerScoreWeights.ErrorRate +
c.Gateway.OpenAIWS.SchedulerScoreWeights.TTFT
if weightSum <= 0 {
return fmt.Errorf("gateway.openai_ws.scheduler_score_weights must not all be zero")
}
if c.Gateway.MaxLineSize < 0 {
return fmt.Errorf("gateway.max_line_size must be non-negative")
}
......
......@@ -6,6 +6,7 @@ import (
"time"
"github.com/spf13/viper"
"github.com/stretchr/testify/require"
)
func resetViperWithJWTSecret(t *testing.T) {
......@@ -75,6 +76,103 @@ func TestLoadDefaultSchedulingConfig(t *testing.T) {
}
}
func TestLoadDefaultOpenAIWSConfig(t *testing.T) {
resetViperWithJWTSecret(t)
cfg, err := Load()
if err != nil {
t.Fatalf("Load() error: %v", err)
}
if !cfg.Gateway.OpenAIWS.Enabled {
t.Fatalf("Gateway.OpenAIWS.Enabled = false, want true")
}
if !cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 {
t.Fatalf("Gateway.OpenAIWS.ResponsesWebsocketsV2 = false, want true")
}
if cfg.Gateway.OpenAIWS.ResponsesWebsockets {
t.Fatalf("Gateway.OpenAIWS.ResponsesWebsockets = true, want false")
}
if !cfg.Gateway.OpenAIWS.DynamicMaxConnsByAccountConcurrencyEnabled {
t.Fatalf("Gateway.OpenAIWS.DynamicMaxConnsByAccountConcurrencyEnabled = false, want true")
}
if cfg.Gateway.OpenAIWS.OAuthMaxConnsFactor != 1.0 {
t.Fatalf("Gateway.OpenAIWS.OAuthMaxConnsFactor = %v, want 1.0", cfg.Gateway.OpenAIWS.OAuthMaxConnsFactor)
}
if cfg.Gateway.OpenAIWS.APIKeyMaxConnsFactor != 1.0 {
t.Fatalf("Gateway.OpenAIWS.APIKeyMaxConnsFactor = %v, want 1.0", cfg.Gateway.OpenAIWS.APIKeyMaxConnsFactor)
}
if cfg.Gateway.OpenAIWS.StickySessionTTLSeconds != 3600 {
t.Fatalf("Gateway.OpenAIWS.StickySessionTTLSeconds = %d, want 3600", cfg.Gateway.OpenAIWS.StickySessionTTLSeconds)
}
if !cfg.Gateway.OpenAIWS.SessionHashReadOldFallback {
t.Fatalf("Gateway.OpenAIWS.SessionHashReadOldFallback = false, want true")
}
if !cfg.Gateway.OpenAIWS.SessionHashDualWriteOld {
t.Fatalf("Gateway.OpenAIWS.SessionHashDualWriteOld = false, want true")
}
if !cfg.Gateway.OpenAIWS.MetadataBridgeEnabled {
t.Fatalf("Gateway.OpenAIWS.MetadataBridgeEnabled = false, want true")
}
if cfg.Gateway.OpenAIWS.StickyResponseIDTTLSeconds != 3600 {
t.Fatalf("Gateway.OpenAIWS.StickyResponseIDTTLSeconds = %d, want 3600", cfg.Gateway.OpenAIWS.StickyResponseIDTTLSeconds)
}
if cfg.Gateway.OpenAIWS.FallbackCooldownSeconds != 30 {
t.Fatalf("Gateway.OpenAIWS.FallbackCooldownSeconds = %d, want 30", cfg.Gateway.OpenAIWS.FallbackCooldownSeconds)
}
if cfg.Gateway.OpenAIWS.EventFlushBatchSize != 1 {
t.Fatalf("Gateway.OpenAIWS.EventFlushBatchSize = %d, want 1", cfg.Gateway.OpenAIWS.EventFlushBatchSize)
}
if cfg.Gateway.OpenAIWS.EventFlushIntervalMS != 10 {
t.Fatalf("Gateway.OpenAIWS.EventFlushIntervalMS = %d, want 10", cfg.Gateway.OpenAIWS.EventFlushIntervalMS)
}
if cfg.Gateway.OpenAIWS.PrewarmCooldownMS != 300 {
t.Fatalf("Gateway.OpenAIWS.PrewarmCooldownMS = %d, want 300", cfg.Gateway.OpenAIWS.PrewarmCooldownMS)
}
if cfg.Gateway.OpenAIWS.RetryBackoffInitialMS != 120 {
t.Fatalf("Gateway.OpenAIWS.RetryBackoffInitialMS = %d, want 120", cfg.Gateway.OpenAIWS.RetryBackoffInitialMS)
}
if cfg.Gateway.OpenAIWS.RetryBackoffMaxMS != 2000 {
t.Fatalf("Gateway.OpenAIWS.RetryBackoffMaxMS = %d, want 2000", cfg.Gateway.OpenAIWS.RetryBackoffMaxMS)
}
if cfg.Gateway.OpenAIWS.RetryJitterRatio != 0.2 {
t.Fatalf("Gateway.OpenAIWS.RetryJitterRatio = %v, want 0.2", cfg.Gateway.OpenAIWS.RetryJitterRatio)
}
if cfg.Gateway.OpenAIWS.RetryTotalBudgetMS != 5000 {
t.Fatalf("Gateway.OpenAIWS.RetryTotalBudgetMS = %d, want 5000", cfg.Gateway.OpenAIWS.RetryTotalBudgetMS)
}
if cfg.Gateway.OpenAIWS.PayloadLogSampleRate != 0.2 {
t.Fatalf("Gateway.OpenAIWS.PayloadLogSampleRate = %v, want 0.2", cfg.Gateway.OpenAIWS.PayloadLogSampleRate)
}
if !cfg.Gateway.OpenAIWS.StoreDisabledForceNewConn {
t.Fatalf("Gateway.OpenAIWS.StoreDisabledForceNewConn = false, want true")
}
if cfg.Gateway.OpenAIWS.StoreDisabledConnMode != "strict" {
t.Fatalf("Gateway.OpenAIWS.StoreDisabledConnMode = %q, want %q", cfg.Gateway.OpenAIWS.StoreDisabledConnMode, "strict")
}
if cfg.Gateway.OpenAIWS.ModeRouterV2Enabled {
t.Fatalf("Gateway.OpenAIWS.ModeRouterV2Enabled = true, want false")
}
if cfg.Gateway.OpenAIWS.IngressModeDefault != "shared" {
t.Fatalf("Gateway.OpenAIWS.IngressModeDefault = %q, want %q", cfg.Gateway.OpenAIWS.IngressModeDefault, "shared")
}
}
func TestLoadOpenAIWSStickyTTLCompatibility(t *testing.T) {
resetViperWithJWTSecret(t)
t.Setenv("GATEWAY_OPENAI_WS_STICKY_RESPONSE_ID_TTL_SECONDS", "0")
t.Setenv("GATEWAY_OPENAI_WS_STICKY_PREVIOUS_RESPONSE_TTL_SECONDS", "7200")
cfg, err := Load()
if err != nil {
t.Fatalf("Load() error: %v", err)
}
if cfg.Gateway.OpenAIWS.StickyResponseIDTTLSeconds != 7200 {
t.Fatalf("StickyResponseIDTTLSeconds = %d, want 7200", cfg.Gateway.OpenAIWS.StickyResponseIDTTLSeconds)
}
}
func TestLoadDefaultIdempotencyConfig(t *testing.T) {
resetViperWithJWTSecret(t)
......@@ -993,6 +1091,16 @@ func TestValidateConfigErrors(t *testing.T) {
mutate: func(c *Config) { c.Gateway.StreamKeepaliveInterval = 4 },
wantErr: "gateway.stream_keepalive_interval",
},
{
name: "gateway openai ws oauth max conns factor",
mutate: func(c *Config) { c.Gateway.OpenAIWS.OAuthMaxConnsFactor = 0 },
wantErr: "gateway.openai_ws.oauth_max_conns_factor",
},
{
name: "gateway openai ws apikey max conns factor",
mutate: func(c *Config) { c.Gateway.OpenAIWS.APIKeyMaxConnsFactor = 0 },
wantErr: "gateway.openai_ws.apikey_max_conns_factor",
},
{
name: "gateway stream data interval range",
mutate: func(c *Config) { c.Gateway.StreamDataIntervalTimeout = 5 },
......@@ -1174,6 +1282,165 @@ func TestValidateConfigErrors(t *testing.T) {
}
}
func TestValidateConfig_OpenAIWSRules(t *testing.T) {
buildValid := func(t *testing.T) *Config {
t.Helper()
resetViperWithJWTSecret(t)
cfg, err := Load()
require.NoError(t, err)
return cfg
}
t.Run("sticky response id ttl 兼容旧键回填", func(t *testing.T) {
cfg := buildValid(t)
cfg.Gateway.OpenAIWS.StickyResponseIDTTLSeconds = 0
cfg.Gateway.OpenAIWS.StickyPreviousResponseTTLSeconds = 7200
require.NoError(t, cfg.Validate())
require.Equal(t, 7200, cfg.Gateway.OpenAIWS.StickyResponseIDTTLSeconds)
})
cases := []struct {
name string
mutate func(*Config)
wantErr string
}{
{
name: "max_conns_per_account 必须为正数",
mutate: func(c *Config) { c.Gateway.OpenAIWS.MaxConnsPerAccount = 0 },
wantErr: "gateway.openai_ws.max_conns_per_account",
},
{
name: "min_idle_per_account 不能为负数",
mutate: func(c *Config) { c.Gateway.OpenAIWS.MinIdlePerAccount = -1 },
wantErr: "gateway.openai_ws.min_idle_per_account",
},
{
name: "max_idle_per_account 不能为负数",
mutate: func(c *Config) { c.Gateway.OpenAIWS.MaxIdlePerAccount = -1 },
wantErr: "gateway.openai_ws.max_idle_per_account",
},
{
name: "min_idle_per_account 不能大于 max_idle_per_account",
mutate: func(c *Config) {
c.Gateway.OpenAIWS.MinIdlePerAccount = 3
c.Gateway.OpenAIWS.MaxIdlePerAccount = 2
},
wantErr: "gateway.openai_ws.min_idle_per_account must be <= max_idle_per_account",
},
{
name: "max_idle_per_account 不能大于 max_conns_per_account",
mutate: func(c *Config) {
c.Gateway.OpenAIWS.MaxConnsPerAccount = 2
c.Gateway.OpenAIWS.MinIdlePerAccount = 1
c.Gateway.OpenAIWS.MaxIdlePerAccount = 3
},
wantErr: "gateway.openai_ws.max_idle_per_account must be <= max_conns_per_account",
},
{
name: "dial_timeout_seconds 必须为正数",
mutate: func(c *Config) { c.Gateway.OpenAIWS.DialTimeoutSeconds = 0 },
wantErr: "gateway.openai_ws.dial_timeout_seconds",
},
{
name: "read_timeout_seconds 必须为正数",
mutate: func(c *Config) { c.Gateway.OpenAIWS.ReadTimeoutSeconds = 0 },
wantErr: "gateway.openai_ws.read_timeout_seconds",
},
{
name: "write_timeout_seconds 必须为正数",
mutate: func(c *Config) { c.Gateway.OpenAIWS.WriteTimeoutSeconds = 0 },
wantErr: "gateway.openai_ws.write_timeout_seconds",
},
{
name: "pool_target_utilization 必须在 (0,1]",
mutate: func(c *Config) { c.Gateway.OpenAIWS.PoolTargetUtilization = 0 },
wantErr: "gateway.openai_ws.pool_target_utilization",
},
{
name: "queue_limit_per_conn 必须为正数",
mutate: func(c *Config) { c.Gateway.OpenAIWS.QueueLimitPerConn = 0 },
wantErr: "gateway.openai_ws.queue_limit_per_conn",
},
{
name: "fallback_cooldown_seconds 不能为负数",
mutate: func(c *Config) { c.Gateway.OpenAIWS.FallbackCooldownSeconds = -1 },
wantErr: "gateway.openai_ws.fallback_cooldown_seconds",
},
{
name: "store_disabled_conn_mode 必须为 strict|adaptive|off",
mutate: func(c *Config) { c.Gateway.OpenAIWS.StoreDisabledConnMode = "invalid" },
wantErr: "gateway.openai_ws.store_disabled_conn_mode",
},
{
name: "ingress_mode_default 必须为 off|shared|dedicated",
mutate: func(c *Config) { c.Gateway.OpenAIWS.IngressModeDefault = "invalid" },
wantErr: "gateway.openai_ws.ingress_mode_default",
},
{
name: "payload_log_sample_rate 必须在 [0,1] 范围内",
mutate: func(c *Config) { c.Gateway.OpenAIWS.PayloadLogSampleRate = 1.2 },
wantErr: "gateway.openai_ws.payload_log_sample_rate",
},
{
name: "retry_total_budget_ms 不能为负数",
mutate: func(c *Config) { c.Gateway.OpenAIWS.RetryTotalBudgetMS = -1 },
wantErr: "gateway.openai_ws.retry_total_budget_ms",
},
{
name: "lb_top_k 必须为正数",
mutate: func(c *Config) { c.Gateway.OpenAIWS.LBTopK = 0 },
wantErr: "gateway.openai_ws.lb_top_k",
},
{
name: "sticky_session_ttl_seconds 必须为正数",
mutate: func(c *Config) { c.Gateway.OpenAIWS.StickySessionTTLSeconds = 0 },
wantErr: "gateway.openai_ws.sticky_session_ttl_seconds",
},
{
name: "sticky_response_id_ttl_seconds 必须为正数",
mutate: func(c *Config) {
c.Gateway.OpenAIWS.StickyResponseIDTTLSeconds = 0
c.Gateway.OpenAIWS.StickyPreviousResponseTTLSeconds = 0
},
wantErr: "gateway.openai_ws.sticky_response_id_ttl_seconds",
},
{
name: "sticky_previous_response_ttl_seconds 不能为负数",
mutate: func(c *Config) { c.Gateway.OpenAIWS.StickyPreviousResponseTTLSeconds = -1 },
wantErr: "gateway.openai_ws.sticky_previous_response_ttl_seconds",
},
{
name: "scheduler_score_weights 不能为负数",
mutate: func(c *Config) { c.Gateway.OpenAIWS.SchedulerScoreWeights.Queue = -0.1 },
wantErr: "gateway.openai_ws.scheduler_score_weights.* must be non-negative",
},
{
name: "scheduler_score_weights 不能全为 0",
mutate: func(c *Config) {
c.Gateway.OpenAIWS.SchedulerScoreWeights.Priority = 0
c.Gateway.OpenAIWS.SchedulerScoreWeights.Load = 0
c.Gateway.OpenAIWS.SchedulerScoreWeights.Queue = 0
c.Gateway.OpenAIWS.SchedulerScoreWeights.ErrorRate = 0
c.Gateway.OpenAIWS.SchedulerScoreWeights.TTFT = 0
},
wantErr: "gateway.openai_ws.scheduler_score_weights must not all be zero",
},
}
for _, tc := range cases {
tc := tc
t.Run(tc.name, func(t *testing.T) {
cfg := buildValid(t)
tc.mutate(cfg)
err := cfg.Validate()
require.Error(t, err)
require.Contains(t, err.Error(), tc.wantErr)
})
}
}
func TestValidateConfig_AutoScaleDisabledIgnoreAutoScaleFields(t *testing.T) {
resetViperWithJWTSecret(t)
cfg, err := Load()
......
......@@ -104,6 +104,9 @@ var DefaultAntigravityModelMapping = map[string]string{
"gemini-3.1-flash-image": "gemini-3.1-flash-image",
// Gemini 3.1 image preview 映射
"gemini-3.1-flash-image-preview": "gemini-3.1-flash-image",
// Gemini 3 image 兼容映射(向 3.1 image 迁移)
"gemini-3-pro-image": "gemini-3.1-flash-image",
"gemini-3-pro-image-preview": "gemini-3.1-flash-image",
// 其他官方模型
"gpt-oss-120b-medium": "gpt-oss-120b-medium",
"tab_flash_lite_preview": "tab_flash_lite_preview",
......
package domain
import "testing"
func TestDefaultAntigravityModelMapping_ImageCompatibilityAliases(t *testing.T) {
t.Parallel()
cases := map[string]string{
"gemini-3.1-flash-image": "gemini-3.1-flash-image",
"gemini-3.1-flash-image-preview": "gemini-3.1-flash-image",
"gemini-3-pro-image": "gemini-3.1-flash-image",
"gemini-3-pro-image-preview": "gemini-3.1-flash-image",
}
for from, want := range cases {
got, ok := DefaultAntigravityModelMapping[from]
if !ok {
t.Fatalf("expected mapping for %q to exist", from)
}
if got != want {
t.Fatalf("unexpected mapping for %q: got %q want %q", from, got, want)
}
}
}
......@@ -1337,6 +1337,34 @@ func (h *AccountHandler) GetTodayStats(c *gin.Context) {
response.Success(c, stats)
}
// BatchTodayStatsRequest 批量今日统计请求体。
type BatchTodayStatsRequest struct {
AccountIDs []int64 `json:"account_ids" binding:"required"`
}
// GetBatchTodayStats 批量获取多个账号的今日统计。
// POST /api/v1/admin/accounts/today-stats/batch
func (h *AccountHandler) GetBatchTodayStats(c *gin.Context) {
var req BatchTodayStatsRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
if len(req.AccountIDs) == 0 {
response.Success(c, gin.H{"stats": map[string]any{}})
return
}
stats, err := h.accountUsageService.GetTodayStatsBatch(c.Request.Context(), req.AccountIDs)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, gin.H{"stats": stats})
}
// SetSchedulableRequest represents the request body for setting schedulable status
type SetSchedulableRequest struct {
Schedulable bool `json:"schedulable"`
......
......@@ -3,6 +3,7 @@ package admin
import (
"errors"
"strconv"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
......@@ -186,7 +187,7 @@ func (h *DashboardHandler) GetRealtimeMetrics(c *gin.Context) {
// GetUsageTrend handles getting usage trend data
// GET /api/v1/admin/dashboard/trend
// Query params: start_date, end_date (YYYY-MM-DD), granularity (day/hour), user_id, api_key_id, model, account_id, group_id, stream, billing_type
// Query params: start_date, end_date (YYYY-MM-DD), granularity (day/hour), user_id, api_key_id, model, account_id, group_id, request_type, stream, billing_type
func (h *DashboardHandler) GetUsageTrend(c *gin.Context) {
startTime, endTime := parseTimeRange(c)
granularity := c.DefaultQuery("granularity", "day")
......@@ -194,6 +195,7 @@ func (h *DashboardHandler) GetUsageTrend(c *gin.Context) {
// Parse optional filter params
var userID, apiKeyID, accountID, groupID int64
var model string
var requestType *int16
var stream *bool
var billingType *int8
......@@ -220,9 +222,20 @@ func (h *DashboardHandler) GetUsageTrend(c *gin.Context) {
if modelStr := c.Query("model"); modelStr != "" {
model = modelStr
}
if streamStr := c.Query("stream"); streamStr != "" {
if requestTypeStr := strings.TrimSpace(c.Query("request_type")); requestTypeStr != "" {
parsed, err := service.ParseUsageRequestType(requestTypeStr)
if err != nil {
response.BadRequest(c, err.Error())
return
}
value := int16(parsed)
requestType = &value
} else if streamStr := c.Query("stream"); streamStr != "" {
if streamVal, err := strconv.ParseBool(streamStr); err == nil {
stream = &streamVal
} else {
response.BadRequest(c, "Invalid stream value, use true or false")
return
}
}
if billingTypeStr := c.Query("billing_type"); billingTypeStr != "" {
......@@ -235,7 +248,7 @@ func (h *DashboardHandler) GetUsageTrend(c *gin.Context) {
}
}
trend, err := h.dashboardService.GetUsageTrendWithFilters(c.Request.Context(), startTime, endTime, granularity, userID, apiKeyID, accountID, groupID, model, stream, billingType)
trend, err := h.dashboardService.GetUsageTrendWithFilters(c.Request.Context(), startTime, endTime, granularity, userID, apiKeyID, accountID, groupID, model, requestType, stream, billingType)
if err != nil {
response.Error(c, 500, "Failed to get usage trend")
return
......@@ -251,12 +264,13 @@ func (h *DashboardHandler) GetUsageTrend(c *gin.Context) {
// GetModelStats handles getting model usage statistics
// GET /api/v1/admin/dashboard/models
// Query params: start_date, end_date (YYYY-MM-DD), user_id, api_key_id, account_id, group_id, stream, billing_type
// Query params: start_date, end_date (YYYY-MM-DD), user_id, api_key_id, account_id, group_id, request_type, stream, billing_type
func (h *DashboardHandler) GetModelStats(c *gin.Context) {
startTime, endTime := parseTimeRange(c)
// Parse optional filter params
var userID, apiKeyID, accountID, groupID int64
var requestType *int16
var stream *bool
var billingType *int8
......@@ -280,9 +294,20 @@ func (h *DashboardHandler) GetModelStats(c *gin.Context) {
groupID = id
}
}
if streamStr := c.Query("stream"); streamStr != "" {
if requestTypeStr := strings.TrimSpace(c.Query("request_type")); requestTypeStr != "" {
parsed, err := service.ParseUsageRequestType(requestTypeStr)
if err != nil {
response.BadRequest(c, err.Error())
return
}
value := int16(parsed)
requestType = &value
} else if streamStr := c.Query("stream"); streamStr != "" {
if streamVal, err := strconv.ParseBool(streamStr); err == nil {
stream = &streamVal
} else {
response.BadRequest(c, "Invalid stream value, use true or false")
return
}
}
if billingTypeStr := c.Query("billing_type"); billingTypeStr != "" {
......@@ -295,7 +320,7 @@ func (h *DashboardHandler) GetModelStats(c *gin.Context) {
}
}
stats, err := h.dashboardService.GetModelStatsWithFilters(c.Request.Context(), startTime, endTime, userID, apiKeyID, accountID, groupID, stream, billingType)
stats, err := h.dashboardService.GetModelStatsWithFilters(c.Request.Context(), startTime, endTime, userID, apiKeyID, accountID, groupID, requestType, stream, billingType)
if err != nil {
response.Error(c, 500, "Failed to get model statistics")
return
......
package admin
import (
"context"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
type dashboardUsageRepoCapture struct {
service.UsageLogRepository
trendRequestType *int16
trendStream *bool
modelRequestType *int16
modelStream *bool
}
func (s *dashboardUsageRepoCapture) GetUsageTrendWithFilters(
ctx context.Context,
startTime, endTime time.Time,
granularity string,
userID, apiKeyID, accountID, groupID int64,
model string,
requestType *int16,
stream *bool,
billingType *int8,
) ([]usagestats.TrendDataPoint, error) {
s.trendRequestType = requestType
s.trendStream = stream
return []usagestats.TrendDataPoint{}, nil
}
func (s *dashboardUsageRepoCapture) GetModelStatsWithFilters(
ctx context.Context,
startTime, endTime time.Time,
userID, apiKeyID, accountID, groupID int64,
requestType *int16,
stream *bool,
billingType *int8,
) ([]usagestats.ModelStat, error) {
s.modelRequestType = requestType
s.modelStream = stream
return []usagestats.ModelStat{}, nil
}
func newDashboardRequestTypeTestRouter(repo *dashboardUsageRepoCapture) *gin.Engine {
gin.SetMode(gin.TestMode)
dashboardSvc := service.NewDashboardService(repo, nil, nil, nil)
handler := NewDashboardHandler(dashboardSvc, nil)
router := gin.New()
router.GET("/admin/dashboard/trend", handler.GetUsageTrend)
router.GET("/admin/dashboard/models", handler.GetModelStats)
return router
}
func TestDashboardTrendRequestTypePriority(t *testing.T) {
repo := &dashboardUsageRepoCapture{}
router := newDashboardRequestTypeTestRouter(repo)
req := httptest.NewRequest(http.MethodGet, "/admin/dashboard/trend?request_type=ws_v2&stream=bad", nil)
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
require.NotNil(t, repo.trendRequestType)
require.Equal(t, int16(service.RequestTypeWSV2), *repo.trendRequestType)
require.Nil(t, repo.trendStream)
}
func TestDashboardTrendInvalidRequestType(t *testing.T) {
repo := &dashboardUsageRepoCapture{}
router := newDashboardRequestTypeTestRouter(repo)
req := httptest.NewRequest(http.MethodGet, "/admin/dashboard/trend?request_type=bad", nil)
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusBadRequest, rec.Code)
}
func TestDashboardTrendInvalidStream(t *testing.T) {
repo := &dashboardUsageRepoCapture{}
router := newDashboardRequestTypeTestRouter(repo)
req := httptest.NewRequest(http.MethodGet, "/admin/dashboard/trend?stream=bad", nil)
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusBadRequest, rec.Code)
}
func TestDashboardModelStatsRequestTypePriority(t *testing.T) {
repo := &dashboardUsageRepoCapture{}
router := newDashboardRequestTypeTestRouter(repo)
req := httptest.NewRequest(http.MethodGet, "/admin/dashboard/models?request_type=sync&stream=bad", nil)
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
require.NotNil(t, repo.modelRequestType)
require.Equal(t, int16(service.RequestTypeSync), *repo.modelRequestType)
require.Nil(t, repo.modelStream)
}
func TestDashboardModelStatsInvalidRequestType(t *testing.T) {
repo := &dashboardUsageRepoCapture{}
router := newDashboardRequestTypeTestRouter(repo)
req := httptest.NewRequest(http.MethodGet, "/admin/dashboard/models?request_type=bad", nil)
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusBadRequest, rec.Code)
}
func TestDashboardModelStatsInvalidStream(t *testing.T) {
repo := &dashboardUsageRepoCapture{}
router := newDashboardRequestTypeTestRouter(repo)
req := httptest.NewRequest(http.MethodGet, "/admin/dashboard/models?stream=bad", nil)
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusBadRequest, rec.Code)
}
package admin
import (
"context"
"strconv"
"strings"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
type DataManagementHandler struct {
dataManagementService dataManagementService
}
func NewDataManagementHandler(dataManagementService *service.DataManagementService) *DataManagementHandler {
return &DataManagementHandler{dataManagementService: dataManagementService}
}
type dataManagementService interface {
GetConfig(ctx context.Context) (service.DataManagementConfig, error)
UpdateConfig(ctx context.Context, cfg service.DataManagementConfig) (service.DataManagementConfig, error)
ValidateS3(ctx context.Context, cfg service.DataManagementS3Config) (service.DataManagementTestS3Result, error)
CreateBackupJob(ctx context.Context, input service.DataManagementCreateBackupJobInput) (service.DataManagementBackupJob, error)
ListSourceProfiles(ctx context.Context, sourceType string) ([]service.DataManagementSourceProfile, error)
CreateSourceProfile(ctx context.Context, input service.DataManagementCreateSourceProfileInput) (service.DataManagementSourceProfile, error)
UpdateSourceProfile(ctx context.Context, input service.DataManagementUpdateSourceProfileInput) (service.DataManagementSourceProfile, error)
DeleteSourceProfile(ctx context.Context, sourceType, profileID string) error
SetActiveSourceProfile(ctx context.Context, sourceType, profileID string) (service.DataManagementSourceProfile, error)
ListS3Profiles(ctx context.Context) ([]service.DataManagementS3Profile, error)
CreateS3Profile(ctx context.Context, input service.DataManagementCreateS3ProfileInput) (service.DataManagementS3Profile, error)
UpdateS3Profile(ctx context.Context, input service.DataManagementUpdateS3ProfileInput) (service.DataManagementS3Profile, error)
DeleteS3Profile(ctx context.Context, profileID string) error
SetActiveS3Profile(ctx context.Context, profileID string) (service.DataManagementS3Profile, error)
ListBackupJobs(ctx context.Context, input service.DataManagementListBackupJobsInput) (service.DataManagementListBackupJobsResult, error)
GetBackupJob(ctx context.Context, jobID string) (service.DataManagementBackupJob, error)
EnsureAgentEnabled(ctx context.Context) error
GetAgentHealth(ctx context.Context) service.DataManagementAgentHealth
}
type TestS3ConnectionRequest struct {
Endpoint string `json:"endpoint"`
Region string `json:"region" binding:"required"`
Bucket string `json:"bucket" binding:"required"`
AccessKeyID string `json:"access_key_id"`
SecretAccessKey string `json:"secret_access_key"`
Prefix string `json:"prefix"`
ForcePathStyle bool `json:"force_path_style"`
UseSSL bool `json:"use_ssl"`
}
type CreateBackupJobRequest struct {
BackupType string `json:"backup_type" binding:"required,oneof=postgres redis full"`
UploadToS3 bool `json:"upload_to_s3"`
S3ProfileID string `json:"s3_profile_id"`
PostgresID string `json:"postgres_profile_id"`
RedisID string `json:"redis_profile_id"`
IdempotencyKey string `json:"idempotency_key"`
}
type CreateSourceProfileRequest struct {
ProfileID string `json:"profile_id" binding:"required"`
Name string `json:"name" binding:"required"`
Config service.DataManagementSourceConfig `json:"config" binding:"required"`
SetActive bool `json:"set_active"`
}
type UpdateSourceProfileRequest struct {
Name string `json:"name" binding:"required"`
Config service.DataManagementSourceConfig `json:"config" binding:"required"`
}
type CreateS3ProfileRequest struct {
ProfileID string `json:"profile_id" binding:"required"`
Name string `json:"name" binding:"required"`
Enabled bool `json:"enabled"`
Endpoint string `json:"endpoint"`
Region string `json:"region"`
Bucket string `json:"bucket"`
AccessKeyID string `json:"access_key_id"`
SecretAccessKey string `json:"secret_access_key"`
Prefix string `json:"prefix"`
ForcePathStyle bool `json:"force_path_style"`
UseSSL bool `json:"use_ssl"`
SetActive bool `json:"set_active"`
}
type UpdateS3ProfileRequest struct {
Name string `json:"name" binding:"required"`
Enabled bool `json:"enabled"`
Endpoint string `json:"endpoint"`
Region string `json:"region"`
Bucket string `json:"bucket"`
AccessKeyID string `json:"access_key_id"`
SecretAccessKey string `json:"secret_access_key"`
Prefix string `json:"prefix"`
ForcePathStyle bool `json:"force_path_style"`
UseSSL bool `json:"use_ssl"`
}
func (h *DataManagementHandler) GetAgentHealth(c *gin.Context) {
health := h.getAgentHealth(c)
payload := gin.H{
"enabled": health.Enabled,
"reason": health.Reason,
"socket_path": health.SocketPath,
}
if health.Agent != nil {
payload["agent"] = gin.H{
"status": health.Agent.Status,
"version": health.Agent.Version,
"uptime_seconds": health.Agent.UptimeSeconds,
}
}
response.Success(c, payload)
}
func (h *DataManagementHandler) GetConfig(c *gin.Context) {
if !h.requireAgentEnabled(c) {
return
}
cfg, err := h.dataManagementService.GetConfig(c.Request.Context())
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, cfg)
}
func (h *DataManagementHandler) UpdateConfig(c *gin.Context) {
var req service.DataManagementConfig
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
if !h.requireAgentEnabled(c) {
return
}
cfg, err := h.dataManagementService.UpdateConfig(c.Request.Context(), req)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, cfg)
}
func (h *DataManagementHandler) TestS3(c *gin.Context) {
var req TestS3ConnectionRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
if !h.requireAgentEnabled(c) {
return
}
result, err := h.dataManagementService.ValidateS3(c.Request.Context(), service.DataManagementS3Config{
Enabled: true,
Endpoint: req.Endpoint,
Region: req.Region,
Bucket: req.Bucket,
AccessKeyID: req.AccessKeyID,
SecretAccessKey: req.SecretAccessKey,
Prefix: req.Prefix,
ForcePathStyle: req.ForcePathStyle,
UseSSL: req.UseSSL,
})
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, gin.H{"ok": result.OK, "message": result.Message})
}
func (h *DataManagementHandler) CreateBackupJob(c *gin.Context) {
var req CreateBackupJobRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
req.IdempotencyKey = normalizeBackupIdempotencyKey(c.GetHeader("X-Idempotency-Key"), req.IdempotencyKey)
if !h.requireAgentEnabled(c) {
return
}
triggeredBy := "admin:unknown"
if subject, ok := middleware2.GetAuthSubjectFromContext(c); ok {
triggeredBy = "admin:" + strconv.FormatInt(subject.UserID, 10)
}
job, err := h.dataManagementService.CreateBackupJob(c.Request.Context(), service.DataManagementCreateBackupJobInput{
BackupType: req.BackupType,
UploadToS3: req.UploadToS3,
S3ProfileID: req.S3ProfileID,
PostgresID: req.PostgresID,
RedisID: req.RedisID,
TriggeredBy: triggeredBy,
IdempotencyKey: req.IdempotencyKey,
})
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, gin.H{"job_id": job.JobID, "status": job.Status})
}
func (h *DataManagementHandler) ListSourceProfiles(c *gin.Context) {
sourceType := strings.TrimSpace(c.Param("source_type"))
if sourceType == "" {
response.BadRequest(c, "Invalid source_type")
return
}
if sourceType != "postgres" && sourceType != "redis" {
response.BadRequest(c, "source_type must be postgres or redis")
return
}
if !h.requireAgentEnabled(c) {
return
}
items, err := h.dataManagementService.ListSourceProfiles(c.Request.Context(), sourceType)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, gin.H{"items": items})
}
func (h *DataManagementHandler) CreateSourceProfile(c *gin.Context) {
sourceType := strings.TrimSpace(c.Param("source_type"))
if sourceType != "postgres" && sourceType != "redis" {
response.BadRequest(c, "source_type must be postgres or redis")
return
}
var req CreateSourceProfileRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
if !h.requireAgentEnabled(c) {
return
}
profile, err := h.dataManagementService.CreateSourceProfile(c.Request.Context(), service.DataManagementCreateSourceProfileInput{
SourceType: sourceType,
ProfileID: req.ProfileID,
Name: req.Name,
Config: req.Config,
SetActive: req.SetActive,
})
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, profile)
}
func (h *DataManagementHandler) UpdateSourceProfile(c *gin.Context) {
sourceType := strings.TrimSpace(c.Param("source_type"))
if sourceType != "postgres" && sourceType != "redis" {
response.BadRequest(c, "source_type must be postgres or redis")
return
}
profileID := strings.TrimSpace(c.Param("profile_id"))
if profileID == "" {
response.BadRequest(c, "Invalid profile_id")
return
}
var req UpdateSourceProfileRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
if !h.requireAgentEnabled(c) {
return
}
profile, err := h.dataManagementService.UpdateSourceProfile(c.Request.Context(), service.DataManagementUpdateSourceProfileInput{
SourceType: sourceType,
ProfileID: profileID,
Name: req.Name,
Config: req.Config,
})
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, profile)
}
func (h *DataManagementHandler) DeleteSourceProfile(c *gin.Context) {
sourceType := strings.TrimSpace(c.Param("source_type"))
if sourceType != "postgres" && sourceType != "redis" {
response.BadRequest(c, "source_type must be postgres or redis")
return
}
profileID := strings.TrimSpace(c.Param("profile_id"))
if profileID == "" {
response.BadRequest(c, "Invalid profile_id")
return
}
if !h.requireAgentEnabled(c) {
return
}
if err := h.dataManagementService.DeleteSourceProfile(c.Request.Context(), sourceType, profileID); err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, gin.H{"deleted": true})
}
func (h *DataManagementHandler) SetActiveSourceProfile(c *gin.Context) {
sourceType := strings.TrimSpace(c.Param("source_type"))
if sourceType != "postgres" && sourceType != "redis" {
response.BadRequest(c, "source_type must be postgres or redis")
return
}
profileID := strings.TrimSpace(c.Param("profile_id"))
if profileID == "" {
response.BadRequest(c, "Invalid profile_id")
return
}
if !h.requireAgentEnabled(c) {
return
}
profile, err := h.dataManagementService.SetActiveSourceProfile(c.Request.Context(), sourceType, profileID)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, profile)
}
func (h *DataManagementHandler) ListS3Profiles(c *gin.Context) {
if !h.requireAgentEnabled(c) {
return
}
items, err := h.dataManagementService.ListS3Profiles(c.Request.Context())
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, gin.H{"items": items})
}
func (h *DataManagementHandler) CreateS3Profile(c *gin.Context) {
var req CreateS3ProfileRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
if !h.requireAgentEnabled(c) {
return
}
profile, err := h.dataManagementService.CreateS3Profile(c.Request.Context(), service.DataManagementCreateS3ProfileInput{
ProfileID: req.ProfileID,
Name: req.Name,
SetActive: req.SetActive,
S3: service.DataManagementS3Config{
Enabled: req.Enabled,
Endpoint: req.Endpoint,
Region: req.Region,
Bucket: req.Bucket,
AccessKeyID: req.AccessKeyID,
SecretAccessKey: req.SecretAccessKey,
Prefix: req.Prefix,
ForcePathStyle: req.ForcePathStyle,
UseSSL: req.UseSSL,
},
})
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, profile)
}
func (h *DataManagementHandler) UpdateS3Profile(c *gin.Context) {
var req UpdateS3ProfileRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
profileID := strings.TrimSpace(c.Param("profile_id"))
if profileID == "" {
response.BadRequest(c, "Invalid profile_id")
return
}
if !h.requireAgentEnabled(c) {
return
}
profile, err := h.dataManagementService.UpdateS3Profile(c.Request.Context(), service.DataManagementUpdateS3ProfileInput{
ProfileID: profileID,
Name: req.Name,
S3: service.DataManagementS3Config{
Enabled: req.Enabled,
Endpoint: req.Endpoint,
Region: req.Region,
Bucket: req.Bucket,
AccessKeyID: req.AccessKeyID,
SecretAccessKey: req.SecretAccessKey,
Prefix: req.Prefix,
ForcePathStyle: req.ForcePathStyle,
UseSSL: req.UseSSL,
},
})
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, profile)
}
func (h *DataManagementHandler) DeleteS3Profile(c *gin.Context) {
profileID := strings.TrimSpace(c.Param("profile_id"))
if profileID == "" {
response.BadRequest(c, "Invalid profile_id")
return
}
if !h.requireAgentEnabled(c) {
return
}
if err := h.dataManagementService.DeleteS3Profile(c.Request.Context(), profileID); err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, gin.H{"deleted": true})
}
func (h *DataManagementHandler) SetActiveS3Profile(c *gin.Context) {
profileID := strings.TrimSpace(c.Param("profile_id"))
if profileID == "" {
response.BadRequest(c, "Invalid profile_id")
return
}
if !h.requireAgentEnabled(c) {
return
}
profile, err := h.dataManagementService.SetActiveS3Profile(c.Request.Context(), profileID)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, profile)
}
func (h *DataManagementHandler) ListBackupJobs(c *gin.Context) {
if !h.requireAgentEnabled(c) {
return
}
pageSize := int32(20)
if raw := strings.TrimSpace(c.Query("page_size")); raw != "" {
v, err := strconv.Atoi(raw)
if err != nil || v <= 0 {
response.BadRequest(c, "Invalid page_size")
return
}
pageSize = int32(v)
}
result, err := h.dataManagementService.ListBackupJobs(c.Request.Context(), service.DataManagementListBackupJobsInput{
PageSize: pageSize,
PageToken: c.Query("page_token"),
Status: c.Query("status"),
BackupType: c.Query("backup_type"),
})
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, result)
}
func (h *DataManagementHandler) GetBackupJob(c *gin.Context) {
jobID := strings.TrimSpace(c.Param("job_id"))
if jobID == "" {
response.BadRequest(c, "Invalid backup job ID")
return
}
if !h.requireAgentEnabled(c) {
return
}
job, err := h.dataManagementService.GetBackupJob(c.Request.Context(), jobID)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, job)
}
func (h *DataManagementHandler) requireAgentEnabled(c *gin.Context) bool {
if h.dataManagementService == nil {
err := infraerrors.ServiceUnavailable(
service.DataManagementAgentUnavailableReason,
"data management agent service is not configured",
).WithMetadata(map[string]string{"socket_path": service.DefaultDataManagementAgentSocketPath})
response.ErrorFrom(c, err)
return false
}
if err := h.dataManagementService.EnsureAgentEnabled(c.Request.Context()); err != nil {
response.ErrorFrom(c, err)
return false
}
return true
}
func (h *DataManagementHandler) getAgentHealth(c *gin.Context) service.DataManagementAgentHealth {
if h.dataManagementService == nil {
return service.DataManagementAgentHealth{
Enabled: false,
Reason: service.DataManagementAgentUnavailableReason,
SocketPath: service.DefaultDataManagementAgentSocketPath,
}
}
return h.dataManagementService.GetAgentHealth(c.Request.Context())
}
func normalizeBackupIdempotencyKey(headerValue, bodyValue string) string {
headerKey := strings.TrimSpace(headerValue)
if headerKey != "" {
return headerKey
}
return strings.TrimSpace(bodyValue)
}
package admin
import (
"encoding/json"
"net/http"
"net/http/httptest"
"path/filepath"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
type apiEnvelope struct {
Code int `json:"code"`
Message string `json:"message"`
Reason string `json:"reason"`
Data json.RawMessage `json:"data"`
}
func TestDataManagementHandler_AgentHealthAlways200(t *testing.T) {
gin.SetMode(gin.TestMode)
svc := service.NewDataManagementServiceWithOptions(filepath.Join(t.TempDir(), "missing.sock"), 50*time.Millisecond)
h := NewDataManagementHandler(svc)
r := gin.New()
r.GET("/api/v1/admin/data-management/agent/health", h.GetAgentHealth)
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/data-management/agent/health", nil)
r.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
var envelope apiEnvelope
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &envelope))
require.Equal(t, 0, envelope.Code)
var data struct {
Enabled bool `json:"enabled"`
Reason string `json:"reason"`
SocketPath string `json:"socket_path"`
}
require.NoError(t, json.Unmarshal(envelope.Data, &data))
require.False(t, data.Enabled)
require.Equal(t, service.DataManagementDeprecatedReason, data.Reason)
require.Equal(t, svc.SocketPath(), data.SocketPath)
}
func TestDataManagementHandler_NonHealthRouteReturns503WhenDisabled(t *testing.T) {
gin.SetMode(gin.TestMode)
svc := service.NewDataManagementServiceWithOptions(filepath.Join(t.TempDir(), "missing.sock"), 50*time.Millisecond)
h := NewDataManagementHandler(svc)
r := gin.New()
r.GET("/api/v1/admin/data-management/config", h.GetConfig)
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/data-management/config", nil)
r.ServeHTTP(rec, req)
require.Equal(t, http.StatusServiceUnavailable, rec.Code)
var envelope apiEnvelope
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &envelope))
require.Equal(t, http.StatusServiceUnavailable, envelope.Code)
require.Equal(t, service.DataManagementDeprecatedReason, envelope.Reason)
}
func TestNormalizeBackupIdempotencyKey(t *testing.T) {
require.Equal(t, "from-header", normalizeBackupIdempotencyKey("from-header", "from-body"))
require.Equal(t, "from-body", normalizeBackupIdempotencyKey(" ", " from-body "))
require.Equal(t, "", normalizeBackupIdempotencyKey("", ""))
}
......@@ -51,6 +51,8 @@ type CreateGroupRequest struct {
MCPXMLInject *bool `json:"mcp_xml_inject"`
// 支持的模型系列(仅 antigravity 平台使用)
SupportedModelScopes []string `json:"supported_model_scopes"`
// Sora 存储配额
SoraStorageQuotaBytes int64 `json:"sora_storage_quota_bytes"`
// 从指定分组复制账号(创建后自动绑定)
CopyAccountsFromGroupIDs []int64 `json:"copy_accounts_from_group_ids"`
}
......@@ -84,6 +86,8 @@ type UpdateGroupRequest struct {
MCPXMLInject *bool `json:"mcp_xml_inject"`
// 支持的模型系列(仅 antigravity 平台使用)
SupportedModelScopes *[]string `json:"supported_model_scopes"`
// Sora 存储配额
SoraStorageQuotaBytes *int64 `json:"sora_storage_quota_bytes"`
// 从指定分组复制账号(同步操作:先清空当前分组的账号绑定,再绑定源分组的账号)
CopyAccountsFromGroupIDs []int64 `json:"copy_accounts_from_group_ids"`
}
......@@ -198,6 +202,7 @@ func (h *GroupHandler) Create(c *gin.Context) {
ModelRoutingEnabled: req.ModelRoutingEnabled,
MCPXMLInject: req.MCPXMLInject,
SupportedModelScopes: req.SupportedModelScopes,
SoraStorageQuotaBytes: req.SoraStorageQuotaBytes,
CopyAccountsFromGroupIDs: req.CopyAccountsFromGroupIDs,
})
if err != nil {
......@@ -248,6 +253,7 @@ func (h *GroupHandler) Update(c *gin.Context) {
ModelRoutingEnabled: req.ModelRoutingEnabled,
MCPXMLInject: req.MCPXMLInject,
SupportedModelScopes: req.SupportedModelScopes,
SoraStorageQuotaBytes: req.SoraStorageQuotaBytes,
CopyAccountsFromGroupIDs: req.CopyAccountsFromGroupIDs,
})
if err != nil {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment