Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Menu
Open sidebar
陈曦
sub2api
Commits
a04ae28a
Commit
a04ae28a
authored
Apr 13, 2026
by
陈曦
Browse files
merge v0.1.111
parents
68f67198
ad64190b
Changes
302
Hide whitespace changes
Inline
Side-by-side
Too many changes to show.
To preserve performance only
302 of 302+
files are displayed.
Plain diff
Email patch
backend/internal/server/router.go
View file @
a04ae28a
...
@@ -111,4 +111,5 @@ func registerRoutes(
...
@@ -111,4 +111,5 @@ func registerRoutes(
routes
.
RegisterUserRoutes
(
v1
,
h
,
jwtAuth
,
settingService
)
routes
.
RegisterUserRoutes
(
v1
,
h
,
jwtAuth
,
settingService
)
routes
.
RegisterAdminRoutes
(
v1
,
h
,
adminAuth
)
routes
.
RegisterAdminRoutes
(
v1
,
h
,
adminAuth
)
routes
.
RegisterGatewayRoutes
(
r
,
h
,
apiKeyAuth
,
apiKeyService
,
subscriptionService
,
opsService
,
settingService
,
cfg
)
routes
.
RegisterGatewayRoutes
(
r
,
h
,
apiKeyAuth
,
apiKeyService
,
subscriptionService
,
opsService
,
settingService
,
cfg
)
routes
.
RegisterPaymentRoutes
(
v1
,
h
.
Payment
,
h
.
PaymentWebhook
,
h
.
Admin
.
Payment
,
jwtAuth
,
adminAuth
,
settingService
)
}
}
backend/internal/server/routes/auth.go
View file @
a04ae28a
...
@@ -70,6 +70,14 @@ func RegisterAuthRoutes(
...
@@ -70,6 +70,14 @@ func RegisterAuthRoutes(
}),
}),
h
.
Auth
.
CompleteLinuxDoOAuthRegistration
,
h
.
Auth
.
CompleteLinuxDoOAuthRegistration
,
)
)
auth
.
GET
(
"/oauth/oidc/start"
,
h
.
Auth
.
OIDCOAuthStart
)
auth
.
GET
(
"/oauth/oidc/callback"
,
h
.
Auth
.
OIDCOAuthCallback
)
auth
.
POST
(
"/oauth/oidc/complete-registration"
,
rateLimiter
.
LimitWithOptions
(
"oauth-oidc-complete"
,
10
,
time
.
Minute
,
middleware
.
RateLimitOptions
{
FailureMode
:
middleware
.
RateLimitFailClose
,
}),
h
.
Auth
.
CompleteOIDCOAuthRegistration
,
)
}
}
// 公开设置(无需认证)
// 公开设置(无需认证)
...
...
backend/internal/server/routes/payment.go
0 → 100644
View file @
a04ae28a
package
routes
import
(
"github.com/Wei-Shaw/sub2api/internal/handler"
"github.com/Wei-Shaw/sub2api/internal/handler/admin"
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
// RegisterPaymentRoutes registers all payment-related routes:
// user-facing endpoints, webhook endpoints, and admin endpoints.
func
RegisterPaymentRoutes
(
v1
*
gin
.
RouterGroup
,
paymentHandler
*
handler
.
PaymentHandler
,
webhookHandler
*
handler
.
PaymentWebhookHandler
,
adminPaymentHandler
*
admin
.
PaymentHandler
,
jwtAuth
middleware
.
JWTAuthMiddleware
,
adminAuth
middleware
.
AdminAuthMiddleware
,
settingService
*
service
.
SettingService
,
)
{
// --- User-facing payment endpoints (authenticated) ---
authenticated
:=
v1
.
Group
(
"/payment"
)
authenticated
.
Use
(
gin
.
HandlerFunc
(
jwtAuth
))
authenticated
.
Use
(
middleware
.
BackendModeUserGuard
(
settingService
))
{
authenticated
.
GET
(
"/config"
,
paymentHandler
.
GetPaymentConfig
)
authenticated
.
GET
(
"/checkout-info"
,
paymentHandler
.
GetCheckoutInfo
)
authenticated
.
GET
(
"/plans"
,
paymentHandler
.
GetPlans
)
authenticated
.
GET
(
"/channels"
,
paymentHandler
.
GetChannels
)
authenticated
.
GET
(
"/limits"
,
paymentHandler
.
GetLimits
)
orders
:=
authenticated
.
Group
(
"/orders"
)
{
orders
.
POST
(
""
,
paymentHandler
.
CreateOrder
)
orders
.
POST
(
"/verify"
,
paymentHandler
.
VerifyOrder
)
orders
.
GET
(
"/my"
,
paymentHandler
.
GetMyOrders
)
orders
.
GET
(
"/:id"
,
paymentHandler
.
GetOrder
)
orders
.
POST
(
"/:id/cancel"
,
paymentHandler
.
CancelOrder
)
orders
.
POST
(
"/:id/refund-request"
,
paymentHandler
.
RequestRefund
)
}
}
// --- Public payment endpoints (no auth) ---
// Payment result page needs to verify order status without login
// (user session may have expired during provider redirect).
public
:=
v1
.
Group
(
"/payment/public"
)
{
public
.
POST
(
"/orders/verify"
,
paymentHandler
.
VerifyOrderPublic
)
}
// --- Webhook endpoints (no auth) ---
webhook
:=
v1
.
Group
(
"/payment/webhook"
)
{
// EasyPay sends GET callbacks with query params
webhook
.
GET
(
"/easypay"
,
webhookHandler
.
EasyPayNotify
)
webhook
.
POST
(
"/easypay"
,
webhookHandler
.
EasyPayNotify
)
webhook
.
POST
(
"/alipay"
,
webhookHandler
.
AlipayNotify
)
webhook
.
POST
(
"/wxpay"
,
webhookHandler
.
WxpayNotify
)
webhook
.
POST
(
"/stripe"
,
webhookHandler
.
StripeWebhook
)
}
// --- Admin payment endpoints (admin auth) ---
adminGroup
:=
v1
.
Group
(
"/admin/payment"
)
adminGroup
.
Use
(
gin
.
HandlerFunc
(
adminAuth
))
{
// Dashboard
adminGroup
.
GET
(
"/dashboard"
,
adminPaymentHandler
.
GetDashboard
)
// Config
adminGroup
.
GET
(
"/config"
,
adminPaymentHandler
.
GetConfig
)
adminGroup
.
PUT
(
"/config"
,
adminPaymentHandler
.
UpdateConfig
)
// Orders
adminOrders
:=
adminGroup
.
Group
(
"/orders"
)
{
adminOrders
.
GET
(
""
,
adminPaymentHandler
.
ListOrders
)
adminOrders
.
GET
(
"/:id"
,
adminPaymentHandler
.
GetOrderDetail
)
adminOrders
.
POST
(
"/:id/cancel"
,
adminPaymentHandler
.
CancelOrder
)
adminOrders
.
POST
(
"/:id/retry"
,
adminPaymentHandler
.
RetryFulfillment
)
adminOrders
.
POST
(
"/:id/refund"
,
adminPaymentHandler
.
ProcessRefund
)
}
// Subscription Plans
plans
:=
adminGroup
.
Group
(
"/plans"
)
{
plans
.
GET
(
""
,
adminPaymentHandler
.
ListPlans
)
plans
.
POST
(
""
,
adminPaymentHandler
.
CreatePlan
)
plans
.
PUT
(
"/:id"
,
adminPaymentHandler
.
UpdatePlan
)
plans
.
DELETE
(
"/:id"
,
adminPaymentHandler
.
DeletePlan
)
}
// Provider Instances
providers
:=
adminGroup
.
Group
(
"/providers"
)
{
providers
.
GET
(
""
,
adminPaymentHandler
.
ListProviders
)
providers
.
POST
(
""
,
adminPaymentHandler
.
CreateProvider
)
providers
.
PUT
(
"/:id"
,
adminPaymentHandler
.
UpdateProvider
)
providers
.
DELETE
(
"/:id"
,
adminPaymentHandler
.
DeleteProvider
)
}
}
}
backend/internal/service/admin_service.go
View file @
a04ae28a
...
@@ -21,13 +21,13 @@ import (
...
@@ -21,13 +21,13 @@ import (
// AdminService interface defines admin management operations
// AdminService interface defines admin management operations
type
AdminService
interface
{
type
AdminService
interface
{
// User management
// User management
ListUsers
(
ctx
context
.
Context
,
page
,
pageSize
int
,
filters
UserListFilters
)
([]
User
,
int64
,
error
)
ListUsers
(
ctx
context
.
Context
,
page
,
pageSize
int
,
filters
UserListFilters
,
sortBy
,
sortOrder
string
)
([]
User
,
int64
,
error
)
GetUser
(
ctx
context
.
Context
,
id
int64
)
(
*
User
,
error
)
GetUser
(
ctx
context
.
Context
,
id
int64
)
(
*
User
,
error
)
CreateUser
(
ctx
context
.
Context
,
input
*
CreateUserInput
)
(
*
User
,
error
)
CreateUser
(
ctx
context
.
Context
,
input
*
CreateUserInput
)
(
*
User
,
error
)
UpdateUser
(
ctx
context
.
Context
,
id
int64
,
input
*
UpdateUserInput
)
(
*
User
,
error
)
UpdateUser
(
ctx
context
.
Context
,
id
int64
,
input
*
UpdateUserInput
)
(
*
User
,
error
)
DeleteUser
(
ctx
context
.
Context
,
id
int64
)
error
DeleteUser
(
ctx
context
.
Context
,
id
int64
)
error
UpdateUserBalance
(
ctx
context
.
Context
,
userID
int64
,
balance
float64
,
operation
string
,
notes
string
)
(
*
User
,
error
)
UpdateUserBalance
(
ctx
context
.
Context
,
userID
int64
,
balance
float64
,
operation
string
,
notes
string
)
(
*
User
,
error
)
GetUserAPIKeys
(
ctx
context
.
Context
,
userID
int64
,
page
,
pageSize
int
)
([]
APIKey
,
int64
,
error
)
GetUserAPIKeys
(
ctx
context
.
Context
,
userID
int64
,
page
,
pageSize
int
,
sortBy
,
sortOrder
string
)
([]
APIKey
,
int64
,
error
)
GetUserUsageStats
(
ctx
context
.
Context
,
userID
int64
,
period
string
)
(
any
,
error
)
GetUserUsageStats
(
ctx
context
.
Context
,
userID
int64
,
period
string
)
(
any
,
error
)
// GetUserBalanceHistory returns paginated balance/concurrency change records for a user.
// GetUserBalanceHistory returns paginated balance/concurrency change records for a user.
// codeType is optional - pass empty string to return all types.
// codeType is optional - pass empty string to return all types.
...
@@ -35,7 +35,7 @@ type AdminService interface {
...
@@ -35,7 +35,7 @@ type AdminService interface {
GetUserBalanceHistory
(
ctx
context
.
Context
,
userID
int64
,
page
,
pageSize
int
,
codeType
string
)
([]
RedeemCode
,
int64
,
float64
,
error
)
GetUserBalanceHistory
(
ctx
context
.
Context
,
userID
int64
,
page
,
pageSize
int
,
codeType
string
)
([]
RedeemCode
,
int64
,
float64
,
error
)
// Group management
// Group management
ListGroups
(
ctx
context
.
Context
,
page
,
pageSize
int
,
platform
,
status
,
search
string
,
isExclusive
*
bool
)
([]
Group
,
int64
,
error
)
ListGroups
(
ctx
context
.
Context
,
page
,
pageSize
int
,
platform
,
status
,
search
string
,
isExclusive
*
bool
,
sortBy
,
sortOrder
string
)
([]
Group
,
int64
,
error
)
GetAllGroups
(
ctx
context
.
Context
)
([]
Group
,
error
)
GetAllGroups
(
ctx
context
.
Context
)
([]
Group
,
error
)
GetAllGroupsByPlatform
(
ctx
context
.
Context
,
platform
string
)
([]
Group
,
error
)
GetAllGroupsByPlatform
(
ctx
context
.
Context
,
platform
string
)
([]
Group
,
error
)
GetGroup
(
ctx
context
.
Context
,
id
int64
)
(
*
Group
,
error
)
GetGroup
(
ctx
context
.
Context
,
id
int64
)
(
*
Group
,
error
)
...
@@ -55,7 +55,7 @@ type AdminService interface {
...
@@ -55,7 +55,7 @@ type AdminService interface {
ReplaceUserGroup
(
ctx
context
.
Context
,
userID
,
oldGroupID
,
newGroupID
int64
)
(
*
ReplaceUserGroupResult
,
error
)
ReplaceUserGroup
(
ctx
context
.
Context
,
userID
,
oldGroupID
,
newGroupID
int64
)
(
*
ReplaceUserGroupResult
,
error
)
// Account management
// Account management
ListAccounts
(
ctx
context
.
Context
,
page
,
pageSize
int
,
platform
,
accountType
,
status
,
search
string
,
groupID
int64
,
privacyMode
string
)
([]
Account
,
int64
,
error
)
ListAccounts
(
ctx
context
.
Context
,
page
,
pageSize
int
,
platform
,
accountType
,
status
,
search
string
,
groupID
int64
,
privacyMode
string
,
sortBy
,
sortOrder
string
)
([]
Account
,
int64
,
error
)
GetAccount
(
ctx
context
.
Context
,
id
int64
)
(
*
Account
,
error
)
GetAccount
(
ctx
context
.
Context
,
id
int64
)
(
*
Account
,
error
)
GetAccountsByIDs
(
ctx
context
.
Context
,
ids
[]
int64
)
([]
*
Account
,
error
)
GetAccountsByIDs
(
ctx
context
.
Context
,
ids
[]
int64
)
([]
*
Account
,
error
)
CreateAccount
(
ctx
context
.
Context
,
input
*
CreateAccountInput
)
(
*
Account
,
error
)
CreateAccount
(
ctx
context
.
Context
,
input
*
CreateAccountInput
)
(
*
Account
,
error
)
...
@@ -77,8 +77,8 @@ type AdminService interface {
...
@@ -77,8 +77,8 @@ type AdminService interface {
CheckMixedChannelRisk
(
ctx
context
.
Context
,
currentAccountID
int64
,
currentAccountPlatform
string
,
groupIDs
[]
int64
)
error
CheckMixedChannelRisk
(
ctx
context
.
Context
,
currentAccountID
int64
,
currentAccountPlatform
string
,
groupIDs
[]
int64
)
error
// Proxy management
// Proxy management
ListProxies
(
ctx
context
.
Context
,
page
,
pageSize
int
,
protocol
,
status
,
search
string
)
([]
Proxy
,
int64
,
error
)
ListProxies
(
ctx
context
.
Context
,
page
,
pageSize
int
,
protocol
,
status
,
search
string
,
sortBy
,
sortOrder
string
)
([]
Proxy
,
int64
,
error
)
ListProxiesWithAccountCount
(
ctx
context
.
Context
,
page
,
pageSize
int
,
protocol
,
status
,
search
string
)
([]
ProxyWithAccountCount
,
int64
,
error
)
ListProxiesWithAccountCount
(
ctx
context
.
Context
,
page
,
pageSize
int
,
protocol
,
status
,
search
string
,
sortBy
,
sortOrder
string
)
([]
ProxyWithAccountCount
,
int64
,
error
)
GetAllProxies
(
ctx
context
.
Context
)
([]
Proxy
,
error
)
GetAllProxies
(
ctx
context
.
Context
)
([]
Proxy
,
error
)
GetAllProxiesWithAccountCount
(
ctx
context
.
Context
)
([]
ProxyWithAccountCount
,
error
)
GetAllProxiesWithAccountCount
(
ctx
context
.
Context
)
([]
ProxyWithAccountCount
,
error
)
GetProxy
(
ctx
context
.
Context
,
id
int64
)
(
*
Proxy
,
error
)
GetProxy
(
ctx
context
.
Context
,
id
int64
)
(
*
Proxy
,
error
)
...
@@ -93,7 +93,7 @@ type AdminService interface {
...
@@ -93,7 +93,7 @@ type AdminService interface {
CheckProxyQuality
(
ctx
context
.
Context
,
id
int64
)
(
*
ProxyQualityCheckResult
,
error
)
CheckProxyQuality
(
ctx
context
.
Context
,
id
int64
)
(
*
ProxyQualityCheckResult
,
error
)
// Redeem code management
// Redeem code management
ListRedeemCodes
(
ctx
context
.
Context
,
page
,
pageSize
int
,
codeType
,
status
,
search
string
)
([]
RedeemCode
,
int64
,
error
)
ListRedeemCodes
(
ctx
context
.
Context
,
page
,
pageSize
int
,
codeType
,
status
,
search
string
,
sortBy
,
sortOrder
string
)
([]
RedeemCode
,
int64
,
error
)
GetRedeemCode
(
ctx
context
.
Context
,
id
int64
)
(
*
RedeemCode
,
error
)
GetRedeemCode
(
ctx
context
.
Context
,
id
int64
)
(
*
RedeemCode
,
error
)
GenerateRedeemCodes
(
ctx
context
.
Context
,
input
*
GenerateRedeemCodesInput
)
([]
RedeemCode
,
error
)
GenerateRedeemCodes
(
ctx
context
.
Context
,
input
*
GenerateRedeemCodesInput
)
([]
RedeemCode
,
error
)
DeleteRedeemCode
(
ctx
context
.
Context
,
id
int64
)
error
DeleteRedeemCode
(
ctx
context
.
Context
,
id
int64
)
error
...
@@ -152,10 +152,11 @@ type CreateGroupInput struct {
...
@@ -152,10 +152,11 @@ type CreateGroupInput struct {
// 支持的模型系列(仅 antigravity 平台使用)
// 支持的模型系列(仅 antigravity 平台使用)
SupportedModelScopes
[]
string
SupportedModelScopes
[]
string
// OpenAI Messages 调度配置(仅 openai 平台使用)
// OpenAI Messages 调度配置(仅 openai 平台使用)
AllowMessagesDispatch
bool
AllowMessagesDispatch
bool
DefaultMappedModel
string
DefaultMappedModel
string
RequireOAuthOnly
bool
RequireOAuthOnly
bool
RequirePrivacySet
bool
RequirePrivacySet
bool
MessagesDispatchModelConfig
OpenAIMessagesDispatchModelConfig
// 从指定分组复制账号(创建分组后在同一事务内绑定)
// 从指定分组复制账号(创建分组后在同一事务内绑定)
CopyAccountsFromGroupIDs
[]
int64
CopyAccountsFromGroupIDs
[]
int64
}
}
...
@@ -186,10 +187,11 @@ type UpdateGroupInput struct {
...
@@ -186,10 +187,11 @@ type UpdateGroupInput struct {
// 支持的模型系列(仅 antigravity 平台使用)
// 支持的模型系列(仅 antigravity 平台使用)
SupportedModelScopes
*
[]
string
SupportedModelScopes
*
[]
string
// OpenAI Messages 调度配置(仅 openai 平台使用)
// OpenAI Messages 调度配置(仅 openai 平台使用)
AllowMessagesDispatch
*
bool
AllowMessagesDispatch
*
bool
DefaultMappedModel
*
string
DefaultMappedModel
*
string
RequireOAuthOnly
*
bool
RequireOAuthOnly
*
bool
RequirePrivacySet
*
bool
RequirePrivacySet
*
bool
MessagesDispatchModelConfig
*
OpenAIMessagesDispatchModelConfig
// 从指定分组复制账号(同步操作:先清空当前分组的账号绑定,再绑定源分组的账号)
// 从指定分组复制账号(同步操作:先清空当前分组的账号绑定,再绑定源分组的账号)
CopyAccountsFromGroupIDs
[]
int64
CopyAccountsFromGroupIDs
[]
int64
}
}
...
@@ -483,8 +485,8 @@ func NewAdminService(
...
@@ -483,8 +485,8 @@ func NewAdminService(
}
}
// User management implementations
// User management implementations
func
(
s
*
adminServiceImpl
)
ListUsers
(
ctx
context
.
Context
,
page
,
pageSize
int
,
filters
UserListFilters
)
([]
User
,
int64
,
error
)
{
func
(
s
*
adminServiceImpl
)
ListUsers
(
ctx
context
.
Context
,
page
,
pageSize
int
,
filters
UserListFilters
,
sortBy
,
sortOrder
string
)
([]
User
,
int64
,
error
)
{
params
:=
pagination
.
PaginationParams
{
Page
:
page
,
PageSize
:
pageSize
}
params
:=
pagination
.
PaginationParams
{
Page
:
page
,
PageSize
:
pageSize
,
SortBy
:
sortBy
,
SortOrder
:
sortOrder
}
users
,
result
,
err
:=
s
.
userRepo
.
ListWithFilters
(
ctx
,
params
,
filters
)
users
,
result
,
err
:=
s
.
userRepo
.
ListWithFilters
(
ctx
,
params
,
filters
)
if
err
!=
nil
{
if
err
!=
nil
{
return
nil
,
0
,
err
return
nil
,
0
,
err
...
@@ -751,8 +753,8 @@ func (s *adminServiceImpl) UpdateUserBalance(ctx context.Context, userID int64,
...
@@ -751,8 +753,8 @@ func (s *adminServiceImpl) UpdateUserBalance(ctx context.Context, userID int64,
return
user
,
nil
return
user
,
nil
}
}
func
(
s
*
adminServiceImpl
)
GetUserAPIKeys
(
ctx
context
.
Context
,
userID
int64
,
page
,
pageSize
int
)
([]
APIKey
,
int64
,
error
)
{
func
(
s
*
adminServiceImpl
)
GetUserAPIKeys
(
ctx
context
.
Context
,
userID
int64
,
page
,
pageSize
int
,
sortBy
,
sortOrder
string
)
([]
APIKey
,
int64
,
error
)
{
params
:=
pagination
.
PaginationParams
{
Page
:
page
,
PageSize
:
pageSize
}
params
:=
pagination
.
PaginationParams
{
Page
:
page
,
PageSize
:
pageSize
,
SortBy
:
sortBy
,
SortOrder
:
sortOrder
}
keys
,
result
,
err
:=
s
.
apiKeyRepo
.
ListByUserID
(
ctx
,
userID
,
params
,
APIKeyListFilters
{})
keys
,
result
,
err
:=
s
.
apiKeyRepo
.
ListByUserID
(
ctx
,
userID
,
params
,
APIKeyListFilters
{})
if
err
!=
nil
{
if
err
!=
nil
{
return
nil
,
0
,
err
return
nil
,
0
,
err
...
@@ -787,8 +789,8 @@ func (s *adminServiceImpl) GetUserBalanceHistory(ctx context.Context, userID int
...
@@ -787,8 +789,8 @@ func (s *adminServiceImpl) GetUserBalanceHistory(ctx context.Context, userID int
}
}
// Group management implementations
// Group management implementations
func
(
s
*
adminServiceImpl
)
ListGroups
(
ctx
context
.
Context
,
page
,
pageSize
int
,
platform
,
status
,
search
string
,
isExclusive
*
bool
)
([]
Group
,
int64
,
error
)
{
func
(
s
*
adminServiceImpl
)
ListGroups
(
ctx
context
.
Context
,
page
,
pageSize
int
,
platform
,
status
,
search
string
,
isExclusive
*
bool
,
sortBy
,
sortOrder
string
)
([]
Group
,
int64
,
error
)
{
params
:=
pagination
.
PaginationParams
{
Page
:
page
,
PageSize
:
pageSize
}
params
:=
pagination
.
PaginationParams
{
Page
:
page
,
PageSize
:
pageSize
,
SortBy
:
sortBy
,
SortOrder
:
sortOrder
}
groups
,
result
,
err
:=
s
.
groupRepo
.
ListWithFilters
(
ctx
,
params
,
platform
,
status
,
search
,
isExclusive
)
groups
,
result
,
err
:=
s
.
groupRepo
.
ListWithFilters
(
ctx
,
params
,
platform
,
status
,
search
,
isExclusive
)
if
err
!=
nil
{
if
err
!=
nil
{
return
nil
,
0
,
err
return
nil
,
0
,
err
...
@@ -908,7 +910,9 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn
...
@@ -908,7 +910,9 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn
RequireOAuthOnly
:
input
.
RequireOAuthOnly
,
RequireOAuthOnly
:
input
.
RequireOAuthOnly
,
RequirePrivacySet
:
input
.
RequirePrivacySet
,
RequirePrivacySet
:
input
.
RequirePrivacySet
,
DefaultMappedModel
:
input
.
DefaultMappedModel
,
DefaultMappedModel
:
input
.
DefaultMappedModel
,
MessagesDispatchModelConfig
:
normalizeOpenAIMessagesDispatchModelConfig
(
input
.
MessagesDispatchModelConfig
),
}
}
sanitizeGroupMessagesDispatchFields
(
group
)
if
err
:=
s
.
groupRepo
.
Create
(
ctx
,
group
);
err
!=
nil
{
if
err
:=
s
.
groupRepo
.
Create
(
ctx
,
group
);
err
!=
nil
{
return
nil
,
err
return
nil
,
err
}
}
...
@@ -1135,6 +1139,10 @@ func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *Upd
...
@@ -1135,6 +1139,10 @@ func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *Upd
if
input
.
DefaultMappedModel
!=
nil
{
if
input
.
DefaultMappedModel
!=
nil
{
group
.
DefaultMappedModel
=
*
input
.
DefaultMappedModel
group
.
DefaultMappedModel
=
*
input
.
DefaultMappedModel
}
}
if
input
.
MessagesDispatchModelConfig
!=
nil
{
group
.
MessagesDispatchModelConfig
=
normalizeOpenAIMessagesDispatchModelConfig
(
*
input
.
MessagesDispatchModelConfig
)
}
sanitizeGroupMessagesDispatchFields
(
group
)
if
err
:=
s
.
groupRepo
.
Update
(
ctx
,
group
);
err
!=
nil
{
if
err
:=
s
.
groupRepo
.
Update
(
ctx
,
group
);
err
!=
nil
{
return
nil
,
err
return
nil
,
err
...
@@ -1456,8 +1464,8 @@ func (s *adminServiceImpl) ReplaceUserGroup(ctx context.Context, userID, oldGrou
...
@@ -1456,8 +1464,8 @@ func (s *adminServiceImpl) ReplaceUserGroup(ctx context.Context, userID, oldGrou
}
}
// Account management implementations
// Account management implementations
func
(
s
*
adminServiceImpl
)
ListAccounts
(
ctx
context
.
Context
,
page
,
pageSize
int
,
platform
,
accountType
,
status
,
search
string
,
groupID
int64
,
privacyMode
string
)
([]
Account
,
int64
,
error
)
{
func
(
s
*
adminServiceImpl
)
ListAccounts
(
ctx
context
.
Context
,
page
,
pageSize
int
,
platform
,
accountType
,
status
,
search
string
,
groupID
int64
,
privacyMode
string
,
sortBy
,
sortOrder
string
)
([]
Account
,
int64
,
error
)
{
params
:=
pagination
.
PaginationParams
{
Page
:
page
,
PageSize
:
pageSize
}
params
:=
pagination
.
PaginationParams
{
Page
:
page
,
PageSize
:
pageSize
,
SortBy
:
sortBy
,
SortOrder
:
sortOrder
}
accounts
,
result
,
err
:=
s
.
accountRepo
.
ListWithFilters
(
ctx
,
params
,
platform
,
accountType
,
status
,
search
,
groupID
,
privacyMode
)
accounts
,
result
,
err
:=
s
.
accountRepo
.
ListWithFilters
(
ctx
,
params
,
platform
,
accountType
,
status
,
search
,
groupID
,
privacyMode
)
if
err
!=
nil
{
if
err
!=
nil
{
return
nil
,
0
,
err
return
nil
,
0
,
err
...
@@ -1885,8 +1893,8 @@ func (s *adminServiceImpl) SetAccountSchedulable(ctx context.Context, id int64,
...
@@ -1885,8 +1893,8 @@ func (s *adminServiceImpl) SetAccountSchedulable(ctx context.Context, id int64,
}
}
// Proxy management implementations
// Proxy management implementations
func
(
s
*
adminServiceImpl
)
ListProxies
(
ctx
context
.
Context
,
page
,
pageSize
int
,
protocol
,
status
,
search
string
)
([]
Proxy
,
int64
,
error
)
{
func
(
s
*
adminServiceImpl
)
ListProxies
(
ctx
context
.
Context
,
page
,
pageSize
int
,
protocol
,
status
,
search
string
,
sortBy
,
sortOrder
string
)
([]
Proxy
,
int64
,
error
)
{
params
:=
pagination
.
PaginationParams
{
Page
:
page
,
PageSize
:
pageSize
}
params
:=
pagination
.
PaginationParams
{
Page
:
page
,
PageSize
:
pageSize
,
SortBy
:
sortBy
,
SortOrder
:
sortOrder
}
proxies
,
result
,
err
:=
s
.
proxyRepo
.
ListWithFilters
(
ctx
,
params
,
protocol
,
status
,
search
)
proxies
,
result
,
err
:=
s
.
proxyRepo
.
ListWithFilters
(
ctx
,
params
,
protocol
,
status
,
search
)
if
err
!=
nil
{
if
err
!=
nil
{
return
nil
,
0
,
err
return
nil
,
0
,
err
...
@@ -1894,8 +1902,8 @@ func (s *adminServiceImpl) ListProxies(ctx context.Context, page, pageSize int,
...
@@ -1894,8 +1902,8 @@ func (s *adminServiceImpl) ListProxies(ctx context.Context, page, pageSize int,
return
proxies
,
result
.
Total
,
nil
return
proxies
,
result
.
Total
,
nil
}
}
func
(
s
*
adminServiceImpl
)
ListProxiesWithAccountCount
(
ctx
context
.
Context
,
page
,
pageSize
int
,
protocol
,
status
,
search
string
)
([]
ProxyWithAccountCount
,
int64
,
error
)
{
func
(
s
*
adminServiceImpl
)
ListProxiesWithAccountCount
(
ctx
context
.
Context
,
page
,
pageSize
int
,
protocol
,
status
,
search
string
,
sortBy
,
sortOrder
string
)
([]
ProxyWithAccountCount
,
int64
,
error
)
{
params
:=
pagination
.
PaginationParams
{
Page
:
page
,
PageSize
:
pageSize
}
params
:=
pagination
.
PaginationParams
{
Page
:
page
,
PageSize
:
pageSize
,
SortBy
:
sortBy
,
SortOrder
:
sortOrder
}
proxies
,
result
,
err
:=
s
.
proxyRepo
.
ListWithFiltersAndAccountCount
(
ctx
,
params
,
protocol
,
status
,
search
)
proxies
,
result
,
err
:=
s
.
proxyRepo
.
ListWithFiltersAndAccountCount
(
ctx
,
params
,
protocol
,
status
,
search
)
if
err
!=
nil
{
if
err
!=
nil
{
return
nil
,
0
,
err
return
nil
,
0
,
err
...
@@ -2032,8 +2040,8 @@ func (s *adminServiceImpl) CheckProxyExists(ctx context.Context, host string, po
...
@@ -2032,8 +2040,8 @@ func (s *adminServiceImpl) CheckProxyExists(ctx context.Context, host string, po
}
}
// Redeem code management implementations
// Redeem code management implementations
func
(
s
*
adminServiceImpl
)
ListRedeemCodes
(
ctx
context
.
Context
,
page
,
pageSize
int
,
codeType
,
status
,
search
string
)
([]
RedeemCode
,
int64
,
error
)
{
func
(
s
*
adminServiceImpl
)
ListRedeemCodes
(
ctx
context
.
Context
,
page
,
pageSize
int
,
codeType
,
status
,
search
string
,
sortBy
,
sortOrder
string
)
([]
RedeemCode
,
int64
,
error
)
{
params
:=
pagination
.
PaginationParams
{
Page
:
page
,
PageSize
:
pageSize
}
params
:=
pagination
.
PaginationParams
{
Page
:
page
,
PageSize
:
pageSize
,
SortBy
:
sortBy
,
SortOrder
:
sortOrder
}
codes
,
result
,
err
:=
s
.
redeemCodeRepo
.
ListWithFilters
(
ctx
,
params
,
codeType
,
status
,
search
)
codes
,
result
,
err
:=
s
.
redeemCodeRepo
.
ListWithFilters
(
ctx
,
params
,
codeType
,
status
,
search
)
if
err
!=
nil
{
if
err
!=
nil
{
return
nil
,
0
,
err
return
nil
,
0
,
err
...
...
backend/internal/service/admin_service_group_test.go
View file @
a04ae28a
...
@@ -10,6 +10,11 @@ import (
...
@@ -10,6 +10,11 @@ import (
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/require"
)
)
func
ptrString
[
T
~
string
](
v
T
)
*
string
{
s
:=
string
(
v
)
return
&
s
}
// groupRepoStubForAdmin 用于测试 AdminService 的 GroupRepository Stub
// groupRepoStubForAdmin 用于测试 AdminService 的 GroupRepository Stub
type
groupRepoStubForAdmin
struct
{
type
groupRepoStubForAdmin
struct
{
created
*
Group
// 记录 Create 调用的参数
created
*
Group
// 记录 Create 调用的参数
...
@@ -120,6 +125,22 @@ func (s *groupRepoStubForAdmin) UpdateSortOrders(_ context.Context, _ []GroupSor
...
@@ -120,6 +125,22 @@ func (s *groupRepoStubForAdmin) UpdateSortOrders(_ context.Context, _ []GroupSor
return
nil
return
nil
}
}
func
TestAdminService_ListGroups_PassesSortParams
(
t
*
testing
.
T
)
{
repo
:=
&
groupRepoStubForAdmin
{
listWithFiltersGroups
:
[]
Group
{{
ID
:
1
,
Name
:
"g1"
}},
}
svc
:=
&
adminServiceImpl
{
groupRepo
:
repo
}
_
,
_
,
err
:=
svc
.
ListGroups
(
context
.
Background
(),
3
,
25
,
PlatformOpenAI
,
StatusActive
,
"needle"
,
nil
,
"account_count"
,
"ASC"
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
pagination
.
PaginationParams
{
Page
:
3
,
PageSize
:
25
,
SortBy
:
"account_count"
,
SortOrder
:
"ASC"
,
},
repo
.
listWithFiltersParams
)
}
// TestAdminService_CreateGroup_WithImagePricing 测试创建分组时 ImagePrice 字段正确传递
// TestAdminService_CreateGroup_WithImagePricing 测试创建分组时 ImagePrice 字段正确传递
func
TestAdminService_CreateGroup_WithImagePricing
(
t
*
testing
.
T
)
{
func
TestAdminService_CreateGroup_WithImagePricing
(
t
*
testing
.
T
)
{
repo
:=
&
groupRepoStubForAdmin
{}
repo
:=
&
groupRepoStubForAdmin
{}
...
@@ -245,6 +266,116 @@ func TestAdminService_UpdateGroup_PartialImagePricing(t *testing.T) {
...
@@ -245,6 +266,116 @@ func TestAdminService_UpdateGroup_PartialImagePricing(t *testing.T) {
require
.
Nil
(
t
,
repo
.
updated
.
ImagePrice4K
)
require
.
Nil
(
t
,
repo
.
updated
.
ImagePrice4K
)
}
}
func
TestAdminService_CreateGroup_NormalizesMessagesDispatchModelConfig
(
t
*
testing
.
T
)
{
repo
:=
&
groupRepoStubForAdmin
{}
svc
:=
&
adminServiceImpl
{
groupRepo
:
repo
}
group
,
err
:=
svc
.
CreateGroup
(
context
.
Background
(),
&
CreateGroupInput
{
Name
:
"dispatch-group"
,
Description
:
"dispatch config"
,
Platform
:
PlatformOpenAI
,
RateMultiplier
:
1.0
,
MessagesDispatchModelConfig
:
OpenAIMessagesDispatchModelConfig
{
OpusMappedModel
:
" gpt-5.4-high "
,
SonnetMappedModel
:
" gpt-5.3-codex "
,
HaikuMappedModel
:
" gpt-5.4-mini-medium "
,
ExactModelMappings
:
map
[
string
]
string
{
" claude-sonnet-4-5-20250929 "
:
" gpt-5.2-high "
,
},
},
})
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
group
)
require
.
NotNil
(
t
,
repo
.
created
)
require
.
Equal
(
t
,
OpenAIMessagesDispatchModelConfig
{
OpusMappedModel
:
"gpt-5.4"
,
SonnetMappedModel
:
"gpt-5.3-codex"
,
HaikuMappedModel
:
"gpt-5.4-mini"
,
ExactModelMappings
:
map
[
string
]
string
{
"claude-sonnet-4-5-20250929"
:
"gpt-5.2"
,
},
},
repo
.
created
.
MessagesDispatchModelConfig
)
}
func
TestAdminService_UpdateGroup_NormalizesMessagesDispatchModelConfig
(
t
*
testing
.
T
)
{
existingGroup
:=
&
Group
{
ID
:
1
,
Name
:
"existing-group"
,
Platform
:
PlatformOpenAI
,
Status
:
StatusActive
,
}
repo
:=
&
groupRepoStubForAdmin
{
getByID
:
existingGroup
}
svc
:=
&
adminServiceImpl
{
groupRepo
:
repo
}
group
,
err
:=
svc
.
UpdateGroup
(
context
.
Background
(),
1
,
&
UpdateGroupInput
{
MessagesDispatchModelConfig
:
&
OpenAIMessagesDispatchModelConfig
{
SonnetMappedModel
:
" gpt-5.4-medium "
,
ExactModelMappings
:
map
[
string
]
string
{
" claude-haiku-4-5-20251001 "
:
" gpt-5.4-mini-high "
,
},
},
})
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
group
)
require
.
NotNil
(
t
,
repo
.
updated
)
require
.
Equal
(
t
,
OpenAIMessagesDispatchModelConfig
{
SonnetMappedModel
:
"gpt-5.4"
,
ExactModelMappings
:
map
[
string
]
string
{
"claude-haiku-4-5-20251001"
:
"gpt-5.4-mini"
,
},
},
repo
.
updated
.
MessagesDispatchModelConfig
)
}
func
TestAdminService_CreateGroup_ClearsMessagesDispatchFieldsForNonOpenAIPlatform
(
t
*
testing
.
T
)
{
repo
:=
&
groupRepoStubForAdmin
{}
svc
:=
&
adminServiceImpl
{
groupRepo
:
repo
}
group
,
err
:=
svc
.
CreateGroup
(
context
.
Background
(),
&
CreateGroupInput
{
Name
:
"anthropic-group"
,
Description
:
"non-openai"
,
Platform
:
PlatformAnthropic
,
RateMultiplier
:
1.0
,
AllowMessagesDispatch
:
true
,
DefaultMappedModel
:
"gpt-5.4"
,
MessagesDispatchModelConfig
:
OpenAIMessagesDispatchModelConfig
{
OpusMappedModel
:
"gpt-5.4"
,
},
})
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
group
)
require
.
NotNil
(
t
,
repo
.
created
)
require
.
False
(
t
,
repo
.
created
.
AllowMessagesDispatch
)
require
.
Empty
(
t
,
repo
.
created
.
DefaultMappedModel
)
require
.
Equal
(
t
,
OpenAIMessagesDispatchModelConfig
{},
repo
.
created
.
MessagesDispatchModelConfig
)
}
func
TestAdminService_UpdateGroup_ClearsMessagesDispatchFieldsWhenPlatformChangesAwayFromOpenAI
(
t
*
testing
.
T
)
{
existingGroup
:=
&
Group
{
ID
:
1
,
Name
:
"existing-openai-group"
,
Platform
:
PlatformOpenAI
,
Status
:
StatusActive
,
AllowMessagesDispatch
:
true
,
DefaultMappedModel
:
"gpt-5.4"
,
MessagesDispatchModelConfig
:
OpenAIMessagesDispatchModelConfig
{
SonnetMappedModel
:
"gpt-5.3-codex"
,
},
}
repo
:=
&
groupRepoStubForAdmin
{
getByID
:
existingGroup
}
svc
:=
&
adminServiceImpl
{
groupRepo
:
repo
}
group
,
err
:=
svc
.
UpdateGroup
(
context
.
Background
(),
1
,
&
UpdateGroupInput
{
Platform
:
PlatformAnthropic
,
})
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
group
)
require
.
NotNil
(
t
,
repo
.
updated
)
require
.
Equal
(
t
,
PlatformAnthropic
,
repo
.
updated
.
Platform
)
require
.
False
(
t
,
repo
.
updated
.
AllowMessagesDispatch
)
require
.
Empty
(
t
,
repo
.
updated
.
DefaultMappedModel
)
require
.
Equal
(
t
,
OpenAIMessagesDispatchModelConfig
{},
repo
.
updated
.
MessagesDispatchModelConfig
)
}
func
TestAdminService_ListGroups_WithSearch
(
t
*
testing
.
T
)
{
func
TestAdminService_ListGroups_WithSearch
(
t
*
testing
.
T
)
{
// 测试:
// 测试:
// 1. search 参数正常传递到 repository 层
// 1. search 参数正常传递到 repository 层
...
@@ -258,7 +389,7 @@ func TestAdminService_ListGroups_WithSearch(t *testing.T) {
...
@@ -258,7 +389,7 @@ func TestAdminService_ListGroups_WithSearch(t *testing.T) {
}
}
svc
:=
&
adminServiceImpl
{
groupRepo
:
repo
}
svc
:=
&
adminServiceImpl
{
groupRepo
:
repo
}
groups
,
total
,
err
:=
svc
.
ListGroups
(
context
.
Background
(),
1
,
20
,
""
,
""
,
"alpha"
,
nil
)
groups
,
total
,
err
:=
svc
.
ListGroups
(
context
.
Background
(),
1
,
20
,
""
,
""
,
"alpha"
,
nil
,
""
,
""
)
require
.
NoError
(
t
,
err
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
int64
(
1
),
total
)
require
.
Equal
(
t
,
int64
(
1
),
total
)
require
.
Equal
(
t
,
[]
Group
{{
ID
:
1
,
Name
:
"alpha"
}},
groups
)
require
.
Equal
(
t
,
[]
Group
{{
ID
:
1
,
Name
:
"alpha"
}},
groups
)
...
@@ -276,7 +407,7 @@ func TestAdminService_ListGroups_WithSearch(t *testing.T) {
...
@@ -276,7 +407,7 @@ func TestAdminService_ListGroups_WithSearch(t *testing.T) {
}
}
svc
:=
&
adminServiceImpl
{
groupRepo
:
repo
}
svc
:=
&
adminServiceImpl
{
groupRepo
:
repo
}
groups
,
total
,
err
:=
svc
.
ListGroups
(
context
.
Background
(),
2
,
10
,
""
,
""
,
""
,
nil
)
groups
,
total
,
err
:=
svc
.
ListGroups
(
context
.
Background
(),
2
,
10
,
""
,
""
,
""
,
nil
,
""
,
""
)
require
.
NoError
(
t
,
err
)
require
.
NoError
(
t
,
err
)
require
.
Empty
(
t
,
groups
)
require
.
Empty
(
t
,
groups
)
require
.
Equal
(
t
,
int64
(
0
),
total
)
require
.
Equal
(
t
,
int64
(
0
),
total
)
...
@@ -295,7 +426,7 @@ func TestAdminService_ListGroups_WithSearch(t *testing.T) {
...
@@ -295,7 +426,7 @@ func TestAdminService_ListGroups_WithSearch(t *testing.T) {
}
}
svc
:=
&
adminServiceImpl
{
groupRepo
:
repo
}
svc
:=
&
adminServiceImpl
{
groupRepo
:
repo
}
groups
,
total
,
err
:=
svc
.
ListGroups
(
context
.
Background
(),
3
,
50
,
PlatformAntigravity
,
StatusActive
,
"beta"
,
&
isExclusive
)
groups
,
total
,
err
:=
svc
.
ListGroups
(
context
.
Background
(),
3
,
50
,
PlatformAntigravity
,
StatusActive
,
"beta"
,
&
isExclusive
,
""
,
""
)
require
.
NoError
(
t
,
err
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
int64
(
42
),
total
)
require
.
Equal
(
t
,
int64
(
42
),
total
)
require
.
Equal
(
t
,
[]
Group
{{
ID
:
2
,
Name
:
"beta"
}},
groups
)
require
.
Equal
(
t
,
[]
Group
{{
ID
:
2
,
Name
:
"beta"
}},
groups
)
...
...
backend/internal/service/admin_service_list_users_test.go
View file @
a04ae28a
...
@@ -13,11 +13,13 @@ import (
...
@@ -13,11 +13,13 @@ import (
type
userRepoStubForListUsers
struct
{
type
userRepoStubForListUsers
struct
{
userRepoStub
userRepoStub
users
[]
User
users
[]
User
err
error
err
error
listWithFiltersParams
pagination
.
PaginationParams
}
}
func
(
s
*
userRepoStubForListUsers
)
ListWithFilters
(
_
context
.
Context
,
params
pagination
.
PaginationParams
,
_
UserListFilters
)
([]
User
,
*
pagination
.
PaginationResult
,
error
)
{
func
(
s
*
userRepoStubForListUsers
)
ListWithFilters
(
_
context
.
Context
,
params
pagination
.
PaginationParams
,
_
UserListFilters
)
([]
User
,
*
pagination
.
PaginationResult
,
error
)
{
s
.
listWithFiltersParams
=
params
if
s
.
err
!=
nil
{
if
s
.
err
!=
nil
{
return
nil
,
nil
,
s
.
err
return
nil
,
nil
,
s
.
err
}
}
...
@@ -103,7 +105,7 @@ func TestAdminService_ListUsers_BatchRateFallbackToSingle(t *testing.T) {
...
@@ -103,7 +105,7 @@ func TestAdminService_ListUsers_BatchRateFallbackToSingle(t *testing.T) {
userGroupRateRepo
:
rateRepo
,
userGroupRateRepo
:
rateRepo
,
}
}
users
,
total
,
err
:=
svc
.
ListUsers
(
context
.
Background
(),
1
,
20
,
UserListFilters
{})
users
,
total
,
err
:=
svc
.
ListUsers
(
context
.
Background
(),
1
,
20
,
UserListFilters
{}
,
""
,
""
)
require
.
NoError
(
t
,
err
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
int64
(
2
),
total
)
require
.
Equal
(
t
,
int64
(
2
),
total
)
require
.
Len
(
t
,
users
,
2
)
require
.
Len
(
t
,
users
,
2
)
...
@@ -112,3 +114,19 @@ func TestAdminService_ListUsers_BatchRateFallbackToSingle(t *testing.T) {
...
@@ -112,3 +114,19 @@ func TestAdminService_ListUsers_BatchRateFallbackToSingle(t *testing.T) {
require
.
Equal
(
t
,
1.1
,
users
[
0
]
.
GroupRates
[
11
])
require
.
Equal
(
t
,
1.1
,
users
[
0
]
.
GroupRates
[
11
])
require
.
Equal
(
t
,
2.2
,
users
[
1
]
.
GroupRates
[
22
])
require
.
Equal
(
t
,
2.2
,
users
[
1
]
.
GroupRates
[
22
])
}
}
func
TestAdminService_ListUsers_PassesSortParams
(
t
*
testing
.
T
)
{
userRepo
:=
&
userRepoStubForListUsers
{
users
:
[]
User
{{
ID
:
1
,
Email
:
"a@example.com"
}},
}
svc
:=
&
adminServiceImpl
{
userRepo
:
userRepo
}
_
,
_
,
err
:=
svc
.
ListUsers
(
context
.
Background
(),
2
,
50
,
UserListFilters
{},
"email"
,
"ASC"
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
pagination
.
PaginationParams
{
Page
:
2
,
PageSize
:
50
,
SortBy
:
"email"
,
SortOrder
:
"ASC"
,
},
userRepo
.
listWithFiltersParams
)
}
backend/internal/service/admin_service_search_test.go
View file @
a04ae28a
...
@@ -170,13 +170,13 @@ func TestAdminService_ListAccounts_WithSearch(t *testing.T) {
...
@@ -170,13 +170,13 @@ func TestAdminService_ListAccounts_WithSearch(t *testing.T) {
}
}
svc
:=
&
adminServiceImpl
{
accountRepo
:
repo
}
svc
:=
&
adminServiceImpl
{
accountRepo
:
repo
}
accounts
,
total
,
err
:=
svc
.
ListAccounts
(
context
.
Background
(),
1
,
20
,
PlatformGemini
,
AccountTypeOAuth
,
StatusActive
,
"acc"
,
0
,
""
)
accounts
,
total
,
err
:=
svc
.
ListAccounts
(
context
.
Background
(),
1
,
20
,
PlatformGemini
,
AccountTypeOAuth
,
StatusActive
,
"acc"
,
0
,
""
,
"name"
,
"ASC"
)
require
.
NoError
(
t
,
err
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
int64
(
10
),
total
)
require
.
Equal
(
t
,
int64
(
10
),
total
)
require
.
Equal
(
t
,
[]
Account
{{
ID
:
1
,
Name
:
"acc"
}},
accounts
)
require
.
Equal
(
t
,
[]
Account
{{
ID
:
1
,
Name
:
"acc"
}},
accounts
)
require
.
Equal
(
t
,
1
,
repo
.
listWithFiltersCalls
)
require
.
Equal
(
t
,
1
,
repo
.
listWithFiltersCalls
)
require
.
Equal
(
t
,
pagination
.
PaginationParams
{
Page
:
1
,
PageSize
:
20
},
repo
.
listWithFiltersParams
)
require
.
Equal
(
t
,
pagination
.
PaginationParams
{
Page
:
1
,
PageSize
:
20
,
SortBy
:
"name"
,
SortOrder
:
"ASC"
},
repo
.
listWithFiltersParams
)
require
.
Equal
(
t
,
PlatformGemini
,
repo
.
listWithFiltersPlatform
)
require
.
Equal
(
t
,
PlatformGemini
,
repo
.
listWithFiltersPlatform
)
require
.
Equal
(
t
,
AccountTypeOAuth
,
repo
.
listWithFiltersType
)
require
.
Equal
(
t
,
AccountTypeOAuth
,
repo
.
listWithFiltersType
)
require
.
Equal
(
t
,
StatusActive
,
repo
.
listWithFiltersStatus
)
require
.
Equal
(
t
,
StatusActive
,
repo
.
listWithFiltersStatus
)
...
@@ -192,7 +192,7 @@ func TestAdminService_ListAccounts_WithPrivacyMode(t *testing.T) {
...
@@ -192,7 +192,7 @@ func TestAdminService_ListAccounts_WithPrivacyMode(t *testing.T) {
}
}
svc
:=
&
adminServiceImpl
{
accountRepo
:
repo
}
svc
:=
&
adminServiceImpl
{
accountRepo
:
repo
}
accounts
,
total
,
err
:=
svc
.
ListAccounts
(
context
.
Background
(),
1
,
20
,
PlatformOpenAI
,
AccountTypeOAuth
,
StatusActive
,
"acc2"
,
0
,
PrivacyModeCFBlocked
)
accounts
,
total
,
err
:=
svc
.
ListAccounts
(
context
.
Background
(),
1
,
20
,
PlatformOpenAI
,
AccountTypeOAuth
,
StatusActive
,
"acc2"
,
0
,
PrivacyModeCFBlocked
,
""
,
""
)
require
.
NoError
(
t
,
err
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
int64
(
1
),
total
)
require
.
Equal
(
t
,
int64
(
1
),
total
)
require
.
Equal
(
t
,
[]
Account
{{
ID
:
2
,
Name
:
"acc2"
}},
accounts
)
require
.
Equal
(
t
,
[]
Account
{{
ID
:
2
,
Name
:
"acc2"
}},
accounts
)
...
@@ -208,13 +208,13 @@ func TestAdminService_ListProxies_WithSearch(t *testing.T) {
...
@@ -208,13 +208,13 @@ func TestAdminService_ListProxies_WithSearch(t *testing.T) {
}
}
svc
:=
&
adminServiceImpl
{
proxyRepo
:
repo
}
svc
:=
&
adminServiceImpl
{
proxyRepo
:
repo
}
proxies
,
total
,
err
:=
svc
.
ListProxies
(
context
.
Background
(),
3
,
50
,
"http"
,
StatusActive
,
"p1"
)
proxies
,
total
,
err
:=
svc
.
ListProxies
(
context
.
Background
(),
3
,
50
,
"http"
,
StatusActive
,
"p1"
,
"name"
,
"ASC"
)
require
.
NoError
(
t
,
err
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
int64
(
7
),
total
)
require
.
Equal
(
t
,
int64
(
7
),
total
)
require
.
Equal
(
t
,
[]
Proxy
{{
ID
:
2
,
Name
:
"p1"
}},
proxies
)
require
.
Equal
(
t
,
[]
Proxy
{{
ID
:
2
,
Name
:
"p1"
}},
proxies
)
require
.
Equal
(
t
,
1
,
repo
.
listWithFiltersCalls
)
require
.
Equal
(
t
,
1
,
repo
.
listWithFiltersCalls
)
require
.
Equal
(
t
,
pagination
.
PaginationParams
{
Page
:
3
,
PageSize
:
50
},
repo
.
listWithFiltersParams
)
require
.
Equal
(
t
,
pagination
.
PaginationParams
{
Page
:
3
,
PageSize
:
50
,
SortBy
:
"name"
,
SortOrder
:
"ASC"
},
repo
.
listWithFiltersParams
)
require
.
Equal
(
t
,
"http"
,
repo
.
listWithFiltersProtocol
)
require
.
Equal
(
t
,
"http"
,
repo
.
listWithFiltersProtocol
)
require
.
Equal
(
t
,
StatusActive
,
repo
.
listWithFiltersStatus
)
require
.
Equal
(
t
,
StatusActive
,
repo
.
listWithFiltersStatus
)
require
.
Equal
(
t
,
"p1"
,
repo
.
listWithFiltersSearch
)
require
.
Equal
(
t
,
"p1"
,
repo
.
listWithFiltersSearch
)
...
@@ -229,13 +229,13 @@ func TestAdminService_ListProxiesWithAccountCount_WithSearch(t *testing.T) {
...
@@ -229,13 +229,13 @@ func TestAdminService_ListProxiesWithAccountCount_WithSearch(t *testing.T) {
}
}
svc
:=
&
adminServiceImpl
{
proxyRepo
:
repo
}
svc
:=
&
adminServiceImpl
{
proxyRepo
:
repo
}
proxies
,
total
,
err
:=
svc
.
ListProxiesWithAccountCount
(
context
.
Background
(),
2
,
10
,
"socks5"
,
StatusDisabled
,
"p2"
)
proxies
,
total
,
err
:=
svc
.
ListProxiesWithAccountCount
(
context
.
Background
(),
2
,
10
,
"socks5"
,
StatusDisabled
,
"p2"
,
"account_count"
,
"DESC"
)
require
.
NoError
(
t
,
err
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
int64
(
9
),
total
)
require
.
Equal
(
t
,
int64
(
9
),
total
)
require
.
Equal
(
t
,
[]
ProxyWithAccountCount
{{
Proxy
:
Proxy
{
ID
:
3
,
Name
:
"p2"
},
AccountCount
:
5
}},
proxies
)
require
.
Equal
(
t
,
[]
ProxyWithAccountCount
{{
Proxy
:
Proxy
{
ID
:
3
,
Name
:
"p2"
},
AccountCount
:
5
}},
proxies
)
require
.
Equal
(
t
,
1
,
repo
.
listWithFiltersAndAccountCountCalls
)
require
.
Equal
(
t
,
1
,
repo
.
listWithFiltersAndAccountCountCalls
)
require
.
Equal
(
t
,
pagination
.
PaginationParams
{
Page
:
2
,
PageSize
:
10
},
repo
.
listWithFiltersAndAccountCountParams
)
require
.
Equal
(
t
,
pagination
.
PaginationParams
{
Page
:
2
,
PageSize
:
10
,
SortBy
:
"account_count"
,
SortOrder
:
"DESC"
},
repo
.
listWithFiltersAndAccountCountParams
)
require
.
Equal
(
t
,
"socks5"
,
repo
.
listWithFiltersAndAccountCountProtocol
)
require
.
Equal
(
t
,
"socks5"
,
repo
.
listWithFiltersAndAccountCountProtocol
)
require
.
Equal
(
t
,
StatusDisabled
,
repo
.
listWithFiltersAndAccountCountStatus
)
require
.
Equal
(
t
,
StatusDisabled
,
repo
.
listWithFiltersAndAccountCountStatus
)
require
.
Equal
(
t
,
"p2"
,
repo
.
listWithFiltersAndAccountCountSearch
)
require
.
Equal
(
t
,
"p2"
,
repo
.
listWithFiltersAndAccountCountSearch
)
...
@@ -250,13 +250,13 @@ func TestAdminService_ListRedeemCodes_WithSearch(t *testing.T) {
...
@@ -250,13 +250,13 @@ func TestAdminService_ListRedeemCodes_WithSearch(t *testing.T) {
}
}
svc
:=
&
adminServiceImpl
{
redeemCodeRepo
:
repo
}
svc
:=
&
adminServiceImpl
{
redeemCodeRepo
:
repo
}
codes
,
total
,
err
:=
svc
.
ListRedeemCodes
(
context
.
Background
(),
1
,
20
,
RedeemTypeBalance
,
StatusUnused
,
"ABC"
)
codes
,
total
,
err
:=
svc
.
ListRedeemCodes
(
context
.
Background
(),
1
,
20
,
RedeemTypeBalance
,
StatusUnused
,
"ABC"
,
"value"
,
"ASC"
)
require
.
NoError
(
t
,
err
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
int64
(
3
),
total
)
require
.
Equal
(
t
,
int64
(
3
),
total
)
require
.
Equal
(
t
,
[]
RedeemCode
{{
ID
:
4
,
Code
:
"ABC"
}},
codes
)
require
.
Equal
(
t
,
[]
RedeemCode
{{
ID
:
4
,
Code
:
"ABC"
}},
codes
)
require
.
Equal
(
t
,
1
,
repo
.
listWithFiltersCalls
)
require
.
Equal
(
t
,
1
,
repo
.
listWithFiltersCalls
)
require
.
Equal
(
t
,
pagination
.
PaginationParams
{
Page
:
1
,
PageSize
:
20
},
repo
.
listWithFiltersParams
)
require
.
Equal
(
t
,
pagination
.
PaginationParams
{
Page
:
1
,
PageSize
:
20
,
SortBy
:
"value"
,
SortOrder
:
"ASC"
},
repo
.
listWithFiltersParams
)
require
.
Equal
(
t
,
RedeemTypeBalance
,
repo
.
listWithFiltersType
)
require
.
Equal
(
t
,
RedeemTypeBalance
,
repo
.
listWithFiltersType
)
require
.
Equal
(
t
,
StatusUnused
,
repo
.
listWithFiltersStatus
)
require
.
Equal
(
t
,
StatusUnused
,
repo
.
listWithFiltersStatus
)
require
.
Equal
(
t
,
"ABC"
,
repo
.
listWithFiltersSearch
)
require
.
Equal
(
t
,
"ABC"
,
repo
.
listWithFiltersSearch
)
...
...
backend/internal/service/api_key_auth_cache.go
View file @
a04ae28a
...
@@ -4,6 +4,7 @@ import "time"
...
@@ -4,6 +4,7 @@ import "time"
// APIKeyAuthSnapshot API Key 认证缓存快照(仅包含认证所需字段)
// APIKeyAuthSnapshot API Key 认证缓存快照(仅包含认证所需字段)
type
APIKeyAuthSnapshot
struct
{
type
APIKeyAuthSnapshot
struct
{
Version
int
`json:"version"`
APIKeyID
int64
`json:"api_key_id"`
APIKeyID
int64
`json:"api_key_id"`
UserID
int64
`json:"user_id"`
UserID
int64
`json:"user_id"`
GroupID
*
int64
`json:"group_id,omitempty"`
GroupID
*
int64
`json:"group_id,omitempty"`
...
@@ -63,8 +64,9 @@ type APIKeyAuthGroupSnapshot struct {
...
@@ -63,8 +64,9 @@ type APIKeyAuthGroupSnapshot struct {
SupportedModelScopes
[]
string
`json:"supported_model_scopes,omitempty"`
SupportedModelScopes
[]
string
`json:"supported_model_scopes,omitempty"`
// OpenAI Messages 调度配置(仅 openai 平台使用)
// OpenAI Messages 调度配置(仅 openai 平台使用)
AllowMessagesDispatch
bool
`json:"allow_messages_dispatch"`
AllowMessagesDispatch
bool
`json:"allow_messages_dispatch"`
DefaultMappedModel
string
`json:"default_mapped_model,omitempty"`
DefaultMappedModel
string
`json:"default_mapped_model,omitempty"`
MessagesDispatchModelConfig
OpenAIMessagesDispatchModelConfig
`json:"messages_dispatch_model_config,omitempty"`
}
}
// APIKeyAuthCacheEntry 缓存条目,支持负缓存
// APIKeyAuthCacheEntry 缓存条目,支持负缓存
...
...
backend/internal/service/api_key_auth_cache_impl.go
View file @
a04ae28a
...
@@ -13,6 +13,8 @@ import (
...
@@ -13,6 +13,8 @@ import (
"github.com/dgraph-io/ristretto"
"github.com/dgraph-io/ristretto"
)
)
const
apiKeyAuthSnapshotVersion
=
3
type
apiKeyAuthCacheConfig
struct
{
type
apiKeyAuthCacheConfig
struct
{
l1Size
int
l1Size
int
l1TTL
time
.
Duration
l1TTL
time
.
Duration
...
@@ -192,6 +194,9 @@ func (s *APIKeyService) applyAuthCacheEntry(key string, entry *APIKeyAuthCacheEn
...
@@ -192,6 +194,9 @@ func (s *APIKeyService) applyAuthCacheEntry(key string, entry *APIKeyAuthCacheEn
if
entry
.
Snapshot
==
nil
{
if
entry
.
Snapshot
==
nil
{
return
nil
,
false
,
nil
return
nil
,
false
,
nil
}
}
if
entry
.
Snapshot
.
Version
!=
apiKeyAuthSnapshotVersion
{
return
nil
,
false
,
nil
}
return
s
.
snapshotToAPIKey
(
key
,
entry
.
Snapshot
),
true
,
nil
return
s
.
snapshotToAPIKey
(
key
,
entry
.
Snapshot
),
true
,
nil
}
}
...
@@ -200,6 +205,7 @@ func (s *APIKeyService) snapshotFromAPIKey(apiKey *APIKey) *APIKeyAuthSnapshot {
...
@@ -200,6 +205,7 @@ func (s *APIKeyService) snapshotFromAPIKey(apiKey *APIKey) *APIKeyAuthSnapshot {
return
nil
return
nil
}
}
snapshot
:=
&
APIKeyAuthSnapshot
{
snapshot
:=
&
APIKeyAuthSnapshot
{
Version
:
apiKeyAuthSnapshotVersion
,
APIKeyID
:
apiKey
.
ID
,
APIKeyID
:
apiKey
.
ID
,
UserID
:
apiKey
.
UserID
,
UserID
:
apiKey
.
UserID
,
GroupID
:
apiKey
.
GroupID
,
GroupID
:
apiKey
.
GroupID
,
...
@@ -243,6 +249,7 @@ func (s *APIKeyService) snapshotFromAPIKey(apiKey *APIKey) *APIKeyAuthSnapshot {
...
@@ -243,6 +249,7 @@ func (s *APIKeyService) snapshotFromAPIKey(apiKey *APIKey) *APIKeyAuthSnapshot {
SupportedModelScopes
:
apiKey
.
Group
.
SupportedModelScopes
,
SupportedModelScopes
:
apiKey
.
Group
.
SupportedModelScopes
,
AllowMessagesDispatch
:
apiKey
.
Group
.
AllowMessagesDispatch
,
AllowMessagesDispatch
:
apiKey
.
Group
.
AllowMessagesDispatch
,
DefaultMappedModel
:
apiKey
.
Group
.
DefaultMappedModel
,
DefaultMappedModel
:
apiKey
.
Group
.
DefaultMappedModel
,
MessagesDispatchModelConfig
:
apiKey
.
Group
.
MessagesDispatchModelConfig
,
}
}
}
}
return
snapshot
return
snapshot
...
@@ -298,6 +305,7 @@ func (s *APIKeyService) snapshotToAPIKey(key string, snapshot *APIKeyAuthSnapsho
...
@@ -298,6 +305,7 @@ func (s *APIKeyService) snapshotToAPIKey(key string, snapshot *APIKeyAuthSnapsho
SupportedModelScopes
:
snapshot
.
Group
.
SupportedModelScopes
,
SupportedModelScopes
:
snapshot
.
Group
.
SupportedModelScopes
,
AllowMessagesDispatch
:
snapshot
.
Group
.
AllowMessagesDispatch
,
AllowMessagesDispatch
:
snapshot
.
Group
.
AllowMessagesDispatch
,
DefaultMappedModel
:
snapshot
.
Group
.
DefaultMappedModel
,
DefaultMappedModel
:
snapshot
.
Group
.
DefaultMappedModel
,
MessagesDispatchModelConfig
:
snapshot
.
Group
.
MessagesDispatchModelConfig
,
}
}
}
}
s
.
compileAPIKeyIPRules
(
apiKey
)
s
.
compileAPIKeyIPRules
(
apiKey
)
...
...
backend/internal/service/api_key_service_cache_test.go
View file @
a04ae28a
...
@@ -188,6 +188,7 @@ func TestAPIKeyService_GetByKey_UsesL2Cache(t *testing.T) {
...
@@ -188,6 +188,7 @@ func TestAPIKeyService_GetByKey_UsesL2Cache(t *testing.T) {
groupID
:=
int64
(
9
)
groupID
:=
int64
(
9
)
cacheEntry
:=
&
APIKeyAuthCacheEntry
{
cacheEntry
:=
&
APIKeyAuthCacheEntry
{
Snapshot
:
&
APIKeyAuthSnapshot
{
Snapshot
:
&
APIKeyAuthSnapshot
{
Version
:
apiKeyAuthSnapshotVersion
,
APIKeyID
:
1
,
APIKeyID
:
1
,
UserID
:
2
,
UserID
:
2
,
GroupID
:
&
groupID
,
GroupID
:
&
groupID
,
...
@@ -226,6 +227,129 @@ func TestAPIKeyService_GetByKey_UsesL2Cache(t *testing.T) {
...
@@ -226,6 +227,129 @@ func TestAPIKeyService_GetByKey_UsesL2Cache(t *testing.T) {
require
.
Equal
(
t
,
map
[
string
][]
int64
{
"claude-opus-*"
:
{
1
,
2
}},
apiKey
.
Group
.
ModelRouting
)
require
.
Equal
(
t
,
map
[
string
][]
int64
{
"claude-opus-*"
:
{
1
,
2
}},
apiKey
.
Group
.
ModelRouting
)
}
}
func
TestAPIKeyService_SnapshotRoundTrip_PreservesMessagesDispatchModelConfig
(
t
*
testing
.
T
)
{
svc
:=
NewAPIKeyService
(
nil
,
nil
,
nil
,
nil
,
nil
,
nil
,
&
config
.
Config
{})
groupID
:=
int64
(
9
)
apiKey
:=
&
APIKey
{
ID
:
1
,
UserID
:
2
,
GroupID
:
&
groupID
,
Key
:
"k-roundtrip"
,
Status
:
StatusActive
,
User
:
&
User
{
ID
:
2
,
Status
:
StatusActive
,
Role
:
RoleUser
,
Balance
:
10
,
Concurrency
:
3
,
},
Group
:
&
Group
{
ID
:
groupID
,
Name
:
"openai"
,
Platform
:
PlatformOpenAI
,
Status
:
StatusActive
,
SubscriptionType
:
SubscriptionTypeStandard
,
RateMultiplier
:
1
,
AllowMessagesDispatch
:
true
,
DefaultMappedModel
:
"gpt-5.4"
,
MessagesDispatchModelConfig
:
OpenAIMessagesDispatchModelConfig
{
OpusMappedModel
:
"gpt-5.4-nano"
,
SonnetMappedModel
:
"gpt-5.3-codex"
,
HaikuMappedModel
:
"gpt-5.4-mini"
,
ExactModelMappings
:
map
[
string
]
string
{
"claude-sonnet-4.5"
:
"gpt-5.4-nano"
,
},
},
},
}
snapshot
:=
svc
.
snapshotFromAPIKey
(
apiKey
)
roundTrip
:=
svc
.
snapshotToAPIKey
(
apiKey
.
Key
,
snapshot
)
require
.
NotNil
(
t
,
roundTrip
)
require
.
NotNil
(
t
,
roundTrip
.
Group
)
require
.
Equal
(
t
,
apiKey
.
Group
.
MessagesDispatchModelConfig
,
roundTrip
.
Group
.
MessagesDispatchModelConfig
)
}
func
TestAPIKeyService_GetByKey_IgnoresLegacyAuthCacheSnapshotWithoutMessagesDispatchConfig
(
t
*
testing
.
T
)
{
cache
:=
&
authCacheStub
{}
var
repoCalls
int32
repo
:=
&
authRepoStub
{
getByKeyForAuth
:
func
(
ctx
context
.
Context
,
key
string
)
(
*
APIKey
,
error
)
{
atomic
.
AddInt32
(
&
repoCalls
,
1
)
groupID
:=
int64
(
9
)
return
&
APIKey
{
ID
:
1
,
UserID
:
2
,
GroupID
:
&
groupID
,
Status
:
StatusActive
,
User
:
&
User
{
ID
:
2
,
Status
:
StatusActive
,
Role
:
RoleUser
,
Balance
:
10
,
Concurrency
:
3
,
},
Group
:
&
Group
{
ID
:
groupID
,
Name
:
"openai"
,
Platform
:
PlatformOpenAI
,
Status
:
StatusActive
,
Hydrated
:
true
,
SubscriptionType
:
SubscriptionTypeStandard
,
RateMultiplier
:
1
,
AllowMessagesDispatch
:
true
,
DefaultMappedModel
:
"gpt-5.4"
,
MessagesDispatchModelConfig
:
OpenAIMessagesDispatchModelConfig
{
OpusMappedModel
:
"gpt-5.4-nano"
,
},
},
},
nil
},
}
cfg
:=
&
config
.
Config
{
APIKeyAuth
:
config
.
APIKeyAuthCacheConfig
{
L2TTLSeconds
:
60
,
},
}
svc
:=
NewAPIKeyService
(
repo
,
nil
,
nil
,
nil
,
nil
,
cache
,
cfg
)
groupID
:=
int64
(
9
)
cache
.
getAuthCache
=
func
(
ctx
context
.
Context
,
key
string
)
(
*
APIKeyAuthCacheEntry
,
error
)
{
return
&
APIKeyAuthCacheEntry
{
Snapshot
:
&
APIKeyAuthSnapshot
{
APIKeyID
:
1
,
UserID
:
2
,
GroupID
:
&
groupID
,
Status
:
StatusActive
,
User
:
APIKeyAuthUserSnapshot
{
ID
:
2
,
Status
:
StatusActive
,
Role
:
RoleUser
,
Balance
:
10
,
Concurrency
:
3
,
},
Group
:
&
APIKeyAuthGroupSnapshot
{
ID
:
groupID
,
Name
:
"openai"
,
Platform
:
PlatformOpenAI
,
Status
:
StatusActive
,
SubscriptionType
:
SubscriptionTypeStandard
,
RateMultiplier
:
1
,
AllowMessagesDispatch
:
true
,
DefaultMappedModel
:
"gpt-5.4"
,
},
},
},
nil
}
apiKey
,
err
:=
svc
.
GetByKey
(
context
.
Background
(),
"k-legacy"
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
int32
(
1
),
atomic
.
LoadInt32
(
&
repoCalls
))
require
.
NotNil
(
t
,
apiKey
.
Group
)
require
.
Equal
(
t
,
"gpt-5.4-nano"
,
apiKey
.
Group
.
MessagesDispatchModelConfig
.
OpusMappedModel
)
}
func
TestAPIKeyService_GetByKey_NegativeCache
(
t
*
testing
.
T
)
{
func
TestAPIKeyService_GetByKey_NegativeCache
(
t
*
testing
.
T
)
{
cache
:=
&
authCacheStub
{}
cache
:=
&
authCacheStub
{}
repo
:=
&
authRepoStub
{
repo
:=
&
authRepoStub
{
...
...
backend/internal/service/auth_service.go
View file @
a04ae28a
...
@@ -833,7 +833,8 @@ func randomHexString(byteLength int) (string, error) {
...
@@ -833,7 +833,8 @@ func randomHexString(byteLength int) (string, error) {
func
isReservedEmail
(
email
string
)
bool
{
func
isReservedEmail
(
email
string
)
bool
{
normalized
:=
strings
.
ToLower
(
strings
.
TrimSpace
(
email
))
normalized
:=
strings
.
ToLower
(
strings
.
TrimSpace
(
email
))
return
strings
.
HasSuffix
(
normalized
,
LinuxDoConnectSyntheticEmailDomain
)
return
strings
.
HasSuffix
(
normalized
,
LinuxDoConnectSyntheticEmailDomain
)
||
strings
.
HasSuffix
(
normalized
,
OIDCConnectSyntheticEmailDomain
)
}
}
// GenerateToken 生成JWT access token
// GenerateToken 生成JWT access token
...
...
backend/internal/service/domain_constants.go
View file @
a04ae28a
...
@@ -71,6 +71,9 @@ const (
...
@@ -71,6 +71,9 @@ const (
// LinuxDoConnectSyntheticEmailDomain 是 LinuxDo Connect 用户的合成邮箱后缀(RFC 保留域名)。
// LinuxDoConnectSyntheticEmailDomain 是 LinuxDo Connect 用户的合成邮箱后缀(RFC 保留域名)。
const
LinuxDoConnectSyntheticEmailDomain
=
"@linuxdo-connect.invalid"
const
LinuxDoConnectSyntheticEmailDomain
=
"@linuxdo-connect.invalid"
// OIDCConnectSyntheticEmailDomain 是 OIDC 用户的合成邮箱后缀(RFC 保留域名)。
const
OIDCConnectSyntheticEmailDomain
=
"@oidc-connect.invalid"
// Setting keys
// Setting keys
const
(
const
(
// 注册设置
// 注册设置
...
@@ -105,6 +108,30 @@ const (
...
@@ -105,6 +108,30 @@ const (
SettingKeyLinuxDoConnectClientSecret
=
"linuxdo_connect_client_secret"
SettingKeyLinuxDoConnectClientSecret
=
"linuxdo_connect_client_secret"
SettingKeyLinuxDoConnectRedirectURL
=
"linuxdo_connect_redirect_url"
SettingKeyLinuxDoConnectRedirectURL
=
"linuxdo_connect_redirect_url"
// Generic OIDC OAuth 登录设置
SettingKeyOIDCConnectEnabled
=
"oidc_connect_enabled"
SettingKeyOIDCConnectProviderName
=
"oidc_connect_provider_name"
SettingKeyOIDCConnectClientID
=
"oidc_connect_client_id"
SettingKeyOIDCConnectClientSecret
=
"oidc_connect_client_secret"
SettingKeyOIDCConnectIssuerURL
=
"oidc_connect_issuer_url"
SettingKeyOIDCConnectDiscoveryURL
=
"oidc_connect_discovery_url"
SettingKeyOIDCConnectAuthorizeURL
=
"oidc_connect_authorize_url"
SettingKeyOIDCConnectTokenURL
=
"oidc_connect_token_url"
SettingKeyOIDCConnectUserInfoURL
=
"oidc_connect_userinfo_url"
SettingKeyOIDCConnectJWKSURL
=
"oidc_connect_jwks_url"
SettingKeyOIDCConnectScopes
=
"oidc_connect_scopes"
SettingKeyOIDCConnectRedirectURL
=
"oidc_connect_redirect_url"
SettingKeyOIDCConnectFrontendRedirectURL
=
"oidc_connect_frontend_redirect_url"
SettingKeyOIDCConnectTokenAuthMethod
=
"oidc_connect_token_auth_method"
SettingKeyOIDCConnectUsePKCE
=
"oidc_connect_use_pkce"
SettingKeyOIDCConnectValidateIDToken
=
"oidc_connect_validate_id_token"
SettingKeyOIDCConnectAllowedSigningAlgs
=
"oidc_connect_allowed_signing_algs"
SettingKeyOIDCConnectClockSkewSeconds
=
"oidc_connect_clock_skew_seconds"
SettingKeyOIDCConnectRequireEmailVerified
=
"oidc_connect_require_email_verified"
SettingKeyOIDCConnectUserInfoEmailPath
=
"oidc_connect_userinfo_email_path"
SettingKeyOIDCConnectUserInfoIDPath
=
"oidc_connect_userinfo_id_path"
SettingKeyOIDCConnectUserInfoUsernamePath
=
"oidc_connect_userinfo_username_path"
// OEM设置
// OEM设置
SettingKeySiteName
=
"site_name"
// 网站名称
SettingKeySiteName
=
"site_name"
// 网站名称
SettingKeySiteLogo
=
"site_logo"
// 网站Logo (base64)
SettingKeySiteLogo
=
"site_logo"
// 网站Logo (base64)
...
@@ -116,6 +143,8 @@ const (
...
@@ -116,6 +143,8 @@ const (
SettingKeyHideCcsImportButton
=
"hide_ccs_import_button"
// 是否隐藏 API Keys 页面的导入 CCS 按钮
SettingKeyHideCcsImportButton
=
"hide_ccs_import_button"
// 是否隐藏 API Keys 页面的导入 CCS 按钮
SettingKeyPurchaseSubscriptionEnabled
=
"purchase_subscription_enabled"
// 是否展示"购买订阅"页面入口
SettingKeyPurchaseSubscriptionEnabled
=
"purchase_subscription_enabled"
// 是否展示"购买订阅"页面入口
SettingKeyPurchaseSubscriptionURL
=
"purchase_subscription_url"
// "购买订阅"页面 URL(作为 iframe src)
SettingKeyPurchaseSubscriptionURL
=
"purchase_subscription_url"
// "购买订阅"页面 URL(作为 iframe src)
SettingKeyTableDefaultPageSize
=
"table_default_page_size"
// 表格默认每页条数
SettingKeyTablePageSizeOptions
=
"table_page_size_options"
// 表格可选每页条数(JSON 数组)
SettingKeyCustomMenuItems
=
"custom_menu_items"
// 自定义菜单项(JSON 数组)
SettingKeyCustomMenuItems
=
"custom_menu_items"
// 自定义菜单项(JSON 数组)
SettingKeyCustomEndpoints
=
"custom_endpoints"
// 自定义端点列表(JSON 数组)
SettingKeyCustomEndpoints
=
"custom_endpoints"
// 自定义端点列表(JSON 数组)
...
...
backend/internal/service/gateway_service.go
View file @
a04ae28a
...
@@ -1192,12 +1192,20 @@ func (s *GatewayService) SelectAccountForModelWithExclusions(ctx context.Context
...
@@ -1192,12 +1192,20 @@ func (s *GatewayService) SelectAccountForModelWithExclusions(ctx context.Context
// anthropic/gemini 分组支持混合调度(包含启用了 mixed_scheduling 的 antigravity 账户)
// anthropic/gemini 分组支持混合调度(包含启用了 mixed_scheduling 的 antigravity 账户)
// 注意:强制平台模式不走混合调度
// 注意:强制平台模式不走混合调度
if
(
platform
==
PlatformAnthropic
||
platform
==
PlatformGemini
)
&&
!
hasForcePlatform
{
if
(
platform
==
PlatformAnthropic
||
platform
==
PlatformGemini
)
&&
!
hasForcePlatform
{
return
s
.
selectAccountWithMixedScheduling
(
ctx
,
groupID
,
sessionHash
,
requestedModel
,
excludedIDs
,
platform
)
account
,
err
:=
s
.
selectAccountWithMixedScheduling
(
ctx
,
groupID
,
sessionHash
,
requestedModel
,
excludedIDs
,
platform
)
if
err
!=
nil
{
return
nil
,
err
}
return
s
.
hydrateSelectedAccount
(
ctx
,
account
)
}
}
// antigravity 分组、强制平台模式或无分组使用单平台选择
// antigravity 分组、强制平台模式或无分组使用单平台选择
// 注意:强制平台模式也必须遵守分组限制,不再回退到全平台查询
// 注意:强制平台模式也必须遵守分组限制,不再回退到全平台查询
return
s
.
selectAccountForModelWithPlatform
(
ctx
,
groupID
,
sessionHash
,
requestedModel
,
excludedIDs
,
platform
)
account
,
err
:=
s
.
selectAccountForModelWithPlatform
(
ctx
,
groupID
,
sessionHash
,
requestedModel
,
excludedIDs
,
platform
)
if
err
!=
nil
{
return
nil
,
err
}
return
s
.
hydrateSelectedAccount
(
ctx
,
account
)
}
}
// SelectAccountWithLoadAwareness selects account with load-awareness and wait plan.
// SelectAccountWithLoadAwareness selects account with load-awareness and wait plan.
...
@@ -1273,11 +1281,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
...
@@ -1273,11 +1281,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
localExcluded
[
account
.
ID
]
=
struct
{}{}
// 排除此账号
localExcluded
[
account
.
ID
]
=
struct
{}{}
// 排除此账号
continue
// 重新选择
continue
// 重新选择
}
}
return
&
AccountSelectionResult
{
return
s
.
newSelectionResult
(
ctx
,
account
,
true
,
result
.
ReleaseFunc
,
nil
)
Account
:
account
,
Acquired
:
true
,
ReleaseFunc
:
result
.
ReleaseFunc
,
},
nil
}
}
// 对于等待计划的情况,也需要先检查会话限制
// 对于等待计划的情况,也需要先检查会话限制
...
@@ -1289,26 +1293,20 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
...
@@ -1289,26 +1293,20 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
if
stickyAccountID
>
0
&&
stickyAccountID
==
account
.
ID
&&
s
.
concurrencyService
!=
nil
{
if
stickyAccountID
>
0
&&
stickyAccountID
==
account
.
ID
&&
s
.
concurrencyService
!=
nil
{
waitingCount
,
_
:=
s
.
concurrencyService
.
GetAccountWaitingCount
(
ctx
,
account
.
ID
)
waitingCount
,
_
:=
s
.
concurrencyService
.
GetAccountWaitingCount
(
ctx
,
account
.
ID
)
if
waitingCount
<
cfg
.
StickySessionMaxWaiting
{
if
waitingCount
<
cfg
.
StickySessionMaxWaiting
{
return
&
AccountSelectionResult
{
return
s
.
newSelectionResult
(
ctx
,
account
,
false
,
nil
,
&
AccountWaitPlan
{
Account
:
account
,
AccountID
:
account
.
ID
,
WaitPlan
:
&
AccountWaitPlan
{
MaxConcurrency
:
account
.
Concurrency
,
AccountID
:
account
.
ID
,
Timeout
:
cfg
.
StickySessionWaitTimeout
,
MaxConcurrency
:
account
.
Concurrency
,
MaxWaiting
:
cfg
.
StickySessionMaxWaiting
,
Timeout
:
cfg
.
StickySessionWaitTimeout
,
})
MaxWaiting
:
cfg
.
StickySessionMaxWaiting
,
},
},
nil
}
}
}
}
return
&
AccountSelectionResult
{
return
s
.
newSelectionResult
(
ctx
,
account
,
false
,
nil
,
&
AccountWaitPlan
{
Account
:
account
,
AccountID
:
account
.
ID
,
WaitPlan
:
&
AccountWaitPlan
{
MaxConcurrency
:
account
.
Concurrency
,
AccountID
:
account
.
ID
,
Timeout
:
cfg
.
FallbackWaitTimeout
,
MaxConcurrency
:
account
.
Concurrency
,
MaxWaiting
:
cfg
.
FallbackMaxWaiting
,
Timeout
:
cfg
.
FallbackWaitTimeout
,
})
MaxWaiting
:
cfg
.
FallbackMaxWaiting
,
},
},
nil
}
}
}
}
...
@@ -1455,11 +1453,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
...
@@ -1455,11 +1453,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
if
s
.
debugModelRoutingEnabled
()
{
if
s
.
debugModelRoutingEnabled
()
{
logger
.
LegacyPrintf
(
"service.gateway"
,
"[ModelRoutingDebug] routed sticky hit: group_id=%v model=%s session=%s account=%d"
,
derefGroupID
(
groupID
),
requestedModel
,
shortSessionHash
(
sessionHash
),
stickyAccountID
)
logger
.
LegacyPrintf
(
"service.gateway"
,
"[ModelRoutingDebug] routed sticky hit: group_id=%v model=%s session=%s account=%d"
,
derefGroupID
(
groupID
),
requestedModel
,
shortSessionHash
(
sessionHash
),
stickyAccountID
)
}
}
return
&
AccountSelectionResult
{
return
s
.
newSelectionResult
(
ctx
,
stickyAccount
,
true
,
result
.
ReleaseFunc
,
nil
)
Account
:
stickyAccount
,
Acquired
:
true
,
ReleaseFunc
:
result
.
ReleaseFunc
,
},
nil
}
}
}
}
...
@@ -1570,11 +1564,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
...
@@ -1570,11 +1564,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
if
s
.
debugModelRoutingEnabled
()
{
if
s
.
debugModelRoutingEnabled
()
{
logger
.
LegacyPrintf
(
"service.gateway"
,
"[ModelRoutingDebug] routed select: group_id=%v model=%s session=%s account=%d"
,
derefGroupID
(
groupID
),
requestedModel
,
shortSessionHash
(
sessionHash
),
item
.
account
.
ID
)
logger
.
LegacyPrintf
(
"service.gateway"
,
"[ModelRoutingDebug] routed select: group_id=%v model=%s session=%s account=%d"
,
derefGroupID
(
groupID
),
requestedModel
,
shortSessionHash
(
sessionHash
),
item
.
account
.
ID
)
}
}
return
&
AccountSelectionResult
{
return
s
.
newSelectionResult
(
ctx
,
item
.
account
,
true
,
result
.
ReleaseFunc
,
nil
)
Account
:
item
.
account
,
Acquired
:
true
,
ReleaseFunc
:
result
.
ReleaseFunc
,
},
nil
}
}
}
}
...
@@ -1587,15 +1577,12 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
...
@@ -1587,15 +1577,12 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
if
s
.
debugModelRoutingEnabled
()
{
if
s
.
debugModelRoutingEnabled
()
{
logger
.
LegacyPrintf
(
"service.gateway"
,
"[ModelRoutingDebug] routed wait: group_id=%v model=%s session=%s account=%d"
,
derefGroupID
(
groupID
),
requestedModel
,
shortSessionHash
(
sessionHash
),
item
.
account
.
ID
)
logger
.
LegacyPrintf
(
"service.gateway"
,
"[ModelRoutingDebug] routed wait: group_id=%v model=%s session=%s account=%d"
,
derefGroupID
(
groupID
),
requestedModel
,
shortSessionHash
(
sessionHash
),
item
.
account
.
ID
)
}
}
return
&
AccountSelectionResult
{
return
s
.
newSelectionResult
(
ctx
,
item
.
account
,
false
,
nil
,
&
AccountWaitPlan
{
Account
:
item
.
account
,
AccountID
:
item
.
account
.
ID
,
WaitPlan
:
&
AccountWaitPlan
{
MaxConcurrency
:
item
.
account
.
Concurrency
,
AccountID
:
item
.
account
.
ID
,
Timeout
:
cfg
.
StickySessionWaitTimeout
,
MaxConcurrency
:
item
.
account
.
Concurrency
,
MaxWaiting
:
cfg
.
StickySessionMaxWaiting
,
Timeout
:
cfg
.
StickySessionWaitTimeout
,
})
MaxWaiting
:
cfg
.
StickySessionMaxWaiting
,
},
},
nil
}
}
// 所有路由账号会话限制都已满,继续到 Layer 2 回退
// 所有路由账号会话限制都已满,继续到 Layer 2 回退
}
}
...
@@ -1631,11 +1618,10 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
...
@@ -1631,11 +1618,10 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
if
!
s
.
checkAndRegisterSession
(
ctx
,
account
,
sessionHash
)
{
if
!
s
.
checkAndRegisterSession
(
ctx
,
account
,
sessionHash
)
{
result
.
ReleaseFunc
()
// 释放槽位,继续到 Layer 2
result
.
ReleaseFunc
()
// 释放槽位,继续到 Layer 2
}
else
{
}
else
{
return
&
AccountSelectionResult
{
if
s
.
cache
!=
nil
{
Account
:
account
,
_
=
s
.
cache
.
RefreshSessionTTL
(
ctx
,
derefGroupID
(
groupID
),
sessionHash
,
stickySessionTTL
)
Acquired
:
true
,
}
ReleaseFunc
:
result
.
ReleaseFunc
,
return
s
.
newSelectionResult
(
ctx
,
account
,
true
,
result
.
ReleaseFunc
,
nil
)
},
nil
}
}
}
}
...
@@ -1647,15 +1633,12 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
...
@@ -1647,15 +1633,12 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
// 会话限制已满,继续到 Layer 2
// 会话限制已满,继续到 Layer 2
// Session limit full, continue to Layer 2
// Session limit full, continue to Layer 2
}
else
{
}
else
{
return
&
AccountSelectionResult
{
return
s
.
newSelectionResult
(
ctx
,
account
,
false
,
nil
,
&
AccountWaitPlan
{
Account
:
account
,
AccountID
:
accountID
,
WaitPlan
:
&
AccountWaitPlan
{
MaxConcurrency
:
account
.
Concurrency
,
AccountID
:
accountID
,
Timeout
:
cfg
.
StickySessionWaitTimeout
,
MaxConcurrency
:
account
.
Concurrency
,
MaxWaiting
:
cfg
.
StickySessionMaxWaiting
,
Timeout
:
cfg
.
StickySessionWaitTimeout
,
})
MaxWaiting
:
cfg
.
StickySessionMaxWaiting
,
},
},
nil
}
}
}
}
}
}
...
@@ -1714,7 +1697,9 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
...
@@ -1714,7 +1697,9 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
loadMap
,
err
:=
s
.
concurrencyService
.
GetAccountsLoadBatch
(
ctx
,
accountLoads
)
loadMap
,
err
:=
s
.
concurrencyService
.
GetAccountsLoadBatch
(
ctx
,
accountLoads
)
if
err
!=
nil
{
if
err
!=
nil
{
if
result
,
ok
:=
s
.
tryAcquireByLegacyOrder
(
ctx
,
candidates
,
groupID
,
sessionHash
,
preferOAuth
);
ok
{
if
result
,
ok
,
legacyErr
:=
s
.
tryAcquireByLegacyOrder
(
ctx
,
candidates
,
groupID
,
sessionHash
,
preferOAuth
);
legacyErr
!=
nil
{
return
nil
,
legacyErr
}
else
if
ok
{
return
result
,
nil
return
result
,
nil
}
}
}
else
{
}
else
{
...
@@ -1753,11 +1738,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
...
@@ -1753,11 +1738,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
if
sessionHash
!=
""
&&
s
.
cache
!=
nil
{
if
sessionHash
!=
""
&&
s
.
cache
!=
nil
{
_
=
s
.
cache
.
SetSessionAccountID
(
ctx
,
derefGroupID
(
groupID
),
sessionHash
,
selected
.
account
.
ID
,
stickySessionTTL
)
_
=
s
.
cache
.
SetSessionAccountID
(
ctx
,
derefGroupID
(
groupID
),
sessionHash
,
selected
.
account
.
ID
,
stickySessionTTL
)
}
}
return
&
AccountSelectionResult
{
return
s
.
newSelectionResult
(
ctx
,
selected
.
account
,
true
,
result
.
ReleaseFunc
,
nil
)
Account
:
selected
.
account
,
Acquired
:
true
,
ReleaseFunc
:
result
.
ReleaseFunc
,
},
nil
}
}
}
}
...
@@ -1780,20 +1761,17 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
...
@@ -1780,20 +1761,17 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
if
!
s
.
checkAndRegisterSession
(
ctx
,
acc
,
sessionHash
)
{
if
!
s
.
checkAndRegisterSession
(
ctx
,
acc
,
sessionHash
)
{
continue
// 会话限制已满,尝试下一个账号
continue
// 会话限制已满,尝试下一个账号
}
}
return
&
AccountSelectionResult
{
return
s
.
newSelectionResult
(
ctx
,
acc
,
false
,
nil
,
&
AccountWaitPlan
{
Account
:
acc
,
AccountID
:
acc
.
ID
,
WaitPlan
:
&
AccountWaitPlan
{
MaxConcurrency
:
acc
.
Concurrency
,
AccountID
:
acc
.
ID
,
Timeout
:
cfg
.
FallbackWaitTimeout
,
MaxConcurrency
:
acc
.
Concurrency
,
MaxWaiting
:
cfg
.
FallbackMaxWaiting
,
Timeout
:
cfg
.
FallbackWaitTimeout
,
})
MaxWaiting
:
cfg
.
FallbackMaxWaiting
,
},
},
nil
}
}
return
nil
,
ErrNoAvailableAccounts
return
nil
,
ErrNoAvailableAccounts
}
}
func
(
s
*
GatewayService
)
tryAcquireByLegacyOrder
(
ctx
context
.
Context
,
candidates
[]
*
Account
,
groupID
*
int64
,
sessionHash
string
,
preferOAuth
bool
)
(
*
AccountSelectionResult
,
bool
)
{
func
(
s
*
GatewayService
)
tryAcquireByLegacyOrder
(
ctx
context
.
Context
,
candidates
[]
*
Account
,
groupID
*
int64
,
sessionHash
string
,
preferOAuth
bool
)
(
*
AccountSelectionResult
,
bool
,
error
)
{
ordered
:=
append
([]
*
Account
(
nil
),
candidates
...
)
ordered
:=
append
([]
*
Account
(
nil
),
candidates
...
)
sortAccountsByPriorityAndLastUsed
(
ordered
,
preferOAuth
)
sortAccountsByPriorityAndLastUsed
(
ordered
,
preferOAuth
)
...
@@ -1808,15 +1786,15 @@ func (s *GatewayService) tryAcquireByLegacyOrder(ctx context.Context, candidates
...
@@ -1808,15 +1786,15 @@ func (s *GatewayService) tryAcquireByLegacyOrder(ctx context.Context, candidates
if
sessionHash
!=
""
&&
s
.
cache
!=
nil
{
if
sessionHash
!=
""
&&
s
.
cache
!=
nil
{
_
=
s
.
cache
.
SetSessionAccountID
(
ctx
,
derefGroupID
(
groupID
),
sessionHash
,
acc
.
ID
,
stickySessionTTL
)
_
=
s
.
cache
.
SetSessionAccountID
(
ctx
,
derefGroupID
(
groupID
),
sessionHash
,
acc
.
ID
,
stickySessionTTL
)
}
}
return
&
AccountSelectionResult
{
selection
,
err
:=
s
.
newSelectionResult
(
ctx
,
acc
,
true
,
result
.
ReleaseFunc
,
nil
)
Account
:
acc
,
if
err
!=
nil
{
Acquired
:
true
,
return
nil
,
false
,
err
ReleaseFunc
:
result
.
ReleaseFunc
,
}
}
,
true
return
selection
,
true
,
nil
}
}
}
}
return
nil
,
false
return
nil
,
false
,
nil
}
}
func
(
s
*
GatewayService
)
schedulingConfig
()
config
.
GatewaySchedulingConfig
{
func
(
s
*
GatewayService
)
schedulingConfig
()
config
.
GatewaySchedulingConfig
{
...
@@ -2431,6 +2409,33 @@ func (s *GatewayService) getSchedulableAccount(ctx context.Context, accountID in
...
@@ -2431,6 +2409,33 @@ func (s *GatewayService) getSchedulableAccount(ctx context.Context, accountID in
return
s
.
accountRepo
.
GetByID
(
ctx
,
accountID
)
return
s
.
accountRepo
.
GetByID
(
ctx
,
accountID
)
}
}
func
(
s
*
GatewayService
)
hydrateSelectedAccount
(
ctx
context
.
Context
,
account
*
Account
)
(
*
Account
,
error
)
{
if
account
==
nil
||
s
.
schedulerSnapshot
==
nil
{
return
account
,
nil
}
hydrated
,
err
:=
s
.
schedulerSnapshot
.
GetAccount
(
ctx
,
account
.
ID
)
if
err
!=
nil
{
return
nil
,
err
}
if
hydrated
==
nil
{
return
nil
,
fmt
.
Errorf
(
"selected gateway account %d not found during hydration"
,
account
.
ID
)
}
return
hydrated
,
nil
}
func
(
s
*
GatewayService
)
newSelectionResult
(
ctx
context
.
Context
,
account
*
Account
,
acquired
bool
,
release
func
(),
waitPlan
*
AccountWaitPlan
)
(
*
AccountSelectionResult
,
error
)
{
hydrated
,
err
:=
s
.
hydrateSelectedAccount
(
ctx
,
account
)
if
err
!=
nil
{
return
nil
,
err
}
return
&
AccountSelectionResult
{
Account
:
hydrated
,
Acquired
:
acquired
,
ReleaseFunc
:
release
,
WaitPlan
:
waitPlan
,
},
nil
}
// filterByMinPriority 过滤出优先级最小的账号集合
// filterByMinPriority 过滤出优先级最小的账号集合
func
filterByMinPriority
(
accounts
[]
accountWithLoad
)
[]
accountWithLoad
{
func
filterByMinPriority
(
accounts
[]
accountWithLoad
)
[]
accountWithLoad
{
if
len
(
accounts
)
==
0
{
if
len
(
accounts
)
==
0
{
...
...
backend/internal/service/gemini_messages_compat_service.go
View file @
a04ae28a
...
@@ -137,7 +137,7 @@ func (s *GeminiMessagesCompatService) SelectAccountForModelWithExclusions(ctx co
...
@@ -137,7 +137,7 @@ func (s *GeminiMessagesCompatService) SelectAccountForModelWithExclusions(ctx co
_
=
s
.
cache
.
SetSessionAccountID
(
ctx
,
derefGroupID
(
groupID
),
cacheKey
,
selected
.
ID
,
geminiStickySessionTTL
)
_
=
s
.
cache
.
SetSessionAccountID
(
ctx
,
derefGroupID
(
groupID
),
cacheKey
,
selected
.
ID
,
geminiStickySessionTTL
)
}
}
return
s
elected
,
nil
return
s
.
hydrateSelectedAccount
(
ctx
,
selected
)
}
}
// resolvePlatformAndSchedulingMode 解析目标平台和调度模式。
// resolvePlatformAndSchedulingMode 解析目标平台和调度模式。
...
@@ -416,6 +416,20 @@ func (s *GeminiMessagesCompatService) getSchedulableAccount(ctx context.Context,
...
@@ -416,6 +416,20 @@ func (s *GeminiMessagesCompatService) getSchedulableAccount(ctx context.Context,
return
s
.
accountRepo
.
GetByID
(
ctx
,
accountID
)
return
s
.
accountRepo
.
GetByID
(
ctx
,
accountID
)
}
}
func
(
s
*
GeminiMessagesCompatService
)
hydrateSelectedAccount
(
ctx
context
.
Context
,
account
*
Account
)
(
*
Account
,
error
)
{
if
account
==
nil
||
s
.
schedulerSnapshot
==
nil
{
return
account
,
nil
}
hydrated
,
err
:=
s
.
schedulerSnapshot
.
GetAccount
(
ctx
,
account
.
ID
)
if
err
!=
nil
{
return
nil
,
err
}
if
hydrated
==
nil
{
return
nil
,
fmt
.
Errorf
(
"selected gemini account %d not found during hydration"
,
account
.
ID
)
}
return
hydrated
,
nil
}
func
(
s
*
GeminiMessagesCompatService
)
listSchedulableAccountsOnce
(
ctx
context
.
Context
,
groupID
*
int64
,
platform
string
,
hasForcePlatform
bool
)
([]
Account
,
error
)
{
func
(
s
*
GeminiMessagesCompatService
)
listSchedulableAccountsOnce
(
ctx
context
.
Context
,
groupID
*
int64
,
platform
string
,
hasForcePlatform
bool
)
([]
Account
,
error
)
{
if
s
.
schedulerSnapshot
!=
nil
{
if
s
.
schedulerSnapshot
!=
nil
{
accounts
,
_
,
err
:=
s
.
schedulerSnapshot
.
ListSchedulableAccounts
(
ctx
,
groupID
,
platform
,
hasForcePlatform
)
accounts
,
_
,
err
:=
s
.
schedulerSnapshot
.
ListSchedulableAccounts
(
ctx
,
groupID
,
platform
,
hasForcePlatform
)
...
@@ -546,7 +560,7 @@ func (s *GeminiMessagesCompatService) SelectAccountForAIStudioEndpoints(ctx cont
...
@@ -546,7 +560,7 @@ func (s *GeminiMessagesCompatService) SelectAccountForAIStudioEndpoints(ctx cont
if
selected
==
nil
{
if
selected
==
nil
{
return
nil
,
errors
.
New
(
"no available Gemini accounts"
)
return
nil
,
errors
.
New
(
"no available Gemini accounts"
)
}
}
return
s
elected
,
nil
return
s
.
hydrateSelectedAccount
(
ctx
,
selected
)
}
}
func
(
s
*
GeminiMessagesCompatService
)
Forward
(
ctx
context
.
Context
,
c
*
gin
.
Context
,
account
*
Account
,
body
[]
byte
)
(
*
ForwardResult
,
error
)
{
func
(
s
*
GeminiMessagesCompatService
)
Forward
(
ctx
context
.
Context
,
c
*
gin
.
Context
,
account
*
Account
,
body
[]
byte
)
(
*
ForwardResult
,
error
)
{
...
...
backend/internal/service/group.go
View file @
a04ae28a
...
@@ -3,8 +3,12 @@ package service
...
@@ -3,8 +3,12 @@ package service
import
(
import
(
"strings"
"strings"
"time"
"time"
"github.com/Wei-Shaw/sub2api/internal/domain"
)
)
type
OpenAIMessagesDispatchModelConfig
=
domain
.
OpenAIMessagesDispatchModelConfig
type
Group
struct
{
type
Group
struct
{
ID
int64
ID
int64
Name
string
Name
string
...
@@ -49,10 +53,11 @@ type Group struct {
...
@@ -49,10 +53,11 @@ type Group struct {
SortOrder
int
SortOrder
int
// OpenAI Messages 调度配置(仅 openai 平台使用)
// OpenAI Messages 调度配置(仅 openai 平台使用)
AllowMessagesDispatch
bool
AllowMessagesDispatch
bool
RequireOAuthOnly
bool
// 仅允许非 apikey 类型账号关联(OpenAI/Antigravity/Anthropic/Gemini)
RequireOAuthOnly
bool
// 仅允许非 apikey 类型账号关联(OpenAI/Antigravity/Anthropic/Gemini)
RequirePrivacySet
bool
// 调度时仅允许 privacy 已成功设置的账号(OpenAI/Antigravity/Anthropic/Gemini)
RequirePrivacySet
bool
// 调度时仅允许 privacy 已成功设置的账号(OpenAI/Antigravity/Anthropic/Gemini)
DefaultMappedModel
string
DefaultMappedModel
string
MessagesDispatchModelConfig
OpenAIMessagesDispatchModelConfig
CreatedAt
time
.
Time
CreatedAt
time
.
Time
UpdatedAt
time
.
Time
UpdatedAt
time
.
Time
...
...
backend/internal/service/openai_codex_instructions_template.go
0 → 100644
View file @
a04ae28a
package
service
import
(
"bytes"
"fmt"
"strings"
"text/template"
)
type
forcedCodexInstructionsTemplateData
struct
{
ExistingInstructions
string
OriginalModel
string
NormalizedModel
string
BillingModel
string
UpstreamModel
string
}
func
applyForcedCodexInstructionsTemplate
(
reqBody
map
[
string
]
any
,
templateText
string
,
data
forcedCodexInstructionsTemplateData
,
)
(
bool
,
error
)
{
rendered
,
err
:=
renderForcedCodexInstructionsTemplate
(
templateText
,
data
)
if
err
!=
nil
{
return
false
,
err
}
if
rendered
==
""
{
return
false
,
nil
}
existing
,
_
:=
reqBody
[
"instructions"
]
.
(
string
)
if
strings
.
TrimSpace
(
existing
)
==
rendered
{
return
false
,
nil
}
reqBody
[
"instructions"
]
=
rendered
return
true
,
nil
}
func
renderForcedCodexInstructionsTemplate
(
templateText
string
,
data
forcedCodexInstructionsTemplateData
,
)
(
string
,
error
)
{
tmpl
,
err
:=
template
.
New
(
"forced_codex_instructions"
)
.
Option
(
"missingkey=zero"
)
.
Parse
(
templateText
)
if
err
!=
nil
{
return
""
,
fmt
.
Errorf
(
"parse forced codex instructions template: %w"
,
err
)
}
var
buf
bytes
.
Buffer
if
err
:=
tmpl
.
Execute
(
&
buf
,
data
);
err
!=
nil
{
return
""
,
fmt
.
Errorf
(
"render forced codex instructions template: %w"
,
err
)
}
return
strings
.
TrimSpace
(
buf
.
String
()),
nil
}
backend/internal/service/openai_compat_model_test.go
View file @
a04ae28a
...
@@ -6,9 +6,12 @@ import (
...
@@ -6,9 +6,12 @@ import (
"io"
"io"
"net/http"
"net/http"
"net/http/httptest"
"net/http/httptest"
"os"
"path/filepath"
"strings"
"strings"
"testing"
"testing"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/apicompat"
"github.com/Wei-Shaw/sub2api/internal/pkg/apicompat"
"github.com/gin-gonic/gin"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/require"
...
@@ -127,3 +130,101 @@ func TestForwardAsAnthropic_NormalizesRoutingAndEffortForGpt54XHigh(t *testing.T
...
@@ -127,3 +130,101 @@ func TestForwardAsAnthropic_NormalizesRoutingAndEffortForGpt54XHigh(t *testing.T
t
.
Logf
(
"upstream body: %s"
,
string
(
upstream
.
lastBody
))
t
.
Logf
(
"upstream body: %s"
,
string
(
upstream
.
lastBody
))
t
.
Logf
(
"response body: %s"
,
rec
.
Body
.
String
())
t
.
Logf
(
"response body: %s"
,
rec
.
Body
.
String
())
}
}
func
TestForwardAsAnthropic_ForcedCodexInstructionsTemplatePrependsRenderedInstructions
(
t
*
testing
.
T
)
{
t
.
Parallel
()
gin
.
SetMode
(
gin
.
TestMode
)
templateDir
:=
t
.
TempDir
()
templatePath
:=
filepath
.
Join
(
templateDir
,
"codex-instructions.md.tmpl"
)
require
.
NoError
(
t
,
os
.
WriteFile
(
templatePath
,
[]
byte
(
"server-prefix
\n\n
{{ .ExistingInstructions }}"
),
0
o644
))
rec
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
rec
)
body
:=
[]
byte
(
`{"model":"gpt-5.4","max_tokens":16,"system":"client-system","messages":[{"role":"user","content":"hello"}],"stream":false}`
)
c
.
Request
=
httptest
.
NewRequest
(
http
.
MethodPost
,
"/v1/messages"
,
bytes
.
NewReader
(
body
))
c
.
Request
.
Header
.
Set
(
"Content-Type"
,
"application/json"
)
upstreamBody
:=
strings
.
Join
([]
string
{
`data: {"type":"response.completed","response":{"id":"resp_1","object":"response","model":"gpt-5.4","status":"completed","output":[{"type":"message","id":"msg_1","role":"assistant","status":"completed","content":[{"type":"output_text","text":"ok"}]}],"usage":{"input_tokens":5,"output_tokens":2,"total_tokens":7}}}`
,
""
,
"data: [DONE]"
,
""
,
},
"
\n
"
)
upstream
:=
&
httpUpstreamRecorder
{
resp
:
&
http
.
Response
{
StatusCode
:
http
.
StatusOK
,
Header
:
http
.
Header
{
"Content-Type"
:
[]
string
{
"text/event-stream"
},
"x-request-id"
:
[]
string
{
"rid_forced"
}},
Body
:
io
.
NopCloser
(
strings
.
NewReader
(
upstreamBody
)),
}}
svc
:=
&
OpenAIGatewayService
{
cfg
:
&
config
.
Config
{
Gateway
:
config
.
GatewayConfig
{
ForcedCodexInstructionsTemplateFile
:
templatePath
,
ForcedCodexInstructionsTemplate
:
"server-prefix
\n\n
{{ .ExistingInstructions }}"
,
}},
httpUpstream
:
upstream
,
}
account
:=
&
Account
{
ID
:
1
,
Name
:
"openai-oauth"
,
Platform
:
PlatformOpenAI
,
Type
:
AccountTypeOAuth
,
Concurrency
:
1
,
Credentials
:
map
[
string
]
any
{
"access_token"
:
"oauth-token"
,
"chatgpt_account_id"
:
"chatgpt-acc"
,
},
}
result
,
err
:=
svc
.
ForwardAsAnthropic
(
context
.
Background
(),
c
,
account
,
body
,
""
,
"gpt-5.1"
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
result
)
require
.
Equal
(
t
,
"server-prefix
\n\n
client-system"
,
gjson
.
GetBytes
(
upstream
.
lastBody
,
"instructions"
)
.
String
())
}
func
TestForwardAsAnthropic_ForcedCodexInstructionsTemplateUsesCachedTemplateContent
(
t
*
testing
.
T
)
{
t
.
Parallel
()
gin
.
SetMode
(
gin
.
TestMode
)
rec
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
rec
)
body
:=
[]
byte
(
`{"model":"gpt-5.4","max_tokens":16,"system":"client-system","messages":[{"role":"user","content":"hello"}],"stream":false}`
)
c
.
Request
=
httptest
.
NewRequest
(
http
.
MethodPost
,
"/v1/messages"
,
bytes
.
NewReader
(
body
))
c
.
Request
.
Header
.
Set
(
"Content-Type"
,
"application/json"
)
upstreamBody
:=
strings
.
Join
([]
string
{
`data: {"type":"response.completed","response":{"id":"resp_1","object":"response","model":"gpt-5.4","status":"completed","output":[{"type":"message","id":"msg_1","role":"assistant","status":"completed","content":[{"type":"output_text","text":"ok"}]}],"usage":{"input_tokens":5,"output_tokens":2,"total_tokens":7}}}`
,
""
,
"data: [DONE]"
,
""
,
},
"
\n
"
)
upstream
:=
&
httpUpstreamRecorder
{
resp
:
&
http
.
Response
{
StatusCode
:
http
.
StatusOK
,
Header
:
http
.
Header
{
"Content-Type"
:
[]
string
{
"text/event-stream"
},
"x-request-id"
:
[]
string
{
"rid_forced_cached"
}},
Body
:
io
.
NopCloser
(
strings
.
NewReader
(
upstreamBody
)),
}}
svc
:=
&
OpenAIGatewayService
{
cfg
:
&
config
.
Config
{
Gateway
:
config
.
GatewayConfig
{
ForcedCodexInstructionsTemplateFile
:
"/path/that/should/not/be/read.tmpl"
,
ForcedCodexInstructionsTemplate
:
"cached-prefix
\n\n
{{ .ExistingInstructions }}"
,
}},
httpUpstream
:
upstream
,
}
account
:=
&
Account
{
ID
:
1
,
Name
:
"openai-oauth"
,
Platform
:
PlatformOpenAI
,
Type
:
AccountTypeOAuth
,
Concurrency
:
1
,
Credentials
:
map
[
string
]
any
{
"access_token"
:
"oauth-token"
,
"chatgpt_account_id"
:
"chatgpt-acc"
,
},
}
result
,
err
:=
svc
.
ForwardAsAnthropic
(
context
.
Background
(),
c
,
account
,
body
,
""
,
"gpt-5.1"
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
result
)
require
.
Equal
(
t
,
"cached-prefix
\n\n
client-system"
,
gjson
.
GetBytes
(
upstream
.
lastBody
,
"instructions"
)
.
String
())
}
backend/internal/service/openai_gateway_messages.go
View file @
a04ae28a
...
@@ -86,6 +86,24 @@ func (s *OpenAIGatewayService) ForwardAsAnthropic(
...
@@ -86,6 +86,24 @@ func (s *OpenAIGatewayService) ForwardAsAnthropic(
return
nil
,
fmt
.
Errorf
(
"unmarshal for codex transform: %w"
,
err
)
return
nil
,
fmt
.
Errorf
(
"unmarshal for codex transform: %w"
,
err
)
}
}
codexResult
:=
applyCodexOAuthTransform
(
reqBody
,
false
,
false
)
codexResult
:=
applyCodexOAuthTransform
(
reqBody
,
false
,
false
)
forcedTemplateText
:=
""
if
s
.
cfg
!=
nil
{
forcedTemplateText
=
s
.
cfg
.
Gateway
.
ForcedCodexInstructionsTemplate
}
templateUpstreamModel
:=
upstreamModel
if
codexResult
.
NormalizedModel
!=
""
{
templateUpstreamModel
=
codexResult
.
NormalizedModel
}
existingInstructions
,
_
:=
reqBody
[
"instructions"
]
.
(
string
)
if
_
,
err
:=
applyForcedCodexInstructionsTemplate
(
reqBody
,
forcedTemplateText
,
forcedCodexInstructionsTemplateData
{
ExistingInstructions
:
strings
.
TrimSpace
(
existingInstructions
),
OriginalModel
:
originalModel
,
NormalizedModel
:
normalizedModel
,
BillingModel
:
billingModel
,
UpstreamModel
:
templateUpstreamModel
,
});
err
!=
nil
{
return
nil
,
err
}
if
codexResult
.
NormalizedModel
!=
""
{
if
codexResult
.
NormalizedModel
!=
""
{
upstreamModel
=
codexResult
.
NormalizedModel
upstreamModel
=
codexResult
.
NormalizedModel
}
}
...
...
backend/internal/service/openai_gateway_service.go
View file @
a04ae28a
...
@@ -1243,7 +1243,7 @@ func (s *OpenAIGatewayService) selectAccountForModelWithExclusions(ctx context.C
...
@@ -1243,7 +1243,7 @@ func (s *OpenAIGatewayService) selectAccountForModelWithExclusions(ctx context.C
_
=
s
.
setStickySessionAccountID
(
ctx
,
groupID
,
sessionHash
,
selected
.
ID
,
openaiStickySessionTTL
)
_
=
s
.
setStickySessionAccountID
(
ctx
,
groupID
,
sessionHash
,
selected
.
ID
,
openaiStickySessionTTL
)
}
}
return
s
elected
,
nil
return
s
.
hydrateSelectedAccount
(
ctx
,
selected
)
}
}
// tryStickySessionHit 尝试从粘性会话获取账号。
// tryStickySessionHit 尝试从粘性会话获取账号。
...
@@ -1408,35 +1408,25 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex
...
@@ -1408,35 +1408,25 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex
}
}
result
,
err
:=
s
.
tryAcquireAccountSlot
(
ctx
,
account
.
ID
,
account
.
Concurrency
)
result
,
err
:=
s
.
tryAcquireAccountSlot
(
ctx
,
account
.
ID
,
account
.
Concurrency
)
if
err
==
nil
&&
result
.
Acquired
{
if
err
==
nil
&&
result
.
Acquired
{
return
&
AccountSelectionResult
{
return
s
.
newSelectionResult
(
ctx
,
account
,
true
,
result
.
ReleaseFunc
,
nil
)
Account
:
account
,
Acquired
:
true
,
ReleaseFunc
:
result
.
ReleaseFunc
,
},
nil
}
}
if
stickyAccountID
>
0
&&
stickyAccountID
==
account
.
ID
&&
s
.
concurrencyService
!=
nil
{
if
stickyAccountID
>
0
&&
stickyAccountID
==
account
.
ID
&&
s
.
concurrencyService
!=
nil
{
waitingCount
,
_
:=
s
.
concurrencyService
.
GetAccountWaitingCount
(
ctx
,
account
.
ID
)
waitingCount
,
_
:=
s
.
concurrencyService
.
GetAccountWaitingCount
(
ctx
,
account
.
ID
)
if
waitingCount
<
cfg
.
StickySessionMaxWaiting
{
if
waitingCount
<
cfg
.
StickySessionMaxWaiting
{
return
&
AccountSelectionResult
{
return
s
.
newSelectionResult
(
ctx
,
account
,
false
,
nil
,
&
AccountWaitPlan
{
Account
:
account
,
AccountID
:
account
.
ID
,
WaitPlan
:
&
AccountWaitPlan
{
MaxConcurrency
:
account
.
Concurrency
,
AccountID
:
account
.
ID
,
Timeout
:
cfg
.
StickySessionWaitTimeout
,
MaxConcurrency
:
account
.
Concurrency
,
MaxWaiting
:
cfg
.
StickySessionMaxWaiting
,
Timeout
:
cfg
.
StickySessionWaitTimeout
,
})
MaxWaiting
:
cfg
.
StickySessionMaxWaiting
,
},
},
nil
}
}
}
}
return
&
AccountSelectionResult
{
return
s
.
newSelectionResult
(
ctx
,
account
,
false
,
nil
,
&
AccountWaitPlan
{
Account
:
account
,
AccountID
:
account
.
ID
,
WaitPlan
:
&
AccountWaitPlan
{
MaxConcurrency
:
account
.
Concurrency
,
AccountID
:
account
.
ID
,
Timeout
:
cfg
.
FallbackWaitTimeout
,
MaxConcurrency
:
account
.
Concurrency
,
MaxWaiting
:
cfg
.
FallbackMaxWaiting
,
Timeout
:
cfg
.
FallbackWaitTimeout
,
})
MaxWaiting
:
cfg
.
FallbackMaxWaiting
,
},
},
nil
}
}
accounts
,
err
:=
s
.
listSchedulableAccounts
(
ctx
,
groupID
)
accounts
,
err
:=
s
.
listSchedulableAccounts
(
ctx
,
groupID
)
...
@@ -1476,24 +1466,17 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex
...
@@ -1476,24 +1466,17 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex
result
,
err
:=
s
.
tryAcquireAccountSlot
(
ctx
,
accountID
,
account
.
Concurrency
)
result
,
err
:=
s
.
tryAcquireAccountSlot
(
ctx
,
accountID
,
account
.
Concurrency
)
if
err
==
nil
&&
result
.
Acquired
{
if
err
==
nil
&&
result
.
Acquired
{
_
=
s
.
refreshStickySessionTTL
(
ctx
,
groupID
,
sessionHash
,
openaiStickySessionTTL
)
_
=
s
.
refreshStickySessionTTL
(
ctx
,
groupID
,
sessionHash
,
openaiStickySessionTTL
)
return
&
AccountSelectionResult
{
return
s
.
newSelectionResult
(
ctx
,
account
,
true
,
result
.
ReleaseFunc
,
nil
)
Account
:
account
,
Acquired
:
true
,
ReleaseFunc
:
result
.
ReleaseFunc
,
},
nil
}
}
waitingCount
,
_
:=
s
.
concurrencyService
.
GetAccountWaitingCount
(
ctx
,
accountID
)
waitingCount
,
_
:=
s
.
concurrencyService
.
GetAccountWaitingCount
(
ctx
,
accountID
)
if
waitingCount
<
cfg
.
StickySessionMaxWaiting
{
if
waitingCount
<
cfg
.
StickySessionMaxWaiting
{
return
&
AccountSelectionResult
{
return
s
.
newSelectionResult
(
ctx
,
account
,
false
,
nil
,
&
AccountWaitPlan
{
Account
:
account
,
AccountID
:
accountID
,
WaitPlan
:
&
AccountWaitPlan
{
MaxConcurrency
:
account
.
Concurrency
,
AccountID
:
accountID
,
Timeout
:
cfg
.
StickySessionWaitTimeout
,
MaxConcurrency
:
account
.
Concurrency
,
MaxWaiting
:
cfg
.
StickySessionMaxWaiting
,
Timeout
:
cfg
.
StickySessionWaitTimeout
,
})
MaxWaiting
:
cfg
.
StickySessionMaxWaiting
,
},
},
nil
}
}
}
}
}
}
...
@@ -1552,11 +1535,7 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex
...
@@ -1552,11 +1535,7 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex
if
sessionHash
!=
""
{
if
sessionHash
!=
""
{
_
=
s
.
setStickySessionAccountID
(
ctx
,
groupID
,
sessionHash
,
fresh
.
ID
,
openaiStickySessionTTL
)
_
=
s
.
setStickySessionAccountID
(
ctx
,
groupID
,
sessionHash
,
fresh
.
ID
,
openaiStickySessionTTL
)
}
}
return
&
AccountSelectionResult
{
return
s
.
newSelectionResult
(
ctx
,
fresh
,
true
,
result
.
ReleaseFunc
,
nil
)
Account
:
fresh
,
Acquired
:
true
,
ReleaseFunc
:
result
.
ReleaseFunc
,
},
nil
}
}
}
}
}
else
{
}
else
{
...
@@ -1609,11 +1588,7 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex
...
@@ -1609,11 +1588,7 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex
if
sessionHash
!=
""
{
if
sessionHash
!=
""
{
_
=
s
.
setStickySessionAccountID
(
ctx
,
groupID
,
sessionHash
,
fresh
.
ID
,
openaiStickySessionTTL
)
_
=
s
.
setStickySessionAccountID
(
ctx
,
groupID
,
sessionHash
,
fresh
.
ID
,
openaiStickySessionTTL
)
}
}
return
&
AccountSelectionResult
{
return
s
.
newSelectionResult
(
ctx
,
fresh
,
true
,
result
.
ReleaseFunc
,
nil
)
Account
:
fresh
,
Acquired
:
true
,
ReleaseFunc
:
result
.
ReleaseFunc
,
},
nil
}
}
}
}
}
}
...
@@ -1629,15 +1604,12 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex
...
@@ -1629,15 +1604,12 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex
if
needsUpstreamCheck
&&
s
.
isUpstreamModelRestrictedByChannel
(
ctx
,
*
groupID
,
fresh
,
requestedModel
)
{
if
needsUpstreamCheck
&&
s
.
isUpstreamModelRestrictedByChannel
(
ctx
,
*
groupID
,
fresh
,
requestedModel
)
{
continue
continue
}
}
return
&
AccountSelectionResult
{
return
s
.
newSelectionResult
(
ctx
,
fresh
,
false
,
nil
,
&
AccountWaitPlan
{
Account
:
fresh
,
AccountID
:
fresh
.
ID
,
WaitPlan
:
&
AccountWaitPlan
{
MaxConcurrency
:
fresh
.
Concurrency
,
AccountID
:
fresh
.
ID
,
Timeout
:
cfg
.
FallbackWaitTimeout
,
MaxConcurrency
:
fresh
.
Concurrency
,
MaxWaiting
:
cfg
.
FallbackMaxWaiting
,
Timeout
:
cfg
.
FallbackWaitTimeout
,
})
MaxWaiting
:
cfg
.
FallbackMaxWaiting
,
},
},
nil
}
}
return
nil
,
ErrNoAvailableAccounts
return
nil
,
ErrNoAvailableAccounts
...
@@ -1732,6 +1704,33 @@ func (s *OpenAIGatewayService) getSchedulableAccount(ctx context.Context, accoun
...
@@ -1732,6 +1704,33 @@ func (s *OpenAIGatewayService) getSchedulableAccount(ctx context.Context, accoun
return
account
,
nil
return
account
,
nil
}
}
func
(
s
*
OpenAIGatewayService
)
hydrateSelectedAccount
(
ctx
context
.
Context
,
account
*
Account
)
(
*
Account
,
error
)
{
if
account
==
nil
||
s
.
schedulerSnapshot
==
nil
{
return
account
,
nil
}
hydrated
,
err
:=
s
.
schedulerSnapshot
.
GetAccount
(
ctx
,
account
.
ID
)
if
err
!=
nil
{
return
nil
,
err
}
if
hydrated
==
nil
{
return
nil
,
fmt
.
Errorf
(
"selected openai account %d not found during hydration"
,
account
.
ID
)
}
return
hydrated
,
nil
}
func
(
s
*
OpenAIGatewayService
)
newSelectionResult
(
ctx
context
.
Context
,
account
*
Account
,
acquired
bool
,
release
func
(),
waitPlan
*
AccountWaitPlan
)
(
*
AccountSelectionResult
,
error
)
{
hydrated
,
err
:=
s
.
hydrateSelectedAccount
(
ctx
,
account
)
if
err
!=
nil
{
return
nil
,
err
}
return
&
AccountSelectionResult
{
Account
:
hydrated
,
Acquired
:
acquired
,
ReleaseFunc
:
release
,
WaitPlan
:
waitPlan
,
},
nil
}
func
(
s
*
OpenAIGatewayService
)
schedulingConfig
()
config
.
GatewaySchedulingConfig
{
func
(
s
*
OpenAIGatewayService
)
schedulingConfig
()
config
.
GatewaySchedulingConfig
{
if
s
.
cfg
!=
nil
{
if
s
.
cfg
!=
nil
{
return
s
.
cfg
.
Gateway
.
Scheduling
return
s
.
cfg
.
Gateway
.
Scheduling
...
...
backend/internal/service/openai_messages_dispatch.go
0 → 100644
View file @
a04ae28a
package
service
import
"strings"
const
(
defaultOpenAIMessagesDispatchOpusMappedModel
=
"gpt-5.4"
defaultOpenAIMessagesDispatchSonnetMappedModel
=
"gpt-5.3-codex"
defaultOpenAIMessagesDispatchHaikuMappedModel
=
"gpt-5.4-mini"
)
func
normalizeOpenAIMessagesDispatchMappedModel
(
model
string
)
string
{
model
=
NormalizeOpenAICompatRequestedModel
(
strings
.
TrimSpace
(
model
))
return
strings
.
TrimSpace
(
model
)
}
func
normalizeOpenAIMessagesDispatchModelConfig
(
cfg
OpenAIMessagesDispatchModelConfig
)
OpenAIMessagesDispatchModelConfig
{
out
:=
OpenAIMessagesDispatchModelConfig
{
OpusMappedModel
:
normalizeOpenAIMessagesDispatchMappedModel
(
cfg
.
OpusMappedModel
),
SonnetMappedModel
:
normalizeOpenAIMessagesDispatchMappedModel
(
cfg
.
SonnetMappedModel
),
HaikuMappedModel
:
normalizeOpenAIMessagesDispatchMappedModel
(
cfg
.
HaikuMappedModel
),
}
if
len
(
cfg
.
ExactModelMappings
)
>
0
{
out
.
ExactModelMappings
=
make
(
map
[
string
]
string
,
len
(
cfg
.
ExactModelMappings
))
for
requestedModel
,
mappedModel
:=
range
cfg
.
ExactModelMappings
{
requestedModel
=
strings
.
TrimSpace
(
requestedModel
)
mappedModel
=
normalizeOpenAIMessagesDispatchMappedModel
(
mappedModel
)
if
requestedModel
==
""
||
mappedModel
==
""
{
continue
}
out
.
ExactModelMappings
[
requestedModel
]
=
mappedModel
}
if
len
(
out
.
ExactModelMappings
)
==
0
{
out
.
ExactModelMappings
=
nil
}
}
return
out
}
func
claudeMessagesDispatchFamily
(
model
string
)
string
{
normalized
:=
strings
.
ToLower
(
strings
.
TrimSpace
(
model
))
if
!
strings
.
HasPrefix
(
normalized
,
"claude"
)
{
return
""
}
switch
{
case
strings
.
Contains
(
normalized
,
"opus"
)
:
return
"opus"
case
strings
.
Contains
(
normalized
,
"sonnet"
)
:
return
"sonnet"
case
strings
.
Contains
(
normalized
,
"haiku"
)
:
return
"haiku"
default
:
return
""
}
}
func
(
g
*
Group
)
ResolveMessagesDispatchModel
(
requestedModel
string
)
string
{
if
g
==
nil
{
return
""
}
requestedModel
=
strings
.
TrimSpace
(
requestedModel
)
if
requestedModel
==
""
{
return
""
}
cfg
:=
normalizeOpenAIMessagesDispatchModelConfig
(
g
.
MessagesDispatchModelConfig
)
if
mappedModel
:=
strings
.
TrimSpace
(
cfg
.
ExactModelMappings
[
requestedModel
]);
mappedModel
!=
""
{
return
mappedModel
}
switch
claudeMessagesDispatchFamily
(
requestedModel
)
{
case
"opus"
:
if
mappedModel
:=
strings
.
TrimSpace
(
cfg
.
OpusMappedModel
);
mappedModel
!=
""
{
return
mappedModel
}
return
defaultOpenAIMessagesDispatchOpusMappedModel
case
"sonnet"
:
if
mappedModel
:=
strings
.
TrimSpace
(
cfg
.
SonnetMappedModel
);
mappedModel
!=
""
{
return
mappedModel
}
return
defaultOpenAIMessagesDispatchSonnetMappedModel
case
"haiku"
:
if
mappedModel
:=
strings
.
TrimSpace
(
cfg
.
HaikuMappedModel
);
mappedModel
!=
""
{
return
mappedModel
}
return
defaultOpenAIMessagesDispatchHaikuMappedModel
default
:
return
""
}
}
func
sanitizeGroupMessagesDispatchFields
(
g
*
Group
)
{
if
g
==
nil
||
g
.
Platform
==
PlatformOpenAI
{
return
}
g
.
AllowMessagesDispatch
=
false
g
.
DefaultMappedModel
=
""
g
.
MessagesDispatchModelConfig
=
OpenAIMessagesDispatchModelConfig
{}
}
Prev
1
…
5
6
7
8
9
10
11
12
13
…
16
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment