"backend/internal/git@web.lueluesay.top:chenxi/sub2api.git" did not exist on "9d801595c95eb5f5616bca0ec409a42d73325987"
Commit a7386882 authored by 陈曦's avatar 陈曦
Browse files

merge capture requests branch to upstream follow

parents 110702d4 55891dff
Pipeline #82303 passed with stage
in 3 minutes and 44 seconds
// Code generated by ent, DO NOT EDIT.
package ent
import (
"context"
"errors"
"fmt"
"entgo.io/ent/dialect/sql"
"entgo.io/ent/dialect/sql/sqlgraph"
"entgo.io/ent/schema/field"
"github.com/Wei-Shaw/sub2api/ent/predicate"
"github.com/Wei-Shaw/sub2api/ent/requestcapturelog"
)
// RequestCaptureLogUpdate is the builder for updating RequestCaptureLog entities.
type RequestCaptureLogUpdate struct {
config
hooks []Hook
mutation *RequestCaptureLogMutation
}
// Where appends a list predicates to the RequestCaptureLogUpdate builder.
func (_u *RequestCaptureLogUpdate) Where(ps ...predicate.RequestCaptureLog) *RequestCaptureLogUpdate {
_u.mutation.Where(ps...)
return _u
}
// SetAPIKeyID sets the "api_key_id" field.
func (_u *RequestCaptureLogUpdate) SetAPIKeyID(v int64) *RequestCaptureLogUpdate {
_u.mutation.ResetAPIKeyID()
_u.mutation.SetAPIKeyID(v)
return _u
}
// SetNillableAPIKeyID sets the "api_key_id" field if the given value is not nil.
func (_u *RequestCaptureLogUpdate) SetNillableAPIKeyID(v *int64) *RequestCaptureLogUpdate {
if v != nil {
_u.SetAPIKeyID(*v)
}
return _u
}
// AddAPIKeyID adds value to the "api_key_id" field.
func (_u *RequestCaptureLogUpdate) AddAPIKeyID(v int64) *RequestCaptureLogUpdate {
_u.mutation.AddAPIKeyID(v)
return _u
}
// SetUserID sets the "user_id" field.
func (_u *RequestCaptureLogUpdate) SetUserID(v int64) *RequestCaptureLogUpdate {
_u.mutation.ResetUserID()
_u.mutation.SetUserID(v)
return _u
}
// SetNillableUserID sets the "user_id" field if the given value is not nil.
func (_u *RequestCaptureLogUpdate) SetNillableUserID(v *int64) *RequestCaptureLogUpdate {
if v != nil {
_u.SetUserID(*v)
}
return _u
}
// AddUserID adds value to the "user_id" field.
func (_u *RequestCaptureLogUpdate) AddUserID(v int64) *RequestCaptureLogUpdate {
_u.mutation.AddUserID(v)
return _u
}
// SetRequestID sets the "request_id" field.
func (_u *RequestCaptureLogUpdate) SetRequestID(v string) *RequestCaptureLogUpdate {
_u.mutation.SetRequestID(v)
return _u
}
// SetNillableRequestID sets the "request_id" field if the given value is not nil.
func (_u *RequestCaptureLogUpdate) SetNillableRequestID(v *string) *RequestCaptureLogUpdate {
if v != nil {
_u.SetRequestID(*v)
}
return _u
}
// ClearRequestID clears the value of the "request_id" field.
func (_u *RequestCaptureLogUpdate) ClearRequestID() *RequestCaptureLogUpdate {
_u.mutation.ClearRequestID()
return _u
}
// SetPath sets the "path" field.
func (_u *RequestCaptureLogUpdate) SetPath(v string) *RequestCaptureLogUpdate {
_u.mutation.SetPath(v)
return _u
}
// SetNillablePath sets the "path" field if the given value is not nil.
func (_u *RequestCaptureLogUpdate) SetNillablePath(v *string) *RequestCaptureLogUpdate {
if v != nil {
_u.SetPath(*v)
}
return _u
}
// ClearPath clears the value of the "path" field.
func (_u *RequestCaptureLogUpdate) ClearPath() *RequestCaptureLogUpdate {
_u.mutation.ClearPath()
return _u
}
// SetMethod sets the "method" field.
func (_u *RequestCaptureLogUpdate) SetMethod(v string) *RequestCaptureLogUpdate {
_u.mutation.SetMethod(v)
return _u
}
// SetNillableMethod sets the "method" field if the given value is not nil.
func (_u *RequestCaptureLogUpdate) SetNillableMethod(v *string) *RequestCaptureLogUpdate {
if v != nil {
_u.SetMethod(*v)
}
return _u
}
// ClearMethod clears the value of the "method" field.
func (_u *RequestCaptureLogUpdate) ClearMethod() *RequestCaptureLogUpdate {
_u.mutation.ClearMethod()
return _u
}
// SetIPAddress sets the "ip_address" field.
func (_u *RequestCaptureLogUpdate) SetIPAddress(v string) *RequestCaptureLogUpdate {
_u.mutation.SetIPAddress(v)
return _u
}
// SetNillableIPAddress sets the "ip_address" field if the given value is not nil.
func (_u *RequestCaptureLogUpdate) SetNillableIPAddress(v *string) *RequestCaptureLogUpdate {
if v != nil {
_u.SetIPAddress(*v)
}
return _u
}
// ClearIPAddress clears the value of the "ip_address" field.
func (_u *RequestCaptureLogUpdate) ClearIPAddress() *RequestCaptureLogUpdate {
_u.mutation.ClearIPAddress()
return _u
}
// SetRequestBody sets the "request_body" field.
func (_u *RequestCaptureLogUpdate) SetRequestBody(v string) *RequestCaptureLogUpdate {
_u.mutation.SetRequestBody(v)
return _u
}
// SetNillableRequestBody sets the "request_body" field if the given value is not nil.
func (_u *RequestCaptureLogUpdate) SetNillableRequestBody(v *string) *RequestCaptureLogUpdate {
if v != nil {
_u.SetRequestBody(*v)
}
return _u
}
// ClearRequestBody clears the value of the "request_body" field.
func (_u *RequestCaptureLogUpdate) ClearRequestBody() *RequestCaptureLogUpdate {
_u.mutation.ClearRequestBody()
return _u
}
// SetResponseBody sets the "response_body" field.
func (_u *RequestCaptureLogUpdate) SetResponseBody(v string) *RequestCaptureLogUpdate {
_u.mutation.SetResponseBody(v)
return _u
}
// SetNillableResponseBody sets the "response_body" field if the given value is not nil.
func (_u *RequestCaptureLogUpdate) SetNillableResponseBody(v *string) *RequestCaptureLogUpdate {
if v != nil {
_u.SetResponseBody(*v)
}
return _u
}
// ClearResponseBody clears the value of the "response_body" field.
func (_u *RequestCaptureLogUpdate) ClearResponseBody() *RequestCaptureLogUpdate {
_u.mutation.ClearResponseBody()
return _u
}
// SetNfsFilePath sets the "nfs_file_path" field.
func (_u *RequestCaptureLogUpdate) SetNfsFilePath(v string) *RequestCaptureLogUpdate {
_u.mutation.SetNfsFilePath(v)
return _u
}
// SetNillableNfsFilePath sets the "nfs_file_path" field if the given value is not nil.
func (_u *RequestCaptureLogUpdate) SetNillableNfsFilePath(v *string) *RequestCaptureLogUpdate {
if v != nil {
_u.SetNfsFilePath(*v)
}
return _u
}
// ClearNfsFilePath clears the value of the "nfs_file_path" field.
func (_u *RequestCaptureLogUpdate) ClearNfsFilePath() *RequestCaptureLogUpdate {
_u.mutation.ClearNfsFilePath()
return _u
}
// Mutation returns the RequestCaptureLogMutation object of the builder.
func (_u *RequestCaptureLogUpdate) Mutation() *RequestCaptureLogMutation {
return _u.mutation
}
// Save executes the query and returns the number of nodes affected by the update operation.
func (_u *RequestCaptureLogUpdate) Save(ctx context.Context) (int, error) {
return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks)
}
// SaveX is like Save, but panics if an error occurs.
func (_u *RequestCaptureLogUpdate) SaveX(ctx context.Context) int {
affected, err := _u.Save(ctx)
if err != nil {
panic(err)
}
return affected
}
// Exec executes the query.
func (_u *RequestCaptureLogUpdate) Exec(ctx context.Context) error {
_, err := _u.Save(ctx)
return err
}
// ExecX is like Exec, but panics if an error occurs.
func (_u *RequestCaptureLogUpdate) ExecX(ctx context.Context) {
if err := _u.Exec(ctx); err != nil {
panic(err)
}
}
// check runs all checks and user-defined validators on the builder.
func (_u *RequestCaptureLogUpdate) check() error {
if v, ok := _u.mutation.RequestID(); ok {
if err := requestcapturelog.RequestIDValidator(v); err != nil {
return &ValidationError{Name: "request_id", err: fmt.Errorf(`ent: validator failed for field "RequestCaptureLog.request_id": %w`, err)}
}
}
if v, ok := _u.mutation.Path(); ok {
if err := requestcapturelog.PathValidator(v); err != nil {
return &ValidationError{Name: "path", err: fmt.Errorf(`ent: validator failed for field "RequestCaptureLog.path": %w`, err)}
}
}
if v, ok := _u.mutation.Method(); ok {
if err := requestcapturelog.MethodValidator(v); err != nil {
return &ValidationError{Name: "method", err: fmt.Errorf(`ent: validator failed for field "RequestCaptureLog.method": %w`, err)}
}
}
if v, ok := _u.mutation.IPAddress(); ok {
if err := requestcapturelog.IPAddressValidator(v); err != nil {
return &ValidationError{Name: "ip_address", err: fmt.Errorf(`ent: validator failed for field "RequestCaptureLog.ip_address": %w`, err)}
}
}
if v, ok := _u.mutation.NfsFilePath(); ok {
if err := requestcapturelog.NfsFilePathValidator(v); err != nil {
return &ValidationError{Name: "nfs_file_path", err: fmt.Errorf(`ent: validator failed for field "RequestCaptureLog.nfs_file_path": %w`, err)}
}
}
return nil
}
func (_u *RequestCaptureLogUpdate) sqlSave(ctx context.Context) (_node int, err error) {
if err := _u.check(); err != nil {
return _node, err
}
_spec := sqlgraph.NewUpdateSpec(requestcapturelog.Table, requestcapturelog.Columns, sqlgraph.NewFieldSpec(requestcapturelog.FieldID, field.TypeInt64))
if ps := _u.mutation.predicates; len(ps) > 0 {
_spec.Predicate = func(selector *sql.Selector) {
for i := range ps {
ps[i](selector)
}
}
}
if value, ok := _u.mutation.APIKeyID(); ok {
_spec.SetField(requestcapturelog.FieldAPIKeyID, field.TypeInt64, value)
}
if value, ok := _u.mutation.AddedAPIKeyID(); ok {
_spec.AddField(requestcapturelog.FieldAPIKeyID, field.TypeInt64, value)
}
if value, ok := _u.mutation.UserID(); ok {
_spec.SetField(requestcapturelog.FieldUserID, field.TypeInt64, value)
}
if value, ok := _u.mutation.AddedUserID(); ok {
_spec.AddField(requestcapturelog.FieldUserID, field.TypeInt64, value)
}
if value, ok := _u.mutation.RequestID(); ok {
_spec.SetField(requestcapturelog.FieldRequestID, field.TypeString, value)
}
if _u.mutation.RequestIDCleared() {
_spec.ClearField(requestcapturelog.FieldRequestID, field.TypeString)
}
if value, ok := _u.mutation.Path(); ok {
_spec.SetField(requestcapturelog.FieldPath, field.TypeString, value)
}
if _u.mutation.PathCleared() {
_spec.ClearField(requestcapturelog.FieldPath, field.TypeString)
}
if value, ok := _u.mutation.Method(); ok {
_spec.SetField(requestcapturelog.FieldMethod, field.TypeString, value)
}
if _u.mutation.MethodCleared() {
_spec.ClearField(requestcapturelog.FieldMethod, field.TypeString)
}
if value, ok := _u.mutation.IPAddress(); ok {
_spec.SetField(requestcapturelog.FieldIPAddress, field.TypeString, value)
}
if _u.mutation.IPAddressCleared() {
_spec.ClearField(requestcapturelog.FieldIPAddress, field.TypeString)
}
if value, ok := _u.mutation.RequestBody(); ok {
_spec.SetField(requestcapturelog.FieldRequestBody, field.TypeString, value)
}
if _u.mutation.RequestBodyCleared() {
_spec.ClearField(requestcapturelog.FieldRequestBody, field.TypeString)
}
if value, ok := _u.mutation.ResponseBody(); ok {
_spec.SetField(requestcapturelog.FieldResponseBody, field.TypeString, value)
}
if _u.mutation.ResponseBodyCleared() {
_spec.ClearField(requestcapturelog.FieldResponseBody, field.TypeString)
}
if value, ok := _u.mutation.NfsFilePath(); ok {
_spec.SetField(requestcapturelog.FieldNfsFilePath, field.TypeString, value)
}
if _u.mutation.NfsFilePathCleared() {
_spec.ClearField(requestcapturelog.FieldNfsFilePath, field.TypeString)
}
if _node, err = sqlgraph.UpdateNodes(ctx, _u.driver, _spec); err != nil {
if _, ok := err.(*sqlgraph.NotFoundError); ok {
err = &NotFoundError{requestcapturelog.Label}
} else if sqlgraph.IsConstraintError(err) {
err = &ConstraintError{msg: err.Error(), wrap: err}
}
return 0, err
}
_u.mutation.done = true
return _node, nil
}
// RequestCaptureLogUpdateOne is the builder for updating a single RequestCaptureLog entity.
type RequestCaptureLogUpdateOne struct {
config
fields []string
hooks []Hook
mutation *RequestCaptureLogMutation
}
// SetAPIKeyID sets the "api_key_id" field.
func (_u *RequestCaptureLogUpdateOne) SetAPIKeyID(v int64) *RequestCaptureLogUpdateOne {
_u.mutation.ResetAPIKeyID()
_u.mutation.SetAPIKeyID(v)
return _u
}
// SetNillableAPIKeyID sets the "api_key_id" field if the given value is not nil.
func (_u *RequestCaptureLogUpdateOne) SetNillableAPIKeyID(v *int64) *RequestCaptureLogUpdateOne {
if v != nil {
_u.SetAPIKeyID(*v)
}
return _u
}
// AddAPIKeyID adds value to the "api_key_id" field.
func (_u *RequestCaptureLogUpdateOne) AddAPIKeyID(v int64) *RequestCaptureLogUpdateOne {
_u.mutation.AddAPIKeyID(v)
return _u
}
// SetUserID sets the "user_id" field.
func (_u *RequestCaptureLogUpdateOne) SetUserID(v int64) *RequestCaptureLogUpdateOne {
_u.mutation.ResetUserID()
_u.mutation.SetUserID(v)
return _u
}
// SetNillableUserID sets the "user_id" field if the given value is not nil.
func (_u *RequestCaptureLogUpdateOne) SetNillableUserID(v *int64) *RequestCaptureLogUpdateOne {
if v != nil {
_u.SetUserID(*v)
}
return _u
}
// AddUserID adds value to the "user_id" field.
func (_u *RequestCaptureLogUpdateOne) AddUserID(v int64) *RequestCaptureLogUpdateOne {
_u.mutation.AddUserID(v)
return _u
}
// SetRequestID sets the "request_id" field.
func (_u *RequestCaptureLogUpdateOne) SetRequestID(v string) *RequestCaptureLogUpdateOne {
_u.mutation.SetRequestID(v)
return _u
}
// SetNillableRequestID sets the "request_id" field if the given value is not nil.
func (_u *RequestCaptureLogUpdateOne) SetNillableRequestID(v *string) *RequestCaptureLogUpdateOne {
if v != nil {
_u.SetRequestID(*v)
}
return _u
}
// ClearRequestID clears the value of the "request_id" field.
func (_u *RequestCaptureLogUpdateOne) ClearRequestID() *RequestCaptureLogUpdateOne {
_u.mutation.ClearRequestID()
return _u
}
// SetPath sets the "path" field.
func (_u *RequestCaptureLogUpdateOne) SetPath(v string) *RequestCaptureLogUpdateOne {
_u.mutation.SetPath(v)
return _u
}
// SetNillablePath sets the "path" field if the given value is not nil.
func (_u *RequestCaptureLogUpdateOne) SetNillablePath(v *string) *RequestCaptureLogUpdateOne {
if v != nil {
_u.SetPath(*v)
}
return _u
}
// ClearPath clears the value of the "path" field.
func (_u *RequestCaptureLogUpdateOne) ClearPath() *RequestCaptureLogUpdateOne {
_u.mutation.ClearPath()
return _u
}
// SetMethod sets the "method" field.
func (_u *RequestCaptureLogUpdateOne) SetMethod(v string) *RequestCaptureLogUpdateOne {
_u.mutation.SetMethod(v)
return _u
}
// SetNillableMethod sets the "method" field if the given value is not nil.
func (_u *RequestCaptureLogUpdateOne) SetNillableMethod(v *string) *RequestCaptureLogUpdateOne {
if v != nil {
_u.SetMethod(*v)
}
return _u
}
// ClearMethod clears the value of the "method" field.
func (_u *RequestCaptureLogUpdateOne) ClearMethod() *RequestCaptureLogUpdateOne {
_u.mutation.ClearMethod()
return _u
}
// SetIPAddress sets the "ip_address" field.
func (_u *RequestCaptureLogUpdateOne) SetIPAddress(v string) *RequestCaptureLogUpdateOne {
_u.mutation.SetIPAddress(v)
return _u
}
// SetNillableIPAddress sets the "ip_address" field if the given value is not nil.
func (_u *RequestCaptureLogUpdateOne) SetNillableIPAddress(v *string) *RequestCaptureLogUpdateOne {
if v != nil {
_u.SetIPAddress(*v)
}
return _u
}
// ClearIPAddress clears the value of the "ip_address" field.
func (_u *RequestCaptureLogUpdateOne) ClearIPAddress() *RequestCaptureLogUpdateOne {
_u.mutation.ClearIPAddress()
return _u
}
// SetRequestBody sets the "request_body" field.
func (_u *RequestCaptureLogUpdateOne) SetRequestBody(v string) *RequestCaptureLogUpdateOne {
_u.mutation.SetRequestBody(v)
return _u
}
// SetNillableRequestBody sets the "request_body" field if the given value is not nil.
func (_u *RequestCaptureLogUpdateOne) SetNillableRequestBody(v *string) *RequestCaptureLogUpdateOne {
if v != nil {
_u.SetRequestBody(*v)
}
return _u
}
// ClearRequestBody clears the value of the "request_body" field.
func (_u *RequestCaptureLogUpdateOne) ClearRequestBody() *RequestCaptureLogUpdateOne {
_u.mutation.ClearRequestBody()
return _u
}
// SetResponseBody sets the "response_body" field.
func (_u *RequestCaptureLogUpdateOne) SetResponseBody(v string) *RequestCaptureLogUpdateOne {
_u.mutation.SetResponseBody(v)
return _u
}
// SetNillableResponseBody sets the "response_body" field if the given value is not nil.
func (_u *RequestCaptureLogUpdateOne) SetNillableResponseBody(v *string) *RequestCaptureLogUpdateOne {
if v != nil {
_u.SetResponseBody(*v)
}
return _u
}
// ClearResponseBody clears the value of the "response_body" field.
func (_u *RequestCaptureLogUpdateOne) ClearResponseBody() *RequestCaptureLogUpdateOne {
_u.mutation.ClearResponseBody()
return _u
}
// SetNfsFilePath sets the "nfs_file_path" field.
func (_u *RequestCaptureLogUpdateOne) SetNfsFilePath(v string) *RequestCaptureLogUpdateOne {
_u.mutation.SetNfsFilePath(v)
return _u
}
// SetNillableNfsFilePath sets the "nfs_file_path" field if the given value is not nil.
func (_u *RequestCaptureLogUpdateOne) SetNillableNfsFilePath(v *string) *RequestCaptureLogUpdateOne {
if v != nil {
_u.SetNfsFilePath(*v)
}
return _u
}
// ClearNfsFilePath clears the value of the "nfs_file_path" field.
func (_u *RequestCaptureLogUpdateOne) ClearNfsFilePath() *RequestCaptureLogUpdateOne {
_u.mutation.ClearNfsFilePath()
return _u
}
// Mutation returns the RequestCaptureLogMutation object of the builder.
func (_u *RequestCaptureLogUpdateOne) Mutation() *RequestCaptureLogMutation {
return _u.mutation
}
// Where appends a list predicates to the RequestCaptureLogUpdate builder.
func (_u *RequestCaptureLogUpdateOne) Where(ps ...predicate.RequestCaptureLog) *RequestCaptureLogUpdateOne {
_u.mutation.Where(ps...)
return _u
}
// Select allows selecting one or more fields (columns) of the returned entity.
// The default is selecting all fields defined in the entity schema.
func (_u *RequestCaptureLogUpdateOne) Select(field string, fields ...string) *RequestCaptureLogUpdateOne {
_u.fields = append([]string{field}, fields...)
return _u
}
// Save executes the query and returns the updated RequestCaptureLog entity.
func (_u *RequestCaptureLogUpdateOne) Save(ctx context.Context) (*RequestCaptureLog, error) {
return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks)
}
// SaveX is like Save, but panics if an error occurs.
func (_u *RequestCaptureLogUpdateOne) SaveX(ctx context.Context) *RequestCaptureLog {
node, err := _u.Save(ctx)
if err != nil {
panic(err)
}
return node
}
// Exec executes the query on the entity.
func (_u *RequestCaptureLogUpdateOne) Exec(ctx context.Context) error {
_, err := _u.Save(ctx)
return err
}
// ExecX is like Exec, but panics if an error occurs.
func (_u *RequestCaptureLogUpdateOne) ExecX(ctx context.Context) {
if err := _u.Exec(ctx); err != nil {
panic(err)
}
}
// check runs all checks and user-defined validators on the builder.
func (_u *RequestCaptureLogUpdateOne) check() error {
if v, ok := _u.mutation.RequestID(); ok {
if err := requestcapturelog.RequestIDValidator(v); err != nil {
return &ValidationError{Name: "request_id", err: fmt.Errorf(`ent: validator failed for field "RequestCaptureLog.request_id": %w`, err)}
}
}
if v, ok := _u.mutation.Path(); ok {
if err := requestcapturelog.PathValidator(v); err != nil {
return &ValidationError{Name: "path", err: fmt.Errorf(`ent: validator failed for field "RequestCaptureLog.path": %w`, err)}
}
}
if v, ok := _u.mutation.Method(); ok {
if err := requestcapturelog.MethodValidator(v); err != nil {
return &ValidationError{Name: "method", err: fmt.Errorf(`ent: validator failed for field "RequestCaptureLog.method": %w`, err)}
}
}
if v, ok := _u.mutation.IPAddress(); ok {
if err := requestcapturelog.IPAddressValidator(v); err != nil {
return &ValidationError{Name: "ip_address", err: fmt.Errorf(`ent: validator failed for field "RequestCaptureLog.ip_address": %w`, err)}
}
}
if v, ok := _u.mutation.NfsFilePath(); ok {
if err := requestcapturelog.NfsFilePathValidator(v); err != nil {
return &ValidationError{Name: "nfs_file_path", err: fmt.Errorf(`ent: validator failed for field "RequestCaptureLog.nfs_file_path": %w`, err)}
}
}
return nil
}
func (_u *RequestCaptureLogUpdateOne) sqlSave(ctx context.Context) (_node *RequestCaptureLog, err error) {
if err := _u.check(); err != nil {
return _node, err
}
_spec := sqlgraph.NewUpdateSpec(requestcapturelog.Table, requestcapturelog.Columns, sqlgraph.NewFieldSpec(requestcapturelog.FieldID, field.TypeInt64))
id, ok := _u.mutation.ID()
if !ok {
return nil, &ValidationError{Name: "id", err: errors.New(`ent: missing "RequestCaptureLog.id" for update`)}
}
_spec.Node.ID.Value = id
if fields := _u.fields; len(fields) > 0 {
_spec.Node.Columns = make([]string, 0, len(fields))
_spec.Node.Columns = append(_spec.Node.Columns, requestcapturelog.FieldID)
for _, f := range fields {
if !requestcapturelog.ValidColumn(f) {
return nil, &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)}
}
if f != requestcapturelog.FieldID {
_spec.Node.Columns = append(_spec.Node.Columns, f)
}
}
}
if ps := _u.mutation.predicates; len(ps) > 0 {
_spec.Predicate = func(selector *sql.Selector) {
for i := range ps {
ps[i](selector)
}
}
}
if value, ok := _u.mutation.APIKeyID(); ok {
_spec.SetField(requestcapturelog.FieldAPIKeyID, field.TypeInt64, value)
}
if value, ok := _u.mutation.AddedAPIKeyID(); ok {
_spec.AddField(requestcapturelog.FieldAPIKeyID, field.TypeInt64, value)
}
if value, ok := _u.mutation.UserID(); ok {
_spec.SetField(requestcapturelog.FieldUserID, field.TypeInt64, value)
}
if value, ok := _u.mutation.AddedUserID(); ok {
_spec.AddField(requestcapturelog.FieldUserID, field.TypeInt64, value)
}
if value, ok := _u.mutation.RequestID(); ok {
_spec.SetField(requestcapturelog.FieldRequestID, field.TypeString, value)
}
if _u.mutation.RequestIDCleared() {
_spec.ClearField(requestcapturelog.FieldRequestID, field.TypeString)
}
if value, ok := _u.mutation.Path(); ok {
_spec.SetField(requestcapturelog.FieldPath, field.TypeString, value)
}
if _u.mutation.PathCleared() {
_spec.ClearField(requestcapturelog.FieldPath, field.TypeString)
}
if value, ok := _u.mutation.Method(); ok {
_spec.SetField(requestcapturelog.FieldMethod, field.TypeString, value)
}
if _u.mutation.MethodCleared() {
_spec.ClearField(requestcapturelog.FieldMethod, field.TypeString)
}
if value, ok := _u.mutation.IPAddress(); ok {
_spec.SetField(requestcapturelog.FieldIPAddress, field.TypeString, value)
}
if _u.mutation.IPAddressCleared() {
_spec.ClearField(requestcapturelog.FieldIPAddress, field.TypeString)
}
if value, ok := _u.mutation.RequestBody(); ok {
_spec.SetField(requestcapturelog.FieldRequestBody, field.TypeString, value)
}
if _u.mutation.RequestBodyCleared() {
_spec.ClearField(requestcapturelog.FieldRequestBody, field.TypeString)
}
if value, ok := _u.mutation.ResponseBody(); ok {
_spec.SetField(requestcapturelog.FieldResponseBody, field.TypeString, value)
}
if _u.mutation.ResponseBodyCleared() {
_spec.ClearField(requestcapturelog.FieldResponseBody, field.TypeString)
}
if value, ok := _u.mutation.NfsFilePath(); ok {
_spec.SetField(requestcapturelog.FieldNfsFilePath, field.TypeString, value)
}
if _u.mutation.NfsFilePathCleared() {
_spec.ClearField(requestcapturelog.FieldNfsFilePath, field.TypeString)
}
_node = &RequestCaptureLog{config: _u.config}
_spec.Assign = _node.assignValues
_spec.ScanValues = _node.scanValues
if err = sqlgraph.UpdateNode(ctx, _u.driver, _spec); err != nil {
if _, ok := err.(*sqlgraph.NotFoundError); ok {
err = &NotFoundError{requestcapturelog.Label}
} else if sqlgraph.IsConstraintError(err) {
err = &ConstraintError{msg: err.Error(), wrap: err}
}
return nil, err
}
_u.mutation.done = true
return _node, nil
}
...@@ -28,6 +28,7 @@ import ( ...@@ -28,6 +28,7 @@ import (
"github.com/Wei-Shaw/sub2api/ent/promocodeusage" "github.com/Wei-Shaw/sub2api/ent/promocodeusage"
"github.com/Wei-Shaw/sub2api/ent/proxy" "github.com/Wei-Shaw/sub2api/ent/proxy"
"github.com/Wei-Shaw/sub2api/ent/redeemcode" "github.com/Wei-Shaw/sub2api/ent/redeemcode"
"github.com/Wei-Shaw/sub2api/ent/requestcapturelog"
"github.com/Wei-Shaw/sub2api/ent/schema" "github.com/Wei-Shaw/sub2api/ent/schema"
"github.com/Wei-Shaw/sub2api/ent/securitysecret" "github.com/Wei-Shaw/sub2api/ent/securitysecret"
"github.com/Wei-Shaw/sub2api/ent/setting" "github.com/Wei-Shaw/sub2api/ent/setting"
...@@ -140,6 +141,10 @@ func init() { ...@@ -140,6 +141,10 @@ func init() {
apikeyDescUsage7d := apikeyFields[16].Descriptor() apikeyDescUsage7d := apikeyFields[16].Descriptor()
// apikey.DefaultUsage7d holds the default value on creation for the usage_7d field. // apikey.DefaultUsage7d holds the default value on creation for the usage_7d field.
apikey.DefaultUsage7d = apikeyDescUsage7d.Default.(float64) apikey.DefaultUsage7d = apikeyDescUsage7d.Default.(float64)
// apikeyDescCaptureRequests is the schema descriptor for capture_requests field.
apikeyDescCaptureRequests := apikeyFields[20].Descriptor()
// apikey.DefaultCaptureRequests holds the default value on creation for the capture_requests field.
apikey.DefaultCaptureRequests = apikeyDescCaptureRequests.Default.(bool)
accountMixin := schema.Account{}.Mixin() accountMixin := schema.Account{}.Mixin()
accountMixinHooks1 := accountMixin[1].Hooks() accountMixinHooks1 := accountMixin[1].Hooks()
account.Hooks[0] = accountMixinHooks1[0] account.Hooks[0] = accountMixinHooks1[0]
...@@ -1377,6 +1382,32 @@ func init() { ...@@ -1377,6 +1382,32 @@ func init() {
redeemcodeDescValidityDays := redeemcodeFields[9].Descriptor() redeemcodeDescValidityDays := redeemcodeFields[9].Descriptor()
// redeemcode.DefaultValidityDays holds the default value on creation for the validity_days field. // redeemcode.DefaultValidityDays holds the default value on creation for the validity_days field.
redeemcode.DefaultValidityDays = redeemcodeDescValidityDays.Default.(int) redeemcode.DefaultValidityDays = redeemcodeDescValidityDays.Default.(int)
requestcapturelogFields := schema.RequestCaptureLog{}.Fields()
_ = requestcapturelogFields
// requestcapturelogDescRequestID is the schema descriptor for request_id field.
requestcapturelogDescRequestID := requestcapturelogFields[2].Descriptor()
// requestcapturelog.RequestIDValidator is a validator for the "request_id" field. It is called by the builders before save.
requestcapturelog.RequestIDValidator = requestcapturelogDescRequestID.Validators[0].(func(string) error)
// requestcapturelogDescPath is the schema descriptor for path field.
requestcapturelogDescPath := requestcapturelogFields[3].Descriptor()
// requestcapturelog.PathValidator is a validator for the "path" field. It is called by the builders before save.
requestcapturelog.PathValidator = requestcapturelogDescPath.Validators[0].(func(string) error)
// requestcapturelogDescMethod is the schema descriptor for method field.
requestcapturelogDescMethod := requestcapturelogFields[4].Descriptor()
// requestcapturelog.MethodValidator is a validator for the "method" field. It is called by the builders before save.
requestcapturelog.MethodValidator = requestcapturelogDescMethod.Validators[0].(func(string) error)
// requestcapturelogDescIPAddress is the schema descriptor for ip_address field.
requestcapturelogDescIPAddress := requestcapturelogFields[5].Descriptor()
// requestcapturelog.IPAddressValidator is a validator for the "ip_address" field. It is called by the builders before save.
requestcapturelog.IPAddressValidator = requestcapturelogDescIPAddress.Validators[0].(func(string) error)
// requestcapturelogDescNfsFilePath is the schema descriptor for nfs_file_path field.
requestcapturelogDescNfsFilePath := requestcapturelogFields[8].Descriptor()
// requestcapturelog.NfsFilePathValidator is a validator for the "nfs_file_path" field. It is called by the builders before save.
requestcapturelog.NfsFilePathValidator = requestcapturelogDescNfsFilePath.Validators[0].(func(string) error)
// requestcapturelogDescCreatedAt is the schema descriptor for created_at field.
requestcapturelogDescCreatedAt := requestcapturelogFields[9].Descriptor()
// requestcapturelog.DefaultCreatedAt holds the default value on creation for the created_at field.
requestcapturelog.DefaultCreatedAt = requestcapturelogDescCreatedAt.Default.(func() time.Time)
securitysecretMixin := schema.SecuritySecret{}.Mixin() securitysecretMixin := schema.SecuritySecret{}.Mixin()
securitysecretMixinFields0 := securitysecretMixin[0].Fields() securitysecretMixinFields0 := securitysecretMixin[0].Fields()
_ = securitysecretMixinFields0 _ = securitysecretMixinFields0
......
...@@ -115,6 +115,11 @@ func (APIKey) Fields() []ent.Field { ...@@ -115,6 +115,11 @@ func (APIKey) Fields() []ent.Field {
Optional(). Optional().
Nillable(). Nillable().
Comment("Start time of the current 7d rate limit window"), Comment("Start time of the current 7d rate limit window"),
// ========== Request capture ==========
field.Bool("capture_requests").
Default(false).
Comment("是否对该 API Key 的请求体进行存储捕获"),
} }
} }
......
package schema
import (
"time"
"entgo.io/ent"
"entgo.io/ent/dialect"
"entgo.io/ent/dialect/entsql"
"entgo.io/ent/schema"
"entgo.io/ent/schema/field"
"entgo.io/ent/schema/index"
)
// RequestCaptureLog 记录指定 API Key 的请求体,用于审计和分析。
// 只追加,不支持更新/删除(同 PaymentAuditLog 模式)。
type RequestCaptureLog struct {
ent.Schema
}
func (RequestCaptureLog) Annotations() []schema.Annotation {
return []schema.Annotation{
entsql.Annotation{Table: "request_capture_logs"},
}
}
func (RequestCaptureLog) Fields() []ent.Field {
return []ent.Field{
field.Int64("api_key_id"),
field.Int64("user_id"),
field.String("request_id").
MaxLen(64).
Optional().
Nillable(),
field.String("path").
MaxLen(100).
Optional().
Nillable(),
field.String("method").
MaxLen(10).
Optional().
Nillable(),
field.String("ip_address").
MaxLen(45).
Optional().
Nillable(),
// request_body 存原始 JSON 文本,不加索引,避免影响查询计划
field.Text("request_body").
Optional().
Nillable(),
// response_body 存响应文本(非 streaming 为完整 JSON,streaming 为拼接的 assistant text)
field.Text("response_body").
Optional().
Nillable(),
// nfs_file_path NFS 文件路径快照,方便核查
field.String("nfs_file_path").
MaxLen(500).
Optional().
Nillable(),
field.Time("created_at").
Default(time.Now).
Immutable().
SchemaType(map[string]string{dialect.Postgres: "timestamptz"}),
}
}
func (RequestCaptureLog) Edges() []ent.Edge {
return nil
}
func (RequestCaptureLog) Indexes() []ent.Index {
return []ent.Index{
index.Fields("api_key_id", "created_at"),
index.Fields("user_id"),
}
}
...@@ -60,6 +60,8 @@ type Tx struct { ...@@ -60,6 +60,8 @@ type Tx struct {
Proxy *ProxyClient Proxy *ProxyClient
// RedeemCode is the client for interacting with the RedeemCode builders. // RedeemCode is the client for interacting with the RedeemCode builders.
RedeemCode *RedeemCodeClient RedeemCode *RedeemCodeClient
// RequestCaptureLog is the client for interacting with the RequestCaptureLog builders.
RequestCaptureLog *RequestCaptureLogClient
// SecuritySecret is the client for interacting with the SecuritySecret builders. // SecuritySecret is the client for interacting with the SecuritySecret builders.
SecuritySecret *SecuritySecretClient SecuritySecret *SecuritySecretClient
// Setting is the client for interacting with the Setting builders. // Setting is the client for interacting with the Setting builders.
...@@ -236,6 +238,7 @@ func (tx *Tx) init() { ...@@ -236,6 +238,7 @@ func (tx *Tx) init() {
tx.PromoCodeUsage = NewPromoCodeUsageClient(tx.config) tx.PromoCodeUsage = NewPromoCodeUsageClient(tx.config)
tx.Proxy = NewProxyClient(tx.config) tx.Proxy = NewProxyClient(tx.config)
tx.RedeemCode = NewRedeemCodeClient(tx.config) tx.RedeemCode = NewRedeemCodeClient(tx.config)
tx.RequestCaptureLog = NewRequestCaptureLogClient(tx.config)
tx.SecuritySecret = NewSecuritySecretClient(tx.config) tx.SecuritySecret = NewSecuritySecretClient(tx.config)
tx.Setting = NewSettingClient(tx.config) tx.Setting = NewSettingClient(tx.config)
tx.SubscriptionPlan = NewSubscriptionPlanClient(tx.config) tx.SubscriptionPlan = NewSubscriptionPlanClient(tx.config)
......
...@@ -89,6 +89,16 @@ type Config struct { ...@@ -89,6 +89,16 @@ type Config struct {
Gemini GeminiConfig `mapstructure:"gemini"` Gemini GeminiConfig `mapstructure:"gemini"`
Update UpdateConfig `mapstructure:"update"` Update UpdateConfig `mapstructure:"update"`
Idempotency IdempotencyConfig `mapstructure:"idempotency"` Idempotency IdempotencyConfig `mapstructure:"idempotency"`
RequestCapture RequestCaptureConfig `mapstructure:"request_capture"`
}
// RequestCaptureConfig 配置请求体捕获功能
type RequestCaptureConfig struct {
// NFSPath 为本地挂载的 NFS 根目录(例如 /mnt/nfs/requests)。
// 留空则跳过文件写入,只写数据库。
NFSPath string `mapstructure:"nfs_path"`
// WorkerTimeoutSeconds 单次异步写入的超时时间(秒),默认 5。
WorkerTimeoutSeconds int `mapstructure:"worker_timeout_seconds"`
} }
type LogConfig struct { type LogConfig struct {
......
...@@ -565,6 +565,17 @@ func (s *stubAdminService) AdminUpdateAPIKeyGroupID(ctx context.Context, keyID i ...@@ -565,6 +565,17 @@ func (s *stubAdminService) AdminUpdateAPIKeyGroupID(ctx context.Context, keyID i
return nil, service.ErrAPIKeyNotFound return nil, service.ErrAPIKeyNotFound
} }
func (s *stubAdminService) AdminSetCaptureRequests(ctx context.Context, keyID int64, enabled bool) (*service.APIKey, error) {
for i := range s.apiKeys {
if s.apiKeys[i].ID == keyID {
s.apiKeys[i].CaptureRequests = enabled
k := s.apiKeys[i]
return &k, nil
}
}
return nil, service.ErrAPIKeyNotFound
}
func (s *stubAdminService) ResetAccountQuota(ctx context.Context, id int64) error { func (s *stubAdminService) ResetAccountQuota(ctx context.Context, id int64) error {
return nil return nil
} }
......
...@@ -10,6 +10,11 @@ import ( ...@@ -10,6 +10,11 @@ import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )
// AdminSetCaptureRequestsRequest 请求体:设置/清除 capture_requests
type AdminSetCaptureRequestsRequest struct {
Enabled bool `json:"enabled"`
}
// AdminAPIKeyHandler handles admin API key management // AdminAPIKeyHandler handles admin API key management
type AdminAPIKeyHandler struct { type AdminAPIKeyHandler struct {
adminService service.AdminService adminService service.AdminService
...@@ -27,6 +32,33 @@ type AdminUpdateAPIKeyGroupRequest struct { ...@@ -27,6 +32,33 @@ type AdminUpdateAPIKeyGroupRequest struct {
GroupID *int64 `json:"group_id"` // nil=不修改, 0=解绑, >0=绑定到目标分组 GroupID *int64 `json:"group_id"` // nil=不修改, 0=解绑, >0=绑定到目标分组
} }
// SetCaptureRequests 开启或关闭指定 API Key 的请求体捕获,并立即失效认证缓存。
// PUT /api/v1/admin/api-keys/:id/capture-requests
func (h *AdminAPIKeyHandler) SetCaptureRequests(c *gin.Context) {
keyID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "Invalid API key ID")
return
}
var req AdminSetCaptureRequestsRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
apiKey, err := h.adminService.AdminSetCaptureRequests(c.Request.Context(), keyID, req.Enabled)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, gin.H{
"id": apiKey.ID,
"capture_requests": apiKey.CaptureRequests,
})
}
// UpdateGroup handles updating an API key's group binding // UpdateGroup handles updating an API key's group binding
// PUT /api/v1/admin/api-keys/:id // PUT /api/v1/admin/api-keys/:id
func (h *AdminAPIKeyHandler) UpdateGroup(c *gin.Context) { func (h *AdminAPIKeyHandler) UpdateGroup(c *gin.Context) {
......
...@@ -45,6 +45,7 @@ type GatewayHandler struct { ...@@ -45,6 +45,7 @@ type GatewayHandler struct {
apiKeyService *service.APIKeyService apiKeyService *service.APIKeyService
usageRecordWorkerPool *service.UsageRecordWorkerPool usageRecordWorkerPool *service.UsageRecordWorkerPool
errorPassthroughService *service.ErrorPassthroughService errorPassthroughService *service.ErrorPassthroughService
requestCaptureService *service.RequestCaptureService
concurrencyHelper *ConcurrencyHelper concurrencyHelper *ConcurrencyHelper
userMsgQueueHelper *UserMsgQueueHelper userMsgQueueHelper *UserMsgQueueHelper
maxAccountSwitches int maxAccountSwitches int
...@@ -65,6 +66,7 @@ func NewGatewayHandler( ...@@ -65,6 +66,7 @@ func NewGatewayHandler(
apiKeyService *service.APIKeyService, apiKeyService *service.APIKeyService,
usageRecordWorkerPool *service.UsageRecordWorkerPool, usageRecordWorkerPool *service.UsageRecordWorkerPool,
errorPassthroughService *service.ErrorPassthroughService, errorPassthroughService *service.ErrorPassthroughService,
requestCaptureService *service.RequestCaptureService,
userMsgQueueService *service.UserMessageQueueService, userMsgQueueService *service.UserMessageQueueService,
cfg *config.Config, cfg *config.Config,
settingService *service.SettingService, settingService *service.SettingService,
...@@ -98,6 +100,7 @@ func NewGatewayHandler( ...@@ -98,6 +100,7 @@ func NewGatewayHandler(
apiKeyService: apiKeyService, apiKeyService: apiKeyService,
usageRecordWorkerPool: usageRecordWorkerPool, usageRecordWorkerPool: usageRecordWorkerPool,
errorPassthroughService: errorPassthroughService, errorPassthroughService: errorPassthroughService,
requestCaptureService: requestCaptureService,
concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatClaude, pingInterval), concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatClaude, pingInterval),
userMsgQueueHelper: umqHelper, userMsgQueueHelper: umqHelper,
maxAccountSwitches: maxAccountSwitches, maxAccountSwitches: maxAccountSwitches,
...@@ -147,6 +150,20 @@ func (h *GatewayHandler) Messages(c *gin.Context) { ...@@ -147,6 +150,20 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
return return
} }
// 捕获请求体(仅当该 API Key 开启了 capture_requests)
var captureID int64
if apiKey.CaptureRequests && h.requestCaptureService != nil {
requestID, _ := c.Request.Context().Value(ctxkey.RequestID).(string)
captureID = h.requestCaptureService.Capture(
apiKey.ID, subject.UserID,
requestID,
c.Request.URL.Path,
c.Request.Method,
c.ClientIP(),
body,
)
}
setOpsRequestContext(c, "", false, body) setOpsRequestContext(c, "", false, body)
parsedReq, err := service.ParseGatewayRequest(body, domain.PlatformAnthropic) parsedReq, err := service.ParseGatewayRequest(body, domain.PlatformAnthropic)
...@@ -158,6 +175,15 @@ func (h *GatewayHandler) Messages(c *gin.Context) { ...@@ -158,6 +175,15 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
reqStream := parsedReq.Stream reqStream := parsedReq.Stream
reqLog = reqLog.With(zap.String("model", reqModel), zap.Bool("stream", reqStream)) reqLog = reqLog.With(zap.String("model", reqModel), zap.Bool("stream", reqStream))
// 若 request_capture 已启用且为流式请求,注入响应体采集 buffer 到 context
// service 层的 handleStreamingResponse 会将 text_delta 内容写入此 buffer
if captureID > 0 && reqStream {
captureRespBuilder := &strings.Builder{}
c.Request = c.Request.WithContext(
context.WithValue(c.Request.Context(), ctxkey.ResponseCaptureBuffer, captureRespBuilder),
)
}
// 解析渠道级模型映射 // 解析渠道级模型映射
channelMapping, _ := h.gatewayService.ResolveChannelMappingAndRestrict(c.Request.Context(), apiKey.GroupID, reqModel) channelMapping, _ := h.gatewayService.ResolveChannelMappingAndRestrict(c.Request.Context(), apiKey.GroupID, reqModel)
...@@ -483,6 +509,11 @@ func (h *GatewayHandler) Messages(c *gin.Context) { ...@@ -483,6 +509,11 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
result.ReasoningEffort = service.NormalizeClaudeOutputEffort(parsedReq.OutputEffort) result.ReasoningEffort = service.NormalizeClaudeOutputEffort(parsedReq.OutputEffort)
} }
// 异步写入响应体到捕获记录
if captureID > 0 && h.requestCaptureService != nil {
h.requestCaptureService.CaptureResponse(captureID, result.ResponseBody)
}
// 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。 // 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。
h.submitUsageRecordTask(func(ctx context.Context) { h.submitUsageRecordTask(func(ctx context.Context) {
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{ if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
...@@ -840,6 +871,11 @@ func (h *GatewayHandler) Messages(c *gin.Context) { ...@@ -840,6 +871,11 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
result.ReasoningEffort = service.NormalizeClaudeOutputEffort(parsedReq.OutputEffort) result.ReasoningEffort = service.NormalizeClaudeOutputEffort(parsedReq.OutputEffort)
} }
// 异步写入响应体到捕获记录
if captureID > 0 && h.requestCaptureService != nil {
h.requestCaptureService.CaptureResponse(captureID, result.ResponseBody)
}
// 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。 // 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。
h.submitUsageRecordTask(func(ctx context.Context) { h.submitUsageRecordTask(func(ctx context.Context) {
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{ if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
......
...@@ -8,6 +8,7 @@ import ( ...@@ -8,6 +8,7 @@ import (
"time" "time"
pkghttputil "github.com/Wei-Shaw/sub2api/internal/pkg/httputil" pkghttputil "github.com/Wei-Shaw/sub2api/internal/pkg/httputil"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
"github.com/Wei-Shaw/sub2api/internal/pkg/ip" "github.com/Wei-Shaw/sub2api/internal/pkg/ip"
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware" middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
...@@ -60,6 +61,20 @@ func (h *GatewayHandler) ChatCompletions(c *gin.Context) { ...@@ -60,6 +61,20 @@ func (h *GatewayHandler) ChatCompletions(c *gin.Context) {
return return
} }
// 异步捕获请求体(仅当该 API Key 开启了 capture_requests)
var captureID int64
if apiKey.CaptureRequests && h.requestCaptureService != nil {
requestID, _ := c.Request.Context().Value(ctxkey.RequestID).(string)
captureID = h.requestCaptureService.Capture(
apiKey.ID, subject.UserID,
requestID,
c.Request.URL.Path,
c.Request.Method,
c.ClientIP(),
body,
)
}
setOpsRequestContext(c, "", false, body) setOpsRequestContext(c, "", false, body)
// Validate JSON // Validate JSON
...@@ -253,6 +268,11 @@ func (h *GatewayHandler) ChatCompletions(c *gin.Context) { ...@@ -253,6 +268,11 @@ func (h *GatewayHandler) ChatCompletions(c *gin.Context) {
inboundEndpoint := GetInboundEndpoint(c) inboundEndpoint := GetInboundEndpoint(c)
upstreamEndpoint := GetUpstreamEndpoint(c, account.Platform) upstreamEndpoint := GetUpstreamEndpoint(c, account.Platform)
// 异步写入响应体到捕获记录
if captureID > 0 && h.requestCaptureService != nil {
h.requestCaptureService.CaptureResponse(captureID, result.ResponseBody)
}
h.submitUsageRecordTask(func(ctx context.Context) { h.submitUsageRecordTask(func(ctx context.Context) {
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{ if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
Result: result, Result: result,
......
...@@ -8,6 +8,7 @@ import ( ...@@ -8,6 +8,7 @@ import (
"time" "time"
pkghttputil "github.com/Wei-Shaw/sub2api/internal/pkg/httputil" pkghttputil "github.com/Wei-Shaw/sub2api/internal/pkg/httputil"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
"github.com/Wei-Shaw/sub2api/internal/pkg/ip" "github.com/Wei-Shaw/sub2api/internal/pkg/ip"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger" "github.com/Wei-Shaw/sub2api/internal/pkg/logger"
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware" middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
...@@ -62,6 +63,20 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) { ...@@ -62,6 +63,20 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) {
return return
} }
// 异步捕获请求体(仅当该 API Key 开启了 capture_requests)
var captureID int64
if apiKey.CaptureRequests && h.requestCaptureService != nil {
requestID, _ := c.Request.Context().Value(ctxkey.RequestID).(string)
captureID = h.requestCaptureService.Capture(
apiKey.ID, subject.UserID,
requestID,
c.Request.URL.Path,
c.Request.Method,
c.ClientIP(),
body,
)
}
if !gjson.ValidBytes(body) { if !gjson.ValidBytes(body) {
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body") h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body")
return return
...@@ -265,6 +280,11 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) { ...@@ -265,6 +280,11 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) {
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, true, nil) h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, true, nil)
} }
// 异步写入响应体到捕获记录
if captureID > 0 && h.requestCaptureService != nil && result != nil {
h.requestCaptureService.CaptureResponse(captureID, result.ResponseBody)
}
userAgent := c.GetHeader("User-Agent") userAgent := c.GetHeader("User-Agent")
clientIP := ip.GetClientIP(c) clientIP := ip.GetClientIP(c)
......
...@@ -13,6 +13,7 @@ import ( ...@@ -13,6 +13,7 @@ import (
"github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/config"
pkghttputil "github.com/Wei-Shaw/sub2api/internal/pkg/httputil" pkghttputil "github.com/Wei-Shaw/sub2api/internal/pkg/httputil"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
"github.com/Wei-Shaw/sub2api/internal/pkg/ip" "github.com/Wei-Shaw/sub2api/internal/pkg/ip"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger" "github.com/Wei-Shaw/sub2api/internal/pkg/logger"
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware" middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
...@@ -32,6 +33,7 @@ type OpenAIGatewayHandler struct { ...@@ -32,6 +33,7 @@ type OpenAIGatewayHandler struct {
apiKeyService *service.APIKeyService apiKeyService *service.APIKeyService
usageRecordWorkerPool *service.UsageRecordWorkerPool usageRecordWorkerPool *service.UsageRecordWorkerPool
errorPassthroughService *service.ErrorPassthroughService errorPassthroughService *service.ErrorPassthroughService
requestCaptureService *service.RequestCaptureService
concurrencyHelper *ConcurrencyHelper concurrencyHelper *ConcurrencyHelper
maxAccountSwitches int maxAccountSwitches int
cfg *config.Config cfg *config.Config
...@@ -62,6 +64,7 @@ func NewOpenAIGatewayHandler( ...@@ -62,6 +64,7 @@ func NewOpenAIGatewayHandler(
apiKeyService *service.APIKeyService, apiKeyService *service.APIKeyService,
usageRecordWorkerPool *service.UsageRecordWorkerPool, usageRecordWorkerPool *service.UsageRecordWorkerPool,
errorPassthroughService *service.ErrorPassthroughService, errorPassthroughService *service.ErrorPassthroughService,
requestCaptureService *service.RequestCaptureService,
cfg *config.Config, cfg *config.Config,
) *OpenAIGatewayHandler { ) *OpenAIGatewayHandler {
pingInterval := time.Duration(0) pingInterval := time.Duration(0)
...@@ -78,6 +81,7 @@ func NewOpenAIGatewayHandler( ...@@ -78,6 +81,7 @@ func NewOpenAIGatewayHandler(
apiKeyService: apiKeyService, apiKeyService: apiKeyService,
usageRecordWorkerPool: usageRecordWorkerPool, usageRecordWorkerPool: usageRecordWorkerPool,
errorPassthroughService: errorPassthroughService, errorPassthroughService: errorPassthroughService,
requestCaptureService: requestCaptureService,
concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatComment, pingInterval), concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatComment, pingInterval),
maxAccountSwitches: maxAccountSwitches, maxAccountSwitches: maxAccountSwitches,
cfg: cfg, cfg: cfg,
...@@ -135,6 +139,20 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { ...@@ -135,6 +139,20 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
return return
} }
// 异步捕获请求体(仅当该 API Key 开启了 capture_requests)
var captureID int64
if apiKey.CaptureRequests && h.requestCaptureService != nil {
requestID, _ := c.Request.Context().Value(ctxkey.RequestID).(string)
captureID = h.requestCaptureService.Capture(
apiKey.ID, subject.UserID,
requestID,
c.Request.URL.Path,
c.Request.Method,
c.ClientIP(),
body,
)
}
setOpsRequestContext(c, "", false, body) setOpsRequestContext(c, "", false, body)
sessionHashBody := body sessionHashBody := body
if service.IsOpenAIResponsesCompactPathForTest(c) { if service.IsOpenAIResponsesCompactPathForTest(c) {
...@@ -389,6 +407,11 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { ...@@ -389,6 +407,11 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, true, nil) h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, true, nil)
} }
// 异步写入响应体到捕获记录
if captureID > 0 && h.requestCaptureService != nil && result != nil {
h.requestCaptureService.CaptureResponse(captureID, result.ResponseBody)
}
// 捕获请求信息(用于异步记录,避免在 goroutine 中访问 gin.Context) // 捕获请求信息(用于异步记录,避免在 goroutine 中访问 gin.Context)
userAgent := c.GetHeader("User-Agent") userAgent := c.GetHeader("User-Agent")
clientIP := ip.GetClientIP(c) clientIP := ip.GetClientIP(c)
...@@ -590,6 +613,20 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) { ...@@ -590,6 +613,20 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) {
subscription, _ := middleware2.GetSubscriptionFromContext(c) subscription, _ := middleware2.GetSubscriptionFromContext(c)
// 异步捕获请求体(仅当该 API Key 开启了 capture_requests)
var captureID int64
if apiKey.CaptureRequests && h.requestCaptureService != nil {
requestID, _ := c.Request.Context().Value(ctxkey.RequestID).(string)
captureID = h.requestCaptureService.Capture(
apiKey.ID, subject.UserID,
requestID,
c.Request.URL.Path,
c.Request.Method,
c.ClientIP(),
body,
)
}
service.SetOpsLatencyMs(c, service.OpsAuthLatencyMsKey, time.Since(requestStart).Milliseconds()) service.SetOpsLatencyMs(c, service.OpsAuthLatencyMsKey, time.Since(requestStart).Milliseconds())
routingStart := time.Now() routingStart := time.Now()
...@@ -764,6 +801,11 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) { ...@@ -764,6 +801,11 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) {
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, true, nil) h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, true, nil)
} }
// 异步写入响应体到捕获记录
if captureID > 0 && h.requestCaptureService != nil && result != nil {
h.requestCaptureService.CaptureResponse(captureID, result.ResponseBody)
}
userAgent := c.GetHeader("User-Agent") userAgent := c.GetHeader("User-Agent")
clientIP := ip.GetClientIP(c) clientIP := ip.GetClientIP(c)
requestPayloadHash := service.HashUsageRequestPayload(body) requestPayloadHash := service.HashUsageRequestPayload(body)
......
...@@ -55,4 +55,8 @@ const ( ...@@ -55,4 +55,8 @@ const (
// ClaudeCodeVersion stores the extracted Claude Code version from User-Agent (e.g. "2.1.22") // ClaudeCodeVersion stores the extracted Claude Code version from User-Agent (e.g. "2.1.22")
ClaudeCodeVersion Key = "ctx_claude_code_version" ClaudeCodeVersion Key = "ctx_claude_code_version"
// ResponseCaptureBuffer 用于在 streaming 响应中收集 assistant 文本,供 request_capture 功能使用。
// 值类型为 *strings.Builder,由 handler 层注入,service 层只负责追加文本。
ResponseCaptureBuffer Key = "ctx_response_capture_buffer"
) )
...@@ -134,6 +134,7 @@ func (r *apiKeyRepository) GetByKeyForAuth(ctx context.Context, key string) (*se ...@@ -134,6 +134,7 @@ func (r *apiKeyRepository) GetByKeyForAuth(ctx context.Context, key string) (*se
apikey.FieldRateLimit5h, apikey.FieldRateLimit5h,
apikey.FieldRateLimit1d, apikey.FieldRateLimit1d,
apikey.FieldRateLimit7d, apikey.FieldRateLimit7d,
apikey.FieldCaptureRequests,
). ).
WithUser(func(q *dbent.UserQuery) { WithUser(func(q *dbent.UserQuery) {
q.Select( q.Select(
...@@ -255,6 +256,8 @@ func (r *apiKeyRepository) Update(ctx context.Context, key *service.APIKey) erro ...@@ -255,6 +256,8 @@ func (r *apiKeyRepository) Update(ctx context.Context, key *service.APIKey) erro
builder.ClearIPBlacklist() builder.ClearIPBlacklist()
} }
builder.SetCaptureRequests(key.CaptureRequests)
affected, err := builder.Save(ctx) affected, err := builder.Save(ctx)
if err != nil { if err != nil {
return err return err
...@@ -634,9 +637,10 @@ func apiKeyEntityToService(m *dbent.APIKey) *service.APIKey { ...@@ -634,9 +637,10 @@ func apiKeyEntityToService(m *dbent.APIKey) *service.APIKey {
Usage5h: m.Usage5h, Usage5h: m.Usage5h,
Usage1d: m.Usage1d, Usage1d: m.Usage1d,
Usage7d: m.Usage7d, Usage7d: m.Usage7d,
Window5hStart: m.Window5hStart, Window5hStart: m.Window5hStart,
Window1dStart: m.Window1dStart, Window1dStart: m.Window1dStart,
Window7dStart: m.Window7dStart, Window7dStart: m.Window7dStart,
CaptureRequests: m.CaptureRequests,
} }
if m.Edges.User != nil { if m.Edges.User != nil {
out.User = userEntityToService(m.Edges.User) out.User = userEntityToService(m.Edges.User)
......
package repository
import (
"context"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/internal/service"
)
type requestCaptureLogRepository struct {
client *dbent.Client
}
// NewRequestCaptureLogRepository 创建请求捕获日志仓储实例。
func NewRequestCaptureLogRepository(client *dbent.Client) service.RequestCaptureLogRepository {
return &requestCaptureLogRepository{client: client}
}
func (r *requestCaptureLogRepository) Create(ctx context.Context, params service.CreateRequestCaptureLogParams) (int64, error) {
q := r.client.RequestCaptureLog.Create().
SetAPIKeyID(params.APIKeyID).
SetUserID(params.UserID)
if params.RequestID != "" {
q = q.SetRequestID(params.RequestID)
}
if params.Path != "" {
q = q.SetPath(params.Path)
}
if params.Method != "" {
q = q.SetMethod(params.Method)
}
if params.IPAddress != "" {
q = q.SetIPAddress(params.IPAddress)
}
if params.RequestBody != "" {
q = q.SetRequestBody(params.RequestBody)
}
if params.NFSFilePath != "" {
q = q.SetNfsFilePath(params.NFSFilePath)
}
row, err := q.Save(ctx)
if err != nil {
return 0, err
}
return row.ID, nil
}
func (r *requestCaptureLogRepository) UpdateResponseBody(ctx context.Context, id int64, responseBody string) error {
return r.client.RequestCaptureLog.UpdateOneID(id).
SetResponseBody(responseBody).
Exec(ctx)
}
...@@ -92,6 +92,7 @@ var ProviderSet = wire.NewSet( ...@@ -92,6 +92,7 @@ var ProviderSet = wire.NewSet(
NewChannelMonitorRepository, NewChannelMonitorRepository,
NewChannelMonitorRequestTemplateRepository, NewChannelMonitorRequestTemplateRepository,
NewAffiliateRepository, NewAffiliateRepository,
NewRequestCaptureLogRepository,
// Cache implementations // Cache implementations
NewGatewayCache, NewGatewayCache,
......
...@@ -101,6 +101,7 @@ func registerAdminAPIKeyRoutes(admin *gin.RouterGroup, h *handler.Handlers) { ...@@ -101,6 +101,7 @@ func registerAdminAPIKeyRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
apiKeys := admin.Group("/api-keys") apiKeys := admin.Group("/api-keys")
{ {
apiKeys.PUT("/:id", h.Admin.APIKey.UpdateGroup) apiKeys.PUT("/:id", h.Admin.APIKey.UpdateGroup)
apiKeys.PUT("/:id/capture-requests", h.Admin.APIKey.SetCaptureRequests)
} }
} }
......
...@@ -58,6 +58,7 @@ type AdminService interface { ...@@ -58,6 +58,7 @@ type AdminService interface {
// API Key management (admin) // API Key management (admin)
AdminUpdateAPIKeyGroupID(ctx context.Context, keyID int64, groupID *int64) (*AdminUpdateAPIKeyGroupIDResult, error) AdminUpdateAPIKeyGroupID(ctx context.Context, keyID int64, groupID *int64) (*AdminUpdateAPIKeyGroupIDResult, error)
AdminSetCaptureRequests(ctx context.Context, keyID int64, enabled bool) (*APIKey, error)
// ReplaceUserGroup 替换用户的专属分组:授予新分组权限、迁移 Key、移除旧分组权限 // ReplaceUserGroup 替换用户的专属分组:授予新分组权限、迁移 Key、移除旧分组权限
ReplaceUserGroup(ctx context.Context, userID, oldGroupID, newGroupID int64) (*ReplaceUserGroupResult, error) ReplaceUserGroup(ctx context.Context, userID, oldGroupID, newGroupID int64) (*ReplaceUserGroupResult, error)
...@@ -1961,6 +1962,22 @@ func (s *adminServiceImpl) AdminUpdateAPIKeyGroupID(ctx context.Context, keyID i ...@@ -1961,6 +1962,22 @@ func (s *adminServiceImpl) AdminUpdateAPIKeyGroupID(ctx context.Context, keyID i
return result, nil return result, nil
} }
// AdminSetCaptureRequests 设置或清除指定 API Key 的请求捕获开关,并立即失效认证缓存。
func (s *adminServiceImpl) AdminSetCaptureRequests(ctx context.Context, keyID int64, enabled bool) (*APIKey, error) {
apiKey, err := s.apiKeyRepo.GetByID(ctx, keyID)
if err != nil {
return nil, err
}
apiKey.CaptureRequests = enabled
if err := s.apiKeyRepo.Update(ctx, apiKey); err != nil {
return nil, err
}
if s.authCacheInvalidator != nil {
s.authCacheInvalidator.InvalidateAuthCacheByKey(ctx, apiKey.Key)
}
return apiKey, nil
}
// ReplaceUserGroup 替换用户的专属分组 // ReplaceUserGroup 替换用户的专属分组
func (s *adminServiceImpl) ReplaceUserGroup(ctx context.Context, userID, oldGroupID, newGroupID int64) (*ReplaceUserGroupResult, error) { func (s *adminServiceImpl) ReplaceUserGroup(ctx context.Context, userID, oldGroupID, newGroupID int64) (*ReplaceUserGroupResult, error) {
if oldGroupID == newGroupID { if oldGroupID == newGroupID {
......
...@@ -21,6 +21,7 @@ import ( ...@@ -21,6 +21,7 @@ import (
"time" "time"
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity" "github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger" "github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/google/uuid" "github.com/google/uuid"
...@@ -1739,6 +1740,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context, ...@@ -1739,6 +1740,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
var usage *ClaudeUsage var usage *ClaudeUsage
var firstTokenMs *int var firstTokenMs *int
var clientDisconnect bool var clientDisconnect bool
var responseBody string
if claudeReq.Stream { if claudeReq.Stream {
// 客户端要求流式,直接透传转换 // 客户端要求流式,直接透传转换
streamRes, err := s.handleClaudeStreamingResponse(c, resp, startTime, originalModel) streamRes, err := s.handleClaudeStreamingResponse(c, resp, startTime, originalModel)
...@@ -1749,6 +1751,10 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context, ...@@ -1749,6 +1751,10 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
usage = streamRes.usage usage = streamRes.usage
firstTokenMs = streamRes.firstTokenMs firstTokenMs = streamRes.firstTokenMs
clientDisconnect = streamRes.clientDisconnect clientDisconnect = streamRes.clientDisconnect
// 流式:从 context buffer 读取采集的文本
if captureBuilder, ok := ctx.Value(ctxkey.ResponseCaptureBuffer).(*strings.Builder); ok && captureBuilder != nil {
responseBody = captureBuilder.String()
}
} else { } else {
// 客户端要求非流式,收集流式响应后转换返回 // 客户端要求非流式,收集流式响应后转换返回
streamRes, err := s.handleClaudeStreamToNonStreaming(c, resp, startTime, originalModel) streamRes, err := s.handleClaudeStreamToNonStreaming(c, resp, startTime, originalModel)
...@@ -1758,6 +1764,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context, ...@@ -1758,6 +1764,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
} }
usage = streamRes.usage usage = streamRes.usage
firstTokenMs = streamRes.firstTokenMs firstTokenMs = streamRes.firstTokenMs
responseBody = streamRes.responseBody
} }
return &ForwardResult{ return &ForwardResult{
...@@ -1769,6 +1776,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context, ...@@ -1769,6 +1776,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
Duration: time.Since(startTime), Duration: time.Since(startTime),
FirstTokenMs: firstTokenMs, FirstTokenMs: firstTokenMs,
ClientDisconnect: clientDisconnect, ClientDisconnect: clientDisconnect,
ResponseBody: responseBody,
}, nil }, nil
} }
...@@ -2421,6 +2429,7 @@ handleSuccess: ...@@ -2421,6 +2429,7 @@ handleSuccess:
var usage *ClaudeUsage var usage *ClaudeUsage
var firstTokenMs *int var firstTokenMs *int
var clientDisconnect bool var clientDisconnect bool
var responseBody string
if stream { if stream {
// 客户端要求流式,直接透传 // 客户端要求流式,直接透传
...@@ -2432,6 +2441,10 @@ handleSuccess: ...@@ -2432,6 +2441,10 @@ handleSuccess:
usage = streamRes.usage usage = streamRes.usage
firstTokenMs = streamRes.firstTokenMs firstTokenMs = streamRes.firstTokenMs
clientDisconnect = streamRes.clientDisconnect clientDisconnect = streamRes.clientDisconnect
// 流式:从 context buffer 读取采集的文本
if captureBuilder, ok := ctx.Value(ctxkey.ResponseCaptureBuffer).(*strings.Builder); ok && captureBuilder != nil {
responseBody = captureBuilder.String()
}
} else { } else {
// 客户端要求非流式,收集流式响应后返回 // 客户端要求非流式,收集流式响应后返回
streamRes, err := s.handleGeminiStreamToNonStreaming(c, resp, startTime) streamRes, err := s.handleGeminiStreamToNonStreaming(c, resp, startTime)
...@@ -2441,6 +2454,7 @@ handleSuccess: ...@@ -2441,6 +2454,7 @@ handleSuccess:
} }
usage = streamRes.usage usage = streamRes.usage
firstTokenMs = streamRes.firstTokenMs firstTokenMs = streamRes.firstTokenMs
responseBody = streamRes.responseBody
} }
if usage == nil { if usage == nil {
...@@ -2465,6 +2479,7 @@ handleSuccess: ...@@ -2465,6 +2479,7 @@ handleSuccess:
ClientDisconnect: clientDisconnect, ClientDisconnect: clientDisconnect,
ImageCount: imageCount, ImageCount: imageCount,
ImageSize: imageSize, ImageSize: imageSize,
ResponseBody: responseBody,
}, nil }, nil
} }
...@@ -2968,7 +2983,8 @@ func (s *AntigravityGatewayService) resolveResetTime(resetAt *int64, defaultDur ...@@ -2968,7 +2983,8 @@ func (s *AntigravityGatewayService) resolveResetTime(resetAt *int64, defaultDur
type antigravityStreamResult struct { type antigravityStreamResult struct {
usage *ClaudeUsage usage *ClaudeUsage
firstTokenMs *int firstTokenMs *int
clientDisconnect bool // 客户端是否在流式传输过程中断开 clientDisconnect bool // 客户端是否在流式传输过程中断开
responseBody string // 响应体内容(非流式:完整 JSON;流式:从上下文 buffer 读取)
} }
// antigravityClientWriter 封装流式响应的客户端写入,自动检测断开并标记。 // antigravityClientWriter 封装流式响应的客户端写入,自动检测断开并标记。
...@@ -3124,6 +3140,9 @@ func (s *AntigravityGatewayService) handleGeminiStreamingResponse(c *gin.Context ...@@ -3124,6 +3140,9 @@ func (s *AntigravityGatewayService) handleGeminiStreamingResponse(c *gin.Context
cw := newAntigravityClientWriter(c.Writer, flusher, "antigravity gemini") cw := newAntigravityClientWriter(c.Writer, flusher, "antigravity gemini")
// 响应体文本采集:若上下文注入了 ResponseCaptureBuffer,则写入文本内容
captureBuilder, _ := c.Request.Context().Value(ctxkey.ResponseCaptureBuffer).(*strings.Builder)
// 仅发送一次错误事件,避免多次写入导致协议混乱 // 仅发送一次错误事件,避免多次写入导致协议混乱
errorEventSent := false errorEventSent := false
sendErrorEvent := func(reason string) { sendErrorEvent := func(reason string) {
...@@ -3197,6 +3216,16 @@ func (s *AntigravityGatewayService) handleGeminiStreamingResponse(c *gin.Context ...@@ -3197,6 +3216,16 @@ func (s *AntigravityGatewayService) handleGeminiStreamingResponse(c *gin.Context
firstTokenMs = &ms firstTokenMs = &ms
} }
// 采集文本用于响应捕获
if captureBuilder != nil && len(inner) > 0 {
gjson.GetBytes(inner, "candidates.0.content.parts").ForEach(func(_, v gjson.Result) bool {
if t := v.Get("text").String(); t != "" {
captureBuilder.WriteString(t)
}
return true
})
}
cw.Fprintf("data: %s\n\n", payload) cw.Fprintf("data: %s\n\n", payload)
continue continue
} }
...@@ -3418,7 +3447,7 @@ returnResponse: ...@@ -3418,7 +3447,7 @@ returnResponse:
} }
c.Data(http.StatusOK, "application/json", respBody) c.Data(http.StatusOK, "application/json", respBody)
return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs}, nil return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs, responseBody: strings.Join(collectedTextParts, "")}, nil
} }
// getOrCreateGeminiParts 获取 Gemini 响应的 parts 结构,返回深拷贝和更新回调 // getOrCreateGeminiParts 获取 Gemini 响应的 parts 结构,返回深拷贝和更新回调
...@@ -3867,7 +3896,7 @@ returnResponse: ...@@ -3867,7 +3896,7 @@ returnResponse:
CacheReadInputTokens: agUsage.CacheReadInputTokens, CacheReadInputTokens: agUsage.CacheReadInputTokens,
} }
return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs}, nil return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs, responseBody: string(claudeResp)}, nil
} }
// handleClaudeStreamingResponse 处理 Claude 流式响应(Gemini SSE → Claude SSE 转换) // handleClaudeStreamingResponse 处理 Claude 流式响应(Gemini SSE → Claude SSE 转换)
...@@ -3971,6 +4000,9 @@ func (s *AntigravityGatewayService) handleClaudeStreamingResponse(c *gin.Context ...@@ -3971,6 +4000,9 @@ func (s *AntigravityGatewayService) handleClaudeStreamingResponse(c *gin.Context
cw := newAntigravityClientWriter(c.Writer, flusher, "antigravity claude") cw := newAntigravityClientWriter(c.Writer, flusher, "antigravity claude")
// 响应体文本采集:若上下文注入了 ResponseCaptureBuffer,则写入文本 delta
captureBuilder, _ := c.Request.Context().Value(ctxkey.ResponseCaptureBuffer).(*strings.Builder)
// 仅发送一次错误事件,避免多次写入导致协议混乱 // 仅发送一次错误事件,避免多次写入导致协议混乱
errorEventSent := false errorEventSent := false
sendErrorEvent := func(reason string) { sendErrorEvent := func(reason string) {
...@@ -4024,7 +4056,8 @@ func (s *AntigravityGatewayService) handleClaudeStreamingResponse(c *gin.Context ...@@ -4024,7 +4056,8 @@ func (s *AntigravityGatewayService) handleClaudeStreamingResponse(c *gin.Context
lastDataAt = time.Now() lastDataAt = time.Now()
// 处理 SSE 行,转换为 Claude 格式 // 处理 SSE 行,转换为 Claude 格式
claudeEvents := processor.ProcessLine(strings.TrimRight(ev.line, "\r\n")) trimmedLine := strings.TrimRight(ev.line, "\r\n")
claudeEvents := processor.ProcessLine(trimmedLine)
if len(claudeEvents) > 0 { if len(claudeEvents) > 0 {
if firstTokenMs == nil { if firstTokenMs == nil {
ms := int(time.Since(startTime).Milliseconds()) ms := int(time.Since(startTime).Milliseconds())
...@@ -4033,6 +4066,22 @@ func (s *AntigravityGatewayService) handleClaudeStreamingResponse(c *gin.Context ...@@ -4033,6 +4066,22 @@ func (s *AntigravityGatewayService) handleClaudeStreamingResponse(c *gin.Context
cw.Write(claudeEvents) cw.Write(claudeEvents)
} }
// 采集文本用于响应捕获
if captureBuilder != nil && strings.HasPrefix(trimmedLine, "data:") {
data := strings.TrimSpace(strings.TrimPrefix(trimmedLine, "data:"))
if data != "" && data != "[DONE]" {
// v1internal 格式: response.candidates.0.content.parts.0.text
// 直接 Gemini 格式: candidates.0.content.parts.0.text
text := gjson.Get(data, "response.candidates.0.content.parts.0.text").String()
if text == "" {
text = gjson.Get(data, "candidates.0.content.parts.0.text").String()
}
if text != "" {
captureBuilder.WriteString(text)
}
}
}
case <-intervalCh: case <-intervalCh:
lastRead := time.Unix(0, atomic.LoadInt64(&lastReadAt)) lastRead := time.Unix(0, atomic.LoadInt64(&lastReadAt))
if time.Since(lastRead) < streamInterval { if time.Since(lastRead) < streamInterval {
...@@ -4286,6 +4335,7 @@ func (s *AntigravityGatewayService) ForwardUpstream(ctx context.Context, c *gin. ...@@ -4286,6 +4335,7 @@ func (s *AntigravityGatewayService) ForwardUpstream(ctx context.Context, c *gin.
var usage *ClaudeUsage var usage *ClaudeUsage
var firstTokenMs *int var firstTokenMs *int
var clientDisconnect bool var clientDisconnect bool
var responseBody string
if claudeReq.Stream { if claudeReq.Stream {
// 流式响应:透传 // 流式响应:透传
...@@ -4295,10 +4345,14 @@ func (s *AntigravityGatewayService) ForwardUpstream(ctx context.Context, c *gin. ...@@ -4295,10 +4345,14 @@ func (s *AntigravityGatewayService) ForwardUpstream(ctx context.Context, c *gin.
c.Header("X-Accel-Buffering", "no") c.Header("X-Accel-Buffering", "no")
c.Status(http.StatusOK) c.Status(http.StatusOK)
streamRes := s.streamUpstreamResponse(c, resp, startTime) streamRes := s.streamUpstreamResponse(ctx, c, resp, startTime)
usage = streamRes.usage usage = streamRes.usage
firstTokenMs = streamRes.firstTokenMs firstTokenMs = streamRes.firstTokenMs
clientDisconnect = streamRes.clientDisconnect clientDisconnect = streamRes.clientDisconnect
// 从 context buffer 读取已收集的 assistant 文本
if captureBuilder, ok := ctx.Value(ctxkey.ResponseCaptureBuffer).(*strings.Builder); ok && captureBuilder != nil {
responseBody = captureBuilder.String()
}
} else { } else {
// 非流式响应:直接透传 // 非流式响应:直接透传
respBody, err := io.ReadAll(resp.Body) respBody, err := io.ReadAll(resp.Body)
...@@ -4308,6 +4362,7 @@ func (s *AntigravityGatewayService) ForwardUpstream(ctx context.Context, c *gin. ...@@ -4308,6 +4362,7 @@ func (s *AntigravityGatewayService) ForwardUpstream(ctx context.Context, c *gin.
// 提取 usage // 提取 usage
usage = s.extractClaudeUsage(respBody) usage = s.extractClaudeUsage(respBody)
responseBody = string(respBody)
c.Header("Content-Type", resp.Header.Get("Content-Type")) c.Header("Content-Type", resp.Header.Get("Content-Type"))
c.Status(http.StatusOK) c.Status(http.StatusOK)
...@@ -4324,6 +4379,7 @@ func (s *AntigravityGatewayService) ForwardUpstream(ctx context.Context, c *gin. ...@@ -4324,6 +4379,7 @@ func (s *AntigravityGatewayService) ForwardUpstream(ctx context.Context, c *gin.
Duration: duration, Duration: duration,
FirstTokenMs: firstTokenMs, FirstTokenMs: firstTokenMs,
ClientDisconnect: clientDisconnect, ClientDisconnect: clientDisconnect,
ResponseBody: responseBody,
Usage: ClaudeUsage{ Usage: ClaudeUsage{
InputTokens: usage.InputTokens, InputTokens: usage.InputTokens,
OutputTokens: usage.OutputTokens, OutputTokens: usage.OutputTokens,
...@@ -4334,10 +4390,13 @@ func (s *AntigravityGatewayService) ForwardUpstream(ctx context.Context, c *gin. ...@@ -4334,10 +4390,13 @@ func (s *AntigravityGatewayService) ForwardUpstream(ctx context.Context, c *gin.
} }
// streamUpstreamResponse 透传上游 SSE 流并提取 Claude usage // streamUpstreamResponse 透传上游 SSE 流并提取 Claude usage
func (s *AntigravityGatewayService) streamUpstreamResponse(c *gin.Context, resp *http.Response, startTime time.Time) *antigravityStreamResult { func (s *AntigravityGatewayService) streamUpstreamResponse(ctx context.Context, c *gin.Context, resp *http.Response, startTime time.Time) *antigravityStreamResult {
usage := &ClaudeUsage{} usage := &ClaudeUsage{}
var firstTokenMs *int var firstTokenMs *int
// 响应体捕获:若 context 中注入了 ResponseCaptureBuffer,则收集 text_delta 文本
captureBuilder, _ := ctx.Value(ctxkey.ResponseCaptureBuffer).(*strings.Builder)
scanner := bufio.NewScanner(resp.Body) scanner := bufio.NewScanner(resp.Body)
maxLineSize := defaultMaxLineSize maxLineSize := defaultMaxLineSize
if s.settingService.cfg != nil && s.settingService.cfg.Gateway.MaxLineSize > 0 { if s.settingService.cfg != nil && s.settingService.cfg.Gateway.MaxLineSize > 0 {
...@@ -4435,6 +4494,16 @@ func (s *AntigravityGatewayService) streamUpstreamResponse(c *gin.Context, resp ...@@ -4435,6 +4494,16 @@ func (s *AntigravityGatewayService) streamUpstreamResponse(c *gin.Context, resp
// 尝试从 message_delta 或 message_stop 事件提取 usage // 尝试从 message_delta 或 message_stop 事件提取 usage
s.extractSSEUsage(line, usage) s.extractSSEUsage(line, usage)
// 收集 assistant text(仅 content_block_delta + text_delta)
if captureBuilder != nil && strings.HasPrefix(line, "data: ") {
data := strings.TrimPrefix(line, "data: ")
if gjson.Get(data, "type").String() == "content_block_delta" {
if gjson.Get(data, "delta.type").String() == "text_delta" {
captureBuilder.WriteString(gjson.Get(data, "delta.text").String())
}
}
}
// 透传行 // 透传行
cw.Fprintf("%s\n", line) cw.Fprintf("%s\n", line)
......
...@@ -60,6 +60,9 @@ type APIKey struct { ...@@ -60,6 +60,9 @@ type APIKey struct {
Window5hStart *time.Time // Start of current 5h window Window5hStart *time.Time // Start of current 5h window
Window1dStart *time.Time // Start of current 1d window Window1dStart *time.Time // Start of current 1d window
Window7dStart *time.Time // Start of current 7d window Window7dStart *time.Time // Start of current 7d window
// 请求体捕获
CaptureRequests bool // 是否对该 Key 的请求体进行存储捕获
} }
func (k *APIKey) IsActive() bool { func (k *APIKey) IsActive() bool {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment