"backend/internal/handler/vscode:/vscode.git/clone" did not exist on "1245f07a2d4a0fe79c89364a440bcfc1d2af2c27"
Commit 64236361 authored by yangjianbo's avatar yangjianbo
Browse files

Merge branch 'test' into dev

parents 2d6066f9 b6aaee01
......@@ -8,7 +8,6 @@ import (
"errors"
"flag"
"log"
"log/slog"
"net/http"
"os"
"os/signal"
......@@ -19,6 +18,7 @@ import (
_ "github.com/Wei-Shaw/sub2api/ent/runtime"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/handler"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/setup"
"github.com/Wei-Shaw/sub2api/internal/web"
......@@ -49,22 +49,9 @@ func init() {
// initLogger configures the default slog handler based on gin.Mode().
// In non-release mode, Debug level logs are enabled.
func initLogger() {
var level slog.Level
if gin.Mode() == gin.ReleaseMode {
level = slog.LevelInfo
} else {
level = slog.LevelDebug
}
handler := slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{
Level: level,
})
slog.SetDefault(slog.New(handler))
}
func main() {
// Initialize slog logger based on gin mode
initLogger()
logger.InitBootstrap()
defer logger.Sync()
// Parse command line flags
setupMode := flag.Bool("setup", false, "Run setup wizard in CLI mode")
......@@ -141,6 +128,9 @@ func runMainServer() {
if err != nil {
log.Fatalf("Failed to load config: %v", err)
}
if err := logger.Init(logger.OptionsFromConfig(cfg.Log)); err != nil {
log.Fatalf("Failed to initialize logger: %v", err)
}
if cfg.RunMode == config.RunModeSimple {
log.Println("⚠️ WARNING: Running in SIMPLE mode - billing and quota checks are DISABLED")
}
......
......@@ -67,6 +67,7 @@ func provideCleanup(
opsAlertEvaluator *service.OpsAlertEvaluatorService,
opsCleanup *service.OpsCleanupService,
opsScheduledReport *service.OpsScheduledReportService,
opsSystemLogSink *service.OpsSystemLogSink,
soraMediaCleanup *service.SoraMediaCleanupService,
schedulerSnapshot *service.SchedulerSnapshotService,
tokenRefresh *service.TokenRefreshService,
......@@ -103,6 +104,12 @@ func provideCleanup(
}
return nil
}},
{"OpsSystemLogSink", func() error {
if opsSystemLogSink != nil {
opsSystemLogSink.Stop()
}
return nil
}},
{"SoraMediaCleanupService", func() error {
if soraMediaCleanup != nil {
soraMediaCleanup.Stop()
......
......@@ -160,7 +160,8 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
openAITokenProvider := service.NewOpenAITokenProvider(accountRepository, geminiTokenCache, openAIOAuthService)
openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService, openAITokenProvider)
geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, groupRepository, gatewayCache, schedulerSnapshotService, geminiTokenProvider, rateLimitService, httpUpstream, antigravityGatewayService, configConfig)
opsService := service.NewOpsService(opsRepository, settingRepository, configConfig, accountRepository, userRepository, concurrencyService, gatewayService, openAIGatewayService, geminiMessagesCompatService, antigravityGatewayService)
opsSystemLogSink := service.ProvideOpsSystemLogSink(opsRepository)
opsService := service.NewOpsService(opsRepository, settingRepository, configConfig, accountRepository, userRepository, concurrencyService, gatewayService, openAIGatewayService, geminiMessagesCompatService, antigravityGatewayService, opsSystemLogSink)
settingHandler := admin.NewSettingHandler(settingService, emailService, turnstileService, opsService)
opsHandler := admin.NewOpsHandler(opsService)
updateCache := repository.NewUpdateCache(redisClient)
......@@ -204,7 +205,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, soraAccountRepository, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, compositeTokenCacheInvalidator, schedulerCache, configConfig)
accountExpiryService := service.ProvideAccountExpiryService(accountRepository)
subscriptionExpiryService := service.ProvideSubscriptionExpiryService(userSubscriptionRepository)
v := provideCleanup(client, redisClient, opsMetricsCollector, opsAggregationService, opsAlertEvaluatorService, opsCleanupService, opsScheduledReportService, soraMediaCleanupService, schedulerSnapshotService, tokenRefreshService, accountExpiryService, subscriptionExpiryService, usageCleanupService, pricingService, emailQueueService, billingCacheService, subscriptionService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService)
v := provideCleanup(client, redisClient, opsMetricsCollector, opsAggregationService, opsAlertEvaluatorService, opsCleanupService, opsScheduledReportService, opsSystemLogSink, soraMediaCleanupService, schedulerSnapshotService, tokenRefreshService, accountExpiryService, subscriptionExpiryService, usageCleanupService, pricingService, emailQueueService, billingCacheService, subscriptionService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService)
application := &Application{
Server: httpServer,
Cleanup: v,
......@@ -234,6 +235,7 @@ func provideCleanup(
opsAlertEvaluator *service.OpsAlertEvaluatorService,
opsCleanup *service.OpsCleanupService,
opsScheduledReport *service.OpsScheduledReportService,
opsSystemLogSink *service.OpsSystemLogSink,
soraMediaCleanup *service.SoraMediaCleanupService,
schedulerSnapshot *service.SchedulerSnapshotService,
tokenRefresh *service.TokenRefreshService,
......@@ -269,6 +271,12 @@ func provideCleanup(
}
return nil
}},
{"OpsSystemLogSink", func() error {
if opsSystemLogSink != nil {
opsSystemLogSink.Stop()
}
return nil
}},
{"SoraMediaCleanupService", func() error {
if soraMediaCleanup != nil {
soraMediaCleanup.Stop()
......
......@@ -5,6 +5,7 @@ go 1.25.7
require (
entgo.io/ent v0.14.5
github.com/DATA-DOG/go-sqlmock v1.5.2
github.com/cespare/xxhash/v2 v2.3.0
github.com/dgraph-io/ristretto v0.2.0
github.com/gin-gonic/gin v1.9.1
github.com/golang-jwt/jwt/v5 v5.2.2
......@@ -13,6 +14,7 @@ require (
github.com/gorilla/websocket v1.5.3
github.com/imroc/req/v3 v3.57.0
github.com/lib/pq v1.10.9
github.com/patrickmn/go-cache v2.1.0+incompatible
github.com/pquerna/otp v1.5.0
github.com/redis/go-redis/v9 v9.17.2
github.com/refraction-networking/utls v1.8.1
......@@ -25,10 +27,12 @@ require (
github.com/tidwall/gjson v1.18.0
github.com/tidwall/sjson v1.2.5
github.com/zeromicro/go-zero v1.9.4
go.uber.org/zap v1.24.0
golang.org/x/crypto v0.47.0
golang.org/x/net v0.49.0
golang.org/x/sync v0.19.0
golang.org/x/term v0.39.0
gopkg.in/natefinch/lumberjack.v2 v2.2.1
gopkg.in/yaml.v3 v3.0.1
modernc.org/sqlite v1.44.3
)
......@@ -45,7 +49,6 @@ require (
github.com/boombuler/barcode v1.0.1-0.20190219062509-6c824513bacc // indirect
github.com/bytedance/sonic v1.9.1 // indirect
github.com/cenkalti/backoff/v4 v4.3.0 // indirect
github.com/cespare/xxhash/v2 v2.3.0 // indirect
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 // indirect
github.com/containerd/errdefs v1.0.0 // indirect
github.com/containerd/errdefs/pkg v0.3.0 // indirect
......@@ -104,7 +107,6 @@ require (
github.com/ncruces/go-strftime v1.0.0 // indirect
github.com/opencontainers/go-digest v1.0.0 // indirect
github.com/opencontainers/image-spec v1.1.1 // indirect
github.com/patrickmn/go-cache v2.1.0+incompatible // indirect
github.com/pelletier/go-toml/v2 v2.2.2 // indirect
github.com/pkg/errors v0.9.1 // indirect
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect
......
......@@ -18,6 +18,8 @@ github.com/andybalholm/brotli v1.2.0 h1:ukwgCxwYrmACq68yiUqwIWnGY0cTPox/M94sVwTo
github.com/andybalholm/brotli v1.2.0/go.mod h1:rzTDkvFWvIrjDXZHkuS16NPggd91W3kUSvPlQ1pLaKY=
github.com/apparentlymart/go-textseg/v15 v15.0.0 h1:uYvfpb3DyLSCGWnctWKGj857c6ew1u1fNQOlOtuGxQY=
github.com/apparentlymart/go-textseg/v15 v15.0.0/go.mod h1:K8XmNZdhEBkdlyDdvbmmsvpAG721bKi0joRfFdHIWJ4=
github.com/benbjohnson/clock v1.1.0 h1:Q92kusRqC1XV2MjkWETPvjJVqKetz1OzxZB7mHJLju8=
github.com/benbjohnson/clock v1.1.0/go.mod h1:J11/hYXuz8f4ySSvYwY0FKfm+ezbsZBKZxNJlLklBHA=
github.com/bmatcuk/doublestar v1.3.4 h1:gPypJ5xD31uhX6Tf54sDPUOBXTqKH4c9aPY66CyQrS0=
github.com/bmatcuk/doublestar v1.3.4/go.mod h1:wiQtGV+rzVYxB7WIlirSN++5HPtPlXEo9MEoZQC/PmE=
github.com/boombuler/barcode v1.0.1-0.20190219062509-6c824513bacc h1:biVzkmvwrH8WK8raXaxBx6fRVTlJILwEwQGL1I/ByEI=
......@@ -137,8 +139,6 @@ github.com/icholy/digest v1.1.0 h1:HfGg9Irj7i+IX1o1QAmPfIBNu/Q5A5Tu3n/MED9k9H4=
github.com/icholy/digest v1.1.0/go.mod h1:QNrsSGQ5v7v9cReDI0+eyjsXGUoRSUZQHeQ5C4XLa0Y=
github.com/imroc/req/v3 v3.57.0 h1:LMTUjNRUybUkTPn8oJDq8Kg3JRBOBTcnDhKu7mzupKI=
github.com/imroc/req/v3 v3.57.0/go.mod h1:JL62ey1nvSLq81HORNcosvlf7SxZStONNqOprg0Pz00=
github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8=
github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw=
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo=
......@@ -174,8 +174,6 @@ github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovk
github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM=
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/mattn/go-runewidth v0.0.15 h1:UNAjwbU9l54TA3KzvqLGxwWjHmMgBUVhBiTjelZgg3U=
github.com/mattn/go-runewidth v0.0.15/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w=
github.com/mattn/go-sqlite3 v1.14.17 h1:mCRHCLDUBXgpKAqIKsaAaAsrAlbkeomtRFKXh2L6YIM=
github.com/mattn/go-sqlite3 v1.14.17/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg=
github.com/mdelapenya/tlscert v0.2.0 h1:7H81W6Z/4weDvZBNOfQte5GpIMo0lGYEeWbkGp5LJHI=
......@@ -209,8 +207,6 @@ github.com/morikuni/aec v1.0.0 h1:nP9CBfwrvYnBRgY6qfDQkygYDmYwOilePFkwzv4dU8A=
github.com/morikuni/aec v1.0.0/go.mod h1:BbKIizmSmc5MMPqRYbxO4ZU0S0+P200+tUnFx7PXmsc=
github.com/ncruces/go-strftime v1.0.0 h1:HMFp8mLCTPp341M/ZnA4qaf7ZlsbTc+miZjCLOFAw7w=
github.com/ncruces/go-strftime v1.0.0/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls=
github.com/olekukonko/tablewriter v0.0.5 h1:P2Ga83D34wi1o9J6Wh1mRuqd4mF/x/lgBS7N7AbDhec=
github.com/olekukonko/tablewriter v0.0.5/go.mod h1:hPp6KlRPjbx+hW8ykQs1w3UBbZlj6HuIJcUGPhkA7kY=
github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8Oi/yOhh5U=
github.com/opencontainers/go-digest v1.0.0/go.mod h1:0JzlMkj0TRzQZfJkVvzbP0HBR3IKzErnv2BNG4W4MAM=
github.com/opencontainers/image-spec v1.1.1 h1:y0fUlFfIZhPF1W537XOLg0/fcx6zcHCJwooC2xJA040=
......@@ -240,8 +236,6 @@ github.com/refraction-networking/utls v1.8.1 h1:yNY1kapmQU8JeM1sSw2H2asfTIwWxIkr
github.com/refraction-networking/utls v1.8.1/go.mod h1:jkSOEkLqn+S/jtpEHPOsVv/4V4EVnelwbMQl4vCWXAM=
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE=
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
github.com/rivo/uniseg v0.2.0 h1:S1pD9weZBuJdFmowNwbpi7BJ8TNftyUImj/0WQi72jY=
github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc=
github.com/robfig/cron/v3 v3.0.1 h1:WdRxkvbJztn8LMz/QEvLN5sBU+xKpSqwwUO1Pjr4qDs=
github.com/robfig/cron/v3 v3.0.1/go.mod h1:eQICP3HwyT7UooqI/z+Ov+PtYAWygg1TEWWzGIFLtro=
github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR38lUII=
......@@ -264,8 +258,6 @@ github.com/spf13/afero v1.11.0 h1:WJQKhtpdm3v2IzqG8VMqrr6Rf3UYpEF239Jy9wNepM8=
github.com/spf13/afero v1.11.0/go.mod h1:GH9Y3pIexgf1MTIWtNGyogA5MwRIDXGUr+hbWNoBjkY=
github.com/spf13/cast v1.6.0 h1:GEiTHELF+vaR5dhz3VqZfFSzZjYbgeKDpBxQVS4GYJ0=
github.com/spf13/cast v1.6.0/go.mod h1:ancEpBxwJDODSW/UG4rDrAqiKolqNNh2DX3mk86cAdo=
github.com/spf13/cobra v1.7.0 h1:hyqWnYt1ZQShIddO5kBpj3vu05/++x6tJ6dg8EC572I=
github.com/spf13/cobra v1.7.0/go.mod h1:uLxZILRyS/50WlhOIKD7W6V5bgeIt+4sICxh6uRMrb0=
github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA=
github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
github.com/spf13/viper v1.18.2 h1:LUXCnvUvSM6FXAsj6nnfc8Q2tp1dIgUfY9Kc8GsSOiQ=
......@@ -342,10 +334,14 @@ go.uber.org/atomic v1.10.0 h1:9qC72Qh0+3MqyJbAn8YU5xVq1frD8bn3JtD2oXtafVQ=
go.uber.org/atomic v1.10.0/go.mod h1:LUxbIzbOniOlMKjJjyPfpl4v+PKK2cNJn91OQbhoJI0=
go.uber.org/automaxprocs v1.6.0 h1:O3y2/QNTOdbF+e/dpXNNW7Rx2hZ4sTIPyybbxyNqTUs=
go.uber.org/automaxprocs v1.6.0/go.mod h1:ifeIMSnPZuznNm6jmdzmU3/bfk01Fe2fotchwEFJ8r8=
go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto=
go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE=
go.uber.org/mock v0.6.0 h1:hyF9dfmbgIX5EfOdasqLsWD6xqpNZlXblLB/Dbnwv3Y=
go.uber.org/mock v0.6.0/go.mod h1:KiVJ4BqZJaMj4svdfmHM0AUx4NJYO8ZNpPnZn1Z+BBU=
go.uber.org/multierr v1.9.0 h1:7fIwc/ZtS0q++VgcfqFDxSBZVv/Xo49/SYnDFupUwlI=
go.uber.org/multierr v1.9.0/go.mod h1:X2jQV1h+kxSjClGpnseKVIxpmcjrj7MNnI0bnlfKTVQ=
go.uber.org/zap v1.24.0 h1:FiJd5l1UOLj0wCgbSE0rwwXHzEdAZS6hiiSnxJN/D60=
go.uber.org/zap v1.24.0/go.mod h1:2kMP+WWQ8aoFoedH3T2sq6iJ2yDWpHbP0f6MQbS9Gkg=
golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
golang.org/x/arch v0.3.0 h1:02VY4/ZcO/gBOH6PUaoiptASxtXU10jazRCP865E97k=
golang.org/x/arch v0.3.0/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
......@@ -393,6 +389,8 @@ gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntN
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
gopkg.in/ini.v1 v1.67.0 h1:Dgnx+6+nfE+IfzjUEISNeydPJh9AXNNsWbGP9KzCsOA=
gopkg.in/ini.v1 v1.67.0/go.mod h1:pNLf8WUiyNEtQjuu5G5vTm06TEv9tsIgeAvK8hOrP4k=
gopkg.in/natefinch/lumberjack.v2 v2.2.1 h1:bBRl1b0OH9s/DuPhuXpNl+VtCaJXFZ5/uEFST95x9zc=
gopkg.in/natefinch/lumberjack.v2 v2.2.1/go.mod h1:YD8tP3GAjkrDg1eZH7EGmyESg/lsYskCTPBJVb9jqSc=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
......
......@@ -5,7 +5,7 @@ import (
"crypto/rand"
"encoding/hex"
"fmt"
"log"
"log/slog"
"net/url"
"os"
"strings"
......@@ -39,6 +39,7 @@ const (
type Config struct {
Server ServerConfig `mapstructure:"server"`
Log LogConfig `mapstructure:"log"`
CORS CORSConfig `mapstructure:"cors"`
Security SecurityConfig `mapstructure:"security"`
Billing BillingConfig `mapstructure:"billing"`
......@@ -68,6 +69,38 @@ type Config struct {
Update UpdateConfig `mapstructure:"update"`
}
type LogConfig struct {
Level string `mapstructure:"level"`
Format string `mapstructure:"format"`
ServiceName string `mapstructure:"service_name"`
Environment string `mapstructure:"env"`
Caller bool `mapstructure:"caller"`
StacktraceLevel string `mapstructure:"stacktrace_level"`
Output LogOutputConfig `mapstructure:"output"`
Rotation LogRotationConfig `mapstructure:"rotation"`
Sampling LogSamplingConfig `mapstructure:"sampling"`
}
type LogOutputConfig struct {
ToStdout bool `mapstructure:"to_stdout"`
ToFile bool `mapstructure:"to_file"`
FilePath string `mapstructure:"file_path"`
}
type LogRotationConfig struct {
MaxSizeMB int `mapstructure:"max_size_mb"`
MaxBackups int `mapstructure:"max_backups"`
MaxAgeDays int `mapstructure:"max_age_days"`
Compress bool `mapstructure:"compress"`
LocalTime bool `mapstructure:"local_time"`
}
type LogSamplingConfig struct {
Enabled bool `mapstructure:"enabled"`
Initial int `mapstructure:"initial"`
Thereafter int `mapstructure:"thereafter"`
}
type GeminiConfig struct {
OAuth GeminiOAuthConfig `mapstructure:"oauth"`
Quota GeminiQuotaConfig `mapstructure:"quota"`
......@@ -756,6 +789,12 @@ func load(allowMissingJWTSecret bool) (*Config, error) {
cfg.Security.ResponseHeaders.AdditionalAllowed = normalizeStringSlice(cfg.Security.ResponseHeaders.AdditionalAllowed)
cfg.Security.ResponseHeaders.ForceRemove = normalizeStringSlice(cfg.Security.ResponseHeaders.ForceRemove)
cfg.Security.CSP.Policy = strings.TrimSpace(cfg.Security.CSP.Policy)
cfg.Log.Level = strings.ToLower(strings.TrimSpace(cfg.Log.Level))
cfg.Log.Format = strings.ToLower(strings.TrimSpace(cfg.Log.Format))
cfg.Log.ServiceName = strings.TrimSpace(cfg.Log.ServiceName)
cfg.Log.Environment = strings.TrimSpace(cfg.Log.Environment)
cfg.Log.StacktraceLevel = strings.ToLower(strings.TrimSpace(cfg.Log.StacktraceLevel))
cfg.Log.Output.FilePath = strings.TrimSpace(cfg.Log.Output.FilePath)
// Auto-generate TOTP encryption key if not set (32 bytes = 64 hex chars for AES-256)
cfg.Totp.EncryptionKey = strings.TrimSpace(cfg.Totp.EncryptionKey)
......@@ -766,7 +805,7 @@ func load(allowMissingJWTSecret bool) (*Config, error) {
}
cfg.Totp.EncryptionKey = key
cfg.Totp.EncryptionKeyConfigured = false
log.Println("Warning: TOTP encryption key auto-generated. Consider setting a fixed key for production.")
slog.Warn("TOTP encryption key auto-generated. Consider setting a fixed key for production.")
} else {
cfg.Totp.EncryptionKeyConfigured = true
}
......@@ -786,19 +825,19 @@ func load(allowMissingJWTSecret bool) (*Config, error) {
}
if !cfg.Security.URLAllowlist.Enabled {
log.Println("Warning: security.url_allowlist.enabled=false; allowlist/SSRF checks disabled (minimal format validation only).")
slog.Warn("security.url_allowlist.enabled=false; allowlist/SSRF checks disabled (minimal format validation only).")
}
if !cfg.Security.ResponseHeaders.Enabled {
log.Println("Warning: security.response_headers.enabled=false; configurable header filtering disabled (default allowlist only).")
slog.Warn("security.response_headers.enabled=false; configurable header filtering disabled (default allowlist only).")
}
if cfg.JWT.Secret != "" && isWeakJWTSecret(cfg.JWT.Secret) {
log.Println("Warning: JWT secret appears weak; use a 32+ character random secret in production.")
slog.Warn("JWT secret appears weak; use a 32+ character random secret in production.")
}
if len(cfg.Security.ResponseHeaders.AdditionalAllowed) > 0 || len(cfg.Security.ResponseHeaders.ForceRemove) > 0 {
log.Printf("AUDIT: response header policy configured additional_allowed=%v force_remove=%v",
cfg.Security.ResponseHeaders.AdditionalAllowed,
cfg.Security.ResponseHeaders.ForceRemove,
slog.Info("response header policy configured",
"additional_allowed", cfg.Security.ResponseHeaders.AdditionalAllowed,
"force_remove", cfg.Security.ResponseHeaders.ForceRemove,
)
}
......@@ -825,6 +864,25 @@ func setDefaults() {
viper.SetDefault("server.h2c.max_upload_buffer_per_connection", 2<<20) // 2MB
viper.SetDefault("server.h2c.max_upload_buffer_per_stream", 512<<10) // 512KB
// Log
viper.SetDefault("log.level", "info")
viper.SetDefault("log.format", "console")
viper.SetDefault("log.service_name", "sub2api")
viper.SetDefault("log.env", "production")
viper.SetDefault("log.caller", true)
viper.SetDefault("log.stacktrace_level", "error")
viper.SetDefault("log.output.to_stdout", true)
viper.SetDefault("log.output.to_file", true)
viper.SetDefault("log.output.file_path", "")
viper.SetDefault("log.rotation.max_size_mb", 100)
viper.SetDefault("log.rotation.max_backups", 10)
viper.SetDefault("log.rotation.max_age_days", 7)
viper.SetDefault("log.rotation.compress", true)
viper.SetDefault("log.rotation.local_time", true)
viper.SetDefault("log.sampling.enabled", false)
viper.SetDefault("log.sampling.initial", 100)
viper.SetDefault("log.sampling.thereafter", 100)
// CORS
viper.SetDefault("cors.allowed_origins", []string{})
viper.SetDefault("cors.allow_credentials", true)
......@@ -1098,6 +1156,54 @@ func (c *Config) Validate() error {
if len([]byte(jwtSecret)) < 32 {
return fmt.Errorf("jwt.secret must be at least 32 bytes")
}
switch c.Log.Level {
case "debug", "info", "warn", "error":
case "":
return fmt.Errorf("log.level is required")
default:
return fmt.Errorf("log.level must be one of: debug/info/warn/error")
}
switch c.Log.Format {
case "json", "console":
case "":
return fmt.Errorf("log.format is required")
default:
return fmt.Errorf("log.format must be one of: json/console")
}
switch c.Log.StacktraceLevel {
case "none", "error", "fatal":
case "":
return fmt.Errorf("log.stacktrace_level is required")
default:
return fmt.Errorf("log.stacktrace_level must be one of: none/error/fatal")
}
if !c.Log.Output.ToStdout && !c.Log.Output.ToFile {
return fmt.Errorf("log.output.to_stdout and log.output.to_file cannot both be false")
}
if c.Log.Rotation.MaxSizeMB <= 0 {
return fmt.Errorf("log.rotation.max_size_mb must be positive")
}
if c.Log.Rotation.MaxBackups < 0 {
return fmt.Errorf("log.rotation.max_backups must be non-negative")
}
if c.Log.Rotation.MaxAgeDays < 0 {
return fmt.Errorf("log.rotation.max_age_days must be non-negative")
}
if c.Log.Sampling.Enabled {
if c.Log.Sampling.Initial <= 0 {
return fmt.Errorf("log.sampling.initial must be positive when sampling is enabled")
}
if c.Log.Sampling.Thereafter <= 0 {
return fmt.Errorf("log.sampling.thereafter must be positive when sampling is enabled")
}
} else {
if c.Log.Sampling.Initial < 0 {
return fmt.Errorf("log.sampling.initial must be non-negative")
}
if c.Log.Sampling.Thereafter < 0 {
return fmt.Errorf("log.sampling.thereafter must be non-negative")
}
}
if c.SubscriptionMaintenance.WorkerCount < 0 {
return fmt.Errorf("subscription_maintenance.worker_count must be non-negative")
......@@ -1137,20 +1243,20 @@ func (c *Config) Validate() error {
return fmt.Errorf("jwt.expire_hour must be <= 168 (7 days)")
}
if c.JWT.ExpireHour > 24 {
log.Printf("Warning: jwt.expire_hour is %d hours (> 24). Consider shorter expiration for security.", c.JWT.ExpireHour)
slog.Warn("jwt.expire_hour is high; consider shorter expiration for security", "expire_hour", c.JWT.ExpireHour)
}
// JWT Refresh Token配置验证
if c.JWT.AccessTokenExpireMinutes <= 0 {
return fmt.Errorf("jwt.access_token_expire_minutes must be positive")
}
if c.JWT.AccessTokenExpireMinutes > 720 {
log.Printf("Warning: jwt.access_token_expire_minutes is %d (> 720). Consider shorter expiration for security.", c.JWT.AccessTokenExpireMinutes)
slog.Warn("jwt.access_token_expire_minutes is high; consider shorter expiration for security", "access_token_expire_minutes", c.JWT.AccessTokenExpireMinutes)
}
if c.JWT.RefreshTokenExpireDays <= 0 {
return fmt.Errorf("jwt.refresh_token_expire_days must be positive")
}
if c.JWT.RefreshTokenExpireDays > 90 {
log.Printf("Warning: jwt.refresh_token_expire_days is %d (> 90). Consider shorter expiration for security.", c.JWT.RefreshTokenExpireDays)
slog.Warn("jwt.refresh_token_expire_days is high; consider shorter expiration for security", "refresh_token_expire_days", c.JWT.RefreshTokenExpireDays)
}
if c.JWT.RefreshWindowMinutes < 0 {
return fmt.Errorf("jwt.refresh_window_minutes must be non-negative")
......@@ -1445,7 +1551,7 @@ func (c *Config) Validate() error {
return fmt.Errorf("gateway.idle_conn_timeout_seconds must be positive")
}
if c.Gateway.IdleConnTimeoutSeconds > 180 {
log.Printf("Warning: gateway.idle_conn_timeout_seconds is %d (> 180). Consider 60-120 seconds for better connection reuse.", c.Gateway.IdleConnTimeoutSeconds)
slog.Warn("gateway.idle_conn_timeout_seconds is high; consider 60-120 seconds for better connection reuse", "idle_conn_timeout_seconds", c.Gateway.IdleConnTimeoutSeconds)
}
if c.Gateway.MaxUpstreamClients <= 0 {
return fmt.Errorf("gateway.max_upstream_clients must be positive")
......@@ -1682,6 +1788,6 @@ func warnIfInsecureURL(field, raw string) {
return
}
if strings.EqualFold(u.Scheme, "http") {
log.Printf("Warning: %s uses http scheme; use https in production to avoid token leakage.", field)
slog.Warn("url uses http scheme; use https in production to avoid token leakage", "field", field)
}
}
......@@ -965,6 +965,37 @@ func TestValidateConfigErrors(t *testing.T) {
},
wantErr: "gateway.scheduling.outbox_lag_rebuild_seconds",
},
{
name: "log level invalid",
mutate: func(c *Config) { c.Log.Level = "trace" },
wantErr: "log.level",
},
{
name: "log format invalid",
mutate: func(c *Config) { c.Log.Format = "plain" },
wantErr: "log.format",
},
{
name: "log output disabled",
mutate: func(c *Config) {
c.Log.Output.ToStdout = false
c.Log.Output.ToFile = false
},
wantErr: "log.output.to_stdout and log.output.to_file cannot both be false",
},
{
name: "log rotation size",
mutate: func(c *Config) { c.Log.Rotation.MaxSizeMB = 0 },
wantErr: "log.rotation.max_size_mb",
},
{
name: "log sampling enabled invalid",
mutate: func(c *Config) {
c.Log.Sampling.Enabled = true
c.Log.Sampling.Initial = 0
},
wantErr: "log.sampling.initial",
},
{
name: "ops metrics collector ttl",
mutate: func(c *Config) { c.Ops.MetricsCollectorCache.TTL = -1 },
......
package admin
import (
"bytes"
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
type testSettingRepo struct {
values map[string]string
}
func newTestSettingRepo() *testSettingRepo {
return &testSettingRepo{values: map[string]string{}}
}
func (s *testSettingRepo) Get(ctx context.Context, key string) (*service.Setting, error) {
v, err := s.GetValue(ctx, key)
if err != nil {
return nil, err
}
return &service.Setting{Key: key, Value: v}, nil
}
func (s *testSettingRepo) GetValue(ctx context.Context, key string) (string, error) {
v, ok := s.values[key]
if !ok {
return "", service.ErrSettingNotFound
}
return v, nil
}
func (s *testSettingRepo) Set(ctx context.Context, key, value string) error {
s.values[key] = value
return nil
}
func (s *testSettingRepo) GetMultiple(ctx context.Context, keys []string) (map[string]string, error) {
out := make(map[string]string, len(keys))
for _, k := range keys {
if v, ok := s.values[k]; ok {
out[k] = v
}
}
return out, nil
}
func (s *testSettingRepo) SetMultiple(ctx context.Context, settings map[string]string) error {
for k, v := range settings {
s.values[k] = v
}
return nil
}
func (s *testSettingRepo) GetAll(ctx context.Context) (map[string]string, error) {
out := make(map[string]string, len(s.values))
for k, v := range s.values {
out[k] = v
}
return out, nil
}
func (s *testSettingRepo) Delete(ctx context.Context, key string) error {
delete(s.values, key)
return nil
}
func newOpsRuntimeRouter(handler *OpsHandler, withUser bool) *gin.Engine {
gin.SetMode(gin.TestMode)
r := gin.New()
if withUser {
r.Use(func(c *gin.Context) {
c.Set(string(middleware.ContextKeyUser), middleware.AuthSubject{UserID: 7})
c.Next()
})
}
r.GET("/runtime/logging", handler.GetRuntimeLogConfig)
r.PUT("/runtime/logging", handler.UpdateRuntimeLogConfig)
r.POST("/runtime/logging/reset", handler.ResetRuntimeLogConfig)
return r
}
func newRuntimeOpsService(t *testing.T) *service.OpsService {
t.Helper()
if err := logger.Init(logger.InitOptions{
Level: "info",
Format: "json",
ServiceName: "sub2api",
Environment: "test",
Output: logger.OutputOptions{
ToStdout: false,
ToFile: false,
},
}); err != nil {
t.Fatalf("init logger: %v", err)
}
settingRepo := newTestSettingRepo()
cfg := &config.Config{
Ops: config.OpsConfig{Enabled: true},
Log: config.LogConfig{
Level: "info",
Caller: true,
StacktraceLevel: "error",
Sampling: config.LogSamplingConfig{
Enabled: false,
Initial: 100,
Thereafter: 100,
},
},
}
return service.NewOpsService(nil, settingRepo, cfg, nil, nil, nil, nil, nil, nil, nil, nil)
}
func TestOpsRuntimeLoggingHandler_GetConfig(t *testing.T) {
h := NewOpsHandler(newRuntimeOpsService(t))
r := newOpsRuntimeRouter(h, false)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/runtime/logging", nil)
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("status=%d, want 200", w.Code)
}
}
func TestOpsRuntimeLoggingHandler_UpdateUnauthorized(t *testing.T) {
h := NewOpsHandler(newRuntimeOpsService(t))
r := newOpsRuntimeRouter(h, false)
body := `{"level":"debug","enable_sampling":false,"sampling_initial":100,"sampling_thereafter":100,"caller":true,"stacktrace_level":"error","retention_days":30}`
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPut, "/runtime/logging", bytes.NewBufferString(body))
req.Header.Set("Content-Type", "application/json")
r.ServeHTTP(w, req)
if w.Code != http.StatusUnauthorized {
t.Fatalf("status=%d, want 401", w.Code)
}
}
func TestOpsRuntimeLoggingHandler_UpdateAndResetSuccess(t *testing.T) {
h := NewOpsHandler(newRuntimeOpsService(t))
r := newOpsRuntimeRouter(h, true)
payload := map[string]any{
"level": "debug",
"enable_sampling": false,
"sampling_initial": 100,
"sampling_thereafter": 100,
"caller": true,
"stacktrace_level": "error",
"retention_days": 30,
}
raw, _ := json.Marshal(payload)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPut, "/runtime/logging", bytes.NewReader(raw))
req.Header.Set("Content-Type", "application/json")
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("update status=%d, want 200, body=%s", w.Code, w.Body.String())
}
w = httptest.NewRecorder()
req = httptest.NewRequest(http.MethodPost, "/runtime/logging/reset", nil)
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("reset status=%d, want 200, body=%s", w.Code, w.Body.String())
}
}
......@@ -4,6 +4,7 @@ import (
"net/http"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
......@@ -101,6 +102,84 @@ func (h *OpsHandler) UpdateAlertRuntimeSettings(c *gin.Context) {
response.Success(c, updated)
}
// GetRuntimeLogConfig returns runtime log config (DB-backed).
// GET /api/v1/admin/ops/runtime/logging
func (h *OpsHandler) GetRuntimeLogConfig(c *gin.Context) {
if h.opsService == nil {
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
return
}
if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
response.ErrorFrom(c, err)
return
}
cfg, err := h.opsService.GetRuntimeLogConfig(c.Request.Context())
if err != nil {
response.Error(c, http.StatusInternalServerError, "Failed to get runtime log config")
return
}
response.Success(c, cfg)
}
// UpdateRuntimeLogConfig updates runtime log config and applies changes immediately.
// PUT /api/v1/admin/ops/runtime/logging
func (h *OpsHandler) UpdateRuntimeLogConfig(c *gin.Context) {
if h.opsService == nil {
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
return
}
if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
response.ErrorFrom(c, err)
return
}
var req service.OpsRuntimeLogConfig
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request body")
return
}
subject, ok := middleware.GetAuthSubjectFromContext(c)
if !ok || subject.UserID <= 0 {
response.Error(c, http.StatusUnauthorized, "Unauthorized")
return
}
updated, err := h.opsService.UpdateRuntimeLogConfig(c.Request.Context(), &req, subject.UserID)
if err != nil {
response.Error(c, http.StatusBadRequest, err.Error())
return
}
response.Success(c, updated)
}
// ResetRuntimeLogConfig removes runtime override and falls back to env/yaml baseline.
// POST /api/v1/admin/ops/runtime/logging/reset
func (h *OpsHandler) ResetRuntimeLogConfig(c *gin.Context) {
if h.opsService == nil {
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
return
}
if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
response.ErrorFrom(c, err)
return
}
subject, ok := middleware.GetAuthSubjectFromContext(c)
if !ok || subject.UserID <= 0 {
response.Error(c, http.StatusUnauthorized, "Unauthorized")
return
}
updated, err := h.opsService.ResetRuntimeLogConfig(c.Request.Context(), subject.UserID)
if err != nil {
response.Error(c, http.StatusBadRequest, err.Error())
return
}
response.Success(c, updated)
}
// GetAdvancedSettings returns Ops advanced settings (DB-backed).
// GET /api/v1/admin/ops/advanced-settings
func (h *OpsHandler) GetAdvancedSettings(c *gin.Context) {
......
package admin
import (
"net/http"
"strconv"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
type opsSystemLogCleanupRequest struct {
StartTime string `json:"start_time"`
EndTime string `json:"end_time"`
Level string `json:"level"`
Component string `json:"component"`
RequestID string `json:"request_id"`
ClientRequestID string `json:"client_request_id"`
UserID *int64 `json:"user_id"`
AccountID *int64 `json:"account_id"`
Platform string `json:"platform"`
Model string `json:"model"`
Query string `json:"q"`
}
// ListSystemLogs returns indexed system logs.
// GET /api/v1/admin/ops/system-logs
func (h *OpsHandler) ListSystemLogs(c *gin.Context) {
if h.opsService == nil {
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
return
}
if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
response.ErrorFrom(c, err)
return
}
page, pageSize := response.ParsePagination(c)
if pageSize > 200 {
pageSize = 200
}
start, end, err := parseOpsTimeRange(c, "1h")
if err != nil {
response.BadRequest(c, err.Error())
return
}
filter := &service.OpsSystemLogFilter{
Page: page,
PageSize: pageSize,
StartTime: &start,
EndTime: &end,
Level: strings.TrimSpace(c.Query("level")),
Component: strings.TrimSpace(c.Query("component")),
RequestID: strings.TrimSpace(c.Query("request_id")),
ClientRequestID: strings.TrimSpace(c.Query("client_request_id")),
Platform: strings.TrimSpace(c.Query("platform")),
Model: strings.TrimSpace(c.Query("model")),
Query: strings.TrimSpace(c.Query("q")),
}
if v := strings.TrimSpace(c.Query("user_id")); v != "" {
id, parseErr := strconv.ParseInt(v, 10, 64)
if parseErr != nil || id <= 0 {
response.BadRequest(c, "Invalid user_id")
return
}
filter.UserID = &id
}
if v := strings.TrimSpace(c.Query("account_id")); v != "" {
id, parseErr := strconv.ParseInt(v, 10, 64)
if parseErr != nil || id <= 0 {
response.BadRequest(c, "Invalid account_id")
return
}
filter.AccountID = &id
}
result, err := h.opsService.ListSystemLogs(c.Request.Context(), filter)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Paginated(c, result.Logs, int64(result.Total), result.Page, result.PageSize)
}
// CleanupSystemLogs deletes indexed system logs by filter.
// POST /api/v1/admin/ops/system-logs/cleanup
func (h *OpsHandler) CleanupSystemLogs(c *gin.Context) {
if h.opsService == nil {
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
return
}
if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
response.ErrorFrom(c, err)
return
}
subject, ok := middleware.GetAuthSubjectFromContext(c)
if !ok || subject.UserID <= 0 {
response.Error(c, http.StatusUnauthorized, "Unauthorized")
return
}
var req opsSystemLogCleanupRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request body")
return
}
parseTS := func(raw string) (*time.Time, error) {
raw = strings.TrimSpace(raw)
if raw == "" {
return nil, nil
}
if t, err := time.Parse(time.RFC3339Nano, raw); err == nil {
return &t, nil
}
t, err := time.Parse(time.RFC3339, raw)
if err != nil {
return nil, err
}
return &t, nil
}
start, err := parseTS(req.StartTime)
if err != nil {
response.BadRequest(c, "Invalid start_time")
return
}
end, err := parseTS(req.EndTime)
if err != nil {
response.BadRequest(c, "Invalid end_time")
return
}
filter := &service.OpsSystemLogCleanupFilter{
StartTime: start,
EndTime: end,
Level: strings.TrimSpace(req.Level),
Component: strings.TrimSpace(req.Component),
RequestID: strings.TrimSpace(req.RequestID),
ClientRequestID: strings.TrimSpace(req.ClientRequestID),
UserID: req.UserID,
AccountID: req.AccountID,
Platform: strings.TrimSpace(req.Platform),
Model: strings.TrimSpace(req.Model),
Query: strings.TrimSpace(req.Query),
}
deleted, err := h.opsService.CleanupSystemLogs(c.Request.Context(), filter, subject.UserID)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, gin.H{"deleted": deleted})
}
// GetSystemLogIngestionHealth returns sink health metrics.
// GET /api/v1/admin/ops/system-logs/health
func (h *OpsHandler) GetSystemLogIngestionHealth(c *gin.Context) {
if h.opsService == nil {
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
return
}
if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, h.opsService.GetSystemLogSinkHealth())
}
package admin
import (
"bytes"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
type responseEnvelope struct {
Code int `json:"code"`
Message string `json:"message"`
Data json.RawMessage `json:"data"`
}
func newOpsSystemLogTestRouter(handler *OpsHandler, withUser bool) *gin.Engine {
gin.SetMode(gin.TestMode)
r := gin.New()
if withUser {
r.Use(func(c *gin.Context) {
c.Set(string(middleware.ContextKeyUser), middleware.AuthSubject{UserID: 99})
c.Next()
})
}
r.GET("/logs", handler.ListSystemLogs)
r.POST("/logs/cleanup", handler.CleanupSystemLogs)
r.GET("/logs/health", handler.GetSystemLogIngestionHealth)
return r
}
func TestOpsSystemLogHandler_ListUnavailable(t *testing.T) {
h := NewOpsHandler(nil)
r := newOpsSystemLogTestRouter(h, false)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/logs", nil)
r.ServeHTTP(w, req)
if w.Code != http.StatusServiceUnavailable {
t.Fatalf("status=%d, want 503", w.Code)
}
}
func TestOpsSystemLogHandler_ListInvalidUserID(t *testing.T) {
svc := service.NewOpsService(nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
h := NewOpsHandler(svc)
r := newOpsSystemLogTestRouter(h, false)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/logs?user_id=abc", nil)
r.ServeHTTP(w, req)
if w.Code != http.StatusBadRequest {
t.Fatalf("status=%d, want 400", w.Code)
}
}
func TestOpsSystemLogHandler_ListInvalidAccountID(t *testing.T) {
svc := service.NewOpsService(nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
h := NewOpsHandler(svc)
r := newOpsSystemLogTestRouter(h, false)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/logs?account_id=-1", nil)
r.ServeHTTP(w, req)
if w.Code != http.StatusBadRequest {
t.Fatalf("status=%d, want 400", w.Code)
}
}
func TestOpsSystemLogHandler_ListMonitoringDisabled(t *testing.T) {
svc := service.NewOpsService(nil, nil, &config.Config{
Ops: config.OpsConfig{Enabled: false},
}, nil, nil, nil, nil, nil, nil, nil, nil)
h := NewOpsHandler(svc)
r := newOpsSystemLogTestRouter(h, false)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/logs", nil)
r.ServeHTTP(w, req)
if w.Code != http.StatusNotFound {
t.Fatalf("status=%d, want 404", w.Code)
}
}
func TestOpsSystemLogHandler_ListSuccess(t *testing.T) {
svc := service.NewOpsService(nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
h := NewOpsHandler(svc)
r := newOpsSystemLogTestRouter(h, false)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/logs?time_range=30m&page=1&page_size=20", nil)
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("status=%d, want 200", w.Code)
}
var resp responseEnvelope
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
t.Fatalf("unmarshal response: %v", err)
}
if resp.Code != 0 {
t.Fatalf("unexpected response code: %+v", resp)
}
}
func TestOpsSystemLogHandler_CleanupUnauthorized(t *testing.T) {
svc := service.NewOpsService(nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
h := NewOpsHandler(svc)
r := newOpsSystemLogTestRouter(h, false)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/logs/cleanup", bytes.NewBufferString(`{"request_id":"r1"}`))
req.Header.Set("Content-Type", "application/json")
r.ServeHTTP(w, req)
if w.Code != http.StatusUnauthorized {
t.Fatalf("status=%d, want 401", w.Code)
}
}
func TestOpsSystemLogHandler_CleanupInvalidPayload(t *testing.T) {
svc := service.NewOpsService(nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
h := NewOpsHandler(svc)
r := newOpsSystemLogTestRouter(h, true)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/logs/cleanup", bytes.NewBufferString(`{bad-json`))
req.Header.Set("Content-Type", "application/json")
r.ServeHTTP(w, req)
if w.Code != http.StatusBadRequest {
t.Fatalf("status=%d, want 400", w.Code)
}
}
func TestOpsSystemLogHandler_CleanupInvalidTime(t *testing.T) {
svc := service.NewOpsService(nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
h := NewOpsHandler(svc)
r := newOpsSystemLogTestRouter(h, true)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/logs/cleanup", bytes.NewBufferString(`{"start_time":"bad","request_id":"r1"}`))
req.Header.Set("Content-Type", "application/json")
r.ServeHTTP(w, req)
if w.Code != http.StatusBadRequest {
t.Fatalf("status=%d, want 400", w.Code)
}
}
func TestOpsSystemLogHandler_CleanupInvalidEndTime(t *testing.T) {
svc := service.NewOpsService(nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
h := NewOpsHandler(svc)
r := newOpsSystemLogTestRouter(h, true)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/logs/cleanup", bytes.NewBufferString(`{"end_time":"bad","request_id":"r1"}`))
req.Header.Set("Content-Type", "application/json")
r.ServeHTTP(w, req)
if w.Code != http.StatusBadRequest {
t.Fatalf("status=%d, want 400", w.Code)
}
}
func TestOpsSystemLogHandler_CleanupServiceUnavailable(t *testing.T) {
svc := service.NewOpsService(nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
h := NewOpsHandler(svc)
r := newOpsSystemLogTestRouter(h, true)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/logs/cleanup", bytes.NewBufferString(`{"request_id":"r1"}`))
req.Header.Set("Content-Type", "application/json")
r.ServeHTTP(w, req)
if w.Code != http.StatusServiceUnavailable {
t.Fatalf("status=%d, want 503", w.Code)
}
}
func TestOpsSystemLogHandler_CleanupMonitoringDisabled(t *testing.T) {
svc := service.NewOpsService(nil, nil, &config.Config{
Ops: config.OpsConfig{Enabled: false},
}, nil, nil, nil, nil, nil, nil, nil, nil)
h := NewOpsHandler(svc)
r := newOpsSystemLogTestRouter(h, true)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/logs/cleanup", bytes.NewBufferString(`{"request_id":"r1"}`))
req.Header.Set("Content-Type", "application/json")
r.ServeHTTP(w, req)
if w.Code != http.StatusNotFound {
t.Fatalf("status=%d, want 404", w.Code)
}
}
func TestOpsSystemLogHandler_Health(t *testing.T) {
sink := service.NewOpsSystemLogSink(nil)
svc := service.NewOpsService(nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, sink)
h := NewOpsHandler(svc)
r := newOpsSystemLogTestRouter(h, false)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/logs/health", nil)
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("status=%d, want 200", w.Code)
}
}
func TestOpsSystemLogHandler_HealthUnavailableAndMonitoringDisabled(t *testing.T) {
h := NewOpsHandler(nil)
r := newOpsSystemLogTestRouter(h, false)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/logs/health", nil)
r.ServeHTTP(w, req)
if w.Code != http.StatusServiceUnavailable {
t.Fatalf("status=%d, want 503", w.Code)
}
svc := service.NewOpsService(nil, nil, &config.Config{
Ops: config.OpsConfig{Enabled: false},
}, nil, nil, nil, nil, nil, nil, nil, nil)
h = NewOpsHandler(svc)
r = newOpsSystemLogTestRouter(h, false)
w = httptest.NewRecorder()
req = httptest.NewRequest(http.MethodGet, "/logs/health", nil)
r.ServeHTTP(w, req)
if w.Code != http.StatusNotFound {
t.Fatalf("status=%d, want 404", w.Code)
}
}
......@@ -3,7 +3,6 @@ package admin
import (
"context"
"encoding/json"
"log"
"math"
"net"
"net/http"
......@@ -16,6 +15,7 @@ import (
"sync/atomic"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/gorilla/websocket"
......@@ -252,7 +252,7 @@ func (c *opsWSQPSCache) refresh(parentCtx context.Context) {
stats, err := opsService.GetWindowStats(ctx, now.Add(-c.requestCountWindow), now)
if err != nil || stats == nil {
if err != nil {
log.Printf("[OpsWS] refresh: get window stats failed: %v", err)
logger.LegacyPrintf("handler.admin.ops_ws", "[OpsWS] refresh: get window stats failed: %v", err)
}
return
}
......@@ -278,7 +278,7 @@ func (c *opsWSQPSCache) refresh(parentCtx context.Context) {
msg, err := json.Marshal(payload)
if err != nil {
log.Printf("[OpsWS] refresh: marshal payload failed: %v", err)
logger.LegacyPrintf("handler.admin.ops_ws", "[OpsWS] refresh: marshal payload failed: %v", err)
return
}
......@@ -338,7 +338,7 @@ func (h *OpsHandler) QPSWSHandler(c *gin.Context) {
// Reserve a global slot before upgrading the connection to keep the limit strict.
if !tryAcquireOpsWSTotalSlot(opsWSLimits.MaxConns) {
log.Printf("[OpsWS] connection limit reached: %d/%d", wsConnCount.Load(), opsWSLimits.MaxConns)
logger.LegacyPrintf("handler.admin.ops_ws", "[OpsWS] connection limit reached: %d/%d", wsConnCount.Load(), opsWSLimits.MaxConns)
c.JSON(http.StatusServiceUnavailable, gin.H{"error": "too many connections"})
return
}
......@@ -350,7 +350,7 @@ func (h *OpsHandler) QPSWSHandler(c *gin.Context) {
if opsWSLimits.MaxConnsPerIP > 0 && clientIP != "" {
if !tryAcquireOpsWSIPSlot(clientIP, opsWSLimits.MaxConnsPerIP) {
log.Printf("[OpsWS] per-ip connection limit reached: ip=%s limit=%d", clientIP, opsWSLimits.MaxConnsPerIP)
logger.LegacyPrintf("handler.admin.ops_ws", "[OpsWS] per-ip connection limit reached: ip=%s limit=%d", clientIP, opsWSLimits.MaxConnsPerIP)
c.JSON(http.StatusServiceUnavailable, gin.H{"error": "too many connections"})
return
}
......@@ -359,7 +359,7 @@ func (h *OpsHandler) QPSWSHandler(c *gin.Context) {
conn, err := upgrader.Upgrade(c.Writer, c.Request, nil)
if err != nil {
log.Printf("[OpsWS] upgrade failed: %v", err)
logger.LegacyPrintf("handler.admin.ops_ws", "[OpsWS] upgrade failed: %v", err)
return
}
......@@ -452,7 +452,7 @@ func handleQPSWebSocket(parentCtx context.Context, conn *websocket.Conn) {
conn.SetReadLimit(qpsWSMaxReadBytes)
if err := conn.SetReadDeadline(time.Now().Add(qpsWSPongWait)); err != nil {
log.Printf("[OpsWS] set read deadline failed: %v", err)
logger.LegacyPrintf("handler.admin.ops_ws", "[OpsWS] set read deadline failed: %v", err)
return
}
conn.SetPongHandler(func(string) error {
......@@ -471,7 +471,7 @@ func handleQPSWebSocket(parentCtx context.Context, conn *websocket.Conn) {
_, _, err := conn.ReadMessage()
if err != nil {
if websocket.IsUnexpectedCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway, websocket.CloseNoStatusReceived) {
log.Printf("[OpsWS] read failed: %v", err)
logger.LegacyPrintf("handler.admin.ops_ws", "[OpsWS] read failed: %v", err)
}
return
}
......@@ -508,7 +508,7 @@ func handleQPSWebSocket(parentCtx context.Context, conn *websocket.Conn) {
continue
}
if err := writeWithTimeout(websocket.TextMessage, msg); err != nil {
log.Printf("[OpsWS] write failed: %v", err)
logger.LegacyPrintf("handler.admin.ops_ws", "[OpsWS] write failed: %v", err)
cancel()
closeConn()
wg.Wait()
......@@ -517,7 +517,7 @@ func handleQPSWebSocket(parentCtx context.Context, conn *websocket.Conn) {
case <-pingTicker.C:
if err := writeWithTimeout(websocket.PingMessage, nil); err != nil {
log.Printf("[OpsWS] ping failed: %v", err)
logger.LegacyPrintf("handler.admin.ops_ws", "[OpsWS] ping failed: %v", err)
cancel()
closeConn()
wg.Wait()
......@@ -666,14 +666,14 @@ func loadOpsWSProxyConfigFromEnv() OpsWSProxyConfig {
if parsed, err := strconv.ParseBool(v); err == nil {
cfg.TrustProxy = parsed
} else {
log.Printf("[OpsWS] invalid %s=%q (expected bool); using default=%v", envOpsWSTrustProxy, v, cfg.TrustProxy)
logger.LegacyPrintf("handler.admin.ops_ws", "[OpsWS] invalid %s=%q (expected bool); using default=%v", envOpsWSTrustProxy, v, cfg.TrustProxy)
}
}
if raw := strings.TrimSpace(os.Getenv(envOpsWSTrustedProxies)); raw != "" {
prefixes, invalid := parseTrustedProxyList(raw)
if len(invalid) > 0 {
log.Printf("[OpsWS] invalid %s entries ignored: %s", envOpsWSTrustedProxies, strings.Join(invalid, ", "))
logger.LegacyPrintf("handler.admin.ops_ws", "[OpsWS] invalid %s entries ignored: %s", envOpsWSTrustedProxies, strings.Join(invalid, ", "))
}
cfg.TrustedProxies = prefixes
}
......@@ -684,7 +684,7 @@ func loadOpsWSProxyConfigFromEnv() OpsWSProxyConfig {
case OriginPolicyStrict, OriginPolicyPermissive:
cfg.OriginPolicy = normalized
default:
log.Printf("[OpsWS] invalid %s=%q (expected %q or %q); using default=%q", envOpsWSOriginPolicy, v, OriginPolicyStrict, OriginPolicyPermissive, cfg.OriginPolicy)
logger.LegacyPrintf("handler.admin.ops_ws", "[OpsWS] invalid %s=%q (expected %q or %q); using default=%q", envOpsWSOriginPolicy, v, OriginPolicyStrict, OriginPolicyPermissive, cfg.OriginPolicy)
}
}
......@@ -701,14 +701,14 @@ func loadOpsWSRuntimeLimitsFromEnv() opsWSRuntimeLimits {
if parsed, err := strconv.Atoi(v); err == nil && parsed > 0 {
cfg.MaxConns = int32(parsed)
} else {
log.Printf("[OpsWS] invalid %s=%q (expected int>0); using default=%d", envOpsWSMaxConns, v, cfg.MaxConns)
logger.LegacyPrintf("handler.admin.ops_ws", "[OpsWS] invalid %s=%q (expected int>0); using default=%d", envOpsWSMaxConns, v, cfg.MaxConns)
}
}
if v := strings.TrimSpace(os.Getenv(envOpsWSMaxConnsPerIP)); v != "" {
if parsed, err := strconv.Atoi(v); err == nil && parsed >= 0 {
cfg.MaxConnsPerIP = int32(parsed)
} else {
log.Printf("[OpsWS] invalid %s=%q (expected int>=0); using default=%d", envOpsWSMaxConnsPerIP, v, cfg.MaxConnsPerIP)
logger.LegacyPrintf("handler.admin.ops_ws", "[OpsWS] invalid %s=%q (expected int>=0); using default=%d", envOpsWSMaxConnsPerIP, v, cfg.MaxConnsPerIP)
}
}
return cfg
......
package admin
import (
"log"
"net/http"
"strconv"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
......@@ -378,11 +378,11 @@ func (h *UsageHandler) ListCleanupTasks(c *gin.Context) {
operator = subject.UserID
}
page, pageSize := response.ParsePagination(c)
log.Printf("[UsageCleanup] 请求清理任务列表: operator=%d page=%d page_size=%d", operator, page, pageSize)
logger.LegacyPrintf("handler.admin.usage", "[UsageCleanup] 请求清理任务列表: operator=%d page=%d page_size=%d", operator, page, pageSize)
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
tasks, result, err := h.cleanupService.ListTasks(c.Request.Context(), params)
if err != nil {
log.Printf("[UsageCleanup] 查询清理任务列表失败: operator=%d page=%d page_size=%d err=%v", operator, page, pageSize, err)
logger.LegacyPrintf("handler.admin.usage", "[UsageCleanup] 查询清理任务列表失败: operator=%d page=%d page_size=%d err=%v", operator, page, pageSize, err)
response.ErrorFrom(c, err)
return
}
......@@ -390,7 +390,7 @@ func (h *UsageHandler) ListCleanupTasks(c *gin.Context) {
for i := range tasks {
out = append(out, *dto.UsageCleanupTaskFromService(&tasks[i]))
}
log.Printf("[UsageCleanup] 返回清理任务列表: operator=%d total=%d items=%d page=%d page_size=%d", operator, result.Total, len(out), page, pageSize)
logger.LegacyPrintf("handler.admin.usage", "[UsageCleanup] 返回清理任务列表: operator=%d total=%d items=%d page=%d page_size=%d", operator, result.Total, len(out), page, pageSize)
response.Paginated(c, out, result.Total, page, pageSize)
}
......@@ -472,7 +472,7 @@ func (h *UsageHandler) CreateCleanupTask(c *gin.Context) {
billingType = *filters.BillingType
}
log.Printf("[UsageCleanup] 请求创建清理任务: operator=%d start=%s end=%s user_id=%v api_key_id=%v account_id=%v group_id=%v model=%v stream=%v billing_type=%v tz=%q",
logger.LegacyPrintf("handler.admin.usage", "[UsageCleanup] 请求创建清理任务: operator=%d start=%s end=%s user_id=%v api_key_id=%v account_id=%v group_id=%v model=%v stream=%v billing_type=%v tz=%q",
subject.UserID,
filters.StartTime.Format(time.RFC3339),
filters.EndTime.Format(time.RFC3339),
......@@ -488,12 +488,12 @@ func (h *UsageHandler) CreateCleanupTask(c *gin.Context) {
task, err := h.cleanupService.CreateTask(c.Request.Context(), filters, subject.UserID)
if err != nil {
log.Printf("[UsageCleanup] 创建清理任务失败: operator=%d err=%v", subject.UserID, err)
logger.LegacyPrintf("handler.admin.usage", "[UsageCleanup] 创建清理任务失败: operator=%d err=%v", subject.UserID, err)
response.ErrorFrom(c, err)
return
}
log.Printf("[UsageCleanup] 清理任务已创建: task=%d operator=%d status=%s", task.ID, subject.UserID, task.Status)
logger.LegacyPrintf("handler.admin.usage", "[UsageCleanup] 清理任务已创建: task=%d operator=%d status=%s", task.ID, subject.UserID, task.Status)
response.Success(c, dto.UsageCleanupTaskFromService(task))
}
......@@ -515,12 +515,12 @@ func (h *UsageHandler) CancelCleanupTask(c *gin.Context) {
response.BadRequest(c, "Invalid task id")
return
}
log.Printf("[UsageCleanup] 请求取消清理任务: task=%d operator=%d", taskID, subject.UserID)
logger.LegacyPrintf("handler.admin.usage", "[UsageCleanup] 请求取消清理任务: task=%d operator=%d", taskID, subject.UserID)
if err := h.cleanupService.CancelTask(c.Request.Context(), taskID, subject.UserID); err != nil {
log.Printf("[UsageCleanup] 取消清理任务失败: task=%d operator=%d err=%v", taskID, subject.UserID, err)
logger.LegacyPrintf("handler.admin.usage", "[UsageCleanup] 取消清理任务失败: task=%d operator=%d err=%v", taskID, subject.UserID, err)
response.ErrorFrom(c, err)
return
}
log.Printf("[UsageCleanup] 清理任务已取消: task=%d operator=%d", taskID, subject.UserID)
logger.LegacyPrintf("handler.admin.usage", "[UsageCleanup] 清理任务已取消: task=%d operator=%d", taskID, subject.UserID)
response.Success(c, gin.H{"id": taskID, "status": service.UsageCleanupStatusCanceled})
}
......@@ -7,7 +7,6 @@ import (
"errors"
"fmt"
"io"
"log"
"net/http"
"strings"
"time"
......@@ -19,11 +18,13 @@ import (
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
pkgerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
)
// GatewayHandler handles API gateway requests
......@@ -98,6 +99,13 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
h.errorResponse(c, http.StatusInternalServerError, "api_error", "User context not found")
return
}
reqLog := requestLogger(
c,
"handler.gateway.messages",
zap.Int64("user_id", subject.UserID),
zap.Int64("api_key_id", apiKey.ID),
zap.Any("group_id", apiKey.GroupID),
)
// 读取请求体
body, err := io.ReadAll(c.Request.Body)
......@@ -124,6 +132,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
}
reqModel := parsedReq.Model
reqStream := parsedReq.Stream
reqLog = reqLog.With(zap.String("model", reqModel), zap.Bool("stream", reqStream))
// 设置 max_tokens=1 + haiku 探测请求标识到 context 中
// 必须在 SetClaudeCodeClientContext 之前设置,因为 ClaudeCodeValidator 需要读取此标识进行绕过判断
......@@ -163,9 +172,10 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
canWait, err := h.concurrencyHelper.IncrementWaitCount(c.Request.Context(), subject.UserID, maxWait)
waitCounted := false
if err != nil {
log.Printf("Increment wait count failed: %v", err)
reqLog.Warn("gateway.user_wait_counter_increment_failed", zap.Error(err))
// On error, allow request to proceed
} else if !canWait {
reqLog.Info("gateway.user_wait_queue_full", zap.Int("max_wait", maxWait))
h.errorResponse(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later")
return
}
......@@ -182,7 +192,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
// 1. 首先获取用户并发槽位
userReleaseFunc, err := h.concurrencyHelper.AcquireUserSlotWithWait(c, subject.UserID, subject.Concurrency, reqStream, &streamStarted)
if err != nil {
log.Printf("User concurrency acquire failed: %v", err)
reqLog.Warn("gateway.user_slot_acquire_failed", zap.Error(err))
h.handleConcurrencyError(c, err, "user", streamStarted)
return
}
......@@ -199,7 +209,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
// 2. 【新增】Wait后二次检查余额/订阅
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
log.Printf("Billing eligibility check failed after wait: %v", err)
reqLog.Info("gateway.billing_eligibility_check_failed", zap.Error(err))
status, code, message := billingErrorDetails(err)
h.handleStreamingAwareError(c, status, code, message, streamStarted)
return
......@@ -251,7 +261,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, reqModel, failedAccountIDs, "") // Gemini 不使用会话限制
if err != nil {
if len(failedAccountIDs) == 0 {
log.Printf("[Gateway] SelectAccount failed: %v", err)
reqLog.Warn("gateway.account_select_failed", zap.Error(err), zap.Int("excluded_account_count", len(failedAccountIDs)))
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "Service temporarily unavailable", streamStarted)
return
}
......@@ -260,7 +270,10 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
// 谷歌上游 503 (MODEL_CAPACITY_EXHAUSTED) 通常是暂时性的,等几秒就能恢复。
if lastFailoverErr != nil && lastFailoverErr.StatusCode == http.StatusServiceUnavailable && switchCount <= maxAccountSwitches {
if sleepAntigravitySingleAccountBackoff(c.Request.Context(), switchCount) {
log.Printf("Antigravity single-account 503 retry: clearing failed accounts, retry %d/%d", switchCount, maxAccountSwitches)
reqLog.Warn("gateway.single_account_retrying",
zap.Int("retry_count", switchCount),
zap.Int("max_retries", maxAccountSwitches),
)
failedAccountIDs = make(map[int64]struct{})
// 设置 context 标记,让 Service 层预检查等待限流过期而非直接切换
ctx := context.WithValue(c.Request.Context(), ctxkey.SingleAccountRetry, true)
......@@ -276,7 +289,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
return
}
account := selection.Account
setOpsSelectedAccount(c, account.ID)
setOpsSelectedAccount(c, account.ID, account.Platform)
// 检查请求拦截(预热请求、SUGGESTION MODE等)
if account.IsInterceptWarmupEnabled() {
......@@ -304,9 +317,12 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
accountWaitCounted := false
canWait, err := h.concurrencyHelper.IncrementAccountWaitCount(c.Request.Context(), account.ID, selection.WaitPlan.MaxWaiting)
if err != nil {
log.Printf("Increment account wait count failed: %v", err)
reqLog.Warn("gateway.account_wait_counter_increment_failed", zap.Int64("account_id", account.ID), zap.Error(err))
} else if !canWait {
log.Printf("Account wait queue full: account=%d", account.ID)
reqLog.Info("gateway.account_wait_queue_full",
zap.Int64("account_id", account.ID),
zap.Int("max_waiting", selection.WaitPlan.MaxWaiting),
)
h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later", streamStarted)
return
}
......@@ -329,7 +345,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
&streamStarted,
)
if err != nil {
log.Printf("Account concurrency acquire failed: %v", err)
reqLog.Warn("gateway.account_slot_acquire_failed", zap.Int64("account_id", account.ID), zap.Error(err))
releaseWait()
h.handleConcurrencyError(c, err, "account", streamStarted)
return
......@@ -337,7 +353,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
// Slot acquired: no longer waiting in queue.
releaseWait()
if err := h.gatewayService.BindStickySession(c.Request.Context(), apiKey.GroupID, sessionKey, account.ID); err != nil {
log.Printf("Bind sticky session failed: %v", err)
reqLog.Warn("gateway.bind_sticky_session_failed", zap.Int64("account_id", account.ID), zap.Error(err))
}
}
// 账号槽位/等待计数需要在超时或断开时安全回收
......@@ -370,7 +386,12 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
return
}
switchCount++
log.Printf("Account %d: upstream error %d, switching account %d/%d", account.ID, failoverErr.StatusCode, switchCount, maxAccountSwitches)
reqLog.Warn("gateway.upstream_failover_switching",
zap.Int64("account_id", account.ID),
zap.Int("upstream_status", failoverErr.StatusCode),
zap.Int("switch_count", switchCount),
zap.Int("max_switches", maxAccountSwitches),
)
if account.Platform == service.PlatformAntigravity {
if !sleepFailoverDelay(c.Request.Context(), switchCount) {
return
......@@ -379,7 +400,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
continue
}
// 错误响应已在Forward中处理,这里只记录日志
log.Printf("Forward request failed: %v", err)
reqLog.Error("gateway.forward_failed", zap.Int64("account_id", account.ID), zap.Error(err))
return
}
......@@ -402,7 +423,14 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
ForceCacheBilling: fcb,
APIKeyService: h.apiKeyService,
}); err != nil {
log.Printf("Record usage failed: %v", err)
logger.L().With(
zap.String("component", "handler.gateway.messages"),
zap.Int64("user_id", subject.UserID),
zap.Int64("api_key_id", apiKey.ID),
zap.Any("group_id", apiKey.GroupID),
zap.String("model", reqModel),
zap.Int64("account_id", usedAccount.ID),
).Error("gateway.record_usage_failed", zap.Error(err))
}
}(result, account, userAgent, clientIP, forceCacheBilling)
return
......@@ -437,7 +465,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), currentAPIKey.GroupID, sessionKey, reqModel, failedAccountIDs, parsedReq.MetadataUserID)
if err != nil {
if len(failedAccountIDs) == 0 {
log.Printf("[Gateway] SelectAccount failed: %v", err)
reqLog.Warn("gateway.account_select_failed", zap.Error(err), zap.Int("excluded_account_count", len(failedAccountIDs)))
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "Service temporarily unavailable", streamStarted)
return
}
......@@ -446,7 +474,10 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
// 谷歌上游 503 (MODEL_CAPACITY_EXHAUSTED) 通常是暂时性的,等几秒就能恢复。
if lastFailoverErr != nil && lastFailoverErr.StatusCode == http.StatusServiceUnavailable && switchCount <= maxAccountSwitches {
if sleepAntigravitySingleAccountBackoff(c.Request.Context(), switchCount) {
log.Printf("Antigravity single-account 503 retry: clearing failed accounts, retry %d/%d", switchCount, maxAccountSwitches)
reqLog.Warn("gateway.single_account_retrying",
zap.Int("retry_count", switchCount),
zap.Int("max_retries", maxAccountSwitches),
)
failedAccountIDs = make(map[int64]struct{})
// 设置 context 标记,让 Service 层预检查等待限流过期而非直接切换
ctx := context.WithValue(c.Request.Context(), ctxkey.SingleAccountRetry, true)
......@@ -462,7 +493,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
return
}
account := selection.Account
setOpsSelectedAccount(c, account.ID)
setOpsSelectedAccount(c, account.ID, account.Platform)
// 检查请求拦截(预热请求、SUGGESTION MODE等)
if account.IsInterceptWarmupEnabled() {
......@@ -490,9 +521,12 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
accountWaitCounted := false
canWait, err := h.concurrencyHelper.IncrementAccountWaitCount(c.Request.Context(), account.ID, selection.WaitPlan.MaxWaiting)
if err != nil {
log.Printf("Increment account wait count failed: %v", err)
reqLog.Warn("gateway.account_wait_counter_increment_failed", zap.Int64("account_id", account.ID), zap.Error(err))
} else if !canWait {
log.Printf("Account wait queue full: account=%d", account.ID)
reqLog.Info("gateway.account_wait_queue_full",
zap.Int64("account_id", account.ID),
zap.Int("max_waiting", selection.WaitPlan.MaxWaiting),
)
h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later", streamStarted)
return
}
......@@ -515,7 +549,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
&streamStarted,
)
if err != nil {
log.Printf("Account concurrency acquire failed: %v", err)
reqLog.Warn("gateway.account_slot_acquire_failed", zap.Int64("account_id", account.ID), zap.Error(err))
releaseWait()
h.handleConcurrencyError(c, err, "account", streamStarted)
return
......@@ -523,7 +557,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
// Slot acquired: no longer waiting in queue.
releaseWait()
if err := h.gatewayService.BindStickySession(c.Request.Context(), currentAPIKey.GroupID, sessionKey, account.ID); err != nil {
log.Printf("Bind sticky session failed: %v", err)
reqLog.Warn("gateway.bind_sticky_session_failed", zap.Int64("account_id", account.ID), zap.Error(err))
}
}
// 账号槽位/等待计数需要在超时或断开时安全回收
......@@ -546,18 +580,26 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
if err != nil {
var promptTooLongErr *service.PromptTooLongError
if errors.As(err, &promptTooLongErr) {
log.Printf("Prompt too long from antigravity: group=%d fallback_group_id=%v fallback_used=%v", currentAPIKey.GroupID, fallbackGroupID, fallbackUsed)
reqLog.Warn("gateway.prompt_too_long_from_antigravity",
zap.Any("current_group_id", currentAPIKey.GroupID),
zap.Any("fallback_group_id", fallbackGroupID),
zap.Bool("fallback_used", fallbackUsed),
)
if !fallbackUsed && fallbackGroupID != nil && *fallbackGroupID > 0 {
fallbackGroup, err := h.gatewayService.ResolveGroupByID(c.Request.Context(), *fallbackGroupID)
if err != nil {
log.Printf("Resolve fallback group failed: %v", err)
reqLog.Warn("gateway.resolve_fallback_group_failed", zap.Int64("fallback_group_id", *fallbackGroupID), zap.Error(err))
_ = h.antigravityGatewayService.WriteMappedClaudeError(c, account, promptTooLongErr.StatusCode, promptTooLongErr.RequestID, promptTooLongErr.Body)
return
}
if fallbackGroup.Platform != service.PlatformAnthropic ||
fallbackGroup.SubscriptionType == service.SubscriptionTypeSubscription ||
fallbackGroup.FallbackGroupIDOnInvalidRequest != nil {
log.Printf("Fallback group invalid: group=%d platform=%s subscription=%s", fallbackGroup.ID, fallbackGroup.Platform, fallbackGroup.SubscriptionType)
reqLog.Warn("gateway.fallback_group_invalid",
zap.Int64("fallback_group_id", fallbackGroup.ID),
zap.String("fallback_platform", fallbackGroup.Platform),
zap.String("fallback_subscription_type", fallbackGroup.SubscriptionType),
)
_ = h.antigravityGatewayService.WriteMappedClaudeError(c, account, promptTooLongErr.StatusCode, promptTooLongErr.RequestID, promptTooLongErr.Body)
return
}
......@@ -591,7 +633,12 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
return
}
switchCount++
log.Printf("Account %d: upstream error %d, switching account %d/%d", account.ID, failoverErr.StatusCode, switchCount, maxAccountSwitches)
reqLog.Warn("gateway.upstream_failover_switching",
zap.Int64("account_id", account.ID),
zap.Int("upstream_status", failoverErr.StatusCode),
zap.Int("switch_count", switchCount),
zap.Int("max_switches", maxAccountSwitches),
)
if account.Platform == service.PlatformAntigravity {
if !sleepFailoverDelay(c.Request.Context(), switchCount) {
return
......@@ -600,7 +647,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
continue
}
// 错误响应已在Forward中处理,这里只记录日志
log.Printf("Account %d: Forward request failed: %v", account.ID, err)
reqLog.Error("gateway.forward_failed", zap.Int64("account_id", account.ID), zap.Error(err))
return
}
......@@ -623,9 +670,21 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
ForceCacheBilling: fcb,
APIKeyService: h.apiKeyService,
}); err != nil {
log.Printf("Record usage failed: %v", err)
logger.L().With(
zap.String("component", "handler.gateway.messages"),
zap.Int64("user_id", subject.UserID),
zap.Int64("api_key_id", currentAPIKey.ID),
zap.Any("group_id", currentAPIKey.GroupID),
zap.String("model", reqModel),
zap.Int64("account_id", usedAccount.ID),
).Error("gateway.record_usage_failed", zap.Error(err))
}
}(result, account, userAgent, clientIP, forceCacheBilling)
reqLog.Debug("gateway.request_completed",
zap.Int64("account_id", account.ID),
zap.Int("switch_count", switchCount),
zap.Bool("fallback_used", fallbackUsed),
)
return
}
if !retryWithFallback {
......@@ -902,7 +961,11 @@ func sleepAntigravitySingleAccountBackoff(ctx context.Context, retryCount int) b
// Handler 层只需短暂间隔后重新进入 Service 层即可。
const delay = 2 * time.Second
log.Printf("Antigravity single-account 503 backoff: waiting %v before retry (attempt %d)", delay, retryCount)
logger.L().With(
zap.String("component", "handler.gateway.failover"),
zap.Duration("delay", delay),
zap.Int("retry_count", retryCount),
).Info("gateway.single_account_backoff_waiting")
select {
case <-ctx.Done():
......@@ -1023,6 +1086,12 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) {
h.errorResponse(c, http.StatusInternalServerError, "api_error", "User context not found")
return
}
reqLog := requestLogger(
c,
"handler.gateway.count_tokens",
zap.Int64("api_key_id", apiKey.ID),
zap.Any("group_id", apiKey.GroupID),
)
// 读取请求体
body, err := io.ReadAll(c.Request.Body)
......@@ -1050,6 +1119,7 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) {
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body")
return
}
reqLog = reqLog.With(zap.String("model", parsedReq.Model), zap.Bool("stream", parsedReq.Stream))
// 在请求上下文中记录 thinking 状态,供 Antigravity 最终模型 key 推导/模型维度限流使用
c.Request = c.Request.WithContext(context.WithValue(c.Request.Context(), ctxkey.ThinkingEnabled, parsedReq.ThinkingEnabled))
......@@ -1083,15 +1153,15 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) {
// 选择支持该模型的账号
account, err := h.gatewayService.SelectAccountForModel(c.Request.Context(), apiKey.GroupID, sessionHash, parsedReq.Model)
if err != nil {
log.Printf("[Gateway] SelectAccountForModel failed: %v", err)
reqLog.Warn("gateway.count_tokens_select_account_failed", zap.Error(err))
h.errorResponse(c, http.StatusServiceUnavailable, "api_error", "Service temporarily unavailable")
return
}
setOpsSelectedAccount(c, account.ID)
setOpsSelectedAccount(c, account.ID, account.Platform)
// 转发请求(不记录使用量)
if err := h.gatewayService.ForwardCountTokens(c.Request.Context(), c, account, parsedReq); err != nil {
log.Printf("Forward count_tokens request failed: %v", err)
reqLog.Error("gateway.count_tokens_forward_failed", zap.Int64("account_id", account.ID), zap.Error(err))
// 错误响应已在 ForwardCountTokens 中处理
return
}
......@@ -1355,7 +1425,10 @@ func billingErrorDetails(err error) (status int, code, message string) {
}
msg := pkgerrors.Message(err)
if msg == "" {
log.Printf("[Gateway] billing error details: %v", err)
logger.L().With(
zap.String("component", "handler.gateway.billing"),
zap.Error(err),
).Warn("gateway.billing_error_missing_message")
msg = "Billing error"
}
return http.StatusForbidden, "billing_error", msg
......
......@@ -8,7 +8,6 @@ import (
"encoding/json"
"errors"
"io"
"log"
"net/http"
"regexp"
"strings"
......@@ -20,11 +19,13 @@ import (
"github.com/Wei-Shaw/sub2api/internal/pkg/gemini"
"github.com/Wei-Shaw/sub2api/internal/pkg/googleapi"
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/google/uuid"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
)
// geminiCLITmpDirRegex 用于从 Gemini CLI 请求体中提取 tmp 目录的哈希值
......@@ -143,6 +144,13 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
googleError(c, http.StatusInternalServerError, "User context not found")
return
}
reqLog := requestLogger(
c,
"handler.gemini_v1beta.models",
zap.Int64("user_id", authSubject.UserID),
zap.Int64("api_key_id", apiKey.ID),
zap.Any("group_id", apiKey.GroupID),
)
// 检查平台:优先使用强制平台(/antigravity 路由,中间件已设置 request.Context),否则要求 gemini 分组
if !middleware.HasForcePlatform(c) {
......@@ -159,6 +167,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
}
stream := action == "streamGenerateContent"
reqLog = reqLog.With(zap.String("model", modelName), zap.String("action", action), zap.Bool("stream", stream))
body, err := io.ReadAll(c.Request.Body)
if err != nil {
......@@ -187,8 +196,9 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
canWait, err := geminiConcurrency.IncrementWaitCount(c.Request.Context(), authSubject.UserID, maxWait)
waitCounted := false
if err != nil {
log.Printf("Increment wait count failed: %v", err)
reqLog.Warn("gemini.user_wait_counter_increment_failed", zap.Error(err))
} else if !canWait {
reqLog.Info("gemini.user_wait_queue_full", zap.Int("max_wait", maxWait))
googleError(c, http.StatusTooManyRequests, "Too many pending requests, please retry later")
return
}
......@@ -208,6 +218,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
}
userReleaseFunc, err := geminiConcurrency.AcquireUserSlotWithWait(c, authSubject.UserID, authSubject.Concurrency, stream, &streamStarted)
if err != nil {
reqLog.Warn("gemini.user_slot_acquire_failed", zap.Error(err))
googleError(c, http.StatusTooManyRequests, err.Error())
return
}
......@@ -223,6 +234,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
// 2) billing eligibility check (after wait)
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
reqLog.Info("gemini.billing_eligibility_check_failed", zap.Error(err))
status, _, message := billingErrorDetails(err)
googleError(c, status, message)
return
......@@ -296,8 +308,11 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
matchedDigestChain = foundMatchedChain
sessionBoundAccountID = foundAccountID
geminiSessionUUID = foundUUID
log.Printf("[Gemini] Digest fallback matched: uuid=%s, accountID=%d, chain=%s",
safeShortPrefix(foundUUID, 8), foundAccountID, truncateDigestChain(geminiDigestChain))
reqLog.Info("gemini.digest_fallback_matched",
zap.String("session_uuid_prefix", safeShortPrefix(foundUUID, 8)),
zap.Int64("account_id", foundAccountID),
zap.String("digest_chain", truncateDigestChain(geminiDigestChain)),
)
// 关键:如果原 sessionKey 为空,使用 prefixHash + uuid 作为 sessionKey
// 这样 SelectAccountWithLoadAwareness 的粘性会话逻辑会优先使用匹配到的账号
......@@ -346,7 +361,10 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
// 谷歌上游 503 (MODEL_CAPACITY_EXHAUSTED) 通常是暂时性的,等几秒就能恢复。
if lastFailoverErr != nil && lastFailoverErr.StatusCode == http.StatusServiceUnavailable && switchCount <= maxAccountSwitches {
if sleepAntigravitySingleAccountBackoff(c.Request.Context(), switchCount) {
log.Printf("Antigravity single-account 503 retry: clearing failed accounts, retry %d/%d", switchCount, maxAccountSwitches)
reqLog.Warn("gemini.single_account_retrying",
zap.Int("retry_count", switchCount),
zap.Int("max_retries", maxAccountSwitches),
)
failedAccountIDs = make(map[int64]struct{})
// 设置 context 标记,让 Service 层预检查等待限流过期而非直接切换
ctx := context.WithValue(c.Request.Context(), ctxkey.SingleAccountRetry, true)
......@@ -358,18 +376,24 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
return
}
account := selection.Account
setOpsSelectedAccount(c, account.ID)
setOpsSelectedAccount(c, account.ID, account.Platform)
// 检测账号切换:如果粘性会话绑定的账号与当前选择的账号不同,清除 thoughtSignature
// 注意:Gemini 原生 API 的 thoughtSignature 与具体上游账号强相关;跨账号透传会导致 400。
if sessionBoundAccountID > 0 && sessionBoundAccountID != account.ID {
log.Printf("[Gemini] Sticky session account switched: %d -> %d, cleaning thoughtSignature", sessionBoundAccountID, account.ID)
reqLog.Info("gemini.sticky_session_account_switched",
zap.Int64("from_account_id", sessionBoundAccountID),
zap.Int64("to_account_id", account.ID),
zap.Bool("clean_thought_signature", true),
)
body = service.CleanGeminiNativeThoughtSignatures(body)
sessionBoundAccountID = account.ID
} else if sessionKey != "" && sessionBoundAccountID == 0 && !cleanedForUnknownBinding && bytes.Contains(body, []byte(`"thoughtSignature"`)) {
// 无缓存绑定但请求里已有 thoughtSignature:常见于缓存丢失/TTL 过期后,客户端继续携带旧签名。
// 为避免第一次转发就 400,这里做一次确定性清理,让新账号重新生成签名链路。
log.Printf("[Gemini] Sticky session binding missing, cleaning thoughtSignature proactively")
reqLog.Info("gemini.sticky_session_binding_missing",
zap.Bool("clean_thought_signature", true),
)
body = service.CleanGeminiNativeThoughtSignatures(body)
cleanedForUnknownBinding = true
sessionBoundAccountID = account.ID
......@@ -388,9 +412,12 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
accountWaitCounted := false
canWait, err := geminiConcurrency.IncrementAccountWaitCount(c.Request.Context(), account.ID, selection.WaitPlan.MaxWaiting)
if err != nil {
log.Printf("Increment account wait count failed: %v", err)
reqLog.Warn("gemini.account_wait_counter_increment_failed", zap.Int64("account_id", account.ID), zap.Error(err))
} else if !canWait {
log.Printf("Account wait queue full: account=%d", account.ID)
reqLog.Info("gemini.account_wait_queue_full",
zap.Int64("account_id", account.ID),
zap.Int("max_waiting", selection.WaitPlan.MaxWaiting),
)
googleError(c, http.StatusTooManyRequests, "Too many pending requests, please retry later")
return
}
......@@ -412,6 +439,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
&streamStarted,
)
if err != nil {
reqLog.Warn("gemini.account_slot_acquire_failed", zap.Int64("account_id", account.ID), zap.Error(err))
googleError(c, http.StatusTooManyRequests, err.Error())
return
}
......@@ -420,7 +448,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
accountWaitCounted = false
}
if err := h.gatewayService.BindStickySession(c.Request.Context(), apiKey.GroupID, sessionKey, account.ID); err != nil {
log.Printf("Bind sticky session failed: %v", err)
reqLog.Warn("gemini.bind_sticky_session_failed", zap.Int64("account_id", account.ID), zap.Error(err))
}
}
// 账号槽位/等待计数需要在超时或断开时安全回收
......@@ -454,7 +482,12 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
}
lastFailoverErr = failoverErr
switchCount++
log.Printf("Gemini account %d: upstream error %d, switching account %d/%d", account.ID, failoverErr.StatusCode, switchCount, maxAccountSwitches)
reqLog.Warn("gemini.upstream_failover_switching",
zap.Int64("account_id", account.ID),
zap.Int("upstream_status", failoverErr.StatusCode),
zap.Int("switch_count", switchCount),
zap.Int("max_switches", maxAccountSwitches),
)
if account.Platform == service.PlatformAntigravity {
if !sleepFailoverDelay(c.Request.Context(), switchCount) {
return
......@@ -463,7 +496,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
continue
}
// ForwardNative already wrote the response
log.Printf("Gemini native forward failed: %v", err)
reqLog.Error("gemini.forward_failed", zap.Int64("account_id", account.ID), zap.Error(err))
return
}
......@@ -482,7 +515,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
account.ID,
matchedDigestChain,
); err != nil {
log.Printf("[Gemini] Failed to save digest session: %v", err)
reqLog.Warn("gemini.digest_session_save_failed", zap.Int64("account_id", account.ID), zap.Error(err))
}
}
......@@ -504,9 +537,20 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
ForceCacheBilling: fcb,
APIKeyService: h.apiKeyService,
}); err != nil {
log.Printf("Record usage failed: %v", err)
logger.L().With(
zap.String("component", "handler.gemini_v1beta.models"),
zap.Int64("user_id", authSubject.UserID),
zap.Int64("api_key_id", apiKey.ID),
zap.Any("group_id", apiKey.GroupID),
zap.String("model", modelName),
zap.Int64("account_id", usedAccount.ID),
).Error("gemini.record_usage_failed", zap.Error(err))
}
}(result, account, userAgent, clientIP, forceCacheBilling)
reqLog.Debug("gemini.request_completed",
zap.Int64("account_id", account.ID),
zap.Int("switch_count", switchCount),
)
return
}
}
......
package handler
import (
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
)
func requestLogger(c *gin.Context, component string, fields ...zap.Field) *zap.Logger {
base := logger.L()
if c != nil && c.Request != nil {
base = logger.FromContext(c.Request.Context())
}
if component != "" {
fields = append([]zap.Field{zap.String("component", component)}, fields...)
}
return base.With(fields...)
}
......@@ -6,18 +6,19 @@ import (
"errors"
"fmt"
"io"
"log"
"net/http"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/tidwall/gjson"
"go.uber.org/zap"
)
// OpenAIGatewayHandler handles OpenAI API gateway requests
......@@ -74,6 +75,13 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
h.errorResponse(c, http.StatusInternalServerError, "api_error", "User context not found")
return
}
reqLog := requestLogger(
c,
"handler.openai_gateway.responses",
zap.Int64("user_id", subject.UserID),
zap.Int64("api_key_id", apiKey.ID),
zap.Any("group_id", apiKey.GroupID),
)
// Read request body
body, err := io.ReadAll(c.Request.Body)
......@@ -113,6 +121,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
return
}
reqStream := streamResult.Bool()
reqLog = reqLog.With(zap.String("model", reqModel), zap.Bool("stream", reqStream))
setOpsRequestContext(c, reqModel, reqStream, body)
......@@ -128,13 +137,17 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
previousResponseID, _ := reqBody["previous_response_id"].(string)
if strings.TrimSpace(previousResponseID) == "" && !service.HasToolCallContext(reqBody) {
if service.HasFunctionCallOutputMissingCallID(reqBody) {
log.Printf("[OpenAI Handler] function_call_output 缺少 call_id: model=%s", reqModel)
reqLog.Warn("openai.request_validation_failed",
zap.String("reason", "function_call_output_missing_call_id"),
)
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "function_call_output requires call_id or previous_response_id; if relying on history, ensure store=true and reuse previous_response_id")
return
}
callIDs := service.FunctionCallOutputCallIDs(reqBody)
if !service.HasItemReferenceForCallIDs(reqBody, callIDs) {
log.Printf("[OpenAI Handler] function_call_output 缺少匹配的 item_reference: model=%s", reqModel)
reqLog.Warn("openai.request_validation_failed",
zap.String("reason", "function_call_output_missing_item_reference"),
)
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "function_call_output requires item_reference ids matching each call_id, or previous_response_id/tool_call context; if relying on history, ensure store=true and reuse previous_response_id")
return
}
......@@ -160,7 +173,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
// 0. 先尝试直接抢占用户槽位(快速路径)
userReleaseFunc, userAcquired, err := h.concurrencyHelper.TryAcquireUserSlot(c.Request.Context(), subject.UserID, subject.Concurrency)
if err != nil {
log.Printf("User concurrency acquire failed: %v", err)
reqLog.Warn("openai.user_slot_acquire_failed", zap.Error(err))
h.handleConcurrencyError(c, err, "user", streamStarted)
return
}
......@@ -171,9 +184,10 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
maxWait := service.CalculateMaxWait(subject.Concurrency)
canWait, waitErr := h.concurrencyHelper.IncrementWaitCount(c.Request.Context(), subject.UserID, maxWait)
if waitErr != nil {
log.Printf("Increment wait count failed: %v", waitErr)
reqLog.Warn("openai.user_wait_counter_increment_failed", zap.Error(waitErr))
// 按现有降级语义:等待计数异常时放行后续抢槽流程
} else if !canWait {
reqLog.Info("openai.user_wait_queue_full", zap.Int("max_wait", maxWait))
h.errorResponse(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later")
return
}
......@@ -188,7 +202,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
userReleaseFunc, err = h.concurrencyHelper.AcquireUserSlotWithWait(c, subject.UserID, subject.Concurrency, reqStream, &streamStarted)
if err != nil {
log.Printf("User concurrency acquire failed: %v", err)
reqLog.Warn("openai.user_slot_acquire_failed_after_wait", zap.Error(err))
h.handleConcurrencyError(c, err, "user", streamStarted)
return
}
......@@ -207,7 +221,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
// 2. Re-check billing eligibility after wait
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
log.Printf("Billing eligibility check failed after wait: %v", err)
reqLog.Info("openai.billing_eligibility_check_failed", zap.Error(err))
status, code, message := billingErrorDetails(err)
h.handleStreamingAwareError(c, status, code, message, streamStarted)
return
......@@ -223,10 +237,13 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
for {
// Select account supporting the requested model
log.Printf("[OpenAI Handler] Selecting account: groupID=%v model=%s", apiKey.GroupID, reqModel)
reqLog.Debug("openai.account_selecting", zap.Int("excluded_account_count", len(failedAccountIDs)))
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionHash, reqModel, failedAccountIDs)
if err != nil {
log.Printf("[OpenAI Handler] SelectAccount failed: groupID=%v model=%s tried=%d err=%v", apiKey.GroupID, reqModel, len(failedAccountIDs), err)
reqLog.Warn("openai.account_select_failed",
zap.Error(err),
zap.Int("excluded_account_count", len(failedAccountIDs)),
)
if len(failedAccountIDs) == 0 {
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "Service temporarily unavailable", streamStarted)
return
......@@ -239,8 +256,8 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
return
}
account := selection.Account
log.Printf("[OpenAI Handler] Selected account: id=%d name=%s", account.ID, account.Name)
setOpsSelectedAccount(c, account.ID)
reqLog.Debug("openai.account_selected", zap.Int64("account_id", account.ID), zap.String("account_name", account.Name))
setOpsSelectedAccount(c, account.ID, account.Platform)
// 3. Acquire account concurrency slot
accountReleaseFunc := selection.ReleaseFunc
......@@ -257,22 +274,25 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
selection.WaitPlan.MaxConcurrency,
)
if err != nil {
log.Printf("Account concurrency quick acquire failed: %v", err)
reqLog.Warn("openai.account_slot_quick_acquire_failed", zap.Int64("account_id", account.ID), zap.Error(err))
h.handleConcurrencyError(c, err, "account", streamStarted)
return
}
if fastAcquired {
accountReleaseFunc = fastReleaseFunc
if err := h.gatewayService.BindStickySession(c.Request.Context(), apiKey.GroupID, sessionHash, account.ID); err != nil {
log.Printf("Bind sticky session failed: %v", err)
reqLog.Warn("openai.bind_sticky_session_failed", zap.Int64("account_id", account.ID), zap.Error(err))
}
} else {
accountWaitCounted := false
canWait, err := h.concurrencyHelper.IncrementAccountWaitCount(c.Request.Context(), account.ID, selection.WaitPlan.MaxWaiting)
if err != nil {
log.Printf("Increment account wait count failed: %v", err)
reqLog.Warn("openai.account_wait_counter_increment_failed", zap.Int64("account_id", account.ID), zap.Error(err))
} else if !canWait {
log.Printf("Account wait queue full: account=%d", account.ID)
reqLog.Info("openai.account_wait_queue_full",
zap.Int64("account_id", account.ID),
zap.Int("max_waiting", selection.WaitPlan.MaxWaiting),
)
h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later", streamStarted)
return
}
......@@ -295,7 +315,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
&streamStarted,
)
if err != nil {
log.Printf("Account concurrency acquire failed: %v", err)
reqLog.Warn("openai.account_slot_acquire_failed", zap.Int64("account_id", account.ID), zap.Error(err))
releaseWait()
h.handleConcurrencyError(c, err, "account", streamStarted)
return
......@@ -303,7 +323,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
// Slot acquired: no longer waiting in queue.
releaseWait()
if err := h.gatewayService.BindStickySession(c.Request.Context(), apiKey.GroupID, sessionHash, account.ID); err != nil {
log.Printf("Bind sticky session failed: %v", err)
reqLog.Warn("openai.bind_sticky_session_failed", zap.Int64("account_id", account.ID), zap.Error(err))
}
}
}
......@@ -337,11 +357,16 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
return
}
switchCount++
log.Printf("Account %d: upstream error %d, switching account %d/%d", account.ID, failoverErr.StatusCode, switchCount, maxAccountSwitches)
reqLog.Warn("openai.upstream_failover_switching",
zap.Int64("account_id", account.ID),
zap.Int("upstream_status", failoverErr.StatusCode),
zap.Int("switch_count", switchCount),
zap.Int("max_switches", maxAccountSwitches),
)
continue
}
// Error response already handled in Forward, just log
log.Printf("Account %d: Forward request failed: %v", account.ID, err)
reqLog.Error("openai.forward_failed", zap.Int64("account_id", account.ID), zap.Error(err))
return
}
......@@ -363,9 +388,20 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
IPAddress: ip,
APIKeyService: h.apiKeyService,
}); err != nil {
log.Printf("Record usage failed: %v", err)
logger.L().With(
zap.String("component", "handler.openai_gateway.responses"),
zap.Int64("user_id", subject.UserID),
zap.Int64("api_key_id", apiKey.ID),
zap.Any("group_id", apiKey.GroupID),
zap.String("model", reqModel),
zap.Int64("account_id", usedAccount.ID),
).Error("openai.record_usage_failed", zap.Error(err))
}
}(result, account, userAgent, clientIP)
reqLog.Debug("openai.request_completed",
zap.Int64("account_id", account.ID),
zap.Int("switch_count", switchCount),
)
return
}
}
......
......@@ -255,18 +255,33 @@ func setOpsRequestContext(c *gin.Context, model string, stream bool, requestBody
if c == nil {
return
}
model = strings.TrimSpace(model)
c.Set(opsModelKey, model)
c.Set(opsStreamKey, stream)
if len(requestBody) > 0 {
c.Set(opsRequestBodyKey, requestBody)
}
if c.Request != nil && model != "" {
ctx := context.WithValue(c.Request.Context(), ctxkey.Model, model)
c.Request = c.Request.WithContext(ctx)
}
}
func setOpsSelectedAccount(c *gin.Context, accountID int64) {
func setOpsSelectedAccount(c *gin.Context, accountID int64, platform ...string) {
if c == nil || accountID <= 0 {
return
}
c.Set(opsAccountIDKey, accountID)
if c.Request != nil {
ctx := context.WithValue(c.Request.Context(), ctxkey.AccountID, accountID)
if len(platform) > 0 {
p := strings.TrimSpace(platform[0])
if p != "" {
ctx = context.WithValue(ctx, ctxkey.Platform, p)
}
}
c.Request = c.Request.WithContext(ctx)
}
}
type opsCaptureWriter struct {
......
......@@ -7,7 +7,6 @@ import (
"errors"
"fmt"
"io"
"log"
"net/http"
"os"
"path"
......@@ -18,12 +17,14 @@ import (
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
"go.uber.org/zap"
)
// SoraGatewayHandler handles Sora chat completions requests
......@@ -89,6 +90,13 @@ func (h *SoraGatewayHandler) ChatCompletions(c *gin.Context) {
h.errorResponse(c, http.StatusInternalServerError, "api_error", "User context not found")
return
}
reqLog := requestLogger(
c,
"handler.sora_gateway.chat_completions",
zap.Int64("user_id", subject.UserID),
zap.Int64("api_key_id", apiKey.ID),
zap.Any("group_id", apiKey.GroupID),
)
body, err := io.ReadAll(c.Request.Body)
if err != nil {
......@@ -127,6 +135,7 @@ func (h *SoraGatewayHandler) ChatCompletions(c *gin.Context) {
}
clientStream := gjson.GetBytes(body, "stream").Bool()
reqLog = reqLog.With(zap.String("model", reqModel), zap.Bool("stream", clientStream))
if !clientStream {
if h.streamMode == "error" {
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Sora requires stream=true")
......@@ -160,8 +169,9 @@ func (h *SoraGatewayHandler) ChatCompletions(c *gin.Context) {
canWait, err := h.concurrencyHelper.IncrementWaitCount(c.Request.Context(), subject.UserID, maxWait)
waitCounted := false
if err != nil {
log.Printf("Increment wait count failed: %v", err)
reqLog.Warn("sora.user_wait_counter_increment_failed", zap.Error(err))
} else if !canWait {
reqLog.Info("sora.user_wait_queue_full", zap.Int("max_wait", maxWait))
h.errorResponse(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later")
return
}
......@@ -176,7 +186,7 @@ func (h *SoraGatewayHandler) ChatCompletions(c *gin.Context) {
userReleaseFunc, err := h.concurrencyHelper.AcquireUserSlotWithWait(c, subject.UserID, subject.Concurrency, clientStream, &streamStarted)
if err != nil {
log.Printf("User concurrency acquire failed: %v", err)
reqLog.Warn("sora.user_slot_acquire_failed", zap.Error(err))
h.handleConcurrencyError(c, err, "user", streamStarted)
return
}
......@@ -190,7 +200,7 @@ func (h *SoraGatewayHandler) ChatCompletions(c *gin.Context) {
}
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
log.Printf("Billing eligibility check failed after wait: %v", err)
reqLog.Info("sora.billing_eligibility_check_failed", zap.Error(err))
status, code, message := billingErrorDetails(err)
h.handleStreamingAwareError(c, status, code, message, streamStarted)
return
......@@ -206,7 +216,10 @@ func (h *SoraGatewayHandler) ChatCompletions(c *gin.Context) {
for {
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionHash, reqModel, failedAccountIDs, "")
if err != nil {
log.Printf("[Sora Handler] SelectAccount failed: %v", err)
reqLog.Warn("sora.account_select_failed",
zap.Error(err),
zap.Int("excluded_account_count", len(failedAccountIDs)),
)
if len(failedAccountIDs) == 0 {
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted)
return
......@@ -215,7 +228,7 @@ func (h *SoraGatewayHandler) ChatCompletions(c *gin.Context) {
return
}
account := selection.Account
setOpsSelectedAccount(c, account.ID)
setOpsSelectedAccount(c, account.ID, account.Platform)
accountReleaseFunc := selection.ReleaseFunc
if !selection.Acquired {
......@@ -226,9 +239,12 @@ func (h *SoraGatewayHandler) ChatCompletions(c *gin.Context) {
accountWaitCounted := false
canWait, err := h.concurrencyHelper.IncrementAccountWaitCount(c.Request.Context(), account.ID, selection.WaitPlan.MaxWaiting)
if err != nil {
log.Printf("Increment account wait count failed: %v", err)
reqLog.Warn("sora.account_wait_counter_increment_failed", zap.Int64("account_id", account.ID), zap.Error(err))
} else if !canWait {
log.Printf("Account wait queue full: account=%d", account.ID)
reqLog.Info("sora.account_wait_queue_full",
zap.Int64("account_id", account.ID),
zap.Int("max_waiting", selection.WaitPlan.MaxWaiting),
)
h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later", streamStarted)
return
}
......@@ -250,7 +266,7 @@ func (h *SoraGatewayHandler) ChatCompletions(c *gin.Context) {
&streamStarted,
)
if err != nil {
log.Printf("Account concurrency acquire failed: %v", err)
reqLog.Warn("sora.account_slot_acquire_failed", zap.Int64("account_id", account.ID), zap.Error(err))
h.handleConcurrencyError(c, err, "account", streamStarted)
return
}
......@@ -276,10 +292,15 @@ func (h *SoraGatewayHandler) ChatCompletions(c *gin.Context) {
}
lastFailoverStatus = failoverErr.StatusCode
switchCount++
log.Printf("Account %d: upstream error %d, switching account %d/%d", account.ID, failoverErr.StatusCode, switchCount, maxAccountSwitches)
reqLog.Warn("sora.upstream_failover_switching",
zap.Int64("account_id", account.ID),
zap.Int("upstream_status", failoverErr.StatusCode),
zap.Int("switch_count", switchCount),
zap.Int("max_switches", maxAccountSwitches),
)
continue
}
log.Printf("Account %d: Forward request failed: %v", account.ID, err)
reqLog.Error("sora.forward_failed", zap.Int64("account_id", account.ID), zap.Error(err))
return
}
......@@ -298,9 +319,20 @@ func (h *SoraGatewayHandler) ChatCompletions(c *gin.Context) {
UserAgent: ua,
IPAddress: ip,
}); err != nil {
log.Printf("Record usage failed: %v", err)
logger.L().With(
zap.String("component", "handler.sora_gateway.chat_completions"),
zap.Int64("user_id", subject.UserID),
zap.Int64("api_key_id", apiKey.ID),
zap.Any("group_id", apiKey.GroupID),
zap.String("model", reqModel),
zap.Int64("account_id", usedAccount.ID),
).Error("sora.record_usage_failed", zap.Error(err))
}
}(result, account, userAgent, clientIP)
reqLog.Debug("sora.request_completed",
zap.Int64("account_id", account.ID),
zap.Int("switch_count", switchCount),
)
return
}
}
......
......@@ -8,9 +8,21 @@ const (
// ForcePlatform 强制平台(用于 /antigravity 路由),由 middleware.ForcePlatform 设置
ForcePlatform Key = "ctx_force_platform"
// RequestID 为服务端生成/透传的请求 ID。
RequestID Key = "ctx_request_id"
// ClientRequestID 客户端请求的唯一标识,用于追踪请求全生命周期(用于 Ops 监控与排障)。
ClientRequestID Key = "ctx_client_request_id"
// Model 请求模型标识(用于统一请求链路日志字段)。
Model Key = "ctx_model"
// Platform 当前请求最终命中的平台(用于统一请求链路日志字段)。
Platform Key = "ctx_platform"
// AccountID 当前请求最终命中的账号 ID(用于统一请求链路日志字段)。
AccountID Key = "ctx_account_id"
// RetryCount 表示当前请求在网关层的重试次数(用于 Ops 记录与排障)。
RetryCount Key = "ctx_retry_count"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment