Commit 538ae31a authored by 陈曦's avatar 陈曦
Browse files

merge v0.1.121 and fixed conflict

parents 74828a7c 48912014
Pipeline #82338 passed with stage
in 17 seconds
......@@ -1354,6 +1354,274 @@ func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_StoreDisabledFun
require.False(t, gjson.Get(requestToJSONString(captureConn.writes[1]), "previous_response_id").Exists(), "上一轮缺失 response.id 时不应自动补齐 previous_response_id")
}
func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_StoreDisabledFunctionCallOutputSkipsAutoAttachWhenToolCallContextPresent(t *testing.T) {
gin.SetMode(gin.TestMode)
cfg := &config.Config{}
cfg.Security.URLAllowlist.Enabled = false
cfg.Security.URLAllowlist.AllowInsecureHTTP = true
cfg.Gateway.OpenAIWS.Enabled = true
cfg.Gateway.OpenAIWS.OAuthEnabled = true
cfg.Gateway.OpenAIWS.APIKeyEnabled = true
cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1
cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0
cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1
cfg.Gateway.OpenAIWS.QueueLimitPerConn = 8
cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3
cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 3
cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 3
captureConn := &openAIWSCaptureConn{
events: [][]byte{
[]byte(`{"type":"response.completed","response":{"id":"resp_auto_prev_ctx_1","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`),
[]byte(`{"type":"response.completed","response":{"id":"resp_auto_prev_ctx_2","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`),
},
}
captureDialer := &openAIWSQueueDialer{
conns: []openAIWSClientConn{captureConn},
}
pool := newOpenAIWSConnPool(cfg)
pool.setClientDialerForTest(captureDialer)
svc := &OpenAIGatewayService{
cfg: cfg,
httpUpstream: &httpUpstreamRecorder{},
cache: &stubGatewayCache{},
openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
toolCorrector: NewCodexToolCorrector(),
openaiWSPool: pool,
}
account := &Account{
ID: 114,
Name: "openai-ingress-tool-context",
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Status: StatusActive,
Schedulable: true,
Concurrency: 1,
Credentials: map[string]any{
"api_key": "sk-test",
},
Extra: map[string]any{
"responses_websockets_v2_enabled": true,
},
}
serverErrCh := make(chan error, 1)
wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
conn, err := coderws.Accept(w, r, &coderws.AcceptOptions{
CompressionMode: coderws.CompressionContextTakeover,
})
if err != nil {
serverErrCh <- err
return
}
defer func() {
_ = conn.CloseNow()
}()
rec := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(rec)
req := r.Clone(r.Context())
req.Header = req.Header.Clone()
req.Header.Set("User-Agent", "unit-test-agent/1.0")
ginCtx.Request = req
readCtx, cancel := context.WithTimeout(r.Context(), 3*time.Second)
msgType, firstMessage, readErr := conn.Read(readCtx)
cancel()
if readErr != nil {
serverErrCh <- readErr
return
}
if msgType != coderws.MessageText && msgType != coderws.MessageBinary {
serverErrCh <- errors.New("unsupported websocket client message type")
return
}
serverErrCh <- svc.ProxyResponsesWebSocketFromClient(r.Context(), ginCtx, conn, account, "sk-test", firstMessage, nil)
}))
defer wsServer.Close()
dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second)
clientConn, _, err := coderws.Dial(dialCtx, "ws"+strings.TrimPrefix(wsServer.URL, "http"), nil)
cancelDial()
require.NoError(t, err)
defer func() {
_ = clientConn.CloseNow()
}()
writeMessage := func(payload string) {
writeCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
require.NoError(t, clientConn.Write(writeCtx, coderws.MessageText, []byte(payload)))
}
readMessage := func() []byte {
readCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
msgType, message, readErr := clientConn.Read(readCtx)
require.NoError(t, readErr)
require.Equal(t, coderws.MessageText, msgType)
return message
}
writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false,"store":false,"input":[{"type":"input_text","text":"hello"}]}`)
firstTurn := readMessage()
require.Equal(t, "resp_auto_prev_ctx_1", gjson.GetBytes(firstTurn, "response.id").String())
writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false,"store":false,"input":[{"type":"function_call","call_id":"call_ctx_1","name":"shell","arguments":"{}"},{"type":"function_call_output","call_id":"call_ctx_1","output":"ok"},{"type":"message","role":"user","content":[{"type":"input_text","text":"retry"}]}]}`)
secondTurn := readMessage()
require.Equal(t, "resp_auto_prev_ctx_2", gjson.GetBytes(secondTurn, "response.id").String())
require.NoError(t, clientConn.Close(coderws.StatusNormalClosure, "done"))
select {
case serverErr := <-serverErrCh:
require.NoError(t, serverErr)
case <-time.After(5 * time.Second):
t.Fatal("等待 ingress websocket 结束超时")
}
require.Equal(t, 1, captureDialer.DialCount())
require.Len(t, captureConn.writes, 2)
require.False(t, gjson.Get(requestToJSONString(captureConn.writes[1]), "previous_response_id").Exists(), "请求已包含 function_call 上下文时不应自动补齐 previous_response_id")
}
func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_StoreDisabledFunctionCallOutputAutoAttachWhenOnlyItemReferencesPresent(t *testing.T) {
gin.SetMode(gin.TestMode)
cfg := &config.Config{}
cfg.Security.URLAllowlist.Enabled = false
cfg.Security.URLAllowlist.AllowInsecureHTTP = true
cfg.Gateway.OpenAIWS.Enabled = true
cfg.Gateway.OpenAIWS.OAuthEnabled = true
cfg.Gateway.OpenAIWS.APIKeyEnabled = true
cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1
cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0
cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1
cfg.Gateway.OpenAIWS.QueueLimitPerConn = 8
cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3
cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 3
cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 3
captureConn := &openAIWSCaptureConn{
events: [][]byte{
[]byte(`{"type":"response.completed","response":{"id":"resp_auto_prev_ref_1","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`),
[]byte(`{"type":"response.completed","response":{"id":"resp_auto_prev_ref_2","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`),
},
}
captureDialer := &openAIWSQueueDialer{
conns: []openAIWSClientConn{captureConn},
}
pool := newOpenAIWSConnPool(cfg)
pool.setClientDialerForTest(captureDialer)
svc := &OpenAIGatewayService{
cfg: cfg,
httpUpstream: &httpUpstreamRecorder{},
cache: &stubGatewayCache{},
openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
toolCorrector: NewCodexToolCorrector(),
openaiWSPool: pool,
}
account := &Account{
ID: 115,
Name: "openai-ingress-item-reference",
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Status: StatusActive,
Schedulable: true,
Concurrency: 1,
Credentials: map[string]any{
"api_key": "sk-test",
},
Extra: map[string]any{
"responses_websockets_v2_enabled": true,
},
}
serverErrCh := make(chan error, 1)
wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
conn, err := coderws.Accept(w, r, &coderws.AcceptOptions{
CompressionMode: coderws.CompressionContextTakeover,
})
if err != nil {
serverErrCh <- err
return
}
defer func() {
_ = conn.CloseNow()
}()
rec := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(rec)
req := r.Clone(r.Context())
req.Header = req.Header.Clone()
req.Header.Set("User-Agent", "unit-test-agent/1.0")
ginCtx.Request = req
readCtx, cancel := context.WithTimeout(r.Context(), 3*time.Second)
msgType, firstMessage, readErr := conn.Read(readCtx)
cancel()
if readErr != nil {
serverErrCh <- readErr
return
}
if msgType != coderws.MessageText && msgType != coderws.MessageBinary {
serverErrCh <- errors.New("unsupported websocket client message type")
return
}
serverErrCh <- svc.ProxyResponsesWebSocketFromClient(r.Context(), ginCtx, conn, account, "sk-test", firstMessage, nil)
}))
defer wsServer.Close()
dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second)
clientConn, _, err := coderws.Dial(dialCtx, "ws"+strings.TrimPrefix(wsServer.URL, "http"), nil)
cancelDial()
require.NoError(t, err)
defer func() {
_ = clientConn.CloseNow()
}()
writeMessage := func(payload string) {
writeCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
require.NoError(t, clientConn.Write(writeCtx, coderws.MessageText, []byte(payload)))
}
readMessage := func() []byte {
readCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
msgType, message, readErr := clientConn.Read(readCtx)
require.NoError(t, readErr)
require.Equal(t, coderws.MessageText, msgType)
return message
}
writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false,"store":false,"input":[{"type":"input_text","text":"hello"}]}`)
firstTurn := readMessage()
require.Equal(t, "resp_auto_prev_ref_1", gjson.GetBytes(firstTurn, "response.id").String())
writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false,"store":false,"input":[{"type":"item_reference","id":"call_ref_1"},{"type":"function_call_output","call_id":"call_ref_1","output":"ok"},{"type":"message","role":"user","content":[{"type":"input_text","text":"retry"}]}]}`)
secondTurn := readMessage()
require.Equal(t, "resp_auto_prev_ref_2", gjson.GetBytes(secondTurn, "response.id").String())
require.NoError(t, clientConn.Close(coderws.StatusNormalClosure, "done"))
select {
case serverErr := <-serverErrCh:
require.NoError(t, serverErr)
case <-time.After(5 * time.Second):
t.Fatal("等待 ingress websocket 结束超时")
}
require.Equal(t, 1, captureDialer.DialCount())
require.Len(t, captureConn.writes, 2)
require.Equal(t, "resp_auto_prev_ref_1", gjson.Get(requestToJSONString(captureConn.writes[1]), "previous_response_id").String(), "仅有 item_reference 不足以自包含 function_call_output,应回填上一轮响应 ID")
}
func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_PreflightPingFailReconnectsBeforeTurn(t *testing.T) {
gin.SetMode(gin.TestMode)
prevPreflightPingIdle := openAIWSIngressPreflightPingIdle
......
......@@ -232,67 +232,91 @@ func TestShouldInferIngressFunctionCallOutputPreviousResponseID(t *testing.T) {
name string
storeDisabled bool
turn int
hasFunctionCallOutput bool
signals ToolContinuationSignals
currentPreviousResponse string
expectedPrevious string
want bool
}{
{
name: "infer_when_all_conditions_match",
storeDisabled: true,
turn: 2,
hasFunctionCallOutput: true,
expectedPrevious: "resp_1",
want: true,
name: "infer_when_all_conditions_match",
storeDisabled: true,
turn: 2,
signals: ToolContinuationSignals{HasFunctionCallOutput: true},
expectedPrevious: "resp_1",
want: true,
},
{
name: "skip_when_store_enabled",
storeDisabled: false,
turn: 2,
hasFunctionCallOutput: true,
expectedPrevious: "resp_1",
want: false,
name: "skip_when_store_enabled",
storeDisabled: false,
turn: 2,
signals: ToolContinuationSignals{HasFunctionCallOutput: true},
expectedPrevious: "resp_1",
want: false,
},
{
name: "skip_on_first_turn",
storeDisabled: true,
turn: 1,
hasFunctionCallOutput: true,
expectedPrevious: "resp_1",
want: false,
name: "skip_on_first_turn",
storeDisabled: true,
turn: 1,
signals: ToolContinuationSignals{HasFunctionCallOutput: true},
expectedPrevious: "resp_1",
want: false,
},
{
name: "skip_without_function_call_output",
storeDisabled: true,
turn: 2,
hasFunctionCallOutput: false,
expectedPrevious: "resp_1",
want: false,
name: "skip_without_function_call_output",
storeDisabled: true,
turn: 2,
signals: ToolContinuationSignals{},
expectedPrevious: "resp_1",
want: false,
},
{
name: "skip_when_request_already_has_previous_response_id",
storeDisabled: true,
turn: 2,
hasFunctionCallOutput: true,
signals: ToolContinuationSignals{HasFunctionCallOutput: true},
currentPreviousResponse: "resp_client",
expectedPrevious: "resp_1",
want: false,
},
{
name: "skip_when_last_turn_response_id_missing",
storeDisabled: true,
turn: 2,
hasFunctionCallOutput: true,
expectedPrevious: "",
want: false,
name: "skip_when_last_turn_response_id_missing",
storeDisabled: true,
turn: 2,
signals: ToolContinuationSignals{HasFunctionCallOutput: true},
expectedPrevious: "",
want: false,
},
{
name: "trim_whitespace_before_judgement",
storeDisabled: true,
turn: 2,
hasFunctionCallOutput: true,
expectedPrevious: " resp_2 ",
want: true,
name: "trim_whitespace_before_judgement",
storeDisabled: true,
turn: 2,
signals: ToolContinuationSignals{HasFunctionCallOutput: true},
expectedPrevious: " resp_2 ",
want: true,
},
{
name: "skip_when_tool_call_context_already_present",
storeDisabled: true,
turn: 2,
signals: ToolContinuationSignals{HasFunctionCallOutput: true, HasToolCallContext: true},
expectedPrevious: "resp_2",
want: false,
},
{
name: "infer_when_only_item_reference_covers_call_ids",
storeDisabled: true,
turn: 2,
signals: ToolContinuationSignals{HasFunctionCallOutput: true, HasItemReferenceForAllCallIDs: true},
expectedPrevious: "resp_2",
want: true,
},
{
name: "skip_when_function_call_output_missing_call_id",
storeDisabled: true,
turn: 2,
signals: ToolContinuationSignals{HasFunctionCallOutput: true, HasFunctionCallOutputMissingCallID: true},
expectedPrevious: "resp_2",
want: false,
},
}
......@@ -303,7 +327,7 @@ func TestShouldInferIngressFunctionCallOutputPreviousResponseID(t *testing.T) {
got := shouldInferIngressFunctionCallOutputPreviousResponseID(
tt.storeDisabled,
tt.turn,
tt.hasFunctionCallOutput,
tt.signals,
tt.currentPreviousResponse,
tt.expectedPrevious,
)
......
......@@ -618,6 +618,7 @@ func TestNewOpenAIGatewayService_InitializesOpenAIWSResolver(t *testing.T) {
nil,
nil,
nil,
nil,
)
decision := svc.getOpenAIWSProtocolResolver().Resolve(nil)
......
......@@ -21,6 +21,109 @@ type openAIWSClientFrameConn struct {
conn *coderws.Conn
}
// openAIWSPolicyEnforcingFrameConn wraps a client-side FrameConn and runs
// every client→upstream frame through the OpenAI Fast Policy. It is the
// passthrough-relay equivalent of the parseClientPayload integration in the
// ingress session path. filter returns:
// - newPayload, nil, nil: forward the (possibly mutated) payload
// - _, *OpenAIFastBlockedError, nil: block — the wrapper sends an error
// event via onBlock and surfaces a transport-level error so the relay
// stops reading from the client.
// - _, _, err: a transport error other than block.
type openAIWSPolicyEnforcingFrameConn struct {
inner openaiwsv2.FrameConn
filter func(msgType coderws.MessageType, payload []byte) ([]byte, *OpenAIFastBlockedError, error)
onBlock func(blocked *OpenAIFastBlockedError)
}
var _ openaiwsv2.FrameConn = (*openAIWSPolicyEnforcingFrameConn)(nil)
func (c *openAIWSPolicyEnforcingFrameConn) ReadFrame(ctx context.Context) (coderws.MessageType, []byte, error) {
if c == nil || c.inner == nil {
return coderws.MessageText, nil, errOpenAIWSConnClosed
}
msgType, payload, err := c.inner.ReadFrame(ctx)
if err != nil {
return msgType, payload, err
}
if c.filter == nil {
return msgType, payload, nil
}
updated, blocked, filterErr := c.filter(msgType, payload)
if filterErr != nil {
return msgType, payload, filterErr
}
if blocked != nil {
if c.onBlock != nil {
c.onBlock(blocked)
}
return msgType, nil, NewOpenAIWSClientCloseError(coderws.StatusPolicyViolation, blocked.Message, blocked)
}
return msgType, updated, nil
}
func (c *openAIWSPolicyEnforcingFrameConn) WriteFrame(ctx context.Context, msgType coderws.MessageType, payload []byte) error {
if c == nil || c.inner == nil {
return errOpenAIWSConnClosed
}
return c.inner.WriteFrame(ctx, msgType, payload)
}
func (c *openAIWSPolicyEnforcingFrameConn) Close() error {
if c == nil || c.inner == nil {
return nil
}
return c.inner.Close()
}
// openAIWSPassthroughPolicyModelForFrame returns the upstream-perspective
// model name that should be passed to evaluateOpenAIFastPolicy for a single
// passthrough WS frame. Mirrors the HTTP-side normalization
// (account.GetMappedModel + normalizeOpenAIModelForUpstream) so the WS path
// matches model whitelists identically.
func openAIWSPassthroughPolicyModelForFrame(account *Account, payload []byte) string {
if account == nil || len(payload) == 0 {
return ""
}
original := strings.TrimSpace(gjson.GetBytes(payload, "model").String())
if original == "" {
return ""
}
return normalizeOpenAIModelForUpstream(account, account.GetMappedModel(original))
}
// openAIWSPassthroughPolicyModelFromSessionFrame returns the upstream model
// derived from a session.update frame's session.model field. Returns "" when
// the frame is not a session.update event or carries no session.model. Used
// by the per-frame policy filter (client→upstream direction) to keep
// capturedSessionModel in sync with the session-level model the client may
// rotate mid-session.
//
// Realtime / Responses WS lets the client change the session model after
// the WS handshake via:
//
// {"type":"session.update","session":{"model":"gpt-5.5", ...}}
//
// If we only capture the model from the very first frame, a client can ship
// gpt-4o on the first response.create (whitelisted as pass), then
// session.update to gpt-5.5, then send response.create without "model" so
// the per-frame resolver returns "" and the stale capturedSessionModel falls
// back to gpt-4o — defeating the gpt-5.5 fast-policy filter.
func openAIWSPassthroughPolicyModelFromSessionFrame(account *Account, payload []byte) string {
if account == nil || len(payload) == 0 {
return ""
}
frameType := strings.TrimSpace(gjson.GetBytes(payload, "type").String())
if frameType != "session.update" {
return ""
}
original := strings.TrimSpace(gjson.GetBytes(payload, "session.model").String())
if original == "" {
return ""
}
return normalizeOpenAIModelForUpstream(account, account.GetMappedModel(original))
}
const openaiWSV2PassthroughModeFields = "ws_mode=passthrough ws_router=v2"
var _ openaiwsv2.FrameConn = (*openAIWSClientFrameConn)(nil)
......@@ -77,7 +180,6 @@ func (s *OpenAIGatewayService) proxyResponsesWebSocketV2Passthrough(
return errors.New("token is empty")
}
requestModel := strings.TrimSpace(gjson.GetBytes(firstClientMessage, "model").String())
requestServiceTier := extractOpenAIServiceTierFromBody(firstClientMessage)
requestPreviousResponseID := strings.TrimSpace(gjson.GetBytes(firstClientMessage, "previous_response_id").String())
logOpenAIWSV2Passthrough(
"relay_start account_id=%d model=%s previous_response_id=%s first_message_type=%s first_message_bytes=%d",
......@@ -88,6 +190,59 @@ func (s *OpenAIGatewayService) proxyResponsesWebSocketV2Passthrough(
len(firstClientMessage),
)
// Apply OpenAI Fast Policy on the first response.create frame. Subsequent
// frames are filtered via a wrapping FrameConn below so every client→
// upstream frame goes through the same policy evaluator/normalize/scope as
// HTTP entrypoints.
//
// We capture the session-level model from the first frame here so the
// per-frame filter (below) can fall back to it when a follow-up frame
// omits "model" — Realtime clients are allowed to send response.create
// without re-stating the model, in which case the upstream uses the model
// negotiated at session.update time. Without this fallback, an empty
// model would miss the default ["gpt-5.5","gpt-5.5*"] whitelist and be
// silently passed through, defeating the policy on every frame after
// the first.
capturedSessionModel := openAIWSPassthroughPolicyModelForFrame(account, firstClientMessage)
updatedFirst, blocked, policyErr := s.applyOpenAIFastPolicyToWSResponseCreate(ctx, account, capturedSessionModel, firstClientMessage)
if policyErr != nil {
return fmt.Errorf("apply openai fast policy on first ws frame: %w", policyErr)
}
if blocked != nil {
// coder/websocket@v1.8.14 Conn.Write is synchronous: it acquires
// writeFrameMu, writes the entire frame, and Flushes the underlying
// bufio writer before returning (write.go:42 → write.go:307-311).
// The subsequent close handshake re-acquires the same writeFrameMu
// to send the close frame, so the error event is guaranteed to
// reach the kernel send buffer before any close frame is queued.
// No explicit flush hop is required here.
eventBytes := buildOpenAIFastPolicyBlockedWSEvent(blocked)
if eventBytes != nil {
writeCtx, cancelWrite := context.WithTimeout(ctx, s.openAIWSWriteTimeout())
_ = clientConn.Write(writeCtx, coderws.MessageText, eventBytes)
cancelWrite()
}
return NewOpenAIWSClientCloseError(coderws.StatusPolicyViolation, blocked.Message, blocked)
}
firstClientMessage = updatedFirst
// 在 policy filter 之后再提取 service_tier 用于 billing 上报:filter
// 命中时 service_tier 已经从 firstClientMessage 中删除,billing 应当
// 反映上游实际处理的 tier(nil = default),而不是用户最初请求的
// "priority"。HTTP 入口(line ~2728 extractOpenAIServiceTier(reqBody))
// 与 WS ingress(openai_ws_forwarder.go:2991 取自 payload)的语义一致。
//
// 多轮 passthrough:OpenAI Realtime / Responses WS 协议允许客户端在
// 同一连接的不同 response.create 帧上发送不同 service_tier(参考
// codex-rs/core/src/client.rs build_responses_request 每次重新填值)。
// 因此使用 atomic.Pointer[string] 在 filter(runClientToUpstream
// goroutine)和 OnTurnComplete / final result(runUpstreamToClient
// goroutine)之间同步当前 turn 的 service_tier。
// extractOpenAIServiceTierFromBody 返回 *string,本身是指针类型,
// 可直接 Store/Load 而无需额外封装。
var requestServiceTierPtr atomic.Pointer[string]
requestServiceTierPtr.Store(extractOpenAIServiceTierFromBody(firstClientMessage))
wsURL, err := s.buildOpenAIResponsesWSURL(account)
if err != nil {
return fmt.Errorf("build ws url: %w", err)
......@@ -152,9 +307,72 @@ func (s *OpenAIGatewayService) proxyResponsesWebSocketV2Passthrough(
}
completedTurns := atomic.Int32{}
policyClientConn := &openAIWSPolicyEnforcingFrameConn{
inner: &openAIWSClientFrameConn{conn: clientConn},
// 注意线程安全:filter 仅在 runClientToUpstream 这一条
// goroutine 中被调用(passthrough_relay.go: ReadFrame loop),
// capturedSessionModel 的读写都发生在该 goroutine 内,因此无需
// 加锁/原子化。
filter: func(msgType coderws.MessageType, payload []byte) ([]byte, *OpenAIFastBlockedError, error) {
if msgType != coderws.MessageText {
return payload, nil, nil
}
// 在评估策略前先刷新 capturedSessionModel:客户端可能通过
// session.update 修改 session-level model(Realtime /
// Responses WS 协议允许),如果不刷新就会出现
// "首帧 model=gpt-4o(pass)→ session.update 改成 gpt-5.5
// → 不带 model 的 response.create fallback 到 gpt-4o" 的
// 绕过路径。这里只看 session.update 事件中的 session.model
// 字段,response.create 自己的 model 仍然由其本帧字段决定。
if updated := openAIWSPassthroughPolicyModelFromSessionFrame(account, payload); updated != "" {
capturedSessionModel = updated
}
// Per-frame model first; if the client omits "model" on a
// follow-up frame (legal in Realtime), fall back to the
// session-level model captured from the first frame so the
// model whitelist still resolves. An empty model would miss
// any whitelist and silently fall back to pass.
model := openAIWSPassthroughPolicyModelForFrame(account, payload)
if model == "" {
model = capturedSessionModel
}
out, blocked, policyErr := s.applyOpenAIFastPolicyToWSResponseCreate(ctx, account, model, payload)
// 多轮 passthrough billing:仅在成功(non-block / non-err)
// 的 response.create 帧上更新 requestServiceTierPtr,使用
// filter 处理后的 payload,与首帧 policy-after-extract 语义
// 保持一致(参见上方 extractOpenAIServiceTierFromBody 注释)。
// - 非 response.create 帧(response.cancel /
// conversation.item.create / session.update 等)不携带
// per-response service_tier,不应覆盖前一轮值。
// - blocked != nil:该帧不会发送上游,billing tier 应保持
// 上一轮值。
// - policyErr != nil:异常路径,保持上一轮值。
// - 不带 service_tier 的 response.create 会让
// extractOpenAIServiceTierFromBody 返回 nil;这里有意
// 覆盖(Store(nil)),因为 OpenAI 上游对该帧实际不传
// service_tier 时按 default 处理,billing 应如实反映。
if policyErr == nil && blocked == nil &&
strings.TrimSpace(gjson.GetBytes(payload, "type").String()) == "response.create" {
requestServiceTierPtr.Store(extractOpenAIServiceTierFromBody(out))
}
return out, blocked, policyErr
},
onBlock: func(blocked *OpenAIFastBlockedError) {
// See note above on Conn.Write being synchronous w.r.t. flush;
// no explicit flush is required to ensure the error event lands
// before the close frame.
eventBytes := buildOpenAIFastPolicyBlockedWSEvent(blocked)
if eventBytes == nil {
return
}
writeCtx, cancel := context.WithTimeout(ctx, s.openAIWSWriteTimeout())
_ = clientConn.Write(writeCtx, coderws.MessageText, eventBytes)
cancel()
},
}
relayResult, relayExit := openaiwsv2.RunEntry(openaiwsv2.EntryInput{
Ctx: ctx,
ClientConn: &openAIWSClientFrameConn{conn: clientConn},
ClientConn: policyClientConn,
UpstreamConn: upstreamFrameConn,
FirstClientMessage: firstClientMessage,
Options: openaiwsv2.RelayOptions{
......@@ -179,7 +397,7 @@ func (s *OpenAIGatewayService) proxyResponsesWebSocketV2Passthrough(
CacheReadInputTokens: turn.Usage.CacheReadInputTokens,
},
Model: turn.RequestModel,
ServiceTier: requestServiceTier,
ServiceTier: requestServiceTierPtr.Load(),
Stream: true,
OpenAIWSMode: true,
ResponseHeaders: cloneHeader(handshakeHeaders),
......@@ -227,7 +445,7 @@ func (s *OpenAIGatewayService) proxyResponsesWebSocketV2Passthrough(
CacheReadInputTokens: relayResult.Usage.CacheReadInputTokens,
},
Model: relayResult.RequestModel,
ServiceTier: requestServiceTier,
ServiceTier: requestServiceTierPtr.Load(),
Stream: true,
OpenAIWSMode: true,
ResponseHeaders: cloneHeader(handshakeHeaders),
......
......@@ -184,6 +184,25 @@ func (c opsCleanupDeletedCounts) String() string {
)
}
// opsCleanupPlan 把"保留天数"翻译成具体的清理动作。
// - days < 0 → 跳过该项清理(ok=false),保留兼容老数据
// - days == 0 → TRUNCATE TABLE(O(1) 全清),truncate=true
// - days > 0 → 批量 DELETE 早于 now-N天 的行,cutoff = now - N 天
//
// 之所以 days==0 走 TRUNCATE 而非"now+24h cutoff + DELETE":
// - 速度从 O(N) 降到 O(1),对百万行级表毫秒完成
// - 无 WAL 写入、无后续 VACUUM 压力
// - 这些 ops 表只有 cleanup 任务自己写,TRUNCATE 的 ACCESS EXCLUSIVE 锁影响可忽略
func opsCleanupPlan(now time.Time, days int) (cutoff time.Time, truncate, ok bool) {
if days < 0 {
return time.Time{}, false, false
}
if days == 0 {
return time.Time{}, true, true
}
return now.AddDate(0, 0, -days), false, true
}
func (s *OpsCleanupService) runCleanupOnce(ctx context.Context) (opsCleanupDeletedCounts, error) {
out := opsCleanupDeletedCounts{}
if s == nil || s.db == nil || s.cfg == nil {
......@@ -194,34 +213,42 @@ func (s *OpsCleanupService) runCleanupOnce(ctx context.Context) (opsCleanupDelet
now := time.Now().UTC()
// Error-like tables: error logs / retry attempts / alert events.
if days := s.cfg.Ops.Cleanup.ErrorLogRetentionDays; days > 0 {
cutoff := now.AddDate(0, 0, -days)
n, err := deleteOldRowsByID(ctx, s.db, "ops_error_logs", "created_at", cutoff, batchSize, false)
// runOne 把"truncate? cutoff? batched delete?"封装到一处,
// 让三组清理(错误日志类 / 分钟指标 / 小时+日预聚合)调用方只关心表名和列名。
runOne := func(truncate bool, cutoff time.Time, table, timeCol string, castDate bool) (int64, error) {
if truncate {
return truncateOpsTable(ctx, s.db, table)
}
return deleteOldRowsByID(ctx, s.db, table, timeCol, cutoff, batchSize, castDate)
}
// Error-like tables: error logs / retry attempts / alert events / system logs / cleanup audits.
if cutoff, truncate, ok := opsCleanupPlan(now, s.cfg.Ops.Cleanup.ErrorLogRetentionDays); ok {
n, err := runOne(truncate, cutoff, "ops_error_logs", "created_at", false)
if err != nil {
return out, err
}
out.errorLogs = n
n, err = deleteOldRowsByID(ctx, s.db, "ops_retry_attempts", "created_at", cutoff, batchSize, false)
n, err = runOne(truncate, cutoff, "ops_retry_attempts", "created_at", false)
if err != nil {
return out, err
}
out.retryAttempts = n
n, err = deleteOldRowsByID(ctx, s.db, "ops_alert_events", "created_at", cutoff, batchSize, false)
n, err = runOne(truncate, cutoff, "ops_alert_events", "created_at", false)
if err != nil {
return out, err
}
out.alertEvents = n
n, err = deleteOldRowsByID(ctx, s.db, "ops_system_logs", "created_at", cutoff, batchSize, false)
n, err = runOne(truncate, cutoff, "ops_system_logs", "created_at", false)
if err != nil {
return out, err
}
out.systemLogs = n
n, err = deleteOldRowsByID(ctx, s.db, "ops_system_log_cleanup_audits", "created_at", cutoff, batchSize, false)
n, err = runOne(truncate, cutoff, "ops_system_log_cleanup_audits", "created_at", false)
if err != nil {
return out, err
}
......@@ -229,9 +256,8 @@ func (s *OpsCleanupService) runCleanupOnce(ctx context.Context) (opsCleanupDelet
}
// Minute-level metrics snapshots.
if days := s.cfg.Ops.Cleanup.MinuteMetricsRetentionDays; days > 0 {
cutoff := now.AddDate(0, 0, -days)
n, err := deleteOldRowsByID(ctx, s.db, "ops_system_metrics", "created_at", cutoff, batchSize, false)
if cutoff, truncate, ok := opsCleanupPlan(now, s.cfg.Ops.Cleanup.MinuteMetricsRetentionDays); ok {
n, err := runOne(truncate, cutoff, "ops_system_metrics", "created_at", false)
if err != nil {
return out, err
}
......@@ -239,15 +265,14 @@ func (s *OpsCleanupService) runCleanupOnce(ctx context.Context) (opsCleanupDelet
}
// Pre-aggregation tables (hourly/daily).
if days := s.cfg.Ops.Cleanup.HourlyMetricsRetentionDays; days > 0 {
cutoff := now.AddDate(0, 0, -days)
n, err := deleteOldRowsByID(ctx, s.db, "ops_metrics_hourly", "bucket_start", cutoff, batchSize, false)
if cutoff, truncate, ok := opsCleanupPlan(now, s.cfg.Ops.Cleanup.HourlyMetricsRetentionDays); ok {
n, err := runOne(truncate, cutoff, "ops_metrics_hourly", "bucket_start", false)
if err != nil {
return out, err
}
out.hourlyPreagg = n
n, err = deleteOldRowsByID(ctx, s.db, "ops_metrics_daily", "bucket_date", cutoff, batchSize, true)
n, err = runOne(truncate, cutoff, "ops_metrics_daily", "bucket_date", true)
if err != nil {
return out, err
}
......@@ -303,7 +328,7 @@ WHERE id IN (SELECT id FROM batch)
res, err := db.ExecContext(ctx, q, cutoff, batchSize)
if err != nil {
// If ops tables aren't present yet (partial deployments), treat as no-op.
if strings.Contains(strings.ToLower(err.Error()), "does not exist") && strings.Contains(strings.ToLower(err.Error()), "relation") {
if isMissingRelationError(err) {
return total, nil
}
return total, err
......@@ -320,6 +345,46 @@ WHERE id IN (SELECT id FROM batch)
return total, nil
}
// truncateOpsTable 用 TRUNCATE TABLE 清空指定表,先 SELECT COUNT(*) 取得清空前行数用于 heartbeat。
//
// 与 deleteOldRowsByID 的差异:
// - 不可指定 WHERE 条件,仅用于 days==0 的"清空全部"语义
// - O(1) 释放表的物理存储页,毫秒级完成,无 WAL 写入、无 VACUUM 压力
// - 需要 ACCESS EXCLUSIVE 锁,但 ops 表只有清理任务自己写入,瞬间锁影响可忽略
//
// 表不存在(部分部署)静默返回 0,与 deleteOldRowsByID 保持一致。
func truncateOpsTable(ctx context.Context, db *sql.DB, table string) (int64, error) {
if db == nil {
return 0, nil
}
var count int64
if err := db.QueryRowContext(ctx, fmt.Sprintf("SELECT COUNT(*) FROM %s", table)).Scan(&count); err != nil {
if isMissingRelationError(err) {
return 0, nil
}
return 0, fmt.Errorf("count %s: %w", table, err)
}
if count == 0 {
return 0, nil
}
if _, err := db.ExecContext(ctx, fmt.Sprintf("TRUNCATE TABLE %s", table)); err != nil {
if isMissingRelationError(err) {
return 0, nil
}
return 0, fmt.Errorf("truncate %s: %w", table, err)
}
return count, nil
}
// isMissingRelationError 判断 PG 报错是否为"表不存在",用于让清理任务在部分部署场景静默跳过。
func isMissingRelationError(err error) bool {
if err == nil {
return false
}
s := strings.ToLower(err.Error())
return strings.Contains(s, "does not exist") && strings.Contains(s, "relation")
}
func (s *OpsCleanupService) tryAcquireLeaderLock(ctx context.Context) (func(), bool) {
if s == nil {
return nil, false
......
package service
import (
"testing"
"time"
)
func TestOpsCleanupPlan(t *testing.T) {
now := time.Date(2026, 4, 29, 12, 0, 0, 0, time.UTC)
cases := []struct {
name string
days int
wantOK bool
wantTruncate bool
wantCutoff time.Time
}{
{name: "negative skips", days: -1, wantOK: false},
{name: "zero truncates", days: 0, wantOK: true, wantTruncate: true},
{name: "positive yields past cutoff", days: 7, wantOK: true, wantCutoff: now.AddDate(0, 0, -7)},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
cutoff, truncate, ok := opsCleanupPlan(now, tc.days)
if ok != tc.wantOK {
t.Fatalf("ok = %v, want %v", ok, tc.wantOK)
}
if !ok {
return
}
if truncate != tc.wantTruncate {
t.Fatalf("truncate = %v, want %v", truncate, tc.wantTruncate)
}
if !tc.wantTruncate && !cutoff.Equal(tc.wantCutoff) {
t.Fatalf("cutoff = %v, want %v", cutoff, tc.wantCutoff)
}
})
}
}
func TestIsMissingRelationError(t *testing.T) {
cases := []struct {
name string
err error
want bool
}{
{name: "nil is not missing", err: nil, want: false},
{name: "match relation does not exist", err: fakeErr(`pq: relation "ops_error_logs" does not exist`), want: true},
{name: "match case-insensitive", err: fakeErr(`ERROR: Relation "x" Does Not Exist`), want: true},
{name: "non-matching error", err: fakeErr("connection refused"), want: false},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
if got := isMissingRelationError(tc.err); got != tc.want {
t.Fatalf("got %v, want %v", got, tc.want)
}
})
}
}
type fakeErr string
func (e fakeErr) Error() string { return string(e) }
......@@ -387,13 +387,15 @@ func normalizeOpsAdvancedSettings(cfg *OpsAdvancedSettings) {
if cfg.DataRetention.CleanupSchedule == "" {
cfg.DataRetention.CleanupSchedule = "0 2 * * *"
}
if cfg.DataRetention.ErrorLogRetentionDays <= 0 {
// 保留天数:0 表示每次定时清理全部(清空所有),> 0 表示按天数保留;
// 仅在拿到非法的负数时回填默认值,避免覆盖用户主动设的 0。
if cfg.DataRetention.ErrorLogRetentionDays < 0 {
cfg.DataRetention.ErrorLogRetentionDays = 30
}
if cfg.DataRetention.MinuteMetricsRetentionDays <= 0 {
if cfg.DataRetention.MinuteMetricsRetentionDays < 0 {
cfg.DataRetention.MinuteMetricsRetentionDays = 30
}
if cfg.DataRetention.HourlyMetricsRetentionDays <= 0 {
if cfg.DataRetention.HourlyMetricsRetentionDays < 0 {
cfg.DataRetention.HourlyMetricsRetentionDays = 30
}
// Normalize auto refresh interval (default 30 seconds)
......@@ -406,14 +408,15 @@ func validateOpsAdvancedSettings(cfg *OpsAdvancedSettings) error {
if cfg == nil {
return errors.New("invalid config")
}
if cfg.DataRetention.ErrorLogRetentionDays < 1 || cfg.DataRetention.ErrorLogRetentionDays > 365 {
return errors.New("error_log_retention_days must be between 1 and 365")
// 保留天数:0 表示每次清理全部,1-365 表示按天数保留。
if cfg.DataRetention.ErrorLogRetentionDays < 0 || cfg.DataRetention.ErrorLogRetentionDays > 365 {
return errors.New("error_log_retention_days must be between 0 and 365")
}
if cfg.DataRetention.MinuteMetricsRetentionDays < 1 || cfg.DataRetention.MinuteMetricsRetentionDays > 365 {
return errors.New("minute_metrics_retention_days must be between 1 and 365")
if cfg.DataRetention.MinuteMetricsRetentionDays < 0 || cfg.DataRetention.MinuteMetricsRetentionDays > 365 {
return errors.New("minute_metrics_retention_days must be between 0 and 365")
}
if cfg.DataRetention.HourlyMetricsRetentionDays < 1 || cfg.DataRetention.HourlyMetricsRetentionDays > 365 {
return errors.New("hourly_metrics_retention_days must be between 1 and 365")
if cfg.DataRetention.HourlyMetricsRetentionDays < 0 || cfg.DataRetention.HourlyMetricsRetentionDays > 365 {
return errors.New("hourly_metrics_retention_days must be between 0 and 365")
}
if cfg.AutoRefreshIntervalSec < 15 || cfg.AutoRefreshIntervalSec > 300 {
return errors.New("auto_refresh_interval_seconds must be between 15 and 300")
......
......@@ -59,6 +59,8 @@ type SchedulerCache interface {
UpdateLastUsed(ctx context.Context, updates map[int64]time.Time) error
// TryLockBucket 尝试获取分桶重建锁。
TryLockBucket(ctx context.Context, bucket SchedulerBucket, ttl time.Duration) (bool, error)
// UnlockBucket 释放分桶重建锁。
UnlockBucket(ctx context.Context, bucket SchedulerBucket) error
// ListBuckets 返回已注册的分桶集合。
ListBuckets(ctx context.Context) ([]SchedulerBucket, error)
// GetOutboxWatermark 读取 outbox 水位。
......
......@@ -44,6 +44,10 @@ func (c *snapshotHydrationCache) TryLockBucket(ctx context.Context, bucket Sched
return true, nil
}
func (c *snapshotHydrationCache) UnlockBucket(ctx context.Context, bucket SchedulerBucket) error {
return nil
}
func (c *snapshotHydrationCache) ListBuckets(ctx context.Context) ([]SchedulerBucket, error) {
return nil, nil
}
......
......@@ -544,6 +544,9 @@ func (s *SchedulerSnapshotService) rebuildBucket(ctx context.Context, bucket Sch
if !ok {
return nil
}
defer func() {
_ = s.cache.UnlockBucket(ctx, bucket)
}()
rebuildCtx, cancel := context.WithTimeout(ctx, 30*time.Second)
defer cancel()
......
......@@ -82,10 +82,11 @@ const backendModeDBTimeout = 5 * time.Second
// cachedGatewayForwardingSettings 缓存网关转发行为设置(进程内缓存,60s TTL)
type cachedGatewayForwardingSettings struct {
fingerprintUnification bool
metadataPassthrough bool
cchSigning bool
expiresAt int64 // unix nano
fingerprintUnification bool
metadataPassthrough bool
cchSigning bool
anthropicCacheTTL1hInjection bool
expiresAt int64 // unix nano
}
var gatewayForwardingCache atomic.Value // *cachedGatewayForwardingSettings
......@@ -1245,6 +1246,7 @@ func (s *SettingService) buildSystemSettingsUpdates(ctx context.Context, setting
updates[SettingKeyEnableFingerprintUnification] = strconv.FormatBool(settings.EnableFingerprintUnification)
updates[SettingKeyEnableMetadataPassthrough] = strconv.FormatBool(settings.EnableMetadataPassthrough)
updates[SettingKeyEnableCCHSigning] = strconv.FormatBool(settings.EnableCCHSigning)
updates[SettingKeyEnableAnthropicCacheTTL1hInjection] = strconv.FormatBool(settings.EnableAnthropicCacheTTL1hInjection)
updates[SettingPaymentVisibleMethodAlipaySource] = settings.PaymentVisibleMethodAlipaySource
updates[SettingPaymentVisibleMethodWxpaySource] = settings.PaymentVisibleMethodWxpaySource
updates[SettingPaymentVisibleMethodAlipayEnabled] = strconv.FormatBool(settings.PaymentVisibleMethodAlipayEnabled)
......@@ -1305,10 +1307,11 @@ func (s *SettingService) refreshCachedSettings(settings *SystemSettings) {
})
gatewayForwardingSF.Forget("gateway_forwarding")
gatewayForwardingCache.Store(&cachedGatewayForwardingSettings{
fingerprintUnification: settings.EnableFingerprintUnification,
metadataPassthrough: settings.EnableMetadataPassthrough,
cchSigning: settings.EnableCCHSigning,
expiresAt: time.Now().Add(gatewayForwardingCacheTTL).UnixNano(),
fingerprintUnification: settings.EnableFingerprintUnification,
metadataPassthrough: settings.EnableMetadataPassthrough,
cchSigning: settings.EnableCCHSigning,
anthropicCacheTTL1hInjection: settings.EnableAnthropicCacheTTL1hInjection,
expiresAt: time.Now().Add(gatewayForwardingCacheTTL).UnixNano(),
})
openAIAdvancedSchedulerSettingSF.Forget(openAIAdvancedSchedulerSettingKey)
openAIAdvancedSchedulerSettingCache.Store(&cachedOpenAIAdvancedSchedulerSetting{
......@@ -1415,22 +1418,30 @@ func (s *SettingService) IsBackendModeEnabled(ctx context.Context) bool {
return false
}
// GetGatewayForwardingSettings returns cached gateway forwarding settings.
// Uses in-process atomic.Value cache with 60s TTL, zero-lock hot path.
// Returns (fingerprintUnification, metadataPassthrough, cchSigning).
func (s *SettingService) GetGatewayForwardingSettings(ctx context.Context) (fingerprintUnification, metadataPassthrough, cchSigning bool) {
type gatewayForwardingSettingsResult struct {
fp, mp, cch, cacheTTL1h bool
}
func (s *SettingService) getGatewayForwardingSettingsCached(ctx context.Context) gatewayForwardingSettingsResult {
if cached, ok := gatewayForwardingCache.Load().(*cachedGatewayForwardingSettings); ok && cached != nil {
if time.Now().UnixNano() < cached.expiresAt {
return cached.fingerprintUnification, cached.metadataPassthrough, cached.cchSigning
return gatewayForwardingSettingsResult{
fp: cached.fingerprintUnification,
mp: cached.metadataPassthrough,
cch: cached.cchSigning,
cacheTTL1h: cached.anthropicCacheTTL1hInjection,
}
}
}
type gwfResult struct {
fp, mp, cch bool
}
val, _, _ := gatewayForwardingSF.Do("gateway_forwarding", func() (any, error) {
if cached, ok := gatewayForwardingCache.Load().(*cachedGatewayForwardingSettings); ok && cached != nil {
if time.Now().UnixNano() < cached.expiresAt {
return gwfResult{cached.fingerprintUnification, cached.metadataPassthrough, cached.cchSigning}, nil
return gatewayForwardingSettingsResult{
fp: cached.fingerprintUnification,
mp: cached.metadataPassthrough,
cch: cached.cchSigning,
cacheTTL1h: cached.anthropicCacheTTL1hInjection,
}, nil
}
}
dbCtx, cancel := context.WithTimeout(context.WithoutCancel(ctx), gatewayForwardingDBTimeout)
......@@ -1439,16 +1450,18 @@ func (s *SettingService) GetGatewayForwardingSettings(ctx context.Context) (fing
SettingKeyEnableFingerprintUnification,
SettingKeyEnableMetadataPassthrough,
SettingKeyEnableCCHSigning,
SettingKeyEnableAnthropicCacheTTL1hInjection,
})
if err != nil {
slog.Warn("failed to get gateway forwarding settings", "error", err)
gatewayForwardingCache.Store(&cachedGatewayForwardingSettings{
fingerprintUnification: true,
metadataPassthrough: false,
cchSigning: false,
expiresAt: time.Now().Add(gatewayForwardingErrorTTL).UnixNano(),
fingerprintUnification: true,
metadataPassthrough: false,
cchSigning: false,
anthropicCacheTTL1hInjection: false,
expiresAt: time.Now().Add(gatewayForwardingErrorTTL).UnixNano(),
})
return gwfResult{true, false, false}, nil
return gatewayForwardingSettingsResult{fp: true}, nil
}
fp := true
if v, ok := values[SettingKeyEnableFingerprintUnification]; ok && v != "" {
......@@ -1456,18 +1469,33 @@ func (s *SettingService) GetGatewayForwardingSettings(ctx context.Context) (fing
}
mp := values[SettingKeyEnableMetadataPassthrough] == "true"
cch := values[SettingKeyEnableCCHSigning] == "true"
cacheTTL1h := values[SettingKeyEnableAnthropicCacheTTL1hInjection] == "true"
gatewayForwardingCache.Store(&cachedGatewayForwardingSettings{
fingerprintUnification: fp,
metadataPassthrough: mp,
cchSigning: cch,
expiresAt: time.Now().Add(gatewayForwardingCacheTTL).UnixNano(),
fingerprintUnification: fp,
metadataPassthrough: mp,
cchSigning: cch,
anthropicCacheTTL1hInjection: cacheTTL1h,
expiresAt: time.Now().Add(gatewayForwardingCacheTTL).UnixNano(),
})
return gwfResult{fp, mp, cch}, nil
return gatewayForwardingSettingsResult{fp: fp, mp: mp, cch: cch, cacheTTL1h: cacheTTL1h}, nil
})
if r, ok := val.(gwfResult); ok {
return r.fp, r.mp, r.cch
if r, ok := val.(gatewayForwardingSettingsResult); ok {
return r
}
return true, false, false // fail-open defaults
return gatewayForwardingSettingsResult{fp: true}
}
// GetGatewayForwardingSettings returns cached gateway forwarding settings.
// Uses in-process atomic.Value cache with 60s TTL, zero-lock hot path.
// Returns (fingerprintUnification, metadataPassthrough, cchSigning).
func (s *SettingService) GetGatewayForwardingSettings(ctx context.Context) (fingerprintUnification, metadataPassthrough, cchSigning bool) {
result := s.getGatewayForwardingSettingsCached(ctx)
return result.fp, result.mp, result.cch
}
// IsAnthropicCacheTTL1hInjectionEnabled 检查是否对 Anthropic OAuth/SetupToken 请求体注入 1h cache_control ttl。
func (s *SettingService) IsAnthropicCacheTTL1hInjectionEnabled(ctx context.Context) bool {
return s.getGatewayForwardingSettingsCached(ctx).cacheTTL1h
}
// IsEmailVerifyEnabled 检查是否开启邮件验证
......@@ -1880,12 +1908,13 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error {
SettingKeyMaxClaudeCodeVersion: "",
// 分组隔离(默认不允许未分组 Key 调度)
SettingKeyAllowUngroupedKeyScheduling: "false",
SettingPaymentVisibleMethodAlipaySource: "",
SettingPaymentVisibleMethodWxpaySource: "",
SettingPaymentVisibleMethodAlipayEnabled: "false",
SettingPaymentVisibleMethodWxpayEnabled: "false",
openAIAdvancedSchedulerSettingKey: "false",
SettingKeyAllowUngroupedKeyScheduling: "false",
SettingKeyEnableAnthropicCacheTTL1hInjection: "false",
SettingPaymentVisibleMethodAlipaySource: "",
SettingPaymentVisibleMethodWxpaySource: "",
SettingPaymentVisibleMethodAlipayEnabled: "false",
SettingPaymentVisibleMethodWxpayEnabled: "false",
openAIAdvancedSchedulerSettingKey: "false",
}
return s.settingRepo.SetMultiple(ctx, defaults)
......@@ -2228,6 +2257,7 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin
}
result.EnableMetadataPassthrough = settings[SettingKeyEnableMetadataPassthrough] == "true"
result.EnableCCHSigning = settings[SettingKeyEnableCCHSigning] == "true"
result.EnableAnthropicCacheTTL1hInjection = settings[SettingKeyEnableAnthropicCacheTTL1hInjection] == "true"
// Web search emulation: quick enabled check from the JSON config
if raw := settings[SettingKeyWebSearchEmulationConfig]; raw != "" {
......@@ -3259,6 +3289,84 @@ func (s *SettingService) SetBetaPolicySettings(ctx context.Context, settings *Be
return s.settingRepo.Set(ctx, SettingKeyBetaPolicySettings, string(data))
}
// GetOpenAIFastPolicySettings 获取 OpenAI fast 策略配置
func (s *SettingService) GetOpenAIFastPolicySettings(ctx context.Context) (*OpenAIFastPolicySettings, error) {
value, err := s.settingRepo.GetValue(ctx, SettingKeyOpenAIFastPolicySettings)
if err != nil {
if errors.Is(err, ErrSettingNotFound) {
return DefaultOpenAIFastPolicySettings(), nil
}
return nil, fmt.Errorf("get openai fast policy settings: %w", err)
}
if value == "" {
return DefaultOpenAIFastPolicySettings(), nil
}
var settings OpenAIFastPolicySettings
if err := json.Unmarshal([]byte(value), &settings); err != nil {
// JSON 损坏时静默 fallback 到默认配置会让策略意外失效(管理员配
// 置的 block/filter 规则被忽略)。记录 Warn 让运维能在出现异常
// 行为时定位到 settings 表里的脏数据。
slog.Warn("failed to unmarshal openai fast policy settings, falling back to defaults",
"error", err,
"key", SettingKeyOpenAIFastPolicySettings)
return DefaultOpenAIFastPolicySettings(), nil
}
return &settings, nil
}
// SetOpenAIFastPolicySettings 设置 OpenAI fast 策略配置
func (s *SettingService) SetOpenAIFastPolicySettings(ctx context.Context, settings *OpenAIFastPolicySettings) error {
if settings == nil {
return fmt.Errorf("settings cannot be nil")
}
validActions := map[string]bool{
BetaPolicyActionPass: true, BetaPolicyActionFilter: true, BetaPolicyActionBlock: true,
}
validScopes := map[string]bool{
BetaPolicyScopeAll: true, BetaPolicyScopeOAuth: true, BetaPolicyScopeAPIKey: true, BetaPolicyScopeBedrock: true,
}
validTiers := map[string]bool{
OpenAIFastTierAny: true, OpenAIFastTierPriority: true, OpenAIFastTierFlex: true,
}
for i, rule := range settings.Rules {
tier := strings.ToLower(strings.TrimSpace(rule.ServiceTier))
if tier == "" {
tier = OpenAIFastTierAny
}
if !validTiers[tier] {
return fmt.Errorf("rule[%d]: invalid service_tier %q", i, rule.ServiceTier)
}
settings.Rules[i].ServiceTier = tier
if !validActions[rule.Action] {
return fmt.Errorf("rule[%d]: invalid action %q", i, rule.Action)
}
if !validScopes[rule.Scope] {
return fmt.Errorf("rule[%d]: invalid scope %q", i, rule.Scope)
}
for j, pattern := range rule.ModelWhitelist {
trimmed := strings.TrimSpace(pattern)
if trimmed == "" {
return fmt.Errorf("rule[%d]: model_whitelist[%d] cannot be empty", i, j)
}
settings.Rules[i].ModelWhitelist[j] = trimmed
}
if rule.FallbackAction != "" && !validActions[rule.FallbackAction] {
return fmt.Errorf("rule[%d]: invalid fallback_action %q", i, rule.FallbackAction)
}
}
data, err := json.Marshal(settings)
if err != nil {
return fmt.Errorf("marshal openai fast policy settings: %w", err)
}
return s.settingRepo.Set(ctx, SettingKeyOpenAIFastPolicySettings, string(data))
}
// SetStreamTimeoutSettings 设置流超时处理配置
func (s *SettingService) SetStreamTimeoutSettings(ctx context.Context, settings *StreamTimeoutSettings) error {
if settings == nil {
......
......@@ -149,9 +149,10 @@ type SystemSettings struct {
BackendModeEnabled bool
// Gateway forwarding behavior
EnableFingerprintUnification bool // 是否统一 OAuth 账号的指纹头(默认 true)
EnableMetadataPassthrough bool // 是否透传客户端原始 metadata(默认 false)
EnableCCHSigning bool // 是否对 billing header cch 进行签名(默认 false)
EnableFingerprintUnification bool // 是否统一 OAuth 账号的指纹头(默认 true)
EnableMetadataPassthrough bool // 是否透传客户端原始 metadata(默认 false)
EnableCCHSigning bool // 是否对 billing header cch 进行签名(默认 false)
EnableAnthropicCacheTTL1hInjection bool // 是否对 Anthropic OAuth/SetupToken 请求体注入 1h cache_control ttl(默认 false)
// Web Search Emulation
WebSearchEmulationEnabled bool // 是否启用 web search 模拟
......@@ -405,3 +406,57 @@ func DefaultBetaPolicySettings() *BetaPolicySettings {
},
}
}
// OpenAI Fast Policy 策略常量
// OpenAI 的 "fast 模式" 通过请求体中的 service_tier 字段识别:
// - "priority"(客户端可传 "fast",归一化为 "priority"):fast 模式
// - "flex":低优先级模式
// - 省略:normal 默认
//
// 本策略复用 BetaPolicyAction*/BetaPolicyScope* 常量语义,只是匹配键从
// anthropic-beta header 换成 body 的 service_tier 字段。
const (
OpenAIFastTierAny = "all" // 匹配任意已识别的 service_tier
OpenAIFastTierPriority = "priority" // 仅匹配 fast(priority)
OpenAIFastTierFlex = "flex" // 仅匹配 flex
)
// OpenAIFastPolicyRule 单条 OpenAI fast/flex 策略规则
type OpenAIFastPolicyRule struct {
ServiceTier string `json:"service_tier"` // "priority" | "flex" | "auto" | "default" | "scale" | "all"
Action string `json:"action"` // "pass" | "filter" | "block"
Scope string `json:"scope"` // "all" | "oauth" | "apikey" | "bedrock"
ErrorMessage string `json:"error_message,omitempty"` // 自定义错误消息 (action=block 时生效)
ModelWhitelist []string `json:"model_whitelist,omitempty"` // 模型匹配模式列表(为空=对所有模型生效)
FallbackAction string `json:"fallback_action,omitempty"` // 未匹配白名单的模型的处理方式
FallbackErrorMessage string `json:"fallback_error_message,omitempty"` // 未匹配白名单时的自定义错误消息 (fallback_action=block 时生效)
}
// OpenAIFastPolicySettings OpenAI fast 策略配置
type OpenAIFastPolicySettings struct {
Rules []OpenAIFastPolicyRule `json:"rules"`
}
// DefaultOpenAIFastPolicySettings 返回默认的 OpenAI fast 策略配置。
// 默认对所有模型的 priority(fast)请求执行 filter,即剔除 service_tier 字段,
// 让上游按 normal 优先级处理。
//
// 为什么 ModelWhitelist 为空(=对所有模型生效):
// codex 客户端的 service_tier=fast 是用户级开关,与 model 字段正交。即使
// 用户使用 gpt-4 + fast,priority 配额仍会被消耗。如果默认规则只锁
// gpt-5.5*,"用 gpt-4 + fast 透传 priority 上游" 这条路径就会绕过策略。
// 与 codex 真实语义对齐,默认对所有模型生效;管理员若需要只针对特定
// 模型,可在 admin UI 中显式配置 model_whitelist。
func DefaultOpenAIFastPolicySettings() *OpenAIFastPolicySettings {
return &OpenAIFastPolicySettings{
Rules: []OpenAIFastPolicyRule{
{
ServiceTier: OpenAIFastTierPriority,
Action: BetaPolicyActionFilter,
Scope: BetaPolicyScopeAll,
ModelWhitelist: []string{},
FallbackAction: BetaPolicyActionPass,
},
},
}
}
package service
import (
"bytes"
"context"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"io"
"log/slog"
"net/http"
"net/url"
"regexp"
"strings"
"time"
"github.com/golang-jwt/jwt/v5"
)
const (
vertexDefaultLocation = "us-central1"
vertexDefaultTokenURL = "https://oauth2.googleapis.com/token"
vertexCloudPlatformScope = "https://www.googleapis.com/auth/cloud-platform"
vertexServiceAccountCacheSkew = 5 * time.Minute
vertexLockWaitTime = 200 * time.Millisecond
vertexAnthropicVersion = "vertex-2023-10-16"
)
var (
vertexLocationPattern = regexp.MustCompile(`^[a-z0-9-]+$`)
vertexAnthropicDatedModelIDPattern = regexp.MustCompile(`^(.+)-([0-9]{8})$`)
vertexAnthropicAlreadyDatedIDPattern = regexp.MustCompile(`^.+@[0-9]{8}$`)
)
type vertexServiceAccountKey struct {
Type string `json:"type"`
ProjectID string `json:"project_id"`
PrivateKeyID string `json:"private_key_id"`
PrivateKey string `json:"private_key"`
ClientEmail string `json:"client_email"`
TokenURI string `json:"token_uri"`
}
type vertexTokenResponse struct {
AccessToken string `json:"access_token"`
TokenType string `json:"token_type"`
ExpiresIn int64 `json:"expires_in"`
Error string `json:"error"`
ErrorDesc string `json:"error_description"`
}
func (a *Account) IsVertexServiceAccount() bool {
return a != nil && a.Type == AccountTypeServiceAccount
}
func (a *Account) VertexProjectID() string {
if a == nil {
return ""
}
if v := strings.TrimSpace(a.GetCredential("project_id")); v != "" {
return v
}
key, err := parseVertexServiceAccountKey(a)
if err == nil {
return strings.TrimSpace(key.ProjectID)
}
return ""
}
func (a *Account) VertexLocation(model string) string {
if a == nil {
return vertexDefaultLocation
}
if model != "" && a.Credentials != nil {
if raw, ok := a.Credentials["vertex_model_locations"].(map[string]any); ok {
if loc, ok := raw[model].(string); ok && strings.TrimSpace(loc) != "" {
return strings.TrimSpace(loc)
}
}
}
if v := strings.TrimSpace(a.GetCredential("location")); v != "" {
return v
}
if v := strings.TrimSpace(a.GetCredential("vertex_location")); v != "" {
return v
}
return vertexDefaultLocation
}
func parseVertexServiceAccountKey(account *Account) (*vertexServiceAccountKey, error) {
if account == nil || account.Credentials == nil {
return nil, errors.New("service account credentials not configured")
}
if raw := strings.TrimSpace(account.GetCredential("service_account_json")); raw != "" {
return parseVertexServiceAccountJSON([]byte(raw))
}
if raw := strings.TrimSpace(account.GetCredential("service_account")); raw != "" {
return parseVertexServiceAccountJSON([]byte(raw))
}
if nested, ok := account.Credentials["service_account_json"].(map[string]any); ok {
b, _ := json.Marshal(nested)
return parseVertexServiceAccountJSON(b)
}
if nested, ok := account.Credentials["service_account"].(map[string]any); ok {
b, _ := json.Marshal(nested)
return parseVertexServiceAccountJSON(b)
}
return nil, errors.New("service_account_json not found in credentials")
}
func parseVertexServiceAccountJSON(raw []byte) (*vertexServiceAccountKey, error) {
var key vertexServiceAccountKey
if err := json.Unmarshal(raw, &key); err != nil {
return nil, fmt.Errorf("invalid service account json: %w", err)
}
if strings.TrimSpace(key.ClientEmail) == "" {
return nil, errors.New("service account json missing client_email")
}
if strings.TrimSpace(key.PrivateKey) == "" {
return nil, errors.New("service account json missing private_key")
}
if strings.TrimSpace(key.ProjectID) == "" {
return nil, errors.New("service account json missing project_id")
}
// Always use the well-known Google token endpoint to prevent SSRF via crafted token_uri.
key.TokenURI = vertexDefaultTokenURL
return &key, nil
}
func vertexServiceAccountCacheKey(account *Account, key *vertexServiceAccountKey) string {
fingerprint := ""
if key != nil {
sum := sha256.Sum256([]byte(key.ClientEmail + "\x00" + key.PrivateKeyID))
fingerprint = hex.EncodeToString(sum[:8])
}
if fingerprint == "" && account != nil {
fingerprint = fmt.Sprintf("account:%d", account.ID)
}
return "vertex:service_account:" + fingerprint
}
// getVertexServiceAccountAccessToken obtains an access token for a Vertex service account,
// using the shared cache and distributed lock to avoid redundant exchanges.
func getVertexServiceAccountAccessToken(ctx context.Context, cache GeminiTokenCache, account *Account) (string, error) {
key, err := parseVertexServiceAccountKey(account)
if err != nil {
return "", err
}
cacheKey := vertexServiceAccountCacheKey(account, key)
if cache != nil {
if token, err := cache.GetAccessToken(ctx, cacheKey); err == nil && strings.TrimSpace(token) != "" {
return token, nil
}
}
locked := false
if cache != nil {
var lockErr error
locked, lockErr = cache.AcquireRefreshLock(ctx, cacheKey, 30*time.Second)
if lockErr == nil && locked {
defer func() { _ = cache.ReleaseRefreshLock(ctx, cacheKey) }()
} else if lockErr != nil {
slog.Warn("vertex_service_account_token_lock_failed", "account_id", account.ID, "error", lockErr)
} else {
time.Sleep(vertexLockWaitTime)
if token, err := cache.GetAccessToken(ctx, cacheKey); err == nil && strings.TrimSpace(token) != "" {
return token, nil
}
}
}
accessToken, ttl, err := exchangeVertexServiceAccountToken(ctx, key)
if err != nil {
return "", err
}
if cache != nil {
_ = cache.SetAccessToken(ctx, cacheKey, accessToken, ttl)
}
return accessToken, nil
}
func exchangeVertexServiceAccountToken(ctx context.Context, key *vertexServiceAccountKey) (string, time.Duration, error) {
now := time.Now()
claims := jwt.MapClaims{
"iss": key.ClientEmail,
"scope": vertexCloudPlatformScope,
"aud": key.TokenURI,
"iat": now.Unix(),
"exp": now.Add(time.Hour).Unix(),
}
token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims)
if strings.TrimSpace(key.PrivateKeyID) != "" {
token.Header["kid"] = key.PrivateKeyID
}
privateKey, err := jwt.ParseRSAPrivateKeyFromPEM([]byte(key.PrivateKey))
if err != nil {
return "", 0, fmt.Errorf("parse service account private key: %w", err)
}
assertion, err := token.SignedString(privateKey)
if err != nil {
return "", 0, fmt.Errorf("sign service account assertion: %w", err)
}
values := url.Values{}
values.Set("grant_type", "urn:ietf:params:oauth:grant-type:jwt-bearer")
values.Set("assertion", assertion)
req, err := http.NewRequestWithContext(ctx, http.MethodPost, key.TokenURI, strings.NewReader(values.Encode()))
if err != nil {
return "", 0, err
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
client := &http.Client{Timeout: 15 * time.Second}
resp, err := client.Do(req)
if err != nil {
return "", 0, fmt.Errorf("service account token request failed: %w", err)
}
defer func() { _ = resp.Body.Close() }()
body, _ := io.ReadAll(io.LimitReader(resp.Body, 1<<20))
var parsed vertexTokenResponse
_ = json.Unmarshal(body, &parsed)
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
msg := strings.TrimSpace(parsed.ErrorDesc)
if msg == "" {
msg = strings.TrimSpace(parsed.Error)
}
if msg == "" {
msg = string(bytes.TrimSpace(body))
}
return "", 0, fmt.Errorf("service account token request returned %d: %s", resp.StatusCode, msg)
}
if strings.TrimSpace(parsed.AccessToken) == "" {
return "", 0, errors.New("service account token response missing access_token")
}
ttl := time.Duration(parsed.ExpiresIn) * time.Second
if ttl <= 0 {
ttl = time.Hour
}
if ttl > vertexServiceAccountCacheSkew {
ttl -= vertexServiceAccountCacheSkew
}
return parsed.AccessToken, ttl, nil
}
func buildVertexGeminiURL(projectID, location, model, action string, stream bool) (string, error) {
projectID = strings.TrimSpace(projectID)
location = strings.TrimSpace(location)
model = strings.TrimSpace(model)
action = strings.TrimSpace(action)
if projectID == "" {
return "", errors.New("vertex project_id is required")
}
if location == "" {
location = vertexDefaultLocation
}
if !vertexLocationPattern.MatchString(location) {
return "", fmt.Errorf("invalid vertex location: %s", location)
}
if model == "" {
return "", errors.New("vertex model is required")
}
switch action {
case "generateContent", "streamGenerateContent", "countTokens":
default:
return "", fmt.Errorf("unsupported vertex gemini action: %s", action)
}
host := fmt.Sprintf("%s-aiplatform.googleapis.com", location)
if location == "global" {
host = "aiplatform.googleapis.com"
}
u := fmt.Sprintf(
"https://%s/v1/projects/%s/locations/%s/publishers/google/models/%s:%s",
host,
url.PathEscape(projectID),
url.PathEscape(location),
url.PathEscape(model),
action,
)
if stream {
u += "?alt=sse"
}
return u, nil
}
func buildVertexAnthropicURL(projectID, location, model string, stream bool) (string, error) {
projectID = strings.TrimSpace(projectID)
location = strings.TrimSpace(location)
model = strings.TrimSpace(model)
if projectID == "" {
return "", errors.New("vertex project_id is required")
}
if location == "" {
location = vertexDefaultLocation
}
if !vertexLocationPattern.MatchString(location) {
return "", fmt.Errorf("invalid vertex location: %s", location)
}
if model == "" {
return "", errors.New("vertex model is required")
}
action := "rawPredict"
if stream {
action = "streamRawPredict"
}
host := fmt.Sprintf("%s-aiplatform.googleapis.com", location)
if location == "global" {
host = "aiplatform.googleapis.com"
}
escapedModel := strings.ReplaceAll(url.PathEscape(model), "%40", "@")
return fmt.Sprintf(
"https://%s/v1/projects/%s/locations/%s/publishers/anthropic/models/%s:%s",
host,
url.PathEscape(projectID),
url.PathEscape(location),
escapedModel,
action,
), nil
}
func normalizeVertexAnthropicModelID(model string) string {
model = strings.TrimSpace(model)
if model == "" || vertexAnthropicAlreadyDatedIDPattern.MatchString(model) {
return model
}
if m := vertexAnthropicDatedModelIDPattern.FindStringSubmatch(model); len(m) == 3 {
return m[1] + "@" + m[2]
}
return model
}
func buildVertexAnthropicRequestBody(body []byte) ([]byte, error) {
var payload map[string]any
if err := json.Unmarshal(body, &payload); err != nil {
return nil, fmt.Errorf("parse anthropic vertex request body: %w", err)
}
delete(payload, "model")
payload["anthropic_version"] = vertexAnthropicVersion
return json.Marshal(payload)
}
package service
import (
"strings"
"testing"
"github.com/stretchr/testify/require"
"github.com/tidwall/gjson"
)
func TestBuildVertexGeminiURL(t *testing.T) {
got, err := buildVertexGeminiURL("my-project", "us-central1", "gemini-3-pro", "streamGenerateContent", true)
require.NoError(t, err)
require.Equal(t, "https://us-central1-aiplatform.googleapis.com/v1/projects/my-project/locations/us-central1/publishers/google/models/gemini-3-pro:streamGenerateContent?alt=sse", got)
}
func TestBuildVertexGeminiURLUsesGlobalEndpointHost(t *testing.T) {
got, err := buildVertexGeminiURL("my-project", "global", "gemini-3-flash-preview", "streamGenerateContent", true)
require.NoError(t, err)
require.Equal(t, "https://aiplatform.googleapis.com/v1/projects/my-project/locations/global/publishers/google/models/gemini-3-flash-preview:streamGenerateContent?alt=sse", got)
}
func TestBuildVertexAnthropicURL(t *testing.T) {
got, err := buildVertexAnthropicURL("my-project", "us-east5", "claude-sonnet-4-5@20250929", false)
require.NoError(t, err)
require.Equal(t, "https://us-east5-aiplatform.googleapis.com/v1/projects/my-project/locations/us-east5/publishers/anthropic/models/claude-sonnet-4-5@20250929:rawPredict", got)
}
func TestBuildVertexAnthropicURLUsesGlobalEndpointHost(t *testing.T) {
got, err := buildVertexAnthropicURL("my-project", "global", "claude-haiku-4-5@20251001", true)
require.NoError(t, err)
require.Equal(t, "https://aiplatform.googleapis.com/v1/projects/my-project/locations/global/publishers/anthropic/models/claude-haiku-4-5@20251001:streamRawPredict", got)
}
func TestNormalizeVertexAnthropicModelID(t *testing.T) {
require.Equal(t, "claude-sonnet-4-5@20250929", normalizeVertexAnthropicModelID("claude-sonnet-4-5-20250929"))
require.Equal(t, "claude-sonnet-4-5@20250929", normalizeVertexAnthropicModelID("claude-sonnet-4-5@20250929"))
require.Equal(t, "claude-sonnet-4-6", normalizeVertexAnthropicModelID("claude-sonnet-4-6"))
}
func TestBuildVertexAnthropicRequestBody(t *testing.T) {
got, err := buildVertexAnthropicRequestBody([]byte(`{"model":"claude-sonnet-4-5","anthropic_version":"2023-06-01","max_tokens":64,"messages":[{"role":"user","content":"hi"}]}`))
require.NoError(t, err)
require.Equal(t, "", gjson.GetBytes(got, "model").String())
require.Equal(t, vertexAnthropicVersion, gjson.GetBytes(got, "anthropic_version").String())
require.Equal(t, int64(64), gjson.GetBytes(got, "max_tokens").Int())
require.Equal(t, "hi", gjson.GetBytes(got, "messages.0.content").String())
}
func TestBuildVertexGeminiURLRejectsInvalidLocation(t *testing.T) {
_, err := buildVertexGeminiURL("my-project", "us-central1/path", "gemini-3-pro", "generateContent", false)
require.Error(t, err)
require.Contains(t, err.Error(), "invalid vertex location")
}
func TestParseVertexServiceAccountKey(t *testing.T) {
raw := `{
"type": "service_account",
"project_id": "vertex-proj",
"private_key_id": "kid",
"private_key": "-----BEGIN PRIVATE KEY-----\nabc\n-----END PRIVATE KEY-----\n",
"client_email": "svc@vertex-proj.iam.gserviceaccount.com"
}`
account := &Account{
Type: AccountTypeServiceAccount,
Platform: PlatformGemini,
Credentials: map[string]any{
"service_account_json": raw,
},
}
key, err := parseVertexServiceAccountKey(account)
require.NoError(t, err)
require.Equal(t, "vertex-proj", key.ProjectID)
require.Equal(t, "svc@vertex-proj.iam.gserviceaccount.com", key.ClientEmail)
require.Equal(t, vertexDefaultTokenURL, key.TokenURI)
require.True(t, strings.Contains(key.PrivateKey, "BEGIN PRIVATE KEY"))
}
......@@ -404,12 +404,28 @@ func ProvideBillingCacheService(
return NewBillingCacheService(cache, userRepo, subRepo, apiKeyRepo, rpmCache, rateRepo, cfg)
}
// ProvideAPIKeyService wires APIKeyService and connects rate-limit cache invalidation.
func ProvideAPIKeyService(
apiKeyRepo APIKeyRepository,
userRepo UserRepository,
groupRepo GroupRepository,
userSubRepo UserSubscriptionRepository,
userGroupRateRepo UserGroupRateRepository,
cache APIKeyCache,
cfg *config.Config,
billingCacheService *BillingCacheService,
) *APIKeyService {
svc := NewAPIKeyService(apiKeyRepo, userRepo, groupRepo, userSubRepo, userGroupRateRepo, cache, cfg)
svc.SetRateLimitCacheInvalidator(billingCacheService)
return svc
}
// ProviderSet is the Wire provider set for all services
var ProviderSet = wire.NewSet(
// Core services
NewAuthService,
NewUserService,
NewAPIKeyService,
ProvideAPIKeyService,
ProvideAPIKeyAuthCacheInvalidator,
NewGroupService,
NewAccountService,
......
CREATE TABLE IF NOT EXISTS payment_orders (
id BIGSERIAL PRIMARY KEY,
user_id BIGINT NOT NULL,
user_email VARCHAR(255) NOT NULL DEFAULT '',
user_name VARCHAR(100) NOT NULL DEFAULT '',
user_notes TEXT,
amount DECIMAL(20,2) NOT NULL,
pay_amount DECIMAL(20,2) NOT NULL,
fee_rate DECIMAL(10,4) NOT NULL DEFAULT 0,
recharge_code VARCHAR(64) NOT NULL DEFAULT '',
payment_type VARCHAR(30) NOT NULL DEFAULT '',
payment_trade_no VARCHAR(128) NOT NULL DEFAULT '',
pay_url TEXT,
qr_code TEXT,
qr_code_img TEXT,
order_type VARCHAR(20) NOT NULL DEFAULT 'balance',
plan_id BIGINT,
subscription_group_id BIGINT,
subscription_days INT,
provider_instance_id VARCHAR(64),
status VARCHAR(30) NOT NULL DEFAULT 'PENDING',
refund_amount DECIMAL(20,2) NOT NULL DEFAULT 0,
refund_reason TEXT,
refund_at TIMESTAMPTZ,
force_refund BOOLEAN NOT NULL DEFAULT FALSE,
refund_requested_at TIMESTAMPTZ,
refund_request_reason TEXT,
refund_requested_by VARCHAR(20),
expires_at TIMESTAMPTZ NOT NULL,
paid_at TIMESTAMPTZ,
completed_at TIMESTAMPTZ,
failed_at TIMESTAMPTZ,
failed_reason TEXT,
client_ip VARCHAR(50) NOT NULL DEFAULT '',
src_host VARCHAR(255) NOT NULL DEFAULT '',
src_url TEXT,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
-- Indexes
CREATE INDEX IF NOT EXISTS idx_payment_orders_user_id ON payment_orders(user_id);
CREATE INDEX IF NOT EXISTS idx_payment_orders_status ON payment_orders(status);
CREATE INDEX IF NOT EXISTS idx_payment_orders_expires_at ON payment_orders(expires_at);
CREATE INDEX IF NOT EXISTS idx_payment_orders_created_at ON payment_orders(created_at);
CREATE INDEX IF NOT EXISTS idx_payment_orders_paid_at ON payment_orders(paid_at);
CREATE INDEX IF NOT EXISTS idx_payment_orders_type_paid ON payment_orders(payment_type, paid_at);
CREATE INDEX IF NOT EXISTS idx_payment_orders_order_type ON payment_orders(order_type);
CREATE TABLE IF NOT EXISTS payment_audit_logs (
id BIGSERIAL PRIMARY KEY,
order_id VARCHAR(64) NOT NULL,
action VARCHAR(50) NOT NULL,
detail TEXT NOT NULL DEFAULT '',
operator VARCHAR(100) NOT NULL DEFAULT 'system',
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
CREATE INDEX IF NOT EXISTS idx_payment_audit_logs_order_id ON payment_audit_logs(order_id);
-- Migration 092: payment_channels table was removed before release.
-- This file is a no-op placeholder to maintain migration numbering continuity.
-- The payment system now uses the existing channels table (migration 081).
SELECT 1;
ALTER TABLE channels ADD COLUMN IF NOT EXISTS features TEXT NOT NULL DEFAULT '';
COMMENT ON COLUMN channels.features IS '渠道特性描述,JSON 数组格式,用于支付页面展示';
CREATE TABLE IF NOT EXISTS subscription_plans (
id BIGSERIAL PRIMARY KEY,
group_id BIGINT NOT NULL,
name VARCHAR(100) NOT NULL,
description TEXT NOT NULL DEFAULT '',
price DECIMAL(20,2) NOT NULL,
original_price DECIMAL(20,2),
validity_days INT NOT NULL DEFAULT 30,
validity_unit VARCHAR(10) NOT NULL DEFAULT 'day',
features TEXT NOT NULL DEFAULT '',
product_name VARCHAR(100) NOT NULL DEFAULT '',
for_sale BOOLEAN NOT NULL DEFAULT TRUE,
sort_order INT NOT NULL DEFAULT 0,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
CREATE INDEX IF NOT EXISTS idx_subscription_plans_group_id ON subscription_plans(group_id);
CREATE INDEX IF NOT EXISTS idx_subscription_plans_for_sale ON subscription_plans(for_sale);
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment