Commit 3d79773b authored by kyx236's avatar kyx236
Browse files

Merge branch 'main' of https://github.com/james-6-23/sub2api

parents 6aa8cbbf 742e73c9
//go:build integration
package repository
import (
"context"
"crypto/sha256"
"encoding/hex"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/stretchr/testify/require"
)
// hashedTestValue returns a unique SHA-256 hex string (64 chars) that fits VARCHAR(64) columns.
func hashedTestValue(t *testing.T, prefix string) string {
t.Helper()
sum := sha256.Sum256([]byte(uniqueTestValue(t, prefix)))
return hex.EncodeToString(sum[:])
}
func TestIdempotencyRepo_CreateProcessing_CompeteSameKey(t *testing.T) {
tx := testTx(t)
repo := &idempotencyRepository{sql: tx}
ctx := context.Background()
now := time.Now().UTC()
record := &service.IdempotencyRecord{
Scope: uniqueTestValue(t, "idem-scope-create"),
IdempotencyKeyHash: hashedTestValue(t, "idem-hash"),
RequestFingerprint: hashedTestValue(t, "idem-fp"),
Status: service.IdempotencyStatusProcessing,
LockedUntil: ptrTime(now.Add(30 * time.Second)),
ExpiresAt: now.Add(24 * time.Hour),
}
owner, err := repo.CreateProcessing(ctx, record)
require.NoError(t, err)
require.True(t, owner)
require.NotZero(t, record.ID)
duplicate := &service.IdempotencyRecord{
Scope: record.Scope,
IdempotencyKeyHash: record.IdempotencyKeyHash,
RequestFingerprint: hashedTestValue(t, "idem-fp-other"),
Status: service.IdempotencyStatusProcessing,
LockedUntil: ptrTime(now.Add(30 * time.Second)),
ExpiresAt: now.Add(24 * time.Hour),
}
owner, err = repo.CreateProcessing(ctx, duplicate)
require.NoError(t, err)
require.False(t, owner, "same scope+key hash should be de-duplicated")
}
func TestIdempotencyRepo_TryReclaim_StatusAndLockWindow(t *testing.T) {
tx := testTx(t)
repo := &idempotencyRepository{sql: tx}
ctx := context.Background()
now := time.Now().UTC()
record := &service.IdempotencyRecord{
Scope: uniqueTestValue(t, "idem-scope-reclaim"),
IdempotencyKeyHash: hashedTestValue(t, "idem-hash-reclaim"),
RequestFingerprint: hashedTestValue(t, "idem-fp-reclaim"),
Status: service.IdempotencyStatusProcessing,
LockedUntil: ptrTime(now.Add(10 * time.Second)),
ExpiresAt: now.Add(24 * time.Hour),
}
owner, err := repo.CreateProcessing(ctx, record)
require.NoError(t, err)
require.True(t, owner)
require.NoError(t, repo.MarkFailedRetryable(
ctx,
record.ID,
"RETRYABLE_FAILURE",
now.Add(-2*time.Second),
now.Add(24*time.Hour),
))
newLockedUntil := now.Add(20 * time.Second)
reclaimed, err := repo.TryReclaim(
ctx,
record.ID,
service.IdempotencyStatusFailedRetryable,
now,
newLockedUntil,
now.Add(24*time.Hour),
)
require.NoError(t, err)
require.True(t, reclaimed, "failed_retryable + expired lock should allow reclaim")
got, err := repo.GetByScopeAndKeyHash(ctx, record.Scope, record.IdempotencyKeyHash)
require.NoError(t, err)
require.NotNil(t, got)
require.Equal(t, service.IdempotencyStatusProcessing, got.Status)
require.NotNil(t, got.LockedUntil)
require.True(t, got.LockedUntil.After(now))
require.NoError(t, repo.MarkFailedRetryable(
ctx,
record.ID,
"RETRYABLE_FAILURE",
now.Add(20*time.Second),
now.Add(24*time.Hour),
))
reclaimed, err = repo.TryReclaim(
ctx,
record.ID,
service.IdempotencyStatusFailedRetryable,
now,
now.Add(40*time.Second),
now.Add(24*time.Hour),
)
require.NoError(t, err)
require.False(t, reclaimed, "within lock window should not reclaim")
}
func TestIdempotencyRepo_StatusTransition_ToSucceeded(t *testing.T) {
tx := testTx(t)
repo := &idempotencyRepository{sql: tx}
ctx := context.Background()
now := time.Now().UTC()
record := &service.IdempotencyRecord{
Scope: uniqueTestValue(t, "idem-scope-success"),
IdempotencyKeyHash: hashedTestValue(t, "idem-hash-success"),
RequestFingerprint: hashedTestValue(t, "idem-fp-success"),
Status: service.IdempotencyStatusProcessing,
LockedUntil: ptrTime(now.Add(10 * time.Second)),
ExpiresAt: now.Add(24 * time.Hour),
}
owner, err := repo.CreateProcessing(ctx, record)
require.NoError(t, err)
require.True(t, owner)
require.NoError(t, repo.MarkSucceeded(ctx, record.ID, 200, `{"ok":true}`, now.Add(24*time.Hour)))
got, err := repo.GetByScopeAndKeyHash(ctx, record.Scope, record.IdempotencyKeyHash)
require.NoError(t, err)
require.NotNil(t, got)
require.Equal(t, service.IdempotencyStatusSucceeded, got.Status)
require.NotNil(t, got.ResponseStatus)
require.Equal(t, 200, *got.ResponseStatus)
require.NotNil(t, got.ResponseBody)
require.Equal(t, `{"ok":true}`, *got.ResponseBody)
require.Nil(t, got.LockedUntil)
}
......@@ -12,7 +12,7 @@ import (
const (
fingerprintKeyPrefix = "fingerprint:"
fingerprintTTL = 24 * time.Hour
fingerprintTTL = 7 * 24 * time.Hour // 7天,配合每24小时懒续期可保持活跃账号永不过期
maskedSessionKeyPrefix = "masked_session:"
maskedSessionTTL = 15 * time.Minute
)
......
......@@ -50,6 +50,30 @@ CREATE TABLE IF NOT EXISTS atlas_schema_revisions (
// 任何稳定的 int64 值都可以,只要不与同一数据库中的其他锁冲突即可。
const migrationsAdvisoryLockID int64 = 694208311321144027
const migrationsLockRetryInterval = 500 * time.Millisecond
const nonTransactionalMigrationSuffix = "_notx.sql"
type migrationChecksumCompatibilityRule struct {
fileChecksum string
acceptedDBChecksum map[string]struct{}
}
// migrationChecksumCompatibilityRules 仅用于兼容历史上误修改过的迁移文件 checksum。
// 规则必须同时匹配「迁移名 + 当前文件 checksum + 历史库 checksum」才会放行,避免放宽全局校验。
var migrationChecksumCompatibilityRules = map[string]migrationChecksumCompatibilityRule{
"054_drop_legacy_cache_columns.sql": {
fileChecksum: "82de761156e03876653e7a6a4eee883cd927847036f779b0b9f34c42a8af7a7d",
acceptedDBChecksum: map[string]struct{}{
"182c193f3359946cf094090cd9e57d5c3fd9abaffbc1e8fc378646b8a6fa12b4": {},
},
},
"061_add_usage_log_request_type.sql": {
fileChecksum: "66207e7aa5dd0429c2e2c0fabdaf79783ff157fa0af2e81adff2ee03790ec65c",
acceptedDBChecksum: map[string]struct{}{
"08a248652cbab7cfde147fc6ef8cda464f2477674e20b718312faa252e0481c0": {},
"222b4a09c797c22e5922b6b172327c824f5463aaa8760e4f621bc5c22e2be0f3": {},
},
},
}
// ApplyMigrations 将嵌入的 SQL 迁移文件应用到指定的数据库。
//
......@@ -147,6 +171,10 @@ func applyMigrationsFS(ctx context.Context, db *sql.DB, fsys fs.FS) error {
if rowErr == nil {
// 迁移已应用,验证校验和是否匹配
if existing != checksum {
// 兼容特定历史误改场景(仅白名单规则),其余仍保持严格不可变约束。
if isMigrationChecksumCompatible(name, existing, checksum) {
continue
}
// 校验和不匹配意味着迁移文件在应用后被修改,这是危险的。
// 正确的做法是创建新的迁移文件来进行变更。
return fmt.Errorf(
......@@ -165,8 +193,34 @@ func applyMigrationsFS(ctx context.Context, db *sql.DB, fsys fs.FS) error {
return fmt.Errorf("check migration %s: %w", name, rowErr)
}
// 迁移未应用,在事务中执行。
// 使用事务确保迁移的原子性:要么完全成功,要么完全回滚。
nonTx, err := validateMigrationExecutionMode(name, content)
if err != nil {
return fmt.Errorf("validate migration %s: %w", name, err)
}
if nonTx {
// *_notx.sql:用于 CREATE/DROP INDEX CONCURRENTLY 场景,必须非事务执行。
// 逐条语句执行,避免将多条 CONCURRENTLY 语句放入同一个隐式事务块。
statements := splitSQLStatements(content)
for i, stmt := range statements {
trimmed := strings.TrimSpace(stmt)
if trimmed == "" {
continue
}
if stripSQLLineComment(trimmed) == "" {
continue
}
if _, err := db.ExecContext(ctx, trimmed); err != nil {
return fmt.Errorf("apply migration %s (non-tx statement %d): %w", name, i+1, err)
}
}
if _, err := db.ExecContext(ctx, "INSERT INTO schema_migrations (filename, checksum) VALUES ($1, $2)", name, checksum); err != nil {
return fmt.Errorf("record migration %s (non-tx): %w", name, err)
}
continue
}
// 默认迁移在事务中执行,确保原子性:要么完全成功,要么完全回滚。
tx, err := db.BeginTx(ctx, nil)
if err != nil {
return fmt.Errorf("begin migration %s: %w", name, err)
......@@ -268,6 +322,84 @@ func latestMigrationBaseline(fsys fs.FS) (string, string, string, error) {
return version, version, hash, nil
}
func isMigrationChecksumCompatible(name, dbChecksum, fileChecksum string) bool {
rule, ok := migrationChecksumCompatibilityRules[name]
if !ok {
return false
}
if rule.fileChecksum != fileChecksum {
return false
}
_, ok = rule.acceptedDBChecksum[dbChecksum]
return ok
}
func validateMigrationExecutionMode(name, content string) (bool, error) {
normalizedName := strings.ToLower(strings.TrimSpace(name))
upperContent := strings.ToUpper(content)
nonTx := strings.HasSuffix(normalizedName, nonTransactionalMigrationSuffix)
if !nonTx {
if strings.Contains(upperContent, "CONCURRENTLY") {
return false, errors.New("CONCURRENTLY statements must be placed in *_notx.sql migrations")
}
return false, nil
}
if strings.Contains(upperContent, "BEGIN") || strings.Contains(upperContent, "COMMIT") || strings.Contains(upperContent, "ROLLBACK") {
return false, errors.New("*_notx.sql must not contain transaction control statements (BEGIN/COMMIT/ROLLBACK)")
}
statements := splitSQLStatements(content)
for _, stmt := range statements {
normalizedStmt := strings.ToUpper(stripSQLLineComment(strings.TrimSpace(stmt)))
if normalizedStmt == "" {
continue
}
if strings.Contains(normalizedStmt, "CONCURRENTLY") {
isCreateIndex := strings.Contains(normalizedStmt, "CREATE") && strings.Contains(normalizedStmt, "INDEX")
isDropIndex := strings.Contains(normalizedStmt, "DROP") && strings.Contains(normalizedStmt, "INDEX")
if !isCreateIndex && !isDropIndex {
return false, errors.New("*_notx.sql currently only supports CREATE/DROP INDEX CONCURRENTLY statements")
}
if isCreateIndex && !strings.Contains(normalizedStmt, "IF NOT EXISTS") {
return false, errors.New("CREATE INDEX CONCURRENTLY in *_notx.sql must include IF NOT EXISTS for idempotency")
}
if isDropIndex && !strings.Contains(normalizedStmt, "IF EXISTS") {
return false, errors.New("DROP INDEX CONCURRENTLY in *_notx.sql must include IF EXISTS for idempotency")
}
continue
}
return false, errors.New("*_notx.sql must not mix non-CONCURRENTLY SQL statements")
}
return true, nil
}
func splitSQLStatements(content string) []string {
parts := strings.Split(content, ";")
out := make([]string, 0, len(parts))
for _, part := range parts {
if strings.TrimSpace(part) == "" {
continue
}
out = append(out, part)
}
return out
}
func stripSQLLineComment(s string) string {
lines := strings.Split(s, "\n")
for i, line := range lines {
if idx := strings.Index(line, "--"); idx >= 0 {
lines[i] = line[:idx]
}
}
return strings.TrimSpace(strings.Join(lines, "\n"))
}
// pgAdvisoryLock 获取 PostgreSQL Advisory Lock。
// Advisory Lock 是一种轻量级的锁机制,不与任何特定的数据库对象关联。
// 它非常适合用于应用层面的分布式锁场景,如迁移序列化。
......
package repository
import (
"testing"
"github.com/stretchr/testify/require"
)
func TestIsMigrationChecksumCompatible(t *testing.T) {
t.Run("054历史checksum可兼容", func(t *testing.T) {
ok := isMigrationChecksumCompatible(
"054_drop_legacy_cache_columns.sql",
"182c193f3359946cf094090cd9e57d5c3fd9abaffbc1e8fc378646b8a6fa12b4",
"82de761156e03876653e7a6a4eee883cd927847036f779b0b9f34c42a8af7a7d",
)
require.True(t, ok)
})
t.Run("054在未知文件checksum下不兼容", func(t *testing.T) {
ok := isMigrationChecksumCompatible(
"054_drop_legacy_cache_columns.sql",
"182c193f3359946cf094090cd9e57d5c3fd9abaffbc1e8fc378646b8a6fa12b4",
"0000000000000000000000000000000000000000000000000000000000000000",
)
require.False(t, ok)
})
t.Run("061历史checksum可兼容", func(t *testing.T) {
ok := isMigrationChecksumCompatible(
"061_add_usage_log_request_type.sql",
"08a248652cbab7cfde147fc6ef8cda464f2477674e20b718312faa252e0481c0",
"66207e7aa5dd0429c2e2c0fabdaf79783ff157fa0af2e81adff2ee03790ec65c",
)
require.True(t, ok)
})
t.Run("061第二个历史checksum可兼容", func(t *testing.T) {
ok := isMigrationChecksumCompatible(
"061_add_usage_log_request_type.sql",
"222b4a09c797c22e5922b6b172327c824f5463aaa8760e4f621bc5c22e2be0f3",
"66207e7aa5dd0429c2e2c0fabdaf79783ff157fa0af2e81adff2ee03790ec65c",
)
require.True(t, ok)
})
t.Run("非白名单迁移不兼容", func(t *testing.T) {
ok := isMigrationChecksumCompatible(
"001_init.sql",
"182c193f3359946cf094090cd9e57d5c3fd9abaffbc1e8fc378646b8a6fa12b4",
"82de761156e03876653e7a6a4eee883cd927847036f779b0b9f34c42a8af7a7d",
)
require.False(t, ok)
})
}
package repository
import (
"context"
"crypto/sha256"
"encoding/hex"
"errors"
"io/fs"
"strings"
"testing"
"testing/fstest"
"time"
sqlmock "github.com/DATA-DOG/go-sqlmock"
"github.com/stretchr/testify/require"
)
func TestApplyMigrations_NilDB(t *testing.T) {
err := ApplyMigrations(context.Background(), nil)
require.Error(t, err)
require.Contains(t, err.Error(), "nil sql db")
}
func TestApplyMigrations_DelegatesToApplyMigrationsFS(t *testing.T) {
db, mock, err := sqlmock.New()
require.NoError(t, err)
defer func() { _ = db.Close() }()
mock.ExpectQuery("SELECT pg_try_advisory_lock\\(\\$1\\)").
WithArgs(migrationsAdvisoryLockID).
WillReturnError(errors.New("lock failed"))
err = ApplyMigrations(context.Background(), db)
require.Error(t, err)
require.Contains(t, err.Error(), "acquire migrations lock")
require.NoError(t, mock.ExpectationsWereMet())
}
func TestLatestMigrationBaseline(t *testing.T) {
t.Run("empty_fs_returns_baseline", func(t *testing.T) {
version, description, hash, err := latestMigrationBaseline(fstest.MapFS{})
require.NoError(t, err)
require.Equal(t, "baseline", version)
require.Equal(t, "baseline", description)
require.Equal(t, "", hash)
})
t.Run("uses_latest_sorted_sql_file", func(t *testing.T) {
fsys := fstest.MapFS{
"001_init.sql": &fstest.MapFile{Data: []byte("CREATE TABLE t1(id int);")},
"010_final.sql": &fstest.MapFile{
Data: []byte("CREATE TABLE t2(id int);"),
},
}
version, description, hash, err := latestMigrationBaseline(fsys)
require.NoError(t, err)
require.Equal(t, "010_final", version)
require.Equal(t, "010_final", description)
require.Len(t, hash, 64)
})
t.Run("read_file_error", func(t *testing.T) {
fsys := fstest.MapFS{
"010_bad.sql": &fstest.MapFile{Mode: fs.ModeDir},
}
_, _, _, err := latestMigrationBaseline(fsys)
require.Error(t, err)
})
}
func TestIsMigrationChecksumCompatible_AdditionalCases(t *testing.T) {
require.False(t, isMigrationChecksumCompatible("unknown.sql", "db", "file"))
var (
name string
rule migrationChecksumCompatibilityRule
)
for n, r := range migrationChecksumCompatibilityRules {
name = n
rule = r
break
}
require.NotEmpty(t, name)
require.False(t, isMigrationChecksumCompatible(name, "db-not-accepted", "file-not-match"))
require.False(t, isMigrationChecksumCompatible(name, "db-not-accepted", rule.fileChecksum))
var accepted string
for checksum := range rule.acceptedDBChecksum {
accepted = checksum
break
}
require.NotEmpty(t, accepted)
require.True(t, isMigrationChecksumCompatible(name, accepted, rule.fileChecksum))
}
func TestEnsureAtlasBaselineAligned(t *testing.T) {
t.Run("skip_when_no_legacy_table", func(t *testing.T) {
db, mock, err := sqlmock.New()
require.NoError(t, err)
defer func() { _ = db.Close() }()
mock.ExpectQuery("SELECT EXISTS \\(").
WithArgs("schema_migrations").
WillReturnRows(sqlmock.NewRows([]string{"exists"}).AddRow(false))
err = ensureAtlasBaselineAligned(context.Background(), db, fstest.MapFS{})
require.NoError(t, err)
require.NoError(t, mock.ExpectationsWereMet())
})
t.Run("create_atlas_and_insert_baseline_when_empty", func(t *testing.T) {
db, mock, err := sqlmock.New()
require.NoError(t, err)
defer func() { _ = db.Close() }()
mock.ExpectQuery("SELECT EXISTS \\(").
WithArgs("schema_migrations").
WillReturnRows(sqlmock.NewRows([]string{"exists"}).AddRow(true))
mock.ExpectQuery("SELECT EXISTS \\(").
WithArgs("atlas_schema_revisions").
WillReturnRows(sqlmock.NewRows([]string{"exists"}).AddRow(false))
mock.ExpectExec("CREATE TABLE IF NOT EXISTS atlas_schema_revisions").
WillReturnResult(sqlmock.NewResult(0, 0))
mock.ExpectQuery("SELECT COUNT\\(\\*\\) FROM atlas_schema_revisions").
WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(0))
mock.ExpectExec("INSERT INTO atlas_schema_revisions").
WithArgs("002_next", "002_next", 1, sqlmock.AnyArg()).
WillReturnResult(sqlmock.NewResult(1, 1))
fsys := fstest.MapFS{
"001_init.sql": &fstest.MapFile{Data: []byte("CREATE TABLE t1(id int);")},
"002_next.sql": &fstest.MapFile{Data: []byte("CREATE TABLE t2(id int);")},
}
err = ensureAtlasBaselineAligned(context.Background(), db, fsys)
require.NoError(t, err)
require.NoError(t, mock.ExpectationsWereMet())
})
t.Run("error_when_checking_legacy_table", func(t *testing.T) {
db, mock, err := sqlmock.New()
require.NoError(t, err)
defer func() { _ = db.Close() }()
mock.ExpectQuery("SELECT EXISTS \\(").
WithArgs("schema_migrations").
WillReturnError(errors.New("exists failed"))
err = ensureAtlasBaselineAligned(context.Background(), db, fstest.MapFS{})
require.Error(t, err)
require.Contains(t, err.Error(), "check schema_migrations")
require.NoError(t, mock.ExpectationsWereMet())
})
t.Run("error_when_counting_atlas_rows", func(t *testing.T) {
db, mock, err := sqlmock.New()
require.NoError(t, err)
defer func() { _ = db.Close() }()
mock.ExpectQuery("SELECT EXISTS \\(").
WithArgs("schema_migrations").
WillReturnRows(sqlmock.NewRows([]string{"exists"}).AddRow(true))
mock.ExpectQuery("SELECT EXISTS \\(").
WithArgs("atlas_schema_revisions").
WillReturnRows(sqlmock.NewRows([]string{"exists"}).AddRow(true))
mock.ExpectQuery("SELECT COUNT\\(\\*\\) FROM atlas_schema_revisions").
WillReturnError(errors.New("count failed"))
err = ensureAtlasBaselineAligned(context.Background(), db, fstest.MapFS{})
require.Error(t, err)
require.Contains(t, err.Error(), "count atlas_schema_revisions")
require.NoError(t, mock.ExpectationsWereMet())
})
t.Run("error_when_creating_atlas_table", func(t *testing.T) {
db, mock, err := sqlmock.New()
require.NoError(t, err)
defer func() { _ = db.Close() }()
mock.ExpectQuery("SELECT EXISTS \\(").
WithArgs("schema_migrations").
WillReturnRows(sqlmock.NewRows([]string{"exists"}).AddRow(true))
mock.ExpectQuery("SELECT EXISTS \\(").
WithArgs("atlas_schema_revisions").
WillReturnRows(sqlmock.NewRows([]string{"exists"}).AddRow(false))
mock.ExpectExec("CREATE TABLE IF NOT EXISTS atlas_schema_revisions").
WillReturnError(errors.New("create failed"))
err = ensureAtlasBaselineAligned(context.Background(), db, fstest.MapFS{})
require.Error(t, err)
require.Contains(t, err.Error(), "create atlas_schema_revisions")
require.NoError(t, mock.ExpectationsWereMet())
})
t.Run("error_when_inserting_baseline", func(t *testing.T) {
db, mock, err := sqlmock.New()
require.NoError(t, err)
defer func() { _ = db.Close() }()
mock.ExpectQuery("SELECT EXISTS \\(").
WithArgs("schema_migrations").
WillReturnRows(sqlmock.NewRows([]string{"exists"}).AddRow(true))
mock.ExpectQuery("SELECT EXISTS \\(").
WithArgs("atlas_schema_revisions").
WillReturnRows(sqlmock.NewRows([]string{"exists"}).AddRow(true))
mock.ExpectQuery("SELECT COUNT\\(\\*\\) FROM atlas_schema_revisions").
WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(0))
mock.ExpectExec("INSERT INTO atlas_schema_revisions").
WithArgs("001_init", "001_init", 1, sqlmock.AnyArg()).
WillReturnError(errors.New("insert failed"))
fsys := fstest.MapFS{
"001_init.sql": &fstest.MapFile{Data: []byte("CREATE TABLE t(id int);")},
}
err = ensureAtlasBaselineAligned(context.Background(), db, fsys)
require.Error(t, err)
require.Contains(t, err.Error(), "insert atlas baseline")
require.NoError(t, mock.ExpectationsWereMet())
})
}
func TestApplyMigrationsFS_ChecksumMismatchRejected(t *testing.T) {
db, mock, err := sqlmock.New()
require.NoError(t, err)
defer func() { _ = db.Close() }()
prepareMigrationsBootstrapExpectations(mock)
mock.ExpectQuery("SELECT checksum FROM schema_migrations WHERE filename = \\$1").
WithArgs("001_init.sql").
WillReturnRows(sqlmock.NewRows([]string{"checksum"}).AddRow("mismatched-checksum"))
mock.ExpectExec("SELECT pg_advisory_unlock\\(\\$1\\)").
WithArgs(migrationsAdvisoryLockID).
WillReturnResult(sqlmock.NewResult(0, 1))
fsys := fstest.MapFS{
"001_init.sql": &fstest.MapFile{Data: []byte("CREATE TABLE t(id int);")},
}
err = applyMigrationsFS(context.Background(), db, fsys)
require.Error(t, err)
require.Contains(t, err.Error(), "checksum mismatch")
require.NoError(t, mock.ExpectationsWereMet())
}
func TestApplyMigrationsFS_CheckMigrationQueryError(t *testing.T) {
db, mock, err := sqlmock.New()
require.NoError(t, err)
defer func() { _ = db.Close() }()
prepareMigrationsBootstrapExpectations(mock)
mock.ExpectQuery("SELECT checksum FROM schema_migrations WHERE filename = \\$1").
WithArgs("001_err.sql").
WillReturnError(errors.New("query failed"))
mock.ExpectExec("SELECT pg_advisory_unlock\\(\\$1\\)").
WithArgs(migrationsAdvisoryLockID).
WillReturnResult(sqlmock.NewResult(0, 1))
fsys := fstest.MapFS{
"001_err.sql": &fstest.MapFile{Data: []byte("SELECT 1;")},
}
err = applyMigrationsFS(context.Background(), db, fsys)
require.Error(t, err)
require.Contains(t, err.Error(), "check migration 001_err.sql")
require.NoError(t, mock.ExpectationsWereMet())
}
func TestApplyMigrationsFS_SkipEmptyAndAlreadyApplied(t *testing.T) {
db, mock, err := sqlmock.New()
require.NoError(t, err)
defer func() { _ = db.Close() }()
prepareMigrationsBootstrapExpectations(mock)
alreadySQL := "CREATE TABLE t(id int);"
checksum := migrationChecksum(alreadySQL)
mock.ExpectQuery("SELECT checksum FROM schema_migrations WHERE filename = \\$1").
WithArgs("001_already.sql").
WillReturnRows(sqlmock.NewRows([]string{"checksum"}).AddRow(checksum))
mock.ExpectExec("SELECT pg_advisory_unlock\\(\\$1\\)").
WithArgs(migrationsAdvisoryLockID).
WillReturnResult(sqlmock.NewResult(0, 1))
fsys := fstest.MapFS{
"000_empty.sql": &fstest.MapFile{Data: []byte(" \n\t ")},
"001_already.sql": &fstest.MapFile{Data: []byte(alreadySQL)},
}
err = applyMigrationsFS(context.Background(), db, fsys)
require.NoError(t, err)
require.NoError(t, mock.ExpectationsWereMet())
}
func TestApplyMigrationsFS_ReadMigrationError(t *testing.T) {
db, mock, err := sqlmock.New()
require.NoError(t, err)
defer func() { _ = db.Close() }()
prepareMigrationsBootstrapExpectations(mock)
mock.ExpectExec("SELECT pg_advisory_unlock\\(\\$1\\)").
WithArgs(migrationsAdvisoryLockID).
WillReturnResult(sqlmock.NewResult(0, 1))
fsys := fstest.MapFS{
"001_bad.sql": &fstest.MapFile{Mode: fs.ModeDir},
}
err = applyMigrationsFS(context.Background(), db, fsys)
require.Error(t, err)
require.Contains(t, err.Error(), "read migration 001_bad.sql")
require.NoError(t, mock.ExpectationsWereMet())
}
func TestPgAdvisoryLockAndUnlock_ErrorBranches(t *testing.T) {
t.Run("context_cancelled_while_not_locked", func(t *testing.T) {
db, mock, err := sqlmock.New()
require.NoError(t, err)
defer func() { _ = db.Close() }()
mock.ExpectQuery("SELECT pg_try_advisory_lock\\(\\$1\\)").
WithArgs(migrationsAdvisoryLockID).
WillReturnRows(sqlmock.NewRows([]string{"pg_try_advisory_lock"}).AddRow(false))
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Millisecond)
defer cancel()
err = pgAdvisoryLock(ctx, db)
require.Error(t, err)
require.Contains(t, err.Error(), "acquire migrations lock")
require.NoError(t, mock.ExpectationsWereMet())
})
t.Run("unlock_exec_error", func(t *testing.T) {
db, mock, err := sqlmock.New()
require.NoError(t, err)
defer func() { _ = db.Close() }()
mock.ExpectExec("SELECT pg_advisory_unlock\\(\\$1\\)").
WithArgs(migrationsAdvisoryLockID).
WillReturnError(errors.New("unlock failed"))
err = pgAdvisoryUnlock(context.Background(), db)
require.Error(t, err)
require.Contains(t, err.Error(), "release migrations lock")
require.NoError(t, mock.ExpectationsWereMet())
})
t.Run("acquire_lock_after_retry", func(t *testing.T) {
db, mock, err := sqlmock.New()
require.NoError(t, err)
defer func() { _ = db.Close() }()
mock.ExpectQuery("SELECT pg_try_advisory_lock\\(\\$1\\)").
WithArgs(migrationsAdvisoryLockID).
WillReturnRows(sqlmock.NewRows([]string{"pg_try_advisory_lock"}).AddRow(false))
mock.ExpectQuery("SELECT pg_try_advisory_lock\\(\\$1\\)").
WithArgs(migrationsAdvisoryLockID).
WillReturnRows(sqlmock.NewRows([]string{"pg_try_advisory_lock"}).AddRow(true))
ctx, cancel := context.WithTimeout(context.Background(), migrationsLockRetryInterval*3)
defer cancel()
start := time.Now()
err = pgAdvisoryLock(ctx, db)
require.NoError(t, err)
require.GreaterOrEqual(t, time.Since(start), migrationsLockRetryInterval)
require.NoError(t, mock.ExpectationsWereMet())
})
}
func migrationChecksum(content string) string {
sum := sha256.Sum256([]byte(strings.TrimSpace(content)))
return hex.EncodeToString(sum[:])
}
package repository
import (
"context"
"database/sql"
"testing"
"testing/fstest"
sqlmock "github.com/DATA-DOG/go-sqlmock"
"github.com/stretchr/testify/require"
)
func TestValidateMigrationExecutionMode(t *testing.T) {
t.Run("事务迁移包含CONCURRENTLY会被拒绝", func(t *testing.T) {
nonTx, err := validateMigrationExecutionMode("001_add_idx.sql", "CREATE INDEX CONCURRENTLY idx_a ON t(a);")
require.False(t, nonTx)
require.Error(t, err)
})
t.Run("notx迁移要求CREATE使用IF NOT EXISTS", func(t *testing.T) {
nonTx, err := validateMigrationExecutionMode("001_add_idx_notx.sql", "CREATE INDEX CONCURRENTLY idx_a ON t(a);")
require.False(t, nonTx)
require.Error(t, err)
})
t.Run("notx迁移要求DROP使用IF EXISTS", func(t *testing.T) {
nonTx, err := validateMigrationExecutionMode("001_drop_idx_notx.sql", "DROP INDEX CONCURRENTLY idx_a;")
require.False(t, nonTx)
require.Error(t, err)
})
t.Run("notx迁移禁止事务控制语句", func(t *testing.T) {
nonTx, err := validateMigrationExecutionMode("001_add_idx_notx.sql", "BEGIN; CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_a ON t(a); COMMIT;")
require.False(t, nonTx)
require.Error(t, err)
})
t.Run("notx迁移禁止混用非CONCURRENTLY语句", func(t *testing.T) {
nonTx, err := validateMigrationExecutionMode("001_add_idx_notx.sql", "CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_a ON t(a); UPDATE t SET a = 1;")
require.False(t, nonTx)
require.Error(t, err)
})
t.Run("notx迁移允许幂等并发索引语句", func(t *testing.T) {
nonTx, err := validateMigrationExecutionMode("001_add_idx_notx.sql", `
CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_a ON t(a);
DROP INDEX CONCURRENTLY IF EXISTS idx_b;
`)
require.True(t, nonTx)
require.NoError(t, err)
})
}
func TestApplyMigrationsFS_NonTransactionalMigration(t *testing.T) {
db, mock, err := sqlmock.New()
require.NoError(t, err)
defer func() { _ = db.Close() }()
prepareMigrationsBootstrapExpectations(mock)
mock.ExpectQuery("SELECT checksum FROM schema_migrations WHERE filename = \\$1").
WithArgs("001_add_idx_notx.sql").
WillReturnError(sql.ErrNoRows)
mock.ExpectExec("CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_t_a ON t\\(a\\)").
WillReturnResult(sqlmock.NewResult(0, 0))
mock.ExpectExec("INSERT INTO schema_migrations \\(filename, checksum\\) VALUES \\(\\$1, \\$2\\)").
WithArgs("001_add_idx_notx.sql", sqlmock.AnyArg()).
WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectExec("SELECT pg_advisory_unlock\\(\\$1\\)").
WithArgs(migrationsAdvisoryLockID).
WillReturnResult(sqlmock.NewResult(0, 1))
fsys := fstest.MapFS{
"001_add_idx_notx.sql": &fstest.MapFile{
Data: []byte("CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_t_a ON t(a);"),
},
}
err = applyMigrationsFS(context.Background(), db, fsys)
require.NoError(t, err)
require.NoError(t, mock.ExpectationsWereMet())
}
func TestApplyMigrationsFS_NonTransactionalMigration_MultiStatements(t *testing.T) {
db, mock, err := sqlmock.New()
require.NoError(t, err)
defer func() { _ = db.Close() }()
prepareMigrationsBootstrapExpectations(mock)
mock.ExpectQuery("SELECT checksum FROM schema_migrations WHERE filename = \\$1").
WithArgs("001_add_multi_idx_notx.sql").
WillReturnError(sql.ErrNoRows)
mock.ExpectExec("CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_t_a ON t\\(a\\)").
WillReturnResult(sqlmock.NewResult(0, 0))
mock.ExpectExec("CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_t_b ON t\\(b\\)").
WillReturnResult(sqlmock.NewResult(0, 0))
mock.ExpectExec("INSERT INTO schema_migrations \\(filename, checksum\\) VALUES \\(\\$1, \\$2\\)").
WithArgs("001_add_multi_idx_notx.sql", sqlmock.AnyArg()).
WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectExec("SELECT pg_advisory_unlock\\(\\$1\\)").
WithArgs(migrationsAdvisoryLockID).
WillReturnResult(sqlmock.NewResult(0, 1))
fsys := fstest.MapFS{
"001_add_multi_idx_notx.sql": &fstest.MapFile{
Data: []byte(`
-- first
CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_t_a ON t(a);
-- second
CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_t_b ON t(b);
`),
},
}
err = applyMigrationsFS(context.Background(), db, fsys)
require.NoError(t, err)
require.NoError(t, mock.ExpectationsWereMet())
}
func TestApplyMigrationsFS_TransactionalMigration(t *testing.T) {
db, mock, err := sqlmock.New()
require.NoError(t, err)
defer func() { _ = db.Close() }()
prepareMigrationsBootstrapExpectations(mock)
mock.ExpectQuery("SELECT checksum FROM schema_migrations WHERE filename = \\$1").
WithArgs("001_add_col.sql").
WillReturnError(sql.ErrNoRows)
mock.ExpectBegin()
mock.ExpectExec("ALTER TABLE t ADD COLUMN name TEXT").
WillReturnResult(sqlmock.NewResult(0, 0))
mock.ExpectExec("INSERT INTO schema_migrations \\(filename, checksum\\) VALUES \\(\\$1, \\$2\\)").
WithArgs("001_add_col.sql", sqlmock.AnyArg()).
WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectCommit()
mock.ExpectExec("SELECT pg_advisory_unlock\\(\\$1\\)").
WithArgs(migrationsAdvisoryLockID).
WillReturnResult(sqlmock.NewResult(0, 1))
fsys := fstest.MapFS{
"001_add_col.sql": &fstest.MapFile{
Data: []byte("ALTER TABLE t ADD COLUMN name TEXT;"),
},
}
err = applyMigrationsFS(context.Background(), db, fsys)
require.NoError(t, err)
require.NoError(t, mock.ExpectationsWereMet())
}
func prepareMigrationsBootstrapExpectations(mock sqlmock.Sqlmock) {
mock.ExpectQuery("SELECT pg_try_advisory_lock\\(\\$1\\)").
WithArgs(migrationsAdvisoryLockID).
WillReturnRows(sqlmock.NewRows([]string{"pg_try_advisory_lock"}).AddRow(true))
mock.ExpectExec("CREATE TABLE IF NOT EXISTS schema_migrations").
WillReturnResult(sqlmock.NewResult(0, 0))
mock.ExpectQuery("SELECT EXISTS \\(").
WithArgs("schema_migrations").
WillReturnRows(sqlmock.NewRows([]string{"exists"}).AddRow(true))
mock.ExpectQuery("SELECT EXISTS \\(").
WithArgs("atlas_schema_revisions").
WillReturnRows(sqlmock.NewRows([]string{"exists"}).AddRow(true))
mock.ExpectQuery("SELECT COUNT\\(\\*\\) FROM atlas_schema_revisions").
WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(1))
}
......@@ -42,12 +42,19 @@ func TestMigrationsRunner_IsIdempotent_AndSchemaIsUpToDate(t *testing.T) {
// usage_logs: billing_type used by filters/stats
requireColumn(t, tx, "usage_logs", "billing_type", "smallint", 0, false)
requireColumn(t, tx, "usage_logs", "request_type", "smallint", 0, false)
requireColumn(t, tx, "usage_logs", "openai_ws_mode", "boolean", 0, false)
// settings table should exist
var settingsRegclass sql.NullString
require.NoError(t, tx.QueryRowContext(context.Background(), "SELECT to_regclass('public.settings')").Scan(&settingsRegclass))
require.True(t, settingsRegclass.Valid, "expected settings table to exist")
// security_secrets table should exist
var securitySecretsRegclass sql.NullString
require.NoError(t, tx.QueryRowContext(context.Background(), "SELECT to_regclass('public.security_secrets')").Scan(&securitySecretsRegclass))
require.True(t, securitySecretsRegclass.Valid, "expected security_secrets table to exist")
// user_allowed_groups table should exist
var uagRegclass sql.NullString
require.NoError(t, tx.QueryRowContext(context.Background(), "SELECT to_regclass('public.user_allowed_groups')").Scan(&uagRegclass))
......
......@@ -4,6 +4,7 @@ import (
"context"
"net/http"
"net/url"
"strings"
"time"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
......@@ -21,16 +22,23 @@ type openaiOAuthService struct {
tokenURL string
}
func (s *openaiOAuthService) ExchangeCode(ctx context.Context, code, codeVerifier, redirectURI, proxyURL string) (*openai.TokenResponse, error) {
client := createOpenAIReqClient(proxyURL)
func (s *openaiOAuthService) ExchangeCode(ctx context.Context, code, codeVerifier, redirectURI, proxyURL, clientID string) (*openai.TokenResponse, error) {
client, err := createOpenAIReqClient(proxyURL)
if err != nil {
return nil, infraerrors.Newf(http.StatusBadGateway, "OPENAI_OAUTH_CLIENT_INIT_FAILED", "create HTTP client: %v", err)
}
if redirectURI == "" {
redirectURI = openai.DefaultRedirectURI
}
clientID = strings.TrimSpace(clientID)
if clientID == "" {
clientID = openai.ClientID
}
formData := url.Values{}
formData.Set("grant_type", "authorization_code")
formData.Set("client_id", openai.ClientID)
formData.Set("client_id", clientID)
formData.Set("code", code)
formData.Set("redirect_uri", redirectURI)
formData.Set("code_verifier", codeVerifier)
......@@ -56,12 +64,28 @@ func (s *openaiOAuthService) ExchangeCode(ctx context.Context, code, codeVerifie
}
func (s *openaiOAuthService) RefreshToken(ctx context.Context, refreshToken, proxyURL string) (*openai.TokenResponse, error) {
client := createOpenAIReqClient(proxyURL)
return s.RefreshTokenWithClientID(ctx, refreshToken, proxyURL, "")
}
func (s *openaiOAuthService) RefreshTokenWithClientID(ctx context.Context, refreshToken, proxyURL string, clientID string) (*openai.TokenResponse, error) {
// 调用方应始终传入正确的 client_id;为兼容旧数据,未指定时默认使用 OpenAI ClientID
clientID = strings.TrimSpace(clientID)
if clientID == "" {
clientID = openai.ClientID
}
return s.refreshTokenWithClientID(ctx, refreshToken, proxyURL, clientID)
}
func (s *openaiOAuthService) refreshTokenWithClientID(ctx context.Context, refreshToken, proxyURL, clientID string) (*openai.TokenResponse, error) {
client, err := createOpenAIReqClient(proxyURL)
if err != nil {
return nil, infraerrors.Newf(http.StatusBadGateway, "OPENAI_OAUTH_CLIENT_INIT_FAILED", "create HTTP client: %v", err)
}
formData := url.Values{}
formData.Set("grant_type", "refresh_token")
formData.Set("refresh_token", refreshToken)
formData.Set("client_id", openai.ClientID)
formData.Set("client_id", clientID)
formData.Set("scope", openai.RefreshScopes)
var tokenResp openai.TokenResponse
......@@ -84,7 +108,7 @@ func (s *openaiOAuthService) RefreshToken(ctx context.Context, refreshToken, pro
return &tokenResp, nil
}
func createOpenAIReqClient(proxyURL string) *req.Client {
func createOpenAIReqClient(proxyURL string) (*req.Client, error) {
return getSharedReqClient(reqClientOptions{
ProxyURL: proxyURL,
Timeout: 120 * time.Second,
......
......@@ -81,7 +81,7 @@ func (s *OpenAIOAuthServiceSuite) TestExchangeCode_DefaultRedirectURI() {
_, _ = io.WriteString(w, `{"access_token":"at","refresh_token":"rt","token_type":"bearer","expires_in":3600}`)
}))
resp, err := s.svc.ExchangeCode(s.ctx, "code", "ver", "", "")
resp, err := s.svc.ExchangeCode(s.ctx, "code", "ver", "", "", "")
require.NoError(s.T(), err, "ExchangeCode")
select {
case msg := <-errCh:
......@@ -136,13 +136,84 @@ func (s *OpenAIOAuthServiceSuite) TestRefreshToken_FormFields() {
require.Equal(s.T(), "rt2", resp.RefreshToken)
}
// TestRefreshToken_DefaultsToOpenAIClientID 验证未指定 client_id 时默认使用 OpenAI ClientID,
// 且只发送一次请求(不再盲猜多个 client_id)。
func (s *OpenAIOAuthServiceSuite) TestRefreshToken_DefaultsToOpenAIClientID() {
var seenClientIDs []string
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if err := r.ParseForm(); err != nil {
w.WriteHeader(http.StatusBadRequest)
return
}
clientID := r.PostForm.Get("client_id")
seenClientIDs = append(seenClientIDs, clientID)
w.Header().Set("Content-Type", "application/json")
_, _ = io.WriteString(w, `{"access_token":"at","refresh_token":"rt","token_type":"bearer","expires_in":3600}`)
}))
resp, err := s.svc.RefreshToken(s.ctx, "rt", "")
require.NoError(s.T(), err, "RefreshToken")
require.Equal(s.T(), "at", resp.AccessToken)
// 只发送了一次请求,使用默认的 OpenAI ClientID
require.Equal(s.T(), []string{openai.ClientID}, seenClientIDs)
}
// TestRefreshToken_UseSoraClientID 验证显式传入 Sora ClientID 时直接使用,不回退。
func (s *OpenAIOAuthServiceSuite) TestRefreshToken_UseSoraClientID() {
var seenClientIDs []string
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if err := r.ParseForm(); err != nil {
w.WriteHeader(http.StatusBadRequest)
return
}
clientID := r.PostForm.Get("client_id")
seenClientIDs = append(seenClientIDs, clientID)
if clientID == openai.SoraClientID {
w.Header().Set("Content-Type", "application/json")
_, _ = io.WriteString(w, `{"access_token":"at-sora","refresh_token":"rt-sora","token_type":"bearer","expires_in":3600}`)
return
}
w.WriteHeader(http.StatusBadRequest)
}))
resp, err := s.svc.RefreshTokenWithClientID(s.ctx, "rt", "", openai.SoraClientID)
require.NoError(s.T(), err, "RefreshTokenWithClientID")
require.Equal(s.T(), "at-sora", resp.AccessToken)
require.Equal(s.T(), []string{openai.SoraClientID}, seenClientIDs)
}
func (s *OpenAIOAuthServiceSuite) TestRefreshToken_UseProvidedClientID() {
const customClientID = "custom-client-id"
var seenClientIDs []string
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if err := r.ParseForm(); err != nil {
w.WriteHeader(http.StatusBadRequest)
return
}
clientID := r.PostForm.Get("client_id")
seenClientIDs = append(seenClientIDs, clientID)
if clientID != customClientID {
w.WriteHeader(http.StatusBadRequest)
return
}
w.Header().Set("Content-Type", "application/json")
_, _ = io.WriteString(w, `{"access_token":"at-custom","refresh_token":"rt-custom","token_type":"bearer","expires_in":3600}`)
}))
resp, err := s.svc.RefreshTokenWithClientID(s.ctx, "rt", "", customClientID)
require.NoError(s.T(), err, "RefreshTokenWithClientID")
require.Equal(s.T(), "at-custom", resp.AccessToken)
require.Equal(s.T(), "rt-custom", resp.RefreshToken)
require.Equal(s.T(), []string{customClientID}, seenClientIDs)
}
func (s *OpenAIOAuthServiceSuite) TestNonSuccessStatus_IncludesBody() {
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusBadRequest)
_, _ = io.WriteString(w, "bad")
}))
_, err := s.svc.ExchangeCode(s.ctx, "code", "ver", openai.DefaultRedirectURI, "")
_, err := s.svc.ExchangeCode(s.ctx, "code", "ver", openai.DefaultRedirectURI, "", "")
require.Error(s.T(), err)
require.ErrorContains(s.T(), err, "status 400")
require.ErrorContains(s.T(), err, "bad")
......@@ -152,7 +223,7 @@ func (s *OpenAIOAuthServiceSuite) TestRequestError_ClosedServer() {
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
s.srv.Close()
_, err := s.svc.ExchangeCode(s.ctx, "code", "ver", openai.DefaultRedirectURI, "")
_, err := s.svc.ExchangeCode(s.ctx, "code", "ver", openai.DefaultRedirectURI, "", "")
require.Error(s.T(), err)
require.ErrorContains(s.T(), err, "request failed")
}
......@@ -169,7 +240,7 @@ func (s *OpenAIOAuthServiceSuite) TestContextCancel() {
done := make(chan error, 1)
go func() {
_, err := s.svc.ExchangeCode(ctx, "code", "ver", openai.DefaultRedirectURI, "")
_, err := s.svc.ExchangeCode(ctx, "code", "ver", openai.DefaultRedirectURI, "", "")
done <- err
}()
......@@ -195,7 +266,30 @@ func (s *OpenAIOAuthServiceSuite) TestExchangeCode_UsesProvidedRedirectURI() {
_, _ = io.WriteString(w, `{"access_token":"at","token_type":"bearer","expires_in":1}`)
}))
_, err := s.svc.ExchangeCode(s.ctx, "code", "ver", want, "")
_, err := s.svc.ExchangeCode(s.ctx, "code", "ver", want, "", "")
require.NoError(s.T(), err, "ExchangeCode")
select {
case msg := <-errCh:
require.Fail(s.T(), msg)
default:
}
}
func (s *OpenAIOAuthServiceSuite) TestExchangeCode_UseProvidedClientID() {
wantClientID := openai.SoraClientID
errCh := make(chan string, 1)
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_ = r.ParseForm()
if got := r.PostForm.Get("client_id"); got != wantClientID {
errCh <- "client_id mismatch"
w.WriteHeader(http.StatusBadRequest)
return
}
w.Header().Set("Content-Type", "application/json")
_, _ = io.WriteString(w, `{"access_token":"at","token_type":"bearer","expires_in":1}`)
}))
_, err := s.svc.ExchangeCode(s.ctx, "code", "ver", openai.DefaultRedirectURI, "", wantClientID)
require.NoError(s.T(), err, "ExchangeCode")
select {
case msg := <-errCh:
......@@ -213,7 +307,7 @@ func (s *OpenAIOAuthServiceSuite) TestTokenURL_CanBeOverriddenWithQuery() {
}))
s.svc.tokenURL = s.srv.URL + "?x=1"
_, err := s.svc.ExchangeCode(s.ctx, "code", "ver", openai.DefaultRedirectURI, "")
_, err := s.svc.ExchangeCode(s.ctx, "code", "ver", openai.DefaultRedirectURI, "", "")
require.NoError(s.T(), err, "ExchangeCode")
select {
case <-s.received:
......@@ -229,7 +323,7 @@ func (s *OpenAIOAuthServiceSuite) TestExchangeCode_SuccessButInvalidJSON() {
_, _ = io.WriteString(w, "not-valid-json")
}))
_, err := s.svc.ExchangeCode(s.ctx, "code", "ver", openai.DefaultRedirectURI, "")
_, err := s.svc.ExchangeCode(s.ctx, "code", "ver", openai.DefaultRedirectURI, "", "")
require.Error(s.T(), err, "expected error for invalid JSON response")
}
......
......@@ -3,6 +3,7 @@ package repository
import (
"context"
"database/sql"
"encoding/json"
"fmt"
"strings"
"time"
......@@ -55,6 +56,10 @@ INSERT INTO ops_error_logs (
upstream_error_message,
upstream_error_detail,
upstream_errors,
auth_latency_ms,
routing_latency_ms,
upstream_latency_ms,
response_latency_ms,
time_to_first_token_ms,
request_body,
request_body_truncated,
......@@ -64,7 +69,7 @@ INSERT INTO ops_error_logs (
retry_count,
created_at
) VALUES (
$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31,$32,$33,$34
$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31,$32,$33,$34,$35,$36,$37,$38
) RETURNING id`
var id int64
......@@ -97,6 +102,10 @@ INSERT INTO ops_error_logs (
opsNullString(input.UpstreamErrorMessage),
opsNullString(input.UpstreamErrorDetail),
opsNullString(input.UpstreamErrorsJSON),
opsNullInt64(input.AuthLatencyMs),
opsNullInt64(input.RoutingLatencyMs),
opsNullInt64(input.UpstreamLatencyMs),
opsNullInt64(input.ResponseLatencyMs),
opsNullInt64(input.TimeToFirstTokenMs),
opsNullString(input.RequestBodyJSON),
input.RequestBodyTruncated,
......@@ -930,6 +939,243 @@ WHERE id = $1`
return err
}
func (r *opsRepository) BatchInsertSystemLogs(ctx context.Context, inputs []*service.OpsInsertSystemLogInput) (int64, error) {
if r == nil || r.db == nil {
return 0, fmt.Errorf("nil ops repository")
}
if len(inputs) == 0 {
return 0, nil
}
tx, err := r.db.BeginTx(ctx, nil)
if err != nil {
return 0, err
}
stmt, err := tx.PrepareContext(ctx, pq.CopyIn(
"ops_system_logs",
"created_at",
"level",
"component",
"message",
"request_id",
"client_request_id",
"user_id",
"account_id",
"platform",
"model",
"extra",
))
if err != nil {
_ = tx.Rollback()
return 0, err
}
var inserted int64
for _, input := range inputs {
if input == nil {
continue
}
createdAt := input.CreatedAt
if createdAt.IsZero() {
createdAt = time.Now().UTC()
}
component := strings.TrimSpace(input.Component)
level := strings.ToLower(strings.TrimSpace(input.Level))
message := strings.TrimSpace(input.Message)
if level == "" || message == "" {
continue
}
if component == "" {
component = "app"
}
extra := strings.TrimSpace(input.ExtraJSON)
if extra == "" {
extra = "{}"
}
if _, err := stmt.ExecContext(
ctx,
createdAt.UTC(),
level,
component,
message,
opsNullString(input.RequestID),
opsNullString(input.ClientRequestID),
opsNullInt64(input.UserID),
opsNullInt64(input.AccountID),
opsNullString(input.Platform),
opsNullString(input.Model),
extra,
); err != nil {
_ = stmt.Close()
_ = tx.Rollback()
return inserted, err
}
inserted++
}
if _, err := stmt.ExecContext(ctx); err != nil {
_ = stmt.Close()
_ = tx.Rollback()
return inserted, err
}
if err := stmt.Close(); err != nil {
_ = tx.Rollback()
return inserted, err
}
if err := tx.Commit(); err != nil {
return inserted, err
}
return inserted, nil
}
func (r *opsRepository) ListSystemLogs(ctx context.Context, filter *service.OpsSystemLogFilter) (*service.OpsSystemLogList, error) {
if r == nil || r.db == nil {
return nil, fmt.Errorf("nil ops repository")
}
if filter == nil {
filter = &service.OpsSystemLogFilter{}
}
page := filter.Page
if page <= 0 {
page = 1
}
pageSize := filter.PageSize
if pageSize <= 0 {
pageSize = 50
}
if pageSize > 200 {
pageSize = 200
}
where, args, _ := buildOpsSystemLogsWhere(filter)
countSQL := "SELECT COUNT(*) FROM ops_system_logs l " + where
var total int
if err := r.db.QueryRowContext(ctx, countSQL, args...).Scan(&total); err != nil {
return nil, err
}
offset := (page - 1) * pageSize
argsWithLimit := append(args, pageSize, offset)
query := `
SELECT
l.id,
l.created_at,
l.level,
COALESCE(l.component, ''),
COALESCE(l.message, ''),
COALESCE(l.request_id, ''),
COALESCE(l.client_request_id, ''),
l.user_id,
l.account_id,
COALESCE(l.platform, ''),
COALESCE(l.model, ''),
COALESCE(l.extra::text, '{}')
FROM ops_system_logs l
` + where + `
ORDER BY l.created_at DESC, l.id DESC
LIMIT $` + itoa(len(args)+1) + ` OFFSET $` + itoa(len(args)+2)
rows, err := r.db.QueryContext(ctx, query, argsWithLimit...)
if err != nil {
return nil, err
}
defer func() { _ = rows.Close() }()
logs := make([]*service.OpsSystemLog, 0, pageSize)
for rows.Next() {
item := &service.OpsSystemLog{}
var userID sql.NullInt64
var accountID sql.NullInt64
var extraRaw string
if err := rows.Scan(
&item.ID,
&item.CreatedAt,
&item.Level,
&item.Component,
&item.Message,
&item.RequestID,
&item.ClientRequestID,
&userID,
&accountID,
&item.Platform,
&item.Model,
&extraRaw,
); err != nil {
return nil, err
}
if userID.Valid {
v := userID.Int64
item.UserID = &v
}
if accountID.Valid {
v := accountID.Int64
item.AccountID = &v
}
extraRaw = strings.TrimSpace(extraRaw)
if extraRaw != "" && extraRaw != "null" && extraRaw != "{}" {
extra := make(map[string]any)
if err := json.Unmarshal([]byte(extraRaw), &extra); err == nil {
item.Extra = extra
}
}
logs = append(logs, item)
}
if err := rows.Err(); err != nil {
return nil, err
}
return &service.OpsSystemLogList{
Logs: logs,
Total: total,
Page: page,
PageSize: pageSize,
}, nil
}
func (r *opsRepository) DeleteSystemLogs(ctx context.Context, filter *service.OpsSystemLogCleanupFilter) (int64, error) {
if r == nil || r.db == nil {
return 0, fmt.Errorf("nil ops repository")
}
if filter == nil {
filter = &service.OpsSystemLogCleanupFilter{}
}
where, args, hasConstraint := buildOpsSystemLogsCleanupWhere(filter)
if !hasConstraint {
return 0, fmt.Errorf("cleanup requires at least one filter condition")
}
query := "DELETE FROM ops_system_logs l " + where
res, err := r.db.ExecContext(ctx, query, args...)
if err != nil {
return 0, err
}
return res.RowsAffected()
}
func (r *opsRepository) InsertSystemLogCleanupAudit(ctx context.Context, input *service.OpsSystemLogCleanupAudit) error {
if r == nil || r.db == nil {
return fmt.Errorf("nil ops repository")
}
if input == nil {
return fmt.Errorf("nil input")
}
createdAt := input.CreatedAt
if createdAt.IsZero() {
createdAt = time.Now().UTC()
}
_, err := r.db.ExecContext(ctx, `
INSERT INTO ops_system_log_cleanup_audits (
created_at,
operator_id,
conditions,
deleted_rows
) VALUES ($1,$2,$3,$4)
`, createdAt.UTC(), input.OperatorID, input.Conditions, input.DeletedRows)
return err
}
func buildOpsErrorLogsWhere(filter *service.OpsErrorLogFilter) (string, []any) {
clauses := make([]string, 0, 12)
args := make([]any, 0, 12)
......@@ -948,7 +1194,7 @@ func buildOpsErrorLogsWhere(filter *service.OpsErrorLogFilter) (string, []any) {
}
// Keep list endpoints scoped to client errors unless explicitly filtering upstream phase.
if phaseFilter != "upstream" {
clauses = append(clauses, "COALESCE(status_code, 0) >= 400")
clauses = append(clauses, "COALESCE(e.status_code, 0) >= 400")
}
if filter.StartTime != nil && !filter.StartTime.IsZero() {
......@@ -962,33 +1208,33 @@ func buildOpsErrorLogsWhere(filter *service.OpsErrorLogFilter) (string, []any) {
}
if p := strings.TrimSpace(filter.Platform); p != "" {
args = append(args, p)
clauses = append(clauses, "platform = $"+itoa(len(args)))
clauses = append(clauses, "e.platform = $"+itoa(len(args)))
}
if filter.GroupID != nil && *filter.GroupID > 0 {
args = append(args, *filter.GroupID)
clauses = append(clauses, "group_id = $"+itoa(len(args)))
clauses = append(clauses, "e.group_id = $"+itoa(len(args)))
}
if filter.AccountID != nil && *filter.AccountID > 0 {
args = append(args, *filter.AccountID)
clauses = append(clauses, "account_id = $"+itoa(len(args)))
clauses = append(clauses, "e.account_id = $"+itoa(len(args)))
}
if phase := phaseFilter; phase != "" {
args = append(args, phase)
clauses = append(clauses, "error_phase = $"+itoa(len(args)))
clauses = append(clauses, "e.error_phase = $"+itoa(len(args)))
}
if filter != nil {
if owner := strings.TrimSpace(strings.ToLower(filter.Owner)); owner != "" {
args = append(args, owner)
clauses = append(clauses, "LOWER(COALESCE(error_owner,'')) = $"+itoa(len(args)))
clauses = append(clauses, "LOWER(COALESCE(e.error_owner,'')) = $"+itoa(len(args)))
}
if source := strings.TrimSpace(strings.ToLower(filter.Source)); source != "" {
args = append(args, source)
clauses = append(clauses, "LOWER(COALESCE(error_source,'')) = $"+itoa(len(args)))
clauses = append(clauses, "LOWER(COALESCE(e.error_source,'')) = $"+itoa(len(args)))
}
}
if resolvedFilter != nil {
args = append(args, *resolvedFilter)
clauses = append(clauses, "COALESCE(resolved,false) = $"+itoa(len(args)))
clauses = append(clauses, "COALESCE(e.resolved,false) = $"+itoa(len(args)))
}
// View filter: errors vs excluded vs all.
......@@ -1000,51 +1246,140 @@ func buildOpsErrorLogsWhere(filter *service.OpsErrorLogFilter) (string, []any) {
}
switch view {
case "", "errors":
clauses = append(clauses, "COALESCE(is_business_limited,false) = false")
clauses = append(clauses, "COALESCE(e.is_business_limited,false) = false")
case "excluded":
clauses = append(clauses, "COALESCE(is_business_limited,false) = true")
clauses = append(clauses, "COALESCE(e.is_business_limited,false) = true")
case "all":
// no-op
default:
// treat unknown as default 'errors'
clauses = append(clauses, "COALESCE(is_business_limited,false) = false")
clauses = append(clauses, "COALESCE(e.is_business_limited,false) = false")
}
if len(filter.StatusCodes) > 0 {
args = append(args, pq.Array(filter.StatusCodes))
clauses = append(clauses, "COALESCE(upstream_status_code, status_code, 0) = ANY($"+itoa(len(args))+")")
clauses = append(clauses, "COALESCE(e.upstream_status_code, e.status_code, 0) = ANY($"+itoa(len(args))+")")
} else if filter.StatusCodesOther {
// "Other" means: status codes not in the common list.
known := []int{400, 401, 403, 404, 409, 422, 429, 500, 502, 503, 504, 529}
args = append(args, pq.Array(known))
clauses = append(clauses, "NOT (COALESCE(upstream_status_code, status_code, 0) = ANY($"+itoa(len(args))+"))")
clauses = append(clauses, "NOT (COALESCE(e.upstream_status_code, e.status_code, 0) = ANY($"+itoa(len(args))+"))")
}
// Exact correlation keys (preferred for request↔upstream linkage).
if rid := strings.TrimSpace(filter.RequestID); rid != "" {
args = append(args, rid)
clauses = append(clauses, "COALESCE(request_id,'') = $"+itoa(len(args)))
clauses = append(clauses, "COALESCE(e.request_id,'') = $"+itoa(len(args)))
}
if crid := strings.TrimSpace(filter.ClientRequestID); crid != "" {
args = append(args, crid)
clauses = append(clauses, "COALESCE(client_request_id,'') = $"+itoa(len(args)))
clauses = append(clauses, "COALESCE(e.client_request_id,'') = $"+itoa(len(args)))
}
if q := strings.TrimSpace(filter.Query); q != "" {
like := "%" + q + "%"
args = append(args, like)
n := itoa(len(args))
clauses = append(clauses, "(request_id ILIKE $"+n+" OR client_request_id ILIKE $"+n+" OR error_message ILIKE $"+n+")")
clauses = append(clauses, "(e.request_id ILIKE $"+n+" OR e.client_request_id ILIKE $"+n+" OR e.error_message ILIKE $"+n+")")
}
if userQuery := strings.TrimSpace(filter.UserQuery); userQuery != "" {
like := "%" + userQuery + "%"
args = append(args, like)
n := itoa(len(args))
clauses = append(clauses, "u.email ILIKE $"+n)
clauses = append(clauses, "EXISTS (SELECT 1 FROM users u WHERE u.id = e.user_id AND u.email ILIKE $"+n+")")
}
return "WHERE " + strings.Join(clauses, " AND "), args
}
func buildOpsSystemLogsWhere(filter *service.OpsSystemLogFilter) (string, []any, bool) {
clauses := make([]string, 0, 10)
args := make([]any, 0, 10)
clauses = append(clauses, "1=1")
hasConstraint := false
if filter != nil && filter.StartTime != nil && !filter.StartTime.IsZero() {
args = append(args, filter.StartTime.UTC())
clauses = append(clauses, "l.created_at >= $"+itoa(len(args)))
hasConstraint = true
}
if filter != nil && filter.EndTime != nil && !filter.EndTime.IsZero() {
args = append(args, filter.EndTime.UTC())
clauses = append(clauses, "l.created_at < $"+itoa(len(args)))
hasConstraint = true
}
if filter != nil {
if v := strings.ToLower(strings.TrimSpace(filter.Level)); v != "" {
args = append(args, v)
clauses = append(clauses, "LOWER(COALESCE(l.level,'')) = $"+itoa(len(args)))
hasConstraint = true
}
if v := strings.TrimSpace(filter.Component); v != "" {
args = append(args, v)
clauses = append(clauses, "COALESCE(l.component,'') = $"+itoa(len(args)))
hasConstraint = true
}
if v := strings.TrimSpace(filter.RequestID); v != "" {
args = append(args, v)
clauses = append(clauses, "COALESCE(l.request_id,'') = $"+itoa(len(args)))
hasConstraint = true
}
if v := strings.TrimSpace(filter.ClientRequestID); v != "" {
args = append(args, v)
clauses = append(clauses, "COALESCE(l.client_request_id,'') = $"+itoa(len(args)))
hasConstraint = true
}
if filter.UserID != nil && *filter.UserID > 0 {
args = append(args, *filter.UserID)
clauses = append(clauses, "l.user_id = $"+itoa(len(args)))
hasConstraint = true
}
if filter.AccountID != nil && *filter.AccountID > 0 {
args = append(args, *filter.AccountID)
clauses = append(clauses, "l.account_id = $"+itoa(len(args)))
hasConstraint = true
}
if v := strings.TrimSpace(filter.Platform); v != "" {
args = append(args, v)
clauses = append(clauses, "COALESCE(l.platform,'') = $"+itoa(len(args)))
hasConstraint = true
}
if v := strings.TrimSpace(filter.Model); v != "" {
args = append(args, v)
clauses = append(clauses, "COALESCE(l.model,'') = $"+itoa(len(args)))
hasConstraint = true
}
if v := strings.TrimSpace(filter.Query); v != "" {
like := "%" + v + "%"
args = append(args, like)
n := itoa(len(args))
clauses = append(clauses, "(l.message ILIKE $"+n+" OR COALESCE(l.request_id,'') ILIKE $"+n+" OR COALESCE(l.client_request_id,'') ILIKE $"+n+" OR COALESCE(l.extra::text,'') ILIKE $"+n+")")
hasConstraint = true
}
}
return "WHERE " + strings.Join(clauses, " AND "), args, hasConstraint
}
func buildOpsSystemLogsCleanupWhere(filter *service.OpsSystemLogCleanupFilter) (string, []any, bool) {
if filter == nil {
filter = &service.OpsSystemLogCleanupFilter{}
}
listFilter := &service.OpsSystemLogFilter{
StartTime: filter.StartTime,
EndTime: filter.EndTime,
Level: filter.Level,
Component: filter.Component,
RequestID: filter.RequestID,
ClientRequestID: filter.ClientRequestID,
UserID: filter.UserID,
AccountID: filter.AccountID,
Platform: filter.Platform,
Model: filter.Model,
Query: filter.Query,
}
return buildOpsSystemLogsWhere(listFilter)
}
// Helpers for nullable args
func opsNullString(v any) any {
switch s := v.(type) {
......
......@@ -12,6 +12,11 @@ import (
"github.com/Wei-Shaw/sub2api/internal/service"
)
const (
opsRawLatencyQueryTimeout = 2 * time.Second
opsRawPeakQueryTimeout = 1500 * time.Millisecond
)
func (r *opsRepository) GetDashboardOverview(ctx context.Context, filter *service.OpsDashboardFilter) (*service.OpsDashboardOverview, error) {
if r == nil || r.db == nil {
return nil, fmt.Errorf("nil ops repository")
......@@ -45,16 +50,25 @@ func (r *opsRepository) GetDashboardOverview(ctx context.Context, filter *servic
func (r *opsRepository) getDashboardOverviewRaw(ctx context.Context, filter *service.OpsDashboardFilter) (*service.OpsDashboardOverview, error) {
start := filter.StartTime.UTC()
end := filter.EndTime.UTC()
degraded := false
successCount, tokenConsumed, err := r.queryUsageCounts(ctx, filter, start, end)
if err != nil {
return nil, err
}
duration, ttft, err := r.queryUsageLatency(ctx, filter, start, end)
latencyCtx, cancelLatency := context.WithTimeout(ctx, opsRawLatencyQueryTimeout)
duration, ttft, err := r.queryUsageLatency(latencyCtx, filter, start, end)
cancelLatency()
if err != nil {
if isQueryTimeoutErr(err) {
degraded = true
duration = service.OpsPercentiles{}
ttft = service.OpsPercentiles{}
} else {
return nil, err
}
}
errorTotal, businessLimited, errorCountSLA, upstreamExcl, upstream429, upstream529, err := r.queryErrorCounts(ctx, filter, start, end)
if err != nil {
......@@ -75,20 +89,40 @@ func (r *opsRepository) getDashboardOverviewRaw(ctx context.Context, filter *ser
qpsCurrent, tpsCurrent, err := r.queryCurrentRates(ctx, filter, end)
if err != nil {
if isQueryTimeoutErr(err) {
degraded = true
} else {
return nil, err
}
}
qpsPeak, err := r.queryPeakQPS(ctx, filter, start, end)
peakCtx, cancelPeak := context.WithTimeout(ctx, opsRawPeakQueryTimeout)
qpsPeak, tpsPeak, err := r.queryPeakRates(peakCtx, filter, start, end)
cancelPeak()
if err != nil {
if isQueryTimeoutErr(err) {
degraded = true
} else {
return nil, err
}
tpsPeak, err := r.queryPeakTPS(ctx, filter, start, end)
if err != nil {
return nil, err
}
qpsAvg := roundTo1DP(float64(requestCountTotal) / windowSeconds)
tpsAvg := roundTo1DP(float64(tokenConsumed) / windowSeconds)
if degraded {
if qpsCurrent <= 0 {
qpsCurrent = qpsAvg
}
if tpsCurrent <= 0 {
tpsCurrent = tpsAvg
}
if qpsPeak <= 0 {
qpsPeak = roundTo1DP(math.Max(qpsCurrent, qpsAvg))
}
if tpsPeak <= 0 {
tpsPeak = roundTo1DP(math.Max(tpsCurrent, tpsAvg))
}
}
return &service.OpsDashboardOverview{
StartTime: start,
......@@ -230,26 +264,45 @@ func (r *opsRepository) getDashboardOverviewPreaggregated(ctx context.Context, f
sla := safeDivideFloat64(float64(successCount), float64(requestCountSLA))
errorRate := safeDivideFloat64(float64(errorCountSLA), float64(requestCountSLA))
upstreamErrorRate := safeDivideFloat64(float64(upstreamExcl), float64(requestCountSLA))
degraded := false
// Keep "current" rates as raw, to preserve realtime semantics.
qpsCurrent, tpsCurrent, err := r.queryCurrentRates(ctx, filter, end)
if err != nil {
if isQueryTimeoutErr(err) {
degraded = true
} else {
return nil, err
}
}
// NOTE: peak still uses raw logs (minute granularity). This is typically cheaper than percentile_cont
// and keeps semantics consistent across modes.
qpsPeak, err := r.queryPeakQPS(ctx, filter, start, end)
peakCtx, cancelPeak := context.WithTimeout(ctx, opsRawPeakQueryTimeout)
qpsPeak, tpsPeak, err := r.queryPeakRates(peakCtx, filter, start, end)
cancelPeak()
if err != nil {
if isQueryTimeoutErr(err) {
degraded = true
} else {
return nil, err
}
tpsPeak, err := r.queryPeakTPS(ctx, filter, start, end)
if err != nil {
return nil, err
}
qpsAvg := roundTo1DP(float64(requestCountTotal) / windowSeconds)
tpsAvg := roundTo1DP(float64(tokenConsumed) / windowSeconds)
if degraded {
if qpsCurrent <= 0 {
qpsCurrent = qpsAvg
}
if tpsCurrent <= 0 {
tpsCurrent = tpsAvg
}
if qpsPeak <= 0 {
qpsPeak = roundTo1DP(math.Max(qpsCurrent, qpsAvg))
}
if tpsPeak <= 0 {
tpsPeak = roundTo1DP(math.Max(tpsCurrent, tpsAvg))
}
}
return &service.OpsDashboardOverview{
StartTime: start,
......@@ -577,10 +630,17 @@ func (r *opsRepository) queryRawPartial(ctx context.Context, filter *service.Ops
return nil, err
}
duration, ttft, err := r.queryUsageLatency(ctx, filter, start, end)
latencyCtx, cancelLatency := context.WithTimeout(ctx, opsRawLatencyQueryTimeout)
duration, ttft, err := r.queryUsageLatency(latencyCtx, filter, start, end)
cancelLatency()
if err != nil {
if isQueryTimeoutErr(err) {
duration = service.OpsPercentiles{}
ttft = service.OpsPercentiles{}
} else {
return nil, err
}
}
errorTotal, businessLimited, errorCountSLA, upstreamExcl, upstream429, upstream529, err := r.queryErrorCounts(ctx, filter, start, end)
if err != nil {
......@@ -735,69 +795,57 @@ FROM usage_logs ul
}
func (r *opsRepository) queryUsageLatency(ctx context.Context, filter *service.OpsDashboardFilter, start, end time.Time) (duration service.OpsPercentiles, ttft service.OpsPercentiles, err error) {
{
join, where, args, _ := buildUsageWhere(filter, start, end, 1)
q := `
SELECT
percentile_cont(0.50) WITHIN GROUP (ORDER BY duration_ms) AS p50,
percentile_cont(0.90) WITHIN GROUP (ORDER BY duration_ms) AS p90,
percentile_cont(0.95) WITHIN GROUP (ORDER BY duration_ms) AS p95,
percentile_cont(0.99) WITHIN GROUP (ORDER BY duration_ms) AS p99,
AVG(duration_ms) AS avg_ms,
MAX(duration_ms) AS max_ms
percentile_cont(0.50) WITHIN GROUP (ORDER BY duration_ms) FILTER (WHERE duration_ms IS NOT NULL) AS duration_p50,
percentile_cont(0.90) WITHIN GROUP (ORDER BY duration_ms) FILTER (WHERE duration_ms IS NOT NULL) AS duration_p90,
percentile_cont(0.95) WITHIN GROUP (ORDER BY duration_ms) FILTER (WHERE duration_ms IS NOT NULL) AS duration_p95,
percentile_cont(0.99) WITHIN GROUP (ORDER BY duration_ms) FILTER (WHERE duration_ms IS NOT NULL) AS duration_p99,
AVG(duration_ms) FILTER (WHERE duration_ms IS NOT NULL) AS duration_avg,
MAX(duration_ms) AS duration_max,
percentile_cont(0.50) WITHIN GROUP (ORDER BY first_token_ms) FILTER (WHERE first_token_ms IS NOT NULL) AS ttft_p50,
percentile_cont(0.90) WITHIN GROUP (ORDER BY first_token_ms) FILTER (WHERE first_token_ms IS NOT NULL) AS ttft_p90,
percentile_cont(0.95) WITHIN GROUP (ORDER BY first_token_ms) FILTER (WHERE first_token_ms IS NOT NULL) AS ttft_p95,
percentile_cont(0.99) WITHIN GROUP (ORDER BY first_token_ms) FILTER (WHERE first_token_ms IS NOT NULL) AS ttft_p99,
AVG(first_token_ms) FILTER (WHERE first_token_ms IS NOT NULL) AS ttft_avg,
MAX(first_token_ms) AS ttft_max
FROM usage_logs ul
` + join + `
` + where + `
AND duration_ms IS NOT NULL`
` + where
var p50, p90, p95, p99 sql.NullFloat64
var avg sql.NullFloat64
var max sql.NullInt64
if err := r.db.QueryRowContext(ctx, q, args...).Scan(&p50, &p90, &p95, &p99, &avg, &max); err != nil {
var dP50, dP90, dP95, dP99 sql.NullFloat64
var dAvg sql.NullFloat64
var dMax sql.NullInt64
var tP50, tP90, tP95, tP99 sql.NullFloat64
var tAvg sql.NullFloat64
var tMax sql.NullInt64
if err := r.db.QueryRowContext(ctx, q, args...).Scan(
&dP50, &dP90, &dP95, &dP99, &dAvg, &dMax,
&tP50, &tP90, &tP95, &tP99, &tAvg, &tMax,
); err != nil {
return service.OpsPercentiles{}, service.OpsPercentiles{}, err
}
duration.P50 = floatToIntPtr(p50)
duration.P90 = floatToIntPtr(p90)
duration.P95 = floatToIntPtr(p95)
duration.P99 = floatToIntPtr(p99)
duration.Avg = floatToIntPtr(avg)
if max.Valid {
v := int(max.Int64)
duration.P50 = floatToIntPtr(dP50)
duration.P90 = floatToIntPtr(dP90)
duration.P95 = floatToIntPtr(dP95)
duration.P99 = floatToIntPtr(dP99)
duration.Avg = floatToIntPtr(dAvg)
if dMax.Valid {
v := int(dMax.Int64)
duration.Max = &v
}
}
{
join, where, args, _ := buildUsageWhere(filter, start, end, 1)
q := `
SELECT
percentile_cont(0.50) WITHIN GROUP (ORDER BY first_token_ms) AS p50,
percentile_cont(0.90) WITHIN GROUP (ORDER BY first_token_ms) AS p90,
percentile_cont(0.95) WITHIN GROUP (ORDER BY first_token_ms) AS p95,
percentile_cont(0.99) WITHIN GROUP (ORDER BY first_token_ms) AS p99,
AVG(first_token_ms) AS avg_ms,
MAX(first_token_ms) AS max_ms
FROM usage_logs ul
` + join + `
` + where + `
AND first_token_ms IS NOT NULL`
var p50, p90, p95, p99 sql.NullFloat64
var avg sql.NullFloat64
var max sql.NullInt64
if err := r.db.QueryRowContext(ctx, q, args...).Scan(&p50, &p90, &p95, &p99, &avg, &max); err != nil {
return service.OpsPercentiles{}, service.OpsPercentiles{}, err
}
ttft.P50 = floatToIntPtr(p50)
ttft.P90 = floatToIntPtr(p90)
ttft.P95 = floatToIntPtr(p95)
ttft.P99 = floatToIntPtr(p99)
ttft.Avg = floatToIntPtr(avg)
if max.Valid {
v := int(max.Int64)
ttft.P50 = floatToIntPtr(tP50)
ttft.P90 = floatToIntPtr(tP90)
ttft.P95 = floatToIntPtr(tP95)
ttft.P99 = floatToIntPtr(tP99)
ttft.Avg = floatToIntPtr(tAvg)
if tMax.Valid {
v := int(tMax.Int64)
ttft.Max = &v
}
}
return duration, ttft, nil
}
......@@ -854,20 +902,23 @@ func (r *opsRepository) queryCurrentRates(ctx context.Context, filter *service.O
return qpsCurrent, tpsCurrent, nil
}
func (r *opsRepository) queryPeakQPS(ctx context.Context, filter *service.OpsDashboardFilter, start, end time.Time) (float64, error) {
func (r *opsRepository) queryPeakRates(ctx context.Context, filter *service.OpsDashboardFilter, start, end time.Time) (qpsPeak float64, tpsPeak float64, err error) {
usageJoin, usageWhere, usageArgs, next := buildUsageWhere(filter, start, end, 1)
errorWhere, errorArgs, _ := buildErrorWhere(filter, start, end, next)
q := `
WITH usage_buckets AS (
SELECT date_trunc('minute', ul.created_at) AS bucket, COUNT(*) AS cnt
SELECT
date_trunc('minute', ul.created_at) AS bucket,
COUNT(*) AS req_cnt,
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) AS token_cnt
FROM usage_logs ul
` + usageJoin + `
` + usageWhere + `
GROUP BY 1
),
error_buckets AS (
SELECT date_trunc('minute', created_at) AS bucket, COUNT(*) AS cnt
SELECT date_trunc('minute', created_at) AS bucket, COUNT(*) AS err_cnt
FROM ops_error_logs
` + errorWhere + `
AND COALESCE(status_code, 0) >= 400
......@@ -875,47 +926,33 @@ error_buckets AS (
),
combined AS (
SELECT COALESCE(u.bucket, e.bucket) AS bucket,
COALESCE(u.cnt, 0) + COALESCE(e.cnt, 0) AS total
COALESCE(u.req_cnt, 0) + COALESCE(e.err_cnt, 0) AS total_req,
COALESCE(u.token_cnt, 0) AS total_tokens
FROM usage_buckets u
FULL OUTER JOIN error_buckets e ON u.bucket = e.bucket
)
SELECT COALESCE(MAX(total), 0) FROM combined`
SELECT
COALESCE(MAX(total_req), 0) AS max_req_per_min,
COALESCE(MAX(total_tokens), 0) AS max_tokens_per_min
FROM combined`
args := append(usageArgs, errorArgs...)
var maxPerMinute sql.NullInt64
if err := r.db.QueryRowContext(ctx, q, args...).Scan(&maxPerMinute); err != nil {
return 0, err
var maxReqPerMinute, maxTokensPerMinute sql.NullInt64
if err := r.db.QueryRowContext(ctx, q, args...).Scan(&maxReqPerMinute, &maxTokensPerMinute); err != nil {
return 0, 0, err
}
if maxReqPerMinute.Valid && maxReqPerMinute.Int64 > 0 {
qpsPeak = roundTo1DP(float64(maxReqPerMinute.Int64) / 60.0)
}
if !maxPerMinute.Valid || maxPerMinute.Int64 <= 0 {
return 0, nil
if maxTokensPerMinute.Valid && maxTokensPerMinute.Int64 > 0 {
tpsPeak = roundTo1DP(float64(maxTokensPerMinute.Int64) / 60.0)
}
return roundTo1DP(float64(maxPerMinute.Int64) / 60.0), nil
return qpsPeak, tpsPeak, nil
}
func (r *opsRepository) queryPeakTPS(ctx context.Context, filter *service.OpsDashboardFilter, start, end time.Time) (float64, error) {
join, where, args, _ := buildUsageWhere(filter, start, end, 1)
q := `
SELECT COALESCE(MAX(tokens_per_min), 0)
FROM (
SELECT
date_trunc('minute', ul.created_at) AS bucket,
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) AS tokens_per_min
FROM usage_logs ul
` + join + `
` + where + `
GROUP BY 1
) t`
var maxPerMinute sql.NullInt64
if err := r.db.QueryRowContext(ctx, q, args...).Scan(&maxPerMinute); err != nil {
return 0, err
}
if !maxPerMinute.Valid || maxPerMinute.Int64 <= 0 {
return 0, nil
}
return roundTo1DP(float64(maxPerMinute.Int64) / 60.0), nil
func isQueryTimeoutErr(err error) bool {
return errors.Is(err, context.DeadlineExceeded)
}
func buildUsageWhere(filter *service.OpsDashboardFilter, start, end time.Time, startIndex int) (join string, where string, args []any, nextIndex int) {
......
package repository
import (
"context"
"fmt"
"testing"
)
func TestIsQueryTimeoutErr(t *testing.T) {
if !isQueryTimeoutErr(context.DeadlineExceeded) {
t.Fatalf("context.DeadlineExceeded should be treated as query timeout")
}
if !isQueryTimeoutErr(fmt.Errorf("wrapped: %w", context.DeadlineExceeded)) {
t.Fatalf("wrapped context.DeadlineExceeded should be treated as query timeout")
}
if isQueryTimeoutErr(context.Canceled) {
t.Fatalf("context.Canceled should not be treated as query timeout")
}
if isQueryTimeoutErr(fmt.Errorf("wrapped: %w", context.Canceled)) {
t.Fatalf("wrapped context.Canceled should not be treated as query timeout")
}
}
package repository
import (
"strings"
"testing"
"github.com/Wei-Shaw/sub2api/internal/service"
)
func TestBuildOpsErrorLogsWhere_QueryUsesQualifiedColumns(t *testing.T) {
filter := &service.OpsErrorLogFilter{
Query: "ACCESS_DENIED",
}
where, args := buildOpsErrorLogsWhere(filter)
if where == "" {
t.Fatalf("where should not be empty")
}
if len(args) != 1 {
t.Fatalf("args len = %d, want 1", len(args))
}
if !strings.Contains(where, "e.request_id ILIKE $") {
t.Fatalf("where should include qualified request_id condition: %s", where)
}
if !strings.Contains(where, "e.client_request_id ILIKE $") {
t.Fatalf("where should include qualified client_request_id condition: %s", where)
}
if !strings.Contains(where, "e.error_message ILIKE $") {
t.Fatalf("where should include qualified error_message condition: %s", where)
}
}
func TestBuildOpsErrorLogsWhere_UserQueryUsesExistsSubquery(t *testing.T) {
filter := &service.OpsErrorLogFilter{
UserQuery: "admin@",
}
where, args := buildOpsErrorLogsWhere(filter)
if where == "" {
t.Fatalf("where should not be empty")
}
if len(args) != 1 {
t.Fatalf("args len = %d, want 1", len(args))
}
if !strings.Contains(where, "EXISTS (SELECT 1 FROM users u WHERE u.id = e.user_id AND u.email ILIKE $") {
t.Fatalf("where should include EXISTS user email condition: %s", where)
}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment