Unverified Commit 6bccb8a8 authored by Wesley Liddick's avatar Wesley Liddick Committed by GitHub
Browse files

Merge branch 'main' into feature/antigravity-user-agent-configurable

parents 1fc6ef3d 3de1e0e4
// Code generated by ent, DO NOT EDIT.
package ent
import (
"context"
"fmt"
"math"
"entgo.io/ent"
"entgo.io/ent/dialect"
"entgo.io/ent/dialect/sql"
"entgo.io/ent/dialect/sql/sqlgraph"
"entgo.io/ent/schema/field"
"github.com/Wei-Shaw/sub2api/ent/predicate"
"github.com/Wei-Shaw/sub2api/ent/securitysecret"
)
// SecuritySecretQuery is the builder for querying SecuritySecret entities.
type SecuritySecretQuery struct {
config
ctx *QueryContext
order []securitysecret.OrderOption
inters []Interceptor
predicates []predicate.SecuritySecret
modifiers []func(*sql.Selector)
// intermediate query (i.e. traversal path).
sql *sql.Selector
path func(context.Context) (*sql.Selector, error)
}
// Where adds a new predicate for the SecuritySecretQuery builder.
func (_q *SecuritySecretQuery) Where(ps ...predicate.SecuritySecret) *SecuritySecretQuery {
_q.predicates = append(_q.predicates, ps...)
return _q
}
// Limit the number of records to be returned by this query.
func (_q *SecuritySecretQuery) Limit(limit int) *SecuritySecretQuery {
_q.ctx.Limit = &limit
return _q
}
// Offset to start from.
func (_q *SecuritySecretQuery) Offset(offset int) *SecuritySecretQuery {
_q.ctx.Offset = &offset
return _q
}
// Unique configures the query builder to filter duplicate records on query.
// By default, unique is set to true, and can be disabled using this method.
func (_q *SecuritySecretQuery) Unique(unique bool) *SecuritySecretQuery {
_q.ctx.Unique = &unique
return _q
}
// Order specifies how the records should be ordered.
func (_q *SecuritySecretQuery) Order(o ...securitysecret.OrderOption) *SecuritySecretQuery {
_q.order = append(_q.order, o...)
return _q
}
// First returns the first SecuritySecret entity from the query.
// Returns a *NotFoundError when no SecuritySecret was found.
func (_q *SecuritySecretQuery) First(ctx context.Context) (*SecuritySecret, error) {
nodes, err := _q.Limit(1).All(setContextOp(ctx, _q.ctx, ent.OpQueryFirst))
if err != nil {
return nil, err
}
if len(nodes) == 0 {
return nil, &NotFoundError{securitysecret.Label}
}
return nodes[0], nil
}
// FirstX is like First, but panics if an error occurs.
func (_q *SecuritySecretQuery) FirstX(ctx context.Context) *SecuritySecret {
node, err := _q.First(ctx)
if err != nil && !IsNotFound(err) {
panic(err)
}
return node
}
// FirstID returns the first SecuritySecret ID from the query.
// Returns a *NotFoundError when no SecuritySecret ID was found.
func (_q *SecuritySecretQuery) FirstID(ctx context.Context) (id int64, err error) {
var ids []int64
if ids, err = _q.Limit(1).IDs(setContextOp(ctx, _q.ctx, ent.OpQueryFirstID)); err != nil {
return
}
if len(ids) == 0 {
err = &NotFoundError{securitysecret.Label}
return
}
return ids[0], nil
}
// FirstIDX is like FirstID, but panics if an error occurs.
func (_q *SecuritySecretQuery) FirstIDX(ctx context.Context) int64 {
id, err := _q.FirstID(ctx)
if err != nil && !IsNotFound(err) {
panic(err)
}
return id
}
// Only returns a single SecuritySecret entity found by the query, ensuring it only returns one.
// Returns a *NotSingularError when more than one SecuritySecret entity is found.
// Returns a *NotFoundError when no SecuritySecret entities are found.
func (_q *SecuritySecretQuery) Only(ctx context.Context) (*SecuritySecret, error) {
nodes, err := _q.Limit(2).All(setContextOp(ctx, _q.ctx, ent.OpQueryOnly))
if err != nil {
return nil, err
}
switch len(nodes) {
case 1:
return nodes[0], nil
case 0:
return nil, &NotFoundError{securitysecret.Label}
default:
return nil, &NotSingularError{securitysecret.Label}
}
}
// OnlyX is like Only, but panics if an error occurs.
func (_q *SecuritySecretQuery) OnlyX(ctx context.Context) *SecuritySecret {
node, err := _q.Only(ctx)
if err != nil {
panic(err)
}
return node
}
// OnlyID is like Only, but returns the only SecuritySecret ID in the query.
// Returns a *NotSingularError when more than one SecuritySecret ID is found.
// Returns a *NotFoundError when no entities are found.
func (_q *SecuritySecretQuery) OnlyID(ctx context.Context) (id int64, err error) {
var ids []int64
if ids, err = _q.Limit(2).IDs(setContextOp(ctx, _q.ctx, ent.OpQueryOnlyID)); err != nil {
return
}
switch len(ids) {
case 1:
id = ids[0]
case 0:
err = &NotFoundError{securitysecret.Label}
default:
err = &NotSingularError{securitysecret.Label}
}
return
}
// OnlyIDX is like OnlyID, but panics if an error occurs.
func (_q *SecuritySecretQuery) OnlyIDX(ctx context.Context) int64 {
id, err := _q.OnlyID(ctx)
if err != nil {
panic(err)
}
return id
}
// All executes the query and returns a list of SecuritySecrets.
func (_q *SecuritySecretQuery) All(ctx context.Context) ([]*SecuritySecret, error) {
ctx = setContextOp(ctx, _q.ctx, ent.OpQueryAll)
if err := _q.prepareQuery(ctx); err != nil {
return nil, err
}
qr := querierAll[[]*SecuritySecret, *SecuritySecretQuery]()
return withInterceptors[[]*SecuritySecret](ctx, _q, qr, _q.inters)
}
// AllX is like All, but panics if an error occurs.
func (_q *SecuritySecretQuery) AllX(ctx context.Context) []*SecuritySecret {
nodes, err := _q.All(ctx)
if err != nil {
panic(err)
}
return nodes
}
// IDs executes the query and returns a list of SecuritySecret IDs.
func (_q *SecuritySecretQuery) IDs(ctx context.Context) (ids []int64, err error) {
if _q.ctx.Unique == nil && _q.path != nil {
_q.Unique(true)
}
ctx = setContextOp(ctx, _q.ctx, ent.OpQueryIDs)
if err = _q.Select(securitysecret.FieldID).Scan(ctx, &ids); err != nil {
return nil, err
}
return ids, nil
}
// IDsX is like IDs, but panics if an error occurs.
func (_q *SecuritySecretQuery) IDsX(ctx context.Context) []int64 {
ids, err := _q.IDs(ctx)
if err != nil {
panic(err)
}
return ids
}
// Count returns the count of the given query.
func (_q *SecuritySecretQuery) Count(ctx context.Context) (int, error) {
ctx = setContextOp(ctx, _q.ctx, ent.OpQueryCount)
if err := _q.prepareQuery(ctx); err != nil {
return 0, err
}
return withInterceptors[int](ctx, _q, querierCount[*SecuritySecretQuery](), _q.inters)
}
// CountX is like Count, but panics if an error occurs.
func (_q *SecuritySecretQuery) CountX(ctx context.Context) int {
count, err := _q.Count(ctx)
if err != nil {
panic(err)
}
return count
}
// Exist returns true if the query has elements in the graph.
func (_q *SecuritySecretQuery) Exist(ctx context.Context) (bool, error) {
ctx = setContextOp(ctx, _q.ctx, ent.OpQueryExist)
switch _, err := _q.FirstID(ctx); {
case IsNotFound(err):
return false, nil
case err != nil:
return false, fmt.Errorf("ent: check existence: %w", err)
default:
return true, nil
}
}
// ExistX is like Exist, but panics if an error occurs.
func (_q *SecuritySecretQuery) ExistX(ctx context.Context) bool {
exist, err := _q.Exist(ctx)
if err != nil {
panic(err)
}
return exist
}
// Clone returns a duplicate of the SecuritySecretQuery builder, including all associated steps. It can be
// used to prepare common query builders and use them differently after the clone is made.
func (_q *SecuritySecretQuery) Clone() *SecuritySecretQuery {
if _q == nil {
return nil
}
return &SecuritySecretQuery{
config: _q.config,
ctx: _q.ctx.Clone(),
order: append([]securitysecret.OrderOption{}, _q.order...),
inters: append([]Interceptor{}, _q.inters...),
predicates: append([]predicate.SecuritySecret{}, _q.predicates...),
// clone intermediate query.
sql: _q.sql.Clone(),
path: _q.path,
}
}
// GroupBy is used to group vertices by one or more fields/columns.
// It is often used with aggregate functions, like: count, max, mean, min, sum.
//
// Example:
//
// var v []struct {
// CreatedAt time.Time `json:"created_at,omitempty"`
// Count int `json:"count,omitempty"`
// }
//
// client.SecuritySecret.Query().
// GroupBy(securitysecret.FieldCreatedAt).
// Aggregate(ent.Count()).
// Scan(ctx, &v)
func (_q *SecuritySecretQuery) GroupBy(field string, fields ...string) *SecuritySecretGroupBy {
_q.ctx.Fields = append([]string{field}, fields...)
grbuild := &SecuritySecretGroupBy{build: _q}
grbuild.flds = &_q.ctx.Fields
grbuild.label = securitysecret.Label
grbuild.scan = grbuild.Scan
return grbuild
}
// Select allows the selection one or more fields/columns for the given query,
// instead of selecting all fields in the entity.
//
// Example:
//
// var v []struct {
// CreatedAt time.Time `json:"created_at,omitempty"`
// }
//
// client.SecuritySecret.Query().
// Select(securitysecret.FieldCreatedAt).
// Scan(ctx, &v)
func (_q *SecuritySecretQuery) Select(fields ...string) *SecuritySecretSelect {
_q.ctx.Fields = append(_q.ctx.Fields, fields...)
sbuild := &SecuritySecretSelect{SecuritySecretQuery: _q}
sbuild.label = securitysecret.Label
sbuild.flds, sbuild.scan = &_q.ctx.Fields, sbuild.Scan
return sbuild
}
// Aggregate returns a SecuritySecretSelect configured with the given aggregations.
func (_q *SecuritySecretQuery) Aggregate(fns ...AggregateFunc) *SecuritySecretSelect {
return _q.Select().Aggregate(fns...)
}
func (_q *SecuritySecretQuery) prepareQuery(ctx context.Context) error {
for _, inter := range _q.inters {
if inter == nil {
return fmt.Errorf("ent: uninitialized interceptor (forgotten import ent/runtime?)")
}
if trv, ok := inter.(Traverser); ok {
if err := trv.Traverse(ctx, _q); err != nil {
return err
}
}
}
for _, f := range _q.ctx.Fields {
if !securitysecret.ValidColumn(f) {
return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)}
}
}
if _q.path != nil {
prev, err := _q.path(ctx)
if err != nil {
return err
}
_q.sql = prev
}
return nil
}
func (_q *SecuritySecretQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*SecuritySecret, error) {
var (
nodes = []*SecuritySecret{}
_spec = _q.querySpec()
)
_spec.ScanValues = func(columns []string) ([]any, error) {
return (*SecuritySecret).scanValues(nil, columns)
}
_spec.Assign = func(columns []string, values []any) error {
node := &SecuritySecret{config: _q.config}
nodes = append(nodes, node)
return node.assignValues(columns, values)
}
if len(_q.modifiers) > 0 {
_spec.Modifiers = _q.modifiers
}
for i := range hooks {
hooks[i](ctx, _spec)
}
if err := sqlgraph.QueryNodes(ctx, _q.driver, _spec); err != nil {
return nil, err
}
if len(nodes) == 0 {
return nodes, nil
}
return nodes, nil
}
func (_q *SecuritySecretQuery) sqlCount(ctx context.Context) (int, error) {
_spec := _q.querySpec()
if len(_q.modifiers) > 0 {
_spec.Modifiers = _q.modifiers
}
_spec.Node.Columns = _q.ctx.Fields
if len(_q.ctx.Fields) > 0 {
_spec.Unique = _q.ctx.Unique != nil && *_q.ctx.Unique
}
return sqlgraph.CountNodes(ctx, _q.driver, _spec)
}
func (_q *SecuritySecretQuery) querySpec() *sqlgraph.QuerySpec {
_spec := sqlgraph.NewQuerySpec(securitysecret.Table, securitysecret.Columns, sqlgraph.NewFieldSpec(securitysecret.FieldID, field.TypeInt64))
_spec.From = _q.sql
if unique := _q.ctx.Unique; unique != nil {
_spec.Unique = *unique
} else if _q.path != nil {
_spec.Unique = true
}
if fields := _q.ctx.Fields; len(fields) > 0 {
_spec.Node.Columns = make([]string, 0, len(fields))
_spec.Node.Columns = append(_spec.Node.Columns, securitysecret.FieldID)
for i := range fields {
if fields[i] != securitysecret.FieldID {
_spec.Node.Columns = append(_spec.Node.Columns, fields[i])
}
}
}
if ps := _q.predicates; len(ps) > 0 {
_spec.Predicate = func(selector *sql.Selector) {
for i := range ps {
ps[i](selector)
}
}
}
if limit := _q.ctx.Limit; limit != nil {
_spec.Limit = *limit
}
if offset := _q.ctx.Offset; offset != nil {
_spec.Offset = *offset
}
if ps := _q.order; len(ps) > 0 {
_spec.Order = func(selector *sql.Selector) {
for i := range ps {
ps[i](selector)
}
}
}
return _spec
}
func (_q *SecuritySecretQuery) sqlQuery(ctx context.Context) *sql.Selector {
builder := sql.Dialect(_q.driver.Dialect())
t1 := builder.Table(securitysecret.Table)
columns := _q.ctx.Fields
if len(columns) == 0 {
columns = securitysecret.Columns
}
selector := builder.Select(t1.Columns(columns...)...).From(t1)
if _q.sql != nil {
selector = _q.sql
selector.Select(selector.Columns(columns...)...)
}
if _q.ctx.Unique != nil && *_q.ctx.Unique {
selector.Distinct()
}
for _, m := range _q.modifiers {
m(selector)
}
for _, p := range _q.predicates {
p(selector)
}
for _, p := range _q.order {
p(selector)
}
if offset := _q.ctx.Offset; offset != nil {
// limit is mandatory for offset clause. We start
// with default value, and override it below if needed.
selector.Offset(*offset).Limit(math.MaxInt32)
}
if limit := _q.ctx.Limit; limit != nil {
selector.Limit(*limit)
}
return selector
}
// ForUpdate locks the selected rows against concurrent updates, and prevent them from being
// updated, deleted or "selected ... for update" by other sessions, until the transaction is
// either committed or rolled-back.
func (_q *SecuritySecretQuery) ForUpdate(opts ...sql.LockOption) *SecuritySecretQuery {
if _q.driver.Dialect() == dialect.Postgres {
_q.Unique(false)
}
_q.modifiers = append(_q.modifiers, func(s *sql.Selector) {
s.ForUpdate(opts...)
})
return _q
}
// ForShare behaves similarly to ForUpdate, except that it acquires a shared mode lock
// on any rows that are read. Other sessions can read the rows, but cannot modify them
// until your transaction commits.
func (_q *SecuritySecretQuery) ForShare(opts ...sql.LockOption) *SecuritySecretQuery {
if _q.driver.Dialect() == dialect.Postgres {
_q.Unique(false)
}
_q.modifiers = append(_q.modifiers, func(s *sql.Selector) {
s.ForShare(opts...)
})
return _q
}
// SecuritySecretGroupBy is the group-by builder for SecuritySecret entities.
type SecuritySecretGroupBy struct {
selector
build *SecuritySecretQuery
}
// Aggregate adds the given aggregation functions to the group-by query.
func (_g *SecuritySecretGroupBy) Aggregate(fns ...AggregateFunc) *SecuritySecretGroupBy {
_g.fns = append(_g.fns, fns...)
return _g
}
// Scan applies the selector query and scans the result into the given value.
func (_g *SecuritySecretGroupBy) Scan(ctx context.Context, v any) error {
ctx = setContextOp(ctx, _g.build.ctx, ent.OpQueryGroupBy)
if err := _g.build.prepareQuery(ctx); err != nil {
return err
}
return scanWithInterceptors[*SecuritySecretQuery, *SecuritySecretGroupBy](ctx, _g.build, _g, _g.build.inters, v)
}
func (_g *SecuritySecretGroupBy) sqlScan(ctx context.Context, root *SecuritySecretQuery, v any) error {
selector := root.sqlQuery(ctx).Select()
aggregation := make([]string, 0, len(_g.fns))
for _, fn := range _g.fns {
aggregation = append(aggregation, fn(selector))
}
if len(selector.SelectedColumns()) == 0 {
columns := make([]string, 0, len(*_g.flds)+len(_g.fns))
for _, f := range *_g.flds {
columns = append(columns, selector.C(f))
}
columns = append(columns, aggregation...)
selector.Select(columns...)
}
selector.GroupBy(selector.Columns(*_g.flds...)...)
if err := selector.Err(); err != nil {
return err
}
rows := &sql.Rows{}
query, args := selector.Query()
if err := _g.build.driver.Query(ctx, query, args, rows); err != nil {
return err
}
defer rows.Close()
return sql.ScanSlice(rows, v)
}
// SecuritySecretSelect is the builder for selecting fields of SecuritySecret entities.
type SecuritySecretSelect struct {
*SecuritySecretQuery
selector
}
// Aggregate adds the given aggregation functions to the selector query.
func (_s *SecuritySecretSelect) Aggregate(fns ...AggregateFunc) *SecuritySecretSelect {
_s.fns = append(_s.fns, fns...)
return _s
}
// Scan applies the selector query and scans the result into the given value.
func (_s *SecuritySecretSelect) Scan(ctx context.Context, v any) error {
ctx = setContextOp(ctx, _s.ctx, ent.OpQuerySelect)
if err := _s.prepareQuery(ctx); err != nil {
return err
}
return scanWithInterceptors[*SecuritySecretQuery, *SecuritySecretSelect](ctx, _s.SecuritySecretQuery, _s, _s.inters, v)
}
func (_s *SecuritySecretSelect) sqlScan(ctx context.Context, root *SecuritySecretQuery, v any) error {
selector := root.sqlQuery(ctx)
aggregation := make([]string, 0, len(_s.fns))
for _, fn := range _s.fns {
aggregation = append(aggregation, fn(selector))
}
switch n := len(*_s.selector.flds); {
case n == 0 && len(aggregation) > 0:
selector.Select(aggregation...)
case n != 0 && len(aggregation) > 0:
selector.AppendSelect(aggregation...)
}
rows := &sql.Rows{}
query, args := selector.Query()
if err := _s.driver.Query(ctx, query, args, rows); err != nil {
return err
}
defer rows.Close()
return sql.ScanSlice(rows, v)
}
// Code generated by ent, DO NOT EDIT.
package ent
import (
"context"
"errors"
"fmt"
"time"
"entgo.io/ent/dialect/sql"
"entgo.io/ent/dialect/sql/sqlgraph"
"entgo.io/ent/schema/field"
"github.com/Wei-Shaw/sub2api/ent/predicate"
"github.com/Wei-Shaw/sub2api/ent/securitysecret"
)
// SecuritySecretUpdate is the builder for updating SecuritySecret entities.
type SecuritySecretUpdate struct {
config
hooks []Hook
mutation *SecuritySecretMutation
}
// Where appends a list predicates to the SecuritySecretUpdate builder.
func (_u *SecuritySecretUpdate) Where(ps ...predicate.SecuritySecret) *SecuritySecretUpdate {
_u.mutation.Where(ps...)
return _u
}
// SetUpdatedAt sets the "updated_at" field.
func (_u *SecuritySecretUpdate) SetUpdatedAt(v time.Time) *SecuritySecretUpdate {
_u.mutation.SetUpdatedAt(v)
return _u
}
// SetKey sets the "key" field.
func (_u *SecuritySecretUpdate) SetKey(v string) *SecuritySecretUpdate {
_u.mutation.SetKey(v)
return _u
}
// SetNillableKey sets the "key" field if the given value is not nil.
func (_u *SecuritySecretUpdate) SetNillableKey(v *string) *SecuritySecretUpdate {
if v != nil {
_u.SetKey(*v)
}
return _u
}
// SetValue sets the "value" field.
func (_u *SecuritySecretUpdate) SetValue(v string) *SecuritySecretUpdate {
_u.mutation.SetValue(v)
return _u
}
// SetNillableValue sets the "value" field if the given value is not nil.
func (_u *SecuritySecretUpdate) SetNillableValue(v *string) *SecuritySecretUpdate {
if v != nil {
_u.SetValue(*v)
}
return _u
}
// Mutation returns the SecuritySecretMutation object of the builder.
func (_u *SecuritySecretUpdate) Mutation() *SecuritySecretMutation {
return _u.mutation
}
// Save executes the query and returns the number of nodes affected by the update operation.
func (_u *SecuritySecretUpdate) Save(ctx context.Context) (int, error) {
_u.defaults()
return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks)
}
// SaveX is like Save, but panics if an error occurs.
func (_u *SecuritySecretUpdate) SaveX(ctx context.Context) int {
affected, err := _u.Save(ctx)
if err != nil {
panic(err)
}
return affected
}
// Exec executes the query.
func (_u *SecuritySecretUpdate) Exec(ctx context.Context) error {
_, err := _u.Save(ctx)
return err
}
// ExecX is like Exec, but panics if an error occurs.
func (_u *SecuritySecretUpdate) ExecX(ctx context.Context) {
if err := _u.Exec(ctx); err != nil {
panic(err)
}
}
// defaults sets the default values of the builder before save.
func (_u *SecuritySecretUpdate) defaults() {
if _, ok := _u.mutation.UpdatedAt(); !ok {
v := securitysecret.UpdateDefaultUpdatedAt()
_u.mutation.SetUpdatedAt(v)
}
}
// check runs all checks and user-defined validators on the builder.
func (_u *SecuritySecretUpdate) check() error {
if v, ok := _u.mutation.Key(); ok {
if err := securitysecret.KeyValidator(v); err != nil {
return &ValidationError{Name: "key", err: fmt.Errorf(`ent: validator failed for field "SecuritySecret.key": %w`, err)}
}
}
if v, ok := _u.mutation.Value(); ok {
if err := securitysecret.ValueValidator(v); err != nil {
return &ValidationError{Name: "value", err: fmt.Errorf(`ent: validator failed for field "SecuritySecret.value": %w`, err)}
}
}
return nil
}
func (_u *SecuritySecretUpdate) sqlSave(ctx context.Context) (_node int, err error) {
if err := _u.check(); err != nil {
return _node, err
}
_spec := sqlgraph.NewUpdateSpec(securitysecret.Table, securitysecret.Columns, sqlgraph.NewFieldSpec(securitysecret.FieldID, field.TypeInt64))
if ps := _u.mutation.predicates; len(ps) > 0 {
_spec.Predicate = func(selector *sql.Selector) {
for i := range ps {
ps[i](selector)
}
}
}
if value, ok := _u.mutation.UpdatedAt(); ok {
_spec.SetField(securitysecret.FieldUpdatedAt, field.TypeTime, value)
}
if value, ok := _u.mutation.Key(); ok {
_spec.SetField(securitysecret.FieldKey, field.TypeString, value)
}
if value, ok := _u.mutation.Value(); ok {
_spec.SetField(securitysecret.FieldValue, field.TypeString, value)
}
if _node, err = sqlgraph.UpdateNodes(ctx, _u.driver, _spec); err != nil {
if _, ok := err.(*sqlgraph.NotFoundError); ok {
err = &NotFoundError{securitysecret.Label}
} else if sqlgraph.IsConstraintError(err) {
err = &ConstraintError{msg: err.Error(), wrap: err}
}
return 0, err
}
_u.mutation.done = true
return _node, nil
}
// SecuritySecretUpdateOne is the builder for updating a single SecuritySecret entity.
type SecuritySecretUpdateOne struct {
config
fields []string
hooks []Hook
mutation *SecuritySecretMutation
}
// SetUpdatedAt sets the "updated_at" field.
func (_u *SecuritySecretUpdateOne) SetUpdatedAt(v time.Time) *SecuritySecretUpdateOne {
_u.mutation.SetUpdatedAt(v)
return _u
}
// SetKey sets the "key" field.
func (_u *SecuritySecretUpdateOne) SetKey(v string) *SecuritySecretUpdateOne {
_u.mutation.SetKey(v)
return _u
}
// SetNillableKey sets the "key" field if the given value is not nil.
func (_u *SecuritySecretUpdateOne) SetNillableKey(v *string) *SecuritySecretUpdateOne {
if v != nil {
_u.SetKey(*v)
}
return _u
}
// SetValue sets the "value" field.
func (_u *SecuritySecretUpdateOne) SetValue(v string) *SecuritySecretUpdateOne {
_u.mutation.SetValue(v)
return _u
}
// SetNillableValue sets the "value" field if the given value is not nil.
func (_u *SecuritySecretUpdateOne) SetNillableValue(v *string) *SecuritySecretUpdateOne {
if v != nil {
_u.SetValue(*v)
}
return _u
}
// Mutation returns the SecuritySecretMutation object of the builder.
func (_u *SecuritySecretUpdateOne) Mutation() *SecuritySecretMutation {
return _u.mutation
}
// Where appends a list predicates to the SecuritySecretUpdate builder.
func (_u *SecuritySecretUpdateOne) Where(ps ...predicate.SecuritySecret) *SecuritySecretUpdateOne {
_u.mutation.Where(ps...)
return _u
}
// Select allows selecting one or more fields (columns) of the returned entity.
// The default is selecting all fields defined in the entity schema.
func (_u *SecuritySecretUpdateOne) Select(field string, fields ...string) *SecuritySecretUpdateOne {
_u.fields = append([]string{field}, fields...)
return _u
}
// Save executes the query and returns the updated SecuritySecret entity.
func (_u *SecuritySecretUpdateOne) Save(ctx context.Context) (*SecuritySecret, error) {
_u.defaults()
return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks)
}
// SaveX is like Save, but panics if an error occurs.
func (_u *SecuritySecretUpdateOne) SaveX(ctx context.Context) *SecuritySecret {
node, err := _u.Save(ctx)
if err != nil {
panic(err)
}
return node
}
// Exec executes the query on the entity.
func (_u *SecuritySecretUpdateOne) Exec(ctx context.Context) error {
_, err := _u.Save(ctx)
return err
}
// ExecX is like Exec, but panics if an error occurs.
func (_u *SecuritySecretUpdateOne) ExecX(ctx context.Context) {
if err := _u.Exec(ctx); err != nil {
panic(err)
}
}
// defaults sets the default values of the builder before save.
func (_u *SecuritySecretUpdateOne) defaults() {
if _, ok := _u.mutation.UpdatedAt(); !ok {
v := securitysecret.UpdateDefaultUpdatedAt()
_u.mutation.SetUpdatedAt(v)
}
}
// check runs all checks and user-defined validators on the builder.
func (_u *SecuritySecretUpdateOne) check() error {
if v, ok := _u.mutation.Key(); ok {
if err := securitysecret.KeyValidator(v); err != nil {
return &ValidationError{Name: "key", err: fmt.Errorf(`ent: validator failed for field "SecuritySecret.key": %w`, err)}
}
}
if v, ok := _u.mutation.Value(); ok {
if err := securitysecret.ValueValidator(v); err != nil {
return &ValidationError{Name: "value", err: fmt.Errorf(`ent: validator failed for field "SecuritySecret.value": %w`, err)}
}
}
return nil
}
func (_u *SecuritySecretUpdateOne) sqlSave(ctx context.Context) (_node *SecuritySecret, err error) {
if err := _u.check(); err != nil {
return _node, err
}
_spec := sqlgraph.NewUpdateSpec(securitysecret.Table, securitysecret.Columns, sqlgraph.NewFieldSpec(securitysecret.FieldID, field.TypeInt64))
id, ok := _u.mutation.ID()
if !ok {
return nil, &ValidationError{Name: "id", err: errors.New(`ent: missing "SecuritySecret.id" for update`)}
}
_spec.Node.ID.Value = id
if fields := _u.fields; len(fields) > 0 {
_spec.Node.Columns = make([]string, 0, len(fields))
_spec.Node.Columns = append(_spec.Node.Columns, securitysecret.FieldID)
for _, f := range fields {
if !securitysecret.ValidColumn(f) {
return nil, &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)}
}
if f != securitysecret.FieldID {
_spec.Node.Columns = append(_spec.Node.Columns, f)
}
}
}
if ps := _u.mutation.predicates; len(ps) > 0 {
_spec.Predicate = func(selector *sql.Selector) {
for i := range ps {
ps[i](selector)
}
}
}
if value, ok := _u.mutation.UpdatedAt(); ok {
_spec.SetField(securitysecret.FieldUpdatedAt, field.TypeTime, value)
}
if value, ok := _u.mutation.Key(); ok {
_spec.SetField(securitysecret.FieldKey, field.TypeString, value)
}
if value, ok := _u.mutation.Value(); ok {
_spec.SetField(securitysecret.FieldValue, field.TypeString, value)
}
_node = &SecuritySecret{config: _u.config}
_spec.Assign = _node.assignValues
_spec.ScanValues = _node.scanValues
if err = sqlgraph.UpdateNode(ctx, _u.driver, _spec); err != nil {
if _, ok := err.(*sqlgraph.NotFoundError); ok {
err = &NotFoundError{securitysecret.Label}
} else if sqlgraph.IsConstraintError(err) {
err = &ConstraintError{msg: err.Error(), wrap: err}
}
return nil, err
}
_u.mutation.done = true
return _node, nil
}
...@@ -36,6 +36,8 @@ type Tx struct { ...@@ -36,6 +36,8 @@ type Tx struct {
Proxy *ProxyClient Proxy *ProxyClient
// RedeemCode is the client for interacting with the RedeemCode builders. // RedeemCode is the client for interacting with the RedeemCode builders.
RedeemCode *RedeemCodeClient RedeemCode *RedeemCodeClient
// SecuritySecret is the client for interacting with the SecuritySecret builders.
SecuritySecret *SecuritySecretClient
// Setting is the client for interacting with the Setting builders. // Setting is the client for interacting with the Setting builders.
Setting *SettingClient Setting *SettingClient
// UsageCleanupTask is the client for interacting with the UsageCleanupTask builders. // UsageCleanupTask is the client for interacting with the UsageCleanupTask builders.
...@@ -194,6 +196,7 @@ func (tx *Tx) init() { ...@@ -194,6 +196,7 @@ func (tx *Tx) init() {
tx.PromoCodeUsage = NewPromoCodeUsageClient(tx.config) tx.PromoCodeUsage = NewPromoCodeUsageClient(tx.config)
tx.Proxy = NewProxyClient(tx.config) tx.Proxy = NewProxyClient(tx.config)
tx.RedeemCode = NewRedeemCodeClient(tx.config) tx.RedeemCode = NewRedeemCodeClient(tx.config)
tx.SecuritySecret = NewSecuritySecretClient(tx.config)
tx.Setting = NewSettingClient(tx.config) tx.Setting = NewSettingClient(tx.config)
tx.UsageCleanupTask = NewUsageCleanupTaskClient(tx.config) tx.UsageCleanupTask = NewUsageCleanupTaskClient(tx.config)
tx.UsageLog = NewUsageLogClient(tx.config) tx.UsageLog = NewUsageLogClient(tx.config)
......
...@@ -80,6 +80,8 @@ type UsageLog struct { ...@@ -80,6 +80,8 @@ type UsageLog struct {
ImageCount int `json:"image_count,omitempty"` ImageCount int `json:"image_count,omitempty"`
// ImageSize holds the value of the "image_size" field. // ImageSize holds the value of the "image_size" field.
ImageSize *string `json:"image_size,omitempty"` ImageSize *string `json:"image_size,omitempty"`
// MediaType holds the value of the "media_type" field.
MediaType *string `json:"media_type,omitempty"`
// CacheTTLOverridden holds the value of the "cache_ttl_overridden" field. // CacheTTLOverridden holds the value of the "cache_ttl_overridden" field.
CacheTTLOverridden bool `json:"cache_ttl_overridden,omitempty"` CacheTTLOverridden bool `json:"cache_ttl_overridden,omitempty"`
// CreatedAt holds the value of the "created_at" field. // CreatedAt holds the value of the "created_at" field.
...@@ -173,7 +175,7 @@ func (*UsageLog) scanValues(columns []string) ([]any, error) { ...@@ -173,7 +175,7 @@ func (*UsageLog) scanValues(columns []string) ([]any, error) {
values[i] = new(sql.NullFloat64) values[i] = new(sql.NullFloat64)
case usagelog.FieldID, usagelog.FieldUserID, usagelog.FieldAPIKeyID, usagelog.FieldAccountID, usagelog.FieldGroupID, usagelog.FieldSubscriptionID, usagelog.FieldInputTokens, usagelog.FieldOutputTokens, usagelog.FieldCacheCreationTokens, usagelog.FieldCacheReadTokens, usagelog.FieldCacheCreation5mTokens, usagelog.FieldCacheCreation1hTokens, usagelog.FieldBillingType, usagelog.FieldDurationMs, usagelog.FieldFirstTokenMs, usagelog.FieldImageCount: case usagelog.FieldID, usagelog.FieldUserID, usagelog.FieldAPIKeyID, usagelog.FieldAccountID, usagelog.FieldGroupID, usagelog.FieldSubscriptionID, usagelog.FieldInputTokens, usagelog.FieldOutputTokens, usagelog.FieldCacheCreationTokens, usagelog.FieldCacheReadTokens, usagelog.FieldCacheCreation5mTokens, usagelog.FieldCacheCreation1hTokens, usagelog.FieldBillingType, usagelog.FieldDurationMs, usagelog.FieldFirstTokenMs, usagelog.FieldImageCount:
values[i] = new(sql.NullInt64) values[i] = new(sql.NullInt64)
case usagelog.FieldRequestID, usagelog.FieldModel, usagelog.FieldUserAgent, usagelog.FieldIPAddress, usagelog.FieldImageSize: case usagelog.FieldRequestID, usagelog.FieldModel, usagelog.FieldUserAgent, usagelog.FieldIPAddress, usagelog.FieldImageSize, usagelog.FieldMediaType:
values[i] = new(sql.NullString) values[i] = new(sql.NullString)
case usagelog.FieldCreatedAt: case usagelog.FieldCreatedAt:
values[i] = new(sql.NullTime) values[i] = new(sql.NullTime)
...@@ -380,6 +382,13 @@ func (_m *UsageLog) assignValues(columns []string, values []any) error { ...@@ -380,6 +382,13 @@ func (_m *UsageLog) assignValues(columns []string, values []any) error {
_m.ImageSize = new(string) _m.ImageSize = new(string)
*_m.ImageSize = value.String *_m.ImageSize = value.String
} }
case usagelog.FieldMediaType:
if value, ok := values[i].(*sql.NullString); !ok {
return fmt.Errorf("unexpected type %T for field media_type", values[i])
} else if value.Valid {
_m.MediaType = new(string)
*_m.MediaType = value.String
}
case usagelog.FieldCacheTTLOverridden: case usagelog.FieldCacheTTLOverridden:
if value, ok := values[i].(*sql.NullBool); !ok { if value, ok := values[i].(*sql.NullBool); !ok {
return fmt.Errorf("unexpected type %T for field cache_ttl_overridden", values[i]) return fmt.Errorf("unexpected type %T for field cache_ttl_overridden", values[i])
...@@ -556,6 +565,11 @@ func (_m *UsageLog) String() string { ...@@ -556,6 +565,11 @@ func (_m *UsageLog) String() string {
builder.WriteString(*v) builder.WriteString(*v)
} }
builder.WriteString(", ") builder.WriteString(", ")
if v := _m.MediaType; v != nil {
builder.WriteString("media_type=")
builder.WriteString(*v)
}
builder.WriteString(", ")
builder.WriteString("cache_ttl_overridden=") builder.WriteString("cache_ttl_overridden=")
builder.WriteString(fmt.Sprintf("%v", _m.CacheTTLOverridden)) builder.WriteString(fmt.Sprintf("%v", _m.CacheTTLOverridden))
builder.WriteString(", ") builder.WriteString(", ")
......
...@@ -72,6 +72,8 @@ const ( ...@@ -72,6 +72,8 @@ const (
FieldImageCount = "image_count" FieldImageCount = "image_count"
// FieldImageSize holds the string denoting the image_size field in the database. // FieldImageSize holds the string denoting the image_size field in the database.
FieldImageSize = "image_size" FieldImageSize = "image_size"
// FieldMediaType holds the string denoting the media_type field in the database.
FieldMediaType = "media_type"
// FieldCacheTTLOverridden holds the string denoting the cache_ttl_overridden field in the database. // FieldCacheTTLOverridden holds the string denoting the cache_ttl_overridden field in the database.
FieldCacheTTLOverridden = "cache_ttl_overridden" FieldCacheTTLOverridden = "cache_ttl_overridden"
// FieldCreatedAt holds the string denoting the created_at field in the database. // FieldCreatedAt holds the string denoting the created_at field in the database.
...@@ -157,6 +159,7 @@ var Columns = []string{ ...@@ -157,6 +159,7 @@ var Columns = []string{
FieldIPAddress, FieldIPAddress,
FieldImageCount, FieldImageCount,
FieldImageSize, FieldImageSize,
FieldMediaType,
FieldCacheTTLOverridden, FieldCacheTTLOverridden,
FieldCreatedAt, FieldCreatedAt,
} }
...@@ -214,6 +217,8 @@ var ( ...@@ -214,6 +217,8 @@ var (
DefaultImageCount int DefaultImageCount int
// ImageSizeValidator is a validator for the "image_size" field. It is called by the builders before save. // ImageSizeValidator is a validator for the "image_size" field. It is called by the builders before save.
ImageSizeValidator func(string) error ImageSizeValidator func(string) error
// MediaTypeValidator is a validator for the "media_type" field. It is called by the builders before save.
MediaTypeValidator func(string) error
// DefaultCacheTTLOverridden holds the default value on creation for the "cache_ttl_overridden" field. // DefaultCacheTTLOverridden holds the default value on creation for the "cache_ttl_overridden" field.
DefaultCacheTTLOverridden bool DefaultCacheTTLOverridden bool
// DefaultCreatedAt holds the default value on creation for the "created_at" field. // DefaultCreatedAt holds the default value on creation for the "created_at" field.
...@@ -373,6 +378,11 @@ func ByImageSize(opts ...sql.OrderTermOption) OrderOption { ...@@ -373,6 +378,11 @@ func ByImageSize(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldImageSize, opts...).ToFunc() return sql.OrderByField(FieldImageSize, opts...).ToFunc()
} }
// ByMediaType orders the results by the media_type field.
func ByMediaType(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldMediaType, opts...).ToFunc()
}
// ByCacheTTLOverridden orders the results by the cache_ttl_overridden field. // ByCacheTTLOverridden orders the results by the cache_ttl_overridden field.
func ByCacheTTLOverridden(opts ...sql.OrderTermOption) OrderOption { func ByCacheTTLOverridden(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldCacheTTLOverridden, opts...).ToFunc() return sql.OrderByField(FieldCacheTTLOverridden, opts...).ToFunc()
......
...@@ -200,6 +200,11 @@ func ImageSize(v string) predicate.UsageLog { ...@@ -200,6 +200,11 @@ func ImageSize(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldEQ(FieldImageSize, v)) return predicate.UsageLog(sql.FieldEQ(FieldImageSize, v))
} }
// MediaType applies equality check predicate on the "media_type" field. It's identical to MediaTypeEQ.
func MediaType(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldEQ(FieldMediaType, v))
}
// CacheTTLOverridden applies equality check predicate on the "cache_ttl_overridden" field. It's identical to CacheTTLOverriddenEQ. // CacheTTLOverridden applies equality check predicate on the "cache_ttl_overridden" field. It's identical to CacheTTLOverriddenEQ.
func CacheTTLOverridden(v bool) predicate.UsageLog { func CacheTTLOverridden(v bool) predicate.UsageLog {
return predicate.UsageLog(sql.FieldEQ(FieldCacheTTLOverridden, v)) return predicate.UsageLog(sql.FieldEQ(FieldCacheTTLOverridden, v))
...@@ -1445,6 +1450,81 @@ func ImageSizeContainsFold(v string) predicate.UsageLog { ...@@ -1445,6 +1450,81 @@ func ImageSizeContainsFold(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldContainsFold(FieldImageSize, v)) return predicate.UsageLog(sql.FieldContainsFold(FieldImageSize, v))
} }
// MediaTypeEQ applies the EQ predicate on the "media_type" field.
func MediaTypeEQ(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldEQ(FieldMediaType, v))
}
// MediaTypeNEQ applies the NEQ predicate on the "media_type" field.
func MediaTypeNEQ(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldNEQ(FieldMediaType, v))
}
// MediaTypeIn applies the In predicate on the "media_type" field.
func MediaTypeIn(vs ...string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldIn(FieldMediaType, vs...))
}
// MediaTypeNotIn applies the NotIn predicate on the "media_type" field.
func MediaTypeNotIn(vs ...string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldNotIn(FieldMediaType, vs...))
}
// MediaTypeGT applies the GT predicate on the "media_type" field.
func MediaTypeGT(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldGT(FieldMediaType, v))
}
// MediaTypeGTE applies the GTE predicate on the "media_type" field.
func MediaTypeGTE(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldGTE(FieldMediaType, v))
}
// MediaTypeLT applies the LT predicate on the "media_type" field.
func MediaTypeLT(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldLT(FieldMediaType, v))
}
// MediaTypeLTE applies the LTE predicate on the "media_type" field.
func MediaTypeLTE(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldLTE(FieldMediaType, v))
}
// MediaTypeContains applies the Contains predicate on the "media_type" field.
func MediaTypeContains(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldContains(FieldMediaType, v))
}
// MediaTypeHasPrefix applies the HasPrefix predicate on the "media_type" field.
func MediaTypeHasPrefix(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldHasPrefix(FieldMediaType, v))
}
// MediaTypeHasSuffix applies the HasSuffix predicate on the "media_type" field.
func MediaTypeHasSuffix(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldHasSuffix(FieldMediaType, v))
}
// MediaTypeIsNil applies the IsNil predicate on the "media_type" field.
func MediaTypeIsNil() predicate.UsageLog {
return predicate.UsageLog(sql.FieldIsNull(FieldMediaType))
}
// MediaTypeNotNil applies the NotNil predicate on the "media_type" field.
func MediaTypeNotNil() predicate.UsageLog {
return predicate.UsageLog(sql.FieldNotNull(FieldMediaType))
}
// MediaTypeEqualFold applies the EqualFold predicate on the "media_type" field.
func MediaTypeEqualFold(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldEqualFold(FieldMediaType, v))
}
// MediaTypeContainsFold applies the ContainsFold predicate on the "media_type" field.
func MediaTypeContainsFold(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldContainsFold(FieldMediaType, v))
}
// CacheTTLOverriddenEQ applies the EQ predicate on the "cache_ttl_overridden" field. // CacheTTLOverriddenEQ applies the EQ predicate on the "cache_ttl_overridden" field.
func CacheTTLOverriddenEQ(v bool) predicate.UsageLog { func CacheTTLOverriddenEQ(v bool) predicate.UsageLog {
return predicate.UsageLog(sql.FieldEQ(FieldCacheTTLOverridden, v)) return predicate.UsageLog(sql.FieldEQ(FieldCacheTTLOverridden, v))
......
...@@ -393,6 +393,20 @@ func (_c *UsageLogCreate) SetNillableImageSize(v *string) *UsageLogCreate { ...@@ -393,6 +393,20 @@ func (_c *UsageLogCreate) SetNillableImageSize(v *string) *UsageLogCreate {
return _c return _c
} }
// SetMediaType sets the "media_type" field.
func (_c *UsageLogCreate) SetMediaType(v string) *UsageLogCreate {
_c.mutation.SetMediaType(v)
return _c
}
// SetNillableMediaType sets the "media_type" field if the given value is not nil.
func (_c *UsageLogCreate) SetNillableMediaType(v *string) *UsageLogCreate {
if v != nil {
_c.SetMediaType(*v)
}
return _c
}
// SetCacheTTLOverridden sets the "cache_ttl_overridden" field. // SetCacheTTLOverridden sets the "cache_ttl_overridden" field.
func (_c *UsageLogCreate) SetCacheTTLOverridden(v bool) *UsageLogCreate { func (_c *UsageLogCreate) SetCacheTTLOverridden(v bool) *UsageLogCreate {
_c.mutation.SetCacheTTLOverridden(v) _c.mutation.SetCacheTTLOverridden(v)
...@@ -645,6 +659,11 @@ func (_c *UsageLogCreate) check() error { ...@@ -645,6 +659,11 @@ func (_c *UsageLogCreate) check() error {
return &ValidationError{Name: "image_size", err: fmt.Errorf(`ent: validator failed for field "UsageLog.image_size": %w`, err)} return &ValidationError{Name: "image_size", err: fmt.Errorf(`ent: validator failed for field "UsageLog.image_size": %w`, err)}
} }
} }
if v, ok := _c.mutation.MediaType(); ok {
if err := usagelog.MediaTypeValidator(v); err != nil {
return &ValidationError{Name: "media_type", err: fmt.Errorf(`ent: validator failed for field "UsageLog.media_type": %w`, err)}
}
}
if _, ok := _c.mutation.CacheTTLOverridden(); !ok { if _, ok := _c.mutation.CacheTTLOverridden(); !ok {
return &ValidationError{Name: "cache_ttl_overridden", err: errors.New(`ent: missing required field "UsageLog.cache_ttl_overridden"`)} return &ValidationError{Name: "cache_ttl_overridden", err: errors.New(`ent: missing required field "UsageLog.cache_ttl_overridden"`)}
} }
...@@ -783,6 +802,10 @@ func (_c *UsageLogCreate) createSpec() (*UsageLog, *sqlgraph.CreateSpec) { ...@@ -783,6 +802,10 @@ func (_c *UsageLogCreate) createSpec() (*UsageLog, *sqlgraph.CreateSpec) {
_spec.SetField(usagelog.FieldImageSize, field.TypeString, value) _spec.SetField(usagelog.FieldImageSize, field.TypeString, value)
_node.ImageSize = &value _node.ImageSize = &value
} }
if value, ok := _c.mutation.MediaType(); ok {
_spec.SetField(usagelog.FieldMediaType, field.TypeString, value)
_node.MediaType = &value
}
if value, ok := _c.mutation.CacheTTLOverridden(); ok { if value, ok := _c.mutation.CacheTTLOverridden(); ok {
_spec.SetField(usagelog.FieldCacheTTLOverridden, field.TypeBool, value) _spec.SetField(usagelog.FieldCacheTTLOverridden, field.TypeBool, value)
_node.CacheTTLOverridden = value _node.CacheTTLOverridden = value
...@@ -1432,6 +1455,24 @@ func (u *UsageLogUpsert) ClearImageSize() *UsageLogUpsert { ...@@ -1432,6 +1455,24 @@ func (u *UsageLogUpsert) ClearImageSize() *UsageLogUpsert {
return u return u
} }
// SetMediaType sets the "media_type" field.
func (u *UsageLogUpsert) SetMediaType(v string) *UsageLogUpsert {
u.Set(usagelog.FieldMediaType, v)
return u
}
// UpdateMediaType sets the "media_type" field to the value that was provided on create.
func (u *UsageLogUpsert) UpdateMediaType() *UsageLogUpsert {
u.SetExcluded(usagelog.FieldMediaType)
return u
}
// ClearMediaType clears the value of the "media_type" field.
func (u *UsageLogUpsert) ClearMediaType() *UsageLogUpsert {
u.SetNull(usagelog.FieldMediaType)
return u
}
// SetCacheTTLOverridden sets the "cache_ttl_overridden" field. // SetCacheTTLOverridden sets the "cache_ttl_overridden" field.
func (u *UsageLogUpsert) SetCacheTTLOverridden(v bool) *UsageLogUpsert { func (u *UsageLogUpsert) SetCacheTTLOverridden(v bool) *UsageLogUpsert {
u.Set(usagelog.FieldCacheTTLOverridden, v) u.Set(usagelog.FieldCacheTTLOverridden, v)
...@@ -2077,6 +2118,27 @@ func (u *UsageLogUpsertOne) ClearImageSize() *UsageLogUpsertOne { ...@@ -2077,6 +2118,27 @@ func (u *UsageLogUpsertOne) ClearImageSize() *UsageLogUpsertOne {
}) })
} }
// SetMediaType sets the "media_type" field.
func (u *UsageLogUpsertOne) SetMediaType(v string) *UsageLogUpsertOne {
return u.Update(func(s *UsageLogUpsert) {
s.SetMediaType(v)
})
}
// UpdateMediaType sets the "media_type" field to the value that was provided on create.
func (u *UsageLogUpsertOne) UpdateMediaType() *UsageLogUpsertOne {
return u.Update(func(s *UsageLogUpsert) {
s.UpdateMediaType()
})
}
// ClearMediaType clears the value of the "media_type" field.
func (u *UsageLogUpsertOne) ClearMediaType() *UsageLogUpsertOne {
return u.Update(func(s *UsageLogUpsert) {
s.ClearMediaType()
})
}
// SetCacheTTLOverridden sets the "cache_ttl_overridden" field. // SetCacheTTLOverridden sets the "cache_ttl_overridden" field.
func (u *UsageLogUpsertOne) SetCacheTTLOverridden(v bool) *UsageLogUpsertOne { func (u *UsageLogUpsertOne) SetCacheTTLOverridden(v bool) *UsageLogUpsertOne {
return u.Update(func(s *UsageLogUpsert) { return u.Update(func(s *UsageLogUpsert) {
...@@ -2890,6 +2952,27 @@ func (u *UsageLogUpsertBulk) ClearImageSize() *UsageLogUpsertBulk { ...@@ -2890,6 +2952,27 @@ func (u *UsageLogUpsertBulk) ClearImageSize() *UsageLogUpsertBulk {
}) })
} }
// SetMediaType sets the "media_type" field.
func (u *UsageLogUpsertBulk) SetMediaType(v string) *UsageLogUpsertBulk {
return u.Update(func(s *UsageLogUpsert) {
s.SetMediaType(v)
})
}
// UpdateMediaType sets the "media_type" field to the value that was provided on create.
func (u *UsageLogUpsertBulk) UpdateMediaType() *UsageLogUpsertBulk {
return u.Update(func(s *UsageLogUpsert) {
s.UpdateMediaType()
})
}
// ClearMediaType clears the value of the "media_type" field.
func (u *UsageLogUpsertBulk) ClearMediaType() *UsageLogUpsertBulk {
return u.Update(func(s *UsageLogUpsert) {
s.ClearMediaType()
})
}
// SetCacheTTLOverridden sets the "cache_ttl_overridden" field. // SetCacheTTLOverridden sets the "cache_ttl_overridden" field.
func (u *UsageLogUpsertBulk) SetCacheTTLOverridden(v bool) *UsageLogUpsertBulk { func (u *UsageLogUpsertBulk) SetCacheTTLOverridden(v bool) *UsageLogUpsertBulk {
return u.Update(func(s *UsageLogUpsert) { return u.Update(func(s *UsageLogUpsert) {
......
...@@ -612,6 +612,26 @@ func (_u *UsageLogUpdate) ClearImageSize() *UsageLogUpdate { ...@@ -612,6 +612,26 @@ func (_u *UsageLogUpdate) ClearImageSize() *UsageLogUpdate {
return _u return _u
} }
// SetMediaType sets the "media_type" field.
func (_u *UsageLogUpdate) SetMediaType(v string) *UsageLogUpdate {
_u.mutation.SetMediaType(v)
return _u
}
// SetNillableMediaType sets the "media_type" field if the given value is not nil.
func (_u *UsageLogUpdate) SetNillableMediaType(v *string) *UsageLogUpdate {
if v != nil {
_u.SetMediaType(*v)
}
return _u
}
// ClearMediaType clears the value of the "media_type" field.
func (_u *UsageLogUpdate) ClearMediaType() *UsageLogUpdate {
_u.mutation.ClearMediaType()
return _u
}
// SetCacheTTLOverridden sets the "cache_ttl_overridden" field. // SetCacheTTLOverridden sets the "cache_ttl_overridden" field.
func (_u *UsageLogUpdate) SetCacheTTLOverridden(v bool) *UsageLogUpdate { func (_u *UsageLogUpdate) SetCacheTTLOverridden(v bool) *UsageLogUpdate {
_u.mutation.SetCacheTTLOverridden(v) _u.mutation.SetCacheTTLOverridden(v)
...@@ -740,6 +760,11 @@ func (_u *UsageLogUpdate) check() error { ...@@ -740,6 +760,11 @@ func (_u *UsageLogUpdate) check() error {
return &ValidationError{Name: "image_size", err: fmt.Errorf(`ent: validator failed for field "UsageLog.image_size": %w`, err)} return &ValidationError{Name: "image_size", err: fmt.Errorf(`ent: validator failed for field "UsageLog.image_size": %w`, err)}
} }
} }
if v, ok := _u.mutation.MediaType(); ok {
if err := usagelog.MediaTypeValidator(v); err != nil {
return &ValidationError{Name: "media_type", err: fmt.Errorf(`ent: validator failed for field "UsageLog.media_type": %w`, err)}
}
}
if _u.mutation.UserCleared() && len(_u.mutation.UserIDs()) > 0 { if _u.mutation.UserCleared() && len(_u.mutation.UserIDs()) > 0 {
return errors.New(`ent: clearing a required unique edge "UsageLog.user"`) return errors.New(`ent: clearing a required unique edge "UsageLog.user"`)
} }
...@@ -908,6 +933,12 @@ func (_u *UsageLogUpdate) sqlSave(ctx context.Context) (_node int, err error) { ...@@ -908,6 +933,12 @@ func (_u *UsageLogUpdate) sqlSave(ctx context.Context) (_node int, err error) {
if _u.mutation.ImageSizeCleared() { if _u.mutation.ImageSizeCleared() {
_spec.ClearField(usagelog.FieldImageSize, field.TypeString) _spec.ClearField(usagelog.FieldImageSize, field.TypeString)
} }
if value, ok := _u.mutation.MediaType(); ok {
_spec.SetField(usagelog.FieldMediaType, field.TypeString, value)
}
if _u.mutation.MediaTypeCleared() {
_spec.ClearField(usagelog.FieldMediaType, field.TypeString)
}
if value, ok := _u.mutation.CacheTTLOverridden(); ok { if value, ok := _u.mutation.CacheTTLOverridden(); ok {
_spec.SetField(usagelog.FieldCacheTTLOverridden, field.TypeBool, value) _spec.SetField(usagelog.FieldCacheTTLOverridden, field.TypeBool, value)
} }
...@@ -1656,6 +1687,26 @@ func (_u *UsageLogUpdateOne) ClearImageSize() *UsageLogUpdateOne { ...@@ -1656,6 +1687,26 @@ func (_u *UsageLogUpdateOne) ClearImageSize() *UsageLogUpdateOne {
return _u return _u
} }
// SetMediaType sets the "media_type" field.
func (_u *UsageLogUpdateOne) SetMediaType(v string) *UsageLogUpdateOne {
_u.mutation.SetMediaType(v)
return _u
}
// SetNillableMediaType sets the "media_type" field if the given value is not nil.
func (_u *UsageLogUpdateOne) SetNillableMediaType(v *string) *UsageLogUpdateOne {
if v != nil {
_u.SetMediaType(*v)
}
return _u
}
// ClearMediaType clears the value of the "media_type" field.
func (_u *UsageLogUpdateOne) ClearMediaType() *UsageLogUpdateOne {
_u.mutation.ClearMediaType()
return _u
}
// SetCacheTTLOverridden sets the "cache_ttl_overridden" field. // SetCacheTTLOverridden sets the "cache_ttl_overridden" field.
func (_u *UsageLogUpdateOne) SetCacheTTLOverridden(v bool) *UsageLogUpdateOne { func (_u *UsageLogUpdateOne) SetCacheTTLOverridden(v bool) *UsageLogUpdateOne {
_u.mutation.SetCacheTTLOverridden(v) _u.mutation.SetCacheTTLOverridden(v)
...@@ -1797,6 +1848,11 @@ func (_u *UsageLogUpdateOne) check() error { ...@@ -1797,6 +1848,11 @@ func (_u *UsageLogUpdateOne) check() error {
return &ValidationError{Name: "image_size", err: fmt.Errorf(`ent: validator failed for field "UsageLog.image_size": %w`, err)} return &ValidationError{Name: "image_size", err: fmt.Errorf(`ent: validator failed for field "UsageLog.image_size": %w`, err)}
} }
} }
if v, ok := _u.mutation.MediaType(); ok {
if err := usagelog.MediaTypeValidator(v); err != nil {
return &ValidationError{Name: "media_type", err: fmt.Errorf(`ent: validator failed for field "UsageLog.media_type": %w`, err)}
}
}
if _u.mutation.UserCleared() && len(_u.mutation.UserIDs()) > 0 { if _u.mutation.UserCleared() && len(_u.mutation.UserIDs()) > 0 {
return errors.New(`ent: clearing a required unique edge "UsageLog.user"`) return errors.New(`ent: clearing a required unique edge "UsageLog.user"`)
} }
...@@ -1982,6 +2038,12 @@ func (_u *UsageLogUpdateOne) sqlSave(ctx context.Context) (_node *UsageLog, err ...@@ -1982,6 +2038,12 @@ func (_u *UsageLogUpdateOne) sqlSave(ctx context.Context) (_node *UsageLog, err
if _u.mutation.ImageSizeCleared() { if _u.mutation.ImageSizeCleared() {
_spec.ClearField(usagelog.FieldImageSize, field.TypeString) _spec.ClearField(usagelog.FieldImageSize, field.TypeString)
} }
if value, ok := _u.mutation.MediaType(); ok {
_spec.SetField(usagelog.FieldMediaType, field.TypeString, value)
}
if _u.mutation.MediaTypeCleared() {
_spec.ClearField(usagelog.FieldMediaType, field.TypeString)
}
if value, ok := _u.mutation.CacheTTLOverridden(); ok { if value, ok := _u.mutation.CacheTTLOverridden(); ok {
_spec.SetField(usagelog.FieldCacheTTLOverridden, field.TypeBool, value) _spec.SetField(usagelog.FieldCacheTTLOverridden, field.TypeBool, value)
} }
......
...@@ -5,6 +5,8 @@ go 1.25.7 ...@@ -5,6 +5,8 @@ go 1.25.7
require ( require (
entgo.io/ent v0.14.5 entgo.io/ent v0.14.5
github.com/DATA-DOG/go-sqlmock v1.5.2 github.com/DATA-DOG/go-sqlmock v1.5.2
github.com/alitto/pond/v2 v2.6.2
github.com/cespare/xxhash/v2 v2.3.0
github.com/dgraph-io/ristretto v0.2.0 github.com/dgraph-io/ristretto v0.2.0
github.com/gin-gonic/gin v1.9.1 github.com/gin-gonic/gin v1.9.1
github.com/golang-jwt/jwt/v5 v5.2.2 github.com/golang-jwt/jwt/v5 v5.2.2
...@@ -13,6 +15,7 @@ require ( ...@@ -13,6 +15,7 @@ require (
github.com/gorilla/websocket v1.5.3 github.com/gorilla/websocket v1.5.3
github.com/imroc/req/v3 v3.57.0 github.com/imroc/req/v3 v3.57.0
github.com/lib/pq v1.10.9 github.com/lib/pq v1.10.9
github.com/patrickmn/go-cache v2.1.0+incompatible
github.com/pquerna/otp v1.5.0 github.com/pquerna/otp v1.5.0
github.com/redis/go-redis/v9 v9.17.2 github.com/redis/go-redis/v9 v9.17.2
github.com/refraction-networking/utls v1.8.1 github.com/refraction-networking/utls v1.8.1
...@@ -25,10 +28,12 @@ require ( ...@@ -25,10 +28,12 @@ require (
github.com/tidwall/gjson v1.18.0 github.com/tidwall/gjson v1.18.0
github.com/tidwall/sjson v1.2.5 github.com/tidwall/sjson v1.2.5
github.com/zeromicro/go-zero v1.9.4 github.com/zeromicro/go-zero v1.9.4
go.uber.org/zap v1.24.0
golang.org/x/crypto v0.47.0 golang.org/x/crypto v0.47.0
golang.org/x/net v0.49.0 golang.org/x/net v0.49.0
golang.org/x/sync v0.19.0 golang.org/x/sync v0.19.0
golang.org/x/term v0.39.0 golang.org/x/term v0.39.0
gopkg.in/natefinch/lumberjack.v2 v2.2.1
gopkg.in/yaml.v3 v3.0.1 gopkg.in/yaml.v3 v3.0.1
modernc.org/sqlite v1.44.3 modernc.org/sqlite v1.44.3
) )
...@@ -45,7 +50,6 @@ require ( ...@@ -45,7 +50,6 @@ require (
github.com/boombuler/barcode v1.0.1-0.20190219062509-6c824513bacc // indirect github.com/boombuler/barcode v1.0.1-0.20190219062509-6c824513bacc // indirect
github.com/bytedance/sonic v1.9.1 // indirect github.com/bytedance/sonic v1.9.1 // indirect
github.com/cenkalti/backoff/v4 v4.3.0 // indirect github.com/cenkalti/backoff/v4 v4.3.0 // indirect
github.com/cespare/xxhash/v2 v2.3.0 // indirect
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 // indirect github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 // indirect
github.com/containerd/errdefs v1.0.0 // indirect github.com/containerd/errdefs v1.0.0 // indirect
github.com/containerd/errdefs/pkg v0.3.0 // indirect github.com/containerd/errdefs/pkg v0.3.0 // indirect
...@@ -75,6 +79,7 @@ require ( ...@@ -75,6 +79,7 @@ require (
github.com/goccy/go-json v0.10.2 // indirect github.com/goccy/go-json v0.10.2 // indirect
github.com/google/go-cmp v0.7.0 // indirect github.com/google/go-cmp v0.7.0 // indirect
github.com/google/go-querystring v1.1.0 // indirect github.com/google/go-querystring v1.1.0 // indirect
github.com/google/subcommands v1.2.0 // indirect
github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.3 // indirect github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.3 // indirect
github.com/hashicorp/hcl v1.0.0 // indirect github.com/hashicorp/hcl v1.0.0 // indirect
github.com/hashicorp/hcl/v2 v2.18.1 // indirect github.com/hashicorp/hcl/v2 v2.18.1 // indirect
...@@ -103,7 +108,6 @@ require ( ...@@ -103,7 +108,6 @@ require (
github.com/ncruces/go-strftime v1.0.0 // indirect github.com/ncruces/go-strftime v1.0.0 // indirect
github.com/opencontainers/go-digest v1.0.0 // indirect github.com/opencontainers/go-digest v1.0.0 // indirect
github.com/opencontainers/image-spec v1.1.1 // indirect github.com/opencontainers/image-spec v1.1.1 // indirect
github.com/patrickmn/go-cache v2.1.0+incompatible // indirect
github.com/pelletier/go-toml/v2 v2.2.2 // indirect github.com/pelletier/go-toml/v2 v2.2.2 // indirect
github.com/pkg/errors v0.9.1 // indirect github.com/pkg/errors v0.9.1 // indirect
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect
...@@ -144,6 +148,7 @@ require ( ...@@ -144,6 +148,7 @@ require (
golang.org/x/mod v0.31.0 // indirect golang.org/x/mod v0.31.0 // indirect
golang.org/x/sys v0.40.0 // indirect golang.org/x/sys v0.40.0 // indirect
golang.org/x/text v0.33.0 // indirect golang.org/x/text v0.33.0 // indirect
golang.org/x/tools v0.40.0 // indirect
google.golang.org/grpc v1.75.1 // indirect google.golang.org/grpc v1.75.1 // indirect
google.golang.org/protobuf v1.36.10 // indirect google.golang.org/protobuf v1.36.10 // indirect
gopkg.in/ini.v1 v1.67.0 // indirect gopkg.in/ini.v1 v1.67.0 // indirect
......
...@@ -14,10 +14,14 @@ github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERo ...@@ -14,10 +14,14 @@ github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERo
github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA8Ipt1oGCvU= github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA8Ipt1oGCvU=
github.com/agext/levenshtein v1.2.3 h1:YB2fHEn0UJagG8T1rrWknE3ZQzWM06O8AMAatNn7lmo= github.com/agext/levenshtein v1.2.3 h1:YB2fHEn0UJagG8T1rrWknE3ZQzWM06O8AMAatNn7lmo=
github.com/agext/levenshtein v1.2.3/go.mod h1:JEDfjyjHDjOF/1e4FlBE/PkbqA9OfWu2ki2W0IB5558= github.com/agext/levenshtein v1.2.3/go.mod h1:JEDfjyjHDjOF/1e4FlBE/PkbqA9OfWu2ki2W0IB5558=
github.com/alitto/pond/v2 v2.6.2 h1:Sphe40g0ILeM1pA2c2K+Th0DGU+pt0A/Kprr+WB24Pw=
github.com/alitto/pond/v2 v2.6.2/go.mod h1:xkjYEgQ05RSpWdfSd1nM3OVv7TBhLdy7rMp3+2Nq+yE=
github.com/andybalholm/brotli v1.2.0 h1:ukwgCxwYrmACq68yiUqwIWnGY0cTPox/M94sVwToPjQ= github.com/andybalholm/brotli v1.2.0 h1:ukwgCxwYrmACq68yiUqwIWnGY0cTPox/M94sVwToPjQ=
github.com/andybalholm/brotli v1.2.0/go.mod h1:rzTDkvFWvIrjDXZHkuS16NPggd91W3kUSvPlQ1pLaKY= github.com/andybalholm/brotli v1.2.0/go.mod h1:rzTDkvFWvIrjDXZHkuS16NPggd91W3kUSvPlQ1pLaKY=
github.com/apparentlymart/go-textseg/v15 v15.0.0 h1:uYvfpb3DyLSCGWnctWKGj857c6ew1u1fNQOlOtuGxQY= github.com/apparentlymart/go-textseg/v15 v15.0.0 h1:uYvfpb3DyLSCGWnctWKGj857c6ew1u1fNQOlOtuGxQY=
github.com/apparentlymart/go-textseg/v15 v15.0.0/go.mod h1:K8XmNZdhEBkdlyDdvbmmsvpAG721bKi0joRfFdHIWJ4= github.com/apparentlymart/go-textseg/v15 v15.0.0/go.mod h1:K8XmNZdhEBkdlyDdvbmmsvpAG721bKi0joRfFdHIWJ4=
github.com/benbjohnson/clock v1.1.0 h1:Q92kusRqC1XV2MjkWETPvjJVqKetz1OzxZB7mHJLju8=
github.com/benbjohnson/clock v1.1.0/go.mod h1:J11/hYXuz8f4ySSvYwY0FKfm+ezbsZBKZxNJlLklBHA=
github.com/bmatcuk/doublestar v1.3.4 h1:gPypJ5xD31uhX6Tf54sDPUOBXTqKH4c9aPY66CyQrS0= github.com/bmatcuk/doublestar v1.3.4 h1:gPypJ5xD31uhX6Tf54sDPUOBXTqKH4c9aPY66CyQrS0=
github.com/bmatcuk/doublestar v1.3.4/go.mod h1:wiQtGV+rzVYxB7WIlirSN++5HPtPlXEo9MEoZQC/PmE= github.com/bmatcuk/doublestar v1.3.4/go.mod h1:wiQtGV+rzVYxB7WIlirSN++5HPtPlXEo9MEoZQC/PmE=
github.com/boombuler/barcode v1.0.1-0.20190219062509-6c824513bacc h1:biVzkmvwrH8WK8raXaxBx6fRVTlJILwEwQGL1I/ByEI= github.com/boombuler/barcode v1.0.1-0.20190219062509-6c824513bacc h1:biVzkmvwrH8WK8raXaxBx6fRVTlJILwEwQGL1I/ByEI=
...@@ -116,6 +120,8 @@ github.com/google/go-querystring v1.1.0/go.mod h1:Kcdr2DB4koayq7X8pmAG4sNG59So17 ...@@ -116,6 +120,8 @@ github.com/google/go-querystring v1.1.0/go.mod h1:Kcdr2DB4koayq7X8pmAG4sNG59So17
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e h1:ijClszYn+mADRFY17kjQEVQ1XRhq2/JR1M3sGqeJoxs= github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e h1:ijClszYn+mADRFY17kjQEVQ1XRhq2/JR1M3sGqeJoxs=
github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e/go.mod h1:boTsfXsheKC2y+lKOCMpSfarhxDeIzfZG1jqGcPl3cA= github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e/go.mod h1:boTsfXsheKC2y+lKOCMpSfarhxDeIzfZG1jqGcPl3cA=
github.com/google/subcommands v1.2.0 h1:vWQspBTo2nEqTUFita5/KeEWlUL8kQObDFbub/EN9oE=
github.com/google/subcommands v1.2.0/go.mod h1:ZjhPrFU+Olkh9WazFPsl27BQ4UPiG37m3yTrtFlrHVk=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/google/wire v0.7.0 h1:JxUKI6+CVBgCO2WToKy/nQk0sS+amI9z9EjVmdaocj4= github.com/google/wire v0.7.0 h1:JxUKI6+CVBgCO2WToKy/nQk0sS+amI9z9EjVmdaocj4=
...@@ -135,8 +141,6 @@ github.com/icholy/digest v1.1.0 h1:HfGg9Irj7i+IX1o1QAmPfIBNu/Q5A5Tu3n/MED9k9H4= ...@@ -135,8 +141,6 @@ github.com/icholy/digest v1.1.0 h1:HfGg9Irj7i+IX1o1QAmPfIBNu/Q5A5Tu3n/MED9k9H4=
github.com/icholy/digest v1.1.0/go.mod h1:QNrsSGQ5v7v9cReDI0+eyjsXGUoRSUZQHeQ5C4XLa0Y= github.com/icholy/digest v1.1.0/go.mod h1:QNrsSGQ5v7v9cReDI0+eyjsXGUoRSUZQHeQ5C4XLa0Y=
github.com/imroc/req/v3 v3.57.0 h1:LMTUjNRUybUkTPn8oJDq8Kg3JRBOBTcnDhKu7mzupKI= github.com/imroc/req/v3 v3.57.0 h1:LMTUjNRUybUkTPn8oJDq8Kg3JRBOBTcnDhKu7mzupKI=
github.com/imroc/req/v3 v3.57.0/go.mod h1:JL62ey1nvSLq81HORNcosvlf7SxZStONNqOprg0Pz00= github.com/imroc/req/v3 v3.57.0/go.mod h1:JL62ey1nvSLq81HORNcosvlf7SxZStONNqOprg0Pz00=
github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8=
github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw=
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo= github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo=
...@@ -340,10 +344,14 @@ go.uber.org/atomic v1.10.0 h1:9qC72Qh0+3MqyJbAn8YU5xVq1frD8bn3JtD2oXtafVQ= ...@@ -340,10 +344,14 @@ go.uber.org/atomic v1.10.0 h1:9qC72Qh0+3MqyJbAn8YU5xVq1frD8bn3JtD2oXtafVQ=
go.uber.org/atomic v1.10.0/go.mod h1:LUxbIzbOniOlMKjJjyPfpl4v+PKK2cNJn91OQbhoJI0= go.uber.org/atomic v1.10.0/go.mod h1:LUxbIzbOniOlMKjJjyPfpl4v+PKK2cNJn91OQbhoJI0=
go.uber.org/automaxprocs v1.6.0 h1:O3y2/QNTOdbF+e/dpXNNW7Rx2hZ4sTIPyybbxyNqTUs= go.uber.org/automaxprocs v1.6.0 h1:O3y2/QNTOdbF+e/dpXNNW7Rx2hZ4sTIPyybbxyNqTUs=
go.uber.org/automaxprocs v1.6.0/go.mod h1:ifeIMSnPZuznNm6jmdzmU3/bfk01Fe2fotchwEFJ8r8= go.uber.org/automaxprocs v1.6.0/go.mod h1:ifeIMSnPZuznNm6jmdzmU3/bfk01Fe2fotchwEFJ8r8=
go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto=
go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE=
go.uber.org/mock v0.6.0 h1:hyF9dfmbgIX5EfOdasqLsWD6xqpNZlXblLB/Dbnwv3Y= go.uber.org/mock v0.6.0 h1:hyF9dfmbgIX5EfOdasqLsWD6xqpNZlXblLB/Dbnwv3Y=
go.uber.org/mock v0.6.0/go.mod h1:KiVJ4BqZJaMj4svdfmHM0AUx4NJYO8ZNpPnZn1Z+BBU= go.uber.org/mock v0.6.0/go.mod h1:KiVJ4BqZJaMj4svdfmHM0AUx4NJYO8ZNpPnZn1Z+BBU=
go.uber.org/multierr v1.9.0 h1:7fIwc/ZtS0q++VgcfqFDxSBZVv/Xo49/SYnDFupUwlI= go.uber.org/multierr v1.9.0 h1:7fIwc/ZtS0q++VgcfqFDxSBZVv/Xo49/SYnDFupUwlI=
go.uber.org/multierr v1.9.0/go.mod h1:X2jQV1h+kxSjClGpnseKVIxpmcjrj7MNnI0bnlfKTVQ= go.uber.org/multierr v1.9.0/go.mod h1:X2jQV1h+kxSjClGpnseKVIxpmcjrj7MNnI0bnlfKTVQ=
go.uber.org/zap v1.24.0 h1:FiJd5l1UOLj0wCgbSE0rwwXHzEdAZS6hiiSnxJN/D60=
go.uber.org/zap v1.24.0/go.mod h1:2kMP+WWQ8aoFoedH3T2sq6iJ2yDWpHbP0f6MQbS9Gkg=
golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8= golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
golang.org/x/arch v0.3.0 h1:02VY4/ZcO/gBOH6PUaoiptASxtXU10jazRCP865E97k= golang.org/x/arch v0.3.0 h1:02VY4/ZcO/gBOH6PUaoiptASxtXU10jazRCP865E97k=
golang.org/x/arch v0.3.0/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8= golang.org/x/arch v0.3.0/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
...@@ -391,6 +399,8 @@ gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntN ...@@ -391,6 +399,8 @@ gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntN
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
gopkg.in/ini.v1 v1.67.0 h1:Dgnx+6+nfE+IfzjUEISNeydPJh9AXNNsWbGP9KzCsOA= gopkg.in/ini.v1 v1.67.0 h1:Dgnx+6+nfE+IfzjUEISNeydPJh9AXNNsWbGP9KzCsOA=
gopkg.in/ini.v1 v1.67.0/go.mod h1:pNLf8WUiyNEtQjuu5G5vTm06TEv9tsIgeAvK8hOrP4k= gopkg.in/ini.v1 v1.67.0/go.mod h1:pNLf8WUiyNEtQjuu5G5vTm06TEv9tsIgeAvK8hOrP4k=
gopkg.in/natefinch/lumberjack.v2 v2.2.1 h1:bBRl1b0OH9s/DuPhuXpNl+VtCaJXFZ5/uEFST95x9zc=
gopkg.in/natefinch/lumberjack.v2 v2.2.1/go.mod h1:YD8tP3GAjkrDg1eZH7EGmyESg/lsYskCTPBJVb9jqSc=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
......
...@@ -5,7 +5,7 @@ import ( ...@@ -5,7 +5,7 @@ import (
"crypto/rand" "crypto/rand"
"encoding/hex" "encoding/hex"
"fmt" "fmt"
"log" "log/slog"
"net/url" "net/url"
"os" "os"
"strings" "strings"
...@@ -19,6 +19,13 @@ const ( ...@@ -19,6 +19,13 @@ const (
RunModeSimple = "simple" RunModeSimple = "simple"
) )
// 使用量记录队列溢出策略
const (
UsageRecordOverflowPolicyDrop = "drop"
UsageRecordOverflowPolicySample = "sample"
UsageRecordOverflowPolicySync = "sync"
)
// DefaultCSPPolicy is the default Content-Security-Policy with nonce support // DefaultCSPPolicy is the default Content-Security-Policy with nonce support
// __CSP_NONCE__ will be replaced with actual nonce at request time by the SecurityHeaders middleware // __CSP_NONCE__ will be replaced with actual nonce at request time by the SecurityHeaders middleware
const DefaultCSPPolicy = "default-src 'self'; script-src 'self' __CSP_NONCE__ https://challenges.cloudflare.com https://static.cloudflareinsights.com; style-src 'self' 'unsafe-inline' https://fonts.googleapis.com; img-src 'self' data: https:; font-src 'self' data: https://fonts.gstatic.com; connect-src 'self' https:; frame-src https://challenges.cloudflare.com; frame-ancestors 'none'; base-uri 'self'; form-action 'self'" const DefaultCSPPolicy = "default-src 'self'; script-src 'self' __CSP_NONCE__ https://challenges.cloudflare.com https://static.cloudflareinsights.com; style-src 'self' 'unsafe-inline' https://fonts.googleapis.com; img-src 'self' data: https:; font-src 'self' data: https://fonts.gstatic.com; connect-src 'self' https:; frame-src https://challenges.cloudflare.com; frame-ancestors 'none'; base-uri 'self'; form-action 'self'"
...@@ -38,31 +45,68 @@ const ( ...@@ -38,31 +45,68 @@ const (
) )
type Config struct { type Config struct {
Server ServerConfig `mapstructure:"server"` Server ServerConfig `mapstructure:"server"`
CORS CORSConfig `mapstructure:"cors"` Log LogConfig `mapstructure:"log"`
Security SecurityConfig `mapstructure:"security"` CORS CORSConfig `mapstructure:"cors"`
Billing BillingConfig `mapstructure:"billing"` Security SecurityConfig `mapstructure:"security"`
Turnstile TurnstileConfig `mapstructure:"turnstile"` Billing BillingConfig `mapstructure:"billing"`
Database DatabaseConfig `mapstructure:"database"` Turnstile TurnstileConfig `mapstructure:"turnstile"`
Redis RedisConfig `mapstructure:"redis"` Database DatabaseConfig `mapstructure:"database"`
Ops OpsConfig `mapstructure:"ops"` Redis RedisConfig `mapstructure:"redis"`
JWT JWTConfig `mapstructure:"jwt"` Ops OpsConfig `mapstructure:"ops"`
Totp TotpConfig `mapstructure:"totp"` JWT JWTConfig `mapstructure:"jwt"`
LinuxDo LinuxDoConnectConfig `mapstructure:"linuxdo_connect"` Totp TotpConfig `mapstructure:"totp"`
Default DefaultConfig `mapstructure:"default"` LinuxDo LinuxDoConnectConfig `mapstructure:"linuxdo_connect"`
RateLimit RateLimitConfig `mapstructure:"rate_limit"` Default DefaultConfig `mapstructure:"default"`
Pricing PricingConfig `mapstructure:"pricing"` RateLimit RateLimitConfig `mapstructure:"rate_limit"`
Gateway GatewayConfig `mapstructure:"gateway"` Pricing PricingConfig `mapstructure:"pricing"`
APIKeyAuth APIKeyAuthCacheConfig `mapstructure:"api_key_auth_cache"` Gateway GatewayConfig `mapstructure:"gateway"`
Dashboard DashboardCacheConfig `mapstructure:"dashboard_cache"` APIKeyAuth APIKeyAuthCacheConfig `mapstructure:"api_key_auth_cache"`
DashboardAgg DashboardAggregationConfig `mapstructure:"dashboard_aggregation"` SubscriptionCache SubscriptionCacheConfig `mapstructure:"subscription_cache"`
UsageCleanup UsageCleanupConfig `mapstructure:"usage_cleanup"` SubscriptionMaintenance SubscriptionMaintenanceConfig `mapstructure:"subscription_maintenance"`
Concurrency ConcurrencyConfig `mapstructure:"concurrency"` Dashboard DashboardCacheConfig `mapstructure:"dashboard_cache"`
TokenRefresh TokenRefreshConfig `mapstructure:"token_refresh"` DashboardAgg DashboardAggregationConfig `mapstructure:"dashboard_aggregation"`
RunMode string `mapstructure:"run_mode" yaml:"run_mode"` UsageCleanup UsageCleanupConfig `mapstructure:"usage_cleanup"`
Timezone string `mapstructure:"timezone"` // e.g. "Asia/Shanghai", "UTC" Concurrency ConcurrencyConfig `mapstructure:"concurrency"`
Gemini GeminiConfig `mapstructure:"gemini"` TokenRefresh TokenRefreshConfig `mapstructure:"token_refresh"`
Update UpdateConfig `mapstructure:"update"` Sora SoraConfig `mapstructure:"sora"`
RunMode string `mapstructure:"run_mode" yaml:"run_mode"`
Timezone string `mapstructure:"timezone"` // e.g. "Asia/Shanghai", "UTC"
Gemini GeminiConfig `mapstructure:"gemini"`
Update UpdateConfig `mapstructure:"update"`
Idempotency IdempotencyConfig `mapstructure:"idempotency"`
}
type LogConfig struct {
Level string `mapstructure:"level"`
Format string `mapstructure:"format"`
ServiceName string `mapstructure:"service_name"`
Environment string `mapstructure:"env"`
Caller bool `mapstructure:"caller"`
StacktraceLevel string `mapstructure:"stacktrace_level"`
Output LogOutputConfig `mapstructure:"output"`
Rotation LogRotationConfig `mapstructure:"rotation"`
Sampling LogSamplingConfig `mapstructure:"sampling"`
}
type LogOutputConfig struct {
ToStdout bool `mapstructure:"to_stdout"`
ToFile bool `mapstructure:"to_file"`
FilePath string `mapstructure:"file_path"`
}
type LogRotationConfig struct {
MaxSizeMB int `mapstructure:"max_size_mb"`
MaxBackups int `mapstructure:"max_backups"`
MaxAgeDays int `mapstructure:"max_age_days"`
Compress bool `mapstructure:"compress"`
LocalTime bool `mapstructure:"local_time"`
}
type LogSamplingConfig struct {
Enabled bool `mapstructure:"enabled"`
Initial int `mapstructure:"initial"`
Thereafter int `mapstructure:"thereafter"`
} }
type GeminiConfig struct { type GeminiConfig struct {
...@@ -94,6 +138,25 @@ type UpdateConfig struct { ...@@ -94,6 +138,25 @@ type UpdateConfig struct {
ProxyURL string `mapstructure:"proxy_url"` ProxyURL string `mapstructure:"proxy_url"`
} }
type IdempotencyConfig struct {
// ObserveOnly 为 true 时处于观察期:未携带 Idempotency-Key 的请求继续放行。
ObserveOnly bool `mapstructure:"observe_only"`
// DefaultTTLSeconds 关键写接口的幂等记录默认 TTL(秒)。
DefaultTTLSeconds int `mapstructure:"default_ttl_seconds"`
// SystemOperationTTLSeconds 系统操作接口的幂等记录 TTL(秒)。
SystemOperationTTLSeconds int `mapstructure:"system_operation_ttl_seconds"`
// ProcessingTimeoutSeconds processing 状态锁超时(秒)。
ProcessingTimeoutSeconds int `mapstructure:"processing_timeout_seconds"`
// FailedRetryBackoffSeconds 失败退避窗口(秒)。
FailedRetryBackoffSeconds int `mapstructure:"failed_retry_backoff_seconds"`
// MaxStoredResponseLen 持久化响应体最大长度(字节)。
MaxStoredResponseLen int `mapstructure:"max_stored_response_len"`
// CleanupIntervalSeconds 过期记录清理周期(秒)。
CleanupIntervalSeconds int `mapstructure:"cleanup_interval_seconds"`
// CleanupBatchSize 每次清理的最大记录数。
CleanupBatchSize int `mapstructure:"cleanup_batch_size"`
}
type LinuxDoConnectConfig struct { type LinuxDoConnectConfig struct {
Enabled bool `mapstructure:"enabled"` Enabled bool `mapstructure:"enabled"`
ClientID string `mapstructure:"client_id"` ClientID string `mapstructure:"client_id"`
...@@ -126,6 +189,8 @@ type TokenRefreshConfig struct { ...@@ -126,6 +189,8 @@ type TokenRefreshConfig struct {
MaxRetries int `mapstructure:"max_retries"` MaxRetries int `mapstructure:"max_retries"`
// 重试退避基础时间(秒) // 重试退避基础时间(秒)
RetryBackoffSeconds int `mapstructure:"retry_backoff_seconds"` RetryBackoffSeconds int `mapstructure:"retry_backoff_seconds"`
// 是否允许 OpenAI 刷新器同步覆盖关联的 Sora 账号 token(默认关闭)
SyncLinkedSoraAccounts bool `mapstructure:"sync_linked_sora_accounts"`
} }
type PricingConfig struct { type PricingConfig struct {
...@@ -147,6 +212,7 @@ type ServerConfig struct { ...@@ -147,6 +212,7 @@ type ServerConfig struct {
Host string `mapstructure:"host"` Host string `mapstructure:"host"`
Port int `mapstructure:"port"` Port int `mapstructure:"port"`
Mode string `mapstructure:"mode"` // debug/release Mode string `mapstructure:"mode"` // debug/release
FrontendURL string `mapstructure:"frontend_url"` // 前端基础 URL,用于生成邮件中的外部链接
ReadHeaderTimeout int `mapstructure:"read_header_timeout"` // 读取请求头超时(秒) ReadHeaderTimeout int `mapstructure:"read_header_timeout"` // 读取请求头超时(秒)
IdleTimeout int `mapstructure:"idle_timeout"` // 空闲连接超时(秒) IdleTimeout int `mapstructure:"idle_timeout"` // 空闲连接超时(秒)
TrustedProxies []string `mapstructure:"trusted_proxies"` // 可信代理列表(CIDR/IP) TrustedProxies []string `mapstructure:"trusted_proxies"` // 可信代理列表(CIDR/IP)
...@@ -173,6 +239,7 @@ type SecurityConfig struct { ...@@ -173,6 +239,7 @@ type SecurityConfig struct {
URLAllowlist URLAllowlistConfig `mapstructure:"url_allowlist"` URLAllowlist URLAllowlistConfig `mapstructure:"url_allowlist"`
ResponseHeaders ResponseHeaderConfig `mapstructure:"response_headers"` ResponseHeaders ResponseHeaderConfig `mapstructure:"response_headers"`
CSP CSPConfig `mapstructure:"csp"` CSP CSPConfig `mapstructure:"csp"`
ProxyFallback ProxyFallbackConfig `mapstructure:"proxy_fallback"`
ProxyProbe ProxyProbeConfig `mapstructure:"proxy_probe"` ProxyProbe ProxyProbeConfig `mapstructure:"proxy_probe"`
} }
...@@ -197,6 +264,12 @@ type CSPConfig struct { ...@@ -197,6 +264,12 @@ type CSPConfig struct {
Policy string `mapstructure:"policy"` Policy string `mapstructure:"policy"`
} }
type ProxyFallbackConfig struct {
// AllowDirectOnError 当代理初始化失败时是否允许回退直连。
// 默认 false:避免因代理配置错误导致 IP 泄露/关联。
AllowDirectOnError bool `mapstructure:"allow_direct_on_error"`
}
type ProxyProbeConfig struct { type ProxyProbeConfig struct {
InsecureSkipVerify bool `mapstructure:"insecure_skip_verify"` // 已禁用:禁止跳过 TLS 证书验证 InsecureSkipVerify bool `mapstructure:"insecure_skip_verify"` // 已禁用:禁止跳过 TLS 证书验证
} }
...@@ -217,6 +290,59 @@ type ConcurrencyConfig struct { ...@@ -217,6 +290,59 @@ type ConcurrencyConfig struct {
PingInterval int `mapstructure:"ping_interval"` PingInterval int `mapstructure:"ping_interval"`
} }
// SoraConfig 直连 Sora 配置
type SoraConfig struct {
Client SoraClientConfig `mapstructure:"client"`
Storage SoraStorageConfig `mapstructure:"storage"`
}
// SoraClientConfig 直连 Sora 客户端配置
type SoraClientConfig struct {
BaseURL string `mapstructure:"base_url"`
TimeoutSeconds int `mapstructure:"timeout_seconds"`
MaxRetries int `mapstructure:"max_retries"`
CloudflareChallengeCooldownSeconds int `mapstructure:"cloudflare_challenge_cooldown_seconds"`
PollIntervalSeconds int `mapstructure:"poll_interval_seconds"`
MaxPollAttempts int `mapstructure:"max_poll_attempts"`
RecentTaskLimit int `mapstructure:"recent_task_limit"`
RecentTaskLimitMax int `mapstructure:"recent_task_limit_max"`
Debug bool `mapstructure:"debug"`
UseOpenAITokenProvider bool `mapstructure:"use_openai_token_provider"`
Headers map[string]string `mapstructure:"headers"`
UserAgent string `mapstructure:"user_agent"`
DisableTLSFingerprint bool `mapstructure:"disable_tls_fingerprint"`
CurlCFFISidecar SoraCurlCFFISidecarConfig `mapstructure:"curl_cffi_sidecar"`
}
// SoraCurlCFFISidecarConfig Sora 专用 curl_cffi sidecar 配置
type SoraCurlCFFISidecarConfig struct {
Enabled bool `mapstructure:"enabled"`
BaseURL string `mapstructure:"base_url"`
Impersonate string `mapstructure:"impersonate"`
TimeoutSeconds int `mapstructure:"timeout_seconds"`
SessionReuseEnabled bool `mapstructure:"session_reuse_enabled"`
SessionTTLSeconds int `mapstructure:"session_ttl_seconds"`
}
// SoraStorageConfig 媒体存储配置
type SoraStorageConfig struct {
Type string `mapstructure:"type"`
LocalPath string `mapstructure:"local_path"`
FallbackToUpstream bool `mapstructure:"fallback_to_upstream"`
MaxConcurrentDownloads int `mapstructure:"max_concurrent_downloads"`
DownloadTimeoutSeconds int `mapstructure:"download_timeout_seconds"`
MaxDownloadBytes int64 `mapstructure:"max_download_bytes"`
Debug bool `mapstructure:"debug"`
Cleanup SoraStorageCleanupConfig `mapstructure:"cleanup"`
}
// SoraStorageCleanupConfig 媒体清理配置
type SoraStorageCleanupConfig struct {
Enabled bool `mapstructure:"enabled"`
Schedule string `mapstructure:"schedule"`
RetentionDays int `mapstructure:"retention_days"`
}
// GatewayConfig API网关相关配置 // GatewayConfig API网关相关配置
type GatewayConfig struct { type GatewayConfig struct {
// 等待上游响应头的超时时间(秒),0表示无超时 // 等待上游响应头的超时时间(秒),0表示无超时
...@@ -224,8 +350,20 @@ type GatewayConfig struct { ...@@ -224,8 +350,20 @@ type GatewayConfig struct {
ResponseHeaderTimeout int `mapstructure:"response_header_timeout"` ResponseHeaderTimeout int `mapstructure:"response_header_timeout"`
// 请求体最大字节数,用于网关请求体大小限制 // 请求体最大字节数,用于网关请求体大小限制
MaxBodySize int64 `mapstructure:"max_body_size"` MaxBodySize int64 `mapstructure:"max_body_size"`
// 非流式上游响应体读取上限(字节),用于防止无界读取导致内存放大
UpstreamResponseReadMaxBytes int64 `mapstructure:"upstream_response_read_max_bytes"`
// 代理探测响应体读取上限(字节)
ProxyProbeResponseReadMaxBytes int64 `mapstructure:"proxy_probe_response_read_max_bytes"`
// Gemini 上游响应头调试日志开关(默认关闭,避免高频日志开销)
GeminiDebugResponseHeaders bool `mapstructure:"gemini_debug_response_headers"`
// ConnectionPoolIsolation: 上游连接池隔离策略(proxy/account/account_proxy) // ConnectionPoolIsolation: 上游连接池隔离策略(proxy/account/account_proxy)
ConnectionPoolIsolation string `mapstructure:"connection_pool_isolation"` ConnectionPoolIsolation string `mapstructure:"connection_pool_isolation"`
// ForceCodexCLI: 强制将 OpenAI `/v1/responses` 请求按 Codex CLI 处理。
// 用于网关未透传/改写 User-Agent 时的兼容兜底(默认关闭,避免影响其他客户端)。
ForceCodexCLI bool `mapstructure:"force_codex_cli"`
// OpenAIPassthroughAllowTimeoutHeaders: OpenAI 透传模式是否放行客户端超时头
// 关闭(默认)可避免 x-stainless-timeout 等头导致上游提前断流。
OpenAIPassthroughAllowTimeoutHeaders bool `mapstructure:"openai_passthrough_allow_timeout_headers"`
// HTTP 上游连接池配置(性能优化:支持高并发场景调优) // HTTP 上游连接池配置(性能优化:支持高并发场景调优)
// MaxIdleConns: 所有主机的最大空闲连接总数 // MaxIdleConns: 所有主机的最大空闲连接总数
...@@ -271,6 +409,24 @@ type GatewayConfig struct { ...@@ -271,6 +409,24 @@ type GatewayConfig struct {
// 是否允许对部分 400 错误触发 failover(默认关闭以避免改变语义) // 是否允许对部分 400 错误触发 failover(默认关闭以避免改变语义)
FailoverOn400 bool `mapstructure:"failover_on_400"` FailoverOn400 bool `mapstructure:"failover_on_400"`
// Sora 专用配置
// SoraMaxBodySize: Sora 请求体最大字节数(0 表示使用 gateway.max_body_size)
SoraMaxBodySize int64 `mapstructure:"sora_max_body_size"`
// SoraStreamTimeoutSeconds: Sora 流式请求总超时(秒,0 表示不限制)
SoraStreamTimeoutSeconds int `mapstructure:"sora_stream_timeout_seconds"`
// SoraRequestTimeoutSeconds: Sora 非流式请求超时(秒,0 表示不限制)
SoraRequestTimeoutSeconds int `mapstructure:"sora_request_timeout_seconds"`
// SoraStreamMode: stream 强制策略(force/error)
SoraStreamMode string `mapstructure:"sora_stream_mode"`
// SoraModelFilters: 模型列表过滤配置
SoraModelFilters SoraModelFiltersConfig `mapstructure:"sora_model_filters"`
// SoraMediaRequireAPIKey: 是否要求访问 /sora/media 携带 API Key
SoraMediaRequireAPIKey bool `mapstructure:"sora_media_require_api_key"`
// SoraMediaSigningKey: /sora/media 临时签名密钥(空表示禁用签名)
SoraMediaSigningKey string `mapstructure:"sora_media_signing_key"`
// SoraMediaSignedURLTTLSeconds: 临时签名 URL 有效期(秒,<=0 表示禁用)
SoraMediaSignedURLTTLSeconds int `mapstructure:"sora_media_signed_url_ttl_seconds"`
// 账户切换最大次数(遇到上游错误时切换到其他账户的次数上限) // 账户切换最大次数(遇到上游错误时切换到其他账户的次数上限)
MaxAccountSwitches int `mapstructure:"max_account_switches"` MaxAccountSwitches int `mapstructure:"max_account_switches"`
// Gemini 账户切换最大次数(Gemini 平台单独配置,因 API 限制更严格) // Gemini 账户切换最大次数(Gemini 平台单独配置,因 API 限制更严格)
...@@ -284,6 +440,53 @@ type GatewayConfig struct { ...@@ -284,6 +440,53 @@ type GatewayConfig struct {
// TLSFingerprint: TLS指纹伪装配置 // TLSFingerprint: TLS指纹伪装配置
TLSFingerprint TLSFingerprintConfig `mapstructure:"tls_fingerprint"` TLSFingerprint TLSFingerprintConfig `mapstructure:"tls_fingerprint"`
// UsageRecord: 使用量记录异步队列配置(有界队列 + 固定 worker)
UsageRecord GatewayUsageRecordConfig `mapstructure:"usage_record"`
// UserGroupRateCacheTTLSeconds: 用户分组倍率热路径缓存 TTL(秒)
UserGroupRateCacheTTLSeconds int `mapstructure:"user_group_rate_cache_ttl_seconds"`
// ModelsListCacheTTLSeconds: /v1/models 模型列表短缓存 TTL(秒)
ModelsListCacheTTLSeconds int `mapstructure:"models_list_cache_ttl_seconds"`
}
// GatewayUsageRecordConfig 使用量记录异步队列配置
type GatewayUsageRecordConfig struct {
// WorkerCount: worker 初始数量(自动扩缩容开启时作为初始并发上限)
WorkerCount int `mapstructure:"worker_count"`
// QueueSize: 队列容量(有界)
QueueSize int `mapstructure:"queue_size"`
// TaskTimeoutSeconds: 单个使用量记录任务超时(秒)
TaskTimeoutSeconds int `mapstructure:"task_timeout_seconds"`
// OverflowPolicy: 队列满时策略(drop/sample/sync)
OverflowPolicy string `mapstructure:"overflow_policy"`
// OverflowSamplePercent: sample 策略下,同步回写采样百分比(1-100)
OverflowSamplePercent int `mapstructure:"overflow_sample_percent"`
// AutoScaleEnabled: 是否启用 worker 自动扩缩容
AutoScaleEnabled bool `mapstructure:"auto_scale_enabled"`
// AutoScaleMinWorkers: 自动扩缩容最小 worker 数
AutoScaleMinWorkers int `mapstructure:"auto_scale_min_workers"`
// AutoScaleMaxWorkers: 自动扩缩容最大 worker 数
AutoScaleMaxWorkers int `mapstructure:"auto_scale_max_workers"`
// AutoScaleUpQueuePercent: 队列占用率达到该阈值时触发扩容
AutoScaleUpQueuePercent int `mapstructure:"auto_scale_up_queue_percent"`
// AutoScaleDownQueuePercent: 队列占用率低于该阈值时触发缩容
AutoScaleDownQueuePercent int `mapstructure:"auto_scale_down_queue_percent"`
// AutoScaleUpStep: 每次扩容步长
AutoScaleUpStep int `mapstructure:"auto_scale_up_step"`
// AutoScaleDownStep: 每次缩容步长
AutoScaleDownStep int `mapstructure:"auto_scale_down_step"`
// AutoScaleCheckIntervalSeconds: 自动扩缩容检测间隔(秒)
AutoScaleCheckIntervalSeconds int `mapstructure:"auto_scale_check_interval_seconds"`
// AutoScaleCooldownSeconds: 自动扩缩容冷却时间(秒)
AutoScaleCooldownSeconds int `mapstructure:"auto_scale_cooldown_seconds"`
}
// SoraModelFiltersConfig Sora 模型过滤配置
type SoraModelFiltersConfig struct {
// HidePromptEnhance 是否隐藏 prompt-enhance 模型
HidePromptEnhance bool `mapstructure:"hide_prompt_enhance"`
} }
// TLSFingerprintConfig TLS指纹伪装配置 // TLSFingerprintConfig TLS指纹伪装配置
...@@ -479,8 +682,9 @@ type OpsMetricsCollectorCacheConfig struct { ...@@ -479,8 +682,9 @@ type OpsMetricsCollectorCacheConfig struct {
type JWTConfig struct { type JWTConfig struct {
Secret string `mapstructure:"secret"` Secret string `mapstructure:"secret"`
ExpireHour int `mapstructure:"expire_hour"` ExpireHour int `mapstructure:"expire_hour"`
// AccessTokenExpireMinutes: Access Token有效期(分钟),默认15分钟 // AccessTokenExpireMinutes: Access Token有效期(分钟)
// 短有效期减少被盗用风险,配合Refresh Token实现无感续期 // - >0: 使用分钟配置(优先级高于 ExpireHour)
// - =0: 回退使用 ExpireHour(向后兼容旧配置)
AccessTokenExpireMinutes int `mapstructure:"access_token_expire_minutes"` AccessTokenExpireMinutes int `mapstructure:"access_token_expire_minutes"`
// RefreshTokenExpireDays: Refresh Token有效期(天),默认30天 // RefreshTokenExpireDays: Refresh Token有效期(天),默认30天
RefreshTokenExpireDays int `mapstructure:"refresh_token_expire_days"` RefreshTokenExpireDays int `mapstructure:"refresh_token_expire_days"`
...@@ -525,6 +729,20 @@ type APIKeyAuthCacheConfig struct { ...@@ -525,6 +729,20 @@ type APIKeyAuthCacheConfig struct {
Singleflight bool `mapstructure:"singleflight"` Singleflight bool `mapstructure:"singleflight"`
} }
// SubscriptionCacheConfig 订阅认证 L1 缓存配置
type SubscriptionCacheConfig struct {
L1Size int `mapstructure:"l1_size"`
L1TTLSeconds int `mapstructure:"l1_ttl_seconds"`
JitterPercent int `mapstructure:"jitter_percent"`
}
// SubscriptionMaintenanceConfig 订阅窗口维护后台任务配置。
// 用于将“请求路径触发的维护动作”有界化,避免高并发下 goroutine 膨胀。
type SubscriptionMaintenanceConfig struct {
WorkerCount int `mapstructure:"worker_count"`
QueueSize int `mapstructure:"queue_size"`
}
// DashboardCacheConfig 仪表盘统计缓存配置 // DashboardCacheConfig 仪表盘统计缓存配置
type DashboardCacheConfig struct { type DashboardCacheConfig struct {
// Enabled: 是否启用仪表盘缓存 // Enabled: 是否启用仪表盘缓存
...@@ -588,7 +806,19 @@ func NormalizeRunMode(value string) string { ...@@ -588,7 +806,19 @@ func NormalizeRunMode(value string) string {
} }
} }
// Load 读取并校验完整配置(要求 jwt.secret 已显式提供)。
func Load() (*Config, error) { func Load() (*Config, error) {
return load(false)
}
// LoadForBootstrap 读取启动阶段配置。
//
// 启动阶段允许 jwt.secret 先留空,后续由数据库初始化流程补齐并再次完整校验。
func LoadForBootstrap() (*Config, error) {
return load(true)
}
func load(allowMissingJWTSecret bool) (*Config, error) {
viper.SetConfigName("config") viper.SetConfigName("config")
viper.SetConfigType("yaml") viper.SetConfigType("yaml")
...@@ -630,6 +860,7 @@ func Load() (*Config, error) { ...@@ -630,6 +860,7 @@ func Load() (*Config, error) {
if cfg.Server.Mode == "" { if cfg.Server.Mode == "" {
cfg.Server.Mode = "debug" cfg.Server.Mode = "debug"
} }
cfg.Server.FrontendURL = strings.TrimSpace(cfg.Server.FrontendURL)
cfg.JWT.Secret = strings.TrimSpace(cfg.JWT.Secret) cfg.JWT.Secret = strings.TrimSpace(cfg.JWT.Secret)
cfg.LinuxDo.ClientID = strings.TrimSpace(cfg.LinuxDo.ClientID) cfg.LinuxDo.ClientID = strings.TrimSpace(cfg.LinuxDo.ClientID)
cfg.LinuxDo.ClientSecret = strings.TrimSpace(cfg.LinuxDo.ClientSecret) cfg.LinuxDo.ClientSecret = strings.TrimSpace(cfg.LinuxDo.ClientSecret)
...@@ -648,15 +879,12 @@ func Load() (*Config, error) { ...@@ -648,15 +879,12 @@ func Load() (*Config, error) {
cfg.Security.ResponseHeaders.AdditionalAllowed = normalizeStringSlice(cfg.Security.ResponseHeaders.AdditionalAllowed) cfg.Security.ResponseHeaders.AdditionalAllowed = normalizeStringSlice(cfg.Security.ResponseHeaders.AdditionalAllowed)
cfg.Security.ResponseHeaders.ForceRemove = normalizeStringSlice(cfg.Security.ResponseHeaders.ForceRemove) cfg.Security.ResponseHeaders.ForceRemove = normalizeStringSlice(cfg.Security.ResponseHeaders.ForceRemove)
cfg.Security.CSP.Policy = strings.TrimSpace(cfg.Security.CSP.Policy) cfg.Security.CSP.Policy = strings.TrimSpace(cfg.Security.CSP.Policy)
cfg.Log.Level = strings.ToLower(strings.TrimSpace(cfg.Log.Level))
if cfg.JWT.Secret == "" { cfg.Log.Format = strings.ToLower(strings.TrimSpace(cfg.Log.Format))
secret, err := generateJWTSecret(64) cfg.Log.ServiceName = strings.TrimSpace(cfg.Log.ServiceName)
if err != nil { cfg.Log.Environment = strings.TrimSpace(cfg.Log.Environment)
return nil, fmt.Errorf("generate jwt secret error: %w", err) cfg.Log.StacktraceLevel = strings.ToLower(strings.TrimSpace(cfg.Log.StacktraceLevel))
} cfg.Log.Output.FilePath = strings.TrimSpace(cfg.Log.Output.FilePath)
cfg.JWT.Secret = secret
log.Println("Warning: JWT secret auto-generated. Consider setting a fixed secret for production.")
}
// Auto-generate TOTP encryption key if not set (32 bytes = 64 hex chars for AES-256) // Auto-generate TOTP encryption key if not set (32 bytes = 64 hex chars for AES-256)
cfg.Totp.EncryptionKey = strings.TrimSpace(cfg.Totp.EncryptionKey) cfg.Totp.EncryptionKey = strings.TrimSpace(cfg.Totp.EncryptionKey)
...@@ -667,29 +895,39 @@ func Load() (*Config, error) { ...@@ -667,29 +895,39 @@ func Load() (*Config, error) {
} }
cfg.Totp.EncryptionKey = key cfg.Totp.EncryptionKey = key
cfg.Totp.EncryptionKeyConfigured = false cfg.Totp.EncryptionKeyConfigured = false
log.Println("Warning: TOTP encryption key auto-generated. Consider setting a fixed key for production.") slog.Warn("TOTP encryption key auto-generated. Consider setting a fixed key for production.")
} else { } else {
cfg.Totp.EncryptionKeyConfigured = true cfg.Totp.EncryptionKeyConfigured = true
} }
originalJWTSecret := cfg.JWT.Secret
if allowMissingJWTSecret && originalJWTSecret == "" {
// 启动阶段允许先无 JWT 密钥,后续在数据库初始化后补齐。
cfg.JWT.Secret = strings.Repeat("0", 32)
}
if err := cfg.Validate(); err != nil { if err := cfg.Validate(); err != nil {
return nil, fmt.Errorf("validate config error: %w", err) return nil, fmt.Errorf("validate config error: %w", err)
} }
if allowMissingJWTSecret && originalJWTSecret == "" {
cfg.JWT.Secret = ""
}
if !cfg.Security.URLAllowlist.Enabled { if !cfg.Security.URLAllowlist.Enabled {
log.Println("Warning: security.url_allowlist.enabled=false; allowlist/SSRF checks disabled (minimal format validation only).") slog.Warn("security.url_allowlist.enabled=false; allowlist/SSRF checks disabled (minimal format validation only).")
} }
if !cfg.Security.ResponseHeaders.Enabled { if !cfg.Security.ResponseHeaders.Enabled {
log.Println("Warning: security.response_headers.enabled=false; configurable header filtering disabled (default allowlist only).") slog.Warn("security.response_headers.enabled=false; configurable header filtering disabled (default allowlist only).")
} }
if cfg.JWT.Secret != "" && isWeakJWTSecret(cfg.JWT.Secret) { if cfg.JWT.Secret != "" && isWeakJWTSecret(cfg.JWT.Secret) {
log.Println("Warning: JWT secret appears weak; use a 32+ character random secret in production.") slog.Warn("JWT secret appears weak; use a 32+ character random secret in production.")
} }
if len(cfg.Security.ResponseHeaders.AdditionalAllowed) > 0 || len(cfg.Security.ResponseHeaders.ForceRemove) > 0 { if len(cfg.Security.ResponseHeaders.AdditionalAllowed) > 0 || len(cfg.Security.ResponseHeaders.ForceRemove) > 0 {
log.Printf("AUDIT: response header policy configured additional_allowed=%v force_remove=%v", slog.Info("response header policy configured",
cfg.Security.ResponseHeaders.AdditionalAllowed, "additional_allowed", cfg.Security.ResponseHeaders.AdditionalAllowed,
cfg.Security.ResponseHeaders.ForceRemove, "force_remove", cfg.Security.ResponseHeaders.ForceRemove,
) )
} }
...@@ -702,7 +940,8 @@ func setDefaults() { ...@@ -702,7 +940,8 @@ func setDefaults() {
// Server // Server
viper.SetDefault("server.host", "0.0.0.0") viper.SetDefault("server.host", "0.0.0.0")
viper.SetDefault("server.port", 8080) viper.SetDefault("server.port", 8080)
viper.SetDefault("server.mode", "debug") viper.SetDefault("server.mode", "release")
viper.SetDefault("server.frontend_url", "")
viper.SetDefault("server.read_header_timeout", 30) // 30秒读取请求头 viper.SetDefault("server.read_header_timeout", 30) // 30秒读取请求头
viper.SetDefault("server.idle_timeout", 120) // 120秒空闲超时 viper.SetDefault("server.idle_timeout", 120) // 120秒空闲超时
viper.SetDefault("server.trusted_proxies", []string{}) viper.SetDefault("server.trusted_proxies", []string{})
...@@ -715,6 +954,25 @@ func setDefaults() { ...@@ -715,6 +954,25 @@ func setDefaults() {
viper.SetDefault("server.h2c.max_upload_buffer_per_connection", 2<<20) // 2MB viper.SetDefault("server.h2c.max_upload_buffer_per_connection", 2<<20) // 2MB
viper.SetDefault("server.h2c.max_upload_buffer_per_stream", 512<<10) // 512KB viper.SetDefault("server.h2c.max_upload_buffer_per_stream", 512<<10) // 512KB
// Log
viper.SetDefault("log.level", "info")
viper.SetDefault("log.format", "console")
viper.SetDefault("log.service_name", "sub2api")
viper.SetDefault("log.env", "production")
viper.SetDefault("log.caller", true)
viper.SetDefault("log.stacktrace_level", "error")
viper.SetDefault("log.output.to_stdout", true)
viper.SetDefault("log.output.to_file", true)
viper.SetDefault("log.output.file_path", "")
viper.SetDefault("log.rotation.max_size_mb", 100)
viper.SetDefault("log.rotation.max_backups", 10)
viper.SetDefault("log.rotation.max_age_days", 7)
viper.SetDefault("log.rotation.compress", true)
viper.SetDefault("log.rotation.local_time", true)
viper.SetDefault("log.sampling.enabled", false)
viper.SetDefault("log.sampling.initial", 100)
viper.SetDefault("log.sampling.thereafter", 100)
// CORS // CORS
viper.SetDefault("cors.allowed_origins", []string{}) viper.SetDefault("cors.allowed_origins", []string{})
viper.SetDefault("cors.allow_credentials", true) viper.SetDefault("cors.allow_credentials", true)
...@@ -737,7 +995,7 @@ func setDefaults() { ...@@ -737,7 +995,7 @@ func setDefaults() {
viper.SetDefault("security.url_allowlist.crs_hosts", []string{}) viper.SetDefault("security.url_allowlist.crs_hosts", []string{})
viper.SetDefault("security.url_allowlist.allow_private_hosts", true) viper.SetDefault("security.url_allowlist.allow_private_hosts", true)
viper.SetDefault("security.url_allowlist.allow_insecure_http", true) viper.SetDefault("security.url_allowlist.allow_insecure_http", true)
viper.SetDefault("security.response_headers.enabled", false) viper.SetDefault("security.response_headers.enabled", true)
viper.SetDefault("security.response_headers.additional_allowed", []string{}) viper.SetDefault("security.response_headers.additional_allowed", []string{})
viper.SetDefault("security.response_headers.force_remove", []string{}) viper.SetDefault("security.response_headers.force_remove", []string{})
viper.SetDefault("security.csp.enabled", true) viper.SetDefault("security.csp.enabled", true)
...@@ -775,9 +1033,9 @@ func setDefaults() { ...@@ -775,9 +1033,9 @@ func setDefaults() {
viper.SetDefault("database.user", "postgres") viper.SetDefault("database.user", "postgres")
viper.SetDefault("database.password", "postgres") viper.SetDefault("database.password", "postgres")
viper.SetDefault("database.dbname", "sub2api") viper.SetDefault("database.dbname", "sub2api")
viper.SetDefault("database.sslmode", "disable") viper.SetDefault("database.sslmode", "prefer")
viper.SetDefault("database.max_open_conns", 50) viper.SetDefault("database.max_open_conns", 256)
viper.SetDefault("database.max_idle_conns", 10) viper.SetDefault("database.max_idle_conns", 128)
viper.SetDefault("database.conn_max_lifetime_minutes", 30) viper.SetDefault("database.conn_max_lifetime_minutes", 30)
viper.SetDefault("database.conn_max_idle_time_minutes", 5) viper.SetDefault("database.conn_max_idle_time_minutes", 5)
...@@ -789,8 +1047,8 @@ func setDefaults() { ...@@ -789,8 +1047,8 @@ func setDefaults() {
viper.SetDefault("redis.dial_timeout_seconds", 5) viper.SetDefault("redis.dial_timeout_seconds", 5)
viper.SetDefault("redis.read_timeout_seconds", 3) viper.SetDefault("redis.read_timeout_seconds", 3)
viper.SetDefault("redis.write_timeout_seconds", 3) viper.SetDefault("redis.write_timeout_seconds", 3)
viper.SetDefault("redis.pool_size", 128) viper.SetDefault("redis.pool_size", 1024)
viper.SetDefault("redis.min_idle_conns", 10) viper.SetDefault("redis.min_idle_conns", 128)
viper.SetDefault("redis.enable_tls", false) viper.SetDefault("redis.enable_tls", false)
// Ops (vNext) // Ops (vNext)
...@@ -810,9 +1068,9 @@ func setDefaults() { ...@@ -810,9 +1068,9 @@ func setDefaults() {
// JWT // JWT
viper.SetDefault("jwt.secret", "") viper.SetDefault("jwt.secret", "")
viper.SetDefault("jwt.expire_hour", 24) viper.SetDefault("jwt.expire_hour", 24)
viper.SetDefault("jwt.access_token_expire_minutes", 360) // 6小时Access Token有效期 viper.SetDefault("jwt.access_token_expire_minutes", 0) // 0 表示回退到 expire_hour
viper.SetDefault("jwt.refresh_token_expire_days", 30) // 30天Refresh Token有效期 viper.SetDefault("jwt.refresh_token_expire_days", 30) // 30天Refresh Token有效期
viper.SetDefault("jwt.refresh_window_minutes", 2) // 过期前2分钟开始允许刷新 viper.SetDefault("jwt.refresh_window_minutes", 2) // 过期前2分钟开始允许刷新
// TOTP // TOTP
viper.SetDefault("totp.encryption_key", "") viper.SetDefault("totp.encryption_key", "")
...@@ -849,6 +1107,11 @@ func setDefaults() { ...@@ -849,6 +1107,11 @@ func setDefaults() {
viper.SetDefault("api_key_auth_cache.jitter_percent", 10) viper.SetDefault("api_key_auth_cache.jitter_percent", 10)
viper.SetDefault("api_key_auth_cache.singleflight", true) viper.SetDefault("api_key_auth_cache.singleflight", true)
// Subscription auth L1 cache
viper.SetDefault("subscription_cache.l1_size", 16384)
viper.SetDefault("subscription_cache.l1_ttl_seconds", 10)
viper.SetDefault("subscription_cache.jitter_percent", 10)
// Dashboard cache // Dashboard cache
viper.SetDefault("dashboard_cache.enabled", true) viper.SetDefault("dashboard_cache.enabled", true)
viper.SetDefault("dashboard_cache.key_prefix", "sub2api:") viper.SetDefault("dashboard_cache.key_prefix", "sub2api:")
...@@ -874,6 +1137,16 @@ func setDefaults() { ...@@ -874,6 +1137,16 @@ func setDefaults() {
viper.SetDefault("usage_cleanup.worker_interval_seconds", 10) viper.SetDefault("usage_cleanup.worker_interval_seconds", 10)
viper.SetDefault("usage_cleanup.task_timeout_seconds", 1800) viper.SetDefault("usage_cleanup.task_timeout_seconds", 1800)
// Idempotency
viper.SetDefault("idempotency.observe_only", true)
viper.SetDefault("idempotency.default_ttl_seconds", 86400)
viper.SetDefault("idempotency.system_operation_ttl_seconds", 3600)
viper.SetDefault("idempotency.processing_timeout_seconds", 30)
viper.SetDefault("idempotency.failed_retry_backoff_seconds", 5)
viper.SetDefault("idempotency.max_stored_response_len", 64*1024)
viper.SetDefault("idempotency.cleanup_interval_seconds", 60)
viper.SetDefault("idempotency.cleanup_batch_size", 500)
// Gateway // Gateway
viper.SetDefault("gateway.response_header_timeout", 600) // 600秒(10分钟)等待上游响应头,LLM高负载时可能排队较久 viper.SetDefault("gateway.response_header_timeout", 600) // 600秒(10分钟)等待上游响应头,LLM高负载时可能排队较久
viper.SetDefault("gateway.log_upstream_error_body", true) viper.SetDefault("gateway.log_upstream_error_body", true)
...@@ -882,13 +1155,25 @@ func setDefaults() { ...@@ -882,13 +1155,25 @@ func setDefaults() {
viper.SetDefault("gateway.failover_on_400", false) viper.SetDefault("gateway.failover_on_400", false)
viper.SetDefault("gateway.max_account_switches", 10) viper.SetDefault("gateway.max_account_switches", 10)
viper.SetDefault("gateway.max_account_switches_gemini", 3) viper.SetDefault("gateway.max_account_switches_gemini", 3)
viper.SetDefault("gateway.force_codex_cli", false)
viper.SetDefault("gateway.openai_passthrough_allow_timeout_headers", false)
viper.SetDefault("gateway.antigravity_fallback_cooldown_minutes", 1) viper.SetDefault("gateway.antigravity_fallback_cooldown_minutes", 1)
viper.SetDefault("gateway.max_body_size", int64(100*1024*1024)) viper.SetDefault("gateway.max_body_size", int64(100*1024*1024))
viper.SetDefault("gateway.upstream_response_read_max_bytes", int64(8*1024*1024))
viper.SetDefault("gateway.proxy_probe_response_read_max_bytes", int64(1024*1024))
viper.SetDefault("gateway.gemini_debug_response_headers", false)
viper.SetDefault("gateway.sora_max_body_size", int64(256*1024*1024))
viper.SetDefault("gateway.sora_stream_timeout_seconds", 900)
viper.SetDefault("gateway.sora_request_timeout_seconds", 180)
viper.SetDefault("gateway.sora_stream_mode", "force")
viper.SetDefault("gateway.sora_model_filters.hide_prompt_enhance", true)
viper.SetDefault("gateway.sora_media_require_api_key", true)
viper.SetDefault("gateway.sora_media_signed_url_ttl_seconds", 900)
viper.SetDefault("gateway.connection_pool_isolation", ConnectionPoolIsolationAccountProxy) viper.SetDefault("gateway.connection_pool_isolation", ConnectionPoolIsolationAccountProxy)
// HTTP 上游连接池配置(针对 5000+ 并发用户优化) // HTTP 上游连接池配置(针对 5000+ 并发用户优化)
viper.SetDefault("gateway.max_idle_conns", 240) // 最大空闲连接总数(HTTP/2 场景默认 viper.SetDefault("gateway.max_idle_conns", 2560) // 最大空闲连接总数(高并发场景可调大
viper.SetDefault("gateway.max_idle_conns_per_host", 120) // 每主机最大空闲连接(HTTP/2 场景默认) viper.SetDefault("gateway.max_idle_conns_per_host", 120) // 每主机最大空闲连接(HTTP/2 场景默认)
viper.SetDefault("gateway.max_conns_per_host", 240) // 每主机最大连接数(含活跃,HTTP/2 场景默认 viper.SetDefault("gateway.max_conns_per_host", 1024) // 每主机最大连接数(含活跃;流式/HTTP1.1 场景可调大,如 2400+
viper.SetDefault("gateway.idle_conn_timeout_seconds", 90) // 空闲连接超时(秒) viper.SetDefault("gateway.idle_conn_timeout_seconds", 90) // 空闲连接超时(秒)
viper.SetDefault("gateway.max_upstream_clients", 5000) viper.SetDefault("gateway.max_upstream_clients", 5000)
viper.SetDefault("gateway.client_idle_ttl_seconds", 900) viper.SetDefault("gateway.client_idle_ttl_seconds", 900)
...@@ -912,16 +1197,65 @@ func setDefaults() { ...@@ -912,16 +1197,65 @@ func setDefaults() {
viper.SetDefault("gateway.scheduling.outbox_lag_rebuild_failures", 3) viper.SetDefault("gateway.scheduling.outbox_lag_rebuild_failures", 3)
viper.SetDefault("gateway.scheduling.outbox_backlog_rebuild_rows", 10000) viper.SetDefault("gateway.scheduling.outbox_backlog_rebuild_rows", 10000)
viper.SetDefault("gateway.scheduling.full_rebuild_interval_seconds", 300) viper.SetDefault("gateway.scheduling.full_rebuild_interval_seconds", 300)
viper.SetDefault("gateway.usage_record.worker_count", 128)
viper.SetDefault("gateway.usage_record.queue_size", 16384)
viper.SetDefault("gateway.usage_record.task_timeout_seconds", 5)
viper.SetDefault("gateway.usage_record.overflow_policy", UsageRecordOverflowPolicySample)
viper.SetDefault("gateway.usage_record.overflow_sample_percent", 10)
viper.SetDefault("gateway.usage_record.auto_scale_enabled", true)
viper.SetDefault("gateway.usage_record.auto_scale_min_workers", 128)
viper.SetDefault("gateway.usage_record.auto_scale_max_workers", 512)
viper.SetDefault("gateway.usage_record.auto_scale_up_queue_percent", 70)
viper.SetDefault("gateway.usage_record.auto_scale_down_queue_percent", 15)
viper.SetDefault("gateway.usage_record.auto_scale_up_step", 32)
viper.SetDefault("gateway.usage_record.auto_scale_down_step", 16)
viper.SetDefault("gateway.usage_record.auto_scale_check_interval_seconds", 3)
viper.SetDefault("gateway.usage_record.auto_scale_cooldown_seconds", 10)
viper.SetDefault("gateway.user_group_rate_cache_ttl_seconds", 30)
viper.SetDefault("gateway.models_list_cache_ttl_seconds", 15)
// TLS指纹伪装配置(默认关闭,需要账号级别单独启用) // TLS指纹伪装配置(默认关闭,需要账号级别单独启用)
viper.SetDefault("gateway.tls_fingerprint.enabled", true) viper.SetDefault("gateway.tls_fingerprint.enabled", true)
viper.SetDefault("concurrency.ping_interval", 10) viper.SetDefault("concurrency.ping_interval", 10)
// Sora 直连配置
viper.SetDefault("sora.client.base_url", "https://sora.chatgpt.com/backend")
viper.SetDefault("sora.client.timeout_seconds", 120)
viper.SetDefault("sora.client.max_retries", 3)
viper.SetDefault("sora.client.cloudflare_challenge_cooldown_seconds", 900)
viper.SetDefault("sora.client.poll_interval_seconds", 2)
viper.SetDefault("sora.client.max_poll_attempts", 600)
viper.SetDefault("sora.client.recent_task_limit", 50)
viper.SetDefault("sora.client.recent_task_limit_max", 200)
viper.SetDefault("sora.client.debug", false)
viper.SetDefault("sora.client.use_openai_token_provider", false)
viper.SetDefault("sora.client.headers", map[string]string{})
viper.SetDefault("sora.client.user_agent", "Sora/1.2026.007 (Android 15; 24122RKC7C; build 2600700)")
viper.SetDefault("sora.client.disable_tls_fingerprint", false)
viper.SetDefault("sora.client.curl_cffi_sidecar.enabled", true)
viper.SetDefault("sora.client.curl_cffi_sidecar.base_url", "http://sora-curl-cffi-sidecar:8080")
viper.SetDefault("sora.client.curl_cffi_sidecar.impersonate", "chrome131")
viper.SetDefault("sora.client.curl_cffi_sidecar.timeout_seconds", 60)
viper.SetDefault("sora.client.curl_cffi_sidecar.session_reuse_enabled", true)
viper.SetDefault("sora.client.curl_cffi_sidecar.session_ttl_seconds", 3600)
viper.SetDefault("sora.storage.type", "local")
viper.SetDefault("sora.storage.local_path", "")
viper.SetDefault("sora.storage.fallback_to_upstream", true)
viper.SetDefault("sora.storage.max_concurrent_downloads", 4)
viper.SetDefault("sora.storage.download_timeout_seconds", 120)
viper.SetDefault("sora.storage.max_download_bytes", int64(200<<20))
viper.SetDefault("sora.storage.debug", false)
viper.SetDefault("sora.storage.cleanup.enabled", true)
viper.SetDefault("sora.storage.cleanup.retention_days", 7)
viper.SetDefault("sora.storage.cleanup.schedule", "0 3 * * *")
// TokenRefresh // TokenRefresh
viper.SetDefault("token_refresh.enabled", true) viper.SetDefault("token_refresh.enabled", true)
viper.SetDefault("token_refresh.check_interval_minutes", 5) // 每5分钟检查一次 viper.SetDefault("token_refresh.check_interval_minutes", 5) // 每5分钟检查一次
viper.SetDefault("token_refresh.refresh_before_expiry_hours", 0.5) // 提前30分钟刷新(适配Google 1小时token) viper.SetDefault("token_refresh.refresh_before_expiry_hours", 0.5) // 提前30分钟刷新(适配Google 1小时token)
viper.SetDefault("token_refresh.max_retries", 3) // 最多重试3次 viper.SetDefault("token_refresh.max_retries", 3) // 最多重试3次
viper.SetDefault("token_refresh.retry_backoff_seconds", 2) // 重试退避基础2秒 viper.SetDefault("token_refresh.retry_backoff_seconds", 2) // 重试退避基础2秒
viper.SetDefault("token_refresh.sync_linked_sora_accounts", false) // 默认不跨平台覆盖 Sora token
// Gemini OAuth - configure via environment variables or config file // Gemini OAuth - configure via environment variables or config file
// GEMINI_OAUTH_CLIENT_ID and GEMINI_OAUTH_CLIENT_SECRET // GEMINI_OAUTH_CLIENT_ID and GEMINI_OAUTH_CLIENT_SECRET
...@@ -930,9 +1264,106 @@ func setDefaults() { ...@@ -930,9 +1264,106 @@ func setDefaults() {
viper.SetDefault("gemini.oauth.client_secret", "") viper.SetDefault("gemini.oauth.client_secret", "")
viper.SetDefault("gemini.oauth.scopes", "") viper.SetDefault("gemini.oauth.scopes", "")
viper.SetDefault("gemini.quota.policy", "") viper.SetDefault("gemini.quota.policy", "")
// Security - proxy fallback
viper.SetDefault("security.proxy_fallback.allow_direct_on_error", false)
// Subscription Maintenance (bounded queue + worker pool)
viper.SetDefault("subscription_maintenance.worker_count", 2)
viper.SetDefault("subscription_maintenance.queue_size", 1024)
} }
func (c *Config) Validate() error { func (c *Config) Validate() error {
jwtSecret := strings.TrimSpace(c.JWT.Secret)
if jwtSecret == "" {
return fmt.Errorf("jwt.secret is required")
}
// NOTE: 按 UTF-8 编码后的字节长度计算。
// 选择 bytes 而不是 rune 计数,确保二进制/随机串的长度语义更接近“熵”而非“字符数”。
if len([]byte(jwtSecret)) < 32 {
return fmt.Errorf("jwt.secret must be at least 32 bytes")
}
switch c.Log.Level {
case "debug", "info", "warn", "error":
case "":
return fmt.Errorf("log.level is required")
default:
return fmt.Errorf("log.level must be one of: debug/info/warn/error")
}
switch c.Log.Format {
case "json", "console":
case "":
return fmt.Errorf("log.format is required")
default:
return fmt.Errorf("log.format must be one of: json/console")
}
switch c.Log.StacktraceLevel {
case "none", "error", "fatal":
case "":
return fmt.Errorf("log.stacktrace_level is required")
default:
return fmt.Errorf("log.stacktrace_level must be one of: none/error/fatal")
}
if !c.Log.Output.ToStdout && !c.Log.Output.ToFile {
return fmt.Errorf("log.output.to_stdout and log.output.to_file cannot both be false")
}
if c.Log.Rotation.MaxSizeMB <= 0 {
return fmt.Errorf("log.rotation.max_size_mb must be positive")
}
if c.Log.Rotation.MaxBackups < 0 {
return fmt.Errorf("log.rotation.max_backups must be non-negative")
}
if c.Log.Rotation.MaxAgeDays < 0 {
return fmt.Errorf("log.rotation.max_age_days must be non-negative")
}
if c.Log.Sampling.Enabled {
if c.Log.Sampling.Initial <= 0 {
return fmt.Errorf("log.sampling.initial must be positive when sampling is enabled")
}
if c.Log.Sampling.Thereafter <= 0 {
return fmt.Errorf("log.sampling.thereafter must be positive when sampling is enabled")
}
} else {
if c.Log.Sampling.Initial < 0 {
return fmt.Errorf("log.sampling.initial must be non-negative")
}
if c.Log.Sampling.Thereafter < 0 {
return fmt.Errorf("log.sampling.thereafter must be non-negative")
}
}
if c.SubscriptionMaintenance.WorkerCount < 0 {
return fmt.Errorf("subscription_maintenance.worker_count must be non-negative")
}
if c.SubscriptionMaintenance.QueueSize < 0 {
return fmt.Errorf("subscription_maintenance.queue_size must be non-negative")
}
// Gemini OAuth 配置校验:client_id 与 client_secret 必须同时设置或同时留空。
// 留空时表示使用内置的 Gemini CLI OAuth 客户端(其 client_secret 通过环境变量注入)。
geminiClientID := strings.TrimSpace(c.Gemini.OAuth.ClientID)
geminiClientSecret := strings.TrimSpace(c.Gemini.OAuth.ClientSecret)
if (geminiClientID == "") != (geminiClientSecret == "") {
return fmt.Errorf("gemini.oauth.client_id and gemini.oauth.client_secret must be both set or both empty")
}
if strings.TrimSpace(c.Server.FrontendURL) != "" {
if err := ValidateAbsoluteHTTPURL(c.Server.FrontendURL); err != nil {
return fmt.Errorf("server.frontend_url invalid: %w", err)
}
u, err := url.Parse(strings.TrimSpace(c.Server.FrontendURL))
if err != nil {
return fmt.Errorf("server.frontend_url invalid: %w", err)
}
if u.RawQuery != "" || u.ForceQuery {
return fmt.Errorf("server.frontend_url invalid: must not include query")
}
if u.User != nil {
return fmt.Errorf("server.frontend_url invalid: must not include userinfo")
}
warnIfInsecureURL("server.frontend_url", c.Server.FrontendURL)
}
if c.JWT.ExpireHour <= 0 { if c.JWT.ExpireHour <= 0 {
return fmt.Errorf("jwt.expire_hour must be positive") return fmt.Errorf("jwt.expire_hour must be positive")
} }
...@@ -940,20 +1371,20 @@ func (c *Config) Validate() error { ...@@ -940,20 +1371,20 @@ func (c *Config) Validate() error {
return fmt.Errorf("jwt.expire_hour must be <= 168 (7 days)") return fmt.Errorf("jwt.expire_hour must be <= 168 (7 days)")
} }
if c.JWT.ExpireHour > 24 { if c.JWT.ExpireHour > 24 {
log.Printf("Warning: jwt.expire_hour is %d hours (> 24). Consider shorter expiration for security.", c.JWT.ExpireHour) slog.Warn("jwt.expire_hour is high; consider shorter expiration for security", "expire_hour", c.JWT.ExpireHour)
} }
// JWT Refresh Token配置验证 // JWT Refresh Token配置验证
if c.JWT.AccessTokenExpireMinutes <= 0 { if c.JWT.AccessTokenExpireMinutes < 0 {
return fmt.Errorf("jwt.access_token_expire_minutes must be positive") return fmt.Errorf("jwt.access_token_expire_minutes must be non-negative")
} }
if c.JWT.AccessTokenExpireMinutes > 720 { if c.JWT.AccessTokenExpireMinutes > 720 {
log.Printf("Warning: jwt.access_token_expire_minutes is %d (> 720). Consider shorter expiration for security.", c.JWT.AccessTokenExpireMinutes) slog.Warn("jwt.access_token_expire_minutes is high; consider shorter expiration for security", "access_token_expire_minutes", c.JWT.AccessTokenExpireMinutes)
} }
if c.JWT.RefreshTokenExpireDays <= 0 { if c.JWT.RefreshTokenExpireDays <= 0 {
return fmt.Errorf("jwt.refresh_token_expire_days must be positive") return fmt.Errorf("jwt.refresh_token_expire_days must be positive")
} }
if c.JWT.RefreshTokenExpireDays > 90 { if c.JWT.RefreshTokenExpireDays > 90 {
log.Printf("Warning: jwt.refresh_token_expire_days is %d (> 90). Consider shorter expiration for security.", c.JWT.RefreshTokenExpireDays) slog.Warn("jwt.refresh_token_expire_days is high; consider shorter expiration for security", "refresh_token_expire_days", c.JWT.RefreshTokenExpireDays)
} }
if c.JWT.RefreshWindowMinutes < 0 { if c.JWT.RefreshWindowMinutes < 0 {
return fmt.Errorf("jwt.refresh_window_minutes must be non-negative") return fmt.Errorf("jwt.refresh_window_minutes must be non-negative")
...@@ -1159,9 +1590,116 @@ func (c *Config) Validate() error { ...@@ -1159,9 +1590,116 @@ func (c *Config) Validate() error {
return fmt.Errorf("usage_cleanup.task_timeout_seconds must be non-negative") return fmt.Errorf("usage_cleanup.task_timeout_seconds must be non-negative")
} }
} }
if c.Idempotency.DefaultTTLSeconds <= 0 {
return fmt.Errorf("idempotency.default_ttl_seconds must be positive")
}
if c.Idempotency.SystemOperationTTLSeconds <= 0 {
return fmt.Errorf("idempotency.system_operation_ttl_seconds must be positive")
}
if c.Idempotency.ProcessingTimeoutSeconds <= 0 {
return fmt.Errorf("idempotency.processing_timeout_seconds must be positive")
}
if c.Idempotency.FailedRetryBackoffSeconds <= 0 {
return fmt.Errorf("idempotency.failed_retry_backoff_seconds must be positive")
}
if c.Idempotency.MaxStoredResponseLen <= 0 {
return fmt.Errorf("idempotency.max_stored_response_len must be positive")
}
if c.Idempotency.CleanupIntervalSeconds <= 0 {
return fmt.Errorf("idempotency.cleanup_interval_seconds must be positive")
}
if c.Idempotency.CleanupBatchSize <= 0 {
return fmt.Errorf("idempotency.cleanup_batch_size must be positive")
}
if c.Gateway.MaxBodySize <= 0 { if c.Gateway.MaxBodySize <= 0 {
return fmt.Errorf("gateway.max_body_size must be positive") return fmt.Errorf("gateway.max_body_size must be positive")
} }
if c.Gateway.UpstreamResponseReadMaxBytes <= 0 {
return fmt.Errorf("gateway.upstream_response_read_max_bytes must be positive")
}
if c.Gateway.ProxyProbeResponseReadMaxBytes <= 0 {
return fmt.Errorf("gateway.proxy_probe_response_read_max_bytes must be positive")
}
if c.Gateway.SoraMaxBodySize < 0 {
return fmt.Errorf("gateway.sora_max_body_size must be non-negative")
}
if c.Gateway.SoraStreamTimeoutSeconds < 0 {
return fmt.Errorf("gateway.sora_stream_timeout_seconds must be non-negative")
}
if c.Gateway.SoraRequestTimeoutSeconds < 0 {
return fmt.Errorf("gateway.sora_request_timeout_seconds must be non-negative")
}
if c.Gateway.SoraMediaSignedURLTTLSeconds < 0 {
return fmt.Errorf("gateway.sora_media_signed_url_ttl_seconds must be non-negative")
}
if mode := strings.TrimSpace(strings.ToLower(c.Gateway.SoraStreamMode)); mode != "" {
switch mode {
case "force", "error":
default:
return fmt.Errorf("gateway.sora_stream_mode must be one of: force/error")
}
}
if c.Sora.Client.TimeoutSeconds < 0 {
return fmt.Errorf("sora.client.timeout_seconds must be non-negative")
}
if c.Sora.Client.MaxRetries < 0 {
return fmt.Errorf("sora.client.max_retries must be non-negative")
}
if c.Sora.Client.CloudflareChallengeCooldownSeconds < 0 {
return fmt.Errorf("sora.client.cloudflare_challenge_cooldown_seconds must be non-negative")
}
if c.Sora.Client.PollIntervalSeconds < 0 {
return fmt.Errorf("sora.client.poll_interval_seconds must be non-negative")
}
if c.Sora.Client.MaxPollAttempts < 0 {
return fmt.Errorf("sora.client.max_poll_attempts must be non-negative")
}
if c.Sora.Client.RecentTaskLimit < 0 {
return fmt.Errorf("sora.client.recent_task_limit must be non-negative")
}
if c.Sora.Client.RecentTaskLimitMax < 0 {
return fmt.Errorf("sora.client.recent_task_limit_max must be non-negative")
}
if c.Sora.Client.RecentTaskLimitMax > 0 && c.Sora.Client.RecentTaskLimit > 0 &&
c.Sora.Client.RecentTaskLimitMax < c.Sora.Client.RecentTaskLimit {
c.Sora.Client.RecentTaskLimitMax = c.Sora.Client.RecentTaskLimit
}
if c.Sora.Client.CurlCFFISidecar.TimeoutSeconds < 0 {
return fmt.Errorf("sora.client.curl_cffi_sidecar.timeout_seconds must be non-negative")
}
if c.Sora.Client.CurlCFFISidecar.SessionTTLSeconds < 0 {
return fmt.Errorf("sora.client.curl_cffi_sidecar.session_ttl_seconds must be non-negative")
}
if !c.Sora.Client.CurlCFFISidecar.Enabled {
return fmt.Errorf("sora.client.curl_cffi_sidecar.enabled must be true")
}
if strings.TrimSpace(c.Sora.Client.CurlCFFISidecar.BaseURL) == "" {
return fmt.Errorf("sora.client.curl_cffi_sidecar.base_url is required")
}
if c.Sora.Storage.MaxConcurrentDownloads < 0 {
return fmt.Errorf("sora.storage.max_concurrent_downloads must be non-negative")
}
if c.Sora.Storage.DownloadTimeoutSeconds < 0 {
return fmt.Errorf("sora.storage.download_timeout_seconds must be non-negative")
}
if c.Sora.Storage.MaxDownloadBytes < 0 {
return fmt.Errorf("sora.storage.max_download_bytes must be non-negative")
}
if c.Sora.Storage.Cleanup.Enabled {
if c.Sora.Storage.Cleanup.RetentionDays <= 0 {
return fmt.Errorf("sora.storage.cleanup.retention_days must be positive")
}
if strings.TrimSpace(c.Sora.Storage.Cleanup.Schedule) == "" {
return fmt.Errorf("sora.storage.cleanup.schedule is required when cleanup is enabled")
}
} else {
if c.Sora.Storage.Cleanup.RetentionDays < 0 {
return fmt.Errorf("sora.storage.cleanup.retention_days must be non-negative")
}
}
if storageType := strings.TrimSpace(strings.ToLower(c.Sora.Storage.Type)); storageType != "" && storageType != "local" {
return fmt.Errorf("sora.storage.type must be 'local'")
}
if strings.TrimSpace(c.Gateway.ConnectionPoolIsolation) != "" { if strings.TrimSpace(c.Gateway.ConnectionPoolIsolation) != "" {
switch c.Gateway.ConnectionPoolIsolation { switch c.Gateway.ConnectionPoolIsolation {
case ConnectionPoolIsolationProxy, ConnectionPoolIsolationAccount, ConnectionPoolIsolationAccountProxy: case ConnectionPoolIsolationProxy, ConnectionPoolIsolationAccount, ConnectionPoolIsolationAccountProxy:
...@@ -1183,7 +1721,7 @@ func (c *Config) Validate() error { ...@@ -1183,7 +1721,7 @@ func (c *Config) Validate() error {
return fmt.Errorf("gateway.idle_conn_timeout_seconds must be positive") return fmt.Errorf("gateway.idle_conn_timeout_seconds must be positive")
} }
if c.Gateway.IdleConnTimeoutSeconds > 180 { if c.Gateway.IdleConnTimeoutSeconds > 180 {
log.Printf("Warning: gateway.idle_conn_timeout_seconds is %d (> 180). Consider 60-120 seconds for better connection reuse.", c.Gateway.IdleConnTimeoutSeconds) slog.Warn("gateway.idle_conn_timeout_seconds is high; consider 60-120 seconds for better connection reuse", "idle_conn_timeout_seconds", c.Gateway.IdleConnTimeoutSeconds)
} }
if c.Gateway.MaxUpstreamClients <= 0 { if c.Gateway.MaxUpstreamClients <= 0 {
return fmt.Errorf("gateway.max_upstream_clients must be positive") return fmt.Errorf("gateway.max_upstream_clients must be positive")
...@@ -1214,6 +1752,70 @@ func (c *Config) Validate() error { ...@@ -1214,6 +1752,70 @@ func (c *Config) Validate() error {
if c.Gateway.MaxLineSize != 0 && c.Gateway.MaxLineSize < 1024*1024 { if c.Gateway.MaxLineSize != 0 && c.Gateway.MaxLineSize < 1024*1024 {
return fmt.Errorf("gateway.max_line_size must be at least 1MB") return fmt.Errorf("gateway.max_line_size must be at least 1MB")
} }
if c.Gateway.UsageRecord.WorkerCount <= 0 {
return fmt.Errorf("gateway.usage_record.worker_count must be positive")
}
if c.Gateway.UsageRecord.QueueSize <= 0 {
return fmt.Errorf("gateway.usage_record.queue_size must be positive")
}
if c.Gateway.UsageRecord.TaskTimeoutSeconds <= 0 {
return fmt.Errorf("gateway.usage_record.task_timeout_seconds must be positive")
}
switch strings.ToLower(strings.TrimSpace(c.Gateway.UsageRecord.OverflowPolicy)) {
case UsageRecordOverflowPolicyDrop, UsageRecordOverflowPolicySample, UsageRecordOverflowPolicySync:
default:
return fmt.Errorf("gateway.usage_record.overflow_policy must be one of: %s/%s/%s",
UsageRecordOverflowPolicyDrop, UsageRecordOverflowPolicySample, UsageRecordOverflowPolicySync)
}
if c.Gateway.UsageRecord.OverflowSamplePercent < 0 || c.Gateway.UsageRecord.OverflowSamplePercent > 100 {
return fmt.Errorf("gateway.usage_record.overflow_sample_percent must be between 0-100")
}
if strings.EqualFold(strings.TrimSpace(c.Gateway.UsageRecord.OverflowPolicy), UsageRecordOverflowPolicySample) &&
c.Gateway.UsageRecord.OverflowSamplePercent <= 0 {
return fmt.Errorf("gateway.usage_record.overflow_sample_percent must be positive when overflow_policy=sample")
}
if c.Gateway.UsageRecord.AutoScaleEnabled {
if c.Gateway.UsageRecord.AutoScaleMinWorkers <= 0 {
return fmt.Errorf("gateway.usage_record.auto_scale_min_workers must be positive")
}
if c.Gateway.UsageRecord.AutoScaleMaxWorkers <= 0 {
return fmt.Errorf("gateway.usage_record.auto_scale_max_workers must be positive")
}
if c.Gateway.UsageRecord.AutoScaleMaxWorkers < c.Gateway.UsageRecord.AutoScaleMinWorkers {
return fmt.Errorf("gateway.usage_record.auto_scale_max_workers must be >= auto_scale_min_workers")
}
if c.Gateway.UsageRecord.WorkerCount < c.Gateway.UsageRecord.AutoScaleMinWorkers ||
c.Gateway.UsageRecord.WorkerCount > c.Gateway.UsageRecord.AutoScaleMaxWorkers {
return fmt.Errorf("gateway.usage_record.worker_count must be between auto_scale_min_workers and auto_scale_max_workers")
}
if c.Gateway.UsageRecord.AutoScaleUpQueuePercent <= 0 || c.Gateway.UsageRecord.AutoScaleUpQueuePercent > 100 {
return fmt.Errorf("gateway.usage_record.auto_scale_up_queue_percent must be between 1-100")
}
if c.Gateway.UsageRecord.AutoScaleDownQueuePercent < 0 || c.Gateway.UsageRecord.AutoScaleDownQueuePercent >= 100 {
return fmt.Errorf("gateway.usage_record.auto_scale_down_queue_percent must be between 0-99")
}
if c.Gateway.UsageRecord.AutoScaleDownQueuePercent >= c.Gateway.UsageRecord.AutoScaleUpQueuePercent {
return fmt.Errorf("gateway.usage_record.auto_scale_down_queue_percent must be less than auto_scale_up_queue_percent")
}
if c.Gateway.UsageRecord.AutoScaleUpStep <= 0 {
return fmt.Errorf("gateway.usage_record.auto_scale_up_step must be positive")
}
if c.Gateway.UsageRecord.AutoScaleDownStep <= 0 {
return fmt.Errorf("gateway.usage_record.auto_scale_down_step must be positive")
}
if c.Gateway.UsageRecord.AutoScaleCheckIntervalSeconds <= 0 {
return fmt.Errorf("gateway.usage_record.auto_scale_check_interval_seconds must be positive")
}
if c.Gateway.UsageRecord.AutoScaleCooldownSeconds < 0 {
return fmt.Errorf("gateway.usage_record.auto_scale_cooldown_seconds must be non-negative")
}
}
if c.Gateway.UserGroupRateCacheTTLSeconds <= 0 {
return fmt.Errorf("gateway.user_group_rate_cache_ttl_seconds must be positive")
}
if c.Gateway.ModelsListCacheTTLSeconds < 10 || c.Gateway.ModelsListCacheTTLSeconds > 30 {
return fmt.Errorf("gateway.models_list_cache_ttl_seconds must be between 10-30")
}
if c.Gateway.Scheduling.StickySessionMaxWaiting <= 0 { if c.Gateway.Scheduling.StickySessionMaxWaiting <= 0 {
return fmt.Errorf("gateway.scheduling.sticky_session_max_waiting must be positive") return fmt.Errorf("gateway.scheduling.sticky_session_max_waiting must be positive")
} }
...@@ -1420,6 +2022,6 @@ func warnIfInsecureURL(field, raw string) { ...@@ -1420,6 +2022,6 @@ func warnIfInsecureURL(field, raw string) {
return return
} }
if strings.EqualFold(u.Scheme, "http") { if strings.EqualFold(u.Scheme, "http") {
log.Printf("Warning: %s uses http scheme; use https in production to avoid token leakage.", field) slog.Warn("url uses http scheme; use https in production to avoid token leakage", "field", field)
} }
} }
...@@ -8,6 +8,25 @@ import ( ...@@ -8,6 +8,25 @@ import (
"github.com/spf13/viper" "github.com/spf13/viper"
) )
func resetViperWithJWTSecret(t *testing.T) {
t.Helper()
viper.Reset()
t.Setenv("JWT_SECRET", strings.Repeat("x", 32))
}
func TestLoadForBootstrapAllowsMissingJWTSecret(t *testing.T) {
viper.Reset()
t.Setenv("JWT_SECRET", "")
cfg, err := LoadForBootstrap()
if err != nil {
t.Fatalf("LoadForBootstrap() error: %v", err)
}
if cfg.JWT.Secret != "" {
t.Fatalf("LoadForBootstrap() should keep empty jwt.secret during bootstrap")
}
}
func TestNormalizeRunMode(t *testing.T) { func TestNormalizeRunMode(t *testing.T) {
tests := []struct { tests := []struct {
input string input string
...@@ -29,7 +48,7 @@ func TestNormalizeRunMode(t *testing.T) { ...@@ -29,7 +48,7 @@ func TestNormalizeRunMode(t *testing.T) {
} }
func TestLoadDefaultSchedulingConfig(t *testing.T) { func TestLoadDefaultSchedulingConfig(t *testing.T) {
viper.Reset() resetViperWithJWTSecret(t)
cfg, err := Load() cfg, err := Load()
if err != nil { if err != nil {
...@@ -56,8 +75,44 @@ func TestLoadDefaultSchedulingConfig(t *testing.T) { ...@@ -56,8 +75,44 @@ func TestLoadDefaultSchedulingConfig(t *testing.T) {
} }
} }
func TestLoadDefaultIdempotencyConfig(t *testing.T) {
resetViperWithJWTSecret(t)
cfg, err := Load()
if err != nil {
t.Fatalf("Load() error: %v", err)
}
if !cfg.Idempotency.ObserveOnly {
t.Fatalf("Idempotency.ObserveOnly = false, want true")
}
if cfg.Idempotency.DefaultTTLSeconds != 86400 {
t.Fatalf("Idempotency.DefaultTTLSeconds = %d, want 86400", cfg.Idempotency.DefaultTTLSeconds)
}
if cfg.Idempotency.SystemOperationTTLSeconds != 3600 {
t.Fatalf("Idempotency.SystemOperationTTLSeconds = %d, want 3600", cfg.Idempotency.SystemOperationTTLSeconds)
}
}
func TestLoadIdempotencyConfigFromEnv(t *testing.T) {
resetViperWithJWTSecret(t)
t.Setenv("IDEMPOTENCY_OBSERVE_ONLY", "false")
t.Setenv("IDEMPOTENCY_DEFAULT_TTL_SECONDS", "600")
cfg, err := Load()
if err != nil {
t.Fatalf("Load() error: %v", err)
}
if cfg.Idempotency.ObserveOnly {
t.Fatalf("Idempotency.ObserveOnly = true, want false")
}
if cfg.Idempotency.DefaultTTLSeconds != 600 {
t.Fatalf("Idempotency.DefaultTTLSeconds = %d, want 600", cfg.Idempotency.DefaultTTLSeconds)
}
}
func TestLoadSchedulingConfigFromEnv(t *testing.T) { func TestLoadSchedulingConfigFromEnv(t *testing.T) {
viper.Reset() resetViperWithJWTSecret(t)
t.Setenv("GATEWAY_SCHEDULING_STICKY_SESSION_MAX_WAITING", "5") t.Setenv("GATEWAY_SCHEDULING_STICKY_SESSION_MAX_WAITING", "5")
cfg, err := Load() cfg, err := Load()
...@@ -71,7 +126,7 @@ func TestLoadSchedulingConfigFromEnv(t *testing.T) { ...@@ -71,7 +126,7 @@ func TestLoadSchedulingConfigFromEnv(t *testing.T) {
} }
func TestLoadDefaultSecurityToggles(t *testing.T) { func TestLoadDefaultSecurityToggles(t *testing.T) {
viper.Reset() resetViperWithJWTSecret(t)
cfg, err := Load() cfg, err := Load()
if err != nil { if err != nil {
...@@ -87,13 +142,69 @@ func TestLoadDefaultSecurityToggles(t *testing.T) { ...@@ -87,13 +142,69 @@ func TestLoadDefaultSecurityToggles(t *testing.T) {
if !cfg.Security.URLAllowlist.AllowPrivateHosts { if !cfg.Security.URLAllowlist.AllowPrivateHosts {
t.Fatalf("URLAllowlist.AllowPrivateHosts = false, want true") t.Fatalf("URLAllowlist.AllowPrivateHosts = false, want true")
} }
if cfg.Security.ResponseHeaders.Enabled { if !cfg.Security.ResponseHeaders.Enabled {
t.Fatalf("ResponseHeaders.Enabled = true, want false") t.Fatalf("ResponseHeaders.Enabled = false, want true")
}
}
func TestLoadDefaultServerMode(t *testing.T) {
resetViperWithJWTSecret(t)
cfg, err := Load()
if err != nil {
t.Fatalf("Load() error: %v", err)
}
if cfg.Server.Mode != "release" {
t.Fatalf("Server.Mode = %q, want %q", cfg.Server.Mode, "release")
}
}
func TestLoadDefaultJWTAccessTokenExpireMinutes(t *testing.T) {
resetViperWithJWTSecret(t)
cfg, err := Load()
if err != nil {
t.Fatalf("Load() error: %v", err)
}
if cfg.JWT.ExpireHour != 24 {
t.Fatalf("JWT.ExpireHour = %d, want 24", cfg.JWT.ExpireHour)
}
if cfg.JWT.AccessTokenExpireMinutes != 0 {
t.Fatalf("JWT.AccessTokenExpireMinutes = %d, want 0", cfg.JWT.AccessTokenExpireMinutes)
}
}
func TestLoadJWTAccessTokenExpireMinutesFromEnv(t *testing.T) {
resetViperWithJWTSecret(t)
t.Setenv("JWT_ACCESS_TOKEN_EXPIRE_MINUTES", "90")
cfg, err := Load()
if err != nil {
t.Fatalf("Load() error: %v", err)
}
if cfg.JWT.AccessTokenExpireMinutes != 90 {
t.Fatalf("JWT.AccessTokenExpireMinutes = %d, want 90", cfg.JWT.AccessTokenExpireMinutes)
}
}
func TestLoadDefaultDatabaseSSLMode(t *testing.T) {
resetViperWithJWTSecret(t)
cfg, err := Load()
if err != nil {
t.Fatalf("Load() error: %v", err)
}
if cfg.Database.SSLMode != "prefer" {
t.Fatalf("Database.SSLMode = %q, want %q", cfg.Database.SSLMode, "prefer")
} }
} }
func TestValidateLinuxDoFrontendRedirectURL(t *testing.T) { func TestValidateLinuxDoFrontendRedirectURL(t *testing.T) {
viper.Reset() resetViperWithJWTSecret(t)
cfg, err := Load() cfg, err := Load()
if err != nil { if err != nil {
...@@ -118,7 +229,7 @@ func TestValidateLinuxDoFrontendRedirectURL(t *testing.T) { ...@@ -118,7 +229,7 @@ func TestValidateLinuxDoFrontendRedirectURL(t *testing.T) {
} }
func TestValidateLinuxDoPKCERequiredForPublicClient(t *testing.T) { func TestValidateLinuxDoPKCERequiredForPublicClient(t *testing.T) {
viper.Reset() resetViperWithJWTSecret(t)
cfg, err := Load() cfg, err := Load()
if err != nil { if err != nil {
...@@ -143,7 +254,7 @@ func TestValidateLinuxDoPKCERequiredForPublicClient(t *testing.T) { ...@@ -143,7 +254,7 @@ func TestValidateLinuxDoPKCERequiredForPublicClient(t *testing.T) {
} }
func TestLoadDefaultDashboardCacheConfig(t *testing.T) { func TestLoadDefaultDashboardCacheConfig(t *testing.T) {
viper.Reset() resetViperWithJWTSecret(t)
cfg, err := Load() cfg, err := Load()
if err != nil { if err != nil {
...@@ -168,7 +279,7 @@ func TestLoadDefaultDashboardCacheConfig(t *testing.T) { ...@@ -168,7 +279,7 @@ func TestLoadDefaultDashboardCacheConfig(t *testing.T) {
} }
func TestValidateDashboardCacheConfigEnabled(t *testing.T) { func TestValidateDashboardCacheConfigEnabled(t *testing.T) {
viper.Reset() resetViperWithJWTSecret(t)
cfg, err := Load() cfg, err := Load()
if err != nil { if err != nil {
...@@ -188,7 +299,7 @@ func TestValidateDashboardCacheConfigEnabled(t *testing.T) { ...@@ -188,7 +299,7 @@ func TestValidateDashboardCacheConfigEnabled(t *testing.T) {
} }
func TestValidateDashboardCacheConfigDisabled(t *testing.T) { func TestValidateDashboardCacheConfigDisabled(t *testing.T) {
viper.Reset() resetViperWithJWTSecret(t)
cfg, err := Load() cfg, err := Load()
if err != nil { if err != nil {
...@@ -207,7 +318,7 @@ func TestValidateDashboardCacheConfigDisabled(t *testing.T) { ...@@ -207,7 +318,7 @@ func TestValidateDashboardCacheConfigDisabled(t *testing.T) {
} }
func TestLoadDefaultDashboardAggregationConfig(t *testing.T) { func TestLoadDefaultDashboardAggregationConfig(t *testing.T) {
viper.Reset() resetViperWithJWTSecret(t)
cfg, err := Load() cfg, err := Load()
if err != nil { if err != nil {
...@@ -244,7 +355,7 @@ func TestLoadDefaultDashboardAggregationConfig(t *testing.T) { ...@@ -244,7 +355,7 @@ func TestLoadDefaultDashboardAggregationConfig(t *testing.T) {
} }
func TestValidateDashboardAggregationConfigDisabled(t *testing.T) { func TestValidateDashboardAggregationConfigDisabled(t *testing.T) {
viper.Reset() resetViperWithJWTSecret(t)
cfg, err := Load() cfg, err := Load()
if err != nil { if err != nil {
...@@ -263,7 +374,7 @@ func TestValidateDashboardAggregationConfigDisabled(t *testing.T) { ...@@ -263,7 +374,7 @@ func TestValidateDashboardAggregationConfigDisabled(t *testing.T) {
} }
func TestValidateDashboardAggregationBackfillMaxDays(t *testing.T) { func TestValidateDashboardAggregationBackfillMaxDays(t *testing.T) {
viper.Reset() resetViperWithJWTSecret(t)
cfg, err := Load() cfg, err := Load()
if err != nil { if err != nil {
...@@ -282,7 +393,7 @@ func TestValidateDashboardAggregationBackfillMaxDays(t *testing.T) { ...@@ -282,7 +393,7 @@ func TestValidateDashboardAggregationBackfillMaxDays(t *testing.T) {
} }
func TestLoadDefaultUsageCleanupConfig(t *testing.T) { func TestLoadDefaultUsageCleanupConfig(t *testing.T) {
viper.Reset() resetViperWithJWTSecret(t)
cfg, err := Load() cfg, err := Load()
if err != nil { if err != nil {
...@@ -307,7 +418,7 @@ func TestLoadDefaultUsageCleanupConfig(t *testing.T) { ...@@ -307,7 +418,7 @@ func TestLoadDefaultUsageCleanupConfig(t *testing.T) {
} }
func TestValidateUsageCleanupConfigEnabled(t *testing.T) { func TestValidateUsageCleanupConfigEnabled(t *testing.T) {
viper.Reset() resetViperWithJWTSecret(t)
cfg, err := Load() cfg, err := Load()
if err != nil { if err != nil {
...@@ -326,7 +437,7 @@ func TestValidateUsageCleanupConfigEnabled(t *testing.T) { ...@@ -326,7 +437,7 @@ func TestValidateUsageCleanupConfigEnabled(t *testing.T) {
} }
func TestValidateUsageCleanupConfigDisabled(t *testing.T) { func TestValidateUsageCleanupConfigDisabled(t *testing.T) {
viper.Reset() resetViperWithJWTSecret(t)
cfg, err := Load() cfg, err := Load()
if err != nil { if err != nil {
...@@ -424,6 +535,40 @@ func TestValidateAbsoluteHTTPURL(t *testing.T) { ...@@ -424,6 +535,40 @@ func TestValidateAbsoluteHTTPURL(t *testing.T) {
} }
} }
func TestValidateServerFrontendURL(t *testing.T) {
resetViperWithJWTSecret(t)
cfg, err := Load()
if err != nil {
t.Fatalf("Load() error: %v", err)
}
cfg.Server.FrontendURL = "https://example.com"
if err := cfg.Validate(); err != nil {
t.Fatalf("Validate() frontend_url valid error: %v", err)
}
cfg.Server.FrontendURL = "https://example.com/path"
if err := cfg.Validate(); err != nil {
t.Fatalf("Validate() frontend_url with path valid error: %v", err)
}
cfg.Server.FrontendURL = "https://example.com?utm=1"
if err := cfg.Validate(); err == nil {
t.Fatalf("Validate() should reject server.frontend_url with query")
}
cfg.Server.FrontendURL = "https://user:pass@example.com"
if err := cfg.Validate(); err == nil {
t.Fatalf("Validate() should reject server.frontend_url with userinfo")
}
cfg.Server.FrontendURL = "/relative"
if err := cfg.Validate(); err == nil {
t.Fatalf("Validate() should reject relative server.frontend_url")
}
}
func TestValidateFrontendRedirectURL(t *testing.T) { func TestValidateFrontendRedirectURL(t *testing.T) {
if err := ValidateFrontendRedirectURL("/auth/callback"); err != nil { if err := ValidateFrontendRedirectURL("/auth/callback"); err != nil {
t.Fatalf("ValidateFrontendRedirectURL relative error: %v", err) t.Fatalf("ValidateFrontendRedirectURL relative error: %v", err)
...@@ -445,6 +590,7 @@ func TestValidateFrontendRedirectURL(t *testing.T) { ...@@ -445,6 +590,7 @@ func TestValidateFrontendRedirectURL(t *testing.T) {
func TestWarnIfInsecureURL(t *testing.T) { func TestWarnIfInsecureURL(t *testing.T) {
warnIfInsecureURL("test", "http://example.com") warnIfInsecureURL("test", "http://example.com")
warnIfInsecureURL("test", "bad://url") warnIfInsecureURL("test", "bad://url")
warnIfInsecureURL("test", "://invalid")
} }
func TestGenerateJWTSecretDefaultLength(t *testing.T) { func TestGenerateJWTSecretDefaultLength(t *testing.T) {
...@@ -458,7 +604,7 @@ func TestGenerateJWTSecretDefaultLength(t *testing.T) { ...@@ -458,7 +604,7 @@ func TestGenerateJWTSecretDefaultLength(t *testing.T) {
} }
func TestValidateOpsCleanupScheduleRequired(t *testing.T) { func TestValidateOpsCleanupScheduleRequired(t *testing.T) {
viper.Reset() resetViperWithJWTSecret(t)
cfg, err := Load() cfg, err := Load()
if err != nil { if err != nil {
...@@ -476,7 +622,7 @@ func TestValidateOpsCleanupScheduleRequired(t *testing.T) { ...@@ -476,7 +622,7 @@ func TestValidateOpsCleanupScheduleRequired(t *testing.T) {
} }
func TestValidateConcurrencyPingInterval(t *testing.T) { func TestValidateConcurrencyPingInterval(t *testing.T) {
viper.Reset() resetViperWithJWTSecret(t)
cfg, err := Load() cfg, err := Load()
if err != nil { if err != nil {
...@@ -493,14 +639,14 @@ func TestValidateConcurrencyPingInterval(t *testing.T) { ...@@ -493,14 +639,14 @@ func TestValidateConcurrencyPingInterval(t *testing.T) {
} }
func TestProvideConfig(t *testing.T) { func TestProvideConfig(t *testing.T) {
viper.Reset() resetViperWithJWTSecret(t)
if _, err := ProvideConfig(); err != nil { if _, err := ProvideConfig(); err != nil {
t.Fatalf("ProvideConfig() error: %v", err) t.Fatalf("ProvideConfig() error: %v", err)
} }
} }
func TestValidateConfigWithLinuxDoEnabled(t *testing.T) { func TestValidateConfigWithLinuxDoEnabled(t *testing.T) {
viper.Reset() resetViperWithJWTSecret(t)
cfg, err := Load() cfg, err := Load()
if err != nil { if err != nil {
...@@ -544,6 +690,24 @@ func TestGenerateJWTSecretWithLength(t *testing.T) { ...@@ -544,6 +690,24 @@ func TestGenerateJWTSecretWithLength(t *testing.T) {
} }
} }
func TestDatabaseDSNWithTimezone_WithPassword(t *testing.T) {
d := &DatabaseConfig{
Host: "localhost",
Port: 5432,
User: "u",
Password: "p",
DBName: "db",
SSLMode: "prefer",
}
got := d.DSNWithTimezone("UTC")
if !strings.Contains(got, "password=p") {
t.Fatalf("DSNWithTimezone should include password: %q", got)
}
if !strings.Contains(got, "TimeZone=UTC") {
t.Fatalf("DSNWithTimezone should include TimeZone=UTC: %q", got)
}
}
func TestValidateAbsoluteHTTPURLMissingHost(t *testing.T) { func TestValidateAbsoluteHTTPURLMissingHost(t *testing.T) {
if err := ValidateAbsoluteHTTPURL("https://"); err == nil { if err := ValidateAbsoluteHTTPURL("https://"); err == nil {
t.Fatalf("ValidateAbsoluteHTTPURL should reject missing host") t.Fatalf("ValidateAbsoluteHTTPURL should reject missing host")
...@@ -566,10 +730,35 @@ func TestWarnIfInsecureURLHTTPS(t *testing.T) { ...@@ -566,10 +730,35 @@ func TestWarnIfInsecureURLHTTPS(t *testing.T) {
warnIfInsecureURL("secure", "https://example.com") warnIfInsecureURL("secure", "https://example.com")
} }
func TestValidateJWTSecret_UTF8Bytes(t *testing.T) {
resetViperWithJWTSecret(t)
cfg, err := Load()
if err != nil {
t.Fatalf("Load() error: %v", err)
}
// 31 bytes (< 32) even though it's 31 characters.
cfg.JWT.Secret = strings.Repeat("a", 31)
err = cfg.Validate()
if err == nil {
t.Fatalf("Validate() should reject 31-byte secret")
}
if !strings.Contains(err.Error(), "at least 32 bytes") {
t.Fatalf("Validate() error = %v", err)
}
// 32 bytes OK.
cfg.JWT.Secret = strings.Repeat("a", 32)
err = cfg.Validate()
if err != nil {
t.Fatalf("Validate() should accept 32-byte secret: %v", err)
}
}
func TestValidateConfigErrors(t *testing.T) { func TestValidateConfigErrors(t *testing.T) {
buildValid := func(t *testing.T) *Config { buildValid := func(t *testing.T) *Config {
t.Helper() t.Helper()
viper.Reset() resetViperWithJWTSecret(t)
cfg, err := Load() cfg, err := Load()
if err != nil { if err != nil {
t.Fatalf("Load() error: %v", err) t.Fatalf("Load() error: %v", err)
...@@ -582,6 +771,26 @@ func TestValidateConfigErrors(t *testing.T) { ...@@ -582,6 +771,26 @@ func TestValidateConfigErrors(t *testing.T) {
mutate func(*Config) mutate func(*Config)
wantErr string wantErr string
}{ }{
{
name: "jwt secret required",
mutate: func(c *Config) { c.JWT.Secret = "" },
wantErr: "jwt.secret is required",
},
{
name: "jwt secret min bytes",
mutate: func(c *Config) { c.JWT.Secret = strings.Repeat("a", 31) },
wantErr: "jwt.secret must be at least 32 bytes",
},
{
name: "subscription maintenance worker_count non-negative",
mutate: func(c *Config) { c.SubscriptionMaintenance.WorkerCount = -1 },
wantErr: "subscription_maintenance.worker_count",
},
{
name: "subscription maintenance queue_size non-negative",
mutate: func(c *Config) { c.SubscriptionMaintenance.QueueSize = -1 },
wantErr: "subscription_maintenance.queue_size",
},
{ {
name: "jwt expire hour positive", name: "jwt expire hour positive",
mutate: func(c *Config) { c.JWT.ExpireHour = 0 }, mutate: func(c *Config) { c.JWT.ExpireHour = 0 },
...@@ -592,6 +801,11 @@ func TestValidateConfigErrors(t *testing.T) { ...@@ -592,6 +801,11 @@ func TestValidateConfigErrors(t *testing.T) {
mutate: func(c *Config) { c.JWT.ExpireHour = 200 }, mutate: func(c *Config) { c.JWT.ExpireHour = 200 },
wantErr: "jwt.expire_hour must be <= 168", wantErr: "jwt.expire_hour must be <= 168",
}, },
{
name: "jwt access token expire minutes non-negative",
mutate: func(c *Config) { c.JWT.AccessTokenExpireMinutes = -1 },
wantErr: "jwt.access_token_expire_minutes must be non-negative",
},
{ {
name: "csp policy required", name: "csp policy required",
mutate: func(c *Config) { c.Security.CSP.Enabled = true; c.Security.CSP.Policy = "" }, mutate: func(c *Config) { c.Security.CSP.Enabled = true; c.Security.CSP.Policy = "" },
...@@ -799,6 +1013,84 @@ func TestValidateConfigErrors(t *testing.T) { ...@@ -799,6 +1013,84 @@ func TestValidateConfigErrors(t *testing.T) {
mutate: func(c *Config) { c.Gateway.MaxLineSize = -1 }, mutate: func(c *Config) { c.Gateway.MaxLineSize = -1 },
wantErr: "gateway.max_line_size must be non-negative", wantErr: "gateway.max_line_size must be non-negative",
}, },
{
name: "gateway usage record worker count",
mutate: func(c *Config) { c.Gateway.UsageRecord.WorkerCount = 0 },
wantErr: "gateway.usage_record.worker_count",
},
{
name: "gateway usage record queue size",
mutate: func(c *Config) { c.Gateway.UsageRecord.QueueSize = 0 },
wantErr: "gateway.usage_record.queue_size",
},
{
name: "gateway usage record timeout",
mutate: func(c *Config) { c.Gateway.UsageRecord.TaskTimeoutSeconds = 0 },
wantErr: "gateway.usage_record.task_timeout_seconds",
},
{
name: "gateway usage record overflow policy",
mutate: func(c *Config) { c.Gateway.UsageRecord.OverflowPolicy = "invalid" },
wantErr: "gateway.usage_record.overflow_policy",
},
{
name: "gateway usage record sample percent range",
mutate: func(c *Config) { c.Gateway.UsageRecord.OverflowSamplePercent = 101 },
wantErr: "gateway.usage_record.overflow_sample_percent",
},
{
name: "gateway usage record sample percent required for sample policy",
mutate: func(c *Config) {
c.Gateway.UsageRecord.OverflowPolicy = UsageRecordOverflowPolicySample
c.Gateway.UsageRecord.OverflowSamplePercent = 0
},
wantErr: "gateway.usage_record.overflow_sample_percent must be positive",
},
{
name: "gateway usage record auto scale max gte min",
mutate: func(c *Config) {
c.Gateway.UsageRecord.AutoScaleMinWorkers = 256
c.Gateway.UsageRecord.AutoScaleMaxWorkers = 128
},
wantErr: "gateway.usage_record.auto_scale_max_workers",
},
{
name: "gateway usage record worker in auto scale range",
mutate: func(c *Config) {
c.Gateway.UsageRecord.AutoScaleMinWorkers = 200
c.Gateway.UsageRecord.AutoScaleMaxWorkers = 300
c.Gateway.UsageRecord.WorkerCount = 128
},
wantErr: "gateway.usage_record.worker_count must be between auto_scale_min_workers and auto_scale_max_workers",
},
{
name: "gateway usage record auto scale queue thresholds order",
mutate: func(c *Config) {
c.Gateway.UsageRecord.AutoScaleUpQueuePercent = 50
c.Gateway.UsageRecord.AutoScaleDownQueuePercent = 50
},
wantErr: "gateway.usage_record.auto_scale_down_queue_percent must be less",
},
{
name: "gateway usage record auto scale up step",
mutate: func(c *Config) { c.Gateway.UsageRecord.AutoScaleUpStep = 0 },
wantErr: "gateway.usage_record.auto_scale_up_step",
},
{
name: "gateway usage record auto scale interval",
mutate: func(c *Config) { c.Gateway.UsageRecord.AutoScaleCheckIntervalSeconds = 0 },
wantErr: "gateway.usage_record.auto_scale_check_interval_seconds",
},
{
name: "gateway user group rate cache ttl",
mutate: func(c *Config) { c.Gateway.UserGroupRateCacheTTLSeconds = 0 },
wantErr: "gateway.user_group_rate_cache_ttl_seconds",
},
{
name: "gateway models list cache ttl range",
mutate: func(c *Config) { c.Gateway.ModelsListCacheTTLSeconds = 31 },
wantErr: "gateway.models_list_cache_ttl_seconds",
},
{ {
name: "gateway scheduling sticky waiting", name: "gateway scheduling sticky waiting",
mutate: func(c *Config) { c.Gateway.Scheduling.StickySessionMaxWaiting = 0 }, mutate: func(c *Config) { c.Gateway.Scheduling.StickySessionMaxWaiting = 0 },
...@@ -822,6 +1114,37 @@ func TestValidateConfigErrors(t *testing.T) { ...@@ -822,6 +1114,37 @@ func TestValidateConfigErrors(t *testing.T) {
}, },
wantErr: "gateway.scheduling.outbox_lag_rebuild_seconds", wantErr: "gateway.scheduling.outbox_lag_rebuild_seconds",
}, },
{
name: "log level invalid",
mutate: func(c *Config) { c.Log.Level = "trace" },
wantErr: "log.level",
},
{
name: "log format invalid",
mutate: func(c *Config) { c.Log.Format = "plain" },
wantErr: "log.format",
},
{
name: "log output disabled",
mutate: func(c *Config) {
c.Log.Output.ToStdout = false
c.Log.Output.ToFile = false
},
wantErr: "log.output.to_stdout and log.output.to_file cannot both be false",
},
{
name: "log rotation size",
mutate: func(c *Config) { c.Log.Rotation.MaxSizeMB = 0 },
wantErr: "log.rotation.max_size_mb",
},
{
name: "log sampling enabled invalid",
mutate: func(c *Config) {
c.Log.Sampling.Enabled = true
c.Log.Sampling.Initial = 0
},
wantErr: "log.sampling.initial",
},
{ {
name: "ops metrics collector ttl", name: "ops metrics collector ttl",
mutate: func(c *Config) { c.Ops.MetricsCollectorCache.TTL = -1 }, mutate: func(c *Config) { c.Ops.MetricsCollectorCache.TTL = -1 },
...@@ -850,3 +1173,234 @@ func TestValidateConfigErrors(t *testing.T) { ...@@ -850,3 +1173,234 @@ func TestValidateConfigErrors(t *testing.T) {
}) })
} }
} }
func TestValidateConfig_AutoScaleDisabledIgnoreAutoScaleFields(t *testing.T) {
resetViperWithJWTSecret(t)
cfg, err := Load()
if err != nil {
t.Fatalf("Load() error: %v", err)
}
cfg.Gateway.UsageRecord.AutoScaleEnabled = false
cfg.Gateway.UsageRecord.WorkerCount = 64
// 自动扩缩容关闭时,这些字段应被忽略,不应导致校验失败。
cfg.Gateway.UsageRecord.AutoScaleMinWorkers = 0
cfg.Gateway.UsageRecord.AutoScaleMaxWorkers = 0
cfg.Gateway.UsageRecord.AutoScaleUpQueuePercent = 0
cfg.Gateway.UsageRecord.AutoScaleDownQueuePercent = 100
cfg.Gateway.UsageRecord.AutoScaleUpStep = 0
cfg.Gateway.UsageRecord.AutoScaleDownStep = 0
cfg.Gateway.UsageRecord.AutoScaleCheckIntervalSeconds = 0
cfg.Gateway.UsageRecord.AutoScaleCooldownSeconds = -1
if err := cfg.Validate(); err != nil {
t.Fatalf("Validate() should ignore auto scale fields when disabled: %v", err)
}
}
func TestValidateConfig_LogRequiredAndRotationBounds(t *testing.T) {
resetViperWithJWTSecret(t)
cases := []struct {
name string
mutate func(*Config)
wantErr string
}{
{
name: "log level required",
mutate: func(c *Config) {
c.Log.Level = ""
},
wantErr: "log.level is required",
},
{
name: "log format required",
mutate: func(c *Config) {
c.Log.Format = ""
},
wantErr: "log.format is required",
},
{
name: "log stacktrace required",
mutate: func(c *Config) {
c.Log.StacktraceLevel = ""
},
wantErr: "log.stacktrace_level is required",
},
{
name: "log max backups non-negative",
mutate: func(c *Config) {
c.Log.Rotation.MaxBackups = -1
},
wantErr: "log.rotation.max_backups must be non-negative",
},
{
name: "log max age non-negative",
mutate: func(c *Config) {
c.Log.Rotation.MaxAgeDays = -1
},
wantErr: "log.rotation.max_age_days must be non-negative",
},
{
name: "sampling thereafter non-negative when disabled",
mutate: func(c *Config) {
c.Log.Sampling.Enabled = false
c.Log.Sampling.Thereafter = -1
},
wantErr: "log.sampling.thereafter must be non-negative",
},
}
for _, tt := range cases {
t.Run(tt.name, func(t *testing.T) {
cfg, err := Load()
if err != nil {
t.Fatalf("Load() error: %v", err)
}
tt.mutate(cfg)
err = cfg.Validate()
if err == nil || !strings.Contains(err.Error(), tt.wantErr) {
t.Fatalf("Validate() error = %v, want %q", err, tt.wantErr)
}
})
}
}
func TestSoraCurlCFFISidecarDefaults(t *testing.T) {
resetViperWithJWTSecret(t)
cfg, err := Load()
if err != nil {
t.Fatalf("Load() error: %v", err)
}
if !cfg.Sora.Client.CurlCFFISidecar.Enabled {
t.Fatalf("Sora curl_cffi sidecar should be enabled by default")
}
if cfg.Sora.Client.CloudflareChallengeCooldownSeconds <= 0 {
t.Fatalf("Sora cloudflare challenge cooldown should be positive by default")
}
if cfg.Sora.Client.CurlCFFISidecar.BaseURL == "" {
t.Fatalf("Sora curl_cffi sidecar base_url should not be empty by default")
}
if cfg.Sora.Client.CurlCFFISidecar.Impersonate == "" {
t.Fatalf("Sora curl_cffi sidecar impersonate should not be empty by default")
}
if !cfg.Sora.Client.CurlCFFISidecar.SessionReuseEnabled {
t.Fatalf("Sora curl_cffi sidecar session reuse should be enabled by default")
}
if cfg.Sora.Client.CurlCFFISidecar.SessionTTLSeconds <= 0 {
t.Fatalf("Sora curl_cffi sidecar session ttl should be positive by default")
}
}
func TestValidateSoraCurlCFFISidecarRequired(t *testing.T) {
resetViperWithJWTSecret(t)
cfg, err := Load()
if err != nil {
t.Fatalf("Load() error: %v", err)
}
cfg.Sora.Client.CurlCFFISidecar.Enabled = false
err = cfg.Validate()
if err == nil || !strings.Contains(err.Error(), "sora.client.curl_cffi_sidecar.enabled must be true") {
t.Fatalf("Validate() error = %v, want sidecar enabled error", err)
}
}
func TestValidateSoraCurlCFFISidecarBaseURLRequired(t *testing.T) {
resetViperWithJWTSecret(t)
cfg, err := Load()
if err != nil {
t.Fatalf("Load() error: %v", err)
}
cfg.Sora.Client.CurlCFFISidecar.BaseURL = " "
err = cfg.Validate()
if err == nil || !strings.Contains(err.Error(), "sora.client.curl_cffi_sidecar.base_url is required") {
t.Fatalf("Validate() error = %v, want sidecar base_url required error", err)
}
}
func TestValidateSoraCurlCFFISidecarSessionTTLNonNegative(t *testing.T) {
resetViperWithJWTSecret(t)
cfg, err := Load()
if err != nil {
t.Fatalf("Load() error: %v", err)
}
cfg.Sora.Client.CurlCFFISidecar.SessionTTLSeconds = -1
err = cfg.Validate()
if err == nil || !strings.Contains(err.Error(), "sora.client.curl_cffi_sidecar.session_ttl_seconds must be non-negative") {
t.Fatalf("Validate() error = %v, want sidecar session ttl error", err)
}
}
func TestValidateSoraCloudflareChallengeCooldownNonNegative(t *testing.T) {
resetViperWithJWTSecret(t)
cfg, err := Load()
if err != nil {
t.Fatalf("Load() error: %v", err)
}
cfg.Sora.Client.CloudflareChallengeCooldownSeconds = -1
err = cfg.Validate()
if err == nil || !strings.Contains(err.Error(), "sora.client.cloudflare_challenge_cooldown_seconds must be non-negative") {
t.Fatalf("Validate() error = %v, want cloudflare cooldown error", err)
}
}
func TestLoad_DefaultGatewayUsageRecordConfig(t *testing.T) {
resetViperWithJWTSecret(t)
cfg, err := Load()
if err != nil {
t.Fatalf("Load() error: %v", err)
}
if cfg.Gateway.UsageRecord.WorkerCount != 128 {
t.Fatalf("worker_count = %d, want 128", cfg.Gateway.UsageRecord.WorkerCount)
}
if cfg.Gateway.UsageRecord.QueueSize != 16384 {
t.Fatalf("queue_size = %d, want 16384", cfg.Gateway.UsageRecord.QueueSize)
}
if cfg.Gateway.UsageRecord.TaskTimeoutSeconds != 5 {
t.Fatalf("task_timeout_seconds = %d, want 5", cfg.Gateway.UsageRecord.TaskTimeoutSeconds)
}
if cfg.Gateway.UsageRecord.OverflowPolicy != UsageRecordOverflowPolicySample {
t.Fatalf("overflow_policy = %s, want %s", cfg.Gateway.UsageRecord.OverflowPolicy, UsageRecordOverflowPolicySample)
}
if cfg.Gateway.UsageRecord.OverflowSamplePercent != 10 {
t.Fatalf("overflow_sample_percent = %d, want 10", cfg.Gateway.UsageRecord.OverflowSamplePercent)
}
if !cfg.Gateway.UsageRecord.AutoScaleEnabled {
t.Fatalf("auto_scale_enabled = false, want true")
}
if cfg.Gateway.UsageRecord.AutoScaleMinWorkers != 128 {
t.Fatalf("auto_scale_min_workers = %d, want 128", cfg.Gateway.UsageRecord.AutoScaleMinWorkers)
}
if cfg.Gateway.UsageRecord.AutoScaleMaxWorkers != 512 {
t.Fatalf("auto_scale_max_workers = %d, want 512", cfg.Gateway.UsageRecord.AutoScaleMaxWorkers)
}
if cfg.Gateway.UsageRecord.AutoScaleUpQueuePercent != 70 {
t.Fatalf("auto_scale_up_queue_percent = %d, want 70", cfg.Gateway.UsageRecord.AutoScaleUpQueuePercent)
}
if cfg.Gateway.UsageRecord.AutoScaleDownQueuePercent != 15 {
t.Fatalf("auto_scale_down_queue_percent = %d, want 15", cfg.Gateway.UsageRecord.AutoScaleDownQueuePercent)
}
if cfg.Gateway.UsageRecord.AutoScaleUpStep != 32 {
t.Fatalf("auto_scale_up_step = %d, want 32", cfg.Gateway.UsageRecord.AutoScaleUpStep)
}
if cfg.Gateway.UsageRecord.AutoScaleDownStep != 16 {
t.Fatalf("auto_scale_down_step = %d, want 16", cfg.Gateway.UsageRecord.AutoScaleDownStep)
}
if cfg.Gateway.UsageRecord.AutoScaleCheckIntervalSeconds != 3 {
t.Fatalf("auto_scale_check_interval_seconds = %d, want 3", cfg.Gateway.UsageRecord.AutoScaleCheckIntervalSeconds)
}
if cfg.Gateway.UsageRecord.AutoScaleCooldownSeconds != 10 {
t.Fatalf("auto_scale_cooldown_seconds = %d, want 10", cfg.Gateway.UsageRecord.AutoScaleCooldownSeconds)
}
}
...@@ -9,5 +9,5 @@ var ProviderSet = wire.NewSet( ...@@ -9,5 +9,5 @@ var ProviderSet = wire.NewSet(
// ProvideConfig 提供应用配置 // ProvideConfig 提供应用配置
func ProvideConfig() (*Config, error) { func ProvideConfig() (*Config, error) {
return Load() return LoadForBootstrap()
} }
...@@ -22,6 +22,7 @@ const ( ...@@ -22,6 +22,7 @@ const (
PlatformOpenAI = "openai" PlatformOpenAI = "openai"
PlatformGemini = "gemini" PlatformGemini = "gemini"
PlatformAntigravity = "antigravity" PlatformAntigravity = "antigravity"
PlatformSora = "sora"
) )
// Account type constants // Account type constants
......
...@@ -175,22 +175,28 @@ func (h *AccountHandler) ImportData(c *gin.Context) { ...@@ -175,22 +175,28 @@ func (h *AccountHandler) ImportData(c *gin.Context) {
return return
} }
dataPayload := req.Data if err := validateDataHeader(req.Data); err != nil {
if err := validateDataHeader(dataPayload); err != nil {
response.BadRequest(c, err.Error()) response.BadRequest(c, err.Error())
return return
} }
executeAdminIdempotentJSON(c, "admin.accounts.import_data", req, service.DefaultWriteIdempotencyTTL(), func(ctx context.Context) (any, error) {
return h.importData(ctx, req)
})
}
func (h *AccountHandler) importData(ctx context.Context, req DataImportRequest) (DataImportResult, error) {
skipDefaultGroupBind := true skipDefaultGroupBind := true
if req.SkipDefaultGroupBind != nil { if req.SkipDefaultGroupBind != nil {
skipDefaultGroupBind = *req.SkipDefaultGroupBind skipDefaultGroupBind = *req.SkipDefaultGroupBind
} }
dataPayload := req.Data
result := DataImportResult{} result := DataImportResult{}
existingProxies, err := h.listAllProxies(c.Request.Context())
existingProxies, err := h.listAllProxies(ctx)
if err != nil { if err != nil {
response.ErrorFrom(c, err) return result, err
return
} }
proxyKeyToID := make(map[string]int64, len(existingProxies)) proxyKeyToID := make(map[string]int64, len(existingProxies))
...@@ -221,8 +227,8 @@ func (h *AccountHandler) ImportData(c *gin.Context) { ...@@ -221,8 +227,8 @@ func (h *AccountHandler) ImportData(c *gin.Context) {
proxyKeyToID[key] = existingID proxyKeyToID[key] = existingID
result.ProxyReused++ result.ProxyReused++
if normalizedStatus != "" { if normalizedStatus != "" {
if proxy, err := h.adminService.GetProxy(c.Request.Context(), existingID); err == nil && proxy != nil && proxy.Status != normalizedStatus { if proxy, getErr := h.adminService.GetProxy(ctx, existingID); getErr == nil && proxy != nil && proxy.Status != normalizedStatus {
_, _ = h.adminService.UpdateProxy(c.Request.Context(), existingID, &service.UpdateProxyInput{ _, _ = h.adminService.UpdateProxy(ctx, existingID, &service.UpdateProxyInput{
Status: normalizedStatus, Status: normalizedStatus,
}) })
} }
...@@ -230,7 +236,7 @@ func (h *AccountHandler) ImportData(c *gin.Context) { ...@@ -230,7 +236,7 @@ func (h *AccountHandler) ImportData(c *gin.Context) {
continue continue
} }
created, err := h.adminService.CreateProxy(c.Request.Context(), &service.CreateProxyInput{ created, createErr := h.adminService.CreateProxy(ctx, &service.CreateProxyInput{
Name: defaultProxyName(item.Name), Name: defaultProxyName(item.Name),
Protocol: item.Protocol, Protocol: item.Protocol,
Host: item.Host, Host: item.Host,
...@@ -238,13 +244,13 @@ func (h *AccountHandler) ImportData(c *gin.Context) { ...@@ -238,13 +244,13 @@ func (h *AccountHandler) ImportData(c *gin.Context) {
Username: item.Username, Username: item.Username,
Password: item.Password, Password: item.Password,
}) })
if err != nil { if createErr != nil {
result.ProxyFailed++ result.ProxyFailed++
result.Errors = append(result.Errors, DataImportError{ result.Errors = append(result.Errors, DataImportError{
Kind: "proxy", Kind: "proxy",
Name: item.Name, Name: item.Name,
ProxyKey: key, ProxyKey: key,
Message: err.Error(), Message: createErr.Error(),
}) })
continue continue
} }
...@@ -252,7 +258,7 @@ func (h *AccountHandler) ImportData(c *gin.Context) { ...@@ -252,7 +258,7 @@ func (h *AccountHandler) ImportData(c *gin.Context) {
result.ProxyCreated++ result.ProxyCreated++
if normalizedStatus != "" && normalizedStatus != created.Status { if normalizedStatus != "" && normalizedStatus != created.Status {
_, _ = h.adminService.UpdateProxy(c.Request.Context(), created.ID, &service.UpdateProxyInput{ _, _ = h.adminService.UpdateProxy(ctx, created.ID, &service.UpdateProxyInput{
Status: normalizedStatus, Status: normalizedStatus,
}) })
} }
...@@ -303,7 +309,7 @@ func (h *AccountHandler) ImportData(c *gin.Context) { ...@@ -303,7 +309,7 @@ func (h *AccountHandler) ImportData(c *gin.Context) {
SkipDefaultGroupBind: skipDefaultGroupBind, SkipDefaultGroupBind: skipDefaultGroupBind,
} }
if _, err := h.adminService.CreateAccount(c.Request.Context(), accountInput); err != nil { if _, err := h.adminService.CreateAccount(ctx, accountInput); err != nil {
result.AccountFailed++ result.AccountFailed++
result.Errors = append(result.Errors, DataImportError{ result.Errors = append(result.Errors, DataImportError{
Kind: "account", Kind: "account",
...@@ -315,7 +321,7 @@ func (h *AccountHandler) ImportData(c *gin.Context) { ...@@ -315,7 +321,7 @@ func (h *AccountHandler) ImportData(c *gin.Context) {
result.AccountCreated++ result.AccountCreated++
} }
response.Success(c, result) return result, nil
} }
func (h *AccountHandler) listAllProxies(ctx context.Context) ([]service.Proxy, error) { func (h *AccountHandler) listAllProxies(ctx context.Context) ([]service.Proxy, error) {
......
...@@ -2,7 +2,13 @@ ...@@ -2,7 +2,13 @@
package admin package admin
import ( import (
"context"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"errors" "errors"
"fmt"
"net/http"
"strconv" "strconv"
"strings" "strings"
"sync" "sync"
...@@ -142,6 +148,44 @@ type AccountWithConcurrency struct { ...@@ -142,6 +148,44 @@ type AccountWithConcurrency struct {
ActiveSessions *int `json:"active_sessions,omitempty"` // 当前活跃会话数 ActiveSessions *int `json:"active_sessions,omitempty"` // 当前活跃会话数
} }
func (h *AccountHandler) buildAccountResponseWithRuntime(ctx context.Context, account *service.Account) AccountWithConcurrency {
item := AccountWithConcurrency{
Account: dto.AccountFromService(account),
CurrentConcurrency: 0,
}
if account == nil {
return item
}
if h.concurrencyService != nil {
if counts, err := h.concurrencyService.GetAccountConcurrencyBatch(ctx, []int64{account.ID}); err == nil {
item.CurrentConcurrency = counts[account.ID]
}
}
if account.IsAnthropicOAuthOrSetupToken() {
if h.accountUsageService != nil && account.GetWindowCostLimit() > 0 {
startTime := account.GetCurrentWindowStartTime()
if stats, err := h.accountUsageService.GetAccountWindowStats(ctx, account.ID, startTime); err == nil && stats != nil {
cost := stats.StandardCost
item.CurrentWindowCost = &cost
}
}
if h.sessionLimitCache != nil && account.GetMaxSessions() > 0 {
idleTimeout := time.Duration(account.GetSessionIdleTimeoutMinutes()) * time.Minute
idleTimeouts := map[int64]time.Duration{account.ID: idleTimeout}
if sessions, err := h.sessionLimitCache.GetActiveSessionCountBatch(ctx, []int64{account.ID}, idleTimeouts); err == nil {
if count, ok := sessions[account.ID]; ok {
item.ActiveSessions = &count
}
}
}
}
return item
}
// List handles listing all accounts with pagination // List handles listing all accounts with pagination
// GET /api/v1/admin/accounts // GET /api/v1/admin/accounts
func (h *AccountHandler) List(c *gin.Context) { func (h *AccountHandler) List(c *gin.Context) {
...@@ -262,9 +306,71 @@ func (h *AccountHandler) List(c *gin.Context) { ...@@ -262,9 +306,71 @@ func (h *AccountHandler) List(c *gin.Context) {
result[i] = item result[i] = item
} }
etag := buildAccountsListETag(result, total, page, pageSize, platform, accountType, status, search)
if etag != "" {
c.Header("ETag", etag)
c.Header("Vary", "If-None-Match")
if ifNoneMatchMatched(c.GetHeader("If-None-Match"), etag) {
c.Status(http.StatusNotModified)
return
}
}
response.Paginated(c, result, total, page, pageSize) response.Paginated(c, result, total, page, pageSize)
} }
func buildAccountsListETag(
items []AccountWithConcurrency,
total int64,
page, pageSize int,
platform, accountType, status, search string,
) string {
payload := struct {
Total int64 `json:"total"`
Page int `json:"page"`
PageSize int `json:"page_size"`
Platform string `json:"platform"`
AccountType string `json:"type"`
Status string `json:"status"`
Search string `json:"search"`
Items []AccountWithConcurrency `json:"items"`
}{
Total: total,
Page: page,
PageSize: pageSize,
Platform: platform,
AccountType: accountType,
Status: status,
Search: search,
Items: items,
}
raw, err := json.Marshal(payload)
if err != nil {
return ""
}
sum := sha256.Sum256(raw)
return "\"" + hex.EncodeToString(sum[:]) + "\""
}
func ifNoneMatchMatched(ifNoneMatch, etag string) bool {
if etag == "" || ifNoneMatch == "" {
return false
}
for _, token := range strings.Split(ifNoneMatch, ",") {
candidate := strings.TrimSpace(token)
if candidate == "*" {
return true
}
if candidate == etag {
return true
}
if strings.HasPrefix(candidate, "W/") && strings.TrimPrefix(candidate, "W/") == etag {
return true
}
}
return false
}
// GetByID handles getting an account by ID // GetByID handles getting an account by ID
// GET /api/v1/admin/accounts/:id // GET /api/v1/admin/accounts/:id
func (h *AccountHandler) GetByID(c *gin.Context) { func (h *AccountHandler) GetByID(c *gin.Context) {
...@@ -280,7 +386,7 @@ func (h *AccountHandler) GetByID(c *gin.Context) { ...@@ -280,7 +386,7 @@ func (h *AccountHandler) GetByID(c *gin.Context) {
return return
} }
response.Success(c, dto.AccountFromService(account)) response.Success(c, h.buildAccountResponseWithRuntime(c.Request.Context(), account))
} }
// Create handles creating a new account // Create handles creating a new account
...@@ -299,21 +405,27 @@ func (h *AccountHandler) Create(c *gin.Context) { ...@@ -299,21 +405,27 @@ func (h *AccountHandler) Create(c *gin.Context) {
// 确定是否跳过混合渠道检查 // 确定是否跳过混合渠道检查
skipCheck := req.ConfirmMixedChannelRisk != nil && *req.ConfirmMixedChannelRisk skipCheck := req.ConfirmMixedChannelRisk != nil && *req.ConfirmMixedChannelRisk
account, err := h.adminService.CreateAccount(c.Request.Context(), &service.CreateAccountInput{ result, err := executeAdminIdempotent(c, "admin.accounts.create", req, service.DefaultWriteIdempotencyTTL(), func(ctx context.Context) (any, error) {
Name: req.Name, account, execErr := h.adminService.CreateAccount(ctx, &service.CreateAccountInput{
Notes: req.Notes, Name: req.Name,
Platform: req.Platform, Notes: req.Notes,
Type: req.Type, Platform: req.Platform,
Credentials: req.Credentials, Type: req.Type,
Extra: req.Extra, Credentials: req.Credentials,
ProxyID: req.ProxyID, Extra: req.Extra,
Concurrency: req.Concurrency, ProxyID: req.ProxyID,
Priority: req.Priority, Concurrency: req.Concurrency,
RateMultiplier: req.RateMultiplier, Priority: req.Priority,
GroupIDs: req.GroupIDs, RateMultiplier: req.RateMultiplier,
ExpiresAt: req.ExpiresAt, GroupIDs: req.GroupIDs,
AutoPauseOnExpired: req.AutoPauseOnExpired, ExpiresAt: req.ExpiresAt,
SkipMixedChannelCheck: skipCheck, AutoPauseOnExpired: req.AutoPauseOnExpired,
SkipMixedChannelCheck: skipCheck,
})
if execErr != nil {
return nil, execErr
}
return h.buildAccountResponseWithRuntime(ctx, account), nil
}) })
if err != nil { if err != nil {
// 检查是否为混合渠道错误 // 检查是否为混合渠道错误
...@@ -334,11 +446,17 @@ func (h *AccountHandler) Create(c *gin.Context) { ...@@ -334,11 +446,17 @@ func (h *AccountHandler) Create(c *gin.Context) {
return return
} }
if retryAfter := service.RetryAfterSecondsFromError(err); retryAfter > 0 {
c.Header("Retry-After", strconv.Itoa(retryAfter))
}
response.ErrorFrom(c, err) response.ErrorFrom(c, err)
return return
} }
response.Success(c, dto.AccountFromService(account)) if result != nil && result.Replayed {
c.Header("X-Idempotency-Replayed", "true")
}
response.Success(c, result.Data)
} }
// Update handles updating an account // Update handles updating an account
...@@ -402,7 +520,7 @@ func (h *AccountHandler) Update(c *gin.Context) { ...@@ -402,7 +520,7 @@ func (h *AccountHandler) Update(c *gin.Context) {
return return
} }
response.Success(c, dto.AccountFromService(account)) response.Success(c, h.buildAccountResponseWithRuntime(c.Request.Context(), account))
} }
// Delete handles deleting an account // Delete handles deleting an account
...@@ -660,7 +778,7 @@ func (h *AccountHandler) Refresh(c *gin.Context) { ...@@ -660,7 +778,7 @@ func (h *AccountHandler) Refresh(c *gin.Context) {
} }
} }
response.Success(c, dto.AccountFromService(updatedAccount)) response.Success(c, h.buildAccountResponseWithRuntime(c.Request.Context(), updatedAccount))
} }
// GetStats handles getting account statistics // GetStats handles getting account statistics
...@@ -718,7 +836,7 @@ func (h *AccountHandler) ClearError(c *gin.Context) { ...@@ -718,7 +836,7 @@ func (h *AccountHandler) ClearError(c *gin.Context) {
} }
} }
response.Success(c, dto.AccountFromService(account)) response.Success(c, h.buildAccountResponseWithRuntime(c.Request.Context(), account))
} }
// BatchCreate handles batch creating accounts // BatchCreate handles batch creating accounts
...@@ -732,61 +850,62 @@ func (h *AccountHandler) BatchCreate(c *gin.Context) { ...@@ -732,61 +850,62 @@ func (h *AccountHandler) BatchCreate(c *gin.Context) {
return return
} }
ctx := c.Request.Context() executeAdminIdempotentJSON(c, "admin.accounts.batch_create", req, service.DefaultWriteIdempotencyTTL(), func(ctx context.Context) (any, error) {
success := 0 success := 0
failed := 0 failed := 0
results := make([]gin.H, 0, len(req.Accounts)) results := make([]gin.H, 0, len(req.Accounts))
for _, item := range req.Accounts { for _, item := range req.Accounts {
if item.RateMultiplier != nil && *item.RateMultiplier < 0 { if item.RateMultiplier != nil && *item.RateMultiplier < 0 {
failed++ failed++
results = append(results, gin.H{ results = append(results, gin.H{
"name": item.Name, "name": item.Name,
"success": false, "success": false,
"error": "rate_multiplier must be >= 0", "error": "rate_multiplier must be >= 0",
}) })
continue continue
} }
skipCheck := item.ConfirmMixedChannelRisk != nil && *item.ConfirmMixedChannelRisk skipCheck := item.ConfirmMixedChannelRisk != nil && *item.ConfirmMixedChannelRisk
account, err := h.adminService.CreateAccount(ctx, &service.CreateAccountInput{ account, err := h.adminService.CreateAccount(ctx, &service.CreateAccountInput{
Name: item.Name, Name: item.Name,
Notes: item.Notes, Notes: item.Notes,
Platform: item.Platform, Platform: item.Platform,
Type: item.Type, Type: item.Type,
Credentials: item.Credentials, Credentials: item.Credentials,
Extra: item.Extra, Extra: item.Extra,
ProxyID: item.ProxyID, ProxyID: item.ProxyID,
Concurrency: item.Concurrency, Concurrency: item.Concurrency,
Priority: item.Priority, Priority: item.Priority,
RateMultiplier: item.RateMultiplier, RateMultiplier: item.RateMultiplier,
GroupIDs: item.GroupIDs, GroupIDs: item.GroupIDs,
ExpiresAt: item.ExpiresAt, ExpiresAt: item.ExpiresAt,
AutoPauseOnExpired: item.AutoPauseOnExpired, AutoPauseOnExpired: item.AutoPauseOnExpired,
SkipMixedChannelCheck: skipCheck, SkipMixedChannelCheck: skipCheck,
}) })
if err != nil { if err != nil {
failed++ failed++
results = append(results, gin.H{
"name": item.Name,
"success": false,
"error": err.Error(),
})
continue
}
success++
results = append(results, gin.H{ results = append(results, gin.H{
"name": item.Name, "name": item.Name,
"success": false, "id": account.ID,
"error": err.Error(), "success": true,
}) })
continue
} }
success++
results = append(results, gin.H{
"name": item.Name,
"id": account.ID,
"success": true,
})
}
response.Success(c, gin.H{ return gin.H{
"success": success, "success": success,
"failed": failed, "failed": failed,
"results": results, "results": results,
}, nil
}) })
} }
...@@ -824,57 +943,58 @@ func (h *AccountHandler) BatchUpdateCredentials(c *gin.Context) { ...@@ -824,57 +943,58 @@ func (h *AccountHandler) BatchUpdateCredentials(c *gin.Context) {
} }
ctx := c.Request.Context() ctx := c.Request.Context()
success := 0
failed := 0
results := []gin.H{}
// 阶段一:预验证所有账号存在,收集 credentials
type accountUpdate struct {
ID int64
Credentials map[string]any
}
updates := make([]accountUpdate, 0, len(req.AccountIDs))
for _, accountID := range req.AccountIDs { for _, accountID := range req.AccountIDs {
// Get account
account, err := h.adminService.GetAccount(ctx, accountID) account, err := h.adminService.GetAccount(ctx, accountID)
if err != nil { if err != nil {
failed++ response.Error(c, 404, fmt.Sprintf("Account %d not found", accountID))
results = append(results, gin.H{ return
"account_id": accountID,
"success": false,
"error": "Account not found",
})
continue
} }
// Update credentials field
if account.Credentials == nil { if account.Credentials == nil {
account.Credentials = make(map[string]any) account.Credentials = make(map[string]any)
} }
account.Credentials[req.Field] = req.Value account.Credentials[req.Field] = req.Value
updates = append(updates, accountUpdate{ID: accountID, Credentials: account.Credentials})
}
// Update account // 阶段二:依次更新,返回每个账号的成功/失败明细,便于调用方重试
updateInput := &service.UpdateAccountInput{ success := 0
Credentials: account.Credentials, failed := 0
} successIDs := make([]int64, 0, len(updates))
failedIDs := make([]int64, 0, len(updates))
_, err = h.adminService.UpdateAccount(ctx, accountID, updateInput) results := make([]gin.H, 0, len(updates))
if err != nil { for _, u := range updates {
updateInput := &service.UpdateAccountInput{Credentials: u.Credentials}
if _, err := h.adminService.UpdateAccount(ctx, u.ID, updateInput); err != nil {
failed++ failed++
failedIDs = append(failedIDs, u.ID)
results = append(results, gin.H{ results = append(results, gin.H{
"account_id": accountID, "account_id": u.ID,
"success": false, "success": false,
"error": err.Error(), "error": err.Error(),
}) })
continue continue
} }
success++ success++
successIDs = append(successIDs, u.ID)
results = append(results, gin.H{ results = append(results, gin.H{
"account_id": accountID, "account_id": u.ID,
"success": true, "success": true,
}) })
} }
response.Success(c, gin.H{ response.Success(c, gin.H{
"success": success, "success": success,
"failed": failed, "failed": failed,
"results": results, "success_ids": successIDs,
"failed_ids": failedIDs,
"results": results,
}) })
} }
...@@ -1109,7 +1229,13 @@ func (h *AccountHandler) ClearRateLimit(c *gin.Context) { ...@@ -1109,7 +1229,13 @@ func (h *AccountHandler) ClearRateLimit(c *gin.Context) {
return return
} }
response.Success(c, gin.H{"message": "Rate limit cleared successfully"}) account, err := h.adminService.GetAccount(c.Request.Context(), accountID)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, h.buildAccountResponseWithRuntime(c.Request.Context(), account))
} }
// GetTempUnschedulable handles getting temporary unschedulable status // GetTempUnschedulable handles getting temporary unschedulable status
...@@ -1199,7 +1325,7 @@ func (h *AccountHandler) SetSchedulable(c *gin.Context) { ...@@ -1199,7 +1325,7 @@ func (h *AccountHandler) SetSchedulable(c *gin.Context) {
return return
} }
response.Success(c, dto.AccountFromService(account)) response.Success(c, h.buildAccountResponseWithRuntime(c.Request.Context(), account))
} }
// GetAvailableModels handles getting available models for an account // GetAvailableModels handles getting available models for an account
...@@ -1325,6 +1451,12 @@ func (h *AccountHandler) GetAvailableModels(c *gin.Context) { ...@@ -1325,6 +1451,12 @@ func (h *AccountHandler) GetAvailableModels(c *gin.Context) {
return return
} }
// Handle Sora accounts
if account.Platform == service.PlatformSora {
response.Success(c, service.DefaultSoraModels(nil))
return
}
// Handle Claude/Anthropic accounts // Handle Claude/Anthropic accounts
// For OAuth and Setup-Token accounts: return default models // For OAuth and Setup-Token accounts: return default models
if account.IsOAuth() { if account.IsOAuth() {
......
package admin
import (
"bytes"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
func TestAccountHandler_Create_AnthropicAPIKeyPassthroughExtraForwarded(t *testing.T) {
gin.SetMode(gin.TestMode)
adminSvc := newStubAdminService()
handler := NewAccountHandler(
adminSvc,
nil,
nil,
nil,
nil,
nil,
nil,
nil,
nil,
nil,
nil,
nil,
)
router := gin.New()
router.POST("/api/v1/admin/accounts", handler.Create)
body := map[string]any{
"name": "anthropic-key-1",
"platform": "anthropic",
"type": "apikey",
"credentials": map[string]any{
"api_key": "sk-ant-xxx",
"base_url": "https://api.anthropic.com",
},
"extra": map[string]any{
"anthropic_passthrough": true,
},
"concurrency": 1,
"priority": 1,
}
raw, err := json.Marshal(body)
require.NoError(t, err)
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/accounts", bytes.NewReader(raw))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
require.Len(t, adminSvc.createdAccounts, 1)
created := adminSvc.createdAccounts[0]
require.Equal(t, "anthropic", created.Platform)
require.Equal(t, "apikey", created.Type)
require.NotNil(t, created.Extra)
require.Equal(t, true, created.Extra["anthropic_passthrough"])
}
...@@ -47,6 +47,7 @@ func setupAdminRouter() (*gin.Engine, *stubAdminService) { ...@@ -47,6 +47,7 @@ func setupAdminRouter() (*gin.Engine, *stubAdminService) {
router.DELETE("/api/v1/admin/proxies/:id", proxyHandler.Delete) router.DELETE("/api/v1/admin/proxies/:id", proxyHandler.Delete)
router.POST("/api/v1/admin/proxies/batch-delete", proxyHandler.BatchDelete) router.POST("/api/v1/admin/proxies/batch-delete", proxyHandler.BatchDelete)
router.POST("/api/v1/admin/proxies/:id/test", proxyHandler.Test) router.POST("/api/v1/admin/proxies/:id/test", proxyHandler.Test)
router.POST("/api/v1/admin/proxies/:id/quality-check", proxyHandler.CheckQuality)
router.GET("/api/v1/admin/proxies/:id/stats", proxyHandler.GetStats) router.GET("/api/v1/admin/proxies/:id/stats", proxyHandler.GetStats)
router.GET("/api/v1/admin/proxies/:id/accounts", proxyHandler.GetProxyAccounts) router.GET("/api/v1/admin/proxies/:id/accounts", proxyHandler.GetProxyAccounts)
...@@ -208,6 +209,11 @@ func TestProxyHandlerEndpoints(t *testing.T) { ...@@ -208,6 +209,11 @@ func TestProxyHandlerEndpoints(t *testing.T) {
router.ServeHTTP(rec, req) router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code) require.Equal(t, http.StatusOK, rec.Code)
rec = httptest.NewRecorder()
req = httptest.NewRequest(http.MethodPost, "/api/v1/admin/proxies/4/quality-check", nil)
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
rec = httptest.NewRecorder() rec = httptest.NewRecorder()
req = httptest.NewRequest(http.MethodGet, "/api/v1/admin/proxies/4/stats", nil) req = httptest.NewRequest(http.MethodGet, "/api/v1/admin/proxies/4/stats", nil)
router.ServeHTTP(rec, req) router.ServeHTTP(rec, req)
......
...@@ -58,6 +58,96 @@ func TestParseOpsDuration(t *testing.T) { ...@@ -58,6 +58,96 @@ func TestParseOpsDuration(t *testing.T) {
require.False(t, ok) require.False(t, ok)
} }
func TestParseOpsOpenAITokenStatsDuration(t *testing.T) {
tests := []struct {
input string
want time.Duration
ok bool
}{
{input: "30m", want: 30 * time.Minute, ok: true},
{input: "1h", want: time.Hour, ok: true},
{input: "1d", want: 24 * time.Hour, ok: true},
{input: "15d", want: 15 * 24 * time.Hour, ok: true},
{input: "30d", want: 30 * 24 * time.Hour, ok: true},
{input: "7d", want: 0, ok: false},
}
for _, tt := range tests {
got, ok := parseOpsOpenAITokenStatsDuration(tt.input)
require.Equal(t, tt.ok, ok, "input=%s", tt.input)
require.Equal(t, tt.want, got, "input=%s", tt.input)
}
}
func TestParseOpsOpenAITokenStatsFilter_Defaults(t *testing.T) {
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
before := time.Now().UTC()
filter, err := parseOpsOpenAITokenStatsFilter(c)
after := time.Now().UTC()
require.NoError(t, err)
require.NotNil(t, filter)
require.Equal(t, "30d", filter.TimeRange)
require.Equal(t, 1, filter.Page)
require.Equal(t, 20, filter.PageSize)
require.Equal(t, 0, filter.TopN)
require.Nil(t, filter.GroupID)
require.Equal(t, "", filter.Platform)
require.True(t, filter.StartTime.Before(filter.EndTime))
require.WithinDuration(t, before.Add(-30*24*time.Hour), filter.StartTime, 2*time.Second)
require.WithinDuration(t, after, filter.EndTime, 2*time.Second)
}
func TestParseOpsOpenAITokenStatsFilter_WithTopN(t *testing.T) {
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(
http.MethodGet,
"/?time_range=1h&platform=openai&group_id=12&top_n=50",
nil,
)
filter, err := parseOpsOpenAITokenStatsFilter(c)
require.NoError(t, err)
require.Equal(t, "1h", filter.TimeRange)
require.Equal(t, "openai", filter.Platform)
require.NotNil(t, filter.GroupID)
require.Equal(t, int64(12), *filter.GroupID)
require.Equal(t, 50, filter.TopN)
require.Equal(t, 0, filter.Page)
require.Equal(t, 0, filter.PageSize)
}
func TestParseOpsOpenAITokenStatsFilter_InvalidParams(t *testing.T) {
tests := []string{
"/?time_range=7d",
"/?group_id=0",
"/?group_id=abc",
"/?top_n=0",
"/?top_n=101",
"/?top_n=10&page=1",
"/?top_n=10&page_size=20",
"/?page=0",
"/?page_size=0",
"/?page_size=101",
}
gin.SetMode(gin.TestMode)
for _, rawURL := range tests {
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodGet, rawURL, nil)
_, err := parseOpsOpenAITokenStatsFilter(c)
require.Error(t, err, "url=%s", rawURL)
}
}
func TestParseOpsTimeRange(t *testing.T) { func TestParseOpsTimeRange(t *testing.T) {
gin.SetMode(gin.TestMode) gin.SetMode(gin.TestMode)
w := httptest.NewRecorder() w := httptest.NewRecorder()
......
...@@ -327,6 +327,27 @@ func (s *stubAdminService) TestProxy(ctx context.Context, id int64) (*service.Pr ...@@ -327,6 +327,27 @@ func (s *stubAdminService) TestProxy(ctx context.Context, id int64) (*service.Pr
return &service.ProxyTestResult{Success: true, Message: "ok"}, nil return &service.ProxyTestResult{Success: true, Message: "ok"}, nil
} }
func (s *stubAdminService) CheckProxyQuality(ctx context.Context, id int64) (*service.ProxyQualityCheckResult, error) {
return &service.ProxyQualityCheckResult{
ProxyID: id,
Score: 95,
Grade: "A",
Summary: "通过 5 项,告警 0 项,失败 0 项,挑战 0 项",
PassedCount: 5,
WarnCount: 0,
FailedCount: 0,
ChallengeCount: 0,
CheckedAt: time.Now().Unix(),
Items: []service.ProxyQualityCheckItem{
{Target: "base_connectivity", Status: "pass", Message: "ok"},
{Target: "openai", Status: "pass", HTTPStatus: 401},
{Target: "anthropic", Status: "pass", HTTPStatus: 401},
{Target: "gemini", Status: "pass", HTTPStatus: 200},
{Target: "sora", Status: "pass", HTTPStatus: 401},
},
}, nil
}
func (s *stubAdminService) ListRedeemCodes(ctx context.Context, page, pageSize int, codeType, status, search string) ([]service.RedeemCode, int64, error) { func (s *stubAdminService) ListRedeemCodes(ctx context.Context, page, pageSize int, codeType, status, search string) ([]service.RedeemCode, int64, error) {
return s.redeems, int64(len(s.redeems)), nil return s.redeems, int64(len(s.redeems)), nil
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment