Unverified Commit 6bccb8a8 authored by Wesley Liddick's avatar Wesley Liddick Committed by GitHub
Browse files

Merge branch 'main' into feature/antigravity-user-agent-configurable

parents 1fc6ef3d 3de1e0e4
......@@ -3,6 +3,7 @@ package repository
import (
"context"
"database/sql"
"encoding/json"
"fmt"
"strings"
"time"
......@@ -55,6 +56,10 @@ INSERT INTO ops_error_logs (
upstream_error_message,
upstream_error_detail,
upstream_errors,
auth_latency_ms,
routing_latency_ms,
upstream_latency_ms,
response_latency_ms,
time_to_first_token_ms,
request_body,
request_body_truncated,
......@@ -64,7 +69,7 @@ INSERT INTO ops_error_logs (
retry_count,
created_at
) VALUES (
$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31,$32,$33,$34
$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31,$32,$33,$34,$35,$36,$37,$38
) RETURNING id`
var id int64
......@@ -97,6 +102,10 @@ INSERT INTO ops_error_logs (
opsNullString(input.UpstreamErrorMessage),
opsNullString(input.UpstreamErrorDetail),
opsNullString(input.UpstreamErrorsJSON),
opsNullInt64(input.AuthLatencyMs),
opsNullInt64(input.RoutingLatencyMs),
opsNullInt64(input.UpstreamLatencyMs),
opsNullInt64(input.ResponseLatencyMs),
opsNullInt64(input.TimeToFirstTokenMs),
opsNullString(input.RequestBodyJSON),
input.RequestBodyTruncated,
......@@ -930,6 +939,243 @@ WHERE id = $1`
return err
}
func (r *opsRepository) BatchInsertSystemLogs(ctx context.Context, inputs []*service.OpsInsertSystemLogInput) (int64, error) {
if r == nil || r.db == nil {
return 0, fmt.Errorf("nil ops repository")
}
if len(inputs) == 0 {
return 0, nil
}
tx, err := r.db.BeginTx(ctx, nil)
if err != nil {
return 0, err
}
stmt, err := tx.PrepareContext(ctx, pq.CopyIn(
"ops_system_logs",
"created_at",
"level",
"component",
"message",
"request_id",
"client_request_id",
"user_id",
"account_id",
"platform",
"model",
"extra",
))
if err != nil {
_ = tx.Rollback()
return 0, err
}
var inserted int64
for _, input := range inputs {
if input == nil {
continue
}
createdAt := input.CreatedAt
if createdAt.IsZero() {
createdAt = time.Now().UTC()
}
component := strings.TrimSpace(input.Component)
level := strings.ToLower(strings.TrimSpace(input.Level))
message := strings.TrimSpace(input.Message)
if level == "" || message == "" {
continue
}
if component == "" {
component = "app"
}
extra := strings.TrimSpace(input.ExtraJSON)
if extra == "" {
extra = "{}"
}
if _, err := stmt.ExecContext(
ctx,
createdAt.UTC(),
level,
component,
message,
opsNullString(input.RequestID),
opsNullString(input.ClientRequestID),
opsNullInt64(input.UserID),
opsNullInt64(input.AccountID),
opsNullString(input.Platform),
opsNullString(input.Model),
extra,
); err != nil {
_ = stmt.Close()
_ = tx.Rollback()
return inserted, err
}
inserted++
}
if _, err := stmt.ExecContext(ctx); err != nil {
_ = stmt.Close()
_ = tx.Rollback()
return inserted, err
}
if err := stmt.Close(); err != nil {
_ = tx.Rollback()
return inserted, err
}
if err := tx.Commit(); err != nil {
return inserted, err
}
return inserted, nil
}
func (r *opsRepository) ListSystemLogs(ctx context.Context, filter *service.OpsSystemLogFilter) (*service.OpsSystemLogList, error) {
if r == nil || r.db == nil {
return nil, fmt.Errorf("nil ops repository")
}
if filter == nil {
filter = &service.OpsSystemLogFilter{}
}
page := filter.Page
if page <= 0 {
page = 1
}
pageSize := filter.PageSize
if pageSize <= 0 {
pageSize = 50
}
if pageSize > 200 {
pageSize = 200
}
where, args, _ := buildOpsSystemLogsWhere(filter)
countSQL := "SELECT COUNT(*) FROM ops_system_logs l " + where
var total int
if err := r.db.QueryRowContext(ctx, countSQL, args...).Scan(&total); err != nil {
return nil, err
}
offset := (page - 1) * pageSize
argsWithLimit := append(args, pageSize, offset)
query := `
SELECT
l.id,
l.created_at,
l.level,
COALESCE(l.component, ''),
COALESCE(l.message, ''),
COALESCE(l.request_id, ''),
COALESCE(l.client_request_id, ''),
l.user_id,
l.account_id,
COALESCE(l.platform, ''),
COALESCE(l.model, ''),
COALESCE(l.extra::text, '{}')
FROM ops_system_logs l
` + where + `
ORDER BY l.created_at DESC, l.id DESC
LIMIT $` + itoa(len(args)+1) + ` OFFSET $` + itoa(len(args)+2)
rows, err := r.db.QueryContext(ctx, query, argsWithLimit...)
if err != nil {
return nil, err
}
defer func() { _ = rows.Close() }()
logs := make([]*service.OpsSystemLog, 0, pageSize)
for rows.Next() {
item := &service.OpsSystemLog{}
var userID sql.NullInt64
var accountID sql.NullInt64
var extraRaw string
if err := rows.Scan(
&item.ID,
&item.CreatedAt,
&item.Level,
&item.Component,
&item.Message,
&item.RequestID,
&item.ClientRequestID,
&userID,
&accountID,
&item.Platform,
&item.Model,
&extraRaw,
); err != nil {
return nil, err
}
if userID.Valid {
v := userID.Int64
item.UserID = &v
}
if accountID.Valid {
v := accountID.Int64
item.AccountID = &v
}
extraRaw = strings.TrimSpace(extraRaw)
if extraRaw != "" && extraRaw != "null" && extraRaw != "{}" {
extra := make(map[string]any)
if err := json.Unmarshal([]byte(extraRaw), &extra); err == nil {
item.Extra = extra
}
}
logs = append(logs, item)
}
if err := rows.Err(); err != nil {
return nil, err
}
return &service.OpsSystemLogList{
Logs: logs,
Total: total,
Page: page,
PageSize: pageSize,
}, nil
}
func (r *opsRepository) DeleteSystemLogs(ctx context.Context, filter *service.OpsSystemLogCleanupFilter) (int64, error) {
if r == nil || r.db == nil {
return 0, fmt.Errorf("nil ops repository")
}
if filter == nil {
filter = &service.OpsSystemLogCleanupFilter{}
}
where, args, hasConstraint := buildOpsSystemLogsCleanupWhere(filter)
if !hasConstraint {
return 0, fmt.Errorf("cleanup requires at least one filter condition")
}
query := "DELETE FROM ops_system_logs l " + where
res, err := r.db.ExecContext(ctx, query, args...)
if err != nil {
return 0, err
}
return res.RowsAffected()
}
func (r *opsRepository) InsertSystemLogCleanupAudit(ctx context.Context, input *service.OpsSystemLogCleanupAudit) error {
if r == nil || r.db == nil {
return fmt.Errorf("nil ops repository")
}
if input == nil {
return fmt.Errorf("nil input")
}
createdAt := input.CreatedAt
if createdAt.IsZero() {
createdAt = time.Now().UTC()
}
_, err := r.db.ExecContext(ctx, `
INSERT INTO ops_system_log_cleanup_audits (
created_at,
operator_id,
conditions,
deleted_rows
) VALUES ($1,$2,$3,$4)
`, createdAt.UTC(), input.OperatorID, input.Conditions, input.DeletedRows)
return err
}
func buildOpsErrorLogsWhere(filter *service.OpsErrorLogFilter) (string, []any) {
clauses := make([]string, 0, 12)
args := make([]any, 0, 12)
......@@ -948,7 +1194,7 @@ func buildOpsErrorLogsWhere(filter *service.OpsErrorLogFilter) (string, []any) {
}
// Keep list endpoints scoped to client errors unless explicitly filtering upstream phase.
if phaseFilter != "upstream" {
clauses = append(clauses, "COALESCE(status_code, 0) >= 400")
clauses = append(clauses, "COALESCE(e.status_code, 0) >= 400")
}
if filter.StartTime != nil && !filter.StartTime.IsZero() {
......@@ -962,33 +1208,33 @@ func buildOpsErrorLogsWhere(filter *service.OpsErrorLogFilter) (string, []any) {
}
if p := strings.TrimSpace(filter.Platform); p != "" {
args = append(args, p)
clauses = append(clauses, "platform = $"+itoa(len(args)))
clauses = append(clauses, "e.platform = $"+itoa(len(args)))
}
if filter.GroupID != nil && *filter.GroupID > 0 {
args = append(args, *filter.GroupID)
clauses = append(clauses, "group_id = $"+itoa(len(args)))
clauses = append(clauses, "e.group_id = $"+itoa(len(args)))
}
if filter.AccountID != nil && *filter.AccountID > 0 {
args = append(args, *filter.AccountID)
clauses = append(clauses, "account_id = $"+itoa(len(args)))
clauses = append(clauses, "e.account_id = $"+itoa(len(args)))
}
if phase := phaseFilter; phase != "" {
args = append(args, phase)
clauses = append(clauses, "error_phase = $"+itoa(len(args)))
clauses = append(clauses, "e.error_phase = $"+itoa(len(args)))
}
if filter != nil {
if owner := strings.TrimSpace(strings.ToLower(filter.Owner)); owner != "" {
args = append(args, owner)
clauses = append(clauses, "LOWER(COALESCE(error_owner,'')) = $"+itoa(len(args)))
clauses = append(clauses, "LOWER(COALESCE(e.error_owner,'')) = $"+itoa(len(args)))
}
if source := strings.TrimSpace(strings.ToLower(filter.Source)); source != "" {
args = append(args, source)
clauses = append(clauses, "LOWER(COALESCE(error_source,'')) = $"+itoa(len(args)))
clauses = append(clauses, "LOWER(COALESCE(e.error_source,'')) = $"+itoa(len(args)))
}
}
if resolvedFilter != nil {
args = append(args, *resolvedFilter)
clauses = append(clauses, "COALESCE(resolved,false) = $"+itoa(len(args)))
clauses = append(clauses, "COALESCE(e.resolved,false) = $"+itoa(len(args)))
}
// View filter: errors vs excluded vs all.
......@@ -1000,51 +1246,140 @@ func buildOpsErrorLogsWhere(filter *service.OpsErrorLogFilter) (string, []any) {
}
switch view {
case "", "errors":
clauses = append(clauses, "COALESCE(is_business_limited,false) = false")
clauses = append(clauses, "COALESCE(e.is_business_limited,false) = false")
case "excluded":
clauses = append(clauses, "COALESCE(is_business_limited,false) = true")
clauses = append(clauses, "COALESCE(e.is_business_limited,false) = true")
case "all":
// no-op
default:
// treat unknown as default 'errors'
clauses = append(clauses, "COALESCE(is_business_limited,false) = false")
clauses = append(clauses, "COALESCE(e.is_business_limited,false) = false")
}
if len(filter.StatusCodes) > 0 {
args = append(args, pq.Array(filter.StatusCodes))
clauses = append(clauses, "COALESCE(upstream_status_code, status_code, 0) = ANY($"+itoa(len(args))+")")
clauses = append(clauses, "COALESCE(e.upstream_status_code, e.status_code, 0) = ANY($"+itoa(len(args))+")")
} else if filter.StatusCodesOther {
// "Other" means: status codes not in the common list.
known := []int{400, 401, 403, 404, 409, 422, 429, 500, 502, 503, 504, 529}
args = append(args, pq.Array(known))
clauses = append(clauses, "NOT (COALESCE(upstream_status_code, status_code, 0) = ANY($"+itoa(len(args))+"))")
clauses = append(clauses, "NOT (COALESCE(e.upstream_status_code, e.status_code, 0) = ANY($"+itoa(len(args))+"))")
}
// Exact correlation keys (preferred for request↔upstream linkage).
if rid := strings.TrimSpace(filter.RequestID); rid != "" {
args = append(args, rid)
clauses = append(clauses, "COALESCE(request_id,'') = $"+itoa(len(args)))
clauses = append(clauses, "COALESCE(e.request_id,'') = $"+itoa(len(args)))
}
if crid := strings.TrimSpace(filter.ClientRequestID); crid != "" {
args = append(args, crid)
clauses = append(clauses, "COALESCE(client_request_id,'') = $"+itoa(len(args)))
clauses = append(clauses, "COALESCE(e.client_request_id,'') = $"+itoa(len(args)))
}
if q := strings.TrimSpace(filter.Query); q != "" {
like := "%" + q + "%"
args = append(args, like)
n := itoa(len(args))
clauses = append(clauses, "(request_id ILIKE $"+n+" OR client_request_id ILIKE $"+n+" OR error_message ILIKE $"+n+")")
clauses = append(clauses, "(e.request_id ILIKE $"+n+" OR e.client_request_id ILIKE $"+n+" OR e.error_message ILIKE $"+n+")")
}
if userQuery := strings.TrimSpace(filter.UserQuery); userQuery != "" {
like := "%" + userQuery + "%"
args = append(args, like)
n := itoa(len(args))
clauses = append(clauses, "u.email ILIKE $"+n)
clauses = append(clauses, "EXISTS (SELECT 1 FROM users u WHERE u.id = e.user_id AND u.email ILIKE $"+n+")")
}
return "WHERE " + strings.Join(clauses, " AND "), args
}
func buildOpsSystemLogsWhere(filter *service.OpsSystemLogFilter) (string, []any, bool) {
clauses := make([]string, 0, 10)
args := make([]any, 0, 10)
clauses = append(clauses, "1=1")
hasConstraint := false
if filter != nil && filter.StartTime != nil && !filter.StartTime.IsZero() {
args = append(args, filter.StartTime.UTC())
clauses = append(clauses, "l.created_at >= $"+itoa(len(args)))
hasConstraint = true
}
if filter != nil && filter.EndTime != nil && !filter.EndTime.IsZero() {
args = append(args, filter.EndTime.UTC())
clauses = append(clauses, "l.created_at < $"+itoa(len(args)))
hasConstraint = true
}
if filter != nil {
if v := strings.ToLower(strings.TrimSpace(filter.Level)); v != "" {
args = append(args, v)
clauses = append(clauses, "LOWER(COALESCE(l.level,'')) = $"+itoa(len(args)))
hasConstraint = true
}
if v := strings.TrimSpace(filter.Component); v != "" {
args = append(args, v)
clauses = append(clauses, "COALESCE(l.component,'') = $"+itoa(len(args)))
hasConstraint = true
}
if v := strings.TrimSpace(filter.RequestID); v != "" {
args = append(args, v)
clauses = append(clauses, "COALESCE(l.request_id,'') = $"+itoa(len(args)))
hasConstraint = true
}
if v := strings.TrimSpace(filter.ClientRequestID); v != "" {
args = append(args, v)
clauses = append(clauses, "COALESCE(l.client_request_id,'') = $"+itoa(len(args)))
hasConstraint = true
}
if filter.UserID != nil && *filter.UserID > 0 {
args = append(args, *filter.UserID)
clauses = append(clauses, "l.user_id = $"+itoa(len(args)))
hasConstraint = true
}
if filter.AccountID != nil && *filter.AccountID > 0 {
args = append(args, *filter.AccountID)
clauses = append(clauses, "l.account_id = $"+itoa(len(args)))
hasConstraint = true
}
if v := strings.TrimSpace(filter.Platform); v != "" {
args = append(args, v)
clauses = append(clauses, "COALESCE(l.platform,'') = $"+itoa(len(args)))
hasConstraint = true
}
if v := strings.TrimSpace(filter.Model); v != "" {
args = append(args, v)
clauses = append(clauses, "COALESCE(l.model,'') = $"+itoa(len(args)))
hasConstraint = true
}
if v := strings.TrimSpace(filter.Query); v != "" {
like := "%" + v + "%"
args = append(args, like)
n := itoa(len(args))
clauses = append(clauses, "(l.message ILIKE $"+n+" OR COALESCE(l.request_id,'') ILIKE $"+n+" OR COALESCE(l.client_request_id,'') ILIKE $"+n+" OR COALESCE(l.extra::text,'') ILIKE $"+n+")")
hasConstraint = true
}
}
return "WHERE " + strings.Join(clauses, " AND "), args, hasConstraint
}
func buildOpsSystemLogsCleanupWhere(filter *service.OpsSystemLogCleanupFilter) (string, []any, bool) {
if filter == nil {
filter = &service.OpsSystemLogCleanupFilter{}
}
listFilter := &service.OpsSystemLogFilter{
StartTime: filter.StartTime,
EndTime: filter.EndTime,
Level: filter.Level,
Component: filter.Component,
RequestID: filter.RequestID,
ClientRequestID: filter.ClientRequestID,
UserID: filter.UserID,
AccountID: filter.AccountID,
Platform: filter.Platform,
Model: filter.Model,
Query: filter.Query,
}
return buildOpsSystemLogsWhere(listFilter)
}
// Helpers for nullable args
func opsNullString(v any) any {
switch s := v.(type) {
......
package repository
import (
"strings"
"testing"
"github.com/Wei-Shaw/sub2api/internal/service"
)
func TestBuildOpsErrorLogsWhere_QueryUsesQualifiedColumns(t *testing.T) {
filter := &service.OpsErrorLogFilter{
Query: "ACCESS_DENIED",
}
where, args := buildOpsErrorLogsWhere(filter)
if where == "" {
t.Fatalf("where should not be empty")
}
if len(args) != 1 {
t.Fatalf("args len = %d, want 1", len(args))
}
if !strings.Contains(where, "e.request_id ILIKE $") {
t.Fatalf("where should include qualified request_id condition: %s", where)
}
if !strings.Contains(where, "e.client_request_id ILIKE $") {
t.Fatalf("where should include qualified client_request_id condition: %s", where)
}
if !strings.Contains(where, "e.error_message ILIKE $") {
t.Fatalf("where should include qualified error_message condition: %s", where)
}
}
func TestBuildOpsErrorLogsWhere_UserQueryUsesExistsSubquery(t *testing.T) {
filter := &service.OpsErrorLogFilter{
UserQuery: "admin@",
}
where, args := buildOpsErrorLogsWhere(filter)
if where == "" {
t.Fatalf("where should not be empty")
}
if len(args) != 1 {
t.Fatalf("args len = %d, want 1", len(args))
}
if !strings.Contains(where, "EXISTS (SELECT 1 FROM users u WHERE u.id = e.user_id AND u.email ILIKE $") {
t.Fatalf("where should include EXISTS user email condition: %s", where)
}
}
package repository
import (
"context"
"database/sql"
"fmt"
"strings"
"github.com/Wei-Shaw/sub2api/internal/service"
)
func (r *opsRepository) GetOpenAITokenStats(ctx context.Context, filter *service.OpsOpenAITokenStatsFilter) (*service.OpsOpenAITokenStatsResponse, error) {
if r == nil || r.db == nil {
return nil, fmt.Errorf("nil ops repository")
}
if filter == nil {
return nil, fmt.Errorf("nil filter")
}
if filter.StartTime.IsZero() || filter.EndTime.IsZero() {
return nil, fmt.Errorf("start_time/end_time required")
}
// 允许 start_time == end_time(结果为空),与 service 层校验口径保持一致。
if filter.StartTime.After(filter.EndTime) {
return nil, fmt.Errorf("start_time must be <= end_time")
}
dashboardFilter := &service.OpsDashboardFilter{
StartTime: filter.StartTime.UTC(),
EndTime: filter.EndTime.UTC(),
Platform: strings.TrimSpace(strings.ToLower(filter.Platform)),
GroupID: filter.GroupID,
}
join, where, baseArgs, next := buildUsageWhere(dashboardFilter, dashboardFilter.StartTime, dashboardFilter.EndTime, 1)
where += " AND ul.model LIKE 'gpt%'"
baseCTE := `
WITH stats AS (
SELECT
ul.model AS model,
COUNT(*)::bigint AS request_count,
ROUND(
AVG(
CASE
WHEN ul.duration_ms > 0 AND ul.output_tokens > 0
THEN ul.output_tokens * 1000.0 / ul.duration_ms
END
)::numeric,
2
)::float8 AS avg_tokens_per_sec,
ROUND(AVG(ul.first_token_ms)::numeric, 2)::float8 AS avg_first_token_ms,
COALESCE(SUM(ul.output_tokens), 0)::bigint AS total_output_tokens,
COALESCE(ROUND(AVG(ul.duration_ms)::numeric, 0), 0)::bigint AS avg_duration_ms,
COUNT(CASE WHEN ul.first_token_ms IS NOT NULL THEN 1 END)::bigint AS requests_with_first_token
FROM usage_logs ul
` + join + `
` + where + `
GROUP BY ul.model
)
`
countSQL := baseCTE + `SELECT COUNT(*) FROM stats`
var total int64
if err := r.db.QueryRowContext(ctx, countSQL, baseArgs...).Scan(&total); err != nil {
return nil, err
}
querySQL := baseCTE + `
SELECT
model,
request_count,
avg_tokens_per_sec,
avg_first_token_ms,
total_output_tokens,
avg_duration_ms,
requests_with_first_token
FROM stats
ORDER BY request_count DESC, model ASC`
args := make([]any, 0, len(baseArgs)+2)
args = append(args, baseArgs...)
if filter.IsTopNMode() {
querySQL += fmt.Sprintf("\nLIMIT $%d", next)
args = append(args, filter.TopN)
} else {
offset := (filter.Page - 1) * filter.PageSize
querySQL += fmt.Sprintf("\nLIMIT $%d OFFSET $%d", next, next+1)
args = append(args, filter.PageSize, offset)
}
rows, err := r.db.QueryContext(ctx, querySQL, args...)
if err != nil {
return nil, err
}
defer func() { _ = rows.Close() }()
items := make([]*service.OpsOpenAITokenStatsItem, 0, 32)
for rows.Next() {
item := &service.OpsOpenAITokenStatsItem{}
var avgTPS sql.NullFloat64
var avgFirstToken sql.NullFloat64
if err := rows.Scan(
&item.Model,
&item.RequestCount,
&avgTPS,
&avgFirstToken,
&item.TotalOutputTokens,
&item.AvgDurationMs,
&item.RequestsWithFirstToken,
); err != nil {
return nil, err
}
if avgTPS.Valid {
v := avgTPS.Float64
item.AvgTokensPerSec = &v
}
if avgFirstToken.Valid {
v := avgFirstToken.Float64
item.AvgFirstTokenMs = &v
}
items = append(items, item)
}
if err := rows.Err(); err != nil {
return nil, err
}
resp := &service.OpsOpenAITokenStatsResponse{
TimeRange: strings.TrimSpace(filter.TimeRange),
StartTime: dashboardFilter.StartTime,
EndTime: dashboardFilter.EndTime,
Platform: dashboardFilter.Platform,
GroupID: dashboardFilter.GroupID,
Items: items,
Total: total,
}
if filter.IsTopNMode() {
topN := filter.TopN
resp.TopN = &topN
} else {
resp.Page = filter.Page
resp.PageSize = filter.PageSize
}
return resp, nil
}
package repository
import (
"context"
"testing"
"time"
"github.com/DATA-DOG/go-sqlmock"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/stretchr/testify/require"
)
func TestOpsRepositoryGetOpenAITokenStats_PaginationMode(t *testing.T) {
db, mock := newSQLMock(t)
repo := &opsRepository{db: db}
start := time.Date(2026, 1, 1, 0, 0, 0, 0, time.UTC)
end := start.Add(24 * time.Hour)
groupID := int64(9)
filter := &service.OpsOpenAITokenStatsFilter{
TimeRange: "1d",
StartTime: start,
EndTime: end,
Platform: " OpenAI ",
GroupID: &groupID,
Page: 2,
PageSize: 10,
}
mock.ExpectQuery(`SELECT COUNT\(\*\) FROM stats`).
WithArgs(start, end, groupID, "openai").
WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(int64(3)))
rows := sqlmock.NewRows([]string{
"model",
"request_count",
"avg_tokens_per_sec",
"avg_first_token_ms",
"total_output_tokens",
"avg_duration_ms",
"requests_with_first_token",
}).
AddRow("gpt-4o-mini", int64(20), 21.56, 120.34, int64(3000), int64(850), int64(18)).
AddRow("gpt-4.1", int64(20), 10.2, 240.0, int64(2500), int64(900), int64(20))
mock.ExpectQuery(`ORDER BY request_count DESC, model ASC\s+LIMIT \$5 OFFSET \$6`).
WithArgs(start, end, groupID, "openai", 10, 10).
WillReturnRows(rows)
resp, err := repo.GetOpenAITokenStats(context.Background(), filter)
require.NoError(t, err)
require.NotNil(t, resp)
require.Equal(t, int64(3), resp.Total)
require.Equal(t, 2, resp.Page)
require.Equal(t, 10, resp.PageSize)
require.Nil(t, resp.TopN)
require.Equal(t, "openai", resp.Platform)
require.NotNil(t, resp.GroupID)
require.Equal(t, groupID, *resp.GroupID)
require.Len(t, resp.Items, 2)
require.Equal(t, "gpt-4o-mini", resp.Items[0].Model)
require.NotNil(t, resp.Items[0].AvgTokensPerSec)
require.InDelta(t, 21.56, *resp.Items[0].AvgTokensPerSec, 0.0001)
require.NotNil(t, resp.Items[0].AvgFirstTokenMs)
require.InDelta(t, 120.34, *resp.Items[0].AvgFirstTokenMs, 0.0001)
require.NoError(t, mock.ExpectationsWereMet())
}
func TestOpsRepositoryGetOpenAITokenStats_TopNMode(t *testing.T) {
db, mock := newSQLMock(t)
repo := &opsRepository{db: db}
start := time.Date(2026, 1, 1, 10, 0, 0, 0, time.UTC)
end := start.Add(time.Hour)
filter := &service.OpsOpenAITokenStatsFilter{
TimeRange: "1h",
StartTime: start,
EndTime: end,
TopN: 5,
}
mock.ExpectQuery(`SELECT COUNT\(\*\) FROM stats`).
WithArgs(start, end).
WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(int64(1)))
rows := sqlmock.NewRows([]string{
"model",
"request_count",
"avg_tokens_per_sec",
"avg_first_token_ms",
"total_output_tokens",
"avg_duration_ms",
"requests_with_first_token",
}).
AddRow("gpt-4o", int64(5), nil, nil, int64(0), int64(0), int64(0))
mock.ExpectQuery(`ORDER BY request_count DESC, model ASC\s+LIMIT \$3`).
WithArgs(start, end, 5).
WillReturnRows(rows)
resp, err := repo.GetOpenAITokenStats(context.Background(), filter)
require.NoError(t, err)
require.NotNil(t, resp)
require.NotNil(t, resp.TopN)
require.Equal(t, 5, *resp.TopN)
require.Equal(t, 0, resp.Page)
require.Equal(t, 0, resp.PageSize)
require.Len(t, resp.Items, 1)
require.Nil(t, resp.Items[0].AvgTokensPerSec)
require.Nil(t, resp.Items[0].AvgFirstTokenMs)
require.NoError(t, mock.ExpectationsWereMet())
}
func TestOpsRepositoryGetOpenAITokenStats_EmptyResult(t *testing.T) {
db, mock := newSQLMock(t)
repo := &opsRepository{db: db}
start := time.Date(2026, 1, 2, 0, 0, 0, 0, time.UTC)
end := start.Add(30 * time.Minute)
filter := &service.OpsOpenAITokenStatsFilter{
TimeRange: "30m",
StartTime: start,
EndTime: end,
Page: 1,
PageSize: 20,
}
mock.ExpectQuery(`SELECT COUNT\(\*\) FROM stats`).
WithArgs(start, end).
WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(int64(0)))
mock.ExpectQuery(`ORDER BY request_count DESC, model ASC\s+LIMIT \$3 OFFSET \$4`).
WithArgs(start, end, 20, 0).
WillReturnRows(sqlmock.NewRows([]string{
"model",
"request_count",
"avg_tokens_per_sec",
"avg_first_token_ms",
"total_output_tokens",
"avg_duration_ms",
"requests_with_first_token",
}))
resp, err := repo.GetOpenAITokenStats(context.Background(), filter)
require.NoError(t, err)
require.NotNil(t, resp)
require.Equal(t, int64(0), resp.Total)
require.Len(t, resp.Items, 0)
require.Equal(t, 1, resp.Page)
require.Equal(t, 20, resp.PageSize)
require.NoError(t, mock.ExpectationsWereMet())
}
package repository
import (
"strings"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
)
func TestBuildOpsSystemLogsWhere_WithClientRequestIDAndUserID(t *testing.T) {
start := time.Date(2026, 2, 1, 0, 0, 0, 0, time.UTC)
end := time.Date(2026, 2, 2, 0, 0, 0, 0, time.UTC)
userID := int64(12)
accountID := int64(34)
filter := &service.OpsSystemLogFilter{
StartTime: &start,
EndTime: &end,
Level: "warn",
Component: "http.access",
RequestID: "req-1",
ClientRequestID: "creq-1",
UserID: &userID,
AccountID: &accountID,
Platform: "openai",
Model: "gpt-5",
Query: "timeout",
}
where, args, hasConstraint := buildOpsSystemLogsWhere(filter)
if !hasConstraint {
t.Fatalf("expected hasConstraint=true")
}
if where == "" {
t.Fatalf("where should not be empty")
}
if len(args) != 11 {
t.Fatalf("args len = %d, want 11", len(args))
}
if !contains(where, "COALESCE(l.client_request_id,'') = $") {
t.Fatalf("where should include client_request_id condition: %s", where)
}
if !contains(where, "l.user_id = $") {
t.Fatalf("where should include user_id condition: %s", where)
}
}
func TestBuildOpsSystemLogsCleanupWhere_RequireConstraint(t *testing.T) {
where, args, hasConstraint := buildOpsSystemLogsCleanupWhere(&service.OpsSystemLogCleanupFilter{})
if hasConstraint {
t.Fatalf("expected hasConstraint=false")
}
if where == "" {
t.Fatalf("where should not be empty")
}
if len(args) != 0 {
t.Fatalf("args len = %d, want 0", len(args))
}
}
func TestBuildOpsSystemLogsCleanupWhere_WithClientRequestIDAndUserID(t *testing.T) {
userID := int64(9)
filter := &service.OpsSystemLogCleanupFilter{
ClientRequestID: "creq-9",
UserID: &userID,
}
where, args, hasConstraint := buildOpsSystemLogsCleanupWhere(filter)
if !hasConstraint {
t.Fatalf("expected hasConstraint=true")
}
if len(args) != 2 {
t.Fatalf("args len = %d, want 2", len(args))
}
if !contains(where, "COALESCE(l.client_request_id,'') = $") {
t.Fatalf("where should include client_request_id condition: %s", where)
}
if !contains(where, "l.user_id = $") {
t.Fatalf("where should include user_id condition: %s", where)
}
}
func contains(s string, sub string) bool {
return strings.Contains(s, sub)
}
......@@ -132,7 +132,7 @@ func (r *promoCodeRepository) ListWithFilters(ctx context.Context, params pagina
q = q.Where(promocode.CodeContainsFold(search))
}
total, err := q.Count(ctx)
total, err := q.Clone().Count(ctx)
if err != nil {
return nil, nil, err
}
......@@ -187,7 +187,7 @@ func (r *promoCodeRepository) ListUsagesByPromoCode(ctx context.Context, promoCo
q := r.client.PromoCodeUsage.Query().
Where(promocodeusage.PromoCodeIDEQ(promoCodeID))
total, err := q.Count(ctx)
total, err := q.Clone().Count(ctx)
if err != nil {
return nil, nil, err
}
......
......@@ -19,10 +19,14 @@ func NewProxyExitInfoProber(cfg *config.Config) service.ProxyExitInfoProber {
insecure := false
allowPrivate := false
validateResolvedIP := true
maxResponseBytes := defaultProxyProbeResponseMaxBytes
if cfg != nil {
insecure = cfg.Security.ProxyProbe.InsecureSkipVerify
allowPrivate = cfg.Security.URLAllowlist.AllowPrivateHosts
validateResolvedIP = cfg.Security.URLAllowlist.Enabled
if cfg.Gateway.ProxyProbeResponseReadMaxBytes > 0 {
maxResponseBytes = cfg.Gateway.ProxyProbeResponseReadMaxBytes
}
}
if insecure {
log.Printf("[ProxyProbe] Warning: insecure_skip_verify is not allowed and will cause probe failure.")
......@@ -31,11 +35,13 @@ func NewProxyExitInfoProber(cfg *config.Config) service.ProxyExitInfoProber {
insecureSkipVerify: insecure,
allowPrivateHosts: allowPrivate,
validateResolvedIP: validateResolvedIP,
maxResponseBytes: maxResponseBytes,
}
}
const (
defaultProxyProbeTimeout = 30 * time.Second
defaultProxyProbeTimeout = 30 * time.Second
defaultProxyProbeResponseMaxBytes = int64(1024 * 1024)
)
// probeURLs 按优先级排列的探测 URL 列表
......@@ -52,6 +58,7 @@ type proxyProbeService struct {
insecureSkipVerify bool
allowPrivateHosts bool
validateResolvedIP bool
maxResponseBytes int64
}
func (s *proxyProbeService) ProbeProxy(ctx context.Context, proxyURL string) (*service.ProxyExitInfo, int64, error) {
......@@ -98,10 +105,17 @@ func (s *proxyProbeService) probeWithURL(ctx context.Context, client *http.Clien
return nil, latencyMs, fmt.Errorf("request failed with status: %d", resp.StatusCode)
}
body, err := io.ReadAll(resp.Body)
maxResponseBytes := s.maxResponseBytes
if maxResponseBytes <= 0 {
maxResponseBytes = defaultProxyProbeResponseMaxBytes
}
body, err := io.ReadAll(io.LimitReader(resp.Body, maxResponseBytes+1))
if err != nil {
return nil, latencyMs, fmt.Errorf("failed to read response: %w", err)
}
if int64(len(body)) > maxResponseBytes {
return nil, latencyMs, fmt.Errorf("proxy probe response exceeds limit: %d", maxResponseBytes)
}
switch parser {
case "ip-api":
......
package repository
import (
"context"
"crypto/rand"
"database/sql"
"encoding/hex"
"errors"
"fmt"
"log"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/securitysecret"
"github.com/Wei-Shaw/sub2api/internal/config"
)
const (
securitySecretKeyJWT = "jwt_secret"
securitySecretReadRetryMax = 5
securitySecretReadRetryWait = 10 * time.Millisecond
)
var readRandomBytes = rand.Read
func ensureBootstrapSecrets(ctx context.Context, client *ent.Client, cfg *config.Config) error {
if client == nil {
return fmt.Errorf("nil ent client")
}
if cfg == nil {
return fmt.Errorf("nil config")
}
cfg.JWT.Secret = strings.TrimSpace(cfg.JWT.Secret)
if cfg.JWT.Secret != "" {
storedSecret, err := createSecuritySecretIfAbsent(ctx, client, securitySecretKeyJWT, cfg.JWT.Secret)
if err != nil {
return fmt.Errorf("persist jwt secret: %w", err)
}
if storedSecret != cfg.JWT.Secret {
log.Println("Warning: configured JWT secret mismatches persisted value; using persisted secret for cross-instance consistency.")
}
cfg.JWT.Secret = storedSecret
return nil
}
secret, created, err := getOrCreateGeneratedSecuritySecret(ctx, client, securitySecretKeyJWT, 32)
if err != nil {
return fmt.Errorf("ensure jwt secret: %w", err)
}
cfg.JWT.Secret = secret
if created {
log.Println("Warning: JWT secret auto-generated and persisted to database. Consider rotating to a managed secret for production.")
}
return nil
}
func getOrCreateGeneratedSecuritySecret(ctx context.Context, client *ent.Client, key string, byteLength int) (string, bool, error) {
existing, err := client.SecuritySecret.Query().Where(securitysecret.KeyEQ(key)).Only(ctx)
if err == nil {
value := strings.TrimSpace(existing.Value)
if len([]byte(value)) < 32 {
return "", false, fmt.Errorf("stored secret %q must be at least 32 bytes", key)
}
return value, false, nil
}
if !ent.IsNotFound(err) {
return "", false, err
}
generated, err := generateHexSecret(byteLength)
if err != nil {
return "", false, err
}
if err := client.SecuritySecret.Create().
SetKey(key).
SetValue(generated).
OnConflictColumns(securitysecret.FieldKey).
DoNothing().
Exec(ctx); err != nil {
if !isSQLNoRowsError(err) {
return "", false, err
}
}
stored, err := querySecuritySecretWithRetry(ctx, client, key)
if err != nil {
return "", false, err
}
value := strings.TrimSpace(stored.Value)
if len([]byte(value)) < 32 {
return "", false, fmt.Errorf("stored secret %q must be at least 32 bytes", key)
}
return value, value == generated, nil
}
func createSecuritySecretIfAbsent(ctx context.Context, client *ent.Client, key, value string) (string, error) {
value = strings.TrimSpace(value)
if len([]byte(value)) < 32 {
return "", fmt.Errorf("secret %q must be at least 32 bytes", key)
}
if err := client.SecuritySecret.Create().
SetKey(key).
SetValue(value).
OnConflictColumns(securitysecret.FieldKey).
DoNothing().
Exec(ctx); err != nil {
if !isSQLNoRowsError(err) {
return "", err
}
}
stored, err := querySecuritySecretWithRetry(ctx, client, key)
if err != nil {
return "", err
}
storedValue := strings.TrimSpace(stored.Value)
if len([]byte(storedValue)) < 32 {
return "", fmt.Errorf("stored secret %q must be at least 32 bytes", key)
}
return storedValue, nil
}
func querySecuritySecretWithRetry(ctx context.Context, client *ent.Client, key string) (*ent.SecuritySecret, error) {
var lastErr error
for attempt := 0; attempt <= securitySecretReadRetryMax; attempt++ {
stored, err := client.SecuritySecret.Query().Where(securitysecret.KeyEQ(key)).Only(ctx)
if err == nil {
return stored, nil
}
if !isSecretNotFoundError(err) {
return nil, err
}
lastErr = err
if attempt == securitySecretReadRetryMax {
break
}
timer := time.NewTimer(securitySecretReadRetryWait)
select {
case <-ctx.Done():
timer.Stop()
return nil, ctx.Err()
case <-timer.C:
}
}
return nil, lastErr
}
func isSecretNotFoundError(err error) bool {
if err == nil {
return false
}
return ent.IsNotFound(err) || isSQLNoRowsError(err)
}
func isSQLNoRowsError(err error) bool {
if err == nil {
return false
}
return errors.Is(err, sql.ErrNoRows) || strings.Contains(err.Error(), "no rows in result set")
}
func generateHexSecret(byteLength int) (string, error) {
if byteLength <= 0 {
byteLength = 32
}
buf := make([]byte, byteLength)
if _, err := readRandomBytes(buf); err != nil {
return "", fmt.Errorf("generate random secret: %w", err)
}
return hex.EncodeToString(buf), nil
}
package repository
import (
"context"
"database/sql"
"encoding/hex"
"errors"
"fmt"
"strings"
"sync"
"testing"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/enttest"
"github.com/Wei-Shaw/sub2api/ent/securitysecret"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/stretchr/testify/require"
"entgo.io/ent/dialect"
entsql "entgo.io/ent/dialect/sql"
_ "modernc.org/sqlite"
)
func newSecuritySecretTestClient(t *testing.T) *dbent.Client {
t.Helper()
name := strings.ReplaceAll(t.Name(), "/", "_")
dsn := fmt.Sprintf("file:%s?mode=memory&cache=shared&_fk=1", name)
db, err := sql.Open("sqlite", dsn)
require.NoError(t, err)
t.Cleanup(func() { _ = db.Close() })
_, err = db.Exec("PRAGMA foreign_keys = ON")
require.NoError(t, err)
drv := entsql.OpenDB(dialect.SQLite, db)
client := enttest.NewClient(t, enttest.WithOptions(dbent.Driver(drv)))
t.Cleanup(func() { _ = client.Close() })
return client
}
func TestEnsureBootstrapSecretsNilInputs(t *testing.T) {
err := ensureBootstrapSecrets(context.Background(), nil, &config.Config{})
require.Error(t, err)
require.Contains(t, err.Error(), "nil ent client")
client := newSecuritySecretTestClient(t)
err = ensureBootstrapSecrets(context.Background(), client, nil)
require.Error(t, err)
require.Contains(t, err.Error(), "nil config")
}
func TestEnsureBootstrapSecretsGenerateAndPersistJWTSecret(t *testing.T) {
client := newSecuritySecretTestClient(t)
cfg := &config.Config{}
err := ensureBootstrapSecrets(context.Background(), client, cfg)
require.NoError(t, err)
require.NotEmpty(t, cfg.JWT.Secret)
require.GreaterOrEqual(t, len([]byte(cfg.JWT.Secret)), 32)
stored, err := client.SecuritySecret.Query().Where(securitysecret.KeyEQ(securitySecretKeyJWT)).Only(context.Background())
require.NoError(t, err)
require.Equal(t, cfg.JWT.Secret, stored.Value)
}
func TestEnsureBootstrapSecretsLoadExistingJWTSecret(t *testing.T) {
client := newSecuritySecretTestClient(t)
_, err := client.SecuritySecret.Create().SetKey(securitySecretKeyJWT).SetValue("existing-jwt-secret-32bytes-long!!!!").Save(context.Background())
require.NoError(t, err)
cfg := &config.Config{}
err = ensureBootstrapSecrets(context.Background(), client, cfg)
require.NoError(t, err)
require.Equal(t, "existing-jwt-secret-32bytes-long!!!!", cfg.JWT.Secret)
}
func TestEnsureBootstrapSecretsRejectInvalidStoredSecret(t *testing.T) {
client := newSecuritySecretTestClient(t)
_, err := client.SecuritySecret.Create().SetKey(securitySecretKeyJWT).SetValue("too-short").Save(context.Background())
require.NoError(t, err)
cfg := &config.Config{}
err = ensureBootstrapSecrets(context.Background(), client, cfg)
require.Error(t, err)
require.Contains(t, err.Error(), "at least 32 bytes")
}
func TestEnsureBootstrapSecretsPersistConfiguredJWTSecret(t *testing.T) {
client := newSecuritySecretTestClient(t)
cfg := &config.Config{
JWT: config.JWTConfig{Secret: "configured-jwt-secret-32bytes-long!!"},
}
err := ensureBootstrapSecrets(context.Background(), client, cfg)
require.NoError(t, err)
stored, err := client.SecuritySecret.Query().Where(securitysecret.KeyEQ(securitySecretKeyJWT)).Only(context.Background())
require.NoError(t, err)
require.Equal(t, "configured-jwt-secret-32bytes-long!!", stored.Value)
}
func TestEnsureBootstrapSecretsConfiguredSecretTooShort(t *testing.T) {
client := newSecuritySecretTestClient(t)
cfg := &config.Config{JWT: config.JWTConfig{Secret: "short"}}
err := ensureBootstrapSecrets(context.Background(), client, cfg)
require.Error(t, err)
require.Contains(t, err.Error(), "at least 32 bytes")
}
func TestEnsureBootstrapSecretsConfiguredSecretDuplicateIgnored(t *testing.T) {
client := newSecuritySecretTestClient(t)
_, err := client.SecuritySecret.Create().
SetKey(securitySecretKeyJWT).
SetValue("existing-jwt-secret-32bytes-long!!!!").
Save(context.Background())
require.NoError(t, err)
cfg := &config.Config{JWT: config.JWTConfig{Secret: "another-configured-jwt-secret-32!!!!"}}
err = ensureBootstrapSecrets(context.Background(), client, cfg)
require.NoError(t, err)
stored, err := client.SecuritySecret.Query().Where(securitysecret.KeyEQ(securitySecretKeyJWT)).Only(context.Background())
require.NoError(t, err)
require.Equal(t, "existing-jwt-secret-32bytes-long!!!!", stored.Value)
require.Equal(t, "existing-jwt-secret-32bytes-long!!!!", cfg.JWT.Secret)
}
func TestGetOrCreateGeneratedSecuritySecretTrimmedExistingValue(t *testing.T) {
client := newSecuritySecretTestClient(t)
_, err := client.SecuritySecret.Create().
SetKey("trimmed_key").
SetValue(" existing-trimmed-secret-32bytes-long!! ").
Save(context.Background())
require.NoError(t, err)
value, created, err := getOrCreateGeneratedSecuritySecret(context.Background(), client, "trimmed_key", 32)
require.NoError(t, err)
require.False(t, created)
require.Equal(t, "existing-trimmed-secret-32bytes-long!!", value)
}
func TestGetOrCreateGeneratedSecuritySecretQueryError(t *testing.T) {
client := newSecuritySecretTestClient(t)
require.NoError(t, client.Close())
_, _, err := getOrCreateGeneratedSecuritySecret(context.Background(), client, "closed_client_key", 32)
require.Error(t, err)
}
func TestGetOrCreateGeneratedSecuritySecretCreateValidationError(t *testing.T) {
client := newSecuritySecretTestClient(t)
tooLongKey := strings.Repeat("k", 101)
_, _, err := getOrCreateGeneratedSecuritySecret(context.Background(), client, tooLongKey, 32)
require.Error(t, err)
}
func TestGetOrCreateGeneratedSecuritySecretConcurrentCreation(t *testing.T) {
client := newSecuritySecretTestClient(t)
const goroutines = 8
key := "concurrent_bootstrap_key"
values := make([]string, goroutines)
createdFlags := make([]bool, goroutines)
errs := make([]error, goroutines)
var wg sync.WaitGroup
for i := 0; i < goroutines; i++ {
wg.Add(1)
go func(idx int) {
defer wg.Done()
values[idx], createdFlags[idx], errs[idx] = getOrCreateGeneratedSecuritySecret(context.Background(), client, key, 32)
}(i)
}
wg.Wait()
for i := range errs {
require.NoError(t, errs[i])
require.NotEmpty(t, values[i])
}
for i := 1; i < len(values); i++ {
require.Equal(t, values[0], values[i])
}
createdCount := 0
for _, created := range createdFlags {
if created {
createdCount++
}
}
require.GreaterOrEqual(t, createdCount, 1)
require.LessOrEqual(t, createdCount, 1)
count, err := client.SecuritySecret.Query().Where(securitysecret.KeyEQ(key)).Count(context.Background())
require.NoError(t, err)
require.Equal(t, 1, count)
}
func TestGetOrCreateGeneratedSecuritySecretGenerateError(t *testing.T) {
client := newSecuritySecretTestClient(t)
originalRead := readRandomBytes
readRandomBytes = func([]byte) (int, error) {
return 0, errors.New("boom")
}
t.Cleanup(func() {
readRandomBytes = originalRead
})
_, _, err := getOrCreateGeneratedSecuritySecret(context.Background(), client, "gen_error_key", 32)
require.Error(t, err)
require.Contains(t, err.Error(), "boom")
}
func TestCreateSecuritySecretIfAbsent(t *testing.T) {
client := newSecuritySecretTestClient(t)
_, err := createSecuritySecretIfAbsent(context.Background(), client, "abc", "short")
require.Error(t, err)
require.Contains(t, err.Error(), "at least 32 bytes")
stored, err := createSecuritySecretIfAbsent(context.Background(), client, "abc", "valid-jwt-secret-value-32bytes-long")
require.NoError(t, err)
require.Equal(t, "valid-jwt-secret-value-32bytes-long", stored)
stored, err = createSecuritySecretIfAbsent(context.Background(), client, "abc", "another-valid-secret-value-32bytes")
require.NoError(t, err)
require.Equal(t, "valid-jwt-secret-value-32bytes-long", stored)
count, err := client.SecuritySecret.Query().Where(securitysecret.KeyEQ("abc")).Count(context.Background())
require.NoError(t, err)
require.Equal(t, 1, count)
}
func TestCreateSecuritySecretIfAbsentValidationError(t *testing.T) {
client := newSecuritySecretTestClient(t)
_, err := createSecuritySecretIfAbsent(
context.Background(),
client,
strings.Repeat("k", 101),
"valid-jwt-secret-value-32bytes-long",
)
require.Error(t, err)
}
func TestCreateSecuritySecretIfAbsentExecError(t *testing.T) {
client := newSecuritySecretTestClient(t)
require.NoError(t, client.Close())
_, err := createSecuritySecretIfAbsent(context.Background(), client, "closed-client-key", "valid-jwt-secret-value-32bytes-long")
require.Error(t, err)
}
func TestQuerySecuritySecretWithRetrySuccess(t *testing.T) {
client := newSecuritySecretTestClient(t)
created, err := client.SecuritySecret.Create().
SetKey("retry_success_key").
SetValue("retry-success-jwt-secret-value-32!!").
Save(context.Background())
require.NoError(t, err)
got, err := querySecuritySecretWithRetry(context.Background(), client, "retry_success_key")
require.NoError(t, err)
require.Equal(t, created.ID, got.ID)
require.Equal(t, "retry-success-jwt-secret-value-32!!", got.Value)
}
func TestQuerySecuritySecretWithRetryExhausted(t *testing.T) {
client := newSecuritySecretTestClient(t)
_, err := querySecuritySecretWithRetry(context.Background(), client, "retry_missing_key")
require.Error(t, err)
require.True(t, isSecretNotFoundError(err))
}
func TestQuerySecuritySecretWithRetryContextCanceled(t *testing.T) {
client := newSecuritySecretTestClient(t)
ctx, cancel := context.WithTimeout(context.Background(), securitySecretReadRetryWait/2)
defer cancel()
_, err := querySecuritySecretWithRetry(ctx, client, "retry_ctx_cancel_key")
require.Error(t, err)
require.ErrorIs(t, err, context.DeadlineExceeded)
}
func TestQuerySecuritySecretWithRetryNonNotFoundError(t *testing.T) {
client := newSecuritySecretTestClient(t)
require.NoError(t, client.Close())
_, err := querySecuritySecretWithRetry(context.Background(), client, "retry_closed_client_key")
require.Error(t, err)
require.False(t, isSecretNotFoundError(err))
}
func TestSecretNotFoundHelpers(t *testing.T) {
require.False(t, isSecretNotFoundError(nil))
require.False(t, isSQLNoRowsError(nil))
require.True(t, isSQLNoRowsError(sql.ErrNoRows))
require.True(t, isSQLNoRowsError(fmt.Errorf("wrapped: %w", sql.ErrNoRows)))
require.True(t, isSQLNoRowsError(errors.New("sql: no rows in result set")))
require.True(t, isSecretNotFoundError(sql.ErrNoRows))
require.True(t, isSecretNotFoundError(errors.New("sql: no rows in result set")))
require.False(t, isSecretNotFoundError(errors.New("some other error")))
}
func TestGenerateHexSecretReadError(t *testing.T) {
originalRead := readRandomBytes
readRandomBytes = func([]byte) (int, error) {
return 0, errors.New("read random failed")
}
t.Cleanup(func() {
readRandomBytes = originalRead
})
_, err := generateHexSecret(32)
require.Error(t, err)
require.Contains(t, err.Error(), "read random failed")
}
func TestGenerateHexSecretLengths(t *testing.T) {
v1, err := generateHexSecret(0)
require.NoError(t, err)
require.Len(t, v1, 64)
_, err = hex.DecodeString(v1)
require.NoError(t, err)
v2, err := generateHexSecret(16)
require.NoError(t, err)
require.Len(t, v2, 32)
_, err = hex.DecodeString(v2)
require.NoError(t, err)
require.NotEqual(t, v1, v2)
}
package repository
import (
"context"
"database/sql"
"errors"
"github.com/Wei-Shaw/sub2api/internal/service"
)
// soraAccountRepository 实现 service.SoraAccountRepository 接口。
// 使用原生 SQL 操作 sora_accounts 表,因为该表不在 Ent ORM 管理范围内。
//
// 设计说明:
// - sora_accounts 表是独立迁移创建的,不通过 Ent Schema 管理
// - 使用 ON CONFLICT (account_id) DO UPDATE 实现 Upsert 语义
// - 与 accounts 主表通过外键关联,ON DELETE CASCADE 确保级联删除
type soraAccountRepository struct {
sql *sql.DB
}
// NewSoraAccountRepository 创建 Sora 账号扩展表仓储实例
func NewSoraAccountRepository(sqlDB *sql.DB) service.SoraAccountRepository {
return &soraAccountRepository{sql: sqlDB}
}
// Upsert 创建或更新 Sora 账号扩展信息
// 使用 PostgreSQL ON CONFLICT ... DO UPDATE 实现原子性 upsert
func (r *soraAccountRepository) Upsert(ctx context.Context, accountID int64, updates map[string]any) error {
accessToken, accessOK := updates["access_token"].(string)
refreshToken, refreshOK := updates["refresh_token"].(string)
sessionToken, sessionOK := updates["session_token"].(string)
if !accessOK || accessToken == "" || !refreshOK || refreshToken == "" {
if !sessionOK {
return errors.New("缺少 access_token/refresh_token,且未提供可更新字段")
}
result, err := r.sql.ExecContext(ctx, `
UPDATE sora_accounts
SET session_token = CASE WHEN $2 = '' THEN session_token ELSE $2 END,
updated_at = NOW()
WHERE account_id = $1
`, accountID, sessionToken)
if err != nil {
return err
}
rows, err := result.RowsAffected()
if err != nil {
return err
}
if rows == 0 {
return errors.New("sora_accounts 记录不存在,无法仅更新 session_token")
}
return nil
}
_, err := r.sql.ExecContext(ctx, `
INSERT INTO sora_accounts (account_id, access_token, refresh_token, session_token, created_at, updated_at)
VALUES ($1, $2, $3, $4, NOW(), NOW())
ON CONFLICT (account_id) DO UPDATE SET
access_token = EXCLUDED.access_token,
refresh_token = EXCLUDED.refresh_token,
session_token = CASE WHEN EXCLUDED.session_token = '' THEN sora_accounts.session_token ELSE EXCLUDED.session_token END,
updated_at = NOW()
`, accountID, accessToken, refreshToken, sessionToken)
return err
}
// GetByAccountID 根据账号 ID 获取 Sora 扩展信息
func (r *soraAccountRepository) GetByAccountID(ctx context.Context, accountID int64) (*service.SoraAccount, error) {
rows, err := r.sql.QueryContext(ctx, `
SELECT account_id, access_token, refresh_token, COALESCE(session_token, '')
FROM sora_accounts
WHERE account_id = $1
`, accountID)
if err != nil {
return nil, err
}
defer func() { _ = rows.Close() }()
if !rows.Next() {
return nil, nil // 记录不存在
}
var sa service.SoraAccount
if err := rows.Scan(&sa.AccountID, &sa.AccessToken, &sa.RefreshToken, &sa.SessionToken); err != nil {
return nil, err
}
return &sa, nil
}
// Delete 删除 Sora 账号扩展信息
func (r *soraAccountRepository) Delete(ctx context.Context, accountID int64) error {
_, err := r.sql.ExecContext(ctx, `
DELETE FROM sora_accounts WHERE account_id = $1
`, accountID)
return err
}
......@@ -22,7 +22,23 @@ import (
"github.com/lib/pq"
)
const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, account_rate_multiplier, billing_type, stream, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, reasoning_effort, cache_ttl_overridden, created_at"
const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, account_rate_multiplier, billing_type, stream, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, media_type, reasoning_effort, cache_ttl_overridden, created_at"
// dateFormatWhitelist 将 granularity 参数映射为 PostgreSQL TO_CHAR 格式字符串,防止外部输入直接拼入 SQL
var dateFormatWhitelist = map[string]string{
"hour": "YYYY-MM-DD HH24:00",
"day": "YYYY-MM-DD",
"week": "IYYY-IW",
"month": "YYYY-MM",
}
// safeDateFormat 根据白名单获取 dateFormat,未匹配时返回默认值
func safeDateFormat(granularity string) string {
if f, ok := dateFormatWhitelist[granularity]; ok {
return f
}
return "YYYY-MM-DD"
}
type usageLogRepository struct {
client *dbent.Client
......@@ -111,23 +127,24 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
duration_ms,
first_token_ms,
user_agent,
ip_address,
image_count,
image_size,
reasoning_effort,
cache_ttl_overridden,
created_at
) VALUES (
$1, $2, $3, $4, $5,
$6, $7,
$8, $9, $10, $11,
$12, $13,
$14, $15, $16, $17, $18, $19,
$20, $21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31, $32
)
ON CONFLICT (request_id, api_key_id) DO NOTHING
RETURNING id, created_at
`
ip_address,
image_count,
image_size,
media_type,
reasoning_effort,
cache_ttl_overridden,
created_at
) VALUES (
$1, $2, $3, $4, $5,
$6, $7,
$8, $9, $10, $11,
$12, $13,
$14, $15, $16, $17, $18, $19,
$20, $21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31, $32, $33
)
ON CONFLICT (request_id, api_key_id) DO NOTHING
RETURNING id, created_at
`
groupID := nullInt64(log.GroupID)
subscriptionID := nullInt64(log.SubscriptionID)
......@@ -136,6 +153,7 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
userAgent := nullString(log.UserAgent)
ipAddress := nullString(log.IPAddress)
imageSize := nullString(log.ImageSize)
mediaType := nullString(log.MediaType)
reasoningEffort := nullString(log.ReasoningEffort)
var requestIDArg any
......@@ -173,6 +191,7 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
ipAddress,
log.ImageCount,
imageSize,
mediaType,
reasoningEffort,
log.CacheTTLOverridden,
createdAt,
......@@ -566,7 +585,7 @@ func (r *usageLogRepository) ListByAccount(ctx context.Context, accountID int64,
}
func (r *usageLogRepository) ListByUserAndTimeRange(ctx context.Context, userID int64, startTime, endTime time.Time) ([]service.UsageLog, *pagination.PaginationResult, error) {
query := "SELECT " + usageLogSelectColumns + " FROM usage_logs WHERE user_id = $1 AND created_at >= $2 AND created_at < $3 ORDER BY id DESC"
query := "SELECT " + usageLogSelectColumns + " FROM usage_logs WHERE user_id = $1 AND created_at >= $2 AND created_at < $3 ORDER BY id DESC LIMIT 10000"
logs, err := r.queryUsageLogs(ctx, query, userID, startTime, endTime)
return logs, nil, err
}
......@@ -812,19 +831,19 @@ func resolveUsageStatsTimezone() string {
}
func (r *usageLogRepository) ListByAPIKeyAndTimeRange(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) ([]service.UsageLog, *pagination.PaginationResult, error) {
query := "SELECT " + usageLogSelectColumns + " FROM usage_logs WHERE api_key_id = $1 AND created_at >= $2 AND created_at < $3 ORDER BY id DESC"
query := "SELECT " + usageLogSelectColumns + " FROM usage_logs WHERE api_key_id = $1 AND created_at >= $2 AND created_at < $3 ORDER BY id DESC LIMIT 10000"
logs, err := r.queryUsageLogs(ctx, query, apiKeyID, startTime, endTime)
return logs, nil, err
}
func (r *usageLogRepository) ListByAccountAndTimeRange(ctx context.Context, accountID int64, startTime, endTime time.Time) ([]service.UsageLog, *pagination.PaginationResult, error) {
query := "SELECT " + usageLogSelectColumns + " FROM usage_logs WHERE account_id = $1 AND created_at >= $2 AND created_at < $3 ORDER BY id DESC"
query := "SELECT " + usageLogSelectColumns + " FROM usage_logs WHERE account_id = $1 AND created_at >= $2 AND created_at < $3 ORDER BY id DESC LIMIT 10000"
logs, err := r.queryUsageLogs(ctx, query, accountID, startTime, endTime)
return logs, nil, err
}
func (r *usageLogRepository) ListByModelAndTimeRange(ctx context.Context, modelName string, startTime, endTime time.Time) ([]service.UsageLog, *pagination.PaginationResult, error) {
query := "SELECT " + usageLogSelectColumns + " FROM usage_logs WHERE model = $1 AND created_at >= $2 AND created_at < $3 ORDER BY id DESC"
query := "SELECT " + usageLogSelectColumns + " FROM usage_logs WHERE model = $1 AND created_at >= $2 AND created_at < $3 ORDER BY id DESC LIMIT 10000"
logs, err := r.queryUsageLogs(ctx, query, modelName, startTime, endTime)
return logs, nil, err
}
......@@ -896,6 +915,59 @@ func (r *usageLogRepository) GetAccountWindowStats(ctx context.Context, accountI
return stats, nil
}
// GetAccountWindowStatsBatch 批量获取同一窗口起点下多个账号的统计数据。
// 返回 map[accountID]*AccountStats,未命中的账号会返回零值统计,便于上层直接复用。
func (r *usageLogRepository) GetAccountWindowStatsBatch(ctx context.Context, accountIDs []int64, startTime time.Time) (map[int64]*usagestats.AccountStats, error) {
result := make(map[int64]*usagestats.AccountStats, len(accountIDs))
if len(accountIDs) == 0 {
return result, nil
}
query := `
SELECT
account_id,
COUNT(*) as requests,
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as tokens,
COALESCE(SUM(total_cost * COALESCE(account_rate_multiplier, 1)), 0) as cost,
COALESCE(SUM(total_cost), 0) as standard_cost,
COALESCE(SUM(actual_cost), 0) as user_cost
FROM usage_logs
WHERE account_id = ANY($1) AND created_at >= $2
GROUP BY account_id
`
rows, err := r.sql.QueryContext(ctx, query, pq.Array(accountIDs), startTime)
if err != nil {
return nil, err
}
defer func() { _ = rows.Close() }()
for rows.Next() {
var accountID int64
stats := &usagestats.AccountStats{}
if err := rows.Scan(
&accountID,
&stats.Requests,
&stats.Tokens,
&stats.Cost,
&stats.StandardCost,
&stats.UserCost,
); err != nil {
return nil, err
}
result[accountID] = stats
}
if err := rows.Err(); err != nil {
return nil, err
}
for _, accountID := range accountIDs {
if _, ok := result[accountID]; !ok {
result[accountID] = &usagestats.AccountStats{}
}
}
return result, nil
}
// TrendDataPoint represents a single point in trend data
type TrendDataPoint = usagestats.TrendDataPoint
......@@ -910,10 +982,7 @@ type APIKeyUsageTrendPoint = usagestats.APIKeyUsageTrendPoint
// GetAPIKeyUsageTrend returns usage trend data grouped by API key and date
func (r *usageLogRepository) GetAPIKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) (results []APIKeyUsageTrendPoint, err error) {
dateFormat := "YYYY-MM-DD"
if granularity == "hour" {
dateFormat = "YYYY-MM-DD HH24:00"
}
dateFormat := safeDateFormat(granularity)
query := fmt.Sprintf(`
WITH top_keys AS (
......@@ -968,10 +1037,7 @@ func (r *usageLogRepository) GetAPIKeyUsageTrend(ctx context.Context, startTime,
// GetUserUsageTrend returns usage trend data grouped by user and date
func (r *usageLogRepository) GetUserUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) (results []UserUsageTrendPoint, err error) {
dateFormat := "YYYY-MM-DD"
if granularity == "hour" {
dateFormat = "YYYY-MM-DD HH24:00"
}
dateFormat := safeDateFormat(granularity)
query := fmt.Sprintf(`
WITH top_users AS (
......@@ -1230,10 +1296,7 @@ func (r *usageLogRepository) GetAPIKeyDashboardStats(ctx context.Context, apiKey
// GetUserUsageTrendByUserID 获取指定用户的使用趋势
func (r *usageLogRepository) GetUserUsageTrendByUserID(ctx context.Context, userID int64, startTime, endTime time.Time, granularity string) (results []TrendDataPoint, err error) {
dateFormat := "YYYY-MM-DD"
if granularity == "hour" {
dateFormat = "YYYY-MM-DD HH24:00"
}
dateFormat := safeDateFormat(granularity)
query := fmt.Sprintf(`
SELECT
......@@ -1371,13 +1434,22 @@ type UsageStats = usagestats.UsageStats
// BatchUserUsageStats represents usage stats for a single user
type BatchUserUsageStats = usagestats.BatchUserUsageStats
// GetBatchUserUsageStats gets today and total actual_cost for multiple users
func (r *usageLogRepository) GetBatchUserUsageStats(ctx context.Context, userIDs []int64) (map[int64]*BatchUserUsageStats, error) {
// GetBatchUserUsageStats gets today and total actual_cost for multiple users within a time range.
// If startTime is zero, defaults to 30 days ago.
func (r *usageLogRepository) GetBatchUserUsageStats(ctx context.Context, userIDs []int64, startTime, endTime time.Time) (map[int64]*BatchUserUsageStats, error) {
result := make(map[int64]*BatchUserUsageStats)
if len(userIDs) == 0 {
return result, nil
}
// 默认最近 30 天
if startTime.IsZero() {
startTime = time.Now().AddDate(0, 0, -30)
}
if endTime.IsZero() {
endTime = time.Now()
}
for _, id := range userIDs {
result[id] = &BatchUserUsageStats{UserID: id}
}
......@@ -1385,10 +1457,10 @@ func (r *usageLogRepository) GetBatchUserUsageStats(ctx context.Context, userIDs
query := `
SELECT user_id, COALESCE(SUM(actual_cost), 0) as total_cost
FROM usage_logs
WHERE user_id = ANY($1)
WHERE user_id = ANY($1) AND created_at >= $2 AND created_at < $3
GROUP BY user_id
`
rows, err := r.sql.QueryContext(ctx, query, pq.Array(userIDs))
rows, err := r.sql.QueryContext(ctx, query, pq.Array(userIDs), startTime, endTime)
if err != nil {
return nil, err
}
......@@ -1445,13 +1517,22 @@ func (r *usageLogRepository) GetBatchUserUsageStats(ctx context.Context, userIDs
// BatchAPIKeyUsageStats represents usage stats for a single API key
type BatchAPIKeyUsageStats = usagestats.BatchAPIKeyUsageStats
// GetBatchAPIKeyUsageStats gets today and total actual_cost for multiple API keys
func (r *usageLogRepository) GetBatchAPIKeyUsageStats(ctx context.Context, apiKeyIDs []int64) (map[int64]*BatchAPIKeyUsageStats, error) {
// GetBatchAPIKeyUsageStats gets today and total actual_cost for multiple API keys within a time range.
// If startTime is zero, defaults to 30 days ago.
func (r *usageLogRepository) GetBatchAPIKeyUsageStats(ctx context.Context, apiKeyIDs []int64, startTime, endTime time.Time) (map[int64]*BatchAPIKeyUsageStats, error) {
result := make(map[int64]*BatchAPIKeyUsageStats)
if len(apiKeyIDs) == 0 {
return result, nil
}
// 默认最近 30 天
if startTime.IsZero() {
startTime = time.Now().AddDate(0, 0, -30)
}
if endTime.IsZero() {
endTime = time.Now()
}
for _, id := range apiKeyIDs {
result[id] = &BatchAPIKeyUsageStats{APIKeyID: id}
}
......@@ -1459,10 +1540,10 @@ func (r *usageLogRepository) GetBatchAPIKeyUsageStats(ctx context.Context, apiKe
query := `
SELECT api_key_id, COALESCE(SUM(actual_cost), 0) as total_cost
FROM usage_logs
WHERE api_key_id = ANY($1)
WHERE api_key_id = ANY($1) AND created_at >= $2 AND created_at < $3
GROUP BY api_key_id
`
rows, err := r.sql.QueryContext(ctx, query, pq.Array(apiKeyIDs))
rows, err := r.sql.QueryContext(ctx, query, pq.Array(apiKeyIDs), startTime, endTime)
if err != nil {
return nil, err
}
......@@ -1518,10 +1599,7 @@ func (r *usageLogRepository) GetBatchAPIKeyUsageStats(ctx context.Context, apiKe
// GetUsageTrendWithFilters returns usage trend data with optional filters
func (r *usageLogRepository) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID, accountID, groupID int64, model string, stream *bool, billingType *int8) (results []TrendDataPoint, err error) {
dateFormat := "YYYY-MM-DD"
if granularity == "hour" {
dateFormat = "YYYY-MM-DD HH24:00"
}
dateFormat := safeDateFormat(granularity)
query := fmt.Sprintf(`
SELECT
......@@ -2196,6 +2274,7 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
ipAddress sql.NullString
imageCount int
imageSize sql.NullString
mediaType sql.NullString
reasoningEffort sql.NullString
cacheTTLOverridden bool
createdAt time.Time
......@@ -2232,6 +2311,7 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
&ipAddress,
&imageCount,
&imageSize,
&mediaType,
&reasoningEffort,
&cacheTTLOverridden,
&createdAt,
......@@ -2294,6 +2374,9 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
if imageSize.Valid {
log.ImageSize = &imageSize.String
}
if mediaType.Valid {
log.MediaType = &mediaType.String
}
if reasoningEffort.Valid {
log.ReasoningEffort = &reasoningEffort.String
}
......
......@@ -648,7 +648,7 @@ func (s *UsageLogRepoSuite) TestGetBatchUserUsageStats() {
s.createUsageLog(user1, apiKey1, account, 10, 20, 0.5, time.Now())
s.createUsageLog(user2, apiKey2, account, 15, 25, 0.6, time.Now())
stats, err := s.repo.GetBatchUserUsageStats(s.ctx, []int64{user1.ID, user2.ID})
stats, err := s.repo.GetBatchUserUsageStats(s.ctx, []int64{user1.ID, user2.ID}, time.Time{}, time.Time{})
s.Require().NoError(err, "GetBatchUserUsageStats")
s.Require().Len(stats, 2)
s.Require().NotNil(stats[user1.ID])
......@@ -656,7 +656,7 @@ func (s *UsageLogRepoSuite) TestGetBatchUserUsageStats() {
}
func (s *UsageLogRepoSuite) TestGetBatchUserUsageStats_Empty() {
stats, err := s.repo.GetBatchUserUsageStats(s.ctx, []int64{})
stats, err := s.repo.GetBatchUserUsageStats(s.ctx, []int64{}, time.Time{}, time.Time{})
s.Require().NoError(err)
s.Require().Empty(stats)
}
......@@ -672,13 +672,13 @@ func (s *UsageLogRepoSuite) TestGetBatchApiKeyUsageStats() {
s.createUsageLog(user, apiKey1, account, 10, 20, 0.5, time.Now())
s.createUsageLog(user, apiKey2, account, 15, 25, 0.6, time.Now())
stats, err := s.repo.GetBatchAPIKeyUsageStats(s.ctx, []int64{apiKey1.ID, apiKey2.ID})
stats, err := s.repo.GetBatchAPIKeyUsageStats(s.ctx, []int64{apiKey1.ID, apiKey2.ID}, time.Time{}, time.Time{})
s.Require().NoError(err, "GetBatchAPIKeyUsageStats")
s.Require().Len(stats, 2)
}
func (s *UsageLogRepoSuite) TestGetBatchApiKeyUsageStats_Empty() {
stats, err := s.repo.GetBatchAPIKeyUsageStats(s.ctx, []int64{})
stats, err := s.repo.GetBatchAPIKeyUsageStats(s.ctx, []int64{}, time.Time{}, time.Time{})
s.Require().NoError(err)
s.Require().Empty(stats)
}
......
//go:build unit
package repository
import (
"testing"
"github.com/stretchr/testify/require"
)
func TestSafeDateFormat(t *testing.T) {
tests := []struct {
name string
granularity string
expected string
}{
// 合法值
{"hour", "hour", "YYYY-MM-DD HH24:00"},
{"day", "day", "YYYY-MM-DD"},
{"week", "week", "IYYY-IW"},
{"month", "month", "YYYY-MM"},
// 非法值回退到默认
{"空字符串", "", "YYYY-MM-DD"},
{"未知粒度 year", "year", "YYYY-MM-DD"},
{"未知粒度 minute", "minute", "YYYY-MM-DD"},
// 恶意字符串
{"SQL 注入尝试", "'; DROP TABLE users; --", "YYYY-MM-DD"},
{"带引号", "day'", "YYYY-MM-DD"},
{"带括号", "day)", "YYYY-MM-DD"},
{"Unicode", "日", "YYYY-MM-DD"},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
got := safeDateFormat(tc.granularity)
require.Equal(t, tc.expected, got, "safeDateFormat(%q)", tc.granularity)
})
}
}
......@@ -28,7 +28,7 @@ func ProvideConcurrencyCache(rdb *redis.Client, cfg *config.Config) service.Conc
// ProvideGitHubReleaseClient 创建 GitHub Release 客户端
// 从配置中读取代理设置,支持国内服务器通过代理访问 GitHub
func ProvideGitHubReleaseClient(cfg *config.Config) service.GitHubReleaseClient {
return NewGitHubReleaseClient(cfg.Update.ProxyURL)
return NewGitHubReleaseClient(cfg.Update.ProxyURL, cfg.Security.ProxyFallback.AllowDirectOnError)
}
// ProvidePricingRemoteClient 创建定价数据远程客户端
......@@ -53,12 +53,14 @@ var ProviderSet = wire.NewSet(
NewAPIKeyRepository,
NewGroupRepository,
NewAccountRepository,
NewSoraAccountRepository, // Sora 账号扩展表仓储
NewProxyRepository,
NewRedeemCodeRepository,
NewPromoCodeRepository,
NewAnnouncementRepository,
NewAnnouncementReadRepository,
NewUsageLogRepository,
NewIdempotencyRepository,
NewUsageCleanupRepository,
NewDashboardAggregationRepository,
NewSettingRepository,
......
......@@ -83,6 +83,7 @@ func TestAPIContracts(t *testing.T) {
"status": "active",
"ip_whitelist": null,
"ip_blacklist": null,
"last_used_at": null,
"quota": 0,
"quota_used": 0,
"expires_at": null,
......@@ -122,6 +123,7 @@ func TestAPIContracts(t *testing.T) {
"status": "active",
"ip_whitelist": null,
"ip_blacklist": null,
"last_used_at": null,
"quota": 0,
"quota_used": 0,
"expires_at": null,
......@@ -184,6 +186,10 @@ func TestAPIContracts(t *testing.T) {
"image_price_1k": null,
"image_price_2k": null,
"image_price_4k": null,
"sora_image_price_360": null,
"sora_image_price_540": null,
"sora_video_price_per_request": null,
"sora_video_price_per_request_hd": null,
"claude_code_only": false,
"fallback_group_id": null,
"fallback_group_id_on_invalid_request": null,
......@@ -401,6 +407,7 @@ func TestAPIContracts(t *testing.T) {
"first_token_ms": 50,
"image_count": 0,
"image_size": null,
"media_type": null,
"cache_ttl_overridden": false,
"created_at": "2025-01-02T03:04:05Z",
"user_agent": null
......@@ -593,13 +600,13 @@ func newContractDeps(t *testing.T) *contractDeps {
RunMode: config.RunModeStandard,
}
userService := service.NewUserService(userRepo, nil)
userService := service.NewUserService(userRepo, nil, nil)
apiKeyService := service.NewAPIKeyService(apiKeyRepo, userRepo, groupRepo, userSubRepo, nil, apiKeyCache, cfg)
usageRepo := newStubUsageLogRepo()
usageService := service.NewUsageService(usageRepo, userRepo, nil, nil)
subscriptionService := service.NewSubscriptionService(groupRepo, userSubRepo, nil)
subscriptionService := service.NewSubscriptionService(groupRepo, userSubRepo, nil, nil, cfg)
subscriptionHandler := handler.NewSubscriptionHandler(subscriptionService)
redeemService := service.NewRedeemService(redeemRepo, userRepo, subscriptionService, nil, nil, nil, nil)
......@@ -608,7 +615,7 @@ func newContractDeps(t *testing.T) *contractDeps {
settingRepo := newStubSettingRepo()
settingService := service.NewSettingService(settingRepo, cfg)
adminService := service.NewAdminService(userRepo, groupRepo, &accountRepo, proxyRepo, apiKeyRepo, redeemRepo, nil, nil, nil, nil, nil)
adminService := service.NewAdminService(userRepo, groupRepo, &accountRepo, nil, proxyRepo, apiKeyRepo, redeemRepo, nil, nil, nil, nil, nil)
authHandler := handler.NewAuthHandler(cfg, nil, userService, settingService, nil, redeemService, nil)
apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService)
usageHandler := handler.NewUsageHandler(usageService, apiKeyService)
......@@ -925,6 +932,10 @@ func (s *stubAccountRepo) GetByCRSAccountID(ctx context.Context, crsAccountID st
return nil, errors.New("not implemented")
}
func (s *stubAccountRepo) FindByExtraField(ctx context.Context, key string, value any) ([]service.Account, error) {
return nil, errors.New("not implemented")
}
func (s *stubAccountRepo) Update(ctx context.Context, account *service.Account) error {
return errors.New("not implemented")
}
......@@ -1462,6 +1473,20 @@ func (r *stubApiKeyRepo) IncrementQuotaUsed(ctx context.Context, id int64, amoun
return 0, errors.New("not implemented")
}
func (r *stubApiKeyRepo) UpdateLastUsed(ctx context.Context, id int64, usedAt time.Time) error {
key, ok := r.byID[id]
if !ok {
return service.ErrAPIKeyNotFound
}
ts := usedAt
key.LastUsedAt = &ts
key.UpdatedAt = usedAt
clone := *key
r.byID[id] = &clone
r.byKey[clone.Key] = &clone
return nil
}
type stubUsageLogRepo struct {
userLogs map[int64][]service.UsageLog
}
......@@ -1607,11 +1632,11 @@ func (r *stubUsageLogRepo) GetDailyStatsAggregated(ctx context.Context, userID i
return nil, errors.New("not implemented")
}
func (r *stubUsageLogRepo) GetBatchUserUsageStats(ctx context.Context, userIDs []int64) (map[int64]*usagestats.BatchUserUsageStats, error) {
func (r *stubUsageLogRepo) GetBatchUserUsageStats(ctx context.Context, userIDs []int64, startTime, endTime time.Time) (map[int64]*usagestats.BatchUserUsageStats, error) {
return nil, errors.New("not implemented")
}
func (r *stubUsageLogRepo) GetBatchAPIKeyUsageStats(ctx context.Context, apiKeyIDs []int64) (map[int64]*usagestats.BatchAPIKeyUsageStats, error) {
func (r *stubUsageLogRepo) GetBatchAPIKeyUsageStats(ctx context.Context, apiKeyIDs []int64, startTime, endTime time.Time) (map[int64]*usagestats.BatchAPIKeyUsageStats, error) {
return nil, errors.New("not implemented")
}
......
......@@ -51,6 +51,9 @@ func ProvideRouter(
if err := r.SetTrustedProxies(nil); err != nil {
log.Printf("Failed to disable trusted proxies: %v", err)
}
if cfg.Server.Mode == "release" {
log.Printf("Warning: server.trusted_proxies is empty in release mode; client IP trust chain is disabled")
}
}
return SetupRouter(r, handlers, jwtAuth, adminAuth, apiKeyAuth, apiKeyService, subscriptionService, opsService, settingService, cfg, redisClient)
......
......@@ -58,8 +58,13 @@ func adminAuth(
authHeader := c.GetHeader("Authorization")
if authHeader != "" {
parts := strings.SplitN(authHeader, " ", 2)
if len(parts) == 2 && parts[0] == "Bearer" {
if !validateJWTForAdmin(c, parts[1], authService, userService) {
if len(parts) == 2 && strings.EqualFold(parts[0], "Bearer") {
token := strings.TrimSpace(parts[1])
if token == "" {
AbortWithError(c, 401, "UNAUTHORIZED", "Authorization required")
return
}
if !validateJWTForAdmin(c, token, authService, userService) {
return
}
c.Next()
......@@ -176,6 +181,12 @@ func validateJWTForAdmin(
return false
}
// 校验 TokenVersion,确保管理员改密后旧 token 失效
if claims.TokenVersion != user.TokenVersion {
AbortWithError(c, 401, "TOKEN_REVOKED", "Token has been revoked (password changed)")
return false
}
// 检查管理员权限
if !user.IsAdmin() {
AbortWithError(c, 403, "FORBIDDEN", "Admin access required")
......
//go:build unit
package middleware
import (
"context"
"net/http"
"net/http/httptest"
"testing"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
func TestAdminAuthJWTValidatesTokenVersion(t *testing.T) {
gin.SetMode(gin.TestMode)
cfg := &config.Config{JWT: config.JWTConfig{Secret: "test-secret", ExpireHour: 1}}
authService := service.NewAuthService(nil, nil, nil, cfg, nil, nil, nil, nil, nil)
admin := &service.User{
ID: 1,
Email: "admin@example.com",
Role: service.RoleAdmin,
Status: service.StatusActive,
TokenVersion: 2,
Concurrency: 1,
}
userRepo := &stubUserRepo{
getByID: func(ctx context.Context, id int64) (*service.User, error) {
if id != admin.ID {
return nil, service.ErrUserNotFound
}
clone := *admin
return &clone, nil
},
}
userService := service.NewUserService(userRepo, nil, nil)
router := gin.New()
router.Use(gin.HandlerFunc(NewAdminAuthMiddleware(authService, userService, nil)))
router.GET("/t", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"ok": true})
})
t.Run("token_version_mismatch_rejected", func(t *testing.T) {
token, err := authService.GenerateToken(&service.User{
ID: admin.ID,
Email: admin.Email,
Role: admin.Role,
TokenVersion: admin.TokenVersion - 1,
})
require.NoError(t, err)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/t", nil)
req.Header.Set("Authorization", "Bearer "+token)
router.ServeHTTP(w, req)
require.Equal(t, http.StatusUnauthorized, w.Code)
require.Contains(t, w.Body.String(), "TOKEN_REVOKED")
})
t.Run("token_version_match_allows", func(t *testing.T) {
token, err := authService.GenerateToken(&service.User{
ID: admin.ID,
Email: admin.Email,
Role: admin.Role,
TokenVersion: admin.TokenVersion,
})
require.NoError(t, err)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/t", nil)
req.Header.Set("Authorization", "Bearer "+token)
router.ServeHTTP(w, req)
require.Equal(t, http.StatusOK, w.Code)
})
t.Run("websocket_token_version_mismatch_rejected", func(t *testing.T) {
token, err := authService.GenerateToken(&service.User{
ID: admin.ID,
Email: admin.Email,
Role: admin.Role,
TokenVersion: admin.TokenVersion - 1,
})
require.NoError(t, err)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/t", nil)
req.Header.Set("Upgrade", "websocket")
req.Header.Set("Connection", "Upgrade")
req.Header.Set("Sec-WebSocket-Protocol", "sub2api-admin, jwt."+token)
router.ServeHTTP(w, req)
require.Equal(t, http.StatusUnauthorized, w.Code)
require.Contains(t, w.Body.String(), "TOKEN_REVOKED")
})
t.Run("websocket_token_version_match_allows", func(t *testing.T) {
token, err := authService.GenerateToken(&service.User{
ID: admin.ID,
Email: admin.Email,
Role: admin.Role,
TokenVersion: admin.TokenVersion,
})
require.NoError(t, err)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/t", nil)
req.Header.Set("Upgrade", "websocket")
req.Header.Set("Connection", "Upgrade")
req.Header.Set("Sec-WebSocket-Protocol", "sub2api-admin, jwt."+token)
router.ServeHTTP(w, req)
require.Equal(t, http.StatusOK, w.Code)
})
}
type stubUserRepo struct {
getByID func(ctx context.Context, id int64) (*service.User, error)
}
func (s *stubUserRepo) Create(ctx context.Context, user *service.User) error {
panic("unexpected Create call")
}
func (s *stubUserRepo) GetByID(ctx context.Context, id int64) (*service.User, error) {
if s.getByID == nil {
panic("GetByID not stubbed")
}
return s.getByID(ctx, id)
}
func (s *stubUserRepo) GetByEmail(ctx context.Context, email string) (*service.User, error) {
panic("unexpected GetByEmail call")
}
func (s *stubUserRepo) GetFirstAdmin(ctx context.Context) (*service.User, error) {
panic("unexpected GetFirstAdmin call")
}
func (s *stubUserRepo) Update(ctx context.Context, user *service.User) error {
panic("unexpected Update call")
}
func (s *stubUserRepo) Delete(ctx context.Context, id int64) error {
panic("unexpected Delete call")
}
func (s *stubUserRepo) List(ctx context.Context, params pagination.PaginationParams) ([]service.User, *pagination.PaginationResult, error) {
panic("unexpected List call")
}
func (s *stubUserRepo) ListWithFilters(ctx context.Context, params pagination.PaginationParams, filters service.UserListFilters) ([]service.User, *pagination.PaginationResult, error) {
panic("unexpected ListWithFilters call")
}
func (s *stubUserRepo) UpdateBalance(ctx context.Context, id int64, amount float64) error {
panic("unexpected UpdateBalance call")
}
func (s *stubUserRepo) DeductBalance(ctx context.Context, id int64, amount float64) error {
panic("unexpected DeductBalance call")
}
func (s *stubUserRepo) UpdateConcurrency(ctx context.Context, id int64, amount int) error {
panic("unexpected UpdateConcurrency call")
}
func (s *stubUserRepo) ExistsByEmail(ctx context.Context, email string) (bool, error) {
panic("unexpected ExistsByEmail call")
}
func (s *stubUserRepo) RemoveGroupFromAllowedGroups(ctx context.Context, groupID int64) (int64, error) {
panic("unexpected RemoveGroupFromAllowedGroups call")
}
func (s *stubUserRepo) UpdateTotpSecret(ctx context.Context, userID int64, encryptedSecret *string) error {
panic("unexpected UpdateTotpSecret call")
}
func (s *stubUserRepo) EnableTotp(ctx context.Context, userID int64) error {
panic("unexpected EnableTotp call")
}
func (s *stubUserRepo) DisableTotp(ctx context.Context, userID int64) error {
panic("unexpected DisableTotp call")
}
......@@ -3,7 +3,6 @@ package middleware
import (
"context"
"errors"
"log"
"strings"
"github.com/Wei-Shaw/sub2api/internal/config"
......@@ -36,8 +35,8 @@ func apiKeyAuthWithSubscription(apiKeyService *service.APIKeyService, subscripti
if authHeader != "" {
// 验证Bearer scheme
parts := strings.SplitN(authHeader, " ", 2)
if len(parts) == 2 && parts[0] == "Bearer" {
apiKeyString = parts[1]
if len(parts) == 2 && strings.EqualFold(parts[0], "Bearer") {
apiKeyString = strings.TrimSpace(parts[1])
}
}
......@@ -97,7 +96,7 @@ func apiKeyAuthWithSubscription(apiKeyService *service.APIKeyService, subscripti
// 检查 IP 限制(白名单/黑名单)
// 注意:错误信息故意模糊,避免暴露具体的 IP 限制机制
if len(apiKey.IPWhitelist) > 0 || len(apiKey.IPBlacklist) > 0 {
clientIP := ip.GetClientIP(c)
clientIP := ip.GetTrustedClientIP(c)
allowed, _ := ip.CheckIPRestriction(clientIP, apiKey.IPWhitelist, apiKey.IPBlacklist)
if !allowed {
AbortWithError(c, 403, "ACCESS_DENIED", "Access denied")
......@@ -126,6 +125,7 @@ func apiKeyAuthWithSubscription(apiKeyService *service.APIKeyService, subscripti
})
c.Set(string(ContextKeyUserRole), apiKey.User.Role)
setGroupContext(c, apiKey.Group)
_ = apiKeyService.TouchLastUsed(c.Request.Context(), apiKey.ID)
c.Next()
return
}
......@@ -134,7 +134,7 @@ func apiKeyAuthWithSubscription(apiKeyService *service.APIKeyService, subscripti
isSubscriptionType := apiKey.Group != nil && apiKey.Group.IsSubscriptionType()
if isSubscriptionType && subscriptionService != nil {
// 订阅模式:验证订阅
// 订阅模式:获取订阅(L1 缓存 + singleflight)
subscription, err := subscriptionService.GetActiveSubscription(
c.Request.Context(),
apiKey.User.ID,
......@@ -145,30 +145,30 @@ func apiKeyAuthWithSubscription(apiKeyService *service.APIKeyService, subscripti
return
}
// 验证订阅状态(是否过期、暂停等)
if err := subscriptionService.ValidateSubscription(c.Request.Context(), subscription); err != nil {
AbortWithError(c, 403, "SUBSCRIPTION_INVALID", err.Error())
return
}
// 激活滑动窗口(首次使用时)
if err := subscriptionService.CheckAndActivateWindow(c.Request.Context(), subscription); err != nil {
log.Printf("Failed to activate subscription windows: %v", err)
}
// 检查并重置过期窗口
if err := subscriptionService.CheckAndResetWindows(c.Request.Context(), subscription); err != nil {
log.Printf("Failed to reset subscription windows: %v", err)
}
// 预检查用量限制(使用0作为额外费用进行预检查)
if err := subscriptionService.CheckUsageLimits(c.Request.Context(), subscription, apiKey.Group, 0); err != nil {
AbortWithError(c, 429, "USAGE_LIMIT_EXCEEDED", err.Error())
// 合并验证 + 限额检查(纯内存操作)
needsMaintenance, err := subscriptionService.ValidateAndCheckLimits(subscription, apiKey.Group)
if err != nil {
code := "SUBSCRIPTION_INVALID"
status := 403
if errors.Is(err, service.ErrDailyLimitExceeded) ||
errors.Is(err, service.ErrWeeklyLimitExceeded) ||
errors.Is(err, service.ErrMonthlyLimitExceeded) {
code = "USAGE_LIMIT_EXCEEDED"
status = 429
}
AbortWithError(c, status, code, err.Error())
return
}
// 将订阅信息存入上下文
c.Set(string(ContextKeySubscription), subscription)
// 窗口维护异步化(不阻塞请求)
// 传递独立拷贝,避免与 handler 读取 context 中的 subscription 产生 data race
if needsMaintenance {
maintenanceCopy := *subscription
subscriptionService.DoWindowMaintenance(&maintenanceCopy)
}
} else {
// 余额模式:检查用户余额
if apiKey.User.Balance <= 0 {
......@@ -185,6 +185,7 @@ func apiKeyAuthWithSubscription(apiKeyService *service.APIKeyService, subscripti
})
c.Set(string(ContextKeyUserRole), apiKey.User.Role)
setGroupContext(c, apiKey.Group)
_ = apiKeyService.TouchLastUsed(c.Request.Context(), apiKey.ID)
c.Next()
}
......
......@@ -64,6 +64,7 @@ func APIKeyAuthWithSubscriptionGoogle(apiKeyService *service.APIKeyService, subs
})
c.Set(string(ContextKeyUserRole), apiKey.User.Role)
setGroupContext(c, apiKey.Group)
_ = apiKeyService.TouchLastUsed(c.Request.Context(), apiKey.ID)
c.Next()
return
}
......@@ -104,6 +105,7 @@ func APIKeyAuthWithSubscriptionGoogle(apiKeyService *service.APIKeyService, subs
})
c.Set(string(ContextKeyUserRole), apiKey.User.Role)
setGroupContext(c, apiKey.Group)
_ = apiKeyService.TouchLastUsed(c.Request.Context(), apiKey.ID)
c.Next()
}
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment