Commit 30f55a1f authored by DaydreamCoding's avatar DaydreamCoding
Browse files

feat(openai): OpenAI Fast/Flex Policy 完整实现(HTTP + WebSocket + Admin)



对称参照 Claude BetaPolicy 的 fast-mode 过滤实现,新增针对 OpenAI 上游
service_tier 字段(priority / flex,含客户端 "fast" → "priority" 归一化)的
pass / filter / block 三态策略,覆盖全部 OpenAI 入口 + admin 配置入口。

后端核心
- 新增 SettingKeyOpenAIFastPolicySettings、OpenAIFastPolicyRule、
  OpenAIFastPolicySettings 配置模型,含规则的 service_tier × action × scope
  × 模型白名单 × fallback action 维度。
- SettingService.Get/SetOpenAIFastPolicySettings;缺失时返回内置默认策略
  (所有模型的 priority 走 filter,whitelist 为空,fallback=pass)。设计
  依据:service_tier=fast 是用户级开关,与 model 字段正交,默认锁定特定
  model slug 会留下"用 gpt-4 + fast 透传 priority 上游"的绕过路径。JSON
  解析失败不再静默 fallback,slog.Warn 记录脏数据,便于运维定位。
- service_tier 归一化(trim + ToLower + fast→priority + 白名单 priority/flex)
  与策略评估(evaluateOpenAIFastPolicy)作为唯一真实来源,HTTP / WS 共用。
  抽出纯函数 evaluateOpenAIFastPolicyWithSettings,配合 ctx-bound settings
  快照(withOpenAIFastPolicyContext / openAIFastPolicySettingsFromContext),
  WS 长会话入口预取一次后所有帧复用,避免每帧打到 settingService。

HTTP 入口(4 个)
- Chat Completions、Anthropic 兼容(Messages,含 BetaFastMode→priority 二次
  命中)、原生 Responses、Passthrough Responses 全部接入
  applyOpenAIFastPolicyToBody,filter 走 sjson 顶层删除 service_tier,block
  返回 403 forbidden_error JSON。
- 4 入口统一使用 upstream 视角的 model(GetMappedModel +
  normalizeOpenAIModelForUpstream + Codex OAuth normalize 后的 slug),
  避免 chat/messages/native /responses/passthrough 因为 model 维度不同
  造成 whitelist 命中差异。
- 在 pass 路径也把客户端 "fast" 别名归一化为 "priority" 写回 body,
  否则 native /responses 与 passthrough 入口会把 "fast" 原样透传给上游
  导致 400/拒绝(chat-completions 入口的 normalizeResponsesBodyServiceTier
  此前已具备同等行为)。

WebSocket 入口
- 新增 applyOpenAIFastPolicyToWSResponseCreate:严格匹配
  type="response.create",仅处理顶层 service_tier;filter 用 sjson 删字段,
  block 返回 typed *OpenAIFastBlockedError。
- ingress 路径在 parseClientPayload 内调用,block 命中先 Write Realtime
  风格 error event 再返回 OpenAIWSClientCloseError(StatusPolicyViolation
  =1008),依赖底层 WebSocket Conn.Write 的同步 flush 保证 error 先于
  close。
- passthrough 路径在 RunEntry 前对 firstClientMessage 应用策略,并通过
  openAIWSPolicyEnforcingFrameConn 包装 ReadFrame 对每个 client→upstream
  帧执行策略;后续帧无 model 字段时回退到 capturedSessionModel。
  filter 闭包内同时侦测 session.update / session.created 帧的 session.model
  字段刷新 capturedSessionModel,封堵"首帧 model=gpt-4o(pass)→
  session.update 改为 gpt-5.5 → 不带 model 的 response.create fallback
  到 gpt-4o"的 mid-session 绕过路径。
- passthrough billing:requestServiceTier 在策略 filter 之后再从
  firstClientMessage 提取,filter 命中时 OpenAIForwardResult.ServiceTier
  上报 nil(default tier),与 HTTP 入口(reqBody 来自 post-filter map)
  / WS ingress(payload 来自 post-filter bytes)的语义一致。
- 错误事件 schema:{event_id: "evt_<32hex>", type: "error",
  error: {type: "forbidden_error", code: "policy_violation", message}},
  与 OpenAI codex 客户端 error event 解析兼容。

Admin / Frontend
- dto.SystemSettings / UpdateSettingsRequest 新增
  openai_fast_policy_settings 字段(omitempty),bulk GET/PUT 接入。
- Settings 页 Gateway 页签新增 Fast/Flex Policy 表单卡片:
  service_tier × action × scope × 模型白名单 × fallback action 全字段配置。
- 前端守门:openaiFastPolicyLoaded 标志仅在 GET 真带回字段时才允许回写,
  避免 rollout/错误把默认规则覆盖成空;saveSettings 回写循环 skip 该字段,
  由专用刷新逻辑处理;仅 action=block 时发送 error_message,匹配后端
  omitempty 行为。

测试
- HTTP 路径:openai_fast_policy_test.go 覆盖默认配置(whitelist=[],所有
  模型 priority filter)/ block 自定义错误 / scope 区分 / filter 删字段 /
  block 不改 body / block 短路上游 / Anthropic BetaFastMode 触发 OpenAI
  fast policy 等场景。
- WebSocket 路径:openai_fast_policy_ws_test.go 覆盖
    helper 单元(filter / fast→priority 归一化 / flex 透传 / block typed
    error / 无 service_tier 字节不变 / 非 response.create 帧不动 / 空 type
    帧不动 / event_id+code 字段断言 / 非字符串 service_tier 容错)+
    pass 路径 fast 别名归一化回归 +
    ingress 端到端(filter 后上游不含 service_tier / block 后客户端先收
    error event 再收 close 1008 且上游 0 写)+
    passthrough capturedSessionModel fallback 用例(whitelist 策略下首帧
    建立、缺 model 命中 fallback、缺少 fallback 时的 leak 文档化)+
    passthrough session.update / session.created 旋转 capturedSessionModel
    的 mid-session 绕过回归 +
    passthrough billing post-filter ServiceTier 与 idempotent filter 回归。
Co-Authored-By: default avatarClaude Opus 4.7 (1M context) <noreply@anthropic.com>
parent c92b88e3
......@@ -186,7 +186,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
balanceNotifyService := service.ProvideBalanceNotifyService(emailService, settingRepository, accountRepository)
gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, sessionLimitCache, rpmCache, digestSessionStore, settingService, tlsFingerprintProfileService, channelService, modelPricingResolver, balanceNotifyService)
openAITokenProvider := service.ProvideOpenAITokenProvider(accountRepository, geminiTokenCache, openAIOAuthService, oAuthRefreshAPI)
openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService, openAITokenProvider, modelPricingResolver, channelService, balanceNotifyService)
openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService, openAITokenProvider, modelPricingResolver, channelService, balanceNotifyService, settingService)
geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, groupRepository, gatewayCache, schedulerSnapshotService, geminiTokenProvider, rateLimitService, httpUpstream, antigravityGatewayService, configConfig)
opsSystemLogSink := service.ProvideOpsSystemLogSink(opsRepository)
opsService := service.NewOpsService(opsRepository, settingRepository, configConfig, accountRepository, userRepository, concurrencyService, gatewayService, openAIGatewayService, geminiMessagesCompatService, antigravityGatewayService, opsSystemLogSink)
......
......@@ -8,6 +8,7 @@ import (
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
......@@ -222,3 +223,66 @@ func TestOpsWSHelpers(t *testing.T) {
require.True(t, isAddrInTrustedProxies(addr, prefixes))
require.False(t, isAddrInTrustedProxies(netip.MustParseAddr("192.168.0.1"), prefixes))
}
// TestOpenAIFastPolicySettingsFromDTO_NormalizesServiceTier 验证 admin
// 写入路径会把 ServiceTier 的空字符串/空白/大小写归一化为
// service.OpenAIFastTierAny ("all"),避免落盘时 "" 与 "all" 双语义。
func TestOpenAIFastPolicySettingsFromDTO_NormalizesServiceTier(t *testing.T) {
t.Run("nil input returns nil", func(t *testing.T) {
require.Nil(t, openaiFastPolicySettingsFromDTO(nil))
})
t.Run("empty service_tier becomes 'all'", func(t *testing.T) {
in := &dto.OpenAIFastPolicySettings{
Rules: []dto.OpenAIFastPolicyRule{{
ServiceTier: "",
Action: "filter",
Scope: "all",
}},
}
out := openaiFastPolicySettingsFromDTO(in)
require.NotNil(t, out)
require.Len(t, out.Rules, 1)
require.Equal(t, service.OpenAIFastTierAny, out.Rules[0].ServiceTier)
require.Equal(t, "all", out.Rules[0].ServiceTier)
})
t.Run("whitespace-only service_tier becomes 'all'", func(t *testing.T) {
in := &dto.OpenAIFastPolicySettings{
Rules: []dto.OpenAIFastPolicyRule{{
ServiceTier: " ",
Action: "pass",
Scope: "all",
}},
}
out := openaiFastPolicySettingsFromDTO(in)
require.Equal(t, service.OpenAIFastTierAny, out.Rules[0].ServiceTier)
})
t.Run("uppercase service_tier is lowercased", func(t *testing.T) {
in := &dto.OpenAIFastPolicySettings{
Rules: []dto.OpenAIFastPolicyRule{{
ServiceTier: "PRIORITY",
Action: "filter",
Scope: "all",
}},
}
out := openaiFastPolicySettingsFromDTO(in)
require.Equal(t, service.OpenAIFastTierPriority, out.Rules[0].ServiceTier)
})
t.Run("non-empty values pass through (lowercased)", func(t *testing.T) {
in := &dto.OpenAIFastPolicySettings{
Rules: []dto.OpenAIFastPolicyRule{
{ServiceTier: "priority", Action: "filter", Scope: "all"},
{ServiceTier: "flex", Action: "block", Scope: "oauth"},
{ServiceTier: "all", Action: "pass", Scope: "apikey"},
},
}
out := openaiFastPolicySettingsFromDTO(in)
require.Len(t, out.Rules, 3)
require.Equal(t, service.OpenAIFastTierPriority, out.Rules[0].ServiceTier)
require.Equal(t, service.OpenAIFastTierFlex, out.Rules[1].ServiceTier)
require.Equal(t, service.OpenAIFastTierAny, out.Rules[2].ServiceTier)
})
}
......@@ -248,9 +248,51 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
AffiliateEnabled: settings.AffiliateEnabled,
}
// OpenAI fast policy (stored under a dedicated setting key)
if fastPolicy, err := h.settingService.GetOpenAIFastPolicySettings(c.Request.Context()); err != nil {
slog.Error("openai_fast_policy_settings_get_failed", "error", err)
} else if fastPolicy != nil {
payload.OpenAIFastPolicySettings = openaiFastPolicySettingsToDTO(fastPolicy)
}
response.Success(c, systemSettingsResponseData(payload, authSourceDefaults))
}
// openaiFastPolicySettingsToDTO converts service -> dto for OpenAI fast policy.
func openaiFastPolicySettingsToDTO(s *service.OpenAIFastPolicySettings) *dto.OpenAIFastPolicySettings {
if s == nil {
return nil
}
rules := make([]dto.OpenAIFastPolicyRule, len(s.Rules))
for i, r := range s.Rules {
rules[i] = dto.OpenAIFastPolicyRule(r)
}
return &dto.OpenAIFastPolicySettings{Rules: rules}
}
// openaiFastPolicySettingsFromDTO converts dto -> service for OpenAI fast policy.
//
// 规范化 ServiceTier:在 DTO 进入 service 层之前统一把空字符串归一为
// service.OpenAIFastTierAny ("all"),避免管理员保存时空串与 "all" 同时
// 表达"匹配任意 tier"造成数据库取值的二义性。其它非空值原样透传,由
// service.SetOpenAIFastPolicySettings 负责合法值校验。
func openaiFastPolicySettingsFromDTO(s *dto.OpenAIFastPolicySettings) *service.OpenAIFastPolicySettings {
if s == nil {
return nil
}
rules := make([]service.OpenAIFastPolicyRule, len(s.Rules))
for i, r := range s.Rules {
rules[i] = service.OpenAIFastPolicyRule(r)
tier := strings.ToLower(strings.TrimSpace(rules[i].ServiceTier))
if tier == "" {
tier = service.OpenAIFastTierAny
}
rules[i].ServiceTier = tier
}
return &service.OpenAIFastPolicySettings{Rules: rules}
}
// UpdateSettingsRequest 更新设置请求
type UpdateSettingsRequest struct {
// 注册设置
......@@ -452,6 +494,9 @@ type UpdateSettingsRequest struct {
// Affiliate (邀请返利) feature switch
AffiliateEnabled *bool `json:"affiliate_enabled"`
// OpenAI fast/flex policy (optional, only updated when provided)
OpenAIFastPolicySettings *dto.OpenAIFastPolicySettings `json:"openai_fast_policy_settings,omitempty"`
}
// UpdateSettings 更新系统设置
......@@ -1350,6 +1395,14 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
return
}
// Update OpenAI fast policy (stored under dedicated key, only when provided).
if req.OpenAIFastPolicySettings != nil {
if err := h.settingService.SetOpenAIFastPolicySettings(c.Request.Context(), openaiFastPolicySettingsFromDTO(req.OpenAIFastPolicySettings)); err != nil {
response.BadRequest(c, err.Error())
return
}
}
// Update payment configuration (integrated into system settings).
// Skip if no payment fields were provided (prevents accidental wipe).
if h.paymentConfigService != nil && hasPaymentFields(req) {
......@@ -1555,6 +1608,11 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
AffiliateEnabled: updatedSettings.AffiliateEnabled,
}
if fastPolicy, err := h.settingService.GetOpenAIFastPolicySettings(c.Request.Context()); err != nil {
slog.Error("openai_fast_policy_settings_get_failed", "error", err)
} else if fastPolicy != nil {
payload.OpenAIFastPolicySettings = openaiFastPolicySettingsToDTO(fastPolicy)
}
response.Success(c, systemSettingsResponseData(payload, updatedAuthSourceDefaults))
}
......
......@@ -26,7 +26,12 @@ func (s *settingHandlerRepoStub) Get(ctx context.Context, key string) (*service.
}
func (s *settingHandlerRepoStub) GetValue(ctx context.Context, key string) (string, error) {
panic("unexpected GetValue call")
if s.values != nil {
if value, ok := s.values[key]; ok {
return value, nil
}
}
return "", nil
}
func (s *settingHandlerRepoStub) Set(ctx context.Context, key, value string) error {
......
......@@ -198,6 +198,9 @@ type SystemSettings struct {
// Affiliate (邀请返利) feature switch
AffiliateEnabled bool `json:"affiliate_enabled"`
// OpenAI fast/flex policy
OpenAIFastPolicySettings *OpenAIFastPolicySettings `json:"openai_fast_policy_settings,omitempty"`
}
type DefaultSubscriptionSetting struct {
......@@ -294,6 +297,22 @@ type BetaPolicySettings struct {
Rules []BetaPolicyRule `json:"rules"`
}
// OpenAIFastPolicyRule OpenAI fast/flex 策略规则 DTO
type OpenAIFastPolicyRule struct {
ServiceTier string `json:"service_tier"`
Action string `json:"action"`
Scope string `json:"scope"`
ErrorMessage string `json:"error_message,omitempty"`
ModelWhitelist []string `json:"model_whitelist,omitempty"`
FallbackAction string `json:"fallback_action,omitempty"`
FallbackErrorMessage string `json:"fallback_error_message,omitempty"`
}
// OpenAIFastPolicySettings OpenAI fast 策略配置 DTO
type OpenAIFastPolicySettings struct {
Rules []OpenAIFastPolicyRule `json:"rules"`
}
// ParseCustomMenuItems parses a JSON string into a slice of CustomMenuItem.
// Returns empty slice on empty/invalid input.
func ParseCustomMenuItems(raw string) []CustomMenuItem {
......
......@@ -748,6 +748,16 @@ func TestAPIContracts(t *testing.T) {
"payment_visible_method_alipay_enabled": true,
"payment_visible_method_wxpay_enabled": false,
"openai_advanced_scheduler_enabled": true,
"openai_fast_policy_settings": {
"rules": [
{
"service_tier": "priority",
"action": "filter",
"scope": "all",
"fallback_action": "pass"
}
]
},
"custom_menu_items": [],
"custom_endpoints": [],
"payment_enabled": false,
......@@ -930,6 +940,16 @@ func TestAPIContracts(t *testing.T) {
"payment_visible_method_alipay_enabled": false,
"payment_visible_method_wxpay_enabled": false,
"openai_advanced_scheduler_enabled": false,
"openai_fast_policy_settings": {
"rules": [
{
"service_tier": "priority",
"action": "filter",
"scope": "all",
"fallback_action": "pass"
}
]
},
"payment_enabled": false,
"payment_min_amount": 0,
"payment_max_amount": 0,
......
......@@ -306,6 +306,12 @@ const (
// SettingKeyBetaPolicySettings stores JSON config for beta policy rules.
SettingKeyBetaPolicySettings = "beta_policy_settings"
// SettingKeyOpenAIFastPolicySettings stores JSON config for OpenAI
// service_tier (fast/flex) policy rules. Mirrors BetaPolicySettings but
// targets OpenAI's body-level service_tier field instead of Claude's
// anthropic-beta header.
SettingKeyOpenAIFastPolicySettings = "openai_fast_policy_settings"
// =========================
// Claude Code Version Check
// =========================
......
package service
import (
"context"
"encoding/json"
"errors"
"testing"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/stretchr/testify/require"
)
type openAIFastPolicyRepoStub struct {
values map[string]string
}
func (s *openAIFastPolicyRepoStub) Get(ctx context.Context, key string) (*Setting, error) {
panic("unexpected Get call")
}
func (s *openAIFastPolicyRepoStub) GetValue(ctx context.Context, key string) (string, error) {
if v, ok := s.values[key]; ok {
return v, nil
}
return "", ErrSettingNotFound
}
func (s *openAIFastPolicyRepoStub) Set(ctx context.Context, key, value string) error {
if s.values == nil {
s.values = map[string]string{}
}
s.values[key] = value
return nil
}
func (s *openAIFastPolicyRepoStub) GetMultiple(ctx context.Context, keys []string) (map[string]string, error) {
panic("unexpected GetMultiple call")
}
func (s *openAIFastPolicyRepoStub) SetMultiple(ctx context.Context, settings map[string]string) error {
panic("unexpected SetMultiple call")
}
func (s *openAIFastPolicyRepoStub) GetAll(ctx context.Context) (map[string]string, error) {
panic("unexpected GetAll call")
}
func (s *openAIFastPolicyRepoStub) Delete(ctx context.Context, key string) error {
panic("unexpected Delete call")
}
func newOpenAIGatewayServiceWithSettings(t *testing.T, settings *OpenAIFastPolicySettings) *OpenAIGatewayService {
t.Helper()
repo := &openAIFastPolicyRepoStub{values: map[string]string{}}
if settings != nil {
raw, err := json.Marshal(settings)
require.NoError(t, err)
repo.values[SettingKeyOpenAIFastPolicySettings] = string(raw)
}
return &OpenAIGatewayService{
settingService: NewSettingService(repo, &config.Config{}),
}
}
func TestEvaluateOpenAIFastPolicy_DefaultFiltersAllModelsPriority(t *testing.T) {
svc := newOpenAIGatewayServiceWithSettings(t, DefaultOpenAIFastPolicySettings())
account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
// 默认策略对所有模型生效(whitelist 为空),因为 codex 的 service_tier=fast
// 是用户级开关,与 model 正交。
// gpt-5.5 + priority → filter
action, _ := svc.evaluateOpenAIFastPolicy(context.Background(), account, "gpt-5.5", OpenAIFastTierPriority)
require.Equal(t, BetaPolicyActionFilter, action)
// gpt-5.5-turbo → filter
action, _ = svc.evaluateOpenAIFastPolicy(context.Background(), account, "gpt-5.5-turbo", OpenAIFastTierPriority)
require.Equal(t, BetaPolicyActionFilter, action)
// gpt-4 + priority → filter(默认策略覆盖所有模型)
action, _ = svc.evaluateOpenAIFastPolicy(context.Background(), account, "gpt-4", OpenAIFastTierPriority)
require.Equal(t, BetaPolicyActionFilter, action)
// gpt-5.5 + flex → pass (tier doesn't match)
action, _ = svc.evaluateOpenAIFastPolicy(context.Background(), account, "gpt-5.5", OpenAIFastTierFlex)
require.Equal(t, BetaPolicyActionPass, action)
// empty tier → pass
action, _ = svc.evaluateOpenAIFastPolicy(context.Background(), account, "gpt-5.5", "")
require.Equal(t, BetaPolicyActionPass, action)
}
func TestEvaluateOpenAIFastPolicy_BlockRuleCarriesMessage(t *testing.T) {
settings := &OpenAIFastPolicySettings{
Rules: []OpenAIFastPolicyRule{{
ServiceTier: OpenAIFastTierPriority,
Action: BetaPolicyActionBlock,
Scope: BetaPolicyScopeAll,
ErrorMessage: "fast mode is not allowed",
ModelWhitelist: []string{"gpt-5.5"},
FallbackAction: BetaPolicyActionPass,
}},
}
svc := newOpenAIGatewayServiceWithSettings(t, settings)
account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
action, msg := svc.evaluateOpenAIFastPolicy(context.Background(), account, "gpt-5.5", OpenAIFastTierPriority)
require.Equal(t, BetaPolicyActionBlock, action)
require.Equal(t, "fast mode is not allowed", msg)
}
func TestEvaluateOpenAIFastPolicy_ScopeFiltersOAuth(t *testing.T) {
settings := &OpenAIFastPolicySettings{
Rules: []OpenAIFastPolicyRule{{
ServiceTier: OpenAIFastTierAny,
Action: BetaPolicyActionFilter,
Scope: BetaPolicyScopeOAuth,
}},
}
svc := newOpenAIGatewayServiceWithSettings(t, settings)
// OAuth account → rule matches
oauthAccount := &Account{Platform: PlatformOpenAI, Type: AccountTypeOAuth}
action, _ := svc.evaluateOpenAIFastPolicy(context.Background(), oauthAccount, "gpt-4", OpenAIFastTierPriority)
require.Equal(t, BetaPolicyActionFilter, action)
// API Key account → rule skipped → pass
apiKeyAccount := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
action, _ = svc.evaluateOpenAIFastPolicy(context.Background(), apiKeyAccount, "gpt-4", OpenAIFastTierPriority)
require.Equal(t, BetaPolicyActionPass, action)
}
func TestApplyOpenAIFastPolicyToBody_FilterRemovesField(t *testing.T) {
svc := newOpenAIGatewayServiceWithSettings(t, DefaultOpenAIFastPolicySettings())
account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
// gpt-5.5 fast → service_tier stripped
body := []byte(`{"model":"gpt-5.5","service_tier":"priority","messages":[]}`)
updated, err := svc.applyOpenAIFastPolicyToBody(context.Background(), account, "gpt-5.5", body)
require.NoError(t, err)
require.NotContains(t, string(updated), `"service_tier"`)
// Client sending "fast" (alias for priority) also filtered
body = []byte(`{"model":"gpt-5.5","service_tier":"fast"}`)
updated, err = svc.applyOpenAIFastPolicyToBody(context.Background(), account, "gpt-5.5", body)
require.NoError(t, err)
require.NotContains(t, string(updated), `"service_tier"`)
// gpt-4 priority → 默认策略对所有模型 filter,service_tier 被移除
body = []byte(`{"model":"gpt-4","service_tier":"priority"}`)
updated, err = svc.applyOpenAIFastPolicyToBody(context.Background(), account, "gpt-4", body)
require.NoError(t, err)
require.NotContains(t, string(updated), `"service_tier"`)
// No service_tier → no-op
body = []byte(`{"model":"gpt-5.5"}`)
updated, err = svc.applyOpenAIFastPolicyToBody(context.Background(), account, "gpt-5.5", body)
require.NoError(t, err)
require.Equal(t, string(body), string(updated))
}
// TestApplyOpenAIFastPolicyToBody_OfficialTiersBypassDefaultRule 验证扩展白名单后
// 客户端显式发送的 OpenAI 官方合法 tier(auto/default/scale)能透传到上游而不被
// 静默剥离。默认策略只针对 priority,所以这些 tier 落在 fall-through pass 分支。
func TestApplyOpenAIFastPolicyToBody_OfficialTiersBypassDefaultRule(t *testing.T) {
svc := newOpenAIGatewayServiceWithSettings(t, DefaultOpenAIFastPolicySettings())
account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
for _, tier := range []string{"auto", "default", "scale"} {
body := []byte(`{"model":"gpt-5.5","service_tier":"` + tier + `"}`)
updated, err := svc.applyOpenAIFastPolicyToBody(context.Background(), account, "gpt-5.5", body)
require.NoError(t, err, "tier %q should pass without error", tier)
require.Contains(t, string(updated), `"service_tier":"`+tier+`"`,
"tier %q should be preserved in body under default rule", tier)
}
// evaluate 层也应判定为 pass(默认规则 ServiceTier=priority 与 auto/default/scale 不匹配)
for _, tier := range []string{"auto", "default", "scale"} {
action, _ := svc.evaluateOpenAIFastPolicy(context.Background(), account, "gpt-5.5", tier)
require.Equal(t, BetaPolicyActionPass, action, "tier %q should evaluate to pass", tier)
}
}
// TestApplyOpenAIFastPolicyToBody_AllRuleStripsOfficialTiers 验证管理员显式配置
// ServiceTier=all + Action=filter 规则后,auto/default/scale 等官方 tier 也会
// 被剥离。这是符合预期的——首条匹配 short-circuit,"all" 覆盖任意已识别 tier。
func TestApplyOpenAIFastPolicyToBody_AllRuleStripsOfficialTiers(t *testing.T) {
settings := &OpenAIFastPolicySettings{
Rules: []OpenAIFastPolicyRule{{
ServiceTier: OpenAIFastTierAny,
Action: BetaPolicyActionFilter,
Scope: BetaPolicyScopeAll,
}},
}
svc := newOpenAIGatewayServiceWithSettings(t, settings)
account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
for _, tier := range []string{"auto", "default", "scale", "priority", "flex"} {
body := []byte(`{"model":"gpt-5.5","service_tier":"` + tier + `"}`)
updated, err := svc.applyOpenAIFastPolicyToBody(context.Background(), account, "gpt-5.5", body)
require.NoError(t, err)
require.NotContains(t, string(updated), `"service_tier"`,
"tier %q should be stripped under ServiceTier=all + filter rule", tier)
}
}
// TestApplyOpenAIFastPolicyToBody_UnknownTierStripped 验证真未知 tier 仍被剥离
// (normalize 返回 nil → normalizeResponsesBodyServiceTier 删除字段;
// applyOpenAIFastPolicyToBody 在 normTier 为空时直接 no-op,因为字段已不可能存在
// 于经过前置归一化的请求里。这里直接调 apply 验证它对未识别值不会异常)。
func TestApplyOpenAIFastPolicyToBody_UnknownTierStripped(t *testing.T) {
svc := newOpenAIGatewayServiceWithSettings(t, DefaultOpenAIFastPolicySettings())
account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
// normalize 阶段会将未知值剥离
require.Nil(t, normalizeOpenAIServiceTier("xxx"))
// applyOpenAIFastPolicyToBody 收到未识别 tier 时不报错,body 透传不变
// (不属于本函数职责——上层 normalizeResponsesBodyServiceTier 已剥离)
body := []byte(`{"model":"gpt-5.5","service_tier":"xxx"}`)
updated, err := svc.applyOpenAIFastPolicyToBody(context.Background(), account, "gpt-5.5", body)
require.NoError(t, err)
require.Equal(t, string(body), string(updated))
}
func TestApplyOpenAIFastPolicyToBody_BlockReturnsTypedError(t *testing.T) {
settings := &OpenAIFastPolicySettings{
Rules: []OpenAIFastPolicyRule{{
ServiceTier: OpenAIFastTierPriority,
Action: BetaPolicyActionBlock,
Scope: BetaPolicyScopeAll,
ErrorMessage: "fast mode is blocked for gpt-5.5",
ModelWhitelist: []string{"gpt-5.5"},
FallbackAction: BetaPolicyActionPass,
}},
}
svc := newOpenAIGatewayServiceWithSettings(t, settings)
account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
body := []byte(`{"model":"gpt-5.5","service_tier":"priority"}`)
updated, err := svc.applyOpenAIFastPolicyToBody(context.Background(), account, "gpt-5.5", body)
require.Error(t, err)
var blocked *OpenAIFastBlockedError
require.True(t, errors.As(err, &blocked))
require.Contains(t, blocked.Message, "fast mode is blocked")
require.Equal(t, string(body), string(updated)) // body not mutated on block
}
func TestSetOpenAIFastPolicySettings_Validation(t *testing.T) {
repo := &openAIFastPolicyRepoStub{values: map[string]string{}}
svc := NewSettingService(repo, &config.Config{})
// Invalid action rejected
err := svc.SetOpenAIFastPolicySettings(context.Background(), &OpenAIFastPolicySettings{
Rules: []OpenAIFastPolicyRule{{
ServiceTier: OpenAIFastTierPriority,
Action: "bogus",
Scope: BetaPolicyScopeAll,
}},
})
require.Error(t, err)
// Invalid service_tier rejected
err = svc.SetOpenAIFastPolicySettings(context.Background(), &OpenAIFastPolicySettings{
Rules: []OpenAIFastPolicyRule{{
ServiceTier: "turbo",
Action: BetaPolicyActionPass,
Scope: BetaPolicyScopeAll,
}},
})
require.Error(t, err)
// Valid settings persisted
err = svc.SetOpenAIFastPolicySettings(context.Background(), &OpenAIFastPolicySettings{
Rules: []OpenAIFastPolicyRule{{
ServiceTier: OpenAIFastTierPriority,
Action: BetaPolicyActionFilter,
Scope: BetaPolicyScopeAll,
}},
})
require.NoError(t, err)
got, err := svc.GetOpenAIFastPolicySettings(context.Background())
require.NoError(t, err)
require.Len(t, got.Rules, 1)
require.Equal(t, OpenAIFastTierPriority, got.Rules[0].ServiceTier)
}
package service
import (
"context"
"encoding/json"
"errors"
"net/http"
"net/http/httptest"
"strings"
"sync/atomic"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/apicompat"
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
coderws "github.com/coder/websocket"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
"github.com/tidwall/gjson"
)
// --- Helper-level (unit) tests for applyOpenAIFastPolicyToWSResponseCreate ---
func TestWSResponseCreate_FilterStripsServiceTier(t *testing.T) {
svc := newOpenAIGatewayServiceWithSettings(t, DefaultOpenAIFastPolicySettings())
account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
frame := []byte(`{"type":"response.create","model":"gpt-5.5","service_tier":"priority","input":[{"type":"input_text","text":"hi"}]}`)
updated, blocked, err := svc.applyOpenAIFastPolicyToWSResponseCreate(context.Background(), account, "gpt-5.5", frame)
require.NoError(t, err)
require.Nil(t, blocked)
require.NotContains(t, string(updated), `"service_tier"`, "filter action should strip service_tier")
// Other fields preserved.
require.Equal(t, "response.create", gjson.GetBytes(updated, "type").String())
require.Equal(t, "gpt-5.5", gjson.GetBytes(updated, "model").String())
require.Equal(t, "hi", gjson.GetBytes(updated, "input.0.text").String())
}
func TestWSResponseCreate_FastNormalizedToPriorityThenFiltered(t *testing.T) {
svc := newOpenAIGatewayServiceWithSettings(t, DefaultOpenAIFastPolicySettings())
account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
// Verbatim "fast" → normalized to "priority" → matches default rule → filter.
frame := []byte(`{"type":"response.create","model":"gpt-5.5","service_tier":"fast"}`)
updated, blocked, err := svc.applyOpenAIFastPolicyToWSResponseCreate(context.Background(), account, "gpt-5.5", frame)
require.NoError(t, err)
require.Nil(t, blocked)
require.NotContains(t, string(updated), `"service_tier"`)
// Mixed-case + whitespace variant should also normalize and filter.
frame = []byte(`{"type":"response.create","model":"gpt-5.5","service_tier":" Fast "}`)
updated, blocked, err = svc.applyOpenAIFastPolicyToWSResponseCreate(context.Background(), account, "gpt-5.5", frame)
require.NoError(t, err)
require.Nil(t, blocked)
require.NotContains(t, string(updated), `"service_tier"`)
}
func TestWSResponseCreate_FlexPassThrough(t *testing.T) {
svc := newOpenAIGatewayServiceWithSettings(t, DefaultOpenAIFastPolicySettings())
account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
// Default policy targets priority only; flex is left untouched.
frame := []byte(`{"type":"response.create","model":"gpt-5.5","service_tier":"flex"}`)
updated, blocked, err := svc.applyOpenAIFastPolicyToWSResponseCreate(context.Background(), account, "gpt-5.5", frame)
require.NoError(t, err)
require.Nil(t, blocked)
require.Equal(t, "flex", gjson.GetBytes(updated, "service_tier").String(), "flex frames must reach upstream untouched under default policy")
}
func TestWSResponseCreate_BlockReturnsTypedError(t *testing.T) {
settings := &OpenAIFastPolicySettings{
Rules: []OpenAIFastPolicyRule{{
ServiceTier: OpenAIFastTierPriority,
Action: BetaPolicyActionBlock,
Scope: BetaPolicyScopeAll,
ErrorMessage: "ws fast blocked",
ModelWhitelist: []string{"gpt-5.5"},
FallbackAction: BetaPolicyActionPass,
}},
}
svc := newOpenAIGatewayServiceWithSettings(t, settings)
account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
frame := []byte(`{"type":"response.create","model":"gpt-5.5","service_tier":"priority"}`)
updated, blocked, err := svc.applyOpenAIFastPolicyToWSResponseCreate(context.Background(), account, "gpt-5.5", frame)
require.NoError(t, err)
require.NotNil(t, blocked)
require.Equal(t, "ws fast blocked", blocked.Message)
// On block, payload returned unchanged so caller can inspect / log it.
require.Equal(t, string(frame), string(updated))
}
func TestWSResponseCreate_NoServiceTierUntouched(t *testing.T) {
svc := newOpenAIGatewayServiceWithSettings(t, DefaultOpenAIFastPolicySettings())
account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
frame := []byte(`{"type":"response.create","model":"gpt-5.5","input":[]}`)
updated, blocked, err := svc.applyOpenAIFastPolicyToWSResponseCreate(context.Background(), account, "gpt-5.5", frame)
require.NoError(t, err)
require.Nil(t, blocked)
require.Equal(t, string(frame), string(updated), "no service_tier present must result in zero mutation")
}
func TestWSResponseCreate_NonResponseCreateFrameUntouched(t *testing.T) {
settings := &OpenAIFastPolicySettings{
Rules: []OpenAIFastPolicyRule{{
ServiceTier: OpenAIFastTierPriority,
Action: BetaPolicyActionFilter,
Scope: BetaPolicyScopeAll,
ModelWhitelist: []string{"*"},
FallbackAction: BetaPolicyActionFilter,
}},
}
svc := newOpenAIGatewayServiceWithSettings(t, settings)
account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
// response.cancel happens to carry a service_tier-shaped field — must not be touched.
frame := []byte(`{"type":"response.cancel","service_tier":"priority"}`)
updated, blocked, err := svc.applyOpenAIFastPolicyToWSResponseCreate(context.Background(), account, "gpt-5.5", frame)
require.NoError(t, err)
require.Nil(t, blocked)
require.Equal(t, string(frame), string(updated))
}
// TestWSResponseCreate_EmptyTypeFrameUntouched is the A1 regression: the
// helper used to treat empty type as response.create, which risked stripping
// fields from malformed / unknown client events. After the A1 fix only a
// strict "response.create" match triggers policy.
func TestWSResponseCreate_EmptyTypeFrameUntouched(t *testing.T) {
settings := &OpenAIFastPolicySettings{
Rules: []OpenAIFastPolicyRule{{
ServiceTier: OpenAIFastTierPriority,
Action: BetaPolicyActionFilter,
Scope: BetaPolicyScopeAll,
ModelWhitelist: []string{"*"},
FallbackAction: BetaPolicyActionFilter,
}},
}
svc := newOpenAIGatewayServiceWithSettings(t, settings)
account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
// Frame with no "type" field: must pass through completely unchanged
// even with a service_tier-shaped field present.
frame := []byte(`{"service_tier":"priority","model":"gpt-5.5"}`)
updated, blocked, err := svc.applyOpenAIFastPolicyToWSResponseCreate(context.Background(), account, "gpt-5.5", frame)
require.NoError(t, err)
require.Nil(t, blocked)
require.Equal(t, string(frame), string(updated), "empty type must NOT be policy-checked — Realtime spec requires type, malformed frames are passed through")
// Explicit empty string also passes through.
frame = []byte(`{"type":"","service_tier":"priority","model":"gpt-5.5"}`)
updated, blocked, err = svc.applyOpenAIFastPolicyToWSResponseCreate(context.Background(), account, "gpt-5.5", frame)
require.NoError(t, err)
require.Nil(t, blocked)
require.Equal(t, string(frame), string(updated))
}
// TestBuildOpenAIFastPolicyBlockedWSEvent_HasEventIDAndCode is the B1
// regression: the rendered Realtime error event must carry a non-empty
// event_id (so clients can correlate the rejection) and a stable error.code
// ("policy_violation"). The HTTP-side equivalent is the 403 permission_error
// JSON body emitted by writeOpenAIFastPolicyBlockedResponse.
func TestBuildOpenAIFastPolicyBlockedWSEvent_HasEventIDAndCode(t *testing.T) {
bytes := buildOpenAIFastPolicyBlockedWSEvent(&OpenAIFastBlockedError{Message: "blocked because reasons"})
require.NotNil(t, bytes)
require.Equal(t, "error", gjson.GetBytes(bytes, "type").String())
require.Equal(t, "invalid_request_error", gjson.GetBytes(bytes, "error.type").String())
require.Equal(t, "policy_violation", gjson.GetBytes(bytes, "error.code").String())
require.Equal(t, "blocked because reasons", gjson.GetBytes(bytes, "error.message").String())
eventID := gjson.GetBytes(bytes, "event_id").String()
require.NotEmpty(t, eventID, "event_id must be present so clients can correlate the rejection in their logs")
require.True(t, strings.HasPrefix(eventID, "evt_"), "event_id should follow the evt_<rand> Realtime convention; got %q", eventID)
// Sanity check: two consecutive events get distinct IDs.
other := buildOpenAIFastPolicyBlockedWSEvent(&OpenAIFastBlockedError{Message: "second"})
otherID := gjson.GetBytes(other, "event_id").String()
require.NotEqual(t, eventID, otherID, "event_id must be random per-event")
}
// TestBuildOpenAIFastPolicyBlockedWSEvent_NilSafe ensures the helper returns
// nil for a nil error (defensive guard for callers that always invoke it).
func TestBuildOpenAIFastPolicyBlockedWSEvent_NilSafe(t *testing.T) {
require.Nil(t, buildOpenAIFastPolicyBlockedWSEvent(nil))
}
// --- D5: passthrough wrapper FrameConn — capturedSessionModel fallback ---
// fakePassthroughFrameConn replays a fixed sequence of client frames into the
// policy-enforcing wrapper, then returns io.EOF. Captures all Write attempts
// for write-side assertions (none expected in the D5 test, since the wrapper
// only filters reads).
type fakePassthroughFrameConn struct {
reads [][]byte
idx int
writes [][]byte
closeOnce bool
}
func (f *fakePassthroughFrameConn) ReadFrame(ctx context.Context) (coderws.MessageType, []byte, error) {
if f.idx >= len(f.reads) {
return coderws.MessageText, nil, errOpenAIWSConnClosed
}
payload := f.reads[f.idx]
f.idx++
return coderws.MessageText, payload, nil
}
func (f *fakePassthroughFrameConn) WriteFrame(ctx context.Context, msgType coderws.MessageType, payload []byte) error {
cp := append([]byte(nil), payload...)
f.writes = append(f.writes, cp)
return nil
}
func (f *fakePassthroughFrameConn) Close() error {
f.closeOnce = true
return nil
}
// gpt55WhitelistFastPolicy 返回一份强制带 model whitelist 的策略,用于
// 验证 capturedSessionModel fallback 的语义(默认策略 whitelist 为空时
// fallback 路径无法被观察到)。
func gpt55WhitelistFastPolicy() *OpenAIFastPolicySettings {
return &OpenAIFastPolicySettings{
Rules: []OpenAIFastPolicyRule{{
ServiceTier: OpenAIFastTierPriority,
Action: BetaPolicyActionFilter,
Scope: BetaPolicyScopeAll,
ModelWhitelist: []string{"gpt-5.5", "gpt-5.5*"},
FallbackAction: BetaPolicyActionPass,
}},
}
}
// TestPolicyEnforcingFrameConn_FollowupFrameWithoutModelUsesCapturedModel is
// the D5 regression: in passthrough mode a follow-up response.create frame
// without a "model" field must still hit the policy via the session-level
// model captured from the first frame. Without the fallback an empty model
// would miss a model whitelist and silently leak service_tier=priority
// through to the upstream.
func TestPolicyEnforcingFrameConn_FollowupFrameWithoutModelUsesCapturedModel(t *testing.T) {
// 此处特意使用带 whitelist 的策略,以便观察 capturedSessionModel
// fallback 是否生效(默认策略 whitelist 为空,fallback 与否结果一致,
// 不能用来覆盖此回归)。
svc := newOpenAIGatewayServiceWithSettings(t, gpt55WhitelistFastPolicy())
account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
// Simulate the passthrough adapter capturing model from the first frame.
firstFrame := []byte(`{"type":"response.create","model":"gpt-5.5","service_tier":"priority"}`)
capturedSessionModel := openAIWSPassthroughPolicyModelForFrame(account, firstFrame)
require.Equal(t, "gpt-5.5", capturedSessionModel)
// Follow-up frame deliberately omits "model" — Realtime allows this.
followupFrame := []byte(`{"type":"response.create","service_tier":"priority"}`)
inner := &fakePassthroughFrameConn{
reads: [][]byte{followupFrame},
}
wrapper := &openAIWSPolicyEnforcingFrameConn{
inner: inner,
filter: func(msgType coderws.MessageType, payload []byte) ([]byte, *OpenAIFastBlockedError, error) {
if msgType != coderws.MessageText {
return payload, nil, nil
}
model := openAIWSPassthroughPolicyModelForFrame(account, payload)
if model == "" {
model = capturedSessionModel
}
return svc.applyOpenAIFastPolicyToWSResponseCreate(context.Background(), account, model, payload)
},
}
// Read the follow-up frame through the wrapper. The policy MUST still
// trigger filter (gpt-5.5 + priority → filter), so the service_tier
// field is gone by the time the relay sees it.
_, payload, err := wrapper.ReadFrame(context.Background())
require.NoError(t, err)
require.NotContains(t, string(payload), `"service_tier"`,
"D5 regression: empty model on follow-up frame must fall back to capturedSessionModel; whitelist policy filters service_tier=priority for gpt-5.5")
require.Equal(t, "response.create", gjson.GetBytes(payload, "type").String())
}
// TestPolicyEnforcingFrameConn_WithoutCapturedFallbackPolicyMisses pins the
// inverse: when the wrapper has NO capturedSessionModel fallback (model is
// empty per-frame and no fallback is wired up), the policy fails to match
// the model whitelist and the frame leaks through unchanged. This documents
// exactly the leak the D5 fix prevents.
func TestPolicyEnforcingFrameConn_WithoutCapturedFallbackPolicyMisses(t *testing.T) {
// 同样使用带 whitelist 的策略以观察 leak。
svc := newOpenAIGatewayServiceWithSettings(t, gpt55WhitelistFastPolicy())
account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
followupFrame := []byte(`{"type":"response.create","service_tier":"priority"}`)
inner := &fakePassthroughFrameConn{reads: [][]byte{followupFrame}}
wrapper := &openAIWSPolicyEnforcingFrameConn{
inner: inner,
filter: func(msgType coderws.MessageType, payload []byte) ([]byte, *OpenAIFastBlockedError, error) {
// NO fallback — emulate the pre-fix behavior.
model := openAIWSPassthroughPolicyModelForFrame(account, payload)
return svc.applyOpenAIFastPolicyToWSResponseCreate(context.Background(), account, model, payload)
},
}
_, payload, err := wrapper.ReadFrame(context.Background())
require.NoError(t, err)
// Pre-fix: empty model misses ["gpt-5.5","gpt-5.5*"] whitelist → fallback=pass → service_tier kept.
require.Contains(t, string(payload), `"service_tier"`,
"sanity: without capturedSessionModel fallback the leak (D5) reproduces — confirms the fix is load-bearing")
}
// --- Ingress end-to-end test (filter path) ---
// TestWSResponseCreate_IngressFiltersServiceTierBeforeUpstream wires up the
// real ProxyResponsesWebSocketFromClient ingress session pipeline against a
// captureConn upstream and asserts that a client frame with service_tier=fast
// is normalized + filtered out before being written upstream. This is the
// integration flavour of TestWSResponseCreate_FilterStripsServiceTier.
func TestWSResponseCreate_IngressFiltersServiceTierBeforeUpstream(t *testing.T) {
gin.SetMode(gin.TestMode)
cfg := &config.Config{}
cfg.Security.URLAllowlist.Enabled = false
cfg.Security.URLAllowlist.AllowInsecureHTTP = true
cfg.Gateway.OpenAIWS.Enabled = true
cfg.Gateway.OpenAIWS.OAuthEnabled = true
cfg.Gateway.OpenAIWS.APIKeyEnabled = true
cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1
cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0
cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1
cfg.Gateway.OpenAIWS.QueueLimitPerConn = 8
cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3
cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 3
cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 3
captureConn := &openAIWSCaptureConn{
events: [][]byte{
[]byte(`{"type":"response.completed","response":{"id":"resp_ws_filter_1","model":"gpt-5.5","usage":{"input_tokens":1,"output_tokens":1}}}`),
},
}
captureDialer := &openAIWSCaptureDialer{conn: captureConn}
pool := newOpenAIWSConnPool(cfg)
pool.setClientDialerForTest(captureDialer)
repo := &openAIFastPolicyRepoStub{values: map[string]string{}}
defaultJSON, err := json.Marshal(DefaultOpenAIFastPolicySettings())
require.NoError(t, err)
repo.values[SettingKeyOpenAIFastPolicySettings] = string(defaultJSON)
svc := &OpenAIGatewayService{
cfg: cfg,
httpUpstream: &httpUpstreamRecorder{},
cache: &stubGatewayCache{},
openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
toolCorrector: NewCodexToolCorrector(),
openaiWSPool: pool,
settingService: NewSettingService(repo, cfg),
}
account := &Account{
ID: 901,
Name: "openai-ws-filter",
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Status: StatusActive,
Schedulable: true,
Concurrency: 1,
Credentials: map[string]any{"api_key": "sk-test"},
Extra: map[string]any{
"responses_websockets_v2_enabled": true,
},
}
serverErrCh := make(chan error, 1)
wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
conn, err := coderws.Accept(w, r, &coderws.AcceptOptions{
CompressionMode: coderws.CompressionContextTakeover,
})
if err != nil {
serverErrCh <- err
return
}
defer func() { _ = conn.CloseNow() }()
rec := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(rec)
req := r.Clone(r.Context())
req.Header = req.Header.Clone()
req.Header.Set("User-Agent", "unit-test-agent/1.0")
ginCtx.Request = req
readCtx, cancel := context.WithTimeout(r.Context(), 3*time.Second)
_, firstMessage, readErr := conn.Read(readCtx)
cancel()
if readErr != nil {
serverErrCh <- readErr
return
}
serverErrCh <- svc.ProxyResponsesWebSocketFromClient(r.Context(), ginCtx, conn, account, "sk-test", firstMessage, nil)
}))
defer wsServer.Close()
dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second)
clientConn, _, err := coderws.Dial(dialCtx, "ws"+strings.TrimPrefix(wsServer.URL, "http"), nil)
cancelDial()
require.NoError(t, err)
defer func() { _ = clientConn.CloseNow() }()
writeCtx, cancelWrite := context.WithTimeout(context.Background(), 3*time.Second)
require.NoError(t, clientConn.Write(writeCtx, coderws.MessageText, []byte(`{"type":"response.create","model":"gpt-5.5","stream":false,"service_tier":"fast"}`)))
cancelWrite()
readCtx, cancelRead := context.WithTimeout(context.Background(), 3*time.Second)
_, event, readErr := clientConn.Read(readCtx)
cancelRead()
require.NoError(t, readErr)
require.Equal(t, "response.completed", gjson.GetBytes(event, "type").String())
require.NoError(t, clientConn.Close(coderws.StatusNormalClosure, "done"))
select {
case serverErr := <-serverErrCh:
require.NoError(t, serverErr)
case <-time.After(5 * time.Second):
t.Fatal("等待 ingress websocket 结束超时")
}
require.Len(t, captureConn.writes, 1, "上游应只收到一条 response.create")
upstream := captureConn.writes[0]
_, hasServiceTier := upstream["service_tier"]
require.False(t, hasServiceTier, "上游收到的 response.create 不应包含 service_tier 字段(已被 fast policy filter 删除)")
require.Equal(t, "response.create", upstream["type"])
require.Equal(t, "gpt-5.5", upstream["model"])
}
// TestWSResponseCreate_IngressBlockSendsErrorEventAndSkipsUpstream is the
// integration flavour of TestWSResponseCreate_BlockReturnsTypedError. It
// asserts that with a custom block rule, the client receives a Realtime-style
// error event AND the upstream FrameConn never receives the offending frame.
func TestWSResponseCreate_IngressBlockSendsErrorEventAndSkipsUpstream(t *testing.T) {
gin.SetMode(gin.TestMode)
cfg := &config.Config{}
cfg.Security.URLAllowlist.Enabled = false
cfg.Security.URLAllowlist.AllowInsecureHTTP = true
cfg.Gateway.OpenAIWS.Enabled = true
cfg.Gateway.OpenAIWS.OAuthEnabled = true
cfg.Gateway.OpenAIWS.APIKeyEnabled = true
cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1
cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0
cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1
cfg.Gateway.OpenAIWS.QueueLimitPerConn = 8
cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3
cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 3
cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 3
captureConn := &openAIWSCaptureConn{
// No events queued; the upstream should never get written to anyway.
}
captureDialer := &openAIWSCaptureDialer{conn: captureConn}
pool := newOpenAIWSConnPool(cfg)
pool.setClientDialerForTest(captureDialer)
blockSettings := &OpenAIFastPolicySettings{
Rules: []OpenAIFastPolicyRule{{
ServiceTier: OpenAIFastTierPriority,
Action: BetaPolicyActionBlock,
Scope: BetaPolicyScopeAll,
ErrorMessage: "ws priority blocked for testing",
ModelWhitelist: []string{"gpt-5.5"},
FallbackAction: BetaPolicyActionPass,
}},
}
repo := &openAIFastPolicyRepoStub{values: map[string]string{}}
raw, err := json.Marshal(blockSettings)
require.NoError(t, err)
repo.values[SettingKeyOpenAIFastPolicySettings] = string(raw)
svc := &OpenAIGatewayService{
cfg: cfg,
httpUpstream: &httpUpstreamRecorder{},
cache: &stubGatewayCache{},
openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
toolCorrector: NewCodexToolCorrector(),
openaiWSPool: pool,
settingService: NewSettingService(repo, cfg),
}
account := &Account{
ID: 902,
Name: "openai-ws-block",
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Status: StatusActive,
Schedulable: true,
Concurrency: 1,
Credentials: map[string]any{"api_key": "sk-test"},
Extra: map[string]any{
"responses_websockets_v2_enabled": true,
},
}
serverErrCh := make(chan error, 1)
wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
conn, err := coderws.Accept(w, r, &coderws.AcceptOptions{
CompressionMode: coderws.CompressionContextTakeover,
})
if err != nil {
serverErrCh <- err
return
}
defer func() { _ = conn.CloseNow() }()
rec := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(rec)
req := r.Clone(r.Context())
req.Header = req.Header.Clone()
req.Header.Set("User-Agent", "unit-test-agent/1.0")
ginCtx.Request = req
readCtx, cancel := context.WithTimeout(r.Context(), 3*time.Second)
_, firstMessage, readErr := conn.Read(readCtx)
cancel()
if readErr != nil {
serverErrCh <- readErr
return
}
proxyErr := svc.ProxyResponsesWebSocketFromClient(r.Context(), ginCtx, conn, account, "sk-test", firstMessage, nil)
// Mirror the production handler (openai_gateway_handler.go:1325-1328):
// when the proxy returns an OpenAIWSClientCloseError, surface its
// status code to the client via a graceful close handshake. Without
// this the deferred CloseNow() above would tear down the TCP
// connection without sending a close frame, and the C3 timing
// assertion (next read returns CloseStatus=1008) would see EOF
// instead.
var closeErr *OpenAIWSClientCloseError
if errors.As(proxyErr, &closeErr) {
_ = conn.Close(closeErr.StatusCode(), closeErr.Reason())
}
serverErrCh <- proxyErr
}))
defer wsServer.Close()
dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second)
clientConn, _, err := coderws.Dial(dialCtx, "ws"+strings.TrimPrefix(wsServer.URL, "http"), nil)
cancelDial()
require.NoError(t, err)
defer func() { _ = clientConn.CloseNow() }()
writeCtx, cancelWrite := context.WithTimeout(context.Background(), 3*time.Second)
require.NoError(t, clientConn.Write(writeCtx, coderws.MessageText, []byte(`{"type":"response.create","model":"gpt-5.5","stream":false,"service_tier":"priority"}`)))
cancelWrite()
// C3 timing assertion: the FIRST frame the client reads must be the
// error event — not a close frame. coder/websocket@v1.8.14 Conn.Write is
// synchronous (writeFrame Flushes the bufio writer at write.go:307-311
// before returning) and the close handshake re-acquires the same
// writeFrameMu, so this ordering is enforced by the library itself; this
// assertion guards against future refactors that might break it.
readCtx, cancelRead := context.WithTimeout(context.Background(), 3*time.Second)
_, event, readErr := clientConn.Read(readCtx)
cancelRead()
require.NoError(t, readErr, "first read must succeed and return the error event before any close frame")
require.Equal(t, "error", gjson.GetBytes(event, "type").String())
require.Equal(t, "invalid_request_error", gjson.GetBytes(event, "error.type").String())
// B1 regression: event_id + error.code must be populated.
require.Equal(t, "policy_violation", gjson.GetBytes(event, "error.code").String())
require.NotEmpty(t, gjson.GetBytes(event, "event_id").String(), "event_id must be present so clients can correlate")
require.Contains(t, gjson.GetBytes(event, "error.message").String(), "ws priority blocked for testing")
// Next read must surface the close frame (as a CloseError). This
// asserts the [error event, close] ordering — i.e. the close did NOT
// race ahead of the data frame.
readCtx2, cancelRead2 := context.WithTimeout(context.Background(), 3*time.Second)
_, _, secondReadErr := clientConn.Read(readCtx2)
cancelRead2()
require.Error(t, secondReadErr, "after the error event the connection must surface a close")
require.Equal(t, coderws.StatusPolicyViolation, coderws.CloseStatus(secondReadErr),
"close status must be PolicyViolation; got %v", secondReadErr)
select {
case serverErr := <-serverErrCh:
// Server returns an OpenAIWSClientCloseError — handler closes the WS;
// here we just assert it surfaced as the typed close error.
require.Error(t, serverErr)
var closeErr *OpenAIWSClientCloseError
require.True(t, errors.As(serverErr, &closeErr), "block 应返回 OpenAIWSClientCloseError,得到 %T: %v", serverErr, serverErr)
require.Equal(t, coderws.StatusPolicyViolation, closeErr.StatusCode())
case <-time.After(5 * time.Second):
t.Fatal("等待 ingress 关闭超时")
}
// Critical: the offending frame must NEVER reach the upstream.
// captureDialer.DialCount may legitimately be 0 or 1 depending on whether
// the lease was acquired before policy fired; either way, no writes.
require.Empty(t, captureConn.writes, "block 命中后上游不应收到 response.create")
}
// --- HTTP-side gap-filling tests (already covered by existing tests but
// requested to be split out explicitly) ---
// TestApplyOpenAIFastPolicyToBody_BlockShortCircuitsUpstream confirms that
// applyOpenAIFastPolicyToBody surfaces a *OpenAIFastBlockedError when the rule
// action is "block", and that the body is left untouched. The caller (chat
// completions / messages handlers) inspects this typed error and skips the
// upstream HTTP call entirely — see openai_gateway_chat_completions.go:175 and
// openai_gateway_messages.go:149.
func TestApplyOpenAIFastPolicyToBody_BlockShortCircuitsUpstream(t *testing.T) {
settings := &OpenAIFastPolicySettings{
Rules: []OpenAIFastPolicyRule{{
ServiceTier: OpenAIFastTierPriority,
Action: BetaPolicyActionBlock,
Scope: BetaPolicyScopeAll,
ErrorMessage: "priority blocked",
ModelWhitelist: []string{"gpt-5.5"},
FallbackAction: BetaPolicyActionPass,
}},
}
svc := newOpenAIGatewayServiceWithSettings(t, settings)
account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
body := []byte(`{"model":"gpt-5.5","service_tier":"priority","input":[]}`)
updated, err := svc.applyOpenAIFastPolicyToBody(context.Background(), account, "gpt-5.5", body)
require.Error(t, err)
var blocked *OpenAIFastBlockedError
require.True(t, errors.As(err, &blocked), "block must surface as typed error so caller can skip upstream HTTP request")
require.Equal(t, "priority blocked", blocked.Message)
require.Equal(t, string(body), string(updated), "block must not mutate body")
}
// TestForwardAsAnthropicMessages_BetaFastModeTriggersOpenAIFastPolicy verifies
// the Anthropic-compat entrypoint chain: anthropic-beta: fast-mode → BetaFastMode
// detection → ServiceTier="priority" injection (openai_gateway_messages.go:60)
// → applyOpenAIFastPolicyToBody filter on default policy → upstream body has
// no service_tier. We exercise the same internal pipeline (Anthropic→Responses
// + BetaFastMode + policy) without spinning up a real upstream HTTP server.
func TestForwardAsAnthropicMessages_BetaFastModeTriggersOpenAIFastPolicy(t *testing.T) {
svc := newOpenAIGatewayServiceWithSettings(t, DefaultOpenAIFastPolicySettings())
account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
// Step 1: parse Anthropic request (mirrors openai_gateway_messages.go:38-50).
anthropicBody := []byte(`{"model":"gpt-5.5","max_tokens":64,"messages":[{"role":"user","content":"hi"}]}`)
var anthropicReq apicompat.AnthropicRequest
require.NoError(t, json.Unmarshal(anthropicBody, &anthropicReq))
responsesReq, err := apicompat.AnthropicToResponses(&anthropicReq)
require.NoError(t, err)
// Step 2: BetaFastMode header → service_tier="priority" (mirrors line 58-61).
headers := http.Header{}
headers.Set("anthropic-beta", claude.BetaFastMode)
require.True(t, containsBetaToken(headers.Get("anthropic-beta"), claude.BetaFastMode))
responsesReq.ServiceTier = "priority"
responsesReq.Model = "gpt-5.5"
// Step 3: marshal & apply fast policy (mirrors line 78 + 149).
responsesBody, err := json.Marshal(responsesReq)
require.NoError(t, err)
require.Equal(t, "priority", gjson.GetBytes(responsesBody, "service_tier").String(), "前置:beta 翻译应当注入 priority")
upstreamBody, policyErr := svc.applyOpenAIFastPolicyToBody(context.Background(), account, "gpt-5.5", responsesBody)
require.NoError(t, policyErr)
// Step 4: assert that policy filtered the field before the upstream HTTP request.
require.NotContains(t, string(upstreamBody), `"service_tier"`, "default policy 命中 gpt-5.5 priority 应当 filter 掉 service_tier")
}
// --- Fix1: passthrough capturedSessionModel must follow session.update ---
// TestPolicyEnforcingFrameConn_SessionUpdateRotatesCapturedModel covers the
// fix1 bypass: client opens with a whitelist-miss model (gpt-4o → pass under
// gpt-5.5 whitelist), rotates to gpt-5.5 via session.update, then sends
// response.create without "model". Without the session.update sniffing the
// follow-up frame would fall back to the stale gpt-4o capture and pass — the
// fix updates capturedSessionModel from session.* events so the fallback now
// resolves to gpt-5.5 and the policy filters service_tier.
func TestPolicyEnforcingFrameConn_SessionUpdateRotatesCapturedModel(t *testing.T) {
svc := newOpenAIGatewayServiceWithSettings(t, gpt55WhitelistFastPolicy())
account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
// Frame 1: response.create with whitelist-miss model — under default
// rule fallback=pass, service_tier stays.
first := []byte(`{"type":"response.create","model":"gpt-4o","service_tier":"priority"}`)
// Frame 2: session.update rotates the session model to gpt-5.5.
rotate := []byte(`{"type":"session.update","session":{"model":"gpt-5.5"}}`)
// Frame 3: response.create WITHOUT model — must inherit gpt-5.5.
followup := []byte(`{"type":"response.create","service_tier":"priority"}`)
inner := &fakePassthroughFrameConn{reads: [][]byte{first, rotate, followup}}
// Replicate the production wiring in openai_ws_v2_passthrough_adapter.go
// so capturedSessionModel state is shared across frames.
capturedSessionModel := openAIWSPassthroughPolicyModelForFrame(account, first)
require.Equal(t, "gpt-4o", capturedSessionModel)
wrapper := &openAIWSPolicyEnforcingFrameConn{
inner: inner,
filter: func(msgType coderws.MessageType, payload []byte) ([]byte, *OpenAIFastBlockedError, error) {
if msgType != coderws.MessageText {
return payload, nil, nil
}
if updated := openAIWSPassthroughPolicyModelFromSessionFrame(account, payload); updated != "" {
capturedSessionModel = updated
}
model := openAIWSPassthroughPolicyModelForFrame(account, payload)
if model == "" {
model = capturedSessionModel
}
return svc.applyOpenAIFastPolicyToWSResponseCreate(context.Background(), account, model, payload)
},
}
// Frame 1: gpt-4o miss whitelist → pass (service_tier preserved).
_, payload1, err := wrapper.ReadFrame(context.Background())
require.NoError(t, err)
require.Contains(t, string(payload1), `"service_tier"`, "frame1: gpt-4o miss whitelist → pass keeps service_tier")
// Frame 2: session.update — not response.create, untouched, but its
// side effect updates capturedSessionModel to gpt-5.5.
_, payload2, err := wrapper.ReadFrame(context.Background())
require.NoError(t, err)
require.Equal(t, string(rotate), string(payload2), "session.update frame is forwarded verbatim")
require.Equal(t, "gpt-5.5", capturedSessionModel, "fix1: session.update must rotate capturedSessionModel")
// Frame 3: empty model + new captured gpt-5.5 → matches whitelist → filter.
_, payload3, err := wrapper.ReadFrame(context.Background())
require.NoError(t, err)
require.NotContains(t, string(payload3), `"service_tier"`,
"fix1: post-rotate response.create without model must use refreshed capturedSessionModel and trigger filter")
}
// TestPolicyModelFromSessionFrame_OnlySessionUpdate covers the negative
// branches of openAIWSPassthroughPolicyModelFromSessionFrame: only
// client→upstream session.update frames rotate the captured model;
// server→client events (session.created) and unrelated frames must not.
func TestPolicyModelFromSessionFrame_OnlySessionUpdate(t *testing.T) {
account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
// session.created is a server→client event in the OpenAI Realtime
// protocol — clients never send it, so this filter (which only runs on
// the client→upstream direction) must ignore it even if it appears.
created := []byte(`{"type":"session.created","session":{"model":"gpt-5.5"}}`)
require.Empty(t, openAIWSPassthroughPolicyModelFromSessionFrame(account, created))
// Non-session.* frames must NOT trigger rotation.
notSession := []byte(`{"type":"response.create","session":{"model":"gpt-9"}}`)
require.Empty(t, openAIWSPassthroughPolicyModelFromSessionFrame(account, notSession))
// Missing session.model returns empty — caller keeps the old captured value.
noModel := []byte(`{"type":"session.update","session":{"voice":"alloy"}}`)
require.Empty(t, openAIWSPassthroughPolicyModelFromSessionFrame(account, noModel))
}
// --- Fix2: native /responses normalize "fast" → "priority" on pass ---
// TestApplyOpenAIFastPolicyToBody_PassNormalizesFastAlias is the fix2
// regression. Before the fix, when action=pass, applyOpenAIFastPolicyToBody
// returned the body unchanged so a raw "fast" alias would leak to the
// upstream OpenAI API (which does not accept "fast"). The fix normalizes
// "fast" → "priority" on pass too.
func TestApplyOpenAIFastPolicyToBody_PassNormalizesFastAlias(t *testing.T) {
// Use a policy that deliberately misses gpt-4 so the action is pass.
settings := &OpenAIFastPolicySettings{
Rules: []OpenAIFastPolicyRule{{
ServiceTier: OpenAIFastTierPriority,
Action: BetaPolicyActionFilter,
Scope: BetaPolicyScopeAll,
ModelWhitelist: []string{"gpt-5.5"},
FallbackAction: BetaPolicyActionPass,
}},
}
svc := newOpenAIGatewayServiceWithSettings(t, settings)
account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
// gpt-4 + "fast" → fallback pass. Body must be rewritten to "priority".
body := []byte(`{"model":"gpt-4","service_tier":"fast"}`)
updated, err := svc.applyOpenAIFastPolicyToBody(context.Background(), account, "gpt-4", body)
require.NoError(t, err)
require.Equal(t, "priority", gjson.GetBytes(updated, "service_tier").String(),
"fix2: pass action must still normalize 'fast' → 'priority' so upstream OpenAI accepts the slug")
// Already-canonical "priority" on pass: zero mutation (byte-equal).
body = []byte(`{"model":"gpt-4","service_tier":"priority"}`)
updated, err = svc.applyOpenAIFastPolicyToBody(context.Background(), account, "gpt-4", body)
require.NoError(t, err)
require.Equal(t, string(body), string(updated))
// Mixed-case alias → normalized.
body = []byte(`{"model":"gpt-4","service_tier":" Fast "}`)
updated, err = svc.applyOpenAIFastPolicyToBody(context.Background(), account, "gpt-4", body)
require.NoError(t, err)
require.Equal(t, "priority", gjson.GetBytes(updated, "service_tier").String())
// Unrecognized tier → still no-op (not normalized, since normTier == "").
body = []byte(`{"model":"gpt-4","service_tier":"turbo"}`)
updated, err = svc.applyOpenAIFastPolicyToBody(context.Background(), account, "gpt-4", body)
require.NoError(t, err)
require.Equal(t, string(body), string(updated))
}
// --- Fix3: passthrough billing must reflect post-filter service_tier ---
// TestPassthroughBilling_PostFilterServiceTier is the fix3 regression. The
// passthrough adapter (openai_ws_v2_passthrough_adapter.go) now extracts
// requestServiceTier from firstClientMessage AFTER applyOpenAIFastPolicy
// has rewritten it, so a filter hit causes billing to report nil (default
// tier) instead of the user-requested "priority". This test pins the
// contract those two helpers must uphold for the adapter's billing path.
func TestPassthroughBilling_PostFilterServiceTier(t *testing.T) {
svc := newOpenAIGatewayServiceWithSettings(t, DefaultOpenAIFastPolicySettings())
account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
raw := []byte(`{"type":"response.create","model":"gpt-5.5","service_tier":"priority"}`)
// Pre-filter sanity: extracting from the raw frame would (incorrectly,
// pre-fix) report "priority" — this is the very thing the adapter
// must NOT do anymore.
pre := extractOpenAIServiceTierFromBody(raw)
require.NotNil(t, pre)
require.Equal(t, "priority", *pre,
"sanity: raw first frame carries priority that pre-fix billing would have reported")
// Apply policy filter (default rule: gpt-5.5 + priority → filter).
filtered, blocked, err := svc.applyOpenAIFastPolicyToWSResponseCreate(context.Background(), account, "gpt-5.5", raw)
require.NoError(t, err)
require.Nil(t, blocked)
require.NotContains(t, string(filtered), `"service_tier"`)
// Post-filter: extracting from the rewritten frame returns nil. This
// is the value the adapter now passes to OpenAIForwardResult.ServiceTier,
// so billing records "default" instead of "priority".
post := extractOpenAIServiceTierFromBody(filtered)
require.Nil(t, post, "fix3: post-filter extraction must return nil so passthrough billing reports default tier instead of the requested priority")
// And the byte-level invariant the adapter relies on: filtering an
// already-filtered frame is a no-op (idempotent), so re-running the
// policy doesn't accidentally re-introduce the field.
again, blocked2, err := svc.applyOpenAIFastPolicyToWSResponseCreate(context.Background(), account, "gpt-5.5", filtered)
require.NoError(t, err)
require.Nil(t, blocked2)
require.Equal(t, string(filtered), string(again),
"policy is idempotent: filtering an already-filtered frame leaves bytes unchanged")
}
// TestApplyOpenAIFastPolicyToBody_NonStringServiceTier covers the test gap
// flagged in the review: when a client sends service_tier as a non-string
// (number, null, object, etc.) the policy must NOT panic and must NOT
// pretend the field was filtered. Behavior: skip policy entirely (treat as
// "no usable tier"), forward body unchanged. This mirrors the HTTP entry's
// type-assertion `reqBody["service_tier"].(string); ok` guard.
func TestApplyOpenAIFastPolicyToBody_NonStringServiceTier(t *testing.T) {
svc := newOpenAIGatewayServiceWithSettings(t, DefaultOpenAIFastPolicySettings())
account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
// Number — gjson .String() coerces to "1" which is not a recognized
// tier alias; normalize returns "" → policy no-ops.
cases := [][]byte{
[]byte(`{"model":"gpt-5.5","service_tier":1}`),
[]byte(`{"model":"gpt-5.5","service_tier":null}`),
[]byte(`{"model":"gpt-5.5","service_tier":{"nested":"priority"}}`),
[]byte(`{"model":"gpt-5.5","service_tier":["priority"]}`),
[]byte(`{"model":"gpt-5.5","service_tier":true}`),
}
for _, body := range cases {
updated, err := svc.applyOpenAIFastPolicyToBody(context.Background(), account, "gpt-5.5", body)
require.NoError(t, err, "non-string service_tier must not error: %s", string(body))
require.Equal(t, string(body), string(updated),
"non-string service_tier must pass through unchanged: %s", string(body))
}
// Same guard for the WS response.create entry.
for _, body := range cases {
frame := body
updated, blocked, err := svc.applyOpenAIFastPolicyToWSResponseCreate(context.Background(), account, "gpt-5.5", frame)
require.NoError(t, err, "non-string service_tier ws frame must not error: %s", string(frame))
require.Nil(t, blocked, "non-string service_tier must not trigger block: %s", string(frame))
require.Equal(t, string(frame), string(updated),
"non-string service_tier ws frame must pass through unchanged: %s", string(frame))
}
}
// TestPassthroughBilling_MultiTurnServiceTierFollowsFilteredFrames covers the
// multi-turn passthrough billing regression: OpenAI Realtime / Responses WS
// allows the client to ship a different service_tier on each response.create
// frame (per-response field, see codex-rs/core/src/client.rs
// build_responses_request which re-fills the field on every request). Before
// the fix the adapter only captured service_tier from firstClientMessage so
// turn 2/3 billing was wrong. After the fix the filter closure refreshes an
// atomic.Pointer[string] on every successful response.create frame.
//
// This test pins the four legs of the semantic contract:
// - turn 1: service_tier=priority hits the default whitelist filter, so
// after filter the upstream sees no tier → billing is nil.
// - turn 2: service_tier=flex passes (default rule targets priority only),
// billing should now reflect "flex".
// - turn 3: response.create without any service_tier — the upstream will
// treat it as default; we choose to mirror that and overwrite billing
// to nil rather than carry over "flex" from turn 2.
// - non-response.create frame (response.cancel here) carrying a stray
// service_tier-shaped field must NOT clobber the billing pointer.
func TestPassthroughBilling_MultiTurnServiceTierFollowsFilteredFrames(t *testing.T) {
svc := newOpenAIGatewayServiceWithSettings(t, DefaultOpenAIFastPolicySettings())
account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
// Mirror the production filter closure (openai_ws_v2_passthrough_adapter.go
// proxyResponsesWebSocketV2Passthrough) so this test fails if the
// production code drops the per-frame Store.
var requestServiceTierPtr atomic.Pointer[string]
capturedSessionModel := ""
filter := func(msgType coderws.MessageType, payload []byte) ([]byte, *OpenAIFastBlockedError, error) {
if msgType != coderws.MessageText {
return payload, nil, nil
}
if updated := openAIWSPassthroughPolicyModelFromSessionFrame(account, payload); updated != "" {
capturedSessionModel = updated
}
model := openAIWSPassthroughPolicyModelForFrame(account, payload)
if model == "" {
model = capturedSessionModel
}
out, blocked, policyErr := svc.applyOpenAIFastPolicyToWSResponseCreate(context.Background(), account, model, payload)
if policyErr == nil && blocked == nil &&
strings.TrimSpace(gjson.GetBytes(payload, "type").String()) == "response.create" {
requestServiceTierPtr.Store(extractOpenAIServiceTierFromBody(out))
}
return out, blocked, policyErr
}
// First-frame initialization mirrors the adapter: extract from the
// post-filter payload so a filter-on-first-frame zeroes billing too.
firstFrame := []byte(`{"type":"response.create","model":"gpt-5.5","service_tier":"priority"}`)
firstOut, firstBlocked, firstErr := svc.applyOpenAIFastPolicyToWSResponseCreate(context.Background(), account, "gpt-5.5", firstFrame)
require.NoError(t, firstErr)
require.Nil(t, firstBlocked)
requestServiceTierPtr.Store(extractOpenAIServiceTierFromBody(firstOut))
capturedSessionModel = openAIWSPassthroughPolicyModelForFrame(account, firstFrame)
require.Nil(t, requestServiceTierPtr.Load(),
"turn 1: filter strips service_tier=priority, billing must reflect upstream-actual nil tier")
// Turn 2: client switches to flex, should pass and update billing.
turn2 := []byte(`{"type":"response.create","model":"gpt-5.5","service_tier":"flex"}`)
out2, blocked2, err2 := filter(coderws.MessageText, turn2)
require.NoError(t, err2)
require.Nil(t, blocked2)
require.Equal(t, "flex", gjson.GetBytes(out2, "service_tier").String(), "turn 2: flex must pass to upstream untouched")
tier2 := requestServiceTierPtr.Load()
require.NotNil(t, tier2, "turn 2: billing must update to reflect flex")
require.Equal(t, "flex", *tier2)
// A non-response.create frame with a stray service_tier-shaped field
// must NOT overwrite the billing pointer (those frames don't carry
// per-response service_tier in the Realtime spec).
cancelFrame := []byte(`{"type":"response.cancel","service_tier":"priority"}`)
_, blockedCancel, errCancel := filter(coderws.MessageText, cancelFrame)
require.NoError(t, errCancel)
require.Nil(t, blockedCancel)
tierAfterCancel := requestServiceTierPtr.Load()
require.NotNil(t, tierAfterCancel, "response.cancel must not clobber billing tier to nil")
require.Equal(t, "flex", *tierAfterCancel,
"non-response.create frames must not update billing tier even if they carry a service_tier-shaped field")
// Turn 3: response.create without any service_tier. We deliberately
// overwrite billing back to nil so it tracks what the upstream actually
// sees on this turn (default tier).
turn3 := []byte(`{"type":"response.create","model":"gpt-5.5"}`)
out3, blocked3, err3 := filter(coderws.MessageText, turn3)
require.NoError(t, err3)
require.Nil(t, blocked3)
require.Equal(t, string(turn3), string(out3), "turn 3 has no service_tier — filter must not mutate")
require.Nil(t, requestServiceTierPtr.Load(),
"turn 3: response.create without service_tier overwrites billing to nil to match upstream default")
}
// TestPassthroughBilling_BlockedFrameDoesNotMutateServiceTier locks in the
// "block keeps previous" semantic: when policy returns block on a
// response.create frame, that frame is never sent upstream, so billing tier
// must keep the previous turn's value rather than getting silently zeroed.
func TestPassthroughBilling_BlockedFrameDoesNotMutateServiceTier(t *testing.T) {
blockSettings := &OpenAIFastPolicySettings{
Rules: []OpenAIFastPolicyRule{{
ServiceTier: OpenAIFastTierPriority,
Action: BetaPolicyActionBlock,
Scope: BetaPolicyScopeAll,
ErrorMessage: "blocked",
ModelWhitelist: []string{"gpt-5.5"},
FallbackAction: BetaPolicyActionPass,
}},
}
svc := newOpenAIGatewayServiceWithSettings(t, blockSettings)
account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
var requestServiceTierPtr atomic.Pointer[string]
flexValue := "flex"
requestServiceTierPtr.Store(&flexValue) // simulate prior turn billed as flex
filter := func(msgType coderws.MessageType, payload []byte) ([]byte, *OpenAIFastBlockedError, error) {
if msgType != coderws.MessageText {
return payload, nil, nil
}
out, blocked, policyErr := svc.applyOpenAIFastPolicyToWSResponseCreate(context.Background(), account, "gpt-5.5", payload)
if policyErr == nil && blocked == nil &&
strings.TrimSpace(gjson.GetBytes(payload, "type").String()) == "response.create" {
requestServiceTierPtr.Store(extractOpenAIServiceTierFromBody(out))
}
return out, blocked, policyErr
}
frame := []byte(`{"type":"response.create","model":"gpt-5.5","service_tier":"priority"}`)
_, blocked, err := filter(coderws.MessageText, frame)
require.NoError(t, err)
require.NotNil(t, blocked, "policy must block this frame")
tier := requestServiceTierPtr.Load()
require.NotNil(t, tier, "blocked frame must not clobber prior billing tier to nil")
require.Equal(t, "flex", *tier,
"blocked frame is never sent upstream; billing must retain the previous turn's tier")
}
......@@ -171,6 +171,17 @@ func (s *OpenAIGatewayService) ForwardAsChatCompletions(
}
}
// 4b. Apply OpenAI fast policy (may filter service_tier or block the request).
updatedBody, policyErr := s.applyOpenAIFastPolicyToBody(ctx, account, upstreamModel, responsesBody)
if policyErr != nil {
var blocked *OpenAIFastBlockedError
if errors.As(policyErr, &blocked) {
writeChatCompletionsError(c, http.StatusForbidden, "permission_error", blocked.Message)
}
return nil, policyErr
}
responsesBody = updatedBody
// 5. Get access token
token, _, err := s.GetAccessToken(ctx, account)
if err != nil {
......
......@@ -19,8 +19,22 @@ func TestNormalizeResponsesRequestServiceTier(t *testing.T) {
normalizeResponsesRequestServiceTier(req)
require.Equal(t, "flex", req.ServiceTier)
// OpenAI 官方合法 tier 应被透传保留。
req.ServiceTier = "auto"
normalizeResponsesRequestServiceTier(req)
require.Equal(t, "auto", req.ServiceTier)
req.ServiceTier = "default"
normalizeResponsesRequestServiceTier(req)
require.Equal(t, "default", req.ServiceTier)
req.ServiceTier = "scale"
normalizeResponsesRequestServiceTier(req)
require.Equal(t, "scale", req.ServiceTier)
// 真未知值仍被剥离。
req.ServiceTier = "turbo"
normalizeResponsesRequestServiceTier(req)
require.Empty(t, req.ServiceTier)
}
......@@ -37,8 +51,25 @@ func TestNormalizeResponsesBodyServiceTier(t *testing.T) {
require.Equal(t, "flex", tier)
require.Equal(t, "flex", gjson.GetBytes(body, "service_tier").String())
// OpenAI 官方 tier 直接保留在 body 中(透传上游)。
body, tier, err = normalizeResponsesBodyServiceTier([]byte(`{"model":"gpt-5.1","service_tier":"auto"}`))
require.NoError(t, err)
require.Equal(t, "auto", tier)
require.Equal(t, "auto", gjson.GetBytes(body, "service_tier").String())
body, tier, err = normalizeResponsesBodyServiceTier([]byte(`{"model":"gpt-5.1","service_tier":"default"}`))
require.NoError(t, err)
require.Equal(t, "default", tier)
require.Equal(t, "default", gjson.GetBytes(body, "service_tier").String())
body, tier, err = normalizeResponsesBodyServiceTier([]byte(`{"model":"gpt-5.1","service_tier":"scale"}`))
require.NoError(t, err)
require.Equal(t, "scale", tier)
require.Equal(t, "scale", gjson.GetBytes(body, "service_tier").String())
// 真未知值才会被删除。
body, tier, err = normalizeResponsesBodyServiceTier([]byte(`{"model":"gpt-5.1","service_tier":"turbo"}`))
require.NoError(t, err)
require.Empty(t, tier)
require.False(t, gjson.GetBytes(body, "service_tier").Exists())
}
......@@ -143,6 +143,19 @@ func (s *OpenAIGatewayService) ForwardAsAnthropic(
}
}
// 4c. Apply OpenAI fast policy (may filter service_tier or block the request).
// Mirrors the Claude anthropic-beta "fast-mode-2026-02-01" filter, but keyed
// on the body-level service_tier field (priority/flex).
updatedBody, policyErr := s.applyOpenAIFastPolicyToBody(ctx, account, upstreamModel, responsesBody)
if policyErr != nil {
var blocked *OpenAIFastBlockedError
if errors.As(policyErr, &blocked) {
writeAnthropicError(c, http.StatusForbidden, "forbidden_error", blocked.Message)
}
return nil, policyErr
}
responsesBody = updatedBody
// 5. Get access token
token, _, err := s.GetAccessToken(ctx, account)
if err != nil {
......
......@@ -148,6 +148,7 @@ func newOpenAIRecordUsageServiceForTest(usageRepo UsageLogRepository, userRepo U
nil,
nil,
nil,
nil,
)
svc.userGroupRateResolver = newUserGroupRateResolver(
rateRepo,
......@@ -826,18 +827,29 @@ func TestNormalizeOpenAIServiceTier(t *testing.T) {
require.Equal(t, "priority", *got)
})
t.Run("default ignored", func(t *testing.T) {
require.Nil(t, normalizeOpenAIServiceTier("default"))
t.Run("openai official tiers preserved", func(t *testing.T) {
// OpenAI 官方文档定义的合法 tier 值都应被透传保留,避免因白名单过窄
// 静默剥离客户端显式发送的合法字段。Codex 客户端只发 priority/flex,
// 所以扩大白名单对 Codex 流量零影响(见 codex-rs/core/src/client.rs)。
for _, tier := range []string{"priority", "flex", "auto", "default", "scale"} {
got := normalizeOpenAIServiceTier(tier)
require.NotNil(t, got, "tier %q should not be normalized to nil", tier)
require.Equal(t, tier, *got)
}
})
t.Run("invalid ignored", func(t *testing.T) {
require.Nil(t, normalizeOpenAIServiceTier("turbo"))
require.Nil(t, normalizeOpenAIServiceTier("xxx"))
})
}
func TestExtractOpenAIServiceTier(t *testing.T) {
require.Equal(t, "priority", *extractOpenAIServiceTier(map[string]any{"service_tier": "fast"}))
require.Equal(t, "flex", *extractOpenAIServiceTier(map[string]any{"service_tier": "flex"}))
require.Equal(t, "auto", *extractOpenAIServiceTier(map[string]any{"service_tier": "auto"}))
require.Equal(t, "default", *extractOpenAIServiceTier(map[string]any{"service_tier": "default"}))
require.Equal(t, "scale", *extractOpenAIServiceTier(map[string]any{"service_tier": "scale"}))
require.Nil(t, extractOpenAIServiceTier(map[string]any{"service_tier": 1}))
require.Nil(t, extractOpenAIServiceTier(nil))
}
......@@ -845,7 +857,10 @@ func TestExtractOpenAIServiceTier(t *testing.T) {
func TestExtractOpenAIServiceTierFromBody(t *testing.T) {
require.Equal(t, "priority", *extractOpenAIServiceTierFromBody([]byte(`{"service_tier":"fast"}`)))
require.Equal(t, "flex", *extractOpenAIServiceTierFromBody([]byte(`{"service_tier":"flex"}`)))
require.Nil(t, extractOpenAIServiceTierFromBody([]byte(`{"service_tier":"default"}`)))
require.Equal(t, "auto", *extractOpenAIServiceTierFromBody([]byte(`{"service_tier":"auto"}`)))
require.Equal(t, "default", *extractOpenAIServiceTierFromBody([]byte(`{"service_tier":"default"}`)))
require.Equal(t, "scale", *extractOpenAIServiceTierFromBody([]byte(`{"service_tier":"scale"}`)))
require.Nil(t, extractOpenAIServiceTierFromBody([]byte(`{"service_tier":"turbo"}`)))
require.Nil(t, extractOpenAIServiceTierFromBody(nil))
}
......
......@@ -334,6 +334,7 @@ type OpenAIGatewayService struct {
resolver *ModelPricingResolver
channelService *ChannelService
balanceNotifyService *BalanceNotifyService
settingService *SettingService
openaiWSPoolOnce sync.Once
openaiWSStateStoreOnce sync.Once
......@@ -372,6 +373,7 @@ func NewOpenAIGatewayService(
resolver *ModelPricingResolver,
channelService *ChannelService,
balanceNotifyService *BalanceNotifyService,
settingService *SettingService,
) *OpenAIGatewayService {
svc := &OpenAIGatewayService{
accountRepo: accountRepo,
......@@ -402,6 +404,7 @@ func NewOpenAIGatewayService(
resolver: resolver,
channelService: channelService,
balanceNotifyService: balanceNotifyService,
settingService: settingService,
responseHeaderFilter: compileResponseHeaderFilter(cfg),
codexSnapshotThrottle: newAccountWriteThrottle(openAICodexSnapshotPersistMinInterval),
}
......@@ -2310,6 +2313,48 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
disablePatch()
}
// Apply OpenAI fast policy (参照 Claude BetaPolicy 的 fast-mode 过滤):
// 针对 body 的 service_tier 字段("priority" 即 fast,"flex"),按策略
// 执行 filter(删除字段)或 block(拒绝请求)。对 gpt-5.5 等模型屏蔽
// fast 时在此生效。
//
// 注意:
// 1. 此处统一使用 upstreamModel(已经过 GetMappedModel +
// normalizeOpenAIModelForUpstream + Codex OAuth normalize),与
// chat-completions / messages 入口保持一致,避免不同入口因为模型
// 维度不同而出现 whitelist 命中差异。
// 2. action=pass 时也要把 raw "fast" 归一化为 "priority" 写回 body,
// 否则 native /responses 入口透传 "fast" 给上游会被拒。chat-
// completions 入口由 normalizeResponsesBodyServiceTier 完成同一
// 行为,这里手工实现等效逻辑。
if rawTier, ok := reqBody["service_tier"].(string); ok {
if normTier := normalizedOpenAIServiceTierValue(rawTier); normTier != "" {
action, errMsg := s.evaluateOpenAIFastPolicy(ctx, account, upstreamModel, normTier)
switch action {
case BetaPolicyActionBlock:
msg := errMsg
if msg == "" {
msg = fmt.Sprintf("openai service_tier=%s is not allowed for model %s", normTier, upstreamModel)
}
blocked := &OpenAIFastBlockedError{Message: msg}
writeOpenAIFastPolicyBlockedResponse(c, blocked)
return nil, blocked
case BetaPolicyActionFilter:
delete(reqBody, "service_tier")
bodyModified = true
disablePatch()
default:
// pass:若客户端传的是别名 "fast",归一化为 "priority"
// 后写回 body,确保上游收到的是其能识别的规范值。
if normTier != rawTier {
reqBody["service_tier"] = normTier
bodyModified = true
markPatchSet("service_tier", normTier)
}
}
}
}
// Re-serialize body only if modified
if bodyModified {
serializedByPatch := false
......@@ -2758,6 +2803,26 @@ func (s *OpenAIGatewayService) forwardOpenAIPassthrough(
body = sanitizedBody
}
// Apply OpenAI fast policy to the passthrough body (filter/block by service_tier).
// 统一使用 upstream 视角的 model:透传路径下 body 已经过 compact 映射 +
// OAuth normalize,body 中的 model 字段即上游真正会看到的 slug。
// 这样可以与 chat-completions / messages / native /responses 入口的
// upstreamModel 保持一致,避免 whitelist 命中差异。当 body 中没有
// model 字段时退回 reqModel。
policyModel := strings.TrimSpace(gjson.GetBytes(body, "model").String())
if policyModel == "" {
policyModel = reqModel
}
updatedBody, policyErr := s.applyOpenAIFastPolicyToBody(ctx, account, policyModel, body)
if policyErr != nil {
var blocked *OpenAIFastBlockedError
if errors.As(policyErr, &blocked) {
writeOpenAIFastPolicyBlockedResponse(c, blocked)
}
return nil, policyErr
}
body = updatedBody
logger.LegacyPrintf("service.openai_gateway",
"[OpenAI 自动透传] 命中自动透传分支: account=%d name=%s type=%s model=%s stream=%v",
account.ID,
......@@ -5590,14 +5655,319 @@ func normalizeOpenAIServiceTier(raw string) *string {
if value == "fast" {
value = "priority"
}
// 放过 OpenAI 官方文档定义的所有合法 tier 值:priority/flex/auto/default/scale。
// 对 Codex 客户端零影响(Codex 只发 priority 或 flex,见 codex-rs/core/src/client.rs),
// 但能让直连 OpenAI SDK 的用户透传 auto/default/scale 以便抓包/调试。
// 真未知值仍返回 nil,由 normalizeResponsesBodyServiceTier 从 body 中删除。
switch value {
case "priority", "flex":
case "priority", "flex", "auto", "default", "scale":
return &value
default:
return nil
}
}
// OpenAIFastBlockedError indicates a request was rejected by the OpenAI fast
// policy (action=block). Mirrors BetaBlockedError on the Claude side.
type OpenAIFastBlockedError struct {
Message string
}
func (e *OpenAIFastBlockedError) Error() string { return e.Message }
// evaluateOpenAIFastPolicy returns the action and error message that should be
// applied for a request with the given account/model/service_tier. When the
// policy service is unavailable or no rule matches, it returns
// (BetaPolicyActionPass, "") so callers can short-circuit safely.
//
// Matching rules:
// - Scope filters by account type (all / oauth / apikey / bedrock)
// - ServiceTier must be empty (= any), "all", or equal the normalized tier
// - ModelWhitelist narrows the rule to specific models; FallbackAction
// handles the non-matching case (default: pass)
//
// 与 Claude BetaPolicy 的差异(保留首条匹配 short-circuit):
// - BetaPolicy 处理的是 anthropic-beta header 中的 token 集合,不同
// 规则可能针对不同 token,filter 需要累加成 set;block 则 first-match。
// - OpenAI fast policy 操作的是单个字段 service_tier:filter 即删字段,
// 没有可累加的对象。一次请求只携带一个 service_tier,规则的 tier
// 维度天然互斥;同一 (scope, tier) 下若多条规则的 model whitelist
// 发生重叠,admin 可通过规则顺序明确意图。因此采用 first-match 而
// 非 BetaPolicy 那样的"block 覆盖 filter 覆盖 pass"语义。
func (s *OpenAIGatewayService) evaluateOpenAIFastPolicy(ctx context.Context, account *Account, model, serviceTier string) (action, errMsg string) {
if s == nil || s.settingService == nil {
return BetaPolicyActionPass, ""
}
tier := strings.ToLower(strings.TrimSpace(serviceTier))
if tier == "" {
return BetaPolicyActionPass, ""
}
settings := openAIFastPolicySettingsFromContext(ctx)
if settings == nil {
fetched, err := s.settingService.GetOpenAIFastPolicySettings(ctx)
if err != nil || fetched == nil {
return BetaPolicyActionPass, ""
}
settings = fetched
}
return evaluateOpenAIFastPolicyWithSettings(settings, account, model, tier)
}
// evaluateOpenAIFastPolicyWithSettings is the pure-function core extracted so
// long-lived sessions (e.g. WS) can prefetch settings once and avoid hitting
// the settingService on every frame. See WSSession entry and
// openAIFastPolicySettingsFromContext for the caching glue.
func evaluateOpenAIFastPolicyWithSettings(settings *OpenAIFastPolicySettings, account *Account, model, tier string) (action, errMsg string) {
if settings == nil {
return BetaPolicyActionPass, ""
}
isOAuth := account != nil && account.IsOAuth()
isBedrock := account != nil && account.IsBedrock()
for _, rule := range settings.Rules {
if !betaPolicyScopeMatches(rule.Scope, isOAuth, isBedrock) {
continue
}
ruleTier := strings.ToLower(strings.TrimSpace(rule.ServiceTier))
if ruleTier != "" && ruleTier != OpenAIFastTierAny && ruleTier != tier {
continue
}
eff := BetaPolicyRule{
Action: rule.Action,
ErrorMessage: rule.ErrorMessage,
ModelWhitelist: rule.ModelWhitelist,
FallbackAction: rule.FallbackAction,
FallbackErrorMessage: rule.FallbackErrorMessage,
}
return resolveRuleAction(eff, model)
}
return BetaPolicyActionPass, ""
}
// openAIFastPolicyCtxKey 是 context 中预取的 OpenAIFastPolicySettings 缓存
// 键,仅用于 WebSocket 长会话内多帧复用同一份策略快照,避免每帧 DB 命中。
//
// Trade-off:策略变更不会影响当前 WS session(只影响新 session)。这是
// 有意为之 —— 对长会话来说,"策略一致性"比"立刻生效"更重要,且 Claude
// BetaPolicy 的 gin.Context 缓存也是同样取舍。需要 hot-reload 时管理员
// 可以通过踢断 session 强制刷新。
type openAIFastPolicyCtxKeyType struct{}
var openAIFastPolicyCtxKey = openAIFastPolicyCtxKeyType{}
// withOpenAIFastPolicyContext 将一份 settings 快照绑定到 context,供该 ctx
// 衍生 goroutine 中的 evaluateOpenAIFastPolicy 复用。
func withOpenAIFastPolicyContext(ctx context.Context, settings *OpenAIFastPolicySettings) context.Context {
if ctx == nil || settings == nil {
return ctx
}
return context.WithValue(ctx, openAIFastPolicyCtxKey, settings)
}
func openAIFastPolicySettingsFromContext(ctx context.Context) *OpenAIFastPolicySettings {
if ctx == nil {
return nil
}
if v, ok := ctx.Value(openAIFastPolicyCtxKey).(*OpenAIFastPolicySettings); ok {
return v
}
return nil
}
// applyOpenAIFastPolicyToBody applies the OpenAI fast policy to a raw request
// body. When action=filter it removes the service_tier field; when
// action=block it returns (body, *OpenAIFastBlockedError). On pass it
// normalizes the service_tier value (e.g. client alias "fast" → "priority"),
// rewriting the body so the upstream receives a slug it recognizes.
//
// Rationale for normalize-on-pass: chat-completions / messages 入口在调用本
// 函数之前已经通过 normalizeResponsesBodyServiceTier 把 service_tier 归一化
// 到了上游可识别值;passthrough(OpenAI 自动透传) / native /responses 等
// 入口没有这一前置步骤,pass 路径下若不在此处归一化,"fast" 就会被原样
// 透传到 OpenAI 上游导致 400/拒绝。把归一化收敛到本函数,所有入口行为一致。
func (s *OpenAIGatewayService) applyOpenAIFastPolicyToBody(ctx context.Context, account *Account, model string, body []byte) ([]byte, error) {
if len(body) == 0 {
return body, nil
}
rawTier := gjson.GetBytes(body, "service_tier").String()
if rawTier == "" {
return body, nil
}
normTier := normalizedOpenAIServiceTierValue(rawTier)
if normTier == "" {
return body, nil
}
action, errMsg := s.evaluateOpenAIFastPolicy(ctx, account, model, normTier)
switch action {
case BetaPolicyActionBlock:
msg := errMsg
if msg == "" {
msg = fmt.Sprintf("openai service_tier=%s is not allowed for model %s", normTier, model)
}
return body, &OpenAIFastBlockedError{Message: msg}
case BetaPolicyActionFilter:
trimmed, err := sjson.DeleteBytes(body, "service_tier")
if err != nil {
return body, fmt.Errorf("strip service_tier from body: %w", err)
}
return trimmed, nil
default:
// pass:把别名(如 "fast")写回为规范值("priority")。
if normTier == rawTier {
return body, nil
}
updated, err := sjson.SetBytes(body, "service_tier", normTier)
if err != nil {
return body, fmt.Errorf("normalize service_tier on pass: %w", err)
}
return updated, nil
}
}
// writeOpenAIFastPolicyBlockedResponse writes a 403 JSON response for a
// request blocked by the OpenAI fast policy.
func writeOpenAIFastPolicyBlockedResponse(c *gin.Context, err *OpenAIFastBlockedError) {
if c == nil || err == nil {
return
}
c.JSON(http.StatusForbidden, gin.H{
"error": gin.H{
"type": "permission_error",
"message": err.Message,
},
})
}
// applyOpenAIFastPolicyToWSResponseCreate evaluates the OpenAI fast policy
// against a single client→upstream WebSocket frame whose top-level
// "type"=="response.create". It mirrors the HTTP-side
// applyOpenAIFastPolicyToBody contract but operates on a Realtime/Responses
// WS payload:
//
// - pass: returns frame unchanged (newBytes == frame, blocked == nil)
// - filter: returns a copy with top-level service_tier removed
// - block: returns (frame, *OpenAIFastBlockedError)
//
// Only frames whose "type" field strictly equals "response.create" are
// inspected/mutated. Any other frame type — including the empty string —
// passes through untouched. The OpenAI Realtime client-event spec requires
// "type" to be set, so an empty type is treated as a malformed frame we do
// not police; the upstream is the source of truth for rejecting it.
//
// service_tier lives at the top level of response.create — same as the
// Responses HTTP body shape (see openai_gateway_chat_completions.go:304 +
// extractOpenAIServiceTierFromBody at line 5593, and the test fixture at
// openai_ws_forwarder_ingress_session_test.go:402). We therefore only need
// to inspect / strip the top-level field; there is no nested form in the
// schema today.
//
// The caller is responsible for choosing the upstream model passed in —
// this helper does not re-derive it.
func (s *OpenAIGatewayService) applyOpenAIFastPolicyToWSResponseCreate(
ctx context.Context,
account *Account,
model string,
frame []byte,
) ([]byte, *OpenAIFastBlockedError, error) {
if len(frame) == 0 {
return frame, nil, nil
}
if !gjson.ValidBytes(frame) {
return frame, nil, nil
}
frameType := strings.TrimSpace(gjson.GetBytes(frame, "type").String())
// Strict match: only response.create is policy-checked. Empty / other
// types pass through untouched so we never accidentally strip fields
// from response.cancel, conversation.item.create, or any future
// client-event the spec adds. The Realtime spec requires "type" on
// every client event, so an empty type is malformed input — let the
// upstream reject it rather than guessing at our layer.
if frameType != "response.create" {
return frame, nil, nil
}
rawTier := gjson.GetBytes(frame, "service_tier").String()
if rawTier == "" {
return frame, nil, nil
}
normTier := normalizedOpenAIServiceTierValue(rawTier)
if normTier == "" {
return frame, nil, nil
}
action, errMsg := s.evaluateOpenAIFastPolicy(ctx, account, model, normTier)
switch action {
case BetaPolicyActionBlock:
msg := errMsg
if msg == "" {
msg = fmt.Sprintf("openai service_tier=%s is not allowed for model %s", normTier, model)
}
return frame, &OpenAIFastBlockedError{Message: msg}, nil
case BetaPolicyActionFilter:
trimmed, err := sjson.DeleteBytes(frame, "service_tier")
if err != nil {
return frame, nil, fmt.Errorf("strip service_tier from ws frame: %w", err)
}
return trimmed, nil, nil
default:
return frame, nil, nil
}
}
// newOpenAIFastPolicyWSEventID returns a Realtime-style event_id for a
// server-emitted error event. Matches the loose "evt_<rand>" convention used
// by upstream Realtime servers; the exact value is not load-bearing and is
// only required for client-side log correlation. We reuse the existing
// google/uuid dependency rather than pulling a new one.
func newOpenAIFastPolicyWSEventID() string {
id, err := uuid.NewRandom()
if err != nil {
// Extremely unlikely; fall back to a fixed prefix so the field is
// still non-empty and the schema stays self-consistent.
return "evt_openai_fast_policy"
}
// Strip dashes so it visually matches "evt_<hex>" rather than UUID v4
// canonical form, mirroring what real Realtime traces look like.
return "evt_" + strings.ReplaceAll(id.String(), "-", "")
}
// buildOpenAIFastPolicyBlockedWSEvent renders an OpenAI Realtime/Responses
// style "error" event payload for a request blocked by the OpenAI fast
// policy. The shape mirrors Realtime error events as observed in upstream
// traces and per the spec's server "error" event:
//
// {
// "event_id": "evt_<random>",
// "type": "error",
// "error": {
// "type": "invalid_request_error",
// "code": "policy_violation",
// "message": "..."
// }
// }
//
// event_id lets clients correlate the rejection in their logs; "code" gives
// programmatic clients a stable identifier (HTTP-side equivalent is the
// 403 permission_error JSON body).
func buildOpenAIFastPolicyBlockedWSEvent(err *OpenAIFastBlockedError) []byte {
if err == nil {
return nil
}
eventID := newOpenAIFastPolicyWSEventID()
payload, mErr := json.Marshal(map[string]any{
"event_id": eventID,
"type": "error",
"error": map[string]any{
"type": "invalid_request_error",
"code": "policy_violation",
"message": err.Message,
},
})
if mErr != nil {
// Fallback to a minimal hand-rolled payload; Marshal of the literal
// shape above should never fail in practice.
return []byte(`{"event_id":"` + eventID + `","type":"error","error":{"type":"invalid_request_error","code":"policy_violation","message":"openai fast policy blocked this request"}}`)
}
return payload
}
func sanitizeEmptyBase64InputImagesInOpenAIBody(body []byte) ([]byte, bool, error) {
if len(body) == 0 || !bytes.Contains(body, []byte(`"image_url"`)) || !bytes.Contains(body, []byte(`base64,`)) {
return body, false, nil
......
......@@ -2366,6 +2366,15 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
return errors.New("token is empty")
}
// 预取一次 OpenAI Fast Policy settings,绑定到 ctx,让该 WS session
// 内所有帧的 evaluateOpenAIFastPolicy 调用复用同一份快照,避免每帧
// 进入 DB / settingRepo。Trade-off 见 withOpenAIFastPolicyContext 注释。
if s.settingService != nil {
if settings, err := s.settingService.GetOpenAIFastPolicySettings(ctx); err == nil && settings != nil {
ctx = withOpenAIFastPolicyContext(ctx, settings)
}
}
wsDecision := s.getOpenAIWSProtocolResolver().Resolve(account)
modeRouterV2Enabled := s != nil && s.cfg != nil && s.cfg.Gateway.OpenAIWS.ModeRouterV2Enabled
ingressMode := OpenAIWSIngressModeCtxPool
......@@ -2524,6 +2533,44 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
normalized = next
}
// Apply OpenAI Fast Policy on the response.create frame using the same
// evaluator/normalize/scope rules as the HTTP entrypoints. This is the
// single integration point for all WS ingress turns (first + follow-up
// frames flow through here).
//
// Model fallback: parseClientPayload above rejects any frame whose
// "model" field is missing (line ~2493-2500), so by the time we
// reach this point upstreamModel is always derived from a non-empty
// per-frame model. The capturedSessionModel fallback used in the
// passthrough adapter is therefore not needed in this path.
policyApplied, blocked, policyErr := s.applyOpenAIFastPolicyToWSResponseCreate(ctx, account, upstreamModel, normalized)
if policyErr != nil {
return openAIWSClientPayload{}, NewOpenAIWSClientCloseError(coderws.StatusPolicyViolation, "invalid websocket request payload", policyErr)
}
if blocked != nil {
// Send a Realtime-style error event to the client first, then
// signal the handler to close the connection with PolicyViolation.
// We intentionally do NOT forward this frame upstream.
//
// coder/websocket@v1.8.14 Conn.Write is synchronous and flushes
// the underlying bufio writer before returning (write.go:42 →
// 307-311), and the subsequent close handshake re-acquires the
// same writeFrameMu, so the error event is guaranteed to reach
// the kernel send buffer before any close frame is queued.
eventBytes := buildOpenAIFastPolicyBlockedWSEvent(blocked)
if eventBytes != nil {
writeCtx, cancel := context.WithTimeout(ctx, s.openAIWSWriteTimeout())
_ = clientConn.Write(writeCtx, coderws.MessageText, eventBytes)
cancel()
}
return openAIWSClientPayload{}, NewOpenAIWSClientCloseError(
coderws.StatusPolicyViolation,
blocked.Message,
blocked,
)
}
normalized = policyApplied
return openAIWSClientPayload{
payloadRaw: normalized,
rawForHash: trimmed,
......
......@@ -618,6 +618,7 @@ func TestNewOpenAIGatewayService_InitializesOpenAIWSResolver(t *testing.T) {
nil,
nil,
nil,
nil,
)
decision := svc.getOpenAIWSProtocolResolver().Resolve(nil)
......
......@@ -21,6 +21,109 @@ type openAIWSClientFrameConn struct {
conn *coderws.Conn
}
// openAIWSPolicyEnforcingFrameConn wraps a client-side FrameConn and runs
// every client→upstream frame through the OpenAI Fast Policy. It is the
// passthrough-relay equivalent of the parseClientPayload integration in the
// ingress session path. filter returns:
// - newPayload, nil, nil: forward the (possibly mutated) payload
// - _, *OpenAIFastBlockedError, nil: block — the wrapper sends an error
// event via onBlock and surfaces a transport-level error so the relay
// stops reading from the client.
// - _, _, err: a transport error other than block.
type openAIWSPolicyEnforcingFrameConn struct {
inner openaiwsv2.FrameConn
filter func(msgType coderws.MessageType, payload []byte) ([]byte, *OpenAIFastBlockedError, error)
onBlock func(blocked *OpenAIFastBlockedError)
}
var _ openaiwsv2.FrameConn = (*openAIWSPolicyEnforcingFrameConn)(nil)
func (c *openAIWSPolicyEnforcingFrameConn) ReadFrame(ctx context.Context) (coderws.MessageType, []byte, error) {
if c == nil || c.inner == nil {
return coderws.MessageText, nil, errOpenAIWSConnClosed
}
msgType, payload, err := c.inner.ReadFrame(ctx)
if err != nil {
return msgType, payload, err
}
if c.filter == nil {
return msgType, payload, nil
}
updated, blocked, filterErr := c.filter(msgType, payload)
if filterErr != nil {
return msgType, payload, filterErr
}
if blocked != nil {
if c.onBlock != nil {
c.onBlock(blocked)
}
return msgType, nil, NewOpenAIWSClientCloseError(coderws.StatusPolicyViolation, blocked.Message, blocked)
}
return msgType, updated, nil
}
func (c *openAIWSPolicyEnforcingFrameConn) WriteFrame(ctx context.Context, msgType coderws.MessageType, payload []byte) error {
if c == nil || c.inner == nil {
return errOpenAIWSConnClosed
}
return c.inner.WriteFrame(ctx, msgType, payload)
}
func (c *openAIWSPolicyEnforcingFrameConn) Close() error {
if c == nil || c.inner == nil {
return nil
}
return c.inner.Close()
}
// openAIWSPassthroughPolicyModelForFrame returns the upstream-perspective
// model name that should be passed to evaluateOpenAIFastPolicy for a single
// passthrough WS frame. Mirrors the HTTP-side normalization
// (account.GetMappedModel + normalizeOpenAIModelForUpstream) so the WS path
// matches model whitelists identically.
func openAIWSPassthroughPolicyModelForFrame(account *Account, payload []byte) string {
if account == nil || len(payload) == 0 {
return ""
}
original := strings.TrimSpace(gjson.GetBytes(payload, "model").String())
if original == "" {
return ""
}
return normalizeOpenAIModelForUpstream(account, account.GetMappedModel(original))
}
// openAIWSPassthroughPolicyModelFromSessionFrame returns the upstream model
// derived from a session.update frame's session.model field. Returns "" when
// the frame is not a session.update event or carries no session.model. Used
// by the per-frame policy filter (client→upstream direction) to keep
// capturedSessionModel in sync with the session-level model the client may
// rotate mid-session.
//
// Realtime / Responses WS lets the client change the session model after
// the WS handshake via:
//
// {"type":"session.update","session":{"model":"gpt-5.5", ...}}
//
// If we only capture the model from the very first frame, a client can ship
// gpt-4o on the first response.create (whitelisted as pass), then
// session.update to gpt-5.5, then send response.create without "model" so
// the per-frame resolver returns "" and the stale capturedSessionModel falls
// back to gpt-4o — defeating the gpt-5.5 fast-policy filter.
func openAIWSPassthroughPolicyModelFromSessionFrame(account *Account, payload []byte) string {
if account == nil || len(payload) == 0 {
return ""
}
frameType := strings.TrimSpace(gjson.GetBytes(payload, "type").String())
if frameType != "session.update" {
return ""
}
original := strings.TrimSpace(gjson.GetBytes(payload, "session.model").String())
if original == "" {
return ""
}
return normalizeOpenAIModelForUpstream(account, account.GetMappedModel(original))
}
const openaiWSV2PassthroughModeFields = "ws_mode=passthrough ws_router=v2"
var _ openaiwsv2.FrameConn = (*openAIWSClientFrameConn)(nil)
......@@ -77,7 +180,6 @@ func (s *OpenAIGatewayService) proxyResponsesWebSocketV2Passthrough(
return errors.New("token is empty")
}
requestModel := strings.TrimSpace(gjson.GetBytes(firstClientMessage, "model").String())
requestServiceTier := extractOpenAIServiceTierFromBody(firstClientMessage)
requestPreviousResponseID := strings.TrimSpace(gjson.GetBytes(firstClientMessage, "previous_response_id").String())
logOpenAIWSV2Passthrough(
"relay_start account_id=%d model=%s previous_response_id=%s first_message_type=%s first_message_bytes=%d",
......@@ -88,6 +190,59 @@ func (s *OpenAIGatewayService) proxyResponsesWebSocketV2Passthrough(
len(firstClientMessage),
)
// Apply OpenAI Fast Policy on the first response.create frame. Subsequent
// frames are filtered via a wrapping FrameConn below so every client→
// upstream frame goes through the same policy evaluator/normalize/scope as
// HTTP entrypoints.
//
// We capture the session-level model from the first frame here so the
// per-frame filter (below) can fall back to it when a follow-up frame
// omits "model" — Realtime clients are allowed to send response.create
// without re-stating the model, in which case the upstream uses the model
// negotiated at session.update time. Without this fallback, an empty
// model would miss the default ["gpt-5.5","gpt-5.5*"] whitelist and be
// silently passed through, defeating the policy on every frame after
// the first.
capturedSessionModel := openAIWSPassthroughPolicyModelForFrame(account, firstClientMessage)
updatedFirst, blocked, policyErr := s.applyOpenAIFastPolicyToWSResponseCreate(ctx, account, capturedSessionModel, firstClientMessage)
if policyErr != nil {
return fmt.Errorf("apply openai fast policy on first ws frame: %w", policyErr)
}
if blocked != nil {
// coder/websocket@v1.8.14 Conn.Write is synchronous: it acquires
// writeFrameMu, writes the entire frame, and Flushes the underlying
// bufio writer before returning (write.go:42 → write.go:307-311).
// The subsequent close handshake re-acquires the same writeFrameMu
// to send the close frame, so the error event is guaranteed to
// reach the kernel send buffer before any close frame is queued.
// No explicit flush hop is required here.
eventBytes := buildOpenAIFastPolicyBlockedWSEvent(blocked)
if eventBytes != nil {
writeCtx, cancelWrite := context.WithTimeout(ctx, s.openAIWSWriteTimeout())
_ = clientConn.Write(writeCtx, coderws.MessageText, eventBytes)
cancelWrite()
}
return NewOpenAIWSClientCloseError(coderws.StatusPolicyViolation, blocked.Message, blocked)
}
firstClientMessage = updatedFirst
// 在 policy filter 之后再提取 service_tier 用于 billing 上报:filter
// 命中时 service_tier 已经从 firstClientMessage 中删除,billing 应当
// 反映上游实际处理的 tier(nil = default),而不是用户最初请求的
// "priority"。HTTP 入口(line ~2728 extractOpenAIServiceTier(reqBody))
// 与 WS ingress(openai_ws_forwarder.go:2991 取自 payload)的语义一致。
//
// 多轮 passthrough:OpenAI Realtime / Responses WS 协议允许客户端在
// 同一连接的不同 response.create 帧上发送不同 service_tier(参考
// codex-rs/core/src/client.rs build_responses_request 每次重新填值)。
// 因此使用 atomic.Pointer[string] 在 filter(runClientToUpstream
// goroutine)和 OnTurnComplete / final result(runUpstreamToClient
// goroutine)之间同步当前 turn 的 service_tier。
// extractOpenAIServiceTierFromBody 返回 *string,本身是指针类型,
// 可直接 Store/Load 而无需额外封装。
var requestServiceTierPtr atomic.Pointer[string]
requestServiceTierPtr.Store(extractOpenAIServiceTierFromBody(firstClientMessage))
wsURL, err := s.buildOpenAIResponsesWSURL(account)
if err != nil {
return fmt.Errorf("build ws url: %w", err)
......@@ -152,9 +307,72 @@ func (s *OpenAIGatewayService) proxyResponsesWebSocketV2Passthrough(
}
completedTurns := atomic.Int32{}
policyClientConn := &openAIWSPolicyEnforcingFrameConn{
inner: &openAIWSClientFrameConn{conn: clientConn},
// 注意线程安全:filter 仅在 runClientToUpstream 这一条
// goroutine 中被调用(passthrough_relay.go: ReadFrame loop),
// capturedSessionModel 的读写都发生在该 goroutine 内,因此无需
// 加锁/原子化。
filter: func(msgType coderws.MessageType, payload []byte) ([]byte, *OpenAIFastBlockedError, error) {
if msgType != coderws.MessageText {
return payload, nil, nil
}
// 在评估策略前先刷新 capturedSessionModel:客户端可能通过
// session.update 修改 session-level model(Realtime /
// Responses WS 协议允许),如果不刷新就会出现
// "首帧 model=gpt-4o(pass)→ session.update 改成 gpt-5.5
// → 不带 model 的 response.create fallback 到 gpt-4o" 的
// 绕过路径。这里只看 session.update 事件中的 session.model
// 字段,response.create 自己的 model 仍然由其本帧字段决定。
if updated := openAIWSPassthroughPolicyModelFromSessionFrame(account, payload); updated != "" {
capturedSessionModel = updated
}
// Per-frame model first; if the client omits "model" on a
// follow-up frame (legal in Realtime), fall back to the
// session-level model captured from the first frame so the
// model whitelist still resolves. An empty model would miss
// any whitelist and silently fall back to pass.
model := openAIWSPassthroughPolicyModelForFrame(account, payload)
if model == "" {
model = capturedSessionModel
}
out, blocked, policyErr := s.applyOpenAIFastPolicyToWSResponseCreate(ctx, account, model, payload)
// 多轮 passthrough billing:仅在成功(non-block / non-err)
// 的 response.create 帧上更新 requestServiceTierPtr,使用
// filter 处理后的 payload,与首帧 policy-after-extract 语义
// 保持一致(参见上方 extractOpenAIServiceTierFromBody 注释)。
// - 非 response.create 帧(response.cancel /
// conversation.item.create / session.update 等)不携带
// per-response service_tier,不应覆盖前一轮值。
// - blocked != nil:该帧不会发送上游,billing tier 应保持
// 上一轮值。
// - policyErr != nil:异常路径,保持上一轮值。
// - 不带 service_tier 的 response.create 会让
// extractOpenAIServiceTierFromBody 返回 nil;这里有意
// 覆盖(Store(nil)),因为 OpenAI 上游对该帧实际不传
// service_tier 时按 default 处理,billing 应如实反映。
if policyErr == nil && blocked == nil &&
strings.TrimSpace(gjson.GetBytes(payload, "type").String()) == "response.create" {
requestServiceTierPtr.Store(extractOpenAIServiceTierFromBody(out))
}
return out, blocked, policyErr
},
onBlock: func(blocked *OpenAIFastBlockedError) {
// See note above on Conn.Write being synchronous w.r.t. flush;
// no explicit flush is required to ensure the error event lands
// before the close frame.
eventBytes := buildOpenAIFastPolicyBlockedWSEvent(blocked)
if eventBytes == nil {
return
}
writeCtx, cancel := context.WithTimeout(ctx, s.openAIWSWriteTimeout())
_ = clientConn.Write(writeCtx, coderws.MessageText, eventBytes)
cancel()
},
}
relayResult, relayExit := openaiwsv2.RunEntry(openaiwsv2.EntryInput{
Ctx: ctx,
ClientConn: &openAIWSClientFrameConn{conn: clientConn},
ClientConn: policyClientConn,
UpstreamConn: upstreamFrameConn,
FirstClientMessage: firstClientMessage,
Options: openaiwsv2.RelayOptions{
......@@ -179,7 +397,7 @@ func (s *OpenAIGatewayService) proxyResponsesWebSocketV2Passthrough(
CacheReadInputTokens: turn.Usage.CacheReadInputTokens,
},
Model: turn.RequestModel,
ServiceTier: requestServiceTier,
ServiceTier: requestServiceTierPtr.Load(),
Stream: true,
OpenAIWSMode: true,
ResponseHeaders: cloneHeader(handshakeHeaders),
......@@ -227,7 +445,7 @@ func (s *OpenAIGatewayService) proxyResponsesWebSocketV2Passthrough(
CacheReadInputTokens: relayResult.Usage.CacheReadInputTokens,
},
Model: relayResult.RequestModel,
ServiceTier: requestServiceTier,
ServiceTier: requestServiceTierPtr.Load(),
Stream: true,
OpenAIWSMode: true,
ResponseHeaders: cloneHeader(handshakeHeaders),
......
......@@ -3259,6 +3259,84 @@ func (s *SettingService) SetBetaPolicySettings(ctx context.Context, settings *Be
return s.settingRepo.Set(ctx, SettingKeyBetaPolicySettings, string(data))
}
// GetOpenAIFastPolicySettings 获取 OpenAI fast 策略配置
func (s *SettingService) GetOpenAIFastPolicySettings(ctx context.Context) (*OpenAIFastPolicySettings, error) {
value, err := s.settingRepo.GetValue(ctx, SettingKeyOpenAIFastPolicySettings)
if err != nil {
if errors.Is(err, ErrSettingNotFound) {
return DefaultOpenAIFastPolicySettings(), nil
}
return nil, fmt.Errorf("get openai fast policy settings: %w", err)
}
if value == "" {
return DefaultOpenAIFastPolicySettings(), nil
}
var settings OpenAIFastPolicySettings
if err := json.Unmarshal([]byte(value), &settings); err != nil {
// JSON 损坏时静默 fallback 到默认配置会让策略意外失效(管理员配
// 置的 block/filter 规则被忽略)。记录 Warn 让运维能在出现异常
// 行为时定位到 settings 表里的脏数据。
slog.Warn("failed to unmarshal openai fast policy settings, falling back to defaults",
"error", err,
"key", SettingKeyOpenAIFastPolicySettings)
return DefaultOpenAIFastPolicySettings(), nil
}
return &settings, nil
}
// SetOpenAIFastPolicySettings 设置 OpenAI fast 策略配置
func (s *SettingService) SetOpenAIFastPolicySettings(ctx context.Context, settings *OpenAIFastPolicySettings) error {
if settings == nil {
return fmt.Errorf("settings cannot be nil")
}
validActions := map[string]bool{
BetaPolicyActionPass: true, BetaPolicyActionFilter: true, BetaPolicyActionBlock: true,
}
validScopes := map[string]bool{
BetaPolicyScopeAll: true, BetaPolicyScopeOAuth: true, BetaPolicyScopeAPIKey: true, BetaPolicyScopeBedrock: true,
}
validTiers := map[string]bool{
OpenAIFastTierAny: true, OpenAIFastTierPriority: true, OpenAIFastTierFlex: true,
}
for i, rule := range settings.Rules {
tier := strings.ToLower(strings.TrimSpace(rule.ServiceTier))
if tier == "" {
tier = OpenAIFastTierAny
}
if !validTiers[tier] {
return fmt.Errorf("rule[%d]: invalid service_tier %q", i, rule.ServiceTier)
}
settings.Rules[i].ServiceTier = tier
if !validActions[rule.Action] {
return fmt.Errorf("rule[%d]: invalid action %q", i, rule.Action)
}
if !validScopes[rule.Scope] {
return fmt.Errorf("rule[%d]: invalid scope %q", i, rule.Scope)
}
for j, pattern := range rule.ModelWhitelist {
trimmed := strings.TrimSpace(pattern)
if trimmed == "" {
return fmt.Errorf("rule[%d]: model_whitelist[%d] cannot be empty", i, j)
}
settings.Rules[i].ModelWhitelist[j] = trimmed
}
if rule.FallbackAction != "" && !validActions[rule.FallbackAction] {
return fmt.Errorf("rule[%d]: invalid fallback_action %q", i, rule.FallbackAction)
}
}
data, err := json.Marshal(settings)
if err != nil {
return fmt.Errorf("marshal openai fast policy settings: %w", err)
}
return s.settingRepo.Set(ctx, SettingKeyOpenAIFastPolicySettings, string(data))
}
// SetStreamTimeoutSettings 设置流超时处理配置
func (s *SettingService) SetStreamTimeoutSettings(ctx context.Context, settings *StreamTimeoutSettings) error {
if settings == nil {
......
......@@ -405,3 +405,57 @@ func DefaultBetaPolicySettings() *BetaPolicySettings {
},
}
}
// OpenAI Fast Policy 策略常量
// OpenAI 的 "fast 模式" 通过请求体中的 service_tier 字段识别:
// - "priority"(客户端可传 "fast",归一化为 "priority"):fast 模式
// - "flex":低优先级模式
// - 省略:normal 默认
//
// 本策略复用 BetaPolicyAction*/BetaPolicyScope* 常量语义,只是匹配键从
// anthropic-beta header 换成 body 的 service_tier 字段。
const (
OpenAIFastTierAny = "all" // 匹配任意已识别的 service_tier
OpenAIFastTierPriority = "priority" // 仅匹配 fast(priority)
OpenAIFastTierFlex = "flex" // 仅匹配 flex
)
// OpenAIFastPolicyRule 单条 OpenAI fast/flex 策略规则
type OpenAIFastPolicyRule struct {
ServiceTier string `json:"service_tier"` // "priority" | "flex" | "auto" | "default" | "scale" | "all"
Action string `json:"action"` // "pass" | "filter" | "block"
Scope string `json:"scope"` // "all" | "oauth" | "apikey" | "bedrock"
ErrorMessage string `json:"error_message,omitempty"` // 自定义错误消息 (action=block 时生效)
ModelWhitelist []string `json:"model_whitelist,omitempty"` // 模型匹配模式列表(为空=对所有模型生效)
FallbackAction string `json:"fallback_action,omitempty"` // 未匹配白名单的模型的处理方式
FallbackErrorMessage string `json:"fallback_error_message,omitempty"` // 未匹配白名单时的自定义错误消息 (fallback_action=block 时生效)
}
// OpenAIFastPolicySettings OpenAI fast 策略配置
type OpenAIFastPolicySettings struct {
Rules []OpenAIFastPolicyRule `json:"rules"`
}
// DefaultOpenAIFastPolicySettings 返回默认的 OpenAI fast 策略配置。
// 默认对所有模型的 priority(fast)请求执行 filter,即剔除 service_tier 字段,
// 让上游按 normal 优先级处理。
//
// 为什么 ModelWhitelist 为空(=对所有模型生效):
// codex 客户端的 service_tier=fast 是用户级开关,与 model 字段正交。即使
// 用户使用 gpt-4 + fast,priority 配额仍会被消耗。如果默认规则只锁
// gpt-5.5*,"用 gpt-4 + fast 透传 priority 上游" 这条路径就会绕过策略。
// 与 codex 真实语义对齐,默认对所有模型生效;管理员若需要只针对特定
// 模型,可在 admin UI 中显式配置 model_whitelist。
func DefaultOpenAIFastPolicySettings() *OpenAIFastPolicySettings {
return &OpenAIFastPolicySettings{
Rules: []OpenAIFastPolicyRule{
{
ServiceTier: OpenAIFastTierPriority,
Action: BetaPolicyActionFilter,
Scope: BetaPolicyScopeAll,
ModelWhitelist: []string{},
FallbackAction: BetaPolicyActionPass,
},
},
}
}
......@@ -484,6 +484,9 @@ export interface SystemSettings {
// Affiliate (邀请返利) feature switch
affiliate_enabled: boolean;
// OpenAI fast/flex policy
openai_fast_policy_settings?: OpenAIFastPolicySettings;
}
export interface UpdateSettingsRequest {
......@@ -648,6 +651,9 @@ export interface UpdateSettingsRequest {
// Affiliate (邀请返利) feature switch
affiliate_enabled?: boolean;
// OpenAI fast/flex policy
openai_fast_policy_settings?: OpenAIFastPolicySettings;
}
/**
......@@ -875,6 +881,29 @@ export async function updateRectifierSettings(
return data;
}
// ==================== OpenAI Fast Policy Settings ====================
/**
* OpenAI fast/flex policy rule interface.
* Matches backend dto.OpenAIFastPolicyRule.
*/
export interface OpenAIFastPolicyRule {
service_tier: "all" | "priority" | "flex";
action: "pass" | "filter" | "block";
scope: "all" | "oauth" | "apikey" | "bedrock";
error_message?: string;
model_whitelist?: string[];
fallback_action?: "pass" | "filter" | "block";
fallback_error_message?: string;
}
/**
* OpenAI fast/flex policy settings interface.
*/
export interface OpenAIFastPolicySettings {
rules: OpenAIFastPolicyRule[];
}
// ==================== Beta Policy Settings ====================
/**
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment