Commit 1d0872e7 authored by yangjianbo's avatar yangjianbo
Browse files

feat(openai-ws): 合并 WS v2 透传模式与前端 ws mode



新增 OpenAI WebSocket v2 passthrough relay 数据面与服务适配层,
支持按账号 ws mode 在 ctx_pool 与 passthrough 间路由。

同步调整前端 OpenAI ws mode 选项为 off/ctx_pool/passthrough,
并补充 i18n 文案与对应单测。

新增 Caddyfile.dmit 与 docker-compose-aicodex.yml 部署配置,
用于宿主机场景下的反向代理与服务编排。
Co-Authored-By: default avatarClaude Opus 4.6 <noreply@anthropic.com>
parent 078fefed
# =============================================================================
# Sub2API Caddy Reverse Proxy Configuration (宿主机部署)
# =============================================================================
# 使用方法:
# 1. 安装 Caddy: https://caddyserver.com/docs/install
# 2. 修改下方 example.com 为你的域名
# 3. 确保域名 DNS 已指向服务器
# 4. 复制配置: sudo cp Caddyfile /etc/caddy/Caddyfile
# 5. 重载配置: sudo systemctl reload caddy
#
# Caddy 会自动申请和续期 Let's Encrypt SSL 证书
# =============================================================================
# 全局配置
{
# Let's Encrypt 邮箱通知
email mt21625457@gmail.com
# 服务器配置
servers {
# 启用 HTTP/2 和 HTTP/3
protocols h1 h2 h3
# 超时配置
timeouts {
read_body 30s
read_header 10s
# WebSocket/流式场景下,延长写入与空闲超时,避免长会话被过早回收
write 3600s
idle 3600s
}
}
}
# 修改为你的域名
dmit.leagsoft.ai {
# =========================================================================
# 静态资源长期缓存(高优先级,放在最前面)
# 带 hash 的文件可以永久缓存,浏览器和 CDN 都会缓存
# =========================================================================
@static {
path /assets/*
path /logo.png
path /favicon.ico
}
header @static {
Cache-Control "public, max-age=31536000, immutable"
# 移除可能干扰缓存的头
-Pragma
-Expires
}
# =========================================================================
# TLS 安全配置
# =========================================================================
tls {
# 仅使用 TLS 1.2 和 1.3
protocols tls1.2 tls1.3
# 优先使用的加密套件
ciphers TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384 TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384 TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256 TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256 TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256 TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256
}
# =========================================================================
# 反向代理配置
# =========================================================================
# OpenAI Responses(含 WebSocket/SSE)专用代理:
# 1) 禁用流式缓冲,降低中间层等待导致的断流概率
# 2) 上游强制 HTTP/1.1,保证 Upgrade 行为稳定
# 3) 放宽流生命周期,避免长会话被代理提前切断
@openai_responses {
path /openai/v1/responses*
}
reverse_proxy @openai_responses localhost:8080 {
flush_interval -1
stream_timeout 24h
stream_close_delay 5m
# 传递真实客户端信息
header_up X-Real-IP {remote_host}
header_up X-Forwarded-For {remote_host}
header_up X-Forwarded-Proto {scheme}
header_up X-Forwarded-Host {host}
header_up CF-Connecting-IP {http.request.header.CF-Connecting-IP}
transport http {
versions 1.1
keepalive 120s
keepalive_idle_conns 256
read_buffer 32KB
write_buffer 32KB
compression off
}
}
reverse_proxy localhost:8080 {
# 健康检查
health_uri /health
health_interval 30s
health_timeout 10s
health_status 200
# 负载均衡策略(单节点可忽略,多节点时有用)
lb_policy round_robin
lb_try_duration 5s
lb_try_interval 250ms
# 传递真实客户端信息
# 兼容 Cloudflare 和直连:后端应优先读取 CF-Connecting-IP,其次 X-Real-IP
header_up X-Real-IP {remote_host}
header_up X-Forwarded-For {remote_host}
header_up X-Forwarded-Proto {scheme}
header_up X-Forwarded-Host {host}
# 保留 Cloudflare 原始头(如果存在)
# 后端获取 IP 的优先级建议: CF-Connecting-IP → X-Real-IP → X-Forwarded-For
header_up CF-Connecting-IP {http.request.header.CF-Connecting-IP}
# 连接池优化
transport http {
keepalive 120s
keepalive_idle_conns 256
read_buffer 16KB
write_buffer 16KB
compression off
}
# 故障转移
fail_duration 30s
max_fails 3
unhealthy_status 500 502 503 504
}
# =========================================================================
# 压缩配置
# =========================================================================
encode {
zstd
gzip 6
minimum_length 256
match {
header Content-Type text/*
header Content-Type application/json*
header Content-Type application/javascript*
header Content-Type application/xml*
header Content-Type application/rss+xml*
header Content-Type image/svg+xml*
}
}
# =========================================================================
# 速率限制 (需要 caddy-ratelimit 插件)
# 如未安装插件,请注释掉此段
# =========================================================================
# rate_limit {
# zone api {
# key {remote_host}
# events 100
# window 1m
# }
# }
# =========================================================================
# 安全响应头
# =========================================================================
header {
# 防止点击劫持
X-Frame-Options "SAMEORIGIN"
# XSS 保护
X-XSS-Protection "1; mode=block"
# 防止 MIME 类型嗅探
X-Content-Type-Options "nosniff"
# 引用策略
Referrer-Policy "strict-origin-when-cross-origin"
# HSTS - 强制 HTTPS (max-age=1年)
Strict-Transport-Security "max-age=31536000; includeSubDomains; preload"
# 内容安全策略 (根据需要调整)
# Content-Security-Policy "default-src 'self'; script-src 'self' 'unsafe-inline' 'unsafe-eval'; style-src 'self' 'unsafe-inline'; img-src 'self' data: https:; font-src 'self' data:; connect-src 'self' https:;"
# 权限策略
Permissions-Policy "accelerometer=(), camera=(), geolocation=(), gyroscope=(), magnetometer=(), microphone=(), payment=(), usb=()"
# 跨域资源策略
Cross-Origin-Opener-Policy "same-origin"
Cross-Origin-Embedder-Policy "require-corp"
Cross-Origin-Resource-Policy "same-origin"
# 移除敏感头
-Server
-X-Powered-By
}
# =========================================================================
# 请求大小限制 (防止大文件攻击)
# =========================================================================
request_body {
max_size 100MB
}
# =========================================================================
# 日志配置
# =========================================================================
log {
output file /var/log/caddy/sub2api.log {
roll_size 50mb
roll_keep 10
roll_keep_for 720h
}
format json
level INFO
}
# =========================================================================
# 错误处理
# =========================================================================
handle_errors {
respond "{err.status_code} {err.status_text}"
}
}
\ No newline at end of file
......@@ -516,7 +516,7 @@ func (c *UserMessageQueueConfig) GetEffectiveMode() string {
type GatewayOpenAIWSConfig struct {
// ModeRouterV2Enabled: 新版 WS mode 路由开关(默认 false;关闭时保持 legacy 行为)
ModeRouterV2Enabled bool `mapstructure:"mode_router_v2_enabled"`
// IngressModeDefault: ingress 默认模式(off/shared/dedicated
// IngressModeDefault: ingress 默认模式(off/ctx_pool/passthrough
IngressModeDefault string `mapstructure:"ingress_mode_default"`
// Enabled: 全局总开关(默认 true)
Enabled bool `mapstructure:"enabled"`
......@@ -1335,7 +1335,7 @@ func setDefaults() {
// OpenAI Responses WebSocket(默认开启;可通过 force_http 紧急回滚)
viper.SetDefault("gateway.openai_ws.enabled", true)
viper.SetDefault("gateway.openai_ws.mode_router_v2_enabled", false)
viper.SetDefault("gateway.openai_ws.ingress_mode_default", "shared")
viper.SetDefault("gateway.openai_ws.ingress_mode_default", "ctx_pool")
viper.SetDefault("gateway.openai_ws.oauth_enabled", true)
viper.SetDefault("gateway.openai_ws.apikey_enabled", true)
viper.SetDefault("gateway.openai_ws.force_http", false)
......@@ -2043,9 +2043,11 @@ func (c *Config) Validate() error {
}
if mode := strings.ToLower(strings.TrimSpace(c.Gateway.OpenAIWS.IngressModeDefault)); mode != "" {
switch mode {
case "off", "shared", "dedicated":
case "off", "ctx_pool", "passthrough":
case "shared", "dedicated":
slog.Warn("gateway.openai_ws.ingress_mode_default is deprecated, treating as ctx_pool; please update to off|ctx_pool|passthrough", "value", mode)
default:
return fmt.Errorf("gateway.openai_ws.ingress_mode_default must be one of off|shared|dedicated")
return fmt.Errorf("gateway.openai_ws.ingress_mode_default must be one of off|ctx_pool|passthrough")
}
}
if mode := strings.ToLower(strings.TrimSpace(c.Gateway.OpenAIWS.StoreDisabledConnMode)); mode != "" {
......
......@@ -153,8 +153,8 @@ func TestLoadDefaultOpenAIWSConfig(t *testing.T) {
if cfg.Gateway.OpenAIWS.ModeRouterV2Enabled {
t.Fatalf("Gateway.OpenAIWS.ModeRouterV2Enabled = true, want false")
}
if cfg.Gateway.OpenAIWS.IngressModeDefault != "shared" {
t.Fatalf("Gateway.OpenAIWS.IngressModeDefault = %q, want %q", cfg.Gateway.OpenAIWS.IngressModeDefault, "shared")
if cfg.Gateway.OpenAIWS.IngressModeDefault != "ctx_pool" {
t.Fatalf("Gateway.OpenAIWS.IngressModeDefault = %q, want %q", cfg.Gateway.OpenAIWS.IngressModeDefault, "ctx_pool")
}
}
......@@ -1373,7 +1373,7 @@ func TestValidateConfig_OpenAIWSRules(t *testing.T) {
wantErr: "gateway.openai_ws.store_disabled_conn_mode",
},
{
name: "ingress_mode_default 必须为 off|shared|dedicated",
name: "ingress_mode_default 必须为 off|ctx_pool|passthrough",
mutate: func(c *Config) { c.Gateway.OpenAIWS.IngressModeDefault = "invalid" },
wantErr: "gateway.openai_ws.ingress_mode_default",
},
......
......@@ -853,15 +853,21 @@ func (a *Account) IsOpenAIResponsesWebSocketV2Enabled() bool {
}
const (
OpenAIWSIngressModeOff = "off"
OpenAIWSIngressModeShared = "shared"
OpenAIWSIngressModeDedicated = "dedicated"
OpenAIWSIngressModeOff = "off"
OpenAIWSIngressModeShared = "shared"
OpenAIWSIngressModeDedicated = "dedicated"
OpenAIWSIngressModeCtxPool = "ctx_pool"
OpenAIWSIngressModePassthrough = "passthrough"
)
func normalizeOpenAIWSIngressMode(mode string) string {
switch strings.ToLower(strings.TrimSpace(mode)) {
case OpenAIWSIngressModeOff:
return OpenAIWSIngressModeOff
case OpenAIWSIngressModeCtxPool:
return OpenAIWSIngressModeCtxPool
case OpenAIWSIngressModePassthrough:
return OpenAIWSIngressModePassthrough
case OpenAIWSIngressModeShared:
return OpenAIWSIngressModeShared
case OpenAIWSIngressModeDedicated:
......@@ -873,18 +879,21 @@ func normalizeOpenAIWSIngressMode(mode string) string {
func normalizeOpenAIWSIngressDefaultMode(mode string) string {
if normalized := normalizeOpenAIWSIngressMode(mode); normalized != "" {
if normalized == OpenAIWSIngressModeShared || normalized == OpenAIWSIngressModeDedicated {
return OpenAIWSIngressModeCtxPool
}
return normalized
}
return OpenAIWSIngressModeShared
return OpenAIWSIngressModeCtxPool
}
// ResolveOpenAIResponsesWebSocketV2Mode 返回账号在 WSv2 ingress 下的有效模式(off/shared/dedicated)。
// ResolveOpenAIResponsesWebSocketV2Mode 返回账号在 WSv2 ingress 下的有效模式(off/ctx_pool/passthrough)。
//
// 优先级:
// 1. 分类型 mode 新字段(string)
// 2. 分类型 enabled 旧字段(bool)
// 3. 兼容 enabled 旧字段(bool)
// 4. defaultMode(非法时回退 shared
// 4. defaultMode(非法时回退 ctx_pool
func (a *Account) ResolveOpenAIResponsesWebSocketV2Mode(defaultMode string) string {
resolvedDefault := normalizeOpenAIWSIngressDefaultMode(defaultMode)
if a == nil || !a.IsOpenAI() {
......@@ -919,7 +928,7 @@ func (a *Account) ResolveOpenAIResponsesWebSocketV2Mode(defaultMode string) stri
return "", false
}
if enabled {
return OpenAIWSIngressModeShared, true
return OpenAIWSIngressModeCtxPool, true
}
return OpenAIWSIngressModeOff, true
}
......@@ -946,6 +955,10 @@ func (a *Account) ResolveOpenAIResponsesWebSocketV2Mode(defaultMode string) stri
if mode, ok := resolveBoolMode("openai_ws_enabled"); ok {
return mode
}
// 兼容旧值:shared/dedicated 语义都归并到 ctx_pool。
if resolvedDefault == OpenAIWSIngressModeShared || resolvedDefault == OpenAIWSIngressModeDedicated {
return OpenAIWSIngressModeCtxPool
}
return resolvedDefault
}
......
......@@ -206,14 +206,14 @@ func TestAccount_IsOpenAIResponsesWebSocketV2Enabled(t *testing.T) {
}
func TestAccount_ResolveOpenAIResponsesWebSocketV2Mode(t *testing.T) {
t.Run("default fallback to shared", func(t *testing.T) {
t.Run("default fallback to ctx_pool", func(t *testing.T) {
account := &Account{
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Extra: map[string]any{},
}
require.Equal(t, OpenAIWSIngressModeShared, account.ResolveOpenAIResponsesWebSocketV2Mode(""))
require.Equal(t, OpenAIWSIngressModeShared, account.ResolveOpenAIResponsesWebSocketV2Mode("invalid"))
require.Equal(t, OpenAIWSIngressModeCtxPool, account.ResolveOpenAIResponsesWebSocketV2Mode(""))
require.Equal(t, OpenAIWSIngressModeCtxPool, account.ResolveOpenAIResponsesWebSocketV2Mode("invalid"))
})
t.Run("oauth mode field has highest priority", func(t *testing.T) {
......@@ -221,15 +221,15 @@ func TestAccount_ResolveOpenAIResponsesWebSocketV2Mode(t *testing.T) {
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Extra: map[string]any{
"openai_oauth_responses_websockets_v2_mode": OpenAIWSIngressModeDedicated,
"openai_oauth_responses_websockets_v2_mode": OpenAIWSIngressModePassthrough,
"openai_oauth_responses_websockets_v2_enabled": false,
"responses_websockets_v2_enabled": false,
},
}
require.Equal(t, OpenAIWSIngressModeDedicated, account.ResolveOpenAIResponsesWebSocketV2Mode(OpenAIWSIngressModeShared))
require.Equal(t, OpenAIWSIngressModePassthrough, account.ResolveOpenAIResponsesWebSocketV2Mode(OpenAIWSIngressModeCtxPool))
})
t.Run("legacy enabled maps to shared", func(t *testing.T) {
t.Run("legacy enabled maps to ctx_pool", func(t *testing.T) {
account := &Account{
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
......@@ -237,7 +237,28 @@ func TestAccount_ResolveOpenAIResponsesWebSocketV2Mode(t *testing.T) {
"responses_websockets_v2_enabled": true,
},
}
require.Equal(t, OpenAIWSIngressModeShared, account.ResolveOpenAIResponsesWebSocketV2Mode(OpenAIWSIngressModeOff))
require.Equal(t, OpenAIWSIngressModeCtxPool, account.ResolveOpenAIResponsesWebSocketV2Mode(OpenAIWSIngressModeOff))
})
t.Run("shared/dedicated mode strings are compatible with ctx_pool", func(t *testing.T) {
shared := &Account{
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Extra: map[string]any{
"openai_oauth_responses_websockets_v2_mode": OpenAIWSIngressModeShared,
},
}
dedicated := &Account{
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Extra: map[string]any{
"openai_oauth_responses_websockets_v2_mode": OpenAIWSIngressModeDedicated,
},
}
require.Equal(t, OpenAIWSIngressModeShared, shared.ResolveOpenAIResponsesWebSocketV2Mode(OpenAIWSIngressModeOff))
require.Equal(t, OpenAIWSIngressModeDedicated, dedicated.ResolveOpenAIResponsesWebSocketV2Mode(OpenAIWSIngressModeOff))
require.Equal(t, OpenAIWSIngressModeCtxPool, normalizeOpenAIWSIngressDefaultMode(OpenAIWSIngressModeShared))
require.Equal(t, OpenAIWSIngressModeCtxPool, normalizeOpenAIWSIngressDefaultMode(OpenAIWSIngressModeDedicated))
})
t.Run("legacy disabled maps to off", func(t *testing.T) {
......@@ -249,7 +270,7 @@ func TestAccount_ResolveOpenAIResponsesWebSocketV2Mode(t *testing.T) {
"responses_websockets_v2_enabled": true,
},
}
require.Equal(t, OpenAIWSIngressModeOff, account.ResolveOpenAIResponsesWebSocketV2Mode(OpenAIWSIngressModeShared))
require.Equal(t, OpenAIWSIngressModeOff, account.ResolveOpenAIResponsesWebSocketV2Mode(OpenAIWSIngressModeCtxPool))
})
t.Run("non openai always off", func(t *testing.T) {
......
......@@ -263,13 +263,15 @@ type OpenAIGatewayService struct {
toolCorrector *CodexToolCorrector
openaiWSResolver OpenAIWSProtocolResolver
openaiWSPoolOnce sync.Once
openaiWSStateStoreOnce sync.Once
openaiSchedulerOnce sync.Once
openaiWSPool *openAIWSConnPool
openaiWSStateStore OpenAIWSStateStore
openaiScheduler OpenAIAccountScheduler
openaiAccountStats *openAIAccountRuntimeStats
openaiWSPoolOnce sync.Once
openaiWSStateStoreOnce sync.Once
openaiSchedulerOnce sync.Once
openaiWSPassthroughDialerOnce sync.Once
openaiWSPool *openAIWSConnPool
openaiWSStateStore OpenAIWSStateStore
openaiScheduler OpenAIAccountScheduler
openaiWSPassthroughDialer openAIWSClientDialer
openaiAccountStats *openAIAccountRuntimeStats
openaiWSFallbackUntil sync.Map // key: int64(accountID), value: time.Time
openaiWSRetryMetrics openAIWSRetryMetrics
......
......@@ -11,6 +11,7 @@ import (
"sync/atomic"
"time"
openaiwsv2 "github.com/Wei-Shaw/sub2api/internal/service/openai_ws_v2"
coderws "github.com/coder/websocket"
"github.com/coder/websocket/wsjson"
)
......@@ -234,6 +235,8 @@ type coderOpenAIWSClientConn struct {
conn *coderws.Conn
}
var _ openaiwsv2.FrameConn = (*coderOpenAIWSClientConn)(nil)
func (c *coderOpenAIWSClientConn) WriteJSON(ctx context.Context, value any) error {
if c == nil || c.conn == nil {
return errOpenAIWSConnClosed
......@@ -264,6 +267,30 @@ func (c *coderOpenAIWSClientConn) ReadMessage(ctx context.Context) ([]byte, erro
}
}
func (c *coderOpenAIWSClientConn) ReadFrame(ctx context.Context) (coderws.MessageType, []byte, error) {
if c == nil || c.conn == nil {
return coderws.MessageText, nil, errOpenAIWSConnClosed
}
if ctx == nil {
ctx = context.Background()
}
msgType, payload, err := c.conn.Read(ctx)
if err != nil {
return coderws.MessageText, nil, err
}
return msgType, payload, nil
}
func (c *coderOpenAIWSClientConn) WriteFrame(ctx context.Context, msgType coderws.MessageType, payload []byte) error {
if c == nil || c.conn == nil {
return errOpenAIWSConnClosed
}
if ctx == nil {
ctx = context.Background()
}
return c.conn.Write(ctx, msgType, payload)
}
func (c *coderOpenAIWSClientConn) Ping(ctx context.Context) error {
if c == nil || c.conn == nil {
return errOpenAIWSConnClosed
......
......@@ -46,9 +46,10 @@ const (
openAIWSPayloadSizeEstimateMaxBytes = 64 * 1024
openAIWSPayloadSizeEstimateMaxItems = 16
openAIWSEventFlushBatchSizeDefault = 4
openAIWSEventFlushIntervalDefault = 25 * time.Millisecond
openAIWSPayloadLogSampleDefault = 0.2
openAIWSEventFlushBatchSizeDefault = 4
openAIWSEventFlushIntervalDefault = 25 * time.Millisecond
openAIWSPayloadLogSampleDefault = 0.2
openAIWSPassthroughIdleTimeoutDefault = time.Hour
openAIWSStoreDisabledConnModeStrict = "strict"
openAIWSStoreDisabledConnModeAdaptive = "adaptive"
......@@ -904,6 +905,18 @@ func (s *OpenAIGatewayService) getOpenAIWSConnPool() *openAIWSConnPool {
return s.openaiWSPool
}
func (s *OpenAIGatewayService) getOpenAIWSPassthroughDialer() openAIWSClientDialer {
if s == nil {
return nil
}
s.openaiWSPassthroughDialerOnce.Do(func() {
if s.openaiWSPassthroughDialer == nil {
s.openaiWSPassthroughDialer = newDefaultOpenAIWSClientDialer()
}
})
return s.openaiWSPassthroughDialer
}
func (s *OpenAIGatewayService) SnapshotOpenAIWSPoolMetrics() OpenAIWSPoolMetricsSnapshot {
pool := s.getOpenAIWSConnPool()
if pool == nil {
......@@ -967,6 +980,13 @@ func (s *OpenAIGatewayService) openAIWSReadTimeout() time.Duration {
return 15 * time.Minute
}
func (s *OpenAIGatewayService) openAIWSPassthroughIdleTimeout() time.Duration {
if timeout := s.openAIWSReadTimeout(); timeout > 0 {
return timeout
}
return openAIWSPassthroughIdleTimeoutDefault
}
func (s *OpenAIGatewayService) openAIWSWriteTimeout() time.Duration {
if s != nil && s.cfg != nil && s.cfg.Gateway.OpenAIWS.WriteTimeoutSeconds > 0 {
return time.Duration(s.cfg.Gateway.OpenAIWS.WriteTimeoutSeconds) * time.Second
......@@ -2322,7 +2342,7 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
wsDecision := s.getOpenAIWSProtocolResolver().Resolve(account)
modeRouterV2Enabled := s != nil && s.cfg != nil && s.cfg.Gateway.OpenAIWS.ModeRouterV2Enabled
ingressMode := OpenAIWSIngressModeShared
ingressMode := OpenAIWSIngressModeCtxPool
if modeRouterV2Enabled {
ingressMode = account.ResolveOpenAIResponsesWebSocketV2Mode(s.cfg.Gateway.OpenAIWS.IngressModeDefault)
if ingressMode == OpenAIWSIngressModeOff {
......@@ -2332,6 +2352,30 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
nil,
)
}
switch ingressMode {
case OpenAIWSIngressModePassthrough:
if wsDecision.Transport != OpenAIUpstreamTransportResponsesWebsocketV2 {
return fmt.Errorf("websocket ingress requires ws_v2 transport, got=%s", wsDecision.Transport)
}
return s.proxyResponsesWebSocketV2Passthrough(
ctx,
c,
clientConn,
account,
token,
firstClientMessage,
hooks,
wsDecision,
)
case OpenAIWSIngressModeCtxPool, OpenAIWSIngressModeShared, OpenAIWSIngressModeDedicated:
// continue
default:
return NewOpenAIWSClientCloseError(
coderws.StatusPolicyViolation,
"websocket mode only supports ctx_pool/passthrough",
nil,
)
}
}
if wsDecision.Transport != OpenAIUpstreamTransportResponsesWebsocketV2 {
return fmt.Errorf("websocket ingress requires ws_v2 transport, got=%s", wsDecision.Transport)
......
......@@ -149,7 +149,7 @@ func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_KeepLeaseAcrossT
require.True(t, <-turnWSModeCh, "首轮 turn 应标记为 WS 模式")
require.True(t, <-turnWSModeCh, "第二轮 turn 应标记为 WS 模式")
require.NoError(t, clientConn.Close(coderws.StatusNormalClosure, "done"))
_ = clientConn.Close(coderws.StatusNormalClosure, "done")
select {
case serverErr := <-serverErrCh:
......@@ -298,6 +298,140 @@ func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_DedicatedModeDoe
require.Equal(t, 2, dialer.DialCount(), "dedicated 模式下跨客户端会话不应复用上游连接")
}
func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_PassthroughModeRelaysByCaddyAdapter(t *testing.T) {
gin.SetMode(gin.TestMode)
cfg := &config.Config{}
cfg.Security.URLAllowlist.Enabled = false
cfg.Security.URLAllowlist.AllowInsecureHTTP = true
cfg.Gateway.OpenAIWS.Enabled = true
cfg.Gateway.OpenAIWS.OAuthEnabled = true
cfg.Gateway.OpenAIWS.APIKeyEnabled = true
cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
cfg.Gateway.OpenAIWS.ModeRouterV2Enabled = true
cfg.Gateway.OpenAIWS.IngressModeDefault = OpenAIWSIngressModeCtxPool
cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3
cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 3
cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 3
upstreamConn := &openAIWSCaptureConn{
events: [][]byte{
[]byte(`{"type":"response.completed","response":{"id":"resp_passthrough_turn_1","model":"gpt-5.1","usage":{"input_tokens":2,"output_tokens":3}}}`),
},
}
captureDialer := &openAIWSCaptureDialer{conn: upstreamConn}
svc := &OpenAIGatewayService{
cfg: cfg,
httpUpstream: &httpUpstreamRecorder{},
cache: &stubGatewayCache{},
openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
toolCorrector: NewCodexToolCorrector(),
openaiWSPassthroughDialer: captureDialer,
}
account := &Account{
ID: 452,
Name: "openai-ingress-passthrough",
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Status: StatusActive,
Schedulable: true,
Concurrency: 1,
Credentials: map[string]any{
"api_key": "sk-test",
},
Extra: map[string]any{
"openai_apikey_responses_websockets_v2_mode": OpenAIWSIngressModePassthrough,
},
}
serverErrCh := make(chan error, 1)
resultCh := make(chan *OpenAIForwardResult, 1)
hooks := &OpenAIWSIngressHooks{
AfterTurn: func(_ int, result *OpenAIForwardResult, turnErr error) {
if turnErr == nil && result != nil {
resultCh <- result
}
},
}
wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
conn, err := coderws.Accept(w, r, &coderws.AcceptOptions{
CompressionMode: coderws.CompressionContextTakeover,
})
if err != nil {
serverErrCh <- err
return
}
defer func() {
_ = conn.CloseNow()
}()
rec := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(rec)
req := r.Clone(r.Context())
req.Header = req.Header.Clone()
req.Header.Set("User-Agent", "unit-test-agent/1.0")
ginCtx.Request = req
readCtx, cancel := context.WithTimeout(r.Context(), 3*time.Second)
msgType, firstMessage, readErr := conn.Read(readCtx)
cancel()
if readErr != nil {
serverErrCh <- readErr
return
}
if msgType != coderws.MessageText && msgType != coderws.MessageBinary {
serverErrCh <- errors.New("unsupported websocket client message type")
return
}
serverErrCh <- svc.ProxyResponsesWebSocketFromClient(r.Context(), ginCtx, conn, account, "sk-test", firstMessage, hooks)
}))
defer wsServer.Close()
dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second)
clientConn, _, err := coderws.Dial(dialCtx, "ws"+strings.TrimPrefix(wsServer.URL, "http"), nil)
cancelDial()
require.NoError(t, err)
defer func() {
_ = clientConn.CloseNow()
}()
writeCtx, cancelWrite := context.WithTimeout(context.Background(), 3*time.Second)
err = clientConn.Write(writeCtx, coderws.MessageText, []byte(`{"type":"response.create","model":"gpt-5.1","stream":false}`))
cancelWrite()
require.NoError(t, err)
readCtx, cancelRead := context.WithTimeout(context.Background(), 3*time.Second)
_, event, readErr := clientConn.Read(readCtx)
cancelRead()
require.NoError(t, readErr)
require.Equal(t, "response.completed", gjson.GetBytes(event, "type").String())
require.Equal(t, "resp_passthrough_turn_1", gjson.GetBytes(event, "response.id").String())
_ = clientConn.Close(coderws.StatusNormalClosure, "done")
select {
case serverErr := <-serverErrCh:
require.NoError(t, serverErr)
case <-time.After(5 * time.Second):
t.Fatal("等待 passthrough websocket 结束超时")
}
select {
case result := <-resultCh:
require.Equal(t, "resp_passthrough_turn_1", result.RequestID)
require.True(t, result.OpenAIWSMode)
require.Equal(t, 2, result.Usage.InputTokens)
require.Equal(t, 3, result.Usage.OutputTokens)
case <-time.After(2 * time.Second):
t.Fatal("未收到 passthrough turn 结果回调")
}
require.Equal(t, 1, captureDialer.DialCount(), "passthrough 模式应直接建立上游 websocket")
require.Len(t, upstreamConn.writes, 1, "passthrough 模式应透传首条 response.create")
}
func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_ModeOffReturnsPolicyViolation(t *testing.T) {
gin.SetMode(gin.TestMode)
......
......@@ -15,6 +15,7 @@ import (
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
coderws "github.com/coder/websocket"
"github.com/gin-gonic/gin"
"github.com/gorilla/websocket"
"github.com/stretchr/testify/require"
......@@ -1282,6 +1283,18 @@ func (c *openAIWSCaptureConn) ReadMessage(ctx context.Context) ([]byte, error) {
return event, nil
}
func (c *openAIWSCaptureConn) ReadFrame(ctx context.Context) (coderws.MessageType, []byte, error) {
payload, err := c.ReadMessage(ctx)
if err != nil {
return coderws.MessageText, nil, err
}
return coderws.MessageText, payload, nil
}
func (c *openAIWSCaptureConn) WriteFrame(ctx context.Context, _ coderws.MessageType, payload []byte) error {
return c.WriteJSON(ctx, json.RawMessage(payload))
}
func (c *openAIWSCaptureConn) Ping(ctx context.Context) error {
_ = ctx
return nil
......
......@@ -69,8 +69,11 @@ func (r *defaultOpenAIWSProtocolResolver) Resolve(account *Account) OpenAIWSProt
switch mode {
case OpenAIWSIngressModeOff:
return openAIWSHTTPDecision("account_mode_off")
case OpenAIWSIngressModeShared, OpenAIWSIngressModeDedicated:
case OpenAIWSIngressModeCtxPool, OpenAIWSIngressModePassthrough:
// continue
case OpenAIWSIngressModeShared, OpenAIWSIngressModeDedicated:
// 历史值兼容:按 ctx_pool 处理。
mode = OpenAIWSIngressModeCtxPool
default:
return openAIWSHTTPDecision("account_mode_off")
}
......
......@@ -143,21 +143,21 @@ func TestOpenAIWSProtocolResolver_Resolve_ModeRouterV2(t *testing.T) {
cfg.Gateway.OpenAIWS.APIKeyEnabled = true
cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
cfg.Gateway.OpenAIWS.ModeRouterV2Enabled = true
cfg.Gateway.OpenAIWS.IngressModeDefault = OpenAIWSIngressModeShared
cfg.Gateway.OpenAIWS.IngressModeDefault = OpenAIWSIngressModeCtxPool
account := &Account{
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Concurrency: 1,
Extra: map[string]any{
"openai_oauth_responses_websockets_v2_mode": OpenAIWSIngressModeDedicated,
"openai_oauth_responses_websockets_v2_mode": OpenAIWSIngressModeCtxPool,
},
}
t.Run("dedicated mode routes to ws v2", func(t *testing.T) {
t.Run("ctx_pool mode routes to ws v2", func(t *testing.T) {
decision := NewOpenAIWSProtocolResolver(cfg).Resolve(account)
require.Equal(t, OpenAIUpstreamTransportResponsesWebsocketV2, decision.Transport)
require.Equal(t, "ws_v2_mode_dedicated", decision.Reason)
require.Equal(t, "ws_v2_mode_ctx_pool", decision.Reason)
})
t.Run("off mode routes to http", func(t *testing.T) {
......@@ -174,7 +174,7 @@ func TestOpenAIWSProtocolResolver_Resolve_ModeRouterV2(t *testing.T) {
require.Equal(t, "account_mode_off", decision.Reason)
})
t.Run("legacy boolean maps to shared in v2 router", func(t *testing.T) {
t.Run("legacy boolean maps to ctx_pool in v2 router", func(t *testing.T) {
legacyAccount := &Account{
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
......@@ -185,7 +185,21 @@ func TestOpenAIWSProtocolResolver_Resolve_ModeRouterV2(t *testing.T) {
}
decision := NewOpenAIWSProtocolResolver(cfg).Resolve(legacyAccount)
require.Equal(t, OpenAIUpstreamTransportResponsesWebsocketV2, decision.Transport)
require.Equal(t, "ws_v2_mode_shared", decision.Reason)
require.Equal(t, "ws_v2_mode_ctx_pool", decision.Reason)
})
t.Run("passthrough mode routes to ws v2", func(t *testing.T) {
passthroughAccount := &Account{
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Concurrency: 1,
Extra: map[string]any{
"openai_oauth_responses_websockets_v2_mode": OpenAIWSIngressModePassthrough,
},
}
decision := NewOpenAIWSProtocolResolver(cfg).Resolve(passthroughAccount)
require.Equal(t, OpenAIUpstreamTransportResponsesWebsocketV2, decision.Transport)
require.Equal(t, "ws_v2_mode_passthrough", decision.Reason)
})
t.Run("non-positive concurrency is rejected in v2 router", func(t *testing.T) {
......@@ -193,7 +207,7 @@ func TestOpenAIWSProtocolResolver_Resolve_ModeRouterV2(t *testing.T) {
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Extra: map[string]any{
"openai_oauth_responses_websockets_v2_mode": OpenAIWSIngressModeShared,
"openai_oauth_responses_websockets_v2_mode": OpenAIWSIngressModeCtxPool,
},
}
decision := NewOpenAIWSProtocolResolver(cfg).Resolve(invalidConcurrency)
......
package openai_ws_v2
import (
"context"
)
// runCaddyStyleRelay 采用 Caddy reverseproxy 的双向隧道思想:
// 连接建立后并发复制两个方向,任一方向退出触发收敛关闭。
//
// Reference:
// - Project: caddyserver/caddy (Apache-2.0)
// - Commit: f283062d37c50627d53ca682ebae2ce219b35515
// - Files:
// - modules/caddyhttp/reverseproxy/streaming.go
// - modules/caddyhttp/reverseproxy/reverseproxy.go
func runCaddyStyleRelay(
ctx context.Context,
clientConn FrameConn,
upstreamConn FrameConn,
firstClientMessage []byte,
options RelayOptions,
) (RelayResult, *RelayExit) {
return Relay(ctx, clientConn, upstreamConn, firstClientMessage, options)
}
package openai_ws_v2
import "context"
// EntryInput 是 passthrough v2 数据面的入口参数。
type EntryInput struct {
Ctx context.Context
ClientConn FrameConn
UpstreamConn FrameConn
FirstClientMessage []byte
Options RelayOptions
}
// RunEntry 是 openai_ws_v2 包对外的统一入口。
func RunEntry(input EntryInput) (RelayResult, *RelayExit) {
return runCaddyStyleRelay(
input.Ctx,
input.ClientConn,
input.UpstreamConn,
input.FirstClientMessage,
input.Options,
)
}
package openai_ws_v2
import (
"sync/atomic"
)
// MetricsSnapshot 是 OpenAI WS v2 passthrough 路径的轻量运行时指标快照。
type MetricsSnapshot struct {
SemanticMutationTotal int64 `json:"semantic_mutation_total"`
UsageParseFailureTotal int64 `json:"usage_parse_failure_total"`
}
var (
// passthrough 路径默认不会做语义改写,该计数通常应保持为 0(保留用于未来防御性校验)。
passthroughSemanticMutationTotal atomic.Int64
passthroughUsageParseFailureTotal atomic.Int64
)
func recordUsageParseFailure() {
passthroughUsageParseFailureTotal.Add(1)
}
// SnapshotMetrics 返回当前 passthrough 指标快照。
func SnapshotMetrics() MetricsSnapshot {
return MetricsSnapshot{
SemanticMutationTotal: passthroughSemanticMutationTotal.Load(),
UsageParseFailureTotal: passthroughUsageParseFailureTotal.Load(),
}
}
This diff is collapsed.
package openai_ws_v2
import (
"context"
"errors"
"io"
"net"
"sync/atomic"
"testing"
"time"
coderws "github.com/coder/websocket"
"github.com/stretchr/testify/require"
"github.com/tidwall/gjson"
)
func TestRunEntry_DelegatesRelay(t *testing.T) {
t.Parallel()
clientConn := newPassthroughTestFrameConn(nil, false)
upstreamConn := newPassthroughTestFrameConn([]passthroughTestFrame{
{
msgType: coderws.MessageText,
payload: []byte(`{"type":"response.completed","response":{"id":"resp_entry","usage":{"input_tokens":1,"output_tokens":1}}}`),
},
}, true)
result, relayExit := RunEntry(EntryInput{
Ctx: context.Background(),
ClientConn: clientConn,
UpstreamConn: upstreamConn,
FirstClientMessage: []byte(`{"type":"response.create","model":"gpt-4o","input":[]}`),
})
require.Nil(t, relayExit)
require.Equal(t, "resp_entry", result.RequestID)
}
func TestRunClientToUpstream_ErrorPaths(t *testing.T) {
t.Parallel()
t.Run("read client eof", func(t *testing.T) {
t.Parallel()
exitCh := make(chan relayExitSignal, 1)
runClientToUpstream(
context.Background(),
newPassthroughTestFrameConn(nil, true),
func(_ coderws.MessageType, _ []byte) error { return nil },
func() {},
nil,
nil,
exitCh,
)
sig := <-exitCh
require.Equal(t, "read_client", sig.stage)
require.True(t, sig.graceful)
})
t.Run("write upstream failed", func(t *testing.T) {
t.Parallel()
exitCh := make(chan relayExitSignal, 1)
runClientToUpstream(
context.Background(),
newPassthroughTestFrameConn([]passthroughTestFrame{
{msgType: coderws.MessageText, payload: []byte(`{"x":1}`)},
}, true),
func(_ coderws.MessageType, _ []byte) error { return errors.New("boom") },
func() {},
nil,
nil,
exitCh,
)
sig := <-exitCh
require.Equal(t, "write_upstream", sig.stage)
require.False(t, sig.graceful)
})
t.Run("forwarded counter and trace callback", func(t *testing.T) {
t.Parallel()
exitCh := make(chan relayExitSignal, 1)
forwarded := &atomic.Int64{}
traces := make([]RelayTraceEvent, 0, 2)
runClientToUpstream(
context.Background(),
newPassthroughTestFrameConn([]passthroughTestFrame{
{msgType: coderws.MessageText, payload: []byte(`{"x":1}`)},
}, true),
func(_ coderws.MessageType, _ []byte) error { return nil },
func() {},
forwarded,
func(event RelayTraceEvent) {
traces = append(traces, event)
},
exitCh,
)
sig := <-exitCh
require.Equal(t, "read_client", sig.stage)
require.Equal(t, int64(1), forwarded.Load())
require.NotEmpty(t, traces)
})
}
func TestRunUpstreamToClient_ErrorAndDropPaths(t *testing.T) {
t.Parallel()
t.Run("read upstream eof", func(t *testing.T) {
t.Parallel()
exitCh := make(chan relayExitSignal, 1)
drop := &atomic.Bool{}
drop.Store(false)
runUpstreamToClient(
context.Background(),
newPassthroughTestFrameConn(nil, true),
func(_ coderws.MessageType, _ []byte) error { return nil },
time.Now(),
time.Now,
&relayState{},
nil,
nil,
drop,
nil,
nil,
func() {},
nil,
exitCh,
)
sig := <-exitCh
require.Equal(t, "read_upstream", sig.stage)
require.True(t, sig.graceful)
})
t.Run("write client failed", func(t *testing.T) {
t.Parallel()
exitCh := make(chan relayExitSignal, 1)
drop := &atomic.Bool{}
drop.Store(false)
runUpstreamToClient(
context.Background(),
newPassthroughTestFrameConn([]passthroughTestFrame{
{msgType: coderws.MessageText, payload: []byte(`{"type":"response.output_text.delta","delta":"x"}`)},
}, true),
func(_ coderws.MessageType, _ []byte) error { return errors.New("write failed") },
time.Now(),
time.Now,
&relayState{},
nil,
nil,
drop,
nil,
nil,
func() {},
nil,
exitCh,
)
sig := <-exitCh
require.Equal(t, "write_client", sig.stage)
})
t.Run("drop downstream and stop on terminal", func(t *testing.T) {
t.Parallel()
exitCh := make(chan relayExitSignal, 1)
drop := &atomic.Bool{}
drop.Store(true)
dropped := &atomic.Int64{}
runUpstreamToClient(
context.Background(),
newPassthroughTestFrameConn([]passthroughTestFrame{
{
msgType: coderws.MessageText,
payload: []byte(`{"type":"response.completed","response":{"id":"resp_drop","usage":{"input_tokens":1,"output_tokens":1}}}`),
},
}, true),
func(_ coderws.MessageType, _ []byte) error { return nil },
time.Now(),
time.Now,
&relayState{},
nil,
nil,
drop,
nil,
dropped,
func() {},
nil,
exitCh,
)
sig := <-exitCh
require.Equal(t, "drain_terminal", sig.stage)
require.True(t, sig.graceful)
require.Equal(t, int64(1), dropped.Load())
})
}
func TestRunIdleWatchdog_NoTimeoutWhenDisabled(t *testing.T) {
t.Parallel()
exitCh := make(chan relayExitSignal, 1)
lastActivity := &atomic.Int64{}
lastActivity.Store(time.Now().UnixNano())
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
go runIdleWatchdog(ctx, time.Now, 0, lastActivity, nil, exitCh)
select {
case <-exitCh:
t.Fatal("unexpected idle timeout signal")
case <-time.After(200 * time.Millisecond):
}
}
func TestHelperFunctionsCoverage(t *testing.T) {
t.Parallel()
require.Equal(t, "text", relayMessageTypeString(coderws.MessageText))
require.Equal(t, "binary", relayMessageTypeString(coderws.MessageBinary))
require.Contains(t, relayMessageTypeString(coderws.MessageType(99)), "unknown(")
require.Equal(t, "", relayErrorString(nil))
require.Equal(t, "x", relayErrorString(errors.New("x")))
require.True(t, isDisconnectError(io.EOF))
require.True(t, isDisconnectError(net.ErrClosed))
require.True(t, isDisconnectError(context.Canceled))
require.True(t, isDisconnectError(coderws.CloseError{Code: coderws.StatusGoingAway}))
require.True(t, isDisconnectError(errors.New("broken pipe")))
require.False(t, isDisconnectError(errors.New("unrelated")))
require.True(t, isTokenEvent("response.output_text.delta"))
require.True(t, isTokenEvent("response.output_audio.delta"))
require.True(t, isTokenEvent("response.completed"))
require.False(t, isTokenEvent(""))
require.False(t, isTokenEvent("response.created"))
require.Equal(t, 2*time.Second, minDuration(2*time.Second, 5*time.Second))
require.Equal(t, 2*time.Second, minDuration(5*time.Second, 2*time.Second))
require.Equal(t, 5*time.Second, minDuration(0, 5*time.Second))
require.Equal(t, 2*time.Second, minDuration(2*time.Second, 0))
ch := make(chan relayExitSignal, 1)
ch <- relayExitSignal{stage: "ok"}
sig, ok := waitRelayExit(ch, 10*time.Millisecond)
require.True(t, ok)
require.Equal(t, "ok", sig.stage)
ch <- relayExitSignal{stage: "ok2"}
sig, ok = waitRelayExit(ch, 0)
require.True(t, ok)
require.Equal(t, "ok2", sig.stage)
_, ok = waitRelayExit(ch, 10*time.Millisecond)
require.False(t, ok)
n, ok := parseUsageIntField(gjson.Get(`{"n":3}`, "n"), true)
require.True(t, ok)
require.Equal(t, 3, n)
_, ok = parseUsageIntField(gjson.Get(`{"n":"x"}`, "n"), true)
require.False(t, ok)
n, ok = parseUsageIntField(gjson.Result{}, false)
require.True(t, ok)
require.Equal(t, 0, n)
_, ok = parseUsageIntField(gjson.Result{}, true)
require.False(t, ok)
}
func TestParseUsageAndEnrichCoverage(t *testing.T) {
t.Parallel()
state := &relayState{}
parseUsageAndAccumulate(state, []byte(`{"type":"response.completed","response":{"usage":{"input_tokens":"bad"}}}`), "response.completed", nil)
require.Equal(t, 0, state.usage.InputTokens)
parseUsageAndAccumulate(
state,
[]byte(`{"type":"response.completed","response":{"usage":{"input_tokens":9,"output_tokens":"bad","input_tokens_details":{"cached_tokens":2}}}}`),
"response.completed",
nil,
)
require.Equal(t, 0, state.usage.InputTokens, "部分字段解析失败时不应累加 usage")
require.Equal(t, 0, state.usage.OutputTokens)
require.Equal(t, 0, state.usage.CacheReadInputTokens)
parseUsageAndAccumulate(
state,
[]byte(`{"type":"response.completed","response":{"usage":{"input_tokens_details":{"cached_tokens":2}}}}`),
"response.completed",
nil,
)
require.Equal(t, 0, state.usage.InputTokens, "必填 usage 字段缺失时不应累加 usage")
require.Equal(t, 0, state.usage.OutputTokens)
require.Equal(t, 0, state.usage.CacheReadInputTokens)
parseUsageAndAccumulate(state, []byte(`{"type":"response.completed","response":{"usage":{"input_tokens":2,"output_tokens":1,"input_tokens_details":{"cached_tokens":1}}}}`), "response.completed", nil)
require.Equal(t, 2, state.usage.InputTokens)
require.Equal(t, 1, state.usage.OutputTokens)
require.Equal(t, 1, state.usage.CacheReadInputTokens)
result := &RelayResult{}
enrichResult(result, state, 5*time.Millisecond)
require.Equal(t, state.usage.InputTokens, result.Usage.InputTokens)
require.Equal(t, 5*time.Millisecond, result.Duration)
parseUsageAndAccumulate(state, []byte(`{"type":"response.in_progress","response":{"usage":{"input_tokens":9}}}`), "response.in_progress", nil)
require.Equal(t, 2, state.usage.InputTokens)
enrichResult(nil, state, 0)
}
func TestEmitTurnCompleteCoverage(t *testing.T) {
t.Parallel()
// 非 terminal 事件不应触发。
called := 0
emitTurnComplete(func(turn RelayTurnResult) {
called++
}, &relayState{requestModel: "gpt-5"}, observedUpstreamEvent{
terminal: false,
eventType: "response.output_text.delta",
responseID: "resp_ignored",
usage: Usage{InputTokens: 1},
})
require.Equal(t, 0, called)
// 缺少 response_id 时不应触发。
emitTurnComplete(func(turn RelayTurnResult) {
called++
}, &relayState{requestModel: "gpt-5"}, observedUpstreamEvent{
terminal: true,
eventType: "response.completed",
})
require.Equal(t, 0, called)
// terminal 且 response_id 存在,应该触发;state=nil 时 model 为空串。
var got RelayTurnResult
emitTurnComplete(func(turn RelayTurnResult) {
called++
got = turn
}, nil, observedUpstreamEvent{
terminal: true,
eventType: "response.completed",
responseID: "resp_emit",
usage: Usage{InputTokens: 2, OutputTokens: 3},
})
require.Equal(t, 1, called)
require.Equal(t, "resp_emit", got.RequestID)
require.Equal(t, "response.completed", got.TerminalEventType)
require.Equal(t, 2, got.Usage.InputTokens)
require.Equal(t, 3, got.Usage.OutputTokens)
require.Equal(t, "", got.RequestModel)
}
func TestIsDisconnectErrorCoverage_CloseStatusesAndMessageBranches(t *testing.T) {
t.Parallel()
require.True(t, isDisconnectError(coderws.CloseError{Code: coderws.StatusNormalClosure}))
require.True(t, isDisconnectError(coderws.CloseError{Code: coderws.StatusNoStatusRcvd}))
require.True(t, isDisconnectError(coderws.CloseError{Code: coderws.StatusAbnormalClosure}))
require.True(t, isDisconnectError(errors.New("connection reset by peer")))
require.False(t, isDisconnectError(errors.New(" ")))
}
func TestIsTokenEventCoverageBranches(t *testing.T) {
t.Parallel()
require.False(t, isTokenEvent("response.in_progress"))
require.False(t, isTokenEvent("response.output_item.added"))
require.True(t, isTokenEvent("response.output_audio.delta"))
require.True(t, isTokenEvent("response.output"))
require.True(t, isTokenEvent("response.done"))
}
func TestRelayTurnTimingHelpersCoverage(t *testing.T) {
t.Parallel()
now := time.Unix(100, 0)
// nil state
require.Nil(t, openAIWSRelayGetOrInitTurnTiming(nil, "resp_nil", now))
_, ok := openAIWSRelayDeleteTurnTiming(nil, "resp_nil")
require.False(t, ok)
state := &relayState{}
timing := openAIWSRelayGetOrInitTurnTiming(state, "resp_a", now)
require.NotNil(t, timing)
require.Equal(t, now, timing.startAt)
// 再次获取返回同一条 timing
timing2 := openAIWSRelayGetOrInitTurnTiming(state, "resp_a", now.Add(5*time.Second))
require.NotNil(t, timing2)
require.Equal(t, now, timing2.startAt)
// 删除存在键
deleted, ok := openAIWSRelayDeleteTurnTiming(state, "resp_a")
require.True(t, ok)
require.Equal(t, now, deleted.startAt)
// 删除不存在键
_, ok = openAIWSRelayDeleteTurnTiming(state, "resp_a")
require.False(t, ok)
}
func TestObserveUpstreamMessage_ResponseIDFallbackPolicy(t *testing.T) {
t.Parallel()
state := &relayState{requestModel: "gpt-5"}
startAt := time.Unix(0, 0)
now := startAt
nowFn := func() time.Time {
now = now.Add(5 * time.Millisecond)
return now
}
// 非 terminal:仅有顶层 id,不应把 event id 当成 response_id。
observed := observeUpstreamMessage(
state,
[]byte(`{"type":"response.output_text.delta","id":"evt_123","delta":"hi"}`),
startAt,
nowFn,
nil,
)
require.False(t, observed.terminal)
require.Equal(t, "", observed.responseID)
// terminal:允许兜底用顶层 id(用于兼容少数字段变体)。
observed = observeUpstreamMessage(
state,
[]byte(`{"type":"response.completed","id":"resp_fallback","response":{"usage":{"input_tokens":1,"output_tokens":1}}}`),
startAt,
nowFn,
nil,
)
require.True(t, observed.terminal)
require.Equal(t, "resp_fallback", observed.responseID)
}
package service
import (
"context"
"errors"
"fmt"
"net/http"
"net/url"
"strings"
"sync/atomic"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
openaiwsv2 "github.com/Wei-Shaw/sub2api/internal/service/openai_ws_v2"
coderws "github.com/coder/websocket"
"github.com/gin-gonic/gin"
"github.com/tidwall/gjson"
)
type openAIWSClientFrameConn struct {
conn *coderws.Conn
}
const openaiWSV2PassthroughModeFields = "ws_mode=passthrough ws_router=v2"
var _ openaiwsv2.FrameConn = (*openAIWSClientFrameConn)(nil)
func (c *openAIWSClientFrameConn) ReadFrame(ctx context.Context) (coderws.MessageType, []byte, error) {
if c == nil || c.conn == nil {
return coderws.MessageText, nil, errOpenAIWSConnClosed
}
if ctx == nil {
ctx = context.Background()
}
return c.conn.Read(ctx)
}
func (c *openAIWSClientFrameConn) WriteFrame(ctx context.Context, msgType coderws.MessageType, payload []byte) error {
if c == nil || c.conn == nil {
return errOpenAIWSConnClosed
}
if ctx == nil {
ctx = context.Background()
}
return c.conn.Write(ctx, msgType, payload)
}
func (c *openAIWSClientFrameConn) Close() error {
if c == nil || c.conn == nil {
return nil
}
_ = c.conn.Close(coderws.StatusNormalClosure, "")
_ = c.conn.CloseNow()
return nil
}
func (s *OpenAIGatewayService) proxyResponsesWebSocketV2Passthrough(
ctx context.Context,
c *gin.Context,
clientConn *coderws.Conn,
account *Account,
token string,
firstClientMessage []byte,
hooks *OpenAIWSIngressHooks,
wsDecision OpenAIWSProtocolDecision,
) error {
if s == nil {
return errors.New("service is nil")
}
if clientConn == nil {
return errors.New("client websocket is nil")
}
if account == nil {
return errors.New("account is nil")
}
if strings.TrimSpace(token) == "" {
return errors.New("token is empty")
}
requestModel := strings.TrimSpace(gjson.GetBytes(firstClientMessage, "model").String())
requestPreviousResponseID := strings.TrimSpace(gjson.GetBytes(firstClientMessage, "previous_response_id").String())
logOpenAIWSV2Passthrough(
"relay_start account_id=%d model=%s previous_response_id=%s first_message_type=%s first_message_bytes=%d",
account.ID,
truncateOpenAIWSLogValue(requestModel, openAIWSLogValueMaxLen),
truncateOpenAIWSLogValue(requestPreviousResponseID, openAIWSIDValueMaxLen),
openaiwsv2RelayMessageTypeName(coderws.MessageText),
len(firstClientMessage),
)
wsURL, err := s.buildOpenAIResponsesWSURL(account)
if err != nil {
return fmt.Errorf("build ws url: %w", err)
}
wsHost := "-"
wsPath := "-"
if parsedURL, parseErr := url.Parse(wsURL); parseErr == nil && parsedURL != nil {
wsHost = normalizeOpenAIWSLogValue(parsedURL.Host)
wsPath = normalizeOpenAIWSLogValue(parsedURL.Path)
}
logOpenAIWSV2Passthrough(
"relay_dial_start account_id=%d ws_host=%s ws_path=%s proxy_enabled=%v",
account.ID,
wsHost,
wsPath,
account.ProxyID != nil && account.Proxy != nil,
)
isCodexCLI := false
if c != nil {
isCodexCLI = openai.IsCodexCLIRequest(c.GetHeader("User-Agent"))
}
if s.cfg != nil && s.cfg.Gateway.ForceCodexCLI {
isCodexCLI = true
}
headers, _ := s.buildOpenAIWSHeaders(c, account, token, wsDecision, isCodexCLI, "", "", "")
proxyURL := ""
if account.ProxyID != nil && account.Proxy != nil {
proxyURL = account.Proxy.URL()
}
dialer := s.getOpenAIWSPassthroughDialer()
if dialer == nil {
return errors.New("openai ws passthrough dialer is nil")
}
dialCtx, cancelDial := context.WithTimeout(ctx, s.openAIWSDialTimeout())
defer cancelDial()
upstreamConn, statusCode, handshakeHeaders, err := dialer.Dial(dialCtx, wsURL, headers, proxyURL)
if err != nil {
logOpenAIWSV2Passthrough(
"relay_dial_failed account_id=%d status_code=%d err=%s",
account.ID,
statusCode,
truncateOpenAIWSLogValue(err.Error(), openAIWSLogValueMaxLen),
)
return s.mapOpenAIWSPassthroughDialError(err, statusCode, handshakeHeaders)
}
defer func() {
_ = upstreamConn.Close()
}()
logOpenAIWSV2Passthrough(
"relay_dial_ok account_id=%d status_code=%d upstream_request_id=%s",
account.ID,
statusCode,
openAIWSHeaderValueForLog(handshakeHeaders, "x-request-id"),
)
upstreamFrameConn, ok := upstreamConn.(openaiwsv2.FrameConn)
if !ok {
return errors.New("openai ws passthrough upstream connection does not support frame relay")
}
completedTurns := atomic.Int32{}
relayResult, relayExit := openaiwsv2.RunEntry(openaiwsv2.EntryInput{
Ctx: ctx,
ClientConn: &openAIWSClientFrameConn{conn: clientConn},
UpstreamConn: upstreamFrameConn,
FirstClientMessage: firstClientMessage,
Options: openaiwsv2.RelayOptions{
WriteTimeout: s.openAIWSWriteTimeout(),
IdleTimeout: s.openAIWSPassthroughIdleTimeout(),
FirstMessageType: coderws.MessageText,
OnUsageParseFailure: func(eventType string, usageRaw string) {
logOpenAIWSV2Passthrough(
"usage_parse_failed event_type=%s usage_raw=%s",
truncateOpenAIWSLogValue(eventType, openAIWSLogValueMaxLen),
truncateOpenAIWSLogValue(usageRaw, openAIWSLogValueMaxLen),
)
},
OnTurnComplete: func(turn openaiwsv2.RelayTurnResult) {
turnNo := int(completedTurns.Add(1))
turnResult := &OpenAIForwardResult{
RequestID: turn.RequestID,
Usage: OpenAIUsage{
InputTokens: turn.Usage.InputTokens,
OutputTokens: turn.Usage.OutputTokens,
CacheCreationInputTokens: turn.Usage.CacheCreationInputTokens,
CacheReadInputTokens: turn.Usage.CacheReadInputTokens,
},
Model: turn.RequestModel,
Stream: true,
OpenAIWSMode: true,
Duration: turn.Duration,
FirstTokenMs: turn.FirstTokenMs,
}
logOpenAIWSV2Passthrough(
"relay_turn_completed account_id=%d turn=%d request_id=%s terminal_event=%s duration_ms=%d first_token_ms=%d input_tokens=%d output_tokens=%d cache_read_tokens=%d",
account.ID,
turnNo,
truncateOpenAIWSLogValue(turnResult.RequestID, openAIWSIDValueMaxLen),
truncateOpenAIWSLogValue(turn.TerminalEventType, openAIWSLogValueMaxLen),
turnResult.Duration.Milliseconds(),
openAIWSFirstTokenMsForLog(turnResult.FirstTokenMs),
turnResult.Usage.InputTokens,
turnResult.Usage.OutputTokens,
turnResult.Usage.CacheReadInputTokens,
)
if hooks != nil && hooks.AfterTurn != nil {
hooks.AfterTurn(turnNo, turnResult, nil)
}
},
OnTrace: func(event openaiwsv2.RelayTraceEvent) {
logOpenAIWSV2Passthrough(
"relay_trace account_id=%d stage=%s direction=%s msg_type=%s bytes=%d graceful=%v wrote_downstream=%v err=%s",
account.ID,
truncateOpenAIWSLogValue(event.Stage, openAIWSLogValueMaxLen),
truncateOpenAIWSLogValue(event.Direction, openAIWSLogValueMaxLen),
truncateOpenAIWSLogValue(event.MessageType, openAIWSLogValueMaxLen),
event.PayloadBytes,
event.Graceful,
event.WroteDownstream,
truncateOpenAIWSLogValue(event.Error, openAIWSLogValueMaxLen),
)
},
},
})
result := &OpenAIForwardResult{
RequestID: relayResult.RequestID,
Usage: OpenAIUsage{
InputTokens: relayResult.Usage.InputTokens,
OutputTokens: relayResult.Usage.OutputTokens,
CacheCreationInputTokens: relayResult.Usage.CacheCreationInputTokens,
CacheReadInputTokens: relayResult.Usage.CacheReadInputTokens,
},
Model: relayResult.RequestModel,
Stream: true,
OpenAIWSMode: true,
Duration: relayResult.Duration,
FirstTokenMs: relayResult.FirstTokenMs,
}
turnCount := int(completedTurns.Load())
if relayExit == nil {
logOpenAIWSV2Passthrough(
"relay_completed account_id=%d request_id=%s terminal_event=%s duration_ms=%d c2u_frames=%d u2c_frames=%d dropped_frames=%d turns=%d",
account.ID,
truncateOpenAIWSLogValue(result.RequestID, openAIWSIDValueMaxLen),
truncateOpenAIWSLogValue(relayResult.TerminalEventType, openAIWSLogValueMaxLen),
result.Duration.Milliseconds(),
relayResult.ClientToUpstreamFrames,
relayResult.UpstreamToClientFrames,
relayResult.DroppedDownstreamFrames,
turnCount,
)
// 正常路径按 terminal 事件逐 turn 已回调;仅在零 turn 场景兜底回调一次。
if turnCount == 0 && hooks != nil && hooks.AfterTurn != nil {
hooks.AfterTurn(1, result, nil)
}
return nil
}
logOpenAIWSV2Passthrough(
"relay_failed account_id=%d stage=%s wrote_downstream=%v err=%s duration_ms=%d c2u_frames=%d u2c_frames=%d dropped_frames=%d turns=%d",
account.ID,
truncateOpenAIWSLogValue(relayExit.Stage, openAIWSLogValueMaxLen),
relayExit.WroteDownstream,
truncateOpenAIWSLogValue(relayErrorText(relayExit.Err), openAIWSLogValueMaxLen),
result.Duration.Milliseconds(),
relayResult.ClientToUpstreamFrames,
relayResult.UpstreamToClientFrames,
relayResult.DroppedDownstreamFrames,
turnCount,
)
relayErr := relayExit.Err
if relayExit.Stage == "idle_timeout" {
relayErr = NewOpenAIWSClientCloseError(
coderws.StatusPolicyViolation,
"client websocket idle timeout",
relayErr,
)
}
turnErr := wrapOpenAIWSIngressTurnError(
relayExit.Stage,
relayErr,
relayExit.WroteDownstream,
)
if hooks != nil && hooks.AfterTurn != nil {
hooks.AfterTurn(turnCount+1, nil, turnErr)
}
return turnErr
}
func (s *OpenAIGatewayService) mapOpenAIWSPassthroughDialError(
err error,
statusCode int,
handshakeHeaders http.Header,
) error {
if err == nil {
return nil
}
wrappedErr := err
var dialErr *openAIWSDialError
if !errors.As(err, &dialErr) {
wrappedErr = &openAIWSDialError{
StatusCode: statusCode,
ResponseHeaders: cloneHeader(handshakeHeaders),
Err: err,
}
}
if errors.Is(err, context.Canceled) {
return err
}
if errors.Is(err, context.DeadlineExceeded) {
return NewOpenAIWSClientCloseError(
coderws.StatusTryAgainLater,
"upstream websocket connect timeout",
wrappedErr,
)
}
if statusCode == http.StatusTooManyRequests {
return NewOpenAIWSClientCloseError(
coderws.StatusTryAgainLater,
"upstream websocket is busy, please retry later",
wrappedErr,
)
}
if statusCode == http.StatusUnauthorized || statusCode == http.StatusForbidden {
return NewOpenAIWSClientCloseError(
coderws.StatusPolicyViolation,
"upstream websocket authentication failed",
wrappedErr,
)
}
if statusCode >= http.StatusBadRequest && statusCode < http.StatusInternalServerError {
return NewOpenAIWSClientCloseError(
coderws.StatusPolicyViolation,
"upstream websocket handshake rejected",
wrappedErr,
)
}
return fmt.Errorf("openai ws passthrough dial: %w", wrappedErr)
}
func openaiwsv2RelayMessageTypeName(msgType coderws.MessageType) string {
switch msgType {
case coderws.MessageText:
return "text"
case coderws.MessageBinary:
return "binary"
default:
return fmt.Sprintf("unknown(%d)", msgType)
}
}
func relayErrorText(err error) string {
if err == nil {
return ""
}
return err.Error()
}
func openAIWSFirstTokenMsForLog(firstTokenMs *int) int {
if firstTokenMs == nil {
return -1
}
return *firstTokenMs
}
func logOpenAIWSV2Passthrough(format string, args ...any) {
logger.LegacyPrintf(
"service.openai_ws_v2",
"[OpenAI WS v2 passthrough] %s "+format,
append([]any{openaiWSV2PassthroughModeFields}, args...)...,
)
}
......@@ -209,8 +209,9 @@ gateway:
openai_ws:
# 新版 WS mode 路由(默认关闭)。关闭时保持当前 legacy 实现行为。
mode_router_v2_enabled: false
# ingress 默认模式:off|shared|dedicated(仅 mode_router_v2_enabled=true 生效)
ingress_mode_default: shared
# ingress 默认模式:off|ctx_pool|passthrough(仅 mode_router_v2_enabled=true 生效)
# 兼容旧值:shared/dedicated 会按 ctx_pool 处理。
ingress_mode_default: ctx_pool
# 全局总开关,默认 true;关闭时所有请求保持原有 HTTP/SSE 路由
enabled: true
# 按账号类型细分开关
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment