Commit 91bdcf89 authored by erio's avatar erio
Browse files

fix(channel): 模型限制用映射后模型检查 + 平台开关保留配置不删除

- OpenAI 网关三处 IsModelRestricted 改用 channelMapping.MappedModel
- 前端平台勾选改为 enabled 开关,取消勾选不清空配置数据
- formToAPI/校验只处理 enabled 的平台
parent 8d03c52e
...@@ -164,9 +164,9 @@ func (h *GatewayHandler) Messages(c *gin.Context) { ...@@ -164,9 +164,9 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
channelMapping = h.gatewayService.ResolveChannelMapping(c.Request.Context(), *apiKey.GroupID, reqModel) channelMapping = h.gatewayService.ResolveChannelMapping(c.Request.Context(), *apiKey.GroupID, reqModel)
} }
// 渠道模型限制检查:使用原始请求模型名,因为定价列表中注册的是用户请求的模型名 // 渠道模型限制检查:先映射再判断,映射后的模型在定价列表中即放行
if apiKey.GroupID != nil { if apiKey.GroupID != nil {
if h.gatewayService.IsModelRestricted(c.Request.Context(), *apiKey.GroupID, reqModel) { if h.gatewayService.IsModelRestricted(c.Request.Context(), *apiKey.GroupID, channelMapping.MappedModel) {
h.errorResponse(c, http.StatusServiceUnavailable, "api_error", "No available accounts") h.errorResponse(c, http.StatusServiceUnavailable, "api_error", "No available accounts")
return return
} }
......
...@@ -191,10 +191,9 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { ...@@ -191,10 +191,9 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
channelMapping = h.gatewayService.ResolveChannelMapping(c.Request.Context(), *apiKey.GroupID, reqModel) channelMapping = h.gatewayService.ResolveChannelMapping(c.Request.Context(), *apiKey.GroupID, reqModel)
} }
// 渠道模型限制检查 // 渠道模型限制检查:先映射再判断
if apiKey.GroupID != nil { if apiKey.GroupID != nil {
if h.gatewayService.IsModelRestricted(c.Request.Context(), *apiKey.GroupID, reqModel) { if h.gatewayService.IsModelRestricted(c.Request.Context(), *apiKey.GroupID, channelMapping.MappedModel) {
h.errorResponse(c, http.StatusServiceUnavailable, "api_error", "No available accounts")
return return
} }
} }
...@@ -584,10 +583,9 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) { ...@@ -584,10 +583,9 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) {
channelMappingMsg = h.gatewayService.ResolveChannelMapping(c.Request.Context(), *apiKey.GroupID, reqModel) channelMappingMsg = h.gatewayService.ResolveChannelMapping(c.Request.Context(), *apiKey.GroupID, reqModel)
} }
// 渠道模型限制检查 // 渠道模型限制检查:先映射再判断
if apiKey.GroupID != nil { if apiKey.GroupID != nil {
if h.gatewayService.IsModelRestricted(c.Request.Context(), *apiKey.GroupID, reqModel) { if h.gatewayService.IsModelRestricted(c.Request.Context(), *apiKey.GroupID, channelMappingMsg.MappedModel) {
h.anthropicErrorResponse(c, http.StatusServiceUnavailable, "api_error", "No available accounts")
return return
} }
} }
...@@ -1165,9 +1163,9 @@ func (h *OpenAIGatewayHandler) ResponsesWebSocket(c *gin.Context) { ...@@ -1165,9 +1163,9 @@ func (h *OpenAIGatewayHandler) ResponsesWebSocket(c *gin.Context) {
channelMappingWS = h.gatewayService.ResolveChannelMapping(ctx, *apiKey.GroupID, reqModel) channelMappingWS = h.gatewayService.ResolveChannelMapping(ctx, *apiKey.GroupID, reqModel)
} }
// 渠道模型限制检查 // 渠道模型限制检查:先映射再判断
if apiKey.GroupID != nil { if apiKey.GroupID != nil {
if h.gatewayService.IsModelRestricted(ctx, *apiKey.GroupID, reqModel) { if h.gatewayService.IsModelRestricted(ctx, *apiKey.GroupID, channelMappingWS.MappedModel) {
closeOpenAIClientWS(wsConn, coderws.StatusPolicyViolation, "model not allowed") closeOpenAIClientWS(wsConn, coderws.StatusPolicyViolation, "model not allowed")
return return
} }
......
...@@ -155,9 +155,9 @@ ...@@ -155,9 +155,9 @@
> >
{{ t('admin.channels.form.basicSettings', '基础设置') }} {{ t('admin.channels.form.basicSettings', '基础设置') }}
</button> </button>
<!-- Platform Tabs --> <!-- Platform Tabs (only enabled) -->
<button <button
v-for="(section, sIdx) in form.platforms" v-for="section in form.platforms.filter(s => s.enabled)"
:key="section.platform" :key="section.platform"
type="button" type="button"
@click="activeTab = section.platform" @click="activeTab = section.platform"
...@@ -166,12 +166,6 @@ ...@@ -166,12 +166,6 @@
> >
<PlatformIcon :platform="section.platform" size="xs" :class="getPlatformTextColor(section.platform)" /> <PlatformIcon :platform="section.platform" size="xs" :class="getPlatformTextColor(section.platform)" />
<span :class="getPlatformTextColor(section.platform)">{{ t('admin.groups.platforms.' + section.platform, section.platform) }}</span> <span :class="getPlatformTextColor(section.platform)">{{ t('admin.groups.platforms.' + section.platform, section.platform) }}</span>
<span
@click.stop="removePlatformSection(sIdx)"
class="ml-1 rounded-full p-0.5 opacity-0 group-hover:opacity-100 hover:bg-gray-200 dark:hover:bg-dark-600 transition-opacity"
>
<Icon name="x" size="xs" class="text-gray-400 hover:text-red-500" />
</span>
</button> </button>
</div> </div>
...@@ -261,7 +255,7 @@ ...@@ -261,7 +255,7 @@
<div <div
v-for="(section, sIdx) in form.platforms" v-for="(section, sIdx) in form.platforms"
:key="'tab-' + section.platform" :key="'tab-' + section.platform"
v-show="activeTab === section.platform" v-show="section.enabled && activeTab === section.platform"
class="space-y-4" class="space-y-4"
> >
<!-- Groups --> <!-- Groups -->
...@@ -449,6 +443,7 @@ const appStore = useAppStore() ...@@ -449,6 +443,7 @@ const appStore = useAppStore()
// ── Platform Section type ── // ── Platform Section type ──
interface PlatformSection { interface PlatformSection {
platform: GroupPlatform platform: GroupPlatform
enabled: boolean
collapsed: boolean collapsed: boolean
group_ids: number[] group_ids: number[]
model_mapping: Record<string, string> model_mapping: Record<string, string>
...@@ -549,11 +544,12 @@ function formatDate(value: string): string { ...@@ -549,11 +544,12 @@ function formatDate(value: string): string {
} }
// ── Platform section helpers ── // ── Platform section helpers ──
const activePlatforms = computed(() => form.platforms.map(s => s.platform)) const activePlatforms = computed(() => form.platforms.filter(s => s.enabled).map(s => s.platform))
function addPlatformSection(platform: GroupPlatform) { function addPlatformSection(platform: GroupPlatform) {
form.platforms.push({ form.platforms.push({
platform, platform,
enabled: true,
collapsed: false, collapsed: false,
group_ids: [], group_ids: [],
model_mapping: {}, model_mapping: {},
...@@ -562,22 +558,17 @@ function addPlatformSection(platform: GroupPlatform) { ...@@ -562,22 +558,17 @@ function addPlatformSection(platform: GroupPlatform) {
} }
function togglePlatform(platform: GroupPlatform) { function togglePlatform(platform: GroupPlatform) {
const idx = form.platforms.findIndex(s => s.platform === platform) const section = form.platforms.find(s => s.platform === platform)
if (idx >= 0) { if (section) {
removePlatformSection(idx) section.enabled = !section.enabled
if (!section.enabled && activeTab.value === platform) {
activeTab.value = 'basic'
}
} else { } else {
addPlatformSection(platform) addPlatformSection(platform)
} }
} }
function removePlatformSection(idx: number) {
const removed = form.platforms[idx]
form.platforms.splice(idx, 1)
if (activeTab.value === removed.platform) {
activeTab.value = 'basic'
}
}
function getGroupsForPlatform(platform: GroupPlatform): AdminGroup[] { function getGroupsForPlatform(platform: GroupPlatform): AdminGroup[] {
return allGroups.value.filter(g => g.platform === platform) return allGroups.value.filter(g => g.platform === platform)
} }
...@@ -682,6 +673,7 @@ function formToAPI(): { group_ids: number[], model_pricing: ChannelModelPricing[ ...@@ -682,6 +673,7 @@ function formToAPI(): { group_ids: number[], model_pricing: ChannelModelPricing[
const model_mapping: Record<string, Record<string, string>> = {} const model_mapping: Record<string, Record<string, string>> = {}
for (const section of form.platforms) { for (const section of form.platforms) {
if (!section.enabled) continue
group_ids.push(...section.group_ids) group_ids.push(...section.group_ids)
// Model mapping per platform // Model mapping per platform
...@@ -755,6 +747,7 @@ function apiToForm(channel: Channel): PlatformSection[] { ...@@ -755,6 +747,7 @@ function apiToForm(channel: Channel): PlatformSection[] {
sections.push({ sections.push({
platform, platform,
enabled: true,
collapsed: false, collapsed: false,
group_ids: groupIds, group_ids: groupIds,
model_mapping: { ...mapping }, model_mapping: { ...mapping },
...@@ -868,16 +861,16 @@ async function handleSubmit() { ...@@ -868,16 +861,16 @@ async function handleSubmit() {
return return
} }
// Check duplicate models across all platform sections // Check duplicate models across all enabled platform sections
const allModels = form.platforms.flatMap(s => s.model_pricing.flatMap(e => e.models.map(m => m.toLowerCase()))) const allModels = form.platforms.filter(s => s.enabled).flatMap(s => s.model_pricing.flatMap(e => e.models.map(m => m.toLowerCase())))
const duplicates = allModels.filter((m, i) => allModels.indexOf(m) !== i) const duplicates = allModels.filter((m, i) => allModels.indexOf(m) !== i)
if (duplicates.length > 0) { if (duplicates.length > 0) {
appStore.showError(t('admin.channels.duplicateModels', `模型 "${duplicates[0]}" 在多个定价条目中重复`)) appStore.showError(t('admin.channels.duplicateModels', `模型 "${duplicates[0]}" 在多个定价条目中重复`))
return return
} }
// 校验 per_request/image 模式必须有价格 // 校验 per_request/image 模式必须有价格 (只校验启用的平台)
for (const section of form.platforms) { for (const section of form.platforms.filter(s => s.enabled)) {
for (const entry of section.model_pricing) { for (const entry of section.model_pricing) {
if (entry.models.length === 0) continue if (entry.models.length === 0) continue
if ((entry.billing_mode === 'per_request' || entry.billing_mode === 'image') && if ((entry.billing_mode === 'per_request' || entry.billing_mode === 'image') &&
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment