Commit 994da655 authored by 陈曦's avatar 陈曦
Browse files

收集req和resp的相关更改

parent 8f7ac1ea
...@@ -20,6 +20,7 @@ import ( ...@@ -20,6 +20,7 @@ import (
"github.com/Wei-Shaw/sub2api/ent/promocodeusage" "github.com/Wei-Shaw/sub2api/ent/promocodeusage"
"github.com/Wei-Shaw/sub2api/ent/proxy" "github.com/Wei-Shaw/sub2api/ent/proxy"
"github.com/Wei-Shaw/sub2api/ent/redeemcode" "github.com/Wei-Shaw/sub2api/ent/redeemcode"
"github.com/Wei-Shaw/sub2api/ent/requestcapturelog"
"github.com/Wei-Shaw/sub2api/ent/schema" "github.com/Wei-Shaw/sub2api/ent/schema"
"github.com/Wei-Shaw/sub2api/ent/securitysecret" "github.com/Wei-Shaw/sub2api/ent/securitysecret"
"github.com/Wei-Shaw/sub2api/ent/setting" "github.com/Wei-Shaw/sub2api/ent/setting"
...@@ -132,6 +133,10 @@ func init() { ...@@ -132,6 +133,10 @@ func init() {
apikeyDescUsage7d := apikeyFields[16].Descriptor() apikeyDescUsage7d := apikeyFields[16].Descriptor()
// apikey.DefaultUsage7d holds the default value on creation for the usage_7d field. // apikey.DefaultUsage7d holds the default value on creation for the usage_7d field.
apikey.DefaultUsage7d = apikeyDescUsage7d.Default.(float64) apikey.DefaultUsage7d = apikeyDescUsage7d.Default.(float64)
// apikeyDescCaptureRequests is the schema descriptor for capture_requests field.
apikeyDescCaptureRequests := apikeyFields[20].Descriptor()
// apikey.DefaultCaptureRequests holds the default value on creation for the capture_requests field.
apikey.DefaultCaptureRequests = apikeyDescCaptureRequests.Default.(bool)
accountMixin := schema.Account{}.Mixin() accountMixin := schema.Account{}.Mixin()
accountMixinHooks1 := accountMixin[1].Hooks() accountMixinHooks1 := accountMixin[1].Hooks()
account.Hooks[0] = accountMixinHooks1[0] account.Hooks[0] = accountMixinHooks1[0]
...@@ -867,6 +872,32 @@ func init() { ...@@ -867,6 +872,32 @@ func init() {
redeemcodeDescValidityDays := redeemcodeFields[9].Descriptor() redeemcodeDescValidityDays := redeemcodeFields[9].Descriptor()
// redeemcode.DefaultValidityDays holds the default value on creation for the validity_days field. // redeemcode.DefaultValidityDays holds the default value on creation for the validity_days field.
redeemcode.DefaultValidityDays = redeemcodeDescValidityDays.Default.(int) redeemcode.DefaultValidityDays = redeemcodeDescValidityDays.Default.(int)
requestcapturelogFields := schema.RequestCaptureLog{}.Fields()
_ = requestcapturelogFields
// requestcapturelogDescRequestID is the schema descriptor for request_id field.
requestcapturelogDescRequestID := requestcapturelogFields[2].Descriptor()
// requestcapturelog.RequestIDValidator is a validator for the "request_id" field. It is called by the builders before save.
requestcapturelog.RequestIDValidator = requestcapturelogDescRequestID.Validators[0].(func(string) error)
// requestcapturelogDescPath is the schema descriptor for path field.
requestcapturelogDescPath := requestcapturelogFields[3].Descriptor()
// requestcapturelog.PathValidator is a validator for the "path" field. It is called by the builders before save.
requestcapturelog.PathValidator = requestcapturelogDescPath.Validators[0].(func(string) error)
// requestcapturelogDescMethod is the schema descriptor for method field.
requestcapturelogDescMethod := requestcapturelogFields[4].Descriptor()
// requestcapturelog.MethodValidator is a validator for the "method" field. It is called by the builders before save.
requestcapturelog.MethodValidator = requestcapturelogDescMethod.Validators[0].(func(string) error)
// requestcapturelogDescIPAddress is the schema descriptor for ip_address field.
requestcapturelogDescIPAddress := requestcapturelogFields[5].Descriptor()
// requestcapturelog.IPAddressValidator is a validator for the "ip_address" field. It is called by the builders before save.
requestcapturelog.IPAddressValidator = requestcapturelogDescIPAddress.Validators[0].(func(string) error)
// requestcapturelogDescNfsFilePath is the schema descriptor for nfs_file_path field.
requestcapturelogDescNfsFilePath := requestcapturelogFields[8].Descriptor()
// requestcapturelog.NfsFilePathValidator is a validator for the "nfs_file_path" field. It is called by the builders before save.
requestcapturelog.NfsFilePathValidator = requestcapturelogDescNfsFilePath.Validators[0].(func(string) error)
// requestcapturelogDescCreatedAt is the schema descriptor for created_at field.
requestcapturelogDescCreatedAt := requestcapturelogFields[9].Descriptor()
// requestcapturelog.DefaultCreatedAt holds the default value on creation for the created_at field.
requestcapturelog.DefaultCreatedAt = requestcapturelogDescCreatedAt.Default.(func() time.Time)
securitysecretMixin := schema.SecuritySecret{}.Mixin() securitysecretMixin := schema.SecuritySecret{}.Mixin()
securitysecretMixinFields0 := securitysecretMixin[0].Fields() securitysecretMixinFields0 := securitysecretMixin[0].Fields()
_ = securitysecretMixinFields0 _ = securitysecretMixinFields0
......
...@@ -115,6 +115,11 @@ func (APIKey) Fields() []ent.Field { ...@@ -115,6 +115,11 @@ func (APIKey) Fields() []ent.Field {
Optional(). Optional().
Nillable(). Nillable().
Comment("Start time of the current 7d rate limit window"), Comment("Start time of the current 7d rate limit window"),
// ========== Request capture ==========
field.Bool("capture_requests").
Default(false).
Comment("是否对该 API Key 的请求体进行存储捕获"),
} }
} }
......
package schema
import (
"time"
"entgo.io/ent"
"entgo.io/ent/dialect"
"entgo.io/ent/dialect/entsql"
"entgo.io/ent/schema"
"entgo.io/ent/schema/field"
"entgo.io/ent/schema/index"
)
// RequestCaptureLog 记录指定 API Key 的请求体,用于审计和分析。
// 只追加,不支持更新/删除(同 PaymentAuditLog 模式)。
type RequestCaptureLog struct {
ent.Schema
}
func (RequestCaptureLog) Annotations() []schema.Annotation {
return []schema.Annotation{
entsql.Annotation{Table: "request_capture_logs"},
}
}
func (RequestCaptureLog) Fields() []ent.Field {
return []ent.Field{
field.Int64("api_key_id"),
field.Int64("user_id"),
field.String("request_id").
MaxLen(64).
Optional().
Nillable(),
field.String("path").
MaxLen(100).
Optional().
Nillable(),
field.String("method").
MaxLen(10).
Optional().
Nillable(),
field.String("ip_address").
MaxLen(45).
Optional().
Nillable(),
// request_body 存原始 JSON 文本,不加索引,避免影响查询计划
field.Text("request_body").
Optional().
Nillable(),
// response_body 存响应文本(非 streaming 为完整 JSON,streaming 为拼接的 assistant text)
field.Text("response_body").
Optional().
Nillable(),
// nfs_file_path NFS 文件路径快照,方便核查
field.String("nfs_file_path").
MaxLen(500).
Optional().
Nillable(),
field.Time("created_at").
Default(time.Now).
Immutable().
SchemaType(map[string]string{dialect.Postgres: "timestamptz"}),
}
}
func (RequestCaptureLog) Edges() []ent.Edge {
return nil
}
func (RequestCaptureLog) Indexes() []ent.Index {
return []ent.Index{
index.Fields("api_key_id", "created_at"),
index.Fields("user_id"),
}
}
...@@ -44,6 +44,8 @@ type Tx struct { ...@@ -44,6 +44,8 @@ type Tx struct {
Proxy *ProxyClient Proxy *ProxyClient
// RedeemCode is the client for interacting with the RedeemCode builders. // RedeemCode is the client for interacting with the RedeemCode builders.
RedeemCode *RedeemCodeClient RedeemCode *RedeemCodeClient
// RequestCaptureLog is the client for interacting with the RequestCaptureLog builders.
RequestCaptureLog *RequestCaptureLogClient
// SecuritySecret is the client for interacting with the SecuritySecret builders. // SecuritySecret is the client for interacting with the SecuritySecret builders.
SecuritySecret *SecuritySecretClient SecuritySecret *SecuritySecretClient
// Setting is the client for interacting with the Setting builders. // Setting is the client for interacting with the Setting builders.
...@@ -212,6 +214,7 @@ func (tx *Tx) init() { ...@@ -212,6 +214,7 @@ func (tx *Tx) init() {
tx.PromoCodeUsage = NewPromoCodeUsageClient(tx.config) tx.PromoCodeUsage = NewPromoCodeUsageClient(tx.config)
tx.Proxy = NewProxyClient(tx.config) tx.Proxy = NewProxyClient(tx.config)
tx.RedeemCode = NewRedeemCodeClient(tx.config) tx.RedeemCode = NewRedeemCodeClient(tx.config)
tx.RequestCaptureLog = NewRequestCaptureLogClient(tx.config)
tx.SecuritySecret = NewSecuritySecretClient(tx.config) tx.SecuritySecret = NewSecuritySecretClient(tx.config)
tx.Setting = NewSettingClient(tx.config) tx.Setting = NewSettingClient(tx.config)
tx.SubscriptionPlan = NewSubscriptionPlanClient(tx.config) tx.SubscriptionPlan = NewSubscriptionPlanClient(tx.config)
......
...@@ -83,6 +83,16 @@ type Config struct { ...@@ -83,6 +83,16 @@ type Config struct {
Gemini GeminiConfig `mapstructure:"gemini"` Gemini GeminiConfig `mapstructure:"gemini"`
Update UpdateConfig `mapstructure:"update"` Update UpdateConfig `mapstructure:"update"`
Idempotency IdempotencyConfig `mapstructure:"idempotency"` Idempotency IdempotencyConfig `mapstructure:"idempotency"`
RequestCapture RequestCaptureConfig `mapstructure:"request_capture"`
}
// RequestCaptureConfig 配置请求体捕获功能
type RequestCaptureConfig struct {
// NFSPath 为本地挂载的 NFS 根目录(例如 /mnt/nfs/requests)。
// 留空则跳过文件写入,只写数据库。
NFSPath string `mapstructure:"nfs_path"`
// WorkerTimeoutSeconds 单次异步写入的超时时间(秒),默认 5。
WorkerTimeoutSeconds int `mapstructure:"worker_timeout_seconds"`
} }
type LogConfig struct { type LogConfig struct {
......
...@@ -45,6 +45,7 @@ type GatewayHandler struct { ...@@ -45,6 +45,7 @@ type GatewayHandler struct {
apiKeyService *service.APIKeyService apiKeyService *service.APIKeyService
usageRecordWorkerPool *service.UsageRecordWorkerPool usageRecordWorkerPool *service.UsageRecordWorkerPool
errorPassthroughService *service.ErrorPassthroughService errorPassthroughService *service.ErrorPassthroughService
requestCaptureService *service.RequestCaptureService
concurrencyHelper *ConcurrencyHelper concurrencyHelper *ConcurrencyHelper
userMsgQueueHelper *UserMsgQueueHelper userMsgQueueHelper *UserMsgQueueHelper
maxAccountSwitches int maxAccountSwitches int
...@@ -65,6 +66,7 @@ func NewGatewayHandler( ...@@ -65,6 +66,7 @@ func NewGatewayHandler(
apiKeyService *service.APIKeyService, apiKeyService *service.APIKeyService,
usageRecordWorkerPool *service.UsageRecordWorkerPool, usageRecordWorkerPool *service.UsageRecordWorkerPool,
errorPassthroughService *service.ErrorPassthroughService, errorPassthroughService *service.ErrorPassthroughService,
requestCaptureService *service.RequestCaptureService,
userMsgQueueService *service.UserMessageQueueService, userMsgQueueService *service.UserMessageQueueService,
cfg *config.Config, cfg *config.Config,
settingService *service.SettingService, settingService *service.SettingService,
...@@ -98,6 +100,7 @@ func NewGatewayHandler( ...@@ -98,6 +100,7 @@ func NewGatewayHandler(
apiKeyService: apiKeyService, apiKeyService: apiKeyService,
usageRecordWorkerPool: usageRecordWorkerPool, usageRecordWorkerPool: usageRecordWorkerPool,
errorPassthroughService: errorPassthroughService, errorPassthroughService: errorPassthroughService,
requestCaptureService: requestCaptureService,
concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatClaude, pingInterval), concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatClaude, pingInterval),
userMsgQueueHelper: umqHelper, userMsgQueueHelper: umqHelper,
maxAccountSwitches: maxAccountSwitches, maxAccountSwitches: maxAccountSwitches,
...@@ -147,6 +150,20 @@ func (h *GatewayHandler) Messages(c *gin.Context) { ...@@ -147,6 +150,20 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
return return
} }
// 捕获请求体(仅当该 API Key 开启了 capture_requests)
var captureID int64
if apiKey.CaptureRequests && h.requestCaptureService != nil {
requestID, _ := c.Request.Context().Value(ctxkey.RequestID).(string)
captureID = h.requestCaptureService.Capture(
apiKey.ID, subject.UserID,
requestID,
c.Request.URL.Path,
c.Request.Method,
c.ClientIP(),
body,
)
}
setOpsRequestContext(c, "", false, body) setOpsRequestContext(c, "", false, body)
parsedReq, err := service.ParseGatewayRequest(body, domain.PlatformAnthropic) parsedReq, err := service.ParseGatewayRequest(body, domain.PlatformAnthropic)
...@@ -811,6 +828,11 @@ func (h *GatewayHandler) Messages(c *gin.Context) { ...@@ -811,6 +828,11 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
result.ReasoningEffort = service.NormalizeClaudeOutputEffort(parsedReq.OutputEffort) result.ReasoningEffort = service.NormalizeClaudeOutputEffort(parsedReq.OutputEffort)
} }
// 异步写入响应体到捕获记录
if captureID > 0 && h.requestCaptureService != nil {
h.requestCaptureService.CaptureResponse(captureID, result.ResponseBody)
}
// 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。 // 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。
h.submitUsageRecordTask(func(ctx context.Context) { h.submitUsageRecordTask(func(ctx context.Context) {
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{ if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
......
...@@ -632,6 +632,7 @@ func apiKeyEntityToService(m *dbent.APIKey) *service.APIKey { ...@@ -632,6 +632,7 @@ func apiKeyEntityToService(m *dbent.APIKey) *service.APIKey {
Window5hStart: m.Window5hStart, Window5hStart: m.Window5hStart,
Window1dStart: m.Window1dStart, Window1dStart: m.Window1dStart,
Window7dStart: m.Window7dStart, Window7dStart: m.Window7dStart,
CaptureRequests: m.CaptureRequests,
} }
if m.Edges.User != nil { if m.Edges.User != nil {
out.User = userEntityToService(m.Edges.User) out.User = userEntityToService(m.Edges.User)
......
This diff is collapsed.
...@@ -89,6 +89,7 @@ var ProviderSet = wire.NewSet( ...@@ -89,6 +89,7 @@ var ProviderSet = wire.NewSet(
NewErrorPassthroughRepository, NewErrorPassthroughRepository,
NewTLSFingerprintProfileRepository, NewTLSFingerprintProfileRepository,
NewChannelRepository, NewChannelRepository,
NewRequestCaptureLogRepository,
// Cache implementations // Cache implementations
NewGatewayCache, NewGatewayCache,
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment