Commit 13262a56 authored by yangjianbo's avatar yangjianbo
Browse files

feat(sora): 新增 Sora 平台支持并修复高危安全和性能问题



新增功能:
- 新增 Sora 账号管理和 OAuth 认证
- 新增 Sora 视频/图片生成 API 网关
- 新增 Sora 任务调度和缓存机制
- 新增 Sora 使用统计和计费支持
- 前端增加 Sora 平台配置界面

安全修复(代码审核):
- [SEC-001] 限制媒体下载响应体大小(图片 20MB、视频 200MB),防止 DoS 攻击
- [SEC-002] 限制 SDK API 响应大小(1MB),防止内存耗尽
- [SEC-003] 修复 SSRF 风险,添加 URL 验证并强制使用代理配置

BUG 修复(代码审核):
- [BUG-001] 修复 for 循环内 defer 累积导致的资源泄漏
- [BUG-002] 修复图片并发槽位获取失败时已持有锁未释放的永久泄漏

性能优化(代码审核):
- [PERF-001] 添加 Sentinel Token 缓存(3 分钟有效期),减少 PoW 计算开销

技术细节:
- 使用 io.LimitReader 限制所有外部输入的大小
- 添加 urlvalidator 验证防止 SSRF 攻击
- 使用 sync.Map 实现线程安全的包级缓存
- 优化并发槽位管理,添加 releaseAll 模式防止泄漏

影响范围:
- 后端:新增 Sora 相关数据模型、服务、网关和管理接口
- 前端:新增 Sora 平台配置、账号管理和监控界面
- 配置:新增 Sora 相关配置项和环境变量
Co-Authored-By: default avatarClaude Sonnet 4.5 <noreply@anthropic.com>
parent bece1b52
// Code generated by ent, DO NOT EDIT.
package ent
import (
"context"
"fmt"
"math"
"entgo.io/ent"
"entgo.io/ent/dialect"
"entgo.io/ent/dialect/sql"
"entgo.io/ent/dialect/sql/sqlgraph"
"entgo.io/ent/schema/field"
"github.com/Wei-Shaw/sub2api/ent/predicate"
"github.com/Wei-Shaw/sub2api/ent/sorausagestat"
)
// SoraUsageStatQuery is the builder for querying SoraUsageStat entities.
type SoraUsageStatQuery struct {
config
ctx *QueryContext
order []sorausagestat.OrderOption
inters []Interceptor
predicates []predicate.SoraUsageStat
modifiers []func(*sql.Selector)
// intermediate query (i.e. traversal path).
sql *sql.Selector
path func(context.Context) (*sql.Selector, error)
}
// Where adds a new predicate for the SoraUsageStatQuery builder.
func (_q *SoraUsageStatQuery) Where(ps ...predicate.SoraUsageStat) *SoraUsageStatQuery {
_q.predicates = append(_q.predicates, ps...)
return _q
}
// Limit the number of records to be returned by this query.
func (_q *SoraUsageStatQuery) Limit(limit int) *SoraUsageStatQuery {
_q.ctx.Limit = &limit
return _q
}
// Offset to start from.
func (_q *SoraUsageStatQuery) Offset(offset int) *SoraUsageStatQuery {
_q.ctx.Offset = &offset
return _q
}
// Unique configures the query builder to filter duplicate records on query.
// By default, unique is set to true, and can be disabled using this method.
func (_q *SoraUsageStatQuery) Unique(unique bool) *SoraUsageStatQuery {
_q.ctx.Unique = &unique
return _q
}
// Order specifies how the records should be ordered.
func (_q *SoraUsageStatQuery) Order(o ...sorausagestat.OrderOption) *SoraUsageStatQuery {
_q.order = append(_q.order, o...)
return _q
}
// First returns the first SoraUsageStat entity from the query.
// Returns a *NotFoundError when no SoraUsageStat was found.
func (_q *SoraUsageStatQuery) First(ctx context.Context) (*SoraUsageStat, error) {
nodes, err := _q.Limit(1).All(setContextOp(ctx, _q.ctx, ent.OpQueryFirst))
if err != nil {
return nil, err
}
if len(nodes) == 0 {
return nil, &NotFoundError{sorausagestat.Label}
}
return nodes[0], nil
}
// FirstX is like First, but panics if an error occurs.
func (_q *SoraUsageStatQuery) FirstX(ctx context.Context) *SoraUsageStat {
node, err := _q.First(ctx)
if err != nil && !IsNotFound(err) {
panic(err)
}
return node
}
// FirstID returns the first SoraUsageStat ID from the query.
// Returns a *NotFoundError when no SoraUsageStat ID was found.
func (_q *SoraUsageStatQuery) FirstID(ctx context.Context) (id int64, err error) {
var ids []int64
if ids, err = _q.Limit(1).IDs(setContextOp(ctx, _q.ctx, ent.OpQueryFirstID)); err != nil {
return
}
if len(ids) == 0 {
err = &NotFoundError{sorausagestat.Label}
return
}
return ids[0], nil
}
// FirstIDX is like FirstID, but panics if an error occurs.
func (_q *SoraUsageStatQuery) FirstIDX(ctx context.Context) int64 {
id, err := _q.FirstID(ctx)
if err != nil && !IsNotFound(err) {
panic(err)
}
return id
}
// Only returns a single SoraUsageStat entity found by the query, ensuring it only returns one.
// Returns a *NotSingularError when more than one SoraUsageStat entity is found.
// Returns a *NotFoundError when no SoraUsageStat entities are found.
func (_q *SoraUsageStatQuery) Only(ctx context.Context) (*SoraUsageStat, error) {
nodes, err := _q.Limit(2).All(setContextOp(ctx, _q.ctx, ent.OpQueryOnly))
if err != nil {
return nil, err
}
switch len(nodes) {
case 1:
return nodes[0], nil
case 0:
return nil, &NotFoundError{sorausagestat.Label}
default:
return nil, &NotSingularError{sorausagestat.Label}
}
}
// OnlyX is like Only, but panics if an error occurs.
func (_q *SoraUsageStatQuery) OnlyX(ctx context.Context) *SoraUsageStat {
node, err := _q.Only(ctx)
if err != nil {
panic(err)
}
return node
}
// OnlyID is like Only, but returns the only SoraUsageStat ID in the query.
// Returns a *NotSingularError when more than one SoraUsageStat ID is found.
// Returns a *NotFoundError when no entities are found.
func (_q *SoraUsageStatQuery) OnlyID(ctx context.Context) (id int64, err error) {
var ids []int64
if ids, err = _q.Limit(2).IDs(setContextOp(ctx, _q.ctx, ent.OpQueryOnlyID)); err != nil {
return
}
switch len(ids) {
case 1:
id = ids[0]
case 0:
err = &NotFoundError{sorausagestat.Label}
default:
err = &NotSingularError{sorausagestat.Label}
}
return
}
// OnlyIDX is like OnlyID, but panics if an error occurs.
func (_q *SoraUsageStatQuery) OnlyIDX(ctx context.Context) int64 {
id, err := _q.OnlyID(ctx)
if err != nil {
panic(err)
}
return id
}
// All executes the query and returns a list of SoraUsageStats.
func (_q *SoraUsageStatQuery) All(ctx context.Context) ([]*SoraUsageStat, error) {
ctx = setContextOp(ctx, _q.ctx, ent.OpQueryAll)
if err := _q.prepareQuery(ctx); err != nil {
return nil, err
}
qr := querierAll[[]*SoraUsageStat, *SoraUsageStatQuery]()
return withInterceptors[[]*SoraUsageStat](ctx, _q, qr, _q.inters)
}
// AllX is like All, but panics if an error occurs.
func (_q *SoraUsageStatQuery) AllX(ctx context.Context) []*SoraUsageStat {
nodes, err := _q.All(ctx)
if err != nil {
panic(err)
}
return nodes
}
// IDs executes the query and returns a list of SoraUsageStat IDs.
func (_q *SoraUsageStatQuery) IDs(ctx context.Context) (ids []int64, err error) {
if _q.ctx.Unique == nil && _q.path != nil {
_q.Unique(true)
}
ctx = setContextOp(ctx, _q.ctx, ent.OpQueryIDs)
if err = _q.Select(sorausagestat.FieldID).Scan(ctx, &ids); err != nil {
return nil, err
}
return ids, nil
}
// IDsX is like IDs, but panics if an error occurs.
func (_q *SoraUsageStatQuery) IDsX(ctx context.Context) []int64 {
ids, err := _q.IDs(ctx)
if err != nil {
panic(err)
}
return ids
}
// Count returns the count of the given query.
func (_q *SoraUsageStatQuery) Count(ctx context.Context) (int, error) {
ctx = setContextOp(ctx, _q.ctx, ent.OpQueryCount)
if err := _q.prepareQuery(ctx); err != nil {
return 0, err
}
return withInterceptors[int](ctx, _q, querierCount[*SoraUsageStatQuery](), _q.inters)
}
// CountX is like Count, but panics if an error occurs.
func (_q *SoraUsageStatQuery) CountX(ctx context.Context) int {
count, err := _q.Count(ctx)
if err != nil {
panic(err)
}
return count
}
// Exist returns true if the query has elements in the graph.
func (_q *SoraUsageStatQuery) Exist(ctx context.Context) (bool, error) {
ctx = setContextOp(ctx, _q.ctx, ent.OpQueryExist)
switch _, err := _q.FirstID(ctx); {
case IsNotFound(err):
return false, nil
case err != nil:
return false, fmt.Errorf("ent: check existence: %w", err)
default:
return true, nil
}
}
// ExistX is like Exist, but panics if an error occurs.
func (_q *SoraUsageStatQuery) ExistX(ctx context.Context) bool {
exist, err := _q.Exist(ctx)
if err != nil {
panic(err)
}
return exist
}
// Clone returns a duplicate of the SoraUsageStatQuery builder, including all associated steps. It can be
// used to prepare common query builders and use them differently after the clone is made.
func (_q *SoraUsageStatQuery) Clone() *SoraUsageStatQuery {
if _q == nil {
return nil
}
return &SoraUsageStatQuery{
config: _q.config,
ctx: _q.ctx.Clone(),
order: append([]sorausagestat.OrderOption{}, _q.order...),
inters: append([]Interceptor{}, _q.inters...),
predicates: append([]predicate.SoraUsageStat{}, _q.predicates...),
// clone intermediate query.
sql: _q.sql.Clone(),
path: _q.path,
}
}
// GroupBy is used to group vertices by one or more fields/columns.
// It is often used with aggregate functions, like: count, max, mean, min, sum.
//
// Example:
//
// var v []struct {
// CreatedAt time.Time `json:"created_at,omitempty"`
// Count int `json:"count,omitempty"`
// }
//
// client.SoraUsageStat.Query().
// GroupBy(sorausagestat.FieldCreatedAt).
// Aggregate(ent.Count()).
// Scan(ctx, &v)
func (_q *SoraUsageStatQuery) GroupBy(field string, fields ...string) *SoraUsageStatGroupBy {
_q.ctx.Fields = append([]string{field}, fields...)
grbuild := &SoraUsageStatGroupBy{build: _q}
grbuild.flds = &_q.ctx.Fields
grbuild.label = sorausagestat.Label
grbuild.scan = grbuild.Scan
return grbuild
}
// Select allows the selection one or more fields/columns for the given query,
// instead of selecting all fields in the entity.
//
// Example:
//
// var v []struct {
// CreatedAt time.Time `json:"created_at,omitempty"`
// }
//
// client.SoraUsageStat.Query().
// Select(sorausagestat.FieldCreatedAt).
// Scan(ctx, &v)
func (_q *SoraUsageStatQuery) Select(fields ...string) *SoraUsageStatSelect {
_q.ctx.Fields = append(_q.ctx.Fields, fields...)
sbuild := &SoraUsageStatSelect{SoraUsageStatQuery: _q}
sbuild.label = sorausagestat.Label
sbuild.flds, sbuild.scan = &_q.ctx.Fields, sbuild.Scan
return sbuild
}
// Aggregate returns a SoraUsageStatSelect configured with the given aggregations.
func (_q *SoraUsageStatQuery) Aggregate(fns ...AggregateFunc) *SoraUsageStatSelect {
return _q.Select().Aggregate(fns...)
}
func (_q *SoraUsageStatQuery) prepareQuery(ctx context.Context) error {
for _, inter := range _q.inters {
if inter == nil {
return fmt.Errorf("ent: uninitialized interceptor (forgotten import ent/runtime?)")
}
if trv, ok := inter.(Traverser); ok {
if err := trv.Traverse(ctx, _q); err != nil {
return err
}
}
}
for _, f := range _q.ctx.Fields {
if !sorausagestat.ValidColumn(f) {
return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)}
}
}
if _q.path != nil {
prev, err := _q.path(ctx)
if err != nil {
return err
}
_q.sql = prev
}
return nil
}
func (_q *SoraUsageStatQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*SoraUsageStat, error) {
var (
nodes = []*SoraUsageStat{}
_spec = _q.querySpec()
)
_spec.ScanValues = func(columns []string) ([]any, error) {
return (*SoraUsageStat).scanValues(nil, columns)
}
_spec.Assign = func(columns []string, values []any) error {
node := &SoraUsageStat{config: _q.config}
nodes = append(nodes, node)
return node.assignValues(columns, values)
}
if len(_q.modifiers) > 0 {
_spec.Modifiers = _q.modifiers
}
for i := range hooks {
hooks[i](ctx, _spec)
}
if err := sqlgraph.QueryNodes(ctx, _q.driver, _spec); err != nil {
return nil, err
}
if len(nodes) == 0 {
return nodes, nil
}
return nodes, nil
}
func (_q *SoraUsageStatQuery) sqlCount(ctx context.Context) (int, error) {
_spec := _q.querySpec()
if len(_q.modifiers) > 0 {
_spec.Modifiers = _q.modifiers
}
_spec.Node.Columns = _q.ctx.Fields
if len(_q.ctx.Fields) > 0 {
_spec.Unique = _q.ctx.Unique != nil && *_q.ctx.Unique
}
return sqlgraph.CountNodes(ctx, _q.driver, _spec)
}
func (_q *SoraUsageStatQuery) querySpec() *sqlgraph.QuerySpec {
_spec := sqlgraph.NewQuerySpec(sorausagestat.Table, sorausagestat.Columns, sqlgraph.NewFieldSpec(sorausagestat.FieldID, field.TypeInt64))
_spec.From = _q.sql
if unique := _q.ctx.Unique; unique != nil {
_spec.Unique = *unique
} else if _q.path != nil {
_spec.Unique = true
}
if fields := _q.ctx.Fields; len(fields) > 0 {
_spec.Node.Columns = make([]string, 0, len(fields))
_spec.Node.Columns = append(_spec.Node.Columns, sorausagestat.FieldID)
for i := range fields {
if fields[i] != sorausagestat.FieldID {
_spec.Node.Columns = append(_spec.Node.Columns, fields[i])
}
}
}
if ps := _q.predicates; len(ps) > 0 {
_spec.Predicate = func(selector *sql.Selector) {
for i := range ps {
ps[i](selector)
}
}
}
if limit := _q.ctx.Limit; limit != nil {
_spec.Limit = *limit
}
if offset := _q.ctx.Offset; offset != nil {
_spec.Offset = *offset
}
if ps := _q.order; len(ps) > 0 {
_spec.Order = func(selector *sql.Selector) {
for i := range ps {
ps[i](selector)
}
}
}
return _spec
}
func (_q *SoraUsageStatQuery) sqlQuery(ctx context.Context) *sql.Selector {
builder := sql.Dialect(_q.driver.Dialect())
t1 := builder.Table(sorausagestat.Table)
columns := _q.ctx.Fields
if len(columns) == 0 {
columns = sorausagestat.Columns
}
selector := builder.Select(t1.Columns(columns...)...).From(t1)
if _q.sql != nil {
selector = _q.sql
selector.Select(selector.Columns(columns...)...)
}
if _q.ctx.Unique != nil && *_q.ctx.Unique {
selector.Distinct()
}
for _, m := range _q.modifiers {
m(selector)
}
for _, p := range _q.predicates {
p(selector)
}
for _, p := range _q.order {
p(selector)
}
if offset := _q.ctx.Offset; offset != nil {
// limit is mandatory for offset clause. We start
// with default value, and override it below if needed.
selector.Offset(*offset).Limit(math.MaxInt32)
}
if limit := _q.ctx.Limit; limit != nil {
selector.Limit(*limit)
}
return selector
}
// ForUpdate locks the selected rows against concurrent updates, and prevent them from being
// updated, deleted or "selected ... for update" by other sessions, until the transaction is
// either committed or rolled-back.
func (_q *SoraUsageStatQuery) ForUpdate(opts ...sql.LockOption) *SoraUsageStatQuery {
if _q.driver.Dialect() == dialect.Postgres {
_q.Unique(false)
}
_q.modifiers = append(_q.modifiers, func(s *sql.Selector) {
s.ForUpdate(opts...)
})
return _q
}
// ForShare behaves similarly to ForUpdate, except that it acquires a shared mode lock
// on any rows that are read. Other sessions can read the rows, but cannot modify them
// until your transaction commits.
func (_q *SoraUsageStatQuery) ForShare(opts ...sql.LockOption) *SoraUsageStatQuery {
if _q.driver.Dialect() == dialect.Postgres {
_q.Unique(false)
}
_q.modifiers = append(_q.modifiers, func(s *sql.Selector) {
s.ForShare(opts...)
})
return _q
}
// SoraUsageStatGroupBy is the group-by builder for SoraUsageStat entities.
type SoraUsageStatGroupBy struct {
selector
build *SoraUsageStatQuery
}
// Aggregate adds the given aggregation functions to the group-by query.
func (_g *SoraUsageStatGroupBy) Aggregate(fns ...AggregateFunc) *SoraUsageStatGroupBy {
_g.fns = append(_g.fns, fns...)
return _g
}
// Scan applies the selector query and scans the result into the given value.
func (_g *SoraUsageStatGroupBy) Scan(ctx context.Context, v any) error {
ctx = setContextOp(ctx, _g.build.ctx, ent.OpQueryGroupBy)
if err := _g.build.prepareQuery(ctx); err != nil {
return err
}
return scanWithInterceptors[*SoraUsageStatQuery, *SoraUsageStatGroupBy](ctx, _g.build, _g, _g.build.inters, v)
}
func (_g *SoraUsageStatGroupBy) sqlScan(ctx context.Context, root *SoraUsageStatQuery, v any) error {
selector := root.sqlQuery(ctx).Select()
aggregation := make([]string, 0, len(_g.fns))
for _, fn := range _g.fns {
aggregation = append(aggregation, fn(selector))
}
if len(selector.SelectedColumns()) == 0 {
columns := make([]string, 0, len(*_g.flds)+len(_g.fns))
for _, f := range *_g.flds {
columns = append(columns, selector.C(f))
}
columns = append(columns, aggregation...)
selector.Select(columns...)
}
selector.GroupBy(selector.Columns(*_g.flds...)...)
if err := selector.Err(); err != nil {
return err
}
rows := &sql.Rows{}
query, args := selector.Query()
if err := _g.build.driver.Query(ctx, query, args, rows); err != nil {
return err
}
defer rows.Close()
return sql.ScanSlice(rows, v)
}
// SoraUsageStatSelect is the builder for selecting fields of SoraUsageStat entities.
type SoraUsageStatSelect struct {
*SoraUsageStatQuery
selector
}
// Aggregate adds the given aggregation functions to the selector query.
func (_s *SoraUsageStatSelect) Aggregate(fns ...AggregateFunc) *SoraUsageStatSelect {
_s.fns = append(_s.fns, fns...)
return _s
}
// Scan applies the selector query and scans the result into the given value.
func (_s *SoraUsageStatSelect) Scan(ctx context.Context, v any) error {
ctx = setContextOp(ctx, _s.ctx, ent.OpQuerySelect)
if err := _s.prepareQuery(ctx); err != nil {
return err
}
return scanWithInterceptors[*SoraUsageStatQuery, *SoraUsageStatSelect](ctx, _s.SoraUsageStatQuery, _s, _s.inters, v)
}
func (_s *SoraUsageStatSelect) sqlScan(ctx context.Context, root *SoraUsageStatQuery, v any) error {
selector := root.sqlQuery(ctx)
aggregation := make([]string, 0, len(_s.fns))
for _, fn := range _s.fns {
aggregation = append(aggregation, fn(selector))
}
switch n := len(*_s.selector.flds); {
case n == 0 && len(aggregation) > 0:
selector.Select(aggregation...)
case n != 0 && len(aggregation) > 0:
selector.AppendSelect(aggregation...)
}
rows := &sql.Rows{}
query, args := selector.Query()
if err := _s.driver.Query(ctx, query, args, rows); err != nil {
return err
}
defer rows.Close()
return sql.ScanSlice(rows, v)
}
// Code generated by ent, DO NOT EDIT.
package ent
import (
"context"
"errors"
"fmt"
"time"
"entgo.io/ent/dialect/sql"
"entgo.io/ent/dialect/sql/sqlgraph"
"entgo.io/ent/schema/field"
"github.com/Wei-Shaw/sub2api/ent/predicate"
"github.com/Wei-Shaw/sub2api/ent/sorausagestat"
)
// SoraUsageStatUpdate is the builder for updating SoraUsageStat entities.
type SoraUsageStatUpdate struct {
config
hooks []Hook
mutation *SoraUsageStatMutation
}
// Where appends a list predicates to the SoraUsageStatUpdate builder.
func (_u *SoraUsageStatUpdate) Where(ps ...predicate.SoraUsageStat) *SoraUsageStatUpdate {
_u.mutation.Where(ps...)
return _u
}
// SetUpdatedAt sets the "updated_at" field.
func (_u *SoraUsageStatUpdate) SetUpdatedAt(v time.Time) *SoraUsageStatUpdate {
_u.mutation.SetUpdatedAt(v)
return _u
}
// SetAccountID sets the "account_id" field.
func (_u *SoraUsageStatUpdate) SetAccountID(v int64) *SoraUsageStatUpdate {
_u.mutation.ResetAccountID()
_u.mutation.SetAccountID(v)
return _u
}
// SetNillableAccountID sets the "account_id" field if the given value is not nil.
func (_u *SoraUsageStatUpdate) SetNillableAccountID(v *int64) *SoraUsageStatUpdate {
if v != nil {
_u.SetAccountID(*v)
}
return _u
}
// AddAccountID adds value to the "account_id" field.
func (_u *SoraUsageStatUpdate) AddAccountID(v int64) *SoraUsageStatUpdate {
_u.mutation.AddAccountID(v)
return _u
}
// SetImageCount sets the "image_count" field.
func (_u *SoraUsageStatUpdate) SetImageCount(v int) *SoraUsageStatUpdate {
_u.mutation.ResetImageCount()
_u.mutation.SetImageCount(v)
return _u
}
// SetNillableImageCount sets the "image_count" field if the given value is not nil.
func (_u *SoraUsageStatUpdate) SetNillableImageCount(v *int) *SoraUsageStatUpdate {
if v != nil {
_u.SetImageCount(*v)
}
return _u
}
// AddImageCount adds value to the "image_count" field.
func (_u *SoraUsageStatUpdate) AddImageCount(v int) *SoraUsageStatUpdate {
_u.mutation.AddImageCount(v)
return _u
}
// SetVideoCount sets the "video_count" field.
func (_u *SoraUsageStatUpdate) SetVideoCount(v int) *SoraUsageStatUpdate {
_u.mutation.ResetVideoCount()
_u.mutation.SetVideoCount(v)
return _u
}
// SetNillableVideoCount sets the "video_count" field if the given value is not nil.
func (_u *SoraUsageStatUpdate) SetNillableVideoCount(v *int) *SoraUsageStatUpdate {
if v != nil {
_u.SetVideoCount(*v)
}
return _u
}
// AddVideoCount adds value to the "video_count" field.
func (_u *SoraUsageStatUpdate) AddVideoCount(v int) *SoraUsageStatUpdate {
_u.mutation.AddVideoCount(v)
return _u
}
// SetErrorCount sets the "error_count" field.
func (_u *SoraUsageStatUpdate) SetErrorCount(v int) *SoraUsageStatUpdate {
_u.mutation.ResetErrorCount()
_u.mutation.SetErrorCount(v)
return _u
}
// SetNillableErrorCount sets the "error_count" field if the given value is not nil.
func (_u *SoraUsageStatUpdate) SetNillableErrorCount(v *int) *SoraUsageStatUpdate {
if v != nil {
_u.SetErrorCount(*v)
}
return _u
}
// AddErrorCount adds value to the "error_count" field.
func (_u *SoraUsageStatUpdate) AddErrorCount(v int) *SoraUsageStatUpdate {
_u.mutation.AddErrorCount(v)
return _u
}
// SetLastErrorAt sets the "last_error_at" field.
func (_u *SoraUsageStatUpdate) SetLastErrorAt(v time.Time) *SoraUsageStatUpdate {
_u.mutation.SetLastErrorAt(v)
return _u
}
// SetNillableLastErrorAt sets the "last_error_at" field if the given value is not nil.
func (_u *SoraUsageStatUpdate) SetNillableLastErrorAt(v *time.Time) *SoraUsageStatUpdate {
if v != nil {
_u.SetLastErrorAt(*v)
}
return _u
}
// ClearLastErrorAt clears the value of the "last_error_at" field.
func (_u *SoraUsageStatUpdate) ClearLastErrorAt() *SoraUsageStatUpdate {
_u.mutation.ClearLastErrorAt()
return _u
}
// SetTodayImageCount sets the "today_image_count" field.
func (_u *SoraUsageStatUpdate) SetTodayImageCount(v int) *SoraUsageStatUpdate {
_u.mutation.ResetTodayImageCount()
_u.mutation.SetTodayImageCount(v)
return _u
}
// SetNillableTodayImageCount sets the "today_image_count" field if the given value is not nil.
func (_u *SoraUsageStatUpdate) SetNillableTodayImageCount(v *int) *SoraUsageStatUpdate {
if v != nil {
_u.SetTodayImageCount(*v)
}
return _u
}
// AddTodayImageCount adds value to the "today_image_count" field.
func (_u *SoraUsageStatUpdate) AddTodayImageCount(v int) *SoraUsageStatUpdate {
_u.mutation.AddTodayImageCount(v)
return _u
}
// SetTodayVideoCount sets the "today_video_count" field.
func (_u *SoraUsageStatUpdate) SetTodayVideoCount(v int) *SoraUsageStatUpdate {
_u.mutation.ResetTodayVideoCount()
_u.mutation.SetTodayVideoCount(v)
return _u
}
// SetNillableTodayVideoCount sets the "today_video_count" field if the given value is not nil.
func (_u *SoraUsageStatUpdate) SetNillableTodayVideoCount(v *int) *SoraUsageStatUpdate {
if v != nil {
_u.SetTodayVideoCount(*v)
}
return _u
}
// AddTodayVideoCount adds value to the "today_video_count" field.
func (_u *SoraUsageStatUpdate) AddTodayVideoCount(v int) *SoraUsageStatUpdate {
_u.mutation.AddTodayVideoCount(v)
return _u
}
// SetTodayErrorCount sets the "today_error_count" field.
func (_u *SoraUsageStatUpdate) SetTodayErrorCount(v int) *SoraUsageStatUpdate {
_u.mutation.ResetTodayErrorCount()
_u.mutation.SetTodayErrorCount(v)
return _u
}
// SetNillableTodayErrorCount sets the "today_error_count" field if the given value is not nil.
func (_u *SoraUsageStatUpdate) SetNillableTodayErrorCount(v *int) *SoraUsageStatUpdate {
if v != nil {
_u.SetTodayErrorCount(*v)
}
return _u
}
// AddTodayErrorCount adds value to the "today_error_count" field.
func (_u *SoraUsageStatUpdate) AddTodayErrorCount(v int) *SoraUsageStatUpdate {
_u.mutation.AddTodayErrorCount(v)
return _u
}
// SetTodayDate sets the "today_date" field.
func (_u *SoraUsageStatUpdate) SetTodayDate(v time.Time) *SoraUsageStatUpdate {
_u.mutation.SetTodayDate(v)
return _u
}
// SetNillableTodayDate sets the "today_date" field if the given value is not nil.
func (_u *SoraUsageStatUpdate) SetNillableTodayDate(v *time.Time) *SoraUsageStatUpdate {
if v != nil {
_u.SetTodayDate(*v)
}
return _u
}
// ClearTodayDate clears the value of the "today_date" field.
func (_u *SoraUsageStatUpdate) ClearTodayDate() *SoraUsageStatUpdate {
_u.mutation.ClearTodayDate()
return _u
}
// SetConsecutiveErrorCount sets the "consecutive_error_count" field.
func (_u *SoraUsageStatUpdate) SetConsecutiveErrorCount(v int) *SoraUsageStatUpdate {
_u.mutation.ResetConsecutiveErrorCount()
_u.mutation.SetConsecutiveErrorCount(v)
return _u
}
// SetNillableConsecutiveErrorCount sets the "consecutive_error_count" field if the given value is not nil.
func (_u *SoraUsageStatUpdate) SetNillableConsecutiveErrorCount(v *int) *SoraUsageStatUpdate {
if v != nil {
_u.SetConsecutiveErrorCount(*v)
}
return _u
}
// AddConsecutiveErrorCount adds value to the "consecutive_error_count" field.
func (_u *SoraUsageStatUpdate) AddConsecutiveErrorCount(v int) *SoraUsageStatUpdate {
_u.mutation.AddConsecutiveErrorCount(v)
return _u
}
// Mutation returns the SoraUsageStatMutation object of the builder.
func (_u *SoraUsageStatUpdate) Mutation() *SoraUsageStatMutation {
return _u.mutation
}
// Save executes the query and returns the number of nodes affected by the update operation.
func (_u *SoraUsageStatUpdate) Save(ctx context.Context) (int, error) {
_u.defaults()
return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks)
}
// SaveX is like Save, but panics if an error occurs.
func (_u *SoraUsageStatUpdate) SaveX(ctx context.Context) int {
affected, err := _u.Save(ctx)
if err != nil {
panic(err)
}
return affected
}
// Exec executes the query.
func (_u *SoraUsageStatUpdate) Exec(ctx context.Context) error {
_, err := _u.Save(ctx)
return err
}
// ExecX is like Exec, but panics if an error occurs.
func (_u *SoraUsageStatUpdate) ExecX(ctx context.Context) {
if err := _u.Exec(ctx); err != nil {
panic(err)
}
}
// defaults sets the default values of the builder before save.
func (_u *SoraUsageStatUpdate) defaults() {
if _, ok := _u.mutation.UpdatedAt(); !ok {
v := sorausagestat.UpdateDefaultUpdatedAt()
_u.mutation.SetUpdatedAt(v)
}
}
func (_u *SoraUsageStatUpdate) sqlSave(ctx context.Context) (_node int, err error) {
_spec := sqlgraph.NewUpdateSpec(sorausagestat.Table, sorausagestat.Columns, sqlgraph.NewFieldSpec(sorausagestat.FieldID, field.TypeInt64))
if ps := _u.mutation.predicates; len(ps) > 0 {
_spec.Predicate = func(selector *sql.Selector) {
for i := range ps {
ps[i](selector)
}
}
}
if value, ok := _u.mutation.UpdatedAt(); ok {
_spec.SetField(sorausagestat.FieldUpdatedAt, field.TypeTime, value)
}
if value, ok := _u.mutation.AccountID(); ok {
_spec.SetField(sorausagestat.FieldAccountID, field.TypeInt64, value)
}
if value, ok := _u.mutation.AddedAccountID(); ok {
_spec.AddField(sorausagestat.FieldAccountID, field.TypeInt64, value)
}
if value, ok := _u.mutation.ImageCount(); ok {
_spec.SetField(sorausagestat.FieldImageCount, field.TypeInt, value)
}
if value, ok := _u.mutation.AddedImageCount(); ok {
_spec.AddField(sorausagestat.FieldImageCount, field.TypeInt, value)
}
if value, ok := _u.mutation.VideoCount(); ok {
_spec.SetField(sorausagestat.FieldVideoCount, field.TypeInt, value)
}
if value, ok := _u.mutation.AddedVideoCount(); ok {
_spec.AddField(sorausagestat.FieldVideoCount, field.TypeInt, value)
}
if value, ok := _u.mutation.ErrorCount(); ok {
_spec.SetField(sorausagestat.FieldErrorCount, field.TypeInt, value)
}
if value, ok := _u.mutation.AddedErrorCount(); ok {
_spec.AddField(sorausagestat.FieldErrorCount, field.TypeInt, value)
}
if value, ok := _u.mutation.LastErrorAt(); ok {
_spec.SetField(sorausagestat.FieldLastErrorAt, field.TypeTime, value)
}
if _u.mutation.LastErrorAtCleared() {
_spec.ClearField(sorausagestat.FieldLastErrorAt, field.TypeTime)
}
if value, ok := _u.mutation.TodayImageCount(); ok {
_spec.SetField(sorausagestat.FieldTodayImageCount, field.TypeInt, value)
}
if value, ok := _u.mutation.AddedTodayImageCount(); ok {
_spec.AddField(sorausagestat.FieldTodayImageCount, field.TypeInt, value)
}
if value, ok := _u.mutation.TodayVideoCount(); ok {
_spec.SetField(sorausagestat.FieldTodayVideoCount, field.TypeInt, value)
}
if value, ok := _u.mutation.AddedTodayVideoCount(); ok {
_spec.AddField(sorausagestat.FieldTodayVideoCount, field.TypeInt, value)
}
if value, ok := _u.mutation.TodayErrorCount(); ok {
_spec.SetField(sorausagestat.FieldTodayErrorCount, field.TypeInt, value)
}
if value, ok := _u.mutation.AddedTodayErrorCount(); ok {
_spec.AddField(sorausagestat.FieldTodayErrorCount, field.TypeInt, value)
}
if value, ok := _u.mutation.TodayDate(); ok {
_spec.SetField(sorausagestat.FieldTodayDate, field.TypeTime, value)
}
if _u.mutation.TodayDateCleared() {
_spec.ClearField(sorausagestat.FieldTodayDate, field.TypeTime)
}
if value, ok := _u.mutation.ConsecutiveErrorCount(); ok {
_spec.SetField(sorausagestat.FieldConsecutiveErrorCount, field.TypeInt, value)
}
if value, ok := _u.mutation.AddedConsecutiveErrorCount(); ok {
_spec.AddField(sorausagestat.FieldConsecutiveErrorCount, field.TypeInt, value)
}
if _node, err = sqlgraph.UpdateNodes(ctx, _u.driver, _spec); err != nil {
if _, ok := err.(*sqlgraph.NotFoundError); ok {
err = &NotFoundError{sorausagestat.Label}
} else if sqlgraph.IsConstraintError(err) {
err = &ConstraintError{msg: err.Error(), wrap: err}
}
return 0, err
}
_u.mutation.done = true
return _node, nil
}
// SoraUsageStatUpdateOne is the builder for updating a single SoraUsageStat entity.
type SoraUsageStatUpdateOne struct {
config
fields []string
hooks []Hook
mutation *SoraUsageStatMutation
}
// SetUpdatedAt sets the "updated_at" field.
func (_u *SoraUsageStatUpdateOne) SetUpdatedAt(v time.Time) *SoraUsageStatUpdateOne {
_u.mutation.SetUpdatedAt(v)
return _u
}
// SetAccountID sets the "account_id" field.
func (_u *SoraUsageStatUpdateOne) SetAccountID(v int64) *SoraUsageStatUpdateOne {
_u.mutation.ResetAccountID()
_u.mutation.SetAccountID(v)
return _u
}
// SetNillableAccountID sets the "account_id" field if the given value is not nil.
func (_u *SoraUsageStatUpdateOne) SetNillableAccountID(v *int64) *SoraUsageStatUpdateOne {
if v != nil {
_u.SetAccountID(*v)
}
return _u
}
// AddAccountID adds value to the "account_id" field.
func (_u *SoraUsageStatUpdateOne) AddAccountID(v int64) *SoraUsageStatUpdateOne {
_u.mutation.AddAccountID(v)
return _u
}
// SetImageCount sets the "image_count" field.
func (_u *SoraUsageStatUpdateOne) SetImageCount(v int) *SoraUsageStatUpdateOne {
_u.mutation.ResetImageCount()
_u.mutation.SetImageCount(v)
return _u
}
// SetNillableImageCount sets the "image_count" field if the given value is not nil.
func (_u *SoraUsageStatUpdateOne) SetNillableImageCount(v *int) *SoraUsageStatUpdateOne {
if v != nil {
_u.SetImageCount(*v)
}
return _u
}
// AddImageCount adds value to the "image_count" field.
func (_u *SoraUsageStatUpdateOne) AddImageCount(v int) *SoraUsageStatUpdateOne {
_u.mutation.AddImageCount(v)
return _u
}
// SetVideoCount sets the "video_count" field.
func (_u *SoraUsageStatUpdateOne) SetVideoCount(v int) *SoraUsageStatUpdateOne {
_u.mutation.ResetVideoCount()
_u.mutation.SetVideoCount(v)
return _u
}
// SetNillableVideoCount sets the "video_count" field if the given value is not nil.
func (_u *SoraUsageStatUpdateOne) SetNillableVideoCount(v *int) *SoraUsageStatUpdateOne {
if v != nil {
_u.SetVideoCount(*v)
}
return _u
}
// AddVideoCount adds value to the "video_count" field.
func (_u *SoraUsageStatUpdateOne) AddVideoCount(v int) *SoraUsageStatUpdateOne {
_u.mutation.AddVideoCount(v)
return _u
}
// SetErrorCount sets the "error_count" field.
func (_u *SoraUsageStatUpdateOne) SetErrorCount(v int) *SoraUsageStatUpdateOne {
_u.mutation.ResetErrorCount()
_u.mutation.SetErrorCount(v)
return _u
}
// SetNillableErrorCount sets the "error_count" field if the given value is not nil.
func (_u *SoraUsageStatUpdateOne) SetNillableErrorCount(v *int) *SoraUsageStatUpdateOne {
if v != nil {
_u.SetErrorCount(*v)
}
return _u
}
// AddErrorCount adds value to the "error_count" field.
func (_u *SoraUsageStatUpdateOne) AddErrorCount(v int) *SoraUsageStatUpdateOne {
_u.mutation.AddErrorCount(v)
return _u
}
// SetLastErrorAt sets the "last_error_at" field.
func (_u *SoraUsageStatUpdateOne) SetLastErrorAt(v time.Time) *SoraUsageStatUpdateOne {
_u.mutation.SetLastErrorAt(v)
return _u
}
// SetNillableLastErrorAt sets the "last_error_at" field if the given value is not nil.
func (_u *SoraUsageStatUpdateOne) SetNillableLastErrorAt(v *time.Time) *SoraUsageStatUpdateOne {
if v != nil {
_u.SetLastErrorAt(*v)
}
return _u
}
// ClearLastErrorAt clears the value of the "last_error_at" field.
func (_u *SoraUsageStatUpdateOne) ClearLastErrorAt() *SoraUsageStatUpdateOne {
_u.mutation.ClearLastErrorAt()
return _u
}
// SetTodayImageCount sets the "today_image_count" field.
func (_u *SoraUsageStatUpdateOne) SetTodayImageCount(v int) *SoraUsageStatUpdateOne {
_u.mutation.ResetTodayImageCount()
_u.mutation.SetTodayImageCount(v)
return _u
}
// SetNillableTodayImageCount sets the "today_image_count" field if the given value is not nil.
func (_u *SoraUsageStatUpdateOne) SetNillableTodayImageCount(v *int) *SoraUsageStatUpdateOne {
if v != nil {
_u.SetTodayImageCount(*v)
}
return _u
}
// AddTodayImageCount adds value to the "today_image_count" field.
func (_u *SoraUsageStatUpdateOne) AddTodayImageCount(v int) *SoraUsageStatUpdateOne {
_u.mutation.AddTodayImageCount(v)
return _u
}
// SetTodayVideoCount sets the "today_video_count" field.
func (_u *SoraUsageStatUpdateOne) SetTodayVideoCount(v int) *SoraUsageStatUpdateOne {
_u.mutation.ResetTodayVideoCount()
_u.mutation.SetTodayVideoCount(v)
return _u
}
// SetNillableTodayVideoCount sets the "today_video_count" field if the given value is not nil.
func (_u *SoraUsageStatUpdateOne) SetNillableTodayVideoCount(v *int) *SoraUsageStatUpdateOne {
if v != nil {
_u.SetTodayVideoCount(*v)
}
return _u
}
// AddTodayVideoCount adds value to the "today_video_count" field.
func (_u *SoraUsageStatUpdateOne) AddTodayVideoCount(v int) *SoraUsageStatUpdateOne {
_u.mutation.AddTodayVideoCount(v)
return _u
}
// SetTodayErrorCount sets the "today_error_count" field.
func (_u *SoraUsageStatUpdateOne) SetTodayErrorCount(v int) *SoraUsageStatUpdateOne {
_u.mutation.ResetTodayErrorCount()
_u.mutation.SetTodayErrorCount(v)
return _u
}
// SetNillableTodayErrorCount sets the "today_error_count" field if the given value is not nil.
func (_u *SoraUsageStatUpdateOne) SetNillableTodayErrorCount(v *int) *SoraUsageStatUpdateOne {
if v != nil {
_u.SetTodayErrorCount(*v)
}
return _u
}
// AddTodayErrorCount adds value to the "today_error_count" field.
func (_u *SoraUsageStatUpdateOne) AddTodayErrorCount(v int) *SoraUsageStatUpdateOne {
_u.mutation.AddTodayErrorCount(v)
return _u
}
// SetTodayDate sets the "today_date" field.
func (_u *SoraUsageStatUpdateOne) SetTodayDate(v time.Time) *SoraUsageStatUpdateOne {
_u.mutation.SetTodayDate(v)
return _u
}
// SetNillableTodayDate sets the "today_date" field if the given value is not nil.
func (_u *SoraUsageStatUpdateOne) SetNillableTodayDate(v *time.Time) *SoraUsageStatUpdateOne {
if v != nil {
_u.SetTodayDate(*v)
}
return _u
}
// ClearTodayDate clears the value of the "today_date" field.
func (_u *SoraUsageStatUpdateOne) ClearTodayDate() *SoraUsageStatUpdateOne {
_u.mutation.ClearTodayDate()
return _u
}
// SetConsecutiveErrorCount sets the "consecutive_error_count" field.
func (_u *SoraUsageStatUpdateOne) SetConsecutiveErrorCount(v int) *SoraUsageStatUpdateOne {
_u.mutation.ResetConsecutiveErrorCount()
_u.mutation.SetConsecutiveErrorCount(v)
return _u
}
// SetNillableConsecutiveErrorCount sets the "consecutive_error_count" field if the given value is not nil.
func (_u *SoraUsageStatUpdateOne) SetNillableConsecutiveErrorCount(v *int) *SoraUsageStatUpdateOne {
if v != nil {
_u.SetConsecutiveErrorCount(*v)
}
return _u
}
// AddConsecutiveErrorCount adds value to the "consecutive_error_count" field.
func (_u *SoraUsageStatUpdateOne) AddConsecutiveErrorCount(v int) *SoraUsageStatUpdateOne {
_u.mutation.AddConsecutiveErrorCount(v)
return _u
}
// Mutation returns the SoraUsageStatMutation object of the builder.
func (_u *SoraUsageStatUpdateOne) Mutation() *SoraUsageStatMutation {
return _u.mutation
}
// Where appends a list predicates to the SoraUsageStatUpdate builder.
func (_u *SoraUsageStatUpdateOne) Where(ps ...predicate.SoraUsageStat) *SoraUsageStatUpdateOne {
_u.mutation.Where(ps...)
return _u
}
// Select allows selecting one or more fields (columns) of the returned entity.
// The default is selecting all fields defined in the entity schema.
func (_u *SoraUsageStatUpdateOne) Select(field string, fields ...string) *SoraUsageStatUpdateOne {
_u.fields = append([]string{field}, fields...)
return _u
}
// Save executes the query and returns the updated SoraUsageStat entity.
func (_u *SoraUsageStatUpdateOne) Save(ctx context.Context) (*SoraUsageStat, error) {
_u.defaults()
return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks)
}
// SaveX is like Save, but panics if an error occurs.
func (_u *SoraUsageStatUpdateOne) SaveX(ctx context.Context) *SoraUsageStat {
node, err := _u.Save(ctx)
if err != nil {
panic(err)
}
return node
}
// Exec executes the query on the entity.
func (_u *SoraUsageStatUpdateOne) Exec(ctx context.Context) error {
_, err := _u.Save(ctx)
return err
}
// ExecX is like Exec, but panics if an error occurs.
func (_u *SoraUsageStatUpdateOne) ExecX(ctx context.Context) {
if err := _u.Exec(ctx); err != nil {
panic(err)
}
}
// defaults sets the default values of the builder before save.
func (_u *SoraUsageStatUpdateOne) defaults() {
if _, ok := _u.mutation.UpdatedAt(); !ok {
v := sorausagestat.UpdateDefaultUpdatedAt()
_u.mutation.SetUpdatedAt(v)
}
}
func (_u *SoraUsageStatUpdateOne) sqlSave(ctx context.Context) (_node *SoraUsageStat, err error) {
_spec := sqlgraph.NewUpdateSpec(sorausagestat.Table, sorausagestat.Columns, sqlgraph.NewFieldSpec(sorausagestat.FieldID, field.TypeInt64))
id, ok := _u.mutation.ID()
if !ok {
return nil, &ValidationError{Name: "id", err: errors.New(`ent: missing "SoraUsageStat.id" for update`)}
}
_spec.Node.ID.Value = id
if fields := _u.fields; len(fields) > 0 {
_spec.Node.Columns = make([]string, 0, len(fields))
_spec.Node.Columns = append(_spec.Node.Columns, sorausagestat.FieldID)
for _, f := range fields {
if !sorausagestat.ValidColumn(f) {
return nil, &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)}
}
if f != sorausagestat.FieldID {
_spec.Node.Columns = append(_spec.Node.Columns, f)
}
}
}
if ps := _u.mutation.predicates; len(ps) > 0 {
_spec.Predicate = func(selector *sql.Selector) {
for i := range ps {
ps[i](selector)
}
}
}
if value, ok := _u.mutation.UpdatedAt(); ok {
_spec.SetField(sorausagestat.FieldUpdatedAt, field.TypeTime, value)
}
if value, ok := _u.mutation.AccountID(); ok {
_spec.SetField(sorausagestat.FieldAccountID, field.TypeInt64, value)
}
if value, ok := _u.mutation.AddedAccountID(); ok {
_spec.AddField(sorausagestat.FieldAccountID, field.TypeInt64, value)
}
if value, ok := _u.mutation.ImageCount(); ok {
_spec.SetField(sorausagestat.FieldImageCount, field.TypeInt, value)
}
if value, ok := _u.mutation.AddedImageCount(); ok {
_spec.AddField(sorausagestat.FieldImageCount, field.TypeInt, value)
}
if value, ok := _u.mutation.VideoCount(); ok {
_spec.SetField(sorausagestat.FieldVideoCount, field.TypeInt, value)
}
if value, ok := _u.mutation.AddedVideoCount(); ok {
_spec.AddField(sorausagestat.FieldVideoCount, field.TypeInt, value)
}
if value, ok := _u.mutation.ErrorCount(); ok {
_spec.SetField(sorausagestat.FieldErrorCount, field.TypeInt, value)
}
if value, ok := _u.mutation.AddedErrorCount(); ok {
_spec.AddField(sorausagestat.FieldErrorCount, field.TypeInt, value)
}
if value, ok := _u.mutation.LastErrorAt(); ok {
_spec.SetField(sorausagestat.FieldLastErrorAt, field.TypeTime, value)
}
if _u.mutation.LastErrorAtCleared() {
_spec.ClearField(sorausagestat.FieldLastErrorAt, field.TypeTime)
}
if value, ok := _u.mutation.TodayImageCount(); ok {
_spec.SetField(sorausagestat.FieldTodayImageCount, field.TypeInt, value)
}
if value, ok := _u.mutation.AddedTodayImageCount(); ok {
_spec.AddField(sorausagestat.FieldTodayImageCount, field.TypeInt, value)
}
if value, ok := _u.mutation.TodayVideoCount(); ok {
_spec.SetField(sorausagestat.FieldTodayVideoCount, field.TypeInt, value)
}
if value, ok := _u.mutation.AddedTodayVideoCount(); ok {
_spec.AddField(sorausagestat.FieldTodayVideoCount, field.TypeInt, value)
}
if value, ok := _u.mutation.TodayErrorCount(); ok {
_spec.SetField(sorausagestat.FieldTodayErrorCount, field.TypeInt, value)
}
if value, ok := _u.mutation.AddedTodayErrorCount(); ok {
_spec.AddField(sorausagestat.FieldTodayErrorCount, field.TypeInt, value)
}
if value, ok := _u.mutation.TodayDate(); ok {
_spec.SetField(sorausagestat.FieldTodayDate, field.TypeTime, value)
}
if _u.mutation.TodayDateCleared() {
_spec.ClearField(sorausagestat.FieldTodayDate, field.TypeTime)
}
if value, ok := _u.mutation.ConsecutiveErrorCount(); ok {
_spec.SetField(sorausagestat.FieldConsecutiveErrorCount, field.TypeInt, value)
}
if value, ok := _u.mutation.AddedConsecutiveErrorCount(); ok {
_spec.AddField(sorausagestat.FieldConsecutiveErrorCount, field.TypeInt, value)
}
_node = &SoraUsageStat{config: _u.config}
_spec.Assign = _node.assignValues
_spec.ScanValues = _node.scanValues
if err = sqlgraph.UpdateNode(ctx, _u.driver, _spec); err != nil {
if _, ok := err.(*sqlgraph.NotFoundError); ok {
err = &NotFoundError{sorausagestat.Label}
} else if sqlgraph.IsConstraintError(err) {
err = &ConstraintError{msg: err.Error(), wrap: err}
}
return nil, err
}
_u.mutation.done = true
return _node, nil
}
...@@ -32,6 +32,14 @@ type Tx struct { ...@@ -32,6 +32,14 @@ type Tx struct {
RedeemCode *RedeemCodeClient RedeemCode *RedeemCodeClient
// Setting is the client for interacting with the Setting builders. // Setting is the client for interacting with the Setting builders.
Setting *SettingClient Setting *SettingClient
// SoraAccount is the client for interacting with the SoraAccount builders.
SoraAccount *SoraAccountClient
// SoraCacheFile is the client for interacting with the SoraCacheFile builders.
SoraCacheFile *SoraCacheFileClient
// SoraTask is the client for interacting with the SoraTask builders.
SoraTask *SoraTaskClient
// SoraUsageStat is the client for interacting with the SoraUsageStat builders.
SoraUsageStat *SoraUsageStatClient
// UsageCleanupTask is the client for interacting with the UsageCleanupTask builders. // UsageCleanupTask is the client for interacting with the UsageCleanupTask builders.
UsageCleanupTask *UsageCleanupTaskClient UsageCleanupTask *UsageCleanupTaskClient
// UsageLog is the client for interacting with the UsageLog builders. // UsageLog is the client for interacting with the UsageLog builders.
...@@ -186,6 +194,10 @@ func (tx *Tx) init() { ...@@ -186,6 +194,10 @@ func (tx *Tx) init() {
tx.Proxy = NewProxyClient(tx.config) tx.Proxy = NewProxyClient(tx.config)
tx.RedeemCode = NewRedeemCodeClient(tx.config) tx.RedeemCode = NewRedeemCodeClient(tx.config)
tx.Setting = NewSettingClient(tx.config) tx.Setting = NewSettingClient(tx.config)
tx.SoraAccount = NewSoraAccountClient(tx.config)
tx.SoraCacheFile = NewSoraCacheFileClient(tx.config)
tx.SoraTask = NewSoraTaskClient(tx.config)
tx.SoraUsageStat = NewSoraUsageStatClient(tx.config)
tx.UsageCleanupTask = NewUsageCleanupTaskClient(tx.config) tx.UsageCleanupTask = NewUsageCleanupTaskClient(tx.config)
tx.UsageLog = NewUsageLogClient(tx.config) tx.UsageLog = NewUsageLogClient(tx.config)
tx.User = NewUserClient(tx.config) tx.User = NewUserClient(tx.config)
......
...@@ -58,6 +58,7 @@ type Config struct { ...@@ -58,6 +58,7 @@ type Config struct {
UsageCleanup UsageCleanupConfig `mapstructure:"usage_cleanup"` UsageCleanup UsageCleanupConfig `mapstructure:"usage_cleanup"`
Concurrency ConcurrencyConfig `mapstructure:"concurrency"` Concurrency ConcurrencyConfig `mapstructure:"concurrency"`
TokenRefresh TokenRefreshConfig `mapstructure:"token_refresh"` TokenRefresh TokenRefreshConfig `mapstructure:"token_refresh"`
Sora SoraConfig `mapstructure:"sora"`
RunMode string `mapstructure:"run_mode" yaml:"run_mode"` RunMode string `mapstructure:"run_mode" yaml:"run_mode"`
Timezone string `mapstructure:"timezone"` // e.g. "Asia/Shanghai", "UTC" Timezone string `mapstructure:"timezone"` // e.g. "Asia/Shanghai", "UTC"
Gemini GeminiConfig `mapstructure:"gemini"` Gemini GeminiConfig `mapstructure:"gemini"`
...@@ -69,6 +70,38 @@ type GeminiConfig struct { ...@@ -69,6 +70,38 @@ type GeminiConfig struct {
Quota GeminiQuotaConfig `mapstructure:"quota"` Quota GeminiQuotaConfig `mapstructure:"quota"`
} }
type SoraConfig struct {
BaseURL string `mapstructure:"base_url"`
Timeout int `mapstructure:"timeout"`
MaxRetries int `mapstructure:"max_retries"`
PollInterval float64 `mapstructure:"poll_interval"`
CallLogicMode string `mapstructure:"call_logic_mode"`
Cache SoraCacheConfig `mapstructure:"cache"`
WatermarkFree SoraWatermarkFreeConfig `mapstructure:"watermark_free"`
TokenRefresh SoraTokenRefreshConfig `mapstructure:"token_refresh"`
}
type SoraCacheConfig struct {
Enabled bool `mapstructure:"enabled"`
BaseDir string `mapstructure:"base_dir"`
VideoDir string `mapstructure:"video_dir"`
MaxBytes int64 `mapstructure:"max_bytes"`
AllowedHosts []string `mapstructure:"allowed_hosts"`
UserDirEnabled bool `mapstructure:"user_dir_enabled"`
}
type SoraWatermarkFreeConfig struct {
Enabled bool `mapstructure:"enabled"`
ParseMethod string `mapstructure:"parse_method"`
CustomParseURL string `mapstructure:"custom_parse_url"`
CustomParseToken string `mapstructure:"custom_parse_token"`
FallbackOnFailure bool `mapstructure:"fallback_on_failure"`
}
type SoraTokenRefreshConfig struct {
Enabled bool `mapstructure:"enabled"`
}
type GeminiOAuthConfig struct { type GeminiOAuthConfig struct {
ClientID string `mapstructure:"client_id"` ClientID string `mapstructure:"client_id"`
ClientSecret string `mapstructure:"client_secret"` ClientSecret string `mapstructure:"client_secret"`
...@@ -862,6 +895,24 @@ func setDefaults() { ...@@ -862,6 +895,24 @@ func setDefaults() {
viper.SetDefault("token_refresh.max_retries", 3) // 最多重试3次 viper.SetDefault("token_refresh.max_retries", 3) // 最多重试3次
viper.SetDefault("token_refresh.retry_backoff_seconds", 2) // 重试退避基础2秒 viper.SetDefault("token_refresh.retry_backoff_seconds", 2) // 重试退避基础2秒
viper.SetDefault("sora.base_url", "https://sora.chatgpt.com/backend")
viper.SetDefault("sora.timeout", 120)
viper.SetDefault("sora.max_retries", 3)
viper.SetDefault("sora.poll_interval", 2.5)
viper.SetDefault("sora.call_logic_mode", "default")
viper.SetDefault("sora.cache.enabled", false)
viper.SetDefault("sora.cache.base_dir", "tmp/sora")
viper.SetDefault("sora.cache.video_dir", "data/video")
viper.SetDefault("sora.cache.max_bytes", int64(0))
viper.SetDefault("sora.cache.allowed_hosts", []string{})
viper.SetDefault("sora.cache.user_dir_enabled", true)
viper.SetDefault("sora.watermark_free.enabled", false)
viper.SetDefault("sora.watermark_free.parse_method", "third_party")
viper.SetDefault("sora.watermark_free.custom_parse_url", "")
viper.SetDefault("sora.watermark_free.custom_parse_token", "")
viper.SetDefault("sora.watermark_free.fallback_on_failure", true)
viper.SetDefault("sora.token_refresh.enabled", false)
// Gemini OAuth - configure via environment variables or config file // Gemini OAuth - configure via environment variables or config file
// GEMINI_OAUTH_CLIENT_ID and GEMINI_OAUTH_CLIENT_SECRET // GEMINI_OAUTH_CLIENT_ID and GEMINI_OAUTH_CLIENT_SECRET
// Default: uses Gemini CLI public credentials (set via environment) // Default: uses Gemini CLI public credentials (set via environment)
......
...@@ -27,7 +27,7 @@ func NewGroupHandler(adminService service.AdminService) *GroupHandler { ...@@ -27,7 +27,7 @@ func NewGroupHandler(adminService service.AdminService) *GroupHandler {
type CreateGroupRequest struct { type CreateGroupRequest struct {
Name string `json:"name" binding:"required"` Name string `json:"name" binding:"required"`
Description string `json:"description"` Description string `json:"description"`
Platform string `json:"platform" binding:"omitempty,oneof=anthropic openai gemini antigravity"` Platform string `json:"platform" binding:"omitempty,oneof=anthropic openai gemini antigravity sora"`
RateMultiplier float64 `json:"rate_multiplier"` RateMultiplier float64 `json:"rate_multiplier"`
IsExclusive bool `json:"is_exclusive"` IsExclusive bool `json:"is_exclusive"`
SubscriptionType string `json:"subscription_type" binding:"omitempty,oneof=standard subscription"` SubscriptionType string `json:"subscription_type" binding:"omitempty,oneof=standard subscription"`
...@@ -49,7 +49,7 @@ type CreateGroupRequest struct { ...@@ -49,7 +49,7 @@ type CreateGroupRequest struct {
type UpdateGroupRequest struct { type UpdateGroupRequest struct {
Name string `json:"name"` Name string `json:"name"`
Description string `json:"description"` Description string `json:"description"`
Platform string `json:"platform" binding:"omitempty,oneof=anthropic openai gemini antigravity"` Platform string `json:"platform" binding:"omitempty,oneof=anthropic openai gemini antigravity sora"`
RateMultiplier *float64 `json:"rate_multiplier"` RateMultiplier *float64 `json:"rate_multiplier"`
IsExclusive *bool `json:"is_exclusive"` IsExclusive *bool `json:"is_exclusive"`
Status string `json:"status" binding:"omitempty,oneof=active inactive"` Status string `json:"status" binding:"omitempty,oneof=active inactive"`
......
...@@ -79,6 +79,23 @@ func (h *SettingHandler) GetSettings(c *gin.Context) { ...@@ -79,6 +79,23 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
FallbackModelAntigravity: settings.FallbackModelAntigravity, FallbackModelAntigravity: settings.FallbackModelAntigravity,
EnableIdentityPatch: settings.EnableIdentityPatch, EnableIdentityPatch: settings.EnableIdentityPatch,
IdentityPatchPrompt: settings.IdentityPatchPrompt, IdentityPatchPrompt: settings.IdentityPatchPrompt,
SoraBaseURL: settings.SoraBaseURL,
SoraTimeout: settings.SoraTimeout,
SoraMaxRetries: settings.SoraMaxRetries,
SoraPollInterval: settings.SoraPollInterval,
SoraCallLogicMode: settings.SoraCallLogicMode,
SoraCacheEnabled: settings.SoraCacheEnabled,
SoraCacheBaseDir: settings.SoraCacheBaseDir,
SoraCacheVideoDir: settings.SoraCacheVideoDir,
SoraCacheMaxBytes: settings.SoraCacheMaxBytes,
SoraCacheAllowedHosts: settings.SoraCacheAllowedHosts,
SoraCacheUserDirEnabled: settings.SoraCacheUserDirEnabled,
SoraWatermarkFreeEnabled: settings.SoraWatermarkFreeEnabled,
SoraWatermarkFreeParseMethod: settings.SoraWatermarkFreeParseMethod,
SoraWatermarkFreeCustomParseURL: settings.SoraWatermarkFreeCustomParseURL,
SoraWatermarkFreeCustomParseToken: settings.SoraWatermarkFreeCustomParseToken,
SoraWatermarkFreeFallbackOnFailure: settings.SoraWatermarkFreeFallbackOnFailure,
SoraTokenRefreshEnabled: settings.SoraTokenRefreshEnabled,
OpsMonitoringEnabled: opsEnabled && settings.OpsMonitoringEnabled, OpsMonitoringEnabled: opsEnabled && settings.OpsMonitoringEnabled,
OpsRealtimeMonitoringEnabled: settings.OpsRealtimeMonitoringEnabled, OpsRealtimeMonitoringEnabled: settings.OpsRealtimeMonitoringEnabled,
OpsQueryModeDefault: settings.OpsQueryModeDefault, OpsQueryModeDefault: settings.OpsQueryModeDefault,
...@@ -138,6 +155,25 @@ type UpdateSettingsRequest struct { ...@@ -138,6 +155,25 @@ type UpdateSettingsRequest struct {
EnableIdentityPatch bool `json:"enable_identity_patch"` EnableIdentityPatch bool `json:"enable_identity_patch"`
IdentityPatchPrompt string `json:"identity_patch_prompt"` IdentityPatchPrompt string `json:"identity_patch_prompt"`
// Sora configuration
SoraBaseURL string `json:"sora_base_url"`
SoraTimeout int `json:"sora_timeout"`
SoraMaxRetries int `json:"sora_max_retries"`
SoraPollInterval float64 `json:"sora_poll_interval"`
SoraCallLogicMode string `json:"sora_call_logic_mode"`
SoraCacheEnabled bool `json:"sora_cache_enabled"`
SoraCacheBaseDir string `json:"sora_cache_base_dir"`
SoraCacheVideoDir string `json:"sora_cache_video_dir"`
SoraCacheMaxBytes int64 `json:"sora_cache_max_bytes"`
SoraCacheAllowedHosts []string `json:"sora_cache_allowed_hosts"`
SoraCacheUserDirEnabled bool `json:"sora_cache_user_dir_enabled"`
SoraWatermarkFreeEnabled bool `json:"sora_watermark_free_enabled"`
SoraWatermarkFreeParseMethod string `json:"sora_watermark_free_parse_method"`
SoraWatermarkFreeCustomParseURL string `json:"sora_watermark_free_custom_parse_url"`
SoraWatermarkFreeCustomParseToken string `json:"sora_watermark_free_custom_parse_token"`
SoraWatermarkFreeFallbackOnFailure bool `json:"sora_watermark_free_fallback_on_failure"`
SoraTokenRefreshEnabled bool `json:"sora_token_refresh_enabled"`
// Ops monitoring (vNext) // Ops monitoring (vNext)
OpsMonitoringEnabled *bool `json:"ops_monitoring_enabled"` OpsMonitoringEnabled *bool `json:"ops_monitoring_enabled"`
OpsRealtimeMonitoringEnabled *bool `json:"ops_realtime_monitoring_enabled"` OpsRealtimeMonitoringEnabled *bool `json:"ops_realtime_monitoring_enabled"`
...@@ -227,6 +263,32 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { ...@@ -227,6 +263,32 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
} }
} }
// Sora 参数校验与清理
req.SoraBaseURL = strings.TrimSpace(req.SoraBaseURL)
if req.SoraBaseURL == "" {
req.SoraBaseURL = previousSettings.SoraBaseURL
}
if req.SoraBaseURL != "" {
if err := config.ValidateAbsoluteHTTPURL(req.SoraBaseURL); err != nil {
response.BadRequest(c, "Sora Base URL must be an absolute http(s) URL")
return
}
}
if req.SoraTimeout <= 0 {
req.SoraTimeout = previousSettings.SoraTimeout
}
if req.SoraMaxRetries < 0 {
req.SoraMaxRetries = previousSettings.SoraMaxRetries
}
if req.SoraPollInterval <= 0 {
req.SoraPollInterval = previousSettings.SoraPollInterval
}
if req.SoraCacheMaxBytes < 0 {
req.SoraCacheMaxBytes = 0
}
req.SoraCacheAllowedHosts = normalizeStringList(req.SoraCacheAllowedHosts)
req.SoraWatermarkFreeCustomParseURL = strings.TrimSpace(req.SoraWatermarkFreeCustomParseURL)
// Ops metrics collector interval validation (seconds). // Ops metrics collector interval validation (seconds).
if req.OpsMetricsIntervalSeconds != nil { if req.OpsMetricsIntervalSeconds != nil {
v := *req.OpsMetricsIntervalSeconds v := *req.OpsMetricsIntervalSeconds
...@@ -240,40 +302,57 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { ...@@ -240,40 +302,57 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
} }
settings := &service.SystemSettings{ settings := &service.SystemSettings{
RegistrationEnabled: req.RegistrationEnabled, RegistrationEnabled: req.RegistrationEnabled,
EmailVerifyEnabled: req.EmailVerifyEnabled, EmailVerifyEnabled: req.EmailVerifyEnabled,
PromoCodeEnabled: req.PromoCodeEnabled, PromoCodeEnabled: req.PromoCodeEnabled,
SMTPHost: req.SMTPHost, SMTPHost: req.SMTPHost,
SMTPPort: req.SMTPPort, SMTPPort: req.SMTPPort,
SMTPUsername: req.SMTPUsername, SMTPUsername: req.SMTPUsername,
SMTPPassword: req.SMTPPassword, SMTPPassword: req.SMTPPassword,
SMTPFrom: req.SMTPFrom, SMTPFrom: req.SMTPFrom,
SMTPFromName: req.SMTPFromName, SMTPFromName: req.SMTPFromName,
SMTPUseTLS: req.SMTPUseTLS, SMTPUseTLS: req.SMTPUseTLS,
TurnstileEnabled: req.TurnstileEnabled, TurnstileEnabled: req.TurnstileEnabled,
TurnstileSiteKey: req.TurnstileSiteKey, TurnstileSiteKey: req.TurnstileSiteKey,
TurnstileSecretKey: req.TurnstileSecretKey, TurnstileSecretKey: req.TurnstileSecretKey,
LinuxDoConnectEnabled: req.LinuxDoConnectEnabled, LinuxDoConnectEnabled: req.LinuxDoConnectEnabled,
LinuxDoConnectClientID: req.LinuxDoConnectClientID, LinuxDoConnectClientID: req.LinuxDoConnectClientID,
LinuxDoConnectClientSecret: req.LinuxDoConnectClientSecret, LinuxDoConnectClientSecret: req.LinuxDoConnectClientSecret,
LinuxDoConnectRedirectURL: req.LinuxDoConnectRedirectURL, LinuxDoConnectRedirectURL: req.LinuxDoConnectRedirectURL,
SiteName: req.SiteName, SiteName: req.SiteName,
SiteLogo: req.SiteLogo, SiteLogo: req.SiteLogo,
SiteSubtitle: req.SiteSubtitle, SiteSubtitle: req.SiteSubtitle,
APIBaseURL: req.APIBaseURL, APIBaseURL: req.APIBaseURL,
ContactInfo: req.ContactInfo, ContactInfo: req.ContactInfo,
DocURL: req.DocURL, DocURL: req.DocURL,
HomeContent: req.HomeContent, HomeContent: req.HomeContent,
HideCcsImportButton: req.HideCcsImportButton, HideCcsImportButton: req.HideCcsImportButton,
DefaultConcurrency: req.DefaultConcurrency, DefaultConcurrency: req.DefaultConcurrency,
DefaultBalance: req.DefaultBalance, DefaultBalance: req.DefaultBalance,
EnableModelFallback: req.EnableModelFallback, EnableModelFallback: req.EnableModelFallback,
FallbackModelAnthropic: req.FallbackModelAnthropic, FallbackModelAnthropic: req.FallbackModelAnthropic,
FallbackModelOpenAI: req.FallbackModelOpenAI, FallbackModelOpenAI: req.FallbackModelOpenAI,
FallbackModelGemini: req.FallbackModelGemini, FallbackModelGemini: req.FallbackModelGemini,
FallbackModelAntigravity: req.FallbackModelAntigravity, FallbackModelAntigravity: req.FallbackModelAntigravity,
EnableIdentityPatch: req.EnableIdentityPatch, EnableIdentityPatch: req.EnableIdentityPatch,
IdentityPatchPrompt: req.IdentityPatchPrompt, IdentityPatchPrompt: req.IdentityPatchPrompt,
SoraBaseURL: req.SoraBaseURL,
SoraTimeout: req.SoraTimeout,
SoraMaxRetries: req.SoraMaxRetries,
SoraPollInterval: req.SoraPollInterval,
SoraCallLogicMode: req.SoraCallLogicMode,
SoraCacheEnabled: req.SoraCacheEnabled,
SoraCacheBaseDir: req.SoraCacheBaseDir,
SoraCacheVideoDir: req.SoraCacheVideoDir,
SoraCacheMaxBytes: req.SoraCacheMaxBytes,
SoraCacheAllowedHosts: req.SoraCacheAllowedHosts,
SoraCacheUserDirEnabled: req.SoraCacheUserDirEnabled,
SoraWatermarkFreeEnabled: req.SoraWatermarkFreeEnabled,
SoraWatermarkFreeParseMethod: req.SoraWatermarkFreeParseMethod,
SoraWatermarkFreeCustomParseURL: req.SoraWatermarkFreeCustomParseURL,
SoraWatermarkFreeCustomParseToken: req.SoraWatermarkFreeCustomParseToken,
SoraWatermarkFreeFallbackOnFailure: req.SoraWatermarkFreeFallbackOnFailure,
SoraTokenRefreshEnabled: req.SoraTokenRefreshEnabled,
OpsMonitoringEnabled: func() bool { OpsMonitoringEnabled: func() bool {
if req.OpsMonitoringEnabled != nil { if req.OpsMonitoringEnabled != nil {
return *req.OpsMonitoringEnabled return *req.OpsMonitoringEnabled
...@@ -349,6 +428,23 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { ...@@ -349,6 +428,23 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
FallbackModelAntigravity: updatedSettings.FallbackModelAntigravity, FallbackModelAntigravity: updatedSettings.FallbackModelAntigravity,
EnableIdentityPatch: updatedSettings.EnableIdentityPatch, EnableIdentityPatch: updatedSettings.EnableIdentityPatch,
IdentityPatchPrompt: updatedSettings.IdentityPatchPrompt, IdentityPatchPrompt: updatedSettings.IdentityPatchPrompt,
SoraBaseURL: updatedSettings.SoraBaseURL,
SoraTimeout: updatedSettings.SoraTimeout,
SoraMaxRetries: updatedSettings.SoraMaxRetries,
SoraPollInterval: updatedSettings.SoraPollInterval,
SoraCallLogicMode: updatedSettings.SoraCallLogicMode,
SoraCacheEnabled: updatedSettings.SoraCacheEnabled,
SoraCacheBaseDir: updatedSettings.SoraCacheBaseDir,
SoraCacheVideoDir: updatedSettings.SoraCacheVideoDir,
SoraCacheMaxBytes: updatedSettings.SoraCacheMaxBytes,
SoraCacheAllowedHosts: updatedSettings.SoraCacheAllowedHosts,
SoraCacheUserDirEnabled: updatedSettings.SoraCacheUserDirEnabled,
SoraWatermarkFreeEnabled: updatedSettings.SoraWatermarkFreeEnabled,
SoraWatermarkFreeParseMethod: updatedSettings.SoraWatermarkFreeParseMethod,
SoraWatermarkFreeCustomParseURL: updatedSettings.SoraWatermarkFreeCustomParseURL,
SoraWatermarkFreeCustomParseToken: updatedSettings.SoraWatermarkFreeCustomParseToken,
SoraWatermarkFreeFallbackOnFailure: updatedSettings.SoraWatermarkFreeFallbackOnFailure,
SoraTokenRefreshEnabled: updatedSettings.SoraTokenRefreshEnabled,
OpsMonitoringEnabled: updatedSettings.OpsMonitoringEnabled, OpsMonitoringEnabled: updatedSettings.OpsMonitoringEnabled,
OpsRealtimeMonitoringEnabled: updatedSettings.OpsRealtimeMonitoringEnabled, OpsRealtimeMonitoringEnabled: updatedSettings.OpsRealtimeMonitoringEnabled,
OpsQueryModeDefault: updatedSettings.OpsQueryModeDefault, OpsQueryModeDefault: updatedSettings.OpsQueryModeDefault,
...@@ -477,6 +573,57 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings, ...@@ -477,6 +573,57 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings,
if before.IdentityPatchPrompt != after.IdentityPatchPrompt { if before.IdentityPatchPrompt != after.IdentityPatchPrompt {
changed = append(changed, "identity_patch_prompt") changed = append(changed, "identity_patch_prompt")
} }
if before.SoraBaseURL != after.SoraBaseURL {
changed = append(changed, "sora_base_url")
}
if before.SoraTimeout != after.SoraTimeout {
changed = append(changed, "sora_timeout")
}
if before.SoraMaxRetries != after.SoraMaxRetries {
changed = append(changed, "sora_max_retries")
}
if before.SoraPollInterval != after.SoraPollInterval {
changed = append(changed, "sora_poll_interval")
}
if before.SoraCallLogicMode != after.SoraCallLogicMode {
changed = append(changed, "sora_call_logic_mode")
}
if before.SoraCacheEnabled != after.SoraCacheEnabled {
changed = append(changed, "sora_cache_enabled")
}
if before.SoraCacheBaseDir != after.SoraCacheBaseDir {
changed = append(changed, "sora_cache_base_dir")
}
if before.SoraCacheVideoDir != after.SoraCacheVideoDir {
changed = append(changed, "sora_cache_video_dir")
}
if before.SoraCacheMaxBytes != after.SoraCacheMaxBytes {
changed = append(changed, "sora_cache_max_bytes")
}
if strings.Join(before.SoraCacheAllowedHosts, ",") != strings.Join(after.SoraCacheAllowedHosts, ",") {
changed = append(changed, "sora_cache_allowed_hosts")
}
if before.SoraCacheUserDirEnabled != after.SoraCacheUserDirEnabled {
changed = append(changed, "sora_cache_user_dir_enabled")
}
if before.SoraWatermarkFreeEnabled != after.SoraWatermarkFreeEnabled {
changed = append(changed, "sora_watermark_free_enabled")
}
if before.SoraWatermarkFreeParseMethod != after.SoraWatermarkFreeParseMethod {
changed = append(changed, "sora_watermark_free_parse_method")
}
if before.SoraWatermarkFreeCustomParseURL != after.SoraWatermarkFreeCustomParseURL {
changed = append(changed, "sora_watermark_free_custom_parse_url")
}
if before.SoraWatermarkFreeCustomParseToken != after.SoraWatermarkFreeCustomParseToken {
changed = append(changed, "sora_watermark_free_custom_parse_token")
}
if before.SoraWatermarkFreeFallbackOnFailure != after.SoraWatermarkFreeFallbackOnFailure {
changed = append(changed, "sora_watermark_free_fallback_on_failure")
}
if before.SoraTokenRefreshEnabled != after.SoraTokenRefreshEnabled {
changed = append(changed, "sora_token_refresh_enabled")
}
if before.OpsMonitoringEnabled != after.OpsMonitoringEnabled { if before.OpsMonitoringEnabled != after.OpsMonitoringEnabled {
changed = append(changed, "ops_monitoring_enabled") changed = append(changed, "ops_monitoring_enabled")
} }
...@@ -492,6 +639,19 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings, ...@@ -492,6 +639,19 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings,
return changed return changed
} }
func normalizeStringList(values []string) []string {
if len(values) == 0 {
return []string{}
}
normalized := make([]string, 0, len(values))
for _, value := range values {
if trimmed := strings.TrimSpace(value); trimmed != "" {
normalized = append(normalized, trimmed)
}
}
return normalized
}
// TestSMTPRequest 测试SMTP连接请求 // TestSMTPRequest 测试SMTP连接请求
type TestSMTPRequest struct { type TestSMTPRequest struct {
SMTPHost string `json:"smtp_host" binding:"required"` SMTPHost string `json:"smtp_host" binding:"required"`
......
package admin
import (
"strconv"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
// SoraAccountHandler Sora 账号扩展管理
// 提供 Sora 扩展表的查询与更新能力。
type SoraAccountHandler struct {
adminService service.AdminService
soraAccountRepo service.SoraAccountRepository
usageRepo service.SoraUsageStatRepository
}
// NewSoraAccountHandler 创建 SoraAccountHandler
func NewSoraAccountHandler(adminService service.AdminService, soraAccountRepo service.SoraAccountRepository, usageRepo service.SoraUsageStatRepository) *SoraAccountHandler {
return &SoraAccountHandler{
adminService: adminService,
soraAccountRepo: soraAccountRepo,
usageRepo: usageRepo,
}
}
// SoraAccountUpdateRequest 更新/创建 Sora 账号扩展请求
// 使用指针类型区分未提供与设置为空值。
type SoraAccountUpdateRequest struct {
AccessToken *string `json:"access_token"`
SessionToken *string `json:"session_token"`
RefreshToken *string `json:"refresh_token"`
ClientID *string `json:"client_id"`
Email *string `json:"email"`
Username *string `json:"username"`
Remark *string `json:"remark"`
UseCount *int `json:"use_count"`
PlanType *string `json:"plan_type"`
PlanTitle *string `json:"plan_title"`
SubscriptionEnd *int64 `json:"subscription_end"`
SoraSupported *bool `json:"sora_supported"`
SoraInviteCode *string `json:"sora_invite_code"`
SoraRedeemedCount *int `json:"sora_redeemed_count"`
SoraRemainingCount *int `json:"sora_remaining_count"`
SoraTotalCount *int `json:"sora_total_count"`
SoraCooldownUntil *int64 `json:"sora_cooldown_until"`
CooledUntil *int64 `json:"cooled_until"`
ImageEnabled *bool `json:"image_enabled"`
VideoEnabled *bool `json:"video_enabled"`
ImageConcurrency *int `json:"image_concurrency"`
VideoConcurrency *int `json:"video_concurrency"`
IsExpired *bool `json:"is_expired"`
}
// SoraAccountBatchRequest 批量导入请求
// accounts 支持批量 upsert。
type SoraAccountBatchRequest struct {
Accounts []SoraAccountBatchItem `json:"accounts"`
}
// SoraAccountBatchItem 批量导入条目
type SoraAccountBatchItem struct {
AccountID int64 `json:"account_id"`
SoraAccountUpdateRequest
}
// SoraAccountBatchResult 批量导入结果
// 仅返回成功/失败数量与明细。
type SoraAccountBatchResult struct {
Success int `json:"success"`
Failed int `json:"failed"`
Results []SoraAccountBatchItemResult `json:"results"`
}
// SoraAccountBatchItemResult 批量导入单条结果
type SoraAccountBatchItemResult struct {
AccountID int64 `json:"account_id"`
Success bool `json:"success"`
Error string `json:"error,omitempty"`
}
// List 获取 Sora 账号扩展列表
// GET /api/v1/admin/sora/accounts
func (h *SoraAccountHandler) List(c *gin.Context) {
page, pageSize := response.ParsePagination(c)
search := strings.TrimSpace(c.Query("search"))
accounts, total, err := h.adminService.ListAccounts(c.Request.Context(), page, pageSize, service.PlatformSora, "", "", search)
if err != nil {
response.ErrorFrom(c, err)
return
}
accountIDs := make([]int64, 0, len(accounts))
for i := range accounts {
accountIDs = append(accountIDs, accounts[i].ID)
}
soraMap := map[int64]*service.SoraAccount{}
if h.soraAccountRepo != nil {
soraMap, _ = h.soraAccountRepo.GetByAccountIDs(c.Request.Context(), accountIDs)
}
usageMap := map[int64]*service.SoraUsageStat{}
if h.usageRepo != nil {
usageMap, _ = h.usageRepo.GetByAccountIDs(c.Request.Context(), accountIDs)
}
result := make([]dto.SoraAccount, 0, len(accounts))
for i := range accounts {
acc := accounts[i]
item := dto.SoraAccountFromService(&acc, soraMap[acc.ID], usageMap[acc.ID])
if item != nil {
result = append(result, *item)
}
}
response.Paginated(c, result, total, page, pageSize)
}
// Get 获取单个 Sora 账号扩展
// GET /api/v1/admin/sora/accounts/:id
func (h *SoraAccountHandler) Get(c *gin.Context) {
accountID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "账号 ID 无效")
return
}
account, err := h.adminService.GetAccount(c.Request.Context(), accountID)
if err != nil {
response.ErrorFrom(c, err)
return
}
if account.Platform != service.PlatformSora {
response.BadRequest(c, "账号不是 Sora 平台")
return
}
var soraAcc *service.SoraAccount
if h.soraAccountRepo != nil {
soraAcc, _ = h.soraAccountRepo.GetByAccountID(c.Request.Context(), accountID)
}
var usage *service.SoraUsageStat
if h.usageRepo != nil {
usage, _ = h.usageRepo.GetByAccountID(c.Request.Context(), accountID)
}
response.Success(c, dto.SoraAccountFromService(account, soraAcc, usage))
}
// Upsert 更新或创建 Sora 账号扩展
// PUT /api/v1/admin/sora/accounts/:id
func (h *SoraAccountHandler) Upsert(c *gin.Context) {
accountID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "账号 ID 无效")
return
}
var req SoraAccountUpdateRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "请求参数无效: "+err.Error())
return
}
account, err := h.adminService.GetAccount(c.Request.Context(), accountID)
if err != nil {
response.ErrorFrom(c, err)
return
}
if account.Platform != service.PlatformSora {
response.BadRequest(c, "账号不是 Sora 平台")
return
}
updates := buildSoraAccountUpdates(&req)
if h.soraAccountRepo != nil && len(updates) > 0 {
if err := h.soraAccountRepo.Upsert(c.Request.Context(), accountID, updates); err != nil {
response.ErrorFrom(c, err)
return
}
}
var soraAcc *service.SoraAccount
if h.soraAccountRepo != nil {
soraAcc, _ = h.soraAccountRepo.GetByAccountID(c.Request.Context(), accountID)
}
var usage *service.SoraUsageStat
if h.usageRepo != nil {
usage, _ = h.usageRepo.GetByAccountID(c.Request.Context(), accountID)
}
response.Success(c, dto.SoraAccountFromService(account, soraAcc, usage))
}
// BatchUpsert 批量导入 Sora 账号扩展
// POST /api/v1/admin/sora/accounts/import
func (h *SoraAccountHandler) BatchUpsert(c *gin.Context) {
var req SoraAccountBatchRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "请求参数无效: "+err.Error())
return
}
if len(req.Accounts) == 0 {
response.BadRequest(c, "accounts 不能为空")
return
}
ids := make([]int64, 0, len(req.Accounts))
for _, item := range req.Accounts {
if item.AccountID > 0 {
ids = append(ids, item.AccountID)
}
}
accountMap := make(map[int64]*service.Account, len(ids))
if len(ids) > 0 {
accounts, _ := h.adminService.GetAccountsByIDs(c.Request.Context(), ids)
for _, acc := range accounts {
if acc != nil {
accountMap[acc.ID] = acc
}
}
}
result := SoraAccountBatchResult{
Results: make([]SoraAccountBatchItemResult, 0, len(req.Accounts)),
}
for _, item := range req.Accounts {
entry := SoraAccountBatchItemResult{AccountID: item.AccountID}
acc := accountMap[item.AccountID]
if acc == nil {
entry.Error = "账号不存在"
result.Results = append(result.Results, entry)
result.Failed++
continue
}
if acc.Platform != service.PlatformSora {
entry.Error = "账号不是 Sora 平台"
result.Results = append(result.Results, entry)
result.Failed++
continue
}
updates := buildSoraAccountUpdates(&item.SoraAccountUpdateRequest)
if h.soraAccountRepo != nil && len(updates) > 0 {
if err := h.soraAccountRepo.Upsert(c.Request.Context(), item.AccountID, updates); err != nil {
entry.Error = err.Error()
result.Results = append(result.Results, entry)
result.Failed++
continue
}
}
entry.Success = true
result.Results = append(result.Results, entry)
result.Success++
}
response.Success(c, result)
}
// ListUsage 获取 Sora 调用统计
// GET /api/v1/admin/sora/usage
func (h *SoraAccountHandler) ListUsage(c *gin.Context) {
page, pageSize := response.ParsePagination(c)
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
if h.usageRepo == nil {
response.Paginated(c, []dto.SoraUsageStat{}, 0, page, pageSize)
return
}
stats, paginationResult, err := h.usageRepo.List(c.Request.Context(), params)
if err != nil {
response.ErrorFrom(c, err)
return
}
result := make([]dto.SoraUsageStat, 0, len(stats))
for _, stat := range stats {
item := dto.SoraUsageStatFromService(stat)
if item != nil {
result = append(result, *item)
}
}
response.Paginated(c, result, paginationResult.Total, paginationResult.Page, paginationResult.PageSize)
}
func buildSoraAccountUpdates(req *SoraAccountUpdateRequest) map[string]any {
if req == nil {
return nil
}
updates := make(map[string]any)
setString := func(key string, value *string) {
if value == nil {
return
}
updates[key] = strings.TrimSpace(*value)
}
setString("access_token", req.AccessToken)
setString("session_token", req.SessionToken)
setString("refresh_token", req.RefreshToken)
setString("client_id", req.ClientID)
setString("email", req.Email)
setString("username", req.Username)
setString("remark", req.Remark)
setString("plan_type", req.PlanType)
setString("plan_title", req.PlanTitle)
setString("sora_invite_code", req.SoraInviteCode)
if req.UseCount != nil {
updates["use_count"] = *req.UseCount
}
if req.SoraSupported != nil {
updates["sora_supported"] = *req.SoraSupported
}
if req.SoraRedeemedCount != nil {
updates["sora_redeemed_count"] = *req.SoraRedeemedCount
}
if req.SoraRemainingCount != nil {
updates["sora_remaining_count"] = *req.SoraRemainingCount
}
if req.SoraTotalCount != nil {
updates["sora_total_count"] = *req.SoraTotalCount
}
if req.ImageEnabled != nil {
updates["image_enabled"] = *req.ImageEnabled
}
if req.VideoEnabled != nil {
updates["video_enabled"] = *req.VideoEnabled
}
if req.ImageConcurrency != nil {
updates["image_concurrency"] = *req.ImageConcurrency
}
if req.VideoConcurrency != nil {
updates["video_concurrency"] = *req.VideoConcurrency
}
if req.IsExpired != nil {
updates["is_expired"] = *req.IsExpired
}
if req.SubscriptionEnd != nil && *req.SubscriptionEnd > 0 {
updates["subscription_end"] = time.Unix(*req.SubscriptionEnd, 0).UTC()
}
if req.SoraCooldownUntil != nil && *req.SoraCooldownUntil > 0 {
updates["sora_cooldown_until"] = time.Unix(*req.SoraCooldownUntil, 0).UTC()
}
if req.CooledUntil != nil && *req.CooledUntil > 0 {
updates["cooled_until"] = time.Unix(*req.CooledUntil, 0).UTC()
}
return updates
}
...@@ -287,6 +287,72 @@ func ProxyWithAccountCountFromService(p *service.ProxyWithAccountCount) *ProxyWi ...@@ -287,6 +287,72 @@ func ProxyWithAccountCountFromService(p *service.ProxyWithAccountCount) *ProxyWi
} }
} }
func SoraUsageStatFromService(stat *service.SoraUsageStat) *SoraUsageStat {
if stat == nil {
return nil
}
return &SoraUsageStat{
AccountID: stat.AccountID,
ImageCount: stat.ImageCount,
VideoCount: stat.VideoCount,
ErrorCount: stat.ErrorCount,
LastErrorAt: timeToUnixSeconds(stat.LastErrorAt),
TodayImageCount: stat.TodayImageCount,
TodayVideoCount: stat.TodayVideoCount,
TodayErrorCount: stat.TodayErrorCount,
TodayDate: timeToUnixSeconds(stat.TodayDate),
ConsecutiveErrorCount: stat.ConsecutiveErrorCount,
CreatedAt: stat.CreatedAt,
UpdatedAt: stat.UpdatedAt,
}
}
func SoraAccountFromService(account *service.Account, soraAcc *service.SoraAccount, usage *service.SoraUsageStat) *SoraAccount {
if account == nil {
return nil
}
out := &SoraAccount{
AccountID: account.ID,
AccountName: account.Name,
AccountStatus: account.Status,
AccountType: account.Type,
AccountConcurrency: account.Concurrency,
ProxyID: account.ProxyID,
Usage: SoraUsageStatFromService(usage),
CreatedAt: account.CreatedAt,
UpdatedAt: account.UpdatedAt,
}
if soraAcc == nil {
return out
}
out.AccessToken = soraAcc.AccessToken
out.SessionToken = soraAcc.SessionToken
out.RefreshToken = soraAcc.RefreshToken
out.ClientID = soraAcc.ClientID
out.Email = soraAcc.Email
out.Username = soraAcc.Username
out.Remark = soraAcc.Remark
out.UseCount = soraAcc.UseCount
out.PlanType = soraAcc.PlanType
out.PlanTitle = soraAcc.PlanTitle
out.SubscriptionEnd = timeToUnixSeconds(soraAcc.SubscriptionEnd)
out.SoraSupported = soraAcc.SoraSupported
out.SoraInviteCode = soraAcc.SoraInviteCode
out.SoraRedeemedCount = soraAcc.SoraRedeemedCount
out.SoraRemainingCount = soraAcc.SoraRemainingCount
out.SoraTotalCount = soraAcc.SoraTotalCount
out.SoraCooldownUntil = timeToUnixSeconds(soraAcc.SoraCooldownUntil)
out.CooledUntil = timeToUnixSeconds(soraAcc.CooledUntil)
out.ImageEnabled = soraAcc.ImageEnabled
out.VideoEnabled = soraAcc.VideoEnabled
out.ImageConcurrency = soraAcc.ImageConcurrency
out.VideoConcurrency = soraAcc.VideoConcurrency
out.IsExpired = soraAcc.IsExpired
out.CreatedAt = soraAcc.CreatedAt
out.UpdatedAt = soraAcc.UpdatedAt
return out
}
func ProxyAccountSummaryFromService(a *service.ProxyAccountSummary) *ProxyAccountSummary { func ProxyAccountSummaryFromService(a *service.ProxyAccountSummary) *ProxyAccountSummary {
if a == nil { if a == nil {
return nil return nil
......
...@@ -46,6 +46,25 @@ type SystemSettings struct { ...@@ -46,6 +46,25 @@ type SystemSettings struct {
EnableIdentityPatch bool `json:"enable_identity_patch"` EnableIdentityPatch bool `json:"enable_identity_patch"`
IdentityPatchPrompt string `json:"identity_patch_prompt"` IdentityPatchPrompt string `json:"identity_patch_prompt"`
// Sora configuration
SoraBaseURL string `json:"sora_base_url"`
SoraTimeout int `json:"sora_timeout"`
SoraMaxRetries int `json:"sora_max_retries"`
SoraPollInterval float64 `json:"sora_poll_interval"`
SoraCallLogicMode string `json:"sora_call_logic_mode"`
SoraCacheEnabled bool `json:"sora_cache_enabled"`
SoraCacheBaseDir string `json:"sora_cache_base_dir"`
SoraCacheVideoDir string `json:"sora_cache_video_dir"`
SoraCacheMaxBytes int64 `json:"sora_cache_max_bytes"`
SoraCacheAllowedHosts []string `json:"sora_cache_allowed_hosts"`
SoraCacheUserDirEnabled bool `json:"sora_cache_user_dir_enabled"`
SoraWatermarkFreeEnabled bool `json:"sora_watermark_free_enabled"`
SoraWatermarkFreeParseMethod string `json:"sora_watermark_free_parse_method"`
SoraWatermarkFreeCustomParseURL string `json:"sora_watermark_free_custom_parse_url"`
SoraWatermarkFreeCustomParseToken string `json:"sora_watermark_free_custom_parse_token"`
SoraWatermarkFreeFallbackOnFailure bool `json:"sora_watermark_free_fallback_on_failure"`
SoraTokenRefreshEnabled bool `json:"sora_token_refresh_enabled"`
// Ops monitoring (vNext) // Ops monitoring (vNext)
OpsMonitoringEnabled bool `json:"ops_monitoring_enabled"` OpsMonitoringEnabled bool `json:"ops_monitoring_enabled"`
OpsRealtimeMonitoringEnabled bool `json:"ops_realtime_monitoring_enabled"` OpsRealtimeMonitoringEnabled bool `json:"ops_realtime_monitoring_enabled"`
......
...@@ -141,6 +141,56 @@ type Account struct { ...@@ -141,6 +141,56 @@ type Account struct {
Groups []*Group `json:"groups,omitempty"` Groups []*Group `json:"groups,omitempty"`
} }
type SoraUsageStat struct {
AccountID int64 `json:"account_id"`
ImageCount int `json:"image_count"`
VideoCount int `json:"video_count"`
ErrorCount int `json:"error_count"`
LastErrorAt *int64 `json:"last_error_at"`
TodayImageCount int `json:"today_image_count"`
TodayVideoCount int `json:"today_video_count"`
TodayErrorCount int `json:"today_error_count"`
TodayDate *int64 `json:"today_date"`
ConsecutiveErrorCount int `json:"consecutive_error_count"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
}
type SoraAccount struct {
AccountID int64 `json:"account_id"`
AccountName string `json:"account_name"`
AccountStatus string `json:"account_status"`
AccountType string `json:"account_type"`
AccountConcurrency int `json:"account_concurrency"`
ProxyID *int64 `json:"proxy_id"`
AccessToken string `json:"access_token"`
SessionToken string `json:"session_token"`
RefreshToken string `json:"refresh_token"`
ClientID string `json:"client_id"`
Email string `json:"email"`
Username string `json:"username"`
Remark string `json:"remark"`
UseCount int `json:"use_count"`
PlanType string `json:"plan_type"`
PlanTitle string `json:"plan_title"`
SubscriptionEnd *int64 `json:"subscription_end"`
SoraSupported bool `json:"sora_supported"`
SoraInviteCode string `json:"sora_invite_code"`
SoraRedeemedCount int `json:"sora_redeemed_count"`
SoraRemainingCount int `json:"sora_remaining_count"`
SoraTotalCount int `json:"sora_total_count"`
SoraCooldownUntil *int64 `json:"sora_cooldown_until"`
CooledUntil *int64 `json:"cooled_until"`
ImageEnabled bool `json:"image_enabled"`
VideoEnabled bool `json:"video_enabled"`
ImageConcurrency int `json:"image_concurrency"`
VideoConcurrency int `json:"video_concurrency"`
IsExpired bool `json:"is_expired"`
Usage *SoraUsageStat `json:"usage,omitempty"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
}
type AccountGroup struct { type AccountGroup struct {
AccountID int64 `json:"account_id"` AccountID int64 `json:"account_id"`
GroupID int64 `json:"group_id"` GroupID int64 `json:"group_id"`
......
...@@ -17,6 +17,7 @@ import ( ...@@ -17,6 +17,7 @@ import (
pkgerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" pkgerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/ip" "github.com/Wei-Shaw/sub2api/internal/pkg/ip"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai" "github.com/Wei-Shaw/sub2api/internal/pkg/openai"
"github.com/Wei-Shaw/sub2api/internal/pkg/sora"
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware" middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
...@@ -508,6 +509,13 @@ func (h *GatewayHandler) Models(c *gin.Context) { ...@@ -508,6 +509,13 @@ func (h *GatewayHandler) Models(c *gin.Context) {
}) })
return return
} }
if platform == service.PlatformSora {
c.JSON(http.StatusOK, gin.H{
"object": "list",
"data": sora.ListModels(),
})
return
}
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"object": "list", "object": "list",
......
...@@ -17,6 +17,7 @@ type AdminHandlers struct { ...@@ -17,6 +17,7 @@ type AdminHandlers struct {
Proxy *admin.ProxyHandler Proxy *admin.ProxyHandler
Redeem *admin.RedeemHandler Redeem *admin.RedeemHandler
Promo *admin.PromoHandler Promo *admin.PromoHandler
SoraAccount *admin.SoraAccountHandler
Setting *admin.SettingHandler Setting *admin.SettingHandler
Ops *admin.OpsHandler Ops *admin.OpsHandler
System *admin.SystemHandler System *admin.SystemHandler
...@@ -36,6 +37,7 @@ type Handlers struct { ...@@ -36,6 +37,7 @@ type Handlers struct {
Admin *AdminHandlers Admin *AdminHandlers
Gateway *GatewayHandler Gateway *GatewayHandler
OpenAIGateway *OpenAIGatewayHandler OpenAIGateway *OpenAIGatewayHandler
SoraGateway *SoraGatewayHandler
Setting *SettingHandler Setting *SettingHandler
} }
......
...@@ -814,6 +814,8 @@ func guessPlatformFromPath(path string) string { ...@@ -814,6 +814,8 @@ func guessPlatformFromPath(path string) string {
return service.PlatformAntigravity return service.PlatformAntigravity
case strings.HasPrefix(p, "/v1beta/"): case strings.HasPrefix(p, "/v1beta/"):
return service.PlatformGemini return service.PlatformGemini
case strings.Contains(p, "/chat/completions"):
return service.PlatformSora
case strings.Contains(p, "/responses"): case strings.Contains(p, "/responses"):
return service.PlatformOpenAI return service.PlatformOpenAI
default: default:
......
package handler
import (
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/sora"
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
// SoraGatewayHandler handles Sora OpenAI compatible endpoints.
type SoraGatewayHandler struct {
gatewayService *service.GatewayService
soraGatewayService *service.SoraGatewayService
billingCacheService *service.BillingCacheService
concurrencyHelper *ConcurrencyHelper
maxAccountSwitches int
}
// NewSoraGatewayHandler creates a new SoraGatewayHandler.
func NewSoraGatewayHandler(
gatewayService *service.GatewayService,
soraGatewayService *service.SoraGatewayService,
concurrencyService *service.ConcurrencyService,
billingCacheService *service.BillingCacheService,
cfg *config.Config,
) *SoraGatewayHandler {
pingInterval := time.Duration(0)
maxAccountSwitches := 3
if cfg != nil {
pingInterval = time.Duration(cfg.Concurrency.PingInterval) * time.Second
if cfg.Gateway.MaxAccountSwitches > 0 {
maxAccountSwitches = cfg.Gateway.MaxAccountSwitches
}
}
return &SoraGatewayHandler{
gatewayService: gatewayService,
soraGatewayService: soraGatewayService,
billingCacheService: billingCacheService,
concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatComment, pingInterval),
maxAccountSwitches: maxAccountSwitches,
}
}
// ChatCompletions handles Sora OpenAI-compatible chat completions endpoint.
// POST /v1/chat/completions
func (h *SoraGatewayHandler) ChatCompletions(c *gin.Context) {
apiKey, ok := middleware.GetAPIKeyFromContext(c)
if !ok {
h.errorResponse(c, http.StatusUnauthorized, "authentication_error", "Invalid API key")
return
}
subject, ok := middleware.GetAuthSubjectFromContext(c)
if !ok {
h.errorResponse(c, http.StatusInternalServerError, "api_error", "User context not found")
return
}
body, err := io.ReadAll(c.Request.Body)
if err != nil {
if maxErr, ok := extractMaxBytesError(err); ok {
h.errorResponse(c, http.StatusRequestEntityTooLarge, "invalid_request_error", buildBodyTooLargeMessage(maxErr.Limit))
return
}
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to read request body")
return
}
if len(body) == 0 {
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Request body is empty")
return
}
var reqBody map[string]any
if err := json.Unmarshal(body, &reqBody); err != nil {
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body")
return
}
model, _ := reqBody["model"].(string)
if strings.TrimSpace(model) == "" {
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "model is required")
return
}
stream, _ := reqBody["stream"].(bool)
prompt, imageData, videoData, remixID, err := parseSoraPrompt(reqBody)
if err != nil {
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", err.Error())
return
}
if remixID == "" {
remixID = sora.ExtractRemixID(prompt)
}
if remixID != "" {
prompt = strings.ReplaceAll(prompt, remixID, "")
}
if apiKey.Group != nil && apiKey.Group.Platform != service.PlatformSora {
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "当前分组不支持 Sora 平台")
return
}
streamStarted := false
maxWait := service.CalculateMaxWait(subject.Concurrency)
canWait, err := h.concurrencyHelper.IncrementWaitCount(c.Request.Context(), subject.UserID, maxWait)
waitCounted := false
if err == nil && canWait {
waitCounted = true
}
if err == nil && !canWait {
h.errorResponse(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later")
return
}
defer func() {
if waitCounted {
h.concurrencyHelper.DecrementWaitCount(c.Request.Context(), subject.UserID)
}
}()
userReleaseFunc, err := h.concurrencyHelper.AcquireUserSlotWithWait(c, subject.UserID, subject.Concurrency, stream, &streamStarted)
if err != nil {
h.handleConcurrencyError(c, err, "user", streamStarted)
return
}
if waitCounted {
h.concurrencyHelper.DecrementWaitCount(c.Request.Context(), subject.UserID)
waitCounted = false
}
userReleaseFunc = wrapReleaseOnDone(c.Request.Context(), userReleaseFunc)
if userReleaseFunc != nil {
defer userReleaseFunc()
}
failedAccountIDs := make(map[int64]struct{})
maxSwitches := h.maxAccountSwitches
if mode := h.soraGatewayService.CallLogicMode(c.Request.Context()); strings.EqualFold(mode, "native") {
maxSwitches = 1
}
for switchCount := 0; switchCount < maxSwitches; switchCount++ {
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, "", model, failedAccountIDs, "")
if err != nil {
h.errorResponse(c, http.StatusServiceUnavailable, "server_error", err.Error())
return
}
account := selection.Account
releaseFunc := selection.ReleaseFunc
result, err := h.soraGatewayService.Generate(c.Request.Context(), account, service.SoraGenerationRequest{
Model: model,
Prompt: prompt,
Image: imageData,
Video: videoData,
RemixTargetID: remixID,
Stream: stream,
UserID: subject.UserID,
})
if err != nil {
// 失败路径:立即释放槽位,而非 defer
if releaseFunc != nil {
releaseFunc()
}
if errors.Is(err, service.ErrSoraAccountMissingToken) || errors.Is(err, service.ErrSoraAccountNotEligible) {
failedAccountIDs[account.ID] = struct{}{}
continue
}
h.handleStreamingAwareError(c, http.StatusBadGateway, "server_error", err.Error(), streamStarted)
return
}
// 成功路径:使用 defer 在函数退出时释放
if releaseFunc != nil {
defer releaseFunc()
}
h.respondCompletion(c, model, result, stream)
return
}
h.handleFailoverExhausted(c, http.StatusServiceUnavailable, streamStarted)
}
func (h *SoraGatewayHandler) respondCompletion(c *gin.Context, model string, result *service.SoraGenerationResult, stream bool) {
if result == nil {
h.errorResponse(c, http.StatusInternalServerError, "api_error", "Empty response")
return
}
if stream {
c.Header("Content-Type", "text/event-stream")
c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive")
first := buildSoraStreamChunk(model, "", true, "")
if _, err := c.Writer.WriteString(first); err != nil {
return
}
final := buildSoraStreamChunk(model, result.Content, false, "stop")
if _, err := c.Writer.WriteString(final); err != nil {
return
}
_, _ = c.Writer.WriteString("data: [DONE]\n\n")
return
}
c.JSON(http.StatusOK, buildSoraNonStreamResponse(model, result.Content))
}
func buildSoraStreamChunk(model, content string, isFirst bool, finishReason string) string {
chunkID := fmt.Sprintf("chatcmpl-%d", time.Now().UnixMilli())
delta := map[string]any{}
if isFirst {
delta["role"] = "assistant"
}
if content != "" {
delta["content"] = content
} else {
delta["content"] = nil
}
response := map[string]any{
"id": chunkID,
"object": "chat.completion.chunk",
"created": time.Now().Unix(),
"model": model,
"choices": []any{
map[string]any{
"index": 0,
"delta": delta,
"finish_reason": finishReason,
},
},
}
payload, _ := json.Marshal(response)
return "data: " + string(payload) + "\n\n"
}
func buildSoraNonStreamResponse(model, content string) map[string]any {
return map[string]any{
"id": fmt.Sprintf("chatcmpl-%d", time.Now().UnixMilli()),
"object": "chat.completion",
"created": time.Now().Unix(),
"model": model,
"choices": []any{
map[string]any{
"index": 0,
"message": map[string]any{
"role": "assistant",
"content": content,
},
"finish_reason": "stop",
},
},
}
}
func parseSoraPrompt(req map[string]any) (prompt, imageData, videoData, remixID string, err error) {
messages, ok := req["messages"].([]any)
if !ok || len(messages) == 0 {
return "", "", "", "", fmt.Errorf("messages is required")
}
last := messages[len(messages)-1]
msg, ok := last.(map[string]any)
if !ok {
return "", "", "", "", fmt.Errorf("invalid message format")
}
content, ok := msg["content"]
if !ok {
return "", "", "", "", fmt.Errorf("content is required")
}
if v, ok := req["image"].(string); ok && v != "" {
imageData = v
}
if v, ok := req["video"].(string); ok && v != "" {
videoData = v
}
if v, ok := req["remix_target_id"].(string); ok {
remixID = v
}
switch value := content.(type) {
case string:
prompt = value
case []any:
for _, item := range value {
part, ok := item.(map[string]any)
if !ok {
continue
}
switch part["type"] {
case "text":
if text, ok := part["text"].(string); ok {
prompt = text
}
case "image_url":
if image, ok := part["image_url"].(map[string]any); ok {
if url, ok := image["url"].(string); ok {
imageData = url
}
}
case "video_url":
if video, ok := part["video_url"].(map[string]any); ok {
if url, ok := video["url"].(string); ok {
videoData = url
}
}
}
}
default:
return "", "", "", "", fmt.Errorf("invalid content format")
}
if strings.TrimSpace(prompt) == "" && strings.TrimSpace(videoData) == "" {
return "", "", "", "", fmt.Errorf("prompt is required")
}
return prompt, imageData, videoData, remixID, nil
}
func looksLikeURL(value string) bool {
trimmed := strings.ToLower(strings.TrimSpace(value))
return strings.HasPrefix(trimmed, "http://") || strings.HasPrefix(trimmed, "https://")
}
func (h *SoraGatewayHandler) handleConcurrencyError(c *gin.Context, err error, slotType string, streamStarted bool) {
if streamStarted {
h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error", err.Error(), true)
return
}
c.JSON(http.StatusTooManyRequests, gin.H{"error": err.Error()})
}
func (h *SoraGatewayHandler) handleFailoverExhausted(c *gin.Context, statusCode int, streamStarted bool) {
message := "No available Sora accounts"
h.handleStreamingAwareError(c, statusCode, "server_error", message, streamStarted)
}
func (h *SoraGatewayHandler) handleStreamingAwareError(c *gin.Context, status int, errType, message string, streamStarted bool) {
if streamStarted {
payload := map[string]any{"error": map[string]any{"message": message, "type": errType, "param": nil, "code": nil}}
data, _ := json.Marshal(payload)
_, _ = c.Writer.WriteString("data: " + string(data) + "\n\n")
_, _ = c.Writer.WriteString("data: [DONE]\n\n")
return
}
h.errorResponse(c, status, errType, message)
}
func (h *SoraGatewayHandler) errorResponse(c *gin.Context, status int, errType, message string) {
c.JSON(status, gin.H{
"error": gin.H{
"message": message,
"type": errType,
"param": nil,
"code": nil,
},
})
}
...@@ -20,6 +20,7 @@ func ProvideAdminHandlers( ...@@ -20,6 +20,7 @@ func ProvideAdminHandlers(
proxyHandler *admin.ProxyHandler, proxyHandler *admin.ProxyHandler,
redeemHandler *admin.RedeemHandler, redeemHandler *admin.RedeemHandler,
promoHandler *admin.PromoHandler, promoHandler *admin.PromoHandler,
soraAccountHandler *admin.SoraAccountHandler,
settingHandler *admin.SettingHandler, settingHandler *admin.SettingHandler,
opsHandler *admin.OpsHandler, opsHandler *admin.OpsHandler,
systemHandler *admin.SystemHandler, systemHandler *admin.SystemHandler,
...@@ -39,6 +40,7 @@ func ProvideAdminHandlers( ...@@ -39,6 +40,7 @@ func ProvideAdminHandlers(
Proxy: proxyHandler, Proxy: proxyHandler,
Redeem: redeemHandler, Redeem: redeemHandler,
Promo: promoHandler, Promo: promoHandler,
SoraAccount: soraAccountHandler,
Setting: settingHandler, Setting: settingHandler,
Ops: opsHandler, Ops: opsHandler,
System: systemHandler, System: systemHandler,
...@@ -69,6 +71,7 @@ func ProvideHandlers( ...@@ -69,6 +71,7 @@ func ProvideHandlers(
adminHandlers *AdminHandlers, adminHandlers *AdminHandlers,
gatewayHandler *GatewayHandler, gatewayHandler *GatewayHandler,
openaiGatewayHandler *OpenAIGatewayHandler, openaiGatewayHandler *OpenAIGatewayHandler,
soraGatewayHandler *SoraGatewayHandler,
settingHandler *SettingHandler, settingHandler *SettingHandler,
) *Handlers { ) *Handlers {
return &Handlers{ return &Handlers{
...@@ -81,6 +84,7 @@ func ProvideHandlers( ...@@ -81,6 +84,7 @@ func ProvideHandlers(
Admin: adminHandlers, Admin: adminHandlers,
Gateway: gatewayHandler, Gateway: gatewayHandler,
OpenAIGateway: openaiGatewayHandler, OpenAIGateway: openaiGatewayHandler,
SoraGateway: soraGatewayHandler,
Setting: settingHandler, Setting: settingHandler,
} }
} }
...@@ -96,6 +100,7 @@ var ProviderSet = wire.NewSet( ...@@ -96,6 +100,7 @@ var ProviderSet = wire.NewSet(
NewSubscriptionHandler, NewSubscriptionHandler,
NewGatewayHandler, NewGatewayHandler,
NewOpenAIGatewayHandler, NewOpenAIGatewayHandler,
NewSoraGatewayHandler,
ProvideSettingHandler, ProvideSettingHandler,
// Admin handlers // Admin handlers
...@@ -110,6 +115,7 @@ var ProviderSet = wire.NewSet( ...@@ -110,6 +115,7 @@ var ProviderSet = wire.NewSet(
admin.NewProxyHandler, admin.NewProxyHandler,
admin.NewRedeemHandler, admin.NewRedeemHandler,
admin.NewPromoHandler, admin.NewPromoHandler,
admin.NewSoraAccountHandler,
admin.NewSettingHandler, admin.NewSettingHandler,
admin.NewOpsHandler, admin.NewOpsHandler,
ProvideSystemHandler, ProvideSystemHandler,
......
package sora
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"mime/multipart"
"net/http"
"net/textproto"
)
// UploadCharacterVideo uploads a character video and returns cameo ID.
func (c *Client) UploadCharacterVideo(ctx context.Context, opts RequestOptions, data []byte) (string, error) {
if len(data) == 0 {
return "", errors.New("video data empty")
}
var buf bytes.Buffer
writer := multipart.NewWriter(&buf)
if err := writeMultipartFile(writer, "file", "video.mp4", "video/mp4", data); err != nil {
return "", err
}
if err := writer.WriteField("timestamps", "0,3"); err != nil {
return "", err
}
if err := writer.Close(); err != nil {
return "", err
}
resp, err := c.doRequest(ctx, "POST", "/characters/upload", opts, &buf, writer.FormDataContentType(), false)
if err != nil {
return "", err
}
return stringFromJSON(resp, "id"), nil
}
// GetCameoStatus returns cameo processing status.
func (c *Client) GetCameoStatus(ctx context.Context, opts RequestOptions, cameoID string) (map[string]any, error) {
if cameoID == "" {
return nil, errors.New("cameo id empty")
}
return c.doRequest(ctx, "GET", "/project_y/cameos/in_progress/"+cameoID, opts, nil, "", false)
}
// DownloadCharacterImage downloads character avatar image data.
func (c *Client) DownloadCharacterImage(ctx context.Context, opts RequestOptions, imageURL string) ([]byte, error) {
if c.upstream == nil {
return nil, errors.New("upstream is nil")
}
req, err := http.NewRequestWithContext(ctx, "GET", imageURL, nil)
if err != nil {
return nil, err
}
req.Header.Set("User-Agent", defaultDesktopUA)
resp, err := c.upstream.DoWithTLS(req, opts.ProxyURL, opts.AccountID, opts.AccountConcurrency, c.enableTLSFingerprint)
if err != nil {
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
return nil, fmt.Errorf("download image failed: %d", resp.StatusCode)
}
return io.ReadAll(resp.Body)
}
// UploadCharacterImage uploads character avatar and returns asset pointer.
func (c *Client) UploadCharacterImage(ctx context.Context, opts RequestOptions, data []byte) (string, error) {
if len(data) == 0 {
return "", errors.New("image data empty")
}
var buf bytes.Buffer
writer := multipart.NewWriter(&buf)
if err := writeMultipartFile(writer, "file", "profile.webp", "image/webp", data); err != nil {
return "", err
}
if err := writer.WriteField("use_case", "profile"); err != nil {
return "", err
}
if err := writer.Close(); err != nil {
return "", err
}
resp, err := c.doRequest(ctx, "POST", "/project_y/file/upload", opts, &buf, writer.FormDataContentType(), false)
if err != nil {
return "", err
}
return stringFromJSON(resp, "asset_pointer"), nil
}
// FinalizeCharacter finalizes character creation and returns character ID.
func (c *Client) FinalizeCharacter(ctx context.Context, opts RequestOptions, cameoID, username, displayName, assetPointer string) (string, error) {
payload := map[string]any{
"cameo_id": cameoID,
"username": username,
"display_name": displayName,
"profile_asset_pointer": assetPointer,
"instruction_set": nil,
"safety_instruction_set": nil,
}
body, err := json.Marshal(payload)
if err != nil {
return "", err
}
resp, err := c.doRequest(ctx, "POST", "/characters/finalize", opts, bytes.NewReader(body), "application/json", false)
if err != nil {
return "", err
}
if character, ok := resp["character"].(map[string]any); ok {
if id, ok := character["character_id"].(string); ok {
return id, nil
}
}
return "", nil
}
// SetCharacterPublic marks character as public.
func (c *Client) SetCharacterPublic(ctx context.Context, opts RequestOptions, cameoID string) error {
payload := map[string]any{"visibility": "public"}
body, err := json.Marshal(payload)
if err != nil {
return err
}
_, err = c.doRequest(ctx, "POST", "/project_y/cameos/by_id/"+cameoID+"/update_v2", opts, bytes.NewReader(body), "application/json", false)
return err
}
// DeleteCharacter deletes a character by ID.
func (c *Client) DeleteCharacter(ctx context.Context, opts RequestOptions, characterID string) error {
if characterID == "" {
return nil
}
_, err := c.doRequest(ctx, "DELETE", "/project_y/characters/"+characterID, opts, nil, "", false)
return err
}
func writeMultipartFile(writer *multipart.Writer, field, filename, contentType string, data []byte) error {
header := make(textproto.MIMEHeader)
header.Set("Content-Disposition", fmt.Sprintf(`form-data; name="%s"; filename="%s"`, field, filename))
if contentType != "" {
header.Set("Content-Type", contentType)
}
part, err := writer.CreatePart(header)
if err != nil {
return err
}
_, err = part.Write(data)
return err
}
package sora
import (
"bytes"
"context"
"crypto/sha3"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"io"
"mime/multipart"
"net/http"
"strings"
"sync"
"time"
"github.com/google/uuid"
)
const (
chatGPTBaseURL = "https://chatgpt.com"
sentinelFlow = "sora_2_create_task"
maxAPIResponseSize = 1 * 1024 * 1024 // 1MB
)
var (
defaultMobileUA = "Sora/1.2026.007 (Android 15; Pixel 8 Pro; build 2600700)"
defaultDesktopUA = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/131.0.0.0 Safari/537.36"
sentinelCache sync.Map // 包级缓存,存储 Sentinel Token,key 为 accountID
)
// sentinelCacheEntry 是 Sentinel Token 缓存条目
type sentinelCacheEntry struct {
token string
expiresAt time.Time
}
// UpstreamClient defines the HTTP client interface for Sora requests.
type UpstreamClient interface {
Do(req *http.Request, proxyURL string, accountID int64, accountConcurrency int) (*http.Response, error)
DoWithTLS(req *http.Request, proxyURL string, accountID int64, accountConcurrency int, enableTLSFingerprint bool) (*http.Response, error)
}
// Client is a minimal Sora API client.
type Client struct {
baseURL string
timeout time.Duration
upstream UpstreamClient
enableTLSFingerprint bool
}
// RequestOptions configures per-request context.
type RequestOptions struct {
AccountID int64
AccountConcurrency int
ProxyURL string
AccessToken string
}
// getCachedSentinel 从缓存中获取 Sentinel Token
func getCachedSentinel(accountID int64) (string, bool) {
v, ok := sentinelCache.Load(accountID)
if !ok {
return "", false
}
entry := v.(*sentinelCacheEntry)
if time.Now().After(entry.expiresAt) {
sentinelCache.Delete(accountID)
return "", false
}
return entry.token, true
}
// cacheSentinel 缓存 Sentinel Token
func cacheSentinel(accountID int64, token string) {
sentinelCache.Store(accountID, &sentinelCacheEntry{
token: token,
expiresAt: time.Now().Add(3 * time.Minute), // 3分钟有效期
})
}
// NewClient creates a Sora client.
func NewClient(baseURL string, timeout time.Duration, upstream UpstreamClient, enableTLSFingerprint bool) *Client {
return &Client{
baseURL: strings.TrimRight(baseURL, "/"),
timeout: timeout,
upstream: upstream,
enableTLSFingerprint: enableTLSFingerprint,
}
}
// UploadImage uploads an image and returns media ID.
func (c *Client) UploadImage(ctx context.Context, opts RequestOptions, data []byte, filename string) (string, error) {
if filename == "" {
filename = "image.png"
}
var buf bytes.Buffer
writer := multipart.NewWriter(&buf)
part, err := writer.CreateFormFile("file", filename)
if err != nil {
return "", err
}
if _, err := part.Write(data); err != nil {
return "", err
}
if err := writer.WriteField("file_name", filename); err != nil {
return "", err
}
if err := writer.Close(); err != nil {
return "", err
}
resp, err := c.doRequest(ctx, "POST", "/uploads", opts, &buf, writer.FormDataContentType(), false)
if err != nil {
return "", err
}
return stringFromJSON(resp, "id"), nil
}
// GenerateImage creates an image generation task.
func (c *Client) GenerateImage(ctx context.Context, opts RequestOptions, prompt string, width, height int, mediaID string) (string, error) {
operation := "simple_compose"
var inpaint []map[string]any
if mediaID != "" {
operation = "remix"
inpaint = []map[string]any{
{
"type": "image",
"frame_index": 0,
"upload_media_id": mediaID,
},
}
}
payload := map[string]any{
"type": "image_gen",
"operation": operation,
"prompt": prompt,
"width": width,
"height": height,
"n_variants": 1,
"n_frames": 1,
"inpaint_items": inpaint,
}
body, err := json.Marshal(payload)
if err != nil {
return "", err
}
resp, err := c.doRequest(ctx, "POST", "/video_gen", opts, bytes.NewReader(body), "application/json", true)
if err != nil {
return "", err
}
return stringFromJSON(resp, "id"), nil
}
// GenerateVideo creates a video generation task.
func (c *Client) GenerateVideo(ctx context.Context, opts RequestOptions, prompt, orientation string, nFrames int, mediaID, styleID, model, size string) (string, error) {
var inpaint []map[string]any
if mediaID != "" {
inpaint = []map[string]any{{"kind": "upload", "upload_id": mediaID}}
}
payload := map[string]any{
"kind": "video",
"prompt": prompt,
"orientation": orientation,
"size": size,
"n_frames": nFrames,
"model": model,
"inpaint_items": inpaint,
"style_id": styleID,
}
body, err := json.Marshal(payload)
if err != nil {
return "", err
}
resp, err := c.doRequest(ctx, "POST", "/nf/create", opts, bytes.NewReader(body), "application/json", true)
if err != nil {
return "", err
}
return stringFromJSON(resp, "id"), nil
}
// GenerateStoryboard creates a storyboard video task.
func (c *Client) GenerateStoryboard(ctx context.Context, opts RequestOptions, prompt, orientation string, nFrames int, mediaID, styleID string) (string, error) {
var inpaint []map[string]any
if mediaID != "" {
inpaint = []map[string]any{{"kind": "upload", "upload_id": mediaID}}
}
payload := map[string]any{
"kind": "video",
"prompt": prompt,
"title": "Draft your video",
"orientation": orientation,
"size": "small",
"n_frames": nFrames,
"storyboard_id": nil,
"inpaint_items": inpaint,
"remix_target_id": nil,
"model": "sy_8",
"metadata": nil,
"style_id": styleID,
"cameo_ids": nil,
"cameo_replacements": nil,
}
body, err := json.Marshal(payload)
if err != nil {
return "", err
}
resp, err := c.doRequest(ctx, "POST", "/nf/create/storyboard", opts, bytes.NewReader(body), "application/json", true)
if err != nil {
return "", err
}
return stringFromJSON(resp, "id"), nil
}
// RemixVideo creates a remix task.
func (c *Client) RemixVideo(ctx context.Context, opts RequestOptions, remixTargetID, prompt, orientation string, nFrames int, styleID string) (string, error) {
payload := map[string]any{
"kind": "video",
"prompt": prompt,
"inpaint_items": []map[string]any{},
"remix_target_id": remixTargetID,
"cameo_ids": []string{},
"cameo_replacements": map[string]any{},
"model": "sy_8",
"orientation": orientation,
"n_frames": nFrames,
"style_id": styleID,
}
body, err := json.Marshal(payload)
if err != nil {
return "", err
}
resp, err := c.doRequest(ctx, "POST", "/nf/create", opts, bytes.NewReader(body), "application/json", true)
if err != nil {
return "", err
}
return stringFromJSON(resp, "id"), nil
}
// GetImageTasks returns recent image tasks.
func (c *Client) GetImageTasks(ctx context.Context, opts RequestOptions) (map[string]any, error) {
return c.doRequest(ctx, "GET", "/v2/recent_tasks?limit=20", opts, nil, "", false)
}
// GetPendingTasks returns pending video tasks.
func (c *Client) GetPendingTasks(ctx context.Context, opts RequestOptions) ([]map[string]any, error) {
resp, err := c.doRequestAny(ctx, "GET", "/nf/pending/v2", opts, nil, "", false)
if err != nil {
return nil, err
}
switch v := resp.(type) {
case []any:
return convertList(v), nil
case map[string]any:
if list, ok := v["items"].([]any); ok {
return convertList(list), nil
}
if arr, ok := v["data"].([]any); ok {
return convertList(arr), nil
}
return convertListFromAny(v), nil
default:
return nil, nil
}
}
// GetVideoDrafts returns recent video drafts.
func (c *Client) GetVideoDrafts(ctx context.Context, opts RequestOptions) (map[string]any, error) {
return c.doRequest(ctx, "GET", "/project_y/profile/drafts?limit=15", opts, nil, "", false)
}
// EnhancePrompt calls prompt enhancement API.
func (c *Client) EnhancePrompt(ctx context.Context, opts RequestOptions, prompt, expansionLevel string, durationS int) (string, error) {
payload := map[string]any{
"prompt": prompt,
"expansion_level": expansionLevel,
"duration_s": durationS,
}
body, err := json.Marshal(payload)
if err != nil {
return "", err
}
resp, err := c.doRequest(ctx, "POST", "/editor/enhance_prompt", opts, bytes.NewReader(body), "application/json", false)
if err != nil {
return "", err
}
return stringFromJSON(resp, "enhanced_prompt"), nil
}
// PostVideoForWatermarkFree publishes a video for watermark-free parsing.
func (c *Client) PostVideoForWatermarkFree(ctx context.Context, opts RequestOptions, generationID string) (string, error) {
payload := map[string]any{
"attachments_to_create": []map[string]any{{
"generation_id": generationID,
"kind": "sora",
}},
"post_text": "",
}
body, err := json.Marshal(payload)
if err != nil {
return "", err
}
resp, err := c.doRequest(ctx, "POST", "/project_y/post", opts, bytes.NewReader(body), "application/json", true)
if err != nil {
return "", err
}
post, _ := resp["post"].(map[string]any)
if post == nil {
return "", nil
}
return stringFromJSON(post, "id"), nil
}
// DeletePost deletes a Sora post.
func (c *Client) DeletePost(ctx context.Context, opts RequestOptions, postID string) error {
if postID == "" {
return nil
}
_, err := c.doRequest(ctx, "DELETE", "/project_y/post/"+postID, opts, nil, "", false)
return err
}
func (c *Client) doRequest(ctx context.Context, method, endpoint string, opts RequestOptions, body io.Reader, contentType string, addSentinel bool) (map[string]any, error) {
resp, err := c.doRequestAny(ctx, method, endpoint, opts, body, contentType, addSentinel)
if err != nil {
return nil, err
}
parsed, ok := resp.(map[string]any)
if !ok {
return nil, errors.New("unexpected response format")
}
return parsed, nil
}
func (c *Client) doRequestAny(ctx context.Context, method, endpoint string, opts RequestOptions, body io.Reader, contentType string, addSentinel bool) (any, error) {
if c.upstream == nil {
return nil, errors.New("upstream is nil")
}
url := c.baseURL + endpoint
req, err := http.NewRequestWithContext(ctx, method, url, body)
if err != nil {
return nil, err
}
if contentType != "" {
req.Header.Set("Content-Type", contentType)
}
if opts.AccessToken != "" {
req.Header.Set("Authorization", "Bearer "+opts.AccessToken)
}
req.Header.Set("User-Agent", defaultMobileUA)
if addSentinel {
sentinel, err := c.generateSentinelToken(ctx, opts)
if err != nil {
return nil, err
}
req.Header.Set("openai-sentinel-token", sentinel)
}
resp, err := c.upstream.DoWithTLS(req, opts.ProxyURL, opts.AccountID, opts.AccountConcurrency, c.enableTLSFingerprint)
if err != nil {
return nil, err
}
defer resp.Body.Close()
// 使用 LimitReader 限制最大响应大小,防止 DoS 攻击
limitedReader := io.LimitReader(resp.Body, maxAPIResponseSize+1)
data, err := io.ReadAll(limitedReader)
if err != nil {
return nil, err
}
// 检查是否超过大小限制
if int64(len(data)) > maxAPIResponseSize {
return nil, fmt.Errorf("API 响应过大 (最大 %d 字节)", maxAPIResponseSize)
}
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
return nil, fmt.Errorf("sora api error: %d %s", resp.StatusCode, strings.TrimSpace(string(data)))
}
if len(data) == 0 {
return map[string]any{}, nil
}
var parsed any
if err := json.Unmarshal(data, &parsed); err != nil {
return nil, err
}
return parsed, nil
}
func (c *Client) generateSentinelToken(ctx context.Context, opts RequestOptions) (string, error) {
// 尝试从缓存获取
if token, ok := getCachedSentinel(opts.AccountID); ok {
return token, nil
}
reqID := uuid.New().String()
powToken, err := generatePowToken(defaultDesktopUA)
if err != nil {
return "", err
}
payload := map[string]any{"p": powToken, "flow": sentinelFlow, "id": reqID}
body, err := json.Marshal(payload)
if err != nil {
return "", err
}
url := chatGPTBaseURL + "/backend-api/sentinel/req"
req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(body))
if err != nil {
return "", err
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Origin", "https://sora.chatgpt.com")
req.Header.Set("Referer", "https://sora.chatgpt.com/")
req.Header.Set("User-Agent", defaultDesktopUA)
if opts.AccessToken != "" {
req.Header.Set("Authorization", "Bearer "+opts.AccessToken)
}
resp, err := c.upstream.DoWithTLS(req, opts.ProxyURL, opts.AccountID, opts.AccountConcurrency, c.enableTLSFingerprint)
if err != nil {
return "", err
}
defer resp.Body.Close()
// 使用 LimitReader 限制最大响应大小,防止 DoS 攻击
limitedReader := io.LimitReader(resp.Body, maxAPIResponseSize+1)
data, err := io.ReadAll(limitedReader)
if err != nil {
return "", err
}
// 检查是否超过大小限制
if int64(len(data)) > maxAPIResponseSize {
return "", fmt.Errorf("API 响应过大 (最大 %d 字节)", maxAPIResponseSize)
}
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
return "", fmt.Errorf("sentinel request failed: %d %s", resp.StatusCode, strings.TrimSpace(string(data)))
}
var parsed map[string]any
if err := json.Unmarshal(data, &parsed); err != nil {
return "", err
}
token := buildSentinelToken(reqID, powToken, parsed)
// 缓存结果
cacheSentinel(opts.AccountID, token)
return token, nil
}
func buildSentinelToken(reqID, powToken string, resp map[string]any) string {
finalPow := powToken
pow, _ := resp["proofofwork"].(map[string]any)
if pow != nil {
required, _ := pow["required"].(bool)
if required {
seed, _ := pow["seed"].(string)
difficulty, _ := pow["difficulty"].(string)
if seed != "" && difficulty != "" {
candidate, _ := solvePow(seed, difficulty, defaultDesktopUA)
if candidate != "" {
finalPow = "gAAAAAB" + candidate
}
}
}
}
if !strings.HasSuffix(finalPow, "~S") {
finalPow += "~S"
}
turnstile := ""
if t, ok := resp["turnstile"].(map[string]any); ok {
turnstile, _ = t["dx"].(string)
}
token := ""
if v, ok := resp["token"].(string); ok {
token = v
}
payload := map[string]any{
"p": finalPow,
"t": turnstile,
"c": token,
"id": reqID,
"flow": sentinelFlow,
}
data, _ := json.Marshal(payload)
return string(data)
}
func generatePowToken(userAgent string) (string, error) {
seed := fmt.Sprintf("%f", float64(time.Now().UnixNano())/1e9)
candidate, _ := solvePow(seed, "0fffff", userAgent)
if candidate == "" {
return "", errors.New("pow generation failed")
}
return "gAAAAAC" + candidate, nil
}
func solvePow(seed, difficulty, userAgent string) (string, bool) {
config := powConfig(userAgent)
seedBytes := []byte(seed)
diffBytes, err := hexDecode(difficulty)
if err != nil {
return "", false
}
configBytes, err := json.Marshal(config)
if err != nil {
return "", false
}
prefix := configBytes[:len(configBytes)-1]
for i := 0; i < 500000; i++ {
payload := append(prefix, []byte(fmt.Sprintf(",%d,%d]", i, i>>1))...)
b64 := base64.StdEncoding.EncodeToString(payload)
h := sha3.Sum512(append(seedBytes, []byte(b64)...))
if bytes.Compare(h[:len(diffBytes)], diffBytes) <= 0 {
return b64, true
}
}
return "", false
}
func powConfig(userAgent string) []any {
return []any{
3000,
formatPowTime(),
4294705152,
0,
userAgent,
"",
nil,
"en-US",
"en-US,es-US,en,es",
0,
"webdriver-false",
"location",
"window",
time.Now().UnixMilli(),
uuid.New().String(),
"",
16,
float64(time.Now().UnixMilli()),
}
}
func formatPowTime() string {
loc := time.FixedZone("EST", -5*60*60)
return time.Now().In(loc).Format("Mon Jan 02 2006 15:04:05") + " GMT-0500 (Eastern Standard Time)"
}
func hexDecode(s string) ([]byte, error) {
if len(s)%2 != 0 {
return nil, errors.New("invalid hex length")
}
out := make([]byte, len(s)/2)
for i := 0; i < len(out); i++ {
byteVal, err := hexPair(s[i*2 : i*2+2])
if err != nil {
return nil, err
}
out[i] = byteVal
}
return out, nil
}
func hexPair(pair string) (byte, error) {
var v byte
for i := 0; i < 2; i++ {
c := pair[i]
var n byte
switch {
case c >= '0' && c <= '9':
n = c - '0'
case c >= 'a' && c <= 'f':
n = c - 'a' + 10
case c >= 'A' && c <= 'F':
n = c - 'A' + 10
default:
return 0, errors.New("invalid hex")
}
v = v<<4 | n
}
return v, nil
}
func stringFromJSON(data map[string]any, key string) string {
if data == nil {
return ""
}
if v, ok := data[key].(string); ok {
return v
}
return ""
}
func convertList(list []any) []map[string]any {
results := make([]map[string]any, 0, len(list))
for _, item := range list {
if m, ok := item.(map[string]any); ok {
results = append(results, m)
}
}
return results
}
func convertListFromAny(data map[string]any) []map[string]any {
if data == nil {
return nil
}
items, ok := data["items"].([]any)
if ok {
return convertList(items)
}
return nil
}
package sora
// ModelConfig 定义 Sora 模型配置。
type ModelConfig struct {
Type string
Width int
Height int
Orientation string
NFrames int
Model string
Size string
RequirePro bool
ExpansionLevel string
DurationS int
}
// ModelConfigs 定义所有模型配置。
var ModelConfigs = map[string]ModelConfig{
"gpt-image": {
Type: "image",
Width: 360,
Height: 360,
},
"gpt-image-landscape": {
Type: "image",
Width: 540,
Height: 360,
},
"gpt-image-portrait": {
Type: "image",
Width: 360,
Height: 540,
},
"sora2-landscape-10s": {
Type: "video",
Orientation: "landscape",
NFrames: 300,
},
"sora2-portrait-10s": {
Type: "video",
Orientation: "portrait",
NFrames: 300,
},
"sora2-landscape-15s": {
Type: "video",
Orientation: "landscape",
NFrames: 450,
},
"sora2-portrait-15s": {
Type: "video",
Orientation: "portrait",
NFrames: 450,
},
"sora2-landscape-25s": {
Type: "video",
Orientation: "landscape",
NFrames: 750,
Model: "sy_8",
Size: "small",
RequirePro: true,
},
"sora2-portrait-25s": {
Type: "video",
Orientation: "portrait",
NFrames: 750,
Model: "sy_8",
Size: "small",
RequirePro: true,
},
"sora2pro-landscape-10s": {
Type: "video",
Orientation: "landscape",
NFrames: 300,
Model: "sy_ore",
Size: "small",
RequirePro: true,
},
"sora2pro-portrait-10s": {
Type: "video",
Orientation: "portrait",
NFrames: 300,
Model: "sy_ore",
Size: "small",
RequirePro: true,
},
"sora2pro-landscape-15s": {
Type: "video",
Orientation: "landscape",
NFrames: 450,
Model: "sy_ore",
Size: "small",
RequirePro: true,
},
"sora2pro-portrait-15s": {
Type: "video",
Orientation: "portrait",
NFrames: 450,
Model: "sy_ore",
Size: "small",
RequirePro: true,
},
"sora2pro-landscape-25s": {
Type: "video",
Orientation: "landscape",
NFrames: 750,
Model: "sy_ore",
Size: "small",
RequirePro: true,
},
"sora2pro-portrait-25s": {
Type: "video",
Orientation: "portrait",
NFrames: 750,
Model: "sy_ore",
Size: "small",
RequirePro: true,
},
"sora2pro-hd-landscape-10s": {
Type: "video",
Orientation: "landscape",
NFrames: 300,
Model: "sy_ore",
Size: "large",
RequirePro: true,
},
"sora2pro-hd-portrait-10s": {
Type: "video",
Orientation: "portrait",
NFrames: 300,
Model: "sy_ore",
Size: "large",
RequirePro: true,
},
"sora2pro-hd-landscape-15s": {
Type: "video",
Orientation: "landscape",
NFrames: 450,
Model: "sy_ore",
Size: "large",
RequirePro: true,
},
"sora2pro-hd-portrait-15s": {
Type: "video",
Orientation: "portrait",
NFrames: 450,
Model: "sy_ore",
Size: "large",
RequirePro: true,
},
"prompt-enhance-short-10s": {
Type: "prompt_enhance",
ExpansionLevel: "short",
DurationS: 10,
},
"prompt-enhance-short-15s": {
Type: "prompt_enhance",
ExpansionLevel: "short",
DurationS: 15,
},
"prompt-enhance-short-20s": {
Type: "prompt_enhance",
ExpansionLevel: "short",
DurationS: 20,
},
"prompt-enhance-medium-10s": {
Type: "prompt_enhance",
ExpansionLevel: "medium",
DurationS: 10,
},
"prompt-enhance-medium-15s": {
Type: "prompt_enhance",
ExpansionLevel: "medium",
DurationS: 15,
},
"prompt-enhance-medium-20s": {
Type: "prompt_enhance",
ExpansionLevel: "medium",
DurationS: 20,
},
"prompt-enhance-long-10s": {
Type: "prompt_enhance",
ExpansionLevel: "long",
DurationS: 10,
},
"prompt-enhance-long-15s": {
Type: "prompt_enhance",
ExpansionLevel: "long",
DurationS: 15,
},
"prompt-enhance-long-20s": {
Type: "prompt_enhance",
ExpansionLevel: "long",
DurationS: 20,
},
}
// ModelListItem 返回模型列表条目。
type ModelListItem struct {
ID string `json:"id"`
Object string `json:"object"`
OwnedBy string `json:"owned_by"`
Description string `json:"description"`
}
// ListModels 生成模型列表。
func ListModels() []ModelListItem {
models := make([]ModelListItem, 0, len(ModelConfigs))
for id, cfg := range ModelConfigs {
description := ""
switch cfg.Type {
case "image":
description = "Image generation"
if cfg.Width > 0 && cfg.Height > 0 {
description += " - " + itoa(cfg.Width) + "x" + itoa(cfg.Height)
}
case "video":
description = "Video generation"
if cfg.Orientation != "" {
description += " - " + cfg.Orientation
}
case "prompt_enhance":
description = "Prompt enhancement"
if cfg.ExpansionLevel != "" {
description += " - " + cfg.ExpansionLevel
}
if cfg.DurationS > 0 {
description += " (" + itoa(cfg.DurationS) + "s)"
}
default:
description = "Sora model"
}
models = append(models, ModelListItem{
ID: id,
Object: "model",
OwnedBy: "sora",
Description: description,
})
}
return models
}
func itoa(val int) string {
if val == 0 {
return "0"
}
neg := false
if val < 0 {
neg = true
val = -val
}
buf := [12]byte{}
i := len(buf)
for val > 0 {
i--
buf[i] = byte('0' + val%10)
val /= 10
}
if neg {
i--
buf[i] = '-'
}
return string(buf[i:])
}
package sora
import (
"regexp"
"strings"
)
var storyboardRe = regexp.MustCompile(`\[(\d+(?:\.\d+)?)s\]`)
// IsStoryboardPrompt 检测是否为分镜提示词。
func IsStoryboardPrompt(prompt string) bool {
if strings.TrimSpace(prompt) == "" {
return false
}
return storyboardRe.MatchString(prompt)
}
// FormatStoryboardPrompt 将分镜提示词转换为 API 需要的格式。
func FormatStoryboardPrompt(prompt string) string {
prompt = strings.TrimSpace(prompt)
if prompt == "" {
return prompt
}
matches := storyboardRe.FindAllStringSubmatchIndex(prompt, -1)
if len(matches) == 0 {
return prompt
}
firstIdx := matches[0][0]
instructions := strings.TrimSpace(prompt[:firstIdx])
shotPattern := regexp.MustCompile(`\[(\d+(?:\.\d+)?)s\]\s*([^\[]+)`)
shotMatches := shotPattern.FindAllStringSubmatch(prompt, -1)
if len(shotMatches) == 0 {
return prompt
}
shots := make([]string, 0, len(shotMatches))
for i, sm := range shotMatches {
if len(sm) < 3 {
continue
}
duration := strings.TrimSpace(sm[1])
scene := strings.TrimSpace(sm[2])
shots = append(shots, "Shot "+itoa(i+1)+":\nduration: "+duration+"sec\nScene: "+scene)
}
timeline := strings.Join(shots, "\n\n")
if instructions != "" {
return "current timeline:\n" + timeline + "\n\ninstructions:\n" + instructions
}
return timeline
}
// ExtractRemixID 提取分享链接中的 remix ID。
func ExtractRemixID(text string) string {
text = strings.TrimSpace(text)
if text == "" {
return ""
}
re := regexp.MustCompile(`s_[a-f0-9]{32}`)
match := re.FindString(text)
return match
}
package uuidv7
import (
"crypto/rand"
"fmt"
"time"
)
// New returns a UUIDv7 string.
func New() (string, error) {
var b [16]byte
if _, err := rand.Read(b[:]); err != nil {
return "", err
}
ms := uint64(time.Now().UnixMilli())
b[0] = byte(ms >> 40)
b[1] = byte(ms >> 32)
b[2] = byte(ms >> 24)
b[3] = byte(ms >> 16)
b[4] = byte(ms >> 8)
b[5] = byte(ms)
b[6] = (b[6] & 0x0f) | 0x70
b[8] = (b[8] & 0x3f) | 0x80
return fmt.Sprintf("%08x-%04x-%04x-%04x-%012x",
uint32(b[0])<<24|uint32(b[1])<<16|uint32(b[2])<<8|uint32(b[3]),
uint16(b[4])<<8|uint16(b[5]),
uint16(b[6])<<8|uint16(b[7]),
uint16(b[8])<<8|uint16(b[9]),
uint64(b[10])<<40|uint64(b[11])<<32|uint64(b[12])<<24|uint64(b[13])<<16|uint64(b[14])<<8|uint64(b[15]),
), nil
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment