Commit 6901b64f authored by cyhhao's avatar cyhhao
Browse files

merge: sync upstream changes

parents 32c47b15 dae0d532
......@@ -545,12 +545,19 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
}
requestIDHeader = idHeader
// Capture upstream request body for ops retry of this attempt.
if c != nil {
// In this code path `body` is already the JSON sent to upstream.
c.Set(OpsUpstreamRequestBodyKey, string(body))
}
resp, err = s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency)
if err != nil {
safeErr := sanitizeUpstreamErrorMessage(err.Error())
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: 0,
Kind: "request_error",
Message: safeErr,
......@@ -588,6 +595,7 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: resp.StatusCode,
UpstreamRequestID: upstreamReqID,
Kind: "signature_error",
......@@ -662,6 +670,7 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: resp.StatusCode,
UpstreamRequestID: upstreamReqID,
Kind: "retry",
......@@ -711,6 +720,7 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: resp.StatusCode,
UpstreamRequestID: upstreamReqID,
Kind: "failover",
......@@ -737,6 +747,7 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: resp.StatusCode,
UpstreamRequestID: upstreamReqID,
Kind: "failover",
......@@ -972,12 +983,19 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
}
requestIDHeader = idHeader
// Capture upstream request body for ops retry of this attempt.
if c != nil {
// In this code path `body` is already the JSON sent to upstream.
c.Set(OpsUpstreamRequestBodyKey, string(body))
}
resp, err = s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency)
if err != nil {
safeErr := sanitizeUpstreamErrorMessage(err.Error())
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: 0,
Kind: "request_error",
Message: safeErr,
......@@ -1036,6 +1054,7 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: resp.StatusCode,
UpstreamRequestID: upstreamReqID,
Kind: "retry",
......@@ -1120,6 +1139,7 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: resp.StatusCode,
UpstreamRequestID: requestID,
Kind: "failover",
......@@ -1143,6 +1163,7 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: resp.StatusCode,
UpstreamRequestID: requestID,
Kind: "failover",
......@@ -1168,6 +1189,7 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: resp.StatusCode,
UpstreamRequestID: requestID,
Kind: "http_error",
......@@ -1300,6 +1322,7 @@ func (s *GeminiMessagesCompatService) writeGeminiMappedError(c *gin.Context, acc
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: upstreamStatus,
UpstreamRequestID: upstreamRequestID,
Kind: "http_error",
......
......@@ -125,6 +125,9 @@ func (m *mockAccountRepoForGemini) SetRateLimited(ctx context.Context, id int64,
func (m *mockAccountRepoForGemini) SetAntigravityQuotaScopeLimit(ctx context.Context, id int64, scope AntigravityQuotaScope, resetAt time.Time) error {
return nil
}
func (m *mockAccountRepoForGemini) SetModelRateLimit(ctx context.Context, id int64, scope string, resetAt time.Time) error {
return nil
}
func (m *mockAccountRepoForGemini) SetOverloaded(ctx context.Context, id int64, until time.Time) error {
return nil
}
......@@ -138,6 +141,9 @@ func (m *mockAccountRepoForGemini) ClearRateLimit(ctx context.Context, id int64)
func (m *mockAccountRepoForGemini) ClearAntigravityQuotaScopes(ctx context.Context, id int64) error {
return nil
}
func (m *mockAccountRepoForGemini) ClearModelRateLimits(ctx context.Context, id int64) error {
return nil
}
func (m *mockAccountRepoForGemini) UpdateSessionWindow(ctx context.Context, id int64, start, end *time.Time, status string) error {
return nil
}
......
......@@ -10,6 +10,7 @@ type GeminiTokenCache interface {
// cacheKey should be stable for the token scope; for GeminiCli OAuth we primarily use project_id.
GetAccessToken(ctx context.Context, cacheKey string) (string, error)
SetAccessToken(ctx context.Context, cacheKey string, token string, ttl time.Duration) error
DeleteAccessToken(ctx context.Context, cacheKey string) error
AcquireRefreshLock(ctx context.Context, cacheKey string, ttl time.Duration) (bool, error)
ReleaseRefreshLock(ctx context.Context, cacheKey string) error
......
......@@ -40,7 +40,7 @@ func (p *GeminiTokenProvider) GetAccessToken(ctx context.Context, account *Accou
return "", errors.New("not a gemini oauth account")
}
cacheKey := geminiTokenCacheKey(account)
cacheKey := GeminiTokenCacheKey(account)
// 1) Try cache first.
if p.tokenCache != nil {
......@@ -151,10 +151,10 @@ func (p *GeminiTokenProvider) GetAccessToken(ctx context.Context, account *Accou
return accessToken, nil
}
func geminiTokenCacheKey(account *Account) string {
func GeminiTokenCacheKey(account *Account) string {
projectID := strings.TrimSpace(account.GetCredential("project_id"))
if projectID != "" {
return projectID
return "gemini:" + projectID
}
return "account:" + strconv.FormatInt(account.ID, 10)
return "gemini:account:" + strconv.FormatInt(account.ID, 10)
}
package service
import "time"
import (
"strings"
"time"
)
type Group struct {
ID int64
......@@ -27,6 +30,12 @@ type Group struct {
ClaudeCodeOnly bool
FallbackGroupID *int64
// 模型路由配置
// key: 模型匹配模式(支持 * 通配符,如 "claude-opus-*")
// value: 优先账号 ID 列表
ModelRouting map[string][]int64
ModelRoutingEnabled bool
CreatedAt time.Time
UpdatedAt time.Time
......@@ -90,3 +99,41 @@ func IsGroupContextValid(group *Group) bool {
}
return true
}
// GetRoutingAccountIDs 根据请求模型获取路由账号 ID 列表
// 返回匹配的优先账号 ID 列表,如果没有匹配规则则返回 nil
func (g *Group) GetRoutingAccountIDs(requestedModel string) []int64 {
if !g.ModelRoutingEnabled || len(g.ModelRouting) == 0 || requestedModel == "" {
return nil
}
// 1. 精确匹配优先
if accountIDs, ok := g.ModelRouting[requestedModel]; ok && len(accountIDs) > 0 {
return accountIDs
}
// 2. 通配符匹配(前缀匹配)
for pattern, accountIDs := range g.ModelRouting {
if matchModelPattern(pattern, requestedModel) && len(accountIDs) > 0 {
return accountIDs
}
}
return nil
}
// matchModelPattern 检查模型是否匹配模式
// 支持 * 通配符,如 "claude-opus-*" 匹配 "claude-opus-4-20250514"
func matchModelPattern(pattern, model string) bool {
if pattern == model {
return true
}
// 处理 * 通配符(仅支持末尾通配符)
if strings.HasSuffix(pattern, "*") {
prefix := strings.TrimSuffix(pattern, "*")
return strings.HasPrefix(model, prefix)
}
return false
}
package service
import (
"strings"
"time"
)
const modelRateLimitsKey = "model_rate_limits"
const modelRateLimitScopeClaudeSonnet = "claude_sonnet"
func resolveModelRateLimitScope(requestedModel string) (string, bool) {
model := strings.ToLower(strings.TrimSpace(requestedModel))
if model == "" {
return "", false
}
model = strings.TrimPrefix(model, "models/")
if strings.Contains(model, "sonnet") {
return modelRateLimitScopeClaudeSonnet, true
}
return "", false
}
func (a *Account) isModelRateLimited(requestedModel string) bool {
scope, ok := resolveModelRateLimitScope(requestedModel)
if !ok {
return false
}
resetAt := a.modelRateLimitResetAt(scope)
if resetAt == nil {
return false
}
return time.Now().Before(*resetAt)
}
func (a *Account) modelRateLimitResetAt(scope string) *time.Time {
if a == nil || a.Extra == nil || scope == "" {
return nil
}
rawLimits, ok := a.Extra[modelRateLimitsKey].(map[string]any)
if !ok {
return nil
}
rawLimit, ok := rawLimits[scope].(map[string]any)
if !ok {
return nil
}
resetAtRaw, ok := rawLimit["rate_limit_reset_at"].(string)
if !ok || strings.TrimSpace(resetAtRaw) == "" {
return nil
}
resetAt, err := time.Parse(time.RFC3339, resetAtRaw)
if err != nil {
return nil
}
return &resetAt
}
......@@ -93,6 +93,8 @@ type OpenAIGatewayService struct {
billingCacheService *BillingCacheService
httpUpstream HTTPUpstream
deferredService *DeferredService
openAITokenProvider *OpenAITokenProvider
toolCorrector *CodexToolCorrector
}
// NewOpenAIGatewayService creates a new OpenAIGatewayService
......@@ -110,6 +112,7 @@ func NewOpenAIGatewayService(
billingCacheService *BillingCacheService,
httpUpstream HTTPUpstream,
deferredService *DeferredService,
openAITokenProvider *OpenAITokenProvider,
) *OpenAIGatewayService {
return &OpenAIGatewayService{
accountRepo: accountRepo,
......@@ -125,6 +128,8 @@ func NewOpenAIGatewayService(
billingCacheService: billingCacheService,
httpUpstream: httpUpstream,
deferredService: deferredService,
openAITokenProvider: openAITokenProvider,
toolCorrector: NewCodexToolCorrector(),
}
}
......@@ -503,6 +508,15 @@ func (s *OpenAIGatewayService) schedulingConfig() config.GatewaySchedulingConfig
func (s *OpenAIGatewayService) GetAccessToken(ctx context.Context, account *Account) (string, string, error) {
switch account.Type {
case AccountTypeOAuth:
// 使用 TokenProvider 获取缓存的 token
if s.openAITokenProvider != nil {
accessToken, err := s.openAITokenProvider.GetAccessToken(ctx, account)
if err != nil {
return "", "", err
}
return accessToken, "oauth", nil
}
// 降级:TokenProvider 未配置时直接从账号读取
accessToken := account.GetOpenAIAccessToken()
if accessToken == "" {
return "", "", errors.New("access_token not found in credentials")
......@@ -664,6 +678,11 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
proxyURL = account.Proxy.URL()
}
// Capture upstream request body for ops retry of this attempt.
if c != nil {
c.Set(OpsUpstreamRequestBodyKey, string(body))
}
// Send request
resp, err := s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency)
if err != nil {
......@@ -673,6 +692,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: 0,
Kind: "request_error",
Message: safeErr,
......@@ -707,6 +727,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: resp.StatusCode,
UpstreamRequestID: resp.Header.Get("x-request-id"),
Kind: "failover",
......@@ -864,6 +885,7 @@ func (s *OpenAIGatewayService) handleErrorResponse(ctx context.Context, resp *ht
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: resp.StatusCode,
UpstreamRequestID: resp.Header.Get("x-request-id"),
Kind: "http_error",
......@@ -894,6 +916,7 @@ func (s *OpenAIGatewayService) handleErrorResponse(ctx context.Context, resp *ht
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: resp.StatusCode,
UpstreamRequestID: resp.Header.Get("x-request-id"),
Kind: kind,
......@@ -1097,6 +1120,12 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
line = s.replaceModelInSSELine(line, mappedModel, originalModel)
}
// Correct Codex tool calls if needed (apply_patch -> edit, etc.)
if correctedData, corrected := s.toolCorrector.CorrectToolCallsInSSEData(data); corrected {
data = correctedData
line = "data: " + correctedData
}
// 写入客户端(客户端断开后继续 drain 上游)
if !clientDisconnected {
if _, err := fmt.Fprintf(w, "%s\n", line); err != nil {
......@@ -1199,6 +1228,20 @@ func (s *OpenAIGatewayService) replaceModelInSSELine(line, fromModel, toModel st
return line
}
// correctToolCallsInResponseBody 修正响应体中的工具调用
func (s *OpenAIGatewayService) correctToolCallsInResponseBody(body []byte) []byte {
if len(body) == 0 {
return body
}
bodyStr := string(body)
corrected, changed := s.toolCorrector.CorrectToolCallsInSSEData(bodyStr)
if changed {
return []byte(corrected)
}
return body
}
func (s *OpenAIGatewayService) parseSSEUsage(data string, usage *OpenAIUsage) {
// Parse response.completed event for usage (OpenAI Responses format)
var event struct {
......@@ -1302,6 +1345,8 @@ func (s *OpenAIGatewayService) handleOAuthSSEToJSON(resp *http.Response, c *gin.
if originalModel != mappedModel {
body = s.replaceModelInResponseBody(body, mappedModel, originalModel)
}
// Correct tool calls in final response
body = s.correctToolCallsInResponseBody(body)
} else {
usage = s.parseSSEUsageFromBody(bodyText)
if originalModel != mappedModel {
......@@ -1470,28 +1515,30 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
// Create usage log
durationMs := int(result.Duration.Milliseconds())
accountRateMultiplier := account.BillingRateMultiplier()
usageLog := &UsageLog{
UserID: user.ID,
APIKeyID: apiKey.ID,
AccountID: account.ID,
RequestID: result.RequestID,
Model: result.Model,
InputTokens: actualInputTokens,
OutputTokens: result.Usage.OutputTokens,
CacheCreationTokens: result.Usage.CacheCreationInputTokens,
CacheReadTokens: result.Usage.CacheReadInputTokens,
InputCost: cost.InputCost,
OutputCost: cost.OutputCost,
CacheCreationCost: cost.CacheCreationCost,
CacheReadCost: cost.CacheReadCost,
TotalCost: cost.TotalCost,
ActualCost: cost.ActualCost,
RateMultiplier: multiplier,
BillingType: billingType,
Stream: result.Stream,
DurationMs: &durationMs,
FirstTokenMs: result.FirstTokenMs,
CreatedAt: time.Now(),
UserID: user.ID,
APIKeyID: apiKey.ID,
AccountID: account.ID,
RequestID: result.RequestID,
Model: result.Model,
InputTokens: actualInputTokens,
OutputTokens: result.Usage.OutputTokens,
CacheCreationTokens: result.Usage.CacheCreationInputTokens,
CacheReadTokens: result.Usage.CacheReadInputTokens,
InputCost: cost.InputCost,
OutputCost: cost.OutputCost,
CacheCreationCost: cost.CacheCreationCost,
CacheReadCost: cost.CacheReadCost,
TotalCost: cost.TotalCost,
ActualCost: cost.ActualCost,
RateMultiplier: multiplier,
AccountRateMultiplier: &accountRateMultiplier,
BillingType: billingType,
Stream: result.Stream,
DurationMs: &durationMs,
FirstTokenMs: result.FirstTokenMs,
CreatedAt: time.Now(),
}
// 添加 UserAgent
......
package service
import (
"strings"
"testing"
)
// TestOpenAIGatewayService_ToolCorrection 测试 OpenAIGatewayService 中的工具修正集成
func TestOpenAIGatewayService_ToolCorrection(t *testing.T) {
// 创建一个简单的 service 实例来测试工具修正
service := &OpenAIGatewayService{
toolCorrector: NewCodexToolCorrector(),
}
tests := []struct {
name string
input []byte
expected string
changed bool
}{
{
name: "correct apply_patch in response body",
input: []byte(`{
"choices": [{
"message": {
"tool_calls": [{
"function": {"name": "apply_patch"}
}]
}
}]
}`),
expected: "edit",
changed: true,
},
{
name: "correct update_plan in response body",
input: []byte(`{
"tool_calls": [{
"function": {"name": "update_plan"}
}]
}`),
expected: "todowrite",
changed: true,
},
{
name: "no change for correct tool name",
input: []byte(`{
"tool_calls": [{
"function": {"name": "edit"}
}]
}`),
expected: "edit",
changed: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := service.correctToolCallsInResponseBody(tt.input)
resultStr := string(result)
// 检查是否包含期望的工具名称
if !strings.Contains(resultStr, tt.expected) {
t.Errorf("expected result to contain %q, got %q", tt.expected, resultStr)
}
// 对于预期有变化的情况,验证结果与输入不同
if tt.changed && string(result) == string(tt.input) {
t.Error("expected result to be different from input, but they are the same")
}
// 对于预期无变化的情况,验证结果与输入相同
if !tt.changed && string(result) != string(tt.input) {
t.Error("expected result to be same as input, but they are different")
}
})
}
}
// TestOpenAIGatewayService_ToolCorrectorInitialization 测试工具修正器是否正确初始化
func TestOpenAIGatewayService_ToolCorrectorInitialization(t *testing.T) {
service := &OpenAIGatewayService{
toolCorrector: NewCodexToolCorrector(),
}
if service.toolCorrector == nil {
t.Fatal("toolCorrector should not be nil")
}
// 测试修正器可以正常工作
data := `{"tool_calls":[{"function":{"name":"apply_patch"}}]}`
corrected, changed := service.toolCorrector.CorrectToolCallsInSSEData(data)
if !changed {
t.Error("expected tool call to be corrected")
}
if !strings.Contains(corrected, "edit") {
t.Errorf("expected corrected data to contain 'edit', got %q", corrected)
}
}
// TestToolCorrectionStats 测试工具修正统计功能
func TestToolCorrectionStats(t *testing.T) {
service := &OpenAIGatewayService{
toolCorrector: NewCodexToolCorrector(),
}
// 执行几次修正
testData := []string{
`{"tool_calls":[{"function":{"name":"apply_patch"}}]}`,
`{"tool_calls":[{"function":{"name":"update_plan"}}]}`,
`{"tool_calls":[{"function":{"name":"apply_patch"}}]}`,
}
for _, data := range testData {
service.toolCorrector.CorrectToolCallsInSSEData(data)
}
stats := service.toolCorrector.GetStats()
if stats.TotalCorrected != 3 {
t.Errorf("expected 3 corrections, got %d", stats.TotalCorrected)
}
if stats.CorrectionsByTool["apply_patch->edit"] != 2 {
t.Errorf("expected 2 apply_patch->edit corrections, got %d", stats.CorrectionsByTool["apply_patch->edit"])
}
if stats.CorrectionsByTool["update_plan->todowrite"] != 1 {
t.Errorf("expected 1 update_plan->todowrite correction, got %d", stats.CorrectionsByTool["update_plan->todowrite"])
}
}
package service
import (
"context"
"errors"
"log/slog"
"strings"
"time"
)
const (
openAITokenRefreshSkew = 3 * time.Minute
openAITokenCacheSkew = 5 * time.Minute
openAILockWaitTime = 200 * time.Millisecond
)
// OpenAITokenCache Token 缓存接口(复用 GeminiTokenCache 接口定义)
type OpenAITokenCache = GeminiTokenCache
// OpenAITokenProvider 管理 OpenAI OAuth 账户的 access_token
type OpenAITokenProvider struct {
accountRepo AccountRepository
tokenCache OpenAITokenCache
openAIOAuthService *OpenAIOAuthService
}
func NewOpenAITokenProvider(
accountRepo AccountRepository,
tokenCache OpenAITokenCache,
openAIOAuthService *OpenAIOAuthService,
) *OpenAITokenProvider {
return &OpenAITokenProvider{
accountRepo: accountRepo,
tokenCache: tokenCache,
openAIOAuthService: openAIOAuthService,
}
}
// GetAccessToken 获取有效的 access_token
func (p *OpenAITokenProvider) GetAccessToken(ctx context.Context, account *Account) (string, error) {
if account == nil {
return "", errors.New("account is nil")
}
if account.Platform != PlatformOpenAI || account.Type != AccountTypeOAuth {
return "", errors.New("not an openai oauth account")
}
cacheKey := OpenAITokenCacheKey(account)
// 1. 先尝试缓存
if p.tokenCache != nil {
if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && strings.TrimSpace(token) != "" {
slog.Debug("openai_token_cache_hit", "account_id", account.ID)
return token, nil
} else if err != nil {
slog.Warn("openai_token_cache_get_failed", "account_id", account.ID, "error", err)
}
}
slog.Debug("openai_token_cache_miss", "account_id", account.ID)
// 2. 如果即将过期则刷新
expiresAt := account.GetCredentialAsTime("expires_at")
needsRefresh := expiresAt == nil || time.Until(*expiresAt) <= openAITokenRefreshSkew
refreshFailed := false
if needsRefresh && p.tokenCache != nil {
locked, lockErr := p.tokenCache.AcquireRefreshLock(ctx, cacheKey, 30*time.Second)
if lockErr == nil && locked {
defer func() { _ = p.tokenCache.ReleaseRefreshLock(ctx, cacheKey) }()
// 拿到锁后再次检查缓存(另一个 worker 可能已刷新)
if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && strings.TrimSpace(token) != "" {
return token, nil
}
// 从数据库获取最新账户信息
fresh, err := p.accountRepo.GetByID(ctx, account.ID)
if err == nil && fresh != nil {
account = fresh
}
expiresAt = account.GetCredentialAsTime("expires_at")
if expiresAt == nil || time.Until(*expiresAt) <= openAITokenRefreshSkew {
if p.openAIOAuthService == nil {
slog.Warn("openai_oauth_service_not_configured", "account_id", account.ID)
refreshFailed = true // 无法刷新,标记失败
} else {
tokenInfo, err := p.openAIOAuthService.RefreshAccountToken(ctx, account)
if err != nil {
// 刷新失败时记录警告,但不立即返回错误,尝试使用现有 token
slog.Warn("openai_token_refresh_failed", "account_id", account.ID, "error", err)
refreshFailed = true // 刷新失败,标记以使用短 TTL
} else {
newCredentials := p.openAIOAuthService.BuildAccountCredentials(tokenInfo)
for k, v := range account.Credentials {
if _, exists := newCredentials[k]; !exists {
newCredentials[k] = v
}
}
account.Credentials = newCredentials
if updateErr := p.accountRepo.Update(ctx, account); updateErr != nil {
slog.Error("openai_token_provider_update_failed", "account_id", account.ID, "error", updateErr)
}
expiresAt = account.GetCredentialAsTime("expires_at")
}
}
}
} else if lockErr != nil {
// Redis 错误导致无法获取锁,降级为无锁刷新(仅在 token 接近过期时)
slog.Warn("openai_token_lock_failed_degraded_refresh", "account_id", account.ID, "error", lockErr)
// 检查 ctx 是否已取消
if ctx.Err() != nil {
return "", ctx.Err()
}
// 从数据库获取最新账户信息
if p.accountRepo != nil {
fresh, err := p.accountRepo.GetByID(ctx, account.ID)
if err == nil && fresh != nil {
account = fresh
}
}
expiresAt = account.GetCredentialAsTime("expires_at")
// 仅在 expires_at 已过期/接近过期时才执行无锁刷新
if expiresAt == nil || time.Until(*expiresAt) <= openAITokenRefreshSkew {
if p.openAIOAuthService == nil {
slog.Warn("openai_oauth_service_not_configured", "account_id", account.ID)
refreshFailed = true
} else {
tokenInfo, err := p.openAIOAuthService.RefreshAccountToken(ctx, account)
if err != nil {
slog.Warn("openai_token_refresh_failed_degraded", "account_id", account.ID, "error", err)
refreshFailed = true
} else {
newCredentials := p.openAIOAuthService.BuildAccountCredentials(tokenInfo)
for k, v := range account.Credentials {
if _, exists := newCredentials[k]; !exists {
newCredentials[k] = v
}
}
account.Credentials = newCredentials
if updateErr := p.accountRepo.Update(ctx, account); updateErr != nil {
slog.Error("openai_token_provider_update_failed", "account_id", account.ID, "error", updateErr)
}
expiresAt = account.GetCredentialAsTime("expires_at")
}
}
}
} else {
// 锁获取失败(被其他 worker 持有),等待 200ms 后重试读取缓存
time.Sleep(openAILockWaitTime)
if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && strings.TrimSpace(token) != "" {
slog.Debug("openai_token_cache_hit_after_wait", "account_id", account.ID)
return token, nil
}
}
}
accessToken := account.GetOpenAIAccessToken()
if strings.TrimSpace(accessToken) == "" {
return "", errors.New("access_token not found in credentials")
}
// 3. 存入缓存
if p.tokenCache != nil {
ttl := 30 * time.Minute
if refreshFailed {
// 刷新失败时使用短 TTL,避免失效 token 长时间缓存导致 401 抖动
ttl = time.Minute
slog.Debug("openai_token_cache_short_ttl", "account_id", account.ID, "reason", "refresh_failed")
} else if expiresAt != nil {
until := time.Until(*expiresAt)
switch {
case until > openAITokenCacheSkew:
ttl = until - openAITokenCacheSkew
case until > 0:
ttl = until
default:
ttl = time.Minute
}
}
if err := p.tokenCache.SetAccessToken(ctx, cacheKey, accessToken, ttl); err != nil {
slog.Warn("openai_token_cache_set_failed", "account_id", account.ID, "error", err)
}
}
return accessToken, nil
}
//go:build unit
package service
import (
"context"
"errors"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/stretchr/testify/require"
)
// openAITokenCacheStub implements OpenAITokenCache for testing
type openAITokenCacheStub struct {
mu sync.Mutex
tokens map[string]string
getErr error
setErr error
deleteErr error
lockAcquired bool
lockErr error
releaseLockErr error
getCalled int32
setCalled int32
lockCalled int32
unlockCalled int32
simulateLockRace bool
}
func newOpenAITokenCacheStub() *openAITokenCacheStub {
return &openAITokenCacheStub{
tokens: make(map[string]string),
lockAcquired: true,
}
}
func (s *openAITokenCacheStub) GetAccessToken(ctx context.Context, cacheKey string) (string, error) {
atomic.AddInt32(&s.getCalled, 1)
if s.getErr != nil {
return "", s.getErr
}
s.mu.Lock()
defer s.mu.Unlock()
return s.tokens[cacheKey], nil
}
func (s *openAITokenCacheStub) SetAccessToken(ctx context.Context, cacheKey string, token string, ttl time.Duration) error {
atomic.AddInt32(&s.setCalled, 1)
if s.setErr != nil {
return s.setErr
}
s.mu.Lock()
defer s.mu.Unlock()
s.tokens[cacheKey] = token
return nil
}
func (s *openAITokenCacheStub) DeleteAccessToken(ctx context.Context, cacheKey string) error {
if s.deleteErr != nil {
return s.deleteErr
}
s.mu.Lock()
defer s.mu.Unlock()
delete(s.tokens, cacheKey)
return nil
}
func (s *openAITokenCacheStub) AcquireRefreshLock(ctx context.Context, cacheKey string, ttl time.Duration) (bool, error) {
atomic.AddInt32(&s.lockCalled, 1)
if s.lockErr != nil {
return false, s.lockErr
}
if s.simulateLockRace {
return false, nil
}
return s.lockAcquired, nil
}
func (s *openAITokenCacheStub) ReleaseRefreshLock(ctx context.Context, cacheKey string) error {
atomic.AddInt32(&s.unlockCalled, 1)
return s.releaseLockErr
}
// openAIAccountRepoStub is a minimal stub implementing only the methods used by OpenAITokenProvider
type openAIAccountRepoStub struct {
account *Account
getErr error
updateErr error
getCalled int32
updateCalled int32
}
func (r *openAIAccountRepoStub) GetByID(ctx context.Context, id int64) (*Account, error) {
atomic.AddInt32(&r.getCalled, 1)
if r.getErr != nil {
return nil, r.getErr
}
return r.account, nil
}
func (r *openAIAccountRepoStub) Update(ctx context.Context, account *Account) error {
atomic.AddInt32(&r.updateCalled, 1)
if r.updateErr != nil {
return r.updateErr
}
r.account = account
return nil
}
// openAIOAuthServiceStub implements OpenAIOAuthService methods for testing
type openAIOAuthServiceStub struct {
tokenInfo *OpenAITokenInfo
refreshErr error
refreshCalled int32
}
func (s *openAIOAuthServiceStub) RefreshAccountToken(ctx context.Context, account *Account) (*OpenAITokenInfo, error) {
atomic.AddInt32(&s.refreshCalled, 1)
if s.refreshErr != nil {
return nil, s.refreshErr
}
return s.tokenInfo, nil
}
func (s *openAIOAuthServiceStub) BuildAccountCredentials(info *OpenAITokenInfo) map[string]any {
now := time.Now()
return map[string]any{
"access_token": info.AccessToken,
"refresh_token": info.RefreshToken,
"expires_at": now.Add(time.Duration(info.ExpiresIn) * time.Second).Format(time.RFC3339),
}
}
func TestOpenAITokenProvider_CacheHit(t *testing.T) {
cache := newOpenAITokenCacheStub()
account := &Account{
ID: 100,
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "db-token",
},
}
cacheKey := OpenAITokenCacheKey(account)
cache.tokens[cacheKey] = "cached-token"
provider := NewOpenAITokenProvider(nil, cache, nil)
token, err := provider.GetAccessToken(context.Background(), account)
require.NoError(t, err)
require.Equal(t, "cached-token", token)
require.Equal(t, int32(1), atomic.LoadInt32(&cache.getCalled))
require.Equal(t, int32(0), atomic.LoadInt32(&cache.setCalled))
}
func TestOpenAITokenProvider_CacheMiss_FromCredentials(t *testing.T) {
cache := newOpenAITokenCacheStub()
// Token expires in far future, no refresh needed
expiresAt := time.Now().Add(1 * time.Hour).Format(time.RFC3339)
account := &Account{
ID: 101,
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "credential-token",
"expires_at": expiresAt,
},
}
provider := NewOpenAITokenProvider(nil, cache, nil)
token, err := provider.GetAccessToken(context.Background(), account)
require.NoError(t, err)
require.Equal(t, "credential-token", token)
// Should have stored in cache
cacheKey := OpenAITokenCacheKey(account)
require.Equal(t, "credential-token", cache.tokens[cacheKey])
}
func TestOpenAITokenProvider_TokenRefresh(t *testing.T) {
cache := newOpenAITokenCacheStub()
accountRepo := &openAIAccountRepoStub{}
oauthService := &openAIOAuthServiceStub{
tokenInfo: &OpenAITokenInfo{
AccessToken: "refreshed-token",
RefreshToken: "new-refresh-token",
ExpiresIn: 3600,
},
}
// Token expires soon (within refresh skew)
expiresAt := time.Now().Add(1 * time.Minute).Format(time.RFC3339)
account := &Account{
ID: 102,
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "old-token",
"refresh_token": "old-refresh-token",
"expires_at": expiresAt,
},
}
accountRepo.account = account
// We need to directly test with the stub - create a custom provider
customProvider := &testOpenAITokenProvider{
accountRepo: accountRepo,
tokenCache: cache,
oauthService: oauthService,
}
token, err := customProvider.GetAccessToken(context.Background(), account)
require.NoError(t, err)
require.Equal(t, "refreshed-token", token)
require.Equal(t, int32(1), atomic.LoadInt32(&oauthService.refreshCalled))
}
// testOpenAITokenProvider is a test version that uses the stub OAuth service
type testOpenAITokenProvider struct {
accountRepo *openAIAccountRepoStub
tokenCache *openAITokenCacheStub
oauthService *openAIOAuthServiceStub
}
func (p *testOpenAITokenProvider) GetAccessToken(ctx context.Context, account *Account) (string, error) {
if account == nil {
return "", errors.New("account is nil")
}
if account.Platform != PlatformOpenAI || account.Type != AccountTypeOAuth {
return "", errors.New("not an openai oauth account")
}
cacheKey := OpenAITokenCacheKey(account)
// 1. Check cache
if p.tokenCache != nil {
if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && token != "" {
return token, nil
}
}
// 2. Check if refresh needed
expiresAt := account.GetCredentialAsTime("expires_at")
needsRefresh := expiresAt == nil || time.Until(*expiresAt) <= openAITokenRefreshSkew
refreshFailed := false
if needsRefresh && p.tokenCache != nil {
locked, err := p.tokenCache.AcquireRefreshLock(ctx, cacheKey, 30*time.Second)
if err == nil && locked {
defer func() { _ = p.tokenCache.ReleaseRefreshLock(ctx, cacheKey) }()
// Check cache again after acquiring lock
if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && token != "" {
return token, nil
}
// Get fresh account from DB
fresh, err := p.accountRepo.GetByID(ctx, account.ID)
if err == nil && fresh != nil {
account = fresh
}
expiresAt = account.GetCredentialAsTime("expires_at")
if expiresAt == nil || time.Until(*expiresAt) <= openAITokenRefreshSkew {
if p.oauthService == nil {
refreshFailed = true // 无法刷新,标记失败
} else {
tokenInfo, err := p.oauthService.RefreshAccountToken(ctx, account)
if err != nil {
refreshFailed = true // 刷新失败,标记以使用短 TTL
} else {
newCredentials := p.oauthService.BuildAccountCredentials(tokenInfo)
for k, v := range account.Credentials {
if _, exists := newCredentials[k]; !exists {
newCredentials[k] = v
}
}
account.Credentials = newCredentials
_ = p.accountRepo.Update(ctx, account)
expiresAt = account.GetCredentialAsTime("expires_at")
}
}
}
} else if p.tokenCache.simulateLockRace {
// Wait and retry cache
time.Sleep(10 * time.Millisecond) // Short wait for test
if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && token != "" {
return token, nil
}
}
}
accessToken := account.GetOpenAIAccessToken()
if accessToken == "" {
return "", errors.New("access_token not found in credentials")
}
// 3. Store in cache
if p.tokenCache != nil {
ttl := 30 * time.Minute
if refreshFailed {
ttl = time.Minute // 刷新失败时使用短 TTL
} else if expiresAt != nil {
until := time.Until(*expiresAt)
if until > openAITokenCacheSkew {
ttl = until - openAITokenCacheSkew
} else if until > 0 {
ttl = until
} else {
ttl = time.Minute
}
}
_ = p.tokenCache.SetAccessToken(ctx, cacheKey, accessToken, ttl)
}
return accessToken, nil
}
func TestOpenAITokenProvider_LockRaceCondition(t *testing.T) {
cache := newOpenAITokenCacheStub()
cache.simulateLockRace = true
accountRepo := &openAIAccountRepoStub{}
// Token expires soon
expiresAt := time.Now().Add(1 * time.Minute).Format(time.RFC3339)
account := &Account{
ID: 103,
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "race-token",
"expires_at": expiresAt,
},
}
accountRepo.account = account
// Simulate another worker already refreshed and cached
cacheKey := OpenAITokenCacheKey(account)
go func() {
time.Sleep(5 * time.Millisecond)
cache.mu.Lock()
cache.tokens[cacheKey] = "winner-token"
cache.mu.Unlock()
}()
provider := &testOpenAITokenProvider{
accountRepo: accountRepo,
tokenCache: cache,
}
token, err := provider.GetAccessToken(context.Background(), account)
require.NoError(t, err)
// Should get the token set by the "winner" or the original
require.NotEmpty(t, token)
}
func TestOpenAITokenProvider_NilAccount(t *testing.T) {
provider := NewOpenAITokenProvider(nil, nil, nil)
token, err := provider.GetAccessToken(context.Background(), nil)
require.Error(t, err)
require.Contains(t, err.Error(), "account is nil")
require.Empty(t, token)
}
func TestOpenAITokenProvider_WrongPlatform(t *testing.T) {
provider := NewOpenAITokenProvider(nil, nil, nil)
account := &Account{
ID: 104,
Platform: PlatformGemini,
Type: AccountTypeOAuth,
}
token, err := provider.GetAccessToken(context.Background(), account)
require.Error(t, err)
require.Contains(t, err.Error(), "not an openai oauth account")
require.Empty(t, token)
}
func TestOpenAITokenProvider_WrongAccountType(t *testing.T) {
provider := NewOpenAITokenProvider(nil, nil, nil)
account := &Account{
ID: 105,
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
}
token, err := provider.GetAccessToken(context.Background(), account)
require.Error(t, err)
require.Contains(t, err.Error(), "not an openai oauth account")
require.Empty(t, token)
}
func TestOpenAITokenProvider_NilCache(t *testing.T) {
// Token doesn't need refresh
expiresAt := time.Now().Add(1 * time.Hour).Format(time.RFC3339)
account := &Account{
ID: 106,
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "nocache-token",
"expires_at": expiresAt,
},
}
provider := NewOpenAITokenProvider(nil, nil, nil)
token, err := provider.GetAccessToken(context.Background(), account)
require.NoError(t, err)
require.Equal(t, "nocache-token", token)
}
func TestOpenAITokenProvider_CacheGetError(t *testing.T) {
cache := newOpenAITokenCacheStub()
cache.getErr = errors.New("redis connection failed")
// Token doesn't need refresh
expiresAt := time.Now().Add(1 * time.Hour).Format(time.RFC3339)
account := &Account{
ID: 107,
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "fallback-token",
"expires_at": expiresAt,
},
}
provider := NewOpenAITokenProvider(nil, cache, nil)
// Should gracefully degrade and return from credentials
token, err := provider.GetAccessToken(context.Background(), account)
require.NoError(t, err)
require.Equal(t, "fallback-token", token)
}
func TestOpenAITokenProvider_CacheSetError(t *testing.T) {
cache := newOpenAITokenCacheStub()
cache.setErr = errors.New("redis write failed")
expiresAt := time.Now().Add(1 * time.Hour).Format(time.RFC3339)
account := &Account{
ID: 108,
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "still-works-token",
"expires_at": expiresAt,
},
}
provider := NewOpenAITokenProvider(nil, cache, nil)
// Should still work even if cache set fails
token, err := provider.GetAccessToken(context.Background(), account)
require.NoError(t, err)
require.Equal(t, "still-works-token", token)
}
func TestOpenAITokenProvider_MissingAccessToken(t *testing.T) {
cache := newOpenAITokenCacheStub()
expiresAt := time.Now().Add(1 * time.Hour).Format(time.RFC3339)
account := &Account{
ID: 109,
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"expires_at": expiresAt,
// missing access_token
},
}
provider := NewOpenAITokenProvider(nil, cache, nil)
token, err := provider.GetAccessToken(context.Background(), account)
require.Error(t, err)
require.Contains(t, err.Error(), "access_token not found")
require.Empty(t, token)
}
func TestOpenAITokenProvider_RefreshError(t *testing.T) {
cache := newOpenAITokenCacheStub()
accountRepo := &openAIAccountRepoStub{}
oauthService := &openAIOAuthServiceStub{
refreshErr: errors.New("oauth refresh failed"),
}
// Token expires soon
expiresAt := time.Now().Add(1 * time.Minute).Format(time.RFC3339)
account := &Account{
ID: 110,
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "old-token",
"refresh_token": "old-refresh-token",
"expires_at": expiresAt,
},
}
accountRepo.account = account
provider := &testOpenAITokenProvider{
accountRepo: accountRepo,
tokenCache: cache,
oauthService: oauthService,
}
// Now with fallback behavior, should return existing token even if refresh fails
token, err := provider.GetAccessToken(context.Background(), account)
require.NoError(t, err)
require.Equal(t, "old-token", token) // Fallback to existing token
}
func TestOpenAITokenProvider_OAuthServiceNotConfigured(t *testing.T) {
cache := newOpenAITokenCacheStub()
accountRepo := &openAIAccountRepoStub{}
// Token expires soon
expiresAt := time.Now().Add(1 * time.Minute).Format(time.RFC3339)
account := &Account{
ID: 111,
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "old-token",
"expires_at": expiresAt,
},
}
accountRepo.account = account
provider := &testOpenAITokenProvider{
accountRepo: accountRepo,
tokenCache: cache,
oauthService: nil, // not configured
}
// Now with fallback behavior, should return existing token even if oauth service not configured
token, err := provider.GetAccessToken(context.Background(), account)
require.NoError(t, err)
require.Equal(t, "old-token", token) // Fallback to existing token
}
func TestOpenAITokenProvider_TTLCalculation(t *testing.T) {
tests := []struct {
name string
expiresIn time.Duration
}{
{
name: "far_future_expiry",
expiresIn: 1 * time.Hour,
},
{
name: "medium_expiry",
expiresIn: 10 * time.Minute,
},
{
name: "near_expiry",
expiresIn: 6 * time.Minute,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cache := newOpenAITokenCacheStub()
expiresAt := time.Now().Add(tt.expiresIn).Format(time.RFC3339)
account := &Account{
ID: 200,
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "test-token",
"expires_at": expiresAt,
},
}
provider := NewOpenAITokenProvider(nil, cache, nil)
_, err := provider.GetAccessToken(context.Background(), account)
require.NoError(t, err)
// Verify token was cached
cacheKey := OpenAITokenCacheKey(account)
require.Equal(t, "test-token", cache.tokens[cacheKey])
})
}
}
func TestOpenAITokenProvider_DoubleCheckAfterLock(t *testing.T) {
cache := newOpenAITokenCacheStub()
accountRepo := &openAIAccountRepoStub{}
oauthService := &openAIOAuthServiceStub{
tokenInfo: &OpenAITokenInfo{
AccessToken: "refreshed-token",
RefreshToken: "new-refresh",
ExpiresIn: 3600,
},
}
// Token expires soon
expiresAt := time.Now().Add(1 * time.Minute).Format(time.RFC3339)
account := &Account{
ID: 112,
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "old-token",
"expires_at": expiresAt,
},
}
accountRepo.account = account
cacheKey := OpenAITokenCacheKey(account)
// Simulate: first GetAccessToken returns empty, but after lock acquired, cache has token
originalGet := int32(0)
cache.tokens[cacheKey] = "" // Empty initially
provider := &testOpenAITokenProvider{
accountRepo: accountRepo,
tokenCache: cache,
oauthService: oauthService,
}
// In a goroutine, set the cached token after a small delay (simulating race)
go func() {
time.Sleep(5 * time.Millisecond)
cache.mu.Lock()
cache.tokens[cacheKey] = "cached-by-other"
cache.mu.Unlock()
}()
token, err := provider.GetAccessToken(context.Background(), account)
require.NoError(t, err)
// Should get either the refreshed token or the cached one
require.NotEmpty(t, token)
_ = originalGet // Suppress unused warning
}
// Tests for real provider - to increase coverage
func TestOpenAITokenProvider_Real_LockFailedWait(t *testing.T) {
cache := newOpenAITokenCacheStub()
cache.lockAcquired = false // Lock acquisition fails
// Token expires soon (within refresh skew) to trigger lock attempt
expiresAt := time.Now().Add(1 * time.Minute).Format(time.RFC3339)
account := &Account{
ID: 200,
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "fallback-token",
"expires_at": expiresAt,
},
}
// Set token in cache after lock wait period (simulate other worker refreshing)
cacheKey := OpenAITokenCacheKey(account)
go func() {
time.Sleep(100 * time.Millisecond)
cache.mu.Lock()
cache.tokens[cacheKey] = "refreshed-by-other"
cache.mu.Unlock()
}()
provider := NewOpenAITokenProvider(nil, cache, nil)
token, err := provider.GetAccessToken(context.Background(), account)
require.NoError(t, err)
// Should get either the fallback token or the refreshed one
require.NotEmpty(t, token)
}
func TestOpenAITokenProvider_Real_CacheHitAfterWait(t *testing.T) {
cache := newOpenAITokenCacheStub()
cache.lockAcquired = false // Lock acquisition fails
// Token expires soon
expiresAt := time.Now().Add(1 * time.Minute).Format(time.RFC3339)
account := &Account{
ID: 201,
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "original-token",
"expires_at": expiresAt,
},
}
cacheKey := OpenAITokenCacheKey(account)
// Set token in cache immediately after wait starts
go func() {
time.Sleep(50 * time.Millisecond)
cache.mu.Lock()
cache.tokens[cacheKey] = "winner-token"
cache.mu.Unlock()
}()
provider := NewOpenAITokenProvider(nil, cache, nil)
token, err := provider.GetAccessToken(context.Background(), account)
require.NoError(t, err)
require.NotEmpty(t, token)
}
func TestOpenAITokenProvider_Real_ExpiredWithoutRefreshToken(t *testing.T) {
cache := newOpenAITokenCacheStub()
cache.lockAcquired = false // Prevent entering refresh logic
// Token with nil expires_at (no expiry set) - should use credentials
account := &Account{
ID: 202,
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "no-expiry-token",
},
}
provider := NewOpenAITokenProvider(nil, cache, nil)
token, err := provider.GetAccessToken(context.Background(), account)
// Without OAuth service, refresh will fail but token should be returned from credentials
require.NoError(t, err)
require.Equal(t, "no-expiry-token", token)
}
func TestOpenAITokenProvider_Real_WhitespaceToken(t *testing.T) {
cache := newOpenAITokenCacheStub()
cacheKey := "openai:account:203"
cache.tokens[cacheKey] = " " // Whitespace only - should be treated as empty
expiresAt := time.Now().Add(1 * time.Hour).Format(time.RFC3339)
account := &Account{
ID: 203,
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "real-token",
"expires_at": expiresAt,
},
}
provider := NewOpenAITokenProvider(nil, cache, nil)
token, err := provider.GetAccessToken(context.Background(), account)
require.NoError(t, err)
require.Equal(t, "real-token", token) // Should fall back to credentials
}
func TestOpenAITokenProvider_Real_LockError(t *testing.T) {
cache := newOpenAITokenCacheStub()
cache.lockErr = errors.New("redis lock failed")
// Token expires soon (within refresh skew)
expiresAt := time.Now().Add(1 * time.Minute).Format(time.RFC3339)
account := &Account{
ID: 204,
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "fallback-on-lock-error",
"expires_at": expiresAt,
},
}
provider := NewOpenAITokenProvider(nil, cache, nil)
token, err := provider.GetAccessToken(context.Background(), account)
require.NoError(t, err)
require.Equal(t, "fallback-on-lock-error", token)
}
func TestOpenAITokenProvider_Real_WhitespaceCredentialToken(t *testing.T) {
cache := newOpenAITokenCacheStub()
expiresAt := time.Now().Add(1 * time.Hour).Format(time.RFC3339)
account := &Account{
ID: 205,
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": " ", // Whitespace only
"expires_at": expiresAt,
},
}
provider := NewOpenAITokenProvider(nil, cache, nil)
token, err := provider.GetAccessToken(context.Background(), account)
require.Error(t, err)
require.Contains(t, err.Error(), "access_token not found")
require.Empty(t, token)
}
func TestOpenAITokenProvider_Real_NilCredentials(t *testing.T) {
cache := newOpenAITokenCacheStub()
expiresAt := time.Now().Add(1 * time.Hour).Format(time.RFC3339)
account := &Account{
ID: 206,
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"expires_at": expiresAt,
// No access_token
},
}
provider := NewOpenAITokenProvider(nil, cache, nil)
token, err := provider.GetAccessToken(context.Background(), account)
require.Error(t, err)
require.Contains(t, err.Error(), "access_token not found")
require.Empty(t, token)
}
package service
import (
"encoding/json"
"fmt"
"log"
"sync"
)
// codexToolNameMapping 定义 Codex 原生工具名称到 OpenCode 工具名称的映射
var codexToolNameMapping = map[string]string{
"apply_patch": "edit",
"applyPatch": "edit",
"update_plan": "todowrite",
"updatePlan": "todowrite",
"read_plan": "todoread",
"readPlan": "todoread",
"search_files": "grep",
"searchFiles": "grep",
"list_files": "glob",
"listFiles": "glob",
"read_file": "read",
"readFile": "read",
"write_file": "write",
"writeFile": "write",
"execute_bash": "bash",
"executeBash": "bash",
"exec_bash": "bash",
"execBash": "bash",
}
// ToolCorrectionStats 记录工具修正的统计信息(导出用于 JSON 序列化)
type ToolCorrectionStats struct {
TotalCorrected int `json:"total_corrected"`
CorrectionsByTool map[string]int `json:"corrections_by_tool"`
}
// CodexToolCorrector 处理 Codex 工具调用的自动修正
type CodexToolCorrector struct {
stats ToolCorrectionStats
mu sync.RWMutex
}
// NewCodexToolCorrector 创建新的工具修正器
func NewCodexToolCorrector() *CodexToolCorrector {
return &CodexToolCorrector{
stats: ToolCorrectionStats{
CorrectionsByTool: make(map[string]int),
},
}
}
// CorrectToolCallsInSSEData 修正 SSE 数据中的工具调用
// 返回修正后的数据和是否进行了修正
func (c *CodexToolCorrector) CorrectToolCallsInSSEData(data string) (string, bool) {
if data == "" || data == "\n" {
return data, false
}
// 尝试解析 JSON
var payload map[string]any
if err := json.Unmarshal([]byte(data), &payload); err != nil {
// 不是有效的 JSON,直接返回原数据
return data, false
}
corrected := false
// 处理 tool_calls 数组
if toolCalls, ok := payload["tool_calls"].([]any); ok {
if c.correctToolCallsArray(toolCalls) {
corrected = true
}
}
// 处理 function_call 对象
if functionCall, ok := payload["function_call"].(map[string]any); ok {
if c.correctFunctionCall(functionCall) {
corrected = true
}
}
// 处理 delta.tool_calls
if delta, ok := payload["delta"].(map[string]any); ok {
if toolCalls, ok := delta["tool_calls"].([]any); ok {
if c.correctToolCallsArray(toolCalls) {
corrected = true
}
}
if functionCall, ok := delta["function_call"].(map[string]any); ok {
if c.correctFunctionCall(functionCall) {
corrected = true
}
}
}
// 处理 choices[].message.tool_calls 和 choices[].delta.tool_calls
if choices, ok := payload["choices"].([]any); ok {
for _, choice := range choices {
if choiceMap, ok := choice.(map[string]any); ok {
// 处理 message 中的工具调用
if message, ok := choiceMap["message"].(map[string]any); ok {
if toolCalls, ok := message["tool_calls"].([]any); ok {
if c.correctToolCallsArray(toolCalls) {
corrected = true
}
}
if functionCall, ok := message["function_call"].(map[string]any); ok {
if c.correctFunctionCall(functionCall) {
corrected = true
}
}
}
// 处理 delta 中的工具调用
if delta, ok := choiceMap["delta"].(map[string]any); ok {
if toolCalls, ok := delta["tool_calls"].([]any); ok {
if c.correctToolCallsArray(toolCalls) {
corrected = true
}
}
if functionCall, ok := delta["function_call"].(map[string]any); ok {
if c.correctFunctionCall(functionCall) {
corrected = true
}
}
}
}
}
}
if !corrected {
return data, false
}
// 序列化回 JSON
correctedBytes, err := json.Marshal(payload)
if err != nil {
log.Printf("[CodexToolCorrector] Failed to marshal corrected data: %v", err)
return data, false
}
return string(correctedBytes), true
}
// correctToolCallsArray 修正工具调用数组中的工具名称
func (c *CodexToolCorrector) correctToolCallsArray(toolCalls []any) bool {
corrected := false
for _, toolCall := range toolCalls {
if toolCallMap, ok := toolCall.(map[string]any); ok {
if function, ok := toolCallMap["function"].(map[string]any); ok {
if c.correctFunctionCall(function) {
corrected = true
}
}
}
}
return corrected
}
// correctFunctionCall 修正单个函数调用的工具名称和参数
func (c *CodexToolCorrector) correctFunctionCall(functionCall map[string]any) bool {
name, ok := functionCall["name"].(string)
if !ok || name == "" {
return false
}
corrected := false
// 查找并修正工具名称
if correctName, found := codexToolNameMapping[name]; found {
functionCall["name"] = correctName
c.recordCorrection(name, correctName)
corrected = true
name = correctName // 使用修正后的名称进行参数修正
}
// 修正工具参数(基于工具名称)
if c.correctToolParameters(name, functionCall) {
corrected = true
}
return corrected
}
// correctToolParameters 修正工具参数以符合 OpenCode 规范
func (c *CodexToolCorrector) correctToolParameters(toolName string, functionCall map[string]any) bool {
arguments, ok := functionCall["arguments"]
if !ok {
return false
}
// arguments 可能是字符串(JSON)或已解析的 map
var argsMap map[string]any
switch v := arguments.(type) {
case string:
// 解析 JSON 字符串
if err := json.Unmarshal([]byte(v), &argsMap); err != nil {
return false
}
case map[string]any:
argsMap = v
default:
return false
}
corrected := false
// 根据工具名称应用特定的参数修正规则
switch toolName {
case "bash":
// 移除 workdir 参数(OpenCode 不支持)
if _, exists := argsMap["workdir"]; exists {
delete(argsMap, "workdir")
corrected = true
log.Printf("[CodexToolCorrector] Removed 'workdir' parameter from bash tool")
}
if _, exists := argsMap["work_dir"]; exists {
delete(argsMap, "work_dir")
corrected = true
log.Printf("[CodexToolCorrector] Removed 'work_dir' parameter from bash tool")
}
case "edit":
// OpenCode edit 使用 old_string/new_string,Codex 可能使用其他名称
// 这里可以添加参数名称的映射逻辑
if _, exists := argsMap["file_path"]; !exists {
if path, exists := argsMap["path"]; exists {
argsMap["file_path"] = path
delete(argsMap, "path")
corrected = true
log.Printf("[CodexToolCorrector] Renamed 'path' to 'file_path' in edit tool")
}
}
}
// 如果修正了参数,需要重新序列化
if corrected {
if _, wasString := arguments.(string); wasString {
// 原本是字符串,序列化回字符串
if newArgsJSON, err := json.Marshal(argsMap); err == nil {
functionCall["arguments"] = string(newArgsJSON)
}
} else {
// 原本是 map,直接赋值
functionCall["arguments"] = argsMap
}
}
return corrected
}
// recordCorrection 记录一次工具名称修正
func (c *CodexToolCorrector) recordCorrection(from, to string) {
c.mu.Lock()
defer c.mu.Unlock()
c.stats.TotalCorrected++
key := fmt.Sprintf("%s->%s", from, to)
c.stats.CorrectionsByTool[key]++
log.Printf("[CodexToolCorrector] Corrected tool call: %s -> %s (total: %d)",
from, to, c.stats.TotalCorrected)
}
// GetStats 获取工具修正统计信息
func (c *CodexToolCorrector) GetStats() ToolCorrectionStats {
c.mu.RLock()
defer c.mu.RUnlock()
// 返回副本以避免并发问题
statsCopy := ToolCorrectionStats{
TotalCorrected: c.stats.TotalCorrected,
CorrectionsByTool: make(map[string]int, len(c.stats.CorrectionsByTool)),
}
for k, v := range c.stats.CorrectionsByTool {
statsCopy.CorrectionsByTool[k] = v
}
return statsCopy
}
// ResetStats 重置统计信息
func (c *CodexToolCorrector) ResetStats() {
c.mu.Lock()
defer c.mu.Unlock()
c.stats.TotalCorrected = 0
c.stats.CorrectionsByTool = make(map[string]int)
}
// CorrectToolName 直接修正工具名称(用于非 SSE 场景)
func CorrectToolName(name string) (string, bool) {
if correctName, found := codexToolNameMapping[name]; found {
return correctName, true
}
return name, false
}
// GetToolNameMapping 获取工具名称映射表
func GetToolNameMapping() map[string]string {
// 返回副本以避免外部修改
mapping := make(map[string]string, len(codexToolNameMapping))
for k, v := range codexToolNameMapping {
mapping[k] = v
}
return mapping
}
package service
import (
"encoding/json"
"testing"
)
func TestCorrectToolCallsInSSEData(t *testing.T) {
corrector := NewCodexToolCorrector()
tests := []struct {
name string
input string
expectCorrected bool
checkFunc func(t *testing.T, result string)
}{
{
name: "empty string",
input: "",
expectCorrected: false,
},
{
name: "newline only",
input: "\n",
expectCorrected: false,
},
{
name: "invalid json",
input: "not a json",
expectCorrected: false,
},
{
name: "correct apply_patch in tool_calls",
input: `{"tool_calls":[{"function":{"name":"apply_patch","arguments":"{}"}}]}`,
expectCorrected: true,
checkFunc: func(t *testing.T, result string) {
var payload map[string]any
if err := json.Unmarshal([]byte(result), &payload); err != nil {
t.Fatalf("Failed to parse result: %v", err)
}
toolCalls, ok := payload["tool_calls"].([]any)
if !ok || len(toolCalls) == 0 {
t.Fatal("No tool_calls found in result")
}
toolCall, ok := toolCalls[0].(map[string]any)
if !ok {
t.Fatal("Invalid tool_call format")
}
functionCall, ok := toolCall["function"].(map[string]any)
if !ok {
t.Fatal("Invalid function format")
}
if functionCall["name"] != "edit" {
t.Errorf("Expected tool name 'edit', got '%v'", functionCall["name"])
}
},
},
{
name: "correct update_plan in function_call",
input: `{"function_call":{"name":"update_plan","arguments":"{}"}}`,
expectCorrected: true,
checkFunc: func(t *testing.T, result string) {
var payload map[string]any
if err := json.Unmarshal([]byte(result), &payload); err != nil {
t.Fatalf("Failed to parse result: %v", err)
}
functionCall, ok := payload["function_call"].(map[string]any)
if !ok {
t.Fatal("Invalid function_call format")
}
if functionCall["name"] != "todowrite" {
t.Errorf("Expected tool name 'todowrite', got '%v'", functionCall["name"])
}
},
},
{
name: "correct search_files in delta.tool_calls",
input: `{"delta":{"tool_calls":[{"function":{"name":"search_files"}}]}}`,
expectCorrected: true,
checkFunc: func(t *testing.T, result string) {
var payload map[string]any
if err := json.Unmarshal([]byte(result), &payload); err != nil {
t.Fatalf("Failed to parse result: %v", err)
}
delta, ok := payload["delta"].(map[string]any)
if !ok {
t.Fatal("Invalid delta format")
}
toolCalls, ok := delta["tool_calls"].([]any)
if !ok || len(toolCalls) == 0 {
t.Fatal("No tool_calls found in delta")
}
toolCall, ok := toolCalls[0].(map[string]any)
if !ok {
t.Fatal("Invalid tool_call format")
}
functionCall, ok := toolCall["function"].(map[string]any)
if !ok {
t.Fatal("Invalid function format")
}
if functionCall["name"] != "grep" {
t.Errorf("Expected tool name 'grep', got '%v'", functionCall["name"])
}
},
},
{
name: "correct list_files in choices.message.tool_calls",
input: `{"choices":[{"message":{"tool_calls":[{"function":{"name":"list_files"}}]}}]}`,
expectCorrected: true,
checkFunc: func(t *testing.T, result string) {
var payload map[string]any
if err := json.Unmarshal([]byte(result), &payload); err != nil {
t.Fatalf("Failed to parse result: %v", err)
}
choices, ok := payload["choices"].([]any)
if !ok || len(choices) == 0 {
t.Fatal("No choices found in result")
}
choice, ok := choices[0].(map[string]any)
if !ok {
t.Fatal("Invalid choice format")
}
message, ok := choice["message"].(map[string]any)
if !ok {
t.Fatal("Invalid message format")
}
toolCalls, ok := message["tool_calls"].([]any)
if !ok || len(toolCalls) == 0 {
t.Fatal("No tool_calls found in message")
}
toolCall, ok := toolCalls[0].(map[string]any)
if !ok {
t.Fatal("Invalid tool_call format")
}
functionCall, ok := toolCall["function"].(map[string]any)
if !ok {
t.Fatal("Invalid function format")
}
if functionCall["name"] != "glob" {
t.Errorf("Expected tool name 'glob', got '%v'", functionCall["name"])
}
},
},
{
name: "no correction needed",
input: `{"tool_calls":[{"function":{"name":"read","arguments":"{}"}}]}`,
expectCorrected: false,
},
{
name: "correct multiple tool calls",
input: `{"tool_calls":[{"function":{"name":"apply_patch"}},{"function":{"name":"read_file"}}]}`,
expectCorrected: true,
checkFunc: func(t *testing.T, result string) {
var payload map[string]any
if err := json.Unmarshal([]byte(result), &payload); err != nil {
t.Fatalf("Failed to parse result: %v", err)
}
toolCalls, ok := payload["tool_calls"].([]any)
if !ok || len(toolCalls) < 2 {
t.Fatal("Expected at least 2 tool_calls")
}
toolCall1, ok := toolCalls[0].(map[string]any)
if !ok {
t.Fatal("Invalid first tool_call format")
}
func1, ok := toolCall1["function"].(map[string]any)
if !ok {
t.Fatal("Invalid first function format")
}
if func1["name"] != "edit" {
t.Errorf("Expected first tool name 'edit', got '%v'", func1["name"])
}
toolCall2, ok := toolCalls[1].(map[string]any)
if !ok {
t.Fatal("Invalid second tool_call format")
}
func2, ok := toolCall2["function"].(map[string]any)
if !ok {
t.Fatal("Invalid second function format")
}
if func2["name"] != "read" {
t.Errorf("Expected second tool name 'read', got '%v'", func2["name"])
}
},
},
{
name: "camelCase format - applyPatch",
input: `{"tool_calls":[{"function":{"name":"applyPatch"}}]}`,
expectCorrected: true,
checkFunc: func(t *testing.T, result string) {
var payload map[string]any
if err := json.Unmarshal([]byte(result), &payload); err != nil {
t.Fatalf("Failed to parse result: %v", err)
}
toolCalls, ok := payload["tool_calls"].([]any)
if !ok || len(toolCalls) == 0 {
t.Fatal("No tool_calls found in result")
}
toolCall, ok := toolCalls[0].(map[string]any)
if !ok {
t.Fatal("Invalid tool_call format")
}
functionCall, ok := toolCall["function"].(map[string]any)
if !ok {
t.Fatal("Invalid function format")
}
if functionCall["name"] != "edit" {
t.Errorf("Expected tool name 'edit', got '%v'", functionCall["name"])
}
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result, corrected := corrector.CorrectToolCallsInSSEData(tt.input)
if corrected != tt.expectCorrected {
t.Errorf("Expected corrected=%v, got %v", tt.expectCorrected, corrected)
}
if !corrected && result != tt.input {
t.Errorf("Expected unchanged result when not corrected")
}
if tt.checkFunc != nil {
tt.checkFunc(t, result)
}
})
}
}
func TestCorrectToolName(t *testing.T) {
tests := []struct {
input string
expected string
corrected bool
}{
{"apply_patch", "edit", true},
{"applyPatch", "edit", true},
{"update_plan", "todowrite", true},
{"updatePlan", "todowrite", true},
{"read_plan", "todoread", true},
{"readPlan", "todoread", true},
{"search_files", "grep", true},
{"searchFiles", "grep", true},
{"list_files", "glob", true},
{"listFiles", "glob", true},
{"read_file", "read", true},
{"readFile", "read", true},
{"write_file", "write", true},
{"writeFile", "write", true},
{"execute_bash", "bash", true},
{"executeBash", "bash", true},
{"exec_bash", "bash", true},
{"execBash", "bash", true},
{"unknown_tool", "unknown_tool", false},
{"read", "read", false},
{"edit", "edit", false},
}
for _, tt := range tests {
t.Run(tt.input, func(t *testing.T) {
result, corrected := CorrectToolName(tt.input)
if corrected != tt.corrected {
t.Errorf("Expected corrected=%v, got %v", tt.corrected, corrected)
}
if result != tt.expected {
t.Errorf("Expected '%s', got '%s'", tt.expected, result)
}
})
}
}
func TestGetToolNameMapping(t *testing.T) {
mapping := GetToolNameMapping()
expectedMappings := map[string]string{
"apply_patch": "edit",
"update_plan": "todowrite",
"read_plan": "todoread",
"search_files": "grep",
"list_files": "glob",
}
for from, to := range expectedMappings {
if mapping[from] != to {
t.Errorf("Expected mapping[%s] = %s, got %s", from, to, mapping[from])
}
}
mapping["test_tool"] = "test_value"
newMapping := GetToolNameMapping()
if _, exists := newMapping["test_tool"]; exists {
t.Error("Modifications to returned mapping should not affect original")
}
}
func TestCorrectorStats(t *testing.T) {
corrector := NewCodexToolCorrector()
stats := corrector.GetStats()
if stats.TotalCorrected != 0 {
t.Errorf("Expected TotalCorrected=0, got %d", stats.TotalCorrected)
}
if len(stats.CorrectionsByTool) != 0 {
t.Errorf("Expected empty CorrectionsByTool, got length %d", len(stats.CorrectionsByTool))
}
corrector.CorrectToolCallsInSSEData(`{"tool_calls":[{"function":{"name":"apply_patch"}}]}`)
corrector.CorrectToolCallsInSSEData(`{"tool_calls":[{"function":{"name":"apply_patch"}}]}`)
corrector.CorrectToolCallsInSSEData(`{"tool_calls":[{"function":{"name":"update_plan"}}]}`)
stats = corrector.GetStats()
if stats.TotalCorrected != 3 {
t.Errorf("Expected TotalCorrected=3, got %d", stats.TotalCorrected)
}
if stats.CorrectionsByTool["apply_patch->edit"] != 2 {
t.Errorf("Expected apply_patch->edit count=2, got %d", stats.CorrectionsByTool["apply_patch->edit"])
}
if stats.CorrectionsByTool["update_plan->todowrite"] != 1 {
t.Errorf("Expected update_plan->todowrite count=1, got %d", stats.CorrectionsByTool["update_plan->todowrite"])
}
corrector.ResetStats()
stats = corrector.GetStats()
if stats.TotalCorrected != 0 {
t.Errorf("Expected TotalCorrected=0 after reset, got %d", stats.TotalCorrected)
}
if len(stats.CorrectionsByTool) != 0 {
t.Errorf("Expected empty CorrectionsByTool after reset, got length %d", len(stats.CorrectionsByTool))
}
}
func TestComplexSSEData(t *testing.T) {
corrector := NewCodexToolCorrector()
input := `{
"id": "chatcmpl-123",
"object": "chat.completion.chunk",
"created": 1234567890,
"model": "gpt-5.1-codex",
"choices": [
{
"index": 0,
"delta": {
"tool_calls": [
{
"index": 0,
"function": {
"name": "apply_patch",
"arguments": "{\"file\":\"test.go\"}"
}
}
]
},
"finish_reason": null
}
]
}`
result, corrected := corrector.CorrectToolCallsInSSEData(input)
if !corrected {
t.Error("Expected data to be corrected")
}
var payload map[string]any
if err := json.Unmarshal([]byte(result), &payload); err != nil {
t.Fatalf("Failed to parse result: %v", err)
}
choices, ok := payload["choices"].([]any)
if !ok || len(choices) == 0 {
t.Fatal("No choices found in result")
}
choice, ok := choices[0].(map[string]any)
if !ok {
t.Fatal("Invalid choice format")
}
delta, ok := choice["delta"].(map[string]any)
if !ok {
t.Fatal("Invalid delta format")
}
toolCalls, ok := delta["tool_calls"].([]any)
if !ok || len(toolCalls) == 0 {
t.Fatal("No tool_calls found in delta")
}
toolCall, ok := toolCalls[0].(map[string]any)
if !ok {
t.Fatal("Invalid tool_call format")
}
function, ok := toolCall["function"].(map[string]any)
if !ok {
t.Fatal("Invalid function format")
}
if function["name"] != "edit" {
t.Errorf("Expected tool name 'edit', got '%v'", function["name"])
}
}
// TestCorrectToolParameters 测试工具参数修正
func TestCorrectToolParameters(t *testing.T) {
corrector := NewCodexToolCorrector()
tests := []struct {
name string
input string
expected map[string]bool // key: 期待存在的参数, value: true表示应该存在
}{
{
name: "remove workdir from bash tool",
input: `{
"tool_calls": [{
"function": {
"name": "bash",
"arguments": "{\"command\":\"ls\",\"workdir\":\"/tmp\"}"
}
}]
}`,
expected: map[string]bool{
"command": true,
"workdir": false,
},
},
{
name: "rename path to file_path in edit tool",
input: `{
"tool_calls": [{
"function": {
"name": "apply_patch",
"arguments": "{\"path\":\"/foo/bar.go\",\"old_string\":\"old\",\"new_string\":\"new\"}"
}
}]
}`,
expected: map[string]bool{
"file_path": true,
"path": false,
"old_string": true,
"new_string": true,
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
corrected, changed := corrector.CorrectToolCallsInSSEData(tt.input)
if !changed {
t.Error("expected data to be corrected")
}
// 解析修正后的数据
var result map[string]any
if err := json.Unmarshal([]byte(corrected), &result); err != nil {
t.Fatalf("failed to parse corrected data: %v", err)
}
// 检查工具调用
toolCalls, ok := result["tool_calls"].([]any)
if !ok || len(toolCalls) == 0 {
t.Fatal("no tool_calls found in corrected data")
}
toolCall, ok := toolCalls[0].(map[string]any)
if !ok {
t.Fatal("invalid tool_call structure")
}
function, ok := toolCall["function"].(map[string]any)
if !ok {
t.Fatal("no function found in tool_call")
}
argumentsStr, ok := function["arguments"].(string)
if !ok {
t.Fatal("arguments is not a string")
}
var args map[string]any
if err := json.Unmarshal([]byte(argumentsStr), &args); err != nil {
t.Fatalf("failed to parse arguments: %v", err)
}
// 验证期望的参数
for param, shouldExist := range tt.expected {
_, exists := args[param]
if shouldExist && !exists {
t.Errorf("expected parameter %q to exist, but it doesn't", param)
}
if !shouldExist && exists {
t.Errorf("expected parameter %q to not exist, but it does", param)
}
}
})
}
}
......@@ -4,6 +4,7 @@ import (
"context"
"database/sql"
"errors"
"fmt"
"log"
"strings"
"sync"
......@@ -235,11 +236,13 @@ func (s *OpsAggregationService) aggregateHourly() {
successAt := finishedAt
hbCtx, hbCancel := context.WithTimeout(context.Background(), 2*time.Second)
defer hbCancel()
result := truncateString(fmt.Sprintf("window=%s..%s", start.Format(time.RFC3339), end.Format(time.RFC3339)), 2048)
_ = s.opsRepo.UpsertJobHeartbeat(hbCtx, &OpsUpsertJobHeartbeatInput{
JobName: opsAggHourlyJobName,
LastRunAt: &runAt,
LastSuccessAt: &successAt,
LastDurationMs: &dur,
LastResult: &result,
})
}
......@@ -331,11 +334,13 @@ func (s *OpsAggregationService) aggregateDaily() {
successAt := finishedAt
hbCtx, hbCancel := context.WithTimeout(context.Background(), 2*time.Second)
defer hbCancel()
result := truncateString(fmt.Sprintf("window=%s..%s", start.Format(time.RFC3339), end.Format(time.RFC3339)), 2048)
_ = s.opsRepo.UpsertJobHeartbeat(hbCtx, &OpsUpsertJobHeartbeatInput{
JobName: opsAggDailyJobName,
LastRunAt: &runAt,
LastSuccessAt: &successAt,
LastDurationMs: &dur,
LastResult: &result,
})
}
......
......@@ -190,6 +190,13 @@ func (s *OpsAlertEvaluatorService) evaluateOnce(interval time.Duration) {
return
}
rulesTotal := len(rules)
rulesEnabled := 0
rulesEvaluated := 0
eventsCreated := 0
eventsResolved := 0
emailsSent := 0
now := time.Now().UTC()
safeEnd := now.Truncate(time.Minute)
if safeEnd.IsZero() {
......@@ -205,8 +212,9 @@ func (s *OpsAlertEvaluatorService) evaluateOnce(interval time.Duration) {
if rule == nil || !rule.Enabled || rule.ID <= 0 {
continue
}
rulesEnabled++
scopePlatform, scopeGroupID := parseOpsAlertRuleScope(rule.Filters)
scopePlatform, scopeGroupID, scopeRegion := parseOpsAlertRuleScope(rule.Filters)
windowMinutes := rule.WindowMinutes
if windowMinutes <= 0 {
......@@ -220,6 +228,7 @@ func (s *OpsAlertEvaluatorService) evaluateOnce(interval time.Duration) {
s.resetRuleState(rule.ID, now)
continue
}
rulesEvaluated++
breachedNow := compareMetric(metricValue, rule.Operator, rule.Threshold)
required := requiredSustainedBreaches(rule.SustainedMinutes, interval)
......@@ -236,6 +245,17 @@ func (s *OpsAlertEvaluatorService) evaluateOnce(interval time.Duration) {
continue
}
// Scoped silencing: if a matching silence exists, skip creating a firing event.
if s.opsService != nil {
platform := strings.TrimSpace(scopePlatform)
region := scopeRegion
if platform != "" {
if ok, err := s.opsService.IsAlertSilenced(ctx, rule.ID, platform, scopeGroupID, region, now); err == nil && ok {
continue
}
}
}
latestEvent, err := s.opsRepo.GetLatestAlertEvent(ctx, rule.ID)
if err != nil {
log.Printf("[OpsAlertEvaluator] get latest event failed (rule=%d): %v", rule.ID, err)
......@@ -267,8 +287,11 @@ func (s *OpsAlertEvaluatorService) evaluateOnce(interval time.Duration) {
continue
}
eventsCreated++
if created != nil && created.ID > 0 {
s.maybeSendAlertEmail(ctx, runtimeCfg, rule, created)
if s.maybeSendAlertEmail(ctx, runtimeCfg, rule, created) {
emailsSent++
}
}
continue
}
......@@ -278,11 +301,14 @@ func (s *OpsAlertEvaluatorService) evaluateOnce(interval time.Duration) {
resolvedAt := now
if err := s.opsRepo.UpdateAlertEventStatus(ctx, activeEvent.ID, OpsAlertStatusResolved, &resolvedAt); err != nil {
log.Printf("[OpsAlertEvaluator] resolve event failed (event=%d): %v", activeEvent.ID, err)
} else {
eventsResolved++
}
}
}
s.recordHeartbeatSuccess(runAt, time.Since(startedAt))
result := truncateString(fmt.Sprintf("rules=%d enabled=%d evaluated=%d created=%d resolved=%d emails_sent=%d", rulesTotal, rulesEnabled, rulesEvaluated, eventsCreated, eventsResolved, emailsSent), 2048)
s.recordHeartbeatSuccess(runAt, time.Since(startedAt), result)
}
func (s *OpsAlertEvaluatorService) pruneRuleStates(rules []*OpsAlertRule) {
......@@ -359,9 +385,9 @@ func requiredSustainedBreaches(sustainedMinutes int, interval time.Duration) int
return required
}
func parseOpsAlertRuleScope(filters map[string]any) (platform string, groupID *int64) {
func parseOpsAlertRuleScope(filters map[string]any) (platform string, groupID *int64, region *string) {
if filters == nil {
return "", nil
return "", nil, nil
}
if v, ok := filters["platform"]; ok {
if s, ok := v.(string); ok {
......@@ -392,7 +418,15 @@ func parseOpsAlertRuleScope(filters map[string]any) (platform string, groupID *i
}
}
}
return platform, groupID
if v, ok := filters["region"]; ok {
if s, ok := v.(string); ok {
vv := strings.TrimSpace(s)
if vv != "" {
region = &vv
}
}
}
return platform, groupID, region
}
func (s *OpsAlertEvaluatorService) computeRuleMetric(
......@@ -504,16 +538,6 @@ func (s *OpsAlertEvaluatorService) computeRuleMetric(
return 0, false
}
return overview.UpstreamErrorRate * 100, true
case "p95_latency_ms":
if overview.Duration.P95 == nil {
return 0, false
}
return float64(*overview.Duration.P95), true
case "p99_latency_ms":
if overview.Duration.P99 == nil {
return 0, false
}
return float64(*overview.Duration.P99), true
default:
return 0, false
}
......@@ -576,32 +600,32 @@ func buildOpsAlertDescription(rule *OpsAlertRule, value float64, windowMinutes i
)
}
func (s *OpsAlertEvaluatorService) maybeSendAlertEmail(ctx context.Context, runtimeCfg *OpsAlertRuntimeSettings, rule *OpsAlertRule, event *OpsAlertEvent) {
func (s *OpsAlertEvaluatorService) maybeSendAlertEmail(ctx context.Context, runtimeCfg *OpsAlertRuntimeSettings, rule *OpsAlertRule, event *OpsAlertEvent) bool {
if s == nil || s.emailService == nil || s.opsService == nil || event == nil || rule == nil {
return
return false
}
if event.EmailSent {
return
return false
}
if !rule.NotifyEmail {
return
return false
}
emailCfg, err := s.opsService.GetEmailNotificationConfig(ctx)
if err != nil || emailCfg == nil || !emailCfg.Alert.Enabled {
return
return false
}
if len(emailCfg.Alert.Recipients) == 0 {
return
return false
}
if !shouldSendOpsAlertEmailByMinSeverity(strings.TrimSpace(emailCfg.Alert.MinSeverity), strings.TrimSpace(rule.Severity)) {
return
return false
}
if runtimeCfg != nil && runtimeCfg.Silencing.Enabled {
if isOpsAlertSilenced(time.Now().UTC(), rule, event, runtimeCfg.Silencing) {
return
return false
}
}
......@@ -630,6 +654,7 @@ func (s *OpsAlertEvaluatorService) maybeSendAlertEmail(ctx context.Context, runt
if anySent {
_ = s.opsRepo.UpdateAlertEventEmailSent(context.Background(), event.ID, true)
}
return anySent
}
func buildOpsAlertEmailBody(rule *OpsAlertRule, event *OpsAlertEvent) string {
......@@ -797,7 +822,7 @@ func (s *OpsAlertEvaluatorService) maybeLogSkip(key string) {
log.Printf("[OpsAlertEvaluator] leader lock held by another instance; skipping (key=%q)", key)
}
func (s *OpsAlertEvaluatorService) recordHeartbeatSuccess(runAt time.Time, duration time.Duration) {
func (s *OpsAlertEvaluatorService) recordHeartbeatSuccess(runAt time.Time, duration time.Duration, result string) {
if s == nil || s.opsRepo == nil {
return
}
......@@ -805,11 +830,17 @@ func (s *OpsAlertEvaluatorService) recordHeartbeatSuccess(runAt time.Time, durat
durMs := duration.Milliseconds()
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
msg := strings.TrimSpace(result)
if msg == "" {
msg = "ok"
}
msg = truncateString(msg, 2048)
_ = s.opsRepo.UpsertJobHeartbeat(ctx, &OpsUpsertJobHeartbeatInput{
JobName: opsAlertEvaluatorJobName,
LastRunAt: &runAt,
LastSuccessAt: &now,
LastDurationMs: &durMs,
LastResult: &msg,
})
}
......
......@@ -8,8 +8,9 @@ import "time"
// with the existing ops dashboard frontend (backup style).
const (
OpsAlertStatusFiring = "firing"
OpsAlertStatusResolved = "resolved"
OpsAlertStatusFiring = "firing"
OpsAlertStatusResolved = "resolved"
OpsAlertStatusManualResolved = "manual_resolved"
)
type OpsAlertRule struct {
......@@ -58,12 +59,32 @@ type OpsAlertEvent struct {
CreatedAt time.Time `json:"created_at"`
}
type OpsAlertSilence struct {
ID int64 `json:"id"`
RuleID int64 `json:"rule_id"`
Platform string `json:"platform"`
GroupID *int64 `json:"group_id,omitempty"`
Region *string `json:"region,omitempty"`
Until time.Time `json:"until"`
Reason string `json:"reason"`
CreatedBy *int64 `json:"created_by,omitempty"`
CreatedAt time.Time `json:"created_at"`
}
type OpsAlertEventFilter struct {
Limit int
// Cursor pagination (descending by fired_at, then id).
BeforeFiredAt *time.Time
BeforeID *int64
// Optional filters.
Status string
Severity string
Status string
Severity string
EmailSent *bool
StartTime *time.Time
EndTime *time.Time
......
......@@ -88,6 +88,29 @@ func (s *OpsService) ListAlertEvents(ctx context.Context, filter *OpsAlertEventF
return s.opsRepo.ListAlertEvents(ctx, filter)
}
func (s *OpsService) GetAlertEventByID(ctx context.Context, eventID int64) (*OpsAlertEvent, error) {
if err := s.RequireMonitoringEnabled(ctx); err != nil {
return nil, err
}
if s.opsRepo == nil {
return nil, infraerrors.ServiceUnavailable("OPS_REPO_UNAVAILABLE", "Ops repository not available")
}
if eventID <= 0 {
return nil, infraerrors.BadRequest("INVALID_EVENT_ID", "invalid event id")
}
ev, err := s.opsRepo.GetAlertEventByID(ctx, eventID)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil, infraerrors.NotFound("OPS_ALERT_EVENT_NOT_FOUND", "alert event not found")
}
return nil, err
}
if ev == nil {
return nil, infraerrors.NotFound("OPS_ALERT_EVENT_NOT_FOUND", "alert event not found")
}
return ev, nil
}
func (s *OpsService) GetActiveAlertEvent(ctx context.Context, ruleID int64) (*OpsAlertEvent, error) {
if err := s.RequireMonitoringEnabled(ctx); err != nil {
return nil, err
......@@ -101,6 +124,49 @@ func (s *OpsService) GetActiveAlertEvent(ctx context.Context, ruleID int64) (*Op
return s.opsRepo.GetActiveAlertEvent(ctx, ruleID)
}
func (s *OpsService) CreateAlertSilence(ctx context.Context, input *OpsAlertSilence) (*OpsAlertSilence, error) {
if err := s.RequireMonitoringEnabled(ctx); err != nil {
return nil, err
}
if s.opsRepo == nil {
return nil, infraerrors.ServiceUnavailable("OPS_REPO_UNAVAILABLE", "Ops repository not available")
}
if input == nil {
return nil, infraerrors.BadRequest("INVALID_SILENCE", "invalid silence")
}
if input.RuleID <= 0 {
return nil, infraerrors.BadRequest("INVALID_RULE_ID", "invalid rule id")
}
if strings.TrimSpace(input.Platform) == "" {
return nil, infraerrors.BadRequest("INVALID_PLATFORM", "invalid platform")
}
if input.Until.IsZero() {
return nil, infraerrors.BadRequest("INVALID_UNTIL", "invalid until")
}
created, err := s.opsRepo.CreateAlertSilence(ctx, input)
if err != nil {
return nil, err
}
return created, nil
}
func (s *OpsService) IsAlertSilenced(ctx context.Context, ruleID int64, platform string, groupID *int64, region *string, now time.Time) (bool, error) {
if err := s.RequireMonitoringEnabled(ctx); err != nil {
return false, err
}
if s.opsRepo == nil {
return false, infraerrors.ServiceUnavailable("OPS_REPO_UNAVAILABLE", "Ops repository not available")
}
if ruleID <= 0 {
return false, infraerrors.BadRequest("INVALID_RULE_ID", "invalid rule id")
}
if strings.TrimSpace(platform) == "" {
return false, nil
}
return s.opsRepo.IsAlertSilenced(ctx, ruleID, platform, groupID, region, now)
}
func (s *OpsService) GetLatestAlertEvent(ctx context.Context, ruleID int64) (*OpsAlertEvent, error) {
if err := s.RequireMonitoringEnabled(ctx); err != nil {
return nil, err
......@@ -142,7 +208,11 @@ func (s *OpsService) UpdateAlertEventStatus(ctx context.Context, eventID int64,
if eventID <= 0 {
return infraerrors.BadRequest("INVALID_EVENT_ID", "invalid event id")
}
if strings.TrimSpace(status) == "" {
status = strings.TrimSpace(status)
if status == "" {
return infraerrors.BadRequest("INVALID_STATUS", "invalid status")
}
if status != OpsAlertStatusResolved && status != OpsAlertStatusManualResolved {
return infraerrors.BadRequest("INVALID_STATUS", "invalid status")
}
return s.opsRepo.UpdateAlertEventStatus(ctx, eventID, status, resolvedAt)
......
......@@ -149,7 +149,7 @@ func (s *OpsCleanupService) runScheduled() {
log.Printf("[OpsCleanup] cleanup failed: %v", err)
return
}
s.recordHeartbeatSuccess(runAt, time.Since(startedAt))
s.recordHeartbeatSuccess(runAt, time.Since(startedAt), counts)
log.Printf("[OpsCleanup] cleanup complete: %s", counts)
}
......@@ -330,12 +330,13 @@ func (s *OpsCleanupService) tryAcquireLeaderLock(ctx context.Context) (func(), b
return release, true
}
func (s *OpsCleanupService) recordHeartbeatSuccess(runAt time.Time, duration time.Duration) {
func (s *OpsCleanupService) recordHeartbeatSuccess(runAt time.Time, duration time.Duration, counts opsCleanupDeletedCounts) {
if s == nil || s.opsRepo == nil {
return
}
now := time.Now().UTC()
durMs := duration.Milliseconds()
result := truncateString(counts.String(), 2048)
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
_ = s.opsRepo.UpsertJobHeartbeat(ctx, &OpsUpsertJobHeartbeatInput{
......@@ -343,6 +344,7 @@ func (s *OpsCleanupService) recordHeartbeatSuccess(runAt time.Time, duration tim
LastRunAt: &runAt,
LastSuccessAt: &now,
LastDurationMs: &durMs,
LastResult: &result,
})
}
......
......@@ -32,49 +32,38 @@ func computeDashboardHealthScore(now time.Time, overview *OpsDashboardOverview)
}
// computeBusinessHealth calculates business health score (0-100)
// Components: SLA (50%) + Error Rate (30%) + Latency (20%)
// Components: Error Rate (50%) + TTFT (50%)
func computeBusinessHealth(overview *OpsDashboardOverview) float64 {
// SLA score: 99.5% → 100, 95% → 0 (linear)
slaScore := 100.0
slaPct := clampFloat64(overview.SLA*100, 0, 100)
if slaPct < 99.5 {
if slaPct >= 95 {
slaScore = (slaPct - 95) / 4.5 * 100
} else {
slaScore = 0
}
}
// Error rate score: 0.5% → 100, 5% → 0 (linear)
// Error rate score: 1% → 100, 10% → 0 (linear)
// Combines request errors and upstream errors
errorScore := 100.0
errorPct := clampFloat64(overview.ErrorRate*100, 0, 100)
upstreamPct := clampFloat64(overview.UpstreamErrorRate*100, 0, 100)
combinedErrorPct := math.Max(errorPct, upstreamPct) // Use worst case
if combinedErrorPct > 0.5 {
if combinedErrorPct <= 5 {
errorScore = (5 - combinedErrorPct) / 4.5 * 100
if combinedErrorPct > 1.0 {
if combinedErrorPct <= 10.0 {
errorScore = (10.0 - combinedErrorPct) / 9.0 * 100
} else {
errorScore = 0
}
}
// Latency score: 1s → 100, 10s → 0 (linear)
// Uses P99 of duration (TTFT is less critical for overall health)
latencyScore := 100.0
if overview.Duration.P99 != nil {
p99 := float64(*overview.Duration.P99)
// TTFT score: 1s → 100, 3s → 0 (linear)
// Time to first token is critical for user experience
ttftScore := 100.0
if overview.TTFT.P99 != nil {
p99 := float64(*overview.TTFT.P99)
if p99 > 1000 {
if p99 <= 10000 {
latencyScore = (10000 - p99) / 9000 * 100
if p99 <= 3000 {
ttftScore = (3000 - p99) / 2000 * 100
} else {
latencyScore = 0
ttftScore = 0
}
}
}
// Weighted combination
return slaScore*0.5 + errorScore*0.3 + latencyScore*0.2
// Weighted combination: 50% error rate + 50% TTFT
return errorScore*0.5 + ttftScore*0.5
}
// computeInfraHealth calculates infrastructure health score (0-100)
......
......@@ -127,8 +127,8 @@ func TestComputeDashboardHealthScore_Comprehensive(t *testing.T) {
MemoryUsagePercent: float64Ptr(75),
},
},
wantMin: 60,
wantMax: 85,
wantMin: 96,
wantMax: 97,
},
{
name: "DB failure",
......@@ -203,8 +203,8 @@ func TestComputeDashboardHealthScore_Comprehensive(t *testing.T) {
MemoryUsagePercent: float64Ptr(30),
},
},
wantMin: 25,
wantMax: 50,
wantMin: 84,
wantMax: 85,
},
{
name: "combined failures - business healthy + infra degraded",
......@@ -277,30 +277,41 @@ func TestComputeBusinessHealth(t *testing.T) {
UpstreamErrorRate: 0,
Duration: OpsPercentiles{P99: intPtr(500)},
},
wantMin: 50,
wantMax: 60,
wantMin: 100,
wantMax: 100,
},
{
name: "error rate boundary 0.5%",
name: "error rate boundary 1%",
overview: &OpsDashboardOverview{
SLA: 0.995,
ErrorRate: 0.005,
SLA: 0.99,
ErrorRate: 0.01,
UpstreamErrorRate: 0,
Duration: OpsPercentiles{P99: intPtr(500)},
},
wantMin: 95,
wantMin: 100,
wantMax: 100,
},
{
name: "latency boundary 1000ms",
name: "error rate 5%",
overview: &OpsDashboardOverview{
SLA: 0.995,
SLA: 0.95,
ErrorRate: 0.05,
UpstreamErrorRate: 0,
Duration: OpsPercentiles{P99: intPtr(500)},
},
wantMin: 77,
wantMax: 78,
},
{
name: "TTFT boundary 2s",
overview: &OpsDashboardOverview{
SLA: 0.99,
ErrorRate: 0,
UpstreamErrorRate: 0,
Duration: OpsPercentiles{P99: intPtr(1000)},
TTFT: OpsPercentiles{P99: intPtr(2000)},
},
wantMin: 95,
wantMax: 100,
wantMin: 75,
wantMax: 75,
},
{
name: "upstream error dominates",
......@@ -310,7 +321,7 @@ func TestComputeBusinessHealth(t *testing.T) {
UpstreamErrorRate: 0.03,
Duration: OpsPercentiles{P99: intPtr(500)},
},
wantMin: 75,
wantMin: 88,
wantMax: 90,
},
}
......
......@@ -6,24 +6,43 @@ type OpsErrorLog struct {
ID int64 `json:"id"`
CreatedAt time.Time `json:"created_at"`
Phase string `json:"phase"`
Type string `json:"type"`
// Standardized classification
// - phase: request|auth|routing|upstream|network|internal
// - owner: client|provider|platform
// - source: client_request|upstream_http|gateway
Phase string `json:"phase"`
Type string `json:"type"`
Owner string `json:"error_owner"`
Source string `json:"error_source"`
Severity string `json:"severity"`
StatusCode int `json:"status_code"`
Platform string `json:"platform"`
Model string `json:"model"`
LatencyMs *int `json:"latency_ms"`
IsRetryable bool `json:"is_retryable"`
RetryCount int `json:"retry_count"`
Resolved bool `json:"resolved"`
ResolvedAt *time.Time `json:"resolved_at"`
ResolvedByUserID *int64 `json:"resolved_by_user_id"`
ResolvedByUserName string `json:"resolved_by_user_name"`
ResolvedRetryID *int64 `json:"resolved_retry_id"`
ResolvedStatusRaw string `json:"-"`
ClientRequestID string `json:"client_request_id"`
RequestID string `json:"request_id"`
Message string `json:"message"`
UserID *int64 `json:"user_id"`
APIKeyID *int64 `json:"api_key_id"`
AccountID *int64 `json:"account_id"`
GroupID *int64 `json:"group_id"`
UserID *int64 `json:"user_id"`
UserEmail string `json:"user_email"`
APIKeyID *int64 `json:"api_key_id"`
AccountID *int64 `json:"account_id"`
AccountName string `json:"account_name"`
GroupID *int64 `json:"group_id"`
GroupName string `json:"group_name"`
ClientIP *string `json:"client_ip"`
RequestPath string `json:"request_path"`
......@@ -67,9 +86,24 @@ type OpsErrorLogFilter struct {
GroupID *int64
AccountID *int64
StatusCodes []int
Phase string
Query string
StatusCodes []int
StatusCodesOther bool
Phase string
Owner string
Source string
Resolved *bool
Query string
UserQuery string // Search by user email
// Optional correlation keys for exact matching.
RequestID string
ClientRequestID string
// View controls error categorization for list endpoints.
// - errors: show actionable errors (exclude business-limited / 429 / 529)
// - excluded: only show excluded errors
// - all: show everything
View string
Page int
PageSize int
......@@ -90,12 +124,23 @@ type OpsRetryAttempt struct {
SourceErrorID int64 `json:"source_error_id"`
Mode string `json:"mode"`
PinnedAccountID *int64 `json:"pinned_account_id"`
PinnedAccountName string `json:"pinned_account_name"`
Status string `json:"status"`
StartedAt *time.Time `json:"started_at"`
FinishedAt *time.Time `json:"finished_at"`
DurationMs *int64 `json:"duration_ms"`
// Persisted execution results (best-effort)
Success *bool `json:"success"`
HTTPStatusCode *int `json:"http_status_code"`
UpstreamRequestID *string `json:"upstream_request_id"`
UsedAccountID *int64 `json:"used_account_id"`
UsedAccountName string `json:"used_account_name"`
ResponsePreview *string `json:"response_preview"`
ResponseTruncated *bool `json:"response_truncated"`
// Optional correlation
ResultRequestID *string `json:"result_request_id"`
ResultErrorID *int64 `json:"result_error_id"`
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment