Commit b9b4db3d authored by song's avatar song
Browse files

Merge upstream/main

parents 5a6f60a9 dae0d532
package service
import "time"
import (
"strings"
"time"
)
type Group struct {
ID int64
......@@ -10,6 +13,7 @@ type Group struct {
RateMultiplier float64
IsExclusive bool
Status string
Hydrated bool // indicates the group was loaded from a trusted repository source
SubscriptionType string
DailyLimitUSD *float64
......@@ -26,6 +30,12 @@ type Group struct {
ClaudeCodeOnly bool
FallbackGroupID *int64
// 模型路由配置
// key: 模型匹配模式(支持 * 通配符,如 "claude-opus-*")
// value: 优先账号 ID 列表
ModelRouting map[string][]int64
ModelRoutingEnabled bool
CreatedAt time.Time
UpdatedAt time.Time
......@@ -72,3 +82,58 @@ func (g *Group) GetImagePrice(imageSize string) *float64 {
return g.ImagePrice2K
}
}
// IsGroupContextValid reports whether a group from context has the fields required for routing decisions.
func IsGroupContextValid(group *Group) bool {
if group == nil {
return false
}
if group.ID <= 0 {
return false
}
if !group.Hydrated {
return false
}
if group.Platform == "" || group.Status == "" {
return false
}
return true
}
// GetRoutingAccountIDs 根据请求模型获取路由账号 ID 列表
// 返回匹配的优先账号 ID 列表,如果没有匹配规则则返回 nil
func (g *Group) GetRoutingAccountIDs(requestedModel string) []int64 {
if !g.ModelRoutingEnabled || len(g.ModelRouting) == 0 || requestedModel == "" {
return nil
}
// 1. 精确匹配优先
if accountIDs, ok := g.ModelRouting[requestedModel]; ok && len(accountIDs) > 0 {
return accountIDs
}
// 2. 通配符匹配(前缀匹配)
for pattern, accountIDs := range g.ModelRouting {
if matchModelPattern(pattern, requestedModel) && len(accountIDs) > 0 {
return accountIDs
}
}
return nil
}
// matchModelPattern 检查模型是否匹配模式
// 支持 * 通配符,如 "claude-opus-*" 匹配 "claude-opus-4-20250514"
func matchModelPattern(pattern, model string) bool {
if pattern == model {
return true
}
// 处理 * 通配符(仅支持末尾通配符)
if strings.HasSuffix(pattern, "*") {
prefix := strings.TrimSuffix(pattern, "*")
return strings.HasPrefix(model, prefix)
}
return false
}
......@@ -16,6 +16,7 @@ var (
type GroupRepository interface {
Create(ctx context.Context, group *Group) error
GetByID(ctx context.Context, id int64) (*Group, error)
GetByIDLite(ctx context.Context, id int64) (*Group, error)
Update(ctx context.Context, group *Group) error
Delete(ctx context.Context, id int64) error
DeleteCascade(ctx context.Context, id int64) ([]int64, error)
......@@ -49,13 +50,15 @@ type UpdateGroupRequest struct {
// GroupService 分组管理服务
type GroupService struct {
groupRepo GroupRepository
groupRepo GroupRepository
authCacheInvalidator APIKeyAuthCacheInvalidator
}
// NewGroupService 创建分组服务实例
func NewGroupService(groupRepo GroupRepository) *GroupService {
func NewGroupService(groupRepo GroupRepository, authCacheInvalidator APIKeyAuthCacheInvalidator) *GroupService {
return &GroupService{
groupRepo: groupRepo,
groupRepo: groupRepo,
authCacheInvalidator: authCacheInvalidator,
}
}
......@@ -154,6 +157,9 @@ func (s *GroupService) Update(ctx context.Context, id int64, req UpdateGroupRequ
if err := s.groupRepo.Update(ctx, group); err != nil {
return nil, fmt.Errorf("update group: %w", err)
}
if s.authCacheInvalidator != nil {
s.authCacheInvalidator.InvalidateAuthCacheByGroupID(ctx, id)
}
return group, nil
}
......@@ -166,6 +172,9 @@ func (s *GroupService) Delete(ctx context.Context, id int64) error {
return fmt.Errorf("get group: %w", err)
}
if s.authCacheInvalidator != nil {
s.authCacheInvalidator.InvalidateAuthCacheByGroupID(ctx, id)
}
if err := s.groupRepo.Delete(ctx, id); err != nil {
return fmt.Errorf("delete group: %w", err)
}
......
package service
import (
"strings"
"time"
)
const modelRateLimitsKey = "model_rate_limits"
const modelRateLimitScopeClaudeSonnet = "claude_sonnet"
func resolveModelRateLimitScope(requestedModel string) (string, bool) {
model := strings.ToLower(strings.TrimSpace(requestedModel))
if model == "" {
return "", false
}
model = strings.TrimPrefix(model, "models/")
if strings.Contains(model, "sonnet") {
return modelRateLimitScopeClaudeSonnet, true
}
return "", false
}
func (a *Account) isModelRateLimited(requestedModel string) bool {
scope, ok := resolveModelRateLimitScope(requestedModel)
if !ok {
return false
}
resetAt := a.modelRateLimitResetAt(scope)
if resetAt == nil {
return false
}
return time.Now().Before(*resetAt)
}
func (a *Account) modelRateLimitResetAt(scope string) *time.Time {
if a == nil || a.Extra == nil || scope == "" {
return nil
}
rawLimits, ok := a.Extra[modelRateLimitsKey].(map[string]any)
if !ok {
return nil
}
rawLimit, ok := rawLimits[scope].(map[string]any)
if !ok {
return nil
}
resetAtRaw, ok := rawLimit["rate_limit_reset_at"].(string)
if !ok || strings.TrimSpace(resetAtRaw) == "" {
return nil
}
resetAt, err := time.Parse(time.RFC3339, resetAtRaw)
if err != nil {
return nil
}
return &resetAt
}
package service
import (
_ "embed"
"encoding/json"
"fmt"
"io"
"net/http"
"os"
"path/filepath"
"strings"
"time"
)
const (
opencodeCodexHeaderURL = "https://raw.githubusercontent.com/anomalyco/opencode/dev/packages/opencode/src/session/prompt/codex_header.txt"
codexCacheTTL = 15 * time.Minute
)
//go:embed prompts/codex_cli_instructions.md
var codexCLIInstructions string
var codexModelMap = map[string]string{
"gpt-5.1-codex": "gpt-5.1-codex",
"gpt-5.1-codex-low": "gpt-5.1-codex",
"gpt-5.1-codex-medium": "gpt-5.1-codex",
"gpt-5.1-codex-high": "gpt-5.1-codex",
"gpt-5.1-codex-max": "gpt-5.1-codex-max",
"gpt-5.1-codex-max-low": "gpt-5.1-codex-max",
"gpt-5.1-codex-max-medium": "gpt-5.1-codex-max",
"gpt-5.1-codex-max-high": "gpt-5.1-codex-max",
"gpt-5.1-codex-max-xhigh": "gpt-5.1-codex-max",
"gpt-5.2": "gpt-5.2",
"gpt-5.2-none": "gpt-5.2",
"gpt-5.2-low": "gpt-5.2",
"gpt-5.2-medium": "gpt-5.2",
"gpt-5.2-high": "gpt-5.2",
"gpt-5.2-xhigh": "gpt-5.2",
"gpt-5.2-codex": "gpt-5.2-codex",
"gpt-5.2-codex-low": "gpt-5.2-codex",
"gpt-5.2-codex-medium": "gpt-5.2-codex",
"gpt-5.2-codex-high": "gpt-5.2-codex",
"gpt-5.2-codex-xhigh": "gpt-5.2-codex",
"gpt-5.1-codex-mini": "gpt-5.1-codex-mini",
"gpt-5.1-codex-mini-medium": "gpt-5.1-codex-mini",
"gpt-5.1-codex-mini-high": "gpt-5.1-codex-mini",
"gpt-5.1": "gpt-5.1",
"gpt-5.1-none": "gpt-5.1",
"gpt-5.1-low": "gpt-5.1",
"gpt-5.1-medium": "gpt-5.1",
"gpt-5.1-high": "gpt-5.1",
"gpt-5.1-chat-latest": "gpt-5.1",
"gpt-5-codex": "gpt-5.1-codex",
"codex-mini-latest": "gpt-5.1-codex-mini",
"gpt-5-codex-mini": "gpt-5.1-codex-mini",
"gpt-5-codex-mini-medium": "gpt-5.1-codex-mini",
"gpt-5-codex-mini-high": "gpt-5.1-codex-mini",
"gpt-5": "gpt-5.1",
"gpt-5-mini": "gpt-5.1",
"gpt-5-nano": "gpt-5.1",
}
type codexTransformResult struct {
Modified bool
NormalizedModel string
PromptCacheKey string
}
type opencodeCacheMetadata struct {
ETag string `json:"etag"`
LastFetch string `json:"lastFetch,omitempty"`
LastChecked int64 `json:"lastChecked"`
}
func applyCodexOAuthTransform(reqBody map[string]any) codexTransformResult {
result := codexTransformResult{}
// 工具续链需求会影响存储策略与 input 过滤逻辑。
needsToolContinuation := NeedsToolContinuation(reqBody)
model := ""
if v, ok := reqBody["model"].(string); ok {
model = v
}
normalizedModel := normalizeCodexModel(model)
if normalizedModel != "" {
if model != normalizedModel {
reqBody["model"] = normalizedModel
result.Modified = true
}
result.NormalizedModel = normalizedModel
}
// OAuth 走 ChatGPT internal API 时,store 必须为 false;显式 true 也会强制覆盖。
// 避免上游返回 "Store must be set to false"。
if v, ok := reqBody["store"].(bool); !ok || v {
reqBody["store"] = false
result.Modified = true
}
if v, ok := reqBody["stream"].(bool); !ok || !v {
reqBody["stream"] = true
result.Modified = true
}
if _, ok := reqBody["max_output_tokens"]; ok {
delete(reqBody, "max_output_tokens")
result.Modified = true
}
if _, ok := reqBody["max_completion_tokens"]; ok {
delete(reqBody, "max_completion_tokens")
result.Modified = true
}
if normalizeCodexTools(reqBody) {
result.Modified = true
}
if v, ok := reqBody["prompt_cache_key"].(string); ok {
result.PromptCacheKey = strings.TrimSpace(v)
}
instructions := strings.TrimSpace(getOpenCodeCodexHeader())
existingInstructions, _ := reqBody["instructions"].(string)
existingInstructions = strings.TrimSpace(existingInstructions)
if instructions != "" {
if existingInstructions != instructions {
reqBody["instructions"] = instructions
result.Modified = true
}
} else if existingInstructions == "" {
// 未获取到 opencode 指令时,回退使用 Codex CLI 指令。
codexInstructions := strings.TrimSpace(getCodexCLIInstructions())
if codexInstructions != "" {
reqBody["instructions"] = codexInstructions
result.Modified = true
}
}
// 续链场景保留 item_reference 与 id,避免 call_id 上下文丢失。
if input, ok := reqBody["input"].([]any); ok {
input = filterCodexInput(input, needsToolContinuation)
reqBody["input"] = input
result.Modified = true
}
return result
}
func normalizeCodexModel(model string) string {
if model == "" {
return "gpt-5.1"
}
modelID := model
if strings.Contains(modelID, "/") {
parts := strings.Split(modelID, "/")
modelID = parts[len(parts)-1]
}
if mapped := getNormalizedCodexModel(modelID); mapped != "" {
return mapped
}
normalized := strings.ToLower(modelID)
if strings.Contains(normalized, "gpt-5.2-codex") || strings.Contains(normalized, "gpt 5.2 codex") {
return "gpt-5.2-codex"
}
if strings.Contains(normalized, "gpt-5.2") || strings.Contains(normalized, "gpt 5.2") {
return "gpt-5.2"
}
if strings.Contains(normalized, "gpt-5.1-codex-max") || strings.Contains(normalized, "gpt 5.1 codex max") {
return "gpt-5.1-codex-max"
}
if strings.Contains(normalized, "gpt-5.1-codex-mini") || strings.Contains(normalized, "gpt 5.1 codex mini") {
return "gpt-5.1-codex-mini"
}
if strings.Contains(normalized, "codex-mini-latest") ||
strings.Contains(normalized, "gpt-5-codex-mini") ||
strings.Contains(normalized, "gpt 5 codex mini") {
return "codex-mini-latest"
}
if strings.Contains(normalized, "gpt-5.1-codex") || strings.Contains(normalized, "gpt 5.1 codex") {
return "gpt-5.1-codex"
}
if strings.Contains(normalized, "gpt-5.1") || strings.Contains(normalized, "gpt 5.1") {
return "gpt-5.1"
}
if strings.Contains(normalized, "codex") {
return "gpt-5.1-codex"
}
if strings.Contains(normalized, "gpt-5") || strings.Contains(normalized, "gpt 5") {
return "gpt-5.1"
}
return "gpt-5.1"
}
func getNormalizedCodexModel(modelID string) string {
if modelID == "" {
return ""
}
if mapped, ok := codexModelMap[modelID]; ok {
return mapped
}
lower := strings.ToLower(modelID)
for key, value := range codexModelMap {
if strings.ToLower(key) == lower {
return value
}
}
return ""
}
func getOpenCodeCachedPrompt(url, cacheFileName, metaFileName string) string {
cacheDir := codexCachePath("")
if cacheDir == "" {
return ""
}
cacheFile := filepath.Join(cacheDir, cacheFileName)
metaFile := filepath.Join(cacheDir, metaFileName)
var cachedContent string
if content, ok := readFile(cacheFile); ok {
cachedContent = content
}
var meta opencodeCacheMetadata
if loadJSON(metaFile, &meta) && meta.LastChecked > 0 && cachedContent != "" {
if time.Since(time.UnixMilli(meta.LastChecked)) < codexCacheTTL {
return cachedContent
}
}
content, etag, status, err := fetchWithETag(url, meta.ETag)
if err == nil && status == http.StatusNotModified && cachedContent != "" {
return cachedContent
}
if err == nil && status >= 200 && status < 300 && content != "" {
_ = writeFile(cacheFile, content)
meta = opencodeCacheMetadata{
ETag: etag,
LastFetch: time.Now().UTC().Format(time.RFC3339),
LastChecked: time.Now().UnixMilli(),
}
_ = writeJSON(metaFile, meta)
return content
}
return cachedContent
}
func getOpenCodeCodexHeader() string {
// 优先从 opencode 仓库缓存获取指令。
opencodeInstructions := getOpenCodeCachedPrompt(opencodeCodexHeaderURL, "opencode-codex-header.txt", "opencode-codex-header-meta.json")
// 若 opencode 指令可用,直接返回。
if opencodeInstructions != "" {
return opencodeInstructions
}
// 否则回退使用本地 Codex CLI 指令。
return getCodexCLIInstructions()
}
func getCodexCLIInstructions() string {
return codexCLIInstructions
}
func GetOpenCodeInstructions() string {
return getOpenCodeCodexHeader()
}
// GetCodexCLIInstructions 返回内置的 Codex CLI 指令内容。
func GetCodexCLIInstructions() string {
return getCodexCLIInstructions()
}
// ReplaceWithCodexInstructions 将请求 instructions 替换为内置 Codex 指令(必要时)。
func ReplaceWithCodexInstructions(reqBody map[string]any) bool {
codexInstructions := strings.TrimSpace(getCodexCLIInstructions())
if codexInstructions == "" {
return false
}
existingInstructions, _ := reqBody["instructions"].(string)
if strings.TrimSpace(existingInstructions) != codexInstructions {
reqBody["instructions"] = codexInstructions
return true
}
return false
}
// IsInstructionError 判断错误信息是否与指令格式/系统提示相关。
func IsInstructionError(errorMessage string) bool {
if errorMessage == "" {
return false
}
lowerMsg := strings.ToLower(errorMessage)
instructionKeywords := []string{
"instruction",
"instructions",
"system prompt",
"system message",
"invalid prompt",
"prompt format",
}
for _, keyword := range instructionKeywords {
if strings.Contains(lowerMsg, keyword) {
return true
}
}
return false
}
// filterCodexInput 按需过滤 item_reference 与 id。
// preserveReferences 为 true 时保持引用与 id,以满足续链请求对上下文的依赖。
func filterCodexInput(input []any, preserveReferences bool) []any {
filtered := make([]any, 0, len(input))
for _, item := range input {
m, ok := item.(map[string]any)
if !ok {
filtered = append(filtered, item)
continue
}
typ, _ := m["type"].(string)
if typ == "item_reference" {
if !preserveReferences {
continue
}
newItem := make(map[string]any, len(m))
for key, value := range m {
newItem[key] = value
}
filtered = append(filtered, newItem)
continue
}
newItem := m
copied := false
// 仅在需要修改字段时创建副本,避免直接改写原始输入。
ensureCopy := func() {
if copied {
return
}
newItem = make(map[string]any, len(m))
for key, value := range m {
newItem[key] = value
}
copied = true
}
if isCodexToolCallItemType(typ) {
if callID, ok := m["call_id"].(string); !ok || strings.TrimSpace(callID) == "" {
if id, ok := m["id"].(string); ok && strings.TrimSpace(id) != "" {
ensureCopy()
newItem["call_id"] = id
}
}
}
if !preserveReferences {
ensureCopy()
delete(newItem, "id")
if !isCodexToolCallItemType(typ) {
delete(newItem, "call_id")
}
}
filtered = append(filtered, newItem)
}
return filtered
}
func isCodexToolCallItemType(typ string) bool {
if typ == "" {
return false
}
return strings.HasSuffix(typ, "_call") || strings.HasSuffix(typ, "_call_output")
}
func normalizeCodexTools(reqBody map[string]any) bool {
rawTools, ok := reqBody["tools"]
if !ok || rawTools == nil {
return false
}
tools, ok := rawTools.([]any)
if !ok {
return false
}
modified := false
for idx, tool := range tools {
toolMap, ok := tool.(map[string]any)
if !ok {
continue
}
toolType, _ := toolMap["type"].(string)
if strings.TrimSpace(toolType) != "function" {
continue
}
function, ok := toolMap["function"].(map[string]any)
if !ok {
continue
}
if _, ok := toolMap["name"]; !ok {
if name, ok := function["name"].(string); ok && strings.TrimSpace(name) != "" {
toolMap["name"] = name
modified = true
}
}
if _, ok := toolMap["description"]; !ok {
if desc, ok := function["description"].(string); ok && strings.TrimSpace(desc) != "" {
toolMap["description"] = desc
modified = true
}
}
if _, ok := toolMap["parameters"]; !ok {
if params, ok := function["parameters"]; ok {
toolMap["parameters"] = params
modified = true
}
}
if _, ok := toolMap["strict"]; !ok {
if strict, ok := function["strict"]; ok {
toolMap["strict"] = strict
modified = true
}
}
tools[idx] = toolMap
}
if modified {
reqBody["tools"] = tools
}
return modified
}
func codexCachePath(filename string) string {
home, err := os.UserHomeDir()
if err != nil {
return ""
}
cacheDir := filepath.Join(home, ".opencode", "cache")
if filename == "" {
return cacheDir
}
return filepath.Join(cacheDir, filename)
}
func readFile(path string) (string, bool) {
if path == "" {
return "", false
}
data, err := os.ReadFile(path)
if err != nil {
return "", false
}
return string(data), true
}
func writeFile(path, content string) error {
if path == "" {
return fmt.Errorf("empty cache path")
}
if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil {
return err
}
return os.WriteFile(path, []byte(content), 0o644)
}
func loadJSON(path string, target any) bool {
data, err := os.ReadFile(path)
if err != nil {
return false
}
if err := json.Unmarshal(data, target); err != nil {
return false
}
return true
}
func writeJSON(path string, value any) error {
if path == "" {
return fmt.Errorf("empty json path")
}
if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil {
return err
}
data, err := json.Marshal(value)
if err != nil {
return err
}
return os.WriteFile(path, data, 0o644)
}
func fetchWithETag(url, etag string) (string, string, int, error) {
req, err := http.NewRequest(http.MethodGet, url, nil)
if err != nil {
return "", "", 0, err
}
req.Header.Set("User-Agent", "sub2api-codex")
if etag != "" {
req.Header.Set("If-None-Match", etag)
}
resp, err := http.DefaultClient.Do(req)
if err != nil {
return "", "", 0, err
}
defer func() {
_ = resp.Body.Close()
}()
body, err := io.ReadAll(resp.Body)
if err != nil {
return "", "", resp.StatusCode, err
}
return string(body), resp.Header.Get("etag"), resp.StatusCode, nil
}
package service
import (
"encoding/json"
"os"
"path/filepath"
"testing"
"time"
"github.com/stretchr/testify/require"
)
func TestApplyCodexOAuthTransform_ToolContinuationPreservesInput(t *testing.T) {
// 续链场景:保留 item_reference 与 id,但不再强制 store=true。
setupCodexCache(t)
reqBody := map[string]any{
"model": "gpt-5.2",
"input": []any{
map[string]any{"type": "item_reference", "id": "ref1", "text": "x"},
map[string]any{"type": "function_call_output", "call_id": "call_1", "output": "ok", "id": "o1"},
},
"tool_choice": "auto",
}
applyCodexOAuthTransform(reqBody)
// 未显式设置 store=true,默认为 false。
store, ok := reqBody["store"].(bool)
require.True(t, ok)
require.False(t, store)
input, ok := reqBody["input"].([]any)
require.True(t, ok)
require.Len(t, input, 2)
// 校验 input[0] 为 map,避免断言失败导致测试中断。
first, ok := input[0].(map[string]any)
require.True(t, ok)
require.Equal(t, "item_reference", first["type"])
require.Equal(t, "ref1", first["id"])
// 校验 input[1] 为 map,确保后续字段断言安全。
second, ok := input[1].(map[string]any)
require.True(t, ok)
require.Equal(t, "o1", second["id"])
}
func TestApplyCodexOAuthTransform_ExplicitStoreFalsePreserved(t *testing.T) {
// 续链场景:显式 store=false 不再强制为 true,保持 false。
setupCodexCache(t)
reqBody := map[string]any{
"model": "gpt-5.1",
"store": false,
"input": []any{
map[string]any{"type": "function_call_output", "call_id": "call_1"},
},
"tool_choice": "auto",
}
applyCodexOAuthTransform(reqBody)
store, ok := reqBody["store"].(bool)
require.True(t, ok)
require.False(t, store)
}
func TestApplyCodexOAuthTransform_ExplicitStoreTrueForcedFalse(t *testing.T) {
// 显式 store=true 也会强制为 false。
setupCodexCache(t)
reqBody := map[string]any{
"model": "gpt-5.1",
"store": true,
"input": []any{
map[string]any{"type": "function_call_output", "call_id": "call_1"},
},
"tool_choice": "auto",
}
applyCodexOAuthTransform(reqBody)
store, ok := reqBody["store"].(bool)
require.True(t, ok)
require.False(t, store)
}
func TestApplyCodexOAuthTransform_NonContinuationDefaultsStoreFalseAndStripsIDs(t *testing.T) {
// 非续链场景:未设置 store 时默认 false,并移除 input 中的 id。
setupCodexCache(t)
reqBody := map[string]any{
"model": "gpt-5.1",
"input": []any{
map[string]any{"type": "text", "id": "t1", "text": "hi"},
},
}
applyCodexOAuthTransform(reqBody)
store, ok := reqBody["store"].(bool)
require.True(t, ok)
require.False(t, store)
input, ok := reqBody["input"].([]any)
require.True(t, ok)
require.Len(t, input, 1)
// 校验 input[0] 为 map,避免类型不匹配触发 errcheck。
item, ok := input[0].(map[string]any)
require.True(t, ok)
_, hasID := item["id"]
require.False(t, hasID)
}
func TestFilterCodexInput_RemovesItemReferenceWhenNotPreserved(t *testing.T) {
input := []any{
map[string]any{"type": "item_reference", "id": "ref1"},
map[string]any{"type": "text", "id": "t1", "text": "hi"},
}
filtered := filterCodexInput(input, false)
require.Len(t, filtered, 1)
// 校验 filtered[0] 为 map,确保字段检查可靠。
item, ok := filtered[0].(map[string]any)
require.True(t, ok)
require.Equal(t, "text", item["type"])
_, hasID := item["id"]
require.False(t, hasID)
}
func TestApplyCodexOAuthTransform_EmptyInput(t *testing.T) {
// 空 input 应保持为空且不触发异常。
setupCodexCache(t)
reqBody := map[string]any{
"model": "gpt-5.1",
"input": []any{},
}
applyCodexOAuthTransform(reqBody)
input, ok := reqBody["input"].([]any)
require.True(t, ok)
require.Len(t, input, 0)
}
func setupCodexCache(t *testing.T) {
t.Helper()
// 使用临时 HOME 避免触发网络拉取 header。
tempDir := t.TempDir()
t.Setenv("HOME", tempDir)
cacheDir := filepath.Join(tempDir, ".opencode", "cache")
require.NoError(t, os.MkdirAll(cacheDir, 0o755))
require.NoError(t, os.WriteFile(filepath.Join(cacheDir, "opencode-codex-header.txt"), []byte("header"), 0o644))
meta := map[string]any{
"etag": "",
"lastFetch": time.Now().UTC().Format(time.RFC3339),
"lastChecked": time.Now().UnixMilli(),
}
data, err := json.Marshal(meta)
require.NoError(t, err)
require.NoError(t, os.WriteFile(filepath.Join(cacheDir, "opencode-codex-header-meta.json"), data, 0o644))
}
......@@ -20,6 +20,7 @@ import (
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
"github.com/Wei-Shaw/sub2api/internal/util/responseheaders"
"github.com/Wei-Shaw/sub2api/internal/util/urlvalidator"
"github.com/gin-gonic/gin"
......@@ -41,6 +42,7 @@ var openaiSSEDataRe = regexp.MustCompile(`^data:\s*`)
var openaiAllowedHeaders = map[string]bool{
"accept-language": true,
"content-type": true,
"conversation_id": true,
"user-agent": true,
"originator": true,
"session_id": true,
......@@ -84,12 +86,15 @@ type OpenAIGatewayService struct {
userSubRepo UserSubscriptionRepository
cache GatewayCache
cfg *config.Config
schedulerSnapshot *SchedulerSnapshotService
concurrencyService *ConcurrencyService
billingService *BillingService
rateLimitService *RateLimitService
billingCacheService *BillingCacheService
httpUpstream HTTPUpstream
deferredService *DeferredService
openAITokenProvider *OpenAITokenProvider
toolCorrector *CodexToolCorrector
}
// NewOpenAIGatewayService creates a new OpenAIGatewayService
......@@ -100,12 +105,14 @@ func NewOpenAIGatewayService(
userSubRepo UserSubscriptionRepository,
cache GatewayCache,
cfg *config.Config,
schedulerSnapshot *SchedulerSnapshotService,
concurrencyService *ConcurrencyService,
billingService *BillingService,
rateLimitService *RateLimitService,
billingCacheService *BillingCacheService,
httpUpstream HTTPUpstream,
deferredService *DeferredService,
openAITokenProvider *OpenAITokenProvider,
) *OpenAIGatewayService {
return &OpenAIGatewayService{
accountRepo: accountRepo,
......@@ -114,12 +121,15 @@ func NewOpenAIGatewayService(
userSubRepo: userSubRepo,
cache: cache,
cfg: cfg,
schedulerSnapshot: schedulerSnapshot,
concurrencyService: concurrencyService,
billingService: billingService,
rateLimitService: rateLimitService,
billingCacheService: billingCacheService,
httpUpstream: httpUpstream,
deferredService: deferredService,
openAITokenProvider: openAITokenProvider,
toolCorrector: NewCodexToolCorrector(),
}
}
......@@ -158,7 +168,7 @@ func (s *OpenAIGatewayService) SelectAccountForModelWithExclusions(ctx context.C
accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), "openai:"+sessionHash)
if err == nil && accountID > 0 {
if _, excluded := excludedIDs[accountID]; !excluded {
account, err := s.accountRepo.GetByID(ctx, accountID)
account, err := s.getSchedulableAccount(ctx, accountID)
if err == nil && account.IsSchedulable() && account.IsOpenAI() && (requestedModel == "" || account.IsModelSupported(requestedModel)) {
// Refresh sticky session TTL
_ = s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), "openai:"+sessionHash, openaiStickySessionTTL)
......@@ -169,16 +179,7 @@ func (s *OpenAIGatewayService) SelectAccountForModelWithExclusions(ctx context.C
}
// 2. Get schedulable OpenAI accounts
var accounts []Account
var err error
// 简易模式:忽略分组限制,查询所有可用账号
if s.cfg.RunMode == config.RunModeSimple {
accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, PlatformOpenAI)
} else if groupID != nil {
accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatform(ctx, *groupID, PlatformOpenAI)
} else {
accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, PlatformOpenAI)
}
accounts, err := s.listSchedulableAccounts(ctx, groupID)
if err != nil {
return nil, fmt.Errorf("query accounts failed: %w", err)
}
......@@ -190,6 +191,11 @@ func (s *OpenAIGatewayService) SelectAccountForModelWithExclusions(ctx context.C
if _, excluded := excludedIDs[acc.ID]; excluded {
continue
}
// Scheduler snapshots can be temporarily stale; re-check schedulability here to
// avoid selecting accounts that were recently rate-limited/overloaded.
if !acc.IsSchedulable() {
continue
}
// Check model support
if requestedModel != "" && !acc.IsModelSupported(requestedModel) {
continue
......@@ -300,7 +306,7 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex
if sessionHash != "" {
accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), "openai:"+sessionHash)
if err == nil && accountID > 0 && !isExcluded(accountID) {
account, err := s.accountRepo.GetByID(ctx, accountID)
account, err := s.getSchedulableAccount(ctx, accountID)
if err == nil && account.IsSchedulable() && account.IsOpenAI() &&
(requestedModel == "" || account.IsModelSupported(requestedModel)) {
result, err := s.tryAcquireAccountSlot(ctx, accountID, account.Concurrency)
......@@ -336,6 +342,12 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex
if isExcluded(acc.ID) {
continue
}
// Scheduler snapshots can be temporarily stale (bucket rebuild is throttled);
// re-check schedulability here so recently rate-limited/overloaded accounts
// are not selected again before the bucket is rebuilt.
if !acc.IsSchedulable() {
continue
}
if requestedModel != "" && !acc.IsModelSupported(requestedModel) {
continue
}
......@@ -445,6 +457,10 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex
}
func (s *OpenAIGatewayService) listSchedulableAccounts(ctx context.Context, groupID *int64) ([]Account, error) {
if s.schedulerSnapshot != nil {
accounts, _, err := s.schedulerSnapshot.ListSchedulableAccounts(ctx, groupID, PlatformOpenAI, false)
return accounts, err
}
var accounts []Account
var err error
if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple {
......@@ -467,6 +483,13 @@ func (s *OpenAIGatewayService) tryAcquireAccountSlot(ctx context.Context, accoun
return s.concurrencyService.AcquireAccountSlot(ctx, accountID, maxConcurrency)
}
func (s *OpenAIGatewayService) getSchedulableAccount(ctx context.Context, accountID int64) (*Account, error) {
if s.schedulerSnapshot != nil {
return s.schedulerSnapshot.GetAccount(ctx, accountID)
}
return s.accountRepo.GetByID(ctx, accountID)
}
func (s *OpenAIGatewayService) schedulingConfig() config.GatewaySchedulingConfig {
if s.cfg != nil {
return s.cfg.Gateway.Scheduling
......@@ -485,6 +508,15 @@ func (s *OpenAIGatewayService) schedulingConfig() config.GatewaySchedulingConfig
func (s *OpenAIGatewayService) GetAccessToken(ctx context.Context, account *Account) (string, string, error) {
switch account.Type {
case AccountTypeOAuth:
// 使用 TokenProvider 获取缓存的 token
if s.openAITokenProvider != nil {
accessToken, err := s.openAITokenProvider.GetAccessToken(ctx, account)
if err != nil {
return "", "", err
}
return accessToken, "oauth", nil
}
// 降级:TokenProvider 未配置时直接从账号读取
accessToken := account.GetOpenAIAccessToken()
if accessToken == "" {
return "", "", errors.New("access_token not found in credentials")
......@@ -511,7 +543,7 @@ func (s *OpenAIGatewayService) shouldFailoverUpstreamError(statusCode int) bool
}
func (s *OpenAIGatewayService) handleFailoverSideEffects(ctx context.Context, resp *http.Response, account *Account) {
body, _ := io.ReadAll(resp.Body)
body, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, body)
}
......@@ -528,30 +560,94 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
// Extract model and stream from parsed body
reqModel, _ := reqBody["model"].(string)
reqStream, _ := reqBody["stream"].(bool)
promptCacheKey := ""
if v, ok := reqBody["prompt_cache_key"].(string); ok {
promptCacheKey = strings.TrimSpace(v)
}
// Track if body needs re-serialization
bodyModified := false
originalModel := reqModel
// Apply model mapping
isCodexCLI := openai.IsCodexCLIRequest(c.GetHeader("User-Agent"))
// 对所有请求执行模型映射(包含 Codex CLI)。
mappedModel := account.GetMappedModel(reqModel)
if mappedModel != reqModel {
log.Printf("[OpenAI] Model mapping applied: %s -> %s (account: %s, isCodexCLI: %v)", reqModel, mappedModel, account.Name, isCodexCLI)
reqBody["model"] = mappedModel
bodyModified = true
}
// For OAuth accounts using ChatGPT internal API:
// 1. Add store: false
// 2. Normalize input format for Codex API compatibility
if account.Type == AccountTypeOAuth {
reqBody["store"] = false
bodyModified = true
// 针对所有 OpenAI 账号执行 Codex 模型名规范化,确保上游识别一致。
if model, ok := reqBody["model"].(string); ok {
normalizedModel := normalizeCodexModel(model)
if normalizedModel != "" && normalizedModel != model {
log.Printf("[OpenAI] Codex model normalization: %s -> %s (account: %s, type: %s, isCodexCLI: %v)",
model, normalizedModel, account.Name, account.Type, isCodexCLI)
reqBody["model"] = normalizedModel
mappedModel = normalizedModel
bodyModified = true
}
}
// Normalize input format: convert AI SDK multi-part content format to simplified format
// AI SDK sends: {"content": [{"type": "input_text", "text": "..."}]}
// Codex API expects: {"content": "..."}
if normalizeInputForCodexAPI(reqBody) {
// 规范化 reasoning.effort 参数(minimal -> none),与上游允许值对齐。
if reasoning, ok := reqBody["reasoning"].(map[string]any); ok {
if effort, ok := reasoning["effort"].(string); ok && effort == "minimal" {
reasoning["effort"] = "none"
bodyModified = true
log.Printf("[OpenAI] Normalized reasoning.effort: minimal -> none (account: %s)", account.Name)
}
}
if account.Type == AccountTypeOAuth && !isCodexCLI {
codexResult := applyCodexOAuthTransform(reqBody)
if codexResult.Modified {
bodyModified = true
}
if codexResult.NormalizedModel != "" {
mappedModel = codexResult.NormalizedModel
}
if codexResult.PromptCacheKey != "" {
promptCacheKey = codexResult.PromptCacheKey
}
}
// Handle max_output_tokens based on platform and account type
if !isCodexCLI {
if maxOutputTokens, hasMaxOutputTokens := reqBody["max_output_tokens"]; hasMaxOutputTokens {
switch account.Platform {
case PlatformOpenAI:
// For OpenAI API Key, remove max_output_tokens (not supported)
// For OpenAI OAuth (Responses API), keep it (supported)
if account.Type == AccountTypeAPIKey {
delete(reqBody, "max_output_tokens")
bodyModified = true
}
case PlatformAnthropic:
// For Anthropic (Claude), convert to max_tokens
delete(reqBody, "max_output_tokens")
if _, hasMaxTokens := reqBody["max_tokens"]; !hasMaxTokens {
reqBody["max_tokens"] = maxOutputTokens
}
bodyModified = true
case PlatformGemini:
// For Gemini, remove (will be handled by Gemini-specific transform)
delete(reqBody, "max_output_tokens")
bodyModified = true
default:
// For unknown platforms, remove to be safe
delete(reqBody, "max_output_tokens")
bodyModified = true
}
}
// Also handle max_completion_tokens (similar logic)
if _, hasMaxCompletionTokens := reqBody["max_completion_tokens"]; hasMaxCompletionTokens {
if account.Type == AccountTypeAPIKey || account.Platform != PlatformOpenAI {
delete(reqBody, "max_completion_tokens")
bodyModified = true
}
}
}
......@@ -571,7 +667,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
}
// Build upstream request
upstreamReq, err := s.buildUpstreamRequest(ctx, c, account, body, token, reqStream)
upstreamReq, err := s.buildUpstreamRequest(ctx, c, account, body, token, reqStream, promptCacheKey, isCodexCLI)
if err != nil {
return nil, err
}
......@@ -582,16 +678,63 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
proxyURL = account.Proxy.URL()
}
// Capture upstream request body for ops retry of this attempt.
if c != nil {
c.Set(OpsUpstreamRequestBodyKey, string(body))
}
// Send request
resp, err := s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency)
if err != nil {
return nil, fmt.Errorf("upstream request failed: %w", err)
// Ensure the client receives an error response (handlers assume Forward writes on non-failover errors).
safeErr := sanitizeUpstreamErrorMessage(err.Error())
setOpsUpstreamError(c, 0, safeErr, "")
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: 0,
Kind: "request_error",
Message: safeErr,
})
c.JSON(http.StatusBadGateway, gin.H{
"error": gin.H{
"type": "upstream_error",
"message": "Upstream request failed",
},
})
return nil, fmt.Errorf("upstream request failed: %s", safeErr)
}
defer func() { _ = resp.Body.Close() }()
// Handle error response
if resp.StatusCode >= 400 {
if s.shouldFailoverUpstreamError(resp.StatusCode) {
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
_ = resp.Body.Close()
resp.Body = io.NopCloser(bytes.NewReader(respBody))
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody))
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
upstreamDetail := ""
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
if maxBytes <= 0 {
maxBytes = 2048
}
upstreamDetail = truncateString(string(respBody), maxBytes)
}
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: resp.StatusCode,
UpstreamRequestID: resp.Header.Get("x-request-id"),
Kind: "failover",
Message: upstreamMsg,
Detail: upstreamDetail,
})
s.handleFailoverSideEffects(ctx, resp, account)
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
}
......@@ -632,7 +775,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
}, nil
}
func (s *OpenAIGatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Context, account *Account, body []byte, token string, isStream bool) (*http.Request, error) {
func (s *OpenAIGatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Context, account *Account, body []byte, token string, isStream bool, promptCacheKey string, isCodexCLI bool) (*http.Request, error) {
// Determine target URL based on account type
var targetURL string
switch account.Type {
......@@ -672,12 +815,6 @@ func (s *OpenAIGatewayService) buildUpstreamRequest(ctx context.Context, c *gin.
if chatgptAccountID != "" {
req.Header.Set("chatgpt-account-id", chatgptAccountID)
}
// Set accept header based on stream mode
if isStream {
req.Header.Set("accept", "text/event-stream")
} else {
req.Header.Set("accept", "application/json")
}
}
// Whitelist passthrough headers
......@@ -689,6 +826,19 @@ func (s *OpenAIGatewayService) buildUpstreamRequest(ctx context.Context, c *gin.
}
}
}
if account.Type == AccountTypeOAuth {
req.Header.Set("OpenAI-Beta", "responses=experimental")
if isCodexCLI {
req.Header.Set("originator", "codex_cli_rs")
} else {
req.Header.Set("originator", "opencode")
}
req.Header.Set("accept", "text/event-stream")
if promptCacheKey != "" {
req.Header.Set("conversation_id", promptCacheKey)
req.Header.Set("session_id", promptCacheKey)
}
}
// Apply custom User-Agent if configured
customUA := account.GetOpenAIUserAgent()
......@@ -705,17 +855,53 @@ func (s *OpenAIGatewayService) buildUpstreamRequest(ctx context.Context, c *gin.
}
func (s *OpenAIGatewayService) handleErrorResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account) (*OpenAIForwardResult, error) {
body, _ := io.ReadAll(resp.Body)
body, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(body))
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
upstreamDetail := ""
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
if maxBytes <= 0 {
maxBytes = 2048
}
upstreamDetail = truncateString(string(body), maxBytes)
}
setOpsUpstreamError(c, resp.StatusCode, upstreamMsg, upstreamDetail)
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
log.Printf(
"OpenAI upstream error %d (account=%d platform=%s type=%s): %s",
resp.StatusCode,
account.ID,
account.Platform,
account.Type,
truncateForLog(body, s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes),
)
}
// Check custom error codes
if !account.ShouldHandleErrorCode(resp.StatusCode) {
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: resp.StatusCode,
UpstreamRequestID: resp.Header.Get("x-request-id"),
Kind: "http_error",
Message: upstreamMsg,
Detail: upstreamDetail,
})
c.JSON(http.StatusInternalServerError, gin.H{
"error": gin.H{
"type": "upstream_error",
"message": "Upstream gateway error",
},
})
return nil, fmt.Errorf("upstream error: %d (not in custom error codes)", resp.StatusCode)
if upstreamMsg == "" {
return nil, fmt.Errorf("upstream error: %d (not in custom error codes)", resp.StatusCode)
}
return nil, fmt.Errorf("upstream error: %d (not in custom error codes) message=%s", resp.StatusCode, upstreamMsg)
}
// Handle upstream error (mark account status)
......@@ -723,6 +909,20 @@ func (s *OpenAIGatewayService) handleErrorResponse(ctx context.Context, resp *ht
if s.rateLimitService != nil {
shouldDisable = s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, body)
}
kind := "http_error"
if shouldDisable {
kind = "failover"
}
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: resp.StatusCode,
UpstreamRequestID: resp.Header.Get("x-request-id"),
Kind: kind,
Message: upstreamMsg,
Detail: upstreamDetail,
})
if shouldDisable {
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
}
......@@ -761,7 +961,10 @@ func (s *OpenAIGatewayService) handleErrorResponse(ctx context.Context, resp *ht
},
})
return nil, fmt.Errorf("upstream error: %d", resp.StatusCode)
if upstreamMsg == "" {
return nil, fmt.Errorf("upstream error: %d", resp.StatusCode)
}
return nil, fmt.Errorf("upstream error: %d message=%s", resp.StatusCode, upstreamMsg)
}
// openaiStreamingResult streaming response result
......@@ -905,6 +1108,11 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
line = s.replaceModelInSSELine(line, mappedModel, originalModel)
}
// Correct Codex tool calls if needed (apply_patch -> edit, etc.)
if correctedData, corrected := s.toolCorrector.CorrectToolCallsInSSEData(data); corrected {
line = "data: " + correctedData
}
// Forward line
if _, err := fmt.Fprintf(w, "%s\n", line); err != nil {
sendErrorEvent("write_failed")
......@@ -933,6 +1141,10 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
continue
}
log.Printf("Stream data interval timeout: account=%d model=%s interval=%s", account.ID, originalModel, streamInterval)
// 处理流超时,可能标记账户为临时不可调度或错误状态
if s.rateLimitService != nil {
s.rateLimitService.HandleStreamTimeout(ctx, account, originalModel)
}
sendErrorEvent("stream_timeout")
return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream data interval timeout")
......@@ -988,6 +1200,20 @@ func (s *OpenAIGatewayService) replaceModelInSSELine(line, fromModel, toModel st
return line
}
// correctToolCallsInResponseBody 修正响应体中的工具调用
func (s *OpenAIGatewayService) correctToolCallsInResponseBody(body []byte) []byte {
if len(body) == 0 {
return body
}
bodyStr := string(body)
corrected, changed := s.toolCorrector.CorrectToolCallsInSSEData(bodyStr)
if changed {
return []byte(corrected)
}
return body
}
func (s *OpenAIGatewayService) parseSSEUsage(data string, usage *OpenAIUsage) {
// Parse response.completed event for usage (OpenAI Responses format)
var event struct {
......@@ -1016,6 +1242,13 @@ func (s *OpenAIGatewayService) handleNonStreamingResponse(ctx context.Context, r
return nil, err
}
if account.Type == AccountTypeOAuth {
bodyLooksLikeSSE := bytes.Contains(body, []byte("data:")) || bytes.Contains(body, []byte("event:"))
if isEventStreamResponse(resp.Header) || bodyLooksLikeSSE {
return s.handleOAuthSSEToJSON(resp, c, body, originalModel, mappedModel)
}
}
// Parse usage
var response struct {
Usage struct {
......@@ -1055,6 +1288,112 @@ func (s *OpenAIGatewayService) handleNonStreamingResponse(ctx context.Context, r
return usage, nil
}
func isEventStreamResponse(header http.Header) bool {
contentType := strings.ToLower(header.Get("Content-Type"))
return strings.Contains(contentType, "text/event-stream")
}
func (s *OpenAIGatewayService) handleOAuthSSEToJSON(resp *http.Response, c *gin.Context, body []byte, originalModel, mappedModel string) (*OpenAIUsage, error) {
bodyText := string(body)
finalResponse, ok := extractCodexFinalResponse(bodyText)
usage := &OpenAIUsage{}
if ok {
var response struct {
Usage struct {
InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"`
InputTokenDetails struct {
CachedTokens int `json:"cached_tokens"`
} `json:"input_tokens_details"`
} `json:"usage"`
}
if err := json.Unmarshal(finalResponse, &response); err == nil {
usage.InputTokens = response.Usage.InputTokens
usage.OutputTokens = response.Usage.OutputTokens
usage.CacheReadInputTokens = response.Usage.InputTokenDetails.CachedTokens
}
body = finalResponse
if originalModel != mappedModel {
body = s.replaceModelInResponseBody(body, mappedModel, originalModel)
}
// Correct tool calls in final response
body = s.correctToolCallsInResponseBody(body)
} else {
usage = s.parseSSEUsageFromBody(bodyText)
if originalModel != mappedModel {
bodyText = s.replaceModelInSSEBody(bodyText, mappedModel, originalModel)
}
body = []byte(bodyText)
}
responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.cfg.Security.ResponseHeaders)
contentType := "application/json; charset=utf-8"
if !ok {
contentType = resp.Header.Get("Content-Type")
if contentType == "" {
contentType = "text/event-stream"
}
}
c.Data(resp.StatusCode, contentType, body)
return usage, nil
}
func extractCodexFinalResponse(body string) ([]byte, bool) {
lines := strings.Split(body, "\n")
for _, line := range lines {
if !openaiSSEDataRe.MatchString(line) {
continue
}
data := openaiSSEDataRe.ReplaceAllString(line, "")
if data == "" || data == "[DONE]" {
continue
}
var event struct {
Type string `json:"type"`
Response json.RawMessage `json:"response"`
}
if json.Unmarshal([]byte(data), &event) != nil {
continue
}
if event.Type == "response.done" || event.Type == "response.completed" {
if len(event.Response) > 0 {
return event.Response, true
}
}
}
return nil, false
}
func (s *OpenAIGatewayService) parseSSEUsageFromBody(body string) *OpenAIUsage {
usage := &OpenAIUsage{}
lines := strings.Split(body, "\n")
for _, line := range lines {
if !openaiSSEDataRe.MatchString(line) {
continue
}
data := openaiSSEDataRe.ReplaceAllString(line, "")
if data == "" || data == "[DONE]" {
continue
}
s.parseSSEUsage(data, usage)
}
return usage
}
func (s *OpenAIGatewayService) replaceModelInSSEBody(body, fromModel, toModel string) string {
lines := strings.Split(body, "\n")
for i, line := range lines {
if !openaiSSEDataRe.MatchString(line) {
continue
}
lines[i] = s.replaceModelInSSELine(line, fromModel, toModel)
}
return strings.Join(lines, "\n")
}
func (s *OpenAIGatewayService) validateUpstreamBaseURL(raw string) (string, error) {
if s.cfg != nil && !s.cfg.Security.URLAllowlist.Enabled {
normalized, err := urlvalidator.ValidateURLFormat(raw, s.cfg.Security.URLAllowlist.AllowInsecureHTTP)
......@@ -1094,101 +1433,6 @@ func (s *OpenAIGatewayService) replaceModelInResponseBody(body []byte, fromModel
return newBody
}
// normalizeInputForCodexAPI converts AI SDK multi-part content format to simplified format
// that the ChatGPT internal Codex API expects.
//
// AI SDK sends content as an array of typed objects:
//
// {"content": [{"type": "input_text", "text": "hello"}]}
//
// ChatGPT Codex API expects content as a simple string:
//
// {"content": "hello"}
//
// This function modifies reqBody in-place and returns true if any modification was made.
func normalizeInputForCodexAPI(reqBody map[string]any) bool {
input, ok := reqBody["input"]
if !ok {
return false
}
// Handle case where input is a simple string (already compatible)
if _, isString := input.(string); isString {
return false
}
// Handle case where input is an array of messages
inputArray, ok := input.([]any)
if !ok {
return false
}
modified := false
for _, item := range inputArray {
message, ok := item.(map[string]any)
if !ok {
continue
}
content, ok := message["content"]
if !ok {
continue
}
// If content is already a string, no conversion needed
if _, isString := content.(string); isString {
continue
}
// If content is an array (AI SDK format), convert to string
contentArray, ok := content.([]any)
if !ok {
continue
}
// Extract text from content array
var textParts []string
for _, part := range contentArray {
partMap, ok := part.(map[string]any)
if !ok {
continue
}
// Handle different content types
partType, _ := partMap["type"].(string)
switch partType {
case "input_text", "text":
// Extract text from input_text or text type
if text, ok := partMap["text"].(string); ok {
textParts = append(textParts, text)
}
case "input_image", "image":
// For images, we need to preserve the original format
// as ChatGPT Codex API may support images in a different way
// For now, skip image parts (they will be lost in conversion)
// TODO: Consider preserving image data or handling it separately
continue
case "input_file", "file":
// Similar to images, file inputs may need special handling
continue
default:
// For unknown types, try to extract text if available
if text, ok := partMap["text"].(string); ok {
textParts = append(textParts, text)
}
}
}
// Convert content array to string
if len(textParts) > 0 {
message["content"] = strings.Join(textParts, "\n")
modified = true
}
}
return modified
}
// OpenAIRecordUsageInput input for recording usage
type OpenAIRecordUsageInput struct {
Result *OpenAIForwardResult
......@@ -1197,6 +1441,7 @@ type OpenAIRecordUsageInput struct {
Account *Account
Subscription *UserSubscription
UserAgent string // 请求的 User-Agent
IPAddress string // 请求的客户端 IP 地址
}
// RecordUsage records usage and deducts balance
......@@ -1242,28 +1487,30 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
// Create usage log
durationMs := int(result.Duration.Milliseconds())
accountRateMultiplier := account.BillingRateMultiplier()
usageLog := &UsageLog{
UserID: user.ID,
APIKeyID: apiKey.ID,
AccountID: account.ID,
RequestID: result.RequestID,
Model: result.Model,
InputTokens: actualInputTokens,
OutputTokens: result.Usage.OutputTokens,
CacheCreationTokens: result.Usage.CacheCreationInputTokens,
CacheReadTokens: result.Usage.CacheReadInputTokens,
InputCost: cost.InputCost,
OutputCost: cost.OutputCost,
CacheCreationCost: cost.CacheCreationCost,
CacheReadCost: cost.CacheReadCost,
TotalCost: cost.TotalCost,
ActualCost: cost.ActualCost,
RateMultiplier: multiplier,
BillingType: billingType,
Stream: result.Stream,
DurationMs: &durationMs,
FirstTokenMs: result.FirstTokenMs,
CreatedAt: time.Now(),
UserID: user.ID,
APIKeyID: apiKey.ID,
AccountID: account.ID,
RequestID: result.RequestID,
Model: result.Model,
InputTokens: actualInputTokens,
OutputTokens: result.Usage.OutputTokens,
CacheCreationTokens: result.Usage.CacheCreationInputTokens,
CacheReadTokens: result.Usage.CacheReadInputTokens,
InputCost: cost.InputCost,
OutputCost: cost.OutputCost,
CacheCreationCost: cost.CacheCreationCost,
CacheReadCost: cost.CacheReadCost,
TotalCost: cost.TotalCost,
ActualCost: cost.ActualCost,
RateMultiplier: multiplier,
AccountRateMultiplier: &accountRateMultiplier,
BillingType: billingType,
Stream: result.Stream,
DurationMs: &durationMs,
FirstTokenMs: result.FirstTokenMs,
CreatedAt: time.Now(),
}
// 添加 UserAgent
......@@ -1271,6 +1518,11 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
usageLog.UserAgent = &input.UserAgent
}
// 添加 IPAddress
if input.IPAddress != "" {
usageLog.IPAddress = &input.IPAddress
}
if apiKey.GroupID != nil {
usageLog.GroupID = apiKey.GroupID
}
......
......@@ -3,6 +3,7 @@ package service
import (
"bufio"
"bytes"
"context"
"errors"
"io"
"net/http"
......@@ -15,6 +16,129 @@ import (
"github.com/gin-gonic/gin"
)
type stubOpenAIAccountRepo struct {
AccountRepository
accounts []Account
}
func (r stubOpenAIAccountRepo) ListSchedulableByGroupIDAndPlatform(ctx context.Context, groupID int64, platform string) ([]Account, error) {
return append([]Account(nil), r.accounts...), nil
}
func (r stubOpenAIAccountRepo) ListSchedulableByPlatform(ctx context.Context, platform string) ([]Account, error) {
return append([]Account(nil), r.accounts...), nil
}
type stubConcurrencyCache struct {
ConcurrencyCache
}
func (c stubConcurrencyCache) AcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error) {
return true, nil
}
func (c stubConcurrencyCache) ReleaseAccountSlot(ctx context.Context, accountID int64, requestID string) error {
return nil
}
func (c stubConcurrencyCache) GetAccountsLoadBatch(ctx context.Context, accounts []AccountWithConcurrency) (map[int64]*AccountLoadInfo, error) {
out := make(map[int64]*AccountLoadInfo, len(accounts))
for _, acc := range accounts {
out[acc.ID] = &AccountLoadInfo{AccountID: acc.ID, LoadRate: 0}
}
return out, nil
}
func TestOpenAISelectAccountWithLoadAwareness_FiltersUnschedulable(t *testing.T) {
now := time.Now()
resetAt := now.Add(10 * time.Minute)
groupID := int64(1)
rateLimited := Account{
ID: 1,
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Status: StatusActive,
Schedulable: true,
Concurrency: 1,
Priority: 0,
RateLimitResetAt: &resetAt,
}
available := Account{
ID: 2,
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Status: StatusActive,
Schedulable: true,
Concurrency: 1,
Priority: 1,
}
svc := &OpenAIGatewayService{
accountRepo: stubOpenAIAccountRepo{accounts: []Account{rateLimited, available}},
concurrencyService: NewConcurrencyService(stubConcurrencyCache{}),
}
selection, err := svc.SelectAccountWithLoadAwareness(context.Background(), &groupID, "", "gpt-5.2", nil)
if err != nil {
t.Fatalf("SelectAccountWithLoadAwareness error: %v", err)
}
if selection == nil || selection.Account == nil {
t.Fatalf("expected selection with account")
}
if selection.Account.ID != available.ID {
t.Fatalf("expected account %d, got %d", available.ID, selection.Account.ID)
}
if selection.ReleaseFunc != nil {
selection.ReleaseFunc()
}
}
func TestOpenAISelectAccountWithLoadAwareness_FiltersUnschedulableWhenNoConcurrencyService(t *testing.T) {
now := time.Now()
resetAt := now.Add(10 * time.Minute)
groupID := int64(1)
rateLimited := Account{
ID: 1,
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Status: StatusActive,
Schedulable: true,
Concurrency: 1,
Priority: 0,
RateLimitResetAt: &resetAt,
}
available := Account{
ID: 2,
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Status: StatusActive,
Schedulable: true,
Concurrency: 1,
Priority: 1,
}
svc := &OpenAIGatewayService{
accountRepo: stubOpenAIAccountRepo{accounts: []Account{rateLimited, available}},
// concurrencyService is nil, forcing the non-load-batch selection path.
}
selection, err := svc.SelectAccountWithLoadAwareness(context.Background(), &groupID, "", "gpt-5.2", nil)
if err != nil {
t.Fatalf("SelectAccountWithLoadAwareness error: %v", err)
}
if selection == nil || selection.Account == nil {
t.Fatalf("expected selection with account")
}
if selection.Account.ID != available.ID {
t.Fatalf("expected account %d, got %d", available.ID, selection.Account.ID)
}
if selection.ReleaseFunc != nil {
selection.ReleaseFunc()
}
}
func TestOpenAIStreamingTimeout(t *testing.T) {
gin.SetMode(gin.TestMode)
cfg := &config.Config{
......@@ -220,7 +344,7 @@ func TestOpenAIInvalidBaseURLWhenAllowlistDisabled(t *testing.T) {
Credentials: map[string]any{"base_url": "://invalid-url"},
}
_, err := svc.buildUpstreamRequest(c.Request.Context(), c, account, []byte("{}"), "token", false)
_, err := svc.buildUpstreamRequest(c.Request.Context(), c, account, []byte("{}"), "token", false, "", false)
if err == nil {
t.Fatalf("expected error for invalid base_url when allowlist disabled")
}
......
package service
import (
"strings"
"testing"
)
// TestOpenAIGatewayService_ToolCorrection 测试 OpenAIGatewayService 中的工具修正集成
func TestOpenAIGatewayService_ToolCorrection(t *testing.T) {
// 创建一个简单的 service 实例来测试工具修正
service := &OpenAIGatewayService{
toolCorrector: NewCodexToolCorrector(),
}
tests := []struct {
name string
input []byte
expected string
changed bool
}{
{
name: "correct apply_patch in response body",
input: []byte(`{
"choices": [{
"message": {
"tool_calls": [{
"function": {"name": "apply_patch"}
}]
}
}]
}`),
expected: "edit",
changed: true,
},
{
name: "correct update_plan in response body",
input: []byte(`{
"tool_calls": [{
"function": {"name": "update_plan"}
}]
}`),
expected: "todowrite",
changed: true,
},
{
name: "no change for correct tool name",
input: []byte(`{
"tool_calls": [{
"function": {"name": "edit"}
}]
}`),
expected: "edit",
changed: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := service.correctToolCallsInResponseBody(tt.input)
resultStr := string(result)
// 检查是否包含期望的工具名称
if !strings.Contains(resultStr, tt.expected) {
t.Errorf("expected result to contain %q, got %q", tt.expected, resultStr)
}
// 对于预期有变化的情况,验证结果与输入不同
if tt.changed && string(result) == string(tt.input) {
t.Error("expected result to be different from input, but they are the same")
}
// 对于预期无变化的情况,验证结果与输入相同
if !tt.changed && string(result) != string(tt.input) {
t.Error("expected result to be same as input, but they are different")
}
})
}
}
// TestOpenAIGatewayService_ToolCorrectorInitialization 测试工具修正器是否正确初始化
func TestOpenAIGatewayService_ToolCorrectorInitialization(t *testing.T) {
service := &OpenAIGatewayService{
toolCorrector: NewCodexToolCorrector(),
}
if service.toolCorrector == nil {
t.Fatal("toolCorrector should not be nil")
}
// 测试修正器可以正常工作
data := `{"tool_calls":[{"function":{"name":"apply_patch"}}]}`
corrected, changed := service.toolCorrector.CorrectToolCallsInSSEData(data)
if !changed {
t.Error("expected tool call to be corrected")
}
if !strings.Contains(corrected, "edit") {
t.Errorf("expected corrected data to contain 'edit', got %q", corrected)
}
}
// TestToolCorrectionStats 测试工具修正统计功能
func TestToolCorrectionStats(t *testing.T) {
service := &OpenAIGatewayService{
toolCorrector: NewCodexToolCorrector(),
}
// 执行几次修正
testData := []string{
`{"tool_calls":[{"function":{"name":"apply_patch"}}]}`,
`{"tool_calls":[{"function":{"name":"update_plan"}}]}`,
`{"tool_calls":[{"function":{"name":"apply_patch"}}]}`,
}
for _, data := range testData {
service.toolCorrector.CorrectToolCallsInSSEData(data)
}
stats := service.toolCorrector.GetStats()
if stats.TotalCorrected != 3 {
t.Errorf("expected 3 corrections, got %d", stats.TotalCorrected)
}
if stats.CorrectionsByTool["apply_patch->edit"] != 2 {
t.Errorf("expected 2 apply_patch->edit corrections, got %d", stats.CorrectionsByTool["apply_patch->edit"])
}
if stats.CorrectionsByTool["update_plan->todowrite"] != 1 {
t.Errorf("expected 1 update_plan->todowrite correction, got %d", stats.CorrectionsByTool["update_plan->todowrite"])
}
}
package service
import (
"context"
"errors"
"log/slog"
"strings"
"time"
)
const (
openAITokenRefreshSkew = 3 * time.Minute
openAITokenCacheSkew = 5 * time.Minute
openAILockWaitTime = 200 * time.Millisecond
)
// OpenAITokenCache Token 缓存接口(复用 GeminiTokenCache 接口定义)
type OpenAITokenCache = GeminiTokenCache
// OpenAITokenProvider 管理 OpenAI OAuth 账户的 access_token
type OpenAITokenProvider struct {
accountRepo AccountRepository
tokenCache OpenAITokenCache
openAIOAuthService *OpenAIOAuthService
}
func NewOpenAITokenProvider(
accountRepo AccountRepository,
tokenCache OpenAITokenCache,
openAIOAuthService *OpenAIOAuthService,
) *OpenAITokenProvider {
return &OpenAITokenProvider{
accountRepo: accountRepo,
tokenCache: tokenCache,
openAIOAuthService: openAIOAuthService,
}
}
// GetAccessToken 获取有效的 access_token
func (p *OpenAITokenProvider) GetAccessToken(ctx context.Context, account *Account) (string, error) {
if account == nil {
return "", errors.New("account is nil")
}
if account.Platform != PlatformOpenAI || account.Type != AccountTypeOAuth {
return "", errors.New("not an openai oauth account")
}
cacheKey := OpenAITokenCacheKey(account)
// 1. 先尝试缓存
if p.tokenCache != nil {
if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && strings.TrimSpace(token) != "" {
slog.Debug("openai_token_cache_hit", "account_id", account.ID)
return token, nil
} else if err != nil {
slog.Warn("openai_token_cache_get_failed", "account_id", account.ID, "error", err)
}
}
slog.Debug("openai_token_cache_miss", "account_id", account.ID)
// 2. 如果即将过期则刷新
expiresAt := account.GetCredentialAsTime("expires_at")
needsRefresh := expiresAt == nil || time.Until(*expiresAt) <= openAITokenRefreshSkew
refreshFailed := false
if needsRefresh && p.tokenCache != nil {
locked, lockErr := p.tokenCache.AcquireRefreshLock(ctx, cacheKey, 30*time.Second)
if lockErr == nil && locked {
defer func() { _ = p.tokenCache.ReleaseRefreshLock(ctx, cacheKey) }()
// 拿到锁后再次检查缓存(另一个 worker 可能已刷新)
if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && strings.TrimSpace(token) != "" {
return token, nil
}
// 从数据库获取最新账户信息
fresh, err := p.accountRepo.GetByID(ctx, account.ID)
if err == nil && fresh != nil {
account = fresh
}
expiresAt = account.GetCredentialAsTime("expires_at")
if expiresAt == nil || time.Until(*expiresAt) <= openAITokenRefreshSkew {
if p.openAIOAuthService == nil {
slog.Warn("openai_oauth_service_not_configured", "account_id", account.ID)
refreshFailed = true // 无法刷新,标记失败
} else {
tokenInfo, err := p.openAIOAuthService.RefreshAccountToken(ctx, account)
if err != nil {
// 刷新失败时记录警告,但不立即返回错误,尝试使用现有 token
slog.Warn("openai_token_refresh_failed", "account_id", account.ID, "error", err)
refreshFailed = true // 刷新失败,标记以使用短 TTL
} else {
newCredentials := p.openAIOAuthService.BuildAccountCredentials(tokenInfo)
for k, v := range account.Credentials {
if _, exists := newCredentials[k]; !exists {
newCredentials[k] = v
}
}
account.Credentials = newCredentials
if updateErr := p.accountRepo.Update(ctx, account); updateErr != nil {
slog.Error("openai_token_provider_update_failed", "account_id", account.ID, "error", updateErr)
}
expiresAt = account.GetCredentialAsTime("expires_at")
}
}
}
} else if lockErr != nil {
// Redis 错误导致无法获取锁,降级为无锁刷新(仅在 token 接近过期时)
slog.Warn("openai_token_lock_failed_degraded_refresh", "account_id", account.ID, "error", lockErr)
// 检查 ctx 是否已取消
if ctx.Err() != nil {
return "", ctx.Err()
}
// 从数据库获取最新账户信息
if p.accountRepo != nil {
fresh, err := p.accountRepo.GetByID(ctx, account.ID)
if err == nil && fresh != nil {
account = fresh
}
}
expiresAt = account.GetCredentialAsTime("expires_at")
// 仅在 expires_at 已过期/接近过期时才执行无锁刷新
if expiresAt == nil || time.Until(*expiresAt) <= openAITokenRefreshSkew {
if p.openAIOAuthService == nil {
slog.Warn("openai_oauth_service_not_configured", "account_id", account.ID)
refreshFailed = true
} else {
tokenInfo, err := p.openAIOAuthService.RefreshAccountToken(ctx, account)
if err != nil {
slog.Warn("openai_token_refresh_failed_degraded", "account_id", account.ID, "error", err)
refreshFailed = true
} else {
newCredentials := p.openAIOAuthService.BuildAccountCredentials(tokenInfo)
for k, v := range account.Credentials {
if _, exists := newCredentials[k]; !exists {
newCredentials[k] = v
}
}
account.Credentials = newCredentials
if updateErr := p.accountRepo.Update(ctx, account); updateErr != nil {
slog.Error("openai_token_provider_update_failed", "account_id", account.ID, "error", updateErr)
}
expiresAt = account.GetCredentialAsTime("expires_at")
}
}
}
} else {
// 锁获取失败(被其他 worker 持有),等待 200ms 后重试读取缓存
time.Sleep(openAILockWaitTime)
if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && strings.TrimSpace(token) != "" {
slog.Debug("openai_token_cache_hit_after_wait", "account_id", account.ID)
return token, nil
}
}
}
accessToken := account.GetOpenAIAccessToken()
if strings.TrimSpace(accessToken) == "" {
return "", errors.New("access_token not found in credentials")
}
// 3. 存入缓存
if p.tokenCache != nil {
ttl := 30 * time.Minute
if refreshFailed {
// 刷新失败时使用短 TTL,避免失效 token 长时间缓存导致 401 抖动
ttl = time.Minute
slog.Debug("openai_token_cache_short_ttl", "account_id", account.ID, "reason", "refresh_failed")
} else if expiresAt != nil {
until := time.Until(*expiresAt)
switch {
case until > openAITokenCacheSkew:
ttl = until - openAITokenCacheSkew
case until > 0:
ttl = until
default:
ttl = time.Minute
}
}
if err := p.tokenCache.SetAccessToken(ctx, cacheKey, accessToken, ttl); err != nil {
slog.Warn("openai_token_cache_set_failed", "account_id", account.ID, "error", err)
}
}
return accessToken, nil
}
//go:build unit
package service
import (
"context"
"errors"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/stretchr/testify/require"
)
// openAITokenCacheStub implements OpenAITokenCache for testing
type openAITokenCacheStub struct {
mu sync.Mutex
tokens map[string]string
getErr error
setErr error
deleteErr error
lockAcquired bool
lockErr error
releaseLockErr error
getCalled int32
setCalled int32
lockCalled int32
unlockCalled int32
simulateLockRace bool
}
func newOpenAITokenCacheStub() *openAITokenCacheStub {
return &openAITokenCacheStub{
tokens: make(map[string]string),
lockAcquired: true,
}
}
func (s *openAITokenCacheStub) GetAccessToken(ctx context.Context, cacheKey string) (string, error) {
atomic.AddInt32(&s.getCalled, 1)
if s.getErr != nil {
return "", s.getErr
}
s.mu.Lock()
defer s.mu.Unlock()
return s.tokens[cacheKey], nil
}
func (s *openAITokenCacheStub) SetAccessToken(ctx context.Context, cacheKey string, token string, ttl time.Duration) error {
atomic.AddInt32(&s.setCalled, 1)
if s.setErr != nil {
return s.setErr
}
s.mu.Lock()
defer s.mu.Unlock()
s.tokens[cacheKey] = token
return nil
}
func (s *openAITokenCacheStub) DeleteAccessToken(ctx context.Context, cacheKey string) error {
if s.deleteErr != nil {
return s.deleteErr
}
s.mu.Lock()
defer s.mu.Unlock()
delete(s.tokens, cacheKey)
return nil
}
func (s *openAITokenCacheStub) AcquireRefreshLock(ctx context.Context, cacheKey string, ttl time.Duration) (bool, error) {
atomic.AddInt32(&s.lockCalled, 1)
if s.lockErr != nil {
return false, s.lockErr
}
if s.simulateLockRace {
return false, nil
}
return s.lockAcquired, nil
}
func (s *openAITokenCacheStub) ReleaseRefreshLock(ctx context.Context, cacheKey string) error {
atomic.AddInt32(&s.unlockCalled, 1)
return s.releaseLockErr
}
// openAIAccountRepoStub is a minimal stub implementing only the methods used by OpenAITokenProvider
type openAIAccountRepoStub struct {
account *Account
getErr error
updateErr error
getCalled int32
updateCalled int32
}
func (r *openAIAccountRepoStub) GetByID(ctx context.Context, id int64) (*Account, error) {
atomic.AddInt32(&r.getCalled, 1)
if r.getErr != nil {
return nil, r.getErr
}
return r.account, nil
}
func (r *openAIAccountRepoStub) Update(ctx context.Context, account *Account) error {
atomic.AddInt32(&r.updateCalled, 1)
if r.updateErr != nil {
return r.updateErr
}
r.account = account
return nil
}
// openAIOAuthServiceStub implements OpenAIOAuthService methods for testing
type openAIOAuthServiceStub struct {
tokenInfo *OpenAITokenInfo
refreshErr error
refreshCalled int32
}
func (s *openAIOAuthServiceStub) RefreshAccountToken(ctx context.Context, account *Account) (*OpenAITokenInfo, error) {
atomic.AddInt32(&s.refreshCalled, 1)
if s.refreshErr != nil {
return nil, s.refreshErr
}
return s.tokenInfo, nil
}
func (s *openAIOAuthServiceStub) BuildAccountCredentials(info *OpenAITokenInfo) map[string]any {
now := time.Now()
return map[string]any{
"access_token": info.AccessToken,
"refresh_token": info.RefreshToken,
"expires_at": now.Add(time.Duration(info.ExpiresIn) * time.Second).Format(time.RFC3339),
}
}
func TestOpenAITokenProvider_CacheHit(t *testing.T) {
cache := newOpenAITokenCacheStub()
account := &Account{
ID: 100,
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "db-token",
},
}
cacheKey := OpenAITokenCacheKey(account)
cache.tokens[cacheKey] = "cached-token"
provider := NewOpenAITokenProvider(nil, cache, nil)
token, err := provider.GetAccessToken(context.Background(), account)
require.NoError(t, err)
require.Equal(t, "cached-token", token)
require.Equal(t, int32(1), atomic.LoadInt32(&cache.getCalled))
require.Equal(t, int32(0), atomic.LoadInt32(&cache.setCalled))
}
func TestOpenAITokenProvider_CacheMiss_FromCredentials(t *testing.T) {
cache := newOpenAITokenCacheStub()
// Token expires in far future, no refresh needed
expiresAt := time.Now().Add(1 * time.Hour).Format(time.RFC3339)
account := &Account{
ID: 101,
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "credential-token",
"expires_at": expiresAt,
},
}
provider := NewOpenAITokenProvider(nil, cache, nil)
token, err := provider.GetAccessToken(context.Background(), account)
require.NoError(t, err)
require.Equal(t, "credential-token", token)
// Should have stored in cache
cacheKey := OpenAITokenCacheKey(account)
require.Equal(t, "credential-token", cache.tokens[cacheKey])
}
func TestOpenAITokenProvider_TokenRefresh(t *testing.T) {
cache := newOpenAITokenCacheStub()
accountRepo := &openAIAccountRepoStub{}
oauthService := &openAIOAuthServiceStub{
tokenInfo: &OpenAITokenInfo{
AccessToken: "refreshed-token",
RefreshToken: "new-refresh-token",
ExpiresIn: 3600,
},
}
// Token expires soon (within refresh skew)
expiresAt := time.Now().Add(1 * time.Minute).Format(time.RFC3339)
account := &Account{
ID: 102,
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "old-token",
"refresh_token": "old-refresh-token",
"expires_at": expiresAt,
},
}
accountRepo.account = account
// We need to directly test with the stub - create a custom provider
customProvider := &testOpenAITokenProvider{
accountRepo: accountRepo,
tokenCache: cache,
oauthService: oauthService,
}
token, err := customProvider.GetAccessToken(context.Background(), account)
require.NoError(t, err)
require.Equal(t, "refreshed-token", token)
require.Equal(t, int32(1), atomic.LoadInt32(&oauthService.refreshCalled))
}
// testOpenAITokenProvider is a test version that uses the stub OAuth service
type testOpenAITokenProvider struct {
accountRepo *openAIAccountRepoStub
tokenCache *openAITokenCacheStub
oauthService *openAIOAuthServiceStub
}
func (p *testOpenAITokenProvider) GetAccessToken(ctx context.Context, account *Account) (string, error) {
if account == nil {
return "", errors.New("account is nil")
}
if account.Platform != PlatformOpenAI || account.Type != AccountTypeOAuth {
return "", errors.New("not an openai oauth account")
}
cacheKey := OpenAITokenCacheKey(account)
// 1. Check cache
if p.tokenCache != nil {
if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && token != "" {
return token, nil
}
}
// 2. Check if refresh needed
expiresAt := account.GetCredentialAsTime("expires_at")
needsRefresh := expiresAt == nil || time.Until(*expiresAt) <= openAITokenRefreshSkew
refreshFailed := false
if needsRefresh && p.tokenCache != nil {
locked, err := p.tokenCache.AcquireRefreshLock(ctx, cacheKey, 30*time.Second)
if err == nil && locked {
defer func() { _ = p.tokenCache.ReleaseRefreshLock(ctx, cacheKey) }()
// Check cache again after acquiring lock
if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && token != "" {
return token, nil
}
// Get fresh account from DB
fresh, err := p.accountRepo.GetByID(ctx, account.ID)
if err == nil && fresh != nil {
account = fresh
}
expiresAt = account.GetCredentialAsTime("expires_at")
if expiresAt == nil || time.Until(*expiresAt) <= openAITokenRefreshSkew {
if p.oauthService == nil {
refreshFailed = true // 无法刷新,标记失败
} else {
tokenInfo, err := p.oauthService.RefreshAccountToken(ctx, account)
if err != nil {
refreshFailed = true // 刷新失败,标记以使用短 TTL
} else {
newCredentials := p.oauthService.BuildAccountCredentials(tokenInfo)
for k, v := range account.Credentials {
if _, exists := newCredentials[k]; !exists {
newCredentials[k] = v
}
}
account.Credentials = newCredentials
_ = p.accountRepo.Update(ctx, account)
expiresAt = account.GetCredentialAsTime("expires_at")
}
}
}
} else if p.tokenCache.simulateLockRace {
// Wait and retry cache
time.Sleep(10 * time.Millisecond) // Short wait for test
if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && token != "" {
return token, nil
}
}
}
accessToken := account.GetOpenAIAccessToken()
if accessToken == "" {
return "", errors.New("access_token not found in credentials")
}
// 3. Store in cache
if p.tokenCache != nil {
ttl := 30 * time.Minute
if refreshFailed {
ttl = time.Minute // 刷新失败时使用短 TTL
} else if expiresAt != nil {
until := time.Until(*expiresAt)
if until > openAITokenCacheSkew {
ttl = until - openAITokenCacheSkew
} else if until > 0 {
ttl = until
} else {
ttl = time.Minute
}
}
_ = p.tokenCache.SetAccessToken(ctx, cacheKey, accessToken, ttl)
}
return accessToken, nil
}
func TestOpenAITokenProvider_LockRaceCondition(t *testing.T) {
cache := newOpenAITokenCacheStub()
cache.simulateLockRace = true
accountRepo := &openAIAccountRepoStub{}
// Token expires soon
expiresAt := time.Now().Add(1 * time.Minute).Format(time.RFC3339)
account := &Account{
ID: 103,
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "race-token",
"expires_at": expiresAt,
},
}
accountRepo.account = account
// Simulate another worker already refreshed and cached
cacheKey := OpenAITokenCacheKey(account)
go func() {
time.Sleep(5 * time.Millisecond)
cache.mu.Lock()
cache.tokens[cacheKey] = "winner-token"
cache.mu.Unlock()
}()
provider := &testOpenAITokenProvider{
accountRepo: accountRepo,
tokenCache: cache,
}
token, err := provider.GetAccessToken(context.Background(), account)
require.NoError(t, err)
// Should get the token set by the "winner" or the original
require.NotEmpty(t, token)
}
func TestOpenAITokenProvider_NilAccount(t *testing.T) {
provider := NewOpenAITokenProvider(nil, nil, nil)
token, err := provider.GetAccessToken(context.Background(), nil)
require.Error(t, err)
require.Contains(t, err.Error(), "account is nil")
require.Empty(t, token)
}
func TestOpenAITokenProvider_WrongPlatform(t *testing.T) {
provider := NewOpenAITokenProvider(nil, nil, nil)
account := &Account{
ID: 104,
Platform: PlatformGemini,
Type: AccountTypeOAuth,
}
token, err := provider.GetAccessToken(context.Background(), account)
require.Error(t, err)
require.Contains(t, err.Error(), "not an openai oauth account")
require.Empty(t, token)
}
func TestOpenAITokenProvider_WrongAccountType(t *testing.T) {
provider := NewOpenAITokenProvider(nil, nil, nil)
account := &Account{
ID: 105,
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
}
token, err := provider.GetAccessToken(context.Background(), account)
require.Error(t, err)
require.Contains(t, err.Error(), "not an openai oauth account")
require.Empty(t, token)
}
func TestOpenAITokenProvider_NilCache(t *testing.T) {
// Token doesn't need refresh
expiresAt := time.Now().Add(1 * time.Hour).Format(time.RFC3339)
account := &Account{
ID: 106,
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "nocache-token",
"expires_at": expiresAt,
},
}
provider := NewOpenAITokenProvider(nil, nil, nil)
token, err := provider.GetAccessToken(context.Background(), account)
require.NoError(t, err)
require.Equal(t, "nocache-token", token)
}
func TestOpenAITokenProvider_CacheGetError(t *testing.T) {
cache := newOpenAITokenCacheStub()
cache.getErr = errors.New("redis connection failed")
// Token doesn't need refresh
expiresAt := time.Now().Add(1 * time.Hour).Format(time.RFC3339)
account := &Account{
ID: 107,
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "fallback-token",
"expires_at": expiresAt,
},
}
provider := NewOpenAITokenProvider(nil, cache, nil)
// Should gracefully degrade and return from credentials
token, err := provider.GetAccessToken(context.Background(), account)
require.NoError(t, err)
require.Equal(t, "fallback-token", token)
}
func TestOpenAITokenProvider_CacheSetError(t *testing.T) {
cache := newOpenAITokenCacheStub()
cache.setErr = errors.New("redis write failed")
expiresAt := time.Now().Add(1 * time.Hour).Format(time.RFC3339)
account := &Account{
ID: 108,
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "still-works-token",
"expires_at": expiresAt,
},
}
provider := NewOpenAITokenProvider(nil, cache, nil)
// Should still work even if cache set fails
token, err := provider.GetAccessToken(context.Background(), account)
require.NoError(t, err)
require.Equal(t, "still-works-token", token)
}
func TestOpenAITokenProvider_MissingAccessToken(t *testing.T) {
cache := newOpenAITokenCacheStub()
expiresAt := time.Now().Add(1 * time.Hour).Format(time.RFC3339)
account := &Account{
ID: 109,
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"expires_at": expiresAt,
// missing access_token
},
}
provider := NewOpenAITokenProvider(nil, cache, nil)
token, err := provider.GetAccessToken(context.Background(), account)
require.Error(t, err)
require.Contains(t, err.Error(), "access_token not found")
require.Empty(t, token)
}
func TestOpenAITokenProvider_RefreshError(t *testing.T) {
cache := newOpenAITokenCacheStub()
accountRepo := &openAIAccountRepoStub{}
oauthService := &openAIOAuthServiceStub{
refreshErr: errors.New("oauth refresh failed"),
}
// Token expires soon
expiresAt := time.Now().Add(1 * time.Minute).Format(time.RFC3339)
account := &Account{
ID: 110,
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "old-token",
"refresh_token": "old-refresh-token",
"expires_at": expiresAt,
},
}
accountRepo.account = account
provider := &testOpenAITokenProvider{
accountRepo: accountRepo,
tokenCache: cache,
oauthService: oauthService,
}
// Now with fallback behavior, should return existing token even if refresh fails
token, err := provider.GetAccessToken(context.Background(), account)
require.NoError(t, err)
require.Equal(t, "old-token", token) // Fallback to existing token
}
func TestOpenAITokenProvider_OAuthServiceNotConfigured(t *testing.T) {
cache := newOpenAITokenCacheStub()
accountRepo := &openAIAccountRepoStub{}
// Token expires soon
expiresAt := time.Now().Add(1 * time.Minute).Format(time.RFC3339)
account := &Account{
ID: 111,
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "old-token",
"expires_at": expiresAt,
},
}
accountRepo.account = account
provider := &testOpenAITokenProvider{
accountRepo: accountRepo,
tokenCache: cache,
oauthService: nil, // not configured
}
// Now with fallback behavior, should return existing token even if oauth service not configured
token, err := provider.GetAccessToken(context.Background(), account)
require.NoError(t, err)
require.Equal(t, "old-token", token) // Fallback to existing token
}
func TestOpenAITokenProvider_TTLCalculation(t *testing.T) {
tests := []struct {
name string
expiresIn time.Duration
}{
{
name: "far_future_expiry",
expiresIn: 1 * time.Hour,
},
{
name: "medium_expiry",
expiresIn: 10 * time.Minute,
},
{
name: "near_expiry",
expiresIn: 6 * time.Minute,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cache := newOpenAITokenCacheStub()
expiresAt := time.Now().Add(tt.expiresIn).Format(time.RFC3339)
account := &Account{
ID: 200,
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "test-token",
"expires_at": expiresAt,
},
}
provider := NewOpenAITokenProvider(nil, cache, nil)
_, err := provider.GetAccessToken(context.Background(), account)
require.NoError(t, err)
// Verify token was cached
cacheKey := OpenAITokenCacheKey(account)
require.Equal(t, "test-token", cache.tokens[cacheKey])
})
}
}
func TestOpenAITokenProvider_DoubleCheckAfterLock(t *testing.T) {
cache := newOpenAITokenCacheStub()
accountRepo := &openAIAccountRepoStub{}
oauthService := &openAIOAuthServiceStub{
tokenInfo: &OpenAITokenInfo{
AccessToken: "refreshed-token",
RefreshToken: "new-refresh",
ExpiresIn: 3600,
},
}
// Token expires soon
expiresAt := time.Now().Add(1 * time.Minute).Format(time.RFC3339)
account := &Account{
ID: 112,
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "old-token",
"expires_at": expiresAt,
},
}
accountRepo.account = account
cacheKey := OpenAITokenCacheKey(account)
// Simulate: first GetAccessToken returns empty, but after lock acquired, cache has token
originalGet := int32(0)
cache.tokens[cacheKey] = "" // Empty initially
provider := &testOpenAITokenProvider{
accountRepo: accountRepo,
tokenCache: cache,
oauthService: oauthService,
}
// In a goroutine, set the cached token after a small delay (simulating race)
go func() {
time.Sleep(5 * time.Millisecond)
cache.mu.Lock()
cache.tokens[cacheKey] = "cached-by-other"
cache.mu.Unlock()
}()
token, err := provider.GetAccessToken(context.Background(), account)
require.NoError(t, err)
// Should get either the refreshed token or the cached one
require.NotEmpty(t, token)
_ = originalGet // Suppress unused warning
}
// Tests for real provider - to increase coverage
func TestOpenAITokenProvider_Real_LockFailedWait(t *testing.T) {
cache := newOpenAITokenCacheStub()
cache.lockAcquired = false // Lock acquisition fails
// Token expires soon (within refresh skew) to trigger lock attempt
expiresAt := time.Now().Add(1 * time.Minute).Format(time.RFC3339)
account := &Account{
ID: 200,
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "fallback-token",
"expires_at": expiresAt,
},
}
// Set token in cache after lock wait period (simulate other worker refreshing)
cacheKey := OpenAITokenCacheKey(account)
go func() {
time.Sleep(100 * time.Millisecond)
cache.mu.Lock()
cache.tokens[cacheKey] = "refreshed-by-other"
cache.mu.Unlock()
}()
provider := NewOpenAITokenProvider(nil, cache, nil)
token, err := provider.GetAccessToken(context.Background(), account)
require.NoError(t, err)
// Should get either the fallback token or the refreshed one
require.NotEmpty(t, token)
}
func TestOpenAITokenProvider_Real_CacheHitAfterWait(t *testing.T) {
cache := newOpenAITokenCacheStub()
cache.lockAcquired = false // Lock acquisition fails
// Token expires soon
expiresAt := time.Now().Add(1 * time.Minute).Format(time.RFC3339)
account := &Account{
ID: 201,
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "original-token",
"expires_at": expiresAt,
},
}
cacheKey := OpenAITokenCacheKey(account)
// Set token in cache immediately after wait starts
go func() {
time.Sleep(50 * time.Millisecond)
cache.mu.Lock()
cache.tokens[cacheKey] = "winner-token"
cache.mu.Unlock()
}()
provider := NewOpenAITokenProvider(nil, cache, nil)
token, err := provider.GetAccessToken(context.Background(), account)
require.NoError(t, err)
require.NotEmpty(t, token)
}
func TestOpenAITokenProvider_Real_ExpiredWithoutRefreshToken(t *testing.T) {
cache := newOpenAITokenCacheStub()
cache.lockAcquired = false // Prevent entering refresh logic
// Token with nil expires_at (no expiry set) - should use credentials
account := &Account{
ID: 202,
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "no-expiry-token",
},
}
provider := NewOpenAITokenProvider(nil, cache, nil)
token, err := provider.GetAccessToken(context.Background(), account)
// Without OAuth service, refresh will fail but token should be returned from credentials
require.NoError(t, err)
require.Equal(t, "no-expiry-token", token)
}
func TestOpenAITokenProvider_Real_WhitespaceToken(t *testing.T) {
cache := newOpenAITokenCacheStub()
cacheKey := "openai:account:203"
cache.tokens[cacheKey] = " " // Whitespace only - should be treated as empty
expiresAt := time.Now().Add(1 * time.Hour).Format(time.RFC3339)
account := &Account{
ID: 203,
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "real-token",
"expires_at": expiresAt,
},
}
provider := NewOpenAITokenProvider(nil, cache, nil)
token, err := provider.GetAccessToken(context.Background(), account)
require.NoError(t, err)
require.Equal(t, "real-token", token) // Should fall back to credentials
}
func TestOpenAITokenProvider_Real_LockError(t *testing.T) {
cache := newOpenAITokenCacheStub()
cache.lockErr = errors.New("redis lock failed")
// Token expires soon (within refresh skew)
expiresAt := time.Now().Add(1 * time.Minute).Format(time.RFC3339)
account := &Account{
ID: 204,
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "fallback-on-lock-error",
"expires_at": expiresAt,
},
}
provider := NewOpenAITokenProvider(nil, cache, nil)
token, err := provider.GetAccessToken(context.Background(), account)
require.NoError(t, err)
require.Equal(t, "fallback-on-lock-error", token)
}
func TestOpenAITokenProvider_Real_WhitespaceCredentialToken(t *testing.T) {
cache := newOpenAITokenCacheStub()
expiresAt := time.Now().Add(1 * time.Hour).Format(time.RFC3339)
account := &Account{
ID: 205,
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": " ", // Whitespace only
"expires_at": expiresAt,
},
}
provider := NewOpenAITokenProvider(nil, cache, nil)
token, err := provider.GetAccessToken(context.Background(), account)
require.Error(t, err)
require.Contains(t, err.Error(), "access_token not found")
require.Empty(t, token)
}
func TestOpenAITokenProvider_Real_NilCredentials(t *testing.T) {
cache := newOpenAITokenCacheStub()
expiresAt := time.Now().Add(1 * time.Hour).Format(time.RFC3339)
account := &Account{
ID: 206,
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"expires_at": expiresAt,
// No access_token
},
}
provider := NewOpenAITokenProvider(nil, cache, nil)
token, err := provider.GetAccessToken(context.Background(), account)
require.Error(t, err)
require.Contains(t, err.Error(), "access_token not found")
require.Empty(t, token)
}
package service
import "strings"
// NeedsToolContinuation 判定请求是否需要工具调用续链处理。
// 满足以下任一信号即视为续链:previous_response_id、input 内包含 function_call_output/item_reference、
// 或显式声明 tools/tool_choice。
func NeedsToolContinuation(reqBody map[string]any) bool {
if reqBody == nil {
return false
}
if hasNonEmptyString(reqBody["previous_response_id"]) {
return true
}
if hasToolsSignal(reqBody) {
return true
}
if hasToolChoiceSignal(reqBody) {
return true
}
if inputHasType(reqBody, "function_call_output") {
return true
}
if inputHasType(reqBody, "item_reference") {
return true
}
return false
}
// HasFunctionCallOutput 判断 input 是否包含 function_call_output,用于触发续链校验。
func HasFunctionCallOutput(reqBody map[string]any) bool {
if reqBody == nil {
return false
}
return inputHasType(reqBody, "function_call_output")
}
// HasToolCallContext 判断 input 是否包含带 call_id 的 tool_call/function_call,
// 用于判断 function_call_output 是否具备可关联的上下文。
func HasToolCallContext(reqBody map[string]any) bool {
if reqBody == nil {
return false
}
input, ok := reqBody["input"].([]any)
if !ok {
return false
}
for _, item := range input {
itemMap, ok := item.(map[string]any)
if !ok {
continue
}
itemType, _ := itemMap["type"].(string)
if itemType != "tool_call" && itemType != "function_call" {
continue
}
if callID, ok := itemMap["call_id"].(string); ok && strings.TrimSpace(callID) != "" {
return true
}
}
return false
}
// FunctionCallOutputCallIDs 提取 input 中 function_call_output 的 call_id 集合。
// 仅返回非空 call_id,用于与 item_reference.id 做匹配校验。
func FunctionCallOutputCallIDs(reqBody map[string]any) []string {
if reqBody == nil {
return nil
}
input, ok := reqBody["input"].([]any)
if !ok {
return nil
}
ids := make(map[string]struct{})
for _, item := range input {
itemMap, ok := item.(map[string]any)
if !ok {
continue
}
itemType, _ := itemMap["type"].(string)
if itemType != "function_call_output" {
continue
}
if callID, ok := itemMap["call_id"].(string); ok && strings.TrimSpace(callID) != "" {
ids[callID] = struct{}{}
}
}
if len(ids) == 0 {
return nil
}
result := make([]string, 0, len(ids))
for id := range ids {
result = append(result, id)
}
return result
}
// HasFunctionCallOutputMissingCallID 判断是否存在缺少 call_id 的 function_call_output。
func HasFunctionCallOutputMissingCallID(reqBody map[string]any) bool {
if reqBody == nil {
return false
}
input, ok := reqBody["input"].([]any)
if !ok {
return false
}
for _, item := range input {
itemMap, ok := item.(map[string]any)
if !ok {
continue
}
itemType, _ := itemMap["type"].(string)
if itemType != "function_call_output" {
continue
}
callID, _ := itemMap["call_id"].(string)
if strings.TrimSpace(callID) == "" {
return true
}
}
return false
}
// HasItemReferenceForCallIDs 判断 item_reference.id 是否覆盖所有 call_id。
// 用于仅依赖引用项完成续链场景的校验。
func HasItemReferenceForCallIDs(reqBody map[string]any, callIDs []string) bool {
if reqBody == nil || len(callIDs) == 0 {
return false
}
input, ok := reqBody["input"].([]any)
if !ok {
return false
}
referenceIDs := make(map[string]struct{})
for _, item := range input {
itemMap, ok := item.(map[string]any)
if !ok {
continue
}
itemType, _ := itemMap["type"].(string)
if itemType != "item_reference" {
continue
}
idValue, _ := itemMap["id"].(string)
idValue = strings.TrimSpace(idValue)
if idValue == "" {
continue
}
referenceIDs[idValue] = struct{}{}
}
if len(referenceIDs) == 0 {
return false
}
for _, callID := range callIDs {
if _, ok := referenceIDs[callID]; !ok {
return false
}
}
return true
}
// inputHasType 判断 input 中是否存在指定类型的 item。
func inputHasType(reqBody map[string]any, want string) bool {
input, ok := reqBody["input"].([]any)
if !ok {
return false
}
for _, item := range input {
itemMap, ok := item.(map[string]any)
if !ok {
continue
}
itemType, _ := itemMap["type"].(string)
if itemType == want {
return true
}
}
return false
}
// hasNonEmptyString 判断字段是否为非空字符串。
func hasNonEmptyString(value any) bool {
stringValue, ok := value.(string)
return ok && strings.TrimSpace(stringValue) != ""
}
// hasToolsSignal 判断 tools 字段是否显式声明(存在且不为空)。
func hasToolsSignal(reqBody map[string]any) bool {
raw, exists := reqBody["tools"]
if !exists || raw == nil {
return false
}
if tools, ok := raw.([]any); ok {
return len(tools) > 0
}
return false
}
// hasToolChoiceSignal 判断 tool_choice 是否显式声明(非空或非 nil)。
func hasToolChoiceSignal(reqBody map[string]any) bool {
raw, exists := reqBody["tool_choice"]
if !exists || raw == nil {
return false
}
switch value := raw.(type) {
case string:
return strings.TrimSpace(value) != ""
case map[string]any:
return len(value) > 0
default:
return false
}
}
package service
import (
"testing"
"github.com/stretchr/testify/require"
)
func TestNeedsToolContinuationSignals(t *testing.T) {
// 覆盖所有触发续链的信号来源,确保判定逻辑完整。
cases := []struct {
name string
body map[string]any
want bool
}{
{name: "nil", body: nil, want: false},
{name: "previous_response_id", body: map[string]any{"previous_response_id": "resp_1"}, want: true},
{name: "previous_response_id_blank", body: map[string]any{"previous_response_id": " "}, want: false},
{name: "function_call_output", body: map[string]any{"input": []any{map[string]any{"type": "function_call_output"}}}, want: true},
{name: "item_reference", body: map[string]any{"input": []any{map[string]any{"type": "item_reference"}}}, want: true},
{name: "tools", body: map[string]any{"tools": []any{map[string]any{"type": "function"}}}, want: true},
{name: "tools_empty", body: map[string]any{"tools": []any{}}, want: false},
{name: "tools_invalid", body: map[string]any{"tools": "bad"}, want: false},
{name: "tool_choice", body: map[string]any{"tool_choice": "auto"}, want: true},
{name: "tool_choice_object", body: map[string]any{"tool_choice": map[string]any{"type": "function"}}, want: true},
{name: "tool_choice_empty_object", body: map[string]any{"tool_choice": map[string]any{}}, want: false},
{name: "none", body: map[string]any{"input": []any{map[string]any{"type": "text", "text": "hi"}}}, want: false},
}
for _, tt := range cases {
t.Run(tt.name, func(t *testing.T) {
require.Equal(t, tt.want, NeedsToolContinuation(tt.body))
})
}
}
func TestHasFunctionCallOutput(t *testing.T) {
// 仅当 input 中存在 function_call_output 才视为续链输出。
require.False(t, HasFunctionCallOutput(nil))
require.True(t, HasFunctionCallOutput(map[string]any{
"input": []any{map[string]any{"type": "function_call_output"}},
}))
require.False(t, HasFunctionCallOutput(map[string]any{
"input": "text",
}))
}
func TestHasToolCallContext(t *testing.T) {
// tool_call/function_call 必须包含 call_id,才能作为可关联上下文。
require.False(t, HasToolCallContext(nil))
require.True(t, HasToolCallContext(map[string]any{
"input": []any{map[string]any{"type": "tool_call", "call_id": "call_1"}},
}))
require.True(t, HasToolCallContext(map[string]any{
"input": []any{map[string]any{"type": "function_call", "call_id": "call_2"}},
}))
require.False(t, HasToolCallContext(map[string]any{
"input": []any{map[string]any{"type": "tool_call"}},
}))
}
func TestFunctionCallOutputCallIDs(t *testing.T) {
// 仅提取非空 call_id,去重后返回。
require.Empty(t, FunctionCallOutputCallIDs(nil))
callIDs := FunctionCallOutputCallIDs(map[string]any{
"input": []any{
map[string]any{"type": "function_call_output", "call_id": "call_1"},
map[string]any{"type": "function_call_output", "call_id": ""},
map[string]any{"type": "function_call_output", "call_id": "call_1"},
},
})
require.ElementsMatch(t, []string{"call_1"}, callIDs)
}
func TestHasFunctionCallOutputMissingCallID(t *testing.T) {
require.False(t, HasFunctionCallOutputMissingCallID(nil))
require.True(t, HasFunctionCallOutputMissingCallID(map[string]any{
"input": []any{map[string]any{"type": "function_call_output"}},
}))
require.False(t, HasFunctionCallOutputMissingCallID(map[string]any{
"input": []any{map[string]any{"type": "function_call_output", "call_id": "call_1"}},
}))
}
func TestHasItemReferenceForCallIDs(t *testing.T) {
// item_reference 需要覆盖所有 call_id 才视为可关联上下文。
require.False(t, HasItemReferenceForCallIDs(nil, []string{"call_1"}))
require.False(t, HasItemReferenceForCallIDs(map[string]any{}, []string{"call_1"}))
req := map[string]any{
"input": []any{
map[string]any{"type": "item_reference", "id": "call_1"},
map[string]any{"type": "item_reference", "id": "call_2"},
},
}
require.True(t, HasItemReferenceForCallIDs(req, []string{"call_1"}))
require.True(t, HasItemReferenceForCallIDs(req, []string{"call_1", "call_2"}))
require.False(t, HasItemReferenceForCallIDs(req, []string{"call_1", "call_3"}))
}
package service
import (
"encoding/json"
"fmt"
"log"
"sync"
)
// codexToolNameMapping 定义 Codex 原生工具名称到 OpenCode 工具名称的映射
var codexToolNameMapping = map[string]string{
"apply_patch": "edit",
"applyPatch": "edit",
"update_plan": "todowrite",
"updatePlan": "todowrite",
"read_plan": "todoread",
"readPlan": "todoread",
"search_files": "grep",
"searchFiles": "grep",
"list_files": "glob",
"listFiles": "glob",
"read_file": "read",
"readFile": "read",
"write_file": "write",
"writeFile": "write",
"execute_bash": "bash",
"executeBash": "bash",
"exec_bash": "bash",
"execBash": "bash",
}
// ToolCorrectionStats 记录工具修正的统计信息(导出用于 JSON 序列化)
type ToolCorrectionStats struct {
TotalCorrected int `json:"total_corrected"`
CorrectionsByTool map[string]int `json:"corrections_by_tool"`
}
// CodexToolCorrector 处理 Codex 工具调用的自动修正
type CodexToolCorrector struct {
stats ToolCorrectionStats
mu sync.RWMutex
}
// NewCodexToolCorrector 创建新的工具修正器
func NewCodexToolCorrector() *CodexToolCorrector {
return &CodexToolCorrector{
stats: ToolCorrectionStats{
CorrectionsByTool: make(map[string]int),
},
}
}
// CorrectToolCallsInSSEData 修正 SSE 数据中的工具调用
// 返回修正后的数据和是否进行了修正
func (c *CodexToolCorrector) CorrectToolCallsInSSEData(data string) (string, bool) {
if data == "" || data == "\n" {
return data, false
}
// 尝试解析 JSON
var payload map[string]any
if err := json.Unmarshal([]byte(data), &payload); err != nil {
// 不是有效的 JSON,直接返回原数据
return data, false
}
corrected := false
// 处理 tool_calls 数组
if toolCalls, ok := payload["tool_calls"].([]any); ok {
if c.correctToolCallsArray(toolCalls) {
corrected = true
}
}
// 处理 function_call 对象
if functionCall, ok := payload["function_call"].(map[string]any); ok {
if c.correctFunctionCall(functionCall) {
corrected = true
}
}
// 处理 delta.tool_calls
if delta, ok := payload["delta"].(map[string]any); ok {
if toolCalls, ok := delta["tool_calls"].([]any); ok {
if c.correctToolCallsArray(toolCalls) {
corrected = true
}
}
if functionCall, ok := delta["function_call"].(map[string]any); ok {
if c.correctFunctionCall(functionCall) {
corrected = true
}
}
}
// 处理 choices[].message.tool_calls 和 choices[].delta.tool_calls
if choices, ok := payload["choices"].([]any); ok {
for _, choice := range choices {
if choiceMap, ok := choice.(map[string]any); ok {
// 处理 message 中的工具调用
if message, ok := choiceMap["message"].(map[string]any); ok {
if toolCalls, ok := message["tool_calls"].([]any); ok {
if c.correctToolCallsArray(toolCalls) {
corrected = true
}
}
if functionCall, ok := message["function_call"].(map[string]any); ok {
if c.correctFunctionCall(functionCall) {
corrected = true
}
}
}
// 处理 delta 中的工具调用
if delta, ok := choiceMap["delta"].(map[string]any); ok {
if toolCalls, ok := delta["tool_calls"].([]any); ok {
if c.correctToolCallsArray(toolCalls) {
corrected = true
}
}
if functionCall, ok := delta["function_call"].(map[string]any); ok {
if c.correctFunctionCall(functionCall) {
corrected = true
}
}
}
}
}
}
if !corrected {
return data, false
}
// 序列化回 JSON
correctedBytes, err := json.Marshal(payload)
if err != nil {
log.Printf("[CodexToolCorrector] Failed to marshal corrected data: %v", err)
return data, false
}
return string(correctedBytes), true
}
// correctToolCallsArray 修正工具调用数组中的工具名称
func (c *CodexToolCorrector) correctToolCallsArray(toolCalls []any) bool {
corrected := false
for _, toolCall := range toolCalls {
if toolCallMap, ok := toolCall.(map[string]any); ok {
if function, ok := toolCallMap["function"].(map[string]any); ok {
if c.correctFunctionCall(function) {
corrected = true
}
}
}
}
return corrected
}
// correctFunctionCall 修正单个函数调用的工具名称和参数
func (c *CodexToolCorrector) correctFunctionCall(functionCall map[string]any) bool {
name, ok := functionCall["name"].(string)
if !ok || name == "" {
return false
}
corrected := false
// 查找并修正工具名称
if correctName, found := codexToolNameMapping[name]; found {
functionCall["name"] = correctName
c.recordCorrection(name, correctName)
corrected = true
name = correctName // 使用修正后的名称进行参数修正
}
// 修正工具参数(基于工具名称)
if c.correctToolParameters(name, functionCall) {
corrected = true
}
return corrected
}
// correctToolParameters 修正工具参数以符合 OpenCode 规范
func (c *CodexToolCorrector) correctToolParameters(toolName string, functionCall map[string]any) bool {
arguments, ok := functionCall["arguments"]
if !ok {
return false
}
// arguments 可能是字符串(JSON)或已解析的 map
var argsMap map[string]any
switch v := arguments.(type) {
case string:
// 解析 JSON 字符串
if err := json.Unmarshal([]byte(v), &argsMap); err != nil {
return false
}
case map[string]any:
argsMap = v
default:
return false
}
corrected := false
// 根据工具名称应用特定的参数修正规则
switch toolName {
case "bash":
// 移除 workdir 参数(OpenCode 不支持)
if _, exists := argsMap["workdir"]; exists {
delete(argsMap, "workdir")
corrected = true
log.Printf("[CodexToolCorrector] Removed 'workdir' parameter from bash tool")
}
if _, exists := argsMap["work_dir"]; exists {
delete(argsMap, "work_dir")
corrected = true
log.Printf("[CodexToolCorrector] Removed 'work_dir' parameter from bash tool")
}
case "edit":
// OpenCode edit 使用 old_string/new_string,Codex 可能使用其他名称
// 这里可以添加参数名称的映射逻辑
if _, exists := argsMap["file_path"]; !exists {
if path, exists := argsMap["path"]; exists {
argsMap["file_path"] = path
delete(argsMap, "path")
corrected = true
log.Printf("[CodexToolCorrector] Renamed 'path' to 'file_path' in edit tool")
}
}
}
// 如果修正了参数,需要重新序列化
if corrected {
if _, wasString := arguments.(string); wasString {
// 原本是字符串,序列化回字符串
if newArgsJSON, err := json.Marshal(argsMap); err == nil {
functionCall["arguments"] = string(newArgsJSON)
}
} else {
// 原本是 map,直接赋值
functionCall["arguments"] = argsMap
}
}
return corrected
}
// recordCorrection 记录一次工具名称修正
func (c *CodexToolCorrector) recordCorrection(from, to string) {
c.mu.Lock()
defer c.mu.Unlock()
c.stats.TotalCorrected++
key := fmt.Sprintf("%s->%s", from, to)
c.stats.CorrectionsByTool[key]++
log.Printf("[CodexToolCorrector] Corrected tool call: %s -> %s (total: %d)",
from, to, c.stats.TotalCorrected)
}
// GetStats 获取工具修正统计信息
func (c *CodexToolCorrector) GetStats() ToolCorrectionStats {
c.mu.RLock()
defer c.mu.RUnlock()
// 返回副本以避免并发问题
statsCopy := ToolCorrectionStats{
TotalCorrected: c.stats.TotalCorrected,
CorrectionsByTool: make(map[string]int, len(c.stats.CorrectionsByTool)),
}
for k, v := range c.stats.CorrectionsByTool {
statsCopy.CorrectionsByTool[k] = v
}
return statsCopy
}
// ResetStats 重置统计信息
func (c *CodexToolCorrector) ResetStats() {
c.mu.Lock()
defer c.mu.Unlock()
c.stats.TotalCorrected = 0
c.stats.CorrectionsByTool = make(map[string]int)
}
// CorrectToolName 直接修正工具名称(用于非 SSE 场景)
func CorrectToolName(name string) (string, bool) {
if correctName, found := codexToolNameMapping[name]; found {
return correctName, true
}
return name, false
}
// GetToolNameMapping 获取工具名称映射表
func GetToolNameMapping() map[string]string {
// 返回副本以避免外部修改
mapping := make(map[string]string, len(codexToolNameMapping))
for k, v := range codexToolNameMapping {
mapping[k] = v
}
return mapping
}
package service
import (
"encoding/json"
"testing"
)
func TestCorrectToolCallsInSSEData(t *testing.T) {
corrector := NewCodexToolCorrector()
tests := []struct {
name string
input string
expectCorrected bool
checkFunc func(t *testing.T, result string)
}{
{
name: "empty string",
input: "",
expectCorrected: false,
},
{
name: "newline only",
input: "\n",
expectCorrected: false,
},
{
name: "invalid json",
input: "not a json",
expectCorrected: false,
},
{
name: "correct apply_patch in tool_calls",
input: `{"tool_calls":[{"function":{"name":"apply_patch","arguments":"{}"}}]}`,
expectCorrected: true,
checkFunc: func(t *testing.T, result string) {
var payload map[string]any
if err := json.Unmarshal([]byte(result), &payload); err != nil {
t.Fatalf("Failed to parse result: %v", err)
}
toolCalls, ok := payload["tool_calls"].([]any)
if !ok || len(toolCalls) == 0 {
t.Fatal("No tool_calls found in result")
}
toolCall, ok := toolCalls[0].(map[string]any)
if !ok {
t.Fatal("Invalid tool_call format")
}
functionCall, ok := toolCall["function"].(map[string]any)
if !ok {
t.Fatal("Invalid function format")
}
if functionCall["name"] != "edit" {
t.Errorf("Expected tool name 'edit', got '%v'", functionCall["name"])
}
},
},
{
name: "correct update_plan in function_call",
input: `{"function_call":{"name":"update_plan","arguments":"{}"}}`,
expectCorrected: true,
checkFunc: func(t *testing.T, result string) {
var payload map[string]any
if err := json.Unmarshal([]byte(result), &payload); err != nil {
t.Fatalf("Failed to parse result: %v", err)
}
functionCall, ok := payload["function_call"].(map[string]any)
if !ok {
t.Fatal("Invalid function_call format")
}
if functionCall["name"] != "todowrite" {
t.Errorf("Expected tool name 'todowrite', got '%v'", functionCall["name"])
}
},
},
{
name: "correct search_files in delta.tool_calls",
input: `{"delta":{"tool_calls":[{"function":{"name":"search_files"}}]}}`,
expectCorrected: true,
checkFunc: func(t *testing.T, result string) {
var payload map[string]any
if err := json.Unmarshal([]byte(result), &payload); err != nil {
t.Fatalf("Failed to parse result: %v", err)
}
delta, ok := payload["delta"].(map[string]any)
if !ok {
t.Fatal("Invalid delta format")
}
toolCalls, ok := delta["tool_calls"].([]any)
if !ok || len(toolCalls) == 0 {
t.Fatal("No tool_calls found in delta")
}
toolCall, ok := toolCalls[0].(map[string]any)
if !ok {
t.Fatal("Invalid tool_call format")
}
functionCall, ok := toolCall["function"].(map[string]any)
if !ok {
t.Fatal("Invalid function format")
}
if functionCall["name"] != "grep" {
t.Errorf("Expected tool name 'grep', got '%v'", functionCall["name"])
}
},
},
{
name: "correct list_files in choices.message.tool_calls",
input: `{"choices":[{"message":{"tool_calls":[{"function":{"name":"list_files"}}]}}]}`,
expectCorrected: true,
checkFunc: func(t *testing.T, result string) {
var payload map[string]any
if err := json.Unmarshal([]byte(result), &payload); err != nil {
t.Fatalf("Failed to parse result: %v", err)
}
choices, ok := payload["choices"].([]any)
if !ok || len(choices) == 0 {
t.Fatal("No choices found in result")
}
choice, ok := choices[0].(map[string]any)
if !ok {
t.Fatal("Invalid choice format")
}
message, ok := choice["message"].(map[string]any)
if !ok {
t.Fatal("Invalid message format")
}
toolCalls, ok := message["tool_calls"].([]any)
if !ok || len(toolCalls) == 0 {
t.Fatal("No tool_calls found in message")
}
toolCall, ok := toolCalls[0].(map[string]any)
if !ok {
t.Fatal("Invalid tool_call format")
}
functionCall, ok := toolCall["function"].(map[string]any)
if !ok {
t.Fatal("Invalid function format")
}
if functionCall["name"] != "glob" {
t.Errorf("Expected tool name 'glob', got '%v'", functionCall["name"])
}
},
},
{
name: "no correction needed",
input: `{"tool_calls":[{"function":{"name":"read","arguments":"{}"}}]}`,
expectCorrected: false,
},
{
name: "correct multiple tool calls",
input: `{"tool_calls":[{"function":{"name":"apply_patch"}},{"function":{"name":"read_file"}}]}`,
expectCorrected: true,
checkFunc: func(t *testing.T, result string) {
var payload map[string]any
if err := json.Unmarshal([]byte(result), &payload); err != nil {
t.Fatalf("Failed to parse result: %v", err)
}
toolCalls, ok := payload["tool_calls"].([]any)
if !ok || len(toolCalls) < 2 {
t.Fatal("Expected at least 2 tool_calls")
}
toolCall1, ok := toolCalls[0].(map[string]any)
if !ok {
t.Fatal("Invalid first tool_call format")
}
func1, ok := toolCall1["function"].(map[string]any)
if !ok {
t.Fatal("Invalid first function format")
}
if func1["name"] != "edit" {
t.Errorf("Expected first tool name 'edit', got '%v'", func1["name"])
}
toolCall2, ok := toolCalls[1].(map[string]any)
if !ok {
t.Fatal("Invalid second tool_call format")
}
func2, ok := toolCall2["function"].(map[string]any)
if !ok {
t.Fatal("Invalid second function format")
}
if func2["name"] != "read" {
t.Errorf("Expected second tool name 'read', got '%v'", func2["name"])
}
},
},
{
name: "camelCase format - applyPatch",
input: `{"tool_calls":[{"function":{"name":"applyPatch"}}]}`,
expectCorrected: true,
checkFunc: func(t *testing.T, result string) {
var payload map[string]any
if err := json.Unmarshal([]byte(result), &payload); err != nil {
t.Fatalf("Failed to parse result: %v", err)
}
toolCalls, ok := payload["tool_calls"].([]any)
if !ok || len(toolCalls) == 0 {
t.Fatal("No tool_calls found in result")
}
toolCall, ok := toolCalls[0].(map[string]any)
if !ok {
t.Fatal("Invalid tool_call format")
}
functionCall, ok := toolCall["function"].(map[string]any)
if !ok {
t.Fatal("Invalid function format")
}
if functionCall["name"] != "edit" {
t.Errorf("Expected tool name 'edit', got '%v'", functionCall["name"])
}
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result, corrected := corrector.CorrectToolCallsInSSEData(tt.input)
if corrected != tt.expectCorrected {
t.Errorf("Expected corrected=%v, got %v", tt.expectCorrected, corrected)
}
if !corrected && result != tt.input {
t.Errorf("Expected unchanged result when not corrected")
}
if tt.checkFunc != nil {
tt.checkFunc(t, result)
}
})
}
}
func TestCorrectToolName(t *testing.T) {
tests := []struct {
input string
expected string
corrected bool
}{
{"apply_patch", "edit", true},
{"applyPatch", "edit", true},
{"update_plan", "todowrite", true},
{"updatePlan", "todowrite", true},
{"read_plan", "todoread", true},
{"readPlan", "todoread", true},
{"search_files", "grep", true},
{"searchFiles", "grep", true},
{"list_files", "glob", true},
{"listFiles", "glob", true},
{"read_file", "read", true},
{"readFile", "read", true},
{"write_file", "write", true},
{"writeFile", "write", true},
{"execute_bash", "bash", true},
{"executeBash", "bash", true},
{"exec_bash", "bash", true},
{"execBash", "bash", true},
{"unknown_tool", "unknown_tool", false},
{"read", "read", false},
{"edit", "edit", false},
}
for _, tt := range tests {
t.Run(tt.input, func(t *testing.T) {
result, corrected := CorrectToolName(tt.input)
if corrected != tt.corrected {
t.Errorf("Expected corrected=%v, got %v", tt.corrected, corrected)
}
if result != tt.expected {
t.Errorf("Expected '%s', got '%s'", tt.expected, result)
}
})
}
}
func TestGetToolNameMapping(t *testing.T) {
mapping := GetToolNameMapping()
expectedMappings := map[string]string{
"apply_patch": "edit",
"update_plan": "todowrite",
"read_plan": "todoread",
"search_files": "grep",
"list_files": "glob",
}
for from, to := range expectedMappings {
if mapping[from] != to {
t.Errorf("Expected mapping[%s] = %s, got %s", from, to, mapping[from])
}
}
mapping["test_tool"] = "test_value"
newMapping := GetToolNameMapping()
if _, exists := newMapping["test_tool"]; exists {
t.Error("Modifications to returned mapping should not affect original")
}
}
func TestCorrectorStats(t *testing.T) {
corrector := NewCodexToolCorrector()
stats := corrector.GetStats()
if stats.TotalCorrected != 0 {
t.Errorf("Expected TotalCorrected=0, got %d", stats.TotalCorrected)
}
if len(stats.CorrectionsByTool) != 0 {
t.Errorf("Expected empty CorrectionsByTool, got length %d", len(stats.CorrectionsByTool))
}
corrector.CorrectToolCallsInSSEData(`{"tool_calls":[{"function":{"name":"apply_patch"}}]}`)
corrector.CorrectToolCallsInSSEData(`{"tool_calls":[{"function":{"name":"apply_patch"}}]}`)
corrector.CorrectToolCallsInSSEData(`{"tool_calls":[{"function":{"name":"update_plan"}}]}`)
stats = corrector.GetStats()
if stats.TotalCorrected != 3 {
t.Errorf("Expected TotalCorrected=3, got %d", stats.TotalCorrected)
}
if stats.CorrectionsByTool["apply_patch->edit"] != 2 {
t.Errorf("Expected apply_patch->edit count=2, got %d", stats.CorrectionsByTool["apply_patch->edit"])
}
if stats.CorrectionsByTool["update_plan->todowrite"] != 1 {
t.Errorf("Expected update_plan->todowrite count=1, got %d", stats.CorrectionsByTool["update_plan->todowrite"])
}
corrector.ResetStats()
stats = corrector.GetStats()
if stats.TotalCorrected != 0 {
t.Errorf("Expected TotalCorrected=0 after reset, got %d", stats.TotalCorrected)
}
if len(stats.CorrectionsByTool) != 0 {
t.Errorf("Expected empty CorrectionsByTool after reset, got length %d", len(stats.CorrectionsByTool))
}
}
func TestComplexSSEData(t *testing.T) {
corrector := NewCodexToolCorrector()
input := `{
"id": "chatcmpl-123",
"object": "chat.completion.chunk",
"created": 1234567890,
"model": "gpt-5.1-codex",
"choices": [
{
"index": 0,
"delta": {
"tool_calls": [
{
"index": 0,
"function": {
"name": "apply_patch",
"arguments": "{\"file\":\"test.go\"}"
}
}
]
},
"finish_reason": null
}
]
}`
result, corrected := corrector.CorrectToolCallsInSSEData(input)
if !corrected {
t.Error("Expected data to be corrected")
}
var payload map[string]any
if err := json.Unmarshal([]byte(result), &payload); err != nil {
t.Fatalf("Failed to parse result: %v", err)
}
choices, ok := payload["choices"].([]any)
if !ok || len(choices) == 0 {
t.Fatal("No choices found in result")
}
choice, ok := choices[0].(map[string]any)
if !ok {
t.Fatal("Invalid choice format")
}
delta, ok := choice["delta"].(map[string]any)
if !ok {
t.Fatal("Invalid delta format")
}
toolCalls, ok := delta["tool_calls"].([]any)
if !ok || len(toolCalls) == 0 {
t.Fatal("No tool_calls found in delta")
}
toolCall, ok := toolCalls[0].(map[string]any)
if !ok {
t.Fatal("Invalid tool_call format")
}
function, ok := toolCall["function"].(map[string]any)
if !ok {
t.Fatal("Invalid function format")
}
if function["name"] != "edit" {
t.Errorf("Expected tool name 'edit', got '%v'", function["name"])
}
}
// TestCorrectToolParameters 测试工具参数修正
func TestCorrectToolParameters(t *testing.T) {
corrector := NewCodexToolCorrector()
tests := []struct {
name string
input string
expected map[string]bool // key: 期待存在的参数, value: true表示应该存在
}{
{
name: "remove workdir from bash tool",
input: `{
"tool_calls": [{
"function": {
"name": "bash",
"arguments": "{\"command\":\"ls\",\"workdir\":\"/tmp\"}"
}
}]
}`,
expected: map[string]bool{
"command": true,
"workdir": false,
},
},
{
name: "rename path to file_path in edit tool",
input: `{
"tool_calls": [{
"function": {
"name": "apply_patch",
"arguments": "{\"path\":\"/foo/bar.go\",\"old_string\":\"old\",\"new_string\":\"new\"}"
}
}]
}`,
expected: map[string]bool{
"file_path": true,
"path": false,
"old_string": true,
"new_string": true,
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
corrected, changed := corrector.CorrectToolCallsInSSEData(tt.input)
if !changed {
t.Error("expected data to be corrected")
}
// 解析修正后的数据
var result map[string]any
if err := json.Unmarshal([]byte(corrected), &result); err != nil {
t.Fatalf("failed to parse corrected data: %v", err)
}
// 检查工具调用
toolCalls, ok := result["tool_calls"].([]any)
if !ok || len(toolCalls) == 0 {
t.Fatal("no tool_calls found in corrected data")
}
toolCall, ok := toolCalls[0].(map[string]any)
if !ok {
t.Fatal("invalid tool_call structure")
}
function, ok := toolCall["function"].(map[string]any)
if !ok {
t.Fatal("no function found in tool_call")
}
argumentsStr, ok := function["arguments"].(string)
if !ok {
t.Fatal("arguments is not a string")
}
var args map[string]any
if err := json.Unmarshal([]byte(argumentsStr), &args); err != nil {
t.Fatalf("failed to parse arguments: %v", err)
}
// 验证期望的参数
for param, shouldExist := range tt.expected {
_, exists := args[param]
if shouldExist && !exists {
t.Errorf("expected parameter %q to exist, but it doesn't", param)
}
if !shouldExist && exists {
t.Errorf("expected parameter %q to not exist, but it does", param)
}
}
})
}
}
package service
import (
"context"
"errors"
"time"
)
// GetAccountAvailabilityStats returns current account availability stats.
//
// Query-level filtering is intentionally limited to platform/group to match the dashboard scope.
func (s *OpsService) GetAccountAvailabilityStats(ctx context.Context, platformFilter string, groupIDFilter *int64) (
map[string]*PlatformAvailability,
map[int64]*GroupAvailability,
map[int64]*AccountAvailability,
*time.Time,
error,
) {
if err := s.RequireMonitoringEnabled(ctx); err != nil {
return nil, nil, nil, nil, err
}
accounts, err := s.listAllAccountsForOps(ctx, platformFilter)
if err != nil {
return nil, nil, nil, nil, err
}
if groupIDFilter != nil && *groupIDFilter > 0 {
filtered := make([]Account, 0, len(accounts))
for _, acc := range accounts {
for _, grp := range acc.Groups {
if grp != nil && grp.ID == *groupIDFilter {
filtered = append(filtered, acc)
break
}
}
}
accounts = filtered
}
now := time.Now()
collectedAt := now
platform := make(map[string]*PlatformAvailability)
group := make(map[int64]*GroupAvailability)
account := make(map[int64]*AccountAvailability)
for _, acc := range accounts {
if acc.ID <= 0 {
continue
}
isTempUnsched := false
if acc.TempUnschedulableUntil != nil && now.Before(*acc.TempUnschedulableUntil) {
isTempUnsched = true
}
isRateLimited := acc.RateLimitResetAt != nil && now.Before(*acc.RateLimitResetAt)
isOverloaded := acc.OverloadUntil != nil && now.Before(*acc.OverloadUntil)
hasError := acc.Status == StatusError
// Normalize exclusive status flags so the UI doesn't show conflicting badges.
if hasError {
isRateLimited = false
isOverloaded = false
}
isAvailable := acc.Status == StatusActive && acc.Schedulable && !isRateLimited && !isOverloaded && !isTempUnsched
if acc.Platform != "" {
if _, ok := platform[acc.Platform]; !ok {
platform[acc.Platform] = &PlatformAvailability{
Platform: acc.Platform,
}
}
p := platform[acc.Platform]
p.TotalAccounts++
if isAvailable {
p.AvailableCount++
}
if isRateLimited {
p.RateLimitCount++
}
if hasError {
p.ErrorCount++
}
}
for _, grp := range acc.Groups {
if grp == nil || grp.ID <= 0 {
continue
}
if _, ok := group[grp.ID]; !ok {
group[grp.ID] = &GroupAvailability{
GroupID: grp.ID,
GroupName: grp.Name,
Platform: grp.Platform,
}
}
g := group[grp.ID]
g.TotalAccounts++
if isAvailable {
g.AvailableCount++
}
if isRateLimited {
g.RateLimitCount++
}
if hasError {
g.ErrorCount++
}
}
displayGroupID := int64(0)
displayGroupName := ""
if len(acc.Groups) > 0 && acc.Groups[0] != nil {
displayGroupID = acc.Groups[0].ID
displayGroupName = acc.Groups[0].Name
}
item := &AccountAvailability{
AccountID: acc.ID,
AccountName: acc.Name,
Platform: acc.Platform,
GroupID: displayGroupID,
GroupName: displayGroupName,
Status: acc.Status,
IsAvailable: isAvailable,
IsRateLimited: isRateLimited,
IsOverloaded: isOverloaded,
HasError: hasError,
ErrorMessage: acc.ErrorMessage,
}
if isRateLimited && acc.RateLimitResetAt != nil {
item.RateLimitResetAt = acc.RateLimitResetAt
remainingSec := int64(time.Until(*acc.RateLimitResetAt).Seconds())
if remainingSec > 0 {
item.RateLimitRemainingSec = &remainingSec
}
}
if isOverloaded && acc.OverloadUntil != nil {
item.OverloadUntil = acc.OverloadUntil
remainingSec := int64(time.Until(*acc.OverloadUntil).Seconds())
if remainingSec > 0 {
item.OverloadRemainingSec = &remainingSec
}
}
if isTempUnsched && acc.TempUnschedulableUntil != nil {
item.TempUnschedulableUntil = acc.TempUnschedulableUntil
}
account[acc.ID] = item
}
return platform, group, account, &collectedAt, nil
}
type OpsAccountAvailability struct {
Group *GroupAvailability
Accounts map[int64]*AccountAvailability
CollectedAt *time.Time
}
func (s *OpsService) GetAccountAvailability(ctx context.Context, platformFilter string, groupIDFilter *int64) (*OpsAccountAvailability, error) {
if s == nil {
return nil, errors.New("ops service is nil")
}
if s.getAccountAvailability != nil {
return s.getAccountAvailability(ctx, platformFilter, groupIDFilter)
}
_, groupStats, accountStats, collectedAt, err := s.GetAccountAvailabilityStats(ctx, platformFilter, groupIDFilter)
if err != nil {
return nil, err
}
var group *GroupAvailability
if groupIDFilter != nil && *groupIDFilter > 0 {
group = groupStats[*groupIDFilter]
}
if accountStats == nil {
accountStats = map[int64]*AccountAvailability{}
}
return &OpsAccountAvailability{
Group: group,
Accounts: accountStats,
CollectedAt: collectedAt,
}, nil
}
package service
import (
"context"
"database/sql"
"hash/fnv"
"time"
)
func hashAdvisoryLockID(key string) int64 {
h := fnv.New64a()
_, _ = h.Write([]byte(key))
return int64(h.Sum64())
}
func tryAcquireDBAdvisoryLock(ctx context.Context, db *sql.DB, lockID int64) (func(), bool) {
if db == nil {
return nil, false
}
if ctx == nil {
ctx = context.Background()
}
conn, err := db.Conn(ctx)
if err != nil {
return nil, false
}
acquired := false
if err := conn.QueryRowContext(ctx, "SELECT pg_try_advisory_lock($1)", lockID).Scan(&acquired); err != nil {
_ = conn.Close()
return nil, false
}
if !acquired {
_ = conn.Close()
return nil, false
}
release := func() {
unlockCtx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
_, _ = conn.ExecContext(unlockCtx, "SELECT pg_advisory_unlock($1)", lockID)
_ = conn.Close()
}
return release, true
}
package service
import (
"context"
"database/sql"
"errors"
"fmt"
"log"
"strings"
"sync"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/google/uuid"
"github.com/redis/go-redis/v9"
)
const (
opsAggHourlyJobName = "ops_preaggregation_hourly"
opsAggDailyJobName = "ops_preaggregation_daily"
opsAggHourlyInterval = 10 * time.Minute
opsAggDailyInterval = 1 * time.Hour
// Keep in sync with ops retention target (vNext default 30d).
opsAggBackfillWindow = 30 * 24 * time.Hour
// Recompute overlap to absorb late-arriving rows near boundaries.
opsAggHourlyOverlap = 2 * time.Hour
opsAggDailyOverlap = 48 * time.Hour
opsAggHourlyChunk = 24 * time.Hour
opsAggDailyChunk = 7 * 24 * time.Hour
// Delay around boundaries (e.g. 10:00..10:05) to avoid aggregating buckets
// that may still receive late inserts.
opsAggSafeDelay = 5 * time.Minute
opsAggMaxQueryTimeout = 3 * time.Second
opsAggHourlyTimeout = 5 * time.Minute
opsAggDailyTimeout = 2 * time.Minute
opsAggHourlyLeaderLockKey = "ops:aggregation:hourly:leader"
opsAggDailyLeaderLockKey = "ops:aggregation:daily:leader"
opsAggHourlyLeaderLockTTL = 15 * time.Minute
opsAggDailyLeaderLockTTL = 10 * time.Minute
)
// OpsAggregationService periodically backfills ops_metrics_hourly / ops_metrics_daily
// for stable long-window dashboard queries.
//
// It is safe to run in multi-replica deployments when Redis is available (leader lock).
type OpsAggregationService struct {
opsRepo OpsRepository
settingRepo SettingRepository
cfg *config.Config
db *sql.DB
redisClient *redis.Client
instanceID string
stopCh chan struct{}
startOnce sync.Once
stopOnce sync.Once
hourlyMu sync.Mutex
dailyMu sync.Mutex
skipLogMu sync.Mutex
skipLogAt time.Time
}
func NewOpsAggregationService(
opsRepo OpsRepository,
settingRepo SettingRepository,
db *sql.DB,
redisClient *redis.Client,
cfg *config.Config,
) *OpsAggregationService {
return &OpsAggregationService{
opsRepo: opsRepo,
settingRepo: settingRepo,
cfg: cfg,
db: db,
redisClient: redisClient,
instanceID: uuid.NewString(),
}
}
func (s *OpsAggregationService) Start() {
if s == nil {
return
}
s.startOnce.Do(func() {
if s.stopCh == nil {
s.stopCh = make(chan struct{})
}
go s.hourlyLoop()
go s.dailyLoop()
})
}
func (s *OpsAggregationService) Stop() {
if s == nil {
return
}
s.stopOnce.Do(func() {
if s.stopCh != nil {
close(s.stopCh)
}
})
}
func (s *OpsAggregationService) hourlyLoop() {
// First run immediately.
s.aggregateHourly()
ticker := time.NewTicker(opsAggHourlyInterval)
defer ticker.Stop()
for {
select {
case <-ticker.C:
s.aggregateHourly()
case <-s.stopCh:
return
}
}
}
func (s *OpsAggregationService) dailyLoop() {
// First run immediately.
s.aggregateDaily()
ticker := time.NewTicker(opsAggDailyInterval)
defer ticker.Stop()
for {
select {
case <-ticker.C:
s.aggregateDaily()
case <-s.stopCh:
return
}
}
}
func (s *OpsAggregationService) aggregateHourly() {
if s == nil || s.opsRepo == nil {
return
}
if s.cfg != nil {
if !s.cfg.Ops.Enabled {
return
}
if !s.cfg.Ops.Aggregation.Enabled {
return
}
}
ctx, cancel := context.WithTimeout(context.Background(), opsAggHourlyTimeout)
defer cancel()
if !s.isMonitoringEnabled(ctx) {
return
}
release, ok := s.tryAcquireLeaderLock(ctx, opsAggHourlyLeaderLockKey, opsAggHourlyLeaderLockTTL, "[OpsAggregation][hourly]")
if !ok {
return
}
if release != nil {
defer release()
}
s.hourlyMu.Lock()
defer s.hourlyMu.Unlock()
startedAt := time.Now().UTC()
runAt := startedAt
// Aggregate stable full hours only.
end := utcFloorToHour(time.Now().UTC().Add(-opsAggSafeDelay))
start := end.Add(-opsAggBackfillWindow)
// Resume from the latest bucket with overlap.
{
ctxMax, cancelMax := context.WithTimeout(context.Background(), opsAggMaxQueryTimeout)
latest, ok, err := s.opsRepo.GetLatestHourlyBucketStart(ctxMax)
cancelMax()
if err != nil {
log.Printf("[OpsAggregation][hourly] failed to read latest bucket: %v", err)
} else if ok {
candidate := latest.Add(-opsAggHourlyOverlap)
if candidate.After(start) {
start = candidate
}
}
}
start = utcFloorToHour(start)
if !start.Before(end) {
return
}
var aggErr error
for cursor := start; cursor.Before(end); cursor = cursor.Add(opsAggHourlyChunk) {
chunkEnd := minTime(cursor.Add(opsAggHourlyChunk), end)
if err := s.opsRepo.UpsertHourlyMetrics(ctx, cursor, chunkEnd); err != nil {
aggErr = err
log.Printf("[OpsAggregation][hourly] upsert failed (%s..%s): %v", cursor.Format(time.RFC3339), chunkEnd.Format(time.RFC3339), err)
break
}
}
finishedAt := time.Now().UTC()
durationMs := finishedAt.Sub(startedAt).Milliseconds()
dur := durationMs
if aggErr != nil {
msg := truncateString(aggErr.Error(), 2048)
errAt := finishedAt
hbCtx, hbCancel := context.WithTimeout(context.Background(), 2*time.Second)
defer hbCancel()
_ = s.opsRepo.UpsertJobHeartbeat(hbCtx, &OpsUpsertJobHeartbeatInput{
JobName: opsAggHourlyJobName,
LastRunAt: &runAt,
LastErrorAt: &errAt,
LastError: &msg,
LastDurationMs: &dur,
})
return
}
successAt := finishedAt
hbCtx, hbCancel := context.WithTimeout(context.Background(), 2*time.Second)
defer hbCancel()
result := truncateString(fmt.Sprintf("window=%s..%s", start.Format(time.RFC3339), end.Format(time.RFC3339)), 2048)
_ = s.opsRepo.UpsertJobHeartbeat(hbCtx, &OpsUpsertJobHeartbeatInput{
JobName: opsAggHourlyJobName,
LastRunAt: &runAt,
LastSuccessAt: &successAt,
LastDurationMs: &dur,
LastResult: &result,
})
}
func (s *OpsAggregationService) aggregateDaily() {
if s == nil || s.opsRepo == nil {
return
}
if s.cfg != nil {
if !s.cfg.Ops.Enabled {
return
}
if !s.cfg.Ops.Aggregation.Enabled {
return
}
}
ctx, cancel := context.WithTimeout(context.Background(), opsAggDailyTimeout)
defer cancel()
if !s.isMonitoringEnabled(ctx) {
return
}
release, ok := s.tryAcquireLeaderLock(ctx, opsAggDailyLeaderLockKey, opsAggDailyLeaderLockTTL, "[OpsAggregation][daily]")
if !ok {
return
}
if release != nil {
defer release()
}
s.dailyMu.Lock()
defer s.dailyMu.Unlock()
startedAt := time.Now().UTC()
runAt := startedAt
end := utcFloorToDay(time.Now().UTC())
start := end.Add(-opsAggBackfillWindow)
{
ctxMax, cancelMax := context.WithTimeout(context.Background(), opsAggMaxQueryTimeout)
latest, ok, err := s.opsRepo.GetLatestDailyBucketDate(ctxMax)
cancelMax()
if err != nil {
log.Printf("[OpsAggregation][daily] failed to read latest bucket: %v", err)
} else if ok {
candidate := latest.Add(-opsAggDailyOverlap)
if candidate.After(start) {
start = candidate
}
}
}
start = utcFloorToDay(start)
if !start.Before(end) {
return
}
var aggErr error
for cursor := start; cursor.Before(end); cursor = cursor.Add(opsAggDailyChunk) {
chunkEnd := minTime(cursor.Add(opsAggDailyChunk), end)
if err := s.opsRepo.UpsertDailyMetrics(ctx, cursor, chunkEnd); err != nil {
aggErr = err
log.Printf("[OpsAggregation][daily] upsert failed (%s..%s): %v", cursor.Format("2006-01-02"), chunkEnd.Format("2006-01-02"), err)
break
}
}
finishedAt := time.Now().UTC()
durationMs := finishedAt.Sub(startedAt).Milliseconds()
dur := durationMs
if aggErr != nil {
msg := truncateString(aggErr.Error(), 2048)
errAt := finishedAt
hbCtx, hbCancel := context.WithTimeout(context.Background(), 2*time.Second)
defer hbCancel()
_ = s.opsRepo.UpsertJobHeartbeat(hbCtx, &OpsUpsertJobHeartbeatInput{
JobName: opsAggDailyJobName,
LastRunAt: &runAt,
LastErrorAt: &errAt,
LastError: &msg,
LastDurationMs: &dur,
})
return
}
successAt := finishedAt
hbCtx, hbCancel := context.WithTimeout(context.Background(), 2*time.Second)
defer hbCancel()
result := truncateString(fmt.Sprintf("window=%s..%s", start.Format(time.RFC3339), end.Format(time.RFC3339)), 2048)
_ = s.opsRepo.UpsertJobHeartbeat(hbCtx, &OpsUpsertJobHeartbeatInput{
JobName: opsAggDailyJobName,
LastRunAt: &runAt,
LastSuccessAt: &successAt,
LastDurationMs: &dur,
LastResult: &result,
})
}
func (s *OpsAggregationService) isMonitoringEnabled(ctx context.Context) bool {
if s == nil {
return false
}
if s.cfg != nil && !s.cfg.Ops.Enabled {
return false
}
if s.settingRepo == nil {
return true
}
if ctx == nil {
ctx = context.Background()
}
value, err := s.settingRepo.GetValue(ctx, SettingKeyOpsMonitoringEnabled)
if err != nil {
if errors.Is(err, ErrSettingNotFound) {
return true
}
return true
}
switch strings.ToLower(strings.TrimSpace(value)) {
case "false", "0", "off", "disabled":
return false
default:
return true
}
}
var opsAggReleaseScript = redis.NewScript(`
if redis.call("GET", KEYS[1]) == ARGV[1] then
return redis.call("DEL", KEYS[1])
end
return 0
`)
func (s *OpsAggregationService) tryAcquireLeaderLock(ctx context.Context, key string, ttl time.Duration, logPrefix string) (func(), bool) {
if s == nil {
return nil, false
}
if ctx == nil {
ctx = context.Background()
}
// Prefer Redis leader lock when available (multi-instance), but avoid stampeding
// the DB when Redis is flaky by falling back to a DB advisory lock.
if s.redisClient != nil {
ok, err := s.redisClient.SetNX(ctx, key, s.instanceID, ttl).Result()
if err == nil {
if !ok {
s.maybeLogSkip(logPrefix)
return nil, false
}
release := func() {
ctx2, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
_, _ = opsAggReleaseScript.Run(ctx2, s.redisClient, []string{key}, s.instanceID).Result()
}
return release, true
}
// Redis error: fall through to DB advisory lock.
}
release, ok := tryAcquireDBAdvisoryLock(ctx, s.db, hashAdvisoryLockID(key))
if !ok {
s.maybeLogSkip(logPrefix)
return nil, false
}
return release, true
}
func (s *OpsAggregationService) maybeLogSkip(prefix string) {
s.skipLogMu.Lock()
defer s.skipLogMu.Unlock()
now := time.Now()
if !s.skipLogAt.IsZero() && now.Sub(s.skipLogAt) < time.Minute {
return
}
s.skipLogAt = now
if prefix == "" {
prefix = "[OpsAggregation]"
}
log.Printf("%s leader lock held by another instance; skipping", prefix)
}
func utcFloorToHour(t time.Time) time.Time {
return t.UTC().Truncate(time.Hour)
}
func utcFloorToDay(t time.Time) time.Time {
u := t.UTC()
y, m, d := u.Date()
return time.Date(y, m, d, 0, 0, 0, 0, time.UTC)
}
func minTime(a, b time.Time) time.Time {
if a.Before(b) {
return a
}
return b
}
package service
import (
"context"
"fmt"
"log"
"math"
"strconv"
"strings"
"sync"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/google/uuid"
"github.com/redis/go-redis/v9"
)
const (
opsAlertEvaluatorJobName = "ops_alert_evaluator"
opsAlertEvaluatorTimeout = 45 * time.Second
opsAlertEvaluatorLeaderLockKey = "ops:alert:evaluator:leader"
opsAlertEvaluatorLeaderLockTTL = 90 * time.Second
opsAlertEvaluatorSkipLogInterval = 1 * time.Minute
)
var opsAlertEvaluatorReleaseScript = redis.NewScript(`
if redis.call("GET", KEYS[1]) == ARGV[1] then
return redis.call("DEL", KEYS[1])
end
return 0
`)
type OpsAlertEvaluatorService struct {
opsService *OpsService
opsRepo OpsRepository
emailService *EmailService
redisClient *redis.Client
cfg *config.Config
instanceID string
stopCh chan struct{}
startOnce sync.Once
stopOnce sync.Once
wg sync.WaitGroup
mu sync.Mutex
ruleStates map[int64]*opsAlertRuleState
emailLimiter *slidingWindowLimiter
skipLogMu sync.Mutex
skipLogAt time.Time
warnNoRedisOnce sync.Once
}
type opsAlertRuleState struct {
LastEvaluatedAt time.Time
ConsecutiveBreaches int
}
func NewOpsAlertEvaluatorService(
opsService *OpsService,
opsRepo OpsRepository,
emailService *EmailService,
redisClient *redis.Client,
cfg *config.Config,
) *OpsAlertEvaluatorService {
return &OpsAlertEvaluatorService{
opsService: opsService,
opsRepo: opsRepo,
emailService: emailService,
redisClient: redisClient,
cfg: cfg,
instanceID: uuid.NewString(),
ruleStates: map[int64]*opsAlertRuleState{},
emailLimiter: newSlidingWindowLimiter(0, time.Hour),
}
}
func (s *OpsAlertEvaluatorService) Start() {
if s == nil {
return
}
s.startOnce.Do(func() {
if s.stopCh == nil {
s.stopCh = make(chan struct{})
}
go s.run()
})
}
func (s *OpsAlertEvaluatorService) Stop() {
if s == nil {
return
}
s.stopOnce.Do(func() {
if s.stopCh != nil {
close(s.stopCh)
}
})
s.wg.Wait()
}
func (s *OpsAlertEvaluatorService) run() {
s.wg.Add(1)
defer s.wg.Done()
// Start immediately to produce early feedback in ops dashboard.
timer := time.NewTimer(0)
defer timer.Stop()
for {
select {
case <-timer.C:
interval := s.getInterval()
s.evaluateOnce(interval)
timer.Reset(interval)
case <-s.stopCh:
return
}
}
}
func (s *OpsAlertEvaluatorService) getInterval() time.Duration {
// Default.
interval := 60 * time.Second
if s == nil || s.opsService == nil {
return interval
}
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
cfg, err := s.opsService.GetOpsAlertRuntimeSettings(ctx)
if err != nil || cfg == nil {
return interval
}
if cfg.EvaluationIntervalSeconds <= 0 {
return interval
}
if cfg.EvaluationIntervalSeconds < 1 {
return interval
}
if cfg.EvaluationIntervalSeconds > int((24 * time.Hour).Seconds()) {
return interval
}
return time.Duration(cfg.EvaluationIntervalSeconds) * time.Second
}
func (s *OpsAlertEvaluatorService) evaluateOnce(interval time.Duration) {
if s == nil || s.opsRepo == nil {
return
}
if s.cfg != nil && !s.cfg.Ops.Enabled {
return
}
ctx, cancel := context.WithTimeout(context.Background(), opsAlertEvaluatorTimeout)
defer cancel()
if s.opsService != nil && !s.opsService.IsMonitoringEnabled(ctx) {
return
}
runtimeCfg := defaultOpsAlertRuntimeSettings()
if s.opsService != nil {
if loaded, err := s.opsService.GetOpsAlertRuntimeSettings(ctx); err == nil && loaded != nil {
runtimeCfg = loaded
}
}
release, ok := s.tryAcquireLeaderLock(ctx, runtimeCfg.DistributedLock)
if !ok {
return
}
if release != nil {
defer release()
}
startedAt := time.Now().UTC()
runAt := startedAt
rules, err := s.opsRepo.ListAlertRules(ctx)
if err != nil {
s.recordHeartbeatError(runAt, time.Since(startedAt), err)
log.Printf("[OpsAlertEvaluator] list rules failed: %v", err)
return
}
rulesTotal := len(rules)
rulesEnabled := 0
rulesEvaluated := 0
eventsCreated := 0
eventsResolved := 0
emailsSent := 0
now := time.Now().UTC()
safeEnd := now.Truncate(time.Minute)
if safeEnd.IsZero() {
safeEnd = now
}
systemMetrics, _ := s.opsRepo.GetLatestSystemMetrics(ctx, 1)
// Cleanup stale state for removed rules.
s.pruneRuleStates(rules)
for _, rule := range rules {
if rule == nil || !rule.Enabled || rule.ID <= 0 {
continue
}
rulesEnabled++
scopePlatform, scopeGroupID, scopeRegion := parseOpsAlertRuleScope(rule.Filters)
windowMinutes := rule.WindowMinutes
if windowMinutes <= 0 {
windowMinutes = 1
}
windowStart := safeEnd.Add(-time.Duration(windowMinutes) * time.Minute)
windowEnd := safeEnd
metricValue, ok := s.computeRuleMetric(ctx, rule, systemMetrics, windowStart, windowEnd, scopePlatform, scopeGroupID)
if !ok {
s.resetRuleState(rule.ID, now)
continue
}
rulesEvaluated++
breachedNow := compareMetric(metricValue, rule.Operator, rule.Threshold)
required := requiredSustainedBreaches(rule.SustainedMinutes, interval)
consecutive := s.updateRuleBreaches(rule.ID, now, interval, breachedNow)
activeEvent, err := s.opsRepo.GetActiveAlertEvent(ctx, rule.ID)
if err != nil {
log.Printf("[OpsAlertEvaluator] get active event failed (rule=%d): %v", rule.ID, err)
continue
}
if breachedNow && consecutive >= required {
if activeEvent != nil {
continue
}
// Scoped silencing: if a matching silence exists, skip creating a firing event.
if s.opsService != nil {
platform := strings.TrimSpace(scopePlatform)
region := scopeRegion
if platform != "" {
if ok, err := s.opsService.IsAlertSilenced(ctx, rule.ID, platform, scopeGroupID, region, now); err == nil && ok {
continue
}
}
}
latestEvent, err := s.opsRepo.GetLatestAlertEvent(ctx, rule.ID)
if err != nil {
log.Printf("[OpsAlertEvaluator] get latest event failed (rule=%d): %v", rule.ID, err)
continue
}
if latestEvent != nil && rule.CooldownMinutes > 0 {
cooldown := time.Duration(rule.CooldownMinutes) * time.Minute
if now.Sub(latestEvent.FiredAt) < cooldown {
continue
}
}
firedEvent := &OpsAlertEvent{
RuleID: rule.ID,
Severity: strings.TrimSpace(rule.Severity),
Status: OpsAlertStatusFiring,
Title: fmt.Sprintf("%s: %s", strings.TrimSpace(rule.Severity), strings.TrimSpace(rule.Name)),
Description: buildOpsAlertDescription(rule, metricValue, windowMinutes, scopePlatform, scopeGroupID),
MetricValue: float64Ptr(metricValue),
ThresholdValue: float64Ptr(rule.Threshold),
Dimensions: buildOpsAlertDimensions(scopePlatform, scopeGroupID),
FiredAt: now,
CreatedAt: now,
}
created, err := s.opsRepo.CreateAlertEvent(ctx, firedEvent)
if err != nil {
log.Printf("[OpsAlertEvaluator] create event failed (rule=%d): %v", rule.ID, err)
continue
}
eventsCreated++
if created != nil && created.ID > 0 {
if s.maybeSendAlertEmail(ctx, runtimeCfg, rule, created) {
emailsSent++
}
}
continue
}
// Not breached: resolve active event if present.
if activeEvent != nil {
resolvedAt := now
if err := s.opsRepo.UpdateAlertEventStatus(ctx, activeEvent.ID, OpsAlertStatusResolved, &resolvedAt); err != nil {
log.Printf("[OpsAlertEvaluator] resolve event failed (event=%d): %v", activeEvent.ID, err)
} else {
eventsResolved++
}
}
}
result := truncateString(fmt.Sprintf("rules=%d enabled=%d evaluated=%d created=%d resolved=%d emails_sent=%d", rulesTotal, rulesEnabled, rulesEvaluated, eventsCreated, eventsResolved, emailsSent), 2048)
s.recordHeartbeatSuccess(runAt, time.Since(startedAt), result)
}
func (s *OpsAlertEvaluatorService) pruneRuleStates(rules []*OpsAlertRule) {
s.mu.Lock()
defer s.mu.Unlock()
live := map[int64]struct{}{}
for _, r := range rules {
if r != nil && r.ID > 0 {
live[r.ID] = struct{}{}
}
}
for id := range s.ruleStates {
if _, ok := live[id]; !ok {
delete(s.ruleStates, id)
}
}
}
func (s *OpsAlertEvaluatorService) resetRuleState(ruleID int64, now time.Time) {
if ruleID <= 0 {
return
}
s.mu.Lock()
defer s.mu.Unlock()
state, ok := s.ruleStates[ruleID]
if !ok {
state = &opsAlertRuleState{}
s.ruleStates[ruleID] = state
}
state.LastEvaluatedAt = now
state.ConsecutiveBreaches = 0
}
func (s *OpsAlertEvaluatorService) updateRuleBreaches(ruleID int64, now time.Time, interval time.Duration, breached bool) int {
if ruleID <= 0 {
return 0
}
s.mu.Lock()
defer s.mu.Unlock()
state, ok := s.ruleStates[ruleID]
if !ok {
state = &opsAlertRuleState{}
s.ruleStates[ruleID] = state
}
if !state.LastEvaluatedAt.IsZero() && interval > 0 {
if now.Sub(state.LastEvaluatedAt) > interval*2 {
state.ConsecutiveBreaches = 0
}
}
state.LastEvaluatedAt = now
if breached {
state.ConsecutiveBreaches++
} else {
state.ConsecutiveBreaches = 0
}
return state.ConsecutiveBreaches
}
func requiredSustainedBreaches(sustainedMinutes int, interval time.Duration) int {
if sustainedMinutes <= 0 {
return 1
}
if interval <= 0 {
return sustainedMinutes
}
required := int(math.Ceil(float64(sustainedMinutes*60) / interval.Seconds()))
if required < 1 {
return 1
}
return required
}
func parseOpsAlertRuleScope(filters map[string]any) (platform string, groupID *int64, region *string) {
if filters == nil {
return "", nil, nil
}
if v, ok := filters["platform"]; ok {
if s, ok := v.(string); ok {
platform = strings.TrimSpace(s)
}
}
if v, ok := filters["group_id"]; ok {
switch t := v.(type) {
case float64:
if t > 0 {
id := int64(t)
groupID = &id
}
case int64:
if t > 0 {
id := t
groupID = &id
}
case int:
if t > 0 {
id := int64(t)
groupID = &id
}
case string:
n, err := strconv.ParseInt(strings.TrimSpace(t), 10, 64)
if err == nil && n > 0 {
groupID = &n
}
}
}
if v, ok := filters["region"]; ok {
if s, ok := v.(string); ok {
vv := strings.TrimSpace(s)
if vv != "" {
region = &vv
}
}
}
return platform, groupID, region
}
func (s *OpsAlertEvaluatorService) computeRuleMetric(
ctx context.Context,
rule *OpsAlertRule,
systemMetrics *OpsSystemMetricsSnapshot,
start time.Time,
end time.Time,
platform string,
groupID *int64,
) (float64, bool) {
if rule == nil {
return 0, false
}
switch strings.TrimSpace(rule.MetricType) {
case "cpu_usage_percent":
if systemMetrics != nil && systemMetrics.CPUUsagePercent != nil {
return *systemMetrics.CPUUsagePercent, true
}
return 0, false
case "memory_usage_percent":
if systemMetrics != nil && systemMetrics.MemoryUsagePercent != nil {
return *systemMetrics.MemoryUsagePercent, true
}
return 0, false
case "concurrency_queue_depth":
if systemMetrics != nil && systemMetrics.ConcurrencyQueueDepth != nil {
return float64(*systemMetrics.ConcurrencyQueueDepth), true
}
return 0, false
case "group_available_accounts":
if groupID == nil || *groupID <= 0 {
return 0, false
}
if s == nil || s.opsService == nil {
return 0, false
}
availability, err := s.opsService.GetAccountAvailability(ctx, platform, groupID)
if err != nil || availability == nil {
return 0, false
}
if availability.Group == nil {
return 0, true
}
return float64(availability.Group.AvailableCount), true
case "group_available_ratio":
if groupID == nil || *groupID <= 0 {
return 0, false
}
if s == nil || s.opsService == nil {
return 0, false
}
availability, err := s.opsService.GetAccountAvailability(ctx, platform, groupID)
if err != nil || availability == nil {
return 0, false
}
return computeGroupAvailableRatio(availability.Group), true
case "account_rate_limited_count":
if s == nil || s.opsService == nil {
return 0, false
}
availability, err := s.opsService.GetAccountAvailability(ctx, platform, groupID)
if err != nil || availability == nil {
return 0, false
}
return float64(countAccountsByCondition(availability.Accounts, func(acc *AccountAvailability) bool {
return acc.IsRateLimited
})), true
case "account_error_count":
if s == nil || s.opsService == nil {
return 0, false
}
availability, err := s.opsService.GetAccountAvailability(ctx, platform, groupID)
if err != nil || availability == nil {
return 0, false
}
return float64(countAccountsByCondition(availability.Accounts, func(acc *AccountAvailability) bool {
return acc.HasError && acc.TempUnschedulableUntil == nil
})), true
}
overview, err := s.opsRepo.GetDashboardOverview(ctx, &OpsDashboardFilter{
StartTime: start,
EndTime: end,
Platform: platform,
GroupID: groupID,
QueryMode: OpsQueryModeRaw,
})
if err != nil {
return 0, false
}
if overview == nil {
return 0, false
}
switch strings.TrimSpace(rule.MetricType) {
case "success_rate":
if overview.RequestCountSLA <= 0 {
return 0, false
}
return overview.SLA * 100, true
case "error_rate":
if overview.RequestCountSLA <= 0 {
return 0, false
}
return overview.ErrorRate * 100, true
case "upstream_error_rate":
if overview.RequestCountSLA <= 0 {
return 0, false
}
return overview.UpstreamErrorRate * 100, true
default:
return 0, false
}
}
func compareMetric(value float64, operator string, threshold float64) bool {
switch strings.TrimSpace(operator) {
case ">":
return value > threshold
case ">=":
return value >= threshold
case "<":
return value < threshold
case "<=":
return value <= threshold
case "==":
return value == threshold
case "!=":
return value != threshold
default:
return false
}
}
func buildOpsAlertDimensions(platform string, groupID *int64) map[string]any {
dims := map[string]any{}
if strings.TrimSpace(platform) != "" {
dims["platform"] = strings.TrimSpace(platform)
}
if groupID != nil && *groupID > 0 {
dims["group_id"] = *groupID
}
if len(dims) == 0 {
return nil
}
return dims
}
func buildOpsAlertDescription(rule *OpsAlertRule, value float64, windowMinutes int, platform string, groupID *int64) string {
if rule == nil {
return ""
}
scope := "overall"
if strings.TrimSpace(platform) != "" {
scope = fmt.Sprintf("platform=%s", strings.TrimSpace(platform))
}
if groupID != nil && *groupID > 0 {
scope = fmt.Sprintf("%s group_id=%d", scope, *groupID)
}
if windowMinutes <= 0 {
windowMinutes = 1
}
return fmt.Sprintf("%s %s %.2f (current %.2f) over last %dm (%s)",
strings.TrimSpace(rule.MetricType),
strings.TrimSpace(rule.Operator),
rule.Threshold,
value,
windowMinutes,
strings.TrimSpace(scope),
)
}
func (s *OpsAlertEvaluatorService) maybeSendAlertEmail(ctx context.Context, runtimeCfg *OpsAlertRuntimeSettings, rule *OpsAlertRule, event *OpsAlertEvent) bool {
if s == nil || s.emailService == nil || s.opsService == nil || event == nil || rule == nil {
return false
}
if event.EmailSent {
return false
}
if !rule.NotifyEmail {
return false
}
emailCfg, err := s.opsService.GetEmailNotificationConfig(ctx)
if err != nil || emailCfg == nil || !emailCfg.Alert.Enabled {
return false
}
if len(emailCfg.Alert.Recipients) == 0 {
return false
}
if !shouldSendOpsAlertEmailByMinSeverity(strings.TrimSpace(emailCfg.Alert.MinSeverity), strings.TrimSpace(rule.Severity)) {
return false
}
if runtimeCfg != nil && runtimeCfg.Silencing.Enabled {
if isOpsAlertSilenced(time.Now().UTC(), rule, event, runtimeCfg.Silencing) {
return false
}
}
// Apply/update rate limiter.
s.emailLimiter.SetLimit(emailCfg.Alert.RateLimitPerHour)
subject := fmt.Sprintf("[Ops Alert][%s] %s", strings.TrimSpace(rule.Severity), strings.TrimSpace(rule.Name))
body := buildOpsAlertEmailBody(rule, event)
anySent := false
for _, to := range emailCfg.Alert.Recipients {
addr := strings.TrimSpace(to)
if addr == "" {
continue
}
if !s.emailLimiter.Allow(time.Now().UTC()) {
continue
}
if err := s.emailService.SendEmail(ctx, addr, subject, body); err != nil {
// Ignore per-recipient failures; continue best-effort.
continue
}
anySent = true
}
if anySent {
_ = s.opsRepo.UpdateAlertEventEmailSent(context.Background(), event.ID, true)
}
return anySent
}
func buildOpsAlertEmailBody(rule *OpsAlertRule, event *OpsAlertEvent) string {
if rule == nil || event == nil {
return ""
}
metric := strings.TrimSpace(rule.MetricType)
value := "-"
threshold := fmt.Sprintf("%.2f", rule.Threshold)
if event.MetricValue != nil {
value = fmt.Sprintf("%.2f", *event.MetricValue)
}
if event.ThresholdValue != nil {
threshold = fmt.Sprintf("%.2f", *event.ThresholdValue)
}
return fmt.Sprintf(`
<h2>Ops Alert</h2>
<p><b>Rule</b>: %s</p>
<p><b>Severity</b>: %s</p>
<p><b>Status</b>: %s</p>
<p><b>Metric</b>: %s %s %s</p>
<p><b>Fired at</b>: %s</p>
<p><b>Description</b>: %s</p>
`,
htmlEscape(rule.Name),
htmlEscape(rule.Severity),
htmlEscape(event.Status),
htmlEscape(metric),
htmlEscape(rule.Operator),
htmlEscape(fmt.Sprintf("%s (threshold %s)", value, threshold)),
event.FiredAt.Format(time.RFC3339),
htmlEscape(event.Description),
)
}
func shouldSendOpsAlertEmailByMinSeverity(minSeverity string, ruleSeverity string) bool {
minSeverity = strings.ToLower(strings.TrimSpace(minSeverity))
if minSeverity == "" {
return true
}
eventLevel := opsEmailSeverityForOps(ruleSeverity)
minLevel := strings.ToLower(minSeverity)
rank := func(level string) int {
switch level {
case "critical":
return 3
case "warning":
return 2
case "info":
return 1
default:
return 0
}
}
return rank(eventLevel) >= rank(minLevel)
}
func opsEmailSeverityForOps(severity string) string {
switch strings.ToUpper(strings.TrimSpace(severity)) {
case "P0":
return "critical"
case "P1":
return "warning"
default:
return "info"
}
}
func isOpsAlertSilenced(now time.Time, rule *OpsAlertRule, event *OpsAlertEvent, silencing OpsAlertSilencingSettings) bool {
if !silencing.Enabled {
return false
}
if now.IsZero() {
now = time.Now().UTC()
}
if strings.TrimSpace(silencing.GlobalUntilRFC3339) != "" {
if t, err := time.Parse(time.RFC3339, strings.TrimSpace(silencing.GlobalUntilRFC3339)); err == nil {
if now.Before(t) {
return true
}
}
}
for _, entry := range silencing.Entries {
untilRaw := strings.TrimSpace(entry.UntilRFC3339)
if untilRaw == "" {
continue
}
until, err := time.Parse(time.RFC3339, untilRaw)
if err != nil {
continue
}
if now.After(until) {
continue
}
if entry.RuleID != nil && rule != nil && rule.ID > 0 && *entry.RuleID != rule.ID {
continue
}
if len(entry.Severities) > 0 {
match := false
for _, s := range entry.Severities {
if strings.EqualFold(strings.TrimSpace(s), strings.TrimSpace(event.Severity)) || strings.EqualFold(strings.TrimSpace(s), strings.TrimSpace(rule.Severity)) {
match = true
break
}
}
if !match {
continue
}
}
return true
}
return false
}
func (s *OpsAlertEvaluatorService) tryAcquireLeaderLock(ctx context.Context, lock OpsDistributedLockSettings) (func(), bool) {
if !lock.Enabled {
return nil, true
}
if s.redisClient == nil {
s.warnNoRedisOnce.Do(func() {
log.Printf("[OpsAlertEvaluator] redis not configured; running without distributed lock")
})
return nil, true
}
key := strings.TrimSpace(lock.Key)
if key == "" {
key = opsAlertEvaluatorLeaderLockKey
}
ttl := time.Duration(lock.TTLSeconds) * time.Second
if ttl <= 0 {
ttl = opsAlertEvaluatorLeaderLockTTL
}
ok, err := s.redisClient.SetNX(ctx, key, s.instanceID, ttl).Result()
if err != nil {
// Prefer fail-closed to avoid duplicate evaluators stampeding the DB when Redis is flaky.
// Single-node deployments can disable the distributed lock via runtime settings.
s.warnNoRedisOnce.Do(func() {
log.Printf("[OpsAlertEvaluator] leader lock SetNX failed; skipping this cycle: %v", err)
})
return nil, false
}
if !ok {
s.maybeLogSkip(key)
return nil, false
}
return func() {
_, _ = opsAlertEvaluatorReleaseScript.Run(ctx, s.redisClient, []string{key}, s.instanceID).Result()
}, true
}
func (s *OpsAlertEvaluatorService) maybeLogSkip(key string) {
s.skipLogMu.Lock()
defer s.skipLogMu.Unlock()
now := time.Now()
if !s.skipLogAt.IsZero() && now.Sub(s.skipLogAt) < opsAlertEvaluatorSkipLogInterval {
return
}
s.skipLogAt = now
log.Printf("[OpsAlertEvaluator] leader lock held by another instance; skipping (key=%q)", key)
}
func (s *OpsAlertEvaluatorService) recordHeartbeatSuccess(runAt time.Time, duration time.Duration, result string) {
if s == nil || s.opsRepo == nil {
return
}
now := time.Now().UTC()
durMs := duration.Milliseconds()
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
msg := strings.TrimSpace(result)
if msg == "" {
msg = "ok"
}
msg = truncateString(msg, 2048)
_ = s.opsRepo.UpsertJobHeartbeat(ctx, &OpsUpsertJobHeartbeatInput{
JobName: opsAlertEvaluatorJobName,
LastRunAt: &runAt,
LastSuccessAt: &now,
LastDurationMs: &durMs,
LastResult: &msg,
})
}
func (s *OpsAlertEvaluatorService) recordHeartbeatError(runAt time.Time, duration time.Duration, err error) {
if s == nil || s.opsRepo == nil || err == nil {
return
}
now := time.Now().UTC()
durMs := duration.Milliseconds()
msg := truncateString(err.Error(), 2048)
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
_ = s.opsRepo.UpsertJobHeartbeat(ctx, &OpsUpsertJobHeartbeatInput{
JobName: opsAlertEvaluatorJobName,
LastRunAt: &runAt,
LastErrorAt: &now,
LastError: &msg,
LastDurationMs: &durMs,
})
}
func htmlEscape(s string) string {
replacer := strings.NewReplacer(
"&", "&amp;",
"<", "&lt;",
">", "&gt;",
`"`, "&quot;",
"'", "&#39;",
)
return replacer.Replace(s)
}
type slidingWindowLimiter struct {
mu sync.Mutex
limit int
window time.Duration
sent []time.Time
}
func newSlidingWindowLimiter(limit int, window time.Duration) *slidingWindowLimiter {
if window <= 0 {
window = time.Hour
}
return &slidingWindowLimiter{
limit: limit,
window: window,
sent: []time.Time{},
}
}
func (l *slidingWindowLimiter) SetLimit(limit int) {
l.mu.Lock()
defer l.mu.Unlock()
l.limit = limit
}
func (l *slidingWindowLimiter) Allow(now time.Time) bool {
l.mu.Lock()
defer l.mu.Unlock()
if l.limit <= 0 {
return true
}
cutoff := now.Add(-l.window)
keep := l.sent[:0]
for _, t := range l.sent {
if t.After(cutoff) {
keep = append(keep, t)
}
}
l.sent = keep
if len(l.sent) >= l.limit {
return false
}
l.sent = append(l.sent, now)
return true
}
// computeGroupAvailableRatio returns the available percentage for a group.
// Formula: (AvailableCount / TotalAccounts) * 100.
// Returns 0 when TotalAccounts is 0.
func computeGroupAvailableRatio(group *GroupAvailability) float64 {
if group == nil || group.TotalAccounts <= 0 {
return 0
}
return (float64(group.AvailableCount) / float64(group.TotalAccounts)) * 100
}
// countAccountsByCondition counts accounts that satisfy the given condition.
func countAccountsByCondition(accounts map[int64]*AccountAvailability, condition func(*AccountAvailability) bool) int64 {
if len(accounts) == 0 || condition == nil {
return 0
}
var count int64
for _, account := range accounts {
if account != nil && condition(account) {
count++
}
}
return count
}
//go:build unit
package service
import (
"context"
"testing"
"time"
"github.com/stretchr/testify/require"
)
type stubOpsRepo struct {
OpsRepository
overview *OpsDashboardOverview
err error
}
func (s *stubOpsRepo) GetDashboardOverview(ctx context.Context, filter *OpsDashboardFilter) (*OpsDashboardOverview, error) {
if s.err != nil {
return nil, s.err
}
if s.overview != nil {
return s.overview, nil
}
return &OpsDashboardOverview{}, nil
}
func TestComputeGroupAvailableRatio(t *testing.T) {
t.Parallel()
t.Run("正常情况: 10个账号, 8个可用 = 80%", func(t *testing.T) {
t.Parallel()
got := computeGroupAvailableRatio(&GroupAvailability{
TotalAccounts: 10,
AvailableCount: 8,
})
require.InDelta(t, 80.0, got, 0.0001)
})
t.Run("边界情况: TotalAccounts = 0 应返回 0", func(t *testing.T) {
t.Parallel()
got := computeGroupAvailableRatio(&GroupAvailability{
TotalAccounts: 0,
AvailableCount: 8,
})
require.Equal(t, 0.0, got)
})
t.Run("边界情况: AvailableCount = 0 应返回 0%", func(t *testing.T) {
t.Parallel()
got := computeGroupAvailableRatio(&GroupAvailability{
TotalAccounts: 10,
AvailableCount: 0,
})
require.Equal(t, 0.0, got)
})
}
func TestCountAccountsByCondition(t *testing.T) {
t.Parallel()
t.Run("测试限流账号统计: acc.IsRateLimited", func(t *testing.T) {
t.Parallel()
accounts := map[int64]*AccountAvailability{
1: {IsRateLimited: true},
2: {IsRateLimited: false},
3: {IsRateLimited: true},
}
got := countAccountsByCondition(accounts, func(acc *AccountAvailability) bool {
return acc.IsRateLimited
})
require.Equal(t, int64(2), got)
})
t.Run("测试错误账号统计(排除临时不可调度): acc.HasError && acc.TempUnschedulableUntil == nil", func(t *testing.T) {
t.Parallel()
until := time.Now().UTC().Add(5 * time.Minute)
accounts := map[int64]*AccountAvailability{
1: {HasError: true},
2: {HasError: true, TempUnschedulableUntil: &until},
3: {HasError: false},
}
got := countAccountsByCondition(accounts, func(acc *AccountAvailability) bool {
return acc.HasError && acc.TempUnschedulableUntil == nil
})
require.Equal(t, int64(1), got)
})
t.Run("边界情况: 空 map 应返回 0", func(t *testing.T) {
t.Parallel()
got := countAccountsByCondition(map[int64]*AccountAvailability{}, func(acc *AccountAvailability) bool {
return acc.IsRateLimited
})
require.Equal(t, int64(0), got)
})
}
func TestComputeRuleMetricNewIndicators(t *testing.T) {
t.Parallel()
groupID := int64(101)
platform := "openai"
availability := &OpsAccountAvailability{
Group: &GroupAvailability{
GroupID: groupID,
TotalAccounts: 10,
AvailableCount: 8,
},
Accounts: map[int64]*AccountAvailability{
1: {IsRateLimited: true},
2: {IsRateLimited: true},
3: {HasError: true},
4: {HasError: true, TempUnschedulableUntil: timePtr(time.Now().UTC().Add(2 * time.Minute))},
5: {HasError: false, IsRateLimited: false},
},
}
opsService := &OpsService{
getAccountAvailability: func(_ context.Context, _ string, _ *int64) (*OpsAccountAvailability, error) {
return availability, nil
},
}
svc := &OpsAlertEvaluatorService{
opsService: opsService,
opsRepo: &stubOpsRepo{overview: &OpsDashboardOverview{}},
}
start := time.Now().UTC().Add(-5 * time.Minute)
end := time.Now().UTC()
ctx := context.Background()
tests := []struct {
name string
metricType string
groupID *int64
wantValue float64
wantOK bool
}{
{
name: "group_available_accounts",
metricType: "group_available_accounts",
groupID: &groupID,
wantValue: 8,
wantOK: true,
},
{
name: "group_available_ratio",
metricType: "group_available_ratio",
groupID: &groupID,
wantValue: 80.0,
wantOK: true,
},
{
name: "account_rate_limited_count",
metricType: "account_rate_limited_count",
groupID: nil,
wantValue: 2,
wantOK: true,
},
{
name: "account_error_count",
metricType: "account_error_count",
groupID: nil,
wantValue: 1,
wantOK: true,
},
{
name: "group_available_accounts without group_id returns false",
metricType: "group_available_accounts",
groupID: nil,
wantValue: 0,
wantOK: false,
},
{
name: "group_available_ratio without group_id returns false",
metricType: "group_available_ratio",
groupID: nil,
wantValue: 0,
wantOK: false,
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
rule := &OpsAlertRule{
MetricType: tt.metricType,
}
gotValue, gotOK := svc.computeRuleMetric(ctx, rule, nil, start, end, platform, tt.groupID)
require.Equal(t, tt.wantOK, gotOK)
if !tt.wantOK {
return
}
require.InDelta(t, tt.wantValue, gotValue, 0.0001)
})
}
}
package service
import "time"
// Ops alert rule/event models.
//
// NOTE: These are admin-facing DTOs and intentionally keep JSON naming aligned
// with the existing ops dashboard frontend (backup style).
const (
OpsAlertStatusFiring = "firing"
OpsAlertStatusResolved = "resolved"
OpsAlertStatusManualResolved = "manual_resolved"
)
type OpsAlertRule struct {
ID int64 `json:"id"`
Name string `json:"name"`
Description string `json:"description"`
Enabled bool `json:"enabled"`
Severity string `json:"severity"`
MetricType string `json:"metric_type"`
Operator string `json:"operator"`
Threshold float64 `json:"threshold"`
WindowMinutes int `json:"window_minutes"`
SustainedMinutes int `json:"sustained_minutes"`
CooldownMinutes int `json:"cooldown_minutes"`
NotifyEmail bool `json:"notify_email"`
Filters map[string]any `json:"filters,omitempty"`
LastTriggeredAt *time.Time `json:"last_triggered_at,omitempty"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
}
type OpsAlertEvent struct {
ID int64 `json:"id"`
RuleID int64 `json:"rule_id"`
Severity string `json:"severity"`
Status string `json:"status"`
Title string `json:"title"`
Description string `json:"description"`
MetricValue *float64 `json:"metric_value,omitempty"`
ThresholdValue *float64 `json:"threshold_value,omitempty"`
Dimensions map[string]any `json:"dimensions,omitempty"`
FiredAt time.Time `json:"fired_at"`
ResolvedAt *time.Time `json:"resolved_at,omitempty"`
EmailSent bool `json:"email_sent"`
CreatedAt time.Time `json:"created_at"`
}
type OpsAlertSilence struct {
ID int64 `json:"id"`
RuleID int64 `json:"rule_id"`
Platform string `json:"platform"`
GroupID *int64 `json:"group_id,omitempty"`
Region *string `json:"region,omitempty"`
Until time.Time `json:"until"`
Reason string `json:"reason"`
CreatedBy *int64 `json:"created_by,omitempty"`
CreatedAt time.Time `json:"created_at"`
}
type OpsAlertEventFilter struct {
Limit int
// Cursor pagination (descending by fired_at, then id).
BeforeFiredAt *time.Time
BeforeID *int64
// Optional filters.
Status string
Severity string
EmailSent *bool
StartTime *time.Time
EndTime *time.Time
// Dimensions filters (best-effort).
Platform string
GroupID *int64
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment