Commit a7386882 authored by 陈曦's avatar 陈曦
Browse files

merge capture requests branch to upstream follow

parents 110702d4 55891dff
Pipeline #82303 passed with stage
in 3 minutes and 44 seconds
// Code generated by ent, DO NOT EDIT.
package ent
import (
"context"
"errors"
"fmt"
"entgo.io/ent/dialect/sql"
"entgo.io/ent/dialect/sql/sqlgraph"
"entgo.io/ent/schema/field"
"github.com/Wei-Shaw/sub2api/ent/predicate"
"github.com/Wei-Shaw/sub2api/ent/requestcapturelog"
)
// RequestCaptureLogUpdate is the builder for updating RequestCaptureLog entities.
type RequestCaptureLogUpdate struct {
config
hooks []Hook
mutation *RequestCaptureLogMutation
}
// Where appends a list predicates to the RequestCaptureLogUpdate builder.
func (_u *RequestCaptureLogUpdate) Where(ps ...predicate.RequestCaptureLog) *RequestCaptureLogUpdate {
_u.mutation.Where(ps...)
return _u
}
// SetAPIKeyID sets the "api_key_id" field.
func (_u *RequestCaptureLogUpdate) SetAPIKeyID(v int64) *RequestCaptureLogUpdate {
_u.mutation.ResetAPIKeyID()
_u.mutation.SetAPIKeyID(v)
return _u
}
// SetNillableAPIKeyID sets the "api_key_id" field if the given value is not nil.
func (_u *RequestCaptureLogUpdate) SetNillableAPIKeyID(v *int64) *RequestCaptureLogUpdate {
if v != nil {
_u.SetAPIKeyID(*v)
}
return _u
}
// AddAPIKeyID adds value to the "api_key_id" field.
func (_u *RequestCaptureLogUpdate) AddAPIKeyID(v int64) *RequestCaptureLogUpdate {
_u.mutation.AddAPIKeyID(v)
return _u
}
// SetUserID sets the "user_id" field.
func (_u *RequestCaptureLogUpdate) SetUserID(v int64) *RequestCaptureLogUpdate {
_u.mutation.ResetUserID()
_u.mutation.SetUserID(v)
return _u
}
// SetNillableUserID sets the "user_id" field if the given value is not nil.
func (_u *RequestCaptureLogUpdate) SetNillableUserID(v *int64) *RequestCaptureLogUpdate {
if v != nil {
_u.SetUserID(*v)
}
return _u
}
// AddUserID adds value to the "user_id" field.
func (_u *RequestCaptureLogUpdate) AddUserID(v int64) *RequestCaptureLogUpdate {
_u.mutation.AddUserID(v)
return _u
}
// SetRequestID sets the "request_id" field.
func (_u *RequestCaptureLogUpdate) SetRequestID(v string) *RequestCaptureLogUpdate {
_u.mutation.SetRequestID(v)
return _u
}
// SetNillableRequestID sets the "request_id" field if the given value is not nil.
func (_u *RequestCaptureLogUpdate) SetNillableRequestID(v *string) *RequestCaptureLogUpdate {
if v != nil {
_u.SetRequestID(*v)
}
return _u
}
// ClearRequestID clears the value of the "request_id" field.
func (_u *RequestCaptureLogUpdate) ClearRequestID() *RequestCaptureLogUpdate {
_u.mutation.ClearRequestID()
return _u
}
// SetPath sets the "path" field.
func (_u *RequestCaptureLogUpdate) SetPath(v string) *RequestCaptureLogUpdate {
_u.mutation.SetPath(v)
return _u
}
// SetNillablePath sets the "path" field if the given value is not nil.
func (_u *RequestCaptureLogUpdate) SetNillablePath(v *string) *RequestCaptureLogUpdate {
if v != nil {
_u.SetPath(*v)
}
return _u
}
// ClearPath clears the value of the "path" field.
func (_u *RequestCaptureLogUpdate) ClearPath() *RequestCaptureLogUpdate {
_u.mutation.ClearPath()
return _u
}
// SetMethod sets the "method" field.
func (_u *RequestCaptureLogUpdate) SetMethod(v string) *RequestCaptureLogUpdate {
_u.mutation.SetMethod(v)
return _u
}
// SetNillableMethod sets the "method" field if the given value is not nil.
func (_u *RequestCaptureLogUpdate) SetNillableMethod(v *string) *RequestCaptureLogUpdate {
if v != nil {
_u.SetMethod(*v)
}
return _u
}
// ClearMethod clears the value of the "method" field.
func (_u *RequestCaptureLogUpdate) ClearMethod() *RequestCaptureLogUpdate {
_u.mutation.ClearMethod()
return _u
}
// SetIPAddress sets the "ip_address" field.
func (_u *RequestCaptureLogUpdate) SetIPAddress(v string) *RequestCaptureLogUpdate {
_u.mutation.SetIPAddress(v)
return _u
}
// SetNillableIPAddress sets the "ip_address" field if the given value is not nil.
func (_u *RequestCaptureLogUpdate) SetNillableIPAddress(v *string) *RequestCaptureLogUpdate {
if v != nil {
_u.SetIPAddress(*v)
}
return _u
}
// ClearIPAddress clears the value of the "ip_address" field.
func (_u *RequestCaptureLogUpdate) ClearIPAddress() *RequestCaptureLogUpdate {
_u.mutation.ClearIPAddress()
return _u
}
// SetRequestBody sets the "request_body" field.
func (_u *RequestCaptureLogUpdate) SetRequestBody(v string) *RequestCaptureLogUpdate {
_u.mutation.SetRequestBody(v)
return _u
}
// SetNillableRequestBody sets the "request_body" field if the given value is not nil.
func (_u *RequestCaptureLogUpdate) SetNillableRequestBody(v *string) *RequestCaptureLogUpdate {
if v != nil {
_u.SetRequestBody(*v)
}
return _u
}
// ClearRequestBody clears the value of the "request_body" field.
func (_u *RequestCaptureLogUpdate) ClearRequestBody() *RequestCaptureLogUpdate {
_u.mutation.ClearRequestBody()
return _u
}
// SetResponseBody sets the "response_body" field.
func (_u *RequestCaptureLogUpdate) SetResponseBody(v string) *RequestCaptureLogUpdate {
_u.mutation.SetResponseBody(v)
return _u
}
// SetNillableResponseBody sets the "response_body" field if the given value is not nil.
func (_u *RequestCaptureLogUpdate) SetNillableResponseBody(v *string) *RequestCaptureLogUpdate {
if v != nil {
_u.SetResponseBody(*v)
}
return _u
}
// ClearResponseBody clears the value of the "response_body" field.
func (_u *RequestCaptureLogUpdate) ClearResponseBody() *RequestCaptureLogUpdate {
_u.mutation.ClearResponseBody()
return _u
}
// SetNfsFilePath sets the "nfs_file_path" field.
func (_u *RequestCaptureLogUpdate) SetNfsFilePath(v string) *RequestCaptureLogUpdate {
_u.mutation.SetNfsFilePath(v)
return _u
}
// SetNillableNfsFilePath sets the "nfs_file_path" field if the given value is not nil.
func (_u *RequestCaptureLogUpdate) SetNillableNfsFilePath(v *string) *RequestCaptureLogUpdate {
if v != nil {
_u.SetNfsFilePath(*v)
}
return _u
}
// ClearNfsFilePath clears the value of the "nfs_file_path" field.
func (_u *RequestCaptureLogUpdate) ClearNfsFilePath() *RequestCaptureLogUpdate {
_u.mutation.ClearNfsFilePath()
return _u
}
// Mutation returns the RequestCaptureLogMutation object of the builder.
func (_u *RequestCaptureLogUpdate) Mutation() *RequestCaptureLogMutation {
return _u.mutation
}
// Save executes the query and returns the number of nodes affected by the update operation.
func (_u *RequestCaptureLogUpdate) Save(ctx context.Context) (int, error) {
return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks)
}
// SaveX is like Save, but panics if an error occurs.
func (_u *RequestCaptureLogUpdate) SaveX(ctx context.Context) int {
affected, err := _u.Save(ctx)
if err != nil {
panic(err)
}
return affected
}
// Exec executes the query.
func (_u *RequestCaptureLogUpdate) Exec(ctx context.Context) error {
_, err := _u.Save(ctx)
return err
}
// ExecX is like Exec, but panics if an error occurs.
func (_u *RequestCaptureLogUpdate) ExecX(ctx context.Context) {
if err := _u.Exec(ctx); err != nil {
panic(err)
}
}
// check runs all checks and user-defined validators on the builder.
func (_u *RequestCaptureLogUpdate) check() error {
if v, ok := _u.mutation.RequestID(); ok {
if err := requestcapturelog.RequestIDValidator(v); err != nil {
return &ValidationError{Name: "request_id", err: fmt.Errorf(`ent: validator failed for field "RequestCaptureLog.request_id": %w`, err)}
}
}
if v, ok := _u.mutation.Path(); ok {
if err := requestcapturelog.PathValidator(v); err != nil {
return &ValidationError{Name: "path", err: fmt.Errorf(`ent: validator failed for field "RequestCaptureLog.path": %w`, err)}
}
}
if v, ok := _u.mutation.Method(); ok {
if err := requestcapturelog.MethodValidator(v); err != nil {
return &ValidationError{Name: "method", err: fmt.Errorf(`ent: validator failed for field "RequestCaptureLog.method": %w`, err)}
}
}
if v, ok := _u.mutation.IPAddress(); ok {
if err := requestcapturelog.IPAddressValidator(v); err != nil {
return &ValidationError{Name: "ip_address", err: fmt.Errorf(`ent: validator failed for field "RequestCaptureLog.ip_address": %w`, err)}
}
}
if v, ok := _u.mutation.NfsFilePath(); ok {
if err := requestcapturelog.NfsFilePathValidator(v); err != nil {
return &ValidationError{Name: "nfs_file_path", err: fmt.Errorf(`ent: validator failed for field "RequestCaptureLog.nfs_file_path": %w`, err)}
}
}
return nil
}
func (_u *RequestCaptureLogUpdate) sqlSave(ctx context.Context) (_node int, err error) {
if err := _u.check(); err != nil {
return _node, err
}
_spec := sqlgraph.NewUpdateSpec(requestcapturelog.Table, requestcapturelog.Columns, sqlgraph.NewFieldSpec(requestcapturelog.FieldID, field.TypeInt64))
if ps := _u.mutation.predicates; len(ps) > 0 {
_spec.Predicate = func(selector *sql.Selector) {
for i := range ps {
ps[i](selector)
}
}
}
if value, ok := _u.mutation.APIKeyID(); ok {
_spec.SetField(requestcapturelog.FieldAPIKeyID, field.TypeInt64, value)
}
if value, ok := _u.mutation.AddedAPIKeyID(); ok {
_spec.AddField(requestcapturelog.FieldAPIKeyID, field.TypeInt64, value)
}
if value, ok := _u.mutation.UserID(); ok {
_spec.SetField(requestcapturelog.FieldUserID, field.TypeInt64, value)
}
if value, ok := _u.mutation.AddedUserID(); ok {
_spec.AddField(requestcapturelog.FieldUserID, field.TypeInt64, value)
}
if value, ok := _u.mutation.RequestID(); ok {
_spec.SetField(requestcapturelog.FieldRequestID, field.TypeString, value)
}
if _u.mutation.RequestIDCleared() {
_spec.ClearField(requestcapturelog.FieldRequestID, field.TypeString)
}
if value, ok := _u.mutation.Path(); ok {
_spec.SetField(requestcapturelog.FieldPath, field.TypeString, value)
}
if _u.mutation.PathCleared() {
_spec.ClearField(requestcapturelog.FieldPath, field.TypeString)
}
if value, ok := _u.mutation.Method(); ok {
_spec.SetField(requestcapturelog.FieldMethod, field.TypeString, value)
}
if _u.mutation.MethodCleared() {
_spec.ClearField(requestcapturelog.FieldMethod, field.TypeString)
}
if value, ok := _u.mutation.IPAddress(); ok {
_spec.SetField(requestcapturelog.FieldIPAddress, field.TypeString, value)
}
if _u.mutation.IPAddressCleared() {
_spec.ClearField(requestcapturelog.FieldIPAddress, field.TypeString)
}
if value, ok := _u.mutation.RequestBody(); ok {
_spec.SetField(requestcapturelog.FieldRequestBody, field.TypeString, value)
}
if _u.mutation.RequestBodyCleared() {
_spec.ClearField(requestcapturelog.FieldRequestBody, field.TypeString)
}
if value, ok := _u.mutation.ResponseBody(); ok {
_spec.SetField(requestcapturelog.FieldResponseBody, field.TypeString, value)
}
if _u.mutation.ResponseBodyCleared() {
_spec.ClearField(requestcapturelog.FieldResponseBody, field.TypeString)
}
if value, ok := _u.mutation.NfsFilePath(); ok {
_spec.SetField(requestcapturelog.FieldNfsFilePath, field.TypeString, value)
}
if _u.mutation.NfsFilePathCleared() {
_spec.ClearField(requestcapturelog.FieldNfsFilePath, field.TypeString)
}
if _node, err = sqlgraph.UpdateNodes(ctx, _u.driver, _spec); err != nil {
if _, ok := err.(*sqlgraph.NotFoundError); ok {
err = &NotFoundError{requestcapturelog.Label}
} else if sqlgraph.IsConstraintError(err) {
err = &ConstraintError{msg: err.Error(), wrap: err}
}
return 0, err
}
_u.mutation.done = true
return _node, nil
}
// RequestCaptureLogUpdateOne is the builder for updating a single RequestCaptureLog entity.
type RequestCaptureLogUpdateOne struct {
config
fields []string
hooks []Hook
mutation *RequestCaptureLogMutation
}
// SetAPIKeyID sets the "api_key_id" field.
func (_u *RequestCaptureLogUpdateOne) SetAPIKeyID(v int64) *RequestCaptureLogUpdateOne {
_u.mutation.ResetAPIKeyID()
_u.mutation.SetAPIKeyID(v)
return _u
}
// SetNillableAPIKeyID sets the "api_key_id" field if the given value is not nil.
func (_u *RequestCaptureLogUpdateOne) SetNillableAPIKeyID(v *int64) *RequestCaptureLogUpdateOne {
if v != nil {
_u.SetAPIKeyID(*v)
}
return _u
}
// AddAPIKeyID adds value to the "api_key_id" field.
func (_u *RequestCaptureLogUpdateOne) AddAPIKeyID(v int64) *RequestCaptureLogUpdateOne {
_u.mutation.AddAPIKeyID(v)
return _u
}
// SetUserID sets the "user_id" field.
func (_u *RequestCaptureLogUpdateOne) SetUserID(v int64) *RequestCaptureLogUpdateOne {
_u.mutation.ResetUserID()
_u.mutation.SetUserID(v)
return _u
}
// SetNillableUserID sets the "user_id" field if the given value is not nil.
func (_u *RequestCaptureLogUpdateOne) SetNillableUserID(v *int64) *RequestCaptureLogUpdateOne {
if v != nil {
_u.SetUserID(*v)
}
return _u
}
// AddUserID adds value to the "user_id" field.
func (_u *RequestCaptureLogUpdateOne) AddUserID(v int64) *RequestCaptureLogUpdateOne {
_u.mutation.AddUserID(v)
return _u
}
// SetRequestID sets the "request_id" field.
func (_u *RequestCaptureLogUpdateOne) SetRequestID(v string) *RequestCaptureLogUpdateOne {
_u.mutation.SetRequestID(v)
return _u
}
// SetNillableRequestID sets the "request_id" field if the given value is not nil.
func (_u *RequestCaptureLogUpdateOne) SetNillableRequestID(v *string) *RequestCaptureLogUpdateOne {
if v != nil {
_u.SetRequestID(*v)
}
return _u
}
// ClearRequestID clears the value of the "request_id" field.
func (_u *RequestCaptureLogUpdateOne) ClearRequestID() *RequestCaptureLogUpdateOne {
_u.mutation.ClearRequestID()
return _u
}
// SetPath sets the "path" field.
func (_u *RequestCaptureLogUpdateOne) SetPath(v string) *RequestCaptureLogUpdateOne {
_u.mutation.SetPath(v)
return _u
}
// SetNillablePath sets the "path" field if the given value is not nil.
func (_u *RequestCaptureLogUpdateOne) SetNillablePath(v *string) *RequestCaptureLogUpdateOne {
if v != nil {
_u.SetPath(*v)
}
return _u
}
// ClearPath clears the value of the "path" field.
func (_u *RequestCaptureLogUpdateOne) ClearPath() *RequestCaptureLogUpdateOne {
_u.mutation.ClearPath()
return _u
}
// SetMethod sets the "method" field.
func (_u *RequestCaptureLogUpdateOne) SetMethod(v string) *RequestCaptureLogUpdateOne {
_u.mutation.SetMethod(v)
return _u
}
// SetNillableMethod sets the "method" field if the given value is not nil.
func (_u *RequestCaptureLogUpdateOne) SetNillableMethod(v *string) *RequestCaptureLogUpdateOne {
if v != nil {
_u.SetMethod(*v)
}
return _u
}
// ClearMethod clears the value of the "method" field.
func (_u *RequestCaptureLogUpdateOne) ClearMethod() *RequestCaptureLogUpdateOne {
_u.mutation.ClearMethod()
return _u
}
// SetIPAddress sets the "ip_address" field.
func (_u *RequestCaptureLogUpdateOne) SetIPAddress(v string) *RequestCaptureLogUpdateOne {
_u.mutation.SetIPAddress(v)
return _u
}
// SetNillableIPAddress sets the "ip_address" field if the given value is not nil.
func (_u *RequestCaptureLogUpdateOne) SetNillableIPAddress(v *string) *RequestCaptureLogUpdateOne {
if v != nil {
_u.SetIPAddress(*v)
}
return _u
}
// ClearIPAddress clears the value of the "ip_address" field.
func (_u *RequestCaptureLogUpdateOne) ClearIPAddress() *RequestCaptureLogUpdateOne {
_u.mutation.ClearIPAddress()
return _u
}
// SetRequestBody sets the "request_body" field.
func (_u *RequestCaptureLogUpdateOne) SetRequestBody(v string) *RequestCaptureLogUpdateOne {
_u.mutation.SetRequestBody(v)
return _u
}
// SetNillableRequestBody sets the "request_body" field if the given value is not nil.
func (_u *RequestCaptureLogUpdateOne) SetNillableRequestBody(v *string) *RequestCaptureLogUpdateOne {
if v != nil {
_u.SetRequestBody(*v)
}
return _u
}
// ClearRequestBody clears the value of the "request_body" field.
func (_u *RequestCaptureLogUpdateOne) ClearRequestBody() *RequestCaptureLogUpdateOne {
_u.mutation.ClearRequestBody()
return _u
}
// SetResponseBody sets the "response_body" field.
func (_u *RequestCaptureLogUpdateOne) SetResponseBody(v string) *RequestCaptureLogUpdateOne {
_u.mutation.SetResponseBody(v)
return _u
}
// SetNillableResponseBody sets the "response_body" field if the given value is not nil.
func (_u *RequestCaptureLogUpdateOne) SetNillableResponseBody(v *string) *RequestCaptureLogUpdateOne {
if v != nil {
_u.SetResponseBody(*v)
}
return _u
}
// ClearResponseBody clears the value of the "response_body" field.
func (_u *RequestCaptureLogUpdateOne) ClearResponseBody() *RequestCaptureLogUpdateOne {
_u.mutation.ClearResponseBody()
return _u
}
// SetNfsFilePath sets the "nfs_file_path" field.
func (_u *RequestCaptureLogUpdateOne) SetNfsFilePath(v string) *RequestCaptureLogUpdateOne {
_u.mutation.SetNfsFilePath(v)
return _u
}
// SetNillableNfsFilePath sets the "nfs_file_path" field if the given value is not nil.
func (_u *RequestCaptureLogUpdateOne) SetNillableNfsFilePath(v *string) *RequestCaptureLogUpdateOne {
if v != nil {
_u.SetNfsFilePath(*v)
}
return _u
}
// ClearNfsFilePath clears the value of the "nfs_file_path" field.
func (_u *RequestCaptureLogUpdateOne) ClearNfsFilePath() *RequestCaptureLogUpdateOne {
_u.mutation.ClearNfsFilePath()
return _u
}
// Mutation returns the RequestCaptureLogMutation object of the builder.
func (_u *RequestCaptureLogUpdateOne) Mutation() *RequestCaptureLogMutation {
return _u.mutation
}
// Where appends a list predicates to the RequestCaptureLogUpdate builder.
func (_u *RequestCaptureLogUpdateOne) Where(ps ...predicate.RequestCaptureLog) *RequestCaptureLogUpdateOne {
_u.mutation.Where(ps...)
return _u
}
// Select allows selecting one or more fields (columns) of the returned entity.
// The default is selecting all fields defined in the entity schema.
func (_u *RequestCaptureLogUpdateOne) Select(field string, fields ...string) *RequestCaptureLogUpdateOne {
_u.fields = append([]string{field}, fields...)
return _u
}
// Save executes the query and returns the updated RequestCaptureLog entity.
func (_u *RequestCaptureLogUpdateOne) Save(ctx context.Context) (*RequestCaptureLog, error) {
return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks)
}
// SaveX is like Save, but panics if an error occurs.
func (_u *RequestCaptureLogUpdateOne) SaveX(ctx context.Context) *RequestCaptureLog {
node, err := _u.Save(ctx)
if err != nil {
panic(err)
}
return node
}
// Exec executes the query on the entity.
func (_u *RequestCaptureLogUpdateOne) Exec(ctx context.Context) error {
_, err := _u.Save(ctx)
return err
}
// ExecX is like Exec, but panics if an error occurs.
func (_u *RequestCaptureLogUpdateOne) ExecX(ctx context.Context) {
if err := _u.Exec(ctx); err != nil {
panic(err)
}
}
// check runs all checks and user-defined validators on the builder.
func (_u *RequestCaptureLogUpdateOne) check() error {
if v, ok := _u.mutation.RequestID(); ok {
if err := requestcapturelog.RequestIDValidator(v); err != nil {
return &ValidationError{Name: "request_id", err: fmt.Errorf(`ent: validator failed for field "RequestCaptureLog.request_id": %w`, err)}
}
}
if v, ok := _u.mutation.Path(); ok {
if err := requestcapturelog.PathValidator(v); err != nil {
return &ValidationError{Name: "path", err: fmt.Errorf(`ent: validator failed for field "RequestCaptureLog.path": %w`, err)}
}
}
if v, ok := _u.mutation.Method(); ok {
if err := requestcapturelog.MethodValidator(v); err != nil {
return &ValidationError{Name: "method", err: fmt.Errorf(`ent: validator failed for field "RequestCaptureLog.method": %w`, err)}
}
}
if v, ok := _u.mutation.IPAddress(); ok {
if err := requestcapturelog.IPAddressValidator(v); err != nil {
return &ValidationError{Name: "ip_address", err: fmt.Errorf(`ent: validator failed for field "RequestCaptureLog.ip_address": %w`, err)}
}
}
if v, ok := _u.mutation.NfsFilePath(); ok {
if err := requestcapturelog.NfsFilePathValidator(v); err != nil {
return &ValidationError{Name: "nfs_file_path", err: fmt.Errorf(`ent: validator failed for field "RequestCaptureLog.nfs_file_path": %w`, err)}
}
}
return nil
}
func (_u *RequestCaptureLogUpdateOne) sqlSave(ctx context.Context) (_node *RequestCaptureLog, err error) {
if err := _u.check(); err != nil {
return _node, err
}
_spec := sqlgraph.NewUpdateSpec(requestcapturelog.Table, requestcapturelog.Columns, sqlgraph.NewFieldSpec(requestcapturelog.FieldID, field.TypeInt64))
id, ok := _u.mutation.ID()
if !ok {
return nil, &ValidationError{Name: "id", err: errors.New(`ent: missing "RequestCaptureLog.id" for update`)}
}
_spec.Node.ID.Value = id
if fields := _u.fields; len(fields) > 0 {
_spec.Node.Columns = make([]string, 0, len(fields))
_spec.Node.Columns = append(_spec.Node.Columns, requestcapturelog.FieldID)
for _, f := range fields {
if !requestcapturelog.ValidColumn(f) {
return nil, &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)}
}
if f != requestcapturelog.FieldID {
_spec.Node.Columns = append(_spec.Node.Columns, f)
}
}
}
if ps := _u.mutation.predicates; len(ps) > 0 {
_spec.Predicate = func(selector *sql.Selector) {
for i := range ps {
ps[i](selector)
}
}
}
if value, ok := _u.mutation.APIKeyID(); ok {
_spec.SetField(requestcapturelog.FieldAPIKeyID, field.TypeInt64, value)
}
if value, ok := _u.mutation.AddedAPIKeyID(); ok {
_spec.AddField(requestcapturelog.FieldAPIKeyID, field.TypeInt64, value)
}
if value, ok := _u.mutation.UserID(); ok {
_spec.SetField(requestcapturelog.FieldUserID, field.TypeInt64, value)
}
if value, ok := _u.mutation.AddedUserID(); ok {
_spec.AddField(requestcapturelog.FieldUserID, field.TypeInt64, value)
}
if value, ok := _u.mutation.RequestID(); ok {
_spec.SetField(requestcapturelog.FieldRequestID, field.TypeString, value)
}
if _u.mutation.RequestIDCleared() {
_spec.ClearField(requestcapturelog.FieldRequestID, field.TypeString)
}
if value, ok := _u.mutation.Path(); ok {
_spec.SetField(requestcapturelog.FieldPath, field.TypeString, value)
}
if _u.mutation.PathCleared() {
_spec.ClearField(requestcapturelog.FieldPath, field.TypeString)
}
if value, ok := _u.mutation.Method(); ok {
_spec.SetField(requestcapturelog.FieldMethod, field.TypeString, value)
}
if _u.mutation.MethodCleared() {
_spec.ClearField(requestcapturelog.FieldMethod, field.TypeString)
}
if value, ok := _u.mutation.IPAddress(); ok {
_spec.SetField(requestcapturelog.FieldIPAddress, field.TypeString, value)
}
if _u.mutation.IPAddressCleared() {
_spec.ClearField(requestcapturelog.FieldIPAddress, field.TypeString)
}
if value, ok := _u.mutation.RequestBody(); ok {
_spec.SetField(requestcapturelog.FieldRequestBody, field.TypeString, value)
}
if _u.mutation.RequestBodyCleared() {
_spec.ClearField(requestcapturelog.FieldRequestBody, field.TypeString)
}
if value, ok := _u.mutation.ResponseBody(); ok {
_spec.SetField(requestcapturelog.FieldResponseBody, field.TypeString, value)
}
if _u.mutation.ResponseBodyCleared() {
_spec.ClearField(requestcapturelog.FieldResponseBody, field.TypeString)
}
if value, ok := _u.mutation.NfsFilePath(); ok {
_spec.SetField(requestcapturelog.FieldNfsFilePath, field.TypeString, value)
}
if _u.mutation.NfsFilePathCleared() {
_spec.ClearField(requestcapturelog.FieldNfsFilePath, field.TypeString)
}
_node = &RequestCaptureLog{config: _u.config}
_spec.Assign = _node.assignValues
_spec.ScanValues = _node.scanValues
if err = sqlgraph.UpdateNode(ctx, _u.driver, _spec); err != nil {
if _, ok := err.(*sqlgraph.NotFoundError); ok {
err = &NotFoundError{requestcapturelog.Label}
} else if sqlgraph.IsConstraintError(err) {
err = &ConstraintError{msg: err.Error(), wrap: err}
}
return nil, err
}
_u.mutation.done = true
return _node, nil
}
......@@ -28,6 +28,7 @@ import (
"github.com/Wei-Shaw/sub2api/ent/promocodeusage"
"github.com/Wei-Shaw/sub2api/ent/proxy"
"github.com/Wei-Shaw/sub2api/ent/redeemcode"
"github.com/Wei-Shaw/sub2api/ent/requestcapturelog"
"github.com/Wei-Shaw/sub2api/ent/schema"
"github.com/Wei-Shaw/sub2api/ent/securitysecret"
"github.com/Wei-Shaw/sub2api/ent/setting"
......@@ -140,6 +141,10 @@ func init() {
apikeyDescUsage7d := apikeyFields[16].Descriptor()
// apikey.DefaultUsage7d holds the default value on creation for the usage_7d field.
apikey.DefaultUsage7d = apikeyDescUsage7d.Default.(float64)
// apikeyDescCaptureRequests is the schema descriptor for capture_requests field.
apikeyDescCaptureRequests := apikeyFields[20].Descriptor()
// apikey.DefaultCaptureRequests holds the default value on creation for the capture_requests field.
apikey.DefaultCaptureRequests = apikeyDescCaptureRequests.Default.(bool)
accountMixin := schema.Account{}.Mixin()
accountMixinHooks1 := accountMixin[1].Hooks()
account.Hooks[0] = accountMixinHooks1[0]
......@@ -1377,6 +1382,32 @@ func init() {
redeemcodeDescValidityDays := redeemcodeFields[9].Descriptor()
// redeemcode.DefaultValidityDays holds the default value on creation for the validity_days field.
redeemcode.DefaultValidityDays = redeemcodeDescValidityDays.Default.(int)
requestcapturelogFields := schema.RequestCaptureLog{}.Fields()
_ = requestcapturelogFields
// requestcapturelogDescRequestID is the schema descriptor for request_id field.
requestcapturelogDescRequestID := requestcapturelogFields[2].Descriptor()
// requestcapturelog.RequestIDValidator is a validator for the "request_id" field. It is called by the builders before save.
requestcapturelog.RequestIDValidator = requestcapturelogDescRequestID.Validators[0].(func(string) error)
// requestcapturelogDescPath is the schema descriptor for path field.
requestcapturelogDescPath := requestcapturelogFields[3].Descriptor()
// requestcapturelog.PathValidator is a validator for the "path" field. It is called by the builders before save.
requestcapturelog.PathValidator = requestcapturelogDescPath.Validators[0].(func(string) error)
// requestcapturelogDescMethod is the schema descriptor for method field.
requestcapturelogDescMethod := requestcapturelogFields[4].Descriptor()
// requestcapturelog.MethodValidator is a validator for the "method" field. It is called by the builders before save.
requestcapturelog.MethodValidator = requestcapturelogDescMethod.Validators[0].(func(string) error)
// requestcapturelogDescIPAddress is the schema descriptor for ip_address field.
requestcapturelogDescIPAddress := requestcapturelogFields[5].Descriptor()
// requestcapturelog.IPAddressValidator is a validator for the "ip_address" field. It is called by the builders before save.
requestcapturelog.IPAddressValidator = requestcapturelogDescIPAddress.Validators[0].(func(string) error)
// requestcapturelogDescNfsFilePath is the schema descriptor for nfs_file_path field.
requestcapturelogDescNfsFilePath := requestcapturelogFields[8].Descriptor()
// requestcapturelog.NfsFilePathValidator is a validator for the "nfs_file_path" field. It is called by the builders before save.
requestcapturelog.NfsFilePathValidator = requestcapturelogDescNfsFilePath.Validators[0].(func(string) error)
// requestcapturelogDescCreatedAt is the schema descriptor for created_at field.
requestcapturelogDescCreatedAt := requestcapturelogFields[9].Descriptor()
// requestcapturelog.DefaultCreatedAt holds the default value on creation for the created_at field.
requestcapturelog.DefaultCreatedAt = requestcapturelogDescCreatedAt.Default.(func() time.Time)
securitysecretMixin := schema.SecuritySecret{}.Mixin()
securitysecretMixinFields0 := securitysecretMixin[0].Fields()
_ = securitysecretMixinFields0
......
......@@ -115,6 +115,11 @@ func (APIKey) Fields() []ent.Field {
Optional().
Nillable().
Comment("Start time of the current 7d rate limit window"),
// ========== Request capture ==========
field.Bool("capture_requests").
Default(false).
Comment("是否对该 API Key 的请求体进行存储捕获"),
}
}
......
package schema
import (
"time"
"entgo.io/ent"
"entgo.io/ent/dialect"
"entgo.io/ent/dialect/entsql"
"entgo.io/ent/schema"
"entgo.io/ent/schema/field"
"entgo.io/ent/schema/index"
)
// RequestCaptureLog 记录指定 API Key 的请求体,用于审计和分析。
// 只追加,不支持更新/删除(同 PaymentAuditLog 模式)。
type RequestCaptureLog struct {
ent.Schema
}
func (RequestCaptureLog) Annotations() []schema.Annotation {
return []schema.Annotation{
entsql.Annotation{Table: "request_capture_logs"},
}
}
func (RequestCaptureLog) Fields() []ent.Field {
return []ent.Field{
field.Int64("api_key_id"),
field.Int64("user_id"),
field.String("request_id").
MaxLen(64).
Optional().
Nillable(),
field.String("path").
MaxLen(100).
Optional().
Nillable(),
field.String("method").
MaxLen(10).
Optional().
Nillable(),
field.String("ip_address").
MaxLen(45).
Optional().
Nillable(),
// request_body 存原始 JSON 文本,不加索引,避免影响查询计划
field.Text("request_body").
Optional().
Nillable(),
// response_body 存响应文本(非 streaming 为完整 JSON,streaming 为拼接的 assistant text)
field.Text("response_body").
Optional().
Nillable(),
// nfs_file_path NFS 文件路径快照,方便核查
field.String("nfs_file_path").
MaxLen(500).
Optional().
Nillable(),
field.Time("created_at").
Default(time.Now).
Immutable().
SchemaType(map[string]string{dialect.Postgres: "timestamptz"}),
}
}
func (RequestCaptureLog) Edges() []ent.Edge {
return nil
}
func (RequestCaptureLog) Indexes() []ent.Index {
return []ent.Index{
index.Fields("api_key_id", "created_at"),
index.Fields("user_id"),
}
}
......@@ -60,6 +60,8 @@ type Tx struct {
Proxy *ProxyClient
// RedeemCode is the client for interacting with the RedeemCode builders.
RedeemCode *RedeemCodeClient
// RequestCaptureLog is the client for interacting with the RequestCaptureLog builders.
RequestCaptureLog *RequestCaptureLogClient
// SecuritySecret is the client for interacting with the SecuritySecret builders.
SecuritySecret *SecuritySecretClient
// Setting is the client for interacting with the Setting builders.
......@@ -236,6 +238,7 @@ func (tx *Tx) init() {
tx.PromoCodeUsage = NewPromoCodeUsageClient(tx.config)
tx.Proxy = NewProxyClient(tx.config)
tx.RedeemCode = NewRedeemCodeClient(tx.config)
tx.RequestCaptureLog = NewRequestCaptureLogClient(tx.config)
tx.SecuritySecret = NewSecuritySecretClient(tx.config)
tx.Setting = NewSettingClient(tx.config)
tx.SubscriptionPlan = NewSubscriptionPlanClient(tx.config)
......
......@@ -89,6 +89,16 @@ type Config struct {
Gemini GeminiConfig `mapstructure:"gemini"`
Update UpdateConfig `mapstructure:"update"`
Idempotency IdempotencyConfig `mapstructure:"idempotency"`
RequestCapture RequestCaptureConfig `mapstructure:"request_capture"`
}
// RequestCaptureConfig 配置请求体捕获功能
type RequestCaptureConfig struct {
// NFSPath 为本地挂载的 NFS 根目录(例如 /mnt/nfs/requests)。
// 留空则跳过文件写入,只写数据库。
NFSPath string `mapstructure:"nfs_path"`
// WorkerTimeoutSeconds 单次异步写入的超时时间(秒),默认 5。
WorkerTimeoutSeconds int `mapstructure:"worker_timeout_seconds"`
}
type LogConfig struct {
......
......@@ -565,6 +565,17 @@ func (s *stubAdminService) AdminUpdateAPIKeyGroupID(ctx context.Context, keyID i
return nil, service.ErrAPIKeyNotFound
}
func (s *stubAdminService) AdminSetCaptureRequests(ctx context.Context, keyID int64, enabled bool) (*service.APIKey, error) {
for i := range s.apiKeys {
if s.apiKeys[i].ID == keyID {
s.apiKeys[i].CaptureRequests = enabled
k := s.apiKeys[i]
return &k, nil
}
}
return nil, service.ErrAPIKeyNotFound
}
func (s *stubAdminService) ResetAccountQuota(ctx context.Context, id int64) error {
return nil
}
......
......@@ -10,6 +10,11 @@ import (
"github.com/gin-gonic/gin"
)
// AdminSetCaptureRequestsRequest 请求体:设置/清除 capture_requests
type AdminSetCaptureRequestsRequest struct {
Enabled bool `json:"enabled"`
}
// AdminAPIKeyHandler handles admin API key management
type AdminAPIKeyHandler struct {
adminService service.AdminService
......@@ -27,6 +32,33 @@ type AdminUpdateAPIKeyGroupRequest struct {
GroupID *int64 `json:"group_id"` // nil=不修改, 0=解绑, >0=绑定到目标分组
}
// SetCaptureRequests 开启或关闭指定 API Key 的请求体捕获,并立即失效认证缓存。
// PUT /api/v1/admin/api-keys/:id/capture-requests
func (h *AdminAPIKeyHandler) SetCaptureRequests(c *gin.Context) {
keyID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "Invalid API key ID")
return
}
var req AdminSetCaptureRequestsRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
apiKey, err := h.adminService.AdminSetCaptureRequests(c.Request.Context(), keyID, req.Enabled)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, gin.H{
"id": apiKey.ID,
"capture_requests": apiKey.CaptureRequests,
})
}
// UpdateGroup handles updating an API key's group binding
// PUT /api/v1/admin/api-keys/:id
func (h *AdminAPIKeyHandler) UpdateGroup(c *gin.Context) {
......
......@@ -45,6 +45,7 @@ type GatewayHandler struct {
apiKeyService *service.APIKeyService
usageRecordWorkerPool *service.UsageRecordWorkerPool
errorPassthroughService *service.ErrorPassthroughService
requestCaptureService *service.RequestCaptureService
concurrencyHelper *ConcurrencyHelper
userMsgQueueHelper *UserMsgQueueHelper
maxAccountSwitches int
......@@ -65,6 +66,7 @@ func NewGatewayHandler(
apiKeyService *service.APIKeyService,
usageRecordWorkerPool *service.UsageRecordWorkerPool,
errorPassthroughService *service.ErrorPassthroughService,
requestCaptureService *service.RequestCaptureService,
userMsgQueueService *service.UserMessageQueueService,
cfg *config.Config,
settingService *service.SettingService,
......@@ -98,6 +100,7 @@ func NewGatewayHandler(
apiKeyService: apiKeyService,
usageRecordWorkerPool: usageRecordWorkerPool,
errorPassthroughService: errorPassthroughService,
requestCaptureService: requestCaptureService,
concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatClaude, pingInterval),
userMsgQueueHelper: umqHelper,
maxAccountSwitches: maxAccountSwitches,
......@@ -147,6 +150,20 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
return
}
// 捕获请求体(仅当该 API Key 开启了 capture_requests)
var captureID int64
if apiKey.CaptureRequests && h.requestCaptureService != nil {
requestID, _ := c.Request.Context().Value(ctxkey.RequestID).(string)
captureID = h.requestCaptureService.Capture(
apiKey.ID, subject.UserID,
requestID,
c.Request.URL.Path,
c.Request.Method,
c.ClientIP(),
body,
)
}
setOpsRequestContext(c, "", false, body)
parsedReq, err := service.ParseGatewayRequest(body, domain.PlatformAnthropic)
......@@ -158,6 +175,15 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
reqStream := parsedReq.Stream
reqLog = reqLog.With(zap.String("model", reqModel), zap.Bool("stream", reqStream))
// 若 request_capture 已启用且为流式请求,注入响应体采集 buffer 到 context
// service 层的 handleStreamingResponse 会将 text_delta 内容写入此 buffer
if captureID > 0 && reqStream {
captureRespBuilder := &strings.Builder{}
c.Request = c.Request.WithContext(
context.WithValue(c.Request.Context(), ctxkey.ResponseCaptureBuffer, captureRespBuilder),
)
}
// 解析渠道级模型映射
channelMapping, _ := h.gatewayService.ResolveChannelMappingAndRestrict(c.Request.Context(), apiKey.GroupID, reqModel)
......@@ -483,6 +509,11 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
result.ReasoningEffort = service.NormalizeClaudeOutputEffort(parsedReq.OutputEffort)
}
// 异步写入响应体到捕获记录
if captureID > 0 && h.requestCaptureService != nil {
h.requestCaptureService.CaptureResponse(captureID, result.ResponseBody)
}
// 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。
h.submitUsageRecordTask(func(ctx context.Context) {
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
......@@ -840,6 +871,11 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
result.ReasoningEffort = service.NormalizeClaudeOutputEffort(parsedReq.OutputEffort)
}
// 异步写入响应体到捕获记录
if captureID > 0 && h.requestCaptureService != nil {
h.requestCaptureService.CaptureResponse(captureID, result.ResponseBody)
}
// 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。
h.submitUsageRecordTask(func(ctx context.Context) {
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
......
......@@ -8,6 +8,7 @@ import (
"time"
pkghttputil "github.com/Wei-Shaw/sub2api/internal/pkg/httputil"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
......@@ -60,6 +61,20 @@ func (h *GatewayHandler) ChatCompletions(c *gin.Context) {
return
}
// 异步捕获请求体(仅当该 API Key 开启了 capture_requests)
var captureID int64
if apiKey.CaptureRequests && h.requestCaptureService != nil {
requestID, _ := c.Request.Context().Value(ctxkey.RequestID).(string)
captureID = h.requestCaptureService.Capture(
apiKey.ID, subject.UserID,
requestID,
c.Request.URL.Path,
c.Request.Method,
c.ClientIP(),
body,
)
}
setOpsRequestContext(c, "", false, body)
// Validate JSON
......@@ -253,6 +268,11 @@ func (h *GatewayHandler) ChatCompletions(c *gin.Context) {
inboundEndpoint := GetInboundEndpoint(c)
upstreamEndpoint := GetUpstreamEndpoint(c, account.Platform)
// 异步写入响应体到捕获记录
if captureID > 0 && h.requestCaptureService != nil {
h.requestCaptureService.CaptureResponse(captureID, result.ResponseBody)
}
h.submitUsageRecordTask(func(ctx context.Context) {
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
Result: result,
......
......@@ -8,6 +8,7 @@ import (
"time"
pkghttputil "github.com/Wei-Shaw/sub2api/internal/pkg/httputil"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
......@@ -62,6 +63,20 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) {
return
}
// 异步捕获请求体(仅当该 API Key 开启了 capture_requests)
var captureID int64
if apiKey.CaptureRequests && h.requestCaptureService != nil {
requestID, _ := c.Request.Context().Value(ctxkey.RequestID).(string)
captureID = h.requestCaptureService.Capture(
apiKey.ID, subject.UserID,
requestID,
c.Request.URL.Path,
c.Request.Method,
c.ClientIP(),
body,
)
}
if !gjson.ValidBytes(body) {
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body")
return
......@@ -265,6 +280,11 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) {
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, true, nil)
}
// 异步写入响应体到捕获记录
if captureID > 0 && h.requestCaptureService != nil && result != nil {
h.requestCaptureService.CaptureResponse(captureID, result.ResponseBody)
}
userAgent := c.GetHeader("User-Agent")
clientIP := ip.GetClientIP(c)
......
......@@ -13,6 +13,7 @@ import (
"github.com/Wei-Shaw/sub2api/internal/config"
pkghttputil "github.com/Wei-Shaw/sub2api/internal/pkg/httputil"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
......@@ -32,6 +33,7 @@ type OpenAIGatewayHandler struct {
apiKeyService *service.APIKeyService
usageRecordWorkerPool *service.UsageRecordWorkerPool
errorPassthroughService *service.ErrorPassthroughService
requestCaptureService *service.RequestCaptureService
concurrencyHelper *ConcurrencyHelper
maxAccountSwitches int
cfg *config.Config
......@@ -62,6 +64,7 @@ func NewOpenAIGatewayHandler(
apiKeyService *service.APIKeyService,
usageRecordWorkerPool *service.UsageRecordWorkerPool,
errorPassthroughService *service.ErrorPassthroughService,
requestCaptureService *service.RequestCaptureService,
cfg *config.Config,
) *OpenAIGatewayHandler {
pingInterval := time.Duration(0)
......@@ -78,6 +81,7 @@ func NewOpenAIGatewayHandler(
apiKeyService: apiKeyService,
usageRecordWorkerPool: usageRecordWorkerPool,
errorPassthroughService: errorPassthroughService,
requestCaptureService: requestCaptureService,
concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatComment, pingInterval),
maxAccountSwitches: maxAccountSwitches,
cfg: cfg,
......@@ -135,6 +139,20 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
return
}
// 异步捕获请求体(仅当该 API Key 开启了 capture_requests)
var captureID int64
if apiKey.CaptureRequests && h.requestCaptureService != nil {
requestID, _ := c.Request.Context().Value(ctxkey.RequestID).(string)
captureID = h.requestCaptureService.Capture(
apiKey.ID, subject.UserID,
requestID,
c.Request.URL.Path,
c.Request.Method,
c.ClientIP(),
body,
)
}
setOpsRequestContext(c, "", false, body)
sessionHashBody := body
if service.IsOpenAIResponsesCompactPathForTest(c) {
......@@ -389,6 +407,11 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, true, nil)
}
// 异步写入响应体到捕获记录
if captureID > 0 && h.requestCaptureService != nil && result != nil {
h.requestCaptureService.CaptureResponse(captureID, result.ResponseBody)
}
// 捕获请求信息(用于异步记录,避免在 goroutine 中访问 gin.Context)
userAgent := c.GetHeader("User-Agent")
clientIP := ip.GetClientIP(c)
......@@ -590,6 +613,20 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) {
subscription, _ := middleware2.GetSubscriptionFromContext(c)
// 异步捕获请求体(仅当该 API Key 开启了 capture_requests)
var captureID int64
if apiKey.CaptureRequests && h.requestCaptureService != nil {
requestID, _ := c.Request.Context().Value(ctxkey.RequestID).(string)
captureID = h.requestCaptureService.Capture(
apiKey.ID, subject.UserID,
requestID,
c.Request.URL.Path,
c.Request.Method,
c.ClientIP(),
body,
)
}
service.SetOpsLatencyMs(c, service.OpsAuthLatencyMsKey, time.Since(requestStart).Milliseconds())
routingStart := time.Now()
......@@ -764,6 +801,11 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) {
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, true, nil)
}
// 异步写入响应体到捕获记录
if captureID > 0 && h.requestCaptureService != nil && result != nil {
h.requestCaptureService.CaptureResponse(captureID, result.ResponseBody)
}
userAgent := c.GetHeader("User-Agent")
clientIP := ip.GetClientIP(c)
requestPayloadHash := service.HashUsageRequestPayload(body)
......
......@@ -55,4 +55,8 @@ const (
// ClaudeCodeVersion stores the extracted Claude Code version from User-Agent (e.g. "2.1.22")
ClaudeCodeVersion Key = "ctx_claude_code_version"
// ResponseCaptureBuffer 用于在 streaming 响应中收集 assistant 文本,供 request_capture 功能使用。
// 值类型为 *strings.Builder,由 handler 层注入,service 层只负责追加文本。
ResponseCaptureBuffer Key = "ctx_response_capture_buffer"
)
......@@ -134,6 +134,7 @@ func (r *apiKeyRepository) GetByKeyForAuth(ctx context.Context, key string) (*se
apikey.FieldRateLimit5h,
apikey.FieldRateLimit1d,
apikey.FieldRateLimit7d,
apikey.FieldCaptureRequests,
).
WithUser(func(q *dbent.UserQuery) {
q.Select(
......@@ -255,6 +256,8 @@ func (r *apiKeyRepository) Update(ctx context.Context, key *service.APIKey) erro
builder.ClearIPBlacklist()
}
builder.SetCaptureRequests(key.CaptureRequests)
affected, err := builder.Save(ctx)
if err != nil {
return err
......@@ -634,9 +637,10 @@ func apiKeyEntityToService(m *dbent.APIKey) *service.APIKey {
Usage5h: m.Usage5h,
Usage1d: m.Usage1d,
Usage7d: m.Usage7d,
Window5hStart: m.Window5hStart,
Window1dStart: m.Window1dStart,
Window7dStart: m.Window7dStart,
Window5hStart: m.Window5hStart,
Window1dStart: m.Window1dStart,
Window7dStart: m.Window7dStart,
CaptureRequests: m.CaptureRequests,
}
if m.Edges.User != nil {
out.User = userEntityToService(m.Edges.User)
......
package repository
import (
"context"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/internal/service"
)
type requestCaptureLogRepository struct {
client *dbent.Client
}
// NewRequestCaptureLogRepository 创建请求捕获日志仓储实例。
func NewRequestCaptureLogRepository(client *dbent.Client) service.RequestCaptureLogRepository {
return &requestCaptureLogRepository{client: client}
}
func (r *requestCaptureLogRepository) Create(ctx context.Context, params service.CreateRequestCaptureLogParams) (int64, error) {
q := r.client.RequestCaptureLog.Create().
SetAPIKeyID(params.APIKeyID).
SetUserID(params.UserID)
if params.RequestID != "" {
q = q.SetRequestID(params.RequestID)
}
if params.Path != "" {
q = q.SetPath(params.Path)
}
if params.Method != "" {
q = q.SetMethod(params.Method)
}
if params.IPAddress != "" {
q = q.SetIPAddress(params.IPAddress)
}
if params.RequestBody != "" {
q = q.SetRequestBody(params.RequestBody)
}
if params.NFSFilePath != "" {
q = q.SetNfsFilePath(params.NFSFilePath)
}
row, err := q.Save(ctx)
if err != nil {
return 0, err
}
return row.ID, nil
}
func (r *requestCaptureLogRepository) UpdateResponseBody(ctx context.Context, id int64, responseBody string) error {
return r.client.RequestCaptureLog.UpdateOneID(id).
SetResponseBody(responseBody).
Exec(ctx)
}
......@@ -92,6 +92,7 @@ var ProviderSet = wire.NewSet(
NewChannelMonitorRepository,
NewChannelMonitorRequestTemplateRepository,
NewAffiliateRepository,
NewRequestCaptureLogRepository,
// Cache implementations
NewGatewayCache,
......
......@@ -101,6 +101,7 @@ func registerAdminAPIKeyRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
apiKeys := admin.Group("/api-keys")
{
apiKeys.PUT("/:id", h.Admin.APIKey.UpdateGroup)
apiKeys.PUT("/:id/capture-requests", h.Admin.APIKey.SetCaptureRequests)
}
}
......
......@@ -58,6 +58,7 @@ type AdminService interface {
// API Key management (admin)
AdminUpdateAPIKeyGroupID(ctx context.Context, keyID int64, groupID *int64) (*AdminUpdateAPIKeyGroupIDResult, error)
AdminSetCaptureRequests(ctx context.Context, keyID int64, enabled bool) (*APIKey, error)
// ReplaceUserGroup 替换用户的专属分组:授予新分组权限、迁移 Key、移除旧分组权限
ReplaceUserGroup(ctx context.Context, userID, oldGroupID, newGroupID int64) (*ReplaceUserGroupResult, error)
......@@ -1961,6 +1962,22 @@ func (s *adminServiceImpl) AdminUpdateAPIKeyGroupID(ctx context.Context, keyID i
return result, nil
}
// AdminSetCaptureRequests 设置或清除指定 API Key 的请求捕获开关,并立即失效认证缓存。
func (s *adminServiceImpl) AdminSetCaptureRequests(ctx context.Context, keyID int64, enabled bool) (*APIKey, error) {
apiKey, err := s.apiKeyRepo.GetByID(ctx, keyID)
if err != nil {
return nil, err
}
apiKey.CaptureRequests = enabled
if err := s.apiKeyRepo.Update(ctx, apiKey); err != nil {
return nil, err
}
if s.authCacheInvalidator != nil {
s.authCacheInvalidator.InvalidateAuthCacheByKey(ctx, apiKey.Key)
}
return apiKey, nil
}
// ReplaceUserGroup 替换用户的专属分组
func (s *adminServiceImpl) ReplaceUserGroup(ctx context.Context, userID, oldGroupID, newGroupID int64) (*ReplaceUserGroupResult, error) {
if oldGroupID == newGroupID {
......
......@@ -21,6 +21,7 @@ import (
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
......@@ -1739,6 +1740,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
var usage *ClaudeUsage
var firstTokenMs *int
var clientDisconnect bool
var responseBody string
if claudeReq.Stream {
// 客户端要求流式,直接透传转换
streamRes, err := s.handleClaudeStreamingResponse(c, resp, startTime, originalModel)
......@@ -1749,6 +1751,10 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
usage = streamRes.usage
firstTokenMs = streamRes.firstTokenMs
clientDisconnect = streamRes.clientDisconnect
// 流式:从 context buffer 读取采集的文本
if captureBuilder, ok := ctx.Value(ctxkey.ResponseCaptureBuffer).(*strings.Builder); ok && captureBuilder != nil {
responseBody = captureBuilder.String()
}
} else {
// 客户端要求非流式,收集流式响应后转换返回
streamRes, err := s.handleClaudeStreamToNonStreaming(c, resp, startTime, originalModel)
......@@ -1758,6 +1764,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
}
usage = streamRes.usage
firstTokenMs = streamRes.firstTokenMs
responseBody = streamRes.responseBody
}
return &ForwardResult{
......@@ -1769,6 +1776,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
Duration: time.Since(startTime),
FirstTokenMs: firstTokenMs,
ClientDisconnect: clientDisconnect,
ResponseBody: responseBody,
}, nil
}
......@@ -2421,6 +2429,7 @@ handleSuccess:
var usage *ClaudeUsage
var firstTokenMs *int
var clientDisconnect bool
var responseBody string
if stream {
// 客户端要求流式,直接透传
......@@ -2432,6 +2441,10 @@ handleSuccess:
usage = streamRes.usage
firstTokenMs = streamRes.firstTokenMs
clientDisconnect = streamRes.clientDisconnect
// 流式:从 context buffer 读取采集的文本
if captureBuilder, ok := ctx.Value(ctxkey.ResponseCaptureBuffer).(*strings.Builder); ok && captureBuilder != nil {
responseBody = captureBuilder.String()
}
} else {
// 客户端要求非流式,收集流式响应后返回
streamRes, err := s.handleGeminiStreamToNonStreaming(c, resp, startTime)
......@@ -2441,6 +2454,7 @@ handleSuccess:
}
usage = streamRes.usage
firstTokenMs = streamRes.firstTokenMs
responseBody = streamRes.responseBody
}
if usage == nil {
......@@ -2465,6 +2479,7 @@ handleSuccess:
ClientDisconnect: clientDisconnect,
ImageCount: imageCount,
ImageSize: imageSize,
ResponseBody: responseBody,
}, nil
}
......@@ -2968,7 +2983,8 @@ func (s *AntigravityGatewayService) resolveResetTime(resetAt *int64, defaultDur
type antigravityStreamResult struct {
usage *ClaudeUsage
firstTokenMs *int
clientDisconnect bool // 客户端是否在流式传输过程中断开
clientDisconnect bool // 客户端是否在流式传输过程中断开
responseBody string // 响应体内容(非流式:完整 JSON;流式:从上下文 buffer 读取)
}
// antigravityClientWriter 封装流式响应的客户端写入,自动检测断开并标记。
......@@ -3124,6 +3140,9 @@ func (s *AntigravityGatewayService) handleGeminiStreamingResponse(c *gin.Context
cw := newAntigravityClientWriter(c.Writer, flusher, "antigravity gemini")
// 响应体文本采集:若上下文注入了 ResponseCaptureBuffer,则写入文本内容
captureBuilder, _ := c.Request.Context().Value(ctxkey.ResponseCaptureBuffer).(*strings.Builder)
// 仅发送一次错误事件,避免多次写入导致协议混乱
errorEventSent := false
sendErrorEvent := func(reason string) {
......@@ -3197,6 +3216,16 @@ func (s *AntigravityGatewayService) handleGeminiStreamingResponse(c *gin.Context
firstTokenMs = &ms
}
// 采集文本用于响应捕获
if captureBuilder != nil && len(inner) > 0 {
gjson.GetBytes(inner, "candidates.0.content.parts").ForEach(func(_, v gjson.Result) bool {
if t := v.Get("text").String(); t != "" {
captureBuilder.WriteString(t)
}
return true
})
}
cw.Fprintf("data: %s\n\n", payload)
continue
}
......@@ -3418,7 +3447,7 @@ returnResponse:
}
c.Data(http.StatusOK, "application/json", respBody)
return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs}, nil
return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs, responseBody: strings.Join(collectedTextParts, "")}, nil
}
// getOrCreateGeminiParts 获取 Gemini 响应的 parts 结构,返回深拷贝和更新回调
......@@ -3867,7 +3896,7 @@ returnResponse:
CacheReadInputTokens: agUsage.CacheReadInputTokens,
}
return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs}, nil
return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs, responseBody: string(claudeResp)}, nil
}
// handleClaudeStreamingResponse 处理 Claude 流式响应(Gemini SSE → Claude SSE 转换)
......@@ -3971,6 +4000,9 @@ func (s *AntigravityGatewayService) handleClaudeStreamingResponse(c *gin.Context
cw := newAntigravityClientWriter(c.Writer, flusher, "antigravity claude")
// 响应体文本采集:若上下文注入了 ResponseCaptureBuffer,则写入文本 delta
captureBuilder, _ := c.Request.Context().Value(ctxkey.ResponseCaptureBuffer).(*strings.Builder)
// 仅发送一次错误事件,避免多次写入导致协议混乱
errorEventSent := false
sendErrorEvent := func(reason string) {
......@@ -4024,7 +4056,8 @@ func (s *AntigravityGatewayService) handleClaudeStreamingResponse(c *gin.Context
lastDataAt = time.Now()
// 处理 SSE 行,转换为 Claude 格式
claudeEvents := processor.ProcessLine(strings.TrimRight(ev.line, "\r\n"))
trimmedLine := strings.TrimRight(ev.line, "\r\n")
claudeEvents := processor.ProcessLine(trimmedLine)
if len(claudeEvents) > 0 {
if firstTokenMs == nil {
ms := int(time.Since(startTime).Milliseconds())
......@@ -4033,6 +4066,22 @@ func (s *AntigravityGatewayService) handleClaudeStreamingResponse(c *gin.Context
cw.Write(claudeEvents)
}
// 采集文本用于响应捕获
if captureBuilder != nil && strings.HasPrefix(trimmedLine, "data:") {
data := strings.TrimSpace(strings.TrimPrefix(trimmedLine, "data:"))
if data != "" && data != "[DONE]" {
// v1internal 格式: response.candidates.0.content.parts.0.text
// 直接 Gemini 格式: candidates.0.content.parts.0.text
text := gjson.Get(data, "response.candidates.0.content.parts.0.text").String()
if text == "" {
text = gjson.Get(data, "candidates.0.content.parts.0.text").String()
}
if text != "" {
captureBuilder.WriteString(text)
}
}
}
case <-intervalCh:
lastRead := time.Unix(0, atomic.LoadInt64(&lastReadAt))
if time.Since(lastRead) < streamInterval {
......@@ -4286,6 +4335,7 @@ func (s *AntigravityGatewayService) ForwardUpstream(ctx context.Context, c *gin.
var usage *ClaudeUsage
var firstTokenMs *int
var clientDisconnect bool
var responseBody string
if claudeReq.Stream {
// 流式响应:透传
......@@ -4295,10 +4345,14 @@ func (s *AntigravityGatewayService) ForwardUpstream(ctx context.Context, c *gin.
c.Header("X-Accel-Buffering", "no")
c.Status(http.StatusOK)
streamRes := s.streamUpstreamResponse(c, resp, startTime)
streamRes := s.streamUpstreamResponse(ctx, c, resp, startTime)
usage = streamRes.usage
firstTokenMs = streamRes.firstTokenMs
clientDisconnect = streamRes.clientDisconnect
// 从 context buffer 读取已收集的 assistant 文本
if captureBuilder, ok := ctx.Value(ctxkey.ResponseCaptureBuffer).(*strings.Builder); ok && captureBuilder != nil {
responseBody = captureBuilder.String()
}
} else {
// 非流式响应:直接透传
respBody, err := io.ReadAll(resp.Body)
......@@ -4308,6 +4362,7 @@ func (s *AntigravityGatewayService) ForwardUpstream(ctx context.Context, c *gin.
// 提取 usage
usage = s.extractClaudeUsage(respBody)
responseBody = string(respBody)
c.Header("Content-Type", resp.Header.Get("Content-Type"))
c.Status(http.StatusOK)
......@@ -4324,6 +4379,7 @@ func (s *AntigravityGatewayService) ForwardUpstream(ctx context.Context, c *gin.
Duration: duration,
FirstTokenMs: firstTokenMs,
ClientDisconnect: clientDisconnect,
ResponseBody: responseBody,
Usage: ClaudeUsage{
InputTokens: usage.InputTokens,
OutputTokens: usage.OutputTokens,
......@@ -4334,10 +4390,13 @@ func (s *AntigravityGatewayService) ForwardUpstream(ctx context.Context, c *gin.
}
// streamUpstreamResponse 透传上游 SSE 流并提取 Claude usage
func (s *AntigravityGatewayService) streamUpstreamResponse(c *gin.Context, resp *http.Response, startTime time.Time) *antigravityStreamResult {
func (s *AntigravityGatewayService) streamUpstreamResponse(ctx context.Context, c *gin.Context, resp *http.Response, startTime time.Time) *antigravityStreamResult {
usage := &ClaudeUsage{}
var firstTokenMs *int
// 响应体捕获:若 context 中注入了 ResponseCaptureBuffer,则收集 text_delta 文本
captureBuilder, _ := ctx.Value(ctxkey.ResponseCaptureBuffer).(*strings.Builder)
scanner := bufio.NewScanner(resp.Body)
maxLineSize := defaultMaxLineSize
if s.settingService.cfg != nil && s.settingService.cfg.Gateway.MaxLineSize > 0 {
......@@ -4435,6 +4494,16 @@ func (s *AntigravityGatewayService) streamUpstreamResponse(c *gin.Context, resp
// 尝试从 message_delta 或 message_stop 事件提取 usage
s.extractSSEUsage(line, usage)
// 收集 assistant text(仅 content_block_delta + text_delta)
if captureBuilder != nil && strings.HasPrefix(line, "data: ") {
data := strings.TrimPrefix(line, "data: ")
if gjson.Get(data, "type").String() == "content_block_delta" {
if gjson.Get(data, "delta.type").String() == "text_delta" {
captureBuilder.WriteString(gjson.Get(data, "delta.text").String())
}
}
}
// 透传行
cw.Fprintf("%s\n", line)
......
......@@ -60,6 +60,9 @@ type APIKey struct {
Window5hStart *time.Time // Start of current 5h window
Window1dStart *time.Time // Start of current 1d window
Window7dStart *time.Time // Start of current 7d window
// 请求体捕获
CaptureRequests bool // 是否对该 Key 的请求体进行存储捕获
}
func (k *APIKey) IsActive() bool {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment