Unverified Commit fa68cbad authored by InCerryGit's avatar InCerryGit Committed by GitHub
Browse files

Merge branch 'Wei-Shaw:main' into main

parents 995ef134 0f033930
...@@ -79,6 +79,17 @@ type OpsInsertErrorLogInput struct { ...@@ -79,6 +79,17 @@ type OpsInsertErrorLogInput struct {
Model string Model string
RequestPath string RequestPath string
Stream bool Stream bool
// InboundEndpoint is the normalized client-facing API endpoint path, e.g. /v1/chat/completions.
InboundEndpoint string
// UpstreamEndpoint is the normalized upstream endpoint path, e.g. /v1/responses.
UpstreamEndpoint string
// RequestedModel is the client-requested model name before mapping.
RequestedModel string
// UpstreamModel is the actual model sent to upstream after mapping. Empty means no mapping.
UpstreamModel string
// RequestType is the granular request type: 0=unknown, 1=sync, 2=stream, 3=ws_v2.
// Matches service.RequestType enum semantics from usage_log.go.
RequestType *int16
UserAgent string UserAgent string
ErrorPhase string ErrorPhase string
......
...@@ -93,6 +93,10 @@ type OpsUpstreamErrorEvent struct { ...@@ -93,6 +93,10 @@ type OpsUpstreamErrorEvent struct {
UpstreamStatusCode int `json:"upstream_status_code,omitempty"` UpstreamStatusCode int `json:"upstream_status_code,omitempty"`
UpstreamRequestID string `json:"upstream_request_id,omitempty"` UpstreamRequestID string `json:"upstream_request_id,omitempty"`
// UpstreamURL is the actual upstream URL that was called (host + path, query/fragment stripped).
// Helps debug 404/routing errors by showing which endpoint was targeted.
UpstreamURL string `json:"upstream_url,omitempty"`
// Best-effort upstream request capture (sanitized+trimmed). // Best-effort upstream request capture (sanitized+trimmed).
// Required for retrying a specific upstream attempt. // Required for retrying a specific upstream attempt.
UpstreamRequestBody string `json:"upstream_request_body,omitempty"` UpstreamRequestBody string `json:"upstream_request_body,omitempty"`
...@@ -119,6 +123,7 @@ func appendOpsUpstreamError(c *gin.Context, ev OpsUpstreamErrorEvent) { ...@@ -119,6 +123,7 @@ func appendOpsUpstreamError(c *gin.Context, ev OpsUpstreamErrorEvent) {
ev.UpstreamRequestBody = strings.TrimSpace(ev.UpstreamRequestBody) ev.UpstreamRequestBody = strings.TrimSpace(ev.UpstreamRequestBody)
ev.UpstreamResponseBody = strings.TrimSpace(ev.UpstreamResponseBody) ev.UpstreamResponseBody = strings.TrimSpace(ev.UpstreamResponseBody)
ev.Kind = strings.TrimSpace(ev.Kind) ev.Kind = strings.TrimSpace(ev.Kind)
ev.UpstreamURL = strings.TrimSpace(ev.UpstreamURL)
ev.Message = strings.TrimSpace(ev.Message) ev.Message = strings.TrimSpace(ev.Message)
ev.Detail = strings.TrimSpace(ev.Detail) ev.Detail = strings.TrimSpace(ev.Detail)
if ev.Message != "" { if ev.Message != "" {
...@@ -205,3 +210,19 @@ func ParseOpsUpstreamErrors(raw string) ([]*OpsUpstreamErrorEvent, error) { ...@@ -205,3 +210,19 @@ func ParseOpsUpstreamErrors(raw string) ([]*OpsUpstreamErrorEvent, error) {
} }
return out, nil return out, nil
} }
// safeUpstreamURL returns scheme + host + path from a URL, stripping query/fragment
// to avoid leaking sensitive query parameters (e.g. OAuth tokens).
func safeUpstreamURL(rawURL string) string {
rawURL = strings.TrimSpace(rawURL)
if rawURL == "" {
return ""
}
if idx := strings.IndexByte(rawURL, '?'); idx >= 0 {
rawURL = rawURL[:idx]
}
if idx := strings.IndexByte(rawURL, '#'); idx >= 0 {
rawURL = rawURL[:idx]
}
return rawURL
}
...@@ -8,6 +8,27 @@ import ( ...@@ -8,6 +8,27 @@ import (
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
func TestSafeUpstreamURL(t *testing.T) {
tests := []struct {
name string
input string
want string
}{
{"strips query", "https://api.anthropic.com/v1/messages?beta=true", "https://api.anthropic.com/v1/messages"},
{"strips fragment", "https://api.openai.com/v1/responses#frag", "https://api.openai.com/v1/responses"},
{"strips both", "https://host/path?token=secret#x", "https://host/path"},
{"no query or fragment", "https://host/path", "https://host/path"},
{"empty string", "", ""},
{"whitespace only", " ", ""},
{"query before fragment", "https://h/p?a=1#f", "https://h/p"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
require.Equal(t, tt.want, safeUpstreamURL(tt.input))
})
}
}
func TestAppendOpsUpstreamError_UsesRequestBodyBytesFromContext(t *testing.T) { func TestAppendOpsUpstreamError_UsesRequestBodyBytesFromContext(t *testing.T) {
gin.SetMode(gin.TestMode) gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder() rec := httptest.NewRecorder()
......
...@@ -163,7 +163,7 @@ func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Acc ...@@ -163,7 +163,7 @@ func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Acc
account.Credentials = make(map[string]any) account.Credentials = make(map[string]any)
} }
account.Credentials["expires_at"] = time.Now().Format(time.RFC3339) account.Credentials["expires_at"] = time.Now().Format(time.RFC3339)
if err := s.accountRepo.Update(ctx, account); err != nil { if err := persistAccountCredentials(ctx, s.accountRepo, account, account.Credentials); err != nil {
slog.Warn("oauth_401_force_refresh_update_failed", "account_id", account.ID, "error", err) slog.Warn("oauth_401_force_refresh_update_failed", "account_id", account.ID, "error", err)
} else { } else {
slog.Info("oauth_401_force_refresh_set", "account_id", account.ID, "platform", account.Platform) slog.Info("oauth_401_force_refresh_set", "account_id", account.ID, "platform", account.Platform)
......
...@@ -17,6 +17,8 @@ type rateLimitAccountRepoStub struct { ...@@ -17,6 +17,8 @@ type rateLimitAccountRepoStub struct {
mockAccountRepoForGemini mockAccountRepoForGemini
setErrorCalls int setErrorCalls int
tempCalls int tempCalls int
updateCredentialsCalls int
lastCredentials map[string]any
lastErrorMsg string lastErrorMsg string
} }
...@@ -31,6 +33,12 @@ func (r *rateLimitAccountRepoStub) SetTempUnschedulable(ctx context.Context, id ...@@ -31,6 +33,12 @@ func (r *rateLimitAccountRepoStub) SetTempUnschedulable(ctx context.Context, id
return nil return nil
} }
func (r *rateLimitAccountRepoStub) UpdateCredentials(ctx context.Context, id int64, credentials map[string]any) error {
r.updateCredentialsCalls++
r.lastCredentials = cloneCredentials(credentials)
return nil
}
type tokenCacheInvalidatorRecorder struct { type tokenCacheInvalidatorRecorder struct {
accounts []*Account accounts []*Account
err error err error
...@@ -110,6 +118,7 @@ func TestRateLimitService_HandleUpstreamError_OAuth401InvalidatorError(t *testin ...@@ -110,6 +118,7 @@ func TestRateLimitService_HandleUpstreamError_OAuth401InvalidatorError(t *testin
require.True(t, shouldDisable) require.True(t, shouldDisable)
require.Equal(t, 0, repo.setErrorCalls) require.Equal(t, 0, repo.setErrorCalls)
require.Equal(t, 1, repo.tempCalls) require.Equal(t, 1, repo.tempCalls)
require.Equal(t, 1, repo.updateCredentialsCalls)
require.Len(t, invalidator.accounts, 1) require.Len(t, invalidator.accounts, 1)
} }
...@@ -130,3 +139,22 @@ func TestRateLimitService_HandleUpstreamError_NonOAuth401(t *testing.T) { ...@@ -130,3 +139,22 @@ func TestRateLimitService_HandleUpstreamError_NonOAuth401(t *testing.T) {
require.Equal(t, 1, repo.setErrorCalls) require.Equal(t, 1, repo.setErrorCalls)
require.Empty(t, invalidator.accounts) require.Empty(t, invalidator.accounts)
} }
func TestRateLimitService_HandleUpstreamError_OAuth401UsesCredentialsUpdater(t *testing.T) {
repo := &rateLimitAccountRepoStub{}
service := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
account := &Account{
ID: 103,
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "token",
},
}
shouldDisable := service.HandleUpstreamError(context.Background(), account, 401, http.Header{}, []byte("unauthorized"))
require.True(t, shouldDisable)
require.Equal(t, 1, repo.updateCredentialsCalls)
require.NotEmpty(t, repo.lastCredentials["expires_at"])
}
...@@ -81,7 +81,7 @@ func (m *sessionWindowMockRepo) Delete(context.Context, int64) error { panic( ...@@ -81,7 +81,7 @@ func (m *sessionWindowMockRepo) Delete(context.Context, int64) error { panic(
func (m *sessionWindowMockRepo) List(context.Context, pagination.PaginationParams) ([]Account, *pagination.PaginationResult, error) { func (m *sessionWindowMockRepo) List(context.Context, pagination.PaginationParams) ([]Account, *pagination.PaginationResult, error) {
panic("unexpected") panic("unexpected")
} }
func (m *sessionWindowMockRepo) ListWithFilters(context.Context, pagination.PaginationParams, string, string, string, string, int64) ([]Account, *pagination.PaginationResult, error) { func (m *sessionWindowMockRepo) ListWithFilters(context.Context, pagination.PaginationParams, string, string, string, string, int64, string) ([]Account, *pagination.PaginationResult, error) {
panic("unexpected") panic("unexpected")
} }
func (m *sessionWindowMockRepo) ListByGroup(context.Context, int64) ([]Account, error) { func (m *sessionWindowMockRepo) ListByGroup(context.Context, int64) ([]Account, error) {
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment