Unverified Commit bbc79796 authored by Wesley Liddick's avatar Wesley Liddick Committed by GitHub
Browse files

Merge pull request #1529 from IanShaw027/feat/group-messages-dispatch-redo

feat: 为openai分组增加messages调度模型映射并支持instructions模板注入
parents 760cc7d6 7d008bd5
...@@ -11,6 +11,7 @@ import ( ...@@ -11,6 +11,7 @@ import (
"entgo.io/ent" "entgo.io/ent"
"entgo.io/ent/dialect/sql" "entgo.io/ent/dialect/sql"
"github.com/Wei-Shaw/sub2api/ent/group" "github.com/Wei-Shaw/sub2api/ent/group"
"github.com/Wei-Shaw/sub2api/internal/domain"
) )
// Group is the model entity for the Group schema. // Group is the model entity for the Group schema.
...@@ -76,6 +77,8 @@ type Group struct { ...@@ -76,6 +77,8 @@ type Group struct {
RequirePrivacySet bool `json:"require_privacy_set,omitempty"` RequirePrivacySet bool `json:"require_privacy_set,omitempty"`
// 默认映射模型 ID,当账号级映射找不到时使用此值 // 默认映射模型 ID,当账号级映射找不到时使用此值
DefaultMappedModel string `json:"default_mapped_model,omitempty"` DefaultMappedModel string `json:"default_mapped_model,omitempty"`
// OpenAI Messages 调度模型配置:按 Claude 系列/精确模型映射到目标 GPT 模型
MessagesDispatchModelConfig domain.OpenAIMessagesDispatchModelConfig `json:"messages_dispatch_model_config,omitempty"`
// Edges holds the relations/edges for other nodes in the graph. // Edges holds the relations/edges for other nodes in the graph.
// The values are being populated by the GroupQuery when eager-loading is set. // The values are being populated by the GroupQuery when eager-loading is set.
Edges GroupEdges `json:"edges"` Edges GroupEdges `json:"edges"`
...@@ -182,7 +185,7 @@ func (*Group) scanValues(columns []string) ([]any, error) { ...@@ -182,7 +185,7 @@ func (*Group) scanValues(columns []string) ([]any, error) {
values := make([]any, len(columns)) values := make([]any, len(columns))
for i := range columns { for i := range columns {
switch columns[i] { switch columns[i] {
case group.FieldModelRouting, group.FieldSupportedModelScopes: case group.FieldModelRouting, group.FieldSupportedModelScopes, group.FieldMessagesDispatchModelConfig:
values[i] = new([]byte) values[i] = new([]byte)
case group.FieldIsExclusive, group.FieldClaudeCodeOnly, group.FieldModelRoutingEnabled, group.FieldMcpXMLInject, group.FieldAllowMessagesDispatch, group.FieldRequireOauthOnly, group.FieldRequirePrivacySet: case group.FieldIsExclusive, group.FieldClaudeCodeOnly, group.FieldModelRoutingEnabled, group.FieldMcpXMLInject, group.FieldAllowMessagesDispatch, group.FieldRequireOauthOnly, group.FieldRequirePrivacySet:
values[i] = new(sql.NullBool) values[i] = new(sql.NullBool)
...@@ -403,6 +406,14 @@ func (_m *Group) assignValues(columns []string, values []any) error { ...@@ -403,6 +406,14 @@ func (_m *Group) assignValues(columns []string, values []any) error {
} else if value.Valid { } else if value.Valid {
_m.DefaultMappedModel = value.String _m.DefaultMappedModel = value.String
} }
case group.FieldMessagesDispatchModelConfig:
if value, ok := values[i].(*[]byte); !ok {
return fmt.Errorf("unexpected type %T for field messages_dispatch_model_config", values[i])
} else if value != nil && len(*value) > 0 {
if err := json.Unmarshal(*value, &_m.MessagesDispatchModelConfig); err != nil {
return fmt.Errorf("unmarshal field messages_dispatch_model_config: %w", err)
}
}
default: default:
_m.selectValues.Set(columns[i], values[i]) _m.selectValues.Set(columns[i], values[i])
} }
...@@ -585,6 +596,9 @@ func (_m *Group) String() string { ...@@ -585,6 +596,9 @@ func (_m *Group) String() string {
builder.WriteString(", ") builder.WriteString(", ")
builder.WriteString("default_mapped_model=") builder.WriteString("default_mapped_model=")
builder.WriteString(_m.DefaultMappedModel) builder.WriteString(_m.DefaultMappedModel)
builder.WriteString(", ")
builder.WriteString("messages_dispatch_model_config=")
builder.WriteString(fmt.Sprintf("%v", _m.MessagesDispatchModelConfig))
builder.WriteByte(')') builder.WriteByte(')')
return builder.String() return builder.String()
} }
......
...@@ -8,6 +8,7 @@ import ( ...@@ -8,6 +8,7 @@ import (
"entgo.io/ent" "entgo.io/ent"
"entgo.io/ent/dialect/sql" "entgo.io/ent/dialect/sql"
"entgo.io/ent/dialect/sql/sqlgraph" "entgo.io/ent/dialect/sql/sqlgraph"
"github.com/Wei-Shaw/sub2api/internal/domain"
) )
const ( const (
...@@ -73,6 +74,8 @@ const ( ...@@ -73,6 +74,8 @@ const (
FieldRequirePrivacySet = "require_privacy_set" FieldRequirePrivacySet = "require_privacy_set"
// FieldDefaultMappedModel holds the string denoting the default_mapped_model field in the database. // FieldDefaultMappedModel holds the string denoting the default_mapped_model field in the database.
FieldDefaultMappedModel = "default_mapped_model" FieldDefaultMappedModel = "default_mapped_model"
// FieldMessagesDispatchModelConfig holds the string denoting the messages_dispatch_model_config field in the database.
FieldMessagesDispatchModelConfig = "messages_dispatch_model_config"
// EdgeAPIKeys holds the string denoting the api_keys edge name in mutations. // EdgeAPIKeys holds the string denoting the api_keys edge name in mutations.
EdgeAPIKeys = "api_keys" EdgeAPIKeys = "api_keys"
// EdgeRedeemCodes holds the string denoting the redeem_codes edge name in mutations. // EdgeRedeemCodes holds the string denoting the redeem_codes edge name in mutations.
...@@ -177,6 +180,7 @@ var Columns = []string{ ...@@ -177,6 +180,7 @@ var Columns = []string{
FieldRequireOauthOnly, FieldRequireOauthOnly,
FieldRequirePrivacySet, FieldRequirePrivacySet,
FieldDefaultMappedModel, FieldDefaultMappedModel,
FieldMessagesDispatchModelConfig,
} }
var ( var (
...@@ -252,6 +256,8 @@ var ( ...@@ -252,6 +256,8 @@ var (
DefaultDefaultMappedModel string DefaultDefaultMappedModel string
// DefaultMappedModelValidator is a validator for the "default_mapped_model" field. It is called by the builders before save. // DefaultMappedModelValidator is a validator for the "default_mapped_model" field. It is called by the builders before save.
DefaultMappedModelValidator func(string) error DefaultMappedModelValidator func(string) error
// DefaultMessagesDispatchModelConfig holds the default value on creation for the "messages_dispatch_model_config" field.
DefaultMessagesDispatchModelConfig domain.OpenAIMessagesDispatchModelConfig
) )
// OrderOption defines the ordering options for the Group queries. // OrderOption defines the ordering options for the Group queries.
......
...@@ -18,6 +18,7 @@ import ( ...@@ -18,6 +18,7 @@ import (
"github.com/Wei-Shaw/sub2api/ent/usagelog" "github.com/Wei-Shaw/sub2api/ent/usagelog"
"github.com/Wei-Shaw/sub2api/ent/user" "github.com/Wei-Shaw/sub2api/ent/user"
"github.com/Wei-Shaw/sub2api/ent/usersubscription" "github.com/Wei-Shaw/sub2api/ent/usersubscription"
"github.com/Wei-Shaw/sub2api/internal/domain"
) )
// GroupCreate is the builder for creating a Group entity. // GroupCreate is the builder for creating a Group entity.
...@@ -410,6 +411,20 @@ func (_c *GroupCreate) SetNillableDefaultMappedModel(v *string) *GroupCreate { ...@@ -410,6 +411,20 @@ func (_c *GroupCreate) SetNillableDefaultMappedModel(v *string) *GroupCreate {
return _c return _c
} }
// SetMessagesDispatchModelConfig sets the "messages_dispatch_model_config" field.
func (_c *GroupCreate) SetMessagesDispatchModelConfig(v domain.OpenAIMessagesDispatchModelConfig) *GroupCreate {
_c.mutation.SetMessagesDispatchModelConfig(v)
return _c
}
// SetNillableMessagesDispatchModelConfig sets the "messages_dispatch_model_config" field if the given value is not nil.
func (_c *GroupCreate) SetNillableMessagesDispatchModelConfig(v *domain.OpenAIMessagesDispatchModelConfig) *GroupCreate {
if v != nil {
_c.SetMessagesDispatchModelConfig(*v)
}
return _c
}
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs. // AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs.
func (_c *GroupCreate) AddAPIKeyIDs(ids ...int64) *GroupCreate { func (_c *GroupCreate) AddAPIKeyIDs(ids ...int64) *GroupCreate {
_c.mutation.AddAPIKeyIDs(ids...) _c.mutation.AddAPIKeyIDs(ids...)
...@@ -611,6 +626,10 @@ func (_c *GroupCreate) defaults() error { ...@@ -611,6 +626,10 @@ func (_c *GroupCreate) defaults() error {
v := group.DefaultDefaultMappedModel v := group.DefaultDefaultMappedModel
_c.mutation.SetDefaultMappedModel(v) _c.mutation.SetDefaultMappedModel(v)
} }
if _, ok := _c.mutation.MessagesDispatchModelConfig(); !ok {
v := group.DefaultMessagesDispatchModelConfig
_c.mutation.SetMessagesDispatchModelConfig(v)
}
return nil return nil
} }
...@@ -695,6 +714,9 @@ func (_c *GroupCreate) check() error { ...@@ -695,6 +714,9 @@ func (_c *GroupCreate) check() error {
return &ValidationError{Name: "default_mapped_model", err: fmt.Errorf(`ent: validator failed for field "Group.default_mapped_model": %w`, err)} return &ValidationError{Name: "default_mapped_model", err: fmt.Errorf(`ent: validator failed for field "Group.default_mapped_model": %w`, err)}
} }
} }
if _, ok := _c.mutation.MessagesDispatchModelConfig(); !ok {
return &ValidationError{Name: "messages_dispatch_model_config", err: errors.New(`ent: missing required field "Group.messages_dispatch_model_config"`)}
}
return nil return nil
} }
...@@ -838,6 +860,10 @@ func (_c *GroupCreate) createSpec() (*Group, *sqlgraph.CreateSpec) { ...@@ -838,6 +860,10 @@ func (_c *GroupCreate) createSpec() (*Group, *sqlgraph.CreateSpec) {
_spec.SetField(group.FieldDefaultMappedModel, field.TypeString, value) _spec.SetField(group.FieldDefaultMappedModel, field.TypeString, value)
_node.DefaultMappedModel = value _node.DefaultMappedModel = value
} }
if value, ok := _c.mutation.MessagesDispatchModelConfig(); ok {
_spec.SetField(group.FieldMessagesDispatchModelConfig, field.TypeJSON, value)
_node.MessagesDispatchModelConfig = value
}
if nodes := _c.mutation.APIKeysIDs(); len(nodes) > 0 { if nodes := _c.mutation.APIKeysIDs(); len(nodes) > 0 {
edge := &sqlgraph.EdgeSpec{ edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M, Rel: sqlgraph.O2M,
...@@ -1462,6 +1488,18 @@ func (u *GroupUpsert) UpdateDefaultMappedModel() *GroupUpsert { ...@@ -1462,6 +1488,18 @@ func (u *GroupUpsert) UpdateDefaultMappedModel() *GroupUpsert {
return u return u
} }
// SetMessagesDispatchModelConfig sets the "messages_dispatch_model_config" field.
func (u *GroupUpsert) SetMessagesDispatchModelConfig(v domain.OpenAIMessagesDispatchModelConfig) *GroupUpsert {
u.Set(group.FieldMessagesDispatchModelConfig, v)
return u
}
// UpdateMessagesDispatchModelConfig sets the "messages_dispatch_model_config" field to the value that was provided on create.
func (u *GroupUpsert) UpdateMessagesDispatchModelConfig() *GroupUpsert {
u.SetExcluded(group.FieldMessagesDispatchModelConfig)
return u
}
// UpdateNewValues updates the mutable fields using the new values that were set on create. // UpdateNewValues updates the mutable fields using the new values that were set on create.
// Using this option is equivalent to using: // Using this option is equivalent to using:
// //
...@@ -2053,6 +2091,20 @@ func (u *GroupUpsertOne) UpdateDefaultMappedModel() *GroupUpsertOne { ...@@ -2053,6 +2091,20 @@ func (u *GroupUpsertOne) UpdateDefaultMappedModel() *GroupUpsertOne {
}) })
} }
// SetMessagesDispatchModelConfig sets the "messages_dispatch_model_config" field.
func (u *GroupUpsertOne) SetMessagesDispatchModelConfig(v domain.OpenAIMessagesDispatchModelConfig) *GroupUpsertOne {
return u.Update(func(s *GroupUpsert) {
s.SetMessagesDispatchModelConfig(v)
})
}
// UpdateMessagesDispatchModelConfig sets the "messages_dispatch_model_config" field to the value that was provided on create.
func (u *GroupUpsertOne) UpdateMessagesDispatchModelConfig() *GroupUpsertOne {
return u.Update(func(s *GroupUpsert) {
s.UpdateMessagesDispatchModelConfig()
})
}
// Exec executes the query. // Exec executes the query.
func (u *GroupUpsertOne) Exec(ctx context.Context) error { func (u *GroupUpsertOne) Exec(ctx context.Context) error {
if len(u.create.conflict) == 0 { if len(u.create.conflict) == 0 {
...@@ -2810,6 +2862,20 @@ func (u *GroupUpsertBulk) UpdateDefaultMappedModel() *GroupUpsertBulk { ...@@ -2810,6 +2862,20 @@ func (u *GroupUpsertBulk) UpdateDefaultMappedModel() *GroupUpsertBulk {
}) })
} }
// SetMessagesDispatchModelConfig sets the "messages_dispatch_model_config" field.
func (u *GroupUpsertBulk) SetMessagesDispatchModelConfig(v domain.OpenAIMessagesDispatchModelConfig) *GroupUpsertBulk {
return u.Update(func(s *GroupUpsert) {
s.SetMessagesDispatchModelConfig(v)
})
}
// UpdateMessagesDispatchModelConfig sets the "messages_dispatch_model_config" field to the value that was provided on create.
func (u *GroupUpsertBulk) UpdateMessagesDispatchModelConfig() *GroupUpsertBulk {
return u.Update(func(s *GroupUpsert) {
s.UpdateMessagesDispatchModelConfig()
})
}
// Exec executes the query. // Exec executes the query.
func (u *GroupUpsertBulk) Exec(ctx context.Context) error { func (u *GroupUpsertBulk) Exec(ctx context.Context) error {
if u.create.err != nil { if u.create.err != nil {
......
...@@ -20,6 +20,7 @@ import ( ...@@ -20,6 +20,7 @@ import (
"github.com/Wei-Shaw/sub2api/ent/usagelog" "github.com/Wei-Shaw/sub2api/ent/usagelog"
"github.com/Wei-Shaw/sub2api/ent/user" "github.com/Wei-Shaw/sub2api/ent/user"
"github.com/Wei-Shaw/sub2api/ent/usersubscription" "github.com/Wei-Shaw/sub2api/ent/usersubscription"
"github.com/Wei-Shaw/sub2api/internal/domain"
) )
// GroupUpdate is the builder for updating Group entities. // GroupUpdate is the builder for updating Group entities.
...@@ -552,6 +553,20 @@ func (_u *GroupUpdate) SetNillableDefaultMappedModel(v *string) *GroupUpdate { ...@@ -552,6 +553,20 @@ func (_u *GroupUpdate) SetNillableDefaultMappedModel(v *string) *GroupUpdate {
return _u return _u
} }
// SetMessagesDispatchModelConfig sets the "messages_dispatch_model_config" field.
func (_u *GroupUpdate) SetMessagesDispatchModelConfig(v domain.OpenAIMessagesDispatchModelConfig) *GroupUpdate {
_u.mutation.SetMessagesDispatchModelConfig(v)
return _u
}
// SetNillableMessagesDispatchModelConfig sets the "messages_dispatch_model_config" field if the given value is not nil.
func (_u *GroupUpdate) SetNillableMessagesDispatchModelConfig(v *domain.OpenAIMessagesDispatchModelConfig) *GroupUpdate {
if v != nil {
_u.SetMessagesDispatchModelConfig(*v)
}
return _u
}
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs. // AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs.
func (_u *GroupUpdate) AddAPIKeyIDs(ids ...int64) *GroupUpdate { func (_u *GroupUpdate) AddAPIKeyIDs(ids ...int64) *GroupUpdate {
_u.mutation.AddAPIKeyIDs(ids...) _u.mutation.AddAPIKeyIDs(ids...)
...@@ -1012,6 +1027,9 @@ func (_u *GroupUpdate) sqlSave(ctx context.Context) (_node int, err error) { ...@@ -1012,6 +1027,9 @@ func (_u *GroupUpdate) sqlSave(ctx context.Context) (_node int, err error) {
if value, ok := _u.mutation.DefaultMappedModel(); ok { if value, ok := _u.mutation.DefaultMappedModel(); ok {
_spec.SetField(group.FieldDefaultMappedModel, field.TypeString, value) _spec.SetField(group.FieldDefaultMappedModel, field.TypeString, value)
} }
if value, ok := _u.mutation.MessagesDispatchModelConfig(); ok {
_spec.SetField(group.FieldMessagesDispatchModelConfig, field.TypeJSON, value)
}
if _u.mutation.APIKeysCleared() { if _u.mutation.APIKeysCleared() {
edge := &sqlgraph.EdgeSpec{ edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M, Rel: sqlgraph.O2M,
...@@ -1843,6 +1861,20 @@ func (_u *GroupUpdateOne) SetNillableDefaultMappedModel(v *string) *GroupUpdateO ...@@ -1843,6 +1861,20 @@ func (_u *GroupUpdateOne) SetNillableDefaultMappedModel(v *string) *GroupUpdateO
return _u return _u
} }
// SetMessagesDispatchModelConfig sets the "messages_dispatch_model_config" field.
func (_u *GroupUpdateOne) SetMessagesDispatchModelConfig(v domain.OpenAIMessagesDispatchModelConfig) *GroupUpdateOne {
_u.mutation.SetMessagesDispatchModelConfig(v)
return _u
}
// SetNillableMessagesDispatchModelConfig sets the "messages_dispatch_model_config" field if the given value is not nil.
func (_u *GroupUpdateOne) SetNillableMessagesDispatchModelConfig(v *domain.OpenAIMessagesDispatchModelConfig) *GroupUpdateOne {
if v != nil {
_u.SetMessagesDispatchModelConfig(*v)
}
return _u
}
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs. // AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs.
func (_u *GroupUpdateOne) AddAPIKeyIDs(ids ...int64) *GroupUpdateOne { func (_u *GroupUpdateOne) AddAPIKeyIDs(ids ...int64) *GroupUpdateOne {
_u.mutation.AddAPIKeyIDs(ids...) _u.mutation.AddAPIKeyIDs(ids...)
...@@ -2333,6 +2365,9 @@ func (_u *GroupUpdateOne) sqlSave(ctx context.Context) (_node *Group, err error) ...@@ -2333,6 +2365,9 @@ func (_u *GroupUpdateOne) sqlSave(ctx context.Context) (_node *Group, err error)
if value, ok := _u.mutation.DefaultMappedModel(); ok { if value, ok := _u.mutation.DefaultMappedModel(); ok {
_spec.SetField(group.FieldDefaultMappedModel, field.TypeString, value) _spec.SetField(group.FieldDefaultMappedModel, field.TypeString, value)
} }
if value, ok := _u.mutation.MessagesDispatchModelConfig(); ok {
_spec.SetField(group.FieldMessagesDispatchModelConfig, field.TypeJSON, value)
}
if _u.mutation.APIKeysCleared() { if _u.mutation.APIKeysCleared() {
edge := &sqlgraph.EdgeSpec{ edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M, Rel: sqlgraph.O2M,
......
...@@ -407,6 +407,7 @@ var ( ...@@ -407,6 +407,7 @@ var (
{Name: "require_oauth_only", Type: field.TypeBool, Default: false}, {Name: "require_oauth_only", Type: field.TypeBool, Default: false},
{Name: "require_privacy_set", Type: field.TypeBool, Default: false}, {Name: "require_privacy_set", Type: field.TypeBool, Default: false},
{Name: "default_mapped_model", Type: field.TypeString, Size: 100, Default: ""}, {Name: "default_mapped_model", Type: field.TypeString, Size: 100, Default: ""},
{Name: "messages_dispatch_model_config", Type: field.TypeJSON, SchemaType: map[string]string{"postgres": "jsonb"}},
} }
// GroupsTable holds the schema information for the "groups" table. // GroupsTable holds the schema information for the "groups" table.
GroupsTable = &schema.Table{ GroupsTable = &schema.Table{
......
...@@ -8246,6 +8246,7 @@ type GroupMutation struct { ...@@ -8246,6 +8246,7 @@ type GroupMutation struct {
require_oauth_only *bool require_oauth_only *bool
require_privacy_set *bool require_privacy_set *bool
default_mapped_model *string default_mapped_model *string
messages_dispatch_model_config *domain.OpenAIMessagesDispatchModelConfig
clearedFields map[string]struct{} clearedFields map[string]struct{}
api_keys map[int64]struct{} api_keys map[int64]struct{}
removedapi_keys map[int64]struct{} removedapi_keys map[int64]struct{}
...@@ -9798,6 +9799,42 @@ func (m *GroupMutation) ResetDefaultMappedModel() { ...@@ -9798,6 +9799,42 @@ func (m *GroupMutation) ResetDefaultMappedModel() {
m.default_mapped_model = nil m.default_mapped_model = nil
} }
   
// SetMessagesDispatchModelConfig sets the "messages_dispatch_model_config" field.
func (m *GroupMutation) SetMessagesDispatchModelConfig(damdmc domain.OpenAIMessagesDispatchModelConfig) {
m.messages_dispatch_model_config = &damdmc
}
// MessagesDispatchModelConfig returns the value of the "messages_dispatch_model_config" field in the mutation.
func (m *GroupMutation) MessagesDispatchModelConfig() (r domain.OpenAIMessagesDispatchModelConfig, exists bool) {
v := m.messages_dispatch_model_config
if v == nil {
return
}
return *v, true
}
// OldMessagesDispatchModelConfig returns the old "messages_dispatch_model_config" field's value of the Group entity.
// If the Group object wasn't provided to the builder, the object is fetched from the database.
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
func (m *GroupMutation) OldMessagesDispatchModelConfig(ctx context.Context) (v domain.OpenAIMessagesDispatchModelConfig, err error) {
if !m.op.Is(OpUpdateOne) {
return v, errors.New("OldMessagesDispatchModelConfig is only allowed on UpdateOne operations")
}
if m.id == nil || m.oldValue == nil {
return v, errors.New("OldMessagesDispatchModelConfig requires an ID field in the mutation")
}
oldValue, err := m.oldValue(ctx)
if err != nil {
return v, fmt.Errorf("querying old value for OldMessagesDispatchModelConfig: %w", err)
}
return oldValue.MessagesDispatchModelConfig, nil
}
// ResetMessagesDispatchModelConfig resets all changes to the "messages_dispatch_model_config" field.
func (m *GroupMutation) ResetMessagesDispatchModelConfig() {
m.messages_dispatch_model_config = nil
}
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by ids. // AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by ids.
func (m *GroupMutation) AddAPIKeyIDs(ids ...int64) { func (m *GroupMutation) AddAPIKeyIDs(ids ...int64) {
if m.api_keys == nil { if m.api_keys == nil {
...@@ -10156,7 +10193,7 @@ func (m *GroupMutation) Type() string { ...@@ -10156,7 +10193,7 @@ func (m *GroupMutation) Type() string {
// order to get all numeric fields that were incremented/decremented, call // order to get all numeric fields that were incremented/decremented, call
// AddedFields(). // AddedFields().
func (m *GroupMutation) Fields() []string { func (m *GroupMutation) Fields() []string {
fields := make([]string, 0, 29) fields := make([]string, 0, 30)
if m.created_at != nil { if m.created_at != nil {
fields = append(fields, group.FieldCreatedAt) fields = append(fields, group.FieldCreatedAt)
} }
...@@ -10244,6 +10281,9 @@ func (m *GroupMutation) Fields() []string { ...@@ -10244,6 +10281,9 @@ func (m *GroupMutation) Fields() []string {
if m.default_mapped_model != nil { if m.default_mapped_model != nil {
fields = append(fields, group.FieldDefaultMappedModel) fields = append(fields, group.FieldDefaultMappedModel)
} }
if m.messages_dispatch_model_config != nil {
fields = append(fields, group.FieldMessagesDispatchModelConfig)
}
return fields return fields
} }
   
...@@ -10310,6 +10350,8 @@ func (m *GroupMutation) Field(name string) (ent.Value, bool) { ...@@ -10310,6 +10350,8 @@ func (m *GroupMutation) Field(name string) (ent.Value, bool) {
return m.RequirePrivacySet() return m.RequirePrivacySet()
case group.FieldDefaultMappedModel: case group.FieldDefaultMappedModel:
return m.DefaultMappedModel() return m.DefaultMappedModel()
case group.FieldMessagesDispatchModelConfig:
return m.MessagesDispatchModelConfig()
} }
return nil, false return nil, false
} }
...@@ -10377,6 +10419,8 @@ func (m *GroupMutation) OldField(ctx context.Context, name string) (ent.Value, e ...@@ -10377,6 +10419,8 @@ func (m *GroupMutation) OldField(ctx context.Context, name string) (ent.Value, e
return m.OldRequirePrivacySet(ctx) return m.OldRequirePrivacySet(ctx)
case group.FieldDefaultMappedModel: case group.FieldDefaultMappedModel:
return m.OldDefaultMappedModel(ctx) return m.OldDefaultMappedModel(ctx)
case group.FieldMessagesDispatchModelConfig:
return m.OldMessagesDispatchModelConfig(ctx)
} }
return nil, fmt.Errorf("unknown Group field %s", name) return nil, fmt.Errorf("unknown Group field %s", name)
} }
...@@ -10589,6 +10633,13 @@ func (m *GroupMutation) SetField(name string, value ent.Value) error { ...@@ -10589,6 +10633,13 @@ func (m *GroupMutation) SetField(name string, value ent.Value) error {
} }
m.SetDefaultMappedModel(v) m.SetDefaultMappedModel(v)
return nil return nil
case group.FieldMessagesDispatchModelConfig:
v, ok := value.(domain.OpenAIMessagesDispatchModelConfig)
if !ok {
return fmt.Errorf("unexpected type %T for field %s", value, name)
}
m.SetMessagesDispatchModelConfig(v)
return nil
} }
return fmt.Errorf("unknown Group field %s", name) return fmt.Errorf("unknown Group field %s", name)
} }
...@@ -10929,6 +10980,9 @@ func (m *GroupMutation) ResetField(name string) error { ...@@ -10929,6 +10980,9 @@ func (m *GroupMutation) ResetField(name string) error {
case group.FieldDefaultMappedModel: case group.FieldDefaultMappedModel:
m.ResetDefaultMappedModel() m.ResetDefaultMappedModel()
return nil return nil
case group.FieldMessagesDispatchModelConfig:
m.ResetMessagesDispatchModelConfig()
return nil
} }
return fmt.Errorf("unknown Group field %s", name) return fmt.Errorf("unknown Group field %s", name)
} }
......
...@@ -28,6 +28,7 @@ import ( ...@@ -28,6 +28,7 @@ import (
"github.com/Wei-Shaw/sub2api/ent/userattributedefinition" "github.com/Wei-Shaw/sub2api/ent/userattributedefinition"
"github.com/Wei-Shaw/sub2api/ent/userattributevalue" "github.com/Wei-Shaw/sub2api/ent/userattributevalue"
"github.com/Wei-Shaw/sub2api/ent/usersubscription" "github.com/Wei-Shaw/sub2api/ent/usersubscription"
"github.com/Wei-Shaw/sub2api/internal/domain"
) )
// The init function reads all schema descriptors with runtime code // The init function reads all schema descriptors with runtime code
...@@ -468,6 +469,10 @@ func init() { ...@@ -468,6 +469,10 @@ func init() {
group.DefaultDefaultMappedModel = groupDescDefaultMappedModel.Default.(string) group.DefaultDefaultMappedModel = groupDescDefaultMappedModel.Default.(string)
// group.DefaultMappedModelValidator is a validator for the "default_mapped_model" field. It is called by the builders before save. // group.DefaultMappedModelValidator is a validator for the "default_mapped_model" field. It is called by the builders before save.
group.DefaultMappedModelValidator = groupDescDefaultMappedModel.Validators[0].(func(string) error) group.DefaultMappedModelValidator = groupDescDefaultMappedModel.Validators[0].(func(string) error)
// groupDescMessagesDispatchModelConfig is the schema descriptor for messages_dispatch_model_config field.
groupDescMessagesDispatchModelConfig := groupFields[26].Descriptor()
// group.DefaultMessagesDispatchModelConfig holds the default value on creation for the messages_dispatch_model_config field.
group.DefaultMessagesDispatchModelConfig = groupDescMessagesDispatchModelConfig.Default.(domain.OpenAIMessagesDispatchModelConfig)
idempotencyrecordMixin := schema.IdempotencyRecord{}.Mixin() idempotencyrecordMixin := schema.IdempotencyRecord{}.Mixin()
idempotencyrecordMixinFields0 := idempotencyrecordMixin[0].Fields() idempotencyrecordMixinFields0 := idempotencyrecordMixin[0].Fields()
_ = idempotencyrecordMixinFields0 _ = idempotencyrecordMixinFields0
......
...@@ -141,6 +141,10 @@ func (Group) Fields() []ent.Field { ...@@ -141,6 +141,10 @@ func (Group) Fields() []ent.Field {
MaxLen(100). MaxLen(100).
Default(""). Default("").
Comment("默认映射模型 ID,当账号级映射找不到时使用此值"), Comment("默认映射模型 ID,当账号级映射找不到时使用此值"),
field.JSON("messages_dispatch_model_config", domain.OpenAIMessagesDispatchModelConfig{}).
Default(domain.OpenAIMessagesDispatchModelConfig{}).
SchemaType(map[string]string{dialect.Postgres: "jsonb"}).
Comment("OpenAI Messages 调度模型配置:按 Claude 系列/精确模型映射到目标 GPT 模型"),
} }
} }
......
...@@ -347,6 +347,12 @@ type GatewayConfig struct { ...@@ -347,6 +347,12 @@ type GatewayConfig struct {
// ForceCodexCLI: 强制将 OpenAI `/v1/responses` 请求按 Codex CLI 处理。 // ForceCodexCLI: 强制将 OpenAI `/v1/responses` 请求按 Codex CLI 处理。
// 用于网关未透传/改写 User-Agent 时的兼容兜底(默认关闭,避免影响其他客户端)。 // 用于网关未透传/改写 User-Agent 时的兼容兜底(默认关闭,避免影响其他客户端)。
ForceCodexCLI bool `mapstructure:"force_codex_cli"` ForceCodexCLI bool `mapstructure:"force_codex_cli"`
// ForcedCodexInstructionsTemplateFile: 服务端强制附加到 Codex 顶层 instructions 的模板文件路径。
// 模板渲染后会直接覆盖最终 instructions;若需要保留客户端 system 转换结果,请在模板中显式引用 {{ .ExistingInstructions }}。
ForcedCodexInstructionsTemplateFile string `mapstructure:"forced_codex_instructions_template_file"`
// ForcedCodexInstructionsTemplate: 启动时从模板文件读取并缓存的模板内容。
// 该字段不直接参与配置反序列化,仅用于请求热路径避免重复读盘。
ForcedCodexInstructionsTemplate string `mapstructure:"-"`
// OpenAIPassthroughAllowTimeoutHeaders: OpenAI 透传模式是否放行客户端超时头 // OpenAIPassthroughAllowTimeoutHeaders: OpenAI 透传模式是否放行客户端超时头
// 关闭(默认)可避免 x-stainless-timeout 等头导致上游提前断流。 // 关闭(默认)可避免 x-stainless-timeout 等头导致上游提前断流。
OpenAIPassthroughAllowTimeoutHeaders bool `mapstructure:"openai_passthrough_allow_timeout_headers"` OpenAIPassthroughAllowTimeoutHeaders bool `mapstructure:"openai_passthrough_allow_timeout_headers"`
...@@ -1029,6 +1035,14 @@ func load(allowMissingJWTSecret bool) (*Config, error) { ...@@ -1029,6 +1035,14 @@ func load(allowMissingJWTSecret bool) (*Config, error) {
cfg.Log.Environment = strings.TrimSpace(cfg.Log.Environment) cfg.Log.Environment = strings.TrimSpace(cfg.Log.Environment)
cfg.Log.StacktraceLevel = strings.ToLower(strings.TrimSpace(cfg.Log.StacktraceLevel)) cfg.Log.StacktraceLevel = strings.ToLower(strings.TrimSpace(cfg.Log.StacktraceLevel))
cfg.Log.Output.FilePath = strings.TrimSpace(cfg.Log.Output.FilePath) cfg.Log.Output.FilePath = strings.TrimSpace(cfg.Log.Output.FilePath)
cfg.Gateway.ForcedCodexInstructionsTemplateFile = strings.TrimSpace(cfg.Gateway.ForcedCodexInstructionsTemplateFile)
if cfg.Gateway.ForcedCodexInstructionsTemplateFile != "" {
content, err := os.ReadFile(cfg.Gateway.ForcedCodexInstructionsTemplateFile)
if err != nil {
return nil, fmt.Errorf("read forced codex instructions template %q: %w", cfg.Gateway.ForcedCodexInstructionsTemplateFile, err)
}
cfg.Gateway.ForcedCodexInstructionsTemplate = string(content)
}
// 兼容旧键 gateway.openai_ws.sticky_previous_response_ttl_seconds。 // 兼容旧键 gateway.openai_ws.sticky_previous_response_ttl_seconds。
// 新键未配置(<=0)时回退旧键;新键优先。 // 新键未配置(<=0)时回退旧键;新键优先。
......
package config package config
import ( import (
"os"
"path/filepath"
"strings" "strings"
"testing" "testing"
"time" "time"
...@@ -223,6 +225,23 @@ func TestLoadSchedulingConfigFromEnv(t *testing.T) { ...@@ -223,6 +225,23 @@ func TestLoadSchedulingConfigFromEnv(t *testing.T) {
} }
} }
func TestLoadForcedCodexInstructionsTemplate(t *testing.T) {
resetViperWithJWTSecret(t)
tempDir := t.TempDir()
templatePath := filepath.Join(tempDir, "codex-instructions.md.tmpl")
configPath := filepath.Join(tempDir, "config.yaml")
require.NoError(t, os.WriteFile(templatePath, []byte("server-prefix\n\n{{ .ExistingInstructions }}"), 0o644))
require.NoError(t, os.WriteFile(configPath, []byte("gateway:\n forced_codex_instructions_template_file: \""+templatePath+"\"\n"), 0o644))
t.Setenv("DATA_DIR", tempDir)
cfg, err := Load()
require.NoError(t, err)
require.Equal(t, templatePath, cfg.Gateway.ForcedCodexInstructionsTemplateFile)
require.Equal(t, "server-prefix\n\n{{ .ExistingInstructions }}", cfg.Gateway.ForcedCodexInstructionsTemplate)
}
func TestLoadDefaultSecurityToggles(t *testing.T) { func TestLoadDefaultSecurityToggles(t *testing.T) {
resetViperWithJWTSecret(t) resetViperWithJWTSecret(t)
......
package domain
// OpenAIMessagesDispatchModelConfig controls how Anthropic /v1/messages
// requests are mapped onto OpenAI/Codex models.
type OpenAIMessagesDispatchModelConfig struct {
OpusMappedModel string `json:"opus_mapped_model,omitempty"`
SonnetMappedModel string `json:"sonnet_mapped_model,omitempty"`
HaikuMappedModel string `json:"haiku_mapped_model,omitempty"`
ExactModelMappings map[string]string `json:"exact_model_mappings,omitempty"`
}
...@@ -105,10 +105,11 @@ type CreateGroupRequest struct { ...@@ -105,10 +105,11 @@ type CreateGroupRequest struct {
// 支持的模型系列(仅 antigravity 平台使用) // 支持的模型系列(仅 antigravity 平台使用)
SupportedModelScopes []string `json:"supported_model_scopes"` SupportedModelScopes []string `json:"supported_model_scopes"`
// OpenAI Messages 调度配置(仅 openai 平台使用) // OpenAI Messages 调度配置(仅 openai 平台使用)
AllowMessagesDispatch bool `json:"allow_messages_dispatch"` AllowMessagesDispatch bool `json:"allow_messages_dispatch"`
RequireOAuthOnly bool `json:"require_oauth_only"` RequireOAuthOnly bool `json:"require_oauth_only"`
RequirePrivacySet bool `json:"require_privacy_set"` RequirePrivacySet bool `json:"require_privacy_set"`
DefaultMappedModel string `json:"default_mapped_model"` DefaultMappedModel string `json:"default_mapped_model"`
MessagesDispatchModelConfig service.OpenAIMessagesDispatchModelConfig `json:"messages_dispatch_model_config"`
// 从指定分组复制账号(创建后自动绑定) // 从指定分组复制账号(创建后自动绑定)
CopyAccountsFromGroupIDs []int64 `json:"copy_accounts_from_group_ids"` CopyAccountsFromGroupIDs []int64 `json:"copy_accounts_from_group_ids"`
} }
...@@ -139,10 +140,11 @@ type UpdateGroupRequest struct { ...@@ -139,10 +140,11 @@ type UpdateGroupRequest struct {
// 支持的模型系列(仅 antigravity 平台使用) // 支持的模型系列(仅 antigravity 平台使用)
SupportedModelScopes *[]string `json:"supported_model_scopes"` SupportedModelScopes *[]string `json:"supported_model_scopes"`
// OpenAI Messages 调度配置(仅 openai 平台使用) // OpenAI Messages 调度配置(仅 openai 平台使用)
AllowMessagesDispatch *bool `json:"allow_messages_dispatch"` AllowMessagesDispatch *bool `json:"allow_messages_dispatch"`
RequireOAuthOnly *bool `json:"require_oauth_only"` RequireOAuthOnly *bool `json:"require_oauth_only"`
RequirePrivacySet *bool `json:"require_privacy_set"` RequirePrivacySet *bool `json:"require_privacy_set"`
DefaultMappedModel *string `json:"default_mapped_model"` DefaultMappedModel *string `json:"default_mapped_model"`
MessagesDispatchModelConfig *service.OpenAIMessagesDispatchModelConfig `json:"messages_dispatch_model_config"`
// 从指定分组复制账号(同步操作:先清空当前分组的账号绑定,再绑定源分组的账号) // 从指定分组复制账号(同步操作:先清空当前分组的账号绑定,再绑定源分组的账号)
CopyAccountsFromGroupIDs []int64 `json:"copy_accounts_from_group_ids"` CopyAccountsFromGroupIDs []int64 `json:"copy_accounts_from_group_ids"`
} }
...@@ -257,6 +259,7 @@ func (h *GroupHandler) Create(c *gin.Context) { ...@@ -257,6 +259,7 @@ func (h *GroupHandler) Create(c *gin.Context) {
RequireOAuthOnly: req.RequireOAuthOnly, RequireOAuthOnly: req.RequireOAuthOnly,
RequirePrivacySet: req.RequirePrivacySet, RequirePrivacySet: req.RequirePrivacySet,
DefaultMappedModel: req.DefaultMappedModel, DefaultMappedModel: req.DefaultMappedModel,
MessagesDispatchModelConfig: req.MessagesDispatchModelConfig,
CopyAccountsFromGroupIDs: req.CopyAccountsFromGroupIDs, CopyAccountsFromGroupIDs: req.CopyAccountsFromGroupIDs,
}) })
if err != nil { if err != nil {
...@@ -307,6 +310,7 @@ func (h *GroupHandler) Update(c *gin.Context) { ...@@ -307,6 +310,7 @@ func (h *GroupHandler) Update(c *gin.Context) {
RequireOAuthOnly: req.RequireOAuthOnly, RequireOAuthOnly: req.RequireOAuthOnly,
RequirePrivacySet: req.RequirePrivacySet, RequirePrivacySet: req.RequirePrivacySet,
DefaultMappedModel: req.DefaultMappedModel, DefaultMappedModel: req.DefaultMappedModel,
MessagesDispatchModelConfig: req.MessagesDispatchModelConfig,
CopyAccountsFromGroupIDs: req.CopyAccountsFromGroupIDs, CopyAccountsFromGroupIDs: req.CopyAccountsFromGroupIDs,
}) })
if err != nil { if err != nil {
......
...@@ -133,16 +133,17 @@ func GroupFromServiceAdmin(g *service.Group) *AdminGroup { ...@@ -133,16 +133,17 @@ func GroupFromServiceAdmin(g *service.Group) *AdminGroup {
return nil return nil
} }
out := &AdminGroup{ out := &AdminGroup{
Group: groupFromServiceBase(g), Group: groupFromServiceBase(g),
ModelRouting: g.ModelRouting, ModelRouting: g.ModelRouting,
ModelRoutingEnabled: g.ModelRoutingEnabled, ModelRoutingEnabled: g.ModelRoutingEnabled,
MCPXMLInject: g.MCPXMLInject, MCPXMLInject: g.MCPXMLInject,
DefaultMappedModel: g.DefaultMappedModel, DefaultMappedModel: g.DefaultMappedModel,
SupportedModelScopes: g.SupportedModelScopes, MessagesDispatchModelConfig: g.MessagesDispatchModelConfig,
AccountCount: g.AccountCount, SupportedModelScopes: g.SupportedModelScopes,
ActiveAccountCount: g.ActiveAccountCount, AccountCount: g.AccountCount,
RateLimitedAccountCount: g.RateLimitedAccountCount, ActiveAccountCount: g.ActiveAccountCount,
SortOrder: g.SortOrder, RateLimitedAccountCount: g.RateLimitedAccountCount,
SortOrder: g.SortOrder,
} }
if len(g.AccountGroups) > 0 { if len(g.AccountGroups) > 0 {
out.AccountGroups = make([]AccountGroup, 0, len(g.AccountGroups)) out.AccountGroups = make([]AccountGroup, 0, len(g.AccountGroups))
......
package dto package dto
import "time" import (
"time"
"github.com/Wei-Shaw/sub2api/internal/domain"
)
type User struct { type User struct {
ID int64 `json:"id"` ID int64 `json:"id"`
...@@ -112,7 +116,8 @@ type AdminGroup struct { ...@@ -112,7 +116,8 @@ type AdminGroup struct {
MCPXMLInject bool `json:"mcp_xml_inject"` MCPXMLInject bool `json:"mcp_xml_inject"`
// OpenAI Messages 调度配置(仅 openai 平台使用) // OpenAI Messages 调度配置(仅 openai 平台使用)
DefaultMappedModel string `json:"default_mapped_model"` DefaultMappedModel string `json:"default_mapped_model"`
MessagesDispatchModelConfig domain.OpenAIMessagesDispatchModelConfig `json:"messages_dispatch_model_config"`
// 支持的模型系列(仅 antigravity 平台使用) // 支持的模型系列(仅 antigravity 平台使用)
SupportedModelScopes []string `json:"supported_model_scopes"` SupportedModelScopes []string `json:"supported_model_scopes"`
......
...@@ -47,6 +47,13 @@ func resolveOpenAIForwardDefaultMappedModel(apiKey *service.APIKey, fallbackMode ...@@ -47,6 +47,13 @@ func resolveOpenAIForwardDefaultMappedModel(apiKey *service.APIKey, fallbackMode
return strings.TrimSpace(apiKey.Group.DefaultMappedModel) return strings.TrimSpace(apiKey.Group.DefaultMappedModel)
} }
func resolveOpenAIMessagesDispatchMappedModel(apiKey *service.APIKey, requestedModel string) string {
if apiKey == nil || apiKey.Group == nil {
return ""
}
return strings.TrimSpace(apiKey.Group.ResolveMessagesDispatchModel(requestedModel))
}
// NewOpenAIGatewayHandler creates a new OpenAIGatewayHandler // NewOpenAIGatewayHandler creates a new OpenAIGatewayHandler
func NewOpenAIGatewayHandler( func NewOpenAIGatewayHandler(
gatewayService *service.OpenAIGatewayService, gatewayService *service.OpenAIGatewayService,
...@@ -551,6 +558,7 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) { ...@@ -551,6 +558,7 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) {
} }
reqModel := modelResult.String() reqModel := modelResult.String()
routingModel := service.NormalizeOpenAICompatRequestedModel(reqModel) routingModel := service.NormalizeOpenAICompatRequestedModel(reqModel)
preferredMappedModel := resolveOpenAIMessagesDispatchMappedModel(apiKey, reqModel)
reqStream := gjson.GetBytes(body, "stream").Bool() reqStream := gjson.GetBytes(body, "stream").Bool()
reqLog = reqLog.With(zap.String("model", reqModel), zap.Bool("stream", reqStream)) reqLog = reqLog.With(zap.String("model", reqModel), zap.Bool("stream", reqStream))
...@@ -609,17 +617,20 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) { ...@@ -609,17 +617,20 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) {
failedAccountIDs := make(map[int64]struct{}) failedAccountIDs := make(map[int64]struct{})
sameAccountRetryCount := make(map[int64]int) sameAccountRetryCount := make(map[int64]int)
var lastFailoverErr *service.UpstreamFailoverError var lastFailoverErr *service.UpstreamFailoverError
effectiveMappedModel := preferredMappedModel
for { for {
// 清除上一次迭代的降级模型标记,避免残留影响本次迭代 currentRoutingModel := routingModel
c.Set("openai_messages_fallback_model", "") if effectiveMappedModel != "" {
currentRoutingModel = effectiveMappedModel
}
reqLog.Debug("openai_messages.account_selecting", zap.Int("excluded_account_count", len(failedAccountIDs))) reqLog.Debug("openai_messages.account_selecting", zap.Int("excluded_account_count", len(failedAccountIDs)))
selection, scheduleDecision, err := h.gatewayService.SelectAccountWithScheduler( selection, scheduleDecision, err := h.gatewayService.SelectAccountWithScheduler(
c.Request.Context(), c.Request.Context(),
apiKey.GroupID, apiKey.GroupID,
"", // no previous_response_id "", // no previous_response_id
sessionHash, sessionHash,
routingModel, currentRoutingModel,
failedAccountIDs, failedAccountIDs,
service.OpenAIUpstreamTransportAny, service.OpenAIUpstreamTransportAny,
) )
...@@ -628,29 +639,7 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) { ...@@ -628,29 +639,7 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) {
zap.Error(err), zap.Error(err),
zap.Int("excluded_account_count", len(failedAccountIDs)), zap.Int("excluded_account_count", len(failedAccountIDs)),
) )
// 首次调度失败 + 有默认映射模型 → 用默认模型重试
if len(failedAccountIDs) == 0 { if len(failedAccountIDs) == 0 {
defaultModel := ""
if apiKey.Group != nil {
defaultModel = apiKey.Group.DefaultMappedModel
}
if defaultModel != "" && defaultModel != routingModel {
reqLog.Info("openai_messages.fallback_to_default_model",
zap.String("default_mapped_model", defaultModel),
)
selection, scheduleDecision, err = h.gatewayService.SelectAccountWithScheduler(
c.Request.Context(),
apiKey.GroupID,
"",
sessionHash,
defaultModel,
failedAccountIDs,
service.OpenAIUpstreamTransportAny,
)
if err == nil && selection != nil {
c.Set("openai_messages_fallback_model", defaultModel)
}
}
if err != nil { if err != nil {
h.anthropicStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "Service temporarily unavailable", streamStarted) h.anthropicStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "Service temporarily unavailable", streamStarted)
return return
...@@ -682,9 +671,7 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) { ...@@ -682,9 +671,7 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) {
service.SetOpsLatencyMs(c, service.OpsRoutingLatencyMsKey, time.Since(routingStart).Milliseconds()) service.SetOpsLatencyMs(c, service.OpsRoutingLatencyMsKey, time.Since(routingStart).Milliseconds())
forwardStart := time.Now() forwardStart := time.Now()
// Forward 层需要始终拿到 group 默认映射模型,这样未命中账号级映射的 defaultMappedModel := strings.TrimSpace(effectiveMappedModel)
// Claude 兼容模型才不会在后续 Codex 规范化中意外退化到 gpt-5.1。
defaultMappedModel := resolveOpenAIForwardDefaultMappedModel(apiKey, c.GetString("openai_messages_fallback_model"))
// 应用渠道模型映射到请求体 // 应用渠道模型映射到请求体
forwardBody := body forwardBody := body
if channelMappingMsg.Mapped { if channelMappingMsg.Mapped {
......
...@@ -360,7 +360,7 @@ func TestResolveOpenAIForwardDefaultMappedModel(t *testing.T) { ...@@ -360,7 +360,7 @@ func TestResolveOpenAIForwardDefaultMappedModel(t *testing.T) {
require.Equal(t, "gpt-5.2", resolveOpenAIForwardDefaultMappedModel(apiKey, " gpt-5.2 ")) require.Equal(t, "gpt-5.2", resolveOpenAIForwardDefaultMappedModel(apiKey, " gpt-5.2 "))
}) })
t.Run("uses_group_default_on_normal_path", func(t *testing.T) { t.Run("uses_group_default_when_explicit_fallback_absent", func(t *testing.T) {
apiKey := &service.APIKey{ apiKey := &service.APIKey{
Group: &service.Group{DefaultMappedModel: "gpt-5.4"}, Group: &service.Group{DefaultMappedModel: "gpt-5.4"},
} }
...@@ -376,6 +376,45 @@ func TestResolveOpenAIForwardDefaultMappedModel(t *testing.T) { ...@@ -376,6 +376,45 @@ func TestResolveOpenAIForwardDefaultMappedModel(t *testing.T) {
}) })
} }
func TestResolveOpenAIMessagesDispatchMappedModel(t *testing.T) {
t.Run("exact_claude_model_override_wins", func(t *testing.T) {
apiKey := &service.APIKey{
Group: &service.Group{
MessagesDispatchModelConfig: service.OpenAIMessagesDispatchModelConfig{
SonnetMappedModel: "gpt-5.2",
ExactModelMappings: map[string]string{
"claude-sonnet-4-5-20250929": "gpt-5.4-mini-high",
},
},
},
}
require.Equal(t, "gpt-5.4-mini", resolveOpenAIMessagesDispatchMappedModel(apiKey, "claude-sonnet-4-5-20250929"))
})
t.Run("uses_family_default_when_no_override", func(t *testing.T) {
apiKey := &service.APIKey{Group: &service.Group{}}
require.Equal(t, "gpt-5.4", resolveOpenAIMessagesDispatchMappedModel(apiKey, "claude-opus-4-6"))
require.Equal(t, "gpt-5.3-codex", resolveOpenAIMessagesDispatchMappedModel(apiKey, "claude-sonnet-4-5-20250929"))
require.Equal(t, "gpt-5.4-mini", resolveOpenAIMessagesDispatchMappedModel(apiKey, "claude-haiku-4-5-20251001"))
})
t.Run("returns_empty_for_non_claude_or_missing_group", func(t *testing.T) {
require.Empty(t, resolveOpenAIMessagesDispatchMappedModel(nil, "claude-sonnet-4-5-20250929"))
require.Empty(t, resolveOpenAIMessagesDispatchMappedModel(&service.APIKey{}, "claude-sonnet-4-5-20250929"))
require.Empty(t, resolveOpenAIMessagesDispatchMappedModel(&service.APIKey{Group: &service.Group{}}, "gpt-5.4"))
})
t.Run("does_not_fall_back_to_group_default_mapped_model", func(t *testing.T) {
apiKey := &service.APIKey{
Group: &service.Group{
DefaultMappedModel: "gpt-5.4",
},
}
require.Empty(t, resolveOpenAIMessagesDispatchMappedModel(apiKey, "gpt-5.4"))
require.Equal(t, "gpt-5.3-codex", resolveOpenAIMessagesDispatchMappedModel(apiKey, "claude-sonnet-4-5-20250929"))
})
}
func TestOpenAIResponses_MissingDependencies_ReturnsServiceUnavailable(t *testing.T) { func TestOpenAIResponses_MissingDependencies_ReturnsServiceUnavailable(t *testing.T) {
gin.SetMode(gin.TestMode) gin.SetMode(gin.TestMode)
......
...@@ -28,7 +28,7 @@ type AnthropicRequest struct { ...@@ -28,7 +28,7 @@ type AnthropicRequest struct {
// AnthropicOutputConfig controls output generation parameters. // AnthropicOutputConfig controls output generation parameters.
type AnthropicOutputConfig struct { type AnthropicOutputConfig struct {
Effort string `json:"effort,omitempty"` // "low" | "medium" | "high" Effort string `json:"effort,omitempty"` // "low" | "medium" | "high" | "max"
} }
// AnthropicThinking configures extended thinking in the Anthropic API. // AnthropicThinking configures extended thinking in the Anthropic API.
...@@ -167,7 +167,7 @@ type ResponsesRequest struct { ...@@ -167,7 +167,7 @@ type ResponsesRequest struct {
// ResponsesReasoning configures reasoning effort in the Responses API. // ResponsesReasoning configures reasoning effort in the Responses API.
type ResponsesReasoning struct { type ResponsesReasoning struct {
Effort string `json:"effort"` // "low" | "medium" | "high" Effort string `json:"effort"` // "low" | "medium" | "high" | "xhigh"
Summary string `json:"summary,omitempty"` // "auto" | "concise" | "detailed" Summary string `json:"summary,omitempty"` // "auto" | "concise" | "detailed"
} }
...@@ -345,7 +345,7 @@ type ChatCompletionsRequest struct { ...@@ -345,7 +345,7 @@ type ChatCompletionsRequest struct {
StreamOptions *ChatStreamOptions `json:"stream_options,omitempty"` StreamOptions *ChatStreamOptions `json:"stream_options,omitempty"`
Tools []ChatTool `json:"tools,omitempty"` Tools []ChatTool `json:"tools,omitempty"`
ToolChoice json.RawMessage `json:"tool_choice,omitempty"` ToolChoice json.RawMessage `json:"tool_choice,omitempty"`
ReasoningEffort string `json:"reasoning_effort,omitempty"` // "low" | "medium" | "high" ReasoningEffort string `json:"reasoning_effort,omitempty"` // "low" | "medium" | "high" | "xhigh"
ServiceTier string `json:"service_tier,omitempty"` ServiceTier string `json:"service_tier,omitempty"`
Stop json.RawMessage `json:"stop,omitempty"` // string or []string Stop json.RawMessage `json:"stop,omitempty"` // string or []string
......
...@@ -58,7 +58,8 @@ func (r *groupRepository) Create(ctx context.Context, groupIn *service.Group) er ...@@ -58,7 +58,8 @@ func (r *groupRepository) Create(ctx context.Context, groupIn *service.Group) er
SetAllowMessagesDispatch(groupIn.AllowMessagesDispatch). SetAllowMessagesDispatch(groupIn.AllowMessagesDispatch).
SetRequireOauthOnly(groupIn.RequireOAuthOnly). SetRequireOauthOnly(groupIn.RequireOAuthOnly).
SetRequirePrivacySet(groupIn.RequirePrivacySet). SetRequirePrivacySet(groupIn.RequirePrivacySet).
SetDefaultMappedModel(groupIn.DefaultMappedModel) SetDefaultMappedModel(groupIn.DefaultMappedModel).
SetMessagesDispatchModelConfig(groupIn.MessagesDispatchModelConfig)
// 设置模型路由配置 // 设置模型路由配置
if groupIn.ModelRouting != nil { if groupIn.ModelRouting != nil {
...@@ -124,7 +125,8 @@ func (r *groupRepository) Update(ctx context.Context, groupIn *service.Group) er ...@@ -124,7 +125,8 @@ func (r *groupRepository) Update(ctx context.Context, groupIn *service.Group) er
SetAllowMessagesDispatch(groupIn.AllowMessagesDispatch). SetAllowMessagesDispatch(groupIn.AllowMessagesDispatch).
SetRequireOauthOnly(groupIn.RequireOAuthOnly). SetRequireOauthOnly(groupIn.RequireOAuthOnly).
SetRequirePrivacySet(groupIn.RequirePrivacySet). SetRequirePrivacySet(groupIn.RequirePrivacySet).
SetDefaultMappedModel(groupIn.DefaultMappedModel) SetDefaultMappedModel(groupIn.DefaultMappedModel).
SetMessagesDispatchModelConfig(groupIn.MessagesDispatchModelConfig)
// 显式处理可空字段:nil 需要 clear,非 nil 需要 set。 // 显式处理可空字段:nil 需要 clear,非 nil 需要 set。
if groupIn.DailyLimitUSD != nil { if groupIn.DailyLimitUSD != nil {
......
...@@ -152,10 +152,11 @@ type CreateGroupInput struct { ...@@ -152,10 +152,11 @@ type CreateGroupInput struct {
// 支持的模型系列(仅 antigravity 平台使用) // 支持的模型系列(仅 antigravity 平台使用)
SupportedModelScopes []string SupportedModelScopes []string
// OpenAI Messages 调度配置(仅 openai 平台使用) // OpenAI Messages 调度配置(仅 openai 平台使用)
AllowMessagesDispatch bool AllowMessagesDispatch bool
DefaultMappedModel string DefaultMappedModel string
RequireOAuthOnly bool RequireOAuthOnly bool
RequirePrivacySet bool RequirePrivacySet bool
MessagesDispatchModelConfig OpenAIMessagesDispatchModelConfig
// 从指定分组复制账号(创建分组后在同一事务内绑定) // 从指定分组复制账号(创建分组后在同一事务内绑定)
CopyAccountsFromGroupIDs []int64 CopyAccountsFromGroupIDs []int64
} }
...@@ -186,10 +187,11 @@ type UpdateGroupInput struct { ...@@ -186,10 +187,11 @@ type UpdateGroupInput struct {
// 支持的模型系列(仅 antigravity 平台使用) // 支持的模型系列(仅 antigravity 平台使用)
SupportedModelScopes *[]string SupportedModelScopes *[]string
// OpenAI Messages 调度配置(仅 openai 平台使用) // OpenAI Messages 调度配置(仅 openai 平台使用)
AllowMessagesDispatch *bool AllowMessagesDispatch *bool
DefaultMappedModel *string DefaultMappedModel *string
RequireOAuthOnly *bool RequireOAuthOnly *bool
RequirePrivacySet *bool RequirePrivacySet *bool
MessagesDispatchModelConfig *OpenAIMessagesDispatchModelConfig
// 从指定分组复制账号(同步操作:先清空当前分组的账号绑定,再绑定源分组的账号) // 从指定分组复制账号(同步操作:先清空当前分组的账号绑定,再绑定源分组的账号)
CopyAccountsFromGroupIDs []int64 CopyAccountsFromGroupIDs []int64
} }
...@@ -908,7 +910,9 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn ...@@ -908,7 +910,9 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn
RequireOAuthOnly: input.RequireOAuthOnly, RequireOAuthOnly: input.RequireOAuthOnly,
RequirePrivacySet: input.RequirePrivacySet, RequirePrivacySet: input.RequirePrivacySet,
DefaultMappedModel: input.DefaultMappedModel, DefaultMappedModel: input.DefaultMappedModel,
MessagesDispatchModelConfig: normalizeOpenAIMessagesDispatchModelConfig(input.MessagesDispatchModelConfig),
} }
sanitizeGroupMessagesDispatchFields(group)
if err := s.groupRepo.Create(ctx, group); err != nil { if err := s.groupRepo.Create(ctx, group); err != nil {
return nil, err return nil, err
} }
...@@ -1135,6 +1139,10 @@ func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *Upd ...@@ -1135,6 +1139,10 @@ func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *Upd
if input.DefaultMappedModel != nil { if input.DefaultMappedModel != nil {
group.DefaultMappedModel = *input.DefaultMappedModel group.DefaultMappedModel = *input.DefaultMappedModel
} }
if input.MessagesDispatchModelConfig != nil {
group.MessagesDispatchModelConfig = normalizeOpenAIMessagesDispatchModelConfig(*input.MessagesDispatchModelConfig)
}
sanitizeGroupMessagesDispatchFields(group)
if err := s.groupRepo.Update(ctx, group); err != nil { if err := s.groupRepo.Update(ctx, group); err != nil {
return nil, err return nil, err
......
...@@ -10,6 +10,11 @@ import ( ...@@ -10,6 +10,11 @@ import (
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
func ptrString[T ~string](v T) *string {
s := string(v)
return &s
}
// groupRepoStubForAdmin 用于测试 AdminService 的 GroupRepository Stub // groupRepoStubForAdmin 用于测试 AdminService 的 GroupRepository Stub
type groupRepoStubForAdmin struct { type groupRepoStubForAdmin struct {
created *Group // 记录 Create 调用的参数 created *Group // 记录 Create 调用的参数
...@@ -245,6 +250,116 @@ func TestAdminService_UpdateGroup_PartialImagePricing(t *testing.T) { ...@@ -245,6 +250,116 @@ func TestAdminService_UpdateGroup_PartialImagePricing(t *testing.T) {
require.Nil(t, repo.updated.ImagePrice4K) require.Nil(t, repo.updated.ImagePrice4K)
} }
func TestAdminService_CreateGroup_NormalizesMessagesDispatchModelConfig(t *testing.T) {
repo := &groupRepoStubForAdmin{}
svc := &adminServiceImpl{groupRepo: repo}
group, err := svc.CreateGroup(context.Background(), &CreateGroupInput{
Name: "dispatch-group",
Description: "dispatch config",
Platform: PlatformOpenAI,
RateMultiplier: 1.0,
MessagesDispatchModelConfig: OpenAIMessagesDispatchModelConfig{
OpusMappedModel: " gpt-5.4-high ",
SonnetMappedModel: " gpt-5.3-codex ",
HaikuMappedModel: " gpt-5.4-mini-medium ",
ExactModelMappings: map[string]string{
" claude-sonnet-4-5-20250929 ": " gpt-5.2-high ",
},
},
})
require.NoError(t, err)
require.NotNil(t, group)
require.NotNil(t, repo.created)
require.Equal(t, OpenAIMessagesDispatchModelConfig{
OpusMappedModel: "gpt-5.4",
SonnetMappedModel: "gpt-5.3-codex",
HaikuMappedModel: "gpt-5.4-mini",
ExactModelMappings: map[string]string{
"claude-sonnet-4-5-20250929": "gpt-5.2",
},
}, repo.created.MessagesDispatchModelConfig)
}
func TestAdminService_UpdateGroup_NormalizesMessagesDispatchModelConfig(t *testing.T) {
existingGroup := &Group{
ID: 1,
Name: "existing-group",
Platform: PlatformOpenAI,
Status: StatusActive,
}
repo := &groupRepoStubForAdmin{getByID: existingGroup}
svc := &adminServiceImpl{groupRepo: repo}
group, err := svc.UpdateGroup(context.Background(), 1, &UpdateGroupInput{
MessagesDispatchModelConfig: &OpenAIMessagesDispatchModelConfig{
SonnetMappedModel: " gpt-5.4-medium ",
ExactModelMappings: map[string]string{
" claude-haiku-4-5-20251001 ": " gpt-5.4-mini-high ",
},
},
})
require.NoError(t, err)
require.NotNil(t, group)
require.NotNil(t, repo.updated)
require.Equal(t, OpenAIMessagesDispatchModelConfig{
SonnetMappedModel: "gpt-5.4",
ExactModelMappings: map[string]string{
"claude-haiku-4-5-20251001": "gpt-5.4-mini",
},
}, repo.updated.MessagesDispatchModelConfig)
}
func TestAdminService_CreateGroup_ClearsMessagesDispatchFieldsForNonOpenAIPlatform(t *testing.T) {
repo := &groupRepoStubForAdmin{}
svc := &adminServiceImpl{groupRepo: repo}
group, err := svc.CreateGroup(context.Background(), &CreateGroupInput{
Name: "anthropic-group",
Description: "non-openai",
Platform: PlatformAnthropic,
RateMultiplier: 1.0,
AllowMessagesDispatch: true,
DefaultMappedModel: "gpt-5.4",
MessagesDispatchModelConfig: OpenAIMessagesDispatchModelConfig{
OpusMappedModel: "gpt-5.4",
},
})
require.NoError(t, err)
require.NotNil(t, group)
require.NotNil(t, repo.created)
require.False(t, repo.created.AllowMessagesDispatch)
require.Empty(t, repo.created.DefaultMappedModel)
require.Equal(t, OpenAIMessagesDispatchModelConfig{}, repo.created.MessagesDispatchModelConfig)
}
func TestAdminService_UpdateGroup_ClearsMessagesDispatchFieldsWhenPlatformChangesAwayFromOpenAI(t *testing.T) {
existingGroup := &Group{
ID: 1,
Name: "existing-openai-group",
Platform: PlatformOpenAI,
Status: StatusActive,
AllowMessagesDispatch: true,
DefaultMappedModel: "gpt-5.4",
MessagesDispatchModelConfig: OpenAIMessagesDispatchModelConfig{
SonnetMappedModel: "gpt-5.3-codex",
},
}
repo := &groupRepoStubForAdmin{getByID: existingGroup}
svc := &adminServiceImpl{groupRepo: repo}
group, err := svc.UpdateGroup(context.Background(), 1, &UpdateGroupInput{
Platform: PlatformAnthropic,
})
require.NoError(t, err)
require.NotNil(t, group)
require.NotNil(t, repo.updated)
require.Equal(t, PlatformAnthropic, repo.updated.Platform)
require.False(t, repo.updated.AllowMessagesDispatch)
require.Empty(t, repo.updated.DefaultMappedModel)
require.Equal(t, OpenAIMessagesDispatchModelConfig{}, repo.updated.MessagesDispatchModelConfig)
}
func TestAdminService_ListGroups_WithSearch(t *testing.T) { func TestAdminService_ListGroups_WithSearch(t *testing.T) {
// 测试: // 测试:
// 1. search 参数正常传递到 repository 层 // 1. search 参数正常传递到 repository 层
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment