Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Menu
Open sidebar
陈曦
sub2api
Commits
b9b4db3d
Commit
b9b4db3d
authored
Jan 17, 2026
by
song
Browse files
Merge upstream/main
parents
5a6f60a9
dae0d532
Changes
237
Show whitespace changes
Inline
Side-by-side
Too many changes to show.
To preserve performance only
237 of 237+
files are displayed.
Plain diff
Email patch
backend/internal/service/group.go
View file @
b9b4db3d
package
service
import
"time"
import
(
"strings"
"time"
)
type
Group
struct
{
ID
int64
...
...
@@ -10,6 +13,7 @@ type Group struct {
RateMultiplier
float64
IsExclusive
bool
Status
string
Hydrated
bool
// indicates the group was loaded from a trusted repository source
SubscriptionType
string
DailyLimitUSD
*
float64
...
...
@@ -26,6 +30,12 @@ type Group struct {
ClaudeCodeOnly
bool
FallbackGroupID
*
int64
// 模型路由配置
// key: 模型匹配模式(支持 * 通配符,如 "claude-opus-*")
// value: 优先账号 ID 列表
ModelRouting
map
[
string
][]
int64
ModelRoutingEnabled
bool
CreatedAt
time
.
Time
UpdatedAt
time
.
Time
...
...
@@ -72,3 +82,58 @@ func (g *Group) GetImagePrice(imageSize string) *float64 {
return
g
.
ImagePrice2K
}
}
// IsGroupContextValid reports whether a group from context has the fields required for routing decisions.
func
IsGroupContextValid
(
group
*
Group
)
bool
{
if
group
==
nil
{
return
false
}
if
group
.
ID
<=
0
{
return
false
}
if
!
group
.
Hydrated
{
return
false
}
if
group
.
Platform
==
""
||
group
.
Status
==
""
{
return
false
}
return
true
}
// GetRoutingAccountIDs 根据请求模型获取路由账号 ID 列表
// 返回匹配的优先账号 ID 列表,如果没有匹配规则则返回 nil
func
(
g
*
Group
)
GetRoutingAccountIDs
(
requestedModel
string
)
[]
int64
{
if
!
g
.
ModelRoutingEnabled
||
len
(
g
.
ModelRouting
)
==
0
||
requestedModel
==
""
{
return
nil
}
// 1. 精确匹配优先
if
accountIDs
,
ok
:=
g
.
ModelRouting
[
requestedModel
];
ok
&&
len
(
accountIDs
)
>
0
{
return
accountIDs
}
// 2. 通配符匹配(前缀匹配)
for
pattern
,
accountIDs
:=
range
g
.
ModelRouting
{
if
matchModelPattern
(
pattern
,
requestedModel
)
&&
len
(
accountIDs
)
>
0
{
return
accountIDs
}
}
return
nil
}
// matchModelPattern 检查模型是否匹配模式
// 支持 * 通配符,如 "claude-opus-*" 匹配 "claude-opus-4-20250514"
func
matchModelPattern
(
pattern
,
model
string
)
bool
{
if
pattern
==
model
{
return
true
}
// 处理 * 通配符(仅支持末尾通配符)
if
strings
.
HasSuffix
(
pattern
,
"*"
)
{
prefix
:=
strings
.
TrimSuffix
(
pattern
,
"*"
)
return
strings
.
HasPrefix
(
model
,
prefix
)
}
return
false
}
backend/internal/service/group_service.go
View file @
b9b4db3d
...
...
@@ -16,6 +16,7 @@ var (
type
GroupRepository
interface
{
Create
(
ctx
context
.
Context
,
group
*
Group
)
error
GetByID
(
ctx
context
.
Context
,
id
int64
)
(
*
Group
,
error
)
GetByIDLite
(
ctx
context
.
Context
,
id
int64
)
(
*
Group
,
error
)
Update
(
ctx
context
.
Context
,
group
*
Group
)
error
Delete
(
ctx
context
.
Context
,
id
int64
)
error
DeleteCascade
(
ctx
context
.
Context
,
id
int64
)
([]
int64
,
error
)
...
...
@@ -50,12 +51,14 @@ type UpdateGroupRequest struct {
// GroupService 分组管理服务
type
GroupService
struct
{
groupRepo
GroupRepository
authCacheInvalidator
APIKeyAuthCacheInvalidator
}
// NewGroupService 创建分组服务实例
func
NewGroupService
(
groupRepo
GroupRepository
)
*
GroupService
{
func
NewGroupService
(
groupRepo
GroupRepository
,
authCacheInvalidator
APIKeyAuthCacheInvalidator
)
*
GroupService
{
return
&
GroupService
{
groupRepo
:
groupRepo
,
authCacheInvalidator
:
authCacheInvalidator
,
}
}
...
...
@@ -154,6 +157,9 @@ func (s *GroupService) Update(ctx context.Context, id int64, req UpdateGroupRequ
if
err
:=
s
.
groupRepo
.
Update
(
ctx
,
group
);
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"update group: %w"
,
err
)
}
if
s
.
authCacheInvalidator
!=
nil
{
s
.
authCacheInvalidator
.
InvalidateAuthCacheByGroupID
(
ctx
,
id
)
}
return
group
,
nil
}
...
...
@@ -166,6 +172,9 @@ func (s *GroupService) Delete(ctx context.Context, id int64) error {
return
fmt
.
Errorf
(
"get group: %w"
,
err
)
}
if
s
.
authCacheInvalidator
!=
nil
{
s
.
authCacheInvalidator
.
InvalidateAuthCacheByGroupID
(
ctx
,
id
)
}
if
err
:=
s
.
groupRepo
.
Delete
(
ctx
,
id
);
err
!=
nil
{
return
fmt
.
Errorf
(
"delete group: %w"
,
err
)
}
...
...
backend/internal/service/model_rate_limit.go
0 → 100644
View file @
b9b4db3d
package
service
import
(
"strings"
"time"
)
const
modelRateLimitsKey
=
"model_rate_limits"
const
modelRateLimitScopeClaudeSonnet
=
"claude_sonnet"
func
resolveModelRateLimitScope
(
requestedModel
string
)
(
string
,
bool
)
{
model
:=
strings
.
ToLower
(
strings
.
TrimSpace
(
requestedModel
))
if
model
==
""
{
return
""
,
false
}
model
=
strings
.
TrimPrefix
(
model
,
"models/"
)
if
strings
.
Contains
(
model
,
"sonnet"
)
{
return
modelRateLimitScopeClaudeSonnet
,
true
}
return
""
,
false
}
func
(
a
*
Account
)
isModelRateLimited
(
requestedModel
string
)
bool
{
scope
,
ok
:=
resolveModelRateLimitScope
(
requestedModel
)
if
!
ok
{
return
false
}
resetAt
:=
a
.
modelRateLimitResetAt
(
scope
)
if
resetAt
==
nil
{
return
false
}
return
time
.
Now
()
.
Before
(
*
resetAt
)
}
func
(
a
*
Account
)
modelRateLimitResetAt
(
scope
string
)
*
time
.
Time
{
if
a
==
nil
||
a
.
Extra
==
nil
||
scope
==
""
{
return
nil
}
rawLimits
,
ok
:=
a
.
Extra
[
modelRateLimitsKey
]
.
(
map
[
string
]
any
)
if
!
ok
{
return
nil
}
rawLimit
,
ok
:=
rawLimits
[
scope
]
.
(
map
[
string
]
any
)
if
!
ok
{
return
nil
}
resetAtRaw
,
ok
:=
rawLimit
[
"rate_limit_reset_at"
]
.
(
string
)
if
!
ok
||
strings
.
TrimSpace
(
resetAtRaw
)
==
""
{
return
nil
}
resetAt
,
err
:=
time
.
Parse
(
time
.
RFC3339
,
resetAtRaw
)
if
err
!=
nil
{
return
nil
}
return
&
resetAt
}
backend/internal/service/openai_codex_transform.go
0 → 100644
View file @
b9b4db3d
package
service
import
(
_
"embed"
"encoding/json"
"fmt"
"io"
"net/http"
"os"
"path/filepath"
"strings"
"time"
)
const
(
opencodeCodexHeaderURL
=
"https://raw.githubusercontent.com/anomalyco/opencode/dev/packages/opencode/src/session/prompt/codex_header.txt"
codexCacheTTL
=
15
*
time
.
Minute
)
//go:embed prompts/codex_cli_instructions.md
var
codexCLIInstructions
string
var
codexModelMap
=
map
[
string
]
string
{
"gpt-5.1-codex"
:
"gpt-5.1-codex"
,
"gpt-5.1-codex-low"
:
"gpt-5.1-codex"
,
"gpt-5.1-codex-medium"
:
"gpt-5.1-codex"
,
"gpt-5.1-codex-high"
:
"gpt-5.1-codex"
,
"gpt-5.1-codex-max"
:
"gpt-5.1-codex-max"
,
"gpt-5.1-codex-max-low"
:
"gpt-5.1-codex-max"
,
"gpt-5.1-codex-max-medium"
:
"gpt-5.1-codex-max"
,
"gpt-5.1-codex-max-high"
:
"gpt-5.1-codex-max"
,
"gpt-5.1-codex-max-xhigh"
:
"gpt-5.1-codex-max"
,
"gpt-5.2"
:
"gpt-5.2"
,
"gpt-5.2-none"
:
"gpt-5.2"
,
"gpt-5.2-low"
:
"gpt-5.2"
,
"gpt-5.2-medium"
:
"gpt-5.2"
,
"gpt-5.2-high"
:
"gpt-5.2"
,
"gpt-5.2-xhigh"
:
"gpt-5.2"
,
"gpt-5.2-codex"
:
"gpt-5.2-codex"
,
"gpt-5.2-codex-low"
:
"gpt-5.2-codex"
,
"gpt-5.2-codex-medium"
:
"gpt-5.2-codex"
,
"gpt-5.2-codex-high"
:
"gpt-5.2-codex"
,
"gpt-5.2-codex-xhigh"
:
"gpt-5.2-codex"
,
"gpt-5.1-codex-mini"
:
"gpt-5.1-codex-mini"
,
"gpt-5.1-codex-mini-medium"
:
"gpt-5.1-codex-mini"
,
"gpt-5.1-codex-mini-high"
:
"gpt-5.1-codex-mini"
,
"gpt-5.1"
:
"gpt-5.1"
,
"gpt-5.1-none"
:
"gpt-5.1"
,
"gpt-5.1-low"
:
"gpt-5.1"
,
"gpt-5.1-medium"
:
"gpt-5.1"
,
"gpt-5.1-high"
:
"gpt-5.1"
,
"gpt-5.1-chat-latest"
:
"gpt-5.1"
,
"gpt-5-codex"
:
"gpt-5.1-codex"
,
"codex-mini-latest"
:
"gpt-5.1-codex-mini"
,
"gpt-5-codex-mini"
:
"gpt-5.1-codex-mini"
,
"gpt-5-codex-mini-medium"
:
"gpt-5.1-codex-mini"
,
"gpt-5-codex-mini-high"
:
"gpt-5.1-codex-mini"
,
"gpt-5"
:
"gpt-5.1"
,
"gpt-5-mini"
:
"gpt-5.1"
,
"gpt-5-nano"
:
"gpt-5.1"
,
}
type
codexTransformResult
struct
{
Modified
bool
NormalizedModel
string
PromptCacheKey
string
}
type
opencodeCacheMetadata
struct
{
ETag
string
`json:"etag"`
LastFetch
string
`json:"lastFetch,omitempty"`
LastChecked
int64
`json:"lastChecked"`
}
func
applyCodexOAuthTransform
(
reqBody
map
[
string
]
any
)
codexTransformResult
{
result
:=
codexTransformResult
{}
// 工具续链需求会影响存储策略与 input 过滤逻辑。
needsToolContinuation
:=
NeedsToolContinuation
(
reqBody
)
model
:=
""
if
v
,
ok
:=
reqBody
[
"model"
]
.
(
string
);
ok
{
model
=
v
}
normalizedModel
:=
normalizeCodexModel
(
model
)
if
normalizedModel
!=
""
{
if
model
!=
normalizedModel
{
reqBody
[
"model"
]
=
normalizedModel
result
.
Modified
=
true
}
result
.
NormalizedModel
=
normalizedModel
}
// OAuth 走 ChatGPT internal API 时,store 必须为 false;显式 true 也会强制覆盖。
// 避免上游返回 "Store must be set to false"。
if
v
,
ok
:=
reqBody
[
"store"
]
.
(
bool
);
!
ok
||
v
{
reqBody
[
"store"
]
=
false
result
.
Modified
=
true
}
if
v
,
ok
:=
reqBody
[
"stream"
]
.
(
bool
);
!
ok
||
!
v
{
reqBody
[
"stream"
]
=
true
result
.
Modified
=
true
}
if
_
,
ok
:=
reqBody
[
"max_output_tokens"
];
ok
{
delete
(
reqBody
,
"max_output_tokens"
)
result
.
Modified
=
true
}
if
_
,
ok
:=
reqBody
[
"max_completion_tokens"
];
ok
{
delete
(
reqBody
,
"max_completion_tokens"
)
result
.
Modified
=
true
}
if
normalizeCodexTools
(
reqBody
)
{
result
.
Modified
=
true
}
if
v
,
ok
:=
reqBody
[
"prompt_cache_key"
]
.
(
string
);
ok
{
result
.
PromptCacheKey
=
strings
.
TrimSpace
(
v
)
}
instructions
:=
strings
.
TrimSpace
(
getOpenCodeCodexHeader
())
existingInstructions
,
_
:=
reqBody
[
"instructions"
]
.
(
string
)
existingInstructions
=
strings
.
TrimSpace
(
existingInstructions
)
if
instructions
!=
""
{
if
existingInstructions
!=
instructions
{
reqBody
[
"instructions"
]
=
instructions
result
.
Modified
=
true
}
}
else
if
existingInstructions
==
""
{
// 未获取到 opencode 指令时,回退使用 Codex CLI 指令。
codexInstructions
:=
strings
.
TrimSpace
(
getCodexCLIInstructions
())
if
codexInstructions
!=
""
{
reqBody
[
"instructions"
]
=
codexInstructions
result
.
Modified
=
true
}
}
// 续链场景保留 item_reference 与 id,避免 call_id 上下文丢失。
if
input
,
ok
:=
reqBody
[
"input"
]
.
([]
any
);
ok
{
input
=
filterCodexInput
(
input
,
needsToolContinuation
)
reqBody
[
"input"
]
=
input
result
.
Modified
=
true
}
return
result
}
func
normalizeCodexModel
(
model
string
)
string
{
if
model
==
""
{
return
"gpt-5.1"
}
modelID
:=
model
if
strings
.
Contains
(
modelID
,
"/"
)
{
parts
:=
strings
.
Split
(
modelID
,
"/"
)
modelID
=
parts
[
len
(
parts
)
-
1
]
}
if
mapped
:=
getNormalizedCodexModel
(
modelID
);
mapped
!=
""
{
return
mapped
}
normalized
:=
strings
.
ToLower
(
modelID
)
if
strings
.
Contains
(
normalized
,
"gpt-5.2-codex"
)
||
strings
.
Contains
(
normalized
,
"gpt 5.2 codex"
)
{
return
"gpt-5.2-codex"
}
if
strings
.
Contains
(
normalized
,
"gpt-5.2"
)
||
strings
.
Contains
(
normalized
,
"gpt 5.2"
)
{
return
"gpt-5.2"
}
if
strings
.
Contains
(
normalized
,
"gpt-5.1-codex-max"
)
||
strings
.
Contains
(
normalized
,
"gpt 5.1 codex max"
)
{
return
"gpt-5.1-codex-max"
}
if
strings
.
Contains
(
normalized
,
"gpt-5.1-codex-mini"
)
||
strings
.
Contains
(
normalized
,
"gpt 5.1 codex mini"
)
{
return
"gpt-5.1-codex-mini"
}
if
strings
.
Contains
(
normalized
,
"codex-mini-latest"
)
||
strings
.
Contains
(
normalized
,
"gpt-5-codex-mini"
)
||
strings
.
Contains
(
normalized
,
"gpt 5 codex mini"
)
{
return
"codex-mini-latest"
}
if
strings
.
Contains
(
normalized
,
"gpt-5.1-codex"
)
||
strings
.
Contains
(
normalized
,
"gpt 5.1 codex"
)
{
return
"gpt-5.1-codex"
}
if
strings
.
Contains
(
normalized
,
"gpt-5.1"
)
||
strings
.
Contains
(
normalized
,
"gpt 5.1"
)
{
return
"gpt-5.1"
}
if
strings
.
Contains
(
normalized
,
"codex"
)
{
return
"gpt-5.1-codex"
}
if
strings
.
Contains
(
normalized
,
"gpt-5"
)
||
strings
.
Contains
(
normalized
,
"gpt 5"
)
{
return
"gpt-5.1"
}
return
"gpt-5.1"
}
func
getNormalizedCodexModel
(
modelID
string
)
string
{
if
modelID
==
""
{
return
""
}
if
mapped
,
ok
:=
codexModelMap
[
modelID
];
ok
{
return
mapped
}
lower
:=
strings
.
ToLower
(
modelID
)
for
key
,
value
:=
range
codexModelMap
{
if
strings
.
ToLower
(
key
)
==
lower
{
return
value
}
}
return
""
}
func
getOpenCodeCachedPrompt
(
url
,
cacheFileName
,
metaFileName
string
)
string
{
cacheDir
:=
codexCachePath
(
""
)
if
cacheDir
==
""
{
return
""
}
cacheFile
:=
filepath
.
Join
(
cacheDir
,
cacheFileName
)
metaFile
:=
filepath
.
Join
(
cacheDir
,
metaFileName
)
var
cachedContent
string
if
content
,
ok
:=
readFile
(
cacheFile
);
ok
{
cachedContent
=
content
}
var
meta
opencodeCacheMetadata
if
loadJSON
(
metaFile
,
&
meta
)
&&
meta
.
LastChecked
>
0
&&
cachedContent
!=
""
{
if
time
.
Since
(
time
.
UnixMilli
(
meta
.
LastChecked
))
<
codexCacheTTL
{
return
cachedContent
}
}
content
,
etag
,
status
,
err
:=
fetchWithETag
(
url
,
meta
.
ETag
)
if
err
==
nil
&&
status
==
http
.
StatusNotModified
&&
cachedContent
!=
""
{
return
cachedContent
}
if
err
==
nil
&&
status
>=
200
&&
status
<
300
&&
content
!=
""
{
_
=
writeFile
(
cacheFile
,
content
)
meta
=
opencodeCacheMetadata
{
ETag
:
etag
,
LastFetch
:
time
.
Now
()
.
UTC
()
.
Format
(
time
.
RFC3339
),
LastChecked
:
time
.
Now
()
.
UnixMilli
(),
}
_
=
writeJSON
(
metaFile
,
meta
)
return
content
}
return
cachedContent
}
func
getOpenCodeCodexHeader
()
string
{
// 优先从 opencode 仓库缓存获取指令。
opencodeInstructions
:=
getOpenCodeCachedPrompt
(
opencodeCodexHeaderURL
,
"opencode-codex-header.txt"
,
"opencode-codex-header-meta.json"
)
// 若 opencode 指令可用,直接返回。
if
opencodeInstructions
!=
""
{
return
opencodeInstructions
}
// 否则回退使用本地 Codex CLI 指令。
return
getCodexCLIInstructions
()
}
func
getCodexCLIInstructions
()
string
{
return
codexCLIInstructions
}
func
GetOpenCodeInstructions
()
string
{
return
getOpenCodeCodexHeader
()
}
// GetCodexCLIInstructions 返回内置的 Codex CLI 指令内容。
func
GetCodexCLIInstructions
()
string
{
return
getCodexCLIInstructions
()
}
// ReplaceWithCodexInstructions 将请求 instructions 替换为内置 Codex 指令(必要时)。
func
ReplaceWithCodexInstructions
(
reqBody
map
[
string
]
any
)
bool
{
codexInstructions
:=
strings
.
TrimSpace
(
getCodexCLIInstructions
())
if
codexInstructions
==
""
{
return
false
}
existingInstructions
,
_
:=
reqBody
[
"instructions"
]
.
(
string
)
if
strings
.
TrimSpace
(
existingInstructions
)
!=
codexInstructions
{
reqBody
[
"instructions"
]
=
codexInstructions
return
true
}
return
false
}
// IsInstructionError 判断错误信息是否与指令格式/系统提示相关。
func
IsInstructionError
(
errorMessage
string
)
bool
{
if
errorMessage
==
""
{
return
false
}
lowerMsg
:=
strings
.
ToLower
(
errorMessage
)
instructionKeywords
:=
[]
string
{
"instruction"
,
"instructions"
,
"system prompt"
,
"system message"
,
"invalid prompt"
,
"prompt format"
,
}
for
_
,
keyword
:=
range
instructionKeywords
{
if
strings
.
Contains
(
lowerMsg
,
keyword
)
{
return
true
}
}
return
false
}
// filterCodexInput 按需过滤 item_reference 与 id。
// preserveReferences 为 true 时保持引用与 id,以满足续链请求对上下文的依赖。
func
filterCodexInput
(
input
[]
any
,
preserveReferences
bool
)
[]
any
{
filtered
:=
make
([]
any
,
0
,
len
(
input
))
for
_
,
item
:=
range
input
{
m
,
ok
:=
item
.
(
map
[
string
]
any
)
if
!
ok
{
filtered
=
append
(
filtered
,
item
)
continue
}
typ
,
_
:=
m
[
"type"
]
.
(
string
)
if
typ
==
"item_reference"
{
if
!
preserveReferences
{
continue
}
newItem
:=
make
(
map
[
string
]
any
,
len
(
m
))
for
key
,
value
:=
range
m
{
newItem
[
key
]
=
value
}
filtered
=
append
(
filtered
,
newItem
)
continue
}
newItem
:=
m
copied
:=
false
// 仅在需要修改字段时创建副本,避免直接改写原始输入。
ensureCopy
:=
func
()
{
if
copied
{
return
}
newItem
=
make
(
map
[
string
]
any
,
len
(
m
))
for
key
,
value
:=
range
m
{
newItem
[
key
]
=
value
}
copied
=
true
}
if
isCodexToolCallItemType
(
typ
)
{
if
callID
,
ok
:=
m
[
"call_id"
]
.
(
string
);
!
ok
||
strings
.
TrimSpace
(
callID
)
==
""
{
if
id
,
ok
:=
m
[
"id"
]
.
(
string
);
ok
&&
strings
.
TrimSpace
(
id
)
!=
""
{
ensureCopy
()
newItem
[
"call_id"
]
=
id
}
}
}
if
!
preserveReferences
{
ensureCopy
()
delete
(
newItem
,
"id"
)
if
!
isCodexToolCallItemType
(
typ
)
{
delete
(
newItem
,
"call_id"
)
}
}
filtered
=
append
(
filtered
,
newItem
)
}
return
filtered
}
func
isCodexToolCallItemType
(
typ
string
)
bool
{
if
typ
==
""
{
return
false
}
return
strings
.
HasSuffix
(
typ
,
"_call"
)
||
strings
.
HasSuffix
(
typ
,
"_call_output"
)
}
func
normalizeCodexTools
(
reqBody
map
[
string
]
any
)
bool
{
rawTools
,
ok
:=
reqBody
[
"tools"
]
if
!
ok
||
rawTools
==
nil
{
return
false
}
tools
,
ok
:=
rawTools
.
([]
any
)
if
!
ok
{
return
false
}
modified
:=
false
for
idx
,
tool
:=
range
tools
{
toolMap
,
ok
:=
tool
.
(
map
[
string
]
any
)
if
!
ok
{
continue
}
toolType
,
_
:=
toolMap
[
"type"
]
.
(
string
)
if
strings
.
TrimSpace
(
toolType
)
!=
"function"
{
continue
}
function
,
ok
:=
toolMap
[
"function"
]
.
(
map
[
string
]
any
)
if
!
ok
{
continue
}
if
_
,
ok
:=
toolMap
[
"name"
];
!
ok
{
if
name
,
ok
:=
function
[
"name"
]
.
(
string
);
ok
&&
strings
.
TrimSpace
(
name
)
!=
""
{
toolMap
[
"name"
]
=
name
modified
=
true
}
}
if
_
,
ok
:=
toolMap
[
"description"
];
!
ok
{
if
desc
,
ok
:=
function
[
"description"
]
.
(
string
);
ok
&&
strings
.
TrimSpace
(
desc
)
!=
""
{
toolMap
[
"description"
]
=
desc
modified
=
true
}
}
if
_
,
ok
:=
toolMap
[
"parameters"
];
!
ok
{
if
params
,
ok
:=
function
[
"parameters"
];
ok
{
toolMap
[
"parameters"
]
=
params
modified
=
true
}
}
if
_
,
ok
:=
toolMap
[
"strict"
];
!
ok
{
if
strict
,
ok
:=
function
[
"strict"
];
ok
{
toolMap
[
"strict"
]
=
strict
modified
=
true
}
}
tools
[
idx
]
=
toolMap
}
if
modified
{
reqBody
[
"tools"
]
=
tools
}
return
modified
}
func
codexCachePath
(
filename
string
)
string
{
home
,
err
:=
os
.
UserHomeDir
()
if
err
!=
nil
{
return
""
}
cacheDir
:=
filepath
.
Join
(
home
,
".opencode"
,
"cache"
)
if
filename
==
""
{
return
cacheDir
}
return
filepath
.
Join
(
cacheDir
,
filename
)
}
func
readFile
(
path
string
)
(
string
,
bool
)
{
if
path
==
""
{
return
""
,
false
}
data
,
err
:=
os
.
ReadFile
(
path
)
if
err
!=
nil
{
return
""
,
false
}
return
string
(
data
),
true
}
func
writeFile
(
path
,
content
string
)
error
{
if
path
==
""
{
return
fmt
.
Errorf
(
"empty cache path"
)
}
if
err
:=
os
.
MkdirAll
(
filepath
.
Dir
(
path
),
0
o755
);
err
!=
nil
{
return
err
}
return
os
.
WriteFile
(
path
,
[]
byte
(
content
),
0
o644
)
}
func
loadJSON
(
path
string
,
target
any
)
bool
{
data
,
err
:=
os
.
ReadFile
(
path
)
if
err
!=
nil
{
return
false
}
if
err
:=
json
.
Unmarshal
(
data
,
target
);
err
!=
nil
{
return
false
}
return
true
}
func
writeJSON
(
path
string
,
value
any
)
error
{
if
path
==
""
{
return
fmt
.
Errorf
(
"empty json path"
)
}
if
err
:=
os
.
MkdirAll
(
filepath
.
Dir
(
path
),
0
o755
);
err
!=
nil
{
return
err
}
data
,
err
:=
json
.
Marshal
(
value
)
if
err
!=
nil
{
return
err
}
return
os
.
WriteFile
(
path
,
data
,
0
o644
)
}
func
fetchWithETag
(
url
,
etag
string
)
(
string
,
string
,
int
,
error
)
{
req
,
err
:=
http
.
NewRequest
(
http
.
MethodGet
,
url
,
nil
)
if
err
!=
nil
{
return
""
,
""
,
0
,
err
}
req
.
Header
.
Set
(
"User-Agent"
,
"sub2api-codex"
)
if
etag
!=
""
{
req
.
Header
.
Set
(
"If-None-Match"
,
etag
)
}
resp
,
err
:=
http
.
DefaultClient
.
Do
(
req
)
if
err
!=
nil
{
return
""
,
""
,
0
,
err
}
defer
func
()
{
_
=
resp
.
Body
.
Close
()
}()
body
,
err
:=
io
.
ReadAll
(
resp
.
Body
)
if
err
!=
nil
{
return
""
,
""
,
resp
.
StatusCode
,
err
}
return
string
(
body
),
resp
.
Header
.
Get
(
"etag"
),
resp
.
StatusCode
,
nil
}
backend/internal/service/openai_codex_transform_test.go
0 → 100644
View file @
b9b4db3d
package
service
import
(
"encoding/json"
"os"
"path/filepath"
"testing"
"time"
"github.com/stretchr/testify/require"
)
func
TestApplyCodexOAuthTransform_ToolContinuationPreservesInput
(
t
*
testing
.
T
)
{
// 续链场景:保留 item_reference 与 id,但不再强制 store=true。
setupCodexCache
(
t
)
reqBody
:=
map
[
string
]
any
{
"model"
:
"gpt-5.2"
,
"input"
:
[]
any
{
map
[
string
]
any
{
"type"
:
"item_reference"
,
"id"
:
"ref1"
,
"text"
:
"x"
},
map
[
string
]
any
{
"type"
:
"function_call_output"
,
"call_id"
:
"call_1"
,
"output"
:
"ok"
,
"id"
:
"o1"
},
},
"tool_choice"
:
"auto"
,
}
applyCodexOAuthTransform
(
reqBody
)
// 未显式设置 store=true,默认为 false。
store
,
ok
:=
reqBody
[
"store"
]
.
(
bool
)
require
.
True
(
t
,
ok
)
require
.
False
(
t
,
store
)
input
,
ok
:=
reqBody
[
"input"
]
.
([]
any
)
require
.
True
(
t
,
ok
)
require
.
Len
(
t
,
input
,
2
)
// 校验 input[0] 为 map,避免断言失败导致测试中断。
first
,
ok
:=
input
[
0
]
.
(
map
[
string
]
any
)
require
.
True
(
t
,
ok
)
require
.
Equal
(
t
,
"item_reference"
,
first
[
"type"
])
require
.
Equal
(
t
,
"ref1"
,
first
[
"id"
])
// 校验 input[1] 为 map,确保后续字段断言安全。
second
,
ok
:=
input
[
1
]
.
(
map
[
string
]
any
)
require
.
True
(
t
,
ok
)
require
.
Equal
(
t
,
"o1"
,
second
[
"id"
])
}
func
TestApplyCodexOAuthTransform_ExplicitStoreFalsePreserved
(
t
*
testing
.
T
)
{
// 续链场景:显式 store=false 不再强制为 true,保持 false。
setupCodexCache
(
t
)
reqBody
:=
map
[
string
]
any
{
"model"
:
"gpt-5.1"
,
"store"
:
false
,
"input"
:
[]
any
{
map
[
string
]
any
{
"type"
:
"function_call_output"
,
"call_id"
:
"call_1"
},
},
"tool_choice"
:
"auto"
,
}
applyCodexOAuthTransform
(
reqBody
)
store
,
ok
:=
reqBody
[
"store"
]
.
(
bool
)
require
.
True
(
t
,
ok
)
require
.
False
(
t
,
store
)
}
func
TestApplyCodexOAuthTransform_ExplicitStoreTrueForcedFalse
(
t
*
testing
.
T
)
{
// 显式 store=true 也会强制为 false。
setupCodexCache
(
t
)
reqBody
:=
map
[
string
]
any
{
"model"
:
"gpt-5.1"
,
"store"
:
true
,
"input"
:
[]
any
{
map
[
string
]
any
{
"type"
:
"function_call_output"
,
"call_id"
:
"call_1"
},
},
"tool_choice"
:
"auto"
,
}
applyCodexOAuthTransform
(
reqBody
)
store
,
ok
:=
reqBody
[
"store"
]
.
(
bool
)
require
.
True
(
t
,
ok
)
require
.
False
(
t
,
store
)
}
func
TestApplyCodexOAuthTransform_NonContinuationDefaultsStoreFalseAndStripsIDs
(
t
*
testing
.
T
)
{
// 非续链场景:未设置 store 时默认 false,并移除 input 中的 id。
setupCodexCache
(
t
)
reqBody
:=
map
[
string
]
any
{
"model"
:
"gpt-5.1"
,
"input"
:
[]
any
{
map
[
string
]
any
{
"type"
:
"text"
,
"id"
:
"t1"
,
"text"
:
"hi"
},
},
}
applyCodexOAuthTransform
(
reqBody
)
store
,
ok
:=
reqBody
[
"store"
]
.
(
bool
)
require
.
True
(
t
,
ok
)
require
.
False
(
t
,
store
)
input
,
ok
:=
reqBody
[
"input"
]
.
([]
any
)
require
.
True
(
t
,
ok
)
require
.
Len
(
t
,
input
,
1
)
// 校验 input[0] 为 map,避免类型不匹配触发 errcheck。
item
,
ok
:=
input
[
0
]
.
(
map
[
string
]
any
)
require
.
True
(
t
,
ok
)
_
,
hasID
:=
item
[
"id"
]
require
.
False
(
t
,
hasID
)
}
func
TestFilterCodexInput_RemovesItemReferenceWhenNotPreserved
(
t
*
testing
.
T
)
{
input
:=
[]
any
{
map
[
string
]
any
{
"type"
:
"item_reference"
,
"id"
:
"ref1"
},
map
[
string
]
any
{
"type"
:
"text"
,
"id"
:
"t1"
,
"text"
:
"hi"
},
}
filtered
:=
filterCodexInput
(
input
,
false
)
require
.
Len
(
t
,
filtered
,
1
)
// 校验 filtered[0] 为 map,确保字段检查可靠。
item
,
ok
:=
filtered
[
0
]
.
(
map
[
string
]
any
)
require
.
True
(
t
,
ok
)
require
.
Equal
(
t
,
"text"
,
item
[
"type"
])
_
,
hasID
:=
item
[
"id"
]
require
.
False
(
t
,
hasID
)
}
func
TestApplyCodexOAuthTransform_EmptyInput
(
t
*
testing
.
T
)
{
// 空 input 应保持为空且不触发异常。
setupCodexCache
(
t
)
reqBody
:=
map
[
string
]
any
{
"model"
:
"gpt-5.1"
,
"input"
:
[]
any
{},
}
applyCodexOAuthTransform
(
reqBody
)
input
,
ok
:=
reqBody
[
"input"
]
.
([]
any
)
require
.
True
(
t
,
ok
)
require
.
Len
(
t
,
input
,
0
)
}
func
setupCodexCache
(
t
*
testing
.
T
)
{
t
.
Helper
()
// 使用临时 HOME 避免触发网络拉取 header。
tempDir
:=
t
.
TempDir
()
t
.
Setenv
(
"HOME"
,
tempDir
)
cacheDir
:=
filepath
.
Join
(
tempDir
,
".opencode"
,
"cache"
)
require
.
NoError
(
t
,
os
.
MkdirAll
(
cacheDir
,
0
o755
))
require
.
NoError
(
t
,
os
.
WriteFile
(
filepath
.
Join
(
cacheDir
,
"opencode-codex-header.txt"
),
[]
byte
(
"header"
),
0
o644
))
meta
:=
map
[
string
]
any
{
"etag"
:
""
,
"lastFetch"
:
time
.
Now
()
.
UTC
()
.
Format
(
time
.
RFC3339
),
"lastChecked"
:
time
.
Now
()
.
UnixMilli
(),
}
data
,
err
:=
json
.
Marshal
(
meta
)
require
.
NoError
(
t
,
err
)
require
.
NoError
(
t
,
os
.
WriteFile
(
filepath
.
Join
(
cacheDir
,
"opencode-codex-header-meta.json"
),
data
,
0
o644
))
}
backend/internal/service/openai_gateway_service.go
View file @
b9b4db3d
...
...
@@ -20,6 +20,7 @@ import (
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
"github.com/Wei-Shaw/sub2api/internal/util/responseheaders"
"github.com/Wei-Shaw/sub2api/internal/util/urlvalidator"
"github.com/gin-gonic/gin"
...
...
@@ -41,6 +42,7 @@ var openaiSSEDataRe = regexp.MustCompile(`^data:\s*`)
var
openaiAllowedHeaders
=
map
[
string
]
bool
{
"accept-language"
:
true
,
"content-type"
:
true
,
"conversation_id"
:
true
,
"user-agent"
:
true
,
"originator"
:
true
,
"session_id"
:
true
,
...
...
@@ -84,12 +86,15 @@ type OpenAIGatewayService struct {
userSubRepo
UserSubscriptionRepository
cache
GatewayCache
cfg
*
config
.
Config
schedulerSnapshot
*
SchedulerSnapshotService
concurrencyService
*
ConcurrencyService
billingService
*
BillingService
rateLimitService
*
RateLimitService
billingCacheService
*
BillingCacheService
httpUpstream
HTTPUpstream
deferredService
*
DeferredService
openAITokenProvider
*
OpenAITokenProvider
toolCorrector
*
CodexToolCorrector
}
// NewOpenAIGatewayService creates a new OpenAIGatewayService
...
...
@@ -100,12 +105,14 @@ func NewOpenAIGatewayService(
userSubRepo
UserSubscriptionRepository
,
cache
GatewayCache
,
cfg
*
config
.
Config
,
schedulerSnapshot
*
SchedulerSnapshotService
,
concurrencyService
*
ConcurrencyService
,
billingService
*
BillingService
,
rateLimitService
*
RateLimitService
,
billingCacheService
*
BillingCacheService
,
httpUpstream
HTTPUpstream
,
deferredService
*
DeferredService
,
openAITokenProvider
*
OpenAITokenProvider
,
)
*
OpenAIGatewayService
{
return
&
OpenAIGatewayService
{
accountRepo
:
accountRepo
,
...
...
@@ -114,12 +121,15 @@ func NewOpenAIGatewayService(
userSubRepo
:
userSubRepo
,
cache
:
cache
,
cfg
:
cfg
,
schedulerSnapshot
:
schedulerSnapshot
,
concurrencyService
:
concurrencyService
,
billingService
:
billingService
,
rateLimitService
:
rateLimitService
,
billingCacheService
:
billingCacheService
,
httpUpstream
:
httpUpstream
,
deferredService
:
deferredService
,
openAITokenProvider
:
openAITokenProvider
,
toolCorrector
:
NewCodexToolCorrector
(),
}
}
...
...
@@ -158,7 +168,7 @@ func (s *OpenAIGatewayService) SelectAccountForModelWithExclusions(ctx context.C
accountID
,
err
:=
s
.
cache
.
GetSessionAccountID
(
ctx
,
derefGroupID
(
groupID
),
"openai:"
+
sessionHash
)
if
err
==
nil
&&
accountID
>
0
{
if
_
,
excluded
:=
excludedIDs
[
accountID
];
!
excluded
{
account
,
err
:=
s
.
accountRepo
.
GetByID
(
ctx
,
accountID
)
account
,
err
:=
s
.
getSchedulableAccount
(
ctx
,
accountID
)
if
err
==
nil
&&
account
.
IsSchedulable
()
&&
account
.
IsOpenAI
()
&&
(
requestedModel
==
""
||
account
.
IsModelSupported
(
requestedModel
))
{
// Refresh sticky session TTL
_
=
s
.
cache
.
RefreshSessionTTL
(
ctx
,
derefGroupID
(
groupID
),
"openai:"
+
sessionHash
,
openaiStickySessionTTL
)
...
...
@@ -169,16 +179,7 @@ func (s *OpenAIGatewayService) SelectAccountForModelWithExclusions(ctx context.C
}
// 2. Get schedulable OpenAI accounts
var
accounts
[]
Account
var
err
error
// 简易模式:忽略分组限制,查询所有可用账号
if
s
.
cfg
.
RunMode
==
config
.
RunModeSimple
{
accounts
,
err
=
s
.
accountRepo
.
ListSchedulableByPlatform
(
ctx
,
PlatformOpenAI
)
}
else
if
groupID
!=
nil
{
accounts
,
err
=
s
.
accountRepo
.
ListSchedulableByGroupIDAndPlatform
(
ctx
,
*
groupID
,
PlatformOpenAI
)
}
else
{
accounts
,
err
=
s
.
accountRepo
.
ListSchedulableByPlatform
(
ctx
,
PlatformOpenAI
)
}
accounts
,
err
:=
s
.
listSchedulableAccounts
(
ctx
,
groupID
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"query accounts failed: %w"
,
err
)
}
...
...
@@ -190,6 +191,11 @@ func (s *OpenAIGatewayService) SelectAccountForModelWithExclusions(ctx context.C
if
_
,
excluded
:=
excludedIDs
[
acc
.
ID
];
excluded
{
continue
}
// Scheduler snapshots can be temporarily stale; re-check schedulability here to
// avoid selecting accounts that were recently rate-limited/overloaded.
if
!
acc
.
IsSchedulable
()
{
continue
}
// Check model support
if
requestedModel
!=
""
&&
!
acc
.
IsModelSupported
(
requestedModel
)
{
continue
...
...
@@ -300,7 +306,7 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex
if
sessionHash
!=
""
{
accountID
,
err
:=
s
.
cache
.
GetSessionAccountID
(
ctx
,
derefGroupID
(
groupID
),
"openai:"
+
sessionHash
)
if
err
==
nil
&&
accountID
>
0
&&
!
isExcluded
(
accountID
)
{
account
,
err
:=
s
.
accountRepo
.
GetByID
(
ctx
,
accountID
)
account
,
err
:=
s
.
getSchedulableAccount
(
ctx
,
accountID
)
if
err
==
nil
&&
account
.
IsSchedulable
()
&&
account
.
IsOpenAI
()
&&
(
requestedModel
==
""
||
account
.
IsModelSupported
(
requestedModel
))
{
result
,
err
:=
s
.
tryAcquireAccountSlot
(
ctx
,
accountID
,
account
.
Concurrency
)
...
...
@@ -336,6 +342,12 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex
if
isExcluded
(
acc
.
ID
)
{
continue
}
// Scheduler snapshots can be temporarily stale (bucket rebuild is throttled);
// re-check schedulability here so recently rate-limited/overloaded accounts
// are not selected again before the bucket is rebuilt.
if
!
acc
.
IsSchedulable
()
{
continue
}
if
requestedModel
!=
""
&&
!
acc
.
IsModelSupported
(
requestedModel
)
{
continue
}
...
...
@@ -445,6 +457,10 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex
}
func
(
s
*
OpenAIGatewayService
)
listSchedulableAccounts
(
ctx
context
.
Context
,
groupID
*
int64
)
([]
Account
,
error
)
{
if
s
.
schedulerSnapshot
!=
nil
{
accounts
,
_
,
err
:=
s
.
schedulerSnapshot
.
ListSchedulableAccounts
(
ctx
,
groupID
,
PlatformOpenAI
,
false
)
return
accounts
,
err
}
var
accounts
[]
Account
var
err
error
if
s
.
cfg
!=
nil
&&
s
.
cfg
.
RunMode
==
config
.
RunModeSimple
{
...
...
@@ -467,6 +483,13 @@ func (s *OpenAIGatewayService) tryAcquireAccountSlot(ctx context.Context, accoun
return
s
.
concurrencyService
.
AcquireAccountSlot
(
ctx
,
accountID
,
maxConcurrency
)
}
func
(
s
*
OpenAIGatewayService
)
getSchedulableAccount
(
ctx
context
.
Context
,
accountID
int64
)
(
*
Account
,
error
)
{
if
s
.
schedulerSnapshot
!=
nil
{
return
s
.
schedulerSnapshot
.
GetAccount
(
ctx
,
accountID
)
}
return
s
.
accountRepo
.
GetByID
(
ctx
,
accountID
)
}
func
(
s
*
OpenAIGatewayService
)
schedulingConfig
()
config
.
GatewaySchedulingConfig
{
if
s
.
cfg
!=
nil
{
return
s
.
cfg
.
Gateway
.
Scheduling
...
...
@@ -485,6 +508,15 @@ func (s *OpenAIGatewayService) schedulingConfig() config.GatewaySchedulingConfig
func
(
s
*
OpenAIGatewayService
)
GetAccessToken
(
ctx
context
.
Context
,
account
*
Account
)
(
string
,
string
,
error
)
{
switch
account
.
Type
{
case
AccountTypeOAuth
:
// 使用 TokenProvider 获取缓存的 token
if
s
.
openAITokenProvider
!=
nil
{
accessToken
,
err
:=
s
.
openAITokenProvider
.
GetAccessToken
(
ctx
,
account
)
if
err
!=
nil
{
return
""
,
""
,
err
}
return
accessToken
,
"oauth"
,
nil
}
// 降级:TokenProvider 未配置时直接从账号读取
accessToken
:=
account
.
GetOpenAIAccessToken
()
if
accessToken
==
""
{
return
""
,
""
,
errors
.
New
(
"access_token not found in credentials"
)
...
...
@@ -511,7 +543,7 @@ func (s *OpenAIGatewayService) shouldFailoverUpstreamError(statusCode int) bool
}
func
(
s
*
OpenAIGatewayService
)
handleFailoverSideEffects
(
ctx
context
.
Context
,
resp
*
http
.
Response
,
account
*
Account
)
{
body
,
_
:=
io
.
ReadAll
(
resp
.
Body
)
body
,
_
:=
io
.
ReadAll
(
io
.
LimitReader
(
resp
.
Body
,
2
<<
20
)
)
s
.
rateLimitService
.
HandleUpstreamError
(
ctx
,
account
,
resp
.
StatusCode
,
resp
.
Header
,
body
)
}
...
...
@@ -528,32 +560,96 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
// Extract model and stream from parsed body
reqModel
,
_
:=
reqBody
[
"model"
]
.
(
string
)
reqStream
,
_
:=
reqBody
[
"stream"
]
.
(
bool
)
promptCacheKey
:=
""
if
v
,
ok
:=
reqBody
[
"prompt_cache_key"
]
.
(
string
);
ok
{
promptCacheKey
=
strings
.
TrimSpace
(
v
)
}
// Track if body needs re-serialization
bodyModified
:=
false
originalModel
:=
reqModel
// Apply model mapping
isCodexCLI
:=
openai
.
IsCodexCLIRequest
(
c
.
GetHeader
(
"User-Agent"
))
// 对所有请求执行模型映射(包含 Codex CLI)。
mappedModel
:=
account
.
GetMappedModel
(
reqModel
)
if
mappedModel
!=
reqModel
{
log
.
Printf
(
"[OpenAI] Model mapping applied: %s -> %s (account: %s, isCodexCLI: %v)"
,
reqModel
,
mappedModel
,
account
.
Name
,
isCodexCLI
)
reqBody
[
"model"
]
=
mappedModel
bodyModified
=
true
}
// For OAuth accounts using ChatGPT internal API:
// 1. Add store: false
// 2. Normalize input format for Codex API compatibility
if
account
.
Type
==
AccountTypeOAuth
{
reqBody
[
"store"
]
=
false
// 针对所有 OpenAI 账号执行 Codex 模型名规范化,确保上游识别一致。
if
model
,
ok
:=
reqBody
[
"model"
]
.
(
string
);
ok
{
normalizedModel
:=
normalizeCodexModel
(
model
)
if
normalizedModel
!=
""
&&
normalizedModel
!=
model
{
log
.
Printf
(
"[OpenAI] Codex model normalization: %s -> %s (account: %s, type: %s, isCodexCLI: %v)"
,
model
,
normalizedModel
,
account
.
Name
,
account
.
Type
,
isCodexCLI
)
reqBody
[
"model"
]
=
normalizedModel
mappedModel
=
normalizedModel
bodyModified
=
true
}
}
// 规范化 reasoning.effort 参数(minimal -> none),与上游允许值对齐。
if
reasoning
,
ok
:=
reqBody
[
"reasoning"
]
.
(
map
[
string
]
any
);
ok
{
if
effort
,
ok
:=
reasoning
[
"effort"
]
.
(
string
);
ok
&&
effort
==
"minimal"
{
reasoning
[
"effort"
]
=
"none"
bodyModified
=
true
log
.
Printf
(
"[OpenAI] Normalized reasoning.effort: minimal -> none (account: %s)"
,
account
.
Name
)
}
}
if
account
.
Type
==
AccountTypeOAuth
&&
!
isCodexCLI
{
codexResult
:=
applyCodexOAuthTransform
(
reqBody
)
if
codexResult
.
Modified
{
bodyModified
=
true
}
if
codexResult
.
NormalizedModel
!=
""
{
mappedModel
=
codexResult
.
NormalizedModel
}
if
codexResult
.
PromptCacheKey
!=
""
{
promptCacheKey
=
codexResult
.
PromptCacheKey
}
}
// Handle max_output_tokens based on platform and account type
if
!
isCodexCLI
{
if
maxOutputTokens
,
hasMaxOutputTokens
:=
reqBody
[
"max_output_tokens"
];
hasMaxOutputTokens
{
switch
account
.
Platform
{
case
PlatformOpenAI
:
// For OpenAI API Key, remove max_output_tokens (not supported)
// For OpenAI OAuth (Responses API), keep it (supported)
if
account
.
Type
==
AccountTypeAPIKey
{
delete
(
reqBody
,
"max_output_tokens"
)
bodyModified
=
true
}
case
PlatformAnthropic
:
// For Anthropic (Claude), convert to max_tokens
delete
(
reqBody
,
"max_output_tokens"
)
if
_
,
hasMaxTokens
:=
reqBody
[
"max_tokens"
];
!
hasMaxTokens
{
reqBody
[
"max_tokens"
]
=
maxOutputTokens
}
bodyModified
=
true
case
PlatformGemini
:
// For Gemini, remove (will be handled by Gemini-specific transform)
delete
(
reqBody
,
"max_output_tokens"
)
bodyModified
=
true
default
:
// For unknown platforms, remove to be safe
delete
(
reqBody
,
"max_output_tokens"
)
bodyModified
=
true
}
}
//
Normalize input format: convert AI SDK multi-part content format to simplified format
// AI SDK sends: {"content": [{"type": "input_text", "text": "..."}]}
// Codex API expects: {"
cont
ent": "..."}
if
normalizeInputForCodexAPI
(
reqBody
)
{
//
Also handle max_completion_tokens (similar logic)
if
_
,
hasMaxCompletionTokens
:=
reqBody
[
"max_completion_tokens"
];
hasMaxCompletionTokens
{
if
account
.
Type
==
AccountTypeAPIKey
||
ac
co
u
nt
.
Platform
!=
PlatformOpenAI
{
delete
(
reqBody
,
"max_completion_tokens"
)
bodyModified
=
true
}
}
}
// Re-serialize body only if modified
if
bodyModified
{
...
...
@@ -571,7 +667,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
}
// Build upstream request
upstreamReq
,
err
:=
s
.
buildUpstreamRequest
(
ctx
,
c
,
account
,
body
,
token
,
reqStream
)
upstreamReq
,
err
:=
s
.
buildUpstreamRequest
(
ctx
,
c
,
account
,
body
,
token
,
reqStream
,
promptCacheKey
,
isCodexCLI
)
if
err
!=
nil
{
return
nil
,
err
}
...
...
@@ -582,16 +678,63 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
proxyURL
=
account
.
Proxy
.
URL
()
}
// Capture upstream request body for ops retry of this attempt.
if
c
!=
nil
{
c
.
Set
(
OpsUpstreamRequestBodyKey
,
string
(
body
))
}
// Send request
resp
,
err
:=
s
.
httpUpstream
.
Do
(
upstreamReq
,
proxyURL
,
account
.
ID
,
account
.
Concurrency
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"upstream request failed: %w"
,
err
)
// Ensure the client receives an error response (handlers assume Forward writes on non-failover errors).
safeErr
:=
sanitizeUpstreamErrorMessage
(
err
.
Error
())
setOpsUpstreamError
(
c
,
0
,
safeErr
,
""
)
appendOpsUpstreamError
(
c
,
OpsUpstreamErrorEvent
{
Platform
:
account
.
Platform
,
AccountID
:
account
.
ID
,
AccountName
:
account
.
Name
,
UpstreamStatusCode
:
0
,
Kind
:
"request_error"
,
Message
:
safeErr
,
})
c
.
JSON
(
http
.
StatusBadGateway
,
gin
.
H
{
"error"
:
gin
.
H
{
"type"
:
"upstream_error"
,
"message"
:
"Upstream request failed"
,
},
})
return
nil
,
fmt
.
Errorf
(
"upstream request failed: %s"
,
safeErr
)
}
defer
func
()
{
_
=
resp
.
Body
.
Close
()
}()
// Handle error response
if
resp
.
StatusCode
>=
400
{
if
s
.
shouldFailoverUpstreamError
(
resp
.
StatusCode
)
{
respBody
,
_
:=
io
.
ReadAll
(
io
.
LimitReader
(
resp
.
Body
,
2
<<
20
))
_
=
resp
.
Body
.
Close
()
resp
.
Body
=
io
.
NopCloser
(
bytes
.
NewReader
(
respBody
))
upstreamMsg
:=
strings
.
TrimSpace
(
extractUpstreamErrorMessage
(
respBody
))
upstreamMsg
=
sanitizeUpstreamErrorMessage
(
upstreamMsg
)
upstreamDetail
:=
""
if
s
.
cfg
!=
nil
&&
s
.
cfg
.
Gateway
.
LogUpstreamErrorBody
{
maxBytes
:=
s
.
cfg
.
Gateway
.
LogUpstreamErrorBodyMaxBytes
if
maxBytes
<=
0
{
maxBytes
=
2048
}
upstreamDetail
=
truncateString
(
string
(
respBody
),
maxBytes
)
}
appendOpsUpstreamError
(
c
,
OpsUpstreamErrorEvent
{
Platform
:
account
.
Platform
,
AccountID
:
account
.
ID
,
AccountName
:
account
.
Name
,
UpstreamStatusCode
:
resp
.
StatusCode
,
UpstreamRequestID
:
resp
.
Header
.
Get
(
"x-request-id"
),
Kind
:
"failover"
,
Message
:
upstreamMsg
,
Detail
:
upstreamDetail
,
})
s
.
handleFailoverSideEffects
(
ctx
,
resp
,
account
)
return
nil
,
&
UpstreamFailoverError
{
StatusCode
:
resp
.
StatusCode
}
}
...
...
@@ -632,7 +775,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
},
nil
}
func
(
s
*
OpenAIGatewayService
)
buildUpstreamRequest
(
ctx
context
.
Context
,
c
*
gin
.
Context
,
account
*
Account
,
body
[]
byte
,
token
string
,
isStream
bool
)
(
*
http
.
Request
,
error
)
{
func
(
s
*
OpenAIGatewayService
)
buildUpstreamRequest
(
ctx
context
.
Context
,
c
*
gin
.
Context
,
account
*
Account
,
body
[]
byte
,
token
string
,
isStream
bool
,
promptCacheKey
string
,
isCodexCLI
bool
)
(
*
http
.
Request
,
error
)
{
// Determine target URL based on account type
var
targetURL
string
switch
account
.
Type
{
...
...
@@ -672,12 +815,6 @@ func (s *OpenAIGatewayService) buildUpstreamRequest(ctx context.Context, c *gin.
if
chatgptAccountID
!=
""
{
req
.
Header
.
Set
(
"chatgpt-account-id"
,
chatgptAccountID
)
}
// Set accept header based on stream mode
if
isStream
{
req
.
Header
.
Set
(
"accept"
,
"text/event-stream"
)
}
else
{
req
.
Header
.
Set
(
"accept"
,
"application/json"
)
}
}
// Whitelist passthrough headers
...
...
@@ -689,6 +826,19 @@ func (s *OpenAIGatewayService) buildUpstreamRequest(ctx context.Context, c *gin.
}
}
}
if
account
.
Type
==
AccountTypeOAuth
{
req
.
Header
.
Set
(
"OpenAI-Beta"
,
"responses=experimental"
)
if
isCodexCLI
{
req
.
Header
.
Set
(
"originator"
,
"codex_cli_rs"
)
}
else
{
req
.
Header
.
Set
(
"originator"
,
"opencode"
)
}
req
.
Header
.
Set
(
"accept"
,
"text/event-stream"
)
if
promptCacheKey
!=
""
{
req
.
Header
.
Set
(
"conversation_id"
,
promptCacheKey
)
req
.
Header
.
Set
(
"session_id"
,
promptCacheKey
)
}
}
// Apply custom User-Agent if configured
customUA
:=
account
.
GetOpenAIUserAgent
()
...
...
@@ -705,24 +855,74 @@ func (s *OpenAIGatewayService) buildUpstreamRequest(ctx context.Context, c *gin.
}
func
(
s
*
OpenAIGatewayService
)
handleErrorResponse
(
ctx
context
.
Context
,
resp
*
http
.
Response
,
c
*
gin
.
Context
,
account
*
Account
)
(
*
OpenAIForwardResult
,
error
)
{
body
,
_
:=
io
.
ReadAll
(
resp
.
Body
)
body
,
_
:=
io
.
ReadAll
(
io
.
LimitReader
(
resp
.
Body
,
2
<<
20
))
upstreamMsg
:=
strings
.
TrimSpace
(
extractUpstreamErrorMessage
(
body
))
upstreamMsg
=
sanitizeUpstreamErrorMessage
(
upstreamMsg
)
upstreamDetail
:=
""
if
s
.
cfg
!=
nil
&&
s
.
cfg
.
Gateway
.
LogUpstreamErrorBody
{
maxBytes
:=
s
.
cfg
.
Gateway
.
LogUpstreamErrorBodyMaxBytes
if
maxBytes
<=
0
{
maxBytes
=
2048
}
upstreamDetail
=
truncateString
(
string
(
body
),
maxBytes
)
}
setOpsUpstreamError
(
c
,
resp
.
StatusCode
,
upstreamMsg
,
upstreamDetail
)
if
s
.
cfg
!=
nil
&&
s
.
cfg
.
Gateway
.
LogUpstreamErrorBody
{
log
.
Printf
(
"OpenAI upstream error %d (account=%d platform=%s type=%s): %s"
,
resp
.
StatusCode
,
account
.
ID
,
account
.
Platform
,
account
.
Type
,
truncateForLog
(
body
,
s
.
cfg
.
Gateway
.
LogUpstreamErrorBodyMaxBytes
),
)
}
// Check custom error codes
if
!
account
.
ShouldHandleErrorCode
(
resp
.
StatusCode
)
{
appendOpsUpstreamError
(
c
,
OpsUpstreamErrorEvent
{
Platform
:
account
.
Platform
,
AccountID
:
account
.
ID
,
AccountName
:
account
.
Name
,
UpstreamStatusCode
:
resp
.
StatusCode
,
UpstreamRequestID
:
resp
.
Header
.
Get
(
"x-request-id"
),
Kind
:
"http_error"
,
Message
:
upstreamMsg
,
Detail
:
upstreamDetail
,
})
c
.
JSON
(
http
.
StatusInternalServerError
,
gin
.
H
{
"error"
:
gin
.
H
{
"type"
:
"upstream_error"
,
"message"
:
"Upstream gateway error"
,
},
})
if
upstreamMsg
==
""
{
return
nil
,
fmt
.
Errorf
(
"upstream error: %d (not in custom error codes)"
,
resp
.
StatusCode
)
}
return
nil
,
fmt
.
Errorf
(
"upstream error: %d (not in custom error codes) message=%s"
,
resp
.
StatusCode
,
upstreamMsg
)
}
// Handle upstream error (mark account status)
shouldDisable
:=
false
if
s
.
rateLimitService
!=
nil
{
shouldDisable
=
s
.
rateLimitService
.
HandleUpstreamError
(
ctx
,
account
,
resp
.
StatusCode
,
resp
.
Header
,
body
)
}
kind
:=
"http_error"
if
shouldDisable
{
kind
=
"failover"
}
appendOpsUpstreamError
(
c
,
OpsUpstreamErrorEvent
{
Platform
:
account
.
Platform
,
AccountID
:
account
.
ID
,
AccountName
:
account
.
Name
,
UpstreamStatusCode
:
resp
.
StatusCode
,
UpstreamRequestID
:
resp
.
Header
.
Get
(
"x-request-id"
),
Kind
:
kind
,
Message
:
upstreamMsg
,
Detail
:
upstreamDetail
,
})
if
shouldDisable
{
return
nil
,
&
UpstreamFailoverError
{
StatusCode
:
resp
.
StatusCode
}
}
...
...
@@ -761,7 +961,10 @@ func (s *OpenAIGatewayService) handleErrorResponse(ctx context.Context, resp *ht
},
})
if
upstreamMsg
==
""
{
return
nil
,
fmt
.
Errorf
(
"upstream error: %d"
,
resp
.
StatusCode
)
}
return
nil
,
fmt
.
Errorf
(
"upstream error: %d message=%s"
,
resp
.
StatusCode
,
upstreamMsg
)
}
// openaiStreamingResult streaming response result
...
...
@@ -905,6 +1108,11 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
line
=
s
.
replaceModelInSSELine
(
line
,
mappedModel
,
originalModel
)
}
// Correct Codex tool calls if needed (apply_patch -> edit, etc.)
if
correctedData
,
corrected
:=
s
.
toolCorrector
.
CorrectToolCallsInSSEData
(
data
);
corrected
{
line
=
"data: "
+
correctedData
}
// Forward line
if
_
,
err
:=
fmt
.
Fprintf
(
w
,
"%s
\n
"
,
line
);
err
!=
nil
{
sendErrorEvent
(
"write_failed"
)
...
...
@@ -933,6 +1141,10 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
continue
}
log
.
Printf
(
"Stream data interval timeout: account=%d model=%s interval=%s"
,
account
.
ID
,
originalModel
,
streamInterval
)
// 处理流超时,可能标记账户为临时不可调度或错误状态
if
s
.
rateLimitService
!=
nil
{
s
.
rateLimitService
.
HandleStreamTimeout
(
ctx
,
account
,
originalModel
)
}
sendErrorEvent
(
"stream_timeout"
)
return
&
openaiStreamingResult
{
usage
:
usage
,
firstTokenMs
:
firstTokenMs
},
fmt
.
Errorf
(
"stream data interval timeout"
)
...
...
@@ -988,6 +1200,20 @@ func (s *OpenAIGatewayService) replaceModelInSSELine(line, fromModel, toModel st
return
line
}
// correctToolCallsInResponseBody 修正响应体中的工具调用
func
(
s
*
OpenAIGatewayService
)
correctToolCallsInResponseBody
(
body
[]
byte
)
[]
byte
{
if
len
(
body
)
==
0
{
return
body
}
bodyStr
:=
string
(
body
)
corrected
,
changed
:=
s
.
toolCorrector
.
CorrectToolCallsInSSEData
(
bodyStr
)
if
changed
{
return
[]
byte
(
corrected
)
}
return
body
}
func
(
s
*
OpenAIGatewayService
)
parseSSEUsage
(
data
string
,
usage
*
OpenAIUsage
)
{
// Parse response.completed event for usage (OpenAI Responses format)
var
event
struct
{
...
...
@@ -1016,6 +1242,13 @@ func (s *OpenAIGatewayService) handleNonStreamingResponse(ctx context.Context, r
return
nil
,
err
}
if
account
.
Type
==
AccountTypeOAuth
{
bodyLooksLikeSSE
:=
bytes
.
Contains
(
body
,
[]
byte
(
"data:"
))
||
bytes
.
Contains
(
body
,
[]
byte
(
"event:"
))
if
isEventStreamResponse
(
resp
.
Header
)
||
bodyLooksLikeSSE
{
return
s
.
handleOAuthSSEToJSON
(
resp
,
c
,
body
,
originalModel
,
mappedModel
)
}
}
// Parse usage
var
response
struct
{
Usage
struct
{
...
...
@@ -1055,6 +1288,112 @@ func (s *OpenAIGatewayService) handleNonStreamingResponse(ctx context.Context, r
return
usage
,
nil
}
func
isEventStreamResponse
(
header
http
.
Header
)
bool
{
contentType
:=
strings
.
ToLower
(
header
.
Get
(
"Content-Type"
))
return
strings
.
Contains
(
contentType
,
"text/event-stream"
)
}
func
(
s
*
OpenAIGatewayService
)
handleOAuthSSEToJSON
(
resp
*
http
.
Response
,
c
*
gin
.
Context
,
body
[]
byte
,
originalModel
,
mappedModel
string
)
(
*
OpenAIUsage
,
error
)
{
bodyText
:=
string
(
body
)
finalResponse
,
ok
:=
extractCodexFinalResponse
(
bodyText
)
usage
:=
&
OpenAIUsage
{}
if
ok
{
var
response
struct
{
Usage
struct
{
InputTokens
int
`json:"input_tokens"`
OutputTokens
int
`json:"output_tokens"`
InputTokenDetails
struct
{
CachedTokens
int
`json:"cached_tokens"`
}
`json:"input_tokens_details"`
}
`json:"usage"`
}
if
err
:=
json
.
Unmarshal
(
finalResponse
,
&
response
);
err
==
nil
{
usage
.
InputTokens
=
response
.
Usage
.
InputTokens
usage
.
OutputTokens
=
response
.
Usage
.
OutputTokens
usage
.
CacheReadInputTokens
=
response
.
Usage
.
InputTokenDetails
.
CachedTokens
}
body
=
finalResponse
if
originalModel
!=
mappedModel
{
body
=
s
.
replaceModelInResponseBody
(
body
,
mappedModel
,
originalModel
)
}
// Correct tool calls in final response
body
=
s
.
correctToolCallsInResponseBody
(
body
)
}
else
{
usage
=
s
.
parseSSEUsageFromBody
(
bodyText
)
if
originalModel
!=
mappedModel
{
bodyText
=
s
.
replaceModelInSSEBody
(
bodyText
,
mappedModel
,
originalModel
)
}
body
=
[]
byte
(
bodyText
)
}
responseheaders
.
WriteFilteredHeaders
(
c
.
Writer
.
Header
(),
resp
.
Header
,
s
.
cfg
.
Security
.
ResponseHeaders
)
contentType
:=
"application/json; charset=utf-8"
if
!
ok
{
contentType
=
resp
.
Header
.
Get
(
"Content-Type"
)
if
contentType
==
""
{
contentType
=
"text/event-stream"
}
}
c
.
Data
(
resp
.
StatusCode
,
contentType
,
body
)
return
usage
,
nil
}
func
extractCodexFinalResponse
(
body
string
)
([]
byte
,
bool
)
{
lines
:=
strings
.
Split
(
body
,
"
\n
"
)
for
_
,
line
:=
range
lines
{
if
!
openaiSSEDataRe
.
MatchString
(
line
)
{
continue
}
data
:=
openaiSSEDataRe
.
ReplaceAllString
(
line
,
""
)
if
data
==
""
||
data
==
"[DONE]"
{
continue
}
var
event
struct
{
Type
string
`json:"type"`
Response
json
.
RawMessage
`json:"response"`
}
if
json
.
Unmarshal
([]
byte
(
data
),
&
event
)
!=
nil
{
continue
}
if
event
.
Type
==
"response.done"
||
event
.
Type
==
"response.completed"
{
if
len
(
event
.
Response
)
>
0
{
return
event
.
Response
,
true
}
}
}
return
nil
,
false
}
func
(
s
*
OpenAIGatewayService
)
parseSSEUsageFromBody
(
body
string
)
*
OpenAIUsage
{
usage
:=
&
OpenAIUsage
{}
lines
:=
strings
.
Split
(
body
,
"
\n
"
)
for
_
,
line
:=
range
lines
{
if
!
openaiSSEDataRe
.
MatchString
(
line
)
{
continue
}
data
:=
openaiSSEDataRe
.
ReplaceAllString
(
line
,
""
)
if
data
==
""
||
data
==
"[DONE]"
{
continue
}
s
.
parseSSEUsage
(
data
,
usage
)
}
return
usage
}
func
(
s
*
OpenAIGatewayService
)
replaceModelInSSEBody
(
body
,
fromModel
,
toModel
string
)
string
{
lines
:=
strings
.
Split
(
body
,
"
\n
"
)
for
i
,
line
:=
range
lines
{
if
!
openaiSSEDataRe
.
MatchString
(
line
)
{
continue
}
lines
[
i
]
=
s
.
replaceModelInSSELine
(
line
,
fromModel
,
toModel
)
}
return
strings
.
Join
(
lines
,
"
\n
"
)
}
func
(
s
*
OpenAIGatewayService
)
validateUpstreamBaseURL
(
raw
string
)
(
string
,
error
)
{
if
s
.
cfg
!=
nil
&&
!
s
.
cfg
.
Security
.
URLAllowlist
.
Enabled
{
normalized
,
err
:=
urlvalidator
.
ValidateURLFormat
(
raw
,
s
.
cfg
.
Security
.
URLAllowlist
.
AllowInsecureHTTP
)
...
...
@@ -1094,101 +1433,6 @@ func (s *OpenAIGatewayService) replaceModelInResponseBody(body []byte, fromModel
return
newBody
}
// normalizeInputForCodexAPI converts AI SDK multi-part content format to simplified format
// that the ChatGPT internal Codex API expects.
//
// AI SDK sends content as an array of typed objects:
//
// {"content": [{"type": "input_text", "text": "hello"}]}
//
// ChatGPT Codex API expects content as a simple string:
//
// {"content": "hello"}
//
// This function modifies reqBody in-place and returns true if any modification was made.
func
normalizeInputForCodexAPI
(
reqBody
map
[
string
]
any
)
bool
{
input
,
ok
:=
reqBody
[
"input"
]
if
!
ok
{
return
false
}
// Handle case where input is a simple string (already compatible)
if
_
,
isString
:=
input
.
(
string
);
isString
{
return
false
}
// Handle case where input is an array of messages
inputArray
,
ok
:=
input
.
([]
any
)
if
!
ok
{
return
false
}
modified
:=
false
for
_
,
item
:=
range
inputArray
{
message
,
ok
:=
item
.
(
map
[
string
]
any
)
if
!
ok
{
continue
}
content
,
ok
:=
message
[
"content"
]
if
!
ok
{
continue
}
// If content is already a string, no conversion needed
if
_
,
isString
:=
content
.
(
string
);
isString
{
continue
}
// If content is an array (AI SDK format), convert to string
contentArray
,
ok
:=
content
.
([]
any
)
if
!
ok
{
continue
}
// Extract text from content array
var
textParts
[]
string
for
_
,
part
:=
range
contentArray
{
partMap
,
ok
:=
part
.
(
map
[
string
]
any
)
if
!
ok
{
continue
}
// Handle different content types
partType
,
_
:=
partMap
[
"type"
]
.
(
string
)
switch
partType
{
case
"input_text"
,
"text"
:
// Extract text from input_text or text type
if
text
,
ok
:=
partMap
[
"text"
]
.
(
string
);
ok
{
textParts
=
append
(
textParts
,
text
)
}
case
"input_image"
,
"image"
:
// For images, we need to preserve the original format
// as ChatGPT Codex API may support images in a different way
// For now, skip image parts (they will be lost in conversion)
// TODO: Consider preserving image data or handling it separately
continue
case
"input_file"
,
"file"
:
// Similar to images, file inputs may need special handling
continue
default
:
// For unknown types, try to extract text if available
if
text
,
ok
:=
partMap
[
"text"
]
.
(
string
);
ok
{
textParts
=
append
(
textParts
,
text
)
}
}
}
// Convert content array to string
if
len
(
textParts
)
>
0
{
message
[
"content"
]
=
strings
.
Join
(
textParts
,
"
\n
"
)
modified
=
true
}
}
return
modified
}
// OpenAIRecordUsageInput input for recording usage
type
OpenAIRecordUsageInput
struct
{
Result
*
OpenAIForwardResult
...
...
@@ -1197,6 +1441,7 @@ type OpenAIRecordUsageInput struct {
Account
*
Account
Subscription
*
UserSubscription
UserAgent
string
// 请求的 User-Agent
IPAddress
string
// 请求的客户端 IP 地址
}
// RecordUsage records usage and deducts balance
...
...
@@ -1242,6 +1487,7 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
// Create usage log
durationMs
:=
int
(
result
.
Duration
.
Milliseconds
())
accountRateMultiplier
:=
account
.
BillingRateMultiplier
()
usageLog
:=
&
UsageLog
{
UserID
:
user
.
ID
,
APIKeyID
:
apiKey
.
ID
,
...
...
@@ -1259,6 +1505,7 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
TotalCost
:
cost
.
TotalCost
,
ActualCost
:
cost
.
ActualCost
,
RateMultiplier
:
multiplier
,
AccountRateMultiplier
:
&
accountRateMultiplier
,
BillingType
:
billingType
,
Stream
:
result
.
Stream
,
DurationMs
:
&
durationMs
,
...
...
@@ -1271,6 +1518,11 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
usageLog
.
UserAgent
=
&
input
.
UserAgent
}
// 添加 IPAddress
if
input
.
IPAddress
!=
""
{
usageLog
.
IPAddress
=
&
input
.
IPAddress
}
if
apiKey
.
GroupID
!=
nil
{
usageLog
.
GroupID
=
apiKey
.
GroupID
}
...
...
backend/internal/service/openai_gateway_service_test.go
View file @
b9b4db3d
...
...
@@ -3,6 +3,7 @@ package service
import
(
"bufio"
"bytes"
"context"
"errors"
"io"
"net/http"
...
...
@@ -15,6 +16,129 @@ import (
"github.com/gin-gonic/gin"
)
type
stubOpenAIAccountRepo
struct
{
AccountRepository
accounts
[]
Account
}
func
(
r
stubOpenAIAccountRepo
)
ListSchedulableByGroupIDAndPlatform
(
ctx
context
.
Context
,
groupID
int64
,
platform
string
)
([]
Account
,
error
)
{
return
append
([]
Account
(
nil
),
r
.
accounts
...
),
nil
}
func
(
r
stubOpenAIAccountRepo
)
ListSchedulableByPlatform
(
ctx
context
.
Context
,
platform
string
)
([]
Account
,
error
)
{
return
append
([]
Account
(
nil
),
r
.
accounts
...
),
nil
}
type
stubConcurrencyCache
struct
{
ConcurrencyCache
}
func
(
c
stubConcurrencyCache
)
AcquireAccountSlot
(
ctx
context
.
Context
,
accountID
int64
,
maxConcurrency
int
,
requestID
string
)
(
bool
,
error
)
{
return
true
,
nil
}
func
(
c
stubConcurrencyCache
)
ReleaseAccountSlot
(
ctx
context
.
Context
,
accountID
int64
,
requestID
string
)
error
{
return
nil
}
func
(
c
stubConcurrencyCache
)
GetAccountsLoadBatch
(
ctx
context
.
Context
,
accounts
[]
AccountWithConcurrency
)
(
map
[
int64
]
*
AccountLoadInfo
,
error
)
{
out
:=
make
(
map
[
int64
]
*
AccountLoadInfo
,
len
(
accounts
))
for
_
,
acc
:=
range
accounts
{
out
[
acc
.
ID
]
=
&
AccountLoadInfo
{
AccountID
:
acc
.
ID
,
LoadRate
:
0
}
}
return
out
,
nil
}
func
TestOpenAISelectAccountWithLoadAwareness_FiltersUnschedulable
(
t
*
testing
.
T
)
{
now
:=
time
.
Now
()
resetAt
:=
now
.
Add
(
10
*
time
.
Minute
)
groupID
:=
int64
(
1
)
rateLimited
:=
Account
{
ID
:
1
,
Platform
:
PlatformOpenAI
,
Type
:
AccountTypeAPIKey
,
Status
:
StatusActive
,
Schedulable
:
true
,
Concurrency
:
1
,
Priority
:
0
,
RateLimitResetAt
:
&
resetAt
,
}
available
:=
Account
{
ID
:
2
,
Platform
:
PlatformOpenAI
,
Type
:
AccountTypeAPIKey
,
Status
:
StatusActive
,
Schedulable
:
true
,
Concurrency
:
1
,
Priority
:
1
,
}
svc
:=
&
OpenAIGatewayService
{
accountRepo
:
stubOpenAIAccountRepo
{
accounts
:
[]
Account
{
rateLimited
,
available
}},
concurrencyService
:
NewConcurrencyService
(
stubConcurrencyCache
{}),
}
selection
,
err
:=
svc
.
SelectAccountWithLoadAwareness
(
context
.
Background
(),
&
groupID
,
""
,
"gpt-5.2"
,
nil
)
if
err
!=
nil
{
t
.
Fatalf
(
"SelectAccountWithLoadAwareness error: %v"
,
err
)
}
if
selection
==
nil
||
selection
.
Account
==
nil
{
t
.
Fatalf
(
"expected selection with account"
)
}
if
selection
.
Account
.
ID
!=
available
.
ID
{
t
.
Fatalf
(
"expected account %d, got %d"
,
available
.
ID
,
selection
.
Account
.
ID
)
}
if
selection
.
ReleaseFunc
!=
nil
{
selection
.
ReleaseFunc
()
}
}
func
TestOpenAISelectAccountWithLoadAwareness_FiltersUnschedulableWhenNoConcurrencyService
(
t
*
testing
.
T
)
{
now
:=
time
.
Now
()
resetAt
:=
now
.
Add
(
10
*
time
.
Minute
)
groupID
:=
int64
(
1
)
rateLimited
:=
Account
{
ID
:
1
,
Platform
:
PlatformOpenAI
,
Type
:
AccountTypeAPIKey
,
Status
:
StatusActive
,
Schedulable
:
true
,
Concurrency
:
1
,
Priority
:
0
,
RateLimitResetAt
:
&
resetAt
,
}
available
:=
Account
{
ID
:
2
,
Platform
:
PlatformOpenAI
,
Type
:
AccountTypeAPIKey
,
Status
:
StatusActive
,
Schedulable
:
true
,
Concurrency
:
1
,
Priority
:
1
,
}
svc
:=
&
OpenAIGatewayService
{
accountRepo
:
stubOpenAIAccountRepo
{
accounts
:
[]
Account
{
rateLimited
,
available
}},
// concurrencyService is nil, forcing the non-load-batch selection path.
}
selection
,
err
:=
svc
.
SelectAccountWithLoadAwareness
(
context
.
Background
(),
&
groupID
,
""
,
"gpt-5.2"
,
nil
)
if
err
!=
nil
{
t
.
Fatalf
(
"SelectAccountWithLoadAwareness error: %v"
,
err
)
}
if
selection
==
nil
||
selection
.
Account
==
nil
{
t
.
Fatalf
(
"expected selection with account"
)
}
if
selection
.
Account
.
ID
!=
available
.
ID
{
t
.
Fatalf
(
"expected account %d, got %d"
,
available
.
ID
,
selection
.
Account
.
ID
)
}
if
selection
.
ReleaseFunc
!=
nil
{
selection
.
ReleaseFunc
()
}
}
func
TestOpenAIStreamingTimeout
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
cfg
:=
&
config
.
Config
{
...
...
@@ -220,7 +344,7 @@ func TestOpenAIInvalidBaseURLWhenAllowlistDisabled(t *testing.T) {
Credentials
:
map
[
string
]
any
{
"base_url"
:
"://invalid-url"
},
}
_
,
err
:=
svc
.
buildUpstreamRequest
(
c
.
Request
.
Context
(),
c
,
account
,
[]
byte
(
"{}"
),
"token"
,
false
)
_
,
err
:=
svc
.
buildUpstreamRequest
(
c
.
Request
.
Context
(),
c
,
account
,
[]
byte
(
"{}"
),
"token"
,
false
,
""
,
false
)
if
err
==
nil
{
t
.
Fatalf
(
"expected error for invalid base_url when allowlist disabled"
)
}
...
...
backend/internal/service/openai_gateway_service_tool_correction_test.go
0 → 100644
View file @
b9b4db3d
package
service
import
(
"strings"
"testing"
)
// TestOpenAIGatewayService_ToolCorrection 测试 OpenAIGatewayService 中的工具修正集成
func
TestOpenAIGatewayService_ToolCorrection
(
t
*
testing
.
T
)
{
// 创建一个简单的 service 实例来测试工具修正
service
:=
&
OpenAIGatewayService
{
toolCorrector
:
NewCodexToolCorrector
(),
}
tests
:=
[]
struct
{
name
string
input
[]
byte
expected
string
changed
bool
}{
{
name
:
"correct apply_patch in response body"
,
input
:
[]
byte
(
`{
"choices": [{
"message": {
"tool_calls": [{
"function": {"name": "apply_patch"}
}]
}
}]
}`
),
expected
:
"edit"
,
changed
:
true
,
},
{
name
:
"correct update_plan in response body"
,
input
:
[]
byte
(
`{
"tool_calls": [{
"function": {"name": "update_plan"}
}]
}`
),
expected
:
"todowrite"
,
changed
:
true
,
},
{
name
:
"no change for correct tool name"
,
input
:
[]
byte
(
`{
"tool_calls": [{
"function": {"name": "edit"}
}]
}`
),
expected
:
"edit"
,
changed
:
false
,
},
}
for
_
,
tt
:=
range
tests
{
t
.
Run
(
tt
.
name
,
func
(
t
*
testing
.
T
)
{
result
:=
service
.
correctToolCallsInResponseBody
(
tt
.
input
)
resultStr
:=
string
(
result
)
// 检查是否包含期望的工具名称
if
!
strings
.
Contains
(
resultStr
,
tt
.
expected
)
{
t
.
Errorf
(
"expected result to contain %q, got %q"
,
tt
.
expected
,
resultStr
)
}
// 对于预期有变化的情况,验证结果与输入不同
if
tt
.
changed
&&
string
(
result
)
==
string
(
tt
.
input
)
{
t
.
Error
(
"expected result to be different from input, but they are the same"
)
}
// 对于预期无变化的情况,验证结果与输入相同
if
!
tt
.
changed
&&
string
(
result
)
!=
string
(
tt
.
input
)
{
t
.
Error
(
"expected result to be same as input, but they are different"
)
}
})
}
}
// TestOpenAIGatewayService_ToolCorrectorInitialization 测试工具修正器是否正确初始化
func
TestOpenAIGatewayService_ToolCorrectorInitialization
(
t
*
testing
.
T
)
{
service
:=
&
OpenAIGatewayService
{
toolCorrector
:
NewCodexToolCorrector
(),
}
if
service
.
toolCorrector
==
nil
{
t
.
Fatal
(
"toolCorrector should not be nil"
)
}
// 测试修正器可以正常工作
data
:=
`{"tool_calls":[{"function":{"name":"apply_patch"}}]}`
corrected
,
changed
:=
service
.
toolCorrector
.
CorrectToolCallsInSSEData
(
data
)
if
!
changed
{
t
.
Error
(
"expected tool call to be corrected"
)
}
if
!
strings
.
Contains
(
corrected
,
"edit"
)
{
t
.
Errorf
(
"expected corrected data to contain 'edit', got %q"
,
corrected
)
}
}
// TestToolCorrectionStats 测试工具修正统计功能
func
TestToolCorrectionStats
(
t
*
testing
.
T
)
{
service
:=
&
OpenAIGatewayService
{
toolCorrector
:
NewCodexToolCorrector
(),
}
// 执行几次修正
testData
:=
[]
string
{
`{"tool_calls":[{"function":{"name":"apply_patch"}}]}`
,
`{"tool_calls":[{"function":{"name":"update_plan"}}]}`
,
`{"tool_calls":[{"function":{"name":"apply_patch"}}]}`
,
}
for
_
,
data
:=
range
testData
{
service
.
toolCorrector
.
CorrectToolCallsInSSEData
(
data
)
}
stats
:=
service
.
toolCorrector
.
GetStats
()
if
stats
.
TotalCorrected
!=
3
{
t
.
Errorf
(
"expected 3 corrections, got %d"
,
stats
.
TotalCorrected
)
}
if
stats
.
CorrectionsByTool
[
"apply_patch->edit"
]
!=
2
{
t
.
Errorf
(
"expected 2 apply_patch->edit corrections, got %d"
,
stats
.
CorrectionsByTool
[
"apply_patch->edit"
])
}
if
stats
.
CorrectionsByTool
[
"update_plan->todowrite"
]
!=
1
{
t
.
Errorf
(
"expected 1 update_plan->todowrite correction, got %d"
,
stats
.
CorrectionsByTool
[
"update_plan->todowrite"
])
}
}
backend/internal/service/openai_token_provider.go
0 → 100644
View file @
b9b4db3d
package
service
import
(
"context"
"errors"
"log/slog"
"strings"
"time"
)
const
(
openAITokenRefreshSkew
=
3
*
time
.
Minute
openAITokenCacheSkew
=
5
*
time
.
Minute
openAILockWaitTime
=
200
*
time
.
Millisecond
)
// OpenAITokenCache Token 缓存接口(复用 GeminiTokenCache 接口定义)
type
OpenAITokenCache
=
GeminiTokenCache
// OpenAITokenProvider 管理 OpenAI OAuth 账户的 access_token
type
OpenAITokenProvider
struct
{
accountRepo
AccountRepository
tokenCache
OpenAITokenCache
openAIOAuthService
*
OpenAIOAuthService
}
func
NewOpenAITokenProvider
(
accountRepo
AccountRepository
,
tokenCache
OpenAITokenCache
,
openAIOAuthService
*
OpenAIOAuthService
,
)
*
OpenAITokenProvider
{
return
&
OpenAITokenProvider
{
accountRepo
:
accountRepo
,
tokenCache
:
tokenCache
,
openAIOAuthService
:
openAIOAuthService
,
}
}
// GetAccessToken 获取有效的 access_token
func
(
p
*
OpenAITokenProvider
)
GetAccessToken
(
ctx
context
.
Context
,
account
*
Account
)
(
string
,
error
)
{
if
account
==
nil
{
return
""
,
errors
.
New
(
"account is nil"
)
}
if
account
.
Platform
!=
PlatformOpenAI
||
account
.
Type
!=
AccountTypeOAuth
{
return
""
,
errors
.
New
(
"not an openai oauth account"
)
}
cacheKey
:=
OpenAITokenCacheKey
(
account
)
// 1. 先尝试缓存
if
p
.
tokenCache
!=
nil
{
if
token
,
err
:=
p
.
tokenCache
.
GetAccessToken
(
ctx
,
cacheKey
);
err
==
nil
&&
strings
.
TrimSpace
(
token
)
!=
""
{
slog
.
Debug
(
"openai_token_cache_hit"
,
"account_id"
,
account
.
ID
)
return
token
,
nil
}
else
if
err
!=
nil
{
slog
.
Warn
(
"openai_token_cache_get_failed"
,
"account_id"
,
account
.
ID
,
"error"
,
err
)
}
}
slog
.
Debug
(
"openai_token_cache_miss"
,
"account_id"
,
account
.
ID
)
// 2. 如果即将过期则刷新
expiresAt
:=
account
.
GetCredentialAsTime
(
"expires_at"
)
needsRefresh
:=
expiresAt
==
nil
||
time
.
Until
(
*
expiresAt
)
<=
openAITokenRefreshSkew
refreshFailed
:=
false
if
needsRefresh
&&
p
.
tokenCache
!=
nil
{
locked
,
lockErr
:=
p
.
tokenCache
.
AcquireRefreshLock
(
ctx
,
cacheKey
,
30
*
time
.
Second
)
if
lockErr
==
nil
&&
locked
{
defer
func
()
{
_
=
p
.
tokenCache
.
ReleaseRefreshLock
(
ctx
,
cacheKey
)
}()
// 拿到锁后再次检查缓存(另一个 worker 可能已刷新)
if
token
,
err
:=
p
.
tokenCache
.
GetAccessToken
(
ctx
,
cacheKey
);
err
==
nil
&&
strings
.
TrimSpace
(
token
)
!=
""
{
return
token
,
nil
}
// 从数据库获取最新账户信息
fresh
,
err
:=
p
.
accountRepo
.
GetByID
(
ctx
,
account
.
ID
)
if
err
==
nil
&&
fresh
!=
nil
{
account
=
fresh
}
expiresAt
=
account
.
GetCredentialAsTime
(
"expires_at"
)
if
expiresAt
==
nil
||
time
.
Until
(
*
expiresAt
)
<=
openAITokenRefreshSkew
{
if
p
.
openAIOAuthService
==
nil
{
slog
.
Warn
(
"openai_oauth_service_not_configured"
,
"account_id"
,
account
.
ID
)
refreshFailed
=
true
// 无法刷新,标记失败
}
else
{
tokenInfo
,
err
:=
p
.
openAIOAuthService
.
RefreshAccountToken
(
ctx
,
account
)
if
err
!=
nil
{
// 刷新失败时记录警告,但不立即返回错误,尝试使用现有 token
slog
.
Warn
(
"openai_token_refresh_failed"
,
"account_id"
,
account
.
ID
,
"error"
,
err
)
refreshFailed
=
true
// 刷新失败,标记以使用短 TTL
}
else
{
newCredentials
:=
p
.
openAIOAuthService
.
BuildAccountCredentials
(
tokenInfo
)
for
k
,
v
:=
range
account
.
Credentials
{
if
_
,
exists
:=
newCredentials
[
k
];
!
exists
{
newCredentials
[
k
]
=
v
}
}
account
.
Credentials
=
newCredentials
if
updateErr
:=
p
.
accountRepo
.
Update
(
ctx
,
account
);
updateErr
!=
nil
{
slog
.
Error
(
"openai_token_provider_update_failed"
,
"account_id"
,
account
.
ID
,
"error"
,
updateErr
)
}
expiresAt
=
account
.
GetCredentialAsTime
(
"expires_at"
)
}
}
}
}
else
if
lockErr
!=
nil
{
// Redis 错误导致无法获取锁,降级为无锁刷新(仅在 token 接近过期时)
slog
.
Warn
(
"openai_token_lock_failed_degraded_refresh"
,
"account_id"
,
account
.
ID
,
"error"
,
lockErr
)
// 检查 ctx 是否已取消
if
ctx
.
Err
()
!=
nil
{
return
""
,
ctx
.
Err
()
}
// 从数据库获取最新账户信息
if
p
.
accountRepo
!=
nil
{
fresh
,
err
:=
p
.
accountRepo
.
GetByID
(
ctx
,
account
.
ID
)
if
err
==
nil
&&
fresh
!=
nil
{
account
=
fresh
}
}
expiresAt
=
account
.
GetCredentialAsTime
(
"expires_at"
)
// 仅在 expires_at 已过期/接近过期时才执行无锁刷新
if
expiresAt
==
nil
||
time
.
Until
(
*
expiresAt
)
<=
openAITokenRefreshSkew
{
if
p
.
openAIOAuthService
==
nil
{
slog
.
Warn
(
"openai_oauth_service_not_configured"
,
"account_id"
,
account
.
ID
)
refreshFailed
=
true
}
else
{
tokenInfo
,
err
:=
p
.
openAIOAuthService
.
RefreshAccountToken
(
ctx
,
account
)
if
err
!=
nil
{
slog
.
Warn
(
"openai_token_refresh_failed_degraded"
,
"account_id"
,
account
.
ID
,
"error"
,
err
)
refreshFailed
=
true
}
else
{
newCredentials
:=
p
.
openAIOAuthService
.
BuildAccountCredentials
(
tokenInfo
)
for
k
,
v
:=
range
account
.
Credentials
{
if
_
,
exists
:=
newCredentials
[
k
];
!
exists
{
newCredentials
[
k
]
=
v
}
}
account
.
Credentials
=
newCredentials
if
updateErr
:=
p
.
accountRepo
.
Update
(
ctx
,
account
);
updateErr
!=
nil
{
slog
.
Error
(
"openai_token_provider_update_failed"
,
"account_id"
,
account
.
ID
,
"error"
,
updateErr
)
}
expiresAt
=
account
.
GetCredentialAsTime
(
"expires_at"
)
}
}
}
}
else
{
// 锁获取失败(被其他 worker 持有),等待 200ms 后重试读取缓存
time
.
Sleep
(
openAILockWaitTime
)
if
token
,
err
:=
p
.
tokenCache
.
GetAccessToken
(
ctx
,
cacheKey
);
err
==
nil
&&
strings
.
TrimSpace
(
token
)
!=
""
{
slog
.
Debug
(
"openai_token_cache_hit_after_wait"
,
"account_id"
,
account
.
ID
)
return
token
,
nil
}
}
}
accessToken
:=
account
.
GetOpenAIAccessToken
()
if
strings
.
TrimSpace
(
accessToken
)
==
""
{
return
""
,
errors
.
New
(
"access_token not found in credentials"
)
}
// 3. 存入缓存
if
p
.
tokenCache
!=
nil
{
ttl
:=
30
*
time
.
Minute
if
refreshFailed
{
// 刷新失败时使用短 TTL,避免失效 token 长时间缓存导致 401 抖动
ttl
=
time
.
Minute
slog
.
Debug
(
"openai_token_cache_short_ttl"
,
"account_id"
,
account
.
ID
,
"reason"
,
"refresh_failed"
)
}
else
if
expiresAt
!=
nil
{
until
:=
time
.
Until
(
*
expiresAt
)
switch
{
case
until
>
openAITokenCacheSkew
:
ttl
=
until
-
openAITokenCacheSkew
case
until
>
0
:
ttl
=
until
default
:
ttl
=
time
.
Minute
}
}
if
err
:=
p
.
tokenCache
.
SetAccessToken
(
ctx
,
cacheKey
,
accessToken
,
ttl
);
err
!=
nil
{
slog
.
Warn
(
"openai_token_cache_set_failed"
,
"account_id"
,
account
.
ID
,
"error"
,
err
)
}
}
return
accessToken
,
nil
}
backend/internal/service/openai_token_provider_test.go
0 → 100644
View file @
b9b4db3d
//go:build unit
package
service
import
(
"context"
"errors"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/stretchr/testify/require"
)
// openAITokenCacheStub implements OpenAITokenCache for testing
type
openAITokenCacheStub
struct
{
mu
sync
.
Mutex
tokens
map
[
string
]
string
getErr
error
setErr
error
deleteErr
error
lockAcquired
bool
lockErr
error
releaseLockErr
error
getCalled
int32
setCalled
int32
lockCalled
int32
unlockCalled
int32
simulateLockRace
bool
}
func
newOpenAITokenCacheStub
()
*
openAITokenCacheStub
{
return
&
openAITokenCacheStub
{
tokens
:
make
(
map
[
string
]
string
),
lockAcquired
:
true
,
}
}
func
(
s
*
openAITokenCacheStub
)
GetAccessToken
(
ctx
context
.
Context
,
cacheKey
string
)
(
string
,
error
)
{
atomic
.
AddInt32
(
&
s
.
getCalled
,
1
)
if
s
.
getErr
!=
nil
{
return
""
,
s
.
getErr
}
s
.
mu
.
Lock
()
defer
s
.
mu
.
Unlock
()
return
s
.
tokens
[
cacheKey
],
nil
}
func
(
s
*
openAITokenCacheStub
)
SetAccessToken
(
ctx
context
.
Context
,
cacheKey
string
,
token
string
,
ttl
time
.
Duration
)
error
{
atomic
.
AddInt32
(
&
s
.
setCalled
,
1
)
if
s
.
setErr
!=
nil
{
return
s
.
setErr
}
s
.
mu
.
Lock
()
defer
s
.
mu
.
Unlock
()
s
.
tokens
[
cacheKey
]
=
token
return
nil
}
func
(
s
*
openAITokenCacheStub
)
DeleteAccessToken
(
ctx
context
.
Context
,
cacheKey
string
)
error
{
if
s
.
deleteErr
!=
nil
{
return
s
.
deleteErr
}
s
.
mu
.
Lock
()
defer
s
.
mu
.
Unlock
()
delete
(
s
.
tokens
,
cacheKey
)
return
nil
}
func
(
s
*
openAITokenCacheStub
)
AcquireRefreshLock
(
ctx
context
.
Context
,
cacheKey
string
,
ttl
time
.
Duration
)
(
bool
,
error
)
{
atomic
.
AddInt32
(
&
s
.
lockCalled
,
1
)
if
s
.
lockErr
!=
nil
{
return
false
,
s
.
lockErr
}
if
s
.
simulateLockRace
{
return
false
,
nil
}
return
s
.
lockAcquired
,
nil
}
func
(
s
*
openAITokenCacheStub
)
ReleaseRefreshLock
(
ctx
context
.
Context
,
cacheKey
string
)
error
{
atomic
.
AddInt32
(
&
s
.
unlockCalled
,
1
)
return
s
.
releaseLockErr
}
// openAIAccountRepoStub is a minimal stub implementing only the methods used by OpenAITokenProvider
type
openAIAccountRepoStub
struct
{
account
*
Account
getErr
error
updateErr
error
getCalled
int32
updateCalled
int32
}
func
(
r
*
openAIAccountRepoStub
)
GetByID
(
ctx
context
.
Context
,
id
int64
)
(
*
Account
,
error
)
{
atomic
.
AddInt32
(
&
r
.
getCalled
,
1
)
if
r
.
getErr
!=
nil
{
return
nil
,
r
.
getErr
}
return
r
.
account
,
nil
}
func
(
r
*
openAIAccountRepoStub
)
Update
(
ctx
context
.
Context
,
account
*
Account
)
error
{
atomic
.
AddInt32
(
&
r
.
updateCalled
,
1
)
if
r
.
updateErr
!=
nil
{
return
r
.
updateErr
}
r
.
account
=
account
return
nil
}
// openAIOAuthServiceStub implements OpenAIOAuthService methods for testing
type
openAIOAuthServiceStub
struct
{
tokenInfo
*
OpenAITokenInfo
refreshErr
error
refreshCalled
int32
}
func
(
s
*
openAIOAuthServiceStub
)
RefreshAccountToken
(
ctx
context
.
Context
,
account
*
Account
)
(
*
OpenAITokenInfo
,
error
)
{
atomic
.
AddInt32
(
&
s
.
refreshCalled
,
1
)
if
s
.
refreshErr
!=
nil
{
return
nil
,
s
.
refreshErr
}
return
s
.
tokenInfo
,
nil
}
func
(
s
*
openAIOAuthServiceStub
)
BuildAccountCredentials
(
info
*
OpenAITokenInfo
)
map
[
string
]
any
{
now
:=
time
.
Now
()
return
map
[
string
]
any
{
"access_token"
:
info
.
AccessToken
,
"refresh_token"
:
info
.
RefreshToken
,
"expires_at"
:
now
.
Add
(
time
.
Duration
(
info
.
ExpiresIn
)
*
time
.
Second
)
.
Format
(
time
.
RFC3339
),
}
}
func
TestOpenAITokenProvider_CacheHit
(
t
*
testing
.
T
)
{
cache
:=
newOpenAITokenCacheStub
()
account
:=
&
Account
{
ID
:
100
,
Platform
:
PlatformOpenAI
,
Type
:
AccountTypeOAuth
,
Credentials
:
map
[
string
]
any
{
"access_token"
:
"db-token"
,
},
}
cacheKey
:=
OpenAITokenCacheKey
(
account
)
cache
.
tokens
[
cacheKey
]
=
"cached-token"
provider
:=
NewOpenAITokenProvider
(
nil
,
cache
,
nil
)
token
,
err
:=
provider
.
GetAccessToken
(
context
.
Background
(),
account
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
"cached-token"
,
token
)
require
.
Equal
(
t
,
int32
(
1
),
atomic
.
LoadInt32
(
&
cache
.
getCalled
))
require
.
Equal
(
t
,
int32
(
0
),
atomic
.
LoadInt32
(
&
cache
.
setCalled
))
}
func
TestOpenAITokenProvider_CacheMiss_FromCredentials
(
t
*
testing
.
T
)
{
cache
:=
newOpenAITokenCacheStub
()
// Token expires in far future, no refresh needed
expiresAt
:=
time
.
Now
()
.
Add
(
1
*
time
.
Hour
)
.
Format
(
time
.
RFC3339
)
account
:=
&
Account
{
ID
:
101
,
Platform
:
PlatformOpenAI
,
Type
:
AccountTypeOAuth
,
Credentials
:
map
[
string
]
any
{
"access_token"
:
"credential-token"
,
"expires_at"
:
expiresAt
,
},
}
provider
:=
NewOpenAITokenProvider
(
nil
,
cache
,
nil
)
token
,
err
:=
provider
.
GetAccessToken
(
context
.
Background
(),
account
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
"credential-token"
,
token
)
// Should have stored in cache
cacheKey
:=
OpenAITokenCacheKey
(
account
)
require
.
Equal
(
t
,
"credential-token"
,
cache
.
tokens
[
cacheKey
])
}
func
TestOpenAITokenProvider_TokenRefresh
(
t
*
testing
.
T
)
{
cache
:=
newOpenAITokenCacheStub
()
accountRepo
:=
&
openAIAccountRepoStub
{}
oauthService
:=
&
openAIOAuthServiceStub
{
tokenInfo
:
&
OpenAITokenInfo
{
AccessToken
:
"refreshed-token"
,
RefreshToken
:
"new-refresh-token"
,
ExpiresIn
:
3600
,
},
}
// Token expires soon (within refresh skew)
expiresAt
:=
time
.
Now
()
.
Add
(
1
*
time
.
Minute
)
.
Format
(
time
.
RFC3339
)
account
:=
&
Account
{
ID
:
102
,
Platform
:
PlatformOpenAI
,
Type
:
AccountTypeOAuth
,
Credentials
:
map
[
string
]
any
{
"access_token"
:
"old-token"
,
"refresh_token"
:
"old-refresh-token"
,
"expires_at"
:
expiresAt
,
},
}
accountRepo
.
account
=
account
// We need to directly test with the stub - create a custom provider
customProvider
:=
&
testOpenAITokenProvider
{
accountRepo
:
accountRepo
,
tokenCache
:
cache
,
oauthService
:
oauthService
,
}
token
,
err
:=
customProvider
.
GetAccessToken
(
context
.
Background
(),
account
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
"refreshed-token"
,
token
)
require
.
Equal
(
t
,
int32
(
1
),
atomic
.
LoadInt32
(
&
oauthService
.
refreshCalled
))
}
// testOpenAITokenProvider is a test version that uses the stub OAuth service
type
testOpenAITokenProvider
struct
{
accountRepo
*
openAIAccountRepoStub
tokenCache
*
openAITokenCacheStub
oauthService
*
openAIOAuthServiceStub
}
func
(
p
*
testOpenAITokenProvider
)
GetAccessToken
(
ctx
context
.
Context
,
account
*
Account
)
(
string
,
error
)
{
if
account
==
nil
{
return
""
,
errors
.
New
(
"account is nil"
)
}
if
account
.
Platform
!=
PlatformOpenAI
||
account
.
Type
!=
AccountTypeOAuth
{
return
""
,
errors
.
New
(
"not an openai oauth account"
)
}
cacheKey
:=
OpenAITokenCacheKey
(
account
)
// 1. Check cache
if
p
.
tokenCache
!=
nil
{
if
token
,
err
:=
p
.
tokenCache
.
GetAccessToken
(
ctx
,
cacheKey
);
err
==
nil
&&
token
!=
""
{
return
token
,
nil
}
}
// 2. Check if refresh needed
expiresAt
:=
account
.
GetCredentialAsTime
(
"expires_at"
)
needsRefresh
:=
expiresAt
==
nil
||
time
.
Until
(
*
expiresAt
)
<=
openAITokenRefreshSkew
refreshFailed
:=
false
if
needsRefresh
&&
p
.
tokenCache
!=
nil
{
locked
,
err
:=
p
.
tokenCache
.
AcquireRefreshLock
(
ctx
,
cacheKey
,
30
*
time
.
Second
)
if
err
==
nil
&&
locked
{
defer
func
()
{
_
=
p
.
tokenCache
.
ReleaseRefreshLock
(
ctx
,
cacheKey
)
}()
// Check cache again after acquiring lock
if
token
,
err
:=
p
.
tokenCache
.
GetAccessToken
(
ctx
,
cacheKey
);
err
==
nil
&&
token
!=
""
{
return
token
,
nil
}
// Get fresh account from DB
fresh
,
err
:=
p
.
accountRepo
.
GetByID
(
ctx
,
account
.
ID
)
if
err
==
nil
&&
fresh
!=
nil
{
account
=
fresh
}
expiresAt
=
account
.
GetCredentialAsTime
(
"expires_at"
)
if
expiresAt
==
nil
||
time
.
Until
(
*
expiresAt
)
<=
openAITokenRefreshSkew
{
if
p
.
oauthService
==
nil
{
refreshFailed
=
true
// 无法刷新,标记失败
}
else
{
tokenInfo
,
err
:=
p
.
oauthService
.
RefreshAccountToken
(
ctx
,
account
)
if
err
!=
nil
{
refreshFailed
=
true
// 刷新失败,标记以使用短 TTL
}
else
{
newCredentials
:=
p
.
oauthService
.
BuildAccountCredentials
(
tokenInfo
)
for
k
,
v
:=
range
account
.
Credentials
{
if
_
,
exists
:=
newCredentials
[
k
];
!
exists
{
newCredentials
[
k
]
=
v
}
}
account
.
Credentials
=
newCredentials
_
=
p
.
accountRepo
.
Update
(
ctx
,
account
)
expiresAt
=
account
.
GetCredentialAsTime
(
"expires_at"
)
}
}
}
}
else
if
p
.
tokenCache
.
simulateLockRace
{
// Wait and retry cache
time
.
Sleep
(
10
*
time
.
Millisecond
)
// Short wait for test
if
token
,
err
:=
p
.
tokenCache
.
GetAccessToken
(
ctx
,
cacheKey
);
err
==
nil
&&
token
!=
""
{
return
token
,
nil
}
}
}
accessToken
:=
account
.
GetOpenAIAccessToken
()
if
accessToken
==
""
{
return
""
,
errors
.
New
(
"access_token not found in credentials"
)
}
// 3. Store in cache
if
p
.
tokenCache
!=
nil
{
ttl
:=
30
*
time
.
Minute
if
refreshFailed
{
ttl
=
time
.
Minute
// 刷新失败时使用短 TTL
}
else
if
expiresAt
!=
nil
{
until
:=
time
.
Until
(
*
expiresAt
)
if
until
>
openAITokenCacheSkew
{
ttl
=
until
-
openAITokenCacheSkew
}
else
if
until
>
0
{
ttl
=
until
}
else
{
ttl
=
time
.
Minute
}
}
_
=
p
.
tokenCache
.
SetAccessToken
(
ctx
,
cacheKey
,
accessToken
,
ttl
)
}
return
accessToken
,
nil
}
func
TestOpenAITokenProvider_LockRaceCondition
(
t
*
testing
.
T
)
{
cache
:=
newOpenAITokenCacheStub
()
cache
.
simulateLockRace
=
true
accountRepo
:=
&
openAIAccountRepoStub
{}
// Token expires soon
expiresAt
:=
time
.
Now
()
.
Add
(
1
*
time
.
Minute
)
.
Format
(
time
.
RFC3339
)
account
:=
&
Account
{
ID
:
103
,
Platform
:
PlatformOpenAI
,
Type
:
AccountTypeOAuth
,
Credentials
:
map
[
string
]
any
{
"access_token"
:
"race-token"
,
"expires_at"
:
expiresAt
,
},
}
accountRepo
.
account
=
account
// Simulate another worker already refreshed and cached
cacheKey
:=
OpenAITokenCacheKey
(
account
)
go
func
()
{
time
.
Sleep
(
5
*
time
.
Millisecond
)
cache
.
mu
.
Lock
()
cache
.
tokens
[
cacheKey
]
=
"winner-token"
cache
.
mu
.
Unlock
()
}()
provider
:=
&
testOpenAITokenProvider
{
accountRepo
:
accountRepo
,
tokenCache
:
cache
,
}
token
,
err
:=
provider
.
GetAccessToken
(
context
.
Background
(),
account
)
require
.
NoError
(
t
,
err
)
// Should get the token set by the "winner" or the original
require
.
NotEmpty
(
t
,
token
)
}
func
TestOpenAITokenProvider_NilAccount
(
t
*
testing
.
T
)
{
provider
:=
NewOpenAITokenProvider
(
nil
,
nil
,
nil
)
token
,
err
:=
provider
.
GetAccessToken
(
context
.
Background
(),
nil
)
require
.
Error
(
t
,
err
)
require
.
Contains
(
t
,
err
.
Error
(),
"account is nil"
)
require
.
Empty
(
t
,
token
)
}
func
TestOpenAITokenProvider_WrongPlatform
(
t
*
testing
.
T
)
{
provider
:=
NewOpenAITokenProvider
(
nil
,
nil
,
nil
)
account
:=
&
Account
{
ID
:
104
,
Platform
:
PlatformGemini
,
Type
:
AccountTypeOAuth
,
}
token
,
err
:=
provider
.
GetAccessToken
(
context
.
Background
(),
account
)
require
.
Error
(
t
,
err
)
require
.
Contains
(
t
,
err
.
Error
(),
"not an openai oauth account"
)
require
.
Empty
(
t
,
token
)
}
func
TestOpenAITokenProvider_WrongAccountType
(
t
*
testing
.
T
)
{
provider
:=
NewOpenAITokenProvider
(
nil
,
nil
,
nil
)
account
:=
&
Account
{
ID
:
105
,
Platform
:
PlatformOpenAI
,
Type
:
AccountTypeAPIKey
,
}
token
,
err
:=
provider
.
GetAccessToken
(
context
.
Background
(),
account
)
require
.
Error
(
t
,
err
)
require
.
Contains
(
t
,
err
.
Error
(),
"not an openai oauth account"
)
require
.
Empty
(
t
,
token
)
}
func
TestOpenAITokenProvider_NilCache
(
t
*
testing
.
T
)
{
// Token doesn't need refresh
expiresAt
:=
time
.
Now
()
.
Add
(
1
*
time
.
Hour
)
.
Format
(
time
.
RFC3339
)
account
:=
&
Account
{
ID
:
106
,
Platform
:
PlatformOpenAI
,
Type
:
AccountTypeOAuth
,
Credentials
:
map
[
string
]
any
{
"access_token"
:
"nocache-token"
,
"expires_at"
:
expiresAt
,
},
}
provider
:=
NewOpenAITokenProvider
(
nil
,
nil
,
nil
)
token
,
err
:=
provider
.
GetAccessToken
(
context
.
Background
(),
account
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
"nocache-token"
,
token
)
}
func
TestOpenAITokenProvider_CacheGetError
(
t
*
testing
.
T
)
{
cache
:=
newOpenAITokenCacheStub
()
cache
.
getErr
=
errors
.
New
(
"redis connection failed"
)
// Token doesn't need refresh
expiresAt
:=
time
.
Now
()
.
Add
(
1
*
time
.
Hour
)
.
Format
(
time
.
RFC3339
)
account
:=
&
Account
{
ID
:
107
,
Platform
:
PlatformOpenAI
,
Type
:
AccountTypeOAuth
,
Credentials
:
map
[
string
]
any
{
"access_token"
:
"fallback-token"
,
"expires_at"
:
expiresAt
,
},
}
provider
:=
NewOpenAITokenProvider
(
nil
,
cache
,
nil
)
// Should gracefully degrade and return from credentials
token
,
err
:=
provider
.
GetAccessToken
(
context
.
Background
(),
account
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
"fallback-token"
,
token
)
}
func
TestOpenAITokenProvider_CacheSetError
(
t
*
testing
.
T
)
{
cache
:=
newOpenAITokenCacheStub
()
cache
.
setErr
=
errors
.
New
(
"redis write failed"
)
expiresAt
:=
time
.
Now
()
.
Add
(
1
*
time
.
Hour
)
.
Format
(
time
.
RFC3339
)
account
:=
&
Account
{
ID
:
108
,
Platform
:
PlatformOpenAI
,
Type
:
AccountTypeOAuth
,
Credentials
:
map
[
string
]
any
{
"access_token"
:
"still-works-token"
,
"expires_at"
:
expiresAt
,
},
}
provider
:=
NewOpenAITokenProvider
(
nil
,
cache
,
nil
)
// Should still work even if cache set fails
token
,
err
:=
provider
.
GetAccessToken
(
context
.
Background
(),
account
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
"still-works-token"
,
token
)
}
func
TestOpenAITokenProvider_MissingAccessToken
(
t
*
testing
.
T
)
{
cache
:=
newOpenAITokenCacheStub
()
expiresAt
:=
time
.
Now
()
.
Add
(
1
*
time
.
Hour
)
.
Format
(
time
.
RFC3339
)
account
:=
&
Account
{
ID
:
109
,
Platform
:
PlatformOpenAI
,
Type
:
AccountTypeOAuth
,
Credentials
:
map
[
string
]
any
{
"expires_at"
:
expiresAt
,
// missing access_token
},
}
provider
:=
NewOpenAITokenProvider
(
nil
,
cache
,
nil
)
token
,
err
:=
provider
.
GetAccessToken
(
context
.
Background
(),
account
)
require
.
Error
(
t
,
err
)
require
.
Contains
(
t
,
err
.
Error
(),
"access_token not found"
)
require
.
Empty
(
t
,
token
)
}
func
TestOpenAITokenProvider_RefreshError
(
t
*
testing
.
T
)
{
cache
:=
newOpenAITokenCacheStub
()
accountRepo
:=
&
openAIAccountRepoStub
{}
oauthService
:=
&
openAIOAuthServiceStub
{
refreshErr
:
errors
.
New
(
"oauth refresh failed"
),
}
// Token expires soon
expiresAt
:=
time
.
Now
()
.
Add
(
1
*
time
.
Minute
)
.
Format
(
time
.
RFC3339
)
account
:=
&
Account
{
ID
:
110
,
Platform
:
PlatformOpenAI
,
Type
:
AccountTypeOAuth
,
Credentials
:
map
[
string
]
any
{
"access_token"
:
"old-token"
,
"refresh_token"
:
"old-refresh-token"
,
"expires_at"
:
expiresAt
,
},
}
accountRepo
.
account
=
account
provider
:=
&
testOpenAITokenProvider
{
accountRepo
:
accountRepo
,
tokenCache
:
cache
,
oauthService
:
oauthService
,
}
// Now with fallback behavior, should return existing token even if refresh fails
token
,
err
:=
provider
.
GetAccessToken
(
context
.
Background
(),
account
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
"old-token"
,
token
)
// Fallback to existing token
}
func
TestOpenAITokenProvider_OAuthServiceNotConfigured
(
t
*
testing
.
T
)
{
cache
:=
newOpenAITokenCacheStub
()
accountRepo
:=
&
openAIAccountRepoStub
{}
// Token expires soon
expiresAt
:=
time
.
Now
()
.
Add
(
1
*
time
.
Minute
)
.
Format
(
time
.
RFC3339
)
account
:=
&
Account
{
ID
:
111
,
Platform
:
PlatformOpenAI
,
Type
:
AccountTypeOAuth
,
Credentials
:
map
[
string
]
any
{
"access_token"
:
"old-token"
,
"expires_at"
:
expiresAt
,
},
}
accountRepo
.
account
=
account
provider
:=
&
testOpenAITokenProvider
{
accountRepo
:
accountRepo
,
tokenCache
:
cache
,
oauthService
:
nil
,
// not configured
}
// Now with fallback behavior, should return existing token even if oauth service not configured
token
,
err
:=
provider
.
GetAccessToken
(
context
.
Background
(),
account
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
"old-token"
,
token
)
// Fallback to existing token
}
func
TestOpenAITokenProvider_TTLCalculation
(
t
*
testing
.
T
)
{
tests
:=
[]
struct
{
name
string
expiresIn
time
.
Duration
}{
{
name
:
"far_future_expiry"
,
expiresIn
:
1
*
time
.
Hour
,
},
{
name
:
"medium_expiry"
,
expiresIn
:
10
*
time
.
Minute
,
},
{
name
:
"near_expiry"
,
expiresIn
:
6
*
time
.
Minute
,
},
}
for
_
,
tt
:=
range
tests
{
t
.
Run
(
tt
.
name
,
func
(
t
*
testing
.
T
)
{
cache
:=
newOpenAITokenCacheStub
()
expiresAt
:=
time
.
Now
()
.
Add
(
tt
.
expiresIn
)
.
Format
(
time
.
RFC3339
)
account
:=
&
Account
{
ID
:
200
,
Platform
:
PlatformOpenAI
,
Type
:
AccountTypeOAuth
,
Credentials
:
map
[
string
]
any
{
"access_token"
:
"test-token"
,
"expires_at"
:
expiresAt
,
},
}
provider
:=
NewOpenAITokenProvider
(
nil
,
cache
,
nil
)
_
,
err
:=
provider
.
GetAccessToken
(
context
.
Background
(),
account
)
require
.
NoError
(
t
,
err
)
// Verify token was cached
cacheKey
:=
OpenAITokenCacheKey
(
account
)
require
.
Equal
(
t
,
"test-token"
,
cache
.
tokens
[
cacheKey
])
})
}
}
func
TestOpenAITokenProvider_DoubleCheckAfterLock
(
t
*
testing
.
T
)
{
cache
:=
newOpenAITokenCacheStub
()
accountRepo
:=
&
openAIAccountRepoStub
{}
oauthService
:=
&
openAIOAuthServiceStub
{
tokenInfo
:
&
OpenAITokenInfo
{
AccessToken
:
"refreshed-token"
,
RefreshToken
:
"new-refresh"
,
ExpiresIn
:
3600
,
},
}
// Token expires soon
expiresAt
:=
time
.
Now
()
.
Add
(
1
*
time
.
Minute
)
.
Format
(
time
.
RFC3339
)
account
:=
&
Account
{
ID
:
112
,
Platform
:
PlatformOpenAI
,
Type
:
AccountTypeOAuth
,
Credentials
:
map
[
string
]
any
{
"access_token"
:
"old-token"
,
"expires_at"
:
expiresAt
,
},
}
accountRepo
.
account
=
account
cacheKey
:=
OpenAITokenCacheKey
(
account
)
// Simulate: first GetAccessToken returns empty, but after lock acquired, cache has token
originalGet
:=
int32
(
0
)
cache
.
tokens
[
cacheKey
]
=
""
// Empty initially
provider
:=
&
testOpenAITokenProvider
{
accountRepo
:
accountRepo
,
tokenCache
:
cache
,
oauthService
:
oauthService
,
}
// In a goroutine, set the cached token after a small delay (simulating race)
go
func
()
{
time
.
Sleep
(
5
*
time
.
Millisecond
)
cache
.
mu
.
Lock
()
cache
.
tokens
[
cacheKey
]
=
"cached-by-other"
cache
.
mu
.
Unlock
()
}()
token
,
err
:=
provider
.
GetAccessToken
(
context
.
Background
(),
account
)
require
.
NoError
(
t
,
err
)
// Should get either the refreshed token or the cached one
require
.
NotEmpty
(
t
,
token
)
_
=
originalGet
// Suppress unused warning
}
// Tests for real provider - to increase coverage
func
TestOpenAITokenProvider_Real_LockFailedWait
(
t
*
testing
.
T
)
{
cache
:=
newOpenAITokenCacheStub
()
cache
.
lockAcquired
=
false
// Lock acquisition fails
// Token expires soon (within refresh skew) to trigger lock attempt
expiresAt
:=
time
.
Now
()
.
Add
(
1
*
time
.
Minute
)
.
Format
(
time
.
RFC3339
)
account
:=
&
Account
{
ID
:
200
,
Platform
:
PlatformOpenAI
,
Type
:
AccountTypeOAuth
,
Credentials
:
map
[
string
]
any
{
"access_token"
:
"fallback-token"
,
"expires_at"
:
expiresAt
,
},
}
// Set token in cache after lock wait period (simulate other worker refreshing)
cacheKey
:=
OpenAITokenCacheKey
(
account
)
go
func
()
{
time
.
Sleep
(
100
*
time
.
Millisecond
)
cache
.
mu
.
Lock
()
cache
.
tokens
[
cacheKey
]
=
"refreshed-by-other"
cache
.
mu
.
Unlock
()
}()
provider
:=
NewOpenAITokenProvider
(
nil
,
cache
,
nil
)
token
,
err
:=
provider
.
GetAccessToken
(
context
.
Background
(),
account
)
require
.
NoError
(
t
,
err
)
// Should get either the fallback token or the refreshed one
require
.
NotEmpty
(
t
,
token
)
}
func
TestOpenAITokenProvider_Real_CacheHitAfterWait
(
t
*
testing
.
T
)
{
cache
:=
newOpenAITokenCacheStub
()
cache
.
lockAcquired
=
false
// Lock acquisition fails
// Token expires soon
expiresAt
:=
time
.
Now
()
.
Add
(
1
*
time
.
Minute
)
.
Format
(
time
.
RFC3339
)
account
:=
&
Account
{
ID
:
201
,
Platform
:
PlatformOpenAI
,
Type
:
AccountTypeOAuth
,
Credentials
:
map
[
string
]
any
{
"access_token"
:
"original-token"
,
"expires_at"
:
expiresAt
,
},
}
cacheKey
:=
OpenAITokenCacheKey
(
account
)
// Set token in cache immediately after wait starts
go
func
()
{
time
.
Sleep
(
50
*
time
.
Millisecond
)
cache
.
mu
.
Lock
()
cache
.
tokens
[
cacheKey
]
=
"winner-token"
cache
.
mu
.
Unlock
()
}()
provider
:=
NewOpenAITokenProvider
(
nil
,
cache
,
nil
)
token
,
err
:=
provider
.
GetAccessToken
(
context
.
Background
(),
account
)
require
.
NoError
(
t
,
err
)
require
.
NotEmpty
(
t
,
token
)
}
func
TestOpenAITokenProvider_Real_ExpiredWithoutRefreshToken
(
t
*
testing
.
T
)
{
cache
:=
newOpenAITokenCacheStub
()
cache
.
lockAcquired
=
false
// Prevent entering refresh logic
// Token with nil expires_at (no expiry set) - should use credentials
account
:=
&
Account
{
ID
:
202
,
Platform
:
PlatformOpenAI
,
Type
:
AccountTypeOAuth
,
Credentials
:
map
[
string
]
any
{
"access_token"
:
"no-expiry-token"
,
},
}
provider
:=
NewOpenAITokenProvider
(
nil
,
cache
,
nil
)
token
,
err
:=
provider
.
GetAccessToken
(
context
.
Background
(),
account
)
// Without OAuth service, refresh will fail but token should be returned from credentials
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
"no-expiry-token"
,
token
)
}
func
TestOpenAITokenProvider_Real_WhitespaceToken
(
t
*
testing
.
T
)
{
cache
:=
newOpenAITokenCacheStub
()
cacheKey
:=
"openai:account:203"
cache
.
tokens
[
cacheKey
]
=
" "
// Whitespace only - should be treated as empty
expiresAt
:=
time
.
Now
()
.
Add
(
1
*
time
.
Hour
)
.
Format
(
time
.
RFC3339
)
account
:=
&
Account
{
ID
:
203
,
Platform
:
PlatformOpenAI
,
Type
:
AccountTypeOAuth
,
Credentials
:
map
[
string
]
any
{
"access_token"
:
"real-token"
,
"expires_at"
:
expiresAt
,
},
}
provider
:=
NewOpenAITokenProvider
(
nil
,
cache
,
nil
)
token
,
err
:=
provider
.
GetAccessToken
(
context
.
Background
(),
account
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
"real-token"
,
token
)
// Should fall back to credentials
}
func
TestOpenAITokenProvider_Real_LockError
(
t
*
testing
.
T
)
{
cache
:=
newOpenAITokenCacheStub
()
cache
.
lockErr
=
errors
.
New
(
"redis lock failed"
)
// Token expires soon (within refresh skew)
expiresAt
:=
time
.
Now
()
.
Add
(
1
*
time
.
Minute
)
.
Format
(
time
.
RFC3339
)
account
:=
&
Account
{
ID
:
204
,
Platform
:
PlatformOpenAI
,
Type
:
AccountTypeOAuth
,
Credentials
:
map
[
string
]
any
{
"access_token"
:
"fallback-on-lock-error"
,
"expires_at"
:
expiresAt
,
},
}
provider
:=
NewOpenAITokenProvider
(
nil
,
cache
,
nil
)
token
,
err
:=
provider
.
GetAccessToken
(
context
.
Background
(),
account
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
"fallback-on-lock-error"
,
token
)
}
func
TestOpenAITokenProvider_Real_WhitespaceCredentialToken
(
t
*
testing
.
T
)
{
cache
:=
newOpenAITokenCacheStub
()
expiresAt
:=
time
.
Now
()
.
Add
(
1
*
time
.
Hour
)
.
Format
(
time
.
RFC3339
)
account
:=
&
Account
{
ID
:
205
,
Platform
:
PlatformOpenAI
,
Type
:
AccountTypeOAuth
,
Credentials
:
map
[
string
]
any
{
"access_token"
:
" "
,
// Whitespace only
"expires_at"
:
expiresAt
,
},
}
provider
:=
NewOpenAITokenProvider
(
nil
,
cache
,
nil
)
token
,
err
:=
provider
.
GetAccessToken
(
context
.
Background
(),
account
)
require
.
Error
(
t
,
err
)
require
.
Contains
(
t
,
err
.
Error
(),
"access_token not found"
)
require
.
Empty
(
t
,
token
)
}
func
TestOpenAITokenProvider_Real_NilCredentials
(
t
*
testing
.
T
)
{
cache
:=
newOpenAITokenCacheStub
()
expiresAt
:=
time
.
Now
()
.
Add
(
1
*
time
.
Hour
)
.
Format
(
time
.
RFC3339
)
account
:=
&
Account
{
ID
:
206
,
Platform
:
PlatformOpenAI
,
Type
:
AccountTypeOAuth
,
Credentials
:
map
[
string
]
any
{
"expires_at"
:
expiresAt
,
// No access_token
},
}
provider
:=
NewOpenAITokenProvider
(
nil
,
cache
,
nil
)
token
,
err
:=
provider
.
GetAccessToken
(
context
.
Background
(),
account
)
require
.
Error
(
t
,
err
)
require
.
Contains
(
t
,
err
.
Error
(),
"access_token not found"
)
require
.
Empty
(
t
,
token
)
}
backend/internal/service/openai_tool_continuation.go
0 → 100644
View file @
b9b4db3d
package
service
import
"strings"
// NeedsToolContinuation 判定请求是否需要工具调用续链处理。
// 满足以下任一信号即视为续链:previous_response_id、input 内包含 function_call_output/item_reference、
// 或显式声明 tools/tool_choice。
func
NeedsToolContinuation
(
reqBody
map
[
string
]
any
)
bool
{
if
reqBody
==
nil
{
return
false
}
if
hasNonEmptyString
(
reqBody
[
"previous_response_id"
])
{
return
true
}
if
hasToolsSignal
(
reqBody
)
{
return
true
}
if
hasToolChoiceSignal
(
reqBody
)
{
return
true
}
if
inputHasType
(
reqBody
,
"function_call_output"
)
{
return
true
}
if
inputHasType
(
reqBody
,
"item_reference"
)
{
return
true
}
return
false
}
// HasFunctionCallOutput 判断 input 是否包含 function_call_output,用于触发续链校验。
func
HasFunctionCallOutput
(
reqBody
map
[
string
]
any
)
bool
{
if
reqBody
==
nil
{
return
false
}
return
inputHasType
(
reqBody
,
"function_call_output"
)
}
// HasToolCallContext 判断 input 是否包含带 call_id 的 tool_call/function_call,
// 用于判断 function_call_output 是否具备可关联的上下文。
func
HasToolCallContext
(
reqBody
map
[
string
]
any
)
bool
{
if
reqBody
==
nil
{
return
false
}
input
,
ok
:=
reqBody
[
"input"
]
.
([]
any
)
if
!
ok
{
return
false
}
for
_
,
item
:=
range
input
{
itemMap
,
ok
:=
item
.
(
map
[
string
]
any
)
if
!
ok
{
continue
}
itemType
,
_
:=
itemMap
[
"type"
]
.
(
string
)
if
itemType
!=
"tool_call"
&&
itemType
!=
"function_call"
{
continue
}
if
callID
,
ok
:=
itemMap
[
"call_id"
]
.
(
string
);
ok
&&
strings
.
TrimSpace
(
callID
)
!=
""
{
return
true
}
}
return
false
}
// FunctionCallOutputCallIDs 提取 input 中 function_call_output 的 call_id 集合。
// 仅返回非空 call_id,用于与 item_reference.id 做匹配校验。
func
FunctionCallOutputCallIDs
(
reqBody
map
[
string
]
any
)
[]
string
{
if
reqBody
==
nil
{
return
nil
}
input
,
ok
:=
reqBody
[
"input"
]
.
([]
any
)
if
!
ok
{
return
nil
}
ids
:=
make
(
map
[
string
]
struct
{})
for
_
,
item
:=
range
input
{
itemMap
,
ok
:=
item
.
(
map
[
string
]
any
)
if
!
ok
{
continue
}
itemType
,
_
:=
itemMap
[
"type"
]
.
(
string
)
if
itemType
!=
"function_call_output"
{
continue
}
if
callID
,
ok
:=
itemMap
[
"call_id"
]
.
(
string
);
ok
&&
strings
.
TrimSpace
(
callID
)
!=
""
{
ids
[
callID
]
=
struct
{}{}
}
}
if
len
(
ids
)
==
0
{
return
nil
}
result
:=
make
([]
string
,
0
,
len
(
ids
))
for
id
:=
range
ids
{
result
=
append
(
result
,
id
)
}
return
result
}
// HasFunctionCallOutputMissingCallID 判断是否存在缺少 call_id 的 function_call_output。
func
HasFunctionCallOutputMissingCallID
(
reqBody
map
[
string
]
any
)
bool
{
if
reqBody
==
nil
{
return
false
}
input
,
ok
:=
reqBody
[
"input"
]
.
([]
any
)
if
!
ok
{
return
false
}
for
_
,
item
:=
range
input
{
itemMap
,
ok
:=
item
.
(
map
[
string
]
any
)
if
!
ok
{
continue
}
itemType
,
_
:=
itemMap
[
"type"
]
.
(
string
)
if
itemType
!=
"function_call_output"
{
continue
}
callID
,
_
:=
itemMap
[
"call_id"
]
.
(
string
)
if
strings
.
TrimSpace
(
callID
)
==
""
{
return
true
}
}
return
false
}
// HasItemReferenceForCallIDs 判断 item_reference.id 是否覆盖所有 call_id。
// 用于仅依赖引用项完成续链场景的校验。
func
HasItemReferenceForCallIDs
(
reqBody
map
[
string
]
any
,
callIDs
[]
string
)
bool
{
if
reqBody
==
nil
||
len
(
callIDs
)
==
0
{
return
false
}
input
,
ok
:=
reqBody
[
"input"
]
.
([]
any
)
if
!
ok
{
return
false
}
referenceIDs
:=
make
(
map
[
string
]
struct
{})
for
_
,
item
:=
range
input
{
itemMap
,
ok
:=
item
.
(
map
[
string
]
any
)
if
!
ok
{
continue
}
itemType
,
_
:=
itemMap
[
"type"
]
.
(
string
)
if
itemType
!=
"item_reference"
{
continue
}
idValue
,
_
:=
itemMap
[
"id"
]
.
(
string
)
idValue
=
strings
.
TrimSpace
(
idValue
)
if
idValue
==
""
{
continue
}
referenceIDs
[
idValue
]
=
struct
{}{}
}
if
len
(
referenceIDs
)
==
0
{
return
false
}
for
_
,
callID
:=
range
callIDs
{
if
_
,
ok
:=
referenceIDs
[
callID
];
!
ok
{
return
false
}
}
return
true
}
// inputHasType 判断 input 中是否存在指定类型的 item。
func
inputHasType
(
reqBody
map
[
string
]
any
,
want
string
)
bool
{
input
,
ok
:=
reqBody
[
"input"
]
.
([]
any
)
if
!
ok
{
return
false
}
for
_
,
item
:=
range
input
{
itemMap
,
ok
:=
item
.
(
map
[
string
]
any
)
if
!
ok
{
continue
}
itemType
,
_
:=
itemMap
[
"type"
]
.
(
string
)
if
itemType
==
want
{
return
true
}
}
return
false
}
// hasNonEmptyString 判断字段是否为非空字符串。
func
hasNonEmptyString
(
value
any
)
bool
{
stringValue
,
ok
:=
value
.
(
string
)
return
ok
&&
strings
.
TrimSpace
(
stringValue
)
!=
""
}
// hasToolsSignal 判断 tools 字段是否显式声明(存在且不为空)。
func
hasToolsSignal
(
reqBody
map
[
string
]
any
)
bool
{
raw
,
exists
:=
reqBody
[
"tools"
]
if
!
exists
||
raw
==
nil
{
return
false
}
if
tools
,
ok
:=
raw
.
([]
any
);
ok
{
return
len
(
tools
)
>
0
}
return
false
}
// hasToolChoiceSignal 判断 tool_choice 是否显式声明(非空或非 nil)。
func
hasToolChoiceSignal
(
reqBody
map
[
string
]
any
)
bool
{
raw
,
exists
:=
reqBody
[
"tool_choice"
]
if
!
exists
||
raw
==
nil
{
return
false
}
switch
value
:=
raw
.
(
type
)
{
case
string
:
return
strings
.
TrimSpace
(
value
)
!=
""
case
map
[
string
]
any
:
return
len
(
value
)
>
0
default
:
return
false
}
}
backend/internal/service/openai_tool_continuation_test.go
0 → 100644
View file @
b9b4db3d
package
service
import
(
"testing"
"github.com/stretchr/testify/require"
)
func
TestNeedsToolContinuationSignals
(
t
*
testing
.
T
)
{
// 覆盖所有触发续链的信号来源,确保判定逻辑完整。
cases
:=
[]
struct
{
name
string
body
map
[
string
]
any
want
bool
}{
{
name
:
"nil"
,
body
:
nil
,
want
:
false
},
{
name
:
"previous_response_id"
,
body
:
map
[
string
]
any
{
"previous_response_id"
:
"resp_1"
},
want
:
true
},
{
name
:
"previous_response_id_blank"
,
body
:
map
[
string
]
any
{
"previous_response_id"
:
" "
},
want
:
false
},
{
name
:
"function_call_output"
,
body
:
map
[
string
]
any
{
"input"
:
[]
any
{
map
[
string
]
any
{
"type"
:
"function_call_output"
}}},
want
:
true
},
{
name
:
"item_reference"
,
body
:
map
[
string
]
any
{
"input"
:
[]
any
{
map
[
string
]
any
{
"type"
:
"item_reference"
}}},
want
:
true
},
{
name
:
"tools"
,
body
:
map
[
string
]
any
{
"tools"
:
[]
any
{
map
[
string
]
any
{
"type"
:
"function"
}}},
want
:
true
},
{
name
:
"tools_empty"
,
body
:
map
[
string
]
any
{
"tools"
:
[]
any
{}},
want
:
false
},
{
name
:
"tools_invalid"
,
body
:
map
[
string
]
any
{
"tools"
:
"bad"
},
want
:
false
},
{
name
:
"tool_choice"
,
body
:
map
[
string
]
any
{
"tool_choice"
:
"auto"
},
want
:
true
},
{
name
:
"tool_choice_object"
,
body
:
map
[
string
]
any
{
"tool_choice"
:
map
[
string
]
any
{
"type"
:
"function"
}},
want
:
true
},
{
name
:
"tool_choice_empty_object"
,
body
:
map
[
string
]
any
{
"tool_choice"
:
map
[
string
]
any
{}},
want
:
false
},
{
name
:
"none"
,
body
:
map
[
string
]
any
{
"input"
:
[]
any
{
map
[
string
]
any
{
"type"
:
"text"
,
"text"
:
"hi"
}}},
want
:
false
},
}
for
_
,
tt
:=
range
cases
{
t
.
Run
(
tt
.
name
,
func
(
t
*
testing
.
T
)
{
require
.
Equal
(
t
,
tt
.
want
,
NeedsToolContinuation
(
tt
.
body
))
})
}
}
func
TestHasFunctionCallOutput
(
t
*
testing
.
T
)
{
// 仅当 input 中存在 function_call_output 才视为续链输出。
require
.
False
(
t
,
HasFunctionCallOutput
(
nil
))
require
.
True
(
t
,
HasFunctionCallOutput
(
map
[
string
]
any
{
"input"
:
[]
any
{
map
[
string
]
any
{
"type"
:
"function_call_output"
}},
}))
require
.
False
(
t
,
HasFunctionCallOutput
(
map
[
string
]
any
{
"input"
:
"text"
,
}))
}
func
TestHasToolCallContext
(
t
*
testing
.
T
)
{
// tool_call/function_call 必须包含 call_id,才能作为可关联上下文。
require
.
False
(
t
,
HasToolCallContext
(
nil
))
require
.
True
(
t
,
HasToolCallContext
(
map
[
string
]
any
{
"input"
:
[]
any
{
map
[
string
]
any
{
"type"
:
"tool_call"
,
"call_id"
:
"call_1"
}},
}))
require
.
True
(
t
,
HasToolCallContext
(
map
[
string
]
any
{
"input"
:
[]
any
{
map
[
string
]
any
{
"type"
:
"function_call"
,
"call_id"
:
"call_2"
}},
}))
require
.
False
(
t
,
HasToolCallContext
(
map
[
string
]
any
{
"input"
:
[]
any
{
map
[
string
]
any
{
"type"
:
"tool_call"
}},
}))
}
func
TestFunctionCallOutputCallIDs
(
t
*
testing
.
T
)
{
// 仅提取非空 call_id,去重后返回。
require
.
Empty
(
t
,
FunctionCallOutputCallIDs
(
nil
))
callIDs
:=
FunctionCallOutputCallIDs
(
map
[
string
]
any
{
"input"
:
[]
any
{
map
[
string
]
any
{
"type"
:
"function_call_output"
,
"call_id"
:
"call_1"
},
map
[
string
]
any
{
"type"
:
"function_call_output"
,
"call_id"
:
""
},
map
[
string
]
any
{
"type"
:
"function_call_output"
,
"call_id"
:
"call_1"
},
},
})
require
.
ElementsMatch
(
t
,
[]
string
{
"call_1"
},
callIDs
)
}
func
TestHasFunctionCallOutputMissingCallID
(
t
*
testing
.
T
)
{
require
.
False
(
t
,
HasFunctionCallOutputMissingCallID
(
nil
))
require
.
True
(
t
,
HasFunctionCallOutputMissingCallID
(
map
[
string
]
any
{
"input"
:
[]
any
{
map
[
string
]
any
{
"type"
:
"function_call_output"
}},
}))
require
.
False
(
t
,
HasFunctionCallOutputMissingCallID
(
map
[
string
]
any
{
"input"
:
[]
any
{
map
[
string
]
any
{
"type"
:
"function_call_output"
,
"call_id"
:
"call_1"
}},
}))
}
func
TestHasItemReferenceForCallIDs
(
t
*
testing
.
T
)
{
// item_reference 需要覆盖所有 call_id 才视为可关联上下文。
require
.
False
(
t
,
HasItemReferenceForCallIDs
(
nil
,
[]
string
{
"call_1"
}))
require
.
False
(
t
,
HasItemReferenceForCallIDs
(
map
[
string
]
any
{},
[]
string
{
"call_1"
}))
req
:=
map
[
string
]
any
{
"input"
:
[]
any
{
map
[
string
]
any
{
"type"
:
"item_reference"
,
"id"
:
"call_1"
},
map
[
string
]
any
{
"type"
:
"item_reference"
,
"id"
:
"call_2"
},
},
}
require
.
True
(
t
,
HasItemReferenceForCallIDs
(
req
,
[]
string
{
"call_1"
}))
require
.
True
(
t
,
HasItemReferenceForCallIDs
(
req
,
[]
string
{
"call_1"
,
"call_2"
}))
require
.
False
(
t
,
HasItemReferenceForCallIDs
(
req
,
[]
string
{
"call_1"
,
"call_3"
}))
}
backend/internal/service/openai_tool_corrector.go
0 → 100644
View file @
b9b4db3d
package
service
import
(
"encoding/json"
"fmt"
"log"
"sync"
)
// codexToolNameMapping 定义 Codex 原生工具名称到 OpenCode 工具名称的映射
var
codexToolNameMapping
=
map
[
string
]
string
{
"apply_patch"
:
"edit"
,
"applyPatch"
:
"edit"
,
"update_plan"
:
"todowrite"
,
"updatePlan"
:
"todowrite"
,
"read_plan"
:
"todoread"
,
"readPlan"
:
"todoread"
,
"search_files"
:
"grep"
,
"searchFiles"
:
"grep"
,
"list_files"
:
"glob"
,
"listFiles"
:
"glob"
,
"read_file"
:
"read"
,
"readFile"
:
"read"
,
"write_file"
:
"write"
,
"writeFile"
:
"write"
,
"execute_bash"
:
"bash"
,
"executeBash"
:
"bash"
,
"exec_bash"
:
"bash"
,
"execBash"
:
"bash"
,
}
// ToolCorrectionStats 记录工具修正的统计信息(导出用于 JSON 序列化)
type
ToolCorrectionStats
struct
{
TotalCorrected
int
`json:"total_corrected"`
CorrectionsByTool
map
[
string
]
int
`json:"corrections_by_tool"`
}
// CodexToolCorrector 处理 Codex 工具调用的自动修正
type
CodexToolCorrector
struct
{
stats
ToolCorrectionStats
mu
sync
.
RWMutex
}
// NewCodexToolCorrector 创建新的工具修正器
func
NewCodexToolCorrector
()
*
CodexToolCorrector
{
return
&
CodexToolCorrector
{
stats
:
ToolCorrectionStats
{
CorrectionsByTool
:
make
(
map
[
string
]
int
),
},
}
}
// CorrectToolCallsInSSEData 修正 SSE 数据中的工具调用
// 返回修正后的数据和是否进行了修正
func
(
c
*
CodexToolCorrector
)
CorrectToolCallsInSSEData
(
data
string
)
(
string
,
bool
)
{
if
data
==
""
||
data
==
"
\n
"
{
return
data
,
false
}
// 尝试解析 JSON
var
payload
map
[
string
]
any
if
err
:=
json
.
Unmarshal
([]
byte
(
data
),
&
payload
);
err
!=
nil
{
// 不是有效的 JSON,直接返回原数据
return
data
,
false
}
corrected
:=
false
// 处理 tool_calls 数组
if
toolCalls
,
ok
:=
payload
[
"tool_calls"
]
.
([]
any
);
ok
{
if
c
.
correctToolCallsArray
(
toolCalls
)
{
corrected
=
true
}
}
// 处理 function_call 对象
if
functionCall
,
ok
:=
payload
[
"function_call"
]
.
(
map
[
string
]
any
);
ok
{
if
c
.
correctFunctionCall
(
functionCall
)
{
corrected
=
true
}
}
// 处理 delta.tool_calls
if
delta
,
ok
:=
payload
[
"delta"
]
.
(
map
[
string
]
any
);
ok
{
if
toolCalls
,
ok
:=
delta
[
"tool_calls"
]
.
([]
any
);
ok
{
if
c
.
correctToolCallsArray
(
toolCalls
)
{
corrected
=
true
}
}
if
functionCall
,
ok
:=
delta
[
"function_call"
]
.
(
map
[
string
]
any
);
ok
{
if
c
.
correctFunctionCall
(
functionCall
)
{
corrected
=
true
}
}
}
// 处理 choices[].message.tool_calls 和 choices[].delta.tool_calls
if
choices
,
ok
:=
payload
[
"choices"
]
.
([]
any
);
ok
{
for
_
,
choice
:=
range
choices
{
if
choiceMap
,
ok
:=
choice
.
(
map
[
string
]
any
);
ok
{
// 处理 message 中的工具调用
if
message
,
ok
:=
choiceMap
[
"message"
]
.
(
map
[
string
]
any
);
ok
{
if
toolCalls
,
ok
:=
message
[
"tool_calls"
]
.
([]
any
);
ok
{
if
c
.
correctToolCallsArray
(
toolCalls
)
{
corrected
=
true
}
}
if
functionCall
,
ok
:=
message
[
"function_call"
]
.
(
map
[
string
]
any
);
ok
{
if
c
.
correctFunctionCall
(
functionCall
)
{
corrected
=
true
}
}
}
// 处理 delta 中的工具调用
if
delta
,
ok
:=
choiceMap
[
"delta"
]
.
(
map
[
string
]
any
);
ok
{
if
toolCalls
,
ok
:=
delta
[
"tool_calls"
]
.
([]
any
);
ok
{
if
c
.
correctToolCallsArray
(
toolCalls
)
{
corrected
=
true
}
}
if
functionCall
,
ok
:=
delta
[
"function_call"
]
.
(
map
[
string
]
any
);
ok
{
if
c
.
correctFunctionCall
(
functionCall
)
{
corrected
=
true
}
}
}
}
}
}
if
!
corrected
{
return
data
,
false
}
// 序列化回 JSON
correctedBytes
,
err
:=
json
.
Marshal
(
payload
)
if
err
!=
nil
{
log
.
Printf
(
"[CodexToolCorrector] Failed to marshal corrected data: %v"
,
err
)
return
data
,
false
}
return
string
(
correctedBytes
),
true
}
// correctToolCallsArray 修正工具调用数组中的工具名称
func
(
c
*
CodexToolCorrector
)
correctToolCallsArray
(
toolCalls
[]
any
)
bool
{
corrected
:=
false
for
_
,
toolCall
:=
range
toolCalls
{
if
toolCallMap
,
ok
:=
toolCall
.
(
map
[
string
]
any
);
ok
{
if
function
,
ok
:=
toolCallMap
[
"function"
]
.
(
map
[
string
]
any
);
ok
{
if
c
.
correctFunctionCall
(
function
)
{
corrected
=
true
}
}
}
}
return
corrected
}
// correctFunctionCall 修正单个函数调用的工具名称和参数
func
(
c
*
CodexToolCorrector
)
correctFunctionCall
(
functionCall
map
[
string
]
any
)
bool
{
name
,
ok
:=
functionCall
[
"name"
]
.
(
string
)
if
!
ok
||
name
==
""
{
return
false
}
corrected
:=
false
// 查找并修正工具名称
if
correctName
,
found
:=
codexToolNameMapping
[
name
];
found
{
functionCall
[
"name"
]
=
correctName
c
.
recordCorrection
(
name
,
correctName
)
corrected
=
true
name
=
correctName
// 使用修正后的名称进行参数修正
}
// 修正工具参数(基于工具名称)
if
c
.
correctToolParameters
(
name
,
functionCall
)
{
corrected
=
true
}
return
corrected
}
// correctToolParameters 修正工具参数以符合 OpenCode 规范
func
(
c
*
CodexToolCorrector
)
correctToolParameters
(
toolName
string
,
functionCall
map
[
string
]
any
)
bool
{
arguments
,
ok
:=
functionCall
[
"arguments"
]
if
!
ok
{
return
false
}
// arguments 可能是字符串(JSON)或已解析的 map
var
argsMap
map
[
string
]
any
switch
v
:=
arguments
.
(
type
)
{
case
string
:
// 解析 JSON 字符串
if
err
:=
json
.
Unmarshal
([]
byte
(
v
),
&
argsMap
);
err
!=
nil
{
return
false
}
case
map
[
string
]
any
:
argsMap
=
v
default
:
return
false
}
corrected
:=
false
// 根据工具名称应用特定的参数修正规则
switch
toolName
{
case
"bash"
:
// 移除 workdir 参数(OpenCode 不支持)
if
_
,
exists
:=
argsMap
[
"workdir"
];
exists
{
delete
(
argsMap
,
"workdir"
)
corrected
=
true
log
.
Printf
(
"[CodexToolCorrector] Removed 'workdir' parameter from bash tool"
)
}
if
_
,
exists
:=
argsMap
[
"work_dir"
];
exists
{
delete
(
argsMap
,
"work_dir"
)
corrected
=
true
log
.
Printf
(
"[CodexToolCorrector] Removed 'work_dir' parameter from bash tool"
)
}
case
"edit"
:
// OpenCode edit 使用 old_string/new_string,Codex 可能使用其他名称
// 这里可以添加参数名称的映射逻辑
if
_
,
exists
:=
argsMap
[
"file_path"
];
!
exists
{
if
path
,
exists
:=
argsMap
[
"path"
];
exists
{
argsMap
[
"file_path"
]
=
path
delete
(
argsMap
,
"path"
)
corrected
=
true
log
.
Printf
(
"[CodexToolCorrector] Renamed 'path' to 'file_path' in edit tool"
)
}
}
}
// 如果修正了参数,需要重新序列化
if
corrected
{
if
_
,
wasString
:=
arguments
.
(
string
);
wasString
{
// 原本是字符串,序列化回字符串
if
newArgsJSON
,
err
:=
json
.
Marshal
(
argsMap
);
err
==
nil
{
functionCall
[
"arguments"
]
=
string
(
newArgsJSON
)
}
}
else
{
// 原本是 map,直接赋值
functionCall
[
"arguments"
]
=
argsMap
}
}
return
corrected
}
// recordCorrection 记录一次工具名称修正
func
(
c
*
CodexToolCorrector
)
recordCorrection
(
from
,
to
string
)
{
c
.
mu
.
Lock
()
defer
c
.
mu
.
Unlock
()
c
.
stats
.
TotalCorrected
++
key
:=
fmt
.
Sprintf
(
"%s->%s"
,
from
,
to
)
c
.
stats
.
CorrectionsByTool
[
key
]
++
log
.
Printf
(
"[CodexToolCorrector] Corrected tool call: %s -> %s (total: %d)"
,
from
,
to
,
c
.
stats
.
TotalCorrected
)
}
// GetStats 获取工具修正统计信息
func
(
c
*
CodexToolCorrector
)
GetStats
()
ToolCorrectionStats
{
c
.
mu
.
RLock
()
defer
c
.
mu
.
RUnlock
()
// 返回副本以避免并发问题
statsCopy
:=
ToolCorrectionStats
{
TotalCorrected
:
c
.
stats
.
TotalCorrected
,
CorrectionsByTool
:
make
(
map
[
string
]
int
,
len
(
c
.
stats
.
CorrectionsByTool
)),
}
for
k
,
v
:=
range
c
.
stats
.
CorrectionsByTool
{
statsCopy
.
CorrectionsByTool
[
k
]
=
v
}
return
statsCopy
}
// ResetStats 重置统计信息
func
(
c
*
CodexToolCorrector
)
ResetStats
()
{
c
.
mu
.
Lock
()
defer
c
.
mu
.
Unlock
()
c
.
stats
.
TotalCorrected
=
0
c
.
stats
.
CorrectionsByTool
=
make
(
map
[
string
]
int
)
}
// CorrectToolName 直接修正工具名称(用于非 SSE 场景)
func
CorrectToolName
(
name
string
)
(
string
,
bool
)
{
if
correctName
,
found
:=
codexToolNameMapping
[
name
];
found
{
return
correctName
,
true
}
return
name
,
false
}
// GetToolNameMapping 获取工具名称映射表
func
GetToolNameMapping
()
map
[
string
]
string
{
// 返回副本以避免外部修改
mapping
:=
make
(
map
[
string
]
string
,
len
(
codexToolNameMapping
))
for
k
,
v
:=
range
codexToolNameMapping
{
mapping
[
k
]
=
v
}
return
mapping
}
backend/internal/service/openai_tool_corrector_test.go
0 → 100644
View file @
b9b4db3d
package
service
import
(
"encoding/json"
"testing"
)
func
TestCorrectToolCallsInSSEData
(
t
*
testing
.
T
)
{
corrector
:=
NewCodexToolCorrector
()
tests
:=
[]
struct
{
name
string
input
string
expectCorrected
bool
checkFunc
func
(
t
*
testing
.
T
,
result
string
)
}{
{
name
:
"empty string"
,
input
:
""
,
expectCorrected
:
false
,
},
{
name
:
"newline only"
,
input
:
"
\n
"
,
expectCorrected
:
false
,
},
{
name
:
"invalid json"
,
input
:
"not a json"
,
expectCorrected
:
false
,
},
{
name
:
"correct apply_patch in tool_calls"
,
input
:
`{"tool_calls":[{"function":{"name":"apply_patch","arguments":"{}"}}]}`
,
expectCorrected
:
true
,
checkFunc
:
func
(
t
*
testing
.
T
,
result
string
)
{
var
payload
map
[
string
]
any
if
err
:=
json
.
Unmarshal
([]
byte
(
result
),
&
payload
);
err
!=
nil
{
t
.
Fatalf
(
"Failed to parse result: %v"
,
err
)
}
toolCalls
,
ok
:=
payload
[
"tool_calls"
]
.
([]
any
)
if
!
ok
||
len
(
toolCalls
)
==
0
{
t
.
Fatal
(
"No tool_calls found in result"
)
}
toolCall
,
ok
:=
toolCalls
[
0
]
.
(
map
[
string
]
any
)
if
!
ok
{
t
.
Fatal
(
"Invalid tool_call format"
)
}
functionCall
,
ok
:=
toolCall
[
"function"
]
.
(
map
[
string
]
any
)
if
!
ok
{
t
.
Fatal
(
"Invalid function format"
)
}
if
functionCall
[
"name"
]
!=
"edit"
{
t
.
Errorf
(
"Expected tool name 'edit', got '%v'"
,
functionCall
[
"name"
])
}
},
},
{
name
:
"correct update_plan in function_call"
,
input
:
`{"function_call":{"name":"update_plan","arguments":"{}"}}`
,
expectCorrected
:
true
,
checkFunc
:
func
(
t
*
testing
.
T
,
result
string
)
{
var
payload
map
[
string
]
any
if
err
:=
json
.
Unmarshal
([]
byte
(
result
),
&
payload
);
err
!=
nil
{
t
.
Fatalf
(
"Failed to parse result: %v"
,
err
)
}
functionCall
,
ok
:=
payload
[
"function_call"
]
.
(
map
[
string
]
any
)
if
!
ok
{
t
.
Fatal
(
"Invalid function_call format"
)
}
if
functionCall
[
"name"
]
!=
"todowrite"
{
t
.
Errorf
(
"Expected tool name 'todowrite', got '%v'"
,
functionCall
[
"name"
])
}
},
},
{
name
:
"correct search_files in delta.tool_calls"
,
input
:
`{"delta":{"tool_calls":[{"function":{"name":"search_files"}}]}}`
,
expectCorrected
:
true
,
checkFunc
:
func
(
t
*
testing
.
T
,
result
string
)
{
var
payload
map
[
string
]
any
if
err
:=
json
.
Unmarshal
([]
byte
(
result
),
&
payload
);
err
!=
nil
{
t
.
Fatalf
(
"Failed to parse result: %v"
,
err
)
}
delta
,
ok
:=
payload
[
"delta"
]
.
(
map
[
string
]
any
)
if
!
ok
{
t
.
Fatal
(
"Invalid delta format"
)
}
toolCalls
,
ok
:=
delta
[
"tool_calls"
]
.
([]
any
)
if
!
ok
||
len
(
toolCalls
)
==
0
{
t
.
Fatal
(
"No tool_calls found in delta"
)
}
toolCall
,
ok
:=
toolCalls
[
0
]
.
(
map
[
string
]
any
)
if
!
ok
{
t
.
Fatal
(
"Invalid tool_call format"
)
}
functionCall
,
ok
:=
toolCall
[
"function"
]
.
(
map
[
string
]
any
)
if
!
ok
{
t
.
Fatal
(
"Invalid function format"
)
}
if
functionCall
[
"name"
]
!=
"grep"
{
t
.
Errorf
(
"Expected tool name 'grep', got '%v'"
,
functionCall
[
"name"
])
}
},
},
{
name
:
"correct list_files in choices.message.tool_calls"
,
input
:
`{"choices":[{"message":{"tool_calls":[{"function":{"name":"list_files"}}]}}]}`
,
expectCorrected
:
true
,
checkFunc
:
func
(
t
*
testing
.
T
,
result
string
)
{
var
payload
map
[
string
]
any
if
err
:=
json
.
Unmarshal
([]
byte
(
result
),
&
payload
);
err
!=
nil
{
t
.
Fatalf
(
"Failed to parse result: %v"
,
err
)
}
choices
,
ok
:=
payload
[
"choices"
]
.
([]
any
)
if
!
ok
||
len
(
choices
)
==
0
{
t
.
Fatal
(
"No choices found in result"
)
}
choice
,
ok
:=
choices
[
0
]
.
(
map
[
string
]
any
)
if
!
ok
{
t
.
Fatal
(
"Invalid choice format"
)
}
message
,
ok
:=
choice
[
"message"
]
.
(
map
[
string
]
any
)
if
!
ok
{
t
.
Fatal
(
"Invalid message format"
)
}
toolCalls
,
ok
:=
message
[
"tool_calls"
]
.
([]
any
)
if
!
ok
||
len
(
toolCalls
)
==
0
{
t
.
Fatal
(
"No tool_calls found in message"
)
}
toolCall
,
ok
:=
toolCalls
[
0
]
.
(
map
[
string
]
any
)
if
!
ok
{
t
.
Fatal
(
"Invalid tool_call format"
)
}
functionCall
,
ok
:=
toolCall
[
"function"
]
.
(
map
[
string
]
any
)
if
!
ok
{
t
.
Fatal
(
"Invalid function format"
)
}
if
functionCall
[
"name"
]
!=
"glob"
{
t
.
Errorf
(
"Expected tool name 'glob', got '%v'"
,
functionCall
[
"name"
])
}
},
},
{
name
:
"no correction needed"
,
input
:
`{"tool_calls":[{"function":{"name":"read","arguments":"{}"}}]}`
,
expectCorrected
:
false
,
},
{
name
:
"correct multiple tool calls"
,
input
:
`{"tool_calls":[{"function":{"name":"apply_patch"}},{"function":{"name":"read_file"}}]}`
,
expectCorrected
:
true
,
checkFunc
:
func
(
t
*
testing
.
T
,
result
string
)
{
var
payload
map
[
string
]
any
if
err
:=
json
.
Unmarshal
([]
byte
(
result
),
&
payload
);
err
!=
nil
{
t
.
Fatalf
(
"Failed to parse result: %v"
,
err
)
}
toolCalls
,
ok
:=
payload
[
"tool_calls"
]
.
([]
any
)
if
!
ok
||
len
(
toolCalls
)
<
2
{
t
.
Fatal
(
"Expected at least 2 tool_calls"
)
}
toolCall1
,
ok
:=
toolCalls
[
0
]
.
(
map
[
string
]
any
)
if
!
ok
{
t
.
Fatal
(
"Invalid first tool_call format"
)
}
func1
,
ok
:=
toolCall1
[
"function"
]
.
(
map
[
string
]
any
)
if
!
ok
{
t
.
Fatal
(
"Invalid first function format"
)
}
if
func1
[
"name"
]
!=
"edit"
{
t
.
Errorf
(
"Expected first tool name 'edit', got '%v'"
,
func1
[
"name"
])
}
toolCall2
,
ok
:=
toolCalls
[
1
]
.
(
map
[
string
]
any
)
if
!
ok
{
t
.
Fatal
(
"Invalid second tool_call format"
)
}
func2
,
ok
:=
toolCall2
[
"function"
]
.
(
map
[
string
]
any
)
if
!
ok
{
t
.
Fatal
(
"Invalid second function format"
)
}
if
func2
[
"name"
]
!=
"read"
{
t
.
Errorf
(
"Expected second tool name 'read', got '%v'"
,
func2
[
"name"
])
}
},
},
{
name
:
"camelCase format - applyPatch"
,
input
:
`{"tool_calls":[{"function":{"name":"applyPatch"}}]}`
,
expectCorrected
:
true
,
checkFunc
:
func
(
t
*
testing
.
T
,
result
string
)
{
var
payload
map
[
string
]
any
if
err
:=
json
.
Unmarshal
([]
byte
(
result
),
&
payload
);
err
!=
nil
{
t
.
Fatalf
(
"Failed to parse result: %v"
,
err
)
}
toolCalls
,
ok
:=
payload
[
"tool_calls"
]
.
([]
any
)
if
!
ok
||
len
(
toolCalls
)
==
0
{
t
.
Fatal
(
"No tool_calls found in result"
)
}
toolCall
,
ok
:=
toolCalls
[
0
]
.
(
map
[
string
]
any
)
if
!
ok
{
t
.
Fatal
(
"Invalid tool_call format"
)
}
functionCall
,
ok
:=
toolCall
[
"function"
]
.
(
map
[
string
]
any
)
if
!
ok
{
t
.
Fatal
(
"Invalid function format"
)
}
if
functionCall
[
"name"
]
!=
"edit"
{
t
.
Errorf
(
"Expected tool name 'edit', got '%v'"
,
functionCall
[
"name"
])
}
},
},
}
for
_
,
tt
:=
range
tests
{
t
.
Run
(
tt
.
name
,
func
(
t
*
testing
.
T
)
{
result
,
corrected
:=
corrector
.
CorrectToolCallsInSSEData
(
tt
.
input
)
if
corrected
!=
tt
.
expectCorrected
{
t
.
Errorf
(
"Expected corrected=%v, got %v"
,
tt
.
expectCorrected
,
corrected
)
}
if
!
corrected
&&
result
!=
tt
.
input
{
t
.
Errorf
(
"Expected unchanged result when not corrected"
)
}
if
tt
.
checkFunc
!=
nil
{
tt
.
checkFunc
(
t
,
result
)
}
})
}
}
func
TestCorrectToolName
(
t
*
testing
.
T
)
{
tests
:=
[]
struct
{
input
string
expected
string
corrected
bool
}{
{
"apply_patch"
,
"edit"
,
true
},
{
"applyPatch"
,
"edit"
,
true
},
{
"update_plan"
,
"todowrite"
,
true
},
{
"updatePlan"
,
"todowrite"
,
true
},
{
"read_plan"
,
"todoread"
,
true
},
{
"readPlan"
,
"todoread"
,
true
},
{
"search_files"
,
"grep"
,
true
},
{
"searchFiles"
,
"grep"
,
true
},
{
"list_files"
,
"glob"
,
true
},
{
"listFiles"
,
"glob"
,
true
},
{
"read_file"
,
"read"
,
true
},
{
"readFile"
,
"read"
,
true
},
{
"write_file"
,
"write"
,
true
},
{
"writeFile"
,
"write"
,
true
},
{
"execute_bash"
,
"bash"
,
true
},
{
"executeBash"
,
"bash"
,
true
},
{
"exec_bash"
,
"bash"
,
true
},
{
"execBash"
,
"bash"
,
true
},
{
"unknown_tool"
,
"unknown_tool"
,
false
},
{
"read"
,
"read"
,
false
},
{
"edit"
,
"edit"
,
false
},
}
for
_
,
tt
:=
range
tests
{
t
.
Run
(
tt
.
input
,
func
(
t
*
testing
.
T
)
{
result
,
corrected
:=
CorrectToolName
(
tt
.
input
)
if
corrected
!=
tt
.
corrected
{
t
.
Errorf
(
"Expected corrected=%v, got %v"
,
tt
.
corrected
,
corrected
)
}
if
result
!=
tt
.
expected
{
t
.
Errorf
(
"Expected '%s', got '%s'"
,
tt
.
expected
,
result
)
}
})
}
}
func
TestGetToolNameMapping
(
t
*
testing
.
T
)
{
mapping
:=
GetToolNameMapping
()
expectedMappings
:=
map
[
string
]
string
{
"apply_patch"
:
"edit"
,
"update_plan"
:
"todowrite"
,
"read_plan"
:
"todoread"
,
"search_files"
:
"grep"
,
"list_files"
:
"glob"
,
}
for
from
,
to
:=
range
expectedMappings
{
if
mapping
[
from
]
!=
to
{
t
.
Errorf
(
"Expected mapping[%s] = %s, got %s"
,
from
,
to
,
mapping
[
from
])
}
}
mapping
[
"test_tool"
]
=
"test_value"
newMapping
:=
GetToolNameMapping
()
if
_
,
exists
:=
newMapping
[
"test_tool"
];
exists
{
t
.
Error
(
"Modifications to returned mapping should not affect original"
)
}
}
func
TestCorrectorStats
(
t
*
testing
.
T
)
{
corrector
:=
NewCodexToolCorrector
()
stats
:=
corrector
.
GetStats
()
if
stats
.
TotalCorrected
!=
0
{
t
.
Errorf
(
"Expected TotalCorrected=0, got %d"
,
stats
.
TotalCorrected
)
}
if
len
(
stats
.
CorrectionsByTool
)
!=
0
{
t
.
Errorf
(
"Expected empty CorrectionsByTool, got length %d"
,
len
(
stats
.
CorrectionsByTool
))
}
corrector
.
CorrectToolCallsInSSEData
(
`{"tool_calls":[{"function":{"name":"apply_patch"}}]}`
)
corrector
.
CorrectToolCallsInSSEData
(
`{"tool_calls":[{"function":{"name":"apply_patch"}}]}`
)
corrector
.
CorrectToolCallsInSSEData
(
`{"tool_calls":[{"function":{"name":"update_plan"}}]}`
)
stats
=
corrector
.
GetStats
()
if
stats
.
TotalCorrected
!=
3
{
t
.
Errorf
(
"Expected TotalCorrected=3, got %d"
,
stats
.
TotalCorrected
)
}
if
stats
.
CorrectionsByTool
[
"apply_patch->edit"
]
!=
2
{
t
.
Errorf
(
"Expected apply_patch->edit count=2, got %d"
,
stats
.
CorrectionsByTool
[
"apply_patch->edit"
])
}
if
stats
.
CorrectionsByTool
[
"update_plan->todowrite"
]
!=
1
{
t
.
Errorf
(
"Expected update_plan->todowrite count=1, got %d"
,
stats
.
CorrectionsByTool
[
"update_plan->todowrite"
])
}
corrector
.
ResetStats
()
stats
=
corrector
.
GetStats
()
if
stats
.
TotalCorrected
!=
0
{
t
.
Errorf
(
"Expected TotalCorrected=0 after reset, got %d"
,
stats
.
TotalCorrected
)
}
if
len
(
stats
.
CorrectionsByTool
)
!=
0
{
t
.
Errorf
(
"Expected empty CorrectionsByTool after reset, got length %d"
,
len
(
stats
.
CorrectionsByTool
))
}
}
func
TestComplexSSEData
(
t
*
testing
.
T
)
{
corrector
:=
NewCodexToolCorrector
()
input
:=
`{
"id": "chatcmpl-123",
"object": "chat.completion.chunk",
"created": 1234567890,
"model": "gpt-5.1-codex",
"choices": [
{
"index": 0,
"delta": {
"tool_calls": [
{
"index": 0,
"function": {
"name": "apply_patch",
"arguments": "{\"file\":\"test.go\"}"
}
}
]
},
"finish_reason": null
}
]
}`
result
,
corrected
:=
corrector
.
CorrectToolCallsInSSEData
(
input
)
if
!
corrected
{
t
.
Error
(
"Expected data to be corrected"
)
}
var
payload
map
[
string
]
any
if
err
:=
json
.
Unmarshal
([]
byte
(
result
),
&
payload
);
err
!=
nil
{
t
.
Fatalf
(
"Failed to parse result: %v"
,
err
)
}
choices
,
ok
:=
payload
[
"choices"
]
.
([]
any
)
if
!
ok
||
len
(
choices
)
==
0
{
t
.
Fatal
(
"No choices found in result"
)
}
choice
,
ok
:=
choices
[
0
]
.
(
map
[
string
]
any
)
if
!
ok
{
t
.
Fatal
(
"Invalid choice format"
)
}
delta
,
ok
:=
choice
[
"delta"
]
.
(
map
[
string
]
any
)
if
!
ok
{
t
.
Fatal
(
"Invalid delta format"
)
}
toolCalls
,
ok
:=
delta
[
"tool_calls"
]
.
([]
any
)
if
!
ok
||
len
(
toolCalls
)
==
0
{
t
.
Fatal
(
"No tool_calls found in delta"
)
}
toolCall
,
ok
:=
toolCalls
[
0
]
.
(
map
[
string
]
any
)
if
!
ok
{
t
.
Fatal
(
"Invalid tool_call format"
)
}
function
,
ok
:=
toolCall
[
"function"
]
.
(
map
[
string
]
any
)
if
!
ok
{
t
.
Fatal
(
"Invalid function format"
)
}
if
function
[
"name"
]
!=
"edit"
{
t
.
Errorf
(
"Expected tool name 'edit', got '%v'"
,
function
[
"name"
])
}
}
// TestCorrectToolParameters 测试工具参数修正
func
TestCorrectToolParameters
(
t
*
testing
.
T
)
{
corrector
:=
NewCodexToolCorrector
()
tests
:=
[]
struct
{
name
string
input
string
expected
map
[
string
]
bool
// key: 期待存在的参数, value: true表示应该存在
}{
{
name
:
"remove workdir from bash tool"
,
input
:
`{
"tool_calls": [{
"function": {
"name": "bash",
"arguments": "{\"command\":\"ls\",\"workdir\":\"/tmp\"}"
}
}]
}`
,
expected
:
map
[
string
]
bool
{
"command"
:
true
,
"workdir"
:
false
,
},
},
{
name
:
"rename path to file_path in edit tool"
,
input
:
`{
"tool_calls": [{
"function": {
"name": "apply_patch",
"arguments": "{\"path\":\"/foo/bar.go\",\"old_string\":\"old\",\"new_string\":\"new\"}"
}
}]
}`
,
expected
:
map
[
string
]
bool
{
"file_path"
:
true
,
"path"
:
false
,
"old_string"
:
true
,
"new_string"
:
true
,
},
},
}
for
_
,
tt
:=
range
tests
{
t
.
Run
(
tt
.
name
,
func
(
t
*
testing
.
T
)
{
corrected
,
changed
:=
corrector
.
CorrectToolCallsInSSEData
(
tt
.
input
)
if
!
changed
{
t
.
Error
(
"expected data to be corrected"
)
}
// 解析修正后的数据
var
result
map
[
string
]
any
if
err
:=
json
.
Unmarshal
([]
byte
(
corrected
),
&
result
);
err
!=
nil
{
t
.
Fatalf
(
"failed to parse corrected data: %v"
,
err
)
}
// 检查工具调用
toolCalls
,
ok
:=
result
[
"tool_calls"
]
.
([]
any
)
if
!
ok
||
len
(
toolCalls
)
==
0
{
t
.
Fatal
(
"no tool_calls found in corrected data"
)
}
toolCall
,
ok
:=
toolCalls
[
0
]
.
(
map
[
string
]
any
)
if
!
ok
{
t
.
Fatal
(
"invalid tool_call structure"
)
}
function
,
ok
:=
toolCall
[
"function"
]
.
(
map
[
string
]
any
)
if
!
ok
{
t
.
Fatal
(
"no function found in tool_call"
)
}
argumentsStr
,
ok
:=
function
[
"arguments"
]
.
(
string
)
if
!
ok
{
t
.
Fatal
(
"arguments is not a string"
)
}
var
args
map
[
string
]
any
if
err
:=
json
.
Unmarshal
([]
byte
(
argumentsStr
),
&
args
);
err
!=
nil
{
t
.
Fatalf
(
"failed to parse arguments: %v"
,
err
)
}
// 验证期望的参数
for
param
,
shouldExist
:=
range
tt
.
expected
{
_
,
exists
:=
args
[
param
]
if
shouldExist
&&
!
exists
{
t
.
Errorf
(
"expected parameter %q to exist, but it doesn't"
,
param
)
}
if
!
shouldExist
&&
exists
{
t
.
Errorf
(
"expected parameter %q to not exist, but it does"
,
param
)
}
}
})
}
}
backend/internal/service/ops_account_availability.go
0 → 100644
View file @
b9b4db3d
package
service
import
(
"context"
"errors"
"time"
)
// GetAccountAvailabilityStats returns current account availability stats.
//
// Query-level filtering is intentionally limited to platform/group to match the dashboard scope.
func
(
s
*
OpsService
)
GetAccountAvailabilityStats
(
ctx
context
.
Context
,
platformFilter
string
,
groupIDFilter
*
int64
)
(
map
[
string
]
*
PlatformAvailability
,
map
[
int64
]
*
GroupAvailability
,
map
[
int64
]
*
AccountAvailability
,
*
time
.
Time
,
error
,
)
{
if
err
:=
s
.
RequireMonitoringEnabled
(
ctx
);
err
!=
nil
{
return
nil
,
nil
,
nil
,
nil
,
err
}
accounts
,
err
:=
s
.
listAllAccountsForOps
(
ctx
,
platformFilter
)
if
err
!=
nil
{
return
nil
,
nil
,
nil
,
nil
,
err
}
if
groupIDFilter
!=
nil
&&
*
groupIDFilter
>
0
{
filtered
:=
make
([]
Account
,
0
,
len
(
accounts
))
for
_
,
acc
:=
range
accounts
{
for
_
,
grp
:=
range
acc
.
Groups
{
if
grp
!=
nil
&&
grp
.
ID
==
*
groupIDFilter
{
filtered
=
append
(
filtered
,
acc
)
break
}
}
}
accounts
=
filtered
}
now
:=
time
.
Now
()
collectedAt
:=
now
platform
:=
make
(
map
[
string
]
*
PlatformAvailability
)
group
:=
make
(
map
[
int64
]
*
GroupAvailability
)
account
:=
make
(
map
[
int64
]
*
AccountAvailability
)
for
_
,
acc
:=
range
accounts
{
if
acc
.
ID
<=
0
{
continue
}
isTempUnsched
:=
false
if
acc
.
TempUnschedulableUntil
!=
nil
&&
now
.
Before
(
*
acc
.
TempUnschedulableUntil
)
{
isTempUnsched
=
true
}
isRateLimited
:=
acc
.
RateLimitResetAt
!=
nil
&&
now
.
Before
(
*
acc
.
RateLimitResetAt
)
isOverloaded
:=
acc
.
OverloadUntil
!=
nil
&&
now
.
Before
(
*
acc
.
OverloadUntil
)
hasError
:=
acc
.
Status
==
StatusError
// Normalize exclusive status flags so the UI doesn't show conflicting badges.
if
hasError
{
isRateLimited
=
false
isOverloaded
=
false
}
isAvailable
:=
acc
.
Status
==
StatusActive
&&
acc
.
Schedulable
&&
!
isRateLimited
&&
!
isOverloaded
&&
!
isTempUnsched
if
acc
.
Platform
!=
""
{
if
_
,
ok
:=
platform
[
acc
.
Platform
];
!
ok
{
platform
[
acc
.
Platform
]
=
&
PlatformAvailability
{
Platform
:
acc
.
Platform
,
}
}
p
:=
platform
[
acc
.
Platform
]
p
.
TotalAccounts
++
if
isAvailable
{
p
.
AvailableCount
++
}
if
isRateLimited
{
p
.
RateLimitCount
++
}
if
hasError
{
p
.
ErrorCount
++
}
}
for
_
,
grp
:=
range
acc
.
Groups
{
if
grp
==
nil
||
grp
.
ID
<=
0
{
continue
}
if
_
,
ok
:=
group
[
grp
.
ID
];
!
ok
{
group
[
grp
.
ID
]
=
&
GroupAvailability
{
GroupID
:
grp
.
ID
,
GroupName
:
grp
.
Name
,
Platform
:
grp
.
Platform
,
}
}
g
:=
group
[
grp
.
ID
]
g
.
TotalAccounts
++
if
isAvailable
{
g
.
AvailableCount
++
}
if
isRateLimited
{
g
.
RateLimitCount
++
}
if
hasError
{
g
.
ErrorCount
++
}
}
displayGroupID
:=
int64
(
0
)
displayGroupName
:=
""
if
len
(
acc
.
Groups
)
>
0
&&
acc
.
Groups
[
0
]
!=
nil
{
displayGroupID
=
acc
.
Groups
[
0
]
.
ID
displayGroupName
=
acc
.
Groups
[
0
]
.
Name
}
item
:=
&
AccountAvailability
{
AccountID
:
acc
.
ID
,
AccountName
:
acc
.
Name
,
Platform
:
acc
.
Platform
,
GroupID
:
displayGroupID
,
GroupName
:
displayGroupName
,
Status
:
acc
.
Status
,
IsAvailable
:
isAvailable
,
IsRateLimited
:
isRateLimited
,
IsOverloaded
:
isOverloaded
,
HasError
:
hasError
,
ErrorMessage
:
acc
.
ErrorMessage
,
}
if
isRateLimited
&&
acc
.
RateLimitResetAt
!=
nil
{
item
.
RateLimitResetAt
=
acc
.
RateLimitResetAt
remainingSec
:=
int64
(
time
.
Until
(
*
acc
.
RateLimitResetAt
)
.
Seconds
())
if
remainingSec
>
0
{
item
.
RateLimitRemainingSec
=
&
remainingSec
}
}
if
isOverloaded
&&
acc
.
OverloadUntil
!=
nil
{
item
.
OverloadUntil
=
acc
.
OverloadUntil
remainingSec
:=
int64
(
time
.
Until
(
*
acc
.
OverloadUntil
)
.
Seconds
())
if
remainingSec
>
0
{
item
.
OverloadRemainingSec
=
&
remainingSec
}
}
if
isTempUnsched
&&
acc
.
TempUnschedulableUntil
!=
nil
{
item
.
TempUnschedulableUntil
=
acc
.
TempUnschedulableUntil
}
account
[
acc
.
ID
]
=
item
}
return
platform
,
group
,
account
,
&
collectedAt
,
nil
}
type
OpsAccountAvailability
struct
{
Group
*
GroupAvailability
Accounts
map
[
int64
]
*
AccountAvailability
CollectedAt
*
time
.
Time
}
func
(
s
*
OpsService
)
GetAccountAvailability
(
ctx
context
.
Context
,
platformFilter
string
,
groupIDFilter
*
int64
)
(
*
OpsAccountAvailability
,
error
)
{
if
s
==
nil
{
return
nil
,
errors
.
New
(
"ops service is nil"
)
}
if
s
.
getAccountAvailability
!=
nil
{
return
s
.
getAccountAvailability
(
ctx
,
platformFilter
,
groupIDFilter
)
}
_
,
groupStats
,
accountStats
,
collectedAt
,
err
:=
s
.
GetAccountAvailabilityStats
(
ctx
,
platformFilter
,
groupIDFilter
)
if
err
!=
nil
{
return
nil
,
err
}
var
group
*
GroupAvailability
if
groupIDFilter
!=
nil
&&
*
groupIDFilter
>
0
{
group
=
groupStats
[
*
groupIDFilter
]
}
if
accountStats
==
nil
{
accountStats
=
map
[
int64
]
*
AccountAvailability
{}
}
return
&
OpsAccountAvailability
{
Group
:
group
,
Accounts
:
accountStats
,
CollectedAt
:
collectedAt
,
},
nil
}
backend/internal/service/ops_advisory_lock.go
0 → 100644
View file @
b9b4db3d
package
service
import
(
"context"
"database/sql"
"hash/fnv"
"time"
)
func
hashAdvisoryLockID
(
key
string
)
int64
{
h
:=
fnv
.
New64a
()
_
,
_
=
h
.
Write
([]
byte
(
key
))
return
int64
(
h
.
Sum64
())
}
func
tryAcquireDBAdvisoryLock
(
ctx
context
.
Context
,
db
*
sql
.
DB
,
lockID
int64
)
(
func
(),
bool
)
{
if
db
==
nil
{
return
nil
,
false
}
if
ctx
==
nil
{
ctx
=
context
.
Background
()
}
conn
,
err
:=
db
.
Conn
(
ctx
)
if
err
!=
nil
{
return
nil
,
false
}
acquired
:=
false
if
err
:=
conn
.
QueryRowContext
(
ctx
,
"SELECT pg_try_advisory_lock($1)"
,
lockID
)
.
Scan
(
&
acquired
);
err
!=
nil
{
_
=
conn
.
Close
()
return
nil
,
false
}
if
!
acquired
{
_
=
conn
.
Close
()
return
nil
,
false
}
release
:=
func
()
{
unlockCtx
,
cancel
:=
context
.
WithTimeout
(
context
.
Background
(),
2
*
time
.
Second
)
defer
cancel
()
_
,
_
=
conn
.
ExecContext
(
unlockCtx
,
"SELECT pg_advisory_unlock($1)"
,
lockID
)
_
=
conn
.
Close
()
}
return
release
,
true
}
backend/internal/service/ops_aggregation_service.go
0 → 100644
View file @
b9b4db3d
package
service
import
(
"context"
"database/sql"
"errors"
"fmt"
"log"
"strings"
"sync"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/google/uuid"
"github.com/redis/go-redis/v9"
)
const
(
opsAggHourlyJobName
=
"ops_preaggregation_hourly"
opsAggDailyJobName
=
"ops_preaggregation_daily"
opsAggHourlyInterval
=
10
*
time
.
Minute
opsAggDailyInterval
=
1
*
time
.
Hour
// Keep in sync with ops retention target (vNext default 30d).
opsAggBackfillWindow
=
30
*
24
*
time
.
Hour
// Recompute overlap to absorb late-arriving rows near boundaries.
opsAggHourlyOverlap
=
2
*
time
.
Hour
opsAggDailyOverlap
=
48
*
time
.
Hour
opsAggHourlyChunk
=
24
*
time
.
Hour
opsAggDailyChunk
=
7
*
24
*
time
.
Hour
// Delay around boundaries (e.g. 10:00..10:05) to avoid aggregating buckets
// that may still receive late inserts.
opsAggSafeDelay
=
5
*
time
.
Minute
opsAggMaxQueryTimeout
=
3
*
time
.
Second
opsAggHourlyTimeout
=
5
*
time
.
Minute
opsAggDailyTimeout
=
2
*
time
.
Minute
opsAggHourlyLeaderLockKey
=
"ops:aggregation:hourly:leader"
opsAggDailyLeaderLockKey
=
"ops:aggregation:daily:leader"
opsAggHourlyLeaderLockTTL
=
15
*
time
.
Minute
opsAggDailyLeaderLockTTL
=
10
*
time
.
Minute
)
// OpsAggregationService periodically backfills ops_metrics_hourly / ops_metrics_daily
// for stable long-window dashboard queries.
//
// It is safe to run in multi-replica deployments when Redis is available (leader lock).
type
OpsAggregationService
struct
{
opsRepo
OpsRepository
settingRepo
SettingRepository
cfg
*
config
.
Config
db
*
sql
.
DB
redisClient
*
redis
.
Client
instanceID
string
stopCh
chan
struct
{}
startOnce
sync
.
Once
stopOnce
sync
.
Once
hourlyMu
sync
.
Mutex
dailyMu
sync
.
Mutex
skipLogMu
sync
.
Mutex
skipLogAt
time
.
Time
}
func
NewOpsAggregationService
(
opsRepo
OpsRepository
,
settingRepo
SettingRepository
,
db
*
sql
.
DB
,
redisClient
*
redis
.
Client
,
cfg
*
config
.
Config
,
)
*
OpsAggregationService
{
return
&
OpsAggregationService
{
opsRepo
:
opsRepo
,
settingRepo
:
settingRepo
,
cfg
:
cfg
,
db
:
db
,
redisClient
:
redisClient
,
instanceID
:
uuid
.
NewString
(),
}
}
func
(
s
*
OpsAggregationService
)
Start
()
{
if
s
==
nil
{
return
}
s
.
startOnce
.
Do
(
func
()
{
if
s
.
stopCh
==
nil
{
s
.
stopCh
=
make
(
chan
struct
{})
}
go
s
.
hourlyLoop
()
go
s
.
dailyLoop
()
})
}
func
(
s
*
OpsAggregationService
)
Stop
()
{
if
s
==
nil
{
return
}
s
.
stopOnce
.
Do
(
func
()
{
if
s
.
stopCh
!=
nil
{
close
(
s
.
stopCh
)
}
})
}
func
(
s
*
OpsAggregationService
)
hourlyLoop
()
{
// First run immediately.
s
.
aggregateHourly
()
ticker
:=
time
.
NewTicker
(
opsAggHourlyInterval
)
defer
ticker
.
Stop
()
for
{
select
{
case
<-
ticker
.
C
:
s
.
aggregateHourly
()
case
<-
s
.
stopCh
:
return
}
}
}
func
(
s
*
OpsAggregationService
)
dailyLoop
()
{
// First run immediately.
s
.
aggregateDaily
()
ticker
:=
time
.
NewTicker
(
opsAggDailyInterval
)
defer
ticker
.
Stop
()
for
{
select
{
case
<-
ticker
.
C
:
s
.
aggregateDaily
()
case
<-
s
.
stopCh
:
return
}
}
}
func
(
s
*
OpsAggregationService
)
aggregateHourly
()
{
if
s
==
nil
||
s
.
opsRepo
==
nil
{
return
}
if
s
.
cfg
!=
nil
{
if
!
s
.
cfg
.
Ops
.
Enabled
{
return
}
if
!
s
.
cfg
.
Ops
.
Aggregation
.
Enabled
{
return
}
}
ctx
,
cancel
:=
context
.
WithTimeout
(
context
.
Background
(),
opsAggHourlyTimeout
)
defer
cancel
()
if
!
s
.
isMonitoringEnabled
(
ctx
)
{
return
}
release
,
ok
:=
s
.
tryAcquireLeaderLock
(
ctx
,
opsAggHourlyLeaderLockKey
,
opsAggHourlyLeaderLockTTL
,
"[OpsAggregation][hourly]"
)
if
!
ok
{
return
}
if
release
!=
nil
{
defer
release
()
}
s
.
hourlyMu
.
Lock
()
defer
s
.
hourlyMu
.
Unlock
()
startedAt
:=
time
.
Now
()
.
UTC
()
runAt
:=
startedAt
// Aggregate stable full hours only.
end
:=
utcFloorToHour
(
time
.
Now
()
.
UTC
()
.
Add
(
-
opsAggSafeDelay
))
start
:=
end
.
Add
(
-
opsAggBackfillWindow
)
// Resume from the latest bucket with overlap.
{
ctxMax
,
cancelMax
:=
context
.
WithTimeout
(
context
.
Background
(),
opsAggMaxQueryTimeout
)
latest
,
ok
,
err
:=
s
.
opsRepo
.
GetLatestHourlyBucketStart
(
ctxMax
)
cancelMax
()
if
err
!=
nil
{
log
.
Printf
(
"[OpsAggregation][hourly] failed to read latest bucket: %v"
,
err
)
}
else
if
ok
{
candidate
:=
latest
.
Add
(
-
opsAggHourlyOverlap
)
if
candidate
.
After
(
start
)
{
start
=
candidate
}
}
}
start
=
utcFloorToHour
(
start
)
if
!
start
.
Before
(
end
)
{
return
}
var
aggErr
error
for
cursor
:=
start
;
cursor
.
Before
(
end
);
cursor
=
cursor
.
Add
(
opsAggHourlyChunk
)
{
chunkEnd
:=
minTime
(
cursor
.
Add
(
opsAggHourlyChunk
),
end
)
if
err
:=
s
.
opsRepo
.
UpsertHourlyMetrics
(
ctx
,
cursor
,
chunkEnd
);
err
!=
nil
{
aggErr
=
err
log
.
Printf
(
"[OpsAggregation][hourly] upsert failed (%s..%s): %v"
,
cursor
.
Format
(
time
.
RFC3339
),
chunkEnd
.
Format
(
time
.
RFC3339
),
err
)
break
}
}
finishedAt
:=
time
.
Now
()
.
UTC
()
durationMs
:=
finishedAt
.
Sub
(
startedAt
)
.
Milliseconds
()
dur
:=
durationMs
if
aggErr
!=
nil
{
msg
:=
truncateString
(
aggErr
.
Error
(),
2048
)
errAt
:=
finishedAt
hbCtx
,
hbCancel
:=
context
.
WithTimeout
(
context
.
Background
(),
2
*
time
.
Second
)
defer
hbCancel
()
_
=
s
.
opsRepo
.
UpsertJobHeartbeat
(
hbCtx
,
&
OpsUpsertJobHeartbeatInput
{
JobName
:
opsAggHourlyJobName
,
LastRunAt
:
&
runAt
,
LastErrorAt
:
&
errAt
,
LastError
:
&
msg
,
LastDurationMs
:
&
dur
,
})
return
}
successAt
:=
finishedAt
hbCtx
,
hbCancel
:=
context
.
WithTimeout
(
context
.
Background
(),
2
*
time
.
Second
)
defer
hbCancel
()
result
:=
truncateString
(
fmt
.
Sprintf
(
"window=%s..%s"
,
start
.
Format
(
time
.
RFC3339
),
end
.
Format
(
time
.
RFC3339
)),
2048
)
_
=
s
.
opsRepo
.
UpsertJobHeartbeat
(
hbCtx
,
&
OpsUpsertJobHeartbeatInput
{
JobName
:
opsAggHourlyJobName
,
LastRunAt
:
&
runAt
,
LastSuccessAt
:
&
successAt
,
LastDurationMs
:
&
dur
,
LastResult
:
&
result
,
})
}
func
(
s
*
OpsAggregationService
)
aggregateDaily
()
{
if
s
==
nil
||
s
.
opsRepo
==
nil
{
return
}
if
s
.
cfg
!=
nil
{
if
!
s
.
cfg
.
Ops
.
Enabled
{
return
}
if
!
s
.
cfg
.
Ops
.
Aggregation
.
Enabled
{
return
}
}
ctx
,
cancel
:=
context
.
WithTimeout
(
context
.
Background
(),
opsAggDailyTimeout
)
defer
cancel
()
if
!
s
.
isMonitoringEnabled
(
ctx
)
{
return
}
release
,
ok
:=
s
.
tryAcquireLeaderLock
(
ctx
,
opsAggDailyLeaderLockKey
,
opsAggDailyLeaderLockTTL
,
"[OpsAggregation][daily]"
)
if
!
ok
{
return
}
if
release
!=
nil
{
defer
release
()
}
s
.
dailyMu
.
Lock
()
defer
s
.
dailyMu
.
Unlock
()
startedAt
:=
time
.
Now
()
.
UTC
()
runAt
:=
startedAt
end
:=
utcFloorToDay
(
time
.
Now
()
.
UTC
())
start
:=
end
.
Add
(
-
opsAggBackfillWindow
)
{
ctxMax
,
cancelMax
:=
context
.
WithTimeout
(
context
.
Background
(),
opsAggMaxQueryTimeout
)
latest
,
ok
,
err
:=
s
.
opsRepo
.
GetLatestDailyBucketDate
(
ctxMax
)
cancelMax
()
if
err
!=
nil
{
log
.
Printf
(
"[OpsAggregation][daily] failed to read latest bucket: %v"
,
err
)
}
else
if
ok
{
candidate
:=
latest
.
Add
(
-
opsAggDailyOverlap
)
if
candidate
.
After
(
start
)
{
start
=
candidate
}
}
}
start
=
utcFloorToDay
(
start
)
if
!
start
.
Before
(
end
)
{
return
}
var
aggErr
error
for
cursor
:=
start
;
cursor
.
Before
(
end
);
cursor
=
cursor
.
Add
(
opsAggDailyChunk
)
{
chunkEnd
:=
minTime
(
cursor
.
Add
(
opsAggDailyChunk
),
end
)
if
err
:=
s
.
opsRepo
.
UpsertDailyMetrics
(
ctx
,
cursor
,
chunkEnd
);
err
!=
nil
{
aggErr
=
err
log
.
Printf
(
"[OpsAggregation][daily] upsert failed (%s..%s): %v"
,
cursor
.
Format
(
"2006-01-02"
),
chunkEnd
.
Format
(
"2006-01-02"
),
err
)
break
}
}
finishedAt
:=
time
.
Now
()
.
UTC
()
durationMs
:=
finishedAt
.
Sub
(
startedAt
)
.
Milliseconds
()
dur
:=
durationMs
if
aggErr
!=
nil
{
msg
:=
truncateString
(
aggErr
.
Error
(),
2048
)
errAt
:=
finishedAt
hbCtx
,
hbCancel
:=
context
.
WithTimeout
(
context
.
Background
(),
2
*
time
.
Second
)
defer
hbCancel
()
_
=
s
.
opsRepo
.
UpsertJobHeartbeat
(
hbCtx
,
&
OpsUpsertJobHeartbeatInput
{
JobName
:
opsAggDailyJobName
,
LastRunAt
:
&
runAt
,
LastErrorAt
:
&
errAt
,
LastError
:
&
msg
,
LastDurationMs
:
&
dur
,
})
return
}
successAt
:=
finishedAt
hbCtx
,
hbCancel
:=
context
.
WithTimeout
(
context
.
Background
(),
2
*
time
.
Second
)
defer
hbCancel
()
result
:=
truncateString
(
fmt
.
Sprintf
(
"window=%s..%s"
,
start
.
Format
(
time
.
RFC3339
),
end
.
Format
(
time
.
RFC3339
)),
2048
)
_
=
s
.
opsRepo
.
UpsertJobHeartbeat
(
hbCtx
,
&
OpsUpsertJobHeartbeatInput
{
JobName
:
opsAggDailyJobName
,
LastRunAt
:
&
runAt
,
LastSuccessAt
:
&
successAt
,
LastDurationMs
:
&
dur
,
LastResult
:
&
result
,
})
}
func
(
s
*
OpsAggregationService
)
isMonitoringEnabled
(
ctx
context
.
Context
)
bool
{
if
s
==
nil
{
return
false
}
if
s
.
cfg
!=
nil
&&
!
s
.
cfg
.
Ops
.
Enabled
{
return
false
}
if
s
.
settingRepo
==
nil
{
return
true
}
if
ctx
==
nil
{
ctx
=
context
.
Background
()
}
value
,
err
:=
s
.
settingRepo
.
GetValue
(
ctx
,
SettingKeyOpsMonitoringEnabled
)
if
err
!=
nil
{
if
errors
.
Is
(
err
,
ErrSettingNotFound
)
{
return
true
}
return
true
}
switch
strings
.
ToLower
(
strings
.
TrimSpace
(
value
))
{
case
"false"
,
"0"
,
"off"
,
"disabled"
:
return
false
default
:
return
true
}
}
var
opsAggReleaseScript
=
redis
.
NewScript
(
`
if redis.call("GET", KEYS[1]) == ARGV[1] then
return redis.call("DEL", KEYS[1])
end
return 0
`
)
func
(
s
*
OpsAggregationService
)
tryAcquireLeaderLock
(
ctx
context
.
Context
,
key
string
,
ttl
time
.
Duration
,
logPrefix
string
)
(
func
(),
bool
)
{
if
s
==
nil
{
return
nil
,
false
}
if
ctx
==
nil
{
ctx
=
context
.
Background
()
}
// Prefer Redis leader lock when available (multi-instance), but avoid stampeding
// the DB when Redis is flaky by falling back to a DB advisory lock.
if
s
.
redisClient
!=
nil
{
ok
,
err
:=
s
.
redisClient
.
SetNX
(
ctx
,
key
,
s
.
instanceID
,
ttl
)
.
Result
()
if
err
==
nil
{
if
!
ok
{
s
.
maybeLogSkip
(
logPrefix
)
return
nil
,
false
}
release
:=
func
()
{
ctx2
,
cancel
:=
context
.
WithTimeout
(
context
.
Background
(),
2
*
time
.
Second
)
defer
cancel
()
_
,
_
=
opsAggReleaseScript
.
Run
(
ctx2
,
s
.
redisClient
,
[]
string
{
key
},
s
.
instanceID
)
.
Result
()
}
return
release
,
true
}
// Redis error: fall through to DB advisory lock.
}
release
,
ok
:=
tryAcquireDBAdvisoryLock
(
ctx
,
s
.
db
,
hashAdvisoryLockID
(
key
))
if
!
ok
{
s
.
maybeLogSkip
(
logPrefix
)
return
nil
,
false
}
return
release
,
true
}
func
(
s
*
OpsAggregationService
)
maybeLogSkip
(
prefix
string
)
{
s
.
skipLogMu
.
Lock
()
defer
s
.
skipLogMu
.
Unlock
()
now
:=
time
.
Now
()
if
!
s
.
skipLogAt
.
IsZero
()
&&
now
.
Sub
(
s
.
skipLogAt
)
<
time
.
Minute
{
return
}
s
.
skipLogAt
=
now
if
prefix
==
""
{
prefix
=
"[OpsAggregation]"
}
log
.
Printf
(
"%s leader lock held by another instance; skipping"
,
prefix
)
}
func
utcFloorToHour
(
t
time
.
Time
)
time
.
Time
{
return
t
.
UTC
()
.
Truncate
(
time
.
Hour
)
}
func
utcFloorToDay
(
t
time
.
Time
)
time
.
Time
{
u
:=
t
.
UTC
()
y
,
m
,
d
:=
u
.
Date
()
return
time
.
Date
(
y
,
m
,
d
,
0
,
0
,
0
,
0
,
time
.
UTC
)
}
func
minTime
(
a
,
b
time
.
Time
)
time
.
Time
{
if
a
.
Before
(
b
)
{
return
a
}
return
b
}
backend/internal/service/ops_alert_evaluator_service.go
0 → 100644
View file @
b9b4db3d
package
service
import
(
"context"
"fmt"
"log"
"math"
"strconv"
"strings"
"sync"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/google/uuid"
"github.com/redis/go-redis/v9"
)
const
(
opsAlertEvaluatorJobName
=
"ops_alert_evaluator"
opsAlertEvaluatorTimeout
=
45
*
time
.
Second
opsAlertEvaluatorLeaderLockKey
=
"ops:alert:evaluator:leader"
opsAlertEvaluatorLeaderLockTTL
=
90
*
time
.
Second
opsAlertEvaluatorSkipLogInterval
=
1
*
time
.
Minute
)
var
opsAlertEvaluatorReleaseScript
=
redis
.
NewScript
(
`
if redis.call("GET", KEYS[1]) == ARGV[1] then
return redis.call("DEL", KEYS[1])
end
return 0
`
)
type
OpsAlertEvaluatorService
struct
{
opsService
*
OpsService
opsRepo
OpsRepository
emailService
*
EmailService
redisClient
*
redis
.
Client
cfg
*
config
.
Config
instanceID
string
stopCh
chan
struct
{}
startOnce
sync
.
Once
stopOnce
sync
.
Once
wg
sync
.
WaitGroup
mu
sync
.
Mutex
ruleStates
map
[
int64
]
*
opsAlertRuleState
emailLimiter
*
slidingWindowLimiter
skipLogMu
sync
.
Mutex
skipLogAt
time
.
Time
warnNoRedisOnce
sync
.
Once
}
type
opsAlertRuleState
struct
{
LastEvaluatedAt
time
.
Time
ConsecutiveBreaches
int
}
func
NewOpsAlertEvaluatorService
(
opsService
*
OpsService
,
opsRepo
OpsRepository
,
emailService
*
EmailService
,
redisClient
*
redis
.
Client
,
cfg
*
config
.
Config
,
)
*
OpsAlertEvaluatorService
{
return
&
OpsAlertEvaluatorService
{
opsService
:
opsService
,
opsRepo
:
opsRepo
,
emailService
:
emailService
,
redisClient
:
redisClient
,
cfg
:
cfg
,
instanceID
:
uuid
.
NewString
(),
ruleStates
:
map
[
int64
]
*
opsAlertRuleState
{},
emailLimiter
:
newSlidingWindowLimiter
(
0
,
time
.
Hour
),
}
}
func
(
s
*
OpsAlertEvaluatorService
)
Start
()
{
if
s
==
nil
{
return
}
s
.
startOnce
.
Do
(
func
()
{
if
s
.
stopCh
==
nil
{
s
.
stopCh
=
make
(
chan
struct
{})
}
go
s
.
run
()
})
}
func
(
s
*
OpsAlertEvaluatorService
)
Stop
()
{
if
s
==
nil
{
return
}
s
.
stopOnce
.
Do
(
func
()
{
if
s
.
stopCh
!=
nil
{
close
(
s
.
stopCh
)
}
})
s
.
wg
.
Wait
()
}
func
(
s
*
OpsAlertEvaluatorService
)
run
()
{
s
.
wg
.
Add
(
1
)
defer
s
.
wg
.
Done
()
// Start immediately to produce early feedback in ops dashboard.
timer
:=
time
.
NewTimer
(
0
)
defer
timer
.
Stop
()
for
{
select
{
case
<-
timer
.
C
:
interval
:=
s
.
getInterval
()
s
.
evaluateOnce
(
interval
)
timer
.
Reset
(
interval
)
case
<-
s
.
stopCh
:
return
}
}
}
func
(
s
*
OpsAlertEvaluatorService
)
getInterval
()
time
.
Duration
{
// Default.
interval
:=
60
*
time
.
Second
if
s
==
nil
||
s
.
opsService
==
nil
{
return
interval
}
ctx
,
cancel
:=
context
.
WithTimeout
(
context
.
Background
(),
2
*
time
.
Second
)
defer
cancel
()
cfg
,
err
:=
s
.
opsService
.
GetOpsAlertRuntimeSettings
(
ctx
)
if
err
!=
nil
||
cfg
==
nil
{
return
interval
}
if
cfg
.
EvaluationIntervalSeconds
<=
0
{
return
interval
}
if
cfg
.
EvaluationIntervalSeconds
<
1
{
return
interval
}
if
cfg
.
EvaluationIntervalSeconds
>
int
((
24
*
time
.
Hour
)
.
Seconds
())
{
return
interval
}
return
time
.
Duration
(
cfg
.
EvaluationIntervalSeconds
)
*
time
.
Second
}
func
(
s
*
OpsAlertEvaluatorService
)
evaluateOnce
(
interval
time
.
Duration
)
{
if
s
==
nil
||
s
.
opsRepo
==
nil
{
return
}
if
s
.
cfg
!=
nil
&&
!
s
.
cfg
.
Ops
.
Enabled
{
return
}
ctx
,
cancel
:=
context
.
WithTimeout
(
context
.
Background
(),
opsAlertEvaluatorTimeout
)
defer
cancel
()
if
s
.
opsService
!=
nil
&&
!
s
.
opsService
.
IsMonitoringEnabled
(
ctx
)
{
return
}
runtimeCfg
:=
defaultOpsAlertRuntimeSettings
()
if
s
.
opsService
!=
nil
{
if
loaded
,
err
:=
s
.
opsService
.
GetOpsAlertRuntimeSettings
(
ctx
);
err
==
nil
&&
loaded
!=
nil
{
runtimeCfg
=
loaded
}
}
release
,
ok
:=
s
.
tryAcquireLeaderLock
(
ctx
,
runtimeCfg
.
DistributedLock
)
if
!
ok
{
return
}
if
release
!=
nil
{
defer
release
()
}
startedAt
:=
time
.
Now
()
.
UTC
()
runAt
:=
startedAt
rules
,
err
:=
s
.
opsRepo
.
ListAlertRules
(
ctx
)
if
err
!=
nil
{
s
.
recordHeartbeatError
(
runAt
,
time
.
Since
(
startedAt
),
err
)
log
.
Printf
(
"[OpsAlertEvaluator] list rules failed: %v"
,
err
)
return
}
rulesTotal
:=
len
(
rules
)
rulesEnabled
:=
0
rulesEvaluated
:=
0
eventsCreated
:=
0
eventsResolved
:=
0
emailsSent
:=
0
now
:=
time
.
Now
()
.
UTC
()
safeEnd
:=
now
.
Truncate
(
time
.
Minute
)
if
safeEnd
.
IsZero
()
{
safeEnd
=
now
}
systemMetrics
,
_
:=
s
.
opsRepo
.
GetLatestSystemMetrics
(
ctx
,
1
)
// Cleanup stale state for removed rules.
s
.
pruneRuleStates
(
rules
)
for
_
,
rule
:=
range
rules
{
if
rule
==
nil
||
!
rule
.
Enabled
||
rule
.
ID
<=
0
{
continue
}
rulesEnabled
++
scopePlatform
,
scopeGroupID
,
scopeRegion
:=
parseOpsAlertRuleScope
(
rule
.
Filters
)
windowMinutes
:=
rule
.
WindowMinutes
if
windowMinutes
<=
0
{
windowMinutes
=
1
}
windowStart
:=
safeEnd
.
Add
(
-
time
.
Duration
(
windowMinutes
)
*
time
.
Minute
)
windowEnd
:=
safeEnd
metricValue
,
ok
:=
s
.
computeRuleMetric
(
ctx
,
rule
,
systemMetrics
,
windowStart
,
windowEnd
,
scopePlatform
,
scopeGroupID
)
if
!
ok
{
s
.
resetRuleState
(
rule
.
ID
,
now
)
continue
}
rulesEvaluated
++
breachedNow
:=
compareMetric
(
metricValue
,
rule
.
Operator
,
rule
.
Threshold
)
required
:=
requiredSustainedBreaches
(
rule
.
SustainedMinutes
,
interval
)
consecutive
:=
s
.
updateRuleBreaches
(
rule
.
ID
,
now
,
interval
,
breachedNow
)
activeEvent
,
err
:=
s
.
opsRepo
.
GetActiveAlertEvent
(
ctx
,
rule
.
ID
)
if
err
!=
nil
{
log
.
Printf
(
"[OpsAlertEvaluator] get active event failed (rule=%d): %v"
,
rule
.
ID
,
err
)
continue
}
if
breachedNow
&&
consecutive
>=
required
{
if
activeEvent
!=
nil
{
continue
}
// Scoped silencing: if a matching silence exists, skip creating a firing event.
if
s
.
opsService
!=
nil
{
platform
:=
strings
.
TrimSpace
(
scopePlatform
)
region
:=
scopeRegion
if
platform
!=
""
{
if
ok
,
err
:=
s
.
opsService
.
IsAlertSilenced
(
ctx
,
rule
.
ID
,
platform
,
scopeGroupID
,
region
,
now
);
err
==
nil
&&
ok
{
continue
}
}
}
latestEvent
,
err
:=
s
.
opsRepo
.
GetLatestAlertEvent
(
ctx
,
rule
.
ID
)
if
err
!=
nil
{
log
.
Printf
(
"[OpsAlertEvaluator] get latest event failed (rule=%d): %v"
,
rule
.
ID
,
err
)
continue
}
if
latestEvent
!=
nil
&&
rule
.
CooldownMinutes
>
0
{
cooldown
:=
time
.
Duration
(
rule
.
CooldownMinutes
)
*
time
.
Minute
if
now
.
Sub
(
latestEvent
.
FiredAt
)
<
cooldown
{
continue
}
}
firedEvent
:=
&
OpsAlertEvent
{
RuleID
:
rule
.
ID
,
Severity
:
strings
.
TrimSpace
(
rule
.
Severity
),
Status
:
OpsAlertStatusFiring
,
Title
:
fmt
.
Sprintf
(
"%s: %s"
,
strings
.
TrimSpace
(
rule
.
Severity
),
strings
.
TrimSpace
(
rule
.
Name
)),
Description
:
buildOpsAlertDescription
(
rule
,
metricValue
,
windowMinutes
,
scopePlatform
,
scopeGroupID
),
MetricValue
:
float64Ptr
(
metricValue
),
ThresholdValue
:
float64Ptr
(
rule
.
Threshold
),
Dimensions
:
buildOpsAlertDimensions
(
scopePlatform
,
scopeGroupID
),
FiredAt
:
now
,
CreatedAt
:
now
,
}
created
,
err
:=
s
.
opsRepo
.
CreateAlertEvent
(
ctx
,
firedEvent
)
if
err
!=
nil
{
log
.
Printf
(
"[OpsAlertEvaluator] create event failed (rule=%d): %v"
,
rule
.
ID
,
err
)
continue
}
eventsCreated
++
if
created
!=
nil
&&
created
.
ID
>
0
{
if
s
.
maybeSendAlertEmail
(
ctx
,
runtimeCfg
,
rule
,
created
)
{
emailsSent
++
}
}
continue
}
// Not breached: resolve active event if present.
if
activeEvent
!=
nil
{
resolvedAt
:=
now
if
err
:=
s
.
opsRepo
.
UpdateAlertEventStatus
(
ctx
,
activeEvent
.
ID
,
OpsAlertStatusResolved
,
&
resolvedAt
);
err
!=
nil
{
log
.
Printf
(
"[OpsAlertEvaluator] resolve event failed (event=%d): %v"
,
activeEvent
.
ID
,
err
)
}
else
{
eventsResolved
++
}
}
}
result
:=
truncateString
(
fmt
.
Sprintf
(
"rules=%d enabled=%d evaluated=%d created=%d resolved=%d emails_sent=%d"
,
rulesTotal
,
rulesEnabled
,
rulesEvaluated
,
eventsCreated
,
eventsResolved
,
emailsSent
),
2048
)
s
.
recordHeartbeatSuccess
(
runAt
,
time
.
Since
(
startedAt
),
result
)
}
func
(
s
*
OpsAlertEvaluatorService
)
pruneRuleStates
(
rules
[]
*
OpsAlertRule
)
{
s
.
mu
.
Lock
()
defer
s
.
mu
.
Unlock
()
live
:=
map
[
int64
]
struct
{}{}
for
_
,
r
:=
range
rules
{
if
r
!=
nil
&&
r
.
ID
>
0
{
live
[
r
.
ID
]
=
struct
{}{}
}
}
for
id
:=
range
s
.
ruleStates
{
if
_
,
ok
:=
live
[
id
];
!
ok
{
delete
(
s
.
ruleStates
,
id
)
}
}
}
func
(
s
*
OpsAlertEvaluatorService
)
resetRuleState
(
ruleID
int64
,
now
time
.
Time
)
{
if
ruleID
<=
0
{
return
}
s
.
mu
.
Lock
()
defer
s
.
mu
.
Unlock
()
state
,
ok
:=
s
.
ruleStates
[
ruleID
]
if
!
ok
{
state
=
&
opsAlertRuleState
{}
s
.
ruleStates
[
ruleID
]
=
state
}
state
.
LastEvaluatedAt
=
now
state
.
ConsecutiveBreaches
=
0
}
func
(
s
*
OpsAlertEvaluatorService
)
updateRuleBreaches
(
ruleID
int64
,
now
time
.
Time
,
interval
time
.
Duration
,
breached
bool
)
int
{
if
ruleID
<=
0
{
return
0
}
s
.
mu
.
Lock
()
defer
s
.
mu
.
Unlock
()
state
,
ok
:=
s
.
ruleStates
[
ruleID
]
if
!
ok
{
state
=
&
opsAlertRuleState
{}
s
.
ruleStates
[
ruleID
]
=
state
}
if
!
state
.
LastEvaluatedAt
.
IsZero
()
&&
interval
>
0
{
if
now
.
Sub
(
state
.
LastEvaluatedAt
)
>
interval
*
2
{
state
.
ConsecutiveBreaches
=
0
}
}
state
.
LastEvaluatedAt
=
now
if
breached
{
state
.
ConsecutiveBreaches
++
}
else
{
state
.
ConsecutiveBreaches
=
0
}
return
state
.
ConsecutiveBreaches
}
func
requiredSustainedBreaches
(
sustainedMinutes
int
,
interval
time
.
Duration
)
int
{
if
sustainedMinutes
<=
0
{
return
1
}
if
interval
<=
0
{
return
sustainedMinutes
}
required
:=
int
(
math
.
Ceil
(
float64
(
sustainedMinutes
*
60
)
/
interval
.
Seconds
()))
if
required
<
1
{
return
1
}
return
required
}
func
parseOpsAlertRuleScope
(
filters
map
[
string
]
any
)
(
platform
string
,
groupID
*
int64
,
region
*
string
)
{
if
filters
==
nil
{
return
""
,
nil
,
nil
}
if
v
,
ok
:=
filters
[
"platform"
];
ok
{
if
s
,
ok
:=
v
.
(
string
);
ok
{
platform
=
strings
.
TrimSpace
(
s
)
}
}
if
v
,
ok
:=
filters
[
"group_id"
];
ok
{
switch
t
:=
v
.
(
type
)
{
case
float64
:
if
t
>
0
{
id
:=
int64
(
t
)
groupID
=
&
id
}
case
int64
:
if
t
>
0
{
id
:=
t
groupID
=
&
id
}
case
int
:
if
t
>
0
{
id
:=
int64
(
t
)
groupID
=
&
id
}
case
string
:
n
,
err
:=
strconv
.
ParseInt
(
strings
.
TrimSpace
(
t
),
10
,
64
)
if
err
==
nil
&&
n
>
0
{
groupID
=
&
n
}
}
}
if
v
,
ok
:=
filters
[
"region"
];
ok
{
if
s
,
ok
:=
v
.
(
string
);
ok
{
vv
:=
strings
.
TrimSpace
(
s
)
if
vv
!=
""
{
region
=
&
vv
}
}
}
return
platform
,
groupID
,
region
}
func
(
s
*
OpsAlertEvaluatorService
)
computeRuleMetric
(
ctx
context
.
Context
,
rule
*
OpsAlertRule
,
systemMetrics
*
OpsSystemMetricsSnapshot
,
start
time
.
Time
,
end
time
.
Time
,
platform
string
,
groupID
*
int64
,
)
(
float64
,
bool
)
{
if
rule
==
nil
{
return
0
,
false
}
switch
strings
.
TrimSpace
(
rule
.
MetricType
)
{
case
"cpu_usage_percent"
:
if
systemMetrics
!=
nil
&&
systemMetrics
.
CPUUsagePercent
!=
nil
{
return
*
systemMetrics
.
CPUUsagePercent
,
true
}
return
0
,
false
case
"memory_usage_percent"
:
if
systemMetrics
!=
nil
&&
systemMetrics
.
MemoryUsagePercent
!=
nil
{
return
*
systemMetrics
.
MemoryUsagePercent
,
true
}
return
0
,
false
case
"concurrency_queue_depth"
:
if
systemMetrics
!=
nil
&&
systemMetrics
.
ConcurrencyQueueDepth
!=
nil
{
return
float64
(
*
systemMetrics
.
ConcurrencyQueueDepth
),
true
}
return
0
,
false
case
"group_available_accounts"
:
if
groupID
==
nil
||
*
groupID
<=
0
{
return
0
,
false
}
if
s
==
nil
||
s
.
opsService
==
nil
{
return
0
,
false
}
availability
,
err
:=
s
.
opsService
.
GetAccountAvailability
(
ctx
,
platform
,
groupID
)
if
err
!=
nil
||
availability
==
nil
{
return
0
,
false
}
if
availability
.
Group
==
nil
{
return
0
,
true
}
return
float64
(
availability
.
Group
.
AvailableCount
),
true
case
"group_available_ratio"
:
if
groupID
==
nil
||
*
groupID
<=
0
{
return
0
,
false
}
if
s
==
nil
||
s
.
opsService
==
nil
{
return
0
,
false
}
availability
,
err
:=
s
.
opsService
.
GetAccountAvailability
(
ctx
,
platform
,
groupID
)
if
err
!=
nil
||
availability
==
nil
{
return
0
,
false
}
return
computeGroupAvailableRatio
(
availability
.
Group
),
true
case
"account_rate_limited_count"
:
if
s
==
nil
||
s
.
opsService
==
nil
{
return
0
,
false
}
availability
,
err
:=
s
.
opsService
.
GetAccountAvailability
(
ctx
,
platform
,
groupID
)
if
err
!=
nil
||
availability
==
nil
{
return
0
,
false
}
return
float64
(
countAccountsByCondition
(
availability
.
Accounts
,
func
(
acc
*
AccountAvailability
)
bool
{
return
acc
.
IsRateLimited
})),
true
case
"account_error_count"
:
if
s
==
nil
||
s
.
opsService
==
nil
{
return
0
,
false
}
availability
,
err
:=
s
.
opsService
.
GetAccountAvailability
(
ctx
,
platform
,
groupID
)
if
err
!=
nil
||
availability
==
nil
{
return
0
,
false
}
return
float64
(
countAccountsByCondition
(
availability
.
Accounts
,
func
(
acc
*
AccountAvailability
)
bool
{
return
acc
.
HasError
&&
acc
.
TempUnschedulableUntil
==
nil
})),
true
}
overview
,
err
:=
s
.
opsRepo
.
GetDashboardOverview
(
ctx
,
&
OpsDashboardFilter
{
StartTime
:
start
,
EndTime
:
end
,
Platform
:
platform
,
GroupID
:
groupID
,
QueryMode
:
OpsQueryModeRaw
,
})
if
err
!=
nil
{
return
0
,
false
}
if
overview
==
nil
{
return
0
,
false
}
switch
strings
.
TrimSpace
(
rule
.
MetricType
)
{
case
"success_rate"
:
if
overview
.
RequestCountSLA
<=
0
{
return
0
,
false
}
return
overview
.
SLA
*
100
,
true
case
"error_rate"
:
if
overview
.
RequestCountSLA
<=
0
{
return
0
,
false
}
return
overview
.
ErrorRate
*
100
,
true
case
"upstream_error_rate"
:
if
overview
.
RequestCountSLA
<=
0
{
return
0
,
false
}
return
overview
.
UpstreamErrorRate
*
100
,
true
default
:
return
0
,
false
}
}
func
compareMetric
(
value
float64
,
operator
string
,
threshold
float64
)
bool
{
switch
strings
.
TrimSpace
(
operator
)
{
case
">"
:
return
value
>
threshold
case
">="
:
return
value
>=
threshold
case
"<"
:
return
value
<
threshold
case
"<="
:
return
value
<=
threshold
case
"=="
:
return
value
==
threshold
case
"!="
:
return
value
!=
threshold
default
:
return
false
}
}
func
buildOpsAlertDimensions
(
platform
string
,
groupID
*
int64
)
map
[
string
]
any
{
dims
:=
map
[
string
]
any
{}
if
strings
.
TrimSpace
(
platform
)
!=
""
{
dims
[
"platform"
]
=
strings
.
TrimSpace
(
platform
)
}
if
groupID
!=
nil
&&
*
groupID
>
0
{
dims
[
"group_id"
]
=
*
groupID
}
if
len
(
dims
)
==
0
{
return
nil
}
return
dims
}
func
buildOpsAlertDescription
(
rule
*
OpsAlertRule
,
value
float64
,
windowMinutes
int
,
platform
string
,
groupID
*
int64
)
string
{
if
rule
==
nil
{
return
""
}
scope
:=
"overall"
if
strings
.
TrimSpace
(
platform
)
!=
""
{
scope
=
fmt
.
Sprintf
(
"platform=%s"
,
strings
.
TrimSpace
(
platform
))
}
if
groupID
!=
nil
&&
*
groupID
>
0
{
scope
=
fmt
.
Sprintf
(
"%s group_id=%d"
,
scope
,
*
groupID
)
}
if
windowMinutes
<=
0
{
windowMinutes
=
1
}
return
fmt
.
Sprintf
(
"%s %s %.2f (current %.2f) over last %dm (%s)"
,
strings
.
TrimSpace
(
rule
.
MetricType
),
strings
.
TrimSpace
(
rule
.
Operator
),
rule
.
Threshold
,
value
,
windowMinutes
,
strings
.
TrimSpace
(
scope
),
)
}
func
(
s
*
OpsAlertEvaluatorService
)
maybeSendAlertEmail
(
ctx
context
.
Context
,
runtimeCfg
*
OpsAlertRuntimeSettings
,
rule
*
OpsAlertRule
,
event
*
OpsAlertEvent
)
bool
{
if
s
==
nil
||
s
.
emailService
==
nil
||
s
.
opsService
==
nil
||
event
==
nil
||
rule
==
nil
{
return
false
}
if
event
.
EmailSent
{
return
false
}
if
!
rule
.
NotifyEmail
{
return
false
}
emailCfg
,
err
:=
s
.
opsService
.
GetEmailNotificationConfig
(
ctx
)
if
err
!=
nil
||
emailCfg
==
nil
||
!
emailCfg
.
Alert
.
Enabled
{
return
false
}
if
len
(
emailCfg
.
Alert
.
Recipients
)
==
0
{
return
false
}
if
!
shouldSendOpsAlertEmailByMinSeverity
(
strings
.
TrimSpace
(
emailCfg
.
Alert
.
MinSeverity
),
strings
.
TrimSpace
(
rule
.
Severity
))
{
return
false
}
if
runtimeCfg
!=
nil
&&
runtimeCfg
.
Silencing
.
Enabled
{
if
isOpsAlertSilenced
(
time
.
Now
()
.
UTC
(),
rule
,
event
,
runtimeCfg
.
Silencing
)
{
return
false
}
}
// Apply/update rate limiter.
s
.
emailLimiter
.
SetLimit
(
emailCfg
.
Alert
.
RateLimitPerHour
)
subject
:=
fmt
.
Sprintf
(
"[Ops Alert][%s] %s"
,
strings
.
TrimSpace
(
rule
.
Severity
),
strings
.
TrimSpace
(
rule
.
Name
))
body
:=
buildOpsAlertEmailBody
(
rule
,
event
)
anySent
:=
false
for
_
,
to
:=
range
emailCfg
.
Alert
.
Recipients
{
addr
:=
strings
.
TrimSpace
(
to
)
if
addr
==
""
{
continue
}
if
!
s
.
emailLimiter
.
Allow
(
time
.
Now
()
.
UTC
())
{
continue
}
if
err
:=
s
.
emailService
.
SendEmail
(
ctx
,
addr
,
subject
,
body
);
err
!=
nil
{
// Ignore per-recipient failures; continue best-effort.
continue
}
anySent
=
true
}
if
anySent
{
_
=
s
.
opsRepo
.
UpdateAlertEventEmailSent
(
context
.
Background
(),
event
.
ID
,
true
)
}
return
anySent
}
func
buildOpsAlertEmailBody
(
rule
*
OpsAlertRule
,
event
*
OpsAlertEvent
)
string
{
if
rule
==
nil
||
event
==
nil
{
return
""
}
metric
:=
strings
.
TrimSpace
(
rule
.
MetricType
)
value
:=
"-"
threshold
:=
fmt
.
Sprintf
(
"%.2f"
,
rule
.
Threshold
)
if
event
.
MetricValue
!=
nil
{
value
=
fmt
.
Sprintf
(
"%.2f"
,
*
event
.
MetricValue
)
}
if
event
.
ThresholdValue
!=
nil
{
threshold
=
fmt
.
Sprintf
(
"%.2f"
,
*
event
.
ThresholdValue
)
}
return
fmt
.
Sprintf
(
`
<h2>Ops Alert</h2>
<p><b>Rule</b>: %s</p>
<p><b>Severity</b>: %s</p>
<p><b>Status</b>: %s</p>
<p><b>Metric</b>: %s %s %s</p>
<p><b>Fired at</b>: %s</p>
<p><b>Description</b>: %s</p>
`
,
htmlEscape
(
rule
.
Name
),
htmlEscape
(
rule
.
Severity
),
htmlEscape
(
event
.
Status
),
htmlEscape
(
metric
),
htmlEscape
(
rule
.
Operator
),
htmlEscape
(
fmt
.
Sprintf
(
"%s (threshold %s)"
,
value
,
threshold
)),
event
.
FiredAt
.
Format
(
time
.
RFC3339
),
htmlEscape
(
event
.
Description
),
)
}
func
shouldSendOpsAlertEmailByMinSeverity
(
minSeverity
string
,
ruleSeverity
string
)
bool
{
minSeverity
=
strings
.
ToLower
(
strings
.
TrimSpace
(
minSeverity
))
if
minSeverity
==
""
{
return
true
}
eventLevel
:=
opsEmailSeverityForOps
(
ruleSeverity
)
minLevel
:=
strings
.
ToLower
(
minSeverity
)
rank
:=
func
(
level
string
)
int
{
switch
level
{
case
"critical"
:
return
3
case
"warning"
:
return
2
case
"info"
:
return
1
default
:
return
0
}
}
return
rank
(
eventLevel
)
>=
rank
(
minLevel
)
}
func
opsEmailSeverityForOps
(
severity
string
)
string
{
switch
strings
.
ToUpper
(
strings
.
TrimSpace
(
severity
))
{
case
"P0"
:
return
"critical"
case
"P1"
:
return
"warning"
default
:
return
"info"
}
}
func
isOpsAlertSilenced
(
now
time
.
Time
,
rule
*
OpsAlertRule
,
event
*
OpsAlertEvent
,
silencing
OpsAlertSilencingSettings
)
bool
{
if
!
silencing
.
Enabled
{
return
false
}
if
now
.
IsZero
()
{
now
=
time
.
Now
()
.
UTC
()
}
if
strings
.
TrimSpace
(
silencing
.
GlobalUntilRFC3339
)
!=
""
{
if
t
,
err
:=
time
.
Parse
(
time
.
RFC3339
,
strings
.
TrimSpace
(
silencing
.
GlobalUntilRFC3339
));
err
==
nil
{
if
now
.
Before
(
t
)
{
return
true
}
}
}
for
_
,
entry
:=
range
silencing
.
Entries
{
untilRaw
:=
strings
.
TrimSpace
(
entry
.
UntilRFC3339
)
if
untilRaw
==
""
{
continue
}
until
,
err
:=
time
.
Parse
(
time
.
RFC3339
,
untilRaw
)
if
err
!=
nil
{
continue
}
if
now
.
After
(
until
)
{
continue
}
if
entry
.
RuleID
!=
nil
&&
rule
!=
nil
&&
rule
.
ID
>
0
&&
*
entry
.
RuleID
!=
rule
.
ID
{
continue
}
if
len
(
entry
.
Severities
)
>
0
{
match
:=
false
for
_
,
s
:=
range
entry
.
Severities
{
if
strings
.
EqualFold
(
strings
.
TrimSpace
(
s
),
strings
.
TrimSpace
(
event
.
Severity
))
||
strings
.
EqualFold
(
strings
.
TrimSpace
(
s
),
strings
.
TrimSpace
(
rule
.
Severity
))
{
match
=
true
break
}
}
if
!
match
{
continue
}
}
return
true
}
return
false
}
func
(
s
*
OpsAlertEvaluatorService
)
tryAcquireLeaderLock
(
ctx
context
.
Context
,
lock
OpsDistributedLockSettings
)
(
func
(),
bool
)
{
if
!
lock
.
Enabled
{
return
nil
,
true
}
if
s
.
redisClient
==
nil
{
s
.
warnNoRedisOnce
.
Do
(
func
()
{
log
.
Printf
(
"[OpsAlertEvaluator] redis not configured; running without distributed lock"
)
})
return
nil
,
true
}
key
:=
strings
.
TrimSpace
(
lock
.
Key
)
if
key
==
""
{
key
=
opsAlertEvaluatorLeaderLockKey
}
ttl
:=
time
.
Duration
(
lock
.
TTLSeconds
)
*
time
.
Second
if
ttl
<=
0
{
ttl
=
opsAlertEvaluatorLeaderLockTTL
}
ok
,
err
:=
s
.
redisClient
.
SetNX
(
ctx
,
key
,
s
.
instanceID
,
ttl
)
.
Result
()
if
err
!=
nil
{
// Prefer fail-closed to avoid duplicate evaluators stampeding the DB when Redis is flaky.
// Single-node deployments can disable the distributed lock via runtime settings.
s
.
warnNoRedisOnce
.
Do
(
func
()
{
log
.
Printf
(
"[OpsAlertEvaluator] leader lock SetNX failed; skipping this cycle: %v"
,
err
)
})
return
nil
,
false
}
if
!
ok
{
s
.
maybeLogSkip
(
key
)
return
nil
,
false
}
return
func
()
{
_
,
_
=
opsAlertEvaluatorReleaseScript
.
Run
(
ctx
,
s
.
redisClient
,
[]
string
{
key
},
s
.
instanceID
)
.
Result
()
},
true
}
func
(
s
*
OpsAlertEvaluatorService
)
maybeLogSkip
(
key
string
)
{
s
.
skipLogMu
.
Lock
()
defer
s
.
skipLogMu
.
Unlock
()
now
:=
time
.
Now
()
if
!
s
.
skipLogAt
.
IsZero
()
&&
now
.
Sub
(
s
.
skipLogAt
)
<
opsAlertEvaluatorSkipLogInterval
{
return
}
s
.
skipLogAt
=
now
log
.
Printf
(
"[OpsAlertEvaluator] leader lock held by another instance; skipping (key=%q)"
,
key
)
}
func
(
s
*
OpsAlertEvaluatorService
)
recordHeartbeatSuccess
(
runAt
time
.
Time
,
duration
time
.
Duration
,
result
string
)
{
if
s
==
nil
||
s
.
opsRepo
==
nil
{
return
}
now
:=
time
.
Now
()
.
UTC
()
durMs
:=
duration
.
Milliseconds
()
ctx
,
cancel
:=
context
.
WithTimeout
(
context
.
Background
(),
2
*
time
.
Second
)
defer
cancel
()
msg
:=
strings
.
TrimSpace
(
result
)
if
msg
==
""
{
msg
=
"ok"
}
msg
=
truncateString
(
msg
,
2048
)
_
=
s
.
opsRepo
.
UpsertJobHeartbeat
(
ctx
,
&
OpsUpsertJobHeartbeatInput
{
JobName
:
opsAlertEvaluatorJobName
,
LastRunAt
:
&
runAt
,
LastSuccessAt
:
&
now
,
LastDurationMs
:
&
durMs
,
LastResult
:
&
msg
,
})
}
func
(
s
*
OpsAlertEvaluatorService
)
recordHeartbeatError
(
runAt
time
.
Time
,
duration
time
.
Duration
,
err
error
)
{
if
s
==
nil
||
s
.
opsRepo
==
nil
||
err
==
nil
{
return
}
now
:=
time
.
Now
()
.
UTC
()
durMs
:=
duration
.
Milliseconds
()
msg
:=
truncateString
(
err
.
Error
(),
2048
)
ctx
,
cancel
:=
context
.
WithTimeout
(
context
.
Background
(),
2
*
time
.
Second
)
defer
cancel
()
_
=
s
.
opsRepo
.
UpsertJobHeartbeat
(
ctx
,
&
OpsUpsertJobHeartbeatInput
{
JobName
:
opsAlertEvaluatorJobName
,
LastRunAt
:
&
runAt
,
LastErrorAt
:
&
now
,
LastError
:
&
msg
,
LastDurationMs
:
&
durMs
,
})
}
func
htmlEscape
(
s
string
)
string
{
replacer
:=
strings
.
NewReplacer
(
"&"
,
"&"
,
"<"
,
"<"
,
">"
,
">"
,
`"`
,
"""
,
"'"
,
"'"
,
)
return
replacer
.
Replace
(
s
)
}
type
slidingWindowLimiter
struct
{
mu
sync
.
Mutex
limit
int
window
time
.
Duration
sent
[]
time
.
Time
}
func
newSlidingWindowLimiter
(
limit
int
,
window
time
.
Duration
)
*
slidingWindowLimiter
{
if
window
<=
0
{
window
=
time
.
Hour
}
return
&
slidingWindowLimiter
{
limit
:
limit
,
window
:
window
,
sent
:
[]
time
.
Time
{},
}
}
func
(
l
*
slidingWindowLimiter
)
SetLimit
(
limit
int
)
{
l
.
mu
.
Lock
()
defer
l
.
mu
.
Unlock
()
l
.
limit
=
limit
}
func
(
l
*
slidingWindowLimiter
)
Allow
(
now
time
.
Time
)
bool
{
l
.
mu
.
Lock
()
defer
l
.
mu
.
Unlock
()
if
l
.
limit
<=
0
{
return
true
}
cutoff
:=
now
.
Add
(
-
l
.
window
)
keep
:=
l
.
sent
[
:
0
]
for
_
,
t
:=
range
l
.
sent
{
if
t
.
After
(
cutoff
)
{
keep
=
append
(
keep
,
t
)
}
}
l
.
sent
=
keep
if
len
(
l
.
sent
)
>=
l
.
limit
{
return
false
}
l
.
sent
=
append
(
l
.
sent
,
now
)
return
true
}
// computeGroupAvailableRatio returns the available percentage for a group.
// Formula: (AvailableCount / TotalAccounts) * 100.
// Returns 0 when TotalAccounts is 0.
func
computeGroupAvailableRatio
(
group
*
GroupAvailability
)
float64
{
if
group
==
nil
||
group
.
TotalAccounts
<=
0
{
return
0
}
return
(
float64
(
group
.
AvailableCount
)
/
float64
(
group
.
TotalAccounts
))
*
100
}
// countAccountsByCondition counts accounts that satisfy the given condition.
func
countAccountsByCondition
(
accounts
map
[
int64
]
*
AccountAvailability
,
condition
func
(
*
AccountAvailability
)
bool
)
int64
{
if
len
(
accounts
)
==
0
||
condition
==
nil
{
return
0
}
var
count
int64
for
_
,
account
:=
range
accounts
{
if
account
!=
nil
&&
condition
(
account
)
{
count
++
}
}
return
count
}
backend/internal/service/ops_alert_evaluator_service_test.go
0 → 100644
View file @
b9b4db3d
//go:build unit
package
service
import
(
"context"
"testing"
"time"
"github.com/stretchr/testify/require"
)
type
stubOpsRepo
struct
{
OpsRepository
overview
*
OpsDashboardOverview
err
error
}
func
(
s
*
stubOpsRepo
)
GetDashboardOverview
(
ctx
context
.
Context
,
filter
*
OpsDashboardFilter
)
(
*
OpsDashboardOverview
,
error
)
{
if
s
.
err
!=
nil
{
return
nil
,
s
.
err
}
if
s
.
overview
!=
nil
{
return
s
.
overview
,
nil
}
return
&
OpsDashboardOverview
{},
nil
}
func
TestComputeGroupAvailableRatio
(
t
*
testing
.
T
)
{
t
.
Parallel
()
t
.
Run
(
"正常情况: 10个账号, 8个可用 = 80%"
,
func
(
t
*
testing
.
T
)
{
t
.
Parallel
()
got
:=
computeGroupAvailableRatio
(
&
GroupAvailability
{
TotalAccounts
:
10
,
AvailableCount
:
8
,
})
require
.
InDelta
(
t
,
80.0
,
got
,
0.0001
)
})
t
.
Run
(
"边界情况: TotalAccounts = 0 应返回 0"
,
func
(
t
*
testing
.
T
)
{
t
.
Parallel
()
got
:=
computeGroupAvailableRatio
(
&
GroupAvailability
{
TotalAccounts
:
0
,
AvailableCount
:
8
,
})
require
.
Equal
(
t
,
0.0
,
got
)
})
t
.
Run
(
"边界情况: AvailableCount = 0 应返回 0%"
,
func
(
t
*
testing
.
T
)
{
t
.
Parallel
()
got
:=
computeGroupAvailableRatio
(
&
GroupAvailability
{
TotalAccounts
:
10
,
AvailableCount
:
0
,
})
require
.
Equal
(
t
,
0.0
,
got
)
})
}
func
TestCountAccountsByCondition
(
t
*
testing
.
T
)
{
t
.
Parallel
()
t
.
Run
(
"测试限流账号统计: acc.IsRateLimited"
,
func
(
t
*
testing
.
T
)
{
t
.
Parallel
()
accounts
:=
map
[
int64
]
*
AccountAvailability
{
1
:
{
IsRateLimited
:
true
},
2
:
{
IsRateLimited
:
false
},
3
:
{
IsRateLimited
:
true
},
}
got
:=
countAccountsByCondition
(
accounts
,
func
(
acc
*
AccountAvailability
)
bool
{
return
acc
.
IsRateLimited
})
require
.
Equal
(
t
,
int64
(
2
),
got
)
})
t
.
Run
(
"测试错误账号统计(排除临时不可调度): acc.HasError && acc.TempUnschedulableUntil == nil"
,
func
(
t
*
testing
.
T
)
{
t
.
Parallel
()
until
:=
time
.
Now
()
.
UTC
()
.
Add
(
5
*
time
.
Minute
)
accounts
:=
map
[
int64
]
*
AccountAvailability
{
1
:
{
HasError
:
true
},
2
:
{
HasError
:
true
,
TempUnschedulableUntil
:
&
until
},
3
:
{
HasError
:
false
},
}
got
:=
countAccountsByCondition
(
accounts
,
func
(
acc
*
AccountAvailability
)
bool
{
return
acc
.
HasError
&&
acc
.
TempUnschedulableUntil
==
nil
})
require
.
Equal
(
t
,
int64
(
1
),
got
)
})
t
.
Run
(
"边界情况: 空 map 应返回 0"
,
func
(
t
*
testing
.
T
)
{
t
.
Parallel
()
got
:=
countAccountsByCondition
(
map
[
int64
]
*
AccountAvailability
{},
func
(
acc
*
AccountAvailability
)
bool
{
return
acc
.
IsRateLimited
})
require
.
Equal
(
t
,
int64
(
0
),
got
)
})
}
func
TestComputeRuleMetricNewIndicators
(
t
*
testing
.
T
)
{
t
.
Parallel
()
groupID
:=
int64
(
101
)
platform
:=
"openai"
availability
:=
&
OpsAccountAvailability
{
Group
:
&
GroupAvailability
{
GroupID
:
groupID
,
TotalAccounts
:
10
,
AvailableCount
:
8
,
},
Accounts
:
map
[
int64
]
*
AccountAvailability
{
1
:
{
IsRateLimited
:
true
},
2
:
{
IsRateLimited
:
true
},
3
:
{
HasError
:
true
},
4
:
{
HasError
:
true
,
TempUnschedulableUntil
:
timePtr
(
time
.
Now
()
.
UTC
()
.
Add
(
2
*
time
.
Minute
))},
5
:
{
HasError
:
false
,
IsRateLimited
:
false
},
},
}
opsService
:=
&
OpsService
{
getAccountAvailability
:
func
(
_
context
.
Context
,
_
string
,
_
*
int64
)
(
*
OpsAccountAvailability
,
error
)
{
return
availability
,
nil
},
}
svc
:=
&
OpsAlertEvaluatorService
{
opsService
:
opsService
,
opsRepo
:
&
stubOpsRepo
{
overview
:
&
OpsDashboardOverview
{}},
}
start
:=
time
.
Now
()
.
UTC
()
.
Add
(
-
5
*
time
.
Minute
)
end
:=
time
.
Now
()
.
UTC
()
ctx
:=
context
.
Background
()
tests
:=
[]
struct
{
name
string
metricType
string
groupID
*
int64
wantValue
float64
wantOK
bool
}{
{
name
:
"group_available_accounts"
,
metricType
:
"group_available_accounts"
,
groupID
:
&
groupID
,
wantValue
:
8
,
wantOK
:
true
,
},
{
name
:
"group_available_ratio"
,
metricType
:
"group_available_ratio"
,
groupID
:
&
groupID
,
wantValue
:
80.0
,
wantOK
:
true
,
},
{
name
:
"account_rate_limited_count"
,
metricType
:
"account_rate_limited_count"
,
groupID
:
nil
,
wantValue
:
2
,
wantOK
:
true
,
},
{
name
:
"account_error_count"
,
metricType
:
"account_error_count"
,
groupID
:
nil
,
wantValue
:
1
,
wantOK
:
true
,
},
{
name
:
"group_available_accounts without group_id returns false"
,
metricType
:
"group_available_accounts"
,
groupID
:
nil
,
wantValue
:
0
,
wantOK
:
false
,
},
{
name
:
"group_available_ratio without group_id returns false"
,
metricType
:
"group_available_ratio"
,
groupID
:
nil
,
wantValue
:
0
,
wantOK
:
false
,
},
}
for
_
,
tt
:=
range
tests
{
tt
:=
tt
t
.
Run
(
tt
.
name
,
func
(
t
*
testing
.
T
)
{
t
.
Parallel
()
rule
:=
&
OpsAlertRule
{
MetricType
:
tt
.
metricType
,
}
gotValue
,
gotOK
:=
svc
.
computeRuleMetric
(
ctx
,
rule
,
nil
,
start
,
end
,
platform
,
tt
.
groupID
)
require
.
Equal
(
t
,
tt
.
wantOK
,
gotOK
)
if
!
tt
.
wantOK
{
return
}
require
.
InDelta
(
t
,
tt
.
wantValue
,
gotValue
,
0.0001
)
})
}
}
backend/internal/service/ops_alert_models.go
0 → 100644
View file @
b9b4db3d
package
service
import
"time"
// Ops alert rule/event models.
//
// NOTE: These are admin-facing DTOs and intentionally keep JSON naming aligned
// with the existing ops dashboard frontend (backup style).
const
(
OpsAlertStatusFiring
=
"firing"
OpsAlertStatusResolved
=
"resolved"
OpsAlertStatusManualResolved
=
"manual_resolved"
)
type
OpsAlertRule
struct
{
ID
int64
`json:"id"`
Name
string
`json:"name"`
Description
string
`json:"description"`
Enabled
bool
`json:"enabled"`
Severity
string
`json:"severity"`
MetricType
string
`json:"metric_type"`
Operator
string
`json:"operator"`
Threshold
float64
`json:"threshold"`
WindowMinutes
int
`json:"window_minutes"`
SustainedMinutes
int
`json:"sustained_minutes"`
CooldownMinutes
int
`json:"cooldown_minutes"`
NotifyEmail
bool
`json:"notify_email"`
Filters
map
[
string
]
any
`json:"filters,omitempty"`
LastTriggeredAt
*
time
.
Time
`json:"last_triggered_at,omitempty"`
CreatedAt
time
.
Time
`json:"created_at"`
UpdatedAt
time
.
Time
`json:"updated_at"`
}
type
OpsAlertEvent
struct
{
ID
int64
`json:"id"`
RuleID
int64
`json:"rule_id"`
Severity
string
`json:"severity"`
Status
string
`json:"status"`
Title
string
`json:"title"`
Description
string
`json:"description"`
MetricValue
*
float64
`json:"metric_value,omitempty"`
ThresholdValue
*
float64
`json:"threshold_value,omitempty"`
Dimensions
map
[
string
]
any
`json:"dimensions,omitempty"`
FiredAt
time
.
Time
`json:"fired_at"`
ResolvedAt
*
time
.
Time
`json:"resolved_at,omitempty"`
EmailSent
bool
`json:"email_sent"`
CreatedAt
time
.
Time
`json:"created_at"`
}
type
OpsAlertSilence
struct
{
ID
int64
`json:"id"`
RuleID
int64
`json:"rule_id"`
Platform
string
`json:"platform"`
GroupID
*
int64
`json:"group_id,omitempty"`
Region
*
string
`json:"region,omitempty"`
Until
time
.
Time
`json:"until"`
Reason
string
`json:"reason"`
CreatedBy
*
int64
`json:"created_by,omitempty"`
CreatedAt
time
.
Time
`json:"created_at"`
}
type
OpsAlertEventFilter
struct
{
Limit
int
// Cursor pagination (descending by fired_at, then id).
BeforeFiredAt
*
time
.
Time
BeforeID
*
int64
// Optional filters.
Status
string
Severity
string
EmailSent
*
bool
StartTime
*
time
.
Time
EndTime
*
time
.
Time
// Dimensions filters (best-effort).
Platform
string
GroupID
*
int64
}
Prev
1
…
7
8
9
10
11
12
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment