Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Menu
Open sidebar
陈曦
sub2api
Commits
3d79773b
Commit
3d79773b
authored
Mar 04, 2026
by
kyx236
Browse files
Merge branch 'main' of
https://github.com/james-6-23/sub2api
parents
6aa8cbbf
742e73c9
Changes
249
Hide whitespace changes
Inline
Side-by-side
Too many changes to show.
To preserve performance only
249 of 249+
files are displayed.
Plain diff
Email patch
backend/internal/repository/idempotency_repo_integration_test.go
0 → 100644
View file @
3d79773b
//go:build integration
package
repository
import
(
"context"
"crypto/sha256"
"encoding/hex"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/stretchr/testify/require"
)
// hashedTestValue returns a unique SHA-256 hex string (64 chars) that fits VARCHAR(64) columns.
func
hashedTestValue
(
t
*
testing
.
T
,
prefix
string
)
string
{
t
.
Helper
()
sum
:=
sha256
.
Sum256
([]
byte
(
uniqueTestValue
(
t
,
prefix
)))
return
hex
.
EncodeToString
(
sum
[
:
])
}
func
TestIdempotencyRepo_CreateProcessing_CompeteSameKey
(
t
*
testing
.
T
)
{
tx
:=
testTx
(
t
)
repo
:=
&
idempotencyRepository
{
sql
:
tx
}
ctx
:=
context
.
Background
()
now
:=
time
.
Now
()
.
UTC
()
record
:=
&
service
.
IdempotencyRecord
{
Scope
:
uniqueTestValue
(
t
,
"idem-scope-create"
),
IdempotencyKeyHash
:
hashedTestValue
(
t
,
"idem-hash"
),
RequestFingerprint
:
hashedTestValue
(
t
,
"idem-fp"
),
Status
:
service
.
IdempotencyStatusProcessing
,
LockedUntil
:
ptrTime
(
now
.
Add
(
30
*
time
.
Second
)),
ExpiresAt
:
now
.
Add
(
24
*
time
.
Hour
),
}
owner
,
err
:=
repo
.
CreateProcessing
(
ctx
,
record
)
require
.
NoError
(
t
,
err
)
require
.
True
(
t
,
owner
)
require
.
NotZero
(
t
,
record
.
ID
)
duplicate
:=
&
service
.
IdempotencyRecord
{
Scope
:
record
.
Scope
,
IdempotencyKeyHash
:
record
.
IdempotencyKeyHash
,
RequestFingerprint
:
hashedTestValue
(
t
,
"idem-fp-other"
),
Status
:
service
.
IdempotencyStatusProcessing
,
LockedUntil
:
ptrTime
(
now
.
Add
(
30
*
time
.
Second
)),
ExpiresAt
:
now
.
Add
(
24
*
time
.
Hour
),
}
owner
,
err
=
repo
.
CreateProcessing
(
ctx
,
duplicate
)
require
.
NoError
(
t
,
err
)
require
.
False
(
t
,
owner
,
"same scope+key hash should be de-duplicated"
)
}
func
TestIdempotencyRepo_TryReclaim_StatusAndLockWindow
(
t
*
testing
.
T
)
{
tx
:=
testTx
(
t
)
repo
:=
&
idempotencyRepository
{
sql
:
tx
}
ctx
:=
context
.
Background
()
now
:=
time
.
Now
()
.
UTC
()
record
:=
&
service
.
IdempotencyRecord
{
Scope
:
uniqueTestValue
(
t
,
"idem-scope-reclaim"
),
IdempotencyKeyHash
:
hashedTestValue
(
t
,
"idem-hash-reclaim"
),
RequestFingerprint
:
hashedTestValue
(
t
,
"idem-fp-reclaim"
),
Status
:
service
.
IdempotencyStatusProcessing
,
LockedUntil
:
ptrTime
(
now
.
Add
(
10
*
time
.
Second
)),
ExpiresAt
:
now
.
Add
(
24
*
time
.
Hour
),
}
owner
,
err
:=
repo
.
CreateProcessing
(
ctx
,
record
)
require
.
NoError
(
t
,
err
)
require
.
True
(
t
,
owner
)
require
.
NoError
(
t
,
repo
.
MarkFailedRetryable
(
ctx
,
record
.
ID
,
"RETRYABLE_FAILURE"
,
now
.
Add
(
-
2
*
time
.
Second
),
now
.
Add
(
24
*
time
.
Hour
),
))
newLockedUntil
:=
now
.
Add
(
20
*
time
.
Second
)
reclaimed
,
err
:=
repo
.
TryReclaim
(
ctx
,
record
.
ID
,
service
.
IdempotencyStatusFailedRetryable
,
now
,
newLockedUntil
,
now
.
Add
(
24
*
time
.
Hour
),
)
require
.
NoError
(
t
,
err
)
require
.
True
(
t
,
reclaimed
,
"failed_retryable + expired lock should allow reclaim"
)
got
,
err
:=
repo
.
GetByScopeAndKeyHash
(
ctx
,
record
.
Scope
,
record
.
IdempotencyKeyHash
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
got
)
require
.
Equal
(
t
,
service
.
IdempotencyStatusProcessing
,
got
.
Status
)
require
.
NotNil
(
t
,
got
.
LockedUntil
)
require
.
True
(
t
,
got
.
LockedUntil
.
After
(
now
))
require
.
NoError
(
t
,
repo
.
MarkFailedRetryable
(
ctx
,
record
.
ID
,
"RETRYABLE_FAILURE"
,
now
.
Add
(
20
*
time
.
Second
),
now
.
Add
(
24
*
time
.
Hour
),
))
reclaimed
,
err
=
repo
.
TryReclaim
(
ctx
,
record
.
ID
,
service
.
IdempotencyStatusFailedRetryable
,
now
,
now
.
Add
(
40
*
time
.
Second
),
now
.
Add
(
24
*
time
.
Hour
),
)
require
.
NoError
(
t
,
err
)
require
.
False
(
t
,
reclaimed
,
"within lock window should not reclaim"
)
}
func
TestIdempotencyRepo_StatusTransition_ToSucceeded
(
t
*
testing
.
T
)
{
tx
:=
testTx
(
t
)
repo
:=
&
idempotencyRepository
{
sql
:
tx
}
ctx
:=
context
.
Background
()
now
:=
time
.
Now
()
.
UTC
()
record
:=
&
service
.
IdempotencyRecord
{
Scope
:
uniqueTestValue
(
t
,
"idem-scope-success"
),
IdempotencyKeyHash
:
hashedTestValue
(
t
,
"idem-hash-success"
),
RequestFingerprint
:
hashedTestValue
(
t
,
"idem-fp-success"
),
Status
:
service
.
IdempotencyStatusProcessing
,
LockedUntil
:
ptrTime
(
now
.
Add
(
10
*
time
.
Second
)),
ExpiresAt
:
now
.
Add
(
24
*
time
.
Hour
),
}
owner
,
err
:=
repo
.
CreateProcessing
(
ctx
,
record
)
require
.
NoError
(
t
,
err
)
require
.
True
(
t
,
owner
)
require
.
NoError
(
t
,
repo
.
MarkSucceeded
(
ctx
,
record
.
ID
,
200
,
`{"ok":true}`
,
now
.
Add
(
24
*
time
.
Hour
)))
got
,
err
:=
repo
.
GetByScopeAndKeyHash
(
ctx
,
record
.
Scope
,
record
.
IdempotencyKeyHash
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
got
)
require
.
Equal
(
t
,
service
.
IdempotencyStatusSucceeded
,
got
.
Status
)
require
.
NotNil
(
t
,
got
.
ResponseStatus
)
require
.
Equal
(
t
,
200
,
*
got
.
ResponseStatus
)
require
.
NotNil
(
t
,
got
.
ResponseBody
)
require
.
Equal
(
t
,
`{"ok":true}`
,
*
got
.
ResponseBody
)
require
.
Nil
(
t
,
got
.
LockedUntil
)
}
backend/internal/repository/identity_cache.go
View file @
3d79773b
...
...
@@ -12,7 +12,7 @@ import (
const
(
fingerprintKeyPrefix
=
"fingerprint:"
fingerprintTTL
=
24
*
time
.
Hour
fingerprintTTL
=
7
*
24
*
time
.
Hour
// 7天,配合每24小时懒续期可保持活跃账号永不过期
maskedSessionKeyPrefix
=
"masked_session:"
maskedSessionTTL
=
15
*
time
.
Minute
)
...
...
backend/internal/repository/migrations_runner.go
View file @
3d79773b
...
...
@@ -50,6 +50,30 @@ CREATE TABLE IF NOT EXISTS atlas_schema_revisions (
// 任何稳定的 int64 值都可以,只要不与同一数据库中的其他锁冲突即可。
const
migrationsAdvisoryLockID
int64
=
694208311321144027
const
migrationsLockRetryInterval
=
500
*
time
.
Millisecond
const
nonTransactionalMigrationSuffix
=
"_notx.sql"
type
migrationChecksumCompatibilityRule
struct
{
fileChecksum
string
acceptedDBChecksum
map
[
string
]
struct
{}
}
// migrationChecksumCompatibilityRules 仅用于兼容历史上误修改过的迁移文件 checksum。
// 规则必须同时匹配「迁移名 + 当前文件 checksum + 历史库 checksum」才会放行,避免放宽全局校验。
var
migrationChecksumCompatibilityRules
=
map
[
string
]
migrationChecksumCompatibilityRule
{
"054_drop_legacy_cache_columns.sql"
:
{
fileChecksum
:
"82de761156e03876653e7a6a4eee883cd927847036f779b0b9f34c42a8af7a7d"
,
acceptedDBChecksum
:
map
[
string
]
struct
{}{
"182c193f3359946cf094090cd9e57d5c3fd9abaffbc1e8fc378646b8a6fa12b4"
:
{},
},
},
"061_add_usage_log_request_type.sql"
:
{
fileChecksum
:
"66207e7aa5dd0429c2e2c0fabdaf79783ff157fa0af2e81adff2ee03790ec65c"
,
acceptedDBChecksum
:
map
[
string
]
struct
{}{
"08a248652cbab7cfde147fc6ef8cda464f2477674e20b718312faa252e0481c0"
:
{},
"222b4a09c797c22e5922b6b172327c824f5463aaa8760e4f621bc5c22e2be0f3"
:
{},
},
},
}
// ApplyMigrations 将嵌入的 SQL 迁移文件应用到指定的数据库。
//
...
...
@@ -147,6 +171,10 @@ func applyMigrationsFS(ctx context.Context, db *sql.DB, fsys fs.FS) error {
if
rowErr
==
nil
{
// 迁移已应用,验证校验和是否匹配
if
existing
!=
checksum
{
// 兼容特定历史误改场景(仅白名单规则),其余仍保持严格不可变约束。
if
isMigrationChecksumCompatible
(
name
,
existing
,
checksum
)
{
continue
}
// 校验和不匹配意味着迁移文件在应用后被修改,这是危险的。
// 正确的做法是创建新的迁移文件来进行变更。
return
fmt
.
Errorf
(
...
...
@@ -165,8 +193,34 @@ func applyMigrationsFS(ctx context.Context, db *sql.DB, fsys fs.FS) error {
return
fmt
.
Errorf
(
"check migration %s: %w"
,
name
,
rowErr
)
}
// 迁移未应用,在事务中执行。
// 使用事务确保迁移的原子性:要么完全成功,要么完全回滚。
nonTx
,
err
:=
validateMigrationExecutionMode
(
name
,
content
)
if
err
!=
nil
{
return
fmt
.
Errorf
(
"validate migration %s: %w"
,
name
,
err
)
}
if
nonTx
{
// *_notx.sql:用于 CREATE/DROP INDEX CONCURRENTLY 场景,必须非事务执行。
// 逐条语句执行,避免将多条 CONCURRENTLY 语句放入同一个隐式事务块。
statements
:=
splitSQLStatements
(
content
)
for
i
,
stmt
:=
range
statements
{
trimmed
:=
strings
.
TrimSpace
(
stmt
)
if
trimmed
==
""
{
continue
}
if
stripSQLLineComment
(
trimmed
)
==
""
{
continue
}
if
_
,
err
:=
db
.
ExecContext
(
ctx
,
trimmed
);
err
!=
nil
{
return
fmt
.
Errorf
(
"apply migration %s (non-tx statement %d): %w"
,
name
,
i
+
1
,
err
)
}
}
if
_
,
err
:=
db
.
ExecContext
(
ctx
,
"INSERT INTO schema_migrations (filename, checksum) VALUES ($1, $2)"
,
name
,
checksum
);
err
!=
nil
{
return
fmt
.
Errorf
(
"record migration %s (non-tx): %w"
,
name
,
err
)
}
continue
}
// 默认迁移在事务中执行,确保原子性:要么完全成功,要么完全回滚。
tx
,
err
:=
db
.
BeginTx
(
ctx
,
nil
)
if
err
!=
nil
{
return
fmt
.
Errorf
(
"begin migration %s: %w"
,
name
,
err
)
...
...
@@ -268,6 +322,84 @@ func latestMigrationBaseline(fsys fs.FS) (string, string, string, error) {
return
version
,
version
,
hash
,
nil
}
func
isMigrationChecksumCompatible
(
name
,
dbChecksum
,
fileChecksum
string
)
bool
{
rule
,
ok
:=
migrationChecksumCompatibilityRules
[
name
]
if
!
ok
{
return
false
}
if
rule
.
fileChecksum
!=
fileChecksum
{
return
false
}
_
,
ok
=
rule
.
acceptedDBChecksum
[
dbChecksum
]
return
ok
}
func
validateMigrationExecutionMode
(
name
,
content
string
)
(
bool
,
error
)
{
normalizedName
:=
strings
.
ToLower
(
strings
.
TrimSpace
(
name
))
upperContent
:=
strings
.
ToUpper
(
content
)
nonTx
:=
strings
.
HasSuffix
(
normalizedName
,
nonTransactionalMigrationSuffix
)
if
!
nonTx
{
if
strings
.
Contains
(
upperContent
,
"CONCURRENTLY"
)
{
return
false
,
errors
.
New
(
"CONCURRENTLY statements must be placed in *_notx.sql migrations"
)
}
return
false
,
nil
}
if
strings
.
Contains
(
upperContent
,
"BEGIN"
)
||
strings
.
Contains
(
upperContent
,
"COMMIT"
)
||
strings
.
Contains
(
upperContent
,
"ROLLBACK"
)
{
return
false
,
errors
.
New
(
"*_notx.sql must not contain transaction control statements (BEGIN/COMMIT/ROLLBACK)"
)
}
statements
:=
splitSQLStatements
(
content
)
for
_
,
stmt
:=
range
statements
{
normalizedStmt
:=
strings
.
ToUpper
(
stripSQLLineComment
(
strings
.
TrimSpace
(
stmt
)))
if
normalizedStmt
==
""
{
continue
}
if
strings
.
Contains
(
normalizedStmt
,
"CONCURRENTLY"
)
{
isCreateIndex
:=
strings
.
Contains
(
normalizedStmt
,
"CREATE"
)
&&
strings
.
Contains
(
normalizedStmt
,
"INDEX"
)
isDropIndex
:=
strings
.
Contains
(
normalizedStmt
,
"DROP"
)
&&
strings
.
Contains
(
normalizedStmt
,
"INDEX"
)
if
!
isCreateIndex
&&
!
isDropIndex
{
return
false
,
errors
.
New
(
"*_notx.sql currently only supports CREATE/DROP INDEX CONCURRENTLY statements"
)
}
if
isCreateIndex
&&
!
strings
.
Contains
(
normalizedStmt
,
"IF NOT EXISTS"
)
{
return
false
,
errors
.
New
(
"CREATE INDEX CONCURRENTLY in *_notx.sql must include IF NOT EXISTS for idempotency"
)
}
if
isDropIndex
&&
!
strings
.
Contains
(
normalizedStmt
,
"IF EXISTS"
)
{
return
false
,
errors
.
New
(
"DROP INDEX CONCURRENTLY in *_notx.sql must include IF EXISTS for idempotency"
)
}
continue
}
return
false
,
errors
.
New
(
"*_notx.sql must not mix non-CONCURRENTLY SQL statements"
)
}
return
true
,
nil
}
func
splitSQLStatements
(
content
string
)
[]
string
{
parts
:=
strings
.
Split
(
content
,
";"
)
out
:=
make
([]
string
,
0
,
len
(
parts
))
for
_
,
part
:=
range
parts
{
if
strings
.
TrimSpace
(
part
)
==
""
{
continue
}
out
=
append
(
out
,
part
)
}
return
out
}
func
stripSQLLineComment
(
s
string
)
string
{
lines
:=
strings
.
Split
(
s
,
"
\n
"
)
for
i
,
line
:=
range
lines
{
if
idx
:=
strings
.
Index
(
line
,
"--"
);
idx
>=
0
{
lines
[
i
]
=
line
[
:
idx
]
}
}
return
strings
.
TrimSpace
(
strings
.
Join
(
lines
,
"
\n
"
))
}
// pgAdvisoryLock 获取 PostgreSQL Advisory Lock。
// Advisory Lock 是一种轻量级的锁机制,不与任何特定的数据库对象关联。
// 它非常适合用于应用层面的分布式锁场景,如迁移序列化。
...
...
backend/internal/repository/migrations_runner_checksum_test.go
0 → 100644
View file @
3d79773b
package
repository
import
(
"testing"
"github.com/stretchr/testify/require"
)
func
TestIsMigrationChecksumCompatible
(
t
*
testing
.
T
)
{
t
.
Run
(
"054历史checksum可兼容"
,
func
(
t
*
testing
.
T
)
{
ok
:=
isMigrationChecksumCompatible
(
"054_drop_legacy_cache_columns.sql"
,
"182c193f3359946cf094090cd9e57d5c3fd9abaffbc1e8fc378646b8a6fa12b4"
,
"82de761156e03876653e7a6a4eee883cd927847036f779b0b9f34c42a8af7a7d"
,
)
require
.
True
(
t
,
ok
)
})
t
.
Run
(
"054在未知文件checksum下不兼容"
,
func
(
t
*
testing
.
T
)
{
ok
:=
isMigrationChecksumCompatible
(
"054_drop_legacy_cache_columns.sql"
,
"182c193f3359946cf094090cd9e57d5c3fd9abaffbc1e8fc378646b8a6fa12b4"
,
"0000000000000000000000000000000000000000000000000000000000000000"
,
)
require
.
False
(
t
,
ok
)
})
t
.
Run
(
"061历史checksum可兼容"
,
func
(
t
*
testing
.
T
)
{
ok
:=
isMigrationChecksumCompatible
(
"061_add_usage_log_request_type.sql"
,
"08a248652cbab7cfde147fc6ef8cda464f2477674e20b718312faa252e0481c0"
,
"66207e7aa5dd0429c2e2c0fabdaf79783ff157fa0af2e81adff2ee03790ec65c"
,
)
require
.
True
(
t
,
ok
)
})
t
.
Run
(
"061第二个历史checksum可兼容"
,
func
(
t
*
testing
.
T
)
{
ok
:=
isMigrationChecksumCompatible
(
"061_add_usage_log_request_type.sql"
,
"222b4a09c797c22e5922b6b172327c824f5463aaa8760e4f621bc5c22e2be0f3"
,
"66207e7aa5dd0429c2e2c0fabdaf79783ff157fa0af2e81adff2ee03790ec65c"
,
)
require
.
True
(
t
,
ok
)
})
t
.
Run
(
"非白名单迁移不兼容"
,
func
(
t
*
testing
.
T
)
{
ok
:=
isMigrationChecksumCompatible
(
"001_init.sql"
,
"182c193f3359946cf094090cd9e57d5c3fd9abaffbc1e8fc378646b8a6fa12b4"
,
"82de761156e03876653e7a6a4eee883cd927847036f779b0b9f34c42a8af7a7d"
,
)
require
.
False
(
t
,
ok
)
})
}
backend/internal/repository/migrations_runner_extra_test.go
0 → 100644
View file @
3d79773b
package
repository
import
(
"context"
"crypto/sha256"
"encoding/hex"
"errors"
"io/fs"
"strings"
"testing"
"testing/fstest"
"time"
sqlmock
"github.com/DATA-DOG/go-sqlmock"
"github.com/stretchr/testify/require"
)
func
TestApplyMigrations_NilDB
(
t
*
testing
.
T
)
{
err
:=
ApplyMigrations
(
context
.
Background
(),
nil
)
require
.
Error
(
t
,
err
)
require
.
Contains
(
t
,
err
.
Error
(),
"nil sql db"
)
}
func
TestApplyMigrations_DelegatesToApplyMigrationsFS
(
t
*
testing
.
T
)
{
db
,
mock
,
err
:=
sqlmock
.
New
()
require
.
NoError
(
t
,
err
)
defer
func
()
{
_
=
db
.
Close
()
}()
mock
.
ExpectQuery
(
"SELECT pg_try_advisory_lock
\\
(
\\
$1
\\
)"
)
.
WithArgs
(
migrationsAdvisoryLockID
)
.
WillReturnError
(
errors
.
New
(
"lock failed"
))
err
=
ApplyMigrations
(
context
.
Background
(),
db
)
require
.
Error
(
t
,
err
)
require
.
Contains
(
t
,
err
.
Error
(),
"acquire migrations lock"
)
require
.
NoError
(
t
,
mock
.
ExpectationsWereMet
())
}
func
TestLatestMigrationBaseline
(
t
*
testing
.
T
)
{
t
.
Run
(
"empty_fs_returns_baseline"
,
func
(
t
*
testing
.
T
)
{
version
,
description
,
hash
,
err
:=
latestMigrationBaseline
(
fstest
.
MapFS
{})
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
"baseline"
,
version
)
require
.
Equal
(
t
,
"baseline"
,
description
)
require
.
Equal
(
t
,
""
,
hash
)
})
t
.
Run
(
"uses_latest_sorted_sql_file"
,
func
(
t
*
testing
.
T
)
{
fsys
:=
fstest
.
MapFS
{
"001_init.sql"
:
&
fstest
.
MapFile
{
Data
:
[]
byte
(
"CREATE TABLE t1(id int);"
)},
"010_final.sql"
:
&
fstest
.
MapFile
{
Data
:
[]
byte
(
"CREATE TABLE t2(id int);"
),
},
}
version
,
description
,
hash
,
err
:=
latestMigrationBaseline
(
fsys
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
"010_final"
,
version
)
require
.
Equal
(
t
,
"010_final"
,
description
)
require
.
Len
(
t
,
hash
,
64
)
})
t
.
Run
(
"read_file_error"
,
func
(
t
*
testing
.
T
)
{
fsys
:=
fstest
.
MapFS
{
"010_bad.sql"
:
&
fstest
.
MapFile
{
Mode
:
fs
.
ModeDir
},
}
_
,
_
,
_
,
err
:=
latestMigrationBaseline
(
fsys
)
require
.
Error
(
t
,
err
)
})
}
func
TestIsMigrationChecksumCompatible_AdditionalCases
(
t
*
testing
.
T
)
{
require
.
False
(
t
,
isMigrationChecksumCompatible
(
"unknown.sql"
,
"db"
,
"file"
))
var
(
name
string
rule
migrationChecksumCompatibilityRule
)
for
n
,
r
:=
range
migrationChecksumCompatibilityRules
{
name
=
n
rule
=
r
break
}
require
.
NotEmpty
(
t
,
name
)
require
.
False
(
t
,
isMigrationChecksumCompatible
(
name
,
"db-not-accepted"
,
"file-not-match"
))
require
.
False
(
t
,
isMigrationChecksumCompatible
(
name
,
"db-not-accepted"
,
rule
.
fileChecksum
))
var
accepted
string
for
checksum
:=
range
rule
.
acceptedDBChecksum
{
accepted
=
checksum
break
}
require
.
NotEmpty
(
t
,
accepted
)
require
.
True
(
t
,
isMigrationChecksumCompatible
(
name
,
accepted
,
rule
.
fileChecksum
))
}
func
TestEnsureAtlasBaselineAligned
(
t
*
testing
.
T
)
{
t
.
Run
(
"skip_when_no_legacy_table"
,
func
(
t
*
testing
.
T
)
{
db
,
mock
,
err
:=
sqlmock
.
New
()
require
.
NoError
(
t
,
err
)
defer
func
()
{
_
=
db
.
Close
()
}()
mock
.
ExpectQuery
(
"SELECT EXISTS
\\
("
)
.
WithArgs
(
"schema_migrations"
)
.
WillReturnRows
(
sqlmock
.
NewRows
([]
string
{
"exists"
})
.
AddRow
(
false
))
err
=
ensureAtlasBaselineAligned
(
context
.
Background
(),
db
,
fstest
.
MapFS
{})
require
.
NoError
(
t
,
err
)
require
.
NoError
(
t
,
mock
.
ExpectationsWereMet
())
})
t
.
Run
(
"create_atlas_and_insert_baseline_when_empty"
,
func
(
t
*
testing
.
T
)
{
db
,
mock
,
err
:=
sqlmock
.
New
()
require
.
NoError
(
t
,
err
)
defer
func
()
{
_
=
db
.
Close
()
}()
mock
.
ExpectQuery
(
"SELECT EXISTS
\\
("
)
.
WithArgs
(
"schema_migrations"
)
.
WillReturnRows
(
sqlmock
.
NewRows
([]
string
{
"exists"
})
.
AddRow
(
true
))
mock
.
ExpectQuery
(
"SELECT EXISTS
\\
("
)
.
WithArgs
(
"atlas_schema_revisions"
)
.
WillReturnRows
(
sqlmock
.
NewRows
([]
string
{
"exists"
})
.
AddRow
(
false
))
mock
.
ExpectExec
(
"CREATE TABLE IF NOT EXISTS atlas_schema_revisions"
)
.
WillReturnResult
(
sqlmock
.
NewResult
(
0
,
0
))
mock
.
ExpectQuery
(
"SELECT COUNT
\\
(
\\
*
\\
) FROM atlas_schema_revisions"
)
.
WillReturnRows
(
sqlmock
.
NewRows
([]
string
{
"count"
})
.
AddRow
(
0
))
mock
.
ExpectExec
(
"INSERT INTO atlas_schema_revisions"
)
.
WithArgs
(
"002_next"
,
"002_next"
,
1
,
sqlmock
.
AnyArg
())
.
WillReturnResult
(
sqlmock
.
NewResult
(
1
,
1
))
fsys
:=
fstest
.
MapFS
{
"001_init.sql"
:
&
fstest
.
MapFile
{
Data
:
[]
byte
(
"CREATE TABLE t1(id int);"
)},
"002_next.sql"
:
&
fstest
.
MapFile
{
Data
:
[]
byte
(
"CREATE TABLE t2(id int);"
)},
}
err
=
ensureAtlasBaselineAligned
(
context
.
Background
(),
db
,
fsys
)
require
.
NoError
(
t
,
err
)
require
.
NoError
(
t
,
mock
.
ExpectationsWereMet
())
})
t
.
Run
(
"error_when_checking_legacy_table"
,
func
(
t
*
testing
.
T
)
{
db
,
mock
,
err
:=
sqlmock
.
New
()
require
.
NoError
(
t
,
err
)
defer
func
()
{
_
=
db
.
Close
()
}()
mock
.
ExpectQuery
(
"SELECT EXISTS
\\
("
)
.
WithArgs
(
"schema_migrations"
)
.
WillReturnError
(
errors
.
New
(
"exists failed"
))
err
=
ensureAtlasBaselineAligned
(
context
.
Background
(),
db
,
fstest
.
MapFS
{})
require
.
Error
(
t
,
err
)
require
.
Contains
(
t
,
err
.
Error
(),
"check schema_migrations"
)
require
.
NoError
(
t
,
mock
.
ExpectationsWereMet
())
})
t
.
Run
(
"error_when_counting_atlas_rows"
,
func
(
t
*
testing
.
T
)
{
db
,
mock
,
err
:=
sqlmock
.
New
()
require
.
NoError
(
t
,
err
)
defer
func
()
{
_
=
db
.
Close
()
}()
mock
.
ExpectQuery
(
"SELECT EXISTS
\\
("
)
.
WithArgs
(
"schema_migrations"
)
.
WillReturnRows
(
sqlmock
.
NewRows
([]
string
{
"exists"
})
.
AddRow
(
true
))
mock
.
ExpectQuery
(
"SELECT EXISTS
\\
("
)
.
WithArgs
(
"atlas_schema_revisions"
)
.
WillReturnRows
(
sqlmock
.
NewRows
([]
string
{
"exists"
})
.
AddRow
(
true
))
mock
.
ExpectQuery
(
"SELECT COUNT
\\
(
\\
*
\\
) FROM atlas_schema_revisions"
)
.
WillReturnError
(
errors
.
New
(
"count failed"
))
err
=
ensureAtlasBaselineAligned
(
context
.
Background
(),
db
,
fstest
.
MapFS
{})
require
.
Error
(
t
,
err
)
require
.
Contains
(
t
,
err
.
Error
(),
"count atlas_schema_revisions"
)
require
.
NoError
(
t
,
mock
.
ExpectationsWereMet
())
})
t
.
Run
(
"error_when_creating_atlas_table"
,
func
(
t
*
testing
.
T
)
{
db
,
mock
,
err
:=
sqlmock
.
New
()
require
.
NoError
(
t
,
err
)
defer
func
()
{
_
=
db
.
Close
()
}()
mock
.
ExpectQuery
(
"SELECT EXISTS
\\
("
)
.
WithArgs
(
"schema_migrations"
)
.
WillReturnRows
(
sqlmock
.
NewRows
([]
string
{
"exists"
})
.
AddRow
(
true
))
mock
.
ExpectQuery
(
"SELECT EXISTS
\\
("
)
.
WithArgs
(
"atlas_schema_revisions"
)
.
WillReturnRows
(
sqlmock
.
NewRows
([]
string
{
"exists"
})
.
AddRow
(
false
))
mock
.
ExpectExec
(
"CREATE TABLE IF NOT EXISTS atlas_schema_revisions"
)
.
WillReturnError
(
errors
.
New
(
"create failed"
))
err
=
ensureAtlasBaselineAligned
(
context
.
Background
(),
db
,
fstest
.
MapFS
{})
require
.
Error
(
t
,
err
)
require
.
Contains
(
t
,
err
.
Error
(),
"create atlas_schema_revisions"
)
require
.
NoError
(
t
,
mock
.
ExpectationsWereMet
())
})
t
.
Run
(
"error_when_inserting_baseline"
,
func
(
t
*
testing
.
T
)
{
db
,
mock
,
err
:=
sqlmock
.
New
()
require
.
NoError
(
t
,
err
)
defer
func
()
{
_
=
db
.
Close
()
}()
mock
.
ExpectQuery
(
"SELECT EXISTS
\\
("
)
.
WithArgs
(
"schema_migrations"
)
.
WillReturnRows
(
sqlmock
.
NewRows
([]
string
{
"exists"
})
.
AddRow
(
true
))
mock
.
ExpectQuery
(
"SELECT EXISTS
\\
("
)
.
WithArgs
(
"atlas_schema_revisions"
)
.
WillReturnRows
(
sqlmock
.
NewRows
([]
string
{
"exists"
})
.
AddRow
(
true
))
mock
.
ExpectQuery
(
"SELECT COUNT
\\
(
\\
*
\\
) FROM atlas_schema_revisions"
)
.
WillReturnRows
(
sqlmock
.
NewRows
([]
string
{
"count"
})
.
AddRow
(
0
))
mock
.
ExpectExec
(
"INSERT INTO atlas_schema_revisions"
)
.
WithArgs
(
"001_init"
,
"001_init"
,
1
,
sqlmock
.
AnyArg
())
.
WillReturnError
(
errors
.
New
(
"insert failed"
))
fsys
:=
fstest
.
MapFS
{
"001_init.sql"
:
&
fstest
.
MapFile
{
Data
:
[]
byte
(
"CREATE TABLE t(id int);"
)},
}
err
=
ensureAtlasBaselineAligned
(
context
.
Background
(),
db
,
fsys
)
require
.
Error
(
t
,
err
)
require
.
Contains
(
t
,
err
.
Error
(),
"insert atlas baseline"
)
require
.
NoError
(
t
,
mock
.
ExpectationsWereMet
())
})
}
func
TestApplyMigrationsFS_ChecksumMismatchRejected
(
t
*
testing
.
T
)
{
db
,
mock
,
err
:=
sqlmock
.
New
()
require
.
NoError
(
t
,
err
)
defer
func
()
{
_
=
db
.
Close
()
}()
prepareMigrationsBootstrapExpectations
(
mock
)
mock
.
ExpectQuery
(
"SELECT checksum FROM schema_migrations WHERE filename =
\\
$1"
)
.
WithArgs
(
"001_init.sql"
)
.
WillReturnRows
(
sqlmock
.
NewRows
([]
string
{
"checksum"
})
.
AddRow
(
"mismatched-checksum"
))
mock
.
ExpectExec
(
"SELECT pg_advisory_unlock
\\
(
\\
$1
\\
)"
)
.
WithArgs
(
migrationsAdvisoryLockID
)
.
WillReturnResult
(
sqlmock
.
NewResult
(
0
,
1
))
fsys
:=
fstest
.
MapFS
{
"001_init.sql"
:
&
fstest
.
MapFile
{
Data
:
[]
byte
(
"CREATE TABLE t(id int);"
)},
}
err
=
applyMigrationsFS
(
context
.
Background
(),
db
,
fsys
)
require
.
Error
(
t
,
err
)
require
.
Contains
(
t
,
err
.
Error
(),
"checksum mismatch"
)
require
.
NoError
(
t
,
mock
.
ExpectationsWereMet
())
}
func
TestApplyMigrationsFS_CheckMigrationQueryError
(
t
*
testing
.
T
)
{
db
,
mock
,
err
:=
sqlmock
.
New
()
require
.
NoError
(
t
,
err
)
defer
func
()
{
_
=
db
.
Close
()
}()
prepareMigrationsBootstrapExpectations
(
mock
)
mock
.
ExpectQuery
(
"SELECT checksum FROM schema_migrations WHERE filename =
\\
$1"
)
.
WithArgs
(
"001_err.sql"
)
.
WillReturnError
(
errors
.
New
(
"query failed"
))
mock
.
ExpectExec
(
"SELECT pg_advisory_unlock
\\
(
\\
$1
\\
)"
)
.
WithArgs
(
migrationsAdvisoryLockID
)
.
WillReturnResult
(
sqlmock
.
NewResult
(
0
,
1
))
fsys
:=
fstest
.
MapFS
{
"001_err.sql"
:
&
fstest
.
MapFile
{
Data
:
[]
byte
(
"SELECT 1;"
)},
}
err
=
applyMigrationsFS
(
context
.
Background
(),
db
,
fsys
)
require
.
Error
(
t
,
err
)
require
.
Contains
(
t
,
err
.
Error
(),
"check migration 001_err.sql"
)
require
.
NoError
(
t
,
mock
.
ExpectationsWereMet
())
}
func
TestApplyMigrationsFS_SkipEmptyAndAlreadyApplied
(
t
*
testing
.
T
)
{
db
,
mock
,
err
:=
sqlmock
.
New
()
require
.
NoError
(
t
,
err
)
defer
func
()
{
_
=
db
.
Close
()
}()
prepareMigrationsBootstrapExpectations
(
mock
)
alreadySQL
:=
"CREATE TABLE t(id int);"
checksum
:=
migrationChecksum
(
alreadySQL
)
mock
.
ExpectQuery
(
"SELECT checksum FROM schema_migrations WHERE filename =
\\
$1"
)
.
WithArgs
(
"001_already.sql"
)
.
WillReturnRows
(
sqlmock
.
NewRows
([]
string
{
"checksum"
})
.
AddRow
(
checksum
))
mock
.
ExpectExec
(
"SELECT pg_advisory_unlock
\\
(
\\
$1
\\
)"
)
.
WithArgs
(
migrationsAdvisoryLockID
)
.
WillReturnResult
(
sqlmock
.
NewResult
(
0
,
1
))
fsys
:=
fstest
.
MapFS
{
"000_empty.sql"
:
&
fstest
.
MapFile
{
Data
:
[]
byte
(
"
\n\t
"
)},
"001_already.sql"
:
&
fstest
.
MapFile
{
Data
:
[]
byte
(
alreadySQL
)},
}
err
=
applyMigrationsFS
(
context
.
Background
(),
db
,
fsys
)
require
.
NoError
(
t
,
err
)
require
.
NoError
(
t
,
mock
.
ExpectationsWereMet
())
}
func
TestApplyMigrationsFS_ReadMigrationError
(
t
*
testing
.
T
)
{
db
,
mock
,
err
:=
sqlmock
.
New
()
require
.
NoError
(
t
,
err
)
defer
func
()
{
_
=
db
.
Close
()
}()
prepareMigrationsBootstrapExpectations
(
mock
)
mock
.
ExpectExec
(
"SELECT pg_advisory_unlock
\\
(
\\
$1
\\
)"
)
.
WithArgs
(
migrationsAdvisoryLockID
)
.
WillReturnResult
(
sqlmock
.
NewResult
(
0
,
1
))
fsys
:=
fstest
.
MapFS
{
"001_bad.sql"
:
&
fstest
.
MapFile
{
Mode
:
fs
.
ModeDir
},
}
err
=
applyMigrationsFS
(
context
.
Background
(),
db
,
fsys
)
require
.
Error
(
t
,
err
)
require
.
Contains
(
t
,
err
.
Error
(),
"read migration 001_bad.sql"
)
require
.
NoError
(
t
,
mock
.
ExpectationsWereMet
())
}
func
TestPgAdvisoryLockAndUnlock_ErrorBranches
(
t
*
testing
.
T
)
{
t
.
Run
(
"context_cancelled_while_not_locked"
,
func
(
t
*
testing
.
T
)
{
db
,
mock
,
err
:=
sqlmock
.
New
()
require
.
NoError
(
t
,
err
)
defer
func
()
{
_
=
db
.
Close
()
}()
mock
.
ExpectQuery
(
"SELECT pg_try_advisory_lock
\\
(
\\
$1
\\
)"
)
.
WithArgs
(
migrationsAdvisoryLockID
)
.
WillReturnRows
(
sqlmock
.
NewRows
([]
string
{
"pg_try_advisory_lock"
})
.
AddRow
(
false
))
ctx
,
cancel
:=
context
.
WithTimeout
(
context
.
Background
(),
5
*
time
.
Millisecond
)
defer
cancel
()
err
=
pgAdvisoryLock
(
ctx
,
db
)
require
.
Error
(
t
,
err
)
require
.
Contains
(
t
,
err
.
Error
(),
"acquire migrations lock"
)
require
.
NoError
(
t
,
mock
.
ExpectationsWereMet
())
})
t
.
Run
(
"unlock_exec_error"
,
func
(
t
*
testing
.
T
)
{
db
,
mock
,
err
:=
sqlmock
.
New
()
require
.
NoError
(
t
,
err
)
defer
func
()
{
_
=
db
.
Close
()
}()
mock
.
ExpectExec
(
"SELECT pg_advisory_unlock
\\
(
\\
$1
\\
)"
)
.
WithArgs
(
migrationsAdvisoryLockID
)
.
WillReturnError
(
errors
.
New
(
"unlock failed"
))
err
=
pgAdvisoryUnlock
(
context
.
Background
(),
db
)
require
.
Error
(
t
,
err
)
require
.
Contains
(
t
,
err
.
Error
(),
"release migrations lock"
)
require
.
NoError
(
t
,
mock
.
ExpectationsWereMet
())
})
t
.
Run
(
"acquire_lock_after_retry"
,
func
(
t
*
testing
.
T
)
{
db
,
mock
,
err
:=
sqlmock
.
New
()
require
.
NoError
(
t
,
err
)
defer
func
()
{
_
=
db
.
Close
()
}()
mock
.
ExpectQuery
(
"SELECT pg_try_advisory_lock
\\
(
\\
$1
\\
)"
)
.
WithArgs
(
migrationsAdvisoryLockID
)
.
WillReturnRows
(
sqlmock
.
NewRows
([]
string
{
"pg_try_advisory_lock"
})
.
AddRow
(
false
))
mock
.
ExpectQuery
(
"SELECT pg_try_advisory_lock
\\
(
\\
$1
\\
)"
)
.
WithArgs
(
migrationsAdvisoryLockID
)
.
WillReturnRows
(
sqlmock
.
NewRows
([]
string
{
"pg_try_advisory_lock"
})
.
AddRow
(
true
))
ctx
,
cancel
:=
context
.
WithTimeout
(
context
.
Background
(),
migrationsLockRetryInterval
*
3
)
defer
cancel
()
start
:=
time
.
Now
()
err
=
pgAdvisoryLock
(
ctx
,
db
)
require
.
NoError
(
t
,
err
)
require
.
GreaterOrEqual
(
t
,
time
.
Since
(
start
),
migrationsLockRetryInterval
)
require
.
NoError
(
t
,
mock
.
ExpectationsWereMet
())
})
}
func
migrationChecksum
(
content
string
)
string
{
sum
:=
sha256
.
Sum256
([]
byte
(
strings
.
TrimSpace
(
content
)))
return
hex
.
EncodeToString
(
sum
[
:
])
}
backend/internal/repository/migrations_runner_notx_test.go
0 → 100644
View file @
3d79773b
package
repository
import
(
"context"
"database/sql"
"testing"
"testing/fstest"
sqlmock
"github.com/DATA-DOG/go-sqlmock"
"github.com/stretchr/testify/require"
)
func
TestValidateMigrationExecutionMode
(
t
*
testing
.
T
)
{
t
.
Run
(
"事务迁移包含CONCURRENTLY会被拒绝"
,
func
(
t
*
testing
.
T
)
{
nonTx
,
err
:=
validateMigrationExecutionMode
(
"001_add_idx.sql"
,
"CREATE INDEX CONCURRENTLY idx_a ON t(a);"
)
require
.
False
(
t
,
nonTx
)
require
.
Error
(
t
,
err
)
})
t
.
Run
(
"notx迁移要求CREATE使用IF NOT EXISTS"
,
func
(
t
*
testing
.
T
)
{
nonTx
,
err
:=
validateMigrationExecutionMode
(
"001_add_idx_notx.sql"
,
"CREATE INDEX CONCURRENTLY idx_a ON t(a);"
)
require
.
False
(
t
,
nonTx
)
require
.
Error
(
t
,
err
)
})
t
.
Run
(
"notx迁移要求DROP使用IF EXISTS"
,
func
(
t
*
testing
.
T
)
{
nonTx
,
err
:=
validateMigrationExecutionMode
(
"001_drop_idx_notx.sql"
,
"DROP INDEX CONCURRENTLY idx_a;"
)
require
.
False
(
t
,
nonTx
)
require
.
Error
(
t
,
err
)
})
t
.
Run
(
"notx迁移禁止事务控制语句"
,
func
(
t
*
testing
.
T
)
{
nonTx
,
err
:=
validateMigrationExecutionMode
(
"001_add_idx_notx.sql"
,
"BEGIN; CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_a ON t(a); COMMIT;"
)
require
.
False
(
t
,
nonTx
)
require
.
Error
(
t
,
err
)
})
t
.
Run
(
"notx迁移禁止混用非CONCURRENTLY语句"
,
func
(
t
*
testing
.
T
)
{
nonTx
,
err
:=
validateMigrationExecutionMode
(
"001_add_idx_notx.sql"
,
"CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_a ON t(a); UPDATE t SET a = 1;"
)
require
.
False
(
t
,
nonTx
)
require
.
Error
(
t
,
err
)
})
t
.
Run
(
"notx迁移允许幂等并发索引语句"
,
func
(
t
*
testing
.
T
)
{
nonTx
,
err
:=
validateMigrationExecutionMode
(
"001_add_idx_notx.sql"
,
`
CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_a ON t(a);
DROP INDEX CONCURRENTLY IF EXISTS idx_b;
`
)
require
.
True
(
t
,
nonTx
)
require
.
NoError
(
t
,
err
)
})
}
func
TestApplyMigrationsFS_NonTransactionalMigration
(
t
*
testing
.
T
)
{
db
,
mock
,
err
:=
sqlmock
.
New
()
require
.
NoError
(
t
,
err
)
defer
func
()
{
_
=
db
.
Close
()
}()
prepareMigrationsBootstrapExpectations
(
mock
)
mock
.
ExpectQuery
(
"SELECT checksum FROM schema_migrations WHERE filename =
\\
$1"
)
.
WithArgs
(
"001_add_idx_notx.sql"
)
.
WillReturnError
(
sql
.
ErrNoRows
)
mock
.
ExpectExec
(
"CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_t_a ON t
\\
(a
\\
)"
)
.
WillReturnResult
(
sqlmock
.
NewResult
(
0
,
0
))
mock
.
ExpectExec
(
"INSERT INTO schema_migrations
\\
(filename, checksum
\\
) VALUES
\\
(
\\
$1,
\\
$2
\\
)"
)
.
WithArgs
(
"001_add_idx_notx.sql"
,
sqlmock
.
AnyArg
())
.
WillReturnResult
(
sqlmock
.
NewResult
(
1
,
1
))
mock
.
ExpectExec
(
"SELECT pg_advisory_unlock
\\
(
\\
$1
\\
)"
)
.
WithArgs
(
migrationsAdvisoryLockID
)
.
WillReturnResult
(
sqlmock
.
NewResult
(
0
,
1
))
fsys
:=
fstest
.
MapFS
{
"001_add_idx_notx.sql"
:
&
fstest
.
MapFile
{
Data
:
[]
byte
(
"CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_t_a ON t(a);"
),
},
}
err
=
applyMigrationsFS
(
context
.
Background
(),
db
,
fsys
)
require
.
NoError
(
t
,
err
)
require
.
NoError
(
t
,
mock
.
ExpectationsWereMet
())
}
func
TestApplyMigrationsFS_NonTransactionalMigration_MultiStatements
(
t
*
testing
.
T
)
{
db
,
mock
,
err
:=
sqlmock
.
New
()
require
.
NoError
(
t
,
err
)
defer
func
()
{
_
=
db
.
Close
()
}()
prepareMigrationsBootstrapExpectations
(
mock
)
mock
.
ExpectQuery
(
"SELECT checksum FROM schema_migrations WHERE filename =
\\
$1"
)
.
WithArgs
(
"001_add_multi_idx_notx.sql"
)
.
WillReturnError
(
sql
.
ErrNoRows
)
mock
.
ExpectExec
(
"CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_t_a ON t
\\
(a
\\
)"
)
.
WillReturnResult
(
sqlmock
.
NewResult
(
0
,
0
))
mock
.
ExpectExec
(
"CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_t_b ON t
\\
(b
\\
)"
)
.
WillReturnResult
(
sqlmock
.
NewResult
(
0
,
0
))
mock
.
ExpectExec
(
"INSERT INTO schema_migrations
\\
(filename, checksum
\\
) VALUES
\\
(
\\
$1,
\\
$2
\\
)"
)
.
WithArgs
(
"001_add_multi_idx_notx.sql"
,
sqlmock
.
AnyArg
())
.
WillReturnResult
(
sqlmock
.
NewResult
(
1
,
1
))
mock
.
ExpectExec
(
"SELECT pg_advisory_unlock
\\
(
\\
$1
\\
)"
)
.
WithArgs
(
migrationsAdvisoryLockID
)
.
WillReturnResult
(
sqlmock
.
NewResult
(
0
,
1
))
fsys
:=
fstest
.
MapFS
{
"001_add_multi_idx_notx.sql"
:
&
fstest
.
MapFile
{
Data
:
[]
byte
(
`
-- first
CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_t_a ON t(a);
-- second
CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_t_b ON t(b);
`
),
},
}
err
=
applyMigrationsFS
(
context
.
Background
(),
db
,
fsys
)
require
.
NoError
(
t
,
err
)
require
.
NoError
(
t
,
mock
.
ExpectationsWereMet
())
}
func
TestApplyMigrationsFS_TransactionalMigration
(
t
*
testing
.
T
)
{
db
,
mock
,
err
:=
sqlmock
.
New
()
require
.
NoError
(
t
,
err
)
defer
func
()
{
_
=
db
.
Close
()
}()
prepareMigrationsBootstrapExpectations
(
mock
)
mock
.
ExpectQuery
(
"SELECT checksum FROM schema_migrations WHERE filename =
\\
$1"
)
.
WithArgs
(
"001_add_col.sql"
)
.
WillReturnError
(
sql
.
ErrNoRows
)
mock
.
ExpectBegin
()
mock
.
ExpectExec
(
"ALTER TABLE t ADD COLUMN name TEXT"
)
.
WillReturnResult
(
sqlmock
.
NewResult
(
0
,
0
))
mock
.
ExpectExec
(
"INSERT INTO schema_migrations
\\
(filename, checksum
\\
) VALUES
\\
(
\\
$1,
\\
$2
\\
)"
)
.
WithArgs
(
"001_add_col.sql"
,
sqlmock
.
AnyArg
())
.
WillReturnResult
(
sqlmock
.
NewResult
(
1
,
1
))
mock
.
ExpectCommit
()
mock
.
ExpectExec
(
"SELECT pg_advisory_unlock
\\
(
\\
$1
\\
)"
)
.
WithArgs
(
migrationsAdvisoryLockID
)
.
WillReturnResult
(
sqlmock
.
NewResult
(
0
,
1
))
fsys
:=
fstest
.
MapFS
{
"001_add_col.sql"
:
&
fstest
.
MapFile
{
Data
:
[]
byte
(
"ALTER TABLE t ADD COLUMN name TEXT;"
),
},
}
err
=
applyMigrationsFS
(
context
.
Background
(),
db
,
fsys
)
require
.
NoError
(
t
,
err
)
require
.
NoError
(
t
,
mock
.
ExpectationsWereMet
())
}
func
prepareMigrationsBootstrapExpectations
(
mock
sqlmock
.
Sqlmock
)
{
mock
.
ExpectQuery
(
"SELECT pg_try_advisory_lock
\\
(
\\
$1
\\
)"
)
.
WithArgs
(
migrationsAdvisoryLockID
)
.
WillReturnRows
(
sqlmock
.
NewRows
([]
string
{
"pg_try_advisory_lock"
})
.
AddRow
(
true
))
mock
.
ExpectExec
(
"CREATE TABLE IF NOT EXISTS schema_migrations"
)
.
WillReturnResult
(
sqlmock
.
NewResult
(
0
,
0
))
mock
.
ExpectQuery
(
"SELECT EXISTS
\\
("
)
.
WithArgs
(
"schema_migrations"
)
.
WillReturnRows
(
sqlmock
.
NewRows
([]
string
{
"exists"
})
.
AddRow
(
true
))
mock
.
ExpectQuery
(
"SELECT EXISTS
\\
("
)
.
WithArgs
(
"atlas_schema_revisions"
)
.
WillReturnRows
(
sqlmock
.
NewRows
([]
string
{
"exists"
})
.
AddRow
(
true
))
mock
.
ExpectQuery
(
"SELECT COUNT
\\
(
\\
*
\\
) FROM atlas_schema_revisions"
)
.
WillReturnRows
(
sqlmock
.
NewRows
([]
string
{
"count"
})
.
AddRow
(
1
))
}
backend/internal/repository/migrations_schema_integration_test.go
View file @
3d79773b
...
...
@@ -42,12 +42,19 @@ func TestMigrationsRunner_IsIdempotent_AndSchemaIsUpToDate(t *testing.T) {
// usage_logs: billing_type used by filters/stats
requireColumn
(
t
,
tx
,
"usage_logs"
,
"billing_type"
,
"smallint"
,
0
,
false
)
requireColumn
(
t
,
tx
,
"usage_logs"
,
"request_type"
,
"smallint"
,
0
,
false
)
requireColumn
(
t
,
tx
,
"usage_logs"
,
"openai_ws_mode"
,
"boolean"
,
0
,
false
)
// settings table should exist
var
settingsRegclass
sql
.
NullString
require
.
NoError
(
t
,
tx
.
QueryRowContext
(
context
.
Background
(),
"SELECT to_regclass('public.settings')"
)
.
Scan
(
&
settingsRegclass
))
require
.
True
(
t
,
settingsRegclass
.
Valid
,
"expected settings table to exist"
)
// security_secrets table should exist
var
securitySecretsRegclass
sql
.
NullString
require
.
NoError
(
t
,
tx
.
QueryRowContext
(
context
.
Background
(),
"SELECT to_regclass('public.security_secrets')"
)
.
Scan
(
&
securitySecretsRegclass
))
require
.
True
(
t
,
securitySecretsRegclass
.
Valid
,
"expected security_secrets table to exist"
)
// user_allowed_groups table should exist
var
uagRegclass
sql
.
NullString
require
.
NoError
(
t
,
tx
.
QueryRowContext
(
context
.
Background
(),
"SELECT to_regclass('public.user_allowed_groups')"
)
.
Scan
(
&
uagRegclass
))
...
...
backend/internal/repository/openai_oauth_service.go
View file @
3d79773b
...
...
@@ -4,6 +4,7 @@ import (
"context"
"net/http"
"net/url"
"strings"
"time"
infraerrors
"github.com/Wei-Shaw/sub2api/internal/pkg/errors"
...
...
@@ -21,16 +22,23 @@ type openaiOAuthService struct {
tokenURL
string
}
func
(
s
*
openaiOAuthService
)
ExchangeCode
(
ctx
context
.
Context
,
code
,
codeVerifier
,
redirectURI
,
proxyURL
string
)
(
*
openai
.
TokenResponse
,
error
)
{
client
:=
createOpenAIReqClient
(
proxyURL
)
func
(
s
*
openaiOAuthService
)
ExchangeCode
(
ctx
context
.
Context
,
code
,
codeVerifier
,
redirectURI
,
proxyURL
,
clientID
string
)
(
*
openai
.
TokenResponse
,
error
)
{
client
,
err
:=
createOpenAIReqClient
(
proxyURL
)
if
err
!=
nil
{
return
nil
,
infraerrors
.
Newf
(
http
.
StatusBadGateway
,
"OPENAI_OAUTH_CLIENT_INIT_FAILED"
,
"create HTTP client: %v"
,
err
)
}
if
redirectURI
==
""
{
redirectURI
=
openai
.
DefaultRedirectURI
}
clientID
=
strings
.
TrimSpace
(
clientID
)
if
clientID
==
""
{
clientID
=
openai
.
ClientID
}
formData
:=
url
.
Values
{}
formData
.
Set
(
"grant_type"
,
"authorization_code"
)
formData
.
Set
(
"client_id"
,
openai
.
C
lientID
)
formData
.
Set
(
"client_id"
,
c
lientID
)
formData
.
Set
(
"code"
,
code
)
formData
.
Set
(
"redirect_uri"
,
redirectURI
)
formData
.
Set
(
"code_verifier"
,
codeVerifier
)
...
...
@@ -56,12 +64,28 @@ func (s *openaiOAuthService) ExchangeCode(ctx context.Context, code, codeVerifie
}
func
(
s
*
openaiOAuthService
)
RefreshToken
(
ctx
context
.
Context
,
refreshToken
,
proxyURL
string
)
(
*
openai
.
TokenResponse
,
error
)
{
client
:=
createOpenAIReqClient
(
proxyURL
)
return
s
.
RefreshTokenWithClientID
(
ctx
,
refreshToken
,
proxyURL
,
""
)
}
func
(
s
*
openaiOAuthService
)
RefreshTokenWithClientID
(
ctx
context
.
Context
,
refreshToken
,
proxyURL
string
,
clientID
string
)
(
*
openai
.
TokenResponse
,
error
)
{
// 调用方应始终传入正确的 client_id;为兼容旧数据,未指定时默认使用 OpenAI ClientID
clientID
=
strings
.
TrimSpace
(
clientID
)
if
clientID
==
""
{
clientID
=
openai
.
ClientID
}
return
s
.
refreshTokenWithClientID
(
ctx
,
refreshToken
,
proxyURL
,
clientID
)
}
func
(
s
*
openaiOAuthService
)
refreshTokenWithClientID
(
ctx
context
.
Context
,
refreshToken
,
proxyURL
,
clientID
string
)
(
*
openai
.
TokenResponse
,
error
)
{
client
,
err
:=
createOpenAIReqClient
(
proxyURL
)
if
err
!=
nil
{
return
nil
,
infraerrors
.
Newf
(
http
.
StatusBadGateway
,
"OPENAI_OAUTH_CLIENT_INIT_FAILED"
,
"create HTTP client: %v"
,
err
)
}
formData
:=
url
.
Values
{}
formData
.
Set
(
"grant_type"
,
"refresh_token"
)
formData
.
Set
(
"refresh_token"
,
refreshToken
)
formData
.
Set
(
"client_id"
,
openai
.
C
lientID
)
formData
.
Set
(
"client_id"
,
c
lientID
)
formData
.
Set
(
"scope"
,
openai
.
RefreshScopes
)
var
tokenResp
openai
.
TokenResponse
...
...
@@ -84,7 +108,7 @@ func (s *openaiOAuthService) RefreshToken(ctx context.Context, refreshToken, pro
return
&
tokenResp
,
nil
}
func
createOpenAIReqClient
(
proxyURL
string
)
*
req
.
Client
{
func
createOpenAIReqClient
(
proxyURL
string
)
(
*
req
.
Client
,
error
)
{
return
getSharedReqClient
(
reqClientOptions
{
ProxyURL
:
proxyURL
,
Timeout
:
120
*
time
.
Second
,
...
...
backend/internal/repository/openai_oauth_service_test.go
View file @
3d79773b
...
...
@@ -81,7 +81,7 @@ func (s *OpenAIOAuthServiceSuite) TestExchangeCode_DefaultRedirectURI() {
_
,
_
=
io
.
WriteString
(
w
,
`{"access_token":"at","refresh_token":"rt","token_type":"bearer","expires_in":3600}`
)
}))
resp
,
err
:=
s
.
svc
.
ExchangeCode
(
s
.
ctx
,
"code"
,
"ver"
,
""
,
""
)
resp
,
err
:=
s
.
svc
.
ExchangeCode
(
s
.
ctx
,
"code"
,
"ver"
,
""
,
""
,
""
)
require
.
NoError
(
s
.
T
(),
err
,
"ExchangeCode"
)
select
{
case
msg
:=
<-
errCh
:
...
...
@@ -136,13 +136,84 @@ func (s *OpenAIOAuthServiceSuite) TestRefreshToken_FormFields() {
require
.
Equal
(
s
.
T
(),
"rt2"
,
resp
.
RefreshToken
)
}
// TestRefreshToken_DefaultsToOpenAIClientID 验证未指定 client_id 时默认使用 OpenAI ClientID,
// 且只发送一次请求(不再盲猜多个 client_id)。
func
(
s
*
OpenAIOAuthServiceSuite
)
TestRefreshToken_DefaultsToOpenAIClientID
()
{
var
seenClientIDs
[]
string
s
.
setupServer
(
http
.
HandlerFunc
(
func
(
w
http
.
ResponseWriter
,
r
*
http
.
Request
)
{
if
err
:=
r
.
ParseForm
();
err
!=
nil
{
w
.
WriteHeader
(
http
.
StatusBadRequest
)
return
}
clientID
:=
r
.
PostForm
.
Get
(
"client_id"
)
seenClientIDs
=
append
(
seenClientIDs
,
clientID
)
w
.
Header
()
.
Set
(
"Content-Type"
,
"application/json"
)
_
,
_
=
io
.
WriteString
(
w
,
`{"access_token":"at","refresh_token":"rt","token_type":"bearer","expires_in":3600}`
)
}))
resp
,
err
:=
s
.
svc
.
RefreshToken
(
s
.
ctx
,
"rt"
,
""
)
require
.
NoError
(
s
.
T
(),
err
,
"RefreshToken"
)
require
.
Equal
(
s
.
T
(),
"at"
,
resp
.
AccessToken
)
// 只发送了一次请求,使用默认的 OpenAI ClientID
require
.
Equal
(
s
.
T
(),
[]
string
{
openai
.
ClientID
},
seenClientIDs
)
}
// TestRefreshToken_UseSoraClientID 验证显式传入 Sora ClientID 时直接使用,不回退。
func
(
s
*
OpenAIOAuthServiceSuite
)
TestRefreshToken_UseSoraClientID
()
{
var
seenClientIDs
[]
string
s
.
setupServer
(
http
.
HandlerFunc
(
func
(
w
http
.
ResponseWriter
,
r
*
http
.
Request
)
{
if
err
:=
r
.
ParseForm
();
err
!=
nil
{
w
.
WriteHeader
(
http
.
StatusBadRequest
)
return
}
clientID
:=
r
.
PostForm
.
Get
(
"client_id"
)
seenClientIDs
=
append
(
seenClientIDs
,
clientID
)
if
clientID
==
openai
.
SoraClientID
{
w
.
Header
()
.
Set
(
"Content-Type"
,
"application/json"
)
_
,
_
=
io
.
WriteString
(
w
,
`{"access_token":"at-sora","refresh_token":"rt-sora","token_type":"bearer","expires_in":3600}`
)
return
}
w
.
WriteHeader
(
http
.
StatusBadRequest
)
}))
resp
,
err
:=
s
.
svc
.
RefreshTokenWithClientID
(
s
.
ctx
,
"rt"
,
""
,
openai
.
SoraClientID
)
require
.
NoError
(
s
.
T
(),
err
,
"RefreshTokenWithClientID"
)
require
.
Equal
(
s
.
T
(),
"at-sora"
,
resp
.
AccessToken
)
require
.
Equal
(
s
.
T
(),
[]
string
{
openai
.
SoraClientID
},
seenClientIDs
)
}
func
(
s
*
OpenAIOAuthServiceSuite
)
TestRefreshToken_UseProvidedClientID
()
{
const
customClientID
=
"custom-client-id"
var
seenClientIDs
[]
string
s
.
setupServer
(
http
.
HandlerFunc
(
func
(
w
http
.
ResponseWriter
,
r
*
http
.
Request
)
{
if
err
:=
r
.
ParseForm
();
err
!=
nil
{
w
.
WriteHeader
(
http
.
StatusBadRequest
)
return
}
clientID
:=
r
.
PostForm
.
Get
(
"client_id"
)
seenClientIDs
=
append
(
seenClientIDs
,
clientID
)
if
clientID
!=
customClientID
{
w
.
WriteHeader
(
http
.
StatusBadRequest
)
return
}
w
.
Header
()
.
Set
(
"Content-Type"
,
"application/json"
)
_
,
_
=
io
.
WriteString
(
w
,
`{"access_token":"at-custom","refresh_token":"rt-custom","token_type":"bearer","expires_in":3600}`
)
}))
resp
,
err
:=
s
.
svc
.
RefreshTokenWithClientID
(
s
.
ctx
,
"rt"
,
""
,
customClientID
)
require
.
NoError
(
s
.
T
(),
err
,
"RefreshTokenWithClientID"
)
require
.
Equal
(
s
.
T
(),
"at-custom"
,
resp
.
AccessToken
)
require
.
Equal
(
s
.
T
(),
"rt-custom"
,
resp
.
RefreshToken
)
require
.
Equal
(
s
.
T
(),
[]
string
{
customClientID
},
seenClientIDs
)
}
func
(
s
*
OpenAIOAuthServiceSuite
)
TestNonSuccessStatus_IncludesBody
()
{
s
.
setupServer
(
http
.
HandlerFunc
(
func
(
w
http
.
ResponseWriter
,
r
*
http
.
Request
)
{
w
.
WriteHeader
(
http
.
StatusBadRequest
)
_
,
_
=
io
.
WriteString
(
w
,
"bad"
)
}))
_
,
err
:=
s
.
svc
.
ExchangeCode
(
s
.
ctx
,
"code"
,
"ver"
,
openai
.
DefaultRedirectURI
,
""
)
_
,
err
:=
s
.
svc
.
ExchangeCode
(
s
.
ctx
,
"code"
,
"ver"
,
openai
.
DefaultRedirectURI
,
""
,
""
)
require
.
Error
(
s
.
T
(),
err
)
require
.
ErrorContains
(
s
.
T
(),
err
,
"status 400"
)
require
.
ErrorContains
(
s
.
T
(),
err
,
"bad"
)
...
...
@@ -152,7 +223,7 @@ func (s *OpenAIOAuthServiceSuite) TestRequestError_ClosedServer() {
s
.
setupServer
(
http
.
HandlerFunc
(
func
(
w
http
.
ResponseWriter
,
r
*
http
.
Request
)
{}))
s
.
srv
.
Close
()
_
,
err
:=
s
.
svc
.
ExchangeCode
(
s
.
ctx
,
"code"
,
"ver"
,
openai
.
DefaultRedirectURI
,
""
)
_
,
err
:=
s
.
svc
.
ExchangeCode
(
s
.
ctx
,
"code"
,
"ver"
,
openai
.
DefaultRedirectURI
,
""
,
""
)
require
.
Error
(
s
.
T
(),
err
)
require
.
ErrorContains
(
s
.
T
(),
err
,
"request failed"
)
}
...
...
@@ -169,7 +240,7 @@ func (s *OpenAIOAuthServiceSuite) TestContextCancel() {
done
:=
make
(
chan
error
,
1
)
go
func
()
{
_
,
err
:=
s
.
svc
.
ExchangeCode
(
ctx
,
"code"
,
"ver"
,
openai
.
DefaultRedirectURI
,
""
)
_
,
err
:=
s
.
svc
.
ExchangeCode
(
ctx
,
"code"
,
"ver"
,
openai
.
DefaultRedirectURI
,
""
,
""
)
done
<-
err
}()
...
...
@@ -195,7 +266,30 @@ func (s *OpenAIOAuthServiceSuite) TestExchangeCode_UsesProvidedRedirectURI() {
_
,
_
=
io
.
WriteString
(
w
,
`{"access_token":"at","token_type":"bearer","expires_in":1}`
)
}))
_
,
err
:=
s
.
svc
.
ExchangeCode
(
s
.
ctx
,
"code"
,
"ver"
,
want
,
""
)
_
,
err
:=
s
.
svc
.
ExchangeCode
(
s
.
ctx
,
"code"
,
"ver"
,
want
,
""
,
""
)
require
.
NoError
(
s
.
T
(),
err
,
"ExchangeCode"
)
select
{
case
msg
:=
<-
errCh
:
require
.
Fail
(
s
.
T
(),
msg
)
default
:
}
}
func
(
s
*
OpenAIOAuthServiceSuite
)
TestExchangeCode_UseProvidedClientID
()
{
wantClientID
:=
openai
.
SoraClientID
errCh
:=
make
(
chan
string
,
1
)
s
.
setupServer
(
http
.
HandlerFunc
(
func
(
w
http
.
ResponseWriter
,
r
*
http
.
Request
)
{
_
=
r
.
ParseForm
()
if
got
:=
r
.
PostForm
.
Get
(
"client_id"
);
got
!=
wantClientID
{
errCh
<-
"client_id mismatch"
w
.
WriteHeader
(
http
.
StatusBadRequest
)
return
}
w
.
Header
()
.
Set
(
"Content-Type"
,
"application/json"
)
_
,
_
=
io
.
WriteString
(
w
,
`{"access_token":"at","token_type":"bearer","expires_in":1}`
)
}))
_
,
err
:=
s
.
svc
.
ExchangeCode
(
s
.
ctx
,
"code"
,
"ver"
,
openai
.
DefaultRedirectURI
,
""
,
wantClientID
)
require
.
NoError
(
s
.
T
(),
err
,
"ExchangeCode"
)
select
{
case
msg
:=
<-
errCh
:
...
...
@@ -213,7 +307,7 @@ func (s *OpenAIOAuthServiceSuite) TestTokenURL_CanBeOverriddenWithQuery() {
}))
s
.
svc
.
tokenURL
=
s
.
srv
.
URL
+
"?x=1"
_
,
err
:=
s
.
svc
.
ExchangeCode
(
s
.
ctx
,
"code"
,
"ver"
,
openai
.
DefaultRedirectURI
,
""
)
_
,
err
:=
s
.
svc
.
ExchangeCode
(
s
.
ctx
,
"code"
,
"ver"
,
openai
.
DefaultRedirectURI
,
""
,
""
)
require
.
NoError
(
s
.
T
(),
err
,
"ExchangeCode"
)
select
{
case
<-
s
.
received
:
...
...
@@ -229,7 +323,7 @@ func (s *OpenAIOAuthServiceSuite) TestExchangeCode_SuccessButInvalidJSON() {
_
,
_
=
io
.
WriteString
(
w
,
"not-valid-json"
)
}))
_
,
err
:=
s
.
svc
.
ExchangeCode
(
s
.
ctx
,
"code"
,
"ver"
,
openai
.
DefaultRedirectURI
,
""
)
_
,
err
:=
s
.
svc
.
ExchangeCode
(
s
.
ctx
,
"code"
,
"ver"
,
openai
.
DefaultRedirectURI
,
""
,
""
)
require
.
Error
(
s
.
T
(),
err
,
"expected error for invalid JSON response"
)
}
...
...
Prev
1
…
9
10
11
12
13
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment