Commit 538ae31a authored by 陈曦's avatar 陈曦
Browse files

merge v0.1.121 and fixed conflict

parents 74828a7c 48912014
Pipeline #82338 passed with stage
in 17 seconds
package service
import (
"context"
"errors"
"strings"
"testing"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
"github.com/stretchr/testify/require"
"github.com/tidwall/gjson"
)
type gatewayTTLSettingRepo struct {
data map[string]string
}
func (r *gatewayTTLSettingRepo) Get(context.Context, string) (*Setting, error) {
return nil, ErrSettingNotFound
}
func (r *gatewayTTLSettingRepo) GetValue(_ context.Context, key string) (string, error) {
if r == nil {
return "", ErrSettingNotFound
}
v, ok := r.data[key]
if !ok {
return "", ErrSettingNotFound
}
return v, nil
}
func (r *gatewayTTLSettingRepo) Set(_ context.Context, key, value string) error {
if r == nil {
return errors.New("setting repo is nil")
}
if r.data == nil {
r.data = map[string]string{}
}
r.data[key] = value
return nil
}
func (r *gatewayTTLSettingRepo) GetMultiple(_ context.Context, keys []string) (map[string]string, error) {
result := make(map[string]string)
if r == nil {
return result, nil
}
for _, key := range keys {
if v, ok := r.data[key]; ok {
result[key] = v
}
}
return result, nil
}
func (r *gatewayTTLSettingRepo) SetMultiple(_ context.Context, settings map[string]string) error {
if r == nil {
return errors.New("setting repo is nil")
}
if r.data == nil {
r.data = map[string]string{}
}
for key, value := range settings {
r.data[key] = value
}
return nil
}
func (r *gatewayTTLSettingRepo) GetAll(context.Context) (map[string]string, error) {
result := make(map[string]string)
if r == nil {
return result, nil
}
for key, value := range r.data {
result[key] = value
}
return result, nil
}
func (r *gatewayTTLSettingRepo) Delete(_ context.Context, key string) error {
if r != nil {
delete(r.data, key)
}
return nil
}
func assertJSONTokenOrder(t *testing.T, body string, tokens ...string) {
t.Helper()
......@@ -71,3 +149,60 @@ func TestEnforceCacheControlLimit_PreservesTopLevelFieldOrder(t *testing.T) {
assertJSONTokenOrder(t, resultStr, `"alpha"`, `"system"`, `"messages"`, `"omega"`)
require.Equal(t, 4, strings.Count(resultStr, `"cache_control"`))
}
func TestInjectAnthropicCacheControlTTL1h_OnlyUpdatesExistingEphemeralCacheControl(t *testing.T) {
body := []byte(`{"alpha":1,"cache_control":{"type":"ephemeral"},"system":[{"type":"text","text":"sys","cache_control":{"type":"ephemeral","ttl":"5m"}},{"type":"text","text":"plain"}],"messages":[{"role":"user","content":[{"type":"text","text":"hi","cache_control":{"type":"ephemeral"}},{"type":"text","text":"non","cache_control":{"type":"persistent","ttl":"5m"}}]}],"tools":[{"name":"a","input_schema":{},"cache_control":{"type":"ephemeral"}}],"omega":2}`)
result := injectAnthropicCacheControlTTL1h(body)
resultStr := string(result)
assertJSONTokenOrder(t, resultStr, `"alpha"`, `"cache_control"`, `"system"`, `"messages"`, `"tools"`, `"omega"`)
require.Equal(t, "1h", gjson.GetBytes(result, "cache_control.ttl").String())
require.Equal(t, "1h", gjson.GetBytes(result, "system.0.cache_control.ttl").String())
require.False(t, gjson.GetBytes(result, "system.1.cache_control").Exists())
require.Equal(t, "1h", gjson.GetBytes(result, "messages.0.content.0.cache_control.ttl").String())
require.Equal(t, "5m", gjson.GetBytes(result, "messages.0.content.1.cache_control.ttl").String())
require.Equal(t, "1h", gjson.GetBytes(result, "tools.0.cache_control.ttl").String())
}
func TestGatewayCacheTTLGlobalSetting_TargetResolution(t *testing.T) {
repo := &gatewayTTLSettingRepo{data: map[string]string{
SettingKeyEnableAnthropicCacheTTL1hInjection: "true",
}}
gatewayForwardingCache.Store(&cachedGatewayForwardingSettings{})
svc := &GatewayService{
settingService: NewSettingService(repo, &config.Config{}),
}
account := &Account{Platform: PlatformAnthropic, Type: AccountTypeOAuth}
target, ok := svc.resolveCacheTTLUsageOverrideTarget(context.Background(), account)
require.True(t, ok)
require.Equal(t, cacheTTLTarget5m, target)
account.Extra = map[string]any{
"cache_ttl_override_enabled": true,
"cache_ttl_override_target": "1h",
}
target, ok = svc.resolveCacheTTLUsageOverrideTarget(context.Background(), account)
require.True(t, ok)
require.Equal(t, cacheTTLTarget1h, target)
}
func TestGatewayCacheTTLGlobalSetting_RequestInjectionScope(t *testing.T) {
repo := &gatewayTTLSettingRepo{data: map[string]string{
SettingKeyEnableAnthropicCacheTTL1hInjection: "true",
}}
gatewayForwardingCache.Store(&cachedGatewayForwardingSettings{})
svc := &GatewayService{
settingService: NewSettingService(repo, &config.Config{}),
}
require.True(t, svc.shouldInjectAnthropicCacheTTL1h(context.Background(), &Account{Platform: PlatformAnthropic, Type: AccountTypeOAuth}))
require.True(t, svc.shouldInjectAnthropicCacheTTL1h(context.Background(), &Account{Platform: PlatformAnthropic, Type: AccountTypeSetupToken}))
require.False(t, svc.shouldInjectAnthropicCacheTTL1h(context.Background(), &Account{Platform: PlatformAnthropic, Type: AccountTypeAPIKey}))
require.False(t, svc.shouldInjectAnthropicCacheTTL1h(context.Background(), &Account{Platform: PlatformOpenAI, Type: AccountTypeOAuth}))
repo.data[SettingKeyEnableAnthropicCacheTTL1hInjection] = "false"
gatewayForwardingCache.Store(&cachedGatewayForwardingSettings{})
require.False(t, svc.shouldInjectAnthropicCacheTTL1h(context.Background(), &Account{Platform: PlatformAnthropic, Type: AccountTypeOAuth}))
}
......@@ -61,10 +61,15 @@ func (s *GatewayService) ForwardAsChatCompletions(
// 4. Model mapping
mappedModel := originalModel
if account.Type == AccountTypeAPIKey {
if account.Type == AccountTypeAPIKey || account.Type == AccountTypeServiceAccount {
mappedModel = account.GetMappedModel(originalModel)
}
if mappedModel == originalModel && account.Platform == PlatformAnthropic && account.Type != AccountTypeAPIKey {
if mappedModel == originalModel && account.Platform == PlatformAnthropic && account.Type == AccountTypeServiceAccount {
normalized := normalizeVertexAnthropicModelID(claude.NormalizeModelID(originalModel))
if normalized != originalModel {
mappedModel = normalized
}
} else if mappedModel == originalModel && account.Platform == PlatformAnthropic && account.Type != AccountTypeAPIKey {
normalized := claude.NormalizeModelID(originalModel)
if normalized != originalModel {
mappedModel = normalized
......@@ -315,6 +320,14 @@ func (s *GatewayService) handleCCBufferedFromAnthropic(
}
// Marshal then bytes-replace so tool name mapping is reversed at byte level
// (parity with Parrot non-stream flow that marshals → restore → emit).
if respBytes, err := json.Marshal(ccResp); err == nil {
respBytes = reverseToolNamesIfPresent(c, respBytes)
c.Data(http.StatusOK, "application/json; charset=utf-8", respBytes)
} else {
c.JSON(http.StatusOK, ccResp)
}
// Marshal then bytes-replace so tool name mapping is reversed at byte level
// (parity with Parrot non-stream flow that marshals → restore → emit).
var responseBody string
if respBytes, err := json.Marshal(ccResp); err == nil {
respBytes = reverseToolNamesIfPresent(c, respBytes)
......
......@@ -58,10 +58,15 @@ func (s *GatewayService) ForwardAsResponses(
// 4. Model mapping
mappedModel := originalModel
reasoningEffort := ExtractResponsesReasoningEffortFromBody(body)
if account.Type == AccountTypeAPIKey {
if account.Type == AccountTypeAPIKey || account.Type == AccountTypeServiceAccount {
mappedModel = account.GetMappedModel(originalModel)
}
if mappedModel == originalModel && account.Platform == PlatformAnthropic && account.Type != AccountTypeAPIKey {
if mappedModel == originalModel && account.Platform == PlatformAnthropic && account.Type == AccountTypeServiceAccount {
normalized := normalizeVertexAnthropicModelID(claude.NormalizeModelID(originalModel))
if normalized != originalModel {
mappedModel = normalized
}
} else if mappedModel == originalModel && account.Platform == PlatformAnthropic && account.Type != AccountTypeAPIKey {
normalized := claude.NormalizeModelID(originalModel)
if normalized != originalModel {
mappedModel = normalized
......
This diff is collapsed.
......@@ -4,9 +4,12 @@ package service
import (
"context"
"errors"
"io"
"net"
"net/http"
"net/http/httptest"
"syscall"
"testing"
"time"
......@@ -218,3 +221,175 @@ func TestHandleStreamingResponse_SpecialCharactersInJSON(t *testing.T) {
body := rec.Body.String()
require.Contains(t, body, "content_block_delta", "响应应包含转发的 SSE 事件")
}
// 上游中途读错误(如 HTTP/2 GOAWAY 触发的 unexpected EOF)发生在向客户端写入任何字节前:
// 网关应返回 *UpstreamFailoverError 触发账号 failover/重试,而不是把错误事件直接发给客户端。
func TestHandleStreamingResponse_StreamReadErrorBeforeOutput_TriggersFailover(t *testing.T) {
gin.SetMode(gin.TestMode)
svc := newMinimalGatewayService()
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
resp := &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"text/event-stream"}},
Body: &streamReadCloser{err: io.ErrUnexpectedEOF},
}
result, err := svc.handleStreamingResponse(context.Background(), resp, c, &Account{ID: 1}, time.Now(), "model", "model", false)
require.Error(t, err)
require.Nil(t, result, "失败移交场景下不应返回 streamingResult")
var failoverErr *UpstreamFailoverError
require.True(t, errors.As(err, &failoverErr), "未输出过字节时 stream read error 必须包成 UpstreamFailoverError,期望: %v", err)
require.Equal(t, http.StatusBadGateway, failoverErr.StatusCode)
require.True(t, failoverErr.RetryableOnSameAccount, "GOAWAY 类错误应允许同账号重试")
// ResponseBody 必须是 Anthropic 标准 error 格式:
// 1) ExtractUpstreamErrorMessage 能正确从 error.message 提取消息(被 handleFailoverExhausted / ops 日志依赖)
// 2) error.type 标记为 upstream_disconnected
extractedMsg := ExtractUpstreamErrorMessage(failoverErr.ResponseBody)
require.NotEmpty(t, extractedMsg, "ExtractUpstreamErrorMessage 必须从 ResponseBody 取到非空 message,否则 ops 日志会丢失诊断信息")
require.Contains(t, extractedMsg, "upstream stream disconnected")
require.Contains(t, string(failoverErr.ResponseBody), `"type":"error"`)
require.Contains(t, string(failoverErr.ResponseBody), `"upstream_disconnected"`)
// 客户端应收不到任何 stream_read_error 事件,由 handler 层根据 failover 结果再决定
require.NotContains(t, rec.Body.String(), "stream_read_error")
}
// 上游已经发送过事件(c.Writer 已写过字节)后再发生读错误:
// SSE 协议无 resume,网关只能透传 stream_read_error 错误事件给客户端,不能 failover。
func TestHandleStreamingResponse_StreamReadErrorAfterOutput_PassesThrough(t *testing.T) {
gin.SetMode(gin.TestMode)
svc := newMinimalGatewayService()
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
// 第一次 Read 返回完整 SSE 事件让网关向 client 写入字节,第二次 Read 返回 EOF
resp := &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"text/event-stream"}},
Body: &streamReadCloser{
payload: []byte("data: {\"type\":\"message_start\",\"message\":{\"usage\":{\"input_tokens\":5}}}\n\n"),
err: io.ErrUnexpectedEOF,
},
}
result, err := svc.handleStreamingResponse(context.Background(), resp, c, &Account{ID: 1}, time.Now(), "model", "model", false)
require.Error(t, err)
require.Contains(t, err.Error(), "stream read error", "已开始流后应透传普通 stream read error")
require.NotNil(t, result, "透传场景下应返回已收集的 streamingResult")
// 不应被错误地包成 failover error
var failoverErr *UpstreamFailoverError
require.False(t, errors.As(err, &failoverErr), "已经向客户端写过字节时不能再 failover")
// 客户端必须收到 Anthropic 标准格式的 SSE error 事件,error.type=stream_read_error,
// error.message 含具体根因(让 SDK 能解析、UI 能显示具体错误)
body := rec.Body.String()
require.Contains(t, body, "event: error\n", "必须按 Anthropic SSE 标准发送 error 事件帧")
require.Contains(t, body, `"type":"error"`, "data 必须含 type:error 顶层字段(Anthropic 标准)")
require.Contains(t, body, `"stream_read_error"`, "error.type 必须为 stream_read_error")
require.Contains(t, body, "upstream stream disconnected", "error.message 必须包含具体根因,Claude Code 等客户端才能显示有效错误文案")
}
// 默认 (*net.OpError).Error() 会拼接 Source/Addr 字段,泄露内部 IP/端口与上游
// 服务器地址。sanitizeStreamError 必须剥离这些信息,避免基础设施拓扑通过
// failover ResponseBody 或 SSE error 帧返回给客户端。
func TestSanitizeStreamError_StripsNetworkAddresses(t *testing.T) {
src, err := net.ResolveTCPAddr("tcp", "10.0.0.1:54321")
require.NoError(t, err)
dst, err := net.ResolveTCPAddr("tcp", "52.1.2.3:443")
require.NoError(t, err)
raw := &net.OpError{
Op: "read",
Net: "tcp",
Source: src,
Addr: dst,
Err: syscall.ECONNRESET,
}
// 前置:原始 Error() 确实包含会泄露的字段(避免测试在 Go 行为变化时静默通过)
require.Contains(t, raw.Error(), "10.0.0.1")
require.Contains(t, raw.Error(), "52.1.2.3")
got := sanitizeStreamError(raw)
require.NotContains(t, got, "10.0.0.1", "不得泄露内部源 IP")
require.NotContains(t, got, "54321", "不得泄露源端口")
require.NotContains(t, got, "52.1.2.3", "不得泄露上游目标 IP")
require.NotContains(t, got, "443", "不得泄露上游端口")
require.Equal(t, "connection reset by peer", got)
}
func TestSanitizeStreamError_KnownErrors(t *testing.T) {
cases := []struct {
name string
err error
want string
}{
{"unexpected EOF", io.ErrUnexpectedEOF, "unexpected EOF"},
{"EOF", io.EOF, "EOF"},
{"context canceled", context.Canceled, "canceled"},
{"deadline exceeded", context.DeadlineExceeded, "deadline exceeded"},
{"ECONNRESET 直接", syscall.ECONNRESET, "connection reset by peer"},
{"EPIPE", syscall.EPIPE, "broken pipe"},
{"ETIMEDOUT", syscall.ETIMEDOUT, "connection timed out"},
{"未识别错误兜底", errors.New("weird internal error"), "upstream connection error"},
{"nil 返回空串", nil, ""},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
require.Equal(t, tc.want, sanitizeStreamError(tc.err))
})
}
}
// failover ResponseBody 必须用 sanitize 过的消息,避免泄露给客户端 / 写入 ops 日志
// 时携带内部地址信息。
func TestHandleStreamingResponse_FailoverBodyDoesNotLeakAddresses(t *testing.T) {
gin.SetMode(gin.TestMode)
svc := newMinimalGatewayService()
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
src, _ := net.ResolveTCPAddr("tcp", "10.0.0.1:54321")
dst, _ := net.ResolveTCPAddr("tcp", "52.1.2.3:443")
netErr := &net.OpError{
Op: "read",
Net: "tcp",
Source: src,
Addr: dst,
Err: syscall.ECONNRESET,
}
resp := &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"text/event-stream"}},
Body: &streamReadCloser{err: netErr},
}
_, err := svc.handleStreamingResponse(context.Background(), resp, c, &Account{ID: 1}, time.Now(), "model", "model", false)
require.Error(t, err)
var failoverErr *UpstreamFailoverError
require.True(t, errors.As(err, &failoverErr))
body := string(failoverErr.ResponseBody)
require.NotContains(t, body, "10.0.0.1", "failover ResponseBody 不得泄露内部源 IP")
require.NotContains(t, body, "54321")
require.NotContains(t, body, "52.1.2.3", "failover ResponseBody 不得泄露上游 IP")
require.NotContains(t, body, "443")
// 仍然包含可诊断的根因
require.Contains(t, body, "connection reset by peer")
require.Contains(t, body, "upstream stream disconnected")
}
......@@ -515,6 +515,10 @@ func (s *GeminiMessagesCompatService) SelectAccountForAIStudioEndpoints(ctx cont
}
// Code Assist OAuth tokens often lack AI Studio scopes for models listing.
return 3
case AccountTypeServiceAccount:
// Vertex service accounts use aiplatform.googleapis.com, not the AI Studio
// endpoint (generativelanguage.googleapis.com), so they cannot serve these requests.
return 999
default:
return 10
}
......@@ -579,7 +583,7 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
originalModel := req.Model
mappedModel := req.Model
if account.Type == AccountTypeAPIKey {
if account.Type == AccountTypeAPIKey || account.Type == AccountTypeServiceAccount {
mappedModel = account.GetMappedModel(req.Model)
}
......@@ -712,6 +716,36 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
}
requestIDHeader = "x-request-id"
case AccountTypeServiceAccount:
buildReq = func(ctx context.Context) (*http.Request, string, error) {
if s.tokenProvider == nil {
return nil, "", errors.New("gemini token provider not configured")
}
accessToken, err := s.tokenProvider.GetAccessToken(ctx, account)
if err != nil {
return nil, "", err
}
action := "generateContent"
if req.Stream {
action = "streamGenerateContent"
}
fullURL, err := buildVertexGeminiURL(account.VertexProjectID(), account.VertexLocation(mappedModel), mappedModel, action, req.Stream)
if err != nil {
return nil, "", err
}
restGeminiReq := normalizeGeminiRequestForAIStudio(geminiReq)
upstreamReq, err := http.NewRequestWithContext(ctx, http.MethodPost, fullURL, bytes.NewReader(restGeminiReq))
if err != nil {
return nil, "", err
}
upstreamReq.Header.Set("Content-Type", "application/json")
upstreamReq.Header.Set("Authorization", "Bearer "+accessToken)
return upstreamReq, "x-request-id", nil
}
requestIDHeader = "x-request-id"
default:
return nil, fmt.Errorf("unsupported account type: %s", account.Type)
}
......@@ -1101,7 +1135,7 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
body = ensureGeminiFunctionCallThoughtSignatures(body)
mappedModel := originalModel
if account.Type == AccountTypeAPIKey {
if account.Type == AccountTypeAPIKey || account.Type == AccountTypeServiceAccount {
mappedModel = account.GetMappedModel(originalModel)
}
......@@ -1220,6 +1254,31 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
}
requestIDHeader = "x-request-id"
case AccountTypeServiceAccount:
buildReq = func(ctx context.Context) (*http.Request, string, error) {
if s.tokenProvider == nil {
return nil, "", errors.New("gemini token provider not configured")
}
accessToken, err := s.tokenProvider.GetAccessToken(ctx, account)
if err != nil {
return nil, "", err
}
fullURL, err := buildVertexGeminiURL(account.VertexProjectID(), account.VertexLocation(mappedModel), mappedModel, upstreamAction, useUpstreamStream)
if err != nil {
return nil, "", err
}
upstreamReq, err := http.NewRequestWithContext(ctx, http.MethodPost, fullURL, bytes.NewReader(body))
if err != nil {
return nil, "", err
}
upstreamReq.Header.Set("Content-Type", "application/json")
upstreamReq.Header.Set("Authorization", "Bearer "+accessToken)
return upstreamReq, "x-request-id", nil
}
requestIDHeader = "x-request-id"
default:
return nil, s.writeGoogleError(c, http.StatusBadGateway, "Unsupported account type: "+account.Type)
}
......
......@@ -15,7 +15,7 @@ const (
geminiTokenCacheSkew = 5 * time.Minute
)
// GeminiTokenProvider manages access_token for Gemini OAuth accounts.
// GeminiTokenProvider manages access_token for Gemini OAuth and Vertex service account accounts.
type GeminiTokenProvider struct {
accountRepo AccountRepository
tokenCache GeminiTokenCache
......@@ -53,8 +53,11 @@ func (p *GeminiTokenProvider) GetAccessToken(ctx context.Context, account *Accou
if account == nil {
return "", errors.New("account is nil")
}
if account.Platform != PlatformGemini || account.Type != AccountTypeOAuth {
return "", errors.New("not a gemini oauth account")
if account.Platform != PlatformGemini || (account.Type != AccountTypeOAuth && account.Type != AccountTypeServiceAccount) {
return "", errors.New("not a gemini oauth or service account")
}
if account.Type == AccountTypeServiceAccount {
return p.getServiceAccountAccessToken(ctx, account)
}
cacheKey := GeminiTokenCacheKey(account)
......@@ -168,7 +171,16 @@ func (p *GeminiTokenProvider) GetAccessToken(ctx context.Context, account *Accou
return accessToken, nil
}
func (p *GeminiTokenProvider) getServiceAccountAccessToken(ctx context.Context, account *Account) (string, error) {
return getVertexServiceAccountAccessToken(ctx, p.tokenCache, account)
}
func GeminiTokenCacheKey(account *Account) string {
if account != nil && account.Type == AccountTypeServiceAccount {
if key, err := parseVertexServiceAccountKey(account); err == nil {
return vertexServiceAccountCacheKey(account, key)
}
}
projectID := strings.TrimSpace(account.GetCredential("project_id"))
if projectID != "" {
return "gemini:" + projectID
......
......@@ -53,6 +53,23 @@ const (
codexSparkImageUnsupportedText = codexSparkImageUnsupportedMarker + "\nThe current model is gpt-5.3-codex-spark, which does not support image generation, image editing, image input, the `image_generation` tool, or Codex `image_gen`/`$imagegen` workflows. If the user asks for image generation or image editing, clearly explain this model limitation and ask them to switch to a non-Spark Codex model such as gpt-5.3-codex or gpt-5.4. Do not claim that the local environment merely lacks image_gen tooling, and do not suggest CLI fallback as the primary fix while the model remains Spark.\n</sub2api-codex-spark-image-unsupported>"
)
var openAIChatGPTInternalUnsupportedFields = []string{
"user",
"metadata",
"prompt_cache_retention",
"safety_identifier",
"stream_options",
}
var openAICodexOAuthUnsupportedFields = append([]string{
"max_output_tokens",
"max_completion_tokens",
"temperature",
"top_p",
"frequency_penalty",
"presence_penalty",
}, openAIChatGPTInternalUnsupportedFields...)
func applyCodexOAuthTransform(reqBody map[string]any, isCodexCLI bool, isCompact bool) codexTransformResult {
result := codexTransformResult{}
// 工具续链需求会影响存储策略与 input 过滤逻辑。
......@@ -93,23 +110,8 @@ func applyCodexOAuthTransform(reqBody map[string]any, isCodexCLI bool, isCompact
}
}
// Strip parameters unsupported by codex models via the Responses API.
for _, key := range []string{
"max_output_tokens",
"max_completion_tokens",
"temperature",
"top_p",
"frequency_penalty",
"presence_penalty",
// prompt_cache_retention is a newer Responses API parameter (cache TTL).
// The ChatGPT internal Codex endpoint rejects it with
// "Unsupported parameter: prompt_cache_retention". Defense-in-depth
// for any OAuth path that reaches this transform — the Cursor
// Responses-shape short-circuit in ForwardAsChatCompletions strips
// it earlier too, but we keep this line so other OAuth callers are
// equally protected.
"prompt_cache_retention",
} {
// Strip parameters unsupported by ChatGPT internal Codex endpoint.
for _, key := range openAICodexOAuthUnsupportedFields {
if _, ok := reqBody[key]; ok {
delete(reqBody, key)
result.Modified = true
......@@ -141,9 +143,7 @@ func applyCodexOAuthTransform(reqBody map[string]any, isCodexCLI bool, isCompact
if name, ok := fcObj["name"].(string); ok && strings.TrimSpace(name) != "" {
reqBody["tool_choice"] = map[string]any{
"type": "function",
"function": map[string]any{
"name": name,
},
"name": name,
}
}
}
......@@ -219,9 +219,38 @@ func normalizeCodexToolChoice(reqBody map[string]any) bool {
return false
}
choiceType := strings.TrimSpace(firstNonEmptyString(choiceMap["type"]))
if choiceType == "" || codexToolsContainType(reqBody["tools"], choiceType) {
if choiceType == "" {
return false
}
modified := false
if choiceType == "function" {
name := strings.TrimSpace(firstNonEmptyString(choiceMap["name"]))
if name == "" {
if function, ok := choiceMap["function"].(map[string]any); ok {
name = strings.TrimSpace(firstNonEmptyString(function["name"]))
}
}
if name == "" {
reqBody["tool_choice"] = "auto"
return true
}
if strings.TrimSpace(firstNonEmptyString(choiceMap["name"])) != name {
choiceMap["name"] = name
modified = true
}
if _, ok := choiceMap["function"]; ok {
delete(choiceMap, "function")
modified = true
}
if !codexToolsContainFunctionName(reqBody["tools"], name) {
reqBody["tool_choice"] = "auto"
return true
}
return modified
}
if codexToolsContainType(reqBody["tools"], choiceType) {
return modified
}
reqBody["tool_choice"] = "auto"
return true
}
......@@ -243,6 +272,33 @@ func codexToolsContainType(rawTools any, toolType string) bool {
return false
}
func codexToolsContainFunctionName(rawTools any, name string) bool {
tools, ok := rawTools.([]any)
if !ok || strings.TrimSpace(name) == "" {
return false
}
normalizedName := strings.TrimSpace(name)
for _, rawTool := range tools {
tool, ok := rawTool.(map[string]any)
if !ok {
continue
}
if strings.TrimSpace(firstNonEmptyString(tool["type"])) != "function" {
continue
}
toolName := strings.TrimSpace(firstNonEmptyString(tool["name"]))
if toolName == "" {
if function, ok := tool["function"].(map[string]any); ok {
toolName = strings.TrimSpace(firstNonEmptyString(function["name"]))
}
}
if toolName == normalizedName {
return true
}
}
return false
}
func normalizeCodexToolRoleMessages(input []any) ([]any, bool) {
if len(input) == 0 {
return input, false
......@@ -853,6 +909,14 @@ func filterCodexInput(input []any, preserveReferences bool) []any {
}
typ, _ := m["type"].(string)
// chatgpt.com codex backend (OAuth path) does not persist reasoning
// items because applyCodexOAuthTransform forces store=false. Any rs_*
// reference replayed in input is guaranteed to 404 upstream
// ("Item with id 'rs_...' not found"). Drop reasoning items entirely.
if typ == "reasoning" {
continue
}
// 仅修正真正的 tool/function call 标识,避免误改普通 message/reasoning id;
// 若 item_reference 指向 legacy call_* 标识,则仅修正该引用本身。
fixCallIDPrefix := func(id string) string {
......
package service
import (
"fmt"
"strings"
"testing"
"github.com/stretchr/testify/require"
......@@ -249,6 +251,44 @@ func TestApplyCodexOAuthTransform_PreservesKnownToolChoice(t *testing.T) {
require.Equal(t, "custom", choice["type"])
}
func TestApplyCodexOAuthTransform_NormalizesLegacyFunctionToolChoice(t *testing.T) {
reqBody := map[string]any{
"model": "gpt-5.4",
"tools": []any{
map[string]any{"type": "function", "name": "shell"},
},
"tool_choice": map[string]any{
"type": "function",
"function": map[string]any{"name": "shell"},
},
}
applyCodexOAuthTransform(reqBody, true, false)
choice, ok := reqBody["tool_choice"].(map[string]any)
require.True(t, ok)
require.Equal(t, "function", choice["type"])
require.Equal(t, "shell", choice["name"])
require.NotContains(t, choice, "function")
}
func TestApplyCodexOAuthTransform_DowngradesMissingFunctionToolChoice(t *testing.T) {
reqBody := map[string]any{
"model": "gpt-5.4",
"tools": []any{
map[string]any{"type": "function", "name": "shell"},
},
"tool_choice": map[string]any{
"type": "function",
"function": map[string]any{"name": "missing"},
},
}
applyCodexOAuthTransform(reqBody, true, false)
require.Equal(t, "auto", reqBody["tool_choice"])
}
func TestApplyCodexOAuthTransform_AddsFallbackNameForFunctionCallInput(t *testing.T) {
reqBody := map[string]any{
"model": "gpt-5.4",
......@@ -1048,6 +1088,27 @@ func TestApplyCodexOAuthTransform_StripsPromptCacheRetention(t *testing.T) {
"prompt_cache_retention must be stripped before forwarding to Codex upstream")
}
func TestApplyCodexOAuthTransform_StripsChatGPTInternalUnsupportedFields(t *testing.T) {
reqBody := map[string]any{
"model": "gpt-5.4",
"user": "user_123",
"metadata": map[string]any{"trace_id": "abc"},
"prompt_cache_retention": "24h",
"safety_identifier": "sid",
"stream_options": map[string]any{"include_usage": true},
"input": []any{
map[string]any{"role": "user", "content": "hi"},
},
}
result := applyCodexOAuthTransform(reqBody, true, false)
require.True(t, result.Modified)
for _, field := range openAIChatGPTInternalUnsupportedFields {
require.NotContains(t, reqBody, field)
}
}
func TestApplyCodexOAuthTransform_ExtractsSystemMessages(t *testing.T) {
reqBody := map[string]any{
"model": "gpt-5.1",
......@@ -1094,3 +1155,56 @@ func TestIsInstructionsEmpty(t *testing.T) {
})
}
}
func TestFilterCodexInput_DropsReasoningItemsRegardlessOfPreserveReferences(t *testing.T) {
// Reasoning items in input[] reference rs_* IDs that were emitted by
// chatgpt.com under store=false (forced by applyCodexOAuthTransform).
// They are never persisted upstream, so forwarding them produces a
// guaranteed 404 ("Item with id 'rs_...' not found"). Drop them
// regardless of preserveReferences. See: Wei-Shaw/sub2api issue #1957.
build := func() []any {
return []any{
map[string]any{"type": "message", "id": "msg_0", "role": "user", "content": "hi"},
map[string]any{
"type": "reasoning",
"id": "rs_0672f12450da0b9c0169f07220a6c08198b68c2455ced99344",
"summary": []any{},
},
map[string]any{"type": "function_call", "id": "fc_1", "call_id": "call_1", "name": "tool"},
map[string]any{"type": "function_call_output", "call_id": "call_1", "output": "{}"},
}
}
for _, preserve := range []bool{true, false} {
preserve := preserve
t.Run(fmt.Sprintf("preserveReferences=%v", preserve), func(t *testing.T) {
filtered := filterCodexInput(build(), preserve)
for _, raw := range filtered {
item, ok := raw.(map[string]any)
require.True(t, ok)
require.NotEqual(t, "reasoning", item["type"],
"reasoning items must be dropped from input on the OAuth path")
if id, ok := item["id"].(string); ok {
require.False(t, strings.HasPrefix(id, "rs_"),
"no item carrying an rs_* id should survive the filter")
}
}
// Sanity check: the non-reasoning items should still be present.
gotTypes := make(map[string]int)
for _, raw := range filtered {
item, ok := raw.(map[string]any)
require.True(t, ok)
typ, ok := item["type"].(string)
require.True(t, ok)
gotTypes[typ]++
}
require.Equal(t, 1, gotTypes["message"])
require.Equal(t, 1, gotTypes["function_call"])
require.Equal(t, 1, gotTypes["function_call_output"])
require.Equal(t, 0, gotTypes["reasoning"])
})
}
}
package service
import (
"context"
"encoding/json"
"errors"
"testing"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/stretchr/testify/require"
)
type openAIFastPolicyRepoStub struct {
values map[string]string
}
func (s *openAIFastPolicyRepoStub) Get(ctx context.Context, key string) (*Setting, error) {
panic("unexpected Get call")
}
func (s *openAIFastPolicyRepoStub) GetValue(ctx context.Context, key string) (string, error) {
if v, ok := s.values[key]; ok {
return v, nil
}
return "", ErrSettingNotFound
}
func (s *openAIFastPolicyRepoStub) Set(ctx context.Context, key, value string) error {
if s.values == nil {
s.values = map[string]string{}
}
s.values[key] = value
return nil
}
func (s *openAIFastPolicyRepoStub) GetMultiple(ctx context.Context, keys []string) (map[string]string, error) {
panic("unexpected GetMultiple call")
}
func (s *openAIFastPolicyRepoStub) SetMultiple(ctx context.Context, settings map[string]string) error {
panic("unexpected SetMultiple call")
}
func (s *openAIFastPolicyRepoStub) GetAll(ctx context.Context) (map[string]string, error) {
panic("unexpected GetAll call")
}
func (s *openAIFastPolicyRepoStub) Delete(ctx context.Context, key string) error {
panic("unexpected Delete call")
}
func newOpenAIGatewayServiceWithSettings(t *testing.T, settings *OpenAIFastPolicySettings) *OpenAIGatewayService {
t.Helper()
repo := &openAIFastPolicyRepoStub{values: map[string]string{}}
if settings != nil {
raw, err := json.Marshal(settings)
require.NoError(t, err)
repo.values[SettingKeyOpenAIFastPolicySettings] = string(raw)
}
return &OpenAIGatewayService{
settingService: NewSettingService(repo, &config.Config{}),
}
}
func TestEvaluateOpenAIFastPolicy_DefaultFiltersAllModelsPriority(t *testing.T) {
svc := newOpenAIGatewayServiceWithSettings(t, DefaultOpenAIFastPolicySettings())
account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
// 默认策略对所有模型生效(whitelist 为空),因为 codex 的 service_tier=fast
// 是用户级开关,与 model 正交。
// gpt-5.5 + priority → filter
action, _ := svc.evaluateOpenAIFastPolicy(context.Background(), account, "gpt-5.5", OpenAIFastTierPriority)
require.Equal(t, BetaPolicyActionFilter, action)
// gpt-5.5-turbo → filter
action, _ = svc.evaluateOpenAIFastPolicy(context.Background(), account, "gpt-5.5-turbo", OpenAIFastTierPriority)
require.Equal(t, BetaPolicyActionFilter, action)
// gpt-4 + priority → filter(默认策略覆盖所有模型)
action, _ = svc.evaluateOpenAIFastPolicy(context.Background(), account, "gpt-4", OpenAIFastTierPriority)
require.Equal(t, BetaPolicyActionFilter, action)
// gpt-5.5 + flex → pass (tier doesn't match)
action, _ = svc.evaluateOpenAIFastPolicy(context.Background(), account, "gpt-5.5", OpenAIFastTierFlex)
require.Equal(t, BetaPolicyActionPass, action)
// empty tier → pass
action, _ = svc.evaluateOpenAIFastPolicy(context.Background(), account, "gpt-5.5", "")
require.Equal(t, BetaPolicyActionPass, action)
}
func TestEvaluateOpenAIFastPolicy_BlockRuleCarriesMessage(t *testing.T) {
settings := &OpenAIFastPolicySettings{
Rules: []OpenAIFastPolicyRule{{
ServiceTier: OpenAIFastTierPriority,
Action: BetaPolicyActionBlock,
Scope: BetaPolicyScopeAll,
ErrorMessage: "fast mode is not allowed",
ModelWhitelist: []string{"gpt-5.5"},
FallbackAction: BetaPolicyActionPass,
}},
}
svc := newOpenAIGatewayServiceWithSettings(t, settings)
account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
action, msg := svc.evaluateOpenAIFastPolicy(context.Background(), account, "gpt-5.5", OpenAIFastTierPriority)
require.Equal(t, BetaPolicyActionBlock, action)
require.Equal(t, "fast mode is not allowed", msg)
}
func TestEvaluateOpenAIFastPolicy_ScopeFiltersOAuth(t *testing.T) {
settings := &OpenAIFastPolicySettings{
Rules: []OpenAIFastPolicyRule{{
ServiceTier: OpenAIFastTierAny,
Action: BetaPolicyActionFilter,
Scope: BetaPolicyScopeOAuth,
}},
}
svc := newOpenAIGatewayServiceWithSettings(t, settings)
// OAuth account → rule matches
oauthAccount := &Account{Platform: PlatformOpenAI, Type: AccountTypeOAuth}
action, _ := svc.evaluateOpenAIFastPolicy(context.Background(), oauthAccount, "gpt-4", OpenAIFastTierPriority)
require.Equal(t, BetaPolicyActionFilter, action)
// API Key account → rule skipped → pass
apiKeyAccount := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
action, _ = svc.evaluateOpenAIFastPolicy(context.Background(), apiKeyAccount, "gpt-4", OpenAIFastTierPriority)
require.Equal(t, BetaPolicyActionPass, action)
}
func TestApplyOpenAIFastPolicyToBody_FilterRemovesField(t *testing.T) {
svc := newOpenAIGatewayServiceWithSettings(t, DefaultOpenAIFastPolicySettings())
account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
// gpt-5.5 fast → service_tier stripped
body := []byte(`{"model":"gpt-5.5","service_tier":"priority","messages":[]}`)
updated, err := svc.applyOpenAIFastPolicyToBody(context.Background(), account, "gpt-5.5", body)
require.NoError(t, err)
require.NotContains(t, string(updated), `"service_tier"`)
// Client sending "fast" (alias for priority) also filtered
body = []byte(`{"model":"gpt-5.5","service_tier":"fast"}`)
updated, err = svc.applyOpenAIFastPolicyToBody(context.Background(), account, "gpt-5.5", body)
require.NoError(t, err)
require.NotContains(t, string(updated), `"service_tier"`)
// gpt-4 priority → 默认策略对所有模型 filter,service_tier 被移除
body = []byte(`{"model":"gpt-4","service_tier":"priority"}`)
updated, err = svc.applyOpenAIFastPolicyToBody(context.Background(), account, "gpt-4", body)
require.NoError(t, err)
require.NotContains(t, string(updated), `"service_tier"`)
// No service_tier → no-op
body = []byte(`{"model":"gpt-5.5"}`)
updated, err = svc.applyOpenAIFastPolicyToBody(context.Background(), account, "gpt-5.5", body)
require.NoError(t, err)
require.Equal(t, string(body), string(updated))
}
// TestApplyOpenAIFastPolicyToBody_OfficialTiersBypassDefaultRule 验证扩展白名单后
// 客户端显式发送的 OpenAI 官方合法 tier(auto/default/scale)能透传到上游而不被
// 静默剥离。默认策略只针对 priority,所以这些 tier 落在 fall-through pass 分支。
func TestApplyOpenAIFastPolicyToBody_OfficialTiersBypassDefaultRule(t *testing.T) {
svc := newOpenAIGatewayServiceWithSettings(t, DefaultOpenAIFastPolicySettings())
account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
for _, tier := range []string{"auto", "default", "scale"} {
body := []byte(`{"model":"gpt-5.5","service_tier":"` + tier + `"}`)
updated, err := svc.applyOpenAIFastPolicyToBody(context.Background(), account, "gpt-5.5", body)
require.NoError(t, err, "tier %q should pass without error", tier)
require.Contains(t, string(updated), `"service_tier":"`+tier+`"`,
"tier %q should be preserved in body under default rule", tier)
}
// evaluate 层也应判定为 pass(默认规则 ServiceTier=priority 与 auto/default/scale 不匹配)
for _, tier := range []string{"auto", "default", "scale"} {
action, _ := svc.evaluateOpenAIFastPolicy(context.Background(), account, "gpt-5.5", tier)
require.Equal(t, BetaPolicyActionPass, action, "tier %q should evaluate to pass", tier)
}
}
// TestApplyOpenAIFastPolicyToBody_AllRuleStripsOfficialTiers 验证管理员显式配置
// ServiceTier=all + Action=filter 规则后,auto/default/scale 等官方 tier 也会
// 被剥离。这是符合预期的——首条匹配 short-circuit,"all" 覆盖任意已识别 tier。
func TestApplyOpenAIFastPolicyToBody_AllRuleStripsOfficialTiers(t *testing.T) {
settings := &OpenAIFastPolicySettings{
Rules: []OpenAIFastPolicyRule{{
ServiceTier: OpenAIFastTierAny,
Action: BetaPolicyActionFilter,
Scope: BetaPolicyScopeAll,
}},
}
svc := newOpenAIGatewayServiceWithSettings(t, settings)
account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
for _, tier := range []string{"auto", "default", "scale", "priority", "flex"} {
body := []byte(`{"model":"gpt-5.5","service_tier":"` + tier + `"}`)
updated, err := svc.applyOpenAIFastPolicyToBody(context.Background(), account, "gpt-5.5", body)
require.NoError(t, err)
require.NotContains(t, string(updated), `"service_tier"`,
"tier %q should be stripped under ServiceTier=all + filter rule", tier)
}
}
// TestApplyOpenAIFastPolicyToBody_UnknownTierStripped 验证真未知 tier 仍被剥离
// (normalize 返回 nil → normalizeResponsesBodyServiceTier 删除字段;
// applyOpenAIFastPolicyToBody 在 normTier 为空时直接 no-op,因为字段已不可能存在
// 于经过前置归一化的请求里。这里直接调 apply 验证它对未识别值不会异常)。
func TestApplyOpenAIFastPolicyToBody_UnknownTierStripped(t *testing.T) {
svc := newOpenAIGatewayServiceWithSettings(t, DefaultOpenAIFastPolicySettings())
account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
// normalize 阶段会将未知值剥离
require.Nil(t, normalizeOpenAIServiceTier("xxx"))
// applyOpenAIFastPolicyToBody 收到未识别 tier 时不报错,body 透传不变
// (不属于本函数职责——上层 normalizeResponsesBodyServiceTier 已剥离)
body := []byte(`{"model":"gpt-5.5","service_tier":"xxx"}`)
updated, err := svc.applyOpenAIFastPolicyToBody(context.Background(), account, "gpt-5.5", body)
require.NoError(t, err)
require.Equal(t, string(body), string(updated))
}
func TestApplyOpenAIFastPolicyToBody_BlockReturnsTypedError(t *testing.T) {
settings := &OpenAIFastPolicySettings{
Rules: []OpenAIFastPolicyRule{{
ServiceTier: OpenAIFastTierPriority,
Action: BetaPolicyActionBlock,
Scope: BetaPolicyScopeAll,
ErrorMessage: "fast mode is blocked for gpt-5.5",
ModelWhitelist: []string{"gpt-5.5"},
FallbackAction: BetaPolicyActionPass,
}},
}
svc := newOpenAIGatewayServiceWithSettings(t, settings)
account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
body := []byte(`{"model":"gpt-5.5","service_tier":"priority"}`)
updated, err := svc.applyOpenAIFastPolicyToBody(context.Background(), account, "gpt-5.5", body)
require.Error(t, err)
var blocked *OpenAIFastBlockedError
require.True(t, errors.As(err, &blocked))
require.Contains(t, blocked.Message, "fast mode is blocked")
require.Equal(t, string(body), string(updated)) // body not mutated on block
}
func TestSetOpenAIFastPolicySettings_Validation(t *testing.T) {
repo := &openAIFastPolicyRepoStub{values: map[string]string{}}
svc := NewSettingService(repo, &config.Config{})
// Invalid action rejected
err := svc.SetOpenAIFastPolicySettings(context.Background(), &OpenAIFastPolicySettings{
Rules: []OpenAIFastPolicyRule{{
ServiceTier: OpenAIFastTierPriority,
Action: "bogus",
Scope: BetaPolicyScopeAll,
}},
})
require.Error(t, err)
// Invalid service_tier rejected
err = svc.SetOpenAIFastPolicySettings(context.Background(), &OpenAIFastPolicySettings{
Rules: []OpenAIFastPolicyRule{{
ServiceTier: "turbo",
Action: BetaPolicyActionPass,
Scope: BetaPolicyScopeAll,
}},
})
require.Error(t, err)
// Valid settings persisted
err = svc.SetOpenAIFastPolicySettings(context.Background(), &OpenAIFastPolicySettings{
Rules: []OpenAIFastPolicyRule{{
ServiceTier: OpenAIFastTierPriority,
Action: BetaPolicyActionFilter,
Scope: BetaPolicyScopeAll,
}},
})
require.NoError(t, err)
got, err := svc.GetOpenAIFastPolicySettings(context.Background())
require.NoError(t, err)
require.Len(t, got.Rules, 1)
require.Equal(t, OpenAIFastTierPriority, got.Rules[0].ServiceTier)
}
This diff is collapsed.
......@@ -171,6 +171,17 @@ func (s *OpenAIGatewayService) ForwardAsChatCompletions(
}
}
// 4b. Apply OpenAI fast policy (may filter service_tier or block the request).
updatedBody, policyErr := s.applyOpenAIFastPolicyToBody(ctx, account, upstreamModel, responsesBody)
if policyErr != nil {
var blocked *OpenAIFastBlockedError
if errors.As(policyErr, &blocked) {
writeChatCompletionsError(c, http.StatusForbidden, "permission_error", blocked.Message)
}
return nil, policyErr
}
responsesBody = updatedBody
// 5. Get access token
token, _, err := s.GetAccessToken(ctx, account)
if err != nil {
......
......@@ -19,8 +19,22 @@ func TestNormalizeResponsesRequestServiceTier(t *testing.T) {
normalizeResponsesRequestServiceTier(req)
require.Equal(t, "flex", req.ServiceTier)
// OpenAI 官方合法 tier 应被透传保留。
req.ServiceTier = "auto"
normalizeResponsesRequestServiceTier(req)
require.Equal(t, "auto", req.ServiceTier)
req.ServiceTier = "default"
normalizeResponsesRequestServiceTier(req)
require.Equal(t, "default", req.ServiceTier)
req.ServiceTier = "scale"
normalizeResponsesRequestServiceTier(req)
require.Equal(t, "scale", req.ServiceTier)
// 真未知值仍被剥离。
req.ServiceTier = "turbo"
normalizeResponsesRequestServiceTier(req)
require.Empty(t, req.ServiceTier)
}
......@@ -37,8 +51,25 @@ func TestNormalizeResponsesBodyServiceTier(t *testing.T) {
require.Equal(t, "flex", tier)
require.Equal(t, "flex", gjson.GetBytes(body, "service_tier").String())
// OpenAI 官方 tier 直接保留在 body 中(透传上游)。
body, tier, err = normalizeResponsesBodyServiceTier([]byte(`{"model":"gpt-5.1","service_tier":"auto"}`))
require.NoError(t, err)
require.Equal(t, "auto", tier)
require.Equal(t, "auto", gjson.GetBytes(body, "service_tier").String())
body, tier, err = normalizeResponsesBodyServiceTier([]byte(`{"model":"gpt-5.1","service_tier":"default"}`))
require.NoError(t, err)
require.Equal(t, "default", tier)
require.Equal(t, "default", gjson.GetBytes(body, "service_tier").String())
body, tier, err = normalizeResponsesBodyServiceTier([]byte(`{"model":"gpt-5.1","service_tier":"scale"}`))
require.NoError(t, err)
require.Equal(t, "scale", tier)
require.Equal(t, "scale", gjson.GetBytes(body, "service_tier").String())
// 真未知值才会被删除。
body, tier, err = normalizeResponsesBodyServiceTier([]byte(`{"model":"gpt-5.1","service_tier":"turbo"}`))
require.NoError(t, err)
require.Empty(t, tier)
require.False(t, gjson.GetBytes(body, "service_tier").Exists())
}
......@@ -143,6 +143,19 @@ func (s *OpenAIGatewayService) ForwardAsAnthropic(
}
}
// 4c. Apply OpenAI fast policy (may filter service_tier or block the request).
// Mirrors the Claude anthropic-beta "fast-mode-2026-02-01" filter, but keyed
// on the body-level service_tier field (priority/flex).
updatedBody, policyErr := s.applyOpenAIFastPolicyToBody(ctx, account, upstreamModel, responsesBody)
if policyErr != nil {
var blocked *OpenAIFastBlockedError
if errors.As(policyErr, &blocked) {
writeAnthropicError(c, http.StatusForbidden, "forbidden_error", blocked.Message)
}
return nil, policyErr
}
responsesBody = updatedBody
// 5. Get access token
token, _, err := s.GetAccessToken(ctx, account)
if err != nil {
......
......@@ -148,6 +148,7 @@ func newOpenAIRecordUsageServiceForTest(usageRepo UsageLogRepository, userRepo U
nil,
nil,
nil,
nil,
)
svc.userGroupRateResolver = newUserGroupRateResolver(
rateRepo,
......@@ -826,18 +827,29 @@ func TestNormalizeOpenAIServiceTier(t *testing.T) {
require.Equal(t, "priority", *got)
})
t.Run("default ignored", func(t *testing.T) {
require.Nil(t, normalizeOpenAIServiceTier("default"))
t.Run("openai official tiers preserved", func(t *testing.T) {
// OpenAI 官方文档定义的合法 tier 值都应被透传保留,避免因白名单过窄
// 静默剥离客户端显式发送的合法字段。Codex 客户端只发 priority/flex,
// 所以扩大白名单对 Codex 流量零影响(见 codex-rs/core/src/client.rs)。
for _, tier := range []string{"priority", "flex", "auto", "default", "scale"} {
got := normalizeOpenAIServiceTier(tier)
require.NotNil(t, got, "tier %q should not be normalized to nil", tier)
require.Equal(t, tier, *got)
}
})
t.Run("invalid ignored", func(t *testing.T) {
require.Nil(t, normalizeOpenAIServiceTier("turbo"))
require.Nil(t, normalizeOpenAIServiceTier("xxx"))
})
}
func TestExtractOpenAIServiceTier(t *testing.T) {
require.Equal(t, "priority", *extractOpenAIServiceTier(map[string]any{"service_tier": "fast"}))
require.Equal(t, "flex", *extractOpenAIServiceTier(map[string]any{"service_tier": "flex"}))
require.Equal(t, "auto", *extractOpenAIServiceTier(map[string]any{"service_tier": "auto"}))
require.Equal(t, "default", *extractOpenAIServiceTier(map[string]any{"service_tier": "default"}))
require.Equal(t, "scale", *extractOpenAIServiceTier(map[string]any{"service_tier": "scale"}))
require.Nil(t, extractOpenAIServiceTier(map[string]any{"service_tier": 1}))
require.Nil(t, extractOpenAIServiceTier(nil))
}
......@@ -845,7 +857,10 @@ func TestExtractOpenAIServiceTier(t *testing.T) {
func TestExtractOpenAIServiceTierFromBody(t *testing.T) {
require.Equal(t, "priority", *extractOpenAIServiceTierFromBody([]byte(`{"service_tier":"fast"}`)))
require.Equal(t, "flex", *extractOpenAIServiceTierFromBody([]byte(`{"service_tier":"flex"}`)))
require.Nil(t, extractOpenAIServiceTierFromBody([]byte(`{"service_tier":"default"}`)))
require.Equal(t, "auto", *extractOpenAIServiceTierFromBody([]byte(`{"service_tier":"auto"}`)))
require.Equal(t, "default", *extractOpenAIServiceTierFromBody([]byte(`{"service_tier":"default"}`)))
require.Equal(t, "scale", *extractOpenAIServiceTierFromBody([]byte(`{"service_tier":"scale"}`)))
require.Nil(t, extractOpenAIServiceTierFromBody([]byte(`{"service_tier":"turbo"}`)))
require.Nil(t, extractOpenAIServiceTierFromBody(nil))
}
......
......@@ -1767,6 +1767,24 @@ func TestOpenAIResponsesRequestPathSuffix(t *testing.T) {
}
}
func TestNormalizeOpenAICompactRequestBodyPreservesCurrentCodexPayloadFields(t *testing.T) {
body := []byte(`{"model":"gpt-5.5","input":[{"type":"message","role":"user","content":"compact me"}],"instructions":"compact-test","tools":[{"type":"function","name":"shell"}],"parallel_tool_calls":true,"reasoning":{"effort":"high"},"text":{"verbosity":"low"},"previous_response_id":"resp_123","store":true,"stream":true,"prompt_cache_key":"cache_123"}`)
normalized, changed, err := normalizeOpenAICompactRequestBody(body)
require.NoError(t, err)
require.True(t, changed)
require.Equal(t, "gpt-5.5", gjson.GetBytes(normalized, "model").String())
require.True(t, gjson.GetBytes(normalized, "tools").Exists())
require.True(t, gjson.GetBytes(normalized, "parallel_tool_calls").Bool())
require.Equal(t, "high", gjson.GetBytes(normalized, "reasoning.effort").String())
require.Equal(t, "low", gjson.GetBytes(normalized, "text.verbosity").String())
require.Equal(t, "resp_123", gjson.GetBytes(normalized, "previous_response_id").String())
require.False(t, gjson.GetBytes(normalized, "store").Exists())
require.False(t, gjson.GetBytes(normalized, "stream").Exists())
require.False(t, gjson.GetBytes(normalized, "prompt_cache_key").Exists())
}
func TestOpenAIBuildUpstreamRequestOpenAIPassthroughPreservesCompactPath(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
......
package service
import (
"testing"
"github.com/stretchr/testify/require"
"github.com/tidwall/gjson"
)
func TestNormalizeOpenAIPassthroughOAuthBody_RemovesUnsupportedUser(t *testing.T) {
body := []byte(`{"model":"gpt-5.4","input":"hello","user":"user_123","metadata":{"user_id":"user_123"},"prompt_cache_retention":"24h","safety_identifier":"sid","stream_options":{"include_usage":true}}`)
normalized, changed, err := normalizeOpenAIPassthroughOAuthBody(body, false)
require.NoError(t, err)
require.True(t, changed)
for _, field := range openAIChatGPTInternalUnsupportedFields {
require.False(t, gjson.GetBytes(normalized, field).Exists(), "%s should be stripped", field)
}
require.True(t, gjson.GetBytes(normalized, "stream").Bool())
require.False(t, gjson.GetBytes(normalized, "store").Bool())
}
func TestNormalizeOpenAIPassthroughOAuthBody_CompactRemovesUnsupportedUser(t *testing.T) {
body := []byte(`{"model":"gpt-5.4","input":"hello","user":"user_123","metadata":{"user_id":"user_123"},"stream":true,"store":true}`)
normalized, changed, err := normalizeOpenAIPassthroughOAuthBody(body, true)
require.NoError(t, err)
require.True(t, changed)
require.False(t, gjson.GetBytes(normalized, "user").Exists())
require.False(t, gjson.GetBytes(normalized, "metadata").Exists())
require.False(t, gjson.GetBytes(normalized, "stream").Exists())
require.False(t, gjson.GetBytes(normalized, "store").Exists())
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment