Commit 7dddd065 authored by yangjianbo's avatar yangjianbo
Browse files
parents 25a0d49a e78c8646
......@@ -166,14 +166,14 @@ func (_c *UserCreate) SetNillableNotes(v *string) *UserCreate {
return _c
}
// AddAPIKeyIDs adds the "api_keys" edge to the ApiKey entity by IDs.
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs.
func (_c *UserCreate) AddAPIKeyIDs(ids ...int64) *UserCreate {
_c.mutation.AddAPIKeyIDs(ids...)
return _c
}
// AddAPIKeys adds the "api_keys" edges to the ApiKey entity.
func (_c *UserCreate) AddAPIKeys(v ...*ApiKey) *UserCreate {
// AddAPIKeys adds the "api_keys" edges to the APIKey entity.
func (_c *UserCreate) AddAPIKeys(v ...*APIKey) *UserCreate {
ids := make([]int64, len(v))
for i := range v {
ids[i] = v[i].ID
......
......@@ -30,7 +30,7 @@ type UserQuery struct {
order []user.OrderOption
inters []Interceptor
predicates []predicate.User
withAPIKeys *ApiKeyQuery
withAPIKeys *APIKeyQuery
withRedeemCodes *RedeemCodeQuery
withSubscriptions *UserSubscriptionQuery
withAssignedSubscriptions *UserSubscriptionQuery
......@@ -75,8 +75,8 @@ func (_q *UserQuery) Order(o ...user.OrderOption) *UserQuery {
}
// QueryAPIKeys chains the current query on the "api_keys" edge.
func (_q *UserQuery) QueryAPIKeys() *ApiKeyQuery {
query := (&ApiKeyClient{config: _q.config}).Query()
func (_q *UserQuery) QueryAPIKeys() *APIKeyQuery {
query := (&APIKeyClient{config: _q.config}).Query()
query.path = func(ctx context.Context) (fromU *sql.Selector, err error) {
if err := _q.prepareQuery(ctx); err != nil {
return nil, err
......@@ -458,8 +458,8 @@ func (_q *UserQuery) Clone() *UserQuery {
// WithAPIKeys tells the query-builder to eager-load the nodes that are connected to
// the "api_keys" edge. The optional arguments are used to configure the query builder of the edge.
func (_q *UserQuery) WithAPIKeys(opts ...func(*ApiKeyQuery)) *UserQuery {
query := (&ApiKeyClient{config: _q.config}).Query()
func (_q *UserQuery) WithAPIKeys(opts ...func(*APIKeyQuery)) *UserQuery {
query := (&APIKeyClient{config: _q.config}).Query()
for _, opt := range opts {
opt(query)
}
......@@ -653,8 +653,8 @@ func (_q *UserQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*User, e
}
if query := _q.withAPIKeys; query != nil {
if err := _q.loadAPIKeys(ctx, query, nodes,
func(n *User) { n.Edges.APIKeys = []*ApiKey{} },
func(n *User, e *ApiKey) { n.Edges.APIKeys = append(n.Edges.APIKeys, e) }); err != nil {
func(n *User) { n.Edges.APIKeys = []*APIKey{} },
func(n *User, e *APIKey) { n.Edges.APIKeys = append(n.Edges.APIKeys, e) }); err != nil {
return nil, err
}
}
......@@ -712,7 +712,7 @@ func (_q *UserQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*User, e
return nodes, nil
}
func (_q *UserQuery) loadAPIKeys(ctx context.Context, query *ApiKeyQuery, nodes []*User, init func(*User), assign func(*User, *ApiKey)) error {
func (_q *UserQuery) loadAPIKeys(ctx context.Context, query *APIKeyQuery, nodes []*User, init func(*User), assign func(*User, *APIKey)) error {
fks := make([]driver.Value, 0, len(nodes))
nodeids := make(map[int64]*User)
for i := range nodes {
......@@ -725,7 +725,7 @@ func (_q *UserQuery) loadAPIKeys(ctx context.Context, query *ApiKeyQuery, nodes
if len(query.ctx.Fields) > 0 {
query.ctx.AppendFieldOnce(apikey.FieldUserID)
}
query.Where(predicate.ApiKey(func(s *sql.Selector) {
query.Where(predicate.APIKey(func(s *sql.Selector) {
s.Where(sql.InValues(s.C(user.APIKeysColumn), fks...))
}))
neighbors, err := query.All(ctx)
......
......@@ -186,14 +186,14 @@ func (_u *UserUpdate) SetNillableNotes(v *string) *UserUpdate {
return _u
}
// AddAPIKeyIDs adds the "api_keys" edge to the ApiKey entity by IDs.
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs.
func (_u *UserUpdate) AddAPIKeyIDs(ids ...int64) *UserUpdate {
_u.mutation.AddAPIKeyIDs(ids...)
return _u
}
// AddAPIKeys adds the "api_keys" edges to the ApiKey entity.
func (_u *UserUpdate) AddAPIKeys(v ...*ApiKey) *UserUpdate {
// AddAPIKeys adds the "api_keys" edges to the APIKey entity.
func (_u *UserUpdate) AddAPIKeys(v ...*APIKey) *UserUpdate {
ids := make([]int64, len(v))
for i := range v {
ids[i] = v[i].ID
......@@ -296,20 +296,20 @@ func (_u *UserUpdate) Mutation() *UserMutation {
return _u.mutation
}
// ClearAPIKeys clears all "api_keys" edges to the ApiKey entity.
// ClearAPIKeys clears all "api_keys" edges to the APIKey entity.
func (_u *UserUpdate) ClearAPIKeys() *UserUpdate {
_u.mutation.ClearAPIKeys()
return _u
}
// RemoveAPIKeyIDs removes the "api_keys" edge to ApiKey entities by IDs.
// RemoveAPIKeyIDs removes the "api_keys" edge to APIKey entities by IDs.
func (_u *UserUpdate) RemoveAPIKeyIDs(ids ...int64) *UserUpdate {
_u.mutation.RemoveAPIKeyIDs(ids...)
return _u
}
// RemoveAPIKeys removes "api_keys" edges to ApiKey entities.
func (_u *UserUpdate) RemoveAPIKeys(v ...*ApiKey) *UserUpdate {
// RemoveAPIKeys removes "api_keys" edges to APIKey entities.
func (_u *UserUpdate) RemoveAPIKeys(v ...*APIKey) *UserUpdate {
ids := make([]int64, len(v))
for i := range v {
ids[i] = v[i].ID
......@@ -1065,14 +1065,14 @@ func (_u *UserUpdateOne) SetNillableNotes(v *string) *UserUpdateOne {
return _u
}
// AddAPIKeyIDs adds the "api_keys" edge to the ApiKey entity by IDs.
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs.
func (_u *UserUpdateOne) AddAPIKeyIDs(ids ...int64) *UserUpdateOne {
_u.mutation.AddAPIKeyIDs(ids...)
return _u
}
// AddAPIKeys adds the "api_keys" edges to the ApiKey entity.
func (_u *UserUpdateOne) AddAPIKeys(v ...*ApiKey) *UserUpdateOne {
// AddAPIKeys adds the "api_keys" edges to the APIKey entity.
func (_u *UserUpdateOne) AddAPIKeys(v ...*APIKey) *UserUpdateOne {
ids := make([]int64, len(v))
for i := range v {
ids[i] = v[i].ID
......@@ -1175,20 +1175,20 @@ func (_u *UserUpdateOne) Mutation() *UserMutation {
return _u.mutation
}
// ClearAPIKeys clears all "api_keys" edges to the ApiKey entity.
// ClearAPIKeys clears all "api_keys" edges to the APIKey entity.
func (_u *UserUpdateOne) ClearAPIKeys() *UserUpdateOne {
_u.mutation.ClearAPIKeys()
return _u
}
// RemoveAPIKeyIDs removes the "api_keys" edge to ApiKey entities by IDs.
// RemoveAPIKeyIDs removes the "api_keys" edge to APIKey entities by IDs.
func (_u *UserUpdateOne) RemoveAPIKeyIDs(ids ...int64) *UserUpdateOne {
_u.mutation.RemoveAPIKeyIDs(ids...)
return _u
}
// RemoveAPIKeys removes "api_keys" edges to ApiKey entities.
func (_u *UserUpdateOne) RemoveAPIKeys(v ...*ApiKey) *UserUpdateOne {
// RemoveAPIKeys removes "api_keys" edges to APIKey entities.
func (_u *UserUpdateOne) RemoveAPIKeys(v ...*APIKey) *UserUpdateOne {
ids := make([]int64, len(v))
for i := range v {
ids[i] = v[i].ID
......
// Package config provides configuration loading, defaults, and validation.
package config
import (
......@@ -206,7 +207,7 @@ type GatewayConfig struct {
LogUpstreamErrorBodyMaxBytes int `mapstructure:"log_upstream_error_body_max_bytes"`
// API-key 账号在客户端未提供 anthropic-beta 时,是否按需自动补齐(默认关闭以保持兼容)
InjectBetaForApiKey bool `mapstructure:"inject_beta_for_apikey"`
InjectBetaForAPIKey bool `mapstructure:"inject_beta_for_apikey"`
// 是否允许对部分 400 错误触发 failover(默认关闭以避免改变语义)
FailoverOn400 bool `mapstructure:"failover_on_400"`
......@@ -312,7 +313,7 @@ type DefaultConfig struct {
AdminPassword string `mapstructure:"admin_password"`
UserConcurrency int `mapstructure:"user_concurrency"`
UserBalance float64 `mapstructure:"user_balance"`
ApiKeyPrefix string `mapstructure:"api_key_prefix"`
APIKeyPrefix string `mapstructure:"api_key_prefix"`
RateMultiplier float64 `mapstructure:"rate_multiplier"`
}
......
// Package admin provides HTTP handlers for administrative operations.
package admin
import (
"errors"
"strconv"
"strings"
"sync"
"time"
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
......@@ -78,6 +81,7 @@ type CreateAccountRequest struct {
Concurrency int `json:"concurrency"`
Priority int `json:"priority"`
GroupIDs []int64 `json:"group_ids"`
ConfirmMixedChannelRisk *bool `json:"confirm_mixed_channel_risk"` // 用户确认混合渠道风险
}
// UpdateAccountRequest represents update account request
......@@ -92,6 +96,7 @@ type UpdateAccountRequest struct {
Priority *int `json:"priority"`
Status string `json:"status" binding:"omitempty,oneof=active inactive"`
GroupIDs *[]int64 `json:"group_ids"`
ConfirmMixedChannelRisk *bool `json:"confirm_mixed_channel_risk"` // 用户确认混合渠道风险
}
// BulkUpdateAccountsRequest represents the payload for bulk editing accounts
......@@ -105,6 +110,7 @@ type BulkUpdateAccountsRequest struct {
GroupIDs *[]int64 `json:"group_ids"`
Credentials map[string]any `json:"credentials"`
Extra map[string]any `json:"extra"`
ConfirmMixedChannelRisk *bool `json:"confirm_mixed_channel_risk"` // 用户确认混合渠道风险
}
// AccountWithConcurrency extends Account with real-time concurrency info
......@@ -179,6 +185,9 @@ func (h *AccountHandler) Create(c *gin.Context) {
return
}
// 确定是否跳过混合渠道检查
skipCheck := req.ConfirmMixedChannelRisk != nil && *req.ConfirmMixedChannelRisk
account, err := h.adminService.CreateAccount(c.Request.Context(), &service.CreateAccountInput{
Name: req.Name,
Platform: req.Platform,
......@@ -189,8 +198,27 @@ func (h *AccountHandler) Create(c *gin.Context) {
Concurrency: req.Concurrency,
Priority: req.Priority,
GroupIDs: req.GroupIDs,
SkipMixedChannelCheck: skipCheck,
})
if err != nil {
// 检查是否为混合渠道错误
var mixedErr *service.MixedChannelError
if errors.As(err, &mixedErr) {
// 返回特殊错误码要求确认
c.JSON(409, gin.H{
"error": "mixed_channel_warning",
"message": mixedErr.Error(),
"details": gin.H{
"group_id": mixedErr.GroupID,
"group_name": mixedErr.GroupName,
"current_platform": mixedErr.CurrentPlatform,
"other_platform": mixedErr.OtherPlatform,
},
"require_confirmation": true,
})
return
}
response.ErrorFrom(c, err)
return
}
......@@ -213,6 +241,9 @@ func (h *AccountHandler) Update(c *gin.Context) {
return
}
// 确定是否跳过混合渠道检查
skipCheck := req.ConfirmMixedChannelRisk != nil && *req.ConfirmMixedChannelRisk
account, err := h.adminService.UpdateAccount(c.Request.Context(), accountID, &service.UpdateAccountInput{
Name: req.Name,
Type: req.Type,
......@@ -223,8 +254,27 @@ func (h *AccountHandler) Update(c *gin.Context) {
Priority: req.Priority, // 指针类型,nil 表示未提供
Status: req.Status,
GroupIDs: req.GroupIDs,
SkipMixedChannelCheck: skipCheck,
})
if err != nil {
// 检查是否为混合渠道错误
var mixedErr *service.MixedChannelError
if errors.As(err, &mixedErr) {
// 返回特殊错误码要求确认
c.JSON(409, gin.H{
"error": "mixed_channel_warning",
"message": mixedErr.Error(),
"details": gin.H{
"group_id": mixedErr.GroupID,
"group_name": mixedErr.GroupName,
"current_platform": mixedErr.CurrentPlatform,
"other_platform": mixedErr.OtherPlatform,
},
"require_confirmation": true,
})
return
}
response.ErrorFrom(c, err)
return
}
......@@ -568,6 +618,9 @@ func (h *AccountHandler) BulkUpdate(c *gin.Context) {
return
}
// 确定是否跳过混合渠道检查
skipCheck := req.ConfirmMixedChannelRisk != nil && *req.ConfirmMixedChannelRisk
hasUpdates := req.Name != "" ||
req.ProxyID != nil ||
req.Concurrency != nil ||
......@@ -592,6 +645,7 @@ func (h *AccountHandler) BulkUpdate(c *gin.Context) {
GroupIDs: req.GroupIDs,
Credentials: req.Credentials,
Extra: req.Extra,
SkipMixedChannelCheck: skipCheck,
})
if err != nil {
response.ErrorFrom(c, err)
......@@ -781,6 +835,49 @@ func (h *AccountHandler) ClearRateLimit(c *gin.Context) {
response.Success(c, gin.H{"message": "Rate limit cleared successfully"})
}
// GetTempUnschedulable handles getting temporary unschedulable status
// GET /api/v1/admin/accounts/:id/temp-unschedulable
func (h *AccountHandler) GetTempUnschedulable(c *gin.Context) {
accountID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "Invalid account ID")
return
}
state, err := h.rateLimitService.GetTempUnschedStatus(c.Request.Context(), accountID)
if err != nil {
response.ErrorFrom(c, err)
return
}
if state == nil || state.UntilUnix <= time.Now().Unix() {
response.Success(c, gin.H{"active": false})
return
}
response.Success(c, gin.H{
"active": true,
"state": state,
})
}
// ClearTempUnschedulable handles clearing temporary unschedulable status
// DELETE /api/v1/admin/accounts/:id/temp-unschedulable
func (h *AccountHandler) ClearTempUnschedulable(c *gin.Context) {
accountID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "Invalid account ID")
return
}
if err := h.rateLimitService.ClearTempUnschedulable(c.Request.Context(), accountID); err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, gin.H{"message": "Temp unschedulable cleared successfully"})
}
// GetTodayStats handles getting account today statistics
// GET /api/v1/admin/accounts/:id/today-stats
func (h *AccountHandler) GetTodayStats(c *gin.Context) {
......
......@@ -237,9 +237,9 @@ func (h *GroupHandler) GetGroupAPIKeys(c *gin.Context) {
return
}
outKeys := make([]dto.ApiKey, 0, len(keys))
outKeys := make([]dto.APIKey, 0, len(keys))
for i := range keys {
outKeys = append(outKeys, *dto.ApiKeyFromService(&keys[i]))
outKeys = append(outKeys, *dto.APIKeyFromService(&keys[i]))
}
response.Paginated(c, outKeys, total, page, pageSize)
}
......@@ -243,9 +243,9 @@ func (h *UserHandler) GetUserAPIKeys(c *gin.Context) {
return
}
out := make([]dto.ApiKey, 0, len(keys))
out := make([]dto.APIKey, 0, len(keys))
for i := range keys {
out = append(out, *dto.ApiKeyFromService(&keys[i]))
out = append(out, *dto.APIKeyFromService(&keys[i]))
}
response.Paginated(c, out, total, page, pageSize)
}
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
......@@ -39,9 +39,9 @@ func (h *SettingHandler) GetPublicSettings(c *gin.Context) {
SiteName: settings.SiteName,
SiteLogo: settings.SiteLogo,
SiteSubtitle: settings.SiteSubtitle,
ApiBaseUrl: settings.ApiBaseUrl,
APIBaseURL: settings.APIBaseURL,
ContactInfo: settings.ContactInfo,
DocUrl: settings.DocUrl,
DocURL: settings.DocURL,
Version: h.version,
})
}
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment