Commit 994da655 authored by 陈曦's avatar 陈曦
Browse files

收集req和resp的相关更改

parent 8f7ac1ea
......@@ -20,6 +20,7 @@ import (
"github.com/Wei-Shaw/sub2api/ent/promocodeusage"
"github.com/Wei-Shaw/sub2api/ent/proxy"
"github.com/Wei-Shaw/sub2api/ent/redeemcode"
"github.com/Wei-Shaw/sub2api/ent/requestcapturelog"
"github.com/Wei-Shaw/sub2api/ent/schema"
"github.com/Wei-Shaw/sub2api/ent/securitysecret"
"github.com/Wei-Shaw/sub2api/ent/setting"
......@@ -132,6 +133,10 @@ func init() {
apikeyDescUsage7d := apikeyFields[16].Descriptor()
// apikey.DefaultUsage7d holds the default value on creation for the usage_7d field.
apikey.DefaultUsage7d = apikeyDescUsage7d.Default.(float64)
// apikeyDescCaptureRequests is the schema descriptor for capture_requests field.
apikeyDescCaptureRequests := apikeyFields[20].Descriptor()
// apikey.DefaultCaptureRequests holds the default value on creation for the capture_requests field.
apikey.DefaultCaptureRequests = apikeyDescCaptureRequests.Default.(bool)
accountMixin := schema.Account{}.Mixin()
accountMixinHooks1 := accountMixin[1].Hooks()
account.Hooks[0] = accountMixinHooks1[0]
......@@ -867,6 +872,32 @@ func init() {
redeemcodeDescValidityDays := redeemcodeFields[9].Descriptor()
// redeemcode.DefaultValidityDays holds the default value on creation for the validity_days field.
redeemcode.DefaultValidityDays = redeemcodeDescValidityDays.Default.(int)
requestcapturelogFields := schema.RequestCaptureLog{}.Fields()
_ = requestcapturelogFields
// requestcapturelogDescRequestID is the schema descriptor for request_id field.
requestcapturelogDescRequestID := requestcapturelogFields[2].Descriptor()
// requestcapturelog.RequestIDValidator is a validator for the "request_id" field. It is called by the builders before save.
requestcapturelog.RequestIDValidator = requestcapturelogDescRequestID.Validators[0].(func(string) error)
// requestcapturelogDescPath is the schema descriptor for path field.
requestcapturelogDescPath := requestcapturelogFields[3].Descriptor()
// requestcapturelog.PathValidator is a validator for the "path" field. It is called by the builders before save.
requestcapturelog.PathValidator = requestcapturelogDescPath.Validators[0].(func(string) error)
// requestcapturelogDescMethod is the schema descriptor for method field.
requestcapturelogDescMethod := requestcapturelogFields[4].Descriptor()
// requestcapturelog.MethodValidator is a validator for the "method" field. It is called by the builders before save.
requestcapturelog.MethodValidator = requestcapturelogDescMethod.Validators[0].(func(string) error)
// requestcapturelogDescIPAddress is the schema descriptor for ip_address field.
requestcapturelogDescIPAddress := requestcapturelogFields[5].Descriptor()
// requestcapturelog.IPAddressValidator is a validator for the "ip_address" field. It is called by the builders before save.
requestcapturelog.IPAddressValidator = requestcapturelogDescIPAddress.Validators[0].(func(string) error)
// requestcapturelogDescNfsFilePath is the schema descriptor for nfs_file_path field.
requestcapturelogDescNfsFilePath := requestcapturelogFields[8].Descriptor()
// requestcapturelog.NfsFilePathValidator is a validator for the "nfs_file_path" field. It is called by the builders before save.
requestcapturelog.NfsFilePathValidator = requestcapturelogDescNfsFilePath.Validators[0].(func(string) error)
// requestcapturelogDescCreatedAt is the schema descriptor for created_at field.
requestcapturelogDescCreatedAt := requestcapturelogFields[9].Descriptor()
// requestcapturelog.DefaultCreatedAt holds the default value on creation for the created_at field.
requestcapturelog.DefaultCreatedAt = requestcapturelogDescCreatedAt.Default.(func() time.Time)
securitysecretMixin := schema.SecuritySecret{}.Mixin()
securitysecretMixinFields0 := securitysecretMixin[0].Fields()
_ = securitysecretMixinFields0
......
......@@ -115,6 +115,11 @@ func (APIKey) Fields() []ent.Field {
Optional().
Nillable().
Comment("Start time of the current 7d rate limit window"),
// ========== Request capture ==========
field.Bool("capture_requests").
Default(false).
Comment("是否对该 API Key 的请求体进行存储捕获"),
}
}
......
package schema
import (
"time"
"entgo.io/ent"
"entgo.io/ent/dialect"
"entgo.io/ent/dialect/entsql"
"entgo.io/ent/schema"
"entgo.io/ent/schema/field"
"entgo.io/ent/schema/index"
)
// RequestCaptureLog 记录指定 API Key 的请求体,用于审计和分析。
// 只追加,不支持更新/删除(同 PaymentAuditLog 模式)。
type RequestCaptureLog struct {
ent.Schema
}
func (RequestCaptureLog) Annotations() []schema.Annotation {
return []schema.Annotation{
entsql.Annotation{Table: "request_capture_logs"},
}
}
func (RequestCaptureLog) Fields() []ent.Field {
return []ent.Field{
field.Int64("api_key_id"),
field.Int64("user_id"),
field.String("request_id").
MaxLen(64).
Optional().
Nillable(),
field.String("path").
MaxLen(100).
Optional().
Nillable(),
field.String("method").
MaxLen(10).
Optional().
Nillable(),
field.String("ip_address").
MaxLen(45).
Optional().
Nillable(),
// request_body 存原始 JSON 文本,不加索引,避免影响查询计划
field.Text("request_body").
Optional().
Nillable(),
// response_body 存响应文本(非 streaming 为完整 JSON,streaming 为拼接的 assistant text)
field.Text("response_body").
Optional().
Nillable(),
// nfs_file_path NFS 文件路径快照,方便核查
field.String("nfs_file_path").
MaxLen(500).
Optional().
Nillable(),
field.Time("created_at").
Default(time.Now).
Immutable().
SchemaType(map[string]string{dialect.Postgres: "timestamptz"}),
}
}
func (RequestCaptureLog) Edges() []ent.Edge {
return nil
}
func (RequestCaptureLog) Indexes() []ent.Index {
return []ent.Index{
index.Fields("api_key_id", "created_at"),
index.Fields("user_id"),
}
}
......@@ -44,6 +44,8 @@ type Tx struct {
Proxy *ProxyClient
// RedeemCode is the client for interacting with the RedeemCode builders.
RedeemCode *RedeemCodeClient
// RequestCaptureLog is the client for interacting with the RequestCaptureLog builders.
RequestCaptureLog *RequestCaptureLogClient
// SecuritySecret is the client for interacting with the SecuritySecret builders.
SecuritySecret *SecuritySecretClient
// Setting is the client for interacting with the Setting builders.
......@@ -212,6 +214,7 @@ func (tx *Tx) init() {
tx.PromoCodeUsage = NewPromoCodeUsageClient(tx.config)
tx.Proxy = NewProxyClient(tx.config)
tx.RedeemCode = NewRedeemCodeClient(tx.config)
tx.RequestCaptureLog = NewRequestCaptureLogClient(tx.config)
tx.SecuritySecret = NewSecuritySecretClient(tx.config)
tx.Setting = NewSettingClient(tx.config)
tx.SubscriptionPlan = NewSubscriptionPlanClient(tx.config)
......
......@@ -83,6 +83,16 @@ type Config struct {
Gemini GeminiConfig `mapstructure:"gemini"`
Update UpdateConfig `mapstructure:"update"`
Idempotency IdempotencyConfig `mapstructure:"idempotency"`
RequestCapture RequestCaptureConfig `mapstructure:"request_capture"`
}
// RequestCaptureConfig 配置请求体捕获功能
type RequestCaptureConfig struct {
// NFSPath 为本地挂载的 NFS 根目录(例如 /mnt/nfs/requests)。
// 留空则跳过文件写入,只写数据库。
NFSPath string `mapstructure:"nfs_path"`
// WorkerTimeoutSeconds 单次异步写入的超时时间(秒),默认 5。
WorkerTimeoutSeconds int `mapstructure:"worker_timeout_seconds"`
}
type LogConfig struct {
......
......@@ -45,6 +45,7 @@ type GatewayHandler struct {
apiKeyService *service.APIKeyService
usageRecordWorkerPool *service.UsageRecordWorkerPool
errorPassthroughService *service.ErrorPassthroughService
requestCaptureService *service.RequestCaptureService
concurrencyHelper *ConcurrencyHelper
userMsgQueueHelper *UserMsgQueueHelper
maxAccountSwitches int
......@@ -65,6 +66,7 @@ func NewGatewayHandler(
apiKeyService *service.APIKeyService,
usageRecordWorkerPool *service.UsageRecordWorkerPool,
errorPassthroughService *service.ErrorPassthroughService,
requestCaptureService *service.RequestCaptureService,
userMsgQueueService *service.UserMessageQueueService,
cfg *config.Config,
settingService *service.SettingService,
......@@ -98,6 +100,7 @@ func NewGatewayHandler(
apiKeyService: apiKeyService,
usageRecordWorkerPool: usageRecordWorkerPool,
errorPassthroughService: errorPassthroughService,
requestCaptureService: requestCaptureService,
concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatClaude, pingInterval),
userMsgQueueHelper: umqHelper,
maxAccountSwitches: maxAccountSwitches,
......@@ -147,6 +150,20 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
return
}
// 捕获请求体(仅当该 API Key 开启了 capture_requests)
var captureID int64
if apiKey.CaptureRequests && h.requestCaptureService != nil {
requestID, _ := c.Request.Context().Value(ctxkey.RequestID).(string)
captureID = h.requestCaptureService.Capture(
apiKey.ID, subject.UserID,
requestID,
c.Request.URL.Path,
c.Request.Method,
c.ClientIP(),
body,
)
}
setOpsRequestContext(c, "", false, body)
parsedReq, err := service.ParseGatewayRequest(body, domain.PlatformAnthropic)
......@@ -811,6 +828,11 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
result.ReasoningEffort = service.NormalizeClaudeOutputEffort(parsedReq.OutputEffort)
}
// 异步写入响应体到捕获记录
if captureID > 0 && h.requestCaptureService != nil {
h.requestCaptureService.CaptureResponse(captureID, result.ResponseBody)
}
// 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。
h.submitUsageRecordTask(func(ctx context.Context) {
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
......
......@@ -629,9 +629,10 @@ func apiKeyEntityToService(m *dbent.APIKey) *service.APIKey {
Usage5h: m.Usage5h,
Usage1d: m.Usage1d,
Usage7d: m.Usage7d,
Window5hStart: m.Window5hStart,
Window1dStart: m.Window1dStart,
Window7dStart: m.Window7dStart,
Window5hStart: m.Window5hStart,
Window1dStart: m.Window1dStart,
Window7dStart: m.Window7dStart,
CaptureRequests: m.CaptureRequests,
}
if m.Edges.User != nil {
out.User = userEntityToService(m.Edges.User)
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment