Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Menu
Open sidebar
陈曦
sub2api
Commits
bb664d9b
Commit
bb664d9b
authored
Feb 28, 2026
by
yangjianbo
Browse files
feat(sync): full code sync from release
parent
bfc7b339
Changes
244
Hide whitespace changes
Inline
Side-by-side
Too many changes to show.
To preserve performance only
244 of 244+
files are displayed.
Plain diff
Email patch
backend/internal/pkg/geminicli/constants.go
View file @
bb664d9b
...
@@ -39,7 +39,7 @@ const (
...
@@ -39,7 +39,7 @@ const (
// They enable the "login without creating your own OAuth client" experience, but Google may
// They enable the "login without creating your own OAuth client" experience, but Google may
// restrict which scopes are allowed for this client.
// restrict which scopes are allowed for this client.
GeminiCLIOAuthClientID
=
"681255809395-oo8ft2oprdrnp9e3aqf6av3hmdib135j.apps.googleusercontent.com"
GeminiCLIOAuthClientID
=
"681255809395-oo8ft2oprdrnp9e3aqf6av3hmdib135j.apps.googleusercontent.com"
GeminiCLIOAuthClientSecret
=
"GOCSPX-
4uHgMPm-1o7Sk-geV6Cu5clXFsxl
"
GeminiCLIOAuthClientSecret
=
"GOCSPX-
your-client-secret
"
// GeminiCLIOAuthClientSecretEnv is the environment variable name for the built-in client secret.
// GeminiCLIOAuthClientSecretEnv is the environment variable name for the built-in client secret.
GeminiCLIOAuthClientSecretEnv
=
"GEMINI_CLI_OAUTH_CLIENT_SECRET"
GeminiCLIOAuthClientSecretEnv
=
"GEMINI_CLI_OAUTH_CLIENT_SECRET"
...
...
backend/internal/pkg/httpclient/pool.go
View file @
bb664d9b
...
@@ -32,6 +32,7 @@ const (
...
@@ -32,6 +32,7 @@ const (
defaultMaxIdleConns
=
100
// 最大空闲连接数
defaultMaxIdleConns
=
100
// 最大空闲连接数
defaultMaxIdleConnsPerHost
=
10
// 每个主机最大空闲连接数
defaultMaxIdleConnsPerHost
=
10
// 每个主机最大空闲连接数
defaultIdleConnTimeout
=
90
*
time
.
Second
// 空闲连接超时时间(建议小于上游 LB 超时)
defaultIdleConnTimeout
=
90
*
time
.
Second
// 空闲连接超时时间(建议小于上游 LB 超时)
validatedHostTTL
=
30
*
time
.
Second
// DNS Rebinding 校验缓存 TTL
)
)
// Options 定义共享 HTTP 客户端的构建参数
// Options 定义共享 HTTP 客户端的构建参数
...
@@ -53,6 +54,9 @@ type Options struct {
...
@@ -53,6 +54,9 @@ type Options struct {
// sharedClients 存储按配置参数缓存的 http.Client 实例
// sharedClients 存储按配置参数缓存的 http.Client 实例
var
sharedClients
sync
.
Map
var
sharedClients
sync
.
Map
// 允许测试替换校验函数,生产默认指向真实实现。
var
validateResolvedIP
=
urlvalidator
.
ValidateResolvedIP
// GetClient 返回共享的 HTTP 客户端实例
// GetClient 返回共享的 HTTP 客户端实例
// 性能优化:相同配置复用同一客户端,避免重复创建 Transport
// 性能优化:相同配置复用同一客户端,避免重复创建 Transport
// 安全说明:代理配置失败时直接返回错误,不会回退到直连,避免 IP 关联风险
// 安全说明:代理配置失败时直接返回错误,不会回退到直连,避免 IP 关联风险
...
@@ -84,7 +88,7 @@ func buildClient(opts Options) (*http.Client, error) {
...
@@ -84,7 +88,7 @@ func buildClient(opts Options) (*http.Client, error) {
var
rt
http
.
RoundTripper
=
transport
var
rt
http
.
RoundTripper
=
transport
if
opts
.
ValidateResolvedIP
&&
!
opts
.
AllowPrivateHosts
{
if
opts
.
ValidateResolvedIP
&&
!
opts
.
AllowPrivateHosts
{
rt
=
&
v
alidatedTransport
{
base
:
transport
}
rt
=
newV
alidatedTransport
(
transport
)
}
}
return
&
http
.
Client
{
return
&
http
.
Client
{
Transport
:
rt
,
Transport
:
rt
,
...
@@ -149,17 +153,56 @@ func buildClientKey(opts Options) string {
...
@@ -149,17 +153,56 @@ func buildClientKey(opts Options) string {
}
}
type
validatedTransport
struct
{
type
validatedTransport
struct
{
base
http
.
RoundTripper
base
http
.
RoundTripper
validatedHosts
sync
.
Map
// map[string]time.Time, value 为过期时间
now
func
()
time
.
Time
}
func
newValidatedTransport
(
base
http
.
RoundTripper
)
*
validatedTransport
{
return
&
validatedTransport
{
base
:
base
,
now
:
time
.
Now
,
}
}
func
(
t
*
validatedTransport
)
isValidatedHost
(
host
string
,
now
time
.
Time
)
bool
{
if
t
==
nil
{
return
false
}
raw
,
ok
:=
t
.
validatedHosts
.
Load
(
host
)
if
!
ok
{
return
false
}
expireAt
,
ok
:=
raw
.
(
time
.
Time
)
if
!
ok
{
t
.
validatedHosts
.
Delete
(
host
)
return
false
}
if
now
.
Before
(
expireAt
)
{
return
true
}
t
.
validatedHosts
.
Delete
(
host
)
return
false
}
}
func
(
t
*
validatedTransport
)
RoundTrip
(
req
*
http
.
Request
)
(
*
http
.
Response
,
error
)
{
func
(
t
*
validatedTransport
)
RoundTrip
(
req
*
http
.
Request
)
(
*
http
.
Response
,
error
)
{
if
req
!=
nil
&&
req
.
URL
!=
nil
{
if
req
!=
nil
&&
req
.
URL
!=
nil
{
host
:=
strings
.
TrimSpace
(
req
.
URL
.
Hostname
())
host
:=
strings
.
ToLower
(
strings
.
TrimSpace
(
req
.
URL
.
Hostname
())
)
if
host
!=
""
{
if
host
!=
""
{
if
err
:=
urlvalidator
.
ValidateResolvedIP
(
host
);
err
!=
nil
{
now
:=
time
.
Now
()
return
nil
,
err
if
t
!=
nil
&&
t
.
now
!=
nil
{
now
=
t
.
now
()
}
if
!
t
.
isValidatedHost
(
host
,
now
)
{
if
err
:=
validateResolvedIP
(
host
);
err
!=
nil
{
return
nil
,
err
}
t
.
validatedHosts
.
Store
(
host
,
now
.
Add
(
validatedHostTTL
))
}
}
}
}
}
}
if
t
==
nil
||
t
.
base
==
nil
{
return
nil
,
fmt
.
Errorf
(
"validated transport base is nil"
)
}
return
t
.
base
.
RoundTrip
(
req
)
return
t
.
base
.
RoundTrip
(
req
)
}
}
backend/internal/pkg/httpclient/pool_test.go
0 → 100644
View file @
bb664d9b
package
httpclient
import
(
"errors"
"io"
"net/http"
"strings"
"sync/atomic"
"testing"
"time"
"github.com/stretchr/testify/require"
)
type
roundTripFunc
func
(
*
http
.
Request
)
(
*
http
.
Response
,
error
)
func
(
f
roundTripFunc
)
RoundTrip
(
req
*
http
.
Request
)
(
*
http
.
Response
,
error
)
{
return
f
(
req
)
}
func
TestValidatedTransport_CacheHostValidation
(
t
*
testing
.
T
)
{
originalValidate
:=
validateResolvedIP
defer
func
()
{
validateResolvedIP
=
originalValidate
}()
var
validateCalls
int32
validateResolvedIP
=
func
(
host
string
)
error
{
atomic
.
AddInt32
(
&
validateCalls
,
1
)
require
.
Equal
(
t
,
"api.openai.com"
,
host
)
return
nil
}
var
baseCalls
int32
base
:=
roundTripFunc
(
func
(
_
*
http
.
Request
)
(
*
http
.
Response
,
error
)
{
atomic
.
AddInt32
(
&
baseCalls
,
1
)
return
&
http
.
Response
{
StatusCode
:
http
.
StatusOK
,
Body
:
io
.
NopCloser
(
strings
.
NewReader
(
`{}`
)),
Header
:
make
(
http
.
Header
),
},
nil
})
now
:=
time
.
Unix
(
1730000000
,
0
)
transport
:=
newValidatedTransport
(
base
)
transport
.
now
=
func
()
time
.
Time
{
return
now
}
req
,
err
:=
http
.
NewRequest
(
http
.
MethodGet
,
"https://api.openai.com/v1/responses"
,
nil
)
require
.
NoError
(
t
,
err
)
_
,
err
=
transport
.
RoundTrip
(
req
)
require
.
NoError
(
t
,
err
)
_
,
err
=
transport
.
RoundTrip
(
req
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
int32
(
1
),
atomic
.
LoadInt32
(
&
validateCalls
))
require
.
Equal
(
t
,
int32
(
2
),
atomic
.
LoadInt32
(
&
baseCalls
))
}
func
TestValidatedTransport_ExpiredCacheTriggersRevalidation
(
t
*
testing
.
T
)
{
originalValidate
:=
validateResolvedIP
defer
func
()
{
validateResolvedIP
=
originalValidate
}()
var
validateCalls
int32
validateResolvedIP
=
func
(
_
string
)
error
{
atomic
.
AddInt32
(
&
validateCalls
,
1
)
return
nil
}
base
:=
roundTripFunc
(
func
(
_
*
http
.
Request
)
(
*
http
.
Response
,
error
)
{
return
&
http
.
Response
{
StatusCode
:
http
.
StatusOK
,
Body
:
io
.
NopCloser
(
strings
.
NewReader
(
`{}`
)),
Header
:
make
(
http
.
Header
),
},
nil
})
now
:=
time
.
Unix
(
1730001000
,
0
)
transport
:=
newValidatedTransport
(
base
)
transport
.
now
=
func
()
time
.
Time
{
return
now
}
req
,
err
:=
http
.
NewRequest
(
http
.
MethodGet
,
"https://api.openai.com/v1/responses"
,
nil
)
require
.
NoError
(
t
,
err
)
_
,
err
=
transport
.
RoundTrip
(
req
)
require
.
NoError
(
t
,
err
)
now
=
now
.
Add
(
validatedHostTTL
+
time
.
Second
)
_
,
err
=
transport
.
RoundTrip
(
req
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
int32
(
2
),
atomic
.
LoadInt32
(
&
validateCalls
))
}
func
TestValidatedTransport_ValidationErrorStopsRoundTrip
(
t
*
testing
.
T
)
{
originalValidate
:=
validateResolvedIP
defer
func
()
{
validateResolvedIP
=
originalValidate
}()
expectedErr
:=
errors
.
New
(
"dns rebinding rejected"
)
validateResolvedIP
=
func
(
_
string
)
error
{
return
expectedErr
}
var
baseCalls
int32
base
:=
roundTripFunc
(
func
(
_
*
http
.
Request
)
(
*
http
.
Response
,
error
)
{
atomic
.
AddInt32
(
&
baseCalls
,
1
)
return
&
http
.
Response
{
StatusCode
:
http
.
StatusOK
,
Body
:
io
.
NopCloser
(
strings
.
NewReader
(
`{}`
))},
nil
})
transport
:=
newValidatedTransport
(
base
)
req
,
err
:=
http
.
NewRequest
(
http
.
MethodGet
,
"https://api.openai.com/v1/responses"
,
nil
)
require
.
NoError
(
t
,
err
)
_
,
err
=
transport
.
RoundTrip
(
req
)
require
.
ErrorIs
(
t
,
err
,
expectedErr
)
require
.
Equal
(
t
,
int32
(
0
),
atomic
.
LoadInt32
(
&
baseCalls
))
}
backend/internal/pkg/httputil/body.go
0 → 100644
View file @
bb664d9b
package
httputil
import
(
"bytes"
"io"
"net/http"
)
const
(
requestBodyReadInitCap
=
512
requestBodyReadMaxInitCap
=
1
<<
20
)
// ReadRequestBodyWithPrealloc reads request body with preallocated buffer based on content length.
func
ReadRequestBodyWithPrealloc
(
req
*
http
.
Request
)
([]
byte
,
error
)
{
if
req
==
nil
||
req
.
Body
==
nil
{
return
nil
,
nil
}
capHint
:=
requestBodyReadInitCap
if
req
.
ContentLength
>
0
{
switch
{
case
req
.
ContentLength
<
int64
(
requestBodyReadInitCap
)
:
capHint
=
requestBodyReadInitCap
case
req
.
ContentLength
>
int64
(
requestBodyReadMaxInitCap
)
:
capHint
=
requestBodyReadMaxInitCap
default
:
capHint
=
int
(
req
.
ContentLength
)
}
}
buf
:=
bytes
.
NewBuffer
(
make
([]
byte
,
0
,
capHint
))
if
_
,
err
:=
io
.
Copy
(
buf
,
req
.
Body
);
err
!=
nil
{
return
nil
,
err
}
return
buf
.
Bytes
(),
nil
}
backend/internal/pkg/ip/ip.go
View file @
bb664d9b
...
@@ -67,6 +67,14 @@ func normalizeIP(ip string) string {
...
@@ -67,6 +67,14 @@ func normalizeIP(ip string) string {
// privateNets 预编译私有 IP CIDR 块,避免每次调用 isPrivateIP 时重复解析
// privateNets 预编译私有 IP CIDR 块,避免每次调用 isPrivateIP 时重复解析
var
privateNets
[]
*
net
.
IPNet
var
privateNets
[]
*
net
.
IPNet
// CompiledIPRules 表示预编译的 IP 匹配规则。
// PatternCount 记录原始规则数量,用于保留“规则存在但全无效”时的行为语义。
type
CompiledIPRules
struct
{
CIDRs
[]
*
net
.
IPNet
IPs
[]
net
.
IP
PatternCount
int
}
func
init
()
{
func
init
()
{
for
_
,
cidr
:=
range
[]
string
{
for
_
,
cidr
:=
range
[]
string
{
"10.0.0.0/8"
,
"10.0.0.0/8"
,
...
@@ -84,6 +92,53 @@ func init() {
...
@@ -84,6 +92,53 @@ func init() {
}
}
}
}
// CompileIPRules 将 IP/CIDR 字符串规则预编译为可复用结构。
// 非法规则会被忽略,但 PatternCount 会保留原始规则条数。
func
CompileIPRules
(
patterns
[]
string
)
*
CompiledIPRules
{
compiled
:=
&
CompiledIPRules
{
CIDRs
:
make
([]
*
net
.
IPNet
,
0
,
len
(
patterns
)),
IPs
:
make
([]
net
.
IP
,
0
,
len
(
patterns
)),
PatternCount
:
len
(
patterns
),
}
for
_
,
pattern
:=
range
patterns
{
normalized
:=
strings
.
TrimSpace
(
pattern
)
if
normalized
==
""
{
continue
}
if
strings
.
Contains
(
normalized
,
"/"
)
{
_
,
cidr
,
err
:=
net
.
ParseCIDR
(
normalized
)
if
err
!=
nil
||
cidr
==
nil
{
continue
}
compiled
.
CIDRs
=
append
(
compiled
.
CIDRs
,
cidr
)
continue
}
parsedIP
:=
net
.
ParseIP
(
normalized
)
if
parsedIP
==
nil
{
continue
}
compiled
.
IPs
=
append
(
compiled
.
IPs
,
parsedIP
)
}
return
compiled
}
func
matchesCompiledRules
(
parsedIP
net
.
IP
,
rules
*
CompiledIPRules
)
bool
{
if
parsedIP
==
nil
||
rules
==
nil
{
return
false
}
for
_
,
cidr
:=
range
rules
.
CIDRs
{
if
cidr
.
Contains
(
parsedIP
)
{
return
true
}
}
for
_
,
ruleIP
:=
range
rules
.
IPs
{
if
parsedIP
.
Equal
(
ruleIP
)
{
return
true
}
}
return
false
}
// isPrivateIP 检查 IP 是否为私有地址。
// isPrivateIP 检查 IP 是否为私有地址。
func
isPrivateIP
(
ipStr
string
)
bool
{
func
isPrivateIP
(
ipStr
string
)
bool
{
ip
:=
net
.
ParseIP
(
ipStr
)
ip
:=
net
.
ParseIP
(
ipStr
)
...
@@ -142,19 +197,32 @@ func MatchesAnyPattern(clientIP string, patterns []string) bool {
...
@@ -142,19 +197,32 @@ func MatchesAnyPattern(clientIP string, patterns []string) bool {
// 2. 如果白名单不为空,IP 必须在白名单中
// 2. 如果白名单不为空,IP 必须在白名单中
// 3. 如果白名单为空,允许访问(除非被黑名单拒绝)
// 3. 如果白名单为空,允许访问(除非被黑名单拒绝)
func
CheckIPRestriction
(
clientIP
string
,
whitelist
,
blacklist
[]
string
)
(
bool
,
string
)
{
func
CheckIPRestriction
(
clientIP
string
,
whitelist
,
blacklist
[]
string
)
(
bool
,
string
)
{
return
CheckIPRestrictionWithCompiledRules
(
clientIP
,
CompileIPRules
(
whitelist
),
CompileIPRules
(
blacklist
),
)
}
// CheckIPRestrictionWithCompiledRules 使用预编译规则检查 IP 是否允许访问。
func
CheckIPRestrictionWithCompiledRules
(
clientIP
string
,
whitelist
,
blacklist
*
CompiledIPRules
)
(
bool
,
string
)
{
// 规范化 IP
// 规范化 IP
clientIP
=
normalizeIP
(
clientIP
)
clientIP
=
normalizeIP
(
clientIP
)
if
clientIP
==
""
{
if
clientIP
==
""
{
return
false
,
"access denied"
return
false
,
"access denied"
}
}
parsedIP
:=
net
.
ParseIP
(
clientIP
)
if
parsedIP
==
nil
{
return
false
,
"access denied"
}
// 1. 检查黑名单
// 1. 检查黑名单
if
len
(
blacklist
)
>
0
&&
M
atches
AnyPattern
(
client
IP
,
blacklist
)
{
if
blacklist
!=
nil
&&
blacklist
.
PatternCount
>
0
&&
m
atches
CompiledRules
(
parsed
IP
,
blacklist
)
{
return
false
,
"access denied"
return
false
,
"access denied"
}
}
// 2. 检查白名单(如果设置了白名单,IP 必须在其中)
// 2. 检查白名单(如果设置了白名单,IP 必须在其中)
if
len
(
whitelist
)
>
0
&&
!
M
atches
AnyPattern
(
client
IP
,
whitelist
)
{
if
whitelist
!=
nil
&&
whitelist
.
PatternCount
>
0
&&
!
m
atches
CompiledRules
(
parsed
IP
,
whitelist
)
{
return
false
,
"access denied"
return
false
,
"access denied"
}
}
...
...
backend/internal/pkg/ip/ip_test.go
View file @
bb664d9b
...
@@ -73,3 +73,24 @@ func TestGetTrustedClientIPUsesGinClientIP(t *testing.T) {
...
@@ -73,3 +73,24 @@ func TestGetTrustedClientIPUsesGinClientIP(t *testing.T) {
require
.
Equal
(
t
,
200
,
w
.
Code
)
require
.
Equal
(
t
,
200
,
w
.
Code
)
require
.
Equal
(
t
,
"9.9.9.9"
,
w
.
Body
.
String
())
require
.
Equal
(
t
,
"9.9.9.9"
,
w
.
Body
.
String
())
}
}
func
TestCheckIPRestrictionWithCompiledRules
(
t
*
testing
.
T
)
{
whitelist
:=
CompileIPRules
([]
string
{
"10.0.0.0/8"
,
"192.168.1.2"
})
blacklist
:=
CompileIPRules
([]
string
{
"10.1.1.1"
})
allowed
,
reason
:=
CheckIPRestrictionWithCompiledRules
(
"10.2.3.4"
,
whitelist
,
blacklist
)
require
.
True
(
t
,
allowed
)
require
.
Equal
(
t
,
""
,
reason
)
allowed
,
reason
=
CheckIPRestrictionWithCompiledRules
(
"10.1.1.1"
,
whitelist
,
blacklist
)
require
.
False
(
t
,
allowed
)
require
.
Equal
(
t
,
"access denied"
,
reason
)
}
func
TestCheckIPRestrictionWithCompiledRules_InvalidWhitelistStillDenies
(
t
*
testing
.
T
)
{
// 与旧实现保持一致:白名单有配置但全无效时,最终应拒绝访问。
invalidWhitelist
:=
CompileIPRules
([]
string
{
"not-a-valid-pattern"
})
allowed
,
reason
:=
CheckIPRestrictionWithCompiledRules
(
"8.8.8.8"
,
invalidWhitelist
,
nil
)
require
.
False
(
t
,
allowed
)
require
.
Equal
(
t
,
"access denied"
,
reason
)
}
backend/internal/pkg/logger/logger.go
View file @
bb664d9b
...
@@ -10,6 +10,7 @@ import (
...
@@ -10,6 +10,7 @@ import (
"path/filepath"
"path/filepath"
"strings"
"strings"
"sync"
"sync"
"sync/atomic"
"time"
"time"
"go.uber.org/zap"
"go.uber.org/zap"
...
@@ -42,15 +43,19 @@ type LogEvent struct {
...
@@ -42,15 +43,19 @@ type LogEvent struct {
var
(
var
(
mu
sync
.
RWMutex
mu
sync
.
RWMutex
global
*
zap
.
Logger
global
atomic
.
Pointer
[
zap
.
Logger
]
sugar
*
zap
.
SugaredLogger
sugar
atomic
.
Pointer
[
zap
.
SugaredLogger
]
atomicLevel
zap
.
AtomicLevel
atomicLevel
zap
.
AtomicLevel
initOptions
InitOptions
initOptions
InitOptions
currentSink
Sink
currentSink
atomic
.
Value
// sinkState
stdLogUndo
func
()
stdLogUndo
func
()
bootstrapOnce
sync
.
Once
bootstrapOnce
sync
.
Once
)
)
type
sinkState
struct
{
sink
Sink
}
func
InitBootstrap
()
{
func
InitBootstrap
()
{
bootstrapOnce
.
Do
(
func
()
{
bootstrapOnce
.
Do
(
func
()
{
if
err
:=
Init
(
bootstrapOptions
());
err
!=
nil
{
if
err
:=
Init
(
bootstrapOptions
());
err
!=
nil
{
...
@@ -72,9 +77,9 @@ func initLocked(options InitOptions) error {
...
@@ -72,9 +77,9 @@ func initLocked(options InitOptions) error {
return
err
return
err
}
}
prev
:=
global
prev
:=
global
.
Load
()
global
=
zl
global
.
Store
(
zl
)
sugar
=
zl
.
Sugar
()
sugar
.
Store
(
zl
.
Sugar
()
)
atomicLevel
=
al
atomicLevel
=
al
initOptions
=
normalized
initOptions
=
normalized
...
@@ -115,24 +120,32 @@ func SetLevel(level string) error {
...
@@ -115,24 +120,32 @@ func SetLevel(level string) error {
func
CurrentLevel
()
string
{
func
CurrentLevel
()
string
{
mu
.
RLock
()
mu
.
RLock
()
defer
mu
.
RUnlock
()
defer
mu
.
RUnlock
()
if
global
==
nil
{
if
global
.
Load
()
==
nil
{
return
"info"
return
"info"
}
}
return
atomicLevel
.
Level
()
.
String
()
return
atomicLevel
.
Level
()
.
String
()
}
}
func
SetSink
(
sink
Sink
)
{
func
SetSink
(
sink
Sink
)
{
mu
.
Lock
()
currentSink
.
Store
(
sinkState
{
sink
:
sink
})
defer
mu
.
Unlock
()
}
currentSink
=
sink
func
loadSink
()
Sink
{
v
:=
currentSink
.
Load
()
if
v
==
nil
{
return
nil
}
state
,
ok
:=
v
.
(
sinkState
)
if
!
ok
{
return
nil
}
return
state
.
sink
}
}
// WriteSinkEvent 直接写入日志 sink,不经过全局日志级别门控。
// WriteSinkEvent 直接写入日志 sink,不经过全局日志级别门控。
// 用于需要“可观测性入库”与“业务输出级别”解耦的场景(例如 ops 系统日志索引)。
// 用于需要“可观测性入库”与“业务输出级别”解耦的场景(例如 ops 系统日志索引)。
func
WriteSinkEvent
(
level
,
component
,
message
string
,
fields
map
[
string
]
any
)
{
func
WriteSinkEvent
(
level
,
component
,
message
string
,
fields
map
[
string
]
any
)
{
mu
.
RLock
()
sink
:=
loadSink
()
sink
:=
currentSink
mu
.
RUnlock
()
if
sink
==
nil
{
if
sink
==
nil
{
return
return
}
}
...
@@ -168,19 +181,15 @@ func WriteSinkEvent(level, component, message string, fields map[string]any) {
...
@@ -168,19 +181,15 @@ func WriteSinkEvent(level, component, message string, fields map[string]any) {
}
}
func
L
()
*
zap
.
Logger
{
func
L
()
*
zap
.
Logger
{
mu
.
RLock
()
if
l
:=
global
.
Load
();
l
!=
nil
{
defer
mu
.
RUnlock
()
return
l
if
global
!=
nil
{
return
global
}
}
return
zap
.
NewNop
()
return
zap
.
NewNop
()
}
}
func
S
()
*
zap
.
SugaredLogger
{
func
S
()
*
zap
.
SugaredLogger
{
mu
.
RLock
()
if
s
:=
sugar
.
Load
();
s
!=
nil
{
defer
mu
.
RUnlock
()
return
s
if
sugar
!=
nil
{
return
sugar
}
}
return
zap
.
NewNop
()
.
Sugar
()
return
zap
.
NewNop
()
.
Sugar
()
}
}
...
@@ -190,9 +199,7 @@ func With(fields ...zap.Field) *zap.Logger {
...
@@ -190,9 +199,7 @@ func With(fields ...zap.Field) *zap.Logger {
}
}
func
Sync
()
{
func
Sync
()
{
mu
.
RLock
()
l
:=
global
.
Load
()
l
:=
global
mu
.
RUnlock
()
if
l
!=
nil
{
if
l
!=
nil
{
_
=
l
.
Sync
()
_
=
l
.
Sync
()
}
}
...
@@ -210,7 +217,11 @@ func bridgeStdLogLocked() {
...
@@ -210,7 +217,11 @@ func bridgeStdLogLocked() {
log
.
SetFlags
(
0
)
log
.
SetFlags
(
0
)
log
.
SetPrefix
(
""
)
log
.
SetPrefix
(
""
)
log
.
SetOutput
(
newStdLogBridge
(
global
.
Named
(
"stdlog"
)))
base
:=
global
.
Load
()
if
base
==
nil
{
base
=
zap
.
NewNop
()
}
log
.
SetOutput
(
newStdLogBridge
(
base
.
Named
(
"stdlog"
)))
stdLogUndo
=
func
()
{
stdLogUndo
=
func
()
{
log
.
SetOutput
(
prevWriter
)
log
.
SetOutput
(
prevWriter
)
...
@@ -220,7 +231,11 @@ func bridgeStdLogLocked() {
...
@@ -220,7 +231,11 @@ func bridgeStdLogLocked() {
}
}
func
bridgeSlogLocked
()
{
func
bridgeSlogLocked
()
{
slog
.
SetDefault
(
slog
.
New
(
newSlogZapHandler
(
global
.
Named
(
"slog"
))))
base
:=
global
.
Load
()
if
base
==
nil
{
base
=
zap
.
NewNop
()
}
slog
.
SetDefault
(
slog
.
New
(
newSlogZapHandler
(
base
.
Named
(
"slog"
))))
}
}
func
buildLogger
(
options
InitOptions
)
(
*
zap
.
Logger
,
zap
.
AtomicLevel
,
error
)
{
func
buildLogger
(
options
InitOptions
)
(
*
zap
.
Logger
,
zap
.
AtomicLevel
,
error
)
{
...
@@ -363,9 +378,7 @@ func (s *sinkCore) Check(entry zapcore.Entry, ce *zapcore.CheckedEntry) *zapcore
...
@@ -363,9 +378,7 @@ func (s *sinkCore) Check(entry zapcore.Entry, ce *zapcore.CheckedEntry) *zapcore
func
(
s
*
sinkCore
)
Write
(
entry
zapcore
.
Entry
,
fields
[]
zapcore
.
Field
)
error
{
func
(
s
*
sinkCore
)
Write
(
entry
zapcore
.
Entry
,
fields
[]
zapcore
.
Field
)
error
{
// Only handle sink forwarding — the inner cores write via their own
// Only handle sink forwarding — the inner cores write via their own
// Write methods (added to CheckedEntry by s.core.Check above).
// Write methods (added to CheckedEntry by s.core.Check above).
mu
.
RLock
()
sink
:=
loadSink
()
sink
:=
currentSink
mu
.
RUnlock
()
if
sink
==
nil
{
if
sink
==
nil
{
return
nil
return
nil
}
}
...
@@ -454,7 +467,7 @@ func inferStdLogLevel(msg string) Level {
...
@@ -454,7 +467,7 @@ func inferStdLogLevel(msg string) Level {
if
strings
.
Contains
(
lower
,
" failed"
)
||
strings
.
Contains
(
lower
,
"error"
)
||
strings
.
Contains
(
lower
,
"panic"
)
||
strings
.
Contains
(
lower
,
"fatal"
)
{
if
strings
.
Contains
(
lower
,
" failed"
)
||
strings
.
Contains
(
lower
,
"error"
)
||
strings
.
Contains
(
lower
,
"panic"
)
||
strings
.
Contains
(
lower
,
"fatal"
)
{
return
LevelError
return
LevelError
}
}
if
strings
.
Contains
(
lower
,
"warning"
)
||
strings
.
Contains
(
lower
,
"warn"
)
||
strings
.
Contains
(
lower
,
"
retry"
)
||
strings
.
Contains
(
lower
,
"
queue full"
)
||
strings
.
Contains
(
lower
,
"fallback"
)
{
if
strings
.
Contains
(
lower
,
"warning"
)
||
strings
.
Contains
(
lower
,
"warn"
)
||
strings
.
Contains
(
lower
,
" queue full"
)
||
strings
.
Contains
(
lower
,
"fallback"
)
{
return
LevelWarn
return
LevelWarn
}
}
return
LevelInfo
return
LevelInfo
...
@@ -467,9 +480,7 @@ func LegacyPrintf(component, format string, args ...any) {
...
@@ -467,9 +480,7 @@ func LegacyPrintf(component, format string, args ...any) {
return
return
}
}
mu
.
RLock
()
initialized
:=
global
.
Load
()
!=
nil
initialized
:=
global
!=
nil
mu
.
RUnlock
()
if
!
initialized
{
if
!
initialized
{
// 在日志系统未初始化前,回退到标准库 log,避免测试/工具链丢日志。
// 在日志系统未初始化前,回退到标准库 log,避免测试/工具链丢日志。
log
.
Print
(
msg
)
log
.
Print
(
msg
)
...
...
backend/internal/pkg/logger/slog_handler.go
View file @
bb664d9b
...
@@ -48,16 +48,15 @@ func (h *slogZapHandler) Handle(_ context.Context, record slog.Record) error {
...
@@ -48,16 +48,15 @@ func (h *slogZapHandler) Handle(_ context.Context, record slog.Record) error {
return
true
return
true
})
})
entry
:=
h
.
logger
.
With
(
fields
...
)
switch
{
switch
{
case
record
.
Level
>=
slog
.
LevelError
:
case
record
.
Level
>=
slog
.
LevelError
:
entry
.
Error
(
record
.
Message
)
h
.
logger
.
Error
(
record
.
Message
,
fields
...
)
case
record
.
Level
>=
slog
.
LevelWarn
:
case
record
.
Level
>=
slog
.
LevelWarn
:
entry
.
Warn
(
record
.
Message
)
h
.
logger
.
Warn
(
record
.
Message
,
fields
...
)
case
record
.
Level
<=
slog
.
LevelDebug
:
case
record
.
Level
<=
slog
.
LevelDebug
:
entry
.
Debug
(
record
.
Message
)
h
.
logger
.
Debug
(
record
.
Message
,
fields
...
)
default
:
default
:
entry
.
Info
(
record
.
Message
)
h
.
logger
.
Info
(
record
.
Message
,
fields
...
)
}
}
return
nil
return
nil
}
}
...
...
backend/internal/pkg/logger/stdlog_bridge_test.go
View file @
bb664d9b
...
@@ -16,6 +16,7 @@ func TestInferStdLogLevel(t *testing.T) {
...
@@ -16,6 +16,7 @@ func TestInferStdLogLevel(t *testing.T) {
{
msg
:
"Warning: queue full"
,
want
:
LevelWarn
},
{
msg
:
"Warning: queue full"
,
want
:
LevelWarn
},
{
msg
:
"Forward request failed: timeout"
,
want
:
LevelError
},
{
msg
:
"Forward request failed: timeout"
,
want
:
LevelError
},
{
msg
:
"[ERROR] upstream unavailable"
,
want
:
LevelError
},
{
msg
:
"[ERROR] upstream unavailable"
,
want
:
LevelError
},
{
msg
:
"[OpenAI WS Mode] reconnect_retry account_id=22 retry=1 max_retries=5"
,
want
:
LevelInfo
},
{
msg
:
"service started"
,
want
:
LevelInfo
},
{
msg
:
"service started"
,
want
:
LevelInfo
},
{
msg
:
"debug: cache miss"
,
want
:
LevelDebug
},
{
msg
:
"debug: cache miss"
,
want
:
LevelDebug
},
}
}
...
...
backend/internal/pkg/openai/oauth.go
View file @
bb664d9b
...
@@ -36,10 +36,18 @@ const (
...
@@ -36,10 +36,18 @@ const (
SessionTTL
=
30
*
time
.
Minute
SessionTTL
=
30
*
time
.
Minute
)
)
const
(
// OAuthPlatformOpenAI uses OpenAI Codex-compatible OAuth client.
OAuthPlatformOpenAI
=
"openai"
// OAuthPlatformSora uses Sora OAuth client.
OAuthPlatformSora
=
"sora"
)
// OAuthSession stores OAuth flow state for OpenAI
// OAuthSession stores OAuth flow state for OpenAI
type
OAuthSession
struct
{
type
OAuthSession
struct
{
State
string
`json:"state"`
State
string
`json:"state"`
CodeVerifier
string
`json:"code_verifier"`
CodeVerifier
string
`json:"code_verifier"`
ClientID
string
`json:"client_id,omitempty"`
ProxyURL
string
`json:"proxy_url,omitempty"`
ProxyURL
string
`json:"proxy_url,omitempty"`
RedirectURI
string
`json:"redirect_uri"`
RedirectURI
string
`json:"redirect_uri"`
CreatedAt
time
.
Time
`json:"created_at"`
CreatedAt
time
.
Time
`json:"created_at"`
...
@@ -174,13 +182,20 @@ func base64URLEncode(data []byte) string {
...
@@ -174,13 +182,20 @@ func base64URLEncode(data []byte) string {
// BuildAuthorizationURL builds the OpenAI OAuth authorization URL
// BuildAuthorizationURL builds the OpenAI OAuth authorization URL
func
BuildAuthorizationURL
(
state
,
codeChallenge
,
redirectURI
string
)
string
{
func
BuildAuthorizationURL
(
state
,
codeChallenge
,
redirectURI
string
)
string
{
return
BuildAuthorizationURLForPlatform
(
state
,
codeChallenge
,
redirectURI
,
OAuthPlatformOpenAI
)
}
// BuildAuthorizationURLForPlatform builds authorization URL by platform.
func
BuildAuthorizationURLForPlatform
(
state
,
codeChallenge
,
redirectURI
,
platform
string
)
string
{
if
redirectURI
==
""
{
if
redirectURI
==
""
{
redirectURI
=
DefaultRedirectURI
redirectURI
=
DefaultRedirectURI
}
}
clientID
,
codexFlow
:=
OAuthClientConfigByPlatform
(
platform
)
params
:=
url
.
Values
{}
params
:=
url
.
Values
{}
params
.
Set
(
"response_type"
,
"code"
)
params
.
Set
(
"response_type"
,
"code"
)
params
.
Set
(
"client_id"
,
C
lientID
)
params
.
Set
(
"client_id"
,
c
lientID
)
params
.
Set
(
"redirect_uri"
,
redirectURI
)
params
.
Set
(
"redirect_uri"
,
redirectURI
)
params
.
Set
(
"scope"
,
DefaultScopes
)
params
.
Set
(
"scope"
,
DefaultScopes
)
params
.
Set
(
"state"
,
state
)
params
.
Set
(
"state"
,
state
)
...
@@ -188,11 +203,25 @@ func BuildAuthorizationURL(state, codeChallenge, redirectURI string) string {
...
@@ -188,11 +203,25 @@ func BuildAuthorizationURL(state, codeChallenge, redirectURI string) string {
params
.
Set
(
"code_challenge_method"
,
"S256"
)
params
.
Set
(
"code_challenge_method"
,
"S256"
)
// OpenAI specific parameters
// OpenAI specific parameters
params
.
Set
(
"id_token_add_organizations"
,
"true"
)
params
.
Set
(
"id_token_add_organizations"
,
"true"
)
params
.
Set
(
"codex_cli_simplified_flow"
,
"true"
)
if
codexFlow
{
params
.
Set
(
"codex_cli_simplified_flow"
,
"true"
)
}
return
fmt
.
Sprintf
(
"%s?%s"
,
AuthorizeURL
,
params
.
Encode
())
return
fmt
.
Sprintf
(
"%s?%s"
,
AuthorizeURL
,
params
.
Encode
())
}
}
// OAuthClientConfigByPlatform returns oauth client_id and whether codex simplified flow should be enabled.
// Sora 授权流程复用 Codex CLI 的 client_id(支持 localhost redirect_uri),
// 但不启用 codex_cli_simplified_flow;拿到的 access_token 绑定同一 OpenAI 账号,对 Sora API 同样可用。
func
OAuthClientConfigByPlatform
(
platform
string
)
(
clientID
string
,
codexFlow
bool
)
{
switch
strings
.
ToLower
(
strings
.
TrimSpace
(
platform
))
{
case
OAuthPlatformSora
:
return
ClientID
,
false
default
:
return
ClientID
,
true
}
}
// TokenRequest represents the token exchange request body
// TokenRequest represents the token exchange request body
type
TokenRequest
struct
{
type
TokenRequest
struct
{
GrantType
string
`json:"grant_type"`
GrantType
string
`json:"grant_type"`
...
@@ -296,9 +325,11 @@ func (r *RefreshTokenRequest) ToFormData() string {
...
@@ -296,9 +325,11 @@ func (r *RefreshTokenRequest) ToFormData() string {
return
params
.
Encode
()
return
params
.
Encode
()
}
}
// ParseIDToken parses the ID Token JWT and extracts claims
// ParseIDToken parses the ID Token JWT and extracts claims.
// Note: This does NOT verify the signature - it only decodes the payload
// 注意:当前仅解码 payload 并校验 exp,未验证 JWT 签名。
// For production, you should verify the token signature using OpenAI's public keys
// 生产环境如需用 ID Token 做授权决策,应通过 OpenAI 的 JWKS 端点验证签名:
//
// https://auth.openai.com/.well-known/jwks.json
func
ParseIDToken
(
idToken
string
)
(
*
IDTokenClaims
,
error
)
{
func
ParseIDToken
(
idToken
string
)
(
*
IDTokenClaims
,
error
)
{
parts
:=
strings
.
Split
(
idToken
,
"."
)
parts
:=
strings
.
Split
(
idToken
,
"."
)
if
len
(
parts
)
!=
3
{
if
len
(
parts
)
!=
3
{
...
@@ -329,6 +360,13 @@ func ParseIDToken(idToken string) (*IDTokenClaims, error) {
...
@@ -329,6 +360,13 @@ func ParseIDToken(idToken string) (*IDTokenClaims, error) {
return
nil
,
fmt
.
Errorf
(
"failed to parse JWT claims: %w"
,
err
)
return
nil
,
fmt
.
Errorf
(
"failed to parse JWT claims: %w"
,
err
)
}
}
// 校验 ID Token 是否已过期(允许 2 分钟时钟偏差,防止因服务器时钟略有差异误判刚颁发的令牌)
const
clockSkewTolerance
=
120
// 秒
now
:=
time
.
Now
()
.
Unix
()
if
claims
.
Exp
>
0
&&
now
>
claims
.
Exp
+
clockSkewTolerance
{
return
nil
,
fmt
.
Errorf
(
"id_token has expired (exp: %d, now: %d, skew_tolerance: %ds)"
,
claims
.
Exp
,
now
,
clockSkewTolerance
)
}
return
&
claims
,
nil
return
&
claims
,
nil
}
}
...
...
backend/internal/pkg/openai/oauth_test.go
View file @
bb664d9b
package
openai
package
openai
import
(
import
(
"net/url"
"sync"
"sync"
"testing"
"testing"
"time"
"time"
...
@@ -41,3 +42,41 @@ func TestSessionStore_Stop_Concurrent(t *testing.T) {
...
@@ -41,3 +42,41 @@ func TestSessionStore_Stop_Concurrent(t *testing.T) {
t
.
Fatal
(
"stopCh 未关闭"
)
t
.
Fatal
(
"stopCh 未关闭"
)
}
}
}
}
func
TestBuildAuthorizationURLForPlatform_OpenAI
(
t
*
testing
.
T
)
{
authURL
:=
BuildAuthorizationURLForPlatform
(
"state-1"
,
"challenge-1"
,
DefaultRedirectURI
,
OAuthPlatformOpenAI
)
parsed
,
err
:=
url
.
Parse
(
authURL
)
if
err
!=
nil
{
t
.
Fatalf
(
"Parse URL failed: %v"
,
err
)
}
q
:=
parsed
.
Query
()
if
got
:=
q
.
Get
(
"client_id"
);
got
!=
ClientID
{
t
.
Fatalf
(
"client_id mismatch: got=%q want=%q"
,
got
,
ClientID
)
}
if
got
:=
q
.
Get
(
"codex_cli_simplified_flow"
);
got
!=
"true"
{
t
.
Fatalf
(
"codex flow mismatch: got=%q want=true"
,
got
)
}
if
got
:=
q
.
Get
(
"id_token_add_organizations"
);
got
!=
"true"
{
t
.
Fatalf
(
"id_token_add_organizations mismatch: got=%q want=true"
,
got
)
}
}
// TestBuildAuthorizationURLForPlatform_Sora 验证 Sora 平台复用 Codex CLI 的 client_id,
// 但不启用 codex_cli_simplified_flow。
func
TestBuildAuthorizationURLForPlatform_Sora
(
t
*
testing
.
T
)
{
authURL
:=
BuildAuthorizationURLForPlatform
(
"state-2"
,
"challenge-2"
,
DefaultRedirectURI
,
OAuthPlatformSora
)
parsed
,
err
:=
url
.
Parse
(
authURL
)
if
err
!=
nil
{
t
.
Fatalf
(
"Parse URL failed: %v"
,
err
)
}
q
:=
parsed
.
Query
()
if
got
:=
q
.
Get
(
"client_id"
);
got
!=
ClientID
{
t
.
Fatalf
(
"client_id mismatch: got=%q want=%q (Sora should reuse Codex CLI client_id)"
,
got
,
ClientID
)
}
if
got
:=
q
.
Get
(
"codex_cli_simplified_flow"
);
got
!=
""
{
t
.
Fatalf
(
"codex flow should be empty for sora, got=%q"
,
got
)
}
if
got
:=
q
.
Get
(
"id_token_add_organizations"
);
got
!=
"true"
{
t
.
Fatalf
(
"id_token_add_organizations mismatch: got=%q want=true"
,
got
)
}
}
backend/internal/pkg/response/response_test.go
View file @
bb664d9b
...
@@ -29,10 +29,10 @@ func parsePaginatedBody(t *testing.T, w *httptest.ResponseRecorder) (Response, P
...
@@ -29,10 +29,10 @@ func parsePaginatedBody(t *testing.T, w *httptest.ResponseRecorder) (Response, P
t
.
Helper
()
t
.
Helper
()
// 先用 raw json 解析,因为 Data 是 any 类型
// 先用 raw json 解析,因为 Data 是 any 类型
var
raw
struct
{
var
raw
struct
{
Code
int
`json:"code"`
Code
int
`json:"code"`
Message
string
`json:"message"`
Message
string
`json:"message"`
Reason
string
`json:"reason,omitempty"`
Reason
string
`json:"reason,omitempty"`
Data
json
.
RawMessage
`json:"data,omitempty"`
Data
json
.
RawMessage
`json:"data,omitempty"`
}
}
require
.
NoError
(
t
,
json
.
Unmarshal
(
w
.
Body
.
Bytes
(),
&
raw
))
require
.
NoError
(
t
,
json
.
Unmarshal
(
w
.
Body
.
Bytes
(),
&
raw
))
...
...
backend/internal/pkg/tlsfingerprint/dialer.go
View file @
bb664d9b
...
@@ -268,8 +268,8 @@ func (d *SOCKS5ProxyDialer) DialTLSContext(ctx context.Context, network, addr st
...
@@ -268,8 +268,8 @@ func (d *SOCKS5ProxyDialer) DialTLSContext(ctx context.Context, network, addr st
"cipher_suites"
,
len
(
spec
.
CipherSuites
),
"cipher_suites"
,
len
(
spec
.
CipherSuites
),
"extensions"
,
len
(
spec
.
Extensions
),
"extensions"
,
len
(
spec
.
Extensions
),
"compression_methods"
,
spec
.
CompressionMethods
,
"compression_methods"
,
spec
.
CompressionMethods
,
"tls_vers_max"
,
fmt
.
Sprintf
(
"0x%04x"
,
spec
.
TLSVersMax
)
,
"tls_vers_max"
,
spec
.
TLSVersMax
,
"tls_vers_min"
,
fmt
.
Sprintf
(
"0x%04x"
,
spec
.
TLSVersMin
)
)
"tls_vers_min"
,
spec
.
TLSVersMin
)
if
d
.
profile
!=
nil
{
if
d
.
profile
!=
nil
{
slog
.
Debug
(
"tls_fingerprint_socks5_using_profile"
,
"name"
,
d
.
profile
.
Name
,
"grease"
,
d
.
profile
.
EnableGREASE
)
slog
.
Debug
(
"tls_fingerprint_socks5_using_profile"
,
"name"
,
d
.
profile
.
Name
,
"grease"
,
d
.
profile
.
EnableGREASE
)
...
@@ -294,8 +294,8 @@ func (d *SOCKS5ProxyDialer) DialTLSContext(ctx context.Context, network, addr st
...
@@ -294,8 +294,8 @@ func (d *SOCKS5ProxyDialer) DialTLSContext(ctx context.Context, network, addr st
state
:=
tlsConn
.
ConnectionState
()
state
:=
tlsConn
.
ConnectionState
()
slog
.
Debug
(
"tls_fingerprint_socks5_handshake_success"
,
slog
.
Debug
(
"tls_fingerprint_socks5_handshake_success"
,
"version"
,
fmt
.
Sprintf
(
"0x%04x"
,
state
.
Version
)
,
"version"
,
state
.
Version
,
"cipher_suite"
,
fmt
.
Sprintf
(
"0x%04x"
,
state
.
CipherSuite
)
,
"cipher_suite"
,
state
.
CipherSuite
,
"alpn"
,
state
.
NegotiatedProtocol
)
"alpn"
,
state
.
NegotiatedProtocol
)
return
tlsConn
,
nil
return
tlsConn
,
nil
...
@@ -404,8 +404,8 @@ func (d *HTTPProxyDialer) DialTLSContext(ctx context.Context, network, addr stri
...
@@ -404,8 +404,8 @@ func (d *HTTPProxyDialer) DialTLSContext(ctx context.Context, network, addr stri
state
:=
tlsConn
.
ConnectionState
()
state
:=
tlsConn
.
ConnectionState
()
slog
.
Debug
(
"tls_fingerprint_http_proxy_handshake_success"
,
slog
.
Debug
(
"tls_fingerprint_http_proxy_handshake_success"
,
"version"
,
fmt
.
Sprintf
(
"0x%04x"
,
state
.
Version
)
,
"version"
,
state
.
Version
,
"cipher_suite"
,
fmt
.
Sprintf
(
"0x%04x"
,
state
.
CipherSuite
)
,
"cipher_suite"
,
state
.
CipherSuite
,
"alpn"
,
state
.
NegotiatedProtocol
)
"alpn"
,
state
.
NegotiatedProtocol
)
return
tlsConn
,
nil
return
tlsConn
,
nil
...
@@ -470,8 +470,8 @@ func (d *Dialer) DialTLSContext(ctx context.Context, network, addr string) (net.
...
@@ -470,8 +470,8 @@ func (d *Dialer) DialTLSContext(ctx context.Context, network, addr string) (net.
// Log successful handshake details
// Log successful handshake details
state
:=
tlsConn
.
ConnectionState
()
state
:=
tlsConn
.
ConnectionState
()
slog
.
Debug
(
"tls_fingerprint_handshake_success"
,
slog
.
Debug
(
"tls_fingerprint_handshake_success"
,
"version"
,
fmt
.
Sprintf
(
"0x%04x"
,
state
.
Version
)
,
"version"
,
state
.
Version
,
"cipher_suite"
,
fmt
.
Sprintf
(
"0x%04x"
,
state
.
CipherSuite
)
,
"cipher_suite"
,
state
.
CipherSuite
,
"alpn"
,
state
.
NegotiatedProtocol
)
"alpn"
,
state
.
NegotiatedProtocol
)
return
tlsConn
,
nil
return
tlsConn
,
nil
...
...
backend/internal/pkg/usagestats/usage_log_types.go
View file @
bb664d9b
...
@@ -139,6 +139,7 @@ type UsageLogFilters struct {
...
@@ -139,6 +139,7 @@ type UsageLogFilters struct {
AccountID
int64
AccountID
int64
GroupID
int64
GroupID
int64
Model
string
Model
string
RequestType
*
int16
Stream
*
bool
Stream
*
bool
BillingType
*
int8
BillingType
*
int8
StartTime
*
time
.
Time
StartTime
*
time
.
Time
...
...
backend/internal/repository/account_repo.go
View file @
bb664d9b
...
@@ -50,11 +50,6 @@ type accountRepository struct {
...
@@ -50,11 +50,6 @@ type accountRepository struct {
schedulerCache
service
.
SchedulerCache
schedulerCache
service
.
SchedulerCache
}
}
type
tempUnschedSnapshot
struct
{
until
*
time
.
Time
reason
string
}
// NewAccountRepository 创建账户仓储实例。
// NewAccountRepository 创建账户仓储实例。
// 这是对外暴露的构造函数,返回接口类型以便于依赖注入。
// 这是对外暴露的构造函数,返回接口类型以便于依赖注入。
func
NewAccountRepository
(
client
*
dbent
.
Client
,
sqlDB
*
sql
.
DB
,
schedulerCache
service
.
SchedulerCache
)
service
.
AccountRepository
{
func
NewAccountRepository
(
client
*
dbent
.
Client
,
sqlDB
*
sql
.
DB
,
schedulerCache
service
.
SchedulerCache
)
service
.
AccountRepository
{
...
@@ -189,11 +184,6 @@ func (r *accountRepository) GetByIDs(ctx context.Context, ids []int64) ([]*servi
...
@@ -189,11 +184,6 @@ func (r *accountRepository) GetByIDs(ctx context.Context, ids []int64) ([]*servi
accountIDs
=
append
(
accountIDs
,
acc
.
ID
)
accountIDs
=
append
(
accountIDs
,
acc
.
ID
)
}
}
tempUnschedMap
,
err
:=
r
.
loadTempUnschedStates
(
ctx
,
accountIDs
)
if
err
!=
nil
{
return
nil
,
err
}
groupsByAccount
,
groupIDsByAccount
,
accountGroupsByAccount
,
err
:=
r
.
loadAccountGroups
(
ctx
,
accountIDs
)
groupsByAccount
,
groupIDsByAccount
,
accountGroupsByAccount
,
err
:=
r
.
loadAccountGroups
(
ctx
,
accountIDs
)
if
err
!=
nil
{
if
err
!=
nil
{
return
nil
,
err
return
nil
,
err
...
@@ -220,10 +210,6 @@ func (r *accountRepository) GetByIDs(ctx context.Context, ids []int64) ([]*servi
...
@@ -220,10 +210,6 @@ func (r *accountRepository) GetByIDs(ctx context.Context, ids []int64) ([]*servi
if
ags
,
ok
:=
accountGroupsByAccount
[
entAcc
.
ID
];
ok
{
if
ags
,
ok
:=
accountGroupsByAccount
[
entAcc
.
ID
];
ok
{
out
.
AccountGroups
=
ags
out
.
AccountGroups
=
ags
}
}
if
snap
,
ok
:=
tempUnschedMap
[
entAcc
.
ID
];
ok
{
out
.
TempUnschedulableUntil
=
snap
.
until
out
.
TempUnschedulableReason
=
snap
.
reason
}
outByID
[
entAcc
.
ID
]
=
out
outByID
[
entAcc
.
ID
]
=
out
}
}
...
@@ -611,6 +597,43 @@ func (r *accountRepository) syncSchedulerAccountSnapshot(ctx context.Context, ac
...
@@ -611,6 +597,43 @@ func (r *accountRepository) syncSchedulerAccountSnapshot(ctx context.Context, ac
}
}
}
}
func
(
r
*
accountRepository
)
syncSchedulerAccountSnapshots
(
ctx
context
.
Context
,
accountIDs
[]
int64
)
{
if
r
==
nil
||
r
.
schedulerCache
==
nil
||
len
(
accountIDs
)
==
0
{
return
}
uniqueIDs
:=
make
([]
int64
,
0
,
len
(
accountIDs
))
seen
:=
make
(
map
[
int64
]
struct
{},
len
(
accountIDs
))
for
_
,
id
:=
range
accountIDs
{
if
id
<=
0
{
continue
}
if
_
,
exists
:=
seen
[
id
];
exists
{
continue
}
seen
[
id
]
=
struct
{}{}
uniqueIDs
=
append
(
uniqueIDs
,
id
)
}
if
len
(
uniqueIDs
)
==
0
{
return
}
accounts
,
err
:=
r
.
GetByIDs
(
ctx
,
uniqueIDs
)
if
err
!=
nil
{
logger
.
LegacyPrintf
(
"repository.account"
,
"[Scheduler] batch sync account snapshot read failed: count=%d err=%v"
,
len
(
uniqueIDs
),
err
)
return
}
for
_
,
account
:=
range
accounts
{
if
account
==
nil
{
continue
}
if
err
:=
r
.
schedulerCache
.
SetAccount
(
ctx
,
account
);
err
!=
nil
{
logger
.
LegacyPrintf
(
"repository.account"
,
"[Scheduler] batch sync account snapshot write failed: id=%d err=%v"
,
account
.
ID
,
err
)
}
}
}
func
(
r
*
accountRepository
)
ClearError
(
ctx
context
.
Context
,
id
int64
)
error
{
func
(
r
*
accountRepository
)
ClearError
(
ctx
context
.
Context
,
id
int64
)
error
{
_
,
err
:=
r
.
client
.
Account
.
Update
()
.
_
,
err
:=
r
.
client
.
Account
.
Update
()
.
Where
(
dbaccount
.
IDEQ
(
id
))
.
Where
(
dbaccount
.
IDEQ
(
id
))
.
...
@@ -1197,9 +1220,7 @@ func (r *accountRepository) BulkUpdate(ctx context.Context, ids []int64, updates
...
@@ -1197,9 +1220,7 @@ func (r *accountRepository) BulkUpdate(ctx context.Context, ids []int64, updates
shouldSync
=
true
shouldSync
=
true
}
}
if
shouldSync
{
if
shouldSync
{
for
_
,
id
:=
range
ids
{
r
.
syncSchedulerAccountSnapshots
(
ctx
,
ids
)
r
.
syncSchedulerAccountSnapshot
(
ctx
,
id
)
}
}
}
}
}
return
rows
,
nil
return
rows
,
nil
...
@@ -1291,10 +1312,6 @@ func (r *accountRepository) accountsToService(ctx context.Context, accounts []*d
...
@@ -1291,10 +1312,6 @@ func (r *accountRepository) accountsToService(ctx context.Context, accounts []*d
if
err
!=
nil
{
if
err
!=
nil
{
return
nil
,
err
return
nil
,
err
}
}
tempUnschedMap
,
err
:=
r
.
loadTempUnschedStates
(
ctx
,
accountIDs
)
if
err
!=
nil
{
return
nil
,
err
}
groupsByAccount
,
groupIDsByAccount
,
accountGroupsByAccount
,
err
:=
r
.
loadAccountGroups
(
ctx
,
accountIDs
)
groupsByAccount
,
groupIDsByAccount
,
accountGroupsByAccount
,
err
:=
r
.
loadAccountGroups
(
ctx
,
accountIDs
)
if
err
!=
nil
{
if
err
!=
nil
{
return
nil
,
err
return
nil
,
err
...
@@ -1320,10 +1337,6 @@ func (r *accountRepository) accountsToService(ctx context.Context, accounts []*d
...
@@ -1320,10 +1337,6 @@ func (r *accountRepository) accountsToService(ctx context.Context, accounts []*d
if
ags
,
ok
:=
accountGroupsByAccount
[
acc
.
ID
];
ok
{
if
ags
,
ok
:=
accountGroupsByAccount
[
acc
.
ID
];
ok
{
out
.
AccountGroups
=
ags
out
.
AccountGroups
=
ags
}
}
if
snap
,
ok
:=
tempUnschedMap
[
acc
.
ID
];
ok
{
out
.
TempUnschedulableUntil
=
snap
.
until
out
.
TempUnschedulableReason
=
snap
.
reason
}
outAccounts
=
append
(
outAccounts
,
*
out
)
outAccounts
=
append
(
outAccounts
,
*
out
)
}
}
...
@@ -1348,48 +1361,6 @@ func notExpiredPredicate(now time.Time) dbpredicate.Account {
...
@@ -1348,48 +1361,6 @@ func notExpiredPredicate(now time.Time) dbpredicate.Account {
)
)
}
}
func
(
r
*
accountRepository
)
loadTempUnschedStates
(
ctx
context
.
Context
,
accountIDs
[]
int64
)
(
map
[
int64
]
tempUnschedSnapshot
,
error
)
{
out
:=
make
(
map
[
int64
]
tempUnschedSnapshot
)
if
len
(
accountIDs
)
==
0
{
return
out
,
nil
}
rows
,
err
:=
r
.
sql
.
QueryContext
(
ctx
,
`
SELECT id, temp_unschedulable_until, temp_unschedulable_reason
FROM accounts
WHERE id = ANY($1)
`
,
pq
.
Array
(
accountIDs
))
if
err
!=
nil
{
return
nil
,
err
}
defer
func
()
{
_
=
rows
.
Close
()
}()
for
rows
.
Next
()
{
var
id
int64
var
until
sql
.
NullTime
var
reason
sql
.
NullString
if
err
:=
rows
.
Scan
(
&
id
,
&
until
,
&
reason
);
err
!=
nil
{
return
nil
,
err
}
var
untilPtr
*
time
.
Time
if
until
.
Valid
{
tmp
:=
until
.
Time
untilPtr
=
&
tmp
}
if
reason
.
Valid
{
out
[
id
]
=
tempUnschedSnapshot
{
until
:
untilPtr
,
reason
:
reason
.
String
}
}
else
{
out
[
id
]
=
tempUnschedSnapshot
{
until
:
untilPtr
,
reason
:
""
}
}
}
if
err
:=
rows
.
Err
();
err
!=
nil
{
return
nil
,
err
}
return
out
,
nil
}
func
(
r
*
accountRepository
)
loadProxies
(
ctx
context
.
Context
,
proxyIDs
[]
int64
)
(
map
[
int64
]
*
service
.
Proxy
,
error
)
{
func
(
r
*
accountRepository
)
loadProxies
(
ctx
context
.
Context
,
proxyIDs
[]
int64
)
(
map
[
int64
]
*
service
.
Proxy
,
error
)
{
proxyMap
:=
make
(
map
[
int64
]
*
service
.
Proxy
)
proxyMap
:=
make
(
map
[
int64
]
*
service
.
Proxy
)
if
len
(
proxyIDs
)
==
0
{
if
len
(
proxyIDs
)
==
0
{
...
@@ -1500,31 +1471,33 @@ func accountEntityToService(m *dbent.Account) *service.Account {
...
@@ -1500,31 +1471,33 @@ func accountEntityToService(m *dbent.Account) *service.Account {
rateMultiplier
:=
m
.
RateMultiplier
rateMultiplier
:=
m
.
RateMultiplier
return
&
service
.
Account
{
return
&
service
.
Account
{
ID
:
m
.
ID
,
ID
:
m
.
ID
,
Name
:
m
.
Name
,
Name
:
m
.
Name
,
Notes
:
m
.
Notes
,
Notes
:
m
.
Notes
,
Platform
:
m
.
Platform
,
Platform
:
m
.
Platform
,
Type
:
m
.
Type
,
Type
:
m
.
Type
,
Credentials
:
copyJSONMap
(
m
.
Credentials
),
Credentials
:
copyJSONMap
(
m
.
Credentials
),
Extra
:
copyJSONMap
(
m
.
Extra
),
Extra
:
copyJSONMap
(
m
.
Extra
),
ProxyID
:
m
.
ProxyID
,
ProxyID
:
m
.
ProxyID
,
Concurrency
:
m
.
Concurrency
,
Concurrency
:
m
.
Concurrency
,
Priority
:
m
.
Priority
,
Priority
:
m
.
Priority
,
RateMultiplier
:
&
rateMultiplier
,
RateMultiplier
:
&
rateMultiplier
,
Status
:
m
.
Status
,
Status
:
m
.
Status
,
ErrorMessage
:
derefString
(
m
.
ErrorMessage
),
ErrorMessage
:
derefString
(
m
.
ErrorMessage
),
LastUsedAt
:
m
.
LastUsedAt
,
LastUsedAt
:
m
.
LastUsedAt
,
ExpiresAt
:
m
.
ExpiresAt
,
ExpiresAt
:
m
.
ExpiresAt
,
AutoPauseOnExpired
:
m
.
AutoPauseOnExpired
,
AutoPauseOnExpired
:
m
.
AutoPauseOnExpired
,
CreatedAt
:
m
.
CreatedAt
,
CreatedAt
:
m
.
CreatedAt
,
UpdatedAt
:
m
.
UpdatedAt
,
UpdatedAt
:
m
.
UpdatedAt
,
Schedulable
:
m
.
Schedulable
,
Schedulable
:
m
.
Schedulable
,
RateLimitedAt
:
m
.
RateLimitedAt
,
RateLimitedAt
:
m
.
RateLimitedAt
,
RateLimitResetAt
:
m
.
RateLimitResetAt
,
RateLimitResetAt
:
m
.
RateLimitResetAt
,
OverloadUntil
:
m
.
OverloadUntil
,
OverloadUntil
:
m
.
OverloadUntil
,
SessionWindowStart
:
m
.
SessionWindowStart
,
TempUnschedulableUntil
:
m
.
TempUnschedulableUntil
,
SessionWindowEnd
:
m
.
SessionWindowEnd
,
TempUnschedulableReason
:
derefString
(
m
.
TempUnschedulableReason
),
SessionWindowStatus
:
derefString
(
m
.
SessionWindowStatus
),
SessionWindowStart
:
m
.
SessionWindowStart
,
SessionWindowEnd
:
m
.
SessionWindowEnd
,
SessionWindowStatus
:
derefString
(
m
.
SessionWindowStatus
),
}
}
}
}
...
...
backend/internal/repository/account_repo_integration_test.go
View file @
bb664d9b
...
@@ -500,6 +500,38 @@ func (s *AccountRepoSuite) TestClearRateLimit() {
...
@@ -500,6 +500,38 @@ func (s *AccountRepoSuite) TestClearRateLimit() {
s
.
Require
()
.
Nil
(
got
.
OverloadUntil
)
s
.
Require
()
.
Nil
(
got
.
OverloadUntil
)
}
}
func
(
s
*
AccountRepoSuite
)
TestTempUnschedulableFieldsLoadedByGetByIDAndGetByIDs
()
{
acc1
:=
mustCreateAccount
(
s
.
T
(),
s
.
client
,
&
service
.
Account
{
Name
:
"acc-temp-1"
})
acc2
:=
mustCreateAccount
(
s
.
T
(),
s
.
client
,
&
service
.
Account
{
Name
:
"acc-temp-2"
})
until
:=
time
.
Now
()
.
Add
(
15
*
time
.
Minute
)
.
UTC
()
.
Truncate
(
time
.
Second
)
reason
:=
`{"rule":"429","matched_keyword":"too many requests"}`
s
.
Require
()
.
NoError
(
s
.
repo
.
SetTempUnschedulable
(
s
.
ctx
,
acc1
.
ID
,
until
,
reason
))
gotByID
,
err
:=
s
.
repo
.
GetByID
(
s
.
ctx
,
acc1
.
ID
)
s
.
Require
()
.
NoError
(
err
)
s
.
Require
()
.
NotNil
(
gotByID
.
TempUnschedulableUntil
)
s
.
Require
()
.
WithinDuration
(
until
,
*
gotByID
.
TempUnschedulableUntil
,
time
.
Second
)
s
.
Require
()
.
Equal
(
reason
,
gotByID
.
TempUnschedulableReason
)
gotByIDs
,
err
:=
s
.
repo
.
GetByIDs
(
s
.
ctx
,
[]
int64
{
acc2
.
ID
,
acc1
.
ID
})
s
.
Require
()
.
NoError
(
err
)
s
.
Require
()
.
Len
(
gotByIDs
,
2
)
s
.
Require
()
.
Equal
(
acc2
.
ID
,
gotByIDs
[
0
]
.
ID
)
s
.
Require
()
.
Nil
(
gotByIDs
[
0
]
.
TempUnschedulableUntil
)
s
.
Require
()
.
Equal
(
""
,
gotByIDs
[
0
]
.
TempUnschedulableReason
)
s
.
Require
()
.
Equal
(
acc1
.
ID
,
gotByIDs
[
1
]
.
ID
)
s
.
Require
()
.
NotNil
(
gotByIDs
[
1
]
.
TempUnschedulableUntil
)
s
.
Require
()
.
WithinDuration
(
until
,
*
gotByIDs
[
1
]
.
TempUnschedulableUntil
,
time
.
Second
)
s
.
Require
()
.
Equal
(
reason
,
gotByIDs
[
1
]
.
TempUnschedulableReason
)
s
.
Require
()
.
NoError
(
s
.
repo
.
ClearTempUnschedulable
(
s
.
ctx
,
acc1
.
ID
))
cleared
,
err
:=
s
.
repo
.
GetByID
(
s
.
ctx
,
acc1
.
ID
)
s
.
Require
()
.
NoError
(
err
)
s
.
Require
()
.
Nil
(
cleared
.
TempUnschedulableUntil
)
s
.
Require
()
.
Equal
(
""
,
cleared
.
TempUnschedulableReason
)
}
// --- UpdateLastUsed ---
// --- UpdateLastUsed ---
func
(
s
*
AccountRepoSuite
)
TestUpdateLastUsed
()
{
func
(
s
*
AccountRepoSuite
)
TestUpdateLastUsed
()
{
...
...
backend/internal/repository/api_key_repo.go
View file @
bb664d9b
...
@@ -445,20 +445,22 @@ func userEntityToService(u *dbent.User) *service.User {
...
@@ -445,20 +445,22 @@ func userEntityToService(u *dbent.User) *service.User {
return
nil
return
nil
}
}
return
&
service
.
User
{
return
&
service
.
User
{
ID
:
u
.
ID
,
ID
:
u
.
ID
,
Email
:
u
.
Email
,
Email
:
u
.
Email
,
Username
:
u
.
Username
,
Username
:
u
.
Username
,
Notes
:
u
.
Notes
,
Notes
:
u
.
Notes
,
PasswordHash
:
u
.
PasswordHash
,
PasswordHash
:
u
.
PasswordHash
,
Role
:
u
.
Role
,
Role
:
u
.
Role
,
Balance
:
u
.
Balance
,
Balance
:
u
.
Balance
,
Concurrency
:
u
.
Concurrency
,
Concurrency
:
u
.
Concurrency
,
Status
:
u
.
Status
,
Status
:
u
.
Status
,
TotpSecretEncrypted
:
u
.
TotpSecretEncrypted
,
SoraStorageQuotaBytes
:
u
.
SoraStorageQuotaBytes
,
TotpEnabled
:
u
.
TotpEnabled
,
SoraStorageUsedBytes
:
u
.
SoraStorageUsedBytes
,
TotpEnabledAt
:
u
.
TotpEnabledAt
,
TotpSecretEncrypted
:
u
.
TotpSecretEncrypted
,
CreatedAt
:
u
.
CreatedAt
,
TotpEnabled
:
u
.
TotpEnabled
,
UpdatedAt
:
u
.
UpdatedAt
,
TotpEnabledAt
:
u
.
TotpEnabledAt
,
CreatedAt
:
u
.
CreatedAt
,
UpdatedAt
:
u
.
UpdatedAt
,
}
}
}
}
...
@@ -486,6 +488,7 @@ func groupEntityToService(g *dbent.Group) *service.Group {
...
@@ -486,6 +488,7 @@ func groupEntityToService(g *dbent.Group) *service.Group {
SoraImagePrice540
:
g
.
SoraImagePrice540
,
SoraImagePrice540
:
g
.
SoraImagePrice540
,
SoraVideoPricePerRequest
:
g
.
SoraVideoPricePerRequest
,
SoraVideoPricePerRequest
:
g
.
SoraVideoPricePerRequest
,
SoraVideoPricePerRequestHD
:
g
.
SoraVideoPricePerRequestHd
,
SoraVideoPricePerRequestHD
:
g
.
SoraVideoPricePerRequestHd
,
SoraStorageQuotaBytes
:
g
.
SoraStorageQuotaBytes
,
DefaultValidityDays
:
g
.
DefaultValidityDays
,
DefaultValidityDays
:
g
.
DefaultValidityDays
,
ClaudeCodeOnly
:
g
.
ClaudeCodeOnly
,
ClaudeCodeOnly
:
g
.
ClaudeCodeOnly
,
FallbackGroupID
:
g
.
FallbackGroupID
,
FallbackGroupID
:
g
.
FallbackGroupID
,
...
...
backend/internal/repository/concurrency_cache.go
View file @
bb664d9b
...
@@ -227,6 +227,43 @@ func (c *concurrencyCache) GetAccountConcurrency(ctx context.Context, accountID
...
@@ -227,6 +227,43 @@ func (c *concurrencyCache) GetAccountConcurrency(ctx context.Context, accountID
return
result
,
nil
return
result
,
nil
}
}
func
(
c
*
concurrencyCache
)
GetAccountConcurrencyBatch
(
ctx
context
.
Context
,
accountIDs
[]
int64
)
(
map
[
int64
]
int
,
error
)
{
if
len
(
accountIDs
)
==
0
{
return
map
[
int64
]
int
{},
nil
}
now
,
err
:=
c
.
rdb
.
Time
(
ctx
)
.
Result
()
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"redis TIME: %w"
,
err
)
}
cutoffTime
:=
now
.
Unix
()
-
int64
(
c
.
slotTTLSeconds
)
pipe
:=
c
.
rdb
.
Pipeline
()
type
accountCmd
struct
{
accountID
int64
zcardCmd
*
redis
.
IntCmd
}
cmds
:=
make
([]
accountCmd
,
0
,
len
(
accountIDs
))
for
_
,
accountID
:=
range
accountIDs
{
slotKey
:=
accountSlotKeyPrefix
+
strconv
.
FormatInt
(
accountID
,
10
)
pipe
.
ZRemRangeByScore
(
ctx
,
slotKey
,
"-inf"
,
strconv
.
FormatInt
(
cutoffTime
,
10
))
cmds
=
append
(
cmds
,
accountCmd
{
accountID
:
accountID
,
zcardCmd
:
pipe
.
ZCard
(
ctx
,
slotKey
),
})
}
if
_
,
err
:=
pipe
.
Exec
(
ctx
);
err
!=
nil
&&
!
errors
.
Is
(
err
,
redis
.
Nil
)
{
return
nil
,
fmt
.
Errorf
(
"pipeline exec: %w"
,
err
)
}
result
:=
make
(
map
[
int64
]
int
,
len
(
accountIDs
))
for
_
,
cmd
:=
range
cmds
{
result
[
cmd
.
accountID
]
=
int
(
cmd
.
zcardCmd
.
Val
())
}
return
result
,
nil
}
// User slot operations
// User slot operations
func
(
c
*
concurrencyCache
)
AcquireUserSlot
(
ctx
context
.
Context
,
userID
int64
,
maxConcurrency
int
,
requestID
string
)
(
bool
,
error
)
{
func
(
c
*
concurrencyCache
)
AcquireUserSlot
(
ctx
context
.
Context
,
userID
int64
,
maxConcurrency
int
,
requestID
string
)
(
bool
,
error
)
{
...
...
backend/internal/repository/gateway_cache_integration_test.go
View file @
bb664d9b
...
@@ -104,7 +104,6 @@ func (s *GatewayCacheSuite) TestGetSessionAccountID_CorruptedValue() {
...
@@ -104,7 +104,6 @@ func (s *GatewayCacheSuite) TestGetSessionAccountID_CorruptedValue() {
require
.
False
(
s
.
T
(),
errors
.
Is
(
err
,
redis
.
Nil
),
"expected parsing error, not redis.Nil"
)
require
.
False
(
s
.
T
(),
errors
.
Is
(
err
,
redis
.
Nil
),
"expected parsing error, not redis.Nil"
)
}
}
func
TestGatewayCacheSuite
(
t
*
testing
.
T
)
{
func
TestGatewayCacheSuite
(
t
*
testing
.
T
)
{
suite
.
Run
(
t
,
new
(
GatewayCacheSuite
))
suite
.
Run
(
t
,
new
(
GatewayCacheSuite
))
}
}
backend/internal/repository/group_repo.go
View file @
bb664d9b
...
@@ -4,6 +4,8 @@ import (
...
@@ -4,6 +4,8 @@ import (
"context"
"context"
"database/sql"
"database/sql"
"errors"
"errors"
"fmt"
"strings"
dbent
"github.com/Wei-Shaw/sub2api/ent"
dbent
"github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/apikey"
"github.com/Wei-Shaw/sub2api/ent/apikey"
...
@@ -56,7 +58,8 @@ func (r *groupRepository) Create(ctx context.Context, groupIn *service.Group) er
...
@@ -56,7 +58,8 @@ func (r *groupRepository) Create(ctx context.Context, groupIn *service.Group) er
SetNillableFallbackGroupID
(
groupIn
.
FallbackGroupID
)
.
SetNillableFallbackGroupID
(
groupIn
.
FallbackGroupID
)
.
SetNillableFallbackGroupIDOnInvalidRequest
(
groupIn
.
FallbackGroupIDOnInvalidRequest
)
.
SetNillableFallbackGroupIDOnInvalidRequest
(
groupIn
.
FallbackGroupIDOnInvalidRequest
)
.
SetModelRoutingEnabled
(
groupIn
.
ModelRoutingEnabled
)
.
SetModelRoutingEnabled
(
groupIn
.
ModelRoutingEnabled
)
.
SetMcpXMLInject
(
groupIn
.
MCPXMLInject
)
SetMcpXMLInject
(
groupIn
.
MCPXMLInject
)
.
SetSoraStorageQuotaBytes
(
groupIn
.
SoraStorageQuotaBytes
)
// 设置模型路由配置
// 设置模型路由配置
if
groupIn
.
ModelRouting
!=
nil
{
if
groupIn
.
ModelRouting
!=
nil
{
...
@@ -121,7 +124,8 @@ func (r *groupRepository) Update(ctx context.Context, groupIn *service.Group) er
...
@@ -121,7 +124,8 @@ func (r *groupRepository) Update(ctx context.Context, groupIn *service.Group) er
SetDefaultValidityDays
(
groupIn
.
DefaultValidityDays
)
.
SetDefaultValidityDays
(
groupIn
.
DefaultValidityDays
)
.
SetClaudeCodeOnly
(
groupIn
.
ClaudeCodeOnly
)
.
SetClaudeCodeOnly
(
groupIn
.
ClaudeCodeOnly
)
.
SetModelRoutingEnabled
(
groupIn
.
ModelRoutingEnabled
)
.
SetModelRoutingEnabled
(
groupIn
.
ModelRoutingEnabled
)
.
SetMcpXMLInject
(
groupIn
.
MCPXMLInject
)
SetMcpXMLInject
(
groupIn
.
MCPXMLInject
)
.
SetSoraStorageQuotaBytes
(
groupIn
.
SoraStorageQuotaBytes
)
// 处理 FallbackGroupID:nil 时清除,否则设置
// 处理 FallbackGroupID:nil 时清除,否则设置
if
groupIn
.
FallbackGroupID
!=
nil
{
if
groupIn
.
FallbackGroupID
!=
nil
{
...
@@ -281,6 +285,54 @@ func (r *groupRepository) ExistsByName(ctx context.Context, name string) (bool,
...
@@ -281,6 +285,54 @@ func (r *groupRepository) ExistsByName(ctx context.Context, name string) (bool,
return
r
.
client
.
Group
.
Query
()
.
Where
(
group
.
NameEQ
(
name
))
.
Exist
(
ctx
)
return
r
.
client
.
Group
.
Query
()
.
Where
(
group
.
NameEQ
(
name
))
.
Exist
(
ctx
)
}
}
// ExistsByIDs 批量检查分组是否存在(仅检查未软删除记录)。
// 返回结构:map[groupID]exists。
func
(
r
*
groupRepository
)
ExistsByIDs
(
ctx
context
.
Context
,
ids
[]
int64
)
(
map
[
int64
]
bool
,
error
)
{
result
:=
make
(
map
[
int64
]
bool
,
len
(
ids
))
if
len
(
ids
)
==
0
{
return
result
,
nil
}
uniqueIDs
:=
make
([]
int64
,
0
,
len
(
ids
))
seen
:=
make
(
map
[
int64
]
struct
{},
len
(
ids
))
for
_
,
id
:=
range
ids
{
if
id
<=
0
{
continue
}
if
_
,
ok
:=
seen
[
id
];
ok
{
continue
}
seen
[
id
]
=
struct
{}{}
uniqueIDs
=
append
(
uniqueIDs
,
id
)
result
[
id
]
=
false
}
if
len
(
uniqueIDs
)
==
0
{
return
result
,
nil
}
rows
,
err
:=
r
.
sql
.
QueryContext
(
ctx
,
`
SELECT id
FROM groups
WHERE id = ANY($1) AND deleted_at IS NULL
`
,
pq
.
Array
(
uniqueIDs
))
if
err
!=
nil
{
return
nil
,
err
}
defer
func
()
{
_
=
rows
.
Close
()
}()
for
rows
.
Next
()
{
var
id
int64
if
err
:=
rows
.
Scan
(
&
id
);
err
!=
nil
{
return
nil
,
err
}
result
[
id
]
=
true
}
if
err
:=
rows
.
Err
();
err
!=
nil
{
return
nil
,
err
}
return
result
,
nil
}
func
(
r
*
groupRepository
)
GetAccountCount
(
ctx
context
.
Context
,
groupID
int64
)
(
int64
,
error
)
{
func
(
r
*
groupRepository
)
GetAccountCount
(
ctx
context
.
Context
,
groupID
int64
)
(
int64
,
error
)
{
var
count
int64
var
count
int64
if
err
:=
scanSingleRow
(
ctx
,
r
.
sql
,
"SELECT COUNT(*) FROM account_groups WHERE group_id = $1"
,
[]
any
{
groupID
},
&
count
);
err
!=
nil
{
if
err
:=
scanSingleRow
(
ctx
,
r
.
sql
,
"SELECT COUNT(*) FROM account_groups WHERE group_id = $1"
,
[]
any
{
groupID
},
&
count
);
err
!=
nil
{
...
@@ -512,22 +564,72 @@ func (r *groupRepository) UpdateSortOrders(ctx context.Context, updates []servic
...
@@ -512,22 +564,72 @@ func (r *groupRepository) UpdateSortOrders(ctx context.Context, updates []servic
return
nil
return
nil
}
}
// 使用事务批量更新
// 去重后保留最后一次排序值,避免重复 ID 造成 CASE 分支冲突。
tx
,
err
:=
r
.
client
.
Tx
(
ctx
)
sortOrderByID
:=
make
(
map
[
int64
]
int
,
len
(
updates
))
if
err
!=
nil
{
groupIDs
:=
make
([]
int64
,
0
,
len
(
updates
))
for
_
,
u
:=
range
updates
{
if
u
.
ID
<=
0
{
continue
}
if
_
,
exists
:=
sortOrderByID
[
u
.
ID
];
!
exists
{
groupIDs
=
append
(
groupIDs
,
u
.
ID
)
}
sortOrderByID
[
u
.
ID
]
=
u
.
SortOrder
}
if
len
(
groupIDs
)
==
0
{
return
nil
}
// 与旧实现保持一致:任何不存在/已删除的分组都返回 not found,且不执行更新。
var
existingCount
int
if
err
:=
scanSingleRow
(
ctx
,
r
.
sql
,
`SELECT COUNT(*) FROM groups WHERE deleted_at IS NULL AND id = ANY($1)`
,
[]
any
{
pq
.
Array
(
groupIDs
)},
&
existingCount
,
);
err
!=
nil
{
return
err
return
err
}
}
defer
func
()
{
_
=
tx
.
Rollback
()
}()
if
existingCount
!=
len
(
groupIDs
)
{
return
service
.
ErrGroupNotFound
}
for
_
,
u
:=
range
updates
{
args
:=
make
([]
any
,
0
,
len
(
groupIDs
)
*
2
+
1
)
if
_
,
err
:=
tx
.
Group
.
UpdateOneID
(
u
.
ID
)
.
SetSortOrder
(
u
.
SortOrder
)
.
Save
(
ctx
);
err
!=
nil
{
caseClauses
:=
make
([]
string
,
0
,
len
(
groupIDs
))
return
translatePersistenceError
(
err
,
service
.
ErrGroupNotFound
,
nil
)
placeholder
:=
1
}
for
_
,
id
:=
range
groupIDs
{
caseClauses
=
append
(
caseClauses
,
fmt
.
Sprintf
(
"WHEN $%d THEN $%d"
,
placeholder
,
placeholder
+
1
))
args
=
append
(
args
,
id
,
sortOrderByID
[
id
])
placeholder
+=
2
}
}
args
=
append
(
args
,
pq
.
Array
(
groupIDs
))
query
:=
fmt
.
Sprintf
(
`
UPDATE groups
SET sort_order = CASE id
%s
ELSE sort_order
END
WHERE deleted_at IS NULL AND id = ANY($%d)
`
,
strings
.
Join
(
caseClauses
,
"
\n\t\t\t
"
),
placeholder
)
if
err
:=
tx
.
Commit
();
err
!=
nil
{
result
,
err
:=
r
.
sql
.
ExecContext
(
ctx
,
query
,
args
...
)
if
err
!=
nil
{
return
err
}
affected
,
err
:=
result
.
RowsAffected
()
if
err
!=
nil
{
return
err
return
err
}
}
if
affected
!=
int64
(
len
(
groupIDs
))
{
return
service
.
ErrGroupNotFound
}
for
_
,
id
:=
range
groupIDs
{
if
err
:=
enqueueSchedulerOutbox
(
ctx
,
r
.
sql
,
service
.
SchedulerOutboxEventGroupChanged
,
nil
,
&
id
,
nil
);
err
!=
nil
{
logger
.
LegacyPrintf
(
"repository.group"
,
"[SchedulerOutbox] enqueue group sort update failed: group=%d err=%v"
,
id
,
err
)
}
}
return
nil
return
nil
}
}
Prev
1
2
3
4
5
6
7
8
9
10
…
13
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment