Commit bb664d9b authored by yangjianbo's avatar yangjianbo
Browse files

feat(sync): full code sync from release

parent bfc7b339
...@@ -39,7 +39,7 @@ const ( ...@@ -39,7 +39,7 @@ const (
// They enable the "login without creating your own OAuth client" experience, but Google may // They enable the "login without creating your own OAuth client" experience, but Google may
// restrict which scopes are allowed for this client. // restrict which scopes are allowed for this client.
GeminiCLIOAuthClientID = "681255809395-oo8ft2oprdrnp9e3aqf6av3hmdib135j.apps.googleusercontent.com" GeminiCLIOAuthClientID = "681255809395-oo8ft2oprdrnp9e3aqf6av3hmdib135j.apps.googleusercontent.com"
GeminiCLIOAuthClientSecret = "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" GeminiCLIOAuthClientSecret = "GOCSPX-your-client-secret"
// GeminiCLIOAuthClientSecretEnv is the environment variable name for the built-in client secret. // GeminiCLIOAuthClientSecretEnv is the environment variable name for the built-in client secret.
GeminiCLIOAuthClientSecretEnv = "GEMINI_CLI_OAUTH_CLIENT_SECRET" GeminiCLIOAuthClientSecretEnv = "GEMINI_CLI_OAUTH_CLIENT_SECRET"
......
...@@ -32,6 +32,7 @@ const ( ...@@ -32,6 +32,7 @@ const (
defaultMaxIdleConns = 100 // 最大空闲连接数 defaultMaxIdleConns = 100 // 最大空闲连接数
defaultMaxIdleConnsPerHost = 10 // 每个主机最大空闲连接数 defaultMaxIdleConnsPerHost = 10 // 每个主机最大空闲连接数
defaultIdleConnTimeout = 90 * time.Second // 空闲连接超时时间(建议小于上游 LB 超时) defaultIdleConnTimeout = 90 * time.Second // 空闲连接超时时间(建议小于上游 LB 超时)
validatedHostTTL = 30 * time.Second // DNS Rebinding 校验缓存 TTL
) )
// Options 定义共享 HTTP 客户端的构建参数 // Options 定义共享 HTTP 客户端的构建参数
...@@ -53,6 +54,9 @@ type Options struct { ...@@ -53,6 +54,9 @@ type Options struct {
// sharedClients 存储按配置参数缓存的 http.Client 实例 // sharedClients 存储按配置参数缓存的 http.Client 实例
var sharedClients sync.Map var sharedClients sync.Map
// 允许测试替换校验函数,生产默认指向真实实现。
var validateResolvedIP = urlvalidator.ValidateResolvedIP
// GetClient 返回共享的 HTTP 客户端实例 // GetClient 返回共享的 HTTP 客户端实例
// 性能优化:相同配置复用同一客户端,避免重复创建 Transport // 性能优化:相同配置复用同一客户端,避免重复创建 Transport
// 安全说明:代理配置失败时直接返回错误,不会回退到直连,避免 IP 关联风险 // 安全说明:代理配置失败时直接返回错误,不会回退到直连,避免 IP 关联风险
...@@ -84,7 +88,7 @@ func buildClient(opts Options) (*http.Client, error) { ...@@ -84,7 +88,7 @@ func buildClient(opts Options) (*http.Client, error) {
var rt http.RoundTripper = transport var rt http.RoundTripper = transport
if opts.ValidateResolvedIP && !opts.AllowPrivateHosts { if opts.ValidateResolvedIP && !opts.AllowPrivateHosts {
rt = &validatedTransport{base: transport} rt = newValidatedTransport(transport)
} }
return &http.Client{ return &http.Client{
Transport: rt, Transport: rt,
...@@ -149,17 +153,56 @@ func buildClientKey(opts Options) string { ...@@ -149,17 +153,56 @@ func buildClientKey(opts Options) string {
} }
type validatedTransport struct { type validatedTransport struct {
base http.RoundTripper base http.RoundTripper
validatedHosts sync.Map // map[string]time.Time, value 为过期时间
now func() time.Time
}
func newValidatedTransport(base http.RoundTripper) *validatedTransport {
return &validatedTransport{
base: base,
now: time.Now,
}
}
func (t *validatedTransport) isValidatedHost(host string, now time.Time) bool {
if t == nil {
return false
}
raw, ok := t.validatedHosts.Load(host)
if !ok {
return false
}
expireAt, ok := raw.(time.Time)
if !ok {
t.validatedHosts.Delete(host)
return false
}
if now.Before(expireAt) {
return true
}
t.validatedHosts.Delete(host)
return false
} }
func (t *validatedTransport) RoundTrip(req *http.Request) (*http.Response, error) { func (t *validatedTransport) RoundTrip(req *http.Request) (*http.Response, error) {
if req != nil && req.URL != nil { if req != nil && req.URL != nil {
host := strings.TrimSpace(req.URL.Hostname()) host := strings.ToLower(strings.TrimSpace(req.URL.Hostname()))
if host != "" { if host != "" {
if err := urlvalidator.ValidateResolvedIP(host); err != nil { now := time.Now()
return nil, err if t != nil && t.now != nil {
now = t.now()
}
if !t.isValidatedHost(host, now) {
if err := validateResolvedIP(host); err != nil {
return nil, err
}
t.validatedHosts.Store(host, now.Add(validatedHostTTL))
} }
} }
} }
if t == nil || t.base == nil {
return nil, fmt.Errorf("validated transport base is nil")
}
return t.base.RoundTrip(req) return t.base.RoundTrip(req)
} }
package httpclient
import (
"errors"
"io"
"net/http"
"strings"
"sync/atomic"
"testing"
"time"
"github.com/stretchr/testify/require"
)
type roundTripFunc func(*http.Request) (*http.Response, error)
func (f roundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) {
return f(req)
}
func TestValidatedTransport_CacheHostValidation(t *testing.T) {
originalValidate := validateResolvedIP
defer func() { validateResolvedIP = originalValidate }()
var validateCalls int32
validateResolvedIP = func(host string) error {
atomic.AddInt32(&validateCalls, 1)
require.Equal(t, "api.openai.com", host)
return nil
}
var baseCalls int32
base := roundTripFunc(func(_ *http.Request) (*http.Response, error) {
atomic.AddInt32(&baseCalls, 1)
return &http.Response{
StatusCode: http.StatusOK,
Body: io.NopCloser(strings.NewReader(`{}`)),
Header: make(http.Header),
}, nil
})
now := time.Unix(1730000000, 0)
transport := newValidatedTransport(base)
transport.now = func() time.Time { return now }
req, err := http.NewRequest(http.MethodGet, "https://api.openai.com/v1/responses", nil)
require.NoError(t, err)
_, err = transport.RoundTrip(req)
require.NoError(t, err)
_, err = transport.RoundTrip(req)
require.NoError(t, err)
require.Equal(t, int32(1), atomic.LoadInt32(&validateCalls))
require.Equal(t, int32(2), atomic.LoadInt32(&baseCalls))
}
func TestValidatedTransport_ExpiredCacheTriggersRevalidation(t *testing.T) {
originalValidate := validateResolvedIP
defer func() { validateResolvedIP = originalValidate }()
var validateCalls int32
validateResolvedIP = func(_ string) error {
atomic.AddInt32(&validateCalls, 1)
return nil
}
base := roundTripFunc(func(_ *http.Request) (*http.Response, error) {
return &http.Response{
StatusCode: http.StatusOK,
Body: io.NopCloser(strings.NewReader(`{}`)),
Header: make(http.Header),
}, nil
})
now := time.Unix(1730001000, 0)
transport := newValidatedTransport(base)
transport.now = func() time.Time { return now }
req, err := http.NewRequest(http.MethodGet, "https://api.openai.com/v1/responses", nil)
require.NoError(t, err)
_, err = transport.RoundTrip(req)
require.NoError(t, err)
now = now.Add(validatedHostTTL + time.Second)
_, err = transport.RoundTrip(req)
require.NoError(t, err)
require.Equal(t, int32(2), atomic.LoadInt32(&validateCalls))
}
func TestValidatedTransport_ValidationErrorStopsRoundTrip(t *testing.T) {
originalValidate := validateResolvedIP
defer func() { validateResolvedIP = originalValidate }()
expectedErr := errors.New("dns rebinding rejected")
validateResolvedIP = func(_ string) error {
return expectedErr
}
var baseCalls int32
base := roundTripFunc(func(_ *http.Request) (*http.Response, error) {
atomic.AddInt32(&baseCalls, 1)
return &http.Response{StatusCode: http.StatusOK, Body: io.NopCloser(strings.NewReader(`{}`))}, nil
})
transport := newValidatedTransport(base)
req, err := http.NewRequest(http.MethodGet, "https://api.openai.com/v1/responses", nil)
require.NoError(t, err)
_, err = transport.RoundTrip(req)
require.ErrorIs(t, err, expectedErr)
require.Equal(t, int32(0), atomic.LoadInt32(&baseCalls))
}
package httputil
import (
"bytes"
"io"
"net/http"
)
const (
requestBodyReadInitCap = 512
requestBodyReadMaxInitCap = 1 << 20
)
// ReadRequestBodyWithPrealloc reads request body with preallocated buffer based on content length.
func ReadRequestBodyWithPrealloc(req *http.Request) ([]byte, error) {
if req == nil || req.Body == nil {
return nil, nil
}
capHint := requestBodyReadInitCap
if req.ContentLength > 0 {
switch {
case req.ContentLength < int64(requestBodyReadInitCap):
capHint = requestBodyReadInitCap
case req.ContentLength > int64(requestBodyReadMaxInitCap):
capHint = requestBodyReadMaxInitCap
default:
capHint = int(req.ContentLength)
}
}
buf := bytes.NewBuffer(make([]byte, 0, capHint))
if _, err := io.Copy(buf, req.Body); err != nil {
return nil, err
}
return buf.Bytes(), nil
}
...@@ -67,6 +67,14 @@ func normalizeIP(ip string) string { ...@@ -67,6 +67,14 @@ func normalizeIP(ip string) string {
// privateNets 预编译私有 IP CIDR 块,避免每次调用 isPrivateIP 时重复解析 // privateNets 预编译私有 IP CIDR 块,避免每次调用 isPrivateIP 时重复解析
var privateNets []*net.IPNet var privateNets []*net.IPNet
// CompiledIPRules 表示预编译的 IP 匹配规则。
// PatternCount 记录原始规则数量,用于保留“规则存在但全无效”时的行为语义。
type CompiledIPRules struct {
CIDRs []*net.IPNet
IPs []net.IP
PatternCount int
}
func init() { func init() {
for _, cidr := range []string{ for _, cidr := range []string{
"10.0.0.0/8", "10.0.0.0/8",
...@@ -84,6 +92,53 @@ func init() { ...@@ -84,6 +92,53 @@ func init() {
} }
} }
// CompileIPRules 将 IP/CIDR 字符串规则预编译为可复用结构。
// 非法规则会被忽略,但 PatternCount 会保留原始规则条数。
func CompileIPRules(patterns []string) *CompiledIPRules {
compiled := &CompiledIPRules{
CIDRs: make([]*net.IPNet, 0, len(patterns)),
IPs: make([]net.IP, 0, len(patterns)),
PatternCount: len(patterns),
}
for _, pattern := range patterns {
normalized := strings.TrimSpace(pattern)
if normalized == "" {
continue
}
if strings.Contains(normalized, "/") {
_, cidr, err := net.ParseCIDR(normalized)
if err != nil || cidr == nil {
continue
}
compiled.CIDRs = append(compiled.CIDRs, cidr)
continue
}
parsedIP := net.ParseIP(normalized)
if parsedIP == nil {
continue
}
compiled.IPs = append(compiled.IPs, parsedIP)
}
return compiled
}
func matchesCompiledRules(parsedIP net.IP, rules *CompiledIPRules) bool {
if parsedIP == nil || rules == nil {
return false
}
for _, cidr := range rules.CIDRs {
if cidr.Contains(parsedIP) {
return true
}
}
for _, ruleIP := range rules.IPs {
if parsedIP.Equal(ruleIP) {
return true
}
}
return false
}
// isPrivateIP 检查 IP 是否为私有地址。 // isPrivateIP 检查 IP 是否为私有地址。
func isPrivateIP(ipStr string) bool { func isPrivateIP(ipStr string) bool {
ip := net.ParseIP(ipStr) ip := net.ParseIP(ipStr)
...@@ -142,19 +197,32 @@ func MatchesAnyPattern(clientIP string, patterns []string) bool { ...@@ -142,19 +197,32 @@ func MatchesAnyPattern(clientIP string, patterns []string) bool {
// 2. 如果白名单不为空,IP 必须在白名单中 // 2. 如果白名单不为空,IP 必须在白名单中
// 3. 如果白名单为空,允许访问(除非被黑名单拒绝) // 3. 如果白名单为空,允许访问(除非被黑名单拒绝)
func CheckIPRestriction(clientIP string, whitelist, blacklist []string) (bool, string) { func CheckIPRestriction(clientIP string, whitelist, blacklist []string) (bool, string) {
return CheckIPRestrictionWithCompiledRules(
clientIP,
CompileIPRules(whitelist),
CompileIPRules(blacklist),
)
}
// CheckIPRestrictionWithCompiledRules 使用预编译规则检查 IP 是否允许访问。
func CheckIPRestrictionWithCompiledRules(clientIP string, whitelist, blacklist *CompiledIPRules) (bool, string) {
// 规范化 IP // 规范化 IP
clientIP = normalizeIP(clientIP) clientIP = normalizeIP(clientIP)
if clientIP == "" { if clientIP == "" {
return false, "access denied" return false, "access denied"
} }
parsedIP := net.ParseIP(clientIP)
if parsedIP == nil {
return false, "access denied"
}
// 1. 检查黑名单 // 1. 检查黑名单
if len(blacklist) > 0 && MatchesAnyPattern(clientIP, blacklist) { if blacklist != nil && blacklist.PatternCount > 0 && matchesCompiledRules(parsedIP, blacklist) {
return false, "access denied" return false, "access denied"
} }
// 2. 检查白名单(如果设置了白名单,IP 必须在其中) // 2. 检查白名单(如果设置了白名单,IP 必须在其中)
if len(whitelist) > 0 && !MatchesAnyPattern(clientIP, whitelist) { if whitelist != nil && whitelist.PatternCount > 0 && !matchesCompiledRules(parsedIP, whitelist) {
return false, "access denied" return false, "access denied"
} }
......
...@@ -73,3 +73,24 @@ func TestGetTrustedClientIPUsesGinClientIP(t *testing.T) { ...@@ -73,3 +73,24 @@ func TestGetTrustedClientIPUsesGinClientIP(t *testing.T) {
require.Equal(t, 200, w.Code) require.Equal(t, 200, w.Code)
require.Equal(t, "9.9.9.9", w.Body.String()) require.Equal(t, "9.9.9.9", w.Body.String())
} }
func TestCheckIPRestrictionWithCompiledRules(t *testing.T) {
whitelist := CompileIPRules([]string{"10.0.0.0/8", "192.168.1.2"})
blacklist := CompileIPRules([]string{"10.1.1.1"})
allowed, reason := CheckIPRestrictionWithCompiledRules("10.2.3.4", whitelist, blacklist)
require.True(t, allowed)
require.Equal(t, "", reason)
allowed, reason = CheckIPRestrictionWithCompiledRules("10.1.1.1", whitelist, blacklist)
require.False(t, allowed)
require.Equal(t, "access denied", reason)
}
func TestCheckIPRestrictionWithCompiledRules_InvalidWhitelistStillDenies(t *testing.T) {
// 与旧实现保持一致:白名单有配置但全无效时,最终应拒绝访问。
invalidWhitelist := CompileIPRules([]string{"not-a-valid-pattern"})
allowed, reason := CheckIPRestrictionWithCompiledRules("8.8.8.8", invalidWhitelist, nil)
require.False(t, allowed)
require.Equal(t, "access denied", reason)
}
...@@ -10,6 +10,7 @@ import ( ...@@ -10,6 +10,7 @@ import (
"path/filepath" "path/filepath"
"strings" "strings"
"sync" "sync"
"sync/atomic"
"time" "time"
"go.uber.org/zap" "go.uber.org/zap"
...@@ -42,15 +43,19 @@ type LogEvent struct { ...@@ -42,15 +43,19 @@ type LogEvent struct {
var ( var (
mu sync.RWMutex mu sync.RWMutex
global *zap.Logger global atomic.Pointer[zap.Logger]
sugar *zap.SugaredLogger sugar atomic.Pointer[zap.SugaredLogger]
atomicLevel zap.AtomicLevel atomicLevel zap.AtomicLevel
initOptions InitOptions initOptions InitOptions
currentSink Sink currentSink atomic.Value // sinkState
stdLogUndo func() stdLogUndo func()
bootstrapOnce sync.Once bootstrapOnce sync.Once
) )
type sinkState struct {
sink Sink
}
func InitBootstrap() { func InitBootstrap() {
bootstrapOnce.Do(func() { bootstrapOnce.Do(func() {
if err := Init(bootstrapOptions()); err != nil { if err := Init(bootstrapOptions()); err != nil {
...@@ -72,9 +77,9 @@ func initLocked(options InitOptions) error { ...@@ -72,9 +77,9 @@ func initLocked(options InitOptions) error {
return err return err
} }
prev := global prev := global.Load()
global = zl global.Store(zl)
sugar = zl.Sugar() sugar.Store(zl.Sugar())
atomicLevel = al atomicLevel = al
initOptions = normalized initOptions = normalized
...@@ -115,24 +120,32 @@ func SetLevel(level string) error { ...@@ -115,24 +120,32 @@ func SetLevel(level string) error {
func CurrentLevel() string { func CurrentLevel() string {
mu.RLock() mu.RLock()
defer mu.RUnlock() defer mu.RUnlock()
if global == nil { if global.Load() == nil {
return "info" return "info"
} }
return atomicLevel.Level().String() return atomicLevel.Level().String()
} }
func SetSink(sink Sink) { func SetSink(sink Sink) {
mu.Lock() currentSink.Store(sinkState{sink: sink})
defer mu.Unlock() }
currentSink = sink
func loadSink() Sink {
v := currentSink.Load()
if v == nil {
return nil
}
state, ok := v.(sinkState)
if !ok {
return nil
}
return state.sink
} }
// WriteSinkEvent 直接写入日志 sink,不经过全局日志级别门控。 // WriteSinkEvent 直接写入日志 sink,不经过全局日志级别门控。
// 用于需要“可观测性入库”与“业务输出级别”解耦的场景(例如 ops 系统日志索引)。 // 用于需要“可观测性入库”与“业务输出级别”解耦的场景(例如 ops 系统日志索引)。
func WriteSinkEvent(level, component, message string, fields map[string]any) { func WriteSinkEvent(level, component, message string, fields map[string]any) {
mu.RLock() sink := loadSink()
sink := currentSink
mu.RUnlock()
if sink == nil { if sink == nil {
return return
} }
...@@ -168,19 +181,15 @@ func WriteSinkEvent(level, component, message string, fields map[string]any) { ...@@ -168,19 +181,15 @@ func WriteSinkEvent(level, component, message string, fields map[string]any) {
} }
func L() *zap.Logger { func L() *zap.Logger {
mu.RLock() if l := global.Load(); l != nil {
defer mu.RUnlock() return l
if global != nil {
return global
} }
return zap.NewNop() return zap.NewNop()
} }
func S() *zap.SugaredLogger { func S() *zap.SugaredLogger {
mu.RLock() if s := sugar.Load(); s != nil {
defer mu.RUnlock() return s
if sugar != nil {
return sugar
} }
return zap.NewNop().Sugar() return zap.NewNop().Sugar()
} }
...@@ -190,9 +199,7 @@ func With(fields ...zap.Field) *zap.Logger { ...@@ -190,9 +199,7 @@ func With(fields ...zap.Field) *zap.Logger {
} }
func Sync() { func Sync() {
mu.RLock() l := global.Load()
l := global
mu.RUnlock()
if l != nil { if l != nil {
_ = l.Sync() _ = l.Sync()
} }
...@@ -210,7 +217,11 @@ func bridgeStdLogLocked() { ...@@ -210,7 +217,11 @@ func bridgeStdLogLocked() {
log.SetFlags(0) log.SetFlags(0)
log.SetPrefix("") log.SetPrefix("")
log.SetOutput(newStdLogBridge(global.Named("stdlog"))) base := global.Load()
if base == nil {
base = zap.NewNop()
}
log.SetOutput(newStdLogBridge(base.Named("stdlog")))
stdLogUndo = func() { stdLogUndo = func() {
log.SetOutput(prevWriter) log.SetOutput(prevWriter)
...@@ -220,7 +231,11 @@ func bridgeStdLogLocked() { ...@@ -220,7 +231,11 @@ func bridgeStdLogLocked() {
} }
func bridgeSlogLocked() { func bridgeSlogLocked() {
slog.SetDefault(slog.New(newSlogZapHandler(global.Named("slog")))) base := global.Load()
if base == nil {
base = zap.NewNop()
}
slog.SetDefault(slog.New(newSlogZapHandler(base.Named("slog"))))
} }
func buildLogger(options InitOptions) (*zap.Logger, zap.AtomicLevel, error) { func buildLogger(options InitOptions) (*zap.Logger, zap.AtomicLevel, error) {
...@@ -363,9 +378,7 @@ func (s *sinkCore) Check(entry zapcore.Entry, ce *zapcore.CheckedEntry) *zapcore ...@@ -363,9 +378,7 @@ func (s *sinkCore) Check(entry zapcore.Entry, ce *zapcore.CheckedEntry) *zapcore
func (s *sinkCore) Write(entry zapcore.Entry, fields []zapcore.Field) error { func (s *sinkCore) Write(entry zapcore.Entry, fields []zapcore.Field) error {
// Only handle sink forwarding — the inner cores write via their own // Only handle sink forwarding — the inner cores write via their own
// Write methods (added to CheckedEntry by s.core.Check above). // Write methods (added to CheckedEntry by s.core.Check above).
mu.RLock() sink := loadSink()
sink := currentSink
mu.RUnlock()
if sink == nil { if sink == nil {
return nil return nil
} }
...@@ -454,7 +467,7 @@ func inferStdLogLevel(msg string) Level { ...@@ -454,7 +467,7 @@ func inferStdLogLevel(msg string) Level {
if strings.Contains(lower, " failed") || strings.Contains(lower, "error") || strings.Contains(lower, "panic") || strings.Contains(lower, "fatal") { if strings.Contains(lower, " failed") || strings.Contains(lower, "error") || strings.Contains(lower, "panic") || strings.Contains(lower, "fatal") {
return LevelError return LevelError
} }
if strings.Contains(lower, "warning") || strings.Contains(lower, "warn") || strings.Contains(lower, " retry") || strings.Contains(lower, " queue full") || strings.Contains(lower, "fallback") { if strings.Contains(lower, "warning") || strings.Contains(lower, "warn") || strings.Contains(lower, " queue full") || strings.Contains(lower, "fallback") {
return LevelWarn return LevelWarn
} }
return LevelInfo return LevelInfo
...@@ -467,9 +480,7 @@ func LegacyPrintf(component, format string, args ...any) { ...@@ -467,9 +480,7 @@ func LegacyPrintf(component, format string, args ...any) {
return return
} }
mu.RLock() initialized := global.Load() != nil
initialized := global != nil
mu.RUnlock()
if !initialized { if !initialized {
// 在日志系统未初始化前,回退到标准库 log,避免测试/工具链丢日志。 // 在日志系统未初始化前,回退到标准库 log,避免测试/工具链丢日志。
log.Print(msg) log.Print(msg)
......
...@@ -48,16 +48,15 @@ func (h *slogZapHandler) Handle(_ context.Context, record slog.Record) error { ...@@ -48,16 +48,15 @@ func (h *slogZapHandler) Handle(_ context.Context, record slog.Record) error {
return true return true
}) })
entry := h.logger.With(fields...)
switch { switch {
case record.Level >= slog.LevelError: case record.Level >= slog.LevelError:
entry.Error(record.Message) h.logger.Error(record.Message, fields...)
case record.Level >= slog.LevelWarn: case record.Level >= slog.LevelWarn:
entry.Warn(record.Message) h.logger.Warn(record.Message, fields...)
case record.Level <= slog.LevelDebug: case record.Level <= slog.LevelDebug:
entry.Debug(record.Message) h.logger.Debug(record.Message, fields...)
default: default:
entry.Info(record.Message) h.logger.Info(record.Message, fields...)
} }
return nil return nil
} }
......
...@@ -16,6 +16,7 @@ func TestInferStdLogLevel(t *testing.T) { ...@@ -16,6 +16,7 @@ func TestInferStdLogLevel(t *testing.T) {
{msg: "Warning: queue full", want: LevelWarn}, {msg: "Warning: queue full", want: LevelWarn},
{msg: "Forward request failed: timeout", want: LevelError}, {msg: "Forward request failed: timeout", want: LevelError},
{msg: "[ERROR] upstream unavailable", want: LevelError}, {msg: "[ERROR] upstream unavailable", want: LevelError},
{msg: "[OpenAI WS Mode] reconnect_retry account_id=22 retry=1 max_retries=5", want: LevelInfo},
{msg: "service started", want: LevelInfo}, {msg: "service started", want: LevelInfo},
{msg: "debug: cache miss", want: LevelDebug}, {msg: "debug: cache miss", want: LevelDebug},
} }
......
...@@ -36,10 +36,18 @@ const ( ...@@ -36,10 +36,18 @@ const (
SessionTTL = 30 * time.Minute SessionTTL = 30 * time.Minute
) )
const (
// OAuthPlatformOpenAI uses OpenAI Codex-compatible OAuth client.
OAuthPlatformOpenAI = "openai"
// OAuthPlatformSora uses Sora OAuth client.
OAuthPlatformSora = "sora"
)
// OAuthSession stores OAuth flow state for OpenAI // OAuthSession stores OAuth flow state for OpenAI
type OAuthSession struct { type OAuthSession struct {
State string `json:"state"` State string `json:"state"`
CodeVerifier string `json:"code_verifier"` CodeVerifier string `json:"code_verifier"`
ClientID string `json:"client_id,omitempty"`
ProxyURL string `json:"proxy_url,omitempty"` ProxyURL string `json:"proxy_url,omitempty"`
RedirectURI string `json:"redirect_uri"` RedirectURI string `json:"redirect_uri"`
CreatedAt time.Time `json:"created_at"` CreatedAt time.Time `json:"created_at"`
...@@ -174,13 +182,20 @@ func base64URLEncode(data []byte) string { ...@@ -174,13 +182,20 @@ func base64URLEncode(data []byte) string {
// BuildAuthorizationURL builds the OpenAI OAuth authorization URL // BuildAuthorizationURL builds the OpenAI OAuth authorization URL
func BuildAuthorizationURL(state, codeChallenge, redirectURI string) string { func BuildAuthorizationURL(state, codeChallenge, redirectURI string) string {
return BuildAuthorizationURLForPlatform(state, codeChallenge, redirectURI, OAuthPlatformOpenAI)
}
// BuildAuthorizationURLForPlatform builds authorization URL by platform.
func BuildAuthorizationURLForPlatform(state, codeChallenge, redirectURI, platform string) string {
if redirectURI == "" { if redirectURI == "" {
redirectURI = DefaultRedirectURI redirectURI = DefaultRedirectURI
} }
clientID, codexFlow := OAuthClientConfigByPlatform(platform)
params := url.Values{} params := url.Values{}
params.Set("response_type", "code") params.Set("response_type", "code")
params.Set("client_id", ClientID) params.Set("client_id", clientID)
params.Set("redirect_uri", redirectURI) params.Set("redirect_uri", redirectURI)
params.Set("scope", DefaultScopes) params.Set("scope", DefaultScopes)
params.Set("state", state) params.Set("state", state)
...@@ -188,11 +203,25 @@ func BuildAuthorizationURL(state, codeChallenge, redirectURI string) string { ...@@ -188,11 +203,25 @@ func BuildAuthorizationURL(state, codeChallenge, redirectURI string) string {
params.Set("code_challenge_method", "S256") params.Set("code_challenge_method", "S256")
// OpenAI specific parameters // OpenAI specific parameters
params.Set("id_token_add_organizations", "true") params.Set("id_token_add_organizations", "true")
params.Set("codex_cli_simplified_flow", "true") if codexFlow {
params.Set("codex_cli_simplified_flow", "true")
}
return fmt.Sprintf("%s?%s", AuthorizeURL, params.Encode()) return fmt.Sprintf("%s?%s", AuthorizeURL, params.Encode())
} }
// OAuthClientConfigByPlatform returns oauth client_id and whether codex simplified flow should be enabled.
// Sora 授权流程复用 Codex CLI 的 client_id(支持 localhost redirect_uri),
// 但不启用 codex_cli_simplified_flow;拿到的 access_token 绑定同一 OpenAI 账号,对 Sora API 同样可用。
func OAuthClientConfigByPlatform(platform string) (clientID string, codexFlow bool) {
switch strings.ToLower(strings.TrimSpace(platform)) {
case OAuthPlatformSora:
return ClientID, false
default:
return ClientID, true
}
}
// TokenRequest represents the token exchange request body // TokenRequest represents the token exchange request body
type TokenRequest struct { type TokenRequest struct {
GrantType string `json:"grant_type"` GrantType string `json:"grant_type"`
...@@ -296,9 +325,11 @@ func (r *RefreshTokenRequest) ToFormData() string { ...@@ -296,9 +325,11 @@ func (r *RefreshTokenRequest) ToFormData() string {
return params.Encode() return params.Encode()
} }
// ParseIDToken parses the ID Token JWT and extracts claims // ParseIDToken parses the ID Token JWT and extracts claims.
// Note: This does NOT verify the signature - it only decodes the payload // 注意:当前仅解码 payload 并校验 exp,未验证 JWT 签名。
// For production, you should verify the token signature using OpenAI's public keys // 生产环境如需用 ID Token 做授权决策,应通过 OpenAI 的 JWKS 端点验证签名:
//
// https://auth.openai.com/.well-known/jwks.json
func ParseIDToken(idToken string) (*IDTokenClaims, error) { func ParseIDToken(idToken string) (*IDTokenClaims, error) {
parts := strings.Split(idToken, ".") parts := strings.Split(idToken, ".")
if len(parts) != 3 { if len(parts) != 3 {
...@@ -329,6 +360,13 @@ func ParseIDToken(idToken string) (*IDTokenClaims, error) { ...@@ -329,6 +360,13 @@ func ParseIDToken(idToken string) (*IDTokenClaims, error) {
return nil, fmt.Errorf("failed to parse JWT claims: %w", err) return nil, fmt.Errorf("failed to parse JWT claims: %w", err)
} }
// 校验 ID Token 是否已过期(允许 2 分钟时钟偏差,防止因服务器时钟略有差异误判刚颁发的令牌)
const clockSkewTolerance = 120 // 秒
now := time.Now().Unix()
if claims.Exp > 0 && now > claims.Exp+clockSkewTolerance {
return nil, fmt.Errorf("id_token has expired (exp: %d, now: %d, skew_tolerance: %ds)", claims.Exp, now, clockSkewTolerance)
}
return &claims, nil return &claims, nil
} }
......
package openai package openai
import ( import (
"net/url"
"sync" "sync"
"testing" "testing"
"time" "time"
...@@ -41,3 +42,41 @@ func TestSessionStore_Stop_Concurrent(t *testing.T) { ...@@ -41,3 +42,41 @@ func TestSessionStore_Stop_Concurrent(t *testing.T) {
t.Fatal("stopCh 未关闭") t.Fatal("stopCh 未关闭")
} }
} }
func TestBuildAuthorizationURLForPlatform_OpenAI(t *testing.T) {
authURL := BuildAuthorizationURLForPlatform("state-1", "challenge-1", DefaultRedirectURI, OAuthPlatformOpenAI)
parsed, err := url.Parse(authURL)
if err != nil {
t.Fatalf("Parse URL failed: %v", err)
}
q := parsed.Query()
if got := q.Get("client_id"); got != ClientID {
t.Fatalf("client_id mismatch: got=%q want=%q", got, ClientID)
}
if got := q.Get("codex_cli_simplified_flow"); got != "true" {
t.Fatalf("codex flow mismatch: got=%q want=true", got)
}
if got := q.Get("id_token_add_organizations"); got != "true" {
t.Fatalf("id_token_add_organizations mismatch: got=%q want=true", got)
}
}
// TestBuildAuthorizationURLForPlatform_Sora 验证 Sora 平台复用 Codex CLI 的 client_id,
// 但不启用 codex_cli_simplified_flow。
func TestBuildAuthorizationURLForPlatform_Sora(t *testing.T) {
authURL := BuildAuthorizationURLForPlatform("state-2", "challenge-2", DefaultRedirectURI, OAuthPlatformSora)
parsed, err := url.Parse(authURL)
if err != nil {
t.Fatalf("Parse URL failed: %v", err)
}
q := parsed.Query()
if got := q.Get("client_id"); got != ClientID {
t.Fatalf("client_id mismatch: got=%q want=%q (Sora should reuse Codex CLI client_id)", got, ClientID)
}
if got := q.Get("codex_cli_simplified_flow"); got != "" {
t.Fatalf("codex flow should be empty for sora, got=%q", got)
}
if got := q.Get("id_token_add_organizations"); got != "true" {
t.Fatalf("id_token_add_organizations mismatch: got=%q want=true", got)
}
}
...@@ -29,10 +29,10 @@ func parsePaginatedBody(t *testing.T, w *httptest.ResponseRecorder) (Response, P ...@@ -29,10 +29,10 @@ func parsePaginatedBody(t *testing.T, w *httptest.ResponseRecorder) (Response, P
t.Helper() t.Helper()
// 先用 raw json 解析,因为 Data 是 any 类型 // 先用 raw json 解析,因为 Data 是 any 类型
var raw struct { var raw struct {
Code int `json:"code"` Code int `json:"code"`
Message string `json:"message"` Message string `json:"message"`
Reason string `json:"reason,omitempty"` Reason string `json:"reason,omitempty"`
Data json.RawMessage `json:"data,omitempty"` Data json.RawMessage `json:"data,omitempty"`
} }
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &raw)) require.NoError(t, json.Unmarshal(w.Body.Bytes(), &raw))
......
...@@ -268,8 +268,8 @@ func (d *SOCKS5ProxyDialer) DialTLSContext(ctx context.Context, network, addr st ...@@ -268,8 +268,8 @@ func (d *SOCKS5ProxyDialer) DialTLSContext(ctx context.Context, network, addr st
"cipher_suites", len(spec.CipherSuites), "cipher_suites", len(spec.CipherSuites),
"extensions", len(spec.Extensions), "extensions", len(spec.Extensions),
"compression_methods", spec.CompressionMethods, "compression_methods", spec.CompressionMethods,
"tls_vers_max", fmt.Sprintf("0x%04x", spec.TLSVersMax), "tls_vers_max", spec.TLSVersMax,
"tls_vers_min", fmt.Sprintf("0x%04x", spec.TLSVersMin)) "tls_vers_min", spec.TLSVersMin)
if d.profile != nil { if d.profile != nil {
slog.Debug("tls_fingerprint_socks5_using_profile", "name", d.profile.Name, "grease", d.profile.EnableGREASE) slog.Debug("tls_fingerprint_socks5_using_profile", "name", d.profile.Name, "grease", d.profile.EnableGREASE)
...@@ -294,8 +294,8 @@ func (d *SOCKS5ProxyDialer) DialTLSContext(ctx context.Context, network, addr st ...@@ -294,8 +294,8 @@ func (d *SOCKS5ProxyDialer) DialTLSContext(ctx context.Context, network, addr st
state := tlsConn.ConnectionState() state := tlsConn.ConnectionState()
slog.Debug("tls_fingerprint_socks5_handshake_success", slog.Debug("tls_fingerprint_socks5_handshake_success",
"version", fmt.Sprintf("0x%04x", state.Version), "version", state.Version,
"cipher_suite", fmt.Sprintf("0x%04x", state.CipherSuite), "cipher_suite", state.CipherSuite,
"alpn", state.NegotiatedProtocol) "alpn", state.NegotiatedProtocol)
return tlsConn, nil return tlsConn, nil
...@@ -404,8 +404,8 @@ func (d *HTTPProxyDialer) DialTLSContext(ctx context.Context, network, addr stri ...@@ -404,8 +404,8 @@ func (d *HTTPProxyDialer) DialTLSContext(ctx context.Context, network, addr stri
state := tlsConn.ConnectionState() state := tlsConn.ConnectionState()
slog.Debug("tls_fingerprint_http_proxy_handshake_success", slog.Debug("tls_fingerprint_http_proxy_handshake_success",
"version", fmt.Sprintf("0x%04x", state.Version), "version", state.Version,
"cipher_suite", fmt.Sprintf("0x%04x", state.CipherSuite), "cipher_suite", state.CipherSuite,
"alpn", state.NegotiatedProtocol) "alpn", state.NegotiatedProtocol)
return tlsConn, nil return tlsConn, nil
...@@ -470,8 +470,8 @@ func (d *Dialer) DialTLSContext(ctx context.Context, network, addr string) (net. ...@@ -470,8 +470,8 @@ func (d *Dialer) DialTLSContext(ctx context.Context, network, addr string) (net.
// Log successful handshake details // Log successful handshake details
state := tlsConn.ConnectionState() state := tlsConn.ConnectionState()
slog.Debug("tls_fingerprint_handshake_success", slog.Debug("tls_fingerprint_handshake_success",
"version", fmt.Sprintf("0x%04x", state.Version), "version", state.Version,
"cipher_suite", fmt.Sprintf("0x%04x", state.CipherSuite), "cipher_suite", state.CipherSuite,
"alpn", state.NegotiatedProtocol) "alpn", state.NegotiatedProtocol)
return tlsConn, nil return tlsConn, nil
......
...@@ -139,6 +139,7 @@ type UsageLogFilters struct { ...@@ -139,6 +139,7 @@ type UsageLogFilters struct {
AccountID int64 AccountID int64
GroupID int64 GroupID int64
Model string Model string
RequestType *int16
Stream *bool Stream *bool
BillingType *int8 BillingType *int8
StartTime *time.Time StartTime *time.Time
......
...@@ -50,11 +50,6 @@ type accountRepository struct { ...@@ -50,11 +50,6 @@ type accountRepository struct {
schedulerCache service.SchedulerCache schedulerCache service.SchedulerCache
} }
type tempUnschedSnapshot struct {
until *time.Time
reason string
}
// NewAccountRepository 创建账户仓储实例。 // NewAccountRepository 创建账户仓储实例。
// 这是对外暴露的构造函数,返回接口类型以便于依赖注入。 // 这是对外暴露的构造函数,返回接口类型以便于依赖注入。
func NewAccountRepository(client *dbent.Client, sqlDB *sql.DB, schedulerCache service.SchedulerCache) service.AccountRepository { func NewAccountRepository(client *dbent.Client, sqlDB *sql.DB, schedulerCache service.SchedulerCache) service.AccountRepository {
...@@ -189,11 +184,6 @@ func (r *accountRepository) GetByIDs(ctx context.Context, ids []int64) ([]*servi ...@@ -189,11 +184,6 @@ func (r *accountRepository) GetByIDs(ctx context.Context, ids []int64) ([]*servi
accountIDs = append(accountIDs, acc.ID) accountIDs = append(accountIDs, acc.ID)
} }
tempUnschedMap, err := r.loadTempUnschedStates(ctx, accountIDs)
if err != nil {
return nil, err
}
groupsByAccount, groupIDsByAccount, accountGroupsByAccount, err := r.loadAccountGroups(ctx, accountIDs) groupsByAccount, groupIDsByAccount, accountGroupsByAccount, err := r.loadAccountGroups(ctx, accountIDs)
if err != nil { if err != nil {
return nil, err return nil, err
...@@ -220,10 +210,6 @@ func (r *accountRepository) GetByIDs(ctx context.Context, ids []int64) ([]*servi ...@@ -220,10 +210,6 @@ func (r *accountRepository) GetByIDs(ctx context.Context, ids []int64) ([]*servi
if ags, ok := accountGroupsByAccount[entAcc.ID]; ok { if ags, ok := accountGroupsByAccount[entAcc.ID]; ok {
out.AccountGroups = ags out.AccountGroups = ags
} }
if snap, ok := tempUnschedMap[entAcc.ID]; ok {
out.TempUnschedulableUntil = snap.until
out.TempUnschedulableReason = snap.reason
}
outByID[entAcc.ID] = out outByID[entAcc.ID] = out
} }
...@@ -611,6 +597,43 @@ func (r *accountRepository) syncSchedulerAccountSnapshot(ctx context.Context, ac ...@@ -611,6 +597,43 @@ func (r *accountRepository) syncSchedulerAccountSnapshot(ctx context.Context, ac
} }
} }
func (r *accountRepository) syncSchedulerAccountSnapshots(ctx context.Context, accountIDs []int64) {
if r == nil || r.schedulerCache == nil || len(accountIDs) == 0 {
return
}
uniqueIDs := make([]int64, 0, len(accountIDs))
seen := make(map[int64]struct{}, len(accountIDs))
for _, id := range accountIDs {
if id <= 0 {
continue
}
if _, exists := seen[id]; exists {
continue
}
seen[id] = struct{}{}
uniqueIDs = append(uniqueIDs, id)
}
if len(uniqueIDs) == 0 {
return
}
accounts, err := r.GetByIDs(ctx, uniqueIDs)
if err != nil {
logger.LegacyPrintf("repository.account", "[Scheduler] batch sync account snapshot read failed: count=%d err=%v", len(uniqueIDs), err)
return
}
for _, account := range accounts {
if account == nil {
continue
}
if err := r.schedulerCache.SetAccount(ctx, account); err != nil {
logger.LegacyPrintf("repository.account", "[Scheduler] batch sync account snapshot write failed: id=%d err=%v", account.ID, err)
}
}
}
func (r *accountRepository) ClearError(ctx context.Context, id int64) error { func (r *accountRepository) ClearError(ctx context.Context, id int64) error {
_, err := r.client.Account.Update(). _, err := r.client.Account.Update().
Where(dbaccount.IDEQ(id)). Where(dbaccount.IDEQ(id)).
...@@ -1197,9 +1220,7 @@ func (r *accountRepository) BulkUpdate(ctx context.Context, ids []int64, updates ...@@ -1197,9 +1220,7 @@ func (r *accountRepository) BulkUpdate(ctx context.Context, ids []int64, updates
shouldSync = true shouldSync = true
} }
if shouldSync { if shouldSync {
for _, id := range ids { r.syncSchedulerAccountSnapshots(ctx, ids)
r.syncSchedulerAccountSnapshot(ctx, id)
}
} }
} }
return rows, nil return rows, nil
...@@ -1291,10 +1312,6 @@ func (r *accountRepository) accountsToService(ctx context.Context, accounts []*d ...@@ -1291,10 +1312,6 @@ func (r *accountRepository) accountsToService(ctx context.Context, accounts []*d
if err != nil { if err != nil {
return nil, err return nil, err
} }
tempUnschedMap, err := r.loadTempUnschedStates(ctx, accountIDs)
if err != nil {
return nil, err
}
groupsByAccount, groupIDsByAccount, accountGroupsByAccount, err := r.loadAccountGroups(ctx, accountIDs) groupsByAccount, groupIDsByAccount, accountGroupsByAccount, err := r.loadAccountGroups(ctx, accountIDs)
if err != nil { if err != nil {
return nil, err return nil, err
...@@ -1320,10 +1337,6 @@ func (r *accountRepository) accountsToService(ctx context.Context, accounts []*d ...@@ -1320,10 +1337,6 @@ func (r *accountRepository) accountsToService(ctx context.Context, accounts []*d
if ags, ok := accountGroupsByAccount[acc.ID]; ok { if ags, ok := accountGroupsByAccount[acc.ID]; ok {
out.AccountGroups = ags out.AccountGroups = ags
} }
if snap, ok := tempUnschedMap[acc.ID]; ok {
out.TempUnschedulableUntil = snap.until
out.TempUnschedulableReason = snap.reason
}
outAccounts = append(outAccounts, *out) outAccounts = append(outAccounts, *out)
} }
...@@ -1348,48 +1361,6 @@ func notExpiredPredicate(now time.Time) dbpredicate.Account { ...@@ -1348,48 +1361,6 @@ func notExpiredPredicate(now time.Time) dbpredicate.Account {
) )
} }
func (r *accountRepository) loadTempUnschedStates(ctx context.Context, accountIDs []int64) (map[int64]tempUnschedSnapshot, error) {
out := make(map[int64]tempUnschedSnapshot)
if len(accountIDs) == 0 {
return out, nil
}
rows, err := r.sql.QueryContext(ctx, `
SELECT id, temp_unschedulable_until, temp_unschedulable_reason
FROM accounts
WHERE id = ANY($1)
`, pq.Array(accountIDs))
if err != nil {
return nil, err
}
defer func() { _ = rows.Close() }()
for rows.Next() {
var id int64
var until sql.NullTime
var reason sql.NullString
if err := rows.Scan(&id, &until, &reason); err != nil {
return nil, err
}
var untilPtr *time.Time
if until.Valid {
tmp := until.Time
untilPtr = &tmp
}
if reason.Valid {
out[id] = tempUnschedSnapshot{until: untilPtr, reason: reason.String}
} else {
out[id] = tempUnschedSnapshot{until: untilPtr, reason: ""}
}
}
if err := rows.Err(); err != nil {
return nil, err
}
return out, nil
}
func (r *accountRepository) loadProxies(ctx context.Context, proxyIDs []int64) (map[int64]*service.Proxy, error) { func (r *accountRepository) loadProxies(ctx context.Context, proxyIDs []int64) (map[int64]*service.Proxy, error) {
proxyMap := make(map[int64]*service.Proxy) proxyMap := make(map[int64]*service.Proxy)
if len(proxyIDs) == 0 { if len(proxyIDs) == 0 {
...@@ -1500,31 +1471,33 @@ func accountEntityToService(m *dbent.Account) *service.Account { ...@@ -1500,31 +1471,33 @@ func accountEntityToService(m *dbent.Account) *service.Account {
rateMultiplier := m.RateMultiplier rateMultiplier := m.RateMultiplier
return &service.Account{ return &service.Account{
ID: m.ID, ID: m.ID,
Name: m.Name, Name: m.Name,
Notes: m.Notes, Notes: m.Notes,
Platform: m.Platform, Platform: m.Platform,
Type: m.Type, Type: m.Type,
Credentials: copyJSONMap(m.Credentials), Credentials: copyJSONMap(m.Credentials),
Extra: copyJSONMap(m.Extra), Extra: copyJSONMap(m.Extra),
ProxyID: m.ProxyID, ProxyID: m.ProxyID,
Concurrency: m.Concurrency, Concurrency: m.Concurrency,
Priority: m.Priority, Priority: m.Priority,
RateMultiplier: &rateMultiplier, RateMultiplier: &rateMultiplier,
Status: m.Status, Status: m.Status,
ErrorMessage: derefString(m.ErrorMessage), ErrorMessage: derefString(m.ErrorMessage),
LastUsedAt: m.LastUsedAt, LastUsedAt: m.LastUsedAt,
ExpiresAt: m.ExpiresAt, ExpiresAt: m.ExpiresAt,
AutoPauseOnExpired: m.AutoPauseOnExpired, AutoPauseOnExpired: m.AutoPauseOnExpired,
CreatedAt: m.CreatedAt, CreatedAt: m.CreatedAt,
UpdatedAt: m.UpdatedAt, UpdatedAt: m.UpdatedAt,
Schedulable: m.Schedulable, Schedulable: m.Schedulable,
RateLimitedAt: m.RateLimitedAt, RateLimitedAt: m.RateLimitedAt,
RateLimitResetAt: m.RateLimitResetAt, RateLimitResetAt: m.RateLimitResetAt,
OverloadUntil: m.OverloadUntil, OverloadUntil: m.OverloadUntil,
SessionWindowStart: m.SessionWindowStart, TempUnschedulableUntil: m.TempUnschedulableUntil,
SessionWindowEnd: m.SessionWindowEnd, TempUnschedulableReason: derefString(m.TempUnschedulableReason),
SessionWindowStatus: derefString(m.SessionWindowStatus), SessionWindowStart: m.SessionWindowStart,
SessionWindowEnd: m.SessionWindowEnd,
SessionWindowStatus: derefString(m.SessionWindowStatus),
} }
} }
......
...@@ -500,6 +500,38 @@ func (s *AccountRepoSuite) TestClearRateLimit() { ...@@ -500,6 +500,38 @@ func (s *AccountRepoSuite) TestClearRateLimit() {
s.Require().Nil(got.OverloadUntil) s.Require().Nil(got.OverloadUntil)
} }
func (s *AccountRepoSuite) TestTempUnschedulableFieldsLoadedByGetByIDAndGetByIDs() {
acc1 := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-temp-1"})
acc2 := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-temp-2"})
until := time.Now().Add(15 * time.Minute).UTC().Truncate(time.Second)
reason := `{"rule":"429","matched_keyword":"too many requests"}`
s.Require().NoError(s.repo.SetTempUnschedulable(s.ctx, acc1.ID, until, reason))
gotByID, err := s.repo.GetByID(s.ctx, acc1.ID)
s.Require().NoError(err)
s.Require().NotNil(gotByID.TempUnschedulableUntil)
s.Require().WithinDuration(until, *gotByID.TempUnschedulableUntil, time.Second)
s.Require().Equal(reason, gotByID.TempUnschedulableReason)
gotByIDs, err := s.repo.GetByIDs(s.ctx, []int64{acc2.ID, acc1.ID})
s.Require().NoError(err)
s.Require().Len(gotByIDs, 2)
s.Require().Equal(acc2.ID, gotByIDs[0].ID)
s.Require().Nil(gotByIDs[0].TempUnschedulableUntil)
s.Require().Equal("", gotByIDs[0].TempUnschedulableReason)
s.Require().Equal(acc1.ID, gotByIDs[1].ID)
s.Require().NotNil(gotByIDs[1].TempUnschedulableUntil)
s.Require().WithinDuration(until, *gotByIDs[1].TempUnschedulableUntil, time.Second)
s.Require().Equal(reason, gotByIDs[1].TempUnschedulableReason)
s.Require().NoError(s.repo.ClearTempUnschedulable(s.ctx, acc1.ID))
cleared, err := s.repo.GetByID(s.ctx, acc1.ID)
s.Require().NoError(err)
s.Require().Nil(cleared.TempUnschedulableUntil)
s.Require().Equal("", cleared.TempUnschedulableReason)
}
// --- UpdateLastUsed --- // --- UpdateLastUsed ---
func (s *AccountRepoSuite) TestUpdateLastUsed() { func (s *AccountRepoSuite) TestUpdateLastUsed() {
......
...@@ -445,20 +445,22 @@ func userEntityToService(u *dbent.User) *service.User { ...@@ -445,20 +445,22 @@ func userEntityToService(u *dbent.User) *service.User {
return nil return nil
} }
return &service.User{ return &service.User{
ID: u.ID, ID: u.ID,
Email: u.Email, Email: u.Email,
Username: u.Username, Username: u.Username,
Notes: u.Notes, Notes: u.Notes,
PasswordHash: u.PasswordHash, PasswordHash: u.PasswordHash,
Role: u.Role, Role: u.Role,
Balance: u.Balance, Balance: u.Balance,
Concurrency: u.Concurrency, Concurrency: u.Concurrency,
Status: u.Status, Status: u.Status,
TotpSecretEncrypted: u.TotpSecretEncrypted, SoraStorageQuotaBytes: u.SoraStorageQuotaBytes,
TotpEnabled: u.TotpEnabled, SoraStorageUsedBytes: u.SoraStorageUsedBytes,
TotpEnabledAt: u.TotpEnabledAt, TotpSecretEncrypted: u.TotpSecretEncrypted,
CreatedAt: u.CreatedAt, TotpEnabled: u.TotpEnabled,
UpdatedAt: u.UpdatedAt, TotpEnabledAt: u.TotpEnabledAt,
CreatedAt: u.CreatedAt,
UpdatedAt: u.UpdatedAt,
} }
} }
...@@ -486,6 +488,7 @@ func groupEntityToService(g *dbent.Group) *service.Group { ...@@ -486,6 +488,7 @@ func groupEntityToService(g *dbent.Group) *service.Group {
SoraImagePrice540: g.SoraImagePrice540, SoraImagePrice540: g.SoraImagePrice540,
SoraVideoPricePerRequest: g.SoraVideoPricePerRequest, SoraVideoPricePerRequest: g.SoraVideoPricePerRequest,
SoraVideoPricePerRequestHD: g.SoraVideoPricePerRequestHd, SoraVideoPricePerRequestHD: g.SoraVideoPricePerRequestHd,
SoraStorageQuotaBytes: g.SoraStorageQuotaBytes,
DefaultValidityDays: g.DefaultValidityDays, DefaultValidityDays: g.DefaultValidityDays,
ClaudeCodeOnly: g.ClaudeCodeOnly, ClaudeCodeOnly: g.ClaudeCodeOnly,
FallbackGroupID: g.FallbackGroupID, FallbackGroupID: g.FallbackGroupID,
......
...@@ -227,6 +227,43 @@ func (c *concurrencyCache) GetAccountConcurrency(ctx context.Context, accountID ...@@ -227,6 +227,43 @@ func (c *concurrencyCache) GetAccountConcurrency(ctx context.Context, accountID
return result, nil return result, nil
} }
func (c *concurrencyCache) GetAccountConcurrencyBatch(ctx context.Context, accountIDs []int64) (map[int64]int, error) {
if len(accountIDs) == 0 {
return map[int64]int{}, nil
}
now, err := c.rdb.Time(ctx).Result()
if err != nil {
return nil, fmt.Errorf("redis TIME: %w", err)
}
cutoffTime := now.Unix() - int64(c.slotTTLSeconds)
pipe := c.rdb.Pipeline()
type accountCmd struct {
accountID int64
zcardCmd *redis.IntCmd
}
cmds := make([]accountCmd, 0, len(accountIDs))
for _, accountID := range accountIDs {
slotKey := accountSlotKeyPrefix + strconv.FormatInt(accountID, 10)
pipe.ZRemRangeByScore(ctx, slotKey, "-inf", strconv.FormatInt(cutoffTime, 10))
cmds = append(cmds, accountCmd{
accountID: accountID,
zcardCmd: pipe.ZCard(ctx, slotKey),
})
}
if _, err := pipe.Exec(ctx); err != nil && !errors.Is(err, redis.Nil) {
return nil, fmt.Errorf("pipeline exec: %w", err)
}
result := make(map[int64]int, len(accountIDs))
for _, cmd := range cmds {
result[cmd.accountID] = int(cmd.zcardCmd.Val())
}
return result, nil
}
// User slot operations // User slot operations
func (c *concurrencyCache) AcquireUserSlot(ctx context.Context, userID int64, maxConcurrency int, requestID string) (bool, error) { func (c *concurrencyCache) AcquireUserSlot(ctx context.Context, userID int64, maxConcurrency int, requestID string) (bool, error) {
......
...@@ -104,7 +104,6 @@ func (s *GatewayCacheSuite) TestGetSessionAccountID_CorruptedValue() { ...@@ -104,7 +104,6 @@ func (s *GatewayCacheSuite) TestGetSessionAccountID_CorruptedValue() {
require.False(s.T(), errors.Is(err, redis.Nil), "expected parsing error, not redis.Nil") require.False(s.T(), errors.Is(err, redis.Nil), "expected parsing error, not redis.Nil")
} }
func TestGatewayCacheSuite(t *testing.T) { func TestGatewayCacheSuite(t *testing.T) {
suite.Run(t, new(GatewayCacheSuite)) suite.Run(t, new(GatewayCacheSuite))
} }
...@@ -4,6 +4,8 @@ import ( ...@@ -4,6 +4,8 @@ import (
"context" "context"
"database/sql" "database/sql"
"errors" "errors"
"fmt"
"strings"
dbent "github.com/Wei-Shaw/sub2api/ent" dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/apikey" "github.com/Wei-Shaw/sub2api/ent/apikey"
...@@ -56,7 +58,8 @@ func (r *groupRepository) Create(ctx context.Context, groupIn *service.Group) er ...@@ -56,7 +58,8 @@ func (r *groupRepository) Create(ctx context.Context, groupIn *service.Group) er
SetNillableFallbackGroupID(groupIn.FallbackGroupID). SetNillableFallbackGroupID(groupIn.FallbackGroupID).
SetNillableFallbackGroupIDOnInvalidRequest(groupIn.FallbackGroupIDOnInvalidRequest). SetNillableFallbackGroupIDOnInvalidRequest(groupIn.FallbackGroupIDOnInvalidRequest).
SetModelRoutingEnabled(groupIn.ModelRoutingEnabled). SetModelRoutingEnabled(groupIn.ModelRoutingEnabled).
SetMcpXMLInject(groupIn.MCPXMLInject) SetMcpXMLInject(groupIn.MCPXMLInject).
SetSoraStorageQuotaBytes(groupIn.SoraStorageQuotaBytes)
// 设置模型路由配置 // 设置模型路由配置
if groupIn.ModelRouting != nil { if groupIn.ModelRouting != nil {
...@@ -121,7 +124,8 @@ func (r *groupRepository) Update(ctx context.Context, groupIn *service.Group) er ...@@ -121,7 +124,8 @@ func (r *groupRepository) Update(ctx context.Context, groupIn *service.Group) er
SetDefaultValidityDays(groupIn.DefaultValidityDays). SetDefaultValidityDays(groupIn.DefaultValidityDays).
SetClaudeCodeOnly(groupIn.ClaudeCodeOnly). SetClaudeCodeOnly(groupIn.ClaudeCodeOnly).
SetModelRoutingEnabled(groupIn.ModelRoutingEnabled). SetModelRoutingEnabled(groupIn.ModelRoutingEnabled).
SetMcpXMLInject(groupIn.MCPXMLInject) SetMcpXMLInject(groupIn.MCPXMLInject).
SetSoraStorageQuotaBytes(groupIn.SoraStorageQuotaBytes)
// 处理 FallbackGroupID:nil 时清除,否则设置 // 处理 FallbackGroupID:nil 时清除,否则设置
if groupIn.FallbackGroupID != nil { if groupIn.FallbackGroupID != nil {
...@@ -281,6 +285,54 @@ func (r *groupRepository) ExistsByName(ctx context.Context, name string) (bool, ...@@ -281,6 +285,54 @@ func (r *groupRepository) ExistsByName(ctx context.Context, name string) (bool,
return r.client.Group.Query().Where(group.NameEQ(name)).Exist(ctx) return r.client.Group.Query().Where(group.NameEQ(name)).Exist(ctx)
} }
// ExistsByIDs 批量检查分组是否存在(仅检查未软删除记录)。
// 返回结构:map[groupID]exists。
func (r *groupRepository) ExistsByIDs(ctx context.Context, ids []int64) (map[int64]bool, error) {
result := make(map[int64]bool, len(ids))
if len(ids) == 0 {
return result, nil
}
uniqueIDs := make([]int64, 0, len(ids))
seen := make(map[int64]struct{}, len(ids))
for _, id := range ids {
if id <= 0 {
continue
}
if _, ok := seen[id]; ok {
continue
}
seen[id] = struct{}{}
uniqueIDs = append(uniqueIDs, id)
result[id] = false
}
if len(uniqueIDs) == 0 {
return result, nil
}
rows, err := r.sql.QueryContext(ctx, `
SELECT id
FROM groups
WHERE id = ANY($1) AND deleted_at IS NULL
`, pq.Array(uniqueIDs))
if err != nil {
return nil, err
}
defer func() { _ = rows.Close() }()
for rows.Next() {
var id int64
if err := rows.Scan(&id); err != nil {
return nil, err
}
result[id] = true
}
if err := rows.Err(); err != nil {
return nil, err
}
return result, nil
}
func (r *groupRepository) GetAccountCount(ctx context.Context, groupID int64) (int64, error) { func (r *groupRepository) GetAccountCount(ctx context.Context, groupID int64) (int64, error) {
var count int64 var count int64
if err := scanSingleRow(ctx, r.sql, "SELECT COUNT(*) FROM account_groups WHERE group_id = $1", []any{groupID}, &count); err != nil { if err := scanSingleRow(ctx, r.sql, "SELECT COUNT(*) FROM account_groups WHERE group_id = $1", []any{groupID}, &count); err != nil {
...@@ -512,22 +564,72 @@ func (r *groupRepository) UpdateSortOrders(ctx context.Context, updates []servic ...@@ -512,22 +564,72 @@ func (r *groupRepository) UpdateSortOrders(ctx context.Context, updates []servic
return nil return nil
} }
// 使用事务批量更新 // 去重后保留最后一次排序值,避免重复 ID 造成 CASE 分支冲突。
tx, err := r.client.Tx(ctx) sortOrderByID := make(map[int64]int, len(updates))
if err != nil { groupIDs := make([]int64, 0, len(updates))
for _, u := range updates {
if u.ID <= 0 {
continue
}
if _, exists := sortOrderByID[u.ID]; !exists {
groupIDs = append(groupIDs, u.ID)
}
sortOrderByID[u.ID] = u.SortOrder
}
if len(groupIDs) == 0 {
return nil
}
// 与旧实现保持一致:任何不存在/已删除的分组都返回 not found,且不执行更新。
var existingCount int
if err := scanSingleRow(
ctx,
r.sql,
`SELECT COUNT(*) FROM groups WHERE deleted_at IS NULL AND id = ANY($1)`,
[]any{pq.Array(groupIDs)},
&existingCount,
); err != nil {
return err return err
} }
defer func() { _ = tx.Rollback() }() if existingCount != len(groupIDs) {
return service.ErrGroupNotFound
}
for _, u := range updates { args := make([]any, 0, len(groupIDs)*2+1)
if _, err := tx.Group.UpdateOneID(u.ID).SetSortOrder(u.SortOrder).Save(ctx); err != nil { caseClauses := make([]string, 0, len(groupIDs))
return translatePersistenceError(err, service.ErrGroupNotFound, nil) placeholder := 1
} for _, id := range groupIDs {
caseClauses = append(caseClauses, fmt.Sprintf("WHEN $%d THEN $%d", placeholder, placeholder+1))
args = append(args, id, sortOrderByID[id])
placeholder += 2
} }
args = append(args, pq.Array(groupIDs))
query := fmt.Sprintf(`
UPDATE groups
SET sort_order = CASE id
%s
ELSE sort_order
END
WHERE deleted_at IS NULL AND id = ANY($%d)
`, strings.Join(caseClauses, "\n\t\t\t"), placeholder)
if err := tx.Commit(); err != nil { result, err := r.sql.ExecContext(ctx, query, args...)
if err != nil {
return err
}
affected, err := result.RowsAffected()
if err != nil {
return err return err
} }
if affected != int64(len(groupIDs)) {
return service.ErrGroupNotFound
}
for _, id := range groupIDs {
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventGroupChanged, nil, &id, nil); err != nil {
logger.LegacyPrintf("repository.group", "[SchedulerOutbox] enqueue group sort update failed: group=%d err=%v", id, err)
}
}
return nil return nil
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment