Commit 3b7a5fff authored by 陈曦's avatar 陈曦
Browse files

补充openai、gemini以及流失请求的采集数据以及nfs落库

parent 8519a8eb
Pipeline #82284 failed with stage
in 2 minutes and 21 seconds
......@@ -92,6 +92,235 @@ func TestApplyCodexOAuthTransform_ToolContinuationNormalizesToolReferenceIDsOnly
require.Equal(t, "fc1", second["call_id"])
}
func TestApplyCodexOAuthTransform_ToolSearchOutputPreservesCallID(t *testing.T) {
reqBody := map[string]any{
"model": "gpt-5.2",
"input": []any{
map[string]any{"type": "tool_search_output", "call_id": "call_1", "output": "ok"},
},
}
applyCodexOAuthTransform(reqBody, false, false)
input, ok := reqBody["input"].([]any)
require.True(t, ok)
require.Len(t, input, 1)
first, ok := input[0].(map[string]any)
require.True(t, ok)
require.Equal(t, "tool_search_output", first["type"])
require.Equal(t, "fc1", first["call_id"])
}
func TestApplyCodexOAuthTransform_CustomAndMCPToolOutputsPreserveCallID(t *testing.T) {
reqBody := map[string]any{
"model": "gpt-5.2",
"input": []any{
map[string]any{"type": "custom_tool_call_output", "call_id": "call_custom", "output": "ok"},
map[string]any{"type": "mcp_tool_call_output", "call_id": "call_mcp", "output": "ok"},
},
}
applyCodexOAuthTransform(reqBody, false, false)
input, ok := reqBody["input"].([]any)
require.True(t, ok)
require.Len(t, input, 2)
first, ok := input[0].(map[string]any)
require.True(t, ok)
require.Equal(t, "fccustom", first["call_id"])
second, ok := input[1].(map[string]any)
require.True(t, ok)
require.Equal(t, "fcmcp", second["call_id"])
}
func TestApplyCodexOAuthTransform_ImageAndWebSearchCallsDoNotGainCallID(t *testing.T) {
reqBody := map[string]any{
"model": "gpt-5.2",
"input": []any{
map[string]any{"type": "image_generation_call", "id": "ig_123", "status": "completed"},
map[string]any{"type": "web_search_call", "call_id": "call_bad", "status": "completed"},
},
"tool_choice": "auto",
}
applyCodexOAuthTransform(reqBody, false, false)
input, ok := reqBody["input"].([]any)
require.True(t, ok)
require.Len(t, input, 2)
first, ok := input[0].(map[string]any)
require.True(t, ok)
require.Equal(t, "ig_123", first["id"])
_, hasCallID := first["call_id"]
require.False(t, hasCallID)
second, ok := input[1].(map[string]any)
require.True(t, ok)
_, hasCallID = second["call_id"]
require.False(t, hasCallID)
}
func TestApplyCodexOAuthTransform_ConvertsToolRoleMessageToFunctionCallOutput(t *testing.T) {
reqBody := map[string]any{
"model": "gpt-5.4",
"input": []any{
map[string]any{
"type": "message",
"role": "tool",
"tool_call_id": "call_1",
"content": "ok",
},
},
}
applyCodexOAuthTransform(reqBody, true, false)
input, ok := reqBody["input"].([]any)
require.True(t, ok)
require.Len(t, input, 1)
item, ok := input[0].(map[string]any)
require.True(t, ok)
require.Equal(t, "function_call_output", item["type"])
require.Equal(t, "fc1", item["call_id"])
require.Equal(t, "ok", item["output"])
_, hasRole := item["role"]
require.False(t, hasRole)
}
func TestApplyCodexOAuthTransform_StringifiesNonStringMessageContentText(t *testing.T) {
reqBody := map[string]any{
"model": "gpt-5.4",
"input": []any{
map[string]any{
"type": "message",
"role": "user",
"content": []any{
map[string]any{"type": "input_text", "text": []any{"a", "b"}},
},
},
},
}
applyCodexOAuthTransform(reqBody, true, false)
input, ok := reqBody["input"].([]any)
require.True(t, ok)
item, ok := input[0].(map[string]any)
require.True(t, ok)
content, ok := item["content"].([]any)
require.True(t, ok)
part, ok := content[0].(map[string]any)
require.True(t, ok)
require.Equal(t, `["a","b"]`, part["text"])
}
func TestApplyCodexOAuthTransform_DowngradesUnknownToolChoice(t *testing.T) {
reqBody := map[string]any{
"model": "gpt-5.4",
"tools": []any{
map[string]any{"type": "function", "name": "shell"},
},
"tool_choice": map[string]any{"type": "custom"},
}
applyCodexOAuthTransform(reqBody, true, false)
require.Equal(t, "auto", reqBody["tool_choice"])
}
func TestApplyCodexOAuthTransform_PreservesKnownToolChoice(t *testing.T) {
reqBody := map[string]any{
"model": "gpt-5.4",
"tools": []any{
map[string]any{"type": "custom", "name": "shell"},
},
"tool_choice": map[string]any{"type": "custom"},
}
applyCodexOAuthTransform(reqBody, true, false)
choice, ok := reqBody["tool_choice"].(map[string]any)
require.True(t, ok)
require.Equal(t, "custom", choice["type"])
}
func TestApplyCodexOAuthTransform_AddsFallbackNameForFunctionCallInput(t *testing.T) {
reqBody := map[string]any{
"model": "gpt-5.4",
"input": []any{
map[string]any{"type": "message", "role": "user", "content": "run tool"},
map[string]any{"type": "function_call", "call_id": "call_1", "arguments": "{}"},
},
}
applyCodexOAuthTransform(reqBody, true, false)
input, ok := reqBody["input"].([]any)
require.True(t, ok)
require.Len(t, input, 2)
item, ok := input[1].(map[string]any)
require.True(t, ok)
require.Equal(t, "function_call", item["type"])
require.Equal(t, "tool", item["name"])
require.Equal(t, "fc1", item["call_id"])
}
func TestApplyCodexOAuthTransform_PreservesFunctionCallInputName(t *testing.T) {
reqBody := map[string]any{
"model": "gpt-5.4",
"input": []any{
map[string]any{"type": "custom_tool_call", "call_id": "call_1", "name": "shell", "input": "pwd"},
},
}
applyCodexOAuthTransform(reqBody, true, false)
input, ok := reqBody["input"].([]any)
require.True(t, ok)
require.Len(t, input, 1)
item, ok := input[0].(map[string]any)
require.True(t, ok)
require.Equal(t, "shell", item["name"])
require.Equal(t, "fc1", item["call_id"])
}
func TestApplyCodexOAuthTransform_PreservesMCPToolCallIDAndName(t *testing.T) {
reqBody := map[string]any{
"model": "gpt-5.4",
"input": []any{
map[string]any{
"type": "mcp_tool_call",
"call_id": "call_abc",
"name": "remote_tool",
"arguments": "{}",
},
},
}
applyCodexOAuthTransform(reqBody, true, false)
input, ok := reqBody["input"].([]any)
require.True(t, ok)
require.Len(t, input, 1)
item, ok := input[0].(map[string]any)
require.True(t, ok)
require.Equal(t, "mcp_tool_call", item["type"])
require.Equal(t, "remote_tool", item["name"])
require.Equal(t, "fcabc", item["call_id"])
}
func TestCodexInputItemRequiresNameTypesAllowCallID(t *testing.T) {
for _, typ := range []string{"function_call", "custom_tool_call", "mcp_tool_call"} {
require.True(t, codexInputItemRequiresName(typ), typ)
require.True(t, isCodexToolCallItemType(typ), typ)
}
}
func TestApplyCodexOAuthTransform_ExplicitStoreFalsePreserved(t *testing.T) {
// 续链场景:显式 store=false 不再强制为 true,保持 false。
......@@ -261,6 +490,17 @@ func TestEnsureOpenAIResponsesImageGenerationTool_NoTools(t *testing.T) {
require.Equal(t, "png", tool["output_format"])
}
func TestEnsureOpenAIResponsesImageGenerationTool_SkipsSpark(t *testing.T) {
reqBody := map[string]any{
"model": "gpt-5.3-codex-spark",
"input": "draw a cat",
}
modified := ensureOpenAIResponsesImageGenerationTool(reqBody)
require.False(t, modified)
require.NotContains(t, reqBody, "tools")
}
func TestEnsureOpenAIResponsesImageGenerationTool_AppendsToExistingTools(t *testing.T) {
reqBody := map[string]any{
"model": "gpt-5.4",
......@@ -306,6 +546,7 @@ func TestEnsureOpenAIResponsesImageGenerationTool_PreservesExistingImageTool(t *
func TestApplyCodexImageGenerationBridgeInstructions_AppendsBridgeOnce(t *testing.T) {
reqBody := map[string]any{
"model": "gpt-5.4",
"instructions": "existing instructions",
"tools": []any{
map[string]any{"type": "image_generation", "output_format": "png"},
......@@ -325,6 +566,20 @@ func TestApplyCodexImageGenerationBridgeInstructions_AppendsBridgeOnce(t *testin
require.False(t, modified)
}
func TestApplyCodexImageGenerationBridgeInstructions_SkipsSpark(t *testing.T) {
reqBody := map[string]any{
"model": "gpt-5.3-codex-spark",
"instructions": "existing instructions",
"tools": []any{
map[string]any{"type": "image_generation", "output_format": "png"},
},
}
modified := applyCodexImageGenerationBridgeInstructions(reqBody)
require.False(t, modified)
require.Equal(t, "existing instructions", reqBody["instructions"])
}
func TestApplyCodexImageGenerationBridgeInstructions_SkipsWithoutImageTool(t *testing.T) {
reqBody := map[string]any{
"instructions": "existing instructions",
......@@ -338,6 +593,91 @@ func TestApplyCodexImageGenerationBridgeInstructions_SkipsWithoutImageTool(t *te
require.Equal(t, "existing instructions", reqBody["instructions"])
}
func TestValidateCodexSparkInputRejectsInputImage(t *testing.T) {
reqBody := map[string]any{
"model": "gpt-5.3-codex-spark",
"input": []any{
map[string]any{
"role": "user",
"content": []any{
map[string]any{"type": "input_text", "text": "describe"},
map[string]any{"type": "input_image", "image_url": "data:image/png;base64,aGVsbG8="},
},
},
},
}
err := validateCodexSparkInput(reqBody, "gpt-5.3-codex-spark")
require.Error(t, err)
require.Contains(t, err.Error(), "does not support image input")
}
func TestValidateCodexSparkInputRejectsChatImageURL(t *testing.T) {
reqBody := map[string]any{
"model": "gpt-5.3-codex-spark",
"messages": []any{
map[string]any{
"role": "user",
"content": []any{
map[string]any{"type": "text", "text": "describe"},
map[string]any{"type": "image_url", "image_url": map[string]any{"url": "data:image/png;base64,aGVsbG8="}},
},
},
},
}
err := validateCodexSparkInput(reqBody, "gpt-5.3-codex-spark")
require.Error(t, err)
}
func TestValidateCodexSparkInputAllowsTextOnly(t *testing.T) {
reqBody := map[string]any{
"model": "gpt-5.3-codex-spark",
"input": []any{
map[string]any{
"role": "user",
"content": []any{
map[string]any{"type": "input_text", "text": "hello"},
},
},
},
}
require.NoError(t, validateCodexSparkInput(reqBody, "gpt-5.3-codex-spark"))
}
func TestApplyCodexOAuthTransform_AddsSparkImageUnsupportedInstructions(t *testing.T) {
reqBody := map[string]any{
"model": "gpt-5.3-codex-spark",
"instructions": "existing instructions",
"input": "hello",
}
result := applyCodexOAuthTransform(reqBody, true, false)
require.True(t, result.Modified)
instructions, ok := reqBody["instructions"].(string)
require.True(t, ok)
require.Contains(t, instructions, "existing instructions")
require.Contains(t, instructions, codexSparkImageUnsupportedMarker)
require.Contains(t, instructions, "does not support image generation")
require.Contains(t, instructions, "switch to a non-Spark Codex model")
require.NotContains(t, instructions, codexImageGenerationBridgeMarker)
}
func TestApplyCodexOAuthTransform_DoesNotAddSparkImageUnsupportedForNonSpark(t *testing.T) {
reqBody := map[string]any{
"model": "gpt-5.4",
"instructions": "existing instructions",
"input": "hello",
}
applyCodexOAuthTransform(reqBody, true, false)
instructions, ok := reqBody["instructions"].(string)
require.True(t, ok)
require.NotContains(t, instructions, codexSparkImageUnsupportedMarker)
}
func TestNormalizeOpenAIResponsesImageOnlyModel_BuildsImageToolRequest(t *testing.T) {
reqBody := map[string]any{
"model": "gpt-image-2",
......
package service
import (
"bytes"
"context"
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
"github.com/tidwall/gjson"
)
func TestOpenAIGatewayService_Forward_CompactOnlyModelMappingOverridesOAuthUpstreamModel(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
body := []byte(`{"model":"gpt-5.4","stream":false,"instructions":"compact-test","input":"hello"}`)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses/compact", bytes.NewReader(body))
c.Request.Header.Set("Content-Type", "application/json")
upstream := &httpUpstreamRecorder{resp: &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"application/json"}, "x-request-id": []string{"rid-compact-map"}},
Body: io.NopCloser(strings.NewReader(`{"id":"resp_123","status":"completed","model":"gpt-5.4-openai-compact","output":[],"usage":{"input_tokens":1,"output_tokens":1}}`)),
}}
svc := &OpenAIGatewayService{httpUpstream: upstream}
account := &Account{
ID: 1,
Name: "openai-oauth",
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Concurrency: 1,
Credentials: map[string]any{
"access_token": "oauth-token",
"chatgpt_account_id": "chatgpt-acc",
"compact_model_mapping": map[string]any{"gpt-5.4": "gpt-5.4-openai-compact"},
},
Status: StatusActive,
Schedulable: true,
}
result, err := svc.Forward(context.Background(), c, account, body)
require.NoError(t, err)
require.NotNil(t, result)
require.Equal(t, "gpt-5.4", result.Model)
require.Equal(t, "gpt-5.4-openai-compact", result.UpstreamModel)
require.Equal(t, "gpt-5.4-openai-compact", gjson.GetBytes(upstream.lastBody, "model").String())
}
func TestOpenAIGatewayService_Forward_NonCompactRequestIgnoresCompactOnlyModelMapping(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
body := []byte(`{"model":"gpt-5.4","stream":false,"instructions":"normal-test","input":"hello"}`)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", bytes.NewReader(body))
c.Request.Header.Set("Content-Type", "application/json")
upstream := &httpUpstreamRecorder{resp: &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"application/json"}, "x-request-id": []string{"rid-normal-map"}},
Body: io.NopCloser(strings.NewReader(`{"id":"resp_124","status":"completed","model":"gpt-5.4","output":[],"usage":{"input_tokens":1,"output_tokens":1}}`)),
}}
svc := &OpenAIGatewayService{httpUpstream: upstream}
account := &Account{
ID: 2,
Name: "openai-oauth",
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Concurrency: 1,
Credentials: map[string]any{
"access_token": "oauth-token",
"chatgpt_account_id": "chatgpt-acc",
"compact_model_mapping": map[string]any{"gpt-5.4": "gpt-5.4-openai-compact"},
},
Status: StatusActive,
Schedulable: true,
}
result, err := svc.Forward(context.Background(), c, account, body)
require.NoError(t, err)
require.NotNil(t, result)
require.Equal(t, "gpt-5.4", result.Model)
require.Equal(t, "gpt-5.4", result.UpstreamModel)
require.Equal(t, "gpt-5.4", gjson.GetBytes(upstream.lastBody, "model").String())
}
func TestOpenAIGatewayService_OAuthPassthrough_CompactOnlyModelMappingOverridesUpstreamModel(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses/compact", bytes.NewReader(nil))
c.Request.Header.Set("User-Agent", "codex_cli_rs/0.1.0")
c.Request.Header.Set("Content-Type", "application/json")
originalBody := []byte(`{"model":"gpt-5.4","stream":true,"store":true,"instructions":"compact-pass","input":[{"type":"text","text":"compact me"}]}`)
upstream := &httpUpstreamRecorder{resp: &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"application/json"}, "x-request-id": []string{"rid-compact-pass-map"}},
Body: io.NopCloser(strings.NewReader(`{"id":"cmp_124","model":"gpt-5.4-openai-compact","usage":{"input_tokens":2,"output_tokens":3}}`)),
}}
svc := &OpenAIGatewayService{httpUpstream: upstream}
account := &Account{
ID: 3,
Name: "openai-oauth-pass",
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Concurrency: 1,
Credentials: map[string]any{
"access_token": "oauth-token",
"chatgpt_account_id": "chatgpt-acc",
"compact_model_mapping": map[string]any{"gpt-5.4": "gpt-5.4-openai-compact"},
},
Extra: map[string]any{"openai_passthrough": true},
Status: StatusActive,
Schedulable: true,
}
result, err := svc.Forward(context.Background(), c, account, originalBody)
require.NoError(t, err)
require.NotNil(t, result)
require.Equal(t, "gpt-5.4", result.Model)
require.Equal(t, "gpt-5.4-openai-compact", result.UpstreamModel)
require.Equal(t, "gpt-5.4-openai-compact", gjson.GetBytes(upstream.lastBody, "model").String())
require.Equal(t, "gpt-5.4", gjson.GetBytes(rec.Body.Bytes(), "model").String())
}
package service
import (
"net/http"
"strconv"
"strings"
"time"
)
const (
// AccountTestModeDefault drives the standard /responses connection test.
AccountTestModeDefault = "default"
// AccountTestModeCompact drives the /responses/compact compact-probe test.
AccountTestModeCompact = "compact"
)
func normalizeAccountTestMode(mode string) string {
switch strings.ToLower(strings.TrimSpace(mode)) {
case AccountTestModeCompact:
return AccountTestModeCompact
default:
return AccountTestModeDefault
}
}
func createOpenAICompactProbePayload(model string) map[string]any {
return map[string]any{
"model": strings.TrimSpace(model),
"instructions": "You are a helpful coding assistant.",
"input": []any{
map[string]any{
"type": "message",
"role": "user",
"content": "Respond with OK.",
},
},
}
}
func shouldMarkOpenAICompactUnsupported(status int, body []byte) bool {
switch status {
case http.StatusNotFound, http.StatusMethodNotAllowed, http.StatusNotImplemented:
return true
case http.StatusBadRequest, http.StatusForbidden, http.StatusUnprocessableEntity:
lower := strings.ToLower(strings.TrimSpace(extractUpstreamErrorMessage(body) + " " + string(body)))
if strings.Contains(lower, "compact") {
for _, keyword := range []string{
"unsupported",
"not support",
"does not support",
"not available",
"disabled",
} {
if strings.Contains(lower, keyword) {
return true
}
}
}
}
return false
}
func buildOpenAICompactProbeExtraUpdates(resp *http.Response, body []byte, probeErr error, now time.Time) map[string]any {
updates := map[string]any{
"openai_compact_checked_at": now.Format(time.RFC3339),
"openai_compact_last_status": nil,
}
if resp != nil {
updates["openai_compact_last_status"] = resp.StatusCode
}
switch {
case probeErr != nil:
updates["openai_compact_last_error"] = truncateString(sanitizeUpstreamErrorMessage(probeErr.Error()), 2048)
case resp == nil:
updates["openai_compact_last_error"] = "compact probe failed"
default:
errMsg := strings.TrimSpace(extractUpstreamErrorMessage(body))
if errMsg == "" && len(body) > 0 {
errMsg = strings.TrimSpace(string(body))
}
if errMsg == "" && (resp.StatusCode < 200 || resp.StatusCode >= 300) {
errMsg = "HTTP " + strconv.Itoa(resp.StatusCode)
}
errMsg = truncateString(sanitizeUpstreamErrorMessage(errMsg), 2048)
if resp.StatusCode >= 200 && resp.StatusCode < 300 {
updates["openai_compact_supported"] = true
updates["openai_compact_last_error"] = ""
} else {
if shouldMarkOpenAICompactUnsupported(resp.StatusCode, body) {
updates["openai_compact_supported"] = false
}
updates["openai_compact_last_error"] = errMsg
}
}
return updates
}
func mergeExtraUpdates(base map[string]any, more map[string]any) map[string]any {
if len(base) == 0 && len(more) == 0 {
return nil
}
out := make(map[string]any, len(base)+len(more))
for key, value := range base {
out[key] = value
}
for key, value := range more {
out[key] = value
}
return out
}
func compactProbeSessionID(accountID int64) string {
if accountID <= 0 {
return "probe_compact"
}
return "probe_compact_" + strconv.FormatInt(accountID, 10)
}
package service
import (
"errors"
"net/http"
"testing"
"time"
)
func TestNormalizeAccountTestMode(t *testing.T) {
tests := []struct {
input string
want string
}{
{input: "", want: AccountTestModeDefault},
{input: "default", want: AccountTestModeDefault},
{input: " compact ", want: AccountTestModeCompact},
{input: "COMPACT", want: AccountTestModeCompact},
{input: "unknown", want: AccountTestModeDefault},
}
for _, tt := range tests {
if got := normalizeAccountTestMode(tt.input); got != tt.want {
t.Fatalf("normalizeAccountTestMode(%q) = %q, want %q", tt.input, got, tt.want)
}
}
}
func TestBuildOpenAICompactProbeExtraUpdates_SuccessMarksSupported(t *testing.T) {
now := time.Date(2026, 4, 10, 10, 0, 0, 0, time.UTC)
updates := buildOpenAICompactProbeExtraUpdates(&http.Response{StatusCode: http.StatusOK}, []byte(`{"id":"cmp_1"}`), nil, now)
if got := updates["openai_compact_supported"]; got != true {
t.Fatalf("openai_compact_supported = %v, want true", got)
}
if got := updates["openai_compact_last_status"]; got != http.StatusOK {
t.Fatalf("openai_compact_last_status = %v, want %d", got, http.StatusOK)
}
if got := updates["openai_compact_last_error"]; got != "" {
t.Fatalf("openai_compact_last_error = %v, want empty string", got)
}
if got := updates["openai_compact_checked_at"]; got != now.Format(time.RFC3339) {
t.Fatalf("openai_compact_checked_at = %v, want %s", got, now.Format(time.RFC3339))
}
}
func TestBuildOpenAICompactProbeExtraUpdates_404MarksUnsupported(t *testing.T) {
now := time.Date(2026, 4, 10, 10, 0, 0, 0, time.UTC)
body := []byte(`404 page not found`)
updates := buildOpenAICompactProbeExtraUpdates(&http.Response{StatusCode: http.StatusNotFound}, body, nil, now)
if got := updates["openai_compact_supported"]; got != false {
t.Fatalf("openai_compact_supported = %v, want false", got)
}
if got := updates["openai_compact_last_status"]; got != http.StatusNotFound {
t.Fatalf("openai_compact_last_status = %v, want %d", got, http.StatusNotFound)
}
}
func TestBuildOpenAICompactProbeExtraUpdates_502DoesNotMarkUnsupported(t *testing.T) {
now := time.Date(2026, 4, 10, 10, 0, 0, 0, time.UTC)
updates := buildOpenAICompactProbeExtraUpdates(&http.Response{StatusCode: http.StatusBadGateway}, []byte(`Upstream request failed`), nil, now)
if _, exists := updates["openai_compact_supported"]; exists {
t.Fatalf("did not expect openai_compact_supported for 502 response")
}
if got := updates["openai_compact_last_status"]; got != http.StatusBadGateway {
t.Fatalf("openai_compact_last_status = %v, want %d", got, http.StatusBadGateway)
}
}
func TestBuildOpenAICompactProbeExtraUpdates_RequestErrorDoesNotMarkUnsupported(t *testing.T) {
now := time.Date(2026, 4, 10, 10, 0, 0, 0, time.UTC)
updates := buildOpenAICompactProbeExtraUpdates(nil, nil, errors.New("dial tcp timeout"), now)
if _, exists := updates["openai_compact_supported"]; exists {
t.Fatalf("did not expect openai_compact_supported for request error")
}
if got, exists := updates["openai_compact_last_status"]; !exists || got != nil {
t.Fatalf("openai_compact_last_status = %v, want nil key", got)
}
if got := updates["openai_compact_last_error"]; got == "" {
t.Fatalf("expected openai_compact_last_error to be populated")
}
}
func TestBuildOpenAICompactProbeExtraUpdates_NoResponseClearsLastStatus(t *testing.T) {
now := time.Date(2026, 4, 10, 10, 0, 0, 0, time.UTC)
updates := buildOpenAICompactProbeExtraUpdates(nil, nil, nil, now)
if got, exists := updates["openai_compact_last_status"]; !exists || got != nil {
t.Fatalf("openai_compact_last_status = %v, want nil key", got)
}
if got := updates["openai_compact_last_error"]; got != "compact probe failed" {
t.Fatalf("openai_compact_last_error = %v, want compact probe failed", got)
}
}
func TestBuildOpenAICompactProbeExtraUpdates_UnknownModelDoesNotMarkUnsupported(t *testing.T) {
now := time.Date(2026, 4, 10, 10, 0, 0, 0, time.UTC)
body := []byte(`{"error":{"message":"unknown model gpt-5.4-openai-compact"}}`)
updates := buildOpenAICompactProbeExtraUpdates(&http.Response{StatusCode: http.StatusBadRequest}, body, nil, now)
if _, exists := updates["openai_compact_supported"]; exists {
t.Fatalf("did not expect openai_compact_supported for unknown-model diagnostics")
}
if got := updates["openai_compact_last_status"]; got != http.StatusBadRequest {
t.Fatalf("openai_compact_last_status = %v, want %d", got, http.StatusBadRequest)
}
}
func TestBuildOpenAICompactProbeExtraUpdates_EmptyFailureBodyFallsBackToHTTPStatus(t *testing.T) {
now := time.Date(2026, 4, 10, 10, 0, 0, 0, time.UTC)
updates := buildOpenAICompactProbeExtraUpdates(&http.Response{StatusCode: http.StatusServiceUnavailable}, nil, nil, now)
if got := updates["openai_compact_last_status"]; got != http.StatusServiceUnavailable {
t.Fatalf("openai_compact_last_status = %v, want %d", got, http.StatusServiceUnavailable)
}
if got := updates["openai_compact_last_error"]; got != "HTTP 503" {
t.Fatalf("openai_compact_last_error = %v, want HTTP 503", got)
}
}
......@@ -406,7 +406,14 @@ func (s *OpenAIGatewayService) handleChatBufferedStreamingResponse(
if s.responseHeaderFilter != nil {
responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter)
}
var responseBody string
if respBytes, err := json.Marshal(chatResp); err == nil {
responseBody = string(respBytes)
c.Data(http.StatusOK, "application/json; charset=utf-8", respBytes)
} else {
c.JSON(http.StatusOK, chatResp)
}
return &OpenAIForwardResult{
RequestID: requestID,
......@@ -416,6 +423,7 @@ func (s *OpenAIGatewayService) handleChatBufferedStreamingResponse(
UpstreamModel: upstreamModel,
Stream: false,
Duration: time.Since(startTime),
ResponseBody: responseBody,
}, nil
}
......@@ -448,6 +456,7 @@ func (s *OpenAIGatewayService) handleChatStreamingResponse(
var usage OpenAIUsage
var firstTokenMs *int
firstChunk := true
var textBuilder strings.Builder // 收集 assistant 文本用于响应捕获
scanner := bufio.NewScanner(resp.Body)
maxLineSize := defaultMaxLineSize
......@@ -466,6 +475,7 @@ func (s *OpenAIGatewayService) handleChatStreamingResponse(
Stream: true,
Duration: time.Since(startTime),
FirstTokenMs: firstTokenMs,
ResponseBody: textBuilder.String(),
}
}
......@@ -499,6 +509,10 @@ func (s *OpenAIGatewayService) handleChatStreamingResponse(
chunks := apicompat.ResponsesEventToChatChunks(&event, state)
for _, chunk := range chunks {
// 收集 assistant text 用于响应捕获
if len(chunk.Choices) > 0 && chunk.Choices[0].Delta.Content != nil {
textBuilder.WriteString(*chunk.Choices[0].Delta.Content)
}
sse, err := apicompat.ChatChunkToSSE(chunk)
if err != nil {
logger.L().Warn("openai chat_completions stream: failed to marshal chunk",
......
......@@ -354,7 +354,13 @@ func (s *OpenAIGatewayService) handleAnthropicBufferedStreamingResponse(
if s.responseHeaderFilter != nil {
responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter)
}
var responseBody string
if respBytes, err := json.Marshal(anthropicResp); err == nil {
responseBody = string(respBytes)
c.Data(http.StatusOK, "application/json; charset=utf-8", respBytes)
} else {
c.JSON(http.StatusOK, anthropicResp)
}
return &OpenAIForwardResult{
RequestID: requestID,
......@@ -364,6 +370,7 @@ func (s *OpenAIGatewayService) handleAnthropicBufferedStreamingResponse(
UpstreamModel: upstreamModel,
Stream: false,
Duration: time.Since(startTime),
ResponseBody: responseBody,
}, nil
}
......@@ -396,6 +403,7 @@ func (s *OpenAIGatewayService) handleAnthropicStreamingResponse(
var usage OpenAIUsage
var firstTokenMs *int
firstChunk := true
var textBuilder strings.Builder // 收集 assistant text 用于响应捕获
scanner := bufio.NewScanner(resp.Body)
maxLineSize := defaultMaxLineSize
......@@ -415,6 +423,7 @@ func (s *OpenAIGatewayService) handleAnthropicStreamingResponse(
Stream: true,
Duration: time.Since(startTime),
FirstTokenMs: firstTokenMs,
ResponseBody: textBuilder.String(),
}
}
......@@ -451,6 +460,10 @@ func (s *OpenAIGatewayService) handleAnthropicStreamingResponse(
// Convert to Anthropic events
events := apicompat.ResponsesEventToAnthropicEvents(&event, state)
for _, evt := range events {
// 采集 text_delta 用于响应捕获
if evt.Type == "content_block_delta" && evt.Delta != nil && evt.Delta.Type == "text_delta" {
textBuilder.WriteString(evt.Delta.Text)
}
sse, err := apicompat.ResponsesAnthropicEventToSSE(evt)
if err != nil {
logger.L().Warn("openai messages stream: failed to marshal event",
......
......@@ -40,7 +40,7 @@ const (
// OpenAI Platform API for API Key accounts (fallback)
openaiPlatformAPIURL = "https://api.openai.com/v1/responses"
openaiStickySessionTTL = time.Hour // 粘性会话TTL
codexCLIUserAgent = "codex_cli_rs/0.104.0"
codexCLIUserAgent = "codex_cli_rs/0.125.0"
// codex_cli_only 拒绝时单个请求头日志长度上限(字符)
codexCLIOnlyHeaderValueMaxBytes = 256
......@@ -54,7 +54,7 @@ const (
openAIWSRetryBackoffMaxDefault = 2 * time.Second
openAIWSRetryJitterRatioDefault = 0.2
openAICompactSessionSeedKey = "openai_compact_session_seed"
codexCLIVersion = "0.104.0"
codexCLIVersion = "0.125.0"
// Codex 限额快照仅用于后台展示/诊断,不需要每个成功请求都立即落库。
openAICodexSnapshotPersistMinInterval = 30 * time.Second
)
......@@ -235,6 +235,9 @@ type OpenAIForwardResult struct {
FirstTokenMs *int
ImageCount int
ImageSize string
// ResponseBody 响应内容:非 streaming 为完整 JSON,streaming 为拼接的 assistant text。
// 仅当 API Key 开启了 capture_requests 时才会被使用。
ResponseBody string
}
type OpenAIWSRetryMetricsSnapshot struct {
......@@ -306,6 +309,10 @@ func (t *accountWriteThrottle) Allow(id int64, now time.Time) bool {
var defaultOpenAICodexSnapshotPersistThrottle = newAccountWriteThrottle(openAICodexSnapshotPersistMinInterval)
// ErrNoAvailableCompactAccounts indicates the request needs /responses/compact
// support but no compatible account is available.
var ErrNoAvailableCompactAccounts = errors.New("no available OpenAI accounts support /responses/compact")
// OpenAIGatewayService handles OpenAI API gateway operations
type OpenAIGatewayService struct {
accountRepo AccountRepository
......@@ -442,11 +449,11 @@ func (s *OpenAIGatewayService) checkChannelPricingRestriction(ctx context.Contex
return s.channelService.IsModelRestricted(ctx, *groupID, billingModel)
}
func (s *OpenAIGatewayService) isUpstreamModelRestrictedByChannel(ctx context.Context, groupID int64, account *Account, requestedModel string) bool {
func (s *OpenAIGatewayService) isUpstreamModelRestrictedByChannel(ctx context.Context, groupID int64, account *Account, requestedModel string, requireCompact bool) bool {
if s.channelService == nil {
return false
}
upstreamModel := resolveOpenAIForwardModel(account, requestedModel, "")
upstreamModel := resolveOpenAIAccountUpstreamModelForRequest(account, requestedModel, requireCompact)
if upstreamModel == "" {
return false
}
......@@ -1121,6 +1128,35 @@ func (s *OpenAIGatewayService) ExtractSessionID(c *gin.Context, body []byte) str
return sessionID
}
func explicitOpenAISessionID(c *gin.Context, body []byte) string {
if c == nil {
return ""
}
sessionID := strings.TrimSpace(c.GetHeader("session_id"))
if sessionID == "" {
sessionID = strings.TrimSpace(c.GetHeader("conversation_id"))
}
if sessionID == "" && len(body) > 0 {
sessionID = strings.TrimSpace(gjson.GetBytes(body, "prompt_cache_key").String())
}
return sessionID
}
// GenerateExplicitSessionHash generates a sticky-session hash only from explicit
// client session signals. It intentionally skips content-derived fallback and is
// used by stateless endpoints such as /v1/images.
func (s *OpenAIGatewayService) GenerateExplicitSessionHash(c *gin.Context, body []byte) string {
sessionID := explicitOpenAISessionID(c, body)
if sessionID == "" {
return ""
}
currentHash, legacyHash := deriveOpenAISessionHashes(sessionID)
attachOpenAILegacySessionHashToGin(c, legacyHash)
return currentHash
}
// GenerateSessionHash generates a sticky-session hash for OpenAI requests.
//
// Priority:
......@@ -1133,13 +1169,7 @@ func (s *OpenAIGatewayService) GenerateSessionHash(c *gin.Context, body []byte)
return ""
}
sessionID := strings.TrimSpace(c.GetHeader("session_id"))
if sessionID == "" {
sessionID = strings.TrimSpace(c.GetHeader("conversation_id"))
}
if sessionID == "" && len(body) > 0 {
sessionID = strings.TrimSpace(gjson.GetBytes(body, "prompt_cache_key").String())
}
sessionID := explicitOpenAISessionID(c, body)
if sessionID == "" && len(body) > 0 {
sessionID = deriveOpenAIContentSessionSeed(body)
}
......@@ -1208,10 +1238,94 @@ func (s *OpenAIGatewayService) SelectAccountForModel(ctx context.Context, groupI
// SelectAccountForModelWithExclusions selects an account supporting the requested model while excluding specified accounts.
// SelectAccountForModelWithExclusions 选择支持指定模型的账号,同时排除指定的账号。
func (s *OpenAIGatewayService) SelectAccountForModelWithExclusions(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}) (*Account, error) {
return s.selectAccountForModelWithExclusions(ctx, groupID, sessionHash, requestedModel, excludedIDs, 0)
return s.selectAccountForModelWithExclusions(ctx, groupID, sessionHash, requestedModel, excludedIDs, false, 0)
}
// noAvailableOpenAISelectionError builds the standard "no account available" error
// while preserving the compact-specific error when applicable.
func noAvailableOpenAISelectionError(requestedModel string, compactBlocked bool) error {
if compactBlocked {
return ErrNoAvailableCompactAccounts
}
if requestedModel != "" {
return fmt.Errorf("no available OpenAI accounts supporting model: %s", requestedModel)
}
return errors.New("no available OpenAI accounts")
}
// openAICompactSupportTier classifies an OpenAI account by compact capability.
// 0 = explicitly unsupported, 1 = unknown / not yet probed, 2 = explicitly supported.
func openAICompactSupportTier(account *Account) int {
if account == nil || !account.IsOpenAI() {
return 0
}
supported, known := account.OpenAICompactSupportKnown()
if !known {
return 1
}
if supported {
return 2
}
return 0
}
// isOpenAIAccountEligibleForRequest centralises the schedulable / OpenAI / model /
// compact-support checks used during account selection.
func isOpenAIAccountEligibleForRequest(account *Account, requestedModel string, requireCompact bool) bool {
if account == nil || !account.IsSchedulable() || !account.IsOpenAI() {
return false
}
if requestedModel != "" && !account.IsModelSupported(requestedModel) {
return false
}
if requireCompact && openAICompactSupportTier(account) == 0 {
return false
}
return true
}
// prioritizeOpenAICompactAccounts re-orders a slice so that accounts with known
// compact support are tried first, followed by unknown, then explicitly unsupported.
// The relative order within each tier is preserved.
func prioritizeOpenAICompactAccounts(accounts []*Account) []*Account {
if len(accounts) == 0 {
return nil
}
supported := make([]*Account, 0, len(accounts))
unknown := make([]*Account, 0, len(accounts))
unsupported := make([]*Account, 0, len(accounts))
for _, account := range accounts {
switch openAICompactSupportTier(account) {
case 2:
supported = append(supported, account)
case 1:
unknown = append(unknown, account)
default:
unsupported = append(unsupported, account)
}
}
out := make([]*Account, 0, len(accounts))
out = append(out, supported...)
out = append(out, unknown...)
out = append(out, unsupported...)
return out
}
// resolveOpenAIAccountUpstreamModelForRequest resolves the upstream model that
// would be sent for a given request, honouring compact-only mappings when the
// caller is on the /responses/compact path.
func resolveOpenAIAccountUpstreamModelForRequest(account *Account, requestedModel string, requireCompact bool) string {
upstreamModel := resolveOpenAIForwardModel(account, requestedModel, "")
if upstreamModel == "" {
return ""
}
if requireCompact {
return resolveOpenAICompactForwardModel(account, upstreamModel)
}
return upstreamModel
}
func (s *OpenAIGatewayService) selectAccountForModelWithExclusions(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}, stickyAccountID int64) (*Account, error) {
func (s *OpenAIGatewayService) selectAccountForModelWithExclusions(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}, requireCompact bool, stickyAccountID int64) (*Account, error) {
if s.checkChannelPricingRestriction(ctx, groupID, requestedModel) {
slog.Warn("channel pricing restriction blocked request",
"group_id", derefGroupID(groupID),
......@@ -1221,7 +1335,7 @@ func (s *OpenAIGatewayService) selectAccountForModelWithExclusions(ctx context.C
// 1. 尝试粘性会话命中
// Try sticky session hit
if account := s.tryStickySessionHit(ctx, groupID, sessionHash, requestedModel, excludedIDs, stickyAccountID); account != nil {
if account := s.tryStickySessionHit(ctx, groupID, sessionHash, requestedModel, excludedIDs, requireCompact, stickyAccountID); account != nil {
return account, nil
}
......@@ -1234,13 +1348,10 @@ func (s *OpenAIGatewayService) selectAccountForModelWithExclusions(ctx context.C
// 3. 按优先级 + LRU 选择最佳账号
// Select by priority + LRU
selected := s.selectBestAccount(ctx, groupID, accounts, requestedModel, excludedIDs)
selected, compactBlocked := s.selectBestAccount(ctx, groupID, accounts, requestedModel, excludedIDs, requireCompact)
if selected == nil {
if requestedModel != "" {
return nil, fmt.Errorf("no available OpenAI accounts supporting model: %s", requestedModel)
}
return nil, errors.New("no available OpenAI accounts")
return nil, noAvailableOpenAISelectionError(requestedModel, compactBlocked)
}
// 4. 设置粘性会话绑定
......@@ -1257,7 +1368,7 @@ func (s *OpenAIGatewayService) selectAccountForModelWithExclusions(ctx context.C
//
// tryStickySessionHit attempts to get account from sticky session.
// Returns account if hit and usable; clears session and returns nil if account is unavailable.
func (s *OpenAIGatewayService) tryStickySessionHit(ctx context.Context, groupID *int64, sessionHash, requestedModel string, excludedIDs map[int64]struct{}, stickyAccountID int64) *Account {
func (s *OpenAIGatewayService) tryStickySessionHit(ctx context.Context, groupID *int64, sessionHash, requestedModel string, excludedIDs map[int64]struct{}, requireCompact bool, stickyAccountID int64) *Account {
if sessionHash == "" {
return nil
}
......@@ -1289,19 +1400,16 @@ func (s *OpenAIGatewayService) tryStickySessionHit(ctx context.Context, groupID
// 验证账号是否可用于当前请求
// Verify account is usable for current request
if !account.IsSchedulable() || !account.IsOpenAI() {
if !isOpenAIAccountEligibleForRequest(account, requestedModel, false) {
return nil
}
if requestedModel != "" && !account.IsModelSupported(requestedModel) {
return nil
}
account = s.recheckSelectedOpenAIAccountFromDB(ctx, account, requestedModel)
account = s.recheckSelectedOpenAIAccountFromDB(ctx, account, requestedModel, requireCompact)
if account == nil {
_ = s.deleteStickySessionAccountID(ctx, groupID, sessionHash)
return nil
}
if groupID != nil && s.needsUpstreamChannelRestrictionCheck(ctx, groupID) &&
s.isUpstreamModelRestrictedByChannel(ctx, *groupID, account, requestedModel) {
s.isUpstreamModelRestrictedByChannel(ctx, *groupID, account, requestedModel, requireCompact) {
_ = s.deleteStickySessionAccountID(ctx, groupID, sessionHash)
return nil
}
......@@ -1316,9 +1424,13 @@ func (s *OpenAIGatewayService) tryStickySessionHit(ctx context.Context, groupID
// 返回 nil 表示无可用账号。
//
// selectBestAccount selects the best account from candidates (priority + LRU).
// Returns nil if no available account.
func (s *OpenAIGatewayService) selectBestAccount(ctx context.Context, groupID *int64, accounts []Account, requestedModel string, excludedIDs map[int64]struct{}) *Account {
// Returns nil if no available account. The second return reports whether at
// least one candidate was filtered out solely because it lacks compact support
// (only meaningful when requireCompact=true).
func (s *OpenAIGatewayService) selectBestAccount(ctx context.Context, groupID *int64, accounts []Account, requestedModel string, excludedIDs map[int64]struct{}, requireCompact bool) (*Account, bool) {
var selected *Account
selectedCompactTier := -1
compactBlocked := false
needsUpstreamCheck := s.needsUpstreamChannelRestrictionCheck(ctx, groupID)
for i := range accounts {
......@@ -1330,31 +1442,50 @@ func (s *OpenAIGatewayService) selectBestAccount(ctx context.Context, groupID *i
continue
}
fresh := s.resolveFreshSchedulableOpenAIAccount(ctx, acc, requestedModel)
fresh := s.resolveFreshSchedulableOpenAIAccount(ctx, acc, requestedModel, false)
if fresh == nil {
continue
}
fresh = s.recheckSelectedOpenAIAccountFromDB(ctx, fresh, requestedModel)
fresh = s.recheckSelectedOpenAIAccountFromDB(ctx, fresh, requestedModel, false)
if fresh == nil {
continue
}
if needsUpstreamCheck && s.isUpstreamModelRestrictedByChannel(ctx, *groupID, fresh, requestedModel) {
if needsUpstreamCheck && s.isUpstreamModelRestrictedByChannel(ctx, *groupID, fresh, requestedModel, requireCompact) {
continue
}
compactTier := 0
if requireCompact {
compactTier = openAICompactSupportTier(fresh)
if compactTier == 0 {
compactBlocked = true
continue
}
}
// 选择优先级最高且最久未使用的账号
// Select highest priority and least recently used
if selected == nil {
selected = fresh
selectedCompactTier = compactTier
continue
}
// compact 模式下高 tier 优先;同 tier 内才比较 priority/LRU。
if requireCompact && compactTier != selectedCompactTier {
if compactTier > selectedCompactTier {
selected = fresh
selectedCompactTier = compactTier
}
continue
}
if s.isBetterAccount(fresh, selected) {
selected = fresh
selectedCompactTier = compactTier
}
}
return selected
return selected, compactBlocked
}
// isBetterAccount 判断 candidate 是否比 current 更优。
......@@ -1392,6 +1523,10 @@ func (s *OpenAIGatewayService) isBetterAccount(candidate, current *Account) bool
// SelectAccountWithLoadAwareness selects an account with load-awareness and wait plan.
func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}) (*AccountSelectionResult, error) {
return s.selectAccountWithLoadAwareness(ctx, groupID, sessionHash, requestedModel, excludedIDs, false)
}
func (s *OpenAIGatewayService) selectAccountWithLoadAwareness(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}, requireCompact bool) (*AccountSelectionResult, error) {
if s.checkChannelPricingRestriction(ctx, groupID, requestedModel) {
slog.Warn("channel pricing restriction blocked request",
"group_id", derefGroupID(groupID),
......@@ -1408,7 +1543,7 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex
}
}
if s.concurrencyService == nil || !cfg.LoadBatchEnabled {
account, err := s.selectAccountForModelWithExclusions(ctx, groupID, sessionHash, requestedModel, excludedIDs, stickyAccountID)
account, err := s.selectAccountForModelWithExclusions(ctx, groupID, sessionHash, requestedModel, excludedIDs, requireCompact, stickyAccountID)
if err != nil {
return nil, err
}
......@@ -1461,12 +1596,11 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex
if clearSticky {
_ = s.deleteStickySessionAccountID(ctx, groupID, sessionHash)
}
if !clearSticky && account.IsSchedulable() && account.IsOpenAI() &&
(requestedModel == "" || account.IsModelSupported(requestedModel)) {
account = s.recheckSelectedOpenAIAccountFromDB(ctx, account, requestedModel)
if !clearSticky && isOpenAIAccountEligibleForRequest(account, requestedModel, false) {
account = s.recheckSelectedOpenAIAccountFromDB(ctx, account, requestedModel, requireCompact)
if account == nil {
_ = s.deleteStickySessionAccountID(ctx, groupID, sessionHash)
} else if needsUpstreamCheck && s.isUpstreamModelRestrictedByChannel(ctx, *groupID, account, requestedModel) {
} else if needsUpstreamCheck && s.isUpstreamModelRestrictedByChannel(ctx, *groupID, account, requestedModel, requireCompact) {
_ = s.deleteStickySessionAccountID(ctx, groupID, sessionHash)
} else {
result, err := s.tryAcquireAccountSlot(ctx, accountID, account.Concurrency)
......@@ -1491,6 +1625,7 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex
}
// ============ Layer 2: Load-aware selection ============
baseCandidateCount := 0
candidates := make([]*Account, 0, len(accounts))
for i := range accounts {
acc := &accounts[i]
......@@ -1506,9 +1641,10 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex
if requestedModel != "" && !acc.IsModelSupported(requestedModel) {
continue
}
if needsUpstreamCheck && s.isUpstreamModelRestrictedByChannel(ctx, *groupID, acc, requestedModel) {
if needsUpstreamCheck && s.isUpstreamModelRestrictedByChannel(ctx, *groupID, acc, requestedModel, requireCompact) {
continue
}
baseCandidateCount++
candidates = append(candidates, acc)
}
......@@ -1528,12 +1664,19 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex
if err != nil {
ordered := append([]*Account(nil), candidates...)
sortAccountsByPriorityAndLastUsed(ordered, false)
if requireCompact {
ordered = prioritizeOpenAICompactAccounts(ordered)
}
for _, acc := range ordered {
fresh := s.resolveFreshSchedulableOpenAIAccount(ctx, acc, requestedModel)
fresh := s.resolveFreshSchedulableOpenAIAccount(ctx, acc, requestedModel, false)
if fresh == nil {
continue
}
fresh = s.recheckSelectedOpenAIAccountFromDB(ctx, fresh, requestedModel, requireCompact)
if fresh == nil {
continue
}
if needsUpstreamCheck && s.isUpstreamModelRestrictedByChannel(ctx, *groupID, fresh, requestedModel) {
if needsUpstreamCheck && s.isUpstreamModelRestrictedByChannel(ctx, *groupID, fresh, requestedModel, requireCompact) {
continue
}
result, err := s.tryAcquireAccountSlot(ctx, fresh.ID, fresh.Concurrency)
......@@ -1581,12 +1724,35 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex
})
shuffleWithinSortGroups(available)
selectionOrder := make([]accountWithLoad, 0, len(available))
if requireCompact {
appendTier := func(out []accountWithLoad, tier int) []accountWithLoad {
for _, item := range available {
fresh := s.resolveFreshSchedulableOpenAIAccount(ctx, item.account, requestedModel)
if openAICompactSupportTier(item.account) == tier {
out = append(out, item)
}
}
return out
}
selectionOrder = appendTier(selectionOrder, 2)
selectionOrder = appendTier(selectionOrder, 1)
// tier 0 候选作为兜底追加:DB recheck 时若发现 cache tier 0 实际
// 已升级为 1/2(探测刚跑完,cache 尚未刷新),仍可正常命中。
selectionOrder = appendTier(selectionOrder, 0)
} else {
selectionOrder = append(selectionOrder, available...)
}
for _, item := range selectionOrder {
fresh := s.resolveFreshSchedulableOpenAIAccount(ctx, item.account, requestedModel, false)
if fresh == nil {
continue
}
if needsUpstreamCheck && s.isUpstreamModelRestrictedByChannel(ctx, *groupID, fresh, requestedModel) {
fresh = s.recheckSelectedOpenAIAccountFromDB(ctx, fresh, requestedModel, requireCompact)
if fresh == nil {
continue
}
if needsUpstreamCheck && s.isUpstreamModelRestrictedByChannel(ctx, *groupID, fresh, requestedModel, requireCompact) {
continue
}
result, err := s.tryAcquireAccountSlot(ctx, fresh.ID, fresh.Concurrency)
......@@ -1602,12 +1768,19 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex
// ============ Layer 3: Fallback wait ============
sortAccountsByPriorityAndLastUsed(candidates, false)
if requireCompact {
candidates = prioritizeOpenAICompactAccounts(candidates)
}
for _, acc := range candidates {
fresh := s.resolveFreshSchedulableOpenAIAccount(ctx, acc, requestedModel)
fresh := s.resolveFreshSchedulableOpenAIAccount(ctx, acc, requestedModel, false)
if fresh == nil {
continue
}
fresh = s.recheckSelectedOpenAIAccountFromDB(ctx, fresh, requestedModel, requireCompact)
if fresh == nil {
continue
}
if needsUpstreamCheck && s.isUpstreamModelRestrictedByChannel(ctx, *groupID, fresh, requestedModel) {
if needsUpstreamCheck && s.isUpstreamModelRestrictedByChannel(ctx, *groupID, fresh, requestedModel, requireCompact) {
continue
}
return s.newSelectionResult(ctx, fresh, false, nil, &AccountWaitPlan{
......@@ -1618,6 +1791,9 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex
})
}
if requireCompact && baseCandidateCount > 0 {
return nil, ErrNoAvailableCompactAccounts
}
return nil, ErrNoAvailableAccounts
}
......@@ -1648,7 +1824,7 @@ func (s *OpenAIGatewayService) tryAcquireAccountSlot(ctx context.Context, accoun
return s.concurrencyService.AcquireAccountSlot(ctx, accountID, maxConcurrency)
}
func (s *OpenAIGatewayService) resolveFreshSchedulableOpenAIAccount(ctx context.Context, account *Account, requestedModel string) *Account {
func (s *OpenAIGatewayService) resolveFreshSchedulableOpenAIAccount(ctx context.Context, account *Account, requestedModel string, requireCompact bool) *Account {
if account == nil {
return nil
}
......@@ -1662,20 +1838,20 @@ func (s *OpenAIGatewayService) resolveFreshSchedulableOpenAIAccount(ctx context.
fresh = current
}
if !fresh.IsSchedulable() || !fresh.IsOpenAI() {
return nil
}
if requestedModel != "" && !fresh.IsModelSupported(requestedModel) {
if !isOpenAIAccountEligibleForRequest(fresh, requestedModel, requireCompact) {
return nil
}
return fresh
}
func (s *OpenAIGatewayService) recheckSelectedOpenAIAccountFromDB(ctx context.Context, account *Account, requestedModel string) *Account {
func (s *OpenAIGatewayService) recheckSelectedOpenAIAccountFromDB(ctx context.Context, account *Account, requestedModel string, requireCompact bool) *Account {
if account == nil {
return nil
}
if s.schedulerSnapshot == nil || s.accountRepo == nil {
if !isOpenAIAccountEligibleForRequest(account, requestedModel, requireCompact) {
return nil
}
return account
}
......@@ -1683,10 +1859,7 @@ func (s *OpenAIGatewayService) recheckSelectedOpenAIAccountFromDB(ctx context.Co
if err != nil || latest == nil {
return nil
}
if !latest.IsSchedulable() || !latest.IsOpenAI() {
return nil
}
if requestedModel != "" && !latest.IsModelSupported(requestedModel) {
if !isOpenAIAccountEligibleForRequest(latest, requestedModel, requireCompact) {
return nil
}
return latest
......@@ -1995,11 +2168,39 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
account.Type,
)
}
if err := validateCodexSparkInput(reqBody, upstreamModel); err != nil {
setOpsUpstreamError(c, http.StatusBadRequest, err.Error(), "")
c.JSON(http.StatusBadRequest, gin.H{
"error": gin.H{
"type": "invalid_request_error",
"message": err.Error(),
"param": "input",
},
})
return nil, err
}
// Compact-only model 映射:仅在 /responses/compact 路径生效,且优先级高于
// OAuth 模型规范化(避免 OAuth 规范化覆盖 compact-only 自定义模型)。
isCompactRequest := isOpenAIResponsesCompactPath(c)
compactMapped := false
if isCompactRequest {
compactMappedModel := resolveOpenAICompactForwardModel(account, billingModel)
if compactMappedModel != "" && compactMappedModel != billingModel {
compactMapped = true
upstreamModel = compactMappedModel
reqBody["model"] = compactMappedModel
bodyModified = true
markPatchSet("model", compactMappedModel)
logger.LegacyPrintf("service.openai_gateway", "[OpenAI] Compact model mapping applied: %s -> %s (account: %s, isCodexCLI: %v)", billingModel, compactMappedModel, account.Name, isCodexCLI)
}
}
// OpenAI OAuth 账号走 ChatGPT internal Codex endpoint,需要将模型名规范化为
// 上游可识别的 Codex/GPT 系列。API Key 账号则应保留原始/映射后的模型名,
// 以兼容自定义 base_url 的 OpenAI-compatible 上游。
if model, ok := reqBody["model"].(string); ok {
if !compactMapped {
upstreamModel = normalizeOpenAIModelForUpstream(account, model)
if upstreamModel != "" && upstreamModel != model {
logger.LegacyPrintf("service.openai_gateway", "[OpenAI] Upstream model resolved: %s -> %s (account: %s, type: %s, isCodexCLI: %v)",
......@@ -2008,6 +2209,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
bodyModified = true
markPatchSet("model", upstreamModel)
}
}
// 移除 gpt-5.2-codex 以下的版本 verbosity 参数
// 确保高版本模型向低版本模型映射不报错
......@@ -2029,7 +2231,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
}
if account.Type == AccountTypeOAuth {
codexResult := applyCodexOAuthTransform(reqBody, isCodexCLI, isOpenAIResponsesCompactPath(c))
codexResult := applyCodexOAuthTransform(reqBody, isCodexCLI, isCompactRequest)
if codexResult.Modified {
bodyModified = true
disablePatch()
......@@ -2504,6 +2706,19 @@ func (s *OpenAIGatewayService) forwardOpenAIPassthrough(
reqStream bool,
startTime time.Time,
) (*OpenAIForwardResult, error) {
upstreamPassthroughModel := ""
if isOpenAIResponsesCompactPath(c) {
compactMappedModel := resolveOpenAICompactForwardModel(account, reqModel)
if compactMappedModel != "" && compactMappedModel != reqModel {
nextBody, setErr := sjson.SetBytes(body, "model", compactMappedModel)
if setErr != nil {
return nil, fmt.Errorf("set compact passthrough model: %w", setErr)
}
body = nextBody
upstreamPassthroughModel = compactMappedModel
}
}
if account != nil && account.Type == AccountTypeOAuth {
if rejectReason := detectOpenAIPassthroughInstructionsRejectReason(reqModel, body); rejectReason != "" {
rejectMsg := "OpenAI codex passthrough requires a non-empty instructions field"
......@@ -2629,14 +2844,14 @@ func (s *OpenAIGatewayService) forwardOpenAIPassthrough(
var usage *OpenAIUsage
var firstTokenMs *int
if reqStream {
result, err := s.handleStreamingResponsePassthrough(ctx, resp, c, account, startTime)
result, err := s.handleStreamingResponsePassthrough(ctx, resp, c, account, startTime, reqModel, upstreamPassthroughModel)
if err != nil {
return nil, err
}
usage = result.usage
firstTokenMs = result.firstTokenMs
} else {
usage, err = s.handleNonStreamingResponsePassthrough(ctx, resp, c)
usage, err = s.handleNonStreamingResponsePassthrough(ctx, resp, c, reqModel, upstreamPassthroughModel)
if err != nil {
return nil, err
}
......@@ -2654,6 +2869,7 @@ func (s *OpenAIGatewayService) forwardOpenAIPassthrough(
RequestID: resp.Header.Get("x-request-id"),
Usage: *usage,
Model: reqModel,
UpstreamModel: upstreamPassthroughModel,
ServiceTier: extractOpenAIServiceTierFromBody(body),
ReasoningEffort: reasoningEffort,
Stream: reqStream,
......@@ -2957,12 +3173,121 @@ type openaiStreamingResultPassthrough struct {
firstTokenMs *int
}
func openAIStreamClientOutputStarted(c *gin.Context, localStarted bool) bool {
if localStarted {
return true
}
return c != nil && c.Writer != nil && c.Writer.Written()
}
func openAIStreamEventIsPreamble(eventType string) bool {
switch strings.TrimSpace(eventType) {
case "response.created", "response.in_progress":
return true
default:
return false
}
}
func openAIStreamDataStartsClientOutput(data, eventType string) bool {
trimmed := strings.TrimSpace(data)
if trimmed == "" {
return false
}
if strings.TrimSpace(eventType) == "response.failed" {
return false
}
return !openAIStreamEventIsPreamble(eventType)
}
func openAIStreamFailedEventShouldFailover(payload []byte, message string) bool {
code := strings.ToLower(strings.TrimSpace(gjson.GetBytes(payload, "response.error.code").String()))
if code == "" {
code = strings.ToLower(strings.TrimSpace(gjson.GetBytes(payload, "error.code").String()))
}
errType := strings.ToLower(strings.TrimSpace(gjson.GetBytes(payload, "response.error.type").String()))
if errType == "" {
errType = strings.ToLower(strings.TrimSpace(gjson.GetBytes(payload, "error.type").String()))
}
combined := strings.ToLower(strings.TrimSpace(message + " " + code + " " + errType))
if combined == "" {
return true
}
nonRetryableMarkers := []string{
"invalid_request",
"content_policy",
"policy",
"safety",
"high-risk cyber",
"not allowed",
"violat",
}
for _, marker := range nonRetryableMarkers {
if strings.Contains(combined, marker) {
return false
}
}
return true
}
func (s *OpenAIGatewayService) newOpenAIStreamFailoverError(
c *gin.Context,
account *Account,
passthrough bool,
upstreamRequestID string,
payload []byte,
message string,
) *UpstreamFailoverError {
message = sanitizeUpstreamErrorMessage(strings.TrimSpace(message))
if message == "" {
message = "OpenAI stream disconnected before completion"
}
detail := ""
if len(payload) > 0 && s != nil && s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
if maxBytes <= 0 {
maxBytes = 2048
}
detail = truncateString(string(payload), maxBytes)
}
if c != nil {
setOpsUpstreamError(c, http.StatusBadGateway, message, detail)
event := OpsUpstreamErrorEvent{
Platform: PlatformOpenAI,
UpstreamStatusCode: http.StatusBadGateway,
UpstreamRequestID: strings.TrimSpace(upstreamRequestID),
Passthrough: passthrough,
Kind: "failover",
Message: message,
Detail: detail,
}
if account != nil {
event.Platform = account.Platform
event.AccountID = account.ID
event.AccountName = account.Name
}
appendOpsUpstreamError(c, event)
}
body, _ := json.Marshal(gin.H{
"error": gin.H{
"type": "upstream_error",
"message": message,
},
})
return &UpstreamFailoverError{
StatusCode: http.StatusBadGateway,
ResponseBody: body,
}
}
func (s *OpenAIGatewayService) handleStreamingResponsePassthrough(
ctx context.Context,
resp *http.Response,
c *gin.Context,
account *Account,
startTime time.Time,
originalModel string,
mappedModel string,
) (*openaiStreamingResultPassthrough, error) {
writeOpenAIPassthroughResponseHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter)
......@@ -2986,7 +3311,22 @@ func (s *OpenAIGatewayService) handleStreamingResponsePassthrough(
clientDisconnected := false
sawDone := false
sawTerminalEvent := false
sawFailedEvent := false
failedMessage := ""
clientOutputStarted := false
upstreamRequestID := strings.TrimSpace(resp.Header.Get("x-request-id"))
pendingLines := make([]string, 0, 8)
writePendingLines := func() bool {
for _, pending := range pendingLines {
if _, err := fmt.Fprintln(w, pending); err != nil {
clientDisconnected = true
logger.LegacyPrintf("service.openai_gateway", "[OpenAI passthrough] Client disconnected during streaming, continue draining upstream for usage: account=%d", account.ID)
return false
}
}
pendingLines = pendingLines[:0]
return true
}
scanner := bufio.NewScanner(resp.Body)
maxLineSize := defaultMaxLineSize
......@@ -2997,18 +3337,40 @@ func (s *OpenAIGatewayService) handleStreamingResponsePassthrough(
scanner.Buffer(scanBuf[:0], maxLineSize)
defer putSSEScannerBuf64K(scanBuf)
needModelReplace := strings.TrimSpace(originalModel) != "" && strings.TrimSpace(mappedModel) != "" && strings.TrimSpace(originalModel) != strings.TrimSpace(mappedModel)
for scanner.Scan() {
line := scanner.Text()
lineStartsClientOutput := false
forceFlushFailedEvent := false
if data, ok := extractOpenAISSEDataLine(line); ok {
dataBytes := []byte(data)
trimmedData := strings.TrimSpace(data)
if needModelReplace && strings.Contains(data, mappedModel) {
line = s.replaceModelInSSELine(line, mappedModel, originalModel)
if replacedData, replaced := extractOpenAISSEDataLine(line); replaced {
dataBytes = []byte(replacedData)
trimmedData = strings.TrimSpace(replacedData)
}
}
eventType := strings.TrimSpace(gjson.Get(trimmedData, "type").String())
if eventType == "response.failed" {
failedMessage = extractOpenAISSEErrorMessage(dataBytes)
if !openAIStreamClientOutputStarted(c, clientOutputStarted) && openAIStreamFailedEventShouldFailover(dataBytes, failedMessage) {
return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs},
s.newOpenAIStreamFailoverError(c, account, true, upstreamRequestID, dataBytes, failedMessage)
}
forceFlushFailedEvent = true
sawFailedEvent = true
}
if trimmedData == "[DONE]" {
sawDone = true
}
if openAIStreamEventIsTerminal(trimmedData) {
sawTerminalEvent = true
}
if firstTokenMs == nil && trimmedData != "" && trimmedData != "[DONE]" {
lineStartsClientOutput = forceFlushFailedEvent || openAIStreamDataStartsClientOutput(trimmedData, eventType)
if firstTokenMs == nil && lineStartsClientOutput && trimmedData != "[DONE]" {
ms := int(time.Since(startTime).Milliseconds())
firstTokenMs = &ms
}
......@@ -3016,20 +3378,30 @@ func (s *OpenAIGatewayService) handleStreamingResponsePassthrough(
}
if !clientDisconnected {
if !clientOutputStarted && !lineStartsClientOutput {
pendingLines = append(pendingLines, line)
continue
}
if !clientOutputStarted && len(pendingLines) > 0 {
if !writePendingLines() {
continue
}
}
if _, err := fmt.Fprintln(w, line); err != nil {
clientDisconnected = true
logger.LegacyPrintf("service.openai_gateway", "[OpenAI passthrough] Client disconnected during streaming, continue draining upstream for usage: account=%d", account.ID)
} else {
clientOutputStarted = true
flusher.Flush()
}
}
}
if err := scanner.Err(); err != nil {
if sawTerminalEvent {
if sawTerminalEvent && !sawFailedEvent {
return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, nil
}
if clientDisconnected {
return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream usage incomplete after disconnect: %w", err)
if sawFailedEvent {
return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("upstream response failed: %s", failedMessage)
}
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream usage incomplete: %w", err)
......@@ -3038,6 +3410,17 @@ func (s *OpenAIGatewayService) handleStreamingResponsePassthrough(
logger.LegacyPrintf("service.openai_gateway", "[OpenAI passthrough] SSE line too long: account=%d max_size=%d error=%v", account.ID, maxLineSize, err)
return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, err
}
if !openAIStreamClientOutputStarted(c, clientOutputStarted) {
msg := "OpenAI stream disconnected before completion"
if errText := strings.TrimSpace(err.Error()); errText != "" {
msg += ": " + errText
}
return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs},
s.newOpenAIStreamFailoverError(c, account, true, upstreamRequestID, nil, msg)
}
if clientDisconnected {
return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream usage incomplete after disconnect: %w", err)
}
logger.LegacyPrintf("service.openai_gateway",
"[OpenAI passthrough] 流读取异常中断: account=%d request_id=%s err=%v",
account.ID,
......@@ -3046,12 +3429,19 @@ func (s *OpenAIGatewayService) handleStreamingResponsePassthrough(
)
return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream read error: %w", err)
}
if sawFailedEvent {
return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("upstream response failed: %s", failedMessage)
}
if !clientDisconnected && !sawDone && !sawTerminalEvent && ctx.Err() == nil {
logger.FromContext(ctx).With(
zap.String("component", "service.openai_gateway"),
zap.Int64("account_id", account.ID),
zap.String("upstream_request_id", upstreamRequestID),
).Info("OpenAI passthrough 上游流在未收到 [DONE] 时结束,疑似断流")
if !openAIStreamClientOutputStarted(c, clientOutputStarted) {
return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs},
s.newOpenAIStreamFailoverError(c, account, true, upstreamRequestID, nil, "OpenAI stream ended before a terminal event")
}
return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, errors.New("stream usage incomplete: missing terminal event")
}
......@@ -3062,6 +3452,8 @@ func (s *OpenAIGatewayService) handleNonStreamingResponsePassthrough(
ctx context.Context,
resp *http.Response,
c *gin.Context,
originalModel string,
mappedModel string,
) (*OpenAIUsage, error) {
body, err := ReadUpstreamResponseBody(resp.Body, s.cfg, c, openAITooLargeError)
if err != nil {
......@@ -3073,7 +3465,7 @@ func (s *OpenAIGatewayService) handleNonStreamingResponsePassthrough(
// stream=false was requested. Without this conversion the client would
// receive raw SSE text or a terminal event with empty output.
if isEventStreamResponse(resp.Header) {
return s.handlePassthroughSSEToJSON(resp, c, body)
return s.handlePassthroughSSEToJSON(resp, c, body, originalModel, mappedModel)
}
usage := &OpenAIUsage{}
......@@ -3095,14 +3487,18 @@ func (s *OpenAIGatewayService) handleNonStreamingResponsePassthrough(
if contentType == "" {
contentType = "application/json"
}
if originalModel != "" && mappedModel != "" && originalModel != mappedModel {
body = s.replaceModelInResponseBody(body, mappedModel, originalModel)
}
c.Data(resp.StatusCode, contentType, body)
return usage, nil
}
// handlePassthroughSSEToJSON converts an SSE response body into a JSON
// response for the passthrough path. It mirrors handleSSEToJSON but skips
// model replacement (passthrough does not remap models).
func (s *OpenAIGatewayService) handlePassthroughSSEToJSON(resp *http.Response, c *gin.Context, body []byte) (*OpenAIUsage, error) {
// response for the passthrough path. It mirrors handleSSEToJSON while
// preserving passthrough payloads, except compact-only model remapping may
// rewrite model fields back to the original requested model.
func (s *OpenAIGatewayService) handlePassthroughSSEToJSON(resp *http.Response, c *gin.Context, body []byte, originalModel string, mappedModel string) (*OpenAIUsage, error) {
bodyText := string(body)
finalResponse, ok := extractCodexFinalResponse(bodyText)
......@@ -3121,6 +3517,9 @@ func (s *OpenAIGatewayService) handlePassthroughSSEToJSON(resp *http.Response, c
}
}
body = finalResponse
if originalModel != "" && mappedModel != "" && originalModel != mappedModel {
body = s.replaceModelInResponseBody(body, mappedModel, originalModel)
}
// Correct tool calls in final response
body = s.correctToolCallsInResponseBody(body)
} else {
......@@ -3133,6 +3532,10 @@ func (s *OpenAIGatewayService) handlePassthroughSSEToJSON(resp *http.Response, c
return nil, s.writeOpenAINonStreamingProtocolError(resp, c, msg)
}
usage = s.parseSSEUsageFromBody(bodyText)
if originalModel != "" && mappedModel != "" && originalModel != mappedModel {
bodyText = s.replaceModelInSSEBody(bodyText, mappedModel, originalModel)
}
body = []byte(bodyText)
}
writeOpenAIPassthroughResponseHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter)
......@@ -3631,8 +4034,10 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
if keepaliveTicker != nil {
keepaliveCh = keepaliveTicker.C
}
// 记录上次收到上游数据的时间,用于控制 keepalive 发送频率
lastDataAt := time.Now()
// Track downstream writes separately from upstream reads: pre-output failover
// can buffer response.created / response.in_progress, so keepalive must be
// based on downstream idle time.
lastDownstreamWriteAt := time.Now()
// 仅发送一次错误事件,避免多次写入导致协议混乱。
// 注意:OpenAI `/v1/responses` streaming 事件必须符合 OpenAI Responses schema;
......@@ -3640,6 +4045,11 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
errorEventSent := false
clientDisconnected := false // 客户端断开后继续 drain 上游以收集 usage
sawTerminalEvent := false
sawFailedEvent := false
failedMessage := ""
clientOutputStarted := false
upstreamRequestID := strings.TrimSpace(resp.Header.Get("x-request-id"))
var streamFailoverErr error
sendErrorEvent := func(reason string) {
if errorEventSent || clientDisconnected {
return
......@@ -3656,7 +4066,10 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
}
if err := flushBuffered(); err != nil {
clientDisconnected = true
return
}
clientOutputStarted = true
lastDownstreamWriteAt = time.Now()
}
needModelReplace := originalModel != mappedModel
......@@ -3664,45 +4077,73 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}
}
finalizeStream := func() (*openaiStreamingResult, error) {
if !sawTerminalEvent {
if !openAIStreamClientOutputStarted(c, clientOutputStarted) {
return resultWithUsage(), s.newOpenAIStreamFailoverError(
c,
account,
false,
upstreamRequestID,
nil,
"OpenAI stream ended before a terminal event",
)
}
return resultWithUsage(), fmt.Errorf("stream usage incomplete: missing terminal event")
}
if sawFailedEvent {
return resultWithUsage(), fmt.Errorf("upstream response failed: %s", failedMessage)
}
if !clientDisconnected {
hadBufferedData := bufferedWriter.Buffered() > 0
if err := flushBuffered(); err != nil {
clientDisconnected = true
logger.LegacyPrintf("service.openai_gateway", "Client disconnected during final flush, returning collected usage")
} else if hadBufferedData {
clientOutputStarted = true
lastDownstreamWriteAt = time.Now()
}
}
if !sawTerminalEvent {
return resultWithUsage(), fmt.Errorf("stream usage incomplete: missing terminal event")
}
return resultWithUsage(), nil
}
handleScanErr := func(scanErr error) (*openaiStreamingResult, error, bool) {
if scanErr == nil {
return nil, nil, false
}
if sawTerminalEvent {
if sawTerminalEvent && !sawFailedEvent {
logger.LegacyPrintf("service.openai_gateway", "Upstream scan ended after terminal event: %v", scanErr)
return resultWithUsage(), nil, true
}
if sawFailedEvent {
return resultWithUsage(), fmt.Errorf("upstream response failed: %s", failedMessage), true
}
// 客户端断开/取消请求时,上游读取往往会返回 context canceled。
// /v1/responses 的 SSE 事件必须符合 OpenAI 协议;这里不注入自定义 error event,避免下游 SDK 解析失败。
if errors.Is(scanErr, context.Canceled) || errors.Is(scanErr, context.DeadlineExceeded) {
return resultWithUsage(), fmt.Errorf("stream usage incomplete: %w", scanErr), true
}
// 客户端已断开时,上游出错仅影响体验,不影响计费;返回已收集 usage
if clientDisconnected {
return resultWithUsage(), fmt.Errorf("stream usage incomplete after disconnect: %w", scanErr), true
}
if errors.Is(scanErr, bufio.ErrTooLong) {
logger.LegacyPrintf("service.openai_gateway", "SSE line too long: account=%d max_size=%d error=%v", account.ID, maxLineSize, scanErr)
sendErrorEvent("response_too_large")
return resultWithUsage(), scanErr, true
}
if !openAIStreamClientOutputStarted(c, clientOutputStarted) {
msg := "OpenAI stream disconnected before completion"
if errText := strings.TrimSpace(scanErr.Error()); errText != "" {
msg += ": " + errText
}
return resultWithUsage(), s.newOpenAIStreamFailoverError(c, account, false, upstreamRequestID, nil, msg), true
}
// 客户端已断开时,上游出错仅影响体验,不影响计费;返回已收集 usage
if clientDisconnected {
return resultWithUsage(), fmt.Errorf("stream usage incomplete after disconnect: %w", scanErr), true
}
sendErrorEvent("stream_read_error")
return resultWithUsage(), fmt.Errorf("stream read error: %w", scanErr), true
}
processSSELine := func(line string, queueDrained bool) {
lastDataAt = time.Now()
if streamFailoverErr != nil {
return
}
// Extract data from SSE line (supports both "data: " and "data:" formats)
if data, ok := extractOpenAISSEDataLine(line); ok {
......@@ -3716,18 +4157,32 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
if openAIStreamEventIsTerminal(data) {
sawTerminalEvent = true
}
eventType := strings.TrimSpace(gjson.GetBytes(dataBytes, "type").String())
forceFlushFailedEvent := false
if eventType == "response.failed" {
failedMessage = extractOpenAISSEErrorMessage(dataBytes)
if !openAIStreamClientOutputStarted(c, clientOutputStarted) && openAIStreamFailedEventShouldFailover(dataBytes, failedMessage) {
sawFailedEvent = true
streamFailoverErr = s.newOpenAIStreamFailoverError(c, account, false, upstreamRequestID, dataBytes, failedMessage)
return
}
forceFlushFailedEvent = true
sawFailedEvent = true
}
// Correct Codex tool calls if needed (apply_patch -> edit, etc.)
if correctedData, corrected := s.toolCorrector.CorrectToolCallsInSSEBytes(dataBytes); corrected {
dataBytes = correctedData
data = string(correctedData)
line = "data: " + data
eventType = strings.TrimSpace(gjson.GetBytes(dataBytes, "type").String())
}
startsClientOutput := forceFlushFailedEvent || openAIStreamDataStartsClientOutput(data, eventType)
// 写入客户端(客户端断开后继续 drain 上游)
if !clientDisconnected {
shouldFlush := queueDrained
if firstTokenMs == nil && data != "" && data != "[DONE]" {
shouldFlush := queueDrained && (clientOutputStarted || startsClientOutput)
if firstTokenMs == nil && startsClientOutput {
// 保证首个 token 事件尽快出站,避免影响 TTFT。
shouldFlush = true
}
......@@ -3741,12 +4196,15 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
if err := flushBuffered(); err != nil {
clientDisconnected = true
logger.LegacyPrintf("service.openai_gateway", "Client disconnected during streaming flush, continuing to drain upstream for billing")
} else {
clientOutputStarted = true
lastDownstreamWriteAt = time.Now()
}
}
}
// Record first token time
if firstTokenMs == nil && data != "" && data != "[DONE]" {
if firstTokenMs == nil && startsClientOutput {
ms := int(time.Since(startTime).Milliseconds())
firstTokenMs = &ms
}
......@@ -3762,10 +4220,13 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
} else if _, err := bufferedWriter.WriteString("\n"); err != nil {
clientDisconnected = true
logger.LegacyPrintf("service.openai_gateway", "Client disconnected during streaming, continuing to drain upstream for billing")
} else if queueDrained {
} else if queueDrained && clientOutputStarted {
if err := flushBuffered(); err != nil {
clientDisconnected = true
logger.LegacyPrintf("service.openai_gateway", "Client disconnected during streaming flush, continuing to drain upstream for billing")
} else {
clientOutputStarted = true
lastDownstreamWriteAt = time.Now()
}
}
}
......@@ -3776,6 +4237,9 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
defer putSSEScannerBuf64K(scanBuf)
for scanner.Scan() {
processSSELine(scanner.Text(), true)
if streamFailoverErr != nil {
return resultWithUsage(), streamFailoverErr
}
}
if result, err, done := handleScanErr(scanner.Err()); done {
return result, err
......@@ -3825,6 +4289,9 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
return result, err
}
processSSELine(ev.line, len(events) == 0)
if streamFailoverErr != nil {
return resultWithUsage(), streamFailoverErr
}
case <-intervalCh:
lastRead := time.Unix(0, atomic.LoadInt64(&lastReadAt))
......@@ -3846,7 +4313,7 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
if clientDisconnected {
continue
}
if time.Since(lastDataAt) < keepaliveInterval {
if time.Since(lastDownstreamWriteAt) < keepaliveInterval {
continue
}
if _, err := bufferedWriter.WriteString(":\n\n"); err != nil {
......@@ -3857,6 +4324,8 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
if err := flushBuffered(); err != nil {
clientDisconnected = true
logger.LegacyPrintf("service.openai_gateway", "Client disconnected during keepalive flush, continuing to drain upstream for billing")
} else {
lastDownstreamWriteAt = time.Now()
}
}
}
......@@ -3935,7 +4404,8 @@ func (s *OpenAIGatewayService) parseSSEUsageBytes(data []byte, usage *OpenAIUsag
return
}
eventType := gjson.GetBytes(data, "type").String()
if eventType != "response.completed" && eventType != "response.done" {
if eventType != "response.completed" && eventType != "response.done" &&
eventType != "response.incomplete" && eventType != "response.cancelled" && eventType != "response.canceled" {
return
}
......@@ -4082,7 +4552,7 @@ func extractOpenAISSETerminalEvent(body string) (string, []byte, bool) {
}
eventType := strings.TrimSpace(gjson.Get(data, "type").String())
switch eventType {
case "response.completed", "response.done", "response.failed":
case "response.completed", "response.done", "response.failed", "response.incomplete", "response.cancelled", "response.canceled":
return eventType, []byte(data), true
}
}
......
......@@ -93,6 +93,13 @@ type cancelReadCloser struct{}
func (c cancelReadCloser) Read(p []byte) (int, error) { return 0, context.Canceled }
func (c cancelReadCloser) Close() error { return nil }
type errReadCloser struct {
err error
}
func (r errReadCloser) Read([]byte) (int, error) { return 0, r.err }
func (r errReadCloser) Close() error { return nil }
type failingGinWriter struct {
gin.ResponseWriter
failAfter int
......@@ -220,6 +227,41 @@ func TestOpenAIGatewayService_GenerateSessionHash_AttachesLegacyHashToContext(t
require.NotEmpty(t, openAILegacySessionHashFromContext(c.Request.Context()))
}
func TestOpenAIGatewayService_GenerateExplicitSessionHash_SkipsContentFallback(t *testing.T) {
gin.SetMode(gin.TestMode)
svc := &OpenAIGatewayService{}
body := []byte(`{"model":"gpt-image-2","prompt":"draw a cat"}`)
t.Run("stateless image body stays unstuck", func(t *testing.T) {
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/images/generations", nil)
require.Empty(t, svc.GenerateExplicitSessionHash(c, body))
require.Empty(t, openAILegacySessionHashFromContext(c.Request.Context()))
})
t.Run("prompt_cache_key is explicit", func(t *testing.T) {
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/images/generations", nil)
got := svc.GenerateExplicitSessionHash(c, []byte(`{"model":"gpt-image-2","prompt_cache_key":"image-session"}`))
require.Equal(t, fmt.Sprintf("%016x", xxhash.Sum64String("image-session")), got)
require.NotEmpty(t, openAILegacySessionHashFromContext(c.Request.Context()))
})
t.Run("header overrides body", func(t *testing.T) {
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/images/generations", nil)
c.Request.Header.Set("session_id", "header-session")
got := svc.GenerateExplicitSessionHash(c, []byte(`{"prompt_cache_key":"body-session"}`))
require.Equal(t, fmt.Sprintf("%016x", xxhash.Sum64String("header-session")), got)
})
}
func TestOpenAIGatewayService_GenerateSessionHashWithFallback(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
......@@ -1003,6 +1045,190 @@ func TestOpenAIStreamingContextCanceledReturnsIncompleteErrorWithoutInjectingErr
}
}
func TestOpenAIStreamingReadErrorBeforeOutputReturnsFailover(t *testing.T) {
gin.SetMode(gin.TestMode)
cfg := &config.Config{
Gateway: config.GatewayConfig{
StreamDataIntervalTimeout: 0,
StreamKeepaliveInterval: 0,
MaxLineSize: defaultMaxLineSize,
},
}
svc := &OpenAIGatewayService{cfg: cfg}
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
resp := &http.Response{
StatusCode: http.StatusOK,
Body: errReadCloser{err: io.ErrUnexpectedEOF},
Header: http.Header{"X-Request-Id": []string{"rid-disconnect"}},
}
_, err := svc.handleStreamingResponse(c.Request.Context(), resp, c, &Account{ID: 1, Platform: PlatformOpenAI, Name: "acc"}, time.Now(), "model", "model")
require.Error(t, err)
var failoverErr *UpstreamFailoverError
require.ErrorAs(t, err, &failoverErr)
require.Equal(t, http.StatusBadGateway, failoverErr.StatusCode)
require.False(t, c.Writer.Written())
require.Empty(t, rec.Body.String())
}
func TestOpenAIStreamingResponseFailedBeforeOutputReturnsFailover(t *testing.T) {
gin.SetMode(gin.TestMode)
cfg := &config.Config{
Gateway: config.GatewayConfig{
StreamDataIntervalTimeout: 0,
StreamKeepaliveInterval: 0,
MaxLineSize: defaultMaxLineSize,
},
}
svc := &OpenAIGatewayService{cfg: cfg}
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
resp := &http.Response{
StatusCode: http.StatusOK,
Body: io.NopCloser(strings.NewReader(strings.Join([]string{
"event: response.created",
`data: {"type":"response.created","response":{"id":"resp_1"}}`,
"",
"event: response.in_progress",
`data: {"type":"response.in_progress","response":{"id":"resp_1"}}`,
"",
"event: response.failed",
`data: {"type":"response.failed","error":{"message":"An error occurred while processing your request."}}`,
"",
}, "\n"))),
Header: http.Header{"X-Request-Id": []string{"rid-failed"}},
}
_, err := svc.handleStreamingResponse(c.Request.Context(), resp, c, &Account{ID: 1, Platform: PlatformOpenAI, Name: "acc"}, time.Now(), "model", "model")
require.Error(t, err)
var failoverErr *UpstreamFailoverError
require.ErrorAs(t, err, &failoverErr)
require.Equal(t, http.StatusBadGateway, failoverErr.StatusCode)
require.Contains(t, string(failoverErr.ResponseBody), "An error occurred while processing your request")
require.False(t, c.Writer.Written())
require.Empty(t, rec.Body.String())
}
func TestOpenAIStreamingPreambleOnlyMissingTerminalReturnsFailover(t *testing.T) {
gin.SetMode(gin.TestMode)
cfg := &config.Config{
Gateway: config.GatewayConfig{
StreamDataIntervalTimeout: 0,
StreamKeepaliveInterval: 0,
MaxLineSize: defaultMaxLineSize,
},
}
svc := &OpenAIGatewayService{cfg: cfg}
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
resp := &http.Response{
StatusCode: http.StatusOK,
Body: io.NopCloser(strings.NewReader(strings.Join([]string{
"event: response.created",
`data: {"type":"response.created","response":{"id":"resp_1"}}`,
"",
"event: response.in_progress",
`data: {"type":"response.in_progress","response":{"id":"resp_1"}}`,
"",
}, "\n"))),
Header: http.Header{"X-Request-Id": []string{"rid-missing-terminal"}},
}
_, err := svc.handleStreamingResponse(c.Request.Context(), resp, c, &Account{ID: 1, Platform: PlatformOpenAI, Name: "acc"}, time.Now(), "model", "model")
require.Error(t, err)
var failoverErr *UpstreamFailoverError
require.ErrorAs(t, err, &failoverErr)
require.False(t, c.Writer.Written())
require.Empty(t, rec.Body.String())
}
func TestOpenAIStreamingPreambleKeepaliveUsesDownstreamIdle(t *testing.T) {
gin.SetMode(gin.TestMode)
cfg := &config.Config{
Gateway: config.GatewayConfig{
StreamDataIntervalTimeout: 0,
StreamKeepaliveInterval: 1,
MaxLineSize: defaultMaxLineSize,
},
}
svc := &OpenAIGatewayService{cfg: cfg}
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
pr, pw := io.Pipe()
resp := &http.Response{
StatusCode: http.StatusOK,
Body: pr,
Header: http.Header{},
}
go func() {
defer func() { _ = pw.Close() }()
_, _ = pw.Write([]byte("data: {\"type\":\"response.created\",\"response\":{\"id\":\"resp_1\"}}\n\n"))
for i := 0; i < 6; i++ {
time.Sleep(250 * time.Millisecond)
_, _ = pw.Write([]byte("data: {\"type\":\"response.in_progress\",\"response\":{\"id\":\"resp_1\"}}\n\n"))
}
_, _ = pw.Write([]byte("data: {\"type\":\"response.completed\",\"response\":{\"usage\":{\"input_tokens\":1,\"output_tokens\":2}}}\n\n"))
}()
result, err := svc.handleStreamingResponse(c.Request.Context(), resp, c, &Account{ID: 1, Platform: PlatformOpenAI, Name: "acc"}, time.Now(), "model", "model")
_ = pr.Close()
require.NoError(t, err)
require.NotNil(t, result)
require.Contains(t, rec.Body.String(), ":\n\n")
require.Contains(t, rec.Body.String(), "response.completed")
}
func TestOpenAIStreamingPolicyResponseFailedBeforeOutputPassesThrough(t *testing.T) {
gin.SetMode(gin.TestMode)
cfg := &config.Config{
Gateway: config.GatewayConfig{
StreamDataIntervalTimeout: 0,
StreamKeepaliveInterval: 0,
MaxLineSize: defaultMaxLineSize,
},
}
svc := &OpenAIGatewayService{cfg: cfg}
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
resp := &http.Response{
StatusCode: http.StatusOK,
Body: io.NopCloser(strings.NewReader(strings.Join([]string{
"event: response.created",
`data: {"type":"response.created","response":{"id":"resp_1"}}`,
"",
"event: response.failed",
`data: {"type":"response.failed","error":{"type":"safety_error","message":"This request has been flagged for potentially high-risk cyber activity."}}`,
"",
}, "\n"))),
Header: http.Header{"X-Request-Id": []string{"rid-policy-failed"}},
}
_, err := svc.handleStreamingResponse(c.Request.Context(), resp, c, &Account{ID: 1, Platform: PlatformOpenAI, Name: "acc"}, time.Now(), "model", "model")
require.Error(t, err)
var failoverErr *UpstreamFailoverError
require.False(t, errors.As(err, &failoverErr))
require.True(t, c.Writer.Written())
require.Contains(t, rec.Body.String(), "response.failed")
require.Contains(t, rec.Body.String(), "high-risk cyber activity")
}
func TestOpenAIStreamingClientDisconnectDrainsUpstreamUsage(t *testing.T) {
gin.SetMode(gin.TestMode)
cfg := &config.Config{
......@@ -1072,7 +1298,7 @@ func TestOpenAIStreamingMissingTerminalEventReturnsIncompleteError(t *testing.T)
go func() {
defer func() { _ = pw.Close() }()
_, _ = pw.Write([]byte("data: {\"type\":\"response.in_progress\",\"response\":{}}\n\n"))
_, _ = pw.Write([]byte("data: {\"type\":\"response.output_item.added\",\"item\":{\"type\":\"message\"},\"output_index\":0}\n\n"))
}()
_, err := svc.handleStreamingResponse(c.Request.Context(), resp, c, &Account{ID: 1}, time.Now(), "model", "model")
......@@ -1104,16 +1330,52 @@ func TestOpenAIStreamingPassthroughMissingTerminalEventReturnsIncompleteError(t
go func() {
defer func() { _ = pw.Close() }()
_, _ = pw.Write([]byte("data: {\"type\":\"response.in_progress\",\"response\":{}}\n\n"))
_, _ = pw.Write([]byte("data: {\"type\":\"response.output_item.added\",\"item\":{\"type\":\"message\"},\"output_index\":0}\n\n"))
}()
_, err := svc.handleStreamingResponsePassthrough(c.Request.Context(), resp, c, &Account{ID: 1}, time.Now())
_, err := svc.handleStreamingResponsePassthrough(c.Request.Context(), resp, c, &Account{ID: 1}, time.Now(), "", "")
_ = pr.Close()
if err == nil || !strings.Contains(err.Error(), "missing terminal event") {
t.Fatalf("expected missing terminal event error, got %v", err)
}
}
func TestOpenAIStreamingPassthroughResponseFailedBeforeOutputReturnsFailover(t *testing.T) {
gin.SetMode(gin.TestMode)
cfg := &config.Config{
Gateway: config.GatewayConfig{
MaxLineSize: defaultMaxLineSize,
},
}
svc := &OpenAIGatewayService{cfg: cfg}
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
resp := &http.Response{
StatusCode: http.StatusOK,
Body: io.NopCloser(strings.NewReader(strings.Join([]string{
"event: response.created",
`data: {"type":"response.created","response":{"id":"resp_1"}}`,
"",
"event: response.failed",
`data: {"type":"response.failed","error":{"message":"upstream processing failed"}}`,
"",
}, "\n"))),
Header: http.Header{"X-Request-Id": []string{"rid-passthrough-failed"}},
}
_, err := svc.handleStreamingResponsePassthrough(c.Request.Context(), resp, c, &Account{ID: 1, Platform: PlatformOpenAI, Name: "acc"}, time.Now(), "", "")
require.Error(t, err)
var failoverErr *UpstreamFailoverError
require.ErrorAs(t, err, &failoverErr)
require.Equal(t, http.StatusBadGateway, failoverErr.StatusCode)
require.Contains(t, string(failoverErr.ResponseBody), "upstream processing failed")
require.False(t, c.Writer.Written())
require.Empty(t, rec.Body.String())
}
func TestOpenAIStreamingPassthroughResponseDoneWithoutDoneMarkerStillSucceeds(t *testing.T) {
gin.SetMode(gin.TestMode)
cfg := &config.Config{
......@@ -1139,7 +1401,42 @@ func TestOpenAIStreamingPassthroughResponseDoneWithoutDoneMarkerStillSucceeds(t
_, _ = pw.Write([]byte("data: {\"type\":\"response.done\",\"response\":{\"usage\":{\"input_tokens\":2,\"output_tokens\":3,\"input_tokens_details\":{\"cached_tokens\":1}}}}\n\n"))
}()
result, err := svc.handleStreamingResponsePassthrough(c.Request.Context(), resp, c, &Account{ID: 1}, time.Now())
result, err := svc.handleStreamingResponsePassthrough(c.Request.Context(), resp, c, &Account{ID: 1}, time.Now(), "", "")
_ = pr.Close()
require.NoError(t, err)
require.NotNil(t, result)
require.NotNil(t, result.usage)
require.Equal(t, 2, result.usage.InputTokens)
require.Equal(t, 3, result.usage.OutputTokens)
require.Equal(t, 1, result.usage.CacheReadInputTokens)
}
func TestOpenAIStreamingPassthroughResponseIncompleteWithoutDoneMarkerStillSucceeds(t *testing.T) {
gin.SetMode(gin.TestMode)
cfg := &config.Config{
Gateway: config.GatewayConfig{
MaxLineSize: defaultMaxLineSize,
},
}
svc := &OpenAIGatewayService{cfg: cfg}
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
pr, pw := io.Pipe()
resp := &http.Response{
StatusCode: http.StatusOK,
Body: pr,
Header: http.Header{},
}
go func() {
defer func() { _ = pw.Close() }()
_, _ = pw.Write([]byte("data: {\"type\":\"response.incomplete\",\"response\":{\"usage\":{\"input_tokens\":2,\"output_tokens\":3,\"input_tokens_details\":{\"cached_tokens\":1}}}}\n\n"))
}()
result, err := svc.handleStreamingResponsePassthrough(c.Request.Context(), resp, c, &Account{ID: 1}, time.Now(), "", "")
_ = pr.Close()
require.NoError(t, err)
require.NotNil(t, result)
......
package service
import "strings"
// resolveOpenAIForwardModel determines the upstream model for OpenAI-compatible
// forwarding. Group-level default mapping only applies when the account itself
// did not match any explicit model_mapping rule.
......@@ -12,8 +14,47 @@ func resolveOpenAIForwardModel(account *Account, requestedModel, defaultMappedMo
}
mappedModel, matched := account.ResolveMappedModel(requestedModel)
if !matched && defaultMappedModel != "" {
if !matched && defaultMappedModel != "" && !isExplicitCodexModel(requestedModel) {
return defaultMappedModel
}
return mappedModel
}
func isExplicitCodexModel(model string) bool {
model = strings.TrimSpace(model)
if model == "" {
return false
}
if strings.Contains(model, "/") {
parts := strings.Split(model, "/")
model = parts[len(parts)-1]
}
model = strings.ToLower(strings.TrimSpace(model))
if getNormalizedCodexModel(model) != "" {
return true
}
if strings.HasSuffix(model, "-openai-compact") {
base := strings.TrimSuffix(model, "-openai-compact")
return getNormalizedCodexModel(base) != ""
}
return false
}
// resolveOpenAICompactForwardModel determines the compact-only upstream model
// for /responses/compact requests. It never affects normal /responses traffic.
// When no compact-specific mapping matches, the input model is returned as-is.
func resolveOpenAICompactForwardModel(account *Account, model string) string {
trimmedModel := strings.TrimSpace(model)
if trimmedModel == "" || account == nil {
return trimmedModel
}
mappedModel, matched := account.ResolveCompactMappedModel(trimmedModel)
if !matched {
return trimmedModel
}
if trimmedMapped := strings.TrimSpace(mappedModel); trimmedMapped != "" {
return trimmedMapped
}
return trimmedModel
}
......@@ -15,10 +15,19 @@ func TestResolveOpenAIForwardModel(t *testing.T) {
account: &Account{
Credentials: map[string]any{},
},
requestedModel: "gpt-5.4",
requestedModel: "claude-opus-4-6",
defaultMappedModel: "gpt-4o-mini",
expectedModel: "gpt-4o-mini",
},
{
name: "preserves explicit gpt-5.4 instead of group default",
account: &Account{
Credentials: map[string]any{},
},
requestedModel: "gpt-5.4",
defaultMappedModel: "gpt-4o-mini",
expectedModel: "gpt-5.4",
},
{
name: "preserves exact passthrough mapping instead of group default",
account: &Account{
......@@ -58,6 +67,42 @@ func TestResolveOpenAIForwardModel(t *testing.T) {
defaultMappedModel: "gpt-4o-mini",
expectedModel: "gpt-5.4",
},
{
name: "preserves codex spark instead of group default",
account: &Account{
Credentials: map[string]any{},
},
requestedModel: "gpt-5.3-codex-spark",
defaultMappedModel: "gpt-5.4",
expectedModel: "gpt-5.3-codex-spark",
},
{
name: "preserves gpt-5.5 instead of group default",
account: &Account{
Credentials: map[string]any{},
},
requestedModel: "gpt-5.5",
defaultMappedModel: "gpt-5.4",
expectedModel: "gpt-5.5",
},
{
name: "preserves openai namespaced gpt-5.5 instead of group default",
account: &Account{
Credentials: map[string]any{},
},
requestedModel: "openai/gpt-5.5",
defaultMappedModel: "gpt-5.4",
expectedModel: "openai/gpt-5.5",
},
{
name: "preserves compact gpt-5.5 instead of group default",
account: &Account{
Credentials: map[string]any{},
},
requestedModel: "gpt-5.5-openai-compact",
defaultMappedModel: "gpt-5.4",
expectedModel: "gpt-5.5-openai-compact",
},
}
for _, tt := range tests {
......@@ -85,6 +130,74 @@ func TestResolveOpenAIForwardModel_PreventsClaudeModelFromFallingBackToGpt54(t *
}
}
func TestResolveOpenAICompactForwardModel(t *testing.T) {
tests := []struct {
name string
account *Account
model string
expectedModel string
}{
{
name: "nil account keeps original model",
account: nil,
model: "gpt-5.4",
expectedModel: "gpt-5.4",
},
{
name: "missing compact mapping keeps original model",
account: &Account{
Credentials: map[string]any{},
},
model: "gpt-5.4",
expectedModel: "gpt-5.4",
},
{
name: "exact compact mapping overrides model",
account: &Account{
Credentials: map[string]any{
"compact_model_mapping": map[string]any{
"gpt-5.4": "gpt-5.4-openai-compact",
},
},
},
model: "gpt-5.4",
expectedModel: "gpt-5.4-openai-compact",
},
{
name: "wildcard compact mapping overrides model",
account: &Account{
Credentials: map[string]any{
"compact_model_mapping": map[string]any{
"gpt-5.*": "gpt-5-openai-compact",
},
},
},
model: "gpt-5.4",
expectedModel: "gpt-5-openai-compact",
},
{
name: "passthrough compact mapping remains unchanged",
account: &Account{
Credentials: map[string]any{
"compact_model_mapping": map[string]any{
"gpt-5.4": "gpt-5.4",
},
},
},
model: "gpt-5.4",
expectedModel: "gpt-5.4",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := resolveOpenAICompactForwardModel(tt.account, tt.model); got != tt.expectedModel {
t.Fatalf("resolveOpenAICompactForwardModel(...) = %q, want %q", got, tt.expectedModel)
}
})
}
}
func TestNormalizeCodexModel(t *testing.T) {
cases := map[string]string{
"gpt-5.3-codex-spark": "gpt-5.3-codex-spark",
......
......@@ -734,7 +734,7 @@ func TestOpenAIGatewayService_OAuthPassthrough_NonCodexUAFallbackToCodexUA(t *te
require.NoError(t, err)
require.Equal(t, false, gjson.GetBytes(upstream.lastBody, "store").Bool())
require.Equal(t, true, gjson.GetBytes(upstream.lastBody, "stream").Bool())
require.Equal(t, "codex_cli_rs/0.104.0", upstream.lastReq.Header.Get("User-Agent"))
require.Equal(t, "codex_cli_rs/0.125.0", upstream.lastReq.Header.Get("User-Agent"))
}
func TestOpenAIGatewayService_CodexCLIOnly_RejectsNonCodexClient(t *testing.T) {
......
......@@ -21,7 +21,7 @@ type FunctionCallOutputValidation struct {
}
// NeedsToolContinuation 判定请求是否需要工具调用续链处理。
// 满足以下任一信号即视为续链:previous_response_id、input 内包含 function_call_output/item_reference、
// 满足以下任一信号即视为续链:previous_response_id、input 内包含工具输出/item_reference、
// 或显式声明 tools/tool_choice。
func NeedsToolContinuation(reqBody map[string]any) bool {
if reqBody == nil {
......@@ -46,7 +46,7 @@ func NeedsToolContinuation(reqBody map[string]any) bool {
continue
}
itemType, _ := itemMap["type"].(string)
if itemType == "function_call_output" || itemType == "item_reference" {
if isCodexToolCallItemType(itemType) || itemType == "item_reference" {
return true
}
}
......
......@@ -17,6 +17,9 @@ func TestNeedsToolContinuationSignals(t *testing.T) {
{name: "previous_response_id", body: map[string]any{"previous_response_id": "resp_1"}, want: true},
{name: "previous_response_id_blank", body: map[string]any{"previous_response_id": " "}, want: false},
{name: "function_call_output", body: map[string]any{"input": []any{map[string]any{"type": "function_call_output"}}}, want: true},
{name: "tool_search_output", body: map[string]any{"input": []any{map[string]any{"type": "tool_search_output"}}}, want: true},
{name: "custom_tool_call_output", body: map[string]any{"input": []any{map[string]any{"type": "custom_tool_call_output"}}}, want: true},
{name: "mcp_tool_call_output", body: map[string]any{"input": []any{map[string]any{"type": "mcp_tool_call_output"}}}, want: true},
{name: "item_reference", body: map[string]any{"input": []any{map[string]any{"type": "item_reference"}}}, want: true},
{name: "tools", body: map[string]any{"tools": []any{map[string]any{"type": "function"}}}, want: true},
{name: "tools_empty", body: map[string]any{"tools": []any{}}, want: false},
......
......@@ -37,7 +37,7 @@ func TestOpenAIGatewayService_SelectAccountByPreviousResponseID_Hit(t *testing.T
require.NoError(t, store.BindResponseAccount(ctx, groupID, "resp_prev_1", account.ID, time.Hour))
selection, err := svc.SelectAccountByPreviousResponseID(ctx, &groupID, "resp_prev_1", "gpt-5.1", nil)
selection, err := svc.SelectAccountByPreviousResponseID(ctx, &groupID, "resp_prev_1", "gpt-5.1", nil, false)
require.NoError(t, err)
require.NotNil(t, selection)
require.NotNil(t, selection.Account)
......@@ -77,7 +77,7 @@ func TestOpenAIGatewayService_SelectAccountByPreviousResponseID_RateLimitedMiss(
require.NoError(t, store.BindResponseAccount(ctx, groupID, "resp_prev_rl", account.ID, time.Hour))
selection, err := svc.SelectAccountByPreviousResponseID(ctx, &groupID, "resp_prev_rl", "gpt-5.1", nil)
selection, err := svc.SelectAccountByPreviousResponseID(ctx, &groupID, "resp_prev_rl", "gpt-5.1", nil, false)
require.NoError(t, err)
require.Nil(t, selection, "限额中的账号不应继续命中 previous_response_id 粘连")
boundAccountID, getErr := store.GetResponseAccount(ctx, groupID, "resp_prev_rl")
......@@ -129,7 +129,7 @@ func TestOpenAIGatewayService_SelectAccountByPreviousResponseID_DBRuntimeRecheck
require.NoError(t, store.BindResponseAccount(ctx, groupID, "resp_prev_db_rl", dbAccount.ID, time.Hour))
selection, err := svc.SelectAccountByPreviousResponseID(ctx, &groupID, "resp_prev_db_rl", "gpt-5.1", nil)
selection, err := svc.SelectAccountByPreviousResponseID(ctx, &groupID, "resp_prev_db_rl", "gpt-5.1", nil, false)
require.NoError(t, err)
require.Nil(t, selection, "DB 中已限流的账号不应继续命中 previous_response_id 粘连")
boundAccountID, getErr := store.GetResponseAccount(ctx, groupID, "resp_prev_db_rl")
......@@ -164,7 +164,7 @@ func TestOpenAIGatewayService_SelectAccountByPreviousResponseID_Excluded(t *test
require.NoError(t, store.BindResponseAccount(ctx, groupID, "resp_prev_2", account.ID, time.Hour))
selection, err := svc.SelectAccountByPreviousResponseID(ctx, &groupID, "resp_prev_2", "gpt-5.1", map[int64]struct{}{account.ID: {}})
selection, err := svc.SelectAccountByPreviousResponseID(ctx, &groupID, "resp_prev_2", "gpt-5.1", map[int64]struct{}{account.ID: {}}, false)
require.NoError(t, err)
require.Nil(t, selection)
}
......@@ -197,7 +197,7 @@ func TestOpenAIGatewayService_SelectAccountByPreviousResponseID_ForceHTTPIgnored
require.NoError(t, store.BindResponseAccount(ctx, groupID, "resp_prev_force_http", account.ID, time.Hour))
selection, err := svc.SelectAccountByPreviousResponseID(ctx, &groupID, "resp_prev_force_http", "gpt-5.1", nil)
selection, err := svc.SelectAccountByPreviousResponseID(ctx, &groupID, "resp_prev_force_http", "gpt-5.1", nil, false)
require.NoError(t, err)
require.Nil(t, selection, "force_http 场景应忽略 previous_response_id 粘连")
}
......@@ -258,7 +258,7 @@ func TestOpenAIGatewayService_SelectAccountByPreviousResponseID_BusyKeepsSticky(
require.NoError(t, store.BindResponseAccount(ctx, groupID, "resp_prev_busy", 21, time.Hour))
selection, err := svc.SelectAccountByPreviousResponseID(ctx, &groupID, "resp_prev_busy", "gpt-5.1", nil)
selection, err := svc.SelectAccountByPreviousResponseID(ctx, &groupID, "resp_prev_busy", "gpt-5.1", nil, false)
require.NoError(t, err)
require.NotNil(t, selection)
require.NotNil(t, selection.Account)
......
......@@ -3800,6 +3800,7 @@ func (s *OpenAIGatewayService) SelectAccountByPreviousResponseID(
previousResponseID string,
requestedModel string,
excludedIDs map[int64]struct{},
requireCompact bool,
) (*AccountSelectionResult, error) {
if s == nil {
return nil, nil
......@@ -3840,11 +3841,16 @@ func (s *OpenAIGatewayService) SelectAccountByPreviousResponseID(
if requestedModel != "" && !account.IsModelSupported(requestedModel) {
return nil, nil
}
account = s.recheckSelectedOpenAIAccountFromDB(ctx, account, requestedModel)
account = s.recheckSelectedOpenAIAccountFromDB(ctx, account, requestedModel, requireCompact)
if account == nil {
_ = store.DeleteResponseAccount(ctx, derefGroupID(groupID), responseID)
return nil, nil
}
// 兜底:若上游 compact 能力刚被探测为不支持,但 sticky 还在,需要主动放弃。
if requireCompact && openAICompactSupportTier(account) == 0 {
_ = store.DeleteResponseAccount(ctx, derefGroupID(groupID), responseID)
return nil, nil
}
result, acquireErr := s.tryAcquireAccountSlot(ctx, accountID, account.Concurrency)
if acquireErr == nil && result.Acquired {
......
......@@ -2,6 +2,7 @@ package service
import (
"context"
"encoding/json"
"errors"
"fmt"
"log/slog"
......@@ -268,6 +269,9 @@ func (s *PaymentService) doBalance(ctx context.Context, o *dbent.PaymentOrder) e
switch action {
case redeemActionSkipCompleted:
if err := s.applyAffiliateRebateForOrder(ctx, o); err != nil {
return err
}
// Code already created and redeemed — just mark completed
return s.markCompleted(ctx, o, "RECHARGE_SUCCESS")
case redeemActionCreate:
......@@ -281,6 +285,9 @@ func (s *PaymentService) doBalance(ctx context.Context, o *dbent.PaymentOrder) e
if _, err := s.redeemService.Redeem(ctx, o.UserID, o.RechargeCode); err != nil {
return fmt.Errorf("redeem balance: %w", err)
}
if err := s.applyAffiliateRebateForOrder(ctx, o); err != nil {
return err
}
return s.markCompleted(ctx, o, "RECHARGE_SUCCESS")
}
......@@ -358,6 +365,142 @@ func (s *PaymentService) hasAuditLog(ctx context.Context, orderID int64, action
return c > 0
}
func (s *PaymentService) applyAffiliateRebateForOrder(ctx context.Context, o *dbent.PaymentOrder) error {
if o == nil || o.OrderType != payment.OrderTypeBalance || o.Amount <= 0 {
return nil
}
if s.affiliateService == nil {
return nil
}
tx, err := s.entClient.Tx(ctx)
if err != nil {
s.writeAuditLog(ctx, o.ID, "AFFILIATE_REBATE_FAILED", "system", map[string]any{
"error": fmt.Sprintf("begin affiliate rebate tx: %v", err),
})
return fmt.Errorf("begin affiliate rebate tx: %w", err)
}
defer func() { _ = tx.Rollback() }()
txCtx := dbent.NewTxContext(ctx, tx)
claimed, err := s.tryClaimAffiliateRebateAudit(txCtx, tx.Client(), o.ID, o.Amount)
if err != nil {
s.writeAuditLog(ctx, o.ID, "AFFILIATE_REBATE_FAILED", "system", map[string]any{
"error": err.Error(),
})
return fmt.Errorf("claim affiliate rebate audit: %w", err)
}
if !claimed {
return nil
}
rebateAmount, err := s.affiliateService.AccrueInviteRebate(txCtx, o.UserID, o.Amount)
if err != nil {
s.writeAuditLog(ctx, o.ID, "AFFILIATE_REBATE_FAILED", "system", map[string]any{
"error": err.Error(),
})
return fmt.Errorf("accrue affiliate rebate: %w", err)
}
if rebateAmount <= 0 {
if err := s.updateClaimedAffiliateRebateAudit(txCtx, tx.Client(), o.ID, "AFFILIATE_REBATE_SKIPPED", map[string]any{
"baseAmount": o.Amount,
"reason": "no inviter bound or rebate amount <= 0",
}); err != nil {
s.writeAuditLog(ctx, o.ID, "AFFILIATE_REBATE_FAILED", "system", map[string]any{
"error": err.Error(),
})
return fmt.Errorf("update affiliate rebate skipped audit: %w", err)
}
if err := tx.Commit(); err != nil {
s.writeAuditLog(ctx, o.ID, "AFFILIATE_REBATE_FAILED", "system", map[string]any{
"error": fmt.Sprintf("commit affiliate rebate tx: %v", err),
})
return fmt.Errorf("commit affiliate rebate tx: %w", err)
}
return nil
}
if err := s.updateClaimedAffiliateRebateAudit(txCtx, tx.Client(), o.ID, "AFFILIATE_REBATE_APPLIED", map[string]any{
"baseAmount": o.Amount,
"rebateAmount": rebateAmount,
}); err != nil {
s.writeAuditLog(ctx, o.ID, "AFFILIATE_REBATE_FAILED", "system", map[string]any{
"error": err.Error(),
})
return fmt.Errorf("update affiliate rebate applied audit: %w", err)
}
if err := tx.Commit(); err != nil {
s.writeAuditLog(ctx, o.ID, "AFFILIATE_REBATE_FAILED", "system", map[string]any{
"error": fmt.Sprintf("commit affiliate rebate tx: %v", err),
})
return fmt.Errorf("commit affiliate rebate tx: %w", err)
}
return nil
}
func (s *PaymentService) tryClaimAffiliateRebateAudit(ctx context.Context, client *dbent.Client, orderID int64, baseAmount float64) (bool, error) {
if client == nil {
return false, errors.New("nil payment client")
}
oid := strconv.FormatInt(orderID, 10)
detail, _ := json.Marshal(map[string]any{
"baseAmount": baseAmount,
"status": "reserved",
})
rows, err := client.QueryContext(ctx, `
INSERT INTO payment_audit_logs (order_id, action, detail, operator, created_at)
SELECT $1::text, 'AFFILIATE_REBATE_APPLIED', $2::text, 'system', NOW()
WHERE NOT EXISTS (
SELECT 1
FROM payment_audit_logs
WHERE order_id = $1::text
AND action IN ('AFFILIATE_REBATE_APPLIED', 'AFFILIATE_REBATE_SKIPPED')
)
ON CONFLICT (order_id, action) DO NOTHING
RETURNING id`, oid, string(detail))
if err != nil {
return false, err
}
defer func() { _ = rows.Close() }()
if !rows.Next() {
if err := rows.Err(); err != nil {
return false, err
}
return false, nil
}
var claimID int64
if err := rows.Scan(&claimID); err != nil {
return false, err
}
return true, nil
}
func (s *PaymentService) updateClaimedAffiliateRebateAudit(ctx context.Context, client *dbent.Client, orderID int64, action string, detail map[string]any) error {
if client == nil {
return errors.New("nil payment client")
}
oid := strconv.FormatInt(orderID, 10)
detailJSON, _ := json.Marshal(detail)
updated, err := client.PaymentAuditLog.Update().
Where(
paymentauditlog.OrderIDEQ(oid),
paymentauditlog.ActionEQ("AFFILIATE_REBATE_APPLIED"),
).
SetAction(action).
SetDetail(string(detailJSON)).
SetOperator("system").
Save(ctx)
if err != nil {
return err
}
if updated == 0 {
return errors.New("affiliate rebate claim log not found")
}
return nil
}
func (s *PaymentService) markFailed(ctx context.Context, oid int64, cause error) {
now := time.Now()
r := psErrMsg(cause)
......
......@@ -181,10 +181,11 @@ type PaymentService struct {
userRepo UserRepository
groupRepo GroupRepository
resumeService *PaymentResumeService
affiliateService *AffiliateService
}
func NewPaymentService(entClient *dbent.Client, registry *payment.Registry, loadBalancer payment.LoadBalancer, redeemService *RedeemService, subscriptionSvc *SubscriptionService, configService *PaymentConfigService, userRepo UserRepository, groupRepo GroupRepository) *PaymentService {
svc := &PaymentService{entClient: entClient, registry: registry, loadBalancer: newVisibleMethodLoadBalancer(loadBalancer, configService), redeemService: redeemService, subscriptionSvc: subscriptionSvc, configService: configService, userRepo: userRepo, groupRepo: groupRepo}
func NewPaymentService(entClient *dbent.Client, registry *payment.Registry, loadBalancer payment.LoadBalancer, redeemService *RedeemService, subscriptionSvc *SubscriptionService, configService *PaymentConfigService, userRepo UserRepository, groupRepo GroupRepository, affiliateService *AffiliateService) *PaymentService {
svc := &PaymentService{entClient: entClient, registry: registry, loadBalancer: newVisibleMethodLoadBalancer(loadBalancer, configService), redeemService: redeemService, subscriptionSvc: subscriptionSvc, configService: configService, userRepo: userRepo, groupRepo: groupRepo, affiliateService: affiliateService}
svc.resumeService = psNewPaymentResumeService(configService)
return svc
}
......
......@@ -931,7 +931,7 @@ func (s *RateLimitService) handle429(ctx context.Context, account *Account, head
// calculateOpenAI429ResetTime 从 OpenAI 429 响应头计算正确的重置时间
// 返回 nil 表示无法从响应头中确定重置时间
func (s *RateLimitService) calculateOpenAI429ResetTime(headers http.Header) *time.Time {
func calculateOpenAI429ResetTime(headers http.Header) *time.Time {
snapshot := ParseCodexRateLimitHeaders(headers)
if snapshot == nil {
return nil
......@@ -977,6 +977,10 @@ func (s *RateLimitService) calculateOpenAI429ResetTime(headers http.Header) *tim
return nil
}
func (s *RateLimitService) calculateOpenAI429ResetTime(headers http.Header) *time.Time {
return calculateOpenAI429ResetTime(headers)
}
// anthropic429Result holds the parsed Anthropic 429 rate-limit information.
type anthropic429Result struct {
resetAt time.Time // The correct reset time to use for SetRateLimited
......
......@@ -7,6 +7,7 @@ import (
"fmt"
"os"
"path/filepath"
"sync"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
......@@ -37,6 +38,7 @@ type RequestCaptureService struct {
repo RequestCaptureLogRepository
nfsPath string
timeout time.Duration
nfsPathMap sync.Map // captureID int64 → nfsFilePath string(短暂存活,CaptureResponse 调用后删除)
}
// nfsFileEnvelope 是写入 NFS 文件的 JSON 结构。
......@@ -109,15 +111,24 @@ func (s *RequestCaptureService) Capture(
)
return 0
}
// 记录 captureID → nfsFilePath 映射,供 CaptureResponse 写响应文件用
if nfsFilePath != "" {
s.nfsPathMap.Store(id, nfsFilePath)
}
return id
}
// CaptureResponse 异步将响应体写入已有的捕获记录,不阻塞调用方。
// CaptureResponse 异步将响应体写入已有的捕获记录(数据库 + NFS),不阻塞调用方。
// captureID 为 Capture 返回的 ID,为 0 时直接忽略。
func (s *RequestCaptureService) CaptureResponse(captureID int64, responseBody string) {
if captureID == 0 || responseBody == "" {
return
}
// 取出并删除 NFS 路径映射(一次性消费)
var nfsFilePath string
if v, ok := s.nfsPathMap.LoadAndDelete(captureID); ok {
nfsFilePath, _ = v.(string)
}
go func() {
ctx, cancel := context.WithTimeout(context.Background(), s.timeout)
defer cancel()
......@@ -127,9 +138,22 @@ func (s *RequestCaptureService) CaptureResponse(captureID int64, responseBody st
zap.Error(err),
)
}
// NFS 响应文件:与请求文件同目录,文件名加 _response 后缀
if nfsFilePath != "" {
respPath := nfsResponseFilePath(nfsFilePath)
s.writeResponseToNFS(respPath, captureID, responseBody)
}
}()
}
// nfsResponseFilePath 将请求文件路径转换为响应文件路径。
// 例如:/nfs/2024-01-01/42/123_reqid.json → /nfs/2024-01-01/42/123_reqid_response.json
func nfsResponseFilePath(requestPath string) string {
ext := filepath.Ext(requestPath)
base := requestPath[:len(requestPath)-len(ext)]
return base + "_response" + ext
}
func (s *RequestCaptureService) buildNFSFilePath(apiKeyID int64, requestID string, t time.Time) string {
date := t.UTC().Format("2006-01-02")
filename := fmt.Sprintf("%d_%s.json", t.UnixNano(), requestID)
......@@ -180,3 +204,44 @@ func (s *RequestCaptureService) writeToNFS(
)
}
}
// nfsResponseEnvelope 是写入 NFS 响应文件的 JSON 结构。
type nfsResponseEnvelope struct {
CaptureID int64 `json:"capture_id"`
CreatedAt time.Time `json:"created_at"`
Body json.RawMessage `json:"body"`
}
func (s *RequestCaptureService) writeResponseToNFS(filePath string, captureID int64, responseBody string) {
dir := filepath.Dir(filePath)
if err := os.MkdirAll(dir, 0o755); err != nil {
logger.L().Error("request_capture: mkdir failed (response)",
zap.String("dir", dir),
zap.Error(err),
)
return
}
envelope := nfsResponseEnvelope{
CaptureID: captureID,
CreatedAt: time.Now().UTC(),
Body: json.RawMessage(responseBody),
}
var buf bytes.Buffer
enc := json.NewEncoder(&buf)
enc.SetEscapeHTML(false)
if err := enc.Encode(envelope); err != nil {
logger.L().Error("request_capture: json marshal failed (response)",
zap.Int64("capture_id", captureID),
zap.Error(err),
)
return
}
if err := os.WriteFile(filePath, buf.Bytes(), 0o644); err != nil {
logger.L().Error("request_capture: nfs write failed (response)",
zap.String("file", filePath),
zap.Error(err),
)
}
}
......@@ -8,6 +8,7 @@ import (
"errors"
"fmt"
"log/slog"
"math"
"net/url"
"sort"
"strconv"
......@@ -453,6 +454,7 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
SettingKeyChannelMonitorEnabled,
SettingKeyChannelMonitorDefaultIntervalSeconds,
SettingKeyAvailableChannelsEnabled,
SettingKeyAffiliateEnabled,
}
settings, err := s.settingRepo.GetMultiple(ctx, keys)
......@@ -540,6 +542,8 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
ChannelMonitorDefaultIntervalSeconds: parseChannelMonitorInterval(settings[SettingKeyChannelMonitorDefaultIntervalSeconds]),
AvailableChannelsEnabled: settings[SettingKeyAvailableChannelsEnabled] == "true",
AffiliateEnabled: settings[SettingKeyAffiliateEnabled] == "true",
}, nil
}
......@@ -686,6 +690,7 @@ type PublicSettingsInjectionPayload struct {
ChannelMonitorEnabled bool `json:"channel_monitor_enabled"`
ChannelMonitorDefaultIntervalSeconds int `json:"channel_monitor_default_interval_seconds"`
AvailableChannelsEnabled bool `json:"available_channels_enabled"`
AffiliateEnabled bool `json:"affiliate_enabled"`
}
// GetPublicSettingsForInjection returns public settings in a format suitable for HTML injection.
......@@ -738,6 +743,7 @@ func (s *SettingService) GetPublicSettingsForInjection(ctx context.Context) (any
ChannelMonitorEnabled: settings.ChannelMonitorEnabled,
ChannelMonitorDefaultIntervalSeconds: settings.ChannelMonitorDefaultIntervalSeconds,
AvailableChannelsEnabled: settings.AvailableChannelsEnabled,
AffiliateEnabled: settings.AffiliateEnabled,
}, nil
}
......@@ -1167,6 +1173,26 @@ func (s *SettingService) buildSystemSettingsUpdates(ctx context.Context, setting
// 默认配置
updates[SettingKeyDefaultConcurrency] = strconv.Itoa(settings.DefaultConcurrency)
updates[SettingKeyDefaultBalance] = strconv.FormatFloat(settings.DefaultBalance, 'f', 8, 64)
settings.AffiliateRebateRate = clampAffiliateRebateRate(settings.AffiliateRebateRate)
updates[SettingKeyAffiliateRebateRate] = strconv.FormatFloat(settings.AffiliateRebateRate, 'f', 8, 64)
if settings.AffiliateRebateFreezeHours < 0 {
settings.AffiliateRebateFreezeHours = AffiliateRebateFreezeHoursDefault
}
if settings.AffiliateRebateFreezeHours > AffiliateRebateFreezeHoursMax {
settings.AffiliateRebateFreezeHours = AffiliateRebateFreezeHoursMax
}
updates[SettingKeyAffiliateRebateFreezeHours] = strconv.Itoa(settings.AffiliateRebateFreezeHours)
if settings.AffiliateRebateDurationDays < 0 {
settings.AffiliateRebateDurationDays = AffiliateRebateDurationDaysDefault
}
if settings.AffiliateRebateDurationDays > AffiliateRebateDurationDaysMax {
settings.AffiliateRebateDurationDays = AffiliateRebateDurationDaysMax
}
updates[SettingKeyAffiliateRebateDurationDays] = strconv.Itoa(settings.AffiliateRebateDurationDays)
if settings.AffiliateRebatePerInviteeCap < 0 {
settings.AffiliateRebatePerInviteeCap = AffiliateRebatePerInviteeCapDefault
}
updates[SettingKeyAffiliateRebatePerInviteeCap] = strconv.FormatFloat(settings.AffiliateRebatePerInviteeCap, 'f', 8, 64)
updates[SettingKeyDefaultUserRPMLimit] = strconv.Itoa(settings.DefaultUserRPMLimit)
defaultSubsJSON, err := json.Marshal(settings.DefaultSubscriptions)
if err != nil {
......@@ -1202,6 +1228,9 @@ func (s *SettingService) buildSystemSettingsUpdates(ctx context.Context, setting
// Available channels feature switch
updates[SettingKeyAvailableChannelsEnabled] = strconv.FormatBool(settings.AvailableChannelsEnabled)
// Affiliate (邀请返利) feature switch
updates[SettingKeyAffiliateEnabled] = strconv.FormatBool(settings.AffiliateEnabled)
// Claude Code version check
updates[SettingKeyMinClaudeCodeVersion] = settings.MinClaudeCodeVersion
updates[SettingKeyMaxClaudeCodeVersion] = settings.MaxClaudeCodeVersion
......@@ -1477,6 +1506,78 @@ func (s *SettingService) IsInvitationCodeEnabled(ctx context.Context) bool {
return value == "true"
}
// IsAffiliateEnabled 检查是否启用邀请返利功能(总开关)
func (s *SettingService) IsAffiliateEnabled(ctx context.Context) bool {
value, err := s.settingRepo.GetValue(ctx, SettingKeyAffiliateEnabled)
if err != nil {
return false // 默认关闭
}
return value == "true"
}
// GetAffiliateRebateRatePercent 读取并 clamp 全局返利比例。
// 解析失败、缺失或越界都回退到 AffiliateRebateRateDefault — 该比例从不抛错,
// 调用方只关心一个可用的数值。
func (s *SettingService) GetAffiliateRebateRatePercent(ctx context.Context) float64 {
raw, err := s.settingRepo.GetValue(ctx, SettingKeyAffiliateRebateRate)
if err != nil {
return AffiliateRebateRateDefault
}
rate, err := strconv.ParseFloat(strings.TrimSpace(raw), 64)
if err != nil || math.IsNaN(rate) || math.IsInf(rate, 0) {
return AffiliateRebateRateDefault
}
return clampAffiliateRebateRate(rate)
}
// GetAffiliateRebateFreezeHours 返回返利冻结期(小时)。
// 返回 0 表示不冻结(向后兼容)。
func (s *SettingService) GetAffiliateRebateFreezeHours(ctx context.Context) int {
raw, err := s.settingRepo.GetValue(ctx, SettingKeyAffiliateRebateFreezeHours)
if err != nil {
return AffiliateRebateFreezeHoursDefault
}
hours, err := strconv.Atoi(strings.TrimSpace(raw))
if err != nil || hours < 0 {
return AffiliateRebateFreezeHoursDefault
}
if hours > AffiliateRebateFreezeHoursMax {
return AffiliateRebateFreezeHoursMax
}
return hours
}
// GetAffiliateRebateDurationDays 返回返利有效期(天)。
// 返回 0 表示永久有效。
func (s *SettingService) GetAffiliateRebateDurationDays(ctx context.Context) int {
raw, err := s.settingRepo.GetValue(ctx, SettingKeyAffiliateRebateDurationDays)
if err != nil {
return AffiliateRebateDurationDaysDefault
}
days, err := strconv.Atoi(strings.TrimSpace(raw))
if err != nil || days < 0 {
return AffiliateRebateDurationDaysDefault
}
if days > AffiliateRebateDurationDaysMax {
return AffiliateRebateDurationDaysMax
}
return days
}
// GetAffiliateRebatePerInviteeCap 返回单人返利上限。
// 返回 0 表示无上限。
func (s *SettingService) GetAffiliateRebatePerInviteeCap(ctx context.Context) float64 {
raw, err := s.settingRepo.GetValue(ctx, SettingKeyAffiliateRebatePerInviteeCap)
if err != nil {
return AffiliateRebatePerInviteeCapDefault
}
cap, err := strconv.ParseFloat(strings.TrimSpace(raw), 64)
if err != nil || cap < 0 || math.IsNaN(cap) || math.IsInf(cap, 0) {
return AffiliateRebatePerInviteeCapDefault
}
return cap
}
// IsPasswordResetEnabled 检查是否启用密码重置功能
// 要求:必须同时开启邮件验证
func (s *SettingService) IsPasswordResetEnabled(ctx context.Context) bool {
......@@ -1719,6 +1820,10 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error {
SettingKeyOIDCConnectUserInfoUsernamePath: "",
SettingKeyDefaultConcurrency: strconv.Itoa(s.cfg.Default.UserConcurrency),
SettingKeyDefaultBalance: strconv.FormatFloat(s.cfg.Default.UserBalance, 'f', 8, 64),
SettingKeyAffiliateRebateRate: strconv.FormatFloat(AffiliateRebateRateDefault, 'f', 8, 64),
SettingKeyAffiliateRebateFreezeHours: strconv.Itoa(AffiliateRebateFreezeHoursDefault),
SettingKeyAffiliateRebateDurationDays: strconv.Itoa(AffiliateRebateDurationDaysDefault),
SettingKeyAffiliateRebatePerInviteeCap: strconv.FormatFloat(AffiliateRebatePerInviteeCapDefault, 'f', 2, 64),
SettingKeyDefaultUserRPMLimit: "0",
SettingKeyDefaultSubscriptions: "[]",
SettingKeyAuthSourceDefaultEmailBalance: "0",
......@@ -1767,6 +1872,9 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error {
// Available channels feature (default disabled; opt-in)
SettingKeyAvailableChannelsEnabled: "false",
// Affiliate (邀请返利) feature (default disabled; opt-in)
SettingKeyAffiliateEnabled: "false",
// Claude Code version check (default: empty = disabled)
SettingKeyMinClaudeCodeVersion: "",
SettingKeyMaxClaudeCodeVersion: "",
......@@ -1846,6 +1954,26 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin
} else {
result.DefaultBalance = s.cfg.Default.UserBalance
}
if rebateRate, err := strconv.ParseFloat(settings[SettingKeyAffiliateRebateRate], 64); err == nil {
result.AffiliateRebateRate = clampAffiliateRebateRate(rebateRate)
} else {
result.AffiliateRebateRate = AffiliateRebateRateDefault
}
if freezeHours, err := strconv.Atoi(settings[SettingKeyAffiliateRebateFreezeHours]); err == nil && freezeHours >= 0 {
if freezeHours > AffiliateRebateFreezeHoursMax {
freezeHours = AffiliateRebateFreezeHoursMax
}
result.AffiliateRebateFreezeHours = freezeHours
}
if durationDays, err := strconv.Atoi(settings[SettingKeyAffiliateRebateDurationDays]); err == nil && durationDays >= 0 {
if durationDays > AffiliateRebateDurationDaysMax {
durationDays = AffiliateRebateDurationDaysMax
}
result.AffiliateRebateDurationDays = durationDays
}
if perInviteeCap, err := strconv.ParseFloat(settings[SettingKeyAffiliateRebatePerInviteeCap], 64); err == nil && perInviteeCap >= 0 {
result.AffiliateRebatePerInviteeCap = perInviteeCap
}
result.DefaultSubscriptions = parseDefaultSubscriptions(settings[SettingKeyDefaultSubscriptions])
// 敏感信息直接返回,方便测试连接时使用
......@@ -2082,6 +2210,9 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin
// Available channels feature (default: disabled; strict true)
result.AvailableChannelsEnabled = settings[SettingKeyAvailableChannelsEnabled] == "true"
// Affiliate (邀请返利) feature (default: disabled; strict true)
result.AffiliateEnabled = settings[SettingKeyAffiliateEnabled] == "true"
// Claude Code version check
result.MinClaudeCodeVersion = settings[SettingKeyMinClaudeCodeVersion]
result.MaxClaudeCodeVersion = settings[SettingKeyMaxClaudeCodeVersion]
......@@ -2130,6 +2261,19 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin
return result
}
func clampAffiliateRebateRate(value float64) float64 {
if math.IsNaN(value) || math.IsInf(value, 0) {
return AffiliateRebateRateDefault
}
if value < AffiliateRebateRateMin {
return AffiliateRebateRateMin
}
if value > AffiliateRebateRateMax {
return AffiliateRebateRateMax
}
return value
}
func isFalseSettingValue(value string) bool {
switch strings.ToLower(strings.TrimSpace(value)) {
case "false", "0", "off", "disabled":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment