Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Menu
Open sidebar
陈曦
sub2api
Commits
7331220e
Commit
7331220e
authored
Jan 01, 2026
by
Edric Li
Browse files
Merge remote-tracking branch 'upstream/main'
# Conflicts: # frontend/src/components/account/CreateAccountModal.vue
parents
fb86002e
4f13c8de
Changes
215
Hide whitespace changes
Inline
Side-by-side
backend/ent/user.go
View file @
7331220e
...
...
@@ -59,11 +59,13 @@ type UserEdges struct {
AssignedSubscriptions
[]
*
UserSubscription
`json:"assigned_subscriptions,omitempty"`
// AllowedGroups holds the value of the allowed_groups edge.
AllowedGroups
[]
*
Group
`json:"allowed_groups,omitempty"`
// UsageLogs holds the value of the usage_logs edge.
UsageLogs
[]
*
UsageLog
`json:"usage_logs,omitempty"`
// UserAllowedGroups holds the value of the user_allowed_groups edge.
UserAllowedGroups
[]
*
UserAllowedGroup
`json:"user_allowed_groups,omitempty"`
// loadedTypes holds the information for reporting if a
// type was loaded (or requested) in eager-loading or not.
loadedTypes
[
6
]
bool
loadedTypes
[
7
]
bool
}
// APIKeysOrErr returns the APIKeys value or an error if the edge
...
...
@@ -111,10 +113,19 @@ func (e UserEdges) AllowedGroupsOrErr() ([]*Group, error) {
return
nil
,
&
NotLoadedError
{
edge
:
"allowed_groups"
}
}
// UsageLogsOrErr returns the UsageLogs value or an error if the edge
// was not loaded in eager-loading.
func
(
e
UserEdges
)
UsageLogsOrErr
()
([]
*
UsageLog
,
error
)
{
if
e
.
loadedTypes
[
5
]
{
return
e
.
UsageLogs
,
nil
}
return
nil
,
&
NotLoadedError
{
edge
:
"usage_logs"
}
}
// UserAllowedGroupsOrErr returns the UserAllowedGroups value or an error if the edge
// was not loaded in eager-loading.
func
(
e
UserEdges
)
UserAllowedGroupsOrErr
()
([]
*
UserAllowedGroup
,
error
)
{
if
e
.
loadedTypes
[
5
]
{
if
e
.
loadedTypes
[
6
]
{
return
e
.
UserAllowedGroups
,
nil
}
return
nil
,
&
NotLoadedError
{
edge
:
"user_allowed_groups"
}
...
...
@@ -265,6 +276,11 @@ func (_m *User) QueryAllowedGroups() *GroupQuery {
return
NewUserClient
(
_m
.
config
)
.
QueryAllowedGroups
(
_m
)
}
// QueryUsageLogs queries the "usage_logs" edge of the User entity.
func
(
_m
*
User
)
QueryUsageLogs
()
*
UsageLogQuery
{
return
NewUserClient
(
_m
.
config
)
.
QueryUsageLogs
(
_m
)
}
// QueryUserAllowedGroups queries the "user_allowed_groups" edge of the User entity.
func
(
_m
*
User
)
QueryUserAllowedGroups
()
*
UserAllowedGroupQuery
{
return
NewUserClient
(
_m
.
config
)
.
QueryUserAllowedGroups
(
_m
)
...
...
backend/ent/user/user.go
View file @
7331220e
...
...
@@ -49,6 +49,8 @@ const (
EdgeAssignedSubscriptions
=
"assigned_subscriptions"
// EdgeAllowedGroups holds the string denoting the allowed_groups edge name in mutations.
EdgeAllowedGroups
=
"allowed_groups"
// EdgeUsageLogs holds the string denoting the usage_logs edge name in mutations.
EdgeUsageLogs
=
"usage_logs"
// EdgeUserAllowedGroups holds the string denoting the user_allowed_groups edge name in mutations.
EdgeUserAllowedGroups
=
"user_allowed_groups"
// Table holds the table name of the user in the database.
...
...
@@ -86,6 +88,13 @@ const (
// AllowedGroupsInverseTable is the table name for the Group entity.
// It exists in this package in order to avoid circular dependency with the "group" package.
AllowedGroupsInverseTable
=
"groups"
// UsageLogsTable is the table that holds the usage_logs relation/edge.
UsageLogsTable
=
"usage_logs"
// UsageLogsInverseTable is the table name for the UsageLog entity.
// It exists in this package in order to avoid circular dependency with the "usagelog" package.
UsageLogsInverseTable
=
"usage_logs"
// UsageLogsColumn is the table column denoting the usage_logs relation/edge.
UsageLogsColumn
=
"user_id"
// UserAllowedGroupsTable is the table that holds the user_allowed_groups relation/edge.
UserAllowedGroupsTable
=
"user_allowed_groups"
// UserAllowedGroupsInverseTable is the table name for the UserAllowedGroup entity.
...
...
@@ -308,6 +317,20 @@ func ByAllowedGroups(term sql.OrderTerm, terms ...sql.OrderTerm) OrderOption {
}
}
// ByUsageLogsCount orders the results by usage_logs count.
func
ByUsageLogsCount
(
opts
...
sql
.
OrderTermOption
)
OrderOption
{
return
func
(
s
*
sql
.
Selector
)
{
sqlgraph
.
OrderByNeighborsCount
(
s
,
newUsageLogsStep
(),
opts
...
)
}
}
// ByUsageLogs orders the results by usage_logs terms.
func
ByUsageLogs
(
term
sql
.
OrderTerm
,
terms
...
sql
.
OrderTerm
)
OrderOption
{
return
func
(
s
*
sql
.
Selector
)
{
sqlgraph
.
OrderByNeighborTerms
(
s
,
newUsageLogsStep
(),
append
([]
sql
.
OrderTerm
{
term
},
terms
...
)
...
)
}
}
// ByUserAllowedGroupsCount orders the results by user_allowed_groups count.
func
ByUserAllowedGroupsCount
(
opts
...
sql
.
OrderTermOption
)
OrderOption
{
return
func
(
s
*
sql
.
Selector
)
{
...
...
@@ -356,6 +379,13 @@ func newAllowedGroupsStep() *sqlgraph.Step {
sqlgraph
.
Edge
(
sqlgraph
.
M2M
,
false
,
AllowedGroupsTable
,
AllowedGroupsPrimaryKey
...
),
)
}
func
newUsageLogsStep
()
*
sqlgraph
.
Step
{
return
sqlgraph
.
NewStep
(
sqlgraph
.
From
(
Table
,
FieldID
),
sqlgraph
.
To
(
UsageLogsInverseTable
,
FieldID
),
sqlgraph
.
Edge
(
sqlgraph
.
O2M
,
false
,
UsageLogsTable
,
UsageLogsColumn
),
)
}
func
newUserAllowedGroupsStep
()
*
sqlgraph
.
Step
{
return
sqlgraph
.
NewStep
(
sqlgraph
.
From
(
Table
,
FieldID
),
...
...
backend/ent/user/where.go
View file @
7331220e
...
...
@@ -895,6 +895,29 @@ func HasAllowedGroupsWith(preds ...predicate.Group) predicate.User {
})
}
// HasUsageLogs applies the HasEdge predicate on the "usage_logs" edge.
func
HasUsageLogs
()
predicate
.
User
{
return
predicate
.
User
(
func
(
s
*
sql
.
Selector
)
{
step
:=
sqlgraph
.
NewStep
(
sqlgraph
.
From
(
Table
,
FieldID
),
sqlgraph
.
Edge
(
sqlgraph
.
O2M
,
false
,
UsageLogsTable
,
UsageLogsColumn
),
)
sqlgraph
.
HasNeighbors
(
s
,
step
)
})
}
// HasUsageLogsWith applies the HasEdge predicate on the "usage_logs" edge with a given conditions (other predicates).
func
HasUsageLogsWith
(
preds
...
predicate
.
UsageLog
)
predicate
.
User
{
return
predicate
.
User
(
func
(
s
*
sql
.
Selector
)
{
step
:=
newUsageLogsStep
()
sqlgraph
.
HasNeighborsWith
(
s
,
step
,
func
(
s
*
sql
.
Selector
)
{
for
_
,
p
:=
range
preds
{
p
(
s
)
}
})
})
}
// HasUserAllowedGroups applies the HasEdge predicate on the "user_allowed_groups" edge.
func
HasUserAllowedGroups
()
predicate
.
User
{
return
predicate
.
User
(
func
(
s
*
sql
.
Selector
)
{
...
...
backend/ent/user_create.go
View file @
7331220e
...
...
@@ -14,6 +14,7 @@ import (
"github.com/Wei-Shaw/sub2api/ent/apikey"
"github.com/Wei-Shaw/sub2api/ent/group"
"github.com/Wei-Shaw/sub2api/ent/redeemcode"
"github.com/Wei-Shaw/sub2api/ent/usagelog"
"github.com/Wei-Shaw/sub2api/ent/user"
"github.com/Wei-Shaw/sub2api/ent/usersubscription"
)
...
...
@@ -253,6 +254,21 @@ func (_c *UserCreate) AddAllowedGroups(v ...*Group) *UserCreate {
return
_c
.
AddAllowedGroupIDs
(
ids
...
)
}
// AddUsageLogIDs adds the "usage_logs" edge to the UsageLog entity by IDs.
func
(
_c
*
UserCreate
)
AddUsageLogIDs
(
ids
...
int64
)
*
UserCreate
{
_c
.
mutation
.
AddUsageLogIDs
(
ids
...
)
return
_c
}
// AddUsageLogs adds the "usage_logs" edges to the UsageLog entity.
func
(
_c
*
UserCreate
)
AddUsageLogs
(
v
...*
UsageLog
)
*
UserCreate
{
ids
:=
make
([]
int64
,
len
(
v
))
for
i
:=
range
v
{
ids
[
i
]
=
v
[
i
]
.
ID
}
return
_c
.
AddUsageLogIDs
(
ids
...
)
}
// Mutation returns the UserMutation object of the builder.
func
(
_c
*
UserCreate
)
Mutation
()
*
UserMutation
{
return
_c
.
mutation
...
...
@@ -559,6 +575,22 @@ func (_c *UserCreate) createSpec() (*User, *sqlgraph.CreateSpec) {
edge
.
Target
.
Fields
=
specE
.
Fields
_spec
.
Edges
=
append
(
_spec
.
Edges
,
edge
)
}
if
nodes
:=
_c
.
mutation
.
UsageLogsIDs
();
len
(
nodes
)
>
0
{
edge
:=
&
sqlgraph
.
EdgeSpec
{
Rel
:
sqlgraph
.
O2M
,
Inverse
:
false
,
Table
:
user
.
UsageLogsTable
,
Columns
:
[]
string
{
user
.
UsageLogsColumn
},
Bidi
:
false
,
Target
:
&
sqlgraph
.
EdgeTarget
{
IDSpec
:
sqlgraph
.
NewFieldSpec
(
usagelog
.
FieldID
,
field
.
TypeInt64
),
},
}
for
_
,
k
:=
range
nodes
{
edge
.
Target
.
Nodes
=
append
(
edge
.
Target
.
Nodes
,
k
)
}
_spec
.
Edges
=
append
(
_spec
.
Edges
,
edge
)
}
return
_node
,
_spec
}
...
...
backend/ent/user_query.go
View file @
7331220e
...
...
@@ -16,6 +16,7 @@ import (
"github.com/Wei-Shaw/sub2api/ent/group"
"github.com/Wei-Shaw/sub2api/ent/predicate"
"github.com/Wei-Shaw/sub2api/ent/redeemcode"
"github.com/Wei-Shaw/sub2api/ent/usagelog"
"github.com/Wei-Shaw/sub2api/ent/user"
"github.com/Wei-Shaw/sub2api/ent/userallowedgroup"
"github.com/Wei-Shaw/sub2api/ent/usersubscription"
...
...
@@ -33,6 +34,7 @@ type UserQuery struct {
withSubscriptions
*
UserSubscriptionQuery
withAssignedSubscriptions
*
UserSubscriptionQuery
withAllowedGroups
*
GroupQuery
withUsageLogs
*
UsageLogQuery
withUserAllowedGroups
*
UserAllowedGroupQuery
// intermediate query (i.e. traversal path).
sql
*
sql
.
Selector
...
...
@@ -180,6 +182,28 @@ func (_q *UserQuery) QueryAllowedGroups() *GroupQuery {
return
query
}
// QueryUsageLogs chains the current query on the "usage_logs" edge.
func
(
_q
*
UserQuery
)
QueryUsageLogs
()
*
UsageLogQuery
{
query
:=
(
&
UsageLogClient
{
config
:
_q
.
config
})
.
Query
()
query
.
path
=
func
(
ctx
context
.
Context
)
(
fromU
*
sql
.
Selector
,
err
error
)
{
if
err
:=
_q
.
prepareQuery
(
ctx
);
err
!=
nil
{
return
nil
,
err
}
selector
:=
_q
.
sqlQuery
(
ctx
)
if
err
:=
selector
.
Err
();
err
!=
nil
{
return
nil
,
err
}
step
:=
sqlgraph
.
NewStep
(
sqlgraph
.
From
(
user
.
Table
,
user
.
FieldID
,
selector
),
sqlgraph
.
To
(
usagelog
.
Table
,
usagelog
.
FieldID
),
sqlgraph
.
Edge
(
sqlgraph
.
O2M
,
false
,
user
.
UsageLogsTable
,
user
.
UsageLogsColumn
),
)
fromU
=
sqlgraph
.
SetNeighbors
(
_q
.
driver
.
Dialect
(),
step
)
return
fromU
,
nil
}
return
query
}
// QueryUserAllowedGroups chains the current query on the "user_allowed_groups" edge.
func
(
_q
*
UserQuery
)
QueryUserAllowedGroups
()
*
UserAllowedGroupQuery
{
query
:=
(
&
UserAllowedGroupClient
{
config
:
_q
.
config
})
.
Query
()
...
...
@@ -399,6 +423,7 @@ func (_q *UserQuery) Clone() *UserQuery {
withSubscriptions
:
_q
.
withSubscriptions
.
Clone
(),
withAssignedSubscriptions
:
_q
.
withAssignedSubscriptions
.
Clone
(),
withAllowedGroups
:
_q
.
withAllowedGroups
.
Clone
(),
withUsageLogs
:
_q
.
withUsageLogs
.
Clone
(),
withUserAllowedGroups
:
_q
.
withUserAllowedGroups
.
Clone
(),
// clone intermediate query.
sql
:
_q
.
sql
.
Clone
(),
...
...
@@ -461,6 +486,17 @@ func (_q *UserQuery) WithAllowedGroups(opts ...func(*GroupQuery)) *UserQuery {
return
_q
}
// WithUsageLogs tells the query-builder to eager-load the nodes that are connected to
// the "usage_logs" edge. The optional arguments are used to configure the query builder of the edge.
func
(
_q
*
UserQuery
)
WithUsageLogs
(
opts
...
func
(
*
UsageLogQuery
))
*
UserQuery
{
query
:=
(
&
UsageLogClient
{
config
:
_q
.
config
})
.
Query
()
for
_
,
opt
:=
range
opts
{
opt
(
query
)
}
_q
.
withUsageLogs
=
query
return
_q
}
// WithUserAllowedGroups tells the query-builder to eager-load the nodes that are connected to
// the "user_allowed_groups" edge. The optional arguments are used to configure the query builder of the edge.
func
(
_q
*
UserQuery
)
WithUserAllowedGroups
(
opts
...
func
(
*
UserAllowedGroupQuery
))
*
UserQuery
{
...
...
@@ -550,12 +586,13 @@ func (_q *UserQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*User, e
var
(
nodes
=
[]
*
User
{}
_spec
=
_q
.
querySpec
()
loadedTypes
=
[
6
]
bool
{
loadedTypes
=
[
7
]
bool
{
_q
.
withAPIKeys
!=
nil
,
_q
.
withRedeemCodes
!=
nil
,
_q
.
withSubscriptions
!=
nil
,
_q
.
withAssignedSubscriptions
!=
nil
,
_q
.
withAllowedGroups
!=
nil
,
_q
.
withUsageLogs
!=
nil
,
_q
.
withUserAllowedGroups
!=
nil
,
}
)
...
...
@@ -614,6 +651,13 @@ func (_q *UserQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*User, e
return
nil
,
err
}
}
if
query
:=
_q
.
withUsageLogs
;
query
!=
nil
{
if
err
:=
_q
.
loadUsageLogs
(
ctx
,
query
,
nodes
,
func
(
n
*
User
)
{
n
.
Edges
.
UsageLogs
=
[]
*
UsageLog
{}
},
func
(
n
*
User
,
e
*
UsageLog
)
{
n
.
Edges
.
UsageLogs
=
append
(
n
.
Edges
.
UsageLogs
,
e
)
});
err
!=
nil
{
return
nil
,
err
}
}
if
query
:=
_q
.
withUserAllowedGroups
;
query
!=
nil
{
if
err
:=
_q
.
loadUserAllowedGroups
(
ctx
,
query
,
nodes
,
func
(
n
*
User
)
{
n
.
Edges
.
UserAllowedGroups
=
[]
*
UserAllowedGroup
{}
},
...
...
@@ -811,6 +855,36 @@ func (_q *UserQuery) loadAllowedGroups(ctx context.Context, query *GroupQuery, n
}
return
nil
}
func
(
_q
*
UserQuery
)
loadUsageLogs
(
ctx
context
.
Context
,
query
*
UsageLogQuery
,
nodes
[]
*
User
,
init
func
(
*
User
),
assign
func
(
*
User
,
*
UsageLog
))
error
{
fks
:=
make
([]
driver
.
Value
,
0
,
len
(
nodes
))
nodeids
:=
make
(
map
[
int64
]
*
User
)
for
i
:=
range
nodes
{
fks
=
append
(
fks
,
nodes
[
i
]
.
ID
)
nodeids
[
nodes
[
i
]
.
ID
]
=
nodes
[
i
]
if
init
!=
nil
{
init
(
nodes
[
i
])
}
}
if
len
(
query
.
ctx
.
Fields
)
>
0
{
query
.
ctx
.
AppendFieldOnce
(
usagelog
.
FieldUserID
)
}
query
.
Where
(
predicate
.
UsageLog
(
func
(
s
*
sql
.
Selector
)
{
s
.
Where
(
sql
.
InValues
(
s
.
C
(
user
.
UsageLogsColumn
),
fks
...
))
}))
neighbors
,
err
:=
query
.
All
(
ctx
)
if
err
!=
nil
{
return
err
}
for
_
,
n
:=
range
neighbors
{
fk
:=
n
.
UserID
node
,
ok
:=
nodeids
[
fk
]
if
!
ok
{
return
fmt
.
Errorf
(
`unexpected referenced foreign-key "user_id" returned %v for node %v`
,
fk
,
n
.
ID
)
}
assign
(
node
,
n
)
}
return
nil
}
func
(
_q
*
UserQuery
)
loadUserAllowedGroups
(
ctx
context
.
Context
,
query
*
UserAllowedGroupQuery
,
nodes
[]
*
User
,
init
func
(
*
User
),
assign
func
(
*
User
,
*
UserAllowedGroup
))
error
{
fks
:=
make
([]
driver
.
Value
,
0
,
len
(
nodes
))
nodeids
:=
make
(
map
[
int64
]
*
User
)
...
...
backend/ent/user_update.go
View file @
7331220e
...
...
@@ -15,6 +15,7 @@ import (
"github.com/Wei-Shaw/sub2api/ent/group"
"github.com/Wei-Shaw/sub2api/ent/predicate"
"github.com/Wei-Shaw/sub2api/ent/redeemcode"
"github.com/Wei-Shaw/sub2api/ent/usagelog"
"github.com/Wei-Shaw/sub2api/ent/user"
"github.com/Wei-Shaw/sub2api/ent/usersubscription"
)
...
...
@@ -273,6 +274,21 @@ func (_u *UserUpdate) AddAllowedGroups(v ...*Group) *UserUpdate {
return
_u
.
AddAllowedGroupIDs
(
ids
...
)
}
// AddUsageLogIDs adds the "usage_logs" edge to the UsageLog entity by IDs.
func
(
_u
*
UserUpdate
)
AddUsageLogIDs
(
ids
...
int64
)
*
UserUpdate
{
_u
.
mutation
.
AddUsageLogIDs
(
ids
...
)
return
_u
}
// AddUsageLogs adds the "usage_logs" edges to the UsageLog entity.
func
(
_u
*
UserUpdate
)
AddUsageLogs
(
v
...*
UsageLog
)
*
UserUpdate
{
ids
:=
make
([]
int64
,
len
(
v
))
for
i
:=
range
v
{
ids
[
i
]
=
v
[
i
]
.
ID
}
return
_u
.
AddUsageLogIDs
(
ids
...
)
}
// Mutation returns the UserMutation object of the builder.
func
(
_u
*
UserUpdate
)
Mutation
()
*
UserMutation
{
return
_u
.
mutation
...
...
@@ -383,6 +399,27 @@ func (_u *UserUpdate) RemoveAllowedGroups(v ...*Group) *UserUpdate {
return
_u
.
RemoveAllowedGroupIDs
(
ids
...
)
}
// ClearUsageLogs clears all "usage_logs" edges to the UsageLog entity.
func
(
_u
*
UserUpdate
)
ClearUsageLogs
()
*
UserUpdate
{
_u
.
mutation
.
ClearUsageLogs
()
return
_u
}
// RemoveUsageLogIDs removes the "usage_logs" edge to UsageLog entities by IDs.
func
(
_u
*
UserUpdate
)
RemoveUsageLogIDs
(
ids
...
int64
)
*
UserUpdate
{
_u
.
mutation
.
RemoveUsageLogIDs
(
ids
...
)
return
_u
}
// RemoveUsageLogs removes "usage_logs" edges to UsageLog entities.
func
(
_u
*
UserUpdate
)
RemoveUsageLogs
(
v
...*
UsageLog
)
*
UserUpdate
{
ids
:=
make
([]
int64
,
len
(
v
))
for
i
:=
range
v
{
ids
[
i
]
=
v
[
i
]
.
ID
}
return
_u
.
RemoveUsageLogIDs
(
ids
...
)
}
// Save executes the query and returns the number of nodes affected by the update operation.
func
(
_u
*
UserUpdate
)
Save
(
ctx
context
.
Context
)
(
int
,
error
)
{
if
err
:=
_u
.
defaults
();
err
!=
nil
{
...
...
@@ -751,6 +788,51 @@ func (_u *UserUpdate) sqlSave(ctx context.Context) (_node int, err error) {
edge
.
Target
.
Fields
=
specE
.
Fields
_spec
.
Edges
.
Add
=
append
(
_spec
.
Edges
.
Add
,
edge
)
}
if
_u
.
mutation
.
UsageLogsCleared
()
{
edge
:=
&
sqlgraph
.
EdgeSpec
{
Rel
:
sqlgraph
.
O2M
,
Inverse
:
false
,
Table
:
user
.
UsageLogsTable
,
Columns
:
[]
string
{
user
.
UsageLogsColumn
},
Bidi
:
false
,
Target
:
&
sqlgraph
.
EdgeTarget
{
IDSpec
:
sqlgraph
.
NewFieldSpec
(
usagelog
.
FieldID
,
field
.
TypeInt64
),
},
}
_spec
.
Edges
.
Clear
=
append
(
_spec
.
Edges
.
Clear
,
edge
)
}
if
nodes
:=
_u
.
mutation
.
RemovedUsageLogsIDs
();
len
(
nodes
)
>
0
&&
!
_u
.
mutation
.
UsageLogsCleared
()
{
edge
:=
&
sqlgraph
.
EdgeSpec
{
Rel
:
sqlgraph
.
O2M
,
Inverse
:
false
,
Table
:
user
.
UsageLogsTable
,
Columns
:
[]
string
{
user
.
UsageLogsColumn
},
Bidi
:
false
,
Target
:
&
sqlgraph
.
EdgeTarget
{
IDSpec
:
sqlgraph
.
NewFieldSpec
(
usagelog
.
FieldID
,
field
.
TypeInt64
),
},
}
for
_
,
k
:=
range
nodes
{
edge
.
Target
.
Nodes
=
append
(
edge
.
Target
.
Nodes
,
k
)
}
_spec
.
Edges
.
Clear
=
append
(
_spec
.
Edges
.
Clear
,
edge
)
}
if
nodes
:=
_u
.
mutation
.
UsageLogsIDs
();
len
(
nodes
)
>
0
{
edge
:=
&
sqlgraph
.
EdgeSpec
{
Rel
:
sqlgraph
.
O2M
,
Inverse
:
false
,
Table
:
user
.
UsageLogsTable
,
Columns
:
[]
string
{
user
.
UsageLogsColumn
},
Bidi
:
false
,
Target
:
&
sqlgraph
.
EdgeTarget
{
IDSpec
:
sqlgraph
.
NewFieldSpec
(
usagelog
.
FieldID
,
field
.
TypeInt64
),
},
}
for
_
,
k
:=
range
nodes
{
edge
.
Target
.
Nodes
=
append
(
edge
.
Target
.
Nodes
,
k
)
}
_spec
.
Edges
.
Add
=
append
(
_spec
.
Edges
.
Add
,
edge
)
}
if
_node
,
err
=
sqlgraph
.
UpdateNodes
(
ctx
,
_u
.
driver
,
_spec
);
err
!=
nil
{
if
_
,
ok
:=
err
.
(
*
sqlgraph
.
NotFoundError
);
ok
{
err
=
&
NotFoundError
{
user
.
Label
}
...
...
@@ -1012,6 +1094,21 @@ func (_u *UserUpdateOne) AddAllowedGroups(v ...*Group) *UserUpdateOne {
return
_u
.
AddAllowedGroupIDs
(
ids
...
)
}
// AddUsageLogIDs adds the "usage_logs" edge to the UsageLog entity by IDs.
func
(
_u
*
UserUpdateOne
)
AddUsageLogIDs
(
ids
...
int64
)
*
UserUpdateOne
{
_u
.
mutation
.
AddUsageLogIDs
(
ids
...
)
return
_u
}
// AddUsageLogs adds the "usage_logs" edges to the UsageLog entity.
func
(
_u
*
UserUpdateOne
)
AddUsageLogs
(
v
...*
UsageLog
)
*
UserUpdateOne
{
ids
:=
make
([]
int64
,
len
(
v
))
for
i
:=
range
v
{
ids
[
i
]
=
v
[
i
]
.
ID
}
return
_u
.
AddUsageLogIDs
(
ids
...
)
}
// Mutation returns the UserMutation object of the builder.
func
(
_u
*
UserUpdateOne
)
Mutation
()
*
UserMutation
{
return
_u
.
mutation
...
...
@@ -1122,6 +1219,27 @@ func (_u *UserUpdateOne) RemoveAllowedGroups(v ...*Group) *UserUpdateOne {
return
_u
.
RemoveAllowedGroupIDs
(
ids
...
)
}
// ClearUsageLogs clears all "usage_logs" edges to the UsageLog entity.
func
(
_u
*
UserUpdateOne
)
ClearUsageLogs
()
*
UserUpdateOne
{
_u
.
mutation
.
ClearUsageLogs
()
return
_u
}
// RemoveUsageLogIDs removes the "usage_logs" edge to UsageLog entities by IDs.
func
(
_u
*
UserUpdateOne
)
RemoveUsageLogIDs
(
ids
...
int64
)
*
UserUpdateOne
{
_u
.
mutation
.
RemoveUsageLogIDs
(
ids
...
)
return
_u
}
// RemoveUsageLogs removes "usage_logs" edges to UsageLog entities.
func
(
_u
*
UserUpdateOne
)
RemoveUsageLogs
(
v
...*
UsageLog
)
*
UserUpdateOne
{
ids
:=
make
([]
int64
,
len
(
v
))
for
i
:=
range
v
{
ids
[
i
]
=
v
[
i
]
.
ID
}
return
_u
.
RemoveUsageLogIDs
(
ids
...
)
}
// Where appends a list predicates to the UserUpdate builder.
func
(
_u
*
UserUpdateOne
)
Where
(
ps
...
predicate
.
User
)
*
UserUpdateOne
{
_u
.
mutation
.
Where
(
ps
...
)
...
...
@@ -1520,6 +1638,51 @@ func (_u *UserUpdateOne) sqlSave(ctx context.Context) (_node *User, err error) {
edge
.
Target
.
Fields
=
specE
.
Fields
_spec
.
Edges
.
Add
=
append
(
_spec
.
Edges
.
Add
,
edge
)
}
if
_u
.
mutation
.
UsageLogsCleared
()
{
edge
:=
&
sqlgraph
.
EdgeSpec
{
Rel
:
sqlgraph
.
O2M
,
Inverse
:
false
,
Table
:
user
.
UsageLogsTable
,
Columns
:
[]
string
{
user
.
UsageLogsColumn
},
Bidi
:
false
,
Target
:
&
sqlgraph
.
EdgeTarget
{
IDSpec
:
sqlgraph
.
NewFieldSpec
(
usagelog
.
FieldID
,
field
.
TypeInt64
),
},
}
_spec
.
Edges
.
Clear
=
append
(
_spec
.
Edges
.
Clear
,
edge
)
}
if
nodes
:=
_u
.
mutation
.
RemovedUsageLogsIDs
();
len
(
nodes
)
>
0
&&
!
_u
.
mutation
.
UsageLogsCleared
()
{
edge
:=
&
sqlgraph
.
EdgeSpec
{
Rel
:
sqlgraph
.
O2M
,
Inverse
:
false
,
Table
:
user
.
UsageLogsTable
,
Columns
:
[]
string
{
user
.
UsageLogsColumn
},
Bidi
:
false
,
Target
:
&
sqlgraph
.
EdgeTarget
{
IDSpec
:
sqlgraph
.
NewFieldSpec
(
usagelog
.
FieldID
,
field
.
TypeInt64
),
},
}
for
_
,
k
:=
range
nodes
{
edge
.
Target
.
Nodes
=
append
(
edge
.
Target
.
Nodes
,
k
)
}
_spec
.
Edges
.
Clear
=
append
(
_spec
.
Edges
.
Clear
,
edge
)
}
if
nodes
:=
_u
.
mutation
.
UsageLogsIDs
();
len
(
nodes
)
>
0
{
edge
:=
&
sqlgraph
.
EdgeSpec
{
Rel
:
sqlgraph
.
O2M
,
Inverse
:
false
,
Table
:
user
.
UsageLogsTable
,
Columns
:
[]
string
{
user
.
UsageLogsColumn
},
Bidi
:
false
,
Target
:
&
sqlgraph
.
EdgeTarget
{
IDSpec
:
sqlgraph
.
NewFieldSpec
(
usagelog
.
FieldID
,
field
.
TypeInt64
),
},
}
for
_
,
k
:=
range
nodes
{
edge
.
Target
.
Nodes
=
append
(
edge
.
Target
.
Nodes
,
k
)
}
_spec
.
Edges
.
Add
=
append
(
_spec
.
Edges
.
Add
,
edge
)
}
_node
=
&
User
{
config
:
_u
.
config
}
_spec
.
Assign
=
_node
.
assignValues
_spec
.
ScanValues
=
_node
.
scanValues
...
...
backend/ent/usersubscription.go
View file @
7331220e
...
...
@@ -23,6 +23,8 @@ type UserSubscription struct {
CreatedAt
time
.
Time
`json:"created_at,omitempty"`
// UpdatedAt holds the value of the "updated_at" field.
UpdatedAt
time
.
Time
`json:"updated_at,omitempty"`
// DeletedAt holds the value of the "deleted_at" field.
DeletedAt
*
time
.
Time
`json:"deleted_at,omitempty"`
// UserID holds the value of the "user_id" field.
UserID
int64
`json:"user_id,omitempty"`
// GroupID holds the value of the "group_id" field.
...
...
@@ -65,9 +67,11 @@ type UserSubscriptionEdges struct {
Group
*
Group
`json:"group,omitempty"`
// AssignedByUser holds the value of the assigned_by_user edge.
AssignedByUser
*
User
`json:"assigned_by_user,omitempty"`
// UsageLogs holds the value of the usage_logs edge.
UsageLogs
[]
*
UsageLog
`json:"usage_logs,omitempty"`
// loadedTypes holds the information for reporting if a
// type was loaded (or requested) in eager-loading or not.
loadedTypes
[
3
]
bool
loadedTypes
[
4
]
bool
}
// UserOrErr returns the User value or an error if the edge
...
...
@@ -103,6 +107,15 @@ func (e UserSubscriptionEdges) AssignedByUserOrErr() (*User, error) {
return
nil
,
&
NotLoadedError
{
edge
:
"assigned_by_user"
}
}
// UsageLogsOrErr returns the UsageLogs value or an error if the edge
// was not loaded in eager-loading.
func
(
e
UserSubscriptionEdges
)
UsageLogsOrErr
()
([]
*
UsageLog
,
error
)
{
if
e
.
loadedTypes
[
3
]
{
return
e
.
UsageLogs
,
nil
}
return
nil
,
&
NotLoadedError
{
edge
:
"usage_logs"
}
}
// scanValues returns the types for scanning values from sql.Rows.
func
(
*
UserSubscription
)
scanValues
(
columns
[]
string
)
([]
any
,
error
)
{
values
:=
make
([]
any
,
len
(
columns
))
...
...
@@ -114,7 +127,7 @@ func (*UserSubscription) scanValues(columns []string) ([]any, error) {
values
[
i
]
=
new
(
sql
.
NullInt64
)
case
usersubscription
.
FieldStatus
,
usersubscription
.
FieldNotes
:
values
[
i
]
=
new
(
sql
.
NullString
)
case
usersubscription
.
FieldCreatedAt
,
usersubscription
.
FieldUpdatedAt
,
usersubscription
.
FieldStartsAt
,
usersubscription
.
FieldExpiresAt
,
usersubscription
.
FieldDailyWindowStart
,
usersubscription
.
FieldWeeklyWindowStart
,
usersubscription
.
FieldMonthlyWindowStart
,
usersubscription
.
FieldAssignedAt
:
case
usersubscription
.
FieldCreatedAt
,
usersubscription
.
FieldUpdatedAt
,
usersubscription
.
FieldDeletedAt
,
usersubscription
.
FieldStartsAt
,
usersubscription
.
FieldExpiresAt
,
usersubscription
.
FieldDailyWindowStart
,
usersubscription
.
FieldWeeklyWindowStart
,
usersubscription
.
FieldMonthlyWindowStart
,
usersubscription
.
FieldAssignedAt
:
values
[
i
]
=
new
(
sql
.
NullTime
)
default
:
values
[
i
]
=
new
(
sql
.
UnknownType
)
...
...
@@ -149,6 +162,13 @@ func (_m *UserSubscription) assignValues(columns []string, values []any) error {
}
else
if
value
.
Valid
{
_m
.
UpdatedAt
=
value
.
Time
}
case
usersubscription
.
FieldDeletedAt
:
if
value
,
ok
:=
values
[
i
]
.
(
*
sql
.
NullTime
);
!
ok
{
return
fmt
.
Errorf
(
"unexpected type %T for field deleted_at"
,
values
[
i
])
}
else
if
value
.
Valid
{
_m
.
DeletedAt
=
new
(
time
.
Time
)
*
_m
.
DeletedAt
=
value
.
Time
}
case
usersubscription
.
FieldUserID
:
if
value
,
ok
:=
values
[
i
]
.
(
*
sql
.
NullInt64
);
!
ok
{
return
fmt
.
Errorf
(
"unexpected type %T for field user_id"
,
values
[
i
])
...
...
@@ -266,6 +286,11 @@ func (_m *UserSubscription) QueryAssignedByUser() *UserQuery {
return
NewUserSubscriptionClient
(
_m
.
config
)
.
QueryAssignedByUser
(
_m
)
}
// QueryUsageLogs queries the "usage_logs" edge of the UserSubscription entity.
func
(
_m
*
UserSubscription
)
QueryUsageLogs
()
*
UsageLogQuery
{
return
NewUserSubscriptionClient
(
_m
.
config
)
.
QueryUsageLogs
(
_m
)
}
// Update returns a builder for updating this UserSubscription.
// Note that you need to call UserSubscription.Unwrap() before calling this method if this UserSubscription
// was returned from a transaction, and the transaction was committed or rolled back.
...
...
@@ -295,6 +320,11 @@ func (_m *UserSubscription) String() string {
builder
.
WriteString
(
"updated_at="
)
builder
.
WriteString
(
_m
.
UpdatedAt
.
Format
(
time
.
ANSIC
))
builder
.
WriteString
(
", "
)
if
v
:=
_m
.
DeletedAt
;
v
!=
nil
{
builder
.
WriteString
(
"deleted_at="
)
builder
.
WriteString
(
v
.
Format
(
time
.
ANSIC
))
}
builder
.
WriteString
(
", "
)
builder
.
WriteString
(
"user_id="
)
builder
.
WriteString
(
fmt
.
Sprintf
(
"%v"
,
_m
.
UserID
))
builder
.
WriteString
(
", "
)
...
...
backend/ent/usersubscription/usersubscription.go
View file @
7331220e
...
...
@@ -5,6 +5,7 @@ package usersubscription
import
(
"time"
"entgo.io/ent"
"entgo.io/ent/dialect/sql"
"entgo.io/ent/dialect/sql/sqlgraph"
)
...
...
@@ -18,6 +19,8 @@ const (
FieldCreatedAt
=
"created_at"
// FieldUpdatedAt holds the string denoting the updated_at field in the database.
FieldUpdatedAt
=
"updated_at"
// FieldDeletedAt holds the string denoting the deleted_at field in the database.
FieldDeletedAt
=
"deleted_at"
// FieldUserID holds the string denoting the user_id field in the database.
FieldUserID
=
"user_id"
// FieldGroupID holds the string denoting the group_id field in the database.
...
...
@@ -52,6 +55,8 @@ const (
EdgeGroup
=
"group"
// EdgeAssignedByUser holds the string denoting the assigned_by_user edge name in mutations.
EdgeAssignedByUser
=
"assigned_by_user"
// EdgeUsageLogs holds the string denoting the usage_logs edge name in mutations.
EdgeUsageLogs
=
"usage_logs"
// Table holds the table name of the usersubscription in the database.
Table
=
"user_subscriptions"
// UserTable is the table that holds the user relation/edge.
...
...
@@ -75,6 +80,13 @@ const (
AssignedByUserInverseTable
=
"users"
// AssignedByUserColumn is the table column denoting the assigned_by_user relation/edge.
AssignedByUserColumn
=
"assigned_by"
// UsageLogsTable is the table that holds the usage_logs relation/edge.
UsageLogsTable
=
"usage_logs"
// UsageLogsInverseTable is the table name for the UsageLog entity.
// It exists in this package in order to avoid circular dependency with the "usagelog" package.
UsageLogsInverseTable
=
"usage_logs"
// UsageLogsColumn is the table column denoting the usage_logs relation/edge.
UsageLogsColumn
=
"subscription_id"
)
// Columns holds all SQL columns for usersubscription fields.
...
...
@@ -82,6 +94,7 @@ var Columns = []string{
FieldID
,
FieldCreatedAt
,
FieldUpdatedAt
,
FieldDeletedAt
,
FieldUserID
,
FieldGroupID
,
FieldStartsAt
,
...
...
@@ -108,7 +121,14 @@ func ValidColumn(column string) bool {
return
false
}
// Note that the variables below are initialized by the runtime
// package on the initialization of the application. Therefore,
// it should be imported in the main as follows:
//
// import _ "github.com/Wei-Shaw/sub2api/ent/runtime"
var
(
Hooks
[
1
]
ent
.
Hook
Interceptors
[
1
]
ent
.
Interceptor
// DefaultCreatedAt holds the default value on creation for the "created_at" field.
DefaultCreatedAt
func
()
time
.
Time
// DefaultUpdatedAt holds the default value on creation for the "updated_at" field.
...
...
@@ -147,6 +167,11 @@ func ByUpdatedAt(opts ...sql.OrderTermOption) OrderOption {
return
sql
.
OrderByField
(
FieldUpdatedAt
,
opts
...
)
.
ToFunc
()
}
// ByDeletedAt orders the results by the deleted_at field.
func
ByDeletedAt
(
opts
...
sql
.
OrderTermOption
)
OrderOption
{
return
sql
.
OrderByField
(
FieldDeletedAt
,
opts
...
)
.
ToFunc
()
}
// ByUserID orders the results by the user_id field.
func
ByUserID
(
opts
...
sql
.
OrderTermOption
)
OrderOption
{
return
sql
.
OrderByField
(
FieldUserID
,
opts
...
)
.
ToFunc
()
...
...
@@ -237,6 +262,20 @@ func ByAssignedByUserField(field string, opts ...sql.OrderTermOption) OrderOptio
sqlgraph
.
OrderByNeighborTerms
(
s
,
newAssignedByUserStep
(),
sql
.
OrderByField
(
field
,
opts
...
))
}
}
// ByUsageLogsCount orders the results by usage_logs count.
func
ByUsageLogsCount
(
opts
...
sql
.
OrderTermOption
)
OrderOption
{
return
func
(
s
*
sql
.
Selector
)
{
sqlgraph
.
OrderByNeighborsCount
(
s
,
newUsageLogsStep
(),
opts
...
)
}
}
// ByUsageLogs orders the results by usage_logs terms.
func
ByUsageLogs
(
term
sql
.
OrderTerm
,
terms
...
sql
.
OrderTerm
)
OrderOption
{
return
func
(
s
*
sql
.
Selector
)
{
sqlgraph
.
OrderByNeighborTerms
(
s
,
newUsageLogsStep
(),
append
([]
sql
.
OrderTerm
{
term
},
terms
...
)
...
)
}
}
func
newUserStep
()
*
sqlgraph
.
Step
{
return
sqlgraph
.
NewStep
(
sqlgraph
.
From
(
Table
,
FieldID
),
...
...
@@ -258,3 +297,10 @@ func newAssignedByUserStep() *sqlgraph.Step {
sqlgraph
.
Edge
(
sqlgraph
.
M2O
,
true
,
AssignedByUserTable
,
AssignedByUserColumn
),
)
}
func
newUsageLogsStep
()
*
sqlgraph
.
Step
{
return
sqlgraph
.
NewStep
(
sqlgraph
.
From
(
Table
,
FieldID
),
sqlgraph
.
To
(
UsageLogsInverseTable
,
FieldID
),
sqlgraph
.
Edge
(
sqlgraph
.
O2M
,
false
,
UsageLogsTable
,
UsageLogsColumn
),
)
}
backend/ent/usersubscription/where.go
View file @
7331220e
...
...
@@ -65,6 +65,11 @@ func UpdatedAt(v time.Time) predicate.UserSubscription {
return
predicate
.
UserSubscription
(
sql
.
FieldEQ
(
FieldUpdatedAt
,
v
))
}
// DeletedAt applies equality check predicate on the "deleted_at" field. It's identical to DeletedAtEQ.
func
DeletedAt
(
v
time
.
Time
)
predicate
.
UserSubscription
{
return
predicate
.
UserSubscription
(
sql
.
FieldEQ
(
FieldDeletedAt
,
v
))
}
// UserID applies equality check predicate on the "user_id" field. It's identical to UserIDEQ.
func
UserID
(
v
int64
)
predicate
.
UserSubscription
{
return
predicate
.
UserSubscription
(
sql
.
FieldEQ
(
FieldUserID
,
v
))
...
...
@@ -215,6 +220,56 @@ func UpdatedAtLTE(v time.Time) predicate.UserSubscription {
return
predicate
.
UserSubscription
(
sql
.
FieldLTE
(
FieldUpdatedAt
,
v
))
}
// DeletedAtEQ applies the EQ predicate on the "deleted_at" field.
func
DeletedAtEQ
(
v
time
.
Time
)
predicate
.
UserSubscription
{
return
predicate
.
UserSubscription
(
sql
.
FieldEQ
(
FieldDeletedAt
,
v
))
}
// DeletedAtNEQ applies the NEQ predicate on the "deleted_at" field.
func
DeletedAtNEQ
(
v
time
.
Time
)
predicate
.
UserSubscription
{
return
predicate
.
UserSubscription
(
sql
.
FieldNEQ
(
FieldDeletedAt
,
v
))
}
// DeletedAtIn applies the In predicate on the "deleted_at" field.
func
DeletedAtIn
(
vs
...
time
.
Time
)
predicate
.
UserSubscription
{
return
predicate
.
UserSubscription
(
sql
.
FieldIn
(
FieldDeletedAt
,
vs
...
))
}
// DeletedAtNotIn applies the NotIn predicate on the "deleted_at" field.
func
DeletedAtNotIn
(
vs
...
time
.
Time
)
predicate
.
UserSubscription
{
return
predicate
.
UserSubscription
(
sql
.
FieldNotIn
(
FieldDeletedAt
,
vs
...
))
}
// DeletedAtGT applies the GT predicate on the "deleted_at" field.
func
DeletedAtGT
(
v
time
.
Time
)
predicate
.
UserSubscription
{
return
predicate
.
UserSubscription
(
sql
.
FieldGT
(
FieldDeletedAt
,
v
))
}
// DeletedAtGTE applies the GTE predicate on the "deleted_at" field.
func
DeletedAtGTE
(
v
time
.
Time
)
predicate
.
UserSubscription
{
return
predicate
.
UserSubscription
(
sql
.
FieldGTE
(
FieldDeletedAt
,
v
))
}
// DeletedAtLT applies the LT predicate on the "deleted_at" field.
func
DeletedAtLT
(
v
time
.
Time
)
predicate
.
UserSubscription
{
return
predicate
.
UserSubscription
(
sql
.
FieldLT
(
FieldDeletedAt
,
v
))
}
// DeletedAtLTE applies the LTE predicate on the "deleted_at" field.
func
DeletedAtLTE
(
v
time
.
Time
)
predicate
.
UserSubscription
{
return
predicate
.
UserSubscription
(
sql
.
FieldLTE
(
FieldDeletedAt
,
v
))
}
// DeletedAtIsNil applies the IsNil predicate on the "deleted_at" field.
func
DeletedAtIsNil
()
predicate
.
UserSubscription
{
return
predicate
.
UserSubscription
(
sql
.
FieldIsNull
(
FieldDeletedAt
))
}
// DeletedAtNotNil applies the NotNil predicate on the "deleted_at" field.
func
DeletedAtNotNil
()
predicate
.
UserSubscription
{
return
predicate
.
UserSubscription
(
sql
.
FieldNotNull
(
FieldDeletedAt
))
}
// UserIDEQ applies the EQ predicate on the "user_id" field.
func
UserIDEQ
(
v
int64
)
predicate
.
UserSubscription
{
return
predicate
.
UserSubscription
(
sql
.
FieldEQ
(
FieldUserID
,
v
))
...
...
@@ -884,6 +939,29 @@ func HasAssignedByUserWith(preds ...predicate.User) predicate.UserSubscription {
})
}
// HasUsageLogs applies the HasEdge predicate on the "usage_logs" edge.
func
HasUsageLogs
()
predicate
.
UserSubscription
{
return
predicate
.
UserSubscription
(
func
(
s
*
sql
.
Selector
)
{
step
:=
sqlgraph
.
NewStep
(
sqlgraph
.
From
(
Table
,
FieldID
),
sqlgraph
.
Edge
(
sqlgraph
.
O2M
,
false
,
UsageLogsTable
,
UsageLogsColumn
),
)
sqlgraph
.
HasNeighbors
(
s
,
step
)
})
}
// HasUsageLogsWith applies the HasEdge predicate on the "usage_logs" edge with a given conditions (other predicates).
func
HasUsageLogsWith
(
preds
...
predicate
.
UsageLog
)
predicate
.
UserSubscription
{
return
predicate
.
UserSubscription
(
func
(
s
*
sql
.
Selector
)
{
step
:=
newUsageLogsStep
()
sqlgraph
.
HasNeighborsWith
(
s
,
step
,
func
(
s
*
sql
.
Selector
)
{
for
_
,
p
:=
range
preds
{
p
(
s
)
}
})
})
}
// And groups predicates with the AND operator between them.
func
And
(
predicates
...
predicate
.
UserSubscription
)
predicate
.
UserSubscription
{
return
predicate
.
UserSubscription
(
sql
.
AndPredicates
(
predicates
...
))
...
...
backend/ent/usersubscription_create.go
View file @
7331220e
...
...
@@ -12,6 +12,7 @@ import (
"entgo.io/ent/dialect/sql/sqlgraph"
"entgo.io/ent/schema/field"
"github.com/Wei-Shaw/sub2api/ent/group"
"github.com/Wei-Shaw/sub2api/ent/usagelog"
"github.com/Wei-Shaw/sub2api/ent/user"
"github.com/Wei-Shaw/sub2api/ent/usersubscription"
)
...
...
@@ -52,6 +53,20 @@ func (_c *UserSubscriptionCreate) SetNillableUpdatedAt(v *time.Time) *UserSubscr
return
_c
}
// SetDeletedAt sets the "deleted_at" field.
func
(
_c
*
UserSubscriptionCreate
)
SetDeletedAt
(
v
time
.
Time
)
*
UserSubscriptionCreate
{
_c
.
mutation
.
SetDeletedAt
(
v
)
return
_c
}
// SetNillableDeletedAt sets the "deleted_at" field if the given value is not nil.
func
(
_c
*
UserSubscriptionCreate
)
SetNillableDeletedAt
(
v
*
time
.
Time
)
*
UserSubscriptionCreate
{
if
v
!=
nil
{
_c
.
SetDeletedAt
(
*
v
)
}
return
_c
}
// SetUserID sets the "user_id" field.
func
(
_c
*
UserSubscriptionCreate
)
SetUserID
(
v
int64
)
*
UserSubscriptionCreate
{
_c
.
mutation
.
SetUserID
(
v
)
...
...
@@ -245,6 +260,21 @@ func (_c *UserSubscriptionCreate) SetAssignedByUser(v *User) *UserSubscriptionCr
return
_c
.
SetAssignedByUserID
(
v
.
ID
)
}
// AddUsageLogIDs adds the "usage_logs" edge to the UsageLog entity by IDs.
func
(
_c
*
UserSubscriptionCreate
)
AddUsageLogIDs
(
ids
...
int64
)
*
UserSubscriptionCreate
{
_c
.
mutation
.
AddUsageLogIDs
(
ids
...
)
return
_c
}
// AddUsageLogs adds the "usage_logs" edges to the UsageLog entity.
func
(
_c
*
UserSubscriptionCreate
)
AddUsageLogs
(
v
...*
UsageLog
)
*
UserSubscriptionCreate
{
ids
:=
make
([]
int64
,
len
(
v
))
for
i
:=
range
v
{
ids
[
i
]
=
v
[
i
]
.
ID
}
return
_c
.
AddUsageLogIDs
(
ids
...
)
}
// Mutation returns the UserSubscriptionMutation object of the builder.
func
(
_c
*
UserSubscriptionCreate
)
Mutation
()
*
UserSubscriptionMutation
{
return
_c
.
mutation
...
...
@@ -252,7 +282,9 @@ func (_c *UserSubscriptionCreate) Mutation() *UserSubscriptionMutation {
// Save creates the UserSubscription in the database.
func
(
_c
*
UserSubscriptionCreate
)
Save
(
ctx
context
.
Context
)
(
*
UserSubscription
,
error
)
{
_c
.
defaults
()
if
err
:=
_c
.
defaults
();
err
!=
nil
{
return
nil
,
err
}
return
withHooks
(
ctx
,
_c
.
sqlSave
,
_c
.
mutation
,
_c
.
hooks
)
}
...
...
@@ -279,12 +311,18 @@ func (_c *UserSubscriptionCreate) ExecX(ctx context.Context) {
}
// defaults sets the default values of the builder before save.
func
(
_c
*
UserSubscriptionCreate
)
defaults
()
{
func
(
_c
*
UserSubscriptionCreate
)
defaults
()
error
{
if
_
,
ok
:=
_c
.
mutation
.
CreatedAt
();
!
ok
{
if
usersubscription
.
DefaultCreatedAt
==
nil
{
return
fmt
.
Errorf
(
"ent: uninitialized usersubscription.DefaultCreatedAt (forgotten import ent/runtime?)"
)
}
v
:=
usersubscription
.
DefaultCreatedAt
()
_c
.
mutation
.
SetCreatedAt
(
v
)
}
if
_
,
ok
:=
_c
.
mutation
.
UpdatedAt
();
!
ok
{
if
usersubscription
.
DefaultUpdatedAt
==
nil
{
return
fmt
.
Errorf
(
"ent: uninitialized usersubscription.DefaultUpdatedAt (forgotten import ent/runtime?)"
)
}
v
:=
usersubscription
.
DefaultUpdatedAt
()
_c
.
mutation
.
SetUpdatedAt
(
v
)
}
...
...
@@ -305,9 +343,13 @@ func (_c *UserSubscriptionCreate) defaults() {
_c
.
mutation
.
SetMonthlyUsageUsd
(
v
)
}
if
_
,
ok
:=
_c
.
mutation
.
AssignedAt
();
!
ok
{
if
usersubscription
.
DefaultAssignedAt
==
nil
{
return
fmt
.
Errorf
(
"ent: uninitialized usersubscription.DefaultAssignedAt (forgotten import ent/runtime?)"
)
}
v
:=
usersubscription
.
DefaultAssignedAt
()
_c
.
mutation
.
SetAssignedAt
(
v
)
}
return
nil
}
// check runs all checks and user-defined validators on the builder.
...
...
@@ -391,6 +433,10 @@ func (_c *UserSubscriptionCreate) createSpec() (*UserSubscription, *sqlgraph.Cre
_spec
.
SetField
(
usersubscription
.
FieldUpdatedAt
,
field
.
TypeTime
,
value
)
_node
.
UpdatedAt
=
value
}
if
value
,
ok
:=
_c
.
mutation
.
DeletedAt
();
ok
{
_spec
.
SetField
(
usersubscription
.
FieldDeletedAt
,
field
.
TypeTime
,
value
)
_node
.
DeletedAt
=
&
value
}
if
value
,
ok
:=
_c
.
mutation
.
StartsAt
();
ok
{
_spec
.
SetField
(
usersubscription
.
FieldStartsAt
,
field
.
TypeTime
,
value
)
_node
.
StartsAt
=
value
...
...
@@ -486,6 +532,22 @@ func (_c *UserSubscriptionCreate) createSpec() (*UserSubscription, *sqlgraph.Cre
_node
.
AssignedBy
=
&
nodes
[
0
]
_spec
.
Edges
=
append
(
_spec
.
Edges
,
edge
)
}
if
nodes
:=
_c
.
mutation
.
UsageLogsIDs
();
len
(
nodes
)
>
0
{
edge
:=
&
sqlgraph
.
EdgeSpec
{
Rel
:
sqlgraph
.
O2M
,
Inverse
:
false
,
Table
:
usersubscription
.
UsageLogsTable
,
Columns
:
[]
string
{
usersubscription
.
UsageLogsColumn
},
Bidi
:
false
,
Target
:
&
sqlgraph
.
EdgeTarget
{
IDSpec
:
sqlgraph
.
NewFieldSpec
(
usagelog
.
FieldID
,
field
.
TypeInt64
),
},
}
for
_
,
k
:=
range
nodes
{
edge
.
Target
.
Nodes
=
append
(
edge
.
Target
.
Nodes
,
k
)
}
_spec
.
Edges
=
append
(
_spec
.
Edges
,
edge
)
}
return
_node
,
_spec
}
...
...
@@ -550,6 +612,24 @@ func (u *UserSubscriptionUpsert) UpdateUpdatedAt() *UserSubscriptionUpsert {
return
u
}
// SetDeletedAt sets the "deleted_at" field.
func
(
u
*
UserSubscriptionUpsert
)
SetDeletedAt
(
v
time
.
Time
)
*
UserSubscriptionUpsert
{
u
.
Set
(
usersubscription
.
FieldDeletedAt
,
v
)
return
u
}
// UpdateDeletedAt sets the "deleted_at" field to the value that was provided on create.
func
(
u
*
UserSubscriptionUpsert
)
UpdateDeletedAt
()
*
UserSubscriptionUpsert
{
u
.
SetExcluded
(
usersubscription
.
FieldDeletedAt
)
return
u
}
// ClearDeletedAt clears the value of the "deleted_at" field.
func
(
u
*
UserSubscriptionUpsert
)
ClearDeletedAt
()
*
UserSubscriptionUpsert
{
u
.
SetNull
(
usersubscription
.
FieldDeletedAt
)
return
u
}
// SetUserID sets the "user_id" field.
func
(
u
*
UserSubscriptionUpsert
)
SetUserID
(
v
int64
)
*
UserSubscriptionUpsert
{
u
.
Set
(
usersubscription
.
FieldUserID
,
v
)
...
...
@@ -825,6 +905,27 @@ func (u *UserSubscriptionUpsertOne) UpdateUpdatedAt() *UserSubscriptionUpsertOne
})
}
// SetDeletedAt sets the "deleted_at" field.
func
(
u
*
UserSubscriptionUpsertOne
)
SetDeletedAt
(
v
time
.
Time
)
*
UserSubscriptionUpsertOne
{
return
u
.
Update
(
func
(
s
*
UserSubscriptionUpsert
)
{
s
.
SetDeletedAt
(
v
)
})
}
// UpdateDeletedAt sets the "deleted_at" field to the value that was provided on create.
func
(
u
*
UserSubscriptionUpsertOne
)
UpdateDeletedAt
()
*
UserSubscriptionUpsertOne
{
return
u
.
Update
(
func
(
s
*
UserSubscriptionUpsert
)
{
s
.
UpdateDeletedAt
()
})
}
// ClearDeletedAt clears the value of the "deleted_at" field.
func
(
u
*
UserSubscriptionUpsertOne
)
ClearDeletedAt
()
*
UserSubscriptionUpsertOne
{
return
u
.
Update
(
func
(
s
*
UserSubscriptionUpsert
)
{
s
.
ClearDeletedAt
()
})
}
// SetUserID sets the "user_id" field.
func
(
u
*
UserSubscriptionUpsertOne
)
SetUserID
(
v
int64
)
*
UserSubscriptionUpsertOne
{
return
u
.
Update
(
func
(
s
*
UserSubscriptionUpsert
)
{
...
...
@@ -1302,6 +1403,27 @@ func (u *UserSubscriptionUpsertBulk) UpdateUpdatedAt() *UserSubscriptionUpsertBu
})
}
// SetDeletedAt sets the "deleted_at" field.
func
(
u
*
UserSubscriptionUpsertBulk
)
SetDeletedAt
(
v
time
.
Time
)
*
UserSubscriptionUpsertBulk
{
return
u
.
Update
(
func
(
s
*
UserSubscriptionUpsert
)
{
s
.
SetDeletedAt
(
v
)
})
}
// UpdateDeletedAt sets the "deleted_at" field to the value that was provided on create.
func
(
u
*
UserSubscriptionUpsertBulk
)
UpdateDeletedAt
()
*
UserSubscriptionUpsertBulk
{
return
u
.
Update
(
func
(
s
*
UserSubscriptionUpsert
)
{
s
.
UpdateDeletedAt
()
})
}
// ClearDeletedAt clears the value of the "deleted_at" field.
func
(
u
*
UserSubscriptionUpsertBulk
)
ClearDeletedAt
()
*
UserSubscriptionUpsertBulk
{
return
u
.
Update
(
func
(
s
*
UserSubscriptionUpsert
)
{
s
.
ClearDeletedAt
()
})
}
// SetUserID sets the "user_id" field.
func
(
u
*
UserSubscriptionUpsertBulk
)
SetUserID
(
v
int64
)
*
UserSubscriptionUpsertBulk
{
return
u
.
Update
(
func
(
s
*
UserSubscriptionUpsert
)
{
...
...
backend/ent/usersubscription_query.go
View file @
7331220e
...
...
@@ -4,6 +4,7 @@ package ent
import
(
"context"
"database/sql/driver"
"fmt"
"math"
...
...
@@ -13,6 +14,7 @@ import (
"entgo.io/ent/schema/field"
"github.com/Wei-Shaw/sub2api/ent/group"
"github.com/Wei-Shaw/sub2api/ent/predicate"
"github.com/Wei-Shaw/sub2api/ent/usagelog"
"github.com/Wei-Shaw/sub2api/ent/user"
"github.com/Wei-Shaw/sub2api/ent/usersubscription"
)
...
...
@@ -27,6 +29,7 @@ type UserSubscriptionQuery struct {
withUser
*
UserQuery
withGroup
*
GroupQuery
withAssignedByUser
*
UserQuery
withUsageLogs
*
UsageLogQuery
// intermediate query (i.e. traversal path).
sql
*
sql
.
Selector
path
func
(
context
.
Context
)
(
*
sql
.
Selector
,
error
)
...
...
@@ -129,6 +132,28 @@ func (_q *UserSubscriptionQuery) QueryAssignedByUser() *UserQuery {
return
query
}
// QueryUsageLogs chains the current query on the "usage_logs" edge.
func
(
_q
*
UserSubscriptionQuery
)
QueryUsageLogs
()
*
UsageLogQuery
{
query
:=
(
&
UsageLogClient
{
config
:
_q
.
config
})
.
Query
()
query
.
path
=
func
(
ctx
context
.
Context
)
(
fromU
*
sql
.
Selector
,
err
error
)
{
if
err
:=
_q
.
prepareQuery
(
ctx
);
err
!=
nil
{
return
nil
,
err
}
selector
:=
_q
.
sqlQuery
(
ctx
)
if
err
:=
selector
.
Err
();
err
!=
nil
{
return
nil
,
err
}
step
:=
sqlgraph
.
NewStep
(
sqlgraph
.
From
(
usersubscription
.
Table
,
usersubscription
.
FieldID
,
selector
),
sqlgraph
.
To
(
usagelog
.
Table
,
usagelog
.
FieldID
),
sqlgraph
.
Edge
(
sqlgraph
.
O2M
,
false
,
usersubscription
.
UsageLogsTable
,
usersubscription
.
UsageLogsColumn
),
)
fromU
=
sqlgraph
.
SetNeighbors
(
_q
.
driver
.
Dialect
(),
step
)
return
fromU
,
nil
}
return
query
}
// First returns the first UserSubscription entity from the query.
// Returns a *NotFoundError when no UserSubscription was found.
func
(
_q
*
UserSubscriptionQuery
)
First
(
ctx
context
.
Context
)
(
*
UserSubscription
,
error
)
{
...
...
@@ -324,6 +349,7 @@ func (_q *UserSubscriptionQuery) Clone() *UserSubscriptionQuery {
withUser
:
_q
.
withUser
.
Clone
(),
withGroup
:
_q
.
withGroup
.
Clone
(),
withAssignedByUser
:
_q
.
withAssignedByUser
.
Clone
(),
withUsageLogs
:
_q
.
withUsageLogs
.
Clone
(),
// clone intermediate query.
sql
:
_q
.
sql
.
Clone
(),
path
:
_q
.
path
,
...
...
@@ -363,6 +389,17 @@ func (_q *UserSubscriptionQuery) WithAssignedByUser(opts ...func(*UserQuery)) *U
return
_q
}
// WithUsageLogs tells the query-builder to eager-load the nodes that are connected to
// the "usage_logs" edge. The optional arguments are used to configure the query builder of the edge.
func
(
_q
*
UserSubscriptionQuery
)
WithUsageLogs
(
opts
...
func
(
*
UsageLogQuery
))
*
UserSubscriptionQuery
{
query
:=
(
&
UsageLogClient
{
config
:
_q
.
config
})
.
Query
()
for
_
,
opt
:=
range
opts
{
opt
(
query
)
}
_q
.
withUsageLogs
=
query
return
_q
}
// GroupBy is used to group vertices by one or more fields/columns.
// It is often used with aggregate functions, like: count, max, mean, min, sum.
//
...
...
@@ -441,10 +478,11 @@ func (_q *UserSubscriptionQuery) sqlAll(ctx context.Context, hooks ...queryHook)
var
(
nodes
=
[]
*
UserSubscription
{}
_spec
=
_q
.
querySpec
()
loadedTypes
=
[
3
]
bool
{
loadedTypes
=
[
4
]
bool
{
_q
.
withUser
!=
nil
,
_q
.
withGroup
!=
nil
,
_q
.
withAssignedByUser
!=
nil
,
_q
.
withUsageLogs
!=
nil
,
}
)
_spec
.
ScanValues
=
func
(
columns
[]
string
)
([]
any
,
error
)
{
...
...
@@ -483,6 +521,13 @@ func (_q *UserSubscriptionQuery) sqlAll(ctx context.Context, hooks ...queryHook)
return
nil
,
err
}
}
if
query
:=
_q
.
withUsageLogs
;
query
!=
nil
{
if
err
:=
_q
.
loadUsageLogs
(
ctx
,
query
,
nodes
,
func
(
n
*
UserSubscription
)
{
n
.
Edges
.
UsageLogs
=
[]
*
UsageLog
{}
},
func
(
n
*
UserSubscription
,
e
*
UsageLog
)
{
n
.
Edges
.
UsageLogs
=
append
(
n
.
Edges
.
UsageLogs
,
e
)
});
err
!=
nil
{
return
nil
,
err
}
}
return
nodes
,
nil
}
...
...
@@ -576,6 +621,39 @@ func (_q *UserSubscriptionQuery) loadAssignedByUser(ctx context.Context, query *
}
return
nil
}
func
(
_q
*
UserSubscriptionQuery
)
loadUsageLogs
(
ctx
context
.
Context
,
query
*
UsageLogQuery
,
nodes
[]
*
UserSubscription
,
init
func
(
*
UserSubscription
),
assign
func
(
*
UserSubscription
,
*
UsageLog
))
error
{
fks
:=
make
([]
driver
.
Value
,
0
,
len
(
nodes
))
nodeids
:=
make
(
map
[
int64
]
*
UserSubscription
)
for
i
:=
range
nodes
{
fks
=
append
(
fks
,
nodes
[
i
]
.
ID
)
nodeids
[
nodes
[
i
]
.
ID
]
=
nodes
[
i
]
if
init
!=
nil
{
init
(
nodes
[
i
])
}
}
if
len
(
query
.
ctx
.
Fields
)
>
0
{
query
.
ctx
.
AppendFieldOnce
(
usagelog
.
FieldSubscriptionID
)
}
query
.
Where
(
predicate
.
UsageLog
(
func
(
s
*
sql
.
Selector
)
{
s
.
Where
(
sql
.
InValues
(
s
.
C
(
usersubscription
.
UsageLogsColumn
),
fks
...
))
}))
neighbors
,
err
:=
query
.
All
(
ctx
)
if
err
!=
nil
{
return
err
}
for
_
,
n
:=
range
neighbors
{
fk
:=
n
.
SubscriptionID
if
fk
==
nil
{
return
fmt
.
Errorf
(
`foreign-key "subscription_id" is nil for node %v`
,
n
.
ID
)
}
node
,
ok
:=
nodeids
[
*
fk
]
if
!
ok
{
return
fmt
.
Errorf
(
`unexpected referenced foreign-key "subscription_id" returned %v for node %v`
,
*
fk
,
n
.
ID
)
}
assign
(
node
,
n
)
}
return
nil
}
func
(
_q
*
UserSubscriptionQuery
)
sqlCount
(
ctx
context
.
Context
)
(
int
,
error
)
{
_spec
:=
_q
.
querySpec
()
...
...
backend/ent/usersubscription_update.go
View file @
7331220e
...
...
@@ -13,6 +13,7 @@ import (
"entgo.io/ent/schema/field"
"github.com/Wei-Shaw/sub2api/ent/group"
"github.com/Wei-Shaw/sub2api/ent/predicate"
"github.com/Wei-Shaw/sub2api/ent/usagelog"
"github.com/Wei-Shaw/sub2api/ent/user"
"github.com/Wei-Shaw/sub2api/ent/usersubscription"
)
...
...
@@ -36,6 +37,26 @@ func (_u *UserSubscriptionUpdate) SetUpdatedAt(v time.Time) *UserSubscriptionUpd
return
_u
}
// SetDeletedAt sets the "deleted_at" field.
func
(
_u
*
UserSubscriptionUpdate
)
SetDeletedAt
(
v
time
.
Time
)
*
UserSubscriptionUpdate
{
_u
.
mutation
.
SetDeletedAt
(
v
)
return
_u
}
// SetNillableDeletedAt sets the "deleted_at" field if the given value is not nil.
func
(
_u
*
UserSubscriptionUpdate
)
SetNillableDeletedAt
(
v
*
time
.
Time
)
*
UserSubscriptionUpdate
{
if
v
!=
nil
{
_u
.
SetDeletedAt
(
*
v
)
}
return
_u
}
// ClearDeletedAt clears the value of the "deleted_at" field.
func
(
_u
*
UserSubscriptionUpdate
)
ClearDeletedAt
()
*
UserSubscriptionUpdate
{
_u
.
mutation
.
ClearDeletedAt
()
return
_u
}
// SetUserID sets the "user_id" field.
func
(
_u
*
UserSubscriptionUpdate
)
SetUserID
(
v
int64
)
*
UserSubscriptionUpdate
{
_u
.
mutation
.
SetUserID
(
v
)
...
...
@@ -312,6 +333,21 @@ func (_u *UserSubscriptionUpdate) SetAssignedByUser(v *User) *UserSubscriptionUp
return
_u
.
SetAssignedByUserID
(
v
.
ID
)
}
// AddUsageLogIDs adds the "usage_logs" edge to the UsageLog entity by IDs.
func
(
_u
*
UserSubscriptionUpdate
)
AddUsageLogIDs
(
ids
...
int64
)
*
UserSubscriptionUpdate
{
_u
.
mutation
.
AddUsageLogIDs
(
ids
...
)
return
_u
}
// AddUsageLogs adds the "usage_logs" edges to the UsageLog entity.
func
(
_u
*
UserSubscriptionUpdate
)
AddUsageLogs
(
v
...*
UsageLog
)
*
UserSubscriptionUpdate
{
ids
:=
make
([]
int64
,
len
(
v
))
for
i
:=
range
v
{
ids
[
i
]
=
v
[
i
]
.
ID
}
return
_u
.
AddUsageLogIDs
(
ids
...
)
}
// Mutation returns the UserSubscriptionMutation object of the builder.
func
(
_u
*
UserSubscriptionUpdate
)
Mutation
()
*
UserSubscriptionMutation
{
return
_u
.
mutation
...
...
@@ -335,9 +371,32 @@ func (_u *UserSubscriptionUpdate) ClearAssignedByUser() *UserSubscriptionUpdate
return
_u
}
// ClearUsageLogs clears all "usage_logs" edges to the UsageLog entity.
func
(
_u
*
UserSubscriptionUpdate
)
ClearUsageLogs
()
*
UserSubscriptionUpdate
{
_u
.
mutation
.
ClearUsageLogs
()
return
_u
}
// RemoveUsageLogIDs removes the "usage_logs" edge to UsageLog entities by IDs.
func
(
_u
*
UserSubscriptionUpdate
)
RemoveUsageLogIDs
(
ids
...
int64
)
*
UserSubscriptionUpdate
{
_u
.
mutation
.
RemoveUsageLogIDs
(
ids
...
)
return
_u
}
// RemoveUsageLogs removes "usage_logs" edges to UsageLog entities.
func
(
_u
*
UserSubscriptionUpdate
)
RemoveUsageLogs
(
v
...*
UsageLog
)
*
UserSubscriptionUpdate
{
ids
:=
make
([]
int64
,
len
(
v
))
for
i
:=
range
v
{
ids
[
i
]
=
v
[
i
]
.
ID
}
return
_u
.
RemoveUsageLogIDs
(
ids
...
)
}
// Save executes the query and returns the number of nodes affected by the update operation.
func
(
_u
*
UserSubscriptionUpdate
)
Save
(
ctx
context
.
Context
)
(
int
,
error
)
{
_u
.
defaults
()
if
err
:=
_u
.
defaults
();
err
!=
nil
{
return
0
,
err
}
return
withHooks
(
ctx
,
_u
.
sqlSave
,
_u
.
mutation
,
_u
.
hooks
)
}
...
...
@@ -364,11 +423,15 @@ func (_u *UserSubscriptionUpdate) ExecX(ctx context.Context) {
}
// defaults sets the default values of the builder before save.
func
(
_u
*
UserSubscriptionUpdate
)
defaults
()
{
func
(
_u
*
UserSubscriptionUpdate
)
defaults
()
error
{
if
_
,
ok
:=
_u
.
mutation
.
UpdatedAt
();
!
ok
{
if
usersubscription
.
UpdateDefaultUpdatedAt
==
nil
{
return
fmt
.
Errorf
(
"ent: uninitialized usersubscription.UpdateDefaultUpdatedAt (forgotten import ent/runtime?)"
)
}
v
:=
usersubscription
.
UpdateDefaultUpdatedAt
()
_u
.
mutation
.
SetUpdatedAt
(
v
)
}
return
nil
}
// check runs all checks and user-defined validators on the builder.
...
...
@@ -402,6 +465,12 @@ func (_u *UserSubscriptionUpdate) sqlSave(ctx context.Context) (_node int, err e
if
value
,
ok
:=
_u
.
mutation
.
UpdatedAt
();
ok
{
_spec
.
SetField
(
usersubscription
.
FieldUpdatedAt
,
field
.
TypeTime
,
value
)
}
if
value
,
ok
:=
_u
.
mutation
.
DeletedAt
();
ok
{
_spec
.
SetField
(
usersubscription
.
FieldDeletedAt
,
field
.
TypeTime
,
value
)
}
if
_u
.
mutation
.
DeletedAtCleared
()
{
_spec
.
ClearField
(
usersubscription
.
FieldDeletedAt
,
field
.
TypeTime
)
}
if
value
,
ok
:=
_u
.
mutation
.
StartsAt
();
ok
{
_spec
.
SetField
(
usersubscription
.
FieldStartsAt
,
field
.
TypeTime
,
value
)
}
...
...
@@ -543,6 +612,51 @@ func (_u *UserSubscriptionUpdate) sqlSave(ctx context.Context) (_node int, err e
}
_spec
.
Edges
.
Add
=
append
(
_spec
.
Edges
.
Add
,
edge
)
}
if
_u
.
mutation
.
UsageLogsCleared
()
{
edge
:=
&
sqlgraph
.
EdgeSpec
{
Rel
:
sqlgraph
.
O2M
,
Inverse
:
false
,
Table
:
usersubscription
.
UsageLogsTable
,
Columns
:
[]
string
{
usersubscription
.
UsageLogsColumn
},
Bidi
:
false
,
Target
:
&
sqlgraph
.
EdgeTarget
{
IDSpec
:
sqlgraph
.
NewFieldSpec
(
usagelog
.
FieldID
,
field
.
TypeInt64
),
},
}
_spec
.
Edges
.
Clear
=
append
(
_spec
.
Edges
.
Clear
,
edge
)
}
if
nodes
:=
_u
.
mutation
.
RemovedUsageLogsIDs
();
len
(
nodes
)
>
0
&&
!
_u
.
mutation
.
UsageLogsCleared
()
{
edge
:=
&
sqlgraph
.
EdgeSpec
{
Rel
:
sqlgraph
.
O2M
,
Inverse
:
false
,
Table
:
usersubscription
.
UsageLogsTable
,
Columns
:
[]
string
{
usersubscription
.
UsageLogsColumn
},
Bidi
:
false
,
Target
:
&
sqlgraph
.
EdgeTarget
{
IDSpec
:
sqlgraph
.
NewFieldSpec
(
usagelog
.
FieldID
,
field
.
TypeInt64
),
},
}
for
_
,
k
:=
range
nodes
{
edge
.
Target
.
Nodes
=
append
(
edge
.
Target
.
Nodes
,
k
)
}
_spec
.
Edges
.
Clear
=
append
(
_spec
.
Edges
.
Clear
,
edge
)
}
if
nodes
:=
_u
.
mutation
.
UsageLogsIDs
();
len
(
nodes
)
>
0
{
edge
:=
&
sqlgraph
.
EdgeSpec
{
Rel
:
sqlgraph
.
O2M
,
Inverse
:
false
,
Table
:
usersubscription
.
UsageLogsTable
,
Columns
:
[]
string
{
usersubscription
.
UsageLogsColumn
},
Bidi
:
false
,
Target
:
&
sqlgraph
.
EdgeTarget
{
IDSpec
:
sqlgraph
.
NewFieldSpec
(
usagelog
.
FieldID
,
field
.
TypeInt64
),
},
}
for
_
,
k
:=
range
nodes
{
edge
.
Target
.
Nodes
=
append
(
edge
.
Target
.
Nodes
,
k
)
}
_spec
.
Edges
.
Add
=
append
(
_spec
.
Edges
.
Add
,
edge
)
}
if
_node
,
err
=
sqlgraph
.
UpdateNodes
(
ctx
,
_u
.
driver
,
_spec
);
err
!=
nil
{
if
_
,
ok
:=
err
.
(
*
sqlgraph
.
NotFoundError
);
ok
{
err
=
&
NotFoundError
{
usersubscription
.
Label
}
...
...
@@ -569,6 +683,26 @@ func (_u *UserSubscriptionUpdateOne) SetUpdatedAt(v time.Time) *UserSubscription
return
_u
}
// SetDeletedAt sets the "deleted_at" field.
func
(
_u
*
UserSubscriptionUpdateOne
)
SetDeletedAt
(
v
time
.
Time
)
*
UserSubscriptionUpdateOne
{
_u
.
mutation
.
SetDeletedAt
(
v
)
return
_u
}
// SetNillableDeletedAt sets the "deleted_at" field if the given value is not nil.
func
(
_u
*
UserSubscriptionUpdateOne
)
SetNillableDeletedAt
(
v
*
time
.
Time
)
*
UserSubscriptionUpdateOne
{
if
v
!=
nil
{
_u
.
SetDeletedAt
(
*
v
)
}
return
_u
}
// ClearDeletedAt clears the value of the "deleted_at" field.
func
(
_u
*
UserSubscriptionUpdateOne
)
ClearDeletedAt
()
*
UserSubscriptionUpdateOne
{
_u
.
mutation
.
ClearDeletedAt
()
return
_u
}
// SetUserID sets the "user_id" field.
func
(
_u
*
UserSubscriptionUpdateOne
)
SetUserID
(
v
int64
)
*
UserSubscriptionUpdateOne
{
_u
.
mutation
.
SetUserID
(
v
)
...
...
@@ -845,6 +979,21 @@ func (_u *UserSubscriptionUpdateOne) SetAssignedByUser(v *User) *UserSubscriptio
return
_u
.
SetAssignedByUserID
(
v
.
ID
)
}
// AddUsageLogIDs adds the "usage_logs" edge to the UsageLog entity by IDs.
func
(
_u
*
UserSubscriptionUpdateOne
)
AddUsageLogIDs
(
ids
...
int64
)
*
UserSubscriptionUpdateOne
{
_u
.
mutation
.
AddUsageLogIDs
(
ids
...
)
return
_u
}
// AddUsageLogs adds the "usage_logs" edges to the UsageLog entity.
func
(
_u
*
UserSubscriptionUpdateOne
)
AddUsageLogs
(
v
...*
UsageLog
)
*
UserSubscriptionUpdateOne
{
ids
:=
make
([]
int64
,
len
(
v
))
for
i
:=
range
v
{
ids
[
i
]
=
v
[
i
]
.
ID
}
return
_u
.
AddUsageLogIDs
(
ids
...
)
}
// Mutation returns the UserSubscriptionMutation object of the builder.
func
(
_u
*
UserSubscriptionUpdateOne
)
Mutation
()
*
UserSubscriptionMutation
{
return
_u
.
mutation
...
...
@@ -868,6 +1017,27 @@ func (_u *UserSubscriptionUpdateOne) ClearAssignedByUser() *UserSubscriptionUpda
return
_u
}
// ClearUsageLogs clears all "usage_logs" edges to the UsageLog entity.
func
(
_u
*
UserSubscriptionUpdateOne
)
ClearUsageLogs
()
*
UserSubscriptionUpdateOne
{
_u
.
mutation
.
ClearUsageLogs
()
return
_u
}
// RemoveUsageLogIDs removes the "usage_logs" edge to UsageLog entities by IDs.
func
(
_u
*
UserSubscriptionUpdateOne
)
RemoveUsageLogIDs
(
ids
...
int64
)
*
UserSubscriptionUpdateOne
{
_u
.
mutation
.
RemoveUsageLogIDs
(
ids
...
)
return
_u
}
// RemoveUsageLogs removes "usage_logs" edges to UsageLog entities.
func
(
_u
*
UserSubscriptionUpdateOne
)
RemoveUsageLogs
(
v
...*
UsageLog
)
*
UserSubscriptionUpdateOne
{
ids
:=
make
([]
int64
,
len
(
v
))
for
i
:=
range
v
{
ids
[
i
]
=
v
[
i
]
.
ID
}
return
_u
.
RemoveUsageLogIDs
(
ids
...
)
}
// Where appends a list predicates to the UserSubscriptionUpdate builder.
func
(
_u
*
UserSubscriptionUpdateOne
)
Where
(
ps
...
predicate
.
UserSubscription
)
*
UserSubscriptionUpdateOne
{
_u
.
mutation
.
Where
(
ps
...
)
...
...
@@ -883,7 +1053,9 @@ func (_u *UserSubscriptionUpdateOne) Select(field string, fields ...string) *Use
// Save executes the query and returns the updated UserSubscription entity.
func
(
_u
*
UserSubscriptionUpdateOne
)
Save
(
ctx
context
.
Context
)
(
*
UserSubscription
,
error
)
{
_u
.
defaults
()
if
err
:=
_u
.
defaults
();
err
!=
nil
{
return
nil
,
err
}
return
withHooks
(
ctx
,
_u
.
sqlSave
,
_u
.
mutation
,
_u
.
hooks
)
}
...
...
@@ -910,11 +1082,15 @@ func (_u *UserSubscriptionUpdateOne) ExecX(ctx context.Context) {
}
// defaults sets the default values of the builder before save.
func
(
_u
*
UserSubscriptionUpdateOne
)
defaults
()
{
func
(
_u
*
UserSubscriptionUpdateOne
)
defaults
()
error
{
if
_
,
ok
:=
_u
.
mutation
.
UpdatedAt
();
!
ok
{
if
usersubscription
.
UpdateDefaultUpdatedAt
==
nil
{
return
fmt
.
Errorf
(
"ent: uninitialized usersubscription.UpdateDefaultUpdatedAt (forgotten import ent/runtime?)"
)
}
v
:=
usersubscription
.
UpdateDefaultUpdatedAt
()
_u
.
mutation
.
SetUpdatedAt
(
v
)
}
return
nil
}
// check runs all checks and user-defined validators on the builder.
...
...
@@ -965,6 +1141,12 @@ func (_u *UserSubscriptionUpdateOne) sqlSave(ctx context.Context) (_node *UserSu
if
value
,
ok
:=
_u
.
mutation
.
UpdatedAt
();
ok
{
_spec
.
SetField
(
usersubscription
.
FieldUpdatedAt
,
field
.
TypeTime
,
value
)
}
if
value
,
ok
:=
_u
.
mutation
.
DeletedAt
();
ok
{
_spec
.
SetField
(
usersubscription
.
FieldDeletedAt
,
field
.
TypeTime
,
value
)
}
if
_u
.
mutation
.
DeletedAtCleared
()
{
_spec
.
ClearField
(
usersubscription
.
FieldDeletedAt
,
field
.
TypeTime
)
}
if
value
,
ok
:=
_u
.
mutation
.
StartsAt
();
ok
{
_spec
.
SetField
(
usersubscription
.
FieldStartsAt
,
field
.
TypeTime
,
value
)
}
...
...
@@ -1106,6 +1288,51 @@ func (_u *UserSubscriptionUpdateOne) sqlSave(ctx context.Context) (_node *UserSu
}
_spec
.
Edges
.
Add
=
append
(
_spec
.
Edges
.
Add
,
edge
)
}
if
_u
.
mutation
.
UsageLogsCleared
()
{
edge
:=
&
sqlgraph
.
EdgeSpec
{
Rel
:
sqlgraph
.
O2M
,
Inverse
:
false
,
Table
:
usersubscription
.
UsageLogsTable
,
Columns
:
[]
string
{
usersubscription
.
UsageLogsColumn
},
Bidi
:
false
,
Target
:
&
sqlgraph
.
EdgeTarget
{
IDSpec
:
sqlgraph
.
NewFieldSpec
(
usagelog
.
FieldID
,
field
.
TypeInt64
),
},
}
_spec
.
Edges
.
Clear
=
append
(
_spec
.
Edges
.
Clear
,
edge
)
}
if
nodes
:=
_u
.
mutation
.
RemovedUsageLogsIDs
();
len
(
nodes
)
>
0
&&
!
_u
.
mutation
.
UsageLogsCleared
()
{
edge
:=
&
sqlgraph
.
EdgeSpec
{
Rel
:
sqlgraph
.
O2M
,
Inverse
:
false
,
Table
:
usersubscription
.
UsageLogsTable
,
Columns
:
[]
string
{
usersubscription
.
UsageLogsColumn
},
Bidi
:
false
,
Target
:
&
sqlgraph
.
EdgeTarget
{
IDSpec
:
sqlgraph
.
NewFieldSpec
(
usagelog
.
FieldID
,
field
.
TypeInt64
),
},
}
for
_
,
k
:=
range
nodes
{
edge
.
Target
.
Nodes
=
append
(
edge
.
Target
.
Nodes
,
k
)
}
_spec
.
Edges
.
Clear
=
append
(
_spec
.
Edges
.
Clear
,
edge
)
}
if
nodes
:=
_u
.
mutation
.
UsageLogsIDs
();
len
(
nodes
)
>
0
{
edge
:=
&
sqlgraph
.
EdgeSpec
{
Rel
:
sqlgraph
.
O2M
,
Inverse
:
false
,
Table
:
usersubscription
.
UsageLogsTable
,
Columns
:
[]
string
{
usersubscription
.
UsageLogsColumn
},
Bidi
:
false
,
Target
:
&
sqlgraph
.
EdgeTarget
{
IDSpec
:
sqlgraph
.
NewFieldSpec
(
usagelog
.
FieldID
,
field
.
TypeInt64
),
},
}
for
_
,
k
:=
range
nodes
{
edge
.
Target
.
Nodes
=
append
(
edge
.
Target
.
Nodes
,
k
)
}
_spec
.
Edges
.
Add
=
append
(
_spec
.
Edges
.
Add
,
edge
)
}
_node
=
&
UserSubscription
{
config
:
_u
.
config
}
_spec
.
Assign
=
_node
.
assignValues
_spec
.
ScanValues
=
_node
.
scanValues
...
...
backend/internal/config/config.go
View file @
7331220e
...
...
@@ -3,6 +3,7 @@ package config
import
(
"fmt"
"strings"
"time"
"github.com/spf13/viper"
)
...
...
@@ -12,6 +13,20 @@ const (
RunModeSimple
=
"simple"
)
// 连接池隔离策略常量
// 用于控制上游 HTTP 连接池的隔离粒度,影响连接复用和资源消耗
const
(
// ConnectionPoolIsolationProxy: 按代理隔离
// 同一代理地址共享连接池,适合代理数量少、账户数量多的场景
ConnectionPoolIsolationProxy
=
"proxy"
// ConnectionPoolIsolationAccount: 按账户隔离
// 每个账户独立连接池,适合账户数量少、需要严格隔离的场景
ConnectionPoolIsolationAccount
=
"account"
// ConnectionPoolIsolationAccountProxy: 按账户+代理组合隔离(默认)
// 同一账户+代理组合共享连接池,提供最细粒度的隔离
ConnectionPoolIsolationAccountProxy
=
"account_proxy"
)
type
Config
struct
{
Server
ServerConfig
`mapstructure:"server"`
Database
DatabaseConfig
`mapstructure:"database"`
...
...
@@ -29,6 +44,7 @@ type Config struct {
type
GeminiConfig
struct
{
OAuth
GeminiOAuthConfig
`mapstructure:"oauth"`
Quota
GeminiQuotaConfig
`mapstructure:"quota"`
}
type
GeminiOAuthConfig
struct
{
...
...
@@ -37,6 +53,17 @@ type GeminiOAuthConfig struct {
Scopes
string
`mapstructure:"scopes"`
}
type
GeminiQuotaConfig
struct
{
Tiers
map
[
string
]
GeminiTierQuotaConfig
`mapstructure:"tiers"`
Policy
string
`mapstructure:"policy"`
}
type
GeminiTierQuotaConfig
struct
{
ProRPD
*
int64
`mapstructure:"pro_rpd" json:"pro_rpd"`
FlashRPD
*
int64
`mapstructure:"flash_rpd" json:"flash_rpd"`
CooldownMinutes
*
int
`mapstructure:"cooldown_minutes" json:"cooldown_minutes"`
}
// TokenRefreshConfig OAuth token自动刷新配置
type
TokenRefreshConfig
struct
{
// 是否启用自动刷新
...
...
@@ -79,12 +106,71 @@ type GatewayConfig struct {
// 等待上游响应头的超时时间(秒),0表示无超时
// 注意:这不影响流式数据传输,只控制等待响应头的时间
ResponseHeaderTimeout
int
`mapstructure:"response_header_timeout"`
// 请求体最大字节数,用于网关请求体大小限制
MaxBodySize
int64
`mapstructure:"max_body_size"`
// ConnectionPoolIsolation: 上游连接池隔离策略(proxy/account/account_proxy)
ConnectionPoolIsolation
string
`mapstructure:"connection_pool_isolation"`
// HTTP 上游连接池配置(性能优化:支持高并发场景调优)
// MaxIdleConns: 所有主机的最大空闲连接总数
MaxIdleConns
int
`mapstructure:"max_idle_conns"`
// MaxIdleConnsPerHost: 每个主机的最大空闲连接数(关键参数,影响连接复用率)
MaxIdleConnsPerHost
int
`mapstructure:"max_idle_conns_per_host"`
// MaxConnsPerHost: 每个主机的最大连接数(包括活跃+空闲),0表示无限制
MaxConnsPerHost
int
`mapstructure:"max_conns_per_host"`
// IdleConnTimeoutSeconds: 空闲连接超时时间(秒)
IdleConnTimeoutSeconds
int
`mapstructure:"idle_conn_timeout_seconds"`
// MaxUpstreamClients: 上游连接池客户端最大缓存数量
// 当使用连接池隔离策略时,系统会为不同的账户/代理组合创建独立的 HTTP 客户端
// 此参数限制缓存的客户端数量,超出后会淘汰最久未使用的客户端
// 建议值:预估的活跃账户数 * 1.2(留有余量)
MaxUpstreamClients
int
`mapstructure:"max_upstream_clients"`
// ClientIdleTTLSeconds: 上游连接池客户端空闲回收阈值(秒)
// 超过此时间未使用的客户端会被标记为可回收
// 建议值:根据用户访问频率设置,一般 10-30 分钟
ClientIdleTTLSeconds
int
`mapstructure:"client_idle_ttl_seconds"`
// ConcurrencySlotTTLMinutes: 并发槽位过期时间(分钟)
// 应大于最长 LLM 请求时间,防止请求完成前槽位过期
ConcurrencySlotTTLMinutes
int
`mapstructure:"concurrency_slot_ttl_minutes"`
// 是否记录上游错误响应体摘要(避免输出请求内容)
LogUpstreamErrorBody
bool
`mapstructure:"log_upstream_error_body"`
// 上游错误响应体记录最大字节数(超过会截断)
LogUpstreamErrorBodyMaxBytes
int
`mapstructure:"log_upstream_error_body_max_bytes"`
// API-key 账号在客户端未提供 anthropic-beta 时,是否按需自动补齐(默认关闭以保持兼容)
InjectBetaForApiKey
bool
`mapstructure:"inject_beta_for_apikey"`
// 是否允许对部分 400 错误触发 failover(默认关闭以避免改变语义)
FailoverOn400
bool
`mapstructure:"failover_on_400"`
// Scheduling: 账号调度相关配置
Scheduling
GatewaySchedulingConfig
`mapstructure:"scheduling"`
}
// GatewaySchedulingConfig accounts scheduling configuration.
type
GatewaySchedulingConfig
struct
{
// 粘性会话排队配置
StickySessionMaxWaiting
int
`mapstructure:"sticky_session_max_waiting"`
StickySessionWaitTimeout
time
.
Duration
`mapstructure:"sticky_session_wait_timeout"`
// 兜底排队配置
FallbackWaitTimeout
time
.
Duration
`mapstructure:"fallback_wait_timeout"`
FallbackMaxWaiting
int
`mapstructure:"fallback_max_waiting"`
// 负载计算
LoadBatchEnabled
bool
`mapstructure:"load_batch_enabled"`
// 过期槽位清理周期(0 表示禁用)
SlotCleanupInterval
time
.
Duration
`mapstructure:"slot_cleanup_interval"`
}
func
(
s
*
ServerConfig
)
Address
()
string
{
return
fmt
.
Sprintf
(
"%s:%d"
,
s
.
Host
,
s
.
Port
)
}
// DatabaseConfig 数据库连接配置
// 性能优化:新增连接池参数,避免频繁创建/销毁连接
type
DatabaseConfig
struct
{
Host
string
`mapstructure:"host"`
Port
int
`mapstructure:"port"`
...
...
@@ -92,6 +178,15 @@ type DatabaseConfig struct {
Password
string
`mapstructure:"password"`
DBName
string
`mapstructure:"dbname"`
SSLMode
string
`mapstructure:"sslmode"`
// 连接池配置(性能优化:可配置化连接池参数)
// MaxOpenConns: 最大打开连接数,控制数据库连接上限,防止资源耗尽
MaxOpenConns
int
`mapstructure:"max_open_conns"`
// MaxIdleConns: 最大空闲连接数,保持热连接减少建连延迟
MaxIdleConns
int
`mapstructure:"max_idle_conns"`
// ConnMaxLifetimeMinutes: 连接最大存活时间,防止长连接导致的资源泄漏
ConnMaxLifetimeMinutes
int
`mapstructure:"conn_max_lifetime_minutes"`
// ConnMaxIdleTimeMinutes: 空闲连接最大存活时间,及时释放不活跃连接
ConnMaxIdleTimeMinutes
int
`mapstructure:"conn_max_idle_time_minutes"`
}
func
(
d
*
DatabaseConfig
)
DSN
()
string
{
...
...
@@ -112,11 +207,24 @@ func (d *DatabaseConfig) DSNWithTimezone(tz string) string {
)
}
// RedisConfig Redis 连接配置
// 性能优化:新增连接池和超时参数,提升高并发场景下的吞吐量
type
RedisConfig
struct
{
Host
string
`mapstructure:"host"`
Port
int
`mapstructure:"port"`
Password
string
`mapstructure:"password"`
DB
int
`mapstructure:"db"`
// 连接池与超时配置(性能优化:可配置化连接池参数)
// DialTimeoutSeconds: 建立连接超时,防止慢连接阻塞
DialTimeoutSeconds
int
`mapstructure:"dial_timeout_seconds"`
// ReadTimeoutSeconds: 读取超时,避免慢查询阻塞连接池
ReadTimeoutSeconds
int
`mapstructure:"read_timeout_seconds"`
// WriteTimeoutSeconds: 写入超时,避免慢写入阻塞连接池
WriteTimeoutSeconds
int
`mapstructure:"write_timeout_seconds"`
// PoolSize: 连接池大小,控制最大并发连接数
PoolSize
int
`mapstructure:"pool_size"`
// MinIdleConns: 最小空闲连接数,保持热连接减少冷启动延迟
MinIdleConns
int
`mapstructure:"min_idle_conns"`
}
func
(
r
*
RedisConfig
)
Address
()
string
{
...
...
@@ -203,12 +311,21 @@ func setDefaults() {
viper
.
SetDefault
(
"database.password"
,
"postgres"
)
viper
.
SetDefault
(
"database.dbname"
,
"sub2api"
)
viper
.
SetDefault
(
"database.sslmode"
,
"disable"
)
viper
.
SetDefault
(
"database.max_open_conns"
,
50
)
viper
.
SetDefault
(
"database.max_idle_conns"
,
10
)
viper
.
SetDefault
(
"database.conn_max_lifetime_minutes"
,
30
)
viper
.
SetDefault
(
"database.conn_max_idle_time_minutes"
,
5
)
// Redis
viper
.
SetDefault
(
"redis.host"
,
"localhost"
)
viper
.
SetDefault
(
"redis.port"
,
6379
)
viper
.
SetDefault
(
"redis.password"
,
""
)
viper
.
SetDefault
(
"redis.db"
,
0
)
viper
.
SetDefault
(
"redis.dial_timeout_seconds"
,
5
)
viper
.
SetDefault
(
"redis.read_timeout_seconds"
,
3
)
viper
.
SetDefault
(
"redis.write_timeout_seconds"
,
3
)
viper
.
SetDefault
(
"redis.pool_size"
,
128
)
viper
.
SetDefault
(
"redis.min_idle_conns"
,
10
)
// JWT
viper
.
SetDefault
(
"jwt.secret"
,
"change-me-in-production"
)
...
...
@@ -240,6 +357,26 @@ func setDefaults() {
// Gateway
viper
.
SetDefault
(
"gateway.response_header_timeout"
,
300
)
// 300秒(5分钟)等待上游响应头,LLM高负载时可能排队较久
viper
.
SetDefault
(
"gateway.log_upstream_error_body"
,
false
)
viper
.
SetDefault
(
"gateway.log_upstream_error_body_max_bytes"
,
2048
)
viper
.
SetDefault
(
"gateway.inject_beta_for_apikey"
,
false
)
viper
.
SetDefault
(
"gateway.failover_on_400"
,
false
)
viper
.
SetDefault
(
"gateway.max_body_size"
,
int64
(
100
*
1024
*
1024
))
viper
.
SetDefault
(
"gateway.connection_pool_isolation"
,
ConnectionPoolIsolationAccountProxy
)
// HTTP 上游连接池配置(针对 5000+ 并发用户优化)
viper
.
SetDefault
(
"gateway.max_idle_conns"
,
240
)
// 最大空闲连接总数(HTTP/2 场景默认)
viper
.
SetDefault
(
"gateway.max_idle_conns_per_host"
,
120
)
// 每主机最大空闲连接(HTTP/2 场景默认)
viper
.
SetDefault
(
"gateway.max_conns_per_host"
,
240
)
// 每主机最大连接数(含活跃,HTTP/2 场景默认)
viper
.
SetDefault
(
"gateway.idle_conn_timeout_seconds"
,
300
)
// 空闲连接超时(秒)
viper
.
SetDefault
(
"gateway.max_upstream_clients"
,
5000
)
viper
.
SetDefault
(
"gateway.client_idle_ttl_seconds"
,
900
)
viper
.
SetDefault
(
"gateway.concurrency_slot_ttl_minutes"
,
15
)
// 并发槽位过期时间(支持超长请求)
viper
.
SetDefault
(
"gateway.scheduling.sticky_session_max_waiting"
,
3
)
viper
.
SetDefault
(
"gateway.scheduling.sticky_session_wait_timeout"
,
45
*
time
.
Second
)
viper
.
SetDefault
(
"gateway.scheduling.fallback_wait_timeout"
,
30
*
time
.
Second
)
viper
.
SetDefault
(
"gateway.scheduling.fallback_max_waiting"
,
100
)
viper
.
SetDefault
(
"gateway.scheduling.load_batch_enabled"
,
true
)
viper
.
SetDefault
(
"gateway.scheduling.slot_cleanup_interval"
,
30
*
time
.
Second
)
// TokenRefresh
viper
.
SetDefault
(
"token_refresh.enabled"
,
true
)
...
...
@@ -254,6 +391,7 @@ func setDefaults() {
viper
.
SetDefault
(
"gemini.oauth.client_id"
,
""
)
viper
.
SetDefault
(
"gemini.oauth.client_secret"
,
""
)
viper
.
SetDefault
(
"gemini.oauth.scopes"
,
""
)
viper
.
SetDefault
(
"gemini.quota.policy"
,
""
)
}
func
(
c
*
Config
)
Validate
()
error
{
...
...
@@ -263,6 +401,86 @@ func (c *Config) Validate() error {
if
c
.
JWT
.
Secret
==
"change-me-in-production"
&&
c
.
Server
.
Mode
==
"release"
{
return
fmt
.
Errorf
(
"jwt.secret must be changed in production"
)
}
if
c
.
Database
.
MaxOpenConns
<=
0
{
return
fmt
.
Errorf
(
"database.max_open_conns must be positive"
)
}
if
c
.
Database
.
MaxIdleConns
<
0
{
return
fmt
.
Errorf
(
"database.max_idle_conns must be non-negative"
)
}
if
c
.
Database
.
MaxIdleConns
>
c
.
Database
.
MaxOpenConns
{
return
fmt
.
Errorf
(
"database.max_idle_conns cannot exceed database.max_open_conns"
)
}
if
c
.
Database
.
ConnMaxLifetimeMinutes
<
0
{
return
fmt
.
Errorf
(
"database.conn_max_lifetime_minutes must be non-negative"
)
}
if
c
.
Database
.
ConnMaxIdleTimeMinutes
<
0
{
return
fmt
.
Errorf
(
"database.conn_max_idle_time_minutes must be non-negative"
)
}
if
c
.
Redis
.
DialTimeoutSeconds
<=
0
{
return
fmt
.
Errorf
(
"redis.dial_timeout_seconds must be positive"
)
}
if
c
.
Redis
.
ReadTimeoutSeconds
<=
0
{
return
fmt
.
Errorf
(
"redis.read_timeout_seconds must be positive"
)
}
if
c
.
Redis
.
WriteTimeoutSeconds
<=
0
{
return
fmt
.
Errorf
(
"redis.write_timeout_seconds must be positive"
)
}
if
c
.
Redis
.
PoolSize
<=
0
{
return
fmt
.
Errorf
(
"redis.pool_size must be positive"
)
}
if
c
.
Redis
.
MinIdleConns
<
0
{
return
fmt
.
Errorf
(
"redis.min_idle_conns must be non-negative"
)
}
if
c
.
Redis
.
MinIdleConns
>
c
.
Redis
.
PoolSize
{
return
fmt
.
Errorf
(
"redis.min_idle_conns cannot exceed redis.pool_size"
)
}
if
c
.
Gateway
.
MaxBodySize
<=
0
{
return
fmt
.
Errorf
(
"gateway.max_body_size must be positive"
)
}
if
strings
.
TrimSpace
(
c
.
Gateway
.
ConnectionPoolIsolation
)
!=
""
{
switch
c
.
Gateway
.
ConnectionPoolIsolation
{
case
ConnectionPoolIsolationProxy
,
ConnectionPoolIsolationAccount
,
ConnectionPoolIsolationAccountProxy
:
default
:
return
fmt
.
Errorf
(
"gateway.connection_pool_isolation must be one of: %s/%s/%s"
,
ConnectionPoolIsolationProxy
,
ConnectionPoolIsolationAccount
,
ConnectionPoolIsolationAccountProxy
)
}
}
if
c
.
Gateway
.
MaxIdleConns
<=
0
{
return
fmt
.
Errorf
(
"gateway.max_idle_conns must be positive"
)
}
if
c
.
Gateway
.
MaxIdleConnsPerHost
<=
0
{
return
fmt
.
Errorf
(
"gateway.max_idle_conns_per_host must be positive"
)
}
if
c
.
Gateway
.
MaxConnsPerHost
<
0
{
return
fmt
.
Errorf
(
"gateway.max_conns_per_host must be non-negative"
)
}
if
c
.
Gateway
.
IdleConnTimeoutSeconds
<=
0
{
return
fmt
.
Errorf
(
"gateway.idle_conn_timeout_seconds must be positive"
)
}
if
c
.
Gateway
.
MaxUpstreamClients
<=
0
{
return
fmt
.
Errorf
(
"gateway.max_upstream_clients must be positive"
)
}
if
c
.
Gateway
.
ClientIdleTTLSeconds
<=
0
{
return
fmt
.
Errorf
(
"gateway.client_idle_ttl_seconds must be positive"
)
}
if
c
.
Gateway
.
ConcurrencySlotTTLMinutes
<=
0
{
return
fmt
.
Errorf
(
"gateway.concurrency_slot_ttl_minutes must be positive"
)
}
if
c
.
Gateway
.
Scheduling
.
StickySessionMaxWaiting
<=
0
{
return
fmt
.
Errorf
(
"gateway.scheduling.sticky_session_max_waiting must be positive"
)
}
if
c
.
Gateway
.
Scheduling
.
StickySessionWaitTimeout
<=
0
{
return
fmt
.
Errorf
(
"gateway.scheduling.sticky_session_wait_timeout must be positive"
)
}
if
c
.
Gateway
.
Scheduling
.
FallbackWaitTimeout
<=
0
{
return
fmt
.
Errorf
(
"gateway.scheduling.fallback_wait_timeout must be positive"
)
}
if
c
.
Gateway
.
Scheduling
.
FallbackMaxWaiting
<=
0
{
return
fmt
.
Errorf
(
"gateway.scheduling.fallback_max_waiting must be positive"
)
}
if
c
.
Gateway
.
Scheduling
.
SlotCleanupInterval
<
0
{
return
fmt
.
Errorf
(
"gateway.scheduling.slot_cleanup_interval must be non-negative"
)
}
return
nil
}
...
...
backend/internal/config/config_test.go
View file @
7331220e
package
config
import
"testing"
import
(
"testing"
"time"
"github.com/spf13/viper"
)
func
TestNormalizeRunMode
(
t
*
testing
.
T
)
{
tests
:=
[]
struct
{
...
...
@@ -21,3 +26,45 @@ func TestNormalizeRunMode(t *testing.T) {
}
}
}
func
TestLoadDefaultSchedulingConfig
(
t
*
testing
.
T
)
{
viper
.
Reset
()
cfg
,
err
:=
Load
()
if
err
!=
nil
{
t
.
Fatalf
(
"Load() error: %v"
,
err
)
}
if
cfg
.
Gateway
.
Scheduling
.
StickySessionMaxWaiting
!=
3
{
t
.
Fatalf
(
"StickySessionMaxWaiting = %d, want 3"
,
cfg
.
Gateway
.
Scheduling
.
StickySessionMaxWaiting
)
}
if
cfg
.
Gateway
.
Scheduling
.
StickySessionWaitTimeout
!=
45
*
time
.
Second
{
t
.
Fatalf
(
"StickySessionWaitTimeout = %v, want 45s"
,
cfg
.
Gateway
.
Scheduling
.
StickySessionWaitTimeout
)
}
if
cfg
.
Gateway
.
Scheduling
.
FallbackWaitTimeout
!=
30
*
time
.
Second
{
t
.
Fatalf
(
"FallbackWaitTimeout = %v, want 30s"
,
cfg
.
Gateway
.
Scheduling
.
FallbackWaitTimeout
)
}
if
cfg
.
Gateway
.
Scheduling
.
FallbackMaxWaiting
!=
100
{
t
.
Fatalf
(
"FallbackMaxWaiting = %d, want 100"
,
cfg
.
Gateway
.
Scheduling
.
FallbackMaxWaiting
)
}
if
!
cfg
.
Gateway
.
Scheduling
.
LoadBatchEnabled
{
t
.
Fatalf
(
"LoadBatchEnabled = false, want true"
)
}
if
cfg
.
Gateway
.
Scheduling
.
SlotCleanupInterval
!=
30
*
time
.
Second
{
t
.
Fatalf
(
"SlotCleanupInterval = %v, want 30s"
,
cfg
.
Gateway
.
Scheduling
.
SlotCleanupInterval
)
}
}
func
TestLoadSchedulingConfigFromEnv
(
t
*
testing
.
T
)
{
viper
.
Reset
()
t
.
Setenv
(
"GATEWAY_SCHEDULING_STICKY_SESSION_MAX_WAITING"
,
"5"
)
cfg
,
err
:=
Load
()
if
err
!=
nil
{
t
.
Fatalf
(
"Load() error: %v"
,
err
)
}
if
cfg
.
Gateway
.
Scheduling
.
StickySessionMaxWaiting
!=
5
{
t
.
Fatalf
(
"StickySessionMaxWaiting = %d, want 5"
,
cfg
.
Gateway
.
Scheduling
.
StickySessionMaxWaiting
)
}
}
backend/internal/handler/admin/setting_handler.go
View file @
7331220e
...
...
@@ -10,15 +10,17 @@ import (
// SettingHandler 系统设置处理器
type
SettingHandler
struct
{
settingService
*
service
.
SettingService
emailService
*
service
.
EmailService
settingService
*
service
.
SettingService
emailService
*
service
.
EmailService
turnstileService
*
service
.
TurnstileService
}
// NewSettingHandler 创建系统设置处理器
func
NewSettingHandler
(
settingService
*
service
.
SettingService
,
emailService
*
service
.
EmailService
)
*
SettingHandler
{
func
NewSettingHandler
(
settingService
*
service
.
SettingService
,
emailService
*
service
.
EmailService
,
turnstileService
*
service
.
TurnstileService
)
*
SettingHandler
{
return
&
SettingHandler
{
settingService
:
settingService
,
emailService
:
emailService
,
settingService
:
settingService
,
emailService
:
emailService
,
turnstileService
:
turnstileService
,
}
}
...
...
@@ -108,6 +110,36 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
req
.
SmtpPort
=
587
}
// Turnstile 参数验证
if
req
.
TurnstileEnabled
{
// 检查必填字段
if
req
.
TurnstileSiteKey
==
""
{
response
.
BadRequest
(
c
,
"Turnstile Site Key is required when enabled"
)
return
}
if
req
.
TurnstileSecretKey
==
""
{
response
.
BadRequest
(
c
,
"Turnstile Secret Key is required when enabled"
)
return
}
// 获取当前设置,检查参数是否有变化
currentSettings
,
err
:=
h
.
settingService
.
GetAllSettings
(
c
.
Request
.
Context
())
if
err
!=
nil
{
response
.
ErrorFrom
(
c
,
err
)
return
}
// 当 site_key 或 secret_key 任一变化时验证(避免配置错误导致无法登录)
siteKeyChanged
:=
currentSettings
.
TurnstileSiteKey
!=
req
.
TurnstileSiteKey
secretKeyChanged
:=
currentSettings
.
TurnstileSecretKey
!=
req
.
TurnstileSecretKey
if
siteKeyChanged
||
secretKeyChanged
{
if
err
:=
h
.
turnstileService
.
ValidateSecretKey
(
c
.
Request
.
Context
(),
req
.
TurnstileSecretKey
);
err
!=
nil
{
response
.
ErrorFrom
(
c
,
err
)
return
}
}
}
settings
:=
&
service
.
SystemSettings
{
RegistrationEnabled
:
req
.
RegistrationEnabled
,
EmailVerifyEnabled
:
req
.
EmailVerifyEnabled
,
...
...
backend/internal/handler/gateway_handler.go
View file @
7331220e
...
...
@@ -67,6 +67,10 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
// 读取请求体
body
,
err
:=
io
.
ReadAll
(
c
.
Request
.
Body
)
if
err
!=
nil
{
if
maxErr
,
ok
:=
extractMaxBytesError
(
err
);
ok
{
h
.
errorResponse
(
c
,
http
.
StatusRequestEntityTooLarge
,
"invalid_request_error"
,
buildBodyTooLargeMessage
(
maxErr
.
Limit
))
return
}
h
.
errorResponse
(
c
,
http
.
StatusBadRequest
,
"invalid_request_error"
,
"Failed to read request body"
)
return
}
...
...
@@ -76,15 +80,19 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
return
}
// 解析请求获取模型名和stream
var
req
struct
{
Model
string
`json:"model"`
Stream
bool
`json:"stream"`
}
if
err
:=
json
.
Unmarshal
(
body
,
&
req
);
err
!=
nil
{
parsedReq
,
err
:=
service
.
ParseGatewayRequest
(
body
)
if
err
!=
nil
{
h
.
errorResponse
(
c
,
http
.
StatusBadRequest
,
"invalid_request_error"
,
"Failed to parse request body"
)
return
}
reqModel
:=
parsedReq
.
Model
reqStream
:=
parsedReq
.
Stream
// 验证 model 必填
if
reqModel
==
""
{
h
.
errorResponse
(
c
,
http
.
StatusBadRequest
,
"invalid_request_error"
,
"model is required"
)
return
}
// Track if we've started streaming (for error handling)
streamStarted
:=
false
...
...
@@ -106,7 +114,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
defer
h
.
concurrencyHelper
.
DecrementWaitCount
(
c
.
Request
.
Context
(),
subject
.
UserID
)
// 1. 首先获取用户并发槽位
userReleaseFunc
,
err
:=
h
.
concurrencyHelper
.
AcquireUserSlotWithWait
(
c
,
subject
.
UserID
,
subject
.
Concurrency
,
req
.
Stream
,
&
streamStarted
)
userReleaseFunc
,
err
:=
h
.
concurrencyHelper
.
AcquireUserSlotWithWait
(
c
,
subject
.
UserID
,
subject
.
Concurrency
,
reqStream
,
&
streamStarted
)
if
err
!=
nil
{
log
.
Printf
(
"User concurrency acquire failed: %v"
,
err
)
h
.
handleConcurrencyError
(
c
,
err
,
"user"
,
streamStarted
)
...
...
@@ -124,7 +132,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
}
// 计算粘性会话hash
sessionHash
:=
h
.
gatewayService
.
GenerateSessionHash
(
body
)
sessionHash
:=
h
.
gatewayService
.
GenerateSessionHash
(
parsedReq
)
// 获取平台:优先使用强制平台(/antigravity 路由,中间件已设置 request.Context),否则使用分组平台
platform
:=
""
...
...
@@ -133,6 +141,10 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
}
else
if
apiKey
.
Group
!=
nil
{
platform
=
apiKey
.
Group
.
Platform
}
sessionKey
:=
sessionHash
if
platform
==
service
.
PlatformGemini
&&
sessionHash
!=
""
{
sessionKey
=
"gemini:"
+
sessionHash
}
if
platform
==
service
.
PlatformGemini
{
const
maxAccountSwitches
=
3
...
...
@@ -141,7 +153,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
lastFailoverStatus
:=
0
for
{
account
,
err
:=
h
.
g
eminiCompat
Service
.
SelectAccount
ForModelWithExclusion
s
(
c
.
Request
.
Context
(),
apiKey
.
GroupID
,
session
Hash
,
req
.
Model
,
failedAccountIDs
)
selection
,
err
:=
h
.
g
ateway
Service
.
SelectAccount
WithLoadAwarenes
s
(
c
.
Request
.
Context
(),
apiKey
.
GroupID
,
session
Key
,
reqModel
,
failedAccountIDs
)
if
err
!=
nil
{
if
len
(
failedAccountIDs
)
==
0
{
h
.
handleStreamingAwareError
(
c
,
http
.
StatusServiceUnavailable
,
"api_error"
,
"No available accounts: "
+
err
.
Error
(),
streamStarted
)
...
...
@@ -150,35 +162,77 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
h
.
handleFailoverExhausted
(
c
,
lastFailoverStatus
,
streamStarted
)
return
}
account
:=
selection
.
Account
// 检查预热请求拦截(在账号选择后、转发前检查)
if
account
.
IsInterceptWarmupEnabled
()
&&
isWarmupRequest
(
body
)
{
if
req
.
Stream
{
sendMockWarmupStream
(
c
,
req
.
Model
)
if
selection
.
Acquired
&&
selection
.
ReleaseFunc
!=
nil
{
selection
.
ReleaseFunc
()
}
if
reqStream
{
sendMockWarmupStream
(
c
,
reqModel
)
}
else
{
sendMockWarmupResponse
(
c
,
req
.
Model
)
sendMockWarmupResponse
(
c
,
reqModel
)
}
return
}
// 3. 获取账号并发槽位
accountReleaseFunc
,
err
:=
h
.
concurrencyHelper
.
AcquireAccountSlotWithWait
(
c
,
account
.
ID
,
account
.
Concurrency
,
req
.
Stream
,
&
streamStarted
)
if
err
!=
nil
{
log
.
Printf
(
"Account concurrency acquire failed: %v"
,
err
)
h
.
handleConcurrencyError
(
c
,
err
,
"account"
,
streamStarted
)
return
accountReleaseFunc
:=
selection
.
ReleaseFunc
var
accountWaitRelease
func
()
if
!
selection
.
Acquired
{
if
selection
.
WaitPlan
==
nil
{
h
.
handleStreamingAwareError
(
c
,
http
.
StatusServiceUnavailable
,
"api_error"
,
"No available accounts"
,
streamStarted
)
return
}
canWait
,
err
:=
h
.
concurrencyHelper
.
IncrementAccountWaitCount
(
c
.
Request
.
Context
(),
account
.
ID
,
selection
.
WaitPlan
.
MaxWaiting
)
if
err
!=
nil
{
log
.
Printf
(
"Increment account wait count failed: %v"
,
err
)
}
else
if
!
canWait
{
log
.
Printf
(
"Account wait queue full: account=%d"
,
account
.
ID
)
h
.
handleStreamingAwareError
(
c
,
http
.
StatusTooManyRequests
,
"rate_limit_error"
,
"Too many pending requests, please retry later"
,
streamStarted
)
return
}
else
{
// Only set release function if increment succeeded
accountWaitRelease
=
func
()
{
h
.
concurrencyHelper
.
DecrementAccountWaitCount
(
c
.
Request
.
Context
(),
account
.
ID
)
}
}
accountReleaseFunc
,
err
=
h
.
concurrencyHelper
.
AcquireAccountSlotWithWaitTimeout
(
c
,
account
.
ID
,
selection
.
WaitPlan
.
MaxConcurrency
,
selection
.
WaitPlan
.
Timeout
,
reqStream
,
&
streamStarted
,
)
if
err
!=
nil
{
if
accountWaitRelease
!=
nil
{
accountWaitRelease
()
}
log
.
Printf
(
"Account concurrency acquire failed: %v"
,
err
)
h
.
handleConcurrencyError
(
c
,
err
,
"account"
,
streamStarted
)
return
}
if
err
:=
h
.
gatewayService
.
BindStickySession
(
c
.
Request
.
Context
(),
sessionKey
,
account
.
ID
);
err
!=
nil
{
log
.
Printf
(
"Bind sticky session failed: %v"
,
err
)
}
}
// 转发请求 - 根据账号平台分流
var
result
*
service
.
ForwardResult
if
account
.
Platform
==
service
.
PlatformAntigravity
{
result
,
err
=
h
.
antigravityGatewayService
.
ForwardGemini
(
c
.
Request
.
Context
(),
c
,
account
,
req
.
Model
,
"generateContent"
,
req
.
Stream
,
body
)
result
,
err
=
h
.
antigravityGatewayService
.
ForwardGemini
(
c
.
Request
.
Context
(),
c
,
account
,
reqModel
,
"generateContent"
,
reqStream
,
body
)
}
else
{
result
,
err
=
h
.
geminiCompatService
.
Forward
(
c
.
Request
.
Context
(),
c
,
account
,
body
)
}
if
accountReleaseFunc
!=
nil
{
accountReleaseFunc
()
}
if
accountWaitRelease
!=
nil
{
accountWaitRelease
()
}
if
err
!=
nil
{
var
failoverErr
*
service
.
UpstreamFailoverError
if
errors
.
As
(
err
,
&
failoverErr
)
{
...
...
@@ -223,7 +277,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
for
{
// 选择支持该模型的账号
account
,
err
:=
h
.
gatewayService
.
SelectAccount
ForModelWithExclusion
s
(
c
.
Request
.
Context
(),
apiKey
.
GroupID
,
session
Hash
,
req
.
Model
,
failedAccountIDs
)
selection
,
err
:=
h
.
gatewayService
.
SelectAccount
WithLoadAwarenes
s
(
c
.
Request
.
Context
(),
apiKey
.
GroupID
,
session
Key
,
reqModel
,
failedAccountIDs
)
if
err
!=
nil
{
if
len
(
failedAccountIDs
)
==
0
{
h
.
handleStreamingAwareError
(
c
,
http
.
StatusServiceUnavailable
,
"api_error"
,
"No available accounts: "
+
err
.
Error
(),
streamStarted
)
...
...
@@ -232,23 +286,62 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
h
.
handleFailoverExhausted
(
c
,
lastFailoverStatus
,
streamStarted
)
return
}
account
:=
selection
.
Account
// 检查预热请求拦截(在账号选择后、转发前检查)
if
account
.
IsInterceptWarmupEnabled
()
&&
isWarmupRequest
(
body
)
{
if
req
.
Stream
{
sendMockWarmupStream
(
c
,
req
.
Model
)
if
selection
.
Acquired
&&
selection
.
ReleaseFunc
!=
nil
{
selection
.
ReleaseFunc
()
}
if
reqStream
{
sendMockWarmupStream
(
c
,
reqModel
)
}
else
{
sendMockWarmupResponse
(
c
,
req
.
Model
)
sendMockWarmupResponse
(
c
,
reqModel
)
}
return
}
// 3. 获取账号并发槽位
accountReleaseFunc
,
err
:=
h
.
concurrencyHelper
.
AcquireAccountSlotWithWait
(
c
,
account
.
ID
,
account
.
Concurrency
,
req
.
Stream
,
&
streamStarted
)
if
err
!=
nil
{
log
.
Printf
(
"Account concurrency acquire failed: %v"
,
err
)
h
.
handleConcurrencyError
(
c
,
err
,
"account"
,
streamStarted
)
return
accountReleaseFunc
:=
selection
.
ReleaseFunc
var
accountWaitRelease
func
()
if
!
selection
.
Acquired
{
if
selection
.
WaitPlan
==
nil
{
h
.
handleStreamingAwareError
(
c
,
http
.
StatusServiceUnavailable
,
"api_error"
,
"No available accounts"
,
streamStarted
)
return
}
canWait
,
err
:=
h
.
concurrencyHelper
.
IncrementAccountWaitCount
(
c
.
Request
.
Context
(),
account
.
ID
,
selection
.
WaitPlan
.
MaxWaiting
)
if
err
!=
nil
{
log
.
Printf
(
"Increment account wait count failed: %v"
,
err
)
}
else
if
!
canWait
{
log
.
Printf
(
"Account wait queue full: account=%d"
,
account
.
ID
)
h
.
handleStreamingAwareError
(
c
,
http
.
StatusTooManyRequests
,
"rate_limit_error"
,
"Too many pending requests, please retry later"
,
streamStarted
)
return
}
else
{
// Only set release function if increment succeeded
accountWaitRelease
=
func
()
{
h
.
concurrencyHelper
.
DecrementAccountWaitCount
(
c
.
Request
.
Context
(),
account
.
ID
)
}
}
accountReleaseFunc
,
err
=
h
.
concurrencyHelper
.
AcquireAccountSlotWithWaitTimeout
(
c
,
account
.
ID
,
selection
.
WaitPlan
.
MaxConcurrency
,
selection
.
WaitPlan
.
Timeout
,
reqStream
,
&
streamStarted
,
)
if
err
!=
nil
{
if
accountWaitRelease
!=
nil
{
accountWaitRelease
()
}
log
.
Printf
(
"Account concurrency acquire failed: %v"
,
err
)
h
.
handleConcurrencyError
(
c
,
err
,
"account"
,
streamStarted
)
return
}
if
err
:=
h
.
gatewayService
.
BindStickySession
(
c
.
Request
.
Context
(),
sessionKey
,
account
.
ID
);
err
!=
nil
{
log
.
Printf
(
"Bind sticky session failed: %v"
,
err
)
}
}
// 转发请求 - 根据账号平台分流
...
...
@@ -256,11 +349,14 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
if
account
.
Platform
==
service
.
PlatformAntigravity
{
result
,
err
=
h
.
antigravityGatewayService
.
Forward
(
c
.
Request
.
Context
(),
c
,
account
,
body
)
}
else
{
result
,
err
=
h
.
gatewayService
.
Forward
(
c
.
Request
.
Context
(),
c
,
account
,
body
)
result
,
err
=
h
.
gatewayService
.
Forward
(
c
.
Request
.
Context
(),
c
,
account
,
parsedReq
)
}
if
accountReleaseFunc
!=
nil
{
accountReleaseFunc
()
}
if
accountWaitRelease
!=
nil
{
accountWaitRelease
()
}
if
err
!=
nil
{
var
failoverErr
*
service
.
UpstreamFailoverError
if
errors
.
As
(
err
,
&
failoverErr
)
{
...
...
@@ -525,6 +621,10 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) {
// 读取请求体
body
,
err
:=
io
.
ReadAll
(
c
.
Request
.
Body
)
if
err
!=
nil
{
if
maxErr
,
ok
:=
extractMaxBytesError
(
err
);
ok
{
h
.
errorResponse
(
c
,
http
.
StatusRequestEntityTooLarge
,
"invalid_request_error"
,
buildBodyTooLargeMessage
(
maxErr
.
Limit
))
return
}
h
.
errorResponse
(
c
,
http
.
StatusBadRequest
,
"invalid_request_error"
,
"Failed to read request body"
)
return
}
...
...
@@ -534,15 +634,18 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) {
return
}
// 解析请求获取模型名
var
req
struct
{
Model
string
`json:"model"`
}
if
err
:=
json
.
Unmarshal
(
body
,
&
req
);
err
!=
nil
{
parsedReq
,
err
:=
service
.
ParseGatewayRequest
(
body
)
if
err
!=
nil
{
h
.
errorResponse
(
c
,
http
.
StatusBadRequest
,
"invalid_request_error"
,
"Failed to parse request body"
)
return
}
// 验证 model 必填
if
parsedReq
.
Model
==
""
{
h
.
errorResponse
(
c
,
http
.
StatusBadRequest
,
"invalid_request_error"
,
"model is required"
)
return
}
// 获取订阅信息(可能为nil)
subscription
,
_
:=
middleware2
.
GetSubscriptionFromContext
(
c
)
...
...
@@ -554,17 +657,17 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) {
}
// 计算粘性会话 hash
sessionHash
:=
h
.
gatewayService
.
GenerateSessionHash
(
body
)
sessionHash
:=
h
.
gatewayService
.
GenerateSessionHash
(
parsedReq
)
// 选择支持该模型的账号
account
,
err
:=
h
.
gatewayService
.
SelectAccountForModel
(
c
.
Request
.
Context
(),
apiKey
.
GroupID
,
sessionHash
,
r
eq
.
Model
)
account
,
err
:=
h
.
gatewayService
.
SelectAccountForModel
(
c
.
Request
.
Context
(),
apiKey
.
GroupID
,
sessionHash
,
parsedR
eq
.
Model
)
if
err
!=
nil
{
h
.
errorResponse
(
c
,
http
.
StatusServiceUnavailable
,
"api_error"
,
"No available accounts: "
+
err
.
Error
())
return
}
// 转发请求(不记录使用量)
if
err
:=
h
.
gatewayService
.
ForwardCountTokens
(
c
.
Request
.
Context
(),
c
,
account
,
body
);
err
!=
nil
{
if
err
:=
h
.
gatewayService
.
ForwardCountTokens
(
c
.
Request
.
Context
(),
c
,
account
,
parsedReq
);
err
!=
nil
{
log
.
Printf
(
"Forward count_tokens request failed: %v"
,
err
)
// 错误响应已在 ForwardCountTokens 中处理
return
...
...
backend/internal/handler/gateway_helper.go
View file @
7331220e
...
...
@@ -3,6 +3,7 @@ package handler
import
(
"context"
"fmt"
"math/rand"
"net/http"
"time"
...
...
@@ -11,11 +12,28 @@ import (
"github.com/gin-gonic/gin"
)
// 并发槽位等待相关常量
//
// 性能优化说明:
// 原实现使用固定间隔(100ms)轮询并发槽位,存在以下问题:
// 1. 高并发时频繁轮询增加 Redis 压力
// 2. 固定间隔可能导致多个请求同时重试(惊群效应)
//
// 新实现使用指数退避 + 抖动算法:
// 1. 初始退避 100ms,每次乘以 1.5,最大 2s
// 2. 添加 ±20% 的随机抖动,分散重试时间点
// 3. 减少 Redis 压力,避免惊群效应
const
(
// maxConcurrencyWait
is the maximum time to wait for a concurrency slot
// maxConcurrencyWait
等待并发槽位的最大时间
maxConcurrencyWait
=
30
*
time
.
Second
// pingInterval
is the interval for sending ping events during slot wait
// pingInterval
流式响应等待时发送 ping 的间隔
pingInterval
=
15
*
time
.
Second
// initialBackoff 初始退避时间
initialBackoff
=
100
*
time
.
Millisecond
// backoffMultiplier 退避时间乘数(指数退避)
backoffMultiplier
=
1.5
// maxBackoff 最大退避时间
maxBackoff
=
2
*
time
.
Second
)
// SSEPingFormat defines the format of SSE ping events for different platforms
...
...
@@ -65,6 +83,16 @@ func (h *ConcurrencyHelper) DecrementWaitCount(ctx context.Context, userID int64
h
.
concurrencyService
.
DecrementWaitCount
(
ctx
,
userID
)
}
// IncrementAccountWaitCount increments the wait count for an account
func
(
h
*
ConcurrencyHelper
)
IncrementAccountWaitCount
(
ctx
context
.
Context
,
accountID
int64
,
maxWait
int
)
(
bool
,
error
)
{
return
h
.
concurrencyService
.
IncrementAccountWaitCount
(
ctx
,
accountID
,
maxWait
)
}
// DecrementAccountWaitCount decrements the wait count for an account
func
(
h
*
ConcurrencyHelper
)
DecrementAccountWaitCount
(
ctx
context
.
Context
,
accountID
int64
)
{
h
.
concurrencyService
.
DecrementAccountWaitCount
(
ctx
,
accountID
)
}
// AcquireUserSlotWithWait acquires a user concurrency slot, waiting if necessary.
// For streaming requests, sends ping events during the wait.
// streamStarted is updated if streaming response has begun.
...
...
@@ -108,7 +136,12 @@ func (h *ConcurrencyHelper) AcquireAccountSlotWithWait(c *gin.Context, accountID
// waitForSlotWithPing waits for a concurrency slot, sending ping events for streaming requests.
// streamStarted pointer is updated when streaming begins (for proper error handling by caller).
func
(
h
*
ConcurrencyHelper
)
waitForSlotWithPing
(
c
*
gin
.
Context
,
slotType
string
,
id
int64
,
maxConcurrency
int
,
isStream
bool
,
streamStarted
*
bool
)
(
func
(),
error
)
{
ctx
,
cancel
:=
context
.
WithTimeout
(
c
.
Request
.
Context
(),
maxConcurrencyWait
)
return
h
.
waitForSlotWithPingTimeout
(
c
,
slotType
,
id
,
maxConcurrency
,
maxConcurrencyWait
,
isStream
,
streamStarted
)
}
// waitForSlotWithPingTimeout waits for a concurrency slot with a custom timeout.
func
(
h
*
ConcurrencyHelper
)
waitForSlotWithPingTimeout
(
c
*
gin
.
Context
,
slotType
string
,
id
int64
,
maxConcurrency
int
,
timeout
time
.
Duration
,
isStream
bool
,
streamStarted
*
bool
)
(
func
(),
error
)
{
ctx
,
cancel
:=
context
.
WithTimeout
(
c
.
Request
.
Context
(),
timeout
)
defer
cancel
()
// Determine if ping is needed (streaming + ping format defined)
...
...
@@ -131,8 +164,10 @@ func (h *ConcurrencyHelper) waitForSlotWithPing(c *gin.Context, slotType string,
pingCh
=
pingTicker
.
C
}
pollTicker
:=
time
.
NewTicker
(
100
*
time
.
Millisecond
)
defer
pollTicker
.
Stop
()
backoff
:=
initialBackoff
timer
:=
time
.
NewTimer
(
backoff
)
defer
timer
.
Stop
()
rng
:=
rand
.
New
(
rand
.
NewSource
(
time
.
Now
()
.
UnixNano
()))
for
{
select
{
...
...
@@ -156,7 +191,7 @@ func (h *ConcurrencyHelper) waitForSlotWithPing(c *gin.Context, slotType string,
}
flusher
.
Flush
()
case
<-
pollTick
er
.
C
:
case
<-
tim
er
.
C
:
// Try to acquire slot
var
result
*
service
.
AcquireResult
var
err
error
...
...
@@ -174,6 +209,40 @@ func (h *ConcurrencyHelper) waitForSlotWithPing(c *gin.Context, slotType string,
if
result
.
Acquired
{
return
result
.
ReleaseFunc
,
nil
}
backoff
=
nextBackoff
(
backoff
,
rng
)
timer
.
Reset
(
backoff
)
}
}
}
// AcquireAccountSlotWithWaitTimeout acquires an account slot with a custom timeout (keeps SSE ping).
func
(
h
*
ConcurrencyHelper
)
AcquireAccountSlotWithWaitTimeout
(
c
*
gin
.
Context
,
accountID
int64
,
maxConcurrency
int
,
timeout
time
.
Duration
,
isStream
bool
,
streamStarted
*
bool
)
(
func
(),
error
)
{
return
h
.
waitForSlotWithPingTimeout
(
c
,
"account"
,
accountID
,
maxConcurrency
,
timeout
,
isStream
,
streamStarted
)
}
// nextBackoff 计算下一次退避时间
// 性能优化:使用指数退避 + 随机抖动,避免惊群效应
// current: 当前退避时间
// rng: 随机数生成器(可为 nil,此时不添加抖动)
// 返回值:下一次退避时间(100ms ~ 2s 之间)
func
nextBackoff
(
current
time
.
Duration
,
rng
*
rand
.
Rand
)
time
.
Duration
{
// 指数退避:当前时间 * 1.5
next
:=
time
.
Duration
(
float64
(
current
)
*
backoffMultiplier
)
if
next
>
maxBackoff
{
next
=
maxBackoff
}
if
rng
==
nil
{
return
next
}
// 添加 ±20% 的随机抖动(jitter 范围 0.8 ~ 1.2)
// 抖动可以分散多个请求的重试时间点,避免同时冲击 Redis
jitter
:=
0.8
+
rng
.
Float64
()
*
0.4
jittered
:=
time
.
Duration
(
float64
(
next
)
*
jitter
)
if
jittered
<
initialBackoff
{
return
initialBackoff
}
if
jittered
>
maxBackoff
{
return
maxBackoff
}
return
jittered
}
backend/internal/handler/gemini_v1beta_handler.go
View file @
7331220e
...
...
@@ -148,6 +148,10 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
body
,
err
:=
io
.
ReadAll
(
c
.
Request
.
Body
)
if
err
!=
nil
{
if
maxErr
,
ok
:=
extractMaxBytesError
(
err
);
ok
{
googleError
(
c
,
http
.
StatusRequestEntityTooLarge
,
buildBodyTooLargeMessage
(
maxErr
.
Limit
))
return
}
googleError
(
c
,
http
.
StatusBadRequest
,
"Failed to read request body"
)
return
}
...
...
@@ -191,14 +195,19 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
}
// 3) select account (sticky session based on request body)
sessionHash
:=
h
.
gatewayService
.
GenerateSessionHash
(
body
)
parsedReq
,
_
:=
service
.
ParseGatewayRequest
(
body
)
sessionHash
:=
h
.
gatewayService
.
GenerateSessionHash
(
parsedReq
)
sessionKey
:=
sessionHash
if
sessionHash
!=
""
{
sessionKey
=
"gemini:"
+
sessionHash
}
const
maxAccountSwitches
=
3
switchCount
:=
0
failedAccountIDs
:=
make
(
map
[
int64
]
struct
{})
lastFailoverStatus
:=
0
for
{
account
,
err
:=
h
.
g
eminiCompat
Service
.
SelectAccount
ForModelWithExclusion
s
(
c
.
Request
.
Context
(),
apiKey
.
GroupID
,
session
Hash
,
modelName
,
failedAccountIDs
)
selection
,
err
:=
h
.
g
ateway
Service
.
SelectAccount
WithLoadAwarenes
s
(
c
.
Request
.
Context
(),
apiKey
.
GroupID
,
session
Key
,
modelName
,
failedAccountIDs
)
if
err
!=
nil
{
if
len
(
failedAccountIDs
)
==
0
{
googleError
(
c
,
http
.
StatusServiceUnavailable
,
"No available Gemini accounts: "
+
err
.
Error
())
...
...
@@ -207,12 +216,48 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
handleGeminiFailoverExhausted
(
c
,
lastFailoverStatus
)
return
}
account
:=
selection
.
Account
// 4) account concurrency slot
accountReleaseFunc
,
err
:=
geminiConcurrency
.
AcquireAccountSlotWithWait
(
c
,
account
.
ID
,
account
.
Concurrency
,
stream
,
&
streamStarted
)
if
err
!=
nil
{
googleError
(
c
,
http
.
StatusTooManyRequests
,
err
.
Error
())
return
accountReleaseFunc
:=
selection
.
ReleaseFunc
var
accountWaitRelease
func
()
if
!
selection
.
Acquired
{
if
selection
.
WaitPlan
==
nil
{
googleError
(
c
,
http
.
StatusServiceUnavailable
,
"No available Gemini accounts"
)
return
}
canWait
,
err
:=
geminiConcurrency
.
IncrementAccountWaitCount
(
c
.
Request
.
Context
(),
account
.
ID
,
selection
.
WaitPlan
.
MaxWaiting
)
if
err
!=
nil
{
log
.
Printf
(
"Increment account wait count failed: %v"
,
err
)
}
else
if
!
canWait
{
log
.
Printf
(
"Account wait queue full: account=%d"
,
account
.
ID
)
googleError
(
c
,
http
.
StatusTooManyRequests
,
"Too many pending requests, please retry later"
)
return
}
else
{
// Only set release function if increment succeeded
accountWaitRelease
=
func
()
{
geminiConcurrency
.
DecrementAccountWaitCount
(
c
.
Request
.
Context
(),
account
.
ID
)
}
}
accountReleaseFunc
,
err
=
geminiConcurrency
.
AcquireAccountSlotWithWaitTimeout
(
c
,
account
.
ID
,
selection
.
WaitPlan
.
MaxConcurrency
,
selection
.
WaitPlan
.
Timeout
,
stream
,
&
streamStarted
,
)
if
err
!=
nil
{
if
accountWaitRelease
!=
nil
{
accountWaitRelease
()
}
googleError
(
c
,
http
.
StatusTooManyRequests
,
err
.
Error
())
return
}
if
err
:=
h
.
gatewayService
.
BindStickySession
(
c
.
Request
.
Context
(),
sessionKey
,
account
.
ID
);
err
!=
nil
{
log
.
Printf
(
"Bind sticky session failed: %v"
,
err
)
}
}
// 5) forward (根据平台分流)
...
...
@@ -225,6 +270,9 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
if
accountReleaseFunc
!=
nil
{
accountReleaseFunc
()
}
if
accountWaitRelease
!=
nil
{
accountWaitRelease
()
}
if
err
!=
nil
{
var
failoverErr
*
service
.
UpstreamFailoverError
if
errors
.
As
(
err
,
&
failoverErr
)
{
...
...
backend/internal/handler/openai_gateway_handler.go
View file @
7331220e
...
...
@@ -56,6 +56,10 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
// Read request body
body
,
err
:=
io
.
ReadAll
(
c
.
Request
.
Body
)
if
err
!=
nil
{
if
maxErr
,
ok
:=
extractMaxBytesError
(
err
);
ok
{
h
.
errorResponse
(
c
,
http
.
StatusRequestEntityTooLarge
,
"invalid_request_error"
,
buildBodyTooLargeMessage
(
maxErr
.
Limit
))
return
}
h
.
errorResponse
(
c
,
http
.
StatusBadRequest
,
"invalid_request_error"
,
"Failed to read request body"
)
return
}
...
...
@@ -76,6 +80,12 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
reqModel
,
_
:=
reqBody
[
"model"
]
.
(
string
)
reqStream
,
_
:=
reqBody
[
"stream"
]
.
(
bool
)
// 验证 model 必填
if
reqModel
==
""
{
h
.
errorResponse
(
c
,
http
.
StatusBadRequest
,
"invalid_request_error"
,
"model is required"
)
return
}
// For non-Codex CLI requests, set default instructions
userAgent
:=
c
.
GetHeader
(
"User-Agent"
)
if
!
openai
.
IsCodexCLIRequest
(
userAgent
)
{
...
...
@@ -136,7 +146,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
for
{
// Select account supporting the requested model
log
.
Printf
(
"[OpenAI Handler] Selecting account: groupID=%v model=%s"
,
apiKey
.
GroupID
,
reqModel
)
account
,
err
:=
h
.
gatewayService
.
SelectAccount
ForModelWithExclusion
s
(
c
.
Request
.
Context
(),
apiKey
.
GroupID
,
sessionHash
,
reqModel
,
failedAccountIDs
)
selection
,
err
:=
h
.
gatewayService
.
SelectAccount
WithLoadAwarenes
s
(
c
.
Request
.
Context
(),
apiKey
.
GroupID
,
sessionHash
,
reqModel
,
failedAccountIDs
)
if
err
!=
nil
{
log
.
Printf
(
"[OpenAI Handler] SelectAccount failed: %v"
,
err
)
if
len
(
failedAccountIDs
)
==
0
{
...
...
@@ -146,14 +156,50 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
h
.
handleFailoverExhausted
(
c
,
lastFailoverStatus
,
streamStarted
)
return
}
account
:=
selection
.
Account
log
.
Printf
(
"[OpenAI Handler] Selected account: id=%d name=%s"
,
account
.
ID
,
account
.
Name
)
// 3. Acquire account concurrency slot
accountReleaseFunc
,
err
:=
h
.
concurrencyHelper
.
AcquireAccountSlotWithWait
(
c
,
account
.
ID
,
account
.
Concurrency
,
reqStream
,
&
streamStarted
)
if
err
!=
nil
{
log
.
Printf
(
"Account concurrency acquire failed: %v"
,
err
)
h
.
handleConcurrencyError
(
c
,
err
,
"account"
,
streamStarted
)
return
accountReleaseFunc
:=
selection
.
ReleaseFunc
var
accountWaitRelease
func
()
if
!
selection
.
Acquired
{
if
selection
.
WaitPlan
==
nil
{
h
.
handleStreamingAwareError
(
c
,
http
.
StatusServiceUnavailable
,
"api_error"
,
"No available accounts"
,
streamStarted
)
return
}
canWait
,
err
:=
h
.
concurrencyHelper
.
IncrementAccountWaitCount
(
c
.
Request
.
Context
(),
account
.
ID
,
selection
.
WaitPlan
.
MaxWaiting
)
if
err
!=
nil
{
log
.
Printf
(
"Increment account wait count failed: %v"
,
err
)
}
else
if
!
canWait
{
log
.
Printf
(
"Account wait queue full: account=%d"
,
account
.
ID
)
h
.
handleStreamingAwareError
(
c
,
http
.
StatusTooManyRequests
,
"rate_limit_error"
,
"Too many pending requests, please retry later"
,
streamStarted
)
return
}
else
{
// Only set release function if increment succeeded
accountWaitRelease
=
func
()
{
h
.
concurrencyHelper
.
DecrementAccountWaitCount
(
c
.
Request
.
Context
(),
account
.
ID
)
}
}
accountReleaseFunc
,
err
=
h
.
concurrencyHelper
.
AcquireAccountSlotWithWaitTimeout
(
c
,
account
.
ID
,
selection
.
WaitPlan
.
MaxConcurrency
,
selection
.
WaitPlan
.
Timeout
,
reqStream
,
&
streamStarted
,
)
if
err
!=
nil
{
if
accountWaitRelease
!=
nil
{
accountWaitRelease
()
}
log
.
Printf
(
"Account concurrency acquire failed: %v"
,
err
)
h
.
handleConcurrencyError
(
c
,
err
,
"account"
,
streamStarted
)
return
}
if
err
:=
h
.
gatewayService
.
BindStickySession
(
c
.
Request
.
Context
(),
sessionHash
,
account
.
ID
);
err
!=
nil
{
log
.
Printf
(
"Bind sticky session failed: %v"
,
err
)
}
}
// Forward request
...
...
@@ -161,6 +207,9 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
if
accountReleaseFunc
!=
nil
{
accountReleaseFunc
()
}
if
accountWaitRelease
!=
nil
{
accountWaitRelease
()
}
if
err
!=
nil
{
var
failoverErr
*
service
.
UpstreamFailoverError
if
errors
.
As
(
err
,
&
failoverErr
)
{
...
...
backend/internal/handler/request_body_limit.go
0 → 100644
View file @
7331220e
package
handler
import
(
"errors"
"fmt"
"net/http"
)
func
extractMaxBytesError
(
err
error
)
(
*
http
.
MaxBytesError
,
bool
)
{
var
maxErr
*
http
.
MaxBytesError
if
errors
.
As
(
err
,
&
maxErr
)
{
return
maxErr
,
true
}
return
nil
,
false
}
func
formatBodyLimit
(
limit
int64
)
string
{
const
mb
=
1024
*
1024
if
limit
>=
mb
{
return
fmt
.
Sprintf
(
"%dMB"
,
limit
/
mb
)
}
return
fmt
.
Sprintf
(
"%dB"
,
limit
)
}
func
buildBodyTooLargeMessage
(
limit
int64
)
string
{
return
fmt
.
Sprintf
(
"Request body too large, limit is %s"
,
formatBodyLimit
(
limit
))
}
Prev
1
2
3
4
5
6
7
8
…
11
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment