Commit 654cfb64 authored by erio's avatar erio
Browse files

feat(channels): add "Available Channels" aggregate view

Add a read-only aggregate view per channel: its linked groups and a
deterministic wildcard-free supported-model list with pricing details.

Backend
- service.Channel.SupportedModels(): combine ModelMapping keys with
  same-platform ModelPricing.Models; trailing "*" keys expand via
  pricing prefix match; platforms without a mapping produce no
  entries (intentional "no mapping = not shown" rule).
- Extract splitWildcardSuffix() shared with toModelEntry.
- Build a per-call pricing lookup map (platform+lowerName -> *pricing)
  to avoid O(N*M) scans in SupportedModels.
- ChannelService.ListAvailable() aggregates channels + active groups;
  filters out group IDs no longer active.
- Admin route GET /api/v1/admin/channels/available returns the full
  DTO (id, status, billing_model_source, restrict_models, groups,
  supported_models).
- User route GET /api/v1/channels/available applies three filters:
  Status==active, visible-group intersection, and platform filter
  on supported_models (prevents cross-platform leak when a channel
  links to both a user-accessible group and an inaccessible one on
  another platform). Response is a plain array (matches the
  /groups/available sibling shape). Field whitelist omits
  billing_model_source, restrict_models, ids, status, sort_order.

Frontend
- New /admin/available-channels and /available-channels views backed
  by a shared AvailableChannelsTable component (admin adds status +
  billing-source columns via slots).
- PricingRow extracted to its own SFC; SupportedModelChip references
  shared billing-mode constants in constants/channel.ts.
- Sidebar: new entry above "渠道管理" for admin; matching entry in
  user nav.
- i18n: zh + en coverage for both namespaces.

Tests
- SupportedModels: wildcard-only pricing skipped, prefix-matches-
  nothing, cross-platform bleed, case-insensitive dedup, empty
  platform mapping.
- ListAvailable: nil groupRepo, inactive-group-ID dropped, stable
  case-insensitive name sort.
- User handler: 401 on unauthenticated, visible-group intersection,
  platform filter on supported_models, JSON whitelist.
- Admin handler: full DTO including default BillingModelSource
  fallback.

Refs: issue #1729
parent c46744f3
...@@ -174,7 +174,8 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { ...@@ -174,7 +174,8 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
claudeTokenProvider := service.ProvideClaudeTokenProvider(accountRepository, geminiTokenCache, oAuthService, oAuthRefreshAPI) claudeTokenProvider := service.ProvideClaudeTokenProvider(accountRepository, geminiTokenCache, oAuthService, oAuthRefreshAPI)
digestSessionStore := service.NewDigestSessionStore() digestSessionStore := service.NewDigestSessionStore()
channelRepository := repository.NewChannelRepository(db) channelRepository := repository.NewChannelRepository(db)
channelService := service.NewChannelService(channelRepository, apiKeyAuthCacheInvalidator) channelService := service.NewChannelService(channelRepository, groupRepository, apiKeyAuthCacheInvalidator)
availableChannelHandler := admin.NewAvailableChannelHandler(channelService)
modelPricingResolver := service.NewModelPricingResolver(channelService, billingService) modelPricingResolver := service.NewModelPricingResolver(channelService, billingService)
balanceNotifyService := service.ProvideBalanceNotifyService(emailService, settingRepository, accountRepository) balanceNotifyService := service.ProvideBalanceNotifyService(emailService, settingRepository, accountRepository)
gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, sessionLimitCache, rpmCache, digestSessionStore, settingService, tlsFingerprintProfileService, channelService, modelPricingResolver, balanceNotifyService) gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, sessionLimitCache, rpmCache, digestSessionStore, settingService, tlsFingerprintProfileService, channelService, modelPricingResolver, balanceNotifyService)
...@@ -234,7 +235,8 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { ...@@ -234,7 +235,8 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
settingHandler := admin.NewSettingHandler(settingService, emailService, turnstileService, opsService, paymentConfigService, paymentService) settingHandler := admin.NewSettingHandler(settingService, emailService, turnstileService, opsService, paymentConfigService, paymentService)
paymentOrderExpiryService := service.ProvidePaymentOrderExpiryService(paymentService) paymentOrderExpiryService := service.ProvidePaymentOrderExpiryService(paymentService)
paymentHandler := admin.NewPaymentHandler(paymentService, paymentConfigService) paymentHandler := admin.NewPaymentHandler(paymentService, paymentConfigService)
adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, adminAnnouncementHandler, dataManagementHandler, backupHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, promoHandler, settingHandler, opsHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler, errorPassthroughHandler, tlsFingerprintProfileHandler, adminAPIKeyHandler, scheduledTestHandler, channelHandler, channelMonitorHandler, channelMonitorRequestTemplateHandler, paymentHandler) availableChannelUserHandler := handler.NewAvailableChannelHandler(channelService, apiKeyService)
adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, adminAnnouncementHandler, dataManagementHandler, backupHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, promoHandler, settingHandler, opsHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler, errorPassthroughHandler, tlsFingerprintProfileHandler, adminAPIKeyHandler, scheduledTestHandler, channelHandler, channelMonitorHandler, channelMonitorRequestTemplateHandler, availableChannelHandler, paymentHandler)
usageRecordWorkerPool := service.NewUsageRecordWorkerPool(configConfig) usageRecordWorkerPool := service.NewUsageRecordWorkerPool(configConfig)
userMsgQueueCache := repository.NewUserMsgQueueCache(redisClient) userMsgQueueCache := repository.NewUserMsgQueueCache(redisClient)
userMessageQueueService := service.ProvideUserMessageQueueService(userMsgQueueCache, rpmCache, configConfig) userMessageQueueService := service.ProvideUserMessageQueueService(userMsgQueueCache, rpmCache, configConfig)
...@@ -246,7 +248,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { ...@@ -246,7 +248,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
paymentWebhookHandler := handler.NewPaymentWebhookHandler(paymentService, registry) paymentWebhookHandler := handler.NewPaymentWebhookHandler(paymentService, registry)
idempotencyCoordinator := service.ProvideIdempotencyCoordinator(idempotencyRepository, configConfig) idempotencyCoordinator := service.ProvideIdempotencyCoordinator(idempotencyRepository, configConfig)
idempotencyCleanupService := service.ProvideIdempotencyCleanupService(idempotencyRepository, configConfig) idempotencyCleanupService := service.ProvideIdempotencyCleanupService(idempotencyRepository, configConfig)
handlers := handler.ProvideHandlers(authHandler, userHandler, apiKeyHandler, usageHandler, redeemHandler, subscriptionHandler, announcementHandler, channelMonitorUserHandler, adminHandlers, gatewayHandler, openAIGatewayHandler, handlerSettingHandler, totpHandler, handlerPaymentHandler, paymentWebhookHandler, idempotencyCoordinator, idempotencyCleanupService) handlers := handler.ProvideHandlers(authHandler, userHandler, apiKeyHandler, usageHandler, redeemHandler, subscriptionHandler, announcementHandler, channelMonitorUserHandler, adminHandlers, gatewayHandler, openAIGatewayHandler, handlerSettingHandler, totpHandler, handlerPaymentHandler, paymentWebhookHandler, availableChannelUserHandler, idempotencyCoordinator, idempotencyCleanupService)
jwtAuthMiddleware := middleware.NewJWTAuthMiddleware(authService, userService) jwtAuthMiddleware := middleware.NewJWTAuthMiddleware(authService, userService)
adminAuthMiddleware := middleware.NewAdminAuthMiddleware(authService, userService, settingService) adminAuthMiddleware := middleware.NewAdminAuthMiddleware(authService, userService, settingService)
apiKeyAuthMiddleware := middleware.NewAPIKeyAuthMiddleware(apiKeyService, subscriptionService, configConfig) apiKeyAuthMiddleware := middleware.NewAPIKeyAuthMiddleware(apiKeyService, subscriptionService, configConfig)
......
package admin
import (
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
// AvailableChannelHandler 处理「可用渠道」聚合视图的管理员接口。
//
// 该视图以只读方式聚合渠道基础信息、关联分组与推导出的支持模型列表(无通配符)。
type AvailableChannelHandler struct {
channelService *service.ChannelService
}
// NewAvailableChannelHandler 创建 AvailableChannelHandler 实例。
func NewAvailableChannelHandler(channelService *service.ChannelService) *AvailableChannelHandler {
return &AvailableChannelHandler{channelService: channelService}
}
// availableGroupResponse 响应中的分组概要。
type availableGroupResponse struct {
ID int64 `json:"id"`
Name string `json:"name"`
Platform string `json:"platform"`
}
// supportedModelResponse 响应中的支持模型条目。
type supportedModelResponse struct {
Name string `json:"name"`
Platform string `json:"platform"`
Pricing *channelModelPricingResponse `json:"pricing"`
}
// availableChannelResponse 管理员视图完整字段集。
type availableChannelResponse struct {
ID int64 `json:"id"`
Name string `json:"name"`
Description string `json:"description"`
Status string `json:"status"`
BillingModelSource string `json:"billing_model_source"`
RestrictModels bool `json:"restrict_models"`
Groups []availableGroupResponse `json:"groups"`
SupportedModels []supportedModelResponse `json:"supported_models"`
}
// AvailableChannelToAdminResponse 将 service 层的 AvailableChannel 转为管理员 DTO。
// 导出供同 package 的复用;也用于构造测试 fixture。
func AvailableChannelToAdminResponse(ch service.AvailableChannel) availableChannelResponse {
groups := make([]availableGroupResponse, 0, len(ch.Groups))
for _, g := range ch.Groups {
groups = append(groups, availableGroupResponse{ID: g.ID, Name: g.Name, Platform: g.Platform})
}
models := make([]supportedModelResponse, 0, len(ch.SupportedModels))
for i := range ch.SupportedModels {
m := ch.SupportedModels[i]
var pricing *channelModelPricingResponse
if m.Pricing != nil {
p := pricingToResponse(m.Pricing)
pricing = &p
}
models = append(models, supportedModelResponse{
Name: m.Name,
Platform: m.Platform,
Pricing: pricing,
})
}
billingSource := ch.BillingModelSource
if billingSource == "" {
billingSource = service.BillingModelSourceChannelMapped
}
return availableChannelResponse{
ID: ch.ID,
Name: ch.Name,
Description: ch.Description,
Status: ch.Status,
BillingModelSource: billingSource,
RestrictModels: ch.RestrictModels,
Groups: groups,
SupportedModels: models,
}
}
// List 列出所有可用渠道(管理员视图)。
// GET /api/v1/admin/channels/available
func (h *AvailableChannelHandler) List(c *gin.Context) {
channels, err := h.channelService.ListAvailable(c.Request.Context())
if err != nil {
response.ErrorFrom(c, err)
return
}
out := make([]availableChannelResponse, 0, len(channels))
for _, ch := range channels {
out = append(out, AvailableChannelToAdminResponse(ch))
}
response.Success(c, gin.H{"items": out})
}
//go:build unit
package admin
import (
"encoding/json"
"testing"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/stretchr/testify/require"
)
func TestAvailableChannelToAdminResponse_IncludesFullDTO(t *testing.T) {
// 管理员视图应包含 id / status / billing_model_source / restrict_models 等
// 管理字段;BillingModelSource 为空时应默认回填 channel_mapped。
input := service.AvailableChannel{
ID: 42,
Name: "ch",
Description: "d",
Status: service.StatusActive,
BillingModelSource: "", // 验证默认值填充
RestrictModels: true,
Groups: []service.AvailableGroupRef{
{ID: 1, Name: "g1", Platform: "anthropic"},
},
SupportedModels: []service.SupportedModel{
{Name: "claude-sonnet-4-6", Platform: "anthropic"},
},
}
resp := AvailableChannelToAdminResponse(input)
require.Equal(t, int64(42), resp.ID)
require.Equal(t, "ch", resp.Name)
require.Equal(t, service.StatusActive, resp.Status)
require.Equal(t, service.BillingModelSourceChannelMapped, resp.BillingModelSource)
require.True(t, resp.RestrictModels)
require.Len(t, resp.Groups, 1)
require.Len(t, resp.SupportedModels, 1)
// JSON 层验证管理字段确实会被序列化。
raw, err := json.Marshal(resp)
require.NoError(t, err)
var decoded map[string]any
require.NoError(t, json.Unmarshal(raw, &decoded))
for _, key := range []string{"id", "status", "billing_model_source", "restrict_models", "groups", "supported_models"} {
_, exists := decoded[key]
require.Truef(t, exists, "admin DTO must expose %q", key)
}
}
func TestAvailableChannelToAdminResponse_PreservesExplicitBillingSource(t *testing.T) {
input := service.AvailableChannel{
BillingModelSource: service.BillingModelSourceUpstream,
}
resp := AvailableChannelToAdminResponse(input)
require.Equal(t, service.BillingModelSourceUpstream, resp.BillingModelSource)
}
package handler
import (
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
// AvailableChannelHandler 处理用户侧「可用渠道」查询。
//
// 用户侧接口委托 ChannelService.ListAvailable,并在返回前做三层过滤:
// 1. 行过滤:只保留状态为 Active 且与当前用户可访问分组有交集的渠道;
// 2. 分组过滤:渠道的 Groups 只保留用户可访问的那些;
// 3. 平台过滤:渠道的 SupportedModels 只保留平台在用户可见 Groups 中出现过的模型,
// 防止"渠道同时挂在 antigravity / anthropic 两个平台的分组上,用户只访问
// antigravity,却看到 anthropic 模型"这类跨平台信息泄漏;
// 4. 字段白名单:仅返回用户需要的字段(省略 BillingModelSource / RestrictModels
// / 内部 ID / Status 等管理字段)。
type AvailableChannelHandler struct {
channelService *service.ChannelService
apiKeyService *service.APIKeyService
}
// NewAvailableChannelHandler 创建用户侧可用渠道 handler。
func NewAvailableChannelHandler(
channelService *service.ChannelService,
apiKeyService *service.APIKeyService,
) *AvailableChannelHandler {
return &AvailableChannelHandler{
channelService: channelService,
apiKeyService: apiKeyService,
}
}
// userAvailableGroup 用户可见的分组概要(白名单字段)。
type userAvailableGroup struct {
ID int64 `json:"id"`
Name string `json:"name"`
Platform string `json:"platform"`
}
// userSupportedModelPricing 用户可见的定价字段白名单。
type userSupportedModelPricing struct {
BillingMode string `json:"billing_mode"`
InputPrice *float64 `json:"input_price"`
OutputPrice *float64 `json:"output_price"`
CacheWritePrice *float64 `json:"cache_write_price"`
CacheReadPrice *float64 `json:"cache_read_price"`
ImageOutputPrice *float64 `json:"image_output_price"`
PerRequestPrice *float64 `json:"per_request_price"`
Intervals []userPricingIntervalDTO `json:"intervals"`
}
// userPricingIntervalDTO 定价区间白名单(去掉内部 ID、SortOrder 等前端不渲染的字段)。
type userPricingIntervalDTO struct {
MinTokens int `json:"min_tokens"`
MaxTokens *int `json:"max_tokens"`
TierLabel string `json:"tier_label,omitempty"`
InputPrice *float64 `json:"input_price"`
OutputPrice *float64 `json:"output_price"`
CacheWritePrice *float64 `json:"cache_write_price"`
CacheReadPrice *float64 `json:"cache_read_price"`
PerRequestPrice *float64 `json:"per_request_price"`
}
// userSupportedModel 用户可见的支持模型条目。
type userSupportedModel struct {
Name string `json:"name"`
Platform string `json:"platform"`
Pricing *userSupportedModelPricing `json:"pricing"`
}
// userAvailableChannel 用户可见的渠道条目(白名单字段)。
type userAvailableChannel struct {
Name string `json:"name"`
Description string `json:"description"`
Groups []userAvailableGroup `json:"groups"`
SupportedModels []userSupportedModel `json:"supported_models"`
}
// List 列出当前用户可见的「可用渠道」。
// GET /api/v1/channels/available
func (h *AvailableChannelHandler) List(c *gin.Context) {
subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !ok {
response.Unauthorized(c, "User not authenticated")
return
}
userGroups, err := h.apiKeyService.GetAvailableGroups(c.Request.Context(), subject.UserID)
if err != nil {
response.ErrorFrom(c, err)
return
}
allowedGroupIDs := make(map[int64]struct{}, len(userGroups))
for i := range userGroups {
allowedGroupIDs[userGroups[i].ID] = struct{}{}
}
channels, err := h.channelService.ListAvailable(c.Request.Context())
if err != nil {
response.ErrorFrom(c, err)
return
}
out := make([]userAvailableChannel, 0, len(channels))
for _, ch := range channels {
if ch.Status != service.StatusActive {
continue
}
visibleGroups := filterUserVisibleGroups(ch.Groups, allowedGroupIDs)
if len(visibleGroups) == 0 {
continue
}
allowedPlatforms := collectGroupPlatforms(visibleGroups)
out = append(out, userAvailableChannel{
Name: ch.Name,
Description: ch.Description,
Groups: visibleGroups,
SupportedModels: toUserSupportedModels(ch.SupportedModels, allowedPlatforms),
})
}
response.Success(c, out)
}
// collectGroupPlatforms 聚合 visible groups 覆盖的平台集合,用于过滤 SupportedModels。
func collectGroupPlatforms(groups []userAvailableGroup) map[string]struct{} {
set := make(map[string]struct{}, len(groups))
for _, g := range groups {
if g.Platform == "" {
continue
}
set[g.Platform] = struct{}{}
}
return set
}
// filterUserVisibleGroups 仅保留用户可访问的分组。
func filterUserVisibleGroups(
groups []service.AvailableGroupRef,
allowed map[int64]struct{},
) []userAvailableGroup {
visible := make([]userAvailableGroup, 0, len(groups))
for _, g := range groups {
if _, ok := allowed[g.ID]; !ok {
continue
}
visible = append(visible, userAvailableGroup{
ID: g.ID,
Name: g.Name,
Platform: g.Platform,
})
}
return visible
}
// toUserSupportedModels 将 service 层支持模型转换为用户 DTO(字段白名单)。
// 仅保留平台在 allowedPlatforms 中的条目,防止跨平台模型信息泄漏。
// allowedPlatforms 为 nil 时不做平台过滤(保留全部,供测试或明确无过滤场景使用)。
func toUserSupportedModels(
src []service.SupportedModel,
allowedPlatforms map[string]struct{},
) []userSupportedModel {
out := make([]userSupportedModel, 0, len(src))
for i := range src {
m := src[i]
if allowedPlatforms != nil {
if _, ok := allowedPlatforms[m.Platform]; !ok {
continue
}
}
out = append(out, userSupportedModel{
Name: m.Name,
Platform: m.Platform,
Pricing: toUserPricing(m.Pricing),
})
}
return out
}
// toUserPricing 将 service 层定价转换为用户 DTO;入参为 nil 时返回 nil。
func toUserPricing(p *service.ChannelModelPricing) *userSupportedModelPricing {
if p == nil {
return nil
}
intervals := make([]userPricingIntervalDTO, 0, len(p.Intervals))
for _, iv := range p.Intervals {
intervals = append(intervals, userPricingIntervalDTO{
MinTokens: iv.MinTokens,
MaxTokens: iv.MaxTokens,
TierLabel: iv.TierLabel,
InputPrice: iv.InputPrice,
OutputPrice: iv.OutputPrice,
CacheWritePrice: iv.CacheWritePrice,
CacheReadPrice: iv.CacheReadPrice,
PerRequestPrice: iv.PerRequestPrice,
})
}
billingMode := string(p.BillingMode)
if billingMode == "" {
billingMode = string(service.BillingModeToken)
}
return &userSupportedModelPricing{
BillingMode: billingMode,
InputPrice: p.InputPrice,
OutputPrice: p.OutputPrice,
CacheWritePrice: p.CacheWritePrice,
CacheReadPrice: p.CacheReadPrice,
ImageOutputPrice: p.ImageOutputPrice,
PerRequestPrice: p.PerRequestPrice,
Intervals: intervals,
}
}
//go:build unit
package handler
import (
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
func TestUserAvailableChannel_Unauthenticated401(t *testing.T) {
// 没有 AuthSubject 注入时,handler 应返回 401 且不触达 service 依赖。
gin.SetMode(gin.TestMode)
h := &AvailableChannelHandler{} // nil services — 401 路径不会调用它们
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodGet, "/api/v1/channels/available", nil)
h.List(c)
require.Equal(t, http.StatusUnauthorized, w.Code)
}
func TestFilterUserVisibleGroups_IntersectionOnly(t *testing.T) {
// 渠道挂在 {g1, g2, g3},用户只允许 {g1, g3} —— 响应必须仅含 g1/g3。
groups := []service.AvailableGroupRef{
{ID: 1, Name: "g1", Platform: "anthropic"},
{ID: 2, Name: "g2", Platform: "anthropic"},
{ID: 3, Name: "g3", Platform: "openai"},
}
allowed := map[int64]struct{}{1: {}, 3: {}}
visible := filterUserVisibleGroups(groups, allowed)
require.Len(t, visible, 2)
ids := []int64{visible[0].ID, visible[1].ID}
require.ElementsMatch(t, []int64{1, 3}, ids)
}
func TestCollectGroupPlatforms_DerivesAllowedSet(t *testing.T) {
groups := []userAvailableGroup{
{ID: 1, Platform: "anthropic"},
{ID: 2, Platform: "openai"},
{ID: 3, Platform: "anthropic"}, // 去重
{ID: 4, Platform: ""}, // 空平台忽略
}
got := collectGroupPlatforms(groups)
require.Len(t, got, 2)
_, hasAnt := got["anthropic"]
_, hasOA := got["openai"]
require.True(t, hasAnt)
require.True(t, hasOA)
}
func TestToUserSupportedModels_FiltersByAllowedPlatforms(t *testing.T) {
// 用户可访问分组只覆盖 anthropic;anthropic 平台的模型保留,openai 模型被剔除。
src := []service.SupportedModel{
{Name: "claude-sonnet-4-6", Platform: "anthropic", Pricing: nil},
{Name: "gpt-4o", Platform: "openai", Pricing: nil},
}
allowed := map[string]struct{}{"anthropic": {}}
out := toUserSupportedModels(src, allowed)
require.Len(t, out, 1)
require.Equal(t, "claude-sonnet-4-6", out[0].Name)
}
func TestToUserSupportedModels_NilAllowedPlatformsKeepsAll(t *testing.T) {
// 显式传 nil allowedPlatforms 表示不做过滤。
src := []service.SupportedModel{
{Name: "a", Platform: "anthropic"},
{Name: "b", Platform: "openai"},
}
require.Len(t, toUserSupportedModels(src, nil), 2)
}
func TestUserAvailableChannel_FieldWhitelist(t *testing.T) {
// 通过序列化 userAvailableChannel 结构体验证响应形状:
// 只有 name / description / groups / supported_models;不含管理端字段。
row := userAvailableChannel{
Name: "ch",
Description: "d",
Groups: []userAvailableGroup{{ID: 1, Name: "g1", Platform: "anthropic"}},
SupportedModels: []userSupportedModel{},
}
raw, err := json.Marshal(row)
require.NoError(t, err)
var decoded map[string]any
require.NoError(t, json.Unmarshal(raw, &decoded))
for _, key := range []string{"id", "status", "billing_model_source", "restrict_models"} {
_, exists := decoded[key]
require.Falsef(t, exists, "user DTO must not expose %q", key)
}
for _, key := range []string{"name", "description", "groups", "supported_models"} {
_, exists := decoded[key]
require.Truef(t, exists, "user DTO must expose %q", key)
}
// pricing interval 白名单:不应暴露 id / sort_order。
pricing := toUserPricing(&service.ChannelModelPricing{
BillingMode: service.BillingModeToken,
Intervals: []service.PricingInterval{
{ID: 7, MinTokens: 0, MaxTokens: nil, SortOrder: 3},
},
})
require.NotNil(t, pricing)
require.Len(t, pricing.Intervals, 1)
rawIv, err := json.Marshal(pricing.Intervals[0])
require.NoError(t, err)
var ivDecoded map[string]any
require.NoError(t, json.Unmarshal(rawIv, &ivDecoded))
for _, key := range []string{"id", "pricing_id", "sort_order"} {
_, exists := ivDecoded[key]
require.Falsef(t, exists, "user pricing interval must not expose %q", key)
}
}
...@@ -33,26 +33,28 @@ type AdminHandlers struct { ...@@ -33,26 +33,28 @@ type AdminHandlers struct {
Channel *admin.ChannelHandler Channel *admin.ChannelHandler
ChannelMonitor *admin.ChannelMonitorHandler ChannelMonitor *admin.ChannelMonitorHandler
ChannelMonitorTemplate *admin.ChannelMonitorRequestTemplateHandler ChannelMonitorTemplate *admin.ChannelMonitorRequestTemplateHandler
AvailableChannel *admin.AvailableChannelHandler
Payment *admin.PaymentHandler Payment *admin.PaymentHandler
} }
// Handlers contains all HTTP handlers // Handlers contains all HTTP handlers
type Handlers struct { type Handlers struct {
Auth *AuthHandler Auth *AuthHandler
User *UserHandler User *UserHandler
APIKey *APIKeyHandler APIKey *APIKeyHandler
Usage *UsageHandler Usage *UsageHandler
Redeem *RedeemHandler Redeem *RedeemHandler
Subscription *SubscriptionHandler Subscription *SubscriptionHandler
Announcement *AnnouncementHandler Announcement *AnnouncementHandler
ChannelMonitor *ChannelMonitorUserHandler ChannelMonitor *ChannelMonitorUserHandler
Admin *AdminHandlers Admin *AdminHandlers
Gateway *GatewayHandler Gateway *GatewayHandler
OpenAIGateway *OpenAIGatewayHandler OpenAIGateway *OpenAIGatewayHandler
Setting *SettingHandler Setting *SettingHandler
Totp *TotpHandler Totp *TotpHandler
Payment *PaymentHandler Payment *PaymentHandler
PaymentWebhook *PaymentWebhookHandler PaymentWebhook *PaymentWebhookHandler
AvailableChannel *AvailableChannelHandler
} }
// BuildInfo contains build-time information // BuildInfo contains build-time information
......
...@@ -36,6 +36,7 @@ func ProvideAdminHandlers( ...@@ -36,6 +36,7 @@ func ProvideAdminHandlers(
channelHandler *admin.ChannelHandler, channelHandler *admin.ChannelHandler,
channelMonitorHandler *admin.ChannelMonitorHandler, channelMonitorHandler *admin.ChannelMonitorHandler,
channelMonitorTemplateHandler *admin.ChannelMonitorRequestTemplateHandler, channelMonitorTemplateHandler *admin.ChannelMonitorRequestTemplateHandler,
availableChannelHandler *admin.AvailableChannelHandler,
paymentHandler *admin.PaymentHandler, paymentHandler *admin.PaymentHandler,
) *AdminHandlers { ) *AdminHandlers {
return &AdminHandlers{ return &AdminHandlers{
...@@ -66,6 +67,7 @@ func ProvideAdminHandlers( ...@@ -66,6 +67,7 @@ func ProvideAdminHandlers(
Channel: channelHandler, Channel: channelHandler,
ChannelMonitor: channelMonitorHandler, ChannelMonitor: channelMonitorHandler,
ChannelMonitorTemplate: channelMonitorTemplateHandler, ChannelMonitorTemplate: channelMonitorTemplateHandler,
AvailableChannel: availableChannelHandler,
Payment: paymentHandler, Payment: paymentHandler,
} }
} }
...@@ -97,25 +99,27 @@ func ProvideHandlers( ...@@ -97,25 +99,27 @@ func ProvideHandlers(
totpHandler *TotpHandler, totpHandler *TotpHandler,
paymentHandler *PaymentHandler, paymentHandler *PaymentHandler,
paymentWebhookHandler *PaymentWebhookHandler, paymentWebhookHandler *PaymentWebhookHandler,
availableChannelHandler *AvailableChannelHandler,
_ *service.IdempotencyCoordinator, _ *service.IdempotencyCoordinator,
_ *service.IdempotencyCleanupService, _ *service.IdempotencyCleanupService,
) *Handlers { ) *Handlers {
return &Handlers{ return &Handlers{
Auth: authHandler, Auth: authHandler,
User: userHandler, User: userHandler,
APIKey: apiKeyHandler, APIKey: apiKeyHandler,
Usage: usageHandler, Usage: usageHandler,
Redeem: redeemHandler, Redeem: redeemHandler,
Subscription: subscriptionHandler, Subscription: subscriptionHandler,
Announcement: announcementHandler, Announcement: announcementHandler,
ChannelMonitor: channelMonitorUserHandler, ChannelMonitor: channelMonitorUserHandler,
Admin: adminHandlers, Admin: adminHandlers,
Gateway: gatewayHandler, Gateway: gatewayHandler,
OpenAIGateway: openaiGatewayHandler, OpenAIGateway: openaiGatewayHandler,
Setting: settingHandler, Setting: settingHandler,
Totp: totpHandler, Totp: totpHandler,
Payment: paymentHandler, Payment: paymentHandler,
PaymentWebhook: paymentWebhookHandler, PaymentWebhook: paymentWebhookHandler,
AvailableChannel: availableChannelHandler,
} }
} }
...@@ -136,6 +140,7 @@ var ProviderSet = wire.NewSet( ...@@ -136,6 +140,7 @@ var ProviderSet = wire.NewSet(
ProvideSettingHandler, ProvideSettingHandler,
NewPaymentHandler, NewPaymentHandler,
NewPaymentWebhookHandler, NewPaymentWebhookHandler,
NewAvailableChannelHandler,
// Admin handlers // Admin handlers
admin.NewDashboardHandler, admin.NewDashboardHandler,
...@@ -165,6 +170,7 @@ var ProviderSet = wire.NewSet( ...@@ -165,6 +170,7 @@ var ProviderSet = wire.NewSet(
admin.NewChannelHandler, admin.NewChannelHandler,
admin.NewChannelMonitorHandler, admin.NewChannelMonitorHandler,
admin.NewChannelMonitorRequestTemplateHandler, admin.NewChannelMonitorRequestTemplateHandler,
admin.NewAvailableChannelHandler,
admin.NewPaymentHandler, admin.NewPaymentHandler,
// AdminHandlers and Handlers constructors // AdminHandlers and Handlers constructors
......
...@@ -560,6 +560,7 @@ func registerChannelRoutes(admin *gin.RouterGroup, h *handler.Handlers) { ...@@ -560,6 +560,7 @@ func registerChannelRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
channels := admin.Group("/channels") channels := admin.Group("/channels")
{ {
channels.GET("", h.Admin.Channel.List) channels.GET("", h.Admin.Channel.List)
channels.GET("/available", h.Admin.AvailableChannel.List)
channels.GET("/model-pricing", h.Admin.Channel.GetModelDefaultPricing) channels.GET("/model-pricing", h.Admin.Channel.GetModelDefaultPricing)
channels.GET("/:id", h.Admin.Channel.GetByID) channels.GET("/:id", h.Admin.Channel.GetByID)
channels.POST("", h.Admin.Channel.Create) channels.POST("", h.Admin.Channel.Create)
......
...@@ -68,6 +68,12 @@ func RegisterUserRoutes( ...@@ -68,6 +68,12 @@ func RegisterUserRoutes(
groups.GET("/rates", h.APIKey.GetUserGroupRates) groups.GET("/rates", h.APIKey.GetUserGroupRates)
} }
// 用户可用渠道(非管理员接口)
channels := authenticated.Group("/channels")
{
channels.GET("/available", h.AvailableChannel.List)
}
// 使用记录 // 使用记录
usage := authenticated.Group("/usage") usage := authenticated.Group("/usage")
{ {
......
...@@ -345,3 +345,175 @@ type ChannelUsageFields struct { ...@@ -345,3 +345,175 @@ type ChannelUsageFields struct {
BillingModelSource string // 计费模型来源:"requested" / "upstream" / "channel_mapped" BillingModelSource string // 计费模型来源:"requested" / "upstream" / "channel_mapped"
ModelMappingChain string // 映射链描述,如 "a→b→c" ModelMappingChain string // 映射链描述,如 "a→b→c"
} }
// SupportedModel 渠道的一个支持模型条目(无通配符、可直接展示给用户)
type SupportedModel struct {
Name string // 用户侧模型名
Platform string // 所属平台
Pricing *ChannelModelPricing // 定价详情(nil 表示未配置定价)
}
// wildcardSuffix 是模型模式中的通配符后缀标记(仅支持尾部匹配)。
const wildcardSuffix = "*"
// splitWildcardSuffix 将模型模式拆分为 (prefix, isWildcard)。
//
// "claude-opus-*" → ("claude-opus-", true)
// "claude-opus-4" → ("claude-opus-4", false)
// "*" → ("", true)
//
// 注意:返回的 prefix 保持原始大小写,由调用方按需 ToLower。
func splitWildcardSuffix(pattern string) (prefix string, isWildcard bool) {
if strings.HasSuffix(pattern, wildcardSuffix) {
return strings.TrimSuffix(pattern, wildcardSuffix), true
}
return pattern, false
}
// GetModelPricingByPlatform 在指定平台下查找精确模型的定价,未找到返回 nil。
// 与 GetModelPricing 的区别:按 Platform 隔离,避免跨平台同名模型误匹配。
func (c *Channel) GetModelPricingByPlatform(platform, model string) *ChannelModelPricing {
if c == nil {
return nil
}
modelLower := strings.ToLower(model)
for i := range c.ModelPricing {
if c.ModelPricing[i].Platform != platform {
continue
}
for _, m := range c.ModelPricing[i].Models {
if strings.ToLower(m) == modelLower {
cp := c.ModelPricing[i].Clone()
return &cp
}
}
}
return nil
}
// pricingLookup 是渠道定价在单个计算过程中的索引:platform → (lowerName → *pricing)。
// 用于将 SupportedModels 的定价解析从 O(N*M) 降到 O(N+M)。
type pricingLookup map[string]map[string]*ChannelModelPricing
// buildPricingLookup 对渠道的定价列表做一次扫描,生成 platform+模型名 的索引。
// 索引值是定价条目的 Clone 指针,调用方可安全按需返回副本而不污染缓存。
// wildcard 后缀(如 "claude-*")不会被索引(它们不是精确模型名)。
func buildPricingLookup(pricings []ChannelModelPricing) pricingLookup {
lookup := make(pricingLookup, len(pricings))
for i := range pricings {
p := pricings[i]
byModel, ok := lookup[p.Platform]
if !ok {
byModel = make(map[string]*ChannelModelPricing, len(p.Models))
lookup[p.Platform] = byModel
}
for _, m := range p.Models {
if _, wild := splitWildcardSuffix(m); wild {
continue
}
lower := strings.ToLower(m)
if _, exists := byModel[lower]; exists {
continue // 首个命中胜出(保持 case-insensitive 去重后第一个定价)
}
cp := pricings[i].Clone()
byModel[lower] = &cp
}
}
return lookup
}
// pricedNamesFor 返回指定平台下已索引的精确模型名(保留原始大小写,按添加顺序)。
// 它是从 pricingLookup 中取 keys 并回查原始 ModelPricing 以得到原样字符串。
func pricedNamesFor(pricings []ChannelModelPricing, platform string) []string {
seen := make(map[string]struct{})
out := make([]string, 0)
for i := range pricings {
if pricings[i].Platform != platform {
continue
}
for _, m := range pricings[i].Models {
if _, wild := splitWildcardSuffix(m); wild {
continue
}
lower := strings.ToLower(m)
if _, ok := seen[lower]; ok {
continue
}
seen[lower] = struct{}{}
out = append(out, m)
}
}
return out
}
// SupportedModels 计算渠道的支持模型列表,结果保证不含通配符。
//
// 算法(以渠道自身的 ModelMapping 为唯一入口):
// - 遍历 Channel.ModelMapping 的每个 platform 条目;
// - 映射 key 不带尾部 "*":直接作为一个支持模型名(即使没有匹配的定价行,也会产出 Pricing=nil 的条目);
// - 映射 key 带尾部 "*":用同 platform 的 ModelPricing.Models 做前缀匹配展开(定价中带 "*" 的条目被忽略,因为它们本身就是模式,不是具体模型名);
// - 未在 ModelMapping 中出现的 platform 不会产出任何条目——这是**刻意设计**("没配映射就不显示"),即使该平台有定价行。
//
// 每个结果尝试从 pricingLookup(平台+模型名索引)查找精确定价,未配置则 Pricing=nil。
// 结果按 (Platform, Name) 稳定排序,并按 (Platform, lowercase(Name)) 去重。
func (c *Channel) SupportedModels() []SupportedModel {
if c == nil || len(c.ModelMapping) == 0 {
return nil
}
lookup := buildPricingLookup(c.ModelPricing)
type dedupKey struct {
platform string
name string
}
seen := make(map[dedupKey]struct{})
result := make([]SupportedModel, 0)
add := func(platform, name string) {
key := dedupKey{platform: platform, name: strings.ToLower(name)}
if _, ok := seen[key]; ok {
return
}
seen[key] = struct{}{}
var pricing *ChannelModelPricing
if byModel, ok := lookup[platform]; ok {
if p, ok := byModel[strings.ToLower(name)]; ok {
pricing = p
}
}
result = append(result, SupportedModel{
Name: name,
Platform: platform,
Pricing: pricing,
})
}
for platform, mapping := range c.ModelMapping {
if len(mapping) == 0 {
continue
}
pricedNames := pricedNamesFor(c.ModelPricing, platform)
for src := range mapping {
prefix, isWild := splitWildcardSuffix(src)
if isWild {
prefixLower := strings.ToLower(prefix)
for _, candidate := range pricedNames {
if strings.HasPrefix(strings.ToLower(candidate), prefixLower) {
add(platform, candidate)
}
}
continue
}
add(platform, src)
}
}
sort.Slice(result, func(i, j int) bool {
if result[i].Platform != result[j].Platform {
return result[i].Platform < result[j].Platform
}
return result[i].Name < result[j].Name
})
return result
}
package service
import (
"context"
"fmt"
"sort"
"strings"
)
// AvailableGroupRef 渠道视图中关联分组的简要信息。
type AvailableGroupRef struct {
ID int64
Name string
Platform string
}
// AvailableChannel 可用渠道视图:用于「可用渠道」页面展示渠道基础信息 +
// 关联的分组 + 推导出的支持模型列表(无通配符)。
type AvailableChannel struct {
ID int64
Name string
Description string
Status string
BillingModelSource string
RestrictModels bool
Groups []AvailableGroupRef
SupportedModels []SupportedModel
}
// ListAvailable 返回所有渠道的可用视图:每个渠道附带关联分组信息与支持模型列表。
//
// 支持模型通过 (*Channel).SupportedModels() 计算得到(见 channel.go)。
// 关联分组信息通过 groupRepo.ListActive 查询后按 ID 映射;渠道 GroupIDs 中未在活跃列表中
// 的分组(已停用或删除)会被忽略。
func (s *ChannelService) ListAvailable(ctx context.Context) ([]AvailableChannel, error) {
channels, err := s.repo.ListAll(ctx)
if err != nil {
return nil, fmt.Errorf("list channels: %w", err)
}
groupByID := make(map[int64]AvailableGroupRef)
if s.groupRepo != nil {
groups, err := s.groupRepo.ListActive(ctx)
if err != nil {
return nil, fmt.Errorf("list active groups: %w", err)
}
for i := range groups {
g := groups[i]
groupByID[g.ID] = AvailableGroupRef{
ID: g.ID,
Name: g.Name,
Platform: g.Platform,
}
}
}
out := make([]AvailableChannel, 0, len(channels))
for i := range channels {
ch := &channels[i]
groups := make([]AvailableGroupRef, 0, len(ch.GroupIDs))
for _, gid := range ch.GroupIDs {
if ref, ok := groupByID[gid]; ok {
groups = append(groups, ref)
}
}
sort.Slice(groups, func(i, j int) bool { return groups[i].Name < groups[j].Name })
out = append(out, AvailableChannel{
ID: ch.ID,
Name: ch.Name,
Description: ch.Description,
Status: ch.Status,
BillingModelSource: ch.BillingModelSource,
RestrictModels: ch.RestrictModels,
Groups: groups,
SupportedModels: ch.SupportedModels(),
})
}
sort.SliceStable(out, func(i, j int) bool {
return strings.ToLower(out[i].Name) < strings.ToLower(out[j].Name)
})
return out, nil
}
//go:build unit
package service
import (
"context"
"testing"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/stretchr/testify/require"
)
// stubGroupRepoForAvailable 是 ListAvailable 测试用的 GroupRepository stub,
// 仅实现 ListActive;其他方法对本测试无关,返回零值即可。
type stubGroupRepoForAvailable struct {
activeGroups []Group
}
func (s *stubGroupRepoForAvailable) ListActive(ctx context.Context) ([]Group, error) {
return s.activeGroups, nil
}
func (s *stubGroupRepoForAvailable) Create(ctx context.Context, group *Group) error { return nil }
func (s *stubGroupRepoForAvailable) GetByID(ctx context.Context, id int64) (*Group, error) {
return nil, nil
}
func (s *stubGroupRepoForAvailable) GetByIDLite(ctx context.Context, id int64) (*Group, error) {
return nil, nil
}
func (s *stubGroupRepoForAvailable) Update(ctx context.Context, group *Group) error { return nil }
func (s *stubGroupRepoForAvailable) Delete(ctx context.Context, id int64) error { return nil }
func (s *stubGroupRepoForAvailable) DeleteCascade(ctx context.Context, id int64) ([]int64, error) {
return nil, nil
}
func (s *stubGroupRepoForAvailable) List(ctx context.Context, params pagination.PaginationParams) ([]Group, *pagination.PaginationResult, error) {
return nil, nil, nil
}
func (s *stubGroupRepoForAvailable) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, status, search string, isExclusive *bool) ([]Group, *pagination.PaginationResult, error) {
return nil, nil, nil
}
func (s *stubGroupRepoForAvailable) ListActiveByPlatform(ctx context.Context, platform string) ([]Group, error) {
return nil, nil
}
func (s *stubGroupRepoForAvailable) ExistsByName(ctx context.Context, name string) (bool, error) {
return false, nil
}
func (s *stubGroupRepoForAvailable) GetAccountCount(ctx context.Context, groupID int64) (int64, int64, error) {
return 0, 0, nil
}
func (s *stubGroupRepoForAvailable) DeleteAccountGroupsByGroupID(ctx context.Context, groupID int64) (int64, error) {
return 0, nil
}
func (s *stubGroupRepoForAvailable) GetAccountIDsByGroupIDs(ctx context.Context, groupIDs []int64) ([]int64, error) {
return nil, nil
}
func (s *stubGroupRepoForAvailable) BindAccountsToGroup(ctx context.Context, groupID int64, accountIDs []int64) error {
return nil
}
func (s *stubGroupRepoForAvailable) UpdateSortOrders(ctx context.Context, updates []GroupSortOrderUpdate) error {
return nil
}
// newAvailableChannelService 构造一个 ChannelService,channelRepo.ListAll 返回给定 channels,
// groupRepo 由参数决定(可传 nil 测试 nil 分支)。
func newAvailableChannelService(channels []Channel, groupRepo GroupRepository) *ChannelService {
repo := &mockChannelRepository{
listAllFn: func(ctx context.Context) ([]Channel, error) { return channels, nil },
}
return NewChannelService(repo, groupRepo, nil)
}
func TestListAvailable_NilGroupRepo_NoGroupsAttached(t *testing.T) {
// groupRepo 为 nil 时不应 panic,且每个渠道的 Groups 应为空切片。
channels := []Channel{{
ID: 1,
Name: "chA",
Status: StatusActive,
GroupIDs: []int64{10, 20},
}}
svc := newAvailableChannelService(channels, nil)
out, err := svc.ListAvailable(context.Background())
require.NoError(t, err)
require.Len(t, out, 1)
require.Empty(t, out[0].Groups)
}
func TestListAvailable_InactiveGroupIDSilentlyDropped(t *testing.T) {
// 渠道 GroupIDs 中引用的 group 未出现在 ListActive 结果中(已停用或删除),应被静默丢弃。
channels := []Channel{{
ID: 1,
Name: "chA",
Status: StatusActive,
GroupIDs: []int64{1, 99},
}}
groupRepo := &stubGroupRepoForAvailable{
activeGroups: []Group{{ID: 1, Name: "g1", Platform: "anthropic"}},
}
svc := newAvailableChannelService(channels, groupRepo)
out, err := svc.ListAvailable(context.Background())
require.NoError(t, err)
require.Len(t, out, 1)
require.Len(t, out[0].Groups, 1)
require.Equal(t, int64(1), out[0].Groups[0].ID)
}
func TestListAvailable_SortedByName(t *testing.T) {
channels := []Channel{
{ID: 1, Name: "beta"},
{ID: 2, Name: "Alpha"},
{ID: 3, Name: "charlie"},
}
svc := newAvailableChannelService(channels, nil)
out, err := svc.ListAvailable(context.Background())
require.NoError(t, err)
require.Len(t, out, 3)
require.Equal(t, "Alpha", out[0].Name)
require.Equal(t, "beta", out[1].Name)
require.Equal(t, "charlie", out[2].Name)
}
...@@ -141,6 +141,7 @@ const ( ...@@ -141,6 +141,7 @@ const (
// ChannelService 渠道管理服务 // ChannelService 渠道管理服务
type ChannelService struct { type ChannelService struct {
repo ChannelRepository repo ChannelRepository
groupRepo GroupRepository
authCacheInvalidator APIKeyAuthCacheInvalidator authCacheInvalidator APIKeyAuthCacheInvalidator
cache atomic.Value // *channelCache cache atomic.Value // *channelCache
...@@ -148,9 +149,10 @@ type ChannelService struct { ...@@ -148,9 +149,10 @@ type ChannelService struct {
} }
// NewChannelService 创建渠道服务实例 // NewChannelService 创建渠道服务实例
func NewChannelService(repo ChannelRepository, authCacheInvalidator APIKeyAuthCacheInvalidator) *ChannelService { func NewChannelService(repo ChannelRepository, groupRepo GroupRepository, authCacheInvalidator APIKeyAuthCacheInvalidator) *ChannelService {
s := &ChannelService{ s := &ChannelService{
repo: repo, repo: repo,
groupRepo: groupRepo,
authCacheInvalidator: authCacheInvalidator, authCacheInvalidator: authCacheInvalidator,
} }
return s return s
...@@ -884,12 +886,7 @@ func conflictsBetween(a, b modelEntry) bool { ...@@ -884,12 +886,7 @@ func conflictsBetween(a, b modelEntry) bool {
// toModelEntry 将模型名转换为 modelEntry // toModelEntry 将模型名转换为 modelEntry
func toModelEntry(pattern string) modelEntry { func toModelEntry(pattern string) modelEntry {
lower := strings.ToLower(pattern) prefix, isWild := splitWildcardSuffix(strings.ToLower(pattern))
isWild := strings.HasSuffix(lower, "*")
prefix := lower
if isWild {
prefix = strings.TrimSuffix(lower, "*")
}
return modelEntry{pattern: pattern, prefix: prefix, wildcard: isWild} return modelEntry{pattern: pattern, prefix: prefix, wildcard: isWild}
} }
......
...@@ -189,11 +189,11 @@ func (m *mockChannelAuthCacheInvalidator) InvalidateAuthCacheByGroupID(_ context ...@@ -189,11 +189,11 @@ func (m *mockChannelAuthCacheInvalidator) InvalidateAuthCacheByGroupID(_ context
// --------------------------------------------------------------------------- // ---------------------------------------------------------------------------
func newTestChannelService(repo *mockChannelRepository) *ChannelService { func newTestChannelService(repo *mockChannelRepository) *ChannelService {
return NewChannelService(repo, nil) return NewChannelService(repo, nil, nil)
} }
func newTestChannelServiceWithAuth(repo *mockChannelRepository, auth *mockChannelAuthCacheInvalidator) *ChannelService { func newTestChannelServiceWithAuth(repo *mockChannelRepository, auth *mockChannelAuthCacheInvalidator) *ChannelService {
return NewChannelService(repo, auth) return NewChannelService(repo, nil, auth)
} }
// makeStandardRepo returns a repo that serves one active channel with anthropic pricing // makeStandardRepo returns a repo that serves one active channel with anthropic pricing
......
...@@ -433,3 +433,207 @@ func TestValidateIntervals_UnboundedNotLast(t *testing.T) { ...@@ -433,3 +433,207 @@ func TestValidateIntervals_UnboundedNotLast(t *testing.T) {
require.Contains(t, err.Error(), "unbounded") require.Contains(t, err.Error(), "unbounded")
require.Contains(t, err.Error(), "last") require.Contains(t, err.Error(), "last")
} }
func TestSupportedModels_ExactKeysAndPricing(t *testing.T) {
ch := &Channel{
ModelPricing: []ChannelModelPricing{
{ID: 10, Platform: "anthropic", Models: []string{"claude-sonnet-4-6"}, InputPrice: testPtrFloat64(3e-6)},
{ID: 11, Platform: "anthropic", Models: []string{"claude-opus-4-6"}, InputPrice: testPtrFloat64(1.5e-5)},
},
ModelMapping: map[string]map[string]string{
"anthropic": {
"claude-sonnet-4-6": "claude-sonnet-4-6",
"claude-opus-4-6": "claude-opus-4-6",
},
},
}
got := ch.SupportedModels()
require.Len(t, got, 2)
require.Equal(t, "anthropic", got[0].Platform)
require.Equal(t, "claude-opus-4-6", got[0].Name)
require.NotNil(t, got[0].Pricing)
require.Equal(t, int64(11), got[0].Pricing.ID)
require.Equal(t, "claude-sonnet-4-6", got[1].Name)
require.Equal(t, int64(10), got[1].Pricing.ID)
}
func TestSupportedModels_WildcardExpandedFromPricing(t *testing.T) {
ch := &Channel{
ModelPricing: []ChannelModelPricing{
{ID: 1, Platform: "anthropic", Models: []string{"claude-sonnet-4-6", "claude-sonnet-4-5"}},
{ID: 2, Platform: "anthropic", Models: []string{"claude-opus-4-6"}},
},
ModelMapping: map[string]map[string]string{
"anthropic": {
"claude-sonnet-*": "claude-sonnet-4-6",
},
},
}
got := ch.SupportedModels()
names := make([]string, 0, len(got))
for _, m := range got {
names = append(names, m.Name)
}
require.ElementsMatch(t, []string{"claude-sonnet-4-5", "claude-sonnet-4-6"}, names)
for _, m := range got {
require.NotContains(t, m.Name, "*")
}
}
func TestSupportedModels_PlatformWithoutMappingSkipped(t *testing.T) {
ch := &Channel{
ModelPricing: []ChannelModelPricing{
{ID: 1, Platform: "anthropic", Models: []string{"claude-sonnet-4-6"}},
{ID: 2, Platform: "openai", Models: []string{"gpt-4o"}},
},
ModelMapping: map[string]map[string]string{
"anthropic": {"claude-sonnet-4-6": "claude-sonnet-4-6"},
// openai 没有 mapping 条目
},
}
got := ch.SupportedModels()
require.Len(t, got, 1)
require.Equal(t, "anthropic", got[0].Platform)
require.Equal(t, "claude-sonnet-4-6", got[0].Name)
}
func TestSupportedModels_MissingPricingKeepsNilPricing(t *testing.T) {
ch := &Channel{
ModelMapping: map[string]map[string]string{
"anthropic": {"claude-sonnet-4-6": "claude-sonnet-4-6"},
},
}
got := ch.SupportedModels()
require.Len(t, got, 1)
require.Equal(t, "claude-sonnet-4-6", got[0].Name)
require.Nil(t, got[0].Pricing)
}
func TestSupportedModels_DedupAndSort(t *testing.T) {
ch := &Channel{
ModelPricing: []ChannelModelPricing{
{ID: 1, Platform: "anthropic", Models: []string{"claude-sonnet-4-6", "claude-sonnet-4-5"}},
{ID: 2, Platform: "openai", Models: []string{"gpt-4o"}},
},
ModelMapping: map[string]map[string]string{
"anthropic": {
"claude-sonnet-4-6": "upstream-a",
"claude-sonnet-*": "upstream-a",
},
"openai": {"gpt-4o": "gpt-4o"},
},
}
got := ch.SupportedModels()
require.Len(t, got, 3)
require.Equal(t, "anthropic", got[0].Platform)
require.Equal(t, "claude-sonnet-4-5", got[0].Name)
require.Equal(t, "anthropic", got[1].Platform)
require.Equal(t, "claude-sonnet-4-6", got[1].Name)
require.Equal(t, "openai", got[2].Platform)
require.Equal(t, "gpt-4o", got[2].Name)
}
func TestSupportedModels_NilChannelAndEmpty(t *testing.T) {
var nilCh *Channel
require.Nil(t, nilCh.SupportedModels())
empty := &Channel{}
require.Nil(t, empty.SupportedModels())
}
func TestGetModelPricingByPlatform(t *testing.T) {
ch := &Channel{
ModelPricing: []ChannelModelPricing{
{ID: 1, Platform: "anthropic", Models: []string{"claude-sonnet-4-6"}, InputPrice: testPtrFloat64(3e-6)},
{ID: 2, Platform: "openai", Models: []string{"claude-sonnet-4-6"}, InputPrice: testPtrFloat64(1e-6)},
},
}
ant := ch.GetModelPricingByPlatform("anthropic", "claude-sonnet-4-6")
require.NotNil(t, ant)
require.Equal(t, int64(1), ant.ID)
oa := ch.GetModelPricingByPlatform("openai", "claude-sonnet-4-6")
require.NotNil(t, oa)
require.Equal(t, int64(2), oa.ID)
require.Nil(t, ch.GetModelPricingByPlatform("gemini", "claude-sonnet-4-6"))
}
func TestSupportedModels_WildcardOnlyPricingRowsSkipped(t *testing.T) {
// 定价中含通配符条目(pattern),不应被当作具体模型名展开。
ch := &Channel{
ModelPricing: []ChannelModelPricing{
{ID: 1, Platform: "anthropic", Models: []string{"claude-sonnet-*", "claude-sonnet-4-6"}},
},
ModelMapping: map[string]map[string]string{
"anthropic": {"claude-sonnet-*": "claude-sonnet-4-6"},
},
}
got := ch.SupportedModels()
require.Len(t, got, 1)
require.Equal(t, "claude-sonnet-4-6", got[0].Name)
for _, m := range got {
require.NotContains(t, m.Name, "*")
}
}
func TestSupportedModels_WildcardPrefixMatchesNothing(t *testing.T) {
// 通配符模式无任何对应定价模型时,该平台应产出 0 个模型。
ch := &Channel{
ModelPricing: []ChannelModelPricing{
{ID: 1, Platform: "openai", Models: []string{"gpt-4o"}},
},
ModelMapping: map[string]map[string]string{
"anthropic": {"gpt-foo-*": "gpt-foo-1"},
},
}
require.Empty(t, ch.SupportedModels())
}
func TestSupportedModels_CrossPlatformPricingDoesNotBleed(t *testing.T) {
// anthropic 的通配符不应拉入 openai 定价行,哪怕名字恰好前缀匹配。
ch := &Channel{
ModelPricing: []ChannelModelPricing{
{ID: 1, Platform: "openai", Models: []string{"claude-sonnet-4-6"}},
},
ModelMapping: map[string]map[string]string{
"anthropic": {"claude-sonnet-*": "x"},
},
}
require.Empty(t, ch.SupportedModels())
}
func TestSupportedModels_CaseInsensitiveDedup(t *testing.T) {
// 两行定价用不同大小写定义了同一模型,结果应去重为 1 条;首次出现的原始大小写保留。
ch := &Channel{
ModelPricing: []ChannelModelPricing{
{ID: 1, Platform: "openai", Models: []string{"GPT-4o"}},
{ID: 2, Platform: "openai", Models: []string{"gpt-4o"}},
},
ModelMapping: map[string]map[string]string{
"openai": {"gpt-*": "x"},
},
}
got := ch.SupportedModels()
require.Len(t, got, 1)
require.Equal(t, "GPT-4o", got[0].Name)
}
func TestSupportedModels_EmptyPlatformMapping(t *testing.T) {
// ModelMapping 有一个 platform key 但 value 是空 map —— 该 platform 应被跳过。
ch := &Channel{
ModelPricing: []ChannelModelPricing{
{ID: 1, Platform: "anthropic", Models: []string{"claude-sonnet-4-6"}},
},
ModelMapping: map[string]map[string]string{
"anthropic": {},
},
}
require.Empty(t, ch.SupportedModels())
}
...@@ -184,7 +184,7 @@ func newResolverWithChannel(t *testing.T, pricing []ChannelModelPricing) *ModelP ...@@ -184,7 +184,7 @@ func newResolverWithChannel(t *testing.T, pricing []ChannelModelPricing) *ModelP
return map[int64]string{groupID: "anthropic"}, nil return map[int64]string{groupID: "anthropic"}, nil
}, },
} }
cs := NewChannelService(repo, nil) cs := NewChannelService(repo, nil, nil)
bs := newTestBillingServiceForResolver() bs := newTestBillingServiceForResolver()
return NewModelPricingResolver(cs, bs) return NewModelPricingResolver(cs, bs)
} }
...@@ -517,7 +517,7 @@ func TestResolve_WithChannelOverride_CacheError(t *testing.T) { ...@@ -517,7 +517,7 @@ func TestResolve_WithChannelOverride_CacheError(t *testing.T) {
return nil, errors.New("database unavailable") return nil, errors.New("database unavailable")
}, },
} }
cs := NewChannelService(repo, nil) cs := NewChannelService(repo, nil, nil)
bs := newTestBillingServiceForResolver() bs := newTestBillingServiceForResolver()
r := NewModelPricingResolver(cs, bs) r := NewModelPricingResolver(cs, bs)
......
...@@ -163,5 +163,42 @@ export async function getModelDefaultPricing(model: string): Promise<ModelDefaul ...@@ -163,5 +163,42 @@ export async function getModelDefaultPricing(model: string): Promise<ModelDefaul
return data return data
} }
const channelsAPI = { list, getById, create, update, remove, getModelDefaultPricing } // --- Available channels (聚合视图:渠道 + 分组 + 支持模型) ---
export interface AvailableGroupRef {
id: number
name: string
platform: string
}
export interface SupportedModel {
name: string
platform: string
pricing: ChannelModelPricing | null
}
export interface AvailableChannel {
id: number
name: string
description: string
status: string
billing_model_source: string
restrict_models: boolean
groups: AvailableGroupRef[]
supported_models: SupportedModel[]
}
interface AvailableChannelsResponse {
items: AvailableChannel[]
}
/** 列出所有可用渠道(含关联分组与支持模型) */
export async function listAvailable(options?: { signal?: AbortSignal }): Promise<AvailableChannel[]> {
const { data } = await apiClient.get<AvailableChannelsResponse>('/admin/channels/available', {
signal: options?.signal
})
return data.items
}
const channelsAPI = { list, getById, create, update, remove, getModelDefaultPricing, listAvailable }
export default channelsAPI export default channelsAPI
/**
* User Channels API endpoints (non-admin)
* 用户侧「可用渠道」聚合查询:渠道 + 用户可访问的分组 + 支持模型(含定价)。
*/
import { apiClient } from './client'
import type { BillingMode } from '@/constants/channel'
export interface UserAvailableGroup {
id: number
name: string
platform: string
}
export interface UserPricingInterval {
min_tokens: number
max_tokens: number | null
tier_label?: string
input_price: number | null
output_price: number | null
cache_write_price: number | null
cache_read_price: number | null
per_request_price: number | null
}
export interface UserSupportedModelPricing {
billing_mode: BillingMode
input_price: number | null
output_price: number | null
cache_write_price: number | null
cache_read_price: number | null
image_output_price: number | null
per_request_price: number | null
intervals: UserPricingInterval[]
}
export interface UserSupportedModel {
name: string
platform: string
pricing: UserSupportedModelPricing | null
}
export interface UserAvailableChannel {
name: string
description: string
groups: UserAvailableGroup[]
supported_models: UserSupportedModel[]
}
/** 列出当前用户可见的「可用渠道」(与 /groups/available 保持一致,返回平数组)。 */
export async function getAvailable(options?: { signal?: AbortSignal }): Promise<UserAvailableChannel[]> {
const { data } = await apiClient.get<UserAvailableChannel[]>('/channels/available', {
signal: options?.signal
})
return data
}
export const userChannelsAPI = { getAvailable }
export default userChannelsAPI
...@@ -16,6 +16,7 @@ export { userAPI } from './user' ...@@ -16,6 +16,7 @@ export { userAPI } from './user'
export { redeemAPI, type RedeemHistoryItem } from './redeem' export { redeemAPI, type RedeemHistoryItem } from './redeem'
export { paymentAPI } from './payment' export { paymentAPI } from './payment'
export { userGroupsAPI } from './groups' export { userGroupsAPI } from './groups'
export { userChannelsAPI } from './channels'
export { totpAPI } from './totp' export { totpAPI } from './totp'
export { default as announcementsAPI } from './announcements' export { default as announcementsAPI } from './announcements'
export { channelMonitorUserAPI } from './channelMonitor' export { channelMonitorUserAPI } from './channelMonitor'
......
<template>
<DataTable :columns="columns" :data="rows" :loading="loading">
<template #cell-name="{ row }">
<div class="font-medium text-gray-900 dark:text-white">{{ row.name }}</div>
<div
v-if="row.description"
class="mt-0.5 text-xs text-gray-500 dark:text-gray-400"
>
{{ row.description }}
</div>
</template>
<template #cell-groups="{ row }">
<div v-if="row.groups.length === 0" class="text-xs text-gray-400">
<slot name="empty-groups">-</slot>
</div>
<div v-else class="flex flex-wrap gap-1">
<span
v-for="g in row.groups"
:key="g.id"
class="inline-flex items-center rounded bg-blue-50 px-2 py-0.5 text-xs font-medium text-blue-700 dark:bg-blue-900/30 dark:text-blue-300"
>
{{ g.name }}
</span>
</div>
</template>
<template #cell-supported_models="{ row }">
<div v-if="row.supported_models.length === 0" class="text-xs text-gray-400">
{{ noModelsLabel }}
</div>
<div v-else class="flex max-w-[560px] flex-wrap gap-1">
<SupportedModelChip
v-for="m in row.supported_models"
:key="`${m.platform}-${m.name}`"
:model="m"
:pricing-key-prefix="pricingKeyPrefix"
:no-pricing-label="noPricingLabel"
/>
</div>
</template>
<!-- 允许父组件为额外列提供自定义渲染(如 admin 的 status / billing_model_source)。 -->
<template v-for="slot in extraCellSlots" :key="slot" #[slot]="scope">
<slot :name="slot" v-bind="scope" />
</template>
<template #empty>
<slot name="empty">
<div class="flex flex-col items-center py-8">
<Icon name="inbox" size="xl" class="mb-3 h-12 w-12 text-gray-400" />
<p class="text-sm text-gray-500 dark:text-gray-400">{{ emptyLabel }}</p>
</div>
</slot>
</template>
</DataTable>
</template>
<script setup lang="ts">
import { computed, useSlots } from 'vue'
import DataTable from '@/components/common/DataTable.vue'
import Icon from '@/components/icons/Icon.vue'
import SupportedModelChip from './SupportedModelChip.vue'
interface GroupRef {
id: number
name: string
platform?: string
}
interface Row {
name: string
description?: string
groups: GroupRef[]
supported_models: Array<{
name: string
platform: string
pricing: unknown | null
}>
[key: string]: unknown
}
interface Column {
key: string
label: string
}
withDefaults(
defineProps<{
columns: Column[]
rows: Row[]
loading: boolean
pricingKeyPrefix: string
noPricingLabel: string
noModelsLabel: string
emptyLabel: string
}>(),
{ loading: false }
)
const slots = useSlots()
/**
* 透传父组件提供的 cell-* 插槽(除本组件内置的 name/groups/supported_models/empty-groups/empty
* 之外),让 admin 场景可以自定义 status / billing_model_source 等列。
*/
const extraCellSlots = computed(() => {
const reserved = new Set(['cell-name', 'cell-groups', 'cell-supported_models', 'empty-groups', 'empty'])
return Object.keys(slots).filter((name) => name.startsWith('cell-') && !reserved.has(name))
})
</script>
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment