Commit a4210588 authored by Edric Li's avatar Edric Li
Browse files

feat(groups): add Claude Code client restriction and session isolation

- Add claude_code_only field to restrict groups to Claude Code clients only
- Add fallback_group_id for non-Claude Code requests to use alternate group
- Implement ClaudeCodeValidator for User-Agent detection
- Add group-level session binding isolation (groupID in Redis key)
- Prevent cross-group sticky session pollution
- Update frontend with Claude Code restriction controls
parent 958ffe7a
......@@ -51,6 +51,10 @@ type Group struct {
ImagePrice2k *float64 `json:"image_price_2k,omitempty"`
// ImagePrice4k holds the value of the "image_price_4k" field.
ImagePrice4k *float64 `json:"image_price_4k,omitempty"`
// 是否仅允许 Claude Code 客户端
ClaudeCodeOnly bool `json:"claude_code_only,omitempty"`
// 非 Claude Code 请求降级使用的分组 ID
FallbackGroupID *int64 `json:"fallback_group_id,omitempty"`
// Edges holds the relations/edges for other nodes in the graph.
// The values are being populated by the GroupQuery when eager-loading is set.
Edges GroupEdges `json:"edges"`
......@@ -157,11 +161,11 @@ func (*Group) scanValues(columns []string) ([]any, error) {
values := make([]any, len(columns))
for i := range columns {
switch columns[i] {
case group.FieldIsExclusive:
case group.FieldIsExclusive, group.FieldClaudeCodeOnly:
values[i] = new(sql.NullBool)
case group.FieldRateMultiplier, group.FieldDailyLimitUsd, group.FieldWeeklyLimitUsd, group.FieldMonthlyLimitUsd, group.FieldImagePrice1k, group.FieldImagePrice2k, group.FieldImagePrice4k:
values[i] = new(sql.NullFloat64)
case group.FieldID, group.FieldDefaultValidityDays:
case group.FieldID, group.FieldDefaultValidityDays, group.FieldFallbackGroupID:
values[i] = new(sql.NullInt64)
case group.FieldName, group.FieldDescription, group.FieldStatus, group.FieldPlatform, group.FieldSubscriptionType:
values[i] = new(sql.NullString)
......@@ -298,6 +302,19 @@ func (_m *Group) assignValues(columns []string, values []any) error {
_m.ImagePrice4k = new(float64)
*_m.ImagePrice4k = value.Float64
}
case group.FieldClaudeCodeOnly:
if value, ok := values[i].(*sql.NullBool); !ok {
return fmt.Errorf("unexpected type %T for field claude_code_only", values[i])
} else if value.Valid {
_m.ClaudeCodeOnly = value.Bool
}
case group.FieldFallbackGroupID:
if value, ok := values[i].(*sql.NullInt64); !ok {
return fmt.Errorf("unexpected type %T for field fallback_group_id", values[i])
} else if value.Valid {
_m.FallbackGroupID = new(int64)
*_m.FallbackGroupID = value.Int64
}
default:
_m.selectValues.Set(columns[i], values[i])
}
......@@ -440,6 +457,14 @@ func (_m *Group) String() string {
builder.WriteString("image_price_4k=")
builder.WriteString(fmt.Sprintf("%v", *v))
}
builder.WriteString(", ")
builder.WriteString("claude_code_only=")
builder.WriteString(fmt.Sprintf("%v", _m.ClaudeCodeOnly))
builder.WriteString(", ")
if v := _m.FallbackGroupID; v != nil {
builder.WriteString("fallback_group_id=")
builder.WriteString(fmt.Sprintf("%v", *v))
}
builder.WriteByte(')')
return builder.String()
}
......
......@@ -49,6 +49,10 @@ const (
FieldImagePrice2k = "image_price_2k"
// FieldImagePrice4k holds the string denoting the image_price_4k field in the database.
FieldImagePrice4k = "image_price_4k"
// FieldClaudeCodeOnly holds the string denoting the claude_code_only field in the database.
FieldClaudeCodeOnly = "claude_code_only"
// FieldFallbackGroupID holds the string denoting the fallback_group_id field in the database.
FieldFallbackGroupID = "fallback_group_id"
// EdgeAPIKeys holds the string denoting the api_keys edge name in mutations.
EdgeAPIKeys = "api_keys"
// EdgeRedeemCodes holds the string denoting the redeem_codes edge name in mutations.
......@@ -141,6 +145,8 @@ var Columns = []string{
FieldImagePrice1k,
FieldImagePrice2k,
FieldImagePrice4k,
FieldClaudeCodeOnly,
FieldFallbackGroupID,
}
var (
......@@ -196,6 +202,8 @@ var (
SubscriptionTypeValidator func(string) error
// DefaultDefaultValidityDays holds the default value on creation for the "default_validity_days" field.
DefaultDefaultValidityDays int
// DefaultClaudeCodeOnly holds the default value on creation for the "claude_code_only" field.
DefaultClaudeCodeOnly bool
)
// OrderOption defines the ordering options for the Group queries.
......@@ -291,6 +299,16 @@ func ByImagePrice4k(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldImagePrice4k, opts...).ToFunc()
}
// ByClaudeCodeOnly orders the results by the claude_code_only field.
func ByClaudeCodeOnly(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldClaudeCodeOnly, opts...).ToFunc()
}
// ByFallbackGroupID orders the results by the fallback_group_id field.
func ByFallbackGroupID(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldFallbackGroupID, opts...).ToFunc()
}
// ByAPIKeysCount orders the results by api_keys count.
func ByAPIKeysCount(opts ...sql.OrderTermOption) OrderOption {
return func(s *sql.Selector) {
......
......@@ -140,6 +140,16 @@ func ImagePrice4k(v float64) predicate.Group {
return predicate.Group(sql.FieldEQ(FieldImagePrice4k, v))
}
// ClaudeCodeOnly applies equality check predicate on the "claude_code_only" field. It's identical to ClaudeCodeOnlyEQ.
func ClaudeCodeOnly(v bool) predicate.Group {
return predicate.Group(sql.FieldEQ(FieldClaudeCodeOnly, v))
}
// FallbackGroupID applies equality check predicate on the "fallback_group_id" field. It's identical to FallbackGroupIDEQ.
func FallbackGroupID(v int64) predicate.Group {
return predicate.Group(sql.FieldEQ(FieldFallbackGroupID, v))
}
// CreatedAtEQ applies the EQ predicate on the "created_at" field.
func CreatedAtEQ(v time.Time) predicate.Group {
return predicate.Group(sql.FieldEQ(FieldCreatedAt, v))
......@@ -995,6 +1005,66 @@ func ImagePrice4kNotNil() predicate.Group {
return predicate.Group(sql.FieldNotNull(FieldImagePrice4k))
}
// ClaudeCodeOnlyEQ applies the EQ predicate on the "claude_code_only" field.
func ClaudeCodeOnlyEQ(v bool) predicate.Group {
return predicate.Group(sql.FieldEQ(FieldClaudeCodeOnly, v))
}
// ClaudeCodeOnlyNEQ applies the NEQ predicate on the "claude_code_only" field.
func ClaudeCodeOnlyNEQ(v bool) predicate.Group {
return predicate.Group(sql.FieldNEQ(FieldClaudeCodeOnly, v))
}
// FallbackGroupIDEQ applies the EQ predicate on the "fallback_group_id" field.
func FallbackGroupIDEQ(v int64) predicate.Group {
return predicate.Group(sql.FieldEQ(FieldFallbackGroupID, v))
}
// FallbackGroupIDNEQ applies the NEQ predicate on the "fallback_group_id" field.
func FallbackGroupIDNEQ(v int64) predicate.Group {
return predicate.Group(sql.FieldNEQ(FieldFallbackGroupID, v))
}
// FallbackGroupIDIn applies the In predicate on the "fallback_group_id" field.
func FallbackGroupIDIn(vs ...int64) predicate.Group {
return predicate.Group(sql.FieldIn(FieldFallbackGroupID, vs...))
}
// FallbackGroupIDNotIn applies the NotIn predicate on the "fallback_group_id" field.
func FallbackGroupIDNotIn(vs ...int64) predicate.Group {
return predicate.Group(sql.FieldNotIn(FieldFallbackGroupID, vs...))
}
// FallbackGroupIDGT applies the GT predicate on the "fallback_group_id" field.
func FallbackGroupIDGT(v int64) predicate.Group {
return predicate.Group(sql.FieldGT(FieldFallbackGroupID, v))
}
// FallbackGroupIDGTE applies the GTE predicate on the "fallback_group_id" field.
func FallbackGroupIDGTE(v int64) predicate.Group {
return predicate.Group(sql.FieldGTE(FieldFallbackGroupID, v))
}
// FallbackGroupIDLT applies the LT predicate on the "fallback_group_id" field.
func FallbackGroupIDLT(v int64) predicate.Group {
return predicate.Group(sql.FieldLT(FieldFallbackGroupID, v))
}
// FallbackGroupIDLTE applies the LTE predicate on the "fallback_group_id" field.
func FallbackGroupIDLTE(v int64) predicate.Group {
return predicate.Group(sql.FieldLTE(FieldFallbackGroupID, v))
}
// FallbackGroupIDIsNil applies the IsNil predicate on the "fallback_group_id" field.
func FallbackGroupIDIsNil() predicate.Group {
return predicate.Group(sql.FieldIsNull(FieldFallbackGroupID))
}
// FallbackGroupIDNotNil applies the NotNil predicate on the "fallback_group_id" field.
func FallbackGroupIDNotNil() predicate.Group {
return predicate.Group(sql.FieldNotNull(FieldFallbackGroupID))
}
// HasAPIKeys applies the HasEdge predicate on the "api_keys" edge.
func HasAPIKeys() predicate.Group {
return predicate.Group(func(s *sql.Selector) {
......
......@@ -258,6 +258,34 @@ func (_c *GroupCreate) SetNillableImagePrice4k(v *float64) *GroupCreate {
return _c
}
// SetClaudeCodeOnly sets the "claude_code_only" field.
func (_c *GroupCreate) SetClaudeCodeOnly(v bool) *GroupCreate {
_c.mutation.SetClaudeCodeOnly(v)
return _c
}
// SetNillableClaudeCodeOnly sets the "claude_code_only" field if the given value is not nil.
func (_c *GroupCreate) SetNillableClaudeCodeOnly(v *bool) *GroupCreate {
if v != nil {
_c.SetClaudeCodeOnly(*v)
}
return _c
}
// SetFallbackGroupID sets the "fallback_group_id" field.
func (_c *GroupCreate) SetFallbackGroupID(v int64) *GroupCreate {
_c.mutation.SetFallbackGroupID(v)
return _c
}
// SetNillableFallbackGroupID sets the "fallback_group_id" field if the given value is not nil.
func (_c *GroupCreate) SetNillableFallbackGroupID(v *int64) *GroupCreate {
if v != nil {
_c.SetFallbackGroupID(*v)
}
return _c
}
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs.
func (_c *GroupCreate) AddAPIKeyIDs(ids ...int64) *GroupCreate {
_c.mutation.AddAPIKeyIDs(ids...)
......@@ -423,6 +451,10 @@ func (_c *GroupCreate) defaults() error {
v := group.DefaultDefaultValidityDays
_c.mutation.SetDefaultValidityDays(v)
}
if _, ok := _c.mutation.ClaudeCodeOnly(); !ok {
v := group.DefaultClaudeCodeOnly
_c.mutation.SetClaudeCodeOnly(v)
}
return nil
}
......@@ -475,6 +507,9 @@ func (_c *GroupCreate) check() error {
if _, ok := _c.mutation.DefaultValidityDays(); !ok {
return &ValidationError{Name: "default_validity_days", err: errors.New(`ent: missing required field "Group.default_validity_days"`)}
}
if _, ok := _c.mutation.ClaudeCodeOnly(); !ok {
return &ValidationError{Name: "claude_code_only", err: errors.New(`ent: missing required field "Group.claude_code_only"`)}
}
return nil
}
......@@ -570,6 +605,14 @@ func (_c *GroupCreate) createSpec() (*Group, *sqlgraph.CreateSpec) {
_spec.SetField(group.FieldImagePrice4k, field.TypeFloat64, value)
_node.ImagePrice4k = &value
}
if value, ok := _c.mutation.ClaudeCodeOnly(); ok {
_spec.SetField(group.FieldClaudeCodeOnly, field.TypeBool, value)
_node.ClaudeCodeOnly = value
}
if value, ok := _c.mutation.FallbackGroupID(); ok {
_spec.SetField(group.FieldFallbackGroupID, field.TypeInt64, value)
_node.FallbackGroupID = &value
}
if nodes := _c.mutation.APIKeysIDs(); len(nodes) > 0 {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
......@@ -1014,6 +1057,42 @@ func (u *GroupUpsert) ClearImagePrice4k() *GroupUpsert {
return u
}
// SetClaudeCodeOnly sets the "claude_code_only" field.
func (u *GroupUpsert) SetClaudeCodeOnly(v bool) *GroupUpsert {
u.Set(group.FieldClaudeCodeOnly, v)
return u
}
// UpdateClaudeCodeOnly sets the "claude_code_only" field to the value that was provided on create.
func (u *GroupUpsert) UpdateClaudeCodeOnly() *GroupUpsert {
u.SetExcluded(group.FieldClaudeCodeOnly)
return u
}
// SetFallbackGroupID sets the "fallback_group_id" field.
func (u *GroupUpsert) SetFallbackGroupID(v int64) *GroupUpsert {
u.Set(group.FieldFallbackGroupID, v)
return u
}
// UpdateFallbackGroupID sets the "fallback_group_id" field to the value that was provided on create.
func (u *GroupUpsert) UpdateFallbackGroupID() *GroupUpsert {
u.SetExcluded(group.FieldFallbackGroupID)
return u
}
// AddFallbackGroupID adds v to the "fallback_group_id" field.
func (u *GroupUpsert) AddFallbackGroupID(v int64) *GroupUpsert {
u.Add(group.FieldFallbackGroupID, v)
return u
}
// ClearFallbackGroupID clears the value of the "fallback_group_id" field.
func (u *GroupUpsert) ClearFallbackGroupID() *GroupUpsert {
u.SetNull(group.FieldFallbackGroupID)
return u
}
// UpdateNewValues updates the mutable fields using the new values that were set on create.
// Using this option is equivalent to using:
//
......@@ -1395,6 +1474,48 @@ func (u *GroupUpsertOne) ClearImagePrice4k() *GroupUpsertOne {
})
}
// SetClaudeCodeOnly sets the "claude_code_only" field.
func (u *GroupUpsertOne) SetClaudeCodeOnly(v bool) *GroupUpsertOne {
return u.Update(func(s *GroupUpsert) {
s.SetClaudeCodeOnly(v)
})
}
// UpdateClaudeCodeOnly sets the "claude_code_only" field to the value that was provided on create.
func (u *GroupUpsertOne) UpdateClaudeCodeOnly() *GroupUpsertOne {
return u.Update(func(s *GroupUpsert) {
s.UpdateClaudeCodeOnly()
})
}
// SetFallbackGroupID sets the "fallback_group_id" field.
func (u *GroupUpsertOne) SetFallbackGroupID(v int64) *GroupUpsertOne {
return u.Update(func(s *GroupUpsert) {
s.SetFallbackGroupID(v)
})
}
// AddFallbackGroupID adds v to the "fallback_group_id" field.
func (u *GroupUpsertOne) AddFallbackGroupID(v int64) *GroupUpsertOne {
return u.Update(func(s *GroupUpsert) {
s.AddFallbackGroupID(v)
})
}
// UpdateFallbackGroupID sets the "fallback_group_id" field to the value that was provided on create.
func (u *GroupUpsertOne) UpdateFallbackGroupID() *GroupUpsertOne {
return u.Update(func(s *GroupUpsert) {
s.UpdateFallbackGroupID()
})
}
// ClearFallbackGroupID clears the value of the "fallback_group_id" field.
func (u *GroupUpsertOne) ClearFallbackGroupID() *GroupUpsertOne {
return u.Update(func(s *GroupUpsert) {
s.ClearFallbackGroupID()
})
}
// Exec executes the query.
func (u *GroupUpsertOne) Exec(ctx context.Context) error {
if len(u.create.conflict) == 0 {
......@@ -1942,6 +2063,48 @@ func (u *GroupUpsertBulk) ClearImagePrice4k() *GroupUpsertBulk {
})
}
// SetClaudeCodeOnly sets the "claude_code_only" field.
func (u *GroupUpsertBulk) SetClaudeCodeOnly(v bool) *GroupUpsertBulk {
return u.Update(func(s *GroupUpsert) {
s.SetClaudeCodeOnly(v)
})
}
// UpdateClaudeCodeOnly sets the "claude_code_only" field to the value that was provided on create.
func (u *GroupUpsertBulk) UpdateClaudeCodeOnly() *GroupUpsertBulk {
return u.Update(func(s *GroupUpsert) {
s.UpdateClaudeCodeOnly()
})
}
// SetFallbackGroupID sets the "fallback_group_id" field.
func (u *GroupUpsertBulk) SetFallbackGroupID(v int64) *GroupUpsertBulk {
return u.Update(func(s *GroupUpsert) {
s.SetFallbackGroupID(v)
})
}
// AddFallbackGroupID adds v to the "fallback_group_id" field.
func (u *GroupUpsertBulk) AddFallbackGroupID(v int64) *GroupUpsertBulk {
return u.Update(func(s *GroupUpsert) {
s.AddFallbackGroupID(v)
})
}
// UpdateFallbackGroupID sets the "fallback_group_id" field to the value that was provided on create.
func (u *GroupUpsertBulk) UpdateFallbackGroupID() *GroupUpsertBulk {
return u.Update(func(s *GroupUpsert) {
s.UpdateFallbackGroupID()
})
}
// ClearFallbackGroupID clears the value of the "fallback_group_id" field.
func (u *GroupUpsertBulk) ClearFallbackGroupID() *GroupUpsertBulk {
return u.Update(func(s *GroupUpsert) {
s.ClearFallbackGroupID()
})
}
// Exec executes the query.
func (u *GroupUpsertBulk) Exec(ctx context.Context) error {
if u.create.err != nil {
......
......@@ -354,6 +354,47 @@ func (_u *GroupUpdate) ClearImagePrice4k() *GroupUpdate {
return _u
}
// SetClaudeCodeOnly sets the "claude_code_only" field.
func (_u *GroupUpdate) SetClaudeCodeOnly(v bool) *GroupUpdate {
_u.mutation.SetClaudeCodeOnly(v)
return _u
}
// SetNillableClaudeCodeOnly sets the "claude_code_only" field if the given value is not nil.
func (_u *GroupUpdate) SetNillableClaudeCodeOnly(v *bool) *GroupUpdate {
if v != nil {
_u.SetClaudeCodeOnly(*v)
}
return _u
}
// SetFallbackGroupID sets the "fallback_group_id" field.
func (_u *GroupUpdate) SetFallbackGroupID(v int64) *GroupUpdate {
_u.mutation.ResetFallbackGroupID()
_u.mutation.SetFallbackGroupID(v)
return _u
}
// SetNillableFallbackGroupID sets the "fallback_group_id" field if the given value is not nil.
func (_u *GroupUpdate) SetNillableFallbackGroupID(v *int64) *GroupUpdate {
if v != nil {
_u.SetFallbackGroupID(*v)
}
return _u
}
// AddFallbackGroupID adds value to the "fallback_group_id" field.
func (_u *GroupUpdate) AddFallbackGroupID(v int64) *GroupUpdate {
_u.mutation.AddFallbackGroupID(v)
return _u
}
// ClearFallbackGroupID clears the value of the "fallback_group_id" field.
func (_u *GroupUpdate) ClearFallbackGroupID() *GroupUpdate {
_u.mutation.ClearFallbackGroupID()
return _u
}
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs.
func (_u *GroupUpdate) AddAPIKeyIDs(ids ...int64) *GroupUpdate {
_u.mutation.AddAPIKeyIDs(ids...)
......@@ -750,6 +791,18 @@ func (_u *GroupUpdate) sqlSave(ctx context.Context) (_node int, err error) {
if _u.mutation.ImagePrice4kCleared() {
_spec.ClearField(group.FieldImagePrice4k, field.TypeFloat64)
}
if value, ok := _u.mutation.ClaudeCodeOnly(); ok {
_spec.SetField(group.FieldClaudeCodeOnly, field.TypeBool, value)
}
if value, ok := _u.mutation.FallbackGroupID(); ok {
_spec.SetField(group.FieldFallbackGroupID, field.TypeInt64, value)
}
if value, ok := _u.mutation.AddedFallbackGroupID(); ok {
_spec.AddField(group.FieldFallbackGroupID, field.TypeInt64, value)
}
if _u.mutation.FallbackGroupIDCleared() {
_spec.ClearField(group.FieldFallbackGroupID, field.TypeInt64)
}
if _u.mutation.APIKeysCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
......@@ -1384,6 +1437,47 @@ func (_u *GroupUpdateOne) ClearImagePrice4k() *GroupUpdateOne {
return _u
}
// SetClaudeCodeOnly sets the "claude_code_only" field.
func (_u *GroupUpdateOne) SetClaudeCodeOnly(v bool) *GroupUpdateOne {
_u.mutation.SetClaudeCodeOnly(v)
return _u
}
// SetNillableClaudeCodeOnly sets the "claude_code_only" field if the given value is not nil.
func (_u *GroupUpdateOne) SetNillableClaudeCodeOnly(v *bool) *GroupUpdateOne {
if v != nil {
_u.SetClaudeCodeOnly(*v)
}
return _u
}
// SetFallbackGroupID sets the "fallback_group_id" field.
func (_u *GroupUpdateOne) SetFallbackGroupID(v int64) *GroupUpdateOne {
_u.mutation.ResetFallbackGroupID()
_u.mutation.SetFallbackGroupID(v)
return _u
}
// SetNillableFallbackGroupID sets the "fallback_group_id" field if the given value is not nil.
func (_u *GroupUpdateOne) SetNillableFallbackGroupID(v *int64) *GroupUpdateOne {
if v != nil {
_u.SetFallbackGroupID(*v)
}
return _u
}
// AddFallbackGroupID adds value to the "fallback_group_id" field.
func (_u *GroupUpdateOne) AddFallbackGroupID(v int64) *GroupUpdateOne {
_u.mutation.AddFallbackGroupID(v)
return _u
}
// ClearFallbackGroupID clears the value of the "fallback_group_id" field.
func (_u *GroupUpdateOne) ClearFallbackGroupID() *GroupUpdateOne {
_u.mutation.ClearFallbackGroupID()
return _u
}
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs.
func (_u *GroupUpdateOne) AddAPIKeyIDs(ids ...int64) *GroupUpdateOne {
_u.mutation.AddAPIKeyIDs(ids...)
......@@ -1810,6 +1904,18 @@ func (_u *GroupUpdateOne) sqlSave(ctx context.Context) (_node *Group, err error)
if _u.mutation.ImagePrice4kCleared() {
_spec.ClearField(group.FieldImagePrice4k, field.TypeFloat64)
}
if value, ok := _u.mutation.ClaudeCodeOnly(); ok {
_spec.SetField(group.FieldClaudeCodeOnly, field.TypeBool, value)
}
if value, ok := _u.mutation.FallbackGroupID(); ok {
_spec.SetField(group.FieldFallbackGroupID, field.TypeInt64, value)
}
if value, ok := _u.mutation.AddedFallbackGroupID(); ok {
_spec.AddField(group.FieldFallbackGroupID, field.TypeInt64, value)
}
if _u.mutation.FallbackGroupIDCleared() {
_spec.ClearField(group.FieldFallbackGroupID, field.TypeInt64)
}
if _u.mutation.APIKeysCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
......
......@@ -221,6 +221,8 @@ var (
{Name: "image_price_1k", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,8)"}},
{Name: "image_price_2k", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,8)"}},
{Name: "image_price_4k", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,8)"}},
{Name: "claude_code_only", Type: field.TypeBool, Default: false},
{Name: "fallback_group_id", Type: field.TypeInt64, Nullable: true},
}
// GroupsTable holds the schema information for the "groups" table.
GroupsTable = &schema.Table{
......
......@@ -3590,6 +3590,9 @@ type GroupMutation struct {
addimage_price_2k *float64
image_price_4k *float64
addimage_price_4k *float64
claude_code_only *bool
fallback_group_id *int64
addfallback_group_id *int64
clearedFields map[string]struct{}
api_keys map[int64]struct{}
removedapi_keys map[int64]struct{}
......@@ -4594,6 +4597,112 @@ func (m *GroupMutation) ResetImagePrice4k() {
delete(m.clearedFields, group.FieldImagePrice4k)
}
// SetClaudeCodeOnly sets the "claude_code_only" field.
func (m *GroupMutation) SetClaudeCodeOnly(b bool) {
m.claude_code_only = &b
}
// ClaudeCodeOnly returns the value of the "claude_code_only" field in the mutation.
func (m *GroupMutation) ClaudeCodeOnly() (r bool, exists bool) {
v := m.claude_code_only
if v == nil {
return
}
return *v, true
}
// OldClaudeCodeOnly returns the old "claude_code_only" field's value of the Group entity.
// If the Group object wasn't provided to the builder, the object is fetched from the database.
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
func (m *GroupMutation) OldClaudeCodeOnly(ctx context.Context) (v bool, err error) {
if !m.op.Is(OpUpdateOne) {
return v, errors.New("OldClaudeCodeOnly is only allowed on UpdateOne operations")
}
if m.id == nil || m.oldValue == nil {
return v, errors.New("OldClaudeCodeOnly requires an ID field in the mutation")
}
oldValue, err := m.oldValue(ctx)
if err != nil {
return v, fmt.Errorf("querying old value for OldClaudeCodeOnly: %w", err)
}
return oldValue.ClaudeCodeOnly, nil
}
// ResetClaudeCodeOnly resets all changes to the "claude_code_only" field.
func (m *GroupMutation) ResetClaudeCodeOnly() {
m.claude_code_only = nil
}
// SetFallbackGroupID sets the "fallback_group_id" field.
func (m *GroupMutation) SetFallbackGroupID(i int64) {
m.fallback_group_id = &i
m.addfallback_group_id = nil
}
// FallbackGroupID returns the value of the "fallback_group_id" field in the mutation.
func (m *GroupMutation) FallbackGroupID() (r int64, exists bool) {
v := m.fallback_group_id
if v == nil {
return
}
return *v, true
}
// OldFallbackGroupID returns the old "fallback_group_id" field's value of the Group entity.
// If the Group object wasn't provided to the builder, the object is fetched from the database.
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
func (m *GroupMutation) OldFallbackGroupID(ctx context.Context) (v *int64, err error) {
if !m.op.Is(OpUpdateOne) {
return v, errors.New("OldFallbackGroupID is only allowed on UpdateOne operations")
}
if m.id == nil || m.oldValue == nil {
return v, errors.New("OldFallbackGroupID requires an ID field in the mutation")
}
oldValue, err := m.oldValue(ctx)
if err != nil {
return v, fmt.Errorf("querying old value for OldFallbackGroupID: %w", err)
}
return oldValue.FallbackGroupID, nil
}
// AddFallbackGroupID adds i to the "fallback_group_id" field.
func (m *GroupMutation) AddFallbackGroupID(i int64) {
if m.addfallback_group_id != nil {
*m.addfallback_group_id += i
} else {
m.addfallback_group_id = &i
}
}
// AddedFallbackGroupID returns the value that was added to the "fallback_group_id" field in this mutation.
func (m *GroupMutation) AddedFallbackGroupID() (r int64, exists bool) {
v := m.addfallback_group_id
if v == nil {
return
}
return *v, true
}
// ClearFallbackGroupID clears the value of the "fallback_group_id" field.
func (m *GroupMutation) ClearFallbackGroupID() {
m.fallback_group_id = nil
m.addfallback_group_id = nil
m.clearedFields[group.FieldFallbackGroupID] = struct{}{}
}
// FallbackGroupIDCleared returns if the "fallback_group_id" field was cleared in this mutation.
func (m *GroupMutation) FallbackGroupIDCleared() bool {
_, ok := m.clearedFields[group.FieldFallbackGroupID]
return ok
}
// ResetFallbackGroupID resets all changes to the "fallback_group_id" field.
func (m *GroupMutation) ResetFallbackGroupID() {
m.fallback_group_id = nil
m.addfallback_group_id = nil
delete(m.clearedFields, group.FieldFallbackGroupID)
}
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by ids.
func (m *GroupMutation) AddAPIKeyIDs(ids ...int64) {
if m.api_keys == nil {
......@@ -4952,7 +5061,7 @@ func (m *GroupMutation) Type() string {
// order to get all numeric fields that were incremented/decremented, call
// AddedFields().
func (m *GroupMutation) Fields() []string {
fields := make([]string, 0, 17)
fields := make([]string, 0, 19)
if m.created_at != nil {
fields = append(fields, group.FieldCreatedAt)
}
......@@ -5004,6 +5113,12 @@ func (m *GroupMutation) Fields() []string {
if m.image_price_4k != nil {
fields = append(fields, group.FieldImagePrice4k)
}
if m.claude_code_only != nil {
fields = append(fields, group.FieldClaudeCodeOnly)
}
if m.fallback_group_id != nil {
fields = append(fields, group.FieldFallbackGroupID)
}
return fields
}
......@@ -5046,6 +5161,10 @@ func (m *GroupMutation) Field(name string) (ent.Value, bool) {
return m.ImagePrice2k()
case group.FieldImagePrice4k:
return m.ImagePrice4k()
case group.FieldClaudeCodeOnly:
return m.ClaudeCodeOnly()
case group.FieldFallbackGroupID:
return m.FallbackGroupID()
}
return nil, false
}
......@@ -5089,6 +5208,10 @@ func (m *GroupMutation) OldField(ctx context.Context, name string) (ent.Value, e
return m.OldImagePrice2k(ctx)
case group.FieldImagePrice4k:
return m.OldImagePrice4k(ctx)
case group.FieldClaudeCodeOnly:
return m.OldClaudeCodeOnly(ctx)
case group.FieldFallbackGroupID:
return m.OldFallbackGroupID(ctx)
}
return nil, fmt.Errorf("unknown Group field %s", name)
}
......@@ -5217,6 +5340,20 @@ func (m *GroupMutation) SetField(name string, value ent.Value) error {
}
m.SetImagePrice4k(v)
return nil
case group.FieldClaudeCodeOnly:
v, ok := value.(bool)
if !ok {
return fmt.Errorf("unexpected type %T for field %s", value, name)
}
m.SetClaudeCodeOnly(v)
return nil
case group.FieldFallbackGroupID:
v, ok := value.(int64)
if !ok {
return fmt.Errorf("unexpected type %T for field %s", value, name)
}
m.SetFallbackGroupID(v)
return nil
}
return fmt.Errorf("unknown Group field %s", name)
}
......@@ -5249,6 +5386,9 @@ func (m *GroupMutation) AddedFields() []string {
if m.addimage_price_4k != nil {
fields = append(fields, group.FieldImagePrice4k)
}
if m.addfallback_group_id != nil {
fields = append(fields, group.FieldFallbackGroupID)
}
return fields
}
......@@ -5273,6 +5413,8 @@ func (m *GroupMutation) AddedField(name string) (ent.Value, bool) {
return m.AddedImagePrice2k()
case group.FieldImagePrice4k:
return m.AddedImagePrice4k()
case group.FieldFallbackGroupID:
return m.AddedFallbackGroupID()
}
return nil, false
}
......@@ -5338,6 +5480,13 @@ func (m *GroupMutation) AddField(name string, value ent.Value) error {
}
m.AddImagePrice4k(v)
return nil
case group.FieldFallbackGroupID:
v, ok := value.(int64)
if !ok {
return fmt.Errorf("unexpected type %T for field %s", value, name)
}
m.AddFallbackGroupID(v)
return nil
}
return fmt.Errorf("unknown Group numeric field %s", name)
}
......@@ -5370,6 +5519,9 @@ func (m *GroupMutation) ClearedFields() []string {
if m.FieldCleared(group.FieldImagePrice4k) {
fields = append(fields, group.FieldImagePrice4k)
}
if m.FieldCleared(group.FieldFallbackGroupID) {
fields = append(fields, group.FieldFallbackGroupID)
}
return fields
}
......@@ -5408,6 +5560,9 @@ func (m *GroupMutation) ClearField(name string) error {
case group.FieldImagePrice4k:
m.ClearImagePrice4k()
return nil
case group.FieldFallbackGroupID:
m.ClearFallbackGroupID()
return nil
}
return fmt.Errorf("unknown Group nullable field %s", name)
}
......@@ -5467,6 +5622,12 @@ func (m *GroupMutation) ResetField(name string) error {
case group.FieldImagePrice4k:
m.ResetImagePrice4k()
return nil
case group.FieldClaudeCodeOnly:
m.ResetClaudeCodeOnly()
return nil
case group.FieldFallbackGroupID:
m.ResetFallbackGroupID()
return nil
}
return fmt.Errorf("unknown Group field %s", name)
}
......
......@@ -270,6 +270,10 @@ func init() {
groupDescDefaultValidityDays := groupFields[10].Descriptor()
// group.DefaultDefaultValidityDays holds the default value on creation for the default_validity_days field.
group.DefaultDefaultValidityDays = groupDescDefaultValidityDays.Default.(int)
// groupDescClaudeCodeOnly is the schema descriptor for claude_code_only field.
groupDescClaudeCodeOnly := groupFields[14].Descriptor()
// group.DefaultClaudeCodeOnly holds the default value on creation for the claude_code_only field.
group.DefaultClaudeCodeOnly = groupDescClaudeCodeOnly.Default.(bool)
proxyMixin := schema.Proxy{}.Mixin()
proxyMixinHooks1 := proxyMixin[1].Hooks()
proxy.Hooks[0] = proxyMixinHooks1[0]
......
......@@ -86,6 +86,15 @@ func (Group) Fields() []ent.Field {
Optional().
Nillable().
SchemaType(map[string]string{dialect.Postgres: "decimal(20,8)"}),
// Claude Code 客户端限制 (added by migration 029)
field.Bool("claude_code_only").
Default(false).
Comment("是否仅允许 Claude Code 客户端"),
field.Int64("fallback_group_id").
Optional().
Nillable().
Comment("非 Claude Code 请求降级使用的分组 ID"),
}
}
......@@ -101,6 +110,8 @@ func (Group) Edges() []ent.Edge {
edge.From("allowed_users", User.Type).
Ref("allowed_groups").
Through("user_allowed_groups", UserAllowedGroup.Type),
// 注意:fallback_group_id 直接作为字段使用,不定义 edge
// 这样允许多个分组指向同一个降级分组(M2O 关系)
}
}
......
......@@ -37,6 +37,8 @@ type CreateGroupRequest struct {
ImagePrice1K *float64 `json:"image_price_1k"`
ImagePrice2K *float64 `json:"image_price_2k"`
ImagePrice4K *float64 `json:"image_price_4k"`
ClaudeCodeOnly bool `json:"claude_code_only"`
FallbackGroupID *int64 `json:"fallback_group_id"`
}
// UpdateGroupRequest represents update group request
......@@ -55,6 +57,8 @@ type UpdateGroupRequest struct {
ImagePrice1K *float64 `json:"image_price_1k"`
ImagePrice2K *float64 `json:"image_price_2k"`
ImagePrice4K *float64 `json:"image_price_4k"`
ClaudeCodeOnly *bool `json:"claude_code_only"`
FallbackGroupID *int64 `json:"fallback_group_id"`
}
// List handles listing all groups with pagination
......@@ -150,6 +154,8 @@ func (h *GroupHandler) Create(c *gin.Context) {
ImagePrice1K: req.ImagePrice1K,
ImagePrice2K: req.ImagePrice2K,
ImagePrice4K: req.ImagePrice4K,
ClaudeCodeOnly: req.ClaudeCodeOnly,
FallbackGroupID: req.FallbackGroupID,
})
if err != nil {
response.ErrorFrom(c, err)
......@@ -188,6 +194,8 @@ func (h *GroupHandler) Update(c *gin.Context) {
ImagePrice1K: req.ImagePrice1K,
ImagePrice2K: req.ImagePrice2K,
ImagePrice4K: req.ImagePrice4K,
ClaudeCodeOnly: req.ClaudeCodeOnly,
FallbackGroupID: req.FallbackGroupID,
})
if err != nil {
response.ErrorFrom(c, err)
......
......@@ -85,6 +85,8 @@ func GroupFromServiceShallow(g *service.Group) *Group {
ImagePrice1K: g.ImagePrice1K,
ImagePrice2K: g.ImagePrice2K,
ImagePrice4K: g.ImagePrice4K,
ClaudeCodeOnly: g.ClaudeCodeOnly,
FallbackGroupID: g.FallbackGroupID,
CreatedAt: g.CreatedAt,
UpdatedAt: g.UpdatedAt,
AccountCount: g.AccountCount,
......
......@@ -52,6 +52,10 @@ type Group struct {
ImagePrice2K *float64 `json:"image_price_2k"`
ImagePrice4K *float64 `json:"image_price_4k"`
// Claude Code 客户端限制
ClaudeCodeOnly bool `json:"claude_code_only"`
FallbackGroupID *int64 `json:"fallback_group_id"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
......
......@@ -96,6 +96,9 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
reqModel := parsedReq.Model
reqStream := parsedReq.Stream
// 设置 Claude Code 客户端标识到 context(用于分组限制检查)
SetClaudeCodeClientContext(c, body)
// 验证 model 必填
if reqModel == "" {
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "model is required")
......@@ -229,7 +232,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
h.handleConcurrencyError(c, err, "account", streamStarted)
return
}
if err := h.gatewayService.BindStickySession(c.Request.Context(), sessionKey, account.ID); err != nil {
if err := h.gatewayService.BindStickySession(c.Request.Context(), apiKey.GroupID, sessionKey, account.ID); err != nil {
log.Printf("Bind sticky session failed: %v", err)
}
}
......@@ -357,7 +360,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
h.handleConcurrencyError(c, err, "account", streamStarted)
return
}
if err := h.gatewayService.BindStickySession(c.Request.Context(), sessionKey, account.ID); err != nil {
if err := h.gatewayService.BindStickySession(c.Request.Context(), apiKey.GroupID, sessionKey, account.ID); err != nil {
log.Printf("Bind sticky session failed: %v", err)
}
}
......@@ -683,6 +686,9 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) {
return
}
// 设置 Claude Code 客户端标识到 context(用于分组限制检查)
SetClaudeCodeClientContext(c, body)
// 验证 model 必填
if parsedReq.Model == "" {
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "model is required")
......
......@@ -2,6 +2,7 @@ package handler
import (
"context"
"encoding/json"
"fmt"
"math/rand"
"net/http"
......@@ -13,6 +14,26 @@ import (
"github.com/gin-gonic/gin"
)
// claudeCodeValidator is a singleton validator for Claude Code client detection
var claudeCodeValidator = service.NewClaudeCodeValidator()
// SetClaudeCodeClientContext 检查请求是否来自 Claude Code 客户端,并设置到 context 中
// 返回更新后的 context
func SetClaudeCodeClientContext(c *gin.Context, body []byte) {
// 解析请求体为 map
var bodyMap map[string]any
if len(body) > 0 {
_ = json.Unmarshal(body, &bodyMap)
}
// 验证是否为 Claude Code 客户端
isClaudeCode := claudeCodeValidator.Validate(c.Request, bodyMap)
// 更新 request context
ctx := service.SetClaudeCodeClient(c.Request.Context(), isClaudeCode)
c.Request = c.Request.WithContext(ctx)
}
// 并发槽位等待相关常量
//
// 性能优化说明:
......
......@@ -203,6 +203,10 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
// 3) select account (sticky session based on request body)
parsedReq, _ := service.ParseGatewayRequest(body)
// 设置 Claude Code 客户端标识到 context(用于分组限制检查)
SetClaudeCodeClientContext(c, body)
sessionHash := h.gatewayService.GenerateSessionHash(parsedReq)
sessionKey := sessionHash
if sessionHash != "" {
......@@ -262,7 +266,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
googleError(c, http.StatusTooManyRequests, err.Error())
return
}
if err := h.gatewayService.BindStickySession(c.Request.Context(), sessionKey, account.ID); err != nil {
if err := h.gatewayService.BindStickySession(c.Request.Context(), apiKey.GroupID, sessionKey, account.ID); err != nil {
log.Printf("Bind sticky session failed: %v", err)
}
}
......
......@@ -206,7 +206,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
h.handleConcurrencyError(c, err, "account", streamStarted)
return
}
if err := h.gatewayService.BindStickySession(c.Request.Context(), sessionHash, account.ID); err != nil {
if err := h.gatewayService.BindStickySession(c.Request.Context(), apiKey.GroupID, sessionHash, account.ID); err != nil {
log.Printf("Bind sticky session failed: %v", err)
}
}
......
......@@ -7,4 +7,6 @@ type Key string
const (
// ForcePlatform 强制平台(用于 /antigravity 路由),由 middleware.ForcePlatform 设置
ForcePlatform Key = "ctx_force_platform"
// IsClaudeCodeClient 是否为 Claude Code 客户端,由中间件设置
IsClaudeCodeClient Key = "ctx_is_claude_code_client"
)
......@@ -325,6 +325,8 @@ func groupEntityToService(g *dbent.Group) *service.Group {
ImagePrice2K: g.ImagePrice2k,
ImagePrice4K: g.ImagePrice4k,
DefaultValidityDays: g.DefaultValidityDays,
ClaudeCodeOnly: g.ClaudeCodeOnly,
FallbackGroupID: g.FallbackGroupID,
CreatedAt: g.CreatedAt,
UpdatedAt: g.UpdatedAt,
}
......
......@@ -2,6 +2,7 @@ package repository
import (
"context"
"fmt"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
......@@ -18,17 +19,23 @@ func NewGatewayCache(rdb *redis.Client) service.GatewayCache {
return &gatewayCache{rdb: rdb}
}
func (c *gatewayCache) GetSessionAccountID(ctx context.Context, sessionHash string) (int64, error) {
key := stickySessionPrefix + sessionHash
// buildSessionKey 构建 session key,包含 groupID 实现分组隔离
// 格式: sticky_session:{groupID}:{sessionHash}
func buildSessionKey(groupID int64, sessionHash string) string {
return fmt.Sprintf("%s%d:%s", stickySessionPrefix, groupID, sessionHash)
}
func (c *gatewayCache) GetSessionAccountID(ctx context.Context, groupID int64, sessionHash string) (int64, error) {
key := buildSessionKey(groupID, sessionHash)
return c.rdb.Get(ctx, key).Int64()
}
func (c *gatewayCache) SetSessionAccountID(ctx context.Context, sessionHash string, accountID int64, ttl time.Duration) error {
key := stickySessionPrefix + sessionHash
func (c *gatewayCache) SetSessionAccountID(ctx context.Context, groupID int64, sessionHash string, accountID int64, ttl time.Duration) error {
key := buildSessionKey(groupID, sessionHash)
return c.rdb.Set(ctx, key, accountID, ttl).Err()
}
func (c *gatewayCache) RefreshSessionTTL(ctx context.Context, sessionHash string, ttl time.Duration) error {
key := stickySessionPrefix + sessionHash
func (c *gatewayCache) RefreshSessionTTL(ctx context.Context, groupID int64, sessionHash string, ttl time.Duration) error {
key := buildSessionKey(groupID, sessionHash)
return c.rdb.Expire(ctx, key, ttl).Err()
}
......@@ -46,7 +46,9 @@ func (r *groupRepository) Create(ctx context.Context, groupIn *service.Group) er
SetNillableImagePrice1k(groupIn.ImagePrice1K).
SetNillableImagePrice2k(groupIn.ImagePrice2K).
SetNillableImagePrice4k(groupIn.ImagePrice4K).
SetDefaultValidityDays(groupIn.DefaultValidityDays)
SetDefaultValidityDays(groupIn.DefaultValidityDays).
SetClaudeCodeOnly(groupIn.ClaudeCodeOnly).
SetNillableFallbackGroupID(groupIn.FallbackGroupID)
created, err := builder.Save(ctx)
if err == nil {
......@@ -72,7 +74,7 @@ func (r *groupRepository) GetByID(ctx context.Context, id int64) (*service.Group
}
func (r *groupRepository) Update(ctx context.Context, groupIn *service.Group) error {
updated, err := r.client.Group.UpdateOneID(groupIn.ID).
builder := r.client.Group.UpdateOneID(groupIn.ID).
SetName(groupIn.Name).
SetDescription(groupIn.Description).
SetPlatform(groupIn.Platform).
......@@ -87,7 +89,16 @@ func (r *groupRepository) Update(ctx context.Context, groupIn *service.Group) er
SetNillableImagePrice2k(groupIn.ImagePrice2K).
SetNillableImagePrice4k(groupIn.ImagePrice4K).
SetDefaultValidityDays(groupIn.DefaultValidityDays).
Save(ctx)
SetClaudeCodeOnly(groupIn.ClaudeCodeOnly)
// 处理 FallbackGroupID:nil 时清除,否则设置
if groupIn.FallbackGroupID != nil {
builder = builder.SetFallbackGroupID(*groupIn.FallbackGroupID)
} else {
builder = builder.ClearFallbackGroupID()
}
updated, err := builder.Save(ctx)
if err != nil {
return translatePersistenceError(err, service.ErrGroupNotFound, service.ErrGroupExists)
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment