Unverified Commit 186e3675 authored by Wesley Liddick's avatar Wesley Liddick Committed by GitHub
Browse files

Merge pull request #1194 from Ethan0x0000/feat/requested-upstream-model-semantics

feat(usage): 统一使用记录中的请求模型与上游模型语义
parents 421728a9 27948c77
...@@ -716,6 +716,7 @@ var ( ...@@ -716,6 +716,7 @@ var (
{Name: "id", Type: field.TypeInt64, Increment: true}, {Name: "id", Type: field.TypeInt64, Increment: true},
{Name: "request_id", Type: field.TypeString, Size: 64}, {Name: "request_id", Type: field.TypeString, Size: 64},
{Name: "model", Type: field.TypeString, Size: 100}, {Name: "model", Type: field.TypeString, Size: 100},
{Name: "requested_model", Type: field.TypeString, Nullable: true, Size: 100},
{Name: "upstream_model", Type: field.TypeString, Nullable: true, Size: 100}, {Name: "upstream_model", Type: field.TypeString, Nullable: true, Size: 100},
{Name: "input_tokens", Type: field.TypeInt, Default: 0}, {Name: "input_tokens", Type: field.TypeInt, Default: 0},
{Name: "output_tokens", Type: field.TypeInt, Default: 0}, {Name: "output_tokens", Type: field.TypeInt, Default: 0},
...@@ -756,31 +757,31 @@ var ( ...@@ -756,31 +757,31 @@ var (
ForeignKeys: []*schema.ForeignKey{ ForeignKeys: []*schema.ForeignKey{
{ {
Symbol: "usage_logs_api_keys_usage_logs", Symbol: "usage_logs_api_keys_usage_logs",
Columns: []*schema.Column{UsageLogsColumns[29]}, Columns: []*schema.Column{UsageLogsColumns[30]},
RefColumns: []*schema.Column{APIKeysColumns[0]}, RefColumns: []*schema.Column{APIKeysColumns[0]},
OnDelete: schema.NoAction, OnDelete: schema.NoAction,
}, },
{ {
Symbol: "usage_logs_accounts_usage_logs", Symbol: "usage_logs_accounts_usage_logs",
Columns: []*schema.Column{UsageLogsColumns[30]}, Columns: []*schema.Column{UsageLogsColumns[31]},
RefColumns: []*schema.Column{AccountsColumns[0]}, RefColumns: []*schema.Column{AccountsColumns[0]},
OnDelete: schema.NoAction, OnDelete: schema.NoAction,
}, },
{ {
Symbol: "usage_logs_groups_usage_logs", Symbol: "usage_logs_groups_usage_logs",
Columns: []*schema.Column{UsageLogsColumns[31]}, Columns: []*schema.Column{UsageLogsColumns[32]},
RefColumns: []*schema.Column{GroupsColumns[0]}, RefColumns: []*schema.Column{GroupsColumns[0]},
OnDelete: schema.SetNull, OnDelete: schema.SetNull,
}, },
{ {
Symbol: "usage_logs_users_usage_logs", Symbol: "usage_logs_users_usage_logs",
Columns: []*schema.Column{UsageLogsColumns[32]}, Columns: []*schema.Column{UsageLogsColumns[33]},
RefColumns: []*schema.Column{UsersColumns[0]}, RefColumns: []*schema.Column{UsersColumns[0]},
OnDelete: schema.NoAction, OnDelete: schema.NoAction,
}, },
{ {
Symbol: "usage_logs_user_subscriptions_usage_logs", Symbol: "usage_logs_user_subscriptions_usage_logs",
Columns: []*schema.Column{UsageLogsColumns[33]}, Columns: []*schema.Column{UsageLogsColumns[34]},
RefColumns: []*schema.Column{UserSubscriptionsColumns[0]}, RefColumns: []*schema.Column{UserSubscriptionsColumns[0]},
OnDelete: schema.SetNull, OnDelete: schema.SetNull,
}, },
...@@ -789,38 +790,43 @@ var ( ...@@ -789,38 +790,43 @@ var (
{ {
Name: "usagelog_user_id", Name: "usagelog_user_id",
Unique: false, Unique: false,
Columns: []*schema.Column{UsageLogsColumns[32]}, Columns: []*schema.Column{UsageLogsColumns[33]},
}, },
{ {
Name: "usagelog_api_key_id", Name: "usagelog_api_key_id",
Unique: false, Unique: false,
Columns: []*schema.Column{UsageLogsColumns[29]}, Columns: []*schema.Column{UsageLogsColumns[30]},
}, },
{ {
Name: "usagelog_account_id", Name: "usagelog_account_id",
Unique: false, Unique: false,
Columns: []*schema.Column{UsageLogsColumns[30]}, Columns: []*schema.Column{UsageLogsColumns[31]},
}, },
{ {
Name: "usagelog_group_id", Name: "usagelog_group_id",
Unique: false, Unique: false,
Columns: []*schema.Column{UsageLogsColumns[31]}, Columns: []*schema.Column{UsageLogsColumns[32]},
}, },
{ {
Name: "usagelog_subscription_id", Name: "usagelog_subscription_id",
Unique: false, Unique: false,
Columns: []*schema.Column{UsageLogsColumns[33]}, Columns: []*schema.Column{UsageLogsColumns[34]},
}, },
{ {
Name: "usagelog_created_at", Name: "usagelog_created_at",
Unique: false, Unique: false,
Columns: []*schema.Column{UsageLogsColumns[28]}, Columns: []*schema.Column{UsageLogsColumns[29]},
}, },
{ {
Name: "usagelog_model", Name: "usagelog_model",
Unique: false, Unique: false,
Columns: []*schema.Column{UsageLogsColumns[2]}, Columns: []*schema.Column{UsageLogsColumns[2]},
}, },
{
Name: "usagelog_requested_model",
Unique: false,
Columns: []*schema.Column{UsageLogsColumns[3]},
},
{ {
Name: "usagelog_request_id", Name: "usagelog_request_id",
Unique: false, Unique: false,
...@@ -829,17 +835,17 @@ var ( ...@@ -829,17 +835,17 @@ var (
{ {
Name: "usagelog_user_id_created_at", Name: "usagelog_user_id_created_at",
Unique: false, Unique: false,
Columns: []*schema.Column{UsageLogsColumns[32], UsageLogsColumns[28]}, Columns: []*schema.Column{UsageLogsColumns[33], UsageLogsColumns[29]},
}, },
{ {
Name: "usagelog_api_key_id_created_at", Name: "usagelog_api_key_id_created_at",
Unique: false, Unique: false,
Columns: []*schema.Column{UsageLogsColumns[29], UsageLogsColumns[28]}, Columns: []*schema.Column{UsageLogsColumns[30], UsageLogsColumns[29]},
}, },
{ {
Name: "usagelog_group_id_created_at", Name: "usagelog_group_id_created_at",
Unique: false, Unique: false,
Columns: []*schema.Column{UsageLogsColumns[31], UsageLogsColumns[28]}, Columns: []*schema.Column{UsageLogsColumns[32], UsageLogsColumns[29]},
}, },
}, },
} }
......
...@@ -18239,6 +18239,7 @@ type UsageLogMutation struct { ...@@ -18239,6 +18239,7 @@ type UsageLogMutation struct {
id *int64 id *int64
request_id *string request_id *string
model *string model *string
requested_model *string
upstream_model *string upstream_model *string
input_tokens *int input_tokens *int
addinput_tokens *int addinput_tokens *int
...@@ -18577,6 +18578,55 @@ func (m *UsageLogMutation) ResetModel() { ...@@ -18577,6 +18578,55 @@ func (m *UsageLogMutation) ResetModel() {
m.model = nil m.model = nil
} }
   
// SetRequestedModel sets the "requested_model" field.
func (m *UsageLogMutation) SetRequestedModel(s string) {
m.requested_model = &s
}
// RequestedModel returns the value of the "requested_model" field in the mutation.
func (m *UsageLogMutation) RequestedModel() (r string, exists bool) {
v := m.requested_model
if v == nil {
return
}
return *v, true
}
// OldRequestedModel returns the old "requested_model" field's value of the UsageLog entity.
// If the UsageLog object wasn't provided to the builder, the object is fetched from the database.
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
func (m *UsageLogMutation) OldRequestedModel(ctx context.Context) (v *string, err error) {
if !m.op.Is(OpUpdateOne) {
return v, errors.New("OldRequestedModel is only allowed on UpdateOne operations")
}
if m.id == nil || m.oldValue == nil {
return v, errors.New("OldRequestedModel requires an ID field in the mutation")
}
oldValue, err := m.oldValue(ctx)
if err != nil {
return v, fmt.Errorf("querying old value for OldRequestedModel: %w", err)
}
return oldValue.RequestedModel, nil
}
// ClearRequestedModel clears the value of the "requested_model" field.
func (m *UsageLogMutation) ClearRequestedModel() {
m.requested_model = nil
m.clearedFields[usagelog.FieldRequestedModel] = struct{}{}
}
// RequestedModelCleared returns if the "requested_model" field was cleared in this mutation.
func (m *UsageLogMutation) RequestedModelCleared() bool {
_, ok := m.clearedFields[usagelog.FieldRequestedModel]
return ok
}
// ResetRequestedModel resets all changes to the "requested_model" field.
func (m *UsageLogMutation) ResetRequestedModel() {
m.requested_model = nil
delete(m.clearedFields, usagelog.FieldRequestedModel)
}
// SetUpstreamModel sets the "upstream_model" field. // SetUpstreamModel sets the "upstream_model" field.
func (m *UsageLogMutation) SetUpstreamModel(s string) { func (m *UsageLogMutation) SetUpstreamModel(s string) {
m.upstream_model = &s m.upstream_model = &s
...@@ -20247,7 +20297,7 @@ func (m *UsageLogMutation) Type() string { ...@@ -20247,7 +20297,7 @@ func (m *UsageLogMutation) Type() string {
// order to get all numeric fields that were incremented/decremented, call // order to get all numeric fields that were incremented/decremented, call
// AddedFields(). // AddedFields().
func (m *UsageLogMutation) Fields() []string { func (m *UsageLogMutation) Fields() []string {
fields := make([]string, 0, 33) fields := make([]string, 0, 34)
if m.user != nil { if m.user != nil {
fields = append(fields, usagelog.FieldUserID) fields = append(fields, usagelog.FieldUserID)
} }
...@@ -20263,6 +20313,9 @@ func (m *UsageLogMutation) Fields() []string { ...@@ -20263,6 +20313,9 @@ func (m *UsageLogMutation) Fields() []string {
if m.model != nil { if m.model != nil {
fields = append(fields, usagelog.FieldModel) fields = append(fields, usagelog.FieldModel)
} }
if m.requested_model != nil {
fields = append(fields, usagelog.FieldRequestedModel)
}
if m.upstream_model != nil { if m.upstream_model != nil {
fields = append(fields, usagelog.FieldUpstreamModel) fields = append(fields, usagelog.FieldUpstreamModel)
} }
...@@ -20365,6 +20418,8 @@ func (m *UsageLogMutation) Field(name string) (ent.Value, bool) { ...@@ -20365,6 +20418,8 @@ func (m *UsageLogMutation) Field(name string) (ent.Value, bool) {
return m.RequestID() return m.RequestID()
case usagelog.FieldModel: case usagelog.FieldModel:
return m.Model() return m.Model()
case usagelog.FieldRequestedModel:
return m.RequestedModel()
case usagelog.FieldUpstreamModel: case usagelog.FieldUpstreamModel:
return m.UpstreamModel() return m.UpstreamModel()
case usagelog.FieldGroupID: case usagelog.FieldGroupID:
...@@ -20440,6 +20495,8 @@ func (m *UsageLogMutation) OldField(ctx context.Context, name string) (ent.Value ...@@ -20440,6 +20495,8 @@ func (m *UsageLogMutation) OldField(ctx context.Context, name string) (ent.Value
return m.OldRequestID(ctx) return m.OldRequestID(ctx)
case usagelog.FieldModel: case usagelog.FieldModel:
return m.OldModel(ctx) return m.OldModel(ctx)
case usagelog.FieldRequestedModel:
return m.OldRequestedModel(ctx)
case usagelog.FieldUpstreamModel: case usagelog.FieldUpstreamModel:
return m.OldUpstreamModel(ctx) return m.OldUpstreamModel(ctx)
case usagelog.FieldGroupID: case usagelog.FieldGroupID:
...@@ -20540,6 +20597,13 @@ func (m *UsageLogMutation) SetField(name string, value ent.Value) error { ...@@ -20540,6 +20597,13 @@ func (m *UsageLogMutation) SetField(name string, value ent.Value) error {
} }
m.SetModel(v) m.SetModel(v)
return nil return nil
case usagelog.FieldRequestedModel:
v, ok := value.(string)
if !ok {
return fmt.Errorf("unexpected type %T for field %s", value, name)
}
m.SetRequestedModel(v)
return nil
case usagelog.FieldUpstreamModel: case usagelog.FieldUpstreamModel:
v, ok := value.(string) v, ok := value.(string)
if !ok { if !ok {
...@@ -20985,6 +21049,9 @@ func (m *UsageLogMutation) AddField(name string, value ent.Value) error { ...@@ -20985,6 +21049,9 @@ func (m *UsageLogMutation) AddField(name string, value ent.Value) error {
// mutation. // mutation.
func (m *UsageLogMutation) ClearedFields() []string { func (m *UsageLogMutation) ClearedFields() []string {
var fields []string var fields []string
if m.FieldCleared(usagelog.FieldRequestedModel) {
fields = append(fields, usagelog.FieldRequestedModel)
}
if m.FieldCleared(usagelog.FieldUpstreamModel) { if m.FieldCleared(usagelog.FieldUpstreamModel) {
fields = append(fields, usagelog.FieldUpstreamModel) fields = append(fields, usagelog.FieldUpstreamModel)
} }
...@@ -21029,6 +21096,9 @@ func (m *UsageLogMutation) FieldCleared(name string) bool { ...@@ -21029,6 +21096,9 @@ func (m *UsageLogMutation) FieldCleared(name string) bool {
// error if the field is not defined in the schema. // error if the field is not defined in the schema.
func (m *UsageLogMutation) ClearField(name string) error { func (m *UsageLogMutation) ClearField(name string) error {
switch name { switch name {
case usagelog.FieldRequestedModel:
m.ClearRequestedModel()
return nil
case usagelog.FieldUpstreamModel: case usagelog.FieldUpstreamModel:
m.ClearUpstreamModel() m.ClearUpstreamModel()
return nil return nil
...@@ -21082,6 +21152,9 @@ func (m *UsageLogMutation) ResetField(name string) error { ...@@ -21082,6 +21152,9 @@ func (m *UsageLogMutation) ResetField(name string) error {
case usagelog.FieldModel: case usagelog.FieldModel:
m.ResetModel() m.ResetModel()
return nil return nil
case usagelog.FieldRequestedModel:
m.ResetRequestedModel()
return nil
case usagelog.FieldUpstreamModel: case usagelog.FieldUpstreamModel:
m.ResetUpstreamModel() m.ResetUpstreamModel()
return nil return nil
......
...@@ -821,96 +821,100 @@ func init() { ...@@ -821,96 +821,100 @@ func init() {
return nil return nil
} }
}() }()
// usagelogDescRequestedModel is the schema descriptor for requested_model field.
usagelogDescRequestedModel := usagelogFields[5].Descriptor()
// usagelog.RequestedModelValidator is a validator for the "requested_model" field. It is called by the builders before save.
usagelog.RequestedModelValidator = usagelogDescRequestedModel.Validators[0].(func(string) error)
// usagelogDescUpstreamModel is the schema descriptor for upstream_model field. // usagelogDescUpstreamModel is the schema descriptor for upstream_model field.
usagelogDescUpstreamModel := usagelogFields[5].Descriptor() usagelogDescUpstreamModel := usagelogFields[6].Descriptor()
// usagelog.UpstreamModelValidator is a validator for the "upstream_model" field. It is called by the builders before save. // usagelog.UpstreamModelValidator is a validator for the "upstream_model" field. It is called by the builders before save.
usagelog.UpstreamModelValidator = usagelogDescUpstreamModel.Validators[0].(func(string) error) usagelog.UpstreamModelValidator = usagelogDescUpstreamModel.Validators[0].(func(string) error)
// usagelogDescInputTokens is the schema descriptor for input_tokens field. // usagelogDescInputTokens is the schema descriptor for input_tokens field.
usagelogDescInputTokens := usagelogFields[8].Descriptor() usagelogDescInputTokens := usagelogFields[9].Descriptor()
// usagelog.DefaultInputTokens holds the default value on creation for the input_tokens field. // usagelog.DefaultInputTokens holds the default value on creation for the input_tokens field.
usagelog.DefaultInputTokens = usagelogDescInputTokens.Default.(int) usagelog.DefaultInputTokens = usagelogDescInputTokens.Default.(int)
// usagelogDescOutputTokens is the schema descriptor for output_tokens field. // usagelogDescOutputTokens is the schema descriptor for output_tokens field.
usagelogDescOutputTokens := usagelogFields[9].Descriptor() usagelogDescOutputTokens := usagelogFields[10].Descriptor()
// usagelog.DefaultOutputTokens holds the default value on creation for the output_tokens field. // usagelog.DefaultOutputTokens holds the default value on creation for the output_tokens field.
usagelog.DefaultOutputTokens = usagelogDescOutputTokens.Default.(int) usagelog.DefaultOutputTokens = usagelogDescOutputTokens.Default.(int)
// usagelogDescCacheCreationTokens is the schema descriptor for cache_creation_tokens field. // usagelogDescCacheCreationTokens is the schema descriptor for cache_creation_tokens field.
usagelogDescCacheCreationTokens := usagelogFields[10].Descriptor() usagelogDescCacheCreationTokens := usagelogFields[11].Descriptor()
// usagelog.DefaultCacheCreationTokens holds the default value on creation for the cache_creation_tokens field. // usagelog.DefaultCacheCreationTokens holds the default value on creation for the cache_creation_tokens field.
usagelog.DefaultCacheCreationTokens = usagelogDescCacheCreationTokens.Default.(int) usagelog.DefaultCacheCreationTokens = usagelogDescCacheCreationTokens.Default.(int)
// usagelogDescCacheReadTokens is the schema descriptor for cache_read_tokens field. // usagelogDescCacheReadTokens is the schema descriptor for cache_read_tokens field.
usagelogDescCacheReadTokens := usagelogFields[11].Descriptor() usagelogDescCacheReadTokens := usagelogFields[12].Descriptor()
// usagelog.DefaultCacheReadTokens holds the default value on creation for the cache_read_tokens field. // usagelog.DefaultCacheReadTokens holds the default value on creation for the cache_read_tokens field.
usagelog.DefaultCacheReadTokens = usagelogDescCacheReadTokens.Default.(int) usagelog.DefaultCacheReadTokens = usagelogDescCacheReadTokens.Default.(int)
// usagelogDescCacheCreation5mTokens is the schema descriptor for cache_creation_5m_tokens field. // usagelogDescCacheCreation5mTokens is the schema descriptor for cache_creation_5m_tokens field.
usagelogDescCacheCreation5mTokens := usagelogFields[12].Descriptor() usagelogDescCacheCreation5mTokens := usagelogFields[13].Descriptor()
// usagelog.DefaultCacheCreation5mTokens holds the default value on creation for the cache_creation_5m_tokens field. // usagelog.DefaultCacheCreation5mTokens holds the default value on creation for the cache_creation_5m_tokens field.
usagelog.DefaultCacheCreation5mTokens = usagelogDescCacheCreation5mTokens.Default.(int) usagelog.DefaultCacheCreation5mTokens = usagelogDescCacheCreation5mTokens.Default.(int)
// usagelogDescCacheCreation1hTokens is the schema descriptor for cache_creation_1h_tokens field. // usagelogDescCacheCreation1hTokens is the schema descriptor for cache_creation_1h_tokens field.
usagelogDescCacheCreation1hTokens := usagelogFields[13].Descriptor() usagelogDescCacheCreation1hTokens := usagelogFields[14].Descriptor()
// usagelog.DefaultCacheCreation1hTokens holds the default value on creation for the cache_creation_1h_tokens field. // usagelog.DefaultCacheCreation1hTokens holds the default value on creation for the cache_creation_1h_tokens field.
usagelog.DefaultCacheCreation1hTokens = usagelogDescCacheCreation1hTokens.Default.(int) usagelog.DefaultCacheCreation1hTokens = usagelogDescCacheCreation1hTokens.Default.(int)
// usagelogDescInputCost is the schema descriptor for input_cost field. // usagelogDescInputCost is the schema descriptor for input_cost field.
usagelogDescInputCost := usagelogFields[14].Descriptor() usagelogDescInputCost := usagelogFields[15].Descriptor()
// usagelog.DefaultInputCost holds the default value on creation for the input_cost field. // usagelog.DefaultInputCost holds the default value on creation for the input_cost field.
usagelog.DefaultInputCost = usagelogDescInputCost.Default.(float64) usagelog.DefaultInputCost = usagelogDescInputCost.Default.(float64)
// usagelogDescOutputCost is the schema descriptor for output_cost field. // usagelogDescOutputCost is the schema descriptor for output_cost field.
usagelogDescOutputCost := usagelogFields[15].Descriptor() usagelogDescOutputCost := usagelogFields[16].Descriptor()
// usagelog.DefaultOutputCost holds the default value on creation for the output_cost field. // usagelog.DefaultOutputCost holds the default value on creation for the output_cost field.
usagelog.DefaultOutputCost = usagelogDescOutputCost.Default.(float64) usagelog.DefaultOutputCost = usagelogDescOutputCost.Default.(float64)
// usagelogDescCacheCreationCost is the schema descriptor for cache_creation_cost field. // usagelogDescCacheCreationCost is the schema descriptor for cache_creation_cost field.
usagelogDescCacheCreationCost := usagelogFields[16].Descriptor() usagelogDescCacheCreationCost := usagelogFields[17].Descriptor()
// usagelog.DefaultCacheCreationCost holds the default value on creation for the cache_creation_cost field. // usagelog.DefaultCacheCreationCost holds the default value on creation for the cache_creation_cost field.
usagelog.DefaultCacheCreationCost = usagelogDescCacheCreationCost.Default.(float64) usagelog.DefaultCacheCreationCost = usagelogDescCacheCreationCost.Default.(float64)
// usagelogDescCacheReadCost is the schema descriptor for cache_read_cost field. // usagelogDescCacheReadCost is the schema descriptor for cache_read_cost field.
usagelogDescCacheReadCost := usagelogFields[17].Descriptor() usagelogDescCacheReadCost := usagelogFields[18].Descriptor()
// usagelog.DefaultCacheReadCost holds the default value on creation for the cache_read_cost field. // usagelog.DefaultCacheReadCost holds the default value on creation for the cache_read_cost field.
usagelog.DefaultCacheReadCost = usagelogDescCacheReadCost.Default.(float64) usagelog.DefaultCacheReadCost = usagelogDescCacheReadCost.Default.(float64)
// usagelogDescTotalCost is the schema descriptor for total_cost field. // usagelogDescTotalCost is the schema descriptor for total_cost field.
usagelogDescTotalCost := usagelogFields[18].Descriptor() usagelogDescTotalCost := usagelogFields[19].Descriptor()
// usagelog.DefaultTotalCost holds the default value on creation for the total_cost field. // usagelog.DefaultTotalCost holds the default value on creation for the total_cost field.
usagelog.DefaultTotalCost = usagelogDescTotalCost.Default.(float64) usagelog.DefaultTotalCost = usagelogDescTotalCost.Default.(float64)
// usagelogDescActualCost is the schema descriptor for actual_cost field. // usagelogDescActualCost is the schema descriptor for actual_cost field.
usagelogDescActualCost := usagelogFields[19].Descriptor() usagelogDescActualCost := usagelogFields[20].Descriptor()
// usagelog.DefaultActualCost holds the default value on creation for the actual_cost field. // usagelog.DefaultActualCost holds the default value on creation for the actual_cost field.
usagelog.DefaultActualCost = usagelogDescActualCost.Default.(float64) usagelog.DefaultActualCost = usagelogDescActualCost.Default.(float64)
// usagelogDescRateMultiplier is the schema descriptor for rate_multiplier field. // usagelogDescRateMultiplier is the schema descriptor for rate_multiplier field.
usagelogDescRateMultiplier := usagelogFields[20].Descriptor() usagelogDescRateMultiplier := usagelogFields[21].Descriptor()
// usagelog.DefaultRateMultiplier holds the default value on creation for the rate_multiplier field. // usagelog.DefaultRateMultiplier holds the default value on creation for the rate_multiplier field.
usagelog.DefaultRateMultiplier = usagelogDescRateMultiplier.Default.(float64) usagelog.DefaultRateMultiplier = usagelogDescRateMultiplier.Default.(float64)
// usagelogDescBillingType is the schema descriptor for billing_type field. // usagelogDescBillingType is the schema descriptor for billing_type field.
usagelogDescBillingType := usagelogFields[22].Descriptor() usagelogDescBillingType := usagelogFields[23].Descriptor()
// usagelog.DefaultBillingType holds the default value on creation for the billing_type field. // usagelog.DefaultBillingType holds the default value on creation for the billing_type field.
usagelog.DefaultBillingType = usagelogDescBillingType.Default.(int8) usagelog.DefaultBillingType = usagelogDescBillingType.Default.(int8)
// usagelogDescStream is the schema descriptor for stream field. // usagelogDescStream is the schema descriptor for stream field.
usagelogDescStream := usagelogFields[23].Descriptor() usagelogDescStream := usagelogFields[24].Descriptor()
// usagelog.DefaultStream holds the default value on creation for the stream field. // usagelog.DefaultStream holds the default value on creation for the stream field.
usagelog.DefaultStream = usagelogDescStream.Default.(bool) usagelog.DefaultStream = usagelogDescStream.Default.(bool)
// usagelogDescUserAgent is the schema descriptor for user_agent field. // usagelogDescUserAgent is the schema descriptor for user_agent field.
usagelogDescUserAgent := usagelogFields[26].Descriptor() usagelogDescUserAgent := usagelogFields[27].Descriptor()
// usagelog.UserAgentValidator is a validator for the "user_agent" field. It is called by the builders before save. // usagelog.UserAgentValidator is a validator for the "user_agent" field. It is called by the builders before save.
usagelog.UserAgentValidator = usagelogDescUserAgent.Validators[0].(func(string) error) usagelog.UserAgentValidator = usagelogDescUserAgent.Validators[0].(func(string) error)
// usagelogDescIPAddress is the schema descriptor for ip_address field. // usagelogDescIPAddress is the schema descriptor for ip_address field.
usagelogDescIPAddress := usagelogFields[27].Descriptor() usagelogDescIPAddress := usagelogFields[28].Descriptor()
// usagelog.IPAddressValidator is a validator for the "ip_address" field. It is called by the builders before save. // usagelog.IPAddressValidator is a validator for the "ip_address" field. It is called by the builders before save.
usagelog.IPAddressValidator = usagelogDescIPAddress.Validators[0].(func(string) error) usagelog.IPAddressValidator = usagelogDescIPAddress.Validators[0].(func(string) error)
// usagelogDescImageCount is the schema descriptor for image_count field. // usagelogDescImageCount is the schema descriptor for image_count field.
usagelogDescImageCount := usagelogFields[28].Descriptor() usagelogDescImageCount := usagelogFields[29].Descriptor()
// usagelog.DefaultImageCount holds the default value on creation for the image_count field. // usagelog.DefaultImageCount holds the default value on creation for the image_count field.
usagelog.DefaultImageCount = usagelogDescImageCount.Default.(int) usagelog.DefaultImageCount = usagelogDescImageCount.Default.(int)
// usagelogDescImageSize is the schema descriptor for image_size field. // usagelogDescImageSize is the schema descriptor for image_size field.
usagelogDescImageSize := usagelogFields[29].Descriptor() usagelogDescImageSize := usagelogFields[30].Descriptor()
// usagelog.ImageSizeValidator is a validator for the "image_size" field. It is called by the builders before save. // usagelog.ImageSizeValidator is a validator for the "image_size" field. It is called by the builders before save.
usagelog.ImageSizeValidator = usagelogDescImageSize.Validators[0].(func(string) error) usagelog.ImageSizeValidator = usagelogDescImageSize.Validators[0].(func(string) error)
// usagelogDescMediaType is the schema descriptor for media_type field. // usagelogDescMediaType is the schema descriptor for media_type field.
usagelogDescMediaType := usagelogFields[30].Descriptor() usagelogDescMediaType := usagelogFields[31].Descriptor()
// usagelog.MediaTypeValidator is a validator for the "media_type" field. It is called by the builders before save. // usagelog.MediaTypeValidator is a validator for the "media_type" field. It is called by the builders before save.
usagelog.MediaTypeValidator = usagelogDescMediaType.Validators[0].(func(string) error) usagelog.MediaTypeValidator = usagelogDescMediaType.Validators[0].(func(string) error)
// usagelogDescCacheTTLOverridden is the schema descriptor for cache_ttl_overridden field. // usagelogDescCacheTTLOverridden is the schema descriptor for cache_ttl_overridden field.
usagelogDescCacheTTLOverridden := usagelogFields[31].Descriptor() usagelogDescCacheTTLOverridden := usagelogFields[32].Descriptor()
// usagelog.DefaultCacheTTLOverridden holds the default value on creation for the cache_ttl_overridden field. // usagelog.DefaultCacheTTLOverridden holds the default value on creation for the cache_ttl_overridden field.
usagelog.DefaultCacheTTLOverridden = usagelogDescCacheTTLOverridden.Default.(bool) usagelog.DefaultCacheTTLOverridden = usagelogDescCacheTTLOverridden.Default.(bool)
// usagelogDescCreatedAt is the schema descriptor for created_at field. // usagelogDescCreatedAt is the schema descriptor for created_at field.
usagelogDescCreatedAt := usagelogFields[32].Descriptor() usagelogDescCreatedAt := usagelogFields[33].Descriptor()
// usagelog.DefaultCreatedAt holds the default value on creation for the created_at field. // usagelog.DefaultCreatedAt holds the default value on creation for the created_at field.
usagelog.DefaultCreatedAt = usagelogDescCreatedAt.Default.(func() time.Time) usagelog.DefaultCreatedAt = usagelogDescCreatedAt.Default.(func() time.Time)
userMixin := schema.User{}.Mixin() userMixin := schema.User{}.Mixin()
......
...@@ -41,6 +41,12 @@ func (UsageLog) Fields() []ent.Field { ...@@ -41,6 +41,12 @@ func (UsageLog) Fields() []ent.Field {
field.String("model"). field.String("model").
MaxLen(100). MaxLen(100).
NotEmpty(), NotEmpty(),
// RequestedModel stores the client-requested model name for stable display and analytics.
// NULL means historical rows written before requested_model dual-write was introduced.
field.String("requested_model").
MaxLen(100).
Optional().
Nillable(),
// UpstreamModel stores the actual upstream model name when model mapping // UpstreamModel stores the actual upstream model name when model mapping
// is applied. NULL means no mapping — the requested model was used as-is. // is applied. NULL means no mapping — the requested model was used as-is.
field.String("upstream_model"). field.String("upstream_model").
...@@ -181,6 +187,7 @@ func (UsageLog) Indexes() []ent.Index { ...@@ -181,6 +187,7 @@ func (UsageLog) Indexes() []ent.Index {
index.Fields("subscription_id"), index.Fields("subscription_id"),
index.Fields("created_at"), index.Fields("created_at"),
index.Fields("model"), index.Fields("model"),
index.Fields("requested_model"),
index.Fields("request_id"), index.Fields("request_id"),
// 复合索引用于时间范围查询 // 复合索引用于时间范围查询
index.Fields("user_id", "created_at"), index.Fields("user_id", "created_at"),
......
...@@ -32,6 +32,8 @@ type UsageLog struct { ...@@ -32,6 +32,8 @@ type UsageLog struct {
RequestID string `json:"request_id,omitempty"` RequestID string `json:"request_id,omitempty"`
// Model holds the value of the "model" field. // Model holds the value of the "model" field.
Model string `json:"model,omitempty"` Model string `json:"model,omitempty"`
// RequestedModel holds the value of the "requested_model" field.
RequestedModel *string `json:"requested_model,omitempty"`
// UpstreamModel holds the value of the "upstream_model" field. // UpstreamModel holds the value of the "upstream_model" field.
UpstreamModel *string `json:"upstream_model,omitempty"` UpstreamModel *string `json:"upstream_model,omitempty"`
// GroupID holds the value of the "group_id" field. // GroupID holds the value of the "group_id" field.
...@@ -177,7 +179,7 @@ func (*UsageLog) scanValues(columns []string) ([]any, error) { ...@@ -177,7 +179,7 @@ func (*UsageLog) scanValues(columns []string) ([]any, error) {
values[i] = new(sql.NullFloat64) values[i] = new(sql.NullFloat64)
case usagelog.FieldID, usagelog.FieldUserID, usagelog.FieldAPIKeyID, usagelog.FieldAccountID, usagelog.FieldGroupID, usagelog.FieldSubscriptionID, usagelog.FieldInputTokens, usagelog.FieldOutputTokens, usagelog.FieldCacheCreationTokens, usagelog.FieldCacheReadTokens, usagelog.FieldCacheCreation5mTokens, usagelog.FieldCacheCreation1hTokens, usagelog.FieldBillingType, usagelog.FieldDurationMs, usagelog.FieldFirstTokenMs, usagelog.FieldImageCount: case usagelog.FieldID, usagelog.FieldUserID, usagelog.FieldAPIKeyID, usagelog.FieldAccountID, usagelog.FieldGroupID, usagelog.FieldSubscriptionID, usagelog.FieldInputTokens, usagelog.FieldOutputTokens, usagelog.FieldCacheCreationTokens, usagelog.FieldCacheReadTokens, usagelog.FieldCacheCreation5mTokens, usagelog.FieldCacheCreation1hTokens, usagelog.FieldBillingType, usagelog.FieldDurationMs, usagelog.FieldFirstTokenMs, usagelog.FieldImageCount:
values[i] = new(sql.NullInt64) values[i] = new(sql.NullInt64)
case usagelog.FieldRequestID, usagelog.FieldModel, usagelog.FieldUpstreamModel, usagelog.FieldUserAgent, usagelog.FieldIPAddress, usagelog.FieldImageSize, usagelog.FieldMediaType: case usagelog.FieldRequestID, usagelog.FieldModel, usagelog.FieldRequestedModel, usagelog.FieldUpstreamModel, usagelog.FieldUserAgent, usagelog.FieldIPAddress, usagelog.FieldImageSize, usagelog.FieldMediaType:
values[i] = new(sql.NullString) values[i] = new(sql.NullString)
case usagelog.FieldCreatedAt: case usagelog.FieldCreatedAt:
values[i] = new(sql.NullTime) values[i] = new(sql.NullTime)
...@@ -232,6 +234,13 @@ func (_m *UsageLog) assignValues(columns []string, values []any) error { ...@@ -232,6 +234,13 @@ func (_m *UsageLog) assignValues(columns []string, values []any) error {
} else if value.Valid { } else if value.Valid {
_m.Model = value.String _m.Model = value.String
} }
case usagelog.FieldRequestedModel:
if value, ok := values[i].(*sql.NullString); !ok {
return fmt.Errorf("unexpected type %T for field requested_model", values[i])
} else if value.Valid {
_m.RequestedModel = new(string)
*_m.RequestedModel = value.String
}
case usagelog.FieldUpstreamModel: case usagelog.FieldUpstreamModel:
if value, ok := values[i].(*sql.NullString); !ok { if value, ok := values[i].(*sql.NullString); !ok {
return fmt.Errorf("unexpected type %T for field upstream_model", values[i]) return fmt.Errorf("unexpected type %T for field upstream_model", values[i])
...@@ -486,6 +495,11 @@ func (_m *UsageLog) String() string { ...@@ -486,6 +495,11 @@ func (_m *UsageLog) String() string {
builder.WriteString("model=") builder.WriteString("model=")
builder.WriteString(_m.Model) builder.WriteString(_m.Model)
builder.WriteString(", ") builder.WriteString(", ")
if v := _m.RequestedModel; v != nil {
builder.WriteString("requested_model=")
builder.WriteString(*v)
}
builder.WriteString(", ")
if v := _m.UpstreamModel; v != nil { if v := _m.UpstreamModel; v != nil {
builder.WriteString("upstream_model=") builder.WriteString("upstream_model=")
builder.WriteString(*v) builder.WriteString(*v)
......
...@@ -24,6 +24,8 @@ const ( ...@@ -24,6 +24,8 @@ const (
FieldRequestID = "request_id" FieldRequestID = "request_id"
// FieldModel holds the string denoting the model field in the database. // FieldModel holds the string denoting the model field in the database.
FieldModel = "model" FieldModel = "model"
// FieldRequestedModel holds the string denoting the requested_model field in the database.
FieldRequestedModel = "requested_model"
// FieldUpstreamModel holds the string denoting the upstream_model field in the database. // FieldUpstreamModel holds the string denoting the upstream_model field in the database.
FieldUpstreamModel = "upstream_model" FieldUpstreamModel = "upstream_model"
// FieldGroupID holds the string denoting the group_id field in the database. // FieldGroupID holds the string denoting the group_id field in the database.
...@@ -137,6 +139,7 @@ var Columns = []string{ ...@@ -137,6 +139,7 @@ var Columns = []string{
FieldAccountID, FieldAccountID,
FieldRequestID, FieldRequestID,
FieldModel, FieldModel,
FieldRequestedModel,
FieldUpstreamModel, FieldUpstreamModel,
FieldGroupID, FieldGroupID,
FieldSubscriptionID, FieldSubscriptionID,
...@@ -182,6 +185,8 @@ var ( ...@@ -182,6 +185,8 @@ var (
RequestIDValidator func(string) error RequestIDValidator func(string) error
// ModelValidator is a validator for the "model" field. It is called by the builders before save. // ModelValidator is a validator for the "model" field. It is called by the builders before save.
ModelValidator func(string) error ModelValidator func(string) error
// RequestedModelValidator is a validator for the "requested_model" field. It is called by the builders before save.
RequestedModelValidator func(string) error
// UpstreamModelValidator is a validator for the "upstream_model" field. It is called by the builders before save. // UpstreamModelValidator is a validator for the "upstream_model" field. It is called by the builders before save.
UpstreamModelValidator func(string) error UpstreamModelValidator func(string) error
// DefaultInputTokens holds the default value on creation for the "input_tokens" field. // DefaultInputTokens holds the default value on creation for the "input_tokens" field.
...@@ -263,6 +268,11 @@ func ByModel(opts ...sql.OrderTermOption) OrderOption { ...@@ -263,6 +268,11 @@ func ByModel(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldModel, opts...).ToFunc() return sql.OrderByField(FieldModel, opts...).ToFunc()
} }
// ByRequestedModel orders the results by the requested_model field.
func ByRequestedModel(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldRequestedModel, opts...).ToFunc()
}
// ByUpstreamModel orders the results by the upstream_model field. // ByUpstreamModel orders the results by the upstream_model field.
func ByUpstreamModel(opts ...sql.OrderTermOption) OrderOption { func ByUpstreamModel(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldUpstreamModel, opts...).ToFunc() return sql.OrderByField(FieldUpstreamModel, opts...).ToFunc()
......
...@@ -80,6 +80,11 @@ func Model(v string) predicate.UsageLog { ...@@ -80,6 +80,11 @@ func Model(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldEQ(FieldModel, v)) return predicate.UsageLog(sql.FieldEQ(FieldModel, v))
} }
// RequestedModel applies equality check predicate on the "requested_model" field. It's identical to RequestedModelEQ.
func RequestedModel(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldEQ(FieldRequestedModel, v))
}
// UpstreamModel applies equality check predicate on the "upstream_model" field. It's identical to UpstreamModelEQ. // UpstreamModel applies equality check predicate on the "upstream_model" field. It's identical to UpstreamModelEQ.
func UpstreamModel(v string) predicate.UsageLog { func UpstreamModel(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldEQ(FieldUpstreamModel, v)) return predicate.UsageLog(sql.FieldEQ(FieldUpstreamModel, v))
...@@ -410,6 +415,81 @@ func ModelContainsFold(v string) predicate.UsageLog { ...@@ -410,6 +415,81 @@ func ModelContainsFold(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldContainsFold(FieldModel, v)) return predicate.UsageLog(sql.FieldContainsFold(FieldModel, v))
} }
// RequestedModelEQ applies the EQ predicate on the "requested_model" field.
func RequestedModelEQ(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldEQ(FieldRequestedModel, v))
}
// RequestedModelNEQ applies the NEQ predicate on the "requested_model" field.
func RequestedModelNEQ(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldNEQ(FieldRequestedModel, v))
}
// RequestedModelIn applies the In predicate on the "requested_model" field.
func RequestedModelIn(vs ...string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldIn(FieldRequestedModel, vs...))
}
// RequestedModelNotIn applies the NotIn predicate on the "requested_model" field.
func RequestedModelNotIn(vs ...string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldNotIn(FieldRequestedModel, vs...))
}
// RequestedModelGT applies the GT predicate on the "requested_model" field.
func RequestedModelGT(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldGT(FieldRequestedModel, v))
}
// RequestedModelGTE applies the GTE predicate on the "requested_model" field.
func RequestedModelGTE(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldGTE(FieldRequestedModel, v))
}
// RequestedModelLT applies the LT predicate on the "requested_model" field.
func RequestedModelLT(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldLT(FieldRequestedModel, v))
}
// RequestedModelLTE applies the LTE predicate on the "requested_model" field.
func RequestedModelLTE(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldLTE(FieldRequestedModel, v))
}
// RequestedModelContains applies the Contains predicate on the "requested_model" field.
func RequestedModelContains(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldContains(FieldRequestedModel, v))
}
// RequestedModelHasPrefix applies the HasPrefix predicate on the "requested_model" field.
func RequestedModelHasPrefix(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldHasPrefix(FieldRequestedModel, v))
}
// RequestedModelHasSuffix applies the HasSuffix predicate on the "requested_model" field.
func RequestedModelHasSuffix(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldHasSuffix(FieldRequestedModel, v))
}
// RequestedModelIsNil applies the IsNil predicate on the "requested_model" field.
func RequestedModelIsNil() predicate.UsageLog {
return predicate.UsageLog(sql.FieldIsNull(FieldRequestedModel))
}
// RequestedModelNotNil applies the NotNil predicate on the "requested_model" field.
func RequestedModelNotNil() predicate.UsageLog {
return predicate.UsageLog(sql.FieldNotNull(FieldRequestedModel))
}
// RequestedModelEqualFold applies the EqualFold predicate on the "requested_model" field.
func RequestedModelEqualFold(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldEqualFold(FieldRequestedModel, v))
}
// RequestedModelContainsFold applies the ContainsFold predicate on the "requested_model" field.
func RequestedModelContainsFold(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldContainsFold(FieldRequestedModel, v))
}
// UpstreamModelEQ applies the EQ predicate on the "upstream_model" field. // UpstreamModelEQ applies the EQ predicate on the "upstream_model" field.
func UpstreamModelEQ(v string) predicate.UsageLog { func UpstreamModelEQ(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldEQ(FieldUpstreamModel, v)) return predicate.UsageLog(sql.FieldEQ(FieldUpstreamModel, v))
......
...@@ -57,6 +57,20 @@ func (_c *UsageLogCreate) SetModel(v string) *UsageLogCreate { ...@@ -57,6 +57,20 @@ func (_c *UsageLogCreate) SetModel(v string) *UsageLogCreate {
return _c return _c
} }
// SetRequestedModel sets the "requested_model" field.
func (_c *UsageLogCreate) SetRequestedModel(v string) *UsageLogCreate {
_c.mutation.SetRequestedModel(v)
return _c
}
// SetNillableRequestedModel sets the "requested_model" field if the given value is not nil.
func (_c *UsageLogCreate) SetNillableRequestedModel(v *string) *UsageLogCreate {
if v != nil {
_c.SetRequestedModel(*v)
}
return _c
}
// SetUpstreamModel sets the "upstream_model" field. // SetUpstreamModel sets the "upstream_model" field.
func (_c *UsageLogCreate) SetUpstreamModel(v string) *UsageLogCreate { func (_c *UsageLogCreate) SetUpstreamModel(v string) *UsageLogCreate {
_c.mutation.SetUpstreamModel(v) _c.mutation.SetUpstreamModel(v)
...@@ -610,6 +624,11 @@ func (_c *UsageLogCreate) check() error { ...@@ -610,6 +624,11 @@ func (_c *UsageLogCreate) check() error {
return &ValidationError{Name: "model", err: fmt.Errorf(`ent: validator failed for field "UsageLog.model": %w`, err)} return &ValidationError{Name: "model", err: fmt.Errorf(`ent: validator failed for field "UsageLog.model": %w`, err)}
} }
} }
if v, ok := _c.mutation.RequestedModel(); ok {
if err := usagelog.RequestedModelValidator(v); err != nil {
return &ValidationError{Name: "requested_model", err: fmt.Errorf(`ent: validator failed for field "UsageLog.requested_model": %w`, err)}
}
}
if v, ok := _c.mutation.UpstreamModel(); ok { if v, ok := _c.mutation.UpstreamModel(); ok {
if err := usagelog.UpstreamModelValidator(v); err != nil { if err := usagelog.UpstreamModelValidator(v); err != nil {
return &ValidationError{Name: "upstream_model", err: fmt.Errorf(`ent: validator failed for field "UsageLog.upstream_model": %w`, err)} return &ValidationError{Name: "upstream_model", err: fmt.Errorf(`ent: validator failed for field "UsageLog.upstream_model": %w`, err)}
...@@ -733,6 +752,10 @@ func (_c *UsageLogCreate) createSpec() (*UsageLog, *sqlgraph.CreateSpec) { ...@@ -733,6 +752,10 @@ func (_c *UsageLogCreate) createSpec() (*UsageLog, *sqlgraph.CreateSpec) {
_spec.SetField(usagelog.FieldModel, field.TypeString, value) _spec.SetField(usagelog.FieldModel, field.TypeString, value)
_node.Model = value _node.Model = value
} }
if value, ok := _c.mutation.RequestedModel(); ok {
_spec.SetField(usagelog.FieldRequestedModel, field.TypeString, value)
_node.RequestedModel = &value
}
if value, ok := _c.mutation.UpstreamModel(); ok { if value, ok := _c.mutation.UpstreamModel(); ok {
_spec.SetField(usagelog.FieldUpstreamModel, field.TypeString, value) _spec.SetField(usagelog.FieldUpstreamModel, field.TypeString, value)
_node.UpstreamModel = &value _node.UpstreamModel = &value
...@@ -1034,6 +1057,24 @@ func (u *UsageLogUpsert) UpdateModel() *UsageLogUpsert { ...@@ -1034,6 +1057,24 @@ func (u *UsageLogUpsert) UpdateModel() *UsageLogUpsert {
return u return u
} }
// SetRequestedModel sets the "requested_model" field.
func (u *UsageLogUpsert) SetRequestedModel(v string) *UsageLogUpsert {
u.Set(usagelog.FieldRequestedModel, v)
return u
}
// UpdateRequestedModel sets the "requested_model" field to the value that was provided on create.
func (u *UsageLogUpsert) UpdateRequestedModel() *UsageLogUpsert {
u.SetExcluded(usagelog.FieldRequestedModel)
return u
}
// ClearRequestedModel clears the value of the "requested_model" field.
func (u *UsageLogUpsert) ClearRequestedModel() *UsageLogUpsert {
u.SetNull(usagelog.FieldRequestedModel)
return u
}
// SetUpstreamModel sets the "upstream_model" field. // SetUpstreamModel sets the "upstream_model" field.
func (u *UsageLogUpsert) SetUpstreamModel(v string) *UsageLogUpsert { func (u *UsageLogUpsert) SetUpstreamModel(v string) *UsageLogUpsert {
u.Set(usagelog.FieldUpstreamModel, v) u.Set(usagelog.FieldUpstreamModel, v)
...@@ -1641,6 +1682,27 @@ func (u *UsageLogUpsertOne) UpdateModel() *UsageLogUpsertOne { ...@@ -1641,6 +1682,27 @@ func (u *UsageLogUpsertOne) UpdateModel() *UsageLogUpsertOne {
}) })
} }
// SetRequestedModel sets the "requested_model" field.
func (u *UsageLogUpsertOne) SetRequestedModel(v string) *UsageLogUpsertOne {
return u.Update(func(s *UsageLogUpsert) {
s.SetRequestedModel(v)
})
}
// UpdateRequestedModel sets the "requested_model" field to the value that was provided on create.
func (u *UsageLogUpsertOne) UpdateRequestedModel() *UsageLogUpsertOne {
return u.Update(func(s *UsageLogUpsert) {
s.UpdateRequestedModel()
})
}
// ClearRequestedModel clears the value of the "requested_model" field.
func (u *UsageLogUpsertOne) ClearRequestedModel() *UsageLogUpsertOne {
return u.Update(func(s *UsageLogUpsert) {
s.ClearRequestedModel()
})
}
// SetUpstreamModel sets the "upstream_model" field. // SetUpstreamModel sets the "upstream_model" field.
func (u *UsageLogUpsertOne) SetUpstreamModel(v string) *UsageLogUpsertOne { func (u *UsageLogUpsertOne) SetUpstreamModel(v string) *UsageLogUpsertOne {
return u.Update(func(s *UsageLogUpsert) { return u.Update(func(s *UsageLogUpsert) {
...@@ -2496,6 +2558,27 @@ func (u *UsageLogUpsertBulk) UpdateModel() *UsageLogUpsertBulk { ...@@ -2496,6 +2558,27 @@ func (u *UsageLogUpsertBulk) UpdateModel() *UsageLogUpsertBulk {
}) })
} }
// SetRequestedModel sets the "requested_model" field.
func (u *UsageLogUpsertBulk) SetRequestedModel(v string) *UsageLogUpsertBulk {
return u.Update(func(s *UsageLogUpsert) {
s.SetRequestedModel(v)
})
}
// UpdateRequestedModel sets the "requested_model" field to the value that was provided on create.
func (u *UsageLogUpsertBulk) UpdateRequestedModel() *UsageLogUpsertBulk {
return u.Update(func(s *UsageLogUpsert) {
s.UpdateRequestedModel()
})
}
// ClearRequestedModel clears the value of the "requested_model" field.
func (u *UsageLogUpsertBulk) ClearRequestedModel() *UsageLogUpsertBulk {
return u.Update(func(s *UsageLogUpsert) {
s.ClearRequestedModel()
})
}
// SetUpstreamModel sets the "upstream_model" field. // SetUpstreamModel sets the "upstream_model" field.
func (u *UsageLogUpsertBulk) SetUpstreamModel(v string) *UsageLogUpsertBulk { func (u *UsageLogUpsertBulk) SetUpstreamModel(v string) *UsageLogUpsertBulk {
return u.Update(func(s *UsageLogUpsert) { return u.Update(func(s *UsageLogUpsert) {
......
...@@ -102,6 +102,26 @@ func (_u *UsageLogUpdate) SetNillableModel(v *string) *UsageLogUpdate { ...@@ -102,6 +102,26 @@ func (_u *UsageLogUpdate) SetNillableModel(v *string) *UsageLogUpdate {
return _u return _u
} }
// SetRequestedModel sets the "requested_model" field.
func (_u *UsageLogUpdate) SetRequestedModel(v string) *UsageLogUpdate {
_u.mutation.SetRequestedModel(v)
return _u
}
// SetNillableRequestedModel sets the "requested_model" field if the given value is not nil.
func (_u *UsageLogUpdate) SetNillableRequestedModel(v *string) *UsageLogUpdate {
if v != nil {
_u.SetRequestedModel(*v)
}
return _u
}
// ClearRequestedModel clears the value of the "requested_model" field.
func (_u *UsageLogUpdate) ClearRequestedModel() *UsageLogUpdate {
_u.mutation.ClearRequestedModel()
return _u
}
// SetUpstreamModel sets the "upstream_model" field. // SetUpstreamModel sets the "upstream_model" field.
func (_u *UsageLogUpdate) SetUpstreamModel(v string) *UsageLogUpdate { func (_u *UsageLogUpdate) SetUpstreamModel(v string) *UsageLogUpdate {
_u.mutation.SetUpstreamModel(v) _u.mutation.SetUpstreamModel(v)
...@@ -765,6 +785,11 @@ func (_u *UsageLogUpdate) check() error { ...@@ -765,6 +785,11 @@ func (_u *UsageLogUpdate) check() error {
return &ValidationError{Name: "model", err: fmt.Errorf(`ent: validator failed for field "UsageLog.model": %w`, err)} return &ValidationError{Name: "model", err: fmt.Errorf(`ent: validator failed for field "UsageLog.model": %w`, err)}
} }
} }
if v, ok := _u.mutation.RequestedModel(); ok {
if err := usagelog.RequestedModelValidator(v); err != nil {
return &ValidationError{Name: "requested_model", err: fmt.Errorf(`ent: validator failed for field "UsageLog.requested_model": %w`, err)}
}
}
if v, ok := _u.mutation.UpstreamModel(); ok { if v, ok := _u.mutation.UpstreamModel(); ok {
if err := usagelog.UpstreamModelValidator(v); err != nil { if err := usagelog.UpstreamModelValidator(v); err != nil {
return &ValidationError{Name: "upstream_model", err: fmt.Errorf(`ent: validator failed for field "UsageLog.upstream_model": %w`, err)} return &ValidationError{Name: "upstream_model", err: fmt.Errorf(`ent: validator failed for field "UsageLog.upstream_model": %w`, err)}
...@@ -820,6 +845,12 @@ func (_u *UsageLogUpdate) sqlSave(ctx context.Context) (_node int, err error) { ...@@ -820,6 +845,12 @@ func (_u *UsageLogUpdate) sqlSave(ctx context.Context) (_node int, err error) {
if value, ok := _u.mutation.Model(); ok { if value, ok := _u.mutation.Model(); ok {
_spec.SetField(usagelog.FieldModel, field.TypeString, value) _spec.SetField(usagelog.FieldModel, field.TypeString, value)
} }
if value, ok := _u.mutation.RequestedModel(); ok {
_spec.SetField(usagelog.FieldRequestedModel, field.TypeString, value)
}
if _u.mutation.RequestedModelCleared() {
_spec.ClearField(usagelog.FieldRequestedModel, field.TypeString)
}
if value, ok := _u.mutation.UpstreamModel(); ok { if value, ok := _u.mutation.UpstreamModel(); ok {
_spec.SetField(usagelog.FieldUpstreamModel, field.TypeString, value) _spec.SetField(usagelog.FieldUpstreamModel, field.TypeString, value)
} }
...@@ -1208,6 +1239,26 @@ func (_u *UsageLogUpdateOne) SetNillableModel(v *string) *UsageLogUpdateOne { ...@@ -1208,6 +1239,26 @@ func (_u *UsageLogUpdateOne) SetNillableModel(v *string) *UsageLogUpdateOne {
return _u return _u
} }
// SetRequestedModel sets the "requested_model" field.
func (_u *UsageLogUpdateOne) SetRequestedModel(v string) *UsageLogUpdateOne {
_u.mutation.SetRequestedModel(v)
return _u
}
// SetNillableRequestedModel sets the "requested_model" field if the given value is not nil.
func (_u *UsageLogUpdateOne) SetNillableRequestedModel(v *string) *UsageLogUpdateOne {
if v != nil {
_u.SetRequestedModel(*v)
}
return _u
}
// ClearRequestedModel clears the value of the "requested_model" field.
func (_u *UsageLogUpdateOne) ClearRequestedModel() *UsageLogUpdateOne {
_u.mutation.ClearRequestedModel()
return _u
}
// SetUpstreamModel sets the "upstream_model" field. // SetUpstreamModel sets the "upstream_model" field.
func (_u *UsageLogUpdateOne) SetUpstreamModel(v string) *UsageLogUpdateOne { func (_u *UsageLogUpdateOne) SetUpstreamModel(v string) *UsageLogUpdateOne {
_u.mutation.SetUpstreamModel(v) _u.mutation.SetUpstreamModel(v)
...@@ -1884,6 +1935,11 @@ func (_u *UsageLogUpdateOne) check() error { ...@@ -1884,6 +1935,11 @@ func (_u *UsageLogUpdateOne) check() error {
return &ValidationError{Name: "model", err: fmt.Errorf(`ent: validator failed for field "UsageLog.model": %w`, err)} return &ValidationError{Name: "model", err: fmt.Errorf(`ent: validator failed for field "UsageLog.model": %w`, err)}
} }
} }
if v, ok := _u.mutation.RequestedModel(); ok {
if err := usagelog.RequestedModelValidator(v); err != nil {
return &ValidationError{Name: "requested_model", err: fmt.Errorf(`ent: validator failed for field "UsageLog.requested_model": %w`, err)}
}
}
if v, ok := _u.mutation.UpstreamModel(); ok { if v, ok := _u.mutation.UpstreamModel(); ok {
if err := usagelog.UpstreamModelValidator(v); err != nil { if err := usagelog.UpstreamModelValidator(v); err != nil {
return &ValidationError{Name: "upstream_model", err: fmt.Errorf(`ent: validator failed for field "UsageLog.upstream_model": %w`, err)} return &ValidationError{Name: "upstream_model", err: fmt.Errorf(`ent: validator failed for field "UsageLog.upstream_model": %w`, err)}
...@@ -1956,6 +2012,12 @@ func (_u *UsageLogUpdateOne) sqlSave(ctx context.Context) (_node *UsageLog, err ...@@ -1956,6 +2012,12 @@ func (_u *UsageLogUpdateOne) sqlSave(ctx context.Context) (_node *UsageLog, err
if value, ok := _u.mutation.Model(); ok { if value, ok := _u.mutation.Model(); ok {
_spec.SetField(usagelog.FieldModel, field.TypeString, value) _spec.SetField(usagelog.FieldModel, field.TypeString, value)
} }
if value, ok := _u.mutation.RequestedModel(); ok {
_spec.SetField(usagelog.FieldRequestedModel, field.TypeString, value)
}
if _u.mutation.RequestedModelCleared() {
_spec.ClearField(usagelog.FieldRequestedModel, field.TypeString)
}
if value, ok := _u.mutation.UpstreamModel(); ok { if value, ok := _u.mutation.UpstreamModel(); ok {
_spec.SetField(usagelog.FieldUpstreamModel, field.TypeString, value) _spec.SetField(usagelog.FieldUpstreamModel, field.TypeString, value)
} }
......
...@@ -94,6 +94,10 @@ github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XL ...@@ -94,6 +94,10 @@ github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XL
github.com/chenzhuoyu/base64x v0.0.0-20211019084208-fb5309c8db06/go.mod h1:DH46F32mSOjUmXrMHnKwZdA8wcEefY7UVqBKYGjpdQY= github.com/chenzhuoyu/base64x v0.0.0-20211019084208-fb5309c8db06/go.mod h1:DH46F32mSOjUmXrMHnKwZdA8wcEefY7UVqBKYGjpdQY=
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 h1:qSGYFH7+jGhDF8vLC+iwCD4WpbV1EBDSzWkJODFLams= github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 h1:qSGYFH7+jGhDF8vLC+iwCD4WpbV1EBDSzWkJODFLams=
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311/go.mod h1:b583jCggY9gE99b6G5LEC39OIiVsWj+R97kbl5odCEk= github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311/go.mod h1:b583jCggY9gE99b6G5LEC39OIiVsWj+R97kbl5odCEk=
github.com/clipperhouse/stringish v0.1.1 h1:+NSqMOr3GR6k1FdRhhnXrLfztGzuG+VuFDfatpWHKCs=
github.com/clipperhouse/stringish v0.1.1/go.mod h1:v/WhFtE1q0ovMta2+m+UbpZ+2/HEXNWYXQgCt4hdOzA=
github.com/clipperhouse/uax29/v2 v2.5.0 h1:x7T0T4eTHDONxFJsL94uKNKPHrclyFI0lm7+w94cO8U=
github.com/clipperhouse/uax29/v2 v2.5.0/go.mod h1:Wn1g7MK6OoeDT0vL+Q0SQLDz/KpfsVRgg6W7ihQeh4g=
github.com/coder/websocket v1.8.14 h1:9L0p0iKiNOibykf283eHkKUHHrpG7f65OE3BhhO7v9g= github.com/coder/websocket v1.8.14 h1:9L0p0iKiNOibykf283eHkKUHHrpG7f65OE3BhhO7v9g=
github.com/coder/websocket v1.8.14/go.mod h1:NX3SzP+inril6yawo5CQXx8+fk145lPDC6pumgx0mVg= github.com/coder/websocket v1.8.14/go.mod h1:NX3SzP+inril6yawo5CQXx8+fk145lPDC6pumgx0mVg=
github.com/containerd/errdefs v1.0.0 h1:tg5yIfIlQIrxYtu9ajqY42W3lpS19XqdxRQeEwYG8PI= github.com/containerd/errdefs v1.0.0 h1:tg5yIfIlQIrxYtu9ajqY42W3lpS19XqdxRQeEwYG8PI=
...@@ -195,6 +199,8 @@ github.com/icholy/digest v1.1.0 h1:HfGg9Irj7i+IX1o1QAmPfIBNu/Q5A5Tu3n/MED9k9H4= ...@@ -195,6 +199,8 @@ github.com/icholy/digest v1.1.0 h1:HfGg9Irj7i+IX1o1QAmPfIBNu/Q5A5Tu3n/MED9k9H4=
github.com/icholy/digest v1.1.0/go.mod h1:QNrsSGQ5v7v9cReDI0+eyjsXGUoRSUZQHeQ5C4XLa0Y= github.com/icholy/digest v1.1.0/go.mod h1:QNrsSGQ5v7v9cReDI0+eyjsXGUoRSUZQHeQ5C4XLa0Y=
github.com/imroc/req/v3 v3.57.0 h1:LMTUjNRUybUkTPn8oJDq8Kg3JRBOBTcnDhKu7mzupKI= github.com/imroc/req/v3 v3.57.0 h1:LMTUjNRUybUkTPn8oJDq8Kg3JRBOBTcnDhKu7mzupKI=
github.com/imroc/req/v3 v3.57.0/go.mod h1:JL62ey1nvSLq81HORNcosvlf7SxZStONNqOprg0Pz00= github.com/imroc/req/v3 v3.57.0/go.mod h1:JL62ey1nvSLq81HORNcosvlf7SxZStONNqOprg0Pz00=
github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8=
github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw=
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo= github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo=
...@@ -230,6 +236,8 @@ github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovk ...@@ -230,6 +236,8 @@ github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovk
github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM=
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/mattn/go-runewidth v0.0.19 h1:v++JhqYnZuu5jSKrk9RbgF5v4CGUjqRfBm05byFGLdw=
github.com/mattn/go-runewidth v0.0.19/go.mod h1:XBkDxAl56ILZc9knddidhrOlY5R/pDhgLpndooCuJAs=
github.com/mattn/go-sqlite3 v1.14.17 h1:mCRHCLDUBXgpKAqIKsaAaAsrAlbkeomtRFKXh2L6YIM= github.com/mattn/go-sqlite3 v1.14.17 h1:mCRHCLDUBXgpKAqIKsaAaAsrAlbkeomtRFKXh2L6YIM=
github.com/mattn/go-sqlite3 v1.14.17/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg= github.com/mattn/go-sqlite3 v1.14.17/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg=
github.com/mdelapenya/tlscert v0.2.0 h1:7H81W6Z/4weDvZBNOfQte5GpIMo0lGYEeWbkGp5LJHI= github.com/mdelapenya/tlscert v0.2.0 h1:7H81W6Z/4weDvZBNOfQte5GpIMo0lGYEeWbkGp5LJHI=
...@@ -263,6 +271,8 @@ github.com/morikuni/aec v1.0.0 h1:nP9CBfwrvYnBRgY6qfDQkygYDmYwOilePFkwzv4dU8A= ...@@ -263,6 +271,8 @@ github.com/morikuni/aec v1.0.0 h1:nP9CBfwrvYnBRgY6qfDQkygYDmYwOilePFkwzv4dU8A=
github.com/morikuni/aec v1.0.0/go.mod h1:BbKIizmSmc5MMPqRYbxO4ZU0S0+P200+tUnFx7PXmsc= github.com/morikuni/aec v1.0.0/go.mod h1:BbKIizmSmc5MMPqRYbxO4ZU0S0+P200+tUnFx7PXmsc=
github.com/ncruces/go-strftime v1.0.0 h1:HMFp8mLCTPp341M/ZnA4qaf7ZlsbTc+miZjCLOFAw7w= github.com/ncruces/go-strftime v1.0.0 h1:HMFp8mLCTPp341M/ZnA4qaf7ZlsbTc+miZjCLOFAw7w=
github.com/ncruces/go-strftime v1.0.0/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls= github.com/ncruces/go-strftime v1.0.0/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls=
github.com/olekukonko/tablewriter v0.0.5 h1:P2Ga83D34wi1o9J6Wh1mRuqd4mF/x/lgBS7N7AbDhec=
github.com/olekukonko/tablewriter v0.0.5/go.mod h1:hPp6KlRPjbx+hW8ykQs1w3UBbZlj6HuIJcUGPhkA7kY=
github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8Oi/yOhh5U= github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8Oi/yOhh5U=
github.com/opencontainers/go-digest v1.0.0/go.mod h1:0JzlMkj0TRzQZfJkVvzbP0HBR3IKzErnv2BNG4W4MAM= github.com/opencontainers/go-digest v1.0.0/go.mod h1:0JzlMkj0TRzQZfJkVvzbP0HBR3IKzErnv2BNG4W4MAM=
github.com/opencontainers/image-spec v1.1.1 h1:y0fUlFfIZhPF1W537XOLg0/fcx6zcHCJwooC2xJA040= github.com/opencontainers/image-spec v1.1.1 h1:y0fUlFfIZhPF1W537XOLg0/fcx6zcHCJwooC2xJA040=
...@@ -314,6 +324,8 @@ github.com/spf13/afero v1.11.0 h1:WJQKhtpdm3v2IzqG8VMqrr6Rf3UYpEF239Jy9wNepM8= ...@@ -314,6 +324,8 @@ github.com/spf13/afero v1.11.0 h1:WJQKhtpdm3v2IzqG8VMqrr6Rf3UYpEF239Jy9wNepM8=
github.com/spf13/afero v1.11.0/go.mod h1:GH9Y3pIexgf1MTIWtNGyogA5MwRIDXGUr+hbWNoBjkY= github.com/spf13/afero v1.11.0/go.mod h1:GH9Y3pIexgf1MTIWtNGyogA5MwRIDXGUr+hbWNoBjkY=
github.com/spf13/cast v1.6.0 h1:GEiTHELF+vaR5dhz3VqZfFSzZjYbgeKDpBxQVS4GYJ0= github.com/spf13/cast v1.6.0 h1:GEiTHELF+vaR5dhz3VqZfFSzZjYbgeKDpBxQVS4GYJ0=
github.com/spf13/cast v1.6.0/go.mod h1:ancEpBxwJDODSW/UG4rDrAqiKolqNNh2DX3mk86cAdo= github.com/spf13/cast v1.6.0/go.mod h1:ancEpBxwJDODSW/UG4rDrAqiKolqNNh2DX3mk86cAdo=
github.com/spf13/cobra v1.7.0 h1:hyqWnYt1ZQShIddO5kBpj3vu05/++x6tJ6dg8EC572I=
github.com/spf13/cobra v1.7.0/go.mod h1:uLxZILRyS/50WlhOIKD7W6V5bgeIt+4sICxh6uRMrb0=
github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA= github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA=
github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
github.com/spf13/viper v1.18.2 h1:LUXCnvUvSM6FXAsj6nnfc8Q2tp1dIgUfY9Kc8GsSOiQ= github.com/spf13/viper v1.18.2 h1:LUXCnvUvSM6FXAsj6nnfc8Q2tp1dIgUfY9Kc8GsSOiQ=
......
...@@ -522,14 +522,17 @@ func usageLogFromServiceUser(l *service.UsageLog) UsageLog { ...@@ -522,14 +522,17 @@ func usageLogFromServiceUser(l *service.UsageLog) UsageLog {
// 普通用户 DTO:严禁包含管理员字段(例如 account_rate_multiplier、ip_address、account)。 // 普通用户 DTO:严禁包含管理员字段(例如 account_rate_multiplier、ip_address、account)。
requestType := l.EffectiveRequestType() requestType := l.EffectiveRequestType()
stream, openAIWSMode := service.ApplyLegacyRequestFields(requestType, l.Stream, l.OpenAIWSMode) stream, openAIWSMode := service.ApplyLegacyRequestFields(requestType, l.Stream, l.OpenAIWSMode)
requestedModel := l.RequestedModel
if requestedModel == "" {
requestedModel = l.Model
}
return UsageLog{ return UsageLog{
ID: l.ID, ID: l.ID,
UserID: l.UserID, UserID: l.UserID,
APIKeyID: l.APIKeyID, APIKeyID: l.APIKeyID,
AccountID: l.AccountID, AccountID: l.AccountID,
RequestID: l.RequestID, RequestID: l.RequestID,
Model: l.Model, Model: requestedModel,
UpstreamModel: l.UpstreamModel,
ServiceTier: l.ServiceTier, ServiceTier: l.ServiceTier,
ReasoningEffort: l.ReasoningEffort, ReasoningEffort: l.ReasoningEffort,
InboundEndpoint: l.InboundEndpoint, InboundEndpoint: l.InboundEndpoint,
...@@ -586,6 +589,7 @@ func UsageLogFromServiceAdmin(l *service.UsageLog) *AdminUsageLog { ...@@ -586,6 +589,7 @@ func UsageLogFromServiceAdmin(l *service.UsageLog) *AdminUsageLog {
} }
return &AdminUsageLog{ return &AdminUsageLog{
UsageLog: usageLogFromServiceUser(l), UsageLog: usageLogFromServiceUser(l),
UpstreamModel: l.UpstreamModel,
AccountRateMultiplier: l.AccountRateMultiplier, AccountRateMultiplier: l.AccountRateMultiplier,
IPAddress: l.IPAddress, IPAddress: l.IPAddress,
Account: AccountSummaryFromService(l.Account), Account: AccountSummaryFromService(l.Account),
......
package dto package dto
import ( import (
"encoding/json"
"testing" "testing"
"github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
...@@ -106,6 +107,47 @@ func TestUsageLogFromService_IncludesServiceTierForUserAndAdmin(t *testing.T) { ...@@ -106,6 +107,47 @@ func TestUsageLogFromService_IncludesServiceTierForUserAndAdmin(t *testing.T) {
require.InDelta(t, 1.5, *adminDTO.AccountRateMultiplier, 1e-12) require.InDelta(t, 1.5, *adminDTO.AccountRateMultiplier, 1e-12)
} }
func TestUsageLogFromService_UsesRequestedModelAndKeepsUpstreamAdminOnly(t *testing.T) {
t.Parallel()
upstreamModel := "claude-sonnet-4-20250514"
log := &service.UsageLog{
RequestID: "req_4",
Model: upstreamModel,
RequestedModel: "claude-sonnet-4",
UpstreamModel: &upstreamModel,
}
userDTO := UsageLogFromService(log)
adminDTO := UsageLogFromServiceAdmin(log)
require.Equal(t, "claude-sonnet-4", userDTO.Model)
require.Equal(t, "claude-sonnet-4", adminDTO.Model)
userJSON, err := json.Marshal(userDTO)
require.NoError(t, err)
require.NotContains(t, string(userJSON), "upstream_model")
adminJSON, err := json.Marshal(adminDTO)
require.NoError(t, err)
require.Contains(t, string(adminJSON), `"upstream_model":"claude-sonnet-4-20250514"`)
}
func TestUsageLogFromService_FallsBackToLegacyModelWhenRequestedModelMissing(t *testing.T) {
t.Parallel()
log := &service.UsageLog{
RequestID: "req_legacy",
Model: "claude-3",
}
userDTO := UsageLogFromService(log)
adminDTO := UsageLogFromServiceAdmin(log)
require.Equal(t, "claude-3", userDTO.Model)
require.Equal(t, "claude-3", adminDTO.Model)
}
func f64Ptr(value float64) *float64 { func f64Ptr(value float64) *float64 {
return &value return &value
} }
...@@ -334,9 +334,6 @@ type UsageLog struct { ...@@ -334,9 +334,6 @@ type UsageLog struct {
AccountID int64 `json:"account_id"` AccountID int64 `json:"account_id"`
RequestID string `json:"request_id"` RequestID string `json:"request_id"`
Model string `json:"model"` Model string `json:"model"`
// UpstreamModel is the actual model sent to the upstream provider after mapping.
// Omitted when no mapping was applied (requested model was used as-is).
UpstreamModel *string `json:"upstream_model,omitempty"`
// ServiceTier records the OpenAI service tier used for billing, e.g. "priority" / "flex". // ServiceTier records the OpenAI service tier used for billing, e.g. "priority" / "flex".
ServiceTier *string `json:"service_tier,omitempty"` ServiceTier *string `json:"service_tier,omitempty"`
// ReasoningEffort is the request's reasoning effort level. // ReasoningEffort is the request's reasoning effort level.
...@@ -396,6 +393,10 @@ type UsageLog struct { ...@@ -396,6 +393,10 @@ type UsageLog struct {
type AdminUsageLog struct { type AdminUsageLog struct {
UsageLog UsageLog
// UpstreamModel is the actual model sent to the upstream provider after mapping.
// Omitted when no mapping was applied (requested model was used as-is).
UpstreamModel *string `json:"upstream_model,omitempty"`
// AccountRateMultiplier 账号计费倍率快照(nil 表示按 1.0 处理) // AccountRateMultiplier 账号计费倍率快照(nil 表示按 1.0 处理)
AccountRateMultiplier *float64 `json:"account_rate_multiplier"` AccountRateMultiplier *float64 `json:"account_rate_multiplier"`
......
...@@ -28,50 +28,64 @@ import ( ...@@ -28,50 +28,64 @@ import (
gocache "github.com/patrickmn/go-cache" gocache "github.com/patrickmn/go-cache"
) )
const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, upstream_model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, account_rate_multiplier, billing_type, request_type, stream, openai_ws_mode, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, media_type, service_tier, reasoning_effort, inbound_endpoint, upstream_endpoint, cache_ttl_overridden, created_at" const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, requested_model, upstream_model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, account_rate_multiplier, billing_type, request_type, stream, openai_ws_mode, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, media_type, service_tier, reasoning_effort, inbound_endpoint, upstream_endpoint, cache_ttl_overridden, created_at"
// usageLogInsertArgTypes must stay in the same order as:
// 1. prepareUsageLogInsert().args
// 2. every INSERT/CTE VALUES column list in this file
// 3. execUsageLogInsertNoResult placeholder positions
// 4. scanUsageLog selected column order (via usageLogSelectColumns)
//
// When adding a usage_logs column, update all of those call sites together.
var usageLogInsertArgTypes = [...]string{ var usageLogInsertArgTypes = [...]string{
"bigint", "bigint", // user_id
"bigint", "bigint", // api_key_id
"bigint", "bigint", // account_id
"text", "text", // request_id
"text", "text", // model
"text", "text", // requested_model
"bigint", "text", // upstream_model
"bigint", "bigint", // group_id
"integer", "bigint", // subscription_id
"integer", "integer", // input_tokens
"integer", "integer", // output_tokens
"integer", "integer", // cache_creation_tokens
"integer", "integer", // cache_read_tokens
"integer", "integer", // cache_creation_5m_tokens
"numeric", "integer", // cache_creation_1h_tokens
"numeric", "numeric", // input_cost
"numeric", "numeric", // output_cost
"numeric", "numeric", // cache_creation_cost
"numeric", "numeric", // cache_read_cost
"numeric", "numeric", // total_cost
"numeric", "numeric", // actual_cost
"numeric", "numeric", // rate_multiplier
"smallint", "numeric", // account_rate_multiplier
"smallint", "smallint", // billing_type
"boolean", "smallint", // request_type
"boolean", "boolean", // stream
"integer", "boolean", // openai_ws_mode
"integer", "integer", // duration_ms
"text", "integer", // first_token_ms
"text", "text", // user_agent
"integer", "text", // ip_address
"text", "integer", // image_count
"text", "text", // image_size
"text", "text", // media_type
"text", "text", // service_tier
"text", "text", // reasoning_effort
"text", "text", // inbound_endpoint
"boolean", "text", // upstream_endpoint
"timestamptz", "boolean", // cache_ttl_overridden
"timestamptz", // created_at
} }
const rawUsageLogModelColumn = "model"
// rawUsageLogModelColumn preserves the exact stored usage_logs.model semantics for direct filters.
// Historical rows may contain upstream/billing model values, while newer rows store requested_model.
// Requested/upstream/mapping analytics must use resolveModelDimensionExpression instead.
// dateFormatWhitelist 将 granularity 参数映射为 PostgreSQL TO_CHAR 格式字符串,防止外部输入直接拼入 SQL // dateFormatWhitelist 将 granularity 参数映射为 PostgreSQL TO_CHAR 格式字符串,防止外部输入直接拼入 SQL
var dateFormatWhitelist = map[string]string{ var dateFormatWhitelist = map[string]string{
"hour": "YYYY-MM-DD HH24:00", "hour": "YYYY-MM-DD HH24:00",
...@@ -88,6 +102,30 @@ func safeDateFormat(granularity string) string { ...@@ -88,6 +102,30 @@ func safeDateFormat(granularity string) string {
return "YYYY-MM-DD" return "YYYY-MM-DD"
} }
// appendRawUsageLogModelWhereCondition keeps direct model filters on the raw model column for backward
// compatibility with historical rows. Requested/upstream analytics must use
// resolveModelDimensionExpression instead.
func appendRawUsageLogModelWhereCondition(conditions []string, args []any, model string) ([]string, []any) {
if strings.TrimSpace(model) == "" {
return conditions, args
}
conditions = append(conditions, fmt.Sprintf("%s = $%d", rawUsageLogModelColumn, len(args)+1))
args = append(args, model)
return conditions, args
}
// appendRawUsageLogModelQueryFilter keeps direct model filters on the raw model column for backward
// compatibility with historical rows. Requested/upstream analytics must use
// resolveModelDimensionExpression instead.
func appendRawUsageLogModelQueryFilter(query string, args []any, model string) (string, []any) {
if strings.TrimSpace(model) == "" {
return query, args
}
query += fmt.Sprintf(" AND %s = $%d", rawUsageLogModelColumn, len(args)+1)
args = append(args, model)
return query, args
}
type usageLogRepository struct { type usageLogRepository struct {
client *dbent.Client client *dbent.Client
sql sqlExecutor sql sqlExecutor
...@@ -278,6 +316,7 @@ func (r *usageLogRepository) createSingle(ctx context.Context, sqlq sqlExecutor, ...@@ -278,6 +316,7 @@ func (r *usageLogRepository) createSingle(ctx context.Context, sqlq sqlExecutor,
account_id, account_id,
request_id, request_id,
model, model,
requested_model,
upstream_model, upstream_model,
group_id, group_id,
subscription_id, subscription_id,
...@@ -313,12 +352,12 @@ func (r *usageLogRepository) createSingle(ctx context.Context, sqlq sqlExecutor, ...@@ -313,12 +352,12 @@ func (r *usageLogRepository) createSingle(ctx context.Context, sqlq sqlExecutor,
cache_ttl_overridden, cache_ttl_overridden,
created_at created_at
) VALUES ( ) VALUES (
$1, $2, $3, $4, $5, $6, $1, $2, $3, $4, $5, $6, $7,
$7, $8, $8, $9,
$9, $10, $11, $12, $10, $11, $12, $13,
$13, $14, $14, $15,
$15, $16, $17, $18, $19, $20, $16, $17, $18, $19, $20, $21,
$21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38, $39 $22, $23, $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38, $39, $40
) )
ON CONFLICT (request_id, api_key_id) DO NOTHING ON CONFLICT (request_id, api_key_id) DO NOTHING
RETURNING id, created_at RETURNING id, created_at
...@@ -709,6 +748,7 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage ...@@ -709,6 +748,7 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage
account_id, account_id,
request_id, request_id,
model, model,
requested_model,
upstream_model, upstream_model,
group_id, group_id,
subscription_id, subscription_id,
...@@ -779,6 +819,7 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage ...@@ -779,6 +819,7 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage
account_id, account_id,
request_id, request_id,
model, model,
requested_model,
upstream_model, upstream_model,
group_id, group_id,
subscription_id, subscription_id,
...@@ -820,6 +861,7 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage ...@@ -820,6 +861,7 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage
account_id, account_id,
request_id, request_id,
model, model,
requested_model,
upstream_model, upstream_model,
group_id, group_id,
subscription_id, subscription_id,
...@@ -901,6 +943,7 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) ( ...@@ -901,6 +943,7 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) (
account_id, account_id,
request_id, request_id,
model, model,
requested_model,
upstream_model, upstream_model,
group_id, group_id,
subscription_id, subscription_id,
...@@ -937,7 +980,7 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) ( ...@@ -937,7 +980,7 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) (
created_at created_at
) AS (VALUES `) ) AS (VALUES `)
args := make([]any, 0, len(preparedList)*39) args := make([]any, 0, len(preparedList)*40)
argPos := 1 argPos := 1
for idx, prepared := range preparedList { for idx, prepared := range preparedList {
if idx > 0 { if idx > 0 {
...@@ -968,6 +1011,7 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) ( ...@@ -968,6 +1011,7 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) (
account_id, account_id,
request_id, request_id,
model, model,
requested_model,
upstream_model, upstream_model,
group_id, group_id,
subscription_id, subscription_id,
...@@ -1009,6 +1053,7 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) ( ...@@ -1009,6 +1053,7 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) (
account_id, account_id,
request_id, request_id,
model, model,
requested_model,
upstream_model, upstream_model,
group_id, group_id,
subscription_id, subscription_id,
...@@ -1058,6 +1103,7 @@ func execUsageLogInsertNoResult(ctx context.Context, sqlq sqlExecutor, prepared ...@@ -1058,6 +1103,7 @@ func execUsageLogInsertNoResult(ctx context.Context, sqlq sqlExecutor, prepared
account_id, account_id,
request_id, request_id,
model, model,
requested_model,
upstream_model, upstream_model,
group_id, group_id,
subscription_id, subscription_id,
...@@ -1093,12 +1139,12 @@ func execUsageLogInsertNoResult(ctx context.Context, sqlq sqlExecutor, prepared ...@@ -1093,12 +1139,12 @@ func execUsageLogInsertNoResult(ctx context.Context, sqlq sqlExecutor, prepared
cache_ttl_overridden, cache_ttl_overridden,
created_at created_at
) VALUES ( ) VALUES (
$1, $2, $3, $4, $5, $6, $1, $2, $3, $4, $5, $6, $7,
$7, $8, $8, $9,
$9, $10, $11, $12, $10, $11, $12, $13,
$13, $14, $14, $15,
$15, $16, $17, $18, $19, $20, $16, $17, $18, $19, $20, $21,
$21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38, $39 $22, $23, $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38, $39, $40
) )
ON CONFLICT (request_id, api_key_id) DO NOTHING ON CONFLICT (request_id, api_key_id) DO NOTHING
`, prepared.args...) `, prepared.args...)
...@@ -1130,6 +1176,10 @@ func prepareUsageLogInsert(log *service.UsageLog) usageLogInsertPrepared { ...@@ -1130,6 +1176,10 @@ func prepareUsageLogInsert(log *service.UsageLog) usageLogInsertPrepared {
reasoningEffort := nullString(log.ReasoningEffort) reasoningEffort := nullString(log.ReasoningEffort)
inboundEndpoint := nullString(log.InboundEndpoint) inboundEndpoint := nullString(log.InboundEndpoint)
upstreamEndpoint := nullString(log.UpstreamEndpoint) upstreamEndpoint := nullString(log.UpstreamEndpoint)
requestedModel := strings.TrimSpace(log.RequestedModel)
if requestedModel == "" {
requestedModel = strings.TrimSpace(log.Model)
}
upstreamModel := nullString(log.UpstreamModel) upstreamModel := nullString(log.UpstreamModel)
var requestIDArg any var requestIDArg any
...@@ -1148,6 +1198,7 @@ func prepareUsageLogInsert(log *service.UsageLog) usageLogInsertPrepared { ...@@ -1148,6 +1198,7 @@ func prepareUsageLogInsert(log *service.UsageLog) usageLogInsertPrepared {
log.AccountID, log.AccountID,
requestIDArg, requestIDArg,
log.Model, log.Model,
nullString(&requestedModel),
upstreamModel, upstreamModel,
groupID, groupID,
subscriptionID, subscriptionID,
...@@ -1702,7 +1753,7 @@ func (r *usageLogRepository) GetAccountStatsAggregated(ctx context.Context, acco ...@@ -1702,7 +1753,7 @@ func (r *usageLogRepository) GetAccountStatsAggregated(ctx context.Context, acco
// GetModelStatsAggregated 使用 SQL 聚合统计模型使用数据 // GetModelStatsAggregated 使用 SQL 聚合统计模型使用数据
// 性能优化:数据库层聚合计算,避免应用层循环统计 // 性能优化:数据库层聚合计算,避免应用层循环统计
func (r *usageLogRepository) GetModelStatsAggregated(ctx context.Context, modelName string, startTime, endTime time.Time) (*usagestats.UsageStats, error) { func (r *usageLogRepository) GetModelStatsAggregated(ctx context.Context, modelName string, startTime, endTime time.Time) (*usagestats.UsageStats, error) {
query := ` query := fmt.Sprintf(`
SELECT SELECT
COUNT(*) as total_requests, COUNT(*) as total_requests,
COALESCE(SUM(input_tokens), 0) as total_input_tokens, COALESCE(SUM(input_tokens), 0) as total_input_tokens,
...@@ -1712,8 +1763,8 @@ func (r *usageLogRepository) GetModelStatsAggregated(ctx context.Context, modelN ...@@ -1712,8 +1763,8 @@ func (r *usageLogRepository) GetModelStatsAggregated(ctx context.Context, modelN
COALESCE(SUM(actual_cost), 0) as total_actual_cost, COALESCE(SUM(actual_cost), 0) as total_actual_cost,
COALESCE(AVG(COALESCE(duration_ms, 0)), 0) as avg_duration_ms COALESCE(AVG(COALESCE(duration_ms, 0)), 0) as avg_duration_ms
FROM usage_logs FROM usage_logs
WHERE model = $1 AND created_at >= $2 AND created_at < $3 WHERE %s = $1 AND created_at >= $2 AND created_at < $3
` `, rawUsageLogModelColumn)
var stats usagestats.UsageStats var stats usagestats.UsageStats
if err := scanSingleRow( if err := scanSingleRow(
...@@ -1837,7 +1888,7 @@ func (r *usageLogRepository) ListByAccountAndTimeRange(ctx context.Context, acco ...@@ -1837,7 +1888,7 @@ func (r *usageLogRepository) ListByAccountAndTimeRange(ctx context.Context, acco
} }
func (r *usageLogRepository) ListByModelAndTimeRange(ctx context.Context, modelName string, startTime, endTime time.Time) ([]service.UsageLog, *pagination.PaginationResult, error) { func (r *usageLogRepository) ListByModelAndTimeRange(ctx context.Context, modelName string, startTime, endTime time.Time) ([]service.UsageLog, *pagination.PaginationResult, error) {
query := "SELECT " + usageLogSelectColumns + " FROM usage_logs WHERE model = $1 AND created_at >= $2 AND created_at < $3 ORDER BY id DESC LIMIT 10000" query := fmt.Sprintf("SELECT %s FROM usage_logs WHERE %s = $1 AND created_at >= $2 AND created_at < $3 ORDER BY id DESC LIMIT 10000", usageLogSelectColumns, rawUsageLogModelColumn)
logs, err := r.queryUsageLogs(ctx, query, modelName, startTime, endTime) logs, err := r.queryUsageLogs(ctx, query, modelName, startTime, endTime)
return logs, nil, err return logs, nil, err
} }
...@@ -2532,10 +2583,7 @@ func (r *usageLogRepository) ListWithFilters(ctx context.Context, params paginat ...@@ -2532,10 +2583,7 @@ func (r *usageLogRepository) ListWithFilters(ctx context.Context, params paginat
conditions = append(conditions, fmt.Sprintf("group_id = $%d", len(args)+1)) conditions = append(conditions, fmt.Sprintf("group_id = $%d", len(args)+1))
args = append(args, filters.GroupID) args = append(args, filters.GroupID)
} }
if filters.Model != "" { conditions, args = appendRawUsageLogModelWhereCondition(conditions, args, filters.Model)
conditions = append(conditions, fmt.Sprintf("model = $%d", len(args)+1))
args = append(args, filters.Model)
}
conditions, args = appendRequestTypeOrStreamWhereCondition(conditions, args, filters.RequestType, filters.Stream) conditions, args = appendRequestTypeOrStreamWhereCondition(conditions, args, filters.RequestType, filters.Stream)
if filters.BillingType != nil { if filters.BillingType != nil {
conditions = append(conditions, fmt.Sprintf("billing_type = $%d", len(args)+1)) conditions = append(conditions, fmt.Sprintf("billing_type = $%d", len(args)+1))
...@@ -2768,10 +2816,7 @@ func (r *usageLogRepository) GetUsageTrendWithFilters(ctx context.Context, start ...@@ -2768,10 +2816,7 @@ func (r *usageLogRepository) GetUsageTrendWithFilters(ctx context.Context, start
query += fmt.Sprintf(" AND group_id = $%d", len(args)+1) query += fmt.Sprintf(" AND group_id = $%d", len(args)+1)
args = append(args, groupID) args = append(args, groupID)
} }
if model != "" { query, args = appendRawUsageLogModelQueryFilter(query, args, model)
query += fmt.Sprintf(" AND model = $%d", len(args)+1)
args = append(args, model)
}
query, args = appendRequestTypeOrStreamQueryFilter(query, args, requestType, stream) query, args = appendRequestTypeOrStreamQueryFilter(query, args, requestType, stream)
if billingType != nil { if billingType != nil {
query += fmt.Sprintf(" AND billing_type = $%d", len(args)+1) query += fmt.Sprintf(" AND billing_type = $%d", len(args)+1)
...@@ -3126,13 +3171,14 @@ func (r *usageLogRepository) GetAllGroupUsageSummary(ctx context.Context, todayS ...@@ -3126,13 +3171,14 @@ func (r *usageLogRepository) GetAllGroupUsageSummary(ctx context.Context, todayS
// resolveModelDimensionExpression maps model source type to a safe SQL expression. // resolveModelDimensionExpression maps model source type to a safe SQL expression.
func resolveModelDimensionExpression(modelType string) string { func resolveModelDimensionExpression(modelType string) string {
requestedExpr := "COALESCE(NULLIF(TRIM(requested_model), ''), model)"
switch usagestats.NormalizeModelSource(modelType) { switch usagestats.NormalizeModelSource(modelType) {
case usagestats.ModelSourceUpstream: case usagestats.ModelSourceUpstream:
return "COALESCE(NULLIF(TRIM(upstream_model), ''), model)" return fmt.Sprintf("COALESCE(NULLIF(TRIM(upstream_model), ''), %s)", requestedExpr)
case usagestats.ModelSourceMapping: case usagestats.ModelSourceMapping:
return "(model || ' -> ' || COALESCE(NULLIF(TRIM(upstream_model), ''), model))" return fmt.Sprintf("(%s || ' -> ' || COALESCE(NULLIF(TRIM(upstream_model), ''), %s))", requestedExpr, requestedExpr)
default: default:
return "model" return requestedExpr
} }
} }
...@@ -3204,10 +3250,7 @@ func (r *usageLogRepository) GetStatsWithFilters(ctx context.Context, filters Us ...@@ -3204,10 +3250,7 @@ func (r *usageLogRepository) GetStatsWithFilters(ctx context.Context, filters Us
conditions = append(conditions, fmt.Sprintf("group_id = $%d", len(args)+1)) conditions = append(conditions, fmt.Sprintf("group_id = $%d", len(args)+1))
args = append(args, filters.GroupID) args = append(args, filters.GroupID)
} }
if filters.Model != "" { conditions, args = appendRawUsageLogModelWhereCondition(conditions, args, filters.Model)
conditions = append(conditions, fmt.Sprintf("model = $%d", len(args)+1))
args = append(args, filters.Model)
}
conditions, args = appendRequestTypeOrStreamWhereCondition(conditions, args, filters.RequestType, filters.Stream) conditions, args = appendRequestTypeOrStreamWhereCondition(conditions, args, filters.RequestType, filters.Stream)
if filters.BillingType != nil { if filters.BillingType != nil {
conditions = append(conditions, fmt.Sprintf("billing_type = $%d", len(args)+1)) conditions = append(conditions, fmt.Sprintf("billing_type = $%d", len(args)+1))
...@@ -3336,10 +3379,7 @@ func (r *usageLogRepository) getEndpointStatsByColumnWithFilters(ctx context.Con ...@@ -3336,10 +3379,7 @@ func (r *usageLogRepository) getEndpointStatsByColumnWithFilters(ctx context.Con
query += fmt.Sprintf(" AND group_id = $%d", len(args)+1) query += fmt.Sprintf(" AND group_id = $%d", len(args)+1)
args = append(args, groupID) args = append(args, groupID)
} }
if model != "" { query, args = appendRawUsageLogModelQueryFilter(query, args, model)
query += fmt.Sprintf(" AND model = $%d", len(args)+1)
args = append(args, model)
}
query, args = appendRequestTypeOrStreamQueryFilter(query, args, requestType, stream) query, args = appendRequestTypeOrStreamQueryFilter(query, args, requestType, stream)
if billingType != nil { if billingType != nil {
query += fmt.Sprintf(" AND billing_type = $%d", len(args)+1) query += fmt.Sprintf(" AND billing_type = $%d", len(args)+1)
...@@ -3410,10 +3450,7 @@ func (r *usageLogRepository) getEndpointPathStatsWithFilters(ctx context.Context ...@@ -3410,10 +3450,7 @@ func (r *usageLogRepository) getEndpointPathStatsWithFilters(ctx context.Context
query += fmt.Sprintf(" AND group_id = $%d", len(args)+1) query += fmt.Sprintf(" AND group_id = $%d", len(args)+1)
args = append(args, groupID) args = append(args, groupID)
} }
if model != "" { query, args = appendRawUsageLogModelQueryFilter(query, args, model)
query += fmt.Sprintf(" AND model = $%d", len(args)+1)
args = append(args, model)
}
query, args = appendRequestTypeOrStreamQueryFilter(query, args, requestType, stream) query, args = appendRequestTypeOrStreamQueryFilter(query, args, requestType, stream)
if billingType != nil { if billingType != nil {
query += fmt.Sprintf(" AND billing_type = $%d", len(args)+1) query += fmt.Sprintf(" AND billing_type = $%d", len(args)+1)
...@@ -3888,6 +3925,7 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e ...@@ -3888,6 +3925,7 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
accountID int64 accountID int64
requestID sql.NullString requestID sql.NullString
model string model string
requestedModel sql.NullString
upstreamModel sql.NullString upstreamModel sql.NullString
groupID sql.NullInt64 groupID sql.NullInt64
subscriptionID sql.NullInt64 subscriptionID sql.NullInt64
...@@ -3931,6 +3969,7 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e ...@@ -3931,6 +3969,7 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
&accountID, &accountID,
&requestID, &requestID,
&model, &model,
&requestedModel,
&upstreamModel, &upstreamModel,
&groupID, &groupID,
&subscriptionID, &subscriptionID,
...@@ -3975,6 +4014,7 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e ...@@ -3975,6 +4014,7 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
APIKeyID: apiKeyID, APIKeyID: apiKeyID,
AccountID: accountID, AccountID: accountID,
Model: model, Model: model,
RequestedModel: coalesceTrimmedString(requestedModel, model),
InputTokens: inputTokens, InputTokens: inputTokens,
OutputTokens: outputTokens, OutputTokens: outputTokens,
CacheCreationTokens: cacheCreationTokens, CacheCreationTokens: cacheCreationTokens,
...@@ -4181,6 +4221,13 @@ func nullString(v *string) sql.NullString { ...@@ -4181,6 +4221,13 @@ func nullString(v *string) sql.NullString {
return sql.NullString{String: *v, Valid: true} return sql.NullString{String: *v, Valid: true}
} }
func coalesceTrimmedString(v sql.NullString, fallback string) string {
if v.Valid && strings.TrimSpace(v.String) != "" {
return v.String
}
return fallback
}
func setToSlice(set map[int64]struct{}) []int64 { func setToSlice(set map[int64]struct{}) []int64 {
out := make([]int64, 0, len(set)) out := make([]int64, 0, len(set))
for id := range set { for id := range set {
......
...@@ -34,11 +34,11 @@ func TestResolveModelDimensionExpression(t *testing.T) { ...@@ -34,11 +34,11 @@ func TestResolveModelDimensionExpression(t *testing.T) {
modelType string modelType string
want string want string
}{ }{
{usagestats.ModelSourceRequested, "model"}, {usagestats.ModelSourceRequested, "COALESCE(NULLIF(TRIM(requested_model), ''), model)"},
{usagestats.ModelSourceUpstream, "COALESCE(NULLIF(TRIM(upstream_model), ''), model)"}, {usagestats.ModelSourceUpstream, "COALESCE(NULLIF(TRIM(upstream_model), ''), COALESCE(NULLIF(TRIM(requested_model), ''), model))"},
{usagestats.ModelSourceMapping, "(model || ' -> ' || COALESCE(NULLIF(TRIM(upstream_model), ''), model))"}, {usagestats.ModelSourceMapping, "(COALESCE(NULLIF(TRIM(requested_model), ''), model) || ' -> ' || COALESCE(NULLIF(TRIM(upstream_model), ''), COALESCE(NULLIF(TRIM(requested_model), ''), model)))"},
{"", "model"}, {"", "COALESCE(NULLIF(TRIM(requested_model), ''), model)"},
{"invalid", "model"}, {"invalid", "COALESCE(NULLIF(TRIM(requested_model), ''), model)"},
} }
for _, tc := range tests { for _, tc := range tests {
......
...@@ -3,6 +3,7 @@ package repository ...@@ -3,6 +3,7 @@ package repository
import ( import (
"context" "context"
"database/sql" "database/sql"
"database/sql/driver"
"fmt" "fmt"
"reflect" "reflect"
"testing" "testing"
...@@ -21,20 +22,21 @@ func TestUsageLogRepositoryCreateSyncRequestTypeAndLegacyFields(t *testing.T) { ...@@ -21,20 +22,21 @@ func TestUsageLogRepositoryCreateSyncRequestTypeAndLegacyFields(t *testing.T) {
createdAt := time.Date(2025, 1, 1, 12, 0, 0, 0, time.UTC) createdAt := time.Date(2025, 1, 1, 12, 0, 0, 0, time.UTC)
log := &service.UsageLog{ log := &service.UsageLog{
UserID: 1, UserID: 1,
APIKeyID: 2, APIKeyID: 2,
AccountID: 3, AccountID: 3,
RequestID: "req-1", RequestID: "req-1",
Model: "gpt-5", Model: "gpt-5",
InputTokens: 10, RequestedModel: "gpt-5",
OutputTokens: 20, InputTokens: 10,
TotalCost: 1, OutputTokens: 20,
ActualCost: 1, TotalCost: 1,
BillingType: service.BillingTypeBalance, ActualCost: 1,
RequestType: service.RequestTypeWSV2, BillingType: service.BillingTypeBalance,
Stream: false, RequestType: service.RequestTypeWSV2,
OpenAIWSMode: false, Stream: false,
CreatedAt: createdAt, OpenAIWSMode: false,
CreatedAt: createdAt,
} }
mock.ExpectQuery("INSERT INTO usage_logs"). mock.ExpectQuery("INSERT INTO usage_logs").
...@@ -44,6 +46,7 @@ func TestUsageLogRepositoryCreateSyncRequestTypeAndLegacyFields(t *testing.T) { ...@@ -44,6 +46,7 @@ func TestUsageLogRepositoryCreateSyncRequestTypeAndLegacyFields(t *testing.T) {
log.AccountID, log.AccountID,
log.RequestID, log.RequestID,
log.Model, log.Model,
log.RequestedModel,
sqlmock.AnyArg(), // upstream_model sqlmock.AnyArg(), // upstream_model
sqlmock.AnyArg(), // group_id sqlmock.AnyArg(), // group_id
sqlmock.AnyArg(), // subscription_id sqlmock.AnyArg(), // subscription_id
...@@ -99,13 +102,14 @@ func TestUsageLogRepositoryCreate_PersistsServiceTier(t *testing.T) { ...@@ -99,13 +102,14 @@ func TestUsageLogRepositoryCreate_PersistsServiceTier(t *testing.T) {
createdAt := time.Date(2025, 1, 2, 12, 0, 0, 0, time.UTC) createdAt := time.Date(2025, 1, 2, 12, 0, 0, 0, time.UTC)
serviceTier := "priority" serviceTier := "priority"
log := &service.UsageLog{ log := &service.UsageLog{
UserID: 1, UserID: 1,
APIKeyID: 2, APIKeyID: 2,
AccountID: 3, AccountID: 3,
RequestID: "req-service-tier", RequestID: "req-service-tier",
Model: "gpt-5.4", Model: "gpt-5.4",
ServiceTier: &serviceTier, RequestedModel: "gpt-5.4",
CreatedAt: createdAt, ServiceTier: &serviceTier,
CreatedAt: createdAt,
} }
mock.ExpectQuery("INSERT INTO usage_logs"). mock.ExpectQuery("INSERT INTO usage_logs").
...@@ -115,6 +119,7 @@ func TestUsageLogRepositoryCreate_PersistsServiceTier(t *testing.T) { ...@@ -115,6 +119,7 @@ func TestUsageLogRepositoryCreate_PersistsServiceTier(t *testing.T) {
log.AccountID, log.AccountID,
log.RequestID, log.RequestID,
log.Model, log.Model,
log.RequestedModel,
sqlmock.AnyArg(), sqlmock.AnyArg(),
sqlmock.AnyArg(), sqlmock.AnyArg(),
sqlmock.AnyArg(), sqlmock.AnyArg(),
...@@ -158,6 +163,75 @@ func TestUsageLogRepositoryCreate_PersistsServiceTier(t *testing.T) { ...@@ -158,6 +163,75 @@ func TestUsageLogRepositoryCreate_PersistsServiceTier(t *testing.T) {
require.NoError(t, mock.ExpectationsWereMet()) require.NoError(t, mock.ExpectationsWereMet())
} }
func TestBuildUsageLogBestEffortInsertQuery_IncludesRequestedModelColumn(t *testing.T) {
prepared := prepareUsageLogInsert(&service.UsageLog{
UserID: 1,
APIKeyID: 2,
AccountID: 3,
RequestID: "req-best-effort-query",
Model: "gpt-5",
RequestedModel: "gpt-5",
CreatedAt: time.Date(2025, 1, 3, 12, 0, 0, 0, time.UTC),
})
query, args := buildUsageLogBestEffortInsertQuery([]usageLogInsertPrepared{prepared})
require.Contains(t, query, "INSERT INTO usage_logs (")
require.Contains(t, query, "\n\t\t\tmodel,\n\t\t\trequested_model,\n\t\t\tupstream_model,")
require.Contains(t, query, "\n\t\t\trequest_id,\n\t\t\tmodel,\n\t\t\trequested_model,\n\t\t\tupstream_model,")
require.Len(t, args, len(prepared.args))
require.Equal(t, prepared.args[5], args[5])
}
func TestExecUsageLogInsertNoResult_PersistsRequestedModel(t *testing.T) {
db, mock := newSQLMock(t)
prepared := prepareUsageLogInsert(&service.UsageLog{
UserID: 1,
APIKeyID: 2,
AccountID: 3,
RequestID: "req-best-effort-exec",
Model: "gpt-5",
RequestedModel: "gpt-5",
CreatedAt: time.Date(2025, 1, 4, 12, 0, 0, 0, time.UTC),
})
mock.ExpectExec("INSERT INTO usage_logs").
WithArgs(anySliceToDriverValues(prepared.args)...).
WillReturnResult(sqlmock.NewResult(0, 1))
err := execUsageLogInsertNoResult(context.Background(), db, prepared)
require.NoError(t, err)
require.NoError(t, mock.ExpectationsWereMet())
}
func TestPrepareUsageLogInsert_ArgCountMatchesTypes(t *testing.T) {
prepared := prepareUsageLogInsert(&service.UsageLog{
UserID: 1,
APIKeyID: 2,
AccountID: 3,
RequestID: "req-arg-count",
Model: "gpt-5",
RequestedModel: "gpt-5",
CreatedAt: time.Date(2025, 1, 5, 12, 0, 0, 0, time.UTC),
})
require.Len(t, prepared.args, len(usageLogInsertArgTypes))
}
func TestCoalesceTrimmedString(t *testing.T) {
require.Equal(t, "fallback", coalesceTrimmedString(sql.NullString{}, "fallback"))
require.Equal(t, "fallback", coalesceTrimmedString(sql.NullString{Valid: true, String: " "}, "fallback"))
require.Equal(t, "value", coalesceTrimmedString(sql.NullString{Valid: true, String: "value"}, "fallback"))
}
func anySliceToDriverValues(values []any) []driver.Value {
out := make([]driver.Value, 0, len(values))
for _, value := range values {
out = append(out, value)
}
return out
}
func TestUsageLogRepositoryListWithFiltersRequestTypePriority(t *testing.T) { func TestUsageLogRepositoryListWithFiltersRequestTypePriority(t *testing.T) {
db, mock := newSQLMock(t) db, mock := newSQLMock(t)
repo := &usageLogRepository{sql: db} repo := &usageLogRepository{sql: db}
...@@ -354,7 +428,8 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) { ...@@ -354,7 +428,8 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) {
int64(20), // api_key_id int64(20), // api_key_id
int64(30), // account_id int64(30), // account_id
sql.NullString{Valid: true, String: "req-1"}, sql.NullString{Valid: true, String: "req-1"},
"gpt-5", // model "gpt-5", // model
sql.NullString{Valid: true, String: "gpt-5"}, // requested_model
sql.NullString{}, // upstream_model sql.NullString{}, // upstream_model
sql.NullInt64{}, // group_id sql.NullInt64{}, // group_id
sql.NullInt64{}, // subscription_id sql.NullInt64{}, // subscription_id
...@@ -407,6 +482,7 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) { ...@@ -407,6 +482,7 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) {
int64(31), int64(31),
sql.NullString{Valid: true, String: "req-2"}, sql.NullString{Valid: true, String: "req-2"},
"gpt-5", "gpt-5",
sql.NullString{Valid: true, String: "gpt-5"},
sql.NullString{}, sql.NullString{},
sql.NullInt64{}, sql.NullInt64{},
sql.NullInt64{}, sql.NullInt64{},
...@@ -449,6 +525,7 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) { ...@@ -449,6 +525,7 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) {
int64(32), int64(32),
sql.NullString{Valid: true, String: "req-3"}, sql.NullString{Valid: true, String: "req-3"},
"gpt-5.4", "gpt-5.4",
sql.NullString{Valid: true, String: "gpt-5.4"},
sql.NullString{}, sql.NullString{},
sql.NullInt64{}, sql.NullInt64{},
sql.NullInt64{}, sql.NullInt64{},
......
...@@ -1742,7 +1742,8 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context, ...@@ -1742,7 +1742,8 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
return &ForwardResult{ return &ForwardResult{
RequestID: requestID, RequestID: requestID,
Usage: *usage, Usage: *usage,
Model: billingModel, // 使用映射模型用于计费和日志 Model: originalModel,
UpstreamModel: billingModel,
Stream: claudeReq.Stream, Stream: claudeReq.Stream,
Duration: time.Since(startTime), Duration: time.Since(startTime),
FirstTokenMs: firstTokenMs, FirstTokenMs: firstTokenMs,
...@@ -2435,7 +2436,8 @@ handleSuccess: ...@@ -2435,7 +2436,8 @@ handleSuccess:
return &ForwardResult{ return &ForwardResult{
RequestID: requestID, RequestID: requestID,
Usage: *usage, Usage: *usage,
Model: billingModel, Model: originalModel,
UpstreamModel: billingModel,
Stream: stream, Stream: stream,
Duration: time.Since(startTime), Duration: time.Since(startTime),
FirstTokenMs: firstTokenMs, FirstTokenMs: firstTokenMs,
......
...@@ -542,7 +542,8 @@ func TestAntigravityGatewayService_Forward_BillsWithMappedModel(t *testing.T) { ...@@ -542,7 +542,8 @@ func TestAntigravityGatewayService_Forward_BillsWithMappedModel(t *testing.T) {
result, err := svc.Forward(context.Background(), c, account, body, false) result, err := svc.Forward(context.Background(), c, account, body, false)
require.NoError(t, err) require.NoError(t, err)
require.NotNil(t, result) require.NotNil(t, result)
require.Equal(t, mappedModel, result.Model) require.Equal(t, "claude-sonnet-4-5", result.Model)
require.Equal(t, mappedModel, result.UpstreamModel)
} }
// TestAntigravityGatewayService_ForwardGemini_BillsWithMappedModel // TestAntigravityGatewayService_ForwardGemini_BillsWithMappedModel
...@@ -594,7 +595,8 @@ func TestAntigravityGatewayService_ForwardGemini_BillsWithMappedModel(t *testing ...@@ -594,7 +595,8 @@ func TestAntigravityGatewayService_ForwardGemini_BillsWithMappedModel(t *testing
result, err := svc.ForwardGemini(context.Background(), c, account, "gemini-2.5-flash", "generateContent", true, body, false) result, err := svc.ForwardGemini(context.Background(), c, account, "gemini-2.5-flash", "generateContent", true, body, false)
require.NoError(t, err) require.NoError(t, err)
require.NotNil(t, result) require.NotNil(t, result)
require.Equal(t, mappedModel, result.Model) require.Equal(t, "gemini-2.5-flash", result.Model)
require.Equal(t, mappedModel, result.UpstreamModel)
} }
func TestAntigravityGatewayService_ForwardGemini_RetriesCorruptedThoughtSignature(t *testing.T) { func TestAntigravityGatewayService_ForwardGemini_RetriesCorruptedThoughtSignature(t *testing.T) {
...@@ -664,7 +666,8 @@ func TestAntigravityGatewayService_ForwardGemini_RetriesCorruptedThoughtSignatur ...@@ -664,7 +666,8 @@ func TestAntigravityGatewayService_ForwardGemini_RetriesCorruptedThoughtSignatur
result, err := svc.ForwardGemini(context.Background(), c, account, originalModel, "streamGenerateContent", true, body, false) result, err := svc.ForwardGemini(context.Background(), c, account, originalModel, "streamGenerateContent", true, body, false)
require.NoError(t, err) require.NoError(t, err)
require.NotNil(t, result) require.NotNil(t, result)
require.Equal(t, mappedModel, result.Model) require.Equal(t, originalModel, result.Model)
require.Equal(t, mappedModel, result.UpstreamModel)
require.Len(t, upstream.requestBodies, 2, "signature error should trigger exactly one retry") require.Len(t, upstream.requestBodies, 2, "signature error should trigger exactly one retry")
firstReq := string(upstream.requestBodies[0]) firstReq := string(upstream.requestBodies[0])
......
...@@ -162,6 +162,32 @@ func TestGatewayServiceRecordUsage_BillingFingerprintFallsBackToContextRequestID ...@@ -162,6 +162,32 @@ func TestGatewayServiceRecordUsage_BillingFingerprintFallsBackToContextRequestID
require.Equal(t, "local:req-local-123", billingRepo.lastCmd.RequestPayloadHash) require.Equal(t, "local:req-local-123", billingRepo.lastCmd.RequestPayloadHash)
} }
func TestGatewayServiceRecordUsage_PreservesRequestedAndUpstreamModels(t *testing.T) {
usageRepo := &openAIRecordUsageLogRepoStub{inserted: true}
svc := newGatewayRecordUsageServiceForTest(usageRepo, &openAIRecordUsageUserRepoStub{}, &openAIRecordUsageSubRepoStub{})
mappedModel := "claude-sonnet-4-20250514"
err := svc.RecordUsage(context.Background(), &RecordUsageInput{
Result: &ForwardResult{
RequestID: "gateway_models_split",
Usage: ClaudeUsage{InputTokens: 10, OutputTokens: 6},
Model: "claude-sonnet-4",
UpstreamModel: mappedModel,
Duration: time.Second,
},
APIKey: &APIKey{ID: 501, Quota: 100},
User: &User{ID: 601},
Account: &Account{ID: 701},
})
require.NoError(t, err)
require.NotNil(t, usageRepo.lastLog)
require.Equal(t, "claude-sonnet-4", usageRepo.lastLog.Model)
require.Equal(t, "claude-sonnet-4", usageRepo.lastLog.RequestedModel)
require.NotNil(t, usageRepo.lastLog.UpstreamModel)
require.Equal(t, mappedModel, *usageRepo.lastLog.UpstreamModel)
}
func TestGatewayServiceRecordUsage_UsageLogWriteErrorDoesNotSkipBilling(t *testing.T) { func TestGatewayServiceRecordUsage_UsageLogWriteErrorDoesNotSkipBilling(t *testing.T) {
usageRepo := &openAIRecordUsageLogRepoStub{inserted: false, err: MarkUsageLogCreateNotPersisted(context.Canceled)} usageRepo := &openAIRecordUsageLogRepoStub{inserted: false, err: MarkUsageLogCreateNotPersisted(context.Canceled)}
userRepo := &openAIRecordUsageUserRepoStub{} userRepo := &openAIRecordUsageUserRepoStub{}
......
...@@ -482,10 +482,12 @@ type ClaudeUsage struct { ...@@ -482,10 +482,12 @@ type ClaudeUsage struct {
// ForwardResult 转发结果 // ForwardResult 转发结果
type ForwardResult struct { type ForwardResult struct {
RequestID string RequestID string
Usage ClaudeUsage Usage ClaudeUsage
Model string Model string
UpstreamModel string // Actual upstream model after mapping (empty = no mapping) // UpstreamModel is the actual upstream model after mapping.
// Prefer empty when it is identical to Model; persistence normalizes equal values away as no-op mappings.
UpstreamModel string
Stream bool Stream bool
Duration time.Duration Duration time.Duration
FirstTokenMs *int // 首字时间(流式请求) FirstTokenMs *int // 首字时间(流式请求)
...@@ -7512,6 +7514,7 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu ...@@ -7512,6 +7514,7 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
} }
var cost *CostBreakdown var cost *CostBreakdown
billingModel := forwardResultBillingModel(result.Model, result.UpstreamModel)
// 根据请求类型选择计费方式 // 根据请求类型选择计费方式
if result.MediaType == "image" || result.MediaType == "video" { if result.MediaType == "image" || result.MediaType == "video" {
...@@ -7527,7 +7530,7 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu ...@@ -7527,7 +7530,7 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
if result.MediaType == "image" { if result.MediaType == "image" {
cost = s.billingService.CalculateSoraImageCost(result.ImageSize, result.ImageCount, soraConfig, multiplier) cost = s.billingService.CalculateSoraImageCost(result.ImageSize, result.ImageCount, soraConfig, multiplier)
} else { } else {
cost = s.billingService.CalculateSoraVideoCost(result.Model, soraConfig, multiplier) cost = s.billingService.CalculateSoraVideoCost(billingModel, soraConfig, multiplier)
} }
} else if result.MediaType == "prompt" { } else if result.MediaType == "prompt" {
cost = &CostBreakdown{} cost = &CostBreakdown{}
...@@ -7541,7 +7544,7 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu ...@@ -7541,7 +7544,7 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
Price4K: apiKey.Group.ImagePrice4K, Price4K: apiKey.Group.ImagePrice4K,
} }
} }
cost = s.billingService.CalculateImageCost(result.Model, result.ImageSize, result.ImageCount, groupConfig, multiplier) cost = s.billingService.CalculateImageCost(billingModel, result.ImageSize, result.ImageCount, groupConfig, multiplier)
} else { } else {
// Token 计费 // Token 计费
tokens := UsageTokens{ tokens := UsageTokens{
...@@ -7553,7 +7556,7 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu ...@@ -7553,7 +7556,7 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
CacheCreation1hTokens: result.Usage.CacheCreation1hTokens, CacheCreation1hTokens: result.Usage.CacheCreation1hTokens,
} }
var err error var err error
cost, err = s.billingService.CalculateCost(result.Model, tokens, multiplier) cost, err = s.billingService.CalculateCost(billingModel, tokens, multiplier)
if err != nil { if err != nil {
logger.LegacyPrintf("service.gateway", "Calculate cost failed: %v", err) logger.LegacyPrintf("service.gateway", "Calculate cost failed: %v", err)
cost = &CostBreakdown{ActualCost: 0} cost = &CostBreakdown{ActualCost: 0}
...@@ -7585,6 +7588,7 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu ...@@ -7585,6 +7588,7 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
AccountID: account.ID, AccountID: account.ID,
RequestID: requestID, RequestID: requestID,
Model: result.Model, Model: result.Model,
RequestedModel: result.Model,
UpstreamModel: optionalNonEqualStringPtr(result.UpstreamModel, result.Model), UpstreamModel: optionalNonEqualStringPtr(result.UpstreamModel, result.Model),
ReasoningEffort: result.ReasoningEffort, ReasoningEffort: result.ReasoningEffort,
InboundEndpoint: optionalTrimmedStringPtr(input.InboundEndpoint), InboundEndpoint: optionalTrimmedStringPtr(input.InboundEndpoint),
...@@ -7715,6 +7719,7 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input * ...@@ -7715,6 +7719,7 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input *
} }
var cost *CostBreakdown var cost *CostBreakdown
billingModel := forwardResultBillingModel(result.Model, result.UpstreamModel)
// 根据请求类型选择计费方式 // 根据请求类型选择计费方式
if result.ImageCount > 0 { if result.ImageCount > 0 {
...@@ -7727,7 +7732,7 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input * ...@@ -7727,7 +7732,7 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input *
Price4K: apiKey.Group.ImagePrice4K, Price4K: apiKey.Group.ImagePrice4K,
} }
} }
cost = s.billingService.CalculateImageCost(result.Model, result.ImageSize, result.ImageCount, groupConfig, multiplier) cost = s.billingService.CalculateImageCost(billingModel, result.ImageSize, result.ImageCount, groupConfig, multiplier)
} else { } else {
// Token 计费(使用长上下文计费方法) // Token 计费(使用长上下文计费方法)
tokens := UsageTokens{ tokens := UsageTokens{
...@@ -7739,7 +7744,7 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input * ...@@ -7739,7 +7744,7 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input *
CacheCreation1hTokens: result.Usage.CacheCreation1hTokens, CacheCreation1hTokens: result.Usage.CacheCreation1hTokens,
} }
var err error var err error
cost, err = s.billingService.CalculateCostWithLongContext(result.Model, tokens, multiplier, input.LongContextThreshold, input.LongContextMultiplier) cost, err = s.billingService.CalculateCostWithLongContext(billingModel, tokens, multiplier, input.LongContextThreshold, input.LongContextMultiplier)
if err != nil { if err != nil {
logger.LegacyPrintf("service.gateway", "Calculate cost failed: %v", err) logger.LegacyPrintf("service.gateway", "Calculate cost failed: %v", err)
cost = &CostBreakdown{ActualCost: 0} cost = &CostBreakdown{ActualCost: 0}
...@@ -7767,6 +7772,7 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input * ...@@ -7767,6 +7772,7 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input *
AccountID: account.ID, AccountID: account.ID,
RequestID: requestID, RequestID: requestID,
Model: result.Model, Model: result.Model,
RequestedModel: result.Model,
UpstreamModel: optionalNonEqualStringPtr(result.UpstreamModel, result.Model), UpstreamModel: optionalNonEqualStringPtr(result.UpstreamModel, result.Model),
ReasoningEffort: result.ReasoningEffort, ReasoningEffort: result.ReasoningEffort,
InboundEndpoint: optionalTrimmedStringPtr(input.InboundEndpoint), InboundEndpoint: optionalTrimmedStringPtr(input.InboundEndpoint),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment