"backend/internal/vscode:/vscode.git/clone" did not exist on "297f08c683f8d02843418b4531c4e932361bca92"
Commit 7319122e authored by LLLLLLiulei's avatar LLLLLLiulei
Browse files

merge upstream/main

parents 029994a8 4809fa4f
...@@ -13,6 +13,7 @@ import ( ...@@ -13,6 +13,7 @@ import (
"github.com/Wei-Shaw/sub2api/ent/announcement" "github.com/Wei-Shaw/sub2api/ent/announcement"
"github.com/Wei-Shaw/sub2api/ent/announcementread" "github.com/Wei-Shaw/sub2api/ent/announcementread"
"github.com/Wei-Shaw/sub2api/ent/apikey" "github.com/Wei-Shaw/sub2api/ent/apikey"
"github.com/Wei-Shaw/sub2api/ent/errorpassthroughrule"
"github.com/Wei-Shaw/sub2api/ent/group" "github.com/Wei-Shaw/sub2api/ent/group"
"github.com/Wei-Shaw/sub2api/ent/predicate" "github.com/Wei-Shaw/sub2api/ent/predicate"
"github.com/Wei-Shaw/sub2api/ent/promocode" "github.com/Wei-Shaw/sub2api/ent/promocode"
...@@ -220,6 +221,33 @@ func (f TraverseAnnouncementRead) Traverse(ctx context.Context, q ent.Query) err ...@@ -220,6 +221,33 @@ func (f TraverseAnnouncementRead) Traverse(ctx context.Context, q ent.Query) err
return fmt.Errorf("unexpected query type %T. expect *ent.AnnouncementReadQuery", q) return fmt.Errorf("unexpected query type %T. expect *ent.AnnouncementReadQuery", q)
} }
// The ErrorPassthroughRuleFunc type is an adapter to allow the use of ordinary function as a Querier.
type ErrorPassthroughRuleFunc func(context.Context, *ent.ErrorPassthroughRuleQuery) (ent.Value, error)
// Query calls f(ctx, q).
func (f ErrorPassthroughRuleFunc) Query(ctx context.Context, q ent.Query) (ent.Value, error) {
if q, ok := q.(*ent.ErrorPassthroughRuleQuery); ok {
return f(ctx, q)
}
return nil, fmt.Errorf("unexpected query type %T. expect *ent.ErrorPassthroughRuleQuery", q)
}
// The TraverseErrorPassthroughRule type is an adapter to allow the use of ordinary function as Traverser.
type TraverseErrorPassthroughRule func(context.Context, *ent.ErrorPassthroughRuleQuery) error
// Intercept is a dummy implementation of Intercept that returns the next Querier in the pipeline.
func (f TraverseErrorPassthroughRule) Intercept(next ent.Querier) ent.Querier {
return next
}
// Traverse calls f(ctx, q).
func (f TraverseErrorPassthroughRule) Traverse(ctx context.Context, q ent.Query) error {
if q, ok := q.(*ent.ErrorPassthroughRuleQuery); ok {
return f(ctx, q)
}
return fmt.Errorf("unexpected query type %T. expect *ent.ErrorPassthroughRuleQuery", q)
}
// The GroupFunc type is an adapter to allow the use of ordinary function as a Querier. // The GroupFunc type is an adapter to allow the use of ordinary function as a Querier.
type GroupFunc func(context.Context, *ent.GroupQuery) (ent.Value, error) type GroupFunc func(context.Context, *ent.GroupQuery) (ent.Value, error)
...@@ -584,6 +612,8 @@ func NewQuery(q ent.Query) (Query, error) { ...@@ -584,6 +612,8 @@ func NewQuery(q ent.Query) (Query, error) {
return &query[*ent.AnnouncementQuery, predicate.Announcement, announcement.OrderOption]{typ: ent.TypeAnnouncement, tq: q}, nil return &query[*ent.AnnouncementQuery, predicate.Announcement, announcement.OrderOption]{typ: ent.TypeAnnouncement, tq: q}, nil
case *ent.AnnouncementReadQuery: case *ent.AnnouncementReadQuery:
return &query[*ent.AnnouncementReadQuery, predicate.AnnouncementRead, announcementread.OrderOption]{typ: ent.TypeAnnouncementRead, tq: q}, nil return &query[*ent.AnnouncementReadQuery, predicate.AnnouncementRead, announcementread.OrderOption]{typ: ent.TypeAnnouncementRead, tq: q}, nil
case *ent.ErrorPassthroughRuleQuery:
return &query[*ent.ErrorPassthroughRuleQuery, predicate.ErrorPassthroughRule, errorpassthroughrule.OrderOption]{typ: ent.TypeErrorPassthroughRule, tq: q}, nil
case *ent.GroupQuery: case *ent.GroupQuery:
return &query[*ent.GroupQuery, predicate.Group, group.OrderOption]{typ: ent.TypeGroup, tq: q}, nil return &query[*ent.GroupQuery, predicate.Group, group.OrderOption]{typ: ent.TypeGroup, tq: q}, nil
case *ent.PromoCodeQuery: case *ent.PromoCodeQuery:
......
...@@ -309,6 +309,42 @@ var ( ...@@ -309,6 +309,42 @@ var (
}, },
}, },
} }
// ErrorPassthroughRulesColumns holds the columns for the "error_passthrough_rules" table.
ErrorPassthroughRulesColumns = []*schema.Column{
{Name: "id", Type: field.TypeInt64, Increment: true},
{Name: "created_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}},
{Name: "updated_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}},
{Name: "name", Type: field.TypeString, Size: 100},
{Name: "enabled", Type: field.TypeBool, Default: true},
{Name: "priority", Type: field.TypeInt, Default: 0},
{Name: "error_codes", Type: field.TypeJSON, Nullable: true, SchemaType: map[string]string{"postgres": "jsonb"}},
{Name: "keywords", Type: field.TypeJSON, Nullable: true, SchemaType: map[string]string{"postgres": "jsonb"}},
{Name: "match_mode", Type: field.TypeString, Size: 10, Default: "any"},
{Name: "platforms", Type: field.TypeJSON, Nullable: true, SchemaType: map[string]string{"postgres": "jsonb"}},
{Name: "passthrough_code", Type: field.TypeBool, Default: true},
{Name: "response_code", Type: field.TypeInt, Nullable: true},
{Name: "passthrough_body", Type: field.TypeBool, Default: true},
{Name: "custom_message", Type: field.TypeString, Nullable: true, Size: 2147483647},
{Name: "description", Type: field.TypeString, Nullable: true, Size: 2147483647},
}
// ErrorPassthroughRulesTable holds the schema information for the "error_passthrough_rules" table.
ErrorPassthroughRulesTable = &schema.Table{
Name: "error_passthrough_rules",
Columns: ErrorPassthroughRulesColumns,
PrimaryKey: []*schema.Column{ErrorPassthroughRulesColumns[0]},
Indexes: []*schema.Index{
{
Name: "errorpassthroughrule_enabled",
Unique: false,
Columns: []*schema.Column{ErrorPassthroughRulesColumns[4]},
},
{
Name: "errorpassthroughrule_priority",
Unique: false,
Columns: []*schema.Column{ErrorPassthroughRulesColumns[5]},
},
},
}
// GroupsColumns holds the columns for the "groups" table. // GroupsColumns holds the columns for the "groups" table.
GroupsColumns = []*schema.Column{ GroupsColumns = []*schema.Column{
{Name: "id", Type: field.TypeInt64, Increment: true}, {Name: "id", Type: field.TypeInt64, Increment: true},
...@@ -950,6 +986,7 @@ var ( ...@@ -950,6 +986,7 @@ var (
AccountGroupsTable, AccountGroupsTable,
AnnouncementsTable, AnnouncementsTable,
AnnouncementReadsTable, AnnouncementReadsTable,
ErrorPassthroughRulesTable,
GroupsTable, GroupsTable,
PromoCodesTable, PromoCodesTable,
PromoCodeUsagesTable, PromoCodeUsagesTable,
...@@ -989,6 +1026,9 @@ func init() { ...@@ -989,6 +1026,9 @@ func init() {
AnnouncementReadsTable.Annotation = &entsql.Annotation{ AnnouncementReadsTable.Annotation = &entsql.Annotation{
Table: "announcement_reads", Table: "announcement_reads",
} }
ErrorPassthroughRulesTable.Annotation = &entsql.Annotation{
Table: "error_passthrough_rules",
}
GroupsTable.Annotation = &entsql.Annotation{ GroupsTable.Annotation = &entsql.Annotation{
Table: "groups", Table: "groups",
} }
......
This diff is collapsed.
...@@ -21,6 +21,9 @@ type Announcement func(*sql.Selector) ...@@ -21,6 +21,9 @@ type Announcement func(*sql.Selector)
// AnnouncementRead is the predicate function for announcementread builders. // AnnouncementRead is the predicate function for announcementread builders.
type AnnouncementRead func(*sql.Selector) type AnnouncementRead func(*sql.Selector)
// ErrorPassthroughRule is the predicate function for errorpassthroughrule builders.
type ErrorPassthroughRule func(*sql.Selector)
// Group is the predicate function for group builders. // Group is the predicate function for group builders.
type Group func(*sql.Selector) type Group func(*sql.Selector)
......
...@@ -10,6 +10,7 @@ import ( ...@@ -10,6 +10,7 @@ import (
"github.com/Wei-Shaw/sub2api/ent/announcement" "github.com/Wei-Shaw/sub2api/ent/announcement"
"github.com/Wei-Shaw/sub2api/ent/announcementread" "github.com/Wei-Shaw/sub2api/ent/announcementread"
"github.com/Wei-Shaw/sub2api/ent/apikey" "github.com/Wei-Shaw/sub2api/ent/apikey"
"github.com/Wei-Shaw/sub2api/ent/errorpassthroughrule"
"github.com/Wei-Shaw/sub2api/ent/group" "github.com/Wei-Shaw/sub2api/ent/group"
"github.com/Wei-Shaw/sub2api/ent/promocode" "github.com/Wei-Shaw/sub2api/ent/promocode"
"github.com/Wei-Shaw/sub2api/ent/promocodeusage" "github.com/Wei-Shaw/sub2api/ent/promocodeusage"
...@@ -270,6 +271,61 @@ func init() { ...@@ -270,6 +271,61 @@ func init() {
announcementreadDescCreatedAt := announcementreadFields[3].Descriptor() announcementreadDescCreatedAt := announcementreadFields[3].Descriptor()
// announcementread.DefaultCreatedAt holds the default value on creation for the created_at field. // announcementread.DefaultCreatedAt holds the default value on creation for the created_at field.
announcementread.DefaultCreatedAt = announcementreadDescCreatedAt.Default.(func() time.Time) announcementread.DefaultCreatedAt = announcementreadDescCreatedAt.Default.(func() time.Time)
errorpassthroughruleMixin := schema.ErrorPassthroughRule{}.Mixin()
errorpassthroughruleMixinFields0 := errorpassthroughruleMixin[0].Fields()
_ = errorpassthroughruleMixinFields0
errorpassthroughruleFields := schema.ErrorPassthroughRule{}.Fields()
_ = errorpassthroughruleFields
// errorpassthroughruleDescCreatedAt is the schema descriptor for created_at field.
errorpassthroughruleDescCreatedAt := errorpassthroughruleMixinFields0[0].Descriptor()
// errorpassthroughrule.DefaultCreatedAt holds the default value on creation for the created_at field.
errorpassthroughrule.DefaultCreatedAt = errorpassthroughruleDescCreatedAt.Default.(func() time.Time)
// errorpassthroughruleDescUpdatedAt is the schema descriptor for updated_at field.
errorpassthroughruleDescUpdatedAt := errorpassthroughruleMixinFields0[1].Descriptor()
// errorpassthroughrule.DefaultUpdatedAt holds the default value on creation for the updated_at field.
errorpassthroughrule.DefaultUpdatedAt = errorpassthroughruleDescUpdatedAt.Default.(func() time.Time)
// errorpassthroughrule.UpdateDefaultUpdatedAt holds the default value on update for the updated_at field.
errorpassthroughrule.UpdateDefaultUpdatedAt = errorpassthroughruleDescUpdatedAt.UpdateDefault.(func() time.Time)
// errorpassthroughruleDescName is the schema descriptor for name field.
errorpassthroughruleDescName := errorpassthroughruleFields[0].Descriptor()
// errorpassthroughrule.NameValidator is a validator for the "name" field. It is called by the builders before save.
errorpassthroughrule.NameValidator = func() func(string) error {
validators := errorpassthroughruleDescName.Validators
fns := [...]func(string) error{
validators[0].(func(string) error),
validators[1].(func(string) error),
}
return func(name string) error {
for _, fn := range fns {
if err := fn(name); err != nil {
return err
}
}
return nil
}
}()
// errorpassthroughruleDescEnabled is the schema descriptor for enabled field.
errorpassthroughruleDescEnabled := errorpassthroughruleFields[1].Descriptor()
// errorpassthroughrule.DefaultEnabled holds the default value on creation for the enabled field.
errorpassthroughrule.DefaultEnabled = errorpassthroughruleDescEnabled.Default.(bool)
// errorpassthroughruleDescPriority is the schema descriptor for priority field.
errorpassthroughruleDescPriority := errorpassthroughruleFields[2].Descriptor()
// errorpassthroughrule.DefaultPriority holds the default value on creation for the priority field.
errorpassthroughrule.DefaultPriority = errorpassthroughruleDescPriority.Default.(int)
// errorpassthroughruleDescMatchMode is the schema descriptor for match_mode field.
errorpassthroughruleDescMatchMode := errorpassthroughruleFields[5].Descriptor()
// errorpassthroughrule.DefaultMatchMode holds the default value on creation for the match_mode field.
errorpassthroughrule.DefaultMatchMode = errorpassthroughruleDescMatchMode.Default.(string)
// errorpassthroughrule.MatchModeValidator is a validator for the "match_mode" field. It is called by the builders before save.
errorpassthroughrule.MatchModeValidator = errorpassthroughruleDescMatchMode.Validators[0].(func(string) error)
// errorpassthroughruleDescPassthroughCode is the schema descriptor for passthrough_code field.
errorpassthroughruleDescPassthroughCode := errorpassthroughruleFields[7].Descriptor()
// errorpassthroughrule.DefaultPassthroughCode holds the default value on creation for the passthrough_code field.
errorpassthroughrule.DefaultPassthroughCode = errorpassthroughruleDescPassthroughCode.Default.(bool)
// errorpassthroughruleDescPassthroughBody is the schema descriptor for passthrough_body field.
errorpassthroughruleDescPassthroughBody := errorpassthroughruleFields[9].Descriptor()
// errorpassthroughrule.DefaultPassthroughBody holds the default value on creation for the passthrough_body field.
errorpassthroughrule.DefaultPassthroughBody = errorpassthroughruleDescPassthroughBody.Default.(bool)
groupMixin := schema.Group{}.Mixin() groupMixin := schema.Group{}.Mixin()
groupMixinHooks1 := groupMixin[1].Hooks() groupMixinHooks1 := groupMixin[1].Hooks()
group.Hooks[0] = groupMixinHooks1[0] group.Hooks[0] = groupMixinHooks1[0]
......
// Package schema 定义 Ent ORM 的数据库 schema。
package schema
import (
"github.com/Wei-Shaw/sub2api/ent/schema/mixins"
"entgo.io/ent"
"entgo.io/ent/dialect"
"entgo.io/ent/dialect/entsql"
"entgo.io/ent/schema"
"entgo.io/ent/schema/field"
"entgo.io/ent/schema/index"
)
// ErrorPassthroughRule 定义全局错误透传规则的 schema。
//
// 错误透传规则用于控制上游错误如何返回给客户端:
// - 匹配条件:错误码 + 关键词组合
// - 响应行为:透传原始信息 或 自定义错误信息
// - 响应状态码:可指定返回给客户端的状态码
// - 平台范围:规则适用的平台(Anthropic、OpenAI、Gemini、Antigravity)
type ErrorPassthroughRule struct {
ent.Schema
}
// Annotations 返回 schema 的注解配置。
func (ErrorPassthroughRule) Annotations() []schema.Annotation {
return []schema.Annotation{
entsql.Annotation{Table: "error_passthrough_rules"},
}
}
// Mixin 返回该 schema 使用的混入组件。
func (ErrorPassthroughRule) Mixin() []ent.Mixin {
return []ent.Mixin{
mixins.TimeMixin{},
}
}
// Fields 定义错误透传规则实体的所有字段。
func (ErrorPassthroughRule) Fields() []ent.Field {
return []ent.Field{
// name: 规则名称,用于在界面中标识规则
field.String("name").
MaxLen(100).
NotEmpty(),
// enabled: 是否启用该规则
field.Bool("enabled").
Default(true),
// priority: 规则优先级,数值越小优先级越高
// 匹配时按优先级顺序检查,命中第一个匹配的规则
field.Int("priority").
Default(0),
// error_codes: 匹配的错误码列表(OR关系)
// 例如:[422, 400] 表示匹配 422 或 400 错误码
field.JSON("error_codes", []int{}).
Optional().
SchemaType(map[string]string{dialect.Postgres: "jsonb"}),
// keywords: 匹配的关键词列表(OR关系)
// 例如:["context limit", "model not supported"]
// 关键词匹配不区分大小写
field.JSON("keywords", []string{}).
Optional().
SchemaType(map[string]string{dialect.Postgres: "jsonb"}),
// match_mode: 匹配模式
// - "any": 错误码匹配 OR 关键词匹配(任一条件满足即可)
// - "all": 错误码匹配 AND 关键词匹配(所有条件都必须满足)
field.String("match_mode").
MaxLen(10).
Default("any"),
// platforms: 适用平台列表
// 例如:["anthropic", "openai", "gemini", "antigravity"]
// 空列表表示适用于所有平台
field.JSON("platforms", []string{}).
Optional().
SchemaType(map[string]string{dialect.Postgres: "jsonb"}),
// passthrough_code: 是否透传上游原始状态码
// true: 使用上游返回的状态码
// false: 使用 response_code 指定的状态码
field.Bool("passthrough_code").
Default(true),
// response_code: 自定义响应状态码
// 当 passthrough_code=false 时使用此状态码
field.Int("response_code").
Optional().
Nillable(),
// passthrough_body: 是否透传上游原始错误信息
// true: 使用上游返回的错误信息
// false: 使用 custom_message 指定的错误信息
field.Bool("passthrough_body").
Default(true),
// custom_message: 自定义错误信息
// 当 passthrough_body=false 时使用此错误信息
field.Text("custom_message").
Optional().
Nillable(),
// description: 规则描述,用于说明规则的用途
field.Text("description").
Optional().
Nillable(),
}
}
// Indexes 定义数据库索引,优化查询性能。
func (ErrorPassthroughRule) Indexes() []ent.Index {
return []ent.Index{
index.Fields("enabled"), // 筛选启用的规则
index.Fields("priority"), // 按优先级排序
}
}
...@@ -24,6 +24,8 @@ type Tx struct { ...@@ -24,6 +24,8 @@ type Tx struct {
Announcement *AnnouncementClient Announcement *AnnouncementClient
// AnnouncementRead is the client for interacting with the AnnouncementRead builders. // AnnouncementRead is the client for interacting with the AnnouncementRead builders.
AnnouncementRead *AnnouncementReadClient AnnouncementRead *AnnouncementReadClient
// ErrorPassthroughRule is the client for interacting with the ErrorPassthroughRule builders.
ErrorPassthroughRule *ErrorPassthroughRuleClient
// Group is the client for interacting with the Group builders. // Group is the client for interacting with the Group builders.
Group *GroupClient Group *GroupClient
// PromoCode is the client for interacting with the PromoCode builders. // PromoCode is the client for interacting with the PromoCode builders.
...@@ -186,6 +188,7 @@ func (tx *Tx) init() { ...@@ -186,6 +188,7 @@ func (tx *Tx) init() {
tx.AccountGroup = NewAccountGroupClient(tx.config) tx.AccountGroup = NewAccountGroupClient(tx.config)
tx.Announcement = NewAnnouncementClient(tx.config) tx.Announcement = NewAnnouncementClient(tx.config)
tx.AnnouncementRead = NewAnnouncementReadClient(tx.config) tx.AnnouncementRead = NewAnnouncementReadClient(tx.config)
tx.ErrorPassthroughRule = NewErrorPassthroughRuleClient(tx.config)
tx.Group = NewGroupClient(tx.config) tx.Group = NewGroupClient(tx.config)
tx.PromoCode = NewPromoCodeClient(tx.config) tx.PromoCode = NewPromoCodeClient(tx.config)
tx.PromoCodeUsage = NewPromoCodeUsageClient(tx.config) tx.PromoCodeUsage = NewPromoCodeUsageClient(tx.config)
......
module github.com/Wei-Shaw/sub2api module github.com/Wei-Shaw/sub2api
go 1.25.6 go 1.25.7
require ( require (
entgo.io/ent v0.14.5 entgo.io/ent v0.14.5
......
This diff is collapsed.
...@@ -45,6 +45,9 @@ type UpdateUserRequest struct { ...@@ -45,6 +45,9 @@ type UpdateUserRequest struct {
Concurrency *int `json:"concurrency"` Concurrency *int `json:"concurrency"`
Status string `json:"status" binding:"omitempty,oneof=active disabled"` Status string `json:"status" binding:"omitempty,oneof=active disabled"`
AllowedGroups *[]int64 `json:"allowed_groups"` AllowedGroups *[]int64 `json:"allowed_groups"`
// GroupRates 用户专属分组倍率配置
// map[groupID]*rate,nil 表示删除该分组的专属倍率
GroupRates map[int64]*float64 `json:"group_rates"`
} }
// UpdateBalanceRequest represents balance update request // UpdateBalanceRequest represents balance update request
...@@ -183,6 +186,7 @@ func (h *UserHandler) Update(c *gin.Context) { ...@@ -183,6 +186,7 @@ func (h *UserHandler) Update(c *gin.Context) {
Concurrency: req.Concurrency, Concurrency: req.Concurrency,
Status: req.Status, Status: req.Status,
AllowedGroups: req.AllowedGroups, AllowedGroups: req.AllowedGroups,
GroupRates: req.GroupRates,
}) })
if err != nil { if err != nil {
response.ErrorFrom(c, err) response.ErrorFrom(c, err)
......
...@@ -243,3 +243,21 @@ func (h *APIKeyHandler) GetAvailableGroups(c *gin.Context) { ...@@ -243,3 +243,21 @@ func (h *APIKeyHandler) GetAvailableGroups(c *gin.Context) {
} }
response.Success(c, out) response.Success(c, out)
} }
// GetUserGroupRates 获取当前用户的专属分组倍率配置
// GET /api/v1/groups/rates
func (h *APIKeyHandler) GetUserGroupRates(c *gin.Context) {
subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !ok {
response.Unauthorized(c, "User not authenticated")
return
}
rates, err := h.apiKeyService.GetUserGroupRates(c.Request.Context(), subject.UserID)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, rates)
}
...@@ -60,6 +60,7 @@ func UserFromServiceAdmin(u *service.User) *AdminUser { ...@@ -60,6 +60,7 @@ func UserFromServiceAdmin(u *service.User) *AdminUser {
return &AdminUser{ return &AdminUser{
User: *base, User: *base,
Notes: u.Notes, Notes: u.Notes,
GroupRates: u.GroupRates,
} }
} }
......
...@@ -29,6 +29,9 @@ type AdminUser struct { ...@@ -29,6 +29,9 @@ type AdminUser struct {
User User
Notes string `json:"notes"` Notes string `json:"notes"`
// GroupRates 用户专属分组倍率配置
// map[groupID]rateMultiplier
GroupRates map[int64]float64 `json:"group_rates,omitempty"`
} }
type APIKey struct { type APIKey struct {
......
...@@ -33,6 +33,7 @@ type GatewayHandler struct { ...@@ -33,6 +33,7 @@ type GatewayHandler struct {
billingCacheService *service.BillingCacheService billingCacheService *service.BillingCacheService
usageService *service.UsageService usageService *service.UsageService
apiKeyService *service.APIKeyService apiKeyService *service.APIKeyService
errorPassthroughService *service.ErrorPassthroughService
concurrencyHelper *ConcurrencyHelper concurrencyHelper *ConcurrencyHelper
maxAccountSwitches int maxAccountSwitches int
maxAccountSwitchesGemini int maxAccountSwitchesGemini int
...@@ -48,6 +49,7 @@ func NewGatewayHandler( ...@@ -48,6 +49,7 @@ func NewGatewayHandler(
billingCacheService *service.BillingCacheService, billingCacheService *service.BillingCacheService,
usageService *service.UsageService, usageService *service.UsageService,
apiKeyService *service.APIKeyService, apiKeyService *service.APIKeyService,
errorPassthroughService *service.ErrorPassthroughService,
cfg *config.Config, cfg *config.Config,
) *GatewayHandler { ) *GatewayHandler {
pingInterval := time.Duration(0) pingInterval := time.Duration(0)
...@@ -70,6 +72,7 @@ func NewGatewayHandler( ...@@ -70,6 +72,7 @@ func NewGatewayHandler(
billingCacheService: billingCacheService, billingCacheService: billingCacheService,
usageService: usageService, usageService: usageService,
apiKeyService: apiKeyService, apiKeyService: apiKeyService,
errorPassthroughService: errorPassthroughService,
concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatClaude, pingInterval), concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatClaude, pingInterval),
maxAccountSwitches: maxAccountSwitches, maxAccountSwitches: maxAccountSwitches,
maxAccountSwitchesGemini: maxAccountSwitchesGemini, maxAccountSwitchesGemini: maxAccountSwitchesGemini,
...@@ -201,7 +204,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) { ...@@ -201,7 +204,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
maxAccountSwitches := h.maxAccountSwitchesGemini maxAccountSwitches := h.maxAccountSwitchesGemini
switchCount := 0 switchCount := 0
failedAccountIDs := make(map[int64]struct{}) failedAccountIDs := make(map[int64]struct{})
lastFailoverStatus := 0 var lastFailoverErr *service.UpstreamFailoverError
for { for {
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, reqModel, failedAccountIDs, "") // Gemini 不使用会话限制 selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, reqModel, failedAccountIDs, "") // Gemini 不使用会话限制
...@@ -210,7 +213,11 @@ func (h *GatewayHandler) Messages(c *gin.Context) { ...@@ -210,7 +213,11 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted) h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted)
return return
} }
h.handleFailoverExhausted(c, lastFailoverStatus, streamStarted) if lastFailoverErr != nil {
h.handleFailoverExhausted(c, lastFailoverErr, service.PlatformGemini, streamStarted)
} else {
h.handleFailoverExhaustedSimple(c, 502, streamStarted)
}
return return
} }
account := selection.Account account := selection.Account
...@@ -301,9 +308,9 @@ func (h *GatewayHandler) Messages(c *gin.Context) { ...@@ -301,9 +308,9 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
var failoverErr *service.UpstreamFailoverError var failoverErr *service.UpstreamFailoverError
if errors.As(err, &failoverErr) { if errors.As(err, &failoverErr) {
failedAccountIDs[account.ID] = struct{}{} failedAccountIDs[account.ID] = struct{}{}
lastFailoverStatus = failoverErr.StatusCode lastFailoverErr = failoverErr
if switchCount >= maxAccountSwitches { if switchCount >= maxAccountSwitches {
h.handleFailoverExhausted(c, lastFailoverStatus, streamStarted) h.handleFailoverExhausted(c, failoverErr, service.PlatformGemini, streamStarted)
return return
} }
switchCount++ switchCount++
...@@ -352,7 +359,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) { ...@@ -352,7 +359,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
maxAccountSwitches := h.maxAccountSwitches maxAccountSwitches := h.maxAccountSwitches
switchCount := 0 switchCount := 0
failedAccountIDs := make(map[int64]struct{}) failedAccountIDs := make(map[int64]struct{})
lastFailoverStatus := 0 var lastFailoverErr *service.UpstreamFailoverError
retryWithFallback := false retryWithFallback := false
for { for {
...@@ -363,7 +370,11 @@ func (h *GatewayHandler) Messages(c *gin.Context) { ...@@ -363,7 +370,11 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted) h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted)
return return
} }
h.handleFailoverExhausted(c, lastFailoverStatus, streamStarted) if lastFailoverErr != nil {
h.handleFailoverExhausted(c, lastFailoverErr, platform, streamStarted)
} else {
h.handleFailoverExhaustedSimple(c, 502, streamStarted)
}
return return
} }
account := selection.Account account := selection.Account
...@@ -487,9 +498,9 @@ func (h *GatewayHandler) Messages(c *gin.Context) { ...@@ -487,9 +498,9 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
var failoverErr *service.UpstreamFailoverError var failoverErr *service.UpstreamFailoverError
if errors.As(err, &failoverErr) { if errors.As(err, &failoverErr) {
failedAccountIDs[account.ID] = struct{}{} failedAccountIDs[account.ID] = struct{}{}
lastFailoverStatus = failoverErr.StatusCode lastFailoverErr = failoverErr
if switchCount >= maxAccountSwitches { if switchCount >= maxAccountSwitches {
h.handleFailoverExhausted(c, lastFailoverStatus, streamStarted) h.handleFailoverExhausted(c, failoverErr, account.Platform, streamStarted)
return return
} }
switchCount++ switchCount++
...@@ -755,7 +766,37 @@ func (h *GatewayHandler) handleConcurrencyError(c *gin.Context, err error, slotT ...@@ -755,7 +766,37 @@ func (h *GatewayHandler) handleConcurrencyError(c *gin.Context, err error, slotT
fmt.Sprintf("Concurrency limit exceeded for %s, please retry later", slotType), streamStarted) fmt.Sprintf("Concurrency limit exceeded for %s, please retry later", slotType), streamStarted)
} }
func (h *GatewayHandler) handleFailoverExhausted(c *gin.Context, statusCode int, streamStarted bool) { func (h *GatewayHandler) handleFailoverExhausted(c *gin.Context, failoverErr *service.UpstreamFailoverError, platform string, streamStarted bool) {
statusCode := failoverErr.StatusCode
responseBody := failoverErr.ResponseBody
// 先检查透传规则
if h.errorPassthroughService != nil && len(responseBody) > 0 {
if rule := h.errorPassthroughService.MatchRule(platform, statusCode, responseBody); rule != nil {
// 确定响应状态码
respCode := statusCode
if !rule.PassthroughCode && rule.ResponseCode != nil {
respCode = *rule.ResponseCode
}
// 确定响应消息
msg := service.ExtractUpstreamErrorMessage(responseBody)
if !rule.PassthroughBody && rule.CustomMessage != nil {
msg = *rule.CustomMessage
}
h.handleStreamingAwareError(c, respCode, "upstream_error", msg, streamStarted)
return
}
}
// 使用默认的错误映射
status, errType, errMsg := h.mapUpstreamError(statusCode)
h.handleStreamingAwareError(c, status, errType, errMsg, streamStarted)
}
// handleFailoverExhaustedSimple 简化版本,用于没有响应体的情况
func (h *GatewayHandler) handleFailoverExhaustedSimple(c *gin.Context, statusCode int, streamStarted bool) {
status, errType, errMsg := h.mapUpstreamError(statusCode) status, errType, errMsg := h.mapUpstreamError(statusCode)
h.handleStreamingAwareError(c, status, errType, errMsg, streamStarted) h.handleStreamingAwareError(c, status, errType, errMsg, streamStarted)
} }
......
...@@ -24,6 +24,7 @@ type AdminHandlers struct { ...@@ -24,6 +24,7 @@ type AdminHandlers struct {
Subscription *admin.SubscriptionHandler Subscription *admin.SubscriptionHandler
Usage *admin.UsageHandler Usage *admin.UsageHandler
UserAttribute *admin.UserAttributeHandler UserAttribute *admin.UserAttributeHandler
ErrorPassthrough *admin.ErrorPassthroughHandler
} }
// Handlers contains all HTTP handlers // Handlers contains all HTTP handlers
......
...@@ -27,6 +27,7 @@ func ProvideAdminHandlers( ...@@ -27,6 +27,7 @@ func ProvideAdminHandlers(
subscriptionHandler *admin.SubscriptionHandler, subscriptionHandler *admin.SubscriptionHandler,
usageHandler *admin.UsageHandler, usageHandler *admin.UsageHandler,
userAttributeHandler *admin.UserAttributeHandler, userAttributeHandler *admin.UserAttributeHandler,
errorPassthroughHandler *admin.ErrorPassthroughHandler,
) *AdminHandlers { ) *AdminHandlers {
return &AdminHandlers{ return &AdminHandlers{
Dashboard: dashboardHandler, Dashboard: dashboardHandler,
...@@ -47,6 +48,7 @@ func ProvideAdminHandlers( ...@@ -47,6 +48,7 @@ func ProvideAdminHandlers(
Subscription: subscriptionHandler, Subscription: subscriptionHandler,
Usage: usageHandler, Usage: usageHandler,
UserAttribute: userAttributeHandler, UserAttribute: userAttributeHandler,
ErrorPassthrough: errorPassthroughHandler,
} }
} }
...@@ -125,6 +127,7 @@ var ProviderSet = wire.NewSet( ...@@ -125,6 +127,7 @@ var ProviderSet = wire.NewSet(
admin.NewSubscriptionHandler, admin.NewSubscriptionHandler,
admin.NewUsageHandler, admin.NewUsageHandler,
admin.NewUserAttributeHandler, admin.NewUserAttributeHandler,
admin.NewErrorPassthroughHandler,
// AdminHandlers and Handlers constructors // AdminHandlers and Handlers constructors
ProvideAdminHandlers, ProvideAdminHandlers,
......
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment