Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Menu
Open sidebar
陈曦
sub2api
Commits
6901b64f
Commit
6901b64f
authored
Jan 17, 2026
by
cyhhao
Browse files
merge: sync upstream changes
parents
32c47b15
dae0d532
Changes
189
Hide whitespace changes
Inline
Side-by-side
backend/internal/service/gemini_messages_compat_service.go
View file @
6901b64f
...
...
@@ -545,12 +545,19 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
}
requestIDHeader
=
idHeader
// Capture upstream request body for ops retry of this attempt.
if
c
!=
nil
{
// In this code path `body` is already the JSON sent to upstream.
c
.
Set
(
OpsUpstreamRequestBodyKey
,
string
(
body
))
}
resp
,
err
=
s
.
httpUpstream
.
Do
(
upstreamReq
,
proxyURL
,
account
.
ID
,
account
.
Concurrency
)
if
err
!=
nil
{
safeErr
:=
sanitizeUpstreamErrorMessage
(
err
.
Error
())
appendOpsUpstreamError
(
c
,
OpsUpstreamErrorEvent
{
Platform
:
account
.
Platform
,
AccountID
:
account
.
ID
,
AccountName
:
account
.
Name
,
UpstreamStatusCode
:
0
,
Kind
:
"request_error"
,
Message
:
safeErr
,
...
...
@@ -588,6 +595,7 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
appendOpsUpstreamError
(
c
,
OpsUpstreamErrorEvent
{
Platform
:
account
.
Platform
,
AccountID
:
account
.
ID
,
AccountName
:
account
.
Name
,
UpstreamStatusCode
:
resp
.
StatusCode
,
UpstreamRequestID
:
upstreamReqID
,
Kind
:
"signature_error"
,
...
...
@@ -662,6 +670,7 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
appendOpsUpstreamError
(
c
,
OpsUpstreamErrorEvent
{
Platform
:
account
.
Platform
,
AccountID
:
account
.
ID
,
AccountName
:
account
.
Name
,
UpstreamStatusCode
:
resp
.
StatusCode
,
UpstreamRequestID
:
upstreamReqID
,
Kind
:
"retry"
,
...
...
@@ -711,6 +720,7 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
appendOpsUpstreamError
(
c
,
OpsUpstreamErrorEvent
{
Platform
:
account
.
Platform
,
AccountID
:
account
.
ID
,
AccountName
:
account
.
Name
,
UpstreamStatusCode
:
resp
.
StatusCode
,
UpstreamRequestID
:
upstreamReqID
,
Kind
:
"failover"
,
...
...
@@ -737,6 +747,7 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
appendOpsUpstreamError
(
c
,
OpsUpstreamErrorEvent
{
Platform
:
account
.
Platform
,
AccountID
:
account
.
ID
,
AccountName
:
account
.
Name
,
UpstreamStatusCode
:
resp
.
StatusCode
,
UpstreamRequestID
:
upstreamReqID
,
Kind
:
"failover"
,
...
...
@@ -972,12 +983,19 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
}
requestIDHeader
=
idHeader
// Capture upstream request body for ops retry of this attempt.
if
c
!=
nil
{
// In this code path `body` is already the JSON sent to upstream.
c
.
Set
(
OpsUpstreamRequestBodyKey
,
string
(
body
))
}
resp
,
err
=
s
.
httpUpstream
.
Do
(
upstreamReq
,
proxyURL
,
account
.
ID
,
account
.
Concurrency
)
if
err
!=
nil
{
safeErr
:=
sanitizeUpstreamErrorMessage
(
err
.
Error
())
appendOpsUpstreamError
(
c
,
OpsUpstreamErrorEvent
{
Platform
:
account
.
Platform
,
AccountID
:
account
.
ID
,
AccountName
:
account
.
Name
,
UpstreamStatusCode
:
0
,
Kind
:
"request_error"
,
Message
:
safeErr
,
...
...
@@ -1036,6 +1054,7 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
appendOpsUpstreamError
(
c
,
OpsUpstreamErrorEvent
{
Platform
:
account
.
Platform
,
AccountID
:
account
.
ID
,
AccountName
:
account
.
Name
,
UpstreamStatusCode
:
resp
.
StatusCode
,
UpstreamRequestID
:
upstreamReqID
,
Kind
:
"retry"
,
...
...
@@ -1120,6 +1139,7 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
appendOpsUpstreamError
(
c
,
OpsUpstreamErrorEvent
{
Platform
:
account
.
Platform
,
AccountID
:
account
.
ID
,
AccountName
:
account
.
Name
,
UpstreamStatusCode
:
resp
.
StatusCode
,
UpstreamRequestID
:
requestID
,
Kind
:
"failover"
,
...
...
@@ -1143,6 +1163,7 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
appendOpsUpstreamError
(
c
,
OpsUpstreamErrorEvent
{
Platform
:
account
.
Platform
,
AccountID
:
account
.
ID
,
AccountName
:
account
.
Name
,
UpstreamStatusCode
:
resp
.
StatusCode
,
UpstreamRequestID
:
requestID
,
Kind
:
"failover"
,
...
...
@@ -1168,6 +1189,7 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
appendOpsUpstreamError
(
c
,
OpsUpstreamErrorEvent
{
Platform
:
account
.
Platform
,
AccountID
:
account
.
ID
,
AccountName
:
account
.
Name
,
UpstreamStatusCode
:
resp
.
StatusCode
,
UpstreamRequestID
:
requestID
,
Kind
:
"http_error"
,
...
...
@@ -1300,6 +1322,7 @@ func (s *GeminiMessagesCompatService) writeGeminiMappedError(c *gin.Context, acc
appendOpsUpstreamError
(
c
,
OpsUpstreamErrorEvent
{
Platform
:
account
.
Platform
,
AccountID
:
account
.
ID
,
AccountName
:
account
.
Name
,
UpstreamStatusCode
:
upstreamStatus
,
UpstreamRequestID
:
upstreamRequestID
,
Kind
:
"http_error"
,
...
...
backend/internal/service/gemini_multiplatform_test.go
View file @
6901b64f
...
...
@@ -125,6 +125,9 @@ func (m *mockAccountRepoForGemini) SetRateLimited(ctx context.Context, id int64,
func
(
m
*
mockAccountRepoForGemini
)
SetAntigravityQuotaScopeLimit
(
ctx
context
.
Context
,
id
int64
,
scope
AntigravityQuotaScope
,
resetAt
time
.
Time
)
error
{
return
nil
}
func
(
m
*
mockAccountRepoForGemini
)
SetModelRateLimit
(
ctx
context
.
Context
,
id
int64
,
scope
string
,
resetAt
time
.
Time
)
error
{
return
nil
}
func
(
m
*
mockAccountRepoForGemini
)
SetOverloaded
(
ctx
context
.
Context
,
id
int64
,
until
time
.
Time
)
error
{
return
nil
}
...
...
@@ -138,6 +141,9 @@ func (m *mockAccountRepoForGemini) ClearRateLimit(ctx context.Context, id int64)
func
(
m
*
mockAccountRepoForGemini
)
ClearAntigravityQuotaScopes
(
ctx
context
.
Context
,
id
int64
)
error
{
return
nil
}
func
(
m
*
mockAccountRepoForGemini
)
ClearModelRateLimits
(
ctx
context
.
Context
,
id
int64
)
error
{
return
nil
}
func
(
m
*
mockAccountRepoForGemini
)
UpdateSessionWindow
(
ctx
context
.
Context
,
id
int64
,
start
,
end
*
time
.
Time
,
status
string
)
error
{
return
nil
}
...
...
backend/internal/service/gemini_token_cache.go
View file @
6901b64f
...
...
@@ -10,6 +10,7 @@ type GeminiTokenCache interface {
// cacheKey should be stable for the token scope; for GeminiCli OAuth we primarily use project_id.
GetAccessToken
(
ctx
context
.
Context
,
cacheKey
string
)
(
string
,
error
)
SetAccessToken
(
ctx
context
.
Context
,
cacheKey
string
,
token
string
,
ttl
time
.
Duration
)
error
DeleteAccessToken
(
ctx
context
.
Context
,
cacheKey
string
)
error
AcquireRefreshLock
(
ctx
context
.
Context
,
cacheKey
string
,
ttl
time
.
Duration
)
(
bool
,
error
)
ReleaseRefreshLock
(
ctx
context
.
Context
,
cacheKey
string
)
error
...
...
backend/internal/service/gemini_token_provider.go
View file @
6901b64f
...
...
@@ -40,7 +40,7 @@ func (p *GeminiTokenProvider) GetAccessToken(ctx context.Context, account *Accou
return
""
,
errors
.
New
(
"not a gemini oauth account"
)
}
cacheKey
:=
g
eminiTokenCacheKey
(
account
)
cacheKey
:=
G
eminiTokenCacheKey
(
account
)
// 1) Try cache first.
if
p
.
tokenCache
!=
nil
{
...
...
@@ -151,10 +151,10 @@ func (p *GeminiTokenProvider) GetAccessToken(ctx context.Context, account *Accou
return
accessToken
,
nil
}
func
g
eminiTokenCacheKey
(
account
*
Account
)
string
{
func
G
eminiTokenCacheKey
(
account
*
Account
)
string
{
projectID
:=
strings
.
TrimSpace
(
account
.
GetCredential
(
"project_id"
))
if
projectID
!=
""
{
return
projectID
return
"gemini:"
+
projectID
}
return
"account:"
+
strconv
.
FormatInt
(
account
.
ID
,
10
)
return
"
gemini:
account:"
+
strconv
.
FormatInt
(
account
.
ID
,
10
)
}
backend/internal/service/group.go
View file @
6901b64f
package
service
import
"time"
import
(
"strings"
"time"
)
type
Group
struct
{
ID
int64
...
...
@@ -27,6 +30,12 @@ type Group struct {
ClaudeCodeOnly
bool
FallbackGroupID
*
int64
// 模型路由配置
// key: 模型匹配模式(支持 * 通配符,如 "claude-opus-*")
// value: 优先账号 ID 列表
ModelRouting
map
[
string
][]
int64
ModelRoutingEnabled
bool
CreatedAt
time
.
Time
UpdatedAt
time
.
Time
...
...
@@ -90,3 +99,41 @@ func IsGroupContextValid(group *Group) bool {
}
return
true
}
// GetRoutingAccountIDs 根据请求模型获取路由账号 ID 列表
// 返回匹配的优先账号 ID 列表,如果没有匹配规则则返回 nil
func
(
g
*
Group
)
GetRoutingAccountIDs
(
requestedModel
string
)
[]
int64
{
if
!
g
.
ModelRoutingEnabled
||
len
(
g
.
ModelRouting
)
==
0
||
requestedModel
==
""
{
return
nil
}
// 1. 精确匹配优先
if
accountIDs
,
ok
:=
g
.
ModelRouting
[
requestedModel
];
ok
&&
len
(
accountIDs
)
>
0
{
return
accountIDs
}
// 2. 通配符匹配(前缀匹配)
for
pattern
,
accountIDs
:=
range
g
.
ModelRouting
{
if
matchModelPattern
(
pattern
,
requestedModel
)
&&
len
(
accountIDs
)
>
0
{
return
accountIDs
}
}
return
nil
}
// matchModelPattern 检查模型是否匹配模式
// 支持 * 通配符,如 "claude-opus-*" 匹配 "claude-opus-4-20250514"
func
matchModelPattern
(
pattern
,
model
string
)
bool
{
if
pattern
==
model
{
return
true
}
// 处理 * 通配符(仅支持末尾通配符)
if
strings
.
HasSuffix
(
pattern
,
"*"
)
{
prefix
:=
strings
.
TrimSuffix
(
pattern
,
"*"
)
return
strings
.
HasPrefix
(
model
,
prefix
)
}
return
false
}
backend/internal/service/model_rate_limit.go
0 → 100644
View file @
6901b64f
package
service
import
(
"strings"
"time"
)
const
modelRateLimitsKey
=
"model_rate_limits"
const
modelRateLimitScopeClaudeSonnet
=
"claude_sonnet"
func
resolveModelRateLimitScope
(
requestedModel
string
)
(
string
,
bool
)
{
model
:=
strings
.
ToLower
(
strings
.
TrimSpace
(
requestedModel
))
if
model
==
""
{
return
""
,
false
}
model
=
strings
.
TrimPrefix
(
model
,
"models/"
)
if
strings
.
Contains
(
model
,
"sonnet"
)
{
return
modelRateLimitScopeClaudeSonnet
,
true
}
return
""
,
false
}
func
(
a
*
Account
)
isModelRateLimited
(
requestedModel
string
)
bool
{
scope
,
ok
:=
resolveModelRateLimitScope
(
requestedModel
)
if
!
ok
{
return
false
}
resetAt
:=
a
.
modelRateLimitResetAt
(
scope
)
if
resetAt
==
nil
{
return
false
}
return
time
.
Now
()
.
Before
(
*
resetAt
)
}
func
(
a
*
Account
)
modelRateLimitResetAt
(
scope
string
)
*
time
.
Time
{
if
a
==
nil
||
a
.
Extra
==
nil
||
scope
==
""
{
return
nil
}
rawLimits
,
ok
:=
a
.
Extra
[
modelRateLimitsKey
]
.
(
map
[
string
]
any
)
if
!
ok
{
return
nil
}
rawLimit
,
ok
:=
rawLimits
[
scope
]
.
(
map
[
string
]
any
)
if
!
ok
{
return
nil
}
resetAtRaw
,
ok
:=
rawLimit
[
"rate_limit_reset_at"
]
.
(
string
)
if
!
ok
||
strings
.
TrimSpace
(
resetAtRaw
)
==
""
{
return
nil
}
resetAt
,
err
:=
time
.
Parse
(
time
.
RFC3339
,
resetAtRaw
)
if
err
!=
nil
{
return
nil
}
return
&
resetAt
}
backend/internal/service/openai_gateway_service.go
View file @
6901b64f
...
...
@@ -93,6 +93,8 @@ type OpenAIGatewayService struct {
billingCacheService
*
BillingCacheService
httpUpstream
HTTPUpstream
deferredService
*
DeferredService
openAITokenProvider
*
OpenAITokenProvider
toolCorrector
*
CodexToolCorrector
}
// NewOpenAIGatewayService creates a new OpenAIGatewayService
...
...
@@ -110,6 +112,7 @@ func NewOpenAIGatewayService(
billingCacheService
*
BillingCacheService
,
httpUpstream
HTTPUpstream
,
deferredService
*
DeferredService
,
openAITokenProvider
*
OpenAITokenProvider
,
)
*
OpenAIGatewayService
{
return
&
OpenAIGatewayService
{
accountRepo
:
accountRepo
,
...
...
@@ -125,6 +128,8 @@ func NewOpenAIGatewayService(
billingCacheService
:
billingCacheService
,
httpUpstream
:
httpUpstream
,
deferredService
:
deferredService
,
openAITokenProvider
:
openAITokenProvider
,
toolCorrector
:
NewCodexToolCorrector
(),
}
}
...
...
@@ -503,6 +508,15 @@ func (s *OpenAIGatewayService) schedulingConfig() config.GatewaySchedulingConfig
func
(
s
*
OpenAIGatewayService
)
GetAccessToken
(
ctx
context
.
Context
,
account
*
Account
)
(
string
,
string
,
error
)
{
switch
account
.
Type
{
case
AccountTypeOAuth
:
// 使用 TokenProvider 获取缓存的 token
if
s
.
openAITokenProvider
!=
nil
{
accessToken
,
err
:=
s
.
openAITokenProvider
.
GetAccessToken
(
ctx
,
account
)
if
err
!=
nil
{
return
""
,
""
,
err
}
return
accessToken
,
"oauth"
,
nil
}
// 降级:TokenProvider 未配置时直接从账号读取
accessToken
:=
account
.
GetOpenAIAccessToken
()
if
accessToken
==
""
{
return
""
,
""
,
errors
.
New
(
"access_token not found in credentials"
)
...
...
@@ -664,6 +678,11 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
proxyURL
=
account
.
Proxy
.
URL
()
}
// Capture upstream request body for ops retry of this attempt.
if
c
!=
nil
{
c
.
Set
(
OpsUpstreamRequestBodyKey
,
string
(
body
))
}
// Send request
resp
,
err
:=
s
.
httpUpstream
.
Do
(
upstreamReq
,
proxyURL
,
account
.
ID
,
account
.
Concurrency
)
if
err
!=
nil
{
...
...
@@ -673,6 +692,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
appendOpsUpstreamError
(
c
,
OpsUpstreamErrorEvent
{
Platform
:
account
.
Platform
,
AccountID
:
account
.
ID
,
AccountName
:
account
.
Name
,
UpstreamStatusCode
:
0
,
Kind
:
"request_error"
,
Message
:
safeErr
,
...
...
@@ -707,6 +727,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
appendOpsUpstreamError
(
c
,
OpsUpstreamErrorEvent
{
Platform
:
account
.
Platform
,
AccountID
:
account
.
ID
,
AccountName
:
account
.
Name
,
UpstreamStatusCode
:
resp
.
StatusCode
,
UpstreamRequestID
:
resp
.
Header
.
Get
(
"x-request-id"
),
Kind
:
"failover"
,
...
...
@@ -864,6 +885,7 @@ func (s *OpenAIGatewayService) handleErrorResponse(ctx context.Context, resp *ht
appendOpsUpstreamError
(
c
,
OpsUpstreamErrorEvent
{
Platform
:
account
.
Platform
,
AccountID
:
account
.
ID
,
AccountName
:
account
.
Name
,
UpstreamStatusCode
:
resp
.
StatusCode
,
UpstreamRequestID
:
resp
.
Header
.
Get
(
"x-request-id"
),
Kind
:
"http_error"
,
...
...
@@ -894,6 +916,7 @@ func (s *OpenAIGatewayService) handleErrorResponse(ctx context.Context, resp *ht
appendOpsUpstreamError
(
c
,
OpsUpstreamErrorEvent
{
Platform
:
account
.
Platform
,
AccountID
:
account
.
ID
,
AccountName
:
account
.
Name
,
UpstreamStatusCode
:
resp
.
StatusCode
,
UpstreamRequestID
:
resp
.
Header
.
Get
(
"x-request-id"
),
Kind
:
kind
,
...
...
@@ -1097,6 +1120,12 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
line
=
s
.
replaceModelInSSELine
(
line
,
mappedModel
,
originalModel
)
}
// Correct Codex tool calls if needed (apply_patch -> edit, etc.)
if
correctedData
,
corrected
:=
s
.
toolCorrector
.
CorrectToolCallsInSSEData
(
data
);
corrected
{
data
=
correctedData
line
=
"data: "
+
correctedData
}
// 写入客户端(客户端断开后继续 drain 上游)
if
!
clientDisconnected
{
if
_
,
err
:=
fmt
.
Fprintf
(
w
,
"%s
\n
"
,
line
);
err
!=
nil
{
...
...
@@ -1199,6 +1228,20 @@ func (s *OpenAIGatewayService) replaceModelInSSELine(line, fromModel, toModel st
return
line
}
// correctToolCallsInResponseBody 修正响应体中的工具调用
func
(
s
*
OpenAIGatewayService
)
correctToolCallsInResponseBody
(
body
[]
byte
)
[]
byte
{
if
len
(
body
)
==
0
{
return
body
}
bodyStr
:=
string
(
body
)
corrected
,
changed
:=
s
.
toolCorrector
.
CorrectToolCallsInSSEData
(
bodyStr
)
if
changed
{
return
[]
byte
(
corrected
)
}
return
body
}
func
(
s
*
OpenAIGatewayService
)
parseSSEUsage
(
data
string
,
usage
*
OpenAIUsage
)
{
// Parse response.completed event for usage (OpenAI Responses format)
var
event
struct
{
...
...
@@ -1302,6 +1345,8 @@ func (s *OpenAIGatewayService) handleOAuthSSEToJSON(resp *http.Response, c *gin.
if
originalModel
!=
mappedModel
{
body
=
s
.
replaceModelInResponseBody
(
body
,
mappedModel
,
originalModel
)
}
// Correct tool calls in final response
body
=
s
.
correctToolCallsInResponseBody
(
body
)
}
else
{
usage
=
s
.
parseSSEUsageFromBody
(
bodyText
)
if
originalModel
!=
mappedModel
{
...
...
@@ -1470,28 +1515,30 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
// Create usage log
durationMs
:=
int
(
result
.
Duration
.
Milliseconds
())
accountRateMultiplier
:=
account
.
BillingRateMultiplier
()
usageLog
:=
&
UsageLog
{
UserID
:
user
.
ID
,
APIKeyID
:
apiKey
.
ID
,
AccountID
:
account
.
ID
,
RequestID
:
result
.
RequestID
,
Model
:
result
.
Model
,
InputTokens
:
actualInputTokens
,
OutputTokens
:
result
.
Usage
.
OutputTokens
,
CacheCreationTokens
:
result
.
Usage
.
CacheCreationInputTokens
,
CacheReadTokens
:
result
.
Usage
.
CacheReadInputTokens
,
InputCost
:
cost
.
InputCost
,
OutputCost
:
cost
.
OutputCost
,
CacheCreationCost
:
cost
.
CacheCreationCost
,
CacheReadCost
:
cost
.
CacheReadCost
,
TotalCost
:
cost
.
TotalCost
,
ActualCost
:
cost
.
ActualCost
,
RateMultiplier
:
multiplier
,
BillingType
:
billingType
,
Stream
:
result
.
Stream
,
DurationMs
:
&
durationMs
,
FirstTokenMs
:
result
.
FirstTokenMs
,
CreatedAt
:
time
.
Now
(),
UserID
:
user
.
ID
,
APIKeyID
:
apiKey
.
ID
,
AccountID
:
account
.
ID
,
RequestID
:
result
.
RequestID
,
Model
:
result
.
Model
,
InputTokens
:
actualInputTokens
,
OutputTokens
:
result
.
Usage
.
OutputTokens
,
CacheCreationTokens
:
result
.
Usage
.
CacheCreationInputTokens
,
CacheReadTokens
:
result
.
Usage
.
CacheReadInputTokens
,
InputCost
:
cost
.
InputCost
,
OutputCost
:
cost
.
OutputCost
,
CacheCreationCost
:
cost
.
CacheCreationCost
,
CacheReadCost
:
cost
.
CacheReadCost
,
TotalCost
:
cost
.
TotalCost
,
ActualCost
:
cost
.
ActualCost
,
RateMultiplier
:
multiplier
,
AccountRateMultiplier
:
&
accountRateMultiplier
,
BillingType
:
billingType
,
Stream
:
result
.
Stream
,
DurationMs
:
&
durationMs
,
FirstTokenMs
:
result
.
FirstTokenMs
,
CreatedAt
:
time
.
Now
(),
}
// 添加 UserAgent
...
...
backend/internal/service/openai_gateway_service_tool_correction_test.go
0 → 100644
View file @
6901b64f
package
service
import
(
"strings"
"testing"
)
// TestOpenAIGatewayService_ToolCorrection 测试 OpenAIGatewayService 中的工具修正集成
func
TestOpenAIGatewayService_ToolCorrection
(
t
*
testing
.
T
)
{
// 创建一个简单的 service 实例来测试工具修正
service
:=
&
OpenAIGatewayService
{
toolCorrector
:
NewCodexToolCorrector
(),
}
tests
:=
[]
struct
{
name
string
input
[]
byte
expected
string
changed
bool
}{
{
name
:
"correct apply_patch in response body"
,
input
:
[]
byte
(
`{
"choices": [{
"message": {
"tool_calls": [{
"function": {"name": "apply_patch"}
}]
}
}]
}`
),
expected
:
"edit"
,
changed
:
true
,
},
{
name
:
"correct update_plan in response body"
,
input
:
[]
byte
(
`{
"tool_calls": [{
"function": {"name": "update_plan"}
}]
}`
),
expected
:
"todowrite"
,
changed
:
true
,
},
{
name
:
"no change for correct tool name"
,
input
:
[]
byte
(
`{
"tool_calls": [{
"function": {"name": "edit"}
}]
}`
),
expected
:
"edit"
,
changed
:
false
,
},
}
for
_
,
tt
:=
range
tests
{
t
.
Run
(
tt
.
name
,
func
(
t
*
testing
.
T
)
{
result
:=
service
.
correctToolCallsInResponseBody
(
tt
.
input
)
resultStr
:=
string
(
result
)
// 检查是否包含期望的工具名称
if
!
strings
.
Contains
(
resultStr
,
tt
.
expected
)
{
t
.
Errorf
(
"expected result to contain %q, got %q"
,
tt
.
expected
,
resultStr
)
}
// 对于预期有变化的情况,验证结果与输入不同
if
tt
.
changed
&&
string
(
result
)
==
string
(
tt
.
input
)
{
t
.
Error
(
"expected result to be different from input, but they are the same"
)
}
// 对于预期无变化的情况,验证结果与输入相同
if
!
tt
.
changed
&&
string
(
result
)
!=
string
(
tt
.
input
)
{
t
.
Error
(
"expected result to be same as input, but they are different"
)
}
})
}
}
// TestOpenAIGatewayService_ToolCorrectorInitialization 测试工具修正器是否正确初始化
func
TestOpenAIGatewayService_ToolCorrectorInitialization
(
t
*
testing
.
T
)
{
service
:=
&
OpenAIGatewayService
{
toolCorrector
:
NewCodexToolCorrector
(),
}
if
service
.
toolCorrector
==
nil
{
t
.
Fatal
(
"toolCorrector should not be nil"
)
}
// 测试修正器可以正常工作
data
:=
`{"tool_calls":[{"function":{"name":"apply_patch"}}]}`
corrected
,
changed
:=
service
.
toolCorrector
.
CorrectToolCallsInSSEData
(
data
)
if
!
changed
{
t
.
Error
(
"expected tool call to be corrected"
)
}
if
!
strings
.
Contains
(
corrected
,
"edit"
)
{
t
.
Errorf
(
"expected corrected data to contain 'edit', got %q"
,
corrected
)
}
}
// TestToolCorrectionStats 测试工具修正统计功能
func
TestToolCorrectionStats
(
t
*
testing
.
T
)
{
service
:=
&
OpenAIGatewayService
{
toolCorrector
:
NewCodexToolCorrector
(),
}
// 执行几次修正
testData
:=
[]
string
{
`{"tool_calls":[{"function":{"name":"apply_patch"}}]}`
,
`{"tool_calls":[{"function":{"name":"update_plan"}}]}`
,
`{"tool_calls":[{"function":{"name":"apply_patch"}}]}`
,
}
for
_
,
data
:=
range
testData
{
service
.
toolCorrector
.
CorrectToolCallsInSSEData
(
data
)
}
stats
:=
service
.
toolCorrector
.
GetStats
()
if
stats
.
TotalCorrected
!=
3
{
t
.
Errorf
(
"expected 3 corrections, got %d"
,
stats
.
TotalCorrected
)
}
if
stats
.
CorrectionsByTool
[
"apply_patch->edit"
]
!=
2
{
t
.
Errorf
(
"expected 2 apply_patch->edit corrections, got %d"
,
stats
.
CorrectionsByTool
[
"apply_patch->edit"
])
}
if
stats
.
CorrectionsByTool
[
"update_plan->todowrite"
]
!=
1
{
t
.
Errorf
(
"expected 1 update_plan->todowrite correction, got %d"
,
stats
.
CorrectionsByTool
[
"update_plan->todowrite"
])
}
}
backend/internal/service/openai_token_provider.go
0 → 100644
View file @
6901b64f
package
service
import
(
"context"
"errors"
"log/slog"
"strings"
"time"
)
const
(
openAITokenRefreshSkew
=
3
*
time
.
Minute
openAITokenCacheSkew
=
5
*
time
.
Minute
openAILockWaitTime
=
200
*
time
.
Millisecond
)
// OpenAITokenCache Token 缓存接口(复用 GeminiTokenCache 接口定义)
type
OpenAITokenCache
=
GeminiTokenCache
// OpenAITokenProvider 管理 OpenAI OAuth 账户的 access_token
type
OpenAITokenProvider
struct
{
accountRepo
AccountRepository
tokenCache
OpenAITokenCache
openAIOAuthService
*
OpenAIOAuthService
}
func
NewOpenAITokenProvider
(
accountRepo
AccountRepository
,
tokenCache
OpenAITokenCache
,
openAIOAuthService
*
OpenAIOAuthService
,
)
*
OpenAITokenProvider
{
return
&
OpenAITokenProvider
{
accountRepo
:
accountRepo
,
tokenCache
:
tokenCache
,
openAIOAuthService
:
openAIOAuthService
,
}
}
// GetAccessToken 获取有效的 access_token
func
(
p
*
OpenAITokenProvider
)
GetAccessToken
(
ctx
context
.
Context
,
account
*
Account
)
(
string
,
error
)
{
if
account
==
nil
{
return
""
,
errors
.
New
(
"account is nil"
)
}
if
account
.
Platform
!=
PlatformOpenAI
||
account
.
Type
!=
AccountTypeOAuth
{
return
""
,
errors
.
New
(
"not an openai oauth account"
)
}
cacheKey
:=
OpenAITokenCacheKey
(
account
)
// 1. 先尝试缓存
if
p
.
tokenCache
!=
nil
{
if
token
,
err
:=
p
.
tokenCache
.
GetAccessToken
(
ctx
,
cacheKey
);
err
==
nil
&&
strings
.
TrimSpace
(
token
)
!=
""
{
slog
.
Debug
(
"openai_token_cache_hit"
,
"account_id"
,
account
.
ID
)
return
token
,
nil
}
else
if
err
!=
nil
{
slog
.
Warn
(
"openai_token_cache_get_failed"
,
"account_id"
,
account
.
ID
,
"error"
,
err
)
}
}
slog
.
Debug
(
"openai_token_cache_miss"
,
"account_id"
,
account
.
ID
)
// 2. 如果即将过期则刷新
expiresAt
:=
account
.
GetCredentialAsTime
(
"expires_at"
)
needsRefresh
:=
expiresAt
==
nil
||
time
.
Until
(
*
expiresAt
)
<=
openAITokenRefreshSkew
refreshFailed
:=
false
if
needsRefresh
&&
p
.
tokenCache
!=
nil
{
locked
,
lockErr
:=
p
.
tokenCache
.
AcquireRefreshLock
(
ctx
,
cacheKey
,
30
*
time
.
Second
)
if
lockErr
==
nil
&&
locked
{
defer
func
()
{
_
=
p
.
tokenCache
.
ReleaseRefreshLock
(
ctx
,
cacheKey
)
}()
// 拿到锁后再次检查缓存(另一个 worker 可能已刷新)
if
token
,
err
:=
p
.
tokenCache
.
GetAccessToken
(
ctx
,
cacheKey
);
err
==
nil
&&
strings
.
TrimSpace
(
token
)
!=
""
{
return
token
,
nil
}
// 从数据库获取最新账户信息
fresh
,
err
:=
p
.
accountRepo
.
GetByID
(
ctx
,
account
.
ID
)
if
err
==
nil
&&
fresh
!=
nil
{
account
=
fresh
}
expiresAt
=
account
.
GetCredentialAsTime
(
"expires_at"
)
if
expiresAt
==
nil
||
time
.
Until
(
*
expiresAt
)
<=
openAITokenRefreshSkew
{
if
p
.
openAIOAuthService
==
nil
{
slog
.
Warn
(
"openai_oauth_service_not_configured"
,
"account_id"
,
account
.
ID
)
refreshFailed
=
true
// 无法刷新,标记失败
}
else
{
tokenInfo
,
err
:=
p
.
openAIOAuthService
.
RefreshAccountToken
(
ctx
,
account
)
if
err
!=
nil
{
// 刷新失败时记录警告,但不立即返回错误,尝试使用现有 token
slog
.
Warn
(
"openai_token_refresh_failed"
,
"account_id"
,
account
.
ID
,
"error"
,
err
)
refreshFailed
=
true
// 刷新失败,标记以使用短 TTL
}
else
{
newCredentials
:=
p
.
openAIOAuthService
.
BuildAccountCredentials
(
tokenInfo
)
for
k
,
v
:=
range
account
.
Credentials
{
if
_
,
exists
:=
newCredentials
[
k
];
!
exists
{
newCredentials
[
k
]
=
v
}
}
account
.
Credentials
=
newCredentials
if
updateErr
:=
p
.
accountRepo
.
Update
(
ctx
,
account
);
updateErr
!=
nil
{
slog
.
Error
(
"openai_token_provider_update_failed"
,
"account_id"
,
account
.
ID
,
"error"
,
updateErr
)
}
expiresAt
=
account
.
GetCredentialAsTime
(
"expires_at"
)
}
}
}
}
else
if
lockErr
!=
nil
{
// Redis 错误导致无法获取锁,降级为无锁刷新(仅在 token 接近过期时)
slog
.
Warn
(
"openai_token_lock_failed_degraded_refresh"
,
"account_id"
,
account
.
ID
,
"error"
,
lockErr
)
// 检查 ctx 是否已取消
if
ctx
.
Err
()
!=
nil
{
return
""
,
ctx
.
Err
()
}
// 从数据库获取最新账户信息
if
p
.
accountRepo
!=
nil
{
fresh
,
err
:=
p
.
accountRepo
.
GetByID
(
ctx
,
account
.
ID
)
if
err
==
nil
&&
fresh
!=
nil
{
account
=
fresh
}
}
expiresAt
=
account
.
GetCredentialAsTime
(
"expires_at"
)
// 仅在 expires_at 已过期/接近过期时才执行无锁刷新
if
expiresAt
==
nil
||
time
.
Until
(
*
expiresAt
)
<=
openAITokenRefreshSkew
{
if
p
.
openAIOAuthService
==
nil
{
slog
.
Warn
(
"openai_oauth_service_not_configured"
,
"account_id"
,
account
.
ID
)
refreshFailed
=
true
}
else
{
tokenInfo
,
err
:=
p
.
openAIOAuthService
.
RefreshAccountToken
(
ctx
,
account
)
if
err
!=
nil
{
slog
.
Warn
(
"openai_token_refresh_failed_degraded"
,
"account_id"
,
account
.
ID
,
"error"
,
err
)
refreshFailed
=
true
}
else
{
newCredentials
:=
p
.
openAIOAuthService
.
BuildAccountCredentials
(
tokenInfo
)
for
k
,
v
:=
range
account
.
Credentials
{
if
_
,
exists
:=
newCredentials
[
k
];
!
exists
{
newCredentials
[
k
]
=
v
}
}
account
.
Credentials
=
newCredentials
if
updateErr
:=
p
.
accountRepo
.
Update
(
ctx
,
account
);
updateErr
!=
nil
{
slog
.
Error
(
"openai_token_provider_update_failed"
,
"account_id"
,
account
.
ID
,
"error"
,
updateErr
)
}
expiresAt
=
account
.
GetCredentialAsTime
(
"expires_at"
)
}
}
}
}
else
{
// 锁获取失败(被其他 worker 持有),等待 200ms 后重试读取缓存
time
.
Sleep
(
openAILockWaitTime
)
if
token
,
err
:=
p
.
tokenCache
.
GetAccessToken
(
ctx
,
cacheKey
);
err
==
nil
&&
strings
.
TrimSpace
(
token
)
!=
""
{
slog
.
Debug
(
"openai_token_cache_hit_after_wait"
,
"account_id"
,
account
.
ID
)
return
token
,
nil
}
}
}
accessToken
:=
account
.
GetOpenAIAccessToken
()
if
strings
.
TrimSpace
(
accessToken
)
==
""
{
return
""
,
errors
.
New
(
"access_token not found in credentials"
)
}
// 3. 存入缓存
if
p
.
tokenCache
!=
nil
{
ttl
:=
30
*
time
.
Minute
if
refreshFailed
{
// 刷新失败时使用短 TTL,避免失效 token 长时间缓存导致 401 抖动
ttl
=
time
.
Minute
slog
.
Debug
(
"openai_token_cache_short_ttl"
,
"account_id"
,
account
.
ID
,
"reason"
,
"refresh_failed"
)
}
else
if
expiresAt
!=
nil
{
until
:=
time
.
Until
(
*
expiresAt
)
switch
{
case
until
>
openAITokenCacheSkew
:
ttl
=
until
-
openAITokenCacheSkew
case
until
>
0
:
ttl
=
until
default
:
ttl
=
time
.
Minute
}
}
if
err
:=
p
.
tokenCache
.
SetAccessToken
(
ctx
,
cacheKey
,
accessToken
,
ttl
);
err
!=
nil
{
slog
.
Warn
(
"openai_token_cache_set_failed"
,
"account_id"
,
account
.
ID
,
"error"
,
err
)
}
}
return
accessToken
,
nil
}
backend/internal/service/openai_token_provider_test.go
0 → 100644
View file @
6901b64f
//go:build unit
package
service
import
(
"context"
"errors"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/stretchr/testify/require"
)
// openAITokenCacheStub implements OpenAITokenCache for testing
type
openAITokenCacheStub
struct
{
mu
sync
.
Mutex
tokens
map
[
string
]
string
getErr
error
setErr
error
deleteErr
error
lockAcquired
bool
lockErr
error
releaseLockErr
error
getCalled
int32
setCalled
int32
lockCalled
int32
unlockCalled
int32
simulateLockRace
bool
}
func
newOpenAITokenCacheStub
()
*
openAITokenCacheStub
{
return
&
openAITokenCacheStub
{
tokens
:
make
(
map
[
string
]
string
),
lockAcquired
:
true
,
}
}
func
(
s
*
openAITokenCacheStub
)
GetAccessToken
(
ctx
context
.
Context
,
cacheKey
string
)
(
string
,
error
)
{
atomic
.
AddInt32
(
&
s
.
getCalled
,
1
)
if
s
.
getErr
!=
nil
{
return
""
,
s
.
getErr
}
s
.
mu
.
Lock
()
defer
s
.
mu
.
Unlock
()
return
s
.
tokens
[
cacheKey
],
nil
}
func
(
s
*
openAITokenCacheStub
)
SetAccessToken
(
ctx
context
.
Context
,
cacheKey
string
,
token
string
,
ttl
time
.
Duration
)
error
{
atomic
.
AddInt32
(
&
s
.
setCalled
,
1
)
if
s
.
setErr
!=
nil
{
return
s
.
setErr
}
s
.
mu
.
Lock
()
defer
s
.
mu
.
Unlock
()
s
.
tokens
[
cacheKey
]
=
token
return
nil
}
func
(
s
*
openAITokenCacheStub
)
DeleteAccessToken
(
ctx
context
.
Context
,
cacheKey
string
)
error
{
if
s
.
deleteErr
!=
nil
{
return
s
.
deleteErr
}
s
.
mu
.
Lock
()
defer
s
.
mu
.
Unlock
()
delete
(
s
.
tokens
,
cacheKey
)
return
nil
}
func
(
s
*
openAITokenCacheStub
)
AcquireRefreshLock
(
ctx
context
.
Context
,
cacheKey
string
,
ttl
time
.
Duration
)
(
bool
,
error
)
{
atomic
.
AddInt32
(
&
s
.
lockCalled
,
1
)
if
s
.
lockErr
!=
nil
{
return
false
,
s
.
lockErr
}
if
s
.
simulateLockRace
{
return
false
,
nil
}
return
s
.
lockAcquired
,
nil
}
func
(
s
*
openAITokenCacheStub
)
ReleaseRefreshLock
(
ctx
context
.
Context
,
cacheKey
string
)
error
{
atomic
.
AddInt32
(
&
s
.
unlockCalled
,
1
)
return
s
.
releaseLockErr
}
// openAIAccountRepoStub is a minimal stub implementing only the methods used by OpenAITokenProvider
type
openAIAccountRepoStub
struct
{
account
*
Account
getErr
error
updateErr
error
getCalled
int32
updateCalled
int32
}
func
(
r
*
openAIAccountRepoStub
)
GetByID
(
ctx
context
.
Context
,
id
int64
)
(
*
Account
,
error
)
{
atomic
.
AddInt32
(
&
r
.
getCalled
,
1
)
if
r
.
getErr
!=
nil
{
return
nil
,
r
.
getErr
}
return
r
.
account
,
nil
}
func
(
r
*
openAIAccountRepoStub
)
Update
(
ctx
context
.
Context
,
account
*
Account
)
error
{
atomic
.
AddInt32
(
&
r
.
updateCalled
,
1
)
if
r
.
updateErr
!=
nil
{
return
r
.
updateErr
}
r
.
account
=
account
return
nil
}
// openAIOAuthServiceStub implements OpenAIOAuthService methods for testing
type
openAIOAuthServiceStub
struct
{
tokenInfo
*
OpenAITokenInfo
refreshErr
error
refreshCalled
int32
}
func
(
s
*
openAIOAuthServiceStub
)
RefreshAccountToken
(
ctx
context
.
Context
,
account
*
Account
)
(
*
OpenAITokenInfo
,
error
)
{
atomic
.
AddInt32
(
&
s
.
refreshCalled
,
1
)
if
s
.
refreshErr
!=
nil
{
return
nil
,
s
.
refreshErr
}
return
s
.
tokenInfo
,
nil
}
func
(
s
*
openAIOAuthServiceStub
)
BuildAccountCredentials
(
info
*
OpenAITokenInfo
)
map
[
string
]
any
{
now
:=
time
.
Now
()
return
map
[
string
]
any
{
"access_token"
:
info
.
AccessToken
,
"refresh_token"
:
info
.
RefreshToken
,
"expires_at"
:
now
.
Add
(
time
.
Duration
(
info
.
ExpiresIn
)
*
time
.
Second
)
.
Format
(
time
.
RFC3339
),
}
}
func
TestOpenAITokenProvider_CacheHit
(
t
*
testing
.
T
)
{
cache
:=
newOpenAITokenCacheStub
()
account
:=
&
Account
{
ID
:
100
,
Platform
:
PlatformOpenAI
,
Type
:
AccountTypeOAuth
,
Credentials
:
map
[
string
]
any
{
"access_token"
:
"db-token"
,
},
}
cacheKey
:=
OpenAITokenCacheKey
(
account
)
cache
.
tokens
[
cacheKey
]
=
"cached-token"
provider
:=
NewOpenAITokenProvider
(
nil
,
cache
,
nil
)
token
,
err
:=
provider
.
GetAccessToken
(
context
.
Background
(),
account
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
"cached-token"
,
token
)
require
.
Equal
(
t
,
int32
(
1
),
atomic
.
LoadInt32
(
&
cache
.
getCalled
))
require
.
Equal
(
t
,
int32
(
0
),
atomic
.
LoadInt32
(
&
cache
.
setCalled
))
}
func
TestOpenAITokenProvider_CacheMiss_FromCredentials
(
t
*
testing
.
T
)
{
cache
:=
newOpenAITokenCacheStub
()
// Token expires in far future, no refresh needed
expiresAt
:=
time
.
Now
()
.
Add
(
1
*
time
.
Hour
)
.
Format
(
time
.
RFC3339
)
account
:=
&
Account
{
ID
:
101
,
Platform
:
PlatformOpenAI
,
Type
:
AccountTypeOAuth
,
Credentials
:
map
[
string
]
any
{
"access_token"
:
"credential-token"
,
"expires_at"
:
expiresAt
,
},
}
provider
:=
NewOpenAITokenProvider
(
nil
,
cache
,
nil
)
token
,
err
:=
provider
.
GetAccessToken
(
context
.
Background
(),
account
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
"credential-token"
,
token
)
// Should have stored in cache
cacheKey
:=
OpenAITokenCacheKey
(
account
)
require
.
Equal
(
t
,
"credential-token"
,
cache
.
tokens
[
cacheKey
])
}
func
TestOpenAITokenProvider_TokenRefresh
(
t
*
testing
.
T
)
{
cache
:=
newOpenAITokenCacheStub
()
accountRepo
:=
&
openAIAccountRepoStub
{}
oauthService
:=
&
openAIOAuthServiceStub
{
tokenInfo
:
&
OpenAITokenInfo
{
AccessToken
:
"refreshed-token"
,
RefreshToken
:
"new-refresh-token"
,
ExpiresIn
:
3600
,
},
}
// Token expires soon (within refresh skew)
expiresAt
:=
time
.
Now
()
.
Add
(
1
*
time
.
Minute
)
.
Format
(
time
.
RFC3339
)
account
:=
&
Account
{
ID
:
102
,
Platform
:
PlatformOpenAI
,
Type
:
AccountTypeOAuth
,
Credentials
:
map
[
string
]
any
{
"access_token"
:
"old-token"
,
"refresh_token"
:
"old-refresh-token"
,
"expires_at"
:
expiresAt
,
},
}
accountRepo
.
account
=
account
// We need to directly test with the stub - create a custom provider
customProvider
:=
&
testOpenAITokenProvider
{
accountRepo
:
accountRepo
,
tokenCache
:
cache
,
oauthService
:
oauthService
,
}
token
,
err
:=
customProvider
.
GetAccessToken
(
context
.
Background
(),
account
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
"refreshed-token"
,
token
)
require
.
Equal
(
t
,
int32
(
1
),
atomic
.
LoadInt32
(
&
oauthService
.
refreshCalled
))
}
// testOpenAITokenProvider is a test version that uses the stub OAuth service
type
testOpenAITokenProvider
struct
{
accountRepo
*
openAIAccountRepoStub
tokenCache
*
openAITokenCacheStub
oauthService
*
openAIOAuthServiceStub
}
func
(
p
*
testOpenAITokenProvider
)
GetAccessToken
(
ctx
context
.
Context
,
account
*
Account
)
(
string
,
error
)
{
if
account
==
nil
{
return
""
,
errors
.
New
(
"account is nil"
)
}
if
account
.
Platform
!=
PlatformOpenAI
||
account
.
Type
!=
AccountTypeOAuth
{
return
""
,
errors
.
New
(
"not an openai oauth account"
)
}
cacheKey
:=
OpenAITokenCacheKey
(
account
)
// 1. Check cache
if
p
.
tokenCache
!=
nil
{
if
token
,
err
:=
p
.
tokenCache
.
GetAccessToken
(
ctx
,
cacheKey
);
err
==
nil
&&
token
!=
""
{
return
token
,
nil
}
}
// 2. Check if refresh needed
expiresAt
:=
account
.
GetCredentialAsTime
(
"expires_at"
)
needsRefresh
:=
expiresAt
==
nil
||
time
.
Until
(
*
expiresAt
)
<=
openAITokenRefreshSkew
refreshFailed
:=
false
if
needsRefresh
&&
p
.
tokenCache
!=
nil
{
locked
,
err
:=
p
.
tokenCache
.
AcquireRefreshLock
(
ctx
,
cacheKey
,
30
*
time
.
Second
)
if
err
==
nil
&&
locked
{
defer
func
()
{
_
=
p
.
tokenCache
.
ReleaseRefreshLock
(
ctx
,
cacheKey
)
}()
// Check cache again after acquiring lock
if
token
,
err
:=
p
.
tokenCache
.
GetAccessToken
(
ctx
,
cacheKey
);
err
==
nil
&&
token
!=
""
{
return
token
,
nil
}
// Get fresh account from DB
fresh
,
err
:=
p
.
accountRepo
.
GetByID
(
ctx
,
account
.
ID
)
if
err
==
nil
&&
fresh
!=
nil
{
account
=
fresh
}
expiresAt
=
account
.
GetCredentialAsTime
(
"expires_at"
)
if
expiresAt
==
nil
||
time
.
Until
(
*
expiresAt
)
<=
openAITokenRefreshSkew
{
if
p
.
oauthService
==
nil
{
refreshFailed
=
true
// 无法刷新,标记失败
}
else
{
tokenInfo
,
err
:=
p
.
oauthService
.
RefreshAccountToken
(
ctx
,
account
)
if
err
!=
nil
{
refreshFailed
=
true
// 刷新失败,标记以使用短 TTL
}
else
{
newCredentials
:=
p
.
oauthService
.
BuildAccountCredentials
(
tokenInfo
)
for
k
,
v
:=
range
account
.
Credentials
{
if
_
,
exists
:=
newCredentials
[
k
];
!
exists
{
newCredentials
[
k
]
=
v
}
}
account
.
Credentials
=
newCredentials
_
=
p
.
accountRepo
.
Update
(
ctx
,
account
)
expiresAt
=
account
.
GetCredentialAsTime
(
"expires_at"
)
}
}
}
}
else
if
p
.
tokenCache
.
simulateLockRace
{
// Wait and retry cache
time
.
Sleep
(
10
*
time
.
Millisecond
)
// Short wait for test
if
token
,
err
:=
p
.
tokenCache
.
GetAccessToken
(
ctx
,
cacheKey
);
err
==
nil
&&
token
!=
""
{
return
token
,
nil
}
}
}
accessToken
:=
account
.
GetOpenAIAccessToken
()
if
accessToken
==
""
{
return
""
,
errors
.
New
(
"access_token not found in credentials"
)
}
// 3. Store in cache
if
p
.
tokenCache
!=
nil
{
ttl
:=
30
*
time
.
Minute
if
refreshFailed
{
ttl
=
time
.
Minute
// 刷新失败时使用短 TTL
}
else
if
expiresAt
!=
nil
{
until
:=
time
.
Until
(
*
expiresAt
)
if
until
>
openAITokenCacheSkew
{
ttl
=
until
-
openAITokenCacheSkew
}
else
if
until
>
0
{
ttl
=
until
}
else
{
ttl
=
time
.
Minute
}
}
_
=
p
.
tokenCache
.
SetAccessToken
(
ctx
,
cacheKey
,
accessToken
,
ttl
)
}
return
accessToken
,
nil
}
func
TestOpenAITokenProvider_LockRaceCondition
(
t
*
testing
.
T
)
{
cache
:=
newOpenAITokenCacheStub
()
cache
.
simulateLockRace
=
true
accountRepo
:=
&
openAIAccountRepoStub
{}
// Token expires soon
expiresAt
:=
time
.
Now
()
.
Add
(
1
*
time
.
Minute
)
.
Format
(
time
.
RFC3339
)
account
:=
&
Account
{
ID
:
103
,
Platform
:
PlatformOpenAI
,
Type
:
AccountTypeOAuth
,
Credentials
:
map
[
string
]
any
{
"access_token"
:
"race-token"
,
"expires_at"
:
expiresAt
,
},
}
accountRepo
.
account
=
account
// Simulate another worker already refreshed and cached
cacheKey
:=
OpenAITokenCacheKey
(
account
)
go
func
()
{
time
.
Sleep
(
5
*
time
.
Millisecond
)
cache
.
mu
.
Lock
()
cache
.
tokens
[
cacheKey
]
=
"winner-token"
cache
.
mu
.
Unlock
()
}()
provider
:=
&
testOpenAITokenProvider
{
accountRepo
:
accountRepo
,
tokenCache
:
cache
,
}
token
,
err
:=
provider
.
GetAccessToken
(
context
.
Background
(),
account
)
require
.
NoError
(
t
,
err
)
// Should get the token set by the "winner" or the original
require
.
NotEmpty
(
t
,
token
)
}
func
TestOpenAITokenProvider_NilAccount
(
t
*
testing
.
T
)
{
provider
:=
NewOpenAITokenProvider
(
nil
,
nil
,
nil
)
token
,
err
:=
provider
.
GetAccessToken
(
context
.
Background
(),
nil
)
require
.
Error
(
t
,
err
)
require
.
Contains
(
t
,
err
.
Error
(),
"account is nil"
)
require
.
Empty
(
t
,
token
)
}
func
TestOpenAITokenProvider_WrongPlatform
(
t
*
testing
.
T
)
{
provider
:=
NewOpenAITokenProvider
(
nil
,
nil
,
nil
)
account
:=
&
Account
{
ID
:
104
,
Platform
:
PlatformGemini
,
Type
:
AccountTypeOAuth
,
}
token
,
err
:=
provider
.
GetAccessToken
(
context
.
Background
(),
account
)
require
.
Error
(
t
,
err
)
require
.
Contains
(
t
,
err
.
Error
(),
"not an openai oauth account"
)
require
.
Empty
(
t
,
token
)
}
func
TestOpenAITokenProvider_WrongAccountType
(
t
*
testing
.
T
)
{
provider
:=
NewOpenAITokenProvider
(
nil
,
nil
,
nil
)
account
:=
&
Account
{
ID
:
105
,
Platform
:
PlatformOpenAI
,
Type
:
AccountTypeAPIKey
,
}
token
,
err
:=
provider
.
GetAccessToken
(
context
.
Background
(),
account
)
require
.
Error
(
t
,
err
)
require
.
Contains
(
t
,
err
.
Error
(),
"not an openai oauth account"
)
require
.
Empty
(
t
,
token
)
}
func
TestOpenAITokenProvider_NilCache
(
t
*
testing
.
T
)
{
// Token doesn't need refresh
expiresAt
:=
time
.
Now
()
.
Add
(
1
*
time
.
Hour
)
.
Format
(
time
.
RFC3339
)
account
:=
&
Account
{
ID
:
106
,
Platform
:
PlatformOpenAI
,
Type
:
AccountTypeOAuth
,
Credentials
:
map
[
string
]
any
{
"access_token"
:
"nocache-token"
,
"expires_at"
:
expiresAt
,
},
}
provider
:=
NewOpenAITokenProvider
(
nil
,
nil
,
nil
)
token
,
err
:=
provider
.
GetAccessToken
(
context
.
Background
(),
account
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
"nocache-token"
,
token
)
}
func
TestOpenAITokenProvider_CacheGetError
(
t
*
testing
.
T
)
{
cache
:=
newOpenAITokenCacheStub
()
cache
.
getErr
=
errors
.
New
(
"redis connection failed"
)
// Token doesn't need refresh
expiresAt
:=
time
.
Now
()
.
Add
(
1
*
time
.
Hour
)
.
Format
(
time
.
RFC3339
)
account
:=
&
Account
{
ID
:
107
,
Platform
:
PlatformOpenAI
,
Type
:
AccountTypeOAuth
,
Credentials
:
map
[
string
]
any
{
"access_token"
:
"fallback-token"
,
"expires_at"
:
expiresAt
,
},
}
provider
:=
NewOpenAITokenProvider
(
nil
,
cache
,
nil
)
// Should gracefully degrade and return from credentials
token
,
err
:=
provider
.
GetAccessToken
(
context
.
Background
(),
account
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
"fallback-token"
,
token
)
}
func
TestOpenAITokenProvider_CacheSetError
(
t
*
testing
.
T
)
{
cache
:=
newOpenAITokenCacheStub
()
cache
.
setErr
=
errors
.
New
(
"redis write failed"
)
expiresAt
:=
time
.
Now
()
.
Add
(
1
*
time
.
Hour
)
.
Format
(
time
.
RFC3339
)
account
:=
&
Account
{
ID
:
108
,
Platform
:
PlatformOpenAI
,
Type
:
AccountTypeOAuth
,
Credentials
:
map
[
string
]
any
{
"access_token"
:
"still-works-token"
,
"expires_at"
:
expiresAt
,
},
}
provider
:=
NewOpenAITokenProvider
(
nil
,
cache
,
nil
)
// Should still work even if cache set fails
token
,
err
:=
provider
.
GetAccessToken
(
context
.
Background
(),
account
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
"still-works-token"
,
token
)
}
func
TestOpenAITokenProvider_MissingAccessToken
(
t
*
testing
.
T
)
{
cache
:=
newOpenAITokenCacheStub
()
expiresAt
:=
time
.
Now
()
.
Add
(
1
*
time
.
Hour
)
.
Format
(
time
.
RFC3339
)
account
:=
&
Account
{
ID
:
109
,
Platform
:
PlatformOpenAI
,
Type
:
AccountTypeOAuth
,
Credentials
:
map
[
string
]
any
{
"expires_at"
:
expiresAt
,
// missing access_token
},
}
provider
:=
NewOpenAITokenProvider
(
nil
,
cache
,
nil
)
token
,
err
:=
provider
.
GetAccessToken
(
context
.
Background
(),
account
)
require
.
Error
(
t
,
err
)
require
.
Contains
(
t
,
err
.
Error
(),
"access_token not found"
)
require
.
Empty
(
t
,
token
)
}
func
TestOpenAITokenProvider_RefreshError
(
t
*
testing
.
T
)
{
cache
:=
newOpenAITokenCacheStub
()
accountRepo
:=
&
openAIAccountRepoStub
{}
oauthService
:=
&
openAIOAuthServiceStub
{
refreshErr
:
errors
.
New
(
"oauth refresh failed"
),
}
// Token expires soon
expiresAt
:=
time
.
Now
()
.
Add
(
1
*
time
.
Minute
)
.
Format
(
time
.
RFC3339
)
account
:=
&
Account
{
ID
:
110
,
Platform
:
PlatformOpenAI
,
Type
:
AccountTypeOAuth
,
Credentials
:
map
[
string
]
any
{
"access_token"
:
"old-token"
,
"refresh_token"
:
"old-refresh-token"
,
"expires_at"
:
expiresAt
,
},
}
accountRepo
.
account
=
account
provider
:=
&
testOpenAITokenProvider
{
accountRepo
:
accountRepo
,
tokenCache
:
cache
,
oauthService
:
oauthService
,
}
// Now with fallback behavior, should return existing token even if refresh fails
token
,
err
:=
provider
.
GetAccessToken
(
context
.
Background
(),
account
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
"old-token"
,
token
)
// Fallback to existing token
}
func
TestOpenAITokenProvider_OAuthServiceNotConfigured
(
t
*
testing
.
T
)
{
cache
:=
newOpenAITokenCacheStub
()
accountRepo
:=
&
openAIAccountRepoStub
{}
// Token expires soon
expiresAt
:=
time
.
Now
()
.
Add
(
1
*
time
.
Minute
)
.
Format
(
time
.
RFC3339
)
account
:=
&
Account
{
ID
:
111
,
Platform
:
PlatformOpenAI
,
Type
:
AccountTypeOAuth
,
Credentials
:
map
[
string
]
any
{
"access_token"
:
"old-token"
,
"expires_at"
:
expiresAt
,
},
}
accountRepo
.
account
=
account
provider
:=
&
testOpenAITokenProvider
{
accountRepo
:
accountRepo
,
tokenCache
:
cache
,
oauthService
:
nil
,
// not configured
}
// Now with fallback behavior, should return existing token even if oauth service not configured
token
,
err
:=
provider
.
GetAccessToken
(
context
.
Background
(),
account
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
"old-token"
,
token
)
// Fallback to existing token
}
func
TestOpenAITokenProvider_TTLCalculation
(
t
*
testing
.
T
)
{
tests
:=
[]
struct
{
name
string
expiresIn
time
.
Duration
}{
{
name
:
"far_future_expiry"
,
expiresIn
:
1
*
time
.
Hour
,
},
{
name
:
"medium_expiry"
,
expiresIn
:
10
*
time
.
Minute
,
},
{
name
:
"near_expiry"
,
expiresIn
:
6
*
time
.
Minute
,
},
}
for
_
,
tt
:=
range
tests
{
t
.
Run
(
tt
.
name
,
func
(
t
*
testing
.
T
)
{
cache
:=
newOpenAITokenCacheStub
()
expiresAt
:=
time
.
Now
()
.
Add
(
tt
.
expiresIn
)
.
Format
(
time
.
RFC3339
)
account
:=
&
Account
{
ID
:
200
,
Platform
:
PlatformOpenAI
,
Type
:
AccountTypeOAuth
,
Credentials
:
map
[
string
]
any
{
"access_token"
:
"test-token"
,
"expires_at"
:
expiresAt
,
},
}
provider
:=
NewOpenAITokenProvider
(
nil
,
cache
,
nil
)
_
,
err
:=
provider
.
GetAccessToken
(
context
.
Background
(),
account
)
require
.
NoError
(
t
,
err
)
// Verify token was cached
cacheKey
:=
OpenAITokenCacheKey
(
account
)
require
.
Equal
(
t
,
"test-token"
,
cache
.
tokens
[
cacheKey
])
})
}
}
func
TestOpenAITokenProvider_DoubleCheckAfterLock
(
t
*
testing
.
T
)
{
cache
:=
newOpenAITokenCacheStub
()
accountRepo
:=
&
openAIAccountRepoStub
{}
oauthService
:=
&
openAIOAuthServiceStub
{
tokenInfo
:
&
OpenAITokenInfo
{
AccessToken
:
"refreshed-token"
,
RefreshToken
:
"new-refresh"
,
ExpiresIn
:
3600
,
},
}
// Token expires soon
expiresAt
:=
time
.
Now
()
.
Add
(
1
*
time
.
Minute
)
.
Format
(
time
.
RFC3339
)
account
:=
&
Account
{
ID
:
112
,
Platform
:
PlatformOpenAI
,
Type
:
AccountTypeOAuth
,
Credentials
:
map
[
string
]
any
{
"access_token"
:
"old-token"
,
"expires_at"
:
expiresAt
,
},
}
accountRepo
.
account
=
account
cacheKey
:=
OpenAITokenCacheKey
(
account
)
// Simulate: first GetAccessToken returns empty, but after lock acquired, cache has token
originalGet
:=
int32
(
0
)
cache
.
tokens
[
cacheKey
]
=
""
// Empty initially
provider
:=
&
testOpenAITokenProvider
{
accountRepo
:
accountRepo
,
tokenCache
:
cache
,
oauthService
:
oauthService
,
}
// In a goroutine, set the cached token after a small delay (simulating race)
go
func
()
{
time
.
Sleep
(
5
*
time
.
Millisecond
)
cache
.
mu
.
Lock
()
cache
.
tokens
[
cacheKey
]
=
"cached-by-other"
cache
.
mu
.
Unlock
()
}()
token
,
err
:=
provider
.
GetAccessToken
(
context
.
Background
(),
account
)
require
.
NoError
(
t
,
err
)
// Should get either the refreshed token or the cached one
require
.
NotEmpty
(
t
,
token
)
_
=
originalGet
// Suppress unused warning
}
// Tests for real provider - to increase coverage
func
TestOpenAITokenProvider_Real_LockFailedWait
(
t
*
testing
.
T
)
{
cache
:=
newOpenAITokenCacheStub
()
cache
.
lockAcquired
=
false
// Lock acquisition fails
// Token expires soon (within refresh skew) to trigger lock attempt
expiresAt
:=
time
.
Now
()
.
Add
(
1
*
time
.
Minute
)
.
Format
(
time
.
RFC3339
)
account
:=
&
Account
{
ID
:
200
,
Platform
:
PlatformOpenAI
,
Type
:
AccountTypeOAuth
,
Credentials
:
map
[
string
]
any
{
"access_token"
:
"fallback-token"
,
"expires_at"
:
expiresAt
,
},
}
// Set token in cache after lock wait period (simulate other worker refreshing)
cacheKey
:=
OpenAITokenCacheKey
(
account
)
go
func
()
{
time
.
Sleep
(
100
*
time
.
Millisecond
)
cache
.
mu
.
Lock
()
cache
.
tokens
[
cacheKey
]
=
"refreshed-by-other"
cache
.
mu
.
Unlock
()
}()
provider
:=
NewOpenAITokenProvider
(
nil
,
cache
,
nil
)
token
,
err
:=
provider
.
GetAccessToken
(
context
.
Background
(),
account
)
require
.
NoError
(
t
,
err
)
// Should get either the fallback token or the refreshed one
require
.
NotEmpty
(
t
,
token
)
}
func
TestOpenAITokenProvider_Real_CacheHitAfterWait
(
t
*
testing
.
T
)
{
cache
:=
newOpenAITokenCacheStub
()
cache
.
lockAcquired
=
false
// Lock acquisition fails
// Token expires soon
expiresAt
:=
time
.
Now
()
.
Add
(
1
*
time
.
Minute
)
.
Format
(
time
.
RFC3339
)
account
:=
&
Account
{
ID
:
201
,
Platform
:
PlatformOpenAI
,
Type
:
AccountTypeOAuth
,
Credentials
:
map
[
string
]
any
{
"access_token"
:
"original-token"
,
"expires_at"
:
expiresAt
,
},
}
cacheKey
:=
OpenAITokenCacheKey
(
account
)
// Set token in cache immediately after wait starts
go
func
()
{
time
.
Sleep
(
50
*
time
.
Millisecond
)
cache
.
mu
.
Lock
()
cache
.
tokens
[
cacheKey
]
=
"winner-token"
cache
.
mu
.
Unlock
()
}()
provider
:=
NewOpenAITokenProvider
(
nil
,
cache
,
nil
)
token
,
err
:=
provider
.
GetAccessToken
(
context
.
Background
(),
account
)
require
.
NoError
(
t
,
err
)
require
.
NotEmpty
(
t
,
token
)
}
func
TestOpenAITokenProvider_Real_ExpiredWithoutRefreshToken
(
t
*
testing
.
T
)
{
cache
:=
newOpenAITokenCacheStub
()
cache
.
lockAcquired
=
false
// Prevent entering refresh logic
// Token with nil expires_at (no expiry set) - should use credentials
account
:=
&
Account
{
ID
:
202
,
Platform
:
PlatformOpenAI
,
Type
:
AccountTypeOAuth
,
Credentials
:
map
[
string
]
any
{
"access_token"
:
"no-expiry-token"
,
},
}
provider
:=
NewOpenAITokenProvider
(
nil
,
cache
,
nil
)
token
,
err
:=
provider
.
GetAccessToken
(
context
.
Background
(),
account
)
// Without OAuth service, refresh will fail but token should be returned from credentials
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
"no-expiry-token"
,
token
)
}
func
TestOpenAITokenProvider_Real_WhitespaceToken
(
t
*
testing
.
T
)
{
cache
:=
newOpenAITokenCacheStub
()
cacheKey
:=
"openai:account:203"
cache
.
tokens
[
cacheKey
]
=
" "
// Whitespace only - should be treated as empty
expiresAt
:=
time
.
Now
()
.
Add
(
1
*
time
.
Hour
)
.
Format
(
time
.
RFC3339
)
account
:=
&
Account
{
ID
:
203
,
Platform
:
PlatformOpenAI
,
Type
:
AccountTypeOAuth
,
Credentials
:
map
[
string
]
any
{
"access_token"
:
"real-token"
,
"expires_at"
:
expiresAt
,
},
}
provider
:=
NewOpenAITokenProvider
(
nil
,
cache
,
nil
)
token
,
err
:=
provider
.
GetAccessToken
(
context
.
Background
(),
account
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
"real-token"
,
token
)
// Should fall back to credentials
}
func
TestOpenAITokenProvider_Real_LockError
(
t
*
testing
.
T
)
{
cache
:=
newOpenAITokenCacheStub
()
cache
.
lockErr
=
errors
.
New
(
"redis lock failed"
)
// Token expires soon (within refresh skew)
expiresAt
:=
time
.
Now
()
.
Add
(
1
*
time
.
Minute
)
.
Format
(
time
.
RFC3339
)
account
:=
&
Account
{
ID
:
204
,
Platform
:
PlatformOpenAI
,
Type
:
AccountTypeOAuth
,
Credentials
:
map
[
string
]
any
{
"access_token"
:
"fallback-on-lock-error"
,
"expires_at"
:
expiresAt
,
},
}
provider
:=
NewOpenAITokenProvider
(
nil
,
cache
,
nil
)
token
,
err
:=
provider
.
GetAccessToken
(
context
.
Background
(),
account
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
"fallback-on-lock-error"
,
token
)
}
func
TestOpenAITokenProvider_Real_WhitespaceCredentialToken
(
t
*
testing
.
T
)
{
cache
:=
newOpenAITokenCacheStub
()
expiresAt
:=
time
.
Now
()
.
Add
(
1
*
time
.
Hour
)
.
Format
(
time
.
RFC3339
)
account
:=
&
Account
{
ID
:
205
,
Platform
:
PlatformOpenAI
,
Type
:
AccountTypeOAuth
,
Credentials
:
map
[
string
]
any
{
"access_token"
:
" "
,
// Whitespace only
"expires_at"
:
expiresAt
,
},
}
provider
:=
NewOpenAITokenProvider
(
nil
,
cache
,
nil
)
token
,
err
:=
provider
.
GetAccessToken
(
context
.
Background
(),
account
)
require
.
Error
(
t
,
err
)
require
.
Contains
(
t
,
err
.
Error
(),
"access_token not found"
)
require
.
Empty
(
t
,
token
)
}
func
TestOpenAITokenProvider_Real_NilCredentials
(
t
*
testing
.
T
)
{
cache
:=
newOpenAITokenCacheStub
()
expiresAt
:=
time
.
Now
()
.
Add
(
1
*
time
.
Hour
)
.
Format
(
time
.
RFC3339
)
account
:=
&
Account
{
ID
:
206
,
Platform
:
PlatformOpenAI
,
Type
:
AccountTypeOAuth
,
Credentials
:
map
[
string
]
any
{
"expires_at"
:
expiresAt
,
// No access_token
},
}
provider
:=
NewOpenAITokenProvider
(
nil
,
cache
,
nil
)
token
,
err
:=
provider
.
GetAccessToken
(
context
.
Background
(),
account
)
require
.
Error
(
t
,
err
)
require
.
Contains
(
t
,
err
.
Error
(),
"access_token not found"
)
require
.
Empty
(
t
,
token
)
}
backend/internal/service/openai_tool_corrector.go
0 → 100644
View file @
6901b64f
package
service
import
(
"encoding/json"
"fmt"
"log"
"sync"
)
// codexToolNameMapping 定义 Codex 原生工具名称到 OpenCode 工具名称的映射
var
codexToolNameMapping
=
map
[
string
]
string
{
"apply_patch"
:
"edit"
,
"applyPatch"
:
"edit"
,
"update_plan"
:
"todowrite"
,
"updatePlan"
:
"todowrite"
,
"read_plan"
:
"todoread"
,
"readPlan"
:
"todoread"
,
"search_files"
:
"grep"
,
"searchFiles"
:
"grep"
,
"list_files"
:
"glob"
,
"listFiles"
:
"glob"
,
"read_file"
:
"read"
,
"readFile"
:
"read"
,
"write_file"
:
"write"
,
"writeFile"
:
"write"
,
"execute_bash"
:
"bash"
,
"executeBash"
:
"bash"
,
"exec_bash"
:
"bash"
,
"execBash"
:
"bash"
,
}
// ToolCorrectionStats 记录工具修正的统计信息(导出用于 JSON 序列化)
type
ToolCorrectionStats
struct
{
TotalCorrected
int
`json:"total_corrected"`
CorrectionsByTool
map
[
string
]
int
`json:"corrections_by_tool"`
}
// CodexToolCorrector 处理 Codex 工具调用的自动修正
type
CodexToolCorrector
struct
{
stats
ToolCorrectionStats
mu
sync
.
RWMutex
}
// NewCodexToolCorrector 创建新的工具修正器
func
NewCodexToolCorrector
()
*
CodexToolCorrector
{
return
&
CodexToolCorrector
{
stats
:
ToolCorrectionStats
{
CorrectionsByTool
:
make
(
map
[
string
]
int
),
},
}
}
// CorrectToolCallsInSSEData 修正 SSE 数据中的工具调用
// 返回修正后的数据和是否进行了修正
func
(
c
*
CodexToolCorrector
)
CorrectToolCallsInSSEData
(
data
string
)
(
string
,
bool
)
{
if
data
==
""
||
data
==
"
\n
"
{
return
data
,
false
}
// 尝试解析 JSON
var
payload
map
[
string
]
any
if
err
:=
json
.
Unmarshal
([]
byte
(
data
),
&
payload
);
err
!=
nil
{
// 不是有效的 JSON,直接返回原数据
return
data
,
false
}
corrected
:=
false
// 处理 tool_calls 数组
if
toolCalls
,
ok
:=
payload
[
"tool_calls"
]
.
([]
any
);
ok
{
if
c
.
correctToolCallsArray
(
toolCalls
)
{
corrected
=
true
}
}
// 处理 function_call 对象
if
functionCall
,
ok
:=
payload
[
"function_call"
]
.
(
map
[
string
]
any
);
ok
{
if
c
.
correctFunctionCall
(
functionCall
)
{
corrected
=
true
}
}
// 处理 delta.tool_calls
if
delta
,
ok
:=
payload
[
"delta"
]
.
(
map
[
string
]
any
);
ok
{
if
toolCalls
,
ok
:=
delta
[
"tool_calls"
]
.
([]
any
);
ok
{
if
c
.
correctToolCallsArray
(
toolCalls
)
{
corrected
=
true
}
}
if
functionCall
,
ok
:=
delta
[
"function_call"
]
.
(
map
[
string
]
any
);
ok
{
if
c
.
correctFunctionCall
(
functionCall
)
{
corrected
=
true
}
}
}
// 处理 choices[].message.tool_calls 和 choices[].delta.tool_calls
if
choices
,
ok
:=
payload
[
"choices"
]
.
([]
any
);
ok
{
for
_
,
choice
:=
range
choices
{
if
choiceMap
,
ok
:=
choice
.
(
map
[
string
]
any
);
ok
{
// 处理 message 中的工具调用
if
message
,
ok
:=
choiceMap
[
"message"
]
.
(
map
[
string
]
any
);
ok
{
if
toolCalls
,
ok
:=
message
[
"tool_calls"
]
.
([]
any
);
ok
{
if
c
.
correctToolCallsArray
(
toolCalls
)
{
corrected
=
true
}
}
if
functionCall
,
ok
:=
message
[
"function_call"
]
.
(
map
[
string
]
any
);
ok
{
if
c
.
correctFunctionCall
(
functionCall
)
{
corrected
=
true
}
}
}
// 处理 delta 中的工具调用
if
delta
,
ok
:=
choiceMap
[
"delta"
]
.
(
map
[
string
]
any
);
ok
{
if
toolCalls
,
ok
:=
delta
[
"tool_calls"
]
.
([]
any
);
ok
{
if
c
.
correctToolCallsArray
(
toolCalls
)
{
corrected
=
true
}
}
if
functionCall
,
ok
:=
delta
[
"function_call"
]
.
(
map
[
string
]
any
);
ok
{
if
c
.
correctFunctionCall
(
functionCall
)
{
corrected
=
true
}
}
}
}
}
}
if
!
corrected
{
return
data
,
false
}
// 序列化回 JSON
correctedBytes
,
err
:=
json
.
Marshal
(
payload
)
if
err
!=
nil
{
log
.
Printf
(
"[CodexToolCorrector] Failed to marshal corrected data: %v"
,
err
)
return
data
,
false
}
return
string
(
correctedBytes
),
true
}
// correctToolCallsArray 修正工具调用数组中的工具名称
func
(
c
*
CodexToolCorrector
)
correctToolCallsArray
(
toolCalls
[]
any
)
bool
{
corrected
:=
false
for
_
,
toolCall
:=
range
toolCalls
{
if
toolCallMap
,
ok
:=
toolCall
.
(
map
[
string
]
any
);
ok
{
if
function
,
ok
:=
toolCallMap
[
"function"
]
.
(
map
[
string
]
any
);
ok
{
if
c
.
correctFunctionCall
(
function
)
{
corrected
=
true
}
}
}
}
return
corrected
}
// correctFunctionCall 修正单个函数调用的工具名称和参数
func
(
c
*
CodexToolCorrector
)
correctFunctionCall
(
functionCall
map
[
string
]
any
)
bool
{
name
,
ok
:=
functionCall
[
"name"
]
.
(
string
)
if
!
ok
||
name
==
""
{
return
false
}
corrected
:=
false
// 查找并修正工具名称
if
correctName
,
found
:=
codexToolNameMapping
[
name
];
found
{
functionCall
[
"name"
]
=
correctName
c
.
recordCorrection
(
name
,
correctName
)
corrected
=
true
name
=
correctName
// 使用修正后的名称进行参数修正
}
// 修正工具参数(基于工具名称)
if
c
.
correctToolParameters
(
name
,
functionCall
)
{
corrected
=
true
}
return
corrected
}
// correctToolParameters 修正工具参数以符合 OpenCode 规范
func
(
c
*
CodexToolCorrector
)
correctToolParameters
(
toolName
string
,
functionCall
map
[
string
]
any
)
bool
{
arguments
,
ok
:=
functionCall
[
"arguments"
]
if
!
ok
{
return
false
}
// arguments 可能是字符串(JSON)或已解析的 map
var
argsMap
map
[
string
]
any
switch
v
:=
arguments
.
(
type
)
{
case
string
:
// 解析 JSON 字符串
if
err
:=
json
.
Unmarshal
([]
byte
(
v
),
&
argsMap
);
err
!=
nil
{
return
false
}
case
map
[
string
]
any
:
argsMap
=
v
default
:
return
false
}
corrected
:=
false
// 根据工具名称应用特定的参数修正规则
switch
toolName
{
case
"bash"
:
// 移除 workdir 参数(OpenCode 不支持)
if
_
,
exists
:=
argsMap
[
"workdir"
];
exists
{
delete
(
argsMap
,
"workdir"
)
corrected
=
true
log
.
Printf
(
"[CodexToolCorrector] Removed 'workdir' parameter from bash tool"
)
}
if
_
,
exists
:=
argsMap
[
"work_dir"
];
exists
{
delete
(
argsMap
,
"work_dir"
)
corrected
=
true
log
.
Printf
(
"[CodexToolCorrector] Removed 'work_dir' parameter from bash tool"
)
}
case
"edit"
:
// OpenCode edit 使用 old_string/new_string,Codex 可能使用其他名称
// 这里可以添加参数名称的映射逻辑
if
_
,
exists
:=
argsMap
[
"file_path"
];
!
exists
{
if
path
,
exists
:=
argsMap
[
"path"
];
exists
{
argsMap
[
"file_path"
]
=
path
delete
(
argsMap
,
"path"
)
corrected
=
true
log
.
Printf
(
"[CodexToolCorrector] Renamed 'path' to 'file_path' in edit tool"
)
}
}
}
// 如果修正了参数,需要重新序列化
if
corrected
{
if
_
,
wasString
:=
arguments
.
(
string
);
wasString
{
// 原本是字符串,序列化回字符串
if
newArgsJSON
,
err
:=
json
.
Marshal
(
argsMap
);
err
==
nil
{
functionCall
[
"arguments"
]
=
string
(
newArgsJSON
)
}
}
else
{
// 原本是 map,直接赋值
functionCall
[
"arguments"
]
=
argsMap
}
}
return
corrected
}
// recordCorrection 记录一次工具名称修正
func
(
c
*
CodexToolCorrector
)
recordCorrection
(
from
,
to
string
)
{
c
.
mu
.
Lock
()
defer
c
.
mu
.
Unlock
()
c
.
stats
.
TotalCorrected
++
key
:=
fmt
.
Sprintf
(
"%s->%s"
,
from
,
to
)
c
.
stats
.
CorrectionsByTool
[
key
]
++
log
.
Printf
(
"[CodexToolCorrector] Corrected tool call: %s -> %s (total: %d)"
,
from
,
to
,
c
.
stats
.
TotalCorrected
)
}
// GetStats 获取工具修正统计信息
func
(
c
*
CodexToolCorrector
)
GetStats
()
ToolCorrectionStats
{
c
.
mu
.
RLock
()
defer
c
.
mu
.
RUnlock
()
// 返回副本以避免并发问题
statsCopy
:=
ToolCorrectionStats
{
TotalCorrected
:
c
.
stats
.
TotalCorrected
,
CorrectionsByTool
:
make
(
map
[
string
]
int
,
len
(
c
.
stats
.
CorrectionsByTool
)),
}
for
k
,
v
:=
range
c
.
stats
.
CorrectionsByTool
{
statsCopy
.
CorrectionsByTool
[
k
]
=
v
}
return
statsCopy
}
// ResetStats 重置统计信息
func
(
c
*
CodexToolCorrector
)
ResetStats
()
{
c
.
mu
.
Lock
()
defer
c
.
mu
.
Unlock
()
c
.
stats
.
TotalCorrected
=
0
c
.
stats
.
CorrectionsByTool
=
make
(
map
[
string
]
int
)
}
// CorrectToolName 直接修正工具名称(用于非 SSE 场景)
func
CorrectToolName
(
name
string
)
(
string
,
bool
)
{
if
correctName
,
found
:=
codexToolNameMapping
[
name
];
found
{
return
correctName
,
true
}
return
name
,
false
}
// GetToolNameMapping 获取工具名称映射表
func
GetToolNameMapping
()
map
[
string
]
string
{
// 返回副本以避免外部修改
mapping
:=
make
(
map
[
string
]
string
,
len
(
codexToolNameMapping
))
for
k
,
v
:=
range
codexToolNameMapping
{
mapping
[
k
]
=
v
}
return
mapping
}
backend/internal/service/openai_tool_corrector_test.go
0 → 100644
View file @
6901b64f
package
service
import
(
"encoding/json"
"testing"
)
func
TestCorrectToolCallsInSSEData
(
t
*
testing
.
T
)
{
corrector
:=
NewCodexToolCorrector
()
tests
:=
[]
struct
{
name
string
input
string
expectCorrected
bool
checkFunc
func
(
t
*
testing
.
T
,
result
string
)
}{
{
name
:
"empty string"
,
input
:
""
,
expectCorrected
:
false
,
},
{
name
:
"newline only"
,
input
:
"
\n
"
,
expectCorrected
:
false
,
},
{
name
:
"invalid json"
,
input
:
"not a json"
,
expectCorrected
:
false
,
},
{
name
:
"correct apply_patch in tool_calls"
,
input
:
`{"tool_calls":[{"function":{"name":"apply_patch","arguments":"{}"}}]}`
,
expectCorrected
:
true
,
checkFunc
:
func
(
t
*
testing
.
T
,
result
string
)
{
var
payload
map
[
string
]
any
if
err
:=
json
.
Unmarshal
([]
byte
(
result
),
&
payload
);
err
!=
nil
{
t
.
Fatalf
(
"Failed to parse result: %v"
,
err
)
}
toolCalls
,
ok
:=
payload
[
"tool_calls"
]
.
([]
any
)
if
!
ok
||
len
(
toolCalls
)
==
0
{
t
.
Fatal
(
"No tool_calls found in result"
)
}
toolCall
,
ok
:=
toolCalls
[
0
]
.
(
map
[
string
]
any
)
if
!
ok
{
t
.
Fatal
(
"Invalid tool_call format"
)
}
functionCall
,
ok
:=
toolCall
[
"function"
]
.
(
map
[
string
]
any
)
if
!
ok
{
t
.
Fatal
(
"Invalid function format"
)
}
if
functionCall
[
"name"
]
!=
"edit"
{
t
.
Errorf
(
"Expected tool name 'edit', got '%v'"
,
functionCall
[
"name"
])
}
},
},
{
name
:
"correct update_plan in function_call"
,
input
:
`{"function_call":{"name":"update_plan","arguments":"{}"}}`
,
expectCorrected
:
true
,
checkFunc
:
func
(
t
*
testing
.
T
,
result
string
)
{
var
payload
map
[
string
]
any
if
err
:=
json
.
Unmarshal
([]
byte
(
result
),
&
payload
);
err
!=
nil
{
t
.
Fatalf
(
"Failed to parse result: %v"
,
err
)
}
functionCall
,
ok
:=
payload
[
"function_call"
]
.
(
map
[
string
]
any
)
if
!
ok
{
t
.
Fatal
(
"Invalid function_call format"
)
}
if
functionCall
[
"name"
]
!=
"todowrite"
{
t
.
Errorf
(
"Expected tool name 'todowrite', got '%v'"
,
functionCall
[
"name"
])
}
},
},
{
name
:
"correct search_files in delta.tool_calls"
,
input
:
`{"delta":{"tool_calls":[{"function":{"name":"search_files"}}]}}`
,
expectCorrected
:
true
,
checkFunc
:
func
(
t
*
testing
.
T
,
result
string
)
{
var
payload
map
[
string
]
any
if
err
:=
json
.
Unmarshal
([]
byte
(
result
),
&
payload
);
err
!=
nil
{
t
.
Fatalf
(
"Failed to parse result: %v"
,
err
)
}
delta
,
ok
:=
payload
[
"delta"
]
.
(
map
[
string
]
any
)
if
!
ok
{
t
.
Fatal
(
"Invalid delta format"
)
}
toolCalls
,
ok
:=
delta
[
"tool_calls"
]
.
([]
any
)
if
!
ok
||
len
(
toolCalls
)
==
0
{
t
.
Fatal
(
"No tool_calls found in delta"
)
}
toolCall
,
ok
:=
toolCalls
[
0
]
.
(
map
[
string
]
any
)
if
!
ok
{
t
.
Fatal
(
"Invalid tool_call format"
)
}
functionCall
,
ok
:=
toolCall
[
"function"
]
.
(
map
[
string
]
any
)
if
!
ok
{
t
.
Fatal
(
"Invalid function format"
)
}
if
functionCall
[
"name"
]
!=
"grep"
{
t
.
Errorf
(
"Expected tool name 'grep', got '%v'"
,
functionCall
[
"name"
])
}
},
},
{
name
:
"correct list_files in choices.message.tool_calls"
,
input
:
`{"choices":[{"message":{"tool_calls":[{"function":{"name":"list_files"}}]}}]}`
,
expectCorrected
:
true
,
checkFunc
:
func
(
t
*
testing
.
T
,
result
string
)
{
var
payload
map
[
string
]
any
if
err
:=
json
.
Unmarshal
([]
byte
(
result
),
&
payload
);
err
!=
nil
{
t
.
Fatalf
(
"Failed to parse result: %v"
,
err
)
}
choices
,
ok
:=
payload
[
"choices"
]
.
([]
any
)
if
!
ok
||
len
(
choices
)
==
0
{
t
.
Fatal
(
"No choices found in result"
)
}
choice
,
ok
:=
choices
[
0
]
.
(
map
[
string
]
any
)
if
!
ok
{
t
.
Fatal
(
"Invalid choice format"
)
}
message
,
ok
:=
choice
[
"message"
]
.
(
map
[
string
]
any
)
if
!
ok
{
t
.
Fatal
(
"Invalid message format"
)
}
toolCalls
,
ok
:=
message
[
"tool_calls"
]
.
([]
any
)
if
!
ok
||
len
(
toolCalls
)
==
0
{
t
.
Fatal
(
"No tool_calls found in message"
)
}
toolCall
,
ok
:=
toolCalls
[
0
]
.
(
map
[
string
]
any
)
if
!
ok
{
t
.
Fatal
(
"Invalid tool_call format"
)
}
functionCall
,
ok
:=
toolCall
[
"function"
]
.
(
map
[
string
]
any
)
if
!
ok
{
t
.
Fatal
(
"Invalid function format"
)
}
if
functionCall
[
"name"
]
!=
"glob"
{
t
.
Errorf
(
"Expected tool name 'glob', got '%v'"
,
functionCall
[
"name"
])
}
},
},
{
name
:
"no correction needed"
,
input
:
`{"tool_calls":[{"function":{"name":"read","arguments":"{}"}}]}`
,
expectCorrected
:
false
,
},
{
name
:
"correct multiple tool calls"
,
input
:
`{"tool_calls":[{"function":{"name":"apply_patch"}},{"function":{"name":"read_file"}}]}`
,
expectCorrected
:
true
,
checkFunc
:
func
(
t
*
testing
.
T
,
result
string
)
{
var
payload
map
[
string
]
any
if
err
:=
json
.
Unmarshal
([]
byte
(
result
),
&
payload
);
err
!=
nil
{
t
.
Fatalf
(
"Failed to parse result: %v"
,
err
)
}
toolCalls
,
ok
:=
payload
[
"tool_calls"
]
.
([]
any
)
if
!
ok
||
len
(
toolCalls
)
<
2
{
t
.
Fatal
(
"Expected at least 2 tool_calls"
)
}
toolCall1
,
ok
:=
toolCalls
[
0
]
.
(
map
[
string
]
any
)
if
!
ok
{
t
.
Fatal
(
"Invalid first tool_call format"
)
}
func1
,
ok
:=
toolCall1
[
"function"
]
.
(
map
[
string
]
any
)
if
!
ok
{
t
.
Fatal
(
"Invalid first function format"
)
}
if
func1
[
"name"
]
!=
"edit"
{
t
.
Errorf
(
"Expected first tool name 'edit', got '%v'"
,
func1
[
"name"
])
}
toolCall2
,
ok
:=
toolCalls
[
1
]
.
(
map
[
string
]
any
)
if
!
ok
{
t
.
Fatal
(
"Invalid second tool_call format"
)
}
func2
,
ok
:=
toolCall2
[
"function"
]
.
(
map
[
string
]
any
)
if
!
ok
{
t
.
Fatal
(
"Invalid second function format"
)
}
if
func2
[
"name"
]
!=
"read"
{
t
.
Errorf
(
"Expected second tool name 'read', got '%v'"
,
func2
[
"name"
])
}
},
},
{
name
:
"camelCase format - applyPatch"
,
input
:
`{"tool_calls":[{"function":{"name":"applyPatch"}}]}`
,
expectCorrected
:
true
,
checkFunc
:
func
(
t
*
testing
.
T
,
result
string
)
{
var
payload
map
[
string
]
any
if
err
:=
json
.
Unmarshal
([]
byte
(
result
),
&
payload
);
err
!=
nil
{
t
.
Fatalf
(
"Failed to parse result: %v"
,
err
)
}
toolCalls
,
ok
:=
payload
[
"tool_calls"
]
.
([]
any
)
if
!
ok
||
len
(
toolCalls
)
==
0
{
t
.
Fatal
(
"No tool_calls found in result"
)
}
toolCall
,
ok
:=
toolCalls
[
0
]
.
(
map
[
string
]
any
)
if
!
ok
{
t
.
Fatal
(
"Invalid tool_call format"
)
}
functionCall
,
ok
:=
toolCall
[
"function"
]
.
(
map
[
string
]
any
)
if
!
ok
{
t
.
Fatal
(
"Invalid function format"
)
}
if
functionCall
[
"name"
]
!=
"edit"
{
t
.
Errorf
(
"Expected tool name 'edit', got '%v'"
,
functionCall
[
"name"
])
}
},
},
}
for
_
,
tt
:=
range
tests
{
t
.
Run
(
tt
.
name
,
func
(
t
*
testing
.
T
)
{
result
,
corrected
:=
corrector
.
CorrectToolCallsInSSEData
(
tt
.
input
)
if
corrected
!=
tt
.
expectCorrected
{
t
.
Errorf
(
"Expected corrected=%v, got %v"
,
tt
.
expectCorrected
,
corrected
)
}
if
!
corrected
&&
result
!=
tt
.
input
{
t
.
Errorf
(
"Expected unchanged result when not corrected"
)
}
if
tt
.
checkFunc
!=
nil
{
tt
.
checkFunc
(
t
,
result
)
}
})
}
}
func
TestCorrectToolName
(
t
*
testing
.
T
)
{
tests
:=
[]
struct
{
input
string
expected
string
corrected
bool
}{
{
"apply_patch"
,
"edit"
,
true
},
{
"applyPatch"
,
"edit"
,
true
},
{
"update_plan"
,
"todowrite"
,
true
},
{
"updatePlan"
,
"todowrite"
,
true
},
{
"read_plan"
,
"todoread"
,
true
},
{
"readPlan"
,
"todoread"
,
true
},
{
"search_files"
,
"grep"
,
true
},
{
"searchFiles"
,
"grep"
,
true
},
{
"list_files"
,
"glob"
,
true
},
{
"listFiles"
,
"glob"
,
true
},
{
"read_file"
,
"read"
,
true
},
{
"readFile"
,
"read"
,
true
},
{
"write_file"
,
"write"
,
true
},
{
"writeFile"
,
"write"
,
true
},
{
"execute_bash"
,
"bash"
,
true
},
{
"executeBash"
,
"bash"
,
true
},
{
"exec_bash"
,
"bash"
,
true
},
{
"execBash"
,
"bash"
,
true
},
{
"unknown_tool"
,
"unknown_tool"
,
false
},
{
"read"
,
"read"
,
false
},
{
"edit"
,
"edit"
,
false
},
}
for
_
,
tt
:=
range
tests
{
t
.
Run
(
tt
.
input
,
func
(
t
*
testing
.
T
)
{
result
,
corrected
:=
CorrectToolName
(
tt
.
input
)
if
corrected
!=
tt
.
corrected
{
t
.
Errorf
(
"Expected corrected=%v, got %v"
,
tt
.
corrected
,
corrected
)
}
if
result
!=
tt
.
expected
{
t
.
Errorf
(
"Expected '%s', got '%s'"
,
tt
.
expected
,
result
)
}
})
}
}
func
TestGetToolNameMapping
(
t
*
testing
.
T
)
{
mapping
:=
GetToolNameMapping
()
expectedMappings
:=
map
[
string
]
string
{
"apply_patch"
:
"edit"
,
"update_plan"
:
"todowrite"
,
"read_plan"
:
"todoread"
,
"search_files"
:
"grep"
,
"list_files"
:
"glob"
,
}
for
from
,
to
:=
range
expectedMappings
{
if
mapping
[
from
]
!=
to
{
t
.
Errorf
(
"Expected mapping[%s] = %s, got %s"
,
from
,
to
,
mapping
[
from
])
}
}
mapping
[
"test_tool"
]
=
"test_value"
newMapping
:=
GetToolNameMapping
()
if
_
,
exists
:=
newMapping
[
"test_tool"
];
exists
{
t
.
Error
(
"Modifications to returned mapping should not affect original"
)
}
}
func
TestCorrectorStats
(
t
*
testing
.
T
)
{
corrector
:=
NewCodexToolCorrector
()
stats
:=
corrector
.
GetStats
()
if
stats
.
TotalCorrected
!=
0
{
t
.
Errorf
(
"Expected TotalCorrected=0, got %d"
,
stats
.
TotalCorrected
)
}
if
len
(
stats
.
CorrectionsByTool
)
!=
0
{
t
.
Errorf
(
"Expected empty CorrectionsByTool, got length %d"
,
len
(
stats
.
CorrectionsByTool
))
}
corrector
.
CorrectToolCallsInSSEData
(
`{"tool_calls":[{"function":{"name":"apply_patch"}}]}`
)
corrector
.
CorrectToolCallsInSSEData
(
`{"tool_calls":[{"function":{"name":"apply_patch"}}]}`
)
corrector
.
CorrectToolCallsInSSEData
(
`{"tool_calls":[{"function":{"name":"update_plan"}}]}`
)
stats
=
corrector
.
GetStats
()
if
stats
.
TotalCorrected
!=
3
{
t
.
Errorf
(
"Expected TotalCorrected=3, got %d"
,
stats
.
TotalCorrected
)
}
if
stats
.
CorrectionsByTool
[
"apply_patch->edit"
]
!=
2
{
t
.
Errorf
(
"Expected apply_patch->edit count=2, got %d"
,
stats
.
CorrectionsByTool
[
"apply_patch->edit"
])
}
if
stats
.
CorrectionsByTool
[
"update_plan->todowrite"
]
!=
1
{
t
.
Errorf
(
"Expected update_plan->todowrite count=1, got %d"
,
stats
.
CorrectionsByTool
[
"update_plan->todowrite"
])
}
corrector
.
ResetStats
()
stats
=
corrector
.
GetStats
()
if
stats
.
TotalCorrected
!=
0
{
t
.
Errorf
(
"Expected TotalCorrected=0 after reset, got %d"
,
stats
.
TotalCorrected
)
}
if
len
(
stats
.
CorrectionsByTool
)
!=
0
{
t
.
Errorf
(
"Expected empty CorrectionsByTool after reset, got length %d"
,
len
(
stats
.
CorrectionsByTool
))
}
}
func
TestComplexSSEData
(
t
*
testing
.
T
)
{
corrector
:=
NewCodexToolCorrector
()
input
:=
`{
"id": "chatcmpl-123",
"object": "chat.completion.chunk",
"created": 1234567890,
"model": "gpt-5.1-codex",
"choices": [
{
"index": 0,
"delta": {
"tool_calls": [
{
"index": 0,
"function": {
"name": "apply_patch",
"arguments": "{\"file\":\"test.go\"}"
}
}
]
},
"finish_reason": null
}
]
}`
result
,
corrected
:=
corrector
.
CorrectToolCallsInSSEData
(
input
)
if
!
corrected
{
t
.
Error
(
"Expected data to be corrected"
)
}
var
payload
map
[
string
]
any
if
err
:=
json
.
Unmarshal
([]
byte
(
result
),
&
payload
);
err
!=
nil
{
t
.
Fatalf
(
"Failed to parse result: %v"
,
err
)
}
choices
,
ok
:=
payload
[
"choices"
]
.
([]
any
)
if
!
ok
||
len
(
choices
)
==
0
{
t
.
Fatal
(
"No choices found in result"
)
}
choice
,
ok
:=
choices
[
0
]
.
(
map
[
string
]
any
)
if
!
ok
{
t
.
Fatal
(
"Invalid choice format"
)
}
delta
,
ok
:=
choice
[
"delta"
]
.
(
map
[
string
]
any
)
if
!
ok
{
t
.
Fatal
(
"Invalid delta format"
)
}
toolCalls
,
ok
:=
delta
[
"tool_calls"
]
.
([]
any
)
if
!
ok
||
len
(
toolCalls
)
==
0
{
t
.
Fatal
(
"No tool_calls found in delta"
)
}
toolCall
,
ok
:=
toolCalls
[
0
]
.
(
map
[
string
]
any
)
if
!
ok
{
t
.
Fatal
(
"Invalid tool_call format"
)
}
function
,
ok
:=
toolCall
[
"function"
]
.
(
map
[
string
]
any
)
if
!
ok
{
t
.
Fatal
(
"Invalid function format"
)
}
if
function
[
"name"
]
!=
"edit"
{
t
.
Errorf
(
"Expected tool name 'edit', got '%v'"
,
function
[
"name"
])
}
}
// TestCorrectToolParameters 测试工具参数修正
func
TestCorrectToolParameters
(
t
*
testing
.
T
)
{
corrector
:=
NewCodexToolCorrector
()
tests
:=
[]
struct
{
name
string
input
string
expected
map
[
string
]
bool
// key: 期待存在的参数, value: true表示应该存在
}{
{
name
:
"remove workdir from bash tool"
,
input
:
`{
"tool_calls": [{
"function": {
"name": "bash",
"arguments": "{\"command\":\"ls\",\"workdir\":\"/tmp\"}"
}
}]
}`
,
expected
:
map
[
string
]
bool
{
"command"
:
true
,
"workdir"
:
false
,
},
},
{
name
:
"rename path to file_path in edit tool"
,
input
:
`{
"tool_calls": [{
"function": {
"name": "apply_patch",
"arguments": "{\"path\":\"/foo/bar.go\",\"old_string\":\"old\",\"new_string\":\"new\"}"
}
}]
}`
,
expected
:
map
[
string
]
bool
{
"file_path"
:
true
,
"path"
:
false
,
"old_string"
:
true
,
"new_string"
:
true
,
},
},
}
for
_
,
tt
:=
range
tests
{
t
.
Run
(
tt
.
name
,
func
(
t
*
testing
.
T
)
{
corrected
,
changed
:=
corrector
.
CorrectToolCallsInSSEData
(
tt
.
input
)
if
!
changed
{
t
.
Error
(
"expected data to be corrected"
)
}
// 解析修正后的数据
var
result
map
[
string
]
any
if
err
:=
json
.
Unmarshal
([]
byte
(
corrected
),
&
result
);
err
!=
nil
{
t
.
Fatalf
(
"failed to parse corrected data: %v"
,
err
)
}
// 检查工具调用
toolCalls
,
ok
:=
result
[
"tool_calls"
]
.
([]
any
)
if
!
ok
||
len
(
toolCalls
)
==
0
{
t
.
Fatal
(
"no tool_calls found in corrected data"
)
}
toolCall
,
ok
:=
toolCalls
[
0
]
.
(
map
[
string
]
any
)
if
!
ok
{
t
.
Fatal
(
"invalid tool_call structure"
)
}
function
,
ok
:=
toolCall
[
"function"
]
.
(
map
[
string
]
any
)
if
!
ok
{
t
.
Fatal
(
"no function found in tool_call"
)
}
argumentsStr
,
ok
:=
function
[
"arguments"
]
.
(
string
)
if
!
ok
{
t
.
Fatal
(
"arguments is not a string"
)
}
var
args
map
[
string
]
any
if
err
:=
json
.
Unmarshal
([]
byte
(
argumentsStr
),
&
args
);
err
!=
nil
{
t
.
Fatalf
(
"failed to parse arguments: %v"
,
err
)
}
// 验证期望的参数
for
param
,
shouldExist
:=
range
tt
.
expected
{
_
,
exists
:=
args
[
param
]
if
shouldExist
&&
!
exists
{
t
.
Errorf
(
"expected parameter %q to exist, but it doesn't"
,
param
)
}
if
!
shouldExist
&&
exists
{
t
.
Errorf
(
"expected parameter %q to not exist, but it does"
,
param
)
}
}
})
}
}
backend/internal/service/ops_aggregation_service.go
View file @
6901b64f
...
...
@@ -4,6 +4,7 @@ import (
"context"
"database/sql"
"errors"
"fmt"
"log"
"strings"
"sync"
...
...
@@ -235,11 +236,13 @@ func (s *OpsAggregationService) aggregateHourly() {
successAt
:=
finishedAt
hbCtx
,
hbCancel
:=
context
.
WithTimeout
(
context
.
Background
(),
2
*
time
.
Second
)
defer
hbCancel
()
result
:=
truncateString
(
fmt
.
Sprintf
(
"window=%s..%s"
,
start
.
Format
(
time
.
RFC3339
),
end
.
Format
(
time
.
RFC3339
)),
2048
)
_
=
s
.
opsRepo
.
UpsertJobHeartbeat
(
hbCtx
,
&
OpsUpsertJobHeartbeatInput
{
JobName
:
opsAggHourlyJobName
,
LastRunAt
:
&
runAt
,
LastSuccessAt
:
&
successAt
,
LastDurationMs
:
&
dur
,
LastResult
:
&
result
,
})
}
...
...
@@ -331,11 +334,13 @@ func (s *OpsAggregationService) aggregateDaily() {
successAt
:=
finishedAt
hbCtx
,
hbCancel
:=
context
.
WithTimeout
(
context
.
Background
(),
2
*
time
.
Second
)
defer
hbCancel
()
result
:=
truncateString
(
fmt
.
Sprintf
(
"window=%s..%s"
,
start
.
Format
(
time
.
RFC3339
),
end
.
Format
(
time
.
RFC3339
)),
2048
)
_
=
s
.
opsRepo
.
UpsertJobHeartbeat
(
hbCtx
,
&
OpsUpsertJobHeartbeatInput
{
JobName
:
opsAggDailyJobName
,
LastRunAt
:
&
runAt
,
LastSuccessAt
:
&
successAt
,
LastDurationMs
:
&
dur
,
LastResult
:
&
result
,
})
}
...
...
backend/internal/service/ops_alert_evaluator_service.go
View file @
6901b64f
...
...
@@ -190,6 +190,13 @@ func (s *OpsAlertEvaluatorService) evaluateOnce(interval time.Duration) {
return
}
rulesTotal
:=
len
(
rules
)
rulesEnabled
:=
0
rulesEvaluated
:=
0
eventsCreated
:=
0
eventsResolved
:=
0
emailsSent
:=
0
now
:=
time
.
Now
()
.
UTC
()
safeEnd
:=
now
.
Truncate
(
time
.
Minute
)
if
safeEnd
.
IsZero
()
{
...
...
@@ -205,8 +212,9 @@ func (s *OpsAlertEvaluatorService) evaluateOnce(interval time.Duration) {
if
rule
==
nil
||
!
rule
.
Enabled
||
rule
.
ID
<=
0
{
continue
}
rulesEnabled
++
scopePlatform
,
scopeGroupID
:=
parseOpsAlertRuleScope
(
rule
.
Filters
)
scopePlatform
,
scopeGroupID
,
scopeRegion
:=
parseOpsAlertRuleScope
(
rule
.
Filters
)
windowMinutes
:=
rule
.
WindowMinutes
if
windowMinutes
<=
0
{
...
...
@@ -220,6 +228,7 @@ func (s *OpsAlertEvaluatorService) evaluateOnce(interval time.Duration) {
s
.
resetRuleState
(
rule
.
ID
,
now
)
continue
}
rulesEvaluated
++
breachedNow
:=
compareMetric
(
metricValue
,
rule
.
Operator
,
rule
.
Threshold
)
required
:=
requiredSustainedBreaches
(
rule
.
SustainedMinutes
,
interval
)
...
...
@@ -236,6 +245,17 @@ func (s *OpsAlertEvaluatorService) evaluateOnce(interval time.Duration) {
continue
}
// Scoped silencing: if a matching silence exists, skip creating a firing event.
if
s
.
opsService
!=
nil
{
platform
:=
strings
.
TrimSpace
(
scopePlatform
)
region
:=
scopeRegion
if
platform
!=
""
{
if
ok
,
err
:=
s
.
opsService
.
IsAlertSilenced
(
ctx
,
rule
.
ID
,
platform
,
scopeGroupID
,
region
,
now
);
err
==
nil
&&
ok
{
continue
}
}
}
latestEvent
,
err
:=
s
.
opsRepo
.
GetLatestAlertEvent
(
ctx
,
rule
.
ID
)
if
err
!=
nil
{
log
.
Printf
(
"[OpsAlertEvaluator] get latest event failed (rule=%d): %v"
,
rule
.
ID
,
err
)
...
...
@@ -267,8 +287,11 @@ func (s *OpsAlertEvaluatorService) evaluateOnce(interval time.Duration) {
continue
}
eventsCreated
++
if
created
!=
nil
&&
created
.
ID
>
0
{
s
.
maybeSendAlertEmail
(
ctx
,
runtimeCfg
,
rule
,
created
)
if
s
.
maybeSendAlertEmail
(
ctx
,
runtimeCfg
,
rule
,
created
)
{
emailsSent
++
}
}
continue
}
...
...
@@ -278,11 +301,14 @@ func (s *OpsAlertEvaluatorService) evaluateOnce(interval time.Duration) {
resolvedAt
:=
now
if
err
:=
s
.
opsRepo
.
UpdateAlertEventStatus
(
ctx
,
activeEvent
.
ID
,
OpsAlertStatusResolved
,
&
resolvedAt
);
err
!=
nil
{
log
.
Printf
(
"[OpsAlertEvaluator] resolve event failed (event=%d): %v"
,
activeEvent
.
ID
,
err
)
}
else
{
eventsResolved
++
}
}
}
s
.
recordHeartbeatSuccess
(
runAt
,
time
.
Since
(
startedAt
))
result
:=
truncateString
(
fmt
.
Sprintf
(
"rules=%d enabled=%d evaluated=%d created=%d resolved=%d emails_sent=%d"
,
rulesTotal
,
rulesEnabled
,
rulesEvaluated
,
eventsCreated
,
eventsResolved
,
emailsSent
),
2048
)
s
.
recordHeartbeatSuccess
(
runAt
,
time
.
Since
(
startedAt
),
result
)
}
func
(
s
*
OpsAlertEvaluatorService
)
pruneRuleStates
(
rules
[]
*
OpsAlertRule
)
{
...
...
@@ -359,9 +385,9 @@ func requiredSustainedBreaches(sustainedMinutes int, interval time.Duration) int
return
required
}
func
parseOpsAlertRuleScope
(
filters
map
[
string
]
any
)
(
platform
string
,
groupID
*
int64
)
{
func
parseOpsAlertRuleScope
(
filters
map
[
string
]
any
)
(
platform
string
,
groupID
*
int64
,
region
*
string
)
{
if
filters
==
nil
{
return
""
,
nil
return
""
,
nil
,
nil
}
if
v
,
ok
:=
filters
[
"platform"
];
ok
{
if
s
,
ok
:=
v
.
(
string
);
ok
{
...
...
@@ -392,7 +418,15 @@ func parseOpsAlertRuleScope(filters map[string]any) (platform string, groupID *i
}
}
}
return
platform
,
groupID
if
v
,
ok
:=
filters
[
"region"
];
ok
{
if
s
,
ok
:=
v
.
(
string
);
ok
{
vv
:=
strings
.
TrimSpace
(
s
)
if
vv
!=
""
{
region
=
&
vv
}
}
}
return
platform
,
groupID
,
region
}
func
(
s
*
OpsAlertEvaluatorService
)
computeRuleMetric
(
...
...
@@ -504,16 +538,6 @@ func (s *OpsAlertEvaluatorService) computeRuleMetric(
return
0
,
false
}
return
overview
.
UpstreamErrorRate
*
100
,
true
case
"p95_latency_ms"
:
if
overview
.
Duration
.
P95
==
nil
{
return
0
,
false
}
return
float64
(
*
overview
.
Duration
.
P95
),
true
case
"p99_latency_ms"
:
if
overview
.
Duration
.
P99
==
nil
{
return
0
,
false
}
return
float64
(
*
overview
.
Duration
.
P99
),
true
default
:
return
0
,
false
}
...
...
@@ -576,32 +600,32 @@ func buildOpsAlertDescription(rule *OpsAlertRule, value float64, windowMinutes i
)
}
func
(
s
*
OpsAlertEvaluatorService
)
maybeSendAlertEmail
(
ctx
context
.
Context
,
runtimeCfg
*
OpsAlertRuntimeSettings
,
rule
*
OpsAlertRule
,
event
*
OpsAlertEvent
)
{
func
(
s
*
OpsAlertEvaluatorService
)
maybeSendAlertEmail
(
ctx
context
.
Context
,
runtimeCfg
*
OpsAlertRuntimeSettings
,
rule
*
OpsAlertRule
,
event
*
OpsAlertEvent
)
bool
{
if
s
==
nil
||
s
.
emailService
==
nil
||
s
.
opsService
==
nil
||
event
==
nil
||
rule
==
nil
{
return
return
false
}
if
event
.
EmailSent
{
return
return
false
}
if
!
rule
.
NotifyEmail
{
return
return
false
}
emailCfg
,
err
:=
s
.
opsService
.
GetEmailNotificationConfig
(
ctx
)
if
err
!=
nil
||
emailCfg
==
nil
||
!
emailCfg
.
Alert
.
Enabled
{
return
return
false
}
if
len
(
emailCfg
.
Alert
.
Recipients
)
==
0
{
return
return
false
}
if
!
shouldSendOpsAlertEmailByMinSeverity
(
strings
.
TrimSpace
(
emailCfg
.
Alert
.
MinSeverity
),
strings
.
TrimSpace
(
rule
.
Severity
))
{
return
return
false
}
if
runtimeCfg
!=
nil
&&
runtimeCfg
.
Silencing
.
Enabled
{
if
isOpsAlertSilenced
(
time
.
Now
()
.
UTC
(),
rule
,
event
,
runtimeCfg
.
Silencing
)
{
return
return
false
}
}
...
...
@@ -630,6 +654,7 @@ func (s *OpsAlertEvaluatorService) maybeSendAlertEmail(ctx context.Context, runt
if
anySent
{
_
=
s
.
opsRepo
.
UpdateAlertEventEmailSent
(
context
.
Background
(),
event
.
ID
,
true
)
}
return
anySent
}
func
buildOpsAlertEmailBody
(
rule
*
OpsAlertRule
,
event
*
OpsAlertEvent
)
string
{
...
...
@@ -797,7 +822,7 @@ func (s *OpsAlertEvaluatorService) maybeLogSkip(key string) {
log
.
Printf
(
"[OpsAlertEvaluator] leader lock held by another instance; skipping (key=%q)"
,
key
)
}
func
(
s
*
OpsAlertEvaluatorService
)
recordHeartbeatSuccess
(
runAt
time
.
Time
,
duration
time
.
Duration
)
{
func
(
s
*
OpsAlertEvaluatorService
)
recordHeartbeatSuccess
(
runAt
time
.
Time
,
duration
time
.
Duration
,
result
string
)
{
if
s
==
nil
||
s
.
opsRepo
==
nil
{
return
}
...
...
@@ -805,11 +830,17 @@ func (s *OpsAlertEvaluatorService) recordHeartbeatSuccess(runAt time.Time, durat
durMs
:=
duration
.
Milliseconds
()
ctx
,
cancel
:=
context
.
WithTimeout
(
context
.
Background
(),
2
*
time
.
Second
)
defer
cancel
()
msg
:=
strings
.
TrimSpace
(
result
)
if
msg
==
""
{
msg
=
"ok"
}
msg
=
truncateString
(
msg
,
2048
)
_
=
s
.
opsRepo
.
UpsertJobHeartbeat
(
ctx
,
&
OpsUpsertJobHeartbeatInput
{
JobName
:
opsAlertEvaluatorJobName
,
LastRunAt
:
&
runAt
,
LastSuccessAt
:
&
now
,
LastDurationMs
:
&
durMs
,
LastResult
:
&
msg
,
})
}
...
...
backend/internal/service/ops_alert_models.go
View file @
6901b64f
...
...
@@ -8,8 +8,9 @@ import "time"
// with the existing ops dashboard frontend (backup style).
const
(
OpsAlertStatusFiring
=
"firing"
OpsAlertStatusResolved
=
"resolved"
OpsAlertStatusFiring
=
"firing"
OpsAlertStatusResolved
=
"resolved"
OpsAlertStatusManualResolved
=
"manual_resolved"
)
type
OpsAlertRule
struct
{
...
...
@@ -58,12 +59,32 @@ type OpsAlertEvent struct {
CreatedAt
time
.
Time
`json:"created_at"`
}
type
OpsAlertSilence
struct
{
ID
int64
`json:"id"`
RuleID
int64
`json:"rule_id"`
Platform
string
`json:"platform"`
GroupID
*
int64
`json:"group_id,omitempty"`
Region
*
string
`json:"region,omitempty"`
Until
time
.
Time
`json:"until"`
Reason
string
`json:"reason"`
CreatedBy
*
int64
`json:"created_by,omitempty"`
CreatedAt
time
.
Time
`json:"created_at"`
}
type
OpsAlertEventFilter
struct
{
Limit
int
// Cursor pagination (descending by fired_at, then id).
BeforeFiredAt
*
time
.
Time
BeforeID
*
int64
// Optional filters.
Status
string
Severity
string
Status
string
Severity
string
EmailSent
*
bool
StartTime
*
time
.
Time
EndTime
*
time
.
Time
...
...
backend/internal/service/ops_alerts.go
View file @
6901b64f
...
...
@@ -88,6 +88,29 @@ func (s *OpsService) ListAlertEvents(ctx context.Context, filter *OpsAlertEventF
return
s
.
opsRepo
.
ListAlertEvents
(
ctx
,
filter
)
}
func
(
s
*
OpsService
)
GetAlertEventByID
(
ctx
context
.
Context
,
eventID
int64
)
(
*
OpsAlertEvent
,
error
)
{
if
err
:=
s
.
RequireMonitoringEnabled
(
ctx
);
err
!=
nil
{
return
nil
,
err
}
if
s
.
opsRepo
==
nil
{
return
nil
,
infraerrors
.
ServiceUnavailable
(
"OPS_REPO_UNAVAILABLE"
,
"Ops repository not available"
)
}
if
eventID
<=
0
{
return
nil
,
infraerrors
.
BadRequest
(
"INVALID_EVENT_ID"
,
"invalid event id"
)
}
ev
,
err
:=
s
.
opsRepo
.
GetAlertEventByID
(
ctx
,
eventID
)
if
err
!=
nil
{
if
errors
.
Is
(
err
,
sql
.
ErrNoRows
)
{
return
nil
,
infraerrors
.
NotFound
(
"OPS_ALERT_EVENT_NOT_FOUND"
,
"alert event not found"
)
}
return
nil
,
err
}
if
ev
==
nil
{
return
nil
,
infraerrors
.
NotFound
(
"OPS_ALERT_EVENT_NOT_FOUND"
,
"alert event not found"
)
}
return
ev
,
nil
}
func
(
s
*
OpsService
)
GetActiveAlertEvent
(
ctx
context
.
Context
,
ruleID
int64
)
(
*
OpsAlertEvent
,
error
)
{
if
err
:=
s
.
RequireMonitoringEnabled
(
ctx
);
err
!=
nil
{
return
nil
,
err
...
...
@@ -101,6 +124,49 @@ func (s *OpsService) GetActiveAlertEvent(ctx context.Context, ruleID int64) (*Op
return
s
.
opsRepo
.
GetActiveAlertEvent
(
ctx
,
ruleID
)
}
func
(
s
*
OpsService
)
CreateAlertSilence
(
ctx
context
.
Context
,
input
*
OpsAlertSilence
)
(
*
OpsAlertSilence
,
error
)
{
if
err
:=
s
.
RequireMonitoringEnabled
(
ctx
);
err
!=
nil
{
return
nil
,
err
}
if
s
.
opsRepo
==
nil
{
return
nil
,
infraerrors
.
ServiceUnavailable
(
"OPS_REPO_UNAVAILABLE"
,
"Ops repository not available"
)
}
if
input
==
nil
{
return
nil
,
infraerrors
.
BadRequest
(
"INVALID_SILENCE"
,
"invalid silence"
)
}
if
input
.
RuleID
<=
0
{
return
nil
,
infraerrors
.
BadRequest
(
"INVALID_RULE_ID"
,
"invalid rule id"
)
}
if
strings
.
TrimSpace
(
input
.
Platform
)
==
""
{
return
nil
,
infraerrors
.
BadRequest
(
"INVALID_PLATFORM"
,
"invalid platform"
)
}
if
input
.
Until
.
IsZero
()
{
return
nil
,
infraerrors
.
BadRequest
(
"INVALID_UNTIL"
,
"invalid until"
)
}
created
,
err
:=
s
.
opsRepo
.
CreateAlertSilence
(
ctx
,
input
)
if
err
!=
nil
{
return
nil
,
err
}
return
created
,
nil
}
func
(
s
*
OpsService
)
IsAlertSilenced
(
ctx
context
.
Context
,
ruleID
int64
,
platform
string
,
groupID
*
int64
,
region
*
string
,
now
time
.
Time
)
(
bool
,
error
)
{
if
err
:=
s
.
RequireMonitoringEnabled
(
ctx
);
err
!=
nil
{
return
false
,
err
}
if
s
.
opsRepo
==
nil
{
return
false
,
infraerrors
.
ServiceUnavailable
(
"OPS_REPO_UNAVAILABLE"
,
"Ops repository not available"
)
}
if
ruleID
<=
0
{
return
false
,
infraerrors
.
BadRequest
(
"INVALID_RULE_ID"
,
"invalid rule id"
)
}
if
strings
.
TrimSpace
(
platform
)
==
""
{
return
false
,
nil
}
return
s
.
opsRepo
.
IsAlertSilenced
(
ctx
,
ruleID
,
platform
,
groupID
,
region
,
now
)
}
func
(
s
*
OpsService
)
GetLatestAlertEvent
(
ctx
context
.
Context
,
ruleID
int64
)
(
*
OpsAlertEvent
,
error
)
{
if
err
:=
s
.
RequireMonitoringEnabled
(
ctx
);
err
!=
nil
{
return
nil
,
err
...
...
@@ -142,7 +208,11 @@ func (s *OpsService) UpdateAlertEventStatus(ctx context.Context, eventID int64,
if
eventID
<=
0
{
return
infraerrors
.
BadRequest
(
"INVALID_EVENT_ID"
,
"invalid event id"
)
}
if
strings
.
TrimSpace
(
status
)
==
""
{
status
=
strings
.
TrimSpace
(
status
)
if
status
==
""
{
return
infraerrors
.
BadRequest
(
"INVALID_STATUS"
,
"invalid status"
)
}
if
status
!=
OpsAlertStatusResolved
&&
status
!=
OpsAlertStatusManualResolved
{
return
infraerrors
.
BadRequest
(
"INVALID_STATUS"
,
"invalid status"
)
}
return
s
.
opsRepo
.
UpdateAlertEventStatus
(
ctx
,
eventID
,
status
,
resolvedAt
)
...
...
backend/internal/service/ops_cleanup_service.go
View file @
6901b64f
...
...
@@ -149,7 +149,7 @@ func (s *OpsCleanupService) runScheduled() {
log
.
Printf
(
"[OpsCleanup] cleanup failed: %v"
,
err
)
return
}
s
.
recordHeartbeatSuccess
(
runAt
,
time
.
Since
(
startedAt
))
s
.
recordHeartbeatSuccess
(
runAt
,
time
.
Since
(
startedAt
)
,
counts
)
log
.
Printf
(
"[OpsCleanup] cleanup complete: %s"
,
counts
)
}
...
...
@@ -330,12 +330,13 @@ func (s *OpsCleanupService) tryAcquireLeaderLock(ctx context.Context) (func(), b
return
release
,
true
}
func
(
s
*
OpsCleanupService
)
recordHeartbeatSuccess
(
runAt
time
.
Time
,
duration
time
.
Duration
)
{
func
(
s
*
OpsCleanupService
)
recordHeartbeatSuccess
(
runAt
time
.
Time
,
duration
time
.
Duration
,
counts
opsCleanupDeletedCounts
)
{
if
s
==
nil
||
s
.
opsRepo
==
nil
{
return
}
now
:=
time
.
Now
()
.
UTC
()
durMs
:=
duration
.
Milliseconds
()
result
:=
truncateString
(
counts
.
String
(),
2048
)
ctx
,
cancel
:=
context
.
WithTimeout
(
context
.
Background
(),
2
*
time
.
Second
)
defer
cancel
()
_
=
s
.
opsRepo
.
UpsertJobHeartbeat
(
ctx
,
&
OpsUpsertJobHeartbeatInput
{
...
...
@@ -343,6 +344,7 @@ func (s *OpsCleanupService) recordHeartbeatSuccess(runAt time.Time, duration tim
LastRunAt
:
&
runAt
,
LastSuccessAt
:
&
now
,
LastDurationMs
:
&
durMs
,
LastResult
:
&
result
,
})
}
...
...
backend/internal/service/ops_health_score.go
View file @
6901b64f
...
...
@@ -32,49 +32,38 @@ func computeDashboardHealthScore(now time.Time, overview *OpsDashboardOverview)
}
// computeBusinessHealth calculates business health score (0-100)
// Components:
SLA (50%) +
Error Rate (
3
0%) +
Latency
(
2
0%)
// Components: Error Rate (
5
0%) +
TTFT
(
5
0%)
func
computeBusinessHealth
(
overview
*
OpsDashboardOverview
)
float64
{
// SLA score: 99.5% → 100, 95% → 0 (linear)
slaScore
:=
100.0
slaPct
:=
clampFloat64
(
overview
.
SLA
*
100
,
0
,
100
)
if
slaPct
<
99.5
{
if
slaPct
>=
95
{
slaScore
=
(
slaPct
-
95
)
/
4.5
*
100
}
else
{
slaScore
=
0
}
}
// Error rate score: 0.5% → 100, 5% → 0 (linear)
// Error rate score: 1% → 100, 10% → 0 (linear)
// Combines request errors and upstream errors
errorScore
:=
100.0
errorPct
:=
clampFloat64
(
overview
.
ErrorRate
*
100
,
0
,
100
)
upstreamPct
:=
clampFloat64
(
overview
.
UpstreamErrorRate
*
100
,
0
,
100
)
combinedErrorPct
:=
math
.
Max
(
errorPct
,
upstreamPct
)
// Use worst case
if
combinedErrorPct
>
0.5
{
if
combinedErrorPct
<=
5
{
errorScore
=
(
5
-
combinedErrorPct
)
/
4.5
*
100
if
combinedErrorPct
>
1.0
{
if
combinedErrorPct
<=
10.0
{
errorScore
=
(
10.0
-
combinedErrorPct
)
/
9.0
*
100
}
else
{
errorScore
=
0
}
}
//
Latency
score: 1s → 100,
10
s → 0 (linear)
//
Uses P99 of duration (TTFT is less critical for overall health)
latency
Score
:=
100.0
if
overview
.
Duration
.
P99
!=
nil
{
p99
:=
float64
(
*
overview
.
Duration
.
P99
)
//
TTFT
score: 1s → 100,
3
s → 0 (linear)
//
Time to first token is critical for user experience
ttft
Score
:=
100.0
if
overview
.
TTFT
.
P99
!=
nil
{
p99
:=
float64
(
*
overview
.
TTFT
.
P99
)
if
p99
>
1000
{
if
p99
<=
10
000
{
latency
Score
=
(
10
000
-
p99
)
/
9
000
*
100
if
p99
<=
3
000
{
ttft
Score
=
(
3
000
-
p99
)
/
2
000
*
100
}
else
{
latency
Score
=
0
ttft
Score
=
0
}
}
}
// Weighted combination
return
slaScore
*
0.5
+
errorScore
*
0.
3
+
latency
Score
*
0.
2
// Weighted combination
: 50% error rate + 50% TTFT
return
errorScore
*
0.
5
+
ttft
Score
*
0.
5
}
// computeInfraHealth calculates infrastructure health score (0-100)
...
...
backend/internal/service/ops_health_score_test.go
View file @
6901b64f
...
...
@@ -127,8 +127,8 @@ func TestComputeDashboardHealthScore_Comprehensive(t *testing.T) {
MemoryUsagePercent
:
float64Ptr
(
75
),
},
},
wantMin
:
6
0
,
wantMax
:
85
,
wantMin
:
9
6
,
wantMax
:
97
,
},
{
name
:
"DB failure"
,
...
...
@@ -203,8 +203,8 @@ func TestComputeDashboardHealthScore_Comprehensive(t *testing.T) {
MemoryUsagePercent
:
float64Ptr
(
30
),
},
},
wantMin
:
25
,
wantMax
:
5
0
,
wantMin
:
84
,
wantMax
:
8
5
,
},
{
name
:
"combined failures - business healthy + infra degraded"
,
...
...
@@ -277,30 +277,41 @@ func TestComputeBusinessHealth(t *testing.T) {
UpstreamErrorRate
:
0
,
Duration
:
OpsPercentiles
{
P99
:
intPtr
(
500
)},
},
wantMin
:
5
0
,
wantMax
:
6
0
,
wantMin
:
10
0
,
wantMax
:
10
0
,
},
{
name
:
"error rate boundary
0.5
%"
,
name
:
"error rate boundary
1
%"
,
overview
:
&
OpsDashboardOverview
{
SLA
:
0.99
5
,
ErrorRate
:
0.0
05
,
SLA
:
0.99
,
ErrorRate
:
0.0
1
,
UpstreamErrorRate
:
0
,
Duration
:
OpsPercentiles
{
P99
:
intPtr
(
500
)},
},
wantMin
:
95
,
wantMin
:
100
,
wantMax
:
100
,
},
{
name
:
"
latency boundary 1000ms
"
,
name
:
"
error rate 5%
"
,
overview
:
&
OpsDashboardOverview
{
SLA
:
0.995
,
SLA
:
0.95
,
ErrorRate
:
0.05
,
UpstreamErrorRate
:
0
,
Duration
:
OpsPercentiles
{
P99
:
intPtr
(
500
)},
},
wantMin
:
77
,
wantMax
:
78
,
},
{
name
:
"TTFT boundary 2s"
,
overview
:
&
OpsDashboardOverview
{
SLA
:
0.99
,
ErrorRate
:
0
,
UpstreamErrorRate
:
0
,
Duration
:
OpsPercentiles
{
P99
:
intPtr
(
1
000
)},
TTFT
:
OpsPercentiles
{
P99
:
intPtr
(
2
000
)},
},
wantMin
:
9
5
,
wantMax
:
100
,
wantMin
:
7
5
,
wantMax
:
75
,
},
{
name
:
"upstream error dominates"
,
...
...
@@ -310,7 +321,7 @@ func TestComputeBusinessHealth(t *testing.T) {
UpstreamErrorRate
:
0.03
,
Duration
:
OpsPercentiles
{
P99
:
intPtr
(
500
)},
},
wantMin
:
75
,
wantMin
:
88
,
wantMax
:
90
,
},
}
...
...
backend/internal/service/ops_models.go
View file @
6901b64f
...
...
@@ -6,24 +6,43 @@ type OpsErrorLog struct {
ID
int64
`json:"id"`
CreatedAt
time
.
Time
`json:"created_at"`
Phase
string
`json:"phase"`
Type
string
`json:"type"`
// Standardized classification
// - phase: request|auth|routing|upstream|network|internal
// - owner: client|provider|platform
// - source: client_request|upstream_http|gateway
Phase
string
`json:"phase"`
Type
string
`json:"type"`
Owner
string
`json:"error_owner"`
Source
string
`json:"error_source"`
Severity
string
`json:"severity"`
StatusCode
int
`json:"status_code"`
Platform
string
`json:"platform"`
Model
string
`json:"model"`
LatencyMs
*
int
`json:"latency_ms"`
IsRetryable
bool
`json:"is_retryable"`
RetryCount
int
`json:"retry_count"`
Resolved
bool
`json:"resolved"`
ResolvedAt
*
time
.
Time
`json:"resolved_at"`
ResolvedByUserID
*
int64
`json:"resolved_by_user_id"`
ResolvedByUserName
string
`json:"resolved_by_user_name"`
ResolvedRetryID
*
int64
`json:"resolved_retry_id"`
ResolvedStatusRaw
string
`json:"-"`
ClientRequestID
string
`json:"client_request_id"`
RequestID
string
`json:"request_id"`
Message
string
`json:"message"`
UserID
*
int64
`json:"user_id"`
APIKeyID
*
int64
`json:"api_key_id"`
AccountID
*
int64
`json:"account_id"`
GroupID
*
int64
`json:"group_id"`
UserID
*
int64
`json:"user_id"`
UserEmail
string
`json:"user_email"`
APIKeyID
*
int64
`json:"api_key_id"`
AccountID
*
int64
`json:"account_id"`
AccountName
string
`json:"account_name"`
GroupID
*
int64
`json:"group_id"`
GroupName
string
`json:"group_name"`
ClientIP
*
string
`json:"client_ip"`
RequestPath
string
`json:"request_path"`
...
...
@@ -67,9 +86,24 @@ type OpsErrorLogFilter struct {
GroupID
*
int64
AccountID
*
int64
StatusCodes
[]
int
Phase
string
Query
string
StatusCodes
[]
int
StatusCodesOther
bool
Phase
string
Owner
string
Source
string
Resolved
*
bool
Query
string
UserQuery
string
// Search by user email
// Optional correlation keys for exact matching.
RequestID
string
ClientRequestID
string
// View controls error categorization for list endpoints.
// - errors: show actionable errors (exclude business-limited / 429 / 529)
// - excluded: only show excluded errors
// - all: show everything
View
string
Page
int
PageSize
int
...
...
@@ -90,12 +124,23 @@ type OpsRetryAttempt struct {
SourceErrorID
int64
`json:"source_error_id"`
Mode
string
`json:"mode"`
PinnedAccountID
*
int64
`json:"pinned_account_id"`
PinnedAccountName
string
`json:"pinned_account_name"`
Status
string
`json:"status"`
StartedAt
*
time
.
Time
`json:"started_at"`
FinishedAt
*
time
.
Time
`json:"finished_at"`
DurationMs
*
int64
`json:"duration_ms"`
// Persisted execution results (best-effort)
Success
*
bool
`json:"success"`
HTTPStatusCode
*
int
`json:"http_status_code"`
UpstreamRequestID
*
string
`json:"upstream_request_id"`
UsedAccountID
*
int64
`json:"used_account_id"`
UsedAccountName
string
`json:"used_account_name"`
ResponsePreview
*
string
`json:"response_preview"`
ResponseTruncated
*
bool
`json:"response_truncated"`
// Optional correlation
ResultRequestID
*
string
`json:"result_request_id"`
ResultErrorID
*
int64
`json:"result_error_id"`
...
...
Prev
1
2
3
4
5
6
7
8
9
10
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment