Commit 64236361 authored by yangjianbo's avatar yangjianbo
Browse files

Merge branch 'test' into dev

parents 2d6066f9 b6aaee01
package logger
import "github.com/Wei-Shaw/sub2api/internal/config"
func OptionsFromConfig(cfg config.LogConfig) InitOptions {
return InitOptions{
Level: cfg.Level,
Format: cfg.Format,
ServiceName: cfg.ServiceName,
Environment: cfg.Environment,
Caller: cfg.Caller,
StacktraceLevel: cfg.StacktraceLevel,
Output: OutputOptions{
ToStdout: cfg.Output.ToStdout,
ToFile: cfg.Output.ToFile,
FilePath: cfg.Output.FilePath,
},
Rotation: RotationOptions{
MaxSizeMB: cfg.Rotation.MaxSizeMB,
MaxBackups: cfg.Rotation.MaxBackups,
MaxAgeDays: cfg.Rotation.MaxAgeDays,
Compress: cfg.Rotation.Compress,
LocalTime: cfg.Rotation.LocalTime,
},
Sampling: SamplingOptions{
Enabled: cfg.Sampling.Enabled,
Initial: cfg.Sampling.Initial,
Thereafter: cfg.Sampling.Thereafter,
},
}
}
package logger
import (
"context"
"fmt"
"io"
"log"
"log/slog"
"os"
"path/filepath"
"strings"
"sync"
"time"
"go.uber.org/zap"
"go.uber.org/zap/zapcore"
"gopkg.in/natefinch/lumberjack.v2"
)
type Level = zapcore.Level
const (
LevelDebug = zapcore.DebugLevel
LevelInfo = zapcore.InfoLevel
LevelWarn = zapcore.WarnLevel
LevelError = zapcore.ErrorLevel
LevelFatal = zapcore.FatalLevel
)
type Sink interface {
WriteLogEvent(event *LogEvent)
}
type LogEvent struct {
Time time.Time
Level string
Component string
Message string
LoggerName string
Fields map[string]any
}
var (
mu sync.RWMutex
global *zap.Logger
sugar *zap.SugaredLogger
atomicLevel zap.AtomicLevel
initOptions InitOptions
currentSink Sink
stdLogUndo func()
bootstrapOnce sync.Once
)
func InitBootstrap() {
bootstrapOnce.Do(func() {
if err := Init(bootstrapOptions()); err != nil {
_, _ = fmt.Fprintf(os.Stderr, "logger bootstrap init failed: %v\n", err)
}
})
}
func Init(options InitOptions) error {
mu.Lock()
defer mu.Unlock()
return initLocked(options)
}
func initLocked(options InitOptions) error {
normalized := options.normalized()
zl, al, err := buildLogger(normalized)
if err != nil {
return err
}
prev := global
global = zl
sugar = zl.Sugar()
atomicLevel = al
initOptions = normalized
bridgeSlogLocked()
bridgeStdLogLocked()
if prev != nil {
_ = prev.Sync()
}
return nil
}
func Reconfigure(mutator func(*InitOptions) error) error {
mu.Lock()
defer mu.Unlock()
next := initOptions
if mutator != nil {
if err := mutator(&next); err != nil {
return err
}
}
return initLocked(next)
}
func SetLevel(level string) error {
lv, ok := parseLevel(level)
if !ok {
return fmt.Errorf("invalid log level: %s", level)
}
mu.Lock()
defer mu.Unlock()
atomicLevel.SetLevel(lv)
initOptions.Level = strings.ToLower(strings.TrimSpace(level))
return nil
}
func CurrentLevel() string {
mu.RLock()
defer mu.RUnlock()
if global == nil {
return "info"
}
return atomicLevel.Level().String()
}
func SetSink(sink Sink) {
mu.Lock()
defer mu.Unlock()
currentSink = sink
}
// WriteSinkEvent 直接写入日志 sink,不经过全局日志级别门控。
// 用于需要“可观测性入库”与“业务输出级别”解耦的场景(例如 ops 系统日志索引)。
func WriteSinkEvent(level, component, message string, fields map[string]any) {
mu.RLock()
sink := currentSink
mu.RUnlock()
if sink == nil {
return
}
level = strings.ToLower(strings.TrimSpace(level))
if level == "" {
level = "info"
}
component = strings.TrimSpace(component)
message = strings.TrimSpace(message)
if message == "" {
return
}
eventFields := make(map[string]any, len(fields)+1)
for k, v := range fields {
eventFields[k] = v
}
if component != "" {
if _, ok := eventFields["component"]; !ok {
eventFields["component"] = component
}
}
sink.WriteLogEvent(&LogEvent{
Time: time.Now(),
Level: level,
Component: component,
Message: message,
LoggerName: component,
Fields: eventFields,
})
}
func L() *zap.Logger {
mu.RLock()
defer mu.RUnlock()
if global != nil {
return global
}
return zap.NewNop()
}
func S() *zap.SugaredLogger {
mu.RLock()
defer mu.RUnlock()
if sugar != nil {
return sugar
}
return zap.NewNop().Sugar()
}
func With(fields ...zap.Field) *zap.Logger {
return L().With(fields...)
}
func Sync() {
mu.RLock()
l := global
mu.RUnlock()
if l != nil {
_ = l.Sync()
}
}
func bridgeStdLogLocked() {
if stdLogUndo != nil {
stdLogUndo()
stdLogUndo = nil
}
prevFlags := log.Flags()
prevPrefix := log.Prefix()
prevWriter := log.Writer()
log.SetFlags(0)
log.SetPrefix("")
log.SetOutput(newStdLogBridge(global.Named("stdlog")))
stdLogUndo = func() {
log.SetOutput(prevWriter)
log.SetFlags(prevFlags)
log.SetPrefix(prevPrefix)
}
}
func bridgeSlogLocked() {
slog.SetDefault(slog.New(newSlogZapHandler(global.Named("slog"))))
}
func buildLogger(options InitOptions) (*zap.Logger, zap.AtomicLevel, error) {
level, _ := parseLevel(options.Level)
atomic := zap.NewAtomicLevelAt(level)
encoderCfg := zapcore.EncoderConfig{
TimeKey: "time",
LevelKey: "level",
NameKey: "logger",
CallerKey: "caller",
MessageKey: "msg",
StacktraceKey: "stacktrace",
LineEnding: zapcore.DefaultLineEnding,
EncodeLevel: zapcore.CapitalLevelEncoder,
EncodeTime: zapcore.ISO8601TimeEncoder,
EncodeDuration: zapcore.MillisDurationEncoder,
EncodeCaller: zapcore.ShortCallerEncoder,
}
var enc zapcore.Encoder
if options.Format == "console" {
enc = zapcore.NewConsoleEncoder(encoderCfg)
} else {
enc = zapcore.NewJSONEncoder(encoderCfg)
}
sinkCore := newSinkCore()
cores := make([]zapcore.Core, 0, 3)
if options.Output.ToStdout {
infoPriority := zap.LevelEnablerFunc(func(lvl zapcore.Level) bool {
return lvl >= atomic.Level() && lvl < zapcore.WarnLevel
})
errPriority := zap.LevelEnablerFunc(func(lvl zapcore.Level) bool {
return lvl >= atomic.Level() && lvl >= zapcore.WarnLevel
})
cores = append(cores, zapcore.NewCore(enc, zapcore.Lock(os.Stdout), infoPriority))
cores = append(cores, zapcore.NewCore(enc, zapcore.Lock(os.Stderr), errPriority))
}
if options.Output.ToFile {
fileCore, filePath, fileErr := buildFileCore(enc, atomic, options)
if fileErr != nil {
_, _ = fmt.Fprintf(os.Stderr, "time=%s level=WARN msg=\"日志文件输出初始化失败,降级为仅标准输出\" path=%s err=%v\n",
time.Now().Format(time.RFC3339Nano),
filePath,
fileErr,
)
} else {
cores = append(cores, fileCore)
}
}
if len(cores) == 0 {
cores = append(cores, zapcore.NewCore(enc, zapcore.Lock(os.Stdout), atomic))
}
core := zapcore.NewTee(cores...)
if options.Sampling.Enabled {
core = zapcore.NewSamplerWithOptions(core, samplingTick(), options.Sampling.Initial, options.Sampling.Thereafter)
}
core = sinkCore.Wrap(core)
stacktraceLevel, _ := parseStacktraceLevel(options.StacktraceLevel)
zapOpts := make([]zap.Option, 0, 5)
if options.Caller {
zapOpts = append(zapOpts, zap.AddCaller())
}
if stacktraceLevel <= zapcore.FatalLevel {
zapOpts = append(zapOpts, zap.AddStacktrace(stacktraceLevel))
}
logger := zap.New(core, zapOpts...).With(
zap.String("service", options.ServiceName),
zap.String("env", options.Environment),
)
return logger, atomic, nil
}
func buildFileCore(enc zapcore.Encoder, atomic zap.AtomicLevel, options InitOptions) (zapcore.Core, string, error) {
filePath := options.Output.FilePath
if strings.TrimSpace(filePath) == "" {
filePath = resolveLogFilePath("")
}
dir := filepath.Dir(filePath)
if err := os.MkdirAll(dir, 0o755); err != nil {
return nil, filePath, err
}
lj := &lumberjack.Logger{
Filename: filePath,
MaxSize: options.Rotation.MaxSizeMB,
MaxBackups: options.Rotation.MaxBackups,
MaxAge: options.Rotation.MaxAgeDays,
Compress: options.Rotation.Compress,
LocalTime: options.Rotation.LocalTime,
}
return zapcore.NewCore(enc, zapcore.AddSync(lj), atomic), filePath, nil
}
type sinkCore struct {
core zapcore.Core
fields []zapcore.Field
}
func newSinkCore() *sinkCore {
return &sinkCore{}
}
func (s *sinkCore) Wrap(core zapcore.Core) zapcore.Core {
cp := *s
cp.core = core
return &cp
}
func (s *sinkCore) Enabled(level zapcore.Level) bool {
return s.core.Enabled(level)
}
func (s *sinkCore) With(fields []zapcore.Field) zapcore.Core {
nextFields := append([]zapcore.Field{}, s.fields...)
nextFields = append(nextFields, fields...)
return &sinkCore{
core: s.core.With(fields),
fields: nextFields,
}
}
func (s *sinkCore) Check(entry zapcore.Entry, ce *zapcore.CheckedEntry) *zapcore.CheckedEntry {
if s.Enabled(entry.Level) {
return ce.AddCore(entry, s)
}
return ce
}
func (s *sinkCore) Write(entry zapcore.Entry, fields []zapcore.Field) error {
if err := s.core.Write(entry, fields); err != nil {
return err
}
mu.RLock()
sink := currentSink
mu.RUnlock()
if sink == nil {
return nil
}
enc := zapcore.NewMapObjectEncoder()
for _, f := range s.fields {
f.AddTo(enc)
}
for _, f := range fields {
f.AddTo(enc)
}
event := &LogEvent{
Time: entry.Time,
Level: strings.ToLower(entry.Level.String()),
Component: entry.LoggerName,
Message: entry.Message,
LoggerName: entry.LoggerName,
Fields: enc.Fields,
}
sink.WriteLogEvent(event)
return nil
}
func (s *sinkCore) Sync() error {
return s.core.Sync()
}
type stdLogBridge struct {
logger *zap.Logger
}
func newStdLogBridge(l *zap.Logger) io.Writer {
if l == nil {
l = zap.NewNop()
}
return &stdLogBridge{logger: l}
}
func (b *stdLogBridge) Write(p []byte) (int, error) {
msg := normalizeStdLogMessage(string(p))
if msg == "" {
return len(p), nil
}
level := inferStdLogLevel(msg)
entry := b.logger.WithOptions(zap.AddCallerSkip(4))
switch level {
case LevelDebug:
entry.Debug(msg, zap.Bool("legacy_stdlog", true))
case LevelWarn:
entry.Warn(msg, zap.Bool("legacy_stdlog", true))
case LevelError, LevelFatal:
entry.Error(msg, zap.Bool("legacy_stdlog", true))
default:
entry.Info(msg, zap.Bool("legacy_stdlog", true))
}
return len(p), nil
}
func normalizeStdLogMessage(raw string) string {
msg := strings.TrimSpace(strings.ReplaceAll(raw, "\n", " "))
if msg == "" {
return ""
}
return strings.Join(strings.Fields(msg), " ")
}
func inferStdLogLevel(msg string) Level {
lower := strings.ToLower(strings.TrimSpace(msg))
if lower == "" {
return LevelInfo
}
if strings.HasPrefix(lower, "[debug]") || strings.HasPrefix(lower, "debug:") {
return LevelDebug
}
if strings.HasPrefix(lower, "[warn]") || strings.HasPrefix(lower, "[warning]") || strings.HasPrefix(lower, "warn:") || strings.HasPrefix(lower, "warning:") {
return LevelWarn
}
if strings.HasPrefix(lower, "[error]") || strings.HasPrefix(lower, "error:") || strings.HasPrefix(lower, "fatal:") || strings.HasPrefix(lower, "panic:") {
return LevelError
}
if strings.Contains(lower, " failed") || strings.Contains(lower, "error") || strings.Contains(lower, "panic") || strings.Contains(lower, "fatal") {
return LevelError
}
if strings.Contains(lower, "warning") || strings.Contains(lower, "warn") || strings.Contains(lower, " retry") || strings.Contains(lower, " queue full") || strings.Contains(lower, "fallback") {
return LevelWarn
}
return LevelInfo
}
// LegacyPrintf 用于平滑迁移历史的 printf 风格日志到结构化 logger。
func LegacyPrintf(component, format string, args ...any) {
msg := normalizeStdLogMessage(fmt.Sprintf(format, args...))
if msg == "" {
return
}
mu.RLock()
initialized := global != nil
mu.RUnlock()
if !initialized {
// 在日志系统未初始化前,回退到标准库 log,避免测试/工具链丢日志。
log.Print(msg)
return
}
l := L()
if component != "" {
l = l.With(zap.String("component", component))
}
l = l.WithOptions(zap.AddCallerSkip(1))
switch inferStdLogLevel(msg) {
case LevelDebug:
l.Debug(msg, zap.Bool("legacy_printf", true))
case LevelWarn:
l.Warn(msg, zap.Bool("legacy_printf", true))
case LevelError, LevelFatal:
l.Error(msg, zap.Bool("legacy_printf", true))
default:
l.Info(msg, zap.Bool("legacy_printf", true))
}
}
type contextKey string
const loggerContextKey contextKey = "ctx_logger"
func IntoContext(ctx context.Context, l *zap.Logger) context.Context {
if ctx == nil {
ctx = context.Background()
}
if l == nil {
l = L()
}
return context.WithValue(ctx, loggerContextKey, l)
}
func FromContext(ctx context.Context) *zap.Logger {
if ctx == nil {
return L()
}
if l, ok := ctx.Value(loggerContextKey).(*zap.Logger); ok && l != nil {
return l
}
return L()
}
package logger
import (
"encoding/json"
"io"
"os"
"path/filepath"
"strings"
"testing"
)
func TestInit_DualOutput(t *testing.T) {
tmpDir := t.TempDir()
logPath := filepath.Join(tmpDir, "logs", "sub2api.log")
origStdout := os.Stdout
origStderr := os.Stderr
stdoutR, stdoutW, err := os.Pipe()
if err != nil {
t.Fatalf("create stdout pipe: %v", err)
}
stderrR, stderrW, err := os.Pipe()
if err != nil {
t.Fatalf("create stderr pipe: %v", err)
}
os.Stdout = stdoutW
os.Stderr = stderrW
t.Cleanup(func() {
os.Stdout = origStdout
os.Stderr = origStderr
_ = stdoutR.Close()
_ = stderrR.Close()
_ = stdoutW.Close()
_ = stderrW.Close()
})
err = Init(InitOptions{
Level: "debug",
Format: "json",
ServiceName: "sub2api",
Environment: "test",
Output: OutputOptions{
ToStdout: true,
ToFile: true,
FilePath: logPath,
},
Rotation: RotationOptions{
MaxSizeMB: 10,
MaxBackups: 2,
MaxAgeDays: 1,
},
Sampling: SamplingOptions{Enabled: false},
})
if err != nil {
t.Fatalf("Init() error: %v", err)
}
L().Info("dual-output-info")
L().Warn("dual-output-warn")
Sync()
_ = stdoutW.Close()
_ = stderrW.Close()
stdoutBytes, _ := io.ReadAll(stdoutR)
stderrBytes, _ := io.ReadAll(stderrR)
stdoutText := string(stdoutBytes)
stderrText := string(stderrBytes)
if !strings.Contains(stdoutText, "dual-output-info") {
t.Fatalf("stdout missing info log: %s", stdoutText)
}
if !strings.Contains(stderrText, "dual-output-warn") {
t.Fatalf("stderr missing warn log: %s", stderrText)
}
fileBytes, err := os.ReadFile(logPath)
if err != nil {
t.Fatalf("read log file: %v", err)
}
fileText := string(fileBytes)
if !strings.Contains(fileText, "dual-output-info") || !strings.Contains(fileText, "dual-output-warn") {
t.Fatalf("file missing logs: %s", fileText)
}
}
func TestInit_FileOutputFailureDowngrade(t *testing.T) {
origStdout := os.Stdout
origStderr := os.Stderr
_, stdoutW, err := os.Pipe()
if err != nil {
t.Fatalf("create stdout pipe: %v", err)
}
stderrR, stderrW, err := os.Pipe()
if err != nil {
t.Fatalf("create stderr pipe: %v", err)
}
os.Stdout = stdoutW
os.Stderr = stderrW
t.Cleanup(func() {
os.Stdout = origStdout
os.Stderr = origStderr
_ = stdoutW.Close()
_ = stderrR.Close()
_ = stderrW.Close()
})
err = Init(InitOptions{
Level: "info",
Format: "json",
Output: OutputOptions{
ToStdout: true,
ToFile: true,
FilePath: filepath.Join(os.DevNull, "logs", "sub2api.log"),
},
Rotation: RotationOptions{
MaxSizeMB: 10,
MaxBackups: 1,
MaxAgeDays: 1,
},
})
if err != nil {
t.Fatalf("Init() should downgrade instead of failing, got: %v", err)
}
_ = stderrW.Close()
stderrBytes, _ := io.ReadAll(stderrR)
if !strings.Contains(string(stderrBytes), "日志文件输出初始化失败") {
t.Fatalf("stderr should contain fallback warning, got: %s", string(stderrBytes))
}
}
func TestInit_CallerShouldPointToCallsite(t *testing.T) {
origStdout := os.Stdout
origStderr := os.Stderr
stdoutR, stdoutW, err := os.Pipe()
if err != nil {
t.Fatalf("create stdout pipe: %v", err)
}
_, stderrW, err := os.Pipe()
if err != nil {
t.Fatalf("create stderr pipe: %v", err)
}
os.Stdout = stdoutW
os.Stderr = stderrW
t.Cleanup(func() {
os.Stdout = origStdout
os.Stderr = origStderr
_ = stdoutR.Close()
_ = stdoutW.Close()
_ = stderrW.Close()
})
if err := Init(InitOptions{
Level: "info",
Format: "json",
ServiceName: "sub2api",
Environment: "test",
Caller: true,
Output: OutputOptions{
ToStdout: true,
ToFile: false,
},
Sampling: SamplingOptions{Enabled: false},
}); err != nil {
t.Fatalf("Init() error: %v", err)
}
L().Info("caller-check")
Sync()
_ = stdoutW.Close()
logBytes, _ := io.ReadAll(stdoutR)
var line string
for _, item := range strings.Split(string(logBytes), "\n") {
if strings.Contains(item, "caller-check") {
line = item
break
}
}
if line == "" {
t.Fatalf("log output missing caller-check: %s", string(logBytes))
}
var payload map[string]any
if err := json.Unmarshal([]byte(line), &payload); err != nil {
t.Fatalf("parse log json failed: %v, line=%s", err, line)
}
caller, _ := payload["caller"].(string)
if !strings.Contains(caller, "logger_test.go:") {
t.Fatalf("caller should point to this test file, got: %s", caller)
}
}
package logger
import (
"os"
"path/filepath"
"strings"
"time"
)
const (
// DefaultContainerLogPath 为容器内默认日志文件路径。
DefaultContainerLogPath = "/app/data/logs/sub2api.log"
defaultLogFilename = "sub2api.log"
)
type InitOptions struct {
Level string
Format string
ServiceName string
Environment string
Caller bool
StacktraceLevel string
Output OutputOptions
Rotation RotationOptions
Sampling SamplingOptions
}
type OutputOptions struct {
ToStdout bool
ToFile bool
FilePath string
}
type RotationOptions struct {
MaxSizeMB int
MaxBackups int
MaxAgeDays int
Compress bool
LocalTime bool
}
type SamplingOptions struct {
Enabled bool
Initial int
Thereafter int
}
func (o InitOptions) normalized() InitOptions {
out := o
out.Level = strings.ToLower(strings.TrimSpace(out.Level))
if out.Level == "" {
out.Level = "info"
}
out.Format = strings.ToLower(strings.TrimSpace(out.Format))
if out.Format == "" {
out.Format = "console"
}
out.ServiceName = strings.TrimSpace(out.ServiceName)
if out.ServiceName == "" {
out.ServiceName = "sub2api"
}
out.Environment = strings.TrimSpace(out.Environment)
if out.Environment == "" {
out.Environment = "production"
}
out.StacktraceLevel = strings.ToLower(strings.TrimSpace(out.StacktraceLevel))
if out.StacktraceLevel == "" {
out.StacktraceLevel = "error"
}
if !out.Output.ToStdout && !out.Output.ToFile {
out.Output.ToStdout = true
}
out.Output.FilePath = resolveLogFilePath(out.Output.FilePath)
if out.Rotation.MaxSizeMB <= 0 {
out.Rotation.MaxSizeMB = 100
}
if out.Rotation.MaxBackups < 0 {
out.Rotation.MaxBackups = 10
}
if out.Rotation.MaxAgeDays < 0 {
out.Rotation.MaxAgeDays = 7
}
if out.Sampling.Enabled {
if out.Sampling.Initial <= 0 {
out.Sampling.Initial = 100
}
if out.Sampling.Thereafter <= 0 {
out.Sampling.Thereafter = 100
}
}
return out
}
func resolveLogFilePath(explicit string) string {
explicit = strings.TrimSpace(explicit)
if explicit != "" {
return explicit
}
dataDir := strings.TrimSpace(os.Getenv("DATA_DIR"))
if dataDir != "" {
return filepath.Join(dataDir, "logs", defaultLogFilename)
}
return DefaultContainerLogPath
}
func bootstrapOptions() InitOptions {
return InitOptions{
Level: "info",
Format: "console",
ServiceName: "sub2api",
Environment: "bootstrap",
Output: OutputOptions{
ToStdout: true,
ToFile: false,
},
Rotation: RotationOptions{
MaxSizeMB: 100,
MaxBackups: 10,
MaxAgeDays: 7,
Compress: true,
LocalTime: true,
},
Sampling: SamplingOptions{
Enabled: false,
Initial: 100,
Thereafter: 100,
},
}
}
func parseLevel(level string) (Level, bool) {
switch strings.ToLower(strings.TrimSpace(level)) {
case "debug":
return LevelDebug, true
case "info":
return LevelInfo, true
case "warn":
return LevelWarn, true
case "error":
return LevelError, true
default:
return LevelInfo, false
}
}
func parseStacktraceLevel(level string) (Level, bool) {
switch strings.ToLower(strings.TrimSpace(level)) {
case "none":
return LevelFatal + 1, true
case "error":
return LevelError, true
case "fatal":
return LevelFatal, true
default:
return LevelError, false
}
}
func samplingTick() time.Duration {
return time.Second
}
package logger
import (
"os"
"path/filepath"
"testing"
"go.uber.org/zap"
"go.uber.org/zap/zapcore"
)
func TestResolveLogFilePath_Default(t *testing.T) {
t.Setenv("DATA_DIR", "")
got := resolveLogFilePath("")
if got != DefaultContainerLogPath {
t.Fatalf("resolveLogFilePath() = %q, want %q", got, DefaultContainerLogPath)
}
}
func TestResolveLogFilePath_WithDataDir(t *testing.T) {
t.Setenv("DATA_DIR", "/tmp/sub2api-data")
got := resolveLogFilePath("")
want := filepath.Join("/tmp/sub2api-data", "logs", "sub2api.log")
if got != want {
t.Fatalf("resolveLogFilePath() = %q, want %q", got, want)
}
}
func TestResolveLogFilePath_ExplicitPath(t *testing.T) {
t.Setenv("DATA_DIR", "/tmp/ignore")
got := resolveLogFilePath("/var/log/custom.log")
if got != "/var/log/custom.log" {
t.Fatalf("resolveLogFilePath() = %q, want explicit path", got)
}
}
func TestNormalizedOptions_InvalidFallback(t *testing.T) {
t.Setenv("DATA_DIR", "")
opts := InitOptions{
Level: "TRACE",
Format: "TEXT",
ServiceName: "",
Environment: "",
StacktraceLevel: "panic",
Output: OutputOptions{
ToStdout: false,
ToFile: false,
},
Rotation: RotationOptions{
MaxSizeMB: 0,
MaxBackups: -1,
MaxAgeDays: -1,
},
Sampling: SamplingOptions{
Enabled: true,
Initial: 0,
Thereafter: 0,
},
}
out := opts.normalized()
if out.Level != "trace" {
// normalized 仅做 trim/lower,不做校验;校验在 config 层。
t.Fatalf("normalized level should preserve value for upstream validation, got %q", out.Level)
}
if !out.Output.ToStdout {
t.Fatalf("normalized output should fallback to stdout")
}
if out.Output.FilePath != DefaultContainerLogPath {
t.Fatalf("normalized file path = %q", out.Output.FilePath)
}
if out.Rotation.MaxSizeMB != 100 {
t.Fatalf("normalized max_size_mb = %d", out.Rotation.MaxSizeMB)
}
if out.Rotation.MaxBackups != 10 {
t.Fatalf("normalized max_backups = %d", out.Rotation.MaxBackups)
}
if out.Rotation.MaxAgeDays != 7 {
t.Fatalf("normalized max_age_days = %d", out.Rotation.MaxAgeDays)
}
if out.Sampling.Initial != 100 || out.Sampling.Thereafter != 100 {
t.Fatalf("normalized sampling defaults invalid: %+v", out.Sampling)
}
}
func TestBuildFileCore_InvalidPathFallback(t *testing.T) {
t.Setenv("DATA_DIR", "")
opts := bootstrapOptions()
opts.Output.ToFile = true
opts.Output.FilePath = filepath.Join(os.DevNull, "logs", "sub2api.log")
encoderCfg := zapcore.EncoderConfig{
TimeKey: "time",
LevelKey: "level",
MessageKey: "msg",
EncodeTime: zapcore.ISO8601TimeEncoder,
EncodeLevel: zapcore.CapitalLevelEncoder,
}
encoder := zapcore.NewJSONEncoder(encoderCfg)
_, _, err := buildFileCore(encoder, zap.NewAtomicLevel(), opts)
if err == nil {
t.Fatalf("buildFileCore() expected error for invalid path")
}
}
package logger
import (
"context"
"log/slog"
"strings"
"time"
"go.uber.org/zap"
"go.uber.org/zap/zapcore"
)
type slogZapHandler struct {
logger *zap.Logger
attrs []slog.Attr
groups []string
}
func newSlogZapHandler(logger *zap.Logger) slog.Handler {
if logger == nil {
logger = zap.NewNop()
}
return &slogZapHandler{
logger: logger,
attrs: make([]slog.Attr, 0, 8),
groups: make([]string, 0, 4),
}
}
func (h *slogZapHandler) Enabled(_ context.Context, level slog.Level) bool {
switch {
case level >= slog.LevelError:
return h.logger.Core().Enabled(LevelError)
case level >= slog.LevelWarn:
return h.logger.Core().Enabled(LevelWarn)
case level <= slog.LevelDebug:
return h.logger.Core().Enabled(LevelDebug)
default:
return h.logger.Core().Enabled(LevelInfo)
}
}
func (h *slogZapHandler) Handle(_ context.Context, record slog.Record) error {
fields := make([]zap.Field, 0, len(h.attrs)+record.NumAttrs()+3)
fields = append(fields, slogAttrsToZapFields(h.groups, h.attrs)...)
record.Attrs(func(attr slog.Attr) bool {
fields = append(fields, slogAttrToZapField(h.groups, attr))
return true
})
entry := h.logger.With(fields...)
switch {
case record.Level >= slog.LevelError:
entry.Error(record.Message)
case record.Level >= slog.LevelWarn:
entry.Warn(record.Message)
case record.Level <= slog.LevelDebug:
entry.Debug(record.Message)
default:
entry.Info(record.Message)
}
return nil
}
func (h *slogZapHandler) WithAttrs(attrs []slog.Attr) slog.Handler {
next := *h
next.attrs = append(append([]slog.Attr{}, h.attrs...), attrs...)
return &next
}
func (h *slogZapHandler) WithGroup(name string) slog.Handler {
name = strings.TrimSpace(name)
if name == "" {
return h
}
next := *h
next.groups = append(append([]string{}, h.groups...), name)
return &next
}
func slogAttrsToZapFields(groups []string, attrs []slog.Attr) []zap.Field {
fields := make([]zap.Field, 0, len(attrs))
for _, attr := range attrs {
fields = append(fields, slogAttrToZapField(groups, attr))
}
return fields
}
func slogAttrToZapField(groups []string, attr slog.Attr) zap.Field {
if len(groups) > 0 {
attr.Key = strings.Join(append(append([]string{}, groups...), attr.Key), ".")
}
value := attr.Value.Resolve()
switch value.Kind() {
case slog.KindBool:
return zap.Bool(attr.Key, value.Bool())
case slog.KindInt64:
return zap.Int64(attr.Key, value.Int64())
case slog.KindUint64:
return zap.Uint64(attr.Key, value.Uint64())
case slog.KindFloat64:
return zap.Float64(attr.Key, value.Float64())
case slog.KindDuration:
return zap.Duration(attr.Key, value.Duration())
case slog.KindTime:
return zap.Time(attr.Key, value.Time())
case slog.KindString:
return zap.String(attr.Key, value.String())
case slog.KindGroup:
groupFields := make([]zap.Field, 0, len(value.Group()))
for _, nested := range value.Group() {
groupFields = append(groupFields, slogAttrToZapField(nil, nested))
}
return zap.Object(attr.Key, zapObjectFields(groupFields))
case slog.KindAny:
if t, ok := value.Any().(time.Time); ok {
return zap.Time(attr.Key, t)
}
return zap.Any(attr.Key, value.Any())
default:
return zap.String(attr.Key, value.String())
}
}
type zapObjectFields []zap.Field
func (z zapObjectFields) MarshalLogObject(enc zapcore.ObjectEncoder) error {
for _, field := range z {
field.AddTo(enc)
}
return nil
}
package logger
import (
"context"
"log/slog"
"testing"
"time"
"go.uber.org/zap"
"go.uber.org/zap/zapcore"
)
type captureState struct {
writes []capturedWrite
}
type capturedWrite struct {
entry zapcore.Entry
fields []zapcore.Field
}
type captureCore struct {
state *captureState
withFields []zapcore.Field
}
func newCaptureCore() *captureCore {
return &captureCore{state: &captureState{}}
}
func (c *captureCore) Enabled(zapcore.Level) bool {
return true
}
func (c *captureCore) With(fields []zapcore.Field) zapcore.Core {
nextFields := make([]zapcore.Field, 0, len(c.withFields)+len(fields))
nextFields = append(nextFields, c.withFields...)
nextFields = append(nextFields, fields...)
return &captureCore{
state: c.state,
withFields: nextFields,
}
}
func (c *captureCore) Check(entry zapcore.Entry, ce *zapcore.CheckedEntry) *zapcore.CheckedEntry {
return ce.AddCore(entry, c)
}
func (c *captureCore) Write(entry zapcore.Entry, fields []zapcore.Field) error {
allFields := make([]zapcore.Field, 0, len(c.withFields)+len(fields))
allFields = append(allFields, c.withFields...)
allFields = append(allFields, fields...)
c.state.writes = append(c.state.writes, capturedWrite{
entry: entry,
fields: allFields,
})
return nil
}
func (c *captureCore) Sync() error {
return nil
}
func TestSlogZapHandler_Handle_DoesNotAppendTimeField(t *testing.T) {
core := newCaptureCore()
handler := newSlogZapHandler(zap.New(core))
record := slog.NewRecord(time.Date(2026, 1, 1, 12, 0, 0, 0, time.UTC), slog.LevelInfo, "hello", 0)
record.AddAttrs(slog.String("component", "http.access"))
if err := handler.Handle(context.Background(), record); err != nil {
t.Fatalf("handle slog record: %v", err)
}
if len(core.state.writes) != 1 {
t.Fatalf("write calls = %d, want 1", len(core.state.writes))
}
var hasComponent bool
for _, field := range core.state.writes[0].fields {
if field.Key == "time" {
t.Fatalf("unexpected duplicate time field in slog adapter output")
}
if field.Key == "component" {
hasComponent = true
}
}
if !hasComponent {
t.Fatalf("component field should be preserved")
}
}
package logger
import (
"io"
"log"
"os"
"strings"
"testing"
)
func TestInferStdLogLevel(t *testing.T) {
cases := []struct {
msg string
want Level
}{
{msg: "Warning: queue full", want: LevelWarn},
{msg: "Forward request failed: timeout", want: LevelError},
{msg: "[ERROR] upstream unavailable", want: LevelError},
{msg: "service started", want: LevelInfo},
{msg: "debug: cache miss", want: LevelDebug},
}
for _, tc := range cases {
got := inferStdLogLevel(tc.msg)
if got != tc.want {
t.Fatalf("inferStdLogLevel(%q)=%v want=%v", tc.msg, got, tc.want)
}
}
}
func TestNormalizeStdLogMessage(t *testing.T) {
raw := " [TokenRefresh] cycle complete \n total=1 failed=0 \n"
got := normalizeStdLogMessage(raw)
want := "[TokenRefresh] cycle complete total=1 failed=0"
if got != want {
t.Fatalf("normalizeStdLogMessage()=%q want=%q", got, want)
}
}
func TestStdLogBridgeRoutesLevels(t *testing.T) {
origStdout := os.Stdout
origStderr := os.Stderr
stdoutR, stdoutW, err := os.Pipe()
if err != nil {
t.Fatalf("create stdout pipe: %v", err)
}
stderrR, stderrW, err := os.Pipe()
if err != nil {
t.Fatalf("create stderr pipe: %v", err)
}
os.Stdout = stdoutW
os.Stderr = stderrW
t.Cleanup(func() {
os.Stdout = origStdout
os.Stderr = origStderr
_ = stdoutR.Close()
_ = stdoutW.Close()
_ = stderrR.Close()
_ = stderrW.Close()
})
if err := Init(InitOptions{
Level: "debug",
Format: "json",
ServiceName: "sub2api",
Environment: "test",
Output: OutputOptions{
ToStdout: true,
ToFile: false,
},
Sampling: SamplingOptions{Enabled: false},
}); err != nil {
t.Fatalf("Init() error: %v", err)
}
log.Printf("service started")
log.Printf("Warning: queue full")
log.Printf("Forward request failed: timeout")
Sync()
_ = stdoutW.Close()
_ = stderrW.Close()
stdoutBytes, _ := io.ReadAll(stdoutR)
stderrBytes, _ := io.ReadAll(stderrR)
stdoutText := string(stdoutBytes)
stderrText := string(stderrBytes)
if !strings.Contains(stdoutText, "service started") {
t.Fatalf("stdout missing info log: %s", stdoutText)
}
if !strings.Contains(stderrText, "Warning: queue full") {
t.Fatalf("stderr missing warn log: %s", stderrText)
}
if !strings.Contains(stderrText, "Forward request failed: timeout") {
t.Fatalf("stderr missing error log: %s", stderrText)
}
if !strings.Contains(stderrText, "\"legacy_stdlog\":true") {
t.Fatalf("stderr missing legacy_stdlog marker: %s", stderrText)
}
}
func TestLegacyPrintfRoutesLevels(t *testing.T) {
origStdout := os.Stdout
origStderr := os.Stderr
stdoutR, stdoutW, err := os.Pipe()
if err != nil {
t.Fatalf("create stdout pipe: %v", err)
}
stderrR, stderrW, err := os.Pipe()
if err != nil {
t.Fatalf("create stderr pipe: %v", err)
}
os.Stdout = stdoutW
os.Stderr = stderrW
t.Cleanup(func() {
os.Stdout = origStdout
os.Stderr = origStderr
_ = stdoutR.Close()
_ = stdoutW.Close()
_ = stderrR.Close()
_ = stderrW.Close()
})
if err := Init(InitOptions{
Level: "debug",
Format: "json",
ServiceName: "sub2api",
Environment: "test",
Output: OutputOptions{
ToStdout: true,
ToFile: false,
},
Sampling: SamplingOptions{Enabled: false},
}); err != nil {
t.Fatalf("Init() error: %v", err)
}
LegacyPrintf("service.test", "request started")
LegacyPrintf("service.test", "Warning: queue full")
LegacyPrintf("service.test", "forward failed: timeout")
Sync()
_ = stdoutW.Close()
_ = stderrW.Close()
stdoutBytes, _ := io.ReadAll(stdoutR)
stderrBytes, _ := io.ReadAll(stderrR)
stdoutText := string(stdoutBytes)
stderrText := string(stderrBytes)
if !strings.Contains(stdoutText, "request started") {
t.Fatalf("stdout missing info log: %s", stdoutText)
}
if !strings.Contains(stderrText, "Warning: queue full") {
t.Fatalf("stderr missing warn log: %s", stderrText)
}
if !strings.Contains(stderrText, "forward failed: timeout") {
t.Fatalf("stderr missing error log: %s", stderrText)
}
if !strings.Contains(stderrText, "\"legacy_printf\":true") {
t.Fatalf("stderr missing legacy_printf marker: %s", stderrText)
}
if !strings.Contains(stderrText, "\"component\":\"service.test\"") {
t.Fatalf("stderr missing component field: %s", stderrText)
}
}
...@@ -15,7 +15,6 @@ import ( ...@@ -15,7 +15,6 @@ import (
"database/sql" "database/sql"
"encoding/json" "encoding/json"
"errors" "errors"
"log"
"strconv" "strconv"
"time" "time"
...@@ -25,6 +24,7 @@ import ( ...@@ -25,6 +24,7 @@ import (
dbgroup "github.com/Wei-Shaw/sub2api/ent/group" dbgroup "github.com/Wei-Shaw/sub2api/ent/group"
dbpredicate "github.com/Wei-Shaw/sub2api/ent/predicate" dbpredicate "github.com/Wei-Shaw/sub2api/ent/predicate"
dbproxy "github.com/Wei-Shaw/sub2api/ent/proxy" dbproxy "github.com/Wei-Shaw/sub2api/ent/proxy"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
"github.com/lib/pq" "github.com/lib/pq"
...@@ -127,7 +127,7 @@ func (r *accountRepository) Create(ctx context.Context, account *service.Account ...@@ -127,7 +127,7 @@ func (r *accountRepository) Create(ctx context.Context, account *service.Account
account.CreatedAt = created.CreatedAt account.CreatedAt = created.CreatedAt
account.UpdatedAt = created.UpdatedAt account.UpdatedAt = created.UpdatedAt
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &account.ID, nil, buildSchedulerGroupPayload(account.GroupIDs)); err != nil { if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &account.ID, nil, buildSchedulerGroupPayload(account.GroupIDs)); err != nil {
log.Printf("[SchedulerOutbox] enqueue account create failed: account=%d err=%v", account.ID, err) logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue account create failed: account=%d err=%v", account.ID, err)
} }
return nil return nil
} }
...@@ -388,7 +388,7 @@ func (r *accountRepository) Update(ctx context.Context, account *service.Account ...@@ -388,7 +388,7 @@ func (r *accountRepository) Update(ctx context.Context, account *service.Account
} }
account.UpdatedAt = updated.UpdatedAt account.UpdatedAt = updated.UpdatedAt
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &account.ID, nil, buildSchedulerGroupPayload(account.GroupIDs)); err != nil { if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &account.ID, nil, buildSchedulerGroupPayload(account.GroupIDs)); err != nil {
log.Printf("[SchedulerOutbox] enqueue account update failed: account=%d err=%v", account.ID, err) logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue account update failed: account=%d err=%v", account.ID, err)
} }
if account.Status == service.StatusError || account.Status == service.StatusDisabled || !account.Schedulable { if account.Status == service.StatusError || account.Status == service.StatusDisabled || !account.Schedulable {
r.syncSchedulerAccountSnapshot(ctx, account.ID) r.syncSchedulerAccountSnapshot(ctx, account.ID)
...@@ -429,7 +429,7 @@ func (r *accountRepository) Delete(ctx context.Context, id int64) error { ...@@ -429,7 +429,7 @@ func (r *accountRepository) Delete(ctx context.Context, id int64) error {
} }
} }
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, buildSchedulerGroupPayload(groupIDs)); err != nil { if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, buildSchedulerGroupPayload(groupIDs)); err != nil {
log.Printf("[SchedulerOutbox] enqueue account delete failed: account=%d err=%v", id, err) logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue account delete failed: account=%d err=%v", id, err)
} }
return nil return nil
} }
...@@ -525,7 +525,7 @@ func (r *accountRepository) UpdateLastUsed(ctx context.Context, id int64) error ...@@ -525,7 +525,7 @@ func (r *accountRepository) UpdateLastUsed(ctx context.Context, id int64) error
}, },
} }
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountLastUsed, &id, nil, payload); err != nil { if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountLastUsed, &id, nil, payload); err != nil {
log.Printf("[SchedulerOutbox] enqueue last used failed: account=%d err=%v", id, err) logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue last used failed: account=%d err=%v", id, err)
} }
return nil return nil
} }
...@@ -560,7 +560,7 @@ func (r *accountRepository) BatchUpdateLastUsed(ctx context.Context, updates map ...@@ -560,7 +560,7 @@ func (r *accountRepository) BatchUpdateLastUsed(ctx context.Context, updates map
} }
payload := map[string]any{"last_used": lastUsedPayload} payload := map[string]any{"last_used": lastUsedPayload}
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountLastUsed, nil, nil, payload); err != nil { if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountLastUsed, nil, nil, payload); err != nil {
log.Printf("[SchedulerOutbox] enqueue batch last used failed: err=%v", err) logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue batch last used failed: err=%v", err)
} }
return nil return nil
} }
...@@ -575,7 +575,7 @@ func (r *accountRepository) SetError(ctx context.Context, id int64, errorMsg str ...@@ -575,7 +575,7 @@ func (r *accountRepository) SetError(ctx context.Context, id int64, errorMsg str
return err return err
} }
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil { if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
log.Printf("[SchedulerOutbox] enqueue set error failed: account=%d err=%v", id, err) logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue set error failed: account=%d err=%v", id, err)
} }
r.syncSchedulerAccountSnapshot(ctx, id) r.syncSchedulerAccountSnapshot(ctx, id)
return nil return nil
...@@ -595,11 +595,11 @@ func (r *accountRepository) syncSchedulerAccountSnapshot(ctx context.Context, ac ...@@ -595,11 +595,11 @@ func (r *accountRepository) syncSchedulerAccountSnapshot(ctx context.Context, ac
} }
account, err := r.GetByID(ctx, accountID) account, err := r.GetByID(ctx, accountID)
if err != nil { if err != nil {
log.Printf("[Scheduler] sync account snapshot read failed: id=%d err=%v", accountID, err) logger.LegacyPrintf("repository.account", "[Scheduler] sync account snapshot read failed: id=%d err=%v", accountID, err)
return return
} }
if err := r.schedulerCache.SetAccount(ctx, account); err != nil { if err := r.schedulerCache.SetAccount(ctx, account); err != nil {
log.Printf("[Scheduler] sync account snapshot write failed: id=%d err=%v", accountID, err) logger.LegacyPrintf("repository.account", "[Scheduler] sync account snapshot write failed: id=%d err=%v", accountID, err)
} }
} }
...@@ -623,7 +623,7 @@ func (r *accountRepository) AddToGroup(ctx context.Context, accountID, groupID i ...@@ -623,7 +623,7 @@ func (r *accountRepository) AddToGroup(ctx context.Context, accountID, groupID i
} }
payload := buildSchedulerGroupPayload([]int64{groupID}) payload := buildSchedulerGroupPayload([]int64{groupID})
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountGroupsChanged, &accountID, nil, payload); err != nil { if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountGroupsChanged, &accountID, nil, payload); err != nil {
log.Printf("[SchedulerOutbox] enqueue add to group failed: account=%d group=%d err=%v", accountID, groupID, err) logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue add to group failed: account=%d group=%d err=%v", accountID, groupID, err)
} }
return nil return nil
} }
...@@ -640,7 +640,7 @@ func (r *accountRepository) RemoveFromGroup(ctx context.Context, accountID, grou ...@@ -640,7 +640,7 @@ func (r *accountRepository) RemoveFromGroup(ctx context.Context, accountID, grou
} }
payload := buildSchedulerGroupPayload([]int64{groupID}) payload := buildSchedulerGroupPayload([]int64{groupID})
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountGroupsChanged, &accountID, nil, payload); err != nil { if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountGroupsChanged, &accountID, nil, payload); err != nil {
log.Printf("[SchedulerOutbox] enqueue remove from group failed: account=%d group=%d err=%v", accountID, groupID, err) logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue remove from group failed: account=%d group=%d err=%v", accountID, groupID, err)
} }
return nil return nil
} }
...@@ -713,7 +713,7 @@ func (r *accountRepository) BindGroups(ctx context.Context, accountID int64, gro ...@@ -713,7 +713,7 @@ func (r *accountRepository) BindGroups(ctx context.Context, accountID int64, gro
} }
payload := buildSchedulerGroupPayload(mergeGroupIDs(existingGroupIDs, groupIDs)) payload := buildSchedulerGroupPayload(mergeGroupIDs(existingGroupIDs, groupIDs))
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountGroupsChanged, &accountID, nil, payload); err != nil { if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountGroupsChanged, &accountID, nil, payload); err != nil {
log.Printf("[SchedulerOutbox] enqueue bind groups failed: account=%d err=%v", accountID, err) logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue bind groups failed: account=%d err=%v", accountID, err)
} }
return nil return nil
} }
...@@ -821,7 +821,7 @@ func (r *accountRepository) SetRateLimited(ctx context.Context, id int64, resetA ...@@ -821,7 +821,7 @@ func (r *accountRepository) SetRateLimited(ctx context.Context, id int64, resetA
return err return err
} }
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil { if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
log.Printf("[SchedulerOutbox] enqueue rate limit failed: account=%d err=%v", id, err) logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue rate limit failed: account=%d err=%v", id, err)
} }
return nil return nil
} }
...@@ -868,7 +868,7 @@ func (r *accountRepository) SetModelRateLimit(ctx context.Context, id int64, sco ...@@ -868,7 +868,7 @@ func (r *accountRepository) SetModelRateLimit(ctx context.Context, id int64, sco
return service.ErrAccountNotFound return service.ErrAccountNotFound
} }
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil { if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
log.Printf("[SchedulerOutbox] enqueue model rate limit failed: account=%d err=%v", id, err) logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue model rate limit failed: account=%d err=%v", id, err)
} }
return nil return nil
} }
...@@ -882,7 +882,7 @@ func (r *accountRepository) SetOverloaded(ctx context.Context, id int64, until t ...@@ -882,7 +882,7 @@ func (r *accountRepository) SetOverloaded(ctx context.Context, id int64, until t
return err return err
} }
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil { if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
log.Printf("[SchedulerOutbox] enqueue overload failed: account=%d err=%v", id, err) logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue overload failed: account=%d err=%v", id, err)
} }
return nil return nil
} }
...@@ -901,7 +901,7 @@ func (r *accountRepository) SetTempUnschedulable(ctx context.Context, id int64, ...@@ -901,7 +901,7 @@ func (r *accountRepository) SetTempUnschedulable(ctx context.Context, id int64,
return err return err
} }
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil { if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
log.Printf("[SchedulerOutbox] enqueue temp unschedulable failed: account=%d err=%v", id, err) logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue temp unschedulable failed: account=%d err=%v", id, err)
} }
r.syncSchedulerAccountSnapshot(ctx, id) r.syncSchedulerAccountSnapshot(ctx, id)
return nil return nil
...@@ -920,7 +920,7 @@ func (r *accountRepository) ClearTempUnschedulable(ctx context.Context, id int64 ...@@ -920,7 +920,7 @@ func (r *accountRepository) ClearTempUnschedulable(ctx context.Context, id int64
return err return err
} }
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil { if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
log.Printf("[SchedulerOutbox] enqueue clear temp unschedulable failed: account=%d err=%v", id, err) logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue clear temp unschedulable failed: account=%d err=%v", id, err)
} }
return nil return nil
} }
...@@ -936,7 +936,7 @@ func (r *accountRepository) ClearRateLimit(ctx context.Context, id int64) error ...@@ -936,7 +936,7 @@ func (r *accountRepository) ClearRateLimit(ctx context.Context, id int64) error
return err return err
} }
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil { if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
log.Printf("[SchedulerOutbox] enqueue clear rate limit failed: account=%d err=%v", id, err) logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue clear rate limit failed: account=%d err=%v", id, err)
} }
return nil return nil
} }
...@@ -960,7 +960,7 @@ func (r *accountRepository) ClearAntigravityQuotaScopes(ctx context.Context, id ...@@ -960,7 +960,7 @@ func (r *accountRepository) ClearAntigravityQuotaScopes(ctx context.Context, id
return service.ErrAccountNotFound return service.ErrAccountNotFound
} }
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil { if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
log.Printf("[SchedulerOutbox] enqueue clear quota scopes failed: account=%d err=%v", id, err) logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue clear quota scopes failed: account=%d err=%v", id, err)
} }
return nil return nil
} }
...@@ -984,7 +984,7 @@ func (r *accountRepository) ClearModelRateLimits(ctx context.Context, id int64) ...@@ -984,7 +984,7 @@ func (r *accountRepository) ClearModelRateLimits(ctx context.Context, id int64)
return service.ErrAccountNotFound return service.ErrAccountNotFound
} }
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil { if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
log.Printf("[SchedulerOutbox] enqueue clear model rate limit failed: account=%d err=%v", id, err) logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue clear model rate limit failed: account=%d err=%v", id, err)
} }
return nil return nil
} }
...@@ -1006,7 +1006,7 @@ func (r *accountRepository) UpdateSessionWindow(ctx context.Context, id int64, s ...@@ -1006,7 +1006,7 @@ func (r *accountRepository) UpdateSessionWindow(ctx context.Context, id int64, s
// 触发调度器缓存更新(仅当窗口时间有变化时) // 触发调度器缓存更新(仅当窗口时间有变化时)
if start != nil || end != nil { if start != nil || end != nil {
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil { if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
log.Printf("[SchedulerOutbox] enqueue session window update failed: account=%d err=%v", id, err) logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue session window update failed: account=%d err=%v", id, err)
} }
} }
return nil return nil
...@@ -1021,7 +1021,7 @@ func (r *accountRepository) SetSchedulable(ctx context.Context, id int64, schedu ...@@ -1021,7 +1021,7 @@ func (r *accountRepository) SetSchedulable(ctx context.Context, id int64, schedu
return err return err
} }
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil { if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
log.Printf("[SchedulerOutbox] enqueue schedulable change failed: account=%d err=%v", id, err) logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue schedulable change failed: account=%d err=%v", id, err)
} }
if !schedulable { if !schedulable {
r.syncSchedulerAccountSnapshot(ctx, id) r.syncSchedulerAccountSnapshot(ctx, id)
...@@ -1049,7 +1049,7 @@ func (r *accountRepository) AutoPauseExpiredAccounts(ctx context.Context, now ti ...@@ -1049,7 +1049,7 @@ func (r *accountRepository) AutoPauseExpiredAccounts(ctx context.Context, now ti
} }
if rows > 0 { if rows > 0 {
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventFullRebuild, nil, nil, nil); err != nil { if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventFullRebuild, nil, nil, nil); err != nil {
log.Printf("[SchedulerOutbox] enqueue auto pause rebuild failed: err=%v", err) logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue auto pause rebuild failed: err=%v", err)
} }
} }
return rows, nil return rows, nil
...@@ -1085,7 +1085,7 @@ func (r *accountRepository) UpdateExtra(ctx context.Context, id int64, updates m ...@@ -1085,7 +1085,7 @@ func (r *accountRepository) UpdateExtra(ctx context.Context, id int64, updates m
return service.ErrAccountNotFound return service.ErrAccountNotFound
} }
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil { if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
log.Printf("[SchedulerOutbox] enqueue extra update failed: account=%d err=%v", id, err) logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue extra update failed: account=%d err=%v", id, err)
} }
return nil return nil
} }
...@@ -1179,7 +1179,7 @@ func (r *accountRepository) BulkUpdate(ctx context.Context, ids []int64, updates ...@@ -1179,7 +1179,7 @@ func (r *accountRepository) BulkUpdate(ctx context.Context, ids []int64, updates
if rows > 0 { if rows > 0 {
payload := map[string]any{"account_ids": ids} payload := map[string]any{"account_ids": ids}
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountBulkChanged, nil, nil, payload); err != nil { if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountBulkChanged, nil, nil, payload); err != nil {
log.Printf("[SchedulerOutbox] enqueue bulk update failed: err=%v", err) logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue bulk update failed: err=%v", err)
} }
shouldSync := false shouldSync := false
if updates.Status != nil && (*updates.Status == service.StatusError || *updates.Status == service.StatusDisabled) { if updates.Status != nil && (*updates.Status == service.StatusError || *updates.Status == service.StatusDisabled) {
......
...@@ -4,12 +4,12 @@ import ( ...@@ -4,12 +4,12 @@ import (
"context" "context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"log"
"net/http" "net/http"
"net/url" "net/url"
"strings" "strings"
"time" "time"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"github.com/Wei-Shaw/sub2api/internal/pkg/oauth" "github.com/Wei-Shaw/sub2api/internal/pkg/oauth"
"github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
"github.com/Wei-Shaw/sub2api/internal/util/logredact" "github.com/Wei-Shaw/sub2api/internal/util/logredact"
...@@ -41,7 +41,7 @@ func (s *claudeOAuthService) GetOrganizationUUID(ctx context.Context, sessionKey ...@@ -41,7 +41,7 @@ func (s *claudeOAuthService) GetOrganizationUUID(ctx context.Context, sessionKey
} }
targetURL := s.baseURL + "/api/organizations" targetURL := s.baseURL + "/api/organizations"
log.Printf("[OAuth] Step 1: Getting organization UUID from %s", targetURL) logger.LegacyPrintf("repository.claude_oauth", "[OAuth] Step 1: Getting organization UUID from %s", targetURL)
resp, err := client.R(). resp, err := client.R().
SetContext(ctx). SetContext(ctx).
...@@ -53,11 +53,11 @@ func (s *claudeOAuthService) GetOrganizationUUID(ctx context.Context, sessionKey ...@@ -53,11 +53,11 @@ func (s *claudeOAuthService) GetOrganizationUUID(ctx context.Context, sessionKey
Get(targetURL) Get(targetURL)
if err != nil { if err != nil {
log.Printf("[OAuth] Step 1 FAILED - Request error: %v", err) logger.LegacyPrintf("repository.claude_oauth", "[OAuth] Step 1 FAILED - Request error: %v", err)
return "", fmt.Errorf("request failed: %w", err) return "", fmt.Errorf("request failed: %w", err)
} }
log.Printf("[OAuth] Step 1 Response - Status: %d", resp.StatusCode) logger.LegacyPrintf("repository.claude_oauth", "[OAuth] Step 1 Response - Status: %d", resp.StatusCode)
if !resp.IsSuccessState() { if !resp.IsSuccessState() {
return "", fmt.Errorf("failed to get organizations: status %d, body: %s", resp.StatusCode, resp.String()) return "", fmt.Errorf("failed to get organizations: status %d, body: %s", resp.StatusCode, resp.String())
...@@ -69,21 +69,21 @@ func (s *claudeOAuthService) GetOrganizationUUID(ctx context.Context, sessionKey ...@@ -69,21 +69,21 @@ func (s *claudeOAuthService) GetOrganizationUUID(ctx context.Context, sessionKey
// 如果只有一个组织,直接使用 // 如果只有一个组织,直接使用
if len(orgs) == 1 { if len(orgs) == 1 {
log.Printf("[OAuth] Step 1 SUCCESS - Single org found, UUID: %s, Name: %s", orgs[0].UUID, orgs[0].Name) logger.LegacyPrintf("repository.claude_oauth", "[OAuth] Step 1 SUCCESS - Single org found, UUID: %s, Name: %s", orgs[0].UUID, orgs[0].Name)
return orgs[0].UUID, nil return orgs[0].UUID, nil
} }
// 如果有多个组织,优先选择 raven_type 为 "team" 的组织 // 如果有多个组织,优先选择 raven_type 为 "team" 的组织
for _, org := range orgs { for _, org := range orgs {
if org.RavenType != nil && *org.RavenType == "team" { if org.RavenType != nil && *org.RavenType == "team" {
log.Printf("[OAuth] Step 1 SUCCESS - Selected team org, UUID: %s, Name: %s, RavenType: %s", logger.LegacyPrintf("repository.claude_oauth", "[OAuth] Step 1 SUCCESS - Selected team org, UUID: %s, Name: %s, RavenType: %s",
org.UUID, org.Name, *org.RavenType) org.UUID, org.Name, *org.RavenType)
return org.UUID, nil return org.UUID, nil
} }
} }
// 如果没有 team 类型的组织,使用第一个 // 如果没有 team 类型的组织,使用第一个
log.Printf("[OAuth] Step 1 SUCCESS - No team org found, using first org, UUID: %s, Name: %s", orgs[0].UUID, orgs[0].Name) logger.LegacyPrintf("repository.claude_oauth", "[OAuth] Step 1 SUCCESS - No team org found, using first org, UUID: %s, Name: %s", orgs[0].UUID, orgs[0].Name)
return orgs[0].UUID, nil return orgs[0].UUID, nil
} }
...@@ -103,9 +103,9 @@ func (s *claudeOAuthService) GetAuthorizationCode(ctx context.Context, sessionKe ...@@ -103,9 +103,9 @@ func (s *claudeOAuthService) GetAuthorizationCode(ctx context.Context, sessionKe
"code_challenge_method": "S256", "code_challenge_method": "S256",
} }
log.Printf("[OAuth] Step 2: Getting authorization code from %s", authURL) logger.LegacyPrintf("repository.claude_oauth", "[OAuth] Step 2: Getting authorization code from %s", authURL)
reqBodyJSON, _ := json.Marshal(logredact.RedactMap(reqBody)) reqBodyJSON, _ := json.Marshal(logredact.RedactMap(reqBody))
log.Printf("[OAuth] Step 2 Request Body: %s", string(reqBodyJSON)) logger.LegacyPrintf("repository.claude_oauth", "[OAuth] Step 2 Request Body: %s", string(reqBodyJSON))
var result struct { var result struct {
RedirectURI string `json:"redirect_uri"` RedirectURI string `json:"redirect_uri"`
...@@ -128,11 +128,11 @@ func (s *claudeOAuthService) GetAuthorizationCode(ctx context.Context, sessionKe ...@@ -128,11 +128,11 @@ func (s *claudeOAuthService) GetAuthorizationCode(ctx context.Context, sessionKe
Post(authURL) Post(authURL)
if err != nil { if err != nil {
log.Printf("[OAuth] Step 2 FAILED - Request error: %v", err) logger.LegacyPrintf("repository.claude_oauth", "[OAuth] Step 2 FAILED - Request error: %v", err)
return "", fmt.Errorf("request failed: %w", err) return "", fmt.Errorf("request failed: %w", err)
} }
log.Printf("[OAuth] Step 2 Response - Status: %d, Body: %s", resp.StatusCode, logredact.RedactJSON(resp.Bytes())) logger.LegacyPrintf("repository.claude_oauth", "[OAuth] Step 2 Response - Status: %d, Body: %s", resp.StatusCode, logredact.RedactJSON(resp.Bytes()))
if !resp.IsSuccessState() { if !resp.IsSuccessState() {
return "", fmt.Errorf("failed to get authorization code: status %d, body: %s", resp.StatusCode, resp.String()) return "", fmt.Errorf("failed to get authorization code: status %d, body: %s", resp.StatusCode, resp.String())
...@@ -160,7 +160,7 @@ func (s *claudeOAuthService) GetAuthorizationCode(ctx context.Context, sessionKe ...@@ -160,7 +160,7 @@ func (s *claudeOAuthService) GetAuthorizationCode(ctx context.Context, sessionKe
fullCode = authCode + "#" + responseState fullCode = authCode + "#" + responseState
} }
log.Printf("[OAuth] Step 2 SUCCESS - Got authorization code") logger.LegacyPrintf("repository.claude_oauth", "[OAuth] Step 2 SUCCESS - Got authorization code")
return fullCode, nil return fullCode, nil
} }
...@@ -192,9 +192,9 @@ func (s *claudeOAuthService) ExchangeCodeForToken(ctx context.Context, code, cod ...@@ -192,9 +192,9 @@ func (s *claudeOAuthService) ExchangeCodeForToken(ctx context.Context, code, cod
reqBody["expires_in"] = 31536000 // 365 * 24 * 60 * 60 seconds reqBody["expires_in"] = 31536000 // 365 * 24 * 60 * 60 seconds
} }
log.Printf("[OAuth] Step 3: Exchanging code for token at %s", s.tokenURL) logger.LegacyPrintf("repository.claude_oauth", "[OAuth] Step 3: Exchanging code for token at %s", s.tokenURL)
reqBodyJSON, _ := json.Marshal(logredact.RedactMap(reqBody)) reqBodyJSON, _ := json.Marshal(logredact.RedactMap(reqBody))
log.Printf("[OAuth] Step 3 Request Body: %s", string(reqBodyJSON)) logger.LegacyPrintf("repository.claude_oauth", "[OAuth] Step 3 Request Body: %s", string(reqBodyJSON))
var tokenResp oauth.TokenResponse var tokenResp oauth.TokenResponse
...@@ -208,17 +208,17 @@ func (s *claudeOAuthService) ExchangeCodeForToken(ctx context.Context, code, cod ...@@ -208,17 +208,17 @@ func (s *claudeOAuthService) ExchangeCodeForToken(ctx context.Context, code, cod
Post(s.tokenURL) Post(s.tokenURL)
if err != nil { if err != nil {
log.Printf("[OAuth] Step 3 FAILED - Request error: %v", err) logger.LegacyPrintf("repository.claude_oauth", "[OAuth] Step 3 FAILED - Request error: %v", err)
return nil, fmt.Errorf("request failed: %w", err) return nil, fmt.Errorf("request failed: %w", err)
} }
log.Printf("[OAuth] Step 3 Response - Status: %d, Body: %s", resp.StatusCode, logredact.RedactJSON(resp.Bytes())) logger.LegacyPrintf("repository.claude_oauth", "[OAuth] Step 3 Response - Status: %d, Body: %s", resp.StatusCode, logredact.RedactJSON(resp.Bytes()))
if !resp.IsSuccessState() { if !resp.IsSuccessState() {
return nil, fmt.Errorf("token exchange failed: status %d, body: %s", resp.StatusCode, resp.String()) return nil, fmt.Errorf("token exchange failed: status %d, body: %s", resp.StatusCode, resp.String())
} }
log.Printf("[OAuth] Step 3 SUCCESS - Got access token") logger.LegacyPrintf("repository.claude_oauth", "[OAuth] Step 3 SUCCESS - Got access token")
return &tokenResp, nil return &tokenResp, nil
} }
......
...@@ -4,11 +4,11 @@ import ( ...@@ -4,11 +4,11 @@ import (
"context" "context"
"database/sql" "database/sql"
"errors" "errors"
"log"
dbent "github.com/Wei-Shaw/sub2api/ent" dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/apikey" "github.com/Wei-Shaw/sub2api/ent/apikey"
"github.com/Wei-Shaw/sub2api/ent/group" "github.com/Wei-Shaw/sub2api/ent/group"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
"github.com/lib/pq" "github.com/lib/pq"
...@@ -72,7 +72,7 @@ func (r *groupRepository) Create(ctx context.Context, groupIn *service.Group) er ...@@ -72,7 +72,7 @@ func (r *groupRepository) Create(ctx context.Context, groupIn *service.Group) er
groupIn.CreatedAt = created.CreatedAt groupIn.CreatedAt = created.CreatedAt
groupIn.UpdatedAt = created.UpdatedAt groupIn.UpdatedAt = created.UpdatedAt
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventGroupChanged, nil, &groupIn.ID, nil); err != nil { if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventGroupChanged, nil, &groupIn.ID, nil); err != nil {
log.Printf("[SchedulerOutbox] enqueue group create failed: group=%d err=%v", groupIn.ID, err) logger.LegacyPrintf("repository.group", "[SchedulerOutbox] enqueue group create failed: group=%d err=%v", groupIn.ID, err)
} }
} }
return translatePersistenceError(err, nil, service.ErrGroupExists) return translatePersistenceError(err, nil, service.ErrGroupExists)
...@@ -152,7 +152,7 @@ func (r *groupRepository) Update(ctx context.Context, groupIn *service.Group) er ...@@ -152,7 +152,7 @@ func (r *groupRepository) Update(ctx context.Context, groupIn *service.Group) er
} }
groupIn.UpdatedAt = updated.UpdatedAt groupIn.UpdatedAt = updated.UpdatedAt
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventGroupChanged, nil, &groupIn.ID, nil); err != nil { if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventGroupChanged, nil, &groupIn.ID, nil); err != nil {
log.Printf("[SchedulerOutbox] enqueue group update failed: group=%d err=%v", groupIn.ID, err) logger.LegacyPrintf("repository.group", "[SchedulerOutbox] enqueue group update failed: group=%d err=%v", groupIn.ID, err)
} }
return nil return nil
} }
...@@ -163,7 +163,7 @@ func (r *groupRepository) Delete(ctx context.Context, id int64) error { ...@@ -163,7 +163,7 @@ func (r *groupRepository) Delete(ctx context.Context, id int64) error {
return translatePersistenceError(err, service.ErrGroupNotFound, nil) return translatePersistenceError(err, service.ErrGroupNotFound, nil)
} }
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventGroupChanged, nil, &id, nil); err != nil { if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventGroupChanged, nil, &id, nil); err != nil {
log.Printf("[SchedulerOutbox] enqueue group delete failed: group=%d err=%v", id, err) logger.LegacyPrintf("repository.group", "[SchedulerOutbox] enqueue group delete failed: group=%d err=%v", id, err)
} }
return nil return nil
} }
...@@ -296,7 +296,7 @@ func (r *groupRepository) DeleteAccountGroupsByGroupID(ctx context.Context, grou ...@@ -296,7 +296,7 @@ func (r *groupRepository) DeleteAccountGroupsByGroupID(ctx context.Context, grou
} }
affected, _ := res.RowsAffected() affected, _ := res.RowsAffected()
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventGroupChanged, nil, &groupID, nil); err != nil { if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventGroupChanged, nil, &groupID, nil); err != nil {
log.Printf("[SchedulerOutbox] enqueue group account clear failed: group=%d err=%v", groupID, err) logger.LegacyPrintf("repository.group", "[SchedulerOutbox] enqueue group account clear failed: group=%d err=%v", groupID, err)
} }
return affected, nil return affected, nil
} }
...@@ -406,7 +406,7 @@ func (r *groupRepository) DeleteCascade(ctx context.Context, id int64) ([]int64, ...@@ -406,7 +406,7 @@ func (r *groupRepository) DeleteCascade(ctx context.Context, id int64) ([]int64,
} }
} }
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventGroupChanged, nil, &id, nil); err != nil { if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventGroupChanged, nil, &id, nil); err != nil {
log.Printf("[SchedulerOutbox] enqueue group cascade delete failed: group=%d err=%v", id, err) logger.LegacyPrintf("repository.group", "[SchedulerOutbox] enqueue group cascade delete failed: group=%d err=%v", id, err)
} }
return affectedUserIDs, nil return affectedUserIDs, nil
...@@ -500,7 +500,7 @@ func (r *groupRepository) BindAccountsToGroup(ctx context.Context, groupID int64 ...@@ -500,7 +500,7 @@ func (r *groupRepository) BindAccountsToGroup(ctx context.Context, groupID int64
// 发送调度器事件 // 发送调度器事件
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventGroupChanged, nil, &groupID, nil); err != nil { if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventGroupChanged, nil, &groupID, nil); err != nil {
log.Printf("[SchedulerOutbox] enqueue bind accounts to group failed: group=%d err=%v", groupID, err) logger.LegacyPrintf("repository.group", "[SchedulerOutbox] enqueue bind accounts to group failed: group=%d err=%v", groupID, err)
} }
return nil return nil
......
...@@ -3,6 +3,7 @@ package repository ...@@ -3,6 +3,7 @@ package repository
import ( import (
"context" "context"
"database/sql" "database/sql"
"encoding/json"
"fmt" "fmt"
"strings" "strings"
"time" "time"
...@@ -938,6 +939,243 @@ WHERE id = $1` ...@@ -938,6 +939,243 @@ WHERE id = $1`
return err return err
} }
func (r *opsRepository) BatchInsertSystemLogs(ctx context.Context, inputs []*service.OpsInsertSystemLogInput) (int64, error) {
if r == nil || r.db == nil {
return 0, fmt.Errorf("nil ops repository")
}
if len(inputs) == 0 {
return 0, nil
}
tx, err := r.db.BeginTx(ctx, nil)
if err != nil {
return 0, err
}
stmt, err := tx.PrepareContext(ctx, pq.CopyIn(
"ops_system_logs",
"created_at",
"level",
"component",
"message",
"request_id",
"client_request_id",
"user_id",
"account_id",
"platform",
"model",
"extra",
))
if err != nil {
_ = tx.Rollback()
return 0, err
}
var inserted int64
for _, input := range inputs {
if input == nil {
continue
}
createdAt := input.CreatedAt
if createdAt.IsZero() {
createdAt = time.Now().UTC()
}
component := strings.TrimSpace(input.Component)
level := strings.ToLower(strings.TrimSpace(input.Level))
message := strings.TrimSpace(input.Message)
if level == "" || message == "" {
continue
}
if component == "" {
component = "app"
}
extra := strings.TrimSpace(input.ExtraJSON)
if extra == "" {
extra = "{}"
}
if _, err := stmt.ExecContext(
ctx,
createdAt.UTC(),
level,
component,
message,
opsNullString(input.RequestID),
opsNullString(input.ClientRequestID),
opsNullInt64(input.UserID),
opsNullInt64(input.AccountID),
opsNullString(input.Platform),
opsNullString(input.Model),
extra,
); err != nil {
_ = stmt.Close()
_ = tx.Rollback()
return inserted, err
}
inserted++
}
if _, err := stmt.ExecContext(ctx); err != nil {
_ = stmt.Close()
_ = tx.Rollback()
return inserted, err
}
if err := stmt.Close(); err != nil {
_ = tx.Rollback()
return inserted, err
}
if err := tx.Commit(); err != nil {
return inserted, err
}
return inserted, nil
}
func (r *opsRepository) ListSystemLogs(ctx context.Context, filter *service.OpsSystemLogFilter) (*service.OpsSystemLogList, error) {
if r == nil || r.db == nil {
return nil, fmt.Errorf("nil ops repository")
}
if filter == nil {
filter = &service.OpsSystemLogFilter{}
}
page := filter.Page
if page <= 0 {
page = 1
}
pageSize := filter.PageSize
if pageSize <= 0 {
pageSize = 50
}
if pageSize > 200 {
pageSize = 200
}
where, args, _ := buildOpsSystemLogsWhere(filter)
countSQL := "SELECT COUNT(*) FROM ops_system_logs l " + where
var total int
if err := r.db.QueryRowContext(ctx, countSQL, args...).Scan(&total); err != nil {
return nil, err
}
offset := (page - 1) * pageSize
argsWithLimit := append(args, pageSize, offset)
query := `
SELECT
l.id,
l.created_at,
l.level,
COALESCE(l.component, ''),
COALESCE(l.message, ''),
COALESCE(l.request_id, ''),
COALESCE(l.client_request_id, ''),
l.user_id,
l.account_id,
COALESCE(l.platform, ''),
COALESCE(l.model, ''),
COALESCE(l.extra::text, '{}')
FROM ops_system_logs l
` + where + `
ORDER BY l.created_at DESC, l.id DESC
LIMIT $` + itoa(len(args)+1) + ` OFFSET $` + itoa(len(args)+2)
rows, err := r.db.QueryContext(ctx, query, argsWithLimit...)
if err != nil {
return nil, err
}
defer func() { _ = rows.Close() }()
logs := make([]*service.OpsSystemLog, 0, pageSize)
for rows.Next() {
item := &service.OpsSystemLog{}
var userID sql.NullInt64
var accountID sql.NullInt64
var extraRaw string
if err := rows.Scan(
&item.ID,
&item.CreatedAt,
&item.Level,
&item.Component,
&item.Message,
&item.RequestID,
&item.ClientRequestID,
&userID,
&accountID,
&item.Platform,
&item.Model,
&extraRaw,
); err != nil {
return nil, err
}
if userID.Valid {
v := userID.Int64
item.UserID = &v
}
if accountID.Valid {
v := accountID.Int64
item.AccountID = &v
}
extraRaw = strings.TrimSpace(extraRaw)
if extraRaw != "" && extraRaw != "null" && extraRaw != "{}" {
extra := make(map[string]any)
if err := json.Unmarshal([]byte(extraRaw), &extra); err == nil {
item.Extra = extra
}
}
logs = append(logs, item)
}
if err := rows.Err(); err != nil {
return nil, err
}
return &service.OpsSystemLogList{
Logs: logs,
Total: total,
Page: page,
PageSize: pageSize,
}, nil
}
func (r *opsRepository) DeleteSystemLogs(ctx context.Context, filter *service.OpsSystemLogCleanupFilter) (int64, error) {
if r == nil || r.db == nil {
return 0, fmt.Errorf("nil ops repository")
}
if filter == nil {
filter = &service.OpsSystemLogCleanupFilter{}
}
where, args, hasConstraint := buildOpsSystemLogsCleanupWhere(filter)
if !hasConstraint {
return 0, fmt.Errorf("cleanup requires at least one filter condition")
}
query := "DELETE FROM ops_system_logs l " + where
res, err := r.db.ExecContext(ctx, query, args...)
if err != nil {
return 0, err
}
return res.RowsAffected()
}
func (r *opsRepository) InsertSystemLogCleanupAudit(ctx context.Context, input *service.OpsSystemLogCleanupAudit) error {
if r == nil || r.db == nil {
return fmt.Errorf("nil ops repository")
}
if input == nil {
return fmt.Errorf("nil input")
}
createdAt := input.CreatedAt
if createdAt.IsZero() {
createdAt = time.Now().UTC()
}
_, err := r.db.ExecContext(ctx, `
INSERT INTO ops_system_log_cleanup_audits (
created_at,
operator_id,
conditions,
deleted_rows
) VALUES ($1,$2,$3,$4)
`, createdAt.UTC(), input.OperatorID, input.Conditions, input.DeletedRows)
return err
}
func buildOpsErrorLogsWhere(filter *service.OpsErrorLogFilter) (string, []any) { func buildOpsErrorLogsWhere(filter *service.OpsErrorLogFilter) (string, []any) {
clauses := make([]string, 0, 12) clauses := make([]string, 0, 12)
args := make([]any, 0, 12) args := make([]any, 0, 12)
...@@ -1053,6 +1291,95 @@ func buildOpsErrorLogsWhere(filter *service.OpsErrorLogFilter) (string, []any) { ...@@ -1053,6 +1291,95 @@ func buildOpsErrorLogsWhere(filter *service.OpsErrorLogFilter) (string, []any) {
return "WHERE " + strings.Join(clauses, " AND "), args return "WHERE " + strings.Join(clauses, " AND "), args
} }
func buildOpsSystemLogsWhere(filter *service.OpsSystemLogFilter) (string, []any, bool) {
clauses := make([]string, 0, 10)
args := make([]any, 0, 10)
clauses = append(clauses, "1=1")
hasConstraint := false
if filter != nil && filter.StartTime != nil && !filter.StartTime.IsZero() {
args = append(args, filter.StartTime.UTC())
clauses = append(clauses, "l.created_at >= $"+itoa(len(args)))
hasConstraint = true
}
if filter != nil && filter.EndTime != nil && !filter.EndTime.IsZero() {
args = append(args, filter.EndTime.UTC())
clauses = append(clauses, "l.created_at < $"+itoa(len(args)))
hasConstraint = true
}
if filter != nil {
if v := strings.ToLower(strings.TrimSpace(filter.Level)); v != "" {
args = append(args, v)
clauses = append(clauses, "LOWER(COALESCE(l.level,'')) = $"+itoa(len(args)))
hasConstraint = true
}
if v := strings.TrimSpace(filter.Component); v != "" {
args = append(args, v)
clauses = append(clauses, "COALESCE(l.component,'') = $"+itoa(len(args)))
hasConstraint = true
}
if v := strings.TrimSpace(filter.RequestID); v != "" {
args = append(args, v)
clauses = append(clauses, "COALESCE(l.request_id,'') = $"+itoa(len(args)))
hasConstraint = true
}
if v := strings.TrimSpace(filter.ClientRequestID); v != "" {
args = append(args, v)
clauses = append(clauses, "COALESCE(l.client_request_id,'') = $"+itoa(len(args)))
hasConstraint = true
}
if filter.UserID != nil && *filter.UserID > 0 {
args = append(args, *filter.UserID)
clauses = append(clauses, "l.user_id = $"+itoa(len(args)))
hasConstraint = true
}
if filter.AccountID != nil && *filter.AccountID > 0 {
args = append(args, *filter.AccountID)
clauses = append(clauses, "l.account_id = $"+itoa(len(args)))
hasConstraint = true
}
if v := strings.TrimSpace(filter.Platform); v != "" {
args = append(args, v)
clauses = append(clauses, "COALESCE(l.platform,'') = $"+itoa(len(args)))
hasConstraint = true
}
if v := strings.TrimSpace(filter.Model); v != "" {
args = append(args, v)
clauses = append(clauses, "COALESCE(l.model,'') = $"+itoa(len(args)))
hasConstraint = true
}
if v := strings.TrimSpace(filter.Query); v != "" {
like := "%" + v + "%"
args = append(args, like)
n := itoa(len(args))
clauses = append(clauses, "(l.message ILIKE $"+n+" OR COALESCE(l.request_id,'') ILIKE $"+n+" OR COALESCE(l.client_request_id,'') ILIKE $"+n+" OR COALESCE(l.extra::text,'') ILIKE $"+n+")")
hasConstraint = true
}
}
return "WHERE " + strings.Join(clauses, " AND "), args, hasConstraint
}
func buildOpsSystemLogsCleanupWhere(filter *service.OpsSystemLogCleanupFilter) (string, []any, bool) {
if filter == nil {
filter = &service.OpsSystemLogCleanupFilter{}
}
listFilter := &service.OpsSystemLogFilter{
StartTime: filter.StartTime,
EndTime: filter.EndTime,
Level: filter.Level,
Component: filter.Component,
RequestID: filter.RequestID,
ClientRequestID: filter.ClientRequestID,
UserID: filter.UserID,
AccountID: filter.AccountID,
Platform: filter.Platform,
Model: filter.Model,
Query: filter.Query,
}
return buildOpsSystemLogsWhere(listFilter)
}
// Helpers for nullable args // Helpers for nullable args
func opsNullString(v any) any { func opsNullString(v any) any {
switch s := v.(type) { switch s := v.(type) {
......
package repository
import (
"strings"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
)
func TestBuildOpsSystemLogsWhere_WithClientRequestIDAndUserID(t *testing.T) {
start := time.Date(2026, 2, 1, 0, 0, 0, 0, time.UTC)
end := time.Date(2026, 2, 2, 0, 0, 0, 0, time.UTC)
userID := int64(12)
accountID := int64(34)
filter := &service.OpsSystemLogFilter{
StartTime: &start,
EndTime: &end,
Level: "warn",
Component: "http.access",
RequestID: "req-1",
ClientRequestID: "creq-1",
UserID: &userID,
AccountID: &accountID,
Platform: "openai",
Model: "gpt-5",
Query: "timeout",
}
where, args, hasConstraint := buildOpsSystemLogsWhere(filter)
if !hasConstraint {
t.Fatalf("expected hasConstraint=true")
}
if where == "" {
t.Fatalf("where should not be empty")
}
if len(args) != 11 {
t.Fatalf("args len = %d, want 11", len(args))
}
if !contains(where, "COALESCE(l.client_request_id,'') = $") {
t.Fatalf("where should include client_request_id condition: %s", where)
}
if !contains(where, "l.user_id = $") {
t.Fatalf("where should include user_id condition: %s", where)
}
}
func TestBuildOpsSystemLogsCleanupWhere_RequireConstraint(t *testing.T) {
where, args, hasConstraint := buildOpsSystemLogsCleanupWhere(&service.OpsSystemLogCleanupFilter{})
if hasConstraint {
t.Fatalf("expected hasConstraint=false")
}
if where == "" {
t.Fatalf("where should not be empty")
}
if len(args) != 0 {
t.Fatalf("args len = %d, want 0", len(args))
}
}
func TestBuildOpsSystemLogsCleanupWhere_WithClientRequestIDAndUserID(t *testing.T) {
userID := int64(9)
filter := &service.OpsSystemLogCleanupFilter{
ClientRequestID: "creq-9",
UserID: &userID,
}
where, args, hasConstraint := buildOpsSystemLogsCleanupWhere(filter)
if !hasConstraint {
t.Fatalf("expected hasConstraint=true")
}
if len(args) != 2 {
t.Fatalf("args len = %d, want 2", len(args))
}
if !contains(where, "COALESCE(l.client_request_id,'') = $") {
t.Fatalf("where should include client_request_id condition: %s", where)
}
if !contains(where, "l.user_id = $") {
t.Fatalf("where should include user_id condition: %s", where)
}
}
func contains(s string, sub string) bool {
return strings.Contains(s, sub)
}
...@@ -2,10 +2,13 @@ package middleware ...@@ -2,10 +2,13 @@ package middleware
import ( import (
"context" "context"
"strings"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey" "github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/google/uuid" "github.com/google/uuid"
"go.uber.org/zap"
) )
// ClientRequestID ensures every request has a unique client_request_id in request.Context(). // ClientRequestID ensures every request has a unique client_request_id in request.Context().
...@@ -24,7 +27,10 @@ func ClientRequestID() gin.HandlerFunc { ...@@ -24,7 +27,10 @@ func ClientRequestID() gin.HandlerFunc {
} }
id := uuid.New().String() id := uuid.New().String()
c.Request = c.Request.WithContext(context.WithValue(c.Request.Context(), ctxkey.ClientRequestID, id)) ctx := context.WithValue(c.Request.Context(), ctxkey.ClientRequestID, id)
requestLogger := logger.FromContext(ctx).With(zap.String("client_request_id", strings.TrimSpace(id)))
ctx = logger.IntoContext(ctx, requestLogger)
c.Request = c.Request.WithContext(ctx)
c.Next() c.Next()
} }
} }
package middleware package middleware
import ( import (
"log"
"time" "time"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"go.uber.org/zap"
) )
// Logger 请求日志中间件 // Logger 请求日志中间件
...@@ -24,38 +26,71 @@ func Logger() gin.HandlerFunc { ...@@ -24,38 +26,71 @@ func Logger() gin.HandlerFunc {
return return
} }
// 结束时间
endTime := time.Now() endTime := time.Now()
// 执行时间
latency := endTime.Sub(startTime) latency := endTime.Sub(startTime)
// 请求方法
method := c.Request.Method method := c.Request.Method
// 状态码
statusCode := c.Writer.Status() statusCode := c.Writer.Status()
// 客户端IP
clientIP := c.ClientIP() clientIP := c.ClientIP()
// 协议版本
protocol := c.Request.Proto protocol := c.Request.Proto
accountID, hasAccountID := c.Request.Context().Value(ctxkey.AccountID).(int64)
platform, _ := c.Request.Context().Value(ctxkey.Platform).(string)
model, _ := c.Request.Context().Value(ctxkey.Model).(string)
fields := []zap.Field{
zap.String("component", "http.access"),
zap.Int("status_code", statusCode),
zap.Int64("latency_ms", latency.Milliseconds()),
zap.String("client_ip", clientIP),
zap.String("protocol", protocol),
zap.String("method", method),
zap.String("path", path),
}
if hasAccountID && accountID > 0 {
fields = append(fields, zap.Int64("account_id", accountID))
}
if platform != "" {
fields = append(fields, zap.String("platform", platform))
}
if model != "" {
fields = append(fields, zap.String("model", model))
}
l := logger.FromContext(c.Request.Context()).With(fields...)
l.Info("http request completed", zap.Time("completed_at", endTime))
// 当全局日志级别高于 info(如 warn/error)时,access info 不会进入 zap core,
// 这里补写一次 sink,保证 ops 系统日志仍可索引关键访问轨迹。
if !logger.L().Core().Enabled(logger.LevelInfo) {
sinkFields := map[string]any{
"component": "http.access",
"status_code": statusCode,
"latency_ms": latency.Milliseconds(),
"client_ip": clientIP,
"protocol": protocol,
"method": method,
"path": path,
"completed_at": endTime,
}
if requestID, ok := c.Request.Context().Value(ctxkey.RequestID).(string); ok && requestID != "" {
sinkFields["request_id"] = requestID
}
if clientRequestID, ok := c.Request.Context().Value(ctxkey.ClientRequestID).(string); ok && clientRequestID != "" {
sinkFields["client_request_id"] = clientRequestID
}
if hasAccountID && accountID > 0 {
sinkFields["account_id"] = accountID
}
if platform != "" {
sinkFields["platform"] = platform
}
if model != "" {
sinkFields["model"] = model
}
logger.WriteSinkEvent("info", "http.access", "http request completed", sinkFields)
}
// 日志格式: [时间] 状态码 | 延迟 | IP | 协议 | 方法 路径
log.Printf("[GIN] %v | %3d | %13v | %15s | %-6s | %-7s %s",
endTime.Format("2006/01/02 - 15:04:05"),
statusCode,
latency,
clientIP,
protocol,
method,
path,
)
// 如果有错误,额外记录错误信息
if len(c.Errors) > 0 { if len(c.Errors) > 0 {
log.Printf("[GIN] Errors: %v", c.Errors.String()) l.Warn("http request contains gin errors", zap.String("errors", c.Errors.String()))
} }
} }
} }
package middleware
import (
"context"
"net/http"
"net/http/httptest"
"strings"
"sync"
"testing"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"github.com/gin-gonic/gin"
)
type testLogSink struct {
mu sync.Mutex
events []*logger.LogEvent
}
func (s *testLogSink) WriteLogEvent(event *logger.LogEvent) {
s.mu.Lock()
defer s.mu.Unlock()
s.events = append(s.events, event)
}
func (s *testLogSink) list() []*logger.LogEvent {
s.mu.Lock()
defer s.mu.Unlock()
out := make([]*logger.LogEvent, len(s.events))
copy(out, s.events)
return out
}
func initMiddlewareTestLogger(t *testing.T) *testLogSink {
return initMiddlewareTestLoggerWithLevel(t, "debug")
}
func initMiddlewareTestLoggerWithLevel(t *testing.T, level string) *testLogSink {
t.Helper()
level = strings.TrimSpace(level)
if level == "" {
level = "debug"
}
if err := logger.Init(logger.InitOptions{
Level: level,
Format: "json",
ServiceName: "sub2api",
Environment: "test",
Output: logger.OutputOptions{
ToStdout: false,
ToFile: false,
},
}); err != nil {
t.Fatalf("init logger: %v", err)
}
sink := &testLogSink{}
logger.SetSink(sink)
t.Cleanup(func() {
logger.SetSink(nil)
})
return sink
}
func TestRequestLogger_GenerateAndPropagateRequestID(t *testing.T) {
gin.SetMode(gin.TestMode)
r := gin.New()
r.Use(RequestLogger())
r.GET("/t", func(c *gin.Context) {
reqID, ok := c.Request.Context().Value(ctxkey.RequestID).(string)
if !ok || reqID == "" {
t.Fatalf("request_id missing in context")
}
if got := c.Writer.Header().Get(requestIDHeader); got != reqID {
t.Fatalf("response header request_id mismatch, header=%q ctx=%q", got, reqID)
}
c.Status(http.StatusOK)
})
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/t", nil)
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("status=%d", w.Code)
}
if w.Header().Get(requestIDHeader) == "" {
t.Fatalf("X-Request-ID should be set")
}
}
func TestRequestLogger_KeepIncomingRequestID(t *testing.T) {
gin.SetMode(gin.TestMode)
r := gin.New()
r.Use(RequestLogger())
r.GET("/t", func(c *gin.Context) {
reqID, _ := c.Request.Context().Value(ctxkey.RequestID).(string)
if reqID != "rid-fixed" {
t.Fatalf("request_id=%q, want rid-fixed", reqID)
}
c.Status(http.StatusOK)
})
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/t", nil)
req.Header.Set(requestIDHeader, "rid-fixed")
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("status=%d", w.Code)
}
if got := w.Header().Get(requestIDHeader); got != "rid-fixed" {
t.Fatalf("header=%q, want rid-fixed", got)
}
}
func TestLogger_AccessLogIncludesCoreFields(t *testing.T) {
gin.SetMode(gin.TestMode)
sink := initMiddlewareTestLogger(t)
r := gin.New()
r.Use(Logger())
r.Use(func(c *gin.Context) {
ctx := c.Request.Context()
ctx = context.WithValue(ctx, ctxkey.AccountID, int64(101))
ctx = context.WithValue(ctx, ctxkey.Platform, "openai")
ctx = context.WithValue(ctx, ctxkey.Model, "gpt-5")
c.Request = c.Request.WithContext(ctx)
c.Next()
})
r.GET("/api/test", func(c *gin.Context) {
c.Status(http.StatusCreated)
})
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/api/test", nil)
r.ServeHTTP(w, req)
if w.Code != http.StatusCreated {
t.Fatalf("status=%d", w.Code)
}
events := sink.list()
if len(events) == 0 {
t.Fatalf("expected at least one log event")
}
found := false
for _, event := range events {
if event == nil || event.Message != "http request completed" {
continue
}
found = true
switch v := event.Fields["status_code"].(type) {
case int:
if v != http.StatusCreated {
t.Fatalf("status_code field mismatch: %v", v)
}
case int64:
if v != int64(http.StatusCreated) {
t.Fatalf("status_code field mismatch: %v", v)
}
default:
t.Fatalf("status_code type mismatch: %T", v)
}
switch v := event.Fields["account_id"].(type) {
case int64:
if v != 101 {
t.Fatalf("account_id field mismatch: %v", v)
}
case int:
if v != 101 {
t.Fatalf("account_id field mismatch: %v", v)
}
default:
t.Fatalf("account_id type mismatch: %T", v)
}
if event.Fields["platform"] != "openai" || event.Fields["model"] != "gpt-5" {
t.Fatalf("platform/model mismatch: %+v", event.Fields)
}
}
if !found {
t.Fatalf("access log event not found")
}
}
func TestLogger_HealthPathSkipped(t *testing.T) {
gin.SetMode(gin.TestMode)
sink := initMiddlewareTestLogger(t)
r := gin.New()
r.Use(Logger())
r.GET("/health", func(c *gin.Context) {
c.Status(http.StatusOK)
})
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/health", nil)
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("status=%d", w.Code)
}
if len(sink.list()) != 0 {
t.Fatalf("health endpoint should not write access log")
}
}
func TestLogger_AccessLogStillIndexedWhenLevelWarn(t *testing.T) {
gin.SetMode(gin.TestMode)
sink := initMiddlewareTestLoggerWithLevel(t, "warn")
r := gin.New()
r.Use(RequestLogger())
r.Use(Logger())
r.GET("/api/test", func(c *gin.Context) {
c.Status(http.StatusCreated)
})
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/api/test", nil)
r.ServeHTTP(w, req)
if w.Code != http.StatusCreated {
t.Fatalf("status=%d", w.Code)
}
events := sink.list()
if len(events) == 0 {
t.Fatalf("expected access log event to be indexed when level=warn")
}
found := false
for _, event := range events {
if event == nil || event.Message != "http request completed" {
continue
}
found = true
if event.Level != "info" {
t.Fatalf("event level=%q, want info", event.Level)
}
if event.Component != "http.access" && event.Fields["component"] != "http.access" {
t.Fatalf("event component mismatch: component=%q fields=%v", event.Component, event.Fields["component"])
}
if _, ok := event.Fields["status_code"]; !ok {
t.Fatalf("status_code field missing: %+v", event.Fields)
}
if _, ok := event.Fields["request_id"]; !ok {
t.Fatalf("request_id field missing: %+v", event.Fields)
}
}
if !found {
t.Fatalf("access log event not found")
}
}
package middleware
import (
"context"
"strings"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
"go.uber.org/zap"
)
const requestIDHeader = "X-Request-ID"
// RequestLogger 在请求入口注入 request-scoped logger。
func RequestLogger() gin.HandlerFunc {
return func(c *gin.Context) {
if c.Request == nil {
c.Next()
return
}
requestID := strings.TrimSpace(c.GetHeader(requestIDHeader))
if requestID == "" {
requestID = uuid.NewString()
}
c.Header(requestIDHeader, requestID)
ctx := context.WithValue(c.Request.Context(), ctxkey.RequestID, requestID)
clientRequestID, _ := ctx.Value(ctxkey.ClientRequestID).(string)
requestLogger := logger.With(
zap.String("component", "http"),
zap.String("request_id", requestID),
zap.String("client_request_id", strings.TrimSpace(clientRequestID)),
zap.String("path", c.Request.URL.Path),
zap.String("method", c.Request.Method),
)
ctx = logger.IntoContext(ctx, requestLogger)
c.Request = c.Request.WithContext(ctx)
c.Next()
}
}
...@@ -29,6 +29,7 @@ func SetupRouter( ...@@ -29,6 +29,7 @@ func SetupRouter(
redisClient *redis.Client, redisClient *redis.Client,
) *gin.Engine { ) *gin.Engine {
// 应用中间件 // 应用中间件
r.Use(middleware2.RequestLogger())
r.Use(middleware2.Logger()) r.Use(middleware2.Logger())
r.Use(middleware2.CORS(cfg.CORS)) r.Use(middleware2.CORS(cfg.CORS))
r.Use(middleware2.SecurityHeaders(cfg.Security.CSP)) r.Use(middleware2.SecurityHeaders(cfg.Security.CSP))
......
...@@ -101,6 +101,9 @@ func registerOpsRoutes(admin *gin.RouterGroup, h *handler.Handlers) { ...@@ -101,6 +101,9 @@ func registerOpsRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
{ {
runtime.GET("/alert", h.Admin.Ops.GetAlertRuntimeSettings) runtime.GET("/alert", h.Admin.Ops.GetAlertRuntimeSettings)
runtime.PUT("/alert", h.Admin.Ops.UpdateAlertRuntimeSettings) runtime.PUT("/alert", h.Admin.Ops.UpdateAlertRuntimeSettings)
runtime.GET("/logging", h.Admin.Ops.GetRuntimeLogConfig)
runtime.PUT("/logging", h.Admin.Ops.UpdateRuntimeLogConfig)
runtime.POST("/logging/reset", h.Admin.Ops.ResetRuntimeLogConfig)
} }
// Advanced settings (DB-backed) // Advanced settings (DB-backed)
...@@ -144,6 +147,11 @@ func registerOpsRoutes(admin *gin.RouterGroup, h *handler.Handlers) { ...@@ -144,6 +147,11 @@ func registerOpsRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
// Request drilldown (success + error) // Request drilldown (success + error)
ops.GET("/requests", h.Admin.Ops.ListRequestDetails) ops.GET("/requests", h.Admin.Ops.ListRequestDetails)
// Indexed system logs
ops.GET("/system-logs", h.Admin.Ops.ListSystemLogs)
ops.POST("/system-logs/cleanup", h.Admin.Ops.CleanupSystemLogs)
ops.GET("/system-logs/health", h.Admin.Ops.GetSystemLogIngestionHealth)
// Dashboard (vNext - raw path for MVP) // Dashboard (vNext - raw path for MVP)
ops.GET("/dashboard/overview", h.Admin.Ops.GetDashboardOverview) ops.GET("/dashboard/overview", h.Admin.Ops.GetDashboardOverview)
ops.GET("/dashboard/throughput-trend", h.Admin.Ops.GetDashboardThroughputTrend) ops.GET("/dashboard/throughput-trend", h.Admin.Ops.GetDashboardThroughputTrend)
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment