"backend/internal/handler/vscode:/vscode.git/clone" did not exist on "e97c376681b32ea526e04686c8a5c3e8904298d8"
Unverified Commit 06093d4f authored by Wesley Liddick's avatar Wesley Liddick Committed by GitHub
Browse files

Merge pull request #311 from longgexx/main

添加分组级别模型路由配置功能(Anthropic平台)
parents 452fa53c 577ee161
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
package ent package ent
import ( import (
"encoding/json"
"fmt" "fmt"
"strings" "strings"
"time" "time"
...@@ -55,6 +56,10 @@ type Group struct { ...@@ -55,6 +56,10 @@ type Group struct {
ClaudeCodeOnly bool `json:"claude_code_only,omitempty"` ClaudeCodeOnly bool `json:"claude_code_only,omitempty"`
// 非 Claude Code 请求降级使用的分组 ID // 非 Claude Code 请求降级使用的分组 ID
FallbackGroupID *int64 `json:"fallback_group_id,omitempty"` FallbackGroupID *int64 `json:"fallback_group_id,omitempty"`
// 模型路由配置:模型模式 -> 优先账号ID列表
ModelRouting map[string][]int64 `json:"model_routing,omitempty"`
// 是否启用模型路由配置
ModelRoutingEnabled bool `json:"model_routing_enabled,omitempty"`
// Edges holds the relations/edges for other nodes in the graph. // Edges holds the relations/edges for other nodes in the graph.
// The values are being populated by the GroupQuery when eager-loading is set. // The values are being populated by the GroupQuery when eager-loading is set.
Edges GroupEdges `json:"edges"` Edges GroupEdges `json:"edges"`
...@@ -161,7 +166,9 @@ func (*Group) scanValues(columns []string) ([]any, error) { ...@@ -161,7 +166,9 @@ func (*Group) scanValues(columns []string) ([]any, error) {
values := make([]any, len(columns)) values := make([]any, len(columns))
for i := range columns { for i := range columns {
switch columns[i] { switch columns[i] {
case group.FieldIsExclusive, group.FieldClaudeCodeOnly: case group.FieldModelRouting:
values[i] = new([]byte)
case group.FieldIsExclusive, group.FieldClaudeCodeOnly, group.FieldModelRoutingEnabled:
values[i] = new(sql.NullBool) values[i] = new(sql.NullBool)
case group.FieldRateMultiplier, group.FieldDailyLimitUsd, group.FieldWeeklyLimitUsd, group.FieldMonthlyLimitUsd, group.FieldImagePrice1k, group.FieldImagePrice2k, group.FieldImagePrice4k: case group.FieldRateMultiplier, group.FieldDailyLimitUsd, group.FieldWeeklyLimitUsd, group.FieldMonthlyLimitUsd, group.FieldImagePrice1k, group.FieldImagePrice2k, group.FieldImagePrice4k:
values[i] = new(sql.NullFloat64) values[i] = new(sql.NullFloat64)
...@@ -315,6 +322,20 @@ func (_m *Group) assignValues(columns []string, values []any) error { ...@@ -315,6 +322,20 @@ func (_m *Group) assignValues(columns []string, values []any) error {
_m.FallbackGroupID = new(int64) _m.FallbackGroupID = new(int64)
*_m.FallbackGroupID = value.Int64 *_m.FallbackGroupID = value.Int64
} }
case group.FieldModelRouting:
if value, ok := values[i].(*[]byte); !ok {
return fmt.Errorf("unexpected type %T for field model_routing", values[i])
} else if value != nil && len(*value) > 0 {
if err := json.Unmarshal(*value, &_m.ModelRouting); err != nil {
return fmt.Errorf("unmarshal field model_routing: %w", err)
}
}
case group.FieldModelRoutingEnabled:
if value, ok := values[i].(*sql.NullBool); !ok {
return fmt.Errorf("unexpected type %T for field model_routing_enabled", values[i])
} else if value.Valid {
_m.ModelRoutingEnabled = value.Bool
}
default: default:
_m.selectValues.Set(columns[i], values[i]) _m.selectValues.Set(columns[i], values[i])
} }
...@@ -465,6 +486,12 @@ func (_m *Group) String() string { ...@@ -465,6 +486,12 @@ func (_m *Group) String() string {
builder.WriteString("fallback_group_id=") builder.WriteString("fallback_group_id=")
builder.WriteString(fmt.Sprintf("%v", *v)) builder.WriteString(fmt.Sprintf("%v", *v))
} }
builder.WriteString(", ")
builder.WriteString("model_routing=")
builder.WriteString(fmt.Sprintf("%v", _m.ModelRouting))
builder.WriteString(", ")
builder.WriteString("model_routing_enabled=")
builder.WriteString(fmt.Sprintf("%v", _m.ModelRoutingEnabled))
builder.WriteByte(')') builder.WriteByte(')')
return builder.String() return builder.String()
} }
......
...@@ -53,6 +53,10 @@ const ( ...@@ -53,6 +53,10 @@ const (
FieldClaudeCodeOnly = "claude_code_only" FieldClaudeCodeOnly = "claude_code_only"
// FieldFallbackGroupID holds the string denoting the fallback_group_id field in the database. // FieldFallbackGroupID holds the string denoting the fallback_group_id field in the database.
FieldFallbackGroupID = "fallback_group_id" FieldFallbackGroupID = "fallback_group_id"
// FieldModelRouting holds the string denoting the model_routing field in the database.
FieldModelRouting = "model_routing"
// FieldModelRoutingEnabled holds the string denoting the model_routing_enabled field in the database.
FieldModelRoutingEnabled = "model_routing_enabled"
// EdgeAPIKeys holds the string denoting the api_keys edge name in mutations. // EdgeAPIKeys holds the string denoting the api_keys edge name in mutations.
EdgeAPIKeys = "api_keys" EdgeAPIKeys = "api_keys"
// EdgeRedeemCodes holds the string denoting the redeem_codes edge name in mutations. // EdgeRedeemCodes holds the string denoting the redeem_codes edge name in mutations.
...@@ -147,6 +151,8 @@ var Columns = []string{ ...@@ -147,6 +151,8 @@ var Columns = []string{
FieldImagePrice4k, FieldImagePrice4k,
FieldClaudeCodeOnly, FieldClaudeCodeOnly,
FieldFallbackGroupID, FieldFallbackGroupID,
FieldModelRouting,
FieldModelRoutingEnabled,
} }
var ( var (
...@@ -204,6 +210,8 @@ var ( ...@@ -204,6 +210,8 @@ var (
DefaultDefaultValidityDays int DefaultDefaultValidityDays int
// DefaultClaudeCodeOnly holds the default value on creation for the "claude_code_only" field. // DefaultClaudeCodeOnly holds the default value on creation for the "claude_code_only" field.
DefaultClaudeCodeOnly bool DefaultClaudeCodeOnly bool
// DefaultModelRoutingEnabled holds the default value on creation for the "model_routing_enabled" field.
DefaultModelRoutingEnabled bool
) )
// OrderOption defines the ordering options for the Group queries. // OrderOption defines the ordering options for the Group queries.
...@@ -309,6 +317,11 @@ func ByFallbackGroupID(opts ...sql.OrderTermOption) OrderOption { ...@@ -309,6 +317,11 @@ func ByFallbackGroupID(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldFallbackGroupID, opts...).ToFunc() return sql.OrderByField(FieldFallbackGroupID, opts...).ToFunc()
} }
// ByModelRoutingEnabled orders the results by the model_routing_enabled field.
func ByModelRoutingEnabled(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldModelRoutingEnabled, opts...).ToFunc()
}
// ByAPIKeysCount orders the results by api_keys count. // ByAPIKeysCount orders the results by api_keys count.
func ByAPIKeysCount(opts ...sql.OrderTermOption) OrderOption { func ByAPIKeysCount(opts ...sql.OrderTermOption) OrderOption {
return func(s *sql.Selector) { return func(s *sql.Selector) {
......
...@@ -150,6 +150,11 @@ func FallbackGroupID(v int64) predicate.Group { ...@@ -150,6 +150,11 @@ func FallbackGroupID(v int64) predicate.Group {
return predicate.Group(sql.FieldEQ(FieldFallbackGroupID, v)) return predicate.Group(sql.FieldEQ(FieldFallbackGroupID, v))
} }
// ModelRoutingEnabled applies equality check predicate on the "model_routing_enabled" field. It's identical to ModelRoutingEnabledEQ.
func ModelRoutingEnabled(v bool) predicate.Group {
return predicate.Group(sql.FieldEQ(FieldModelRoutingEnabled, v))
}
// CreatedAtEQ applies the EQ predicate on the "created_at" field. // CreatedAtEQ applies the EQ predicate on the "created_at" field.
func CreatedAtEQ(v time.Time) predicate.Group { func CreatedAtEQ(v time.Time) predicate.Group {
return predicate.Group(sql.FieldEQ(FieldCreatedAt, v)) return predicate.Group(sql.FieldEQ(FieldCreatedAt, v))
...@@ -1065,6 +1070,26 @@ func FallbackGroupIDNotNil() predicate.Group { ...@@ -1065,6 +1070,26 @@ func FallbackGroupIDNotNil() predicate.Group {
return predicate.Group(sql.FieldNotNull(FieldFallbackGroupID)) return predicate.Group(sql.FieldNotNull(FieldFallbackGroupID))
} }
// ModelRoutingIsNil applies the IsNil predicate on the "model_routing" field.
func ModelRoutingIsNil() predicate.Group {
return predicate.Group(sql.FieldIsNull(FieldModelRouting))
}
// ModelRoutingNotNil applies the NotNil predicate on the "model_routing" field.
func ModelRoutingNotNil() predicate.Group {
return predicate.Group(sql.FieldNotNull(FieldModelRouting))
}
// ModelRoutingEnabledEQ applies the EQ predicate on the "model_routing_enabled" field.
func ModelRoutingEnabledEQ(v bool) predicate.Group {
return predicate.Group(sql.FieldEQ(FieldModelRoutingEnabled, v))
}
// ModelRoutingEnabledNEQ applies the NEQ predicate on the "model_routing_enabled" field.
func ModelRoutingEnabledNEQ(v bool) predicate.Group {
return predicate.Group(sql.FieldNEQ(FieldModelRoutingEnabled, v))
}
// HasAPIKeys applies the HasEdge predicate on the "api_keys" edge. // HasAPIKeys applies the HasEdge predicate on the "api_keys" edge.
func HasAPIKeys() predicate.Group { func HasAPIKeys() predicate.Group {
return predicate.Group(func(s *sql.Selector) { return predicate.Group(func(s *sql.Selector) {
......
...@@ -286,6 +286,26 @@ func (_c *GroupCreate) SetNillableFallbackGroupID(v *int64) *GroupCreate { ...@@ -286,6 +286,26 @@ func (_c *GroupCreate) SetNillableFallbackGroupID(v *int64) *GroupCreate {
return _c return _c
} }
// SetModelRouting sets the "model_routing" field.
func (_c *GroupCreate) SetModelRouting(v map[string][]int64) *GroupCreate {
_c.mutation.SetModelRouting(v)
return _c
}
// SetModelRoutingEnabled sets the "model_routing_enabled" field.
func (_c *GroupCreate) SetModelRoutingEnabled(v bool) *GroupCreate {
_c.mutation.SetModelRoutingEnabled(v)
return _c
}
// SetNillableModelRoutingEnabled sets the "model_routing_enabled" field if the given value is not nil.
func (_c *GroupCreate) SetNillableModelRoutingEnabled(v *bool) *GroupCreate {
if v != nil {
_c.SetModelRoutingEnabled(*v)
}
return _c
}
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs. // AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs.
func (_c *GroupCreate) AddAPIKeyIDs(ids ...int64) *GroupCreate { func (_c *GroupCreate) AddAPIKeyIDs(ids ...int64) *GroupCreate {
_c.mutation.AddAPIKeyIDs(ids...) _c.mutation.AddAPIKeyIDs(ids...)
...@@ -455,6 +475,10 @@ func (_c *GroupCreate) defaults() error { ...@@ -455,6 +475,10 @@ func (_c *GroupCreate) defaults() error {
v := group.DefaultClaudeCodeOnly v := group.DefaultClaudeCodeOnly
_c.mutation.SetClaudeCodeOnly(v) _c.mutation.SetClaudeCodeOnly(v)
} }
if _, ok := _c.mutation.ModelRoutingEnabled(); !ok {
v := group.DefaultModelRoutingEnabled
_c.mutation.SetModelRoutingEnabled(v)
}
return nil return nil
} }
...@@ -510,6 +534,9 @@ func (_c *GroupCreate) check() error { ...@@ -510,6 +534,9 @@ func (_c *GroupCreate) check() error {
if _, ok := _c.mutation.ClaudeCodeOnly(); !ok { if _, ok := _c.mutation.ClaudeCodeOnly(); !ok {
return &ValidationError{Name: "claude_code_only", err: errors.New(`ent: missing required field "Group.claude_code_only"`)} return &ValidationError{Name: "claude_code_only", err: errors.New(`ent: missing required field "Group.claude_code_only"`)}
} }
if _, ok := _c.mutation.ModelRoutingEnabled(); !ok {
return &ValidationError{Name: "model_routing_enabled", err: errors.New(`ent: missing required field "Group.model_routing_enabled"`)}
}
return nil return nil
} }
...@@ -613,6 +640,14 @@ func (_c *GroupCreate) createSpec() (*Group, *sqlgraph.CreateSpec) { ...@@ -613,6 +640,14 @@ func (_c *GroupCreate) createSpec() (*Group, *sqlgraph.CreateSpec) {
_spec.SetField(group.FieldFallbackGroupID, field.TypeInt64, value) _spec.SetField(group.FieldFallbackGroupID, field.TypeInt64, value)
_node.FallbackGroupID = &value _node.FallbackGroupID = &value
} }
if value, ok := _c.mutation.ModelRouting(); ok {
_spec.SetField(group.FieldModelRouting, field.TypeJSON, value)
_node.ModelRouting = value
}
if value, ok := _c.mutation.ModelRoutingEnabled(); ok {
_spec.SetField(group.FieldModelRoutingEnabled, field.TypeBool, value)
_node.ModelRoutingEnabled = value
}
if nodes := _c.mutation.APIKeysIDs(); len(nodes) > 0 { if nodes := _c.mutation.APIKeysIDs(); len(nodes) > 0 {
edge := &sqlgraph.EdgeSpec{ edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M, Rel: sqlgraph.O2M,
...@@ -1093,6 +1128,36 @@ func (u *GroupUpsert) ClearFallbackGroupID() *GroupUpsert { ...@@ -1093,6 +1128,36 @@ func (u *GroupUpsert) ClearFallbackGroupID() *GroupUpsert {
return u return u
} }
// SetModelRouting sets the "model_routing" field.
func (u *GroupUpsert) SetModelRouting(v map[string][]int64) *GroupUpsert {
u.Set(group.FieldModelRouting, v)
return u
}
// UpdateModelRouting sets the "model_routing" field to the value that was provided on create.
func (u *GroupUpsert) UpdateModelRouting() *GroupUpsert {
u.SetExcluded(group.FieldModelRouting)
return u
}
// ClearModelRouting clears the value of the "model_routing" field.
func (u *GroupUpsert) ClearModelRouting() *GroupUpsert {
u.SetNull(group.FieldModelRouting)
return u
}
// SetModelRoutingEnabled sets the "model_routing_enabled" field.
func (u *GroupUpsert) SetModelRoutingEnabled(v bool) *GroupUpsert {
u.Set(group.FieldModelRoutingEnabled, v)
return u
}
// UpdateModelRoutingEnabled sets the "model_routing_enabled" field to the value that was provided on create.
func (u *GroupUpsert) UpdateModelRoutingEnabled() *GroupUpsert {
u.SetExcluded(group.FieldModelRoutingEnabled)
return u
}
// UpdateNewValues updates the mutable fields using the new values that were set on create. // UpdateNewValues updates the mutable fields using the new values that were set on create.
// Using this option is equivalent to using: // Using this option is equivalent to using:
// //
...@@ -1516,6 +1581,41 @@ func (u *GroupUpsertOne) ClearFallbackGroupID() *GroupUpsertOne { ...@@ -1516,6 +1581,41 @@ func (u *GroupUpsertOne) ClearFallbackGroupID() *GroupUpsertOne {
}) })
} }
// SetModelRouting sets the "model_routing" field.
func (u *GroupUpsertOne) SetModelRouting(v map[string][]int64) *GroupUpsertOne {
return u.Update(func(s *GroupUpsert) {
s.SetModelRouting(v)
})
}
// UpdateModelRouting sets the "model_routing" field to the value that was provided on create.
func (u *GroupUpsertOne) UpdateModelRouting() *GroupUpsertOne {
return u.Update(func(s *GroupUpsert) {
s.UpdateModelRouting()
})
}
// ClearModelRouting clears the value of the "model_routing" field.
func (u *GroupUpsertOne) ClearModelRouting() *GroupUpsertOne {
return u.Update(func(s *GroupUpsert) {
s.ClearModelRouting()
})
}
// SetModelRoutingEnabled sets the "model_routing_enabled" field.
func (u *GroupUpsertOne) SetModelRoutingEnabled(v bool) *GroupUpsertOne {
return u.Update(func(s *GroupUpsert) {
s.SetModelRoutingEnabled(v)
})
}
// UpdateModelRoutingEnabled sets the "model_routing_enabled" field to the value that was provided on create.
func (u *GroupUpsertOne) UpdateModelRoutingEnabled() *GroupUpsertOne {
return u.Update(func(s *GroupUpsert) {
s.UpdateModelRoutingEnabled()
})
}
// Exec executes the query. // Exec executes the query.
func (u *GroupUpsertOne) Exec(ctx context.Context) error { func (u *GroupUpsertOne) Exec(ctx context.Context) error {
if len(u.create.conflict) == 0 { if len(u.create.conflict) == 0 {
...@@ -2105,6 +2205,41 @@ func (u *GroupUpsertBulk) ClearFallbackGroupID() *GroupUpsertBulk { ...@@ -2105,6 +2205,41 @@ func (u *GroupUpsertBulk) ClearFallbackGroupID() *GroupUpsertBulk {
}) })
} }
// SetModelRouting sets the "model_routing" field.
func (u *GroupUpsertBulk) SetModelRouting(v map[string][]int64) *GroupUpsertBulk {
return u.Update(func(s *GroupUpsert) {
s.SetModelRouting(v)
})
}
// UpdateModelRouting sets the "model_routing" field to the value that was provided on create.
func (u *GroupUpsertBulk) UpdateModelRouting() *GroupUpsertBulk {
return u.Update(func(s *GroupUpsert) {
s.UpdateModelRouting()
})
}
// ClearModelRouting clears the value of the "model_routing" field.
func (u *GroupUpsertBulk) ClearModelRouting() *GroupUpsertBulk {
return u.Update(func(s *GroupUpsert) {
s.ClearModelRouting()
})
}
// SetModelRoutingEnabled sets the "model_routing_enabled" field.
func (u *GroupUpsertBulk) SetModelRoutingEnabled(v bool) *GroupUpsertBulk {
return u.Update(func(s *GroupUpsert) {
s.SetModelRoutingEnabled(v)
})
}
// UpdateModelRoutingEnabled sets the "model_routing_enabled" field to the value that was provided on create.
func (u *GroupUpsertBulk) UpdateModelRoutingEnabled() *GroupUpsertBulk {
return u.Update(func(s *GroupUpsert) {
s.UpdateModelRoutingEnabled()
})
}
// Exec executes the query. // Exec executes the query.
func (u *GroupUpsertBulk) Exec(ctx context.Context) error { func (u *GroupUpsertBulk) Exec(ctx context.Context) error {
if u.create.err != nil { if u.create.err != nil {
......
...@@ -395,6 +395,32 @@ func (_u *GroupUpdate) ClearFallbackGroupID() *GroupUpdate { ...@@ -395,6 +395,32 @@ func (_u *GroupUpdate) ClearFallbackGroupID() *GroupUpdate {
return _u return _u
} }
// SetModelRouting sets the "model_routing" field.
func (_u *GroupUpdate) SetModelRouting(v map[string][]int64) *GroupUpdate {
_u.mutation.SetModelRouting(v)
return _u
}
// ClearModelRouting clears the value of the "model_routing" field.
func (_u *GroupUpdate) ClearModelRouting() *GroupUpdate {
_u.mutation.ClearModelRouting()
return _u
}
// SetModelRoutingEnabled sets the "model_routing_enabled" field.
func (_u *GroupUpdate) SetModelRoutingEnabled(v bool) *GroupUpdate {
_u.mutation.SetModelRoutingEnabled(v)
return _u
}
// SetNillableModelRoutingEnabled sets the "model_routing_enabled" field if the given value is not nil.
func (_u *GroupUpdate) SetNillableModelRoutingEnabled(v *bool) *GroupUpdate {
if v != nil {
_u.SetModelRoutingEnabled(*v)
}
return _u
}
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs. // AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs.
func (_u *GroupUpdate) AddAPIKeyIDs(ids ...int64) *GroupUpdate { func (_u *GroupUpdate) AddAPIKeyIDs(ids ...int64) *GroupUpdate {
_u.mutation.AddAPIKeyIDs(ids...) _u.mutation.AddAPIKeyIDs(ids...)
...@@ -803,6 +829,15 @@ func (_u *GroupUpdate) sqlSave(ctx context.Context) (_node int, err error) { ...@@ -803,6 +829,15 @@ func (_u *GroupUpdate) sqlSave(ctx context.Context) (_node int, err error) {
if _u.mutation.FallbackGroupIDCleared() { if _u.mutation.FallbackGroupIDCleared() {
_spec.ClearField(group.FieldFallbackGroupID, field.TypeInt64) _spec.ClearField(group.FieldFallbackGroupID, field.TypeInt64)
} }
if value, ok := _u.mutation.ModelRouting(); ok {
_spec.SetField(group.FieldModelRouting, field.TypeJSON, value)
}
if _u.mutation.ModelRoutingCleared() {
_spec.ClearField(group.FieldModelRouting, field.TypeJSON)
}
if value, ok := _u.mutation.ModelRoutingEnabled(); ok {
_spec.SetField(group.FieldModelRoutingEnabled, field.TypeBool, value)
}
if _u.mutation.APIKeysCleared() { if _u.mutation.APIKeysCleared() {
edge := &sqlgraph.EdgeSpec{ edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M, Rel: sqlgraph.O2M,
...@@ -1478,6 +1513,32 @@ func (_u *GroupUpdateOne) ClearFallbackGroupID() *GroupUpdateOne { ...@@ -1478,6 +1513,32 @@ func (_u *GroupUpdateOne) ClearFallbackGroupID() *GroupUpdateOne {
return _u return _u
} }
// SetModelRouting sets the "model_routing" field.
func (_u *GroupUpdateOne) SetModelRouting(v map[string][]int64) *GroupUpdateOne {
_u.mutation.SetModelRouting(v)
return _u
}
// ClearModelRouting clears the value of the "model_routing" field.
func (_u *GroupUpdateOne) ClearModelRouting() *GroupUpdateOne {
_u.mutation.ClearModelRouting()
return _u
}
// SetModelRoutingEnabled sets the "model_routing_enabled" field.
func (_u *GroupUpdateOne) SetModelRoutingEnabled(v bool) *GroupUpdateOne {
_u.mutation.SetModelRoutingEnabled(v)
return _u
}
// SetNillableModelRoutingEnabled sets the "model_routing_enabled" field if the given value is not nil.
func (_u *GroupUpdateOne) SetNillableModelRoutingEnabled(v *bool) *GroupUpdateOne {
if v != nil {
_u.SetModelRoutingEnabled(*v)
}
return _u
}
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs. // AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs.
func (_u *GroupUpdateOne) AddAPIKeyIDs(ids ...int64) *GroupUpdateOne { func (_u *GroupUpdateOne) AddAPIKeyIDs(ids ...int64) *GroupUpdateOne {
_u.mutation.AddAPIKeyIDs(ids...) _u.mutation.AddAPIKeyIDs(ids...)
...@@ -1916,6 +1977,15 @@ func (_u *GroupUpdateOne) sqlSave(ctx context.Context) (_node *Group, err error) ...@@ -1916,6 +1977,15 @@ func (_u *GroupUpdateOne) sqlSave(ctx context.Context) (_node *Group, err error)
if _u.mutation.FallbackGroupIDCleared() { if _u.mutation.FallbackGroupIDCleared() {
_spec.ClearField(group.FieldFallbackGroupID, field.TypeInt64) _spec.ClearField(group.FieldFallbackGroupID, field.TypeInt64)
} }
if value, ok := _u.mutation.ModelRouting(); ok {
_spec.SetField(group.FieldModelRouting, field.TypeJSON, value)
}
if _u.mutation.ModelRoutingCleared() {
_spec.ClearField(group.FieldModelRouting, field.TypeJSON)
}
if value, ok := _u.mutation.ModelRoutingEnabled(); ok {
_spec.SetField(group.FieldModelRoutingEnabled, field.TypeBool, value)
}
if _u.mutation.APIKeysCleared() { if _u.mutation.APIKeysCleared() {
edge := &sqlgraph.EdgeSpec{ edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M, Rel: sqlgraph.O2M,
......
...@@ -226,6 +226,8 @@ var ( ...@@ -226,6 +226,8 @@ var (
{Name: "image_price_4k", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,8)"}}, {Name: "image_price_4k", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,8)"}},
{Name: "claude_code_only", Type: field.TypeBool, Default: false}, {Name: "claude_code_only", Type: field.TypeBool, Default: false},
{Name: "fallback_group_id", Type: field.TypeInt64, Nullable: true}, {Name: "fallback_group_id", Type: field.TypeInt64, Nullable: true},
{Name: "model_routing", Type: field.TypeJSON, Nullable: true, SchemaType: map[string]string{"postgres": "jsonb"}},
{Name: "model_routing_enabled", Type: field.TypeBool, Default: false},
} }
// GroupsTable holds the schema information for the "groups" table. // GroupsTable holds the schema information for the "groups" table.
GroupsTable = &schema.Table{ GroupsTable = &schema.Table{
......
...@@ -3864,6 +3864,8 @@ type GroupMutation struct { ...@@ -3864,6 +3864,8 @@ type GroupMutation struct {
claude_code_only *bool claude_code_only *bool
fallback_group_id *int64 fallback_group_id *int64
addfallback_group_id *int64 addfallback_group_id *int64
model_routing *map[string][]int64
model_routing_enabled *bool
clearedFields map[string]struct{} clearedFields map[string]struct{}
api_keys map[int64]struct{} api_keys map[int64]struct{}
removedapi_keys map[int64]struct{} removedapi_keys map[int64]struct{}
...@@ -4974,6 +4976,91 @@ func (m *GroupMutation) ResetFallbackGroupID() { ...@@ -4974,6 +4976,91 @@ func (m *GroupMutation) ResetFallbackGroupID() {
delete(m.clearedFields, group.FieldFallbackGroupID) delete(m.clearedFields, group.FieldFallbackGroupID)
} }
   
// SetModelRouting sets the "model_routing" field.
func (m *GroupMutation) SetModelRouting(value map[string][]int64) {
m.model_routing = &value
}
// ModelRouting returns the value of the "model_routing" field in the mutation.
func (m *GroupMutation) ModelRouting() (r map[string][]int64, exists bool) {
v := m.model_routing
if v == nil {
return
}
return *v, true
}
// OldModelRouting returns the old "model_routing" field's value of the Group entity.
// If the Group object wasn't provided to the builder, the object is fetched from the database.
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
func (m *GroupMutation) OldModelRouting(ctx context.Context) (v map[string][]int64, err error) {
if !m.op.Is(OpUpdateOne) {
return v, errors.New("OldModelRouting is only allowed on UpdateOne operations")
}
if m.id == nil || m.oldValue == nil {
return v, errors.New("OldModelRouting requires an ID field in the mutation")
}
oldValue, err := m.oldValue(ctx)
if err != nil {
return v, fmt.Errorf("querying old value for OldModelRouting: %w", err)
}
return oldValue.ModelRouting, nil
}
// ClearModelRouting clears the value of the "model_routing" field.
func (m *GroupMutation) ClearModelRouting() {
m.model_routing = nil
m.clearedFields[group.FieldModelRouting] = struct{}{}
}
// ModelRoutingCleared returns if the "model_routing" field was cleared in this mutation.
func (m *GroupMutation) ModelRoutingCleared() bool {
_, ok := m.clearedFields[group.FieldModelRouting]
return ok
}
// ResetModelRouting resets all changes to the "model_routing" field.
func (m *GroupMutation) ResetModelRouting() {
m.model_routing = nil
delete(m.clearedFields, group.FieldModelRouting)
}
// SetModelRoutingEnabled sets the "model_routing_enabled" field.
func (m *GroupMutation) SetModelRoutingEnabled(b bool) {
m.model_routing_enabled = &b
}
// ModelRoutingEnabled returns the value of the "model_routing_enabled" field in the mutation.
func (m *GroupMutation) ModelRoutingEnabled() (r bool, exists bool) {
v := m.model_routing_enabled
if v == nil {
return
}
return *v, true
}
// OldModelRoutingEnabled returns the old "model_routing_enabled" field's value of the Group entity.
// If the Group object wasn't provided to the builder, the object is fetched from the database.
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
func (m *GroupMutation) OldModelRoutingEnabled(ctx context.Context) (v bool, err error) {
if !m.op.Is(OpUpdateOne) {
return v, errors.New("OldModelRoutingEnabled is only allowed on UpdateOne operations")
}
if m.id == nil || m.oldValue == nil {
return v, errors.New("OldModelRoutingEnabled requires an ID field in the mutation")
}
oldValue, err := m.oldValue(ctx)
if err != nil {
return v, fmt.Errorf("querying old value for OldModelRoutingEnabled: %w", err)
}
return oldValue.ModelRoutingEnabled, nil
}
// ResetModelRoutingEnabled resets all changes to the "model_routing_enabled" field.
func (m *GroupMutation) ResetModelRoutingEnabled() {
m.model_routing_enabled = nil
}
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by ids. // AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by ids.
func (m *GroupMutation) AddAPIKeyIDs(ids ...int64) { func (m *GroupMutation) AddAPIKeyIDs(ids ...int64) {
if m.api_keys == nil { if m.api_keys == nil {
...@@ -5332,7 +5419,7 @@ func (m *GroupMutation) Type() string { ...@@ -5332,7 +5419,7 @@ func (m *GroupMutation) Type() string {
// order to get all numeric fields that were incremented/decremented, call // order to get all numeric fields that were incremented/decremented, call
// AddedFields(). // AddedFields().
func (m *GroupMutation) Fields() []string { func (m *GroupMutation) Fields() []string {
fields := make([]string, 0, 19) fields := make([]string, 0, 21)
if m.created_at != nil { if m.created_at != nil {
fields = append(fields, group.FieldCreatedAt) fields = append(fields, group.FieldCreatedAt)
} }
...@@ -5390,6 +5477,12 @@ func (m *GroupMutation) Fields() []string { ...@@ -5390,6 +5477,12 @@ func (m *GroupMutation) Fields() []string {
if m.fallback_group_id != nil { if m.fallback_group_id != nil {
fields = append(fields, group.FieldFallbackGroupID) fields = append(fields, group.FieldFallbackGroupID)
} }
if m.model_routing != nil {
fields = append(fields, group.FieldModelRouting)
}
if m.model_routing_enabled != nil {
fields = append(fields, group.FieldModelRoutingEnabled)
}
return fields return fields
} }
   
...@@ -5436,6 +5529,10 @@ func (m *GroupMutation) Field(name string) (ent.Value, bool) { ...@@ -5436,6 +5529,10 @@ func (m *GroupMutation) Field(name string) (ent.Value, bool) {
return m.ClaudeCodeOnly() return m.ClaudeCodeOnly()
case group.FieldFallbackGroupID: case group.FieldFallbackGroupID:
return m.FallbackGroupID() return m.FallbackGroupID()
case group.FieldModelRouting:
return m.ModelRouting()
case group.FieldModelRoutingEnabled:
return m.ModelRoutingEnabled()
} }
return nil, false return nil, false
} }
...@@ -5483,6 +5580,10 @@ func (m *GroupMutation) OldField(ctx context.Context, name string) (ent.Value, e ...@@ -5483,6 +5580,10 @@ func (m *GroupMutation) OldField(ctx context.Context, name string) (ent.Value, e
return m.OldClaudeCodeOnly(ctx) return m.OldClaudeCodeOnly(ctx)
case group.FieldFallbackGroupID: case group.FieldFallbackGroupID:
return m.OldFallbackGroupID(ctx) return m.OldFallbackGroupID(ctx)
case group.FieldModelRouting:
return m.OldModelRouting(ctx)
case group.FieldModelRoutingEnabled:
return m.OldModelRoutingEnabled(ctx)
} }
return nil, fmt.Errorf("unknown Group field %s", name) return nil, fmt.Errorf("unknown Group field %s", name)
} }
...@@ -5625,6 +5726,20 @@ func (m *GroupMutation) SetField(name string, value ent.Value) error { ...@@ -5625,6 +5726,20 @@ func (m *GroupMutation) SetField(name string, value ent.Value) error {
} }
m.SetFallbackGroupID(v) m.SetFallbackGroupID(v)
return nil return nil
case group.FieldModelRouting:
v, ok := value.(map[string][]int64)
if !ok {
return fmt.Errorf("unexpected type %T for field %s", value, name)
}
m.SetModelRouting(v)
return nil
case group.FieldModelRoutingEnabled:
v, ok := value.(bool)
if !ok {
return fmt.Errorf("unexpected type %T for field %s", value, name)
}
m.SetModelRoutingEnabled(v)
return nil
} }
return fmt.Errorf("unknown Group field %s", name) return fmt.Errorf("unknown Group field %s", name)
} }
...@@ -5793,6 +5908,9 @@ func (m *GroupMutation) ClearedFields() []string { ...@@ -5793,6 +5908,9 @@ func (m *GroupMutation) ClearedFields() []string {
if m.FieldCleared(group.FieldFallbackGroupID) { if m.FieldCleared(group.FieldFallbackGroupID) {
fields = append(fields, group.FieldFallbackGroupID) fields = append(fields, group.FieldFallbackGroupID)
} }
if m.FieldCleared(group.FieldModelRouting) {
fields = append(fields, group.FieldModelRouting)
}
return fields return fields
} }
   
...@@ -5834,6 +5952,9 @@ func (m *GroupMutation) ClearField(name string) error { ...@@ -5834,6 +5952,9 @@ func (m *GroupMutation) ClearField(name string) error {
case group.FieldFallbackGroupID: case group.FieldFallbackGroupID:
m.ClearFallbackGroupID() m.ClearFallbackGroupID()
return nil return nil
case group.FieldModelRouting:
m.ClearModelRouting()
return nil
} }
return fmt.Errorf("unknown Group nullable field %s", name) return fmt.Errorf("unknown Group nullable field %s", name)
} }
...@@ -5899,6 +6020,12 @@ func (m *GroupMutation) ResetField(name string) error { ...@@ -5899,6 +6020,12 @@ func (m *GroupMutation) ResetField(name string) error {
case group.FieldFallbackGroupID: case group.FieldFallbackGroupID:
m.ResetFallbackGroupID() m.ResetFallbackGroupID()
return nil return nil
case group.FieldModelRouting:
m.ResetModelRouting()
return nil
case group.FieldModelRoutingEnabled:
m.ResetModelRoutingEnabled()
return nil
} }
return fmt.Errorf("unknown Group field %s", name) return fmt.Errorf("unknown Group field %s", name)
} }
......
...@@ -280,6 +280,10 @@ func init() { ...@@ -280,6 +280,10 @@ func init() {
groupDescClaudeCodeOnly := groupFields[14].Descriptor() groupDescClaudeCodeOnly := groupFields[14].Descriptor()
// group.DefaultClaudeCodeOnly holds the default value on creation for the claude_code_only field. // group.DefaultClaudeCodeOnly holds the default value on creation for the claude_code_only field.
group.DefaultClaudeCodeOnly = groupDescClaudeCodeOnly.Default.(bool) group.DefaultClaudeCodeOnly = groupDescClaudeCodeOnly.Default.(bool)
// groupDescModelRoutingEnabled is the schema descriptor for model_routing_enabled field.
groupDescModelRoutingEnabled := groupFields[17].Descriptor()
// group.DefaultModelRoutingEnabled holds the default value on creation for the model_routing_enabled field.
group.DefaultModelRoutingEnabled = groupDescModelRoutingEnabled.Default.(bool)
promocodeFields := schema.PromoCode{}.Fields() promocodeFields := schema.PromoCode{}.Fields()
_ = promocodeFields _ = promocodeFields
// promocodeDescCode is the schema descriptor for code field. // promocodeDescCode is the schema descriptor for code field.
......
...@@ -95,6 +95,17 @@ func (Group) Fields() []ent.Field { ...@@ -95,6 +95,17 @@ func (Group) Fields() []ent.Field {
Optional(). Optional().
Nillable(). Nillable().
Comment("非 Claude Code 请求降级使用的分组 ID"), Comment("非 Claude Code 请求降级使用的分组 ID"),
// 模型路由配置 (added by migration 040)
field.JSON("model_routing", map[string][]int64{}).
Optional().
SchemaType(map[string]string{dialect.Postgres: "jsonb"}).
Comment("模型路由配置:模型模式 -> 优先账号ID列表"),
// 模型路由开关 (added by migration 041)
field.Bool("model_routing_enabled").
Default(false).
Comment("是否启用模型路由配置"),
} }
} }
......
...@@ -40,6 +40,9 @@ type CreateGroupRequest struct { ...@@ -40,6 +40,9 @@ type CreateGroupRequest struct {
ImagePrice4K *float64 `json:"image_price_4k"` ImagePrice4K *float64 `json:"image_price_4k"`
ClaudeCodeOnly bool `json:"claude_code_only"` ClaudeCodeOnly bool `json:"claude_code_only"`
FallbackGroupID *int64 `json:"fallback_group_id"` FallbackGroupID *int64 `json:"fallback_group_id"`
// 模型路由配置(仅 anthropic 平台使用)
ModelRouting map[string][]int64 `json:"model_routing"`
ModelRoutingEnabled bool `json:"model_routing_enabled"`
} }
// UpdateGroupRequest represents update group request // UpdateGroupRequest represents update group request
...@@ -60,6 +63,9 @@ type UpdateGroupRequest struct { ...@@ -60,6 +63,9 @@ type UpdateGroupRequest struct {
ImagePrice4K *float64 `json:"image_price_4k"` ImagePrice4K *float64 `json:"image_price_4k"`
ClaudeCodeOnly *bool `json:"claude_code_only"` ClaudeCodeOnly *bool `json:"claude_code_only"`
FallbackGroupID *int64 `json:"fallback_group_id"` FallbackGroupID *int64 `json:"fallback_group_id"`
// 模型路由配置(仅 anthropic 平台使用)
ModelRouting map[string][]int64 `json:"model_routing"`
ModelRoutingEnabled *bool `json:"model_routing_enabled"`
} }
// List handles listing all groups with pagination // List handles listing all groups with pagination
...@@ -149,20 +155,22 @@ func (h *GroupHandler) Create(c *gin.Context) { ...@@ -149,20 +155,22 @@ func (h *GroupHandler) Create(c *gin.Context) {
} }
group, err := h.adminService.CreateGroup(c.Request.Context(), &service.CreateGroupInput{ group, err := h.adminService.CreateGroup(c.Request.Context(), &service.CreateGroupInput{
Name: req.Name, Name: req.Name,
Description: req.Description, Description: req.Description,
Platform: req.Platform, Platform: req.Platform,
RateMultiplier: req.RateMultiplier, RateMultiplier: req.RateMultiplier,
IsExclusive: req.IsExclusive, IsExclusive: req.IsExclusive,
SubscriptionType: req.SubscriptionType, SubscriptionType: req.SubscriptionType,
DailyLimitUSD: req.DailyLimitUSD, DailyLimitUSD: req.DailyLimitUSD,
WeeklyLimitUSD: req.WeeklyLimitUSD, WeeklyLimitUSD: req.WeeklyLimitUSD,
MonthlyLimitUSD: req.MonthlyLimitUSD, MonthlyLimitUSD: req.MonthlyLimitUSD,
ImagePrice1K: req.ImagePrice1K, ImagePrice1K: req.ImagePrice1K,
ImagePrice2K: req.ImagePrice2K, ImagePrice2K: req.ImagePrice2K,
ImagePrice4K: req.ImagePrice4K, ImagePrice4K: req.ImagePrice4K,
ClaudeCodeOnly: req.ClaudeCodeOnly, ClaudeCodeOnly: req.ClaudeCodeOnly,
FallbackGroupID: req.FallbackGroupID, FallbackGroupID: req.FallbackGroupID,
ModelRouting: req.ModelRouting,
ModelRoutingEnabled: req.ModelRoutingEnabled,
}) })
if err != nil { if err != nil {
response.ErrorFrom(c, err) response.ErrorFrom(c, err)
...@@ -188,21 +196,23 @@ func (h *GroupHandler) Update(c *gin.Context) { ...@@ -188,21 +196,23 @@ func (h *GroupHandler) Update(c *gin.Context) {
} }
group, err := h.adminService.UpdateGroup(c.Request.Context(), groupID, &service.UpdateGroupInput{ group, err := h.adminService.UpdateGroup(c.Request.Context(), groupID, &service.UpdateGroupInput{
Name: req.Name, Name: req.Name,
Description: req.Description, Description: req.Description,
Platform: req.Platform, Platform: req.Platform,
RateMultiplier: req.RateMultiplier, RateMultiplier: req.RateMultiplier,
IsExclusive: req.IsExclusive, IsExclusive: req.IsExclusive,
Status: req.Status, Status: req.Status,
SubscriptionType: req.SubscriptionType, SubscriptionType: req.SubscriptionType,
DailyLimitUSD: req.DailyLimitUSD, DailyLimitUSD: req.DailyLimitUSD,
WeeklyLimitUSD: req.WeeklyLimitUSD, WeeklyLimitUSD: req.WeeklyLimitUSD,
MonthlyLimitUSD: req.MonthlyLimitUSD, MonthlyLimitUSD: req.MonthlyLimitUSD,
ImagePrice1K: req.ImagePrice1K, ImagePrice1K: req.ImagePrice1K,
ImagePrice2K: req.ImagePrice2K, ImagePrice2K: req.ImagePrice2K,
ImagePrice4K: req.ImagePrice4K, ImagePrice4K: req.ImagePrice4K,
ClaudeCodeOnly: req.ClaudeCodeOnly, ClaudeCodeOnly: req.ClaudeCodeOnly,
FallbackGroupID: req.FallbackGroupID, FallbackGroupID: req.FallbackGroupID,
ModelRouting: req.ModelRouting,
ModelRoutingEnabled: req.ModelRoutingEnabled,
}) })
if err != nil { if err != nil {
response.ErrorFrom(c, err) response.ErrorFrom(c, err)
......
...@@ -73,25 +73,27 @@ func GroupFromServiceShallow(g *service.Group) *Group { ...@@ -73,25 +73,27 @@ func GroupFromServiceShallow(g *service.Group) *Group {
return nil return nil
} }
return &Group{ return &Group{
ID: g.ID, ID: g.ID,
Name: g.Name, Name: g.Name,
Description: g.Description, Description: g.Description,
Platform: g.Platform, Platform: g.Platform,
RateMultiplier: g.RateMultiplier, RateMultiplier: g.RateMultiplier,
IsExclusive: g.IsExclusive, IsExclusive: g.IsExclusive,
Status: g.Status, Status: g.Status,
SubscriptionType: g.SubscriptionType, SubscriptionType: g.SubscriptionType,
DailyLimitUSD: g.DailyLimitUSD, DailyLimitUSD: g.DailyLimitUSD,
WeeklyLimitUSD: g.WeeklyLimitUSD, WeeklyLimitUSD: g.WeeklyLimitUSD,
MonthlyLimitUSD: g.MonthlyLimitUSD, MonthlyLimitUSD: g.MonthlyLimitUSD,
ImagePrice1K: g.ImagePrice1K, ImagePrice1K: g.ImagePrice1K,
ImagePrice2K: g.ImagePrice2K, ImagePrice2K: g.ImagePrice2K,
ImagePrice4K: g.ImagePrice4K, ImagePrice4K: g.ImagePrice4K,
ClaudeCodeOnly: g.ClaudeCodeOnly, ClaudeCodeOnly: g.ClaudeCodeOnly,
FallbackGroupID: g.FallbackGroupID, FallbackGroupID: g.FallbackGroupID,
CreatedAt: g.CreatedAt, ModelRouting: g.ModelRouting,
UpdatedAt: g.UpdatedAt, ModelRoutingEnabled: g.ModelRoutingEnabled,
AccountCount: g.AccountCount, CreatedAt: g.CreatedAt,
UpdatedAt: g.UpdatedAt,
AccountCount: g.AccountCount,
} }
} }
......
...@@ -58,6 +58,10 @@ type Group struct { ...@@ -58,6 +58,10 @@ type Group struct {
ClaudeCodeOnly bool `json:"claude_code_only"` ClaudeCodeOnly bool `json:"claude_code_only"`
FallbackGroupID *int64 `json:"fallback_group_id"` FallbackGroupID *int64 `json:"fallback_group_id"`
// 模型路由配置(仅 anthropic 平台使用)
ModelRouting map[string][]int64 `json:"model_routing"`
ModelRoutingEnabled bool `json:"model_routing_enabled"`
CreatedAt time.Time `json:"created_at"` CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"` UpdatedAt time.Time `json:"updated_at"`
......
...@@ -136,6 +136,8 @@ func (r *apiKeyRepository) GetByKeyForAuth(ctx context.Context, key string) (*se ...@@ -136,6 +136,8 @@ func (r *apiKeyRepository) GetByKeyForAuth(ctx context.Context, key string) (*se
group.FieldImagePrice4k, group.FieldImagePrice4k,
group.FieldClaudeCodeOnly, group.FieldClaudeCodeOnly,
group.FieldFallbackGroupID, group.FieldFallbackGroupID,
group.FieldModelRoutingEnabled,
group.FieldModelRouting,
) )
}). }).
Only(ctx) Only(ctx)
...@@ -422,6 +424,8 @@ func groupEntityToService(g *dbent.Group) *service.Group { ...@@ -422,6 +424,8 @@ func groupEntityToService(g *dbent.Group) *service.Group {
DefaultValidityDays: g.DefaultValidityDays, DefaultValidityDays: g.DefaultValidityDays,
ClaudeCodeOnly: g.ClaudeCodeOnly, ClaudeCodeOnly: g.ClaudeCodeOnly,
FallbackGroupID: g.FallbackGroupID, FallbackGroupID: g.FallbackGroupID,
ModelRouting: g.ModelRouting,
ModelRoutingEnabled: g.ModelRoutingEnabled,
CreatedAt: g.CreatedAt, CreatedAt: g.CreatedAt,
UpdatedAt: g.UpdatedAt, UpdatedAt: g.UpdatedAt,
} }
......
...@@ -49,7 +49,13 @@ func (r *groupRepository) Create(ctx context.Context, groupIn *service.Group) er ...@@ -49,7 +49,13 @@ func (r *groupRepository) Create(ctx context.Context, groupIn *service.Group) er
SetNillableImagePrice4k(groupIn.ImagePrice4K). SetNillableImagePrice4k(groupIn.ImagePrice4K).
SetDefaultValidityDays(groupIn.DefaultValidityDays). SetDefaultValidityDays(groupIn.DefaultValidityDays).
SetClaudeCodeOnly(groupIn.ClaudeCodeOnly). SetClaudeCodeOnly(groupIn.ClaudeCodeOnly).
SetNillableFallbackGroupID(groupIn.FallbackGroupID) SetNillableFallbackGroupID(groupIn.FallbackGroupID).
SetModelRoutingEnabled(groupIn.ModelRoutingEnabled)
// 设置模型路由配置
if groupIn.ModelRouting != nil {
builder = builder.SetModelRouting(groupIn.ModelRouting)
}
created, err := builder.Save(ctx) created, err := builder.Save(ctx)
if err == nil { if err == nil {
...@@ -101,7 +107,8 @@ func (r *groupRepository) Update(ctx context.Context, groupIn *service.Group) er ...@@ -101,7 +107,8 @@ func (r *groupRepository) Update(ctx context.Context, groupIn *service.Group) er
SetNillableImagePrice2k(groupIn.ImagePrice2K). SetNillableImagePrice2k(groupIn.ImagePrice2K).
SetNillableImagePrice4k(groupIn.ImagePrice4K). SetNillableImagePrice4k(groupIn.ImagePrice4K).
SetDefaultValidityDays(groupIn.DefaultValidityDays). SetDefaultValidityDays(groupIn.DefaultValidityDays).
SetClaudeCodeOnly(groupIn.ClaudeCodeOnly) SetClaudeCodeOnly(groupIn.ClaudeCodeOnly).
SetModelRoutingEnabled(groupIn.ModelRoutingEnabled)
// 处理 FallbackGroupID:nil 时清除,否则设置 // 处理 FallbackGroupID:nil 时清除,否则设置
if groupIn.FallbackGroupID != nil { if groupIn.FallbackGroupID != nil {
...@@ -110,6 +117,13 @@ func (r *groupRepository) Update(ctx context.Context, groupIn *service.Group) er ...@@ -110,6 +117,13 @@ func (r *groupRepository) Update(ctx context.Context, groupIn *service.Group) er
builder = builder.ClearFallbackGroupID() builder = builder.ClearFallbackGroupID()
} }
// 处理 ModelRouting:nil 时清除,否则设置
if groupIn.ModelRouting != nil {
builder = builder.SetModelRouting(groupIn.ModelRouting)
} else {
builder = builder.ClearModelRouting()
}
updated, err := builder.Save(ctx) updated, err := builder.Save(ctx)
if err != nil { if err != nil {
return translatePersistenceError(err, service.ErrGroupNotFound, service.ErrGroupExists) return translatePersistenceError(err, service.ErrGroupNotFound, service.ErrGroupExists)
......
...@@ -106,6 +106,9 @@ type CreateGroupInput struct { ...@@ -106,6 +106,9 @@ type CreateGroupInput struct {
ImagePrice4K *float64 ImagePrice4K *float64
ClaudeCodeOnly bool // 仅允许 Claude Code 客户端 ClaudeCodeOnly bool // 仅允许 Claude Code 客户端
FallbackGroupID *int64 // 降级分组 ID FallbackGroupID *int64 // 降级分组 ID
// 模型路由配置(仅 anthropic 平台使用)
ModelRouting map[string][]int64
ModelRoutingEnabled bool // 是否启用模型路由
} }
type UpdateGroupInput struct { type UpdateGroupInput struct {
...@@ -125,6 +128,9 @@ type UpdateGroupInput struct { ...@@ -125,6 +128,9 @@ type UpdateGroupInput struct {
ImagePrice4K *float64 ImagePrice4K *float64
ClaudeCodeOnly *bool // 仅允许 Claude Code 客户端 ClaudeCodeOnly *bool // 仅允许 Claude Code 客户端
FallbackGroupID *int64 // 降级分组 ID FallbackGroupID *int64 // 降级分组 ID
// 模型路由配置(仅 anthropic 平台使用)
ModelRouting map[string][]int64
ModelRoutingEnabled *bool // 是否启用模型路由
} }
type CreateAccountInput struct { type CreateAccountInput struct {
...@@ -581,6 +587,7 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn ...@@ -581,6 +587,7 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn
ImagePrice4K: imagePrice4K, ImagePrice4K: imagePrice4K,
ClaudeCodeOnly: input.ClaudeCodeOnly, ClaudeCodeOnly: input.ClaudeCodeOnly,
FallbackGroupID: input.FallbackGroupID, FallbackGroupID: input.FallbackGroupID,
ModelRouting: input.ModelRouting,
} }
if err := s.groupRepo.Create(ctx, group); err != nil { if err := s.groupRepo.Create(ctx, group); err != nil {
return nil, err return nil, err
...@@ -709,6 +716,14 @@ func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *Upd ...@@ -709,6 +716,14 @@ func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *Upd
} }
} }
// 模型路由配置
if input.ModelRouting != nil {
group.ModelRouting = input.ModelRouting
}
if input.ModelRoutingEnabled != nil {
group.ModelRoutingEnabled = *input.ModelRoutingEnabled
}
if err := s.groupRepo.Update(ctx, group); err != nil { if err := s.groupRepo.Update(ctx, group); err != nil {
return nil, err return nil, err
} }
......
...@@ -37,6 +37,11 @@ type APIKeyAuthGroupSnapshot struct { ...@@ -37,6 +37,11 @@ type APIKeyAuthGroupSnapshot struct {
ImagePrice4K *float64 `json:"image_price_4k,omitempty"` ImagePrice4K *float64 `json:"image_price_4k,omitempty"`
ClaudeCodeOnly bool `json:"claude_code_only"` ClaudeCodeOnly bool `json:"claude_code_only"`
FallbackGroupID *int64 `json:"fallback_group_id,omitempty"` FallbackGroupID *int64 `json:"fallback_group_id,omitempty"`
// Model routing is used by gateway account selection, so it must be part of auth cache snapshot.
// Only anthropic groups use these fields; others may leave them empty.
ModelRouting map[string][]int64 `json:"model_routing,omitempty"`
ModelRoutingEnabled bool `json:"model_routing_enabled"`
} }
// APIKeyAuthCacheEntry 缓存条目,支持负缓存 // APIKeyAuthCacheEntry 缓存条目,支持负缓存
......
...@@ -207,20 +207,22 @@ func (s *APIKeyService) snapshotFromAPIKey(apiKey *APIKey) *APIKeyAuthSnapshot { ...@@ -207,20 +207,22 @@ func (s *APIKeyService) snapshotFromAPIKey(apiKey *APIKey) *APIKeyAuthSnapshot {
} }
if apiKey.Group != nil { if apiKey.Group != nil {
snapshot.Group = &APIKeyAuthGroupSnapshot{ snapshot.Group = &APIKeyAuthGroupSnapshot{
ID: apiKey.Group.ID, ID: apiKey.Group.ID,
Name: apiKey.Group.Name, Name: apiKey.Group.Name,
Platform: apiKey.Group.Platform, Platform: apiKey.Group.Platform,
Status: apiKey.Group.Status, Status: apiKey.Group.Status,
SubscriptionType: apiKey.Group.SubscriptionType, SubscriptionType: apiKey.Group.SubscriptionType,
RateMultiplier: apiKey.Group.RateMultiplier, RateMultiplier: apiKey.Group.RateMultiplier,
DailyLimitUSD: apiKey.Group.DailyLimitUSD, DailyLimitUSD: apiKey.Group.DailyLimitUSD,
WeeklyLimitUSD: apiKey.Group.WeeklyLimitUSD, WeeklyLimitUSD: apiKey.Group.WeeklyLimitUSD,
MonthlyLimitUSD: apiKey.Group.MonthlyLimitUSD, MonthlyLimitUSD: apiKey.Group.MonthlyLimitUSD,
ImagePrice1K: apiKey.Group.ImagePrice1K, ImagePrice1K: apiKey.Group.ImagePrice1K,
ImagePrice2K: apiKey.Group.ImagePrice2K, ImagePrice2K: apiKey.Group.ImagePrice2K,
ImagePrice4K: apiKey.Group.ImagePrice4K, ImagePrice4K: apiKey.Group.ImagePrice4K,
ClaudeCodeOnly: apiKey.Group.ClaudeCodeOnly, ClaudeCodeOnly: apiKey.Group.ClaudeCodeOnly,
FallbackGroupID: apiKey.Group.FallbackGroupID, FallbackGroupID: apiKey.Group.FallbackGroupID,
ModelRouting: apiKey.Group.ModelRouting,
ModelRoutingEnabled: apiKey.Group.ModelRoutingEnabled,
} }
} }
return snapshot return snapshot
...@@ -248,21 +250,23 @@ func (s *APIKeyService) snapshotToAPIKey(key string, snapshot *APIKeyAuthSnapsho ...@@ -248,21 +250,23 @@ func (s *APIKeyService) snapshotToAPIKey(key string, snapshot *APIKeyAuthSnapsho
} }
if snapshot.Group != nil { if snapshot.Group != nil {
apiKey.Group = &Group{ apiKey.Group = &Group{
ID: snapshot.Group.ID, ID: snapshot.Group.ID,
Name: snapshot.Group.Name, Name: snapshot.Group.Name,
Platform: snapshot.Group.Platform, Platform: snapshot.Group.Platform,
Status: snapshot.Group.Status, Status: snapshot.Group.Status,
Hydrated: true, Hydrated: true,
SubscriptionType: snapshot.Group.SubscriptionType, SubscriptionType: snapshot.Group.SubscriptionType,
RateMultiplier: snapshot.Group.RateMultiplier, RateMultiplier: snapshot.Group.RateMultiplier,
DailyLimitUSD: snapshot.Group.DailyLimitUSD, DailyLimitUSD: snapshot.Group.DailyLimitUSD,
WeeklyLimitUSD: snapshot.Group.WeeklyLimitUSD, WeeklyLimitUSD: snapshot.Group.WeeklyLimitUSD,
MonthlyLimitUSD: snapshot.Group.MonthlyLimitUSD, MonthlyLimitUSD: snapshot.Group.MonthlyLimitUSD,
ImagePrice1K: snapshot.Group.ImagePrice1K, ImagePrice1K: snapshot.Group.ImagePrice1K,
ImagePrice2K: snapshot.Group.ImagePrice2K, ImagePrice2K: snapshot.Group.ImagePrice2K,
ImagePrice4K: snapshot.Group.ImagePrice4K, ImagePrice4K: snapshot.Group.ImagePrice4K,
ClaudeCodeOnly: snapshot.Group.ClaudeCodeOnly, ClaudeCodeOnly: snapshot.Group.ClaudeCodeOnly,
FallbackGroupID: snapshot.Group.FallbackGroupID, FallbackGroupID: snapshot.Group.FallbackGroupID,
ModelRouting: snapshot.Group.ModelRouting,
ModelRoutingEnabled: snapshot.Group.ModelRoutingEnabled,
} }
} }
return apiKey return apiKey
......
...@@ -172,12 +172,16 @@ func TestAPIKeyService_GetByKey_UsesL2Cache(t *testing.T) { ...@@ -172,12 +172,16 @@ func TestAPIKeyService_GetByKey_UsesL2Cache(t *testing.T) {
Concurrency: 3, Concurrency: 3,
}, },
Group: &APIKeyAuthGroupSnapshot{ Group: &APIKeyAuthGroupSnapshot{
ID: groupID, ID: groupID,
Name: "g", Name: "g",
Platform: PlatformAnthropic, Platform: PlatformAnthropic,
Status: StatusActive, Status: StatusActive,
SubscriptionType: SubscriptionTypeStandard, SubscriptionType: SubscriptionTypeStandard,
RateMultiplier: 1, RateMultiplier: 1,
ModelRoutingEnabled: true,
ModelRouting: map[string][]int64{
"claude-opus-*": {1, 2},
},
}, },
}, },
} }
...@@ -190,6 +194,8 @@ func TestAPIKeyService_GetByKey_UsesL2Cache(t *testing.T) { ...@@ -190,6 +194,8 @@ func TestAPIKeyService_GetByKey_UsesL2Cache(t *testing.T) {
require.Equal(t, int64(1), apiKey.ID) require.Equal(t, int64(1), apiKey.ID)
require.Equal(t, int64(2), apiKey.User.ID) require.Equal(t, int64(2), apiKey.User.ID)
require.Equal(t, groupID, apiKey.Group.ID) require.Equal(t, groupID, apiKey.Group.ID)
require.True(t, apiKey.Group.ModelRoutingEnabled)
require.Equal(t, map[string][]int64{"claude-opus-*": {1, 2}}, apiKey.Group.ModelRouting)
} }
func TestAPIKeyService_GetByKey_NegativeCache(t *testing.T) { func TestAPIKeyService_GetByKey_NegativeCache(t *testing.T) {
......
...@@ -696,11 +696,11 @@ func TestClaudeTokenProvider_RefreshPreservesExistingCredentials(t *testing.T) { ...@@ -696,11 +696,11 @@ func TestClaudeTokenProvider_RefreshPreservesExistingCredentials(t *testing.T) {
Platform: PlatformAnthropic, Platform: PlatformAnthropic,
Type: AccountTypeOAuth, Type: AccountTypeOAuth,
Credentials: map[string]any{ Credentials: map[string]any{
"access_token": "old-access-token", "access_token": "old-access-token",
"refresh_token": "old-refresh-token", "refresh_token": "old-refresh-token",
"expires_at": expiresAt, "expires_at": expiresAt,
"custom_field": "should-be-preserved", "custom_field": "should-be-preserved",
"organization": "test-org", "organization": "test-org",
}, },
} }
accountRepo.account = account accountRepo.account = account
......
...@@ -1059,6 +1059,60 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) { ...@@ -1059,6 +1059,60 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
require.Equal(t, int64(1), result.Account.ID, "应选择优先级最高的账号") require.Equal(t, int64(1), result.Account.ID, "应选择优先级最高的账号")
}) })
t.Run("模型路由-无ConcurrencyService也生效", func(t *testing.T) {
groupID := int64(1)
sessionHash := "sticky"
repo := &mockAccountRepoForPlatform{
accounts: []Account{
{ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true, Concurrency: 5, AccountGroups: []AccountGroup{{GroupID: groupID}}},
{ID: 2, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true, Concurrency: 5, AccountGroups: []AccountGroup{{GroupID: groupID}}},
},
accountsByID: map[int64]*Account{},
}
for i := range repo.accounts {
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
}
cache := &mockGatewayCacheForPlatform{
sessionBindings: map[string]int64{sessionHash: 1},
}
groupRepo := &mockGroupRepoForGateway{
groups: map[int64]*Group{
groupID: {
ID: groupID,
Platform: PlatformAnthropic,
Status: StatusActive,
Hydrated: true,
ModelRoutingEnabled: true,
ModelRouting: map[string][]int64{
"claude-a": {1},
"claude-b": {2},
},
},
},
}
cfg := testConfig()
cfg.Gateway.Scheduling.LoadBatchEnabled = true
svc := &GatewayService{
accountRepo: repo,
groupRepo: groupRepo,
cache: cache,
cfg: cfg,
concurrencyService: nil, // legacy path
}
result, err := svc.SelectAccountWithLoadAwareness(ctx, &groupID, sessionHash, "claude-b", nil)
require.NoError(t, err)
require.NotNil(t, result)
require.NotNil(t, result.Account)
require.Equal(t, int64(2), result.Account.ID, "切换到 claude-b 时应按模型路由切换账号")
require.Equal(t, int64(2), cache.sessionBindings[sessionHash], "粘性绑定应更新为路由选择的账号")
})
t.Run("无ConcurrencyService-降级到传统选择", func(t *testing.T) { t.Run("无ConcurrencyService-降级到传统选择", func(t *testing.T) {
repo := &mockAccountRepoForPlatform{ repo := &mockAccountRepoForPlatform{
accounts: []Account{ accounts: []Account{
...@@ -1347,6 +1401,7 @@ func TestGatewayService_GroupResolution_IgnoresInvalidContextGroup(t *testing.T) ...@@ -1347,6 +1401,7 @@ func TestGatewayService_GroupResolution_IgnoresInvalidContextGroup(t *testing.T)
ID: groupID, ID: groupID,
Platform: PlatformAnthropic, Platform: PlatformAnthropic,
Status: StatusActive, Status: StatusActive,
Hydrated: true,
} }
groupRepo := &mockGroupRepoForGateway{ groupRepo := &mockGroupRepoForGateway{
groups: map[int64]*Group{groupID: group}, groups: map[int64]*Group{groupID: group},
...@@ -1404,6 +1459,7 @@ func TestGatewayService_GroupResolution_FallbackUsesLiteOnce(t *testing.T) { ...@@ -1404,6 +1459,7 @@ func TestGatewayService_GroupResolution_FallbackUsesLiteOnce(t *testing.T) {
ID: fallbackID, ID: fallbackID,
Platform: PlatformAnthropic, Platform: PlatformAnthropic,
Status: StatusActive, Status: StatusActive,
Hydrated: true,
} }
ctx = context.WithValue(ctx, ctxkey.Group, group) ctx = context.WithValue(ctx, ctxkey.Group, group)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment