Commit 2b70d1d3 authored by IanShaw027's avatar IanShaw027
Browse files

merge upstream main into fix/bug-cleanup-main

parents b37afd68 00c08c57
...@@ -11,6 +11,7 @@ import ( ...@@ -11,6 +11,7 @@ import (
"entgo.io/ent" "entgo.io/ent"
"entgo.io/ent/dialect/sql" "entgo.io/ent/dialect/sql"
"github.com/Wei-Shaw/sub2api/ent/group" "github.com/Wei-Shaw/sub2api/ent/group"
"github.com/Wei-Shaw/sub2api/internal/domain"
) )
// Group is the model entity for the Group schema. // Group is the model entity for the Group schema.
...@@ -76,6 +77,8 @@ type Group struct { ...@@ -76,6 +77,8 @@ type Group struct {
RequirePrivacySet bool `json:"require_privacy_set,omitempty"` RequirePrivacySet bool `json:"require_privacy_set,omitempty"`
// 默认映射模型 ID,当账号级映射找不到时使用此值 // 默认映射模型 ID,当账号级映射找不到时使用此值
DefaultMappedModel string `json:"default_mapped_model,omitempty"` DefaultMappedModel string `json:"default_mapped_model,omitempty"`
// OpenAI Messages 调度模型配置:按 Claude 系列/精确模型映射到目标 GPT 模型
MessagesDispatchModelConfig domain.OpenAIMessagesDispatchModelConfig `json:"messages_dispatch_model_config,omitempty"`
// Edges holds the relations/edges for other nodes in the graph. // Edges holds the relations/edges for other nodes in the graph.
// The values are being populated by the GroupQuery when eager-loading is set. // The values are being populated by the GroupQuery when eager-loading is set.
Edges GroupEdges `json:"edges"` Edges GroupEdges `json:"edges"`
...@@ -182,7 +185,7 @@ func (*Group) scanValues(columns []string) ([]any, error) { ...@@ -182,7 +185,7 @@ func (*Group) scanValues(columns []string) ([]any, error) {
values := make([]any, len(columns)) values := make([]any, len(columns))
for i := range columns { for i := range columns {
switch columns[i] { switch columns[i] {
case group.FieldModelRouting, group.FieldSupportedModelScopes: case group.FieldModelRouting, group.FieldSupportedModelScopes, group.FieldMessagesDispatchModelConfig:
values[i] = new([]byte) values[i] = new([]byte)
case group.FieldIsExclusive, group.FieldClaudeCodeOnly, group.FieldModelRoutingEnabled, group.FieldMcpXMLInject, group.FieldAllowMessagesDispatch, group.FieldRequireOauthOnly, group.FieldRequirePrivacySet: case group.FieldIsExclusive, group.FieldClaudeCodeOnly, group.FieldModelRoutingEnabled, group.FieldMcpXMLInject, group.FieldAllowMessagesDispatch, group.FieldRequireOauthOnly, group.FieldRequirePrivacySet:
values[i] = new(sql.NullBool) values[i] = new(sql.NullBool)
...@@ -403,6 +406,14 @@ func (_m *Group) assignValues(columns []string, values []any) error { ...@@ -403,6 +406,14 @@ func (_m *Group) assignValues(columns []string, values []any) error {
} else if value.Valid { } else if value.Valid {
_m.DefaultMappedModel = value.String _m.DefaultMappedModel = value.String
} }
case group.FieldMessagesDispatchModelConfig:
if value, ok := values[i].(*[]byte); !ok {
return fmt.Errorf("unexpected type %T for field messages_dispatch_model_config", values[i])
} else if value != nil && len(*value) > 0 {
if err := json.Unmarshal(*value, &_m.MessagesDispatchModelConfig); err != nil {
return fmt.Errorf("unmarshal field messages_dispatch_model_config: %w", err)
}
}
default: default:
_m.selectValues.Set(columns[i], values[i]) _m.selectValues.Set(columns[i], values[i])
} }
...@@ -585,6 +596,9 @@ func (_m *Group) String() string { ...@@ -585,6 +596,9 @@ func (_m *Group) String() string {
builder.WriteString(", ") builder.WriteString(", ")
builder.WriteString("default_mapped_model=") builder.WriteString("default_mapped_model=")
builder.WriteString(_m.DefaultMappedModel) builder.WriteString(_m.DefaultMappedModel)
builder.WriteString(", ")
builder.WriteString("messages_dispatch_model_config=")
builder.WriteString(fmt.Sprintf("%v", _m.MessagesDispatchModelConfig))
builder.WriteByte(')') builder.WriteByte(')')
return builder.String() return builder.String()
} }
......
...@@ -8,6 +8,7 @@ import ( ...@@ -8,6 +8,7 @@ import (
"entgo.io/ent" "entgo.io/ent"
"entgo.io/ent/dialect/sql" "entgo.io/ent/dialect/sql"
"entgo.io/ent/dialect/sql/sqlgraph" "entgo.io/ent/dialect/sql/sqlgraph"
"github.com/Wei-Shaw/sub2api/internal/domain"
) )
const ( const (
...@@ -73,6 +74,8 @@ const ( ...@@ -73,6 +74,8 @@ const (
FieldRequirePrivacySet = "require_privacy_set" FieldRequirePrivacySet = "require_privacy_set"
// FieldDefaultMappedModel holds the string denoting the default_mapped_model field in the database. // FieldDefaultMappedModel holds the string denoting the default_mapped_model field in the database.
FieldDefaultMappedModel = "default_mapped_model" FieldDefaultMappedModel = "default_mapped_model"
// FieldMessagesDispatchModelConfig holds the string denoting the messages_dispatch_model_config field in the database.
FieldMessagesDispatchModelConfig = "messages_dispatch_model_config"
// EdgeAPIKeys holds the string denoting the api_keys edge name in mutations. // EdgeAPIKeys holds the string denoting the api_keys edge name in mutations.
EdgeAPIKeys = "api_keys" EdgeAPIKeys = "api_keys"
// EdgeRedeemCodes holds the string denoting the redeem_codes edge name in mutations. // EdgeRedeemCodes holds the string denoting the redeem_codes edge name in mutations.
...@@ -177,6 +180,7 @@ var Columns = []string{ ...@@ -177,6 +180,7 @@ var Columns = []string{
FieldRequireOauthOnly, FieldRequireOauthOnly,
FieldRequirePrivacySet, FieldRequirePrivacySet,
FieldDefaultMappedModel, FieldDefaultMappedModel,
FieldMessagesDispatchModelConfig,
} }
var ( var (
...@@ -252,6 +256,8 @@ var ( ...@@ -252,6 +256,8 @@ var (
DefaultDefaultMappedModel string DefaultDefaultMappedModel string
// DefaultMappedModelValidator is a validator for the "default_mapped_model" field. It is called by the builders before save. // DefaultMappedModelValidator is a validator for the "default_mapped_model" field. It is called by the builders before save.
DefaultMappedModelValidator func(string) error DefaultMappedModelValidator func(string) error
// DefaultMessagesDispatchModelConfig holds the default value on creation for the "messages_dispatch_model_config" field.
DefaultMessagesDispatchModelConfig domain.OpenAIMessagesDispatchModelConfig
) )
// OrderOption defines the ordering options for the Group queries. // OrderOption defines the ordering options for the Group queries.
......
...@@ -18,6 +18,7 @@ import ( ...@@ -18,6 +18,7 @@ import (
"github.com/Wei-Shaw/sub2api/ent/usagelog" "github.com/Wei-Shaw/sub2api/ent/usagelog"
"github.com/Wei-Shaw/sub2api/ent/user" "github.com/Wei-Shaw/sub2api/ent/user"
"github.com/Wei-Shaw/sub2api/ent/usersubscription" "github.com/Wei-Shaw/sub2api/ent/usersubscription"
"github.com/Wei-Shaw/sub2api/internal/domain"
) )
// GroupCreate is the builder for creating a Group entity. // GroupCreate is the builder for creating a Group entity.
...@@ -410,6 +411,20 @@ func (_c *GroupCreate) SetNillableDefaultMappedModel(v *string) *GroupCreate { ...@@ -410,6 +411,20 @@ func (_c *GroupCreate) SetNillableDefaultMappedModel(v *string) *GroupCreate {
return _c return _c
} }
// SetMessagesDispatchModelConfig sets the "messages_dispatch_model_config" field.
func (_c *GroupCreate) SetMessagesDispatchModelConfig(v domain.OpenAIMessagesDispatchModelConfig) *GroupCreate {
_c.mutation.SetMessagesDispatchModelConfig(v)
return _c
}
// SetNillableMessagesDispatchModelConfig sets the "messages_dispatch_model_config" field if the given value is not nil.
func (_c *GroupCreate) SetNillableMessagesDispatchModelConfig(v *domain.OpenAIMessagesDispatchModelConfig) *GroupCreate {
if v != nil {
_c.SetMessagesDispatchModelConfig(*v)
}
return _c
}
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs. // AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs.
func (_c *GroupCreate) AddAPIKeyIDs(ids ...int64) *GroupCreate { func (_c *GroupCreate) AddAPIKeyIDs(ids ...int64) *GroupCreate {
_c.mutation.AddAPIKeyIDs(ids...) _c.mutation.AddAPIKeyIDs(ids...)
...@@ -611,6 +626,10 @@ func (_c *GroupCreate) defaults() error { ...@@ -611,6 +626,10 @@ func (_c *GroupCreate) defaults() error {
v := group.DefaultDefaultMappedModel v := group.DefaultDefaultMappedModel
_c.mutation.SetDefaultMappedModel(v) _c.mutation.SetDefaultMappedModel(v)
} }
if _, ok := _c.mutation.MessagesDispatchModelConfig(); !ok {
v := group.DefaultMessagesDispatchModelConfig
_c.mutation.SetMessagesDispatchModelConfig(v)
}
return nil return nil
} }
...@@ -695,6 +714,9 @@ func (_c *GroupCreate) check() error { ...@@ -695,6 +714,9 @@ func (_c *GroupCreate) check() error {
return &ValidationError{Name: "default_mapped_model", err: fmt.Errorf(`ent: validator failed for field "Group.default_mapped_model": %w`, err)} return &ValidationError{Name: "default_mapped_model", err: fmt.Errorf(`ent: validator failed for field "Group.default_mapped_model": %w`, err)}
} }
} }
if _, ok := _c.mutation.MessagesDispatchModelConfig(); !ok {
return &ValidationError{Name: "messages_dispatch_model_config", err: errors.New(`ent: missing required field "Group.messages_dispatch_model_config"`)}
}
return nil return nil
} }
...@@ -838,6 +860,10 @@ func (_c *GroupCreate) createSpec() (*Group, *sqlgraph.CreateSpec) { ...@@ -838,6 +860,10 @@ func (_c *GroupCreate) createSpec() (*Group, *sqlgraph.CreateSpec) {
_spec.SetField(group.FieldDefaultMappedModel, field.TypeString, value) _spec.SetField(group.FieldDefaultMappedModel, field.TypeString, value)
_node.DefaultMappedModel = value _node.DefaultMappedModel = value
} }
if value, ok := _c.mutation.MessagesDispatchModelConfig(); ok {
_spec.SetField(group.FieldMessagesDispatchModelConfig, field.TypeJSON, value)
_node.MessagesDispatchModelConfig = value
}
if nodes := _c.mutation.APIKeysIDs(); len(nodes) > 0 { if nodes := _c.mutation.APIKeysIDs(); len(nodes) > 0 {
edge := &sqlgraph.EdgeSpec{ edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M, Rel: sqlgraph.O2M,
...@@ -1462,6 +1488,18 @@ func (u *GroupUpsert) UpdateDefaultMappedModel() *GroupUpsert { ...@@ -1462,6 +1488,18 @@ func (u *GroupUpsert) UpdateDefaultMappedModel() *GroupUpsert {
return u return u
} }
// SetMessagesDispatchModelConfig sets the "messages_dispatch_model_config" field.
func (u *GroupUpsert) SetMessagesDispatchModelConfig(v domain.OpenAIMessagesDispatchModelConfig) *GroupUpsert {
u.Set(group.FieldMessagesDispatchModelConfig, v)
return u
}
// UpdateMessagesDispatchModelConfig sets the "messages_dispatch_model_config" field to the value that was provided on create.
func (u *GroupUpsert) UpdateMessagesDispatchModelConfig() *GroupUpsert {
u.SetExcluded(group.FieldMessagesDispatchModelConfig)
return u
}
// UpdateNewValues updates the mutable fields using the new values that were set on create. // UpdateNewValues updates the mutable fields using the new values that were set on create.
// Using this option is equivalent to using: // Using this option is equivalent to using:
// //
...@@ -2053,6 +2091,20 @@ func (u *GroupUpsertOne) UpdateDefaultMappedModel() *GroupUpsertOne { ...@@ -2053,6 +2091,20 @@ func (u *GroupUpsertOne) UpdateDefaultMappedModel() *GroupUpsertOne {
}) })
} }
// SetMessagesDispatchModelConfig sets the "messages_dispatch_model_config" field.
func (u *GroupUpsertOne) SetMessagesDispatchModelConfig(v domain.OpenAIMessagesDispatchModelConfig) *GroupUpsertOne {
return u.Update(func(s *GroupUpsert) {
s.SetMessagesDispatchModelConfig(v)
})
}
// UpdateMessagesDispatchModelConfig sets the "messages_dispatch_model_config" field to the value that was provided on create.
func (u *GroupUpsertOne) UpdateMessagesDispatchModelConfig() *GroupUpsertOne {
return u.Update(func(s *GroupUpsert) {
s.UpdateMessagesDispatchModelConfig()
})
}
// Exec executes the query. // Exec executes the query.
func (u *GroupUpsertOne) Exec(ctx context.Context) error { func (u *GroupUpsertOne) Exec(ctx context.Context) error {
if len(u.create.conflict) == 0 { if len(u.create.conflict) == 0 {
...@@ -2810,6 +2862,20 @@ func (u *GroupUpsertBulk) UpdateDefaultMappedModel() *GroupUpsertBulk { ...@@ -2810,6 +2862,20 @@ func (u *GroupUpsertBulk) UpdateDefaultMappedModel() *GroupUpsertBulk {
}) })
} }
// SetMessagesDispatchModelConfig sets the "messages_dispatch_model_config" field.
func (u *GroupUpsertBulk) SetMessagesDispatchModelConfig(v domain.OpenAIMessagesDispatchModelConfig) *GroupUpsertBulk {
return u.Update(func(s *GroupUpsert) {
s.SetMessagesDispatchModelConfig(v)
})
}
// UpdateMessagesDispatchModelConfig sets the "messages_dispatch_model_config" field to the value that was provided on create.
func (u *GroupUpsertBulk) UpdateMessagesDispatchModelConfig() *GroupUpsertBulk {
return u.Update(func(s *GroupUpsert) {
s.UpdateMessagesDispatchModelConfig()
})
}
// Exec executes the query. // Exec executes the query.
func (u *GroupUpsertBulk) Exec(ctx context.Context) error { func (u *GroupUpsertBulk) Exec(ctx context.Context) error {
if u.create.err != nil { if u.create.err != nil {
......
...@@ -20,6 +20,7 @@ import ( ...@@ -20,6 +20,7 @@ import (
"github.com/Wei-Shaw/sub2api/ent/usagelog" "github.com/Wei-Shaw/sub2api/ent/usagelog"
"github.com/Wei-Shaw/sub2api/ent/user" "github.com/Wei-Shaw/sub2api/ent/user"
"github.com/Wei-Shaw/sub2api/ent/usersubscription" "github.com/Wei-Shaw/sub2api/ent/usersubscription"
"github.com/Wei-Shaw/sub2api/internal/domain"
) )
// GroupUpdate is the builder for updating Group entities. // GroupUpdate is the builder for updating Group entities.
...@@ -552,6 +553,20 @@ func (_u *GroupUpdate) SetNillableDefaultMappedModel(v *string) *GroupUpdate { ...@@ -552,6 +553,20 @@ func (_u *GroupUpdate) SetNillableDefaultMappedModel(v *string) *GroupUpdate {
return _u return _u
} }
// SetMessagesDispatchModelConfig sets the "messages_dispatch_model_config" field.
func (_u *GroupUpdate) SetMessagesDispatchModelConfig(v domain.OpenAIMessagesDispatchModelConfig) *GroupUpdate {
_u.mutation.SetMessagesDispatchModelConfig(v)
return _u
}
// SetNillableMessagesDispatchModelConfig sets the "messages_dispatch_model_config" field if the given value is not nil.
func (_u *GroupUpdate) SetNillableMessagesDispatchModelConfig(v *domain.OpenAIMessagesDispatchModelConfig) *GroupUpdate {
if v != nil {
_u.SetMessagesDispatchModelConfig(*v)
}
return _u
}
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs. // AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs.
func (_u *GroupUpdate) AddAPIKeyIDs(ids ...int64) *GroupUpdate { func (_u *GroupUpdate) AddAPIKeyIDs(ids ...int64) *GroupUpdate {
_u.mutation.AddAPIKeyIDs(ids...) _u.mutation.AddAPIKeyIDs(ids...)
...@@ -1012,6 +1027,9 @@ func (_u *GroupUpdate) sqlSave(ctx context.Context) (_node int, err error) { ...@@ -1012,6 +1027,9 @@ func (_u *GroupUpdate) sqlSave(ctx context.Context) (_node int, err error) {
if value, ok := _u.mutation.DefaultMappedModel(); ok { if value, ok := _u.mutation.DefaultMappedModel(); ok {
_spec.SetField(group.FieldDefaultMappedModel, field.TypeString, value) _spec.SetField(group.FieldDefaultMappedModel, field.TypeString, value)
} }
if value, ok := _u.mutation.MessagesDispatchModelConfig(); ok {
_spec.SetField(group.FieldMessagesDispatchModelConfig, field.TypeJSON, value)
}
if _u.mutation.APIKeysCleared() { if _u.mutation.APIKeysCleared() {
edge := &sqlgraph.EdgeSpec{ edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M, Rel: sqlgraph.O2M,
...@@ -1843,6 +1861,20 @@ func (_u *GroupUpdateOne) SetNillableDefaultMappedModel(v *string) *GroupUpdateO ...@@ -1843,6 +1861,20 @@ func (_u *GroupUpdateOne) SetNillableDefaultMappedModel(v *string) *GroupUpdateO
return _u return _u
} }
// SetMessagesDispatchModelConfig sets the "messages_dispatch_model_config" field.
func (_u *GroupUpdateOne) SetMessagesDispatchModelConfig(v domain.OpenAIMessagesDispatchModelConfig) *GroupUpdateOne {
_u.mutation.SetMessagesDispatchModelConfig(v)
return _u
}
// SetNillableMessagesDispatchModelConfig sets the "messages_dispatch_model_config" field if the given value is not nil.
func (_u *GroupUpdateOne) SetNillableMessagesDispatchModelConfig(v *domain.OpenAIMessagesDispatchModelConfig) *GroupUpdateOne {
if v != nil {
_u.SetMessagesDispatchModelConfig(*v)
}
return _u
}
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs. // AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs.
func (_u *GroupUpdateOne) AddAPIKeyIDs(ids ...int64) *GroupUpdateOne { func (_u *GroupUpdateOne) AddAPIKeyIDs(ids ...int64) *GroupUpdateOne {
_u.mutation.AddAPIKeyIDs(ids...) _u.mutation.AddAPIKeyIDs(ids...)
...@@ -2333,6 +2365,9 @@ func (_u *GroupUpdateOne) sqlSave(ctx context.Context) (_node *Group, err error) ...@@ -2333,6 +2365,9 @@ func (_u *GroupUpdateOne) sqlSave(ctx context.Context) (_node *Group, err error)
if value, ok := _u.mutation.DefaultMappedModel(); ok { if value, ok := _u.mutation.DefaultMappedModel(); ok {
_spec.SetField(group.FieldDefaultMappedModel, field.TypeString, value) _spec.SetField(group.FieldDefaultMappedModel, field.TypeString, value)
} }
if value, ok := _u.mutation.MessagesDispatchModelConfig(); ok {
_spec.SetField(group.FieldMessagesDispatchModelConfig, field.TypeJSON, value)
}
if _u.mutation.APIKeysCleared() { if _u.mutation.APIKeysCleared() {
edge := &sqlgraph.EdgeSpec{ edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M, Rel: sqlgraph.O2M,
......
...@@ -407,6 +407,7 @@ var ( ...@@ -407,6 +407,7 @@ var (
{Name: "require_oauth_only", Type: field.TypeBool, Default: false}, {Name: "require_oauth_only", Type: field.TypeBool, Default: false},
{Name: "require_privacy_set", Type: field.TypeBool, Default: false}, {Name: "require_privacy_set", Type: field.TypeBool, Default: false},
{Name: "default_mapped_model", Type: field.TypeString, Size: 100, Default: ""}, {Name: "default_mapped_model", Type: field.TypeString, Size: 100, Default: ""},
{Name: "messages_dispatch_model_config", Type: field.TypeJSON, SchemaType: map[string]string{"postgres": "jsonb"}},
} }
// GroupsTable holds the schema information for the "groups" table. // GroupsTable holds the schema information for the "groups" table.
GroupsTable = &schema.Table{ GroupsTable = &schema.Table{
......
...@@ -8246,6 +8246,7 @@ type GroupMutation struct { ...@@ -8246,6 +8246,7 @@ type GroupMutation struct {
require_oauth_only *bool require_oauth_only *bool
require_privacy_set *bool require_privacy_set *bool
default_mapped_model *string default_mapped_model *string
messages_dispatch_model_config *domain.OpenAIMessagesDispatchModelConfig
clearedFields map[string]struct{} clearedFields map[string]struct{}
api_keys map[int64]struct{} api_keys map[int64]struct{}
removedapi_keys map[int64]struct{} removedapi_keys map[int64]struct{}
...@@ -9798,6 +9799,42 @@ func (m *GroupMutation) ResetDefaultMappedModel() { ...@@ -9798,6 +9799,42 @@ func (m *GroupMutation) ResetDefaultMappedModel() {
m.default_mapped_model = nil m.default_mapped_model = nil
} }
   
// SetMessagesDispatchModelConfig sets the "messages_dispatch_model_config" field.
func (m *GroupMutation) SetMessagesDispatchModelConfig(damdmc domain.OpenAIMessagesDispatchModelConfig) {
m.messages_dispatch_model_config = &damdmc
}
// MessagesDispatchModelConfig returns the value of the "messages_dispatch_model_config" field in the mutation.
func (m *GroupMutation) MessagesDispatchModelConfig() (r domain.OpenAIMessagesDispatchModelConfig, exists bool) {
v := m.messages_dispatch_model_config
if v == nil {
return
}
return *v, true
}
// OldMessagesDispatchModelConfig returns the old "messages_dispatch_model_config" field's value of the Group entity.
// If the Group object wasn't provided to the builder, the object is fetched from the database.
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
func (m *GroupMutation) OldMessagesDispatchModelConfig(ctx context.Context) (v domain.OpenAIMessagesDispatchModelConfig, err error) {
if !m.op.Is(OpUpdateOne) {
return v, errors.New("OldMessagesDispatchModelConfig is only allowed on UpdateOne operations")
}
if m.id == nil || m.oldValue == nil {
return v, errors.New("OldMessagesDispatchModelConfig requires an ID field in the mutation")
}
oldValue, err := m.oldValue(ctx)
if err != nil {
return v, fmt.Errorf("querying old value for OldMessagesDispatchModelConfig: %w", err)
}
return oldValue.MessagesDispatchModelConfig, nil
}
// ResetMessagesDispatchModelConfig resets all changes to the "messages_dispatch_model_config" field.
func (m *GroupMutation) ResetMessagesDispatchModelConfig() {
m.messages_dispatch_model_config = nil
}
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by ids. // AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by ids.
func (m *GroupMutation) AddAPIKeyIDs(ids ...int64) { func (m *GroupMutation) AddAPIKeyIDs(ids ...int64) {
if m.api_keys == nil { if m.api_keys == nil {
...@@ -10156,7 +10193,7 @@ func (m *GroupMutation) Type() string { ...@@ -10156,7 +10193,7 @@ func (m *GroupMutation) Type() string {
// order to get all numeric fields that were incremented/decremented, call // order to get all numeric fields that were incremented/decremented, call
// AddedFields(). // AddedFields().
func (m *GroupMutation) Fields() []string { func (m *GroupMutation) Fields() []string {
fields := make([]string, 0, 29) fields := make([]string, 0, 30)
if m.created_at != nil { if m.created_at != nil {
fields = append(fields, group.FieldCreatedAt) fields = append(fields, group.FieldCreatedAt)
} }
...@@ -10244,6 +10281,9 @@ func (m *GroupMutation) Fields() []string { ...@@ -10244,6 +10281,9 @@ func (m *GroupMutation) Fields() []string {
if m.default_mapped_model != nil { if m.default_mapped_model != nil {
fields = append(fields, group.FieldDefaultMappedModel) fields = append(fields, group.FieldDefaultMappedModel)
} }
if m.messages_dispatch_model_config != nil {
fields = append(fields, group.FieldMessagesDispatchModelConfig)
}
return fields return fields
} }
   
...@@ -10310,6 +10350,8 @@ func (m *GroupMutation) Field(name string) (ent.Value, bool) { ...@@ -10310,6 +10350,8 @@ func (m *GroupMutation) Field(name string) (ent.Value, bool) {
return m.RequirePrivacySet() return m.RequirePrivacySet()
case group.FieldDefaultMappedModel: case group.FieldDefaultMappedModel:
return m.DefaultMappedModel() return m.DefaultMappedModel()
case group.FieldMessagesDispatchModelConfig:
return m.MessagesDispatchModelConfig()
} }
return nil, false return nil, false
} }
...@@ -10377,6 +10419,8 @@ func (m *GroupMutation) OldField(ctx context.Context, name string) (ent.Value, e ...@@ -10377,6 +10419,8 @@ func (m *GroupMutation) OldField(ctx context.Context, name string) (ent.Value, e
return m.OldRequirePrivacySet(ctx) return m.OldRequirePrivacySet(ctx)
case group.FieldDefaultMappedModel: case group.FieldDefaultMappedModel:
return m.OldDefaultMappedModel(ctx) return m.OldDefaultMappedModel(ctx)
case group.FieldMessagesDispatchModelConfig:
return m.OldMessagesDispatchModelConfig(ctx)
} }
return nil, fmt.Errorf("unknown Group field %s", name) return nil, fmt.Errorf("unknown Group field %s", name)
} }
...@@ -10589,6 +10633,13 @@ func (m *GroupMutation) SetField(name string, value ent.Value) error { ...@@ -10589,6 +10633,13 @@ func (m *GroupMutation) SetField(name string, value ent.Value) error {
} }
m.SetDefaultMappedModel(v) m.SetDefaultMappedModel(v)
return nil return nil
case group.FieldMessagesDispatchModelConfig:
v, ok := value.(domain.OpenAIMessagesDispatchModelConfig)
if !ok {
return fmt.Errorf("unexpected type %T for field %s", value, name)
}
m.SetMessagesDispatchModelConfig(v)
return nil
} }
return fmt.Errorf("unknown Group field %s", name) return fmt.Errorf("unknown Group field %s", name)
} }
...@@ -10929,6 +10980,9 @@ func (m *GroupMutation) ResetField(name string) error { ...@@ -10929,6 +10980,9 @@ func (m *GroupMutation) ResetField(name string) error {
case group.FieldDefaultMappedModel: case group.FieldDefaultMappedModel:
m.ResetDefaultMappedModel() m.ResetDefaultMappedModel()
return nil return nil
case group.FieldMessagesDispatchModelConfig:
m.ResetMessagesDispatchModelConfig()
return nil
} }
return fmt.Errorf("unknown Group field %s", name) return fmt.Errorf("unknown Group field %s", name)
} }
......
...@@ -28,6 +28,7 @@ import ( ...@@ -28,6 +28,7 @@ import (
"github.com/Wei-Shaw/sub2api/ent/userattributedefinition" "github.com/Wei-Shaw/sub2api/ent/userattributedefinition"
"github.com/Wei-Shaw/sub2api/ent/userattributevalue" "github.com/Wei-Shaw/sub2api/ent/userattributevalue"
"github.com/Wei-Shaw/sub2api/ent/usersubscription" "github.com/Wei-Shaw/sub2api/ent/usersubscription"
"github.com/Wei-Shaw/sub2api/internal/domain"
) )
// The init function reads all schema descriptors with runtime code // The init function reads all schema descriptors with runtime code
...@@ -468,6 +469,10 @@ func init() { ...@@ -468,6 +469,10 @@ func init() {
group.DefaultDefaultMappedModel = groupDescDefaultMappedModel.Default.(string) group.DefaultDefaultMappedModel = groupDescDefaultMappedModel.Default.(string)
// group.DefaultMappedModelValidator is a validator for the "default_mapped_model" field. It is called by the builders before save. // group.DefaultMappedModelValidator is a validator for the "default_mapped_model" field. It is called by the builders before save.
group.DefaultMappedModelValidator = groupDescDefaultMappedModel.Validators[0].(func(string) error) group.DefaultMappedModelValidator = groupDescDefaultMappedModel.Validators[0].(func(string) error)
// groupDescMessagesDispatchModelConfig is the schema descriptor for messages_dispatch_model_config field.
groupDescMessagesDispatchModelConfig := groupFields[26].Descriptor()
// group.DefaultMessagesDispatchModelConfig holds the default value on creation for the messages_dispatch_model_config field.
group.DefaultMessagesDispatchModelConfig = groupDescMessagesDispatchModelConfig.Default.(domain.OpenAIMessagesDispatchModelConfig)
idempotencyrecordMixin := schema.IdempotencyRecord{}.Mixin() idempotencyrecordMixin := schema.IdempotencyRecord{}.Mixin()
idempotencyrecordMixinFields0 := idempotencyrecordMixin[0].Fields() idempotencyrecordMixinFields0 := idempotencyrecordMixin[0].Fields()
_ = idempotencyrecordMixinFields0 _ = idempotencyrecordMixinFields0
......
...@@ -141,6 +141,10 @@ func (Group) Fields() []ent.Field { ...@@ -141,6 +141,10 @@ func (Group) Fields() []ent.Field {
MaxLen(100). MaxLen(100).
Default(""). Default("").
Comment("默认映射模型 ID,当账号级映射找不到时使用此值"), Comment("默认映射模型 ID,当账号级映射找不到时使用此值"),
field.JSON("messages_dispatch_model_config", domain.OpenAIMessagesDispatchModelConfig{}).
Default(domain.OpenAIMessagesDispatchModelConfig{}).
SchemaType(map[string]string{dialect.Postgres: "jsonb"}).
Comment("OpenAI Messages 调度模型配置:按 Claude 系列/精确模型映射到目标 GPT 模型"),
} }
} }
......
...@@ -65,6 +65,7 @@ type Config struct { ...@@ -65,6 +65,7 @@ type Config struct {
JWT JWTConfig `mapstructure:"jwt"` JWT JWTConfig `mapstructure:"jwt"`
Totp TotpConfig `mapstructure:"totp"` Totp TotpConfig `mapstructure:"totp"`
LinuxDo LinuxDoConnectConfig `mapstructure:"linuxdo_connect"` LinuxDo LinuxDoConnectConfig `mapstructure:"linuxdo_connect"`
OIDC OIDCConnectConfig `mapstructure:"oidc_connect"`
Default DefaultConfig `mapstructure:"default"` Default DefaultConfig `mapstructure:"default"`
RateLimit RateLimitConfig `mapstructure:"rate_limit"` RateLimit RateLimitConfig `mapstructure:"rate_limit"`
Pricing PricingConfig `mapstructure:"pricing"` Pricing PricingConfig `mapstructure:"pricing"`
...@@ -184,6 +185,34 @@ type LinuxDoConnectConfig struct { ...@@ -184,6 +185,34 @@ type LinuxDoConnectConfig struct {
UserInfoUsernamePath string `mapstructure:"userinfo_username_path"` UserInfoUsernamePath string `mapstructure:"userinfo_username_path"`
} }
type OIDCConnectConfig struct {
Enabled bool `mapstructure:"enabled"`
ProviderName string `mapstructure:"provider_name"` // 显示名: "Keycloak" 等
ClientID string `mapstructure:"client_id"`
ClientSecret string `mapstructure:"client_secret"`
IssuerURL string `mapstructure:"issuer_url"`
DiscoveryURL string `mapstructure:"discovery_url"`
AuthorizeURL string `mapstructure:"authorize_url"`
TokenURL string `mapstructure:"token_url"`
UserInfoURL string `mapstructure:"userinfo_url"`
JWKSURL string `mapstructure:"jwks_url"`
Scopes string `mapstructure:"scopes"` // 默认 "openid email profile"
RedirectURL string `mapstructure:"redirect_url"` // 后端回调地址(需在提供方后台登记)
FrontendRedirectURL string `mapstructure:"frontend_redirect_url"` // 前端接收 token 的路由(默认:/auth/oidc/callback)
TokenAuthMethod string `mapstructure:"token_auth_method"` // client_secret_post / client_secret_basic / none
UsePKCE bool `mapstructure:"use_pkce"`
ValidateIDToken bool `mapstructure:"validate_id_token"`
AllowedSigningAlgs string `mapstructure:"allowed_signing_algs"` // 默认 "RS256,ES256,PS256"
ClockSkewSeconds int `mapstructure:"clock_skew_seconds"` // 默认 120
RequireEmailVerified bool `mapstructure:"require_email_verified"` // 默认 false
// 可选:用于从 userinfo JSON 中提取字段的 gjson 路径。
// 为空时,服务端会尝试一组常见字段名。
UserInfoEmailPath string `mapstructure:"userinfo_email_path"`
UserInfoIDPath string `mapstructure:"userinfo_id_path"`
UserInfoUsernamePath string `mapstructure:"userinfo_username_path"`
}
// TokenRefreshConfig OAuth token自动刷新配置 // TokenRefreshConfig OAuth token自动刷新配置
type TokenRefreshConfig struct { type TokenRefreshConfig struct {
// 是否启用自动刷新 // 是否启用自动刷新
...@@ -318,6 +347,12 @@ type GatewayConfig struct { ...@@ -318,6 +347,12 @@ type GatewayConfig struct {
// ForceCodexCLI: 强制将 OpenAI `/v1/responses` 请求按 Codex CLI 处理。 // ForceCodexCLI: 强制将 OpenAI `/v1/responses` 请求按 Codex CLI 处理。
// 用于网关未透传/改写 User-Agent 时的兼容兜底(默认关闭,避免影响其他客户端)。 // 用于网关未透传/改写 User-Agent 时的兼容兜底(默认关闭,避免影响其他客户端)。
ForceCodexCLI bool `mapstructure:"force_codex_cli"` ForceCodexCLI bool `mapstructure:"force_codex_cli"`
// ForcedCodexInstructionsTemplateFile: 服务端强制附加到 Codex 顶层 instructions 的模板文件路径。
// 模板渲染后会直接覆盖最终 instructions;若需要保留客户端 system 转换结果,请在模板中显式引用 {{ .ExistingInstructions }}。
ForcedCodexInstructionsTemplateFile string `mapstructure:"forced_codex_instructions_template_file"`
// ForcedCodexInstructionsTemplate: 启动时从模板文件读取并缓存的模板内容。
// 该字段不直接参与配置反序列化,仅用于请求热路径避免重复读盘。
ForcedCodexInstructionsTemplate string `mapstructure:"-"`
// OpenAIPassthroughAllowTimeoutHeaders: OpenAI 透传模式是否放行客户端超时头 // OpenAIPassthroughAllowTimeoutHeaders: OpenAI 透传模式是否放行客户端超时头
// 关闭(默认)可避免 x-stainless-timeout 等头导致上游提前断流。 // 关闭(默认)可避免 x-stainless-timeout 等头导致上游提前断流。
OpenAIPassthroughAllowTimeoutHeaders bool `mapstructure:"openai_passthrough_allow_timeout_headers"` OpenAIPassthroughAllowTimeoutHeaders bool `mapstructure:"openai_passthrough_allow_timeout_headers"`
...@@ -972,6 +1007,23 @@ func load(allowMissingJWTSecret bool) (*Config, error) { ...@@ -972,6 +1007,23 @@ func load(allowMissingJWTSecret bool) (*Config, error) {
cfg.LinuxDo.UserInfoEmailPath = strings.TrimSpace(cfg.LinuxDo.UserInfoEmailPath) cfg.LinuxDo.UserInfoEmailPath = strings.TrimSpace(cfg.LinuxDo.UserInfoEmailPath)
cfg.LinuxDo.UserInfoIDPath = strings.TrimSpace(cfg.LinuxDo.UserInfoIDPath) cfg.LinuxDo.UserInfoIDPath = strings.TrimSpace(cfg.LinuxDo.UserInfoIDPath)
cfg.LinuxDo.UserInfoUsernamePath = strings.TrimSpace(cfg.LinuxDo.UserInfoUsernamePath) cfg.LinuxDo.UserInfoUsernamePath = strings.TrimSpace(cfg.LinuxDo.UserInfoUsernamePath)
cfg.OIDC.ProviderName = strings.TrimSpace(cfg.OIDC.ProviderName)
cfg.OIDC.ClientID = strings.TrimSpace(cfg.OIDC.ClientID)
cfg.OIDC.ClientSecret = strings.TrimSpace(cfg.OIDC.ClientSecret)
cfg.OIDC.IssuerURL = strings.TrimSpace(cfg.OIDC.IssuerURL)
cfg.OIDC.DiscoveryURL = strings.TrimSpace(cfg.OIDC.DiscoveryURL)
cfg.OIDC.AuthorizeURL = strings.TrimSpace(cfg.OIDC.AuthorizeURL)
cfg.OIDC.TokenURL = strings.TrimSpace(cfg.OIDC.TokenURL)
cfg.OIDC.UserInfoURL = strings.TrimSpace(cfg.OIDC.UserInfoURL)
cfg.OIDC.JWKSURL = strings.TrimSpace(cfg.OIDC.JWKSURL)
cfg.OIDC.Scopes = strings.TrimSpace(cfg.OIDC.Scopes)
cfg.OIDC.RedirectURL = strings.TrimSpace(cfg.OIDC.RedirectURL)
cfg.OIDC.FrontendRedirectURL = strings.TrimSpace(cfg.OIDC.FrontendRedirectURL)
cfg.OIDC.TokenAuthMethod = strings.ToLower(strings.TrimSpace(cfg.OIDC.TokenAuthMethod))
cfg.OIDC.AllowedSigningAlgs = strings.TrimSpace(cfg.OIDC.AllowedSigningAlgs)
cfg.OIDC.UserInfoEmailPath = strings.TrimSpace(cfg.OIDC.UserInfoEmailPath)
cfg.OIDC.UserInfoIDPath = strings.TrimSpace(cfg.OIDC.UserInfoIDPath)
cfg.OIDC.UserInfoUsernamePath = strings.TrimSpace(cfg.OIDC.UserInfoUsernamePath)
cfg.Dashboard.KeyPrefix = strings.TrimSpace(cfg.Dashboard.KeyPrefix) cfg.Dashboard.KeyPrefix = strings.TrimSpace(cfg.Dashboard.KeyPrefix)
cfg.CORS.AllowedOrigins = normalizeStringSlice(cfg.CORS.AllowedOrigins) cfg.CORS.AllowedOrigins = normalizeStringSlice(cfg.CORS.AllowedOrigins)
cfg.Security.ResponseHeaders.AdditionalAllowed = normalizeStringSlice(cfg.Security.ResponseHeaders.AdditionalAllowed) cfg.Security.ResponseHeaders.AdditionalAllowed = normalizeStringSlice(cfg.Security.ResponseHeaders.AdditionalAllowed)
...@@ -983,6 +1035,14 @@ func load(allowMissingJWTSecret bool) (*Config, error) { ...@@ -983,6 +1035,14 @@ func load(allowMissingJWTSecret bool) (*Config, error) {
cfg.Log.Environment = strings.TrimSpace(cfg.Log.Environment) cfg.Log.Environment = strings.TrimSpace(cfg.Log.Environment)
cfg.Log.StacktraceLevel = strings.ToLower(strings.TrimSpace(cfg.Log.StacktraceLevel)) cfg.Log.StacktraceLevel = strings.ToLower(strings.TrimSpace(cfg.Log.StacktraceLevel))
cfg.Log.Output.FilePath = strings.TrimSpace(cfg.Log.Output.FilePath) cfg.Log.Output.FilePath = strings.TrimSpace(cfg.Log.Output.FilePath)
cfg.Gateway.ForcedCodexInstructionsTemplateFile = strings.TrimSpace(cfg.Gateway.ForcedCodexInstructionsTemplateFile)
if cfg.Gateway.ForcedCodexInstructionsTemplateFile != "" {
content, err := os.ReadFile(cfg.Gateway.ForcedCodexInstructionsTemplateFile)
if err != nil {
return nil, fmt.Errorf("read forced codex instructions template %q: %w", cfg.Gateway.ForcedCodexInstructionsTemplateFile, err)
}
cfg.Gateway.ForcedCodexInstructionsTemplate = string(content)
}
// 兼容旧键 gateway.openai_ws.sticky_previous_response_ttl_seconds。 // 兼容旧键 gateway.openai_ws.sticky_previous_response_ttl_seconds。
// 新键未配置(<=0)时回退旧键;新键优先。 // 新键未配置(<=0)时回退旧键;新键优先。
...@@ -1142,6 +1202,30 @@ func setDefaults() { ...@@ -1142,6 +1202,30 @@ func setDefaults() {
viper.SetDefault("linuxdo_connect.userinfo_id_path", "") viper.SetDefault("linuxdo_connect.userinfo_id_path", "")
viper.SetDefault("linuxdo_connect.userinfo_username_path", "") viper.SetDefault("linuxdo_connect.userinfo_username_path", "")
// Generic OIDC OAuth 登录
viper.SetDefault("oidc_connect.enabled", false)
viper.SetDefault("oidc_connect.provider_name", "OIDC")
viper.SetDefault("oidc_connect.client_id", "")
viper.SetDefault("oidc_connect.client_secret", "")
viper.SetDefault("oidc_connect.issuer_url", "")
viper.SetDefault("oidc_connect.discovery_url", "")
viper.SetDefault("oidc_connect.authorize_url", "")
viper.SetDefault("oidc_connect.token_url", "")
viper.SetDefault("oidc_connect.userinfo_url", "")
viper.SetDefault("oidc_connect.jwks_url", "")
viper.SetDefault("oidc_connect.scopes", "openid email profile")
viper.SetDefault("oidc_connect.redirect_url", "")
viper.SetDefault("oidc_connect.frontend_redirect_url", "/auth/oidc/callback")
viper.SetDefault("oidc_connect.token_auth_method", "client_secret_post")
viper.SetDefault("oidc_connect.use_pkce", false)
viper.SetDefault("oidc_connect.validate_id_token", true)
viper.SetDefault("oidc_connect.allowed_signing_algs", "RS256,ES256,PS256")
viper.SetDefault("oidc_connect.clock_skew_seconds", 120)
viper.SetDefault("oidc_connect.require_email_verified", false)
viper.SetDefault("oidc_connect.userinfo_email_path", "")
viper.SetDefault("oidc_connect.userinfo_id_path", "")
viper.SetDefault("oidc_connect.userinfo_username_path", "")
// Database // Database
viper.SetDefault("database.host", "localhost") viper.SetDefault("database.host", "localhost")
viper.SetDefault("database.port", 5432) viper.SetDefault("database.port", 5432)
...@@ -1578,6 +1662,87 @@ func (c *Config) Validate() error { ...@@ -1578,6 +1662,87 @@ func (c *Config) Validate() error {
warnIfInsecureURL("linuxdo_connect.redirect_url", c.LinuxDo.RedirectURL) warnIfInsecureURL("linuxdo_connect.redirect_url", c.LinuxDo.RedirectURL)
warnIfInsecureURL("linuxdo_connect.frontend_redirect_url", c.LinuxDo.FrontendRedirectURL) warnIfInsecureURL("linuxdo_connect.frontend_redirect_url", c.LinuxDo.FrontendRedirectURL)
} }
if c.OIDC.Enabled {
if strings.TrimSpace(c.OIDC.ClientID) == "" {
return fmt.Errorf("oidc_connect.client_id is required when oidc_connect.enabled=true")
}
if strings.TrimSpace(c.OIDC.IssuerURL) == "" {
return fmt.Errorf("oidc_connect.issuer_url is required when oidc_connect.enabled=true")
}
if strings.TrimSpace(c.OIDC.RedirectURL) == "" {
return fmt.Errorf("oidc_connect.redirect_url is required when oidc_connect.enabled=true")
}
if strings.TrimSpace(c.OIDC.FrontendRedirectURL) == "" {
return fmt.Errorf("oidc_connect.frontend_redirect_url is required when oidc_connect.enabled=true")
}
if !scopeContainsOpenID(c.OIDC.Scopes) {
return fmt.Errorf("oidc_connect.scopes must contain openid")
}
method := strings.ToLower(strings.TrimSpace(c.OIDC.TokenAuthMethod))
switch method {
case "", "client_secret_post", "client_secret_basic", "none":
default:
return fmt.Errorf("oidc_connect.token_auth_method must be one of: client_secret_post/client_secret_basic/none")
}
if method == "none" && !c.OIDC.UsePKCE {
return fmt.Errorf("oidc_connect.use_pkce must be true when oidc_connect.token_auth_method=none")
}
if (method == "" || method == "client_secret_post" || method == "client_secret_basic") &&
strings.TrimSpace(c.OIDC.ClientSecret) == "" {
return fmt.Errorf("oidc_connect.client_secret is required when oidc_connect.enabled=true and token_auth_method is client_secret_post/client_secret_basic")
}
if c.OIDC.ClockSkewSeconds < 0 || c.OIDC.ClockSkewSeconds > 600 {
return fmt.Errorf("oidc_connect.clock_skew_seconds must be between 0 and 600")
}
if c.OIDC.ValidateIDToken && strings.TrimSpace(c.OIDC.AllowedSigningAlgs) == "" {
return fmt.Errorf("oidc_connect.allowed_signing_algs is required when oidc_connect.validate_id_token=true")
}
if err := ValidateAbsoluteHTTPURL(c.OIDC.IssuerURL); err != nil {
return fmt.Errorf("oidc_connect.issuer_url invalid: %w", err)
}
if v := strings.TrimSpace(c.OIDC.DiscoveryURL); v != "" {
if err := ValidateAbsoluteHTTPURL(v); err != nil {
return fmt.Errorf("oidc_connect.discovery_url invalid: %w", err)
}
}
if v := strings.TrimSpace(c.OIDC.AuthorizeURL); v != "" {
if err := ValidateAbsoluteHTTPURL(v); err != nil {
return fmt.Errorf("oidc_connect.authorize_url invalid: %w", err)
}
}
if v := strings.TrimSpace(c.OIDC.TokenURL); v != "" {
if err := ValidateAbsoluteHTTPURL(v); err != nil {
return fmt.Errorf("oidc_connect.token_url invalid: %w", err)
}
}
if v := strings.TrimSpace(c.OIDC.UserInfoURL); v != "" {
if err := ValidateAbsoluteHTTPURL(v); err != nil {
return fmt.Errorf("oidc_connect.userinfo_url invalid: %w", err)
}
}
if v := strings.TrimSpace(c.OIDC.JWKSURL); v != "" {
if err := ValidateAbsoluteHTTPURL(v); err != nil {
return fmt.Errorf("oidc_connect.jwks_url invalid: %w", err)
}
}
if err := ValidateAbsoluteHTTPURL(c.OIDC.RedirectURL); err != nil {
return fmt.Errorf("oidc_connect.redirect_url invalid: %w", err)
}
if err := ValidateFrontendRedirectURL(c.OIDC.FrontendRedirectURL); err != nil {
return fmt.Errorf("oidc_connect.frontend_redirect_url invalid: %w", err)
}
warnIfInsecureURL("oidc_connect.issuer_url", c.OIDC.IssuerURL)
warnIfInsecureURL("oidc_connect.discovery_url", c.OIDC.DiscoveryURL)
warnIfInsecureURL("oidc_connect.authorize_url", c.OIDC.AuthorizeURL)
warnIfInsecureURL("oidc_connect.token_url", c.OIDC.TokenURL)
warnIfInsecureURL("oidc_connect.userinfo_url", c.OIDC.UserInfoURL)
warnIfInsecureURL("oidc_connect.jwks_url", c.OIDC.JWKSURL)
warnIfInsecureURL("oidc_connect.redirect_url", c.OIDC.RedirectURL)
warnIfInsecureURL("oidc_connect.frontend_redirect_url", c.OIDC.FrontendRedirectURL)
}
if c.Billing.CircuitBreaker.Enabled { if c.Billing.CircuitBreaker.Enabled {
if c.Billing.CircuitBreaker.FailureThreshold <= 0 { if c.Billing.CircuitBreaker.FailureThreshold <= 0 {
return fmt.Errorf("billing.circuit_breaker.failure_threshold must be positive") return fmt.Errorf("billing.circuit_breaker.failure_threshold must be positive")
...@@ -2196,6 +2361,15 @@ func ValidateFrontendRedirectURL(raw string) error { ...@@ -2196,6 +2361,15 @@ func ValidateFrontendRedirectURL(raw string) error {
return nil return nil
} }
func scopeContainsOpenID(scopes string) bool {
for _, scope := range strings.Fields(strings.ToLower(strings.TrimSpace(scopes))) {
if scope == "openid" {
return true
}
}
return false
}
// isHTTPScheme 检查是否为 HTTP 或 HTTPS 协议 // isHTTPScheme 检查是否为 HTTP 或 HTTPS 协议
func isHTTPScheme(scheme string) bool { func isHTTPScheme(scheme string) bool {
return strings.EqualFold(scheme, "http") || strings.EqualFold(scheme, "https") return strings.EqualFold(scheme, "http") || strings.EqualFold(scheme, "https")
......
package config package config
import ( import (
"os"
"path/filepath"
"strings" "strings"
"testing" "testing"
"time" "time"
...@@ -223,6 +225,23 @@ func TestLoadSchedulingConfigFromEnv(t *testing.T) { ...@@ -223,6 +225,23 @@ func TestLoadSchedulingConfigFromEnv(t *testing.T) {
} }
} }
func TestLoadForcedCodexInstructionsTemplate(t *testing.T) {
resetViperWithJWTSecret(t)
tempDir := t.TempDir()
templatePath := filepath.Join(tempDir, "codex-instructions.md.tmpl")
configPath := filepath.Join(tempDir, "config.yaml")
require.NoError(t, os.WriteFile(templatePath, []byte("server-prefix\n\n{{ .ExistingInstructions }}"), 0o644))
require.NoError(t, os.WriteFile(configPath, []byte("gateway:\n forced_codex_instructions_template_file: \""+templatePath+"\"\n"), 0o644))
t.Setenv("DATA_DIR", tempDir)
cfg, err := Load()
require.NoError(t, err)
require.Equal(t, templatePath, cfg.Gateway.ForcedCodexInstructionsTemplateFile)
require.Equal(t, "server-prefix\n\n{{ .ExistingInstructions }}", cfg.Gateway.ForcedCodexInstructionsTemplate)
}
func TestLoadDefaultSecurityToggles(t *testing.T) { func TestLoadDefaultSecurityToggles(t *testing.T) {
resetViperWithJWTSecret(t) resetViperWithJWTSecret(t)
...@@ -351,6 +370,60 @@ func TestValidateLinuxDoPKCERequiredForPublicClient(t *testing.T) { ...@@ -351,6 +370,60 @@ func TestValidateLinuxDoPKCERequiredForPublicClient(t *testing.T) {
} }
} }
func TestValidateOIDCScopesMustContainOpenID(t *testing.T) {
resetViperWithJWTSecret(t)
cfg, err := Load()
if err != nil {
t.Fatalf("Load() error: %v", err)
}
cfg.OIDC.Enabled = true
cfg.OIDC.ClientID = "oidc-client"
cfg.OIDC.ClientSecret = "oidc-secret"
cfg.OIDC.IssuerURL = "https://issuer.example.com"
cfg.OIDC.AuthorizeURL = "https://issuer.example.com/auth"
cfg.OIDC.TokenURL = "https://issuer.example.com/token"
cfg.OIDC.JWKSURL = "https://issuer.example.com/jwks"
cfg.OIDC.RedirectURL = "https://example.com/api/v1/auth/oauth/oidc/callback"
cfg.OIDC.FrontendRedirectURL = "/auth/oidc/callback"
cfg.OIDC.Scopes = "profile email"
err = cfg.Validate()
if err == nil {
t.Fatalf("Validate() expected error when scopes do not include openid, got nil")
}
if !strings.Contains(err.Error(), "oidc_connect.scopes") {
t.Fatalf("Validate() expected oidc_connect.scopes error, got: %v", err)
}
}
func TestValidateOIDCAllowsIssuerOnlyEndpointsWithDiscoveryFallback(t *testing.T) {
resetViperWithJWTSecret(t)
cfg, err := Load()
if err != nil {
t.Fatalf("Load() error: %v", err)
}
cfg.OIDC.Enabled = true
cfg.OIDC.ClientID = "oidc-client"
cfg.OIDC.ClientSecret = "oidc-secret"
cfg.OIDC.IssuerURL = "https://issuer.example.com"
cfg.OIDC.AuthorizeURL = ""
cfg.OIDC.TokenURL = ""
cfg.OIDC.JWKSURL = ""
cfg.OIDC.RedirectURL = "https://example.com/api/v1/auth/oauth/oidc/callback"
cfg.OIDC.FrontendRedirectURL = "/auth/oidc/callback"
cfg.OIDC.Scopes = "openid email profile"
cfg.OIDC.ValidateIDToken = true
err = cfg.Validate()
if err != nil {
t.Fatalf("Validate() expected issuer-only OIDC config to pass with discovery fallback, got: %v", err)
}
}
func TestLoadDefaultDashboardCacheConfig(t *testing.T) { func TestLoadDefaultDashboardCacheConfig(t *testing.T) {
resetViperWithJWTSecret(t) resetViperWithJWTSecret(t)
......
package domain
// OpenAIMessagesDispatchModelConfig controls how Anthropic /v1/messages
// requests are mapped onto OpenAI/Codex models.
type OpenAIMessagesDispatchModelConfig struct {
OpusMappedModel string `json:"opus_mapped_model,omitempty"`
SonnetMappedModel string `json:"sonnet_mapped_model,omitempty"`
HaikuMappedModel string `json:"haiku_mapped_model,omitempty"`
ExactModelMappings map[string]string `json:"exact_model_mappings,omitempty"`
}
...@@ -105,10 +105,11 @@ type CreateGroupRequest struct { ...@@ -105,10 +105,11 @@ type CreateGroupRequest struct {
// 支持的模型系列(仅 antigravity 平台使用) // 支持的模型系列(仅 antigravity 平台使用)
SupportedModelScopes []string `json:"supported_model_scopes"` SupportedModelScopes []string `json:"supported_model_scopes"`
// OpenAI Messages 调度配置(仅 openai 平台使用) // OpenAI Messages 调度配置(仅 openai 平台使用)
AllowMessagesDispatch bool `json:"allow_messages_dispatch"` AllowMessagesDispatch bool `json:"allow_messages_dispatch"`
RequireOAuthOnly bool `json:"require_oauth_only"` RequireOAuthOnly bool `json:"require_oauth_only"`
RequirePrivacySet bool `json:"require_privacy_set"` RequirePrivacySet bool `json:"require_privacy_set"`
DefaultMappedModel string `json:"default_mapped_model"` DefaultMappedModel string `json:"default_mapped_model"`
MessagesDispatchModelConfig service.OpenAIMessagesDispatchModelConfig `json:"messages_dispatch_model_config"`
// 从指定分组复制账号(创建后自动绑定) // 从指定分组复制账号(创建后自动绑定)
CopyAccountsFromGroupIDs []int64 `json:"copy_accounts_from_group_ids"` CopyAccountsFromGroupIDs []int64 `json:"copy_accounts_from_group_ids"`
} }
...@@ -139,10 +140,11 @@ type UpdateGroupRequest struct { ...@@ -139,10 +140,11 @@ type UpdateGroupRequest struct {
// 支持的模型系列(仅 antigravity 平台使用) // 支持的模型系列(仅 antigravity 平台使用)
SupportedModelScopes *[]string `json:"supported_model_scopes"` SupportedModelScopes *[]string `json:"supported_model_scopes"`
// OpenAI Messages 调度配置(仅 openai 平台使用) // OpenAI Messages 调度配置(仅 openai 平台使用)
AllowMessagesDispatch *bool `json:"allow_messages_dispatch"` AllowMessagesDispatch *bool `json:"allow_messages_dispatch"`
RequireOAuthOnly *bool `json:"require_oauth_only"` RequireOAuthOnly *bool `json:"require_oauth_only"`
RequirePrivacySet *bool `json:"require_privacy_set"` RequirePrivacySet *bool `json:"require_privacy_set"`
DefaultMappedModel *string `json:"default_mapped_model"` DefaultMappedModel *string `json:"default_mapped_model"`
MessagesDispatchModelConfig *service.OpenAIMessagesDispatchModelConfig `json:"messages_dispatch_model_config"`
// 从指定分组复制账号(同步操作:先清空当前分组的账号绑定,再绑定源分组的账号) // 从指定分组复制账号(同步操作:先清空当前分组的账号绑定,再绑定源分组的账号)
CopyAccountsFromGroupIDs []int64 `json:"copy_accounts_from_group_ids"` CopyAccountsFromGroupIDs []int64 `json:"copy_accounts_from_group_ids"`
} }
...@@ -259,6 +261,7 @@ func (h *GroupHandler) Create(c *gin.Context) { ...@@ -259,6 +261,7 @@ func (h *GroupHandler) Create(c *gin.Context) {
RequireOAuthOnly: req.RequireOAuthOnly, RequireOAuthOnly: req.RequireOAuthOnly,
RequirePrivacySet: req.RequirePrivacySet, RequirePrivacySet: req.RequirePrivacySet,
DefaultMappedModel: req.DefaultMappedModel, DefaultMappedModel: req.DefaultMappedModel,
MessagesDispatchModelConfig: req.MessagesDispatchModelConfig,
CopyAccountsFromGroupIDs: req.CopyAccountsFromGroupIDs, CopyAccountsFromGroupIDs: req.CopyAccountsFromGroupIDs,
}) })
if err != nil { if err != nil {
...@@ -309,6 +312,7 @@ func (h *GroupHandler) Update(c *gin.Context) { ...@@ -309,6 +312,7 @@ func (h *GroupHandler) Update(c *gin.Context) {
RequireOAuthOnly: req.RequireOAuthOnly, RequireOAuthOnly: req.RequireOAuthOnly,
RequirePrivacySet: req.RequirePrivacySet, RequirePrivacySet: req.RequirePrivacySet,
DefaultMappedModel: req.DefaultMappedModel, DefaultMappedModel: req.DefaultMappedModel,
MessagesDispatchModelConfig: req.MessagesDispatchModelConfig,
CopyAccountsFromGroupIDs: req.CopyAccountsFromGroupIDs, CopyAccountsFromGroupIDs: req.CopyAccountsFromGroupIDs,
}) })
if err != nil { if err != nil {
......
...@@ -35,6 +35,15 @@ func generateMenuItemID() (string, error) { ...@@ -35,6 +35,15 @@ func generateMenuItemID() (string, error) {
return hex.EncodeToString(b), nil return hex.EncodeToString(b), nil
} }
func scopesContainOpenID(scopes string) bool {
for _, scope := range strings.Fields(strings.ToLower(strings.TrimSpace(scopes))) {
if scope == "openid" {
return true
}
}
return false
}
// SettingHandler 系统设置处理器 // SettingHandler 系统设置处理器
type SettingHandler struct { type SettingHandler struct {
settingService *service.SettingService settingService *service.SettingService
...@@ -96,6 +105,28 @@ func (h *SettingHandler) GetSettings(c *gin.Context) { ...@@ -96,6 +105,28 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
LinuxDoConnectClientID: settings.LinuxDoConnectClientID, LinuxDoConnectClientID: settings.LinuxDoConnectClientID,
LinuxDoConnectClientSecretConfigured: settings.LinuxDoConnectClientSecretConfigured, LinuxDoConnectClientSecretConfigured: settings.LinuxDoConnectClientSecretConfigured,
LinuxDoConnectRedirectURL: settings.LinuxDoConnectRedirectURL, LinuxDoConnectRedirectURL: settings.LinuxDoConnectRedirectURL,
OIDCConnectEnabled: settings.OIDCConnectEnabled,
OIDCConnectProviderName: settings.OIDCConnectProviderName,
OIDCConnectClientID: settings.OIDCConnectClientID,
OIDCConnectClientSecretConfigured: settings.OIDCConnectClientSecretConfigured,
OIDCConnectIssuerURL: settings.OIDCConnectIssuerURL,
OIDCConnectDiscoveryURL: settings.OIDCConnectDiscoveryURL,
OIDCConnectAuthorizeURL: settings.OIDCConnectAuthorizeURL,
OIDCConnectTokenURL: settings.OIDCConnectTokenURL,
OIDCConnectUserInfoURL: settings.OIDCConnectUserInfoURL,
OIDCConnectJWKSURL: settings.OIDCConnectJWKSURL,
OIDCConnectScopes: settings.OIDCConnectScopes,
OIDCConnectRedirectURL: settings.OIDCConnectRedirectURL,
OIDCConnectFrontendRedirectURL: settings.OIDCConnectFrontendRedirectURL,
OIDCConnectTokenAuthMethod: settings.OIDCConnectTokenAuthMethod,
OIDCConnectUsePKCE: settings.OIDCConnectUsePKCE,
OIDCConnectValidateIDToken: settings.OIDCConnectValidateIDToken,
OIDCConnectAllowedSigningAlgs: settings.OIDCConnectAllowedSigningAlgs,
OIDCConnectClockSkewSeconds: settings.OIDCConnectClockSkewSeconds,
OIDCConnectRequireEmailVerified: settings.OIDCConnectRequireEmailVerified,
OIDCConnectUserInfoEmailPath: settings.OIDCConnectUserInfoEmailPath,
OIDCConnectUserInfoIDPath: settings.OIDCConnectUserInfoIDPath,
OIDCConnectUserInfoUsernamePath: settings.OIDCConnectUserInfoUsernamePath,
SiteName: settings.SiteName, SiteName: settings.SiteName,
SiteLogo: settings.SiteLogo, SiteLogo: settings.SiteLogo,
SiteSubtitle: settings.SiteSubtitle, SiteSubtitle: settings.SiteSubtitle,
...@@ -166,6 +197,30 @@ type UpdateSettingsRequest struct { ...@@ -166,6 +197,30 @@ type UpdateSettingsRequest struct {
LinuxDoConnectClientSecret string `json:"linuxdo_connect_client_secret"` LinuxDoConnectClientSecret string `json:"linuxdo_connect_client_secret"`
LinuxDoConnectRedirectURL string `json:"linuxdo_connect_redirect_url"` LinuxDoConnectRedirectURL string `json:"linuxdo_connect_redirect_url"`
// Generic OIDC OAuth 登录
OIDCConnectEnabled bool `json:"oidc_connect_enabled"`
OIDCConnectProviderName string `json:"oidc_connect_provider_name"`
OIDCConnectClientID string `json:"oidc_connect_client_id"`
OIDCConnectClientSecret string `json:"oidc_connect_client_secret"`
OIDCConnectIssuerURL string `json:"oidc_connect_issuer_url"`
OIDCConnectDiscoveryURL string `json:"oidc_connect_discovery_url"`
OIDCConnectAuthorizeURL string `json:"oidc_connect_authorize_url"`
OIDCConnectTokenURL string `json:"oidc_connect_token_url"`
OIDCConnectUserInfoURL string `json:"oidc_connect_userinfo_url"`
OIDCConnectJWKSURL string `json:"oidc_connect_jwks_url"`
OIDCConnectScopes string `json:"oidc_connect_scopes"`
OIDCConnectRedirectURL string `json:"oidc_connect_redirect_url"`
OIDCConnectFrontendRedirectURL string `json:"oidc_connect_frontend_redirect_url"`
OIDCConnectTokenAuthMethod string `json:"oidc_connect_token_auth_method"`
OIDCConnectUsePKCE bool `json:"oidc_connect_use_pkce"`
OIDCConnectValidateIDToken bool `json:"oidc_connect_validate_id_token"`
OIDCConnectAllowedSigningAlgs string `json:"oidc_connect_allowed_signing_algs"`
OIDCConnectClockSkewSeconds int `json:"oidc_connect_clock_skew_seconds"`
OIDCConnectRequireEmailVerified bool `json:"oidc_connect_require_email_verified"`
OIDCConnectUserInfoEmailPath string `json:"oidc_connect_userinfo_email_path"`
OIDCConnectUserInfoIDPath string `json:"oidc_connect_userinfo_id_path"`
OIDCConnectUserInfoUsernamePath string `json:"oidc_connect_userinfo_username_path"`
// OEM设置 // OEM设置
SiteName string `json:"site_name"` SiteName string `json:"site_name"`
SiteLogo string `json:"site_logo"` SiteLogo string `json:"site_logo"`
...@@ -335,6 +390,122 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { ...@@ -335,6 +390,122 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
} }
} }
// Generic OIDC 参数验证
if req.OIDCConnectEnabled {
req.OIDCConnectProviderName = strings.TrimSpace(req.OIDCConnectProviderName)
req.OIDCConnectClientID = strings.TrimSpace(req.OIDCConnectClientID)
req.OIDCConnectClientSecret = strings.TrimSpace(req.OIDCConnectClientSecret)
req.OIDCConnectIssuerURL = strings.TrimSpace(req.OIDCConnectIssuerURL)
req.OIDCConnectDiscoveryURL = strings.TrimSpace(req.OIDCConnectDiscoveryURL)
req.OIDCConnectAuthorizeURL = strings.TrimSpace(req.OIDCConnectAuthorizeURL)
req.OIDCConnectTokenURL = strings.TrimSpace(req.OIDCConnectTokenURL)
req.OIDCConnectUserInfoURL = strings.TrimSpace(req.OIDCConnectUserInfoURL)
req.OIDCConnectJWKSURL = strings.TrimSpace(req.OIDCConnectJWKSURL)
req.OIDCConnectScopes = strings.TrimSpace(req.OIDCConnectScopes)
req.OIDCConnectRedirectURL = strings.TrimSpace(req.OIDCConnectRedirectURL)
req.OIDCConnectFrontendRedirectURL = strings.TrimSpace(req.OIDCConnectFrontendRedirectURL)
req.OIDCConnectTokenAuthMethod = strings.ToLower(strings.TrimSpace(req.OIDCConnectTokenAuthMethod))
req.OIDCConnectAllowedSigningAlgs = strings.TrimSpace(req.OIDCConnectAllowedSigningAlgs)
req.OIDCConnectUserInfoEmailPath = strings.TrimSpace(req.OIDCConnectUserInfoEmailPath)
req.OIDCConnectUserInfoIDPath = strings.TrimSpace(req.OIDCConnectUserInfoIDPath)
req.OIDCConnectUserInfoUsernamePath = strings.TrimSpace(req.OIDCConnectUserInfoUsernamePath)
if req.OIDCConnectProviderName == "" {
req.OIDCConnectProviderName = "OIDC"
}
if req.OIDCConnectClientID == "" {
response.BadRequest(c, "OIDC Client ID is required when enabled")
return
}
if req.OIDCConnectIssuerURL == "" {
response.BadRequest(c, "OIDC Issuer URL is required when enabled")
return
}
if err := config.ValidateAbsoluteHTTPURL(req.OIDCConnectIssuerURL); err != nil {
response.BadRequest(c, "OIDC Issuer URL must be an absolute http(s) URL")
return
}
if req.OIDCConnectDiscoveryURL != "" {
if err := config.ValidateAbsoluteHTTPURL(req.OIDCConnectDiscoveryURL); err != nil {
response.BadRequest(c, "OIDC Discovery URL must be an absolute http(s) URL")
return
}
}
if req.OIDCConnectAuthorizeURL != "" {
if err := config.ValidateAbsoluteHTTPURL(req.OIDCConnectAuthorizeURL); err != nil {
response.BadRequest(c, "OIDC Authorize URL must be an absolute http(s) URL")
return
}
}
if req.OIDCConnectTokenURL != "" {
if err := config.ValidateAbsoluteHTTPURL(req.OIDCConnectTokenURL); err != nil {
response.BadRequest(c, "OIDC Token URL must be an absolute http(s) URL")
return
}
}
if req.OIDCConnectUserInfoURL != "" {
if err := config.ValidateAbsoluteHTTPURL(req.OIDCConnectUserInfoURL); err != nil {
response.BadRequest(c, "OIDC UserInfo URL must be an absolute http(s) URL")
return
}
}
if req.OIDCConnectRedirectURL == "" {
response.BadRequest(c, "OIDC Redirect URL is required when enabled")
return
}
if err := config.ValidateAbsoluteHTTPURL(req.OIDCConnectRedirectURL); err != nil {
response.BadRequest(c, "OIDC Redirect URL must be an absolute http(s) URL")
return
}
if req.OIDCConnectFrontendRedirectURL == "" {
response.BadRequest(c, "OIDC Frontend Redirect URL is required when enabled")
return
}
if err := config.ValidateFrontendRedirectURL(req.OIDCConnectFrontendRedirectURL); err != nil {
response.BadRequest(c, "OIDC Frontend Redirect URL is invalid")
return
}
if !scopesContainOpenID(req.OIDCConnectScopes) {
response.BadRequest(c, "OIDC scopes must contain openid")
return
}
switch req.OIDCConnectTokenAuthMethod {
case "", "client_secret_post", "client_secret_basic", "none":
default:
response.BadRequest(c, "OIDC Token Auth Method must be one of client_secret_post/client_secret_basic/none")
return
}
if req.OIDCConnectTokenAuthMethod == "none" && !req.OIDCConnectUsePKCE {
response.BadRequest(c, "OIDC PKCE must be enabled when token_auth_method=none")
return
}
if req.OIDCConnectClockSkewSeconds < 0 || req.OIDCConnectClockSkewSeconds > 600 {
response.BadRequest(c, "OIDC clock skew seconds must be between 0 and 600")
return
}
if req.OIDCConnectValidateIDToken {
if req.OIDCConnectAllowedSigningAlgs == "" {
response.BadRequest(c, "OIDC Allowed Signing Algs is required when validate_id_token=true")
return
}
}
if req.OIDCConnectJWKSURL != "" {
if err := config.ValidateAbsoluteHTTPURL(req.OIDCConnectJWKSURL); err != nil {
response.BadRequest(c, "OIDC JWKS URL must be an absolute http(s) URL")
return
}
}
if req.OIDCConnectTokenAuthMethod == "" || req.OIDCConnectTokenAuthMethod == "client_secret_post" || req.OIDCConnectTokenAuthMethod == "client_secret_basic" {
if req.OIDCConnectClientSecret == "" {
if previousSettings.OIDCConnectClientSecret == "" {
response.BadRequest(c, "OIDC Client Secret is required when enabled")
return
}
req.OIDCConnectClientSecret = previousSettings.OIDCConnectClientSecret
}
}
}
// “购买订阅”页面配置验证 // “购买订阅”页面配置验证
purchaseEnabled := previousSettings.PurchaseSubscriptionEnabled purchaseEnabled := previousSettings.PurchaseSubscriptionEnabled
if req.PurchaseSubscriptionEnabled != nil { if req.PurchaseSubscriptionEnabled != nil {
...@@ -565,6 +736,28 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { ...@@ -565,6 +736,28 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
LinuxDoConnectClientID: req.LinuxDoConnectClientID, LinuxDoConnectClientID: req.LinuxDoConnectClientID,
LinuxDoConnectClientSecret: req.LinuxDoConnectClientSecret, LinuxDoConnectClientSecret: req.LinuxDoConnectClientSecret,
LinuxDoConnectRedirectURL: req.LinuxDoConnectRedirectURL, LinuxDoConnectRedirectURL: req.LinuxDoConnectRedirectURL,
OIDCConnectEnabled: req.OIDCConnectEnabled,
OIDCConnectProviderName: req.OIDCConnectProviderName,
OIDCConnectClientID: req.OIDCConnectClientID,
OIDCConnectClientSecret: req.OIDCConnectClientSecret,
OIDCConnectIssuerURL: req.OIDCConnectIssuerURL,
OIDCConnectDiscoveryURL: req.OIDCConnectDiscoveryURL,
OIDCConnectAuthorizeURL: req.OIDCConnectAuthorizeURL,
OIDCConnectTokenURL: req.OIDCConnectTokenURL,
OIDCConnectUserInfoURL: req.OIDCConnectUserInfoURL,
OIDCConnectJWKSURL: req.OIDCConnectJWKSURL,
OIDCConnectScopes: req.OIDCConnectScopes,
OIDCConnectRedirectURL: req.OIDCConnectRedirectURL,
OIDCConnectFrontendRedirectURL: req.OIDCConnectFrontendRedirectURL,
OIDCConnectTokenAuthMethod: req.OIDCConnectTokenAuthMethod,
OIDCConnectUsePKCE: req.OIDCConnectUsePKCE,
OIDCConnectValidateIDToken: req.OIDCConnectValidateIDToken,
OIDCConnectAllowedSigningAlgs: req.OIDCConnectAllowedSigningAlgs,
OIDCConnectClockSkewSeconds: req.OIDCConnectClockSkewSeconds,
OIDCConnectRequireEmailVerified: req.OIDCConnectRequireEmailVerified,
OIDCConnectUserInfoEmailPath: req.OIDCConnectUserInfoEmailPath,
OIDCConnectUserInfoIDPath: req.OIDCConnectUserInfoIDPath,
OIDCConnectUserInfoUsernamePath: req.OIDCConnectUserInfoUsernamePath,
SiteName: req.SiteName, SiteName: req.SiteName,
SiteLogo: req.SiteLogo, SiteLogo: req.SiteLogo,
SiteSubtitle: req.SiteSubtitle, SiteSubtitle: req.SiteSubtitle,
...@@ -682,6 +875,28 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { ...@@ -682,6 +875,28 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
LinuxDoConnectClientID: updatedSettings.LinuxDoConnectClientID, LinuxDoConnectClientID: updatedSettings.LinuxDoConnectClientID,
LinuxDoConnectClientSecretConfigured: updatedSettings.LinuxDoConnectClientSecretConfigured, LinuxDoConnectClientSecretConfigured: updatedSettings.LinuxDoConnectClientSecretConfigured,
LinuxDoConnectRedirectURL: updatedSettings.LinuxDoConnectRedirectURL, LinuxDoConnectRedirectURL: updatedSettings.LinuxDoConnectRedirectURL,
OIDCConnectEnabled: updatedSettings.OIDCConnectEnabled,
OIDCConnectProviderName: updatedSettings.OIDCConnectProviderName,
OIDCConnectClientID: updatedSettings.OIDCConnectClientID,
OIDCConnectClientSecretConfigured: updatedSettings.OIDCConnectClientSecretConfigured,
OIDCConnectIssuerURL: updatedSettings.OIDCConnectIssuerURL,
OIDCConnectDiscoveryURL: updatedSettings.OIDCConnectDiscoveryURL,
OIDCConnectAuthorizeURL: updatedSettings.OIDCConnectAuthorizeURL,
OIDCConnectTokenURL: updatedSettings.OIDCConnectTokenURL,
OIDCConnectUserInfoURL: updatedSettings.OIDCConnectUserInfoURL,
OIDCConnectJWKSURL: updatedSettings.OIDCConnectJWKSURL,
OIDCConnectScopes: updatedSettings.OIDCConnectScopes,
OIDCConnectRedirectURL: updatedSettings.OIDCConnectRedirectURL,
OIDCConnectFrontendRedirectURL: updatedSettings.OIDCConnectFrontendRedirectURL,
OIDCConnectTokenAuthMethod: updatedSettings.OIDCConnectTokenAuthMethod,
OIDCConnectUsePKCE: updatedSettings.OIDCConnectUsePKCE,
OIDCConnectValidateIDToken: updatedSettings.OIDCConnectValidateIDToken,
OIDCConnectAllowedSigningAlgs: updatedSettings.OIDCConnectAllowedSigningAlgs,
OIDCConnectClockSkewSeconds: updatedSettings.OIDCConnectClockSkewSeconds,
OIDCConnectRequireEmailVerified: updatedSettings.OIDCConnectRequireEmailVerified,
OIDCConnectUserInfoEmailPath: updatedSettings.OIDCConnectUserInfoEmailPath,
OIDCConnectUserInfoIDPath: updatedSettings.OIDCConnectUserInfoIDPath,
OIDCConnectUserInfoUsernamePath: updatedSettings.OIDCConnectUserInfoUsernamePath,
SiteName: updatedSettings.SiteName, SiteName: updatedSettings.SiteName,
SiteLogo: updatedSettings.SiteLogo, SiteLogo: updatedSettings.SiteLogo,
SiteSubtitle: updatedSettings.SiteSubtitle, SiteSubtitle: updatedSettings.SiteSubtitle,
...@@ -802,6 +1017,72 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings, ...@@ -802,6 +1017,72 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings,
if before.LinuxDoConnectRedirectURL != after.LinuxDoConnectRedirectURL { if before.LinuxDoConnectRedirectURL != after.LinuxDoConnectRedirectURL {
changed = append(changed, "linuxdo_connect_redirect_url") changed = append(changed, "linuxdo_connect_redirect_url")
} }
if before.OIDCConnectEnabled != after.OIDCConnectEnabled {
changed = append(changed, "oidc_connect_enabled")
}
if before.OIDCConnectProviderName != after.OIDCConnectProviderName {
changed = append(changed, "oidc_connect_provider_name")
}
if before.OIDCConnectClientID != after.OIDCConnectClientID {
changed = append(changed, "oidc_connect_client_id")
}
if req.OIDCConnectClientSecret != "" {
changed = append(changed, "oidc_connect_client_secret")
}
if before.OIDCConnectIssuerURL != after.OIDCConnectIssuerURL {
changed = append(changed, "oidc_connect_issuer_url")
}
if before.OIDCConnectDiscoveryURL != after.OIDCConnectDiscoveryURL {
changed = append(changed, "oidc_connect_discovery_url")
}
if before.OIDCConnectAuthorizeURL != after.OIDCConnectAuthorizeURL {
changed = append(changed, "oidc_connect_authorize_url")
}
if before.OIDCConnectTokenURL != after.OIDCConnectTokenURL {
changed = append(changed, "oidc_connect_token_url")
}
if before.OIDCConnectUserInfoURL != after.OIDCConnectUserInfoURL {
changed = append(changed, "oidc_connect_userinfo_url")
}
if before.OIDCConnectJWKSURL != after.OIDCConnectJWKSURL {
changed = append(changed, "oidc_connect_jwks_url")
}
if before.OIDCConnectScopes != after.OIDCConnectScopes {
changed = append(changed, "oidc_connect_scopes")
}
if before.OIDCConnectRedirectURL != after.OIDCConnectRedirectURL {
changed = append(changed, "oidc_connect_redirect_url")
}
if before.OIDCConnectFrontendRedirectURL != after.OIDCConnectFrontendRedirectURL {
changed = append(changed, "oidc_connect_frontend_redirect_url")
}
if before.OIDCConnectTokenAuthMethod != after.OIDCConnectTokenAuthMethod {
changed = append(changed, "oidc_connect_token_auth_method")
}
if before.OIDCConnectUsePKCE != after.OIDCConnectUsePKCE {
changed = append(changed, "oidc_connect_use_pkce")
}
if before.OIDCConnectValidateIDToken != after.OIDCConnectValidateIDToken {
changed = append(changed, "oidc_connect_validate_id_token")
}
if before.OIDCConnectAllowedSigningAlgs != after.OIDCConnectAllowedSigningAlgs {
changed = append(changed, "oidc_connect_allowed_signing_algs")
}
if before.OIDCConnectClockSkewSeconds != after.OIDCConnectClockSkewSeconds {
changed = append(changed, "oidc_connect_clock_skew_seconds")
}
if before.OIDCConnectRequireEmailVerified != after.OIDCConnectRequireEmailVerified {
changed = append(changed, "oidc_connect_require_email_verified")
}
if before.OIDCConnectUserInfoEmailPath != after.OIDCConnectUserInfoEmailPath {
changed = append(changed, "oidc_connect_userinfo_email_path")
}
if before.OIDCConnectUserInfoIDPath != after.OIDCConnectUserInfoIDPath {
changed = append(changed, "oidc_connect_userinfo_id_path")
}
if before.OIDCConnectUserInfoUsernamePath != after.OIDCConnectUserInfoUsernamePath {
changed = append(changed, "oidc_connect_userinfo_username_path")
}
if before.SiteName != after.SiteName { if before.SiteName != after.SiteName {
changed = append(changed, "site_name") changed = append(changed, "site_name")
} }
......
package handler
import (
"context"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rsa"
"crypto/sha256"
"encoding/base64"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"log"
"math/big"
"net/http"
"net/url"
"strconv"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/oauth"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/golang-jwt/jwt/v5"
"github.com/imroc/req/v3"
"github.com/tidwall/gjson"
)
const (
oidcOAuthCookiePath = "/api/v1/auth/oauth/oidc"
oidcOAuthStateCookieName = "oidc_oauth_state"
oidcOAuthVerifierCookie = "oidc_oauth_verifier"
oidcOAuthRedirectCookie = "oidc_oauth_redirect"
oidcOAuthNonceCookie = "oidc_oauth_nonce"
oidcOAuthCookieMaxAgeSec = 10 * 60 // 10 minutes
oidcOAuthDefaultRedirectTo = "/dashboard"
oidcOAuthDefaultFrontendCB = "/auth/oidc/callback"
)
type oidcTokenResponse struct {
AccessToken string `json:"access_token"`
TokenType string `json:"token_type"`
ExpiresIn int64 `json:"expires_in"`
RefreshToken string `json:"refresh_token,omitempty"`
Scope string `json:"scope,omitempty"`
IDToken string `json:"id_token,omitempty"`
}
type oidcTokenExchangeError struct {
StatusCode int
ProviderError string
ProviderDescription string
Body string
}
func (e *oidcTokenExchangeError) Error() string {
if e == nil {
return ""
}
parts := []string{fmt.Sprintf("token exchange status=%d", e.StatusCode)}
if strings.TrimSpace(e.ProviderError) != "" {
parts = append(parts, "error="+strings.TrimSpace(e.ProviderError))
}
if strings.TrimSpace(e.ProviderDescription) != "" {
parts = append(parts, "error_description="+strings.TrimSpace(e.ProviderDescription))
}
return strings.Join(parts, " ")
}
type oidcIDTokenClaims struct {
Email string `json:"email,omitempty"`
EmailVerified *bool `json:"email_verified,omitempty"`
PreferredUsername string `json:"preferred_username,omitempty"`
Name string `json:"name,omitempty"`
Nonce string `json:"nonce,omitempty"`
Azp string `json:"azp,omitempty"`
jwt.RegisteredClaims
}
type oidcUserInfoClaims struct {
Email string
Username string
Subject string
EmailVerified *bool
}
type oidcJWKSet struct {
Keys []oidcJWK `json:"keys"`
}
type oidcJWK struct {
Kty string `json:"kty"`
Kid string `json:"kid"`
Use string `json:"use"`
Alg string `json:"alg"`
N string `json:"n"`
E string `json:"e"`
Crv string `json:"crv"`
X string `json:"x"`
Y string `json:"y"`
}
// OIDCOAuthStart 启动通用 OIDC OAuth 登录流程。
// GET /api/v1/auth/oauth/oidc/start?redirect=/dashboard
func (h *AuthHandler) OIDCOAuthStart(c *gin.Context) {
cfg, err := h.getOIDCOAuthConfig(c.Request.Context())
if err != nil {
response.ErrorFrom(c, err)
return
}
state, err := oauth.GenerateState()
if err != nil {
response.ErrorFrom(c, infraerrors.InternalServer("OAUTH_STATE_GEN_FAILED", "failed to generate oauth state").WithCause(err))
return
}
redirectTo := sanitizeFrontendRedirectPath(c.Query("redirect"))
if redirectTo == "" {
redirectTo = oidcOAuthDefaultRedirectTo
}
secureCookie := isRequestHTTPS(c)
oidcSetCookie(c, oidcOAuthStateCookieName, encodeCookieValue(state), oidcOAuthCookieMaxAgeSec, secureCookie)
oidcSetCookie(c, oidcOAuthRedirectCookie, encodeCookieValue(redirectTo), oidcOAuthCookieMaxAgeSec, secureCookie)
codeChallenge := ""
if cfg.UsePKCE {
verifier, genErr := oauth.GenerateCodeVerifier()
if genErr != nil {
response.ErrorFrom(c, infraerrors.InternalServer("OAUTH_PKCE_GEN_FAILED", "failed to generate pkce verifier").WithCause(genErr))
return
}
codeChallenge = oauth.GenerateCodeChallenge(verifier)
oidcSetCookie(c, oidcOAuthVerifierCookie, encodeCookieValue(verifier), oidcOAuthCookieMaxAgeSec, secureCookie)
}
nonce := ""
if cfg.ValidateIDToken {
nonce, err = oauth.GenerateState()
if err != nil {
response.ErrorFrom(c, infraerrors.InternalServer("OAUTH_NONCE_GEN_FAILED", "failed to generate oauth nonce").WithCause(err))
return
}
oidcSetCookie(c, oidcOAuthNonceCookie, encodeCookieValue(nonce), oidcOAuthCookieMaxAgeSec, secureCookie)
}
redirectURI := strings.TrimSpace(cfg.RedirectURL)
if redirectURI == "" {
response.ErrorFrom(c, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth redirect url not configured"))
return
}
authURL, err := buildOIDCAuthorizeURL(cfg, state, nonce, codeChallenge, redirectURI)
if err != nil {
response.ErrorFrom(c, infraerrors.InternalServer("OAUTH_BUILD_URL_FAILED", "failed to build oauth authorization url").WithCause(err))
return
}
c.Redirect(http.StatusFound, authURL)
}
// OIDCOAuthCallback 处理 OIDC 回调:校验 id_token、创建/登录用户并重定向到前端。
// GET /api/v1/auth/oauth/oidc/callback?code=...&state=...
func (h *AuthHandler) OIDCOAuthCallback(c *gin.Context) {
cfg, cfgErr := h.getOIDCOAuthConfig(c.Request.Context())
if cfgErr != nil {
response.ErrorFrom(c, cfgErr)
return
}
frontendCallback := strings.TrimSpace(cfg.FrontendRedirectURL)
if frontendCallback == "" {
frontendCallback = oidcOAuthDefaultFrontendCB
}
if providerErr := strings.TrimSpace(c.Query("error")); providerErr != "" {
redirectOAuthError(c, frontendCallback, "provider_error", providerErr, c.Query("error_description"))
return
}
code := strings.TrimSpace(c.Query("code"))
state := strings.TrimSpace(c.Query("state"))
if code == "" || state == "" {
redirectOAuthError(c, frontendCallback, "missing_params", "missing code/state", "")
return
}
secureCookie := isRequestHTTPS(c)
defer func() {
oidcClearCookie(c, oidcOAuthStateCookieName, secureCookie)
oidcClearCookie(c, oidcOAuthVerifierCookie, secureCookie)
oidcClearCookie(c, oidcOAuthRedirectCookie, secureCookie)
oidcClearCookie(c, oidcOAuthNonceCookie, secureCookie)
}()
expectedState, err := readCookieDecoded(c, oidcOAuthStateCookieName)
if err != nil || expectedState == "" || state != expectedState {
redirectOAuthError(c, frontendCallback, "invalid_state", "invalid oauth state", "")
return
}
redirectTo, _ := readCookieDecoded(c, oidcOAuthRedirectCookie)
redirectTo = sanitizeFrontendRedirectPath(redirectTo)
if redirectTo == "" {
redirectTo = oidcOAuthDefaultRedirectTo
}
codeVerifier := ""
if cfg.UsePKCE {
codeVerifier, _ = readCookieDecoded(c, oidcOAuthVerifierCookie)
if codeVerifier == "" {
redirectOAuthError(c, frontendCallback, "missing_verifier", "missing pkce verifier", "")
return
}
}
expectedNonce := ""
if cfg.ValidateIDToken {
expectedNonce, _ = readCookieDecoded(c, oidcOAuthNonceCookie)
if expectedNonce == "" {
redirectOAuthError(c, frontendCallback, "missing_nonce", "missing oauth nonce", "")
return
}
}
redirectURI := strings.TrimSpace(cfg.RedirectURL)
if redirectURI == "" {
redirectOAuthError(c, frontendCallback, "config_error", "oauth redirect url not configured", "")
return
}
tokenResp, err := oidcExchangeCode(c.Request.Context(), cfg, code, redirectURI, codeVerifier)
if err != nil {
description := ""
var exchangeErr *oidcTokenExchangeError
if errors.As(err, &exchangeErr) && exchangeErr != nil {
log.Printf(
"[OIDC OAuth] token exchange failed: status=%d provider_error=%q provider_description=%q body=%s",
exchangeErr.StatusCode,
exchangeErr.ProviderError,
exchangeErr.ProviderDescription,
truncateLogValue(exchangeErr.Body, 2048),
)
description = exchangeErr.Error()
} else {
log.Printf("[OIDC OAuth] token exchange failed: %v", err)
description = err.Error()
}
redirectOAuthError(c, frontendCallback, "token_exchange_failed", "failed to exchange oauth code", singleLine(description))
return
}
if cfg.ValidateIDToken && strings.TrimSpace(tokenResp.IDToken) == "" {
redirectOAuthError(c, frontendCallback, "missing_id_token", "missing id_token", "")
return
}
idClaims, err := oidcParseAndValidateIDToken(c.Request.Context(), cfg, tokenResp.IDToken, expectedNonce)
if err != nil {
log.Printf("[OIDC OAuth] id_token validation failed: %v", err)
redirectOAuthError(c, frontendCallback, "invalid_id_token", "failed to validate id_token", "")
return
}
userInfoClaims, err := oidcFetchUserInfo(c.Request.Context(), cfg, tokenResp)
if err != nil {
log.Printf("[OIDC OAuth] userinfo fetch failed: %v", err)
redirectOAuthError(c, frontendCallback, "userinfo_failed", "failed to fetch user info", "")
return
}
subject := strings.TrimSpace(idClaims.Subject)
if subject == "" {
subject = strings.TrimSpace(userInfoClaims.Subject)
}
if subject == "" {
redirectOAuthError(c, frontendCallback, "missing_subject", "missing subject claim", "")
return
}
issuer := strings.TrimSpace(idClaims.Issuer)
if issuer == "" {
issuer = strings.TrimSpace(cfg.IssuerURL)
}
if issuer == "" {
redirectOAuthError(c, frontendCallback, "missing_issuer", "missing issuer claim", "")
return
}
emailVerified := userInfoClaims.EmailVerified
if emailVerified == nil {
emailVerified = idClaims.EmailVerified
}
if cfg.RequireEmailVerified {
if emailVerified == nil || !*emailVerified {
redirectOAuthError(c, frontendCallback, "email_not_verified", "email is not verified", "")
return
}
}
identityKey := oidcIdentityKey(issuer, subject)
email := oidcSelectLoginEmail(userInfoClaims.Email, idClaims.Email, identityKey)
username := firstNonEmpty(
userInfoClaims.Username,
idClaims.PreferredUsername,
idClaims.Name,
oidcFallbackUsername(subject),
)
// 传入空邀请码;如果需要邀请码,服务层返回 ErrOAuthInvitationRequired
tokenPair, _, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, "")
if err != nil {
if errors.Is(err, service.ErrOAuthInvitationRequired) {
pendingToken, tokenErr := h.authService.CreatePendingOAuthToken(email, username)
if tokenErr != nil {
redirectOAuthError(c, frontendCallback, "login_failed", "service_error", "")
return
}
fragment := url.Values{}
fragment.Set("error", "invitation_required")
fragment.Set("pending_oauth_token", pendingToken)
fragment.Set("redirect", redirectTo)
redirectWithFragment(c, frontendCallback, fragment)
return
}
redirectOAuthError(c, frontendCallback, "login_failed", infraerrors.Reason(err), infraerrors.Message(err))
return
}
fragment := url.Values{}
fragment.Set("access_token", tokenPair.AccessToken)
fragment.Set("refresh_token", tokenPair.RefreshToken)
fragment.Set("expires_in", fmt.Sprintf("%d", tokenPair.ExpiresIn))
fragment.Set("token_type", "Bearer")
fragment.Set("redirect", redirectTo)
redirectWithFragment(c, frontendCallback, fragment)
}
type completeOIDCOAuthRequest struct {
PendingOAuthToken string `json:"pending_oauth_token" binding:"required"`
InvitationCode string `json:"invitation_code" binding:"required"`
}
// CompleteOIDCOAuthRegistration completes a pending OAuth registration by validating
// the invitation code and creating the user account.
// POST /api/v1/auth/oauth/oidc/complete-registration
func (h *AuthHandler) CompleteOIDCOAuthRegistration(c *gin.Context) {
var req completeOIDCOAuthRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "INVALID_REQUEST", "message": err.Error()})
return
}
email, username, err := h.authService.VerifyPendingOAuthToken(req.PendingOAuthToken)
if err != nil {
c.JSON(http.StatusUnauthorized, gin.H{"error": "INVALID_TOKEN", "message": "invalid or expired registration token"})
return
}
tokenPair, _, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, req.InvitationCode)
if err != nil {
response.ErrorFrom(c, err)
return
}
c.JSON(http.StatusOK, gin.H{
"access_token": tokenPair.AccessToken,
"refresh_token": tokenPair.RefreshToken,
"expires_in": tokenPair.ExpiresIn,
"token_type": "Bearer",
})
}
func (h *AuthHandler) getOIDCOAuthConfig(ctx context.Context) (config.OIDCConnectConfig, error) {
if h != nil && h.settingSvc != nil {
return h.settingSvc.GetOIDCConnectOAuthConfig(ctx)
}
if h == nil || h.cfg == nil {
return config.OIDCConnectConfig{}, infraerrors.ServiceUnavailable("CONFIG_NOT_READY", "config not loaded")
}
if !h.cfg.OIDC.Enabled {
return config.OIDCConnectConfig{}, infraerrors.NotFound("OAUTH_DISABLED", "oauth login is disabled")
}
return h.cfg.OIDC, nil
}
func oidcExchangeCode(
ctx context.Context,
cfg config.OIDCConnectConfig,
code string,
redirectURI string,
codeVerifier string,
) (*oidcTokenResponse, error) {
client := req.C().SetTimeout(30 * time.Second)
form := url.Values{}
form.Set("grant_type", "authorization_code")
form.Set("client_id", cfg.ClientID)
form.Set("code", code)
form.Set("redirect_uri", redirectURI)
if cfg.UsePKCE {
form.Set("code_verifier", codeVerifier)
}
r := client.R().
SetContext(ctx).
SetHeader("Accept", "application/json")
switch strings.ToLower(strings.TrimSpace(cfg.TokenAuthMethod)) {
case "", "client_secret_post":
form.Set("client_secret", cfg.ClientSecret)
case "client_secret_basic":
r.SetBasicAuth(cfg.ClientID, cfg.ClientSecret)
case "none":
default:
return nil, fmt.Errorf("unsupported token_auth_method: %s", cfg.TokenAuthMethod)
}
resp, err := r.SetFormDataFromValues(form).Post(cfg.TokenURL)
if err != nil {
return nil, fmt.Errorf("request token: %w", err)
}
body := strings.TrimSpace(resp.String())
if !resp.IsSuccessState() {
providerErr, providerDesc := parseOAuthProviderError(body)
return nil, &oidcTokenExchangeError{
StatusCode: resp.StatusCode,
ProviderError: providerErr,
ProviderDescription: providerDesc,
Body: body,
}
}
tokenResp, ok := oidcParseTokenResponse(body)
if !ok {
return nil, &oidcTokenExchangeError{StatusCode: resp.StatusCode, Body: body}
}
if strings.TrimSpace(tokenResp.TokenType) == "" {
tokenResp.TokenType = "Bearer"
}
if strings.TrimSpace(tokenResp.AccessToken) == "" && strings.TrimSpace(tokenResp.IDToken) == "" {
return nil, &oidcTokenExchangeError{StatusCode: resp.StatusCode, Body: body}
}
return tokenResp, nil
}
func oidcParseTokenResponse(body string) (*oidcTokenResponse, bool) {
body = strings.TrimSpace(body)
if body == "" {
return nil, false
}
accessToken := strings.TrimSpace(getGJSON(body, "access_token"))
idToken := strings.TrimSpace(getGJSON(body, "id_token"))
if accessToken != "" || idToken != "" {
tokenType := strings.TrimSpace(getGJSON(body, "token_type"))
refreshToken := strings.TrimSpace(getGJSON(body, "refresh_token"))
scope := strings.TrimSpace(getGJSON(body, "scope"))
expiresIn := gjson.Get(body, "expires_in").Int()
return &oidcTokenResponse{
AccessToken: accessToken,
TokenType: tokenType,
ExpiresIn: expiresIn,
RefreshToken: refreshToken,
Scope: scope,
IDToken: idToken,
}, true
}
values, err := url.ParseQuery(body)
if err != nil {
return nil, false
}
accessToken = strings.TrimSpace(values.Get("access_token"))
idToken = strings.TrimSpace(values.Get("id_token"))
if accessToken == "" && idToken == "" {
return nil, false
}
expiresIn := int64(0)
if raw := strings.TrimSpace(values.Get("expires_in")); raw != "" {
if v, parseErr := strconv.ParseInt(raw, 10, 64); parseErr == nil {
expiresIn = v
}
}
return &oidcTokenResponse{
AccessToken: accessToken,
TokenType: strings.TrimSpace(values.Get("token_type")),
ExpiresIn: expiresIn,
RefreshToken: strings.TrimSpace(values.Get("refresh_token")),
Scope: strings.TrimSpace(values.Get("scope")),
IDToken: idToken,
}, true
}
func oidcFetchUserInfo(
ctx context.Context,
cfg config.OIDCConnectConfig,
token *oidcTokenResponse,
) (*oidcUserInfoClaims, error) {
if strings.TrimSpace(cfg.UserInfoURL) == "" {
return &oidcUserInfoClaims{}, nil
}
if token == nil || strings.TrimSpace(token.AccessToken) == "" {
return nil, errors.New("missing access_token for userinfo request")
}
client := req.C().SetTimeout(30 * time.Second)
authorization, err := buildBearerAuthorization(token.TokenType, token.AccessToken)
if err != nil {
return nil, fmt.Errorf("invalid token for userinfo request: %w", err)
}
resp, err := client.R().
SetContext(ctx).
SetHeader("Accept", "application/json").
SetHeader("Authorization", authorization).
Get(cfg.UserInfoURL)
if err != nil {
return nil, fmt.Errorf("request userinfo: %w", err)
}
if !resp.IsSuccessState() {
return nil, fmt.Errorf("userinfo status=%d", resp.StatusCode)
}
return oidcParseUserInfo(resp.String(), cfg), nil
}
func oidcParseUserInfo(body string, cfg config.OIDCConnectConfig) *oidcUserInfoClaims {
claims := &oidcUserInfoClaims{}
claims.Email = firstNonEmpty(
getGJSON(body, cfg.UserInfoEmailPath),
getGJSON(body, "email"),
getGJSON(body, "user.email"),
getGJSON(body, "data.email"),
getGJSON(body, "attributes.email"),
)
claims.Username = firstNonEmpty(
getGJSON(body, cfg.UserInfoUsernamePath),
getGJSON(body, "preferred_username"),
getGJSON(body, "username"),
getGJSON(body, "name"),
getGJSON(body, "user.username"),
getGJSON(body, "user.name"),
)
claims.Subject = firstNonEmpty(
getGJSON(body, cfg.UserInfoIDPath),
getGJSON(body, "sub"),
getGJSON(body, "id"),
getGJSON(body, "user_id"),
getGJSON(body, "uid"),
getGJSON(body, "user.id"),
)
if verified, ok := getGJSONBool(body, "email_verified"); ok {
claims.EmailVerified = &verified
}
claims.Email = strings.TrimSpace(claims.Email)
claims.Username = strings.TrimSpace(claims.Username)
claims.Subject = strings.TrimSpace(claims.Subject)
return claims
}
func getGJSONBool(body string, path string) (bool, bool) {
path = strings.TrimSpace(path)
if path == "" {
return false, false
}
res := gjson.Get(body, path)
if !res.Exists() {
return false, false
}
return res.Bool(), true
}
func buildOIDCAuthorizeURL(cfg config.OIDCConnectConfig, state, nonce, codeChallenge, redirectURI string) (string, error) {
u, err := url.Parse(cfg.AuthorizeURL)
if err != nil {
return "", fmt.Errorf("parse authorize_url: %w", err)
}
q := u.Query()
q.Set("response_type", "code")
q.Set("client_id", cfg.ClientID)
q.Set("redirect_uri", redirectURI)
if strings.TrimSpace(cfg.Scopes) != "" {
q.Set("scope", cfg.Scopes)
}
q.Set("state", state)
if strings.TrimSpace(nonce) != "" {
q.Set("nonce", nonce)
}
if cfg.UsePKCE {
q.Set("code_challenge", codeChallenge)
q.Set("code_challenge_method", "S256")
}
u.RawQuery = q.Encode()
return u.String(), nil
}
func oidcParseAndValidateIDToken(ctx context.Context, cfg config.OIDCConnectConfig, idToken string, expectedNonce string) (*oidcIDTokenClaims, error) {
idToken = strings.TrimSpace(idToken)
if idToken == "" {
return nil, errors.New("missing id_token")
}
allowed := oidcAllowedSigningAlgs(cfg.AllowedSigningAlgs)
if len(allowed) == 0 {
return nil, errors.New("empty allowed signing algorithms")
}
jwks, err := oidcFetchJWKSet(ctx, cfg.JWKSURL)
if err != nil {
return nil, err
}
leeway := time.Duration(cfg.ClockSkewSeconds) * time.Second
claims := &oidcIDTokenClaims{}
parsed, err := jwt.ParseWithClaims(
idToken,
claims,
func(token *jwt.Token) (any, error) {
alg := strings.TrimSpace(token.Method.Alg())
if !containsString(allowed, alg) {
return nil, fmt.Errorf("unexpected signing algorithm: %s", alg)
}
kid, _ := token.Header["kid"].(string)
return oidcFindPublicKey(jwks, strings.TrimSpace(kid), alg)
},
jwt.WithValidMethods(allowed),
jwt.WithAudience(cfg.ClientID),
jwt.WithIssuer(cfg.IssuerURL),
jwt.WithLeeway(leeway),
)
if err != nil {
return nil, err
}
if !parsed.Valid {
return nil, errors.New("id_token invalid")
}
if strings.TrimSpace(claims.Subject) == "" {
return nil, errors.New("id_token missing sub")
}
if expectedNonce != "" && strings.TrimSpace(claims.Nonce) != strings.TrimSpace(expectedNonce) {
return nil, errors.New("id_token nonce mismatch")
}
if len(claims.Audience) > 1 {
if strings.TrimSpace(claims.Azp) == "" || strings.TrimSpace(claims.Azp) != strings.TrimSpace(cfg.ClientID) {
return nil, errors.New("id_token azp mismatch")
}
}
return claims, nil
}
func oidcAllowedSigningAlgs(raw string) []string {
if strings.TrimSpace(raw) == "" {
return []string{"RS256", "ES256", "PS256"}
}
seen := make(map[string]struct{})
out := make([]string, 0, 4)
for _, part := range strings.Split(raw, ",") {
alg := strings.ToUpper(strings.TrimSpace(part))
if alg == "" {
continue
}
if _, ok := seen[alg]; ok {
continue
}
seen[alg] = struct{}{}
out = append(out, alg)
}
return out
}
func oidcFetchJWKSet(ctx context.Context, jwksURL string) (*oidcJWKSet, error) {
jwksURL = strings.TrimSpace(jwksURL)
if jwksURL == "" {
return nil, errors.New("missing jwks_url")
}
resp, err := req.C().
SetTimeout(30*time.Second).
R().
SetContext(ctx).
SetHeader("Accept", "application/json").
Get(jwksURL)
if err != nil {
return nil, fmt.Errorf("request jwks: %w", err)
}
if !resp.IsSuccessState() {
return nil, fmt.Errorf("jwks status=%d", resp.StatusCode)
}
set := &oidcJWKSet{}
if err := json.Unmarshal(resp.Bytes(), set); err != nil {
return nil, fmt.Errorf("parse jwks: %w", err)
}
if len(set.Keys) == 0 {
return nil, errors.New("jwks empty keys")
}
return set, nil
}
func oidcFindPublicKey(set *oidcJWKSet, kid, alg string) (any, error) {
if set == nil {
return nil, errors.New("jwks not loaded")
}
alg = strings.ToUpper(strings.TrimSpace(alg))
kid = strings.TrimSpace(kid)
var lastErr error
for i := range set.Keys {
k := set.Keys[i]
if strings.TrimSpace(k.Use) != "" && !strings.EqualFold(strings.TrimSpace(k.Use), "sig") {
continue
}
if kid != "" && strings.TrimSpace(k.Kid) != kid {
continue
}
if strings.TrimSpace(k.Alg) != "" && !strings.EqualFold(strings.TrimSpace(k.Alg), alg) {
continue
}
pk, err := k.publicKey()
if err != nil {
lastErr = err
continue
}
if pk != nil {
return pk, nil
}
}
if lastErr != nil {
return nil, lastErr
}
if kid != "" {
return nil, fmt.Errorf("jwk not found for kid=%s", kid)
}
return nil, errors.New("jwk not found")
}
func (k oidcJWK) publicKey() (any, error) {
switch strings.ToUpper(strings.TrimSpace(k.Kty)) {
case "RSA":
n, err := decodeBase64URLBigInt(k.N)
if err != nil {
return nil, fmt.Errorf("decode rsa n: %w", err)
}
eBytes, err := base64.RawURLEncoding.DecodeString(strings.TrimSpace(k.E))
if err != nil {
return nil, fmt.Errorf("decode rsa e: %w", err)
}
if len(eBytes) == 0 {
return nil, errors.New("empty rsa e")
}
e := 0
for _, b := range eBytes {
e = (e << 8) | int(b)
}
if e <= 0 {
return nil, errors.New("invalid rsa exponent")
}
if n.Sign() <= 0 {
return nil, errors.New("invalid rsa modulus")
}
return &rsa.PublicKey{N: n, E: e}, nil
case "EC":
var curve elliptic.Curve
switch strings.TrimSpace(k.Crv) {
case "P-256":
curve = elliptic.P256()
case "P-384":
curve = elliptic.P384()
case "P-521":
curve = elliptic.P521()
default:
return nil, fmt.Errorf("unsupported ec curve: %s", k.Crv)
}
x, err := decodeBase64URLBigInt(k.X)
if err != nil {
return nil, fmt.Errorf("decode ec x: %w", err)
}
y, err := decodeBase64URLBigInt(k.Y)
if err != nil {
return nil, fmt.Errorf("decode ec y: %w", err)
}
if !curve.IsOnCurve(x, y) {
return nil, errors.New("ec point is not on curve")
}
return &ecdsa.PublicKey{Curve: curve, X: x, Y: y}, nil
default:
return nil, fmt.Errorf("unsupported jwk kty: %s", k.Kty)
}
}
func decodeBase64URLBigInt(raw string) (*big.Int, error) {
buf, err := base64.RawURLEncoding.DecodeString(strings.TrimSpace(raw))
if err != nil {
return nil, err
}
if len(buf) == 0 {
return nil, errors.New("empty value")
}
return new(big.Int).SetBytes(buf), nil
}
func containsString(values []string, target string) bool {
target = strings.TrimSpace(target)
for _, v := range values {
if strings.EqualFold(strings.TrimSpace(v), target) {
return true
}
}
return false
}
func oidcIdentityKey(issuer, subject string) string {
issuer = strings.TrimSpace(strings.ToLower(issuer))
subject = strings.TrimSpace(subject)
return issuer + "\x1f" + subject
}
func oidcSyntheticEmailFromIdentityKey(identityKey string) string {
identityKey = strings.TrimSpace(identityKey)
if identityKey == "" {
return ""
}
sum := sha256.Sum256([]byte(identityKey))
return "oidc-" + hex.EncodeToString(sum[:16]) + service.OIDCConnectSyntheticEmailDomain
}
func oidcSelectLoginEmail(userInfoEmail, idTokenEmail, identityKey string) string {
email := strings.TrimSpace(firstNonEmpty(userInfoEmail, idTokenEmail))
if email != "" {
return email
}
return oidcSyntheticEmailFromIdentityKey(identityKey)
}
func oidcFallbackUsername(subject string) string {
subject = strings.TrimSpace(subject)
if subject == "" {
return "oidc_user"
}
sum := sha256.Sum256([]byte(subject))
return "oidc_" + hex.EncodeToString(sum[:])[:12]
}
func oidcSetCookie(c *gin.Context, name, value string, maxAgeSec int, secure bool) {
http.SetCookie(c.Writer, &http.Cookie{
Name: name,
Value: value,
Path: oidcOAuthCookiePath,
MaxAge: maxAgeSec,
HttpOnly: true,
Secure: secure,
SameSite: http.SameSiteLaxMode,
})
}
func oidcClearCookie(c *gin.Context, name string, secure bool) {
http.SetCookie(c.Writer, &http.Cookie{
Name: name,
Value: "",
Path: oidcOAuthCookiePath,
MaxAge: -1,
HttpOnly: true,
Secure: secure,
SameSite: http.SameSiteLaxMode,
})
}
package handler
import (
"context"
"crypto/rand"
"crypto/rsa"
"encoding/base64"
"encoding/json"
"math/big"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/golang-jwt/jwt/v5"
"github.com/stretchr/testify/require"
)
func TestOIDCSyntheticEmailStableAndDistinct(t *testing.T) {
k1 := oidcIdentityKey("https://issuer.example.com", "subject-a")
k2 := oidcIdentityKey("https://issuer.example.com", "subject-b")
e1 := oidcSyntheticEmailFromIdentityKey(k1)
e1Again := oidcSyntheticEmailFromIdentityKey(k1)
e2 := oidcSyntheticEmailFromIdentityKey(k2)
require.Equal(t, e1, e1Again)
require.NotEqual(t, e1, e2)
require.Contains(t, e1, "@oidc-connect.invalid")
}
func TestOIDCSelectLoginEmailPrefersRealEmail(t *testing.T) {
identityKey := oidcIdentityKey("https://issuer.example.com", "subject-a")
email := oidcSelectLoginEmail("user@example.com", "idtoken@example.com", identityKey)
require.Equal(t, "user@example.com", email)
email = oidcSelectLoginEmail("", "idtoken@example.com", identityKey)
require.Equal(t, "idtoken@example.com", email)
email = oidcSelectLoginEmail("", "", identityKey)
require.Contains(t, email, "@oidc-connect.invalid")
require.Equal(t, oidcSyntheticEmailFromIdentityKey(identityKey), email)
}
func TestBuildOIDCAuthorizeURLIncludesNonceAndPKCE(t *testing.T) {
cfg := config.OIDCConnectConfig{
AuthorizeURL: "https://issuer.example.com/auth",
ClientID: "cid",
Scopes: "openid email profile",
UsePKCE: true,
}
u, err := buildOIDCAuthorizeURL(cfg, "state123", "nonce123", "challenge123", "https://app.example.com/callback")
require.NoError(t, err)
require.Contains(t, u, "nonce=nonce123")
require.Contains(t, u, "code_challenge=challenge123")
require.Contains(t, u, "code_challenge_method=S256")
require.Contains(t, u, "scope=openid+email+profile")
}
func TestOIDCParseAndValidateIDToken(t *testing.T) {
priv, err := rsa.GenerateKey(rand.Reader, 2048)
require.NoError(t, err)
kid := "kid-1"
jwks := oidcJWKSet{Keys: []oidcJWK{buildRSAJWK(kid, &priv.PublicKey)}}
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
require.NoError(t, json.NewEncoder(w).Encode(jwks))
}))
defer srv.Close()
now := time.Now()
claims := oidcIDTokenClaims{
Nonce: "nonce-ok",
Azp: "client-1",
RegisteredClaims: jwt.RegisteredClaims{
Issuer: "https://issuer.example.com",
Subject: "subject-1",
Audience: jwt.ClaimStrings{"client-1", "another-aud"},
IssuedAt: jwt.NewNumericDate(now),
NotBefore: jwt.NewNumericDate(now.Add(-30 * time.Second)),
ExpiresAt: jwt.NewNumericDate(now.Add(5 * time.Minute)),
},
}
tok := jwt.NewWithClaims(jwt.SigningMethodRS256, claims)
tok.Header["kid"] = kid
signed, err := tok.SignedString(priv)
require.NoError(t, err)
cfg := config.OIDCConnectConfig{
ClientID: "client-1",
IssuerURL: "https://issuer.example.com",
JWKSURL: srv.URL,
AllowedSigningAlgs: "RS256",
ClockSkewSeconds: 120,
}
parsed, err := oidcParseAndValidateIDToken(context.Background(), cfg, signed, "nonce-ok")
require.NoError(t, err)
require.Equal(t, "subject-1", parsed.Subject)
require.Equal(t, "https://issuer.example.com", parsed.Issuer)
_, err = oidcParseAndValidateIDToken(context.Background(), cfg, signed, "bad-nonce")
require.Error(t, err)
}
func buildRSAJWK(kid string, pub *rsa.PublicKey) oidcJWK {
n := base64.RawURLEncoding.EncodeToString(pub.N.Bytes())
e := base64.RawURLEncoding.EncodeToString(big.NewInt(int64(pub.E)).Bytes())
return oidcJWK{
Kty: "RSA",
Kid: kid,
Use: "sig",
Alg: "RS256",
N: n,
E: e,
}
}
...@@ -133,16 +133,17 @@ func GroupFromServiceAdmin(g *service.Group) *AdminGroup { ...@@ -133,16 +133,17 @@ func GroupFromServiceAdmin(g *service.Group) *AdminGroup {
return nil return nil
} }
out := &AdminGroup{ out := &AdminGroup{
Group: groupFromServiceBase(g), Group: groupFromServiceBase(g),
ModelRouting: g.ModelRouting, ModelRouting: g.ModelRouting,
ModelRoutingEnabled: g.ModelRoutingEnabled, ModelRoutingEnabled: g.ModelRoutingEnabled,
MCPXMLInject: g.MCPXMLInject, MCPXMLInject: g.MCPXMLInject,
DefaultMappedModel: g.DefaultMappedModel, DefaultMappedModel: g.DefaultMappedModel,
SupportedModelScopes: g.SupportedModelScopes, MessagesDispatchModelConfig: g.MessagesDispatchModelConfig,
AccountCount: g.AccountCount, SupportedModelScopes: g.SupportedModelScopes,
ActiveAccountCount: g.ActiveAccountCount, AccountCount: g.AccountCount,
RateLimitedAccountCount: g.RateLimitedAccountCount, ActiveAccountCount: g.ActiveAccountCount,
SortOrder: g.SortOrder, RateLimitedAccountCount: g.RateLimitedAccountCount,
SortOrder: g.SortOrder,
} }
if len(g.AccountGroups) > 0 { if len(g.AccountGroups) > 0 {
out.AccountGroups = make([]AccountGroup, 0, len(g.AccountGroups)) out.AccountGroups = make([]AccountGroup, 0, len(g.AccountGroups))
......
...@@ -51,6 +51,29 @@ type SystemSettings struct { ...@@ -51,6 +51,29 @@ type SystemSettings struct {
LinuxDoConnectClientSecretConfigured bool `json:"linuxdo_connect_client_secret_configured"` LinuxDoConnectClientSecretConfigured bool `json:"linuxdo_connect_client_secret_configured"`
LinuxDoConnectRedirectURL string `json:"linuxdo_connect_redirect_url"` LinuxDoConnectRedirectURL string `json:"linuxdo_connect_redirect_url"`
OIDCConnectEnabled bool `json:"oidc_connect_enabled"`
OIDCConnectProviderName string `json:"oidc_connect_provider_name"`
OIDCConnectClientID string `json:"oidc_connect_client_id"`
OIDCConnectClientSecretConfigured bool `json:"oidc_connect_client_secret_configured"`
OIDCConnectIssuerURL string `json:"oidc_connect_issuer_url"`
OIDCConnectDiscoveryURL string `json:"oidc_connect_discovery_url"`
OIDCConnectAuthorizeURL string `json:"oidc_connect_authorize_url"`
OIDCConnectTokenURL string `json:"oidc_connect_token_url"`
OIDCConnectUserInfoURL string `json:"oidc_connect_userinfo_url"`
OIDCConnectJWKSURL string `json:"oidc_connect_jwks_url"`
OIDCConnectScopes string `json:"oidc_connect_scopes"`
OIDCConnectRedirectURL string `json:"oidc_connect_redirect_url"`
OIDCConnectFrontendRedirectURL string `json:"oidc_connect_frontend_redirect_url"`
OIDCConnectTokenAuthMethod string `json:"oidc_connect_token_auth_method"`
OIDCConnectUsePKCE bool `json:"oidc_connect_use_pkce"`
OIDCConnectValidateIDToken bool `json:"oidc_connect_validate_id_token"`
OIDCConnectAllowedSigningAlgs string `json:"oidc_connect_allowed_signing_algs"`
OIDCConnectClockSkewSeconds int `json:"oidc_connect_clock_skew_seconds"`
OIDCConnectRequireEmailVerified bool `json:"oidc_connect_require_email_verified"`
OIDCConnectUserInfoEmailPath string `json:"oidc_connect_userinfo_email_path"`
OIDCConnectUserInfoIDPath string `json:"oidc_connect_userinfo_id_path"`
OIDCConnectUserInfoUsernamePath string `json:"oidc_connect_userinfo_username_path"`
SiteName string `json:"site_name"` SiteName string `json:"site_name"`
SiteLogo string `json:"site_logo"` SiteLogo string `json:"site_logo"`
SiteSubtitle string `json:"site_subtitle"` SiteSubtitle string `json:"site_subtitle"`
...@@ -132,6 +155,9 @@ type PublicSettings struct { ...@@ -132,6 +155,9 @@ type PublicSettings struct {
CustomMenuItems []CustomMenuItem `json:"custom_menu_items"` CustomMenuItems []CustomMenuItem `json:"custom_menu_items"`
CustomEndpoints []CustomEndpoint `json:"custom_endpoints"` CustomEndpoints []CustomEndpoint `json:"custom_endpoints"`
LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"` LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"`
OIDCOAuthEnabled bool `json:"oidc_oauth_enabled"`
OIDCOAuthProviderName string `json:"oidc_oauth_provider_name"`
SoraClientEnabled bool `json:"sora_client_enabled"`
BackendModeEnabled bool `json:"backend_mode_enabled"` BackendModeEnabled bool `json:"backend_mode_enabled"`
Version string `json:"version"` Version string `json:"version"`
} }
......
package dto package dto
import "time" import (
"time"
"github.com/Wei-Shaw/sub2api/internal/domain"
)
type User struct { type User struct {
ID int64 `json:"id"` ID int64 `json:"id"`
...@@ -112,7 +116,8 @@ type AdminGroup struct { ...@@ -112,7 +116,8 @@ type AdminGroup struct {
MCPXMLInject bool `json:"mcp_xml_inject"` MCPXMLInject bool `json:"mcp_xml_inject"`
// OpenAI Messages 调度配置(仅 openai 平台使用) // OpenAI Messages 调度配置(仅 openai 平台使用)
DefaultMappedModel string `json:"default_mapped_model"` DefaultMappedModel string `json:"default_mapped_model"`
MessagesDispatchModelConfig domain.OpenAIMessagesDispatchModelConfig `json:"messages_dispatch_model_config"`
// 支持的模型系列(仅 antigravity 平台使用) // 支持的模型系列(仅 antigravity 平台使用)
SupportedModelScopes []string `json:"supported_model_scopes"` SupportedModelScopes []string `json:"supported_model_scopes"`
......
...@@ -47,6 +47,13 @@ func resolveOpenAIForwardDefaultMappedModel(apiKey *service.APIKey, fallbackMode ...@@ -47,6 +47,13 @@ func resolveOpenAIForwardDefaultMappedModel(apiKey *service.APIKey, fallbackMode
return strings.TrimSpace(apiKey.Group.DefaultMappedModel) return strings.TrimSpace(apiKey.Group.DefaultMappedModel)
} }
func resolveOpenAIMessagesDispatchMappedModel(apiKey *service.APIKey, requestedModel string) string {
if apiKey == nil || apiKey.Group == nil {
return ""
}
return strings.TrimSpace(apiKey.Group.ResolveMessagesDispatchModel(requestedModel))
}
// NewOpenAIGatewayHandler creates a new OpenAIGatewayHandler // NewOpenAIGatewayHandler creates a new OpenAIGatewayHandler
func NewOpenAIGatewayHandler( func NewOpenAIGatewayHandler(
gatewayService *service.OpenAIGatewayService, gatewayService *service.OpenAIGatewayService,
...@@ -551,6 +558,7 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) { ...@@ -551,6 +558,7 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) {
} }
reqModel := modelResult.String() reqModel := modelResult.String()
routingModel := service.NormalizeOpenAICompatRequestedModel(reqModel) routingModel := service.NormalizeOpenAICompatRequestedModel(reqModel)
preferredMappedModel := resolveOpenAIMessagesDispatchMappedModel(apiKey, reqModel)
reqStream := gjson.GetBytes(body, "stream").Bool() reqStream := gjson.GetBytes(body, "stream").Bool()
reqLog = reqLog.With(zap.String("model", reqModel), zap.Bool("stream", reqStream)) reqLog = reqLog.With(zap.String("model", reqModel), zap.Bool("stream", reqStream))
...@@ -609,17 +617,20 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) { ...@@ -609,17 +617,20 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) {
failedAccountIDs := make(map[int64]struct{}) failedAccountIDs := make(map[int64]struct{})
sameAccountRetryCount := make(map[int64]int) sameAccountRetryCount := make(map[int64]int)
var lastFailoverErr *service.UpstreamFailoverError var lastFailoverErr *service.UpstreamFailoverError
effectiveMappedModel := preferredMappedModel
for { for {
// 清除上一次迭代的降级模型标记,避免残留影响本次迭代 currentRoutingModel := routingModel
c.Set("openai_messages_fallback_model", "") if effectiveMappedModel != "" {
currentRoutingModel = effectiveMappedModel
}
reqLog.Debug("openai_messages.account_selecting", zap.Int("excluded_account_count", len(failedAccountIDs))) reqLog.Debug("openai_messages.account_selecting", zap.Int("excluded_account_count", len(failedAccountIDs)))
selection, scheduleDecision, err := h.gatewayService.SelectAccountWithScheduler( selection, scheduleDecision, err := h.gatewayService.SelectAccountWithScheduler(
c.Request.Context(), c.Request.Context(),
apiKey.GroupID, apiKey.GroupID,
"", // no previous_response_id "", // no previous_response_id
sessionHash, sessionHash,
routingModel, currentRoutingModel,
failedAccountIDs, failedAccountIDs,
service.OpenAIUpstreamTransportAny, service.OpenAIUpstreamTransportAny,
) )
...@@ -628,29 +639,7 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) { ...@@ -628,29 +639,7 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) {
zap.Error(err), zap.Error(err),
zap.Int("excluded_account_count", len(failedAccountIDs)), zap.Int("excluded_account_count", len(failedAccountIDs)),
) )
// 首次调度失败 + 有默认映射模型 → 用默认模型重试
if len(failedAccountIDs) == 0 { if len(failedAccountIDs) == 0 {
defaultModel := ""
if apiKey.Group != nil {
defaultModel = apiKey.Group.DefaultMappedModel
}
if defaultModel != "" && defaultModel != routingModel {
reqLog.Info("openai_messages.fallback_to_default_model",
zap.String("default_mapped_model", defaultModel),
)
selection, scheduleDecision, err = h.gatewayService.SelectAccountWithScheduler(
c.Request.Context(),
apiKey.GroupID,
"",
sessionHash,
defaultModel,
failedAccountIDs,
service.OpenAIUpstreamTransportAny,
)
if err == nil && selection != nil {
c.Set("openai_messages_fallback_model", defaultModel)
}
}
if err != nil { if err != nil {
h.anthropicStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "Service temporarily unavailable", streamStarted) h.anthropicStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "Service temporarily unavailable", streamStarted)
return return
...@@ -682,9 +671,7 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) { ...@@ -682,9 +671,7 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) {
service.SetOpsLatencyMs(c, service.OpsRoutingLatencyMsKey, time.Since(routingStart).Milliseconds()) service.SetOpsLatencyMs(c, service.OpsRoutingLatencyMsKey, time.Since(routingStart).Milliseconds())
forwardStart := time.Now() forwardStart := time.Now()
// Forward 层需要始终拿到 group 默认映射模型,这样未命中账号级映射的 defaultMappedModel := strings.TrimSpace(effectiveMappedModel)
// Claude 兼容模型才不会在后续 Codex 规范化中意外退化到 gpt-5.1。
defaultMappedModel := resolveOpenAIForwardDefaultMappedModel(apiKey, c.GetString("openai_messages_fallback_model"))
// 应用渠道模型映射到请求体 // 应用渠道模型映射到请求体
forwardBody := body forwardBody := body
if channelMappingMsg.Mapped { if channelMappingMsg.Mapped {
......
...@@ -360,7 +360,7 @@ func TestResolveOpenAIForwardDefaultMappedModel(t *testing.T) { ...@@ -360,7 +360,7 @@ func TestResolveOpenAIForwardDefaultMappedModel(t *testing.T) {
require.Equal(t, "gpt-5.2", resolveOpenAIForwardDefaultMappedModel(apiKey, " gpt-5.2 ")) require.Equal(t, "gpt-5.2", resolveOpenAIForwardDefaultMappedModel(apiKey, " gpt-5.2 "))
}) })
t.Run("uses_group_default_on_normal_path", func(t *testing.T) { t.Run("uses_group_default_when_explicit_fallback_absent", func(t *testing.T) {
apiKey := &service.APIKey{ apiKey := &service.APIKey{
Group: &service.Group{DefaultMappedModel: "gpt-5.4"}, Group: &service.Group{DefaultMappedModel: "gpt-5.4"},
} }
...@@ -376,6 +376,45 @@ func TestResolveOpenAIForwardDefaultMappedModel(t *testing.T) { ...@@ -376,6 +376,45 @@ func TestResolveOpenAIForwardDefaultMappedModel(t *testing.T) {
}) })
} }
func TestResolveOpenAIMessagesDispatchMappedModel(t *testing.T) {
t.Run("exact_claude_model_override_wins", func(t *testing.T) {
apiKey := &service.APIKey{
Group: &service.Group{
MessagesDispatchModelConfig: service.OpenAIMessagesDispatchModelConfig{
SonnetMappedModel: "gpt-5.2",
ExactModelMappings: map[string]string{
"claude-sonnet-4-5-20250929": "gpt-5.4-mini-high",
},
},
},
}
require.Equal(t, "gpt-5.4-mini", resolveOpenAIMessagesDispatchMappedModel(apiKey, "claude-sonnet-4-5-20250929"))
})
t.Run("uses_family_default_when_no_override", func(t *testing.T) {
apiKey := &service.APIKey{Group: &service.Group{}}
require.Equal(t, "gpt-5.4", resolveOpenAIMessagesDispatchMappedModel(apiKey, "claude-opus-4-6"))
require.Equal(t, "gpt-5.3-codex", resolveOpenAIMessagesDispatchMappedModel(apiKey, "claude-sonnet-4-5-20250929"))
require.Equal(t, "gpt-5.4-mini", resolveOpenAIMessagesDispatchMappedModel(apiKey, "claude-haiku-4-5-20251001"))
})
t.Run("returns_empty_for_non_claude_or_missing_group", func(t *testing.T) {
require.Empty(t, resolveOpenAIMessagesDispatchMappedModel(nil, "claude-sonnet-4-5-20250929"))
require.Empty(t, resolveOpenAIMessagesDispatchMappedModel(&service.APIKey{}, "claude-sonnet-4-5-20250929"))
require.Empty(t, resolveOpenAIMessagesDispatchMappedModel(&service.APIKey{Group: &service.Group{}}, "gpt-5.4"))
})
t.Run("does_not_fall_back_to_group_default_mapped_model", func(t *testing.T) {
apiKey := &service.APIKey{
Group: &service.Group{
DefaultMappedModel: "gpt-5.4",
},
}
require.Empty(t, resolveOpenAIMessagesDispatchMappedModel(apiKey, "gpt-5.4"))
require.Equal(t, "gpt-5.3-codex", resolveOpenAIMessagesDispatchMappedModel(apiKey, "claude-sonnet-4-5-20250929"))
})
}
func TestOpenAIResponses_MissingDependencies_ReturnsServiceUnavailable(t *testing.T) { func TestOpenAIResponses_MissingDependencies_ReturnsServiceUnavailable(t *testing.T) {
gin.SetMode(gin.TestMode) gin.SetMode(gin.TestMode)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment