Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Menu
Open sidebar
陈曦
sub2api
Commits
7076717b
Unverified
Commit
7076717b
authored
Mar 05, 2026
by
Wesley Liddick
Committed by
GitHub
Mar 05, 2026
Browse files
Merge pull request #772 from mt21625457/aicodex2api-main
feat(openai-ws): 合并 WS v2 透传模式与前端 ws mode
parents
33988637
c0a4fcea
Changes
25
Expand all
Show whitespace changes
Inline
Side-by-side
backend/internal/config/config.go
View file @
7076717b
...
...
@@ -516,7 +516,7 @@ func (c *UserMessageQueueConfig) GetEffectiveMode() string {
type
GatewayOpenAIWSConfig
struct
{
// ModeRouterV2Enabled: 新版 WS mode 路由开关(默认 false;关闭时保持 legacy 行为)
ModeRouterV2Enabled
bool
`mapstructure:"mode_router_v2_enabled"`
// IngressModeDefault: ingress 默认模式(off/
shared/dedicated
)
// IngressModeDefault: ingress 默认模式(off/
ctx_pool/passthrough
)
IngressModeDefault
string
`mapstructure:"ingress_mode_default"`
// Enabled: 全局总开关(默认 true)
Enabled
bool
`mapstructure:"enabled"`
...
...
@@ -1335,7 +1335,7 @@ func setDefaults() {
// OpenAI Responses WebSocket(默认开启;可通过 force_http 紧急回滚)
viper
.
SetDefault
(
"gateway.openai_ws.enabled"
,
true
)
viper
.
SetDefault
(
"gateway.openai_ws.mode_router_v2_enabled"
,
false
)
viper
.
SetDefault
(
"gateway.openai_ws.ingress_mode_default"
,
"
shared
"
)
viper
.
SetDefault
(
"gateway.openai_ws.ingress_mode_default"
,
"
ctx_pool
"
)
viper
.
SetDefault
(
"gateway.openai_ws.oauth_enabled"
,
true
)
viper
.
SetDefault
(
"gateway.openai_ws.apikey_enabled"
,
true
)
viper
.
SetDefault
(
"gateway.openai_ws.force_http"
,
false
)
...
...
@@ -2043,9 +2043,11 @@ func (c *Config) Validate() error {
}
if
mode
:=
strings
.
ToLower
(
strings
.
TrimSpace
(
c
.
Gateway
.
OpenAIWS
.
IngressModeDefault
));
mode
!=
""
{
switch
mode
{
case
"off"
,
"shared"
,
"dedicated"
:
case
"off"
,
"ctx_pool"
,
"passthrough"
:
case
"shared"
,
"dedicated"
:
slog
.
Warn
(
"gateway.openai_ws.ingress_mode_default is deprecated, treating as ctx_pool; please update to off|ctx_pool|passthrough"
,
"value"
,
mode
)
default
:
return
fmt
.
Errorf
(
"gateway.openai_ws.ingress_mode_default must be one of off|
shared|dedicated
"
)
return
fmt
.
Errorf
(
"gateway.openai_ws.ingress_mode_default must be one of off|
ctx_pool|passthrough
"
)
}
}
if
mode
:=
strings
.
ToLower
(
strings
.
TrimSpace
(
c
.
Gateway
.
OpenAIWS
.
StoreDisabledConnMode
));
mode
!=
""
{
...
...
backend/internal/config/config_test.go
View file @
7076717b
...
...
@@ -153,8 +153,8 @@ func TestLoadDefaultOpenAIWSConfig(t *testing.T) {
if
cfg
.
Gateway
.
OpenAIWS
.
ModeRouterV2Enabled
{
t
.
Fatalf
(
"Gateway.OpenAIWS.ModeRouterV2Enabled = true, want false"
)
}
if
cfg
.
Gateway
.
OpenAIWS
.
IngressModeDefault
!=
"
shared
"
{
t
.
Fatalf
(
"Gateway.OpenAIWS.IngressModeDefault = %q, want %q"
,
cfg
.
Gateway
.
OpenAIWS
.
IngressModeDefault
,
"
shared
"
)
if
cfg
.
Gateway
.
OpenAIWS
.
IngressModeDefault
!=
"
ctx_pool
"
{
t
.
Fatalf
(
"Gateway.OpenAIWS.IngressModeDefault = %q, want %q"
,
cfg
.
Gateway
.
OpenAIWS
.
IngressModeDefault
,
"
ctx_pool
"
)
}
}
...
...
@@ -1373,7 +1373,7 @@ func TestValidateConfig_OpenAIWSRules(t *testing.T) {
wantErr
:
"gateway.openai_ws.store_disabled_conn_mode"
,
},
{
name
:
"ingress_mode_default 必须为 off|
shared|dedicated
"
,
name
:
"ingress_mode_default 必须为 off|
ctx_pool|passthrough
"
,
mutate
:
func
(
c
*
Config
)
{
c
.
Gateway
.
OpenAIWS
.
IngressModeDefault
=
"invalid"
},
wantErr
:
"gateway.openai_ws.ingress_mode_default"
,
},
...
...
backend/internal/service/account.go
View file @
7076717b
...
...
@@ -856,12 +856,18 @@ const (
OpenAIWSIngressModeOff
=
"off"
OpenAIWSIngressModeShared
=
"shared"
OpenAIWSIngressModeDedicated
=
"dedicated"
OpenAIWSIngressModeCtxPool
=
"ctx_pool"
OpenAIWSIngressModePassthrough
=
"passthrough"
)
func
normalizeOpenAIWSIngressMode
(
mode
string
)
string
{
switch
strings
.
ToLower
(
strings
.
TrimSpace
(
mode
))
{
case
OpenAIWSIngressModeOff
:
return
OpenAIWSIngressModeOff
case
OpenAIWSIngressModeCtxPool
:
return
OpenAIWSIngressModeCtxPool
case
OpenAIWSIngressModePassthrough
:
return
OpenAIWSIngressModePassthrough
case
OpenAIWSIngressModeShared
:
return
OpenAIWSIngressModeShared
case
OpenAIWSIngressModeDedicated
:
...
...
@@ -873,18 +879,21 @@ func normalizeOpenAIWSIngressMode(mode string) string {
func
normalizeOpenAIWSIngressDefaultMode
(
mode
string
)
string
{
if
normalized
:=
normalizeOpenAIWSIngressMode
(
mode
);
normalized
!=
""
{
if
normalized
==
OpenAIWSIngressModeShared
||
normalized
==
OpenAIWSIngressModeDedicated
{
return
OpenAIWSIngressModeCtxPool
}
return
normalized
}
return
OpenAIWSIngressMode
Shared
return
OpenAIWSIngressMode
CtxPool
}
// ResolveOpenAIResponsesWebSocketV2Mode 返回账号在 WSv2 ingress 下的有效模式(off/
shared/dedicated
)。
// ResolveOpenAIResponsesWebSocketV2Mode 返回账号在 WSv2 ingress 下的有效模式(off/
ctx_pool/passthrough
)。
//
// 优先级:
// 1. 分类型 mode 新字段(string)
// 2. 分类型 enabled 旧字段(bool)
// 3. 兼容 enabled 旧字段(bool)
// 4. defaultMode(非法时回退
shared
)
// 4. defaultMode(非法时回退
ctx_pool
)
func
(
a
*
Account
)
ResolveOpenAIResponsesWebSocketV2Mode
(
defaultMode
string
)
string
{
resolvedDefault
:=
normalizeOpenAIWSIngressDefaultMode
(
defaultMode
)
if
a
==
nil
||
!
a
.
IsOpenAI
()
{
...
...
@@ -919,7 +928,7 @@ func (a *Account) ResolveOpenAIResponsesWebSocketV2Mode(defaultMode string) stri
return
""
,
false
}
if
enabled
{
return
OpenAIWSIngressMode
Shared
,
true
return
OpenAIWSIngressMode
CtxPool
,
true
}
return
OpenAIWSIngressModeOff
,
true
}
...
...
@@ -946,6 +955,10 @@ func (a *Account) ResolveOpenAIResponsesWebSocketV2Mode(defaultMode string) stri
if
mode
,
ok
:=
resolveBoolMode
(
"openai_ws_enabled"
);
ok
{
return
mode
}
// 兼容旧值:shared/dedicated 语义都归并到 ctx_pool。
if
resolvedDefault
==
OpenAIWSIngressModeShared
||
resolvedDefault
==
OpenAIWSIngressModeDedicated
{
return
OpenAIWSIngressModeCtxPool
}
return
resolvedDefault
}
...
...
backend/internal/service/account_openai_passthrough_test.go
View file @
7076717b
...
...
@@ -206,14 +206,14 @@ func TestAccount_IsOpenAIResponsesWebSocketV2Enabled(t *testing.T) {
}
func
TestAccount_ResolveOpenAIResponsesWebSocketV2Mode
(
t
*
testing
.
T
)
{
t
.
Run
(
"default fallback to
shared
"
,
func
(
t
*
testing
.
T
)
{
t
.
Run
(
"default fallback to
ctx_pool
"
,
func
(
t
*
testing
.
T
)
{
account
:=
&
Account
{
Platform
:
PlatformOpenAI
,
Type
:
AccountTypeOAuth
,
Extra
:
map
[
string
]
any
{},
}
require
.
Equal
(
t
,
OpenAIWSIngressMode
Shared
,
account
.
ResolveOpenAIResponsesWebSocketV2Mode
(
""
))
require
.
Equal
(
t
,
OpenAIWSIngressMode
Shared
,
account
.
ResolveOpenAIResponsesWebSocketV2Mode
(
"invalid"
))
require
.
Equal
(
t
,
OpenAIWSIngressMode
CtxPool
,
account
.
ResolveOpenAIResponsesWebSocketV2Mode
(
""
))
require
.
Equal
(
t
,
OpenAIWSIngressMode
CtxPool
,
account
.
ResolveOpenAIResponsesWebSocketV2Mode
(
"invalid"
))
})
t
.
Run
(
"oauth mode field has highest priority"
,
func
(
t
*
testing
.
T
)
{
...
...
@@ -221,15 +221,15 @@ func TestAccount_ResolveOpenAIResponsesWebSocketV2Mode(t *testing.T) {
Platform
:
PlatformOpenAI
,
Type
:
AccountTypeOAuth
,
Extra
:
map
[
string
]
any
{
"openai_oauth_responses_websockets_v2_mode"
:
OpenAIWSIngressMode
Dedicated
,
"openai_oauth_responses_websockets_v2_mode"
:
OpenAIWSIngressMode
Passthrough
,
"openai_oauth_responses_websockets_v2_enabled"
:
false
,
"responses_websockets_v2_enabled"
:
false
,
},
}
require
.
Equal
(
t
,
OpenAIWSIngressMode
Dedicated
,
account
.
ResolveOpenAIResponsesWebSocketV2Mode
(
OpenAIWSIngressMode
Shared
))
require
.
Equal
(
t
,
OpenAIWSIngressMode
Passthrough
,
account
.
ResolveOpenAIResponsesWebSocketV2Mode
(
OpenAIWSIngressMode
CtxPool
))
})
t
.
Run
(
"legacy enabled maps to
shared
"
,
func
(
t
*
testing
.
T
)
{
t
.
Run
(
"legacy enabled maps to
ctx_pool
"
,
func
(
t
*
testing
.
T
)
{
account
:=
&
Account
{
Platform
:
PlatformOpenAI
,
Type
:
AccountTypeAPIKey
,
...
...
@@ -237,7 +237,28 @@ func TestAccount_ResolveOpenAIResponsesWebSocketV2Mode(t *testing.T) {
"responses_websockets_v2_enabled"
:
true
,
},
}
require
.
Equal
(
t
,
OpenAIWSIngressModeShared
,
account
.
ResolveOpenAIResponsesWebSocketV2Mode
(
OpenAIWSIngressModeOff
))
require
.
Equal
(
t
,
OpenAIWSIngressModeCtxPool
,
account
.
ResolveOpenAIResponsesWebSocketV2Mode
(
OpenAIWSIngressModeOff
))
})
t
.
Run
(
"shared/dedicated mode strings are compatible with ctx_pool"
,
func
(
t
*
testing
.
T
)
{
shared
:=
&
Account
{
Platform
:
PlatformOpenAI
,
Type
:
AccountTypeOAuth
,
Extra
:
map
[
string
]
any
{
"openai_oauth_responses_websockets_v2_mode"
:
OpenAIWSIngressModeShared
,
},
}
dedicated
:=
&
Account
{
Platform
:
PlatformOpenAI
,
Type
:
AccountTypeOAuth
,
Extra
:
map
[
string
]
any
{
"openai_oauth_responses_websockets_v2_mode"
:
OpenAIWSIngressModeDedicated
,
},
}
require
.
Equal
(
t
,
OpenAIWSIngressModeShared
,
shared
.
ResolveOpenAIResponsesWebSocketV2Mode
(
OpenAIWSIngressModeOff
))
require
.
Equal
(
t
,
OpenAIWSIngressModeDedicated
,
dedicated
.
ResolveOpenAIResponsesWebSocketV2Mode
(
OpenAIWSIngressModeOff
))
require
.
Equal
(
t
,
OpenAIWSIngressModeCtxPool
,
normalizeOpenAIWSIngressDefaultMode
(
OpenAIWSIngressModeShared
))
require
.
Equal
(
t
,
OpenAIWSIngressModeCtxPool
,
normalizeOpenAIWSIngressDefaultMode
(
OpenAIWSIngressModeDedicated
))
})
t
.
Run
(
"legacy disabled maps to off"
,
func
(
t
*
testing
.
T
)
{
...
...
@@ -249,7 +270,7 @@ func TestAccount_ResolveOpenAIResponsesWebSocketV2Mode(t *testing.T) {
"responses_websockets_v2_enabled"
:
true
,
},
}
require
.
Equal
(
t
,
OpenAIWSIngressModeOff
,
account
.
ResolveOpenAIResponsesWebSocketV2Mode
(
OpenAIWSIngressMode
Shared
))
require
.
Equal
(
t
,
OpenAIWSIngressModeOff
,
account
.
ResolveOpenAIResponsesWebSocketV2Mode
(
OpenAIWSIngressMode
CtxPool
))
})
t
.
Run
(
"non openai always off"
,
func
(
t
*
testing
.
T
)
{
...
...
backend/internal/service/openai_gateway_service.go
View file @
7076717b
...
...
@@ -266,9 +266,11 @@ type OpenAIGatewayService struct {
openaiWSPoolOnce
sync
.
Once
openaiWSStateStoreOnce
sync
.
Once
openaiSchedulerOnce
sync
.
Once
openaiWSPassthroughDialerOnce
sync
.
Once
openaiWSPool
*
openAIWSConnPool
openaiWSStateStore
OpenAIWSStateStore
openaiScheduler
OpenAIAccountScheduler
openaiWSPassthroughDialer
openAIWSClientDialer
openaiAccountStats
*
openAIAccountRuntimeStats
openaiWSFallbackUntil
sync
.
Map
// key: int64(accountID), value: time.Time
...
...
backend/internal/service/openai_ws_client.go
View file @
7076717b
...
...
@@ -11,6 +11,7 @@ import (
"sync/atomic"
"time"
openaiwsv2
"github.com/Wei-Shaw/sub2api/internal/service/openai_ws_v2"
coderws
"github.com/coder/websocket"
"github.com/coder/websocket/wsjson"
)
...
...
@@ -234,6 +235,8 @@ type coderOpenAIWSClientConn struct {
conn
*
coderws
.
Conn
}
var
_
openaiwsv2
.
FrameConn
=
(
*
coderOpenAIWSClientConn
)(
nil
)
func
(
c
*
coderOpenAIWSClientConn
)
WriteJSON
(
ctx
context
.
Context
,
value
any
)
error
{
if
c
==
nil
||
c
.
conn
==
nil
{
return
errOpenAIWSConnClosed
...
...
@@ -264,6 +267,30 @@ func (c *coderOpenAIWSClientConn) ReadMessage(ctx context.Context) ([]byte, erro
}
}
func
(
c
*
coderOpenAIWSClientConn
)
ReadFrame
(
ctx
context
.
Context
)
(
coderws
.
MessageType
,
[]
byte
,
error
)
{
if
c
==
nil
||
c
.
conn
==
nil
{
return
coderws
.
MessageText
,
nil
,
errOpenAIWSConnClosed
}
if
ctx
==
nil
{
ctx
=
context
.
Background
()
}
msgType
,
payload
,
err
:=
c
.
conn
.
Read
(
ctx
)
if
err
!=
nil
{
return
coderws
.
MessageText
,
nil
,
err
}
return
msgType
,
payload
,
nil
}
func
(
c
*
coderOpenAIWSClientConn
)
WriteFrame
(
ctx
context
.
Context
,
msgType
coderws
.
MessageType
,
payload
[]
byte
)
error
{
if
c
==
nil
||
c
.
conn
==
nil
{
return
errOpenAIWSConnClosed
}
if
ctx
==
nil
{
ctx
=
context
.
Background
()
}
return
c
.
conn
.
Write
(
ctx
,
msgType
,
payload
)
}
func
(
c
*
coderOpenAIWSClientConn
)
Ping
(
ctx
context
.
Context
)
error
{
if
c
==
nil
||
c
.
conn
==
nil
{
return
errOpenAIWSConnClosed
...
...
backend/internal/service/openai_ws_forwarder.go
View file @
7076717b
...
...
@@ -49,6 +49,7 @@ const (
openAIWSEventFlushBatchSizeDefault
=
4
openAIWSEventFlushIntervalDefault
=
25
*
time
.
Millisecond
openAIWSPayloadLogSampleDefault
=
0.2
openAIWSPassthroughIdleTimeoutDefault
=
time
.
Hour
openAIWSStoreDisabledConnModeStrict
=
"strict"
openAIWSStoreDisabledConnModeAdaptive
=
"adaptive"
...
...
@@ -904,6 +905,18 @@ func (s *OpenAIGatewayService) getOpenAIWSConnPool() *openAIWSConnPool {
return
s
.
openaiWSPool
}
func
(
s
*
OpenAIGatewayService
)
getOpenAIWSPassthroughDialer
()
openAIWSClientDialer
{
if
s
==
nil
{
return
nil
}
s
.
openaiWSPassthroughDialerOnce
.
Do
(
func
()
{
if
s
.
openaiWSPassthroughDialer
==
nil
{
s
.
openaiWSPassthroughDialer
=
newDefaultOpenAIWSClientDialer
()
}
})
return
s
.
openaiWSPassthroughDialer
}
func
(
s
*
OpenAIGatewayService
)
SnapshotOpenAIWSPoolMetrics
()
OpenAIWSPoolMetricsSnapshot
{
pool
:=
s
.
getOpenAIWSConnPool
()
if
pool
==
nil
{
...
...
@@ -967,6 +980,13 @@ func (s *OpenAIGatewayService) openAIWSReadTimeout() time.Duration {
return
15
*
time
.
Minute
}
func
(
s
*
OpenAIGatewayService
)
openAIWSPassthroughIdleTimeout
()
time
.
Duration
{
if
timeout
:=
s
.
openAIWSReadTimeout
();
timeout
>
0
{
return
timeout
}
return
openAIWSPassthroughIdleTimeoutDefault
}
func
(
s
*
OpenAIGatewayService
)
openAIWSWriteTimeout
()
time
.
Duration
{
if
s
!=
nil
&&
s
.
cfg
!=
nil
&&
s
.
cfg
.
Gateway
.
OpenAIWS
.
WriteTimeoutSeconds
>
0
{
return
time
.
Duration
(
s
.
cfg
.
Gateway
.
OpenAIWS
.
WriteTimeoutSeconds
)
*
time
.
Second
...
...
@@ -2322,7 +2342,7 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
wsDecision
:=
s
.
getOpenAIWSProtocolResolver
()
.
Resolve
(
account
)
modeRouterV2Enabled
:=
s
!=
nil
&&
s
.
cfg
!=
nil
&&
s
.
cfg
.
Gateway
.
OpenAIWS
.
ModeRouterV2Enabled
ingressMode
:=
OpenAIWSIngressMode
Shared
ingressMode
:=
OpenAIWSIngressMode
CtxPool
if
modeRouterV2Enabled
{
ingressMode
=
account
.
ResolveOpenAIResponsesWebSocketV2Mode
(
s
.
cfg
.
Gateway
.
OpenAIWS
.
IngressModeDefault
)
if
ingressMode
==
OpenAIWSIngressModeOff
{
...
...
@@ -2332,6 +2352,30 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
nil
,
)
}
switch
ingressMode
{
case
OpenAIWSIngressModePassthrough
:
if
wsDecision
.
Transport
!=
OpenAIUpstreamTransportResponsesWebsocketV2
{
return
fmt
.
Errorf
(
"websocket ingress requires ws_v2 transport, got=%s"
,
wsDecision
.
Transport
)
}
return
s
.
proxyResponsesWebSocketV2Passthrough
(
ctx
,
c
,
clientConn
,
account
,
token
,
firstClientMessage
,
hooks
,
wsDecision
,
)
case
OpenAIWSIngressModeCtxPool
,
OpenAIWSIngressModeShared
,
OpenAIWSIngressModeDedicated
:
// continue
default
:
return
NewOpenAIWSClientCloseError
(
coderws
.
StatusPolicyViolation
,
"websocket mode only supports ctx_pool/passthrough"
,
nil
,
)
}
}
if
wsDecision
.
Transport
!=
OpenAIUpstreamTransportResponsesWebsocketV2
{
return
fmt
.
Errorf
(
"websocket ingress requires ws_v2 transport, got=%s"
,
wsDecision
.
Transport
)
...
...
backend/internal/service/openai_ws_forwarder_ingress_session_test.go
View file @
7076717b
...
...
@@ -149,7 +149,7 @@ func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_KeepLeaseAcrossT
require
.
True
(
t
,
<-
turnWSModeCh
,
"首轮 turn 应标记为 WS 模式"
)
require
.
True
(
t
,
<-
turnWSModeCh
,
"第二轮 turn 应标记为 WS 模式"
)
require
.
NoError
(
t
,
clientConn
.
Close
(
coderws
.
StatusNormalClosure
,
"done"
)
)
_
=
clientConn
.
Close
(
coderws
.
StatusNormalClosure
,
"done"
)
select
{
case
serverErr
:=
<-
serverErrCh
:
...
...
@@ -298,6 +298,140 @@ func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_DedicatedModeDoe
require
.
Equal
(
t
,
2
,
dialer
.
DialCount
(),
"dedicated 模式下跨客户端会话不应复用上游连接"
)
}
func
TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_PassthroughModeRelaysByCaddyAdapter
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
cfg
:=
&
config
.
Config
{}
cfg
.
Security
.
URLAllowlist
.
Enabled
=
false
cfg
.
Security
.
URLAllowlist
.
AllowInsecureHTTP
=
true
cfg
.
Gateway
.
OpenAIWS
.
Enabled
=
true
cfg
.
Gateway
.
OpenAIWS
.
OAuthEnabled
=
true
cfg
.
Gateway
.
OpenAIWS
.
APIKeyEnabled
=
true
cfg
.
Gateway
.
OpenAIWS
.
ResponsesWebsocketsV2
=
true
cfg
.
Gateway
.
OpenAIWS
.
ModeRouterV2Enabled
=
true
cfg
.
Gateway
.
OpenAIWS
.
IngressModeDefault
=
OpenAIWSIngressModeCtxPool
cfg
.
Gateway
.
OpenAIWS
.
DialTimeoutSeconds
=
3
cfg
.
Gateway
.
OpenAIWS
.
ReadTimeoutSeconds
=
3
cfg
.
Gateway
.
OpenAIWS
.
WriteTimeoutSeconds
=
3
upstreamConn
:=
&
openAIWSCaptureConn
{
events
:
[][]
byte
{
[]
byte
(
`{"type":"response.completed","response":{"id":"resp_passthrough_turn_1","model":"gpt-5.1","usage":{"input_tokens":2,"output_tokens":3}}}`
),
},
}
captureDialer
:=
&
openAIWSCaptureDialer
{
conn
:
upstreamConn
}
svc
:=
&
OpenAIGatewayService
{
cfg
:
cfg
,
httpUpstream
:
&
httpUpstreamRecorder
{},
cache
:
&
stubGatewayCache
{},
openaiWSResolver
:
NewOpenAIWSProtocolResolver
(
cfg
),
toolCorrector
:
NewCodexToolCorrector
(),
openaiWSPassthroughDialer
:
captureDialer
,
}
account
:=
&
Account
{
ID
:
452
,
Name
:
"openai-ingress-passthrough"
,
Platform
:
PlatformOpenAI
,
Type
:
AccountTypeAPIKey
,
Status
:
StatusActive
,
Schedulable
:
true
,
Concurrency
:
1
,
Credentials
:
map
[
string
]
any
{
"api_key"
:
"sk-test"
,
},
Extra
:
map
[
string
]
any
{
"openai_apikey_responses_websockets_v2_mode"
:
OpenAIWSIngressModePassthrough
,
},
}
serverErrCh
:=
make
(
chan
error
,
1
)
resultCh
:=
make
(
chan
*
OpenAIForwardResult
,
1
)
hooks
:=
&
OpenAIWSIngressHooks
{
AfterTurn
:
func
(
_
int
,
result
*
OpenAIForwardResult
,
turnErr
error
)
{
if
turnErr
==
nil
&&
result
!=
nil
{
resultCh
<-
result
}
},
}
wsServer
:=
httptest
.
NewServer
(
http
.
HandlerFunc
(
func
(
w
http
.
ResponseWriter
,
r
*
http
.
Request
)
{
conn
,
err
:=
coderws
.
Accept
(
w
,
r
,
&
coderws
.
AcceptOptions
{
CompressionMode
:
coderws
.
CompressionContextTakeover
,
})
if
err
!=
nil
{
serverErrCh
<-
err
return
}
defer
func
()
{
_
=
conn
.
CloseNow
()
}()
rec
:=
httptest
.
NewRecorder
()
ginCtx
,
_
:=
gin
.
CreateTestContext
(
rec
)
req
:=
r
.
Clone
(
r
.
Context
())
req
.
Header
=
req
.
Header
.
Clone
()
req
.
Header
.
Set
(
"User-Agent"
,
"unit-test-agent/1.0"
)
ginCtx
.
Request
=
req
readCtx
,
cancel
:=
context
.
WithTimeout
(
r
.
Context
(),
3
*
time
.
Second
)
msgType
,
firstMessage
,
readErr
:=
conn
.
Read
(
readCtx
)
cancel
()
if
readErr
!=
nil
{
serverErrCh
<-
readErr
return
}
if
msgType
!=
coderws
.
MessageText
&&
msgType
!=
coderws
.
MessageBinary
{
serverErrCh
<-
errors
.
New
(
"unsupported websocket client message type"
)
return
}
serverErrCh
<-
svc
.
ProxyResponsesWebSocketFromClient
(
r
.
Context
(),
ginCtx
,
conn
,
account
,
"sk-test"
,
firstMessage
,
hooks
)
}))
defer
wsServer
.
Close
()
dialCtx
,
cancelDial
:=
context
.
WithTimeout
(
context
.
Background
(),
3
*
time
.
Second
)
clientConn
,
_
,
err
:=
coderws
.
Dial
(
dialCtx
,
"ws"
+
strings
.
TrimPrefix
(
wsServer
.
URL
,
"http"
),
nil
)
cancelDial
()
require
.
NoError
(
t
,
err
)
defer
func
()
{
_
=
clientConn
.
CloseNow
()
}()
writeCtx
,
cancelWrite
:=
context
.
WithTimeout
(
context
.
Background
(),
3
*
time
.
Second
)
err
=
clientConn
.
Write
(
writeCtx
,
coderws
.
MessageText
,
[]
byte
(
`{"type":"response.create","model":"gpt-5.1","stream":false}`
))
cancelWrite
()
require
.
NoError
(
t
,
err
)
readCtx
,
cancelRead
:=
context
.
WithTimeout
(
context
.
Background
(),
3
*
time
.
Second
)
_
,
event
,
readErr
:=
clientConn
.
Read
(
readCtx
)
cancelRead
()
require
.
NoError
(
t
,
readErr
)
require
.
Equal
(
t
,
"response.completed"
,
gjson
.
GetBytes
(
event
,
"type"
)
.
String
())
require
.
Equal
(
t
,
"resp_passthrough_turn_1"
,
gjson
.
GetBytes
(
event
,
"response.id"
)
.
String
())
_
=
clientConn
.
Close
(
coderws
.
StatusNormalClosure
,
"done"
)
select
{
case
serverErr
:=
<-
serverErrCh
:
require
.
NoError
(
t
,
serverErr
)
case
<-
time
.
After
(
5
*
time
.
Second
)
:
t
.
Fatal
(
"等待 passthrough websocket 结束超时"
)
}
select
{
case
result
:=
<-
resultCh
:
require
.
Equal
(
t
,
"resp_passthrough_turn_1"
,
result
.
RequestID
)
require
.
True
(
t
,
result
.
OpenAIWSMode
)
require
.
Equal
(
t
,
2
,
result
.
Usage
.
InputTokens
)
require
.
Equal
(
t
,
3
,
result
.
Usage
.
OutputTokens
)
case
<-
time
.
After
(
2
*
time
.
Second
)
:
t
.
Fatal
(
"未收到 passthrough turn 结果回调"
)
}
require
.
Equal
(
t
,
1
,
captureDialer
.
DialCount
(),
"passthrough 模式应直接建立上游 websocket"
)
require
.
Len
(
t
,
upstreamConn
.
writes
,
1
,
"passthrough 模式应透传首条 response.create"
)
}
func
TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_ModeOffReturnsPolicyViolation
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
...
...
backend/internal/service/openai_ws_forwarder_success_test.go
View file @
7076717b
...
...
@@ -15,6 +15,7 @@ import (
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
coderws
"github.com/coder/websocket"
"github.com/gin-gonic/gin"
"github.com/gorilla/websocket"
"github.com/stretchr/testify/require"
...
...
@@ -1282,6 +1283,18 @@ func (c *openAIWSCaptureConn) ReadMessage(ctx context.Context) ([]byte, error) {
return
event
,
nil
}
func
(
c
*
openAIWSCaptureConn
)
ReadFrame
(
ctx
context
.
Context
)
(
coderws
.
MessageType
,
[]
byte
,
error
)
{
payload
,
err
:=
c
.
ReadMessage
(
ctx
)
if
err
!=
nil
{
return
coderws
.
MessageText
,
nil
,
err
}
return
coderws
.
MessageText
,
payload
,
nil
}
func
(
c
*
openAIWSCaptureConn
)
WriteFrame
(
ctx
context
.
Context
,
_
coderws
.
MessageType
,
payload
[]
byte
)
error
{
return
c
.
WriteJSON
(
ctx
,
json
.
RawMessage
(
payload
))
}
func
(
c
*
openAIWSCaptureConn
)
Ping
(
ctx
context
.
Context
)
error
{
_
=
ctx
return
nil
...
...
backend/internal/service/openai_ws_protocol_resolver.go
View file @
7076717b
...
...
@@ -69,8 +69,11 @@ func (r *defaultOpenAIWSProtocolResolver) Resolve(account *Account) OpenAIWSProt
switch
mode
{
case
OpenAIWSIngressModeOff
:
return
openAIWSHTTPDecision
(
"account_mode_off"
)
case
OpenAIWSIngressMode
Shared
,
OpenAIWSIngressMode
Dedicated
:
case
OpenAIWSIngressMode
CtxPool
,
OpenAIWSIngressMode
Passthrough
:
// continue
case
OpenAIWSIngressModeShared
,
OpenAIWSIngressModeDedicated
:
// 历史值兼容:按 ctx_pool 处理。
mode
=
OpenAIWSIngressModeCtxPool
default
:
return
openAIWSHTTPDecision
(
"account_mode_off"
)
}
...
...
backend/internal/service/openai_ws_protocol_resolver_test.go
View file @
7076717b
...
...
@@ -143,21 +143,21 @@ func TestOpenAIWSProtocolResolver_Resolve_ModeRouterV2(t *testing.T) {
cfg
.
Gateway
.
OpenAIWS
.
APIKeyEnabled
=
true
cfg
.
Gateway
.
OpenAIWS
.
ResponsesWebsocketsV2
=
true
cfg
.
Gateway
.
OpenAIWS
.
ModeRouterV2Enabled
=
true
cfg
.
Gateway
.
OpenAIWS
.
IngressModeDefault
=
OpenAIWSIngressMode
Shared
cfg
.
Gateway
.
OpenAIWS
.
IngressModeDefault
=
OpenAIWSIngressMode
CtxPool
account
:=
&
Account
{
Platform
:
PlatformOpenAI
,
Type
:
AccountTypeOAuth
,
Concurrency
:
1
,
Extra
:
map
[
string
]
any
{
"openai_oauth_responses_websockets_v2_mode"
:
OpenAIWSIngressMode
Dedicated
,
"openai_oauth_responses_websockets_v2_mode"
:
OpenAIWSIngressMode
CtxPool
,
},
}
t
.
Run
(
"
dedicated
mode routes to ws v2"
,
func
(
t
*
testing
.
T
)
{
t
.
Run
(
"
ctx_pool
mode routes to ws v2"
,
func
(
t
*
testing
.
T
)
{
decision
:=
NewOpenAIWSProtocolResolver
(
cfg
)
.
Resolve
(
account
)
require
.
Equal
(
t
,
OpenAIUpstreamTransportResponsesWebsocketV2
,
decision
.
Transport
)
require
.
Equal
(
t
,
"ws_v2_mode_
dedicated
"
,
decision
.
Reason
)
require
.
Equal
(
t
,
"ws_v2_mode_
ctx_pool
"
,
decision
.
Reason
)
})
t
.
Run
(
"off mode routes to http"
,
func
(
t
*
testing
.
T
)
{
...
...
@@ -174,7 +174,7 @@ func TestOpenAIWSProtocolResolver_Resolve_ModeRouterV2(t *testing.T) {
require
.
Equal
(
t
,
"account_mode_off"
,
decision
.
Reason
)
})
t
.
Run
(
"legacy boolean maps to
shared
in v2 router"
,
func
(
t
*
testing
.
T
)
{
t
.
Run
(
"legacy boolean maps to
ctx_pool
in v2 router"
,
func
(
t
*
testing
.
T
)
{
legacyAccount
:=
&
Account
{
Platform
:
PlatformOpenAI
,
Type
:
AccountTypeAPIKey
,
...
...
@@ -185,7 +185,21 @@ func TestOpenAIWSProtocolResolver_Resolve_ModeRouterV2(t *testing.T) {
}
decision
:=
NewOpenAIWSProtocolResolver
(
cfg
)
.
Resolve
(
legacyAccount
)
require
.
Equal
(
t
,
OpenAIUpstreamTransportResponsesWebsocketV2
,
decision
.
Transport
)
require
.
Equal
(
t
,
"ws_v2_mode_shared"
,
decision
.
Reason
)
require
.
Equal
(
t
,
"ws_v2_mode_ctx_pool"
,
decision
.
Reason
)
})
t
.
Run
(
"passthrough mode routes to ws v2"
,
func
(
t
*
testing
.
T
)
{
passthroughAccount
:=
&
Account
{
Platform
:
PlatformOpenAI
,
Type
:
AccountTypeOAuth
,
Concurrency
:
1
,
Extra
:
map
[
string
]
any
{
"openai_oauth_responses_websockets_v2_mode"
:
OpenAIWSIngressModePassthrough
,
},
}
decision
:=
NewOpenAIWSProtocolResolver
(
cfg
)
.
Resolve
(
passthroughAccount
)
require
.
Equal
(
t
,
OpenAIUpstreamTransportResponsesWebsocketV2
,
decision
.
Transport
)
require
.
Equal
(
t
,
"ws_v2_mode_passthrough"
,
decision
.
Reason
)
})
t
.
Run
(
"non-positive concurrency is rejected in v2 router"
,
func
(
t
*
testing
.
T
)
{
...
...
@@ -193,7 +207,7 @@ func TestOpenAIWSProtocolResolver_Resolve_ModeRouterV2(t *testing.T) {
Platform
:
PlatformOpenAI
,
Type
:
AccountTypeOAuth
,
Extra
:
map
[
string
]
any
{
"openai_oauth_responses_websockets_v2_mode"
:
OpenAIWSIngressMode
Shared
,
"openai_oauth_responses_websockets_v2_mode"
:
OpenAIWSIngressMode
CtxPool
,
},
}
decision
:=
NewOpenAIWSProtocolResolver
(
cfg
)
.
Resolve
(
invalidConcurrency
)
...
...
backend/internal/service/openai_ws_v2/caddy_adapter.go
0 → 100644
View file @
7076717b
package
openai_ws_v2
import
(
"context"
)
// runCaddyStyleRelay 采用 Caddy reverseproxy 的双向隧道思想:
// 连接建立后并发复制两个方向,任一方向退出触发收敛关闭。
//
// Reference:
// - Project: caddyserver/caddy (Apache-2.0)
// - Commit: f283062d37c50627d53ca682ebae2ce219b35515
// - Files:
// - modules/caddyhttp/reverseproxy/streaming.go
// - modules/caddyhttp/reverseproxy/reverseproxy.go
func
runCaddyStyleRelay
(
ctx
context
.
Context
,
clientConn
FrameConn
,
upstreamConn
FrameConn
,
firstClientMessage
[]
byte
,
options
RelayOptions
,
)
(
RelayResult
,
*
RelayExit
)
{
return
Relay
(
ctx
,
clientConn
,
upstreamConn
,
firstClientMessage
,
options
)
}
backend/internal/service/openai_ws_v2/entry.go
0 → 100644
View file @
7076717b
package
openai_ws_v2
import
"context"
// EntryInput 是 passthrough v2 数据面的入口参数。
type
EntryInput
struct
{
Ctx
context
.
Context
ClientConn
FrameConn
UpstreamConn
FrameConn
FirstClientMessage
[]
byte
Options
RelayOptions
}
// RunEntry 是 openai_ws_v2 包对外的统一入口。
func
RunEntry
(
input
EntryInput
)
(
RelayResult
,
*
RelayExit
)
{
return
runCaddyStyleRelay
(
input
.
Ctx
,
input
.
ClientConn
,
input
.
UpstreamConn
,
input
.
FirstClientMessage
,
input
.
Options
,
)
}
backend/internal/service/openai_ws_v2/metrics.go
0 → 100644
View file @
7076717b
package
openai_ws_v2
import
(
"sync/atomic"
)
// MetricsSnapshot 是 OpenAI WS v2 passthrough 路径的轻量运行时指标快照。
type
MetricsSnapshot
struct
{
SemanticMutationTotal
int64
`json:"semantic_mutation_total"`
UsageParseFailureTotal
int64
`json:"usage_parse_failure_total"`
}
var
(
// passthrough 路径默认不会做语义改写,该计数通常应保持为 0(保留用于未来防御性校验)。
passthroughSemanticMutationTotal
atomic
.
Int64
passthroughUsageParseFailureTotal
atomic
.
Int64
)
func
recordUsageParseFailure
()
{
passthroughUsageParseFailureTotal
.
Add
(
1
)
}
// SnapshotMetrics 返回当前 passthrough 指标快照。
func
SnapshotMetrics
()
MetricsSnapshot
{
return
MetricsSnapshot
{
SemanticMutationTotal
:
passthroughSemanticMutationTotal
.
Load
(),
UsageParseFailureTotal
:
passthroughUsageParseFailureTotal
.
Load
(),
}
}
backend/internal/service/openai_ws_v2/passthrough_relay.go
0 → 100644
View file @
7076717b
This diff is collapsed.
Click to expand it.
backend/internal/service/openai_ws_v2/passthrough_relay_internal_test.go
0 → 100644
View file @
7076717b
package
openai_ws_v2
import
(
"context"
"errors"
"io"
"net"
"sync/atomic"
"testing"
"time"
coderws
"github.com/coder/websocket"
"github.com/stretchr/testify/require"
"github.com/tidwall/gjson"
)
func
TestRunEntry_DelegatesRelay
(
t
*
testing
.
T
)
{
t
.
Parallel
()
clientConn
:=
newPassthroughTestFrameConn
(
nil
,
false
)
upstreamConn
:=
newPassthroughTestFrameConn
([]
passthroughTestFrame
{
{
msgType
:
coderws
.
MessageText
,
payload
:
[]
byte
(
`{"type":"response.completed","response":{"id":"resp_entry","usage":{"input_tokens":1,"output_tokens":1}}}`
),
},
},
true
)
result
,
relayExit
:=
RunEntry
(
EntryInput
{
Ctx
:
context
.
Background
(),
ClientConn
:
clientConn
,
UpstreamConn
:
upstreamConn
,
FirstClientMessage
:
[]
byte
(
`{"type":"response.create","model":"gpt-4o","input":[]}`
),
})
require
.
Nil
(
t
,
relayExit
)
require
.
Equal
(
t
,
"resp_entry"
,
result
.
RequestID
)
}
func
TestRunClientToUpstream_ErrorPaths
(
t
*
testing
.
T
)
{
t
.
Parallel
()
t
.
Run
(
"read client eof"
,
func
(
t
*
testing
.
T
)
{
t
.
Parallel
()
exitCh
:=
make
(
chan
relayExitSignal
,
1
)
runClientToUpstream
(
context
.
Background
(),
newPassthroughTestFrameConn
(
nil
,
true
),
func
(
_
coderws
.
MessageType
,
_
[]
byte
)
error
{
return
nil
},
func
()
{},
nil
,
nil
,
exitCh
,
)
sig
:=
<-
exitCh
require
.
Equal
(
t
,
"read_client"
,
sig
.
stage
)
require
.
True
(
t
,
sig
.
graceful
)
})
t
.
Run
(
"write upstream failed"
,
func
(
t
*
testing
.
T
)
{
t
.
Parallel
()
exitCh
:=
make
(
chan
relayExitSignal
,
1
)
runClientToUpstream
(
context
.
Background
(),
newPassthroughTestFrameConn
([]
passthroughTestFrame
{
{
msgType
:
coderws
.
MessageText
,
payload
:
[]
byte
(
`{"x":1}`
)},
},
true
),
func
(
_
coderws
.
MessageType
,
_
[]
byte
)
error
{
return
errors
.
New
(
"boom"
)
},
func
()
{},
nil
,
nil
,
exitCh
,
)
sig
:=
<-
exitCh
require
.
Equal
(
t
,
"write_upstream"
,
sig
.
stage
)
require
.
False
(
t
,
sig
.
graceful
)
})
t
.
Run
(
"forwarded counter and trace callback"
,
func
(
t
*
testing
.
T
)
{
t
.
Parallel
()
exitCh
:=
make
(
chan
relayExitSignal
,
1
)
forwarded
:=
&
atomic
.
Int64
{}
traces
:=
make
([]
RelayTraceEvent
,
0
,
2
)
runClientToUpstream
(
context
.
Background
(),
newPassthroughTestFrameConn
([]
passthroughTestFrame
{
{
msgType
:
coderws
.
MessageText
,
payload
:
[]
byte
(
`{"x":1}`
)},
},
true
),
func
(
_
coderws
.
MessageType
,
_
[]
byte
)
error
{
return
nil
},
func
()
{},
forwarded
,
func
(
event
RelayTraceEvent
)
{
traces
=
append
(
traces
,
event
)
},
exitCh
,
)
sig
:=
<-
exitCh
require
.
Equal
(
t
,
"read_client"
,
sig
.
stage
)
require
.
Equal
(
t
,
int64
(
1
),
forwarded
.
Load
())
require
.
NotEmpty
(
t
,
traces
)
})
}
func
TestRunUpstreamToClient_ErrorAndDropPaths
(
t
*
testing
.
T
)
{
t
.
Parallel
()
t
.
Run
(
"read upstream eof"
,
func
(
t
*
testing
.
T
)
{
t
.
Parallel
()
exitCh
:=
make
(
chan
relayExitSignal
,
1
)
drop
:=
&
atomic
.
Bool
{}
drop
.
Store
(
false
)
runUpstreamToClient
(
context
.
Background
(),
newPassthroughTestFrameConn
(
nil
,
true
),
func
(
_
coderws
.
MessageType
,
_
[]
byte
)
error
{
return
nil
},
time
.
Now
(),
time
.
Now
,
&
relayState
{},
nil
,
nil
,
drop
,
nil
,
nil
,
func
()
{},
nil
,
exitCh
,
)
sig
:=
<-
exitCh
require
.
Equal
(
t
,
"read_upstream"
,
sig
.
stage
)
require
.
True
(
t
,
sig
.
graceful
)
})
t
.
Run
(
"write client failed"
,
func
(
t
*
testing
.
T
)
{
t
.
Parallel
()
exitCh
:=
make
(
chan
relayExitSignal
,
1
)
drop
:=
&
atomic
.
Bool
{}
drop
.
Store
(
false
)
runUpstreamToClient
(
context
.
Background
(),
newPassthroughTestFrameConn
([]
passthroughTestFrame
{
{
msgType
:
coderws
.
MessageText
,
payload
:
[]
byte
(
`{"type":"response.output_text.delta","delta":"x"}`
)},
},
true
),
func
(
_
coderws
.
MessageType
,
_
[]
byte
)
error
{
return
errors
.
New
(
"write failed"
)
},
time
.
Now
(),
time
.
Now
,
&
relayState
{},
nil
,
nil
,
drop
,
nil
,
nil
,
func
()
{},
nil
,
exitCh
,
)
sig
:=
<-
exitCh
require
.
Equal
(
t
,
"write_client"
,
sig
.
stage
)
})
t
.
Run
(
"drop downstream and stop on terminal"
,
func
(
t
*
testing
.
T
)
{
t
.
Parallel
()
exitCh
:=
make
(
chan
relayExitSignal
,
1
)
drop
:=
&
atomic
.
Bool
{}
drop
.
Store
(
true
)
dropped
:=
&
atomic
.
Int64
{}
runUpstreamToClient
(
context
.
Background
(),
newPassthroughTestFrameConn
([]
passthroughTestFrame
{
{
msgType
:
coderws
.
MessageText
,
payload
:
[]
byte
(
`{"type":"response.completed","response":{"id":"resp_drop","usage":{"input_tokens":1,"output_tokens":1}}}`
),
},
},
true
),
func
(
_
coderws
.
MessageType
,
_
[]
byte
)
error
{
return
nil
},
time
.
Now
(),
time
.
Now
,
&
relayState
{},
nil
,
nil
,
drop
,
nil
,
dropped
,
func
()
{},
nil
,
exitCh
,
)
sig
:=
<-
exitCh
require
.
Equal
(
t
,
"drain_terminal"
,
sig
.
stage
)
require
.
True
(
t
,
sig
.
graceful
)
require
.
Equal
(
t
,
int64
(
1
),
dropped
.
Load
())
})
}
func
TestRunIdleWatchdog_NoTimeoutWhenDisabled
(
t
*
testing
.
T
)
{
t
.
Parallel
()
exitCh
:=
make
(
chan
relayExitSignal
,
1
)
lastActivity
:=
&
atomic
.
Int64
{}
lastActivity
.
Store
(
time
.
Now
()
.
UnixNano
())
ctx
,
cancel
:=
context
.
WithCancel
(
context
.
Background
())
defer
cancel
()
go
runIdleWatchdog
(
ctx
,
time
.
Now
,
0
,
lastActivity
,
nil
,
exitCh
)
select
{
case
<-
exitCh
:
t
.
Fatal
(
"unexpected idle timeout signal"
)
case
<-
time
.
After
(
200
*
time
.
Millisecond
)
:
}
}
func
TestHelperFunctionsCoverage
(
t
*
testing
.
T
)
{
t
.
Parallel
()
require
.
Equal
(
t
,
"text"
,
relayMessageTypeString
(
coderws
.
MessageText
))
require
.
Equal
(
t
,
"binary"
,
relayMessageTypeString
(
coderws
.
MessageBinary
))
require
.
Contains
(
t
,
relayMessageTypeString
(
coderws
.
MessageType
(
99
)),
"unknown("
)
require
.
Equal
(
t
,
""
,
relayErrorString
(
nil
))
require
.
Equal
(
t
,
"x"
,
relayErrorString
(
errors
.
New
(
"x"
)))
require
.
True
(
t
,
isDisconnectError
(
io
.
EOF
))
require
.
True
(
t
,
isDisconnectError
(
net
.
ErrClosed
))
require
.
True
(
t
,
isDisconnectError
(
context
.
Canceled
))
require
.
True
(
t
,
isDisconnectError
(
coderws
.
CloseError
{
Code
:
coderws
.
StatusGoingAway
}))
require
.
True
(
t
,
isDisconnectError
(
errors
.
New
(
"broken pipe"
)))
require
.
False
(
t
,
isDisconnectError
(
errors
.
New
(
"unrelated"
)))
require
.
True
(
t
,
isTokenEvent
(
"response.output_text.delta"
))
require
.
True
(
t
,
isTokenEvent
(
"response.output_audio.delta"
))
require
.
True
(
t
,
isTokenEvent
(
"response.completed"
))
require
.
False
(
t
,
isTokenEvent
(
""
))
require
.
False
(
t
,
isTokenEvent
(
"response.created"
))
require
.
Equal
(
t
,
2
*
time
.
Second
,
minDuration
(
2
*
time
.
Second
,
5
*
time
.
Second
))
require
.
Equal
(
t
,
2
*
time
.
Second
,
minDuration
(
5
*
time
.
Second
,
2
*
time
.
Second
))
require
.
Equal
(
t
,
5
*
time
.
Second
,
minDuration
(
0
,
5
*
time
.
Second
))
require
.
Equal
(
t
,
2
*
time
.
Second
,
minDuration
(
2
*
time
.
Second
,
0
))
ch
:=
make
(
chan
relayExitSignal
,
1
)
ch
<-
relayExitSignal
{
stage
:
"ok"
}
sig
,
ok
:=
waitRelayExit
(
ch
,
10
*
time
.
Millisecond
)
require
.
True
(
t
,
ok
)
require
.
Equal
(
t
,
"ok"
,
sig
.
stage
)
ch
<-
relayExitSignal
{
stage
:
"ok2"
}
sig
,
ok
=
waitRelayExit
(
ch
,
0
)
require
.
True
(
t
,
ok
)
require
.
Equal
(
t
,
"ok2"
,
sig
.
stage
)
_
,
ok
=
waitRelayExit
(
ch
,
10
*
time
.
Millisecond
)
require
.
False
(
t
,
ok
)
n
,
ok
:=
parseUsageIntField
(
gjson
.
Get
(
`{"n":3}`
,
"n"
),
true
)
require
.
True
(
t
,
ok
)
require
.
Equal
(
t
,
3
,
n
)
_
,
ok
=
parseUsageIntField
(
gjson
.
Get
(
`{"n":"x"}`
,
"n"
),
true
)
require
.
False
(
t
,
ok
)
n
,
ok
=
parseUsageIntField
(
gjson
.
Result
{},
false
)
require
.
True
(
t
,
ok
)
require
.
Equal
(
t
,
0
,
n
)
_
,
ok
=
parseUsageIntField
(
gjson
.
Result
{},
true
)
require
.
False
(
t
,
ok
)
}
func
TestParseUsageAndEnrichCoverage
(
t
*
testing
.
T
)
{
t
.
Parallel
()
state
:=
&
relayState
{}
parseUsageAndAccumulate
(
state
,
[]
byte
(
`{"type":"response.completed","response":{"usage":{"input_tokens":"bad"}}}`
),
"response.completed"
,
nil
)
require
.
Equal
(
t
,
0
,
state
.
usage
.
InputTokens
)
parseUsageAndAccumulate
(
state
,
[]
byte
(
`{"type":"response.completed","response":{"usage":{"input_tokens":9,"output_tokens":"bad","input_tokens_details":{"cached_tokens":2}}}}`
),
"response.completed"
,
nil
,
)
require
.
Equal
(
t
,
0
,
state
.
usage
.
InputTokens
,
"部分字段解析失败时不应累加 usage"
)
require
.
Equal
(
t
,
0
,
state
.
usage
.
OutputTokens
)
require
.
Equal
(
t
,
0
,
state
.
usage
.
CacheReadInputTokens
)
parseUsageAndAccumulate
(
state
,
[]
byte
(
`{"type":"response.completed","response":{"usage":{"input_tokens_details":{"cached_tokens":2}}}}`
),
"response.completed"
,
nil
,
)
require
.
Equal
(
t
,
0
,
state
.
usage
.
InputTokens
,
"必填 usage 字段缺失时不应累加 usage"
)
require
.
Equal
(
t
,
0
,
state
.
usage
.
OutputTokens
)
require
.
Equal
(
t
,
0
,
state
.
usage
.
CacheReadInputTokens
)
parseUsageAndAccumulate
(
state
,
[]
byte
(
`{"type":"response.completed","response":{"usage":{"input_tokens":2,"output_tokens":1,"input_tokens_details":{"cached_tokens":1}}}}`
),
"response.completed"
,
nil
)
require
.
Equal
(
t
,
2
,
state
.
usage
.
InputTokens
)
require
.
Equal
(
t
,
1
,
state
.
usage
.
OutputTokens
)
require
.
Equal
(
t
,
1
,
state
.
usage
.
CacheReadInputTokens
)
result
:=
&
RelayResult
{}
enrichResult
(
result
,
state
,
5
*
time
.
Millisecond
)
require
.
Equal
(
t
,
state
.
usage
.
InputTokens
,
result
.
Usage
.
InputTokens
)
require
.
Equal
(
t
,
5
*
time
.
Millisecond
,
result
.
Duration
)
parseUsageAndAccumulate
(
state
,
[]
byte
(
`{"type":"response.in_progress","response":{"usage":{"input_tokens":9}}}`
),
"response.in_progress"
,
nil
)
require
.
Equal
(
t
,
2
,
state
.
usage
.
InputTokens
)
enrichResult
(
nil
,
state
,
0
)
}
func
TestEmitTurnCompleteCoverage
(
t
*
testing
.
T
)
{
t
.
Parallel
()
// 非 terminal 事件不应触发。
called
:=
0
emitTurnComplete
(
func
(
turn
RelayTurnResult
)
{
called
++
},
&
relayState
{
requestModel
:
"gpt-5"
},
observedUpstreamEvent
{
terminal
:
false
,
eventType
:
"response.output_text.delta"
,
responseID
:
"resp_ignored"
,
usage
:
Usage
{
InputTokens
:
1
},
})
require
.
Equal
(
t
,
0
,
called
)
// 缺少 response_id 时不应触发。
emitTurnComplete
(
func
(
turn
RelayTurnResult
)
{
called
++
},
&
relayState
{
requestModel
:
"gpt-5"
},
observedUpstreamEvent
{
terminal
:
true
,
eventType
:
"response.completed"
,
})
require
.
Equal
(
t
,
0
,
called
)
// terminal 且 response_id 存在,应该触发;state=nil 时 model 为空串。
var
got
RelayTurnResult
emitTurnComplete
(
func
(
turn
RelayTurnResult
)
{
called
++
got
=
turn
},
nil
,
observedUpstreamEvent
{
terminal
:
true
,
eventType
:
"response.completed"
,
responseID
:
"resp_emit"
,
usage
:
Usage
{
InputTokens
:
2
,
OutputTokens
:
3
},
})
require
.
Equal
(
t
,
1
,
called
)
require
.
Equal
(
t
,
"resp_emit"
,
got
.
RequestID
)
require
.
Equal
(
t
,
"response.completed"
,
got
.
TerminalEventType
)
require
.
Equal
(
t
,
2
,
got
.
Usage
.
InputTokens
)
require
.
Equal
(
t
,
3
,
got
.
Usage
.
OutputTokens
)
require
.
Equal
(
t
,
""
,
got
.
RequestModel
)
}
func
TestIsDisconnectErrorCoverage_CloseStatusesAndMessageBranches
(
t
*
testing
.
T
)
{
t
.
Parallel
()
require
.
True
(
t
,
isDisconnectError
(
coderws
.
CloseError
{
Code
:
coderws
.
StatusNormalClosure
}))
require
.
True
(
t
,
isDisconnectError
(
coderws
.
CloseError
{
Code
:
coderws
.
StatusNoStatusRcvd
}))
require
.
True
(
t
,
isDisconnectError
(
coderws
.
CloseError
{
Code
:
coderws
.
StatusAbnormalClosure
}))
require
.
True
(
t
,
isDisconnectError
(
errors
.
New
(
"connection reset by peer"
)))
require
.
False
(
t
,
isDisconnectError
(
errors
.
New
(
" "
)))
}
func
TestIsTokenEventCoverageBranches
(
t
*
testing
.
T
)
{
t
.
Parallel
()
require
.
False
(
t
,
isTokenEvent
(
"response.in_progress"
))
require
.
False
(
t
,
isTokenEvent
(
"response.output_item.added"
))
require
.
True
(
t
,
isTokenEvent
(
"response.output_audio.delta"
))
require
.
True
(
t
,
isTokenEvent
(
"response.output"
))
require
.
True
(
t
,
isTokenEvent
(
"response.done"
))
}
func
TestRelayTurnTimingHelpersCoverage
(
t
*
testing
.
T
)
{
t
.
Parallel
()
now
:=
time
.
Unix
(
100
,
0
)
// nil state
require
.
Nil
(
t
,
openAIWSRelayGetOrInitTurnTiming
(
nil
,
"resp_nil"
,
now
))
_
,
ok
:=
openAIWSRelayDeleteTurnTiming
(
nil
,
"resp_nil"
)
require
.
False
(
t
,
ok
)
state
:=
&
relayState
{}
timing
:=
openAIWSRelayGetOrInitTurnTiming
(
state
,
"resp_a"
,
now
)
require
.
NotNil
(
t
,
timing
)
require
.
Equal
(
t
,
now
,
timing
.
startAt
)
// 再次获取返回同一条 timing
timing2
:=
openAIWSRelayGetOrInitTurnTiming
(
state
,
"resp_a"
,
now
.
Add
(
5
*
time
.
Second
))
require
.
NotNil
(
t
,
timing2
)
require
.
Equal
(
t
,
now
,
timing2
.
startAt
)
// 删除存在键
deleted
,
ok
:=
openAIWSRelayDeleteTurnTiming
(
state
,
"resp_a"
)
require
.
True
(
t
,
ok
)
require
.
Equal
(
t
,
now
,
deleted
.
startAt
)
// 删除不存在键
_
,
ok
=
openAIWSRelayDeleteTurnTiming
(
state
,
"resp_a"
)
require
.
False
(
t
,
ok
)
}
func
TestObserveUpstreamMessage_ResponseIDFallbackPolicy
(
t
*
testing
.
T
)
{
t
.
Parallel
()
state
:=
&
relayState
{
requestModel
:
"gpt-5"
}
startAt
:=
time
.
Unix
(
0
,
0
)
now
:=
startAt
nowFn
:=
func
()
time
.
Time
{
now
=
now
.
Add
(
5
*
time
.
Millisecond
)
return
now
}
// 非 terminal:仅有顶层 id,不应把 event id 当成 response_id。
observed
:=
observeUpstreamMessage
(
state
,
[]
byte
(
`{"type":"response.output_text.delta","id":"evt_123","delta":"hi"}`
),
startAt
,
nowFn
,
nil
,
)
require
.
False
(
t
,
observed
.
terminal
)
require
.
Equal
(
t
,
""
,
observed
.
responseID
)
// terminal:允许兜底用顶层 id(用于兼容少数字段变体)。
observed
=
observeUpstreamMessage
(
state
,
[]
byte
(
`{"type":"response.completed","id":"resp_fallback","response":{"usage":{"input_tokens":1,"output_tokens":1}}}`
),
startAt
,
nowFn
,
nil
,
)
require
.
True
(
t
,
observed
.
terminal
)
require
.
Equal
(
t
,
"resp_fallback"
,
observed
.
responseID
)
}
backend/internal/service/openai_ws_v2/passthrough_relay_test.go
0 → 100644
View file @
7076717b
This diff is collapsed.
Click to expand it.
backend/internal/service/openai_ws_v2_passthrough_adapter.go
0 → 100644
View file @
7076717b
package
service
import
(
"context"
"errors"
"fmt"
"net/http"
"net/url"
"strings"
"sync/atomic"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
openaiwsv2
"github.com/Wei-Shaw/sub2api/internal/service/openai_ws_v2"
coderws
"github.com/coder/websocket"
"github.com/gin-gonic/gin"
"github.com/tidwall/gjson"
)
type
openAIWSClientFrameConn
struct
{
conn
*
coderws
.
Conn
}
const
openaiWSV2PassthroughModeFields
=
"ws_mode=passthrough ws_router=v2"
var
_
openaiwsv2
.
FrameConn
=
(
*
openAIWSClientFrameConn
)(
nil
)
func
(
c
*
openAIWSClientFrameConn
)
ReadFrame
(
ctx
context
.
Context
)
(
coderws
.
MessageType
,
[]
byte
,
error
)
{
if
c
==
nil
||
c
.
conn
==
nil
{
return
coderws
.
MessageText
,
nil
,
errOpenAIWSConnClosed
}
if
ctx
==
nil
{
ctx
=
context
.
Background
()
}
return
c
.
conn
.
Read
(
ctx
)
}
func
(
c
*
openAIWSClientFrameConn
)
WriteFrame
(
ctx
context
.
Context
,
msgType
coderws
.
MessageType
,
payload
[]
byte
)
error
{
if
c
==
nil
||
c
.
conn
==
nil
{
return
errOpenAIWSConnClosed
}
if
ctx
==
nil
{
ctx
=
context
.
Background
()
}
return
c
.
conn
.
Write
(
ctx
,
msgType
,
payload
)
}
func
(
c
*
openAIWSClientFrameConn
)
Close
()
error
{
if
c
==
nil
||
c
.
conn
==
nil
{
return
nil
}
_
=
c
.
conn
.
Close
(
coderws
.
StatusNormalClosure
,
""
)
_
=
c
.
conn
.
CloseNow
()
return
nil
}
func
(
s
*
OpenAIGatewayService
)
proxyResponsesWebSocketV2Passthrough
(
ctx
context
.
Context
,
c
*
gin
.
Context
,
clientConn
*
coderws
.
Conn
,
account
*
Account
,
token
string
,
firstClientMessage
[]
byte
,
hooks
*
OpenAIWSIngressHooks
,
wsDecision
OpenAIWSProtocolDecision
,
)
error
{
if
s
==
nil
{
return
errors
.
New
(
"service is nil"
)
}
if
clientConn
==
nil
{
return
errors
.
New
(
"client websocket is nil"
)
}
if
account
==
nil
{
return
errors
.
New
(
"account is nil"
)
}
if
strings
.
TrimSpace
(
token
)
==
""
{
return
errors
.
New
(
"token is empty"
)
}
requestModel
:=
strings
.
TrimSpace
(
gjson
.
GetBytes
(
firstClientMessage
,
"model"
)
.
String
())
requestPreviousResponseID
:=
strings
.
TrimSpace
(
gjson
.
GetBytes
(
firstClientMessage
,
"previous_response_id"
)
.
String
())
logOpenAIWSV2Passthrough
(
"relay_start account_id=%d model=%s previous_response_id=%s first_message_type=%s first_message_bytes=%d"
,
account
.
ID
,
truncateOpenAIWSLogValue
(
requestModel
,
openAIWSLogValueMaxLen
),
truncateOpenAIWSLogValue
(
requestPreviousResponseID
,
openAIWSIDValueMaxLen
),
openaiwsv2RelayMessageTypeName
(
coderws
.
MessageText
),
len
(
firstClientMessage
),
)
wsURL
,
err
:=
s
.
buildOpenAIResponsesWSURL
(
account
)
if
err
!=
nil
{
return
fmt
.
Errorf
(
"build ws url: %w"
,
err
)
}
wsHost
:=
"-"
wsPath
:=
"-"
if
parsedURL
,
parseErr
:=
url
.
Parse
(
wsURL
);
parseErr
==
nil
&&
parsedURL
!=
nil
{
wsHost
=
normalizeOpenAIWSLogValue
(
parsedURL
.
Host
)
wsPath
=
normalizeOpenAIWSLogValue
(
parsedURL
.
Path
)
}
logOpenAIWSV2Passthrough
(
"relay_dial_start account_id=%d ws_host=%s ws_path=%s proxy_enabled=%v"
,
account
.
ID
,
wsHost
,
wsPath
,
account
.
ProxyID
!=
nil
&&
account
.
Proxy
!=
nil
,
)
isCodexCLI
:=
false
if
c
!=
nil
{
isCodexCLI
=
openai
.
IsCodexCLIRequest
(
c
.
GetHeader
(
"User-Agent"
))
}
if
s
.
cfg
!=
nil
&&
s
.
cfg
.
Gateway
.
ForceCodexCLI
{
isCodexCLI
=
true
}
headers
,
_
:=
s
.
buildOpenAIWSHeaders
(
c
,
account
,
token
,
wsDecision
,
isCodexCLI
,
""
,
""
,
""
)
proxyURL
:=
""
if
account
.
ProxyID
!=
nil
&&
account
.
Proxy
!=
nil
{
proxyURL
=
account
.
Proxy
.
URL
()
}
dialer
:=
s
.
getOpenAIWSPassthroughDialer
()
if
dialer
==
nil
{
return
errors
.
New
(
"openai ws passthrough dialer is nil"
)
}
dialCtx
,
cancelDial
:=
context
.
WithTimeout
(
ctx
,
s
.
openAIWSDialTimeout
())
defer
cancelDial
()
upstreamConn
,
statusCode
,
handshakeHeaders
,
err
:=
dialer
.
Dial
(
dialCtx
,
wsURL
,
headers
,
proxyURL
)
if
err
!=
nil
{
logOpenAIWSV2Passthrough
(
"relay_dial_failed account_id=%d status_code=%d err=%s"
,
account
.
ID
,
statusCode
,
truncateOpenAIWSLogValue
(
err
.
Error
(),
openAIWSLogValueMaxLen
),
)
return
s
.
mapOpenAIWSPassthroughDialError
(
err
,
statusCode
,
handshakeHeaders
)
}
defer
func
()
{
_
=
upstreamConn
.
Close
()
}()
logOpenAIWSV2Passthrough
(
"relay_dial_ok account_id=%d status_code=%d upstream_request_id=%s"
,
account
.
ID
,
statusCode
,
openAIWSHeaderValueForLog
(
handshakeHeaders
,
"x-request-id"
),
)
upstreamFrameConn
,
ok
:=
upstreamConn
.
(
openaiwsv2
.
FrameConn
)
if
!
ok
{
return
errors
.
New
(
"openai ws passthrough upstream connection does not support frame relay"
)
}
completedTurns
:=
atomic
.
Int32
{}
relayResult
,
relayExit
:=
openaiwsv2
.
RunEntry
(
openaiwsv2
.
EntryInput
{
Ctx
:
ctx
,
ClientConn
:
&
openAIWSClientFrameConn
{
conn
:
clientConn
},
UpstreamConn
:
upstreamFrameConn
,
FirstClientMessage
:
firstClientMessage
,
Options
:
openaiwsv2
.
RelayOptions
{
WriteTimeout
:
s
.
openAIWSWriteTimeout
(),
IdleTimeout
:
s
.
openAIWSPassthroughIdleTimeout
(),
FirstMessageType
:
coderws
.
MessageText
,
OnUsageParseFailure
:
func
(
eventType
string
,
usageRaw
string
)
{
logOpenAIWSV2Passthrough
(
"usage_parse_failed event_type=%s usage_raw=%s"
,
truncateOpenAIWSLogValue
(
eventType
,
openAIWSLogValueMaxLen
),
truncateOpenAIWSLogValue
(
usageRaw
,
openAIWSLogValueMaxLen
),
)
},
OnTurnComplete
:
func
(
turn
openaiwsv2
.
RelayTurnResult
)
{
turnNo
:=
int
(
completedTurns
.
Add
(
1
))
turnResult
:=
&
OpenAIForwardResult
{
RequestID
:
turn
.
RequestID
,
Usage
:
OpenAIUsage
{
InputTokens
:
turn
.
Usage
.
InputTokens
,
OutputTokens
:
turn
.
Usage
.
OutputTokens
,
CacheCreationInputTokens
:
turn
.
Usage
.
CacheCreationInputTokens
,
CacheReadInputTokens
:
turn
.
Usage
.
CacheReadInputTokens
,
},
Model
:
turn
.
RequestModel
,
Stream
:
true
,
OpenAIWSMode
:
true
,
Duration
:
turn
.
Duration
,
FirstTokenMs
:
turn
.
FirstTokenMs
,
}
logOpenAIWSV2Passthrough
(
"relay_turn_completed account_id=%d turn=%d request_id=%s terminal_event=%s duration_ms=%d first_token_ms=%d input_tokens=%d output_tokens=%d cache_read_tokens=%d"
,
account
.
ID
,
turnNo
,
truncateOpenAIWSLogValue
(
turnResult
.
RequestID
,
openAIWSIDValueMaxLen
),
truncateOpenAIWSLogValue
(
turn
.
TerminalEventType
,
openAIWSLogValueMaxLen
),
turnResult
.
Duration
.
Milliseconds
(),
openAIWSFirstTokenMsForLog
(
turnResult
.
FirstTokenMs
),
turnResult
.
Usage
.
InputTokens
,
turnResult
.
Usage
.
OutputTokens
,
turnResult
.
Usage
.
CacheReadInputTokens
,
)
if
hooks
!=
nil
&&
hooks
.
AfterTurn
!=
nil
{
hooks
.
AfterTurn
(
turnNo
,
turnResult
,
nil
)
}
},
OnTrace
:
func
(
event
openaiwsv2
.
RelayTraceEvent
)
{
logOpenAIWSV2Passthrough
(
"relay_trace account_id=%d stage=%s direction=%s msg_type=%s bytes=%d graceful=%v wrote_downstream=%v err=%s"
,
account
.
ID
,
truncateOpenAIWSLogValue
(
event
.
Stage
,
openAIWSLogValueMaxLen
),
truncateOpenAIWSLogValue
(
event
.
Direction
,
openAIWSLogValueMaxLen
),
truncateOpenAIWSLogValue
(
event
.
MessageType
,
openAIWSLogValueMaxLen
),
event
.
PayloadBytes
,
event
.
Graceful
,
event
.
WroteDownstream
,
truncateOpenAIWSLogValue
(
event
.
Error
,
openAIWSLogValueMaxLen
),
)
},
},
})
result
:=
&
OpenAIForwardResult
{
RequestID
:
relayResult
.
RequestID
,
Usage
:
OpenAIUsage
{
InputTokens
:
relayResult
.
Usage
.
InputTokens
,
OutputTokens
:
relayResult
.
Usage
.
OutputTokens
,
CacheCreationInputTokens
:
relayResult
.
Usage
.
CacheCreationInputTokens
,
CacheReadInputTokens
:
relayResult
.
Usage
.
CacheReadInputTokens
,
},
Model
:
relayResult
.
RequestModel
,
Stream
:
true
,
OpenAIWSMode
:
true
,
Duration
:
relayResult
.
Duration
,
FirstTokenMs
:
relayResult
.
FirstTokenMs
,
}
turnCount
:=
int
(
completedTurns
.
Load
())
if
relayExit
==
nil
{
logOpenAIWSV2Passthrough
(
"relay_completed account_id=%d request_id=%s terminal_event=%s duration_ms=%d c2u_frames=%d u2c_frames=%d dropped_frames=%d turns=%d"
,
account
.
ID
,
truncateOpenAIWSLogValue
(
result
.
RequestID
,
openAIWSIDValueMaxLen
),
truncateOpenAIWSLogValue
(
relayResult
.
TerminalEventType
,
openAIWSLogValueMaxLen
),
result
.
Duration
.
Milliseconds
(),
relayResult
.
ClientToUpstreamFrames
,
relayResult
.
UpstreamToClientFrames
,
relayResult
.
DroppedDownstreamFrames
,
turnCount
,
)
// 正常路径按 terminal 事件逐 turn 已回调;仅在零 turn 场景兜底回调一次。
if
turnCount
==
0
&&
hooks
!=
nil
&&
hooks
.
AfterTurn
!=
nil
{
hooks
.
AfterTurn
(
1
,
result
,
nil
)
}
return
nil
}
logOpenAIWSV2Passthrough
(
"relay_failed account_id=%d stage=%s wrote_downstream=%v err=%s duration_ms=%d c2u_frames=%d u2c_frames=%d dropped_frames=%d turns=%d"
,
account
.
ID
,
truncateOpenAIWSLogValue
(
relayExit
.
Stage
,
openAIWSLogValueMaxLen
),
relayExit
.
WroteDownstream
,
truncateOpenAIWSLogValue
(
relayErrorText
(
relayExit
.
Err
),
openAIWSLogValueMaxLen
),
result
.
Duration
.
Milliseconds
(),
relayResult
.
ClientToUpstreamFrames
,
relayResult
.
UpstreamToClientFrames
,
relayResult
.
DroppedDownstreamFrames
,
turnCount
,
)
relayErr
:=
relayExit
.
Err
if
relayExit
.
Stage
==
"idle_timeout"
{
relayErr
=
NewOpenAIWSClientCloseError
(
coderws
.
StatusPolicyViolation
,
"client websocket idle timeout"
,
relayErr
,
)
}
turnErr
:=
wrapOpenAIWSIngressTurnError
(
relayExit
.
Stage
,
relayErr
,
relayExit
.
WroteDownstream
,
)
if
hooks
!=
nil
&&
hooks
.
AfterTurn
!=
nil
{
hooks
.
AfterTurn
(
turnCount
+
1
,
nil
,
turnErr
)
}
return
turnErr
}
func
(
s
*
OpenAIGatewayService
)
mapOpenAIWSPassthroughDialError
(
err
error
,
statusCode
int
,
handshakeHeaders
http
.
Header
,
)
error
{
if
err
==
nil
{
return
nil
}
wrappedErr
:=
err
var
dialErr
*
openAIWSDialError
if
!
errors
.
As
(
err
,
&
dialErr
)
{
wrappedErr
=
&
openAIWSDialError
{
StatusCode
:
statusCode
,
ResponseHeaders
:
cloneHeader
(
handshakeHeaders
),
Err
:
err
,
}
}
if
errors
.
Is
(
err
,
context
.
Canceled
)
{
return
err
}
if
errors
.
Is
(
err
,
context
.
DeadlineExceeded
)
{
return
NewOpenAIWSClientCloseError
(
coderws
.
StatusTryAgainLater
,
"upstream websocket connect timeout"
,
wrappedErr
,
)
}
if
statusCode
==
http
.
StatusTooManyRequests
{
return
NewOpenAIWSClientCloseError
(
coderws
.
StatusTryAgainLater
,
"upstream websocket is busy, please retry later"
,
wrappedErr
,
)
}
if
statusCode
==
http
.
StatusUnauthorized
||
statusCode
==
http
.
StatusForbidden
{
return
NewOpenAIWSClientCloseError
(
coderws
.
StatusPolicyViolation
,
"upstream websocket authentication failed"
,
wrappedErr
,
)
}
if
statusCode
>=
http
.
StatusBadRequest
&&
statusCode
<
http
.
StatusInternalServerError
{
return
NewOpenAIWSClientCloseError
(
coderws
.
StatusPolicyViolation
,
"upstream websocket handshake rejected"
,
wrappedErr
,
)
}
return
fmt
.
Errorf
(
"openai ws passthrough dial: %w"
,
wrappedErr
)
}
func
openaiwsv2RelayMessageTypeName
(
msgType
coderws
.
MessageType
)
string
{
switch
msgType
{
case
coderws
.
MessageText
:
return
"text"
case
coderws
.
MessageBinary
:
return
"binary"
default
:
return
fmt
.
Sprintf
(
"unknown(%d)"
,
msgType
)
}
}
func
relayErrorText
(
err
error
)
string
{
if
err
==
nil
{
return
""
}
return
err
.
Error
()
}
func
openAIWSFirstTokenMsForLog
(
firstTokenMs
*
int
)
int
{
if
firstTokenMs
==
nil
{
return
-
1
}
return
*
firstTokenMs
}
func
logOpenAIWSV2Passthrough
(
format
string
,
args
...
any
)
{
logger
.
LegacyPrintf
(
"service.openai_ws_v2"
,
"[OpenAI WS v2 passthrough] %s "
+
format
,
append
([]
any
{
openaiWSV2PassthroughModeFields
},
args
...
)
...
,
)
}
deploy/config.example.yaml
View file @
7076717b
...
...
@@ -209,8 +209,9 @@ gateway:
openai_ws
:
# 新版 WS mode 路由(默认关闭)。关闭时保持当前 legacy 实现行为。
mode_router_v2_enabled
:
false
# ingress 默认模式:off|shared|dedicated(仅 mode_router_v2_enabled=true 生效)
ingress_mode_default
:
shared
# ingress 默认模式:off|ctx_pool|passthrough(仅 mode_router_v2_enabled=true 生效)
# 兼容旧值:shared/dedicated 会按 ctx_pool 处理。
ingress_mode_default
:
ctx_pool
# 全局总开关,默认 true;关闭时所有请求保持原有 HTTP/SSE 路由
enabled
:
true
# 按账号类型细分开关
...
...
frontend/src/components/account/CreateAccountModal.vue
View file @
7076717b
...
...
@@ -1807,7 +1807,7 @@
<
/div
>
<
/div
>
<!--
OpenAI
WS
Mode
三态
(
off
/
shared
/
dedicated
)
-->
<!--
OpenAI
WS
Mode
三态
(
off
/
ctx_pool
/
passthrough
)
-->
<
div
v
-
if
=
"
form.platform === 'openai' && (accountCategory === 'oauth-based' || accountCategory === 'apikey')
"
class
=
"
border-t border-gray-200 pt-4 dark:border-dark-600
"
...
...
@@ -1819,7 +1819,7 @@
{{
t
(
'
admin.accounts.openai.wsModeDesc
'
)
}}
<
/p
>
<
p
class
=
"
mt-1 text-xs text-gray-500 dark:text-gray-400
"
>
{{
t
(
'
admin.accounts.openai.ws
ModeConcurrencyHint
'
)
}}
{{
t
(
openAIWS
ModeConcurrencyHint
Key
)
}}
<
/p
>
<
/div
>
<
div
class
=
"
w-52
"
>
...
...
@@ -2341,10 +2341,11 @@ import { applyInterceptWarmup } from '@/components/account/credentialsBuilder'
import
{
formatDateTimeLocalInput
,
parseDateTimeLocalInput
}
from
'
@/utils/format
'
import
{
createStableObjectKeyResolver
}
from
'
@/utils/stableObjectKey
'
import
{
OPENAI_WS_MODE_
DEDICATED
,
OPENAI_WS_MODE_
CTX_POOL
,
OPENAI_WS_MODE_OFF
,
OPENAI_WS_MODE_
SHARED
,
OPENAI_WS_MODE_
PASSTHROUGH
,
isOpenAIWSModeEnabled
,
resolveOpenAIWSModeConcurrencyHintKey
,
type
OpenAIWSMode
}
from
'
@/utils/openaiWsMode
'
import
OAuthAuthorizationFlow
from
'
./OAuthAuthorizationFlow.vue
'
...
...
@@ -2541,8 +2542,8 @@ const geminiSelectedTier = computed(() => {
const
openAIWSModeOptions
=
computed
(()
=>
[
{
value
:
OPENAI_WS_MODE_OFF
,
label
:
t
(
'
admin.accounts.openai.wsModeOff
'
)
}
,
{
value
:
OPENAI_WS_MODE_
SHARED
,
label
:
t
(
'
admin.accounts.openai.wsMode
Shared
'
)
}
,
{
value
:
OPENAI_WS_MODE_
DEDICATED
,
label
:
t
(
'
admin.accounts.openai.wsMode
Dedicated
'
)
}
{
value
:
OPENAI_WS_MODE_
CTX_POOL
,
label
:
t
(
'
admin.accounts.openai.wsMode
CtxPool
'
)
}
,
{
value
:
OPENAI_WS_MODE_
PASSTHROUGH
,
label
:
t
(
'
admin.accounts.openai.wsMode
Passthrough
'
)
}
])
const
openaiResponsesWebSocketV2Mode
=
computed
({
...
...
@@ -2561,6 +2562,10 @@ const openaiResponsesWebSocketV2Mode = computed({
}
}
)
const
openAIWSModeConcurrencyHintKey
=
computed
(()
=>
resolveOpenAIWSModeConcurrencyHintKey
(
openaiResponsesWebSocketV2Mode
.
value
)
)
const
isOpenAIModelRestrictionDisabled
=
computed
(()
=>
form
.
platform
===
'
openai
'
&&
openaiPassthroughEnabled
.
value
)
...
...
@@ -3180,10 +3185,13 @@ const buildOpenAIExtra = (base?: Record<string, unknown>): Record<string, unknow
}
const
extra
:
Record
<
string
,
unknown
>
=
{
...(
base
||
{
}
)
}
if
(
accountCategory
.
value
===
'
oauth-based
'
)
{
extra
.
openai_oauth_responses_websockets_v2_mode
=
openaiOAuthResponsesWebSocketV2Mode
.
value
extra
.
openai_apikey_responses_websockets_v2_mode
=
openaiAPIKeyResponsesWebSocketV2Mode
.
value
extra
.
openai_oauth_responses_websockets_v2_enabled
=
isOpenAIWSModeEnabled
(
openaiOAuthResponsesWebSocketV2Mode
.
value
)
}
else
if
(
accountCategory
.
value
===
'
apikey
'
)
{
extra
.
openai_apikey_responses_websockets_v2_mode
=
openaiAPIKeyResponsesWebSocketV2Mode
.
value
extra
.
openai_apikey_responses_websockets_v2_enabled
=
isOpenAIWSModeEnabled
(
openaiAPIKeyResponsesWebSocketV2Mode
.
value
)
}
// 清理兼容旧键,统一改用分类型开关。
delete
extra
.
responses_websockets_v2_enabled
delete
extra
.
openai_ws_enabled
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment