Commit 994da655 authored by 陈曦's avatar 陈曦
Browse files

收集req和resp的相关更改

parent 8f7ac1ea
......@@ -20,6 +20,7 @@ import (
"github.com/Wei-Shaw/sub2api/ent/promocodeusage"
"github.com/Wei-Shaw/sub2api/ent/proxy"
"github.com/Wei-Shaw/sub2api/ent/redeemcode"
"github.com/Wei-Shaw/sub2api/ent/requestcapturelog"
"github.com/Wei-Shaw/sub2api/ent/schema"
"github.com/Wei-Shaw/sub2api/ent/securitysecret"
"github.com/Wei-Shaw/sub2api/ent/setting"
......@@ -132,6 +133,10 @@ func init() {
apikeyDescUsage7d := apikeyFields[16].Descriptor()
// apikey.DefaultUsage7d holds the default value on creation for the usage_7d field.
apikey.DefaultUsage7d = apikeyDescUsage7d.Default.(float64)
// apikeyDescCaptureRequests is the schema descriptor for capture_requests field.
apikeyDescCaptureRequests := apikeyFields[20].Descriptor()
// apikey.DefaultCaptureRequests holds the default value on creation for the capture_requests field.
apikey.DefaultCaptureRequests = apikeyDescCaptureRequests.Default.(bool)
accountMixin := schema.Account{}.Mixin()
accountMixinHooks1 := accountMixin[1].Hooks()
account.Hooks[0] = accountMixinHooks1[0]
......@@ -867,6 +872,32 @@ func init() {
redeemcodeDescValidityDays := redeemcodeFields[9].Descriptor()
// redeemcode.DefaultValidityDays holds the default value on creation for the validity_days field.
redeemcode.DefaultValidityDays = redeemcodeDescValidityDays.Default.(int)
requestcapturelogFields := schema.RequestCaptureLog{}.Fields()
_ = requestcapturelogFields
// requestcapturelogDescRequestID is the schema descriptor for request_id field.
requestcapturelogDescRequestID := requestcapturelogFields[2].Descriptor()
// requestcapturelog.RequestIDValidator is a validator for the "request_id" field. It is called by the builders before save.
requestcapturelog.RequestIDValidator = requestcapturelogDescRequestID.Validators[0].(func(string) error)
// requestcapturelogDescPath is the schema descriptor for path field.
requestcapturelogDescPath := requestcapturelogFields[3].Descriptor()
// requestcapturelog.PathValidator is a validator for the "path" field. It is called by the builders before save.
requestcapturelog.PathValidator = requestcapturelogDescPath.Validators[0].(func(string) error)
// requestcapturelogDescMethod is the schema descriptor for method field.
requestcapturelogDescMethod := requestcapturelogFields[4].Descriptor()
// requestcapturelog.MethodValidator is a validator for the "method" field. It is called by the builders before save.
requestcapturelog.MethodValidator = requestcapturelogDescMethod.Validators[0].(func(string) error)
// requestcapturelogDescIPAddress is the schema descriptor for ip_address field.
requestcapturelogDescIPAddress := requestcapturelogFields[5].Descriptor()
// requestcapturelog.IPAddressValidator is a validator for the "ip_address" field. It is called by the builders before save.
requestcapturelog.IPAddressValidator = requestcapturelogDescIPAddress.Validators[0].(func(string) error)
// requestcapturelogDescNfsFilePath is the schema descriptor for nfs_file_path field.
requestcapturelogDescNfsFilePath := requestcapturelogFields[8].Descriptor()
// requestcapturelog.NfsFilePathValidator is a validator for the "nfs_file_path" field. It is called by the builders before save.
requestcapturelog.NfsFilePathValidator = requestcapturelogDescNfsFilePath.Validators[0].(func(string) error)
// requestcapturelogDescCreatedAt is the schema descriptor for created_at field.
requestcapturelogDescCreatedAt := requestcapturelogFields[9].Descriptor()
// requestcapturelog.DefaultCreatedAt holds the default value on creation for the created_at field.
requestcapturelog.DefaultCreatedAt = requestcapturelogDescCreatedAt.Default.(func() time.Time)
securitysecretMixin := schema.SecuritySecret{}.Mixin()
securitysecretMixinFields0 := securitysecretMixin[0].Fields()
_ = securitysecretMixinFields0
......
......@@ -115,6 +115,11 @@ func (APIKey) Fields() []ent.Field {
Optional().
Nillable().
Comment("Start time of the current 7d rate limit window"),
// ========== Request capture ==========
field.Bool("capture_requests").
Default(false).
Comment("是否对该 API Key 的请求体进行存储捕获"),
}
}
......
package schema
import (
"time"
"entgo.io/ent"
"entgo.io/ent/dialect"
"entgo.io/ent/dialect/entsql"
"entgo.io/ent/schema"
"entgo.io/ent/schema/field"
"entgo.io/ent/schema/index"
)
// RequestCaptureLog 记录指定 API Key 的请求体,用于审计和分析。
// 只追加,不支持更新/删除(同 PaymentAuditLog 模式)。
type RequestCaptureLog struct {
ent.Schema
}
func (RequestCaptureLog) Annotations() []schema.Annotation {
return []schema.Annotation{
entsql.Annotation{Table: "request_capture_logs"},
}
}
func (RequestCaptureLog) Fields() []ent.Field {
return []ent.Field{
field.Int64("api_key_id"),
field.Int64("user_id"),
field.String("request_id").
MaxLen(64).
Optional().
Nillable(),
field.String("path").
MaxLen(100).
Optional().
Nillable(),
field.String("method").
MaxLen(10).
Optional().
Nillable(),
field.String("ip_address").
MaxLen(45).
Optional().
Nillable(),
// request_body 存原始 JSON 文本,不加索引,避免影响查询计划
field.Text("request_body").
Optional().
Nillable(),
// response_body 存响应文本(非 streaming 为完整 JSON,streaming 为拼接的 assistant text)
field.Text("response_body").
Optional().
Nillable(),
// nfs_file_path NFS 文件路径快照,方便核查
field.String("nfs_file_path").
MaxLen(500).
Optional().
Nillable(),
field.Time("created_at").
Default(time.Now).
Immutable().
SchemaType(map[string]string{dialect.Postgres: "timestamptz"}),
}
}
func (RequestCaptureLog) Edges() []ent.Edge {
return nil
}
func (RequestCaptureLog) Indexes() []ent.Index {
return []ent.Index{
index.Fields("api_key_id", "created_at"),
index.Fields("user_id"),
}
}
......@@ -44,6 +44,8 @@ type Tx struct {
Proxy *ProxyClient
// RedeemCode is the client for interacting with the RedeemCode builders.
RedeemCode *RedeemCodeClient
// RequestCaptureLog is the client for interacting with the RequestCaptureLog builders.
RequestCaptureLog *RequestCaptureLogClient
// SecuritySecret is the client for interacting with the SecuritySecret builders.
SecuritySecret *SecuritySecretClient
// Setting is the client for interacting with the Setting builders.
......@@ -212,6 +214,7 @@ func (tx *Tx) init() {
tx.PromoCodeUsage = NewPromoCodeUsageClient(tx.config)
tx.Proxy = NewProxyClient(tx.config)
tx.RedeemCode = NewRedeemCodeClient(tx.config)
tx.RequestCaptureLog = NewRequestCaptureLogClient(tx.config)
tx.SecuritySecret = NewSecuritySecretClient(tx.config)
tx.Setting = NewSettingClient(tx.config)
tx.SubscriptionPlan = NewSubscriptionPlanClient(tx.config)
......
......@@ -83,6 +83,16 @@ type Config struct {
Gemini GeminiConfig `mapstructure:"gemini"`
Update UpdateConfig `mapstructure:"update"`
Idempotency IdempotencyConfig `mapstructure:"idempotency"`
RequestCapture RequestCaptureConfig `mapstructure:"request_capture"`
}
// RequestCaptureConfig 配置请求体捕获功能
type RequestCaptureConfig struct {
// NFSPath 为本地挂载的 NFS 根目录(例如 /mnt/nfs/requests)。
// 留空则跳过文件写入,只写数据库。
NFSPath string `mapstructure:"nfs_path"`
// WorkerTimeoutSeconds 单次异步写入的超时时间(秒),默认 5。
WorkerTimeoutSeconds int `mapstructure:"worker_timeout_seconds"`
}
type LogConfig struct {
......
......@@ -45,6 +45,7 @@ type GatewayHandler struct {
apiKeyService *service.APIKeyService
usageRecordWorkerPool *service.UsageRecordWorkerPool
errorPassthroughService *service.ErrorPassthroughService
requestCaptureService *service.RequestCaptureService
concurrencyHelper *ConcurrencyHelper
userMsgQueueHelper *UserMsgQueueHelper
maxAccountSwitches int
......@@ -65,6 +66,7 @@ func NewGatewayHandler(
apiKeyService *service.APIKeyService,
usageRecordWorkerPool *service.UsageRecordWorkerPool,
errorPassthroughService *service.ErrorPassthroughService,
requestCaptureService *service.RequestCaptureService,
userMsgQueueService *service.UserMessageQueueService,
cfg *config.Config,
settingService *service.SettingService,
......@@ -98,6 +100,7 @@ func NewGatewayHandler(
apiKeyService: apiKeyService,
usageRecordWorkerPool: usageRecordWorkerPool,
errorPassthroughService: errorPassthroughService,
requestCaptureService: requestCaptureService,
concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatClaude, pingInterval),
userMsgQueueHelper: umqHelper,
maxAccountSwitches: maxAccountSwitches,
......@@ -147,6 +150,20 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
return
}
// 捕获请求体(仅当该 API Key 开启了 capture_requests)
var captureID int64
if apiKey.CaptureRequests && h.requestCaptureService != nil {
requestID, _ := c.Request.Context().Value(ctxkey.RequestID).(string)
captureID = h.requestCaptureService.Capture(
apiKey.ID, subject.UserID,
requestID,
c.Request.URL.Path,
c.Request.Method,
c.ClientIP(),
body,
)
}
setOpsRequestContext(c, "", false, body)
parsedReq, err := service.ParseGatewayRequest(body, domain.PlatformAnthropic)
......@@ -811,6 +828,11 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
result.ReasoningEffort = service.NormalizeClaudeOutputEffort(parsedReq.OutputEffort)
}
// 异步写入响应体到捕获记录
if captureID > 0 && h.requestCaptureService != nil {
h.requestCaptureService.CaptureResponse(captureID, result.ResponseBody)
}
// 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。
h.submitUsageRecordTask(func(ctx context.Context) {
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
......
......@@ -632,6 +632,7 @@ func apiKeyEntityToService(m *dbent.APIKey) *service.APIKey {
Window5hStart: m.Window5hStart,
Window1dStart: m.Window1dStart,
Window7dStart: m.Window7dStart,
CaptureRequests: m.CaptureRequests,
}
if m.Edges.User != nil {
out.User = userEntityToService(m.Edges.User)
......
This diff is collapsed.
......@@ -89,6 +89,7 @@ var ProviderSet = wire.NewSet(
NewErrorPassthroughRepository,
NewTLSFingerprintProfileRepository,
NewChannelRepository,
NewRequestCaptureLogRepository,
// Cache implementations
NewGatewayCache,
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment