Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Menu
Open sidebar
陈曦
sub2api
Commits
7076717b
Unverified
Commit
7076717b
authored
Mar 05, 2026
by
Wesley Liddick
Committed by
GitHub
Mar 05, 2026
Browse files
Merge pull request #772 from mt21625457/aicodex2api-main
feat(openai-ws): 合并 WS v2 透传模式与前端 ws mode
parents
33988637
c0a4fcea
Changes
25
Hide whitespace changes
Inline
Side-by-side
backend/internal/config/config.go
View file @
7076717b
...
@@ -516,7 +516,7 @@ func (c *UserMessageQueueConfig) GetEffectiveMode() string {
...
@@ -516,7 +516,7 @@ func (c *UserMessageQueueConfig) GetEffectiveMode() string {
type
GatewayOpenAIWSConfig
struct
{
type
GatewayOpenAIWSConfig
struct
{
// ModeRouterV2Enabled: 新版 WS mode 路由开关(默认 false;关闭时保持 legacy 行为)
// ModeRouterV2Enabled: 新版 WS mode 路由开关(默认 false;关闭时保持 legacy 行为)
ModeRouterV2Enabled
bool
`mapstructure:"mode_router_v2_enabled"`
ModeRouterV2Enabled
bool
`mapstructure:"mode_router_v2_enabled"`
// IngressModeDefault: ingress 默认模式(off/
shared/dedicated
)
// IngressModeDefault: ingress 默认模式(off/
ctx_pool/passthrough
)
IngressModeDefault
string
`mapstructure:"ingress_mode_default"`
IngressModeDefault
string
`mapstructure:"ingress_mode_default"`
// Enabled: 全局总开关(默认 true)
// Enabled: 全局总开关(默认 true)
Enabled
bool
`mapstructure:"enabled"`
Enabled
bool
`mapstructure:"enabled"`
...
@@ -1335,7 +1335,7 @@ func setDefaults() {
...
@@ -1335,7 +1335,7 @@ func setDefaults() {
// OpenAI Responses WebSocket(默认开启;可通过 force_http 紧急回滚)
// OpenAI Responses WebSocket(默认开启;可通过 force_http 紧急回滚)
viper
.
SetDefault
(
"gateway.openai_ws.enabled"
,
true
)
viper
.
SetDefault
(
"gateway.openai_ws.enabled"
,
true
)
viper
.
SetDefault
(
"gateway.openai_ws.mode_router_v2_enabled"
,
false
)
viper
.
SetDefault
(
"gateway.openai_ws.mode_router_v2_enabled"
,
false
)
viper
.
SetDefault
(
"gateway.openai_ws.ingress_mode_default"
,
"
shared
"
)
viper
.
SetDefault
(
"gateway.openai_ws.ingress_mode_default"
,
"
ctx_pool
"
)
viper
.
SetDefault
(
"gateway.openai_ws.oauth_enabled"
,
true
)
viper
.
SetDefault
(
"gateway.openai_ws.oauth_enabled"
,
true
)
viper
.
SetDefault
(
"gateway.openai_ws.apikey_enabled"
,
true
)
viper
.
SetDefault
(
"gateway.openai_ws.apikey_enabled"
,
true
)
viper
.
SetDefault
(
"gateway.openai_ws.force_http"
,
false
)
viper
.
SetDefault
(
"gateway.openai_ws.force_http"
,
false
)
...
@@ -2043,9 +2043,11 @@ func (c *Config) Validate() error {
...
@@ -2043,9 +2043,11 @@ func (c *Config) Validate() error {
}
}
if
mode
:=
strings
.
ToLower
(
strings
.
TrimSpace
(
c
.
Gateway
.
OpenAIWS
.
IngressModeDefault
));
mode
!=
""
{
if
mode
:=
strings
.
ToLower
(
strings
.
TrimSpace
(
c
.
Gateway
.
OpenAIWS
.
IngressModeDefault
));
mode
!=
""
{
switch
mode
{
switch
mode
{
case
"off"
,
"shared"
,
"dedicated"
:
case
"off"
,
"ctx_pool"
,
"passthrough"
:
case
"shared"
,
"dedicated"
:
slog
.
Warn
(
"gateway.openai_ws.ingress_mode_default is deprecated, treating as ctx_pool; please update to off|ctx_pool|passthrough"
,
"value"
,
mode
)
default
:
default
:
return
fmt
.
Errorf
(
"gateway.openai_ws.ingress_mode_default must be one of off|
shared|dedicated
"
)
return
fmt
.
Errorf
(
"gateway.openai_ws.ingress_mode_default must be one of off|
ctx_pool|passthrough
"
)
}
}
}
}
if
mode
:=
strings
.
ToLower
(
strings
.
TrimSpace
(
c
.
Gateway
.
OpenAIWS
.
StoreDisabledConnMode
));
mode
!=
""
{
if
mode
:=
strings
.
ToLower
(
strings
.
TrimSpace
(
c
.
Gateway
.
OpenAIWS
.
StoreDisabledConnMode
));
mode
!=
""
{
...
...
backend/internal/config/config_test.go
View file @
7076717b
...
@@ -153,8 +153,8 @@ func TestLoadDefaultOpenAIWSConfig(t *testing.T) {
...
@@ -153,8 +153,8 @@ func TestLoadDefaultOpenAIWSConfig(t *testing.T) {
if
cfg
.
Gateway
.
OpenAIWS
.
ModeRouterV2Enabled
{
if
cfg
.
Gateway
.
OpenAIWS
.
ModeRouterV2Enabled
{
t
.
Fatalf
(
"Gateway.OpenAIWS.ModeRouterV2Enabled = true, want false"
)
t
.
Fatalf
(
"Gateway.OpenAIWS.ModeRouterV2Enabled = true, want false"
)
}
}
if
cfg
.
Gateway
.
OpenAIWS
.
IngressModeDefault
!=
"
shared
"
{
if
cfg
.
Gateway
.
OpenAIWS
.
IngressModeDefault
!=
"
ctx_pool
"
{
t
.
Fatalf
(
"Gateway.OpenAIWS.IngressModeDefault = %q, want %q"
,
cfg
.
Gateway
.
OpenAIWS
.
IngressModeDefault
,
"
shared
"
)
t
.
Fatalf
(
"Gateway.OpenAIWS.IngressModeDefault = %q, want %q"
,
cfg
.
Gateway
.
OpenAIWS
.
IngressModeDefault
,
"
ctx_pool
"
)
}
}
}
}
...
@@ -1373,7 +1373,7 @@ func TestValidateConfig_OpenAIWSRules(t *testing.T) {
...
@@ -1373,7 +1373,7 @@ func TestValidateConfig_OpenAIWSRules(t *testing.T) {
wantErr
:
"gateway.openai_ws.store_disabled_conn_mode"
,
wantErr
:
"gateway.openai_ws.store_disabled_conn_mode"
,
},
},
{
{
name
:
"ingress_mode_default 必须为 off|
shared|dedicated
"
,
name
:
"ingress_mode_default 必须为 off|
ctx_pool|passthrough
"
,
mutate
:
func
(
c
*
Config
)
{
c
.
Gateway
.
OpenAIWS
.
IngressModeDefault
=
"invalid"
},
mutate
:
func
(
c
*
Config
)
{
c
.
Gateway
.
OpenAIWS
.
IngressModeDefault
=
"invalid"
},
wantErr
:
"gateway.openai_ws.ingress_mode_default"
,
wantErr
:
"gateway.openai_ws.ingress_mode_default"
,
},
},
...
...
backend/internal/service/account.go
View file @
7076717b
...
@@ -853,15 +853,21 @@ func (a *Account) IsOpenAIResponsesWebSocketV2Enabled() bool {
...
@@ -853,15 +853,21 @@ func (a *Account) IsOpenAIResponsesWebSocketV2Enabled() bool {
}
}
const
(
const
(
OpenAIWSIngressModeOff
=
"off"
OpenAIWSIngressModeOff
=
"off"
OpenAIWSIngressModeShared
=
"shared"
OpenAIWSIngressModeShared
=
"shared"
OpenAIWSIngressModeDedicated
=
"dedicated"
OpenAIWSIngressModeDedicated
=
"dedicated"
OpenAIWSIngressModeCtxPool
=
"ctx_pool"
OpenAIWSIngressModePassthrough
=
"passthrough"
)
)
func
normalizeOpenAIWSIngressMode
(
mode
string
)
string
{
func
normalizeOpenAIWSIngressMode
(
mode
string
)
string
{
switch
strings
.
ToLower
(
strings
.
TrimSpace
(
mode
))
{
switch
strings
.
ToLower
(
strings
.
TrimSpace
(
mode
))
{
case
OpenAIWSIngressModeOff
:
case
OpenAIWSIngressModeOff
:
return
OpenAIWSIngressModeOff
return
OpenAIWSIngressModeOff
case
OpenAIWSIngressModeCtxPool
:
return
OpenAIWSIngressModeCtxPool
case
OpenAIWSIngressModePassthrough
:
return
OpenAIWSIngressModePassthrough
case
OpenAIWSIngressModeShared
:
case
OpenAIWSIngressModeShared
:
return
OpenAIWSIngressModeShared
return
OpenAIWSIngressModeShared
case
OpenAIWSIngressModeDedicated
:
case
OpenAIWSIngressModeDedicated
:
...
@@ -873,18 +879,21 @@ func normalizeOpenAIWSIngressMode(mode string) string {
...
@@ -873,18 +879,21 @@ func normalizeOpenAIWSIngressMode(mode string) string {
func
normalizeOpenAIWSIngressDefaultMode
(
mode
string
)
string
{
func
normalizeOpenAIWSIngressDefaultMode
(
mode
string
)
string
{
if
normalized
:=
normalizeOpenAIWSIngressMode
(
mode
);
normalized
!=
""
{
if
normalized
:=
normalizeOpenAIWSIngressMode
(
mode
);
normalized
!=
""
{
if
normalized
==
OpenAIWSIngressModeShared
||
normalized
==
OpenAIWSIngressModeDedicated
{
return
OpenAIWSIngressModeCtxPool
}
return
normalized
return
normalized
}
}
return
OpenAIWSIngressMode
Shared
return
OpenAIWSIngressMode
CtxPool
}
}
// ResolveOpenAIResponsesWebSocketV2Mode 返回账号在 WSv2 ingress 下的有效模式(off/
shared/dedicated
)。
// ResolveOpenAIResponsesWebSocketV2Mode 返回账号在 WSv2 ingress 下的有效模式(off/
ctx_pool/passthrough
)。
//
//
// 优先级:
// 优先级:
// 1. 分类型 mode 新字段(string)
// 1. 分类型 mode 新字段(string)
// 2. 分类型 enabled 旧字段(bool)
// 2. 分类型 enabled 旧字段(bool)
// 3. 兼容 enabled 旧字段(bool)
// 3. 兼容 enabled 旧字段(bool)
// 4. defaultMode(非法时回退
shared
)
// 4. defaultMode(非法时回退
ctx_pool
)
func
(
a
*
Account
)
ResolveOpenAIResponsesWebSocketV2Mode
(
defaultMode
string
)
string
{
func
(
a
*
Account
)
ResolveOpenAIResponsesWebSocketV2Mode
(
defaultMode
string
)
string
{
resolvedDefault
:=
normalizeOpenAIWSIngressDefaultMode
(
defaultMode
)
resolvedDefault
:=
normalizeOpenAIWSIngressDefaultMode
(
defaultMode
)
if
a
==
nil
||
!
a
.
IsOpenAI
()
{
if
a
==
nil
||
!
a
.
IsOpenAI
()
{
...
@@ -919,7 +928,7 @@ func (a *Account) ResolveOpenAIResponsesWebSocketV2Mode(defaultMode string) stri
...
@@ -919,7 +928,7 @@ func (a *Account) ResolveOpenAIResponsesWebSocketV2Mode(defaultMode string) stri
return
""
,
false
return
""
,
false
}
}
if
enabled
{
if
enabled
{
return
OpenAIWSIngressMode
Shared
,
true
return
OpenAIWSIngressMode
CtxPool
,
true
}
}
return
OpenAIWSIngressModeOff
,
true
return
OpenAIWSIngressModeOff
,
true
}
}
...
@@ -946,6 +955,10 @@ func (a *Account) ResolveOpenAIResponsesWebSocketV2Mode(defaultMode string) stri
...
@@ -946,6 +955,10 @@ func (a *Account) ResolveOpenAIResponsesWebSocketV2Mode(defaultMode string) stri
if
mode
,
ok
:=
resolveBoolMode
(
"openai_ws_enabled"
);
ok
{
if
mode
,
ok
:=
resolveBoolMode
(
"openai_ws_enabled"
);
ok
{
return
mode
return
mode
}
}
// 兼容旧值:shared/dedicated 语义都归并到 ctx_pool。
if
resolvedDefault
==
OpenAIWSIngressModeShared
||
resolvedDefault
==
OpenAIWSIngressModeDedicated
{
return
OpenAIWSIngressModeCtxPool
}
return
resolvedDefault
return
resolvedDefault
}
}
...
...
backend/internal/service/account_openai_passthrough_test.go
View file @
7076717b
...
@@ -206,14 +206,14 @@ func TestAccount_IsOpenAIResponsesWebSocketV2Enabled(t *testing.T) {
...
@@ -206,14 +206,14 @@ func TestAccount_IsOpenAIResponsesWebSocketV2Enabled(t *testing.T) {
}
}
func
TestAccount_ResolveOpenAIResponsesWebSocketV2Mode
(
t
*
testing
.
T
)
{
func
TestAccount_ResolveOpenAIResponsesWebSocketV2Mode
(
t
*
testing
.
T
)
{
t
.
Run
(
"default fallback to
shared
"
,
func
(
t
*
testing
.
T
)
{
t
.
Run
(
"default fallback to
ctx_pool
"
,
func
(
t
*
testing
.
T
)
{
account
:=
&
Account
{
account
:=
&
Account
{
Platform
:
PlatformOpenAI
,
Platform
:
PlatformOpenAI
,
Type
:
AccountTypeOAuth
,
Type
:
AccountTypeOAuth
,
Extra
:
map
[
string
]
any
{},
Extra
:
map
[
string
]
any
{},
}
}
require
.
Equal
(
t
,
OpenAIWSIngressMode
Shared
,
account
.
ResolveOpenAIResponsesWebSocketV2Mode
(
""
))
require
.
Equal
(
t
,
OpenAIWSIngressMode
CtxPool
,
account
.
ResolveOpenAIResponsesWebSocketV2Mode
(
""
))
require
.
Equal
(
t
,
OpenAIWSIngressMode
Shared
,
account
.
ResolveOpenAIResponsesWebSocketV2Mode
(
"invalid"
))
require
.
Equal
(
t
,
OpenAIWSIngressMode
CtxPool
,
account
.
ResolveOpenAIResponsesWebSocketV2Mode
(
"invalid"
))
})
})
t
.
Run
(
"oauth mode field has highest priority"
,
func
(
t
*
testing
.
T
)
{
t
.
Run
(
"oauth mode field has highest priority"
,
func
(
t
*
testing
.
T
)
{
...
@@ -221,15 +221,15 @@ func TestAccount_ResolveOpenAIResponsesWebSocketV2Mode(t *testing.T) {
...
@@ -221,15 +221,15 @@ func TestAccount_ResolveOpenAIResponsesWebSocketV2Mode(t *testing.T) {
Platform
:
PlatformOpenAI
,
Platform
:
PlatformOpenAI
,
Type
:
AccountTypeOAuth
,
Type
:
AccountTypeOAuth
,
Extra
:
map
[
string
]
any
{
Extra
:
map
[
string
]
any
{
"openai_oauth_responses_websockets_v2_mode"
:
OpenAIWSIngressMode
Dedicated
,
"openai_oauth_responses_websockets_v2_mode"
:
OpenAIWSIngressMode
Passthrough
,
"openai_oauth_responses_websockets_v2_enabled"
:
false
,
"openai_oauth_responses_websockets_v2_enabled"
:
false
,
"responses_websockets_v2_enabled"
:
false
,
"responses_websockets_v2_enabled"
:
false
,
},
},
}
}
require
.
Equal
(
t
,
OpenAIWSIngressMode
Dedicated
,
account
.
ResolveOpenAIResponsesWebSocketV2Mode
(
OpenAIWSIngressMode
Shared
))
require
.
Equal
(
t
,
OpenAIWSIngressMode
Passthrough
,
account
.
ResolveOpenAIResponsesWebSocketV2Mode
(
OpenAIWSIngressMode
CtxPool
))
})
})
t
.
Run
(
"legacy enabled maps to
shared
"
,
func
(
t
*
testing
.
T
)
{
t
.
Run
(
"legacy enabled maps to
ctx_pool
"
,
func
(
t
*
testing
.
T
)
{
account
:=
&
Account
{
account
:=
&
Account
{
Platform
:
PlatformOpenAI
,
Platform
:
PlatformOpenAI
,
Type
:
AccountTypeAPIKey
,
Type
:
AccountTypeAPIKey
,
...
@@ -237,7 +237,28 @@ func TestAccount_ResolveOpenAIResponsesWebSocketV2Mode(t *testing.T) {
...
@@ -237,7 +237,28 @@ func TestAccount_ResolveOpenAIResponsesWebSocketV2Mode(t *testing.T) {
"responses_websockets_v2_enabled"
:
true
,
"responses_websockets_v2_enabled"
:
true
,
},
},
}
}
require
.
Equal
(
t
,
OpenAIWSIngressModeShared
,
account
.
ResolveOpenAIResponsesWebSocketV2Mode
(
OpenAIWSIngressModeOff
))
require
.
Equal
(
t
,
OpenAIWSIngressModeCtxPool
,
account
.
ResolveOpenAIResponsesWebSocketV2Mode
(
OpenAIWSIngressModeOff
))
})
t
.
Run
(
"shared/dedicated mode strings are compatible with ctx_pool"
,
func
(
t
*
testing
.
T
)
{
shared
:=
&
Account
{
Platform
:
PlatformOpenAI
,
Type
:
AccountTypeOAuth
,
Extra
:
map
[
string
]
any
{
"openai_oauth_responses_websockets_v2_mode"
:
OpenAIWSIngressModeShared
,
},
}
dedicated
:=
&
Account
{
Platform
:
PlatformOpenAI
,
Type
:
AccountTypeOAuth
,
Extra
:
map
[
string
]
any
{
"openai_oauth_responses_websockets_v2_mode"
:
OpenAIWSIngressModeDedicated
,
},
}
require
.
Equal
(
t
,
OpenAIWSIngressModeShared
,
shared
.
ResolveOpenAIResponsesWebSocketV2Mode
(
OpenAIWSIngressModeOff
))
require
.
Equal
(
t
,
OpenAIWSIngressModeDedicated
,
dedicated
.
ResolveOpenAIResponsesWebSocketV2Mode
(
OpenAIWSIngressModeOff
))
require
.
Equal
(
t
,
OpenAIWSIngressModeCtxPool
,
normalizeOpenAIWSIngressDefaultMode
(
OpenAIWSIngressModeShared
))
require
.
Equal
(
t
,
OpenAIWSIngressModeCtxPool
,
normalizeOpenAIWSIngressDefaultMode
(
OpenAIWSIngressModeDedicated
))
})
})
t
.
Run
(
"legacy disabled maps to off"
,
func
(
t
*
testing
.
T
)
{
t
.
Run
(
"legacy disabled maps to off"
,
func
(
t
*
testing
.
T
)
{
...
@@ -249,7 +270,7 @@ func TestAccount_ResolveOpenAIResponsesWebSocketV2Mode(t *testing.T) {
...
@@ -249,7 +270,7 @@ func TestAccount_ResolveOpenAIResponsesWebSocketV2Mode(t *testing.T) {
"responses_websockets_v2_enabled"
:
true
,
"responses_websockets_v2_enabled"
:
true
,
},
},
}
}
require
.
Equal
(
t
,
OpenAIWSIngressModeOff
,
account
.
ResolveOpenAIResponsesWebSocketV2Mode
(
OpenAIWSIngressMode
Shared
))
require
.
Equal
(
t
,
OpenAIWSIngressModeOff
,
account
.
ResolveOpenAIResponsesWebSocketV2Mode
(
OpenAIWSIngressMode
CtxPool
))
})
})
t
.
Run
(
"non openai always off"
,
func
(
t
*
testing
.
T
)
{
t
.
Run
(
"non openai always off"
,
func
(
t
*
testing
.
T
)
{
...
...
backend/internal/service/openai_gateway_service.go
View file @
7076717b
...
@@ -263,13 +263,15 @@ type OpenAIGatewayService struct {
...
@@ -263,13 +263,15 @@ type OpenAIGatewayService struct {
toolCorrector
*
CodexToolCorrector
toolCorrector
*
CodexToolCorrector
openaiWSResolver
OpenAIWSProtocolResolver
openaiWSResolver
OpenAIWSProtocolResolver
openaiWSPoolOnce
sync
.
Once
openaiWSPoolOnce
sync
.
Once
openaiWSStateStoreOnce
sync
.
Once
openaiWSStateStoreOnce
sync
.
Once
openaiSchedulerOnce
sync
.
Once
openaiSchedulerOnce
sync
.
Once
openaiWSPool
*
openAIWSConnPool
openaiWSPassthroughDialerOnce
sync
.
Once
openaiWSStateStore
OpenAIWSStateStore
openaiWSPool
*
openAIWSConnPool
openaiScheduler
OpenAIAccountScheduler
openaiWSStateStore
OpenAIWSStateStore
openaiAccountStats
*
openAIAccountRuntimeStats
openaiScheduler
OpenAIAccountScheduler
openaiWSPassthroughDialer
openAIWSClientDialer
openaiAccountStats
*
openAIAccountRuntimeStats
openaiWSFallbackUntil
sync
.
Map
// key: int64(accountID), value: time.Time
openaiWSFallbackUntil
sync
.
Map
// key: int64(accountID), value: time.Time
openaiWSRetryMetrics
openAIWSRetryMetrics
openaiWSRetryMetrics
openAIWSRetryMetrics
...
...
backend/internal/service/openai_ws_client.go
View file @
7076717b
...
@@ -11,6 +11,7 @@ import (
...
@@ -11,6 +11,7 @@ import (
"sync/atomic"
"sync/atomic"
"time"
"time"
openaiwsv2
"github.com/Wei-Shaw/sub2api/internal/service/openai_ws_v2"
coderws
"github.com/coder/websocket"
coderws
"github.com/coder/websocket"
"github.com/coder/websocket/wsjson"
"github.com/coder/websocket/wsjson"
)
)
...
@@ -234,6 +235,8 @@ type coderOpenAIWSClientConn struct {
...
@@ -234,6 +235,8 @@ type coderOpenAIWSClientConn struct {
conn
*
coderws
.
Conn
conn
*
coderws
.
Conn
}
}
var
_
openaiwsv2
.
FrameConn
=
(
*
coderOpenAIWSClientConn
)(
nil
)
func
(
c
*
coderOpenAIWSClientConn
)
WriteJSON
(
ctx
context
.
Context
,
value
any
)
error
{
func
(
c
*
coderOpenAIWSClientConn
)
WriteJSON
(
ctx
context
.
Context
,
value
any
)
error
{
if
c
==
nil
||
c
.
conn
==
nil
{
if
c
==
nil
||
c
.
conn
==
nil
{
return
errOpenAIWSConnClosed
return
errOpenAIWSConnClosed
...
@@ -264,6 +267,30 @@ func (c *coderOpenAIWSClientConn) ReadMessage(ctx context.Context) ([]byte, erro
...
@@ -264,6 +267,30 @@ func (c *coderOpenAIWSClientConn) ReadMessage(ctx context.Context) ([]byte, erro
}
}
}
}
func
(
c
*
coderOpenAIWSClientConn
)
ReadFrame
(
ctx
context
.
Context
)
(
coderws
.
MessageType
,
[]
byte
,
error
)
{
if
c
==
nil
||
c
.
conn
==
nil
{
return
coderws
.
MessageText
,
nil
,
errOpenAIWSConnClosed
}
if
ctx
==
nil
{
ctx
=
context
.
Background
()
}
msgType
,
payload
,
err
:=
c
.
conn
.
Read
(
ctx
)
if
err
!=
nil
{
return
coderws
.
MessageText
,
nil
,
err
}
return
msgType
,
payload
,
nil
}
func
(
c
*
coderOpenAIWSClientConn
)
WriteFrame
(
ctx
context
.
Context
,
msgType
coderws
.
MessageType
,
payload
[]
byte
)
error
{
if
c
==
nil
||
c
.
conn
==
nil
{
return
errOpenAIWSConnClosed
}
if
ctx
==
nil
{
ctx
=
context
.
Background
()
}
return
c
.
conn
.
Write
(
ctx
,
msgType
,
payload
)
}
func
(
c
*
coderOpenAIWSClientConn
)
Ping
(
ctx
context
.
Context
)
error
{
func
(
c
*
coderOpenAIWSClientConn
)
Ping
(
ctx
context
.
Context
)
error
{
if
c
==
nil
||
c
.
conn
==
nil
{
if
c
==
nil
||
c
.
conn
==
nil
{
return
errOpenAIWSConnClosed
return
errOpenAIWSConnClosed
...
...
backend/internal/service/openai_ws_forwarder.go
View file @
7076717b
...
@@ -46,9 +46,10 @@ const (
...
@@ -46,9 +46,10 @@ const (
openAIWSPayloadSizeEstimateMaxBytes
=
64
*
1024
openAIWSPayloadSizeEstimateMaxBytes
=
64
*
1024
openAIWSPayloadSizeEstimateMaxItems
=
16
openAIWSPayloadSizeEstimateMaxItems
=
16
openAIWSEventFlushBatchSizeDefault
=
4
openAIWSEventFlushBatchSizeDefault
=
4
openAIWSEventFlushIntervalDefault
=
25
*
time
.
Millisecond
openAIWSEventFlushIntervalDefault
=
25
*
time
.
Millisecond
openAIWSPayloadLogSampleDefault
=
0.2
openAIWSPayloadLogSampleDefault
=
0.2
openAIWSPassthroughIdleTimeoutDefault
=
time
.
Hour
openAIWSStoreDisabledConnModeStrict
=
"strict"
openAIWSStoreDisabledConnModeStrict
=
"strict"
openAIWSStoreDisabledConnModeAdaptive
=
"adaptive"
openAIWSStoreDisabledConnModeAdaptive
=
"adaptive"
...
@@ -904,6 +905,18 @@ func (s *OpenAIGatewayService) getOpenAIWSConnPool() *openAIWSConnPool {
...
@@ -904,6 +905,18 @@ func (s *OpenAIGatewayService) getOpenAIWSConnPool() *openAIWSConnPool {
return
s
.
openaiWSPool
return
s
.
openaiWSPool
}
}
func
(
s
*
OpenAIGatewayService
)
getOpenAIWSPassthroughDialer
()
openAIWSClientDialer
{
if
s
==
nil
{
return
nil
}
s
.
openaiWSPassthroughDialerOnce
.
Do
(
func
()
{
if
s
.
openaiWSPassthroughDialer
==
nil
{
s
.
openaiWSPassthroughDialer
=
newDefaultOpenAIWSClientDialer
()
}
})
return
s
.
openaiWSPassthroughDialer
}
func
(
s
*
OpenAIGatewayService
)
SnapshotOpenAIWSPoolMetrics
()
OpenAIWSPoolMetricsSnapshot
{
func
(
s
*
OpenAIGatewayService
)
SnapshotOpenAIWSPoolMetrics
()
OpenAIWSPoolMetricsSnapshot
{
pool
:=
s
.
getOpenAIWSConnPool
()
pool
:=
s
.
getOpenAIWSConnPool
()
if
pool
==
nil
{
if
pool
==
nil
{
...
@@ -967,6 +980,13 @@ func (s *OpenAIGatewayService) openAIWSReadTimeout() time.Duration {
...
@@ -967,6 +980,13 @@ func (s *OpenAIGatewayService) openAIWSReadTimeout() time.Duration {
return
15
*
time
.
Minute
return
15
*
time
.
Minute
}
}
func
(
s
*
OpenAIGatewayService
)
openAIWSPassthroughIdleTimeout
()
time
.
Duration
{
if
timeout
:=
s
.
openAIWSReadTimeout
();
timeout
>
0
{
return
timeout
}
return
openAIWSPassthroughIdleTimeoutDefault
}
func
(
s
*
OpenAIGatewayService
)
openAIWSWriteTimeout
()
time
.
Duration
{
func
(
s
*
OpenAIGatewayService
)
openAIWSWriteTimeout
()
time
.
Duration
{
if
s
!=
nil
&&
s
.
cfg
!=
nil
&&
s
.
cfg
.
Gateway
.
OpenAIWS
.
WriteTimeoutSeconds
>
0
{
if
s
!=
nil
&&
s
.
cfg
!=
nil
&&
s
.
cfg
.
Gateway
.
OpenAIWS
.
WriteTimeoutSeconds
>
0
{
return
time
.
Duration
(
s
.
cfg
.
Gateway
.
OpenAIWS
.
WriteTimeoutSeconds
)
*
time
.
Second
return
time
.
Duration
(
s
.
cfg
.
Gateway
.
OpenAIWS
.
WriteTimeoutSeconds
)
*
time
.
Second
...
@@ -2322,7 +2342,7 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
...
@@ -2322,7 +2342,7 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
wsDecision
:=
s
.
getOpenAIWSProtocolResolver
()
.
Resolve
(
account
)
wsDecision
:=
s
.
getOpenAIWSProtocolResolver
()
.
Resolve
(
account
)
modeRouterV2Enabled
:=
s
!=
nil
&&
s
.
cfg
!=
nil
&&
s
.
cfg
.
Gateway
.
OpenAIWS
.
ModeRouterV2Enabled
modeRouterV2Enabled
:=
s
!=
nil
&&
s
.
cfg
!=
nil
&&
s
.
cfg
.
Gateway
.
OpenAIWS
.
ModeRouterV2Enabled
ingressMode
:=
OpenAIWSIngressMode
Shared
ingressMode
:=
OpenAIWSIngressMode
CtxPool
if
modeRouterV2Enabled
{
if
modeRouterV2Enabled
{
ingressMode
=
account
.
ResolveOpenAIResponsesWebSocketV2Mode
(
s
.
cfg
.
Gateway
.
OpenAIWS
.
IngressModeDefault
)
ingressMode
=
account
.
ResolveOpenAIResponsesWebSocketV2Mode
(
s
.
cfg
.
Gateway
.
OpenAIWS
.
IngressModeDefault
)
if
ingressMode
==
OpenAIWSIngressModeOff
{
if
ingressMode
==
OpenAIWSIngressModeOff
{
...
@@ -2332,6 +2352,30 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
...
@@ -2332,6 +2352,30 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
nil
,
nil
,
)
)
}
}
switch
ingressMode
{
case
OpenAIWSIngressModePassthrough
:
if
wsDecision
.
Transport
!=
OpenAIUpstreamTransportResponsesWebsocketV2
{
return
fmt
.
Errorf
(
"websocket ingress requires ws_v2 transport, got=%s"
,
wsDecision
.
Transport
)
}
return
s
.
proxyResponsesWebSocketV2Passthrough
(
ctx
,
c
,
clientConn
,
account
,
token
,
firstClientMessage
,
hooks
,
wsDecision
,
)
case
OpenAIWSIngressModeCtxPool
,
OpenAIWSIngressModeShared
,
OpenAIWSIngressModeDedicated
:
// continue
default
:
return
NewOpenAIWSClientCloseError
(
coderws
.
StatusPolicyViolation
,
"websocket mode only supports ctx_pool/passthrough"
,
nil
,
)
}
}
}
if
wsDecision
.
Transport
!=
OpenAIUpstreamTransportResponsesWebsocketV2
{
if
wsDecision
.
Transport
!=
OpenAIUpstreamTransportResponsesWebsocketV2
{
return
fmt
.
Errorf
(
"websocket ingress requires ws_v2 transport, got=%s"
,
wsDecision
.
Transport
)
return
fmt
.
Errorf
(
"websocket ingress requires ws_v2 transport, got=%s"
,
wsDecision
.
Transport
)
...
...
backend/internal/service/openai_ws_forwarder_ingress_session_test.go
View file @
7076717b
...
@@ -149,7 +149,7 @@ func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_KeepLeaseAcrossT
...
@@ -149,7 +149,7 @@ func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_KeepLeaseAcrossT
require
.
True
(
t
,
<-
turnWSModeCh
,
"首轮 turn 应标记为 WS 模式"
)
require
.
True
(
t
,
<-
turnWSModeCh
,
"首轮 turn 应标记为 WS 模式"
)
require
.
True
(
t
,
<-
turnWSModeCh
,
"第二轮 turn 应标记为 WS 模式"
)
require
.
True
(
t
,
<-
turnWSModeCh
,
"第二轮 turn 应标记为 WS 模式"
)
require
.
NoError
(
t
,
clientConn
.
Close
(
coderws
.
StatusNormalClosure
,
"done"
)
)
_
=
clientConn
.
Close
(
coderws
.
StatusNormalClosure
,
"done"
)
select
{
select
{
case
serverErr
:=
<-
serverErrCh
:
case
serverErr
:=
<-
serverErrCh
:
...
@@ -298,6 +298,140 @@ func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_DedicatedModeDoe
...
@@ -298,6 +298,140 @@ func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_DedicatedModeDoe
require
.
Equal
(
t
,
2
,
dialer
.
DialCount
(),
"dedicated 模式下跨客户端会话不应复用上游连接"
)
require
.
Equal
(
t
,
2
,
dialer
.
DialCount
(),
"dedicated 模式下跨客户端会话不应复用上游连接"
)
}
}
func
TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_PassthroughModeRelaysByCaddyAdapter
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
cfg
:=
&
config
.
Config
{}
cfg
.
Security
.
URLAllowlist
.
Enabled
=
false
cfg
.
Security
.
URLAllowlist
.
AllowInsecureHTTP
=
true
cfg
.
Gateway
.
OpenAIWS
.
Enabled
=
true
cfg
.
Gateway
.
OpenAIWS
.
OAuthEnabled
=
true
cfg
.
Gateway
.
OpenAIWS
.
APIKeyEnabled
=
true
cfg
.
Gateway
.
OpenAIWS
.
ResponsesWebsocketsV2
=
true
cfg
.
Gateway
.
OpenAIWS
.
ModeRouterV2Enabled
=
true
cfg
.
Gateway
.
OpenAIWS
.
IngressModeDefault
=
OpenAIWSIngressModeCtxPool
cfg
.
Gateway
.
OpenAIWS
.
DialTimeoutSeconds
=
3
cfg
.
Gateway
.
OpenAIWS
.
ReadTimeoutSeconds
=
3
cfg
.
Gateway
.
OpenAIWS
.
WriteTimeoutSeconds
=
3
upstreamConn
:=
&
openAIWSCaptureConn
{
events
:
[][]
byte
{
[]
byte
(
`{"type":"response.completed","response":{"id":"resp_passthrough_turn_1","model":"gpt-5.1","usage":{"input_tokens":2,"output_tokens":3}}}`
),
},
}
captureDialer
:=
&
openAIWSCaptureDialer
{
conn
:
upstreamConn
}
svc
:=
&
OpenAIGatewayService
{
cfg
:
cfg
,
httpUpstream
:
&
httpUpstreamRecorder
{},
cache
:
&
stubGatewayCache
{},
openaiWSResolver
:
NewOpenAIWSProtocolResolver
(
cfg
),
toolCorrector
:
NewCodexToolCorrector
(),
openaiWSPassthroughDialer
:
captureDialer
,
}
account
:=
&
Account
{
ID
:
452
,
Name
:
"openai-ingress-passthrough"
,
Platform
:
PlatformOpenAI
,
Type
:
AccountTypeAPIKey
,
Status
:
StatusActive
,
Schedulable
:
true
,
Concurrency
:
1
,
Credentials
:
map
[
string
]
any
{
"api_key"
:
"sk-test"
,
},
Extra
:
map
[
string
]
any
{
"openai_apikey_responses_websockets_v2_mode"
:
OpenAIWSIngressModePassthrough
,
},
}
serverErrCh
:=
make
(
chan
error
,
1
)
resultCh
:=
make
(
chan
*
OpenAIForwardResult
,
1
)
hooks
:=
&
OpenAIWSIngressHooks
{
AfterTurn
:
func
(
_
int
,
result
*
OpenAIForwardResult
,
turnErr
error
)
{
if
turnErr
==
nil
&&
result
!=
nil
{
resultCh
<-
result
}
},
}
wsServer
:=
httptest
.
NewServer
(
http
.
HandlerFunc
(
func
(
w
http
.
ResponseWriter
,
r
*
http
.
Request
)
{
conn
,
err
:=
coderws
.
Accept
(
w
,
r
,
&
coderws
.
AcceptOptions
{
CompressionMode
:
coderws
.
CompressionContextTakeover
,
})
if
err
!=
nil
{
serverErrCh
<-
err
return
}
defer
func
()
{
_
=
conn
.
CloseNow
()
}()
rec
:=
httptest
.
NewRecorder
()
ginCtx
,
_
:=
gin
.
CreateTestContext
(
rec
)
req
:=
r
.
Clone
(
r
.
Context
())
req
.
Header
=
req
.
Header
.
Clone
()
req
.
Header
.
Set
(
"User-Agent"
,
"unit-test-agent/1.0"
)
ginCtx
.
Request
=
req
readCtx
,
cancel
:=
context
.
WithTimeout
(
r
.
Context
(),
3
*
time
.
Second
)
msgType
,
firstMessage
,
readErr
:=
conn
.
Read
(
readCtx
)
cancel
()
if
readErr
!=
nil
{
serverErrCh
<-
readErr
return
}
if
msgType
!=
coderws
.
MessageText
&&
msgType
!=
coderws
.
MessageBinary
{
serverErrCh
<-
errors
.
New
(
"unsupported websocket client message type"
)
return
}
serverErrCh
<-
svc
.
ProxyResponsesWebSocketFromClient
(
r
.
Context
(),
ginCtx
,
conn
,
account
,
"sk-test"
,
firstMessage
,
hooks
)
}))
defer
wsServer
.
Close
()
dialCtx
,
cancelDial
:=
context
.
WithTimeout
(
context
.
Background
(),
3
*
time
.
Second
)
clientConn
,
_
,
err
:=
coderws
.
Dial
(
dialCtx
,
"ws"
+
strings
.
TrimPrefix
(
wsServer
.
URL
,
"http"
),
nil
)
cancelDial
()
require
.
NoError
(
t
,
err
)
defer
func
()
{
_
=
clientConn
.
CloseNow
()
}()
writeCtx
,
cancelWrite
:=
context
.
WithTimeout
(
context
.
Background
(),
3
*
time
.
Second
)
err
=
clientConn
.
Write
(
writeCtx
,
coderws
.
MessageText
,
[]
byte
(
`{"type":"response.create","model":"gpt-5.1","stream":false}`
))
cancelWrite
()
require
.
NoError
(
t
,
err
)
readCtx
,
cancelRead
:=
context
.
WithTimeout
(
context
.
Background
(),
3
*
time
.
Second
)
_
,
event
,
readErr
:=
clientConn
.
Read
(
readCtx
)
cancelRead
()
require
.
NoError
(
t
,
readErr
)
require
.
Equal
(
t
,
"response.completed"
,
gjson
.
GetBytes
(
event
,
"type"
)
.
String
())
require
.
Equal
(
t
,
"resp_passthrough_turn_1"
,
gjson
.
GetBytes
(
event
,
"response.id"
)
.
String
())
_
=
clientConn
.
Close
(
coderws
.
StatusNormalClosure
,
"done"
)
select
{
case
serverErr
:=
<-
serverErrCh
:
require
.
NoError
(
t
,
serverErr
)
case
<-
time
.
After
(
5
*
time
.
Second
)
:
t
.
Fatal
(
"等待 passthrough websocket 结束超时"
)
}
select
{
case
result
:=
<-
resultCh
:
require
.
Equal
(
t
,
"resp_passthrough_turn_1"
,
result
.
RequestID
)
require
.
True
(
t
,
result
.
OpenAIWSMode
)
require
.
Equal
(
t
,
2
,
result
.
Usage
.
InputTokens
)
require
.
Equal
(
t
,
3
,
result
.
Usage
.
OutputTokens
)
case
<-
time
.
After
(
2
*
time
.
Second
)
:
t
.
Fatal
(
"未收到 passthrough turn 结果回调"
)
}
require
.
Equal
(
t
,
1
,
captureDialer
.
DialCount
(),
"passthrough 模式应直接建立上游 websocket"
)
require
.
Len
(
t
,
upstreamConn
.
writes
,
1
,
"passthrough 模式应透传首条 response.create"
)
}
func
TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_ModeOffReturnsPolicyViolation
(
t
*
testing
.
T
)
{
func
TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_ModeOffReturnsPolicyViolation
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
gin
.
SetMode
(
gin
.
TestMode
)
...
...
backend/internal/service/openai_ws_forwarder_success_test.go
View file @
7076717b
...
@@ -15,6 +15,7 @@ import (
...
@@ -15,6 +15,7 @@ import (
"time"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/config"
coderws
"github.com/coder/websocket"
"github.com/gin-gonic/gin"
"github.com/gin-gonic/gin"
"github.com/gorilla/websocket"
"github.com/gorilla/websocket"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/require"
...
@@ -1282,6 +1283,18 @@ func (c *openAIWSCaptureConn) ReadMessage(ctx context.Context) ([]byte, error) {
...
@@ -1282,6 +1283,18 @@ func (c *openAIWSCaptureConn) ReadMessage(ctx context.Context) ([]byte, error) {
return
event
,
nil
return
event
,
nil
}
}
func
(
c
*
openAIWSCaptureConn
)
ReadFrame
(
ctx
context
.
Context
)
(
coderws
.
MessageType
,
[]
byte
,
error
)
{
payload
,
err
:=
c
.
ReadMessage
(
ctx
)
if
err
!=
nil
{
return
coderws
.
MessageText
,
nil
,
err
}
return
coderws
.
MessageText
,
payload
,
nil
}
func
(
c
*
openAIWSCaptureConn
)
WriteFrame
(
ctx
context
.
Context
,
_
coderws
.
MessageType
,
payload
[]
byte
)
error
{
return
c
.
WriteJSON
(
ctx
,
json
.
RawMessage
(
payload
))
}
func
(
c
*
openAIWSCaptureConn
)
Ping
(
ctx
context
.
Context
)
error
{
func
(
c
*
openAIWSCaptureConn
)
Ping
(
ctx
context
.
Context
)
error
{
_
=
ctx
_
=
ctx
return
nil
return
nil
...
...
backend/internal/service/openai_ws_protocol_resolver.go
View file @
7076717b
...
@@ -69,8 +69,11 @@ func (r *defaultOpenAIWSProtocolResolver) Resolve(account *Account) OpenAIWSProt
...
@@ -69,8 +69,11 @@ func (r *defaultOpenAIWSProtocolResolver) Resolve(account *Account) OpenAIWSProt
switch
mode
{
switch
mode
{
case
OpenAIWSIngressModeOff
:
case
OpenAIWSIngressModeOff
:
return
openAIWSHTTPDecision
(
"account_mode_off"
)
return
openAIWSHTTPDecision
(
"account_mode_off"
)
case
OpenAIWSIngressMode
Shared
,
OpenAIWSIngressMode
Dedicated
:
case
OpenAIWSIngressMode
CtxPool
,
OpenAIWSIngressMode
Passthrough
:
// continue
// continue
case
OpenAIWSIngressModeShared
,
OpenAIWSIngressModeDedicated
:
// 历史值兼容:按 ctx_pool 处理。
mode
=
OpenAIWSIngressModeCtxPool
default
:
default
:
return
openAIWSHTTPDecision
(
"account_mode_off"
)
return
openAIWSHTTPDecision
(
"account_mode_off"
)
}
}
...
...
backend/internal/service/openai_ws_protocol_resolver_test.go
View file @
7076717b
...
@@ -143,21 +143,21 @@ func TestOpenAIWSProtocolResolver_Resolve_ModeRouterV2(t *testing.T) {
...
@@ -143,21 +143,21 @@ func TestOpenAIWSProtocolResolver_Resolve_ModeRouterV2(t *testing.T) {
cfg
.
Gateway
.
OpenAIWS
.
APIKeyEnabled
=
true
cfg
.
Gateway
.
OpenAIWS
.
APIKeyEnabled
=
true
cfg
.
Gateway
.
OpenAIWS
.
ResponsesWebsocketsV2
=
true
cfg
.
Gateway
.
OpenAIWS
.
ResponsesWebsocketsV2
=
true
cfg
.
Gateway
.
OpenAIWS
.
ModeRouterV2Enabled
=
true
cfg
.
Gateway
.
OpenAIWS
.
ModeRouterV2Enabled
=
true
cfg
.
Gateway
.
OpenAIWS
.
IngressModeDefault
=
OpenAIWSIngressMode
Shared
cfg
.
Gateway
.
OpenAIWS
.
IngressModeDefault
=
OpenAIWSIngressMode
CtxPool
account
:=
&
Account
{
account
:=
&
Account
{
Platform
:
PlatformOpenAI
,
Platform
:
PlatformOpenAI
,
Type
:
AccountTypeOAuth
,
Type
:
AccountTypeOAuth
,
Concurrency
:
1
,
Concurrency
:
1
,
Extra
:
map
[
string
]
any
{
Extra
:
map
[
string
]
any
{
"openai_oauth_responses_websockets_v2_mode"
:
OpenAIWSIngressMode
Dedicated
,
"openai_oauth_responses_websockets_v2_mode"
:
OpenAIWSIngressMode
CtxPool
,
},
},
}
}
t
.
Run
(
"
dedicated
mode routes to ws v2"
,
func
(
t
*
testing
.
T
)
{
t
.
Run
(
"
ctx_pool
mode routes to ws v2"
,
func
(
t
*
testing
.
T
)
{
decision
:=
NewOpenAIWSProtocolResolver
(
cfg
)
.
Resolve
(
account
)
decision
:=
NewOpenAIWSProtocolResolver
(
cfg
)
.
Resolve
(
account
)
require
.
Equal
(
t
,
OpenAIUpstreamTransportResponsesWebsocketV2
,
decision
.
Transport
)
require
.
Equal
(
t
,
OpenAIUpstreamTransportResponsesWebsocketV2
,
decision
.
Transport
)
require
.
Equal
(
t
,
"ws_v2_mode_
dedicated
"
,
decision
.
Reason
)
require
.
Equal
(
t
,
"ws_v2_mode_
ctx_pool
"
,
decision
.
Reason
)
})
})
t
.
Run
(
"off mode routes to http"
,
func
(
t
*
testing
.
T
)
{
t
.
Run
(
"off mode routes to http"
,
func
(
t
*
testing
.
T
)
{
...
@@ -174,7 +174,7 @@ func TestOpenAIWSProtocolResolver_Resolve_ModeRouterV2(t *testing.T) {
...
@@ -174,7 +174,7 @@ func TestOpenAIWSProtocolResolver_Resolve_ModeRouterV2(t *testing.T) {
require
.
Equal
(
t
,
"account_mode_off"
,
decision
.
Reason
)
require
.
Equal
(
t
,
"account_mode_off"
,
decision
.
Reason
)
})
})
t
.
Run
(
"legacy boolean maps to
shared
in v2 router"
,
func
(
t
*
testing
.
T
)
{
t
.
Run
(
"legacy boolean maps to
ctx_pool
in v2 router"
,
func
(
t
*
testing
.
T
)
{
legacyAccount
:=
&
Account
{
legacyAccount
:=
&
Account
{
Platform
:
PlatformOpenAI
,
Platform
:
PlatformOpenAI
,
Type
:
AccountTypeAPIKey
,
Type
:
AccountTypeAPIKey
,
...
@@ -185,7 +185,21 @@ func TestOpenAIWSProtocolResolver_Resolve_ModeRouterV2(t *testing.T) {
...
@@ -185,7 +185,21 @@ func TestOpenAIWSProtocolResolver_Resolve_ModeRouterV2(t *testing.T) {
}
}
decision
:=
NewOpenAIWSProtocolResolver
(
cfg
)
.
Resolve
(
legacyAccount
)
decision
:=
NewOpenAIWSProtocolResolver
(
cfg
)
.
Resolve
(
legacyAccount
)
require
.
Equal
(
t
,
OpenAIUpstreamTransportResponsesWebsocketV2
,
decision
.
Transport
)
require
.
Equal
(
t
,
OpenAIUpstreamTransportResponsesWebsocketV2
,
decision
.
Transport
)
require
.
Equal
(
t
,
"ws_v2_mode_shared"
,
decision
.
Reason
)
require
.
Equal
(
t
,
"ws_v2_mode_ctx_pool"
,
decision
.
Reason
)
})
t
.
Run
(
"passthrough mode routes to ws v2"
,
func
(
t
*
testing
.
T
)
{
passthroughAccount
:=
&
Account
{
Platform
:
PlatformOpenAI
,
Type
:
AccountTypeOAuth
,
Concurrency
:
1
,
Extra
:
map
[
string
]
any
{
"openai_oauth_responses_websockets_v2_mode"
:
OpenAIWSIngressModePassthrough
,
},
}
decision
:=
NewOpenAIWSProtocolResolver
(
cfg
)
.
Resolve
(
passthroughAccount
)
require
.
Equal
(
t
,
OpenAIUpstreamTransportResponsesWebsocketV2
,
decision
.
Transport
)
require
.
Equal
(
t
,
"ws_v2_mode_passthrough"
,
decision
.
Reason
)
})
})
t
.
Run
(
"non-positive concurrency is rejected in v2 router"
,
func
(
t
*
testing
.
T
)
{
t
.
Run
(
"non-positive concurrency is rejected in v2 router"
,
func
(
t
*
testing
.
T
)
{
...
@@ -193,7 +207,7 @@ func TestOpenAIWSProtocolResolver_Resolve_ModeRouterV2(t *testing.T) {
...
@@ -193,7 +207,7 @@ func TestOpenAIWSProtocolResolver_Resolve_ModeRouterV2(t *testing.T) {
Platform
:
PlatformOpenAI
,
Platform
:
PlatformOpenAI
,
Type
:
AccountTypeOAuth
,
Type
:
AccountTypeOAuth
,
Extra
:
map
[
string
]
any
{
Extra
:
map
[
string
]
any
{
"openai_oauth_responses_websockets_v2_mode"
:
OpenAIWSIngressMode
Shared
,
"openai_oauth_responses_websockets_v2_mode"
:
OpenAIWSIngressMode
CtxPool
,
},
},
}
}
decision
:=
NewOpenAIWSProtocolResolver
(
cfg
)
.
Resolve
(
invalidConcurrency
)
decision
:=
NewOpenAIWSProtocolResolver
(
cfg
)
.
Resolve
(
invalidConcurrency
)
...
...
backend/internal/service/openai_ws_v2/caddy_adapter.go
0 → 100644
View file @
7076717b
package
openai_ws_v2
import
(
"context"
)
// runCaddyStyleRelay 采用 Caddy reverseproxy 的双向隧道思想:
// 连接建立后并发复制两个方向,任一方向退出触发收敛关闭。
//
// Reference:
// - Project: caddyserver/caddy (Apache-2.0)
// - Commit: f283062d37c50627d53ca682ebae2ce219b35515
// - Files:
// - modules/caddyhttp/reverseproxy/streaming.go
// - modules/caddyhttp/reverseproxy/reverseproxy.go
func
runCaddyStyleRelay
(
ctx
context
.
Context
,
clientConn
FrameConn
,
upstreamConn
FrameConn
,
firstClientMessage
[]
byte
,
options
RelayOptions
,
)
(
RelayResult
,
*
RelayExit
)
{
return
Relay
(
ctx
,
clientConn
,
upstreamConn
,
firstClientMessage
,
options
)
}
backend/internal/service/openai_ws_v2/entry.go
0 → 100644
View file @
7076717b
package
openai_ws_v2
import
"context"
// EntryInput 是 passthrough v2 数据面的入口参数。
type
EntryInput
struct
{
Ctx
context
.
Context
ClientConn
FrameConn
UpstreamConn
FrameConn
FirstClientMessage
[]
byte
Options
RelayOptions
}
// RunEntry 是 openai_ws_v2 包对外的统一入口。
func
RunEntry
(
input
EntryInput
)
(
RelayResult
,
*
RelayExit
)
{
return
runCaddyStyleRelay
(
input
.
Ctx
,
input
.
ClientConn
,
input
.
UpstreamConn
,
input
.
FirstClientMessage
,
input
.
Options
,
)
}
backend/internal/service/openai_ws_v2/metrics.go
0 → 100644
View file @
7076717b
package
openai_ws_v2
import
(
"sync/atomic"
)
// MetricsSnapshot 是 OpenAI WS v2 passthrough 路径的轻量运行时指标快照。
type
MetricsSnapshot
struct
{
SemanticMutationTotal
int64
`json:"semantic_mutation_total"`
UsageParseFailureTotal
int64
`json:"usage_parse_failure_total"`
}
var
(
// passthrough 路径默认不会做语义改写,该计数通常应保持为 0(保留用于未来防御性校验)。
passthroughSemanticMutationTotal
atomic
.
Int64
passthroughUsageParseFailureTotal
atomic
.
Int64
)
func
recordUsageParseFailure
()
{
passthroughUsageParseFailureTotal
.
Add
(
1
)
}
// SnapshotMetrics 返回当前 passthrough 指标快照。
func
SnapshotMetrics
()
MetricsSnapshot
{
return
MetricsSnapshot
{
SemanticMutationTotal
:
passthroughSemanticMutationTotal
.
Load
(),
UsageParseFailureTotal
:
passthroughUsageParseFailureTotal
.
Load
(),
}
}
backend/internal/service/openai_ws_v2/passthrough_relay.go
0 → 100644
View file @
7076717b
package
openai_ws_v2
import
(
"context"
"errors"
"io"
"net"
"strconv"
"strings"
"sync/atomic"
"time"
coderws
"github.com/coder/websocket"
"github.com/tidwall/gjson"
)
type
FrameConn
interface
{
ReadFrame
(
ctx
context
.
Context
)
(
coderws
.
MessageType
,
[]
byte
,
error
)
WriteFrame
(
ctx
context
.
Context
,
msgType
coderws
.
MessageType
,
payload
[]
byte
)
error
Close
()
error
}
type
Usage
struct
{
InputTokens
int
OutputTokens
int
CacheCreationInputTokens
int
CacheReadInputTokens
int
}
type
RelayResult
struct
{
RequestModel
string
Usage
Usage
RequestID
string
TerminalEventType
string
FirstTokenMs
*
int
Duration
time
.
Duration
ClientToUpstreamFrames
int64
UpstreamToClientFrames
int64
DroppedDownstreamFrames
int64
}
type
RelayTurnResult
struct
{
RequestModel
string
Usage
Usage
RequestID
string
TerminalEventType
string
Duration
time
.
Duration
FirstTokenMs
*
int
}
type
RelayExit
struct
{
Stage
string
Err
error
WroteDownstream
bool
}
type
RelayOptions
struct
{
WriteTimeout
time
.
Duration
IdleTimeout
time
.
Duration
UpstreamDrainTimeout
time
.
Duration
FirstMessageType
coderws
.
MessageType
OnUsageParseFailure
func
(
eventType
string
,
usageRaw
string
)
OnTurnComplete
func
(
turn
RelayTurnResult
)
OnTrace
func
(
event
RelayTraceEvent
)
Now
func
()
time
.
Time
}
type
RelayTraceEvent
struct
{
Stage
string
Direction
string
MessageType
string
PayloadBytes
int
Graceful
bool
WroteDownstream
bool
Error
string
}
type
relayState
struct
{
usage
Usage
requestModel
string
lastResponseID
string
terminalEventType
string
firstTokenMs
*
int
turnTimingByID
map
[
string
]
*
relayTurnTiming
}
type
relayExitSignal
struct
{
stage
string
err
error
graceful
bool
wroteDownstream
bool
}
type
observedUpstreamEvent
struct
{
terminal
bool
eventType
string
responseID
string
usage
Usage
duration
time
.
Duration
firstToken
*
int
}
type
relayTurnTiming
struct
{
startAt
time
.
Time
firstTokenMs
*
int
}
func
Relay
(
ctx
context
.
Context
,
clientConn
FrameConn
,
upstreamConn
FrameConn
,
firstClientMessage
[]
byte
,
options
RelayOptions
,
)
(
RelayResult
,
*
RelayExit
)
{
result
:=
RelayResult
{
RequestModel
:
strings
.
TrimSpace
(
gjson
.
GetBytes
(
firstClientMessage
,
"model"
)
.
String
())}
if
clientConn
==
nil
||
upstreamConn
==
nil
{
return
result
,
&
RelayExit
{
Stage
:
"relay_init"
,
Err
:
errors
.
New
(
"relay connection is nil"
)}
}
if
ctx
==
nil
{
ctx
=
context
.
Background
()
}
nowFn
:=
options
.
Now
if
nowFn
==
nil
{
nowFn
=
time
.
Now
}
writeTimeout
:=
options
.
WriteTimeout
if
writeTimeout
<=
0
{
writeTimeout
=
2
*
time
.
Minute
}
drainTimeout
:=
options
.
UpstreamDrainTimeout
if
drainTimeout
<=
0
{
drainTimeout
=
1200
*
time
.
Millisecond
}
firstMessageType
:=
options
.
FirstMessageType
if
firstMessageType
!=
coderws
.
MessageBinary
{
firstMessageType
=
coderws
.
MessageText
}
startAt
:=
nowFn
()
state
:=
&
relayState
{
requestModel
:
result
.
RequestModel
}
onTrace
:=
options
.
OnTrace
relayCtx
,
relayCancel
:=
context
.
WithCancel
(
ctx
)
defer
relayCancel
()
lastActivity
:=
atomic
.
Int64
{}
lastActivity
.
Store
(
nowFn
()
.
UnixNano
())
markActivity
:=
func
()
{
lastActivity
.
Store
(
nowFn
()
.
UnixNano
())
}
writeUpstream
:=
func
(
msgType
coderws
.
MessageType
,
payload
[]
byte
)
error
{
writeCtx
,
cancel
:=
context
.
WithTimeout
(
relayCtx
,
writeTimeout
)
defer
cancel
()
return
upstreamConn
.
WriteFrame
(
writeCtx
,
msgType
,
payload
)
}
writeClient
:=
func
(
msgType
coderws
.
MessageType
,
payload
[]
byte
)
error
{
writeCtx
,
cancel
:=
context
.
WithTimeout
(
relayCtx
,
writeTimeout
)
defer
cancel
()
return
clientConn
.
WriteFrame
(
writeCtx
,
msgType
,
payload
)
}
clientToUpstreamFrames
:=
&
atomic
.
Int64
{}
upstreamToClientFrames
:=
&
atomic
.
Int64
{}
droppedDownstreamFrames
:=
&
atomic
.
Int64
{}
emitRelayTrace
(
onTrace
,
RelayTraceEvent
{
Stage
:
"relay_start"
,
PayloadBytes
:
len
(
firstClientMessage
),
MessageType
:
relayMessageTypeString
(
firstMessageType
),
})
if
err
:=
writeUpstream
(
firstMessageType
,
firstClientMessage
);
err
!=
nil
{
result
.
Duration
=
nowFn
()
.
Sub
(
startAt
)
emitRelayTrace
(
onTrace
,
RelayTraceEvent
{
Stage
:
"write_first_message_failed"
,
Direction
:
"client_to_upstream"
,
MessageType
:
relayMessageTypeString
(
firstMessageType
),
PayloadBytes
:
len
(
firstClientMessage
),
Error
:
err
.
Error
(),
})
return
result
,
&
RelayExit
{
Stage
:
"write_upstream"
,
Err
:
err
}
}
clientToUpstreamFrames
.
Add
(
1
)
emitRelayTrace
(
onTrace
,
RelayTraceEvent
{
Stage
:
"write_first_message_ok"
,
Direction
:
"client_to_upstream"
,
MessageType
:
relayMessageTypeString
(
firstMessageType
),
PayloadBytes
:
len
(
firstClientMessage
),
})
markActivity
()
exitCh
:=
make
(
chan
relayExitSignal
,
3
)
dropDownstreamWrites
:=
atomic
.
Bool
{}
go
runClientToUpstream
(
relayCtx
,
clientConn
,
writeUpstream
,
markActivity
,
clientToUpstreamFrames
,
onTrace
,
exitCh
)
go
runUpstreamToClient
(
relayCtx
,
upstreamConn
,
writeClient
,
startAt
,
nowFn
,
state
,
options
.
OnUsageParseFailure
,
options
.
OnTurnComplete
,
&
dropDownstreamWrites
,
upstreamToClientFrames
,
droppedDownstreamFrames
,
markActivity
,
onTrace
,
exitCh
,
)
go
runIdleWatchdog
(
relayCtx
,
nowFn
,
options
.
IdleTimeout
,
&
lastActivity
,
onTrace
,
exitCh
)
firstExit
:=
<-
exitCh
emitRelayTrace
(
onTrace
,
RelayTraceEvent
{
Stage
:
"first_exit"
,
Direction
:
relayDirectionFromStage
(
firstExit
.
stage
),
Graceful
:
firstExit
.
graceful
,
WroteDownstream
:
firstExit
.
wroteDownstream
,
Error
:
relayErrorString
(
firstExit
.
err
),
})
combinedWroteDownstream
:=
firstExit
.
wroteDownstream
secondExit
:=
relayExitSignal
{
graceful
:
true
}
hasSecondExit
:=
false
// 客户端断开后尽力继续读取上游短窗口,捕获延迟 usage/terminal 事件用于计费。
if
firstExit
.
stage
==
"read_client"
&&
firstExit
.
graceful
{
dropDownstreamWrites
.
Store
(
true
)
secondExit
,
hasSecondExit
=
waitRelayExit
(
exitCh
,
drainTimeout
)
}
else
{
relayCancel
()
_
=
upstreamConn
.
Close
()
secondExit
,
hasSecondExit
=
waitRelayExit
(
exitCh
,
200
*
time
.
Millisecond
)
}
if
hasSecondExit
{
combinedWroteDownstream
=
combinedWroteDownstream
||
secondExit
.
wroteDownstream
emitRelayTrace
(
onTrace
,
RelayTraceEvent
{
Stage
:
"second_exit"
,
Direction
:
relayDirectionFromStage
(
secondExit
.
stage
),
Graceful
:
secondExit
.
graceful
,
WroteDownstream
:
secondExit
.
wroteDownstream
,
Error
:
relayErrorString
(
secondExit
.
err
),
})
}
relayCancel
()
_
=
upstreamConn
.
Close
()
enrichResult
(
&
result
,
state
,
nowFn
()
.
Sub
(
startAt
))
result
.
ClientToUpstreamFrames
=
clientToUpstreamFrames
.
Load
()
result
.
UpstreamToClientFrames
=
upstreamToClientFrames
.
Load
()
result
.
DroppedDownstreamFrames
=
droppedDownstreamFrames
.
Load
()
if
firstExit
.
stage
==
"read_client"
&&
firstExit
.
graceful
{
stage
:=
"client_disconnected"
exitErr
:=
firstExit
.
err
if
hasSecondExit
&&
!
secondExit
.
graceful
{
stage
=
secondExit
.
stage
exitErr
=
secondExit
.
err
}
if
exitErr
==
nil
{
exitErr
=
io
.
EOF
}
emitRelayTrace
(
onTrace
,
RelayTraceEvent
{
Stage
:
"relay_exit"
,
Direction
:
relayDirectionFromStage
(
stage
),
Graceful
:
false
,
WroteDownstream
:
combinedWroteDownstream
,
Error
:
relayErrorString
(
exitErr
),
})
return
result
,
&
RelayExit
{
Stage
:
stage
,
Err
:
exitErr
,
WroteDownstream
:
combinedWroteDownstream
,
}
}
if
firstExit
.
graceful
&&
(
!
hasSecondExit
||
secondExit
.
graceful
)
{
emitRelayTrace
(
onTrace
,
RelayTraceEvent
{
Stage
:
"relay_complete"
,
Graceful
:
true
,
WroteDownstream
:
combinedWroteDownstream
,
})
_
=
clientConn
.
Close
()
return
result
,
nil
}
if
!
firstExit
.
graceful
{
emitRelayTrace
(
onTrace
,
RelayTraceEvent
{
Stage
:
"relay_exit"
,
Direction
:
relayDirectionFromStage
(
firstExit
.
stage
),
Graceful
:
false
,
WroteDownstream
:
combinedWroteDownstream
,
Error
:
relayErrorString
(
firstExit
.
err
),
})
return
result
,
&
RelayExit
{
Stage
:
firstExit
.
stage
,
Err
:
firstExit
.
err
,
WroteDownstream
:
combinedWroteDownstream
,
}
}
if
hasSecondExit
&&
!
secondExit
.
graceful
{
emitRelayTrace
(
onTrace
,
RelayTraceEvent
{
Stage
:
"relay_exit"
,
Direction
:
relayDirectionFromStage
(
secondExit
.
stage
),
Graceful
:
false
,
WroteDownstream
:
combinedWroteDownstream
,
Error
:
relayErrorString
(
secondExit
.
err
),
})
return
result
,
&
RelayExit
{
Stage
:
secondExit
.
stage
,
Err
:
secondExit
.
err
,
WroteDownstream
:
combinedWroteDownstream
,
}
}
emitRelayTrace
(
onTrace
,
RelayTraceEvent
{
Stage
:
"relay_complete"
,
Graceful
:
true
,
WroteDownstream
:
combinedWroteDownstream
,
})
_
=
clientConn
.
Close
()
return
result
,
nil
}
func
runClientToUpstream
(
ctx
context
.
Context
,
clientConn
FrameConn
,
writeUpstream
func
(
msgType
coderws
.
MessageType
,
payload
[]
byte
)
error
,
markActivity
func
(),
forwardedFrames
*
atomic
.
Int64
,
onTrace
func
(
event
RelayTraceEvent
),
exitCh
chan
<-
relayExitSignal
,
)
{
for
{
msgType
,
payload
,
err
:=
clientConn
.
ReadFrame
(
ctx
)
if
err
!=
nil
{
emitRelayTrace
(
onTrace
,
RelayTraceEvent
{
Stage
:
"read_client_failed"
,
Direction
:
"client_to_upstream"
,
Error
:
err
.
Error
(),
Graceful
:
isDisconnectError
(
err
),
})
exitCh
<-
relayExitSignal
{
stage
:
"read_client"
,
err
:
err
,
graceful
:
isDisconnectError
(
err
)}
return
}
markActivity
()
if
err
:=
writeUpstream
(
msgType
,
payload
);
err
!=
nil
{
emitRelayTrace
(
onTrace
,
RelayTraceEvent
{
Stage
:
"write_upstream_failed"
,
Direction
:
"client_to_upstream"
,
MessageType
:
relayMessageTypeString
(
msgType
),
PayloadBytes
:
len
(
payload
),
Error
:
err
.
Error
(),
})
exitCh
<-
relayExitSignal
{
stage
:
"write_upstream"
,
err
:
err
}
return
}
if
forwardedFrames
!=
nil
{
forwardedFrames
.
Add
(
1
)
}
markActivity
()
}
}
func
runUpstreamToClient
(
ctx
context
.
Context
,
upstreamConn
FrameConn
,
writeClient
func
(
msgType
coderws
.
MessageType
,
payload
[]
byte
)
error
,
startAt
time
.
Time
,
nowFn
func
()
time
.
Time
,
state
*
relayState
,
onUsageParseFailure
func
(
eventType
string
,
usageRaw
string
),
onTurnComplete
func
(
turn
RelayTurnResult
),
dropDownstreamWrites
*
atomic
.
Bool
,
forwardedFrames
*
atomic
.
Int64
,
droppedFrames
*
atomic
.
Int64
,
markActivity
func
(),
onTrace
func
(
event
RelayTraceEvent
),
exitCh
chan
<-
relayExitSignal
,
)
{
wroteDownstream
:=
false
for
{
msgType
,
payload
,
err
:=
upstreamConn
.
ReadFrame
(
ctx
)
if
err
!=
nil
{
emitRelayTrace
(
onTrace
,
RelayTraceEvent
{
Stage
:
"read_upstream_failed"
,
Direction
:
"upstream_to_client"
,
Error
:
err
.
Error
(),
Graceful
:
isDisconnectError
(
err
),
WroteDownstream
:
wroteDownstream
,
})
exitCh
<-
relayExitSignal
{
stage
:
"read_upstream"
,
err
:
err
,
graceful
:
isDisconnectError
(
err
),
wroteDownstream
:
wroteDownstream
,
}
return
}
markActivity
()
observedEvent
:=
observedUpstreamEvent
{}
switch
msgType
{
case
coderws
.
MessageText
:
observedEvent
=
observeUpstreamMessage
(
state
,
payload
,
startAt
,
nowFn
,
onUsageParseFailure
)
case
coderws
.
MessageBinary
:
// binary frame 直接透传,不进入 JSON 观测路径(避免无效解析开销)。
}
emitTurnComplete
(
onTurnComplete
,
state
,
observedEvent
)
if
dropDownstreamWrites
!=
nil
&&
dropDownstreamWrites
.
Load
()
{
if
droppedFrames
!=
nil
{
droppedFrames
.
Add
(
1
)
}
emitRelayTrace
(
onTrace
,
RelayTraceEvent
{
Stage
:
"drop_downstream_frame"
,
Direction
:
"upstream_to_client"
,
MessageType
:
relayMessageTypeString
(
msgType
),
PayloadBytes
:
len
(
payload
),
WroteDownstream
:
wroteDownstream
,
})
if
observedEvent
.
terminal
{
exitCh
<-
relayExitSignal
{
stage
:
"drain_terminal"
,
graceful
:
true
,
wroteDownstream
:
wroteDownstream
,
}
return
}
markActivity
()
continue
}
if
err
:=
writeClient
(
msgType
,
payload
);
err
!=
nil
{
emitRelayTrace
(
onTrace
,
RelayTraceEvent
{
Stage
:
"write_client_failed"
,
Direction
:
"upstream_to_client"
,
MessageType
:
relayMessageTypeString
(
msgType
),
PayloadBytes
:
len
(
payload
),
WroteDownstream
:
wroteDownstream
,
Error
:
err
.
Error
(),
})
exitCh
<-
relayExitSignal
{
stage
:
"write_client"
,
err
:
err
,
wroteDownstream
:
wroteDownstream
}
return
}
wroteDownstream
=
true
if
forwardedFrames
!=
nil
{
forwardedFrames
.
Add
(
1
)
}
markActivity
()
}
}
func
runIdleWatchdog
(
ctx
context
.
Context
,
nowFn
func
()
time
.
Time
,
idleTimeout
time
.
Duration
,
lastActivity
*
atomic
.
Int64
,
onTrace
func
(
event
RelayTraceEvent
),
exitCh
chan
<-
relayExitSignal
,
)
{
if
idleTimeout
<=
0
{
return
}
checkInterval
:=
minDuration
(
idleTimeout
/
4
,
5
*
time
.
Second
)
if
checkInterval
<
time
.
Second
{
checkInterval
=
time
.
Second
}
ticker
:=
time
.
NewTicker
(
checkInterval
)
defer
ticker
.
Stop
()
for
{
select
{
case
<-
ctx
.
Done
()
:
return
case
<-
ticker
.
C
:
last
:=
time
.
Unix
(
0
,
lastActivity
.
Load
())
if
nowFn
()
.
Sub
(
last
)
<
idleTimeout
{
continue
}
emitRelayTrace
(
onTrace
,
RelayTraceEvent
{
Stage
:
"idle_timeout_triggered"
,
Direction
:
"watchdog"
,
Error
:
context
.
DeadlineExceeded
.
Error
(),
})
exitCh
<-
relayExitSignal
{
stage
:
"idle_timeout"
,
err
:
context
.
DeadlineExceeded
}
return
}
}
}
func
emitRelayTrace
(
onTrace
func
(
event
RelayTraceEvent
),
event
RelayTraceEvent
)
{
if
onTrace
==
nil
{
return
}
onTrace
(
event
)
}
func
relayMessageTypeString
(
msgType
coderws
.
MessageType
)
string
{
switch
msgType
{
case
coderws
.
MessageText
:
return
"text"
case
coderws
.
MessageBinary
:
return
"binary"
default
:
return
"unknown("
+
strconv
.
Itoa
(
int
(
msgType
))
+
")"
}
}
func
relayDirectionFromStage
(
stage
string
)
string
{
switch
stage
{
case
"read_client"
,
"write_upstream"
:
return
"client_to_upstream"
case
"read_upstream"
,
"write_client"
,
"drain_terminal"
:
return
"upstream_to_client"
case
"idle_timeout"
:
return
"watchdog"
default
:
return
""
}
}
func
relayErrorString
(
err
error
)
string
{
if
err
==
nil
{
return
""
}
return
err
.
Error
()
}
func
observeUpstreamMessage
(
state
*
relayState
,
message
[]
byte
,
startAt
time
.
Time
,
nowFn
func
()
time
.
Time
,
onUsageParseFailure
func
(
eventType
string
,
usageRaw
string
),
)
observedUpstreamEvent
{
if
state
==
nil
||
len
(
message
)
==
0
{
return
observedUpstreamEvent
{}
}
values
:=
gjson
.
GetManyBytes
(
message
,
"type"
,
"response.id"
,
"response_id"
,
"id"
)
eventType
:=
strings
.
TrimSpace
(
values
[
0
]
.
String
())
if
eventType
==
""
{
return
observedUpstreamEvent
{}
}
responseID
:=
strings
.
TrimSpace
(
values
[
1
]
.
String
())
if
responseID
==
""
{
responseID
=
strings
.
TrimSpace
(
values
[
2
]
.
String
())
}
// 仅 terminal 事件兜底读取顶层 id,避免把 event_id 当成 response_id 关联到 turn。
if
responseID
==
""
&&
isTerminalEvent
(
eventType
)
{
responseID
=
strings
.
TrimSpace
(
values
[
3
]
.
String
())
}
now
:=
nowFn
()
if
state
.
firstTokenMs
==
nil
&&
isTokenEvent
(
eventType
)
{
ms
:=
int
(
now
.
Sub
(
startAt
)
.
Milliseconds
())
if
ms
>=
0
{
state
.
firstTokenMs
=
&
ms
}
}
parsedUsage
:=
parseUsageAndAccumulate
(
state
,
message
,
eventType
,
onUsageParseFailure
)
observed
:=
observedUpstreamEvent
{
eventType
:
eventType
,
responseID
:
responseID
,
usage
:
parsedUsage
,
}
if
responseID
!=
""
{
turnTiming
:=
openAIWSRelayGetOrInitTurnTiming
(
state
,
responseID
,
now
)
if
turnTiming
!=
nil
&&
turnTiming
.
firstTokenMs
==
nil
&&
isTokenEvent
(
eventType
)
{
ms
:=
int
(
now
.
Sub
(
turnTiming
.
startAt
)
.
Milliseconds
())
if
ms
>=
0
{
turnTiming
.
firstTokenMs
=
&
ms
}
}
}
if
!
isTerminalEvent
(
eventType
)
{
return
observed
}
observed
.
terminal
=
true
state
.
terminalEventType
=
eventType
if
responseID
!=
""
{
state
.
lastResponseID
=
responseID
if
turnTiming
,
ok
:=
openAIWSRelayDeleteTurnTiming
(
state
,
responseID
);
ok
{
duration
:=
now
.
Sub
(
turnTiming
.
startAt
)
if
duration
<
0
{
duration
=
0
}
observed
.
duration
=
duration
observed
.
firstToken
=
openAIWSRelayCloneIntPtr
(
turnTiming
.
firstTokenMs
)
}
}
return
observed
}
func
emitTurnComplete
(
onTurnComplete
func
(
turn
RelayTurnResult
),
state
*
relayState
,
observed
observedUpstreamEvent
,
)
{
if
onTurnComplete
==
nil
||
!
observed
.
terminal
{
return
}
responseID
:=
strings
.
TrimSpace
(
observed
.
responseID
)
if
responseID
==
""
{
return
}
requestModel
:=
""
if
state
!=
nil
{
requestModel
=
state
.
requestModel
}
onTurnComplete
(
RelayTurnResult
{
RequestModel
:
requestModel
,
Usage
:
observed
.
usage
,
RequestID
:
responseID
,
TerminalEventType
:
observed
.
eventType
,
Duration
:
observed
.
duration
,
FirstTokenMs
:
openAIWSRelayCloneIntPtr
(
observed
.
firstToken
),
})
}
func
openAIWSRelayGetOrInitTurnTiming
(
state
*
relayState
,
responseID
string
,
now
time
.
Time
)
*
relayTurnTiming
{
if
state
==
nil
{
return
nil
}
if
state
.
turnTimingByID
==
nil
{
state
.
turnTimingByID
=
make
(
map
[
string
]
*
relayTurnTiming
,
8
)
}
timing
,
ok
:=
state
.
turnTimingByID
[
responseID
]
if
!
ok
||
timing
==
nil
||
timing
.
startAt
.
IsZero
()
{
timing
=
&
relayTurnTiming
{
startAt
:
now
}
state
.
turnTimingByID
[
responseID
]
=
timing
return
timing
}
return
timing
}
func
openAIWSRelayDeleteTurnTiming
(
state
*
relayState
,
responseID
string
)
(
relayTurnTiming
,
bool
)
{
if
state
==
nil
||
state
.
turnTimingByID
==
nil
{
return
relayTurnTiming
{},
false
}
timing
,
ok
:=
state
.
turnTimingByID
[
responseID
]
if
!
ok
||
timing
==
nil
{
return
relayTurnTiming
{},
false
}
delete
(
state
.
turnTimingByID
,
responseID
)
return
*
timing
,
true
}
func
openAIWSRelayCloneIntPtr
(
v
*
int
)
*
int
{
if
v
==
nil
{
return
nil
}
cloned
:=
*
v
return
&
cloned
}
func
parseUsageAndAccumulate
(
state
*
relayState
,
message
[]
byte
,
eventType
string
,
onParseFailure
func
(
eventType
string
,
usageRaw
string
),
)
Usage
{
if
state
==
nil
||
len
(
message
)
==
0
||
!
shouldParseUsage
(
eventType
)
{
return
Usage
{}
}
usageResult
:=
gjson
.
GetBytes
(
message
,
"response.usage"
)
if
!
usageResult
.
Exists
()
{
return
Usage
{}
}
usageRaw
:=
strings
.
TrimSpace
(
usageResult
.
Raw
)
if
usageRaw
==
""
||
!
strings
.
HasPrefix
(
usageRaw
,
"{"
)
{
recordUsageParseFailure
()
if
onParseFailure
!=
nil
{
onParseFailure
(
eventType
,
usageRaw
)
}
return
Usage
{}
}
inputResult
:=
gjson
.
GetBytes
(
message
,
"response.usage.input_tokens"
)
outputResult
:=
gjson
.
GetBytes
(
message
,
"response.usage.output_tokens"
)
cachedResult
:=
gjson
.
GetBytes
(
message
,
"response.usage.input_tokens_details.cached_tokens"
)
inputTokens
,
inputOK
:=
parseUsageIntField
(
inputResult
,
true
)
outputTokens
,
outputOK
:=
parseUsageIntField
(
outputResult
,
true
)
cachedTokens
,
cachedOK
:=
parseUsageIntField
(
cachedResult
,
false
)
if
!
inputOK
||
!
outputOK
||
!
cachedOK
{
recordUsageParseFailure
()
if
onParseFailure
!=
nil
{
onParseFailure
(
eventType
,
usageRaw
)
}
// 解析失败时不做部分字段累加,避免计费 usage 出现“半有效”状态。
return
Usage
{}
}
parsedUsage
:=
Usage
{
InputTokens
:
inputTokens
,
OutputTokens
:
outputTokens
,
CacheReadInputTokens
:
cachedTokens
,
}
state
.
usage
.
InputTokens
+=
parsedUsage
.
InputTokens
state
.
usage
.
OutputTokens
+=
parsedUsage
.
OutputTokens
state
.
usage
.
CacheReadInputTokens
+=
parsedUsage
.
CacheReadInputTokens
return
parsedUsage
}
func
parseUsageIntField
(
value
gjson
.
Result
,
required
bool
)
(
int
,
bool
)
{
if
!
value
.
Exists
()
{
return
0
,
!
required
}
if
value
.
Type
!=
gjson
.
Number
{
return
0
,
false
}
return
int
(
value
.
Int
()),
true
}
func
enrichResult
(
result
*
RelayResult
,
state
*
relayState
,
duration
time
.
Duration
)
{
if
result
==
nil
{
return
}
result
.
Duration
=
duration
if
state
==
nil
{
return
}
result
.
RequestModel
=
state
.
requestModel
result
.
Usage
=
state
.
usage
result
.
RequestID
=
state
.
lastResponseID
result
.
TerminalEventType
=
state
.
terminalEventType
result
.
FirstTokenMs
=
state
.
firstTokenMs
}
func
isDisconnectError
(
err
error
)
bool
{
if
err
==
nil
{
return
false
}
if
errors
.
Is
(
err
,
io
.
EOF
)
||
errors
.
Is
(
err
,
net
.
ErrClosed
)
||
errors
.
Is
(
err
,
context
.
Canceled
)
{
return
true
}
switch
coderws
.
CloseStatus
(
err
)
{
case
coderws
.
StatusNormalClosure
,
coderws
.
StatusGoingAway
,
coderws
.
StatusNoStatusRcvd
,
coderws
.
StatusAbnormalClosure
:
return
true
}
message
:=
strings
.
ToLower
(
strings
.
TrimSpace
(
err
.
Error
()))
if
message
==
""
{
return
false
}
return
strings
.
Contains
(
message
,
"failed to read frame header: eof"
)
||
strings
.
Contains
(
message
,
"unexpected eof"
)
||
strings
.
Contains
(
message
,
"use of closed network connection"
)
||
strings
.
Contains
(
message
,
"connection reset by peer"
)
||
strings
.
Contains
(
message
,
"broken pipe"
)
}
func
isTerminalEvent
(
eventType
string
)
bool
{
switch
eventType
{
case
"response.completed"
,
"response.done"
,
"response.failed"
,
"response.incomplete"
,
"response.cancelled"
,
"response.canceled"
:
return
true
default
:
return
false
}
}
func
shouldParseUsage
(
eventType
string
)
bool
{
switch
eventType
{
case
"response.completed"
,
"response.done"
,
"response.failed"
:
return
true
default
:
return
false
}
}
func
isTokenEvent
(
eventType
string
)
bool
{
if
eventType
==
""
{
return
false
}
switch
eventType
{
case
"response.created"
,
"response.in_progress"
,
"response.output_item.added"
,
"response.output_item.done"
:
return
false
}
if
strings
.
Contains
(
eventType
,
".delta"
)
{
return
true
}
if
strings
.
HasPrefix
(
eventType
,
"response.output_text"
)
{
return
true
}
if
strings
.
HasPrefix
(
eventType
,
"response.output"
)
{
return
true
}
return
eventType
==
"response.completed"
||
eventType
==
"response.done"
}
func
minDuration
(
a
,
b
time
.
Duration
)
time
.
Duration
{
if
a
<=
0
{
return
b
}
if
b
<=
0
{
return
a
}
if
a
<
b
{
return
a
}
return
b
}
func
waitRelayExit
(
exitCh
<-
chan
relayExitSignal
,
timeout
time
.
Duration
)
(
relayExitSignal
,
bool
)
{
if
timeout
<=
0
{
timeout
=
200
*
time
.
Millisecond
}
select
{
case
sig
:=
<-
exitCh
:
return
sig
,
true
case
<-
time
.
After
(
timeout
)
:
return
relayExitSignal
{},
false
}
}
backend/internal/service/openai_ws_v2/passthrough_relay_internal_test.go
0 → 100644
View file @
7076717b
package
openai_ws_v2
import
(
"context"
"errors"
"io"
"net"
"sync/atomic"
"testing"
"time"
coderws
"github.com/coder/websocket"
"github.com/stretchr/testify/require"
"github.com/tidwall/gjson"
)
func
TestRunEntry_DelegatesRelay
(
t
*
testing
.
T
)
{
t
.
Parallel
()
clientConn
:=
newPassthroughTestFrameConn
(
nil
,
false
)
upstreamConn
:=
newPassthroughTestFrameConn
([]
passthroughTestFrame
{
{
msgType
:
coderws
.
MessageText
,
payload
:
[]
byte
(
`{"type":"response.completed","response":{"id":"resp_entry","usage":{"input_tokens":1,"output_tokens":1}}}`
),
},
},
true
)
result
,
relayExit
:=
RunEntry
(
EntryInput
{
Ctx
:
context
.
Background
(),
ClientConn
:
clientConn
,
UpstreamConn
:
upstreamConn
,
FirstClientMessage
:
[]
byte
(
`{"type":"response.create","model":"gpt-4o","input":[]}`
),
})
require
.
Nil
(
t
,
relayExit
)
require
.
Equal
(
t
,
"resp_entry"
,
result
.
RequestID
)
}
func
TestRunClientToUpstream_ErrorPaths
(
t
*
testing
.
T
)
{
t
.
Parallel
()
t
.
Run
(
"read client eof"
,
func
(
t
*
testing
.
T
)
{
t
.
Parallel
()
exitCh
:=
make
(
chan
relayExitSignal
,
1
)
runClientToUpstream
(
context
.
Background
(),
newPassthroughTestFrameConn
(
nil
,
true
),
func
(
_
coderws
.
MessageType
,
_
[]
byte
)
error
{
return
nil
},
func
()
{},
nil
,
nil
,
exitCh
,
)
sig
:=
<-
exitCh
require
.
Equal
(
t
,
"read_client"
,
sig
.
stage
)
require
.
True
(
t
,
sig
.
graceful
)
})
t
.
Run
(
"write upstream failed"
,
func
(
t
*
testing
.
T
)
{
t
.
Parallel
()
exitCh
:=
make
(
chan
relayExitSignal
,
1
)
runClientToUpstream
(
context
.
Background
(),
newPassthroughTestFrameConn
([]
passthroughTestFrame
{
{
msgType
:
coderws
.
MessageText
,
payload
:
[]
byte
(
`{"x":1}`
)},
},
true
),
func
(
_
coderws
.
MessageType
,
_
[]
byte
)
error
{
return
errors
.
New
(
"boom"
)
},
func
()
{},
nil
,
nil
,
exitCh
,
)
sig
:=
<-
exitCh
require
.
Equal
(
t
,
"write_upstream"
,
sig
.
stage
)
require
.
False
(
t
,
sig
.
graceful
)
})
t
.
Run
(
"forwarded counter and trace callback"
,
func
(
t
*
testing
.
T
)
{
t
.
Parallel
()
exitCh
:=
make
(
chan
relayExitSignal
,
1
)
forwarded
:=
&
atomic
.
Int64
{}
traces
:=
make
([]
RelayTraceEvent
,
0
,
2
)
runClientToUpstream
(
context
.
Background
(),
newPassthroughTestFrameConn
([]
passthroughTestFrame
{
{
msgType
:
coderws
.
MessageText
,
payload
:
[]
byte
(
`{"x":1}`
)},
},
true
),
func
(
_
coderws
.
MessageType
,
_
[]
byte
)
error
{
return
nil
},
func
()
{},
forwarded
,
func
(
event
RelayTraceEvent
)
{
traces
=
append
(
traces
,
event
)
},
exitCh
,
)
sig
:=
<-
exitCh
require
.
Equal
(
t
,
"read_client"
,
sig
.
stage
)
require
.
Equal
(
t
,
int64
(
1
),
forwarded
.
Load
())
require
.
NotEmpty
(
t
,
traces
)
})
}
func
TestRunUpstreamToClient_ErrorAndDropPaths
(
t
*
testing
.
T
)
{
t
.
Parallel
()
t
.
Run
(
"read upstream eof"
,
func
(
t
*
testing
.
T
)
{
t
.
Parallel
()
exitCh
:=
make
(
chan
relayExitSignal
,
1
)
drop
:=
&
atomic
.
Bool
{}
drop
.
Store
(
false
)
runUpstreamToClient
(
context
.
Background
(),
newPassthroughTestFrameConn
(
nil
,
true
),
func
(
_
coderws
.
MessageType
,
_
[]
byte
)
error
{
return
nil
},
time
.
Now
(),
time
.
Now
,
&
relayState
{},
nil
,
nil
,
drop
,
nil
,
nil
,
func
()
{},
nil
,
exitCh
,
)
sig
:=
<-
exitCh
require
.
Equal
(
t
,
"read_upstream"
,
sig
.
stage
)
require
.
True
(
t
,
sig
.
graceful
)
})
t
.
Run
(
"write client failed"
,
func
(
t
*
testing
.
T
)
{
t
.
Parallel
()
exitCh
:=
make
(
chan
relayExitSignal
,
1
)
drop
:=
&
atomic
.
Bool
{}
drop
.
Store
(
false
)
runUpstreamToClient
(
context
.
Background
(),
newPassthroughTestFrameConn
([]
passthroughTestFrame
{
{
msgType
:
coderws
.
MessageText
,
payload
:
[]
byte
(
`{"type":"response.output_text.delta","delta":"x"}`
)},
},
true
),
func
(
_
coderws
.
MessageType
,
_
[]
byte
)
error
{
return
errors
.
New
(
"write failed"
)
},
time
.
Now
(),
time
.
Now
,
&
relayState
{},
nil
,
nil
,
drop
,
nil
,
nil
,
func
()
{},
nil
,
exitCh
,
)
sig
:=
<-
exitCh
require
.
Equal
(
t
,
"write_client"
,
sig
.
stage
)
})
t
.
Run
(
"drop downstream and stop on terminal"
,
func
(
t
*
testing
.
T
)
{
t
.
Parallel
()
exitCh
:=
make
(
chan
relayExitSignal
,
1
)
drop
:=
&
atomic
.
Bool
{}
drop
.
Store
(
true
)
dropped
:=
&
atomic
.
Int64
{}
runUpstreamToClient
(
context
.
Background
(),
newPassthroughTestFrameConn
([]
passthroughTestFrame
{
{
msgType
:
coderws
.
MessageText
,
payload
:
[]
byte
(
`{"type":"response.completed","response":{"id":"resp_drop","usage":{"input_tokens":1,"output_tokens":1}}}`
),
},
},
true
),
func
(
_
coderws
.
MessageType
,
_
[]
byte
)
error
{
return
nil
},
time
.
Now
(),
time
.
Now
,
&
relayState
{},
nil
,
nil
,
drop
,
nil
,
dropped
,
func
()
{},
nil
,
exitCh
,
)
sig
:=
<-
exitCh
require
.
Equal
(
t
,
"drain_terminal"
,
sig
.
stage
)
require
.
True
(
t
,
sig
.
graceful
)
require
.
Equal
(
t
,
int64
(
1
),
dropped
.
Load
())
})
}
func
TestRunIdleWatchdog_NoTimeoutWhenDisabled
(
t
*
testing
.
T
)
{
t
.
Parallel
()
exitCh
:=
make
(
chan
relayExitSignal
,
1
)
lastActivity
:=
&
atomic
.
Int64
{}
lastActivity
.
Store
(
time
.
Now
()
.
UnixNano
())
ctx
,
cancel
:=
context
.
WithCancel
(
context
.
Background
())
defer
cancel
()
go
runIdleWatchdog
(
ctx
,
time
.
Now
,
0
,
lastActivity
,
nil
,
exitCh
)
select
{
case
<-
exitCh
:
t
.
Fatal
(
"unexpected idle timeout signal"
)
case
<-
time
.
After
(
200
*
time
.
Millisecond
)
:
}
}
func
TestHelperFunctionsCoverage
(
t
*
testing
.
T
)
{
t
.
Parallel
()
require
.
Equal
(
t
,
"text"
,
relayMessageTypeString
(
coderws
.
MessageText
))
require
.
Equal
(
t
,
"binary"
,
relayMessageTypeString
(
coderws
.
MessageBinary
))
require
.
Contains
(
t
,
relayMessageTypeString
(
coderws
.
MessageType
(
99
)),
"unknown("
)
require
.
Equal
(
t
,
""
,
relayErrorString
(
nil
))
require
.
Equal
(
t
,
"x"
,
relayErrorString
(
errors
.
New
(
"x"
)))
require
.
True
(
t
,
isDisconnectError
(
io
.
EOF
))
require
.
True
(
t
,
isDisconnectError
(
net
.
ErrClosed
))
require
.
True
(
t
,
isDisconnectError
(
context
.
Canceled
))
require
.
True
(
t
,
isDisconnectError
(
coderws
.
CloseError
{
Code
:
coderws
.
StatusGoingAway
}))
require
.
True
(
t
,
isDisconnectError
(
errors
.
New
(
"broken pipe"
)))
require
.
False
(
t
,
isDisconnectError
(
errors
.
New
(
"unrelated"
)))
require
.
True
(
t
,
isTokenEvent
(
"response.output_text.delta"
))
require
.
True
(
t
,
isTokenEvent
(
"response.output_audio.delta"
))
require
.
True
(
t
,
isTokenEvent
(
"response.completed"
))
require
.
False
(
t
,
isTokenEvent
(
""
))
require
.
False
(
t
,
isTokenEvent
(
"response.created"
))
require
.
Equal
(
t
,
2
*
time
.
Second
,
minDuration
(
2
*
time
.
Second
,
5
*
time
.
Second
))
require
.
Equal
(
t
,
2
*
time
.
Second
,
minDuration
(
5
*
time
.
Second
,
2
*
time
.
Second
))
require
.
Equal
(
t
,
5
*
time
.
Second
,
minDuration
(
0
,
5
*
time
.
Second
))
require
.
Equal
(
t
,
2
*
time
.
Second
,
minDuration
(
2
*
time
.
Second
,
0
))
ch
:=
make
(
chan
relayExitSignal
,
1
)
ch
<-
relayExitSignal
{
stage
:
"ok"
}
sig
,
ok
:=
waitRelayExit
(
ch
,
10
*
time
.
Millisecond
)
require
.
True
(
t
,
ok
)
require
.
Equal
(
t
,
"ok"
,
sig
.
stage
)
ch
<-
relayExitSignal
{
stage
:
"ok2"
}
sig
,
ok
=
waitRelayExit
(
ch
,
0
)
require
.
True
(
t
,
ok
)
require
.
Equal
(
t
,
"ok2"
,
sig
.
stage
)
_
,
ok
=
waitRelayExit
(
ch
,
10
*
time
.
Millisecond
)
require
.
False
(
t
,
ok
)
n
,
ok
:=
parseUsageIntField
(
gjson
.
Get
(
`{"n":3}`
,
"n"
),
true
)
require
.
True
(
t
,
ok
)
require
.
Equal
(
t
,
3
,
n
)
_
,
ok
=
parseUsageIntField
(
gjson
.
Get
(
`{"n":"x"}`
,
"n"
),
true
)
require
.
False
(
t
,
ok
)
n
,
ok
=
parseUsageIntField
(
gjson
.
Result
{},
false
)
require
.
True
(
t
,
ok
)
require
.
Equal
(
t
,
0
,
n
)
_
,
ok
=
parseUsageIntField
(
gjson
.
Result
{},
true
)
require
.
False
(
t
,
ok
)
}
func
TestParseUsageAndEnrichCoverage
(
t
*
testing
.
T
)
{
t
.
Parallel
()
state
:=
&
relayState
{}
parseUsageAndAccumulate
(
state
,
[]
byte
(
`{"type":"response.completed","response":{"usage":{"input_tokens":"bad"}}}`
),
"response.completed"
,
nil
)
require
.
Equal
(
t
,
0
,
state
.
usage
.
InputTokens
)
parseUsageAndAccumulate
(
state
,
[]
byte
(
`{"type":"response.completed","response":{"usage":{"input_tokens":9,"output_tokens":"bad","input_tokens_details":{"cached_tokens":2}}}}`
),
"response.completed"
,
nil
,
)
require
.
Equal
(
t
,
0
,
state
.
usage
.
InputTokens
,
"部分字段解析失败时不应累加 usage"
)
require
.
Equal
(
t
,
0
,
state
.
usage
.
OutputTokens
)
require
.
Equal
(
t
,
0
,
state
.
usage
.
CacheReadInputTokens
)
parseUsageAndAccumulate
(
state
,
[]
byte
(
`{"type":"response.completed","response":{"usage":{"input_tokens_details":{"cached_tokens":2}}}}`
),
"response.completed"
,
nil
,
)
require
.
Equal
(
t
,
0
,
state
.
usage
.
InputTokens
,
"必填 usage 字段缺失时不应累加 usage"
)
require
.
Equal
(
t
,
0
,
state
.
usage
.
OutputTokens
)
require
.
Equal
(
t
,
0
,
state
.
usage
.
CacheReadInputTokens
)
parseUsageAndAccumulate
(
state
,
[]
byte
(
`{"type":"response.completed","response":{"usage":{"input_tokens":2,"output_tokens":1,"input_tokens_details":{"cached_tokens":1}}}}`
),
"response.completed"
,
nil
)
require
.
Equal
(
t
,
2
,
state
.
usage
.
InputTokens
)
require
.
Equal
(
t
,
1
,
state
.
usage
.
OutputTokens
)
require
.
Equal
(
t
,
1
,
state
.
usage
.
CacheReadInputTokens
)
result
:=
&
RelayResult
{}
enrichResult
(
result
,
state
,
5
*
time
.
Millisecond
)
require
.
Equal
(
t
,
state
.
usage
.
InputTokens
,
result
.
Usage
.
InputTokens
)
require
.
Equal
(
t
,
5
*
time
.
Millisecond
,
result
.
Duration
)
parseUsageAndAccumulate
(
state
,
[]
byte
(
`{"type":"response.in_progress","response":{"usage":{"input_tokens":9}}}`
),
"response.in_progress"
,
nil
)
require
.
Equal
(
t
,
2
,
state
.
usage
.
InputTokens
)
enrichResult
(
nil
,
state
,
0
)
}
func
TestEmitTurnCompleteCoverage
(
t
*
testing
.
T
)
{
t
.
Parallel
()
// 非 terminal 事件不应触发。
called
:=
0
emitTurnComplete
(
func
(
turn
RelayTurnResult
)
{
called
++
},
&
relayState
{
requestModel
:
"gpt-5"
},
observedUpstreamEvent
{
terminal
:
false
,
eventType
:
"response.output_text.delta"
,
responseID
:
"resp_ignored"
,
usage
:
Usage
{
InputTokens
:
1
},
})
require
.
Equal
(
t
,
0
,
called
)
// 缺少 response_id 时不应触发。
emitTurnComplete
(
func
(
turn
RelayTurnResult
)
{
called
++
},
&
relayState
{
requestModel
:
"gpt-5"
},
observedUpstreamEvent
{
terminal
:
true
,
eventType
:
"response.completed"
,
})
require
.
Equal
(
t
,
0
,
called
)
// terminal 且 response_id 存在,应该触发;state=nil 时 model 为空串。
var
got
RelayTurnResult
emitTurnComplete
(
func
(
turn
RelayTurnResult
)
{
called
++
got
=
turn
},
nil
,
observedUpstreamEvent
{
terminal
:
true
,
eventType
:
"response.completed"
,
responseID
:
"resp_emit"
,
usage
:
Usage
{
InputTokens
:
2
,
OutputTokens
:
3
},
})
require
.
Equal
(
t
,
1
,
called
)
require
.
Equal
(
t
,
"resp_emit"
,
got
.
RequestID
)
require
.
Equal
(
t
,
"response.completed"
,
got
.
TerminalEventType
)
require
.
Equal
(
t
,
2
,
got
.
Usage
.
InputTokens
)
require
.
Equal
(
t
,
3
,
got
.
Usage
.
OutputTokens
)
require
.
Equal
(
t
,
""
,
got
.
RequestModel
)
}
func
TestIsDisconnectErrorCoverage_CloseStatusesAndMessageBranches
(
t
*
testing
.
T
)
{
t
.
Parallel
()
require
.
True
(
t
,
isDisconnectError
(
coderws
.
CloseError
{
Code
:
coderws
.
StatusNormalClosure
}))
require
.
True
(
t
,
isDisconnectError
(
coderws
.
CloseError
{
Code
:
coderws
.
StatusNoStatusRcvd
}))
require
.
True
(
t
,
isDisconnectError
(
coderws
.
CloseError
{
Code
:
coderws
.
StatusAbnormalClosure
}))
require
.
True
(
t
,
isDisconnectError
(
errors
.
New
(
"connection reset by peer"
)))
require
.
False
(
t
,
isDisconnectError
(
errors
.
New
(
" "
)))
}
func
TestIsTokenEventCoverageBranches
(
t
*
testing
.
T
)
{
t
.
Parallel
()
require
.
False
(
t
,
isTokenEvent
(
"response.in_progress"
))
require
.
False
(
t
,
isTokenEvent
(
"response.output_item.added"
))
require
.
True
(
t
,
isTokenEvent
(
"response.output_audio.delta"
))
require
.
True
(
t
,
isTokenEvent
(
"response.output"
))
require
.
True
(
t
,
isTokenEvent
(
"response.done"
))
}
func
TestRelayTurnTimingHelpersCoverage
(
t
*
testing
.
T
)
{
t
.
Parallel
()
now
:=
time
.
Unix
(
100
,
0
)
// nil state
require
.
Nil
(
t
,
openAIWSRelayGetOrInitTurnTiming
(
nil
,
"resp_nil"
,
now
))
_
,
ok
:=
openAIWSRelayDeleteTurnTiming
(
nil
,
"resp_nil"
)
require
.
False
(
t
,
ok
)
state
:=
&
relayState
{}
timing
:=
openAIWSRelayGetOrInitTurnTiming
(
state
,
"resp_a"
,
now
)
require
.
NotNil
(
t
,
timing
)
require
.
Equal
(
t
,
now
,
timing
.
startAt
)
// 再次获取返回同一条 timing
timing2
:=
openAIWSRelayGetOrInitTurnTiming
(
state
,
"resp_a"
,
now
.
Add
(
5
*
time
.
Second
))
require
.
NotNil
(
t
,
timing2
)
require
.
Equal
(
t
,
now
,
timing2
.
startAt
)
// 删除存在键
deleted
,
ok
:=
openAIWSRelayDeleteTurnTiming
(
state
,
"resp_a"
)
require
.
True
(
t
,
ok
)
require
.
Equal
(
t
,
now
,
deleted
.
startAt
)
// 删除不存在键
_
,
ok
=
openAIWSRelayDeleteTurnTiming
(
state
,
"resp_a"
)
require
.
False
(
t
,
ok
)
}
func
TestObserveUpstreamMessage_ResponseIDFallbackPolicy
(
t
*
testing
.
T
)
{
t
.
Parallel
()
state
:=
&
relayState
{
requestModel
:
"gpt-5"
}
startAt
:=
time
.
Unix
(
0
,
0
)
now
:=
startAt
nowFn
:=
func
()
time
.
Time
{
now
=
now
.
Add
(
5
*
time
.
Millisecond
)
return
now
}
// 非 terminal:仅有顶层 id,不应把 event id 当成 response_id。
observed
:=
observeUpstreamMessage
(
state
,
[]
byte
(
`{"type":"response.output_text.delta","id":"evt_123","delta":"hi"}`
),
startAt
,
nowFn
,
nil
,
)
require
.
False
(
t
,
observed
.
terminal
)
require
.
Equal
(
t
,
""
,
observed
.
responseID
)
// terminal:允许兜底用顶层 id(用于兼容少数字段变体)。
observed
=
observeUpstreamMessage
(
state
,
[]
byte
(
`{"type":"response.completed","id":"resp_fallback","response":{"usage":{"input_tokens":1,"output_tokens":1}}}`
),
startAt
,
nowFn
,
nil
,
)
require
.
True
(
t
,
observed
.
terminal
)
require
.
Equal
(
t
,
"resp_fallback"
,
observed
.
responseID
)
}
backend/internal/service/openai_ws_v2/passthrough_relay_test.go
0 → 100644
View file @
7076717b
package
openai_ws_v2
import
(
"context"
"errors"
"io"
"sync"
"sync/atomic"
"testing"
"time"
coderws
"github.com/coder/websocket"
"github.com/stretchr/testify/require"
)
type
passthroughTestFrame
struct
{
msgType
coderws
.
MessageType
payload
[]
byte
}
type
passthroughTestFrameConn
struct
{
mu
sync
.
Mutex
writes
[]
passthroughTestFrame
readCh
chan
passthroughTestFrame
once
sync
.
Once
}
type
delayedReadFrameConn
struct
{
base
FrameConn
firstDelay
time
.
Duration
once
sync
.
Once
}
type
closeSpyFrameConn
struct
{
closeCalls
atomic
.
Int32
}
func
newPassthroughTestFrameConn
(
frames
[]
passthroughTestFrame
,
autoClose
bool
)
*
passthroughTestFrameConn
{
c
:=
&
passthroughTestFrameConn
{
readCh
:
make
(
chan
passthroughTestFrame
,
len
(
frames
)
+
1
),
}
for
_
,
frame
:=
range
frames
{
copied
:=
passthroughTestFrame
{
msgType
:
frame
.
msgType
,
payload
:
append
([]
byte
(
nil
),
frame
.
payload
...
)}
c
.
readCh
<-
copied
}
if
autoClose
{
close
(
c
.
readCh
)
}
return
c
}
func
(
c
*
passthroughTestFrameConn
)
ReadFrame
(
ctx
context
.
Context
)
(
coderws
.
MessageType
,
[]
byte
,
error
)
{
if
ctx
==
nil
{
ctx
=
context
.
Background
()
}
select
{
case
<-
ctx
.
Done
()
:
return
coderws
.
MessageText
,
nil
,
ctx
.
Err
()
case
frame
,
ok
:=
<-
c
.
readCh
:
if
!
ok
{
return
coderws
.
MessageText
,
nil
,
io
.
EOF
}
return
frame
.
msgType
,
append
([]
byte
(
nil
),
frame
.
payload
...
),
nil
}
}
func
(
c
*
passthroughTestFrameConn
)
WriteFrame
(
ctx
context
.
Context
,
msgType
coderws
.
MessageType
,
payload
[]
byte
)
error
{
if
ctx
==
nil
{
ctx
=
context
.
Background
()
}
select
{
case
<-
ctx
.
Done
()
:
return
ctx
.
Err
()
default
:
}
c
.
mu
.
Lock
()
defer
c
.
mu
.
Unlock
()
c
.
writes
=
append
(
c
.
writes
,
passthroughTestFrame
{
msgType
:
msgType
,
payload
:
append
([]
byte
(
nil
),
payload
...
)})
return
nil
}
func
(
c
*
passthroughTestFrameConn
)
Close
()
error
{
c
.
once
.
Do
(
func
()
{
defer
func
()
{
_
=
recover
()
}()
close
(
c
.
readCh
)
})
return
nil
}
func
(
c
*
passthroughTestFrameConn
)
Writes
()
[]
passthroughTestFrame
{
c
.
mu
.
Lock
()
defer
c
.
mu
.
Unlock
()
out
:=
make
([]
passthroughTestFrame
,
len
(
c
.
writes
))
copy
(
out
,
c
.
writes
)
return
out
}
func
(
c
*
delayedReadFrameConn
)
ReadFrame
(
ctx
context
.
Context
)
(
coderws
.
MessageType
,
[]
byte
,
error
)
{
if
c
==
nil
||
c
.
base
==
nil
{
return
coderws
.
MessageText
,
nil
,
io
.
EOF
}
c
.
once
.
Do
(
func
()
{
if
c
.
firstDelay
>
0
{
timer
:=
time
.
NewTimer
(
c
.
firstDelay
)
defer
timer
.
Stop
()
select
{
case
<-
ctx
.
Done
()
:
case
<-
timer
.
C
:
}
}
})
return
c
.
base
.
ReadFrame
(
ctx
)
}
func
(
c
*
delayedReadFrameConn
)
WriteFrame
(
ctx
context
.
Context
,
msgType
coderws
.
MessageType
,
payload
[]
byte
)
error
{
if
c
==
nil
||
c
.
base
==
nil
{
return
io
.
EOF
}
return
c
.
base
.
WriteFrame
(
ctx
,
msgType
,
payload
)
}
func
(
c
*
delayedReadFrameConn
)
Close
()
error
{
if
c
==
nil
||
c
.
base
==
nil
{
return
nil
}
return
c
.
base
.
Close
()
}
func
(
c
*
closeSpyFrameConn
)
ReadFrame
(
ctx
context
.
Context
)
(
coderws
.
MessageType
,
[]
byte
,
error
)
{
if
ctx
==
nil
{
ctx
=
context
.
Background
()
}
<-
ctx
.
Done
()
return
coderws
.
MessageText
,
nil
,
ctx
.
Err
()
}
func
(
c
*
closeSpyFrameConn
)
WriteFrame
(
ctx
context
.
Context
,
_
coderws
.
MessageType
,
_
[]
byte
)
error
{
if
ctx
==
nil
{
ctx
=
context
.
Background
()
}
select
{
case
<-
ctx
.
Done
()
:
return
ctx
.
Err
()
default
:
return
nil
}
}
func
(
c
*
closeSpyFrameConn
)
Close
()
error
{
if
c
!=
nil
{
c
.
closeCalls
.
Add
(
1
)
}
return
nil
}
func
(
c
*
closeSpyFrameConn
)
CloseCalls
()
int32
{
if
c
==
nil
{
return
0
}
return
c
.
closeCalls
.
Load
()
}
func
TestRelay_BasicRelayAndUsage
(
t
*
testing
.
T
)
{
t
.
Parallel
()
clientConn
:=
newPassthroughTestFrameConn
(
nil
,
false
)
upstreamConn
:=
newPassthroughTestFrameConn
([]
passthroughTestFrame
{
{
msgType
:
coderws
.
MessageText
,
payload
:
[]
byte
(
`{"type":"response.completed","response":{"id":"resp_123","usage":{"input_tokens":7,"output_tokens":3,"input_tokens_details":{"cached_tokens":2}}}}`
),
},
},
true
)
firstPayload
:=
[]
byte
(
`{"type":"response.create","model":"gpt-5.3-codex","input":[{"type":"input_text","text":"hello"}]}`
)
ctx
,
cancel
:=
context
.
WithTimeout
(
context
.
Background
(),
2
*
time
.
Second
)
defer
cancel
()
result
,
relayExit
:=
Relay
(
ctx
,
clientConn
,
upstreamConn
,
firstPayload
,
RelayOptions
{})
require
.
Nil
(
t
,
relayExit
)
require
.
Equal
(
t
,
"gpt-5.3-codex"
,
result
.
RequestModel
)
require
.
Equal
(
t
,
"resp_123"
,
result
.
RequestID
)
require
.
Equal
(
t
,
"response.completed"
,
result
.
TerminalEventType
)
require
.
Equal
(
t
,
7
,
result
.
Usage
.
InputTokens
)
require
.
Equal
(
t
,
3
,
result
.
Usage
.
OutputTokens
)
require
.
Equal
(
t
,
2
,
result
.
Usage
.
CacheReadInputTokens
)
require
.
NotNil
(
t
,
result
.
FirstTokenMs
)
require
.
Equal
(
t
,
int64
(
1
),
result
.
ClientToUpstreamFrames
)
require
.
Equal
(
t
,
int64
(
1
),
result
.
UpstreamToClientFrames
)
require
.
Equal
(
t
,
int64
(
0
),
result
.
DroppedDownstreamFrames
)
upstreamWrites
:=
upstreamConn
.
Writes
()
require
.
Len
(
t
,
upstreamWrites
,
1
)
require
.
Equal
(
t
,
coderws
.
MessageText
,
upstreamWrites
[
0
]
.
msgType
)
require
.
JSONEq
(
t
,
string
(
firstPayload
),
string
(
upstreamWrites
[
0
]
.
payload
))
clientWrites
:=
clientConn
.
Writes
()
require
.
Len
(
t
,
clientWrites
,
1
)
require
.
Equal
(
t
,
coderws
.
MessageText
,
clientWrites
[
0
]
.
msgType
)
require
.
JSONEq
(
t
,
`{"type":"response.completed","response":{"id":"resp_123","usage":{"input_tokens":7,"output_tokens":3,"input_tokens_details":{"cached_tokens":2}}}}`
,
string
(
clientWrites
[
0
]
.
payload
))
}
func
TestRelay_FunctionCallOutputBytesPreserved
(
t
*
testing
.
T
)
{
t
.
Parallel
()
clientConn
:=
newPassthroughTestFrameConn
(
nil
,
false
)
upstreamConn
:=
newPassthroughTestFrameConn
([]
passthroughTestFrame
{
{
msgType
:
coderws
.
MessageText
,
payload
:
[]
byte
(
`{"type":"response.completed","response":{"id":"resp_func","usage":{"input_tokens":1,"output_tokens":1}}}`
),
},
},
true
)
firstPayload
:=
[]
byte
(
`{"type":"response.create","model":"gpt-5.3-codex","input":[{"type":"function_call_output","call_id":"call_abc123","output":"{\"ok\":true}"}]}`
)
ctx
,
cancel
:=
context
.
WithTimeout
(
context
.
Background
(),
2
*
time
.
Second
)
defer
cancel
()
_
,
relayExit
:=
Relay
(
ctx
,
clientConn
,
upstreamConn
,
firstPayload
,
RelayOptions
{})
require
.
Nil
(
t
,
relayExit
)
upstreamWrites
:=
upstreamConn
.
Writes
()
require
.
Len
(
t
,
upstreamWrites
,
1
)
require
.
Equal
(
t
,
coderws
.
MessageText
,
upstreamWrites
[
0
]
.
msgType
)
require
.
Equal
(
t
,
firstPayload
,
upstreamWrites
[
0
]
.
payload
)
}
func
TestRelay_UpstreamDisconnect
(
t
*
testing
.
T
)
{
t
.
Parallel
()
// 上游立即关闭(EOF),客户端不发送额外帧
clientConn
:=
newPassthroughTestFrameConn
(
nil
,
false
)
upstreamConn
:=
newPassthroughTestFrameConn
(
nil
,
true
)
// 立即 close -> EOF
firstPayload
:=
[]
byte
(
`{"type":"response.create","model":"gpt-4o","input":[]}`
)
ctx
,
cancel
:=
context
.
WithTimeout
(
context
.
Background
(),
2
*
time
.
Second
)
defer
cancel
()
result
,
relayExit
:=
Relay
(
ctx
,
clientConn
,
upstreamConn
,
firstPayload
,
RelayOptions
{})
// 上游 EOF 属于 disconnect,标记为 graceful
require
.
Nil
(
t
,
relayExit
,
"上游 EOF 应被视为 graceful disconnect"
)
require
.
Equal
(
t
,
"gpt-4o"
,
result
.
RequestModel
)
}
func
TestRelay_ClientDisconnect
(
t
*
testing
.
T
)
{
t
.
Parallel
()
// 客户端立即关闭(EOF),上游阻塞读取直到 context 取消
clientConn
:=
newPassthroughTestFrameConn
(
nil
,
true
)
// 立即 close -> EOF
upstreamConn
:=
newPassthroughTestFrameConn
(
nil
,
false
)
firstPayload
:=
[]
byte
(
`{"type":"response.create","model":"gpt-4o","input":[]}`
)
ctx
,
cancel
:=
context
.
WithTimeout
(
context
.
Background
(),
2
*
time
.
Second
)
defer
cancel
()
result
,
relayExit
:=
Relay
(
ctx
,
clientConn
,
upstreamConn
,
firstPayload
,
RelayOptions
{})
require
.
NotNil
(
t
,
relayExit
,
"客户端 EOF 应返回可观测的中断状态"
)
require
.
Equal
(
t
,
"client_disconnected"
,
relayExit
.
Stage
)
require
.
Equal
(
t
,
"gpt-4o"
,
result
.
RequestModel
)
}
func
TestRelay_ClientDisconnect_DrainCapturesLateUsage
(
t
*
testing
.
T
)
{
t
.
Parallel
()
clientConn
:=
newPassthroughTestFrameConn
(
nil
,
true
)
upstreamBase
:=
newPassthroughTestFrameConn
([]
passthroughTestFrame
{
{
msgType
:
coderws
.
MessageText
,
payload
:
[]
byte
(
`{"type":"response.completed","response":{"id":"resp_drain","usage":{"input_tokens":6,"output_tokens":4,"input_tokens_details":{"cached_tokens":1}}}}`
),
},
},
true
)
upstreamConn
:=
&
delayedReadFrameConn
{
base
:
upstreamBase
,
firstDelay
:
80
*
time
.
Millisecond
,
}
firstPayload
:=
[]
byte
(
`{"type":"response.create","model":"gpt-4o","input":[]}`
)
ctx
,
cancel
:=
context
.
WithTimeout
(
context
.
Background
(),
2
*
time
.
Second
)
defer
cancel
()
result
,
relayExit
:=
Relay
(
ctx
,
clientConn
,
upstreamConn
,
firstPayload
,
RelayOptions
{
UpstreamDrainTimeout
:
400
*
time
.
Millisecond
,
})
require
.
NotNil
(
t
,
relayExit
)
require
.
Equal
(
t
,
"client_disconnected"
,
relayExit
.
Stage
)
require
.
Equal
(
t
,
"resp_drain"
,
result
.
RequestID
)
require
.
Equal
(
t
,
"response.completed"
,
result
.
TerminalEventType
)
require
.
Equal
(
t
,
6
,
result
.
Usage
.
InputTokens
)
require
.
Equal
(
t
,
4
,
result
.
Usage
.
OutputTokens
)
require
.
Equal
(
t
,
1
,
result
.
Usage
.
CacheReadInputTokens
)
require
.
Equal
(
t
,
int64
(
1
),
result
.
ClientToUpstreamFrames
)
require
.
Equal
(
t
,
int64
(
0
),
result
.
UpstreamToClientFrames
)
require
.
Equal
(
t
,
int64
(
1
),
result
.
DroppedDownstreamFrames
)
}
func
TestRelay_IdleTimeout
(
t
*
testing
.
T
)
{
t
.
Parallel
()
// 客户端和上游都不发送帧,idle timeout 应触发
clientConn
:=
newPassthroughTestFrameConn
(
nil
,
false
)
upstreamConn
:=
newPassthroughTestFrameConn
(
nil
,
false
)
firstPayload
:=
[]
byte
(
`{"type":"response.create","model":"gpt-4o","input":[]}`
)
ctx
,
cancel
:=
context
.
WithTimeout
(
context
.
Background
(),
5
*
time
.
Second
)
defer
cancel
()
// 使用快进时间来加速 idle timeout
now
:=
time
.
Now
()
callCount
:=
0
nowFn
:=
func
()
time
.
Time
{
callCount
++
// 前几次调用返回正常时间(初始化阶段),之后快进
if
callCount
<=
5
{
return
now
}
return
now
.
Add
(
time
.
Hour
)
// 快进到超时
}
result
,
relayExit
:=
Relay
(
ctx
,
clientConn
,
upstreamConn
,
firstPayload
,
RelayOptions
{
IdleTimeout
:
2
*
time
.
Second
,
Now
:
nowFn
,
})
require
.
NotNil
(
t
,
relayExit
,
"应因 idle timeout 退出"
)
require
.
Equal
(
t
,
"idle_timeout"
,
relayExit
.
Stage
)
require
.
Equal
(
t
,
"gpt-4o"
,
result
.
RequestModel
)
}
func
TestRelay_IdleTimeoutDoesNotCloseClientOnError
(
t
*
testing
.
T
)
{
t
.
Parallel
()
clientConn
:=
&
closeSpyFrameConn
{}
upstreamConn
:=
&
closeSpyFrameConn
{}
firstPayload
:=
[]
byte
(
`{"type":"response.create","model":"gpt-4o","input":[]}`
)
ctx
,
cancel
:=
context
.
WithTimeout
(
context
.
Background
(),
5
*
time
.
Second
)
defer
cancel
()
now
:=
time
.
Now
()
callCount
:=
0
nowFn
:=
func
()
time
.
Time
{
callCount
++
if
callCount
<=
5
{
return
now
}
return
now
.
Add
(
time
.
Hour
)
}
_
,
relayExit
:=
Relay
(
ctx
,
clientConn
,
upstreamConn
,
firstPayload
,
RelayOptions
{
IdleTimeout
:
2
*
time
.
Second
,
Now
:
nowFn
,
})
require
.
NotNil
(
t
,
relayExit
,
"应因 idle timeout 退出"
)
require
.
Equal
(
t
,
"idle_timeout"
,
relayExit
.
Stage
)
require
.
Zero
(
t
,
clientConn
.
CloseCalls
(),
"错误路径不应提前关闭客户端连接,交给上层决定 close code"
)
require
.
GreaterOrEqual
(
t
,
upstreamConn
.
CloseCalls
(),
int32
(
1
))
}
func
TestRelay_NilConnections
(
t
*
testing
.
T
)
{
t
.
Parallel
()
firstPayload
:=
[]
byte
(
`{"type":"response.create","model":"gpt-4o","input":[]}`
)
ctx
:=
context
.
Background
()
t
.
Run
(
"nil client conn"
,
func
(
t
*
testing
.
T
)
{
upstreamConn
:=
newPassthroughTestFrameConn
(
nil
,
true
)
_
,
relayExit
:=
Relay
(
ctx
,
nil
,
upstreamConn
,
firstPayload
,
RelayOptions
{})
require
.
NotNil
(
t
,
relayExit
)
require
.
Equal
(
t
,
"relay_init"
,
relayExit
.
Stage
)
require
.
Contains
(
t
,
relayExit
.
Err
.
Error
(),
"nil"
)
})
t
.
Run
(
"nil upstream conn"
,
func
(
t
*
testing
.
T
)
{
clientConn
:=
newPassthroughTestFrameConn
(
nil
,
true
)
_
,
relayExit
:=
Relay
(
ctx
,
clientConn
,
nil
,
firstPayload
,
RelayOptions
{})
require
.
NotNil
(
t
,
relayExit
)
require
.
Equal
(
t
,
"relay_init"
,
relayExit
.
Stage
)
require
.
Contains
(
t
,
relayExit
.
Err
.
Error
(),
"nil"
)
})
}
func
TestRelay_MultipleUpstreamMessages
(
t
*
testing
.
T
)
{
t
.
Parallel
()
// 上游发送多个事件(delta + completed),验证多帧中继和 usage 聚合
clientConn
:=
newPassthroughTestFrameConn
(
nil
,
false
)
upstreamConn
:=
newPassthroughTestFrameConn
([]
passthroughTestFrame
{
{
msgType
:
coderws
.
MessageText
,
payload
:
[]
byte
(
`{"type":"response.output_text.delta","delta":"Hello"}`
),
},
{
msgType
:
coderws
.
MessageText
,
payload
:
[]
byte
(
`{"type":"response.output_text.delta","delta":" world"}`
),
},
{
msgType
:
coderws
.
MessageText
,
payload
:
[]
byte
(
`{"type":"response.completed","response":{"id":"resp_multi","usage":{"input_tokens":10,"output_tokens":5,"input_tokens_details":{"cached_tokens":3}}}}`
),
},
},
true
)
firstPayload
:=
[]
byte
(
`{"type":"response.create","model":"gpt-4o","input":[{"type":"input_text","text":"hi"}]}`
)
ctx
,
cancel
:=
context
.
WithTimeout
(
context
.
Background
(),
2
*
time
.
Second
)
defer
cancel
()
result
,
relayExit
:=
Relay
(
ctx
,
clientConn
,
upstreamConn
,
firstPayload
,
RelayOptions
{})
require
.
Nil
(
t
,
relayExit
)
require
.
Equal
(
t
,
"resp_multi"
,
result
.
RequestID
)
require
.
Equal
(
t
,
"response.completed"
,
result
.
TerminalEventType
)
require
.
Equal
(
t
,
10
,
result
.
Usage
.
InputTokens
)
require
.
Equal
(
t
,
5
,
result
.
Usage
.
OutputTokens
)
require
.
Equal
(
t
,
3
,
result
.
Usage
.
CacheReadInputTokens
)
require
.
NotNil
(
t
,
result
.
FirstTokenMs
)
// 验证所有 3 个上游帧都转发给了客户端
clientWrites
:=
clientConn
.
Writes
()
require
.
Len
(
t
,
clientWrites
,
3
)
}
func
TestRelay_OnTurnComplete_PerTerminalEvent
(
t
*
testing
.
T
)
{
t
.
Parallel
()
clientConn
:=
newPassthroughTestFrameConn
(
nil
,
false
)
upstreamConn
:=
newPassthroughTestFrameConn
([]
passthroughTestFrame
{
{
msgType
:
coderws
.
MessageText
,
payload
:
[]
byte
(
`{"type":"response.completed","response":{"id":"resp_turn_1","usage":{"input_tokens":2,"output_tokens":1}}}`
),
},
{
msgType
:
coderws
.
MessageText
,
payload
:
[]
byte
(
`{"type":"response.failed","response":{"id":"resp_turn_2","usage":{"input_tokens":3,"output_tokens":4}}}`
),
},
},
true
)
firstPayload
:=
[]
byte
(
`{"type":"response.create","model":"gpt-5.3-codex","input":[]}`
)
ctx
,
cancel
:=
context
.
WithTimeout
(
context
.
Background
(),
2
*
time
.
Second
)
defer
cancel
()
turns
:=
make
([]
RelayTurnResult
,
0
,
2
)
result
,
relayExit
:=
Relay
(
ctx
,
clientConn
,
upstreamConn
,
firstPayload
,
RelayOptions
{
OnTurnComplete
:
func
(
turn
RelayTurnResult
)
{
turns
=
append
(
turns
,
turn
)
},
})
require
.
Nil
(
t
,
relayExit
)
require
.
Len
(
t
,
turns
,
2
)
require
.
Equal
(
t
,
"resp_turn_1"
,
turns
[
0
]
.
RequestID
)
require
.
Equal
(
t
,
"response.completed"
,
turns
[
0
]
.
TerminalEventType
)
require
.
Equal
(
t
,
2
,
turns
[
0
]
.
Usage
.
InputTokens
)
require
.
Equal
(
t
,
1
,
turns
[
0
]
.
Usage
.
OutputTokens
)
require
.
Equal
(
t
,
"resp_turn_2"
,
turns
[
1
]
.
RequestID
)
require
.
Equal
(
t
,
"response.failed"
,
turns
[
1
]
.
TerminalEventType
)
require
.
Equal
(
t
,
3
,
turns
[
1
]
.
Usage
.
InputTokens
)
require
.
Equal
(
t
,
4
,
turns
[
1
]
.
Usage
.
OutputTokens
)
require
.
Equal
(
t
,
5
,
result
.
Usage
.
InputTokens
)
require
.
Equal
(
t
,
5
,
result
.
Usage
.
OutputTokens
)
}
func
TestRelay_OnTurnComplete_ProvidesTurnMetrics
(
t
*
testing
.
T
)
{
t
.
Parallel
()
clientConn
:=
newPassthroughTestFrameConn
(
nil
,
false
)
upstreamConn
:=
newPassthroughTestFrameConn
([]
passthroughTestFrame
{
{
msgType
:
coderws
.
MessageText
,
payload
:
[]
byte
(
`{"type":"response.output_text.delta","response_id":"resp_metric","delta":"hi"}`
),
},
{
msgType
:
coderws
.
MessageText
,
payload
:
[]
byte
(
`{"type":"response.completed","response":{"id":"resp_metric","usage":{"input_tokens":2,"output_tokens":1}}}`
),
},
},
true
)
firstPayload
:=
[]
byte
(
`{"type":"response.create","model":"gpt-5.3-codex","input":[]}`
)
ctx
,
cancel
:=
context
.
WithTimeout
(
context
.
Background
(),
2
*
time
.
Second
)
defer
cancel
()
base
:=
time
.
Unix
(
0
,
0
)
var
nowTick
atomic
.
Int64
nowFn
:=
func
()
time
.
Time
{
step
:=
nowTick
.
Add
(
1
)
return
base
.
Add
(
time
.
Duration
(
step
)
*
5
*
time
.
Millisecond
)
}
var
turn
RelayTurnResult
result
,
relayExit
:=
Relay
(
ctx
,
clientConn
,
upstreamConn
,
firstPayload
,
RelayOptions
{
Now
:
nowFn
,
OnTurnComplete
:
func
(
current
RelayTurnResult
)
{
turn
=
current
},
})
require
.
Nil
(
t
,
relayExit
)
require
.
Equal
(
t
,
"resp_metric"
,
turn
.
RequestID
)
require
.
Equal
(
t
,
"response.completed"
,
turn
.
TerminalEventType
)
require
.
NotNil
(
t
,
turn
.
FirstTokenMs
)
require
.
GreaterOrEqual
(
t
,
*
turn
.
FirstTokenMs
,
0
)
require
.
Greater
(
t
,
turn
.
Duration
.
Milliseconds
(),
int64
(
0
))
require
.
NotNil
(
t
,
result
.
FirstTokenMs
)
require
.
Greater
(
t
,
result
.
Duration
.
Milliseconds
(),
int64
(
0
))
}
func
TestRelay_BinaryFramePassthrough
(
t
*
testing
.
T
)
{
t
.
Parallel
()
// 验证 binary frame 被透传但不进行 usage 解析
binaryPayload
:=
[]
byte
{
0x00
,
0x01
,
0x02
,
0x03
}
clientConn
:=
newPassthroughTestFrameConn
(
nil
,
false
)
upstreamConn
:=
newPassthroughTestFrameConn
([]
passthroughTestFrame
{
{
msgType
:
coderws
.
MessageBinary
,
payload
:
binaryPayload
,
},
},
true
)
firstPayload
:=
[]
byte
(
`{"type":"response.create","model":"gpt-4o","input":[]}`
)
ctx
,
cancel
:=
context
.
WithTimeout
(
context
.
Background
(),
2
*
time
.
Second
)
defer
cancel
()
result
,
relayExit
:=
Relay
(
ctx
,
clientConn
,
upstreamConn
,
firstPayload
,
RelayOptions
{})
require
.
Nil
(
t
,
relayExit
)
// binary frame 不解析 usage
require
.
Equal
(
t
,
0
,
result
.
Usage
.
InputTokens
)
clientWrites
:=
clientConn
.
Writes
()
require
.
Len
(
t
,
clientWrites
,
1
)
require
.
Equal
(
t
,
coderws
.
MessageBinary
,
clientWrites
[
0
]
.
msgType
)
require
.
Equal
(
t
,
binaryPayload
,
clientWrites
[
0
]
.
payload
)
}
func
TestRelay_BinaryJSONFrameSkipsObservation
(
t
*
testing
.
T
)
{
t
.
Parallel
()
clientConn
:=
newPassthroughTestFrameConn
(
nil
,
false
)
upstreamConn
:=
newPassthroughTestFrameConn
([]
passthroughTestFrame
{
{
msgType
:
coderws
.
MessageBinary
,
payload
:
[]
byte
(
`{"type":"response.completed","response":{"id":"resp_binary","usage":{"input_tokens":7,"output_tokens":3}}}`
),
},
},
true
)
firstPayload
:=
[]
byte
(
`{"type":"response.create","model":"gpt-4o","input":[]}`
)
ctx
,
cancel
:=
context
.
WithTimeout
(
context
.
Background
(),
2
*
time
.
Second
)
defer
cancel
()
result
,
relayExit
:=
Relay
(
ctx
,
clientConn
,
upstreamConn
,
firstPayload
,
RelayOptions
{})
require
.
Nil
(
t
,
relayExit
)
require
.
Equal
(
t
,
0
,
result
.
Usage
.
InputTokens
)
require
.
Equal
(
t
,
""
,
result
.
RequestID
)
require
.
Equal
(
t
,
""
,
result
.
TerminalEventType
)
clientWrites
:=
clientConn
.
Writes
()
require
.
Len
(
t
,
clientWrites
,
1
)
require
.
Equal
(
t
,
coderws
.
MessageBinary
,
clientWrites
[
0
]
.
msgType
)
}
func
TestRelay_UpstreamErrorEventPassthroughRaw
(
t
*
testing
.
T
)
{
t
.
Parallel
()
clientConn
:=
newPassthroughTestFrameConn
(
nil
,
false
)
errorEvent
:=
[]
byte
(
`{"type":"error","error":{"type":"invalid_request_error","message":"No tool call found"}}`
)
upstreamConn
:=
newPassthroughTestFrameConn
([]
passthroughTestFrame
{
{
msgType
:
coderws
.
MessageText
,
payload
:
errorEvent
,
},
},
true
)
firstPayload
:=
[]
byte
(
`{"type":"response.create","model":"gpt-4o","input":[]}`
)
ctx
,
cancel
:=
context
.
WithTimeout
(
context
.
Background
(),
2
*
time
.
Second
)
defer
cancel
()
_
,
relayExit
:=
Relay
(
ctx
,
clientConn
,
upstreamConn
,
firstPayload
,
RelayOptions
{})
require
.
Nil
(
t
,
relayExit
)
clientWrites
:=
clientConn
.
Writes
()
require
.
Len
(
t
,
clientWrites
,
1
)
require
.
Equal
(
t
,
coderws
.
MessageText
,
clientWrites
[
0
]
.
msgType
)
require
.
Equal
(
t
,
errorEvent
,
clientWrites
[
0
]
.
payload
)
}
func
TestRelay_PreservesFirstMessageType
(
t
*
testing
.
T
)
{
t
.
Parallel
()
clientConn
:=
newPassthroughTestFrameConn
(
nil
,
false
)
upstreamConn
:=
newPassthroughTestFrameConn
(
nil
,
true
)
firstPayload
:=
[]
byte
(
`{"type":"response.create","model":"gpt-4o","input":[]}`
)
ctx
,
cancel
:=
context
.
WithTimeout
(
context
.
Background
(),
2
*
time
.
Second
)
defer
cancel
()
_
,
relayExit
:=
Relay
(
ctx
,
clientConn
,
upstreamConn
,
firstPayload
,
RelayOptions
{
FirstMessageType
:
coderws
.
MessageBinary
,
})
require
.
Nil
(
t
,
relayExit
)
upstreamWrites
:=
upstreamConn
.
Writes
()
require
.
Len
(
t
,
upstreamWrites
,
1
)
require
.
Equal
(
t
,
coderws
.
MessageBinary
,
upstreamWrites
[
0
]
.
msgType
)
require
.
Equal
(
t
,
firstPayload
,
upstreamWrites
[
0
]
.
payload
)
}
func
TestRelay_UsageParseFailureDoesNotBlockRelay
(
t
*
testing
.
T
)
{
baseline
:=
SnapshotMetrics
()
.
UsageParseFailureTotal
// 上游发送无效 JSON(非 usage 格式),不应影响透传
clientConn
:=
newPassthroughTestFrameConn
(
nil
,
false
)
upstreamConn
:=
newPassthroughTestFrameConn
([]
passthroughTestFrame
{
{
msgType
:
coderws
.
MessageText
,
payload
:
[]
byte
(
`{"type":"response.completed","response":{"id":"resp_bad","usage":"not_an_object"}}`
),
},
},
true
)
firstPayload
:=
[]
byte
(
`{"type":"response.create","model":"gpt-4o","input":[]}`
)
ctx
,
cancel
:=
context
.
WithTimeout
(
context
.
Background
(),
2
*
time
.
Second
)
defer
cancel
()
result
,
relayExit
:=
Relay
(
ctx
,
clientConn
,
upstreamConn
,
firstPayload
,
RelayOptions
{})
require
.
Nil
(
t
,
relayExit
)
// usage 解析失败,值为 0 但不影响透传
require
.
Equal
(
t
,
0
,
result
.
Usage
.
InputTokens
)
require
.
Equal
(
t
,
"response.completed"
,
result
.
TerminalEventType
)
// 帧仍然被转发
clientWrites
:=
clientConn
.
Writes
()
require
.
Len
(
t
,
clientWrites
,
1
)
require
.
GreaterOrEqual
(
t
,
SnapshotMetrics
()
.
UsageParseFailureTotal
,
baseline
+
1
)
}
func
TestRelay_WriteUpstreamFirstMessageFails
(
t
*
testing
.
T
)
{
t
.
Parallel
()
// 上游连接立即关闭,首包写入失败
upstreamConn
:=
newPassthroughTestFrameConn
(
nil
,
true
)
_
=
upstreamConn
.
Close
()
// 覆盖 WriteFrame 使其返回错误
errConn
:=
&
errorOnWriteFrameConn
{}
clientConn
:=
newPassthroughTestFrameConn
(
nil
,
false
)
firstPayload
:=
[]
byte
(
`{"type":"response.create","model":"gpt-4o","input":[]}`
)
ctx
,
cancel
:=
context
.
WithTimeout
(
context
.
Background
(),
2
*
time
.
Second
)
defer
cancel
()
_
,
relayExit
:=
Relay
(
ctx
,
clientConn
,
errConn
,
firstPayload
,
RelayOptions
{})
require
.
NotNil
(
t
,
relayExit
)
require
.
Equal
(
t
,
"write_upstream"
,
relayExit
.
Stage
)
}
func
TestRelay_ContextCanceled
(
t
*
testing
.
T
)
{
t
.
Parallel
()
clientConn
:=
newPassthroughTestFrameConn
(
nil
,
false
)
upstreamConn
:=
newPassthroughTestFrameConn
(
nil
,
false
)
firstPayload
:=
[]
byte
(
`{"type":"response.create","model":"gpt-4o","input":[]}`
)
// 立即取消 context
ctx
,
cancel
:=
context
.
WithCancel
(
context
.
Background
())
cancel
()
_
,
relayExit
:=
Relay
(
ctx
,
clientConn
,
upstreamConn
,
firstPayload
,
RelayOptions
{})
// context 取消导致写首包失败
require
.
NotNil
(
t
,
relayExit
)
}
func
TestRelay_TraceEvents_ContainsLifecycleStages
(
t
*
testing
.
T
)
{
t
.
Parallel
()
clientConn
:=
newPassthroughTestFrameConn
(
nil
,
false
)
upstreamConn
:=
newPassthroughTestFrameConn
([]
passthroughTestFrame
{
{
msgType
:
coderws
.
MessageText
,
payload
:
[]
byte
(
`{"type":"response.completed","response":{"id":"resp_trace","usage":{"input_tokens":1,"output_tokens":1}}}`
),
},
},
true
)
firstPayload
:=
[]
byte
(
`{"type":"response.create","model":"gpt-4o","input":[]}`
)
ctx
,
cancel
:=
context
.
WithTimeout
(
context
.
Background
(),
2
*
time
.
Second
)
defer
cancel
()
stages
:=
make
([]
string
,
0
,
8
)
var
stagesMu
sync
.
Mutex
_
,
relayExit
:=
Relay
(
ctx
,
clientConn
,
upstreamConn
,
firstPayload
,
RelayOptions
{
OnTrace
:
func
(
event
RelayTraceEvent
)
{
stagesMu
.
Lock
()
stages
=
append
(
stages
,
event
.
Stage
)
stagesMu
.
Unlock
()
},
})
require
.
Nil
(
t
,
relayExit
)
stagesMu
.
Lock
()
capturedStages
:=
append
([]
string
(
nil
),
stages
...
)
stagesMu
.
Unlock
()
require
.
Contains
(
t
,
capturedStages
,
"relay_start"
)
require
.
Contains
(
t
,
capturedStages
,
"write_first_message_ok"
)
require
.
Contains
(
t
,
capturedStages
,
"first_exit"
)
require
.
Contains
(
t
,
capturedStages
,
"relay_complete"
)
}
func
TestRelay_TraceEvents_IdleTimeout
(
t
*
testing
.
T
)
{
t
.
Parallel
()
clientConn
:=
newPassthroughTestFrameConn
(
nil
,
false
)
upstreamConn
:=
newPassthroughTestFrameConn
(
nil
,
false
)
firstPayload
:=
[]
byte
(
`{"type":"response.create","model":"gpt-4o","input":[]}`
)
ctx
,
cancel
:=
context
.
WithTimeout
(
context
.
Background
(),
5
*
time
.
Second
)
defer
cancel
()
now
:=
time
.
Now
()
callCount
:=
0
nowFn
:=
func
()
time
.
Time
{
callCount
++
if
callCount
<=
5
{
return
now
}
return
now
.
Add
(
time
.
Hour
)
}
stages
:=
make
([]
string
,
0
,
8
)
var
stagesMu
sync
.
Mutex
_
,
relayExit
:=
Relay
(
ctx
,
clientConn
,
upstreamConn
,
firstPayload
,
RelayOptions
{
IdleTimeout
:
2
*
time
.
Second
,
Now
:
nowFn
,
OnTrace
:
func
(
event
RelayTraceEvent
)
{
stagesMu
.
Lock
()
stages
=
append
(
stages
,
event
.
Stage
)
stagesMu
.
Unlock
()
},
})
require
.
NotNil
(
t
,
relayExit
)
require
.
Equal
(
t
,
"idle_timeout"
,
relayExit
.
Stage
)
stagesMu
.
Lock
()
capturedStages
:=
append
([]
string
(
nil
),
stages
...
)
stagesMu
.
Unlock
()
require
.
Contains
(
t
,
capturedStages
,
"idle_timeout_triggered"
)
require
.
Contains
(
t
,
capturedStages
,
"relay_exit"
)
}
// errorOnWriteFrameConn 是一个写入总是失败的 FrameConn 实现,用于测试首包写入失败。
type
errorOnWriteFrameConn
struct
{}
func
(
c
*
errorOnWriteFrameConn
)
ReadFrame
(
ctx
context
.
Context
)
(
coderws
.
MessageType
,
[]
byte
,
error
)
{
<-
ctx
.
Done
()
return
coderws
.
MessageText
,
nil
,
ctx
.
Err
()
}
func
(
c
*
errorOnWriteFrameConn
)
WriteFrame
(
_
context
.
Context
,
_
coderws
.
MessageType
,
_
[]
byte
)
error
{
return
errors
.
New
(
"write failed: connection refused"
)
}
func
(
c
*
errorOnWriteFrameConn
)
Close
()
error
{
return
nil
}
backend/internal/service/openai_ws_v2_passthrough_adapter.go
0 → 100644
View file @
7076717b
package
service
import
(
"context"
"errors"
"fmt"
"net/http"
"net/url"
"strings"
"sync/atomic"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
openaiwsv2
"github.com/Wei-Shaw/sub2api/internal/service/openai_ws_v2"
coderws
"github.com/coder/websocket"
"github.com/gin-gonic/gin"
"github.com/tidwall/gjson"
)
type
openAIWSClientFrameConn
struct
{
conn
*
coderws
.
Conn
}
const
openaiWSV2PassthroughModeFields
=
"ws_mode=passthrough ws_router=v2"
var
_
openaiwsv2
.
FrameConn
=
(
*
openAIWSClientFrameConn
)(
nil
)
func
(
c
*
openAIWSClientFrameConn
)
ReadFrame
(
ctx
context
.
Context
)
(
coderws
.
MessageType
,
[]
byte
,
error
)
{
if
c
==
nil
||
c
.
conn
==
nil
{
return
coderws
.
MessageText
,
nil
,
errOpenAIWSConnClosed
}
if
ctx
==
nil
{
ctx
=
context
.
Background
()
}
return
c
.
conn
.
Read
(
ctx
)
}
func
(
c
*
openAIWSClientFrameConn
)
WriteFrame
(
ctx
context
.
Context
,
msgType
coderws
.
MessageType
,
payload
[]
byte
)
error
{
if
c
==
nil
||
c
.
conn
==
nil
{
return
errOpenAIWSConnClosed
}
if
ctx
==
nil
{
ctx
=
context
.
Background
()
}
return
c
.
conn
.
Write
(
ctx
,
msgType
,
payload
)
}
func
(
c
*
openAIWSClientFrameConn
)
Close
()
error
{
if
c
==
nil
||
c
.
conn
==
nil
{
return
nil
}
_
=
c
.
conn
.
Close
(
coderws
.
StatusNormalClosure
,
""
)
_
=
c
.
conn
.
CloseNow
()
return
nil
}
func
(
s
*
OpenAIGatewayService
)
proxyResponsesWebSocketV2Passthrough
(
ctx
context
.
Context
,
c
*
gin
.
Context
,
clientConn
*
coderws
.
Conn
,
account
*
Account
,
token
string
,
firstClientMessage
[]
byte
,
hooks
*
OpenAIWSIngressHooks
,
wsDecision
OpenAIWSProtocolDecision
,
)
error
{
if
s
==
nil
{
return
errors
.
New
(
"service is nil"
)
}
if
clientConn
==
nil
{
return
errors
.
New
(
"client websocket is nil"
)
}
if
account
==
nil
{
return
errors
.
New
(
"account is nil"
)
}
if
strings
.
TrimSpace
(
token
)
==
""
{
return
errors
.
New
(
"token is empty"
)
}
requestModel
:=
strings
.
TrimSpace
(
gjson
.
GetBytes
(
firstClientMessage
,
"model"
)
.
String
())
requestPreviousResponseID
:=
strings
.
TrimSpace
(
gjson
.
GetBytes
(
firstClientMessage
,
"previous_response_id"
)
.
String
())
logOpenAIWSV2Passthrough
(
"relay_start account_id=%d model=%s previous_response_id=%s first_message_type=%s first_message_bytes=%d"
,
account
.
ID
,
truncateOpenAIWSLogValue
(
requestModel
,
openAIWSLogValueMaxLen
),
truncateOpenAIWSLogValue
(
requestPreviousResponseID
,
openAIWSIDValueMaxLen
),
openaiwsv2RelayMessageTypeName
(
coderws
.
MessageText
),
len
(
firstClientMessage
),
)
wsURL
,
err
:=
s
.
buildOpenAIResponsesWSURL
(
account
)
if
err
!=
nil
{
return
fmt
.
Errorf
(
"build ws url: %w"
,
err
)
}
wsHost
:=
"-"
wsPath
:=
"-"
if
parsedURL
,
parseErr
:=
url
.
Parse
(
wsURL
);
parseErr
==
nil
&&
parsedURL
!=
nil
{
wsHost
=
normalizeOpenAIWSLogValue
(
parsedURL
.
Host
)
wsPath
=
normalizeOpenAIWSLogValue
(
parsedURL
.
Path
)
}
logOpenAIWSV2Passthrough
(
"relay_dial_start account_id=%d ws_host=%s ws_path=%s proxy_enabled=%v"
,
account
.
ID
,
wsHost
,
wsPath
,
account
.
ProxyID
!=
nil
&&
account
.
Proxy
!=
nil
,
)
isCodexCLI
:=
false
if
c
!=
nil
{
isCodexCLI
=
openai
.
IsCodexCLIRequest
(
c
.
GetHeader
(
"User-Agent"
))
}
if
s
.
cfg
!=
nil
&&
s
.
cfg
.
Gateway
.
ForceCodexCLI
{
isCodexCLI
=
true
}
headers
,
_
:=
s
.
buildOpenAIWSHeaders
(
c
,
account
,
token
,
wsDecision
,
isCodexCLI
,
""
,
""
,
""
)
proxyURL
:=
""
if
account
.
ProxyID
!=
nil
&&
account
.
Proxy
!=
nil
{
proxyURL
=
account
.
Proxy
.
URL
()
}
dialer
:=
s
.
getOpenAIWSPassthroughDialer
()
if
dialer
==
nil
{
return
errors
.
New
(
"openai ws passthrough dialer is nil"
)
}
dialCtx
,
cancelDial
:=
context
.
WithTimeout
(
ctx
,
s
.
openAIWSDialTimeout
())
defer
cancelDial
()
upstreamConn
,
statusCode
,
handshakeHeaders
,
err
:=
dialer
.
Dial
(
dialCtx
,
wsURL
,
headers
,
proxyURL
)
if
err
!=
nil
{
logOpenAIWSV2Passthrough
(
"relay_dial_failed account_id=%d status_code=%d err=%s"
,
account
.
ID
,
statusCode
,
truncateOpenAIWSLogValue
(
err
.
Error
(),
openAIWSLogValueMaxLen
),
)
return
s
.
mapOpenAIWSPassthroughDialError
(
err
,
statusCode
,
handshakeHeaders
)
}
defer
func
()
{
_
=
upstreamConn
.
Close
()
}()
logOpenAIWSV2Passthrough
(
"relay_dial_ok account_id=%d status_code=%d upstream_request_id=%s"
,
account
.
ID
,
statusCode
,
openAIWSHeaderValueForLog
(
handshakeHeaders
,
"x-request-id"
),
)
upstreamFrameConn
,
ok
:=
upstreamConn
.
(
openaiwsv2
.
FrameConn
)
if
!
ok
{
return
errors
.
New
(
"openai ws passthrough upstream connection does not support frame relay"
)
}
completedTurns
:=
atomic
.
Int32
{}
relayResult
,
relayExit
:=
openaiwsv2
.
RunEntry
(
openaiwsv2
.
EntryInput
{
Ctx
:
ctx
,
ClientConn
:
&
openAIWSClientFrameConn
{
conn
:
clientConn
},
UpstreamConn
:
upstreamFrameConn
,
FirstClientMessage
:
firstClientMessage
,
Options
:
openaiwsv2
.
RelayOptions
{
WriteTimeout
:
s
.
openAIWSWriteTimeout
(),
IdleTimeout
:
s
.
openAIWSPassthroughIdleTimeout
(),
FirstMessageType
:
coderws
.
MessageText
,
OnUsageParseFailure
:
func
(
eventType
string
,
usageRaw
string
)
{
logOpenAIWSV2Passthrough
(
"usage_parse_failed event_type=%s usage_raw=%s"
,
truncateOpenAIWSLogValue
(
eventType
,
openAIWSLogValueMaxLen
),
truncateOpenAIWSLogValue
(
usageRaw
,
openAIWSLogValueMaxLen
),
)
},
OnTurnComplete
:
func
(
turn
openaiwsv2
.
RelayTurnResult
)
{
turnNo
:=
int
(
completedTurns
.
Add
(
1
))
turnResult
:=
&
OpenAIForwardResult
{
RequestID
:
turn
.
RequestID
,
Usage
:
OpenAIUsage
{
InputTokens
:
turn
.
Usage
.
InputTokens
,
OutputTokens
:
turn
.
Usage
.
OutputTokens
,
CacheCreationInputTokens
:
turn
.
Usage
.
CacheCreationInputTokens
,
CacheReadInputTokens
:
turn
.
Usage
.
CacheReadInputTokens
,
},
Model
:
turn
.
RequestModel
,
Stream
:
true
,
OpenAIWSMode
:
true
,
Duration
:
turn
.
Duration
,
FirstTokenMs
:
turn
.
FirstTokenMs
,
}
logOpenAIWSV2Passthrough
(
"relay_turn_completed account_id=%d turn=%d request_id=%s terminal_event=%s duration_ms=%d first_token_ms=%d input_tokens=%d output_tokens=%d cache_read_tokens=%d"
,
account
.
ID
,
turnNo
,
truncateOpenAIWSLogValue
(
turnResult
.
RequestID
,
openAIWSIDValueMaxLen
),
truncateOpenAIWSLogValue
(
turn
.
TerminalEventType
,
openAIWSLogValueMaxLen
),
turnResult
.
Duration
.
Milliseconds
(),
openAIWSFirstTokenMsForLog
(
turnResult
.
FirstTokenMs
),
turnResult
.
Usage
.
InputTokens
,
turnResult
.
Usage
.
OutputTokens
,
turnResult
.
Usage
.
CacheReadInputTokens
,
)
if
hooks
!=
nil
&&
hooks
.
AfterTurn
!=
nil
{
hooks
.
AfterTurn
(
turnNo
,
turnResult
,
nil
)
}
},
OnTrace
:
func
(
event
openaiwsv2
.
RelayTraceEvent
)
{
logOpenAIWSV2Passthrough
(
"relay_trace account_id=%d stage=%s direction=%s msg_type=%s bytes=%d graceful=%v wrote_downstream=%v err=%s"
,
account
.
ID
,
truncateOpenAIWSLogValue
(
event
.
Stage
,
openAIWSLogValueMaxLen
),
truncateOpenAIWSLogValue
(
event
.
Direction
,
openAIWSLogValueMaxLen
),
truncateOpenAIWSLogValue
(
event
.
MessageType
,
openAIWSLogValueMaxLen
),
event
.
PayloadBytes
,
event
.
Graceful
,
event
.
WroteDownstream
,
truncateOpenAIWSLogValue
(
event
.
Error
,
openAIWSLogValueMaxLen
),
)
},
},
})
result
:=
&
OpenAIForwardResult
{
RequestID
:
relayResult
.
RequestID
,
Usage
:
OpenAIUsage
{
InputTokens
:
relayResult
.
Usage
.
InputTokens
,
OutputTokens
:
relayResult
.
Usage
.
OutputTokens
,
CacheCreationInputTokens
:
relayResult
.
Usage
.
CacheCreationInputTokens
,
CacheReadInputTokens
:
relayResult
.
Usage
.
CacheReadInputTokens
,
},
Model
:
relayResult
.
RequestModel
,
Stream
:
true
,
OpenAIWSMode
:
true
,
Duration
:
relayResult
.
Duration
,
FirstTokenMs
:
relayResult
.
FirstTokenMs
,
}
turnCount
:=
int
(
completedTurns
.
Load
())
if
relayExit
==
nil
{
logOpenAIWSV2Passthrough
(
"relay_completed account_id=%d request_id=%s terminal_event=%s duration_ms=%d c2u_frames=%d u2c_frames=%d dropped_frames=%d turns=%d"
,
account
.
ID
,
truncateOpenAIWSLogValue
(
result
.
RequestID
,
openAIWSIDValueMaxLen
),
truncateOpenAIWSLogValue
(
relayResult
.
TerminalEventType
,
openAIWSLogValueMaxLen
),
result
.
Duration
.
Milliseconds
(),
relayResult
.
ClientToUpstreamFrames
,
relayResult
.
UpstreamToClientFrames
,
relayResult
.
DroppedDownstreamFrames
,
turnCount
,
)
// 正常路径按 terminal 事件逐 turn 已回调;仅在零 turn 场景兜底回调一次。
if
turnCount
==
0
&&
hooks
!=
nil
&&
hooks
.
AfterTurn
!=
nil
{
hooks
.
AfterTurn
(
1
,
result
,
nil
)
}
return
nil
}
logOpenAIWSV2Passthrough
(
"relay_failed account_id=%d stage=%s wrote_downstream=%v err=%s duration_ms=%d c2u_frames=%d u2c_frames=%d dropped_frames=%d turns=%d"
,
account
.
ID
,
truncateOpenAIWSLogValue
(
relayExit
.
Stage
,
openAIWSLogValueMaxLen
),
relayExit
.
WroteDownstream
,
truncateOpenAIWSLogValue
(
relayErrorText
(
relayExit
.
Err
),
openAIWSLogValueMaxLen
),
result
.
Duration
.
Milliseconds
(),
relayResult
.
ClientToUpstreamFrames
,
relayResult
.
UpstreamToClientFrames
,
relayResult
.
DroppedDownstreamFrames
,
turnCount
,
)
relayErr
:=
relayExit
.
Err
if
relayExit
.
Stage
==
"idle_timeout"
{
relayErr
=
NewOpenAIWSClientCloseError
(
coderws
.
StatusPolicyViolation
,
"client websocket idle timeout"
,
relayErr
,
)
}
turnErr
:=
wrapOpenAIWSIngressTurnError
(
relayExit
.
Stage
,
relayErr
,
relayExit
.
WroteDownstream
,
)
if
hooks
!=
nil
&&
hooks
.
AfterTurn
!=
nil
{
hooks
.
AfterTurn
(
turnCount
+
1
,
nil
,
turnErr
)
}
return
turnErr
}
func
(
s
*
OpenAIGatewayService
)
mapOpenAIWSPassthroughDialError
(
err
error
,
statusCode
int
,
handshakeHeaders
http
.
Header
,
)
error
{
if
err
==
nil
{
return
nil
}
wrappedErr
:=
err
var
dialErr
*
openAIWSDialError
if
!
errors
.
As
(
err
,
&
dialErr
)
{
wrappedErr
=
&
openAIWSDialError
{
StatusCode
:
statusCode
,
ResponseHeaders
:
cloneHeader
(
handshakeHeaders
),
Err
:
err
,
}
}
if
errors
.
Is
(
err
,
context
.
Canceled
)
{
return
err
}
if
errors
.
Is
(
err
,
context
.
DeadlineExceeded
)
{
return
NewOpenAIWSClientCloseError
(
coderws
.
StatusTryAgainLater
,
"upstream websocket connect timeout"
,
wrappedErr
,
)
}
if
statusCode
==
http
.
StatusTooManyRequests
{
return
NewOpenAIWSClientCloseError
(
coderws
.
StatusTryAgainLater
,
"upstream websocket is busy, please retry later"
,
wrappedErr
,
)
}
if
statusCode
==
http
.
StatusUnauthorized
||
statusCode
==
http
.
StatusForbidden
{
return
NewOpenAIWSClientCloseError
(
coderws
.
StatusPolicyViolation
,
"upstream websocket authentication failed"
,
wrappedErr
,
)
}
if
statusCode
>=
http
.
StatusBadRequest
&&
statusCode
<
http
.
StatusInternalServerError
{
return
NewOpenAIWSClientCloseError
(
coderws
.
StatusPolicyViolation
,
"upstream websocket handshake rejected"
,
wrappedErr
,
)
}
return
fmt
.
Errorf
(
"openai ws passthrough dial: %w"
,
wrappedErr
)
}
func
openaiwsv2RelayMessageTypeName
(
msgType
coderws
.
MessageType
)
string
{
switch
msgType
{
case
coderws
.
MessageText
:
return
"text"
case
coderws
.
MessageBinary
:
return
"binary"
default
:
return
fmt
.
Sprintf
(
"unknown(%d)"
,
msgType
)
}
}
func
relayErrorText
(
err
error
)
string
{
if
err
==
nil
{
return
""
}
return
err
.
Error
()
}
func
openAIWSFirstTokenMsForLog
(
firstTokenMs
*
int
)
int
{
if
firstTokenMs
==
nil
{
return
-
1
}
return
*
firstTokenMs
}
func
logOpenAIWSV2Passthrough
(
format
string
,
args
...
any
)
{
logger
.
LegacyPrintf
(
"service.openai_ws_v2"
,
"[OpenAI WS v2 passthrough] %s "
+
format
,
append
([]
any
{
openaiWSV2PassthroughModeFields
},
args
...
)
...
,
)
}
deploy/config.example.yaml
View file @
7076717b
...
@@ -209,8 +209,9 @@ gateway:
...
@@ -209,8 +209,9 @@ gateway:
openai_ws
:
openai_ws
:
# 新版 WS mode 路由(默认关闭)。关闭时保持当前 legacy 实现行为。
# 新版 WS mode 路由(默认关闭)。关闭时保持当前 legacy 实现行为。
mode_router_v2_enabled
:
false
mode_router_v2_enabled
:
false
# ingress 默认模式:off|shared|dedicated(仅 mode_router_v2_enabled=true 生效)
# ingress 默认模式:off|ctx_pool|passthrough(仅 mode_router_v2_enabled=true 生效)
ingress_mode_default
:
shared
# 兼容旧值:shared/dedicated 会按 ctx_pool 处理。
ingress_mode_default
:
ctx_pool
# 全局总开关,默认 true;关闭时所有请求保持原有 HTTP/SSE 路由
# 全局总开关,默认 true;关闭时所有请求保持原有 HTTP/SSE 路由
enabled
:
true
enabled
:
true
# 按账号类型细分开关
# 按账号类型细分开关
...
...
frontend/src/components/account/CreateAccountModal.vue
View file @
7076717b
...
@@ -1807,7 +1807,7 @@
...
@@ -1807,7 +1807,7 @@
<
/div
>
<
/div
>
<
/div
>
<
/div
>
<!--
OpenAI
WS
Mode
三态
(
off
/
shared
/
dedicated
)
-->
<!--
OpenAI
WS
Mode
三态
(
off
/
ctx_pool
/
passthrough
)
-->
<
div
<
div
v
-
if
=
"
form.platform === 'openai' && (accountCategory === 'oauth-based' || accountCategory === 'apikey')
"
v
-
if
=
"
form.platform === 'openai' && (accountCategory === 'oauth-based' || accountCategory === 'apikey')
"
class
=
"
border-t border-gray-200 pt-4 dark:border-dark-600
"
class
=
"
border-t border-gray-200 pt-4 dark:border-dark-600
"
...
@@ -1819,7 +1819,7 @@
...
@@ -1819,7 +1819,7 @@
{{
t
(
'
admin.accounts.openai.wsModeDesc
'
)
}}
{{
t
(
'
admin.accounts.openai.wsModeDesc
'
)
}}
<
/p
>
<
/p
>
<
p
class
=
"
mt-1 text-xs text-gray-500 dark:text-gray-400
"
>
<
p
class
=
"
mt-1 text-xs text-gray-500 dark:text-gray-400
"
>
{{
t
(
'
admin.accounts.openai.ws
ModeConcurrencyHint
'
)
}}
{{
t
(
openAIWS
ModeConcurrencyHint
Key
)
}}
<
/p
>
<
/p
>
<
/div
>
<
/div
>
<
div
class
=
"
w-52
"
>
<
div
class
=
"
w-52
"
>
...
@@ -2341,10 +2341,11 @@ import { applyInterceptWarmup } from '@/components/account/credentialsBuilder'
...
@@ -2341,10 +2341,11 @@ import { applyInterceptWarmup } from '@/components/account/credentialsBuilder'
import
{
formatDateTimeLocalInput
,
parseDateTimeLocalInput
}
from
'
@/utils/format
'
import
{
formatDateTimeLocalInput
,
parseDateTimeLocalInput
}
from
'
@/utils/format
'
import
{
createStableObjectKeyResolver
}
from
'
@/utils/stableObjectKey
'
import
{
createStableObjectKeyResolver
}
from
'
@/utils/stableObjectKey
'
import
{
import
{
OPENAI_WS_MODE_
DEDICATED
,
OPENAI_WS_MODE_
CTX_POOL
,
OPENAI_WS_MODE_OFF
,
OPENAI_WS_MODE_OFF
,
OPENAI_WS_MODE_
SHARED
,
OPENAI_WS_MODE_
PASSTHROUGH
,
isOpenAIWSModeEnabled
,
isOpenAIWSModeEnabled
,
resolveOpenAIWSModeConcurrencyHintKey
,
type
OpenAIWSMode
type
OpenAIWSMode
}
from
'
@/utils/openaiWsMode
'
}
from
'
@/utils/openaiWsMode
'
import
OAuthAuthorizationFlow
from
'
./OAuthAuthorizationFlow.vue
'
import
OAuthAuthorizationFlow
from
'
./OAuthAuthorizationFlow.vue
'
...
@@ -2541,8 +2542,8 @@ const geminiSelectedTier = computed(() => {
...
@@ -2541,8 +2542,8 @@ const geminiSelectedTier = computed(() => {
const
openAIWSModeOptions
=
computed
(()
=>
[
const
openAIWSModeOptions
=
computed
(()
=>
[
{
value
:
OPENAI_WS_MODE_OFF
,
label
:
t
(
'
admin.accounts.openai.wsModeOff
'
)
}
,
{
value
:
OPENAI_WS_MODE_OFF
,
label
:
t
(
'
admin.accounts.openai.wsModeOff
'
)
}
,
{
value
:
OPENAI_WS_MODE_
SHARED
,
label
:
t
(
'
admin.accounts.openai.wsMode
Shared
'
)
}
,
{
value
:
OPENAI_WS_MODE_
CTX_POOL
,
label
:
t
(
'
admin.accounts.openai.wsMode
CtxPool
'
)
}
,
{
value
:
OPENAI_WS_MODE_
DEDICATED
,
label
:
t
(
'
admin.accounts.openai.wsMode
Dedicated
'
)
}
{
value
:
OPENAI_WS_MODE_
PASSTHROUGH
,
label
:
t
(
'
admin.accounts.openai.wsMode
Passthrough
'
)
}
])
])
const
openaiResponsesWebSocketV2Mode
=
computed
({
const
openaiResponsesWebSocketV2Mode
=
computed
({
...
@@ -2561,6 +2562,10 @@ const openaiResponsesWebSocketV2Mode = computed({
...
@@ -2561,6 +2562,10 @@ const openaiResponsesWebSocketV2Mode = computed({
}
}
}
)
}
)
const
openAIWSModeConcurrencyHintKey
=
computed
(()
=>
resolveOpenAIWSModeConcurrencyHintKey
(
openaiResponsesWebSocketV2Mode
.
value
)
)
const
isOpenAIModelRestrictionDisabled
=
computed
(()
=>
const
isOpenAIModelRestrictionDisabled
=
computed
(()
=>
form
.
platform
===
'
openai
'
&&
openaiPassthroughEnabled
.
value
form
.
platform
===
'
openai
'
&&
openaiPassthroughEnabled
.
value
)
)
...
@@ -3180,10 +3185,13 @@ const buildOpenAIExtra = (base?: Record<string, unknown>): Record<string, unknow
...
@@ -3180,10 +3185,13 @@ const buildOpenAIExtra = (base?: Record<string, unknown>): Record<string, unknow
}
}
const
extra
:
Record
<
string
,
unknown
>
=
{
...(
base
||
{
}
)
}
const
extra
:
Record
<
string
,
unknown
>
=
{
...(
base
||
{
}
)
}
extra
.
openai_oauth_responses_websockets_v2_mode
=
openaiOAuthResponsesWebSocketV2Mode
.
value
if
(
accountCategory
.
value
===
'
oauth-based
'
)
{
extra
.
openai_apikey_responses_websockets_v2_mode
=
openaiAPIKeyResponsesWebSocketV2Mode
.
value
extra
.
openai_oauth_responses_websockets_v2_mode
=
openaiOAuthResponsesWebSocketV2Mode
.
value
extra
.
openai_oauth_responses_websockets_v2_enabled
=
isOpenAIWSModeEnabled
(
openaiOAuthResponsesWebSocketV2Mode
.
value
)
extra
.
openai_oauth_responses_websockets_v2_enabled
=
isOpenAIWSModeEnabled
(
openaiOAuthResponsesWebSocketV2Mode
.
value
)
extra
.
openai_apikey_responses_websockets_v2_enabled
=
isOpenAIWSModeEnabled
(
openaiAPIKeyResponsesWebSocketV2Mode
.
value
)
}
else
if
(
accountCategory
.
value
===
'
apikey
'
)
{
extra
.
openai_apikey_responses_websockets_v2_mode
=
openaiAPIKeyResponsesWebSocketV2Mode
.
value
extra
.
openai_apikey_responses_websockets_v2_enabled
=
isOpenAIWSModeEnabled
(
openaiAPIKeyResponsesWebSocketV2Mode
.
value
)
}
// 清理兼容旧键,统一改用分类型开关。
// 清理兼容旧键,统一改用分类型开关。
delete
extra
.
responses_websockets_v2_enabled
delete
extra
.
responses_websockets_v2_enabled
delete
extra
.
openai_ws_enabled
delete
extra
.
openai_ws_enabled
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment