Commit 987589ea authored by yangjianbo's avatar yangjianbo
Browse files

Merge branch 'test' into release

parents 372e04f6 03f69dd3
...@@ -32,7 +32,7 @@ jobs: ...@@ -32,7 +32,7 @@ jobs:
working-directory: backend working-directory: backend
run: | run: |
go install github.com/securego/gosec/v2/cmd/gosec@latest go install github.com/securego/gosec/v2/cmd/gosec@latest
gosec -severity high -confidence high ./... gosec -conf .gosec.json -severity high -confidence high ./...
frontend-security: frontend-security:
runs-on: ubuntu-latest runs-on: ubuntu-latest
......
{
"global": {
"exclude": "G704"
}
}
...@@ -184,7 +184,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { ...@@ -184,7 +184,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, adminAnnouncementHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, promoHandler, settingHandler, opsHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler, errorPassthroughHandler) adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, adminAnnouncementHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, promoHandler, settingHandler, opsHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler, errorPassthroughHandler)
gatewayHandler := handler.NewGatewayHandler(gatewayService, geminiMessagesCompatService, antigravityGatewayService, userService, concurrencyService, billingCacheService, usageService, apiKeyService, errorPassthroughService, configConfig) gatewayHandler := handler.NewGatewayHandler(gatewayService, geminiMessagesCompatService, antigravityGatewayService, userService, concurrencyService, billingCacheService, usageService, apiKeyService, errorPassthroughService, configConfig)
openAIGatewayHandler := handler.NewOpenAIGatewayHandler(openAIGatewayService, concurrencyService, billingCacheService, apiKeyService, errorPassthroughService, configConfig) openAIGatewayHandler := handler.NewOpenAIGatewayHandler(openAIGatewayService, concurrencyService, billingCacheService, apiKeyService, errorPassthroughService, configConfig)
soraDirectClient := service.NewSoraDirectClient(configConfig, httpUpstream, openAITokenProvider) soraDirectClient := service.ProvideSoraDirectClient(configConfig, httpUpstream, openAITokenProvider, accountRepository, soraAccountRepository)
soraMediaStorage := service.ProvideSoraMediaStorage(configConfig) soraMediaStorage := service.ProvideSoraMediaStorage(configConfig)
soraGatewayService := service.NewSoraGatewayService(soraDirectClient, soraMediaStorage, rateLimitService, configConfig) soraGatewayService := service.NewSoraGatewayService(soraDirectClient, soraMediaStorage, rateLimitService, configConfig)
soraGatewayHandler := handler.NewSoraGatewayHandler(gatewayService, soraGatewayService, concurrencyService, billingCacheService, configConfig) soraGatewayHandler := handler.NewSoraGatewayHandler(gatewayService, soraGatewayService, concurrencyService, billingCacheService, configConfig)
......
...@@ -669,6 +669,7 @@ var ( ...@@ -669,6 +669,7 @@ var (
{Name: "image_count", Type: field.TypeInt, Default: 0}, {Name: "image_count", Type: field.TypeInt, Default: 0},
{Name: "image_size", Type: field.TypeString, Nullable: true, Size: 10}, {Name: "image_size", Type: field.TypeString, Nullable: true, Size: 10},
{Name: "media_type", Type: field.TypeString, Nullable: true, Size: 16}, {Name: "media_type", Type: field.TypeString, Nullable: true, Size: 16},
{Name: "cache_ttl_overridden", Type: field.TypeBool, Default: false},
{Name: "created_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}}, {Name: "created_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}},
{Name: "api_key_id", Type: field.TypeInt64}, {Name: "api_key_id", Type: field.TypeInt64},
{Name: "account_id", Type: field.TypeInt64}, {Name: "account_id", Type: field.TypeInt64},
...@@ -684,31 +685,31 @@ var ( ...@@ -684,31 +685,31 @@ var (
ForeignKeys: []*schema.ForeignKey{ ForeignKeys: []*schema.ForeignKey{
{ {
Symbol: "usage_logs_api_keys_usage_logs", Symbol: "usage_logs_api_keys_usage_logs",
Columns: []*schema.Column{UsageLogsColumns[27]}, Columns: []*schema.Column{UsageLogsColumns[28]},
RefColumns: []*schema.Column{APIKeysColumns[0]}, RefColumns: []*schema.Column{APIKeysColumns[0]},
OnDelete: schema.NoAction, OnDelete: schema.NoAction,
}, },
{ {
Symbol: "usage_logs_accounts_usage_logs", Symbol: "usage_logs_accounts_usage_logs",
Columns: []*schema.Column{UsageLogsColumns[28]}, Columns: []*schema.Column{UsageLogsColumns[29]},
RefColumns: []*schema.Column{AccountsColumns[0]}, RefColumns: []*schema.Column{AccountsColumns[0]},
OnDelete: schema.NoAction, OnDelete: schema.NoAction,
}, },
{ {
Symbol: "usage_logs_groups_usage_logs", Symbol: "usage_logs_groups_usage_logs",
Columns: []*schema.Column{UsageLogsColumns[29]}, Columns: []*schema.Column{UsageLogsColumns[30]},
RefColumns: []*schema.Column{GroupsColumns[0]}, RefColumns: []*schema.Column{GroupsColumns[0]},
OnDelete: schema.SetNull, OnDelete: schema.SetNull,
}, },
{ {
Symbol: "usage_logs_users_usage_logs", Symbol: "usage_logs_users_usage_logs",
Columns: []*schema.Column{UsageLogsColumns[30]}, Columns: []*schema.Column{UsageLogsColumns[31]},
RefColumns: []*schema.Column{UsersColumns[0]}, RefColumns: []*schema.Column{UsersColumns[0]},
OnDelete: schema.NoAction, OnDelete: schema.NoAction,
}, },
{ {
Symbol: "usage_logs_user_subscriptions_usage_logs", Symbol: "usage_logs_user_subscriptions_usage_logs",
Columns: []*schema.Column{UsageLogsColumns[31]}, Columns: []*schema.Column{UsageLogsColumns[32]},
RefColumns: []*schema.Column{UserSubscriptionsColumns[0]}, RefColumns: []*schema.Column{UserSubscriptionsColumns[0]},
OnDelete: schema.SetNull, OnDelete: schema.SetNull,
}, },
...@@ -717,32 +718,32 @@ var ( ...@@ -717,32 +718,32 @@ var (
{ {
Name: "usagelog_user_id", Name: "usagelog_user_id",
Unique: false, Unique: false,
Columns: []*schema.Column{UsageLogsColumns[30]}, Columns: []*schema.Column{UsageLogsColumns[31]},
}, },
{ {
Name: "usagelog_api_key_id", Name: "usagelog_api_key_id",
Unique: false, Unique: false,
Columns: []*schema.Column{UsageLogsColumns[27]}, Columns: []*schema.Column{UsageLogsColumns[28]},
}, },
{ {
Name: "usagelog_account_id", Name: "usagelog_account_id",
Unique: false, Unique: false,
Columns: []*schema.Column{UsageLogsColumns[28]}, Columns: []*schema.Column{UsageLogsColumns[29]},
}, },
{ {
Name: "usagelog_group_id", Name: "usagelog_group_id",
Unique: false, Unique: false,
Columns: []*schema.Column{UsageLogsColumns[29]}, Columns: []*schema.Column{UsageLogsColumns[30]},
}, },
{ {
Name: "usagelog_subscription_id", Name: "usagelog_subscription_id",
Unique: false, Unique: false,
Columns: []*schema.Column{UsageLogsColumns[31]}, Columns: []*schema.Column{UsageLogsColumns[32]},
}, },
{ {
Name: "usagelog_created_at", Name: "usagelog_created_at",
Unique: false, Unique: false,
Columns: []*schema.Column{UsageLogsColumns[26]}, Columns: []*schema.Column{UsageLogsColumns[27]},
}, },
{ {
Name: "usagelog_model", Name: "usagelog_model",
...@@ -757,12 +758,12 @@ var ( ...@@ -757,12 +758,12 @@ var (
{ {
Name: "usagelog_user_id_created_at", Name: "usagelog_user_id_created_at",
Unique: false, Unique: false,
Columns: []*schema.Column{UsageLogsColumns[30], UsageLogsColumns[26]}, Columns: []*schema.Column{UsageLogsColumns[31], UsageLogsColumns[27]},
}, },
{ {
Name: "usagelog_api_key_id_created_at", Name: "usagelog_api_key_id_created_at",
Unique: false, Unique: false,
Columns: []*schema.Column{UsageLogsColumns[27], UsageLogsColumns[26]}, Columns: []*schema.Column{UsageLogsColumns[28], UsageLogsColumns[27]},
}, },
}, },
} }
......
...@@ -15980,6 +15980,7 @@ type UsageLogMutation struct { ...@@ -15980,6 +15980,7 @@ type UsageLogMutation struct {
addimage_count *int addimage_count *int
image_size *string image_size *string
media_type *string media_type *string
cache_ttl_overridden *bool
created_at *time.Time created_at *time.Time
clearedFields map[string]struct{} clearedFields map[string]struct{}
user *int64 user *int64
...@@ -17655,6 +17656,42 @@ func (m *UsageLogMutation) ResetMediaType() { ...@@ -17655,6 +17656,42 @@ func (m *UsageLogMutation) ResetMediaType() {
delete(m.clearedFields, usagelog.FieldMediaType) delete(m.clearedFields, usagelog.FieldMediaType)
} }
   
// SetCacheTTLOverridden sets the "cache_ttl_overridden" field.
func (m *UsageLogMutation) SetCacheTTLOverridden(b bool) {
m.cache_ttl_overridden = &b
}
// CacheTTLOverridden returns the value of the "cache_ttl_overridden" field in the mutation.
func (m *UsageLogMutation) CacheTTLOverridden() (r bool, exists bool) {
v := m.cache_ttl_overridden
if v == nil {
return
}
return *v, true
}
// OldCacheTTLOverridden returns the old "cache_ttl_overridden" field's value of the UsageLog entity.
// If the UsageLog object wasn't provided to the builder, the object is fetched from the database.
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
func (m *UsageLogMutation) OldCacheTTLOverridden(ctx context.Context) (v bool, err error) {
if !m.op.Is(OpUpdateOne) {
return v, errors.New("OldCacheTTLOverridden is only allowed on UpdateOne operations")
}
if m.id == nil || m.oldValue == nil {
return v, errors.New("OldCacheTTLOverridden requires an ID field in the mutation")
}
oldValue, err := m.oldValue(ctx)
if err != nil {
return v, fmt.Errorf("querying old value for OldCacheTTLOverridden: %w", err)
}
return oldValue.CacheTTLOverridden, nil
}
// ResetCacheTTLOverridden resets all changes to the "cache_ttl_overridden" field.
func (m *UsageLogMutation) ResetCacheTTLOverridden() {
m.cache_ttl_overridden = nil
}
// SetCreatedAt sets the "created_at" field. // SetCreatedAt sets the "created_at" field.
func (m *UsageLogMutation) SetCreatedAt(t time.Time) { func (m *UsageLogMutation) SetCreatedAt(t time.Time) {
m.created_at = &t m.created_at = &t
...@@ -17860,7 +17897,7 @@ func (m *UsageLogMutation) Type() string { ...@@ -17860,7 +17897,7 @@ func (m *UsageLogMutation) Type() string {
// order to get all numeric fields that were incremented/decremented, call // order to get all numeric fields that were incremented/decremented, call
// AddedFields(). // AddedFields().
func (m *UsageLogMutation) Fields() []string { func (m *UsageLogMutation) Fields() []string {
fields := make([]string, 0, 31) fields := make([]string, 0, 32)
if m.user != nil { if m.user != nil {
fields = append(fields, usagelog.FieldUserID) fields = append(fields, usagelog.FieldUserID)
} }
...@@ -17951,6 +17988,9 @@ func (m *UsageLogMutation) Fields() []string { ...@@ -17951,6 +17988,9 @@ func (m *UsageLogMutation) Fields() []string {
if m.media_type != nil { if m.media_type != nil {
fields = append(fields, usagelog.FieldMediaType) fields = append(fields, usagelog.FieldMediaType)
} }
if m.cache_ttl_overridden != nil {
fields = append(fields, usagelog.FieldCacheTTLOverridden)
}
if m.created_at != nil { if m.created_at != nil {
fields = append(fields, usagelog.FieldCreatedAt) fields = append(fields, usagelog.FieldCreatedAt)
} }
...@@ -18022,6 +18062,8 @@ func (m *UsageLogMutation) Field(name string) (ent.Value, bool) { ...@@ -18022,6 +18062,8 @@ func (m *UsageLogMutation) Field(name string) (ent.Value, bool) {
return m.ImageSize() return m.ImageSize()
case usagelog.FieldMediaType: case usagelog.FieldMediaType:
return m.MediaType() return m.MediaType()
case usagelog.FieldCacheTTLOverridden:
return m.CacheTTLOverridden()
case usagelog.FieldCreatedAt: case usagelog.FieldCreatedAt:
return m.CreatedAt() return m.CreatedAt()
} }
...@@ -18093,6 +18135,8 @@ func (m *UsageLogMutation) OldField(ctx context.Context, name string) (ent.Value ...@@ -18093,6 +18135,8 @@ func (m *UsageLogMutation) OldField(ctx context.Context, name string) (ent.Value
return m.OldImageSize(ctx) return m.OldImageSize(ctx)
case usagelog.FieldMediaType: case usagelog.FieldMediaType:
return m.OldMediaType(ctx) return m.OldMediaType(ctx)
case usagelog.FieldCacheTTLOverridden:
return m.OldCacheTTLOverridden(ctx)
case usagelog.FieldCreatedAt: case usagelog.FieldCreatedAt:
return m.OldCreatedAt(ctx) return m.OldCreatedAt(ctx)
} }
...@@ -18314,6 +18358,13 @@ func (m *UsageLogMutation) SetField(name string, value ent.Value) error { ...@@ -18314,6 +18358,13 @@ func (m *UsageLogMutation) SetField(name string, value ent.Value) error {
} }
m.SetMediaType(v) m.SetMediaType(v)
return nil return nil
case usagelog.FieldCacheTTLOverridden:
v, ok := value.(bool)
if !ok {
return fmt.Errorf("unexpected type %T for field %s", value, name)
}
m.SetCacheTTLOverridden(v)
return nil
case usagelog.FieldCreatedAt: case usagelog.FieldCreatedAt:
v, ok := value.(time.Time) v, ok := value.(time.Time)
if !ok { if !ok {
...@@ -18736,6 +18787,9 @@ func (m *UsageLogMutation) ResetField(name string) error { ...@@ -18736,6 +18787,9 @@ func (m *UsageLogMutation) ResetField(name string) error {
case usagelog.FieldMediaType: case usagelog.FieldMediaType:
m.ResetMediaType() m.ResetMediaType()
return nil return nil
case usagelog.FieldCacheTTLOverridden:
m.ResetCacheTTLOverridden()
return nil
case usagelog.FieldCreatedAt: case usagelog.FieldCreatedAt:
m.ResetCreatedAt() m.ResetCreatedAt()
return nil return nil
......
...@@ -821,8 +821,12 @@ func init() { ...@@ -821,8 +821,12 @@ func init() {
usagelogDescMediaType := usagelogFields[29].Descriptor() usagelogDescMediaType := usagelogFields[29].Descriptor()
// usagelog.MediaTypeValidator is a validator for the "media_type" field. It is called by the builders before save. // usagelog.MediaTypeValidator is a validator for the "media_type" field. It is called by the builders before save.
usagelog.MediaTypeValidator = usagelogDescMediaType.Validators[0].(func(string) error) usagelog.MediaTypeValidator = usagelogDescMediaType.Validators[0].(func(string) error)
// usagelogDescCacheTTLOverridden is the schema descriptor for cache_ttl_overridden field.
usagelogDescCacheTTLOverridden := usagelogFields[30].Descriptor()
// usagelog.DefaultCacheTTLOverridden holds the default value on creation for the cache_ttl_overridden field.
usagelog.DefaultCacheTTLOverridden = usagelogDescCacheTTLOverridden.Default.(bool)
// usagelogDescCreatedAt is the schema descriptor for created_at field. // usagelogDescCreatedAt is the schema descriptor for created_at field.
usagelogDescCreatedAt := usagelogFields[30].Descriptor() usagelogDescCreatedAt := usagelogFields[31].Descriptor()
// usagelog.DefaultCreatedAt holds the default value on creation for the created_at field. // usagelog.DefaultCreatedAt holds the default value on creation for the created_at field.
usagelog.DefaultCreatedAt = usagelogDescCreatedAt.Default.(func() time.Time) usagelog.DefaultCreatedAt = usagelogDescCreatedAt.Default.(func() time.Time)
userMixin := schema.User{}.Mixin() userMixin := schema.User{}.Mixin()
......
...@@ -124,6 +124,10 @@ func (UsageLog) Fields() []ent.Field { ...@@ -124,6 +124,10 @@ func (UsageLog) Fields() []ent.Field {
Optional(). Optional().
Nillable(), Nillable(),
// Cache TTL Override 标记(管理员强制替换了缓存 TTL 计费)
field.Bool("cache_ttl_overridden").
Default(false),
// 时间戳(只有 created_at,日志不可修改) // 时间戳(只有 created_at,日志不可修改)
field.Time("created_at"). field.Time("created_at").
Default(time.Now). Default(time.Now).
......
...@@ -82,6 +82,8 @@ type UsageLog struct { ...@@ -82,6 +82,8 @@ type UsageLog struct {
ImageSize *string `json:"image_size,omitempty"` ImageSize *string `json:"image_size,omitempty"`
// MediaType holds the value of the "media_type" field. // MediaType holds the value of the "media_type" field.
MediaType *string `json:"media_type,omitempty"` MediaType *string `json:"media_type,omitempty"`
// CacheTTLOverridden holds the value of the "cache_ttl_overridden" field.
CacheTTLOverridden bool `json:"cache_ttl_overridden,omitempty"`
// CreatedAt holds the value of the "created_at" field. // CreatedAt holds the value of the "created_at" field.
CreatedAt time.Time `json:"created_at,omitempty"` CreatedAt time.Time `json:"created_at,omitempty"`
// Edges holds the relations/edges for other nodes in the graph. // Edges holds the relations/edges for other nodes in the graph.
...@@ -167,7 +169,7 @@ func (*UsageLog) scanValues(columns []string) ([]any, error) { ...@@ -167,7 +169,7 @@ func (*UsageLog) scanValues(columns []string) ([]any, error) {
values := make([]any, len(columns)) values := make([]any, len(columns))
for i := range columns { for i := range columns {
switch columns[i] { switch columns[i] {
case usagelog.FieldStream: case usagelog.FieldStream, usagelog.FieldCacheTTLOverridden:
values[i] = new(sql.NullBool) values[i] = new(sql.NullBool)
case usagelog.FieldInputCost, usagelog.FieldOutputCost, usagelog.FieldCacheCreationCost, usagelog.FieldCacheReadCost, usagelog.FieldTotalCost, usagelog.FieldActualCost, usagelog.FieldRateMultiplier, usagelog.FieldAccountRateMultiplier: case usagelog.FieldInputCost, usagelog.FieldOutputCost, usagelog.FieldCacheCreationCost, usagelog.FieldCacheReadCost, usagelog.FieldTotalCost, usagelog.FieldActualCost, usagelog.FieldRateMultiplier, usagelog.FieldAccountRateMultiplier:
values[i] = new(sql.NullFloat64) values[i] = new(sql.NullFloat64)
...@@ -387,6 +389,12 @@ func (_m *UsageLog) assignValues(columns []string, values []any) error { ...@@ -387,6 +389,12 @@ func (_m *UsageLog) assignValues(columns []string, values []any) error {
_m.MediaType = new(string) _m.MediaType = new(string)
*_m.MediaType = value.String *_m.MediaType = value.String
} }
case usagelog.FieldCacheTTLOverridden:
if value, ok := values[i].(*sql.NullBool); !ok {
return fmt.Errorf("unexpected type %T for field cache_ttl_overridden", values[i])
} else if value.Valid {
_m.CacheTTLOverridden = value.Bool
}
case usagelog.FieldCreatedAt: case usagelog.FieldCreatedAt:
if value, ok := values[i].(*sql.NullTime); !ok { if value, ok := values[i].(*sql.NullTime); !ok {
return fmt.Errorf("unexpected type %T for field created_at", values[i]) return fmt.Errorf("unexpected type %T for field created_at", values[i])
...@@ -562,6 +570,9 @@ func (_m *UsageLog) String() string { ...@@ -562,6 +570,9 @@ func (_m *UsageLog) String() string {
builder.WriteString(*v) builder.WriteString(*v)
} }
builder.WriteString(", ") builder.WriteString(", ")
builder.WriteString("cache_ttl_overridden=")
builder.WriteString(fmt.Sprintf("%v", _m.CacheTTLOverridden))
builder.WriteString(", ")
builder.WriteString("created_at=") builder.WriteString("created_at=")
builder.WriteString(_m.CreatedAt.Format(time.ANSIC)) builder.WriteString(_m.CreatedAt.Format(time.ANSIC))
builder.WriteByte(')') builder.WriteByte(')')
......
...@@ -74,6 +74,8 @@ const ( ...@@ -74,6 +74,8 @@ const (
FieldImageSize = "image_size" FieldImageSize = "image_size"
// FieldMediaType holds the string denoting the media_type field in the database. // FieldMediaType holds the string denoting the media_type field in the database.
FieldMediaType = "media_type" FieldMediaType = "media_type"
// FieldCacheTTLOverridden holds the string denoting the cache_ttl_overridden field in the database.
FieldCacheTTLOverridden = "cache_ttl_overridden"
// FieldCreatedAt holds the string denoting the created_at field in the database. // FieldCreatedAt holds the string denoting the created_at field in the database.
FieldCreatedAt = "created_at" FieldCreatedAt = "created_at"
// EdgeUser holds the string denoting the user edge name in mutations. // EdgeUser holds the string denoting the user edge name in mutations.
...@@ -158,6 +160,7 @@ var Columns = []string{ ...@@ -158,6 +160,7 @@ var Columns = []string{
FieldImageCount, FieldImageCount,
FieldImageSize, FieldImageSize,
FieldMediaType, FieldMediaType,
FieldCacheTTLOverridden,
FieldCreatedAt, FieldCreatedAt,
} }
...@@ -216,6 +219,8 @@ var ( ...@@ -216,6 +219,8 @@ var (
ImageSizeValidator func(string) error ImageSizeValidator func(string) error
// MediaTypeValidator is a validator for the "media_type" field. It is called by the builders before save. // MediaTypeValidator is a validator for the "media_type" field. It is called by the builders before save.
MediaTypeValidator func(string) error MediaTypeValidator func(string) error
// DefaultCacheTTLOverridden holds the default value on creation for the "cache_ttl_overridden" field.
DefaultCacheTTLOverridden bool
// DefaultCreatedAt holds the default value on creation for the "created_at" field. // DefaultCreatedAt holds the default value on creation for the "created_at" field.
DefaultCreatedAt func() time.Time DefaultCreatedAt func() time.Time
) )
...@@ -378,6 +383,11 @@ func ByMediaType(opts ...sql.OrderTermOption) OrderOption { ...@@ -378,6 +383,11 @@ func ByMediaType(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldMediaType, opts...).ToFunc() return sql.OrderByField(FieldMediaType, opts...).ToFunc()
} }
// ByCacheTTLOverridden orders the results by the cache_ttl_overridden field.
func ByCacheTTLOverridden(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldCacheTTLOverridden, opts...).ToFunc()
}
// ByCreatedAt orders the results by the created_at field. // ByCreatedAt orders the results by the created_at field.
func ByCreatedAt(opts ...sql.OrderTermOption) OrderOption { func ByCreatedAt(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldCreatedAt, opts...).ToFunc() return sql.OrderByField(FieldCreatedAt, opts...).ToFunc()
......
...@@ -205,6 +205,11 @@ func MediaType(v string) predicate.UsageLog { ...@@ -205,6 +205,11 @@ func MediaType(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldEQ(FieldMediaType, v)) return predicate.UsageLog(sql.FieldEQ(FieldMediaType, v))
} }
// CacheTTLOverridden applies equality check predicate on the "cache_ttl_overridden" field. It's identical to CacheTTLOverriddenEQ.
func CacheTTLOverridden(v bool) predicate.UsageLog {
return predicate.UsageLog(sql.FieldEQ(FieldCacheTTLOverridden, v))
}
// CreatedAt applies equality check predicate on the "created_at" field. It's identical to CreatedAtEQ. // CreatedAt applies equality check predicate on the "created_at" field. It's identical to CreatedAtEQ.
func CreatedAt(v time.Time) predicate.UsageLog { func CreatedAt(v time.Time) predicate.UsageLog {
return predicate.UsageLog(sql.FieldEQ(FieldCreatedAt, v)) return predicate.UsageLog(sql.FieldEQ(FieldCreatedAt, v))
...@@ -1520,6 +1525,16 @@ func MediaTypeContainsFold(v string) predicate.UsageLog { ...@@ -1520,6 +1525,16 @@ func MediaTypeContainsFold(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldContainsFold(FieldMediaType, v)) return predicate.UsageLog(sql.FieldContainsFold(FieldMediaType, v))
} }
// CacheTTLOverriddenEQ applies the EQ predicate on the "cache_ttl_overridden" field.
func CacheTTLOverriddenEQ(v bool) predicate.UsageLog {
return predicate.UsageLog(sql.FieldEQ(FieldCacheTTLOverridden, v))
}
// CacheTTLOverriddenNEQ applies the NEQ predicate on the "cache_ttl_overridden" field.
func CacheTTLOverriddenNEQ(v bool) predicate.UsageLog {
return predicate.UsageLog(sql.FieldNEQ(FieldCacheTTLOverridden, v))
}
// CreatedAtEQ applies the EQ predicate on the "created_at" field. // CreatedAtEQ applies the EQ predicate on the "created_at" field.
func CreatedAtEQ(v time.Time) predicate.UsageLog { func CreatedAtEQ(v time.Time) predicate.UsageLog {
return predicate.UsageLog(sql.FieldEQ(FieldCreatedAt, v)) return predicate.UsageLog(sql.FieldEQ(FieldCreatedAt, v))
......
...@@ -407,6 +407,20 @@ func (_c *UsageLogCreate) SetNillableMediaType(v *string) *UsageLogCreate { ...@@ -407,6 +407,20 @@ func (_c *UsageLogCreate) SetNillableMediaType(v *string) *UsageLogCreate {
return _c return _c
} }
// SetCacheTTLOverridden sets the "cache_ttl_overridden" field.
func (_c *UsageLogCreate) SetCacheTTLOverridden(v bool) *UsageLogCreate {
_c.mutation.SetCacheTTLOverridden(v)
return _c
}
// SetNillableCacheTTLOverridden sets the "cache_ttl_overridden" field if the given value is not nil.
func (_c *UsageLogCreate) SetNillableCacheTTLOverridden(v *bool) *UsageLogCreate {
if v != nil {
_c.SetCacheTTLOverridden(*v)
}
return _c
}
// SetCreatedAt sets the "created_at" field. // SetCreatedAt sets the "created_at" field.
func (_c *UsageLogCreate) SetCreatedAt(v time.Time) *UsageLogCreate { func (_c *UsageLogCreate) SetCreatedAt(v time.Time) *UsageLogCreate {
_c.mutation.SetCreatedAt(v) _c.mutation.SetCreatedAt(v)
...@@ -545,6 +559,10 @@ func (_c *UsageLogCreate) defaults() { ...@@ -545,6 +559,10 @@ func (_c *UsageLogCreate) defaults() {
v := usagelog.DefaultImageCount v := usagelog.DefaultImageCount
_c.mutation.SetImageCount(v) _c.mutation.SetImageCount(v)
} }
if _, ok := _c.mutation.CacheTTLOverridden(); !ok {
v := usagelog.DefaultCacheTTLOverridden
_c.mutation.SetCacheTTLOverridden(v)
}
if _, ok := _c.mutation.CreatedAt(); !ok { if _, ok := _c.mutation.CreatedAt(); !ok {
v := usagelog.DefaultCreatedAt() v := usagelog.DefaultCreatedAt()
_c.mutation.SetCreatedAt(v) _c.mutation.SetCreatedAt(v)
...@@ -646,6 +664,9 @@ func (_c *UsageLogCreate) check() error { ...@@ -646,6 +664,9 @@ func (_c *UsageLogCreate) check() error {
return &ValidationError{Name: "media_type", err: fmt.Errorf(`ent: validator failed for field "UsageLog.media_type": %w`, err)} return &ValidationError{Name: "media_type", err: fmt.Errorf(`ent: validator failed for field "UsageLog.media_type": %w`, err)}
} }
} }
if _, ok := _c.mutation.CacheTTLOverridden(); !ok {
return &ValidationError{Name: "cache_ttl_overridden", err: errors.New(`ent: missing required field "UsageLog.cache_ttl_overridden"`)}
}
if _, ok := _c.mutation.CreatedAt(); !ok { if _, ok := _c.mutation.CreatedAt(); !ok {
return &ValidationError{Name: "created_at", err: errors.New(`ent: missing required field "UsageLog.created_at"`)} return &ValidationError{Name: "created_at", err: errors.New(`ent: missing required field "UsageLog.created_at"`)}
} }
...@@ -785,6 +806,10 @@ func (_c *UsageLogCreate) createSpec() (*UsageLog, *sqlgraph.CreateSpec) { ...@@ -785,6 +806,10 @@ func (_c *UsageLogCreate) createSpec() (*UsageLog, *sqlgraph.CreateSpec) {
_spec.SetField(usagelog.FieldMediaType, field.TypeString, value) _spec.SetField(usagelog.FieldMediaType, field.TypeString, value)
_node.MediaType = &value _node.MediaType = &value
} }
if value, ok := _c.mutation.CacheTTLOverridden(); ok {
_spec.SetField(usagelog.FieldCacheTTLOverridden, field.TypeBool, value)
_node.CacheTTLOverridden = value
}
if value, ok := _c.mutation.CreatedAt(); ok { if value, ok := _c.mutation.CreatedAt(); ok {
_spec.SetField(usagelog.FieldCreatedAt, field.TypeTime, value) _spec.SetField(usagelog.FieldCreatedAt, field.TypeTime, value)
_node.CreatedAt = value _node.CreatedAt = value
...@@ -1448,6 +1473,18 @@ func (u *UsageLogUpsert) ClearMediaType() *UsageLogUpsert { ...@@ -1448,6 +1473,18 @@ func (u *UsageLogUpsert) ClearMediaType() *UsageLogUpsert {
return u return u
} }
// SetCacheTTLOverridden sets the "cache_ttl_overridden" field.
func (u *UsageLogUpsert) SetCacheTTLOverridden(v bool) *UsageLogUpsert {
u.Set(usagelog.FieldCacheTTLOverridden, v)
return u
}
// UpdateCacheTTLOverridden sets the "cache_ttl_overridden" field to the value that was provided on create.
func (u *UsageLogUpsert) UpdateCacheTTLOverridden() *UsageLogUpsert {
u.SetExcluded(usagelog.FieldCacheTTLOverridden)
return u
}
// UpdateNewValues updates the mutable fields using the new values that were set on create. // UpdateNewValues updates the mutable fields using the new values that were set on create.
// Using this option is equivalent to using: // Using this option is equivalent to using:
// //
...@@ -2102,6 +2139,20 @@ func (u *UsageLogUpsertOne) ClearMediaType() *UsageLogUpsertOne { ...@@ -2102,6 +2139,20 @@ func (u *UsageLogUpsertOne) ClearMediaType() *UsageLogUpsertOne {
}) })
} }
// SetCacheTTLOverridden sets the "cache_ttl_overridden" field.
func (u *UsageLogUpsertOne) SetCacheTTLOverridden(v bool) *UsageLogUpsertOne {
return u.Update(func(s *UsageLogUpsert) {
s.SetCacheTTLOverridden(v)
})
}
// UpdateCacheTTLOverridden sets the "cache_ttl_overridden" field to the value that was provided on create.
func (u *UsageLogUpsertOne) UpdateCacheTTLOverridden() *UsageLogUpsertOne {
return u.Update(func(s *UsageLogUpsert) {
s.UpdateCacheTTLOverridden()
})
}
// Exec executes the query. // Exec executes the query.
func (u *UsageLogUpsertOne) Exec(ctx context.Context) error { func (u *UsageLogUpsertOne) Exec(ctx context.Context) error {
if len(u.create.conflict) == 0 { if len(u.create.conflict) == 0 {
...@@ -2922,6 +2973,20 @@ func (u *UsageLogUpsertBulk) ClearMediaType() *UsageLogUpsertBulk { ...@@ -2922,6 +2973,20 @@ func (u *UsageLogUpsertBulk) ClearMediaType() *UsageLogUpsertBulk {
}) })
} }
// SetCacheTTLOverridden sets the "cache_ttl_overridden" field.
func (u *UsageLogUpsertBulk) SetCacheTTLOverridden(v bool) *UsageLogUpsertBulk {
return u.Update(func(s *UsageLogUpsert) {
s.SetCacheTTLOverridden(v)
})
}
// UpdateCacheTTLOverridden sets the "cache_ttl_overridden" field to the value that was provided on create.
func (u *UsageLogUpsertBulk) UpdateCacheTTLOverridden() *UsageLogUpsertBulk {
return u.Update(func(s *UsageLogUpsert) {
s.UpdateCacheTTLOverridden()
})
}
// Exec executes the query. // Exec executes the query.
func (u *UsageLogUpsertBulk) Exec(ctx context.Context) error { func (u *UsageLogUpsertBulk) Exec(ctx context.Context) error {
if u.create.err != nil { if u.create.err != nil {
......
...@@ -632,6 +632,20 @@ func (_u *UsageLogUpdate) ClearMediaType() *UsageLogUpdate { ...@@ -632,6 +632,20 @@ func (_u *UsageLogUpdate) ClearMediaType() *UsageLogUpdate {
return _u return _u
} }
// SetCacheTTLOverridden sets the "cache_ttl_overridden" field.
func (_u *UsageLogUpdate) SetCacheTTLOverridden(v bool) *UsageLogUpdate {
_u.mutation.SetCacheTTLOverridden(v)
return _u
}
// SetNillableCacheTTLOverridden sets the "cache_ttl_overridden" field if the given value is not nil.
func (_u *UsageLogUpdate) SetNillableCacheTTLOverridden(v *bool) *UsageLogUpdate {
if v != nil {
_u.SetCacheTTLOverridden(*v)
}
return _u
}
// SetUser sets the "user" edge to the User entity. // SetUser sets the "user" edge to the User entity.
func (_u *UsageLogUpdate) SetUser(v *User) *UsageLogUpdate { func (_u *UsageLogUpdate) SetUser(v *User) *UsageLogUpdate {
return _u.SetUserID(v.ID) return _u.SetUserID(v.ID)
...@@ -925,6 +939,9 @@ func (_u *UsageLogUpdate) sqlSave(ctx context.Context) (_node int, err error) { ...@@ -925,6 +939,9 @@ func (_u *UsageLogUpdate) sqlSave(ctx context.Context) (_node int, err error) {
if _u.mutation.MediaTypeCleared() { if _u.mutation.MediaTypeCleared() {
_spec.ClearField(usagelog.FieldMediaType, field.TypeString) _spec.ClearField(usagelog.FieldMediaType, field.TypeString)
} }
if value, ok := _u.mutation.CacheTTLOverridden(); ok {
_spec.SetField(usagelog.FieldCacheTTLOverridden, field.TypeBool, value)
}
if _u.mutation.UserCleared() { if _u.mutation.UserCleared() {
edge := &sqlgraph.EdgeSpec{ edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.M2O, Rel: sqlgraph.M2O,
...@@ -1690,6 +1707,20 @@ func (_u *UsageLogUpdateOne) ClearMediaType() *UsageLogUpdateOne { ...@@ -1690,6 +1707,20 @@ func (_u *UsageLogUpdateOne) ClearMediaType() *UsageLogUpdateOne {
return _u return _u
} }
// SetCacheTTLOverridden sets the "cache_ttl_overridden" field.
func (_u *UsageLogUpdateOne) SetCacheTTLOverridden(v bool) *UsageLogUpdateOne {
_u.mutation.SetCacheTTLOverridden(v)
return _u
}
// SetNillableCacheTTLOverridden sets the "cache_ttl_overridden" field if the given value is not nil.
func (_u *UsageLogUpdateOne) SetNillableCacheTTLOverridden(v *bool) *UsageLogUpdateOne {
if v != nil {
_u.SetCacheTTLOverridden(*v)
}
return _u
}
// SetUser sets the "user" edge to the User entity. // SetUser sets the "user" edge to the User entity.
func (_u *UsageLogUpdateOne) SetUser(v *User) *UsageLogUpdateOne { func (_u *UsageLogUpdateOne) SetUser(v *User) *UsageLogUpdateOne {
return _u.SetUserID(v.ID) return _u.SetUserID(v.ID)
...@@ -2013,6 +2044,9 @@ func (_u *UsageLogUpdateOne) sqlSave(ctx context.Context) (_node *UsageLog, err ...@@ -2013,6 +2044,9 @@ func (_u *UsageLogUpdateOne) sqlSave(ctx context.Context) (_node *UsageLog, err
if _u.mutation.MediaTypeCleared() { if _u.mutation.MediaTypeCleared() {
_spec.ClearField(usagelog.FieldMediaType, field.TypeString) _spec.ClearField(usagelog.FieldMediaType, field.TypeString)
} }
if value, ok := _u.mutation.CacheTTLOverridden(); ok {
_spec.SetField(usagelog.FieldCacheTTLOverridden, field.TypeBool, value)
}
if _u.mutation.UserCleared() { if _u.mutation.UserCleared() {
edge := &sqlgraph.EdgeSpec{ edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.M2O, Rel: sqlgraph.M2O,
......
...@@ -162,6 +162,8 @@ type TokenRefreshConfig struct { ...@@ -162,6 +162,8 @@ type TokenRefreshConfig struct {
MaxRetries int `mapstructure:"max_retries"` MaxRetries int `mapstructure:"max_retries"`
// 重试退避基础时间(秒) // 重试退避基础时间(秒)
RetryBackoffSeconds int `mapstructure:"retry_backoff_seconds"` RetryBackoffSeconds int `mapstructure:"retry_backoff_seconds"`
// 是否允许 OpenAI 刷新器同步覆盖关联的 Sora 账号 token(默认关闭)
SyncLinkedSoraAccounts bool `mapstructure:"sync_linked_sora_accounts"`
} }
type PricingConfig struct { type PricingConfig struct {
...@@ -272,14 +274,27 @@ type SoraClientConfig struct { ...@@ -272,14 +274,27 @@ type SoraClientConfig struct {
BaseURL string `mapstructure:"base_url"` BaseURL string `mapstructure:"base_url"`
TimeoutSeconds int `mapstructure:"timeout_seconds"` TimeoutSeconds int `mapstructure:"timeout_seconds"`
MaxRetries int `mapstructure:"max_retries"` MaxRetries int `mapstructure:"max_retries"`
CloudflareChallengeCooldownSeconds int `mapstructure:"cloudflare_challenge_cooldown_seconds"`
PollIntervalSeconds int `mapstructure:"poll_interval_seconds"` PollIntervalSeconds int `mapstructure:"poll_interval_seconds"`
MaxPollAttempts int `mapstructure:"max_poll_attempts"` MaxPollAttempts int `mapstructure:"max_poll_attempts"`
RecentTaskLimit int `mapstructure:"recent_task_limit"` RecentTaskLimit int `mapstructure:"recent_task_limit"`
RecentTaskLimitMax int `mapstructure:"recent_task_limit_max"` RecentTaskLimitMax int `mapstructure:"recent_task_limit_max"`
Debug bool `mapstructure:"debug"` Debug bool `mapstructure:"debug"`
UseOpenAITokenProvider bool `mapstructure:"use_openai_token_provider"`
Headers map[string]string `mapstructure:"headers"` Headers map[string]string `mapstructure:"headers"`
UserAgent string `mapstructure:"user_agent"` UserAgent string `mapstructure:"user_agent"`
DisableTLSFingerprint bool `mapstructure:"disable_tls_fingerprint"` DisableTLSFingerprint bool `mapstructure:"disable_tls_fingerprint"`
CurlCFFISidecar SoraCurlCFFISidecarConfig `mapstructure:"curl_cffi_sidecar"`
}
// SoraCurlCFFISidecarConfig Sora 专用 curl_cffi sidecar 配置
type SoraCurlCFFISidecarConfig struct {
Enabled bool `mapstructure:"enabled"`
BaseURL string `mapstructure:"base_url"`
Impersonate string `mapstructure:"impersonate"`
TimeoutSeconds int `mapstructure:"timeout_seconds"`
SessionReuseEnabled bool `mapstructure:"session_reuse_enabled"`
SessionTTLSeconds int `mapstructure:"session_ttl_seconds"`
} }
// SoraStorageConfig 媒体存储配置 // SoraStorageConfig 媒体存储配置
...@@ -1111,14 +1126,22 @@ func setDefaults() { ...@@ -1111,14 +1126,22 @@ func setDefaults() {
viper.SetDefault("sora.client.base_url", "https://sora.chatgpt.com/backend") viper.SetDefault("sora.client.base_url", "https://sora.chatgpt.com/backend")
viper.SetDefault("sora.client.timeout_seconds", 120) viper.SetDefault("sora.client.timeout_seconds", 120)
viper.SetDefault("sora.client.max_retries", 3) viper.SetDefault("sora.client.max_retries", 3)
viper.SetDefault("sora.client.cloudflare_challenge_cooldown_seconds", 900)
viper.SetDefault("sora.client.poll_interval_seconds", 2) viper.SetDefault("sora.client.poll_interval_seconds", 2)
viper.SetDefault("sora.client.max_poll_attempts", 600) viper.SetDefault("sora.client.max_poll_attempts", 600)
viper.SetDefault("sora.client.recent_task_limit", 50) viper.SetDefault("sora.client.recent_task_limit", 50)
viper.SetDefault("sora.client.recent_task_limit_max", 200) viper.SetDefault("sora.client.recent_task_limit_max", 200)
viper.SetDefault("sora.client.debug", false) viper.SetDefault("sora.client.debug", false)
viper.SetDefault("sora.client.use_openai_token_provider", false)
viper.SetDefault("sora.client.headers", map[string]string{}) viper.SetDefault("sora.client.headers", map[string]string{})
viper.SetDefault("sora.client.user_agent", "Sora/1.2026.007 (Android 15; 24122RKC7C; build 2600700)") viper.SetDefault("sora.client.user_agent", "Sora/1.2026.007 (Android 15; 24122RKC7C; build 2600700)")
viper.SetDefault("sora.client.disable_tls_fingerprint", false) viper.SetDefault("sora.client.disable_tls_fingerprint", false)
viper.SetDefault("sora.client.curl_cffi_sidecar.enabled", true)
viper.SetDefault("sora.client.curl_cffi_sidecar.base_url", "http://sora-curl-cffi-sidecar:8080")
viper.SetDefault("sora.client.curl_cffi_sidecar.impersonate", "chrome131")
viper.SetDefault("sora.client.curl_cffi_sidecar.timeout_seconds", 60)
viper.SetDefault("sora.client.curl_cffi_sidecar.session_reuse_enabled", true)
viper.SetDefault("sora.client.curl_cffi_sidecar.session_ttl_seconds", 3600)
viper.SetDefault("sora.storage.type", "local") viper.SetDefault("sora.storage.type", "local")
viper.SetDefault("sora.storage.local_path", "") viper.SetDefault("sora.storage.local_path", "")
...@@ -1137,6 +1160,7 @@ func setDefaults() { ...@@ -1137,6 +1160,7 @@ func setDefaults() {
viper.SetDefault("token_refresh.refresh_before_expiry_hours", 0.5) // 提前30分钟刷新(适配Google 1小时token) viper.SetDefault("token_refresh.refresh_before_expiry_hours", 0.5) // 提前30分钟刷新(适配Google 1小时token)
viper.SetDefault("token_refresh.max_retries", 3) // 最多重试3次 viper.SetDefault("token_refresh.max_retries", 3) // 最多重试3次
viper.SetDefault("token_refresh.retry_backoff_seconds", 2) // 重试退避基础2秒 viper.SetDefault("token_refresh.retry_backoff_seconds", 2) // 重试退避基础2秒
viper.SetDefault("token_refresh.sync_linked_sora_accounts", false) // 默认不跨平台覆盖 Sora token
// Gemini OAuth - configure via environment variables or config file // Gemini OAuth - configure via environment variables or config file
// GEMINI_OAUTH_CLIENT_ID and GEMINI_OAUTH_CLIENT_SECRET // GEMINI_OAUTH_CLIENT_ID and GEMINI_OAUTH_CLIENT_SECRET
...@@ -1505,6 +1529,9 @@ func (c *Config) Validate() error { ...@@ -1505,6 +1529,9 @@ func (c *Config) Validate() error {
if c.Sora.Client.MaxRetries < 0 { if c.Sora.Client.MaxRetries < 0 {
return fmt.Errorf("sora.client.max_retries must be non-negative") return fmt.Errorf("sora.client.max_retries must be non-negative")
} }
if c.Sora.Client.CloudflareChallengeCooldownSeconds < 0 {
return fmt.Errorf("sora.client.cloudflare_challenge_cooldown_seconds must be non-negative")
}
if c.Sora.Client.PollIntervalSeconds < 0 { if c.Sora.Client.PollIntervalSeconds < 0 {
return fmt.Errorf("sora.client.poll_interval_seconds must be non-negative") return fmt.Errorf("sora.client.poll_interval_seconds must be non-negative")
} }
...@@ -1521,6 +1548,18 @@ func (c *Config) Validate() error { ...@@ -1521,6 +1548,18 @@ func (c *Config) Validate() error {
c.Sora.Client.RecentTaskLimitMax < c.Sora.Client.RecentTaskLimit { c.Sora.Client.RecentTaskLimitMax < c.Sora.Client.RecentTaskLimit {
c.Sora.Client.RecentTaskLimitMax = c.Sora.Client.RecentTaskLimit c.Sora.Client.RecentTaskLimitMax = c.Sora.Client.RecentTaskLimit
} }
if c.Sora.Client.CurlCFFISidecar.TimeoutSeconds < 0 {
return fmt.Errorf("sora.client.curl_cffi_sidecar.timeout_seconds must be non-negative")
}
if c.Sora.Client.CurlCFFISidecar.SessionTTLSeconds < 0 {
return fmt.Errorf("sora.client.curl_cffi_sidecar.session_ttl_seconds must be non-negative")
}
if !c.Sora.Client.CurlCFFISidecar.Enabled {
return fmt.Errorf("sora.client.curl_cffi_sidecar.enabled must be true")
}
if strings.TrimSpace(c.Sora.Client.CurlCFFISidecar.BaseURL) == "" {
return fmt.Errorf("sora.client.curl_cffi_sidecar.base_url is required")
}
if c.Sora.Storage.MaxConcurrentDownloads < 0 { if c.Sora.Storage.MaxConcurrentDownloads < 0 {
return fmt.Errorf("sora.storage.max_concurrent_downloads must be non-negative") return fmt.Errorf("sora.storage.max_concurrent_downloads must be non-negative")
} }
......
...@@ -1024,3 +1024,91 @@ func TestValidateConfigErrors(t *testing.T) { ...@@ -1024,3 +1024,91 @@ func TestValidateConfigErrors(t *testing.T) {
}) })
} }
} }
func TestSoraCurlCFFISidecarDefaults(t *testing.T) {
resetViperWithJWTSecret(t)
cfg, err := Load()
if err != nil {
t.Fatalf("Load() error: %v", err)
}
if !cfg.Sora.Client.CurlCFFISidecar.Enabled {
t.Fatalf("Sora curl_cffi sidecar should be enabled by default")
}
if cfg.Sora.Client.CloudflareChallengeCooldownSeconds <= 0 {
t.Fatalf("Sora cloudflare challenge cooldown should be positive by default")
}
if cfg.Sora.Client.CurlCFFISidecar.BaseURL == "" {
t.Fatalf("Sora curl_cffi sidecar base_url should not be empty by default")
}
if cfg.Sora.Client.CurlCFFISidecar.Impersonate == "" {
t.Fatalf("Sora curl_cffi sidecar impersonate should not be empty by default")
}
if !cfg.Sora.Client.CurlCFFISidecar.SessionReuseEnabled {
t.Fatalf("Sora curl_cffi sidecar session reuse should be enabled by default")
}
if cfg.Sora.Client.CurlCFFISidecar.SessionTTLSeconds <= 0 {
t.Fatalf("Sora curl_cffi sidecar session ttl should be positive by default")
}
}
func TestValidateSoraCurlCFFISidecarRequired(t *testing.T) {
resetViperWithJWTSecret(t)
cfg, err := Load()
if err != nil {
t.Fatalf("Load() error: %v", err)
}
cfg.Sora.Client.CurlCFFISidecar.Enabled = false
err = cfg.Validate()
if err == nil || !strings.Contains(err.Error(), "sora.client.curl_cffi_sidecar.enabled must be true") {
t.Fatalf("Validate() error = %v, want sidecar enabled error", err)
}
}
func TestValidateSoraCurlCFFISidecarBaseURLRequired(t *testing.T) {
resetViperWithJWTSecret(t)
cfg, err := Load()
if err != nil {
t.Fatalf("Load() error: %v", err)
}
cfg.Sora.Client.CurlCFFISidecar.BaseURL = " "
err = cfg.Validate()
if err == nil || !strings.Contains(err.Error(), "sora.client.curl_cffi_sidecar.base_url is required") {
t.Fatalf("Validate() error = %v, want sidecar base_url required error", err)
}
}
func TestValidateSoraCurlCFFISidecarSessionTTLNonNegative(t *testing.T) {
resetViperWithJWTSecret(t)
cfg, err := Load()
if err != nil {
t.Fatalf("Load() error: %v", err)
}
cfg.Sora.Client.CurlCFFISidecar.SessionTTLSeconds = -1
err = cfg.Validate()
if err == nil || !strings.Contains(err.Error(), "sora.client.curl_cffi_sidecar.session_ttl_seconds must be non-negative") {
t.Fatalf("Validate() error = %v, want sidecar session ttl error", err)
}
}
func TestValidateSoraCloudflareChallengeCooldownNonNegative(t *testing.T) {
resetViperWithJWTSecret(t)
cfg, err := Load()
if err != nil {
t.Fatalf("Load() error: %v", err)
}
cfg.Sora.Client.CloudflareChallengeCooldownSeconds = -1
err = cfg.Validate()
if err == nil || !strings.Contains(err.Error(), "sora.client.cloudflare_challenge_cooldown_seconds must be non-negative") {
t.Fatalf("Validate() error = %v, want cloudflare cooldown error", err)
}
}
...@@ -341,7 +341,7 @@ func (h *AccountHandler) listAccountsFiltered(ctx context.Context, platform, acc ...@@ -341,7 +341,7 @@ func (h *AccountHandler) listAccountsFiltered(ctx context.Context, platform, acc
pageSize := dataPageCap pageSize := dataPageCap
var out []service.Account var out []service.Account
for { for {
items, total, err := h.adminService.ListAccounts(ctx, page, pageSize, platform, accountType, status, search) items, total, err := h.adminService.ListAccounts(ctx, page, pageSize, platform, accountType, status, search, 0)
if err != nil { if err != nil {
return nil, err return nil, err
} }
......
...@@ -200,7 +200,12 @@ func (h *AccountHandler) List(c *gin.Context) { ...@@ -200,7 +200,12 @@ func (h *AccountHandler) List(c *gin.Context) {
search = search[:100] search = search[:100]
} }
accounts, total, err := h.adminService.ListAccounts(c.Request.Context(), page, pageSize, platform, accountType, status, search) var groupID int64
if groupIDStr := c.Query("group"); groupIDStr != "" {
groupID, _ = strconv.ParseInt(groupIDStr, 10, 64)
}
accounts, total, err := h.adminService.ListAccounts(c.Request.Context(), page, pageSize, platform, accountType, status, search, groupID)
if err != nil { if err != nil {
response.ErrorFrom(c, err) response.ErrorFrom(c, err)
return return
...@@ -1433,6 +1438,12 @@ func (h *AccountHandler) GetAvailableModels(c *gin.Context) { ...@@ -1433,6 +1438,12 @@ func (h *AccountHandler) GetAvailableModels(c *gin.Context) {
return return
} }
// Handle Sora accounts
if account.Platform == service.PlatformSora {
response.Success(c, service.DefaultSoraModels(nil))
return
}
// Handle Claude/Anthropic accounts // Handle Claude/Anthropic accounts
// For OAuth and Setup-Token accounts: return default models // For OAuth and Setup-Token accounts: return default models
if account.IsOAuth() { if account.IsOAuth() {
...@@ -1542,7 +1553,7 @@ func (h *AccountHandler) BatchRefreshTier(c *gin.Context) { ...@@ -1542,7 +1553,7 @@ func (h *AccountHandler) BatchRefreshTier(c *gin.Context) {
accounts := make([]*service.Account, 0) accounts := make([]*service.Account, 0)
if len(req.AccountIDs) == 0 { if len(req.AccountIDs) == 0 {
allAccounts, _, err := h.adminService.ListAccounts(ctx, 1, 10000, "gemini", "oauth", "", "") allAccounts, _, err := h.adminService.ListAccounts(ctx, 1, 10000, "gemini", "oauth", "", "", 0)
if err != nil { if err != nil {
response.ErrorFrom(c, err) response.ErrorFrom(c, err)
return return
......
...@@ -47,6 +47,7 @@ func setupAdminRouter() (*gin.Engine, *stubAdminService) { ...@@ -47,6 +47,7 @@ func setupAdminRouter() (*gin.Engine, *stubAdminService) {
router.DELETE("/api/v1/admin/proxies/:id", proxyHandler.Delete) router.DELETE("/api/v1/admin/proxies/:id", proxyHandler.Delete)
router.POST("/api/v1/admin/proxies/batch-delete", proxyHandler.BatchDelete) router.POST("/api/v1/admin/proxies/batch-delete", proxyHandler.BatchDelete)
router.POST("/api/v1/admin/proxies/:id/test", proxyHandler.Test) router.POST("/api/v1/admin/proxies/:id/test", proxyHandler.Test)
router.POST("/api/v1/admin/proxies/:id/quality-check", proxyHandler.CheckQuality)
router.GET("/api/v1/admin/proxies/:id/stats", proxyHandler.GetStats) router.GET("/api/v1/admin/proxies/:id/stats", proxyHandler.GetStats)
router.GET("/api/v1/admin/proxies/:id/accounts", proxyHandler.GetProxyAccounts) router.GET("/api/v1/admin/proxies/:id/accounts", proxyHandler.GetProxyAccounts)
...@@ -208,6 +209,11 @@ func TestProxyHandlerEndpoints(t *testing.T) { ...@@ -208,6 +209,11 @@ func TestProxyHandlerEndpoints(t *testing.T) {
router.ServeHTTP(rec, req) router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code) require.Equal(t, http.StatusOK, rec.Code)
rec = httptest.NewRecorder()
req = httptest.NewRequest(http.MethodPost, "/api/v1/admin/proxies/4/quality-check", nil)
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
rec = httptest.NewRecorder() rec = httptest.NewRecorder()
req = httptest.NewRequest(http.MethodGet, "/api/v1/admin/proxies/4/stats", nil) req = httptest.NewRequest(http.MethodGet, "/api/v1/admin/proxies/4/stats", nil)
router.ServeHTTP(rec, req) router.ServeHTTP(rec, req)
......
...@@ -166,7 +166,7 @@ func (s *stubAdminService) GetGroupAPIKeys(ctx context.Context, groupID int64, p ...@@ -166,7 +166,7 @@ func (s *stubAdminService) GetGroupAPIKeys(ctx context.Context, groupID int64, p
return s.apiKeys, int64(len(s.apiKeys)), nil return s.apiKeys, int64(len(s.apiKeys)), nil
} }
func (s *stubAdminService) ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string) ([]service.Account, int64, error) { func (s *stubAdminService) ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string, groupID int64) ([]service.Account, int64, error) {
return s.accounts, int64(len(s.accounts)), nil return s.accounts, int64(len(s.accounts)), nil
} }
...@@ -327,6 +327,27 @@ func (s *stubAdminService) TestProxy(ctx context.Context, id int64) (*service.Pr ...@@ -327,6 +327,27 @@ func (s *stubAdminService) TestProxy(ctx context.Context, id int64) (*service.Pr
return &service.ProxyTestResult{Success: true, Message: "ok"}, nil return &service.ProxyTestResult{Success: true, Message: "ok"}, nil
} }
func (s *stubAdminService) CheckProxyQuality(ctx context.Context, id int64) (*service.ProxyQualityCheckResult, error) {
return &service.ProxyQualityCheckResult{
ProxyID: id,
Score: 95,
Grade: "A",
Summary: "通过 5 项,告警 0 项,失败 0 项,挑战 0 项",
PassedCount: 5,
WarnCount: 0,
FailedCount: 0,
ChallengeCount: 0,
CheckedAt: time.Now().Unix(),
Items: []service.ProxyQualityCheckItem{
{Target: "base_connectivity", Status: "pass", Message: "ok"},
{Target: "openai", Status: "pass", HTTPStatus: 401},
{Target: "anthropic", Status: "pass", HTTPStatus: 401},
{Target: "gemini", Status: "pass", HTTPStatus: 200},
{Target: "sora", Status: "pass", HTTPStatus: 401},
},
}, nil
}
func (s *stubAdminService) ListRedeemCodes(ctx context.Context, page, pageSize int, codeType, status, search string) ([]service.RedeemCode, int64, error) { func (s *stubAdminService) ListRedeemCodes(ctx context.Context, page, pageSize int, codeType, status, search string) ([]service.RedeemCode, int64, error) {
return s.redeems, int64(len(s.redeems)), nil return s.redeems, int64(len(s.redeems)), nil
} }
......
...@@ -2,6 +2,7 @@ package admin ...@@ -2,6 +2,7 @@ package admin
import ( import (
"strconv" "strconv"
"strings"
"github.com/Wei-Shaw/sub2api/internal/handler/dto" "github.com/Wei-Shaw/sub2api/internal/handler/dto"
"github.com/Wei-Shaw/sub2api/internal/pkg/response" "github.com/Wei-Shaw/sub2api/internal/pkg/response"
...@@ -16,6 +17,13 @@ type OpenAIOAuthHandler struct { ...@@ -16,6 +17,13 @@ type OpenAIOAuthHandler struct {
adminService service.AdminService adminService service.AdminService
} }
func oauthPlatformFromPath(c *gin.Context) string {
if strings.Contains(c.FullPath(), "/admin/sora/") {
return service.PlatformSora
}
return service.PlatformOpenAI
}
// NewOpenAIOAuthHandler creates a new OpenAI OAuth handler // NewOpenAIOAuthHandler creates a new OpenAI OAuth handler
func NewOpenAIOAuthHandler(openaiOAuthService *service.OpenAIOAuthService, adminService service.AdminService) *OpenAIOAuthHandler { func NewOpenAIOAuthHandler(openaiOAuthService *service.OpenAIOAuthService, adminService service.AdminService) *OpenAIOAuthHandler {
return &OpenAIOAuthHandler{ return &OpenAIOAuthHandler{
...@@ -52,6 +60,7 @@ func (h *OpenAIOAuthHandler) GenerateAuthURL(c *gin.Context) { ...@@ -52,6 +60,7 @@ func (h *OpenAIOAuthHandler) GenerateAuthURL(c *gin.Context) {
type OpenAIExchangeCodeRequest struct { type OpenAIExchangeCodeRequest struct {
SessionID string `json:"session_id" binding:"required"` SessionID string `json:"session_id" binding:"required"`
Code string `json:"code" binding:"required"` Code string `json:"code" binding:"required"`
State string `json:"state" binding:"required"`
RedirectURI string `json:"redirect_uri"` RedirectURI string `json:"redirect_uri"`
ProxyID *int64 `json:"proxy_id"` ProxyID *int64 `json:"proxy_id"`
} }
...@@ -68,6 +77,7 @@ func (h *OpenAIOAuthHandler) ExchangeCode(c *gin.Context) { ...@@ -68,6 +77,7 @@ func (h *OpenAIOAuthHandler) ExchangeCode(c *gin.Context) {
tokenInfo, err := h.openaiOAuthService.ExchangeCode(c.Request.Context(), &service.OpenAIExchangeCodeInput{ tokenInfo, err := h.openaiOAuthService.ExchangeCode(c.Request.Context(), &service.OpenAIExchangeCodeInput{
SessionID: req.SessionID, SessionID: req.SessionID,
Code: req.Code, Code: req.Code,
State: req.State,
RedirectURI: req.RedirectURI, RedirectURI: req.RedirectURI,
ProxyID: req.ProxyID, ProxyID: req.ProxyID,
}) })
...@@ -81,18 +91,29 @@ func (h *OpenAIOAuthHandler) ExchangeCode(c *gin.Context) { ...@@ -81,18 +91,29 @@ func (h *OpenAIOAuthHandler) ExchangeCode(c *gin.Context) {
// OpenAIRefreshTokenRequest represents the request for refreshing OpenAI token // OpenAIRefreshTokenRequest represents the request for refreshing OpenAI token
type OpenAIRefreshTokenRequest struct { type OpenAIRefreshTokenRequest struct {
RefreshToken string `json:"refresh_token" binding:"required"` RefreshToken string `json:"refresh_token"`
RT string `json:"rt"`
ClientID string `json:"client_id"`
ProxyID *int64 `json:"proxy_id"` ProxyID *int64 `json:"proxy_id"`
} }
// RefreshToken refreshes an OpenAI OAuth token // RefreshToken refreshes an OpenAI OAuth token
// POST /api/v1/admin/openai/refresh-token // POST /api/v1/admin/openai/refresh-token
// POST /api/v1/admin/sora/rt2at
func (h *OpenAIOAuthHandler) RefreshToken(c *gin.Context) { func (h *OpenAIOAuthHandler) RefreshToken(c *gin.Context) {
var req OpenAIRefreshTokenRequest var req OpenAIRefreshTokenRequest
if err := c.ShouldBindJSON(&req); err != nil { if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error()) response.BadRequest(c, "Invalid request: "+err.Error())
return return
} }
refreshToken := strings.TrimSpace(req.RefreshToken)
if refreshToken == "" {
refreshToken = strings.TrimSpace(req.RT)
}
if refreshToken == "" {
response.BadRequest(c, "refresh_token is required")
return
}
var proxyURL string var proxyURL string
if req.ProxyID != nil { if req.ProxyID != nil {
...@@ -102,7 +123,7 @@ func (h *OpenAIOAuthHandler) RefreshToken(c *gin.Context) { ...@@ -102,7 +123,7 @@ func (h *OpenAIOAuthHandler) RefreshToken(c *gin.Context) {
} }
} }
tokenInfo, err := h.openaiOAuthService.RefreshToken(c.Request.Context(), req.RefreshToken, proxyURL) tokenInfo, err := h.openaiOAuthService.RefreshTokenWithClientID(c.Request.Context(), refreshToken, proxyURL, strings.TrimSpace(req.ClientID))
if err != nil { if err != nil {
response.ErrorFrom(c, err) response.ErrorFrom(c, err)
return return
...@@ -111,8 +132,39 @@ func (h *OpenAIOAuthHandler) RefreshToken(c *gin.Context) { ...@@ -111,8 +132,39 @@ func (h *OpenAIOAuthHandler) RefreshToken(c *gin.Context) {
response.Success(c, tokenInfo) response.Success(c, tokenInfo)
} }
// RefreshAccountToken refreshes token for a specific OpenAI account // ExchangeSoraSessionToken exchanges Sora session token to access token
// POST /api/v1/admin/sora/st2at
func (h *OpenAIOAuthHandler) ExchangeSoraSessionToken(c *gin.Context) {
var req struct {
SessionToken string `json:"session_token"`
ST string `json:"st"`
ProxyID *int64 `json:"proxy_id"`
}
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
sessionToken := strings.TrimSpace(req.SessionToken)
if sessionToken == "" {
sessionToken = strings.TrimSpace(req.ST)
}
if sessionToken == "" {
response.BadRequest(c, "session_token is required")
return
}
tokenInfo, err := h.openaiOAuthService.ExchangeSoraSessionToken(c.Request.Context(), sessionToken, req.ProxyID)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, tokenInfo)
}
// RefreshAccountToken refreshes token for a specific OpenAI/Sora account
// POST /api/v1/admin/openai/accounts/:id/refresh // POST /api/v1/admin/openai/accounts/:id/refresh
// POST /api/v1/admin/sora/accounts/:id/refresh
func (h *OpenAIOAuthHandler) RefreshAccountToken(c *gin.Context) { func (h *OpenAIOAuthHandler) RefreshAccountToken(c *gin.Context) {
accountID, err := strconv.ParseInt(c.Param("id"), 10, 64) accountID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil { if err != nil {
...@@ -127,9 +179,9 @@ func (h *OpenAIOAuthHandler) RefreshAccountToken(c *gin.Context) { ...@@ -127,9 +179,9 @@ func (h *OpenAIOAuthHandler) RefreshAccountToken(c *gin.Context) {
return return
} }
// Ensure account is OpenAI platform platform := oauthPlatformFromPath(c)
if !account.IsOpenAI() { if account.Platform != platform {
response.BadRequest(c, "Account is not an OpenAI account") response.BadRequest(c, "Account platform does not match OAuth endpoint")
return return
} }
...@@ -167,12 +219,14 @@ func (h *OpenAIOAuthHandler) RefreshAccountToken(c *gin.Context) { ...@@ -167,12 +219,14 @@ func (h *OpenAIOAuthHandler) RefreshAccountToken(c *gin.Context) {
response.Success(c, dto.AccountFromService(updatedAccount)) response.Success(c, dto.AccountFromService(updatedAccount))
} }
// CreateAccountFromOAuth creates a new OpenAI OAuth account from token info // CreateAccountFromOAuth creates a new OpenAI/Sora OAuth account from token info
// POST /api/v1/admin/openai/create-from-oauth // POST /api/v1/admin/openai/create-from-oauth
// POST /api/v1/admin/sora/create-from-oauth
func (h *OpenAIOAuthHandler) CreateAccountFromOAuth(c *gin.Context) { func (h *OpenAIOAuthHandler) CreateAccountFromOAuth(c *gin.Context) {
var req struct { var req struct {
SessionID string `json:"session_id" binding:"required"` SessionID string `json:"session_id" binding:"required"`
Code string `json:"code" binding:"required"` Code string `json:"code" binding:"required"`
State string `json:"state" binding:"required"`
RedirectURI string `json:"redirect_uri"` RedirectURI string `json:"redirect_uri"`
ProxyID *int64 `json:"proxy_id"` ProxyID *int64 `json:"proxy_id"`
Name string `json:"name"` Name string `json:"name"`
...@@ -189,6 +243,7 @@ func (h *OpenAIOAuthHandler) CreateAccountFromOAuth(c *gin.Context) { ...@@ -189,6 +243,7 @@ func (h *OpenAIOAuthHandler) CreateAccountFromOAuth(c *gin.Context) {
tokenInfo, err := h.openaiOAuthService.ExchangeCode(c.Request.Context(), &service.OpenAIExchangeCodeInput{ tokenInfo, err := h.openaiOAuthService.ExchangeCode(c.Request.Context(), &service.OpenAIExchangeCodeInput{
SessionID: req.SessionID, SessionID: req.SessionID,
Code: req.Code, Code: req.Code,
State: req.State,
RedirectURI: req.RedirectURI, RedirectURI: req.RedirectURI,
ProxyID: req.ProxyID, ProxyID: req.ProxyID,
}) })
...@@ -200,19 +255,25 @@ func (h *OpenAIOAuthHandler) CreateAccountFromOAuth(c *gin.Context) { ...@@ -200,19 +255,25 @@ func (h *OpenAIOAuthHandler) CreateAccountFromOAuth(c *gin.Context) {
// Build credentials from token info // Build credentials from token info
credentials := h.openaiOAuthService.BuildAccountCredentials(tokenInfo) credentials := h.openaiOAuthService.BuildAccountCredentials(tokenInfo)
platform := oauthPlatformFromPath(c)
// Use email as default name if not provided // Use email as default name if not provided
name := req.Name name := req.Name
if name == "" && tokenInfo.Email != "" { if name == "" && tokenInfo.Email != "" {
name = tokenInfo.Email name = tokenInfo.Email
} }
if name == "" { if name == "" {
if platform == service.PlatformSora {
name = "Sora OAuth Account"
} else {
name = "OpenAI OAuth Account" name = "OpenAI OAuth Account"
} }
}
// Create account // Create account
account, err := h.adminService.CreateAccount(c.Request.Context(), &service.CreateAccountInput{ account, err := h.adminService.CreateAccount(c.Request.Context(), &service.CreateAccountInput{
Name: name, Name: name,
Platform: "openai", Platform: platform,
Type: "oauth", Type: "oauth",
Credentials: credentials, Credentials: credentials,
ProxyID: req.ProxyID, ProxyID: req.ProxyID,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment