Commit a04ae28a authored by 陈曦's avatar 陈曦
Browse files

merge v0.1.111

parents 68f67198 ad64190b
...@@ -111,4 +111,5 @@ func registerRoutes( ...@@ -111,4 +111,5 @@ func registerRoutes(
routes.RegisterUserRoutes(v1, h, jwtAuth, settingService) routes.RegisterUserRoutes(v1, h, jwtAuth, settingService)
routes.RegisterAdminRoutes(v1, h, adminAuth) routes.RegisterAdminRoutes(v1, h, adminAuth)
routes.RegisterGatewayRoutes(r, h, apiKeyAuth, apiKeyService, subscriptionService, opsService, settingService, cfg) routes.RegisterGatewayRoutes(r, h, apiKeyAuth, apiKeyService, subscriptionService, opsService, settingService, cfg)
routes.RegisterPaymentRoutes(v1, h.Payment, h.PaymentWebhook, h.Admin.Payment, jwtAuth, adminAuth, settingService)
} }
...@@ -70,6 +70,14 @@ func RegisterAuthRoutes( ...@@ -70,6 +70,14 @@ func RegisterAuthRoutes(
}), }),
h.Auth.CompleteLinuxDoOAuthRegistration, h.Auth.CompleteLinuxDoOAuthRegistration,
) )
auth.GET("/oauth/oidc/start", h.Auth.OIDCOAuthStart)
auth.GET("/oauth/oidc/callback", h.Auth.OIDCOAuthCallback)
auth.POST("/oauth/oidc/complete-registration",
rateLimiter.LimitWithOptions("oauth-oidc-complete", 10, time.Minute, middleware.RateLimitOptions{
FailureMode: middleware.RateLimitFailClose,
}),
h.Auth.CompleteOIDCOAuthRegistration,
)
} }
// 公开设置(无需认证) // 公开设置(无需认证)
......
package routes
import (
"github.com/Wei-Shaw/sub2api/internal/handler"
"github.com/Wei-Shaw/sub2api/internal/handler/admin"
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
// RegisterPaymentRoutes registers all payment-related routes:
// user-facing endpoints, webhook endpoints, and admin endpoints.
func RegisterPaymentRoutes(
v1 *gin.RouterGroup,
paymentHandler *handler.PaymentHandler,
webhookHandler *handler.PaymentWebhookHandler,
adminPaymentHandler *admin.PaymentHandler,
jwtAuth middleware.JWTAuthMiddleware,
adminAuth middleware.AdminAuthMiddleware,
settingService *service.SettingService,
) {
// --- User-facing payment endpoints (authenticated) ---
authenticated := v1.Group("/payment")
authenticated.Use(gin.HandlerFunc(jwtAuth))
authenticated.Use(middleware.BackendModeUserGuard(settingService))
{
authenticated.GET("/config", paymentHandler.GetPaymentConfig)
authenticated.GET("/checkout-info", paymentHandler.GetCheckoutInfo)
authenticated.GET("/plans", paymentHandler.GetPlans)
authenticated.GET("/channels", paymentHandler.GetChannels)
authenticated.GET("/limits", paymentHandler.GetLimits)
orders := authenticated.Group("/orders")
{
orders.POST("", paymentHandler.CreateOrder)
orders.POST("/verify", paymentHandler.VerifyOrder)
orders.GET("/my", paymentHandler.GetMyOrders)
orders.GET("/:id", paymentHandler.GetOrder)
orders.POST("/:id/cancel", paymentHandler.CancelOrder)
orders.POST("/:id/refund-request", paymentHandler.RequestRefund)
}
}
// --- Public payment endpoints (no auth) ---
// Payment result page needs to verify order status without login
// (user session may have expired during provider redirect).
public := v1.Group("/payment/public")
{
public.POST("/orders/verify", paymentHandler.VerifyOrderPublic)
}
// --- Webhook endpoints (no auth) ---
webhook := v1.Group("/payment/webhook")
{
// EasyPay sends GET callbacks with query params
webhook.GET("/easypay", webhookHandler.EasyPayNotify)
webhook.POST("/easypay", webhookHandler.EasyPayNotify)
webhook.POST("/alipay", webhookHandler.AlipayNotify)
webhook.POST("/wxpay", webhookHandler.WxpayNotify)
webhook.POST("/stripe", webhookHandler.StripeWebhook)
}
// --- Admin payment endpoints (admin auth) ---
adminGroup := v1.Group("/admin/payment")
adminGroup.Use(gin.HandlerFunc(adminAuth))
{
// Dashboard
adminGroup.GET("/dashboard", adminPaymentHandler.GetDashboard)
// Config
adminGroup.GET("/config", adminPaymentHandler.GetConfig)
adminGroup.PUT("/config", adminPaymentHandler.UpdateConfig)
// Orders
adminOrders := adminGroup.Group("/orders")
{
adminOrders.GET("", adminPaymentHandler.ListOrders)
adminOrders.GET("/:id", adminPaymentHandler.GetOrderDetail)
adminOrders.POST("/:id/cancel", adminPaymentHandler.CancelOrder)
adminOrders.POST("/:id/retry", adminPaymentHandler.RetryFulfillment)
adminOrders.POST("/:id/refund", adminPaymentHandler.ProcessRefund)
}
// Subscription Plans
plans := adminGroup.Group("/plans")
{
plans.GET("", adminPaymentHandler.ListPlans)
plans.POST("", adminPaymentHandler.CreatePlan)
plans.PUT("/:id", adminPaymentHandler.UpdatePlan)
plans.DELETE("/:id", adminPaymentHandler.DeletePlan)
}
// Provider Instances
providers := adminGroup.Group("/providers")
{
providers.GET("", adminPaymentHandler.ListProviders)
providers.POST("", adminPaymentHandler.CreateProvider)
providers.PUT("/:id", adminPaymentHandler.UpdateProvider)
providers.DELETE("/:id", adminPaymentHandler.DeleteProvider)
}
}
}
...@@ -21,13 +21,13 @@ import ( ...@@ -21,13 +21,13 @@ import (
// AdminService interface defines admin management operations // AdminService interface defines admin management operations
type AdminService interface { type AdminService interface {
// User management // User management
ListUsers(ctx context.Context, page, pageSize int, filters UserListFilters) ([]User, int64, error) ListUsers(ctx context.Context, page, pageSize int, filters UserListFilters, sortBy, sortOrder string) ([]User, int64, error)
GetUser(ctx context.Context, id int64) (*User, error) GetUser(ctx context.Context, id int64) (*User, error)
CreateUser(ctx context.Context, input *CreateUserInput) (*User, error) CreateUser(ctx context.Context, input *CreateUserInput) (*User, error)
UpdateUser(ctx context.Context, id int64, input *UpdateUserInput) (*User, error) UpdateUser(ctx context.Context, id int64, input *UpdateUserInput) (*User, error)
DeleteUser(ctx context.Context, id int64) error DeleteUser(ctx context.Context, id int64) error
UpdateUserBalance(ctx context.Context, userID int64, balance float64, operation string, notes string) (*User, error) UpdateUserBalance(ctx context.Context, userID int64, balance float64, operation string, notes string) (*User, error)
GetUserAPIKeys(ctx context.Context, userID int64, page, pageSize int) ([]APIKey, int64, error) GetUserAPIKeys(ctx context.Context, userID int64, page, pageSize int, sortBy, sortOrder string) ([]APIKey, int64, error)
GetUserUsageStats(ctx context.Context, userID int64, period string) (any, error) GetUserUsageStats(ctx context.Context, userID int64, period string) (any, error)
// GetUserBalanceHistory returns paginated balance/concurrency change records for a user. // GetUserBalanceHistory returns paginated balance/concurrency change records for a user.
// codeType is optional - pass empty string to return all types. // codeType is optional - pass empty string to return all types.
...@@ -35,7 +35,7 @@ type AdminService interface { ...@@ -35,7 +35,7 @@ type AdminService interface {
GetUserBalanceHistory(ctx context.Context, userID int64, page, pageSize int, codeType string) ([]RedeemCode, int64, float64, error) GetUserBalanceHistory(ctx context.Context, userID int64, page, pageSize int, codeType string) ([]RedeemCode, int64, float64, error)
// Group management // Group management
ListGroups(ctx context.Context, page, pageSize int, platform, status, search string, isExclusive *bool) ([]Group, int64, error) ListGroups(ctx context.Context, page, pageSize int, platform, status, search string, isExclusive *bool, sortBy, sortOrder string) ([]Group, int64, error)
GetAllGroups(ctx context.Context) ([]Group, error) GetAllGroups(ctx context.Context) ([]Group, error)
GetAllGroupsByPlatform(ctx context.Context, platform string) ([]Group, error) GetAllGroupsByPlatform(ctx context.Context, platform string) ([]Group, error)
GetGroup(ctx context.Context, id int64) (*Group, error) GetGroup(ctx context.Context, id int64) (*Group, error)
...@@ -55,7 +55,7 @@ type AdminService interface { ...@@ -55,7 +55,7 @@ type AdminService interface {
ReplaceUserGroup(ctx context.Context, userID, oldGroupID, newGroupID int64) (*ReplaceUserGroupResult, error) ReplaceUserGroup(ctx context.Context, userID, oldGroupID, newGroupID int64) (*ReplaceUserGroupResult, error)
// Account management // Account management
ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string, groupID int64, privacyMode string) ([]Account, int64, error) ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string, groupID int64, privacyMode string, sortBy, sortOrder string) ([]Account, int64, error)
GetAccount(ctx context.Context, id int64) (*Account, error) GetAccount(ctx context.Context, id int64) (*Account, error)
GetAccountsByIDs(ctx context.Context, ids []int64) ([]*Account, error) GetAccountsByIDs(ctx context.Context, ids []int64) ([]*Account, error)
CreateAccount(ctx context.Context, input *CreateAccountInput) (*Account, error) CreateAccount(ctx context.Context, input *CreateAccountInput) (*Account, error)
...@@ -77,8 +77,8 @@ type AdminService interface { ...@@ -77,8 +77,8 @@ type AdminService interface {
CheckMixedChannelRisk(ctx context.Context, currentAccountID int64, currentAccountPlatform string, groupIDs []int64) error CheckMixedChannelRisk(ctx context.Context, currentAccountID int64, currentAccountPlatform string, groupIDs []int64) error
// Proxy management // Proxy management
ListProxies(ctx context.Context, page, pageSize int, protocol, status, search string) ([]Proxy, int64, error) ListProxies(ctx context.Context, page, pageSize int, protocol, status, search string, sortBy, sortOrder string) ([]Proxy, int64, error)
ListProxiesWithAccountCount(ctx context.Context, page, pageSize int, protocol, status, search string) ([]ProxyWithAccountCount, int64, error) ListProxiesWithAccountCount(ctx context.Context, page, pageSize int, protocol, status, search string, sortBy, sortOrder string) ([]ProxyWithAccountCount, int64, error)
GetAllProxies(ctx context.Context) ([]Proxy, error) GetAllProxies(ctx context.Context) ([]Proxy, error)
GetAllProxiesWithAccountCount(ctx context.Context) ([]ProxyWithAccountCount, error) GetAllProxiesWithAccountCount(ctx context.Context) ([]ProxyWithAccountCount, error)
GetProxy(ctx context.Context, id int64) (*Proxy, error) GetProxy(ctx context.Context, id int64) (*Proxy, error)
...@@ -93,7 +93,7 @@ type AdminService interface { ...@@ -93,7 +93,7 @@ type AdminService interface {
CheckProxyQuality(ctx context.Context, id int64) (*ProxyQualityCheckResult, error) CheckProxyQuality(ctx context.Context, id int64) (*ProxyQualityCheckResult, error)
// Redeem code management // Redeem code management
ListRedeemCodes(ctx context.Context, page, pageSize int, codeType, status, search string) ([]RedeemCode, int64, error) ListRedeemCodes(ctx context.Context, page, pageSize int, codeType, status, search string, sortBy, sortOrder string) ([]RedeemCode, int64, error)
GetRedeemCode(ctx context.Context, id int64) (*RedeemCode, error) GetRedeemCode(ctx context.Context, id int64) (*RedeemCode, error)
GenerateRedeemCodes(ctx context.Context, input *GenerateRedeemCodesInput) ([]RedeemCode, error) GenerateRedeemCodes(ctx context.Context, input *GenerateRedeemCodesInput) ([]RedeemCode, error)
DeleteRedeemCode(ctx context.Context, id int64) error DeleteRedeemCode(ctx context.Context, id int64) error
...@@ -152,10 +152,11 @@ type CreateGroupInput struct { ...@@ -152,10 +152,11 @@ type CreateGroupInput struct {
// 支持的模型系列(仅 antigravity 平台使用) // 支持的模型系列(仅 antigravity 平台使用)
SupportedModelScopes []string SupportedModelScopes []string
// OpenAI Messages 调度配置(仅 openai 平台使用) // OpenAI Messages 调度配置(仅 openai 平台使用)
AllowMessagesDispatch bool AllowMessagesDispatch bool
DefaultMappedModel string DefaultMappedModel string
RequireOAuthOnly bool RequireOAuthOnly bool
RequirePrivacySet bool RequirePrivacySet bool
MessagesDispatchModelConfig OpenAIMessagesDispatchModelConfig
// 从指定分组复制账号(创建分组后在同一事务内绑定) // 从指定分组复制账号(创建分组后在同一事务内绑定)
CopyAccountsFromGroupIDs []int64 CopyAccountsFromGroupIDs []int64
} }
...@@ -186,10 +187,11 @@ type UpdateGroupInput struct { ...@@ -186,10 +187,11 @@ type UpdateGroupInput struct {
// 支持的模型系列(仅 antigravity 平台使用) // 支持的模型系列(仅 antigravity 平台使用)
SupportedModelScopes *[]string SupportedModelScopes *[]string
// OpenAI Messages 调度配置(仅 openai 平台使用) // OpenAI Messages 调度配置(仅 openai 平台使用)
AllowMessagesDispatch *bool AllowMessagesDispatch *bool
DefaultMappedModel *string DefaultMappedModel *string
RequireOAuthOnly *bool RequireOAuthOnly *bool
RequirePrivacySet *bool RequirePrivacySet *bool
MessagesDispatchModelConfig *OpenAIMessagesDispatchModelConfig
// 从指定分组复制账号(同步操作:先清空当前分组的账号绑定,再绑定源分组的账号) // 从指定分组复制账号(同步操作:先清空当前分组的账号绑定,再绑定源分组的账号)
CopyAccountsFromGroupIDs []int64 CopyAccountsFromGroupIDs []int64
} }
...@@ -483,8 +485,8 @@ func NewAdminService( ...@@ -483,8 +485,8 @@ func NewAdminService(
} }
// User management implementations // User management implementations
func (s *adminServiceImpl) ListUsers(ctx context.Context, page, pageSize int, filters UserListFilters) ([]User, int64, error) { func (s *adminServiceImpl) ListUsers(ctx context.Context, page, pageSize int, filters UserListFilters, sortBy, sortOrder string) ([]User, int64, error) {
params := pagination.PaginationParams{Page: page, PageSize: pageSize} params := pagination.PaginationParams{Page: page, PageSize: pageSize, SortBy: sortBy, SortOrder: sortOrder}
users, result, err := s.userRepo.ListWithFilters(ctx, params, filters) users, result, err := s.userRepo.ListWithFilters(ctx, params, filters)
if err != nil { if err != nil {
return nil, 0, err return nil, 0, err
...@@ -751,8 +753,8 @@ func (s *adminServiceImpl) UpdateUserBalance(ctx context.Context, userID int64, ...@@ -751,8 +753,8 @@ func (s *adminServiceImpl) UpdateUserBalance(ctx context.Context, userID int64,
return user, nil return user, nil
} }
func (s *adminServiceImpl) GetUserAPIKeys(ctx context.Context, userID int64, page, pageSize int) ([]APIKey, int64, error) { func (s *adminServiceImpl) GetUserAPIKeys(ctx context.Context, userID int64, page, pageSize int, sortBy, sortOrder string) ([]APIKey, int64, error) {
params := pagination.PaginationParams{Page: page, PageSize: pageSize} params := pagination.PaginationParams{Page: page, PageSize: pageSize, SortBy: sortBy, SortOrder: sortOrder}
keys, result, err := s.apiKeyRepo.ListByUserID(ctx, userID, params, APIKeyListFilters{}) keys, result, err := s.apiKeyRepo.ListByUserID(ctx, userID, params, APIKeyListFilters{})
if err != nil { if err != nil {
return nil, 0, err return nil, 0, err
...@@ -787,8 +789,8 @@ func (s *adminServiceImpl) GetUserBalanceHistory(ctx context.Context, userID int ...@@ -787,8 +789,8 @@ func (s *adminServiceImpl) GetUserBalanceHistory(ctx context.Context, userID int
} }
// Group management implementations // Group management implementations
func (s *adminServiceImpl) ListGroups(ctx context.Context, page, pageSize int, platform, status, search string, isExclusive *bool) ([]Group, int64, error) { func (s *adminServiceImpl) ListGroups(ctx context.Context, page, pageSize int, platform, status, search string, isExclusive *bool, sortBy, sortOrder string) ([]Group, int64, error) {
params := pagination.PaginationParams{Page: page, PageSize: pageSize} params := pagination.PaginationParams{Page: page, PageSize: pageSize, SortBy: sortBy, SortOrder: sortOrder}
groups, result, err := s.groupRepo.ListWithFilters(ctx, params, platform, status, search, isExclusive) groups, result, err := s.groupRepo.ListWithFilters(ctx, params, platform, status, search, isExclusive)
if err != nil { if err != nil {
return nil, 0, err return nil, 0, err
...@@ -908,7 +910,9 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn ...@@ -908,7 +910,9 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn
RequireOAuthOnly: input.RequireOAuthOnly, RequireOAuthOnly: input.RequireOAuthOnly,
RequirePrivacySet: input.RequirePrivacySet, RequirePrivacySet: input.RequirePrivacySet,
DefaultMappedModel: input.DefaultMappedModel, DefaultMappedModel: input.DefaultMappedModel,
MessagesDispatchModelConfig: normalizeOpenAIMessagesDispatchModelConfig(input.MessagesDispatchModelConfig),
} }
sanitizeGroupMessagesDispatchFields(group)
if err := s.groupRepo.Create(ctx, group); err != nil { if err := s.groupRepo.Create(ctx, group); err != nil {
return nil, err return nil, err
} }
...@@ -1135,6 +1139,10 @@ func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *Upd ...@@ -1135,6 +1139,10 @@ func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *Upd
if input.DefaultMappedModel != nil { if input.DefaultMappedModel != nil {
group.DefaultMappedModel = *input.DefaultMappedModel group.DefaultMappedModel = *input.DefaultMappedModel
} }
if input.MessagesDispatchModelConfig != nil {
group.MessagesDispatchModelConfig = normalizeOpenAIMessagesDispatchModelConfig(*input.MessagesDispatchModelConfig)
}
sanitizeGroupMessagesDispatchFields(group)
if err := s.groupRepo.Update(ctx, group); err != nil { if err := s.groupRepo.Update(ctx, group); err != nil {
return nil, err return nil, err
...@@ -1456,8 +1464,8 @@ func (s *adminServiceImpl) ReplaceUserGroup(ctx context.Context, userID, oldGrou ...@@ -1456,8 +1464,8 @@ func (s *adminServiceImpl) ReplaceUserGroup(ctx context.Context, userID, oldGrou
} }
// Account management implementations // Account management implementations
func (s *adminServiceImpl) ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string, groupID int64, privacyMode string) ([]Account, int64, error) { func (s *adminServiceImpl) ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string, groupID int64, privacyMode string, sortBy, sortOrder string) ([]Account, int64, error) {
params := pagination.PaginationParams{Page: page, PageSize: pageSize} params := pagination.PaginationParams{Page: page, PageSize: pageSize, SortBy: sortBy, SortOrder: sortOrder}
accounts, result, err := s.accountRepo.ListWithFilters(ctx, params, platform, accountType, status, search, groupID, privacyMode) accounts, result, err := s.accountRepo.ListWithFilters(ctx, params, platform, accountType, status, search, groupID, privacyMode)
if err != nil { if err != nil {
return nil, 0, err return nil, 0, err
...@@ -1885,8 +1893,8 @@ func (s *adminServiceImpl) SetAccountSchedulable(ctx context.Context, id int64, ...@@ -1885,8 +1893,8 @@ func (s *adminServiceImpl) SetAccountSchedulable(ctx context.Context, id int64,
} }
// Proxy management implementations // Proxy management implementations
func (s *adminServiceImpl) ListProxies(ctx context.Context, page, pageSize int, protocol, status, search string) ([]Proxy, int64, error) { func (s *adminServiceImpl) ListProxies(ctx context.Context, page, pageSize int, protocol, status, search string, sortBy, sortOrder string) ([]Proxy, int64, error) {
params := pagination.PaginationParams{Page: page, PageSize: pageSize} params := pagination.PaginationParams{Page: page, PageSize: pageSize, SortBy: sortBy, SortOrder: sortOrder}
proxies, result, err := s.proxyRepo.ListWithFilters(ctx, params, protocol, status, search) proxies, result, err := s.proxyRepo.ListWithFilters(ctx, params, protocol, status, search)
if err != nil { if err != nil {
return nil, 0, err return nil, 0, err
...@@ -1894,8 +1902,8 @@ func (s *adminServiceImpl) ListProxies(ctx context.Context, page, pageSize int, ...@@ -1894,8 +1902,8 @@ func (s *adminServiceImpl) ListProxies(ctx context.Context, page, pageSize int,
return proxies, result.Total, nil return proxies, result.Total, nil
} }
func (s *adminServiceImpl) ListProxiesWithAccountCount(ctx context.Context, page, pageSize int, protocol, status, search string) ([]ProxyWithAccountCount, int64, error) { func (s *adminServiceImpl) ListProxiesWithAccountCount(ctx context.Context, page, pageSize int, protocol, status, search string, sortBy, sortOrder string) ([]ProxyWithAccountCount, int64, error) {
params := pagination.PaginationParams{Page: page, PageSize: pageSize} params := pagination.PaginationParams{Page: page, PageSize: pageSize, SortBy: sortBy, SortOrder: sortOrder}
proxies, result, err := s.proxyRepo.ListWithFiltersAndAccountCount(ctx, params, protocol, status, search) proxies, result, err := s.proxyRepo.ListWithFiltersAndAccountCount(ctx, params, protocol, status, search)
if err != nil { if err != nil {
return nil, 0, err return nil, 0, err
...@@ -2032,8 +2040,8 @@ func (s *adminServiceImpl) CheckProxyExists(ctx context.Context, host string, po ...@@ -2032,8 +2040,8 @@ func (s *adminServiceImpl) CheckProxyExists(ctx context.Context, host string, po
} }
// Redeem code management implementations // Redeem code management implementations
func (s *adminServiceImpl) ListRedeemCodes(ctx context.Context, page, pageSize int, codeType, status, search string) ([]RedeemCode, int64, error) { func (s *adminServiceImpl) ListRedeemCodes(ctx context.Context, page, pageSize int, codeType, status, search string, sortBy, sortOrder string) ([]RedeemCode, int64, error) {
params := pagination.PaginationParams{Page: page, PageSize: pageSize} params := pagination.PaginationParams{Page: page, PageSize: pageSize, SortBy: sortBy, SortOrder: sortOrder}
codes, result, err := s.redeemCodeRepo.ListWithFilters(ctx, params, codeType, status, search) codes, result, err := s.redeemCodeRepo.ListWithFilters(ctx, params, codeType, status, search)
if err != nil { if err != nil {
return nil, 0, err return nil, 0, err
......
...@@ -10,6 +10,11 @@ import ( ...@@ -10,6 +10,11 @@ import (
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
func ptrString[T ~string](v T) *string {
s := string(v)
return &s
}
// groupRepoStubForAdmin 用于测试 AdminService 的 GroupRepository Stub // groupRepoStubForAdmin 用于测试 AdminService 的 GroupRepository Stub
type groupRepoStubForAdmin struct { type groupRepoStubForAdmin struct {
created *Group // 记录 Create 调用的参数 created *Group // 记录 Create 调用的参数
...@@ -120,6 +125,22 @@ func (s *groupRepoStubForAdmin) UpdateSortOrders(_ context.Context, _ []GroupSor ...@@ -120,6 +125,22 @@ func (s *groupRepoStubForAdmin) UpdateSortOrders(_ context.Context, _ []GroupSor
return nil return nil
} }
func TestAdminService_ListGroups_PassesSortParams(t *testing.T) {
repo := &groupRepoStubForAdmin{
listWithFiltersGroups: []Group{{ID: 1, Name: "g1"}},
}
svc := &adminServiceImpl{groupRepo: repo}
_, _, err := svc.ListGroups(context.Background(), 3, 25, PlatformOpenAI, StatusActive, "needle", nil, "account_count", "ASC")
require.NoError(t, err)
require.Equal(t, pagination.PaginationParams{
Page: 3,
PageSize: 25,
SortBy: "account_count",
SortOrder: "ASC",
}, repo.listWithFiltersParams)
}
// TestAdminService_CreateGroup_WithImagePricing 测试创建分组时 ImagePrice 字段正确传递 // TestAdminService_CreateGroup_WithImagePricing 测试创建分组时 ImagePrice 字段正确传递
func TestAdminService_CreateGroup_WithImagePricing(t *testing.T) { func TestAdminService_CreateGroup_WithImagePricing(t *testing.T) {
repo := &groupRepoStubForAdmin{} repo := &groupRepoStubForAdmin{}
...@@ -245,6 +266,116 @@ func TestAdminService_UpdateGroup_PartialImagePricing(t *testing.T) { ...@@ -245,6 +266,116 @@ func TestAdminService_UpdateGroup_PartialImagePricing(t *testing.T) {
require.Nil(t, repo.updated.ImagePrice4K) require.Nil(t, repo.updated.ImagePrice4K)
} }
func TestAdminService_CreateGroup_NormalizesMessagesDispatchModelConfig(t *testing.T) {
repo := &groupRepoStubForAdmin{}
svc := &adminServiceImpl{groupRepo: repo}
group, err := svc.CreateGroup(context.Background(), &CreateGroupInput{
Name: "dispatch-group",
Description: "dispatch config",
Platform: PlatformOpenAI,
RateMultiplier: 1.0,
MessagesDispatchModelConfig: OpenAIMessagesDispatchModelConfig{
OpusMappedModel: " gpt-5.4-high ",
SonnetMappedModel: " gpt-5.3-codex ",
HaikuMappedModel: " gpt-5.4-mini-medium ",
ExactModelMappings: map[string]string{
" claude-sonnet-4-5-20250929 ": " gpt-5.2-high ",
},
},
})
require.NoError(t, err)
require.NotNil(t, group)
require.NotNil(t, repo.created)
require.Equal(t, OpenAIMessagesDispatchModelConfig{
OpusMappedModel: "gpt-5.4",
SonnetMappedModel: "gpt-5.3-codex",
HaikuMappedModel: "gpt-5.4-mini",
ExactModelMappings: map[string]string{
"claude-sonnet-4-5-20250929": "gpt-5.2",
},
}, repo.created.MessagesDispatchModelConfig)
}
func TestAdminService_UpdateGroup_NormalizesMessagesDispatchModelConfig(t *testing.T) {
existingGroup := &Group{
ID: 1,
Name: "existing-group",
Platform: PlatformOpenAI,
Status: StatusActive,
}
repo := &groupRepoStubForAdmin{getByID: existingGroup}
svc := &adminServiceImpl{groupRepo: repo}
group, err := svc.UpdateGroup(context.Background(), 1, &UpdateGroupInput{
MessagesDispatchModelConfig: &OpenAIMessagesDispatchModelConfig{
SonnetMappedModel: " gpt-5.4-medium ",
ExactModelMappings: map[string]string{
" claude-haiku-4-5-20251001 ": " gpt-5.4-mini-high ",
},
},
})
require.NoError(t, err)
require.NotNil(t, group)
require.NotNil(t, repo.updated)
require.Equal(t, OpenAIMessagesDispatchModelConfig{
SonnetMappedModel: "gpt-5.4",
ExactModelMappings: map[string]string{
"claude-haiku-4-5-20251001": "gpt-5.4-mini",
},
}, repo.updated.MessagesDispatchModelConfig)
}
func TestAdminService_CreateGroup_ClearsMessagesDispatchFieldsForNonOpenAIPlatform(t *testing.T) {
repo := &groupRepoStubForAdmin{}
svc := &adminServiceImpl{groupRepo: repo}
group, err := svc.CreateGroup(context.Background(), &CreateGroupInput{
Name: "anthropic-group",
Description: "non-openai",
Platform: PlatformAnthropic,
RateMultiplier: 1.0,
AllowMessagesDispatch: true,
DefaultMappedModel: "gpt-5.4",
MessagesDispatchModelConfig: OpenAIMessagesDispatchModelConfig{
OpusMappedModel: "gpt-5.4",
},
})
require.NoError(t, err)
require.NotNil(t, group)
require.NotNil(t, repo.created)
require.False(t, repo.created.AllowMessagesDispatch)
require.Empty(t, repo.created.DefaultMappedModel)
require.Equal(t, OpenAIMessagesDispatchModelConfig{}, repo.created.MessagesDispatchModelConfig)
}
func TestAdminService_UpdateGroup_ClearsMessagesDispatchFieldsWhenPlatformChangesAwayFromOpenAI(t *testing.T) {
existingGroup := &Group{
ID: 1,
Name: "existing-openai-group",
Platform: PlatformOpenAI,
Status: StatusActive,
AllowMessagesDispatch: true,
DefaultMappedModel: "gpt-5.4",
MessagesDispatchModelConfig: OpenAIMessagesDispatchModelConfig{
SonnetMappedModel: "gpt-5.3-codex",
},
}
repo := &groupRepoStubForAdmin{getByID: existingGroup}
svc := &adminServiceImpl{groupRepo: repo}
group, err := svc.UpdateGroup(context.Background(), 1, &UpdateGroupInput{
Platform: PlatformAnthropic,
})
require.NoError(t, err)
require.NotNil(t, group)
require.NotNil(t, repo.updated)
require.Equal(t, PlatformAnthropic, repo.updated.Platform)
require.False(t, repo.updated.AllowMessagesDispatch)
require.Empty(t, repo.updated.DefaultMappedModel)
require.Equal(t, OpenAIMessagesDispatchModelConfig{}, repo.updated.MessagesDispatchModelConfig)
}
func TestAdminService_ListGroups_WithSearch(t *testing.T) { func TestAdminService_ListGroups_WithSearch(t *testing.T) {
// 测试: // 测试:
// 1. search 参数正常传递到 repository 层 // 1. search 参数正常传递到 repository 层
...@@ -258,7 +389,7 @@ func TestAdminService_ListGroups_WithSearch(t *testing.T) { ...@@ -258,7 +389,7 @@ func TestAdminService_ListGroups_WithSearch(t *testing.T) {
} }
svc := &adminServiceImpl{groupRepo: repo} svc := &adminServiceImpl{groupRepo: repo}
groups, total, err := svc.ListGroups(context.Background(), 1, 20, "", "", "alpha", nil) groups, total, err := svc.ListGroups(context.Background(), 1, 20, "", "", "alpha", nil, "", "")
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, int64(1), total) require.Equal(t, int64(1), total)
require.Equal(t, []Group{{ID: 1, Name: "alpha"}}, groups) require.Equal(t, []Group{{ID: 1, Name: "alpha"}}, groups)
...@@ -276,7 +407,7 @@ func TestAdminService_ListGroups_WithSearch(t *testing.T) { ...@@ -276,7 +407,7 @@ func TestAdminService_ListGroups_WithSearch(t *testing.T) {
} }
svc := &adminServiceImpl{groupRepo: repo} svc := &adminServiceImpl{groupRepo: repo}
groups, total, err := svc.ListGroups(context.Background(), 2, 10, "", "", "", nil) groups, total, err := svc.ListGroups(context.Background(), 2, 10, "", "", "", nil, "", "")
require.NoError(t, err) require.NoError(t, err)
require.Empty(t, groups) require.Empty(t, groups)
require.Equal(t, int64(0), total) require.Equal(t, int64(0), total)
...@@ -295,7 +426,7 @@ func TestAdminService_ListGroups_WithSearch(t *testing.T) { ...@@ -295,7 +426,7 @@ func TestAdminService_ListGroups_WithSearch(t *testing.T) {
} }
svc := &adminServiceImpl{groupRepo: repo} svc := &adminServiceImpl{groupRepo: repo}
groups, total, err := svc.ListGroups(context.Background(), 3, 50, PlatformAntigravity, StatusActive, "beta", &isExclusive) groups, total, err := svc.ListGroups(context.Background(), 3, 50, PlatformAntigravity, StatusActive, "beta", &isExclusive, "", "")
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, int64(42), total) require.Equal(t, int64(42), total)
require.Equal(t, []Group{{ID: 2, Name: "beta"}}, groups) require.Equal(t, []Group{{ID: 2, Name: "beta"}}, groups)
......
...@@ -13,11 +13,13 @@ import ( ...@@ -13,11 +13,13 @@ import (
type userRepoStubForListUsers struct { type userRepoStubForListUsers struct {
userRepoStub userRepoStub
users []User users []User
err error err error
listWithFiltersParams pagination.PaginationParams
} }
func (s *userRepoStubForListUsers) ListWithFilters(_ context.Context, params pagination.PaginationParams, _ UserListFilters) ([]User, *pagination.PaginationResult, error) { func (s *userRepoStubForListUsers) ListWithFilters(_ context.Context, params pagination.PaginationParams, _ UserListFilters) ([]User, *pagination.PaginationResult, error) {
s.listWithFiltersParams = params
if s.err != nil { if s.err != nil {
return nil, nil, s.err return nil, nil, s.err
} }
...@@ -103,7 +105,7 @@ func TestAdminService_ListUsers_BatchRateFallbackToSingle(t *testing.T) { ...@@ -103,7 +105,7 @@ func TestAdminService_ListUsers_BatchRateFallbackToSingle(t *testing.T) {
userGroupRateRepo: rateRepo, userGroupRateRepo: rateRepo,
} }
users, total, err := svc.ListUsers(context.Background(), 1, 20, UserListFilters{}) users, total, err := svc.ListUsers(context.Background(), 1, 20, UserListFilters{}, "", "")
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, int64(2), total) require.Equal(t, int64(2), total)
require.Len(t, users, 2) require.Len(t, users, 2)
...@@ -112,3 +114,19 @@ func TestAdminService_ListUsers_BatchRateFallbackToSingle(t *testing.T) { ...@@ -112,3 +114,19 @@ func TestAdminService_ListUsers_BatchRateFallbackToSingle(t *testing.T) {
require.Equal(t, 1.1, users[0].GroupRates[11]) require.Equal(t, 1.1, users[0].GroupRates[11])
require.Equal(t, 2.2, users[1].GroupRates[22]) require.Equal(t, 2.2, users[1].GroupRates[22])
} }
func TestAdminService_ListUsers_PassesSortParams(t *testing.T) {
userRepo := &userRepoStubForListUsers{
users: []User{{ID: 1, Email: "a@example.com"}},
}
svc := &adminServiceImpl{userRepo: userRepo}
_, _, err := svc.ListUsers(context.Background(), 2, 50, UserListFilters{}, "email", "ASC")
require.NoError(t, err)
require.Equal(t, pagination.PaginationParams{
Page: 2,
PageSize: 50,
SortBy: "email",
SortOrder: "ASC",
}, userRepo.listWithFiltersParams)
}
...@@ -170,13 +170,13 @@ func TestAdminService_ListAccounts_WithSearch(t *testing.T) { ...@@ -170,13 +170,13 @@ func TestAdminService_ListAccounts_WithSearch(t *testing.T) {
} }
svc := &adminServiceImpl{accountRepo: repo} svc := &adminServiceImpl{accountRepo: repo}
accounts, total, err := svc.ListAccounts(context.Background(), 1, 20, PlatformGemini, AccountTypeOAuth, StatusActive, "acc", 0, "") accounts, total, err := svc.ListAccounts(context.Background(), 1, 20, PlatformGemini, AccountTypeOAuth, StatusActive, "acc", 0, "", "name", "ASC")
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, int64(10), total) require.Equal(t, int64(10), total)
require.Equal(t, []Account{{ID: 1, Name: "acc"}}, accounts) require.Equal(t, []Account{{ID: 1, Name: "acc"}}, accounts)
require.Equal(t, 1, repo.listWithFiltersCalls) require.Equal(t, 1, repo.listWithFiltersCalls)
require.Equal(t, pagination.PaginationParams{Page: 1, PageSize: 20}, repo.listWithFiltersParams) require.Equal(t, pagination.PaginationParams{Page: 1, PageSize: 20, SortBy: "name", SortOrder: "ASC"}, repo.listWithFiltersParams)
require.Equal(t, PlatformGemini, repo.listWithFiltersPlatform) require.Equal(t, PlatformGemini, repo.listWithFiltersPlatform)
require.Equal(t, AccountTypeOAuth, repo.listWithFiltersType) require.Equal(t, AccountTypeOAuth, repo.listWithFiltersType)
require.Equal(t, StatusActive, repo.listWithFiltersStatus) require.Equal(t, StatusActive, repo.listWithFiltersStatus)
...@@ -192,7 +192,7 @@ func TestAdminService_ListAccounts_WithPrivacyMode(t *testing.T) { ...@@ -192,7 +192,7 @@ func TestAdminService_ListAccounts_WithPrivacyMode(t *testing.T) {
} }
svc := &adminServiceImpl{accountRepo: repo} svc := &adminServiceImpl{accountRepo: repo}
accounts, total, err := svc.ListAccounts(context.Background(), 1, 20, PlatformOpenAI, AccountTypeOAuth, StatusActive, "acc2", 0, PrivacyModeCFBlocked) accounts, total, err := svc.ListAccounts(context.Background(), 1, 20, PlatformOpenAI, AccountTypeOAuth, StatusActive, "acc2", 0, PrivacyModeCFBlocked, "", "")
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, int64(1), total) require.Equal(t, int64(1), total)
require.Equal(t, []Account{{ID: 2, Name: "acc2"}}, accounts) require.Equal(t, []Account{{ID: 2, Name: "acc2"}}, accounts)
...@@ -208,13 +208,13 @@ func TestAdminService_ListProxies_WithSearch(t *testing.T) { ...@@ -208,13 +208,13 @@ func TestAdminService_ListProxies_WithSearch(t *testing.T) {
} }
svc := &adminServiceImpl{proxyRepo: repo} svc := &adminServiceImpl{proxyRepo: repo}
proxies, total, err := svc.ListProxies(context.Background(), 3, 50, "http", StatusActive, "p1") proxies, total, err := svc.ListProxies(context.Background(), 3, 50, "http", StatusActive, "p1", "name", "ASC")
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, int64(7), total) require.Equal(t, int64(7), total)
require.Equal(t, []Proxy{{ID: 2, Name: "p1"}}, proxies) require.Equal(t, []Proxy{{ID: 2, Name: "p1"}}, proxies)
require.Equal(t, 1, repo.listWithFiltersCalls) require.Equal(t, 1, repo.listWithFiltersCalls)
require.Equal(t, pagination.PaginationParams{Page: 3, PageSize: 50}, repo.listWithFiltersParams) require.Equal(t, pagination.PaginationParams{Page: 3, PageSize: 50, SortBy: "name", SortOrder: "ASC"}, repo.listWithFiltersParams)
require.Equal(t, "http", repo.listWithFiltersProtocol) require.Equal(t, "http", repo.listWithFiltersProtocol)
require.Equal(t, StatusActive, repo.listWithFiltersStatus) require.Equal(t, StatusActive, repo.listWithFiltersStatus)
require.Equal(t, "p1", repo.listWithFiltersSearch) require.Equal(t, "p1", repo.listWithFiltersSearch)
...@@ -229,13 +229,13 @@ func TestAdminService_ListProxiesWithAccountCount_WithSearch(t *testing.T) { ...@@ -229,13 +229,13 @@ func TestAdminService_ListProxiesWithAccountCount_WithSearch(t *testing.T) {
} }
svc := &adminServiceImpl{proxyRepo: repo} svc := &adminServiceImpl{proxyRepo: repo}
proxies, total, err := svc.ListProxiesWithAccountCount(context.Background(), 2, 10, "socks5", StatusDisabled, "p2") proxies, total, err := svc.ListProxiesWithAccountCount(context.Background(), 2, 10, "socks5", StatusDisabled, "p2", "account_count", "DESC")
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, int64(9), total) require.Equal(t, int64(9), total)
require.Equal(t, []ProxyWithAccountCount{{Proxy: Proxy{ID: 3, Name: "p2"}, AccountCount: 5}}, proxies) require.Equal(t, []ProxyWithAccountCount{{Proxy: Proxy{ID: 3, Name: "p2"}, AccountCount: 5}}, proxies)
require.Equal(t, 1, repo.listWithFiltersAndAccountCountCalls) require.Equal(t, 1, repo.listWithFiltersAndAccountCountCalls)
require.Equal(t, pagination.PaginationParams{Page: 2, PageSize: 10}, repo.listWithFiltersAndAccountCountParams) require.Equal(t, pagination.PaginationParams{Page: 2, PageSize: 10, SortBy: "account_count", SortOrder: "DESC"}, repo.listWithFiltersAndAccountCountParams)
require.Equal(t, "socks5", repo.listWithFiltersAndAccountCountProtocol) require.Equal(t, "socks5", repo.listWithFiltersAndAccountCountProtocol)
require.Equal(t, StatusDisabled, repo.listWithFiltersAndAccountCountStatus) require.Equal(t, StatusDisabled, repo.listWithFiltersAndAccountCountStatus)
require.Equal(t, "p2", repo.listWithFiltersAndAccountCountSearch) require.Equal(t, "p2", repo.listWithFiltersAndAccountCountSearch)
...@@ -250,13 +250,13 @@ func TestAdminService_ListRedeemCodes_WithSearch(t *testing.T) { ...@@ -250,13 +250,13 @@ func TestAdminService_ListRedeemCodes_WithSearch(t *testing.T) {
} }
svc := &adminServiceImpl{redeemCodeRepo: repo} svc := &adminServiceImpl{redeemCodeRepo: repo}
codes, total, err := svc.ListRedeemCodes(context.Background(), 1, 20, RedeemTypeBalance, StatusUnused, "ABC") codes, total, err := svc.ListRedeemCodes(context.Background(), 1, 20, RedeemTypeBalance, StatusUnused, "ABC", "value", "ASC")
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, int64(3), total) require.Equal(t, int64(3), total)
require.Equal(t, []RedeemCode{{ID: 4, Code: "ABC"}}, codes) require.Equal(t, []RedeemCode{{ID: 4, Code: "ABC"}}, codes)
require.Equal(t, 1, repo.listWithFiltersCalls) require.Equal(t, 1, repo.listWithFiltersCalls)
require.Equal(t, pagination.PaginationParams{Page: 1, PageSize: 20}, repo.listWithFiltersParams) require.Equal(t, pagination.PaginationParams{Page: 1, PageSize: 20, SortBy: "value", SortOrder: "ASC"}, repo.listWithFiltersParams)
require.Equal(t, RedeemTypeBalance, repo.listWithFiltersType) require.Equal(t, RedeemTypeBalance, repo.listWithFiltersType)
require.Equal(t, StatusUnused, repo.listWithFiltersStatus) require.Equal(t, StatusUnused, repo.listWithFiltersStatus)
require.Equal(t, "ABC", repo.listWithFiltersSearch) require.Equal(t, "ABC", repo.listWithFiltersSearch)
......
...@@ -4,6 +4,7 @@ import "time" ...@@ -4,6 +4,7 @@ import "time"
// APIKeyAuthSnapshot API Key 认证缓存快照(仅包含认证所需字段) // APIKeyAuthSnapshot API Key 认证缓存快照(仅包含认证所需字段)
type APIKeyAuthSnapshot struct { type APIKeyAuthSnapshot struct {
Version int `json:"version"`
APIKeyID int64 `json:"api_key_id"` APIKeyID int64 `json:"api_key_id"`
UserID int64 `json:"user_id"` UserID int64 `json:"user_id"`
GroupID *int64 `json:"group_id,omitempty"` GroupID *int64 `json:"group_id,omitempty"`
...@@ -63,8 +64,9 @@ type APIKeyAuthGroupSnapshot struct { ...@@ -63,8 +64,9 @@ type APIKeyAuthGroupSnapshot struct {
SupportedModelScopes []string `json:"supported_model_scopes,omitempty"` SupportedModelScopes []string `json:"supported_model_scopes,omitempty"`
// OpenAI Messages 调度配置(仅 openai 平台使用) // OpenAI Messages 调度配置(仅 openai 平台使用)
AllowMessagesDispatch bool `json:"allow_messages_dispatch"` AllowMessagesDispatch bool `json:"allow_messages_dispatch"`
DefaultMappedModel string `json:"default_mapped_model,omitempty"` DefaultMappedModel string `json:"default_mapped_model,omitempty"`
MessagesDispatchModelConfig OpenAIMessagesDispatchModelConfig `json:"messages_dispatch_model_config,omitempty"`
} }
// APIKeyAuthCacheEntry 缓存条目,支持负缓存 // APIKeyAuthCacheEntry 缓存条目,支持负缓存
......
...@@ -13,6 +13,8 @@ import ( ...@@ -13,6 +13,8 @@ import (
"github.com/dgraph-io/ristretto" "github.com/dgraph-io/ristretto"
) )
const apiKeyAuthSnapshotVersion = 3
type apiKeyAuthCacheConfig struct { type apiKeyAuthCacheConfig struct {
l1Size int l1Size int
l1TTL time.Duration l1TTL time.Duration
...@@ -192,6 +194,9 @@ func (s *APIKeyService) applyAuthCacheEntry(key string, entry *APIKeyAuthCacheEn ...@@ -192,6 +194,9 @@ func (s *APIKeyService) applyAuthCacheEntry(key string, entry *APIKeyAuthCacheEn
if entry.Snapshot == nil { if entry.Snapshot == nil {
return nil, false, nil return nil, false, nil
} }
if entry.Snapshot.Version != apiKeyAuthSnapshotVersion {
return nil, false, nil
}
return s.snapshotToAPIKey(key, entry.Snapshot), true, nil return s.snapshotToAPIKey(key, entry.Snapshot), true, nil
} }
...@@ -200,6 +205,7 @@ func (s *APIKeyService) snapshotFromAPIKey(apiKey *APIKey) *APIKeyAuthSnapshot { ...@@ -200,6 +205,7 @@ func (s *APIKeyService) snapshotFromAPIKey(apiKey *APIKey) *APIKeyAuthSnapshot {
return nil return nil
} }
snapshot := &APIKeyAuthSnapshot{ snapshot := &APIKeyAuthSnapshot{
Version: apiKeyAuthSnapshotVersion,
APIKeyID: apiKey.ID, APIKeyID: apiKey.ID,
UserID: apiKey.UserID, UserID: apiKey.UserID,
GroupID: apiKey.GroupID, GroupID: apiKey.GroupID,
...@@ -243,6 +249,7 @@ func (s *APIKeyService) snapshotFromAPIKey(apiKey *APIKey) *APIKeyAuthSnapshot { ...@@ -243,6 +249,7 @@ func (s *APIKeyService) snapshotFromAPIKey(apiKey *APIKey) *APIKeyAuthSnapshot {
SupportedModelScopes: apiKey.Group.SupportedModelScopes, SupportedModelScopes: apiKey.Group.SupportedModelScopes,
AllowMessagesDispatch: apiKey.Group.AllowMessagesDispatch, AllowMessagesDispatch: apiKey.Group.AllowMessagesDispatch,
DefaultMappedModel: apiKey.Group.DefaultMappedModel, DefaultMappedModel: apiKey.Group.DefaultMappedModel,
MessagesDispatchModelConfig: apiKey.Group.MessagesDispatchModelConfig,
} }
} }
return snapshot return snapshot
...@@ -298,6 +305,7 @@ func (s *APIKeyService) snapshotToAPIKey(key string, snapshot *APIKeyAuthSnapsho ...@@ -298,6 +305,7 @@ func (s *APIKeyService) snapshotToAPIKey(key string, snapshot *APIKeyAuthSnapsho
SupportedModelScopes: snapshot.Group.SupportedModelScopes, SupportedModelScopes: snapshot.Group.SupportedModelScopes,
AllowMessagesDispatch: snapshot.Group.AllowMessagesDispatch, AllowMessagesDispatch: snapshot.Group.AllowMessagesDispatch,
DefaultMappedModel: snapshot.Group.DefaultMappedModel, DefaultMappedModel: snapshot.Group.DefaultMappedModel,
MessagesDispatchModelConfig: snapshot.Group.MessagesDispatchModelConfig,
} }
} }
s.compileAPIKeyIPRules(apiKey) s.compileAPIKeyIPRules(apiKey)
......
...@@ -188,6 +188,7 @@ func TestAPIKeyService_GetByKey_UsesL2Cache(t *testing.T) { ...@@ -188,6 +188,7 @@ func TestAPIKeyService_GetByKey_UsesL2Cache(t *testing.T) {
groupID := int64(9) groupID := int64(9)
cacheEntry := &APIKeyAuthCacheEntry{ cacheEntry := &APIKeyAuthCacheEntry{
Snapshot: &APIKeyAuthSnapshot{ Snapshot: &APIKeyAuthSnapshot{
Version: apiKeyAuthSnapshotVersion,
APIKeyID: 1, APIKeyID: 1,
UserID: 2, UserID: 2,
GroupID: &groupID, GroupID: &groupID,
...@@ -226,6 +227,129 @@ func TestAPIKeyService_GetByKey_UsesL2Cache(t *testing.T) { ...@@ -226,6 +227,129 @@ func TestAPIKeyService_GetByKey_UsesL2Cache(t *testing.T) {
require.Equal(t, map[string][]int64{"claude-opus-*": {1, 2}}, apiKey.Group.ModelRouting) require.Equal(t, map[string][]int64{"claude-opus-*": {1, 2}}, apiKey.Group.ModelRouting)
} }
func TestAPIKeyService_SnapshotRoundTrip_PreservesMessagesDispatchModelConfig(t *testing.T) {
svc := NewAPIKeyService(nil, nil, nil, nil, nil, nil, &config.Config{})
groupID := int64(9)
apiKey := &APIKey{
ID: 1,
UserID: 2,
GroupID: &groupID,
Key: "k-roundtrip",
Status: StatusActive,
User: &User{
ID: 2,
Status: StatusActive,
Role: RoleUser,
Balance: 10,
Concurrency: 3,
},
Group: &Group{
ID: groupID,
Name: "openai",
Platform: PlatformOpenAI,
Status: StatusActive,
SubscriptionType: SubscriptionTypeStandard,
RateMultiplier: 1,
AllowMessagesDispatch: true,
DefaultMappedModel: "gpt-5.4",
MessagesDispatchModelConfig: OpenAIMessagesDispatchModelConfig{
OpusMappedModel: "gpt-5.4-nano",
SonnetMappedModel: "gpt-5.3-codex",
HaikuMappedModel: "gpt-5.4-mini",
ExactModelMappings: map[string]string{
"claude-sonnet-4.5": "gpt-5.4-nano",
},
},
},
}
snapshot := svc.snapshotFromAPIKey(apiKey)
roundTrip := svc.snapshotToAPIKey(apiKey.Key, snapshot)
require.NotNil(t, roundTrip)
require.NotNil(t, roundTrip.Group)
require.Equal(t, apiKey.Group.MessagesDispatchModelConfig, roundTrip.Group.MessagesDispatchModelConfig)
}
func TestAPIKeyService_GetByKey_IgnoresLegacyAuthCacheSnapshotWithoutMessagesDispatchConfig(t *testing.T) {
cache := &authCacheStub{}
var repoCalls int32
repo := &authRepoStub{
getByKeyForAuth: func(ctx context.Context, key string) (*APIKey, error) {
atomic.AddInt32(&repoCalls, 1)
groupID := int64(9)
return &APIKey{
ID: 1,
UserID: 2,
GroupID: &groupID,
Status: StatusActive,
User: &User{
ID: 2,
Status: StatusActive,
Role: RoleUser,
Balance: 10,
Concurrency: 3,
},
Group: &Group{
ID: groupID,
Name: "openai",
Platform: PlatformOpenAI,
Status: StatusActive,
Hydrated: true,
SubscriptionType: SubscriptionTypeStandard,
RateMultiplier: 1,
AllowMessagesDispatch: true,
DefaultMappedModel: "gpt-5.4",
MessagesDispatchModelConfig: OpenAIMessagesDispatchModelConfig{
OpusMappedModel: "gpt-5.4-nano",
},
},
}, nil
},
}
cfg := &config.Config{
APIKeyAuth: config.APIKeyAuthCacheConfig{
L2TTLSeconds: 60,
},
}
svc := NewAPIKeyService(repo, nil, nil, nil, nil, cache, cfg)
groupID := int64(9)
cache.getAuthCache = func(ctx context.Context, key string) (*APIKeyAuthCacheEntry, error) {
return &APIKeyAuthCacheEntry{
Snapshot: &APIKeyAuthSnapshot{
APIKeyID: 1,
UserID: 2,
GroupID: &groupID,
Status: StatusActive,
User: APIKeyAuthUserSnapshot{
ID: 2,
Status: StatusActive,
Role: RoleUser,
Balance: 10,
Concurrency: 3,
},
Group: &APIKeyAuthGroupSnapshot{
ID: groupID,
Name: "openai",
Platform: PlatformOpenAI,
Status: StatusActive,
SubscriptionType: SubscriptionTypeStandard,
RateMultiplier: 1,
AllowMessagesDispatch: true,
DefaultMappedModel: "gpt-5.4",
},
},
}, nil
}
apiKey, err := svc.GetByKey(context.Background(), "k-legacy")
require.NoError(t, err)
require.Equal(t, int32(1), atomic.LoadInt32(&repoCalls))
require.NotNil(t, apiKey.Group)
require.Equal(t, "gpt-5.4-nano", apiKey.Group.MessagesDispatchModelConfig.OpusMappedModel)
}
func TestAPIKeyService_GetByKey_NegativeCache(t *testing.T) { func TestAPIKeyService_GetByKey_NegativeCache(t *testing.T) {
cache := &authCacheStub{} cache := &authCacheStub{}
repo := &authRepoStub{ repo := &authRepoStub{
......
...@@ -833,7 +833,8 @@ func randomHexString(byteLength int) (string, error) { ...@@ -833,7 +833,8 @@ func randomHexString(byteLength int) (string, error) {
func isReservedEmail(email string) bool { func isReservedEmail(email string) bool {
normalized := strings.ToLower(strings.TrimSpace(email)) normalized := strings.ToLower(strings.TrimSpace(email))
return strings.HasSuffix(normalized, LinuxDoConnectSyntheticEmailDomain) return strings.HasSuffix(normalized, LinuxDoConnectSyntheticEmailDomain) ||
strings.HasSuffix(normalized, OIDCConnectSyntheticEmailDomain)
} }
// GenerateToken 生成JWT access token // GenerateToken 生成JWT access token
......
...@@ -71,6 +71,9 @@ const ( ...@@ -71,6 +71,9 @@ const (
// LinuxDoConnectSyntheticEmailDomain 是 LinuxDo Connect 用户的合成邮箱后缀(RFC 保留域名)。 // LinuxDoConnectSyntheticEmailDomain 是 LinuxDo Connect 用户的合成邮箱后缀(RFC 保留域名)。
const LinuxDoConnectSyntheticEmailDomain = "@linuxdo-connect.invalid" const LinuxDoConnectSyntheticEmailDomain = "@linuxdo-connect.invalid"
// OIDCConnectSyntheticEmailDomain 是 OIDC 用户的合成邮箱后缀(RFC 保留域名)。
const OIDCConnectSyntheticEmailDomain = "@oidc-connect.invalid"
// Setting keys // Setting keys
const ( const (
// 注册设置 // 注册设置
...@@ -105,6 +108,30 @@ const ( ...@@ -105,6 +108,30 @@ const (
SettingKeyLinuxDoConnectClientSecret = "linuxdo_connect_client_secret" SettingKeyLinuxDoConnectClientSecret = "linuxdo_connect_client_secret"
SettingKeyLinuxDoConnectRedirectURL = "linuxdo_connect_redirect_url" SettingKeyLinuxDoConnectRedirectURL = "linuxdo_connect_redirect_url"
// Generic OIDC OAuth 登录设置
SettingKeyOIDCConnectEnabled = "oidc_connect_enabled"
SettingKeyOIDCConnectProviderName = "oidc_connect_provider_name"
SettingKeyOIDCConnectClientID = "oidc_connect_client_id"
SettingKeyOIDCConnectClientSecret = "oidc_connect_client_secret"
SettingKeyOIDCConnectIssuerURL = "oidc_connect_issuer_url"
SettingKeyOIDCConnectDiscoveryURL = "oidc_connect_discovery_url"
SettingKeyOIDCConnectAuthorizeURL = "oidc_connect_authorize_url"
SettingKeyOIDCConnectTokenURL = "oidc_connect_token_url"
SettingKeyOIDCConnectUserInfoURL = "oidc_connect_userinfo_url"
SettingKeyOIDCConnectJWKSURL = "oidc_connect_jwks_url"
SettingKeyOIDCConnectScopes = "oidc_connect_scopes"
SettingKeyOIDCConnectRedirectURL = "oidc_connect_redirect_url"
SettingKeyOIDCConnectFrontendRedirectURL = "oidc_connect_frontend_redirect_url"
SettingKeyOIDCConnectTokenAuthMethod = "oidc_connect_token_auth_method"
SettingKeyOIDCConnectUsePKCE = "oidc_connect_use_pkce"
SettingKeyOIDCConnectValidateIDToken = "oidc_connect_validate_id_token"
SettingKeyOIDCConnectAllowedSigningAlgs = "oidc_connect_allowed_signing_algs"
SettingKeyOIDCConnectClockSkewSeconds = "oidc_connect_clock_skew_seconds"
SettingKeyOIDCConnectRequireEmailVerified = "oidc_connect_require_email_verified"
SettingKeyOIDCConnectUserInfoEmailPath = "oidc_connect_userinfo_email_path"
SettingKeyOIDCConnectUserInfoIDPath = "oidc_connect_userinfo_id_path"
SettingKeyOIDCConnectUserInfoUsernamePath = "oidc_connect_userinfo_username_path"
// OEM设置 // OEM设置
SettingKeySiteName = "site_name" // 网站名称 SettingKeySiteName = "site_name" // 网站名称
SettingKeySiteLogo = "site_logo" // 网站Logo (base64) SettingKeySiteLogo = "site_logo" // 网站Logo (base64)
...@@ -116,6 +143,8 @@ const ( ...@@ -116,6 +143,8 @@ const (
SettingKeyHideCcsImportButton = "hide_ccs_import_button" // 是否隐藏 API Keys 页面的导入 CCS 按钮 SettingKeyHideCcsImportButton = "hide_ccs_import_button" // 是否隐藏 API Keys 页面的导入 CCS 按钮
SettingKeyPurchaseSubscriptionEnabled = "purchase_subscription_enabled" // 是否展示"购买订阅"页面入口 SettingKeyPurchaseSubscriptionEnabled = "purchase_subscription_enabled" // 是否展示"购买订阅"页面入口
SettingKeyPurchaseSubscriptionURL = "purchase_subscription_url" // "购买订阅"页面 URL(作为 iframe src) SettingKeyPurchaseSubscriptionURL = "purchase_subscription_url" // "购买订阅"页面 URL(作为 iframe src)
SettingKeyTableDefaultPageSize = "table_default_page_size" // 表格默认每页条数
SettingKeyTablePageSizeOptions = "table_page_size_options" // 表格可选每页条数(JSON 数组)
SettingKeyCustomMenuItems = "custom_menu_items" // 自定义菜单项(JSON 数组) SettingKeyCustomMenuItems = "custom_menu_items" // 自定义菜单项(JSON 数组)
SettingKeyCustomEndpoints = "custom_endpoints" // 自定义端点列表(JSON 数组) SettingKeyCustomEndpoints = "custom_endpoints" // 自定义端点列表(JSON 数组)
......
...@@ -1192,12 +1192,20 @@ func (s *GatewayService) SelectAccountForModelWithExclusions(ctx context.Context ...@@ -1192,12 +1192,20 @@ func (s *GatewayService) SelectAccountForModelWithExclusions(ctx context.Context
// anthropic/gemini 分组支持混合调度(包含启用了 mixed_scheduling 的 antigravity 账户) // anthropic/gemini 分组支持混合调度(包含启用了 mixed_scheduling 的 antigravity 账户)
// 注意:强制平台模式不走混合调度 // 注意:强制平台模式不走混合调度
if (platform == PlatformAnthropic || platform == PlatformGemini) && !hasForcePlatform { if (platform == PlatformAnthropic || platform == PlatformGemini) && !hasForcePlatform {
return s.selectAccountWithMixedScheduling(ctx, groupID, sessionHash, requestedModel, excludedIDs, platform) account, err := s.selectAccountWithMixedScheduling(ctx, groupID, sessionHash, requestedModel, excludedIDs, platform)
if err != nil {
return nil, err
}
return s.hydrateSelectedAccount(ctx, account)
} }
// antigravity 分组、强制平台模式或无分组使用单平台选择 // antigravity 分组、强制平台模式或无分组使用单平台选择
// 注意:强制平台模式也必须遵守分组限制,不再回退到全平台查询 // 注意:强制平台模式也必须遵守分组限制,不再回退到全平台查询
return s.selectAccountForModelWithPlatform(ctx, groupID, sessionHash, requestedModel, excludedIDs, platform) account, err := s.selectAccountForModelWithPlatform(ctx, groupID, sessionHash, requestedModel, excludedIDs, platform)
if err != nil {
return nil, err
}
return s.hydrateSelectedAccount(ctx, account)
} }
// SelectAccountWithLoadAwareness selects account with load-awareness and wait plan. // SelectAccountWithLoadAwareness selects account with load-awareness and wait plan.
...@@ -1273,11 +1281,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro ...@@ -1273,11 +1281,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
localExcluded[account.ID] = struct{}{} // 排除此账号 localExcluded[account.ID] = struct{}{} // 排除此账号
continue // 重新选择 continue // 重新选择
} }
return &AccountSelectionResult{ return s.newSelectionResult(ctx, account, true, result.ReleaseFunc, nil)
Account: account,
Acquired: true,
ReleaseFunc: result.ReleaseFunc,
}, nil
} }
// 对于等待计划的情况,也需要先检查会话限制 // 对于等待计划的情况,也需要先检查会话限制
...@@ -1289,26 +1293,20 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro ...@@ -1289,26 +1293,20 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
if stickyAccountID > 0 && stickyAccountID == account.ID && s.concurrencyService != nil { if stickyAccountID > 0 && stickyAccountID == account.ID && s.concurrencyService != nil {
waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, account.ID) waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, account.ID)
if waitingCount < cfg.StickySessionMaxWaiting { if waitingCount < cfg.StickySessionMaxWaiting {
return &AccountSelectionResult{ return s.newSelectionResult(ctx, account, false, nil, &AccountWaitPlan{
Account: account, AccountID: account.ID,
WaitPlan: &AccountWaitPlan{ MaxConcurrency: account.Concurrency,
AccountID: account.ID, Timeout: cfg.StickySessionWaitTimeout,
MaxConcurrency: account.Concurrency, MaxWaiting: cfg.StickySessionMaxWaiting,
Timeout: cfg.StickySessionWaitTimeout, })
MaxWaiting: cfg.StickySessionMaxWaiting,
},
}, nil
} }
} }
return &AccountSelectionResult{ return s.newSelectionResult(ctx, account, false, nil, &AccountWaitPlan{
Account: account, AccountID: account.ID,
WaitPlan: &AccountWaitPlan{ MaxConcurrency: account.Concurrency,
AccountID: account.ID, Timeout: cfg.FallbackWaitTimeout,
MaxConcurrency: account.Concurrency, MaxWaiting: cfg.FallbackMaxWaiting,
Timeout: cfg.FallbackWaitTimeout, })
MaxWaiting: cfg.FallbackMaxWaiting,
},
}, nil
} }
} }
...@@ -1455,11 +1453,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro ...@@ -1455,11 +1453,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
if s.debugModelRoutingEnabled() { if s.debugModelRoutingEnabled() {
logger.LegacyPrintf("service.gateway", "[ModelRoutingDebug] routed sticky hit: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), stickyAccountID) logger.LegacyPrintf("service.gateway", "[ModelRoutingDebug] routed sticky hit: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), stickyAccountID)
} }
return &AccountSelectionResult{ return s.newSelectionResult(ctx, stickyAccount, true, result.ReleaseFunc, nil)
Account: stickyAccount,
Acquired: true,
ReleaseFunc: result.ReleaseFunc,
}, nil
} }
} }
...@@ -1570,11 +1564,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro ...@@ -1570,11 +1564,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
if s.debugModelRoutingEnabled() { if s.debugModelRoutingEnabled() {
logger.LegacyPrintf("service.gateway", "[ModelRoutingDebug] routed select: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), item.account.ID) logger.LegacyPrintf("service.gateway", "[ModelRoutingDebug] routed select: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), item.account.ID)
} }
return &AccountSelectionResult{ return s.newSelectionResult(ctx, item.account, true, result.ReleaseFunc, nil)
Account: item.account,
Acquired: true,
ReleaseFunc: result.ReleaseFunc,
}, nil
} }
} }
...@@ -1587,15 +1577,12 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro ...@@ -1587,15 +1577,12 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
if s.debugModelRoutingEnabled() { if s.debugModelRoutingEnabled() {
logger.LegacyPrintf("service.gateway", "[ModelRoutingDebug] routed wait: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), item.account.ID) logger.LegacyPrintf("service.gateway", "[ModelRoutingDebug] routed wait: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), item.account.ID)
} }
return &AccountSelectionResult{ return s.newSelectionResult(ctx, item.account, false, nil, &AccountWaitPlan{
Account: item.account, AccountID: item.account.ID,
WaitPlan: &AccountWaitPlan{ MaxConcurrency: item.account.Concurrency,
AccountID: item.account.ID, Timeout: cfg.StickySessionWaitTimeout,
MaxConcurrency: item.account.Concurrency, MaxWaiting: cfg.StickySessionMaxWaiting,
Timeout: cfg.StickySessionWaitTimeout, })
MaxWaiting: cfg.StickySessionMaxWaiting,
},
}, nil
} }
// 所有路由账号会话限制都已满,继续到 Layer 2 回退 // 所有路由账号会话限制都已满,继续到 Layer 2 回退
} }
...@@ -1631,11 +1618,10 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro ...@@ -1631,11 +1618,10 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
if !s.checkAndRegisterSession(ctx, account, sessionHash) { if !s.checkAndRegisterSession(ctx, account, sessionHash) {
result.ReleaseFunc() // 释放槽位,继续到 Layer 2 result.ReleaseFunc() // 释放槽位,继续到 Layer 2
} else { } else {
return &AccountSelectionResult{ if s.cache != nil {
Account: account, _ = s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL)
Acquired: true, }
ReleaseFunc: result.ReleaseFunc, return s.newSelectionResult(ctx, account, true, result.ReleaseFunc, nil)
}, nil
} }
} }
...@@ -1647,15 +1633,12 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro ...@@ -1647,15 +1633,12 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
// 会话限制已满,继续到 Layer 2 // 会话限制已满,继续到 Layer 2
// Session limit full, continue to Layer 2 // Session limit full, continue to Layer 2
} else { } else {
return &AccountSelectionResult{ return s.newSelectionResult(ctx, account, false, nil, &AccountWaitPlan{
Account: account, AccountID: accountID,
WaitPlan: &AccountWaitPlan{ MaxConcurrency: account.Concurrency,
AccountID: accountID, Timeout: cfg.StickySessionWaitTimeout,
MaxConcurrency: account.Concurrency, MaxWaiting: cfg.StickySessionMaxWaiting,
Timeout: cfg.StickySessionWaitTimeout, })
MaxWaiting: cfg.StickySessionMaxWaiting,
},
}, nil
} }
} }
} }
...@@ -1714,7 +1697,9 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro ...@@ -1714,7 +1697,9 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
loadMap, err := s.concurrencyService.GetAccountsLoadBatch(ctx, accountLoads) loadMap, err := s.concurrencyService.GetAccountsLoadBatch(ctx, accountLoads)
if err != nil { if err != nil {
if result, ok := s.tryAcquireByLegacyOrder(ctx, candidates, groupID, sessionHash, preferOAuth); ok { if result, ok, legacyErr := s.tryAcquireByLegacyOrder(ctx, candidates, groupID, sessionHash, preferOAuth); legacyErr != nil {
return nil, legacyErr
} else if ok {
return result, nil return result, nil
} }
} else { } else {
...@@ -1753,11 +1738,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro ...@@ -1753,11 +1738,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
if sessionHash != "" && s.cache != nil { if sessionHash != "" && s.cache != nil {
_ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, selected.account.ID, stickySessionTTL) _ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, selected.account.ID, stickySessionTTL)
} }
return &AccountSelectionResult{ return s.newSelectionResult(ctx, selected.account, true, result.ReleaseFunc, nil)
Account: selected.account,
Acquired: true,
ReleaseFunc: result.ReleaseFunc,
}, nil
} }
} }
...@@ -1780,20 +1761,17 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro ...@@ -1780,20 +1761,17 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
if !s.checkAndRegisterSession(ctx, acc, sessionHash) { if !s.checkAndRegisterSession(ctx, acc, sessionHash) {
continue // 会话限制已满,尝试下一个账号 continue // 会话限制已满,尝试下一个账号
} }
return &AccountSelectionResult{ return s.newSelectionResult(ctx, acc, false, nil, &AccountWaitPlan{
Account: acc, AccountID: acc.ID,
WaitPlan: &AccountWaitPlan{ MaxConcurrency: acc.Concurrency,
AccountID: acc.ID, Timeout: cfg.FallbackWaitTimeout,
MaxConcurrency: acc.Concurrency, MaxWaiting: cfg.FallbackMaxWaiting,
Timeout: cfg.FallbackWaitTimeout, })
MaxWaiting: cfg.FallbackMaxWaiting,
},
}, nil
} }
return nil, ErrNoAvailableAccounts return nil, ErrNoAvailableAccounts
} }
func (s *GatewayService) tryAcquireByLegacyOrder(ctx context.Context, candidates []*Account, groupID *int64, sessionHash string, preferOAuth bool) (*AccountSelectionResult, bool) { func (s *GatewayService) tryAcquireByLegacyOrder(ctx context.Context, candidates []*Account, groupID *int64, sessionHash string, preferOAuth bool) (*AccountSelectionResult, bool, error) {
ordered := append([]*Account(nil), candidates...) ordered := append([]*Account(nil), candidates...)
sortAccountsByPriorityAndLastUsed(ordered, preferOAuth) sortAccountsByPriorityAndLastUsed(ordered, preferOAuth)
...@@ -1808,15 +1786,15 @@ func (s *GatewayService) tryAcquireByLegacyOrder(ctx context.Context, candidates ...@@ -1808,15 +1786,15 @@ func (s *GatewayService) tryAcquireByLegacyOrder(ctx context.Context, candidates
if sessionHash != "" && s.cache != nil { if sessionHash != "" && s.cache != nil {
_ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, acc.ID, stickySessionTTL) _ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, acc.ID, stickySessionTTL)
} }
return &AccountSelectionResult{ selection, err := s.newSelectionResult(ctx, acc, true, result.ReleaseFunc, nil)
Account: acc, if err != nil {
Acquired: true, return nil, false, err
ReleaseFunc: result.ReleaseFunc, }
}, true return selection, true, nil
} }
} }
return nil, false return nil, false, nil
} }
func (s *GatewayService) schedulingConfig() config.GatewaySchedulingConfig { func (s *GatewayService) schedulingConfig() config.GatewaySchedulingConfig {
...@@ -2431,6 +2409,33 @@ func (s *GatewayService) getSchedulableAccount(ctx context.Context, accountID in ...@@ -2431,6 +2409,33 @@ func (s *GatewayService) getSchedulableAccount(ctx context.Context, accountID in
return s.accountRepo.GetByID(ctx, accountID) return s.accountRepo.GetByID(ctx, accountID)
} }
func (s *GatewayService) hydrateSelectedAccount(ctx context.Context, account *Account) (*Account, error) {
if account == nil || s.schedulerSnapshot == nil {
return account, nil
}
hydrated, err := s.schedulerSnapshot.GetAccount(ctx, account.ID)
if err != nil {
return nil, err
}
if hydrated == nil {
return nil, fmt.Errorf("selected gateway account %d not found during hydration", account.ID)
}
return hydrated, nil
}
func (s *GatewayService) newSelectionResult(ctx context.Context, account *Account, acquired bool, release func(), waitPlan *AccountWaitPlan) (*AccountSelectionResult, error) {
hydrated, err := s.hydrateSelectedAccount(ctx, account)
if err != nil {
return nil, err
}
return &AccountSelectionResult{
Account: hydrated,
Acquired: acquired,
ReleaseFunc: release,
WaitPlan: waitPlan,
}, nil
}
// filterByMinPriority 过滤出优先级最小的账号集合 // filterByMinPriority 过滤出优先级最小的账号集合
func filterByMinPriority(accounts []accountWithLoad) []accountWithLoad { func filterByMinPriority(accounts []accountWithLoad) []accountWithLoad {
if len(accounts) == 0 { if len(accounts) == 0 {
......
...@@ -137,7 +137,7 @@ func (s *GeminiMessagesCompatService) SelectAccountForModelWithExclusions(ctx co ...@@ -137,7 +137,7 @@ func (s *GeminiMessagesCompatService) SelectAccountForModelWithExclusions(ctx co
_ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), cacheKey, selected.ID, geminiStickySessionTTL) _ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), cacheKey, selected.ID, geminiStickySessionTTL)
} }
return selected, nil return s.hydrateSelectedAccount(ctx, selected)
} }
// resolvePlatformAndSchedulingMode 解析目标平台和调度模式。 // resolvePlatformAndSchedulingMode 解析目标平台和调度模式。
...@@ -416,6 +416,20 @@ func (s *GeminiMessagesCompatService) getSchedulableAccount(ctx context.Context, ...@@ -416,6 +416,20 @@ func (s *GeminiMessagesCompatService) getSchedulableAccount(ctx context.Context,
return s.accountRepo.GetByID(ctx, accountID) return s.accountRepo.GetByID(ctx, accountID)
} }
func (s *GeminiMessagesCompatService) hydrateSelectedAccount(ctx context.Context, account *Account) (*Account, error) {
if account == nil || s.schedulerSnapshot == nil {
return account, nil
}
hydrated, err := s.schedulerSnapshot.GetAccount(ctx, account.ID)
if err != nil {
return nil, err
}
if hydrated == nil {
return nil, fmt.Errorf("selected gemini account %d not found during hydration", account.ID)
}
return hydrated, nil
}
func (s *GeminiMessagesCompatService) listSchedulableAccountsOnce(ctx context.Context, groupID *int64, platform string, hasForcePlatform bool) ([]Account, error) { func (s *GeminiMessagesCompatService) listSchedulableAccountsOnce(ctx context.Context, groupID *int64, platform string, hasForcePlatform bool) ([]Account, error) {
if s.schedulerSnapshot != nil { if s.schedulerSnapshot != nil {
accounts, _, err := s.schedulerSnapshot.ListSchedulableAccounts(ctx, groupID, platform, hasForcePlatform) accounts, _, err := s.schedulerSnapshot.ListSchedulableAccounts(ctx, groupID, platform, hasForcePlatform)
...@@ -546,7 +560,7 @@ func (s *GeminiMessagesCompatService) SelectAccountForAIStudioEndpoints(ctx cont ...@@ -546,7 +560,7 @@ func (s *GeminiMessagesCompatService) SelectAccountForAIStudioEndpoints(ctx cont
if selected == nil { if selected == nil {
return nil, errors.New("no available Gemini accounts") return nil, errors.New("no available Gemini accounts")
} }
return selected, nil return s.hydrateSelectedAccount(ctx, selected)
} }
func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Context, account *Account, body []byte) (*ForwardResult, error) { func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Context, account *Account, body []byte) (*ForwardResult, error) {
......
...@@ -3,8 +3,12 @@ package service ...@@ -3,8 +3,12 @@ package service
import ( import (
"strings" "strings"
"time" "time"
"github.com/Wei-Shaw/sub2api/internal/domain"
) )
type OpenAIMessagesDispatchModelConfig = domain.OpenAIMessagesDispatchModelConfig
type Group struct { type Group struct {
ID int64 ID int64
Name string Name string
...@@ -49,10 +53,11 @@ type Group struct { ...@@ -49,10 +53,11 @@ type Group struct {
SortOrder int SortOrder int
// OpenAI Messages 调度配置(仅 openai 平台使用) // OpenAI Messages 调度配置(仅 openai 平台使用)
AllowMessagesDispatch bool AllowMessagesDispatch bool
RequireOAuthOnly bool // 仅允许非 apikey 类型账号关联(OpenAI/Antigravity/Anthropic/Gemini) RequireOAuthOnly bool // 仅允许非 apikey 类型账号关联(OpenAI/Antigravity/Anthropic/Gemini)
RequirePrivacySet bool // 调度时仅允许 privacy 已成功设置的账号(OpenAI/Antigravity/Anthropic/Gemini) RequirePrivacySet bool // 调度时仅允许 privacy 已成功设置的账号(OpenAI/Antigravity/Anthropic/Gemini)
DefaultMappedModel string DefaultMappedModel string
MessagesDispatchModelConfig OpenAIMessagesDispatchModelConfig
CreatedAt time.Time CreatedAt time.Time
UpdatedAt time.Time UpdatedAt time.Time
......
package service
import (
"bytes"
"fmt"
"strings"
"text/template"
)
type forcedCodexInstructionsTemplateData struct {
ExistingInstructions string
OriginalModel string
NormalizedModel string
BillingModel string
UpstreamModel string
}
func applyForcedCodexInstructionsTemplate(
reqBody map[string]any,
templateText string,
data forcedCodexInstructionsTemplateData,
) (bool, error) {
rendered, err := renderForcedCodexInstructionsTemplate(templateText, data)
if err != nil {
return false, err
}
if rendered == "" {
return false, nil
}
existing, _ := reqBody["instructions"].(string)
if strings.TrimSpace(existing) == rendered {
return false, nil
}
reqBody["instructions"] = rendered
return true, nil
}
func renderForcedCodexInstructionsTemplate(
templateText string,
data forcedCodexInstructionsTemplateData,
) (string, error) {
tmpl, err := template.New("forced_codex_instructions").Option("missingkey=zero").Parse(templateText)
if err != nil {
return "", fmt.Errorf("parse forced codex instructions template: %w", err)
}
var buf bytes.Buffer
if err := tmpl.Execute(&buf, data); err != nil {
return "", fmt.Errorf("render forced codex instructions template: %w", err)
}
return strings.TrimSpace(buf.String()), nil
}
...@@ -6,9 +6,12 @@ import ( ...@@ -6,9 +6,12 @@ import (
"io" "io"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"os"
"path/filepath"
"strings" "strings"
"testing" "testing"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/apicompat" "github.com/Wei-Shaw/sub2api/internal/pkg/apicompat"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
...@@ -127,3 +130,101 @@ func TestForwardAsAnthropic_NormalizesRoutingAndEffortForGpt54XHigh(t *testing.T ...@@ -127,3 +130,101 @@ func TestForwardAsAnthropic_NormalizesRoutingAndEffortForGpt54XHigh(t *testing.T
t.Logf("upstream body: %s", string(upstream.lastBody)) t.Logf("upstream body: %s", string(upstream.lastBody))
t.Logf("response body: %s", rec.Body.String()) t.Logf("response body: %s", rec.Body.String())
} }
func TestForwardAsAnthropic_ForcedCodexInstructionsTemplatePrependsRenderedInstructions(t *testing.T) {
t.Parallel()
gin.SetMode(gin.TestMode)
templateDir := t.TempDir()
templatePath := filepath.Join(templateDir, "codex-instructions.md.tmpl")
require.NoError(t, os.WriteFile(templatePath, []byte("server-prefix\n\n{{ .ExistingInstructions }}"), 0o644))
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
body := []byte(`{"model":"gpt-5.4","max_tokens":16,"system":"client-system","messages":[{"role":"user","content":"hello"}],"stream":false}`)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(body))
c.Request.Header.Set("Content-Type", "application/json")
upstreamBody := strings.Join([]string{
`data: {"type":"response.completed","response":{"id":"resp_1","object":"response","model":"gpt-5.4","status":"completed","output":[{"type":"message","id":"msg_1","role":"assistant","status":"completed","content":[{"type":"output_text","text":"ok"}]}],"usage":{"input_tokens":5,"output_tokens":2,"total_tokens":7}}}`,
"",
"data: [DONE]",
"",
}, "\n")
upstream := &httpUpstreamRecorder{resp: &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid_forced"}},
Body: io.NopCloser(strings.NewReader(upstreamBody)),
}}
svc := &OpenAIGatewayService{
cfg: &config.Config{Gateway: config.GatewayConfig{
ForcedCodexInstructionsTemplateFile: templatePath,
ForcedCodexInstructionsTemplate: "server-prefix\n\n{{ .ExistingInstructions }}",
}},
httpUpstream: upstream,
}
account := &Account{
ID: 1,
Name: "openai-oauth",
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Concurrency: 1,
Credentials: map[string]any{
"access_token": "oauth-token",
"chatgpt_account_id": "chatgpt-acc",
},
}
result, err := svc.ForwardAsAnthropic(context.Background(), c, account, body, "", "gpt-5.1")
require.NoError(t, err)
require.NotNil(t, result)
require.Equal(t, "server-prefix\n\nclient-system", gjson.GetBytes(upstream.lastBody, "instructions").String())
}
func TestForwardAsAnthropic_ForcedCodexInstructionsTemplateUsesCachedTemplateContent(t *testing.T) {
t.Parallel()
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
body := []byte(`{"model":"gpt-5.4","max_tokens":16,"system":"client-system","messages":[{"role":"user","content":"hello"}],"stream":false}`)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(body))
c.Request.Header.Set("Content-Type", "application/json")
upstreamBody := strings.Join([]string{
`data: {"type":"response.completed","response":{"id":"resp_1","object":"response","model":"gpt-5.4","status":"completed","output":[{"type":"message","id":"msg_1","role":"assistant","status":"completed","content":[{"type":"output_text","text":"ok"}]}],"usage":{"input_tokens":5,"output_tokens":2,"total_tokens":7}}}`,
"",
"data: [DONE]",
"",
}, "\n")
upstream := &httpUpstreamRecorder{resp: &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid_forced_cached"}},
Body: io.NopCloser(strings.NewReader(upstreamBody)),
}}
svc := &OpenAIGatewayService{
cfg: &config.Config{Gateway: config.GatewayConfig{
ForcedCodexInstructionsTemplateFile: "/path/that/should/not/be/read.tmpl",
ForcedCodexInstructionsTemplate: "cached-prefix\n\n{{ .ExistingInstructions }}",
}},
httpUpstream: upstream,
}
account := &Account{
ID: 1,
Name: "openai-oauth",
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Concurrency: 1,
Credentials: map[string]any{
"access_token": "oauth-token",
"chatgpt_account_id": "chatgpt-acc",
},
}
result, err := svc.ForwardAsAnthropic(context.Background(), c, account, body, "", "gpt-5.1")
require.NoError(t, err)
require.NotNil(t, result)
require.Equal(t, "cached-prefix\n\nclient-system", gjson.GetBytes(upstream.lastBody, "instructions").String())
}
...@@ -86,6 +86,24 @@ func (s *OpenAIGatewayService) ForwardAsAnthropic( ...@@ -86,6 +86,24 @@ func (s *OpenAIGatewayService) ForwardAsAnthropic(
return nil, fmt.Errorf("unmarshal for codex transform: %w", err) return nil, fmt.Errorf("unmarshal for codex transform: %w", err)
} }
codexResult := applyCodexOAuthTransform(reqBody, false, false) codexResult := applyCodexOAuthTransform(reqBody, false, false)
forcedTemplateText := ""
if s.cfg != nil {
forcedTemplateText = s.cfg.Gateway.ForcedCodexInstructionsTemplate
}
templateUpstreamModel := upstreamModel
if codexResult.NormalizedModel != "" {
templateUpstreamModel = codexResult.NormalizedModel
}
existingInstructions, _ := reqBody["instructions"].(string)
if _, err := applyForcedCodexInstructionsTemplate(reqBody, forcedTemplateText, forcedCodexInstructionsTemplateData{
ExistingInstructions: strings.TrimSpace(existingInstructions),
OriginalModel: originalModel,
NormalizedModel: normalizedModel,
BillingModel: billingModel,
UpstreamModel: templateUpstreamModel,
}); err != nil {
return nil, err
}
if codexResult.NormalizedModel != "" { if codexResult.NormalizedModel != "" {
upstreamModel = codexResult.NormalizedModel upstreamModel = codexResult.NormalizedModel
} }
......
...@@ -1243,7 +1243,7 @@ func (s *OpenAIGatewayService) selectAccountForModelWithExclusions(ctx context.C ...@@ -1243,7 +1243,7 @@ func (s *OpenAIGatewayService) selectAccountForModelWithExclusions(ctx context.C
_ = s.setStickySessionAccountID(ctx, groupID, sessionHash, selected.ID, openaiStickySessionTTL) _ = s.setStickySessionAccountID(ctx, groupID, sessionHash, selected.ID, openaiStickySessionTTL)
} }
return selected, nil return s.hydrateSelectedAccount(ctx, selected)
} }
// tryStickySessionHit 尝试从粘性会话获取账号。 // tryStickySessionHit 尝试从粘性会话获取账号。
...@@ -1408,35 +1408,25 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex ...@@ -1408,35 +1408,25 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex
} }
result, err := s.tryAcquireAccountSlot(ctx, account.ID, account.Concurrency) result, err := s.tryAcquireAccountSlot(ctx, account.ID, account.Concurrency)
if err == nil && result.Acquired { if err == nil && result.Acquired {
return &AccountSelectionResult{ return s.newSelectionResult(ctx, account, true, result.ReleaseFunc, nil)
Account: account,
Acquired: true,
ReleaseFunc: result.ReleaseFunc,
}, nil
} }
if stickyAccountID > 0 && stickyAccountID == account.ID && s.concurrencyService != nil { if stickyAccountID > 0 && stickyAccountID == account.ID && s.concurrencyService != nil {
waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, account.ID) waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, account.ID)
if waitingCount < cfg.StickySessionMaxWaiting { if waitingCount < cfg.StickySessionMaxWaiting {
return &AccountSelectionResult{ return s.newSelectionResult(ctx, account, false, nil, &AccountWaitPlan{
Account: account, AccountID: account.ID,
WaitPlan: &AccountWaitPlan{ MaxConcurrency: account.Concurrency,
AccountID: account.ID, Timeout: cfg.StickySessionWaitTimeout,
MaxConcurrency: account.Concurrency, MaxWaiting: cfg.StickySessionMaxWaiting,
Timeout: cfg.StickySessionWaitTimeout, })
MaxWaiting: cfg.StickySessionMaxWaiting,
},
}, nil
} }
} }
return &AccountSelectionResult{ return s.newSelectionResult(ctx, account, false, nil, &AccountWaitPlan{
Account: account, AccountID: account.ID,
WaitPlan: &AccountWaitPlan{ MaxConcurrency: account.Concurrency,
AccountID: account.ID, Timeout: cfg.FallbackWaitTimeout,
MaxConcurrency: account.Concurrency, MaxWaiting: cfg.FallbackMaxWaiting,
Timeout: cfg.FallbackWaitTimeout, })
MaxWaiting: cfg.FallbackMaxWaiting,
},
}, nil
} }
accounts, err := s.listSchedulableAccounts(ctx, groupID) accounts, err := s.listSchedulableAccounts(ctx, groupID)
...@@ -1476,24 +1466,17 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex ...@@ -1476,24 +1466,17 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex
result, err := s.tryAcquireAccountSlot(ctx, accountID, account.Concurrency) result, err := s.tryAcquireAccountSlot(ctx, accountID, account.Concurrency)
if err == nil && result.Acquired { if err == nil && result.Acquired {
_ = s.refreshStickySessionTTL(ctx, groupID, sessionHash, openaiStickySessionTTL) _ = s.refreshStickySessionTTL(ctx, groupID, sessionHash, openaiStickySessionTTL)
return &AccountSelectionResult{ return s.newSelectionResult(ctx, account, true, result.ReleaseFunc, nil)
Account: account,
Acquired: true,
ReleaseFunc: result.ReleaseFunc,
}, nil
} }
waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, accountID) waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, accountID)
if waitingCount < cfg.StickySessionMaxWaiting { if waitingCount < cfg.StickySessionMaxWaiting {
return &AccountSelectionResult{ return s.newSelectionResult(ctx, account, false, nil, &AccountWaitPlan{
Account: account, AccountID: accountID,
WaitPlan: &AccountWaitPlan{ MaxConcurrency: account.Concurrency,
AccountID: accountID, Timeout: cfg.StickySessionWaitTimeout,
MaxConcurrency: account.Concurrency, MaxWaiting: cfg.StickySessionMaxWaiting,
Timeout: cfg.StickySessionWaitTimeout, })
MaxWaiting: cfg.StickySessionMaxWaiting,
},
}, nil
} }
} }
} }
...@@ -1552,11 +1535,7 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex ...@@ -1552,11 +1535,7 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex
if sessionHash != "" { if sessionHash != "" {
_ = s.setStickySessionAccountID(ctx, groupID, sessionHash, fresh.ID, openaiStickySessionTTL) _ = s.setStickySessionAccountID(ctx, groupID, sessionHash, fresh.ID, openaiStickySessionTTL)
} }
return &AccountSelectionResult{ return s.newSelectionResult(ctx, fresh, true, result.ReleaseFunc, nil)
Account: fresh,
Acquired: true,
ReleaseFunc: result.ReleaseFunc,
}, nil
} }
} }
} else { } else {
...@@ -1609,11 +1588,7 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex ...@@ -1609,11 +1588,7 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex
if sessionHash != "" { if sessionHash != "" {
_ = s.setStickySessionAccountID(ctx, groupID, sessionHash, fresh.ID, openaiStickySessionTTL) _ = s.setStickySessionAccountID(ctx, groupID, sessionHash, fresh.ID, openaiStickySessionTTL)
} }
return &AccountSelectionResult{ return s.newSelectionResult(ctx, fresh, true, result.ReleaseFunc, nil)
Account: fresh,
Acquired: true,
ReleaseFunc: result.ReleaseFunc,
}, nil
} }
} }
} }
...@@ -1629,15 +1604,12 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex ...@@ -1629,15 +1604,12 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex
if needsUpstreamCheck && s.isUpstreamModelRestrictedByChannel(ctx, *groupID, fresh, requestedModel) { if needsUpstreamCheck && s.isUpstreamModelRestrictedByChannel(ctx, *groupID, fresh, requestedModel) {
continue continue
} }
return &AccountSelectionResult{ return s.newSelectionResult(ctx, fresh, false, nil, &AccountWaitPlan{
Account: fresh, AccountID: fresh.ID,
WaitPlan: &AccountWaitPlan{ MaxConcurrency: fresh.Concurrency,
AccountID: fresh.ID, Timeout: cfg.FallbackWaitTimeout,
MaxConcurrency: fresh.Concurrency, MaxWaiting: cfg.FallbackMaxWaiting,
Timeout: cfg.FallbackWaitTimeout, })
MaxWaiting: cfg.FallbackMaxWaiting,
},
}, nil
} }
return nil, ErrNoAvailableAccounts return nil, ErrNoAvailableAccounts
...@@ -1732,6 +1704,33 @@ func (s *OpenAIGatewayService) getSchedulableAccount(ctx context.Context, accoun ...@@ -1732,6 +1704,33 @@ func (s *OpenAIGatewayService) getSchedulableAccount(ctx context.Context, accoun
return account, nil return account, nil
} }
func (s *OpenAIGatewayService) hydrateSelectedAccount(ctx context.Context, account *Account) (*Account, error) {
if account == nil || s.schedulerSnapshot == nil {
return account, nil
}
hydrated, err := s.schedulerSnapshot.GetAccount(ctx, account.ID)
if err != nil {
return nil, err
}
if hydrated == nil {
return nil, fmt.Errorf("selected openai account %d not found during hydration", account.ID)
}
return hydrated, nil
}
func (s *OpenAIGatewayService) newSelectionResult(ctx context.Context, account *Account, acquired bool, release func(), waitPlan *AccountWaitPlan) (*AccountSelectionResult, error) {
hydrated, err := s.hydrateSelectedAccount(ctx, account)
if err != nil {
return nil, err
}
return &AccountSelectionResult{
Account: hydrated,
Acquired: acquired,
ReleaseFunc: release,
WaitPlan: waitPlan,
}, nil
}
func (s *OpenAIGatewayService) schedulingConfig() config.GatewaySchedulingConfig { func (s *OpenAIGatewayService) schedulingConfig() config.GatewaySchedulingConfig {
if s.cfg != nil { if s.cfg != nil {
return s.cfg.Gateway.Scheduling return s.cfg.Gateway.Scheduling
......
package service
import "strings"
const (
defaultOpenAIMessagesDispatchOpusMappedModel = "gpt-5.4"
defaultOpenAIMessagesDispatchSonnetMappedModel = "gpt-5.3-codex"
defaultOpenAIMessagesDispatchHaikuMappedModel = "gpt-5.4-mini"
)
func normalizeOpenAIMessagesDispatchMappedModel(model string) string {
model = NormalizeOpenAICompatRequestedModel(strings.TrimSpace(model))
return strings.TrimSpace(model)
}
func normalizeOpenAIMessagesDispatchModelConfig(cfg OpenAIMessagesDispatchModelConfig) OpenAIMessagesDispatchModelConfig {
out := OpenAIMessagesDispatchModelConfig{
OpusMappedModel: normalizeOpenAIMessagesDispatchMappedModel(cfg.OpusMappedModel),
SonnetMappedModel: normalizeOpenAIMessagesDispatchMappedModel(cfg.SonnetMappedModel),
HaikuMappedModel: normalizeOpenAIMessagesDispatchMappedModel(cfg.HaikuMappedModel),
}
if len(cfg.ExactModelMappings) > 0 {
out.ExactModelMappings = make(map[string]string, len(cfg.ExactModelMappings))
for requestedModel, mappedModel := range cfg.ExactModelMappings {
requestedModel = strings.TrimSpace(requestedModel)
mappedModel = normalizeOpenAIMessagesDispatchMappedModel(mappedModel)
if requestedModel == "" || mappedModel == "" {
continue
}
out.ExactModelMappings[requestedModel] = mappedModel
}
if len(out.ExactModelMappings) == 0 {
out.ExactModelMappings = nil
}
}
return out
}
func claudeMessagesDispatchFamily(model string) string {
normalized := strings.ToLower(strings.TrimSpace(model))
if !strings.HasPrefix(normalized, "claude") {
return ""
}
switch {
case strings.Contains(normalized, "opus"):
return "opus"
case strings.Contains(normalized, "sonnet"):
return "sonnet"
case strings.Contains(normalized, "haiku"):
return "haiku"
default:
return ""
}
}
func (g *Group) ResolveMessagesDispatchModel(requestedModel string) string {
if g == nil {
return ""
}
requestedModel = strings.TrimSpace(requestedModel)
if requestedModel == "" {
return ""
}
cfg := normalizeOpenAIMessagesDispatchModelConfig(g.MessagesDispatchModelConfig)
if mappedModel := strings.TrimSpace(cfg.ExactModelMappings[requestedModel]); mappedModel != "" {
return mappedModel
}
switch claudeMessagesDispatchFamily(requestedModel) {
case "opus":
if mappedModel := strings.TrimSpace(cfg.OpusMappedModel); mappedModel != "" {
return mappedModel
}
return defaultOpenAIMessagesDispatchOpusMappedModel
case "sonnet":
if mappedModel := strings.TrimSpace(cfg.SonnetMappedModel); mappedModel != "" {
return mappedModel
}
return defaultOpenAIMessagesDispatchSonnetMappedModel
case "haiku":
if mappedModel := strings.TrimSpace(cfg.HaikuMappedModel); mappedModel != "" {
return mappedModel
}
return defaultOpenAIMessagesDispatchHaikuMappedModel
default:
return ""
}
}
func sanitizeGroupMessagesDispatchFields(g *Group) {
if g == nil || g.Platform == PlatformOpenAI {
return
}
g.AllowMessagesDispatch = false
g.DefaultMappedModel = ""
g.MessagesDispatchModelConfig = OpenAIMessagesDispatchModelConfig{}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment