"backend/internal/vscode:/vscode.git/clone" did not exist on "e12dd079fd29a30d0f3a5bc96d69c8e27ef5195c"
Commit 7331220e authored by Edric Li's avatar Edric Li
Browse files

Merge remote-tracking branch 'upstream/main'

# Conflicts:
#	frontend/src/components/account/CreateAccountModal.vue
parents fb86002e 4f13c8de
...@@ -59,11 +59,13 @@ type UserEdges struct { ...@@ -59,11 +59,13 @@ type UserEdges struct {
AssignedSubscriptions []*UserSubscription `json:"assigned_subscriptions,omitempty"` AssignedSubscriptions []*UserSubscription `json:"assigned_subscriptions,omitempty"`
// AllowedGroups holds the value of the allowed_groups edge. // AllowedGroups holds the value of the allowed_groups edge.
AllowedGroups []*Group `json:"allowed_groups,omitempty"` AllowedGroups []*Group `json:"allowed_groups,omitempty"`
// UsageLogs holds the value of the usage_logs edge.
UsageLogs []*UsageLog `json:"usage_logs,omitempty"`
// UserAllowedGroups holds the value of the user_allowed_groups edge. // UserAllowedGroups holds the value of the user_allowed_groups edge.
UserAllowedGroups []*UserAllowedGroup `json:"user_allowed_groups,omitempty"` UserAllowedGroups []*UserAllowedGroup `json:"user_allowed_groups,omitempty"`
// loadedTypes holds the information for reporting if a // loadedTypes holds the information for reporting if a
// type was loaded (or requested) in eager-loading or not. // type was loaded (or requested) in eager-loading or not.
loadedTypes [6]bool loadedTypes [7]bool
} }
// APIKeysOrErr returns the APIKeys value or an error if the edge // APIKeysOrErr returns the APIKeys value or an error if the edge
...@@ -111,10 +113,19 @@ func (e UserEdges) AllowedGroupsOrErr() ([]*Group, error) { ...@@ -111,10 +113,19 @@ func (e UserEdges) AllowedGroupsOrErr() ([]*Group, error) {
return nil, &NotLoadedError{edge: "allowed_groups"} return nil, &NotLoadedError{edge: "allowed_groups"}
} }
// UsageLogsOrErr returns the UsageLogs value or an error if the edge
// was not loaded in eager-loading.
func (e UserEdges) UsageLogsOrErr() ([]*UsageLog, error) {
if e.loadedTypes[5] {
return e.UsageLogs, nil
}
return nil, &NotLoadedError{edge: "usage_logs"}
}
// UserAllowedGroupsOrErr returns the UserAllowedGroups value or an error if the edge // UserAllowedGroupsOrErr returns the UserAllowedGroups value or an error if the edge
// was not loaded in eager-loading. // was not loaded in eager-loading.
func (e UserEdges) UserAllowedGroupsOrErr() ([]*UserAllowedGroup, error) { func (e UserEdges) UserAllowedGroupsOrErr() ([]*UserAllowedGroup, error) {
if e.loadedTypes[5] { if e.loadedTypes[6] {
return e.UserAllowedGroups, nil return e.UserAllowedGroups, nil
} }
return nil, &NotLoadedError{edge: "user_allowed_groups"} return nil, &NotLoadedError{edge: "user_allowed_groups"}
...@@ -265,6 +276,11 @@ func (_m *User) QueryAllowedGroups() *GroupQuery { ...@@ -265,6 +276,11 @@ func (_m *User) QueryAllowedGroups() *GroupQuery {
return NewUserClient(_m.config).QueryAllowedGroups(_m) return NewUserClient(_m.config).QueryAllowedGroups(_m)
} }
// QueryUsageLogs queries the "usage_logs" edge of the User entity.
func (_m *User) QueryUsageLogs() *UsageLogQuery {
return NewUserClient(_m.config).QueryUsageLogs(_m)
}
// QueryUserAllowedGroups queries the "user_allowed_groups" edge of the User entity. // QueryUserAllowedGroups queries the "user_allowed_groups" edge of the User entity.
func (_m *User) QueryUserAllowedGroups() *UserAllowedGroupQuery { func (_m *User) QueryUserAllowedGroups() *UserAllowedGroupQuery {
return NewUserClient(_m.config).QueryUserAllowedGroups(_m) return NewUserClient(_m.config).QueryUserAllowedGroups(_m)
......
...@@ -49,6 +49,8 @@ const ( ...@@ -49,6 +49,8 @@ const (
EdgeAssignedSubscriptions = "assigned_subscriptions" EdgeAssignedSubscriptions = "assigned_subscriptions"
// EdgeAllowedGroups holds the string denoting the allowed_groups edge name in mutations. // EdgeAllowedGroups holds the string denoting the allowed_groups edge name in mutations.
EdgeAllowedGroups = "allowed_groups" EdgeAllowedGroups = "allowed_groups"
// EdgeUsageLogs holds the string denoting the usage_logs edge name in mutations.
EdgeUsageLogs = "usage_logs"
// EdgeUserAllowedGroups holds the string denoting the user_allowed_groups edge name in mutations. // EdgeUserAllowedGroups holds the string denoting the user_allowed_groups edge name in mutations.
EdgeUserAllowedGroups = "user_allowed_groups" EdgeUserAllowedGroups = "user_allowed_groups"
// Table holds the table name of the user in the database. // Table holds the table name of the user in the database.
...@@ -86,6 +88,13 @@ const ( ...@@ -86,6 +88,13 @@ const (
// AllowedGroupsInverseTable is the table name for the Group entity. // AllowedGroupsInverseTable is the table name for the Group entity.
// It exists in this package in order to avoid circular dependency with the "group" package. // It exists in this package in order to avoid circular dependency with the "group" package.
AllowedGroupsInverseTable = "groups" AllowedGroupsInverseTable = "groups"
// UsageLogsTable is the table that holds the usage_logs relation/edge.
UsageLogsTable = "usage_logs"
// UsageLogsInverseTable is the table name for the UsageLog entity.
// It exists in this package in order to avoid circular dependency with the "usagelog" package.
UsageLogsInverseTable = "usage_logs"
// UsageLogsColumn is the table column denoting the usage_logs relation/edge.
UsageLogsColumn = "user_id"
// UserAllowedGroupsTable is the table that holds the user_allowed_groups relation/edge. // UserAllowedGroupsTable is the table that holds the user_allowed_groups relation/edge.
UserAllowedGroupsTable = "user_allowed_groups" UserAllowedGroupsTable = "user_allowed_groups"
// UserAllowedGroupsInverseTable is the table name for the UserAllowedGroup entity. // UserAllowedGroupsInverseTable is the table name for the UserAllowedGroup entity.
...@@ -308,6 +317,20 @@ func ByAllowedGroups(term sql.OrderTerm, terms ...sql.OrderTerm) OrderOption { ...@@ -308,6 +317,20 @@ func ByAllowedGroups(term sql.OrderTerm, terms ...sql.OrderTerm) OrderOption {
} }
} }
// ByUsageLogsCount orders the results by usage_logs count.
func ByUsageLogsCount(opts ...sql.OrderTermOption) OrderOption {
return func(s *sql.Selector) {
sqlgraph.OrderByNeighborsCount(s, newUsageLogsStep(), opts...)
}
}
// ByUsageLogs orders the results by usage_logs terms.
func ByUsageLogs(term sql.OrderTerm, terms ...sql.OrderTerm) OrderOption {
return func(s *sql.Selector) {
sqlgraph.OrderByNeighborTerms(s, newUsageLogsStep(), append([]sql.OrderTerm{term}, terms...)...)
}
}
// ByUserAllowedGroupsCount orders the results by user_allowed_groups count. // ByUserAllowedGroupsCount orders the results by user_allowed_groups count.
func ByUserAllowedGroupsCount(opts ...sql.OrderTermOption) OrderOption { func ByUserAllowedGroupsCount(opts ...sql.OrderTermOption) OrderOption {
return func(s *sql.Selector) { return func(s *sql.Selector) {
...@@ -356,6 +379,13 @@ func newAllowedGroupsStep() *sqlgraph.Step { ...@@ -356,6 +379,13 @@ func newAllowedGroupsStep() *sqlgraph.Step {
sqlgraph.Edge(sqlgraph.M2M, false, AllowedGroupsTable, AllowedGroupsPrimaryKey...), sqlgraph.Edge(sqlgraph.M2M, false, AllowedGroupsTable, AllowedGroupsPrimaryKey...),
) )
} }
func newUsageLogsStep() *sqlgraph.Step {
return sqlgraph.NewStep(
sqlgraph.From(Table, FieldID),
sqlgraph.To(UsageLogsInverseTable, FieldID),
sqlgraph.Edge(sqlgraph.O2M, false, UsageLogsTable, UsageLogsColumn),
)
}
func newUserAllowedGroupsStep() *sqlgraph.Step { func newUserAllowedGroupsStep() *sqlgraph.Step {
return sqlgraph.NewStep( return sqlgraph.NewStep(
sqlgraph.From(Table, FieldID), sqlgraph.From(Table, FieldID),
......
...@@ -895,6 +895,29 @@ func HasAllowedGroupsWith(preds ...predicate.Group) predicate.User { ...@@ -895,6 +895,29 @@ func HasAllowedGroupsWith(preds ...predicate.Group) predicate.User {
}) })
} }
// HasUsageLogs applies the HasEdge predicate on the "usage_logs" edge.
func HasUsageLogs() predicate.User {
return predicate.User(func(s *sql.Selector) {
step := sqlgraph.NewStep(
sqlgraph.From(Table, FieldID),
sqlgraph.Edge(sqlgraph.O2M, false, UsageLogsTable, UsageLogsColumn),
)
sqlgraph.HasNeighbors(s, step)
})
}
// HasUsageLogsWith applies the HasEdge predicate on the "usage_logs" edge with a given conditions (other predicates).
func HasUsageLogsWith(preds ...predicate.UsageLog) predicate.User {
return predicate.User(func(s *sql.Selector) {
step := newUsageLogsStep()
sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) {
for _, p := range preds {
p(s)
}
})
})
}
// HasUserAllowedGroups applies the HasEdge predicate on the "user_allowed_groups" edge. // HasUserAllowedGroups applies the HasEdge predicate on the "user_allowed_groups" edge.
func HasUserAllowedGroups() predicate.User { func HasUserAllowedGroups() predicate.User {
return predicate.User(func(s *sql.Selector) { return predicate.User(func(s *sql.Selector) {
......
...@@ -14,6 +14,7 @@ import ( ...@@ -14,6 +14,7 @@ import (
"github.com/Wei-Shaw/sub2api/ent/apikey" "github.com/Wei-Shaw/sub2api/ent/apikey"
"github.com/Wei-Shaw/sub2api/ent/group" "github.com/Wei-Shaw/sub2api/ent/group"
"github.com/Wei-Shaw/sub2api/ent/redeemcode" "github.com/Wei-Shaw/sub2api/ent/redeemcode"
"github.com/Wei-Shaw/sub2api/ent/usagelog"
"github.com/Wei-Shaw/sub2api/ent/user" "github.com/Wei-Shaw/sub2api/ent/user"
"github.com/Wei-Shaw/sub2api/ent/usersubscription" "github.com/Wei-Shaw/sub2api/ent/usersubscription"
) )
...@@ -253,6 +254,21 @@ func (_c *UserCreate) AddAllowedGroups(v ...*Group) *UserCreate { ...@@ -253,6 +254,21 @@ func (_c *UserCreate) AddAllowedGroups(v ...*Group) *UserCreate {
return _c.AddAllowedGroupIDs(ids...) return _c.AddAllowedGroupIDs(ids...)
} }
// AddUsageLogIDs adds the "usage_logs" edge to the UsageLog entity by IDs.
func (_c *UserCreate) AddUsageLogIDs(ids ...int64) *UserCreate {
_c.mutation.AddUsageLogIDs(ids...)
return _c
}
// AddUsageLogs adds the "usage_logs" edges to the UsageLog entity.
func (_c *UserCreate) AddUsageLogs(v ...*UsageLog) *UserCreate {
ids := make([]int64, len(v))
for i := range v {
ids[i] = v[i].ID
}
return _c.AddUsageLogIDs(ids...)
}
// Mutation returns the UserMutation object of the builder. // Mutation returns the UserMutation object of the builder.
func (_c *UserCreate) Mutation() *UserMutation { func (_c *UserCreate) Mutation() *UserMutation {
return _c.mutation return _c.mutation
...@@ -559,6 +575,22 @@ func (_c *UserCreate) createSpec() (*User, *sqlgraph.CreateSpec) { ...@@ -559,6 +575,22 @@ func (_c *UserCreate) createSpec() (*User, *sqlgraph.CreateSpec) {
edge.Target.Fields = specE.Fields edge.Target.Fields = specE.Fields
_spec.Edges = append(_spec.Edges, edge) _spec.Edges = append(_spec.Edges, edge)
} }
if nodes := _c.mutation.UsageLogsIDs(); len(nodes) > 0 {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
Inverse: false,
Table: user.UsageLogsTable,
Columns: []string{user.UsageLogsColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(usagelog.FieldID, field.TypeInt64),
},
}
for _, k := range nodes {
edge.Target.Nodes = append(edge.Target.Nodes, k)
}
_spec.Edges = append(_spec.Edges, edge)
}
return _node, _spec return _node, _spec
} }
......
...@@ -16,6 +16,7 @@ import ( ...@@ -16,6 +16,7 @@ import (
"github.com/Wei-Shaw/sub2api/ent/group" "github.com/Wei-Shaw/sub2api/ent/group"
"github.com/Wei-Shaw/sub2api/ent/predicate" "github.com/Wei-Shaw/sub2api/ent/predicate"
"github.com/Wei-Shaw/sub2api/ent/redeemcode" "github.com/Wei-Shaw/sub2api/ent/redeemcode"
"github.com/Wei-Shaw/sub2api/ent/usagelog"
"github.com/Wei-Shaw/sub2api/ent/user" "github.com/Wei-Shaw/sub2api/ent/user"
"github.com/Wei-Shaw/sub2api/ent/userallowedgroup" "github.com/Wei-Shaw/sub2api/ent/userallowedgroup"
"github.com/Wei-Shaw/sub2api/ent/usersubscription" "github.com/Wei-Shaw/sub2api/ent/usersubscription"
...@@ -33,6 +34,7 @@ type UserQuery struct { ...@@ -33,6 +34,7 @@ type UserQuery struct {
withSubscriptions *UserSubscriptionQuery withSubscriptions *UserSubscriptionQuery
withAssignedSubscriptions *UserSubscriptionQuery withAssignedSubscriptions *UserSubscriptionQuery
withAllowedGroups *GroupQuery withAllowedGroups *GroupQuery
withUsageLogs *UsageLogQuery
withUserAllowedGroups *UserAllowedGroupQuery withUserAllowedGroups *UserAllowedGroupQuery
// intermediate query (i.e. traversal path). // intermediate query (i.e. traversal path).
sql *sql.Selector sql *sql.Selector
...@@ -180,6 +182,28 @@ func (_q *UserQuery) QueryAllowedGroups() *GroupQuery { ...@@ -180,6 +182,28 @@ func (_q *UserQuery) QueryAllowedGroups() *GroupQuery {
return query return query
} }
// QueryUsageLogs chains the current query on the "usage_logs" edge.
func (_q *UserQuery) QueryUsageLogs() *UsageLogQuery {
query := (&UsageLogClient{config: _q.config}).Query()
query.path = func(ctx context.Context) (fromU *sql.Selector, err error) {
if err := _q.prepareQuery(ctx); err != nil {
return nil, err
}
selector := _q.sqlQuery(ctx)
if err := selector.Err(); err != nil {
return nil, err
}
step := sqlgraph.NewStep(
sqlgraph.From(user.Table, user.FieldID, selector),
sqlgraph.To(usagelog.Table, usagelog.FieldID),
sqlgraph.Edge(sqlgraph.O2M, false, user.UsageLogsTable, user.UsageLogsColumn),
)
fromU = sqlgraph.SetNeighbors(_q.driver.Dialect(), step)
return fromU, nil
}
return query
}
// QueryUserAllowedGroups chains the current query on the "user_allowed_groups" edge. // QueryUserAllowedGroups chains the current query on the "user_allowed_groups" edge.
func (_q *UserQuery) QueryUserAllowedGroups() *UserAllowedGroupQuery { func (_q *UserQuery) QueryUserAllowedGroups() *UserAllowedGroupQuery {
query := (&UserAllowedGroupClient{config: _q.config}).Query() query := (&UserAllowedGroupClient{config: _q.config}).Query()
...@@ -399,6 +423,7 @@ func (_q *UserQuery) Clone() *UserQuery { ...@@ -399,6 +423,7 @@ func (_q *UserQuery) Clone() *UserQuery {
withSubscriptions: _q.withSubscriptions.Clone(), withSubscriptions: _q.withSubscriptions.Clone(),
withAssignedSubscriptions: _q.withAssignedSubscriptions.Clone(), withAssignedSubscriptions: _q.withAssignedSubscriptions.Clone(),
withAllowedGroups: _q.withAllowedGroups.Clone(), withAllowedGroups: _q.withAllowedGroups.Clone(),
withUsageLogs: _q.withUsageLogs.Clone(),
withUserAllowedGroups: _q.withUserAllowedGroups.Clone(), withUserAllowedGroups: _q.withUserAllowedGroups.Clone(),
// clone intermediate query. // clone intermediate query.
sql: _q.sql.Clone(), sql: _q.sql.Clone(),
...@@ -461,6 +486,17 @@ func (_q *UserQuery) WithAllowedGroups(opts ...func(*GroupQuery)) *UserQuery { ...@@ -461,6 +486,17 @@ func (_q *UserQuery) WithAllowedGroups(opts ...func(*GroupQuery)) *UserQuery {
return _q return _q
} }
// WithUsageLogs tells the query-builder to eager-load the nodes that are connected to
// the "usage_logs" edge. The optional arguments are used to configure the query builder of the edge.
func (_q *UserQuery) WithUsageLogs(opts ...func(*UsageLogQuery)) *UserQuery {
query := (&UsageLogClient{config: _q.config}).Query()
for _, opt := range opts {
opt(query)
}
_q.withUsageLogs = query
return _q
}
// WithUserAllowedGroups tells the query-builder to eager-load the nodes that are connected to // WithUserAllowedGroups tells the query-builder to eager-load the nodes that are connected to
// the "user_allowed_groups" edge. The optional arguments are used to configure the query builder of the edge. // the "user_allowed_groups" edge. The optional arguments are used to configure the query builder of the edge.
func (_q *UserQuery) WithUserAllowedGroups(opts ...func(*UserAllowedGroupQuery)) *UserQuery { func (_q *UserQuery) WithUserAllowedGroups(opts ...func(*UserAllowedGroupQuery)) *UserQuery {
...@@ -550,12 +586,13 @@ func (_q *UserQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*User, e ...@@ -550,12 +586,13 @@ func (_q *UserQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*User, e
var ( var (
nodes = []*User{} nodes = []*User{}
_spec = _q.querySpec() _spec = _q.querySpec()
loadedTypes = [6]bool{ loadedTypes = [7]bool{
_q.withAPIKeys != nil, _q.withAPIKeys != nil,
_q.withRedeemCodes != nil, _q.withRedeemCodes != nil,
_q.withSubscriptions != nil, _q.withSubscriptions != nil,
_q.withAssignedSubscriptions != nil, _q.withAssignedSubscriptions != nil,
_q.withAllowedGroups != nil, _q.withAllowedGroups != nil,
_q.withUsageLogs != nil,
_q.withUserAllowedGroups != nil, _q.withUserAllowedGroups != nil,
} }
) )
...@@ -614,6 +651,13 @@ func (_q *UserQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*User, e ...@@ -614,6 +651,13 @@ func (_q *UserQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*User, e
return nil, err return nil, err
} }
} }
if query := _q.withUsageLogs; query != nil {
if err := _q.loadUsageLogs(ctx, query, nodes,
func(n *User) { n.Edges.UsageLogs = []*UsageLog{} },
func(n *User, e *UsageLog) { n.Edges.UsageLogs = append(n.Edges.UsageLogs, e) }); err != nil {
return nil, err
}
}
if query := _q.withUserAllowedGroups; query != nil { if query := _q.withUserAllowedGroups; query != nil {
if err := _q.loadUserAllowedGroups(ctx, query, nodes, if err := _q.loadUserAllowedGroups(ctx, query, nodes,
func(n *User) { n.Edges.UserAllowedGroups = []*UserAllowedGroup{} }, func(n *User) { n.Edges.UserAllowedGroups = []*UserAllowedGroup{} },
...@@ -811,6 +855,36 @@ func (_q *UserQuery) loadAllowedGroups(ctx context.Context, query *GroupQuery, n ...@@ -811,6 +855,36 @@ func (_q *UserQuery) loadAllowedGroups(ctx context.Context, query *GroupQuery, n
} }
return nil return nil
} }
func (_q *UserQuery) loadUsageLogs(ctx context.Context, query *UsageLogQuery, nodes []*User, init func(*User), assign func(*User, *UsageLog)) error {
fks := make([]driver.Value, 0, len(nodes))
nodeids := make(map[int64]*User)
for i := range nodes {
fks = append(fks, nodes[i].ID)
nodeids[nodes[i].ID] = nodes[i]
if init != nil {
init(nodes[i])
}
}
if len(query.ctx.Fields) > 0 {
query.ctx.AppendFieldOnce(usagelog.FieldUserID)
}
query.Where(predicate.UsageLog(func(s *sql.Selector) {
s.Where(sql.InValues(s.C(user.UsageLogsColumn), fks...))
}))
neighbors, err := query.All(ctx)
if err != nil {
return err
}
for _, n := range neighbors {
fk := n.UserID
node, ok := nodeids[fk]
if !ok {
return fmt.Errorf(`unexpected referenced foreign-key "user_id" returned %v for node %v`, fk, n.ID)
}
assign(node, n)
}
return nil
}
func (_q *UserQuery) loadUserAllowedGroups(ctx context.Context, query *UserAllowedGroupQuery, nodes []*User, init func(*User), assign func(*User, *UserAllowedGroup)) error { func (_q *UserQuery) loadUserAllowedGroups(ctx context.Context, query *UserAllowedGroupQuery, nodes []*User, init func(*User), assign func(*User, *UserAllowedGroup)) error {
fks := make([]driver.Value, 0, len(nodes)) fks := make([]driver.Value, 0, len(nodes))
nodeids := make(map[int64]*User) nodeids := make(map[int64]*User)
......
...@@ -15,6 +15,7 @@ import ( ...@@ -15,6 +15,7 @@ import (
"github.com/Wei-Shaw/sub2api/ent/group" "github.com/Wei-Shaw/sub2api/ent/group"
"github.com/Wei-Shaw/sub2api/ent/predicate" "github.com/Wei-Shaw/sub2api/ent/predicate"
"github.com/Wei-Shaw/sub2api/ent/redeemcode" "github.com/Wei-Shaw/sub2api/ent/redeemcode"
"github.com/Wei-Shaw/sub2api/ent/usagelog"
"github.com/Wei-Shaw/sub2api/ent/user" "github.com/Wei-Shaw/sub2api/ent/user"
"github.com/Wei-Shaw/sub2api/ent/usersubscription" "github.com/Wei-Shaw/sub2api/ent/usersubscription"
) )
...@@ -273,6 +274,21 @@ func (_u *UserUpdate) AddAllowedGroups(v ...*Group) *UserUpdate { ...@@ -273,6 +274,21 @@ func (_u *UserUpdate) AddAllowedGroups(v ...*Group) *UserUpdate {
return _u.AddAllowedGroupIDs(ids...) return _u.AddAllowedGroupIDs(ids...)
} }
// AddUsageLogIDs adds the "usage_logs" edge to the UsageLog entity by IDs.
func (_u *UserUpdate) AddUsageLogIDs(ids ...int64) *UserUpdate {
_u.mutation.AddUsageLogIDs(ids...)
return _u
}
// AddUsageLogs adds the "usage_logs" edges to the UsageLog entity.
func (_u *UserUpdate) AddUsageLogs(v ...*UsageLog) *UserUpdate {
ids := make([]int64, len(v))
for i := range v {
ids[i] = v[i].ID
}
return _u.AddUsageLogIDs(ids...)
}
// Mutation returns the UserMutation object of the builder. // Mutation returns the UserMutation object of the builder.
func (_u *UserUpdate) Mutation() *UserMutation { func (_u *UserUpdate) Mutation() *UserMutation {
return _u.mutation return _u.mutation
...@@ -383,6 +399,27 @@ func (_u *UserUpdate) RemoveAllowedGroups(v ...*Group) *UserUpdate { ...@@ -383,6 +399,27 @@ func (_u *UserUpdate) RemoveAllowedGroups(v ...*Group) *UserUpdate {
return _u.RemoveAllowedGroupIDs(ids...) return _u.RemoveAllowedGroupIDs(ids...)
} }
// ClearUsageLogs clears all "usage_logs" edges to the UsageLog entity.
func (_u *UserUpdate) ClearUsageLogs() *UserUpdate {
_u.mutation.ClearUsageLogs()
return _u
}
// RemoveUsageLogIDs removes the "usage_logs" edge to UsageLog entities by IDs.
func (_u *UserUpdate) RemoveUsageLogIDs(ids ...int64) *UserUpdate {
_u.mutation.RemoveUsageLogIDs(ids...)
return _u
}
// RemoveUsageLogs removes "usage_logs" edges to UsageLog entities.
func (_u *UserUpdate) RemoveUsageLogs(v ...*UsageLog) *UserUpdate {
ids := make([]int64, len(v))
for i := range v {
ids[i] = v[i].ID
}
return _u.RemoveUsageLogIDs(ids...)
}
// Save executes the query and returns the number of nodes affected by the update operation. // Save executes the query and returns the number of nodes affected by the update operation.
func (_u *UserUpdate) Save(ctx context.Context) (int, error) { func (_u *UserUpdate) Save(ctx context.Context) (int, error) {
if err := _u.defaults(); err != nil { if err := _u.defaults(); err != nil {
...@@ -751,6 +788,51 @@ func (_u *UserUpdate) sqlSave(ctx context.Context) (_node int, err error) { ...@@ -751,6 +788,51 @@ func (_u *UserUpdate) sqlSave(ctx context.Context) (_node int, err error) {
edge.Target.Fields = specE.Fields edge.Target.Fields = specE.Fields
_spec.Edges.Add = append(_spec.Edges.Add, edge) _spec.Edges.Add = append(_spec.Edges.Add, edge)
} }
if _u.mutation.UsageLogsCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
Inverse: false,
Table: user.UsageLogsTable,
Columns: []string{user.UsageLogsColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(usagelog.FieldID, field.TypeInt64),
},
}
_spec.Edges.Clear = append(_spec.Edges.Clear, edge)
}
if nodes := _u.mutation.RemovedUsageLogsIDs(); len(nodes) > 0 && !_u.mutation.UsageLogsCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
Inverse: false,
Table: user.UsageLogsTable,
Columns: []string{user.UsageLogsColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(usagelog.FieldID, field.TypeInt64),
},
}
for _, k := range nodes {
edge.Target.Nodes = append(edge.Target.Nodes, k)
}
_spec.Edges.Clear = append(_spec.Edges.Clear, edge)
}
if nodes := _u.mutation.UsageLogsIDs(); len(nodes) > 0 {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
Inverse: false,
Table: user.UsageLogsTable,
Columns: []string{user.UsageLogsColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(usagelog.FieldID, field.TypeInt64),
},
}
for _, k := range nodes {
edge.Target.Nodes = append(edge.Target.Nodes, k)
}
_spec.Edges.Add = append(_spec.Edges.Add, edge)
}
if _node, err = sqlgraph.UpdateNodes(ctx, _u.driver, _spec); err != nil { if _node, err = sqlgraph.UpdateNodes(ctx, _u.driver, _spec); err != nil {
if _, ok := err.(*sqlgraph.NotFoundError); ok { if _, ok := err.(*sqlgraph.NotFoundError); ok {
err = &NotFoundError{user.Label} err = &NotFoundError{user.Label}
...@@ -1012,6 +1094,21 @@ func (_u *UserUpdateOne) AddAllowedGroups(v ...*Group) *UserUpdateOne { ...@@ -1012,6 +1094,21 @@ func (_u *UserUpdateOne) AddAllowedGroups(v ...*Group) *UserUpdateOne {
return _u.AddAllowedGroupIDs(ids...) return _u.AddAllowedGroupIDs(ids...)
} }
// AddUsageLogIDs adds the "usage_logs" edge to the UsageLog entity by IDs.
func (_u *UserUpdateOne) AddUsageLogIDs(ids ...int64) *UserUpdateOne {
_u.mutation.AddUsageLogIDs(ids...)
return _u
}
// AddUsageLogs adds the "usage_logs" edges to the UsageLog entity.
func (_u *UserUpdateOne) AddUsageLogs(v ...*UsageLog) *UserUpdateOne {
ids := make([]int64, len(v))
for i := range v {
ids[i] = v[i].ID
}
return _u.AddUsageLogIDs(ids...)
}
// Mutation returns the UserMutation object of the builder. // Mutation returns the UserMutation object of the builder.
func (_u *UserUpdateOne) Mutation() *UserMutation { func (_u *UserUpdateOne) Mutation() *UserMutation {
return _u.mutation return _u.mutation
...@@ -1122,6 +1219,27 @@ func (_u *UserUpdateOne) RemoveAllowedGroups(v ...*Group) *UserUpdateOne { ...@@ -1122,6 +1219,27 @@ func (_u *UserUpdateOne) RemoveAllowedGroups(v ...*Group) *UserUpdateOne {
return _u.RemoveAllowedGroupIDs(ids...) return _u.RemoveAllowedGroupIDs(ids...)
} }
// ClearUsageLogs clears all "usage_logs" edges to the UsageLog entity.
func (_u *UserUpdateOne) ClearUsageLogs() *UserUpdateOne {
_u.mutation.ClearUsageLogs()
return _u
}
// RemoveUsageLogIDs removes the "usage_logs" edge to UsageLog entities by IDs.
func (_u *UserUpdateOne) RemoveUsageLogIDs(ids ...int64) *UserUpdateOne {
_u.mutation.RemoveUsageLogIDs(ids...)
return _u
}
// RemoveUsageLogs removes "usage_logs" edges to UsageLog entities.
func (_u *UserUpdateOne) RemoveUsageLogs(v ...*UsageLog) *UserUpdateOne {
ids := make([]int64, len(v))
for i := range v {
ids[i] = v[i].ID
}
return _u.RemoveUsageLogIDs(ids...)
}
// Where appends a list predicates to the UserUpdate builder. // Where appends a list predicates to the UserUpdate builder.
func (_u *UserUpdateOne) Where(ps ...predicate.User) *UserUpdateOne { func (_u *UserUpdateOne) Where(ps ...predicate.User) *UserUpdateOne {
_u.mutation.Where(ps...) _u.mutation.Where(ps...)
...@@ -1520,6 +1638,51 @@ func (_u *UserUpdateOne) sqlSave(ctx context.Context) (_node *User, err error) { ...@@ -1520,6 +1638,51 @@ func (_u *UserUpdateOne) sqlSave(ctx context.Context) (_node *User, err error) {
edge.Target.Fields = specE.Fields edge.Target.Fields = specE.Fields
_spec.Edges.Add = append(_spec.Edges.Add, edge) _spec.Edges.Add = append(_spec.Edges.Add, edge)
} }
if _u.mutation.UsageLogsCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
Inverse: false,
Table: user.UsageLogsTable,
Columns: []string{user.UsageLogsColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(usagelog.FieldID, field.TypeInt64),
},
}
_spec.Edges.Clear = append(_spec.Edges.Clear, edge)
}
if nodes := _u.mutation.RemovedUsageLogsIDs(); len(nodes) > 0 && !_u.mutation.UsageLogsCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
Inverse: false,
Table: user.UsageLogsTable,
Columns: []string{user.UsageLogsColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(usagelog.FieldID, field.TypeInt64),
},
}
for _, k := range nodes {
edge.Target.Nodes = append(edge.Target.Nodes, k)
}
_spec.Edges.Clear = append(_spec.Edges.Clear, edge)
}
if nodes := _u.mutation.UsageLogsIDs(); len(nodes) > 0 {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
Inverse: false,
Table: user.UsageLogsTable,
Columns: []string{user.UsageLogsColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(usagelog.FieldID, field.TypeInt64),
},
}
for _, k := range nodes {
edge.Target.Nodes = append(edge.Target.Nodes, k)
}
_spec.Edges.Add = append(_spec.Edges.Add, edge)
}
_node = &User{config: _u.config} _node = &User{config: _u.config}
_spec.Assign = _node.assignValues _spec.Assign = _node.assignValues
_spec.ScanValues = _node.scanValues _spec.ScanValues = _node.scanValues
......
...@@ -23,6 +23,8 @@ type UserSubscription struct { ...@@ -23,6 +23,8 @@ type UserSubscription struct {
CreatedAt time.Time `json:"created_at,omitempty"` CreatedAt time.Time `json:"created_at,omitempty"`
// UpdatedAt holds the value of the "updated_at" field. // UpdatedAt holds the value of the "updated_at" field.
UpdatedAt time.Time `json:"updated_at,omitempty"` UpdatedAt time.Time `json:"updated_at,omitempty"`
// DeletedAt holds the value of the "deleted_at" field.
DeletedAt *time.Time `json:"deleted_at,omitempty"`
// UserID holds the value of the "user_id" field. // UserID holds the value of the "user_id" field.
UserID int64 `json:"user_id,omitempty"` UserID int64 `json:"user_id,omitempty"`
// GroupID holds the value of the "group_id" field. // GroupID holds the value of the "group_id" field.
...@@ -65,9 +67,11 @@ type UserSubscriptionEdges struct { ...@@ -65,9 +67,11 @@ type UserSubscriptionEdges struct {
Group *Group `json:"group,omitempty"` Group *Group `json:"group,omitempty"`
// AssignedByUser holds the value of the assigned_by_user edge. // AssignedByUser holds the value of the assigned_by_user edge.
AssignedByUser *User `json:"assigned_by_user,omitempty"` AssignedByUser *User `json:"assigned_by_user,omitempty"`
// UsageLogs holds the value of the usage_logs edge.
UsageLogs []*UsageLog `json:"usage_logs,omitempty"`
// loadedTypes holds the information for reporting if a // loadedTypes holds the information for reporting if a
// type was loaded (or requested) in eager-loading or not. // type was loaded (or requested) in eager-loading or not.
loadedTypes [3]bool loadedTypes [4]bool
} }
// UserOrErr returns the User value or an error if the edge // UserOrErr returns the User value or an error if the edge
...@@ -103,6 +107,15 @@ func (e UserSubscriptionEdges) AssignedByUserOrErr() (*User, error) { ...@@ -103,6 +107,15 @@ func (e UserSubscriptionEdges) AssignedByUserOrErr() (*User, error) {
return nil, &NotLoadedError{edge: "assigned_by_user"} return nil, &NotLoadedError{edge: "assigned_by_user"}
} }
// UsageLogsOrErr returns the UsageLogs value or an error if the edge
// was not loaded in eager-loading.
func (e UserSubscriptionEdges) UsageLogsOrErr() ([]*UsageLog, error) {
if e.loadedTypes[3] {
return e.UsageLogs, nil
}
return nil, &NotLoadedError{edge: "usage_logs"}
}
// scanValues returns the types for scanning values from sql.Rows. // scanValues returns the types for scanning values from sql.Rows.
func (*UserSubscription) scanValues(columns []string) ([]any, error) { func (*UserSubscription) scanValues(columns []string) ([]any, error) {
values := make([]any, len(columns)) values := make([]any, len(columns))
...@@ -114,7 +127,7 @@ func (*UserSubscription) scanValues(columns []string) ([]any, error) { ...@@ -114,7 +127,7 @@ func (*UserSubscription) scanValues(columns []string) ([]any, error) {
values[i] = new(sql.NullInt64) values[i] = new(sql.NullInt64)
case usersubscription.FieldStatus, usersubscription.FieldNotes: case usersubscription.FieldStatus, usersubscription.FieldNotes:
values[i] = new(sql.NullString) values[i] = new(sql.NullString)
case usersubscription.FieldCreatedAt, usersubscription.FieldUpdatedAt, usersubscription.FieldStartsAt, usersubscription.FieldExpiresAt, usersubscription.FieldDailyWindowStart, usersubscription.FieldWeeklyWindowStart, usersubscription.FieldMonthlyWindowStart, usersubscription.FieldAssignedAt: case usersubscription.FieldCreatedAt, usersubscription.FieldUpdatedAt, usersubscription.FieldDeletedAt, usersubscription.FieldStartsAt, usersubscription.FieldExpiresAt, usersubscription.FieldDailyWindowStart, usersubscription.FieldWeeklyWindowStart, usersubscription.FieldMonthlyWindowStart, usersubscription.FieldAssignedAt:
values[i] = new(sql.NullTime) values[i] = new(sql.NullTime)
default: default:
values[i] = new(sql.UnknownType) values[i] = new(sql.UnknownType)
...@@ -149,6 +162,13 @@ func (_m *UserSubscription) assignValues(columns []string, values []any) error { ...@@ -149,6 +162,13 @@ func (_m *UserSubscription) assignValues(columns []string, values []any) error {
} else if value.Valid { } else if value.Valid {
_m.UpdatedAt = value.Time _m.UpdatedAt = value.Time
} }
case usersubscription.FieldDeletedAt:
if value, ok := values[i].(*sql.NullTime); !ok {
return fmt.Errorf("unexpected type %T for field deleted_at", values[i])
} else if value.Valid {
_m.DeletedAt = new(time.Time)
*_m.DeletedAt = value.Time
}
case usersubscription.FieldUserID: case usersubscription.FieldUserID:
if value, ok := values[i].(*sql.NullInt64); !ok { if value, ok := values[i].(*sql.NullInt64); !ok {
return fmt.Errorf("unexpected type %T for field user_id", values[i]) return fmt.Errorf("unexpected type %T for field user_id", values[i])
...@@ -266,6 +286,11 @@ func (_m *UserSubscription) QueryAssignedByUser() *UserQuery { ...@@ -266,6 +286,11 @@ func (_m *UserSubscription) QueryAssignedByUser() *UserQuery {
return NewUserSubscriptionClient(_m.config).QueryAssignedByUser(_m) return NewUserSubscriptionClient(_m.config).QueryAssignedByUser(_m)
} }
// QueryUsageLogs queries the "usage_logs" edge of the UserSubscription entity.
func (_m *UserSubscription) QueryUsageLogs() *UsageLogQuery {
return NewUserSubscriptionClient(_m.config).QueryUsageLogs(_m)
}
// Update returns a builder for updating this UserSubscription. // Update returns a builder for updating this UserSubscription.
// Note that you need to call UserSubscription.Unwrap() before calling this method if this UserSubscription // Note that you need to call UserSubscription.Unwrap() before calling this method if this UserSubscription
// was returned from a transaction, and the transaction was committed or rolled back. // was returned from a transaction, and the transaction was committed or rolled back.
...@@ -295,6 +320,11 @@ func (_m *UserSubscription) String() string { ...@@ -295,6 +320,11 @@ func (_m *UserSubscription) String() string {
builder.WriteString("updated_at=") builder.WriteString("updated_at=")
builder.WriteString(_m.UpdatedAt.Format(time.ANSIC)) builder.WriteString(_m.UpdatedAt.Format(time.ANSIC))
builder.WriteString(", ") builder.WriteString(", ")
if v := _m.DeletedAt; v != nil {
builder.WriteString("deleted_at=")
builder.WriteString(v.Format(time.ANSIC))
}
builder.WriteString(", ")
builder.WriteString("user_id=") builder.WriteString("user_id=")
builder.WriteString(fmt.Sprintf("%v", _m.UserID)) builder.WriteString(fmt.Sprintf("%v", _m.UserID))
builder.WriteString(", ") builder.WriteString(", ")
......
...@@ -5,6 +5,7 @@ package usersubscription ...@@ -5,6 +5,7 @@ package usersubscription
import ( import (
"time" "time"
"entgo.io/ent"
"entgo.io/ent/dialect/sql" "entgo.io/ent/dialect/sql"
"entgo.io/ent/dialect/sql/sqlgraph" "entgo.io/ent/dialect/sql/sqlgraph"
) )
...@@ -18,6 +19,8 @@ const ( ...@@ -18,6 +19,8 @@ const (
FieldCreatedAt = "created_at" FieldCreatedAt = "created_at"
// FieldUpdatedAt holds the string denoting the updated_at field in the database. // FieldUpdatedAt holds the string denoting the updated_at field in the database.
FieldUpdatedAt = "updated_at" FieldUpdatedAt = "updated_at"
// FieldDeletedAt holds the string denoting the deleted_at field in the database.
FieldDeletedAt = "deleted_at"
// FieldUserID holds the string denoting the user_id field in the database. // FieldUserID holds the string denoting the user_id field in the database.
FieldUserID = "user_id" FieldUserID = "user_id"
// FieldGroupID holds the string denoting the group_id field in the database. // FieldGroupID holds the string denoting the group_id field in the database.
...@@ -52,6 +55,8 @@ const ( ...@@ -52,6 +55,8 @@ const (
EdgeGroup = "group" EdgeGroup = "group"
// EdgeAssignedByUser holds the string denoting the assigned_by_user edge name in mutations. // EdgeAssignedByUser holds the string denoting the assigned_by_user edge name in mutations.
EdgeAssignedByUser = "assigned_by_user" EdgeAssignedByUser = "assigned_by_user"
// EdgeUsageLogs holds the string denoting the usage_logs edge name in mutations.
EdgeUsageLogs = "usage_logs"
// Table holds the table name of the usersubscription in the database. // Table holds the table name of the usersubscription in the database.
Table = "user_subscriptions" Table = "user_subscriptions"
// UserTable is the table that holds the user relation/edge. // UserTable is the table that holds the user relation/edge.
...@@ -75,6 +80,13 @@ const ( ...@@ -75,6 +80,13 @@ const (
AssignedByUserInverseTable = "users" AssignedByUserInverseTable = "users"
// AssignedByUserColumn is the table column denoting the assigned_by_user relation/edge. // AssignedByUserColumn is the table column denoting the assigned_by_user relation/edge.
AssignedByUserColumn = "assigned_by" AssignedByUserColumn = "assigned_by"
// UsageLogsTable is the table that holds the usage_logs relation/edge.
UsageLogsTable = "usage_logs"
// UsageLogsInverseTable is the table name for the UsageLog entity.
// It exists in this package in order to avoid circular dependency with the "usagelog" package.
UsageLogsInverseTable = "usage_logs"
// UsageLogsColumn is the table column denoting the usage_logs relation/edge.
UsageLogsColumn = "subscription_id"
) )
// Columns holds all SQL columns for usersubscription fields. // Columns holds all SQL columns for usersubscription fields.
...@@ -82,6 +94,7 @@ var Columns = []string{ ...@@ -82,6 +94,7 @@ var Columns = []string{
FieldID, FieldID,
FieldCreatedAt, FieldCreatedAt,
FieldUpdatedAt, FieldUpdatedAt,
FieldDeletedAt,
FieldUserID, FieldUserID,
FieldGroupID, FieldGroupID,
FieldStartsAt, FieldStartsAt,
...@@ -108,7 +121,14 @@ func ValidColumn(column string) bool { ...@@ -108,7 +121,14 @@ func ValidColumn(column string) bool {
return false return false
} }
// Note that the variables below are initialized by the runtime
// package on the initialization of the application. Therefore,
// it should be imported in the main as follows:
//
// import _ "github.com/Wei-Shaw/sub2api/ent/runtime"
var ( var (
Hooks [1]ent.Hook
Interceptors [1]ent.Interceptor
// DefaultCreatedAt holds the default value on creation for the "created_at" field. // DefaultCreatedAt holds the default value on creation for the "created_at" field.
DefaultCreatedAt func() time.Time DefaultCreatedAt func() time.Time
// DefaultUpdatedAt holds the default value on creation for the "updated_at" field. // DefaultUpdatedAt holds the default value on creation for the "updated_at" field.
...@@ -147,6 +167,11 @@ func ByUpdatedAt(opts ...sql.OrderTermOption) OrderOption { ...@@ -147,6 +167,11 @@ func ByUpdatedAt(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldUpdatedAt, opts...).ToFunc() return sql.OrderByField(FieldUpdatedAt, opts...).ToFunc()
} }
// ByDeletedAt orders the results by the deleted_at field.
func ByDeletedAt(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldDeletedAt, opts...).ToFunc()
}
// ByUserID orders the results by the user_id field. // ByUserID orders the results by the user_id field.
func ByUserID(opts ...sql.OrderTermOption) OrderOption { func ByUserID(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldUserID, opts...).ToFunc() return sql.OrderByField(FieldUserID, opts...).ToFunc()
...@@ -237,6 +262,20 @@ func ByAssignedByUserField(field string, opts ...sql.OrderTermOption) OrderOptio ...@@ -237,6 +262,20 @@ func ByAssignedByUserField(field string, opts ...sql.OrderTermOption) OrderOptio
sqlgraph.OrderByNeighborTerms(s, newAssignedByUserStep(), sql.OrderByField(field, opts...)) sqlgraph.OrderByNeighborTerms(s, newAssignedByUserStep(), sql.OrderByField(field, opts...))
} }
} }
// ByUsageLogsCount orders the results by usage_logs count.
func ByUsageLogsCount(opts ...sql.OrderTermOption) OrderOption {
return func(s *sql.Selector) {
sqlgraph.OrderByNeighborsCount(s, newUsageLogsStep(), opts...)
}
}
// ByUsageLogs orders the results by usage_logs terms.
func ByUsageLogs(term sql.OrderTerm, terms ...sql.OrderTerm) OrderOption {
return func(s *sql.Selector) {
sqlgraph.OrderByNeighborTerms(s, newUsageLogsStep(), append([]sql.OrderTerm{term}, terms...)...)
}
}
func newUserStep() *sqlgraph.Step { func newUserStep() *sqlgraph.Step {
return sqlgraph.NewStep( return sqlgraph.NewStep(
sqlgraph.From(Table, FieldID), sqlgraph.From(Table, FieldID),
...@@ -258,3 +297,10 @@ func newAssignedByUserStep() *sqlgraph.Step { ...@@ -258,3 +297,10 @@ func newAssignedByUserStep() *sqlgraph.Step {
sqlgraph.Edge(sqlgraph.M2O, true, AssignedByUserTable, AssignedByUserColumn), sqlgraph.Edge(sqlgraph.M2O, true, AssignedByUserTable, AssignedByUserColumn),
) )
} }
func newUsageLogsStep() *sqlgraph.Step {
return sqlgraph.NewStep(
sqlgraph.From(Table, FieldID),
sqlgraph.To(UsageLogsInverseTable, FieldID),
sqlgraph.Edge(sqlgraph.O2M, false, UsageLogsTable, UsageLogsColumn),
)
}
...@@ -65,6 +65,11 @@ func UpdatedAt(v time.Time) predicate.UserSubscription { ...@@ -65,6 +65,11 @@ func UpdatedAt(v time.Time) predicate.UserSubscription {
return predicate.UserSubscription(sql.FieldEQ(FieldUpdatedAt, v)) return predicate.UserSubscription(sql.FieldEQ(FieldUpdatedAt, v))
} }
// DeletedAt applies equality check predicate on the "deleted_at" field. It's identical to DeletedAtEQ.
func DeletedAt(v time.Time) predicate.UserSubscription {
return predicate.UserSubscription(sql.FieldEQ(FieldDeletedAt, v))
}
// UserID applies equality check predicate on the "user_id" field. It's identical to UserIDEQ. // UserID applies equality check predicate on the "user_id" field. It's identical to UserIDEQ.
func UserID(v int64) predicate.UserSubscription { func UserID(v int64) predicate.UserSubscription {
return predicate.UserSubscription(sql.FieldEQ(FieldUserID, v)) return predicate.UserSubscription(sql.FieldEQ(FieldUserID, v))
...@@ -215,6 +220,56 @@ func UpdatedAtLTE(v time.Time) predicate.UserSubscription { ...@@ -215,6 +220,56 @@ func UpdatedAtLTE(v time.Time) predicate.UserSubscription {
return predicate.UserSubscription(sql.FieldLTE(FieldUpdatedAt, v)) return predicate.UserSubscription(sql.FieldLTE(FieldUpdatedAt, v))
} }
// DeletedAtEQ applies the EQ predicate on the "deleted_at" field.
func DeletedAtEQ(v time.Time) predicate.UserSubscription {
return predicate.UserSubscription(sql.FieldEQ(FieldDeletedAt, v))
}
// DeletedAtNEQ applies the NEQ predicate on the "deleted_at" field.
func DeletedAtNEQ(v time.Time) predicate.UserSubscription {
return predicate.UserSubscription(sql.FieldNEQ(FieldDeletedAt, v))
}
// DeletedAtIn applies the In predicate on the "deleted_at" field.
func DeletedAtIn(vs ...time.Time) predicate.UserSubscription {
return predicate.UserSubscription(sql.FieldIn(FieldDeletedAt, vs...))
}
// DeletedAtNotIn applies the NotIn predicate on the "deleted_at" field.
func DeletedAtNotIn(vs ...time.Time) predicate.UserSubscription {
return predicate.UserSubscription(sql.FieldNotIn(FieldDeletedAt, vs...))
}
// DeletedAtGT applies the GT predicate on the "deleted_at" field.
func DeletedAtGT(v time.Time) predicate.UserSubscription {
return predicate.UserSubscription(sql.FieldGT(FieldDeletedAt, v))
}
// DeletedAtGTE applies the GTE predicate on the "deleted_at" field.
func DeletedAtGTE(v time.Time) predicate.UserSubscription {
return predicate.UserSubscription(sql.FieldGTE(FieldDeletedAt, v))
}
// DeletedAtLT applies the LT predicate on the "deleted_at" field.
func DeletedAtLT(v time.Time) predicate.UserSubscription {
return predicate.UserSubscription(sql.FieldLT(FieldDeletedAt, v))
}
// DeletedAtLTE applies the LTE predicate on the "deleted_at" field.
func DeletedAtLTE(v time.Time) predicate.UserSubscription {
return predicate.UserSubscription(sql.FieldLTE(FieldDeletedAt, v))
}
// DeletedAtIsNil applies the IsNil predicate on the "deleted_at" field.
func DeletedAtIsNil() predicate.UserSubscription {
return predicate.UserSubscription(sql.FieldIsNull(FieldDeletedAt))
}
// DeletedAtNotNil applies the NotNil predicate on the "deleted_at" field.
func DeletedAtNotNil() predicate.UserSubscription {
return predicate.UserSubscription(sql.FieldNotNull(FieldDeletedAt))
}
// UserIDEQ applies the EQ predicate on the "user_id" field. // UserIDEQ applies the EQ predicate on the "user_id" field.
func UserIDEQ(v int64) predicate.UserSubscription { func UserIDEQ(v int64) predicate.UserSubscription {
return predicate.UserSubscription(sql.FieldEQ(FieldUserID, v)) return predicate.UserSubscription(sql.FieldEQ(FieldUserID, v))
...@@ -884,6 +939,29 @@ func HasAssignedByUserWith(preds ...predicate.User) predicate.UserSubscription { ...@@ -884,6 +939,29 @@ func HasAssignedByUserWith(preds ...predicate.User) predicate.UserSubscription {
}) })
} }
// HasUsageLogs applies the HasEdge predicate on the "usage_logs" edge.
func HasUsageLogs() predicate.UserSubscription {
return predicate.UserSubscription(func(s *sql.Selector) {
step := sqlgraph.NewStep(
sqlgraph.From(Table, FieldID),
sqlgraph.Edge(sqlgraph.O2M, false, UsageLogsTable, UsageLogsColumn),
)
sqlgraph.HasNeighbors(s, step)
})
}
// HasUsageLogsWith applies the HasEdge predicate on the "usage_logs" edge with a given conditions (other predicates).
func HasUsageLogsWith(preds ...predicate.UsageLog) predicate.UserSubscription {
return predicate.UserSubscription(func(s *sql.Selector) {
step := newUsageLogsStep()
sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) {
for _, p := range preds {
p(s)
}
})
})
}
// And groups predicates with the AND operator between them. // And groups predicates with the AND operator between them.
func And(predicates ...predicate.UserSubscription) predicate.UserSubscription { func And(predicates ...predicate.UserSubscription) predicate.UserSubscription {
return predicate.UserSubscription(sql.AndPredicates(predicates...)) return predicate.UserSubscription(sql.AndPredicates(predicates...))
......
...@@ -12,6 +12,7 @@ import ( ...@@ -12,6 +12,7 @@ import (
"entgo.io/ent/dialect/sql/sqlgraph" "entgo.io/ent/dialect/sql/sqlgraph"
"entgo.io/ent/schema/field" "entgo.io/ent/schema/field"
"github.com/Wei-Shaw/sub2api/ent/group" "github.com/Wei-Shaw/sub2api/ent/group"
"github.com/Wei-Shaw/sub2api/ent/usagelog"
"github.com/Wei-Shaw/sub2api/ent/user" "github.com/Wei-Shaw/sub2api/ent/user"
"github.com/Wei-Shaw/sub2api/ent/usersubscription" "github.com/Wei-Shaw/sub2api/ent/usersubscription"
) )
...@@ -52,6 +53,20 @@ func (_c *UserSubscriptionCreate) SetNillableUpdatedAt(v *time.Time) *UserSubscr ...@@ -52,6 +53,20 @@ func (_c *UserSubscriptionCreate) SetNillableUpdatedAt(v *time.Time) *UserSubscr
return _c return _c
} }
// SetDeletedAt sets the "deleted_at" field.
func (_c *UserSubscriptionCreate) SetDeletedAt(v time.Time) *UserSubscriptionCreate {
_c.mutation.SetDeletedAt(v)
return _c
}
// SetNillableDeletedAt sets the "deleted_at" field if the given value is not nil.
func (_c *UserSubscriptionCreate) SetNillableDeletedAt(v *time.Time) *UserSubscriptionCreate {
if v != nil {
_c.SetDeletedAt(*v)
}
return _c
}
// SetUserID sets the "user_id" field. // SetUserID sets the "user_id" field.
func (_c *UserSubscriptionCreate) SetUserID(v int64) *UserSubscriptionCreate { func (_c *UserSubscriptionCreate) SetUserID(v int64) *UserSubscriptionCreate {
_c.mutation.SetUserID(v) _c.mutation.SetUserID(v)
...@@ -245,6 +260,21 @@ func (_c *UserSubscriptionCreate) SetAssignedByUser(v *User) *UserSubscriptionCr ...@@ -245,6 +260,21 @@ func (_c *UserSubscriptionCreate) SetAssignedByUser(v *User) *UserSubscriptionCr
return _c.SetAssignedByUserID(v.ID) return _c.SetAssignedByUserID(v.ID)
} }
// AddUsageLogIDs adds the "usage_logs" edge to the UsageLog entity by IDs.
func (_c *UserSubscriptionCreate) AddUsageLogIDs(ids ...int64) *UserSubscriptionCreate {
_c.mutation.AddUsageLogIDs(ids...)
return _c
}
// AddUsageLogs adds the "usage_logs" edges to the UsageLog entity.
func (_c *UserSubscriptionCreate) AddUsageLogs(v ...*UsageLog) *UserSubscriptionCreate {
ids := make([]int64, len(v))
for i := range v {
ids[i] = v[i].ID
}
return _c.AddUsageLogIDs(ids...)
}
// Mutation returns the UserSubscriptionMutation object of the builder. // Mutation returns the UserSubscriptionMutation object of the builder.
func (_c *UserSubscriptionCreate) Mutation() *UserSubscriptionMutation { func (_c *UserSubscriptionCreate) Mutation() *UserSubscriptionMutation {
return _c.mutation return _c.mutation
...@@ -252,7 +282,9 @@ func (_c *UserSubscriptionCreate) Mutation() *UserSubscriptionMutation { ...@@ -252,7 +282,9 @@ func (_c *UserSubscriptionCreate) Mutation() *UserSubscriptionMutation {
// Save creates the UserSubscription in the database. // Save creates the UserSubscription in the database.
func (_c *UserSubscriptionCreate) Save(ctx context.Context) (*UserSubscription, error) { func (_c *UserSubscriptionCreate) Save(ctx context.Context) (*UserSubscription, error) {
_c.defaults() if err := _c.defaults(); err != nil {
return nil, err
}
return withHooks(ctx, _c.sqlSave, _c.mutation, _c.hooks) return withHooks(ctx, _c.sqlSave, _c.mutation, _c.hooks)
} }
...@@ -279,12 +311,18 @@ func (_c *UserSubscriptionCreate) ExecX(ctx context.Context) { ...@@ -279,12 +311,18 @@ func (_c *UserSubscriptionCreate) ExecX(ctx context.Context) {
} }
// defaults sets the default values of the builder before save. // defaults sets the default values of the builder before save.
func (_c *UserSubscriptionCreate) defaults() { func (_c *UserSubscriptionCreate) defaults() error {
if _, ok := _c.mutation.CreatedAt(); !ok { if _, ok := _c.mutation.CreatedAt(); !ok {
if usersubscription.DefaultCreatedAt == nil {
return fmt.Errorf("ent: uninitialized usersubscription.DefaultCreatedAt (forgotten import ent/runtime?)")
}
v := usersubscription.DefaultCreatedAt() v := usersubscription.DefaultCreatedAt()
_c.mutation.SetCreatedAt(v) _c.mutation.SetCreatedAt(v)
} }
if _, ok := _c.mutation.UpdatedAt(); !ok { if _, ok := _c.mutation.UpdatedAt(); !ok {
if usersubscription.DefaultUpdatedAt == nil {
return fmt.Errorf("ent: uninitialized usersubscription.DefaultUpdatedAt (forgotten import ent/runtime?)")
}
v := usersubscription.DefaultUpdatedAt() v := usersubscription.DefaultUpdatedAt()
_c.mutation.SetUpdatedAt(v) _c.mutation.SetUpdatedAt(v)
} }
...@@ -305,9 +343,13 @@ func (_c *UserSubscriptionCreate) defaults() { ...@@ -305,9 +343,13 @@ func (_c *UserSubscriptionCreate) defaults() {
_c.mutation.SetMonthlyUsageUsd(v) _c.mutation.SetMonthlyUsageUsd(v)
} }
if _, ok := _c.mutation.AssignedAt(); !ok { if _, ok := _c.mutation.AssignedAt(); !ok {
if usersubscription.DefaultAssignedAt == nil {
return fmt.Errorf("ent: uninitialized usersubscription.DefaultAssignedAt (forgotten import ent/runtime?)")
}
v := usersubscription.DefaultAssignedAt() v := usersubscription.DefaultAssignedAt()
_c.mutation.SetAssignedAt(v) _c.mutation.SetAssignedAt(v)
} }
return nil
} }
// check runs all checks and user-defined validators on the builder. // check runs all checks and user-defined validators on the builder.
...@@ -391,6 +433,10 @@ func (_c *UserSubscriptionCreate) createSpec() (*UserSubscription, *sqlgraph.Cre ...@@ -391,6 +433,10 @@ func (_c *UserSubscriptionCreate) createSpec() (*UserSubscription, *sqlgraph.Cre
_spec.SetField(usersubscription.FieldUpdatedAt, field.TypeTime, value) _spec.SetField(usersubscription.FieldUpdatedAt, field.TypeTime, value)
_node.UpdatedAt = value _node.UpdatedAt = value
} }
if value, ok := _c.mutation.DeletedAt(); ok {
_spec.SetField(usersubscription.FieldDeletedAt, field.TypeTime, value)
_node.DeletedAt = &value
}
if value, ok := _c.mutation.StartsAt(); ok { if value, ok := _c.mutation.StartsAt(); ok {
_spec.SetField(usersubscription.FieldStartsAt, field.TypeTime, value) _spec.SetField(usersubscription.FieldStartsAt, field.TypeTime, value)
_node.StartsAt = value _node.StartsAt = value
...@@ -486,6 +532,22 @@ func (_c *UserSubscriptionCreate) createSpec() (*UserSubscription, *sqlgraph.Cre ...@@ -486,6 +532,22 @@ func (_c *UserSubscriptionCreate) createSpec() (*UserSubscription, *sqlgraph.Cre
_node.AssignedBy = &nodes[0] _node.AssignedBy = &nodes[0]
_spec.Edges = append(_spec.Edges, edge) _spec.Edges = append(_spec.Edges, edge)
} }
if nodes := _c.mutation.UsageLogsIDs(); len(nodes) > 0 {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
Inverse: false,
Table: usersubscription.UsageLogsTable,
Columns: []string{usersubscription.UsageLogsColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(usagelog.FieldID, field.TypeInt64),
},
}
for _, k := range nodes {
edge.Target.Nodes = append(edge.Target.Nodes, k)
}
_spec.Edges = append(_spec.Edges, edge)
}
return _node, _spec return _node, _spec
} }
...@@ -550,6 +612,24 @@ func (u *UserSubscriptionUpsert) UpdateUpdatedAt() *UserSubscriptionUpsert { ...@@ -550,6 +612,24 @@ func (u *UserSubscriptionUpsert) UpdateUpdatedAt() *UserSubscriptionUpsert {
return u return u
} }
// SetDeletedAt sets the "deleted_at" field.
func (u *UserSubscriptionUpsert) SetDeletedAt(v time.Time) *UserSubscriptionUpsert {
u.Set(usersubscription.FieldDeletedAt, v)
return u
}
// UpdateDeletedAt sets the "deleted_at" field to the value that was provided on create.
func (u *UserSubscriptionUpsert) UpdateDeletedAt() *UserSubscriptionUpsert {
u.SetExcluded(usersubscription.FieldDeletedAt)
return u
}
// ClearDeletedAt clears the value of the "deleted_at" field.
func (u *UserSubscriptionUpsert) ClearDeletedAt() *UserSubscriptionUpsert {
u.SetNull(usersubscription.FieldDeletedAt)
return u
}
// SetUserID sets the "user_id" field. // SetUserID sets the "user_id" field.
func (u *UserSubscriptionUpsert) SetUserID(v int64) *UserSubscriptionUpsert { func (u *UserSubscriptionUpsert) SetUserID(v int64) *UserSubscriptionUpsert {
u.Set(usersubscription.FieldUserID, v) u.Set(usersubscription.FieldUserID, v)
...@@ -825,6 +905,27 @@ func (u *UserSubscriptionUpsertOne) UpdateUpdatedAt() *UserSubscriptionUpsertOne ...@@ -825,6 +905,27 @@ func (u *UserSubscriptionUpsertOne) UpdateUpdatedAt() *UserSubscriptionUpsertOne
}) })
} }
// SetDeletedAt sets the "deleted_at" field.
func (u *UserSubscriptionUpsertOne) SetDeletedAt(v time.Time) *UserSubscriptionUpsertOne {
return u.Update(func(s *UserSubscriptionUpsert) {
s.SetDeletedAt(v)
})
}
// UpdateDeletedAt sets the "deleted_at" field to the value that was provided on create.
func (u *UserSubscriptionUpsertOne) UpdateDeletedAt() *UserSubscriptionUpsertOne {
return u.Update(func(s *UserSubscriptionUpsert) {
s.UpdateDeletedAt()
})
}
// ClearDeletedAt clears the value of the "deleted_at" field.
func (u *UserSubscriptionUpsertOne) ClearDeletedAt() *UserSubscriptionUpsertOne {
return u.Update(func(s *UserSubscriptionUpsert) {
s.ClearDeletedAt()
})
}
// SetUserID sets the "user_id" field. // SetUserID sets the "user_id" field.
func (u *UserSubscriptionUpsertOne) SetUserID(v int64) *UserSubscriptionUpsertOne { func (u *UserSubscriptionUpsertOne) SetUserID(v int64) *UserSubscriptionUpsertOne {
return u.Update(func(s *UserSubscriptionUpsert) { return u.Update(func(s *UserSubscriptionUpsert) {
...@@ -1302,6 +1403,27 @@ func (u *UserSubscriptionUpsertBulk) UpdateUpdatedAt() *UserSubscriptionUpsertBu ...@@ -1302,6 +1403,27 @@ func (u *UserSubscriptionUpsertBulk) UpdateUpdatedAt() *UserSubscriptionUpsertBu
}) })
} }
// SetDeletedAt sets the "deleted_at" field.
func (u *UserSubscriptionUpsertBulk) SetDeletedAt(v time.Time) *UserSubscriptionUpsertBulk {
return u.Update(func(s *UserSubscriptionUpsert) {
s.SetDeletedAt(v)
})
}
// UpdateDeletedAt sets the "deleted_at" field to the value that was provided on create.
func (u *UserSubscriptionUpsertBulk) UpdateDeletedAt() *UserSubscriptionUpsertBulk {
return u.Update(func(s *UserSubscriptionUpsert) {
s.UpdateDeletedAt()
})
}
// ClearDeletedAt clears the value of the "deleted_at" field.
func (u *UserSubscriptionUpsertBulk) ClearDeletedAt() *UserSubscriptionUpsertBulk {
return u.Update(func(s *UserSubscriptionUpsert) {
s.ClearDeletedAt()
})
}
// SetUserID sets the "user_id" field. // SetUserID sets the "user_id" field.
func (u *UserSubscriptionUpsertBulk) SetUserID(v int64) *UserSubscriptionUpsertBulk { func (u *UserSubscriptionUpsertBulk) SetUserID(v int64) *UserSubscriptionUpsertBulk {
return u.Update(func(s *UserSubscriptionUpsert) { return u.Update(func(s *UserSubscriptionUpsert) {
......
...@@ -4,6 +4,7 @@ package ent ...@@ -4,6 +4,7 @@ package ent
import ( import (
"context" "context"
"database/sql/driver"
"fmt" "fmt"
"math" "math"
...@@ -13,6 +14,7 @@ import ( ...@@ -13,6 +14,7 @@ import (
"entgo.io/ent/schema/field" "entgo.io/ent/schema/field"
"github.com/Wei-Shaw/sub2api/ent/group" "github.com/Wei-Shaw/sub2api/ent/group"
"github.com/Wei-Shaw/sub2api/ent/predicate" "github.com/Wei-Shaw/sub2api/ent/predicate"
"github.com/Wei-Shaw/sub2api/ent/usagelog"
"github.com/Wei-Shaw/sub2api/ent/user" "github.com/Wei-Shaw/sub2api/ent/user"
"github.com/Wei-Shaw/sub2api/ent/usersubscription" "github.com/Wei-Shaw/sub2api/ent/usersubscription"
) )
...@@ -27,6 +29,7 @@ type UserSubscriptionQuery struct { ...@@ -27,6 +29,7 @@ type UserSubscriptionQuery struct {
withUser *UserQuery withUser *UserQuery
withGroup *GroupQuery withGroup *GroupQuery
withAssignedByUser *UserQuery withAssignedByUser *UserQuery
withUsageLogs *UsageLogQuery
// intermediate query (i.e. traversal path). // intermediate query (i.e. traversal path).
sql *sql.Selector sql *sql.Selector
path func(context.Context) (*sql.Selector, error) path func(context.Context) (*sql.Selector, error)
...@@ -129,6 +132,28 @@ func (_q *UserSubscriptionQuery) QueryAssignedByUser() *UserQuery { ...@@ -129,6 +132,28 @@ func (_q *UserSubscriptionQuery) QueryAssignedByUser() *UserQuery {
return query return query
} }
// QueryUsageLogs chains the current query on the "usage_logs" edge.
func (_q *UserSubscriptionQuery) QueryUsageLogs() *UsageLogQuery {
query := (&UsageLogClient{config: _q.config}).Query()
query.path = func(ctx context.Context) (fromU *sql.Selector, err error) {
if err := _q.prepareQuery(ctx); err != nil {
return nil, err
}
selector := _q.sqlQuery(ctx)
if err := selector.Err(); err != nil {
return nil, err
}
step := sqlgraph.NewStep(
sqlgraph.From(usersubscription.Table, usersubscription.FieldID, selector),
sqlgraph.To(usagelog.Table, usagelog.FieldID),
sqlgraph.Edge(sqlgraph.O2M, false, usersubscription.UsageLogsTable, usersubscription.UsageLogsColumn),
)
fromU = sqlgraph.SetNeighbors(_q.driver.Dialect(), step)
return fromU, nil
}
return query
}
// First returns the first UserSubscription entity from the query. // First returns the first UserSubscription entity from the query.
// Returns a *NotFoundError when no UserSubscription was found. // Returns a *NotFoundError when no UserSubscription was found.
func (_q *UserSubscriptionQuery) First(ctx context.Context) (*UserSubscription, error) { func (_q *UserSubscriptionQuery) First(ctx context.Context) (*UserSubscription, error) {
...@@ -324,6 +349,7 @@ func (_q *UserSubscriptionQuery) Clone() *UserSubscriptionQuery { ...@@ -324,6 +349,7 @@ func (_q *UserSubscriptionQuery) Clone() *UserSubscriptionQuery {
withUser: _q.withUser.Clone(), withUser: _q.withUser.Clone(),
withGroup: _q.withGroup.Clone(), withGroup: _q.withGroup.Clone(),
withAssignedByUser: _q.withAssignedByUser.Clone(), withAssignedByUser: _q.withAssignedByUser.Clone(),
withUsageLogs: _q.withUsageLogs.Clone(),
// clone intermediate query. // clone intermediate query.
sql: _q.sql.Clone(), sql: _q.sql.Clone(),
path: _q.path, path: _q.path,
...@@ -363,6 +389,17 @@ func (_q *UserSubscriptionQuery) WithAssignedByUser(opts ...func(*UserQuery)) *U ...@@ -363,6 +389,17 @@ func (_q *UserSubscriptionQuery) WithAssignedByUser(opts ...func(*UserQuery)) *U
return _q return _q
} }
// WithUsageLogs tells the query-builder to eager-load the nodes that are connected to
// the "usage_logs" edge. The optional arguments are used to configure the query builder of the edge.
func (_q *UserSubscriptionQuery) WithUsageLogs(opts ...func(*UsageLogQuery)) *UserSubscriptionQuery {
query := (&UsageLogClient{config: _q.config}).Query()
for _, opt := range opts {
opt(query)
}
_q.withUsageLogs = query
return _q
}
// GroupBy is used to group vertices by one or more fields/columns. // GroupBy is used to group vertices by one or more fields/columns.
// It is often used with aggregate functions, like: count, max, mean, min, sum. // It is often used with aggregate functions, like: count, max, mean, min, sum.
// //
...@@ -441,10 +478,11 @@ func (_q *UserSubscriptionQuery) sqlAll(ctx context.Context, hooks ...queryHook) ...@@ -441,10 +478,11 @@ func (_q *UserSubscriptionQuery) sqlAll(ctx context.Context, hooks ...queryHook)
var ( var (
nodes = []*UserSubscription{} nodes = []*UserSubscription{}
_spec = _q.querySpec() _spec = _q.querySpec()
loadedTypes = [3]bool{ loadedTypes = [4]bool{
_q.withUser != nil, _q.withUser != nil,
_q.withGroup != nil, _q.withGroup != nil,
_q.withAssignedByUser != nil, _q.withAssignedByUser != nil,
_q.withUsageLogs != nil,
} }
) )
_spec.ScanValues = func(columns []string) ([]any, error) { _spec.ScanValues = func(columns []string) ([]any, error) {
...@@ -483,6 +521,13 @@ func (_q *UserSubscriptionQuery) sqlAll(ctx context.Context, hooks ...queryHook) ...@@ -483,6 +521,13 @@ func (_q *UserSubscriptionQuery) sqlAll(ctx context.Context, hooks ...queryHook)
return nil, err return nil, err
} }
} }
if query := _q.withUsageLogs; query != nil {
if err := _q.loadUsageLogs(ctx, query, nodes,
func(n *UserSubscription) { n.Edges.UsageLogs = []*UsageLog{} },
func(n *UserSubscription, e *UsageLog) { n.Edges.UsageLogs = append(n.Edges.UsageLogs, e) }); err != nil {
return nil, err
}
}
return nodes, nil return nodes, nil
} }
...@@ -576,6 +621,39 @@ func (_q *UserSubscriptionQuery) loadAssignedByUser(ctx context.Context, query * ...@@ -576,6 +621,39 @@ func (_q *UserSubscriptionQuery) loadAssignedByUser(ctx context.Context, query *
} }
return nil return nil
} }
func (_q *UserSubscriptionQuery) loadUsageLogs(ctx context.Context, query *UsageLogQuery, nodes []*UserSubscription, init func(*UserSubscription), assign func(*UserSubscription, *UsageLog)) error {
fks := make([]driver.Value, 0, len(nodes))
nodeids := make(map[int64]*UserSubscription)
for i := range nodes {
fks = append(fks, nodes[i].ID)
nodeids[nodes[i].ID] = nodes[i]
if init != nil {
init(nodes[i])
}
}
if len(query.ctx.Fields) > 0 {
query.ctx.AppendFieldOnce(usagelog.FieldSubscriptionID)
}
query.Where(predicate.UsageLog(func(s *sql.Selector) {
s.Where(sql.InValues(s.C(usersubscription.UsageLogsColumn), fks...))
}))
neighbors, err := query.All(ctx)
if err != nil {
return err
}
for _, n := range neighbors {
fk := n.SubscriptionID
if fk == nil {
return fmt.Errorf(`foreign-key "subscription_id" is nil for node %v`, n.ID)
}
node, ok := nodeids[*fk]
if !ok {
return fmt.Errorf(`unexpected referenced foreign-key "subscription_id" returned %v for node %v`, *fk, n.ID)
}
assign(node, n)
}
return nil
}
func (_q *UserSubscriptionQuery) sqlCount(ctx context.Context) (int, error) { func (_q *UserSubscriptionQuery) sqlCount(ctx context.Context) (int, error) {
_spec := _q.querySpec() _spec := _q.querySpec()
......
...@@ -13,6 +13,7 @@ import ( ...@@ -13,6 +13,7 @@ import (
"entgo.io/ent/schema/field" "entgo.io/ent/schema/field"
"github.com/Wei-Shaw/sub2api/ent/group" "github.com/Wei-Shaw/sub2api/ent/group"
"github.com/Wei-Shaw/sub2api/ent/predicate" "github.com/Wei-Shaw/sub2api/ent/predicate"
"github.com/Wei-Shaw/sub2api/ent/usagelog"
"github.com/Wei-Shaw/sub2api/ent/user" "github.com/Wei-Shaw/sub2api/ent/user"
"github.com/Wei-Shaw/sub2api/ent/usersubscription" "github.com/Wei-Shaw/sub2api/ent/usersubscription"
) )
...@@ -36,6 +37,26 @@ func (_u *UserSubscriptionUpdate) SetUpdatedAt(v time.Time) *UserSubscriptionUpd ...@@ -36,6 +37,26 @@ func (_u *UserSubscriptionUpdate) SetUpdatedAt(v time.Time) *UserSubscriptionUpd
return _u return _u
} }
// SetDeletedAt sets the "deleted_at" field.
func (_u *UserSubscriptionUpdate) SetDeletedAt(v time.Time) *UserSubscriptionUpdate {
_u.mutation.SetDeletedAt(v)
return _u
}
// SetNillableDeletedAt sets the "deleted_at" field if the given value is not nil.
func (_u *UserSubscriptionUpdate) SetNillableDeletedAt(v *time.Time) *UserSubscriptionUpdate {
if v != nil {
_u.SetDeletedAt(*v)
}
return _u
}
// ClearDeletedAt clears the value of the "deleted_at" field.
func (_u *UserSubscriptionUpdate) ClearDeletedAt() *UserSubscriptionUpdate {
_u.mutation.ClearDeletedAt()
return _u
}
// SetUserID sets the "user_id" field. // SetUserID sets the "user_id" field.
func (_u *UserSubscriptionUpdate) SetUserID(v int64) *UserSubscriptionUpdate { func (_u *UserSubscriptionUpdate) SetUserID(v int64) *UserSubscriptionUpdate {
_u.mutation.SetUserID(v) _u.mutation.SetUserID(v)
...@@ -312,6 +333,21 @@ func (_u *UserSubscriptionUpdate) SetAssignedByUser(v *User) *UserSubscriptionUp ...@@ -312,6 +333,21 @@ func (_u *UserSubscriptionUpdate) SetAssignedByUser(v *User) *UserSubscriptionUp
return _u.SetAssignedByUserID(v.ID) return _u.SetAssignedByUserID(v.ID)
} }
// AddUsageLogIDs adds the "usage_logs" edge to the UsageLog entity by IDs.
func (_u *UserSubscriptionUpdate) AddUsageLogIDs(ids ...int64) *UserSubscriptionUpdate {
_u.mutation.AddUsageLogIDs(ids...)
return _u
}
// AddUsageLogs adds the "usage_logs" edges to the UsageLog entity.
func (_u *UserSubscriptionUpdate) AddUsageLogs(v ...*UsageLog) *UserSubscriptionUpdate {
ids := make([]int64, len(v))
for i := range v {
ids[i] = v[i].ID
}
return _u.AddUsageLogIDs(ids...)
}
// Mutation returns the UserSubscriptionMutation object of the builder. // Mutation returns the UserSubscriptionMutation object of the builder.
func (_u *UserSubscriptionUpdate) Mutation() *UserSubscriptionMutation { func (_u *UserSubscriptionUpdate) Mutation() *UserSubscriptionMutation {
return _u.mutation return _u.mutation
...@@ -335,9 +371,32 @@ func (_u *UserSubscriptionUpdate) ClearAssignedByUser() *UserSubscriptionUpdate ...@@ -335,9 +371,32 @@ func (_u *UserSubscriptionUpdate) ClearAssignedByUser() *UserSubscriptionUpdate
return _u return _u
} }
// ClearUsageLogs clears all "usage_logs" edges to the UsageLog entity.
func (_u *UserSubscriptionUpdate) ClearUsageLogs() *UserSubscriptionUpdate {
_u.mutation.ClearUsageLogs()
return _u
}
// RemoveUsageLogIDs removes the "usage_logs" edge to UsageLog entities by IDs.
func (_u *UserSubscriptionUpdate) RemoveUsageLogIDs(ids ...int64) *UserSubscriptionUpdate {
_u.mutation.RemoveUsageLogIDs(ids...)
return _u
}
// RemoveUsageLogs removes "usage_logs" edges to UsageLog entities.
func (_u *UserSubscriptionUpdate) RemoveUsageLogs(v ...*UsageLog) *UserSubscriptionUpdate {
ids := make([]int64, len(v))
for i := range v {
ids[i] = v[i].ID
}
return _u.RemoveUsageLogIDs(ids...)
}
// Save executes the query and returns the number of nodes affected by the update operation. // Save executes the query and returns the number of nodes affected by the update operation.
func (_u *UserSubscriptionUpdate) Save(ctx context.Context) (int, error) { func (_u *UserSubscriptionUpdate) Save(ctx context.Context) (int, error) {
_u.defaults() if err := _u.defaults(); err != nil {
return 0, err
}
return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks) return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks)
} }
...@@ -364,11 +423,15 @@ func (_u *UserSubscriptionUpdate) ExecX(ctx context.Context) { ...@@ -364,11 +423,15 @@ func (_u *UserSubscriptionUpdate) ExecX(ctx context.Context) {
} }
// defaults sets the default values of the builder before save. // defaults sets the default values of the builder before save.
func (_u *UserSubscriptionUpdate) defaults() { func (_u *UserSubscriptionUpdate) defaults() error {
if _, ok := _u.mutation.UpdatedAt(); !ok { if _, ok := _u.mutation.UpdatedAt(); !ok {
if usersubscription.UpdateDefaultUpdatedAt == nil {
return fmt.Errorf("ent: uninitialized usersubscription.UpdateDefaultUpdatedAt (forgotten import ent/runtime?)")
}
v := usersubscription.UpdateDefaultUpdatedAt() v := usersubscription.UpdateDefaultUpdatedAt()
_u.mutation.SetUpdatedAt(v) _u.mutation.SetUpdatedAt(v)
} }
return nil
} }
// check runs all checks and user-defined validators on the builder. // check runs all checks and user-defined validators on the builder.
...@@ -402,6 +465,12 @@ func (_u *UserSubscriptionUpdate) sqlSave(ctx context.Context) (_node int, err e ...@@ -402,6 +465,12 @@ func (_u *UserSubscriptionUpdate) sqlSave(ctx context.Context) (_node int, err e
if value, ok := _u.mutation.UpdatedAt(); ok { if value, ok := _u.mutation.UpdatedAt(); ok {
_spec.SetField(usersubscription.FieldUpdatedAt, field.TypeTime, value) _spec.SetField(usersubscription.FieldUpdatedAt, field.TypeTime, value)
} }
if value, ok := _u.mutation.DeletedAt(); ok {
_spec.SetField(usersubscription.FieldDeletedAt, field.TypeTime, value)
}
if _u.mutation.DeletedAtCleared() {
_spec.ClearField(usersubscription.FieldDeletedAt, field.TypeTime)
}
if value, ok := _u.mutation.StartsAt(); ok { if value, ok := _u.mutation.StartsAt(); ok {
_spec.SetField(usersubscription.FieldStartsAt, field.TypeTime, value) _spec.SetField(usersubscription.FieldStartsAt, field.TypeTime, value)
} }
...@@ -543,6 +612,51 @@ func (_u *UserSubscriptionUpdate) sqlSave(ctx context.Context) (_node int, err e ...@@ -543,6 +612,51 @@ func (_u *UserSubscriptionUpdate) sqlSave(ctx context.Context) (_node int, err e
} }
_spec.Edges.Add = append(_spec.Edges.Add, edge) _spec.Edges.Add = append(_spec.Edges.Add, edge)
} }
if _u.mutation.UsageLogsCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
Inverse: false,
Table: usersubscription.UsageLogsTable,
Columns: []string{usersubscription.UsageLogsColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(usagelog.FieldID, field.TypeInt64),
},
}
_spec.Edges.Clear = append(_spec.Edges.Clear, edge)
}
if nodes := _u.mutation.RemovedUsageLogsIDs(); len(nodes) > 0 && !_u.mutation.UsageLogsCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
Inverse: false,
Table: usersubscription.UsageLogsTable,
Columns: []string{usersubscription.UsageLogsColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(usagelog.FieldID, field.TypeInt64),
},
}
for _, k := range nodes {
edge.Target.Nodes = append(edge.Target.Nodes, k)
}
_spec.Edges.Clear = append(_spec.Edges.Clear, edge)
}
if nodes := _u.mutation.UsageLogsIDs(); len(nodes) > 0 {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
Inverse: false,
Table: usersubscription.UsageLogsTable,
Columns: []string{usersubscription.UsageLogsColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(usagelog.FieldID, field.TypeInt64),
},
}
for _, k := range nodes {
edge.Target.Nodes = append(edge.Target.Nodes, k)
}
_spec.Edges.Add = append(_spec.Edges.Add, edge)
}
if _node, err = sqlgraph.UpdateNodes(ctx, _u.driver, _spec); err != nil { if _node, err = sqlgraph.UpdateNodes(ctx, _u.driver, _spec); err != nil {
if _, ok := err.(*sqlgraph.NotFoundError); ok { if _, ok := err.(*sqlgraph.NotFoundError); ok {
err = &NotFoundError{usersubscription.Label} err = &NotFoundError{usersubscription.Label}
...@@ -569,6 +683,26 @@ func (_u *UserSubscriptionUpdateOne) SetUpdatedAt(v time.Time) *UserSubscription ...@@ -569,6 +683,26 @@ func (_u *UserSubscriptionUpdateOne) SetUpdatedAt(v time.Time) *UserSubscription
return _u return _u
} }
// SetDeletedAt sets the "deleted_at" field.
func (_u *UserSubscriptionUpdateOne) SetDeletedAt(v time.Time) *UserSubscriptionUpdateOne {
_u.mutation.SetDeletedAt(v)
return _u
}
// SetNillableDeletedAt sets the "deleted_at" field if the given value is not nil.
func (_u *UserSubscriptionUpdateOne) SetNillableDeletedAt(v *time.Time) *UserSubscriptionUpdateOne {
if v != nil {
_u.SetDeletedAt(*v)
}
return _u
}
// ClearDeletedAt clears the value of the "deleted_at" field.
func (_u *UserSubscriptionUpdateOne) ClearDeletedAt() *UserSubscriptionUpdateOne {
_u.mutation.ClearDeletedAt()
return _u
}
// SetUserID sets the "user_id" field. // SetUserID sets the "user_id" field.
func (_u *UserSubscriptionUpdateOne) SetUserID(v int64) *UserSubscriptionUpdateOne { func (_u *UserSubscriptionUpdateOne) SetUserID(v int64) *UserSubscriptionUpdateOne {
_u.mutation.SetUserID(v) _u.mutation.SetUserID(v)
...@@ -845,6 +979,21 @@ func (_u *UserSubscriptionUpdateOne) SetAssignedByUser(v *User) *UserSubscriptio ...@@ -845,6 +979,21 @@ func (_u *UserSubscriptionUpdateOne) SetAssignedByUser(v *User) *UserSubscriptio
return _u.SetAssignedByUserID(v.ID) return _u.SetAssignedByUserID(v.ID)
} }
// AddUsageLogIDs adds the "usage_logs" edge to the UsageLog entity by IDs.
func (_u *UserSubscriptionUpdateOne) AddUsageLogIDs(ids ...int64) *UserSubscriptionUpdateOne {
_u.mutation.AddUsageLogIDs(ids...)
return _u
}
// AddUsageLogs adds the "usage_logs" edges to the UsageLog entity.
func (_u *UserSubscriptionUpdateOne) AddUsageLogs(v ...*UsageLog) *UserSubscriptionUpdateOne {
ids := make([]int64, len(v))
for i := range v {
ids[i] = v[i].ID
}
return _u.AddUsageLogIDs(ids...)
}
// Mutation returns the UserSubscriptionMutation object of the builder. // Mutation returns the UserSubscriptionMutation object of the builder.
func (_u *UserSubscriptionUpdateOne) Mutation() *UserSubscriptionMutation { func (_u *UserSubscriptionUpdateOne) Mutation() *UserSubscriptionMutation {
return _u.mutation return _u.mutation
...@@ -868,6 +1017,27 @@ func (_u *UserSubscriptionUpdateOne) ClearAssignedByUser() *UserSubscriptionUpda ...@@ -868,6 +1017,27 @@ func (_u *UserSubscriptionUpdateOne) ClearAssignedByUser() *UserSubscriptionUpda
return _u return _u
} }
// ClearUsageLogs clears all "usage_logs" edges to the UsageLog entity.
func (_u *UserSubscriptionUpdateOne) ClearUsageLogs() *UserSubscriptionUpdateOne {
_u.mutation.ClearUsageLogs()
return _u
}
// RemoveUsageLogIDs removes the "usage_logs" edge to UsageLog entities by IDs.
func (_u *UserSubscriptionUpdateOne) RemoveUsageLogIDs(ids ...int64) *UserSubscriptionUpdateOne {
_u.mutation.RemoveUsageLogIDs(ids...)
return _u
}
// RemoveUsageLogs removes "usage_logs" edges to UsageLog entities.
func (_u *UserSubscriptionUpdateOne) RemoveUsageLogs(v ...*UsageLog) *UserSubscriptionUpdateOne {
ids := make([]int64, len(v))
for i := range v {
ids[i] = v[i].ID
}
return _u.RemoveUsageLogIDs(ids...)
}
// Where appends a list predicates to the UserSubscriptionUpdate builder. // Where appends a list predicates to the UserSubscriptionUpdate builder.
func (_u *UserSubscriptionUpdateOne) Where(ps ...predicate.UserSubscription) *UserSubscriptionUpdateOne { func (_u *UserSubscriptionUpdateOne) Where(ps ...predicate.UserSubscription) *UserSubscriptionUpdateOne {
_u.mutation.Where(ps...) _u.mutation.Where(ps...)
...@@ -883,7 +1053,9 @@ func (_u *UserSubscriptionUpdateOne) Select(field string, fields ...string) *Use ...@@ -883,7 +1053,9 @@ func (_u *UserSubscriptionUpdateOne) Select(field string, fields ...string) *Use
// Save executes the query and returns the updated UserSubscription entity. // Save executes the query and returns the updated UserSubscription entity.
func (_u *UserSubscriptionUpdateOne) Save(ctx context.Context) (*UserSubscription, error) { func (_u *UserSubscriptionUpdateOne) Save(ctx context.Context) (*UserSubscription, error) {
_u.defaults() if err := _u.defaults(); err != nil {
return nil, err
}
return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks) return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks)
} }
...@@ -910,11 +1082,15 @@ func (_u *UserSubscriptionUpdateOne) ExecX(ctx context.Context) { ...@@ -910,11 +1082,15 @@ func (_u *UserSubscriptionUpdateOne) ExecX(ctx context.Context) {
} }
// defaults sets the default values of the builder before save. // defaults sets the default values of the builder before save.
func (_u *UserSubscriptionUpdateOne) defaults() { func (_u *UserSubscriptionUpdateOne) defaults() error {
if _, ok := _u.mutation.UpdatedAt(); !ok { if _, ok := _u.mutation.UpdatedAt(); !ok {
if usersubscription.UpdateDefaultUpdatedAt == nil {
return fmt.Errorf("ent: uninitialized usersubscription.UpdateDefaultUpdatedAt (forgotten import ent/runtime?)")
}
v := usersubscription.UpdateDefaultUpdatedAt() v := usersubscription.UpdateDefaultUpdatedAt()
_u.mutation.SetUpdatedAt(v) _u.mutation.SetUpdatedAt(v)
} }
return nil
} }
// check runs all checks and user-defined validators on the builder. // check runs all checks and user-defined validators on the builder.
...@@ -965,6 +1141,12 @@ func (_u *UserSubscriptionUpdateOne) sqlSave(ctx context.Context) (_node *UserSu ...@@ -965,6 +1141,12 @@ func (_u *UserSubscriptionUpdateOne) sqlSave(ctx context.Context) (_node *UserSu
if value, ok := _u.mutation.UpdatedAt(); ok { if value, ok := _u.mutation.UpdatedAt(); ok {
_spec.SetField(usersubscription.FieldUpdatedAt, field.TypeTime, value) _spec.SetField(usersubscription.FieldUpdatedAt, field.TypeTime, value)
} }
if value, ok := _u.mutation.DeletedAt(); ok {
_spec.SetField(usersubscription.FieldDeletedAt, field.TypeTime, value)
}
if _u.mutation.DeletedAtCleared() {
_spec.ClearField(usersubscription.FieldDeletedAt, field.TypeTime)
}
if value, ok := _u.mutation.StartsAt(); ok { if value, ok := _u.mutation.StartsAt(); ok {
_spec.SetField(usersubscription.FieldStartsAt, field.TypeTime, value) _spec.SetField(usersubscription.FieldStartsAt, field.TypeTime, value)
} }
...@@ -1106,6 +1288,51 @@ func (_u *UserSubscriptionUpdateOne) sqlSave(ctx context.Context) (_node *UserSu ...@@ -1106,6 +1288,51 @@ func (_u *UserSubscriptionUpdateOne) sqlSave(ctx context.Context) (_node *UserSu
} }
_spec.Edges.Add = append(_spec.Edges.Add, edge) _spec.Edges.Add = append(_spec.Edges.Add, edge)
} }
if _u.mutation.UsageLogsCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
Inverse: false,
Table: usersubscription.UsageLogsTable,
Columns: []string{usersubscription.UsageLogsColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(usagelog.FieldID, field.TypeInt64),
},
}
_spec.Edges.Clear = append(_spec.Edges.Clear, edge)
}
if nodes := _u.mutation.RemovedUsageLogsIDs(); len(nodes) > 0 && !_u.mutation.UsageLogsCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
Inverse: false,
Table: usersubscription.UsageLogsTable,
Columns: []string{usersubscription.UsageLogsColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(usagelog.FieldID, field.TypeInt64),
},
}
for _, k := range nodes {
edge.Target.Nodes = append(edge.Target.Nodes, k)
}
_spec.Edges.Clear = append(_spec.Edges.Clear, edge)
}
if nodes := _u.mutation.UsageLogsIDs(); len(nodes) > 0 {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
Inverse: false,
Table: usersubscription.UsageLogsTable,
Columns: []string{usersubscription.UsageLogsColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(usagelog.FieldID, field.TypeInt64),
},
}
for _, k := range nodes {
edge.Target.Nodes = append(edge.Target.Nodes, k)
}
_spec.Edges.Add = append(_spec.Edges.Add, edge)
}
_node = &UserSubscription{config: _u.config} _node = &UserSubscription{config: _u.config}
_spec.Assign = _node.assignValues _spec.Assign = _node.assignValues
_spec.ScanValues = _node.scanValues _spec.ScanValues = _node.scanValues
......
...@@ -3,6 +3,7 @@ package config ...@@ -3,6 +3,7 @@ package config
import ( import (
"fmt" "fmt"
"strings" "strings"
"time"
"github.com/spf13/viper" "github.com/spf13/viper"
) )
...@@ -12,6 +13,20 @@ const ( ...@@ -12,6 +13,20 @@ const (
RunModeSimple = "simple" RunModeSimple = "simple"
) )
// 连接池隔离策略常量
// 用于控制上游 HTTP 连接池的隔离粒度,影响连接复用和资源消耗
const (
// ConnectionPoolIsolationProxy: 按代理隔离
// 同一代理地址共享连接池,适合代理数量少、账户数量多的场景
ConnectionPoolIsolationProxy = "proxy"
// ConnectionPoolIsolationAccount: 按账户隔离
// 每个账户独立连接池,适合账户数量少、需要严格隔离的场景
ConnectionPoolIsolationAccount = "account"
// ConnectionPoolIsolationAccountProxy: 按账户+代理组合隔离(默认)
// 同一账户+代理组合共享连接池,提供最细粒度的隔离
ConnectionPoolIsolationAccountProxy = "account_proxy"
)
type Config struct { type Config struct {
Server ServerConfig `mapstructure:"server"` Server ServerConfig `mapstructure:"server"`
Database DatabaseConfig `mapstructure:"database"` Database DatabaseConfig `mapstructure:"database"`
...@@ -29,6 +44,7 @@ type Config struct { ...@@ -29,6 +44,7 @@ type Config struct {
type GeminiConfig struct { type GeminiConfig struct {
OAuth GeminiOAuthConfig `mapstructure:"oauth"` OAuth GeminiOAuthConfig `mapstructure:"oauth"`
Quota GeminiQuotaConfig `mapstructure:"quota"`
} }
type GeminiOAuthConfig struct { type GeminiOAuthConfig struct {
...@@ -37,6 +53,17 @@ type GeminiOAuthConfig struct { ...@@ -37,6 +53,17 @@ type GeminiOAuthConfig struct {
Scopes string `mapstructure:"scopes"` Scopes string `mapstructure:"scopes"`
} }
type GeminiQuotaConfig struct {
Tiers map[string]GeminiTierQuotaConfig `mapstructure:"tiers"`
Policy string `mapstructure:"policy"`
}
type GeminiTierQuotaConfig struct {
ProRPD *int64 `mapstructure:"pro_rpd" json:"pro_rpd"`
FlashRPD *int64 `mapstructure:"flash_rpd" json:"flash_rpd"`
CooldownMinutes *int `mapstructure:"cooldown_minutes" json:"cooldown_minutes"`
}
// TokenRefreshConfig OAuth token自动刷新配置 // TokenRefreshConfig OAuth token自动刷新配置
type TokenRefreshConfig struct { type TokenRefreshConfig struct {
// 是否启用自动刷新 // 是否启用自动刷新
...@@ -79,12 +106,71 @@ type GatewayConfig struct { ...@@ -79,12 +106,71 @@ type GatewayConfig struct {
// 等待上游响应头的超时时间(秒),0表示无超时 // 等待上游响应头的超时时间(秒),0表示无超时
// 注意:这不影响流式数据传输,只控制等待响应头的时间 // 注意:这不影响流式数据传输,只控制等待响应头的时间
ResponseHeaderTimeout int `mapstructure:"response_header_timeout"` ResponseHeaderTimeout int `mapstructure:"response_header_timeout"`
// 请求体最大字节数,用于网关请求体大小限制
MaxBodySize int64 `mapstructure:"max_body_size"`
// ConnectionPoolIsolation: 上游连接池隔离策略(proxy/account/account_proxy)
ConnectionPoolIsolation string `mapstructure:"connection_pool_isolation"`
// HTTP 上游连接池配置(性能优化:支持高并发场景调优)
// MaxIdleConns: 所有主机的最大空闲连接总数
MaxIdleConns int `mapstructure:"max_idle_conns"`
// MaxIdleConnsPerHost: 每个主机的最大空闲连接数(关键参数,影响连接复用率)
MaxIdleConnsPerHost int `mapstructure:"max_idle_conns_per_host"`
// MaxConnsPerHost: 每个主机的最大连接数(包括活跃+空闲),0表示无限制
MaxConnsPerHost int `mapstructure:"max_conns_per_host"`
// IdleConnTimeoutSeconds: 空闲连接超时时间(秒)
IdleConnTimeoutSeconds int `mapstructure:"idle_conn_timeout_seconds"`
// MaxUpstreamClients: 上游连接池客户端最大缓存数量
// 当使用连接池隔离策略时,系统会为不同的账户/代理组合创建独立的 HTTP 客户端
// 此参数限制缓存的客户端数量,超出后会淘汰最久未使用的客户端
// 建议值:预估的活跃账户数 * 1.2(留有余量)
MaxUpstreamClients int `mapstructure:"max_upstream_clients"`
// ClientIdleTTLSeconds: 上游连接池客户端空闲回收阈值(秒)
// 超过此时间未使用的客户端会被标记为可回收
// 建议值:根据用户访问频率设置,一般 10-30 分钟
ClientIdleTTLSeconds int `mapstructure:"client_idle_ttl_seconds"`
// ConcurrencySlotTTLMinutes: 并发槽位过期时间(分钟)
// 应大于最长 LLM 请求时间,防止请求完成前槽位过期
ConcurrencySlotTTLMinutes int `mapstructure:"concurrency_slot_ttl_minutes"`
// 是否记录上游错误响应体摘要(避免输出请求内容)
LogUpstreamErrorBody bool `mapstructure:"log_upstream_error_body"`
// 上游错误响应体记录最大字节数(超过会截断)
LogUpstreamErrorBodyMaxBytes int `mapstructure:"log_upstream_error_body_max_bytes"`
// API-key 账号在客户端未提供 anthropic-beta 时,是否按需自动补齐(默认关闭以保持兼容)
InjectBetaForApiKey bool `mapstructure:"inject_beta_for_apikey"`
// 是否允许对部分 400 错误触发 failover(默认关闭以避免改变语义)
FailoverOn400 bool `mapstructure:"failover_on_400"`
// Scheduling: 账号调度相关配置
Scheduling GatewaySchedulingConfig `mapstructure:"scheduling"`
}
// GatewaySchedulingConfig accounts scheduling configuration.
type GatewaySchedulingConfig struct {
// 粘性会话排队配置
StickySessionMaxWaiting int `mapstructure:"sticky_session_max_waiting"`
StickySessionWaitTimeout time.Duration `mapstructure:"sticky_session_wait_timeout"`
// 兜底排队配置
FallbackWaitTimeout time.Duration `mapstructure:"fallback_wait_timeout"`
FallbackMaxWaiting int `mapstructure:"fallback_max_waiting"`
// 负载计算
LoadBatchEnabled bool `mapstructure:"load_batch_enabled"`
// 过期槽位清理周期(0 表示禁用)
SlotCleanupInterval time.Duration `mapstructure:"slot_cleanup_interval"`
} }
func (s *ServerConfig) Address() string { func (s *ServerConfig) Address() string {
return fmt.Sprintf("%s:%d", s.Host, s.Port) return fmt.Sprintf("%s:%d", s.Host, s.Port)
} }
// DatabaseConfig 数据库连接配置
// 性能优化:新增连接池参数,避免频繁创建/销毁连接
type DatabaseConfig struct { type DatabaseConfig struct {
Host string `mapstructure:"host"` Host string `mapstructure:"host"`
Port int `mapstructure:"port"` Port int `mapstructure:"port"`
...@@ -92,6 +178,15 @@ type DatabaseConfig struct { ...@@ -92,6 +178,15 @@ type DatabaseConfig struct {
Password string `mapstructure:"password"` Password string `mapstructure:"password"`
DBName string `mapstructure:"dbname"` DBName string `mapstructure:"dbname"`
SSLMode string `mapstructure:"sslmode"` SSLMode string `mapstructure:"sslmode"`
// 连接池配置(性能优化:可配置化连接池参数)
// MaxOpenConns: 最大打开连接数,控制数据库连接上限,防止资源耗尽
MaxOpenConns int `mapstructure:"max_open_conns"`
// MaxIdleConns: 最大空闲连接数,保持热连接减少建连延迟
MaxIdleConns int `mapstructure:"max_idle_conns"`
// ConnMaxLifetimeMinutes: 连接最大存活时间,防止长连接导致的资源泄漏
ConnMaxLifetimeMinutes int `mapstructure:"conn_max_lifetime_minutes"`
// ConnMaxIdleTimeMinutes: 空闲连接最大存活时间,及时释放不活跃连接
ConnMaxIdleTimeMinutes int `mapstructure:"conn_max_idle_time_minutes"`
} }
func (d *DatabaseConfig) DSN() string { func (d *DatabaseConfig) DSN() string {
...@@ -112,11 +207,24 @@ func (d *DatabaseConfig) DSNWithTimezone(tz string) string { ...@@ -112,11 +207,24 @@ func (d *DatabaseConfig) DSNWithTimezone(tz string) string {
) )
} }
// RedisConfig Redis 连接配置
// 性能优化:新增连接池和超时参数,提升高并发场景下的吞吐量
type RedisConfig struct { type RedisConfig struct {
Host string `mapstructure:"host"` Host string `mapstructure:"host"`
Port int `mapstructure:"port"` Port int `mapstructure:"port"`
Password string `mapstructure:"password"` Password string `mapstructure:"password"`
DB int `mapstructure:"db"` DB int `mapstructure:"db"`
// 连接池与超时配置(性能优化:可配置化连接池参数)
// DialTimeoutSeconds: 建立连接超时,防止慢连接阻塞
DialTimeoutSeconds int `mapstructure:"dial_timeout_seconds"`
// ReadTimeoutSeconds: 读取超时,避免慢查询阻塞连接池
ReadTimeoutSeconds int `mapstructure:"read_timeout_seconds"`
// WriteTimeoutSeconds: 写入超时,避免慢写入阻塞连接池
WriteTimeoutSeconds int `mapstructure:"write_timeout_seconds"`
// PoolSize: 连接池大小,控制最大并发连接数
PoolSize int `mapstructure:"pool_size"`
// MinIdleConns: 最小空闲连接数,保持热连接减少冷启动延迟
MinIdleConns int `mapstructure:"min_idle_conns"`
} }
func (r *RedisConfig) Address() string { func (r *RedisConfig) Address() string {
...@@ -203,12 +311,21 @@ func setDefaults() { ...@@ -203,12 +311,21 @@ func setDefaults() {
viper.SetDefault("database.password", "postgres") viper.SetDefault("database.password", "postgres")
viper.SetDefault("database.dbname", "sub2api") viper.SetDefault("database.dbname", "sub2api")
viper.SetDefault("database.sslmode", "disable") viper.SetDefault("database.sslmode", "disable")
viper.SetDefault("database.max_open_conns", 50)
viper.SetDefault("database.max_idle_conns", 10)
viper.SetDefault("database.conn_max_lifetime_minutes", 30)
viper.SetDefault("database.conn_max_idle_time_minutes", 5)
// Redis // Redis
viper.SetDefault("redis.host", "localhost") viper.SetDefault("redis.host", "localhost")
viper.SetDefault("redis.port", 6379) viper.SetDefault("redis.port", 6379)
viper.SetDefault("redis.password", "") viper.SetDefault("redis.password", "")
viper.SetDefault("redis.db", 0) viper.SetDefault("redis.db", 0)
viper.SetDefault("redis.dial_timeout_seconds", 5)
viper.SetDefault("redis.read_timeout_seconds", 3)
viper.SetDefault("redis.write_timeout_seconds", 3)
viper.SetDefault("redis.pool_size", 128)
viper.SetDefault("redis.min_idle_conns", 10)
// JWT // JWT
viper.SetDefault("jwt.secret", "change-me-in-production") viper.SetDefault("jwt.secret", "change-me-in-production")
...@@ -240,6 +357,26 @@ func setDefaults() { ...@@ -240,6 +357,26 @@ func setDefaults() {
// Gateway // Gateway
viper.SetDefault("gateway.response_header_timeout", 300) // 300秒(5分钟)等待上游响应头,LLM高负载时可能排队较久 viper.SetDefault("gateway.response_header_timeout", 300) // 300秒(5分钟)等待上游响应头,LLM高负载时可能排队较久
viper.SetDefault("gateway.log_upstream_error_body", false)
viper.SetDefault("gateway.log_upstream_error_body_max_bytes", 2048)
viper.SetDefault("gateway.inject_beta_for_apikey", false)
viper.SetDefault("gateway.failover_on_400", false)
viper.SetDefault("gateway.max_body_size", int64(100*1024*1024))
viper.SetDefault("gateway.connection_pool_isolation", ConnectionPoolIsolationAccountProxy)
// HTTP 上游连接池配置(针对 5000+ 并发用户优化)
viper.SetDefault("gateway.max_idle_conns", 240) // 最大空闲连接总数(HTTP/2 场景默认)
viper.SetDefault("gateway.max_idle_conns_per_host", 120) // 每主机最大空闲连接(HTTP/2 场景默认)
viper.SetDefault("gateway.max_conns_per_host", 240) // 每主机最大连接数(含活跃,HTTP/2 场景默认)
viper.SetDefault("gateway.idle_conn_timeout_seconds", 300) // 空闲连接超时(秒)
viper.SetDefault("gateway.max_upstream_clients", 5000)
viper.SetDefault("gateway.client_idle_ttl_seconds", 900)
viper.SetDefault("gateway.concurrency_slot_ttl_minutes", 15) // 并发槽位过期时间(支持超长请求)
viper.SetDefault("gateway.scheduling.sticky_session_max_waiting", 3)
viper.SetDefault("gateway.scheduling.sticky_session_wait_timeout", 45*time.Second)
viper.SetDefault("gateway.scheduling.fallback_wait_timeout", 30*time.Second)
viper.SetDefault("gateway.scheduling.fallback_max_waiting", 100)
viper.SetDefault("gateway.scheduling.load_batch_enabled", true)
viper.SetDefault("gateway.scheduling.slot_cleanup_interval", 30*time.Second)
// TokenRefresh // TokenRefresh
viper.SetDefault("token_refresh.enabled", true) viper.SetDefault("token_refresh.enabled", true)
...@@ -254,6 +391,7 @@ func setDefaults() { ...@@ -254,6 +391,7 @@ func setDefaults() {
viper.SetDefault("gemini.oauth.client_id", "") viper.SetDefault("gemini.oauth.client_id", "")
viper.SetDefault("gemini.oauth.client_secret", "") viper.SetDefault("gemini.oauth.client_secret", "")
viper.SetDefault("gemini.oauth.scopes", "") viper.SetDefault("gemini.oauth.scopes", "")
viper.SetDefault("gemini.quota.policy", "")
} }
func (c *Config) Validate() error { func (c *Config) Validate() error {
...@@ -263,6 +401,86 @@ func (c *Config) Validate() error { ...@@ -263,6 +401,86 @@ func (c *Config) Validate() error {
if c.JWT.Secret == "change-me-in-production" && c.Server.Mode == "release" { if c.JWT.Secret == "change-me-in-production" && c.Server.Mode == "release" {
return fmt.Errorf("jwt.secret must be changed in production") return fmt.Errorf("jwt.secret must be changed in production")
} }
if c.Database.MaxOpenConns <= 0 {
return fmt.Errorf("database.max_open_conns must be positive")
}
if c.Database.MaxIdleConns < 0 {
return fmt.Errorf("database.max_idle_conns must be non-negative")
}
if c.Database.MaxIdleConns > c.Database.MaxOpenConns {
return fmt.Errorf("database.max_idle_conns cannot exceed database.max_open_conns")
}
if c.Database.ConnMaxLifetimeMinutes < 0 {
return fmt.Errorf("database.conn_max_lifetime_minutes must be non-negative")
}
if c.Database.ConnMaxIdleTimeMinutes < 0 {
return fmt.Errorf("database.conn_max_idle_time_minutes must be non-negative")
}
if c.Redis.DialTimeoutSeconds <= 0 {
return fmt.Errorf("redis.dial_timeout_seconds must be positive")
}
if c.Redis.ReadTimeoutSeconds <= 0 {
return fmt.Errorf("redis.read_timeout_seconds must be positive")
}
if c.Redis.WriteTimeoutSeconds <= 0 {
return fmt.Errorf("redis.write_timeout_seconds must be positive")
}
if c.Redis.PoolSize <= 0 {
return fmt.Errorf("redis.pool_size must be positive")
}
if c.Redis.MinIdleConns < 0 {
return fmt.Errorf("redis.min_idle_conns must be non-negative")
}
if c.Redis.MinIdleConns > c.Redis.PoolSize {
return fmt.Errorf("redis.min_idle_conns cannot exceed redis.pool_size")
}
if c.Gateway.MaxBodySize <= 0 {
return fmt.Errorf("gateway.max_body_size must be positive")
}
if strings.TrimSpace(c.Gateway.ConnectionPoolIsolation) != "" {
switch c.Gateway.ConnectionPoolIsolation {
case ConnectionPoolIsolationProxy, ConnectionPoolIsolationAccount, ConnectionPoolIsolationAccountProxy:
default:
return fmt.Errorf("gateway.connection_pool_isolation must be one of: %s/%s/%s",
ConnectionPoolIsolationProxy, ConnectionPoolIsolationAccount, ConnectionPoolIsolationAccountProxy)
}
}
if c.Gateway.MaxIdleConns <= 0 {
return fmt.Errorf("gateway.max_idle_conns must be positive")
}
if c.Gateway.MaxIdleConnsPerHost <= 0 {
return fmt.Errorf("gateway.max_idle_conns_per_host must be positive")
}
if c.Gateway.MaxConnsPerHost < 0 {
return fmt.Errorf("gateway.max_conns_per_host must be non-negative")
}
if c.Gateway.IdleConnTimeoutSeconds <= 0 {
return fmt.Errorf("gateway.idle_conn_timeout_seconds must be positive")
}
if c.Gateway.MaxUpstreamClients <= 0 {
return fmt.Errorf("gateway.max_upstream_clients must be positive")
}
if c.Gateway.ClientIdleTTLSeconds <= 0 {
return fmt.Errorf("gateway.client_idle_ttl_seconds must be positive")
}
if c.Gateway.ConcurrencySlotTTLMinutes <= 0 {
return fmt.Errorf("gateway.concurrency_slot_ttl_minutes must be positive")
}
if c.Gateway.Scheduling.StickySessionMaxWaiting <= 0 {
return fmt.Errorf("gateway.scheduling.sticky_session_max_waiting must be positive")
}
if c.Gateway.Scheduling.StickySessionWaitTimeout <= 0 {
return fmt.Errorf("gateway.scheduling.sticky_session_wait_timeout must be positive")
}
if c.Gateway.Scheduling.FallbackWaitTimeout <= 0 {
return fmt.Errorf("gateway.scheduling.fallback_wait_timeout must be positive")
}
if c.Gateway.Scheduling.FallbackMaxWaiting <= 0 {
return fmt.Errorf("gateway.scheduling.fallback_max_waiting must be positive")
}
if c.Gateway.Scheduling.SlotCleanupInterval < 0 {
return fmt.Errorf("gateway.scheduling.slot_cleanup_interval must be non-negative")
}
return nil return nil
} }
......
package config package config
import "testing" import (
"testing"
"time"
"github.com/spf13/viper"
)
func TestNormalizeRunMode(t *testing.T) { func TestNormalizeRunMode(t *testing.T) {
tests := []struct { tests := []struct {
...@@ -21,3 +26,45 @@ func TestNormalizeRunMode(t *testing.T) { ...@@ -21,3 +26,45 @@ func TestNormalizeRunMode(t *testing.T) {
} }
} }
} }
func TestLoadDefaultSchedulingConfig(t *testing.T) {
viper.Reset()
cfg, err := Load()
if err != nil {
t.Fatalf("Load() error: %v", err)
}
if cfg.Gateway.Scheduling.StickySessionMaxWaiting != 3 {
t.Fatalf("StickySessionMaxWaiting = %d, want 3", cfg.Gateway.Scheduling.StickySessionMaxWaiting)
}
if cfg.Gateway.Scheduling.StickySessionWaitTimeout != 45*time.Second {
t.Fatalf("StickySessionWaitTimeout = %v, want 45s", cfg.Gateway.Scheduling.StickySessionWaitTimeout)
}
if cfg.Gateway.Scheduling.FallbackWaitTimeout != 30*time.Second {
t.Fatalf("FallbackWaitTimeout = %v, want 30s", cfg.Gateway.Scheduling.FallbackWaitTimeout)
}
if cfg.Gateway.Scheduling.FallbackMaxWaiting != 100 {
t.Fatalf("FallbackMaxWaiting = %d, want 100", cfg.Gateway.Scheduling.FallbackMaxWaiting)
}
if !cfg.Gateway.Scheduling.LoadBatchEnabled {
t.Fatalf("LoadBatchEnabled = false, want true")
}
if cfg.Gateway.Scheduling.SlotCleanupInterval != 30*time.Second {
t.Fatalf("SlotCleanupInterval = %v, want 30s", cfg.Gateway.Scheduling.SlotCleanupInterval)
}
}
func TestLoadSchedulingConfigFromEnv(t *testing.T) {
viper.Reset()
t.Setenv("GATEWAY_SCHEDULING_STICKY_SESSION_MAX_WAITING", "5")
cfg, err := Load()
if err != nil {
t.Fatalf("Load() error: %v", err)
}
if cfg.Gateway.Scheduling.StickySessionMaxWaiting != 5 {
t.Fatalf("StickySessionMaxWaiting = %d, want 5", cfg.Gateway.Scheduling.StickySessionMaxWaiting)
}
}
...@@ -10,15 +10,17 @@ import ( ...@@ -10,15 +10,17 @@ import (
// SettingHandler 系统设置处理器 // SettingHandler 系统设置处理器
type SettingHandler struct { type SettingHandler struct {
settingService *service.SettingService settingService *service.SettingService
emailService *service.EmailService emailService *service.EmailService
turnstileService *service.TurnstileService
} }
// NewSettingHandler 创建系统设置处理器 // NewSettingHandler 创建系统设置处理器
func NewSettingHandler(settingService *service.SettingService, emailService *service.EmailService) *SettingHandler { func NewSettingHandler(settingService *service.SettingService, emailService *service.EmailService, turnstileService *service.TurnstileService) *SettingHandler {
return &SettingHandler{ return &SettingHandler{
settingService: settingService, settingService: settingService,
emailService: emailService, emailService: emailService,
turnstileService: turnstileService,
} }
} }
...@@ -108,6 +110,36 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { ...@@ -108,6 +110,36 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
req.SmtpPort = 587 req.SmtpPort = 587
} }
// Turnstile 参数验证
if req.TurnstileEnabled {
// 检查必填字段
if req.TurnstileSiteKey == "" {
response.BadRequest(c, "Turnstile Site Key is required when enabled")
return
}
if req.TurnstileSecretKey == "" {
response.BadRequest(c, "Turnstile Secret Key is required when enabled")
return
}
// 获取当前设置,检查参数是否有变化
currentSettings, err := h.settingService.GetAllSettings(c.Request.Context())
if err != nil {
response.ErrorFrom(c, err)
return
}
// 当 site_key 或 secret_key 任一变化时验证(避免配置错误导致无法登录)
siteKeyChanged := currentSettings.TurnstileSiteKey != req.TurnstileSiteKey
secretKeyChanged := currentSettings.TurnstileSecretKey != req.TurnstileSecretKey
if siteKeyChanged || secretKeyChanged {
if err := h.turnstileService.ValidateSecretKey(c.Request.Context(), req.TurnstileSecretKey); err != nil {
response.ErrorFrom(c, err)
return
}
}
}
settings := &service.SystemSettings{ settings := &service.SystemSettings{
RegistrationEnabled: req.RegistrationEnabled, RegistrationEnabled: req.RegistrationEnabled,
EmailVerifyEnabled: req.EmailVerifyEnabled, EmailVerifyEnabled: req.EmailVerifyEnabled,
......
...@@ -67,6 +67,10 @@ func (h *GatewayHandler) Messages(c *gin.Context) { ...@@ -67,6 +67,10 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
// 读取请求体 // 读取请求体
body, err := io.ReadAll(c.Request.Body) body, err := io.ReadAll(c.Request.Body)
if err != nil { if err != nil {
if maxErr, ok := extractMaxBytesError(err); ok {
h.errorResponse(c, http.StatusRequestEntityTooLarge, "invalid_request_error", buildBodyTooLargeMessage(maxErr.Limit))
return
}
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to read request body") h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to read request body")
return return
} }
...@@ -76,15 +80,19 @@ func (h *GatewayHandler) Messages(c *gin.Context) { ...@@ -76,15 +80,19 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
return return
} }
// 解析请求获取模型名和stream parsedReq, err := service.ParseGatewayRequest(body)
var req struct { if err != nil {
Model string `json:"model"`
Stream bool `json:"stream"`
}
if err := json.Unmarshal(body, &req); err != nil {
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body") h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body")
return return
} }
reqModel := parsedReq.Model
reqStream := parsedReq.Stream
// 验证 model 必填
if reqModel == "" {
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "model is required")
return
}
// Track if we've started streaming (for error handling) // Track if we've started streaming (for error handling)
streamStarted := false streamStarted := false
...@@ -106,7 +114,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) { ...@@ -106,7 +114,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
defer h.concurrencyHelper.DecrementWaitCount(c.Request.Context(), subject.UserID) defer h.concurrencyHelper.DecrementWaitCount(c.Request.Context(), subject.UserID)
// 1. 首先获取用户并发槽位 // 1. 首先获取用户并发槽位
userReleaseFunc, err := h.concurrencyHelper.AcquireUserSlotWithWait(c, subject.UserID, subject.Concurrency, req.Stream, &streamStarted) userReleaseFunc, err := h.concurrencyHelper.AcquireUserSlotWithWait(c, subject.UserID, subject.Concurrency, reqStream, &streamStarted)
if err != nil { if err != nil {
log.Printf("User concurrency acquire failed: %v", err) log.Printf("User concurrency acquire failed: %v", err)
h.handleConcurrencyError(c, err, "user", streamStarted) h.handleConcurrencyError(c, err, "user", streamStarted)
...@@ -124,7 +132,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) { ...@@ -124,7 +132,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
} }
// 计算粘性会话hash // 计算粘性会话hash
sessionHash := h.gatewayService.GenerateSessionHash(body) sessionHash := h.gatewayService.GenerateSessionHash(parsedReq)
// 获取平台:优先使用强制平台(/antigravity 路由,中间件已设置 request.Context),否则使用分组平台 // 获取平台:优先使用强制平台(/antigravity 路由,中间件已设置 request.Context),否则使用分组平台
platform := "" platform := ""
...@@ -133,6 +141,10 @@ func (h *GatewayHandler) Messages(c *gin.Context) { ...@@ -133,6 +141,10 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
} else if apiKey.Group != nil { } else if apiKey.Group != nil {
platform = apiKey.Group.Platform platform = apiKey.Group.Platform
} }
sessionKey := sessionHash
if platform == service.PlatformGemini && sessionHash != "" {
sessionKey = "gemini:" + sessionHash
}
if platform == service.PlatformGemini { if platform == service.PlatformGemini {
const maxAccountSwitches = 3 const maxAccountSwitches = 3
...@@ -141,7 +153,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) { ...@@ -141,7 +153,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
lastFailoverStatus := 0 lastFailoverStatus := 0
for { for {
account, err := h.geminiCompatService.SelectAccountForModelWithExclusions(c.Request.Context(), apiKey.GroupID, sessionHash, req.Model, failedAccountIDs) selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, reqModel, failedAccountIDs)
if err != nil { if err != nil {
if len(failedAccountIDs) == 0 { if len(failedAccountIDs) == 0 {
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted) h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted)
...@@ -150,35 +162,77 @@ func (h *GatewayHandler) Messages(c *gin.Context) { ...@@ -150,35 +162,77 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
h.handleFailoverExhausted(c, lastFailoverStatus, streamStarted) h.handleFailoverExhausted(c, lastFailoverStatus, streamStarted)
return return
} }
account := selection.Account
// 检查预热请求拦截(在账号选择后、转发前检查) // 检查预热请求拦截(在账号选择后、转发前检查)
if account.IsInterceptWarmupEnabled() && isWarmupRequest(body) { if account.IsInterceptWarmupEnabled() && isWarmupRequest(body) {
if req.Stream { if selection.Acquired && selection.ReleaseFunc != nil {
sendMockWarmupStream(c, req.Model) selection.ReleaseFunc()
}
if reqStream {
sendMockWarmupStream(c, reqModel)
} else { } else {
sendMockWarmupResponse(c, req.Model) sendMockWarmupResponse(c, reqModel)
} }
return return
} }
// 3. 获取账号并发槽位 // 3. 获取账号并发槽位
accountReleaseFunc, err := h.concurrencyHelper.AcquireAccountSlotWithWait(c, account.ID, account.Concurrency, req.Stream, &streamStarted) accountReleaseFunc := selection.ReleaseFunc
if err != nil { var accountWaitRelease func()
log.Printf("Account concurrency acquire failed: %v", err) if !selection.Acquired {
h.handleConcurrencyError(c, err, "account", streamStarted) if selection.WaitPlan == nil {
return h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts", streamStarted)
return
}
canWait, err := h.concurrencyHelper.IncrementAccountWaitCount(c.Request.Context(), account.ID, selection.WaitPlan.MaxWaiting)
if err != nil {
log.Printf("Increment account wait count failed: %v", err)
} else if !canWait {
log.Printf("Account wait queue full: account=%d", account.ID)
h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later", streamStarted)
return
} else {
// Only set release function if increment succeeded
accountWaitRelease = func() {
h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID)
}
}
accountReleaseFunc, err = h.concurrencyHelper.AcquireAccountSlotWithWaitTimeout(
c,
account.ID,
selection.WaitPlan.MaxConcurrency,
selection.WaitPlan.Timeout,
reqStream,
&streamStarted,
)
if err != nil {
if accountWaitRelease != nil {
accountWaitRelease()
}
log.Printf("Account concurrency acquire failed: %v", err)
h.handleConcurrencyError(c, err, "account", streamStarted)
return
}
if err := h.gatewayService.BindStickySession(c.Request.Context(), sessionKey, account.ID); err != nil {
log.Printf("Bind sticky session failed: %v", err)
}
} }
// 转发请求 - 根据账号平台分流 // 转发请求 - 根据账号平台分流
var result *service.ForwardResult var result *service.ForwardResult
if account.Platform == service.PlatformAntigravity { if account.Platform == service.PlatformAntigravity {
result, err = h.antigravityGatewayService.ForwardGemini(c.Request.Context(), c, account, req.Model, "generateContent", req.Stream, body) result, err = h.antigravityGatewayService.ForwardGemini(c.Request.Context(), c, account, reqModel, "generateContent", reqStream, body)
} else { } else {
result, err = h.geminiCompatService.Forward(c.Request.Context(), c, account, body) result, err = h.geminiCompatService.Forward(c.Request.Context(), c, account, body)
} }
if accountReleaseFunc != nil { if accountReleaseFunc != nil {
accountReleaseFunc() accountReleaseFunc()
} }
if accountWaitRelease != nil {
accountWaitRelease()
}
if err != nil { if err != nil {
var failoverErr *service.UpstreamFailoverError var failoverErr *service.UpstreamFailoverError
if errors.As(err, &failoverErr) { if errors.As(err, &failoverErr) {
...@@ -223,7 +277,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) { ...@@ -223,7 +277,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
for { for {
// 选择支持该模型的账号 // 选择支持该模型的账号
account, err := h.gatewayService.SelectAccountForModelWithExclusions(c.Request.Context(), apiKey.GroupID, sessionHash, req.Model, failedAccountIDs) selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, reqModel, failedAccountIDs)
if err != nil { if err != nil {
if len(failedAccountIDs) == 0 { if len(failedAccountIDs) == 0 {
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted) h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted)
...@@ -232,23 +286,62 @@ func (h *GatewayHandler) Messages(c *gin.Context) { ...@@ -232,23 +286,62 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
h.handleFailoverExhausted(c, lastFailoverStatus, streamStarted) h.handleFailoverExhausted(c, lastFailoverStatus, streamStarted)
return return
} }
account := selection.Account
// 检查预热请求拦截(在账号选择后、转发前检查) // 检查预热请求拦截(在账号选择后、转发前检查)
if account.IsInterceptWarmupEnabled() && isWarmupRequest(body) { if account.IsInterceptWarmupEnabled() && isWarmupRequest(body) {
if req.Stream { if selection.Acquired && selection.ReleaseFunc != nil {
sendMockWarmupStream(c, req.Model) selection.ReleaseFunc()
}
if reqStream {
sendMockWarmupStream(c, reqModel)
} else { } else {
sendMockWarmupResponse(c, req.Model) sendMockWarmupResponse(c, reqModel)
} }
return return
} }
// 3. 获取账号并发槽位 // 3. 获取账号并发槽位
accountReleaseFunc, err := h.concurrencyHelper.AcquireAccountSlotWithWait(c, account.ID, account.Concurrency, req.Stream, &streamStarted) accountReleaseFunc := selection.ReleaseFunc
if err != nil { var accountWaitRelease func()
log.Printf("Account concurrency acquire failed: %v", err) if !selection.Acquired {
h.handleConcurrencyError(c, err, "account", streamStarted) if selection.WaitPlan == nil {
return h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts", streamStarted)
return
}
canWait, err := h.concurrencyHelper.IncrementAccountWaitCount(c.Request.Context(), account.ID, selection.WaitPlan.MaxWaiting)
if err != nil {
log.Printf("Increment account wait count failed: %v", err)
} else if !canWait {
log.Printf("Account wait queue full: account=%d", account.ID)
h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later", streamStarted)
return
} else {
// Only set release function if increment succeeded
accountWaitRelease = func() {
h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID)
}
}
accountReleaseFunc, err = h.concurrencyHelper.AcquireAccountSlotWithWaitTimeout(
c,
account.ID,
selection.WaitPlan.MaxConcurrency,
selection.WaitPlan.Timeout,
reqStream,
&streamStarted,
)
if err != nil {
if accountWaitRelease != nil {
accountWaitRelease()
}
log.Printf("Account concurrency acquire failed: %v", err)
h.handleConcurrencyError(c, err, "account", streamStarted)
return
}
if err := h.gatewayService.BindStickySession(c.Request.Context(), sessionKey, account.ID); err != nil {
log.Printf("Bind sticky session failed: %v", err)
}
} }
// 转发请求 - 根据账号平台分流 // 转发请求 - 根据账号平台分流
...@@ -256,11 +349,14 @@ func (h *GatewayHandler) Messages(c *gin.Context) { ...@@ -256,11 +349,14 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
if account.Platform == service.PlatformAntigravity { if account.Platform == service.PlatformAntigravity {
result, err = h.antigravityGatewayService.Forward(c.Request.Context(), c, account, body) result, err = h.antigravityGatewayService.Forward(c.Request.Context(), c, account, body)
} else { } else {
result, err = h.gatewayService.Forward(c.Request.Context(), c, account, body) result, err = h.gatewayService.Forward(c.Request.Context(), c, account, parsedReq)
} }
if accountReleaseFunc != nil { if accountReleaseFunc != nil {
accountReleaseFunc() accountReleaseFunc()
} }
if accountWaitRelease != nil {
accountWaitRelease()
}
if err != nil { if err != nil {
var failoverErr *service.UpstreamFailoverError var failoverErr *service.UpstreamFailoverError
if errors.As(err, &failoverErr) { if errors.As(err, &failoverErr) {
...@@ -525,6 +621,10 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) { ...@@ -525,6 +621,10 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) {
// 读取请求体 // 读取请求体
body, err := io.ReadAll(c.Request.Body) body, err := io.ReadAll(c.Request.Body)
if err != nil { if err != nil {
if maxErr, ok := extractMaxBytesError(err); ok {
h.errorResponse(c, http.StatusRequestEntityTooLarge, "invalid_request_error", buildBodyTooLargeMessage(maxErr.Limit))
return
}
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to read request body") h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to read request body")
return return
} }
...@@ -534,15 +634,18 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) { ...@@ -534,15 +634,18 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) {
return return
} }
// 解析请求获取模型名 parsedReq, err := service.ParseGatewayRequest(body)
var req struct { if err != nil {
Model string `json:"model"`
}
if err := json.Unmarshal(body, &req); err != nil {
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body") h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body")
return return
} }
// 验证 model 必填
if parsedReq.Model == "" {
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "model is required")
return
}
// 获取订阅信息(可能为nil) // 获取订阅信息(可能为nil)
subscription, _ := middleware2.GetSubscriptionFromContext(c) subscription, _ := middleware2.GetSubscriptionFromContext(c)
...@@ -554,17 +657,17 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) { ...@@ -554,17 +657,17 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) {
} }
// 计算粘性会话 hash // 计算粘性会话 hash
sessionHash := h.gatewayService.GenerateSessionHash(body) sessionHash := h.gatewayService.GenerateSessionHash(parsedReq)
// 选择支持该模型的账号 // 选择支持该模型的账号
account, err := h.gatewayService.SelectAccountForModel(c.Request.Context(), apiKey.GroupID, sessionHash, req.Model) account, err := h.gatewayService.SelectAccountForModel(c.Request.Context(), apiKey.GroupID, sessionHash, parsedReq.Model)
if err != nil { if err != nil {
h.errorResponse(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error()) h.errorResponse(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error())
return return
} }
// 转发请求(不记录使用量) // 转发请求(不记录使用量)
if err := h.gatewayService.ForwardCountTokens(c.Request.Context(), c, account, body); err != nil { if err := h.gatewayService.ForwardCountTokens(c.Request.Context(), c, account, parsedReq); err != nil {
log.Printf("Forward count_tokens request failed: %v", err) log.Printf("Forward count_tokens request failed: %v", err)
// 错误响应已在 ForwardCountTokens 中处理 // 错误响应已在 ForwardCountTokens 中处理
return return
......
...@@ -3,6 +3,7 @@ package handler ...@@ -3,6 +3,7 @@ package handler
import ( import (
"context" "context"
"fmt" "fmt"
"math/rand"
"net/http" "net/http"
"time" "time"
...@@ -11,11 +12,28 @@ import ( ...@@ -11,11 +12,28 @@ import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )
// 并发槽位等待相关常量
//
// 性能优化说明:
// 原实现使用固定间隔(100ms)轮询并发槽位,存在以下问题:
// 1. 高并发时频繁轮询增加 Redis 压力
// 2. 固定间隔可能导致多个请求同时重试(惊群效应)
//
// 新实现使用指数退避 + 抖动算法:
// 1. 初始退避 100ms,每次乘以 1.5,最大 2s
// 2. 添加 ±20% 的随机抖动,分散重试时间点
// 3. 减少 Redis 压力,避免惊群效应
const ( const (
// maxConcurrencyWait is the maximum time to wait for a concurrency slot // maxConcurrencyWait 等待并发槽位的最大时间
maxConcurrencyWait = 30 * time.Second maxConcurrencyWait = 30 * time.Second
// pingInterval is the interval for sending ping events during slot wait // pingInterval 流式响应等待时发送 ping 的间隔
pingInterval = 15 * time.Second pingInterval = 15 * time.Second
// initialBackoff 初始退避时间
initialBackoff = 100 * time.Millisecond
// backoffMultiplier 退避时间乘数(指数退避)
backoffMultiplier = 1.5
// maxBackoff 最大退避时间
maxBackoff = 2 * time.Second
) )
// SSEPingFormat defines the format of SSE ping events for different platforms // SSEPingFormat defines the format of SSE ping events for different platforms
...@@ -65,6 +83,16 @@ func (h *ConcurrencyHelper) DecrementWaitCount(ctx context.Context, userID int64 ...@@ -65,6 +83,16 @@ func (h *ConcurrencyHelper) DecrementWaitCount(ctx context.Context, userID int64
h.concurrencyService.DecrementWaitCount(ctx, userID) h.concurrencyService.DecrementWaitCount(ctx, userID)
} }
// IncrementAccountWaitCount increments the wait count for an account
func (h *ConcurrencyHelper) IncrementAccountWaitCount(ctx context.Context, accountID int64, maxWait int) (bool, error) {
return h.concurrencyService.IncrementAccountWaitCount(ctx, accountID, maxWait)
}
// DecrementAccountWaitCount decrements the wait count for an account
func (h *ConcurrencyHelper) DecrementAccountWaitCount(ctx context.Context, accountID int64) {
h.concurrencyService.DecrementAccountWaitCount(ctx, accountID)
}
// AcquireUserSlotWithWait acquires a user concurrency slot, waiting if necessary. // AcquireUserSlotWithWait acquires a user concurrency slot, waiting if necessary.
// For streaming requests, sends ping events during the wait. // For streaming requests, sends ping events during the wait.
// streamStarted is updated if streaming response has begun. // streamStarted is updated if streaming response has begun.
...@@ -108,7 +136,12 @@ func (h *ConcurrencyHelper) AcquireAccountSlotWithWait(c *gin.Context, accountID ...@@ -108,7 +136,12 @@ func (h *ConcurrencyHelper) AcquireAccountSlotWithWait(c *gin.Context, accountID
// waitForSlotWithPing waits for a concurrency slot, sending ping events for streaming requests. // waitForSlotWithPing waits for a concurrency slot, sending ping events for streaming requests.
// streamStarted pointer is updated when streaming begins (for proper error handling by caller). // streamStarted pointer is updated when streaming begins (for proper error handling by caller).
func (h *ConcurrencyHelper) waitForSlotWithPing(c *gin.Context, slotType string, id int64, maxConcurrency int, isStream bool, streamStarted *bool) (func(), error) { func (h *ConcurrencyHelper) waitForSlotWithPing(c *gin.Context, slotType string, id int64, maxConcurrency int, isStream bool, streamStarted *bool) (func(), error) {
ctx, cancel := context.WithTimeout(c.Request.Context(), maxConcurrencyWait) return h.waitForSlotWithPingTimeout(c, slotType, id, maxConcurrency, maxConcurrencyWait, isStream, streamStarted)
}
// waitForSlotWithPingTimeout waits for a concurrency slot with a custom timeout.
func (h *ConcurrencyHelper) waitForSlotWithPingTimeout(c *gin.Context, slotType string, id int64, maxConcurrency int, timeout time.Duration, isStream bool, streamStarted *bool) (func(), error) {
ctx, cancel := context.WithTimeout(c.Request.Context(), timeout)
defer cancel() defer cancel()
// Determine if ping is needed (streaming + ping format defined) // Determine if ping is needed (streaming + ping format defined)
...@@ -131,8 +164,10 @@ func (h *ConcurrencyHelper) waitForSlotWithPing(c *gin.Context, slotType string, ...@@ -131,8 +164,10 @@ func (h *ConcurrencyHelper) waitForSlotWithPing(c *gin.Context, slotType string,
pingCh = pingTicker.C pingCh = pingTicker.C
} }
pollTicker := time.NewTicker(100 * time.Millisecond) backoff := initialBackoff
defer pollTicker.Stop() timer := time.NewTimer(backoff)
defer timer.Stop()
rng := rand.New(rand.NewSource(time.Now().UnixNano()))
for { for {
select { select {
...@@ -156,7 +191,7 @@ func (h *ConcurrencyHelper) waitForSlotWithPing(c *gin.Context, slotType string, ...@@ -156,7 +191,7 @@ func (h *ConcurrencyHelper) waitForSlotWithPing(c *gin.Context, slotType string,
} }
flusher.Flush() flusher.Flush()
case <-pollTicker.C: case <-timer.C:
// Try to acquire slot // Try to acquire slot
var result *service.AcquireResult var result *service.AcquireResult
var err error var err error
...@@ -174,6 +209,40 @@ func (h *ConcurrencyHelper) waitForSlotWithPing(c *gin.Context, slotType string, ...@@ -174,6 +209,40 @@ func (h *ConcurrencyHelper) waitForSlotWithPing(c *gin.Context, slotType string,
if result.Acquired { if result.Acquired {
return result.ReleaseFunc, nil return result.ReleaseFunc, nil
} }
backoff = nextBackoff(backoff, rng)
timer.Reset(backoff)
} }
} }
} }
// AcquireAccountSlotWithWaitTimeout acquires an account slot with a custom timeout (keeps SSE ping).
func (h *ConcurrencyHelper) AcquireAccountSlotWithWaitTimeout(c *gin.Context, accountID int64, maxConcurrency int, timeout time.Duration, isStream bool, streamStarted *bool) (func(), error) {
return h.waitForSlotWithPingTimeout(c, "account", accountID, maxConcurrency, timeout, isStream, streamStarted)
}
// nextBackoff 计算下一次退避时间
// 性能优化:使用指数退避 + 随机抖动,避免惊群效应
// current: 当前退避时间
// rng: 随机数生成器(可为 nil,此时不添加抖动)
// 返回值:下一次退避时间(100ms ~ 2s 之间)
func nextBackoff(current time.Duration, rng *rand.Rand) time.Duration {
// 指数退避:当前时间 * 1.5
next := time.Duration(float64(current) * backoffMultiplier)
if next > maxBackoff {
next = maxBackoff
}
if rng == nil {
return next
}
// 添加 ±20% 的随机抖动(jitter 范围 0.8 ~ 1.2)
// 抖动可以分散多个请求的重试时间点,避免同时冲击 Redis
jitter := 0.8 + rng.Float64()*0.4
jittered := time.Duration(float64(next) * jitter)
if jittered < initialBackoff {
return initialBackoff
}
if jittered > maxBackoff {
return maxBackoff
}
return jittered
}
...@@ -148,6 +148,10 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { ...@@ -148,6 +148,10 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
body, err := io.ReadAll(c.Request.Body) body, err := io.ReadAll(c.Request.Body)
if err != nil { if err != nil {
if maxErr, ok := extractMaxBytesError(err); ok {
googleError(c, http.StatusRequestEntityTooLarge, buildBodyTooLargeMessage(maxErr.Limit))
return
}
googleError(c, http.StatusBadRequest, "Failed to read request body") googleError(c, http.StatusBadRequest, "Failed to read request body")
return return
} }
...@@ -191,14 +195,19 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { ...@@ -191,14 +195,19 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
} }
// 3) select account (sticky session based on request body) // 3) select account (sticky session based on request body)
sessionHash := h.gatewayService.GenerateSessionHash(body) parsedReq, _ := service.ParseGatewayRequest(body)
sessionHash := h.gatewayService.GenerateSessionHash(parsedReq)
sessionKey := sessionHash
if sessionHash != "" {
sessionKey = "gemini:" + sessionHash
}
const maxAccountSwitches = 3 const maxAccountSwitches = 3
switchCount := 0 switchCount := 0
failedAccountIDs := make(map[int64]struct{}) failedAccountIDs := make(map[int64]struct{})
lastFailoverStatus := 0 lastFailoverStatus := 0
for { for {
account, err := h.geminiCompatService.SelectAccountForModelWithExclusions(c.Request.Context(), apiKey.GroupID, sessionHash, modelName, failedAccountIDs) selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, modelName, failedAccountIDs)
if err != nil { if err != nil {
if len(failedAccountIDs) == 0 { if len(failedAccountIDs) == 0 {
googleError(c, http.StatusServiceUnavailable, "No available Gemini accounts: "+err.Error()) googleError(c, http.StatusServiceUnavailable, "No available Gemini accounts: "+err.Error())
...@@ -207,12 +216,48 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { ...@@ -207,12 +216,48 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
handleGeminiFailoverExhausted(c, lastFailoverStatus) handleGeminiFailoverExhausted(c, lastFailoverStatus)
return return
} }
account := selection.Account
// 4) account concurrency slot // 4) account concurrency slot
accountReleaseFunc, err := geminiConcurrency.AcquireAccountSlotWithWait(c, account.ID, account.Concurrency, stream, &streamStarted) accountReleaseFunc := selection.ReleaseFunc
if err != nil { var accountWaitRelease func()
googleError(c, http.StatusTooManyRequests, err.Error()) if !selection.Acquired {
return if selection.WaitPlan == nil {
googleError(c, http.StatusServiceUnavailable, "No available Gemini accounts")
return
}
canWait, err := geminiConcurrency.IncrementAccountWaitCount(c.Request.Context(), account.ID, selection.WaitPlan.MaxWaiting)
if err != nil {
log.Printf("Increment account wait count failed: %v", err)
} else if !canWait {
log.Printf("Account wait queue full: account=%d", account.ID)
googleError(c, http.StatusTooManyRequests, "Too many pending requests, please retry later")
return
} else {
// Only set release function if increment succeeded
accountWaitRelease = func() {
geminiConcurrency.DecrementAccountWaitCount(c.Request.Context(), account.ID)
}
}
accountReleaseFunc, err = geminiConcurrency.AcquireAccountSlotWithWaitTimeout(
c,
account.ID,
selection.WaitPlan.MaxConcurrency,
selection.WaitPlan.Timeout,
stream,
&streamStarted,
)
if err != nil {
if accountWaitRelease != nil {
accountWaitRelease()
}
googleError(c, http.StatusTooManyRequests, err.Error())
return
}
if err := h.gatewayService.BindStickySession(c.Request.Context(), sessionKey, account.ID); err != nil {
log.Printf("Bind sticky session failed: %v", err)
}
} }
// 5) forward (根据平台分流) // 5) forward (根据平台分流)
...@@ -225,6 +270,9 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { ...@@ -225,6 +270,9 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
if accountReleaseFunc != nil { if accountReleaseFunc != nil {
accountReleaseFunc() accountReleaseFunc()
} }
if accountWaitRelease != nil {
accountWaitRelease()
}
if err != nil { if err != nil {
var failoverErr *service.UpstreamFailoverError var failoverErr *service.UpstreamFailoverError
if errors.As(err, &failoverErr) { if errors.As(err, &failoverErr) {
......
...@@ -56,6 +56,10 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { ...@@ -56,6 +56,10 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
// Read request body // Read request body
body, err := io.ReadAll(c.Request.Body) body, err := io.ReadAll(c.Request.Body)
if err != nil { if err != nil {
if maxErr, ok := extractMaxBytesError(err); ok {
h.errorResponse(c, http.StatusRequestEntityTooLarge, "invalid_request_error", buildBodyTooLargeMessage(maxErr.Limit))
return
}
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to read request body") h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to read request body")
return return
} }
...@@ -76,6 +80,12 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { ...@@ -76,6 +80,12 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
reqModel, _ := reqBody["model"].(string) reqModel, _ := reqBody["model"].(string)
reqStream, _ := reqBody["stream"].(bool) reqStream, _ := reqBody["stream"].(bool)
// 验证 model 必填
if reqModel == "" {
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "model is required")
return
}
// For non-Codex CLI requests, set default instructions // For non-Codex CLI requests, set default instructions
userAgent := c.GetHeader("User-Agent") userAgent := c.GetHeader("User-Agent")
if !openai.IsCodexCLIRequest(userAgent) { if !openai.IsCodexCLIRequest(userAgent) {
...@@ -136,7 +146,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { ...@@ -136,7 +146,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
for { for {
// Select account supporting the requested model // Select account supporting the requested model
log.Printf("[OpenAI Handler] Selecting account: groupID=%v model=%s", apiKey.GroupID, reqModel) log.Printf("[OpenAI Handler] Selecting account: groupID=%v model=%s", apiKey.GroupID, reqModel)
account, err := h.gatewayService.SelectAccountForModelWithExclusions(c.Request.Context(), apiKey.GroupID, sessionHash, reqModel, failedAccountIDs) selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionHash, reqModel, failedAccountIDs)
if err != nil { if err != nil {
log.Printf("[OpenAI Handler] SelectAccount failed: %v", err) log.Printf("[OpenAI Handler] SelectAccount failed: %v", err)
if len(failedAccountIDs) == 0 { if len(failedAccountIDs) == 0 {
...@@ -146,14 +156,50 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { ...@@ -146,14 +156,50 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
h.handleFailoverExhausted(c, lastFailoverStatus, streamStarted) h.handleFailoverExhausted(c, lastFailoverStatus, streamStarted)
return return
} }
account := selection.Account
log.Printf("[OpenAI Handler] Selected account: id=%d name=%s", account.ID, account.Name) log.Printf("[OpenAI Handler] Selected account: id=%d name=%s", account.ID, account.Name)
// 3. Acquire account concurrency slot // 3. Acquire account concurrency slot
accountReleaseFunc, err := h.concurrencyHelper.AcquireAccountSlotWithWait(c, account.ID, account.Concurrency, reqStream, &streamStarted) accountReleaseFunc := selection.ReleaseFunc
if err != nil { var accountWaitRelease func()
log.Printf("Account concurrency acquire failed: %v", err) if !selection.Acquired {
h.handleConcurrencyError(c, err, "account", streamStarted) if selection.WaitPlan == nil {
return h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts", streamStarted)
return
}
canWait, err := h.concurrencyHelper.IncrementAccountWaitCount(c.Request.Context(), account.ID, selection.WaitPlan.MaxWaiting)
if err != nil {
log.Printf("Increment account wait count failed: %v", err)
} else if !canWait {
log.Printf("Account wait queue full: account=%d", account.ID)
h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later", streamStarted)
return
} else {
// Only set release function if increment succeeded
accountWaitRelease = func() {
h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID)
}
}
accountReleaseFunc, err = h.concurrencyHelper.AcquireAccountSlotWithWaitTimeout(
c,
account.ID,
selection.WaitPlan.MaxConcurrency,
selection.WaitPlan.Timeout,
reqStream,
&streamStarted,
)
if err != nil {
if accountWaitRelease != nil {
accountWaitRelease()
}
log.Printf("Account concurrency acquire failed: %v", err)
h.handleConcurrencyError(c, err, "account", streamStarted)
return
}
if err := h.gatewayService.BindStickySession(c.Request.Context(), sessionHash, account.ID); err != nil {
log.Printf("Bind sticky session failed: %v", err)
}
} }
// Forward request // Forward request
...@@ -161,6 +207,9 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { ...@@ -161,6 +207,9 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
if accountReleaseFunc != nil { if accountReleaseFunc != nil {
accountReleaseFunc() accountReleaseFunc()
} }
if accountWaitRelease != nil {
accountWaitRelease()
}
if err != nil { if err != nil {
var failoverErr *service.UpstreamFailoverError var failoverErr *service.UpstreamFailoverError
if errors.As(err, &failoverErr) { if errors.As(err, &failoverErr) {
......
package handler
import (
"errors"
"fmt"
"net/http"
)
func extractMaxBytesError(err error) (*http.MaxBytesError, bool) {
var maxErr *http.MaxBytesError
if errors.As(err, &maxErr) {
return maxErr, true
}
return nil, false
}
func formatBodyLimit(limit int64) string {
const mb = 1024 * 1024
if limit >= mb {
return fmt.Sprintf("%dMB", limit/mb)
}
return fmt.Sprintf("%dB", limit)
}
func buildBodyTooLargeMessage(limit int64) string {
return fmt.Sprintf("Request body too large, limit is %s", formatBodyLimit(limit))
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment