Commit bb664d9b authored by yangjianbo's avatar yangjianbo
Browse files

feat(sync): full code sync from release

parent bfc7b339
...@@ -352,6 +352,81 @@ func (s *GroupRepoSuite) TestListWithFilters_Search() { ...@@ -352,6 +352,81 @@ func (s *GroupRepoSuite) TestListWithFilters_Search() {
}) })
} }
func (s *GroupRepoSuite) TestUpdateSortOrders_BatchCaseWhen() {
g1 := &service.Group{
Name: "sort-g1",
Platform: service.PlatformAnthropic,
RateMultiplier: 1.0,
IsExclusive: false,
Status: service.StatusActive,
SubscriptionType: service.SubscriptionTypeStandard,
}
g2 := &service.Group{
Name: "sort-g2",
Platform: service.PlatformAnthropic,
RateMultiplier: 1.0,
IsExclusive: false,
Status: service.StatusActive,
SubscriptionType: service.SubscriptionTypeStandard,
}
g3 := &service.Group{
Name: "sort-g3",
Platform: service.PlatformAnthropic,
RateMultiplier: 1.0,
IsExclusive: false,
Status: service.StatusActive,
SubscriptionType: service.SubscriptionTypeStandard,
}
s.Require().NoError(s.repo.Create(s.ctx, g1))
s.Require().NoError(s.repo.Create(s.ctx, g2))
s.Require().NoError(s.repo.Create(s.ctx, g3))
err := s.repo.UpdateSortOrders(s.ctx, []service.GroupSortOrderUpdate{
{ID: g1.ID, SortOrder: 30},
{ID: g2.ID, SortOrder: 10},
{ID: g3.ID, SortOrder: 20},
{ID: g2.ID, SortOrder: 15}, // 重复 ID 应以最后一次为准
})
s.Require().NoError(err)
got1, err := s.repo.GetByID(s.ctx, g1.ID)
s.Require().NoError(err)
got2, err := s.repo.GetByID(s.ctx, g2.ID)
s.Require().NoError(err)
got3, err := s.repo.GetByID(s.ctx, g3.ID)
s.Require().NoError(err)
s.Require().Equal(30, got1.SortOrder)
s.Require().Equal(15, got2.SortOrder)
s.Require().Equal(20, got3.SortOrder)
}
func (s *GroupRepoSuite) TestUpdateSortOrders_MissingGroupNoPartialUpdate() {
g1 := &service.Group{
Name: "sort-no-partial",
Platform: service.PlatformAnthropic,
RateMultiplier: 1.0,
IsExclusive: false,
Status: service.StatusActive,
SubscriptionType: service.SubscriptionTypeStandard,
}
s.Require().NoError(s.repo.Create(s.ctx, g1))
before, err := s.repo.GetByID(s.ctx, g1.ID)
s.Require().NoError(err)
beforeSort := before.SortOrder
err = s.repo.UpdateSortOrders(s.ctx, []service.GroupSortOrderUpdate{
{ID: g1.ID, SortOrder: 99},
{ID: 99999999, SortOrder: 1},
})
s.Require().Error(err)
s.Require().ErrorIs(err, service.ErrGroupNotFound)
after, err := s.repo.GetByID(s.ctx, g1.ID)
s.Require().NoError(err)
s.Require().Equal(beforeSort, after.SortOrder)
}
func (s *GroupRepoSuite) TestListWithFilters_AccountCount() { func (s *GroupRepoSuite) TestListWithFilters_AccountCount() {
g1 := &service.Group{ g1 := &service.Group{
Name: "g1", Name: "g1",
......
...@@ -147,4 +147,3 @@ func TestIdempotencyRepo_StatusTransition_ToSucceeded(t *testing.T) { ...@@ -147,4 +147,3 @@ func TestIdempotencyRepo_StatusTransition_ToSucceeded(t *testing.T) {
require.Equal(t, `{"ok":true}`, *got.ResponseBody) require.Equal(t, `{"ok":true}`, *got.ResponseBody)
require.Nil(t, got.LockedUntil) require.Nil(t, got.LockedUntil)
} }
...@@ -50,6 +50,23 @@ CREATE TABLE IF NOT EXISTS atlas_schema_revisions ( ...@@ -50,6 +50,23 @@ CREATE TABLE IF NOT EXISTS atlas_schema_revisions (
// 任何稳定的 int64 值都可以,只要不与同一数据库中的其他锁冲突即可。 // 任何稳定的 int64 值都可以,只要不与同一数据库中的其他锁冲突即可。
const migrationsAdvisoryLockID int64 = 694208311321144027 const migrationsAdvisoryLockID int64 = 694208311321144027
const migrationsLockRetryInterval = 500 * time.Millisecond const migrationsLockRetryInterval = 500 * time.Millisecond
const nonTransactionalMigrationSuffix = "_notx.sql"
type migrationChecksumCompatibilityRule struct {
fileChecksum string
acceptedDBChecksum map[string]struct{}
}
// migrationChecksumCompatibilityRules 仅用于兼容历史上误修改过的迁移文件 checksum。
// 规则必须同时匹配「迁移名 + 当前文件 checksum + 历史库 checksum」才会放行,避免放宽全局校验。
var migrationChecksumCompatibilityRules = map[string]migrationChecksumCompatibilityRule{
"054_drop_legacy_cache_columns.sql": {
fileChecksum: "82de761156e03876653e7a6a4eee883cd927847036f779b0b9f34c42a8af7a7d",
acceptedDBChecksum: map[string]struct{}{
"182c193f3359946cf094090cd9e57d5c3fd9abaffbc1e8fc378646b8a6fa12b4": {},
},
},
}
// ApplyMigrations 将嵌入的 SQL 迁移文件应用到指定的数据库。 // ApplyMigrations 将嵌入的 SQL 迁移文件应用到指定的数据库。
// //
...@@ -147,6 +164,10 @@ func applyMigrationsFS(ctx context.Context, db *sql.DB, fsys fs.FS) error { ...@@ -147,6 +164,10 @@ func applyMigrationsFS(ctx context.Context, db *sql.DB, fsys fs.FS) error {
if rowErr == nil { if rowErr == nil {
// 迁移已应用,验证校验和是否匹配 // 迁移已应用,验证校验和是否匹配
if existing != checksum { if existing != checksum {
// 兼容特定历史误改场景(仅白名单规则),其余仍保持严格不可变约束。
if isMigrationChecksumCompatible(name, existing, checksum) {
continue
}
// 校验和不匹配意味着迁移文件在应用后被修改,这是危险的。 // 校验和不匹配意味着迁移文件在应用后被修改,这是危险的。
// 正确的做法是创建新的迁移文件来进行变更。 // 正确的做法是创建新的迁移文件来进行变更。
return fmt.Errorf( return fmt.Errorf(
...@@ -165,8 +186,34 @@ func applyMigrationsFS(ctx context.Context, db *sql.DB, fsys fs.FS) error { ...@@ -165,8 +186,34 @@ func applyMigrationsFS(ctx context.Context, db *sql.DB, fsys fs.FS) error {
return fmt.Errorf("check migration %s: %w", name, rowErr) return fmt.Errorf("check migration %s: %w", name, rowErr)
} }
// 迁移未应用,在事务中执行。 nonTx, err := validateMigrationExecutionMode(name, content)
// 使用事务确保迁移的原子性:要么完全成功,要么完全回滚。 if err != nil {
return fmt.Errorf("validate migration %s: %w", name, err)
}
if nonTx {
// *_notx.sql:用于 CREATE/DROP INDEX CONCURRENTLY 场景,必须非事务执行。
// 逐条语句执行,避免将多条 CONCURRENTLY 语句放入同一个隐式事务块。
statements := splitSQLStatements(content)
for i, stmt := range statements {
trimmed := strings.TrimSpace(stmt)
if trimmed == "" {
continue
}
if stripSQLLineComment(trimmed) == "" {
continue
}
if _, err := db.ExecContext(ctx, trimmed); err != nil {
return fmt.Errorf("apply migration %s (non-tx statement %d): %w", name, i+1, err)
}
}
if _, err := db.ExecContext(ctx, "INSERT INTO schema_migrations (filename, checksum) VALUES ($1, $2)", name, checksum); err != nil {
return fmt.Errorf("record migration %s (non-tx): %w", name, err)
}
continue
}
// 默认迁移在事务中执行,确保原子性:要么完全成功,要么完全回滚。
tx, err := db.BeginTx(ctx, nil) tx, err := db.BeginTx(ctx, nil)
if err != nil { if err != nil {
return fmt.Errorf("begin migration %s: %w", name, err) return fmt.Errorf("begin migration %s: %w", name, err)
...@@ -268,6 +315,84 @@ func latestMigrationBaseline(fsys fs.FS) (string, string, string, error) { ...@@ -268,6 +315,84 @@ func latestMigrationBaseline(fsys fs.FS) (string, string, string, error) {
return version, version, hash, nil return version, version, hash, nil
} }
func isMigrationChecksumCompatible(name, dbChecksum, fileChecksum string) bool {
rule, ok := migrationChecksumCompatibilityRules[name]
if !ok {
return false
}
if rule.fileChecksum != fileChecksum {
return false
}
_, ok = rule.acceptedDBChecksum[dbChecksum]
return ok
}
func validateMigrationExecutionMode(name, content string) (bool, error) {
normalizedName := strings.ToLower(strings.TrimSpace(name))
upperContent := strings.ToUpper(content)
nonTx := strings.HasSuffix(normalizedName, nonTransactionalMigrationSuffix)
if !nonTx {
if strings.Contains(upperContent, "CONCURRENTLY") {
return false, errors.New("CONCURRENTLY statements must be placed in *_notx.sql migrations")
}
return false, nil
}
if strings.Contains(upperContent, "BEGIN") || strings.Contains(upperContent, "COMMIT") || strings.Contains(upperContent, "ROLLBACK") {
return false, errors.New("*_notx.sql must not contain transaction control statements (BEGIN/COMMIT/ROLLBACK)")
}
statements := splitSQLStatements(content)
for _, stmt := range statements {
normalizedStmt := strings.ToUpper(stripSQLLineComment(strings.TrimSpace(stmt)))
if normalizedStmt == "" {
continue
}
if strings.Contains(normalizedStmt, "CONCURRENTLY") {
isCreateIndex := strings.Contains(normalizedStmt, "CREATE") && strings.Contains(normalizedStmt, "INDEX")
isDropIndex := strings.Contains(normalizedStmt, "DROP") && strings.Contains(normalizedStmt, "INDEX")
if !isCreateIndex && !isDropIndex {
return false, errors.New("*_notx.sql currently only supports CREATE/DROP INDEX CONCURRENTLY statements")
}
if isCreateIndex && !strings.Contains(normalizedStmt, "IF NOT EXISTS") {
return false, errors.New("CREATE INDEX CONCURRENTLY in *_notx.sql must include IF NOT EXISTS for idempotency")
}
if isDropIndex && !strings.Contains(normalizedStmt, "IF EXISTS") {
return false, errors.New("DROP INDEX CONCURRENTLY in *_notx.sql must include IF EXISTS for idempotency")
}
continue
}
return false, errors.New("*_notx.sql must not mix non-CONCURRENTLY SQL statements")
}
return true, nil
}
func splitSQLStatements(content string) []string {
parts := strings.Split(content, ";")
out := make([]string, 0, len(parts))
for _, part := range parts {
if strings.TrimSpace(part) == "" {
continue
}
out = append(out, part)
}
return out
}
func stripSQLLineComment(s string) string {
lines := strings.Split(s, "\n")
for i, line := range lines {
if idx := strings.Index(line, "--"); idx >= 0 {
lines[i] = line[:idx]
}
}
return strings.TrimSpace(strings.Join(lines, "\n"))
}
// pgAdvisoryLock 获取 PostgreSQL Advisory Lock。 // pgAdvisoryLock 获取 PostgreSQL Advisory Lock。
// Advisory Lock 是一种轻量级的锁机制,不与任何特定的数据库对象关联。 // Advisory Lock 是一种轻量级的锁机制,不与任何特定的数据库对象关联。
// 它非常适合用于应用层面的分布式锁场景,如迁移序列化。 // 它非常适合用于应用层面的分布式锁场景,如迁移序列化。
......
package repository
import (
"testing"
"github.com/stretchr/testify/require"
)
func TestIsMigrationChecksumCompatible(t *testing.T) {
t.Run("054历史checksum可兼容", func(t *testing.T) {
ok := isMigrationChecksumCompatible(
"054_drop_legacy_cache_columns.sql",
"182c193f3359946cf094090cd9e57d5c3fd9abaffbc1e8fc378646b8a6fa12b4",
"82de761156e03876653e7a6a4eee883cd927847036f779b0b9f34c42a8af7a7d",
)
require.True(t, ok)
})
t.Run("054在未知文件checksum下不兼容", func(t *testing.T) {
ok := isMigrationChecksumCompatible(
"054_drop_legacy_cache_columns.sql",
"182c193f3359946cf094090cd9e57d5c3fd9abaffbc1e8fc378646b8a6fa12b4",
"0000000000000000000000000000000000000000000000000000000000000000",
)
require.False(t, ok)
})
t.Run("非白名单迁移不兼容", func(t *testing.T) {
ok := isMigrationChecksumCompatible(
"001_init.sql",
"182c193f3359946cf094090cd9e57d5c3fd9abaffbc1e8fc378646b8a6fa12b4",
"82de761156e03876653e7a6a4eee883cd927847036f779b0b9f34c42a8af7a7d",
)
require.False(t, ok)
})
}
package repository
import (
"context"
"crypto/sha256"
"encoding/hex"
"errors"
"io/fs"
"strings"
"testing"
"testing/fstest"
"time"
sqlmock "github.com/DATA-DOG/go-sqlmock"
"github.com/stretchr/testify/require"
)
func TestApplyMigrations_NilDB(t *testing.T) {
err := ApplyMigrations(context.Background(), nil)
require.Error(t, err)
require.Contains(t, err.Error(), "nil sql db")
}
func TestApplyMigrations_DelegatesToApplyMigrationsFS(t *testing.T) {
db, mock, err := sqlmock.New()
require.NoError(t, err)
defer func() { _ = db.Close() }()
mock.ExpectQuery("SELECT pg_try_advisory_lock\\(\\$1\\)").
WithArgs(migrationsAdvisoryLockID).
WillReturnError(errors.New("lock failed"))
err = ApplyMigrations(context.Background(), db)
require.Error(t, err)
require.Contains(t, err.Error(), "acquire migrations lock")
require.NoError(t, mock.ExpectationsWereMet())
}
func TestLatestMigrationBaseline(t *testing.T) {
t.Run("empty_fs_returns_baseline", func(t *testing.T) {
version, description, hash, err := latestMigrationBaseline(fstest.MapFS{})
require.NoError(t, err)
require.Equal(t, "baseline", version)
require.Equal(t, "baseline", description)
require.Equal(t, "", hash)
})
t.Run("uses_latest_sorted_sql_file", func(t *testing.T) {
fsys := fstest.MapFS{
"001_init.sql": &fstest.MapFile{Data: []byte("CREATE TABLE t1(id int);")},
"010_final.sql": &fstest.MapFile{
Data: []byte("CREATE TABLE t2(id int);"),
},
}
version, description, hash, err := latestMigrationBaseline(fsys)
require.NoError(t, err)
require.Equal(t, "010_final", version)
require.Equal(t, "010_final", description)
require.Len(t, hash, 64)
})
t.Run("read_file_error", func(t *testing.T) {
fsys := fstest.MapFS{
"010_bad.sql": &fstest.MapFile{Mode: fs.ModeDir},
}
_, _, _, err := latestMigrationBaseline(fsys)
require.Error(t, err)
})
}
func TestIsMigrationChecksumCompatible_AdditionalCases(t *testing.T) {
require.False(t, isMigrationChecksumCompatible("unknown.sql", "db", "file"))
var (
name string
rule migrationChecksumCompatibilityRule
)
for n, r := range migrationChecksumCompatibilityRules {
name = n
rule = r
break
}
require.NotEmpty(t, name)
require.False(t, isMigrationChecksumCompatible(name, "db-not-accepted", "file-not-match"))
require.False(t, isMigrationChecksumCompatible(name, "db-not-accepted", rule.fileChecksum))
var accepted string
for checksum := range rule.acceptedDBChecksum {
accepted = checksum
break
}
require.NotEmpty(t, accepted)
require.True(t, isMigrationChecksumCompatible(name, accepted, rule.fileChecksum))
}
func TestEnsureAtlasBaselineAligned(t *testing.T) {
t.Run("skip_when_no_legacy_table", func(t *testing.T) {
db, mock, err := sqlmock.New()
require.NoError(t, err)
defer func() { _ = db.Close() }()
mock.ExpectQuery("SELECT EXISTS \\(").
WithArgs("schema_migrations").
WillReturnRows(sqlmock.NewRows([]string{"exists"}).AddRow(false))
err = ensureAtlasBaselineAligned(context.Background(), db, fstest.MapFS{})
require.NoError(t, err)
require.NoError(t, mock.ExpectationsWereMet())
})
t.Run("create_atlas_and_insert_baseline_when_empty", func(t *testing.T) {
db, mock, err := sqlmock.New()
require.NoError(t, err)
defer func() { _ = db.Close() }()
mock.ExpectQuery("SELECT EXISTS \\(").
WithArgs("schema_migrations").
WillReturnRows(sqlmock.NewRows([]string{"exists"}).AddRow(true))
mock.ExpectQuery("SELECT EXISTS \\(").
WithArgs("atlas_schema_revisions").
WillReturnRows(sqlmock.NewRows([]string{"exists"}).AddRow(false))
mock.ExpectExec("CREATE TABLE IF NOT EXISTS atlas_schema_revisions").
WillReturnResult(sqlmock.NewResult(0, 0))
mock.ExpectQuery("SELECT COUNT\\(\\*\\) FROM atlas_schema_revisions").
WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(0))
mock.ExpectExec("INSERT INTO atlas_schema_revisions").
WithArgs("002_next", "002_next", 1, sqlmock.AnyArg()).
WillReturnResult(sqlmock.NewResult(1, 1))
fsys := fstest.MapFS{
"001_init.sql": &fstest.MapFile{Data: []byte("CREATE TABLE t1(id int);")},
"002_next.sql": &fstest.MapFile{Data: []byte("CREATE TABLE t2(id int);")},
}
err = ensureAtlasBaselineAligned(context.Background(), db, fsys)
require.NoError(t, err)
require.NoError(t, mock.ExpectationsWereMet())
})
t.Run("error_when_checking_legacy_table", func(t *testing.T) {
db, mock, err := sqlmock.New()
require.NoError(t, err)
defer func() { _ = db.Close() }()
mock.ExpectQuery("SELECT EXISTS \\(").
WithArgs("schema_migrations").
WillReturnError(errors.New("exists failed"))
err = ensureAtlasBaselineAligned(context.Background(), db, fstest.MapFS{})
require.Error(t, err)
require.Contains(t, err.Error(), "check schema_migrations")
require.NoError(t, mock.ExpectationsWereMet())
})
t.Run("error_when_counting_atlas_rows", func(t *testing.T) {
db, mock, err := sqlmock.New()
require.NoError(t, err)
defer func() { _ = db.Close() }()
mock.ExpectQuery("SELECT EXISTS \\(").
WithArgs("schema_migrations").
WillReturnRows(sqlmock.NewRows([]string{"exists"}).AddRow(true))
mock.ExpectQuery("SELECT EXISTS \\(").
WithArgs("atlas_schema_revisions").
WillReturnRows(sqlmock.NewRows([]string{"exists"}).AddRow(true))
mock.ExpectQuery("SELECT COUNT\\(\\*\\) FROM atlas_schema_revisions").
WillReturnError(errors.New("count failed"))
err = ensureAtlasBaselineAligned(context.Background(), db, fstest.MapFS{})
require.Error(t, err)
require.Contains(t, err.Error(), "count atlas_schema_revisions")
require.NoError(t, mock.ExpectationsWereMet())
})
t.Run("error_when_creating_atlas_table", func(t *testing.T) {
db, mock, err := sqlmock.New()
require.NoError(t, err)
defer func() { _ = db.Close() }()
mock.ExpectQuery("SELECT EXISTS \\(").
WithArgs("schema_migrations").
WillReturnRows(sqlmock.NewRows([]string{"exists"}).AddRow(true))
mock.ExpectQuery("SELECT EXISTS \\(").
WithArgs("atlas_schema_revisions").
WillReturnRows(sqlmock.NewRows([]string{"exists"}).AddRow(false))
mock.ExpectExec("CREATE TABLE IF NOT EXISTS atlas_schema_revisions").
WillReturnError(errors.New("create failed"))
err = ensureAtlasBaselineAligned(context.Background(), db, fstest.MapFS{})
require.Error(t, err)
require.Contains(t, err.Error(), "create atlas_schema_revisions")
require.NoError(t, mock.ExpectationsWereMet())
})
t.Run("error_when_inserting_baseline", func(t *testing.T) {
db, mock, err := sqlmock.New()
require.NoError(t, err)
defer func() { _ = db.Close() }()
mock.ExpectQuery("SELECT EXISTS \\(").
WithArgs("schema_migrations").
WillReturnRows(sqlmock.NewRows([]string{"exists"}).AddRow(true))
mock.ExpectQuery("SELECT EXISTS \\(").
WithArgs("atlas_schema_revisions").
WillReturnRows(sqlmock.NewRows([]string{"exists"}).AddRow(true))
mock.ExpectQuery("SELECT COUNT\\(\\*\\) FROM atlas_schema_revisions").
WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(0))
mock.ExpectExec("INSERT INTO atlas_schema_revisions").
WithArgs("001_init", "001_init", 1, sqlmock.AnyArg()).
WillReturnError(errors.New("insert failed"))
fsys := fstest.MapFS{
"001_init.sql": &fstest.MapFile{Data: []byte("CREATE TABLE t(id int);")},
}
err = ensureAtlasBaselineAligned(context.Background(), db, fsys)
require.Error(t, err)
require.Contains(t, err.Error(), "insert atlas baseline")
require.NoError(t, mock.ExpectationsWereMet())
})
}
func TestApplyMigrationsFS_ChecksumMismatchRejected(t *testing.T) {
db, mock, err := sqlmock.New()
require.NoError(t, err)
defer func() { _ = db.Close() }()
prepareMigrationsBootstrapExpectations(mock)
mock.ExpectQuery("SELECT checksum FROM schema_migrations WHERE filename = \\$1").
WithArgs("001_init.sql").
WillReturnRows(sqlmock.NewRows([]string{"checksum"}).AddRow("mismatched-checksum"))
mock.ExpectExec("SELECT pg_advisory_unlock\\(\\$1\\)").
WithArgs(migrationsAdvisoryLockID).
WillReturnResult(sqlmock.NewResult(0, 1))
fsys := fstest.MapFS{
"001_init.sql": &fstest.MapFile{Data: []byte("CREATE TABLE t(id int);")},
}
err = applyMigrationsFS(context.Background(), db, fsys)
require.Error(t, err)
require.Contains(t, err.Error(), "checksum mismatch")
require.NoError(t, mock.ExpectationsWereMet())
}
func TestApplyMigrationsFS_CheckMigrationQueryError(t *testing.T) {
db, mock, err := sqlmock.New()
require.NoError(t, err)
defer func() { _ = db.Close() }()
prepareMigrationsBootstrapExpectations(mock)
mock.ExpectQuery("SELECT checksum FROM schema_migrations WHERE filename = \\$1").
WithArgs("001_err.sql").
WillReturnError(errors.New("query failed"))
mock.ExpectExec("SELECT pg_advisory_unlock\\(\\$1\\)").
WithArgs(migrationsAdvisoryLockID).
WillReturnResult(sqlmock.NewResult(0, 1))
fsys := fstest.MapFS{
"001_err.sql": &fstest.MapFile{Data: []byte("SELECT 1;")},
}
err = applyMigrationsFS(context.Background(), db, fsys)
require.Error(t, err)
require.Contains(t, err.Error(), "check migration 001_err.sql")
require.NoError(t, mock.ExpectationsWereMet())
}
func TestApplyMigrationsFS_SkipEmptyAndAlreadyApplied(t *testing.T) {
db, mock, err := sqlmock.New()
require.NoError(t, err)
defer func() { _ = db.Close() }()
prepareMigrationsBootstrapExpectations(mock)
alreadySQL := "CREATE TABLE t(id int);"
checksum := migrationChecksum(alreadySQL)
mock.ExpectQuery("SELECT checksum FROM schema_migrations WHERE filename = \\$1").
WithArgs("001_already.sql").
WillReturnRows(sqlmock.NewRows([]string{"checksum"}).AddRow(checksum))
mock.ExpectExec("SELECT pg_advisory_unlock\\(\\$1\\)").
WithArgs(migrationsAdvisoryLockID).
WillReturnResult(sqlmock.NewResult(0, 1))
fsys := fstest.MapFS{
"000_empty.sql": &fstest.MapFile{Data: []byte(" \n\t ")},
"001_already.sql": &fstest.MapFile{Data: []byte(alreadySQL)},
}
err = applyMigrationsFS(context.Background(), db, fsys)
require.NoError(t, err)
require.NoError(t, mock.ExpectationsWereMet())
}
func TestApplyMigrationsFS_ReadMigrationError(t *testing.T) {
db, mock, err := sqlmock.New()
require.NoError(t, err)
defer func() { _ = db.Close() }()
prepareMigrationsBootstrapExpectations(mock)
mock.ExpectExec("SELECT pg_advisory_unlock\\(\\$1\\)").
WithArgs(migrationsAdvisoryLockID).
WillReturnResult(sqlmock.NewResult(0, 1))
fsys := fstest.MapFS{
"001_bad.sql": &fstest.MapFile{Mode: fs.ModeDir},
}
err = applyMigrationsFS(context.Background(), db, fsys)
require.Error(t, err)
require.Contains(t, err.Error(), "read migration 001_bad.sql")
require.NoError(t, mock.ExpectationsWereMet())
}
func TestPgAdvisoryLockAndUnlock_ErrorBranches(t *testing.T) {
t.Run("context_cancelled_while_not_locked", func(t *testing.T) {
db, mock, err := sqlmock.New()
require.NoError(t, err)
defer func() { _ = db.Close() }()
mock.ExpectQuery("SELECT pg_try_advisory_lock\\(\\$1\\)").
WithArgs(migrationsAdvisoryLockID).
WillReturnRows(sqlmock.NewRows([]string{"pg_try_advisory_lock"}).AddRow(false))
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Millisecond)
defer cancel()
err = pgAdvisoryLock(ctx, db)
require.Error(t, err)
require.Contains(t, err.Error(), "acquire migrations lock")
require.NoError(t, mock.ExpectationsWereMet())
})
t.Run("unlock_exec_error", func(t *testing.T) {
db, mock, err := sqlmock.New()
require.NoError(t, err)
defer func() { _ = db.Close() }()
mock.ExpectExec("SELECT pg_advisory_unlock\\(\\$1\\)").
WithArgs(migrationsAdvisoryLockID).
WillReturnError(errors.New("unlock failed"))
err = pgAdvisoryUnlock(context.Background(), db)
require.Error(t, err)
require.Contains(t, err.Error(), "release migrations lock")
require.NoError(t, mock.ExpectationsWereMet())
})
t.Run("acquire_lock_after_retry", func(t *testing.T) {
db, mock, err := sqlmock.New()
require.NoError(t, err)
defer func() { _ = db.Close() }()
mock.ExpectQuery("SELECT pg_try_advisory_lock\\(\\$1\\)").
WithArgs(migrationsAdvisoryLockID).
WillReturnRows(sqlmock.NewRows([]string{"pg_try_advisory_lock"}).AddRow(false))
mock.ExpectQuery("SELECT pg_try_advisory_lock\\(\\$1\\)").
WithArgs(migrationsAdvisoryLockID).
WillReturnRows(sqlmock.NewRows([]string{"pg_try_advisory_lock"}).AddRow(true))
ctx, cancel := context.WithTimeout(context.Background(), migrationsLockRetryInterval*3)
defer cancel()
start := time.Now()
err = pgAdvisoryLock(ctx, db)
require.NoError(t, err)
require.GreaterOrEqual(t, time.Since(start), migrationsLockRetryInterval)
require.NoError(t, mock.ExpectationsWereMet())
})
}
func migrationChecksum(content string) string {
sum := sha256.Sum256([]byte(strings.TrimSpace(content)))
return hex.EncodeToString(sum[:])
}
package repository
import (
"context"
"database/sql"
"testing"
"testing/fstest"
sqlmock "github.com/DATA-DOG/go-sqlmock"
"github.com/stretchr/testify/require"
)
func TestValidateMigrationExecutionMode(t *testing.T) {
t.Run("事务迁移包含CONCURRENTLY会被拒绝", func(t *testing.T) {
nonTx, err := validateMigrationExecutionMode("001_add_idx.sql", "CREATE INDEX CONCURRENTLY idx_a ON t(a);")
require.False(t, nonTx)
require.Error(t, err)
})
t.Run("notx迁移要求CREATE使用IF NOT EXISTS", func(t *testing.T) {
nonTx, err := validateMigrationExecutionMode("001_add_idx_notx.sql", "CREATE INDEX CONCURRENTLY idx_a ON t(a);")
require.False(t, nonTx)
require.Error(t, err)
})
t.Run("notx迁移要求DROP使用IF EXISTS", func(t *testing.T) {
nonTx, err := validateMigrationExecutionMode("001_drop_idx_notx.sql", "DROP INDEX CONCURRENTLY idx_a;")
require.False(t, nonTx)
require.Error(t, err)
})
t.Run("notx迁移禁止事务控制语句", func(t *testing.T) {
nonTx, err := validateMigrationExecutionMode("001_add_idx_notx.sql", "BEGIN; CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_a ON t(a); COMMIT;")
require.False(t, nonTx)
require.Error(t, err)
})
t.Run("notx迁移禁止混用非CONCURRENTLY语句", func(t *testing.T) {
nonTx, err := validateMigrationExecutionMode("001_add_idx_notx.sql", "CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_a ON t(a); UPDATE t SET a = 1;")
require.False(t, nonTx)
require.Error(t, err)
})
t.Run("notx迁移允许幂等并发索引语句", func(t *testing.T) {
nonTx, err := validateMigrationExecutionMode("001_add_idx_notx.sql", `
CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_a ON t(a);
DROP INDEX CONCURRENTLY IF EXISTS idx_b;
`)
require.True(t, nonTx)
require.NoError(t, err)
})
}
func TestApplyMigrationsFS_NonTransactionalMigration(t *testing.T) {
db, mock, err := sqlmock.New()
require.NoError(t, err)
defer func() { _ = db.Close() }()
prepareMigrationsBootstrapExpectations(mock)
mock.ExpectQuery("SELECT checksum FROM schema_migrations WHERE filename = \\$1").
WithArgs("001_add_idx_notx.sql").
WillReturnError(sql.ErrNoRows)
mock.ExpectExec("CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_t_a ON t\\(a\\)").
WillReturnResult(sqlmock.NewResult(0, 0))
mock.ExpectExec("INSERT INTO schema_migrations \\(filename, checksum\\) VALUES \\(\\$1, \\$2\\)").
WithArgs("001_add_idx_notx.sql", sqlmock.AnyArg()).
WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectExec("SELECT pg_advisory_unlock\\(\\$1\\)").
WithArgs(migrationsAdvisoryLockID).
WillReturnResult(sqlmock.NewResult(0, 1))
fsys := fstest.MapFS{
"001_add_idx_notx.sql": &fstest.MapFile{
Data: []byte("CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_t_a ON t(a);"),
},
}
err = applyMigrationsFS(context.Background(), db, fsys)
require.NoError(t, err)
require.NoError(t, mock.ExpectationsWereMet())
}
func TestApplyMigrationsFS_NonTransactionalMigration_MultiStatements(t *testing.T) {
db, mock, err := sqlmock.New()
require.NoError(t, err)
defer func() { _ = db.Close() }()
prepareMigrationsBootstrapExpectations(mock)
mock.ExpectQuery("SELECT checksum FROM schema_migrations WHERE filename = \\$1").
WithArgs("001_add_multi_idx_notx.sql").
WillReturnError(sql.ErrNoRows)
mock.ExpectExec("CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_t_a ON t\\(a\\)").
WillReturnResult(sqlmock.NewResult(0, 0))
mock.ExpectExec("CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_t_b ON t\\(b\\)").
WillReturnResult(sqlmock.NewResult(0, 0))
mock.ExpectExec("INSERT INTO schema_migrations \\(filename, checksum\\) VALUES \\(\\$1, \\$2\\)").
WithArgs("001_add_multi_idx_notx.sql", sqlmock.AnyArg()).
WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectExec("SELECT pg_advisory_unlock\\(\\$1\\)").
WithArgs(migrationsAdvisoryLockID).
WillReturnResult(sqlmock.NewResult(0, 1))
fsys := fstest.MapFS{
"001_add_multi_idx_notx.sql": &fstest.MapFile{
Data: []byte(`
-- first
CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_t_a ON t(a);
-- second
CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_t_b ON t(b);
`),
},
}
err = applyMigrationsFS(context.Background(), db, fsys)
require.NoError(t, err)
require.NoError(t, mock.ExpectationsWereMet())
}
func TestApplyMigrationsFS_TransactionalMigration(t *testing.T) {
db, mock, err := sqlmock.New()
require.NoError(t, err)
defer func() { _ = db.Close() }()
prepareMigrationsBootstrapExpectations(mock)
mock.ExpectQuery("SELECT checksum FROM schema_migrations WHERE filename = \\$1").
WithArgs("001_add_col.sql").
WillReturnError(sql.ErrNoRows)
mock.ExpectBegin()
mock.ExpectExec("ALTER TABLE t ADD COLUMN name TEXT").
WillReturnResult(sqlmock.NewResult(0, 0))
mock.ExpectExec("INSERT INTO schema_migrations \\(filename, checksum\\) VALUES \\(\\$1, \\$2\\)").
WithArgs("001_add_col.sql", sqlmock.AnyArg()).
WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectCommit()
mock.ExpectExec("SELECT pg_advisory_unlock\\(\\$1\\)").
WithArgs(migrationsAdvisoryLockID).
WillReturnResult(sqlmock.NewResult(0, 1))
fsys := fstest.MapFS{
"001_add_col.sql": &fstest.MapFile{
Data: []byte("ALTER TABLE t ADD COLUMN name TEXT;"),
},
}
err = applyMigrationsFS(context.Background(), db, fsys)
require.NoError(t, err)
require.NoError(t, mock.ExpectationsWereMet())
}
func prepareMigrationsBootstrapExpectations(mock sqlmock.Sqlmock) {
mock.ExpectQuery("SELECT pg_try_advisory_lock\\(\\$1\\)").
WithArgs(migrationsAdvisoryLockID).
WillReturnRows(sqlmock.NewRows([]string{"pg_try_advisory_lock"}).AddRow(true))
mock.ExpectExec("CREATE TABLE IF NOT EXISTS schema_migrations").
WillReturnResult(sqlmock.NewResult(0, 0))
mock.ExpectQuery("SELECT EXISTS \\(").
WithArgs("schema_migrations").
WillReturnRows(sqlmock.NewRows([]string{"exists"}).AddRow(true))
mock.ExpectQuery("SELECT EXISTS \\(").
WithArgs("atlas_schema_revisions").
WillReturnRows(sqlmock.NewRows([]string{"exists"}).AddRow(true))
mock.ExpectQuery("SELECT COUNT\\(\\*\\) FROM atlas_schema_revisions").
WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(1))
}
...@@ -42,6 +42,8 @@ func TestMigrationsRunner_IsIdempotent_AndSchemaIsUpToDate(t *testing.T) { ...@@ -42,6 +42,8 @@ func TestMigrationsRunner_IsIdempotent_AndSchemaIsUpToDate(t *testing.T) {
// usage_logs: billing_type used by filters/stats // usage_logs: billing_type used by filters/stats
requireColumn(t, tx, "usage_logs", "billing_type", "smallint", 0, false) requireColumn(t, tx, "usage_logs", "billing_type", "smallint", 0, false)
requireColumn(t, tx, "usage_logs", "request_type", "smallint", 0, false)
requireColumn(t, tx, "usage_logs", "openai_ws_mode", "boolean", 0, false)
// settings table should exist // settings table should exist
var settingsRegclass sql.NullString var settingsRegclass sql.NullString
......
...@@ -22,16 +22,20 @@ type openaiOAuthService struct { ...@@ -22,16 +22,20 @@ type openaiOAuthService struct {
tokenURL string tokenURL string
} }
func (s *openaiOAuthService) ExchangeCode(ctx context.Context, code, codeVerifier, redirectURI, proxyURL string) (*openai.TokenResponse, error) { func (s *openaiOAuthService) ExchangeCode(ctx context.Context, code, codeVerifier, redirectURI, proxyURL, clientID string) (*openai.TokenResponse, error) {
client := createOpenAIReqClient(proxyURL) client := createOpenAIReqClient(proxyURL)
if redirectURI == "" { if redirectURI == "" {
redirectURI = openai.DefaultRedirectURI redirectURI = openai.DefaultRedirectURI
} }
clientID = strings.TrimSpace(clientID)
if clientID == "" {
clientID = openai.ClientID
}
formData := url.Values{} formData := url.Values{}
formData.Set("grant_type", "authorization_code") formData.Set("grant_type", "authorization_code")
formData.Set("client_id", openai.ClientID) formData.Set("client_id", clientID)
formData.Set("code", code) formData.Set("code", code)
formData.Set("redirect_uri", redirectURI) formData.Set("redirect_uri", redirectURI)
formData.Set("code_verifier", codeVerifier) formData.Set("code_verifier", codeVerifier)
...@@ -61,36 +65,12 @@ func (s *openaiOAuthService) RefreshToken(ctx context.Context, refreshToken, pro ...@@ -61,36 +65,12 @@ func (s *openaiOAuthService) RefreshToken(ctx context.Context, refreshToken, pro
} }
func (s *openaiOAuthService) RefreshTokenWithClientID(ctx context.Context, refreshToken, proxyURL string, clientID string) (*openai.TokenResponse, error) { func (s *openaiOAuthService) RefreshTokenWithClientID(ctx context.Context, refreshToken, proxyURL string, clientID string) (*openai.TokenResponse, error) {
if strings.TrimSpace(clientID) != "" { // 调用方应始终传入正确的 client_id;为兼容旧数据,未指定时默认使用 OpenAI ClientID
return s.refreshTokenWithClientID(ctx, refreshToken, proxyURL, strings.TrimSpace(clientID)) clientID = strings.TrimSpace(clientID)
} if clientID == "" {
clientID = openai.ClientID
clientIDs := []string{
openai.ClientID,
openai.SoraClientID,
}
seen := make(map[string]struct{}, len(clientIDs))
var lastErr error
for _, clientID := range clientIDs {
clientID = strings.TrimSpace(clientID)
if clientID == "" {
continue
}
if _, ok := seen[clientID]; ok {
continue
}
seen[clientID] = struct{}{}
tokenResp, err := s.refreshTokenWithClientID(ctx, refreshToken, proxyURL, clientID)
if err == nil {
return tokenResp, nil
}
lastErr = err
}
if lastErr != nil {
return nil, lastErr
} }
return nil, infraerrors.New(http.StatusBadGateway, "OPENAI_OAUTH_TOKEN_REFRESH_FAILED", "token refresh failed") return s.refreshTokenWithClientID(ctx, refreshToken, proxyURL, clientID)
} }
func (s *openaiOAuthService) refreshTokenWithClientID(ctx context.Context, refreshToken, proxyURL, clientID string) (*openai.TokenResponse, error) { func (s *openaiOAuthService) refreshTokenWithClientID(ctx context.Context, refreshToken, proxyURL, clientID string) (*openai.TokenResponse, error) {
......
...@@ -81,7 +81,7 @@ func (s *OpenAIOAuthServiceSuite) TestExchangeCode_DefaultRedirectURI() { ...@@ -81,7 +81,7 @@ func (s *OpenAIOAuthServiceSuite) TestExchangeCode_DefaultRedirectURI() {
_, _ = io.WriteString(w, `{"access_token":"at","refresh_token":"rt","token_type":"bearer","expires_in":3600}`) _, _ = io.WriteString(w, `{"access_token":"at","refresh_token":"rt","token_type":"bearer","expires_in":3600}`)
})) }))
resp, err := s.svc.ExchangeCode(s.ctx, "code", "ver", "", "") resp, err := s.svc.ExchangeCode(s.ctx, "code", "ver", "", "", "")
require.NoError(s.T(), err, "ExchangeCode") require.NoError(s.T(), err, "ExchangeCode")
select { select {
case msg := <-errCh: case msg := <-errCh:
...@@ -136,7 +136,9 @@ func (s *OpenAIOAuthServiceSuite) TestRefreshToken_FormFields() { ...@@ -136,7 +136,9 @@ func (s *OpenAIOAuthServiceSuite) TestRefreshToken_FormFields() {
require.Equal(s.T(), "rt2", resp.RefreshToken) require.Equal(s.T(), "rt2", resp.RefreshToken)
} }
func (s *OpenAIOAuthServiceSuite) TestRefreshToken_FallbackToSoraClientID() { // TestRefreshToken_DefaultsToOpenAIClientID 验证未指定 client_id 时默认使用 OpenAI ClientID,
// 且只发送一次请求(不再盲猜多个 client_id)。
func (s *OpenAIOAuthServiceSuite) TestRefreshToken_DefaultsToOpenAIClientID() {
var seenClientIDs []string var seenClientIDs []string
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if err := r.ParseForm(); err != nil { if err := r.ParseForm(); err != nil {
...@@ -145,11 +147,27 @@ func (s *OpenAIOAuthServiceSuite) TestRefreshToken_FallbackToSoraClientID() { ...@@ -145,11 +147,27 @@ func (s *OpenAIOAuthServiceSuite) TestRefreshToken_FallbackToSoraClientID() {
} }
clientID := r.PostForm.Get("client_id") clientID := r.PostForm.Get("client_id")
seenClientIDs = append(seenClientIDs, clientID) seenClientIDs = append(seenClientIDs, clientID)
if clientID == openai.ClientID { w.Header().Set("Content-Type", "application/json")
_, _ = io.WriteString(w, `{"access_token":"at","refresh_token":"rt","token_type":"bearer","expires_in":3600}`)
}))
resp, err := s.svc.RefreshToken(s.ctx, "rt", "")
require.NoError(s.T(), err, "RefreshToken")
require.Equal(s.T(), "at", resp.AccessToken)
// 只发送了一次请求,使用默认的 OpenAI ClientID
require.Equal(s.T(), []string{openai.ClientID}, seenClientIDs)
}
// TestRefreshToken_UseSoraClientID 验证显式传入 Sora ClientID 时直接使用,不回退。
func (s *OpenAIOAuthServiceSuite) TestRefreshToken_UseSoraClientID() {
var seenClientIDs []string
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if err := r.ParseForm(); err != nil {
w.WriteHeader(http.StatusBadRequest) w.WriteHeader(http.StatusBadRequest)
_, _ = io.WriteString(w, "invalid_grant")
return return
} }
clientID := r.PostForm.Get("client_id")
seenClientIDs = append(seenClientIDs, clientID)
if clientID == openai.SoraClientID { if clientID == openai.SoraClientID {
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
_, _ = io.WriteString(w, `{"access_token":"at-sora","refresh_token":"rt-sora","token_type":"bearer","expires_in":3600}`) _, _ = io.WriteString(w, `{"access_token":"at-sora","refresh_token":"rt-sora","token_type":"bearer","expires_in":3600}`)
...@@ -158,11 +176,10 @@ func (s *OpenAIOAuthServiceSuite) TestRefreshToken_FallbackToSoraClientID() { ...@@ -158,11 +176,10 @@ func (s *OpenAIOAuthServiceSuite) TestRefreshToken_FallbackToSoraClientID() {
w.WriteHeader(http.StatusBadRequest) w.WriteHeader(http.StatusBadRequest)
})) }))
resp, err := s.svc.RefreshToken(s.ctx, "rt", "") resp, err := s.svc.RefreshTokenWithClientID(s.ctx, "rt", "", openai.SoraClientID)
require.NoError(s.T(), err, "RefreshToken") require.NoError(s.T(), err, "RefreshTokenWithClientID")
require.Equal(s.T(), "at-sora", resp.AccessToken) require.Equal(s.T(), "at-sora", resp.AccessToken)
require.Equal(s.T(), "rt-sora", resp.RefreshToken) require.Equal(s.T(), []string{openai.SoraClientID}, seenClientIDs)
require.Equal(s.T(), []string{openai.ClientID, openai.SoraClientID}, seenClientIDs)
} }
func (s *OpenAIOAuthServiceSuite) TestRefreshToken_UseProvidedClientID() { func (s *OpenAIOAuthServiceSuite) TestRefreshToken_UseProvidedClientID() {
...@@ -196,7 +213,7 @@ func (s *OpenAIOAuthServiceSuite) TestNonSuccessStatus_IncludesBody() { ...@@ -196,7 +213,7 @@ func (s *OpenAIOAuthServiceSuite) TestNonSuccessStatus_IncludesBody() {
_, _ = io.WriteString(w, "bad") _, _ = io.WriteString(w, "bad")
})) }))
_, err := s.svc.ExchangeCode(s.ctx, "code", "ver", openai.DefaultRedirectURI, "") _, err := s.svc.ExchangeCode(s.ctx, "code", "ver", openai.DefaultRedirectURI, "", "")
require.Error(s.T(), err) require.Error(s.T(), err)
require.ErrorContains(s.T(), err, "status 400") require.ErrorContains(s.T(), err, "status 400")
require.ErrorContains(s.T(), err, "bad") require.ErrorContains(s.T(), err, "bad")
...@@ -206,7 +223,7 @@ func (s *OpenAIOAuthServiceSuite) TestRequestError_ClosedServer() { ...@@ -206,7 +223,7 @@ func (s *OpenAIOAuthServiceSuite) TestRequestError_ClosedServer() {
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})) s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
s.srv.Close() s.srv.Close()
_, err := s.svc.ExchangeCode(s.ctx, "code", "ver", openai.DefaultRedirectURI, "") _, err := s.svc.ExchangeCode(s.ctx, "code", "ver", openai.DefaultRedirectURI, "", "")
require.Error(s.T(), err) require.Error(s.T(), err)
require.ErrorContains(s.T(), err, "request failed") require.ErrorContains(s.T(), err, "request failed")
} }
...@@ -223,7 +240,7 @@ func (s *OpenAIOAuthServiceSuite) TestContextCancel() { ...@@ -223,7 +240,7 @@ func (s *OpenAIOAuthServiceSuite) TestContextCancel() {
done := make(chan error, 1) done := make(chan error, 1)
go func() { go func() {
_, err := s.svc.ExchangeCode(ctx, "code", "ver", openai.DefaultRedirectURI, "") _, err := s.svc.ExchangeCode(ctx, "code", "ver", openai.DefaultRedirectURI, "", "")
done <- err done <- err
}() }()
...@@ -249,7 +266,30 @@ func (s *OpenAIOAuthServiceSuite) TestExchangeCode_UsesProvidedRedirectURI() { ...@@ -249,7 +266,30 @@ func (s *OpenAIOAuthServiceSuite) TestExchangeCode_UsesProvidedRedirectURI() {
_, _ = io.WriteString(w, `{"access_token":"at","token_type":"bearer","expires_in":1}`) _, _ = io.WriteString(w, `{"access_token":"at","token_type":"bearer","expires_in":1}`)
})) }))
_, err := s.svc.ExchangeCode(s.ctx, "code", "ver", want, "") _, err := s.svc.ExchangeCode(s.ctx, "code", "ver", want, "", "")
require.NoError(s.T(), err, "ExchangeCode")
select {
case msg := <-errCh:
require.Fail(s.T(), msg)
default:
}
}
func (s *OpenAIOAuthServiceSuite) TestExchangeCode_UseProvidedClientID() {
wantClientID := openai.SoraClientID
errCh := make(chan string, 1)
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_ = r.ParseForm()
if got := r.PostForm.Get("client_id"); got != wantClientID {
errCh <- "client_id mismatch"
w.WriteHeader(http.StatusBadRequest)
return
}
w.Header().Set("Content-Type", "application/json")
_, _ = io.WriteString(w, `{"access_token":"at","token_type":"bearer","expires_in":1}`)
}))
_, err := s.svc.ExchangeCode(s.ctx, "code", "ver", openai.DefaultRedirectURI, "", wantClientID)
require.NoError(s.T(), err, "ExchangeCode") require.NoError(s.T(), err, "ExchangeCode")
select { select {
case msg := <-errCh: case msg := <-errCh:
...@@ -267,7 +307,7 @@ func (s *OpenAIOAuthServiceSuite) TestTokenURL_CanBeOverriddenWithQuery() { ...@@ -267,7 +307,7 @@ func (s *OpenAIOAuthServiceSuite) TestTokenURL_CanBeOverriddenWithQuery() {
})) }))
s.svc.tokenURL = s.srv.URL + "?x=1" s.svc.tokenURL = s.srv.URL + "?x=1"
_, err := s.svc.ExchangeCode(s.ctx, "code", "ver", openai.DefaultRedirectURI, "") _, err := s.svc.ExchangeCode(s.ctx, "code", "ver", openai.DefaultRedirectURI, "", "")
require.NoError(s.T(), err, "ExchangeCode") require.NoError(s.T(), err, "ExchangeCode")
select { select {
case <-s.received: case <-s.received:
...@@ -283,7 +323,7 @@ func (s *OpenAIOAuthServiceSuite) TestExchangeCode_SuccessButInvalidJSON() { ...@@ -283,7 +323,7 @@ func (s *OpenAIOAuthServiceSuite) TestExchangeCode_SuccessButInvalidJSON() {
_, _ = io.WriteString(w, "not-valid-json") _, _ = io.WriteString(w, "not-valid-json")
})) }))
_, err := s.svc.ExchangeCode(s.ctx, "code", "ver", openai.DefaultRedirectURI, "") _, err := s.svc.ExchangeCode(s.ctx, "code", "ver", openai.DefaultRedirectURI, "", "")
require.Error(s.T(), err, "expected error for invalid JSON response") require.Error(s.T(), err, "expected error for invalid JSON response")
} }
......
...@@ -12,6 +12,11 @@ import ( ...@@ -12,6 +12,11 @@ import (
"github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
) )
const (
opsRawLatencyQueryTimeout = 2 * time.Second
opsRawPeakQueryTimeout = 1500 * time.Millisecond
)
func (r *opsRepository) GetDashboardOverview(ctx context.Context, filter *service.OpsDashboardFilter) (*service.OpsDashboardOverview, error) { func (r *opsRepository) GetDashboardOverview(ctx context.Context, filter *service.OpsDashboardFilter) (*service.OpsDashboardOverview, error) {
if r == nil || r.db == nil { if r == nil || r.db == nil {
return nil, fmt.Errorf("nil ops repository") return nil, fmt.Errorf("nil ops repository")
...@@ -45,15 +50,24 @@ func (r *opsRepository) GetDashboardOverview(ctx context.Context, filter *servic ...@@ -45,15 +50,24 @@ func (r *opsRepository) GetDashboardOverview(ctx context.Context, filter *servic
func (r *opsRepository) getDashboardOverviewRaw(ctx context.Context, filter *service.OpsDashboardFilter) (*service.OpsDashboardOverview, error) { func (r *opsRepository) getDashboardOverviewRaw(ctx context.Context, filter *service.OpsDashboardFilter) (*service.OpsDashboardOverview, error) {
start := filter.StartTime.UTC() start := filter.StartTime.UTC()
end := filter.EndTime.UTC() end := filter.EndTime.UTC()
degraded := false
successCount, tokenConsumed, err := r.queryUsageCounts(ctx, filter, start, end) successCount, tokenConsumed, err := r.queryUsageCounts(ctx, filter, start, end)
if err != nil { if err != nil {
return nil, err return nil, err
} }
duration, ttft, err := r.queryUsageLatency(ctx, filter, start, end) latencyCtx, cancelLatency := context.WithTimeout(ctx, opsRawLatencyQueryTimeout)
duration, ttft, err := r.queryUsageLatency(latencyCtx, filter, start, end)
cancelLatency()
if err != nil { if err != nil {
return nil, err if isQueryTimeoutErr(err) {
degraded = true
duration = service.OpsPercentiles{}
ttft = service.OpsPercentiles{}
} else {
return nil, err
}
} }
errorTotal, businessLimited, errorCountSLA, upstreamExcl, upstream429, upstream529, err := r.queryErrorCounts(ctx, filter, start, end) errorTotal, businessLimited, errorCountSLA, upstreamExcl, upstream429, upstream529, err := r.queryErrorCounts(ctx, filter, start, end)
...@@ -75,20 +89,40 @@ func (r *opsRepository) getDashboardOverviewRaw(ctx context.Context, filter *ser ...@@ -75,20 +89,40 @@ func (r *opsRepository) getDashboardOverviewRaw(ctx context.Context, filter *ser
qpsCurrent, tpsCurrent, err := r.queryCurrentRates(ctx, filter, end) qpsCurrent, tpsCurrent, err := r.queryCurrentRates(ctx, filter, end)
if err != nil { if err != nil {
return nil, err if isQueryTimeoutErr(err) {
degraded = true
} else {
return nil, err
}
} }
qpsPeak, err := r.queryPeakQPS(ctx, filter, start, end) peakCtx, cancelPeak := context.WithTimeout(ctx, opsRawPeakQueryTimeout)
qpsPeak, tpsPeak, err := r.queryPeakRates(peakCtx, filter, start, end)
cancelPeak()
if err != nil { if err != nil {
return nil, err if isQueryTimeoutErr(err) {
} degraded = true
tpsPeak, err := r.queryPeakTPS(ctx, filter, start, end) } else {
if err != nil { return nil, err
return nil, err }
} }
qpsAvg := roundTo1DP(float64(requestCountTotal) / windowSeconds) qpsAvg := roundTo1DP(float64(requestCountTotal) / windowSeconds)
tpsAvg := roundTo1DP(float64(tokenConsumed) / windowSeconds) tpsAvg := roundTo1DP(float64(tokenConsumed) / windowSeconds)
if degraded {
if qpsCurrent <= 0 {
qpsCurrent = qpsAvg
}
if tpsCurrent <= 0 {
tpsCurrent = tpsAvg
}
if qpsPeak <= 0 {
qpsPeak = roundTo1DP(math.Max(qpsCurrent, qpsAvg))
}
if tpsPeak <= 0 {
tpsPeak = roundTo1DP(math.Max(tpsCurrent, tpsAvg))
}
}
return &service.OpsDashboardOverview{ return &service.OpsDashboardOverview{
StartTime: start, StartTime: start,
...@@ -230,26 +264,45 @@ func (r *opsRepository) getDashboardOverviewPreaggregated(ctx context.Context, f ...@@ -230,26 +264,45 @@ func (r *opsRepository) getDashboardOverviewPreaggregated(ctx context.Context, f
sla := safeDivideFloat64(float64(successCount), float64(requestCountSLA)) sla := safeDivideFloat64(float64(successCount), float64(requestCountSLA))
errorRate := safeDivideFloat64(float64(errorCountSLA), float64(requestCountSLA)) errorRate := safeDivideFloat64(float64(errorCountSLA), float64(requestCountSLA))
upstreamErrorRate := safeDivideFloat64(float64(upstreamExcl), float64(requestCountSLA)) upstreamErrorRate := safeDivideFloat64(float64(upstreamExcl), float64(requestCountSLA))
degraded := false
// Keep "current" rates as raw, to preserve realtime semantics. // Keep "current" rates as raw, to preserve realtime semantics.
qpsCurrent, tpsCurrent, err := r.queryCurrentRates(ctx, filter, end) qpsCurrent, tpsCurrent, err := r.queryCurrentRates(ctx, filter, end)
if err != nil { if err != nil {
return nil, err if isQueryTimeoutErr(err) {
degraded = true
} else {
return nil, err
}
} }
// NOTE: peak still uses raw logs (minute granularity). This is typically cheaper than percentile_cont peakCtx, cancelPeak := context.WithTimeout(ctx, opsRawPeakQueryTimeout)
// and keeps semantics consistent across modes. qpsPeak, tpsPeak, err := r.queryPeakRates(peakCtx, filter, start, end)
qpsPeak, err := r.queryPeakQPS(ctx, filter, start, end) cancelPeak()
if err != nil { if err != nil {
return nil, err if isQueryTimeoutErr(err) {
} degraded = true
tpsPeak, err := r.queryPeakTPS(ctx, filter, start, end) } else {
if err != nil { return nil, err
return nil, err }
} }
qpsAvg := roundTo1DP(float64(requestCountTotal) / windowSeconds) qpsAvg := roundTo1DP(float64(requestCountTotal) / windowSeconds)
tpsAvg := roundTo1DP(float64(tokenConsumed) / windowSeconds) tpsAvg := roundTo1DP(float64(tokenConsumed) / windowSeconds)
if degraded {
if qpsCurrent <= 0 {
qpsCurrent = qpsAvg
}
if tpsCurrent <= 0 {
tpsCurrent = tpsAvg
}
if qpsPeak <= 0 {
qpsPeak = roundTo1DP(math.Max(qpsCurrent, qpsAvg))
}
if tpsPeak <= 0 {
tpsPeak = roundTo1DP(math.Max(tpsCurrent, tpsAvg))
}
}
return &service.OpsDashboardOverview{ return &service.OpsDashboardOverview{
StartTime: start, StartTime: start,
...@@ -577,9 +630,16 @@ func (r *opsRepository) queryRawPartial(ctx context.Context, filter *service.Ops ...@@ -577,9 +630,16 @@ func (r *opsRepository) queryRawPartial(ctx context.Context, filter *service.Ops
return nil, err return nil, err
} }
duration, ttft, err := r.queryUsageLatency(ctx, filter, start, end) latencyCtx, cancelLatency := context.WithTimeout(ctx, opsRawLatencyQueryTimeout)
duration, ttft, err := r.queryUsageLatency(latencyCtx, filter, start, end)
cancelLatency()
if err != nil { if err != nil {
return nil, err if isQueryTimeoutErr(err) {
duration = service.OpsPercentiles{}
ttft = service.OpsPercentiles{}
} else {
return nil, err
}
} }
errorTotal, businessLimited, errorCountSLA, upstreamExcl, upstream429, upstream529, err := r.queryErrorCounts(ctx, filter, start, end) errorTotal, businessLimited, errorCountSLA, upstreamExcl, upstream429, upstream529, err := r.queryErrorCounts(ctx, filter, start, end)
...@@ -735,68 +795,56 @@ FROM usage_logs ul ...@@ -735,68 +795,56 @@ FROM usage_logs ul
} }
func (r *opsRepository) queryUsageLatency(ctx context.Context, filter *service.OpsDashboardFilter, start, end time.Time) (duration service.OpsPercentiles, ttft service.OpsPercentiles, err error) { func (r *opsRepository) queryUsageLatency(ctx context.Context, filter *service.OpsDashboardFilter, start, end time.Time) (duration service.OpsPercentiles, ttft service.OpsPercentiles, err error) {
{ join, where, args, _ := buildUsageWhere(filter, start, end, 1)
join, where, args, _ := buildUsageWhere(filter, start, end, 1) q := `
q := `
SELECT SELECT
percentile_cont(0.50) WITHIN GROUP (ORDER BY duration_ms) AS p50, percentile_cont(0.50) WITHIN GROUP (ORDER BY duration_ms) FILTER (WHERE duration_ms IS NOT NULL) AS duration_p50,
percentile_cont(0.90) WITHIN GROUP (ORDER BY duration_ms) AS p90, percentile_cont(0.90) WITHIN GROUP (ORDER BY duration_ms) FILTER (WHERE duration_ms IS NOT NULL) AS duration_p90,
percentile_cont(0.95) WITHIN GROUP (ORDER BY duration_ms) AS p95, percentile_cont(0.95) WITHIN GROUP (ORDER BY duration_ms) FILTER (WHERE duration_ms IS NOT NULL) AS duration_p95,
percentile_cont(0.99) WITHIN GROUP (ORDER BY duration_ms) AS p99, percentile_cont(0.99) WITHIN GROUP (ORDER BY duration_ms) FILTER (WHERE duration_ms IS NOT NULL) AS duration_p99,
AVG(duration_ms) AS avg_ms, AVG(duration_ms) FILTER (WHERE duration_ms IS NOT NULL) AS duration_avg,
MAX(duration_ms) AS max_ms MAX(duration_ms) AS duration_max,
percentile_cont(0.50) WITHIN GROUP (ORDER BY first_token_ms) FILTER (WHERE first_token_ms IS NOT NULL) AS ttft_p50,
percentile_cont(0.90) WITHIN GROUP (ORDER BY first_token_ms) FILTER (WHERE first_token_ms IS NOT NULL) AS ttft_p90,
percentile_cont(0.95) WITHIN GROUP (ORDER BY first_token_ms) FILTER (WHERE first_token_ms IS NOT NULL) AS ttft_p95,
percentile_cont(0.99) WITHIN GROUP (ORDER BY first_token_ms) FILTER (WHERE first_token_ms IS NOT NULL) AS ttft_p99,
AVG(first_token_ms) FILTER (WHERE first_token_ms IS NOT NULL) AS ttft_avg,
MAX(first_token_ms) AS ttft_max
FROM usage_logs ul FROM usage_logs ul
` + join + ` ` + join + `
` + where + ` ` + where
AND duration_ms IS NOT NULL`
var p50, p90, p95, p99 sql.NullFloat64 var dP50, dP90, dP95, dP99 sql.NullFloat64
var avg sql.NullFloat64 var dAvg sql.NullFloat64
var max sql.NullInt64 var dMax sql.NullInt64
if err := r.db.QueryRowContext(ctx, q, args...).Scan(&p50, &p90, &p95, &p99, &avg, &max); err != nil { var tP50, tP90, tP95, tP99 sql.NullFloat64
return service.OpsPercentiles{}, service.OpsPercentiles{}, err var tAvg sql.NullFloat64
} var tMax sql.NullInt64
duration.P50 = floatToIntPtr(p50) if err := r.db.QueryRowContext(ctx, q, args...).Scan(
duration.P90 = floatToIntPtr(p90) &dP50, &dP90, &dP95, &dP99, &dAvg, &dMax,
duration.P95 = floatToIntPtr(p95) &tP50, &tP90, &tP95, &tP99, &tAvg, &tMax,
duration.P99 = floatToIntPtr(p99) ); err != nil {
duration.Avg = floatToIntPtr(avg) return service.OpsPercentiles{}, service.OpsPercentiles{}, err
if max.Valid {
v := int(max.Int64)
duration.Max = &v
}
} }
{ duration.P50 = floatToIntPtr(dP50)
join, where, args, _ := buildUsageWhere(filter, start, end, 1) duration.P90 = floatToIntPtr(dP90)
q := ` duration.P95 = floatToIntPtr(dP95)
SELECT duration.P99 = floatToIntPtr(dP99)
percentile_cont(0.50) WITHIN GROUP (ORDER BY first_token_ms) AS p50, duration.Avg = floatToIntPtr(dAvg)
percentile_cont(0.90) WITHIN GROUP (ORDER BY first_token_ms) AS p90, if dMax.Valid {
percentile_cont(0.95) WITHIN GROUP (ORDER BY first_token_ms) AS p95, v := int(dMax.Int64)
percentile_cont(0.99) WITHIN GROUP (ORDER BY first_token_ms) AS p99, duration.Max = &v
AVG(first_token_ms) AS avg_ms, }
MAX(first_token_ms) AS max_ms
FROM usage_logs ul
` + join + `
` + where + `
AND first_token_ms IS NOT NULL`
var p50, p90, p95, p99 sql.NullFloat64 ttft.P50 = floatToIntPtr(tP50)
var avg sql.NullFloat64 ttft.P90 = floatToIntPtr(tP90)
var max sql.NullInt64 ttft.P95 = floatToIntPtr(tP95)
if err := r.db.QueryRowContext(ctx, q, args...).Scan(&p50, &p90, &p95, &p99, &avg, &max); err != nil { ttft.P99 = floatToIntPtr(tP99)
return service.OpsPercentiles{}, service.OpsPercentiles{}, err ttft.Avg = floatToIntPtr(tAvg)
} if tMax.Valid {
ttft.P50 = floatToIntPtr(p50) v := int(tMax.Int64)
ttft.P90 = floatToIntPtr(p90) ttft.Max = &v
ttft.P95 = floatToIntPtr(p95)
ttft.P99 = floatToIntPtr(p99)
ttft.Avg = floatToIntPtr(avg)
if max.Valid {
v := int(max.Int64)
ttft.Max = &v
}
} }
return duration, ttft, nil return duration, ttft, nil
...@@ -854,20 +902,23 @@ func (r *opsRepository) queryCurrentRates(ctx context.Context, filter *service.O ...@@ -854,20 +902,23 @@ func (r *opsRepository) queryCurrentRates(ctx context.Context, filter *service.O
return qpsCurrent, tpsCurrent, nil return qpsCurrent, tpsCurrent, nil
} }
func (r *opsRepository) queryPeakQPS(ctx context.Context, filter *service.OpsDashboardFilter, start, end time.Time) (float64, error) { func (r *opsRepository) queryPeakRates(ctx context.Context, filter *service.OpsDashboardFilter, start, end time.Time) (qpsPeak float64, tpsPeak float64, err error) {
usageJoin, usageWhere, usageArgs, next := buildUsageWhere(filter, start, end, 1) usageJoin, usageWhere, usageArgs, next := buildUsageWhere(filter, start, end, 1)
errorWhere, errorArgs, _ := buildErrorWhere(filter, start, end, next) errorWhere, errorArgs, _ := buildErrorWhere(filter, start, end, next)
q := ` q := `
WITH usage_buckets AS ( WITH usage_buckets AS (
SELECT date_trunc('minute', ul.created_at) AS bucket, COUNT(*) AS cnt SELECT
date_trunc('minute', ul.created_at) AS bucket,
COUNT(*) AS req_cnt,
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) AS token_cnt
FROM usage_logs ul FROM usage_logs ul
` + usageJoin + ` ` + usageJoin + `
` + usageWhere + ` ` + usageWhere + `
GROUP BY 1 GROUP BY 1
), ),
error_buckets AS ( error_buckets AS (
SELECT date_trunc('minute', created_at) AS bucket, COUNT(*) AS cnt SELECT date_trunc('minute', created_at) AS bucket, COUNT(*) AS err_cnt
FROM ops_error_logs FROM ops_error_logs
` + errorWhere + ` ` + errorWhere + `
AND COALESCE(status_code, 0) >= 400 AND COALESCE(status_code, 0) >= 400
...@@ -875,47 +926,33 @@ error_buckets AS ( ...@@ -875,47 +926,33 @@ error_buckets AS (
), ),
combined AS ( combined AS (
SELECT COALESCE(u.bucket, e.bucket) AS bucket, SELECT COALESCE(u.bucket, e.bucket) AS bucket,
COALESCE(u.cnt, 0) + COALESCE(e.cnt, 0) AS total COALESCE(u.req_cnt, 0) + COALESCE(e.err_cnt, 0) AS total_req,
COALESCE(u.token_cnt, 0) AS total_tokens
FROM usage_buckets u FROM usage_buckets u
FULL OUTER JOIN error_buckets e ON u.bucket = e.bucket FULL OUTER JOIN error_buckets e ON u.bucket = e.bucket
) )
SELECT COALESCE(MAX(total), 0) FROM combined` SELECT
COALESCE(MAX(total_req), 0) AS max_req_per_min,
COALESCE(MAX(total_tokens), 0) AS max_tokens_per_min
FROM combined`
args := append(usageArgs, errorArgs...) args := append(usageArgs, errorArgs...)
var maxPerMinute sql.NullInt64 var maxReqPerMinute, maxTokensPerMinute sql.NullInt64
if err := r.db.QueryRowContext(ctx, q, args...).Scan(&maxPerMinute); err != nil { if err := r.db.QueryRowContext(ctx, q, args...).Scan(&maxReqPerMinute, &maxTokensPerMinute); err != nil {
return 0, err return 0, 0, err
}
if maxReqPerMinute.Valid && maxReqPerMinute.Int64 > 0 {
qpsPeak = roundTo1DP(float64(maxReqPerMinute.Int64) / 60.0)
} }
if !maxPerMinute.Valid || maxPerMinute.Int64 <= 0 { if maxTokensPerMinute.Valid && maxTokensPerMinute.Int64 > 0 {
return 0, nil tpsPeak = roundTo1DP(float64(maxTokensPerMinute.Int64) / 60.0)
} }
return roundTo1DP(float64(maxPerMinute.Int64) / 60.0), nil return qpsPeak, tpsPeak, nil
} }
func (r *opsRepository) queryPeakTPS(ctx context.Context, filter *service.OpsDashboardFilter, start, end time.Time) (float64, error) { func isQueryTimeoutErr(err error) bool {
join, where, args, _ := buildUsageWhere(filter, start, end, 1) return errors.Is(err, context.DeadlineExceeded)
q := `
SELECT COALESCE(MAX(tokens_per_min), 0)
FROM (
SELECT
date_trunc('minute', ul.created_at) AS bucket,
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) AS tokens_per_min
FROM usage_logs ul
` + join + `
` + where + `
GROUP BY 1
) t`
var maxPerMinute sql.NullInt64
if err := r.db.QueryRowContext(ctx, q, args...).Scan(&maxPerMinute); err != nil {
return 0, err
}
if !maxPerMinute.Valid || maxPerMinute.Int64 <= 0 {
return 0, nil
}
return roundTo1DP(float64(maxPerMinute.Int64) / 60.0), nil
} }
func buildUsageWhere(filter *service.OpsDashboardFilter, start, end time.Time, startIndex int) (join string, where string, args []any, nextIndex int) { func buildUsageWhere(filter *service.OpsDashboardFilter, start, end time.Time, startIndex int) (join string, where string, args []any, nextIndex int) {
......
package repository
import (
"context"
"fmt"
"testing"
)
func TestIsQueryTimeoutErr(t *testing.T) {
if !isQueryTimeoutErr(context.DeadlineExceeded) {
t.Fatalf("context.DeadlineExceeded should be treated as query timeout")
}
if !isQueryTimeoutErr(fmt.Errorf("wrapped: %w", context.DeadlineExceeded)) {
t.Fatalf("wrapped context.DeadlineExceeded should be treated as query timeout")
}
if isQueryTimeoutErr(context.Canceled) {
t.Fatalf("context.Canceled should not be treated as query timeout")
}
if isQueryTimeoutErr(fmt.Errorf("wrapped: %w", context.Canceled)) {
t.Fatalf("wrapped context.Canceled should not be treated as query timeout")
}
}
package repository
import (
"context"
"database/sql"
"encoding/json"
"fmt"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
)
// soraGenerationRepository 实现 service.SoraGenerationRepository 接口。
// 使用原生 SQL 操作 sora_generations 表。
type soraGenerationRepository struct {
sql *sql.DB
}
// NewSoraGenerationRepository 创建 Sora 生成记录仓储实例。
func NewSoraGenerationRepository(sqlDB *sql.DB) service.SoraGenerationRepository {
return &soraGenerationRepository{sql: sqlDB}
}
func (r *soraGenerationRepository) Create(ctx context.Context, gen *service.SoraGeneration) error {
mediaURLsJSON, _ := json.Marshal(gen.MediaURLs)
s3KeysJSON, _ := json.Marshal(gen.S3ObjectKeys)
err := r.sql.QueryRowContext(ctx, `
INSERT INTO sora_generations (
user_id, api_key_id, model, prompt, media_type,
status, media_url, media_urls, file_size_bytes,
storage_type, s3_object_keys, upstream_task_id, error_message
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13)
RETURNING id, created_at
`,
gen.UserID, gen.APIKeyID, gen.Model, gen.Prompt, gen.MediaType,
gen.Status, gen.MediaURL, mediaURLsJSON, gen.FileSizeBytes,
gen.StorageType, s3KeysJSON, gen.UpstreamTaskID, gen.ErrorMessage,
).Scan(&gen.ID, &gen.CreatedAt)
return err
}
// CreatePendingWithLimit 在单事务内执行“并发上限检查 + 创建”,避免 count+create 竞态。
func (r *soraGenerationRepository) CreatePendingWithLimit(
ctx context.Context,
gen *service.SoraGeneration,
activeStatuses []string,
maxActive int64,
) error {
if gen == nil {
return fmt.Errorf("generation is nil")
}
if maxActive <= 0 {
return r.Create(ctx, gen)
}
if len(activeStatuses) == 0 {
activeStatuses = []string{service.SoraGenStatusPending, service.SoraGenStatusGenerating}
}
tx, err := r.sql.BeginTx(ctx, nil)
if err != nil {
return err
}
defer func() { _ = tx.Rollback() }()
// 使用用户级 advisory lock 串行化并发创建,避免超限竞态。
if _, err := tx.ExecContext(ctx, `SELECT pg_advisory_xact_lock($1)`, gen.UserID); err != nil {
return err
}
placeholders := make([]string, len(activeStatuses))
args := make([]any, 0, 1+len(activeStatuses))
args = append(args, gen.UserID)
for i, s := range activeStatuses {
placeholders[i] = fmt.Sprintf("$%d", i+2)
args = append(args, s)
}
countQuery := fmt.Sprintf(
`SELECT COUNT(*) FROM sora_generations WHERE user_id = $1 AND status IN (%s)`,
strings.Join(placeholders, ","),
)
var activeCount int64
if err := tx.QueryRowContext(ctx, countQuery, args...).Scan(&activeCount); err != nil {
return err
}
if activeCount >= maxActive {
return service.ErrSoraGenerationConcurrencyLimit
}
mediaURLsJSON, _ := json.Marshal(gen.MediaURLs)
s3KeysJSON, _ := json.Marshal(gen.S3ObjectKeys)
if err := tx.QueryRowContext(ctx, `
INSERT INTO sora_generations (
user_id, api_key_id, model, prompt, media_type,
status, media_url, media_urls, file_size_bytes,
storage_type, s3_object_keys, upstream_task_id, error_message
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13)
RETURNING id, created_at
`,
gen.UserID, gen.APIKeyID, gen.Model, gen.Prompt, gen.MediaType,
gen.Status, gen.MediaURL, mediaURLsJSON, gen.FileSizeBytes,
gen.StorageType, s3KeysJSON, gen.UpstreamTaskID, gen.ErrorMessage,
).Scan(&gen.ID, &gen.CreatedAt); err != nil {
return err
}
return tx.Commit()
}
func (r *soraGenerationRepository) GetByID(ctx context.Context, id int64) (*service.SoraGeneration, error) {
gen := &service.SoraGeneration{}
var mediaURLsJSON, s3KeysJSON []byte
var completedAt sql.NullTime
var apiKeyID sql.NullInt64
err := r.sql.QueryRowContext(ctx, `
SELECT id, user_id, api_key_id, model, prompt, media_type,
status, media_url, media_urls, file_size_bytes,
storage_type, s3_object_keys, upstream_task_id, error_message,
created_at, completed_at
FROM sora_generations WHERE id = $1
`, id).Scan(
&gen.ID, &gen.UserID, &apiKeyID, &gen.Model, &gen.Prompt, &gen.MediaType,
&gen.Status, &gen.MediaURL, &mediaURLsJSON, &gen.FileSizeBytes,
&gen.StorageType, &s3KeysJSON, &gen.UpstreamTaskID, &gen.ErrorMessage,
&gen.CreatedAt, &completedAt,
)
if err != nil {
if err == sql.ErrNoRows {
return nil, fmt.Errorf("生成记录不存在")
}
return nil, err
}
if apiKeyID.Valid {
gen.APIKeyID = &apiKeyID.Int64
}
if completedAt.Valid {
gen.CompletedAt = &completedAt.Time
}
_ = json.Unmarshal(mediaURLsJSON, &gen.MediaURLs)
_ = json.Unmarshal(s3KeysJSON, &gen.S3ObjectKeys)
return gen, nil
}
func (r *soraGenerationRepository) Update(ctx context.Context, gen *service.SoraGeneration) error {
mediaURLsJSON, _ := json.Marshal(gen.MediaURLs)
s3KeysJSON, _ := json.Marshal(gen.S3ObjectKeys)
var completedAt *time.Time
if gen.CompletedAt != nil {
completedAt = gen.CompletedAt
}
_, err := r.sql.ExecContext(ctx, `
UPDATE sora_generations SET
status = $2, media_url = $3, media_urls = $4, file_size_bytes = $5,
storage_type = $6, s3_object_keys = $7, upstream_task_id = $8,
error_message = $9, completed_at = $10
WHERE id = $1
`,
gen.ID, gen.Status, gen.MediaURL, mediaURLsJSON, gen.FileSizeBytes,
gen.StorageType, s3KeysJSON, gen.UpstreamTaskID,
gen.ErrorMessage, completedAt,
)
return err
}
// UpdateGeneratingIfPending 仅当状态为 pending 时更新为 generating。
func (r *soraGenerationRepository) UpdateGeneratingIfPending(ctx context.Context, id int64, upstreamTaskID string) (bool, error) {
result, err := r.sql.ExecContext(ctx, `
UPDATE sora_generations
SET status = $2, upstream_task_id = $3
WHERE id = $1 AND status = $4
`,
id, service.SoraGenStatusGenerating, upstreamTaskID, service.SoraGenStatusPending,
)
if err != nil {
return false, err
}
affected, err := result.RowsAffected()
if err != nil {
return false, err
}
return affected > 0, nil
}
// UpdateCompletedIfActive 仅当状态为 pending/generating 时更新为 completed。
func (r *soraGenerationRepository) UpdateCompletedIfActive(
ctx context.Context,
id int64,
mediaURL string,
mediaURLs []string,
storageType string,
s3Keys []string,
fileSizeBytes int64,
completedAt time.Time,
) (bool, error) {
mediaURLsJSON, _ := json.Marshal(mediaURLs)
s3KeysJSON, _ := json.Marshal(s3Keys)
result, err := r.sql.ExecContext(ctx, `
UPDATE sora_generations
SET status = $2,
media_url = $3,
media_urls = $4,
file_size_bytes = $5,
storage_type = $6,
s3_object_keys = $7,
error_message = '',
completed_at = $8
WHERE id = $1 AND status IN ($9, $10)
`,
id, service.SoraGenStatusCompleted, mediaURL, mediaURLsJSON, fileSizeBytes,
storageType, s3KeysJSON, completedAt, service.SoraGenStatusPending, service.SoraGenStatusGenerating,
)
if err != nil {
return false, err
}
affected, err := result.RowsAffected()
if err != nil {
return false, err
}
return affected > 0, nil
}
// UpdateFailedIfActive 仅当状态为 pending/generating 时更新为 failed。
func (r *soraGenerationRepository) UpdateFailedIfActive(
ctx context.Context,
id int64,
errMsg string,
completedAt time.Time,
) (bool, error) {
result, err := r.sql.ExecContext(ctx, `
UPDATE sora_generations
SET status = $2,
error_message = $3,
completed_at = $4
WHERE id = $1 AND status IN ($5, $6)
`,
id, service.SoraGenStatusFailed, errMsg, completedAt, service.SoraGenStatusPending, service.SoraGenStatusGenerating,
)
if err != nil {
return false, err
}
affected, err := result.RowsAffected()
if err != nil {
return false, err
}
return affected > 0, nil
}
// UpdateCancelledIfActive 仅当状态为 pending/generating 时更新为 cancelled。
func (r *soraGenerationRepository) UpdateCancelledIfActive(ctx context.Context, id int64, completedAt time.Time) (bool, error) {
result, err := r.sql.ExecContext(ctx, `
UPDATE sora_generations
SET status = $2, completed_at = $3
WHERE id = $1 AND status IN ($4, $5)
`,
id, service.SoraGenStatusCancelled, completedAt, service.SoraGenStatusPending, service.SoraGenStatusGenerating,
)
if err != nil {
return false, err
}
affected, err := result.RowsAffected()
if err != nil {
return false, err
}
return affected > 0, nil
}
// UpdateStorageIfCompleted 更新已完成记录的存储信息(用于手动保存,不重置 completed_at)。
func (r *soraGenerationRepository) UpdateStorageIfCompleted(
ctx context.Context,
id int64,
mediaURL string,
mediaURLs []string,
storageType string,
s3Keys []string,
fileSizeBytes int64,
) (bool, error) {
mediaURLsJSON, _ := json.Marshal(mediaURLs)
s3KeysJSON, _ := json.Marshal(s3Keys)
result, err := r.sql.ExecContext(ctx, `
UPDATE sora_generations
SET media_url = $2,
media_urls = $3,
file_size_bytes = $4,
storage_type = $5,
s3_object_keys = $6
WHERE id = $1 AND status = $7
`,
id, mediaURL, mediaURLsJSON, fileSizeBytes, storageType, s3KeysJSON, service.SoraGenStatusCompleted,
)
if err != nil {
return false, err
}
affected, err := result.RowsAffected()
if err != nil {
return false, err
}
return affected > 0, nil
}
func (r *soraGenerationRepository) Delete(ctx context.Context, id int64) error {
_, err := r.sql.ExecContext(ctx, `DELETE FROM sora_generations WHERE id = $1`, id)
return err
}
func (r *soraGenerationRepository) List(ctx context.Context, params service.SoraGenerationListParams) ([]*service.SoraGeneration, int64, error) {
// 构建 WHERE 条件
conditions := []string{"user_id = $1"}
args := []any{params.UserID}
argIdx := 2
if params.Status != "" {
// 支持逗号分隔的多状态
statuses := strings.Split(params.Status, ",")
placeholders := make([]string, len(statuses))
for i, s := range statuses {
placeholders[i] = fmt.Sprintf("$%d", argIdx)
args = append(args, strings.TrimSpace(s))
argIdx++
}
conditions = append(conditions, fmt.Sprintf("status IN (%s)", strings.Join(placeholders, ",")))
}
if params.StorageType != "" {
storageTypes := strings.Split(params.StorageType, ",")
placeholders := make([]string, len(storageTypes))
for i, s := range storageTypes {
placeholders[i] = fmt.Sprintf("$%d", argIdx)
args = append(args, strings.TrimSpace(s))
argIdx++
}
conditions = append(conditions, fmt.Sprintf("storage_type IN (%s)", strings.Join(placeholders, ",")))
}
if params.MediaType != "" {
conditions = append(conditions, fmt.Sprintf("media_type = $%d", argIdx))
args = append(args, params.MediaType)
argIdx++
}
whereClause := "WHERE " + strings.Join(conditions, " AND ")
// 计数
var total int64
countQuery := fmt.Sprintf("SELECT COUNT(*) FROM sora_generations %s", whereClause)
if err := r.sql.QueryRowContext(ctx, countQuery, args...).Scan(&total); err != nil {
return nil, 0, err
}
// 分页查询
offset := (params.Page - 1) * params.PageSize
listQuery := fmt.Sprintf(`
SELECT id, user_id, api_key_id, model, prompt, media_type,
status, media_url, media_urls, file_size_bytes,
storage_type, s3_object_keys, upstream_task_id, error_message,
created_at, completed_at
FROM sora_generations %s
ORDER BY created_at DESC
LIMIT $%d OFFSET $%d
`, whereClause, argIdx, argIdx+1)
args = append(args, params.PageSize, offset)
rows, err := r.sql.QueryContext(ctx, listQuery, args...)
if err != nil {
return nil, 0, err
}
defer func() {
_ = rows.Close()
}()
var results []*service.SoraGeneration
for rows.Next() {
gen := &service.SoraGeneration{}
var mediaURLsJSON, s3KeysJSON []byte
var completedAt sql.NullTime
var apiKeyID sql.NullInt64
if err := rows.Scan(
&gen.ID, &gen.UserID, &apiKeyID, &gen.Model, &gen.Prompt, &gen.MediaType,
&gen.Status, &gen.MediaURL, &mediaURLsJSON, &gen.FileSizeBytes,
&gen.StorageType, &s3KeysJSON, &gen.UpstreamTaskID, &gen.ErrorMessage,
&gen.CreatedAt, &completedAt,
); err != nil {
return nil, 0, err
}
if apiKeyID.Valid {
gen.APIKeyID = &apiKeyID.Int64
}
if completedAt.Valid {
gen.CompletedAt = &completedAt.Time
}
_ = json.Unmarshal(mediaURLsJSON, &gen.MediaURLs)
_ = json.Unmarshal(s3KeysJSON, &gen.S3ObjectKeys)
results = append(results, gen)
}
return results, total, rows.Err()
}
func (r *soraGenerationRepository) CountByUserAndStatus(ctx context.Context, userID int64, statuses []string) (int64, error) {
if len(statuses) == 0 {
return 0, nil
}
placeholders := make([]string, len(statuses))
args := []any{userID}
for i, s := range statuses {
placeholders[i] = fmt.Sprintf("$%d", i+2)
args = append(args, s)
}
var count int64
query := fmt.Sprintf("SELECT COUNT(*) FROM sora_generations WHERE user_id = $1 AND status IN (%s)", strings.Join(placeholders, ","))
err := r.sql.QueryRowContext(ctx, query, args...).Scan(&count)
return count, err
}
...@@ -362,7 +362,12 @@ func buildUsageCleanupWhere(filters service.UsageCleanupFilters) (string, []any) ...@@ -362,7 +362,12 @@ func buildUsageCleanupWhere(filters service.UsageCleanupFilters) (string, []any)
idx++ idx++
} }
} }
if filters.Stream != nil { if filters.RequestType != nil {
condition, conditionArgs := buildRequestTypeFilterCondition(idx, *filters.RequestType)
conditions = append(conditions, condition)
args = append(args, conditionArgs...)
idx += len(conditionArgs)
} else if filters.Stream != nil {
conditions = append(conditions, fmt.Sprintf("stream = $%d", idx)) conditions = append(conditions, fmt.Sprintf("stream = $%d", idx))
args = append(args, *filters.Stream) args = append(args, *filters.Stream)
idx++ idx++
......
...@@ -466,6 +466,38 @@ func TestBuildUsageCleanupWhere(t *testing.T) { ...@@ -466,6 +466,38 @@ func TestBuildUsageCleanupWhere(t *testing.T) {
require.Equal(t, []any{start, end, userID, apiKeyID, accountID, groupID, "gpt-4", stream, billingType}, args) require.Equal(t, []any{start, end, userID, apiKeyID, accountID, groupID, "gpt-4", stream, billingType}, args)
} }
func TestBuildUsageCleanupWhereRequestTypePriority(t *testing.T) {
start := time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC)
end := start.Add(24 * time.Hour)
requestType := int16(service.RequestTypeWSV2)
stream := false
where, args := buildUsageCleanupWhere(service.UsageCleanupFilters{
StartTime: start,
EndTime: end,
RequestType: &requestType,
Stream: &stream,
})
require.Equal(t, "created_at >= $1 AND created_at <= $2 AND (request_type = $3 OR (request_type = 0 AND openai_ws_mode = TRUE))", where)
require.Equal(t, []any{start, end, requestType}, args)
}
func TestBuildUsageCleanupWhereRequestTypeLegacyFallback(t *testing.T) {
start := time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC)
end := start.Add(24 * time.Hour)
requestType := int16(service.RequestTypeStream)
where, args := buildUsageCleanupWhere(service.UsageCleanupFilters{
StartTime: start,
EndTime: end,
RequestType: &requestType,
})
require.Equal(t, "created_at >= $1 AND created_at <= $2 AND (request_type = $3 OR (request_type = 0 AND stream = TRUE AND openai_ws_mode = FALSE))", where)
require.Equal(t, []any{start, end, requestType}, args)
}
func TestBuildUsageCleanupWhereModelEmpty(t *testing.T) { func TestBuildUsageCleanupWhereModelEmpty(t *testing.T) {
start := time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC) start := time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC)
end := start.Add(24 * time.Hour) end := start.Add(24 * time.Hour)
......
...@@ -22,7 +22,7 @@ import ( ...@@ -22,7 +22,7 @@ import (
"github.com/lib/pq" "github.com/lib/pq"
) )
const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, account_rate_multiplier, billing_type, stream, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, media_type, reasoning_effort, cache_ttl_overridden, created_at" const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, account_rate_multiplier, billing_type, request_type, stream, openai_ws_mode, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, media_type, reasoning_effort, cache_ttl_overridden, created_at"
// dateFormatWhitelist 将 granularity 参数映射为 PostgreSQL TO_CHAR 格式字符串,防止外部输入直接拼入 SQL // dateFormatWhitelist 将 granularity 参数映射为 PostgreSQL TO_CHAR 格式字符串,防止外部输入直接拼入 SQL
var dateFormatWhitelist = map[string]string{ var dateFormatWhitelist = map[string]string{
...@@ -98,6 +98,8 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog) ...@@ -98,6 +98,8 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
log.RequestID = requestID log.RequestID = requestID
rateMultiplier := log.RateMultiplier rateMultiplier := log.RateMultiplier
log.SyncRequestTypeAndLegacyFields()
requestType := int16(log.RequestType)
query := ` query := `
INSERT INTO usage_logs ( INSERT INTO usage_logs (
...@@ -123,7 +125,9 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog) ...@@ -123,7 +125,9 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
rate_multiplier, rate_multiplier,
account_rate_multiplier, account_rate_multiplier,
billing_type, billing_type,
request_type,
stream, stream,
openai_ws_mode,
duration_ms, duration_ms,
first_token_ms, first_token_ms,
user_agent, user_agent,
...@@ -140,7 +144,7 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog) ...@@ -140,7 +144,7 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
$8, $9, $10, $11, $8, $9, $10, $11,
$12, $13, $12, $13,
$14, $15, $16, $17, $18, $19, $14, $15, $16, $17, $18, $19,
$20, $21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31, $32, $33 $20, $21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35
) )
ON CONFLICT (request_id, api_key_id) DO NOTHING ON CONFLICT (request_id, api_key_id) DO NOTHING
RETURNING id, created_at RETURNING id, created_at
...@@ -184,7 +188,9 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog) ...@@ -184,7 +188,9 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
rateMultiplier, rateMultiplier,
log.AccountRateMultiplier, log.AccountRateMultiplier,
log.BillingType, log.BillingType,
requestType,
log.Stream, log.Stream,
log.OpenAIWSMode,
duration, duration,
firstToken, firstToken,
userAgent, userAgent,
...@@ -492,25 +498,46 @@ func (r *usageLogRepository) fillDashboardUsageStatsAggregated(ctx context.Conte ...@@ -492,25 +498,46 @@ func (r *usageLogRepository) fillDashboardUsageStatsAggregated(ctx context.Conte
} }
func (r *usageLogRepository) fillDashboardUsageStatsFromUsageLogs(ctx context.Context, stats *DashboardStats, startUTC, endUTC, todayUTC, now time.Time) error { func (r *usageLogRepository) fillDashboardUsageStatsFromUsageLogs(ctx context.Context, stats *DashboardStats, startUTC, endUTC, todayUTC, now time.Time) error {
totalStatsQuery := ` todayEnd := todayUTC.Add(24 * time.Hour)
combinedStatsQuery := `
WITH scoped AS (
SELECT
created_at,
input_tokens,
output_tokens,
cache_creation_tokens,
cache_read_tokens,
total_cost,
actual_cost,
COALESCE(duration_ms, 0) AS duration_ms
FROM usage_logs
WHERE created_at >= LEAST($1::timestamptz, $3::timestamptz)
AND created_at < GREATEST($2::timestamptz, $4::timestamptz)
)
SELECT SELECT
COUNT(*) as total_requests, COUNT(*) FILTER (WHERE created_at >= $1::timestamptz AND created_at < $2::timestamptz) AS total_requests,
COALESCE(SUM(input_tokens), 0) as total_input_tokens, COALESCE(SUM(input_tokens) FILTER (WHERE created_at >= $1::timestamptz AND created_at < $2::timestamptz), 0) AS total_input_tokens,
COALESCE(SUM(output_tokens), 0) as total_output_tokens, COALESCE(SUM(output_tokens) FILTER (WHERE created_at >= $1::timestamptz AND created_at < $2::timestamptz), 0) AS total_output_tokens,
COALESCE(SUM(cache_creation_tokens), 0) as total_cache_creation_tokens, COALESCE(SUM(cache_creation_tokens) FILTER (WHERE created_at >= $1::timestamptz AND created_at < $2::timestamptz), 0) AS total_cache_creation_tokens,
COALESCE(SUM(cache_read_tokens), 0) as total_cache_read_tokens, COALESCE(SUM(cache_read_tokens) FILTER (WHERE created_at >= $1::timestamptz AND created_at < $2::timestamptz), 0) AS total_cache_read_tokens,
COALESCE(SUM(total_cost), 0) as total_cost, COALESCE(SUM(total_cost) FILTER (WHERE created_at >= $1::timestamptz AND created_at < $2::timestamptz), 0) AS total_cost,
COALESCE(SUM(actual_cost), 0) as total_actual_cost, COALESCE(SUM(actual_cost) FILTER (WHERE created_at >= $1::timestamptz AND created_at < $2::timestamptz), 0) AS total_actual_cost,
COALESCE(SUM(COALESCE(duration_ms, 0)), 0) as total_duration_ms COALESCE(SUM(duration_ms) FILTER (WHERE created_at >= $1::timestamptz AND created_at < $2::timestamptz), 0) AS total_duration_ms,
FROM usage_logs COUNT(*) FILTER (WHERE created_at >= $3::timestamptz AND created_at < $4::timestamptz) AS today_requests,
WHERE created_at >= $1 AND created_at < $2 COALESCE(SUM(input_tokens) FILTER (WHERE created_at >= $3::timestamptz AND created_at < $4::timestamptz), 0) AS today_input_tokens,
COALESCE(SUM(output_tokens) FILTER (WHERE created_at >= $3::timestamptz AND created_at < $4::timestamptz), 0) AS today_output_tokens,
COALESCE(SUM(cache_creation_tokens) FILTER (WHERE created_at >= $3::timestamptz AND created_at < $4::timestamptz), 0) AS today_cache_creation_tokens,
COALESCE(SUM(cache_read_tokens) FILTER (WHERE created_at >= $3::timestamptz AND created_at < $4::timestamptz), 0) AS today_cache_read_tokens,
COALESCE(SUM(total_cost) FILTER (WHERE created_at >= $3::timestamptz AND created_at < $4::timestamptz), 0) AS today_cost,
COALESCE(SUM(actual_cost) FILTER (WHERE created_at >= $3::timestamptz AND created_at < $4::timestamptz), 0) AS today_actual_cost
FROM scoped
` `
var totalDurationMs int64 var totalDurationMs int64
if err := scanSingleRow( if err := scanSingleRow(
ctx, ctx,
r.sql, r.sql,
totalStatsQuery, combinedStatsQuery,
[]any{startUTC, endUTC}, []any{startUTC, endUTC, todayUTC, todayEnd},
&stats.TotalRequests, &stats.TotalRequests,
&stats.TotalInputTokens, &stats.TotalInputTokens,
&stats.TotalOutputTokens, &stats.TotalOutputTokens,
...@@ -519,32 +546,6 @@ func (r *usageLogRepository) fillDashboardUsageStatsFromUsageLogs(ctx context.Co ...@@ -519,32 +546,6 @@ func (r *usageLogRepository) fillDashboardUsageStatsFromUsageLogs(ctx context.Co
&stats.TotalCost, &stats.TotalCost,
&stats.TotalActualCost, &stats.TotalActualCost,
&totalDurationMs, &totalDurationMs,
); err != nil {
return err
}
stats.TotalTokens = stats.TotalInputTokens + stats.TotalOutputTokens + stats.TotalCacheCreationTokens + stats.TotalCacheReadTokens
if stats.TotalRequests > 0 {
stats.AverageDurationMs = float64(totalDurationMs) / float64(stats.TotalRequests)
}
todayEnd := todayUTC.Add(24 * time.Hour)
todayStatsQuery := `
SELECT
COUNT(*) as today_requests,
COALESCE(SUM(input_tokens), 0) as today_input_tokens,
COALESCE(SUM(output_tokens), 0) as today_output_tokens,
COALESCE(SUM(cache_creation_tokens), 0) as today_cache_creation_tokens,
COALESCE(SUM(cache_read_tokens), 0) as today_cache_read_tokens,
COALESCE(SUM(total_cost), 0) as today_cost,
COALESCE(SUM(actual_cost), 0) as today_actual_cost
FROM usage_logs
WHERE created_at >= $1 AND created_at < $2
`
if err := scanSingleRow(
ctx,
r.sql,
todayStatsQuery,
[]any{todayUTC, todayEnd},
&stats.TodayRequests, &stats.TodayRequests,
&stats.TodayInputTokens, &stats.TodayInputTokens,
&stats.TodayOutputTokens, &stats.TodayOutputTokens,
...@@ -555,25 +556,28 @@ func (r *usageLogRepository) fillDashboardUsageStatsFromUsageLogs(ctx context.Co ...@@ -555,25 +556,28 @@ func (r *usageLogRepository) fillDashboardUsageStatsFromUsageLogs(ctx context.Co
); err != nil { ); err != nil {
return err return err
} }
stats.TodayTokens = stats.TodayInputTokens + stats.TodayOutputTokens + stats.TodayCacheCreationTokens + stats.TodayCacheReadTokens stats.TotalTokens = stats.TotalInputTokens + stats.TotalOutputTokens + stats.TotalCacheCreationTokens + stats.TotalCacheReadTokens
if stats.TotalRequests > 0 {
activeUsersQuery := ` stats.AverageDurationMs = float64(totalDurationMs) / float64(stats.TotalRequests)
SELECT COUNT(DISTINCT user_id) as active_users
FROM usage_logs
WHERE created_at >= $1 AND created_at < $2
`
if err := scanSingleRow(ctx, r.sql, activeUsersQuery, []any{todayUTC, todayEnd}, &stats.ActiveUsers); err != nil {
return err
} }
stats.TodayTokens = stats.TodayInputTokens + stats.TodayOutputTokens + stats.TodayCacheCreationTokens + stats.TodayCacheReadTokens
hourStart := now.UTC().Truncate(time.Hour) hourStart := now.UTC().Truncate(time.Hour)
hourEnd := hourStart.Add(time.Hour) hourEnd := hourStart.Add(time.Hour)
hourlyActiveQuery := ` activeUsersQuery := `
SELECT COUNT(DISTINCT user_id) as active_users WITH scoped AS (
FROM usage_logs SELECT user_id, created_at
WHERE created_at >= $1 AND created_at < $2 FROM usage_logs
WHERE created_at >= LEAST($1::timestamptz, $3::timestamptz)
AND created_at < GREATEST($2::timestamptz, $4::timestamptz)
)
SELECT
COUNT(DISTINCT CASE WHEN created_at >= $1::timestamptz AND created_at < $2::timestamptz THEN user_id END) AS active_users,
COUNT(DISTINCT CASE WHEN created_at >= $3::timestamptz AND created_at < $4::timestamptz THEN user_id END) AS hourly_active_users
FROM scoped
` `
if err := scanSingleRow(ctx, r.sql, hourlyActiveQuery, []any{hourStart, hourEnd}, &stats.HourlyActiveUsers); err != nil { if err := scanSingleRow(ctx, r.sql, activeUsersQuery, []any{todayUTC, todayEnd, hourStart, hourEnd}, &stats.ActiveUsers, &stats.HourlyActiveUsers); err != nil {
return err return err
} }
...@@ -968,6 +972,61 @@ func (r *usageLogRepository) GetAccountWindowStatsBatch(ctx context.Context, acc ...@@ -968,6 +972,61 @@ func (r *usageLogRepository) GetAccountWindowStatsBatch(ctx context.Context, acc
return result, nil return result, nil
} }
// GetGeminiUsageTotalsBatch 批量聚合 Gemini 账号在窗口内的 Pro/Flash 请求与用量。
// 模型分类规则与 service.geminiModelClassFromName 一致:model 包含 flash/lite 视为 flash,其余视为 pro。
func (r *usageLogRepository) GetGeminiUsageTotalsBatch(ctx context.Context, accountIDs []int64, startTime, endTime time.Time) (map[int64]service.GeminiUsageTotals, error) {
result := make(map[int64]service.GeminiUsageTotals, len(accountIDs))
if len(accountIDs) == 0 {
return result, nil
}
query := `
SELECT
account_id,
COALESCE(SUM(CASE WHEN LOWER(COALESCE(model, '')) LIKE '%flash%' OR LOWER(COALESCE(model, '')) LIKE '%lite%' THEN 1 ELSE 0 END), 0) AS flash_requests,
COALESCE(SUM(CASE WHEN LOWER(COALESCE(model, '')) LIKE '%flash%' OR LOWER(COALESCE(model, '')) LIKE '%lite%' THEN 0 ELSE 1 END), 0) AS pro_requests,
COALESCE(SUM(CASE WHEN LOWER(COALESCE(model, '')) LIKE '%flash%' OR LOWER(COALESCE(model, '')) LIKE '%lite%' THEN (input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens) ELSE 0 END), 0) AS flash_tokens,
COALESCE(SUM(CASE WHEN LOWER(COALESCE(model, '')) LIKE '%flash%' OR LOWER(COALESCE(model, '')) LIKE '%lite%' THEN 0 ELSE (input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens) END), 0) AS pro_tokens,
COALESCE(SUM(CASE WHEN LOWER(COALESCE(model, '')) LIKE '%flash%' OR LOWER(COALESCE(model, '')) LIKE '%lite%' THEN actual_cost ELSE 0 END), 0) AS flash_cost,
COALESCE(SUM(CASE WHEN LOWER(COALESCE(model, '')) LIKE '%flash%' OR LOWER(COALESCE(model, '')) LIKE '%lite%' THEN 0 ELSE actual_cost END), 0) AS pro_cost
FROM usage_logs
WHERE account_id = ANY($1) AND created_at >= $2 AND created_at < $3
GROUP BY account_id
`
rows, err := r.sql.QueryContext(ctx, query, pq.Array(accountIDs), startTime, endTime)
if err != nil {
return nil, err
}
defer func() { _ = rows.Close() }()
for rows.Next() {
var accountID int64
var totals service.GeminiUsageTotals
if err := rows.Scan(
&accountID,
&totals.FlashRequests,
&totals.ProRequests,
&totals.FlashTokens,
&totals.ProTokens,
&totals.FlashCost,
&totals.ProCost,
); err != nil {
return nil, err
}
result[accountID] = totals
}
if err := rows.Err(); err != nil {
return nil, err
}
for _, accountID := range accountIDs {
if _, ok := result[accountID]; !ok {
result[accountID] = service.GeminiUsageTotals{}
}
}
return result, nil
}
// TrendDataPoint represents a single point in trend data // TrendDataPoint represents a single point in trend data
type TrendDataPoint = usagestats.TrendDataPoint type TrendDataPoint = usagestats.TrendDataPoint
...@@ -1399,10 +1458,7 @@ func (r *usageLogRepository) ListWithFilters(ctx context.Context, params paginat ...@@ -1399,10 +1458,7 @@ func (r *usageLogRepository) ListWithFilters(ctx context.Context, params paginat
conditions = append(conditions, fmt.Sprintf("model = $%d", len(args)+1)) conditions = append(conditions, fmt.Sprintf("model = $%d", len(args)+1))
args = append(args, filters.Model) args = append(args, filters.Model)
} }
if filters.Stream != nil { conditions, args = appendRequestTypeOrStreamWhereCondition(conditions, args, filters.RequestType, filters.Stream)
conditions = append(conditions, fmt.Sprintf("stream = $%d", len(args)+1))
args = append(args, *filters.Stream)
}
if filters.BillingType != nil { if filters.BillingType != nil {
conditions = append(conditions, fmt.Sprintf("billing_type = $%d", len(args)+1)) conditions = append(conditions, fmt.Sprintf("billing_type = $%d", len(args)+1))
args = append(args, int16(*filters.BillingType)) args = append(args, int16(*filters.BillingType))
...@@ -1598,7 +1654,7 @@ func (r *usageLogRepository) GetBatchAPIKeyUsageStats(ctx context.Context, apiKe ...@@ -1598,7 +1654,7 @@ func (r *usageLogRepository) GetBatchAPIKeyUsageStats(ctx context.Context, apiKe
} }
// GetUsageTrendWithFilters returns usage trend data with optional filters // GetUsageTrendWithFilters returns usage trend data with optional filters
func (r *usageLogRepository) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID, accountID, groupID int64, model string, stream *bool, billingType *int8) (results []TrendDataPoint, err error) { func (r *usageLogRepository) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID, accountID, groupID int64, model string, requestType *int16, stream *bool, billingType *int8) (results []TrendDataPoint, err error) {
dateFormat := safeDateFormat(granularity) dateFormat := safeDateFormat(granularity)
query := fmt.Sprintf(` query := fmt.Sprintf(`
...@@ -1636,10 +1692,7 @@ func (r *usageLogRepository) GetUsageTrendWithFilters(ctx context.Context, start ...@@ -1636,10 +1692,7 @@ func (r *usageLogRepository) GetUsageTrendWithFilters(ctx context.Context, start
query += fmt.Sprintf(" AND model = $%d", len(args)+1) query += fmt.Sprintf(" AND model = $%d", len(args)+1)
args = append(args, model) args = append(args, model)
} }
if stream != nil { query, args = appendRequestTypeOrStreamQueryFilter(query, args, requestType, stream)
query += fmt.Sprintf(" AND stream = $%d", len(args)+1)
args = append(args, *stream)
}
if billingType != nil { if billingType != nil {
query += fmt.Sprintf(" AND billing_type = $%d", len(args)+1) query += fmt.Sprintf(" AND billing_type = $%d", len(args)+1)
args = append(args, int16(*billingType)) args = append(args, int16(*billingType))
...@@ -1667,7 +1720,7 @@ func (r *usageLogRepository) GetUsageTrendWithFilters(ctx context.Context, start ...@@ -1667,7 +1720,7 @@ func (r *usageLogRepository) GetUsageTrendWithFilters(ctx context.Context, start
} }
// GetModelStatsWithFilters returns model statistics with optional filters // GetModelStatsWithFilters returns model statistics with optional filters
func (r *usageLogRepository) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, stream *bool, billingType *int8) (results []ModelStat, err error) { func (r *usageLogRepository) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) (results []ModelStat, err error) {
actualCostExpr := "COALESCE(SUM(actual_cost), 0) as actual_cost" actualCostExpr := "COALESCE(SUM(actual_cost), 0) as actual_cost"
// 当仅按 account_id 聚合时,实际费用使用账号倍率(total_cost * account_rate_multiplier)。 // 当仅按 account_id 聚合时,实际费用使用账号倍率(total_cost * account_rate_multiplier)。
if accountID > 0 && userID == 0 && apiKeyID == 0 { if accountID > 0 && userID == 0 && apiKeyID == 0 {
...@@ -1704,10 +1757,7 @@ func (r *usageLogRepository) GetModelStatsWithFilters(ctx context.Context, start ...@@ -1704,10 +1757,7 @@ func (r *usageLogRepository) GetModelStatsWithFilters(ctx context.Context, start
query += fmt.Sprintf(" AND group_id = $%d", len(args)+1) query += fmt.Sprintf(" AND group_id = $%d", len(args)+1)
args = append(args, groupID) args = append(args, groupID)
} }
if stream != nil { query, args = appendRequestTypeOrStreamQueryFilter(query, args, requestType, stream)
query += fmt.Sprintf(" AND stream = $%d", len(args)+1)
args = append(args, *stream)
}
if billingType != nil { if billingType != nil {
query += fmt.Sprintf(" AND billing_type = $%d", len(args)+1) query += fmt.Sprintf(" AND billing_type = $%d", len(args)+1)
args = append(args, int16(*billingType)) args = append(args, int16(*billingType))
...@@ -1794,10 +1844,7 @@ func (r *usageLogRepository) GetStatsWithFilters(ctx context.Context, filters Us ...@@ -1794,10 +1844,7 @@ func (r *usageLogRepository) GetStatsWithFilters(ctx context.Context, filters Us
conditions = append(conditions, fmt.Sprintf("model = $%d", len(args)+1)) conditions = append(conditions, fmt.Sprintf("model = $%d", len(args)+1))
args = append(args, filters.Model) args = append(args, filters.Model)
} }
if filters.Stream != nil { conditions, args = appendRequestTypeOrStreamWhereCondition(conditions, args, filters.RequestType, filters.Stream)
conditions = append(conditions, fmt.Sprintf("stream = $%d", len(args)+1))
args = append(args, *filters.Stream)
}
if filters.BillingType != nil { if filters.BillingType != nil {
conditions = append(conditions, fmt.Sprintf("billing_type = $%d", len(args)+1)) conditions = append(conditions, fmt.Sprintf("billing_type = $%d", len(args)+1))
args = append(args, int16(*filters.BillingType)) args = append(args, int16(*filters.BillingType))
...@@ -2017,7 +2064,7 @@ func (r *usageLogRepository) GetAccountUsageStats(ctx context.Context, accountID ...@@ -2017,7 +2064,7 @@ func (r *usageLogRepository) GetAccountUsageStats(ctx context.Context, accountID
} }
} }
models, err := r.GetModelStatsWithFilters(ctx, startTime, endTime, 0, 0, accountID, 0, nil, nil) models, err := r.GetModelStatsWithFilters(ctx, startTime, endTime, 0, 0, accountID, 0, nil, nil, nil)
if err != nil { if err != nil {
models = []ModelStat{} models = []ModelStat{}
} }
...@@ -2267,7 +2314,9 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e ...@@ -2267,7 +2314,9 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
rateMultiplier float64 rateMultiplier float64
accountRateMultiplier sql.NullFloat64 accountRateMultiplier sql.NullFloat64
billingType int16 billingType int16
requestTypeRaw int16
stream bool stream bool
openaiWSMode bool
durationMs sql.NullInt64 durationMs sql.NullInt64
firstTokenMs sql.NullInt64 firstTokenMs sql.NullInt64
userAgent sql.NullString userAgent sql.NullString
...@@ -2304,7 +2353,9 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e ...@@ -2304,7 +2353,9 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
&rateMultiplier, &rateMultiplier,
&accountRateMultiplier, &accountRateMultiplier,
&billingType, &billingType,
&requestTypeRaw,
&stream, &stream,
&openaiWSMode,
&durationMs, &durationMs,
&firstTokenMs, &firstTokenMs,
&userAgent, &userAgent,
...@@ -2340,11 +2391,16 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e ...@@ -2340,11 +2391,16 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
RateMultiplier: rateMultiplier, RateMultiplier: rateMultiplier,
AccountRateMultiplier: nullFloat64Ptr(accountRateMultiplier), AccountRateMultiplier: nullFloat64Ptr(accountRateMultiplier),
BillingType: int8(billingType), BillingType: int8(billingType),
Stream: stream, RequestType: service.RequestTypeFromInt16(requestTypeRaw),
ImageCount: imageCount, ImageCount: imageCount,
CacheTTLOverridden: cacheTTLOverridden, CacheTTLOverridden: cacheTTLOverridden,
CreatedAt: createdAt, CreatedAt: createdAt,
} }
// 先回填 legacy 字段,再基于 legacy + request_type 计算最终请求类型,保证历史数据兼容。
log.Stream = stream
log.OpenAIWSMode = openaiWSMode
log.RequestType = log.EffectiveRequestType()
log.Stream, log.OpenAIWSMode = service.ApplyLegacyRequestFields(log.RequestType, stream, openaiWSMode)
if requestID.Valid { if requestID.Valid {
log.RequestID = requestID.String log.RequestID = requestID.String
...@@ -2438,6 +2494,50 @@ func buildWhere(conditions []string) string { ...@@ -2438,6 +2494,50 @@ func buildWhere(conditions []string) string {
return "WHERE " + strings.Join(conditions, " AND ") return "WHERE " + strings.Join(conditions, " AND ")
} }
func appendRequestTypeOrStreamWhereCondition(conditions []string, args []any, requestType *int16, stream *bool) ([]string, []any) {
if requestType != nil {
condition, conditionArgs := buildRequestTypeFilterCondition(len(args)+1, *requestType)
conditions = append(conditions, condition)
args = append(args, conditionArgs...)
return conditions, args
}
if stream != nil {
conditions = append(conditions, fmt.Sprintf("stream = $%d", len(args)+1))
args = append(args, *stream)
}
return conditions, args
}
func appendRequestTypeOrStreamQueryFilter(query string, args []any, requestType *int16, stream *bool) (string, []any) {
if requestType != nil {
condition, conditionArgs := buildRequestTypeFilterCondition(len(args)+1, *requestType)
query += " AND " + condition
args = append(args, conditionArgs...)
return query, args
}
if stream != nil {
query += fmt.Sprintf(" AND stream = $%d", len(args)+1)
args = append(args, *stream)
}
return query, args
}
// buildRequestTypeFilterCondition 在 request_type 过滤时兼容 legacy 字段,避免历史数据漏查。
func buildRequestTypeFilterCondition(startArgIndex int, requestType int16) (string, []any) {
normalized := service.RequestTypeFromInt16(requestType)
requestTypeArg := int16(normalized)
switch normalized {
case service.RequestTypeSync:
return fmt.Sprintf("(request_type = $%d OR (request_type = %d AND stream = FALSE AND openai_ws_mode = FALSE))", startArgIndex, int16(service.RequestTypeUnknown)), []any{requestTypeArg}
case service.RequestTypeStream:
return fmt.Sprintf("(request_type = $%d OR (request_type = %d AND stream = TRUE AND openai_ws_mode = FALSE))", startArgIndex, int16(service.RequestTypeUnknown)), []any{requestTypeArg}
case service.RequestTypeWSV2:
return fmt.Sprintf("(request_type = $%d OR (request_type = %d AND openai_ws_mode = TRUE))", startArgIndex, int16(service.RequestTypeUnknown)), []any{requestTypeArg}
default:
return fmt.Sprintf("request_type = $%d", startArgIndex), []any{requestTypeArg}
}
}
func nullInt64(v *int64) sql.NullInt64 { func nullInt64(v *int64) sql.NullInt64 {
if v == nil { if v == nil {
return sql.NullInt64{} return sql.NullInt64{}
......
...@@ -130,6 +130,62 @@ func (s *UsageLogRepoSuite) TestGetByID_ReturnsAccountRateMultiplier() { ...@@ -130,6 +130,62 @@ func (s *UsageLogRepoSuite) TestGetByID_ReturnsAccountRateMultiplier() {
s.Require().InEpsilon(0.5, *got.AccountRateMultiplier, 0.0001) s.Require().InEpsilon(0.5, *got.AccountRateMultiplier, 0.0001)
} }
func (s *UsageLogRepoSuite) TestGetByID_ReturnsOpenAIWSMode() {
user := mustCreateUser(s.T(), s.client, &service.User{Email: "getbyid-ws@test.com"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-getbyid-ws", Name: "k"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-getbyid-ws"})
log := &service.UsageLog{
UserID: user.ID,
APIKeyID: apiKey.ID,
AccountID: account.ID,
RequestID: uuid.New().String(),
Model: "gpt-5.3-codex",
InputTokens: 10,
OutputTokens: 20,
TotalCost: 1.0,
ActualCost: 1.0,
OpenAIWSMode: true,
CreatedAt: timezone.Today().Add(3 * time.Hour),
}
_, err := s.repo.Create(s.ctx, log)
s.Require().NoError(err)
got, err := s.repo.GetByID(s.ctx, log.ID)
s.Require().NoError(err)
s.Require().True(got.OpenAIWSMode)
}
func (s *UsageLogRepoSuite) TestGetByID_ReturnsRequestTypeAndLegacyFallback() {
user := mustCreateUser(s.T(), s.client, &service.User{Email: "getbyid-request-type@test.com"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-getbyid-request-type", Name: "k"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-getbyid-request-type"})
log := &service.UsageLog{
UserID: user.ID,
APIKeyID: apiKey.ID,
AccountID: account.ID,
RequestID: uuid.New().String(),
Model: "gpt-5.3-codex",
RequestType: service.RequestTypeWSV2,
Stream: true,
OpenAIWSMode: false,
InputTokens: 10,
OutputTokens: 20,
TotalCost: 1.0,
ActualCost: 1.0,
CreatedAt: timezone.Today().Add(4 * time.Hour),
}
_, err := s.repo.Create(s.ctx, log)
s.Require().NoError(err)
got, err := s.repo.GetByID(s.ctx, log.ID)
s.Require().NoError(err)
s.Require().Equal(service.RequestTypeWSV2, got.RequestType)
s.Require().True(got.Stream)
s.Require().True(got.OpenAIWSMode)
}
// --- Delete --- // --- Delete ---
func (s *UsageLogRepoSuite) TestDelete() { func (s *UsageLogRepoSuite) TestDelete() {
...@@ -944,17 +1000,17 @@ func (s *UsageLogRepoSuite) TestGetUsageTrendWithFilters() { ...@@ -944,17 +1000,17 @@ func (s *UsageLogRepoSuite) TestGetUsageTrendWithFilters() {
endTime := base.Add(48 * time.Hour) endTime := base.Add(48 * time.Hour)
// Test with user filter // Test with user filter
trend, err := s.repo.GetUsageTrendWithFilters(s.ctx, startTime, endTime, "day", user.ID, 0, 0, 0, "", nil, nil) trend, err := s.repo.GetUsageTrendWithFilters(s.ctx, startTime, endTime, "day", user.ID, 0, 0, 0, "", nil, nil, nil)
s.Require().NoError(err, "GetUsageTrendWithFilters user filter") s.Require().NoError(err, "GetUsageTrendWithFilters user filter")
s.Require().Len(trend, 2) s.Require().Len(trend, 2)
// Test with apiKey filter // Test with apiKey filter
trend, err = s.repo.GetUsageTrendWithFilters(s.ctx, startTime, endTime, "day", 0, apiKey.ID, 0, 0, "", nil, nil) trend, err = s.repo.GetUsageTrendWithFilters(s.ctx, startTime, endTime, "day", 0, apiKey.ID, 0, 0, "", nil, nil, nil)
s.Require().NoError(err, "GetUsageTrendWithFilters apiKey filter") s.Require().NoError(err, "GetUsageTrendWithFilters apiKey filter")
s.Require().Len(trend, 2) s.Require().Len(trend, 2)
// Test with both filters // Test with both filters
trend, err = s.repo.GetUsageTrendWithFilters(s.ctx, startTime, endTime, "day", user.ID, apiKey.ID, 0, 0, "", nil, nil) trend, err = s.repo.GetUsageTrendWithFilters(s.ctx, startTime, endTime, "day", user.ID, apiKey.ID, 0, 0, "", nil, nil, nil)
s.Require().NoError(err, "GetUsageTrendWithFilters both filters") s.Require().NoError(err, "GetUsageTrendWithFilters both filters")
s.Require().Len(trend, 2) s.Require().Len(trend, 2)
} }
...@@ -971,7 +1027,7 @@ func (s *UsageLogRepoSuite) TestGetUsageTrendWithFilters_HourlyGranularity() { ...@@ -971,7 +1027,7 @@ func (s *UsageLogRepoSuite) TestGetUsageTrendWithFilters_HourlyGranularity() {
startTime := base.Add(-1 * time.Hour) startTime := base.Add(-1 * time.Hour)
endTime := base.Add(3 * time.Hour) endTime := base.Add(3 * time.Hour)
trend, err := s.repo.GetUsageTrendWithFilters(s.ctx, startTime, endTime, "hour", user.ID, 0, 0, 0, "", nil, nil) trend, err := s.repo.GetUsageTrendWithFilters(s.ctx, startTime, endTime, "hour", user.ID, 0, 0, 0, "", nil, nil, nil)
s.Require().NoError(err, "GetUsageTrendWithFilters hourly") s.Require().NoError(err, "GetUsageTrendWithFilters hourly")
s.Require().Len(trend, 2) s.Require().Len(trend, 2)
} }
...@@ -1017,17 +1073,17 @@ func (s *UsageLogRepoSuite) TestGetModelStatsWithFilters() { ...@@ -1017,17 +1073,17 @@ func (s *UsageLogRepoSuite) TestGetModelStatsWithFilters() {
endTime := base.Add(2 * time.Hour) endTime := base.Add(2 * time.Hour)
// Test with user filter // Test with user filter
stats, err := s.repo.GetModelStatsWithFilters(s.ctx, startTime, endTime, user.ID, 0, 0, 0, nil, nil) stats, err := s.repo.GetModelStatsWithFilters(s.ctx, startTime, endTime, user.ID, 0, 0, 0, nil, nil, nil)
s.Require().NoError(err, "GetModelStatsWithFilters user filter") s.Require().NoError(err, "GetModelStatsWithFilters user filter")
s.Require().Len(stats, 2) s.Require().Len(stats, 2)
// Test with apiKey filter // Test with apiKey filter
stats, err = s.repo.GetModelStatsWithFilters(s.ctx, startTime, endTime, 0, apiKey.ID, 0, 0, nil, nil) stats, err = s.repo.GetModelStatsWithFilters(s.ctx, startTime, endTime, 0, apiKey.ID, 0, 0, nil, nil, nil)
s.Require().NoError(err, "GetModelStatsWithFilters apiKey filter") s.Require().NoError(err, "GetModelStatsWithFilters apiKey filter")
s.Require().Len(stats, 2) s.Require().Len(stats, 2)
// Test with account filter // Test with account filter
stats, err = s.repo.GetModelStatsWithFilters(s.ctx, startTime, endTime, 0, 0, account.ID, 0, nil, nil) stats, err = s.repo.GetModelStatsWithFilters(s.ctx, startTime, endTime, 0, 0, account.ID, 0, nil, nil, nil)
s.Require().NoError(err, "GetModelStatsWithFilters account filter") s.Require().NoError(err, "GetModelStatsWithFilters account filter")
s.Require().Len(stats, 2) s.Require().Len(stats, 2)
} }
......
package repository
import (
"context"
"database/sql"
"fmt"
"reflect"
"testing"
"time"
"github.com/DATA-DOG/go-sqlmock"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/stretchr/testify/require"
)
func TestUsageLogRepositoryCreateSyncRequestTypeAndLegacyFields(t *testing.T) {
db, mock := newSQLMock(t)
repo := &usageLogRepository{sql: db}
createdAt := time.Date(2025, 1, 1, 12, 0, 0, 0, time.UTC)
log := &service.UsageLog{
UserID: 1,
APIKeyID: 2,
AccountID: 3,
RequestID: "req-1",
Model: "gpt-5",
InputTokens: 10,
OutputTokens: 20,
TotalCost: 1,
ActualCost: 1,
BillingType: service.BillingTypeBalance,
RequestType: service.RequestTypeWSV2,
Stream: false,
OpenAIWSMode: false,
CreatedAt: createdAt,
}
mock.ExpectQuery("INSERT INTO usage_logs").
WithArgs(
log.UserID,
log.APIKeyID,
log.AccountID,
log.RequestID,
log.Model,
sqlmock.AnyArg(), // group_id
sqlmock.AnyArg(), // subscription_id
log.InputTokens,
log.OutputTokens,
log.CacheCreationTokens,
log.CacheReadTokens,
log.CacheCreation5mTokens,
log.CacheCreation1hTokens,
log.InputCost,
log.OutputCost,
log.CacheCreationCost,
log.CacheReadCost,
log.TotalCost,
log.ActualCost,
log.RateMultiplier,
log.AccountRateMultiplier,
log.BillingType,
int16(service.RequestTypeWSV2),
true,
true,
sqlmock.AnyArg(), // duration_ms
sqlmock.AnyArg(), // first_token_ms
sqlmock.AnyArg(), // user_agent
sqlmock.AnyArg(), // ip_address
log.ImageCount,
sqlmock.AnyArg(), // image_size
sqlmock.AnyArg(), // media_type
sqlmock.AnyArg(), // reasoning_effort
log.CacheTTLOverridden,
createdAt,
).
WillReturnRows(sqlmock.NewRows([]string{"id", "created_at"}).AddRow(int64(99), createdAt))
inserted, err := repo.Create(context.Background(), log)
require.NoError(t, err)
require.True(t, inserted)
require.Equal(t, int64(99), log.ID)
require.Equal(t, service.RequestTypeWSV2, log.RequestType)
require.True(t, log.Stream)
require.True(t, log.OpenAIWSMode)
require.NoError(t, mock.ExpectationsWereMet())
}
func TestUsageLogRepositoryListWithFiltersRequestTypePriority(t *testing.T) {
db, mock := newSQLMock(t)
repo := &usageLogRepository{sql: db}
requestType := int16(service.RequestTypeWSV2)
stream := false
filters := usagestats.UsageLogFilters{
RequestType: &requestType,
Stream: &stream,
}
mock.ExpectQuery("SELECT COUNT\\(\\*\\) FROM usage_logs WHERE \\(request_type = \\$1 OR \\(request_type = 0 AND openai_ws_mode = TRUE\\)\\)").
WithArgs(requestType).
WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(int64(0)))
mock.ExpectQuery("SELECT .* FROM usage_logs WHERE \\(request_type = \\$1 OR \\(request_type = 0 AND openai_ws_mode = TRUE\\)\\) ORDER BY id DESC LIMIT \\$2 OFFSET \\$3").
WithArgs(requestType, 20, 0).
WillReturnRows(sqlmock.NewRows([]string{"id"}))
logs, page, err := repo.ListWithFilters(context.Background(), pagination.PaginationParams{Page: 1, PageSize: 20}, filters)
require.NoError(t, err)
require.Empty(t, logs)
require.NotNil(t, page)
require.Equal(t, int64(0), page.Total)
require.NoError(t, mock.ExpectationsWereMet())
}
func TestUsageLogRepositoryGetUsageTrendWithFiltersRequestTypePriority(t *testing.T) {
db, mock := newSQLMock(t)
repo := &usageLogRepository{sql: db}
start := time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC)
end := start.Add(24 * time.Hour)
requestType := int16(service.RequestTypeStream)
stream := true
mock.ExpectQuery("AND \\(request_type = \\$3 OR \\(request_type = 0 AND stream = TRUE AND openai_ws_mode = FALSE\\)\\)").
WithArgs(start, end, requestType).
WillReturnRows(sqlmock.NewRows([]string{"date", "requests", "input_tokens", "output_tokens", "cache_tokens", "total_tokens", "cost", "actual_cost"}))
trend, err := repo.GetUsageTrendWithFilters(context.Background(), start, end, "day", 0, 0, 0, 0, "", &requestType, &stream, nil)
require.NoError(t, err)
require.Empty(t, trend)
require.NoError(t, mock.ExpectationsWereMet())
}
func TestUsageLogRepositoryGetModelStatsWithFiltersRequestTypePriority(t *testing.T) {
db, mock := newSQLMock(t)
repo := &usageLogRepository{sql: db}
start := time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC)
end := start.Add(24 * time.Hour)
requestType := int16(service.RequestTypeWSV2)
stream := false
mock.ExpectQuery("AND \\(request_type = \\$3 OR \\(request_type = 0 AND openai_ws_mode = TRUE\\)\\)").
WithArgs(start, end, requestType).
WillReturnRows(sqlmock.NewRows([]string{"model", "requests", "input_tokens", "output_tokens", "total_tokens", "cost", "actual_cost"}))
stats, err := repo.GetModelStatsWithFilters(context.Background(), start, end, 0, 0, 0, 0, &requestType, &stream, nil)
require.NoError(t, err)
require.Empty(t, stats)
require.NoError(t, mock.ExpectationsWereMet())
}
func TestUsageLogRepositoryGetStatsWithFiltersRequestTypePriority(t *testing.T) {
db, mock := newSQLMock(t)
repo := &usageLogRepository{sql: db}
requestType := int16(service.RequestTypeSync)
stream := true
filters := usagestats.UsageLogFilters{
RequestType: &requestType,
Stream: &stream,
}
mock.ExpectQuery("FROM usage_logs\\s+WHERE \\(request_type = \\$1 OR \\(request_type = 0 AND stream = FALSE AND openai_ws_mode = FALSE\\)\\)").
WithArgs(requestType).
WillReturnRows(sqlmock.NewRows([]string{
"total_requests",
"total_input_tokens",
"total_output_tokens",
"total_cache_tokens",
"total_cost",
"total_actual_cost",
"total_account_cost",
"avg_duration_ms",
}).AddRow(int64(1), int64(2), int64(3), int64(4), 1.2, 1.0, 1.2, 20.0))
stats, err := repo.GetStatsWithFilters(context.Background(), filters)
require.NoError(t, err)
require.Equal(t, int64(1), stats.TotalRequests)
require.Equal(t, int64(9), stats.TotalTokens)
require.NoError(t, mock.ExpectationsWereMet())
}
func TestBuildRequestTypeFilterConditionLegacyFallback(t *testing.T) {
tests := []struct {
name string
request int16
wantWhere string
wantArg int16
}{
{
name: "sync_with_legacy_fallback",
request: int16(service.RequestTypeSync),
wantWhere: "(request_type = $3 OR (request_type = 0 AND stream = FALSE AND openai_ws_mode = FALSE))",
wantArg: int16(service.RequestTypeSync),
},
{
name: "stream_with_legacy_fallback",
request: int16(service.RequestTypeStream),
wantWhere: "(request_type = $3 OR (request_type = 0 AND stream = TRUE AND openai_ws_mode = FALSE))",
wantArg: int16(service.RequestTypeStream),
},
{
name: "ws_v2_with_legacy_fallback",
request: int16(service.RequestTypeWSV2),
wantWhere: "(request_type = $3 OR (request_type = 0 AND openai_ws_mode = TRUE))",
wantArg: int16(service.RequestTypeWSV2),
},
{
name: "invalid_request_type_normalized_to_unknown",
request: int16(99),
wantWhere: "request_type = $3",
wantArg: int16(service.RequestTypeUnknown),
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
where, args := buildRequestTypeFilterCondition(3, tt.request)
require.Equal(t, tt.wantWhere, where)
require.Equal(t, []any{tt.wantArg}, args)
})
}
}
type usageLogScannerStub struct {
values []any
}
func (s usageLogScannerStub) Scan(dest ...any) error {
if len(dest) != len(s.values) {
return fmt.Errorf("scan arg count mismatch: got %d want %d", len(dest), len(s.values))
}
for i := range dest {
dv := reflect.ValueOf(dest[i])
if dv.Kind() != reflect.Ptr {
return fmt.Errorf("dest[%d] is not pointer", i)
}
dv.Elem().Set(reflect.ValueOf(s.values[i]))
}
return nil
}
func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) {
t.Run("request_type_ws_v2_overrides_legacy", func(t *testing.T) {
now := time.Now().UTC()
log, err := scanUsageLog(usageLogScannerStub{values: []any{
int64(1), // id
int64(10), // user_id
int64(20), // api_key_id
int64(30), // account_id
sql.NullString{Valid: true, String: "req-1"},
"gpt-5", // model
sql.NullInt64{}, // group_id
sql.NullInt64{}, // subscription_id
1, // input_tokens
2, // output_tokens
3, // cache_creation_tokens
4, // cache_read_tokens
5, // cache_creation_5m_tokens
6, // cache_creation_1h_tokens
0.1, // input_cost
0.2, // output_cost
0.3, // cache_creation_cost
0.4, // cache_read_cost
1.0, // total_cost
0.9, // actual_cost
1.0, // rate_multiplier
sql.NullFloat64{}, // account_rate_multiplier
int16(service.BillingTypeBalance),
int16(service.RequestTypeWSV2),
false, // legacy stream
false, // legacy openai ws
sql.NullInt64{},
sql.NullInt64{},
sql.NullString{},
sql.NullString{},
0,
sql.NullString{},
sql.NullString{},
sql.NullString{},
false,
now,
}})
require.NoError(t, err)
require.Equal(t, service.RequestTypeWSV2, log.RequestType)
require.True(t, log.Stream)
require.True(t, log.OpenAIWSMode)
})
t.Run("request_type_unknown_falls_back_to_legacy", func(t *testing.T) {
now := time.Now().UTC()
log, err := scanUsageLog(usageLogScannerStub{values: []any{
int64(2),
int64(11),
int64(21),
int64(31),
sql.NullString{Valid: true, String: "req-2"},
"gpt-5",
sql.NullInt64{},
sql.NullInt64{},
1, 2, 3, 4, 5, 6,
0.1, 0.2, 0.3, 0.4, 1.0, 0.9,
1.0,
sql.NullFloat64{},
int16(service.BillingTypeBalance),
int16(service.RequestTypeUnknown),
true,
false,
sql.NullInt64{},
sql.NullInt64{},
sql.NullString{},
sql.NullString{},
0,
sql.NullString{},
sql.NullString{},
sql.NullString{},
false,
now,
}})
require.NoError(t, err)
require.Equal(t, service.RequestTypeStream, log.RequestType)
require.True(t, log.Stream)
require.False(t, log.OpenAIWSMode)
})
}
...@@ -6,6 +6,7 @@ import ( ...@@ -6,6 +6,7 @@ import (
"time" "time"
"github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
"github.com/lib/pq"
) )
type userGroupRateRepository struct { type userGroupRateRepository struct {
...@@ -41,6 +42,59 @@ func (r *userGroupRateRepository) GetByUserID(ctx context.Context, userID int64) ...@@ -41,6 +42,59 @@ func (r *userGroupRateRepository) GetByUserID(ctx context.Context, userID int64)
return result, nil return result, nil
} }
// GetByUserIDs 批量获取多个用户的专属分组倍率。
// 返回结构:map[userID]map[groupID]rate
func (r *userGroupRateRepository) GetByUserIDs(ctx context.Context, userIDs []int64) (map[int64]map[int64]float64, error) {
result := make(map[int64]map[int64]float64, len(userIDs))
if len(userIDs) == 0 {
return result, nil
}
uniqueIDs := make([]int64, 0, len(userIDs))
seen := make(map[int64]struct{}, len(userIDs))
for _, userID := range userIDs {
if userID <= 0 {
continue
}
if _, exists := seen[userID]; exists {
continue
}
seen[userID] = struct{}{}
uniqueIDs = append(uniqueIDs, userID)
result[userID] = make(map[int64]float64)
}
if len(uniqueIDs) == 0 {
return result, nil
}
rows, err := r.sql.QueryContext(ctx, `
SELECT user_id, group_id, rate_multiplier
FROM user_group_rate_multipliers
WHERE user_id = ANY($1)
`, pq.Array(uniqueIDs))
if err != nil {
return nil, err
}
defer func() { _ = rows.Close() }()
for rows.Next() {
var userID int64
var groupID int64
var rate float64
if err := rows.Scan(&userID, &groupID, &rate); err != nil {
return nil, err
}
if _, ok := result[userID]; !ok {
result[userID] = make(map[int64]float64)
}
result[userID][groupID] = rate
}
if err := rows.Err(); err != nil {
return nil, err
}
return result, nil
}
// GetByUserAndGroup 获取用户在特定分组的专属倍率 // GetByUserAndGroup 获取用户在特定分组的专属倍率
func (r *userGroupRateRepository) GetByUserAndGroup(ctx context.Context, userID, groupID int64) (*float64, error) { func (r *userGroupRateRepository) GetByUserAndGroup(ctx context.Context, userID, groupID int64) (*float64, error) {
query := `SELECT rate_multiplier FROM user_group_rate_multipliers WHERE user_id = $1 AND group_id = $2` query := `SELECT rate_multiplier FROM user_group_rate_multipliers WHERE user_id = $1 AND group_id = $2`
...@@ -65,33 +119,43 @@ func (r *userGroupRateRepository) SyncUserGroupRates(ctx context.Context, userID ...@@ -65,33 +119,43 @@ func (r *userGroupRateRepository) SyncUserGroupRates(ctx context.Context, userID
// 分离需要删除和需要 upsert 的记录 // 分离需要删除和需要 upsert 的记录
var toDelete []int64 var toDelete []int64
toUpsert := make(map[int64]float64) upsertGroupIDs := make([]int64, 0, len(rates))
upsertRates := make([]float64, 0, len(rates))
for groupID, rate := range rates { for groupID, rate := range rates {
if rate == nil { if rate == nil {
toDelete = append(toDelete, groupID) toDelete = append(toDelete, groupID)
} else { } else {
toUpsert[groupID] = *rate upsertGroupIDs = append(upsertGroupIDs, groupID)
upsertRates = append(upsertRates, *rate)
} }
} }
// 删除指定的记录 // 删除指定的记录
for _, groupID := range toDelete { if len(toDelete) > 0 {
_, err := r.sql.ExecContext(ctx, if _, err := r.sql.ExecContext(ctx,
`DELETE FROM user_group_rate_multipliers WHERE user_id = $1 AND group_id = $2`, `DELETE FROM user_group_rate_multipliers WHERE user_id = $1 AND group_id = ANY($2)`,
userID, groupID) userID, pq.Array(toDelete)); err != nil {
if err != nil {
return err return err
} }
} }
// Upsert 记录 // Upsert 记录
now := time.Now() now := time.Now()
for groupID, rate := range toUpsert { if len(upsertGroupIDs) > 0 {
_, err := r.sql.ExecContext(ctx, ` _, err := r.sql.ExecContext(ctx, `
INSERT INTO user_group_rate_multipliers (user_id, group_id, rate_multiplier, created_at, updated_at) INSERT INTO user_group_rate_multipliers (user_id, group_id, rate_multiplier, created_at, updated_at)
VALUES ($1, $2, $3, $4, $4) SELECT
ON CONFLICT (user_id, group_id) DO UPDATE SET rate_multiplier = $3, updated_at = $4 $1::bigint,
`, userID, groupID, rate, now) data.group_id,
data.rate_multiplier,
$2::timestamptz,
$2::timestamptz
FROM unnest($3::bigint[], $4::double precision[]) AS data(group_id, rate_multiplier)
ON CONFLICT (user_id, group_id)
DO UPDATE SET
rate_multiplier = EXCLUDED.rate_multiplier,
updated_at = EXCLUDED.updated_at
`, userID, now, pq.Array(upsertGroupIDs), pq.Array(upsertRates))
if err != nil { if err != nil {
return err return err
} }
......
...@@ -61,6 +61,7 @@ func (r *userRepository) Create(ctx context.Context, userIn *service.User) error ...@@ -61,6 +61,7 @@ func (r *userRepository) Create(ctx context.Context, userIn *service.User) error
SetBalance(userIn.Balance). SetBalance(userIn.Balance).
SetConcurrency(userIn.Concurrency). SetConcurrency(userIn.Concurrency).
SetStatus(userIn.Status). SetStatus(userIn.Status).
SetSoraStorageQuotaBytes(userIn.SoraStorageQuotaBytes).
Save(ctx) Save(ctx)
if err != nil { if err != nil {
return translatePersistenceError(err, nil, service.ErrEmailExists) return translatePersistenceError(err, nil, service.ErrEmailExists)
...@@ -143,6 +144,8 @@ func (r *userRepository) Update(ctx context.Context, userIn *service.User) error ...@@ -143,6 +144,8 @@ func (r *userRepository) Update(ctx context.Context, userIn *service.User) error
SetBalance(userIn.Balance). SetBalance(userIn.Balance).
SetConcurrency(userIn.Concurrency). SetConcurrency(userIn.Concurrency).
SetStatus(userIn.Status). SetStatus(userIn.Status).
SetSoraStorageQuotaBytes(userIn.SoraStorageQuotaBytes).
SetSoraStorageUsedBytes(userIn.SoraStorageUsedBytes).
Save(ctx) Save(ctx)
if err != nil { if err != nil {
return translatePersistenceError(err, service.ErrUserNotFound, service.ErrEmailExists) return translatePersistenceError(err, service.ErrUserNotFound, service.ErrEmailExists)
...@@ -363,6 +366,65 @@ func (r *userRepository) UpdateConcurrency(ctx context.Context, id int64, amount ...@@ -363,6 +366,65 @@ func (r *userRepository) UpdateConcurrency(ctx context.Context, id int64, amount
return nil return nil
} }
// AddSoraStorageUsageWithQuota 原子累加 Sora 存储用量,并在有配额时校验不超额。
func (r *userRepository) AddSoraStorageUsageWithQuota(ctx context.Context, userID int64, deltaBytes int64, effectiveQuota int64) (int64, error) {
if deltaBytes <= 0 {
user, err := r.GetByID(ctx, userID)
if err != nil {
return 0, err
}
return user.SoraStorageUsedBytes, nil
}
var newUsed int64
err := scanSingleRow(ctx, r.sql, `
UPDATE users
SET sora_storage_used_bytes = sora_storage_used_bytes + $2
WHERE id = $1
AND ($3 = 0 OR sora_storage_used_bytes + $2 <= $3)
RETURNING sora_storage_used_bytes
`, []any{userID, deltaBytes, effectiveQuota}, &newUsed)
if err == nil {
return newUsed, nil
}
if errors.Is(err, sql.ErrNoRows) {
// 区分用户不存在和配额冲突
exists, existsErr := r.client.User.Query().Where(dbuser.IDEQ(userID)).Exist(ctx)
if existsErr != nil {
return 0, existsErr
}
if !exists {
return 0, service.ErrUserNotFound
}
return 0, service.ErrSoraStorageQuotaExceeded
}
return 0, err
}
// ReleaseSoraStorageUsageAtomic 原子释放 Sora 存储用量,并保证不低于 0。
func (r *userRepository) ReleaseSoraStorageUsageAtomic(ctx context.Context, userID int64, deltaBytes int64) (int64, error) {
if deltaBytes <= 0 {
user, err := r.GetByID(ctx, userID)
if err != nil {
return 0, err
}
return user.SoraStorageUsedBytes, nil
}
var newUsed int64
err := scanSingleRow(ctx, r.sql, `
UPDATE users
SET sora_storage_used_bytes = GREATEST(sora_storage_used_bytes - $2, 0)
WHERE id = $1
RETURNING sora_storage_used_bytes
`, []any{userID, deltaBytes}, &newUsed)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return 0, service.ErrUserNotFound
}
return 0, err
}
return newUsed, nil
}
func (r *userRepository) ExistsByEmail(ctx context.Context, email string) (bool, error) { func (r *userRepository) ExistsByEmail(ctx context.Context, email string) (bool, error) {
return r.client.User.Query().Where(dbuser.EmailEQ(email)).Exist(ctx) return r.client.User.Query().Where(dbuser.EmailEQ(email)).Exist(ctx)
} }
......
...@@ -186,11 +186,12 @@ func TestAPIContracts(t *testing.T) { ...@@ -186,11 +186,12 @@ func TestAPIContracts(t *testing.T) {
"image_price_1k": null, "image_price_1k": null,
"image_price_2k": null, "image_price_2k": null,
"image_price_4k": null, "image_price_4k": null,
"sora_image_price_360": null, "sora_image_price_360": null,
"sora_image_price_540": null, "sora_image_price_540": null,
"sora_video_price_per_request": null, "sora_storage_quota_bytes": 0,
"sora_video_price_per_request_hd": null, "sora_video_price_per_request": null,
"claude_code_only": false, "sora_video_price_per_request_hd": null,
"claude_code_only": false,
"fallback_group_id": null, "fallback_group_id": null,
"fallback_group_id_on_invalid_request": null, "fallback_group_id_on_invalid_request": null,
"created_at": "2025-01-02T03:04:05Z", "created_at": "2025-01-02T03:04:05Z",
...@@ -384,10 +385,12 @@ func TestAPIContracts(t *testing.T) { ...@@ -384,10 +385,12 @@ func TestAPIContracts(t *testing.T) {
"user_id": 1, "user_id": 1,
"api_key_id": 100, "api_key_id": 100,
"account_id": 200, "account_id": 200,
"request_id": "req_123", "request_id": "req_123",
"model": "claude-3", "model": "claude-3",
"group_id": null, "request_type": "stream",
"subscription_id": null, "openai_ws_mode": false,
"group_id": null,
"subscription_id": null,
"input_tokens": 10, "input_tokens": 10,
"output_tokens": 20, "output_tokens": 20,
"cache_creation_tokens": 1, "cache_creation_tokens": 1,
...@@ -500,11 +503,12 @@ func TestAPIContracts(t *testing.T) { ...@@ -500,11 +503,12 @@ func TestAPIContracts(t *testing.T) {
"fallback_model_anthropic": "claude-3-5-sonnet-20241022", "fallback_model_anthropic": "claude-3-5-sonnet-20241022",
"fallback_model_antigravity": "gemini-2.5-pro", "fallback_model_antigravity": "gemini-2.5-pro",
"fallback_model_gemini": "gemini-2.5-pro", "fallback_model_gemini": "gemini-2.5-pro",
"fallback_model_openai": "gpt-4o", "fallback_model_openai": "gpt-4o",
"enable_identity_patch": true, "enable_identity_patch": true,
"identity_patch_prompt": "", "identity_patch_prompt": "",
"invitation_code_enabled": false, "sora_client_enabled": false,
"home_content": "", "invitation_code_enabled": false,
"home_content": "",
"hide_ccs_import_button": false, "hide_ccs_import_button": false,
"purchase_subscription_enabled": false, "purchase_subscription_enabled": false,
"purchase_subscription_url": "" "purchase_subscription_url": ""
...@@ -619,7 +623,7 @@ func newContractDeps(t *testing.T) *contractDeps { ...@@ -619,7 +623,7 @@ func newContractDeps(t *testing.T) *contractDeps {
authHandler := handler.NewAuthHandler(cfg, nil, userService, settingService, nil, redeemService, nil) authHandler := handler.NewAuthHandler(cfg, nil, userService, settingService, nil, redeemService, nil)
apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService) apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService)
usageHandler := handler.NewUsageHandler(usageService, apiKeyService) usageHandler := handler.NewUsageHandler(usageService, apiKeyService)
adminSettingHandler := adminhandler.NewSettingHandler(settingService, nil, nil, nil) adminSettingHandler := adminhandler.NewSettingHandler(settingService, nil, nil, nil, nil)
adminAccountHandler := adminhandler.NewAccountHandler(adminService, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil) adminAccountHandler := adminhandler.NewAccountHandler(adminService, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
jwtAuth := func(c *gin.Context) { jwtAuth := func(c *gin.Context) {
...@@ -1555,11 +1559,11 @@ func (r *stubUsageLogRepo) GetDashboardStats(ctx context.Context) (*usagestats.D ...@@ -1555,11 +1559,11 @@ func (r *stubUsageLogRepo) GetDashboardStats(ctx context.Context) (*usagestats.D
return nil, errors.New("not implemented") return nil, errors.New("not implemented")
} }
func (r *stubUsageLogRepo) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID, accountID, groupID int64, model string, stream *bool, billingType *int8) ([]usagestats.TrendDataPoint, error) { func (r *stubUsageLogRepo) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID, accountID, groupID int64, model string, requestType *int16, stream *bool, billingType *int8) ([]usagestats.TrendDataPoint, error) {
return nil, errors.New("not implemented") return nil, errors.New("not implemented")
} }
func (r *stubUsageLogRepo) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, stream *bool, billingType *int8) ([]usagestats.ModelStat, error) { func (r *stubUsageLogRepo) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) ([]usagestats.ModelStat, error) {
return nil, errors.New("not implemented") return nil, errors.New("not implemented")
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment