Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Menu
Open sidebar
陈曦
sub2api
Commits
2fe8932c
"backend/internal/service/vscode:/vscode.git/clone" did not exist on "876e85e7adce3869e1c59f6a3ff7752ff92ac435"
Unverified
Commit
2fe8932c
authored
Feb 03, 2026
by
Call White
Committed by
GitHub
Feb 03, 2026
Browse files
Merge pull request #3 from cyhhao/main
merge to main
parents
2f2e76f9
adb77af1
Changes
267
Hide whitespace changes
Inline
Side-by-side
backend/internal/service/http_upstream_port.go
View file @
2fe8932c
...
@@ -10,6 +10,7 @@ import "net/http"
...
@@ -10,6 +10,7 @@ import "net/http"
// - 支持可选代理配置
// - 支持可选代理配置
// - 支持账户级连接池隔离
// - 支持账户级连接池隔离
// - 实现类负责连接池管理和复用
// - 实现类负责连接池管理和复用
// - 支持可选的 TLS 指纹伪装
type
HTTPUpstream
interface
{
type
HTTPUpstream
interface
{
// Do 执行 HTTP 请求
// Do 执行 HTTP 请求
//
//
...
@@ -27,4 +28,28 @@ type HTTPUpstream interface {
...
@@ -27,4 +28,28 @@ type HTTPUpstream interface {
// - 调用方必须关闭 resp.Body,否则会导致连接泄漏
// - 调用方必须关闭 resp.Body,否则会导致连接泄漏
// - 响应体可能已被包装以跟踪请求生命周期
// - 响应体可能已被包装以跟踪请求生命周期
Do
(
req
*
http
.
Request
,
proxyURL
string
,
accountID
int64
,
accountConcurrency
int
)
(
*
http
.
Response
,
error
)
Do
(
req
*
http
.
Request
,
proxyURL
string
,
accountID
int64
,
accountConcurrency
int
)
(
*
http
.
Response
,
error
)
// DoWithTLS 执行带 TLS 指纹伪装的 HTTP 请求
//
// 参数:
// - req: HTTP 请求对象,由调用方构建
// - proxyURL: 代理服务器地址,空字符串表示直连
// - accountID: 账户 ID,用于连接池隔离和 TLS 指纹模板选择
// - accountConcurrency: 账户并发限制,用于动态调整连接池大小
// - enableTLSFingerprint: 是否启用 TLS 指纹伪装
//
// 返回:
// - *http.Response: HTTP 响应,调用方必须关闭 Body
// - error: 请求错误(网络错误、超时等)
//
// TLS 指纹说明:
// - 当 enableTLSFingerprint=true 时,使用 utls 库模拟 Claude CLI 的 TLS 指纹
// - TLS 指纹模板根据 accountID % len(profiles) 自动选择
// - 支持直连、HTTP/HTTPS 代理、SOCKS5 代理三种场景
// - 如果 enableTLSFingerprint=false,行为与 Do 方法相同
//
// 注意:
// - 调用方必须关闭 resp.Body,否则会导致连接泄漏
// - TLS 指纹客户端与普通客户端使用不同的缓存键,互不影响
DoWithTLS
(
req
*
http
.
Request
,
proxyURL
string
,
accountID
int64
,
accountConcurrency
int
,
enableTLSFingerprint
bool
)
(
*
http
.
Response
,
error
)
}
}
backend/internal/service/identity_service.go
View file @
2fe8932c
...
@@ -8,9 +8,11 @@ import (
...
@@ -8,9 +8,11 @@ import (
"encoding/json"
"encoding/json"
"fmt"
"fmt"
"log"
"log"
"log/slog"
"net/http"
"net/http"
"regexp"
"regexp"
"strconv"
"strconv"
"strings"
"time"
"time"
)
)
...
@@ -24,13 +26,13 @@ var (
...
@@ -24,13 +26,13 @@ var (
// 默认指纹值(当客户端未提供时使用)
// 默认指纹值(当客户端未提供时使用)
var
defaultFingerprint
=
Fingerprint
{
var
defaultFingerprint
=
Fingerprint
{
UserAgent
:
"claude-cli/2.1.2 (external, cli)"
,
UserAgent
:
"claude-cli/2.1.2
2
(external, cli)"
,
StainlessLang
:
"js"
,
StainlessLang
:
"js"
,
StainlessPackageVersion
:
"0.70.0"
,
StainlessPackageVersion
:
"0.70.0"
,
StainlessOS
:
"Linux"
,
StainlessOS
:
"Linux"
,
StainlessArch
:
"
x
64"
,
StainlessArch
:
"
arm
64"
,
StainlessRuntime
:
"node"
,
StainlessRuntime
:
"node"
,
StainlessRuntimeVersion
:
"v24.3.0"
,
StainlessRuntimeVersion
:
"v24.
1
3.0"
,
}
}
// Fingerprint represents account fingerprint data
// Fingerprint represents account fingerprint data
...
@@ -49,6 +51,13 @@ type Fingerprint struct {
...
@@ -49,6 +51,13 @@ type Fingerprint struct {
type
IdentityCache
interface
{
type
IdentityCache
interface
{
GetFingerprint
(
ctx
context
.
Context
,
accountID
int64
)
(
*
Fingerprint
,
error
)
GetFingerprint
(
ctx
context
.
Context
,
accountID
int64
)
(
*
Fingerprint
,
error
)
SetFingerprint
(
ctx
context
.
Context
,
accountID
int64
,
fp
*
Fingerprint
)
error
SetFingerprint
(
ctx
context
.
Context
,
accountID
int64
,
fp
*
Fingerprint
)
error
// GetMaskedSessionID 获取固定的会话ID(用于会话ID伪装功能)
// 返回的 sessionID 是一个 UUID 格式的字符串
// 如果不存在或已过期(15分钟无请求),返回空字符串
GetMaskedSessionID
(
ctx
context
.
Context
,
accountID
int64
)
(
string
,
error
)
// SetMaskedSessionID 设置固定的会话ID,TTL 为 15 分钟
// 每次调用都会刷新 TTL
SetMaskedSessionID
(
ctx
context
.
Context
,
accountID
int64
,
sessionID
string
)
error
}
}
// IdentityService 管理OAuth账号的请求身份指纹
// IdentityService 管理OAuth账号的请求身份指纹
...
@@ -203,6 +212,94 @@ func (s *IdentityService) RewriteUserID(body []byte, accountID int64, accountUUI
...
@@ -203,6 +212,94 @@ func (s *IdentityService) RewriteUserID(body []byte, accountID int64, accountUUI
return
json
.
Marshal
(
reqMap
)
return
json
.
Marshal
(
reqMap
)
}
}
// RewriteUserIDWithMasking 重写body中的metadata.user_id,支持会话ID伪装
// 如果账号启用了会话ID伪装(session_id_masking_enabled),
// 则在完成常规重写后,将 session 部分替换为固定的伪装ID(15分钟内保持不变)
func
(
s
*
IdentityService
)
RewriteUserIDWithMasking
(
ctx
context
.
Context
,
body
[]
byte
,
account
*
Account
,
accountUUID
,
cachedClientID
string
)
([]
byte
,
error
)
{
// 先执行常规的 RewriteUserID 逻辑
newBody
,
err
:=
s
.
RewriteUserID
(
body
,
account
.
ID
,
accountUUID
,
cachedClientID
)
if
err
!=
nil
{
return
newBody
,
err
}
// 检查是否启用会话ID伪装
if
!
account
.
IsSessionIDMaskingEnabled
()
{
return
newBody
,
nil
}
// 解析重写后的 body,提取 user_id
var
reqMap
map
[
string
]
any
if
err
:=
json
.
Unmarshal
(
newBody
,
&
reqMap
);
err
!=
nil
{
return
newBody
,
nil
}
metadata
,
ok
:=
reqMap
[
"metadata"
]
.
(
map
[
string
]
any
)
if
!
ok
{
return
newBody
,
nil
}
userID
,
ok
:=
metadata
[
"user_id"
]
.
(
string
)
if
!
ok
||
userID
==
""
{
return
newBody
,
nil
}
// 查找 _session_ 的位置,替换其后的内容
const
sessionMarker
=
"_session_"
idx
:=
strings
.
LastIndex
(
userID
,
sessionMarker
)
if
idx
==
-
1
{
return
newBody
,
nil
}
// 获取或生成固定的伪装 session ID
maskedSessionID
,
err
:=
s
.
cache
.
GetMaskedSessionID
(
ctx
,
account
.
ID
)
if
err
!=
nil
{
log
.
Printf
(
"Warning: failed to get masked session ID for account %d: %v"
,
account
.
ID
,
err
)
return
newBody
,
nil
}
if
maskedSessionID
==
""
{
// 首次或已过期,生成新的伪装 session ID
maskedSessionID
=
generateRandomUUID
()
log
.
Printf
(
"Generated new masked session ID for account %d: %s"
,
account
.
ID
,
maskedSessionID
)
}
// 刷新 TTL(每次请求都刷新,保持 15 分钟有效期)
if
err
:=
s
.
cache
.
SetMaskedSessionID
(
ctx
,
account
.
ID
,
maskedSessionID
);
err
!=
nil
{
log
.
Printf
(
"Warning: failed to set masked session ID for account %d: %v"
,
account
.
ID
,
err
)
}
// 替换 session 部分:保留 _session_ 之前的内容,替换之后的内容
newUserID
:=
userID
[
:
idx
+
len
(
sessionMarker
)]
+
maskedSessionID
slog
.
Debug
(
"session_id_masking_applied"
,
"account_id"
,
account
.
ID
,
"before"
,
userID
,
"after"
,
newUserID
,
)
metadata
[
"user_id"
]
=
newUserID
reqMap
[
"metadata"
]
=
metadata
return
json
.
Marshal
(
reqMap
)
}
// generateRandomUUID 生成随机 UUID v4 格式字符串
func
generateRandomUUID
()
string
{
b
:=
make
([]
byte
,
16
)
if
_
,
err
:=
rand
.
Read
(
b
);
err
!=
nil
{
// fallback: 使用时间戳生成
h
:=
sha256
.
Sum256
([]
byte
(
fmt
.
Sprintf
(
"%d"
,
time
.
Now
()
.
UnixNano
())))
b
=
h
[
:
16
]
}
// 设置 UUID v4 版本和变体位
b
[
6
]
=
(
b
[
6
]
&
0x0f
)
|
0x40
b
[
8
]
=
(
b
[
8
]
&
0x3f
)
|
0x80
return
fmt
.
Sprintf
(
"%x-%x-%x-%x-%x"
,
b
[
0
:
4
],
b
[
4
:
6
],
b
[
6
:
8
],
b
[
8
:
10
],
b
[
10
:
16
])
}
// generateClientID 生成64位十六进制客户端ID(32字节随机数)
// generateClientID 生成64位十六进制客户端ID(32字节随机数)
func
generateClientID
()
string
{
func
generateClientID
()
string
{
b
:=
make
([]
byte
,
32
)
b
:=
make
([]
byte
,
32
)
...
...
backend/internal/service/oauth_service.go
View file @
2fe8932c
...
@@ -48,8 +48,7 @@ type GenerateAuthURLResult struct {
...
@@ -48,8 +48,7 @@ type GenerateAuthURLResult struct {
// GenerateAuthURL generates an OAuth authorization URL with full scope
// GenerateAuthURL generates an OAuth authorization URL with full scope
func
(
s
*
OAuthService
)
GenerateAuthURL
(
ctx
context
.
Context
,
proxyID
*
int64
)
(
*
GenerateAuthURLResult
,
error
)
{
func
(
s
*
OAuthService
)
GenerateAuthURL
(
ctx
context
.
Context
,
proxyID
*
int64
)
(
*
GenerateAuthURLResult
,
error
)
{
scope
:=
fmt
.
Sprintf
(
"%s %s"
,
oauth
.
ScopeProfile
,
oauth
.
ScopeInference
)
return
s
.
generateAuthURLWithScope
(
ctx
,
oauth
.
ScopeOAuth
,
proxyID
)
return
s
.
generateAuthURLWithScope
(
ctx
,
scope
,
proxyID
)
}
}
// GenerateSetupTokenURL generates an OAuth authorization URL for setup token (inference only)
// GenerateSetupTokenURL generates an OAuth authorization URL for setup token (inference only)
...
@@ -123,6 +122,7 @@ type TokenInfo struct {
...
@@ -123,6 +122,7 @@ type TokenInfo struct {
Scope
string
`json:"scope,omitempty"`
Scope
string
`json:"scope,omitempty"`
OrgUUID
string
`json:"org_uuid,omitempty"`
OrgUUID
string
`json:"org_uuid,omitempty"`
AccountUUID
string
`json:"account_uuid,omitempty"`
AccountUUID
string
`json:"account_uuid,omitempty"`
EmailAddress
string
`json:"email_address,omitempty"`
}
}
// ExchangeCode exchanges authorization code for tokens
// ExchangeCode exchanges authorization code for tokens
...
@@ -176,7 +176,8 @@ func (s *OAuthService) CookieAuth(ctx context.Context, input *CookieAuthInput) (
...
@@ -176,7 +176,8 @@ func (s *OAuthService) CookieAuth(ctx context.Context, input *CookieAuthInput) (
}
}
// Determine scope and if this is a setup token
// Determine scope and if this is a setup token
scope
:=
fmt
.
Sprintf
(
"%s %s"
,
oauth
.
ScopeProfile
,
oauth
.
ScopeInference
)
// Internal API call uses ScopeAPI (org:create_api_key not supported)
scope
:=
oauth
.
ScopeAPI
isSetupToken
:=
false
isSetupToken
:=
false
if
input
.
Scope
==
"inference"
{
if
input
.
Scope
==
"inference"
{
scope
=
oauth
.
ScopeInference
scope
=
oauth
.
ScopeInference
...
@@ -252,9 +253,15 @@ func (s *OAuthService) exchangeCodeForToken(ctx context.Context, code, codeVerif
...
@@ -252,9 +253,15 @@ func (s *OAuthService) exchangeCodeForToken(ctx context.Context, code, codeVerif
tokenInfo
.
OrgUUID
=
tokenResp
.
Organization
.
UUID
tokenInfo
.
OrgUUID
=
tokenResp
.
Organization
.
UUID
log
.
Printf
(
"[OAuth] Got org_uuid: %s"
,
tokenInfo
.
OrgUUID
)
log
.
Printf
(
"[OAuth] Got org_uuid: %s"
,
tokenInfo
.
OrgUUID
)
}
}
if
tokenResp
.
Account
!=
nil
&&
tokenResp
.
Account
.
UUID
!=
""
{
if
tokenResp
.
Account
!=
nil
{
tokenInfo
.
AccountUUID
=
tokenResp
.
Account
.
UUID
if
tokenResp
.
Account
.
UUID
!=
""
{
log
.
Printf
(
"[OAuth] Got account_uuid: %s"
,
tokenInfo
.
AccountUUID
)
tokenInfo
.
AccountUUID
=
tokenResp
.
Account
.
UUID
log
.
Printf
(
"[OAuth] Got account_uuid: %s"
,
tokenInfo
.
AccountUUID
)
}
if
tokenResp
.
Account
.
EmailAddress
!=
""
{
tokenInfo
.
EmailAddress
=
tokenResp
.
Account
.
EmailAddress
log
.
Printf
(
"[OAuth] Got email_address: %s"
,
tokenInfo
.
EmailAddress
)
}
}
}
return
tokenInfo
,
nil
return
tokenInfo
,
nil
...
...
backend/internal/service/openai_codex_transform.go
View file @
2fe8932c
...
@@ -394,19 +394,35 @@ func normalizeCodexTools(reqBody map[string]any) bool {
...
@@ -394,19 +394,35 @@ func normalizeCodexTools(reqBody map[string]any) bool {
}
}
modified
:=
false
modified
:=
false
for
idx
,
tool
:=
range
tools
{
validTools
:=
make
([]
any
,
0
,
len
(
tools
))
for
_
,
tool
:=
range
tools
{
toolMap
,
ok
:=
tool
.
(
map
[
string
]
any
)
toolMap
,
ok
:=
tool
.
(
map
[
string
]
any
)
if
!
ok
{
if
!
ok
{
// Keep unknown structure as-is to avoid breaking upstream behavior.
validTools
=
append
(
validTools
,
tool
)
continue
continue
}
}
toolType
,
_
:=
toolMap
[
"type"
]
.
(
string
)
toolType
,
_
:=
toolMap
[
"type"
]
.
(
string
)
if
strings
.
TrimSpace
(
toolType
)
!=
"function"
{
toolType
=
strings
.
TrimSpace
(
toolType
)
if
toolType
!=
"function"
{
validTools
=
append
(
validTools
,
toolMap
)
continue
continue
}
}
function
,
ok
:=
toolMap
[
"function"
]
.
(
map
[
string
]
any
)
// OpenAI Responses-style tools use top-level name/parameters.
if
!
ok
{
if
name
,
ok
:=
toolMap
[
"name"
]
.
(
string
);
ok
&&
strings
.
TrimSpace
(
name
)
!=
""
{
validTools
=
append
(
validTools
,
toolMap
)
continue
}
// ChatCompletions-style tools use {type:"function", function:{...}}.
functionValue
,
hasFunction
:=
toolMap
[
"function"
]
function
,
ok
:=
functionValue
.
(
map
[
string
]
any
)
if
!
hasFunction
||
functionValue
==
nil
||
!
ok
||
function
==
nil
{
// Drop invalid function tools.
modified
=
true
continue
continue
}
}
...
@@ -435,11 +451,11 @@ func normalizeCodexTools(reqBody map[string]any) bool {
...
@@ -435,11 +451,11 @@ func normalizeCodexTools(reqBody map[string]any) bool {
}
}
}
}
tools
[
idx
]
=
toolMap
validTools
=
append
(
validTools
,
toolMap
)
}
}
if
modified
{
if
modified
{
reqBody
[
"tools"
]
=
t
ools
reqBody
[
"tools"
]
=
validT
ools
}
}
return
modified
return
modified
...
...
backend/internal/service/openai_codex_transform_test.go
View file @
2fe8932c
...
@@ -129,6 +129,37 @@ func TestFilterCodexInput_RemovesItemReferenceWhenNotPreserved(t *testing.T) {
...
@@ -129,6 +129,37 @@ func TestFilterCodexInput_RemovesItemReferenceWhenNotPreserved(t *testing.T) {
require
.
False
(
t
,
hasID
)
require
.
False
(
t
,
hasID
)
}
}
func
TestApplyCodexOAuthTransform_NormalizeCodexTools_PreservesResponsesFunctionTools
(
t
*
testing
.
T
)
{
setupCodexCache
(
t
)
reqBody
:=
map
[
string
]
any
{
"model"
:
"gpt-5.1"
,
"tools"
:
[]
any
{
map
[
string
]
any
{
"type"
:
"function"
,
"name"
:
"bash"
,
"description"
:
"desc"
,
"parameters"
:
map
[
string
]
any
{
"type"
:
"object"
},
},
map
[
string
]
any
{
"type"
:
"function"
,
"function"
:
nil
,
},
},
}
applyCodexOAuthTransform
(
reqBody
)
tools
,
ok
:=
reqBody
[
"tools"
]
.
([]
any
)
require
.
True
(
t
,
ok
)
require
.
Len
(
t
,
tools
,
1
)
first
,
ok
:=
tools
[
0
]
.
(
map
[
string
]
any
)
require
.
True
(
t
,
ok
)
require
.
Equal
(
t
,
"function"
,
first
[
"type"
])
require
.
Equal
(
t
,
"bash"
,
first
[
"name"
])
}
func
TestApplyCodexOAuthTransform_EmptyInput
(
t
*
testing
.
T
)
{
func
TestApplyCodexOAuthTransform_EmptyInput
(
t
*
testing
.
T
)
{
// 空 input 应保持为空且不触发异常。
// 空 input 应保持为空且不触发异常。
setupCodexCache
(
t
)
setupCodexCache
(
t
)
...
...
backend/internal/service/openai_gateway_service.go
View file @
2fe8932c
...
@@ -60,6 +60,92 @@ type OpenAICodexUsageSnapshot struct {
...
@@ -60,6 +60,92 @@ type OpenAICodexUsageSnapshot struct {
UpdatedAt
string
`json:"updated_at,omitempty"`
UpdatedAt
string
`json:"updated_at,omitempty"`
}
}
// NormalizedCodexLimits contains normalized 5h/7d rate limit data
type
NormalizedCodexLimits
struct
{
Used5hPercent
*
float64
Reset5hSeconds
*
int
Window5hMinutes
*
int
Used7dPercent
*
float64
Reset7dSeconds
*
int
Window7dMinutes
*
int
}
// Normalize converts primary/secondary fields to canonical 5h/7d fields.
// Strategy: Compare window_minutes to determine which is 5h vs 7d.
// Returns nil if snapshot is nil or has no useful data.
func
(
s
*
OpenAICodexUsageSnapshot
)
Normalize
()
*
NormalizedCodexLimits
{
if
s
==
nil
{
return
nil
}
result
:=
&
NormalizedCodexLimits
{}
primaryMins
:=
0
secondaryMins
:=
0
hasPrimaryWindow
:=
false
hasSecondaryWindow
:=
false
if
s
.
PrimaryWindowMinutes
!=
nil
{
primaryMins
=
*
s
.
PrimaryWindowMinutes
hasPrimaryWindow
=
true
}
if
s
.
SecondaryWindowMinutes
!=
nil
{
secondaryMins
=
*
s
.
SecondaryWindowMinutes
hasSecondaryWindow
=
true
}
// Determine mapping based on window_minutes
use5hFromPrimary
:=
false
use7dFromPrimary
:=
false
if
hasPrimaryWindow
&&
hasSecondaryWindow
{
// Both known: smaller window is 5h, larger is 7d
if
primaryMins
<
secondaryMins
{
use5hFromPrimary
=
true
}
else
{
use7dFromPrimary
=
true
}
}
else
if
hasPrimaryWindow
{
// Only primary known: classify by threshold (<=360 min = 6h -> 5h window)
if
primaryMins
<=
360
{
use5hFromPrimary
=
true
}
else
{
use7dFromPrimary
=
true
}
}
else
if
hasSecondaryWindow
{
// Only secondary known: classify by threshold
if
secondaryMins
<=
360
{
// 5h from secondary, so primary (if any data) is 7d
use7dFromPrimary
=
true
}
else
{
// 7d from secondary, so primary (if any data) is 5h
use5hFromPrimary
=
true
}
}
else
{
// No window_minutes: fall back to legacy assumption (primary=7d, secondary=5h)
use7dFromPrimary
=
true
}
// Assign values
if
use5hFromPrimary
{
result
.
Used5hPercent
=
s
.
PrimaryUsedPercent
result
.
Reset5hSeconds
=
s
.
PrimaryResetAfterSeconds
result
.
Window5hMinutes
=
s
.
PrimaryWindowMinutes
result
.
Used7dPercent
=
s
.
SecondaryUsedPercent
result
.
Reset7dSeconds
=
s
.
SecondaryResetAfterSeconds
result
.
Window7dMinutes
=
s
.
SecondaryWindowMinutes
}
else
if
use7dFromPrimary
{
result
.
Used7dPercent
=
s
.
PrimaryUsedPercent
result
.
Reset7dSeconds
=
s
.
PrimaryResetAfterSeconds
result
.
Window7dMinutes
=
s
.
PrimaryWindowMinutes
result
.
Used5hPercent
=
s
.
SecondaryUsedPercent
result
.
Reset5hSeconds
=
s
.
SecondaryResetAfterSeconds
result
.
Window5hMinutes
=
s
.
SecondaryWindowMinutes
}
return
result
}
// OpenAIUsage represents OpenAI API response usage
// OpenAIUsage represents OpenAI API response usage
type
OpenAIUsage
struct
{
type
OpenAIUsage
struct
{
InputTokens
int
`json:"input_tokens"`
InputTokens
int
`json:"input_tokens"`
...
@@ -133,12 +219,30 @@ func NewOpenAIGatewayService(
...
@@ -133,12 +219,30 @@ func NewOpenAIGatewayService(
}
}
}
}
// GenerateSessionHash generates session hash from header (OpenAI uses session_id header)
// GenerateSessionHash generates a sticky-session hash for OpenAI requests.
func
(
s
*
OpenAIGatewayService
)
GenerateSessionHash
(
c
*
gin
.
Context
)
string
{
//
sessionID
:=
c
.
GetHeader
(
"session_id"
)
// Priority:
// 1. Header: session_id
// 2. Header: conversation_id
// 3. Body: prompt_cache_key (opencode)
func
(
s
*
OpenAIGatewayService
)
GenerateSessionHash
(
c
*
gin
.
Context
,
reqBody
map
[
string
]
any
)
string
{
if
c
==
nil
{
return
""
}
sessionID
:=
strings
.
TrimSpace
(
c
.
GetHeader
(
"session_id"
))
if
sessionID
==
""
{
sessionID
=
strings
.
TrimSpace
(
c
.
GetHeader
(
"conversation_id"
))
}
if
sessionID
==
""
&&
reqBody
!=
nil
{
if
v
,
ok
:=
reqBody
[
"prompt_cache_key"
]
.
(
string
);
ok
{
sessionID
=
strings
.
TrimSpace
(
v
)
}
}
if
sessionID
==
""
{
if
sessionID
==
""
{
return
""
return
""
}
}
hash
:=
sha256
.
Sum256
([]
byte
(
sessionID
))
hash
:=
sha256
.
Sum256
([]
byte
(
sessionID
))
return
hex
.
EncodeToString
(
hash
[
:
])
return
hex
.
EncodeToString
(
hash
[
:
])
}
}
...
@@ -162,81 +266,164 @@ func (s *OpenAIGatewayService) SelectAccountForModel(ctx context.Context, groupI
...
@@ -162,81 +266,164 @@ func (s *OpenAIGatewayService) SelectAccountForModel(ctx context.Context, groupI
}
}
// SelectAccountForModelWithExclusions selects an account supporting the requested model while excluding specified accounts.
// SelectAccountForModelWithExclusions selects an account supporting the requested model while excluding specified accounts.
// SelectAccountForModelWithExclusions 选择支持指定模型的账号,同时排除指定的账号。
func
(
s
*
OpenAIGatewayService
)
SelectAccountForModelWithExclusions
(
ctx
context
.
Context
,
groupID
*
int64
,
sessionHash
string
,
requestedModel
string
,
excludedIDs
map
[
int64
]
struct
{})
(
*
Account
,
error
)
{
func
(
s
*
OpenAIGatewayService
)
SelectAccountForModelWithExclusions
(
ctx
context
.
Context
,
groupID
*
int64
,
sessionHash
string
,
requestedModel
string
,
excludedIDs
map
[
int64
]
struct
{})
(
*
Account
,
error
)
{
// 1. Check sticky session
cacheKey
:=
"openai:"
+
sessionHash
if
sessionHash
!=
""
{
accountID
,
err
:=
s
.
cache
.
GetSessionAccountID
(
ctx
,
derefGroupID
(
groupID
),
"openai:"
+
sessionHash
)
// 1. 尝试粘性会话命中
if
err
==
nil
&&
accountID
>
0
{
// Try sticky session hit
if
_
,
excluded
:=
excludedIDs
[
accountID
];
!
excluded
{
if
account
:=
s
.
tryStickySessionHit
(
ctx
,
groupID
,
sessionHash
,
cacheKey
,
requestedModel
,
excludedIDs
);
account
!=
nil
{
account
,
err
:=
s
.
getSchedulableAccount
(
ctx
,
accountID
)
return
account
,
nil
if
err
==
nil
&&
account
.
IsSchedulable
()
&&
account
.
IsOpenAI
()
&&
(
requestedModel
==
""
||
account
.
IsModelSupported
(
requestedModel
))
{
// Refresh sticky session TTL
_
=
s
.
cache
.
RefreshSessionTTL
(
ctx
,
derefGroupID
(
groupID
),
"openai:"
+
sessionHash
,
openaiStickySessionTTL
)
return
account
,
nil
}
}
}
}
}
// 2. Get schedulable OpenAI accounts
// 2. 获取可调度的 OpenAI 账号
// Get schedulable OpenAI accounts
accounts
,
err
:=
s
.
listSchedulableAccounts
(
ctx
,
groupID
)
accounts
,
err
:=
s
.
listSchedulableAccounts
(
ctx
,
groupID
)
if
err
!=
nil
{
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"query accounts failed: %w"
,
err
)
return
nil
,
fmt
.
Errorf
(
"query accounts failed: %w"
,
err
)
}
}
// 3. Select by priority + LRU
// 3. 按优先级 + LRU 选择最佳账号
// Select by priority + LRU
selected
:=
s
.
selectBestAccount
(
accounts
,
requestedModel
,
excludedIDs
)
if
selected
==
nil
{
if
requestedModel
!=
""
{
return
nil
,
fmt
.
Errorf
(
"no available OpenAI accounts supporting model: %s"
,
requestedModel
)
}
return
nil
,
errors
.
New
(
"no available OpenAI accounts"
)
}
// 4. 设置粘性会话绑定
// Set sticky session binding
if
sessionHash
!=
""
{
_
=
s
.
cache
.
SetSessionAccountID
(
ctx
,
derefGroupID
(
groupID
),
cacheKey
,
selected
.
ID
,
openaiStickySessionTTL
)
}
return
selected
,
nil
}
// tryStickySessionHit 尝试从粘性会话获取账号。
// 如果命中且账号可用则返回账号;如果账号不可用则清理会话并返回 nil。
//
// tryStickySessionHit attempts to get account from sticky session.
// Returns account if hit and usable; clears session and returns nil if account is unavailable.
func
(
s
*
OpenAIGatewayService
)
tryStickySessionHit
(
ctx
context
.
Context
,
groupID
*
int64
,
sessionHash
,
cacheKey
,
requestedModel
string
,
excludedIDs
map
[
int64
]
struct
{})
*
Account
{
if
sessionHash
==
""
{
return
nil
}
accountID
,
err
:=
s
.
cache
.
GetSessionAccountID
(
ctx
,
derefGroupID
(
groupID
),
cacheKey
)
if
err
!=
nil
||
accountID
<=
0
{
return
nil
}
if
_
,
excluded
:=
excludedIDs
[
accountID
];
excluded
{
return
nil
}
account
,
err
:=
s
.
getSchedulableAccount
(
ctx
,
accountID
)
if
err
!=
nil
{
return
nil
}
// 检查账号是否需要清理粘性会话
// Check if sticky session should be cleared
if
shouldClearStickySession
(
account
)
{
_
=
s
.
cache
.
DeleteSessionAccountID
(
ctx
,
derefGroupID
(
groupID
),
cacheKey
)
return
nil
}
// 验证账号是否可用于当前请求
// Verify account is usable for current request
if
!
account
.
IsSchedulable
()
||
!
account
.
IsOpenAI
()
{
return
nil
}
if
requestedModel
!=
""
&&
!
account
.
IsModelSupported
(
requestedModel
)
{
return
nil
}
// 刷新会话 TTL 并返回账号
// Refresh session TTL and return account
_
=
s
.
cache
.
RefreshSessionTTL
(
ctx
,
derefGroupID
(
groupID
),
cacheKey
,
openaiStickySessionTTL
)
return
account
}
// selectBestAccount 从候选账号中选择最佳账号(优先级 + LRU)。
// 返回 nil 表示无可用账号。
//
// selectBestAccount selects the best account from candidates (priority + LRU).
// Returns nil if no available account.
func
(
s
*
OpenAIGatewayService
)
selectBestAccount
(
accounts
[]
Account
,
requestedModel
string
,
excludedIDs
map
[
int64
]
struct
{})
*
Account
{
var
selected
*
Account
var
selected
*
Account
for
i
:=
range
accounts
{
for
i
:=
range
accounts
{
acc
:=
&
accounts
[
i
]
acc
:=
&
accounts
[
i
]
// 跳过被排除的账号
// Skip excluded accounts
if
_
,
excluded
:=
excludedIDs
[
acc
.
ID
];
excluded
{
if
_
,
excluded
:=
excludedIDs
[
acc
.
ID
];
excluded
{
continue
continue
}
}
// Scheduler snapshots can be temporarily stale; re-check schedulability here to
// avoid selecting accounts that were recently rate-limited/overloaded.
// 调度器快照可能暂时过时,这里重新检查可调度性和平台
if
!
acc
.
IsSchedulable
()
{
// Scheduler snapshots can be temporarily stale; re-check schedulability and platform
if
!
acc
.
IsSchedulable
()
||
!
acc
.
IsOpenAI
()
{
continue
continue
}
}
// 检查模型支持
// Check model support
// Check model support
if
requestedModel
!=
""
&&
!
acc
.
IsModelSupported
(
requestedModel
)
{
if
requestedModel
!=
""
&&
!
acc
.
IsModelSupported
(
requestedModel
)
{
continue
continue
}
}
// 选择优先级最高且最久未使用的账号
// Select highest priority and least recently used
if
selected
==
nil
{
if
selected
==
nil
{
selected
=
acc
selected
=
acc
continue
continue
}
}
// Lower priority value means higher priority
if
acc
.
Priority
<
selected
.
Priority
{
if
s
.
isBetterAccount
(
acc
,
selected
)
{
selected
=
acc
selected
=
acc
}
else
if
acc
.
Priority
==
selected
.
Priority
{
switch
{
case
acc
.
LastUsedAt
==
nil
&&
selected
.
LastUsedAt
!=
nil
:
selected
=
acc
case
acc
.
LastUsedAt
!=
nil
&&
selected
.
LastUsedAt
==
nil
:
// keep selected (never used is preferred)
case
acc
.
LastUsedAt
==
nil
&&
selected
.
LastUsedAt
==
nil
:
// keep selected (both never used)
default
:
// Same priority, select least recently used
if
acc
.
LastUsedAt
.
Before
(
*
selected
.
LastUsedAt
)
{
selected
=
acc
}
}
}
}
}
}
if
selected
==
nil
{
return
selected
if
requestedModel
!=
""
{
}
return
nil
,
fmt
.
Errorf
(
"no available OpenAI accounts supporting model: %s"
,
requestedModel
)
}
return
nil
,
errors
.
New
(
"no available OpenAI accounts"
)
}
// 4. Set sticky session
// isBetterAccount 判断 candidate 是否比 current 更优。
if
sessionHash
!=
""
{
// 规则:优先级更高(数值更小)优先;同优先级时,未使用过的优先,其次是最久未使用的。
_
=
s
.
cache
.
SetSessionAccountID
(
ctx
,
derefGroupID
(
groupID
),
"openai:"
+
sessionHash
,
selected
.
ID
,
openaiStickySessionTTL
)
//
// isBetterAccount checks if candidate is better than current.
// Rules: higher priority (lower value) wins; same priority: never used > least recently used.
func
(
s
*
OpenAIGatewayService
)
isBetterAccount
(
candidate
,
current
*
Account
)
bool
{
// 优先级更高(数值更小)
// Higher priority (lower value)
if
candidate
.
Priority
<
current
.
Priority
{
return
true
}
if
candidate
.
Priority
>
current
.
Priority
{
return
false
}
}
return
selected
,
nil
// 同优先级,比较最后使用时间
// Same priority, compare last used time
switch
{
case
candidate
.
LastUsedAt
==
nil
&&
current
.
LastUsedAt
!=
nil
:
// candidate 从未使用,优先
return
true
case
candidate
.
LastUsedAt
!=
nil
&&
current
.
LastUsedAt
==
nil
:
// current 从未使用,保持
return
false
case
candidate
.
LastUsedAt
==
nil
&&
current
.
LastUsedAt
==
nil
:
// 都未使用,保持
return
false
default
:
// 都使用过,选择最久未使用的
return
candidate
.
LastUsedAt
.
Before
(
*
current
.
LastUsedAt
)
}
}
}
// SelectAccountWithLoadAwareness selects an account with load-awareness and wait plan.
// SelectAccountWithLoadAwareness selects an account with load-awareness and wait plan.
...
@@ -307,29 +494,35 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex
...
@@ -307,29 +494,35 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex
accountID
,
err
:=
s
.
cache
.
GetSessionAccountID
(
ctx
,
derefGroupID
(
groupID
),
"openai:"
+
sessionHash
)
accountID
,
err
:=
s
.
cache
.
GetSessionAccountID
(
ctx
,
derefGroupID
(
groupID
),
"openai:"
+
sessionHash
)
if
err
==
nil
&&
accountID
>
0
&&
!
isExcluded
(
accountID
)
{
if
err
==
nil
&&
accountID
>
0
&&
!
isExcluded
(
accountID
)
{
account
,
err
:=
s
.
getSchedulableAccount
(
ctx
,
accountID
)
account
,
err
:=
s
.
getSchedulableAccount
(
ctx
,
accountID
)
if
err
==
nil
&&
account
.
IsSchedulable
()
&&
account
.
IsOpenAI
()
&&
if
err
==
nil
{
(
requestedModel
==
""
||
account
.
IsModelSupported
(
requestedModel
))
{
clearSticky
:=
shouldClearStickySession
(
account
)
result
,
err
:=
s
.
tryAcquireAccountSlot
(
ctx
,
accountID
,
account
.
Concurrency
)
if
clearSticky
{
if
err
==
nil
&&
result
.
Acquired
{
_
=
s
.
cache
.
DeleteSessionAccountID
(
ctx
,
derefGroupID
(
groupID
),
"openai:"
+
sessionHash
)
_
=
s
.
cache
.
RefreshSessionTTL
(
ctx
,
derefGroupID
(
groupID
),
"openai:"
+
sessionHash
,
openaiStickySessionTTL
)
return
&
AccountSelectionResult
{
Account
:
account
,
Acquired
:
true
,
ReleaseFunc
:
result
.
ReleaseFunc
,
},
nil
}
}
if
!
clearSticky
&&
account
.
IsSchedulable
()
&&
account
.
IsOpenAI
()
&&
(
requestedModel
==
""
||
account
.
IsModelSupported
(
requestedModel
))
{
result
,
err
:=
s
.
tryAcquireAccountSlot
(
ctx
,
accountID
,
account
.
Concurrency
)
if
err
==
nil
&&
result
.
Acquired
{
_
=
s
.
cache
.
RefreshSessionTTL
(
ctx
,
derefGroupID
(
groupID
),
"openai:"
+
sessionHash
,
openaiStickySessionTTL
)
return
&
AccountSelectionResult
{
Account
:
account
,
Acquired
:
true
,
ReleaseFunc
:
result
.
ReleaseFunc
,
},
nil
}
waitingCount
,
_
:=
s
.
concurrencyService
.
GetAccountWaitingCount
(
ctx
,
accountID
)
waitingCount
,
_
:=
s
.
concurrencyService
.
GetAccountWaitingCount
(
ctx
,
accountID
)
if
waitingCount
<
cfg
.
StickySessionMaxWaiting
{
if
waitingCount
<
cfg
.
StickySessionMaxWaiting
{
return
&
AccountSelectionResult
{
return
&
AccountSelectionResult
{
Account
:
account
,
Account
:
account
,
WaitPlan
:
&
AccountWaitPlan
{
WaitPlan
:
&
AccountWaitPlan
{
AccountID
:
accountID
,
AccountID
:
accountID
,
MaxConcurrency
:
account
.
Concurrency
,
MaxConcurrency
:
account
.
Concurrency
,
Timeout
:
cfg
.
StickySessionWaitTimeout
,
Timeout
:
cfg
.
StickySessionWaitTimeout
,
MaxWaiting
:
cfg
.
StickySessionMaxWaiting
,
MaxWaiting
:
cfg
.
StickySessionMaxWaiting
,
},
},
},
nil
},
nil
}
}
}
}
}
}
}
...
@@ -760,7 +953,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
...
@@ -760,7 +953,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
// Extract and save Codex usage snapshot from response headers (for OAuth accounts)
// Extract and save Codex usage snapshot from response headers (for OAuth accounts)
if
account
.
Type
==
AccountTypeOAuth
{
if
account
.
Type
==
AccountTypeOAuth
{
if
snapshot
:=
extractCodexUsage
Headers
(
resp
.
Header
);
snapshot
!=
nil
{
if
snapshot
:=
ParseCodexRateLimit
Headers
(
resp
.
Header
);
snapshot
!=
nil
{
s
.
updateCodexUsageSnapshot
(
ctx
,
account
.
ID
,
snapshot
)
s
.
updateCodexUsageSnapshot
(
ctx
,
account
.
ID
,
snapshot
)
}
}
}
}
...
@@ -1599,8 +1792,9 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
...
@@ -1599,8 +1792,9 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
return
nil
return
nil
}
}
// extractCodexUsageHeaders extracts Codex usage limits from response headers
// ParseCodexRateLimitHeaders extracts Codex usage limits from response headers.
func
extractCodexUsageHeaders
(
headers
http
.
Header
)
*
OpenAICodexUsageSnapshot
{
// Exported for use in ratelimit_service when handling OpenAI 429 responses.
func
ParseCodexRateLimitHeaders
(
headers
http
.
Header
)
*
OpenAICodexUsageSnapshot
{
snapshot
:=
&
OpenAICodexUsageSnapshot
{}
snapshot
:=
&
OpenAICodexUsageSnapshot
{}
hasData
:=
false
hasData
:=
false
...
@@ -1674,6 +1868,8 @@ func (s *OpenAIGatewayService) updateCodexUsageSnapshot(ctx context.Context, acc
...
@@ -1674,6 +1868,8 @@ func (s *OpenAIGatewayService) updateCodexUsageSnapshot(ctx context.Context, acc
// Convert snapshot to map for merging into Extra
// Convert snapshot to map for merging into Extra
updates
:=
make
(
map
[
string
]
any
)
updates
:=
make
(
map
[
string
]
any
)
// Save raw primary/secondary fields for debugging/tracing
if
snapshot
.
PrimaryUsedPercent
!=
nil
{
if
snapshot
.
PrimaryUsedPercent
!=
nil
{
updates
[
"codex_primary_used_percent"
]
=
*
snapshot
.
PrimaryUsedPercent
updates
[
"codex_primary_used_percent"
]
=
*
snapshot
.
PrimaryUsedPercent
}
}
...
@@ -1697,109 +1893,25 @@ func (s *OpenAIGatewayService) updateCodexUsageSnapshot(ctx context.Context, acc
...
@@ -1697,109 +1893,25 @@ func (s *OpenAIGatewayService) updateCodexUsageSnapshot(ctx context.Context, acc
}
}
updates
[
"codex_usage_updated_at"
]
=
snapshot
.
UpdatedAt
updates
[
"codex_usage_updated_at"
]
=
snapshot
.
UpdatedAt
// Normalize to canonical 5h/7d fields based on window_minutes
// Normalize to canonical 5h/7d fields
// This fixes the issue where OpenAI's primary/secondary naming is reversed
if
normalized
:=
snapshot
.
Normalize
();
normalized
!=
nil
{
// Strategy: Compare the two windows and assign the smaller one to 5h, larger one to 7d
if
normalized
.
Used5hPercent
!=
nil
{
updates
[
"codex_5h_used_percent"
]
=
*
normalized
.
Used5hPercent
// IMPORTANT: We can only reliably determine window type from window_minutes field
// The reset_after_seconds is remaining time, not window size, so it cannot be used for comparison
var
primaryWindowMins
,
secondaryWindowMins
int
var
hasPrimaryWindow
,
hasSecondaryWindow
bool
// Only use window_minutes for reliable window size comparison
if
snapshot
.
PrimaryWindowMinutes
!=
nil
{
primaryWindowMins
=
*
snapshot
.
PrimaryWindowMinutes
hasPrimaryWindow
=
true
}
if
snapshot
.
SecondaryWindowMinutes
!=
nil
{
secondaryWindowMins
=
*
snapshot
.
SecondaryWindowMinutes
hasSecondaryWindow
=
true
}
// Determine which is 5h and which is 7d
var
use5hFromPrimary
,
use7dFromPrimary
bool
var
use5hFromSecondary
,
use7dFromSecondary
bool
if
hasPrimaryWindow
&&
hasSecondaryWindow
{
// Both window sizes known: compare and assign smaller to 5h, larger to 7d
if
primaryWindowMins
<
secondaryWindowMins
{
use5hFromPrimary
=
true
use7dFromSecondary
=
true
}
else
{
use5hFromSecondary
=
true
use7dFromPrimary
=
true
}
}
else
if
hasPrimaryWindow
{
// Only primary window size known: classify by absolute threshold
if
primaryWindowMins
<=
360
{
use5hFromPrimary
=
true
}
else
{
use7dFromPrimary
=
true
}
}
else
if
hasSecondaryWindow
{
// Only secondary window size known: classify by absolute threshold
if
secondaryWindowMins
<=
360
{
use5hFromSecondary
=
true
}
else
{
use7dFromSecondary
=
true
}
}
else
{
// No window_minutes available: cannot reliably determine window types
// Fall back to legacy assumption (may be incorrect)
// Assume primary=7d, secondary=5h based on historical observation
if
snapshot
.
SecondaryUsedPercent
!=
nil
||
snapshot
.
SecondaryResetAfterSeconds
!=
nil
||
snapshot
.
SecondaryWindowMinutes
!=
nil
{
use5hFromSecondary
=
true
}
if
snapshot
.
PrimaryUsedPercent
!=
nil
||
snapshot
.
PrimaryResetAfterSeconds
!=
nil
||
snapshot
.
PrimaryWindowMinutes
!=
nil
{
use7dFromPrimary
=
true
}
}
// Write canonical 5h fields
if
use5hFromPrimary
{
if
snapshot
.
PrimaryUsedPercent
!=
nil
{
updates
[
"codex_5h_used_percent"
]
=
*
snapshot
.
PrimaryUsedPercent
}
if
snapshot
.
PrimaryResetAfterSeconds
!=
nil
{
updates
[
"codex_5h_reset_after_seconds"
]
=
*
snapshot
.
PrimaryResetAfterSeconds
}
if
snapshot
.
PrimaryWindowMinutes
!=
nil
{
updates
[
"codex_5h_window_minutes"
]
=
*
snapshot
.
PrimaryWindowMinutes
}
}
else
if
use5hFromSecondary
{
if
snapshot
.
SecondaryUsedPercent
!=
nil
{
updates
[
"codex_5h_used_percent"
]
=
*
snapshot
.
SecondaryUsedPercent
}
if
snapshot
.
SecondaryResetAfterSeconds
!=
nil
{
updates
[
"codex_5h_reset_after_seconds"
]
=
*
snapshot
.
SecondaryResetAfterSeconds
}
if
snapshot
.
SecondaryWindowMinutes
!=
nil
{
updates
[
"codex_5h_window_minutes"
]
=
*
snapshot
.
SecondaryWindowMinutes
}
}
// Write canonical 7d fields
if
use7dFromPrimary
{
if
snapshot
.
PrimaryUsedPercent
!=
nil
{
updates
[
"codex_7d_used_percent"
]
=
*
snapshot
.
PrimaryUsedPercent
}
}
if
snapshot
.
PrimaryResetAfter
Seconds
!=
nil
{
if
normalized
.
Reset5h
Seconds
!=
nil
{
updates
[
"codex_
7d
_reset_after_seconds"
]
=
*
snapshot
.
PrimaryResetAfter
Seconds
updates
[
"codex_
5h
_reset_after_seconds"
]
=
*
normalized
.
Reset5h
Seconds
}
}
if
snapshot
.
Primary
WindowMinutes
!=
nil
{
if
normalized
.
Window
5h
Minutes
!=
nil
{
updates
[
"codex_
7d
_window_minutes"
]
=
*
snapshot
.
Primary
WindowMinutes
updates
[
"codex_
5h
_window_minutes"
]
=
*
normalized
.
Window
5h
Minutes
}
}
}
else
if
use7dFromSecondary
{
if
normalized
.
Used7dPercent
!=
nil
{
if
snapshot
.
SecondaryUsedPercent
!=
nil
{
updates
[
"codex_7d_used_percent"
]
=
*
normalized
.
Used7dPercent
updates
[
"codex_7d_used_percent"
]
=
*
snapshot
.
SecondaryUsedPercent
}
}
if
snapshot
.
SecondaryResetAfter
Seconds
!=
nil
{
if
normalized
.
Reset7d
Seconds
!=
nil
{
updates
[
"codex_7d_reset_after_seconds"
]
=
*
snapshot
.
SecondaryResetAfter
Seconds
updates
[
"codex_7d_reset_after_seconds"
]
=
*
normalized
.
Reset7d
Seconds
}
}
if
snapshot
.
Secondary
WindowMinutes
!=
nil
{
if
normalized
.
Window
7d
Minutes
!=
nil
{
updates
[
"codex_7d_window_minutes"
]
=
*
snapshot
.
Secondary
WindowMinutes
updates
[
"codex_7d_window_minutes"
]
=
*
normalized
.
Window
7d
Minutes
}
}
}
}
...
...
backend/internal/service/openai_gateway_service_test.go
View file @
2fe8932c
...
@@ -21,16 +21,42 @@ type stubOpenAIAccountRepo struct {
...
@@ -21,16 +21,42 @@ type stubOpenAIAccountRepo struct {
accounts
[]
Account
accounts
[]
Account
}
}
func
(
r
stubOpenAIAccountRepo
)
GetByID
(
ctx
context
.
Context
,
id
int64
)
(
*
Account
,
error
)
{
for
i
:=
range
r
.
accounts
{
if
r
.
accounts
[
i
]
.
ID
==
id
{
return
&
r
.
accounts
[
i
],
nil
}
}
return
nil
,
errors
.
New
(
"account not found"
)
}
func
(
r
stubOpenAIAccountRepo
)
ListSchedulableByGroupIDAndPlatform
(
ctx
context
.
Context
,
groupID
int64
,
platform
string
)
([]
Account
,
error
)
{
func
(
r
stubOpenAIAccountRepo
)
ListSchedulableByGroupIDAndPlatform
(
ctx
context
.
Context
,
groupID
int64
,
platform
string
)
([]
Account
,
error
)
{
return
append
([]
Account
(
nil
),
r
.
accounts
...
),
nil
var
result
[]
Account
for
_
,
acc
:=
range
r
.
accounts
{
if
acc
.
Platform
==
platform
{
result
=
append
(
result
,
acc
)
}
}
return
result
,
nil
}
}
func
(
r
stubOpenAIAccountRepo
)
ListSchedulableByPlatform
(
ctx
context
.
Context
,
platform
string
)
([]
Account
,
error
)
{
func
(
r
stubOpenAIAccountRepo
)
ListSchedulableByPlatform
(
ctx
context
.
Context
,
platform
string
)
([]
Account
,
error
)
{
return
append
([]
Account
(
nil
),
r
.
accounts
...
),
nil
var
result
[]
Account
for
_
,
acc
:=
range
r
.
accounts
{
if
acc
.
Platform
==
platform
{
result
=
append
(
result
,
acc
)
}
}
return
result
,
nil
}
}
type
stubConcurrencyCache
struct
{
type
stubConcurrencyCache
struct
{
ConcurrencyCache
ConcurrencyCache
loadBatchErr
error
loadMap
map
[
int64
]
*
AccountLoadInfo
acquireResults
map
[
int64
]
bool
waitCounts
map
[
int64
]
int
skipDefaultLoad
bool
}
}
type
cancelReadCloser
struct
{}
type
cancelReadCloser
struct
{}
...
@@ -53,6 +79,11 @@ func (w *failingGinWriter) Write(p []byte) (int, error) {
...
@@ -53,6 +79,11 @@ func (w *failingGinWriter) Write(p []byte) (int, error) {
}
}
func
(
c
stubConcurrencyCache
)
AcquireAccountSlot
(
ctx
context
.
Context
,
accountID
int64
,
maxConcurrency
int
,
requestID
string
)
(
bool
,
error
)
{
func
(
c
stubConcurrencyCache
)
AcquireAccountSlot
(
ctx
context
.
Context
,
accountID
int64
,
maxConcurrency
int
,
requestID
string
)
(
bool
,
error
)
{
if
c
.
acquireResults
!=
nil
{
if
result
,
ok
:=
c
.
acquireResults
[
accountID
];
ok
{
return
result
,
nil
}
}
return
true
,
nil
return
true
,
nil
}
}
...
@@ -61,13 +92,118 @@ func (c stubConcurrencyCache) ReleaseAccountSlot(ctx context.Context, accountID
...
@@ -61,13 +92,118 @@ func (c stubConcurrencyCache) ReleaseAccountSlot(ctx context.Context, accountID
}
}
func
(
c
stubConcurrencyCache
)
GetAccountsLoadBatch
(
ctx
context
.
Context
,
accounts
[]
AccountWithConcurrency
)
(
map
[
int64
]
*
AccountLoadInfo
,
error
)
{
func
(
c
stubConcurrencyCache
)
GetAccountsLoadBatch
(
ctx
context
.
Context
,
accounts
[]
AccountWithConcurrency
)
(
map
[
int64
]
*
AccountLoadInfo
,
error
)
{
if
c
.
loadBatchErr
!=
nil
{
return
nil
,
c
.
loadBatchErr
}
out
:=
make
(
map
[
int64
]
*
AccountLoadInfo
,
len
(
accounts
))
out
:=
make
(
map
[
int64
]
*
AccountLoadInfo
,
len
(
accounts
))
if
c
.
skipDefaultLoad
&&
c
.
loadMap
!=
nil
{
for
_
,
acc
:=
range
accounts
{
if
load
,
ok
:=
c
.
loadMap
[
acc
.
ID
];
ok
{
out
[
acc
.
ID
]
=
load
}
}
return
out
,
nil
}
for
_
,
acc
:=
range
accounts
{
for
_
,
acc
:=
range
accounts
{
if
c
.
loadMap
!=
nil
{
if
load
,
ok
:=
c
.
loadMap
[
acc
.
ID
];
ok
{
out
[
acc
.
ID
]
=
load
continue
}
}
out
[
acc
.
ID
]
=
&
AccountLoadInfo
{
AccountID
:
acc
.
ID
,
LoadRate
:
0
}
out
[
acc
.
ID
]
=
&
AccountLoadInfo
{
AccountID
:
acc
.
ID
,
LoadRate
:
0
}
}
}
return
out
,
nil
return
out
,
nil
}
}
func
TestOpenAIGatewayService_GenerateSessionHash_Priority
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
rec
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
rec
)
c
.
Request
=
httptest
.
NewRequest
(
http
.
MethodPost
,
"/openai/v1/responses"
,
nil
)
svc
:=
&
OpenAIGatewayService
{}
// 1) session_id header wins
c
.
Request
.
Header
.
Set
(
"session_id"
,
"sess-123"
)
c
.
Request
.
Header
.
Set
(
"conversation_id"
,
"conv-456"
)
h1
:=
svc
.
GenerateSessionHash
(
c
,
map
[
string
]
any
{
"prompt_cache_key"
:
"ses_aaa"
})
if
h1
==
""
{
t
.
Fatalf
(
"expected non-empty hash"
)
}
// 2) conversation_id used when session_id absent
c
.
Request
.
Header
.
Del
(
"session_id"
)
h2
:=
svc
.
GenerateSessionHash
(
c
,
map
[
string
]
any
{
"prompt_cache_key"
:
"ses_aaa"
})
if
h2
==
""
{
t
.
Fatalf
(
"expected non-empty hash"
)
}
if
h1
==
h2
{
t
.
Fatalf
(
"expected different hashes for different keys"
)
}
// 3) prompt_cache_key used when both headers absent
c
.
Request
.
Header
.
Del
(
"conversation_id"
)
h3
:=
svc
.
GenerateSessionHash
(
c
,
map
[
string
]
any
{
"prompt_cache_key"
:
"ses_aaa"
})
if
h3
==
""
{
t
.
Fatalf
(
"expected non-empty hash"
)
}
if
h2
==
h3
{
t
.
Fatalf
(
"expected different hashes for different keys"
)
}
// 4) empty when no signals
h4
:=
svc
.
GenerateSessionHash
(
c
,
map
[
string
]
any
{})
if
h4
!=
""
{
t
.
Fatalf
(
"expected empty hash when no signals"
)
}
}
func
(
c
stubConcurrencyCache
)
GetAccountWaitingCount
(
ctx
context
.
Context
,
accountID
int64
)
(
int
,
error
)
{
if
c
.
waitCounts
!=
nil
{
if
count
,
ok
:=
c
.
waitCounts
[
accountID
];
ok
{
return
count
,
nil
}
}
return
0
,
nil
}
type
stubGatewayCache
struct
{
sessionBindings
map
[
string
]
int64
deletedSessions
map
[
string
]
int
}
func
(
c
*
stubGatewayCache
)
GetSessionAccountID
(
ctx
context
.
Context
,
groupID
int64
,
sessionHash
string
)
(
int64
,
error
)
{
if
id
,
ok
:=
c
.
sessionBindings
[
sessionHash
];
ok
{
return
id
,
nil
}
return
0
,
errors
.
New
(
"not found"
)
}
func
(
c
*
stubGatewayCache
)
SetSessionAccountID
(
ctx
context
.
Context
,
groupID
int64
,
sessionHash
string
,
accountID
int64
,
ttl
time
.
Duration
)
error
{
if
c
.
sessionBindings
==
nil
{
c
.
sessionBindings
=
make
(
map
[
string
]
int64
)
}
c
.
sessionBindings
[
sessionHash
]
=
accountID
return
nil
}
func
(
c
*
stubGatewayCache
)
RefreshSessionTTL
(
ctx
context
.
Context
,
groupID
int64
,
sessionHash
string
,
ttl
time
.
Duration
)
error
{
return
nil
}
func
(
c
*
stubGatewayCache
)
DeleteSessionAccountID
(
ctx
context
.
Context
,
groupID
int64
,
sessionHash
string
)
error
{
if
c
.
sessionBindings
==
nil
{
return
nil
}
if
c
.
deletedSessions
==
nil
{
c
.
deletedSessions
=
make
(
map
[
string
]
int
)
}
c
.
deletedSessions
[
sessionHash
]
++
delete
(
c
.
sessionBindings
,
sessionHash
)
return
nil
}
func
TestOpenAISelectAccountWithLoadAwareness_FiltersUnschedulable
(
t
*
testing
.
T
)
{
func
TestOpenAISelectAccountWithLoadAwareness_FiltersUnschedulable
(
t
*
testing
.
T
)
{
now
:=
time
.
Now
()
now
:=
time
.
Now
()
resetAt
:=
now
.
Add
(
10
*
time
.
Minute
)
resetAt
:=
now
.
Add
(
10
*
time
.
Minute
)
...
@@ -158,6 +294,515 @@ func TestOpenAISelectAccountWithLoadAwareness_FiltersUnschedulableWhenNoConcurre
...
@@ -158,6 +294,515 @@ func TestOpenAISelectAccountWithLoadAwareness_FiltersUnschedulableWhenNoConcurre
}
}
}
}
func
TestOpenAISelectAccountForModelWithExclusions_StickyUnschedulableClearsSession
(
t
*
testing
.
T
)
{
sessionHash
:=
"session-1"
repo
:=
stubOpenAIAccountRepo
{
accounts
:
[]
Account
{
{
ID
:
1
,
Platform
:
PlatformOpenAI
,
Status
:
StatusDisabled
,
Schedulable
:
true
,
Concurrency
:
1
},
{
ID
:
2
,
Platform
:
PlatformOpenAI
,
Status
:
StatusActive
,
Schedulable
:
true
,
Concurrency
:
1
},
},
}
cache
:=
&
stubGatewayCache
{
sessionBindings
:
map
[
string
]
int64
{
"openai:"
+
sessionHash
:
1
},
}
svc
:=
&
OpenAIGatewayService
{
accountRepo
:
repo
,
cache
:
cache
,
}
acc
,
err
:=
svc
.
SelectAccountForModelWithExclusions
(
context
.
Background
(),
nil
,
sessionHash
,
"gpt-4"
,
nil
)
if
err
!=
nil
{
t
.
Fatalf
(
"SelectAccountForModelWithExclusions error: %v"
,
err
)
}
if
acc
==
nil
||
acc
.
ID
!=
2
{
t
.
Fatalf
(
"expected account 2, got %+v"
,
acc
)
}
if
cache
.
deletedSessions
[
"openai:"
+
sessionHash
]
!=
1
{
t
.
Fatalf
(
"expected sticky session to be deleted"
)
}
if
cache
.
sessionBindings
[
"openai:"
+
sessionHash
]
!=
2
{
t
.
Fatalf
(
"expected sticky session to bind to account 2"
)
}
}
func
TestOpenAISelectAccountWithLoadAwareness_StickyUnschedulableClearsSession
(
t
*
testing
.
T
)
{
sessionHash
:=
"session-2"
groupID
:=
int64
(
1
)
repo
:=
stubOpenAIAccountRepo
{
accounts
:
[]
Account
{
{
ID
:
1
,
Platform
:
PlatformOpenAI
,
Status
:
StatusDisabled
,
Schedulable
:
true
,
Concurrency
:
1
},
{
ID
:
2
,
Platform
:
PlatformOpenAI
,
Status
:
StatusActive
,
Schedulable
:
true
,
Concurrency
:
1
},
},
}
cache
:=
&
stubGatewayCache
{
sessionBindings
:
map
[
string
]
int64
{
"openai:"
+
sessionHash
:
1
},
}
svc
:=
&
OpenAIGatewayService
{
accountRepo
:
repo
,
cache
:
cache
,
concurrencyService
:
NewConcurrencyService
(
stubConcurrencyCache
{}),
}
selection
,
err
:=
svc
.
SelectAccountWithLoadAwareness
(
context
.
Background
(),
&
groupID
,
sessionHash
,
"gpt-4"
,
nil
)
if
err
!=
nil
{
t
.
Fatalf
(
"SelectAccountWithLoadAwareness error: %v"
,
err
)
}
if
selection
==
nil
||
selection
.
Account
==
nil
||
selection
.
Account
.
ID
!=
2
{
t
.
Fatalf
(
"expected account 2, got %+v"
,
selection
)
}
if
cache
.
deletedSessions
[
"openai:"
+
sessionHash
]
!=
1
{
t
.
Fatalf
(
"expected sticky session to be deleted"
)
}
if
cache
.
sessionBindings
[
"openai:"
+
sessionHash
]
!=
2
{
t
.
Fatalf
(
"expected sticky session to bind to account 2"
)
}
if
selection
.
ReleaseFunc
!=
nil
{
selection
.
ReleaseFunc
()
}
}
func
TestOpenAISelectAccountForModelWithExclusions_NoModelSupport
(
t
*
testing
.
T
)
{
repo
:=
stubOpenAIAccountRepo
{
accounts
:
[]
Account
{
{
ID
:
1
,
Platform
:
PlatformOpenAI
,
Status
:
StatusActive
,
Schedulable
:
true
,
Credentials
:
map
[
string
]
any
{
"model_mapping"
:
map
[
string
]
any
{
"gpt-3.5-turbo"
:
"gpt-3.5-turbo"
}},
},
},
}
cache
:=
&
stubGatewayCache
{}
svc
:=
&
OpenAIGatewayService
{
accountRepo
:
repo
,
cache
:
cache
,
}
acc
,
err
:=
svc
.
SelectAccountForModelWithExclusions
(
context
.
Background
(),
nil
,
""
,
"gpt-4"
,
nil
)
if
err
==
nil
{
t
.
Fatalf
(
"expected error for unsupported model"
)
}
if
acc
!=
nil
{
t
.
Fatalf
(
"expected nil account for unsupported model"
)
}
if
!
strings
.
Contains
(
err
.
Error
(),
"supporting model"
)
{
t
.
Fatalf
(
"unexpected error: %v"
,
err
)
}
}
func
TestOpenAISelectAccountWithLoadAwareness_LoadBatchErrorFallback
(
t
*
testing
.
T
)
{
groupID
:=
int64
(
1
)
repo
:=
stubOpenAIAccountRepo
{
accounts
:
[]
Account
{
{
ID
:
1
,
Platform
:
PlatformOpenAI
,
Status
:
StatusActive
,
Schedulable
:
true
,
Concurrency
:
1
,
Priority
:
2
},
{
ID
:
2
,
Platform
:
PlatformOpenAI
,
Status
:
StatusActive
,
Schedulable
:
true
,
Concurrency
:
1
,
Priority
:
1
},
},
}
cache
:=
&
stubGatewayCache
{}
concurrencyCache
:=
stubConcurrencyCache
{
loadBatchErr
:
errors
.
New
(
"load batch failed"
),
}
svc
:=
&
OpenAIGatewayService
{
accountRepo
:
repo
,
cache
:
cache
,
concurrencyService
:
NewConcurrencyService
(
concurrencyCache
),
}
selection
,
err
:=
svc
.
SelectAccountWithLoadAwareness
(
context
.
Background
(),
&
groupID
,
"fallback"
,
"gpt-4"
,
nil
)
if
err
!=
nil
{
t
.
Fatalf
(
"SelectAccountWithLoadAwareness error: %v"
,
err
)
}
if
selection
==
nil
||
selection
.
Account
==
nil
{
t
.
Fatalf
(
"expected selection"
)
}
if
selection
.
Account
.
ID
!=
2
{
t
.
Fatalf
(
"expected account 2, got %d"
,
selection
.
Account
.
ID
)
}
if
cache
.
sessionBindings
[
"openai:fallback"
]
!=
2
{
t
.
Fatalf
(
"expected sticky session updated"
)
}
if
selection
.
ReleaseFunc
!=
nil
{
selection
.
ReleaseFunc
()
}
}
func
TestOpenAISelectAccountWithLoadAwareness_NoSlotFallbackWait
(
t
*
testing
.
T
)
{
groupID
:=
int64
(
1
)
repo
:=
stubOpenAIAccountRepo
{
accounts
:
[]
Account
{
{
ID
:
1
,
Platform
:
PlatformOpenAI
,
Status
:
StatusActive
,
Schedulable
:
true
,
Concurrency
:
1
,
Priority
:
1
},
},
}
cache
:=
&
stubGatewayCache
{}
concurrencyCache
:=
stubConcurrencyCache
{
acquireResults
:
map
[
int64
]
bool
{
1
:
false
},
loadMap
:
map
[
int64
]
*
AccountLoadInfo
{
1
:
{
AccountID
:
1
,
LoadRate
:
10
},
},
}
svc
:=
&
OpenAIGatewayService
{
accountRepo
:
repo
,
cache
:
cache
,
concurrencyService
:
NewConcurrencyService
(
concurrencyCache
),
}
selection
,
err
:=
svc
.
SelectAccountWithLoadAwareness
(
context
.
Background
(),
&
groupID
,
""
,
"gpt-4"
,
nil
)
if
err
!=
nil
{
t
.
Fatalf
(
"SelectAccountWithLoadAwareness error: %v"
,
err
)
}
if
selection
==
nil
||
selection
.
WaitPlan
==
nil
{
t
.
Fatalf
(
"expected wait plan fallback"
)
}
if
selection
.
Account
==
nil
||
selection
.
Account
.
ID
!=
1
{
t
.
Fatalf
(
"expected account 1"
)
}
}
func
TestOpenAISelectAccountForModelWithExclusions_SetsStickyBinding
(
t
*
testing
.
T
)
{
sessionHash
:=
"bind"
repo
:=
stubOpenAIAccountRepo
{
accounts
:
[]
Account
{
{
ID
:
1
,
Platform
:
PlatformOpenAI
,
Status
:
StatusActive
,
Schedulable
:
true
,
Concurrency
:
1
,
Priority
:
1
},
},
}
cache
:=
&
stubGatewayCache
{}
svc
:=
&
OpenAIGatewayService
{
accountRepo
:
repo
,
cache
:
cache
,
}
acc
,
err
:=
svc
.
SelectAccountForModelWithExclusions
(
context
.
Background
(),
nil
,
sessionHash
,
"gpt-4"
,
nil
)
if
err
!=
nil
{
t
.
Fatalf
(
"SelectAccountForModelWithExclusions error: %v"
,
err
)
}
if
acc
==
nil
||
acc
.
ID
!=
1
{
t
.
Fatalf
(
"expected account 1"
)
}
if
cache
.
sessionBindings
[
"openai:"
+
sessionHash
]
!=
1
{
t
.
Fatalf
(
"expected sticky session binding"
)
}
}
func
TestOpenAISelectAccountWithLoadAwareness_StickyWaitPlan
(
t
*
testing
.
T
)
{
sessionHash
:=
"sticky-wait"
groupID
:=
int64
(
1
)
repo
:=
stubOpenAIAccountRepo
{
accounts
:
[]
Account
{
{
ID
:
1
,
Platform
:
PlatformOpenAI
,
Status
:
StatusActive
,
Schedulable
:
true
,
Concurrency
:
1
,
Priority
:
1
},
},
}
cache
:=
&
stubGatewayCache
{
sessionBindings
:
map
[
string
]
int64
{
"openai:"
+
sessionHash
:
1
},
}
concurrencyCache
:=
stubConcurrencyCache
{
acquireResults
:
map
[
int64
]
bool
{
1
:
false
},
waitCounts
:
map
[
int64
]
int
{
1
:
0
},
}
svc
:=
&
OpenAIGatewayService
{
accountRepo
:
repo
,
cache
:
cache
,
concurrencyService
:
NewConcurrencyService
(
concurrencyCache
),
}
selection
,
err
:=
svc
.
SelectAccountWithLoadAwareness
(
context
.
Background
(),
&
groupID
,
sessionHash
,
"gpt-4"
,
nil
)
if
err
!=
nil
{
t
.
Fatalf
(
"SelectAccountWithLoadAwareness error: %v"
,
err
)
}
if
selection
==
nil
||
selection
.
WaitPlan
==
nil
{
t
.
Fatalf
(
"expected sticky wait plan"
)
}
if
selection
.
Account
==
nil
||
selection
.
Account
.
ID
!=
1
{
t
.
Fatalf
(
"expected account 1"
)
}
}
func
TestOpenAISelectAccountWithLoadAwareness_PrefersLowerLoad
(
t
*
testing
.
T
)
{
groupID
:=
int64
(
1
)
repo
:=
stubOpenAIAccountRepo
{
accounts
:
[]
Account
{
{
ID
:
1
,
Platform
:
PlatformOpenAI
,
Status
:
StatusActive
,
Schedulable
:
true
,
Concurrency
:
1
,
Priority
:
1
},
{
ID
:
2
,
Platform
:
PlatformOpenAI
,
Status
:
StatusActive
,
Schedulable
:
true
,
Concurrency
:
1
,
Priority
:
1
},
},
}
cache
:=
&
stubGatewayCache
{}
concurrencyCache
:=
stubConcurrencyCache
{
loadMap
:
map
[
int64
]
*
AccountLoadInfo
{
1
:
{
AccountID
:
1
,
LoadRate
:
80
},
2
:
{
AccountID
:
2
,
LoadRate
:
10
},
},
}
svc
:=
&
OpenAIGatewayService
{
accountRepo
:
repo
,
cache
:
cache
,
concurrencyService
:
NewConcurrencyService
(
concurrencyCache
),
}
selection
,
err
:=
svc
.
SelectAccountWithLoadAwareness
(
context
.
Background
(),
&
groupID
,
"load"
,
"gpt-4"
,
nil
)
if
err
!=
nil
{
t
.
Fatalf
(
"SelectAccountWithLoadAwareness error: %v"
,
err
)
}
if
selection
==
nil
||
selection
.
Account
==
nil
||
selection
.
Account
.
ID
!=
2
{
t
.
Fatalf
(
"expected account 2"
)
}
if
cache
.
sessionBindings
[
"openai:load"
]
!=
2
{
t
.
Fatalf
(
"expected sticky session updated"
)
}
}
func
TestOpenAISelectAccountForModelWithExclusions_StickyExcludedFallback
(
t
*
testing
.
T
)
{
sessionHash
:=
"excluded"
repo
:=
stubOpenAIAccountRepo
{
accounts
:
[]
Account
{
{
ID
:
1
,
Platform
:
PlatformOpenAI
,
Status
:
StatusActive
,
Schedulable
:
true
,
Concurrency
:
1
,
Priority
:
1
},
{
ID
:
2
,
Platform
:
PlatformOpenAI
,
Status
:
StatusActive
,
Schedulable
:
true
,
Concurrency
:
1
,
Priority
:
2
},
},
}
cache
:=
&
stubGatewayCache
{
sessionBindings
:
map
[
string
]
int64
{
"openai:"
+
sessionHash
:
1
},
}
svc
:=
&
OpenAIGatewayService
{
accountRepo
:
repo
,
cache
:
cache
,
}
excluded
:=
map
[
int64
]
struct
{}{
1
:
{}}
acc
,
err
:=
svc
.
SelectAccountForModelWithExclusions
(
context
.
Background
(),
nil
,
sessionHash
,
"gpt-4"
,
excluded
)
if
err
!=
nil
{
t
.
Fatalf
(
"SelectAccountForModelWithExclusions error: %v"
,
err
)
}
if
acc
==
nil
||
acc
.
ID
!=
2
{
t
.
Fatalf
(
"expected account 2"
)
}
}
func
TestOpenAISelectAccountForModelWithExclusions_StickyNonOpenAI
(
t
*
testing
.
T
)
{
sessionHash
:=
"non-openai"
repo
:=
stubOpenAIAccountRepo
{
accounts
:
[]
Account
{
{
ID
:
1
,
Platform
:
PlatformAnthropic
,
Status
:
StatusActive
,
Schedulable
:
true
,
Concurrency
:
1
,
Priority
:
1
},
{
ID
:
2
,
Platform
:
PlatformOpenAI
,
Status
:
StatusActive
,
Schedulable
:
true
,
Concurrency
:
1
,
Priority
:
2
},
},
}
cache
:=
&
stubGatewayCache
{
sessionBindings
:
map
[
string
]
int64
{
"openai:"
+
sessionHash
:
1
},
}
svc
:=
&
OpenAIGatewayService
{
accountRepo
:
repo
,
cache
:
cache
,
}
acc
,
err
:=
svc
.
SelectAccountForModelWithExclusions
(
context
.
Background
(),
nil
,
sessionHash
,
"gpt-4"
,
nil
)
if
err
!=
nil
{
t
.
Fatalf
(
"SelectAccountForModelWithExclusions error: %v"
,
err
)
}
if
acc
==
nil
||
acc
.
ID
!=
2
{
t
.
Fatalf
(
"expected account 2"
)
}
}
func
TestOpenAISelectAccountForModelWithExclusions_NoAccounts
(
t
*
testing
.
T
)
{
repo
:=
stubOpenAIAccountRepo
{
accounts
:
[]
Account
{}}
cache
:=
&
stubGatewayCache
{}
svc
:=
&
OpenAIGatewayService
{
accountRepo
:
repo
,
cache
:
cache
,
}
acc
,
err
:=
svc
.
SelectAccountForModelWithExclusions
(
context
.
Background
(),
nil
,
""
,
""
,
nil
)
if
err
==
nil
{
t
.
Fatalf
(
"expected error for no accounts"
)
}
if
acc
!=
nil
{
t
.
Fatalf
(
"expected nil account"
)
}
if
!
strings
.
Contains
(
err
.
Error
(),
"no available OpenAI accounts"
)
{
t
.
Fatalf
(
"unexpected error: %v"
,
err
)
}
}
func
TestOpenAISelectAccountWithLoadAwareness_NoCandidates
(
t
*
testing
.
T
)
{
groupID
:=
int64
(
1
)
resetAt
:=
time
.
Now
()
.
Add
(
1
*
time
.
Hour
)
repo
:=
stubOpenAIAccountRepo
{
accounts
:
[]
Account
{
{
ID
:
1
,
Platform
:
PlatformOpenAI
,
Status
:
StatusActive
,
Schedulable
:
true
,
Concurrency
:
1
,
Priority
:
1
,
RateLimitResetAt
:
&
resetAt
},
},
}
cache
:=
&
stubGatewayCache
{}
concurrencyCache
:=
stubConcurrencyCache
{}
svc
:=
&
OpenAIGatewayService
{
accountRepo
:
repo
,
cache
:
cache
,
concurrencyService
:
NewConcurrencyService
(
concurrencyCache
),
}
selection
,
err
:=
svc
.
SelectAccountWithLoadAwareness
(
context
.
Background
(),
&
groupID
,
""
,
"gpt-4"
,
nil
)
if
err
==
nil
{
t
.
Fatalf
(
"expected error for no candidates"
)
}
if
selection
!=
nil
{
t
.
Fatalf
(
"expected nil selection"
)
}
}
func
TestOpenAISelectAccountWithLoadAwareness_AllFullWaitPlan
(
t
*
testing
.
T
)
{
groupID
:=
int64
(
1
)
repo
:=
stubOpenAIAccountRepo
{
accounts
:
[]
Account
{
{
ID
:
1
,
Platform
:
PlatformOpenAI
,
Status
:
StatusActive
,
Schedulable
:
true
,
Concurrency
:
1
,
Priority
:
1
},
},
}
cache
:=
&
stubGatewayCache
{}
concurrencyCache
:=
stubConcurrencyCache
{
loadMap
:
map
[
int64
]
*
AccountLoadInfo
{
1
:
{
AccountID
:
1
,
LoadRate
:
100
},
},
}
svc
:=
&
OpenAIGatewayService
{
accountRepo
:
repo
,
cache
:
cache
,
concurrencyService
:
NewConcurrencyService
(
concurrencyCache
),
}
selection
,
err
:=
svc
.
SelectAccountWithLoadAwareness
(
context
.
Background
(),
&
groupID
,
""
,
"gpt-4"
,
nil
)
if
err
!=
nil
{
t
.
Fatalf
(
"SelectAccountWithLoadAwareness error: %v"
,
err
)
}
if
selection
==
nil
||
selection
.
WaitPlan
==
nil
{
t
.
Fatalf
(
"expected wait plan"
)
}
}
func
TestOpenAISelectAccountWithLoadAwareness_LoadBatchErrorNoAcquire
(
t
*
testing
.
T
)
{
groupID
:=
int64
(
1
)
repo
:=
stubOpenAIAccountRepo
{
accounts
:
[]
Account
{
{
ID
:
1
,
Platform
:
PlatformOpenAI
,
Status
:
StatusActive
,
Schedulable
:
true
,
Concurrency
:
1
,
Priority
:
1
},
},
}
cache
:=
&
stubGatewayCache
{}
concurrencyCache
:=
stubConcurrencyCache
{
loadBatchErr
:
errors
.
New
(
"load batch failed"
),
acquireResults
:
map
[
int64
]
bool
{
1
:
false
},
}
svc
:=
&
OpenAIGatewayService
{
accountRepo
:
repo
,
cache
:
cache
,
concurrencyService
:
NewConcurrencyService
(
concurrencyCache
),
}
selection
,
err
:=
svc
.
SelectAccountWithLoadAwareness
(
context
.
Background
(),
&
groupID
,
""
,
"gpt-4"
,
nil
)
if
err
!=
nil
{
t
.
Fatalf
(
"SelectAccountWithLoadAwareness error: %v"
,
err
)
}
if
selection
==
nil
||
selection
.
WaitPlan
==
nil
{
t
.
Fatalf
(
"expected wait plan"
)
}
}
func
TestOpenAISelectAccountWithLoadAwareness_MissingLoadInfo
(
t
*
testing
.
T
)
{
groupID
:=
int64
(
1
)
repo
:=
stubOpenAIAccountRepo
{
accounts
:
[]
Account
{
{
ID
:
1
,
Platform
:
PlatformOpenAI
,
Status
:
StatusActive
,
Schedulable
:
true
,
Concurrency
:
1
,
Priority
:
1
},
{
ID
:
2
,
Platform
:
PlatformOpenAI
,
Status
:
StatusActive
,
Schedulable
:
true
,
Concurrency
:
1
,
Priority
:
1
},
},
}
cache
:=
&
stubGatewayCache
{}
concurrencyCache
:=
stubConcurrencyCache
{
loadMap
:
map
[
int64
]
*
AccountLoadInfo
{
1
:
{
AccountID
:
1
,
LoadRate
:
50
},
},
skipDefaultLoad
:
true
,
}
svc
:=
&
OpenAIGatewayService
{
accountRepo
:
repo
,
cache
:
cache
,
concurrencyService
:
NewConcurrencyService
(
concurrencyCache
),
}
selection
,
err
:=
svc
.
SelectAccountWithLoadAwareness
(
context
.
Background
(),
&
groupID
,
""
,
"gpt-4"
,
nil
)
if
err
!=
nil
{
t
.
Fatalf
(
"SelectAccountWithLoadAwareness error: %v"
,
err
)
}
if
selection
==
nil
||
selection
.
Account
==
nil
||
selection
.
Account
.
ID
!=
2
{
t
.
Fatalf
(
"expected account 2"
)
}
}
func
TestOpenAISelectAccountForModelWithExclusions_LeastRecentlyUsed
(
t
*
testing
.
T
)
{
oldTime
:=
time
.
Now
()
.
Add
(
-
2
*
time
.
Hour
)
newTime
:=
time
.
Now
()
.
Add
(
-
1
*
time
.
Hour
)
repo
:=
stubOpenAIAccountRepo
{
accounts
:
[]
Account
{
{
ID
:
1
,
Platform
:
PlatformOpenAI
,
Status
:
StatusActive
,
Schedulable
:
true
,
Priority
:
1
,
LastUsedAt
:
&
newTime
},
{
ID
:
2
,
Platform
:
PlatformOpenAI
,
Status
:
StatusActive
,
Schedulable
:
true
,
Priority
:
1
,
LastUsedAt
:
&
oldTime
},
},
}
cache
:=
&
stubGatewayCache
{}
svc
:=
&
OpenAIGatewayService
{
accountRepo
:
repo
,
cache
:
cache
,
}
acc
,
err
:=
svc
.
SelectAccountForModelWithExclusions
(
context
.
Background
(),
nil
,
""
,
"gpt-4"
,
nil
)
if
err
!=
nil
{
t
.
Fatalf
(
"SelectAccountForModelWithExclusions error: %v"
,
err
)
}
if
acc
==
nil
||
acc
.
ID
!=
2
{
t
.
Fatalf
(
"expected account 2"
)
}
}
func
TestOpenAISelectAccountWithLoadAwareness_PreferNeverUsed
(
t
*
testing
.
T
)
{
groupID
:=
int64
(
1
)
lastUsed
:=
time
.
Now
()
.
Add
(
-
1
*
time
.
Hour
)
repo
:=
stubOpenAIAccountRepo
{
accounts
:
[]
Account
{
{
ID
:
1
,
Platform
:
PlatformOpenAI
,
Status
:
StatusActive
,
Schedulable
:
true
,
Concurrency
:
1
,
Priority
:
1
,
LastUsedAt
:
&
lastUsed
},
{
ID
:
2
,
Platform
:
PlatformOpenAI
,
Status
:
StatusActive
,
Schedulable
:
true
,
Concurrency
:
1
,
Priority
:
1
},
},
}
cache
:=
&
stubGatewayCache
{}
concurrencyCache
:=
stubConcurrencyCache
{
loadMap
:
map
[
int64
]
*
AccountLoadInfo
{
1
:
{
AccountID
:
1
,
LoadRate
:
10
},
2
:
{
AccountID
:
2
,
LoadRate
:
10
},
},
}
svc
:=
&
OpenAIGatewayService
{
accountRepo
:
repo
,
cache
:
cache
,
concurrencyService
:
NewConcurrencyService
(
concurrencyCache
),
}
selection
,
err
:=
svc
.
SelectAccountWithLoadAwareness
(
context
.
Background
(),
&
groupID
,
""
,
"gpt-4"
,
nil
)
if
err
!=
nil
{
t
.
Fatalf
(
"SelectAccountWithLoadAwareness error: %v"
,
err
)
}
if
selection
==
nil
||
selection
.
Account
==
nil
||
selection
.
Account
.
ID
!=
2
{
t
.
Fatalf
(
"expected account 2"
)
}
}
func
TestOpenAIStreamingTimeout
(
t
*
testing
.
T
)
{
func
TestOpenAIStreamingTimeout
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
gin
.
SetMode
(
gin
.
TestMode
)
cfg
:=
&
config
.
Config
{
cfg
:=
&
config
.
Config
{
...
...
backend/internal/service/openai_oauth_service.go
View file @
2fe8932c
...
@@ -2,9 +2,10 @@ package service
...
@@ -2,9 +2,10 @@ package service
import
(
import
(
"context"
"context"
"
fmt
"
"
net/http
"
"time"
"time"
infraerrors
"github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
)
)
...
@@ -35,12 +36,12 @@ func (s *OpenAIOAuthService) GenerateAuthURL(ctx context.Context, proxyID *int64
...
@@ -35,12 +36,12 @@ func (s *OpenAIOAuthService) GenerateAuthURL(ctx context.Context, proxyID *int64
// Generate PKCE values
// Generate PKCE values
state
,
err
:=
openai
.
GenerateState
()
state
,
err
:=
openai
.
GenerateState
()
if
err
!=
nil
{
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"failed to generate state: %
w
"
,
err
)
return
nil
,
infraerrors
.
Newf
(
http
.
StatusInternalServerError
,
"OPENAI_OAUTH_STATE_FAILED"
,
"failed to generate state: %
v
"
,
err
)
}
}
codeVerifier
,
err
:=
openai
.
GenerateCodeVerifier
()
codeVerifier
,
err
:=
openai
.
GenerateCodeVerifier
()
if
err
!=
nil
{
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"failed to generate code verifier: %
w
"
,
err
)
return
nil
,
infraerrors
.
Newf
(
http
.
StatusInternalServerError
,
"OPENAI_OAUTH_VERIFIER_FAILED"
,
"failed to generate code verifier: %
v
"
,
err
)
}
}
codeChallenge
:=
openai
.
GenerateCodeChallenge
(
codeVerifier
)
codeChallenge
:=
openai
.
GenerateCodeChallenge
(
codeVerifier
)
...
@@ -48,14 +49,17 @@ func (s *OpenAIOAuthService) GenerateAuthURL(ctx context.Context, proxyID *int64
...
@@ -48,14 +49,17 @@ func (s *OpenAIOAuthService) GenerateAuthURL(ctx context.Context, proxyID *int64
// Generate session ID
// Generate session ID
sessionID
,
err
:=
openai
.
GenerateSessionID
()
sessionID
,
err
:=
openai
.
GenerateSessionID
()
if
err
!=
nil
{
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"failed to generate session ID: %
w
"
,
err
)
return
nil
,
infraerrors
.
Newf
(
http
.
StatusInternalServerError
,
"OPENAI_OAUTH_SESSION_FAILED"
,
"failed to generate session ID: %
v
"
,
err
)
}
}
// Get proxy URL if specified
// Get proxy URL if specified
var
proxyURL
string
var
proxyURL
string
if
proxyID
!=
nil
{
if
proxyID
!=
nil
{
proxy
,
err
:=
s
.
proxyRepo
.
GetByID
(
ctx
,
*
proxyID
)
proxy
,
err
:=
s
.
proxyRepo
.
GetByID
(
ctx
,
*
proxyID
)
if
err
==
nil
&&
proxy
!=
nil
{
if
err
!=
nil
{
return
nil
,
infraerrors
.
Newf
(
http
.
StatusBadRequest
,
"OPENAI_OAUTH_PROXY_NOT_FOUND"
,
"proxy not found: %v"
,
err
)
}
if
proxy
!=
nil
{
proxyURL
=
proxy
.
URL
()
proxyURL
=
proxy
.
URL
()
}
}
}
}
...
@@ -110,14 +114,17 @@ func (s *OpenAIOAuthService) ExchangeCode(ctx context.Context, input *OpenAIExch
...
@@ -110,14 +114,17 @@ func (s *OpenAIOAuthService) ExchangeCode(ctx context.Context, input *OpenAIExch
// Get session
// Get session
session
,
ok
:=
s
.
sessionStore
.
Get
(
input
.
SessionID
)
session
,
ok
:=
s
.
sessionStore
.
Get
(
input
.
SessionID
)
if
!
ok
{
if
!
ok
{
return
nil
,
fmt
.
Errorf
(
"session not found or expired"
)
return
nil
,
infraerrors
.
New
(
http
.
StatusBadRequest
,
"OPENAI_OAUTH_SESSION_NOT_FOUND"
,
"session not found or expired"
)
}
}
// Get proxy URL
// Get proxy URL
: prefer input.ProxyID, fallback to session.ProxyURL
proxyURL
:=
session
.
ProxyURL
proxyURL
:=
session
.
ProxyURL
if
input
.
ProxyID
!=
nil
{
if
input
.
ProxyID
!=
nil
{
proxy
,
err
:=
s
.
proxyRepo
.
GetByID
(
ctx
,
*
input
.
ProxyID
)
proxy
,
err
:=
s
.
proxyRepo
.
GetByID
(
ctx
,
*
input
.
ProxyID
)
if
err
==
nil
&&
proxy
!=
nil
{
if
err
!=
nil
{
return
nil
,
infraerrors
.
Newf
(
http
.
StatusBadRequest
,
"OPENAI_OAUTH_PROXY_NOT_FOUND"
,
"proxy not found: %v"
,
err
)
}
if
proxy
!=
nil
{
proxyURL
=
proxy
.
URL
()
proxyURL
=
proxy
.
URL
()
}
}
}
}
...
@@ -131,7 +138,7 @@ func (s *OpenAIOAuthService) ExchangeCode(ctx context.Context, input *OpenAIExch
...
@@ -131,7 +138,7 @@ func (s *OpenAIOAuthService) ExchangeCode(ctx context.Context, input *OpenAIExch
// Exchange code for token
// Exchange code for token
tokenResp
,
err
:=
s
.
oauthClient
.
ExchangeCode
(
ctx
,
input
.
Code
,
session
.
CodeVerifier
,
redirectURI
,
proxyURL
)
tokenResp
,
err
:=
s
.
oauthClient
.
ExchangeCode
(
ctx
,
input
.
Code
,
session
.
CodeVerifier
,
redirectURI
,
proxyURL
)
if
err
!=
nil
{
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"failed to exchange code: %w"
,
err
)
return
nil
,
err
}
}
// Parse ID token to get user info
// Parse ID token to get user info
...
@@ -201,12 +208,12 @@ func (s *OpenAIOAuthService) RefreshToken(ctx context.Context, refreshToken stri
...
@@ -201,12 +208,12 @@ func (s *OpenAIOAuthService) RefreshToken(ctx context.Context, refreshToken stri
// RefreshAccountToken refreshes token for an OpenAI account
// RefreshAccountToken refreshes token for an OpenAI account
func
(
s
*
OpenAIOAuthService
)
RefreshAccountToken
(
ctx
context
.
Context
,
account
*
Account
)
(
*
OpenAITokenInfo
,
error
)
{
func
(
s
*
OpenAIOAuthService
)
RefreshAccountToken
(
ctx
context
.
Context
,
account
*
Account
)
(
*
OpenAITokenInfo
,
error
)
{
if
!
account
.
IsOpenAI
()
{
if
!
account
.
IsOpenAI
()
{
return
nil
,
fmt
.
Errorf
(
"account is not an OpenAI account"
)
return
nil
,
infraerrors
.
New
(
http
.
StatusBadRequest
,
"OPENAI_OAUTH_INVALID_ACCOUNT"
,
"account is not an OpenAI account"
)
}
}
refreshToken
:=
account
.
GetOpenAIRefreshToken
()
refreshToken
:=
account
.
GetOpenAIRefreshToken
()
if
refreshToken
==
""
{
if
refreshToken
==
""
{
return
nil
,
fmt
.
Errorf
(
"no refresh token available"
)
return
nil
,
infraerrors
.
New
(
http
.
StatusBadRequest
,
"OPENAI_OAUTH_NO_REFRESH_TOKEN"
,
"no refresh token available"
)
}
}
var
proxyURL
string
var
proxyURL
string
...
...
backend/internal/service/openai_token_provider.go
View file @
2fe8932c
...
@@ -162,26 +162,37 @@ func (p *OpenAITokenProvider) GetAccessToken(ctx context.Context, account *Accou
...
@@ -162,26 +162,37 @@ func (p *OpenAITokenProvider) GetAccessToken(ctx context.Context, account *Accou
return
""
,
errors
.
New
(
"access_token not found in credentials"
)
return
""
,
errors
.
New
(
"access_token not found in credentials"
)
}
}
// 3. 存入缓存
// 3. 存入缓存
(验证版本后再写入,避免异步刷新任务与请求线程的竞态条件)
if
p
.
tokenCache
!=
nil
{
if
p
.
tokenCache
!=
nil
{
ttl
:=
30
*
time
.
Minute
latestAccount
,
isStale
:=
CheckTokenVersion
(
ctx
,
account
,
p
.
accountRepo
)
if
refreshFailed
{
if
isStale
&&
latestAccount
!=
nil
{
//
刷新失败时使用短 TTL,避免失效 token 长时间缓存导致 401 抖动
//
版本过时,使用 DB 中的最新 token
ttl
=
time
.
Minute
slog
.
Debug
(
"openai_token_version_stale_use_latest"
,
"account_id"
,
account
.
ID
)
slog
.
Debug
(
"openai_token_cache_short_ttl"
,
"account_id"
,
account
.
ID
,
"reason"
,
"refresh_failed"
)
accessToken
=
latestAccount
.
GetOpenAIAccessToken
(
)
}
else
if
expiresAt
!=
nil
{
if
strings
.
TrimSpace
(
accessToken
)
==
""
{
until
:=
time
.
Until
(
*
expiresAt
)
return
""
,
errors
.
New
(
"access_token not found after version check"
)
switch
{
}
case
until
>
openAITokenCacheSkew
:
// 不写入缓存,让下次请求重新处理
ttl
=
until
-
openAITokenCacheSkew
}
else
{
case
until
>
0
:
ttl
:=
30
*
time
.
Minute
ttl
=
until
if
refreshFailed
{
default
:
// 刷新失败时使用短 TTL,避免失效 token 长时间缓存导致 401 抖动
ttl
=
time
.
Minute
ttl
=
time
.
Minute
slog
.
Debug
(
"openai_token_cache_short_ttl"
,
"account_id"
,
account
.
ID
,
"reason"
,
"refresh_failed"
)
}
else
if
expiresAt
!=
nil
{
until
:=
time
.
Until
(
*
expiresAt
)
switch
{
case
until
>
openAITokenCacheSkew
:
ttl
=
until
-
openAITokenCacheSkew
case
until
>
0
:
ttl
=
until
default
:
ttl
=
time
.
Minute
}
}
if
err
:=
p
.
tokenCache
.
SetAccessToken
(
ctx
,
cacheKey
,
accessToken
,
ttl
);
err
!=
nil
{
slog
.
Warn
(
"openai_token_cache_set_failed"
,
"account_id"
,
account
.
ID
,
"error"
,
err
)
}
}
}
if
err
:=
p
.
tokenCache
.
SetAccessToken
(
ctx
,
cacheKey
,
accessToken
,
ttl
);
err
!=
nil
{
slog
.
Warn
(
"openai_token_cache_set_failed"
,
"account_id"
,
account
.
ID
,
"error"
,
err
)
}
}
}
}
...
...
backend/internal/service/openai_tool_corrector.go
View file @
2fe8932c
...
@@ -27,6 +27,11 @@ var codexToolNameMapping = map[string]string{
...
@@ -27,6 +27,11 @@ var codexToolNameMapping = map[string]string{
"executeBash"
:
"bash"
,
"executeBash"
:
"bash"
,
"exec_bash"
:
"bash"
,
"exec_bash"
:
"bash"
,
"execBash"
:
"bash"
,
"execBash"
:
"bash"
,
// Some clients output generic fetch names.
"fetch"
:
"webfetch"
,
"web_fetch"
:
"webfetch"
,
"webFetch"
:
"webfetch"
,
}
}
// ToolCorrectionStats 记录工具修正的统计信息(导出用于 JSON 序列化)
// ToolCorrectionStats 记录工具修正的统计信息(导出用于 JSON 序列化)
...
@@ -208,27 +213,67 @@ func (c *CodexToolCorrector) correctToolParameters(toolName string, functionCall
...
@@ -208,27 +213,67 @@ func (c *CodexToolCorrector) correctToolParameters(toolName string, functionCall
// 根据工具名称应用特定的参数修正规则
// 根据工具名称应用特定的参数修正规则
switch
toolName
{
switch
toolName
{
case
"bash"
:
case
"bash"
:
// 移除 workdir 参数(OpenCode 不支持)
// OpenCode bash 支持 workdir;有些来源会输出 work_dir。
if
_
,
exists
:=
argsMap
[
"workdir"
];
exists
{
if
_
,
hasWorkdir
:=
argsMap
[
"workdir"
];
!
hasWorkdir
{
delete
(
argsMap
,
"workdir"
)
if
workDir
,
exists
:=
argsMap
[
"work_dir"
];
exists
{
corrected
=
true
argsMap
[
"workdir"
]
=
workDir
log
.
Printf
(
"[CodexToolCorrector] Removed 'workdir' parameter from bash tool"
)
delete
(
argsMap
,
"work_dir"
)
}
corrected
=
true
if
_
,
exists
:=
argsMap
[
"work_dir"
];
exists
{
log
.
Printf
(
"[CodexToolCorrector] Renamed 'work_dir' to 'workdir' in bash tool"
)
delete
(
argsMap
,
"work_dir"
)
}
corrected
=
true
}
else
{
log
.
Printf
(
"[CodexToolCorrector] Removed 'work_dir' parameter from bash tool"
)
if
_
,
exists
:=
argsMap
[
"work_dir"
];
exists
{
delete
(
argsMap
,
"work_dir"
)
corrected
=
true
log
.
Printf
(
"[CodexToolCorrector] Removed duplicate 'work_dir' parameter from bash tool"
)
}
}
}
case
"edit"
:
case
"edit"
:
// OpenCode edit 使用 old_string/new_string,Codex 可能使用其他名称
// OpenCode edit 参数为 filePath/oldString/newString(camelCase)。
// 这里可以添加参数名称的映射逻辑
if
_
,
exists
:=
argsMap
[
"filePath"
];
!
exists
{
if
_
,
exists
:=
argsMap
[
"file_path"
];
!
exists
{
if
filePath
,
exists
:=
argsMap
[
"file_path"
];
exists
{
if
path
,
exists
:=
argsMap
[
"path"
];
exists
{
argsMap
[
"filePath"
]
=
filePath
argsMap
[
"file_path"
]
=
path
delete
(
argsMap
,
"file_path"
)
corrected
=
true
log
.
Printf
(
"[CodexToolCorrector] Renamed 'file_path' to 'filePath' in edit tool"
)
}
else
if
filePath
,
exists
:=
argsMap
[
"path"
];
exists
{
argsMap
[
"filePath"
]
=
filePath
delete
(
argsMap
,
"path"
)
delete
(
argsMap
,
"path"
)
corrected
=
true
corrected
=
true
log
.
Printf
(
"[CodexToolCorrector] Renamed 'path' to 'file_path' in edit tool"
)
log
.
Printf
(
"[CodexToolCorrector] Renamed 'path' to 'filePath' in edit tool"
)
}
else
if
filePath
,
exists
:=
argsMap
[
"file"
];
exists
{
argsMap
[
"filePath"
]
=
filePath
delete
(
argsMap
,
"file"
)
corrected
=
true
log
.
Printf
(
"[CodexToolCorrector] Renamed 'file' to 'filePath' in edit tool"
)
}
}
if
_
,
exists
:=
argsMap
[
"oldString"
];
!
exists
{
if
oldString
,
exists
:=
argsMap
[
"old_string"
];
exists
{
argsMap
[
"oldString"
]
=
oldString
delete
(
argsMap
,
"old_string"
)
corrected
=
true
log
.
Printf
(
"[CodexToolCorrector] Renamed 'old_string' to 'oldString' in edit tool"
)
}
}
if
_
,
exists
:=
argsMap
[
"newString"
];
!
exists
{
if
newString
,
exists
:=
argsMap
[
"new_string"
];
exists
{
argsMap
[
"newString"
]
=
newString
delete
(
argsMap
,
"new_string"
)
corrected
=
true
log
.
Printf
(
"[CodexToolCorrector] Renamed 'new_string' to 'newString' in edit tool"
)
}
}
if
_
,
exists
:=
argsMap
[
"replaceAll"
];
!
exists
{
if
replaceAll
,
exists
:=
argsMap
[
"replace_all"
];
exists
{
argsMap
[
"replaceAll"
]
=
replaceAll
delete
(
argsMap
,
"replace_all"
)
corrected
=
true
log
.
Printf
(
"[CodexToolCorrector] Renamed 'replace_all' to 'replaceAll' in edit tool"
)
}
}
}
}
}
}
...
...
backend/internal/service/openai_tool_corrector_test.go
View file @
2fe8932c
...
@@ -416,22 +416,23 @@ func TestCorrectToolParameters(t *testing.T) {
...
@@ -416,22 +416,23 @@ func TestCorrectToolParameters(t *testing.T) {
expected
map
[
string
]
bool
// key: 期待存在的参数, value: true表示应该存在
expected
map
[
string
]
bool
// key: 期待存在的参数, value: true表示应该存在
}{
}{
{
{
name
:
"re
mov
e workdir
from
bash tool"
,
name
:
"re
nam
e work
_
dir
to workdir in
bash tool"
,
input
:
`{
input
:
`{
"tool_calls": [{
"tool_calls": [{
"function": {
"function": {
"name": "bash",
"name": "bash",
"arguments": "{\"command\":\"ls\",\"workdir\":\"/tmp\"}"
"arguments": "{\"command\":\"ls\",\"work
_
dir\":\"/tmp\"}"
}
}
}]
}]
}`
,
}`
,
expected
:
map
[
string
]
bool
{
expected
:
map
[
string
]
bool
{
"command"
:
true
,
"command"
:
true
,
"workdir"
:
false
,
"workdir"
:
true
,
"work_dir"
:
false
,
},
},
},
},
{
{
name
:
"rename
path to file_path in edit tool
"
,
name
:
"rename
snake_case edit params to camelCase
"
,
input
:
`{
input
:
`{
"tool_calls": [{
"tool_calls": [{
"function": {
"function": {
...
@@ -441,10 +442,12 @@ func TestCorrectToolParameters(t *testing.T) {
...
@@ -441,10 +442,12 @@ func TestCorrectToolParameters(t *testing.T) {
}]
}]
}`
,
}`
,
expected
:
map
[
string
]
bool
{
expected
:
map
[
string
]
bool
{
"file
_p
ath"
:
true
,
"file
P
ath"
:
true
,
"path"
:
false
,
"path"
:
false
,
"old_string"
:
true
,
"oldString"
:
true
,
"new_string"
:
true
,
"old_string"
:
false
,
"newString"
:
true
,
"new_string"
:
false
,
},
},
},
},
}
}
...
...
backend/internal/service/pricing_service.go
View file @
2fe8932c
...
@@ -531,8 +531,8 @@ func (s *PricingService) buildModelLookupCandidates(modelLower string) []string
...
@@ -531,8 +531,8 @@ func (s *PricingService) buildModelLookupCandidates(modelLower string) []string
func
normalizeModelNameForPricing
(
model
string
)
string
{
func
normalizeModelNameForPricing
(
model
string
)
string
{
// Common Gemini/VertexAI forms:
// Common Gemini/VertexAI forms:
// - models/gemini-2.0-flash-exp
// - models/gemini-2.0-flash-exp
// - publishers/google/models/gemini-
1
.5-pro
// - publishers/google/models/gemini-
2
.5-pro
// - projects/.../locations/.../publishers/google/models/gemini-
1
.5-pro
// - projects/.../locations/.../publishers/google/models/gemini-
2
.5-pro
model
=
strings
.
TrimSpace
(
model
)
model
=
strings
.
TrimSpace
(
model
)
model
=
strings
.
TrimLeft
(
model
,
"/"
)
model
=
strings
.
TrimLeft
(
model
,
"/"
)
model
=
strings
.
TrimPrefix
(
model
,
"models/"
)
model
=
strings
.
TrimPrefix
(
model
,
"models/"
)
...
...
backend/internal/service/ratelimit_service.go
View file @
2fe8932c
...
@@ -73,10 +73,14 @@ func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Acc
...
@@ -73,10 +73,14 @@ func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Acc
return
false
return
false
}
}
tempMatched
:=
false
// 先尝试临时不可调度规则(401除外)
// 如果匹配成功,直接返回,不执行后续禁用逻辑
if
statusCode
!=
401
{
if
statusCode
!=
401
{
tempMatched
=
s
.
tryTempUnschedulable
(
ctx
,
account
,
statusCode
,
responseBody
)
if
s
.
tryTempUnschedulable
(
ctx
,
account
,
statusCode
,
responseBody
)
{
return
true
}
}
}
upstreamMsg
:=
strings
.
TrimSpace
(
extractUpstreamErrorMessage
(
responseBody
))
upstreamMsg
:=
strings
.
TrimSpace
(
extractUpstreamErrorMessage
(
responseBody
))
upstreamMsg
=
sanitizeUpstreamErrorMessage
(
upstreamMsg
)
upstreamMsg
=
sanitizeUpstreamErrorMessage
(
upstreamMsg
)
if
upstreamMsg
!=
""
{
if
upstreamMsg
!=
""
{
...
@@ -84,6 +88,14 @@ func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Acc
...
@@ -84,6 +88,14 @@ func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Acc
}
}
switch
statusCode
{
switch
statusCode
{
case
400
:
// 只有当错误信息包含 "organization has been disabled" 时才禁用
if
strings
.
Contains
(
strings
.
ToLower
(
upstreamMsg
),
"organization has been disabled"
)
{
msg
:=
"Organization disabled (400): "
+
upstreamMsg
s
.
handleAuthError
(
ctx
,
account
,
msg
)
shouldDisable
=
true
}
// 其他 400 错误(如参数问题)不处理,不禁用账号
case
401
:
case
401
:
// 对所有 OAuth 账号在 401 错误时调用缓存失效并强制下次刷新
// 对所有 OAuth 账号在 401 错误时调用缓存失效并强制下次刷新
if
account
.
Type
==
AccountTypeOAuth
{
if
account
.
Type
==
AccountTypeOAuth
{
...
@@ -148,9 +160,6 @@ func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Acc
...
@@ -148,9 +160,6 @@ func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Acc
}
}
}
}
if
tempMatched
{
return
true
}
return
shouldDisable
return
shouldDisable
}
}
...
@@ -190,7 +199,7 @@ func (s *RateLimitService) PreCheckUsage(ctx context.Context, account *Account,
...
@@ -190,7 +199,7 @@ func (s *RateLimitService) PreCheckUsage(ctx context.Context, account *Account,
start
:=
geminiDailyWindowStart
(
now
)
start
:=
geminiDailyWindowStart
(
now
)
totals
,
ok
:=
s
.
getGeminiUsageTotals
(
account
.
ID
,
start
,
now
)
totals
,
ok
:=
s
.
getGeminiUsageTotals
(
account
.
ID
,
start
,
now
)
if
!
ok
{
if
!
ok
{
stats
,
err
:=
s
.
usageRepo
.
GetModelStatsWithFilters
(
ctx
,
start
,
now
,
0
,
0
,
account
.
ID
,
0
,
nil
)
stats
,
err
:=
s
.
usageRepo
.
GetModelStatsWithFilters
(
ctx
,
start
,
now
,
0
,
0
,
account
.
ID
,
0
,
nil
,
nil
)
if
err
!=
nil
{
if
err
!=
nil
{
return
true
,
err
return
true
,
err
}
}
...
@@ -237,7 +246,7 @@ func (s *RateLimitService) PreCheckUsage(ctx context.Context, account *Account,
...
@@ -237,7 +246,7 @@ func (s *RateLimitService) PreCheckUsage(ctx context.Context, account *Account,
if
limit
>
0
{
if
limit
>
0
{
start
:=
now
.
Truncate
(
time
.
Minute
)
start
:=
now
.
Truncate
(
time
.
Minute
)
stats
,
err
:=
s
.
usageRepo
.
GetModelStatsWithFilters
(
ctx
,
start
,
now
,
0
,
0
,
account
.
ID
,
0
,
nil
)
stats
,
err
:=
s
.
usageRepo
.
GetModelStatsWithFilters
(
ctx
,
start
,
now
,
0
,
0
,
account
.
ID
,
0
,
nil
,
nil
)
if
err
!=
nil
{
if
err
!=
nil
{
return
true
,
err
return
true
,
err
}
}
...
@@ -334,9 +343,48 @@ func (s *RateLimitService) handleCustomErrorCode(ctx context.Context, account *A
...
@@ -334,9 +343,48 @@ func (s *RateLimitService) handleCustomErrorCode(ctx context.Context, account *A
// handle429 处理429限流错误
// handle429 处理429限流错误
// 解析响应头获取重置时间,标记账号为限流状态
// 解析响应头获取重置时间,标记账号为限流状态
func
(
s
*
RateLimitService
)
handle429
(
ctx
context
.
Context
,
account
*
Account
,
headers
http
.
Header
,
responseBody
[]
byte
)
{
func
(
s
*
RateLimitService
)
handle429
(
ctx
context
.
Context
,
account
*
Account
,
headers
http
.
Header
,
responseBody
[]
byte
)
{
// 解析重置时间戳
// 1. OpenAI 平台:优先尝试解析 x-codex-* 响应头(用于 rate_limit_exceeded)
if
account
.
Platform
==
PlatformOpenAI
{
if
resetAt
:=
s
.
calculateOpenAI429ResetTime
(
headers
);
resetAt
!=
nil
{
if
err
:=
s
.
accountRepo
.
SetRateLimited
(
ctx
,
account
.
ID
,
*
resetAt
);
err
!=
nil
{
slog
.
Warn
(
"rate_limit_set_failed"
,
"account_id"
,
account
.
ID
,
"error"
,
err
)
return
}
slog
.
Info
(
"openai_account_rate_limited"
,
"account_id"
,
account
.
ID
,
"reset_at"
,
*
resetAt
)
return
}
}
// 2. 尝试从响应头解析重置时间(Anthropic)
resetTimestamp
:=
headers
.
Get
(
"anthropic-ratelimit-unified-reset"
)
resetTimestamp
:=
headers
.
Get
(
"anthropic-ratelimit-unified-reset"
)
// 3. 如果响应头没有,尝试从响应体解析(OpenAI usage_limit_reached, Gemini)
if
resetTimestamp
==
""
{
if
resetTimestamp
==
""
{
switch
account
.
Platform
{
case
PlatformOpenAI
:
// 尝试解析 OpenAI 的 usage_limit_reached 错误
if
resetAt
:=
parseOpenAIRateLimitResetTime
(
responseBody
);
resetAt
!=
nil
{
resetTime
:=
time
.
Unix
(
*
resetAt
,
0
)
if
err
:=
s
.
accountRepo
.
SetRateLimited
(
ctx
,
account
.
ID
,
resetTime
);
err
!=
nil
{
slog
.
Warn
(
"rate_limit_set_failed"
,
"account_id"
,
account
.
ID
,
"error"
,
err
)
return
}
slog
.
Info
(
"account_rate_limited"
,
"account_id"
,
account
.
ID
,
"platform"
,
account
.
Platform
,
"reset_at"
,
resetTime
,
"reset_in"
,
time
.
Until
(
resetTime
)
.
Truncate
(
time
.
Second
))
return
}
case
PlatformGemini
,
PlatformAntigravity
:
// 尝试解析 Gemini 格式(用于其他平台)
if
resetAt
:=
ParseGeminiRateLimitResetTime
(
responseBody
);
resetAt
!=
nil
{
resetTime
:=
time
.
Unix
(
*
resetAt
,
0
)
if
err
:=
s
.
accountRepo
.
SetRateLimited
(
ctx
,
account
.
ID
,
resetTime
);
err
!=
nil
{
slog
.
Warn
(
"rate_limit_set_failed"
,
"account_id"
,
account
.
ID
,
"error"
,
err
)
return
}
slog
.
Info
(
"account_rate_limited"
,
"account_id"
,
account
.
ID
,
"platform"
,
account
.
Platform
,
"reset_at"
,
resetTime
,
"reset_in"
,
time
.
Until
(
resetTime
)
.
Truncate
(
time
.
Second
))
return
}
}
// 没有重置时间,使用默认5分钟
// 没有重置时间,使用默认5分钟
resetAt
:=
time
.
Now
()
.
Add
(
5
*
time
.
Minute
)
resetAt
:=
time
.
Now
()
.
Add
(
5
*
time
.
Minute
)
if
s
.
shouldScopeClaudeSonnetRateLimit
(
account
,
responseBody
)
{
if
s
.
shouldScopeClaudeSonnetRateLimit
(
account
,
responseBody
)
{
...
@@ -347,6 +395,7 @@ func (s *RateLimitService) handle429(ctx context.Context, account *Account, head
...
@@ -347,6 +395,7 @@ func (s *RateLimitService) handle429(ctx context.Context, account *Account, head
}
}
return
return
}
}
slog
.
Warn
(
"rate_limit_no_reset_time"
,
"account_id"
,
account
.
ID
,
"platform"
,
account
.
Platform
,
"using_default"
,
"5m"
)
if
err
:=
s
.
accountRepo
.
SetRateLimited
(
ctx
,
account
.
ID
,
resetAt
);
err
!=
nil
{
if
err
:=
s
.
accountRepo
.
SetRateLimited
(
ctx
,
account
.
ID
,
resetAt
);
err
!=
nil
{
slog
.
Warn
(
"rate_limit_set_failed"
,
"account_id"
,
account
.
ID
,
"error"
,
err
)
slog
.
Warn
(
"rate_limit_set_failed"
,
"account_id"
,
account
.
ID
,
"error"
,
err
)
}
}
...
@@ -410,6 +459,108 @@ func (s *RateLimitService) shouldScopeClaudeSonnetRateLimit(account *Account, re
...
@@ -410,6 +459,108 @@ func (s *RateLimitService) shouldScopeClaudeSonnetRateLimit(account *Account, re
return
strings
.
Contains
(
msg
,
"sonnet"
)
return
strings
.
Contains
(
msg
,
"sonnet"
)
}
}
// calculateOpenAI429ResetTime 从 OpenAI 429 响应头计算正确的重置时间
// 返回 nil 表示无法从响应头中确定重置时间
func
(
s
*
RateLimitService
)
calculateOpenAI429ResetTime
(
headers
http
.
Header
)
*
time
.
Time
{
snapshot
:=
ParseCodexRateLimitHeaders
(
headers
)
if
snapshot
==
nil
{
return
nil
}
normalized
:=
snapshot
.
Normalize
()
if
normalized
==
nil
{
return
nil
}
now
:=
time
.
Now
()
// 判断哪个限制被触发(used_percent >= 100)
is7dExhausted
:=
normalized
.
Used7dPercent
!=
nil
&&
*
normalized
.
Used7dPercent
>=
100
is5hExhausted
:=
normalized
.
Used5hPercent
!=
nil
&&
*
normalized
.
Used5hPercent
>=
100
// 优先使用被触发限制的重置时间
if
is7dExhausted
&&
normalized
.
Reset7dSeconds
!=
nil
{
resetAt
:=
now
.
Add
(
time
.
Duration
(
*
normalized
.
Reset7dSeconds
)
*
time
.
Second
)
slog
.
Info
(
"openai_429_7d_limit_exhausted"
,
"reset_after_seconds"
,
*
normalized
.
Reset7dSeconds
,
"reset_at"
,
resetAt
)
return
&
resetAt
}
if
is5hExhausted
&&
normalized
.
Reset5hSeconds
!=
nil
{
resetAt
:=
now
.
Add
(
time
.
Duration
(
*
normalized
.
Reset5hSeconds
)
*
time
.
Second
)
slog
.
Info
(
"openai_429_5h_limit_exhausted"
,
"reset_after_seconds"
,
*
normalized
.
Reset5hSeconds
,
"reset_at"
,
resetAt
)
return
&
resetAt
}
// 都未达到100%但收到429,使用较长的重置时间
var
maxResetSecs
int
if
normalized
.
Reset7dSeconds
!=
nil
&&
*
normalized
.
Reset7dSeconds
>
maxResetSecs
{
maxResetSecs
=
*
normalized
.
Reset7dSeconds
}
if
normalized
.
Reset5hSeconds
!=
nil
&&
*
normalized
.
Reset5hSeconds
>
maxResetSecs
{
maxResetSecs
=
*
normalized
.
Reset5hSeconds
}
if
maxResetSecs
>
0
{
resetAt
:=
now
.
Add
(
time
.
Duration
(
maxResetSecs
)
*
time
.
Second
)
slog
.
Info
(
"openai_429_using_max_reset"
,
"max_reset_seconds"
,
maxResetSecs
,
"reset_at"
,
resetAt
)
return
&
resetAt
}
return
nil
}
// parseOpenAIRateLimitResetTime 解析 OpenAI 格式的 429 响应,返回重置时间的 Unix 时间戳
// OpenAI 的 usage_limit_reached 错误格式:
//
// {
// "error": {
// "message": "The usage limit has been reached",
// "type": "usage_limit_reached",
// "resets_at": 1769404154,
// "resets_in_seconds": 133107
// }
// }
func
parseOpenAIRateLimitResetTime
(
body
[]
byte
)
*
int64
{
var
parsed
map
[
string
]
any
if
err
:=
json
.
Unmarshal
(
body
,
&
parsed
);
err
!=
nil
{
return
nil
}
errObj
,
ok
:=
parsed
[
"error"
]
.
(
map
[
string
]
any
)
if
!
ok
{
return
nil
}
// 检查是否为 usage_limit_reached 或 rate_limit_exceeded 类型
errType
,
_
:=
errObj
[
"type"
]
.
(
string
)
if
errType
!=
"usage_limit_reached"
&&
errType
!=
"rate_limit_exceeded"
{
return
nil
}
// 优先使用 resets_at(Unix 时间戳)
if
resetsAt
,
ok
:=
errObj
[
"resets_at"
]
.
(
float64
);
ok
{
ts
:=
int64
(
resetsAt
)
return
&
ts
}
if
resetsAt
,
ok
:=
errObj
[
"resets_at"
]
.
(
string
);
ok
{
if
ts
,
err
:=
strconv
.
ParseInt
(
resetsAt
,
10
,
64
);
err
==
nil
{
return
&
ts
}
}
// 如果没有 resets_at,尝试使用 resets_in_seconds
if
resetsInSeconds
,
ok
:=
errObj
[
"resets_in_seconds"
]
.
(
float64
);
ok
{
ts
:=
time
.
Now
()
.
Unix
()
+
int64
(
resetsInSeconds
)
return
&
ts
}
if
resetsInSeconds
,
ok
:=
errObj
[
"resets_in_seconds"
]
.
(
string
);
ok
{
if
sec
,
err
:=
strconv
.
ParseInt
(
resetsInSeconds
,
10
,
64
);
err
==
nil
{
ts
:=
time
.
Now
()
.
Unix
()
+
sec
return
&
ts
}
}
return
nil
}
// handle529 处理529过载错误
// handle529 处理529过载错误
// 根据配置设置过载冷却时间
// 根据配置设置过载冷却时间
func
(
s
*
RateLimitService
)
handle529
(
ctx
context
.
Context
,
account
*
Account
)
{
func
(
s
*
RateLimitService
)
handle529
(
ctx
context
.
Context
,
account
*
Account
)
{
...
...
backend/internal/service/ratelimit_service_openai_test.go
0 → 100644
View file @
2fe8932c
package
service
import
(
"net/http"
"testing"
"time"
)
func
TestCalculateOpenAI429ResetTime_7dExhausted
(
t
*
testing
.
T
)
{
svc
:=
&
RateLimitService
{}
// Simulate headers when 7d limit is exhausted (100% used)
// Primary = 7d (10080 minutes), Secondary = 5h (300 minutes)
headers
:=
http
.
Header
{}
headers
.
Set
(
"x-codex-primary-used-percent"
,
"100"
)
headers
.
Set
(
"x-codex-primary-reset-after-seconds"
,
"384607"
)
// ~4.5 days
headers
.
Set
(
"x-codex-primary-window-minutes"
,
"10080"
)
// 7 days
headers
.
Set
(
"x-codex-secondary-used-percent"
,
"3"
)
headers
.
Set
(
"x-codex-secondary-reset-after-seconds"
,
"17369"
)
// ~4.8 hours
headers
.
Set
(
"x-codex-secondary-window-minutes"
,
"300"
)
// 5 hours
before
:=
time
.
Now
()
resetAt
:=
svc
.
calculateOpenAI429ResetTime
(
headers
)
after
:=
time
.
Now
()
if
resetAt
==
nil
{
t
.
Fatal
(
"expected non-nil resetAt"
)
}
// Should be approximately 384607 seconds from now
expectedDuration
:=
384607
*
time
.
Second
minExpected
:=
before
.
Add
(
expectedDuration
)
maxExpected
:=
after
.
Add
(
expectedDuration
)
if
resetAt
.
Before
(
minExpected
)
||
resetAt
.
After
(
maxExpected
)
{
t
.
Errorf
(
"resetAt %v not in expected range [%v, %v]"
,
resetAt
,
minExpected
,
maxExpected
)
}
}
func
TestCalculateOpenAI429ResetTime_5hExhausted
(
t
*
testing
.
T
)
{
svc
:=
&
RateLimitService
{}
// Simulate headers when 5h limit is exhausted (100% used)
headers
:=
http
.
Header
{}
headers
.
Set
(
"x-codex-primary-used-percent"
,
"50"
)
headers
.
Set
(
"x-codex-primary-reset-after-seconds"
,
"500000"
)
headers
.
Set
(
"x-codex-primary-window-minutes"
,
"10080"
)
// 7 days
headers
.
Set
(
"x-codex-secondary-used-percent"
,
"100"
)
headers
.
Set
(
"x-codex-secondary-reset-after-seconds"
,
"3600"
)
// 1 hour
headers
.
Set
(
"x-codex-secondary-window-minutes"
,
"300"
)
// 5 hours
before
:=
time
.
Now
()
resetAt
:=
svc
.
calculateOpenAI429ResetTime
(
headers
)
after
:=
time
.
Now
()
if
resetAt
==
nil
{
t
.
Fatal
(
"expected non-nil resetAt"
)
}
// Should be approximately 3600 seconds from now
expectedDuration
:=
3600
*
time
.
Second
minExpected
:=
before
.
Add
(
expectedDuration
)
maxExpected
:=
after
.
Add
(
expectedDuration
)
if
resetAt
.
Before
(
minExpected
)
||
resetAt
.
After
(
maxExpected
)
{
t
.
Errorf
(
"resetAt %v not in expected range [%v, %v]"
,
resetAt
,
minExpected
,
maxExpected
)
}
}
func
TestCalculateOpenAI429ResetTime_NeitherExhausted_UsesMax
(
t
*
testing
.
T
)
{
svc
:=
&
RateLimitService
{}
// Neither limit at 100%, should use the longer reset time
headers
:=
http
.
Header
{}
headers
.
Set
(
"x-codex-primary-used-percent"
,
"80"
)
headers
.
Set
(
"x-codex-primary-reset-after-seconds"
,
"100000"
)
headers
.
Set
(
"x-codex-primary-window-minutes"
,
"10080"
)
headers
.
Set
(
"x-codex-secondary-used-percent"
,
"90"
)
headers
.
Set
(
"x-codex-secondary-reset-after-seconds"
,
"5000"
)
headers
.
Set
(
"x-codex-secondary-window-minutes"
,
"300"
)
before
:=
time
.
Now
()
resetAt
:=
svc
.
calculateOpenAI429ResetTime
(
headers
)
after
:=
time
.
Now
()
if
resetAt
==
nil
{
t
.
Fatal
(
"expected non-nil resetAt"
)
}
// Should use the max (100000 seconds from 7d window)
expectedDuration
:=
100000
*
time
.
Second
minExpected
:=
before
.
Add
(
expectedDuration
)
maxExpected
:=
after
.
Add
(
expectedDuration
)
if
resetAt
.
Before
(
minExpected
)
||
resetAt
.
After
(
maxExpected
)
{
t
.
Errorf
(
"resetAt %v not in expected range [%v, %v]"
,
resetAt
,
minExpected
,
maxExpected
)
}
}
func
TestCalculateOpenAI429ResetTime_NoCodexHeaders
(
t
*
testing
.
T
)
{
svc
:=
&
RateLimitService
{}
// No codex headers at all
headers
:=
http
.
Header
{}
headers
.
Set
(
"content-type"
,
"application/json"
)
resetAt
:=
svc
.
calculateOpenAI429ResetTime
(
headers
)
if
resetAt
!=
nil
{
t
.
Errorf
(
"expected nil resetAt when no codex headers, got %v"
,
resetAt
)
}
}
func
TestCalculateOpenAI429ResetTime_ReversedWindowOrder
(
t
*
testing
.
T
)
{
svc
:=
&
RateLimitService
{}
// Test when OpenAI sends primary as 5h and secondary as 7d (reversed)
headers
:=
http
.
Header
{}
headers
.
Set
(
"x-codex-primary-used-percent"
,
"100"
)
// This is 5h
headers
.
Set
(
"x-codex-primary-reset-after-seconds"
,
"3600"
)
// 1 hour
headers
.
Set
(
"x-codex-primary-window-minutes"
,
"300"
)
// 5 hours - smaller!
headers
.
Set
(
"x-codex-secondary-used-percent"
,
"50"
)
headers
.
Set
(
"x-codex-secondary-reset-after-seconds"
,
"500000"
)
headers
.
Set
(
"x-codex-secondary-window-minutes"
,
"10080"
)
// 7 days - larger!
before
:=
time
.
Now
()
resetAt
:=
svc
.
calculateOpenAI429ResetTime
(
headers
)
after
:=
time
.
Now
()
if
resetAt
==
nil
{
t
.
Fatal
(
"expected non-nil resetAt"
)
}
// Should correctly identify that primary is 5h (smaller window) and use its reset time
expectedDuration
:=
3600
*
time
.
Second
minExpected
:=
before
.
Add
(
expectedDuration
)
maxExpected
:=
after
.
Add
(
expectedDuration
)
if
resetAt
.
Before
(
minExpected
)
||
resetAt
.
After
(
maxExpected
)
{
t
.
Errorf
(
"resetAt %v not in expected range [%v, %v]"
,
resetAt
,
minExpected
,
maxExpected
)
}
}
func
TestNormalizedCodexLimits
(
t
*
testing
.
T
)
{
// Test the Normalize() method directly
pUsed
:=
100.0
pReset
:=
384607
pWindow
:=
10080
sUsed
:=
3.0
sReset
:=
17369
sWindow
:=
300
snapshot
:=
&
OpenAICodexUsageSnapshot
{
PrimaryUsedPercent
:
&
pUsed
,
PrimaryResetAfterSeconds
:
&
pReset
,
PrimaryWindowMinutes
:
&
pWindow
,
SecondaryUsedPercent
:
&
sUsed
,
SecondaryResetAfterSeconds
:
&
sReset
,
SecondaryWindowMinutes
:
&
sWindow
,
}
normalized
:=
snapshot
.
Normalize
()
if
normalized
==
nil
{
t
.
Fatal
(
"expected non-nil normalized"
)
}
// Primary has larger window (10080 > 300), so primary should be 7d
if
normalized
.
Used7dPercent
==
nil
||
*
normalized
.
Used7dPercent
!=
100.0
{
t
.
Errorf
(
"expected Used7dPercent=100, got %v"
,
normalized
.
Used7dPercent
)
}
if
normalized
.
Reset7dSeconds
==
nil
||
*
normalized
.
Reset7dSeconds
!=
384607
{
t
.
Errorf
(
"expected Reset7dSeconds=384607, got %v"
,
normalized
.
Reset7dSeconds
)
}
if
normalized
.
Used5hPercent
==
nil
||
*
normalized
.
Used5hPercent
!=
3.0
{
t
.
Errorf
(
"expected Used5hPercent=3, got %v"
,
normalized
.
Used5hPercent
)
}
if
normalized
.
Reset5hSeconds
==
nil
||
*
normalized
.
Reset5hSeconds
!=
17369
{
t
.
Errorf
(
"expected Reset5hSeconds=17369, got %v"
,
normalized
.
Reset5hSeconds
)
}
}
func
TestNormalizedCodexLimits_OnlyPrimaryData
(
t
*
testing
.
T
)
{
// Test when only primary has data, no window_minutes
pUsed
:=
80.0
pReset
:=
50000
snapshot
:=
&
OpenAICodexUsageSnapshot
{
PrimaryUsedPercent
:
&
pUsed
,
PrimaryResetAfterSeconds
:
&
pReset
,
// No window_minutes, no secondary data
}
normalized
:=
snapshot
.
Normalize
()
if
normalized
==
nil
{
t
.
Fatal
(
"expected non-nil normalized"
)
}
// Legacy assumption: primary=7d, secondary=5h
if
normalized
.
Used7dPercent
==
nil
||
*
normalized
.
Used7dPercent
!=
80.0
{
t
.
Errorf
(
"expected Used7dPercent=80, got %v"
,
normalized
.
Used7dPercent
)
}
if
normalized
.
Reset7dSeconds
==
nil
||
*
normalized
.
Reset7dSeconds
!=
50000
{
t
.
Errorf
(
"expected Reset7dSeconds=50000, got %v"
,
normalized
.
Reset7dSeconds
)
}
// Secondary (5h) should be nil
if
normalized
.
Used5hPercent
!=
nil
{
t
.
Errorf
(
"expected Used5hPercent=nil, got %v"
,
*
normalized
.
Used5hPercent
)
}
if
normalized
.
Reset5hSeconds
!=
nil
{
t
.
Errorf
(
"expected Reset5hSeconds=nil, got %v"
,
*
normalized
.
Reset5hSeconds
)
}
}
func
TestNormalizedCodexLimits_OnlySecondaryData
(
t
*
testing
.
T
)
{
// Test when only secondary has data, no window_minutes
sUsed
:=
60.0
sReset
:=
3000
snapshot
:=
&
OpenAICodexUsageSnapshot
{
SecondaryUsedPercent
:
&
sUsed
,
SecondaryResetAfterSeconds
:
&
sReset
,
// No window_minutes, no primary data
}
normalized
:=
snapshot
.
Normalize
()
if
normalized
==
nil
{
t
.
Fatal
(
"expected non-nil normalized"
)
}
// Legacy assumption: primary=7d, secondary=5h
// So secondary goes to 5h
if
normalized
.
Used5hPercent
==
nil
||
*
normalized
.
Used5hPercent
!=
60.0
{
t
.
Errorf
(
"expected Used5hPercent=60, got %v"
,
normalized
.
Used5hPercent
)
}
if
normalized
.
Reset5hSeconds
==
nil
||
*
normalized
.
Reset5hSeconds
!=
3000
{
t
.
Errorf
(
"expected Reset5hSeconds=3000, got %v"
,
normalized
.
Reset5hSeconds
)
}
// Primary (7d) should be nil
if
normalized
.
Used7dPercent
!=
nil
{
t
.
Errorf
(
"expected Used7dPercent=nil, got %v"
,
*
normalized
.
Used7dPercent
)
}
}
func
TestNormalizedCodexLimits_BothDataNoWindowMinutes
(
t
*
testing
.
T
)
{
// Test when both have data but no window_minutes
pUsed
:=
100.0
pReset
:=
400000
sUsed
:=
50.0
sReset
:=
10000
snapshot
:=
&
OpenAICodexUsageSnapshot
{
PrimaryUsedPercent
:
&
pUsed
,
PrimaryResetAfterSeconds
:
&
pReset
,
SecondaryUsedPercent
:
&
sUsed
,
SecondaryResetAfterSeconds
:
&
sReset
,
// No window_minutes
}
normalized
:=
snapshot
.
Normalize
()
if
normalized
==
nil
{
t
.
Fatal
(
"expected non-nil normalized"
)
}
// Legacy assumption: primary=7d, secondary=5h
if
normalized
.
Used7dPercent
==
nil
||
*
normalized
.
Used7dPercent
!=
100.0
{
t
.
Errorf
(
"expected Used7dPercent=100, got %v"
,
normalized
.
Used7dPercent
)
}
if
normalized
.
Reset7dSeconds
==
nil
||
*
normalized
.
Reset7dSeconds
!=
400000
{
t
.
Errorf
(
"expected Reset7dSeconds=400000, got %v"
,
normalized
.
Reset7dSeconds
)
}
if
normalized
.
Used5hPercent
==
nil
||
*
normalized
.
Used5hPercent
!=
50.0
{
t
.
Errorf
(
"expected Used5hPercent=50, got %v"
,
normalized
.
Used5hPercent
)
}
if
normalized
.
Reset5hSeconds
==
nil
||
*
normalized
.
Reset5hSeconds
!=
10000
{
t
.
Errorf
(
"expected Reset5hSeconds=10000, got %v"
,
normalized
.
Reset5hSeconds
)
}
}
func
TestHandle429_AnthropicPlatformUnaffected
(
t
*
testing
.
T
)
{
// Verify that Anthropic platform accounts still use the original logic
// This test ensures we don't break existing Claude account rate limiting
svc
:=
&
RateLimitService
{}
// Simulate Anthropic 429 headers
headers
:=
http
.
Header
{}
headers
.
Set
(
"anthropic-ratelimit-unified-reset"
,
"1737820800"
)
// A future Unix timestamp
// For Anthropic platform, calculateOpenAI429ResetTime should return nil
// because it only handles OpenAI platform
resetAt
:=
svc
.
calculateOpenAI429ResetTime
(
headers
)
// Should return nil since there are no x-codex-* headers
if
resetAt
!=
nil
{
t
.
Errorf
(
"expected nil for Anthropic headers, got %v"
,
resetAt
)
}
}
func
TestCalculateOpenAI429ResetTime_UserProvidedScenario
(
t
*
testing
.
T
)
{
// This is the exact scenario from the user:
// codex_7d_used_percent: 100
// codex_7d_reset_after_seconds: 384607 (约4.5天后重置)
// codex_5h_used_percent: 3
// codex_5h_reset_after_seconds: 17369 (约4.8小时后重置)
svc
:=
&
RateLimitService
{}
// Simulate headers matching user's data
// Note: We need to map the canonical 5h/7d back to primary/secondary
// Based on typical OpenAI behavior: primary=7d (larger window), secondary=5h (smaller window)
headers
:=
http
.
Header
{}
headers
.
Set
(
"x-codex-primary-used-percent"
,
"100"
)
headers
.
Set
(
"x-codex-primary-reset-after-seconds"
,
"384607"
)
headers
.
Set
(
"x-codex-primary-window-minutes"
,
"10080"
)
// 7 days = 10080 minutes
headers
.
Set
(
"x-codex-secondary-used-percent"
,
"3"
)
headers
.
Set
(
"x-codex-secondary-reset-after-seconds"
,
"17369"
)
headers
.
Set
(
"x-codex-secondary-window-minutes"
,
"300"
)
// 5 hours = 300 minutes
before
:=
time
.
Now
()
resetAt
:=
svc
.
calculateOpenAI429ResetTime
(
headers
)
after
:=
time
.
Now
()
if
resetAt
==
nil
{
t
.
Fatal
(
"expected non-nil resetAt for user scenario"
)
}
// Should use the 7d reset time (384607 seconds) since 7d limit is exhausted (100%)
expectedDuration
:=
384607
*
time
.
Second
minExpected
:=
before
.
Add
(
expectedDuration
)
maxExpected
:=
after
.
Add
(
expectedDuration
)
if
resetAt
.
Before
(
minExpected
)
||
resetAt
.
After
(
maxExpected
)
{
t
.
Errorf
(
"resetAt %v not in expected range [%v, %v]"
,
resetAt
,
minExpected
,
maxExpected
)
}
// Verify it's approximately 4.45 days (384607 seconds)
duration
:=
resetAt
.
Sub
(
before
)
actualDays
:=
duration
.
Hours
()
/
24.0
// 384607 / 86400 = ~4.45 days
if
actualDays
<
4.4
||
actualDays
>
4.5
{
t
.
Errorf
(
"expected ~4.45 days, got %.2f days"
,
actualDays
)
}
t
.
Logf
(
"User scenario: reset_at=%v, duration=%.2f days"
,
resetAt
,
actualDays
)
}
func
TestCalculateOpenAI429ResetTime_5MinFallbackWhenNoReset
(
t
*
testing
.
T
)
{
// Test that we return nil when there's used_percent but no reset_after_seconds
// This should cause the caller to use the default 5-minute fallback
svc
:=
&
RateLimitService
{}
headers
:=
http
.
Header
{}
headers
.
Set
(
"x-codex-primary-used-percent"
,
"100"
)
// No reset_after_seconds!
resetAt
:=
svc
.
calculateOpenAI429ResetTime
(
headers
)
// Should return nil since there's no reset time available
if
resetAt
!=
nil
{
t
.
Errorf
(
"expected nil when no reset_after_seconds, got %v"
,
resetAt
)
}
}
backend/internal/service/session_limit_cache.go
View file @
2fe8932c
...
@@ -38,8 +38,9 @@ type SessionLimitCache interface {
...
@@ -38,8 +38,9 @@ type SessionLimitCache interface {
GetActiveSessionCount
(
ctx
context
.
Context
,
accountID
int64
)
(
int
,
error
)
GetActiveSessionCount
(
ctx
context
.
Context
,
accountID
int64
)
(
int
,
error
)
// GetActiveSessionCountBatch 批量获取多个账号的活跃会话数
// GetActiveSessionCountBatch 批量获取多个账号的活跃会话数
// idleTimeouts: 每个账号的空闲超时时间配置,key 为 accountID;若为 nil 或某账号不在其中,则使用默认超时
// 返回 map[accountID]count,查询失败的账号不在 map 中
// 返回 map[accountID]count,查询失败的账号不在 map 中
GetActiveSessionCountBatch
(
ctx
context
.
Context
,
accountIDs
[]
int64
)
(
map
[
int64
]
int
,
error
)
GetActiveSessionCountBatch
(
ctx
context
.
Context
,
accountIDs
[]
int64
,
idleTimeouts
map
[
int64
]
time
.
Duration
)
(
map
[
int64
]
int
,
error
)
// IsSessionActive 检查特定会话是否活跃(未过期)
// IsSessionActive 检查特定会话是否活跃(未过期)
IsSessionActive
(
ctx
context
.
Context
,
accountID
int64
,
sessionUUID
string
)
(
bool
,
error
)
IsSessionActive
(
ctx
context
.
Context
,
accountID
int64
,
sessionUUID
string
)
(
bool
,
error
)
...
...
backend/internal/service/setting_service.go
View file @
2fe8932c
...
@@ -60,6 +60,9 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
...
@@ -60,6 +60,9 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
keys
:=
[]
string
{
keys
:=
[]
string
{
SettingKeyRegistrationEnabled
,
SettingKeyRegistrationEnabled
,
SettingKeyEmailVerifyEnabled
,
SettingKeyEmailVerifyEnabled
,
SettingKeyPromoCodeEnabled
,
SettingKeyPasswordResetEnabled
,
SettingKeyTotpEnabled
,
SettingKeyTurnstileEnabled
,
SettingKeyTurnstileEnabled
,
SettingKeyTurnstileSiteKey
,
SettingKeyTurnstileSiteKey
,
SettingKeySiteName
,
SettingKeySiteName
,
...
@@ -69,6 +72,9 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
...
@@ -69,6 +72,9 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
SettingKeyContactInfo
,
SettingKeyContactInfo
,
SettingKeyDocURL
,
SettingKeyDocURL
,
SettingKeyHomeContent
,
SettingKeyHomeContent
,
SettingKeyHideCcsImportButton
,
SettingKeyPurchaseSubscriptionEnabled
,
SettingKeyPurchaseSubscriptionURL
,
SettingKeyLinuxDoConnectEnabled
,
SettingKeyLinuxDoConnectEnabled
,
}
}
...
@@ -84,19 +90,29 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
...
@@ -84,19 +90,29 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
linuxDoEnabled
=
s
.
cfg
!=
nil
&&
s
.
cfg
.
LinuxDo
.
Enabled
linuxDoEnabled
=
s
.
cfg
!=
nil
&&
s
.
cfg
.
LinuxDo
.
Enabled
}
}
// Password reset requires email verification to be enabled
emailVerifyEnabled
:=
settings
[
SettingKeyEmailVerifyEnabled
]
==
"true"
passwordResetEnabled
:=
emailVerifyEnabled
&&
settings
[
SettingKeyPasswordResetEnabled
]
==
"true"
return
&
PublicSettings
{
return
&
PublicSettings
{
RegistrationEnabled
:
settings
[
SettingKeyRegistrationEnabled
]
==
"true"
,
RegistrationEnabled
:
settings
[
SettingKeyRegistrationEnabled
]
==
"true"
,
EmailVerifyEnabled
:
settings
[
SettingKeyEmailVerifyEnabled
]
==
"true"
,
EmailVerifyEnabled
:
emailVerifyEnabled
,
TurnstileEnabled
:
settings
[
SettingKeyTurnstileEnabled
]
==
"true"
,
PromoCodeEnabled
:
settings
[
SettingKeyPromoCodeEnabled
]
!=
"false"
,
// 默认启用
TurnstileSiteKey
:
settings
[
SettingKeyTurnstileSiteKey
],
PasswordResetEnabled
:
passwordResetEnabled
,
SiteName
:
s
.
getStringOrDefault
(
settings
,
SettingKeySiteName
,
"Sub2API"
),
TotpEnabled
:
settings
[
SettingKeyTotpEnabled
]
==
"true"
,
SiteLogo
:
settings
[
SettingKeySiteLogo
],
TurnstileEnabled
:
settings
[
SettingKeyTurnstileEnabled
]
==
"true"
,
SiteSubtitle
:
s
.
getStringOrDefault
(
settings
,
SettingKeySiteSubtitle
,
"Subscription to API Conversion Platform"
),
TurnstileSiteKey
:
settings
[
SettingKeyTurnstileSiteKey
],
APIBaseURL
:
settings
[
SettingKeyAPIBaseURL
],
SiteName
:
s
.
getStringOrDefault
(
settings
,
SettingKeySiteName
,
"Sub2API"
),
ContactInfo
:
settings
[
SettingKeyContactInfo
],
SiteLogo
:
settings
[
SettingKeySiteLogo
],
DocURL
:
settings
[
SettingKeyDocURL
],
SiteSubtitle
:
s
.
getStringOrDefault
(
settings
,
SettingKeySiteSubtitle
,
"Subscription to API Conversion Platform"
),
HomeContent
:
settings
[
SettingKeyHomeContent
],
APIBaseURL
:
settings
[
SettingKeyAPIBaseURL
],
LinuxDoOAuthEnabled
:
linuxDoEnabled
,
ContactInfo
:
settings
[
SettingKeyContactInfo
],
DocURL
:
settings
[
SettingKeyDocURL
],
HomeContent
:
settings
[
SettingKeyHomeContent
],
HideCcsImportButton
:
settings
[
SettingKeyHideCcsImportButton
]
==
"true"
,
PurchaseSubscriptionEnabled
:
settings
[
SettingKeyPurchaseSubscriptionEnabled
]
==
"true"
,
PurchaseSubscriptionURL
:
strings
.
TrimSpace
(
settings
[
SettingKeyPurchaseSubscriptionURL
]),
LinuxDoOAuthEnabled
:
linuxDoEnabled
,
},
nil
},
nil
}
}
...
@@ -121,33 +137,45 @@ func (s *SettingService) GetPublicSettingsForInjection(ctx context.Context) (any
...
@@ -121,33 +137,45 @@ func (s *SettingService) GetPublicSettingsForInjection(ctx context.Context) (any
// Return a struct that matches the frontend's expected format
// Return a struct that matches the frontend's expected format
return
&
struct
{
return
&
struct
{
RegistrationEnabled
bool
`json:"registration_enabled"`
RegistrationEnabled
bool
`json:"registration_enabled"`
EmailVerifyEnabled
bool
`json:"email_verify_enabled"`
EmailVerifyEnabled
bool
`json:"email_verify_enabled"`
TurnstileEnabled
bool
`json:"turnstile_enabled"`
PromoCodeEnabled
bool
`json:"promo_code_enabled"`
TurnstileSiteKey
string
`json:"turnstile_site_key,omitempty"`
PasswordResetEnabled
bool
`json:"password_reset_enabled"`
SiteName
string
`json:"site_name"`
TotpEnabled
bool
`json:"totp_enabled"`
SiteLogo
string
`json:"site_logo,omitempty"`
TurnstileEnabled
bool
`json:"turnstile_enabled"`
SiteSubtitle
string
`json:"site_subtitle,omitempty"`
TurnstileSiteKey
string
`json:"turnstile_site_key,omitempty"`
APIBaseURL
string
`json:"api_base_url,omitempty"`
SiteName
string
`json:"site_name"`
ContactInfo
string
`json:"contact_info,omitempty"`
SiteLogo
string
`json:"site_logo,omitempty"`
DocURL
string
`json:"doc_url,omitempty"`
SiteSubtitle
string
`json:"site_subtitle,omitempty"`
HomeContent
string
`json:"home_content,omitempty"`
APIBaseURL
string
`json:"api_base_url,omitempty"`
LinuxDoOAuthEnabled
bool
`json:"linuxdo_oauth_enabled"`
ContactInfo
string
`json:"contact_info,omitempty"`
Version
string
`json:"version,omitempty"`
DocURL
string
`json:"doc_url,omitempty"`
HomeContent
string
`json:"home_content,omitempty"`
HideCcsImportButton
bool
`json:"hide_ccs_import_button"`
PurchaseSubscriptionEnabled
bool
`json:"purchase_subscription_enabled"`
PurchaseSubscriptionURL
string
`json:"purchase_subscription_url,omitempty"`
LinuxDoOAuthEnabled
bool
`json:"linuxdo_oauth_enabled"`
Version
string
`json:"version,omitempty"`
}{
}{
RegistrationEnabled
:
settings
.
RegistrationEnabled
,
RegistrationEnabled
:
settings
.
RegistrationEnabled
,
EmailVerifyEnabled
:
settings
.
EmailVerifyEnabled
,
EmailVerifyEnabled
:
settings
.
EmailVerifyEnabled
,
TurnstileEnabled
:
settings
.
TurnstileEnabled
,
PromoCodeEnabled
:
settings
.
PromoCodeEnabled
,
TurnstileSiteKey
:
settings
.
TurnstileSiteKey
,
PasswordResetEnabled
:
settings
.
PasswordResetEnabled
,
SiteName
:
settings
.
SiteName
,
TotpEnabled
:
settings
.
TotpEnabled
,
SiteLogo
:
settings
.
SiteLogo
,
TurnstileEnabled
:
settings
.
TurnstileEnabled
,
SiteSubtitle
:
settings
.
SiteSubtitle
,
TurnstileSiteKey
:
settings
.
TurnstileSiteKey
,
APIBaseURL
:
settings
.
APIBaseURL
,
SiteName
:
settings
.
SiteName
,
ContactInfo
:
settings
.
ContactInfo
,
SiteLogo
:
settings
.
SiteLogo
,
DocURL
:
settings
.
DocURL
,
SiteSubtitle
:
settings
.
SiteSubtitle
,
HomeContent
:
settings
.
HomeContent
,
APIBaseURL
:
settings
.
APIBaseURL
,
LinuxDoOAuthEnabled
:
settings
.
LinuxDoOAuthEnabled
,
ContactInfo
:
settings
.
ContactInfo
,
Version
:
s
.
version
,
DocURL
:
settings
.
DocURL
,
HomeContent
:
settings
.
HomeContent
,
HideCcsImportButton
:
settings
.
HideCcsImportButton
,
PurchaseSubscriptionEnabled
:
settings
.
PurchaseSubscriptionEnabled
,
PurchaseSubscriptionURL
:
settings
.
PurchaseSubscriptionURL
,
LinuxDoOAuthEnabled
:
settings
.
LinuxDoOAuthEnabled
,
Version
:
s
.
version
,
},
nil
},
nil
}
}
...
@@ -158,6 +186,9 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet
...
@@ -158,6 +186,9 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet
// 注册设置
// 注册设置
updates
[
SettingKeyRegistrationEnabled
]
=
strconv
.
FormatBool
(
settings
.
RegistrationEnabled
)
updates
[
SettingKeyRegistrationEnabled
]
=
strconv
.
FormatBool
(
settings
.
RegistrationEnabled
)
updates
[
SettingKeyEmailVerifyEnabled
]
=
strconv
.
FormatBool
(
settings
.
EmailVerifyEnabled
)
updates
[
SettingKeyEmailVerifyEnabled
]
=
strconv
.
FormatBool
(
settings
.
EmailVerifyEnabled
)
updates
[
SettingKeyPromoCodeEnabled
]
=
strconv
.
FormatBool
(
settings
.
PromoCodeEnabled
)
updates
[
SettingKeyPasswordResetEnabled
]
=
strconv
.
FormatBool
(
settings
.
PasswordResetEnabled
)
updates
[
SettingKeyTotpEnabled
]
=
strconv
.
FormatBool
(
settings
.
TotpEnabled
)
// 邮件服务设置(只有非空才更新密码)
// 邮件服务设置(只有非空才更新密码)
updates
[
SettingKeySMTPHost
]
=
settings
.
SMTPHost
updates
[
SettingKeySMTPHost
]
=
settings
.
SMTPHost
...
@@ -193,6 +224,9 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet
...
@@ -193,6 +224,9 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet
updates
[
SettingKeyContactInfo
]
=
settings
.
ContactInfo
updates
[
SettingKeyContactInfo
]
=
settings
.
ContactInfo
updates
[
SettingKeyDocURL
]
=
settings
.
DocURL
updates
[
SettingKeyDocURL
]
=
settings
.
DocURL
updates
[
SettingKeyHomeContent
]
=
settings
.
HomeContent
updates
[
SettingKeyHomeContent
]
=
settings
.
HomeContent
updates
[
SettingKeyHideCcsImportButton
]
=
strconv
.
FormatBool
(
settings
.
HideCcsImportButton
)
updates
[
SettingKeyPurchaseSubscriptionEnabled
]
=
strconv
.
FormatBool
(
settings
.
PurchaseSubscriptionEnabled
)
updates
[
SettingKeyPurchaseSubscriptionURL
]
=
strings
.
TrimSpace
(
settings
.
PurchaseSubscriptionURL
)
// 默认配置
// 默认配置
updates
[
SettingKeyDefaultConcurrency
]
=
strconv
.
Itoa
(
settings
.
DefaultConcurrency
)
updates
[
SettingKeyDefaultConcurrency
]
=
strconv
.
Itoa
(
settings
.
DefaultConcurrency
)
...
@@ -243,6 +277,44 @@ func (s *SettingService) IsEmailVerifyEnabled(ctx context.Context) bool {
...
@@ -243,6 +277,44 @@ func (s *SettingService) IsEmailVerifyEnabled(ctx context.Context) bool {
return
value
==
"true"
return
value
==
"true"
}
}
// IsPromoCodeEnabled 检查是否启用优惠码功能
func
(
s
*
SettingService
)
IsPromoCodeEnabled
(
ctx
context
.
Context
)
bool
{
value
,
err
:=
s
.
settingRepo
.
GetValue
(
ctx
,
SettingKeyPromoCodeEnabled
)
if
err
!=
nil
{
return
true
// 默认启用
}
return
value
!=
"false"
}
// IsPasswordResetEnabled 检查是否启用密码重置功能
// 要求:必须同时开启邮件验证
func
(
s
*
SettingService
)
IsPasswordResetEnabled
(
ctx
context
.
Context
)
bool
{
// Password reset requires email verification to be enabled
if
!
s
.
IsEmailVerifyEnabled
(
ctx
)
{
return
false
}
value
,
err
:=
s
.
settingRepo
.
GetValue
(
ctx
,
SettingKeyPasswordResetEnabled
)
if
err
!=
nil
{
return
false
// 默认关闭
}
return
value
==
"true"
}
// IsTotpEnabled 检查是否启用 TOTP 双因素认证功能
func
(
s
*
SettingService
)
IsTotpEnabled
(
ctx
context
.
Context
)
bool
{
value
,
err
:=
s
.
settingRepo
.
GetValue
(
ctx
,
SettingKeyTotpEnabled
)
if
err
!=
nil
{
return
false
// 默认关闭
}
return
value
==
"true"
}
// IsTotpEncryptionKeyConfigured 检查 TOTP 加密密钥是否已手动配置
// 只有手动配置了密钥才允许在管理后台启用 TOTP 功能
func
(
s
*
SettingService
)
IsTotpEncryptionKeyConfigured
()
bool
{
return
s
.
cfg
.
Totp
.
EncryptionKeyConfigured
}
// GetSiteName 获取网站名称
// GetSiteName 获取网站名称
func
(
s
*
SettingService
)
GetSiteName
(
ctx
context
.
Context
)
string
{
func
(
s
*
SettingService
)
GetSiteName
(
ctx
context
.
Context
)
string
{
value
,
err
:=
s
.
settingRepo
.
GetValue
(
ctx
,
SettingKeySiteName
)
value
,
err
:=
s
.
settingRepo
.
GetValue
(
ctx
,
SettingKeySiteName
)
...
@@ -290,14 +362,17 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error {
...
@@ -290,14 +362,17 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error {
// 初始化默认设置
// 初始化默认设置
defaults
:=
map
[
string
]
string
{
defaults
:=
map
[
string
]
string
{
SettingKeyRegistrationEnabled
:
"true"
,
SettingKeyRegistrationEnabled
:
"true"
,
SettingKeyEmailVerifyEnabled
:
"false"
,
SettingKeyEmailVerifyEnabled
:
"false"
,
SettingKeySiteName
:
"Sub2API"
,
SettingKeyPromoCodeEnabled
:
"true"
,
// 默认启用优惠码功能
SettingKeySiteLogo
:
""
,
SettingKeySiteName
:
"Sub2API"
,
SettingKeyDefaultConcurrency
:
strconv
.
Itoa
(
s
.
cfg
.
Default
.
UserConcurrency
),
SettingKeySiteLogo
:
""
,
SettingKeyDefaultBalance
:
strconv
.
FormatFloat
(
s
.
cfg
.
Default
.
UserBalance
,
'f'
,
8
,
64
),
SettingKeyPurchaseSubscriptionEnabled
:
"false"
,
SettingKeySMTPPort
:
"587"
,
SettingKeyPurchaseSubscriptionURL
:
""
,
SettingKeySMTPUseTLS
:
"false"
,
SettingKeyDefaultConcurrency
:
strconv
.
Itoa
(
s
.
cfg
.
Default
.
UserConcurrency
),
SettingKeyDefaultBalance
:
strconv
.
FormatFloat
(
s
.
cfg
.
Default
.
UserBalance
,
'f'
,
8
,
64
),
SettingKeySMTPPort
:
"587"
,
SettingKeySMTPUseTLS
:
"false"
,
// Model fallback defaults
// Model fallback defaults
SettingKeyEnableModelFallback
:
"false"
,
SettingKeyEnableModelFallback
:
"false"
,
SettingKeyFallbackModelAnthropic
:
"claude-3-5-sonnet-20241022"
,
SettingKeyFallbackModelAnthropic
:
"claude-3-5-sonnet-20241022"
,
...
@@ -320,9 +395,13 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error {
...
@@ -320,9 +395,13 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error {
// parseSettings 解析设置到结构体
// parseSettings 解析设置到结构体
func
(
s
*
SettingService
)
parseSettings
(
settings
map
[
string
]
string
)
*
SystemSettings
{
func
(
s
*
SettingService
)
parseSettings
(
settings
map
[
string
]
string
)
*
SystemSettings
{
emailVerifyEnabled
:=
settings
[
SettingKeyEmailVerifyEnabled
]
==
"true"
result
:=
&
SystemSettings
{
result
:=
&
SystemSettings
{
RegistrationEnabled
:
settings
[
SettingKeyRegistrationEnabled
]
==
"true"
,
RegistrationEnabled
:
settings
[
SettingKeyRegistrationEnabled
]
==
"true"
,
EmailVerifyEnabled
:
settings
[
SettingKeyEmailVerifyEnabled
]
==
"true"
,
EmailVerifyEnabled
:
emailVerifyEnabled
,
PromoCodeEnabled
:
settings
[
SettingKeyPromoCodeEnabled
]
!=
"false"
,
// 默认启用
PasswordResetEnabled
:
emailVerifyEnabled
&&
settings
[
SettingKeyPasswordResetEnabled
]
==
"true"
,
TotpEnabled
:
settings
[
SettingKeyTotpEnabled
]
==
"true"
,
SMTPHost
:
settings
[
SettingKeySMTPHost
],
SMTPHost
:
settings
[
SettingKeySMTPHost
],
SMTPUsername
:
settings
[
SettingKeySMTPUsername
],
SMTPUsername
:
settings
[
SettingKeySMTPUsername
],
SMTPFrom
:
settings
[
SettingKeySMTPFrom
],
SMTPFrom
:
settings
[
SettingKeySMTPFrom
],
...
@@ -339,6 +418,9 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin
...
@@ -339,6 +418,9 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin
ContactInfo
:
settings
[
SettingKeyContactInfo
],
ContactInfo
:
settings
[
SettingKeyContactInfo
],
DocURL
:
settings
[
SettingKeyDocURL
],
DocURL
:
settings
[
SettingKeyDocURL
],
HomeContent
:
settings
[
SettingKeyHomeContent
],
HomeContent
:
settings
[
SettingKeyHomeContent
],
HideCcsImportButton
:
settings
[
SettingKeyHideCcsImportButton
]
==
"true"
,
PurchaseSubscriptionEnabled
:
settings
[
SettingKeyPurchaseSubscriptionEnabled
]
==
"true"
,
PurchaseSubscriptionURL
:
strings
.
TrimSpace
(
settings
[
SettingKeyPurchaseSubscriptionURL
]),
}
}
// 解析整数类型
// 解析整数类型
...
...
backend/internal/service/settings_view.go
View file @
2fe8932c
package
service
package
service
type
SystemSettings
struct
{
type
SystemSettings
struct
{
RegistrationEnabled
bool
RegistrationEnabled
bool
EmailVerifyEnabled
bool
EmailVerifyEnabled
bool
PromoCodeEnabled
bool
PasswordResetEnabled
bool
TotpEnabled
bool
// TOTP 双因素认证
SMTPHost
string
SMTPHost
string
SMTPPort
int
SMTPPort
int
...
@@ -25,13 +28,16 @@ type SystemSettings struct {
...
@@ -25,13 +28,16 @@ type SystemSettings struct {
LinuxDoConnectClientSecretConfigured
bool
LinuxDoConnectClientSecretConfigured
bool
LinuxDoConnectRedirectURL
string
LinuxDoConnectRedirectURL
string
SiteName
string
SiteName
string
SiteLogo
string
SiteLogo
string
SiteSubtitle
string
SiteSubtitle
string
APIBaseURL
string
APIBaseURL
string
ContactInfo
string
ContactInfo
string
DocURL
string
DocURL
string
HomeContent
string
HomeContent
string
HideCcsImportButton
bool
PurchaseSubscriptionEnabled
bool
PurchaseSubscriptionURL
string
DefaultConcurrency
int
DefaultConcurrency
int
DefaultBalance
float64
DefaultBalance
float64
...
@@ -55,17 +61,25 @@ type SystemSettings struct {
...
@@ -55,17 +61,25 @@ type SystemSettings struct {
}
}
type
PublicSettings
struct
{
type
PublicSettings
struct
{
RegistrationEnabled
bool
RegistrationEnabled
bool
EmailVerifyEnabled
bool
EmailVerifyEnabled
bool
TurnstileEnabled
bool
PromoCodeEnabled
bool
TurnstileSiteKey
string
PasswordResetEnabled
bool
SiteName
string
TotpEnabled
bool
// TOTP 双因素认证
SiteLogo
string
TurnstileEnabled
bool
SiteSubtitle
string
TurnstileSiteKey
string
APIBaseURL
string
SiteName
string
ContactInfo
string
SiteLogo
string
DocURL
string
SiteSubtitle
string
HomeContent
string
APIBaseURL
string
ContactInfo
string
DocURL
string
HomeContent
string
HideCcsImportButton
bool
PurchaseSubscriptionEnabled
bool
PurchaseSubscriptionURL
string
LinuxDoOAuthEnabled
bool
LinuxDoOAuthEnabled
bool
Version
string
Version
string
}
}
...
...
backend/internal/service/sticky_session_test.go
0 → 100644
View file @
2fe8932c
//go:build unit
// Package service 提供 API 网关核心服务。
// 本文件包含 shouldClearStickySession 函数的单元测试,
// 验证粘性会话清理逻辑在各种账号状态下的正确行为。
//
// This file contains unit tests for the shouldClearStickySession function,
// verifying correct sticky session clearing behavior under various account states.
package
service
import
(
"testing"
"time"
"github.com/stretchr/testify/require"
)
// TestShouldClearStickySession 测试粘性会话清理判断逻辑。
// 验证在以下情况下是否正确判断需要清理粘性会话:
// - nil 账号:不清理(返回 false)
// - 状态为错误或禁用:清理
// - 不可调度:清理
// - 临时不可调度且未过期:清理
// - 临时不可调度已过期:不清理
// - 正常可调度状态:不清理
//
// TestShouldClearStickySession tests the sticky session clearing logic.
// Verifies correct behavior for various account states including:
// nil account, error/disabled status, unschedulable, temporary unschedulable.
func
TestShouldClearStickySession
(
t
*
testing
.
T
)
{
now
:=
time
.
Now
()
future
:=
now
.
Add
(
1
*
time
.
Hour
)
past
:=
now
.
Add
(
-
1
*
time
.
Hour
)
tests
:=
[]
struct
{
name
string
account
*
Account
want
bool
}{
{
name
:
"nil account"
,
account
:
nil
,
want
:
false
},
{
name
:
"status error"
,
account
:
&
Account
{
Status
:
StatusError
,
Schedulable
:
true
},
want
:
true
},
{
name
:
"status disabled"
,
account
:
&
Account
{
Status
:
StatusDisabled
,
Schedulable
:
true
},
want
:
true
},
{
name
:
"schedulable false"
,
account
:
&
Account
{
Status
:
StatusActive
,
Schedulable
:
false
},
want
:
true
},
{
name
:
"temp unschedulable"
,
account
:
&
Account
{
Status
:
StatusActive
,
Schedulable
:
true
,
TempUnschedulableUntil
:
&
future
},
want
:
true
},
{
name
:
"temp unschedulable expired"
,
account
:
&
Account
{
Status
:
StatusActive
,
Schedulable
:
true
,
TempUnschedulableUntil
:
&
past
},
want
:
false
},
{
name
:
"active schedulable"
,
account
:
&
Account
{
Status
:
StatusActive
,
Schedulable
:
true
},
want
:
false
},
}
for
_
,
tt
:=
range
tests
{
t
.
Run
(
tt
.
name
,
func
(
t
*
testing
.
T
)
{
require
.
Equal
(
t
,
tt
.
want
,
shouldClearStickySession
(
tt
.
account
))
})
}
}
backend/internal/service/subscription_expiry_service.go
0 → 100644
View file @
2fe8932c
package
service
import
(
"context"
"log"
"sync"
"time"
)
// SubscriptionExpiryService periodically updates expired subscription status.
type
SubscriptionExpiryService
struct
{
userSubRepo
UserSubscriptionRepository
interval
time
.
Duration
stopCh
chan
struct
{}
stopOnce
sync
.
Once
wg
sync
.
WaitGroup
}
func
NewSubscriptionExpiryService
(
userSubRepo
UserSubscriptionRepository
,
interval
time
.
Duration
)
*
SubscriptionExpiryService
{
return
&
SubscriptionExpiryService
{
userSubRepo
:
userSubRepo
,
interval
:
interval
,
stopCh
:
make
(
chan
struct
{}),
}
}
func
(
s
*
SubscriptionExpiryService
)
Start
()
{
if
s
==
nil
||
s
.
userSubRepo
==
nil
||
s
.
interval
<=
0
{
return
}
s
.
wg
.
Add
(
1
)
go
func
()
{
defer
s
.
wg
.
Done
()
ticker
:=
time
.
NewTicker
(
s
.
interval
)
defer
ticker
.
Stop
()
s
.
runOnce
()
for
{
select
{
case
<-
ticker
.
C
:
s
.
runOnce
()
case
<-
s
.
stopCh
:
return
}
}
}()
}
func
(
s
*
SubscriptionExpiryService
)
Stop
()
{
if
s
==
nil
{
return
}
s
.
stopOnce
.
Do
(
func
()
{
close
(
s
.
stopCh
)
})
s
.
wg
.
Wait
()
}
func
(
s
*
SubscriptionExpiryService
)
runOnce
()
{
ctx
,
cancel
:=
context
.
WithTimeout
(
context
.
Background
(),
10
*
time
.
Second
)
defer
cancel
()
updated
,
err
:=
s
.
userSubRepo
.
BatchUpdateExpiredStatus
(
ctx
)
if
err
!=
nil
{
log
.
Printf
(
"[SubscriptionExpiry] Update expired subscriptions failed: %v"
,
err
)
return
}
if
updated
>
0
{
log
.
Printf
(
"[SubscriptionExpiry] Updated %d expired subscriptions"
,
updated
)
}
}
backend/internal/service/subscription_service.go
View file @
2fe8932c
...
@@ -27,6 +27,7 @@ var (
...
@@ -27,6 +27,7 @@ var (
ErrWeeklyLimitExceeded
=
infraerrors
.
TooManyRequests
(
"WEEKLY_LIMIT_EXCEEDED"
,
"weekly usage limit exceeded"
)
ErrWeeklyLimitExceeded
=
infraerrors
.
TooManyRequests
(
"WEEKLY_LIMIT_EXCEEDED"
,
"weekly usage limit exceeded"
)
ErrMonthlyLimitExceeded
=
infraerrors
.
TooManyRequests
(
"MONTHLY_LIMIT_EXCEEDED"
,
"monthly usage limit exceeded"
)
ErrMonthlyLimitExceeded
=
infraerrors
.
TooManyRequests
(
"MONTHLY_LIMIT_EXCEEDED"
,
"monthly usage limit exceeded"
)
ErrSubscriptionNilInput
=
infraerrors
.
BadRequest
(
"SUBSCRIPTION_NIL_INPUT"
,
"subscription input cannot be nil"
)
ErrSubscriptionNilInput
=
infraerrors
.
BadRequest
(
"SUBSCRIPTION_NIL_INPUT"
,
"subscription input cannot be nil"
)
ErrAdjustWouldExpire
=
infraerrors
.
BadRequest
(
"ADJUST_WOULD_EXPIRE"
,
"adjustment would result in expired subscription (remaining days must be > 0)"
)
)
)
// SubscriptionService 订阅服务
// SubscriptionService 订阅服务
...
@@ -308,24 +309,48 @@ func (s *SubscriptionService) RevokeSubscription(ctx context.Context, subscripti
...
@@ -308,24 +309,48 @@ func (s *SubscriptionService) RevokeSubscription(ctx context.Context, subscripti
return
nil
return
nil
}
}
// ExtendSubscription
延长订阅
// ExtendSubscription
调整订阅时长(正数延长,负数缩短)
func
(
s
*
SubscriptionService
)
ExtendSubscription
(
ctx
context
.
Context
,
subscriptionID
int64
,
days
int
)
(
*
UserSubscription
,
error
)
{
func
(
s
*
SubscriptionService
)
ExtendSubscription
(
ctx
context
.
Context
,
subscriptionID
int64
,
days
int
)
(
*
UserSubscription
,
error
)
{
sub
,
err
:=
s
.
userSubRepo
.
GetByID
(
ctx
,
subscriptionID
)
sub
,
err
:=
s
.
userSubRepo
.
GetByID
(
ctx
,
subscriptionID
)
if
err
!=
nil
{
if
err
!=
nil
{
return
nil
,
ErrSubscriptionNotFound
return
nil
,
ErrSubscriptionNotFound
}
}
// 限制
延长天数
// 限制
调整天数范围
if
days
>
MaxValidityDays
{
if
days
>
MaxValidityDays
{
days
=
MaxValidityDays
days
=
MaxValidityDays
}
}
if
days
<
-
MaxValidityDays
{
days
=
-
MaxValidityDays
}
now
:=
time
.
Now
()
isExpired
:=
!
sub
.
ExpiresAt
.
After
(
now
)
// 如果订阅已过期,不允许负向调整
if
isExpired
&&
days
<
0
{
return
nil
,
infraerrors
.
BadRequest
(
"CANNOT_SHORTEN_EXPIRED"
,
"cannot shorten an expired subscription"
)
}
// 计算新的过期时间
// 计算新的过期时间
newExpiresAt
:=
sub
.
ExpiresAt
.
AddDate
(
0
,
0
,
days
)
var
newExpiresAt
time
.
Time
if
isExpired
{
// 已过期:从当前时间开始增加天数
newExpiresAt
=
now
.
AddDate
(
0
,
0
,
days
)
}
else
{
// 未过期:从原过期时间增加/减少天数
newExpiresAt
=
sub
.
ExpiresAt
.
AddDate
(
0
,
0
,
days
)
}
if
newExpiresAt
.
After
(
MaxExpiresAt
)
{
if
newExpiresAt
.
After
(
MaxExpiresAt
)
{
newExpiresAt
=
MaxExpiresAt
newExpiresAt
=
MaxExpiresAt
}
}
// 检查新的过期时间必须大于当前时间
if
!
newExpiresAt
.
After
(
now
)
{
return
nil
,
ErrAdjustWouldExpire
}
if
err
:=
s
.
userSubRepo
.
ExtendExpiry
(
ctx
,
subscriptionID
,
newExpiresAt
);
err
!=
nil
{
if
err
:=
s
.
userSubRepo
.
ExtendExpiry
(
ctx
,
subscriptionID
,
newExpiresAt
);
err
!=
nil
{
return
nil
,
err
return
nil
,
err
}
}
...
@@ -371,6 +396,7 @@ func (s *SubscriptionService) ListUserSubscriptions(ctx context.Context, userID
...
@@ -371,6 +396,7 @@ func (s *SubscriptionService) ListUserSubscriptions(ctx context.Context, userID
return
nil
,
err
return
nil
,
err
}
}
normalizeExpiredWindows
(
subs
)
normalizeExpiredWindows
(
subs
)
normalizeSubscriptionStatus
(
subs
)
return
subs
,
nil
return
subs
,
nil
}
}
...
@@ -392,17 +418,19 @@ func (s *SubscriptionService) ListGroupSubscriptions(ctx context.Context, groupI
...
@@ -392,17 +418,19 @@ func (s *SubscriptionService) ListGroupSubscriptions(ctx context.Context, groupI
return
nil
,
nil
,
err
return
nil
,
nil
,
err
}
}
normalizeExpiredWindows
(
subs
)
normalizeExpiredWindows
(
subs
)
normalizeSubscriptionStatus
(
subs
)
return
subs
,
pag
,
nil
return
subs
,
pag
,
nil
}
}
// List 获取所有订阅(分页,支持筛选)
// List 获取所有订阅(分页,支持筛选
和排序
)
func
(
s
*
SubscriptionService
)
List
(
ctx
context
.
Context
,
page
,
pageSize
int
,
userID
,
groupID
*
int64
,
status
string
)
([]
UserSubscription
,
*
pagination
.
PaginationResult
,
error
)
{
func
(
s
*
SubscriptionService
)
List
(
ctx
context
.
Context
,
page
,
pageSize
int
,
userID
,
groupID
*
int64
,
status
,
sortBy
,
sortOrder
string
)
([]
UserSubscription
,
*
pagination
.
PaginationResult
,
error
)
{
params
:=
pagination
.
PaginationParams
{
Page
:
page
,
PageSize
:
pageSize
}
params
:=
pagination
.
PaginationParams
{
Page
:
page
,
PageSize
:
pageSize
}
subs
,
pag
,
err
:=
s
.
userSubRepo
.
List
(
ctx
,
params
,
userID
,
groupID
,
status
)
subs
,
pag
,
err
:=
s
.
userSubRepo
.
List
(
ctx
,
params
,
userID
,
groupID
,
status
,
sortBy
,
sortOrder
)
if
err
!=
nil
{
if
err
!=
nil
{
return
nil
,
nil
,
err
return
nil
,
nil
,
err
}
}
normalizeExpiredWindows
(
subs
)
normalizeExpiredWindows
(
subs
)
normalizeSubscriptionStatus
(
subs
)
return
subs
,
pag
,
nil
return
subs
,
pag
,
nil
}
}
...
@@ -429,6 +457,18 @@ func normalizeExpiredWindows(subs []UserSubscription) {
...
@@ -429,6 +457,18 @@ func normalizeExpiredWindows(subs []UserSubscription) {
}
}
}
}
// normalizeSubscriptionStatus 根据实际过期时间修正状态(仅影响返回数据,不影响数据库)
// 这确保前端显示正确的状态,即使定时任务尚未更新数据库
func
normalizeSubscriptionStatus
(
subs
[]
UserSubscription
)
{
now
:=
time
.
Now
()
for
i
:=
range
subs
{
sub
:=
&
subs
[
i
]
if
sub
.
Status
==
SubscriptionStatusActive
&&
!
sub
.
ExpiresAt
.
After
(
now
)
{
sub
.
Status
=
SubscriptionStatusExpired
}
}
}
// startOfDay 返回给定时间所在日期的零点(保持原时区)
// startOfDay 返回给定时间所在日期的零点(保持原时区)
func
startOfDay
(
t
time
.
Time
)
time
.
Time
{
func
startOfDay
(
t
time
.
Time
)
time
.
Time
{
return
time
.
Date
(
t
.
Year
(),
t
.
Month
(),
t
.
Day
(),
0
,
0
,
0
,
0
,
t
.
Location
())
return
time
.
Date
(
t
.
Year
(),
t
.
Month
(),
t
.
Day
(),
0
,
0
,
0
,
0
,
t
.
Location
())
...
@@ -647,11 +687,6 @@ func (s *SubscriptionService) GetUserSubscriptionsWithProgress(ctx context.Conte
...
@@ -647,11 +687,6 @@ func (s *SubscriptionService) GetUserSubscriptionsWithProgress(ctx context.Conte
return
progresses
,
nil
return
progresses
,
nil
}
}
// UpdateExpiredSubscriptions 更新过期订阅状态(定时任务调用)
func
(
s
*
SubscriptionService
)
UpdateExpiredSubscriptions
(
ctx
context
.
Context
)
(
int64
,
error
)
{
return
s
.
userSubRepo
.
BatchUpdateExpiredStatus
(
ctx
)
}
// ValidateSubscription 验证订阅是否有效
// ValidateSubscription 验证订阅是否有效
func
(
s
*
SubscriptionService
)
ValidateSubscription
(
ctx
context
.
Context
,
sub
*
UserSubscription
)
error
{
func
(
s
*
SubscriptionService
)
ValidateSubscription
(
ctx
context
.
Context
,
sub
*
UserSubscription
)
error
{
if
sub
.
Status
==
SubscriptionStatusExpired
{
if
sub
.
Status
==
SubscriptionStatusExpired
{
...
...
Prev
1
…
5
6
7
8
9
10
11
12
13
14
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment