Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Menu
Open sidebar
陈曦
sub2api
Commits
25a304c2
Commit
25a304c2
authored
Dec 25, 2025
by
Forest
Browse files
test: 增加 repository 测试
parent
9d30ceae
Changes
36
Expand all
Hide whitespace changes
Inline
Side-by-side
.github/workflows/backend-ci.yml
View file @
25a304c2
...
...
@@ -17,9 +17,12 @@ jobs:
go-version-file
:
backend/go.mod
check-latest
:
true
cache
:
true
-
name
:
Run
tests
-
name
:
Unit
tests
working-directory
:
backend
run
:
go test ./...
run
:
make test-unit
-
name
:
Integration tests
working-directory
:
backend
run
:
make test-integration
golangci-lint
:
runs-on
:
ubuntu-latest
...
...
backend/Makefile
View file @
25a304c2
.PHONY
:
wire build build-embed
.PHONY
:
wire build build-embed
test-unit test-integration test-cover-integration clean-coverage
wire
:
@
echo
"生成 Wire 代码..."
...
...
@@ -13,4 +13,21 @@ build:
build-embed
:
@
echo
"构建后端(嵌入前端)..."
@
go build
-tags
embed
-o
bin/server ./cmd/server
@
echo
"构建完成: bin/server (with embedded frontend)"
\ No newline at end of file
@
echo
"构建完成: bin/server (with embedded frontend)"
test-unit
:
@
go
test
./...
$(TEST_ARGS)
test-integration
:
@
go
test
-tags
integration ./internal/repository
-count
=
1
-race
-parallel
=
8
test-cover-integration
:
@
echo
"运行集成测试并生成覆盖率报告..."
@
go
test
-tags
=
integration
-cover
-coverprofile
=
coverage.out
-count
=
1
-race
-parallel
=
8 ./internal/repository/...
@
go tool cover
-func
=
coverage.out |
tail
-1
@
go tool cover
-html
=
coverage.out
-o
coverage.html
@
echo
"覆盖率报告已生成: coverage.html"
clean-coverage
:
@
rm
-f
coverage.out coverage.html
@
echo
"覆盖率文件已清理"
\ No newline at end of file
backend/go.mod
View file @
25a304c2
...
...
@@ -11,8 +11,11 @@ require (
github.com/google/wire v0.7.0
github.com/imroc/req/v3 v3.56.0
github.com/lib/pq v1.10.9
github.com/redis/go-redis/v9 v9.
3.0
github.com/redis/go-redis/v9 v9.
7.3
github.com/spf13/viper v1.18.2
github.com/stretchr/testify v1.11.1
github.com/testcontainers/testcontainers-go/modules/postgres v0.40.0
github.com/testcontainers/testcontainers-go/modules/redis v0.40.0
github.com/tidwall/gjson v1.18.0
github.com/tidwall/sjson v1.2.5
golang.org/x/crypto v0.44.0
...
...
@@ -24,52 +27,99 @@ require (
)
require (
dario.cat/mergo v1.0.2 // indirect
github.com/Azure/go-ansiterm v0.0.0-20210617225240-d185dfc1b5a1 // indirect
github.com/Microsoft/go-winio v0.6.2 // indirect
github.com/andybalholm/brotli v1.2.0 // indirect
github.com/bytedance/sonic v1.9.1 // indirect
github.com/cespare/xxhash/v2 v2.2.0 // indirect
github.com/cenkalti/backoff/v4 v4.3.0 // indirect
github.com/cespare/xxhash/v2 v2.3.0 // indirect
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 // indirect
github.com/containerd/errdefs v1.0.0 // indirect
github.com/containerd/errdefs/pkg v0.3.0 // indirect
github.com/containerd/log v0.1.0 // indirect
github.com/containerd/platforms v0.2.1 // indirect
github.com/cpuguy83/dockercfg v0.3.2 // indirect
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
github.com/distribution/reference v0.6.0 // indirect
github.com/docker/docker v28.5.1+incompatible // indirect
github.com/docker/go-connections v0.6.0 // indirect
github.com/docker/go-units v0.5.0 // indirect
github.com/ebitengine/purego v0.8.4 // indirect
github.com/felixge/httpsnoop v1.0.4 // indirect
github.com/fsnotify/fsnotify v1.7.0 // indirect
github.com/gabriel-vasile/mimetype v1.4.2 // indirect
github.com/gin-contrib/sse v0.1.0 // indirect
github.com/go-logr/logr v1.4.3 // indirect
github.com/go-logr/stdr v1.2.2 // indirect
github.com/go-ole/go-ole v1.2.6 // indirect
github.com/go-playground/locales v0.14.1 // indirect
github.com/go-playground/universal-translator v0.18.1 // indirect
github.com/go-playground/validator/v10 v10.14.0 // indirect
github.com/goccy/go-json v0.10.2 // indirect
github.com/google/go-querystring v1.1.0 // indirect
github.com/google/subcommands v1.2.0 // indirect
github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.3 // indirect
github.com/hashicorp/hcl v1.0.0 // indirect
github.com/icholy/digest v1.1.0 // indirect
github.com/jackc/pgpassfile v1.0.0 // indirect
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect
github.com/jackc/pgx/v5 v5.4.3 // indirect
github.com/jackc/pgx/v5 v5.5.4 // indirect
github.com/jackc/puddle/v2 v2.2.1 // indirect
github.com/jinzhu/inflection v1.0.0 // indirect
github.com/jinzhu/now v1.1.5 // indirect
github.com/json-iterator/go v1.1.12 // indirect
github.com/klauspost/compress v1.18.1 // indirect
github.com/klauspost/cpuid/v2 v2.2.4 // indirect
github.com/leodido/go-urn v1.2.4 // indirect
github.com/magiconair/properties v1.8.7 // indirect
github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 // indirect
github.com/magiconair/properties v1.8.10 // indirect
github.com/mattn/go-isatty v0.0.19 // indirect
github.com/mdelapenya/tlscert v0.2.0 // indirect
github.com/mitchellh/mapstructure v1.5.0 // indirect
github.com/moby/docker-image-spec v1.3.1 // indirect
github.com/moby/go-archive v0.1.0 // indirect
github.com/moby/patternmatcher v0.6.0 // indirect
github.com/moby/sys/sequential v0.6.0 // indirect
github.com/moby/sys/user v0.4.0 // indirect
github.com/moby/sys/userns v0.1.0 // indirect
github.com/moby/term v0.5.0 // indirect
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
github.com/modern-go/reflect2 v1.0.2 // indirect
github.com/morikuni/aec v1.0.0 // indirect
github.com/opencontainers/go-digest v1.0.0 // indirect
github.com/opencontainers/image-spec v1.1.1 // indirect
github.com/pelletier/go-toml/v2 v2.1.0 // indirect
github.com/pkg/errors v0.9.1 // indirect
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect
github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c // indirect
github.com/quic-go/qpack v0.5.1 // indirect
github.com/quic-go/quic-go v0.56.0 // indirect
github.com/refraction-networking/utls v1.8.1 // indirect
github.com/sagikazarmark/locafero v0.4.0 // indirect
github.com/sagikazarmark/slog-shim v0.1.0 // indirect
github.com/shirou/gopsutil/v4 v4.25.6 // indirect
github.com/sirupsen/logrus v1.9.3 // indirect
github.com/sourcegraph/conc v0.3.0 // indirect
github.com/spf13/afero v1.11.0 // indirect
github.com/spf13/cast v1.6.0 // indirect
github.com/spf13/pflag v1.0.5 // indirect
github.com/subosito/gotenv v1.6.0 // indirect
github.com/testcontainers/testcontainers-go v0.40.0 // indirect
github.com/tidwall/match v1.1.1 // indirect
github.com/tidwall/pretty v1.2.0 // indirect
github.com/tklauser/go-sysconf v0.3.12 // indirect
github.com/tklauser/numcpus v0.6.1 // indirect
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
github.com/ugorji/go/codec v1.2.11 // indirect
github.com/yusufpapurcu/wmi v1.2.4 // indirect
go.opentelemetry.io/auto/sdk v1.1.0 // indirect
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.49.0 // indirect
go.opentelemetry.io/otel v1.37.0 // indirect
go.opentelemetry.io/otel/metric v1.37.0 // indirect
go.opentelemetry.io/otel/sdk v1.37.0 // indirect
go.opentelemetry.io/otel/trace v1.37.0 // indirect
go.uber.org/atomic v1.9.0 // indirect
go.uber.org/multierr v1.9.0 // indirect
golang.org/x/arch v0.3.0 // indirect
...
...
@@ -79,6 +129,7 @@ require (
golang.org/x/sys v0.38.0 // indirect
golang.org/x/text v0.31.0 // indirect
golang.org/x/tools v0.38.0 // indirect
google.golang.org/protobuf v1.31.0 // indirect
google.golang.org/grpc v1.75.1 // indirect
google.golang.org/protobuf v1.36.10 // indirect
gopkg.in/ini.v1 v1.67.0 // indirect
)
backend/go.sum
View file @
25a304c2
This diff is collapsed.
Click to expand it.
backend/internal/repository/account_repo_integration_test.go
0 → 100644
View file @
25a304c2
This diff is collapsed.
Click to expand it.
backend/internal/repository/api_key_cache_integration_test.go
0 → 100644
View file @
25a304c2
//go:build integration
package
repository
import
(
"context"
"fmt"
"testing"
"time"
"github.com/redis/go-redis/v9"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
)
type
ApiKeyCacheSuite
struct
{
IntegrationRedisSuite
}
func
(
s
*
ApiKeyCacheSuite
)
TestCreateAttemptCount
()
{
tests
:=
[]
struct
{
name
string
fn
func
(
ctx
context
.
Context
,
rdb
*
redis
.
Client
,
cache
*
apiKeyCache
)
}{
{
name
:
"missing_key_returns_redis_nil"
,
fn
:
func
(
ctx
context
.
Context
,
rdb
*
redis
.
Client
,
cache
*
apiKeyCache
)
{
userID
:=
int64
(
1
)
_
,
err
:=
cache
.
GetCreateAttemptCount
(
ctx
,
userID
)
require
.
ErrorIs
(
s
.
T
(),
err
,
redis
.
Nil
,
"expected redis.Nil for missing key"
)
},
},
{
name
:
"increment_increases_count_and_sets_ttl"
,
fn
:
func
(
ctx
context
.
Context
,
rdb
*
redis
.
Client
,
cache
*
apiKeyCache
)
{
userID
:=
int64
(
1
)
key
:=
fmt
.
Sprintf
(
"%s%d"
,
apiKeyRateLimitKeyPrefix
,
userID
)
require
.
NoError
(
s
.
T
(),
cache
.
IncrementCreateAttemptCount
(
ctx
,
userID
),
"IncrementCreateAttemptCount"
)
require
.
NoError
(
s
.
T
(),
cache
.
IncrementCreateAttemptCount
(
ctx
,
userID
),
"IncrementCreateAttemptCount 2"
)
count
,
err
:=
cache
.
GetCreateAttemptCount
(
ctx
,
userID
)
require
.
NoError
(
s
.
T
(),
err
,
"GetCreateAttemptCount"
)
require
.
Equal
(
s
.
T
(),
2
,
count
,
"count mismatch"
)
ttl
,
err
:=
rdb
.
TTL
(
ctx
,
key
)
.
Result
()
require
.
NoError
(
s
.
T
(),
err
,
"TTL"
)
s
.
AssertTTLWithin
(
ttl
,
1
*
time
.
Second
,
apiKeyRateLimitDuration
)
},
},
{
name
:
"delete_removes_key"
,
fn
:
func
(
ctx
context
.
Context
,
rdb
*
redis
.
Client
,
cache
*
apiKeyCache
)
{
userID
:=
int64
(
1
)
require
.
NoError
(
s
.
T
(),
cache
.
IncrementCreateAttemptCount
(
ctx
,
userID
))
require
.
NoError
(
s
.
T
(),
cache
.
DeleteCreateAttemptCount
(
ctx
,
userID
),
"DeleteCreateAttemptCount"
)
_
,
err
:=
cache
.
GetCreateAttemptCount
(
ctx
,
userID
)
require
.
ErrorIs
(
s
.
T
(),
err
,
redis
.
Nil
,
"expected redis.Nil after delete"
)
},
},
}
for
_
,
tt
:=
range
tests
{
s
.
Run
(
tt
.
name
,
func
()
{
// 每个 case 重新获取隔离资源
rdb
:=
testRedis
(
s
.
T
())
cache
:=
&
apiKeyCache
{
rdb
:
rdb
}
ctx
:=
context
.
Background
()
tt
.
fn
(
ctx
,
rdb
,
cache
)
})
}
}
func
(
s
*
ApiKeyCacheSuite
)
TestDailyUsage
()
{
tests
:=
[]
struct
{
name
string
fn
func
(
ctx
context
.
Context
,
rdb
*
redis
.
Client
,
cache
*
apiKeyCache
)
}{
{
name
:
"increment_increases_count"
,
fn
:
func
(
ctx
context
.
Context
,
rdb
*
redis
.
Client
,
cache
*
apiKeyCache
)
{
dailyKey
:=
"daily:sk-test"
require
.
NoError
(
s
.
T
(),
cache
.
IncrementDailyUsage
(
ctx
,
dailyKey
),
"IncrementDailyUsage"
)
require
.
NoError
(
s
.
T
(),
cache
.
IncrementDailyUsage
(
ctx
,
dailyKey
),
"IncrementDailyUsage 2"
)
n
,
err
:=
rdb
.
Get
(
ctx
,
dailyKey
)
.
Int
()
require
.
NoError
(
s
.
T
(),
err
,
"Get dailyKey"
)
require
.
Equal
(
s
.
T
(),
2
,
n
,
"expected daily usage=2"
)
},
},
{
name
:
"set_expiry_sets_ttl"
,
fn
:
func
(
ctx
context
.
Context
,
rdb
*
redis
.
Client
,
cache
*
apiKeyCache
)
{
dailyKey
:=
"daily:sk-test-expiry"
require
.
NoError
(
s
.
T
(),
cache
.
IncrementDailyUsage
(
ctx
,
dailyKey
))
require
.
NoError
(
s
.
T
(),
cache
.
SetDailyUsageExpiry
(
ctx
,
dailyKey
,
1
*
time
.
Hour
),
"SetDailyUsageExpiry"
)
ttl
,
err
:=
rdb
.
TTL
(
ctx
,
dailyKey
)
.
Result
()
require
.
NoError
(
s
.
T
(),
err
,
"TTL dailyKey"
)
require
.
Greater
(
s
.
T
(),
ttl
,
time
.
Duration
(
0
),
"expected ttl > 0"
)
},
},
}
for
_
,
tt
:=
range
tests
{
s
.
Run
(
tt
.
name
,
func
()
{
rdb
:=
testRedis
(
s
.
T
())
cache
:=
&
apiKeyCache
{
rdb
:
rdb
}
ctx
:=
context
.
Background
()
tt
.
fn
(
ctx
,
rdb
,
cache
)
})
}
}
func
TestApiKeyCacheSuite
(
t
*
testing
.
T
)
{
suite
.
Run
(
t
,
new
(
ApiKeyCacheSuite
))
}
backend/internal/repository/api_key_repo_integration_test.go
0 → 100644
View file @
25a304c2
//go:build integration
package
repository
import
(
"context"
"testing"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/stretchr/testify/suite"
"gorm.io/gorm"
)
type
ApiKeyRepoSuite
struct
{
suite
.
Suite
ctx
context
.
Context
db
*
gorm
.
DB
repo
*
ApiKeyRepository
}
func
(
s
*
ApiKeyRepoSuite
)
SetupTest
()
{
s
.
ctx
=
context
.
Background
()
s
.
db
=
testTx
(
s
.
T
())
s
.
repo
=
NewApiKeyRepository
(
s
.
db
)
}
func
TestApiKeyRepoSuite
(
t
*
testing
.
T
)
{
suite
.
Run
(
t
,
new
(
ApiKeyRepoSuite
))
}
// --- Create / GetByID / GetByKey ---
func
(
s
*
ApiKeyRepoSuite
)
TestCreate
()
{
user
:=
mustCreateUser
(
s
.
T
(),
s
.
db
,
&
model
.
User
{
Email
:
"create@test.com"
})
key
:=
&
model
.
ApiKey
{
UserID
:
user
.
ID
,
Key
:
"sk-create-test"
,
Name
:
"Test Key"
,
Status
:
model
.
StatusActive
,
}
err
:=
s
.
repo
.
Create
(
s
.
ctx
,
key
)
s
.
Require
()
.
NoError
(
err
,
"Create"
)
s
.
Require
()
.
NotZero
(
key
.
ID
,
"expected ID to be set"
)
got
,
err
:=
s
.
repo
.
GetByID
(
s
.
ctx
,
key
.
ID
)
s
.
Require
()
.
NoError
(
err
,
"GetByID"
)
s
.
Require
()
.
Equal
(
"sk-create-test"
,
got
.
Key
)
}
func
(
s
*
ApiKeyRepoSuite
)
TestGetByID_NotFound
()
{
_
,
err
:=
s
.
repo
.
GetByID
(
s
.
ctx
,
999999
)
s
.
Require
()
.
Error
(
err
,
"expected error for non-existent ID"
)
}
func
(
s
*
ApiKeyRepoSuite
)
TestGetByKey
()
{
user
:=
mustCreateUser
(
s
.
T
(),
s
.
db
,
&
model
.
User
{
Email
:
"getbykey@test.com"
})
group
:=
mustCreateGroup
(
s
.
T
(),
s
.
db
,
&
model
.
Group
{
Name
:
"g-key"
})
key
:=
mustCreateApiKey
(
s
.
T
(),
s
.
db
,
&
model
.
ApiKey
{
UserID
:
user
.
ID
,
Key
:
"sk-getbykey"
,
Name
:
"My Key"
,
GroupID
:
&
group
.
ID
,
Status
:
model
.
StatusActive
,
})
got
,
err
:=
s
.
repo
.
GetByKey
(
s
.
ctx
,
key
.
Key
)
s
.
Require
()
.
NoError
(
err
,
"GetByKey"
)
s
.
Require
()
.
Equal
(
key
.
ID
,
got
.
ID
)
s
.
Require
()
.
NotNil
(
got
.
User
,
"expected User preload"
)
s
.
Require
()
.
Equal
(
user
.
ID
,
got
.
User
.
ID
)
s
.
Require
()
.
NotNil
(
got
.
Group
,
"expected Group preload"
)
s
.
Require
()
.
Equal
(
group
.
ID
,
got
.
Group
.
ID
)
}
func
(
s
*
ApiKeyRepoSuite
)
TestGetByKey_NotFound
()
{
_
,
err
:=
s
.
repo
.
GetByKey
(
s
.
ctx
,
"non-existent-key"
)
s
.
Require
()
.
Error
(
err
,
"expected error for non-existent key"
)
}
// --- Update ---
func
(
s
*
ApiKeyRepoSuite
)
TestUpdate
()
{
user
:=
mustCreateUser
(
s
.
T
(),
s
.
db
,
&
model
.
User
{
Email
:
"update@test.com"
})
key
:=
mustCreateApiKey
(
s
.
T
(),
s
.
db
,
&
model
.
ApiKey
{
UserID
:
user
.
ID
,
Key
:
"sk-update"
,
Name
:
"Original"
,
Status
:
model
.
StatusActive
,
})
key
.
Name
=
"Renamed"
key
.
Status
=
model
.
StatusDisabled
err
:=
s
.
repo
.
Update
(
s
.
ctx
,
key
)
s
.
Require
()
.
NoError
(
err
,
"Update"
)
got
,
err
:=
s
.
repo
.
GetByID
(
s
.
ctx
,
key
.
ID
)
s
.
Require
()
.
NoError
(
err
,
"GetByID after update"
)
s
.
Require
()
.
Equal
(
"sk-update"
,
got
.
Key
,
"Update should not change key"
)
s
.
Require
()
.
Equal
(
user
.
ID
,
got
.
UserID
,
"Update should not change user_id"
)
s
.
Require
()
.
Equal
(
"Renamed"
,
got
.
Name
)
s
.
Require
()
.
Equal
(
model
.
StatusDisabled
,
got
.
Status
)
}
func
(
s
*
ApiKeyRepoSuite
)
TestUpdate_ClearGroupID
()
{
user
:=
mustCreateUser
(
s
.
T
(),
s
.
db
,
&
model
.
User
{
Email
:
"cleargroup@test.com"
})
group
:=
mustCreateGroup
(
s
.
T
(),
s
.
db
,
&
model
.
Group
{
Name
:
"g-clear"
})
key
:=
mustCreateApiKey
(
s
.
T
(),
s
.
db
,
&
model
.
ApiKey
{
UserID
:
user
.
ID
,
Key
:
"sk-clear-group"
,
Name
:
"Group Key"
,
GroupID
:
&
group
.
ID
,
})
key
.
GroupID
=
nil
err
:=
s
.
repo
.
Update
(
s
.
ctx
,
key
)
s
.
Require
()
.
NoError
(
err
,
"Update"
)
got
,
err
:=
s
.
repo
.
GetByID
(
s
.
ctx
,
key
.
ID
)
s
.
Require
()
.
NoError
(
err
)
s
.
Require
()
.
Nil
(
got
.
GroupID
,
"expected GroupID to be cleared"
)
}
// --- Delete ---
func
(
s
*
ApiKeyRepoSuite
)
TestDelete
()
{
user
:=
mustCreateUser
(
s
.
T
(),
s
.
db
,
&
model
.
User
{
Email
:
"delete@test.com"
})
key
:=
mustCreateApiKey
(
s
.
T
(),
s
.
db
,
&
model
.
ApiKey
{
UserID
:
user
.
ID
,
Key
:
"sk-delete"
,
Name
:
"Delete Me"
,
})
err
:=
s
.
repo
.
Delete
(
s
.
ctx
,
key
.
ID
)
s
.
Require
()
.
NoError
(
err
,
"Delete"
)
_
,
err
=
s
.
repo
.
GetByID
(
s
.
ctx
,
key
.
ID
)
s
.
Require
()
.
Error
(
err
,
"expected error after delete"
)
}
// --- ListByUserID / CountByUserID ---
func
(
s
*
ApiKeyRepoSuite
)
TestListByUserID
()
{
user
:=
mustCreateUser
(
s
.
T
(),
s
.
db
,
&
model
.
User
{
Email
:
"listbyuser@test.com"
})
mustCreateApiKey
(
s
.
T
(),
s
.
db
,
&
model
.
ApiKey
{
UserID
:
user
.
ID
,
Key
:
"sk-list-1"
,
Name
:
"Key 1"
})
mustCreateApiKey
(
s
.
T
(),
s
.
db
,
&
model
.
ApiKey
{
UserID
:
user
.
ID
,
Key
:
"sk-list-2"
,
Name
:
"Key 2"
})
keys
,
page
,
err
:=
s
.
repo
.
ListByUserID
(
s
.
ctx
,
user
.
ID
,
pagination
.
PaginationParams
{
Page
:
1
,
PageSize
:
10
})
s
.
Require
()
.
NoError
(
err
,
"ListByUserID"
)
s
.
Require
()
.
Len
(
keys
,
2
)
s
.
Require
()
.
Equal
(
int64
(
2
),
page
.
Total
)
}
func
(
s
*
ApiKeyRepoSuite
)
TestListByUserID_Pagination
()
{
user
:=
mustCreateUser
(
s
.
T
(),
s
.
db
,
&
model
.
User
{
Email
:
"paging@test.com"
})
for
i
:=
0
;
i
<
5
;
i
++
{
mustCreateApiKey
(
s
.
T
(),
s
.
db
,
&
model
.
ApiKey
{
UserID
:
user
.
ID
,
Key
:
"sk-page-"
+
string
(
rune
(
'a'
+
i
)),
Name
:
"Key"
,
})
}
keys
,
page
,
err
:=
s
.
repo
.
ListByUserID
(
s
.
ctx
,
user
.
ID
,
pagination
.
PaginationParams
{
Page
:
1
,
PageSize
:
2
})
s
.
Require
()
.
NoError
(
err
)
s
.
Require
()
.
Len
(
keys
,
2
)
s
.
Require
()
.
Equal
(
int64
(
5
),
page
.
Total
)
s
.
Require
()
.
Equal
(
3
,
page
.
Pages
)
}
func
(
s
*
ApiKeyRepoSuite
)
TestCountByUserID
()
{
user
:=
mustCreateUser
(
s
.
T
(),
s
.
db
,
&
model
.
User
{
Email
:
"count@test.com"
})
mustCreateApiKey
(
s
.
T
(),
s
.
db
,
&
model
.
ApiKey
{
UserID
:
user
.
ID
,
Key
:
"sk-count-1"
,
Name
:
"K1"
})
mustCreateApiKey
(
s
.
T
(),
s
.
db
,
&
model
.
ApiKey
{
UserID
:
user
.
ID
,
Key
:
"sk-count-2"
,
Name
:
"K2"
})
count
,
err
:=
s
.
repo
.
CountByUserID
(
s
.
ctx
,
user
.
ID
)
s
.
Require
()
.
NoError
(
err
,
"CountByUserID"
)
s
.
Require
()
.
Equal
(
int64
(
2
),
count
)
}
// --- ListByGroupID / CountByGroupID ---
func
(
s
*
ApiKeyRepoSuite
)
TestListByGroupID
()
{
user
:=
mustCreateUser
(
s
.
T
(),
s
.
db
,
&
model
.
User
{
Email
:
"listbygroup@test.com"
})
group
:=
mustCreateGroup
(
s
.
T
(),
s
.
db
,
&
model
.
Group
{
Name
:
"g-list"
})
mustCreateApiKey
(
s
.
T
(),
s
.
db
,
&
model
.
ApiKey
{
UserID
:
user
.
ID
,
Key
:
"sk-grp-1"
,
Name
:
"K1"
,
GroupID
:
&
group
.
ID
})
mustCreateApiKey
(
s
.
T
(),
s
.
db
,
&
model
.
ApiKey
{
UserID
:
user
.
ID
,
Key
:
"sk-grp-2"
,
Name
:
"K2"
,
GroupID
:
&
group
.
ID
})
mustCreateApiKey
(
s
.
T
(),
s
.
db
,
&
model
.
ApiKey
{
UserID
:
user
.
ID
,
Key
:
"sk-grp-3"
,
Name
:
"K3"
})
// no group
keys
,
page
,
err
:=
s
.
repo
.
ListByGroupID
(
s
.
ctx
,
group
.
ID
,
pagination
.
PaginationParams
{
Page
:
1
,
PageSize
:
10
})
s
.
Require
()
.
NoError
(
err
,
"ListByGroupID"
)
s
.
Require
()
.
Len
(
keys
,
2
)
s
.
Require
()
.
Equal
(
int64
(
2
),
page
.
Total
)
// User preloaded
s
.
Require
()
.
NotNil
(
keys
[
0
]
.
User
)
}
func
(
s
*
ApiKeyRepoSuite
)
TestCountByGroupID
()
{
user
:=
mustCreateUser
(
s
.
T
(),
s
.
db
,
&
model
.
User
{
Email
:
"countgroup@test.com"
})
group
:=
mustCreateGroup
(
s
.
T
(),
s
.
db
,
&
model
.
Group
{
Name
:
"g-count"
})
mustCreateApiKey
(
s
.
T
(),
s
.
db
,
&
model
.
ApiKey
{
UserID
:
user
.
ID
,
Key
:
"sk-gc-1"
,
Name
:
"K1"
,
GroupID
:
&
group
.
ID
})
count
,
err
:=
s
.
repo
.
CountByGroupID
(
s
.
ctx
,
group
.
ID
)
s
.
Require
()
.
NoError
(
err
,
"CountByGroupID"
)
s
.
Require
()
.
Equal
(
int64
(
1
),
count
)
}
// --- ExistsByKey ---
func
(
s
*
ApiKeyRepoSuite
)
TestExistsByKey
()
{
user
:=
mustCreateUser
(
s
.
T
(),
s
.
db
,
&
model
.
User
{
Email
:
"exists@test.com"
})
mustCreateApiKey
(
s
.
T
(),
s
.
db
,
&
model
.
ApiKey
{
UserID
:
user
.
ID
,
Key
:
"sk-exists"
,
Name
:
"K"
})
exists
,
err
:=
s
.
repo
.
ExistsByKey
(
s
.
ctx
,
"sk-exists"
)
s
.
Require
()
.
NoError
(
err
,
"ExistsByKey"
)
s
.
Require
()
.
True
(
exists
)
notExists
,
err
:=
s
.
repo
.
ExistsByKey
(
s
.
ctx
,
"sk-not-exists"
)
s
.
Require
()
.
NoError
(
err
)
s
.
Require
()
.
False
(
notExists
)
}
// --- SearchApiKeys ---
func
(
s
*
ApiKeyRepoSuite
)
TestSearchApiKeys
()
{
user
:=
mustCreateUser
(
s
.
T
(),
s
.
db
,
&
model
.
User
{
Email
:
"search@test.com"
})
mustCreateApiKey
(
s
.
T
(),
s
.
db
,
&
model
.
ApiKey
{
UserID
:
user
.
ID
,
Key
:
"sk-search-1"
,
Name
:
"Production Key"
})
mustCreateApiKey
(
s
.
T
(),
s
.
db
,
&
model
.
ApiKey
{
UserID
:
user
.
ID
,
Key
:
"sk-search-2"
,
Name
:
"Development Key"
})
found
,
err
:=
s
.
repo
.
SearchApiKeys
(
s
.
ctx
,
user
.
ID
,
"prod"
,
10
)
s
.
Require
()
.
NoError
(
err
,
"SearchApiKeys"
)
s
.
Require
()
.
Len
(
found
,
1
)
s
.
Require
()
.
Contains
(
found
[
0
]
.
Name
,
"Production"
)
}
func
(
s
*
ApiKeyRepoSuite
)
TestSearchApiKeys_NoKeyword
()
{
user
:=
mustCreateUser
(
s
.
T
(),
s
.
db
,
&
model
.
User
{
Email
:
"searchnokw@test.com"
})
mustCreateApiKey
(
s
.
T
(),
s
.
db
,
&
model
.
ApiKey
{
UserID
:
user
.
ID
,
Key
:
"sk-nk-1"
,
Name
:
"K1"
})
mustCreateApiKey
(
s
.
T
(),
s
.
db
,
&
model
.
ApiKey
{
UserID
:
user
.
ID
,
Key
:
"sk-nk-2"
,
Name
:
"K2"
})
found
,
err
:=
s
.
repo
.
SearchApiKeys
(
s
.
ctx
,
user
.
ID
,
""
,
10
)
s
.
Require
()
.
NoError
(
err
)
s
.
Require
()
.
Len
(
found
,
2
)
}
func
(
s
*
ApiKeyRepoSuite
)
TestSearchApiKeys_NoUserID
()
{
user
:=
mustCreateUser
(
s
.
T
(),
s
.
db
,
&
model
.
User
{
Email
:
"searchnouid@test.com"
})
mustCreateApiKey
(
s
.
T
(),
s
.
db
,
&
model
.
ApiKey
{
UserID
:
user
.
ID
,
Key
:
"sk-nu-1"
,
Name
:
"TestKey"
})
found
,
err
:=
s
.
repo
.
SearchApiKeys
(
s
.
ctx
,
0
,
"testkey"
,
10
)
s
.
Require
()
.
NoError
(
err
)
s
.
Require
()
.
Len
(
found
,
1
)
}
// --- ClearGroupIDByGroupID ---
func
(
s
*
ApiKeyRepoSuite
)
TestClearGroupIDByGroupID
()
{
user
:=
mustCreateUser
(
s
.
T
(),
s
.
db
,
&
model
.
User
{
Email
:
"cleargrp@test.com"
})
group
:=
mustCreateGroup
(
s
.
T
(),
s
.
db
,
&
model
.
Group
{
Name
:
"g-clear-bulk"
})
k1
:=
mustCreateApiKey
(
s
.
T
(),
s
.
db
,
&
model
.
ApiKey
{
UserID
:
user
.
ID
,
Key
:
"sk-clr-1"
,
Name
:
"K1"
,
GroupID
:
&
group
.
ID
})
k2
:=
mustCreateApiKey
(
s
.
T
(),
s
.
db
,
&
model
.
ApiKey
{
UserID
:
user
.
ID
,
Key
:
"sk-clr-2"
,
Name
:
"K2"
,
GroupID
:
&
group
.
ID
})
mustCreateApiKey
(
s
.
T
(),
s
.
db
,
&
model
.
ApiKey
{
UserID
:
user
.
ID
,
Key
:
"sk-clr-3"
,
Name
:
"K3"
})
// no group
affected
,
err
:=
s
.
repo
.
ClearGroupIDByGroupID
(
s
.
ctx
,
group
.
ID
)
s
.
Require
()
.
NoError
(
err
,
"ClearGroupIDByGroupID"
)
s
.
Require
()
.
Equal
(
int64
(
2
),
affected
)
got1
,
_
:=
s
.
repo
.
GetByID
(
s
.
ctx
,
k1
.
ID
)
got2
,
_
:=
s
.
repo
.
GetByID
(
s
.
ctx
,
k2
.
ID
)
s
.
Require
()
.
Nil
(
got1
.
GroupID
)
s
.
Require
()
.
Nil
(
got2
.
GroupID
)
count
,
_
:=
s
.
repo
.
CountByGroupID
(
s
.
ctx
,
group
.
ID
)
s
.
Require
()
.
Zero
(
count
)
}
// --- Combined CRUD/Search/ClearGroupID (original test preserved as integration) ---
func
(
s
*
ApiKeyRepoSuite
)
TestCRUD_Search_ClearGroupID
()
{
user
:=
mustCreateUser
(
s
.
T
(),
s
.
db
,
&
model
.
User
{
Email
:
"k@example.com"
})
group
:=
mustCreateGroup
(
s
.
T
(),
s
.
db
,
&
model
.
Group
{
Name
:
"g-k"
})
key
:=
mustCreateApiKey
(
s
.
T
(),
s
.
db
,
&
model
.
ApiKey
{
UserID
:
user
.
ID
,
Key
:
"sk-test-1"
,
Name
:
"My Key"
,
GroupID
:
&
group
.
ID
,
Status
:
model
.
StatusActive
,
})
got
,
err
:=
s
.
repo
.
GetByKey
(
s
.
ctx
,
key
.
Key
)
s
.
Require
()
.
NoError
(
err
,
"GetByKey"
)
s
.
Require
()
.
Equal
(
key
.
ID
,
got
.
ID
)
s
.
Require
()
.
NotNil
(
got
.
User
)
s
.
Require
()
.
Equal
(
user
.
ID
,
got
.
User
.
ID
)
s
.
Require
()
.
NotNil
(
got
.
Group
)
s
.
Require
()
.
Equal
(
group
.
ID
,
got
.
Group
.
ID
)
key
.
Name
=
"Renamed"
key
.
Status
=
model
.
StatusDisabled
key
.
GroupID
=
nil
s
.
Require
()
.
NoError
(
s
.
repo
.
Update
(
s
.
ctx
,
key
),
"Update"
)
got2
,
err
:=
s
.
repo
.
GetByID
(
s
.
ctx
,
key
.
ID
)
s
.
Require
()
.
NoError
(
err
,
"GetByID"
)
s
.
Require
()
.
Equal
(
"sk-test-1"
,
got2
.
Key
,
"Update should not change key"
)
s
.
Require
()
.
Equal
(
user
.
ID
,
got2
.
UserID
,
"Update should not change user_id"
)
s
.
Require
()
.
Equal
(
"Renamed"
,
got2
.
Name
)
s
.
Require
()
.
Equal
(
model
.
StatusDisabled
,
got2
.
Status
)
s
.
Require
()
.
Nil
(
got2
.
GroupID
)
keys
,
page
,
err
:=
s
.
repo
.
ListByUserID
(
s
.
ctx
,
user
.
ID
,
pagination
.
PaginationParams
{
Page
:
1
,
PageSize
:
10
})
s
.
Require
()
.
NoError
(
err
,
"ListByUserID"
)
s
.
Require
()
.
Equal
(
int64
(
1
),
page
.
Total
)
s
.
Require
()
.
Len
(
keys
,
1
)
exists
,
err
:=
s
.
repo
.
ExistsByKey
(
s
.
ctx
,
"sk-test-1"
)
s
.
Require
()
.
NoError
(
err
,
"ExistsByKey"
)
s
.
Require
()
.
True
(
exists
,
"expected key to exist"
)
found
,
err
:=
s
.
repo
.
SearchApiKeys
(
s
.
ctx
,
user
.
ID
,
"renam"
,
10
)
s
.
Require
()
.
NoError
(
err
,
"SearchApiKeys"
)
s
.
Require
()
.
Len
(
found
,
1
)
s
.
Require
()
.
Equal
(
key
.
ID
,
found
[
0
]
.
ID
)
// ClearGroupIDByGroupID
k2
:=
mustCreateApiKey
(
s
.
T
(),
s
.
db
,
&
model
.
ApiKey
{
UserID
:
user
.
ID
,
Key
:
"sk-test-2"
,
Name
:
"Group Key"
,
GroupID
:
&
group
.
ID
,
})
countBefore
,
err
:=
s
.
repo
.
CountByGroupID
(
s
.
ctx
,
group
.
ID
)
s
.
Require
()
.
NoError
(
err
,
"CountByGroupID"
)
s
.
Require
()
.
Equal
(
int64
(
1
),
countBefore
,
"expected 1 key in group before clear"
)
affected
,
err
:=
s
.
repo
.
ClearGroupIDByGroupID
(
s
.
ctx
,
group
.
ID
)
s
.
Require
()
.
NoError
(
err
,
"ClearGroupIDByGroupID"
)
s
.
Require
()
.
Equal
(
int64
(
1
),
affected
,
"expected 1 affected row"
)
got3
,
err
:=
s
.
repo
.
GetByID
(
s
.
ctx
,
k2
.
ID
)
s
.
Require
()
.
NoError
(
err
,
"GetByID"
)
s
.
Require
()
.
Nil
(
got3
.
GroupID
,
"expected GroupID cleared"
)
countAfter
,
err
:=
s
.
repo
.
CountByGroupID
(
s
.
ctx
,
group
.
ID
)
s
.
Require
()
.
NoError
(
err
,
"CountByGroupID after clear"
)
s
.
Require
()
.
Equal
(
int64
(
0
),
countAfter
,
"expected 0 keys in group after clear"
)
}
backend/internal/repository/billing_cache_integration_test.go
0 → 100644
View file @
25a304c2
//go:build integration
package
repository
import
(
"context"
"fmt"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/service/ports"
"github.com/redis/go-redis/v9"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
)
type
BillingCacheSuite
struct
{
IntegrationRedisSuite
}
func
(
s
*
BillingCacheSuite
)
TestUserBalance
()
{
tests
:=
[]
struct
{
name
string
fn
func
(
ctx
context
.
Context
,
rdb
*
redis
.
Client
,
cache
ports
.
BillingCache
)
}{
{
name
:
"missing_key_returns_redis_nil"
,
fn
:
func
(
ctx
context
.
Context
,
rdb
*
redis
.
Client
,
cache
ports
.
BillingCache
)
{
_
,
err
:=
cache
.
GetUserBalance
(
ctx
,
1
)
require
.
ErrorIs
(
s
.
T
(),
err
,
redis
.
Nil
,
"expected redis.Nil for missing balance key"
)
},
},
{
name
:
"deduct_on_nonexistent_is_noop"
,
fn
:
func
(
ctx
context
.
Context
,
rdb
*
redis
.
Client
,
cache
ports
.
BillingCache
)
{
userID
:=
int64
(
1
)
balanceKey
:=
fmt
.
Sprintf
(
"%s%d"
,
billingBalanceKeyPrefix
,
userID
)
require
.
NoError
(
s
.
T
(),
cache
.
DeductUserBalance
(
ctx
,
userID
,
1
),
"DeductUserBalance should not error"
)
_
,
err
:=
rdb
.
Get
(
ctx
,
balanceKey
)
.
Result
()
require
.
ErrorIs
(
s
.
T
(),
err
,
redis
.
Nil
,
"expected missing key after deduct on non-existent"
)
},
},
{
name
:
"set_and_get_with_ttl"
,
fn
:
func
(
ctx
context
.
Context
,
rdb
*
redis
.
Client
,
cache
ports
.
BillingCache
)
{
userID
:=
int64
(
2
)
balanceKey
:=
fmt
.
Sprintf
(
"%s%d"
,
billingBalanceKeyPrefix
,
userID
)
require
.
NoError
(
s
.
T
(),
cache
.
SetUserBalance
(
ctx
,
userID
,
10.5
),
"SetUserBalance"
)
got
,
err
:=
cache
.
GetUserBalance
(
ctx
,
userID
)
require
.
NoError
(
s
.
T
(),
err
,
"GetUserBalance"
)
require
.
Equal
(
s
.
T
(),
10.5
,
got
,
"balance mismatch"
)
ttl
,
err
:=
rdb
.
TTL
(
ctx
,
balanceKey
)
.
Result
()
require
.
NoError
(
s
.
T
(),
err
,
"TTL"
)
s
.
AssertTTLWithin
(
ttl
,
1
*
time
.
Second
,
billingCacheTTL
)
},
},
{
name
:
"deduct_reduces_balance"
,
fn
:
func
(
ctx
context
.
Context
,
rdb
*
redis
.
Client
,
cache
ports
.
BillingCache
)
{
userID
:=
int64
(
3
)
require
.
NoError
(
s
.
T
(),
cache
.
SetUserBalance
(
ctx
,
userID
,
10.5
),
"SetUserBalance"
)
require
.
NoError
(
s
.
T
(),
cache
.
DeductUserBalance
(
ctx
,
userID
,
2.25
),
"DeductUserBalance"
)
got
,
err
:=
cache
.
GetUserBalance
(
ctx
,
userID
)
require
.
NoError
(
s
.
T
(),
err
,
"GetUserBalance after deduct"
)
require
.
Equal
(
s
.
T
(),
8.25
,
got
,
"deduct mismatch"
)
},
},
{
name
:
"invalidate_removes_key"
,
fn
:
func
(
ctx
context
.
Context
,
rdb
*
redis
.
Client
,
cache
ports
.
BillingCache
)
{
userID
:=
int64
(
100
)
balanceKey
:=
fmt
.
Sprintf
(
"%s%d"
,
billingBalanceKeyPrefix
,
userID
)
require
.
NoError
(
s
.
T
(),
cache
.
SetUserBalance
(
ctx
,
userID
,
50.0
),
"SetUserBalance"
)
exists
,
err
:=
rdb
.
Exists
(
ctx
,
balanceKey
)
.
Result
()
require
.
NoError
(
s
.
T
(),
err
,
"Exists"
)
require
.
Equal
(
s
.
T
(),
int64
(
1
),
exists
,
"expected balance key to exist"
)
require
.
NoError
(
s
.
T
(),
cache
.
InvalidateUserBalance
(
ctx
,
userID
),
"InvalidateUserBalance"
)
exists
,
err
=
rdb
.
Exists
(
ctx
,
balanceKey
)
.
Result
()
require
.
NoError
(
s
.
T
(),
err
,
"Exists after invalidate"
)
require
.
Equal
(
s
.
T
(),
int64
(
0
),
exists
,
"expected balance key to be removed after invalidate"
)
_
,
err
=
cache
.
GetUserBalance
(
ctx
,
userID
)
require
.
ErrorIs
(
s
.
T
(),
err
,
redis
.
Nil
,
"expected redis.Nil after invalidate"
)
},
},
{
name
:
"deduct_refreshes_ttl"
,
fn
:
func
(
ctx
context
.
Context
,
rdb
*
redis
.
Client
,
cache
ports
.
BillingCache
)
{
userID
:=
int64
(
103
)
balanceKey
:=
fmt
.
Sprintf
(
"%s%d"
,
billingBalanceKeyPrefix
,
userID
)
require
.
NoError
(
s
.
T
(),
cache
.
SetUserBalance
(
ctx
,
userID
,
100.0
),
"SetUserBalance"
)
ttl1
,
err
:=
rdb
.
TTL
(
ctx
,
balanceKey
)
.
Result
()
require
.
NoError
(
s
.
T
(),
err
,
"TTL before deduct"
)
s
.
AssertTTLWithin
(
ttl1
,
1
*
time
.
Second
,
billingCacheTTL
)
require
.
NoError
(
s
.
T
(),
cache
.
DeductUserBalance
(
ctx
,
userID
,
25.0
),
"DeductUserBalance"
)
balance
,
err
:=
cache
.
GetUserBalance
(
ctx
,
userID
)
require
.
NoError
(
s
.
T
(),
err
,
"GetUserBalance"
)
require
.
Equal
(
s
.
T
(),
75.0
,
balance
,
"expected balance 75.0"
)
ttl2
,
err
:=
rdb
.
TTL
(
ctx
,
balanceKey
)
.
Result
()
require
.
NoError
(
s
.
T
(),
err
,
"TTL after deduct"
)
s
.
AssertTTLWithin
(
ttl2
,
1
*
time
.
Second
,
billingCacheTTL
)
},
},
}
for
_
,
tt
:=
range
tests
{
s
.
Run
(
tt
.
name
,
func
()
{
rdb
:=
testRedis
(
s
.
T
())
cache
:=
NewBillingCache
(
rdb
)
ctx
:=
context
.
Background
()
tt
.
fn
(
ctx
,
rdb
,
cache
)
})
}
}
func
(
s
*
BillingCacheSuite
)
TestSubscriptionCache
()
{
tests
:=
[]
struct
{
name
string
fn
func
(
ctx
context
.
Context
,
rdb
*
redis
.
Client
,
cache
ports
.
BillingCache
)
}{
{
name
:
"missing_key_returns_redis_nil"
,
fn
:
func
(
ctx
context
.
Context
,
rdb
*
redis
.
Client
,
cache
ports
.
BillingCache
)
{
userID
:=
int64
(
10
)
groupID
:=
int64
(
20
)
_
,
err
:=
cache
.
GetSubscriptionCache
(
ctx
,
userID
,
groupID
)
require
.
ErrorIs
(
s
.
T
(),
err
,
redis
.
Nil
,
"expected redis.Nil for missing subscription key"
)
},
},
{
name
:
"update_usage_on_nonexistent_is_noop"
,
fn
:
func
(
ctx
context
.
Context
,
rdb
*
redis
.
Client
,
cache
ports
.
BillingCache
)
{
userID
:=
int64
(
11
)
groupID
:=
int64
(
21
)
subKey
:=
fmt
.
Sprintf
(
"%s%d:%d"
,
billingSubKeyPrefix
,
userID
,
groupID
)
require
.
NoError
(
s
.
T
(),
cache
.
UpdateSubscriptionUsage
(
ctx
,
userID
,
groupID
,
1.0
),
"UpdateSubscriptionUsage should not error"
)
exists
,
err
:=
rdb
.
Exists
(
ctx
,
subKey
)
.
Result
()
require
.
NoError
(
s
.
T
(),
err
,
"Exists"
)
require
.
Equal
(
s
.
T
(),
int64
(
0
),
exists
,
"expected missing subscription key after UpdateSubscriptionUsage on non-existent"
)
},
},
{
name
:
"set_and_get_with_ttl"
,
fn
:
func
(
ctx
context
.
Context
,
rdb
*
redis
.
Client
,
cache
ports
.
BillingCache
)
{
userID
:=
int64
(
12
)
groupID
:=
int64
(
22
)
subKey
:=
fmt
.
Sprintf
(
"%s%d:%d"
,
billingSubKeyPrefix
,
userID
,
groupID
)
data
:=
&
ports
.
SubscriptionCacheData
{
Status
:
"active"
,
ExpiresAt
:
time
.
Now
()
.
Add
(
1
*
time
.
Hour
),
DailyUsage
:
1.0
,
WeeklyUsage
:
2.0
,
MonthlyUsage
:
3.0
,
Version
:
7
,
}
require
.
NoError
(
s
.
T
(),
cache
.
SetSubscriptionCache
(
ctx
,
userID
,
groupID
,
data
),
"SetSubscriptionCache"
)
gotSub
,
err
:=
cache
.
GetSubscriptionCache
(
ctx
,
userID
,
groupID
)
require
.
NoError
(
s
.
T
(),
err
,
"GetSubscriptionCache"
)
require
.
Equal
(
s
.
T
(),
"active"
,
gotSub
.
Status
)
require
.
Equal
(
s
.
T
(),
int64
(
7
),
gotSub
.
Version
)
require
.
Equal
(
s
.
T
(),
1.0
,
gotSub
.
DailyUsage
)
ttl
,
err
:=
rdb
.
TTL
(
ctx
,
subKey
)
.
Result
()
require
.
NoError
(
s
.
T
(),
err
,
"TTL subKey"
)
s
.
AssertTTLWithin
(
ttl
,
1
*
time
.
Second
,
billingCacheTTL
)
},
},
{
name
:
"update_usage_increments_all_fields"
,
fn
:
func
(
ctx
context
.
Context
,
rdb
*
redis
.
Client
,
cache
ports
.
BillingCache
)
{
userID
:=
int64
(
13
)
groupID
:=
int64
(
23
)
data
:=
&
ports
.
SubscriptionCacheData
{
Status
:
"active"
,
ExpiresAt
:
time
.
Now
()
.
Add
(
1
*
time
.
Hour
),
DailyUsage
:
1.0
,
WeeklyUsage
:
2.0
,
MonthlyUsage
:
3.0
,
Version
:
1
,
}
require
.
NoError
(
s
.
T
(),
cache
.
SetSubscriptionCache
(
ctx
,
userID
,
groupID
,
data
),
"SetSubscriptionCache"
)
require
.
NoError
(
s
.
T
(),
cache
.
UpdateSubscriptionUsage
(
ctx
,
userID
,
groupID
,
0.5
),
"UpdateSubscriptionUsage"
)
gotSub
,
err
:=
cache
.
GetSubscriptionCache
(
ctx
,
userID
,
groupID
)
require
.
NoError
(
s
.
T
(),
err
,
"GetSubscriptionCache after update"
)
require
.
Equal
(
s
.
T
(),
1.5
,
gotSub
.
DailyUsage
)
require
.
Equal
(
s
.
T
(),
2.5
,
gotSub
.
WeeklyUsage
)
require
.
Equal
(
s
.
T
(),
3.5
,
gotSub
.
MonthlyUsage
)
},
},
{
name
:
"invalidate_removes_key"
,
fn
:
func
(
ctx
context
.
Context
,
rdb
*
redis
.
Client
,
cache
ports
.
BillingCache
)
{
userID
:=
int64
(
101
)
groupID
:=
int64
(
10
)
subKey
:=
fmt
.
Sprintf
(
"%s%d:%d"
,
billingSubKeyPrefix
,
userID
,
groupID
)
data
:=
&
ports
.
SubscriptionCacheData
{
Status
:
"active"
,
ExpiresAt
:
time
.
Now
()
.
Add
(
1
*
time
.
Hour
),
DailyUsage
:
1.0
,
WeeklyUsage
:
2.0
,
MonthlyUsage
:
3.0
,
Version
:
1
,
}
require
.
NoError
(
s
.
T
(),
cache
.
SetSubscriptionCache
(
ctx
,
userID
,
groupID
,
data
),
"SetSubscriptionCache"
)
exists
,
err
:=
rdb
.
Exists
(
ctx
,
subKey
)
.
Result
()
require
.
NoError
(
s
.
T
(),
err
,
"Exists"
)
require
.
Equal
(
s
.
T
(),
int64
(
1
),
exists
,
"expected subscription key to exist"
)
require
.
NoError
(
s
.
T
(),
cache
.
InvalidateSubscriptionCache
(
ctx
,
userID
,
groupID
),
"InvalidateSubscriptionCache"
)
exists
,
err
=
rdb
.
Exists
(
ctx
,
subKey
)
.
Result
()
require
.
NoError
(
s
.
T
(),
err
,
"Exists after invalidate"
)
require
.
Equal
(
s
.
T
(),
int64
(
0
),
exists
,
"expected subscription key to be removed after invalidate"
)
_
,
err
=
cache
.
GetSubscriptionCache
(
ctx
,
userID
,
groupID
)
require
.
ErrorIs
(
s
.
T
(),
err
,
redis
.
Nil
,
"expected redis.Nil after invalidate"
)
},
},
{
name
:
"missing_status_returns_parsing_error"
,
fn
:
func
(
ctx
context
.
Context
,
rdb
*
redis
.
Client
,
cache
ports
.
BillingCache
)
{
userID
:=
int64
(
102
)
groupID
:=
int64
(
11
)
subKey
:=
fmt
.
Sprintf
(
"%s%d:%d"
,
billingSubKeyPrefix
,
userID
,
groupID
)
fields
:=
map
[
string
]
any
{
"expires_at"
:
time
.
Now
()
.
Add
(
1
*
time
.
Hour
)
.
Unix
(),
"daily_usage"
:
1.0
,
"weekly_usage"
:
2.0
,
"monthly_usage"
:
3.0
,
"version"
:
1
,
}
require
.
NoError
(
s
.
T
(),
rdb
.
HSet
(
ctx
,
subKey
,
fields
)
.
Err
(),
"HSet"
)
_
,
err
:=
cache
.
GetSubscriptionCache
(
ctx
,
userID
,
groupID
)
require
.
Error
(
s
.
T
(),
err
,
"expected error for missing status field"
)
require
.
NotErrorIs
(
s
.
T
(),
err
,
redis
.
Nil
,
"expected parsing error, not redis.Nil"
)
require
.
Equal
(
s
.
T
(),
"invalid cache: missing status"
,
err
.
Error
())
},
},
}
for
_
,
tt
:=
range
tests
{
s
.
Run
(
tt
.
name
,
func
()
{
rdb
:=
testRedis
(
s
.
T
())
cache
:=
NewBillingCache
(
rdb
)
ctx
:=
context
.
Background
()
tt
.
fn
(
ctx
,
rdb
,
cache
)
})
}
}
func
TestBillingCacheSuite
(
t
*
testing
.
T
)
{
suite
.
Run
(
t
,
new
(
BillingCacheSuite
))
}
backend/internal/repository/claude_oauth_service.go
View file @
25a304c2
...
...
@@ -16,20 +16,28 @@ import (
"github.com/imroc/req/v3"
)
type
claudeOAuthService
struct
{}
func
NewClaudeOAuthClient
()
service
.
ClaudeOAuthClient
{
return
&
claudeOAuthService
{}
return
&
claudeOAuthService
{
baseURL
:
"https://claude.ai"
,
tokenURL
:
oauth
.
TokenURL
,
clientFactory
:
createReqClient
,
}
}
type
claudeOAuthService
struct
{
baseURL
string
tokenURL
string
clientFactory
func
(
proxyURL
string
)
*
req
.
Client
}
func
(
s
*
claudeOAuthService
)
GetOrganizationUUID
(
ctx
context
.
Context
,
sessionKey
,
proxyURL
string
)
(
string
,
error
)
{
client
:=
createReqClient
(
proxyURL
)
client
:=
s
.
clientFactory
(
proxyURL
)
var
orgs
[]
struct
{
UUID
string
`json:"uuid"`
}
targetURL
:=
"https://claude.ai
/api/organizations"
targetURL
:=
s
.
baseURL
+
"
/api/organizations"
log
.
Printf
(
"[OAuth] Step 1: Getting organization UUID from %s"
,
targetURL
)
resp
,
err
:=
client
.
R
()
.
...
...
@@ -61,9 +69,9 @@ func (s *claudeOAuthService) GetOrganizationUUID(ctx context.Context, sessionKey
}
func
(
s
*
claudeOAuthService
)
GetAuthorizationCode
(
ctx
context
.
Context
,
sessionKey
,
orgUUID
,
scope
,
codeChallenge
,
state
,
proxyURL
string
)
(
string
,
error
)
{
client
:=
createReqClient
(
proxyURL
)
client
:=
s
.
clientFactory
(
proxyURL
)
authURL
:=
fmt
.
Sprintf
(
"
https://claude.ai
/v1/oauth/%s/authorize"
,
orgUUID
)
authURL
:=
fmt
.
Sprintf
(
"
%s
/v1/oauth/%s/authorize"
,
s
.
baseURL
,
orgUUID
)
reqBody
:=
map
[
string
]
any
{
"response_type"
:
"code"
,
...
...
@@ -133,12 +141,12 @@ func (s *claudeOAuthService) GetAuthorizationCode(ctx context.Context, sessionKe
fullCode
=
authCode
+
"#"
+
responseState
}
log
.
Printf
(
"[OAuth] Step 2 SUCCESS - Got authorization code: %s..."
,
authCode
[
:
20
]
)
log
.
Printf
(
"[OAuth] Step 2 SUCCESS - Got authorization code: %s..."
,
prefix
(
authCode
,
20
)
)
return
fullCode
,
nil
}
func
(
s
*
claudeOAuthService
)
ExchangeCodeForToken
(
ctx
context
.
Context
,
code
,
codeVerifier
,
state
,
proxyURL
string
)
(
*
oauth
.
TokenResponse
,
error
)
{
client
:=
createReqClient
(
proxyURL
)
client
:=
s
.
clientFactory
(
proxyURL
)
// Parse code which may contain state in format "authCode#state"
authCode
:=
code
...
...
@@ -161,7 +169,7 @@ func (s *claudeOAuthService) ExchangeCodeForToken(ctx context.Context, code, cod
}
reqBodyJSON
,
_
:=
json
.
Marshal
(
reqBody
)
log
.
Printf
(
"[OAuth] Step 3: Exchanging code for token at %s"
,
oauth
.
T
okenURL
)
log
.
Printf
(
"[OAuth] Step 3: Exchanging code for token at %s"
,
s
.
t
okenURL
)
log
.
Printf
(
"[OAuth] Step 3 Request Body: %s"
,
string
(
reqBodyJSON
))
var
tokenResp
oauth
.
TokenResponse
...
...
@@ -171,7 +179,7 @@ func (s *claudeOAuthService) ExchangeCodeForToken(ctx context.Context, code, cod
SetHeader
(
"Content-Type"
,
"application/json"
)
.
SetBody
(
reqBody
)
.
SetSuccessResult
(
&
tokenResp
)
.
Post
(
oauth
.
T
okenURL
)
Post
(
s
.
t
okenURL
)
if
err
!=
nil
{
log
.
Printf
(
"[OAuth] Step 3 FAILED - Request error: %v"
,
err
)
...
...
@@ -189,7 +197,7 @@ func (s *claudeOAuthService) ExchangeCodeForToken(ctx context.Context, code, cod
}
func
(
s
*
claudeOAuthService
)
RefreshToken
(
ctx
context
.
Context
,
refreshToken
,
proxyURL
string
)
(
*
oauth
.
TokenResponse
,
error
)
{
client
:=
createReqClient
(
proxyURL
)
client
:=
s
.
clientFactory
(
proxyURL
)
formData
:=
url
.
Values
{}
formData
.
Set
(
"grant_type"
,
"refresh_token"
)
...
...
@@ -202,7 +210,7 @@ func (s *claudeOAuthService) RefreshToken(ctx context.Context, refreshToken, pro
SetContext
(
ctx
)
.
SetFormDataFromValues
(
formData
)
.
SetSuccessResult
(
&
tokenResp
)
.
Post
(
oauth
.
T
okenURL
)
Post
(
s
.
t
okenURL
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"request failed: %w"
,
err
)
...
...
@@ -226,3 +234,13 @@ func createReqClient(proxyURL string) *req.Client {
return
client
}
func
prefix
(
s
string
,
n
int
)
string
{
if
n
<=
0
{
return
""
}
if
len
(
s
)
<=
n
{
return
s
}
return
s
[
:
n
]
}
backend/internal/repository/claude_oauth_service_test.go
0 → 100644
View file @
25a304c2
package
repository
import
(
"context"
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"
"github.com/Wei-Shaw/sub2api/internal/pkg/oauth"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
)
type
ClaudeOAuthServiceSuite
struct
{
suite
.
Suite
srv
*
httptest
.
Server
client
*
claudeOAuthService
}
func
(
s
*
ClaudeOAuthServiceSuite
)
TearDownTest
()
{
if
s
.
srv
!=
nil
{
s
.
srv
.
Close
()
s
.
srv
=
nil
}
}
// requestCapture holds captured request data for assertions in the main goroutine.
type
requestCapture
struct
{
path
string
method
string
cookies
[]
*
http
.
Cookie
body
[]
byte
formValues
url
.
Values
bodyJSON
map
[
string
]
any
contentType
string
}
func
(
s
*
ClaudeOAuthServiceSuite
)
TestGetOrganizationUUID
()
{
tests
:=
[]
struct
{
name
string
handler
http
.
HandlerFunc
wantErr
bool
errContain
string
wantUUID
string
validate
func
(
captured
requestCapture
)
}{
{
name
:
"success"
,
handler
:
func
(
w
http
.
ResponseWriter
,
r
*
http
.
Request
)
{
w
.
Header
()
.
Set
(
"Content-Type"
,
"application/json"
)
_
,
_
=
w
.
Write
([]
byte
(
`[{"uuid":"org-1"}]`
))
},
wantUUID
:
"org-1"
,
validate
:
func
(
captured
requestCapture
)
{
require
.
Equal
(
s
.
T
(),
"/api/organizations"
,
captured
.
path
,
"unexpected path"
)
require
.
Len
(
s
.
T
(),
captured
.
cookies
,
1
,
"expected 1 cookie"
)
require
.
Equal
(
s
.
T
(),
"sessionKey"
,
captured
.
cookies
[
0
]
.
Name
)
require
.
Equal
(
s
.
T
(),
"sess"
,
captured
.
cookies
[
0
]
.
Value
)
},
},
{
name
:
"non_200_returns_error"
,
handler
:
func
(
w
http
.
ResponseWriter
,
r
*
http
.
Request
)
{
w
.
WriteHeader
(
http
.
StatusUnauthorized
)
_
,
_
=
w
.
Write
([]
byte
(
"unauthorized"
))
},
wantErr
:
true
,
errContain
:
"401"
,
},
{
name
:
"invalid_json_returns_error"
,
handler
:
func
(
w
http
.
ResponseWriter
,
r
*
http
.
Request
)
{
w
.
Header
()
.
Set
(
"Content-Type"
,
"application/json"
)
_
,
_
=
w
.
Write
([]
byte
(
"not-json"
))
},
wantErr
:
true
,
},
}
for
_
,
tt
:=
range
tests
{
s
.
Run
(
tt
.
name
,
func
()
{
var
captured
requestCapture
s
.
srv
=
httptest
.
NewServer
(
http
.
HandlerFunc
(
func
(
w
http
.
ResponseWriter
,
r
*
http
.
Request
)
{
captured
.
path
=
r
.
URL
.
Path
captured
.
cookies
=
r
.
Cookies
()
tt
.
handler
(
w
,
r
)
}))
defer
s
.
srv
.
Close
()
client
,
ok
:=
NewClaudeOAuthClient
()
.
(
*
claudeOAuthService
)
require
.
True
(
s
.
T
(),
ok
,
"type assertion failed"
)
s
.
client
=
client
s
.
client
.
baseURL
=
s
.
srv
.
URL
got
,
err
:=
s
.
client
.
GetOrganizationUUID
(
context
.
Background
(),
"sess"
,
""
)
if
tt
.
wantErr
{
require
.
Error
(
s
.
T
(),
err
)
if
tt
.
errContain
!=
""
{
require
.
ErrorContains
(
s
.
T
(),
err
,
tt
.
errContain
)
}
return
}
require
.
NoError
(
s
.
T
(),
err
)
require
.
Equal
(
s
.
T
(),
tt
.
wantUUID
,
got
)
if
tt
.
validate
!=
nil
{
tt
.
validate
(
captured
)
}
})
}
}
func
(
s
*
ClaudeOAuthServiceSuite
)
TestGetAuthorizationCode
()
{
tests
:=
[]
struct
{
name
string
handler
http
.
HandlerFunc
wantErr
bool
wantCode
string
validate
func
(
captured
requestCapture
)
}{
{
name
:
"parses_redirect_uri"
,
handler
:
func
(
w
http
.
ResponseWriter
,
r
*
http
.
Request
)
{
w
.
Header
()
.
Set
(
"Content-Type"
,
"application/json"
)
_
=
json
.
NewEncoder
(
w
)
.
Encode
(
map
[
string
]
string
{
"redirect_uri"
:
oauth
.
RedirectURI
+
"?code=AUTH&state=STATE"
,
})
},
wantCode
:
"AUTH#STATE"
,
validate
:
func
(
captured
requestCapture
)
{
require
.
True
(
s
.
T
(),
strings
.
HasPrefix
(
captured
.
path
,
"/v1/oauth/"
)
&&
strings
.
HasSuffix
(
captured
.
path
,
"/authorize"
),
"unexpected path: %s"
,
captured
.
path
)
require
.
Equal
(
s
.
T
(),
http
.
MethodPost
,
captured
.
method
,
"expected POST"
)
require
.
Len
(
s
.
T
(),
captured
.
cookies
,
1
,
"expected 1 cookie"
)
require
.
Equal
(
s
.
T
(),
"sess"
,
captured
.
cookies
[
0
]
.
Value
)
require
.
Equal
(
s
.
T
(),
"org-1"
,
captured
.
bodyJSON
[
"organization_uuid"
])
require
.
Equal
(
s
.
T
(),
oauth
.
ClientID
,
captured
.
bodyJSON
[
"client_id"
])
require
.
Equal
(
s
.
T
(),
oauth
.
RedirectURI
,
captured
.
bodyJSON
[
"redirect_uri"
])
require
.
Equal
(
s
.
T
(),
"st"
,
captured
.
bodyJSON
[
"state"
])
},
},
{
name
:
"missing_code_returns_error"
,
handler
:
func
(
w
http
.
ResponseWriter
,
r
*
http
.
Request
)
{
w
.
Header
()
.
Set
(
"Content-Type"
,
"application/json"
)
_
=
json
.
NewEncoder
(
w
)
.
Encode
(
map
[
string
]
string
{
"redirect_uri"
:
oauth
.
RedirectURI
+
"?state=STATE"
,
// no code
})
},
wantErr
:
true
,
},
}
for
_
,
tt
:=
range
tests
{
s
.
Run
(
tt
.
name
,
func
()
{
var
captured
requestCapture
s
.
srv
=
httptest
.
NewServer
(
http
.
HandlerFunc
(
func
(
w
http
.
ResponseWriter
,
r
*
http
.
Request
)
{
captured
.
path
=
r
.
URL
.
Path
captured
.
method
=
r
.
Method
captured
.
cookies
=
r
.
Cookies
()
captured
.
body
,
_
=
io
.
ReadAll
(
r
.
Body
)
_
=
json
.
Unmarshal
(
captured
.
body
,
&
captured
.
bodyJSON
)
tt
.
handler
(
w
,
r
)
}))
defer
s
.
srv
.
Close
()
client
,
ok
:=
NewClaudeOAuthClient
()
.
(
*
claudeOAuthService
)
require
.
True
(
s
.
T
(),
ok
,
"type assertion failed"
)
s
.
client
=
client
s
.
client
.
baseURL
=
s
.
srv
.
URL
code
,
err
:=
s
.
client
.
GetAuthorizationCode
(
context
.
Background
(),
"sess"
,
"org-1"
,
oauth
.
ScopeProfile
,
"cc"
,
"st"
,
""
)
if
tt
.
wantErr
{
require
.
Error
(
s
.
T
(),
err
)
return
}
require
.
NoError
(
s
.
T
(),
err
)
require
.
Equal
(
s
.
T
(),
tt
.
wantCode
,
code
)
if
tt
.
validate
!=
nil
{
tt
.
validate
(
captured
)
}
})
}
}
func
(
s
*
ClaudeOAuthServiceSuite
)
TestExchangeCodeForToken
()
{
tests
:=
[]
struct
{
name
string
handler
http
.
HandlerFunc
code
string
wantErr
bool
wantResp
*
oauth
.
TokenResponse
validate
func
(
captured
requestCapture
)
}{
{
name
:
"sends_state_when_embedded"
,
handler
:
func
(
w
http
.
ResponseWriter
,
r
*
http
.
Request
)
{
w
.
Header
()
.
Set
(
"Content-Type"
,
"application/json"
)
_
=
json
.
NewEncoder
(
w
)
.
Encode
(
oauth
.
TokenResponse
{
AccessToken
:
"at"
,
TokenType
:
"bearer"
,
ExpiresIn
:
3600
,
RefreshToken
:
"rt"
,
Scope
:
"s"
,
})
},
code
:
"AUTH#STATE2"
,
wantResp
:
&
oauth
.
TokenResponse
{
AccessToken
:
"at"
,
RefreshToken
:
"rt"
,
},
validate
:
func
(
captured
requestCapture
)
{
require
.
Equal
(
s
.
T
(),
http
.
MethodPost
,
captured
.
method
,
"expected POST"
)
require
.
True
(
s
.
T
(),
strings
.
HasPrefix
(
captured
.
contentType
,
"application/json"
),
"unexpected content-type"
)
require
.
Equal
(
s
.
T
(),
"AUTH"
,
captured
.
bodyJSON
[
"code"
])
require
.
Equal
(
s
.
T
(),
"STATE2"
,
captured
.
bodyJSON
[
"state"
])
require
.
Equal
(
s
.
T
(),
oauth
.
ClientID
,
captured
.
bodyJSON
[
"client_id"
])
require
.
Equal
(
s
.
T
(),
oauth
.
RedirectURI
,
captured
.
bodyJSON
[
"redirect_uri"
])
require
.
Equal
(
s
.
T
(),
"ver"
,
captured
.
bodyJSON
[
"code_verifier"
])
},
},
{
name
:
"non_200_returns_error"
,
handler
:
func
(
w
http
.
ResponseWriter
,
r
*
http
.
Request
)
{
w
.
WriteHeader
(
http
.
StatusBadRequest
)
_
,
_
=
w
.
Write
([]
byte
(
"bad request"
))
},
code
:
"AUTH"
,
wantErr
:
true
,
},
}
for
_
,
tt
:=
range
tests
{
s
.
Run
(
tt
.
name
,
func
()
{
var
captured
requestCapture
s
.
srv
=
httptest
.
NewServer
(
http
.
HandlerFunc
(
func
(
w
http
.
ResponseWriter
,
r
*
http
.
Request
)
{
captured
.
method
=
r
.
Method
captured
.
contentType
=
r
.
Header
.
Get
(
"Content-Type"
)
captured
.
body
,
_
=
io
.
ReadAll
(
r
.
Body
)
_
=
json
.
Unmarshal
(
captured
.
body
,
&
captured
.
bodyJSON
)
tt
.
handler
(
w
,
r
)
}))
defer
s
.
srv
.
Close
()
client
,
ok
:=
NewClaudeOAuthClient
()
.
(
*
claudeOAuthService
)
require
.
True
(
s
.
T
(),
ok
,
"type assertion failed"
)
s
.
client
=
client
s
.
client
.
tokenURL
=
s
.
srv
.
URL
resp
,
err
:=
s
.
client
.
ExchangeCodeForToken
(
context
.
Background
(),
tt
.
code
,
"ver"
,
""
,
""
)
if
tt
.
wantErr
{
require
.
Error
(
s
.
T
(),
err
)
return
}
require
.
NoError
(
s
.
T
(),
err
)
require
.
Equal
(
s
.
T
(),
tt
.
wantResp
.
AccessToken
,
resp
.
AccessToken
)
require
.
Equal
(
s
.
T
(),
tt
.
wantResp
.
RefreshToken
,
resp
.
RefreshToken
)
if
tt
.
validate
!=
nil
{
tt
.
validate
(
captured
)
}
})
}
}
func
(
s
*
ClaudeOAuthServiceSuite
)
TestRefreshToken
()
{
tests
:=
[]
struct
{
name
string
handler
http
.
HandlerFunc
wantErr
bool
wantResp
*
oauth
.
TokenResponse
validate
func
(
captured
requestCapture
)
}{
{
name
:
"sends_form"
,
handler
:
func
(
w
http
.
ResponseWriter
,
r
*
http
.
Request
)
{
w
.
Header
()
.
Set
(
"Content-Type"
,
"application/json"
)
_
=
json
.
NewEncoder
(
w
)
.
Encode
(
oauth
.
TokenResponse
{
AccessToken
:
"at2"
,
TokenType
:
"bearer"
,
ExpiresIn
:
3600
})
},
wantResp
:
&
oauth
.
TokenResponse
{
AccessToken
:
"at2"
},
validate
:
func
(
captured
requestCapture
)
{
require
.
Equal
(
s
.
T
(),
http
.
MethodPost
,
captured
.
method
,
"expected POST"
)
require
.
Equal
(
s
.
T
(),
"refresh_token"
,
captured
.
formValues
.
Get
(
"grant_type"
))
require
.
Equal
(
s
.
T
(),
"rt"
,
captured
.
formValues
.
Get
(
"refresh_token"
))
require
.
Equal
(
s
.
T
(),
oauth
.
ClientID
,
captured
.
formValues
.
Get
(
"client_id"
))
},
},
{
name
:
"non_200_returns_error"
,
handler
:
func
(
w
http
.
ResponseWriter
,
r
*
http
.
Request
)
{
w
.
WriteHeader
(
http
.
StatusUnauthorized
)
_
,
_
=
w
.
Write
([]
byte
(
"unauthorized"
))
},
wantErr
:
true
,
},
}
for
_
,
tt
:=
range
tests
{
s
.
Run
(
tt
.
name
,
func
()
{
var
captured
requestCapture
s
.
srv
=
httptest
.
NewServer
(
http
.
HandlerFunc
(
func
(
w
http
.
ResponseWriter
,
r
*
http
.
Request
)
{
captured
.
method
=
r
.
Method
captured
.
body
,
_
=
io
.
ReadAll
(
r
.
Body
)
captured
.
formValues
,
_
=
url
.
ParseQuery
(
string
(
captured
.
body
))
tt
.
handler
(
w
,
r
)
}))
defer
s
.
srv
.
Close
()
client
,
ok
:=
NewClaudeOAuthClient
()
.
(
*
claudeOAuthService
)
require
.
True
(
s
.
T
(),
ok
,
"type assertion failed"
)
s
.
client
=
client
s
.
client
.
tokenURL
=
s
.
srv
.
URL
resp
,
err
:=
s
.
client
.
RefreshToken
(
context
.
Background
(),
"rt"
,
""
)
if
tt
.
wantErr
{
require
.
Error
(
s
.
T
(),
err
)
return
}
require
.
NoError
(
s
.
T
(),
err
)
require
.
Equal
(
s
.
T
(),
tt
.
wantResp
.
AccessToken
,
resp
.
AccessToken
)
if
tt
.
validate
!=
nil
{
tt
.
validate
(
captured
)
}
})
}
}
func
TestClaudeOAuthServiceSuite
(
t
*
testing
.
T
)
{
suite
.
Run
(
t
,
new
(
ClaudeOAuthServiceSuite
))
}
backend/internal/repository/claude_usage_service.go
View file @
25a304c2
...
...
@@ -12,10 +12,14 @@ import (
"github.com/Wei-Shaw/sub2api/internal/service"
)
type
claudeUsageService
struct
{}
const
defaultClaudeUsageURL
=
"https://api.anthropic.com/api/oauth/usage"
type
claudeUsageService
struct
{
usageURL
string
}
func
NewClaudeUsageFetcher
()
service
.
ClaudeUsageFetcher
{
return
&
claudeUsageService
{}
return
&
claudeUsageService
{
usageURL
:
defaultClaudeUsageURL
}
}
func
(
s
*
claudeUsageService
)
FetchUsage
(
ctx
context
.
Context
,
accessToken
,
proxyURL
string
)
(
*
service
.
ClaudeUsageResponse
,
error
)
{
...
...
@@ -35,7 +39,7 @@ func (s *claudeUsageService) FetchUsage(ctx context.Context, accessToken, proxyU
Timeout
:
30
*
time
.
Second
,
}
req
,
err
:=
http
.
NewRequestWithContext
(
ctx
,
"GET"
,
"https://api.anthropic.com/api/oauth/
usage
"
,
nil
)
req
,
err
:=
http
.
NewRequestWithContext
(
ctx
,
"GET"
,
s
.
usage
URL
,
nil
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"create request failed: %w"
,
err
)
}
...
...
backend/internal/repository/claude_usage_service_test.go
0 → 100644
View file @
25a304c2
package
repository
import
(
"context"
"io"
"net/http"
"net/http/httptest"
"testing"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
)
type
ClaudeUsageServiceSuite
struct
{
suite
.
Suite
srv
*
httptest
.
Server
fetcher
*
claudeUsageService
}
func
(
s
*
ClaudeUsageServiceSuite
)
TearDownTest
()
{
if
s
.
srv
!=
nil
{
s
.
srv
.
Close
()
s
.
srv
=
nil
}
}
// usageRequestCapture holds captured request data for assertions in the main goroutine.
type
usageRequestCapture
struct
{
authorization
string
anthropicBeta
string
}
func
(
s
*
ClaudeUsageServiceSuite
)
TestFetchUsage_Success
()
{
var
captured
usageRequestCapture
s
.
srv
=
httptest
.
NewServer
(
http
.
HandlerFunc
(
func
(
w
http
.
ResponseWriter
,
r
*
http
.
Request
)
{
captured
.
authorization
=
r
.
Header
.
Get
(
"Authorization"
)
captured
.
anthropicBeta
=
r
.
Header
.
Get
(
"anthropic-beta"
)
w
.
Header
()
.
Set
(
"Content-Type"
,
"application/json"
)
_
,
_
=
io
.
WriteString
(
w
,
`{
"five_hour": {"utilization": 12.5, "resets_at": "2025-01-01T00:00:00Z"},
"seven_day": {"utilization": 34.0, "resets_at": "2025-01-08T00:00:00Z"},
"seven_day_sonnet": {"utilization": 56.0, "resets_at": "2025-01-08T00:00:00Z"}
}`
)
}))
s
.
fetcher
=
&
claudeUsageService
{
usageURL
:
s
.
srv
.
URL
}
resp
,
err
:=
s
.
fetcher
.
FetchUsage
(
context
.
Background
(),
"at"
,
"://bad-proxy-url"
)
require
.
NoError
(
s
.
T
(),
err
,
"FetchUsage"
)
require
.
Equal
(
s
.
T
(),
12.5
,
resp
.
FiveHour
.
Utilization
,
"FiveHour utilization mismatch"
)
require
.
Equal
(
s
.
T
(),
34.0
,
resp
.
SevenDay
.
Utilization
,
"SevenDay utilization mismatch"
)
require
.
Equal
(
s
.
T
(),
56.0
,
resp
.
SevenDaySonnet
.
Utilization
,
"SevenDaySonnet utilization mismatch"
)
// Assertions on captured request data
require
.
Equal
(
s
.
T
(),
"Bearer at"
,
captured
.
authorization
,
"Authorization header mismatch"
)
require
.
Equal
(
s
.
T
(),
"oauth-2025-04-20"
,
captured
.
anthropicBeta
,
"anthropic-beta header mismatch"
)
}
func
(
s
*
ClaudeUsageServiceSuite
)
TestFetchUsage_NonOK
()
{
s
.
srv
=
httptest
.
NewServer
(
http
.
HandlerFunc
(
func
(
w
http
.
ResponseWriter
,
r
*
http
.
Request
)
{
w
.
WriteHeader
(
http
.
StatusUnauthorized
)
_
,
_
=
io
.
WriteString
(
w
,
"nope"
)
}))
s
.
fetcher
=
&
claudeUsageService
{
usageURL
:
s
.
srv
.
URL
}
_
,
err
:=
s
.
fetcher
.
FetchUsage
(
context
.
Background
(),
"at"
,
""
)
require
.
Error
(
s
.
T
(),
err
)
require
.
ErrorContains
(
s
.
T
(),
err
,
"status 401"
)
require
.
ErrorContains
(
s
.
T
(),
err
,
"nope"
)
}
func
(
s
*
ClaudeUsageServiceSuite
)
TestFetchUsage_BadJSON
()
{
s
.
srv
=
httptest
.
NewServer
(
http
.
HandlerFunc
(
func
(
w
http
.
ResponseWriter
,
r
*
http
.
Request
)
{
w
.
Header
()
.
Set
(
"Content-Type"
,
"application/json"
)
_
,
_
=
io
.
WriteString
(
w
,
"not-json"
)
}))
s
.
fetcher
=
&
claudeUsageService
{
usageURL
:
s
.
srv
.
URL
}
_
,
err
:=
s
.
fetcher
.
FetchUsage
(
context
.
Background
(),
"at"
,
""
)
require
.
Error
(
s
.
T
(),
err
)
require
.
ErrorContains
(
s
.
T
(),
err
,
"decode response failed"
)
}
func
(
s
*
ClaudeUsageServiceSuite
)
TestFetchUsage_ContextCancel
()
{
s
.
srv
=
httptest
.
NewServer
(
http
.
HandlerFunc
(
func
(
w
http
.
ResponseWriter
,
r
*
http
.
Request
)
{
// Never respond - simulate slow server
<-
r
.
Context
()
.
Done
()
}))
s
.
fetcher
=
&
claudeUsageService
{
usageURL
:
s
.
srv
.
URL
}
ctx
,
cancel
:=
context
.
WithCancel
(
context
.
Background
())
cancel
()
// Cancel immediately
_
,
err
:=
s
.
fetcher
.
FetchUsage
(
ctx
,
"at"
,
""
)
require
.
Error
(
s
.
T
(),
err
,
"expected error for cancelled context"
)
}
func
TestClaudeUsageServiceSuite
(
t
*
testing
.
T
)
{
suite
.
Run
(
t
,
new
(
ClaudeUsageServiceSuite
))
}
backend/internal/repository/concurrency_cache_integration_test.go
0 → 100644
View file @
25a304c2
//go:build integration
package
repository
import
(
"errors"
"fmt"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/service/ports"
"github.com/redis/go-redis/v9"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
)
type
ConcurrencyCacheSuite
struct
{
IntegrationRedisSuite
cache
ports
.
ConcurrencyCache
}
func
(
s
*
ConcurrencyCacheSuite
)
SetupTest
()
{
s
.
IntegrationRedisSuite
.
SetupTest
()
s
.
cache
=
NewConcurrencyCache
(
s
.
rdb
)
}
func
(
s
*
ConcurrencyCacheSuite
)
TestAccountSlot_AcquireAndRelease
()
{
accountID
:=
int64
(
10
)
reqID1
,
reqID2
,
reqID3
:=
"req1"
,
"req2"
,
"req3"
ok
,
err
:=
s
.
cache
.
AcquireAccountSlot
(
s
.
ctx
,
accountID
,
2
,
reqID1
)
require
.
NoError
(
s
.
T
(),
err
,
"AcquireAccountSlot 1"
)
require
.
True
(
s
.
T
(),
ok
)
ok
,
err
=
s
.
cache
.
AcquireAccountSlot
(
s
.
ctx
,
accountID
,
2
,
reqID2
)
require
.
NoError
(
s
.
T
(),
err
,
"AcquireAccountSlot 2"
)
require
.
True
(
s
.
T
(),
ok
)
ok
,
err
=
s
.
cache
.
AcquireAccountSlot
(
s
.
ctx
,
accountID
,
2
,
reqID3
)
require
.
NoError
(
s
.
T
(),
err
,
"AcquireAccountSlot 3"
)
require
.
False
(
s
.
T
(),
ok
,
"expected third acquire to fail"
)
cur
,
err
:=
s
.
cache
.
GetAccountConcurrency
(
s
.
ctx
,
accountID
)
require
.
NoError
(
s
.
T
(),
err
,
"GetAccountConcurrency"
)
require
.
Equal
(
s
.
T
(),
2
,
cur
,
"concurrency mismatch"
)
require
.
NoError
(
s
.
T
(),
s
.
cache
.
ReleaseAccountSlot
(
s
.
ctx
,
accountID
,
reqID1
),
"ReleaseAccountSlot"
)
cur
,
err
=
s
.
cache
.
GetAccountConcurrency
(
s
.
ctx
,
accountID
)
require
.
NoError
(
s
.
T
(),
err
,
"GetAccountConcurrency after release"
)
require
.
Equal
(
s
.
T
(),
1
,
cur
,
"expected 1 after release"
)
}
func
(
s
*
ConcurrencyCacheSuite
)
TestAccountSlot_TTL
()
{
accountID
:=
int64
(
11
)
reqID
:=
"req_ttl_test"
slotKey
:=
fmt
.
Sprintf
(
"%s%d:%s"
,
accountSlotKeyPrefix
,
accountID
,
reqID
)
ok
,
err
:=
s
.
cache
.
AcquireAccountSlot
(
s
.
ctx
,
accountID
,
5
,
reqID
)
require
.
NoError
(
s
.
T
(),
err
,
"AcquireAccountSlot"
)
require
.
True
(
s
.
T
(),
ok
)
ttl
,
err
:=
s
.
rdb
.
TTL
(
s
.
ctx
,
slotKey
)
.
Result
()
require
.
NoError
(
s
.
T
(),
err
,
"TTL"
)
s
.
AssertTTLWithin
(
ttl
,
1
*
time
.
Second
,
slotTTL
)
}
func
(
s
*
ConcurrencyCacheSuite
)
TestAccountSlot_DuplicateReqID
()
{
accountID
:=
int64
(
12
)
reqID
:=
"dup-req"
ok
,
err
:=
s
.
cache
.
AcquireAccountSlot
(
s
.
ctx
,
accountID
,
2
,
reqID
)
require
.
NoError
(
s
.
T
(),
err
)
require
.
True
(
s
.
T
(),
ok
)
// Acquiring with same reqID should be idempotent
ok
,
err
=
s
.
cache
.
AcquireAccountSlot
(
s
.
ctx
,
accountID
,
2
,
reqID
)
require
.
NoError
(
s
.
T
(),
err
)
require
.
True
(
s
.
T
(),
ok
)
cur
,
err
:=
s
.
cache
.
GetAccountConcurrency
(
s
.
ctx
,
accountID
)
require
.
NoError
(
s
.
T
(),
err
)
require
.
Equal
(
s
.
T
(),
1
,
cur
,
"expected concurrency=1 (idempotent)"
)
}
func
(
s
*
ConcurrencyCacheSuite
)
TestAccountSlot_ReleaseIdempotent
()
{
accountID
:=
int64
(
13
)
reqID
:=
"release-test"
ok
,
err
:=
s
.
cache
.
AcquireAccountSlot
(
s
.
ctx
,
accountID
,
1
,
reqID
)
require
.
NoError
(
s
.
T
(),
err
)
require
.
True
(
s
.
T
(),
ok
)
require
.
NoError
(
s
.
T
(),
s
.
cache
.
ReleaseAccountSlot
(
s
.
ctx
,
accountID
,
reqID
),
"ReleaseAccountSlot"
)
// Releasing again should not error
require
.
NoError
(
s
.
T
(),
s
.
cache
.
ReleaseAccountSlot
(
s
.
ctx
,
accountID
,
reqID
),
"ReleaseAccountSlot again"
)
// Releasing non-existent should not error
require
.
NoError
(
s
.
T
(),
s
.
cache
.
ReleaseAccountSlot
(
s
.
ctx
,
accountID
,
"non-existent"
),
"ReleaseAccountSlot non-existent"
)
cur
,
err
:=
s
.
cache
.
GetAccountConcurrency
(
s
.
ctx
,
accountID
)
require
.
NoError
(
s
.
T
(),
err
)
require
.
Equal
(
s
.
T
(),
0
,
cur
)
}
func
(
s
*
ConcurrencyCacheSuite
)
TestAccountSlot_MaxZero
()
{
accountID
:=
int64
(
14
)
reqID
:=
"max-zero-test"
ok
,
err
:=
s
.
cache
.
AcquireAccountSlot
(
s
.
ctx
,
accountID
,
0
,
reqID
)
require
.
NoError
(
s
.
T
(),
err
)
require
.
False
(
s
.
T
(),
ok
,
"expected acquire to fail with max=0"
)
}
func
(
s
*
ConcurrencyCacheSuite
)
TestUserSlot_AcquireAndRelease
()
{
userID
:=
int64
(
42
)
reqID1
,
reqID2
:=
"req1"
,
"req2"
ok
,
err
:=
s
.
cache
.
AcquireUserSlot
(
s
.
ctx
,
userID
,
1
,
reqID1
)
require
.
NoError
(
s
.
T
(),
err
,
"AcquireUserSlot"
)
require
.
True
(
s
.
T
(),
ok
)
ok
,
err
=
s
.
cache
.
AcquireUserSlot
(
s
.
ctx
,
userID
,
1
,
reqID2
)
require
.
NoError
(
s
.
T
(),
err
,
"AcquireUserSlot 2"
)
require
.
False
(
s
.
T
(),
ok
,
"expected second acquire to fail at max=1"
)
cur
,
err
:=
s
.
cache
.
GetUserConcurrency
(
s
.
ctx
,
userID
)
require
.
NoError
(
s
.
T
(),
err
,
"GetUserConcurrency"
)
require
.
Equal
(
s
.
T
(),
1
,
cur
,
"expected concurrency=1"
)
require
.
NoError
(
s
.
T
(),
s
.
cache
.
ReleaseUserSlot
(
s
.
ctx
,
userID
,
reqID1
),
"ReleaseUserSlot"
)
// Releasing a non-existent slot should not error
require
.
NoError
(
s
.
T
(),
s
.
cache
.
ReleaseUserSlot
(
s
.
ctx
,
userID
,
"non-existent"
),
"ReleaseUserSlot non-existent"
)
cur
,
err
=
s
.
cache
.
GetUserConcurrency
(
s
.
ctx
,
userID
)
require
.
NoError
(
s
.
T
(),
err
,
"GetUserConcurrency after release"
)
require
.
Equal
(
s
.
T
(),
0
,
cur
,
"expected concurrency=0 after release"
)
}
func
(
s
*
ConcurrencyCacheSuite
)
TestUserSlot_TTL
()
{
userID
:=
int64
(
200
)
reqID
:=
"req_ttl_test"
slotKey
:=
fmt
.
Sprintf
(
"%s%d:%s"
,
userSlotKeyPrefix
,
userID
,
reqID
)
ok
,
err
:=
s
.
cache
.
AcquireUserSlot
(
s
.
ctx
,
userID
,
5
,
reqID
)
require
.
NoError
(
s
.
T
(),
err
,
"AcquireUserSlot"
)
require
.
True
(
s
.
T
(),
ok
)
ttl
,
err
:=
s
.
rdb
.
TTL
(
s
.
ctx
,
slotKey
)
.
Result
()
require
.
NoError
(
s
.
T
(),
err
,
"TTL"
)
s
.
AssertTTLWithin
(
ttl
,
1
*
time
.
Second
,
slotTTL
)
}
func
(
s
*
ConcurrencyCacheSuite
)
TestWaitQueue_IncrementAndDecrement
()
{
userID
:=
int64
(
20
)
waitKey
:=
fmt
.
Sprintf
(
"%s%d"
,
waitQueueKeyPrefix
,
userID
)
ok
,
err
:=
s
.
cache
.
IncrementWaitCount
(
s
.
ctx
,
userID
,
2
)
require
.
NoError
(
s
.
T
(),
err
,
"IncrementWaitCount 1"
)
require
.
True
(
s
.
T
(),
ok
)
ok
,
err
=
s
.
cache
.
IncrementWaitCount
(
s
.
ctx
,
userID
,
2
)
require
.
NoError
(
s
.
T
(),
err
,
"IncrementWaitCount 2"
)
require
.
True
(
s
.
T
(),
ok
)
ok
,
err
=
s
.
cache
.
IncrementWaitCount
(
s
.
ctx
,
userID
,
2
)
require
.
NoError
(
s
.
T
(),
err
,
"IncrementWaitCount 3"
)
require
.
False
(
s
.
T
(),
ok
,
"expected wait increment over max to fail"
)
ttl
,
err
:=
s
.
rdb
.
TTL
(
s
.
ctx
,
waitKey
)
.
Result
()
require
.
NoError
(
s
.
T
(),
err
,
"TTL waitKey"
)
s
.
AssertTTLWithin
(
ttl
,
1
*
time
.
Second
,
slotTTL
)
require
.
NoError
(
s
.
T
(),
s
.
cache
.
DecrementWaitCount
(
s
.
ctx
,
userID
),
"DecrementWaitCount"
)
val
,
err
:=
s
.
rdb
.
Get
(
s
.
ctx
,
waitKey
)
.
Int
()
if
!
errors
.
Is
(
err
,
redis
.
Nil
)
{
require
.
NoError
(
s
.
T
(),
err
,
"Get waitKey"
)
}
require
.
Equal
(
s
.
T
(),
1
,
val
,
"expected wait count 1"
)
}
func
(
s
*
ConcurrencyCacheSuite
)
TestWaitQueue_DecrementNoNegative
()
{
userID
:=
int64
(
300
)
waitKey
:=
fmt
.
Sprintf
(
"%s%d"
,
waitQueueKeyPrefix
,
userID
)
// Test decrement on non-existent key - should not error and should not create negative value
require
.
NoError
(
s
.
T
(),
s
.
cache
.
DecrementWaitCount
(
s
.
ctx
,
userID
),
"DecrementWaitCount on non-existent key"
)
// Verify no key was created or it's not negative
val
,
err
:=
s
.
rdb
.
Get
(
s
.
ctx
,
waitKey
)
.
Int
()
if
!
errors
.
Is
(
err
,
redis
.
Nil
)
{
require
.
NoError
(
s
.
T
(),
err
,
"Get waitKey"
)
}
require
.
GreaterOrEqual
(
s
.
T
(),
val
,
0
,
"expected non-negative wait count after decrement on empty"
)
// Set count to 1, then decrement twice
ok
,
err
:=
s
.
cache
.
IncrementWaitCount
(
s
.
ctx
,
userID
,
5
)
require
.
NoError
(
s
.
T
(),
err
,
"IncrementWaitCount"
)
require
.
True
(
s
.
T
(),
ok
)
// Decrement once (1 -> 0)
require
.
NoError
(
s
.
T
(),
s
.
cache
.
DecrementWaitCount
(
s
.
ctx
,
userID
),
"DecrementWaitCount"
)
// Decrement again on 0 - should not go negative
require
.
NoError
(
s
.
T
(),
s
.
cache
.
DecrementWaitCount
(
s
.
ctx
,
userID
),
"DecrementWaitCount on zero"
)
// Verify count is 0, not negative
val
,
err
=
s
.
rdb
.
Get
(
s
.
ctx
,
waitKey
)
.
Int
()
if
!
errors
.
Is
(
err
,
redis
.
Nil
)
{
require
.
NoError
(
s
.
T
(),
err
,
"Get waitKey after double decrement"
)
}
require
.
GreaterOrEqual
(
s
.
T
(),
val
,
0
,
"expected non-negative wait count"
)
}
func
(
s
*
ConcurrencyCacheSuite
)
TestGetAccountConcurrency_Missing
()
{
// When no slots exist, GetAccountConcurrency should return 0
cur
,
err
:=
s
.
cache
.
GetAccountConcurrency
(
s
.
ctx
,
999
)
require
.
NoError
(
s
.
T
(),
err
)
require
.
Equal
(
s
.
T
(),
0
,
cur
)
}
func
(
s
*
ConcurrencyCacheSuite
)
TestGetUserConcurrency_Missing
()
{
// When no slots exist, GetUserConcurrency should return 0
cur
,
err
:=
s
.
cache
.
GetUserConcurrency
(
s
.
ctx
,
999
)
require
.
NoError
(
s
.
T
(),
err
)
require
.
Equal
(
s
.
T
(),
0
,
cur
)
}
func
TestConcurrencyCacheSuite
(
t
*
testing
.
T
)
{
suite
.
Run
(
t
,
new
(
ConcurrencyCacheSuite
))
}
backend/internal/repository/email_cache_integration_test.go
0 → 100644
View file @
25a304c2
//go:build integration
package
repository
import
(
"errors"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/service/ports"
"github.com/redis/go-redis/v9"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
)
type
EmailCacheSuite
struct
{
IntegrationRedisSuite
cache
ports
.
EmailCache
}
func
(
s
*
EmailCacheSuite
)
SetupTest
()
{
s
.
IntegrationRedisSuite
.
SetupTest
()
s
.
cache
=
NewEmailCache
(
s
.
rdb
)
}
func
(
s
*
EmailCacheSuite
)
TestGetVerificationCode_Missing
()
{
_
,
err
:=
s
.
cache
.
GetVerificationCode
(
s
.
ctx
,
"nonexistent@example.com"
)
require
.
True
(
s
.
T
(),
errors
.
Is
(
err
,
redis
.
Nil
),
"expected redis.Nil for missing verification code"
)
}
func
(
s
*
EmailCacheSuite
)
TestSetAndGetVerificationCode
()
{
email
:=
"a@example.com"
emailTTL
:=
2
*
time
.
Minute
data
:=
&
ports
.
VerificationCodeData
{
Code
:
"123456"
,
Attempts
:
1
,
CreatedAt
:
time
.
Now
()}
require
.
NoError
(
s
.
T
(),
s
.
cache
.
SetVerificationCode
(
s
.
ctx
,
email
,
data
,
emailTTL
),
"SetVerificationCode"
)
got
,
err
:=
s
.
cache
.
GetVerificationCode
(
s
.
ctx
,
email
)
require
.
NoError
(
s
.
T
(),
err
,
"GetVerificationCode"
)
require
.
Equal
(
s
.
T
(),
"123456"
,
got
.
Code
)
require
.
Equal
(
s
.
T
(),
1
,
got
.
Attempts
)
}
func
(
s
*
EmailCacheSuite
)
TestVerificationCode_TTL
()
{
email
:=
"ttl@example.com"
emailTTL
:=
2
*
time
.
Minute
data
:=
&
ports
.
VerificationCodeData
{
Code
:
"654321"
,
Attempts
:
0
,
CreatedAt
:
time
.
Now
()}
require
.
NoError
(
s
.
T
(),
s
.
cache
.
SetVerificationCode
(
s
.
ctx
,
email
,
data
,
emailTTL
),
"SetVerificationCode"
)
emailKey
:=
verifyCodeKeyPrefix
+
email
ttl
,
err
:=
s
.
rdb
.
TTL
(
s
.
ctx
,
emailKey
)
.
Result
()
require
.
NoError
(
s
.
T
(),
err
,
"TTL emailKey"
)
s
.
AssertTTLWithin
(
ttl
,
1
*
time
.
Second
,
emailTTL
)
}
func
(
s
*
EmailCacheSuite
)
TestDeleteVerificationCode
()
{
email
:=
"delete@example.com"
data
:=
&
ports
.
VerificationCodeData
{
Code
:
"999999"
,
Attempts
:
0
,
CreatedAt
:
time
.
Now
()}
require
.
NoError
(
s
.
T
(),
s
.
cache
.
SetVerificationCode
(
s
.
ctx
,
email
,
data
,
2
*
time
.
Minute
),
"SetVerificationCode"
)
// Verify it exists
_
,
err
:=
s
.
cache
.
GetVerificationCode
(
s
.
ctx
,
email
)
require
.
NoError
(
s
.
T
(),
err
,
"GetVerificationCode before delete"
)
// Delete
require
.
NoError
(
s
.
T
(),
s
.
cache
.
DeleteVerificationCode
(
s
.
ctx
,
email
),
"DeleteVerificationCode"
)
// Verify it's gone
_
,
err
=
s
.
cache
.
GetVerificationCode
(
s
.
ctx
,
email
)
require
.
True
(
s
.
T
(),
errors
.
Is
(
err
,
redis
.
Nil
),
"expected redis.Nil after delete"
)
}
func
(
s
*
EmailCacheSuite
)
TestDeleteVerificationCode_NonExistent
()
{
// Deleting a non-existent key should not error
require
.
NoError
(
s
.
T
(),
s
.
cache
.
DeleteVerificationCode
(
s
.
ctx
,
"nonexistent@example.com"
),
"DeleteVerificationCode non-existent"
)
}
func
(
s
*
EmailCacheSuite
)
TestGetVerificationCode_JSONCorruption
()
{
emailKey
:=
verifyCodeKeyPrefix
+
"corrupted@example.com"
require
.
NoError
(
s
.
T
(),
s
.
rdb
.
Set
(
s
.
ctx
,
emailKey
,
"not-json"
,
1
*
time
.
Minute
)
.
Err
(),
"Set invalid JSON"
)
_
,
err
:=
s
.
cache
.
GetVerificationCode
(
s
.
ctx
,
"corrupted@example.com"
)
require
.
Error
(
s
.
T
(),
err
,
"expected error for corrupted JSON"
)
require
.
False
(
s
.
T
(),
errors
.
Is
(
err
,
redis
.
Nil
),
"expected decoding error, not redis.Nil"
)
}
func
TestEmailCacheSuite
(
t
*
testing
.
T
)
{
suite
.
Run
(
t
,
new
(
EmailCacheSuite
))
}
backend/internal/repository/fixtures_integration_test.go
0 → 100644
View file @
25a304c2
//go:build integration
package
repository
import
(
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/stretchr/testify/require"
"gorm.io/gorm"
)
func
mustCreateUser
(
t
*
testing
.
T
,
db
*
gorm
.
DB
,
u
*
model
.
User
)
*
model
.
User
{
t
.
Helper
()
if
u
.
PasswordHash
==
""
{
u
.
PasswordHash
=
"test-password-hash"
}
if
u
.
Role
==
""
{
u
.
Role
=
model
.
RoleUser
}
if
u
.
Status
==
""
{
u
.
Status
=
model
.
StatusActive
}
if
u
.
CreatedAt
.
IsZero
()
{
u
.
CreatedAt
=
time
.
Now
()
}
if
u
.
UpdatedAt
.
IsZero
()
{
u
.
UpdatedAt
=
u
.
CreatedAt
}
require
.
NoError
(
t
,
db
.
Create
(
u
)
.
Error
,
"create user"
)
return
u
}
func
mustCreateGroup
(
t
*
testing
.
T
,
db
*
gorm
.
DB
,
g
*
model
.
Group
)
*
model
.
Group
{
t
.
Helper
()
if
g
.
Platform
==
""
{
g
.
Platform
=
model
.
PlatformAnthropic
}
if
g
.
Status
==
""
{
g
.
Status
=
model
.
StatusActive
}
if
g
.
SubscriptionType
==
""
{
g
.
SubscriptionType
=
model
.
SubscriptionTypeStandard
}
if
g
.
CreatedAt
.
IsZero
()
{
g
.
CreatedAt
=
time
.
Now
()
}
if
g
.
UpdatedAt
.
IsZero
()
{
g
.
UpdatedAt
=
g
.
CreatedAt
}
require
.
NoError
(
t
,
db
.
Create
(
g
)
.
Error
,
"create group"
)
return
g
}
func
mustCreateProxy
(
t
*
testing
.
T
,
db
*
gorm
.
DB
,
p
*
model
.
Proxy
)
*
model
.
Proxy
{
t
.
Helper
()
if
p
.
Protocol
==
""
{
p
.
Protocol
=
"http"
}
if
p
.
Host
==
""
{
p
.
Host
=
"127.0.0.1"
}
if
p
.
Port
==
0
{
p
.
Port
=
8080
}
if
p
.
Status
==
""
{
p
.
Status
=
model
.
StatusActive
}
if
p
.
CreatedAt
.
IsZero
()
{
p
.
CreatedAt
=
time
.
Now
()
}
if
p
.
UpdatedAt
.
IsZero
()
{
p
.
UpdatedAt
=
p
.
CreatedAt
}
require
.
NoError
(
t
,
db
.
Create
(
p
)
.
Error
,
"create proxy"
)
return
p
}
func
mustCreateAccount
(
t
*
testing
.
T
,
db
*
gorm
.
DB
,
a
*
model
.
Account
)
*
model
.
Account
{
t
.
Helper
()
if
a
.
Platform
==
""
{
a
.
Platform
=
model
.
PlatformAnthropic
}
if
a
.
Type
==
""
{
a
.
Type
=
model
.
AccountTypeOAuth
}
if
a
.
Status
==
""
{
a
.
Status
=
model
.
StatusActive
}
if
!
a
.
Schedulable
{
a
.
Schedulable
=
true
}
if
a
.
Credentials
==
nil
{
a
.
Credentials
=
model
.
JSONB
{}
}
if
a
.
Extra
==
nil
{
a
.
Extra
=
model
.
JSONB
{}
}
if
a
.
CreatedAt
.
IsZero
()
{
a
.
CreatedAt
=
time
.
Now
()
}
if
a
.
UpdatedAt
.
IsZero
()
{
a
.
UpdatedAt
=
a
.
CreatedAt
}
require
.
NoError
(
t
,
db
.
Create
(
a
)
.
Error
,
"create account"
)
return
a
}
func
mustCreateApiKey
(
t
*
testing
.
T
,
db
*
gorm
.
DB
,
k
*
model
.
ApiKey
)
*
model
.
ApiKey
{
t
.
Helper
()
if
k
.
Status
==
""
{
k
.
Status
=
model
.
StatusActive
}
if
k
.
CreatedAt
.
IsZero
()
{
k
.
CreatedAt
=
time
.
Now
()
}
if
k
.
UpdatedAt
.
IsZero
()
{
k
.
UpdatedAt
=
k
.
CreatedAt
}
require
.
NoError
(
t
,
db
.
Create
(
k
)
.
Error
,
"create api key"
)
return
k
}
func
mustCreateRedeemCode
(
t
*
testing
.
T
,
db
*
gorm
.
DB
,
c
*
model
.
RedeemCode
)
*
model
.
RedeemCode
{
t
.
Helper
()
if
c
.
Status
==
""
{
c
.
Status
=
model
.
StatusUnused
}
if
c
.
Type
==
""
{
c
.
Type
=
model
.
RedeemTypeBalance
}
if
c
.
CreatedAt
.
IsZero
()
{
c
.
CreatedAt
=
time
.
Now
()
}
require
.
NoError
(
t
,
db
.
Create
(
c
)
.
Error
,
"create redeem code"
)
return
c
}
func
mustCreateSubscription
(
t
*
testing
.
T
,
db
*
gorm
.
DB
,
s
*
model
.
UserSubscription
)
*
model
.
UserSubscription
{
t
.
Helper
()
if
s
.
Status
==
""
{
s
.
Status
=
model
.
SubscriptionStatusActive
}
now
:=
time
.
Now
()
if
s
.
StartsAt
.
IsZero
()
{
s
.
StartsAt
=
now
.
Add
(
-
1
*
time
.
Hour
)
}
if
s
.
ExpiresAt
.
IsZero
()
{
s
.
ExpiresAt
=
now
.
Add
(
24
*
time
.
Hour
)
}
if
s
.
AssignedAt
.
IsZero
()
{
s
.
AssignedAt
=
now
}
if
s
.
CreatedAt
.
IsZero
()
{
s
.
CreatedAt
=
now
}
if
s
.
UpdatedAt
.
IsZero
()
{
s
.
UpdatedAt
=
now
}
require
.
NoError
(
t
,
db
.
Create
(
s
)
.
Error
,
"create user subscription"
)
return
s
}
func
mustBindAccountToGroup
(
t
*
testing
.
T
,
db
*
gorm
.
DB
,
accountID
,
groupID
int64
,
priority
int
)
{
t
.
Helper
()
require
.
NoError
(
t
,
db
.
Create
(
&
model
.
AccountGroup
{
AccountID
:
accountID
,
GroupID
:
groupID
,
Priority
:
priority
,
})
.
Error
,
"create account_group"
)
}
backend/internal/repository/gateway_cache_integration_test.go
0 → 100644
View file @
25a304c2
//go:build integration
package
repository
import
(
"errors"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/service/ports"
"github.com/redis/go-redis/v9"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
)
type
GatewayCacheSuite
struct
{
IntegrationRedisSuite
cache
ports
.
GatewayCache
}
func
(
s
*
GatewayCacheSuite
)
SetupTest
()
{
s
.
IntegrationRedisSuite
.
SetupTest
()
s
.
cache
=
NewGatewayCache
(
s
.
rdb
)
}
func
(
s
*
GatewayCacheSuite
)
TestGetSessionAccountID_Missing
()
{
_
,
err
:=
s
.
cache
.
GetSessionAccountID
(
s
.
ctx
,
"nonexistent"
)
require
.
True
(
s
.
T
(),
errors
.
Is
(
err
,
redis
.
Nil
),
"expected redis.Nil for missing session"
)
}
func
(
s
*
GatewayCacheSuite
)
TestSetAndGetSessionAccountID
()
{
sessionID
:=
"s1"
accountID
:=
int64
(
99
)
sessionTTL
:=
1
*
time
.
Minute
require
.
NoError
(
s
.
T
(),
s
.
cache
.
SetSessionAccountID
(
s
.
ctx
,
sessionID
,
accountID
,
sessionTTL
),
"SetSessionAccountID"
)
sid
,
err
:=
s
.
cache
.
GetSessionAccountID
(
s
.
ctx
,
sessionID
)
require
.
NoError
(
s
.
T
(),
err
,
"GetSessionAccountID"
)
require
.
Equal
(
s
.
T
(),
accountID
,
sid
,
"session id mismatch"
)
}
func
(
s
*
GatewayCacheSuite
)
TestSessionAccountID_TTL
()
{
sessionID
:=
"s2"
accountID
:=
int64
(
100
)
sessionTTL
:=
1
*
time
.
Minute
require
.
NoError
(
s
.
T
(),
s
.
cache
.
SetSessionAccountID
(
s
.
ctx
,
sessionID
,
accountID
,
sessionTTL
),
"SetSessionAccountID"
)
sessionKey
:=
stickySessionPrefix
+
sessionID
ttl
,
err
:=
s
.
rdb
.
TTL
(
s
.
ctx
,
sessionKey
)
.
Result
()
require
.
NoError
(
s
.
T
(),
err
,
"TTL sessionKey after Set"
)
s
.
AssertTTLWithin
(
ttl
,
1
*
time
.
Second
,
sessionTTL
)
}
func
(
s
*
GatewayCacheSuite
)
TestRefreshSessionTTL
()
{
sessionID
:=
"s3"
accountID
:=
int64
(
101
)
initialTTL
:=
1
*
time
.
Minute
refreshTTL
:=
3
*
time
.
Minute
require
.
NoError
(
s
.
T
(),
s
.
cache
.
SetSessionAccountID
(
s
.
ctx
,
sessionID
,
accountID
,
initialTTL
),
"SetSessionAccountID"
)
require
.
NoError
(
s
.
T
(),
s
.
cache
.
RefreshSessionTTL
(
s
.
ctx
,
sessionID
,
refreshTTL
),
"RefreshSessionTTL"
)
sessionKey
:=
stickySessionPrefix
+
sessionID
ttl
,
err
:=
s
.
rdb
.
TTL
(
s
.
ctx
,
sessionKey
)
.
Result
()
require
.
NoError
(
s
.
T
(),
err
,
"TTL after Refresh"
)
s
.
AssertTTLWithin
(
ttl
,
1
*
time
.
Second
,
refreshTTL
)
}
func
(
s
*
GatewayCacheSuite
)
TestRefreshSessionTTL_MissingKey
()
{
// RefreshSessionTTL on a missing key should not error (no-op)
err
:=
s
.
cache
.
RefreshSessionTTL
(
s
.
ctx
,
"missing-session"
,
1
*
time
.
Minute
)
require
.
NoError
(
s
.
T
(),
err
,
"RefreshSessionTTL on missing key should not error"
)
}
func
(
s
*
GatewayCacheSuite
)
TestGetSessionAccountID_CorruptedValue
()
{
sessionID
:=
"corrupted"
sessionKey
:=
stickySessionPrefix
+
sessionID
// Set a non-integer value
require
.
NoError
(
s
.
T
(),
s
.
rdb
.
Set
(
s
.
ctx
,
sessionKey
,
"not-a-number"
,
1
*
time
.
Minute
)
.
Err
(),
"Set invalid value"
)
_
,
err
:=
s
.
cache
.
GetSessionAccountID
(
s
.
ctx
,
sessionID
)
require
.
Error
(
s
.
T
(),
err
,
"expected error for corrupted value"
)
require
.
False
(
s
.
T
(),
errors
.
Is
(
err
,
redis
.
Nil
),
"expected parsing error, not redis.Nil"
)
}
func
TestGatewayCacheSuite
(
t
*
testing
.
T
)
{
suite
.
Run
(
t
,
new
(
GatewayCacheSuite
))
}
backend/internal/repository/github_release_service_test.go
0 → 100644
View file @
25a304c2
package
repository
import
(
"bytes"
"context"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"strings"
"testing"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
)
type
GitHubReleaseServiceSuite
struct
{
suite
.
Suite
srv
*
httptest
.
Server
client
*
githubReleaseClient
tempDir
string
}
// testTransport redirects requests to the test server
type
testTransport
struct
{
testServerURL
string
}
func
(
t
*
testTransport
)
RoundTrip
(
req
*
http
.
Request
)
(
*
http
.
Response
,
error
)
{
// Rewrite the URL to point to our test server
testURL
:=
t
.
testServerURL
+
req
.
URL
.
Path
newReq
,
err
:=
http
.
NewRequestWithContext
(
req
.
Context
(),
req
.
Method
,
testURL
,
req
.
Body
)
if
err
!=
nil
{
return
nil
,
err
}
newReq
.
Header
=
req
.
Header
return
http
.
DefaultTransport
.
RoundTrip
(
newReq
)
}
func
(
s
*
GitHubReleaseServiceSuite
)
SetupTest
()
{
s
.
tempDir
=
s
.
T
()
.
TempDir
()
}
func
(
s
*
GitHubReleaseServiceSuite
)
TearDownTest
()
{
if
s
.
srv
!=
nil
{
s
.
srv
.
Close
()
s
.
srv
=
nil
}
}
func
(
s
*
GitHubReleaseServiceSuite
)
TestDownloadFile_EnforcesMaxSize_ContentLength
()
{
s
.
srv
=
httptest
.
NewServer
(
http
.
HandlerFunc
(
func
(
w
http
.
ResponseWriter
,
r
*
http
.
Request
)
{
w
.
Header
()
.
Set
(
"Content-Length"
,
"100"
)
w
.
WriteHeader
(
http
.
StatusOK
)
_
,
_
=
w
.
Write
(
bytes
.
Repeat
([]
byte
(
"a"
),
100
))
}))
client
,
ok
:=
NewGitHubReleaseClient
()
.
(
*
githubReleaseClient
)
require
.
True
(
s
.
T
(),
ok
,
"type assertion failed"
)
s
.
client
=
client
dest
:=
filepath
.
Join
(
s
.
tempDir
,
"file1.bin"
)
err
:=
s
.
client
.
DownloadFile
(
context
.
Background
(),
s
.
srv
.
URL
,
dest
,
10
)
require
.
Error
(
s
.
T
(),
err
,
"expected error for oversized download with Content-Length"
)
_
,
statErr
:=
os
.
Stat
(
dest
)
require
.
Error
(
s
.
T
(),
statErr
,
"expected file to not exist for rejected download"
)
}
func
(
s
*
GitHubReleaseServiceSuite
)
TestDownloadFile_EnforcesMaxSize_Chunked
()
{
s
.
srv
=
httptest
.
NewServer
(
http
.
HandlerFunc
(
func
(
w
http
.
ResponseWriter
,
r
*
http
.
Request
)
{
// Force chunked encoding (unknown Content-Length) by flushing headers before writing.
w
.
WriteHeader
(
http
.
StatusOK
)
if
fl
,
ok
:=
w
.
(
http
.
Flusher
);
ok
{
fl
.
Flush
()
}
for
i
:=
0
;
i
<
10
;
i
++
{
_
,
_
=
w
.
Write
(
bytes
.
Repeat
([]
byte
(
"b"
),
10
))
if
fl
,
ok
:=
w
.
(
http
.
Flusher
);
ok
{
fl
.
Flush
()
}
}
}))
client
,
ok
:=
NewGitHubReleaseClient
()
.
(
*
githubReleaseClient
)
require
.
True
(
s
.
T
(),
ok
,
"type assertion failed"
)
s
.
client
=
client
dest
:=
filepath
.
Join
(
s
.
tempDir
,
"file2.bin"
)
err
:=
s
.
client
.
DownloadFile
(
context
.
Background
(),
s
.
srv
.
URL
,
dest
,
10
)
require
.
Error
(
s
.
T
(),
err
,
"expected error for oversized chunked download"
)
_
,
statErr
:=
os
.
Stat
(
dest
)
require
.
Error
(
s
.
T
(),
statErr
,
"expected file to be cleaned up for oversized chunked download"
)
}
func
(
s
*
GitHubReleaseServiceSuite
)
TestDownloadFile_Success
()
{
s
.
srv
=
httptest
.
NewServer
(
http
.
HandlerFunc
(
func
(
w
http
.
ResponseWriter
,
r
*
http
.
Request
)
{
w
.
WriteHeader
(
http
.
StatusOK
)
if
fl
,
ok
:=
w
.
(
http
.
Flusher
);
ok
{
fl
.
Flush
()
}
for
i
:=
0
;
i
<
10
;
i
++
{
_
,
_
=
w
.
Write
(
bytes
.
Repeat
([]
byte
(
"b"
),
10
))
if
fl
,
ok
:=
w
.
(
http
.
Flusher
);
ok
{
fl
.
Flush
()
}
}
}))
client
,
ok
:=
NewGitHubReleaseClient
()
.
(
*
githubReleaseClient
)
require
.
True
(
s
.
T
(),
ok
,
"type assertion failed"
)
s
.
client
=
client
dest
:=
filepath
.
Join
(
s
.
tempDir
,
"file3.bin"
)
err
:=
s
.
client
.
DownloadFile
(
context
.
Background
(),
s
.
srv
.
URL
,
dest
,
200
)
require
.
NoError
(
s
.
T
(),
err
,
"expected success"
)
b
,
err
:=
os
.
ReadFile
(
dest
)
require
.
NoError
(
s
.
T
(),
err
,
"read"
)
require
.
True
(
s
.
T
(),
strings
.
HasPrefix
(
string
(
b
),
"b"
),
"downloaded content should start with 'b'"
)
require
.
Len
(
s
.
T
(),
b
,
100
,
"downloaded content length mismatch"
)
}
func
(
s
*
GitHubReleaseServiceSuite
)
TestDownloadFile_404
()
{
s
.
srv
=
httptest
.
NewServer
(
http
.
HandlerFunc
(
func
(
w
http
.
ResponseWriter
,
r
*
http
.
Request
)
{
w
.
WriteHeader
(
http
.
StatusNotFound
)
}))
client
,
ok
:=
NewGitHubReleaseClient
()
.
(
*
githubReleaseClient
)
require
.
True
(
s
.
T
(),
ok
,
"type assertion failed"
)
s
.
client
=
client
dest
:=
filepath
.
Join
(
s
.
tempDir
,
"notfound.bin"
)
err
:=
s
.
client
.
DownloadFile
(
context
.
Background
(),
s
.
srv
.
URL
,
dest
,
100
)
require
.
Error
(
s
.
T
(),
err
,
"expected error for 404"
)
_
,
statErr
:=
os
.
Stat
(
dest
)
require
.
Error
(
s
.
T
(),
statErr
,
"expected file to not exist for 404"
)
}
func
(
s
*
GitHubReleaseServiceSuite
)
TestFetchChecksumFile_Success
()
{
s
.
srv
=
httptest
.
NewServer
(
http
.
HandlerFunc
(
func
(
w
http
.
ResponseWriter
,
r
*
http
.
Request
)
{
w
.
WriteHeader
(
http
.
StatusOK
)
_
,
_
=
w
.
Write
([]
byte
(
"sum"
))
}))
client
,
ok
:=
NewGitHubReleaseClient
()
.
(
*
githubReleaseClient
)
require
.
True
(
s
.
T
(),
ok
,
"type assertion failed"
)
s
.
client
=
client
body
,
err
:=
s
.
client
.
FetchChecksumFile
(
context
.
Background
(),
s
.
srv
.
URL
)
require
.
NoError
(
s
.
T
(),
err
,
"FetchChecksumFile"
)
require
.
Equal
(
s
.
T
(),
"sum"
,
string
(
body
),
"checksum body mismatch"
)
}
func
(
s
*
GitHubReleaseServiceSuite
)
TestFetchChecksumFile_Non200
()
{
s
.
srv
=
httptest
.
NewServer
(
http
.
HandlerFunc
(
func
(
w
http
.
ResponseWriter
,
r
*
http
.
Request
)
{
w
.
WriteHeader
(
http
.
StatusInternalServerError
)
}))
client
,
ok
:=
NewGitHubReleaseClient
()
.
(
*
githubReleaseClient
)
require
.
True
(
s
.
T
(),
ok
,
"type assertion failed"
)
s
.
client
=
client
_
,
err
:=
s
.
client
.
FetchChecksumFile
(
context
.
Background
(),
s
.
srv
.
URL
)
require
.
Error
(
s
.
T
(),
err
,
"expected error for non-200"
)
}
func
(
s
*
GitHubReleaseServiceSuite
)
TestDownloadFile_ContextCancel
()
{
s
.
srv
=
httptest
.
NewServer
(
http
.
HandlerFunc
(
func
(
w
http
.
ResponseWriter
,
r
*
http
.
Request
)
{
<-
r
.
Context
()
.
Done
()
}))
client
,
ok
:=
NewGitHubReleaseClient
()
.
(
*
githubReleaseClient
)
require
.
True
(
s
.
T
(),
ok
,
"type assertion failed"
)
s
.
client
=
client
ctx
,
cancel
:=
context
.
WithCancel
(
context
.
Background
())
cancel
()
dest
:=
filepath
.
Join
(
s
.
tempDir
,
"cancelled.bin"
)
err
:=
s
.
client
.
DownloadFile
(
ctx
,
s
.
srv
.
URL
,
dest
,
100
)
require
.
Error
(
s
.
T
(),
err
,
"expected error for cancelled context"
)
}
func
(
s
*
GitHubReleaseServiceSuite
)
TestDownloadFile_InvalidURL
()
{
client
,
ok
:=
NewGitHubReleaseClient
()
.
(
*
githubReleaseClient
)
require
.
True
(
s
.
T
(),
ok
,
"type assertion failed"
)
s
.
client
=
client
dest
:=
filepath
.
Join
(
s
.
tempDir
,
"invalid.bin"
)
err
:=
s
.
client
.
DownloadFile
(
context
.
Background
(),
"://invalid-url"
,
dest
,
100
)
require
.
Error
(
s
.
T
(),
err
,
"expected error for invalid URL"
)
}
func
(
s
*
GitHubReleaseServiceSuite
)
TestDownloadFile_InvalidDestPath
()
{
s
.
srv
=
httptest
.
NewServer
(
http
.
HandlerFunc
(
func
(
w
http
.
ResponseWriter
,
r
*
http
.
Request
)
{
w
.
WriteHeader
(
http
.
StatusOK
)
_
,
_
=
w
.
Write
([]
byte
(
"content"
))
}))
client
,
ok
:=
NewGitHubReleaseClient
()
.
(
*
githubReleaseClient
)
require
.
True
(
s
.
T
(),
ok
,
"type assertion failed"
)
s
.
client
=
client
// Use a path that cannot be created (directory doesn't exist)
dest
:=
filepath
.
Join
(
s
.
tempDir
,
"nonexistent"
,
"subdir"
,
"file.bin"
)
err
:=
s
.
client
.
DownloadFile
(
context
.
Background
(),
s
.
srv
.
URL
,
dest
,
100
)
require
.
Error
(
s
.
T
(),
err
,
"expected error for invalid destination path"
)
}
func
(
s
*
GitHubReleaseServiceSuite
)
TestFetchChecksumFile_InvalidURL
()
{
client
,
ok
:=
NewGitHubReleaseClient
()
.
(
*
githubReleaseClient
)
require
.
True
(
s
.
T
(),
ok
,
"type assertion failed"
)
s
.
client
=
client
_
,
err
:=
s
.
client
.
FetchChecksumFile
(
context
.
Background
(),
"://invalid-url"
)
require
.
Error
(
s
.
T
(),
err
,
"expected error for invalid URL"
)
}
func
(
s
*
GitHubReleaseServiceSuite
)
TestFetchLatestRelease_Success
()
{
releaseJSON
:=
`{
"tag_name": "v1.0.0",
"name": "Release 1.0.0",
"body": "Release notes",
"html_url": "https://github.com/test/repo/releases/v1.0.0",
"assets": [
{
"name": "app-linux-amd64.tar.gz",
"browser_download_url": "https://github.com/test/repo/releases/download/v1.0.0/app-linux-amd64.tar.gz"
}
]
}`
s
.
srv
=
httptest
.
NewServer
(
http
.
HandlerFunc
(
func
(
w
http
.
ResponseWriter
,
r
*
http
.
Request
)
{
require
.
Equal
(
s
.
T
(),
"/repos/test/repo/releases/latest"
,
r
.
URL
.
Path
)
require
.
Equal
(
s
.
T
(),
"application/vnd.github.v3+json"
,
r
.
Header
.
Get
(
"Accept"
))
require
.
Equal
(
s
.
T
(),
"Sub2API-Updater"
,
r
.
Header
.
Get
(
"User-Agent"
))
w
.
Header
()
.
Set
(
"Content-Type"
,
"application/json"
)
w
.
WriteHeader
(
http
.
StatusOK
)
_
,
_
=
w
.
Write
([]
byte
(
releaseJSON
))
}))
// Use custom transport to redirect requests to test server
s
.
client
=
&
githubReleaseClient
{
httpClient
:
&
http
.
Client
{
Transport
:
&
testTransport
{
testServerURL
:
s
.
srv
.
URL
},
},
}
release
,
err
:=
s
.
client
.
FetchLatestRelease
(
context
.
Background
(),
"test/repo"
)
require
.
NoError
(
s
.
T
(),
err
)
require
.
Equal
(
s
.
T
(),
"v1.0.0"
,
release
.
TagName
)
require
.
Equal
(
s
.
T
(),
"Release 1.0.0"
,
release
.
Name
)
require
.
Len
(
s
.
T
(),
release
.
Assets
,
1
)
require
.
Equal
(
s
.
T
(),
"app-linux-amd64.tar.gz"
,
release
.
Assets
[
0
]
.
Name
)
}
func
(
s
*
GitHubReleaseServiceSuite
)
TestFetchLatestRelease_Non200
()
{
s
.
srv
=
httptest
.
NewServer
(
http
.
HandlerFunc
(
func
(
w
http
.
ResponseWriter
,
r
*
http
.
Request
)
{
w
.
WriteHeader
(
http
.
StatusNotFound
)
}))
s
.
client
=
&
githubReleaseClient
{
httpClient
:
&
http
.
Client
{
Transport
:
&
testTransport
{
testServerURL
:
s
.
srv
.
URL
},
},
}
_
,
err
:=
s
.
client
.
FetchLatestRelease
(
context
.
Background
(),
"test/repo"
)
require
.
Error
(
s
.
T
(),
err
)
require
.
Contains
(
s
.
T
(),
err
.
Error
(),
"404"
)
}
func
(
s
*
GitHubReleaseServiceSuite
)
TestFetchLatestRelease_InvalidJSON
()
{
s
.
srv
=
httptest
.
NewServer
(
http
.
HandlerFunc
(
func
(
w
http
.
ResponseWriter
,
r
*
http
.
Request
)
{
w
.
WriteHeader
(
http
.
StatusOK
)
_
,
_
=
w
.
Write
([]
byte
(
"not valid json"
))
}))
s
.
client
=
&
githubReleaseClient
{
httpClient
:
&
http
.
Client
{
Transport
:
&
testTransport
{
testServerURL
:
s
.
srv
.
URL
},
},
}
_
,
err
:=
s
.
client
.
FetchLatestRelease
(
context
.
Background
(),
"test/repo"
)
require
.
Error
(
s
.
T
(),
err
)
}
func
(
s
*
GitHubReleaseServiceSuite
)
TestFetchLatestRelease_ContextCancel
()
{
s
.
srv
=
httptest
.
NewServer
(
http
.
HandlerFunc
(
func
(
w
http
.
ResponseWriter
,
r
*
http
.
Request
)
{
<-
r
.
Context
()
.
Done
()
}))
s
.
client
=
&
githubReleaseClient
{
httpClient
:
&
http
.
Client
{
Transport
:
&
testTransport
{
testServerURL
:
s
.
srv
.
URL
},
},
}
ctx
,
cancel
:=
context
.
WithCancel
(
context
.
Background
())
cancel
()
_
,
err
:=
s
.
client
.
FetchLatestRelease
(
ctx
,
"test/repo"
)
require
.
Error
(
s
.
T
(),
err
)
}
func
(
s
*
GitHubReleaseServiceSuite
)
TestFetchChecksumFile_ContextCancel
()
{
s
.
srv
=
httptest
.
NewServer
(
http
.
HandlerFunc
(
func
(
w
http
.
ResponseWriter
,
r
*
http
.
Request
)
{
<-
r
.
Context
()
.
Done
()
}))
client
,
ok
:=
NewGitHubReleaseClient
()
.
(
*
githubReleaseClient
)
require
.
True
(
s
.
T
(),
ok
,
"type assertion failed"
)
s
.
client
=
client
ctx
,
cancel
:=
context
.
WithCancel
(
context
.
Background
())
cancel
()
_
,
err
:=
s
.
client
.
FetchChecksumFile
(
ctx
,
s
.
srv
.
URL
)
require
.
Error
(
s
.
T
(),
err
)
}
func
TestGitHubReleaseServiceSuite
(
t
*
testing
.
T
)
{
suite
.
Run
(
t
,
new
(
GitHubReleaseServiceSuite
))
}
backend/internal/repository/group_repo_integration_test.go
0 → 100644
View file @
25a304c2
//go:build integration
package
repository
import
(
"context"
"testing"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/stretchr/testify/suite"
"gorm.io/gorm"
)
type
GroupRepoSuite
struct
{
suite
.
Suite
ctx
context
.
Context
db
*
gorm
.
DB
repo
*
GroupRepository
}
func
(
s
*
GroupRepoSuite
)
SetupTest
()
{
s
.
ctx
=
context
.
Background
()
s
.
db
=
testTx
(
s
.
T
())
s
.
repo
=
NewGroupRepository
(
s
.
db
)
}
func
TestGroupRepoSuite
(
t
*
testing
.
T
)
{
suite
.
Run
(
t
,
new
(
GroupRepoSuite
))
}
// --- Create / GetByID / Update / Delete ---
func
(
s
*
GroupRepoSuite
)
TestCreate
()
{
group
:=
&
model
.
Group
{
Name
:
"test-create"
,
Platform
:
model
.
PlatformAnthropic
,
Status
:
model
.
StatusActive
,
}
err
:=
s
.
repo
.
Create
(
s
.
ctx
,
group
)
s
.
Require
()
.
NoError
(
err
,
"Create"
)
s
.
Require
()
.
NotZero
(
group
.
ID
,
"expected ID to be set"
)
got
,
err
:=
s
.
repo
.
GetByID
(
s
.
ctx
,
group
.
ID
)
s
.
Require
()
.
NoError
(
err
,
"GetByID"
)
s
.
Require
()
.
Equal
(
"test-create"
,
got
.
Name
)
}
func
(
s
*
GroupRepoSuite
)
TestGetByID_NotFound
()
{
_
,
err
:=
s
.
repo
.
GetByID
(
s
.
ctx
,
999999
)
s
.
Require
()
.
Error
(
err
,
"expected error for non-existent ID"
)
}
func
(
s
*
GroupRepoSuite
)
TestUpdate
()
{
group
:=
mustCreateGroup
(
s
.
T
(),
s
.
db
,
&
model
.
Group
{
Name
:
"original"
})
group
.
Name
=
"updated"
err
:=
s
.
repo
.
Update
(
s
.
ctx
,
group
)
s
.
Require
()
.
NoError
(
err
,
"Update"
)
got
,
err
:=
s
.
repo
.
GetByID
(
s
.
ctx
,
group
.
ID
)
s
.
Require
()
.
NoError
(
err
,
"GetByID after update"
)
s
.
Require
()
.
Equal
(
"updated"
,
got
.
Name
)
}
func
(
s
*
GroupRepoSuite
)
TestDelete
()
{
group
:=
mustCreateGroup
(
s
.
T
(),
s
.
db
,
&
model
.
Group
{
Name
:
"to-delete"
})
err
:=
s
.
repo
.
Delete
(
s
.
ctx
,
group
.
ID
)
s
.
Require
()
.
NoError
(
err
,
"Delete"
)
_
,
err
=
s
.
repo
.
GetByID
(
s
.
ctx
,
group
.
ID
)
s
.
Require
()
.
Error
(
err
,
"expected error after delete"
)
}
// --- List / ListWithFilters ---
func
(
s
*
GroupRepoSuite
)
TestList
()
{
mustCreateGroup
(
s
.
T
(),
s
.
db
,
&
model
.
Group
{
Name
:
"g1"
})
mustCreateGroup
(
s
.
T
(),
s
.
db
,
&
model
.
Group
{
Name
:
"g2"
})
groups
,
page
,
err
:=
s
.
repo
.
List
(
s
.
ctx
,
pagination
.
PaginationParams
{
Page
:
1
,
PageSize
:
10
})
s
.
Require
()
.
NoError
(
err
,
"List"
)
s
.
Require
()
.
Len
(
groups
,
2
)
s
.
Require
()
.
Equal
(
int64
(
2
),
page
.
Total
)
}
func
(
s
*
GroupRepoSuite
)
TestListWithFilters_Platform
()
{
mustCreateGroup
(
s
.
T
(),
s
.
db
,
&
model
.
Group
{
Name
:
"g1"
,
Platform
:
model
.
PlatformAnthropic
})
mustCreateGroup
(
s
.
T
(),
s
.
db
,
&
model
.
Group
{
Name
:
"g2"
,
Platform
:
model
.
PlatformOpenAI
})
groups
,
_
,
err
:=
s
.
repo
.
ListWithFilters
(
s
.
ctx
,
pagination
.
PaginationParams
{
Page
:
1
,
PageSize
:
10
},
model
.
PlatformOpenAI
,
""
,
nil
)
s
.
Require
()
.
NoError
(
err
)
s
.
Require
()
.
Len
(
groups
,
1
)
s
.
Require
()
.
Equal
(
model
.
PlatformOpenAI
,
groups
[
0
]
.
Platform
)
}
func
(
s
*
GroupRepoSuite
)
TestListWithFilters_Status
()
{
mustCreateGroup
(
s
.
T
(),
s
.
db
,
&
model
.
Group
{
Name
:
"g1"
,
Status
:
model
.
StatusActive
})
mustCreateGroup
(
s
.
T
(),
s
.
db
,
&
model
.
Group
{
Name
:
"g2"
,
Status
:
model
.
StatusDisabled
})
groups
,
_
,
err
:=
s
.
repo
.
ListWithFilters
(
s
.
ctx
,
pagination
.
PaginationParams
{
Page
:
1
,
PageSize
:
10
},
""
,
model
.
StatusDisabled
,
nil
)
s
.
Require
()
.
NoError
(
err
)
s
.
Require
()
.
Len
(
groups
,
1
)
s
.
Require
()
.
Equal
(
model
.
StatusDisabled
,
groups
[
0
]
.
Status
)
}
func
(
s
*
GroupRepoSuite
)
TestListWithFilters_IsExclusive
()
{
mustCreateGroup
(
s
.
T
(),
s
.
db
,
&
model
.
Group
{
Name
:
"g1"
,
IsExclusive
:
false
})
mustCreateGroup
(
s
.
T
(),
s
.
db
,
&
model
.
Group
{
Name
:
"g2"
,
IsExclusive
:
true
})
isExclusive
:=
true
groups
,
_
,
err
:=
s
.
repo
.
ListWithFilters
(
s
.
ctx
,
pagination
.
PaginationParams
{
Page
:
1
,
PageSize
:
10
},
""
,
""
,
&
isExclusive
)
s
.
Require
()
.
NoError
(
err
)
s
.
Require
()
.
Len
(
groups
,
1
)
s
.
Require
()
.
True
(
groups
[
0
]
.
IsExclusive
)
}
func
(
s
*
GroupRepoSuite
)
TestListWithFilters_AccountCount
()
{
g1
:=
mustCreateGroup
(
s
.
T
(),
s
.
db
,
&
model
.
Group
{
Name
:
"g1"
,
Platform
:
model
.
PlatformAnthropic
,
Status
:
model
.
StatusActive
,
})
g2
:=
mustCreateGroup
(
s
.
T
(),
s
.
db
,
&
model
.
Group
{
Name
:
"g2"
,
Platform
:
model
.
PlatformAnthropic
,
Status
:
model
.
StatusActive
,
IsExclusive
:
true
,
})
a
:=
mustCreateAccount
(
s
.
T
(),
s
.
db
,
&
model
.
Account
{
Name
:
"acc1"
})
mustBindAccountToGroup
(
s
.
T
(),
s
.
db
,
a
.
ID
,
g1
.
ID
,
1
)
mustBindAccountToGroup
(
s
.
T
(),
s
.
db
,
a
.
ID
,
g2
.
ID
,
1
)
isExclusive
:=
true
groups
,
page
,
err
:=
s
.
repo
.
ListWithFilters
(
s
.
ctx
,
pagination
.
PaginationParams
{
Page
:
1
,
PageSize
:
10
},
model
.
PlatformAnthropic
,
model
.
StatusActive
,
&
isExclusive
)
s
.
Require
()
.
NoError
(
err
,
"ListWithFilters"
)
s
.
Require
()
.
Equal
(
int64
(
1
),
page
.
Total
)
s
.
Require
()
.
Len
(
groups
,
1
)
s
.
Require
()
.
Equal
(
g2
.
ID
,
groups
[
0
]
.
ID
,
"ListWithFilters returned wrong group"
)
s
.
Require
()
.
Equal
(
int64
(
1
),
groups
[
0
]
.
AccountCount
,
"AccountCount mismatch"
)
}
// --- ListActive / ListActiveByPlatform ---
func
(
s
*
GroupRepoSuite
)
TestListActive
()
{
mustCreateGroup
(
s
.
T
(),
s
.
db
,
&
model
.
Group
{
Name
:
"active1"
,
Status
:
model
.
StatusActive
})
mustCreateGroup
(
s
.
T
(),
s
.
db
,
&
model
.
Group
{
Name
:
"inactive1"
,
Status
:
model
.
StatusDisabled
})
groups
,
err
:=
s
.
repo
.
ListActive
(
s
.
ctx
)
s
.
Require
()
.
NoError
(
err
,
"ListActive"
)
s
.
Require
()
.
Len
(
groups
,
1
)
s
.
Require
()
.
Equal
(
"active1"
,
groups
[
0
]
.
Name
)
}
func
(
s
*
GroupRepoSuite
)
TestListActiveByPlatform
()
{
mustCreateGroup
(
s
.
T
(),
s
.
db
,
&
model
.
Group
{
Name
:
"g1"
,
Platform
:
model
.
PlatformAnthropic
,
Status
:
model
.
StatusActive
})
mustCreateGroup
(
s
.
T
(),
s
.
db
,
&
model
.
Group
{
Name
:
"g2"
,
Platform
:
model
.
PlatformOpenAI
,
Status
:
model
.
StatusActive
})
mustCreateGroup
(
s
.
T
(),
s
.
db
,
&
model
.
Group
{
Name
:
"g3"
,
Platform
:
model
.
PlatformAnthropic
,
Status
:
model
.
StatusDisabled
})
groups
,
err
:=
s
.
repo
.
ListActiveByPlatform
(
s
.
ctx
,
model
.
PlatformAnthropic
)
s
.
Require
()
.
NoError
(
err
,
"ListActiveByPlatform"
)
s
.
Require
()
.
Len
(
groups
,
1
)
s
.
Require
()
.
Equal
(
"g1"
,
groups
[
0
]
.
Name
)
}
// --- ExistsByName ---
func
(
s
*
GroupRepoSuite
)
TestExistsByName
()
{
mustCreateGroup
(
s
.
T
(),
s
.
db
,
&
model
.
Group
{
Name
:
"existing-group"
})
exists
,
err
:=
s
.
repo
.
ExistsByName
(
s
.
ctx
,
"existing-group"
)
s
.
Require
()
.
NoError
(
err
,
"ExistsByName"
)
s
.
Require
()
.
True
(
exists
)
notExists
,
err
:=
s
.
repo
.
ExistsByName
(
s
.
ctx
,
"non-existing"
)
s
.
Require
()
.
NoError
(
err
)
s
.
Require
()
.
False
(
notExists
)
}
// --- GetAccountCount ---
func
(
s
*
GroupRepoSuite
)
TestGetAccountCount
()
{
group
:=
mustCreateGroup
(
s
.
T
(),
s
.
db
,
&
model
.
Group
{
Name
:
"g-count"
})
a1
:=
mustCreateAccount
(
s
.
T
(),
s
.
db
,
&
model
.
Account
{
Name
:
"a1"
})
a2
:=
mustCreateAccount
(
s
.
T
(),
s
.
db
,
&
model
.
Account
{
Name
:
"a2"
})
mustBindAccountToGroup
(
s
.
T
(),
s
.
db
,
a1
.
ID
,
group
.
ID
,
1
)
mustBindAccountToGroup
(
s
.
T
(),
s
.
db
,
a2
.
ID
,
group
.
ID
,
2
)
count
,
err
:=
s
.
repo
.
GetAccountCount
(
s
.
ctx
,
group
.
ID
)
s
.
Require
()
.
NoError
(
err
,
"GetAccountCount"
)
s
.
Require
()
.
Equal
(
int64
(
2
),
count
)
}
func
(
s
*
GroupRepoSuite
)
TestGetAccountCount_Empty
()
{
group
:=
mustCreateGroup
(
s
.
T
(),
s
.
db
,
&
model
.
Group
{
Name
:
"g-empty"
})
count
,
err
:=
s
.
repo
.
GetAccountCount
(
s
.
ctx
,
group
.
ID
)
s
.
Require
()
.
NoError
(
err
)
s
.
Require
()
.
Zero
(
count
)
}
// --- DeleteAccountGroupsByGroupID ---
func
(
s
*
GroupRepoSuite
)
TestDeleteAccountGroupsByGroupID
()
{
g
:=
mustCreateGroup
(
s
.
T
(),
s
.
db
,
&
model
.
Group
{
Name
:
"g-del"
})
a
:=
mustCreateAccount
(
s
.
T
(),
s
.
db
,
&
model
.
Account
{
Name
:
"acc-del"
})
mustBindAccountToGroup
(
s
.
T
(),
s
.
db
,
a
.
ID
,
g
.
ID
,
1
)
affected
,
err
:=
s
.
repo
.
DeleteAccountGroupsByGroupID
(
s
.
ctx
,
g
.
ID
)
s
.
Require
()
.
NoError
(
err
,
"DeleteAccountGroupsByGroupID"
)
s
.
Require
()
.
Equal
(
int64
(
1
),
affected
,
"expected 1 affected row"
)
count
,
err
:=
s
.
repo
.
GetAccountCount
(
s
.
ctx
,
g
.
ID
)
s
.
Require
()
.
NoError
(
err
,
"GetAccountCount"
)
s
.
Require
()
.
Equal
(
int64
(
0
),
count
,
"expected 0 account groups"
)
}
func
(
s
*
GroupRepoSuite
)
TestDeleteAccountGroupsByGroupID_MultipleAccounts
()
{
g
:=
mustCreateGroup
(
s
.
T
(),
s
.
db
,
&
model
.
Group
{
Name
:
"g-multi"
})
a1
:=
mustCreateAccount
(
s
.
T
(),
s
.
db
,
&
model
.
Account
{
Name
:
"a1"
})
a2
:=
mustCreateAccount
(
s
.
T
(),
s
.
db
,
&
model
.
Account
{
Name
:
"a2"
})
a3
:=
mustCreateAccount
(
s
.
T
(),
s
.
db
,
&
model
.
Account
{
Name
:
"a3"
})
mustBindAccountToGroup
(
s
.
T
(),
s
.
db
,
a1
.
ID
,
g
.
ID
,
1
)
mustBindAccountToGroup
(
s
.
T
(),
s
.
db
,
a2
.
ID
,
g
.
ID
,
2
)
mustBindAccountToGroup
(
s
.
T
(),
s
.
db
,
a3
.
ID
,
g
.
ID
,
3
)
affected
,
err
:=
s
.
repo
.
DeleteAccountGroupsByGroupID
(
s
.
ctx
,
g
.
ID
)
s
.
Require
()
.
NoError
(
err
)
s
.
Require
()
.
Equal
(
int64
(
3
),
affected
)
count
,
_
:=
s
.
repo
.
GetAccountCount
(
s
.
ctx
,
g
.
ID
)
s
.
Require
()
.
Zero
(
count
)
}
// --- DB ---
func
(
s
*
GroupRepoSuite
)
TestDB
()
{
db
:=
s
.
repo
.
DB
()
s
.
Require
()
.
NotNil
(
db
,
"DB should return non-nil"
)
s
.
Require
()
.
Equal
(
s
.
db
,
db
,
"DB should return the underlying gorm.DB"
)
}
backend/internal/repository/http_upstream_test.go
0 → 100644
View file @
25a304c2
package
repository
import
(
"io"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
)
type
HTTPUpstreamSuite
struct
{
suite
.
Suite
cfg
*
config
.
Config
}
func
(
s
*
HTTPUpstreamSuite
)
SetupTest
()
{
s
.
cfg
=
&
config
.
Config
{}
}
func
(
s
*
HTTPUpstreamSuite
)
TestDefaultResponseHeaderTimeout
()
{
up
:=
NewHTTPUpstream
(
s
.
cfg
)
svc
,
ok
:=
up
.
(
*
httpUpstreamService
)
require
.
True
(
s
.
T
(),
ok
,
"expected *httpUpstreamService"
)
transport
,
ok
:=
svc
.
defaultClient
.
Transport
.
(
*
http
.
Transport
)
require
.
True
(
s
.
T
(),
ok
,
"expected *http.Transport"
)
require
.
Equal
(
s
.
T
(),
300
*
time
.
Second
,
transport
.
ResponseHeaderTimeout
,
"ResponseHeaderTimeout mismatch"
)
}
func
(
s
*
HTTPUpstreamSuite
)
TestCustomResponseHeaderTimeout
()
{
s
.
cfg
.
Gateway
=
config
.
GatewayConfig
{
ResponseHeaderTimeout
:
7
}
up
:=
NewHTTPUpstream
(
s
.
cfg
)
svc
,
ok
:=
up
.
(
*
httpUpstreamService
)
require
.
True
(
s
.
T
(),
ok
,
"expected *httpUpstreamService"
)
transport
,
ok
:=
svc
.
defaultClient
.
Transport
.
(
*
http
.
Transport
)
require
.
True
(
s
.
T
(),
ok
,
"expected *http.Transport"
)
require
.
Equal
(
s
.
T
(),
7
*
time
.
Second
,
transport
.
ResponseHeaderTimeout
,
"ResponseHeaderTimeout mismatch"
)
}
func
(
s
*
HTTPUpstreamSuite
)
TestCreateProxyClient_InvalidURLFallsBackToDefault
()
{
s
.
cfg
.
Gateway
=
config
.
GatewayConfig
{
ResponseHeaderTimeout
:
5
}
up
:=
NewHTTPUpstream
(
s
.
cfg
)
svc
,
ok
:=
up
.
(
*
httpUpstreamService
)
require
.
True
(
s
.
T
(),
ok
,
"expected *httpUpstreamService"
)
got
:=
svc
.
createProxyClient
(
"://bad-proxy-url"
)
require
.
Equal
(
s
.
T
(),
svc
.
defaultClient
,
got
,
"expected defaultClient fallback"
)
}
func
(
s
*
HTTPUpstreamSuite
)
TestDo_WithoutProxy_GoesDirect
()
{
upstream
:=
httptest
.
NewServer
(
http
.
HandlerFunc
(
func
(
w
http
.
ResponseWriter
,
r
*
http
.
Request
)
{
_
,
_
=
io
.
WriteString
(
w
,
"direct"
)
}))
s
.
T
()
.
Cleanup
(
upstream
.
Close
)
up
:=
NewHTTPUpstream
(
s
.
cfg
)
req
,
err
:=
http
.
NewRequest
(
http
.
MethodGet
,
upstream
.
URL
+
"/x"
,
nil
)
require
.
NoError
(
s
.
T
(),
err
,
"NewRequest"
)
resp
,
err
:=
up
.
Do
(
req
,
""
)
require
.
NoError
(
s
.
T
(),
err
,
"Do"
)
defer
func
()
{
_
=
resp
.
Body
.
Close
()
}()
b
,
_
:=
io
.
ReadAll
(
resp
.
Body
)
require
.
Equal
(
s
.
T
(),
"direct"
,
string
(
b
),
"unexpected body"
)
}
func
(
s
*
HTTPUpstreamSuite
)
TestDo_WithHTTPProxy_UsesProxy
()
{
seen
:=
make
(
chan
string
,
1
)
proxySrv
:=
httptest
.
NewServer
(
http
.
HandlerFunc
(
func
(
w
http
.
ResponseWriter
,
r
*
http
.
Request
)
{
seen
<-
r
.
RequestURI
_
,
_
=
io
.
WriteString
(
w
,
"proxied"
)
}))
s
.
T
()
.
Cleanup
(
proxySrv
.
Close
)
s
.
cfg
.
Gateway
=
config
.
GatewayConfig
{
ResponseHeaderTimeout
:
1
}
up
:=
NewHTTPUpstream
(
s
.
cfg
)
req
,
err
:=
http
.
NewRequest
(
http
.
MethodGet
,
"http://example.com/test"
,
nil
)
require
.
NoError
(
s
.
T
(),
err
,
"NewRequest"
)
resp
,
err
:=
up
.
Do
(
req
,
proxySrv
.
URL
)
require
.
NoError
(
s
.
T
(),
err
,
"Do"
)
defer
func
()
{
_
=
resp
.
Body
.
Close
()
}()
b
,
_
:=
io
.
ReadAll
(
resp
.
Body
)
require
.
Equal
(
s
.
T
(),
"proxied"
,
string
(
b
),
"unexpected body"
)
select
{
case
uri
:=
<-
seen
:
require
.
Equal
(
s
.
T
(),
"http://example.com/test"
,
uri
,
"expected absolute-form request URI"
)
default
:
require
.
Fail
(
s
.
T
(),
"expected proxy to receive request"
)
}
}
func
(
s
*
HTTPUpstreamSuite
)
TestDo_EmptyProxy_UsesDirect
()
{
upstream
:=
httptest
.
NewServer
(
http
.
HandlerFunc
(
func
(
w
http
.
ResponseWriter
,
r
*
http
.
Request
)
{
_
,
_
=
io
.
WriteString
(
w
,
"direct-empty"
)
}))
s
.
T
()
.
Cleanup
(
upstream
.
Close
)
up
:=
NewHTTPUpstream
(
s
.
cfg
)
req
,
err
:=
http
.
NewRequest
(
http
.
MethodGet
,
upstream
.
URL
+
"/y"
,
nil
)
require
.
NoError
(
s
.
T
(),
err
,
"NewRequest"
)
resp
,
err
:=
up
.
Do
(
req
,
""
)
require
.
NoError
(
s
.
T
(),
err
,
"Do with empty proxy"
)
defer
func
()
{
_
=
resp
.
Body
.
Close
()
}()
b
,
_
:=
io
.
ReadAll
(
resp
.
Body
)
require
.
Equal
(
s
.
T
(),
"direct-empty"
,
string
(
b
))
}
func
TestHTTPUpstreamSuite
(
t
*
testing
.
T
)
{
suite
.
Run
(
t
,
new
(
HTTPUpstreamSuite
))
}
backend/internal/repository/identity_cache_integration_test.go
0 → 100644
View file @
25a304c2
//go:build integration
package
repository
import
(
"errors"
"fmt"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/service/ports"
"github.com/redis/go-redis/v9"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
)
type
IdentityCacheSuite
struct
{
IntegrationRedisSuite
cache
*
identityCache
}
func
(
s
*
IdentityCacheSuite
)
SetupTest
()
{
s
.
IntegrationRedisSuite
.
SetupTest
()
s
.
cache
=
NewIdentityCache
(
s
.
rdb
)
.
(
*
identityCache
)
}
func
(
s
*
IdentityCacheSuite
)
TestGetFingerprint_Missing
()
{
_
,
err
:=
s
.
cache
.
GetFingerprint
(
s
.
ctx
,
1
)
require
.
True
(
s
.
T
(),
errors
.
Is
(
err
,
redis
.
Nil
),
"expected redis.Nil for missing fingerprint"
)
}
func
(
s
*
IdentityCacheSuite
)
TestSetAndGetFingerprint
()
{
fp
:=
&
ports
.
Fingerprint
{
ClientID
:
"c1"
,
UserAgent
:
"ua"
}
require
.
NoError
(
s
.
T
(),
s
.
cache
.
SetFingerprint
(
s
.
ctx
,
1
,
fp
),
"SetFingerprint"
)
gotFP
,
err
:=
s
.
cache
.
GetFingerprint
(
s
.
ctx
,
1
)
require
.
NoError
(
s
.
T
(),
err
,
"GetFingerprint"
)
require
.
Equal
(
s
.
T
(),
"c1"
,
gotFP
.
ClientID
)
require
.
Equal
(
s
.
T
(),
"ua"
,
gotFP
.
UserAgent
)
}
func
(
s
*
IdentityCacheSuite
)
TestFingerprint_TTL
()
{
fp
:=
&
ports
.
Fingerprint
{
ClientID
:
"c1"
,
UserAgent
:
"ua"
}
require
.
NoError
(
s
.
T
(),
s
.
cache
.
SetFingerprint
(
s
.
ctx
,
2
,
fp
))
fpKey
:=
fmt
.
Sprintf
(
"%s%d"
,
fingerprintKeyPrefix
,
2
)
ttl
,
err
:=
s
.
rdb
.
TTL
(
s
.
ctx
,
fpKey
)
.
Result
()
require
.
NoError
(
s
.
T
(),
err
,
"TTL fpKey"
)
s
.
AssertTTLWithin
(
ttl
,
1
*
time
.
Second
,
fingerprintTTL
)
}
func
(
s
*
IdentityCacheSuite
)
TestGetFingerprint_JSONCorruption
()
{
fpKey
:=
fmt
.
Sprintf
(
"%s%d"
,
fingerprintKeyPrefix
,
999
)
require
.
NoError
(
s
.
T
(),
s
.
rdb
.
Set
(
s
.
ctx
,
fpKey
,
"invalid-json-data"
,
1
*
time
.
Minute
)
.
Err
(),
"Set invalid JSON"
)
_
,
err
:=
s
.
cache
.
GetFingerprint
(
s
.
ctx
,
999
)
require
.
Error
(
s
.
T
(),
err
,
"expected error for corrupted JSON"
)
require
.
False
(
s
.
T
(),
errors
.
Is
(
err
,
redis
.
Nil
),
"expected decoding error, not redis.Nil"
)
}
func
(
s
*
IdentityCacheSuite
)
TestSetFingerprint_Nil
()
{
err
:=
s
.
cache
.
SetFingerprint
(
s
.
ctx
,
100
,
nil
)
require
.
NoError
(
s
.
T
(),
err
,
"SetFingerprint(nil) should succeed"
)
}
func
TestIdentityCacheSuite
(
t
*
testing
.
T
)
{
suite
.
Run
(
t
,
new
(
IdentityCacheSuite
))
}
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment