Commit fd0370c0 authored by song's avatar song
Browse files

Add invalid-request fallback routing

parent 316f2fee
...@@ -56,6 +56,8 @@ type Group struct { ...@@ -56,6 +56,8 @@ type Group struct {
ClaudeCodeOnly bool `json:"claude_code_only,omitempty"` ClaudeCodeOnly bool `json:"claude_code_only,omitempty"`
// 非 Claude Code 请求降级使用的分组 ID // 非 Claude Code 请求降级使用的分组 ID
FallbackGroupID *int64 `json:"fallback_group_id,omitempty"` FallbackGroupID *int64 `json:"fallback_group_id,omitempty"`
// 无效请求兜底使用的分组 ID
FallbackGroupIDOnInvalidRequest *int64 `json:"fallback_group_id_on_invalid_request,omitempty"`
// 模型路由配置:模型模式 -> 优先账号ID列表 // 模型路由配置:模型模式 -> 优先账号ID列表
ModelRouting map[string][]int64 `json:"model_routing,omitempty"` ModelRouting map[string][]int64 `json:"model_routing,omitempty"`
// 是否启用模型路由配置 // 是否启用模型路由配置
...@@ -172,7 +174,7 @@ func (*Group) scanValues(columns []string) ([]any, error) { ...@@ -172,7 +174,7 @@ func (*Group) scanValues(columns []string) ([]any, error) {
values[i] = new(sql.NullBool) values[i] = new(sql.NullBool)
case group.FieldRateMultiplier, group.FieldDailyLimitUsd, group.FieldWeeklyLimitUsd, group.FieldMonthlyLimitUsd, group.FieldImagePrice1k, group.FieldImagePrice2k, group.FieldImagePrice4k: case group.FieldRateMultiplier, group.FieldDailyLimitUsd, group.FieldWeeklyLimitUsd, group.FieldMonthlyLimitUsd, group.FieldImagePrice1k, group.FieldImagePrice2k, group.FieldImagePrice4k:
values[i] = new(sql.NullFloat64) values[i] = new(sql.NullFloat64)
case group.FieldID, group.FieldDefaultValidityDays, group.FieldFallbackGroupID: case group.FieldID, group.FieldDefaultValidityDays, group.FieldFallbackGroupID, group.FieldFallbackGroupIDOnInvalidRequest:
values[i] = new(sql.NullInt64) values[i] = new(sql.NullInt64)
case group.FieldName, group.FieldDescription, group.FieldStatus, group.FieldPlatform, group.FieldSubscriptionType: case group.FieldName, group.FieldDescription, group.FieldStatus, group.FieldPlatform, group.FieldSubscriptionType:
values[i] = new(sql.NullString) values[i] = new(sql.NullString)
...@@ -322,6 +324,13 @@ func (_m *Group) assignValues(columns []string, values []any) error { ...@@ -322,6 +324,13 @@ func (_m *Group) assignValues(columns []string, values []any) error {
_m.FallbackGroupID = new(int64) _m.FallbackGroupID = new(int64)
*_m.FallbackGroupID = value.Int64 *_m.FallbackGroupID = value.Int64
} }
case group.FieldFallbackGroupIDOnInvalidRequest:
if value, ok := values[i].(*sql.NullInt64); !ok {
return fmt.Errorf("unexpected type %T for field fallback_group_id_on_invalid_request", values[i])
} else if value.Valid {
_m.FallbackGroupIDOnInvalidRequest = new(int64)
*_m.FallbackGroupIDOnInvalidRequest = value.Int64
}
case group.FieldModelRouting: case group.FieldModelRouting:
if value, ok := values[i].(*[]byte); !ok { if value, ok := values[i].(*[]byte); !ok {
return fmt.Errorf("unexpected type %T for field model_routing", values[i]) return fmt.Errorf("unexpected type %T for field model_routing", values[i])
...@@ -487,6 +496,11 @@ func (_m *Group) String() string { ...@@ -487,6 +496,11 @@ func (_m *Group) String() string {
builder.WriteString(fmt.Sprintf("%v", *v)) builder.WriteString(fmt.Sprintf("%v", *v))
} }
builder.WriteString(", ") builder.WriteString(", ")
if v := _m.FallbackGroupIDOnInvalidRequest; v != nil {
builder.WriteString("fallback_group_id_on_invalid_request=")
builder.WriteString(fmt.Sprintf("%v", *v))
}
builder.WriteString(", ")
builder.WriteString("model_routing=") builder.WriteString("model_routing=")
builder.WriteString(fmt.Sprintf("%v", _m.ModelRouting)) builder.WriteString(fmt.Sprintf("%v", _m.ModelRouting))
builder.WriteString(", ") builder.WriteString(", ")
......
...@@ -53,6 +53,8 @@ const ( ...@@ -53,6 +53,8 @@ const (
FieldClaudeCodeOnly = "claude_code_only" FieldClaudeCodeOnly = "claude_code_only"
// FieldFallbackGroupID holds the string denoting the fallback_group_id field in the database. // FieldFallbackGroupID holds the string denoting the fallback_group_id field in the database.
FieldFallbackGroupID = "fallback_group_id" FieldFallbackGroupID = "fallback_group_id"
// FieldFallbackGroupIDOnInvalidRequest holds the string denoting the fallback_group_id_on_invalid_request field in the database.
FieldFallbackGroupIDOnInvalidRequest = "fallback_group_id_on_invalid_request"
// FieldModelRouting holds the string denoting the model_routing field in the database. // FieldModelRouting holds the string denoting the model_routing field in the database.
FieldModelRouting = "model_routing" FieldModelRouting = "model_routing"
// FieldModelRoutingEnabled holds the string denoting the model_routing_enabled field in the database. // FieldModelRoutingEnabled holds the string denoting the model_routing_enabled field in the database.
...@@ -151,6 +153,7 @@ var Columns = []string{ ...@@ -151,6 +153,7 @@ var Columns = []string{
FieldImagePrice4k, FieldImagePrice4k,
FieldClaudeCodeOnly, FieldClaudeCodeOnly,
FieldFallbackGroupID, FieldFallbackGroupID,
FieldFallbackGroupIDOnInvalidRequest,
FieldModelRouting, FieldModelRouting,
FieldModelRoutingEnabled, FieldModelRoutingEnabled,
} }
...@@ -317,6 +320,11 @@ func ByFallbackGroupID(opts ...sql.OrderTermOption) OrderOption { ...@@ -317,6 +320,11 @@ func ByFallbackGroupID(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldFallbackGroupID, opts...).ToFunc() return sql.OrderByField(FieldFallbackGroupID, opts...).ToFunc()
} }
// ByFallbackGroupIDOnInvalidRequest orders the results by the fallback_group_id_on_invalid_request field.
func ByFallbackGroupIDOnInvalidRequest(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldFallbackGroupIDOnInvalidRequest, opts...).ToFunc()
}
// ByModelRoutingEnabled orders the results by the model_routing_enabled field. // ByModelRoutingEnabled orders the results by the model_routing_enabled field.
func ByModelRoutingEnabled(opts ...sql.OrderTermOption) OrderOption { func ByModelRoutingEnabled(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldModelRoutingEnabled, opts...).ToFunc() return sql.OrderByField(FieldModelRoutingEnabled, opts...).ToFunc()
......
...@@ -150,6 +150,11 @@ func FallbackGroupID(v int64) predicate.Group { ...@@ -150,6 +150,11 @@ func FallbackGroupID(v int64) predicate.Group {
return predicate.Group(sql.FieldEQ(FieldFallbackGroupID, v)) return predicate.Group(sql.FieldEQ(FieldFallbackGroupID, v))
} }
// FallbackGroupIDOnInvalidRequest applies equality check predicate on the "fallback_group_id_on_invalid_request" field. It's identical to FallbackGroupIDOnInvalidRequestEQ.
func FallbackGroupIDOnInvalidRequest(v int64) predicate.Group {
return predicate.Group(sql.FieldEQ(FieldFallbackGroupIDOnInvalidRequest, v))
}
// ModelRoutingEnabled applies equality check predicate on the "model_routing_enabled" field. It's identical to ModelRoutingEnabledEQ. // ModelRoutingEnabled applies equality check predicate on the "model_routing_enabled" field. It's identical to ModelRoutingEnabledEQ.
func ModelRoutingEnabled(v bool) predicate.Group { func ModelRoutingEnabled(v bool) predicate.Group {
return predicate.Group(sql.FieldEQ(FieldModelRoutingEnabled, v)) return predicate.Group(sql.FieldEQ(FieldModelRoutingEnabled, v))
...@@ -1070,6 +1075,56 @@ func FallbackGroupIDNotNil() predicate.Group { ...@@ -1070,6 +1075,56 @@ func FallbackGroupIDNotNil() predicate.Group {
return predicate.Group(sql.FieldNotNull(FieldFallbackGroupID)) return predicate.Group(sql.FieldNotNull(FieldFallbackGroupID))
} }
// FallbackGroupIDOnInvalidRequestEQ applies the EQ predicate on the "fallback_group_id_on_invalid_request" field.
func FallbackGroupIDOnInvalidRequestEQ(v int64) predicate.Group {
return predicate.Group(sql.FieldEQ(FieldFallbackGroupIDOnInvalidRequest, v))
}
// FallbackGroupIDOnInvalidRequestNEQ applies the NEQ predicate on the "fallback_group_id_on_invalid_request" field.
func FallbackGroupIDOnInvalidRequestNEQ(v int64) predicate.Group {
return predicate.Group(sql.FieldNEQ(FieldFallbackGroupIDOnInvalidRequest, v))
}
// FallbackGroupIDOnInvalidRequestIn applies the In predicate on the "fallback_group_id_on_invalid_request" field.
func FallbackGroupIDOnInvalidRequestIn(vs ...int64) predicate.Group {
return predicate.Group(sql.FieldIn(FieldFallbackGroupIDOnInvalidRequest, vs...))
}
// FallbackGroupIDOnInvalidRequestNotIn applies the NotIn predicate on the "fallback_group_id_on_invalid_request" field.
func FallbackGroupIDOnInvalidRequestNotIn(vs ...int64) predicate.Group {
return predicate.Group(sql.FieldNotIn(FieldFallbackGroupIDOnInvalidRequest, vs...))
}
// FallbackGroupIDOnInvalidRequestGT applies the GT predicate on the "fallback_group_id_on_invalid_request" field.
func FallbackGroupIDOnInvalidRequestGT(v int64) predicate.Group {
return predicate.Group(sql.FieldGT(FieldFallbackGroupIDOnInvalidRequest, v))
}
// FallbackGroupIDOnInvalidRequestGTE applies the GTE predicate on the "fallback_group_id_on_invalid_request" field.
func FallbackGroupIDOnInvalidRequestGTE(v int64) predicate.Group {
return predicate.Group(sql.FieldGTE(FieldFallbackGroupIDOnInvalidRequest, v))
}
// FallbackGroupIDOnInvalidRequestLT applies the LT predicate on the "fallback_group_id_on_invalid_request" field.
func FallbackGroupIDOnInvalidRequestLT(v int64) predicate.Group {
return predicate.Group(sql.FieldLT(FieldFallbackGroupIDOnInvalidRequest, v))
}
// FallbackGroupIDOnInvalidRequestLTE applies the LTE predicate on the "fallback_group_id_on_invalid_request" field.
func FallbackGroupIDOnInvalidRequestLTE(v int64) predicate.Group {
return predicate.Group(sql.FieldLTE(FieldFallbackGroupIDOnInvalidRequest, v))
}
// FallbackGroupIDOnInvalidRequestIsNil applies the IsNil predicate on the "fallback_group_id_on_invalid_request" field.
func FallbackGroupIDOnInvalidRequestIsNil() predicate.Group {
return predicate.Group(sql.FieldIsNull(FieldFallbackGroupIDOnInvalidRequest))
}
// FallbackGroupIDOnInvalidRequestNotNil applies the NotNil predicate on the "fallback_group_id_on_invalid_request" field.
func FallbackGroupIDOnInvalidRequestNotNil() predicate.Group {
return predicate.Group(sql.FieldNotNull(FieldFallbackGroupIDOnInvalidRequest))
}
// ModelRoutingIsNil applies the IsNil predicate on the "model_routing" field. // ModelRoutingIsNil applies the IsNil predicate on the "model_routing" field.
func ModelRoutingIsNil() predicate.Group { func ModelRoutingIsNil() predicate.Group {
return predicate.Group(sql.FieldIsNull(FieldModelRouting)) return predicate.Group(sql.FieldIsNull(FieldModelRouting))
......
...@@ -286,6 +286,20 @@ func (_c *GroupCreate) SetNillableFallbackGroupID(v *int64) *GroupCreate { ...@@ -286,6 +286,20 @@ func (_c *GroupCreate) SetNillableFallbackGroupID(v *int64) *GroupCreate {
return _c return _c
} }
// SetFallbackGroupIDOnInvalidRequest sets the "fallback_group_id_on_invalid_request" field.
func (_c *GroupCreate) SetFallbackGroupIDOnInvalidRequest(v int64) *GroupCreate {
_c.mutation.SetFallbackGroupIDOnInvalidRequest(v)
return _c
}
// SetNillableFallbackGroupIDOnInvalidRequest sets the "fallback_group_id_on_invalid_request" field if the given value is not nil.
func (_c *GroupCreate) SetNillableFallbackGroupIDOnInvalidRequest(v *int64) *GroupCreate {
if v != nil {
_c.SetFallbackGroupIDOnInvalidRequest(*v)
}
return _c
}
// SetModelRouting sets the "model_routing" field. // SetModelRouting sets the "model_routing" field.
func (_c *GroupCreate) SetModelRouting(v map[string][]int64) *GroupCreate { func (_c *GroupCreate) SetModelRouting(v map[string][]int64) *GroupCreate {
_c.mutation.SetModelRouting(v) _c.mutation.SetModelRouting(v)
...@@ -640,6 +654,10 @@ func (_c *GroupCreate) createSpec() (*Group, *sqlgraph.CreateSpec) { ...@@ -640,6 +654,10 @@ func (_c *GroupCreate) createSpec() (*Group, *sqlgraph.CreateSpec) {
_spec.SetField(group.FieldFallbackGroupID, field.TypeInt64, value) _spec.SetField(group.FieldFallbackGroupID, field.TypeInt64, value)
_node.FallbackGroupID = &value _node.FallbackGroupID = &value
} }
if value, ok := _c.mutation.FallbackGroupIDOnInvalidRequest(); ok {
_spec.SetField(group.FieldFallbackGroupIDOnInvalidRequest, field.TypeInt64, value)
_node.FallbackGroupIDOnInvalidRequest = &value
}
if value, ok := _c.mutation.ModelRouting(); ok { if value, ok := _c.mutation.ModelRouting(); ok {
_spec.SetField(group.FieldModelRouting, field.TypeJSON, value) _spec.SetField(group.FieldModelRouting, field.TypeJSON, value)
_node.ModelRouting = value _node.ModelRouting = value
...@@ -1128,6 +1146,30 @@ func (u *GroupUpsert) ClearFallbackGroupID() *GroupUpsert { ...@@ -1128,6 +1146,30 @@ func (u *GroupUpsert) ClearFallbackGroupID() *GroupUpsert {
return u return u
} }
// SetFallbackGroupIDOnInvalidRequest sets the "fallback_group_id_on_invalid_request" field.
func (u *GroupUpsert) SetFallbackGroupIDOnInvalidRequest(v int64) *GroupUpsert {
u.Set(group.FieldFallbackGroupIDOnInvalidRequest, v)
return u
}
// UpdateFallbackGroupIDOnInvalidRequest sets the "fallback_group_id_on_invalid_request" field to the value that was provided on create.
func (u *GroupUpsert) UpdateFallbackGroupIDOnInvalidRequest() *GroupUpsert {
u.SetExcluded(group.FieldFallbackGroupIDOnInvalidRequest)
return u
}
// AddFallbackGroupIDOnInvalidRequest adds v to the "fallback_group_id_on_invalid_request" field.
func (u *GroupUpsert) AddFallbackGroupIDOnInvalidRequest(v int64) *GroupUpsert {
u.Add(group.FieldFallbackGroupIDOnInvalidRequest, v)
return u
}
// ClearFallbackGroupIDOnInvalidRequest clears the value of the "fallback_group_id_on_invalid_request" field.
func (u *GroupUpsert) ClearFallbackGroupIDOnInvalidRequest() *GroupUpsert {
u.SetNull(group.FieldFallbackGroupIDOnInvalidRequest)
return u
}
// SetModelRouting sets the "model_routing" field. // SetModelRouting sets the "model_routing" field.
func (u *GroupUpsert) SetModelRouting(v map[string][]int64) *GroupUpsert { func (u *GroupUpsert) SetModelRouting(v map[string][]int64) *GroupUpsert {
u.Set(group.FieldModelRouting, v) u.Set(group.FieldModelRouting, v)
...@@ -1581,6 +1623,34 @@ func (u *GroupUpsertOne) ClearFallbackGroupID() *GroupUpsertOne { ...@@ -1581,6 +1623,34 @@ func (u *GroupUpsertOne) ClearFallbackGroupID() *GroupUpsertOne {
}) })
} }
// SetFallbackGroupIDOnInvalidRequest sets the "fallback_group_id_on_invalid_request" field.
func (u *GroupUpsertOne) SetFallbackGroupIDOnInvalidRequest(v int64) *GroupUpsertOne {
return u.Update(func(s *GroupUpsert) {
s.SetFallbackGroupIDOnInvalidRequest(v)
})
}
// AddFallbackGroupIDOnInvalidRequest adds v to the "fallback_group_id_on_invalid_request" field.
func (u *GroupUpsertOne) AddFallbackGroupIDOnInvalidRequest(v int64) *GroupUpsertOne {
return u.Update(func(s *GroupUpsert) {
s.AddFallbackGroupIDOnInvalidRequest(v)
})
}
// UpdateFallbackGroupIDOnInvalidRequest sets the "fallback_group_id_on_invalid_request" field to the value that was provided on create.
func (u *GroupUpsertOne) UpdateFallbackGroupIDOnInvalidRequest() *GroupUpsertOne {
return u.Update(func(s *GroupUpsert) {
s.UpdateFallbackGroupIDOnInvalidRequest()
})
}
// ClearFallbackGroupIDOnInvalidRequest clears the value of the "fallback_group_id_on_invalid_request" field.
func (u *GroupUpsertOne) ClearFallbackGroupIDOnInvalidRequest() *GroupUpsertOne {
return u.Update(func(s *GroupUpsert) {
s.ClearFallbackGroupIDOnInvalidRequest()
})
}
// SetModelRouting sets the "model_routing" field. // SetModelRouting sets the "model_routing" field.
func (u *GroupUpsertOne) SetModelRouting(v map[string][]int64) *GroupUpsertOne { func (u *GroupUpsertOne) SetModelRouting(v map[string][]int64) *GroupUpsertOne {
return u.Update(func(s *GroupUpsert) { return u.Update(func(s *GroupUpsert) {
...@@ -2205,6 +2275,34 @@ func (u *GroupUpsertBulk) ClearFallbackGroupID() *GroupUpsertBulk { ...@@ -2205,6 +2275,34 @@ func (u *GroupUpsertBulk) ClearFallbackGroupID() *GroupUpsertBulk {
}) })
} }
// SetFallbackGroupIDOnInvalidRequest sets the "fallback_group_id_on_invalid_request" field.
func (u *GroupUpsertBulk) SetFallbackGroupIDOnInvalidRequest(v int64) *GroupUpsertBulk {
return u.Update(func(s *GroupUpsert) {
s.SetFallbackGroupIDOnInvalidRequest(v)
})
}
// AddFallbackGroupIDOnInvalidRequest adds v to the "fallback_group_id_on_invalid_request" field.
func (u *GroupUpsertBulk) AddFallbackGroupIDOnInvalidRequest(v int64) *GroupUpsertBulk {
return u.Update(func(s *GroupUpsert) {
s.AddFallbackGroupIDOnInvalidRequest(v)
})
}
// UpdateFallbackGroupIDOnInvalidRequest sets the "fallback_group_id_on_invalid_request" field to the value that was provided on create.
func (u *GroupUpsertBulk) UpdateFallbackGroupIDOnInvalidRequest() *GroupUpsertBulk {
return u.Update(func(s *GroupUpsert) {
s.UpdateFallbackGroupIDOnInvalidRequest()
})
}
// ClearFallbackGroupIDOnInvalidRequest clears the value of the "fallback_group_id_on_invalid_request" field.
func (u *GroupUpsertBulk) ClearFallbackGroupIDOnInvalidRequest() *GroupUpsertBulk {
return u.Update(func(s *GroupUpsert) {
s.ClearFallbackGroupIDOnInvalidRequest()
})
}
// SetModelRouting sets the "model_routing" field. // SetModelRouting sets the "model_routing" field.
func (u *GroupUpsertBulk) SetModelRouting(v map[string][]int64) *GroupUpsertBulk { func (u *GroupUpsertBulk) SetModelRouting(v map[string][]int64) *GroupUpsertBulk {
return u.Update(func(s *GroupUpsert) { return u.Update(func(s *GroupUpsert) {
......
...@@ -395,6 +395,33 @@ func (_u *GroupUpdate) ClearFallbackGroupID() *GroupUpdate { ...@@ -395,6 +395,33 @@ func (_u *GroupUpdate) ClearFallbackGroupID() *GroupUpdate {
return _u return _u
} }
// SetFallbackGroupIDOnInvalidRequest sets the "fallback_group_id_on_invalid_request" field.
func (_u *GroupUpdate) SetFallbackGroupIDOnInvalidRequest(v int64) *GroupUpdate {
_u.mutation.ResetFallbackGroupIDOnInvalidRequest()
_u.mutation.SetFallbackGroupIDOnInvalidRequest(v)
return _u
}
// SetNillableFallbackGroupIDOnInvalidRequest sets the "fallback_group_id_on_invalid_request" field if the given value is not nil.
func (_u *GroupUpdate) SetNillableFallbackGroupIDOnInvalidRequest(v *int64) *GroupUpdate {
if v != nil {
_u.SetFallbackGroupIDOnInvalidRequest(*v)
}
return _u
}
// AddFallbackGroupIDOnInvalidRequest adds value to the "fallback_group_id_on_invalid_request" field.
func (_u *GroupUpdate) AddFallbackGroupIDOnInvalidRequest(v int64) *GroupUpdate {
_u.mutation.AddFallbackGroupIDOnInvalidRequest(v)
return _u
}
// ClearFallbackGroupIDOnInvalidRequest clears the value of the "fallback_group_id_on_invalid_request" field.
func (_u *GroupUpdate) ClearFallbackGroupIDOnInvalidRequest() *GroupUpdate {
_u.mutation.ClearFallbackGroupIDOnInvalidRequest()
return _u
}
// SetModelRouting sets the "model_routing" field. // SetModelRouting sets the "model_routing" field.
func (_u *GroupUpdate) SetModelRouting(v map[string][]int64) *GroupUpdate { func (_u *GroupUpdate) SetModelRouting(v map[string][]int64) *GroupUpdate {
_u.mutation.SetModelRouting(v) _u.mutation.SetModelRouting(v)
...@@ -829,6 +856,15 @@ func (_u *GroupUpdate) sqlSave(ctx context.Context) (_node int, err error) { ...@@ -829,6 +856,15 @@ func (_u *GroupUpdate) sqlSave(ctx context.Context) (_node int, err error) {
if _u.mutation.FallbackGroupIDCleared() { if _u.mutation.FallbackGroupIDCleared() {
_spec.ClearField(group.FieldFallbackGroupID, field.TypeInt64) _spec.ClearField(group.FieldFallbackGroupID, field.TypeInt64)
} }
if value, ok := _u.mutation.FallbackGroupIDOnInvalidRequest(); ok {
_spec.SetField(group.FieldFallbackGroupIDOnInvalidRequest, field.TypeInt64, value)
}
if value, ok := _u.mutation.AddedFallbackGroupIDOnInvalidRequest(); ok {
_spec.AddField(group.FieldFallbackGroupIDOnInvalidRequest, field.TypeInt64, value)
}
if _u.mutation.FallbackGroupIDOnInvalidRequestCleared() {
_spec.ClearField(group.FieldFallbackGroupIDOnInvalidRequest, field.TypeInt64)
}
if value, ok := _u.mutation.ModelRouting(); ok { if value, ok := _u.mutation.ModelRouting(); ok {
_spec.SetField(group.FieldModelRouting, field.TypeJSON, value) _spec.SetField(group.FieldModelRouting, field.TypeJSON, value)
} }
...@@ -1513,6 +1549,33 @@ func (_u *GroupUpdateOne) ClearFallbackGroupID() *GroupUpdateOne { ...@@ -1513,6 +1549,33 @@ func (_u *GroupUpdateOne) ClearFallbackGroupID() *GroupUpdateOne {
return _u return _u
} }
// SetFallbackGroupIDOnInvalidRequest sets the "fallback_group_id_on_invalid_request" field.
func (_u *GroupUpdateOne) SetFallbackGroupIDOnInvalidRequest(v int64) *GroupUpdateOne {
_u.mutation.ResetFallbackGroupIDOnInvalidRequest()
_u.mutation.SetFallbackGroupIDOnInvalidRequest(v)
return _u
}
// SetNillableFallbackGroupIDOnInvalidRequest sets the "fallback_group_id_on_invalid_request" field if the given value is not nil.
func (_u *GroupUpdateOne) SetNillableFallbackGroupIDOnInvalidRequest(v *int64) *GroupUpdateOne {
if v != nil {
_u.SetFallbackGroupIDOnInvalidRequest(*v)
}
return _u
}
// AddFallbackGroupIDOnInvalidRequest adds value to the "fallback_group_id_on_invalid_request" field.
func (_u *GroupUpdateOne) AddFallbackGroupIDOnInvalidRequest(v int64) *GroupUpdateOne {
_u.mutation.AddFallbackGroupIDOnInvalidRequest(v)
return _u
}
// ClearFallbackGroupIDOnInvalidRequest clears the value of the "fallback_group_id_on_invalid_request" field.
func (_u *GroupUpdateOne) ClearFallbackGroupIDOnInvalidRequest() *GroupUpdateOne {
_u.mutation.ClearFallbackGroupIDOnInvalidRequest()
return _u
}
// SetModelRouting sets the "model_routing" field. // SetModelRouting sets the "model_routing" field.
func (_u *GroupUpdateOne) SetModelRouting(v map[string][]int64) *GroupUpdateOne { func (_u *GroupUpdateOne) SetModelRouting(v map[string][]int64) *GroupUpdateOne {
_u.mutation.SetModelRouting(v) _u.mutation.SetModelRouting(v)
...@@ -1977,6 +2040,15 @@ func (_u *GroupUpdateOne) sqlSave(ctx context.Context) (_node *Group, err error) ...@@ -1977,6 +2040,15 @@ func (_u *GroupUpdateOne) sqlSave(ctx context.Context) (_node *Group, err error)
if _u.mutation.FallbackGroupIDCleared() { if _u.mutation.FallbackGroupIDCleared() {
_spec.ClearField(group.FieldFallbackGroupID, field.TypeInt64) _spec.ClearField(group.FieldFallbackGroupID, field.TypeInt64)
} }
if value, ok := _u.mutation.FallbackGroupIDOnInvalidRequest(); ok {
_spec.SetField(group.FieldFallbackGroupIDOnInvalidRequest, field.TypeInt64, value)
}
if value, ok := _u.mutation.AddedFallbackGroupIDOnInvalidRequest(); ok {
_spec.AddField(group.FieldFallbackGroupIDOnInvalidRequest, field.TypeInt64, value)
}
if _u.mutation.FallbackGroupIDOnInvalidRequestCleared() {
_spec.ClearField(group.FieldFallbackGroupIDOnInvalidRequest, field.TypeInt64)
}
if value, ok := _u.mutation.ModelRouting(); ok { if value, ok := _u.mutation.ModelRouting(); ok {
_spec.SetField(group.FieldModelRouting, field.TypeJSON, value) _spec.SetField(group.FieldModelRouting, field.TypeJSON, value)
} }
......
...@@ -226,6 +226,7 @@ var ( ...@@ -226,6 +226,7 @@ var (
{Name: "image_price_4k", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,8)"}}, {Name: "image_price_4k", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,8)"}},
{Name: "claude_code_only", Type: field.TypeBool, Default: false}, {Name: "claude_code_only", Type: field.TypeBool, Default: false},
{Name: "fallback_group_id", Type: field.TypeInt64, Nullable: true}, {Name: "fallback_group_id", Type: field.TypeInt64, Nullable: true},
{Name: "fallback_group_id_on_invalid_request", Type: field.TypeInt64, Nullable: true},
{Name: "model_routing", Type: field.TypeJSON, Nullable: true, SchemaType: map[string]string{"postgres": "jsonb"}}, {Name: "model_routing", Type: field.TypeJSON, Nullable: true, SchemaType: map[string]string{"postgres": "jsonb"}},
{Name: "model_routing_enabled", Type: field.TypeBool, Default: false}, {Name: "model_routing_enabled", Type: field.TypeBool, Default: false},
} }
......
...@@ -3864,6 +3864,8 @@ type GroupMutation struct { ...@@ -3864,6 +3864,8 @@ type GroupMutation struct {
claude_code_only *bool claude_code_only *bool
fallback_group_id *int64 fallback_group_id *int64
addfallback_group_id *int64 addfallback_group_id *int64
fallback_group_id_on_invalid_request *int64
addfallback_group_id_on_invalid_request *int64
model_routing *map[string][]int64 model_routing *map[string][]int64
model_routing_enabled *bool model_routing_enabled *bool
clearedFields map[string]struct{} clearedFields map[string]struct{}
...@@ -4976,6 +4978,76 @@ func (m *GroupMutation) ResetFallbackGroupID() { ...@@ -4976,6 +4978,76 @@ func (m *GroupMutation) ResetFallbackGroupID() {
delete(m.clearedFields, group.FieldFallbackGroupID) delete(m.clearedFields, group.FieldFallbackGroupID)
} }
   
// SetFallbackGroupIDOnInvalidRequest sets the "fallback_group_id_on_invalid_request" field.
func (m *GroupMutation) SetFallbackGroupIDOnInvalidRequest(i int64) {
m.fallback_group_id_on_invalid_request = &i
m.addfallback_group_id_on_invalid_request = nil
}
// FallbackGroupIDOnInvalidRequest returns the value of the "fallback_group_id_on_invalid_request" field in the mutation.
func (m *GroupMutation) FallbackGroupIDOnInvalidRequest() (r int64, exists bool) {
v := m.fallback_group_id_on_invalid_request
if v == nil {
return
}
return *v, true
}
// OldFallbackGroupIDOnInvalidRequest returns the old "fallback_group_id_on_invalid_request" field's value of the Group entity.
// If the Group object wasn't provided to the builder, the object is fetched from the database.
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
func (m *GroupMutation) OldFallbackGroupIDOnInvalidRequest(ctx context.Context) (v *int64, err error) {
if !m.op.Is(OpUpdateOne) {
return v, errors.New("OldFallbackGroupIDOnInvalidRequest is only allowed on UpdateOne operations")
}
if m.id == nil || m.oldValue == nil {
return v, errors.New("OldFallbackGroupIDOnInvalidRequest requires an ID field in the mutation")
}
oldValue, err := m.oldValue(ctx)
if err != nil {
return v, fmt.Errorf("querying old value for OldFallbackGroupIDOnInvalidRequest: %w", err)
}
return oldValue.FallbackGroupIDOnInvalidRequest, nil
}
// AddFallbackGroupIDOnInvalidRequest adds i to the "fallback_group_id_on_invalid_request" field.
func (m *GroupMutation) AddFallbackGroupIDOnInvalidRequest(i int64) {
if m.addfallback_group_id_on_invalid_request != nil {
*m.addfallback_group_id_on_invalid_request += i
} else {
m.addfallback_group_id_on_invalid_request = &i
}
}
// AddedFallbackGroupIDOnInvalidRequest returns the value that was added to the "fallback_group_id_on_invalid_request" field in this mutation.
func (m *GroupMutation) AddedFallbackGroupIDOnInvalidRequest() (r int64, exists bool) {
v := m.addfallback_group_id_on_invalid_request
if v == nil {
return
}
return *v, true
}
// ClearFallbackGroupIDOnInvalidRequest clears the value of the "fallback_group_id_on_invalid_request" field.
func (m *GroupMutation) ClearFallbackGroupIDOnInvalidRequest() {
m.fallback_group_id_on_invalid_request = nil
m.addfallback_group_id_on_invalid_request = nil
m.clearedFields[group.FieldFallbackGroupIDOnInvalidRequest] = struct{}{}
}
// FallbackGroupIDOnInvalidRequestCleared returns if the "fallback_group_id_on_invalid_request" field was cleared in this mutation.
func (m *GroupMutation) FallbackGroupIDOnInvalidRequestCleared() bool {
_, ok := m.clearedFields[group.FieldFallbackGroupIDOnInvalidRequest]
return ok
}
// ResetFallbackGroupIDOnInvalidRequest resets all changes to the "fallback_group_id_on_invalid_request" field.
func (m *GroupMutation) ResetFallbackGroupIDOnInvalidRequest() {
m.fallback_group_id_on_invalid_request = nil
m.addfallback_group_id_on_invalid_request = nil
delete(m.clearedFields, group.FieldFallbackGroupIDOnInvalidRequest)
}
// SetModelRouting sets the "model_routing" field. // SetModelRouting sets the "model_routing" field.
func (m *GroupMutation) SetModelRouting(value map[string][]int64) { func (m *GroupMutation) SetModelRouting(value map[string][]int64) {
m.model_routing = &value m.model_routing = &value
...@@ -5419,7 +5491,7 @@ func (m *GroupMutation) Type() string { ...@@ -5419,7 +5491,7 @@ func (m *GroupMutation) Type() string {
// order to get all numeric fields that were incremented/decremented, call // order to get all numeric fields that were incremented/decremented, call
// AddedFields(). // AddedFields().
func (m *GroupMutation) Fields() []string { func (m *GroupMutation) Fields() []string {
fields := make([]string, 0, 21) fields := make([]string, 0, 22)
if m.created_at != nil { if m.created_at != nil {
fields = append(fields, group.FieldCreatedAt) fields = append(fields, group.FieldCreatedAt)
} }
...@@ -5477,6 +5549,9 @@ func (m *GroupMutation) Fields() []string { ...@@ -5477,6 +5549,9 @@ func (m *GroupMutation) Fields() []string {
if m.fallback_group_id != nil { if m.fallback_group_id != nil {
fields = append(fields, group.FieldFallbackGroupID) fields = append(fields, group.FieldFallbackGroupID)
} }
if m.fallback_group_id_on_invalid_request != nil {
fields = append(fields, group.FieldFallbackGroupIDOnInvalidRequest)
}
if m.model_routing != nil { if m.model_routing != nil {
fields = append(fields, group.FieldModelRouting) fields = append(fields, group.FieldModelRouting)
} }
...@@ -5529,6 +5604,8 @@ func (m *GroupMutation) Field(name string) (ent.Value, bool) { ...@@ -5529,6 +5604,8 @@ func (m *GroupMutation) Field(name string) (ent.Value, bool) {
return m.ClaudeCodeOnly() return m.ClaudeCodeOnly()
case group.FieldFallbackGroupID: case group.FieldFallbackGroupID:
return m.FallbackGroupID() return m.FallbackGroupID()
case group.FieldFallbackGroupIDOnInvalidRequest:
return m.FallbackGroupIDOnInvalidRequest()
case group.FieldModelRouting: case group.FieldModelRouting:
return m.ModelRouting() return m.ModelRouting()
case group.FieldModelRoutingEnabled: case group.FieldModelRoutingEnabled:
...@@ -5580,6 +5657,8 @@ func (m *GroupMutation) OldField(ctx context.Context, name string) (ent.Value, e ...@@ -5580,6 +5657,8 @@ func (m *GroupMutation) OldField(ctx context.Context, name string) (ent.Value, e
return m.OldClaudeCodeOnly(ctx) return m.OldClaudeCodeOnly(ctx)
case group.FieldFallbackGroupID: case group.FieldFallbackGroupID:
return m.OldFallbackGroupID(ctx) return m.OldFallbackGroupID(ctx)
case group.FieldFallbackGroupIDOnInvalidRequest:
return m.OldFallbackGroupIDOnInvalidRequest(ctx)
case group.FieldModelRouting: case group.FieldModelRouting:
return m.OldModelRouting(ctx) return m.OldModelRouting(ctx)
case group.FieldModelRoutingEnabled: case group.FieldModelRoutingEnabled:
...@@ -5726,6 +5805,13 @@ func (m *GroupMutation) SetField(name string, value ent.Value) error { ...@@ -5726,6 +5805,13 @@ func (m *GroupMutation) SetField(name string, value ent.Value) error {
} }
m.SetFallbackGroupID(v) m.SetFallbackGroupID(v)
return nil return nil
case group.FieldFallbackGroupIDOnInvalidRequest:
v, ok := value.(int64)
if !ok {
return fmt.Errorf("unexpected type %T for field %s", value, name)
}
m.SetFallbackGroupIDOnInvalidRequest(v)
return nil
case group.FieldModelRouting: case group.FieldModelRouting:
v, ok := value.(map[string][]int64) v, ok := value.(map[string][]int64)
if !ok { if !ok {
...@@ -5775,6 +5861,9 @@ func (m *GroupMutation) AddedFields() []string { ...@@ -5775,6 +5861,9 @@ func (m *GroupMutation) AddedFields() []string {
if m.addfallback_group_id != nil { if m.addfallback_group_id != nil {
fields = append(fields, group.FieldFallbackGroupID) fields = append(fields, group.FieldFallbackGroupID)
} }
if m.addfallback_group_id_on_invalid_request != nil {
fields = append(fields, group.FieldFallbackGroupIDOnInvalidRequest)
}
return fields return fields
} }
   
...@@ -5801,6 +5890,8 @@ func (m *GroupMutation) AddedField(name string) (ent.Value, bool) { ...@@ -5801,6 +5890,8 @@ func (m *GroupMutation) AddedField(name string) (ent.Value, bool) {
return m.AddedImagePrice4k() return m.AddedImagePrice4k()
case group.FieldFallbackGroupID: case group.FieldFallbackGroupID:
return m.AddedFallbackGroupID() return m.AddedFallbackGroupID()
case group.FieldFallbackGroupIDOnInvalidRequest:
return m.AddedFallbackGroupIDOnInvalidRequest()
} }
return nil, false return nil, false
} }
...@@ -5873,6 +5964,13 @@ func (m *GroupMutation) AddField(name string, value ent.Value) error { ...@@ -5873,6 +5964,13 @@ func (m *GroupMutation) AddField(name string, value ent.Value) error {
} }
m.AddFallbackGroupID(v) m.AddFallbackGroupID(v)
return nil return nil
case group.FieldFallbackGroupIDOnInvalidRequest:
v, ok := value.(int64)
if !ok {
return fmt.Errorf("unexpected type %T for field %s", value, name)
}
m.AddFallbackGroupIDOnInvalidRequest(v)
return nil
} }
return fmt.Errorf("unknown Group numeric field %s", name) return fmt.Errorf("unknown Group numeric field %s", name)
} }
...@@ -5908,6 +6006,9 @@ func (m *GroupMutation) ClearedFields() []string { ...@@ -5908,6 +6006,9 @@ func (m *GroupMutation) ClearedFields() []string {
if m.FieldCleared(group.FieldFallbackGroupID) { if m.FieldCleared(group.FieldFallbackGroupID) {
fields = append(fields, group.FieldFallbackGroupID) fields = append(fields, group.FieldFallbackGroupID)
} }
if m.FieldCleared(group.FieldFallbackGroupIDOnInvalidRequest) {
fields = append(fields, group.FieldFallbackGroupIDOnInvalidRequest)
}
if m.FieldCleared(group.FieldModelRouting) { if m.FieldCleared(group.FieldModelRouting) {
fields = append(fields, group.FieldModelRouting) fields = append(fields, group.FieldModelRouting)
} }
...@@ -5952,6 +6053,9 @@ func (m *GroupMutation) ClearField(name string) error { ...@@ -5952,6 +6053,9 @@ func (m *GroupMutation) ClearField(name string) error {
case group.FieldFallbackGroupID: case group.FieldFallbackGroupID:
m.ClearFallbackGroupID() m.ClearFallbackGroupID()
return nil return nil
case group.FieldFallbackGroupIDOnInvalidRequest:
m.ClearFallbackGroupIDOnInvalidRequest()
return nil
case group.FieldModelRouting: case group.FieldModelRouting:
m.ClearModelRouting() m.ClearModelRouting()
return nil return nil
...@@ -6020,6 +6124,9 @@ func (m *GroupMutation) ResetField(name string) error { ...@@ -6020,6 +6124,9 @@ func (m *GroupMutation) ResetField(name string) error {
case group.FieldFallbackGroupID: case group.FieldFallbackGroupID:
m.ResetFallbackGroupID() m.ResetFallbackGroupID()
return nil return nil
case group.FieldFallbackGroupIDOnInvalidRequest:
m.ResetFallbackGroupIDOnInvalidRequest()
return nil
case group.FieldModelRouting: case group.FieldModelRouting:
m.ResetModelRouting() m.ResetModelRouting()
return nil return nil
......
...@@ -281,7 +281,7 @@ func init() { ...@@ -281,7 +281,7 @@ func init() {
// group.DefaultClaudeCodeOnly holds the default value on creation for the claude_code_only field. // group.DefaultClaudeCodeOnly holds the default value on creation for the claude_code_only field.
group.DefaultClaudeCodeOnly = groupDescClaudeCodeOnly.Default.(bool) group.DefaultClaudeCodeOnly = groupDescClaudeCodeOnly.Default.(bool)
// groupDescModelRoutingEnabled is the schema descriptor for model_routing_enabled field. // groupDescModelRoutingEnabled is the schema descriptor for model_routing_enabled field.
groupDescModelRoutingEnabled := groupFields[17].Descriptor() groupDescModelRoutingEnabled := groupFields[18].Descriptor()
// group.DefaultModelRoutingEnabled holds the default value on creation for the model_routing_enabled field. // group.DefaultModelRoutingEnabled holds the default value on creation for the model_routing_enabled field.
group.DefaultModelRoutingEnabled = groupDescModelRoutingEnabled.Default.(bool) group.DefaultModelRoutingEnabled = groupDescModelRoutingEnabled.Default.(bool)
promocodeFields := schema.PromoCode{}.Fields() promocodeFields := schema.PromoCode{}.Fields()
......
...@@ -95,6 +95,10 @@ func (Group) Fields() []ent.Field { ...@@ -95,6 +95,10 @@ func (Group) Fields() []ent.Field {
Optional(). Optional().
Nillable(). Nillable().
Comment("非 Claude Code 请求降级使用的分组 ID"), Comment("非 Claude Code 请求降级使用的分组 ID"),
field.Int64("fallback_group_id_on_invalid_request").
Optional().
Nillable().
Comment("无效请求兜底使用的分组 ID"),
// 模型路由配置 (added by migration 040) // 模型路由配置 (added by migration 040)
field.JSON("model_routing", map[string][]int64{}). field.JSON("model_routing", map[string][]int64{}).
......
...@@ -40,6 +40,7 @@ type CreateGroupRequest struct { ...@@ -40,6 +40,7 @@ type CreateGroupRequest struct {
ImagePrice4K *float64 `json:"image_price_4k"` ImagePrice4K *float64 `json:"image_price_4k"`
ClaudeCodeOnly bool `json:"claude_code_only"` ClaudeCodeOnly bool `json:"claude_code_only"`
FallbackGroupID *int64 `json:"fallback_group_id"` FallbackGroupID *int64 `json:"fallback_group_id"`
FallbackGroupIDOnInvalidRequest *int64 `json:"fallback_group_id_on_invalid_request"`
// 模型路由配置(仅 anthropic 平台使用) // 模型路由配置(仅 anthropic 平台使用)
ModelRouting map[string][]int64 `json:"model_routing"` ModelRouting map[string][]int64 `json:"model_routing"`
ModelRoutingEnabled bool `json:"model_routing_enabled"` ModelRoutingEnabled bool `json:"model_routing_enabled"`
...@@ -63,6 +64,7 @@ type UpdateGroupRequest struct { ...@@ -63,6 +64,7 @@ type UpdateGroupRequest struct {
ImagePrice4K *float64 `json:"image_price_4k"` ImagePrice4K *float64 `json:"image_price_4k"`
ClaudeCodeOnly *bool `json:"claude_code_only"` ClaudeCodeOnly *bool `json:"claude_code_only"`
FallbackGroupID *int64 `json:"fallback_group_id"` FallbackGroupID *int64 `json:"fallback_group_id"`
FallbackGroupIDOnInvalidRequest *int64 `json:"fallback_group_id_on_invalid_request"`
// 模型路由配置(仅 anthropic 平台使用) // 模型路由配置(仅 anthropic 平台使用)
ModelRouting map[string][]int64 `json:"model_routing"` ModelRouting map[string][]int64 `json:"model_routing"`
ModelRoutingEnabled *bool `json:"model_routing_enabled"` ModelRoutingEnabled *bool `json:"model_routing_enabled"`
...@@ -169,6 +171,7 @@ func (h *GroupHandler) Create(c *gin.Context) { ...@@ -169,6 +171,7 @@ func (h *GroupHandler) Create(c *gin.Context) {
ImagePrice4K: req.ImagePrice4K, ImagePrice4K: req.ImagePrice4K,
ClaudeCodeOnly: req.ClaudeCodeOnly, ClaudeCodeOnly: req.ClaudeCodeOnly,
FallbackGroupID: req.FallbackGroupID, FallbackGroupID: req.FallbackGroupID,
FallbackGroupIDOnInvalidRequest: req.FallbackGroupIDOnInvalidRequest,
ModelRouting: req.ModelRouting, ModelRouting: req.ModelRouting,
ModelRoutingEnabled: req.ModelRoutingEnabled, ModelRoutingEnabled: req.ModelRoutingEnabled,
}) })
...@@ -211,6 +214,7 @@ func (h *GroupHandler) Update(c *gin.Context) { ...@@ -211,6 +214,7 @@ func (h *GroupHandler) Update(c *gin.Context) {
ImagePrice4K: req.ImagePrice4K, ImagePrice4K: req.ImagePrice4K,
ClaudeCodeOnly: req.ClaudeCodeOnly, ClaudeCodeOnly: req.ClaudeCodeOnly,
FallbackGroupID: req.FallbackGroupID, FallbackGroupID: req.FallbackGroupID,
FallbackGroupIDOnInvalidRequest: req.FallbackGroupIDOnInvalidRequest,
ModelRouting: req.ModelRouting, ModelRouting: req.ModelRouting,
ModelRoutingEnabled: req.ModelRoutingEnabled, ModelRoutingEnabled: req.ModelRoutingEnabled,
}) })
......
...@@ -89,6 +89,7 @@ func GroupFromServiceShallow(g *service.Group) *Group { ...@@ -89,6 +89,7 @@ func GroupFromServiceShallow(g *service.Group) *Group {
ImagePrice4K: g.ImagePrice4K, ImagePrice4K: g.ImagePrice4K,
ClaudeCodeOnly: g.ClaudeCodeOnly, ClaudeCodeOnly: g.ClaudeCodeOnly,
FallbackGroupID: g.FallbackGroupID, FallbackGroupID: g.FallbackGroupID,
FallbackGroupIDOnInvalidRequest: g.FallbackGroupIDOnInvalidRequest,
ModelRouting: g.ModelRouting, ModelRouting: g.ModelRouting,
ModelRoutingEnabled: g.ModelRoutingEnabled, ModelRoutingEnabled: g.ModelRoutingEnabled,
CreatedAt: g.CreatedAt, CreatedAt: g.CreatedAt,
......
...@@ -57,6 +57,8 @@ type Group struct { ...@@ -57,6 +57,8 @@ type Group struct {
// Claude Code 客户端限制 // Claude Code 客户端限制
ClaudeCodeOnly bool `json:"claude_code_only"` ClaudeCodeOnly bool `json:"claude_code_only"`
FallbackGroupID *int64 `json:"fallback_group_id"` FallbackGroupID *int64 `json:"fallback_group_id"`
// 无效请求兜底分组
FallbackGroupIDOnInvalidRequest *int64 `json:"fallback_group_id_on_invalid_request"`
// 模型路由配置(仅 anthropic 平台使用) // 模型路由配置(仅 anthropic 平台使用)
ModelRouting map[string][]int64 `json:"model_routing"` ModelRouting map[string][]int64 `json:"model_routing"`
......
...@@ -14,6 +14,7 @@ import ( ...@@ -14,6 +14,7 @@ import (
"github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity" "github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
"github.com/Wei-Shaw/sub2api/internal/pkg/claude" "github.com/Wei-Shaw/sub2api/internal/pkg/claude"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
pkgerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" pkgerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/ip" "github.com/Wei-Shaw/sub2api/internal/pkg/ip"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai" "github.com/Wei-Shaw/sub2api/internal/pkg/openai"
...@@ -325,14 +326,24 @@ func (h *GatewayHandler) Messages(c *gin.Context) { ...@@ -325,14 +326,24 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
} }
} }
currentAPIKey := apiKey
currentSubscription := subscription
var fallbackGroupID *int64
if apiKey.Group != nil {
fallbackGroupID = apiKey.Group.FallbackGroupIDOnInvalidRequest
}
fallbackUsed := false
for {
maxAccountSwitches := h.maxAccountSwitches maxAccountSwitches := h.maxAccountSwitches
switchCount := 0 switchCount := 0
failedAccountIDs := make(map[int64]struct{}) failedAccountIDs := make(map[int64]struct{})
lastFailoverStatus := 0 lastFailoverStatus := 0
retryWithFallback := false
for { for {
// 选择支持该模型的账号 // 选择支持该模型的账号
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, reqModel, failedAccountIDs, parsedReq.MetadataUserID) selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), currentAPIKey.GroupID, sessionKey, reqModel, failedAccountIDs, parsedReq.MetadataUserID)
if err != nil { if err != nil {
if len(failedAccountIDs) == 0 { if len(failedAccountIDs) == 0 {
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted) h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted)
...@@ -399,7 +410,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) { ...@@ -399,7 +410,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID) h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID)
accountWaitCounted = false accountWaitCounted = false
} }
if err := h.gatewayService.BindStickySession(c.Request.Context(), apiKey.GroupID, sessionKey, account.ID); err != nil { if err := h.gatewayService.BindStickySession(c.Request.Context(), currentAPIKey.GroupID, sessionKey, account.ID); err != nil {
log.Printf("Bind sticky session failed: %v", err) log.Printf("Bind sticky session failed: %v", err)
} }
} }
...@@ -417,6 +428,41 @@ func (h *GatewayHandler) Messages(c *gin.Context) { ...@@ -417,6 +428,41 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
accountReleaseFunc() accountReleaseFunc()
} }
if err != nil { if err != nil {
var promptTooLongErr *service.PromptTooLongError
if errors.As(err, &promptTooLongErr) {
log.Printf("Prompt too long from antigravity: group=%d fallback_group_id=%v fallback_used=%v", currentAPIKey.GroupID, fallbackGroupID, fallbackUsed)
if !fallbackUsed && fallbackGroupID != nil && *fallbackGroupID > 0 {
fallbackGroup, err := h.gatewayService.ResolveGroupByID(c.Request.Context(), *fallbackGroupID)
if err != nil {
log.Printf("Resolve fallback group failed: %v", err)
_ = h.antigravityGatewayService.WriteMappedClaudeError(c, account, promptTooLongErr.StatusCode, promptTooLongErr.RequestID, promptTooLongErr.Body)
return
}
if fallbackGroup.Platform != service.PlatformAnthropic ||
fallbackGroup.SubscriptionType == service.SubscriptionTypeSubscription ||
fallbackGroup.FallbackGroupIDOnInvalidRequest != nil {
log.Printf("Fallback group invalid: group=%d platform=%s subscription=%s", fallbackGroup.ID, fallbackGroup.Platform, fallbackGroup.SubscriptionType)
_ = h.antigravityGatewayService.WriteMappedClaudeError(c, account, promptTooLongErr.StatusCode, promptTooLongErr.RequestID, promptTooLongErr.Body)
return
}
fallbackAPIKey := cloneAPIKeyWithGroup(apiKey, fallbackGroup)
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), fallbackAPIKey.User, fallbackAPIKey, fallbackGroup, nil); err != nil {
status, code, message := billingErrorDetails(err)
h.handleStreamingAwareError(c, status, code, message, streamStarted)
return
}
// 兜底重试按“直接请求兜底分组”处理:清除强制平台,允许按分组平台调度
ctx := context.WithValue(c.Request.Context(), ctxkey.ForcePlatform, "")
c.Request = c.Request.WithContext(ctx)
currentAPIKey = fallbackAPIKey
currentSubscription = nil
fallbackUsed = true
retryWithFallback = true
break
}
_ = h.antigravityGatewayService.WriteMappedClaudeError(c, account, promptTooLongErr.StatusCode, promptTooLongErr.RequestID, promptTooLongErr.Body)
return
}
var failoverErr *service.UpstreamFailoverError var failoverErr *service.UpstreamFailoverError
if errors.As(err, &failoverErr) { if errors.As(err, &failoverErr) {
failedAccountIDs[account.ID] = struct{}{} failedAccountIDs[account.ID] = struct{}{}
...@@ -444,10 +490,10 @@ func (h *GatewayHandler) Messages(c *gin.Context) { ...@@ -444,10 +490,10 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
defer cancel() defer cancel()
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{ if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
Result: result, Result: result,
APIKey: apiKey, APIKey: currentAPIKey,
User: apiKey.User, User: currentAPIKey.User,
Account: usedAccount, Account: usedAccount,
Subscription: subscription, Subscription: currentSubscription,
UserAgent: ua, UserAgent: ua,
IPAddress: clientIP, IPAddress: clientIP,
}); err != nil { }); err != nil {
...@@ -456,6 +502,11 @@ func (h *GatewayHandler) Messages(c *gin.Context) { ...@@ -456,6 +502,11 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
}(result, account, userAgent, clientIP) }(result, account, userAgent, clientIP)
return return
} }
if !retryWithFallback {
return
}
}
} }
// Models handles listing available models // Models handles listing available models
...@@ -518,6 +569,17 @@ func (h *GatewayHandler) AntigravityModels(c *gin.Context) { ...@@ -518,6 +569,17 @@ func (h *GatewayHandler) AntigravityModels(c *gin.Context) {
}) })
} }
func cloneAPIKeyWithGroup(apiKey *service.APIKey, group *service.Group) *service.APIKey {
if apiKey == nil || group == nil {
return apiKey
}
cloned := *apiKey
groupID := group.ID
cloned.GroupID = &groupID
cloned.Group = group
return &cloned
}
// Usage handles getting account balance for CC Switch integration // Usage handles getting account balance for CC Switch integration
// GET /v1/usage // GET /v1/usage
func (h *GatewayHandler) Usage(c *gin.Context) { func (h *GatewayHandler) Usage(c *gin.Context) {
......
...@@ -136,6 +136,7 @@ func (r *apiKeyRepository) GetByKeyForAuth(ctx context.Context, key string) (*se ...@@ -136,6 +136,7 @@ func (r *apiKeyRepository) GetByKeyForAuth(ctx context.Context, key string) (*se
group.FieldImagePrice4k, group.FieldImagePrice4k,
group.FieldClaudeCodeOnly, group.FieldClaudeCodeOnly,
group.FieldFallbackGroupID, group.FieldFallbackGroupID,
group.FieldFallbackGroupIDOnInvalidRequest,
group.FieldModelRoutingEnabled, group.FieldModelRoutingEnabled,
group.FieldModelRouting, group.FieldModelRouting,
) )
...@@ -424,6 +425,7 @@ func groupEntityToService(g *dbent.Group) *service.Group { ...@@ -424,6 +425,7 @@ func groupEntityToService(g *dbent.Group) *service.Group {
DefaultValidityDays: g.DefaultValidityDays, DefaultValidityDays: g.DefaultValidityDays,
ClaudeCodeOnly: g.ClaudeCodeOnly, ClaudeCodeOnly: g.ClaudeCodeOnly,
FallbackGroupID: g.FallbackGroupID, FallbackGroupID: g.FallbackGroupID,
FallbackGroupIDOnInvalidRequest: g.FallbackGroupIDOnInvalidRequest,
ModelRouting: g.ModelRouting, ModelRouting: g.ModelRouting,
ModelRoutingEnabled: g.ModelRoutingEnabled, ModelRoutingEnabled: g.ModelRoutingEnabled,
CreatedAt: g.CreatedAt, CreatedAt: g.CreatedAt,
......
...@@ -50,6 +50,7 @@ func (r *groupRepository) Create(ctx context.Context, groupIn *service.Group) er ...@@ -50,6 +50,7 @@ func (r *groupRepository) Create(ctx context.Context, groupIn *service.Group) er
SetDefaultValidityDays(groupIn.DefaultValidityDays). SetDefaultValidityDays(groupIn.DefaultValidityDays).
SetClaudeCodeOnly(groupIn.ClaudeCodeOnly). SetClaudeCodeOnly(groupIn.ClaudeCodeOnly).
SetNillableFallbackGroupID(groupIn.FallbackGroupID). SetNillableFallbackGroupID(groupIn.FallbackGroupID).
SetNillableFallbackGroupIDOnInvalidRequest(groupIn.FallbackGroupIDOnInvalidRequest).
SetModelRoutingEnabled(groupIn.ModelRoutingEnabled) SetModelRoutingEnabled(groupIn.ModelRoutingEnabled)
// 设置模型路由配置 // 设置模型路由配置
...@@ -116,6 +117,12 @@ func (r *groupRepository) Update(ctx context.Context, groupIn *service.Group) er ...@@ -116,6 +117,12 @@ func (r *groupRepository) Update(ctx context.Context, groupIn *service.Group) er
} else { } else {
builder = builder.ClearFallbackGroupID() builder = builder.ClearFallbackGroupID()
} }
// 处理 FallbackGroupIDOnInvalidRequest:nil 时清除,否则设置
if groupIn.FallbackGroupIDOnInvalidRequest != nil {
builder = builder.SetFallbackGroupIDOnInvalidRequest(*groupIn.FallbackGroupIDOnInvalidRequest)
} else {
builder = builder.ClearFallbackGroupIDOnInvalidRequest()
}
// 处理 ModelRouting:nil 时清除,否则设置 // 处理 ModelRouting:nil 时清除,否则设置
if groupIn.ModelRouting != nil { if groupIn.ModelRouting != nil {
......
...@@ -108,6 +108,8 @@ type CreateGroupInput struct { ...@@ -108,6 +108,8 @@ type CreateGroupInput struct {
ImagePrice4K *float64 ImagePrice4K *float64
ClaudeCodeOnly bool // 仅允许 Claude Code 客户端 ClaudeCodeOnly bool // 仅允许 Claude Code 客户端
FallbackGroupID *int64 // 降级分组 ID FallbackGroupID *int64 // 降级分组 ID
// 无效请求兜底分组 ID(仅 anthropic 平台使用)
FallbackGroupIDOnInvalidRequest *int64
// 模型路由配置(仅 anthropic 平台使用) // 模型路由配置(仅 anthropic 平台使用)
ModelRouting map[string][]int64 ModelRouting map[string][]int64
ModelRoutingEnabled bool // 是否启用模型路由 ModelRoutingEnabled bool // 是否启用模型路由
...@@ -130,6 +132,8 @@ type UpdateGroupInput struct { ...@@ -130,6 +132,8 @@ type UpdateGroupInput struct {
ImagePrice4K *float64 ImagePrice4K *float64
ClaudeCodeOnly *bool // 仅允许 Claude Code 客户端 ClaudeCodeOnly *bool // 仅允许 Claude Code 客户端
FallbackGroupID *int64 // 降级分组 ID FallbackGroupID *int64 // 降级分组 ID
// 无效请求兜底分组 ID(仅 anthropic 平台使用)
FallbackGroupIDOnInvalidRequest *int64
// 模型路由配置(仅 anthropic 平台使用) // 模型路由配置(仅 anthropic 平台使用)
ModelRouting map[string][]int64 ModelRouting map[string][]int64
ModelRoutingEnabled *bool // 是否启用模型路由 ModelRoutingEnabled *bool // 是否启用模型路由
...@@ -572,6 +576,16 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn ...@@ -572,6 +576,16 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn
return nil, err return nil, err
} }
} }
fallbackOnInvalidRequest := input.FallbackGroupIDOnInvalidRequest
if fallbackOnInvalidRequest != nil && *fallbackOnInvalidRequest <= 0 {
fallbackOnInvalidRequest = nil
}
// 校验无效请求兜底分组
if fallbackOnInvalidRequest != nil {
if err := s.validateFallbackGroupOnInvalidRequest(ctx, 0, platform, subscriptionType, *fallbackOnInvalidRequest); err != nil {
return nil, err
}
}
group := &Group{ group := &Group{
Name: input.Name, Name: input.Name,
...@@ -589,6 +603,7 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn ...@@ -589,6 +603,7 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn
ImagePrice4K: imagePrice4K, ImagePrice4K: imagePrice4K,
ClaudeCodeOnly: input.ClaudeCodeOnly, ClaudeCodeOnly: input.ClaudeCodeOnly,
FallbackGroupID: input.FallbackGroupID, FallbackGroupID: input.FallbackGroupID,
FallbackGroupIDOnInvalidRequest: fallbackOnInvalidRequest,
ModelRouting: input.ModelRouting, ModelRouting: input.ModelRouting,
} }
if err := s.groupRepo.Create(ctx, group); err != nil { if err := s.groupRepo.Create(ctx, group); err != nil {
...@@ -651,6 +666,37 @@ func (s *adminServiceImpl) validateFallbackGroup(ctx context.Context, currentGro ...@@ -651,6 +666,37 @@ func (s *adminServiceImpl) validateFallbackGroup(ctx context.Context, currentGro
} }
} }
// validateFallbackGroupOnInvalidRequest 校验无效请求兜底分组的有效性
// currentGroupID: 当前分组 ID(新建时为 0)
// platform/subscriptionType: 当前分组的有效平台/订阅类型
// fallbackGroupID: 兜底分组 ID
func (s *adminServiceImpl) validateFallbackGroupOnInvalidRequest(ctx context.Context, currentGroupID int64, platform, subscriptionType string, fallbackGroupID int64) error {
if platform != PlatformAnthropic && platform != PlatformAntigravity {
return fmt.Errorf("invalid request fallback only supported for anthropic or antigravity groups")
}
if subscriptionType == SubscriptionTypeSubscription {
return fmt.Errorf("subscription groups cannot set invalid request fallback")
}
if currentGroupID > 0 && currentGroupID == fallbackGroupID {
return fmt.Errorf("cannot set self as invalid request fallback group")
}
fallbackGroup, err := s.groupRepo.GetByIDLite(ctx, fallbackGroupID)
if err != nil {
return fmt.Errorf("fallback group not found: %w", err)
}
if fallbackGroup.Platform != PlatformAnthropic {
return fmt.Errorf("fallback group must be anthropic platform")
}
if fallbackGroup.SubscriptionType == SubscriptionTypeSubscription {
return fmt.Errorf("fallback group cannot be subscription type")
}
if fallbackGroup.FallbackGroupIDOnInvalidRequest != nil {
return fmt.Errorf("fallback group cannot have invalid request fallback configured")
}
return nil
}
func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *UpdateGroupInput) (*Group, error) { func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *UpdateGroupInput) (*Group, error) {
group, err := s.groupRepo.GetByID(ctx, id) group, err := s.groupRepo.GetByID(ctx, id)
if err != nil { if err != nil {
...@@ -717,6 +763,20 @@ func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *Upd ...@@ -717,6 +763,20 @@ func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *Upd
group.FallbackGroupID = nil group.FallbackGroupID = nil
} }
} }
fallbackOnInvalidRequest := group.FallbackGroupIDOnInvalidRequest
if input.FallbackGroupIDOnInvalidRequest != nil {
if *input.FallbackGroupIDOnInvalidRequest > 0 {
fallbackOnInvalidRequest = input.FallbackGroupIDOnInvalidRequest
} else {
fallbackOnInvalidRequest = nil
}
}
if fallbackOnInvalidRequest != nil {
if err := s.validateFallbackGroupOnInvalidRequest(ctx, id, group.Platform, group.SubscriptionType, *fallbackOnInvalidRequest); err != nil {
return nil, err
}
}
group.FallbackGroupIDOnInvalidRequest = fallbackOnInvalidRequest
// 模型路由配置 // 模型路由配置
if input.ModelRouting != nil { if input.ModelRouting != nil {
......
...@@ -378,3 +378,374 @@ func (s *groupRepoStubForFallbackCycle) GetAccountCount(_ context.Context, _ int ...@@ -378,3 +378,374 @@ func (s *groupRepoStubForFallbackCycle) GetAccountCount(_ context.Context, _ int
func (s *groupRepoStubForFallbackCycle) DeleteAccountGroupsByGroupID(_ context.Context, _ int64) (int64, error) { func (s *groupRepoStubForFallbackCycle) DeleteAccountGroupsByGroupID(_ context.Context, _ int64) (int64, error) {
panic("unexpected DeleteAccountGroupsByGroupID call") panic("unexpected DeleteAccountGroupsByGroupID call")
} }
type groupRepoStubForInvalidRequestFallback struct {
groups map[int64]*Group
created *Group
updated *Group
}
func (s *groupRepoStubForInvalidRequestFallback) Create(_ context.Context, g *Group) error {
s.created = g
return nil
}
func (s *groupRepoStubForInvalidRequestFallback) Update(_ context.Context, g *Group) error {
s.updated = g
return nil
}
func (s *groupRepoStubForInvalidRequestFallback) GetByID(ctx context.Context, id int64) (*Group, error) {
return s.GetByIDLite(ctx, id)
}
func (s *groupRepoStubForInvalidRequestFallback) GetByIDLite(_ context.Context, id int64) (*Group, error) {
if g, ok := s.groups[id]; ok {
return g, nil
}
return nil, ErrGroupNotFound
}
func (s *groupRepoStubForInvalidRequestFallback) Delete(_ context.Context, _ int64) error {
panic("unexpected Delete call")
}
func (s *groupRepoStubForInvalidRequestFallback) DeleteCascade(_ context.Context, _ int64) ([]int64, error) {
panic("unexpected DeleteCascade call")
}
func (s *groupRepoStubForInvalidRequestFallback) List(_ context.Context, _ pagination.PaginationParams) ([]Group, *pagination.PaginationResult, error) {
panic("unexpected List call")
}
func (s *groupRepoStubForInvalidRequestFallback) ListWithFilters(_ context.Context, _ pagination.PaginationParams, _, _, _ string, _ *bool) ([]Group, *pagination.PaginationResult, error) {
panic("unexpected ListWithFilters call")
}
func (s *groupRepoStubForInvalidRequestFallback) ListActive(_ context.Context) ([]Group, error) {
panic("unexpected ListActive call")
}
func (s *groupRepoStubForInvalidRequestFallback) ListActiveByPlatform(_ context.Context, _ string) ([]Group, error) {
panic("unexpected ListActiveByPlatform call")
}
func (s *groupRepoStubForInvalidRequestFallback) ExistsByName(_ context.Context, _ string) (bool, error) {
panic("unexpected ExistsByName call")
}
func (s *groupRepoStubForInvalidRequestFallback) GetAccountCount(_ context.Context, _ int64) (int64, error) {
panic("unexpected GetAccountCount call")
}
func (s *groupRepoStubForInvalidRequestFallback) DeleteAccountGroupsByGroupID(_ context.Context, _ int64) (int64, error) {
panic("unexpected DeleteAccountGroupsByGroupID call")
}
func TestAdminService_CreateGroup_InvalidRequestFallbackRejectsUnsupportedPlatform(t *testing.T) {
fallbackID := int64(10)
repo := &groupRepoStubForInvalidRequestFallback{
groups: map[int64]*Group{
fallbackID: {ID: fallbackID, Platform: PlatformAnthropic, SubscriptionType: SubscriptionTypeStandard},
},
}
svc := &adminServiceImpl{groupRepo: repo}
_, err := svc.CreateGroup(context.Background(), &CreateGroupInput{
Name: "g1",
Platform: PlatformOpenAI,
SubscriptionType: SubscriptionTypeStandard,
FallbackGroupIDOnInvalidRequest: &fallbackID,
})
require.Error(t, err)
require.Contains(t, err.Error(), "invalid request fallback only supported for anthropic or antigravity groups")
require.Nil(t, repo.created)
}
func TestAdminService_CreateGroup_InvalidRequestFallbackRejectsSubscription(t *testing.T) {
fallbackID := int64(10)
repo := &groupRepoStubForInvalidRequestFallback{
groups: map[int64]*Group{
fallbackID: {ID: fallbackID, Platform: PlatformAnthropic, SubscriptionType: SubscriptionTypeStandard},
},
}
svc := &adminServiceImpl{groupRepo: repo}
_, err := svc.CreateGroup(context.Background(), &CreateGroupInput{
Name: "g1",
Platform: PlatformAnthropic,
SubscriptionType: SubscriptionTypeSubscription,
FallbackGroupIDOnInvalidRequest: &fallbackID,
})
require.Error(t, err)
require.Contains(t, err.Error(), "subscription groups cannot set invalid request fallback")
require.Nil(t, repo.created)
}
func TestAdminService_CreateGroup_InvalidRequestFallbackRejectsFallbackGroup(t *testing.T) {
tests := []struct {
name string
fallback *Group
wantMessage string
}{
{
name: "openai_target",
fallback: &Group{ID: 10, Platform: PlatformOpenAI, SubscriptionType: SubscriptionTypeStandard},
wantMessage: "fallback group must be anthropic platform",
},
{
name: "antigravity_target",
fallback: &Group{ID: 10, Platform: PlatformAntigravity, SubscriptionType: SubscriptionTypeStandard},
wantMessage: "fallback group must be anthropic platform",
},
{
name: "subscription_group",
fallback: &Group{ID: 10, Platform: PlatformAnthropic, SubscriptionType: SubscriptionTypeSubscription},
wantMessage: "fallback group cannot be subscription type",
},
{
name: "nested_fallback",
fallback: &Group{
ID: 10,
Platform: PlatformAnthropic,
SubscriptionType: SubscriptionTypeStandard,
FallbackGroupIDOnInvalidRequest: func() *int64 { v := int64(99); return &v }(),
},
wantMessage: "fallback group cannot have invalid request fallback configured",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
fallbackID := tc.fallback.ID
repo := &groupRepoStubForInvalidRequestFallback{
groups: map[int64]*Group{
fallbackID: tc.fallback,
},
}
svc := &adminServiceImpl{groupRepo: repo}
_, err := svc.CreateGroup(context.Background(), &CreateGroupInput{
Name: "g1",
Platform: PlatformAnthropic,
SubscriptionType: SubscriptionTypeStandard,
FallbackGroupIDOnInvalidRequest: &fallbackID,
})
require.Error(t, err)
require.Contains(t, err.Error(), tc.wantMessage)
require.Nil(t, repo.created)
})
}
}
func TestAdminService_CreateGroup_InvalidRequestFallbackNotFound(t *testing.T) {
fallbackID := int64(10)
repo := &groupRepoStubForInvalidRequestFallback{}
svc := &adminServiceImpl{groupRepo: repo}
_, err := svc.CreateGroup(context.Background(), &CreateGroupInput{
Name: "g1",
Platform: PlatformAnthropic,
SubscriptionType: SubscriptionTypeStandard,
FallbackGroupIDOnInvalidRequest: &fallbackID,
})
require.Error(t, err)
require.Contains(t, err.Error(), "fallback group not found")
require.Nil(t, repo.created)
}
func TestAdminService_CreateGroup_InvalidRequestFallbackAllowsAntigravity(t *testing.T) {
fallbackID := int64(10)
repo := &groupRepoStubForInvalidRequestFallback{
groups: map[int64]*Group{
fallbackID: {ID: fallbackID, Platform: PlatformAnthropic, SubscriptionType: SubscriptionTypeStandard},
},
}
svc := &adminServiceImpl{groupRepo: repo}
group, err := svc.CreateGroup(context.Background(), &CreateGroupInput{
Name: "g1",
Platform: PlatformAntigravity,
SubscriptionType: SubscriptionTypeStandard,
FallbackGroupIDOnInvalidRequest: &fallbackID,
})
require.NoError(t, err)
require.NotNil(t, group)
require.NotNil(t, repo.created)
require.Equal(t, fallbackID, *repo.created.FallbackGroupIDOnInvalidRequest)
}
func TestAdminService_CreateGroup_InvalidRequestFallbackClearsOnZero(t *testing.T) {
zero := int64(0)
repo := &groupRepoStubForInvalidRequestFallback{}
svc := &adminServiceImpl{groupRepo: repo}
group, err := svc.CreateGroup(context.Background(), &CreateGroupInput{
Name: "g1",
Platform: PlatformAnthropic,
SubscriptionType: SubscriptionTypeStandard,
FallbackGroupIDOnInvalidRequest: &zero,
})
require.NoError(t, err)
require.NotNil(t, group)
require.NotNil(t, repo.created)
require.Nil(t, repo.created.FallbackGroupIDOnInvalidRequest)
}
func TestAdminService_UpdateGroup_InvalidRequestFallbackPlatformMismatch(t *testing.T) {
fallbackID := int64(10)
existing := &Group{
ID: 1,
Name: "g1",
Platform: PlatformAnthropic,
SubscriptionType: SubscriptionTypeStandard,
Status: StatusActive,
FallbackGroupIDOnInvalidRequest: &fallbackID,
}
repo := &groupRepoStubForInvalidRequestFallback{
groups: map[int64]*Group{
existing.ID: existing,
fallbackID: {ID: fallbackID, Platform: PlatformAnthropic, SubscriptionType: SubscriptionTypeStandard},
},
}
svc := &adminServiceImpl{groupRepo: repo}
_, err := svc.UpdateGroup(context.Background(), existing.ID, &UpdateGroupInput{
Platform: PlatformOpenAI,
})
require.Error(t, err)
require.Contains(t, err.Error(), "invalid request fallback only supported for anthropic or antigravity groups")
require.Nil(t, repo.updated)
}
func TestAdminService_UpdateGroup_InvalidRequestFallbackSubscriptionMismatch(t *testing.T) {
fallbackID := int64(10)
existing := &Group{
ID: 1,
Name: "g1",
Platform: PlatformAnthropic,
SubscriptionType: SubscriptionTypeStandard,
Status: StatusActive,
FallbackGroupIDOnInvalidRequest: &fallbackID,
}
repo := &groupRepoStubForInvalidRequestFallback{
groups: map[int64]*Group{
existing.ID: existing,
fallbackID: {ID: fallbackID, Platform: PlatformAnthropic, SubscriptionType: SubscriptionTypeStandard},
},
}
svc := &adminServiceImpl{groupRepo: repo}
_, err := svc.UpdateGroup(context.Background(), existing.ID, &UpdateGroupInput{
SubscriptionType: SubscriptionTypeSubscription,
})
require.Error(t, err)
require.Contains(t, err.Error(), "subscription groups cannot set invalid request fallback")
require.Nil(t, repo.updated)
}
func TestAdminService_UpdateGroup_InvalidRequestFallbackClearsOnZero(t *testing.T) {
fallbackID := int64(10)
existing := &Group{
ID: 1,
Name: "g1",
Platform: PlatformAnthropic,
SubscriptionType: SubscriptionTypeStandard,
Status: StatusActive,
FallbackGroupIDOnInvalidRequest: &fallbackID,
}
repo := &groupRepoStubForInvalidRequestFallback{
groups: map[int64]*Group{
existing.ID: existing,
fallbackID: {ID: fallbackID, Platform: PlatformAnthropic, SubscriptionType: SubscriptionTypeStandard},
},
}
svc := &adminServiceImpl{groupRepo: repo}
clear := int64(0)
group, err := svc.UpdateGroup(context.Background(), existing.ID, &UpdateGroupInput{
Platform: PlatformOpenAI,
FallbackGroupIDOnInvalidRequest: &clear,
})
require.NoError(t, err)
require.NotNil(t, group)
require.NotNil(t, repo.updated)
require.Nil(t, repo.updated.FallbackGroupIDOnInvalidRequest)
}
func TestAdminService_UpdateGroup_InvalidRequestFallbackRejectsFallbackGroup(t *testing.T) {
fallbackID := int64(10)
existing := &Group{
ID: 1,
Name: "g1",
Platform: PlatformAnthropic,
SubscriptionType: SubscriptionTypeStandard,
Status: StatusActive,
}
repo := &groupRepoStubForInvalidRequestFallback{
groups: map[int64]*Group{
existing.ID: existing,
fallbackID: {ID: fallbackID, Platform: PlatformAnthropic, SubscriptionType: SubscriptionTypeSubscription},
},
}
svc := &adminServiceImpl{groupRepo: repo}
_, err := svc.UpdateGroup(context.Background(), existing.ID, &UpdateGroupInput{
FallbackGroupIDOnInvalidRequest: &fallbackID,
})
require.Error(t, err)
require.Contains(t, err.Error(), "fallback group cannot be subscription type")
require.Nil(t, repo.updated)
}
func TestAdminService_UpdateGroup_InvalidRequestFallbackSetSuccess(t *testing.T) {
fallbackID := int64(10)
existing := &Group{
ID: 1,
Name: "g1",
Platform: PlatformAnthropic,
SubscriptionType: SubscriptionTypeStandard,
Status: StatusActive,
}
repo := &groupRepoStubForInvalidRequestFallback{
groups: map[int64]*Group{
existing.ID: existing,
fallbackID: {ID: fallbackID, Platform: PlatformAnthropic, SubscriptionType: SubscriptionTypeStandard},
},
}
svc := &adminServiceImpl{groupRepo: repo}
group, err := svc.UpdateGroup(context.Background(), existing.ID, &UpdateGroupInput{
FallbackGroupIDOnInvalidRequest: &fallbackID,
})
require.NoError(t, err)
require.NotNil(t, group)
require.NotNil(t, repo.updated)
require.Equal(t, fallbackID, *repo.updated.FallbackGroupIDOnInvalidRequest)
}
func TestAdminService_UpdateGroup_InvalidRequestFallbackAllowsAntigravity(t *testing.T) {
fallbackID := int64(10)
existing := &Group{
ID: 1,
Name: "g1",
Platform: PlatformAntigravity,
SubscriptionType: SubscriptionTypeStandard,
Status: StatusActive,
}
repo := &groupRepoStubForInvalidRequestFallback{
groups: map[int64]*Group{
existing.ID: existing,
fallbackID: {ID: fallbackID, Platform: PlatformAnthropic, SubscriptionType: SubscriptionTypeStandard},
},
}
svc := &adminServiceImpl{groupRepo: repo}
group, err := svc.UpdateGroup(context.Background(), existing.ID, &UpdateGroupInput{
FallbackGroupIDOnInvalidRequest: &fallbackID,
})
require.NoError(t, err)
require.NotNil(t, group)
require.NotNil(t, repo.updated)
require.Equal(t, fallbackID, *repo.updated.FallbackGroupIDOnInvalidRequest)
}
...@@ -62,6 +62,17 @@ type antigravityRetryLoopResult struct { ...@@ -62,6 +62,17 @@ type antigravityRetryLoopResult struct {
resp *http.Response resp *http.Response
} }
// PromptTooLongError 表示上游明确返回 prompt too long
type PromptTooLongError struct {
StatusCode int
RequestID string
Body []byte
}
func (e *PromptTooLongError) Error() string {
return fmt.Sprintf("prompt too long: status=%d", e.StatusCode)
}
// antigravityRetryLoop 执行带 URL fallback 的重试循环 // antigravityRetryLoop 执行带 URL fallback 的重试循环
func antigravityRetryLoop(p antigravityRetryLoopParams) (*antigravityRetryLoopResult, error) { func antigravityRetryLoop(p antigravityRetryLoopParams) (*antigravityRetryLoopResult, error) {
availableURLs := antigravity.DefaultURLAvailability.GetAvailableURLs() availableURLs := antigravity.DefaultURLAvailability.GetAvailableURLs()
...@@ -930,6 +941,39 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context, ...@@ -930,6 +941,39 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
// 处理错误响应(重试后仍失败或不触发重试) // 处理错误响应(重试后仍失败或不触发重试)
if resp.StatusCode >= 400 { if resp.StatusCode >= 400 {
if resp.StatusCode == http.StatusBadRequest {
upstreamMsg := strings.TrimSpace(extractAntigravityErrorMessage(respBody))
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
log.Printf("%s status=400 prompt_too_long=%v upstream_message=%q request_id=%s body=%s", prefix, isPromptTooLongError(respBody), upstreamMsg, resp.Header.Get("x-request-id"), truncateForLog(respBody, 500))
}
if resp.StatusCode == http.StatusBadRequest && isPromptTooLongError(respBody) {
upstreamMsg := strings.TrimSpace(extractAntigravityErrorMessage(respBody))
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
logBody := s.settingService != nil && s.settingService.cfg != nil && s.settingService.cfg.Gateway.LogUpstreamErrorBody
maxBytes := 2048
if s.settingService != nil && s.settingService.cfg != nil && s.settingService.cfg.Gateway.LogUpstreamErrorBodyMaxBytes > 0 {
maxBytes = s.settingService.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
}
upstreamDetail := ""
if logBody {
upstreamDetail = truncateString(string(respBody), maxBytes)
}
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: resp.StatusCode,
UpstreamRequestID: resp.Header.Get("x-request-id"),
Kind: "prompt_too_long",
Message: upstreamMsg,
Detail: upstreamDetail,
})
return nil, &PromptTooLongError{
StatusCode: resp.StatusCode,
RequestID: resp.Header.Get("x-request-id"),
Body: respBody,
}
}
s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody, quotaScope) s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody, quotaScope)
if s.shouldFailoverUpstreamError(resp.StatusCode) { if s.shouldFailoverUpstreamError(resp.StatusCode) {
...@@ -1019,21 +1063,55 @@ func isSignatureRelatedError(respBody []byte) bool { ...@@ -1019,21 +1063,55 @@ func isSignatureRelatedError(respBody []byte) bool {
return false return false
} }
func isPromptTooLongError(respBody []byte) bool {
msg := strings.ToLower(strings.TrimSpace(extractAntigravityErrorMessage(respBody)))
if msg == "" {
msg = strings.ToLower(string(respBody))
}
return strings.Contains(msg, "prompt is too long")
}
func extractAntigravityErrorMessage(body []byte) string { func extractAntigravityErrorMessage(body []byte) string {
var payload map[string]any var payload map[string]any
if err := json.Unmarshal(body, &payload); err != nil { if err := json.Unmarshal(body, &payload); err != nil {
return "" return ""
} }
parseNestedMessage := func(msg string) string {
trimmed := strings.TrimSpace(msg)
if trimmed == "" || !strings.HasPrefix(trimmed, "{") {
return ""
}
var nested map[string]any
if err := json.Unmarshal([]byte(trimmed), &nested); err != nil {
return ""
}
if errObj, ok := nested["error"].(map[string]any); ok {
if innerMsg, ok := errObj["message"].(string); ok && strings.TrimSpace(innerMsg) != "" {
return innerMsg
}
}
if innerMsg, ok := nested["message"].(string); ok && strings.TrimSpace(innerMsg) != "" {
return innerMsg
}
return ""
}
// Google-style: {"error": {"message": "..."}} // Google-style: {"error": {"message": "..."}}
if errObj, ok := payload["error"].(map[string]any); ok { if errObj, ok := payload["error"].(map[string]any); ok {
if msg, ok := errObj["message"].(string); ok && strings.TrimSpace(msg) != "" { if msg, ok := errObj["message"].(string); ok && strings.TrimSpace(msg) != "" {
if innerMsg := parseNestedMessage(msg); innerMsg != "" {
return innerMsg
}
return msg return msg
} }
} }
// Fallback: top-level message // Fallback: top-level message
if msg, ok := payload["message"].(string); ok && strings.TrimSpace(msg) != "" { if msg, ok := payload["message"].(string); ok && strings.TrimSpace(msg) != "" {
if innerMsg := parseNestedMessage(msg); innerMsg != "" {
return innerMsg
}
return msg return msg
} }
...@@ -2209,6 +2287,10 @@ func (s *AntigravityGatewayService) writeMappedClaudeError(c *gin.Context, accou ...@@ -2209,6 +2287,10 @@ func (s *AntigravityGatewayService) writeMappedClaudeError(c *gin.Context, accou
return fmt.Errorf("upstream error: %d message=%s", upstreamStatus, upstreamMsg) return fmt.Errorf("upstream error: %d message=%s", upstreamStatus, upstreamMsg)
} }
func (s *AntigravityGatewayService) WriteMappedClaudeError(c *gin.Context, account *Account, upstreamStatus int, upstreamRequestID string, body []byte) error {
return s.writeMappedClaudeError(c, account, upstreamStatus, upstreamRequestID, body)
}
func (s *AntigravityGatewayService) writeGoogleError(c *gin.Context, status int, message string) error { func (s *AntigravityGatewayService) writeGoogleError(c *gin.Context, status int, message string) error {
statusStr := "UNKNOWN" statusStr := "UNKNOWN"
switch status { switch status {
......
package service package service
import ( import (
"bytes"
"context"
"encoding/json" "encoding/json"
"io"
"net/http"
"net/http/httptest"
"testing" "testing"
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity" "github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
...@@ -81,3 +87,77 @@ func TestStripThinkingFromClaudeRequest_DoesNotDowngradeTools(t *testing.T) { ...@@ -81,3 +87,77 @@ func TestStripThinkingFromClaudeRequest_DoesNotDowngradeTools(t *testing.T) {
require.Equal(t, "secret plan", blocks[0]["text"]) require.Equal(t, "secret plan", blocks[0]["text"])
require.Equal(t, "tool_use", blocks[1]["type"]) require.Equal(t, "tool_use", blocks[1]["type"])
} }
func TestIsPromptTooLongError(t *testing.T) {
require.True(t, isPromptTooLongError([]byte(`{"error":{"message":"Prompt is too long"}}`)))
require.True(t, isPromptTooLongError([]byte(`{"message":"Prompt is too long"}`)))
require.False(t, isPromptTooLongError([]byte(`{"error":{"message":"other"}}`)))
}
type httpUpstreamStub struct {
resp *http.Response
err error
}
func (s *httpUpstreamStub) Do(_ *http.Request, _ string, _ int64, _ int) (*http.Response, error) {
return s.resp, s.err
}
func TestAntigravityGatewayService_Forward_PromptTooLong(t *testing.T) {
gin.SetMode(gin.TestMode)
writer := httptest.NewRecorder()
c, _ := gin.CreateTestContext(writer)
body, err := json.Marshal(map[string]any{
"model": "claude-opus-4-5",
"messages": []map[string]any{
{"role": "user", "content": "hi"},
},
"max_tokens": 1,
"stream": false,
})
require.NoError(t, err)
req := httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(body))
c.Request = req
respBody := []byte(`{"error":{"message":"Prompt is too long"}}`)
resp := &http.Response{
StatusCode: http.StatusBadRequest,
Header: http.Header{"X-Request-Id": []string{"req-1"}},
Body: io.NopCloser(bytes.NewReader(respBody)),
}
svc := &AntigravityGatewayService{
tokenProvider: &AntigravityTokenProvider{},
httpUpstream: &httpUpstreamStub{resp: resp},
}
account := &Account{
ID: 1,
Name: "acc-1",
Platform: PlatformAntigravity,
Type: AccountTypeOAuth,
Status: StatusActive,
Concurrency: 1,
Credentials: map[string]any{
"access_token": "token",
},
}
result, err := svc.Forward(context.Background(), c, account, body)
require.Nil(t, result)
var promptErr *PromptTooLongError
require.ErrorAs(t, err, &promptErr)
require.Equal(t, http.StatusBadRequest, promptErr.StatusCode)
require.Equal(t, "req-1", promptErr.RequestID)
require.NotEmpty(t, promptErr.Body)
raw, ok := c.Get(OpsUpstreamErrorsKey)
require.True(t, ok)
events, ok := raw.([]*OpsUpstreamErrorEvent)
require.True(t, ok)
require.Len(t, events, 1)
require.Equal(t, "prompt_too_long", events[0].Kind)
}
...@@ -37,6 +37,7 @@ type APIKeyAuthGroupSnapshot struct { ...@@ -37,6 +37,7 @@ type APIKeyAuthGroupSnapshot struct {
ImagePrice4K *float64 `json:"image_price_4k,omitempty"` ImagePrice4K *float64 `json:"image_price_4k,omitempty"`
ClaudeCodeOnly bool `json:"claude_code_only"` ClaudeCodeOnly bool `json:"claude_code_only"`
FallbackGroupID *int64 `json:"fallback_group_id,omitempty"` FallbackGroupID *int64 `json:"fallback_group_id,omitempty"`
FallbackGroupIDOnInvalidRequest *int64 `json:"fallback_group_id_on_invalid_request,omitempty"`
// Model routing is used by gateway account selection, so it must be part of auth cache snapshot. // Model routing is used by gateway account selection, so it must be part of auth cache snapshot.
// Only anthropic groups use these fields; others may leave them empty. // Only anthropic groups use these fields; others may leave them empty.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment