Unverified Commit 6bccb8a8 authored by Wesley Liddick's avatar Wesley Liddick Committed by GitHub
Browse files

Merge branch 'main' into feature/antigravity-user-agent-configurable

parents 1fc6ef3d 3de1e0e4
...@@ -44,6 +44,16 @@ func GetClientIP(c *gin.Context) string { ...@@ -44,6 +44,16 @@ func GetClientIP(c *gin.Context) string {
return normalizeIP(c.ClientIP()) return normalizeIP(c.ClientIP())
} }
// GetTrustedClientIP 从 Gin 的可信代理解析链提取客户端 IP。
// 该方法依赖 gin.Engine.SetTrustedProxies 配置,不会优先直接信任原始转发头值。
// 适用于 ACL / 风控等安全敏感场景。
func GetTrustedClientIP(c *gin.Context) string {
if c == nil {
return ""
}
return normalizeIP(c.ClientIP())
}
// normalizeIP 规范化 IP 地址,去除端口号和空格。 // normalizeIP 规范化 IP 地址,去除端口号和空格。
func normalizeIP(ip string) string { func normalizeIP(ip string) string {
ip = strings.TrimSpace(ip) ip = strings.TrimSpace(ip)
...@@ -54,29 +64,34 @@ func normalizeIP(ip string) string { ...@@ -54,29 +64,34 @@ func normalizeIP(ip string) string {
return ip return ip
} }
// isPrivateIP 检查 IP 是否为私有地址。 // privateNets 预编译私有 IP CIDR 块,避免每次调用 isPrivateIP 时重复解析
func isPrivateIP(ipStr string) bool { var privateNets []*net.IPNet
ip := net.ParseIP(ipStr)
if ip == nil {
return false
}
// 私有 IP 范围 func init() {
privateBlocks := []string{ for _, cidr := range []string{
"10.0.0.0/8", "10.0.0.0/8",
"172.16.0.0/12", "172.16.0.0/12",
"192.168.0.0/16", "192.168.0.0/16",
"127.0.0.0/8", "127.0.0.0/8",
"::1/128", "::1/128",
"fc00::/7", "fc00::/7",
} } {
_, block, err := net.ParseCIDR(cidr)
for _, block := range privateBlocks {
_, cidr, err := net.ParseCIDR(block)
if err != nil { if err != nil {
continue panic("invalid CIDR: " + cidr)
} }
if cidr.Contains(ip) { privateNets = append(privateNets, block)
}
}
// isPrivateIP 检查 IP 是否为私有地址。
func isPrivateIP(ipStr string) bool {
ip := net.ParseIP(ipStr)
if ip == nil {
return false
}
for _, block := range privateNets {
if block.Contains(ip) {
return true return true
} }
} }
......
//go:build unit
package ip
import (
"net/http/httptest"
"testing"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
func TestIsPrivateIP(t *testing.T) {
tests := []struct {
name string
ip string
expected bool
}{
// 私有 IPv4
{"10.x 私有地址", "10.0.0.1", true},
{"10.x 私有地址段末", "10.255.255.255", true},
{"172.16.x 私有地址", "172.16.0.1", true},
{"172.31.x 私有地址", "172.31.255.255", true},
{"192.168.x 私有地址", "192.168.1.1", true},
{"127.0.0.1 本地回环", "127.0.0.1", true},
{"127.x 回环段", "127.255.255.255", true},
// 公网 IPv4
{"8.8.8.8 公网 DNS", "8.8.8.8", false},
{"1.1.1.1 公网", "1.1.1.1", false},
{"172.15.255.255 非私有", "172.15.255.255", false},
{"172.32.0.0 非私有", "172.32.0.0", false},
{"11.0.0.1 公网", "11.0.0.1", false},
// IPv6
{"::1 IPv6 回环", "::1", true},
{"fc00:: IPv6 私有", "fc00::1", true},
{"fd00:: IPv6 私有", "fd00::1", true},
{"2001:db8::1 IPv6 公网", "2001:db8::1", false},
// 无效输入
{"空字符串", "", false},
{"非法字符串", "not-an-ip", false},
{"不完整 IP", "192.168", false},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
got := isPrivateIP(tc.ip)
require.Equal(t, tc.expected, got, "isPrivateIP(%q)", tc.ip)
})
}
}
func TestGetTrustedClientIPUsesGinClientIP(t *testing.T) {
gin.SetMode(gin.TestMode)
r := gin.New()
require.NoError(t, r.SetTrustedProxies(nil))
r.GET("/t", func(c *gin.Context) {
c.String(200, GetTrustedClientIP(c))
})
w := httptest.NewRecorder()
req := httptest.NewRequest("GET", "/t", nil)
req.RemoteAddr = "9.9.9.9:12345"
req.Header.Set("X-Forwarded-For", "1.2.3.4")
req.Header.Set("X-Real-IP", "1.2.3.4")
req.Header.Set("CF-Connecting-IP", "1.2.3.4")
r.ServeHTTP(w, req)
require.Equal(t, 200, w.Code)
require.Equal(t, "9.9.9.9", w.Body.String())
}
package logger
import "github.com/Wei-Shaw/sub2api/internal/config"
func OptionsFromConfig(cfg config.LogConfig) InitOptions {
return InitOptions{
Level: cfg.Level,
Format: cfg.Format,
ServiceName: cfg.ServiceName,
Environment: cfg.Environment,
Caller: cfg.Caller,
StacktraceLevel: cfg.StacktraceLevel,
Output: OutputOptions{
ToStdout: cfg.Output.ToStdout,
ToFile: cfg.Output.ToFile,
FilePath: cfg.Output.FilePath,
},
Rotation: RotationOptions{
MaxSizeMB: cfg.Rotation.MaxSizeMB,
MaxBackups: cfg.Rotation.MaxBackups,
MaxAgeDays: cfg.Rotation.MaxAgeDays,
Compress: cfg.Rotation.Compress,
LocalTime: cfg.Rotation.LocalTime,
},
Sampling: SamplingOptions{
Enabled: cfg.Sampling.Enabled,
Initial: cfg.Sampling.Initial,
Thereafter: cfg.Sampling.Thereafter,
},
}
}
package logger
import (
"context"
"fmt"
"io"
"log"
"log/slog"
"os"
"path/filepath"
"strings"
"sync"
"time"
"go.uber.org/zap"
"go.uber.org/zap/zapcore"
"gopkg.in/natefinch/lumberjack.v2"
)
type Level = zapcore.Level
const (
LevelDebug = zapcore.DebugLevel
LevelInfo = zapcore.InfoLevel
LevelWarn = zapcore.WarnLevel
LevelError = zapcore.ErrorLevel
LevelFatal = zapcore.FatalLevel
)
type Sink interface {
WriteLogEvent(event *LogEvent)
}
type LogEvent struct {
Time time.Time
Level string
Component string
Message string
LoggerName string
Fields map[string]any
}
var (
mu sync.RWMutex
global *zap.Logger
sugar *zap.SugaredLogger
atomicLevel zap.AtomicLevel
initOptions InitOptions
currentSink Sink
stdLogUndo func()
bootstrapOnce sync.Once
)
func InitBootstrap() {
bootstrapOnce.Do(func() {
if err := Init(bootstrapOptions()); err != nil {
_, _ = fmt.Fprintf(os.Stderr, "logger bootstrap init failed: %v\n", err)
}
})
}
func Init(options InitOptions) error {
mu.Lock()
defer mu.Unlock()
return initLocked(options)
}
func initLocked(options InitOptions) error {
normalized := options.normalized()
zl, al, err := buildLogger(normalized)
if err != nil {
return err
}
prev := global
global = zl
sugar = zl.Sugar()
atomicLevel = al
initOptions = normalized
bridgeSlogLocked()
bridgeStdLogLocked()
if prev != nil {
_ = prev.Sync()
}
return nil
}
func Reconfigure(mutator func(*InitOptions) error) error {
mu.Lock()
defer mu.Unlock()
next := initOptions
if mutator != nil {
if err := mutator(&next); err != nil {
return err
}
}
return initLocked(next)
}
func SetLevel(level string) error {
lv, ok := parseLevel(level)
if !ok {
return fmt.Errorf("invalid log level: %s", level)
}
mu.Lock()
defer mu.Unlock()
atomicLevel.SetLevel(lv)
initOptions.Level = strings.ToLower(strings.TrimSpace(level))
return nil
}
func CurrentLevel() string {
mu.RLock()
defer mu.RUnlock()
if global == nil {
return "info"
}
return atomicLevel.Level().String()
}
func SetSink(sink Sink) {
mu.Lock()
defer mu.Unlock()
currentSink = sink
}
// WriteSinkEvent 直接写入日志 sink,不经过全局日志级别门控。
// 用于需要“可观测性入库”与“业务输出级别”解耦的场景(例如 ops 系统日志索引)。
func WriteSinkEvent(level, component, message string, fields map[string]any) {
mu.RLock()
sink := currentSink
mu.RUnlock()
if sink == nil {
return
}
level = strings.ToLower(strings.TrimSpace(level))
if level == "" {
level = "info"
}
component = strings.TrimSpace(component)
message = strings.TrimSpace(message)
if message == "" {
return
}
eventFields := make(map[string]any, len(fields)+1)
for k, v := range fields {
eventFields[k] = v
}
if component != "" {
if _, ok := eventFields["component"]; !ok {
eventFields["component"] = component
}
}
sink.WriteLogEvent(&LogEvent{
Time: time.Now(),
Level: level,
Component: component,
Message: message,
LoggerName: component,
Fields: eventFields,
})
}
func L() *zap.Logger {
mu.RLock()
defer mu.RUnlock()
if global != nil {
return global
}
return zap.NewNop()
}
func S() *zap.SugaredLogger {
mu.RLock()
defer mu.RUnlock()
if sugar != nil {
return sugar
}
return zap.NewNop().Sugar()
}
func With(fields ...zap.Field) *zap.Logger {
return L().With(fields...)
}
func Sync() {
mu.RLock()
l := global
mu.RUnlock()
if l != nil {
_ = l.Sync()
}
}
func bridgeStdLogLocked() {
if stdLogUndo != nil {
stdLogUndo()
stdLogUndo = nil
}
prevFlags := log.Flags()
prevPrefix := log.Prefix()
prevWriter := log.Writer()
log.SetFlags(0)
log.SetPrefix("")
log.SetOutput(newStdLogBridge(global.Named("stdlog")))
stdLogUndo = func() {
log.SetOutput(prevWriter)
log.SetFlags(prevFlags)
log.SetPrefix(prevPrefix)
}
}
func bridgeSlogLocked() {
slog.SetDefault(slog.New(newSlogZapHandler(global.Named("slog"))))
}
func buildLogger(options InitOptions) (*zap.Logger, zap.AtomicLevel, error) {
level, _ := parseLevel(options.Level)
atomic := zap.NewAtomicLevelAt(level)
encoderCfg := zapcore.EncoderConfig{
TimeKey: "time",
LevelKey: "level",
NameKey: "logger",
CallerKey: "caller",
MessageKey: "msg",
StacktraceKey: "stacktrace",
LineEnding: zapcore.DefaultLineEnding,
EncodeLevel: zapcore.CapitalLevelEncoder,
EncodeTime: zapcore.ISO8601TimeEncoder,
EncodeDuration: zapcore.MillisDurationEncoder,
EncodeCaller: zapcore.ShortCallerEncoder,
}
var enc zapcore.Encoder
if options.Format == "console" {
enc = zapcore.NewConsoleEncoder(encoderCfg)
} else {
enc = zapcore.NewJSONEncoder(encoderCfg)
}
sinkCore := newSinkCore()
cores := make([]zapcore.Core, 0, 3)
if options.Output.ToStdout {
infoPriority := zap.LevelEnablerFunc(func(lvl zapcore.Level) bool {
return lvl >= atomic.Level() && lvl < zapcore.WarnLevel
})
errPriority := zap.LevelEnablerFunc(func(lvl zapcore.Level) bool {
return lvl >= atomic.Level() && lvl >= zapcore.WarnLevel
})
cores = append(cores, zapcore.NewCore(enc, zapcore.Lock(os.Stdout), infoPriority))
cores = append(cores, zapcore.NewCore(enc, zapcore.Lock(os.Stderr), errPriority))
}
if options.Output.ToFile {
fileCore, filePath, fileErr := buildFileCore(enc, atomic, options)
if fileErr != nil {
_, _ = fmt.Fprintf(os.Stderr, "time=%s level=WARN msg=\"日志文件输出初始化失败,降级为仅标准输出\" path=%s err=%v\n",
time.Now().Format(time.RFC3339Nano),
filePath,
fileErr,
)
} else {
cores = append(cores, fileCore)
}
}
if len(cores) == 0 {
cores = append(cores, zapcore.NewCore(enc, zapcore.Lock(os.Stdout), atomic))
}
core := zapcore.NewTee(cores...)
if options.Sampling.Enabled {
core = zapcore.NewSamplerWithOptions(core, samplingTick(), options.Sampling.Initial, options.Sampling.Thereafter)
}
core = sinkCore.Wrap(core)
stacktraceLevel, _ := parseStacktraceLevel(options.StacktraceLevel)
zapOpts := make([]zap.Option, 0, 5)
if options.Caller {
zapOpts = append(zapOpts, zap.AddCaller())
}
if stacktraceLevel <= zapcore.FatalLevel {
zapOpts = append(zapOpts, zap.AddStacktrace(stacktraceLevel))
}
logger := zap.New(core, zapOpts...).With(
zap.String("service", options.ServiceName),
zap.String("env", options.Environment),
)
return logger, atomic, nil
}
func buildFileCore(enc zapcore.Encoder, atomic zap.AtomicLevel, options InitOptions) (zapcore.Core, string, error) {
filePath := options.Output.FilePath
if strings.TrimSpace(filePath) == "" {
filePath = resolveLogFilePath("")
}
dir := filepath.Dir(filePath)
if err := os.MkdirAll(dir, 0o755); err != nil {
return nil, filePath, err
}
lj := &lumberjack.Logger{
Filename: filePath,
MaxSize: options.Rotation.MaxSizeMB,
MaxBackups: options.Rotation.MaxBackups,
MaxAge: options.Rotation.MaxAgeDays,
Compress: options.Rotation.Compress,
LocalTime: options.Rotation.LocalTime,
}
return zapcore.NewCore(enc, zapcore.AddSync(lj), atomic), filePath, nil
}
type sinkCore struct {
core zapcore.Core
fields []zapcore.Field
}
func newSinkCore() *sinkCore {
return &sinkCore{}
}
func (s *sinkCore) Wrap(core zapcore.Core) zapcore.Core {
cp := *s
cp.core = core
return &cp
}
func (s *sinkCore) Enabled(level zapcore.Level) bool {
return s.core.Enabled(level)
}
func (s *sinkCore) With(fields []zapcore.Field) zapcore.Core {
nextFields := append([]zapcore.Field{}, s.fields...)
nextFields = append(nextFields, fields...)
return &sinkCore{
core: s.core.With(fields),
fields: nextFields,
}
}
func (s *sinkCore) Check(entry zapcore.Entry, ce *zapcore.CheckedEntry) *zapcore.CheckedEntry {
// Delegate to inner core (tee) so each sub-core's level enabler is respected.
// Then add ourselves for sink forwarding only.
ce = s.core.Check(entry, ce)
if ce != nil {
ce = ce.AddCore(entry, s)
}
return ce
}
func (s *sinkCore) Write(entry zapcore.Entry, fields []zapcore.Field) error {
// Only handle sink forwarding — the inner cores write via their own
// Write methods (added to CheckedEntry by s.core.Check above).
mu.RLock()
sink := currentSink
mu.RUnlock()
if sink == nil {
return nil
}
enc := zapcore.NewMapObjectEncoder()
for _, f := range s.fields {
f.AddTo(enc)
}
for _, f := range fields {
f.AddTo(enc)
}
event := &LogEvent{
Time: entry.Time,
Level: strings.ToLower(entry.Level.String()),
Component: entry.LoggerName,
Message: entry.Message,
LoggerName: entry.LoggerName,
Fields: enc.Fields,
}
sink.WriteLogEvent(event)
return nil
}
func (s *sinkCore) Sync() error {
return s.core.Sync()
}
type stdLogBridge struct {
logger *zap.Logger
}
func newStdLogBridge(l *zap.Logger) io.Writer {
if l == nil {
l = zap.NewNop()
}
return &stdLogBridge{logger: l}
}
func (b *stdLogBridge) Write(p []byte) (int, error) {
msg := normalizeStdLogMessage(string(p))
if msg == "" {
return len(p), nil
}
level := inferStdLogLevel(msg)
entry := b.logger.WithOptions(zap.AddCallerSkip(4))
switch level {
case LevelDebug:
entry.Debug(msg, zap.Bool("legacy_stdlog", true))
case LevelWarn:
entry.Warn(msg, zap.Bool("legacy_stdlog", true))
case LevelError, LevelFatal:
entry.Error(msg, zap.Bool("legacy_stdlog", true))
default:
entry.Info(msg, zap.Bool("legacy_stdlog", true))
}
return len(p), nil
}
func normalizeStdLogMessage(raw string) string {
msg := strings.TrimSpace(strings.ReplaceAll(raw, "\n", " "))
if msg == "" {
return ""
}
return strings.Join(strings.Fields(msg), " ")
}
func inferStdLogLevel(msg string) Level {
lower := strings.ToLower(strings.TrimSpace(msg))
if lower == "" {
return LevelInfo
}
if strings.HasPrefix(lower, "[debug]") || strings.HasPrefix(lower, "debug:") {
return LevelDebug
}
if strings.HasPrefix(lower, "[warn]") || strings.HasPrefix(lower, "[warning]") || strings.HasPrefix(lower, "warn:") || strings.HasPrefix(lower, "warning:") {
return LevelWarn
}
if strings.HasPrefix(lower, "[error]") || strings.HasPrefix(lower, "error:") || strings.HasPrefix(lower, "fatal:") || strings.HasPrefix(lower, "panic:") {
return LevelError
}
if strings.Contains(lower, " failed") || strings.Contains(lower, "error") || strings.Contains(lower, "panic") || strings.Contains(lower, "fatal") {
return LevelError
}
if strings.Contains(lower, "warning") || strings.Contains(lower, "warn") || strings.Contains(lower, " retry") || strings.Contains(lower, " queue full") || strings.Contains(lower, "fallback") {
return LevelWarn
}
return LevelInfo
}
// LegacyPrintf 用于平滑迁移历史的 printf 风格日志到结构化 logger。
func LegacyPrintf(component, format string, args ...any) {
msg := normalizeStdLogMessage(fmt.Sprintf(format, args...))
if msg == "" {
return
}
mu.RLock()
initialized := global != nil
mu.RUnlock()
if !initialized {
// 在日志系统未初始化前,回退到标准库 log,避免测试/工具链丢日志。
log.Print(msg)
return
}
l := L()
if component != "" {
l = l.With(zap.String("component", component))
}
l = l.WithOptions(zap.AddCallerSkip(1))
switch inferStdLogLevel(msg) {
case LevelDebug:
l.Debug(msg, zap.Bool("legacy_printf", true))
case LevelWarn:
l.Warn(msg, zap.Bool("legacy_printf", true))
case LevelError, LevelFatal:
l.Error(msg, zap.Bool("legacy_printf", true))
default:
l.Info(msg, zap.Bool("legacy_printf", true))
}
}
type contextKey string
const loggerContextKey contextKey = "ctx_logger"
func IntoContext(ctx context.Context, l *zap.Logger) context.Context {
if ctx == nil {
ctx = context.Background()
}
if l == nil {
l = L()
}
return context.WithValue(ctx, loggerContextKey, l)
}
func FromContext(ctx context.Context) *zap.Logger {
if ctx == nil {
return L()
}
if l, ok := ctx.Value(loggerContextKey).(*zap.Logger); ok && l != nil {
return l
}
return L()
}
package logger
import (
"encoding/json"
"io"
"os"
"path/filepath"
"strings"
"testing"
)
func TestInit_DualOutput(t *testing.T) {
tmpDir := t.TempDir()
logPath := filepath.Join(tmpDir, "logs", "sub2api.log")
origStdout := os.Stdout
origStderr := os.Stderr
stdoutR, stdoutW, err := os.Pipe()
if err != nil {
t.Fatalf("create stdout pipe: %v", err)
}
stderrR, stderrW, err := os.Pipe()
if err != nil {
t.Fatalf("create stderr pipe: %v", err)
}
os.Stdout = stdoutW
os.Stderr = stderrW
t.Cleanup(func() {
os.Stdout = origStdout
os.Stderr = origStderr
_ = stdoutR.Close()
_ = stderrR.Close()
_ = stdoutW.Close()
_ = stderrW.Close()
})
err = Init(InitOptions{
Level: "debug",
Format: "json",
ServiceName: "sub2api",
Environment: "test",
Output: OutputOptions{
ToStdout: true,
ToFile: true,
FilePath: logPath,
},
Rotation: RotationOptions{
MaxSizeMB: 10,
MaxBackups: 2,
MaxAgeDays: 1,
},
Sampling: SamplingOptions{Enabled: false},
})
if err != nil {
t.Fatalf("Init() error: %v", err)
}
L().Info("dual-output-info")
L().Warn("dual-output-warn")
Sync()
_ = stdoutW.Close()
_ = stderrW.Close()
stdoutBytes, _ := io.ReadAll(stdoutR)
stderrBytes, _ := io.ReadAll(stderrR)
stdoutText := string(stdoutBytes)
stderrText := string(stderrBytes)
if !strings.Contains(stdoutText, "dual-output-info") {
t.Fatalf("stdout missing info log: %s", stdoutText)
}
if !strings.Contains(stderrText, "dual-output-warn") {
t.Fatalf("stderr missing warn log: %s", stderrText)
}
fileBytes, err := os.ReadFile(logPath)
if err != nil {
t.Fatalf("read log file: %v", err)
}
fileText := string(fileBytes)
if !strings.Contains(fileText, "dual-output-info") || !strings.Contains(fileText, "dual-output-warn") {
t.Fatalf("file missing logs: %s", fileText)
}
}
func TestInit_FileOutputFailureDowngrade(t *testing.T) {
origStdout := os.Stdout
origStderr := os.Stderr
_, stdoutW, err := os.Pipe()
if err != nil {
t.Fatalf("create stdout pipe: %v", err)
}
stderrR, stderrW, err := os.Pipe()
if err != nil {
t.Fatalf("create stderr pipe: %v", err)
}
os.Stdout = stdoutW
os.Stderr = stderrW
t.Cleanup(func() {
os.Stdout = origStdout
os.Stderr = origStderr
_ = stdoutW.Close()
_ = stderrR.Close()
_ = stderrW.Close()
})
err = Init(InitOptions{
Level: "info",
Format: "json",
Output: OutputOptions{
ToStdout: true,
ToFile: true,
FilePath: filepath.Join(os.DevNull, "logs", "sub2api.log"),
},
Rotation: RotationOptions{
MaxSizeMB: 10,
MaxBackups: 1,
MaxAgeDays: 1,
},
})
if err != nil {
t.Fatalf("Init() should downgrade instead of failing, got: %v", err)
}
_ = stderrW.Close()
stderrBytes, _ := io.ReadAll(stderrR)
if !strings.Contains(string(stderrBytes), "日志文件输出初始化失败") {
t.Fatalf("stderr should contain fallback warning, got: %s", string(stderrBytes))
}
}
func TestInit_CallerShouldPointToCallsite(t *testing.T) {
origStdout := os.Stdout
origStderr := os.Stderr
stdoutR, stdoutW, err := os.Pipe()
if err != nil {
t.Fatalf("create stdout pipe: %v", err)
}
_, stderrW, err := os.Pipe()
if err != nil {
t.Fatalf("create stderr pipe: %v", err)
}
os.Stdout = stdoutW
os.Stderr = stderrW
t.Cleanup(func() {
os.Stdout = origStdout
os.Stderr = origStderr
_ = stdoutR.Close()
_ = stdoutW.Close()
_ = stderrW.Close()
})
if err := Init(InitOptions{
Level: "info",
Format: "json",
ServiceName: "sub2api",
Environment: "test",
Caller: true,
Output: OutputOptions{
ToStdout: true,
ToFile: false,
},
Sampling: SamplingOptions{Enabled: false},
}); err != nil {
t.Fatalf("Init() error: %v", err)
}
L().Info("caller-check")
Sync()
_ = stdoutW.Close()
logBytes, _ := io.ReadAll(stdoutR)
var line string
for _, item := range strings.Split(string(logBytes), "\n") {
if strings.Contains(item, "caller-check") {
line = item
break
}
}
if line == "" {
t.Fatalf("log output missing caller-check: %s", string(logBytes))
}
var payload map[string]any
if err := json.Unmarshal([]byte(line), &payload); err != nil {
t.Fatalf("parse log json failed: %v, line=%s", err, line)
}
caller, _ := payload["caller"].(string)
if !strings.Contains(caller, "logger_test.go:") {
t.Fatalf("caller should point to this test file, got: %s", caller)
}
}
package logger
import (
"os"
"path/filepath"
"strings"
"time"
)
const (
// DefaultContainerLogPath 为容器内默认日志文件路径。
DefaultContainerLogPath = "/app/data/logs/sub2api.log"
defaultLogFilename = "sub2api.log"
)
type InitOptions struct {
Level string
Format string
ServiceName string
Environment string
Caller bool
StacktraceLevel string
Output OutputOptions
Rotation RotationOptions
Sampling SamplingOptions
}
type OutputOptions struct {
ToStdout bool
ToFile bool
FilePath string
}
type RotationOptions struct {
MaxSizeMB int
MaxBackups int
MaxAgeDays int
Compress bool
LocalTime bool
}
type SamplingOptions struct {
Enabled bool
Initial int
Thereafter int
}
func (o InitOptions) normalized() InitOptions {
out := o
out.Level = strings.ToLower(strings.TrimSpace(out.Level))
if out.Level == "" {
out.Level = "info"
}
out.Format = strings.ToLower(strings.TrimSpace(out.Format))
if out.Format == "" {
out.Format = "console"
}
out.ServiceName = strings.TrimSpace(out.ServiceName)
if out.ServiceName == "" {
out.ServiceName = "sub2api"
}
out.Environment = strings.TrimSpace(out.Environment)
if out.Environment == "" {
out.Environment = "production"
}
out.StacktraceLevel = strings.ToLower(strings.TrimSpace(out.StacktraceLevel))
if out.StacktraceLevel == "" {
out.StacktraceLevel = "error"
}
if !out.Output.ToStdout && !out.Output.ToFile {
out.Output.ToStdout = true
}
out.Output.FilePath = resolveLogFilePath(out.Output.FilePath)
if out.Rotation.MaxSizeMB <= 0 {
out.Rotation.MaxSizeMB = 100
}
if out.Rotation.MaxBackups < 0 {
out.Rotation.MaxBackups = 10
}
if out.Rotation.MaxAgeDays < 0 {
out.Rotation.MaxAgeDays = 7
}
if out.Sampling.Enabled {
if out.Sampling.Initial <= 0 {
out.Sampling.Initial = 100
}
if out.Sampling.Thereafter <= 0 {
out.Sampling.Thereafter = 100
}
}
return out
}
func resolveLogFilePath(explicit string) string {
explicit = strings.TrimSpace(explicit)
if explicit != "" {
return explicit
}
dataDir := strings.TrimSpace(os.Getenv("DATA_DIR"))
if dataDir != "" {
return filepath.Join(dataDir, "logs", defaultLogFilename)
}
return DefaultContainerLogPath
}
func bootstrapOptions() InitOptions {
return InitOptions{
Level: "info",
Format: "console",
ServiceName: "sub2api",
Environment: "bootstrap",
Output: OutputOptions{
ToStdout: true,
ToFile: false,
},
Rotation: RotationOptions{
MaxSizeMB: 100,
MaxBackups: 10,
MaxAgeDays: 7,
Compress: true,
LocalTime: true,
},
Sampling: SamplingOptions{
Enabled: false,
Initial: 100,
Thereafter: 100,
},
}
}
func parseLevel(level string) (Level, bool) {
switch strings.ToLower(strings.TrimSpace(level)) {
case "debug":
return LevelDebug, true
case "info":
return LevelInfo, true
case "warn":
return LevelWarn, true
case "error":
return LevelError, true
default:
return LevelInfo, false
}
}
func parseStacktraceLevel(level string) (Level, bool) {
switch strings.ToLower(strings.TrimSpace(level)) {
case "none":
return LevelFatal + 1, true
case "error":
return LevelError, true
case "fatal":
return LevelFatal, true
default:
return LevelError, false
}
}
func samplingTick() time.Duration {
return time.Second
}
package logger
import (
"os"
"path/filepath"
"testing"
"go.uber.org/zap"
"go.uber.org/zap/zapcore"
)
func TestResolveLogFilePath_Default(t *testing.T) {
t.Setenv("DATA_DIR", "")
got := resolveLogFilePath("")
if got != DefaultContainerLogPath {
t.Fatalf("resolveLogFilePath() = %q, want %q", got, DefaultContainerLogPath)
}
}
func TestResolveLogFilePath_WithDataDir(t *testing.T) {
t.Setenv("DATA_DIR", "/tmp/sub2api-data")
got := resolveLogFilePath("")
want := filepath.Join("/tmp/sub2api-data", "logs", "sub2api.log")
if got != want {
t.Fatalf("resolveLogFilePath() = %q, want %q", got, want)
}
}
func TestResolveLogFilePath_ExplicitPath(t *testing.T) {
t.Setenv("DATA_DIR", "/tmp/ignore")
got := resolveLogFilePath("/var/log/custom.log")
if got != "/var/log/custom.log" {
t.Fatalf("resolveLogFilePath() = %q, want explicit path", got)
}
}
func TestNormalizedOptions_InvalidFallback(t *testing.T) {
t.Setenv("DATA_DIR", "")
opts := InitOptions{
Level: "TRACE",
Format: "TEXT",
ServiceName: "",
Environment: "",
StacktraceLevel: "panic",
Output: OutputOptions{
ToStdout: false,
ToFile: false,
},
Rotation: RotationOptions{
MaxSizeMB: 0,
MaxBackups: -1,
MaxAgeDays: -1,
},
Sampling: SamplingOptions{
Enabled: true,
Initial: 0,
Thereafter: 0,
},
}
out := opts.normalized()
if out.Level != "trace" {
// normalized 仅做 trim/lower,不做校验;校验在 config 层。
t.Fatalf("normalized level should preserve value for upstream validation, got %q", out.Level)
}
if !out.Output.ToStdout {
t.Fatalf("normalized output should fallback to stdout")
}
if out.Output.FilePath != DefaultContainerLogPath {
t.Fatalf("normalized file path = %q", out.Output.FilePath)
}
if out.Rotation.MaxSizeMB != 100 {
t.Fatalf("normalized max_size_mb = %d", out.Rotation.MaxSizeMB)
}
if out.Rotation.MaxBackups != 10 {
t.Fatalf("normalized max_backups = %d", out.Rotation.MaxBackups)
}
if out.Rotation.MaxAgeDays != 7 {
t.Fatalf("normalized max_age_days = %d", out.Rotation.MaxAgeDays)
}
if out.Sampling.Initial != 100 || out.Sampling.Thereafter != 100 {
t.Fatalf("normalized sampling defaults invalid: %+v", out.Sampling)
}
}
func TestBuildFileCore_InvalidPathFallback(t *testing.T) {
t.Setenv("DATA_DIR", "")
opts := bootstrapOptions()
opts.Output.ToFile = true
opts.Output.FilePath = filepath.Join(os.DevNull, "logs", "sub2api.log")
encoderCfg := zapcore.EncoderConfig{
TimeKey: "time",
LevelKey: "level",
MessageKey: "msg",
EncodeTime: zapcore.ISO8601TimeEncoder,
EncodeLevel: zapcore.CapitalLevelEncoder,
}
encoder := zapcore.NewJSONEncoder(encoderCfg)
_, _, err := buildFileCore(encoder, zap.NewAtomicLevel(), opts)
if err == nil {
t.Fatalf("buildFileCore() expected error for invalid path")
}
}
package logger
import (
"context"
"log/slog"
"strings"
"time"
"go.uber.org/zap"
"go.uber.org/zap/zapcore"
)
type slogZapHandler struct {
logger *zap.Logger
attrs []slog.Attr
groups []string
}
func newSlogZapHandler(logger *zap.Logger) slog.Handler {
if logger == nil {
logger = zap.NewNop()
}
return &slogZapHandler{
logger: logger,
attrs: make([]slog.Attr, 0, 8),
groups: make([]string, 0, 4),
}
}
func (h *slogZapHandler) Enabled(_ context.Context, level slog.Level) bool {
switch {
case level >= slog.LevelError:
return h.logger.Core().Enabled(LevelError)
case level >= slog.LevelWarn:
return h.logger.Core().Enabled(LevelWarn)
case level <= slog.LevelDebug:
return h.logger.Core().Enabled(LevelDebug)
default:
return h.logger.Core().Enabled(LevelInfo)
}
}
func (h *slogZapHandler) Handle(_ context.Context, record slog.Record) error {
fields := make([]zap.Field, 0, len(h.attrs)+record.NumAttrs()+3)
fields = append(fields, slogAttrsToZapFields(h.groups, h.attrs)...)
record.Attrs(func(attr slog.Attr) bool {
fields = append(fields, slogAttrToZapField(h.groups, attr))
return true
})
entry := h.logger.With(fields...)
switch {
case record.Level >= slog.LevelError:
entry.Error(record.Message)
case record.Level >= slog.LevelWarn:
entry.Warn(record.Message)
case record.Level <= slog.LevelDebug:
entry.Debug(record.Message)
default:
entry.Info(record.Message)
}
return nil
}
func (h *slogZapHandler) WithAttrs(attrs []slog.Attr) slog.Handler {
next := *h
next.attrs = append(append([]slog.Attr{}, h.attrs...), attrs...)
return &next
}
func (h *slogZapHandler) WithGroup(name string) slog.Handler {
name = strings.TrimSpace(name)
if name == "" {
return h
}
next := *h
next.groups = append(append([]string{}, h.groups...), name)
return &next
}
func slogAttrsToZapFields(groups []string, attrs []slog.Attr) []zap.Field {
fields := make([]zap.Field, 0, len(attrs))
for _, attr := range attrs {
fields = append(fields, slogAttrToZapField(groups, attr))
}
return fields
}
func slogAttrToZapField(groups []string, attr slog.Attr) zap.Field {
if len(groups) > 0 {
attr.Key = strings.Join(append(append([]string{}, groups...), attr.Key), ".")
}
value := attr.Value.Resolve()
switch value.Kind() {
case slog.KindBool:
return zap.Bool(attr.Key, value.Bool())
case slog.KindInt64:
return zap.Int64(attr.Key, value.Int64())
case slog.KindUint64:
return zap.Uint64(attr.Key, value.Uint64())
case slog.KindFloat64:
return zap.Float64(attr.Key, value.Float64())
case slog.KindDuration:
return zap.Duration(attr.Key, value.Duration())
case slog.KindTime:
return zap.Time(attr.Key, value.Time())
case slog.KindString:
return zap.String(attr.Key, value.String())
case slog.KindGroup:
groupFields := make([]zap.Field, 0, len(value.Group()))
for _, nested := range value.Group() {
groupFields = append(groupFields, slogAttrToZapField(nil, nested))
}
return zap.Object(attr.Key, zapObjectFields(groupFields))
case slog.KindAny:
if t, ok := value.Any().(time.Time); ok {
return zap.Time(attr.Key, t)
}
return zap.Any(attr.Key, value.Any())
default:
return zap.String(attr.Key, value.String())
}
}
type zapObjectFields []zap.Field
func (z zapObjectFields) MarshalLogObject(enc zapcore.ObjectEncoder) error {
for _, field := range z {
field.AddTo(enc)
}
return nil
}
package logger
import (
"context"
"log/slog"
"testing"
"time"
"go.uber.org/zap"
"go.uber.org/zap/zapcore"
)
type captureState struct {
writes []capturedWrite
}
type capturedWrite struct {
fields []zapcore.Field
}
type captureCore struct {
state *captureState
withFields []zapcore.Field
}
func newCaptureCore() *captureCore {
return &captureCore{state: &captureState{}}
}
func (c *captureCore) Enabled(zapcore.Level) bool {
return true
}
func (c *captureCore) With(fields []zapcore.Field) zapcore.Core {
nextFields := make([]zapcore.Field, 0, len(c.withFields)+len(fields))
nextFields = append(nextFields, c.withFields...)
nextFields = append(nextFields, fields...)
return &captureCore{
state: c.state,
withFields: nextFields,
}
}
func (c *captureCore) Check(entry zapcore.Entry, ce *zapcore.CheckedEntry) *zapcore.CheckedEntry {
return ce.AddCore(entry, c)
}
func (c *captureCore) Write(entry zapcore.Entry, fields []zapcore.Field) error {
allFields := make([]zapcore.Field, 0, len(c.withFields)+len(fields))
allFields = append(allFields, c.withFields...)
allFields = append(allFields, fields...)
c.state.writes = append(c.state.writes, capturedWrite{
fields: allFields,
})
return nil
}
func (c *captureCore) Sync() error {
return nil
}
func TestSlogZapHandler_Handle_DoesNotAppendTimeField(t *testing.T) {
core := newCaptureCore()
handler := newSlogZapHandler(zap.New(core))
record := slog.NewRecord(time.Date(2026, 1, 1, 12, 0, 0, 0, time.UTC), slog.LevelInfo, "hello", 0)
record.AddAttrs(slog.String("component", "http.access"))
if err := handler.Handle(context.Background(), record); err != nil {
t.Fatalf("handle slog record: %v", err)
}
if len(core.state.writes) != 1 {
t.Fatalf("write calls = %d, want 1", len(core.state.writes))
}
var hasComponent bool
for _, field := range core.state.writes[0].fields {
if field.Key == "time" {
t.Fatalf("unexpected duplicate time field in slog adapter output")
}
if field.Key == "component" {
hasComponent = true
}
}
if !hasComponent {
t.Fatalf("component field should be preserved")
}
}
package logger
import (
"io"
"log"
"os"
"strings"
"testing"
)
func TestInferStdLogLevel(t *testing.T) {
cases := []struct {
msg string
want Level
}{
{msg: "Warning: queue full", want: LevelWarn},
{msg: "Forward request failed: timeout", want: LevelError},
{msg: "[ERROR] upstream unavailable", want: LevelError},
{msg: "service started", want: LevelInfo},
{msg: "debug: cache miss", want: LevelDebug},
}
for _, tc := range cases {
got := inferStdLogLevel(tc.msg)
if got != tc.want {
t.Fatalf("inferStdLogLevel(%q)=%v want=%v", tc.msg, got, tc.want)
}
}
}
func TestNormalizeStdLogMessage(t *testing.T) {
raw := " [TokenRefresh] cycle complete \n total=1 failed=0 \n"
got := normalizeStdLogMessage(raw)
want := "[TokenRefresh] cycle complete total=1 failed=0"
if got != want {
t.Fatalf("normalizeStdLogMessage()=%q want=%q", got, want)
}
}
func TestStdLogBridgeRoutesLevels(t *testing.T) {
origStdout := os.Stdout
origStderr := os.Stderr
stdoutR, stdoutW, err := os.Pipe()
if err != nil {
t.Fatalf("create stdout pipe: %v", err)
}
stderrR, stderrW, err := os.Pipe()
if err != nil {
t.Fatalf("create stderr pipe: %v", err)
}
os.Stdout = stdoutW
os.Stderr = stderrW
t.Cleanup(func() {
os.Stdout = origStdout
os.Stderr = origStderr
_ = stdoutR.Close()
_ = stdoutW.Close()
_ = stderrR.Close()
_ = stderrW.Close()
})
if err := Init(InitOptions{
Level: "debug",
Format: "json",
ServiceName: "sub2api",
Environment: "test",
Output: OutputOptions{
ToStdout: true,
ToFile: false,
},
Sampling: SamplingOptions{Enabled: false},
}); err != nil {
t.Fatalf("Init() error: %v", err)
}
log.Printf("service started")
log.Printf("Warning: queue full")
log.Printf("Forward request failed: timeout")
Sync()
_ = stdoutW.Close()
_ = stderrW.Close()
stdoutBytes, _ := io.ReadAll(stdoutR)
stderrBytes, _ := io.ReadAll(stderrR)
stdoutText := string(stdoutBytes)
stderrText := string(stderrBytes)
if !strings.Contains(stdoutText, "service started") {
t.Fatalf("stdout missing info log: %s", stdoutText)
}
if !strings.Contains(stderrText, "Warning: queue full") {
t.Fatalf("stderr missing warn log: %s", stderrText)
}
if !strings.Contains(stderrText, "Forward request failed: timeout") {
t.Fatalf("stderr missing error log: %s", stderrText)
}
if !strings.Contains(stderrText, "\"legacy_stdlog\":true") {
t.Fatalf("stderr missing legacy_stdlog marker: %s", stderrText)
}
}
func TestLegacyPrintfRoutesLevels(t *testing.T) {
origStdout := os.Stdout
origStderr := os.Stderr
stdoutR, stdoutW, err := os.Pipe()
if err != nil {
t.Fatalf("create stdout pipe: %v", err)
}
stderrR, stderrW, err := os.Pipe()
if err != nil {
t.Fatalf("create stderr pipe: %v", err)
}
os.Stdout = stdoutW
os.Stderr = stderrW
t.Cleanup(func() {
os.Stdout = origStdout
os.Stderr = origStderr
_ = stdoutR.Close()
_ = stdoutW.Close()
_ = stderrR.Close()
_ = stderrW.Close()
})
if err := Init(InitOptions{
Level: "debug",
Format: "json",
ServiceName: "sub2api",
Environment: "test",
Output: OutputOptions{
ToStdout: true,
ToFile: false,
},
Sampling: SamplingOptions{Enabled: false},
}); err != nil {
t.Fatalf("Init() error: %v", err)
}
LegacyPrintf("service.test", "request started")
LegacyPrintf("service.test", "Warning: queue full")
LegacyPrintf("service.test", "forward failed: timeout")
Sync()
_ = stdoutW.Close()
_ = stderrW.Close()
stdoutBytes, _ := io.ReadAll(stdoutR)
stderrBytes, _ := io.ReadAll(stderrR)
stdoutText := string(stdoutBytes)
stderrText := string(stderrBytes)
if !strings.Contains(stdoutText, "request started") {
t.Fatalf("stdout missing info log: %s", stdoutText)
}
if !strings.Contains(stderrText, "Warning: queue full") {
t.Fatalf("stderr missing warn log: %s", stderrText)
}
if !strings.Contains(stderrText, "forward failed: timeout") {
t.Fatalf("stderr missing error log: %s", stderrText)
}
if !strings.Contains(stderrText, "\"legacy_printf\":true") {
t.Fatalf("stderr missing legacy_printf marker: %s", stderrText)
}
if !strings.Contains(stderrText, "\"component\":\"service.test\"") {
t.Fatalf("stderr missing component field: %s", stderrText)
}
}
...@@ -50,6 +50,7 @@ type OAuthSession struct { ...@@ -50,6 +50,7 @@ type OAuthSession struct {
type SessionStore struct { type SessionStore struct {
mu sync.RWMutex mu sync.RWMutex
sessions map[string]*OAuthSession sessions map[string]*OAuthSession
stopOnce sync.Once
stopCh chan struct{} stopCh chan struct{}
} }
...@@ -65,7 +66,9 @@ func NewSessionStore() *SessionStore { ...@@ -65,7 +66,9 @@ func NewSessionStore() *SessionStore {
// Stop stops the cleanup goroutine // Stop stops the cleanup goroutine
func (s *SessionStore) Stop() { func (s *SessionStore) Stop() {
close(s.stopCh) s.stopOnce.Do(func() {
close(s.stopCh)
})
} }
// Set stores a session // Set stores a session
......
package oauth
import (
"sync"
"testing"
"time"
)
func TestSessionStore_Stop_Idempotent(t *testing.T) {
store := NewSessionStore()
store.Stop()
store.Stop()
select {
case <-store.stopCh:
// ok
case <-time.After(time.Second):
t.Fatal("stopCh 未关闭")
}
}
func TestSessionStore_Stop_Concurrent(t *testing.T) {
store := NewSessionStore()
var wg sync.WaitGroup
for range 50 {
wg.Add(1)
go func() {
defer wg.Done()
store.Stop()
}()
}
wg.Wait()
select {
case <-store.stopCh:
// ok
case <-time.After(time.Second):
t.Fatal("stopCh 未关闭")
}
}
...@@ -15,8 +15,8 @@ type Model struct { ...@@ -15,8 +15,8 @@ type Model struct {
// DefaultModels OpenAI models list // DefaultModels OpenAI models list
var DefaultModels = []Model{ var DefaultModels = []Model{
{ID: "gpt-5.3", Object: "model", Created: 1735689600, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.3"},
{ID: "gpt-5.3-codex", Object: "model", Created: 1735689600, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.3 Codex"}, {ID: "gpt-5.3-codex", Object: "model", Created: 1735689600, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.3 Codex"},
{ID: "gpt-5.3-codex-spark", Object: "model", Created: 1735689600, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.3 Codex Spark"},
{ID: "gpt-5.2", Object: "model", Created: 1733875200, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.2"}, {ID: "gpt-5.2", Object: "model", Created: 1733875200, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.2"},
{ID: "gpt-5.2-codex", Object: "model", Created: 1733011200, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.2 Codex"}, {ID: "gpt-5.2-codex", Object: "model", Created: 1733011200, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.2 Codex"},
{ID: "gpt-5.1-codex-max", Object: "model", Created: 1730419200, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.1 Codex Max"}, {ID: "gpt-5.1-codex-max", Object: "model", Created: 1730419200, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.1 Codex Max"},
......
...@@ -17,6 +17,8 @@ import ( ...@@ -17,6 +17,8 @@ import (
const ( const (
// OAuth Client ID for OpenAI (Codex CLI official) // OAuth Client ID for OpenAI (Codex CLI official)
ClientID = "app_EMoamEEZ73f0CkXaXp7hrann" ClientID = "app_EMoamEEZ73f0CkXaXp7hrann"
// OAuth Client ID for Sora mobile flow (aligned with sora2api)
SoraClientID = "app_LlGpXReQgckcGGUo2JrYvtJK"
// OAuth endpoints // OAuth endpoints
AuthorizeURL = "https://auth.openai.com/oauth/authorize" AuthorizeURL = "https://auth.openai.com/oauth/authorize"
...@@ -47,6 +49,7 @@ type OAuthSession struct { ...@@ -47,6 +49,7 @@ type OAuthSession struct {
type SessionStore struct { type SessionStore struct {
mu sync.RWMutex mu sync.RWMutex
sessions map[string]*OAuthSession sessions map[string]*OAuthSession
stopOnce sync.Once
stopCh chan struct{} stopCh chan struct{}
} }
...@@ -92,7 +95,9 @@ func (s *SessionStore) Delete(sessionID string) { ...@@ -92,7 +95,9 @@ func (s *SessionStore) Delete(sessionID string) {
// Stop stops the cleanup goroutine // Stop stops the cleanup goroutine
func (s *SessionStore) Stop() { func (s *SessionStore) Stop() {
close(s.stopCh) s.stopOnce.Do(func() {
close(s.stopCh)
})
} }
// cleanup removes expired sessions periodically // cleanup removes expired sessions periodically
......
package openai
import (
"sync"
"testing"
"time"
)
func TestSessionStore_Stop_Idempotent(t *testing.T) {
store := NewSessionStore()
store.Stop()
store.Stop()
select {
case <-store.stopCh:
// ok
case <-time.After(time.Second):
t.Fatal("stopCh 未关闭")
}
}
func TestSessionStore_Stop_Concurrent(t *testing.T) {
store := NewSessionStore()
var wg sync.WaitGroup
for range 50 {
wg.Add(1)
go func() {
defer wg.Done()
store.Stop()
}()
}
wg.Wait()
select {
case <-store.stopCh:
// ok
case <-time.After(time.Second):
t.Fatal("stopCh 未关闭")
}
}
package openai package openai
import "strings"
// CodexCLIUserAgentPrefixes matches Codex CLI User-Agent patterns // CodexCLIUserAgentPrefixes matches Codex CLI User-Agent patterns
// Examples: "codex_vscode/1.0.0", "codex_cli_rs/0.1.2" // Examples: "codex_vscode/1.0.0", "codex_cli_rs/0.1.2"
var CodexCLIUserAgentPrefixes = []string{ var CodexCLIUserAgentPrefixes = []string{
...@@ -7,10 +9,67 @@ var CodexCLIUserAgentPrefixes = []string{ ...@@ -7,10 +9,67 @@ var CodexCLIUserAgentPrefixes = []string{
"codex_cli_rs/", "codex_cli_rs/",
} }
// CodexOfficialClientUserAgentPrefixes matches Codex 官方客户端家族 User-Agent 前缀。
// 该列表仅用于 OpenAI OAuth `codex_cli_only` 访问限制判定。
var CodexOfficialClientUserAgentPrefixes = []string{
"codex_cli_rs/",
"codex_vscode/",
"codex_app/",
"codex_chatgpt_desktop/",
"codex_atlas/",
"codex_exec/",
"codex_sdk_ts/",
"codex ",
}
// CodexOfficialClientOriginatorPrefixes matches Codex 官方客户端家族 originator 前缀。
// 说明:OpenAI 官方 Codex 客户端并不只使用固定的 codex_app 标识。
// 例如 codex_cli_rs、codex_vscode、codex_chatgpt_desktop、codex_atlas、codex_exec、codex_sdk_ts 等。
var CodexOfficialClientOriginatorPrefixes = []string{
"codex_",
"codex ",
}
// IsCodexCLIRequest checks if the User-Agent indicates a Codex CLI request // IsCodexCLIRequest checks if the User-Agent indicates a Codex CLI request
func IsCodexCLIRequest(userAgent string) bool { func IsCodexCLIRequest(userAgent string) bool {
for _, prefix := range CodexCLIUserAgentPrefixes { ua := normalizeCodexClientHeader(userAgent)
if len(userAgent) >= len(prefix) && userAgent[:len(prefix)] == prefix { if ua == "" {
return false
}
return matchCodexClientHeaderPrefixes(ua, CodexCLIUserAgentPrefixes)
}
// IsCodexOfficialClientRequest checks if the User-Agent indicates a Codex 官方客户端请求。
// 与 IsCodexCLIRequest 解耦,避免影响历史兼容逻辑。
func IsCodexOfficialClientRequest(userAgent string) bool {
ua := normalizeCodexClientHeader(userAgent)
if ua == "" {
return false
}
return matchCodexClientHeaderPrefixes(ua, CodexOfficialClientUserAgentPrefixes)
}
// IsCodexOfficialClientOriginator checks if originator indicates a Codex 官方客户端请求。
func IsCodexOfficialClientOriginator(originator string) bool {
v := normalizeCodexClientHeader(originator)
if v == "" {
return false
}
return matchCodexClientHeaderPrefixes(v, CodexOfficialClientOriginatorPrefixes)
}
func normalizeCodexClientHeader(value string) string {
return strings.ToLower(strings.TrimSpace(value))
}
func matchCodexClientHeaderPrefixes(value string, prefixes []string) bool {
for _, prefix := range prefixes {
normalizedPrefix := normalizeCodexClientHeader(prefix)
if normalizedPrefix == "" {
continue
}
// 优先前缀匹配;若 UA/Originator 被网关拼接为复合字符串时,退化为包含匹配。
if strings.HasPrefix(value, normalizedPrefix) || strings.Contains(value, normalizedPrefix) {
return true return true
} }
} }
......
package openai
import "testing"
func TestIsCodexCLIRequest(t *testing.T) {
tests := []struct {
name string
ua string
want bool
}{
{name: "codex_cli_rs 前缀", ua: "codex_cli_rs/0.1.0", want: true},
{name: "codex_vscode 前缀", ua: "codex_vscode/1.2.3", want: true},
{name: "大小写混合", ua: "Codex_CLI_Rs/0.1.0", want: true},
{name: "复合 UA 包含 codex", ua: "Mozilla/5.0 codex_cli_rs/0.1.0", want: true},
{name: "空白包裹", ua: " codex_vscode/1.2.3 ", want: true},
{name: "非 codex", ua: "curl/8.0.1", want: false},
{name: "空字符串", ua: "", want: false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := IsCodexCLIRequest(tt.ua)
if got != tt.want {
t.Fatalf("IsCodexCLIRequest(%q) = %v, want %v", tt.ua, got, tt.want)
}
})
}
}
func TestIsCodexOfficialClientRequest(t *testing.T) {
tests := []struct {
name string
ua string
want bool
}{
{name: "codex_cli_rs 前缀", ua: "codex_cli_rs/0.98.0", want: true},
{name: "codex_vscode 前缀", ua: "codex_vscode/1.0.0", want: true},
{name: "codex_app 前缀", ua: "codex_app/0.1.0", want: true},
{name: "codex_chatgpt_desktop 前缀", ua: "codex_chatgpt_desktop/1.0.0", want: true},
{name: "codex_atlas 前缀", ua: "codex_atlas/1.0.0", want: true},
{name: "codex_exec 前缀", ua: "codex_exec/0.1.0", want: true},
{name: "codex_sdk_ts 前缀", ua: "codex_sdk_ts/0.1.0", want: true},
{name: "Codex 桌面 UA", ua: "Codex Desktop/1.2.3", want: true},
{name: "复合 UA 包含 codex_app", ua: "Mozilla/5.0 codex_app/0.1.0", want: true},
{name: "大小写混合", ua: "Codex_VSCode/1.2.3", want: true},
{name: "非 codex", ua: "curl/8.0.1", want: false},
{name: "空字符串", ua: "", want: false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := IsCodexOfficialClientRequest(tt.ua)
if got != tt.want {
t.Fatalf("IsCodexOfficialClientRequest(%q) = %v, want %v", tt.ua, got, tt.want)
}
})
}
}
func TestIsCodexOfficialClientOriginator(t *testing.T) {
tests := []struct {
name string
originator string
want bool
}{
{name: "codex_cli_rs", originator: "codex_cli_rs", want: true},
{name: "codex_vscode", originator: "codex_vscode", want: true},
{name: "codex_app", originator: "codex_app", want: true},
{name: "codex_chatgpt_desktop", originator: "codex_chatgpt_desktop", want: true},
{name: "codex_atlas", originator: "codex_atlas", want: true},
{name: "codex_exec", originator: "codex_exec", want: true},
{name: "codex_sdk_ts", originator: "codex_sdk_ts", want: true},
{name: "Codex 前缀", originator: "Codex Desktop", want: true},
{name: "空白包裹", originator: " codex_vscode ", want: true},
{name: "非 codex", originator: "my_client", want: false},
{name: "空字符串", originator: "", want: false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := IsCodexOfficialClientOriginator(tt.originator)
if got != tt.want {
t.Fatalf("IsCodexOfficialClientOriginator(%q) = %v, want %v", tt.originator, got, tt.want)
}
})
}
}
...@@ -7,6 +7,7 @@ import ( ...@@ -7,6 +7,7 @@ import (
"net/http" "net/http"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/util/logredact"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )
...@@ -78,7 +79,7 @@ func ErrorFrom(c *gin.Context, err error) bool { ...@@ -78,7 +79,7 @@ func ErrorFrom(c *gin.Context, err error) bool {
// Log internal errors with full details for debugging // Log internal errors with full details for debugging
if statusCode >= 500 && c.Request != nil { if statusCode >= 500 && c.Request != nil {
log.Printf("[ERROR] %s %s\n Error: %s", c.Request.Method, c.Request.URL.Path, err.Error()) log.Printf("[ERROR] %s %s\n Error: %s", c.Request.Method, c.Request.URL.Path, logredact.RedactText(err.Error()))
} }
ErrorWithDetails(c, statusCode, status.Message, status.Reason, status.Metadata) ErrorWithDetails(c, statusCode, status.Message, status.Reason, status.Metadata)
......
...@@ -14,6 +14,44 @@ import ( ...@@ -14,6 +14,44 @@ import (
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
// ---------- 辅助函数 ----------
// parseResponseBody 从 httptest.ResponseRecorder 中解析 JSON 响应体
func parseResponseBody(t *testing.T, w *httptest.ResponseRecorder) Response {
t.Helper()
var got Response
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &got))
return got
}
// parsePaginatedBody 从响应体中解析分页数据(Data 字段是 PaginatedData)
func parsePaginatedBody(t *testing.T, w *httptest.ResponseRecorder) (Response, PaginatedData) {
t.Helper()
// 先用 raw json 解析,因为 Data 是 any 类型
var raw struct {
Code int `json:"code"`
Message string `json:"message"`
Reason string `json:"reason,omitempty"`
Data json.RawMessage `json:"data,omitempty"`
}
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &raw))
var pd PaginatedData
require.NoError(t, json.Unmarshal(raw.Data, &pd))
return Response{Code: raw.Code, Message: raw.Message, Reason: raw.Reason}, pd
}
// newContextWithQuery 创建一个带有 URL query 参数的 gin.Context 用于测试 ParsePagination
func newContextWithQuery(query string) (*httptest.ResponseRecorder, *gin.Context) {
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodGet, "/?"+query, nil)
return w, c
}
// ---------- 现有测试 ----------
func TestErrorWithDetails(t *testing.T) { func TestErrorWithDetails(t *testing.T) {
gin.SetMode(gin.TestMode) gin.SetMode(gin.TestMode)
...@@ -169,3 +207,582 @@ func TestErrorFrom(t *testing.T) { ...@@ -169,3 +207,582 @@ func TestErrorFrom(t *testing.T) {
}) })
} }
} }
// ---------- 新增测试 ----------
func TestSuccess(t *testing.T) {
gin.SetMode(gin.TestMode)
tests := []struct {
name string
data any
wantCode int
wantBody Response
}{
{
name: "返回字符串数据",
data: "hello",
wantCode: http.StatusOK,
wantBody: Response{Code: 0, Message: "success", Data: "hello"},
},
{
name: "返回nil数据",
data: nil,
wantCode: http.StatusOK,
wantBody: Response{Code: 0, Message: "success"},
},
{
name: "返回map数据",
data: map[string]string{"key": "value"},
wantCode: http.StatusOK,
wantBody: Response{Code: 0, Message: "success"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
Success(c, tt.data)
require.Equal(t, tt.wantCode, w.Code)
// 只验证 code 和 message,data 字段类型在 JSON 反序列化时会变成 map/slice
got := parseResponseBody(t, w)
require.Equal(t, 0, got.Code)
require.Equal(t, "success", got.Message)
if tt.data == nil {
require.Nil(t, got.Data)
} else {
require.NotNil(t, got.Data)
}
})
}
}
func TestCreated(t *testing.T) {
gin.SetMode(gin.TestMode)
tests := []struct {
name string
data any
wantCode int
}{
{
name: "创建成功_返回数据",
data: map[string]int{"id": 42},
wantCode: http.StatusCreated,
},
{
name: "创建成功_nil数据",
data: nil,
wantCode: http.StatusCreated,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
Created(c, tt.data)
require.Equal(t, tt.wantCode, w.Code)
got := parseResponseBody(t, w)
require.Equal(t, 0, got.Code)
require.Equal(t, "success", got.Message)
})
}
}
func TestError(t *testing.T) {
gin.SetMode(gin.TestMode)
tests := []struct {
name string
statusCode int
message string
}{
{
name: "400错误",
statusCode: http.StatusBadRequest,
message: "bad request",
},
{
name: "500错误",
statusCode: http.StatusInternalServerError,
message: "internal error",
},
{
name: "自定义状态码",
statusCode: 418,
message: "I'm a teapot",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
Error(c, tt.statusCode, tt.message)
require.Equal(t, tt.statusCode, w.Code)
got := parseResponseBody(t, w)
require.Equal(t, tt.statusCode, got.Code)
require.Equal(t, tt.message, got.Message)
require.Empty(t, got.Reason)
require.Nil(t, got.Metadata)
require.Nil(t, got.Data)
})
}
}
func TestBadRequest(t *testing.T) {
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
BadRequest(c, "参数无效")
require.Equal(t, http.StatusBadRequest, w.Code)
got := parseResponseBody(t, w)
require.Equal(t, http.StatusBadRequest, got.Code)
require.Equal(t, "参数无效", got.Message)
}
func TestUnauthorized(t *testing.T) {
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
Unauthorized(c, "未登录")
require.Equal(t, http.StatusUnauthorized, w.Code)
got := parseResponseBody(t, w)
require.Equal(t, http.StatusUnauthorized, got.Code)
require.Equal(t, "未登录", got.Message)
}
func TestForbidden(t *testing.T) {
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
Forbidden(c, "无权限")
require.Equal(t, http.StatusForbidden, w.Code)
got := parseResponseBody(t, w)
require.Equal(t, http.StatusForbidden, got.Code)
require.Equal(t, "无权限", got.Message)
}
func TestNotFound(t *testing.T) {
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
NotFound(c, "资源不存在")
require.Equal(t, http.StatusNotFound, w.Code)
got := parseResponseBody(t, w)
require.Equal(t, http.StatusNotFound, got.Code)
require.Equal(t, "资源不存在", got.Message)
}
func TestInternalError(t *testing.T) {
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
InternalError(c, "服务器内部错误")
require.Equal(t, http.StatusInternalServerError, w.Code)
got := parseResponseBody(t, w)
require.Equal(t, http.StatusInternalServerError, got.Code)
require.Equal(t, "服务器内部错误", got.Message)
}
func TestPaginated(t *testing.T) {
gin.SetMode(gin.TestMode)
tests := []struct {
name string
items any
total int64
page int
pageSize int
wantPages int
wantTotal int64
wantPage int
wantPageSize int
}{
{
name: "标准分页_多页",
items: []string{"a", "b"},
total: 25,
page: 1,
pageSize: 10,
wantPages: 3,
wantTotal: 25,
wantPage: 1,
wantPageSize: 10,
},
{
name: "总数刚好整除",
items: []string{"a"},
total: 20,
page: 2,
pageSize: 10,
wantPages: 2,
wantTotal: 20,
wantPage: 2,
wantPageSize: 10,
},
{
name: "总数为0_pages至少为1",
items: []string{},
total: 0,
page: 1,
pageSize: 10,
wantPages: 1,
wantTotal: 0,
wantPage: 1,
wantPageSize: 10,
},
{
name: "单页数据",
items: []int{1, 2, 3},
total: 3,
page: 1,
pageSize: 20,
wantPages: 1,
wantTotal: 3,
wantPage: 1,
wantPageSize: 20,
},
{
name: "总数为1",
items: []string{"only"},
total: 1,
page: 1,
pageSize: 10,
wantPages: 1,
wantTotal: 1,
wantPage: 1,
wantPageSize: 10,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
Paginated(c, tt.items, tt.total, tt.page, tt.pageSize)
require.Equal(t, http.StatusOK, w.Code)
resp, pd := parsePaginatedBody(t, w)
require.Equal(t, 0, resp.Code)
require.Equal(t, "success", resp.Message)
require.Equal(t, tt.wantTotal, pd.Total)
require.Equal(t, tt.wantPage, pd.Page)
require.Equal(t, tt.wantPageSize, pd.PageSize)
require.Equal(t, tt.wantPages, pd.Pages)
})
}
}
func TestPaginatedWithResult(t *testing.T) {
gin.SetMode(gin.TestMode)
tests := []struct {
name string
items any
pagination *PaginationResult
wantTotal int64
wantPage int
wantPageSize int
wantPages int
}{
{
name: "正常分页结果",
items: []string{"a", "b"},
pagination: &PaginationResult{
Total: 50,
Page: 3,
PageSize: 10,
Pages: 5,
},
wantTotal: 50,
wantPage: 3,
wantPageSize: 10,
wantPages: 5,
},
{
name: "pagination为nil_使用默认值",
items: []string{},
pagination: nil,
wantTotal: 0,
wantPage: 1,
wantPageSize: 20,
wantPages: 1,
},
{
name: "单页结果",
items: []int{1},
pagination: &PaginationResult{
Total: 1,
Page: 1,
PageSize: 20,
Pages: 1,
},
wantTotal: 1,
wantPage: 1,
wantPageSize: 20,
wantPages: 1,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
PaginatedWithResult(c, tt.items, tt.pagination)
require.Equal(t, http.StatusOK, w.Code)
resp, pd := parsePaginatedBody(t, w)
require.Equal(t, 0, resp.Code)
require.Equal(t, "success", resp.Message)
require.Equal(t, tt.wantTotal, pd.Total)
require.Equal(t, tt.wantPage, pd.Page)
require.Equal(t, tt.wantPageSize, pd.PageSize)
require.Equal(t, tt.wantPages, pd.Pages)
})
}
}
func TestParsePagination(t *testing.T) {
gin.SetMode(gin.TestMode)
tests := []struct {
name string
query string
wantPage int
wantPageSize int
}{
{
name: "无参数_使用默认值",
query: "",
wantPage: 1,
wantPageSize: 20,
},
{
name: "仅指定page",
query: "page=3",
wantPage: 3,
wantPageSize: 20,
},
{
name: "仅指定page_size",
query: "page_size=50",
wantPage: 1,
wantPageSize: 50,
},
{
name: "同时指定page和page_size",
query: "page=2&page_size=30",
wantPage: 2,
wantPageSize: 30,
},
{
name: "使用limit代替page_size",
query: "limit=15",
wantPage: 1,
wantPageSize: 15,
},
{
name: "page_size优先于limit",
query: "page_size=25&limit=50",
wantPage: 1,
wantPageSize: 25,
},
{
name: "page为0_使用默认值",
query: "page=0",
wantPage: 1,
wantPageSize: 20,
},
{
name: "page_size超过1000_使用默认值",
query: "page_size=1001",
wantPage: 1,
wantPageSize: 20,
},
{
name: "page_size恰好1000_有效",
query: "page_size=1000",
wantPage: 1,
wantPageSize: 1000,
},
{
name: "page为非数字_使用默认值",
query: "page=abc",
wantPage: 1,
wantPageSize: 20,
},
{
name: "page_size为非数字_使用默认值",
query: "page_size=xyz",
wantPage: 1,
wantPageSize: 20,
},
{
name: "limit为非数字_使用默认值",
query: "limit=abc",
wantPage: 1,
wantPageSize: 20,
},
{
name: "page_size为0_使用默认值",
query: "page_size=0",
wantPage: 1,
wantPageSize: 20,
},
{
name: "limit为0_使用默认值",
query: "limit=0",
wantPage: 1,
wantPageSize: 20,
},
{
name: "大页码",
query: "page=999&page_size=100",
wantPage: 999,
wantPageSize: 100,
},
{
name: "page_size为1_最小有效值",
query: "page_size=1",
wantPage: 1,
wantPageSize: 1,
},
{
name: "混合数字和字母的page",
query: "page=12a",
wantPage: 1,
wantPageSize: 20,
},
{
name: "limit超过1000_使用默认值",
query: "limit=2000",
wantPage: 1,
wantPageSize: 20,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
_, c := newContextWithQuery(tt.query)
page, pageSize := ParsePagination(c)
require.Equal(t, tt.wantPage, page, "page 不符合预期")
require.Equal(t, tt.wantPageSize, pageSize, "pageSize 不符合预期")
})
}
}
func Test_parseInt(t *testing.T) {
tests := []struct {
name string
input string
wantVal int
wantErr bool
}{
{
name: "正常数字",
input: "123",
wantVal: 123,
wantErr: false,
},
{
name: "零",
input: "0",
wantVal: 0,
wantErr: false,
},
{
name: "单个数字",
input: "5",
wantVal: 5,
wantErr: false,
},
{
name: "大数字",
input: "99999",
wantVal: 99999,
wantErr: false,
},
{
name: "包含字母_返回0",
input: "abc",
wantVal: 0,
wantErr: false,
},
{
name: "数字开头接字母_返回0",
input: "12a",
wantVal: 0,
wantErr: false,
},
{
name: "包含负号_返回0",
input: "-1",
wantVal: 0,
wantErr: false,
},
{
name: "包含小数点_返回0",
input: "1.5",
wantVal: 0,
wantErr: false,
},
{
name: "包含空格_返回0",
input: "1 2",
wantVal: 0,
wantErr: false,
},
{
name: "空字符串",
input: "",
wantVal: 0,
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
val, err := parseInt(tt.input)
if tt.wantErr {
require.Error(t, err)
} else {
require.NoError(t, err)
}
require.Equal(t, tt.wantVal, val)
})
}
}
...@@ -286,7 +286,7 @@ func (d *SOCKS5ProxyDialer) DialTLSContext(ctx context.Context, network, addr st ...@@ -286,7 +286,7 @@ func (d *SOCKS5ProxyDialer) DialTLSContext(ctx context.Context, network, addr st
return nil, fmt.Errorf("apply TLS preset: %w", err) return nil, fmt.Errorf("apply TLS preset: %w", err)
} }
if err := tlsConn.Handshake(); err != nil { if err := tlsConn.HandshakeContext(ctx); err != nil {
slog.Debug("tls_fingerprint_socks5_handshake_failed", "error", err) slog.Debug("tls_fingerprint_socks5_handshake_failed", "error", err)
_ = conn.Close() _ = conn.Close()
return nil, fmt.Errorf("TLS handshake failed: %w", err) return nil, fmt.Errorf("TLS handshake failed: %w", err)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment