Unverified Commit 6bccb8a8 authored by Wesley Liddick's avatar Wesley Liddick Committed by GitHub
Browse files

Merge branch 'main' into feature/antigravity-user-agent-configurable

parents 1fc6ef3d 3de1e0e4
package service
import (
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
"github.com/gin-gonic/gin"
)
const (
// CodexClientRestrictionReasonDisabled 表示账号未开启 codex_cli_only。
CodexClientRestrictionReasonDisabled = "codex_cli_only_disabled"
// CodexClientRestrictionReasonMatchedUA 表示请求命中官方客户端 UA 白名单。
CodexClientRestrictionReasonMatchedUA = "official_client_user_agent_matched"
// CodexClientRestrictionReasonMatchedOriginator 表示请求命中官方客户端 originator 白名单。
CodexClientRestrictionReasonMatchedOriginator = "official_client_originator_matched"
// CodexClientRestrictionReasonNotMatchedUA 表示请求未命中官方客户端 UA 白名单。
CodexClientRestrictionReasonNotMatchedUA = "official_client_user_agent_not_matched"
// CodexClientRestrictionReasonForceCodexCLI 表示通过 ForceCodexCLI 配置兜底放行。
CodexClientRestrictionReasonForceCodexCLI = "force_codex_cli_enabled"
)
// CodexClientRestrictionDetectionResult 是 codex_cli_only 统一检测入口结果。
type CodexClientRestrictionDetectionResult struct {
Enabled bool
Matched bool
Reason string
}
// CodexClientRestrictionDetector 定义 codex_cli_only 统一检测入口。
type CodexClientRestrictionDetector interface {
Detect(c *gin.Context, account *Account) CodexClientRestrictionDetectionResult
}
// OpenAICodexClientRestrictionDetector 为 OpenAI OAuth codex_cli_only 的默认实现。
type OpenAICodexClientRestrictionDetector struct {
cfg *config.Config
}
func NewOpenAICodexClientRestrictionDetector(cfg *config.Config) *OpenAICodexClientRestrictionDetector {
return &OpenAICodexClientRestrictionDetector{cfg: cfg}
}
func (d *OpenAICodexClientRestrictionDetector) Detect(c *gin.Context, account *Account) CodexClientRestrictionDetectionResult {
if account == nil || !account.IsCodexCLIOnlyEnabled() {
return CodexClientRestrictionDetectionResult{
Enabled: false,
Matched: false,
Reason: CodexClientRestrictionReasonDisabled,
}
}
if d != nil && d.cfg != nil && d.cfg.Gateway.ForceCodexCLI {
return CodexClientRestrictionDetectionResult{
Enabled: true,
Matched: true,
Reason: CodexClientRestrictionReasonForceCodexCLI,
}
}
userAgent := ""
originator := ""
if c != nil {
userAgent = c.GetHeader("User-Agent")
originator = c.GetHeader("originator")
}
if openai.IsCodexOfficialClientRequest(userAgent) {
return CodexClientRestrictionDetectionResult{
Enabled: true,
Matched: true,
Reason: CodexClientRestrictionReasonMatchedUA,
}
}
if openai.IsCodexOfficialClientOriginator(originator) {
return CodexClientRestrictionDetectionResult{
Enabled: true,
Matched: true,
Reason: CodexClientRestrictionReasonMatchedOriginator,
}
}
return CodexClientRestrictionDetectionResult{
Enabled: true,
Matched: false,
Reason: CodexClientRestrictionReasonNotMatchedUA,
}
}
package service
import (
"net/http"
"net/http/httptest"
"testing"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
func newCodexDetectorTestContext(ua string, originator string) *gin.Context {
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
if ua != "" {
c.Request.Header.Set("User-Agent", ua)
}
if originator != "" {
c.Request.Header.Set("originator", originator)
}
return c
}
func TestOpenAICodexClientRestrictionDetector_Detect(t *testing.T) {
gin.SetMode(gin.TestMode)
t.Run("未开启开关时绕过", func(t *testing.T) {
detector := NewOpenAICodexClientRestrictionDetector(nil)
account := &Account{Platform: PlatformOpenAI, Type: AccountTypeOAuth, Extra: map[string]any{}}
result := detector.Detect(newCodexDetectorTestContext("curl/8.0", ""), account)
require.False(t, result.Enabled)
require.False(t, result.Matched)
require.Equal(t, CodexClientRestrictionReasonDisabled, result.Reason)
})
t.Run("开启后 codex_cli_rs 命中", func(t *testing.T) {
detector := NewOpenAICodexClientRestrictionDetector(nil)
account := &Account{
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Extra: map[string]any{"codex_cli_only": true},
}
result := detector.Detect(newCodexDetectorTestContext("codex_cli_rs/0.99.0", ""), account)
require.True(t, result.Enabled)
require.True(t, result.Matched)
require.Equal(t, CodexClientRestrictionReasonMatchedUA, result.Reason)
})
t.Run("开启后 codex_vscode 命中", func(t *testing.T) {
detector := NewOpenAICodexClientRestrictionDetector(nil)
account := &Account{
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Extra: map[string]any{"codex_cli_only": true},
}
result := detector.Detect(newCodexDetectorTestContext("codex_vscode/1.0.0", ""), account)
require.True(t, result.Enabled)
require.True(t, result.Matched)
require.Equal(t, CodexClientRestrictionReasonMatchedUA, result.Reason)
})
t.Run("开启后 codex_app 命中", func(t *testing.T) {
detector := NewOpenAICodexClientRestrictionDetector(nil)
account := &Account{
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Extra: map[string]any{"codex_cli_only": true},
}
result := detector.Detect(newCodexDetectorTestContext("codex_app/2.1.0", ""), account)
require.True(t, result.Enabled)
require.True(t, result.Matched)
require.Equal(t, CodexClientRestrictionReasonMatchedUA, result.Reason)
})
t.Run("开启后 originator 命中", func(t *testing.T) {
detector := NewOpenAICodexClientRestrictionDetector(nil)
account := &Account{
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Extra: map[string]any{"codex_cli_only": true},
}
result := detector.Detect(newCodexDetectorTestContext("curl/8.0", "codex_chatgpt_desktop"), account)
require.True(t, result.Enabled)
require.True(t, result.Matched)
require.Equal(t, CodexClientRestrictionReasonMatchedOriginator, result.Reason)
})
t.Run("开启后非官方客户端拒绝", func(t *testing.T) {
detector := NewOpenAICodexClientRestrictionDetector(nil)
account := &Account{
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Extra: map[string]any{"codex_cli_only": true},
}
result := detector.Detect(newCodexDetectorTestContext("curl/8.0", "my_client"), account)
require.True(t, result.Enabled)
require.False(t, result.Matched)
require.Equal(t, CodexClientRestrictionReasonNotMatchedUA, result.Reason)
})
t.Run("开启 ForceCodexCLI 时允许通过", func(t *testing.T) {
detector := NewOpenAICodexClientRestrictionDetector(&config.Config{
Gateway: config.GatewayConfig{ForceCodexCLI: true},
})
account := &Account{
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Extra: map[string]any{"codex_cli_only": true},
}
result := detector.Detect(newCodexDetectorTestContext("curl/8.0", "my_client"), account)
require.True(t, result.Enabled)
require.True(t, result.Matched)
require.Equal(t, CodexClientRestrictionReasonForceCodexCLI, result.Reason)
})
}
......@@ -2,73 +2,66 @@ package service
import (
_ "embed"
"encoding/json"
"fmt"
"io"
"net/http"
"os"
"path/filepath"
"strings"
"time"
)
const (
opencodeCodexHeaderURL = "https://raw.githubusercontent.com/anomalyco/opencode/dev/packages/opencode/src/session/prompt/codex_header.txt"
codexCacheTTL = 15 * time.Minute
)
//go:embed prompts/codex_cli_instructions.md
var codexCLIInstructions string
var codexModelMap = map[string]string{
"gpt-5.3": "gpt-5.3",
"gpt-5.3-none": "gpt-5.3",
"gpt-5.3-low": "gpt-5.3",
"gpt-5.3-medium": "gpt-5.3",
"gpt-5.3-high": "gpt-5.3",
"gpt-5.3-xhigh": "gpt-5.3",
"gpt-5.3-codex": "gpt-5.3-codex",
"gpt-5.3-codex-low": "gpt-5.3-codex",
"gpt-5.3-codex-medium": "gpt-5.3-codex",
"gpt-5.3-codex-high": "gpt-5.3-codex",
"gpt-5.3-codex-xhigh": "gpt-5.3-codex",
"gpt-5.1-codex": "gpt-5.1-codex",
"gpt-5.1-codex-low": "gpt-5.1-codex",
"gpt-5.1-codex-medium": "gpt-5.1-codex",
"gpt-5.1-codex-high": "gpt-5.1-codex",
"gpt-5.1-codex-max": "gpt-5.1-codex-max",
"gpt-5.1-codex-max-low": "gpt-5.1-codex-max",
"gpt-5.1-codex-max-medium": "gpt-5.1-codex-max",
"gpt-5.1-codex-max-high": "gpt-5.1-codex-max",
"gpt-5.1-codex-max-xhigh": "gpt-5.1-codex-max",
"gpt-5.2": "gpt-5.2",
"gpt-5.2-none": "gpt-5.2",
"gpt-5.2-low": "gpt-5.2",
"gpt-5.2-medium": "gpt-5.2",
"gpt-5.2-high": "gpt-5.2",
"gpt-5.2-xhigh": "gpt-5.2",
"gpt-5.2-codex": "gpt-5.2-codex",
"gpt-5.2-codex-low": "gpt-5.2-codex",
"gpt-5.2-codex-medium": "gpt-5.2-codex",
"gpt-5.2-codex-high": "gpt-5.2-codex",
"gpt-5.2-codex-xhigh": "gpt-5.2-codex",
"gpt-5.1-codex-mini": "gpt-5.1-codex-mini",
"gpt-5.1-codex-mini-medium": "gpt-5.1-codex-mini",
"gpt-5.1-codex-mini-high": "gpt-5.1-codex-mini",
"gpt-5.1": "gpt-5.1",
"gpt-5.1-none": "gpt-5.1",
"gpt-5.1-low": "gpt-5.1",
"gpt-5.1-medium": "gpt-5.1",
"gpt-5.1-high": "gpt-5.1",
"gpt-5.1-chat-latest": "gpt-5.1",
"gpt-5-codex": "gpt-5.1-codex",
"codex-mini-latest": "gpt-5.1-codex-mini",
"gpt-5-codex-mini": "gpt-5.1-codex-mini",
"gpt-5-codex-mini-medium": "gpt-5.1-codex-mini",
"gpt-5-codex-mini-high": "gpt-5.1-codex-mini",
"gpt-5": "gpt-5.1",
"gpt-5-mini": "gpt-5.1",
"gpt-5-nano": "gpt-5.1",
"gpt-5.3": "gpt-5.3-codex",
"gpt-5.3-none": "gpt-5.3-codex",
"gpt-5.3-low": "gpt-5.3-codex",
"gpt-5.3-medium": "gpt-5.3-codex",
"gpt-5.3-high": "gpt-5.3-codex",
"gpt-5.3-xhigh": "gpt-5.3-codex",
"gpt-5.3-codex": "gpt-5.3-codex",
"gpt-5.3-codex-spark": "gpt-5.3-codex",
"gpt-5.3-codex-spark-low": "gpt-5.3-codex",
"gpt-5.3-codex-spark-medium": "gpt-5.3-codex",
"gpt-5.3-codex-spark-high": "gpt-5.3-codex",
"gpt-5.3-codex-spark-xhigh": "gpt-5.3-codex",
"gpt-5.3-codex-low": "gpt-5.3-codex",
"gpt-5.3-codex-medium": "gpt-5.3-codex",
"gpt-5.3-codex-high": "gpt-5.3-codex",
"gpt-5.3-codex-xhigh": "gpt-5.3-codex",
"gpt-5.1-codex": "gpt-5.1-codex",
"gpt-5.1-codex-low": "gpt-5.1-codex",
"gpt-5.1-codex-medium": "gpt-5.1-codex",
"gpt-5.1-codex-high": "gpt-5.1-codex",
"gpt-5.1-codex-max": "gpt-5.1-codex-max",
"gpt-5.1-codex-max-low": "gpt-5.1-codex-max",
"gpt-5.1-codex-max-medium": "gpt-5.1-codex-max",
"gpt-5.1-codex-max-high": "gpt-5.1-codex-max",
"gpt-5.1-codex-max-xhigh": "gpt-5.1-codex-max",
"gpt-5.2": "gpt-5.2",
"gpt-5.2-none": "gpt-5.2",
"gpt-5.2-low": "gpt-5.2",
"gpt-5.2-medium": "gpt-5.2",
"gpt-5.2-high": "gpt-5.2",
"gpt-5.2-xhigh": "gpt-5.2",
"gpt-5.2-codex": "gpt-5.2-codex",
"gpt-5.2-codex-low": "gpt-5.2-codex",
"gpt-5.2-codex-medium": "gpt-5.2-codex",
"gpt-5.2-codex-high": "gpt-5.2-codex",
"gpt-5.2-codex-xhigh": "gpt-5.2-codex",
"gpt-5.1-codex-mini": "gpt-5.1-codex-mini",
"gpt-5.1-codex-mini-medium": "gpt-5.1-codex-mini",
"gpt-5.1-codex-mini-high": "gpt-5.1-codex-mini",
"gpt-5.1": "gpt-5.1",
"gpt-5.1-none": "gpt-5.1",
"gpt-5.1-low": "gpt-5.1",
"gpt-5.1-medium": "gpt-5.1",
"gpt-5.1-high": "gpt-5.1",
"gpt-5.1-chat-latest": "gpt-5.1",
"gpt-5-codex": "gpt-5.1-codex",
"codex-mini-latest": "gpt-5.1-codex-mini",
"gpt-5-codex-mini": "gpt-5.1-codex-mini",
"gpt-5-codex-mini-medium": "gpt-5.1-codex-mini",
"gpt-5-codex-mini-high": "gpt-5.1-codex-mini",
"gpt-5": "gpt-5.1",
"gpt-5-mini": "gpt-5.1",
"gpt-5-nano": "gpt-5.1",
}
type codexTransformResult struct {
......@@ -77,12 +70,6 @@ type codexTransformResult struct {
PromptCacheKey string
}
type opencodeCacheMetadata struct {
ETag string `json:"etag"`
LastFetch string `json:"lastFetch,omitempty"`
LastChecked int64 `json:"lastChecked"`
}
func applyCodexOAuthTransform(reqBody map[string]any, isCodexCLI bool) codexTransformResult {
result := codexTransformResult{}
// 工具续链需求会影响存储策略与 input 过滤逻辑。
......@@ -177,7 +164,7 @@ func normalizeCodexModel(model string) string {
return "gpt-5.3-codex"
}
if strings.Contains(normalized, "gpt-5.3") || strings.Contains(normalized, "gpt 5.3") {
return "gpt-5.3"
return "gpt-5.3-codex"
}
if strings.Contains(normalized, "gpt-5.1-codex-max") || strings.Contains(normalized, "gpt 5.1 codex max") {
return "gpt-5.1-codex-max"
......@@ -222,54 +209,9 @@ func getNormalizedCodexModel(modelID string) string {
return ""
}
func getOpenCodeCachedPrompt(url, cacheFileName, metaFileName string) string {
cacheDir := codexCachePath("")
if cacheDir == "" {
return ""
}
cacheFile := filepath.Join(cacheDir, cacheFileName)
metaFile := filepath.Join(cacheDir, metaFileName)
var cachedContent string
if content, ok := readFile(cacheFile); ok {
cachedContent = content
}
var meta opencodeCacheMetadata
if loadJSON(metaFile, &meta) && meta.LastChecked > 0 && cachedContent != "" {
if time.Since(time.UnixMilli(meta.LastChecked)) < codexCacheTTL {
return cachedContent
}
}
content, etag, status, err := fetchWithETag(url, meta.ETag)
if err == nil && status == http.StatusNotModified && cachedContent != "" {
return cachedContent
}
if err == nil && status >= 200 && status < 300 && content != "" {
_ = writeFile(cacheFile, content)
meta = opencodeCacheMetadata{
ETag: etag,
LastFetch: time.Now().UTC().Format(time.RFC3339),
LastChecked: time.Now().UnixMilli(),
}
_ = writeJSON(metaFile, meta)
return content
}
return cachedContent
}
func getOpenCodeCodexHeader() string {
// 优先从 opencode 仓库缓存获取指令。
opencodeInstructions := getOpenCodeCachedPrompt(opencodeCodexHeaderURL, "opencode-codex-header.txt", "opencode-codex-header-meta.json")
// 若 opencode 指令可用,直接返回。
if opencodeInstructions != "" {
return opencodeInstructions
}
// 否则回退使用本地 Codex CLI 指令。
// 兼容保留:历史上这里会从 opencode 仓库拉取 codex_header.txt。
// 现在我们与 Codex CLI 一致,直接使用仓库内置的 instructions,避免读写缓存与外网依赖。
return getCodexCLIInstructions()
}
......@@ -287,8 +229,8 @@ func GetCodexCLIInstructions() string {
}
// applyInstructions 处理 instructions 字段
// isCodexCLI=true: 仅补充缺失的 instructions(使用 opencode 指令)
// isCodexCLI=false: 优先使用 opencode 指令覆盖
// isCodexCLI=true: 仅补充缺失的 instructions(使用内置 Codex CLI 指令)
// isCodexCLI=false: 优先使用内置 Codex CLI 指令覆盖
func applyInstructions(reqBody map[string]any, isCodexCLI bool) bool {
if isCodexCLI {
return applyCodexCLIInstructions(reqBody)
......@@ -297,13 +239,13 @@ func applyInstructions(reqBody map[string]any, isCodexCLI bool) bool {
}
// applyCodexCLIInstructions 为 Codex CLI 请求补充缺失的 instructions
// 仅在 instructions 为空时添加 opencode 指令
// 仅在 instructions 为空时添加内置 Codex CLI 指令(不依赖 opencode 缓存/回源)
func applyCodexCLIInstructions(reqBody map[string]any) bool {
if !isInstructionsEmpty(reqBody) {
return false // 已有有效 instructions,不修改
}
instructions := strings.TrimSpace(getOpenCodeCodexHeader())
instructions := strings.TrimSpace(getCodexCLIInstructions())
if instructions != "" {
reqBody["instructions"] = instructions
return true
......@@ -312,8 +254,8 @@ func applyCodexCLIInstructions(reqBody map[string]any) bool {
return false
}
// applyOpenCodeInstructions 为非 Codex CLI 请求应用 opencode 指令
// 优先使用 opencode 指令覆盖
// applyOpenCodeInstructions 为非 Codex CLI 请求应用内置 Codex CLI 指令(兼容历史函数名)
// 优先使用内置 Codex CLI 指令覆盖
func applyOpenCodeInstructions(reqBody map[string]any) bool {
instructions := strings.TrimSpace(getOpenCodeCodexHeader())
existingInstructions, _ := reqBody["instructions"].(string)
......@@ -495,85 +437,3 @@ func normalizeCodexTools(reqBody map[string]any) bool {
return modified
}
func codexCachePath(filename string) string {
home, err := os.UserHomeDir()
if err != nil {
return ""
}
cacheDir := filepath.Join(home, ".opencode", "cache")
if filename == "" {
return cacheDir
}
return filepath.Join(cacheDir, filename)
}
func readFile(path string) (string, bool) {
if path == "" {
return "", false
}
data, err := os.ReadFile(path)
if err != nil {
return "", false
}
return string(data), true
}
func writeFile(path, content string) error {
if path == "" {
return fmt.Errorf("empty cache path")
}
if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil {
return err
}
return os.WriteFile(path, []byte(content), 0o644)
}
func loadJSON(path string, target any) bool {
data, err := os.ReadFile(path)
if err != nil {
return false
}
if err := json.Unmarshal(data, target); err != nil {
return false
}
return true
}
func writeJSON(path string, value any) error {
if path == "" {
return fmt.Errorf("empty json path")
}
if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil {
return err
}
data, err := json.Marshal(value)
if err != nil {
return err
}
return os.WriteFile(path, data, 0o644)
}
func fetchWithETag(url, etag string) (string, string, int, error) {
req, err := http.NewRequest(http.MethodGet, url, nil)
if err != nil {
return "", "", 0, err
}
req.Header.Set("User-Agent", "sub2api-codex")
if etag != "" {
req.Header.Set("If-None-Match", etag)
}
resp, err := http.DefaultClient.Do(req)
if err != nil {
return "", "", 0, err
}
defer func() {
_ = resp.Body.Close()
}()
body, err := io.ReadAll(resp.Body)
if err != nil {
return "", "", resp.StatusCode, err
}
return string(body), resp.Header.Get("etag"), resp.StatusCode, nil
}
package service
import (
"encoding/json"
"os"
"path/filepath"
"testing"
"time"
"github.com/stretchr/testify/require"
)
func TestApplyCodexOAuthTransform_ToolContinuationPreservesInput(t *testing.T) {
// 续链场景:保留 item_reference 与 id,但不再强制 store=true。
setupCodexCache(t)
reqBody := map[string]any{
"model": "gpt-5.2",
......@@ -48,7 +43,6 @@ func TestApplyCodexOAuthTransform_ToolContinuationPreservesInput(t *testing.T) {
func TestApplyCodexOAuthTransform_ExplicitStoreFalsePreserved(t *testing.T) {
// 续链场景:显式 store=false 不再强制为 true,保持 false。
setupCodexCache(t)
reqBody := map[string]any{
"model": "gpt-5.1",
......@@ -68,7 +62,6 @@ func TestApplyCodexOAuthTransform_ExplicitStoreFalsePreserved(t *testing.T) {
func TestApplyCodexOAuthTransform_ExplicitStoreTrueForcedFalse(t *testing.T) {
// 显式 store=true 也会强制为 false。
setupCodexCache(t)
reqBody := map[string]any{
"model": "gpt-5.1",
......@@ -88,7 +81,6 @@ func TestApplyCodexOAuthTransform_ExplicitStoreTrueForcedFalse(t *testing.T) {
func TestApplyCodexOAuthTransform_NonContinuationDefaultsStoreFalseAndStripsIDs(t *testing.T) {
// 非续链场景:未设置 store 时默认 false,并移除 input 中的 id。
setupCodexCache(t)
reqBody := map[string]any{
"model": "gpt-5.1",
......@@ -130,8 +122,6 @@ func TestFilterCodexInput_RemovesItemReferenceWhenNotPreserved(t *testing.T) {
}
func TestApplyCodexOAuthTransform_NormalizeCodexTools_PreservesResponsesFunctionTools(t *testing.T) {
setupCodexCache(t)
reqBody := map[string]any{
"model": "gpt-5.1",
"tools": []any{
......@@ -162,7 +152,6 @@ func TestApplyCodexOAuthTransform_NormalizeCodexTools_PreservesResponsesFunction
func TestApplyCodexOAuthTransform_EmptyInput(t *testing.T) {
// 空 input 应保持为空且不触发异常。
setupCodexCache(t)
reqBody := map[string]any{
"model": "gpt-5.1",
......@@ -178,97 +167,39 @@ func TestApplyCodexOAuthTransform_EmptyInput(t *testing.T) {
func TestNormalizeCodexModel_Gpt53(t *testing.T) {
cases := map[string]string{
"gpt-5.3": "gpt-5.3",
"gpt-5.3-codex": "gpt-5.3-codex",
"gpt-5.3-codex-xhigh": "gpt-5.3-codex",
"gpt 5.3 codex": "gpt-5.3-codex",
"gpt-5.3": "gpt-5.3-codex",
"gpt-5.3-codex": "gpt-5.3-codex",
"gpt-5.3-codex-xhigh": "gpt-5.3-codex",
"gpt-5.3-codex-spark": "gpt-5.3-codex",
"gpt-5.3-codex-spark-high": "gpt-5.3-codex",
"gpt-5.3-codex-spark-xhigh": "gpt-5.3-codex",
"gpt 5.3 codex": "gpt-5.3-codex",
}
for input, expected := range cases {
require.Equal(t, expected, normalizeCodexModel(input))
}
}
func TestApplyCodexOAuthTransform_CodexCLI_PreservesExistingInstructions(t *testing.T) {
// Codex CLI 场景:已有 instructions 时保持不变
setupCodexCache(t)
// Codex CLI 场景:已有 instructions 时不修改
reqBody := map[string]any{
"model": "gpt-5.1",
"instructions": "user custom instructions",
"input": []any{},
}
result := applyCodexOAuthTransform(reqBody, true)
instructions, ok := reqBody["instructions"].(string)
require.True(t, ok)
require.Equal(t, "user custom instructions", instructions)
// instructions 未变,但其他字段(如 store、stream)可能被修改
require.True(t, result.Modified)
}
func TestApplyCodexOAuthTransform_CodexCLI_AddsInstructionsWhenEmpty(t *testing.T) {
// Codex CLI 场景:无 instructions 时补充内置指令
setupCodexCache(t)
reqBody := map[string]any{
"model": "gpt-5.1",
"input": []any{},
"instructions": "existing instructions",
}
result := applyCodexOAuthTransform(reqBody, true)
instructions, ok := reqBody["instructions"].(string)
require.True(t, ok)
require.NotEmpty(t, instructions)
require.True(t, result.Modified)
}
func TestApplyCodexOAuthTransform_NonCodexCLI_UsesOpenCodeInstructions(t *testing.T) {
// 非 Codex CLI 场景:使用 opencode 指令(缓存中有 header)
setupCodexCache(t)
reqBody := map[string]any{
"model": "gpt-5.1",
"input": []any{},
}
result := applyCodexOAuthTransform(reqBody, false)
result := applyCodexOAuthTransform(reqBody, true) // isCodexCLI=true
instructions, ok := reqBody["instructions"].(string)
require.True(t, ok)
require.Equal(t, "header", instructions) // setupCodexCache 设置的缓存内容
require.True(t, result.Modified)
}
func setupCodexCache(t *testing.T) {
t.Helper()
// 使用临时 HOME 避免触发网络拉取 header。
// Windows 使用 USERPROFILE,Unix 使用 HOME。
tempDir := t.TempDir()
t.Setenv("HOME", tempDir)
t.Setenv("USERPROFILE", tempDir)
cacheDir := filepath.Join(tempDir, ".opencode", "cache")
require.NoError(t, os.MkdirAll(cacheDir, 0o755))
require.NoError(t, os.WriteFile(filepath.Join(cacheDir, "opencode-codex-header.txt"), []byte("header"), 0o644))
meta := map[string]any{
"etag": "",
"lastFetch": time.Now().UTC().Format(time.RFC3339),
"lastChecked": time.Now().UnixMilli(),
}
data, err := json.Marshal(meta)
require.NoError(t, err)
require.NoError(t, os.WriteFile(filepath.Join(cacheDir, "opencode-codex-header-meta.json"), data, 0o644))
require.Equal(t, "existing instructions", instructions)
// Modified 仍可能为 true(因为其他字段被修改),但 instructions 应保持不变
_ = result
}
func TestApplyCodexOAuthTransform_CodexCLI_SuppliesDefaultWhenEmpty(t *testing.T) {
// Codex CLI 场景:无 instructions 时补充默认值
setupCodexCache(t)
reqBody := map[string]any{
"model": "gpt-5.1",
......@@ -284,8 +215,7 @@ func TestApplyCodexOAuthTransform_CodexCLI_SuppliesDefaultWhenEmpty(t *testing.T
}
func TestApplyCodexOAuthTransform_NonCodexCLI_OverridesInstructions(t *testing.T) {
// 非 Codex CLI 场景:使用 opencode 指令覆盖
setupCodexCache(t)
// 非 Codex CLI 场景:使用内置 Codex CLI 指令覆盖
reqBody := map[string]any{
"model": "gpt-5.1",
......
......@@ -10,9 +10,7 @@ import (
"errors"
"fmt"
"io"
"log"
"net/http"
"regexp"
"sort"
"strconv"
"strings"
......@@ -20,10 +18,14 @@ import (
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
"github.com/Wei-Shaw/sub2api/internal/util/responseheaders"
"github.com/Wei-Shaw/sub2api/internal/util/urlvalidator"
"github.com/gin-gonic/gin"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
"go.uber.org/zap"
)
const (
......@@ -32,13 +34,15 @@ const (
// OpenAI Platform API for API Key accounts (fallback)
openaiPlatformAPIURL = "https://api.openai.com/v1/responses"
openaiStickySessionTTL = time.Hour // 粘性会话TTL
)
codexCLIUserAgent = "codex_cli_rs/0.98.0"
// codex_cli_only 拒绝时单个请求头日志长度上限(字符)
codexCLIOnlyHeaderValueMaxBytes = 256
// openaiSSEDataRe matches SSE data lines with optional whitespace after colon.
// Some upstream APIs return non-standard "data:" without space (should be "data: ").
var openaiSSEDataRe = regexp.MustCompile(`^data:\s*`)
// OpenAIParsedRequestBodyKey 缓存 handler 侧已解析的请求体,避免重复解析。
OpenAIParsedRequestBodyKey = "openai_parsed_request_body"
)
// OpenAI allowed headers whitelist (for non-OAuth accounts)
// OpenAI allowed headers whitelist (for non-passthrough).
var openaiAllowedHeaders = map[string]bool{
"accept-language": true,
"content-type": true,
......@@ -48,6 +52,35 @@ var openaiAllowedHeaders = map[string]bool{
"session_id": true,
}
// OpenAI passthrough allowed headers whitelist.
// 透传模式下仅放行这些低风险请求头,避免将非标准/环境噪声头传给上游触发风控。
var openaiPassthroughAllowedHeaders = map[string]bool{
"accept": true,
"accept-language": true,
"content-type": true,
"conversation_id": true,
"openai-beta": true,
"user-agent": true,
"originator": true,
"session_id": true,
}
// codex_cli_only 拒绝时记录的请求头白名单(仅用于诊断日志,不参与上游透传)
var codexCLIOnlyDebugHeaderWhitelist = []string{
"User-Agent",
"Content-Type",
"Accept",
"Accept-Language",
"OpenAI-Beta",
"Originator",
"Session_ID",
"Conversation_ID",
"X-Request-ID",
"X-Client-Request-ID",
"X-Forwarded-For",
"X-Real-IP",
}
// OpenAICodexUsageSnapshot represents Codex API usage limits from response headers
type OpenAICodexUsageSnapshot struct {
PrimaryUsedPercent *float64 `json:"primary_used_percent,omitempty"`
......@@ -175,6 +208,7 @@ type OpenAIGatewayService struct {
userSubRepo UserSubscriptionRepository
cache GatewayCache
cfg *config.Config
codexDetector CodexClientRestrictionDetector
schedulerSnapshot *SchedulerSnapshotService
concurrencyService *ConcurrencyService
billingService *BillingService
......@@ -210,6 +244,7 @@ func NewOpenAIGatewayService(
userSubRepo: userSubRepo,
cache: cache,
cfg: cfg,
codexDetector: NewOpenAICodexClientRestrictionDetector(cfg),
schedulerSnapshot: schedulerSnapshot,
concurrencyService: concurrencyService,
billingService: billingService,
......@@ -222,13 +257,228 @@ func NewOpenAIGatewayService(
}
}
func (s *OpenAIGatewayService) getCodexClientRestrictionDetector() CodexClientRestrictionDetector {
if s != nil && s.codexDetector != nil {
return s.codexDetector
}
var cfg *config.Config
if s != nil {
cfg = s.cfg
}
return NewOpenAICodexClientRestrictionDetector(cfg)
}
func (s *OpenAIGatewayService) detectCodexClientRestriction(c *gin.Context, account *Account) CodexClientRestrictionDetectionResult {
return s.getCodexClientRestrictionDetector().Detect(c, account)
}
func getAPIKeyIDFromContext(c *gin.Context) int64 {
if c == nil {
return 0
}
v, exists := c.Get("api_key")
if !exists {
return 0
}
apiKey, ok := v.(*APIKey)
if !ok || apiKey == nil {
return 0
}
return apiKey.ID
}
func logCodexCLIOnlyDetection(ctx context.Context, c *gin.Context, account *Account, apiKeyID int64, result CodexClientRestrictionDetectionResult, body []byte) {
if !result.Enabled {
return
}
if ctx == nil {
ctx = context.Background()
}
accountID := int64(0)
if account != nil {
accountID = account.ID
}
fields := []zap.Field{
zap.String("component", "service.openai_gateway"),
zap.Int64("account_id", accountID),
zap.Bool("codex_cli_only_enabled", result.Enabled),
zap.Bool("codex_official_client_match", result.Matched),
zap.String("reject_reason", result.Reason),
}
if apiKeyID > 0 {
fields = append(fields, zap.Int64("api_key_id", apiKeyID))
}
if !result.Matched {
fields = appendCodexCLIOnlyRejectedRequestFields(fields, c, body)
}
log := logger.FromContext(ctx).With(fields...)
if result.Matched {
return
}
log.Warn("OpenAI codex_cli_only 拒绝非官方客户端请求")
}
func appendCodexCLIOnlyRejectedRequestFields(fields []zap.Field, c *gin.Context, body []byte) []zap.Field {
if c == nil || c.Request == nil {
return fields
}
req := c.Request
requestModel, requestStream, promptCacheKey := extractOpenAIRequestMetaFromBody(body)
fields = append(fields,
zap.String("request_method", strings.TrimSpace(req.Method)),
zap.String("request_path", strings.TrimSpace(req.URL.Path)),
zap.String("request_query", strings.TrimSpace(req.URL.RawQuery)),
zap.String("request_host", strings.TrimSpace(req.Host)),
zap.String("request_client_ip", strings.TrimSpace(c.ClientIP())),
zap.String("request_remote_addr", strings.TrimSpace(req.RemoteAddr)),
zap.String("request_user_agent", strings.TrimSpace(req.Header.Get("User-Agent"))),
zap.String("request_content_type", strings.TrimSpace(req.Header.Get("Content-Type"))),
zap.Int64("request_content_length", req.ContentLength),
zap.Bool("request_stream", requestStream),
)
if requestModel != "" {
fields = append(fields, zap.String("request_model", requestModel))
}
if promptCacheKey != "" {
fields = append(fields, zap.String("request_prompt_cache_key_sha256", hashSensitiveValueForLog(promptCacheKey)))
}
if headers := snapshotCodexCLIOnlyHeaders(req.Header); len(headers) > 0 {
fields = append(fields, zap.Any("request_headers", headers))
}
fields = append(fields, zap.Int("request_body_size", len(body)))
return fields
}
func snapshotCodexCLIOnlyHeaders(header http.Header) map[string]string {
if len(header) == 0 {
return nil
}
result := make(map[string]string, len(codexCLIOnlyDebugHeaderWhitelist))
for _, key := range codexCLIOnlyDebugHeaderWhitelist {
value := strings.TrimSpace(header.Get(key))
if value == "" {
continue
}
result[strings.ToLower(key)] = truncateString(value, codexCLIOnlyHeaderValueMaxBytes)
}
return result
}
func hashSensitiveValueForLog(raw string) string {
value := strings.TrimSpace(raw)
if value == "" {
return ""
}
sum := sha256.Sum256([]byte(value))
return hex.EncodeToString(sum[:8])
}
func logOpenAIInstructionsRequiredDebug(
ctx context.Context,
c *gin.Context,
account *Account,
upstreamStatusCode int,
upstreamMsg string,
requestBody []byte,
upstreamBody []byte,
) {
msg := strings.TrimSpace(upstreamMsg)
if !isOpenAIInstructionsRequiredError(upstreamStatusCode, msg, upstreamBody) {
return
}
if ctx == nil {
ctx = context.Background()
}
accountID := int64(0)
accountName := ""
if account != nil {
accountID = account.ID
accountName = strings.TrimSpace(account.Name)
}
userAgent := ""
if c != nil {
userAgent = strings.TrimSpace(c.GetHeader("User-Agent"))
}
fields := []zap.Field{
zap.String("component", "service.openai_gateway"),
zap.Int64("account_id", accountID),
zap.String("account_name", accountName),
zap.Int("upstream_status_code", upstreamStatusCode),
zap.String("upstream_error_message", msg),
zap.String("request_user_agent", userAgent),
zap.Bool("codex_official_client_match", openai.IsCodexCLIRequest(userAgent)),
}
fields = appendCodexCLIOnlyRejectedRequestFields(fields, c, requestBody)
logger.FromContext(ctx).With(fields...).Warn("OpenAI 上游返回 Instructions are required,已记录请求详情用于排查")
}
func isOpenAIInstructionsRequiredError(upstreamStatusCode int, upstreamMsg string, upstreamBody []byte) bool {
if upstreamStatusCode != http.StatusBadRequest {
return false
}
hasInstructionRequired := func(text string) bool {
lower := strings.ToLower(strings.TrimSpace(text))
if lower == "" {
return false
}
if strings.Contains(lower, "instructions are required") {
return true
}
if strings.Contains(lower, "required parameter: 'instructions'") {
return true
}
if strings.Contains(lower, "required parameter: instructions") {
return true
}
if strings.Contains(lower, "missing required parameter") && strings.Contains(lower, "instructions") {
return true
}
return strings.Contains(lower, "instruction") && strings.Contains(lower, "required")
}
if hasInstructionRequired(upstreamMsg) {
return true
}
if len(upstreamBody) == 0 {
return false
}
errMsg := gjson.GetBytes(upstreamBody, "error.message").String()
errMsgLower := strings.ToLower(strings.TrimSpace(errMsg))
errCode := strings.ToLower(strings.TrimSpace(gjson.GetBytes(upstreamBody, "error.code").String()))
errParam := strings.ToLower(strings.TrimSpace(gjson.GetBytes(upstreamBody, "error.param").String()))
errType := strings.ToLower(strings.TrimSpace(gjson.GetBytes(upstreamBody, "error.type").String()))
if errParam == "instructions" {
return true
}
if hasInstructionRequired(errMsg) {
return true
}
if strings.Contains(errCode, "missing_required_parameter") && strings.Contains(errMsgLower, "instructions") {
return true
}
if strings.Contains(errType, "invalid_request") && strings.Contains(errMsgLower, "instructions") && strings.Contains(errMsgLower, "required") {
return true
}
return false
}
// GenerateSessionHash generates a sticky-session hash for OpenAI requests.
//
// Priority:
// 1. Header: session_id
// 2. Header: conversation_id
// 3. Body: prompt_cache_key (opencode)
func (s *OpenAIGatewayService) GenerateSessionHash(c *gin.Context, reqBody map[string]any) string {
func (s *OpenAIGatewayService) GenerateSessionHash(c *gin.Context, body []byte) string {
if c == nil {
return ""
}
......@@ -237,10 +487,8 @@ func (s *OpenAIGatewayService) GenerateSessionHash(c *gin.Context, reqBody map[s
if sessionID == "" {
sessionID = strings.TrimSpace(c.GetHeader("conversation_id"))
}
if sessionID == "" && reqBody != nil {
if v, ok := reqBody["prompt_cache_key"].(string); ok {
sessionID = strings.TrimSpace(v)
}
if sessionID == "" && len(body) > 0 {
sessionID = strings.TrimSpace(gjson.GetBytes(body, "prompt_cache_key").String())
}
if sessionID == "" {
return ""
......@@ -744,30 +992,64 @@ func (s *OpenAIGatewayService) handleFailoverSideEffects(ctx context.Context, re
func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, account *Account, body []byte) (*OpenAIForwardResult, error) {
startTime := time.Now()
// Parse request body once (avoid multiple parse/serialize cycles)
var reqBody map[string]any
if err := json.Unmarshal(body, &reqBody); err != nil {
return nil, fmt.Errorf("parse request: %w", err)
restrictionResult := s.detectCodexClientRestriction(c, account)
apiKeyID := getAPIKeyIDFromContext(c)
logCodexCLIOnlyDetection(ctx, c, account, apiKeyID, restrictionResult, body)
if restrictionResult.Enabled && !restrictionResult.Matched {
c.JSON(http.StatusForbidden, gin.H{
"error": gin.H{
"type": "forbidden_error",
"message": "This account only allows Codex official clients",
},
})
return nil, errors.New("codex_cli_only restriction: only codex official clients are allowed")
}
originalBody := body
reqModel, reqStream, promptCacheKey := extractOpenAIRequestMetaFromBody(body)
originalModel := reqModel
isCodexCLI := openai.IsCodexCLIRequest(c.GetHeader("User-Agent")) || (s.cfg != nil && s.cfg.Gateway.ForceCodexCLI)
passthroughEnabled := account.IsOpenAIPassthroughEnabled()
if passthroughEnabled {
// 透传分支只需要轻量提取字段,避免热路径全量 Unmarshal。
reasoningEffort := extractOpenAIReasoningEffortFromBody(body, reqModel)
return s.forwardOpenAIPassthrough(ctx, c, account, originalBody, reqModel, reasoningEffort, reqStream, startTime)
}
// Extract model and stream from parsed body
reqModel, _ := reqBody["model"].(string)
reqStream, _ := reqBody["stream"].(bool)
promptCacheKey := ""
if v, ok := reqBody["prompt_cache_key"].(string); ok {
promptCacheKey = strings.TrimSpace(v)
reqBody, err := getOpenAIRequestBodyMap(c, body)
if err != nil {
return nil, err
}
if v, ok := reqBody["model"].(string); ok {
reqModel = v
originalModel = reqModel
}
if v, ok := reqBody["stream"].(bool); ok {
reqStream = v
}
if promptCacheKey == "" {
if v, ok := reqBody["prompt_cache_key"].(string); ok {
promptCacheKey = strings.TrimSpace(v)
}
}
// Track if body needs re-serialization
bodyModified := false
originalModel := reqModel
isCodexCLI := openai.IsCodexCLIRequest(c.GetHeader("User-Agent"))
// 非透传模式下,保持历史行为:非 Codex CLI 请求在 instructions 为空时注入默认指令。
if !isCodexCLI && isInstructionsEmpty(reqBody) {
if instructions := strings.TrimSpace(GetOpenCodeInstructions()); instructions != "" {
reqBody["instructions"] = instructions
bodyModified = true
}
}
// 对所有请求执行模型映射(包含 Codex CLI)。
mappedModel := account.GetMappedModel(reqModel)
if mappedModel != reqModel {
log.Printf("[OpenAI] Model mapping applied: %s -> %s (account: %s, isCodexCLI: %v)", reqModel, mappedModel, account.Name, isCodexCLI)
logger.LegacyPrintf("service.openai_gateway", "[OpenAI] Model mapping applied: %s -> %s (account: %s, isCodexCLI: %v)", reqModel, mappedModel, account.Name, isCodexCLI)
reqBody["model"] = mappedModel
bodyModified = true
}
......@@ -776,7 +1058,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
if model, ok := reqBody["model"].(string); ok {
normalizedModel := normalizeCodexModel(model)
if normalizedModel != "" && normalizedModel != model {
log.Printf("[OpenAI] Codex model normalization: %s -> %s (account: %s, type: %s, isCodexCLI: %v)",
logger.LegacyPrintf("service.openai_gateway", "[OpenAI] Codex model normalization: %s -> %s (account: %s, type: %s, isCodexCLI: %v)",
model, normalizedModel, account.Name, account.Type, isCodexCLI)
reqBody["model"] = normalizedModel
mappedModel = normalizedModel
......@@ -789,7 +1071,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
if effort, ok := reasoning["effort"].(string); ok && effort == "minimal" {
reasoning["effort"] = "none"
bodyModified = true
log.Printf("[OpenAI] Normalized reasoning.effort: minimal -> none (account: %s)", account.Name)
logger.LegacyPrintf("service.openai_gateway", "[OpenAI] Normalized reasoning.effort: minimal -> none (account: %s)", account.Name)
}
}
......@@ -860,123 +1142,700 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
return nil, fmt.Errorf("serialize request body: %w", err)
}
}
// Get access token
token, _, err := s.GetAccessToken(ctx, account)
if err != nil {
return nil, err
// Get access token
token, _, err := s.GetAccessToken(ctx, account)
if err != nil {
return nil, err
}
// Build upstream request
upstreamReq, err := s.buildUpstreamRequest(ctx, c, account, body, token, reqStream, promptCacheKey, isCodexCLI)
if err != nil {
return nil, err
}
// Get proxy URL
proxyURL := ""
if account.ProxyID != nil && account.Proxy != nil {
proxyURL = account.Proxy.URL()
}
// Capture upstream request body for ops retry of this attempt.
setOpsUpstreamRequestBody(c, body)
// Send request
upstreamStart := time.Now()
resp, err := s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency)
SetOpsLatencyMs(c, OpsUpstreamLatencyMsKey, time.Since(upstreamStart).Milliseconds())
if err != nil {
// Ensure the client receives an error response (handlers assume Forward writes on non-failover errors).
safeErr := sanitizeUpstreamErrorMessage(err.Error())
setOpsUpstreamError(c, 0, safeErr, "")
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: 0,
Kind: "request_error",
Message: safeErr,
})
c.JSON(http.StatusBadGateway, gin.H{
"error": gin.H{
"type": "upstream_error",
"message": "Upstream request failed",
},
})
return nil, fmt.Errorf("upstream request failed: %s", safeErr)
}
defer func() { _ = resp.Body.Close() }()
// Handle error response
if resp.StatusCode >= 400 {
if s.shouldFailoverUpstreamError(resp.StatusCode) {
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
_ = resp.Body.Close()
resp.Body = io.NopCloser(bytes.NewReader(respBody))
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody))
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
upstreamDetail := ""
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
if maxBytes <= 0 {
maxBytes = 2048
}
upstreamDetail = truncateString(string(respBody), maxBytes)
}
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: resp.StatusCode,
UpstreamRequestID: resp.Header.Get("x-request-id"),
Kind: "failover",
Message: upstreamMsg,
Detail: upstreamDetail,
})
s.handleFailoverSideEffects(ctx, resp, account)
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody}
}
return s.handleErrorResponse(ctx, resp, c, account, body)
}
// Handle normal response
var usage *OpenAIUsage
var firstTokenMs *int
if reqStream {
streamResult, err := s.handleStreamingResponse(ctx, resp, c, account, startTime, originalModel, mappedModel)
if err != nil {
return nil, err
}
usage = streamResult.usage
firstTokenMs = streamResult.firstTokenMs
} else {
usage, err = s.handleNonStreamingResponse(ctx, resp, c, account, originalModel, mappedModel)
if err != nil {
return nil, err
}
}
// Extract and save Codex usage snapshot from response headers (for OAuth accounts)
if account.Type == AccountTypeOAuth {
if snapshot := ParseCodexRateLimitHeaders(resp.Header); snapshot != nil {
s.updateCodexUsageSnapshot(ctx, account.ID, snapshot)
}
}
if usage == nil {
usage = &OpenAIUsage{}
}
reasoningEffort := extractOpenAIReasoningEffort(reqBody, originalModel)
return &OpenAIForwardResult{
RequestID: resp.Header.Get("x-request-id"),
Usage: *usage,
Model: originalModel,
ReasoningEffort: reasoningEffort,
Stream: reqStream,
Duration: time.Since(startTime),
FirstTokenMs: firstTokenMs,
}, nil
}
func (s *OpenAIGatewayService) forwardOpenAIPassthrough(
ctx context.Context,
c *gin.Context,
account *Account,
body []byte,
reqModel string,
reasoningEffort *string,
reqStream bool,
startTime time.Time,
) (*OpenAIForwardResult, error) {
if account != nil && account.Type == AccountTypeOAuth {
if rejectReason := detectOpenAIPassthroughInstructionsRejectReason(reqModel, body); rejectReason != "" {
rejectMsg := "OpenAI codex passthrough requires a non-empty instructions field"
setOpsUpstreamError(c, http.StatusForbidden, rejectMsg, "")
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: http.StatusForbidden,
Passthrough: true,
Kind: "request_error",
Message: rejectMsg,
Detail: rejectReason,
})
logOpenAIPassthroughInstructionsRejected(ctx, c, account, reqModel, rejectReason, body)
c.JSON(http.StatusForbidden, gin.H{
"error": gin.H{
"type": "forbidden_error",
"message": rejectMsg,
},
})
return nil, fmt.Errorf("openai passthrough rejected before upstream: %s", rejectReason)
}
normalizedBody, normalized, err := normalizeOpenAIPassthroughOAuthBody(body)
if err != nil {
return nil, err
}
if normalized {
body = normalizedBody
reqStream = true
}
}
logger.LegacyPrintf("service.openai_gateway",
"[OpenAI 自动透传] 命中自动透传分支: account=%d name=%s type=%s model=%s stream=%v",
account.ID,
account.Name,
account.Type,
reqModel,
reqStream,
)
if reqStream && c != nil && c.Request != nil {
if timeoutHeaders := collectOpenAIPassthroughTimeoutHeaders(c.Request.Header); len(timeoutHeaders) > 0 {
streamWarnLogger := logger.FromContext(ctx).With(
zap.String("component", "service.openai_gateway"),
zap.Int64("account_id", account.ID),
zap.Strings("timeout_headers", timeoutHeaders),
)
if s.isOpenAIPassthroughTimeoutHeadersAllowed() {
streamWarnLogger.Warn("OpenAI passthrough 透传请求包含超时相关请求头,且当前配置为放行,可能导致上游提前断流")
} else {
streamWarnLogger.Warn("OpenAI passthrough 检测到超时相关请求头,将按配置过滤以降低断流风险")
}
}
}
// Get access token
token, _, err := s.GetAccessToken(ctx, account)
if err != nil {
return nil, err
}
upstreamReq, err := s.buildUpstreamRequestOpenAIPassthrough(ctx, c, account, body, token)
if err != nil {
return nil, err
}
proxyURL := ""
if account.ProxyID != nil && account.Proxy != nil {
proxyURL = account.Proxy.URL()
}
setOpsUpstreamRequestBody(c, body)
if c != nil {
c.Set("openai_passthrough", true)
}
upstreamStart := time.Now()
resp, err := s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency)
SetOpsLatencyMs(c, OpsUpstreamLatencyMsKey, time.Since(upstreamStart).Milliseconds())
if err != nil {
safeErr := sanitizeUpstreamErrorMessage(err.Error())
setOpsUpstreamError(c, 0, safeErr, "")
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: 0,
Passthrough: true,
Kind: "request_error",
Message: safeErr,
})
c.JSON(http.StatusBadGateway, gin.H{
"error": gin.H{
"type": "upstream_error",
"message": "Upstream request failed",
},
})
return nil, fmt.Errorf("upstream request failed: %s", safeErr)
}
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode >= 400 {
// 透传模式不做 failover(避免改变原始上游语义),按上游原样返回错误响应。
return nil, s.handleErrorResponsePassthrough(ctx, resp, c, account, body)
}
var usage *OpenAIUsage
var firstTokenMs *int
if reqStream {
result, err := s.handleStreamingResponsePassthrough(ctx, resp, c, account, startTime)
if err != nil {
return nil, err
}
usage = result.usage
firstTokenMs = result.firstTokenMs
} else {
usage, err = s.handleNonStreamingResponsePassthrough(ctx, resp, c)
if err != nil {
return nil, err
}
}
if snapshot := ParseCodexRateLimitHeaders(resp.Header); snapshot != nil {
s.updateCodexUsageSnapshot(ctx, account.ID, snapshot)
}
if usage == nil {
usage = &OpenAIUsage{}
}
return &OpenAIForwardResult{
RequestID: resp.Header.Get("x-request-id"),
Usage: *usage,
Model: reqModel,
ReasoningEffort: reasoningEffort,
Stream: reqStream,
Duration: time.Since(startTime),
FirstTokenMs: firstTokenMs,
}, nil
}
func logOpenAIPassthroughInstructionsRejected(
ctx context.Context,
c *gin.Context,
account *Account,
reqModel string,
rejectReason string,
body []byte,
) {
if ctx == nil {
ctx = context.Background()
}
accountID := int64(0)
accountName := ""
accountType := ""
if account != nil {
accountID = account.ID
accountName = strings.TrimSpace(account.Name)
accountType = strings.TrimSpace(string(account.Type))
}
fields := []zap.Field{
zap.String("component", "service.openai_gateway"),
zap.Int64("account_id", accountID),
zap.String("account_name", accountName),
zap.String("account_type", accountType),
zap.String("request_model", strings.TrimSpace(reqModel)),
zap.String("reject_reason", strings.TrimSpace(rejectReason)),
}
fields = appendCodexCLIOnlyRejectedRequestFields(fields, c, body)
logger.FromContext(ctx).With(fields...).Warn("OpenAI passthrough 本地拦截:Codex 请求缺少有效 instructions")
}
func (s *OpenAIGatewayService) buildUpstreamRequestOpenAIPassthrough(
ctx context.Context,
c *gin.Context,
account *Account,
body []byte,
token string,
) (*http.Request, error) {
targetURL := openaiPlatformAPIURL
switch account.Type {
case AccountTypeOAuth:
targetURL = chatgptCodexURL
case AccountTypeAPIKey:
baseURL := account.GetOpenAIBaseURL()
if baseURL != "" {
validatedURL, err := s.validateUpstreamBaseURL(baseURL)
if err != nil {
return nil, err
}
targetURL = buildOpenAIResponsesURL(validatedURL)
}
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, targetURL, bytes.NewReader(body))
if err != nil {
return nil, err
}
// 透传客户端请求头(安全白名单)。
allowTimeoutHeaders := s.isOpenAIPassthroughTimeoutHeadersAllowed()
if c != nil && c.Request != nil {
for key, values := range c.Request.Header {
lower := strings.ToLower(strings.TrimSpace(key))
if !isOpenAIPassthroughAllowedRequestHeader(lower, allowTimeoutHeaders) {
continue
}
for _, v := range values {
req.Header.Add(key, v)
}
}
}
// 覆盖入站鉴权残留,并注入上游认证
req.Header.Del("authorization")
req.Header.Del("x-api-key")
req.Header.Del("x-goog-api-key")
req.Header.Set("authorization", "Bearer "+token)
// OAuth 透传到 ChatGPT internal API 时补齐必要头。
if account.Type == AccountTypeOAuth {
promptCacheKey := strings.TrimSpace(gjson.GetBytes(body, "prompt_cache_key").String())
req.Host = "chatgpt.com"
if chatgptAccountID := account.GetChatGPTAccountID(); chatgptAccountID != "" {
req.Header.Set("chatgpt-account-id", chatgptAccountID)
}
if req.Header.Get("accept") == "" {
req.Header.Set("accept", "text/event-stream")
}
if req.Header.Get("OpenAI-Beta") == "" {
req.Header.Set("OpenAI-Beta", "responses=experimental")
}
if req.Header.Get("originator") == "" {
req.Header.Set("originator", "codex_cli_rs")
}
if promptCacheKey != "" {
if req.Header.Get("conversation_id") == "" {
req.Header.Set("conversation_id", promptCacheKey)
}
if req.Header.Get("session_id") == "" {
req.Header.Set("session_id", promptCacheKey)
}
}
}
// 透传模式也支持账户自定义 User-Agent 与 ForceCodexCLI 兜底。
customUA := account.GetOpenAIUserAgent()
if customUA != "" {
req.Header.Set("user-agent", customUA)
}
if s.cfg != nil && s.cfg.Gateway.ForceCodexCLI {
req.Header.Set("user-agent", codexCLIUserAgent)
}
// OAuth 安全透传:对非 Codex UA 统一兜底,降低被上游风控拦截概率。
if account.Type == AccountTypeOAuth && !openai.IsCodexCLIRequest(req.Header.Get("user-agent")) {
req.Header.Set("user-agent", codexCLIUserAgent)
}
if req.Header.Get("content-type") == "" {
req.Header.Set("content-type", "application/json")
}
return req, nil
}
func (s *OpenAIGatewayService) handleErrorResponsePassthrough(
ctx context.Context,
resp *http.Response,
c *gin.Context,
account *Account,
requestBody []byte,
) error {
body, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(body))
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
upstreamDetail := ""
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
if maxBytes <= 0 {
maxBytes = 2048
}
upstreamDetail = truncateString(string(body), maxBytes)
}
setOpsUpstreamError(c, resp.StatusCode, upstreamMsg, upstreamDetail)
logOpenAIInstructionsRequiredDebug(ctx, c, account, resp.StatusCode, upstreamMsg, requestBody, body)
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: resp.StatusCode,
UpstreamRequestID: resp.Header.Get("x-request-id"),
Passthrough: true,
Kind: "http_error",
Message: upstreamMsg,
Detail: upstreamDetail,
UpstreamResponseBody: upstreamDetail,
})
writeOpenAIPassthroughResponseHeaders(c.Writer.Header(), resp.Header, s.cfg)
contentType := resp.Header.Get("Content-Type")
if contentType == "" {
contentType = "application/json"
}
c.Data(resp.StatusCode, contentType, body)
if upstreamMsg == "" {
return fmt.Errorf("upstream error: %d", resp.StatusCode)
}
return fmt.Errorf("upstream error: %d message=%s", resp.StatusCode, upstreamMsg)
}
func isOpenAIPassthroughAllowedRequestHeader(lowerKey string, allowTimeoutHeaders bool) bool {
if lowerKey == "" {
return false
}
if isOpenAIPassthroughTimeoutHeader(lowerKey) {
return allowTimeoutHeaders
}
return openaiPassthroughAllowedHeaders[lowerKey]
}
func isOpenAIPassthroughTimeoutHeader(lowerKey string) bool {
switch lowerKey {
case "x-stainless-timeout", "x-stainless-read-timeout", "x-stainless-connect-timeout", "x-request-timeout", "request-timeout", "grpc-timeout":
return true
default:
return false
}
}
func (s *OpenAIGatewayService) isOpenAIPassthroughTimeoutHeadersAllowed() bool {
return s != nil && s.cfg != nil && s.cfg.Gateway.OpenAIPassthroughAllowTimeoutHeaders
}
func collectOpenAIPassthroughTimeoutHeaders(h http.Header) []string {
if h == nil {
return nil
}
var matched []string
for key, values := range h {
lowerKey := strings.ToLower(strings.TrimSpace(key))
if isOpenAIPassthroughTimeoutHeader(lowerKey) {
entry := lowerKey
if len(values) > 0 {
entry = fmt.Sprintf("%s=%s", lowerKey, strings.Join(values, "|"))
}
matched = append(matched, entry)
}
}
sort.Strings(matched)
return matched
}
type openaiStreamingResultPassthrough struct {
usage *OpenAIUsage
firstTokenMs *int
}
func (s *OpenAIGatewayService) handleStreamingResponsePassthrough(
ctx context.Context,
resp *http.Response,
c *gin.Context,
account *Account,
startTime time.Time,
) (*openaiStreamingResultPassthrough, error) {
writeOpenAIPassthroughResponseHeaders(c.Writer.Header(), resp.Header, s.cfg)
// SSE headers
c.Header("Content-Type", "text/event-stream")
c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive")
c.Header("X-Accel-Buffering", "no")
if v := resp.Header.Get("x-request-id"); v != "" {
c.Header("x-request-id", v)
}
w := c.Writer
flusher, ok := w.(http.Flusher)
if !ok {
return nil, errors.New("streaming not supported")
}
usage := &OpenAIUsage{}
var firstTokenMs *int
clientDisconnected := false
sawDone := false
upstreamRequestID := strings.TrimSpace(resp.Header.Get("x-request-id"))
scanner := bufio.NewScanner(resp.Body)
maxLineSize := defaultMaxLineSize
if s.cfg != nil && s.cfg.Gateway.MaxLineSize > 0 {
maxLineSize = s.cfg.Gateway.MaxLineSize
}
scanBuf := getSSEScannerBuf64K()
scanner.Buffer(scanBuf[:0], maxLineSize)
defer putSSEScannerBuf64K(scanBuf)
for scanner.Scan() {
line := scanner.Text()
if data, ok := extractOpenAISSEDataLine(line); ok {
trimmedData := strings.TrimSpace(data)
if trimmedData == "[DONE]" {
sawDone = true
}
if firstTokenMs == nil && trimmedData != "" && trimmedData != "[DONE]" {
ms := int(time.Since(startTime).Milliseconds())
firstTokenMs = &ms
}
s.parseSSEUsage(data, usage)
}
if !clientDisconnected {
if _, err := fmt.Fprintln(w, line); err != nil {
clientDisconnected = true
logger.LegacyPrintf("service.openai_gateway", "[OpenAI passthrough] Client disconnected during streaming, continue draining upstream for usage: account=%d", account.ID)
} else {
flusher.Flush()
}
}
}
if err := scanner.Err(); err != nil {
if clientDisconnected {
logger.LegacyPrintf("service.openai_gateway", "[OpenAI passthrough] Upstream read error after client disconnect: account=%d err=%v", account.ID, err)
return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, nil
}
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
logger.LegacyPrintf("service.openai_gateway",
"[OpenAI passthrough] 流读取被取消,可能发生断流: account=%d request_id=%s err=%v ctx_err=%v",
account.ID,
upstreamRequestID,
err,
ctx.Err(),
)
return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, nil
}
if errors.Is(err, bufio.ErrTooLong) {
logger.LegacyPrintf("service.openai_gateway", "[OpenAI passthrough] SSE line too long: account=%d max_size=%d error=%v", account.ID, maxLineSize, err)
return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, err
}
logger.LegacyPrintf("service.openai_gateway",
"[OpenAI passthrough] 流读取异常中断: account=%d request_id=%s err=%v",
account.ID,
upstreamRequestID,
err,
)
return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream read error: %w", err)
}
if !clientDisconnected && !sawDone && ctx.Err() == nil {
logger.FromContext(ctx).With(
zap.String("component", "service.openai_gateway"),
zap.Int64("account_id", account.ID),
zap.String("upstream_request_id", upstreamRequestID),
).Info("OpenAI passthrough 上游流在未收到 [DONE] 时结束,疑似断流")
}
// Build upstream request
upstreamReq, err := s.buildUpstreamRequest(ctx, c, account, body, token, reqStream, promptCacheKey, isCodexCLI)
return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, nil
}
func (s *OpenAIGatewayService) handleNonStreamingResponsePassthrough(
ctx context.Context,
resp *http.Response,
c *gin.Context,
) (*OpenAIUsage, error) {
maxBytes := resolveUpstreamResponseReadLimit(s.cfg)
body, err := readUpstreamResponseBodyLimited(resp.Body, maxBytes)
if err != nil {
if errors.Is(err, ErrUpstreamResponseBodyTooLarge) {
setOpsUpstreamError(c, http.StatusBadGateway, "upstream response too large", "")
c.JSON(http.StatusBadGateway, gin.H{
"error": gin.H{
"type": "upstream_error",
"message": "Upstream response too large",
},
})
}
return nil, err
}
// Get proxy URL
proxyURL := ""
if account.ProxyID != nil && account.Proxy != nil {
proxyURL = account.Proxy.URL()
}
// Capture upstream request body for ops retry of this attempt.
if c != nil {
c.Set(OpsUpstreamRequestBodyKey, string(body))
usage := &OpenAIUsage{}
usageParsed := false
if len(body) > 0 {
var response struct {
Usage struct {
InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"`
InputTokenDetails struct {
CachedTokens int `json:"cached_tokens"`
} `json:"input_tokens_details"`
} `json:"usage"`
}
if json.Unmarshal(body, &response) == nil {
usage.InputTokens = response.Usage.InputTokens
usage.OutputTokens = response.Usage.OutputTokens
usage.CacheReadInputTokens = response.Usage.InputTokenDetails.CachedTokens
usageParsed = true
}
}
// Send request
resp, err := s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency)
if err != nil {
// Ensure the client receives an error response (handlers assume Forward writes on non-failover errors).
safeErr := sanitizeUpstreamErrorMessage(err.Error())
setOpsUpstreamError(c, 0, safeErr, "")
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: 0,
Kind: "request_error",
Message: safeErr,
})
c.JSON(http.StatusBadGateway, gin.H{
"error": gin.H{
"type": "upstream_error",
"message": "Upstream request failed",
},
})
return nil, fmt.Errorf("upstream request failed: %s", safeErr)
if !usageParsed {
// 兜底:尝试从 SSE 文本中解析 usage
usage = s.parseSSEUsageFromBody(string(body))
}
defer func() { _ = resp.Body.Close() }()
// Handle error response
if resp.StatusCode >= 400 {
if s.shouldFailoverUpstreamError(resp.StatusCode) {
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
_ = resp.Body.Close()
resp.Body = io.NopCloser(bytes.NewReader(respBody))
writeOpenAIPassthroughResponseHeaders(c.Writer.Header(), resp.Header, s.cfg)
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody))
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
upstreamDetail := ""
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
if maxBytes <= 0 {
maxBytes = 2048
}
upstreamDetail = truncateString(string(respBody), maxBytes)
}
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: resp.StatusCode,
UpstreamRequestID: resp.Header.Get("x-request-id"),
Kind: "failover",
Message: upstreamMsg,
Detail: upstreamDetail,
})
contentType := resp.Header.Get("Content-Type")
if contentType == "" {
contentType = "application/json"
}
c.Data(resp.StatusCode, contentType, body)
return usage, nil
}
s.handleFailoverSideEffects(ctx, resp, account)
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody}
func writeOpenAIPassthroughResponseHeaders(dst http.Header, src http.Header, cfg *config.Config) {
if dst == nil || src == nil {
return
}
if cfg != nil {
responseheaders.WriteFilteredHeaders(dst, src, cfg.Security.ResponseHeaders)
} else {
// 兜底:尽量保留最基础的 content-type
if v := strings.TrimSpace(src.Get("Content-Type")); v != "" {
dst.Set("Content-Type", v)
}
return s.handleErrorResponse(ctx, resp, c, account)
}
// Handle normal response
var usage *OpenAIUsage
var firstTokenMs *int
if reqStream {
streamResult, err := s.handleStreamingResponse(ctx, resp, c, account, startTime, originalModel, mappedModel)
if err != nil {
return nil, err
// 透传模式强制放行 x-codex-* 响应头(若上游返回)。
// 注意:真实 http.Response.Header 的 key 一般会被 canonicalize;但为了兼容测试/自建响应,
// 这里用 EqualFold 做一次大小写不敏感的查找。
getCaseInsensitiveValues := func(h http.Header, want string) []string {
if h == nil {
return nil
}
usage = streamResult.usage
firstTokenMs = streamResult.firstTokenMs
} else {
usage, err = s.handleNonStreamingResponse(ctx, resp, c, account, originalModel, mappedModel)
if err != nil {
return nil, err
for k, vals := range h {
if strings.EqualFold(k, want) {
return vals
}
}
return nil
}
// Extract and save Codex usage snapshot from response headers (for OAuth accounts)
if account.Type == AccountTypeOAuth {
if snapshot := ParseCodexRateLimitHeaders(resp.Header); snapshot != nil {
s.updateCodexUsageSnapshot(ctx, account.ID, snapshot)
for _, rawKey := range []string{
"x-codex-primary-used-percent",
"x-codex-primary-reset-after-seconds",
"x-codex-primary-window-minutes",
"x-codex-secondary-used-percent",
"x-codex-secondary-reset-after-seconds",
"x-codex-secondary-window-minutes",
"x-codex-primary-over-secondary-limit-percent",
} {
vals := getCaseInsensitiveValues(src, rawKey)
if len(vals) == 0 {
continue
}
key := http.CanonicalHeaderKey(rawKey)
dst.Del(key)
for _, v := range vals {
dst.Add(key, v)
}
}
reasoningEffort := extractOpenAIReasoningEffort(reqBody, originalModel)
return &OpenAIForwardResult{
RequestID: resp.Header.Get("x-request-id"),
Usage: *usage,
Model: originalModel,
ReasoningEffort: reasoningEffort,
Stream: reqStream,
Duration: time.Since(startTime),
FirstTokenMs: firstTokenMs,
}, nil
}
func (s *OpenAIGatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Context, account *Account, body []byte, token string, isStream bool, promptCacheKey string, isCodexCLI bool) (*http.Request, error) {
......@@ -996,7 +1855,7 @@ func (s *OpenAIGatewayService) buildUpstreamRequest(ctx context.Context, c *gin.
if err != nil {
return nil, err
}
targetURL = validatedURL + "/responses"
targetURL = buildOpenAIResponsesURL(validatedURL)
}
default:
targetURL = openaiPlatformAPIURL
......@@ -1050,6 +1909,12 @@ func (s *OpenAIGatewayService) buildUpstreamRequest(ctx context.Context, c *gin.
req.Header.Set("user-agent", customUA)
}
// 若开启 ForceCodexCLI,则强制将上游 User-Agent 伪装为 Codex CLI。
// 用于网关未透传/改写 User-Agent 时,仍能命中 Codex 侧识别逻辑。
if s.cfg != nil && s.cfg.Gateway.ForceCodexCLI {
req.Header.Set("user-agent", codexCLIUserAgent)
}
// Ensure required headers exist
if req.Header.Get("content-type") == "" {
req.Header.Set("content-type", "application/json")
......@@ -1058,7 +1923,13 @@ func (s *OpenAIGatewayService) buildUpstreamRequest(ctx context.Context, c *gin.
return req, nil
}
func (s *OpenAIGatewayService) handleErrorResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account) (*OpenAIForwardResult, error) {
func (s *OpenAIGatewayService) handleErrorResponse(
ctx context.Context,
resp *http.Response,
c *gin.Context,
account *Account,
requestBody []byte,
) (*OpenAIForwardResult, error) {
body, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(body))
......@@ -1072,9 +1943,10 @@ func (s *OpenAIGatewayService) handleErrorResponse(ctx context.Context, resp *ht
upstreamDetail = truncateString(string(body), maxBytes)
}
setOpsUpstreamError(c, resp.StatusCode, upstreamMsg, upstreamDetail)
logOpenAIInstructionsRequiredDebug(ctx, c, account, resp.StatusCode, upstreamMsg, requestBody, body)
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
log.Printf(
logger.LegacyPrintf("service.openai_gateway",
"OpenAI upstream error %d (account=%d platform=%s type=%s): %s",
resp.StatusCode,
account.ID,
......@@ -1230,7 +2102,8 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
if s.cfg != nil && s.cfg.Gateway.MaxLineSize > 0 {
maxLineSize = s.cfg.Gateway.MaxLineSize
}
scanner.Buffer(make([]byte, 64*1024), maxLineSize)
scanBuf := getSSEScannerBuf64K()
scanner.Buffer(scanBuf[:0], maxLineSize)
type scanEvent struct {
line string
......@@ -1249,7 +2122,8 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
}
var lastReadAt int64
atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
go func() {
go func(scanBuf *sseScannerBuf64K) {
defer putSSEScannerBuf64K(scanBuf)
defer close(events)
for scanner.Scan() {
atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
......@@ -1260,7 +2134,7 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
if err := scanner.Err(); err != nil {
_ = sendEvent(scanEvent{err: err})
}
}()
}(scanBuf)
defer close(done)
streamInterval := time.Duration(0)
......@@ -1332,16 +2206,16 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
// 客户端断开/取消请求时,上游读取往往会返回 context canceled。
// /v1/responses 的 SSE 事件必须符合 OpenAI 协议;这里不注入自定义 error event,避免下游 SDK 解析失败。
if errors.Is(ev.err, context.Canceled) || errors.Is(ev.err, context.DeadlineExceeded) {
log.Printf("Context canceled during streaming, returning collected usage")
logger.LegacyPrintf("service.openai_gateway", "Context canceled during streaming, returning collected usage")
return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, nil
}
// 客户端已断开时,上游出错仅影响体验,不影响计费;返回已收集 usage
if clientDisconnected {
log.Printf("Upstream read error after client disconnect: %v, returning collected usage", ev.err)
logger.LegacyPrintf("service.openai_gateway", "Upstream read error after client disconnect: %v, returning collected usage", ev.err)
return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, nil
}
if errors.Is(ev.err, bufio.ErrTooLong) {
log.Printf("SSE line too long: account=%d max_size=%d error=%v", account.ID, maxLineSize, ev.err)
logger.LegacyPrintf("service.openai_gateway", "SSE line too long: account=%d max_size=%d error=%v", account.ID, maxLineSize, ev.err)
sendErrorEvent("response_too_large")
return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, ev.err
}
......@@ -1353,8 +2227,7 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
lastDataAt = time.Now()
// Extract data from SSE line (supports both "data: " and "data:" formats)
if openaiSSEDataRe.MatchString(line) {
data := openaiSSEDataRe.ReplaceAllString(line, "")
if data, ok := extractOpenAISSEDataLine(line); ok {
// Replace model in response if needed
if needModelReplace {
......@@ -1371,7 +2244,7 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
if !clientDisconnected {
if _, err := fmt.Fprintf(w, "%s\n", line); err != nil {
clientDisconnected = true
log.Printf("Client disconnected during streaming, continuing to drain upstream for billing")
logger.LegacyPrintf("service.openai_gateway", "Client disconnected during streaming, continuing to drain upstream for billing")
} else {
flusher.Flush()
}
......@@ -1388,7 +2261,7 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
if !clientDisconnected {
if _, err := fmt.Fprintf(w, "%s\n", line); err != nil {
clientDisconnected = true
log.Printf("Client disconnected during streaming, continuing to drain upstream for billing")
logger.LegacyPrintf("service.openai_gateway", "Client disconnected during streaming, continuing to drain upstream for billing")
} else {
flusher.Flush()
}
......@@ -1401,10 +2274,10 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
continue
}
if clientDisconnected {
log.Printf("Upstream timeout after client disconnect, returning collected usage")
logger.LegacyPrintf("service.openai_gateway", "Upstream timeout after client disconnect, returning collected usage")
return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, nil
}
log.Printf("Stream data interval timeout: account=%d model=%s interval=%s", account.ID, originalModel, streamInterval)
logger.LegacyPrintf("service.openai_gateway", "Stream data interval timeout: account=%d model=%s interval=%s", account.ID, originalModel, streamInterval)
// 处理流超时,可能标记账户为临时不可调度或错误状态
if s.rateLimitService != nil {
s.rateLimitService.HandleStreamTimeout(ctx, account, originalModel)
......@@ -1421,7 +2294,7 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
}
if _, err := fmt.Fprint(w, ":\n\n"); err != nil {
clientDisconnected = true
log.Printf("Client disconnected during streaming, continuing to drain upstream for billing")
logger.LegacyPrintf("service.openai_gateway", "Client disconnected during streaming, continuing to drain upstream for billing")
continue
}
flusher.Flush()
......@@ -1430,40 +2303,47 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
}
// extractOpenAISSEDataLine 低开销提取 SSE `data:` 行内容。
// 兼容 `data: xxx` 与 `data:xxx` 两种格式。
func extractOpenAISSEDataLine(line string) (string, bool) {
if !strings.HasPrefix(line, "data:") {
return "", false
}
start := len("data:")
for start < len(line) {
if line[start] != ' ' && line[start] != ' ' {
break
}
start++
}
return line[start:], true
}
func (s *OpenAIGatewayService) replaceModelInSSELine(line, fromModel, toModel string) string {
if !openaiSSEDataRe.MatchString(line) {
data, ok := extractOpenAISSEDataLine(line)
if !ok {
return line
}
data := openaiSSEDataRe.ReplaceAllString(line, "")
if data == "" || data == "[DONE]" {
return line
}
var event map[string]any
if err := json.Unmarshal([]byte(data), &event); err != nil {
return line
}
// Replace model in response
if m, ok := event["model"].(string); ok && m == fromModel {
event["model"] = toModel
newData, err := json.Marshal(event)
// 使用 gjson 精确检查 model 字段,避免全量 JSON 反序列化
if m := gjson.Get(data, "model"); m.Exists() && m.Str == fromModel {
newData, err := sjson.Set(data, "model", toModel)
if err != nil {
return line
}
return "data: " + string(newData)
return "data: " + newData
}
// Check nested response
if response, ok := event["response"].(map[string]any); ok {
if m, ok := response["model"].(string); ok && m == fromModel {
response["model"] = toModel
newData, err := json.Marshal(event)
if err != nil {
return line
}
return "data: " + string(newData)
// 检查嵌套的 response.model 字段
if m := gjson.Get(data, "response.model"); m.Exists() && m.Str == fromModel {
newData, err := sjson.Set(data, "response.model", toModel)
if err != nil {
return line
}
return "data: " + newData
}
return line
......@@ -1484,30 +2364,35 @@ func (s *OpenAIGatewayService) correctToolCallsInResponseBody(body []byte) []byt
}
func (s *OpenAIGatewayService) parseSSEUsage(data string, usage *OpenAIUsage) {
// Parse response.completed event for usage (OpenAI Responses format)
var event struct {
Type string `json:"type"`
Response struct {
Usage struct {
InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"`
InputTokenDetails struct {
CachedTokens int `json:"cached_tokens"`
} `json:"input_tokens_details"`
} `json:"usage"`
} `json:"response"`
if usage == nil || data == "" || data == "[DONE]" {
return
}
if json.Unmarshal([]byte(data), &event) == nil && event.Type == "response.completed" {
usage.InputTokens = event.Response.Usage.InputTokens
usage.OutputTokens = event.Response.Usage.OutputTokens
usage.CacheReadInputTokens = event.Response.Usage.InputTokenDetails.CachedTokens
// 选择性解析:仅在数据中包含 completed 事件标识时才进入字段提取。
if !strings.Contains(data, `"response.completed"`) {
return
}
if gjson.Get(data, "type").String() != "response.completed" {
return
}
usage.InputTokens = int(gjson.Get(data, "response.usage.input_tokens").Int())
usage.OutputTokens = int(gjson.Get(data, "response.usage.output_tokens").Int())
usage.CacheReadInputTokens = int(gjson.Get(data, "response.usage.input_tokens_details.cached_tokens").Int())
}
func (s *OpenAIGatewayService) handleNonStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, originalModel, mappedModel string) (*OpenAIUsage, error) {
body, err := io.ReadAll(resp.Body)
maxBytes := resolveUpstreamResponseReadLimit(s.cfg)
body, err := readUpstreamResponseBodyLimited(resp.Body, maxBytes)
if err != nil {
if errors.Is(err, ErrUpstreamResponseBodyTooLarge) {
setOpsUpstreamError(c, http.StatusBadGateway, "upstream response too large", "")
c.JSON(http.StatusBadGateway, gin.H{
"error": gin.H{
"type": "upstream_error",
"message": "Upstream response too large",
},
})
}
return nil, err
}
......@@ -1613,10 +2498,10 @@ func (s *OpenAIGatewayService) handleOAuthSSEToJSON(resp *http.Response, c *gin.
func extractCodexFinalResponse(body string) ([]byte, bool) {
lines := strings.Split(body, "\n")
for _, line := range lines {
if !openaiSSEDataRe.MatchString(line) {
data, ok := extractOpenAISSEDataLine(line)
if !ok {
continue
}
data := openaiSSEDataRe.ReplaceAllString(line, "")
if data == "" || data == "[DONE]" {
continue
}
......@@ -1640,10 +2525,10 @@ func (s *OpenAIGatewayService) parseSSEUsageFromBody(body string) *OpenAIUsage {
usage := &OpenAIUsage{}
lines := strings.Split(body, "\n")
for _, line := range lines {
if !openaiSSEDataRe.MatchString(line) {
data, ok := extractOpenAISSEDataLine(line)
if !ok {
continue
}
data := openaiSSEDataRe.ReplaceAllString(line, "")
if data == "" || data == "[DONE]" {
continue
}
......@@ -1655,7 +2540,7 @@ func (s *OpenAIGatewayService) parseSSEUsageFromBody(body string) *OpenAIUsage {
func (s *OpenAIGatewayService) replaceModelInSSEBody(body, fromModel, toModel string) string {
lines := strings.Split(body, "\n")
for i, line := range lines {
if !openaiSSEDataRe.MatchString(line) {
if _, ok := extractOpenAISSEDataLine(line); !ok {
continue
}
lines[i] = s.replaceModelInSSELine(line, fromModel, toModel)
......@@ -1682,24 +2567,31 @@ func (s *OpenAIGatewayService) validateUpstreamBaseURL(raw string) (string, erro
return normalized, nil
}
func (s *OpenAIGatewayService) replaceModelInResponseBody(body []byte, fromModel, toModel string) []byte {
var resp map[string]any
if err := json.Unmarshal(body, &resp); err != nil {
return body
// buildOpenAIResponsesURL 组装 OpenAI Responses 端点。
// - base 以 /v1 结尾:追加 /responses
// - base 已是 /responses:原样返回
// - 其他情况:追加 /v1/responses
func buildOpenAIResponsesURL(base string) string {
normalized := strings.TrimRight(strings.TrimSpace(base), "/")
if strings.HasSuffix(normalized, "/responses") {
return normalized
}
model, ok := resp["model"].(string)
if !ok || model != fromModel {
return body
if strings.HasSuffix(normalized, "/v1") {
return normalized + "/responses"
}
return normalized + "/v1/responses"
}
resp["model"] = toModel
newBody, err := json.Marshal(resp)
if err != nil {
return body
func (s *OpenAIGatewayService) replaceModelInResponseBody(body []byte, fromModel, toModel string) []byte {
// 使用 gjson/sjson 精确替换 model 字段,避免全量 JSON 反序列化
if m := gjson.GetBytes(body, "model"); m.Exists() && m.Str == fromModel {
newBody, err := sjson.SetBytes(body, "model", toModel)
if err != nil {
return body
}
return newBody
}
return newBody
return body
}
// OpenAIRecordUsageInput input for recording usage
......@@ -1803,7 +2695,7 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
inserted, err := s.usageLogRepo.Create(ctx, usageLog)
if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple {
log.Printf("[SIMPLE MODE] Usage recorded (not billed): user=%d, tokens=%d", usageLog.UserID, usageLog.TotalTokens())
logger.LegacyPrintf("service.openai_gateway", "[SIMPLE MODE] Usage recorded (not billed): user=%d, tokens=%d", usageLog.UserID, usageLog.TotalTokens())
s.deferredService.ScheduleLastUsedUpdate(account.ID)
return nil
}
......@@ -1826,7 +2718,7 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
// Update API key quota if applicable (only for balance mode with quota set)
if shouldBill && cost.ActualCost > 0 && apiKey.Quota > 0 && input.APIKeyService != nil {
if err := input.APIKeyService.UpdateQuotaUsed(ctx, apiKey.ID, cost.ActualCost); err != nil {
log.Printf("Update API key quota failed: %v", err)
logger.LegacyPrintf("service.openai_gateway", "Update API key quota failed: %v", err)
}
}
......@@ -1904,16 +2796,41 @@ func ParseCodexRateLimitHeaders(headers http.Header) *OpenAICodexUsageSnapshot {
return snapshot
}
// updateCodexUsageSnapshot saves the Codex usage snapshot to account's Extra field
func (s *OpenAIGatewayService) updateCodexUsageSnapshot(ctx context.Context, accountID int64, snapshot *OpenAICodexUsageSnapshot) {
func codexSnapshotBaseTime(snapshot *OpenAICodexUsageSnapshot, fallback time.Time) time.Time {
if snapshot == nil {
return
return fallback
}
if snapshot.UpdatedAt == "" {
return fallback
}
base, err := time.Parse(time.RFC3339, snapshot.UpdatedAt)
if err != nil {
return fallback
}
return base
}
func codexResetAtRFC3339(base time.Time, resetAfterSeconds *int) *string {
if resetAfterSeconds == nil {
return nil
}
sec := *resetAfterSeconds
if sec < 0 {
sec = 0
}
resetAt := base.Add(time.Duration(sec) * time.Second).Format(time.RFC3339)
return &resetAt
}
func buildCodexUsageExtraUpdates(snapshot *OpenAICodexUsageSnapshot, fallbackNow time.Time) map[string]any {
if snapshot == nil {
return nil
}
// Convert snapshot to map for merging into Extra
baseTime := codexSnapshotBaseTime(snapshot, fallbackNow)
updates := make(map[string]any)
// Save raw primary/secondary fields for debugging/tracing
// 保存原始 primary/secondary 字段,便于排查问题
if snapshot.PrimaryUsedPercent != nil {
updates["codex_primary_used_percent"] = *snapshot.PrimaryUsedPercent
}
......@@ -1935,9 +2852,9 @@ func (s *OpenAIGatewayService) updateCodexUsageSnapshot(ctx context.Context, acc
if snapshot.PrimaryOverSecondaryPercent != nil {
updates["codex_primary_over_secondary_percent"] = *snapshot.PrimaryOverSecondaryPercent
}
updates["codex_usage_updated_at"] = snapshot.UpdatedAt
updates["codex_usage_updated_at"] = baseTime.Format(time.RFC3339)
// Normalize to canonical 5h/7d fields
// 归一化到 5h/7d 规范字段
if normalized := snapshot.Normalize(); normalized != nil {
if normalized.Used5hPercent != nil {
updates["codex_5h_used_percent"] = *normalized.Used5hPercent
......@@ -1957,6 +2874,29 @@ func (s *OpenAIGatewayService) updateCodexUsageSnapshot(ctx context.Context, acc
if normalized.Window7dMinutes != nil {
updates["codex_7d_window_minutes"] = *normalized.Window7dMinutes
}
if reset5hAt := codexResetAtRFC3339(baseTime, normalized.Reset5hSeconds); reset5hAt != nil {
updates["codex_5h_reset_at"] = *reset5hAt
}
if reset7dAt := codexResetAtRFC3339(baseTime, normalized.Reset7dSeconds); reset7dAt != nil {
updates["codex_7d_reset_at"] = *reset7dAt
}
}
return updates
}
// updateCodexUsageSnapshot saves the Codex usage snapshot to account's Extra field
func (s *OpenAIGatewayService) updateCodexUsageSnapshot(ctx context.Context, accountID int64, snapshot *OpenAICodexUsageSnapshot) {
if snapshot == nil {
return
}
if s == nil || s.accountRepo == nil {
return
}
updates := buildCodexUsageExtraUpdates(snapshot, time.Now())
if len(updates) == 0 {
return
}
// Update account's Extra field asynchronously
......@@ -2013,6 +2953,103 @@ func deriveOpenAIReasoningEffortFromModel(model string) string {
return normalizeOpenAIReasoningEffort(parts[len(parts)-1])
}
func extractOpenAIRequestMetaFromBody(body []byte) (model string, stream bool, promptCacheKey string) {
if len(body) == 0 {
return "", false, ""
}
model = strings.TrimSpace(gjson.GetBytes(body, "model").String())
stream = gjson.GetBytes(body, "stream").Bool()
promptCacheKey = strings.TrimSpace(gjson.GetBytes(body, "prompt_cache_key").String())
return model, stream, promptCacheKey
}
// normalizeOpenAIPassthroughOAuthBody 将透传 OAuth 请求体收敛为旧链路关键行为:
// 1) store=false 2) stream=true
func normalizeOpenAIPassthroughOAuthBody(body []byte) ([]byte, bool, error) {
if len(body) == 0 {
return body, false, nil
}
normalized := body
changed := false
if store := gjson.GetBytes(normalized, "store"); !store.Exists() || store.Type != gjson.False {
next, err := sjson.SetBytes(normalized, "store", false)
if err != nil {
return body, false, fmt.Errorf("normalize passthrough body store=false: %w", err)
}
normalized = next
changed = true
}
if stream := gjson.GetBytes(normalized, "stream"); !stream.Exists() || stream.Type != gjson.True {
next, err := sjson.SetBytes(normalized, "stream", true)
if err != nil {
return body, false, fmt.Errorf("normalize passthrough body stream=true: %w", err)
}
normalized = next
changed = true
}
return normalized, changed, nil
}
func detectOpenAIPassthroughInstructionsRejectReason(reqModel string, body []byte) string {
model := strings.ToLower(strings.TrimSpace(reqModel))
if !strings.Contains(model, "codex") {
return ""
}
instructions := gjson.GetBytes(body, "instructions")
if !instructions.Exists() {
return "instructions_missing"
}
if instructions.Type != gjson.String {
return "instructions_not_string"
}
if strings.TrimSpace(instructions.String()) == "" {
return "instructions_empty"
}
return ""
}
func extractOpenAIReasoningEffortFromBody(body []byte, requestedModel string) *string {
reasoningEffort := strings.TrimSpace(gjson.GetBytes(body, "reasoning.effort").String())
if reasoningEffort == "" {
reasoningEffort = strings.TrimSpace(gjson.GetBytes(body, "reasoning_effort").String())
}
if reasoningEffort != "" {
normalized := normalizeOpenAIReasoningEffort(reasoningEffort)
if normalized == "" {
return nil
}
return &normalized
}
value := deriveOpenAIReasoningEffortFromModel(requestedModel)
if value == "" {
return nil
}
return &value
}
func getOpenAIRequestBodyMap(c *gin.Context, body []byte) (map[string]any, error) {
if c != nil {
if cached, ok := c.Get(OpenAIParsedRequestBodyKey); ok {
if reqBody, ok := cached.(map[string]any); ok && reqBody != nil {
return reqBody, nil
}
}
}
var reqBody map[string]any
if err := json.Unmarshal(body, &reqBody); err != nil {
return nil, fmt.Errorf("parse request: %w", err)
}
return reqBody, nil
}
func extractOpenAIReasoningEffort(reqBody map[string]any, requestedModel string) *string {
if value, present := getOpenAIReasoningEffortFromReqBody(reqBody); present {
if value == "" {
......
package service
import (
"bytes"
"context"
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
type stubCodexRestrictionDetector struct {
result CodexClientRestrictionDetectionResult
}
func (s *stubCodexRestrictionDetector) Detect(_ *gin.Context, _ *Account) CodexClientRestrictionDetectionResult {
return s.result
}
func TestOpenAIGatewayService_GetCodexClientRestrictionDetector(t *testing.T) {
gin.SetMode(gin.TestMode)
t.Run("使用注入的 detector", func(t *testing.T) {
expected := &stubCodexRestrictionDetector{
result: CodexClientRestrictionDetectionResult{Enabled: true, Matched: true, Reason: "stub"},
}
svc := &OpenAIGatewayService{codexDetector: expected}
got := svc.getCodexClientRestrictionDetector()
require.Same(t, expected, got)
})
t.Run("service 为 nil 时返回默认 detector", func(t *testing.T) {
var svc *OpenAIGatewayService
got := svc.getCodexClientRestrictionDetector()
require.NotNil(t, got)
})
t.Run("service 未注入 detector 时返回默认 detector", func(t *testing.T) {
svc := &OpenAIGatewayService{cfg: &config.Config{Gateway: config.GatewayConfig{ForceCodexCLI: true}}}
got := svc.getCodexClientRestrictionDetector()
require.NotNil(t, got)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
c.Request.Header.Set("User-Agent", "curl/8.0")
account := &Account{Platform: PlatformOpenAI, Type: AccountTypeOAuth, Extra: map[string]any{"codex_cli_only": true}}
result := got.Detect(c, account)
require.True(t, result.Enabled)
require.True(t, result.Matched)
require.Equal(t, CodexClientRestrictionReasonForceCodexCLI, result.Reason)
})
}
func TestGetAPIKeyIDFromContext(t *testing.T) {
gin.SetMode(gin.TestMode)
t.Run("context 为 nil", func(t *testing.T) {
require.Equal(t, int64(0), getAPIKeyIDFromContext(nil))
})
t.Run("上下文没有 api_key", func(t *testing.T) {
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
require.Equal(t, int64(0), getAPIKeyIDFromContext(c))
})
t.Run("api_key 类型错误", func(t *testing.T) {
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Set("api_key", "not-api-key")
require.Equal(t, int64(0), getAPIKeyIDFromContext(c))
})
t.Run("api_key 指针为空", func(t *testing.T) {
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
var k *APIKey
c.Set("api_key", k)
require.Equal(t, int64(0), getAPIKeyIDFromContext(c))
})
t.Run("正常读取 api_key_id", func(t *testing.T) {
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Set("api_key", &APIKey{ID: 12345})
require.Equal(t, int64(12345), getAPIKeyIDFromContext(c))
})
}
func TestLogCodexCLIOnlyDetection_NilSafety(t *testing.T) {
// 不校验日志内容,仅保证在 nil 入参下不会 panic。
require.NotPanics(t, func() {
logCodexCLIOnlyDetection(context.TODO(), nil, nil, 0, CodexClientRestrictionDetectionResult{Enabled: true, Matched: false, Reason: "test"}, nil)
logCodexCLIOnlyDetection(context.Background(), nil, nil, 0, CodexClientRestrictionDetectionResult{Enabled: false, Matched: false, Reason: "disabled"}, nil)
})
}
func TestLogCodexCLIOnlyDetection_OnlyLogsRejected(t *testing.T) {
logSink, restore := captureStructuredLog(t)
defer restore()
account := &Account{ID: 1001}
logCodexCLIOnlyDetection(context.Background(), nil, account, 2002, CodexClientRestrictionDetectionResult{
Enabled: true,
Matched: true,
Reason: CodexClientRestrictionReasonMatchedUA,
}, nil)
logCodexCLIOnlyDetection(context.Background(), nil, account, 2002, CodexClientRestrictionDetectionResult{
Enabled: true,
Matched: false,
Reason: CodexClientRestrictionReasonNotMatchedUA,
}, nil)
require.False(t, logSink.ContainsMessage("OpenAI codex_cli_only 允许官方客户端请求"))
require.True(t, logSink.ContainsMessage("OpenAI codex_cli_only 拒绝非官方客户端请求"))
}
func TestLogCodexCLIOnlyDetection_RejectedIncludesRequestDetails(t *testing.T) {
gin.SetMode(gin.TestMode)
logSink, restore := captureStructuredLog(t)
defer restore()
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses?trace=1", bytes.NewReader(nil))
c.Request.Header.Set("User-Agent", "codex_cli_rs/0.98.0 (Windows 10.0.19045; x86_64) unknown")
c.Request.Header.Set("Content-Type", "application/json")
c.Request.Header.Set("OpenAI-Beta", "assistants=v2")
body := []byte(`{"model":"gpt-5.2","stream":false,"prompt_cache_key":"pc-123","access_token":"secret-token","input":[{"type":"text","text":"hello"}]}`)
account := &Account{ID: 1001}
logCodexCLIOnlyDetection(context.Background(), c, account, 2002, CodexClientRestrictionDetectionResult{
Enabled: true,
Matched: false,
Reason: CodexClientRestrictionReasonNotMatchedUA,
}, body)
require.True(t, logSink.ContainsFieldValue("request_user_agent", "codex_cli_rs/0.98.0 (Windows 10.0.19045; x86_64) unknown"))
require.True(t, logSink.ContainsFieldValue("request_model", "gpt-5.2"))
require.True(t, logSink.ContainsFieldValue("request_query", "trace=1"))
require.True(t, logSink.ContainsFieldValue("request_prompt_cache_key_sha256", hashSensitiveValueForLog("pc-123")))
require.True(t, logSink.ContainsFieldValue("request_headers", "openai-beta"))
require.True(t, logSink.ContainsField("request_body_size"))
require.False(t, logSink.ContainsField("request_body_preview"))
}
func TestLogOpenAIInstructionsRequiredDebug_LogsRequestDetails(t *testing.T) {
gin.SetMode(gin.TestMode)
logSink, restore := captureStructuredLog(t)
defer restore()
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses?trace=1", bytes.NewReader(nil))
c.Request.Header.Set("User-Agent", "curl/8.0")
c.Request.Header.Set("Content-Type", "application/json")
c.Request.Header.Set("OpenAI-Beta", "assistants=v2")
body := []byte(`{"model":"gpt-5.1-codex","stream":false,"prompt_cache_key":"pc-abc","access_token":"secret-token","input":[{"type":"text","text":"hello"}]}`)
account := &Account{ID: 1001, Name: "codex max套餐"}
logOpenAIInstructionsRequiredDebug(
context.Background(),
c,
account,
http.StatusBadRequest,
"Instructions are required",
body,
[]byte(`{"error":{"message":"Instructions are required","type":"invalid_request_error","param":"instructions","code":"missing_required_parameter"}}`),
)
require.True(t, logSink.ContainsMessageAtLevel("OpenAI 上游返回 Instructions are required,已记录请求详情用于排查", "warn"))
require.True(t, logSink.ContainsFieldValue("request_user_agent", "curl/8.0"))
require.True(t, logSink.ContainsFieldValue("request_model", "gpt-5.1-codex"))
require.True(t, logSink.ContainsFieldValue("request_query", "trace=1"))
require.True(t, logSink.ContainsFieldValue("account_name", "codex max套餐"))
require.True(t, logSink.ContainsFieldValue("request_headers", "openai-beta"))
require.True(t, logSink.ContainsField("request_body_size"))
require.False(t, logSink.ContainsField("request_body_preview"))
}
func TestLogOpenAIInstructionsRequiredDebug_NonTargetErrorSkipped(t *testing.T) {
gin.SetMode(gin.TestMode)
logSink, restore := captureStructuredLog(t)
defer restore()
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", bytes.NewReader(nil))
c.Request.Header.Set("User-Agent", "curl/8.0")
body := []byte(`{"model":"gpt-5.1-codex","stream":false}`)
logOpenAIInstructionsRequiredDebug(
context.Background(),
c,
&Account{ID: 1001},
http.StatusForbidden,
"forbidden",
body,
[]byte(`{"error":{"message":"forbidden"}}`),
)
require.False(t, logSink.ContainsMessage("OpenAI 上游返回 Instructions are required,已记录请求详情用于排查"))
}
func TestOpenAIGatewayService_Forward_LogsInstructionsRequiredDetails(t *testing.T) {
gin.SetMode(gin.TestMode)
logSink, restore := captureStructuredLog(t)
defer restore()
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses?trace=1", bytes.NewReader(nil))
c.Request.Header.Set("User-Agent", "codex_cli_rs/0.1.0")
c.Request.Header.Set("Content-Type", "application/json")
c.Request.Header.Set("OpenAI-Beta", "assistants=v2")
upstream := &httpUpstreamRecorder{
resp: &http.Response{
StatusCode: http.StatusBadRequest,
Header: http.Header{
"Content-Type": []string{"application/json"},
"x-request-id": []string{"rid-upstream"},
},
Body: io.NopCloser(strings.NewReader(`{"error":{"message":"Missing required parameter: 'instructions'","type":"invalid_request_error","param":"instructions","code":"missing_required_parameter"}}`)),
},
}
svc := &OpenAIGatewayService{
cfg: &config.Config{
Gateway: config.GatewayConfig{ForceCodexCLI: false},
},
httpUpstream: upstream,
}
account := &Account{
ID: 1001,
Name: "codex max套餐",
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Concurrency: 1,
Credentials: map[string]any{"api_key": "sk-test"},
Status: StatusActive,
Schedulable: true,
RateMultiplier: f64p(1),
}
body := []byte(`{"model":"gpt-5.1-codex","stream":false,"input":[{"type":"text","text":"hello"}],"prompt_cache_key":"pc-forward","access_token":"secret-token"}`)
_, err := svc.Forward(context.Background(), c, account, body)
require.Error(t, err)
require.Equal(t, http.StatusBadGateway, rec.Code)
require.Contains(t, err.Error(), "upstream error: 400")
require.True(t, logSink.ContainsMessageAtLevel("OpenAI 上游返回 Instructions are required,已记录请求详情用于排查", "warn"))
require.True(t, logSink.ContainsFieldValue("request_user_agent", "codex_cli_rs/0.1.0"))
require.True(t, logSink.ContainsFieldValue("request_model", "gpt-5.1-codex"))
require.True(t, logSink.ContainsFieldValue("request_headers", "openai-beta"))
require.True(t, logSink.ContainsField("request_body_size"))
require.False(t, logSink.ContainsField("request_body_preview"))
}
package service
import (
"testing"
"time"
)
func TestCodexSnapshotBaseTime(t *testing.T) {
fallback := time.Date(2026, 2, 20, 9, 0, 0, 0, time.UTC)
t.Run("nil snapshot uses fallback", func(t *testing.T) {
got := codexSnapshotBaseTime(nil, fallback)
if !got.Equal(fallback) {
t.Fatalf("got %v, want fallback %v", got, fallback)
}
})
t.Run("empty updatedAt uses fallback", func(t *testing.T) {
got := codexSnapshotBaseTime(&OpenAICodexUsageSnapshot{}, fallback)
if !got.Equal(fallback) {
t.Fatalf("got %v, want fallback %v", got, fallback)
}
})
t.Run("valid updatedAt wins", func(t *testing.T) {
got := codexSnapshotBaseTime(&OpenAICodexUsageSnapshot{UpdatedAt: "2026-02-16T10:00:00Z"}, fallback)
want := time.Date(2026, 2, 16, 10, 0, 0, 0, time.UTC)
if !got.Equal(want) {
t.Fatalf("got %v, want %v", got, want)
}
})
t.Run("invalid updatedAt uses fallback", func(t *testing.T) {
got := codexSnapshotBaseTime(&OpenAICodexUsageSnapshot{UpdatedAt: "invalid"}, fallback)
if !got.Equal(fallback) {
t.Fatalf("got %v, want fallback %v", got, fallback)
}
})
}
func TestCodexResetAtRFC3339(t *testing.T) {
base := time.Date(2026, 2, 16, 10, 0, 0, 0, time.UTC)
t.Run("nil reset returns nil", func(t *testing.T) {
if got := codexResetAtRFC3339(base, nil); got != nil {
t.Fatalf("expected nil, got %v", *got)
}
})
t.Run("positive seconds", func(t *testing.T) {
sec := 90
got := codexResetAtRFC3339(base, &sec)
if got == nil {
t.Fatal("expected non-nil")
}
if *got != "2026-02-16T10:01:30Z" {
t.Fatalf("got %s, want %s", *got, "2026-02-16T10:01:30Z")
}
})
t.Run("negative seconds clamp to base", func(t *testing.T) {
sec := -3
got := codexResetAtRFC3339(base, &sec)
if got == nil {
t.Fatal("expected non-nil")
}
if *got != "2026-02-16T10:00:00Z" {
t.Fatalf("got %s, want %s", *got, "2026-02-16T10:00:00Z")
}
})
}
func TestBuildCodexUsageExtraUpdates_UsesSnapshotUpdatedAt(t *testing.T) {
primaryUsed := 88.0
primaryReset := 86400
primaryWindow := 10080
secondaryUsed := 12.0
secondaryReset := 3600
secondaryWindow := 300
snapshot := &OpenAICodexUsageSnapshot{
PrimaryUsedPercent: &primaryUsed,
PrimaryResetAfterSeconds: &primaryReset,
PrimaryWindowMinutes: &primaryWindow,
SecondaryUsedPercent: &secondaryUsed,
SecondaryResetAfterSeconds: &secondaryReset,
SecondaryWindowMinutes: &secondaryWindow,
UpdatedAt: "2026-02-16T10:00:00Z",
}
updates := buildCodexUsageExtraUpdates(snapshot, time.Date(2026, 2, 20, 8, 0, 0, 0, time.UTC))
if updates == nil {
t.Fatal("expected non-nil updates")
}
if got := updates["codex_usage_updated_at"]; got != "2026-02-16T10:00:00Z" {
t.Fatalf("codex_usage_updated_at = %v, want %s", got, "2026-02-16T10:00:00Z")
}
if got := updates["codex_5h_reset_at"]; got != "2026-02-16T11:00:00Z" {
t.Fatalf("codex_5h_reset_at = %v, want %s", got, "2026-02-16T11:00:00Z")
}
if got := updates["codex_7d_reset_at"]; got != "2026-02-17T10:00:00Z" {
t.Fatalf("codex_7d_reset_at = %v, want %s", got, "2026-02-17T10:00:00Z")
}
}
func TestBuildCodexUsageExtraUpdates_FallbackToNowWhenUpdatedAtInvalid(t *testing.T) {
primaryUsed := 15.0
primaryReset := 30
primaryWindow := 300
fallbackNow := time.Date(2026, 2, 20, 8, 30, 0, 0, time.UTC)
snapshot := &OpenAICodexUsageSnapshot{
PrimaryUsedPercent: &primaryUsed,
PrimaryResetAfterSeconds: &primaryReset,
PrimaryWindowMinutes: &primaryWindow,
UpdatedAt: "invalid-time",
}
updates := buildCodexUsageExtraUpdates(snapshot, fallbackNow)
if updates == nil {
t.Fatal("expected non-nil updates")
}
if got := updates["codex_usage_updated_at"]; got != "2026-02-20T08:30:00Z" {
t.Fatalf("codex_usage_updated_at = %v, want %s", got, "2026-02-20T08:30:00Z")
}
if got := updates["codex_5h_reset_at"]; got != "2026-02-20T08:30:30Z" {
t.Fatalf("codex_5h_reset_at = %v, want %s", got, "2026-02-20T08:30:30Z")
}
}
func TestBuildCodexUsageExtraUpdates_ClampNegativeResetSeconds(t *testing.T) {
primaryUsed := 90.0
primaryReset := 7200
primaryWindow := 10080
secondaryUsed := 100.0
secondaryReset := -15
secondaryWindow := 300
snapshot := &OpenAICodexUsageSnapshot{
PrimaryUsedPercent: &primaryUsed,
PrimaryResetAfterSeconds: &primaryReset,
PrimaryWindowMinutes: &primaryWindow,
SecondaryUsedPercent: &secondaryUsed,
SecondaryResetAfterSeconds: &secondaryReset,
SecondaryWindowMinutes: &secondaryWindow,
UpdatedAt: "2026-02-16T10:00:00Z",
}
updates := buildCodexUsageExtraUpdates(snapshot, time.Time{})
if updates == nil {
t.Fatal("expected non-nil updates")
}
if got := updates["codex_5h_reset_after_seconds"]; got != -15 {
t.Fatalf("codex_5h_reset_after_seconds = %v, want %d", got, -15)
}
if got := updates["codex_5h_reset_at"]; got != "2026-02-16T10:00:00Z" {
t.Fatalf("codex_5h_reset_at = %v, want %s", got, "2026-02-16T10:00:00Z")
}
}
func TestBuildCodexUsageExtraUpdates_NilSnapshot(t *testing.T) {
if got := buildCodexUsageExtraUpdates(nil, time.Now()); got != nil {
t.Fatalf("expected nil updates, got %v", got)
}
}
func TestBuildCodexUsageExtraUpdates_WithoutNormalizedWindowFields(t *testing.T) {
primaryUsed := 42.0
fallbackNow := time.Date(2026, 2, 20, 9, 15, 0, 0, time.UTC)
snapshot := &OpenAICodexUsageSnapshot{
PrimaryUsedPercent: &primaryUsed,
UpdatedAt: "",
}
updates := buildCodexUsageExtraUpdates(snapshot, fallbackNow)
if updates == nil {
t.Fatal("expected non-nil updates")
}
if got := updates["codex_usage_updated_at"]; got != "2026-02-20T09:15:00Z" {
t.Fatalf("codex_usage_updated_at = %v, want %s", got, "2026-02-20T09:15:00Z")
}
if _, ok := updates["codex_5h_reset_at"]; ok {
t.Fatalf("did not expect codex_5h_reset_at in updates: %v", updates["codex_5h_reset_at"])
}
if _, ok := updates["codex_7d_reset_at"]; ok {
t.Fatalf("did not expect codex_7d_reset_at in updates: %v", updates["codex_7d_reset_at"])
}
}
package service
import (
"net/http/httptest"
"testing"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
func TestExtractOpenAIRequestMetaFromBody(t *testing.T) {
tests := []struct {
name string
body []byte
wantModel string
wantStream bool
wantPromptKey string
}{
{
name: "完整字段",
body: []byte(`{"model":"gpt-5","stream":true,"prompt_cache_key":" ses-1 "}`),
wantModel: "gpt-5",
wantStream: true,
wantPromptKey: "ses-1",
},
{
name: "缺失可选字段",
body: []byte(`{"model":"gpt-4"}`),
wantModel: "gpt-4",
wantStream: false,
wantPromptKey: "",
},
{
name: "空请求体",
body: nil,
wantModel: "",
wantStream: false,
wantPromptKey: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
model, stream, promptKey := extractOpenAIRequestMetaFromBody(tt.body)
require.Equal(t, tt.wantModel, model)
require.Equal(t, tt.wantStream, stream)
require.Equal(t, tt.wantPromptKey, promptKey)
})
}
}
func TestExtractOpenAIReasoningEffortFromBody(t *testing.T) {
tests := []struct {
name string
body []byte
model string
wantNil bool
wantValue string
}{
{
name: "优先读取 reasoning.effort",
body: []byte(`{"reasoning":{"effort":"medium"}}`),
model: "gpt-5-high",
wantNil: false,
wantValue: "medium",
},
{
name: "兼容 reasoning_effort",
body: []byte(`{"reasoning_effort":"x-high"}`),
model: "",
wantNil: false,
wantValue: "xhigh",
},
{
name: "minimal 归一化为空",
body: []byte(`{"reasoning":{"effort":"minimal"}}`),
model: "gpt-5-high",
wantNil: true,
},
{
name: "缺失字段时从模型后缀推导",
body: []byte(`{"input":"hi"}`),
model: "gpt-5-high",
wantNil: false,
wantValue: "high",
},
{
name: "未知后缀不返回",
body: []byte(`{"input":"hi"}`),
model: "gpt-5-unknown",
wantNil: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := extractOpenAIReasoningEffortFromBody(tt.body, tt.model)
if tt.wantNil {
require.Nil(t, got)
return
}
require.NotNil(t, got)
require.Equal(t, tt.wantValue, *got)
})
}
}
func TestGetOpenAIRequestBodyMap_UsesContextCache(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
cached := map[string]any{"model": "cached-model", "stream": true}
c.Set(OpenAIParsedRequestBodyKey, cached)
got, err := getOpenAIRequestBodyMap(c, []byte(`{invalid-json`))
require.NoError(t, err)
require.Equal(t, cached, got)
}
func TestGetOpenAIRequestBodyMap_ParseErrorWithoutCache(t *testing.T) {
_, err := getOpenAIRequestBodyMap(nil, []byte(`{invalid-json`))
require.Error(t, err)
require.Contains(t, err.Error(), "parse request")
}
......@@ -14,8 +14,13 @@ import (
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
// 编译期接口断言
var _ AccountRepository = (*stubOpenAIAccountRepo)(nil)
var _ GatewayCache = (*stubGatewayCache)(nil)
type stubOpenAIAccountRepo struct {
AccountRepository
accounts []Account
......@@ -124,17 +129,19 @@ func TestOpenAIGatewayService_GenerateSessionHash_Priority(t *testing.T) {
svc := &OpenAIGatewayService{}
bodyWithKey := []byte(`{"prompt_cache_key":"ses_aaa"}`)
// 1) session_id header wins
c.Request.Header.Set("session_id", "sess-123")
c.Request.Header.Set("conversation_id", "conv-456")
h1 := svc.GenerateSessionHash(c, map[string]any{"prompt_cache_key": "ses_aaa"})
h1 := svc.GenerateSessionHash(c, bodyWithKey)
if h1 == "" {
t.Fatalf("expected non-empty hash")
}
// 2) conversation_id used when session_id absent
c.Request.Header.Del("session_id")
h2 := svc.GenerateSessionHash(c, map[string]any{"prompt_cache_key": "ses_aaa"})
h2 := svc.GenerateSessionHash(c, bodyWithKey)
if h2 == "" {
t.Fatalf("expected non-empty hash")
}
......@@ -144,7 +151,7 @@ func TestOpenAIGatewayService_GenerateSessionHash_Priority(t *testing.T) {
// 3) prompt_cache_key used when both headers absent
c.Request.Header.Del("conversation_id")
h3 := svc.GenerateSessionHash(c, map[string]any{"prompt_cache_key": "ses_aaa"})
h3 := svc.GenerateSessionHash(c, bodyWithKey)
if h3 == "" {
t.Fatalf("expected non-empty hash")
}
......@@ -153,7 +160,7 @@ func TestOpenAIGatewayService_GenerateSessionHash_Priority(t *testing.T) {
}
// 4) empty when no signals
h4 := svc.GenerateSessionHash(c, map[string]any{})
h4 := svc.GenerateSessionHash(c, []byte(`{}`))
if h4 != "" {
t.Fatalf("expected empty hash when no signals")
}
......@@ -1066,6 +1073,43 @@ func TestOpenAIStreamingHeadersOverride(t *testing.T) {
}
}
func TestOpenAIStreamingReuseScannerBufferAndStillWorks(t *testing.T) {
gin.SetMode(gin.TestMode)
cfg := &config.Config{
Gateway: config.GatewayConfig{
StreamDataIntervalTimeout: 0,
StreamKeepaliveInterval: 0,
MaxLineSize: defaultMaxLineSize,
},
}
svc := &OpenAIGatewayService{cfg: cfg}
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
pr, pw := io.Pipe()
resp := &http.Response{
StatusCode: http.StatusOK,
Body: pr,
Header: http.Header{},
}
go func() {
defer func() { _ = pw.Close() }()
_, _ = pw.Write([]byte("data: {\"type\":\"response.completed\",\"response\":{\"usage\":{\"input_tokens\":1,\"output_tokens\":2,\"input_tokens_details\":{\"cached_tokens\":3}}}}\n\n"))
}()
result, err := svc.handleStreamingResponse(c.Request.Context(), resp, c, &Account{ID: 1}, time.Now(), "model", "model")
_ = pr.Close()
require.NoError(t, err)
require.NotNil(t, result)
require.NotNil(t, result.usage)
require.Equal(t, 1, result.usage.InputTokens)
require.Equal(t, 2, result.usage.OutputTokens)
require.Equal(t, 3, result.usage.CacheReadInputTokens)
}
func TestOpenAIInvalidBaseURLWhenAllowlistDisabled(t *testing.T) {
gin.SetMode(gin.TestMode)
cfg := &config.Config{
......@@ -1149,3 +1193,332 @@ func TestOpenAIValidateUpstreamBaseURLEnabledEnforcesAllowlist(t *testing.T) {
t.Fatalf("expected non-allowlisted host to fail")
}
}
// ==================== P1-08 修复:model 替换性能优化测试 ====================
func TestReplaceModelInSSELine(t *testing.T) {
svc := &OpenAIGatewayService{}
tests := []struct {
name string
line string
from string
to string
expected string
}{
{
name: "顶层 model 字段替换",
line: `data: {"id":"chatcmpl-123","model":"gpt-4o","choices":[]}`,
from: "gpt-4o",
to: "my-custom-model",
expected: `data: {"id":"chatcmpl-123","model":"my-custom-model","choices":[]}`,
},
{
name: "嵌套 response.model 替换",
line: `data: {"type":"response","response":{"id":"resp-1","model":"gpt-4o","output":[]}}`,
from: "gpt-4o",
to: "my-model",
expected: `data: {"type":"response","response":{"id":"resp-1","model":"my-model","output":[]}}`,
},
{
name: "model 不匹配时不替换",
line: `data: {"id":"chatcmpl-123","model":"gpt-3.5-turbo","choices":[]}`,
from: "gpt-4o",
to: "my-model",
expected: `data: {"id":"chatcmpl-123","model":"gpt-3.5-turbo","choices":[]}`,
},
{
name: "无 model 字段时不替换",
line: `data: {"id":"chatcmpl-123","choices":[]}`,
from: "gpt-4o",
to: "my-model",
expected: `data: {"id":"chatcmpl-123","choices":[]}`,
},
{
name: "空 data 行",
line: `data: `,
from: "gpt-4o",
to: "my-model",
expected: `data: `,
},
{
name: "[DONE] 行",
line: `data: [DONE]`,
from: "gpt-4o",
to: "my-model",
expected: `data: [DONE]`,
},
{
name: "非 data: 前缀行",
line: `event: message`,
from: "gpt-4o",
to: "my-model",
expected: `event: message`,
},
{
name: "非法 JSON 不替换",
line: `data: {invalid json}`,
from: "gpt-4o",
to: "my-model",
expected: `data: {invalid json}`,
},
{
name: "无空格 data: 格式",
line: `data:{"id":"x","model":"gpt-4o"}`,
from: "gpt-4o",
to: "my-model",
expected: `data: {"id":"x","model":"my-model"}`,
},
{
name: "model 名含特殊字符",
line: `data: {"model":"org/model-v2.1-beta"}`,
from: "org/model-v2.1-beta",
to: "custom/alias",
expected: `data: {"model":"custom/alias"}`,
},
{
name: "空行",
line: "",
from: "gpt-4o",
to: "my-model",
expected: "",
},
{
name: "保持其他字段不变",
line: `data: {"id":"abc","object":"chat.completion.chunk","model":"gpt-4o","created":1234567890,"choices":[{"index":0,"delta":{"content":"hi"}}]}`,
from: "gpt-4o",
to: "alias",
expected: `data: {"id":"abc","object":"chat.completion.chunk","model":"alias","created":1234567890,"choices":[{"index":0,"delta":{"content":"hi"}}]}`,
},
{
name: "顶层优先于嵌套:同时存在两个 model",
line: `data: {"model":"gpt-4o","response":{"model":"gpt-4o"}}`,
from: "gpt-4o",
to: "replaced",
expected: `data: {"model":"replaced","response":{"model":"gpt-4o"}}`,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := svc.replaceModelInSSELine(tt.line, tt.from, tt.to)
require.Equal(t, tt.expected, got)
})
}
}
func TestReplaceModelInSSEBody(t *testing.T) {
svc := &OpenAIGatewayService{}
tests := []struct {
name string
body string
from string
to string
expected string
}{
{
name: "多行 SSE body 替换",
body: "data: {\"model\":\"gpt-4o\",\"choices\":[]}\n\ndata: {\"model\":\"gpt-4o\",\"choices\":[{\"delta\":{\"content\":\"hi\"}}]}\n\ndata: [DONE]\n",
from: "gpt-4o",
to: "alias",
expected: "data: {\"model\":\"alias\",\"choices\":[]}\n\ndata: {\"model\":\"alias\",\"choices\":[{\"delta\":{\"content\":\"hi\"}}]}\n\ndata: [DONE]\n",
},
{
name: "无需替换的 body",
body: "data: {\"model\":\"gpt-3.5-turbo\"}\n\ndata: [DONE]\n",
from: "gpt-4o",
to: "alias",
expected: "data: {\"model\":\"gpt-3.5-turbo\"}\n\ndata: [DONE]\n",
},
{
name: "混合 event 和 data 行",
body: "event: message\ndata: {\"model\":\"gpt-4o\"}\n\n",
from: "gpt-4o",
to: "alias",
expected: "event: message\ndata: {\"model\":\"alias\"}\n\n",
},
{
name: "空 body",
body: "",
from: "gpt-4o",
to: "alias",
expected: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := svc.replaceModelInSSEBody(tt.body, tt.from, tt.to)
require.Equal(t, tt.expected, got)
})
}
}
func TestReplaceModelInResponseBody(t *testing.T) {
svc := &OpenAIGatewayService{}
tests := []struct {
name string
body string
from string
to string
expected string
}{
{
name: "替换顶层 model",
body: `{"id":"chatcmpl-123","model":"gpt-4o","choices":[]}`,
from: "gpt-4o",
to: "alias",
expected: `{"id":"chatcmpl-123","model":"alias","choices":[]}`,
},
{
name: "model 不匹配不替换",
body: `{"id":"chatcmpl-123","model":"gpt-3.5-turbo","choices":[]}`,
from: "gpt-4o",
to: "alias",
expected: `{"id":"chatcmpl-123","model":"gpt-3.5-turbo","choices":[]}`,
},
{
name: "无 model 字段不替换",
body: `{"id":"chatcmpl-123","choices":[]}`,
from: "gpt-4o",
to: "alias",
expected: `{"id":"chatcmpl-123","choices":[]}`,
},
{
name: "非法 JSON 返回原值",
body: `not json`,
from: "gpt-4o",
to: "alias",
expected: `not json`,
},
{
name: "空 body 返回原值",
body: ``,
from: "gpt-4o",
to: "alias",
expected: ``,
},
{
name: "保持嵌套结构不变",
body: `{"model":"gpt-4o","usage":{"prompt_tokens":10,"completion_tokens":20},"choices":[{"message":{"role":"assistant","content":"hello"}}]}`,
from: "gpt-4o",
to: "alias",
expected: `{"model":"alias","usage":{"prompt_tokens":10,"completion_tokens":20},"choices":[{"message":{"role":"assistant","content":"hello"}}]}`,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := svc.replaceModelInResponseBody([]byte(tt.body), tt.from, tt.to)
require.Equal(t, tt.expected, string(got))
})
}
}
func TestExtractOpenAISSEDataLine(t *testing.T) {
tests := []struct {
name string
line string
wantData string
wantOK bool
}{
{name: "标准格式", line: `data: {"type":"x"}`, wantData: `{"type":"x"}`, wantOK: true},
{name: "无空格格式", line: `data:{"type":"x"}`, wantData: `{"type":"x"}`, wantOK: true},
{name: "纯空数据", line: `data: `, wantData: ``, wantOK: true},
{name: "非 data 行", line: `event: message`, wantData: ``, wantOK: false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, ok := extractOpenAISSEDataLine(tt.line)
require.Equal(t, tt.wantOK, ok)
require.Equal(t, tt.wantData, got)
})
}
}
func TestParseSSEUsage_SelectiveParsing(t *testing.T) {
svc := &OpenAIGatewayService{}
usage := &OpenAIUsage{InputTokens: 9, OutputTokens: 8, CacheReadInputTokens: 7}
// 非 completed 事件,不应覆盖 usage
svc.parseSSEUsage(`{"type":"response.in_progress","response":{"usage":{"input_tokens":1,"output_tokens":2}}}`, usage)
require.Equal(t, 9, usage.InputTokens)
require.Equal(t, 8, usage.OutputTokens)
require.Equal(t, 7, usage.CacheReadInputTokens)
// completed 事件,应提取 usage
svc.parseSSEUsage(`{"type":"response.completed","response":{"usage":{"input_tokens":3,"output_tokens":5,"input_tokens_details":{"cached_tokens":2}}}}`, usage)
require.Equal(t, 3, usage.InputTokens)
require.Equal(t, 5, usage.OutputTokens)
require.Equal(t, 2, usage.CacheReadInputTokens)
}
func TestExtractCodexFinalResponse_SampleReplay(t *testing.T) {
body := strings.Join([]string{
`event: message`,
`data: {"type":"response.in_progress","response":{"id":"resp_1"}}`,
`data: {"type":"response.completed","response":{"id":"resp_1","model":"gpt-4o","usage":{"input_tokens":11,"output_tokens":22,"input_tokens_details":{"cached_tokens":3}}}}`,
`data: [DONE]`,
}, "\n")
finalResp, ok := extractCodexFinalResponse(body)
require.True(t, ok)
require.Contains(t, string(finalResp), `"id":"resp_1"`)
require.Contains(t, string(finalResp), `"input_tokens":11`)
}
func TestHandleOAuthSSEToJSON_CompletedEventReturnsJSON(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
svc := &OpenAIGatewayService{cfg: &config.Config{}}
resp := &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"text/event-stream"}},
}
body := []byte(strings.Join([]string{
`data: {"type":"response.in_progress","response":{"id":"resp_2"}}`,
`data: {"type":"response.completed","response":{"id":"resp_2","model":"gpt-4o","usage":{"input_tokens":7,"output_tokens":9,"input_tokens_details":{"cached_tokens":1}}}}`,
`data: [DONE]`,
}, "\n"))
usage, err := svc.handleOAuthSSEToJSON(resp, c, body, "gpt-4o", "gpt-4o")
require.NoError(t, err)
require.NotNil(t, usage)
require.Equal(t, 7, usage.InputTokens)
require.Equal(t, 9, usage.OutputTokens)
require.Equal(t, 1, usage.CacheReadInputTokens)
// Header 可能由上游 Content-Type 透传;关键是 body 已转换为最终 JSON 响应。
require.NotContains(t, rec.Body.String(), "event:")
require.Contains(t, rec.Body.String(), `"id":"resp_2"`)
require.NotContains(t, rec.Body.String(), "data:")
}
func TestHandleOAuthSSEToJSON_NoFinalResponseKeepsSSEBody(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
svc := &OpenAIGatewayService{cfg: &config.Config{}}
resp := &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"text/event-stream"}},
}
body := []byte(strings.Join([]string{
`data: {"type":"response.in_progress","response":{"id":"resp_3"}}`,
`data: [DONE]`,
}, "\n"))
usage, err := svc.handleOAuthSSEToJSON(resp, c, body, "gpt-4o", "gpt-4o")
require.NoError(t, err)
require.NotNil(t, usage)
require.Equal(t, 0, usage.InputTokens)
require.Contains(t, rec.Header().Get("Content-Type"), "text/event-stream")
require.Contains(t, rec.Body.String(), `data: {"type":"response.in_progress"`)
}
package service
import (
"bytes"
"context"
"fmt"
"io"
"net/http"
"net/http/httptest"
"strings"
"sync"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
"github.com/tidwall/gjson"
)
func f64p(v float64) *float64 { return &v }
type httpUpstreamRecorder struct {
lastReq *http.Request
lastBody []byte
resp *http.Response
err error
}
func (u *httpUpstreamRecorder) Do(req *http.Request, proxyURL string, accountID int64, accountConcurrency int) (*http.Response, error) {
u.lastReq = req
if req != nil && req.Body != nil {
b, _ := io.ReadAll(req.Body)
u.lastBody = b
_ = req.Body.Close()
req.Body = io.NopCloser(bytes.NewReader(b))
}
if u.err != nil {
return nil, u.err
}
return u.resp, nil
}
func (u *httpUpstreamRecorder) DoWithTLS(req *http.Request, proxyURL string, accountID int64, accountConcurrency int, enableTLSFingerprint bool) (*http.Response, error) {
return u.Do(req, proxyURL, accountID, accountConcurrency)
}
var structuredLogCaptureMu sync.Mutex
type inMemoryLogSink struct {
mu sync.Mutex
events []*logger.LogEvent
}
func (s *inMemoryLogSink) WriteLogEvent(event *logger.LogEvent) {
if event == nil {
return
}
cloned := *event
if event.Fields != nil {
cloned.Fields = make(map[string]any, len(event.Fields))
for k, v := range event.Fields {
cloned.Fields[k] = v
}
}
s.mu.Lock()
s.events = append(s.events, &cloned)
s.mu.Unlock()
}
func (s *inMemoryLogSink) ContainsMessage(substr string) bool {
s.mu.Lock()
defer s.mu.Unlock()
for _, ev := range s.events {
if ev != nil && strings.Contains(ev.Message, substr) {
return true
}
}
return false
}
func (s *inMemoryLogSink) ContainsMessageAtLevel(substr, level string) bool {
s.mu.Lock()
defer s.mu.Unlock()
wantLevel := strings.ToLower(strings.TrimSpace(level))
for _, ev := range s.events {
if ev == nil {
continue
}
if strings.Contains(ev.Message, substr) && strings.ToLower(strings.TrimSpace(ev.Level)) == wantLevel {
return true
}
}
return false
}
func (s *inMemoryLogSink) ContainsFieldValue(field, substr string) bool {
s.mu.Lock()
defer s.mu.Unlock()
for _, ev := range s.events {
if ev == nil || ev.Fields == nil {
continue
}
if v, ok := ev.Fields[field]; ok && strings.Contains(fmt.Sprint(v), substr) {
return true
}
}
return false
}
func (s *inMemoryLogSink) ContainsField(field string) bool {
s.mu.Lock()
defer s.mu.Unlock()
for _, ev := range s.events {
if ev == nil || ev.Fields == nil {
continue
}
if _, ok := ev.Fields[field]; ok {
return true
}
}
return false
}
func captureStructuredLog(t *testing.T) (*inMemoryLogSink, func()) {
t.Helper()
structuredLogCaptureMu.Lock()
err := logger.Init(logger.InitOptions{
Level: "debug",
Format: "json",
ServiceName: "sub2api",
Environment: "test",
Output: logger.OutputOptions{
ToStdout: true,
ToFile: false,
},
Sampling: logger.SamplingOptions{Enabled: false},
})
require.NoError(t, err)
sink := &inMemoryLogSink{}
logger.SetSink(sink)
return sink, func() {
logger.SetSink(nil)
structuredLogCaptureMu.Unlock()
}
}
func TestOpenAIGatewayService_OAuthPassthrough_StreamKeepsToolNameAndBodyNormalized(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", bytes.NewReader(nil))
c.Request.Header.Set("User-Agent", "codex_cli_rs/0.1.0")
c.Request.Header.Set("Authorization", "Bearer inbound-should-not-forward")
c.Request.Header.Set("Cookie", "secret=1")
c.Request.Header.Set("X-Api-Key", "sk-inbound")
c.Request.Header.Set("X-Goog-Api-Key", "goog-inbound")
c.Request.Header.Set("Accept-Encoding", "gzip")
c.Request.Header.Set("Proxy-Authorization", "Basic abc")
c.Request.Header.Set("X-Test", "keep")
originalBody := []byte(`{"model":"gpt-5.2","stream":true,"store":true,"instructions":"local-test-instructions","input":[{"type":"text","text":"hi"}]}`)
upstreamSSE := strings.Join([]string{
`data: {"type":"response.output_item.added","item":{"type":"tool_call","tool_calls":[{"function":{"name":"apply_patch"}}]}}`,
"",
"data: [DONE]",
"",
}, "\n")
resp := &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid"}},
Body: io.NopCloser(strings.NewReader(upstreamSSE)),
}
upstream := &httpUpstreamRecorder{resp: resp}
svc := &OpenAIGatewayService{
cfg: &config.Config{Gateway: config.GatewayConfig{ForceCodexCLI: false}},
httpUpstream: upstream,
openAITokenProvider: &OpenAITokenProvider{ // minimal: will be bypassed by nil cache/service, but GetAccessToken uses provider only if non-nil
accountRepo: nil,
},
}
account := &Account{
ID: 123,
Name: "acc",
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Concurrency: 1,
Credentials: map[string]any{"access_token": "oauth-token", "chatgpt_account_id": "chatgpt-acc"},
Extra: map[string]any{"openai_passthrough": true},
Status: StatusActive,
Schedulable: true,
RateMultiplier: f64p(1),
}
// Use the gateway method that reads token from credentials when provider is nil.
svc.openAITokenProvider = nil
result, err := svc.Forward(context.Background(), c, account, originalBody)
require.NoError(t, err)
require.NotNil(t, result)
require.True(t, result.Stream)
// 1) 透传 OAuth 请求体与旧链路关键行为保持一致:store=false + stream=true。
require.Equal(t, false, gjson.GetBytes(upstream.lastBody, "store").Bool())
require.Equal(t, true, gjson.GetBytes(upstream.lastBody, "stream").Bool())
require.Equal(t, "local-test-instructions", strings.TrimSpace(gjson.GetBytes(upstream.lastBody, "instructions").String()))
// 其余关键字段保持原值。
require.Equal(t, "gpt-5.2", gjson.GetBytes(upstream.lastBody, "model").String())
require.Equal(t, "hi", gjson.GetBytes(upstream.lastBody, "input.0.text").String())
// 2) only auth is replaced; inbound auth/cookie are not forwarded
require.Equal(t, "Bearer oauth-token", upstream.lastReq.Header.Get("Authorization"))
require.Equal(t, "codex_cli_rs/0.1.0", upstream.lastReq.Header.Get("User-Agent"))
require.Empty(t, upstream.lastReq.Header.Get("Cookie"))
require.Empty(t, upstream.lastReq.Header.Get("X-Api-Key"))
require.Empty(t, upstream.lastReq.Header.Get("X-Goog-Api-Key"))
require.Empty(t, upstream.lastReq.Header.Get("Accept-Encoding"))
require.Empty(t, upstream.lastReq.Header.Get("Proxy-Authorization"))
require.Empty(t, upstream.lastReq.Header.Get("X-Test"))
// 3) required OAuth headers are present
require.Equal(t, "chatgpt.com", upstream.lastReq.Host)
require.Equal(t, "chatgpt-acc", upstream.lastReq.Header.Get("chatgpt-account-id"))
// 4) downstream SSE keeps tool name (no toolCorrector)
body := rec.Body.String()
require.Contains(t, body, "apply_patch")
require.NotContains(t, body, "\"name\":\"edit\"")
}
func TestOpenAIGatewayService_OAuthPassthrough_CodexMissingInstructionsRejectedBeforeUpstream(t *testing.T) {
gin.SetMode(gin.TestMode)
logSink, restore := captureStructuredLog(t)
defer restore()
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses?trace=1", bytes.NewReader(nil))
c.Request.Header.Set("User-Agent", "codex_cli_rs/0.98.0 (Windows 10.0.19045; x86_64) unknown")
c.Request.Header.Set("Content-Type", "application/json")
c.Request.Header.Set("OpenAI-Beta", "responses=experimental")
// Codex 模型且缺少 instructions,应在本地直接 403 拒绝,不触达上游。
originalBody := []byte(`{"model":"gpt-5.1-codex-max","stream":false,"store":true,"input":[{"type":"text","text":"hi"}]}`)
upstream := &httpUpstreamRecorder{
resp: &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"application/json"}, "x-request-id": []string{"rid"}},
Body: io.NopCloser(strings.NewReader(`{"output":[],"usage":{"input_tokens":1,"output_tokens":1}}`)),
},
}
svc := &OpenAIGatewayService{
cfg: &config.Config{Gateway: config.GatewayConfig{ForceCodexCLI: false}},
httpUpstream: upstream,
}
account := &Account{
ID: 123,
Name: "acc",
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Concurrency: 1,
Credentials: map[string]any{"access_token": "oauth-token", "chatgpt_account_id": "chatgpt-acc"},
Extra: map[string]any{"openai_passthrough": true},
Status: StatusActive,
Schedulable: true,
RateMultiplier: f64p(1),
}
result, err := svc.Forward(context.Background(), c, account, originalBody)
require.Error(t, err)
require.Nil(t, result)
require.Equal(t, http.StatusForbidden, rec.Code)
require.Contains(t, rec.Body.String(), "requires a non-empty instructions field")
require.Nil(t, upstream.lastReq)
require.True(t, logSink.ContainsMessage("OpenAI passthrough 本地拦截:Codex 请求缺少有效 instructions"))
require.True(t, logSink.ContainsFieldValue("request_user_agent", "codex_cli_rs/0.98.0 (Windows 10.0.19045; x86_64) unknown"))
require.True(t, logSink.ContainsFieldValue("reject_reason", "instructions_missing"))
}
func TestOpenAIGatewayService_OAuthPassthrough_DisabledUsesLegacyTransform(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", bytes.NewReader(nil))
c.Request.Header.Set("User-Agent", "codex_cli_rs/0.1.0")
// store=true + stream=false should be forced to store=false + stream=true by applyCodexOAuthTransform (OAuth legacy path)
inputBody := []byte(`{"model":"gpt-5.2","stream":false,"store":true,"input":[{"type":"text","text":"hi"}]}`)
resp := &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid"}},
Body: io.NopCloser(strings.NewReader("data: [DONE]\n\n")),
}
upstream := &httpUpstreamRecorder{resp: resp}
svc := &OpenAIGatewayService{
cfg: &config.Config{Gateway: config.GatewayConfig{ForceCodexCLI: false}},
httpUpstream: upstream,
}
account := &Account{
ID: 123,
Name: "acc",
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Concurrency: 1,
Credentials: map[string]any{"access_token": "oauth-token", "chatgpt_account_id": "chatgpt-acc"},
Extra: map[string]any{"openai_passthrough": false},
Status: StatusActive,
Schedulable: true,
RateMultiplier: f64p(1),
}
_, err := svc.Forward(context.Background(), c, account, inputBody)
require.NoError(t, err)
// legacy path rewrites request body (not byte-equal)
require.NotEqual(t, inputBody, upstream.lastBody)
require.Contains(t, string(upstream.lastBody), `"store":false`)
require.Contains(t, string(upstream.lastBody), `"stream":true`)
}
func TestOpenAIGatewayService_OAuthLegacy_CompositeCodexUAUsesCodexOriginator(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", bytes.NewReader(nil))
// 复合 UA(前缀不是 codex_cli_rs),历史实现会误判为非 Codex 并走 opencode。
c.Request.Header.Set("User-Agent", "Mozilla/5.0 codex_cli_rs/0.1.0")
inputBody := []byte(`{"model":"gpt-5.2","stream":true,"store":false,"input":[{"type":"text","text":"hi"}]}`)
resp := &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid"}},
Body: io.NopCloser(strings.NewReader("data: [DONE]\n\n")),
}
upstream := &httpUpstreamRecorder{resp: resp}
svc := &OpenAIGatewayService{
cfg: &config.Config{Gateway: config.GatewayConfig{ForceCodexCLI: false}},
httpUpstream: upstream,
}
account := &Account{
ID: 123,
Name: "acc",
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Concurrency: 1,
Credentials: map[string]any{"access_token": "oauth-token", "chatgpt_account_id": "chatgpt-acc"},
Extra: map[string]any{"openai_passthrough": false},
Status: StatusActive,
Schedulable: true,
RateMultiplier: f64p(1),
}
_, err := svc.Forward(context.Background(), c, account, inputBody)
require.NoError(t, err)
require.NotNil(t, upstream.lastReq)
require.Equal(t, "codex_cli_rs", upstream.lastReq.Header.Get("originator"))
require.NotEqual(t, "opencode", upstream.lastReq.Header.Get("originator"))
}
func TestOpenAIGatewayService_OAuthPassthrough_ResponseHeadersAllowXCodex(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", bytes.NewReader(nil))
c.Request.Header.Set("User-Agent", "codex_cli_rs/0.1.0")
originalBody := []byte(`{"model":"gpt-5.2","stream":false,"input":[{"type":"text","text":"hi"}]}`)
headers := make(http.Header)
headers.Set("Content-Type", "application/json")
headers.Set("x-request-id", "rid")
headers.Set("x-codex-primary-used-percent", "12")
headers.Set("x-codex-secondary-used-percent", "34")
headers.Set("x-codex-primary-window-minutes", "300")
headers.Set("x-codex-secondary-window-minutes", "10080")
headers.Set("x-codex-primary-reset-after-seconds", "1")
resp := &http.Response{
StatusCode: http.StatusOK,
Header: headers,
Body: io.NopCloser(strings.NewReader(`{"output":[],"usage":{"input_tokens":1,"output_tokens":1,"input_tokens_details":{"cached_tokens":0}}}`)),
}
upstream := &httpUpstreamRecorder{resp: resp}
svc := &OpenAIGatewayService{
cfg: &config.Config{Gateway: config.GatewayConfig{ForceCodexCLI: false}},
httpUpstream: upstream,
}
account := &Account{
ID: 123,
Name: "acc",
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Concurrency: 1,
Credentials: map[string]any{"access_token": "oauth-token", "chatgpt_account_id": "chatgpt-acc"},
Extra: map[string]any{"openai_passthrough": true},
Status: StatusActive,
Schedulable: true,
RateMultiplier: f64p(1),
}
_, err := svc.Forward(context.Background(), c, account, originalBody)
require.NoError(t, err)
require.Equal(t, "12", rec.Header().Get("x-codex-primary-used-percent"))
require.Equal(t, "34", rec.Header().Get("x-codex-secondary-used-percent"))
}
func TestOpenAIGatewayService_OAuthPassthrough_UpstreamErrorIncludesPassthroughFlag(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", bytes.NewReader(nil))
c.Request.Header.Set("User-Agent", "codex_cli_rs/0.1.0")
originalBody := []byte(`{"model":"gpt-5.2","stream":false,"input":[{"type":"text","text":"hi"}]}`)
resp := &http.Response{
StatusCode: http.StatusBadRequest,
Header: http.Header{"Content-Type": []string{"application/json"}, "x-request-id": []string{"rid"}},
Body: io.NopCloser(strings.NewReader(`{"error":{"message":"bad"}}`)),
}
upstream := &httpUpstreamRecorder{resp: resp}
svc := &OpenAIGatewayService{
cfg: &config.Config{Gateway: config.GatewayConfig{ForceCodexCLI: false}},
httpUpstream: upstream,
}
account := &Account{
ID: 123,
Name: "acc",
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Concurrency: 1,
Credentials: map[string]any{"access_token": "oauth-token", "chatgpt_account_id": "chatgpt-acc"},
Extra: map[string]any{"openai_passthrough": true},
Status: StatusActive,
Schedulable: true,
RateMultiplier: f64p(1),
}
_, err := svc.Forward(context.Background(), c, account, originalBody)
require.Error(t, err)
// should append an upstream error event with passthrough=true
v, ok := c.Get(OpsUpstreamErrorsKey)
require.True(t, ok)
arr, ok := v.([]*OpsUpstreamErrorEvent)
require.True(t, ok)
require.NotEmpty(t, arr)
require.True(t, arr[len(arr)-1].Passthrough)
}
func TestOpenAIGatewayService_OAuthPassthrough_NonCodexUAFallbackToCodexUA(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", bytes.NewReader(nil))
// Non-Codex UA
c.Request.Header.Set("User-Agent", "curl/8.0")
inputBody := []byte(`{"model":"gpt-5.2","stream":false,"store":true,"input":[{"type":"text","text":"hi"}]}`)
resp := &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid"}},
Body: io.NopCloser(strings.NewReader("data: [DONE]\n\n")),
}
upstream := &httpUpstreamRecorder{resp: resp}
svc := &OpenAIGatewayService{
cfg: &config.Config{Gateway: config.GatewayConfig{ForceCodexCLI: false}},
httpUpstream: upstream,
}
account := &Account{
ID: 123,
Name: "acc",
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Concurrency: 1,
Credentials: map[string]any{"access_token": "oauth-token", "chatgpt_account_id": "chatgpt-acc"},
Extra: map[string]any{"openai_passthrough": true},
Status: StatusActive,
Schedulable: true,
RateMultiplier: f64p(1),
}
_, err := svc.Forward(context.Background(), c, account, inputBody)
require.NoError(t, err)
require.Equal(t, false, gjson.GetBytes(upstream.lastBody, "store").Bool())
require.Equal(t, true, gjson.GetBytes(upstream.lastBody, "stream").Bool())
require.Equal(t, "codex_cli_rs/0.98.0", upstream.lastReq.Header.Get("User-Agent"))
}
func TestOpenAIGatewayService_CodexCLIOnly_RejectsNonCodexClient(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", bytes.NewReader(nil))
c.Request.Header.Set("User-Agent", "curl/8.0")
inputBody := []byte(`{"model":"gpt-5.2","stream":false,"store":true,"input":[{"type":"text","text":"hi"}]}`)
svc := &OpenAIGatewayService{
cfg: &config.Config{Gateway: config.GatewayConfig{ForceCodexCLI: false}},
}
account := &Account{
ID: 123,
Name: "acc",
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Concurrency: 1,
Credentials: map[string]any{"access_token": "oauth-token", "chatgpt_account_id": "chatgpt-acc"},
Extra: map[string]any{"openai_passthrough": true, "codex_cli_only": true},
Status: StatusActive,
Schedulable: true,
RateMultiplier: f64p(1),
}
_, err := svc.Forward(context.Background(), c, account, inputBody)
require.Error(t, err)
require.Equal(t, http.StatusForbidden, rec.Code)
require.Contains(t, rec.Body.String(), "Codex official clients")
}
func TestOpenAIGatewayService_CodexCLIOnly_AllowOfficialClientFamilies(t *testing.T) {
gin.SetMode(gin.TestMode)
tests := []struct {
name string
ua string
originator string
}{
{name: "codex_cli_rs", ua: "codex_cli_rs/0.99.0", originator: ""},
{name: "codex_vscode", ua: "codex_vscode/1.0.0", originator: ""},
{name: "codex_app", ua: "codex_app/2.1.0", originator: ""},
{name: "originator_codex_chatgpt_desktop", ua: "curl/8.0", originator: "codex_chatgpt_desktop"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", bytes.NewReader(nil))
c.Request.Header.Set("User-Agent", tt.ua)
if tt.originator != "" {
c.Request.Header.Set("originator", tt.originator)
}
inputBody := []byte(`{"model":"gpt-5.2","stream":false,"store":true,"input":[{"type":"text","text":"hi"}]}`)
resp := &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid"}},
Body: io.NopCloser(strings.NewReader("data: [DONE]\n\n")),
}
upstream := &httpUpstreamRecorder{resp: resp}
svc := &OpenAIGatewayService{
cfg: &config.Config{Gateway: config.GatewayConfig{ForceCodexCLI: false}},
httpUpstream: upstream,
}
account := &Account{
ID: 123,
Name: "acc",
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Concurrency: 1,
Credentials: map[string]any{"access_token": "oauth-token", "chatgpt_account_id": "chatgpt-acc"},
Extra: map[string]any{"openai_passthrough": true, "codex_cli_only": true},
Status: StatusActive,
Schedulable: true,
RateMultiplier: f64p(1),
}
_, err := svc.Forward(context.Background(), c, account, inputBody)
require.NoError(t, err)
require.NotNil(t, upstream.lastReq)
})
}
}
func TestOpenAIGatewayService_OAuthPassthrough_StreamingSetsFirstTokenMs(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", bytes.NewReader(nil))
c.Request.Header.Set("User-Agent", "codex_cli_rs/0.1.0")
originalBody := []byte(`{"model":"gpt-5.2","stream":true,"input":[{"type":"text","text":"hi"}]}`)
upstreamSSE := strings.Join([]string{
`data: {"type":"response.output_text.delta","delta":"h"}`,
"",
"data: [DONE]",
"",
}, "\n")
resp := &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid"}},
Body: io.NopCloser(strings.NewReader(upstreamSSE)),
}
upstream := &httpUpstreamRecorder{resp: resp}
svc := &OpenAIGatewayService{
cfg: &config.Config{Gateway: config.GatewayConfig{ForceCodexCLI: false}},
httpUpstream: upstream,
}
account := &Account{
ID: 123,
Name: "acc",
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Concurrency: 1,
Credentials: map[string]any{"access_token": "oauth-token", "chatgpt_account_id": "chatgpt-acc"},
Extra: map[string]any{"openai_passthrough": true},
Status: StatusActive,
Schedulable: true,
RateMultiplier: f64p(1),
}
start := time.Now()
result, err := svc.Forward(context.Background(), c, account, originalBody)
require.NoError(t, err)
// sanity: duration after start
require.GreaterOrEqual(t, time.Since(start), time.Duration(0))
require.NotNil(t, result.FirstTokenMs)
require.GreaterOrEqual(t, *result.FirstTokenMs, 0)
}
func TestOpenAIGatewayService_OAuthPassthrough_StreamClientDisconnectStillCollectsUsage(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", bytes.NewReader(nil))
c.Request.Header.Set("User-Agent", "codex_cli_rs/0.1.0")
// 首次写入成功,后续写入失败,模拟客户端中途断开。
c.Writer = &failingGinWriter{ResponseWriter: c.Writer, failAfter: 1}
originalBody := []byte(`{"model":"gpt-5.2","stream":true,"input":[{"type":"text","text":"hi"}]}`)
upstreamSSE := strings.Join([]string{
`data: {"type":"response.output_text.delta","delta":"h"}`,
"",
`data: {"type":"response.completed","response":{"usage":{"input_tokens":11,"output_tokens":7,"input_tokens_details":{"cached_tokens":3}}}}`,
"",
"data: [DONE]",
"",
}, "\n")
resp := &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid"}},
Body: io.NopCloser(strings.NewReader(upstreamSSE)),
}
upstream := &httpUpstreamRecorder{resp: resp}
svc := &OpenAIGatewayService{
cfg: &config.Config{Gateway: config.GatewayConfig{ForceCodexCLI: false}},
httpUpstream: upstream,
}
account := &Account{
ID: 123,
Name: "acc",
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Concurrency: 1,
Credentials: map[string]any{"access_token": "oauth-token", "chatgpt_account_id": "chatgpt-acc"},
Extra: map[string]any{"openai_passthrough": true},
Status: StatusActive,
Schedulable: true,
RateMultiplier: f64p(1),
}
result, err := svc.Forward(context.Background(), c, account, originalBody)
require.NoError(t, err)
require.NotNil(t, result)
require.True(t, result.Stream)
require.NotNil(t, result.FirstTokenMs)
require.Equal(t, 11, result.Usage.InputTokens)
require.Equal(t, 7, result.Usage.OutputTokens)
require.Equal(t, 3, result.Usage.CacheReadInputTokens)
}
func TestOpenAIGatewayService_APIKeyPassthrough_PreservesBodyAndUsesResponsesEndpoint(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", bytes.NewReader(nil))
c.Request.Header.Set("User-Agent", "curl/8.0")
c.Request.Header.Set("X-Test", "keep")
originalBody := []byte(`{"model":"gpt-5.2","stream":false,"max_output_tokens":128,"input":[{"type":"text","text":"hi"}]}`)
resp := &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"application/json"}, "x-request-id": []string{"rid"}},
Body: io.NopCloser(strings.NewReader(`{"output":[],"usage":{"input_tokens":1,"output_tokens":1,"input_tokens_details":{"cached_tokens":0}}}`)),
}
upstream := &httpUpstreamRecorder{resp: resp}
svc := &OpenAIGatewayService{
cfg: &config.Config{Gateway: config.GatewayConfig{ForceCodexCLI: false}},
httpUpstream: upstream,
}
account := &Account{
ID: 456,
Name: "apikey-acc",
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Concurrency: 1,
Credentials: map[string]any{"api_key": "sk-api-key", "base_url": "https://api.openai.com"},
Extra: map[string]any{"openai_passthrough": true},
Status: StatusActive,
Schedulable: true,
RateMultiplier: f64p(1),
}
_, err := svc.Forward(context.Background(), c, account, originalBody)
require.NoError(t, err)
require.NotNil(t, upstream.lastReq)
require.Equal(t, originalBody, upstream.lastBody)
require.Equal(t, "https://api.openai.com/v1/responses", upstream.lastReq.URL.String())
require.Equal(t, "Bearer sk-api-key", upstream.lastReq.Header.Get("Authorization"))
require.Equal(t, "curl/8.0", upstream.lastReq.Header.Get("User-Agent"))
require.Empty(t, upstream.lastReq.Header.Get("X-Test"))
}
func TestOpenAIGatewayService_OAuthPassthrough_WarnOnTimeoutHeadersForStream(t *testing.T) {
gin.SetMode(gin.TestMode)
logSink, restore := captureStructuredLog(t)
defer restore()
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", bytes.NewReader(nil))
c.Request.Header.Set("User-Agent", "codex_cli_rs/0.1.0")
c.Request.Header.Set("x-stainless-timeout", "10000")
originalBody := []byte(`{"model":"gpt-5.2","stream":true,"input":[{"type":"text","text":"hi"}]}`)
resp := &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"text/event-stream"}, "X-Request-Id": []string{"rid-timeout"}},
Body: io.NopCloser(strings.NewReader("data: [DONE]\n\n")),
}
upstream := &httpUpstreamRecorder{resp: resp}
svc := &OpenAIGatewayService{
cfg: &config.Config{Gateway: config.GatewayConfig{ForceCodexCLI: false}},
httpUpstream: upstream,
}
account := &Account{
ID: 321,
Name: "acc",
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Concurrency: 1,
Credentials: map[string]any{"access_token": "oauth-token", "chatgpt_account_id": "chatgpt-acc"},
Extra: map[string]any{"openai_passthrough": true},
Status: StatusActive,
Schedulable: true,
RateMultiplier: f64p(1),
}
_, err := svc.Forward(context.Background(), c, account, originalBody)
require.NoError(t, err)
require.True(t, logSink.ContainsMessage("检测到超时相关请求头,将按配置过滤以降低断流风险"))
require.True(t, logSink.ContainsFieldValue("timeout_headers", "x-stainless-timeout=10000"))
}
func TestOpenAIGatewayService_OAuthPassthrough_InfoWhenStreamEndsWithoutDone(t *testing.T) {
gin.SetMode(gin.TestMode)
logSink, restore := captureStructuredLog(t)
defer restore()
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", bytes.NewReader(nil))
c.Request.Header.Set("User-Agent", "codex_cli_rs/0.1.0")
originalBody := []byte(`{"model":"gpt-5.2","stream":true,"input":[{"type":"text","text":"hi"}]}`)
// 注意:刻意不发送 [DONE],模拟上游中途断流。
resp := &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"text/event-stream"}, "X-Request-Id": []string{"rid-truncate"}},
Body: io.NopCloser(strings.NewReader("data: {\"type\":\"response.output_text.delta\",\"delta\":\"h\"}\n\n")),
}
upstream := &httpUpstreamRecorder{resp: resp}
svc := &OpenAIGatewayService{
cfg: &config.Config{Gateway: config.GatewayConfig{ForceCodexCLI: false}},
httpUpstream: upstream,
}
account := &Account{
ID: 654,
Name: "acc",
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Concurrency: 1,
Credentials: map[string]any{"access_token": "oauth-token", "chatgpt_account_id": "chatgpt-acc"},
Extra: map[string]any{"openai_passthrough": true},
Status: StatusActive,
Schedulable: true,
RateMultiplier: f64p(1),
}
_, err := svc.Forward(context.Background(), c, account, originalBody)
require.NoError(t, err)
require.True(t, logSink.ContainsMessage("上游流在未收到 [DONE] 时结束,疑似断流"))
require.True(t, logSink.ContainsMessageAtLevel("上游流在未收到 [DONE] 时结束,疑似断流", "info"))
require.True(t, logSink.ContainsFieldValue("upstream_request_id", "rid-truncate"))
}
func TestOpenAIGatewayService_OAuthPassthrough_DefaultFiltersTimeoutHeaders(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", bytes.NewReader(nil))
c.Request.Header.Set("User-Agent", "codex_cli_rs/0.1.0")
c.Request.Header.Set("x-stainless-timeout", "120000")
c.Request.Header.Set("X-Test", "keep")
originalBody := []byte(`{"model":"gpt-5.2","stream":false,"input":[{"type":"text","text":"hi"}]}`)
resp := &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"application/json"}, "X-Request-Id": []string{"rid-filter-default"}},
Body: io.NopCloser(strings.NewReader(`{"output":[],"usage":{"input_tokens":1,"output_tokens":1,"input_tokens_details":{"cached_tokens":0}}}`)),
}
upstream := &httpUpstreamRecorder{resp: resp}
svc := &OpenAIGatewayService{
cfg: &config.Config{Gateway: config.GatewayConfig{ForceCodexCLI: false}},
httpUpstream: upstream,
}
account := &Account{
ID: 111,
Name: "acc",
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Concurrency: 1,
Credentials: map[string]any{"access_token": "oauth-token", "chatgpt_account_id": "chatgpt-acc"},
Extra: map[string]any{"openai_passthrough": true},
Status: StatusActive,
Schedulable: true,
RateMultiplier: f64p(1),
}
_, err := svc.Forward(context.Background(), c, account, originalBody)
require.NoError(t, err)
require.NotNil(t, upstream.lastReq)
require.Empty(t, upstream.lastReq.Header.Get("x-stainless-timeout"))
require.Empty(t, upstream.lastReq.Header.Get("X-Test"))
}
func TestOpenAIGatewayService_OAuthPassthrough_AllowTimeoutHeadersWhenConfigured(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", bytes.NewReader(nil))
c.Request.Header.Set("User-Agent", "codex_cli_rs/0.1.0")
c.Request.Header.Set("x-stainless-timeout", "120000")
c.Request.Header.Set("X-Test", "keep")
originalBody := []byte(`{"model":"gpt-5.2","stream":false,"input":[{"type":"text","text":"hi"}]}`)
resp := &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"application/json"}, "X-Request-Id": []string{"rid-filter-allow"}},
Body: io.NopCloser(strings.NewReader(`{"output":[],"usage":{"input_tokens":1,"output_tokens":1,"input_tokens_details":{"cached_tokens":0}}}`)),
}
upstream := &httpUpstreamRecorder{resp: resp}
svc := &OpenAIGatewayService{
cfg: &config.Config{Gateway: config.GatewayConfig{
ForceCodexCLI: false,
OpenAIPassthroughAllowTimeoutHeaders: true,
}},
httpUpstream: upstream,
}
account := &Account{
ID: 222,
Name: "acc",
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Concurrency: 1,
Credentials: map[string]any{"access_token": "oauth-token", "chatgpt_account_id": "chatgpt-acc"},
Extra: map[string]any{"openai_passthrough": true},
Status: StatusActive,
Schedulable: true,
RateMultiplier: f64p(1),
}
_, err := svc.Forward(context.Background(), c, account, originalBody)
require.NoError(t, err)
require.NotNil(t, upstream.lastReq)
require.Equal(t, "120000", upstream.lastReq.Header.Get("x-stainless-timeout"))
require.Empty(t, upstream.lastReq.Header.Get("X-Test"))
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment