Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Menu
Open sidebar
陈曦
sub2api
Commits
b9b4db3d
Commit
b9b4db3d
authored
Jan 17, 2026
by
song
Browse files
Merge upstream/main
parents
5a6f60a9
dae0d532
Changes
237
Show whitespace changes
Inline
Side-by-side
Too many changes to show.
To preserve performance only
237 of 237+
files are displayed.
Plain diff
Email patch
backend/internal/repository/proxy_probe_service_test.go
View file @
b9b4db3d
...
...
@@ -21,7 +21,7 @@ type ProxyProbeServiceSuite struct {
func
(
s
*
ProxyProbeServiceSuite
)
SetupTest
()
{
s
.
ctx
=
context
.
Background
()
s
.
prober
=
&
proxyProbeService
{
ipInfoURL
:
"http://ip
info
.test/json"
,
ipInfoURL
:
"http://ip
-api
.test/json
/?lang=zh-CN
"
,
allowPrivateHosts
:
true
,
}
}
...
...
@@ -54,7 +54,7 @@ func (s *ProxyProbeServiceSuite) TestProbeProxy_Success() {
s
.
setupProxyServer
(
http
.
HandlerFunc
(
func
(
w
http
.
ResponseWriter
,
r
*
http
.
Request
)
{
seen
<-
r
.
RequestURI
w
.
Header
()
.
Set
(
"Content-Type"
,
"application/json"
)
_
,
_
=
io
.
WriteString
(
w
,
`{"
ip
":"1.2.3.4","city":"c","region":"r","country":"cc"}`
)
_
,
_
=
io
.
WriteString
(
w
,
`{"
status":"success","query
":"1.2.3.4","city":"c","region
Name
":"r","country":"cc"
,"countryCode":"CC"
}`
)
}))
info
,
latencyMs
,
err
:=
s
.
prober
.
ProbeProxy
(
s
.
ctx
,
s
.
proxySrv
.
URL
)
...
...
@@ -64,11 +64,12 @@ func (s *ProxyProbeServiceSuite) TestProbeProxy_Success() {
require
.
Equal
(
s
.
T
(),
"c"
,
info
.
City
)
require
.
Equal
(
s
.
T
(),
"r"
,
info
.
Region
)
require
.
Equal
(
s
.
T
(),
"cc"
,
info
.
Country
)
require
.
Equal
(
s
.
T
(),
"CC"
,
info
.
CountryCode
)
// Verify proxy received the request
select
{
case
uri
:=
<-
seen
:
require
.
Contains
(
s
.
T
(),
uri
,
"ip
info
.test"
,
"expected request to go through proxy"
)
require
.
Contains
(
s
.
T
(),
uri
,
"ip
-api
.test"
,
"expected request to go through proxy"
)
default
:
require
.
Fail
(
s
.
T
(),
"expected proxy to receive request"
)
}
...
...
backend/internal/repository/proxy_repo.go
View file @
b9b4db3d
...
...
@@ -219,12 +219,54 @@ func (r *proxyRepository) ExistsByHostPortAuth(ctx context.Context, host string,
// CountAccountsByProxyID returns the number of accounts using a specific proxy
func
(
r
*
proxyRepository
)
CountAccountsByProxyID
(
ctx
context
.
Context
,
proxyID
int64
)
(
int64
,
error
)
{
var
count
int64
if
err
:=
scanSingleRow
(
ctx
,
r
.
sql
,
"SELECT COUNT(*) FROM accounts WHERE proxy_id = $1"
,
[]
any
{
proxyID
},
&
count
);
err
!=
nil
{
if
err
:=
scanSingleRow
(
ctx
,
r
.
sql
,
"SELECT COUNT(*) FROM accounts WHERE proxy_id = $1
AND deleted_at IS NULL
"
,
[]
any
{
proxyID
},
&
count
);
err
!=
nil
{
return
0
,
err
}
return
count
,
nil
}
func
(
r
*
proxyRepository
)
ListAccountSummariesByProxyID
(
ctx
context
.
Context
,
proxyID
int64
)
([]
service
.
ProxyAccountSummary
,
error
)
{
rows
,
err
:=
r
.
sql
.
QueryContext
(
ctx
,
`
SELECT id, name, platform, type, notes
FROM accounts
WHERE proxy_id = $1 AND deleted_at IS NULL
ORDER BY id DESC
`
,
proxyID
)
if
err
!=
nil
{
return
nil
,
err
}
defer
func
()
{
_
=
rows
.
Close
()
}()
out
:=
make
([]
service
.
ProxyAccountSummary
,
0
)
for
rows
.
Next
()
{
var
(
id
int64
name
string
platform
string
accType
string
notes
sql
.
NullString
)
if
err
:=
rows
.
Scan
(
&
id
,
&
name
,
&
platform
,
&
accType
,
&
notes
);
err
!=
nil
{
return
nil
,
err
}
var
notesPtr
*
string
if
notes
.
Valid
{
notesPtr
=
&
notes
.
String
}
out
=
append
(
out
,
service
.
ProxyAccountSummary
{
ID
:
id
,
Name
:
name
,
Platform
:
platform
,
Type
:
accType
,
Notes
:
notesPtr
,
})
}
if
err
:=
rows
.
Err
();
err
!=
nil
{
return
nil
,
err
}
return
out
,
nil
}
// GetAccountCountsForProxies returns a map of proxy ID to account count for all proxies
func
(
r
*
proxyRepository
)
GetAccountCountsForProxies
(
ctx
context
.
Context
)
(
counts
map
[
int64
]
int64
,
err
error
)
{
rows
,
err
:=
r
.
sql
.
QueryContext
(
ctx
,
"SELECT proxy_id, COUNT(*) AS count FROM accounts WHERE proxy_id IS NOT NULL AND deleted_at IS NULL GROUP BY proxy_id"
)
...
...
backend/internal/repository/scheduler_cache.go
0 → 100644
View file @
b9b4db3d
package
repository
import
(
"context"
"encoding/json"
"fmt"
"strconv"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/redis/go-redis/v9"
)
const
(
schedulerBucketSetKey
=
"sched:buckets"
schedulerOutboxWatermarkKey
=
"sched:outbox:watermark"
schedulerAccountPrefix
=
"sched:acc:"
schedulerActivePrefix
=
"sched:active:"
schedulerReadyPrefix
=
"sched:ready:"
schedulerVersionPrefix
=
"sched:ver:"
schedulerSnapshotPrefix
=
"sched:"
schedulerLockPrefix
=
"sched:lock:"
)
type
schedulerCache
struct
{
rdb
*
redis
.
Client
}
func
NewSchedulerCache
(
rdb
*
redis
.
Client
)
service
.
SchedulerCache
{
return
&
schedulerCache
{
rdb
:
rdb
}
}
func
(
c
*
schedulerCache
)
GetSnapshot
(
ctx
context
.
Context
,
bucket
service
.
SchedulerBucket
)
([]
*
service
.
Account
,
bool
,
error
)
{
readyKey
:=
schedulerBucketKey
(
schedulerReadyPrefix
,
bucket
)
readyVal
,
err
:=
c
.
rdb
.
Get
(
ctx
,
readyKey
)
.
Result
()
if
err
==
redis
.
Nil
{
return
nil
,
false
,
nil
}
if
err
!=
nil
{
return
nil
,
false
,
err
}
if
readyVal
!=
"1"
{
return
nil
,
false
,
nil
}
activeKey
:=
schedulerBucketKey
(
schedulerActivePrefix
,
bucket
)
activeVal
,
err
:=
c
.
rdb
.
Get
(
ctx
,
activeKey
)
.
Result
()
if
err
==
redis
.
Nil
{
return
nil
,
false
,
nil
}
if
err
!=
nil
{
return
nil
,
false
,
err
}
snapshotKey
:=
schedulerSnapshotKey
(
bucket
,
activeVal
)
ids
,
err
:=
c
.
rdb
.
ZRange
(
ctx
,
snapshotKey
,
0
,
-
1
)
.
Result
()
if
err
!=
nil
{
return
nil
,
false
,
err
}
if
len
(
ids
)
==
0
{
return
[]
*
service
.
Account
{},
true
,
nil
}
keys
:=
make
([]
string
,
0
,
len
(
ids
))
for
_
,
id
:=
range
ids
{
keys
=
append
(
keys
,
schedulerAccountKey
(
id
))
}
values
,
err
:=
c
.
rdb
.
MGet
(
ctx
,
keys
...
)
.
Result
()
if
err
!=
nil
{
return
nil
,
false
,
err
}
accounts
:=
make
([]
*
service
.
Account
,
0
,
len
(
values
))
for
_
,
val
:=
range
values
{
if
val
==
nil
{
return
nil
,
false
,
nil
}
account
,
err
:=
decodeCachedAccount
(
val
)
if
err
!=
nil
{
return
nil
,
false
,
err
}
accounts
=
append
(
accounts
,
account
)
}
return
accounts
,
true
,
nil
}
func
(
c
*
schedulerCache
)
SetSnapshot
(
ctx
context
.
Context
,
bucket
service
.
SchedulerBucket
,
accounts
[]
service
.
Account
)
error
{
activeKey
:=
schedulerBucketKey
(
schedulerActivePrefix
,
bucket
)
oldActive
,
_
:=
c
.
rdb
.
Get
(
ctx
,
activeKey
)
.
Result
()
versionKey
:=
schedulerBucketKey
(
schedulerVersionPrefix
,
bucket
)
version
,
err
:=
c
.
rdb
.
Incr
(
ctx
,
versionKey
)
.
Result
()
if
err
!=
nil
{
return
err
}
versionStr
:=
strconv
.
FormatInt
(
version
,
10
)
snapshotKey
:=
schedulerSnapshotKey
(
bucket
,
versionStr
)
pipe
:=
c
.
rdb
.
Pipeline
()
for
_
,
account
:=
range
accounts
{
payload
,
err
:=
json
.
Marshal
(
account
)
if
err
!=
nil
{
return
err
}
pipe
.
Set
(
ctx
,
schedulerAccountKey
(
strconv
.
FormatInt
(
account
.
ID
,
10
)),
payload
,
0
)
}
if
len
(
accounts
)
>
0
{
// 使用序号作为 score,保持数据库返回的排序语义。
members
:=
make
([]
redis
.
Z
,
0
,
len
(
accounts
))
for
idx
,
account
:=
range
accounts
{
members
=
append
(
members
,
redis
.
Z
{
Score
:
float64
(
idx
),
Member
:
strconv
.
FormatInt
(
account
.
ID
,
10
),
})
}
pipe
.
ZAdd
(
ctx
,
snapshotKey
,
members
...
)
}
else
{
pipe
.
Del
(
ctx
,
snapshotKey
)
}
pipe
.
Set
(
ctx
,
activeKey
,
versionStr
,
0
)
pipe
.
Set
(
ctx
,
schedulerBucketKey
(
schedulerReadyPrefix
,
bucket
),
"1"
,
0
)
pipe
.
SAdd
(
ctx
,
schedulerBucketSetKey
,
bucket
.
String
())
if
_
,
err
:=
pipe
.
Exec
(
ctx
);
err
!=
nil
{
return
err
}
if
oldActive
!=
""
&&
oldActive
!=
versionStr
{
_
=
c
.
rdb
.
Del
(
ctx
,
schedulerSnapshotKey
(
bucket
,
oldActive
))
.
Err
()
}
return
nil
}
func
(
c
*
schedulerCache
)
GetAccount
(
ctx
context
.
Context
,
accountID
int64
)
(
*
service
.
Account
,
error
)
{
key
:=
schedulerAccountKey
(
strconv
.
FormatInt
(
accountID
,
10
))
val
,
err
:=
c
.
rdb
.
Get
(
ctx
,
key
)
.
Result
()
if
err
==
redis
.
Nil
{
return
nil
,
nil
}
if
err
!=
nil
{
return
nil
,
err
}
return
decodeCachedAccount
(
val
)
}
func
(
c
*
schedulerCache
)
SetAccount
(
ctx
context
.
Context
,
account
*
service
.
Account
)
error
{
if
account
==
nil
||
account
.
ID
<=
0
{
return
nil
}
payload
,
err
:=
json
.
Marshal
(
account
)
if
err
!=
nil
{
return
err
}
key
:=
schedulerAccountKey
(
strconv
.
FormatInt
(
account
.
ID
,
10
))
return
c
.
rdb
.
Set
(
ctx
,
key
,
payload
,
0
)
.
Err
()
}
func
(
c
*
schedulerCache
)
DeleteAccount
(
ctx
context
.
Context
,
accountID
int64
)
error
{
if
accountID
<=
0
{
return
nil
}
key
:=
schedulerAccountKey
(
strconv
.
FormatInt
(
accountID
,
10
))
return
c
.
rdb
.
Del
(
ctx
,
key
)
.
Err
()
}
func
(
c
*
schedulerCache
)
UpdateLastUsed
(
ctx
context
.
Context
,
updates
map
[
int64
]
time
.
Time
)
error
{
if
len
(
updates
)
==
0
{
return
nil
}
keys
:=
make
([]
string
,
0
,
len
(
updates
))
ids
:=
make
([]
int64
,
0
,
len
(
updates
))
for
id
:=
range
updates
{
keys
=
append
(
keys
,
schedulerAccountKey
(
strconv
.
FormatInt
(
id
,
10
)))
ids
=
append
(
ids
,
id
)
}
values
,
err
:=
c
.
rdb
.
MGet
(
ctx
,
keys
...
)
.
Result
()
if
err
!=
nil
{
return
err
}
pipe
:=
c
.
rdb
.
Pipeline
()
for
i
,
val
:=
range
values
{
if
val
==
nil
{
continue
}
account
,
err
:=
decodeCachedAccount
(
val
)
if
err
!=
nil
{
return
err
}
account
.
LastUsedAt
=
ptrTime
(
updates
[
ids
[
i
]])
updated
,
err
:=
json
.
Marshal
(
account
)
if
err
!=
nil
{
return
err
}
pipe
.
Set
(
ctx
,
keys
[
i
],
updated
,
0
)
}
_
,
err
=
pipe
.
Exec
(
ctx
)
return
err
}
func
(
c
*
schedulerCache
)
TryLockBucket
(
ctx
context
.
Context
,
bucket
service
.
SchedulerBucket
,
ttl
time
.
Duration
)
(
bool
,
error
)
{
key
:=
schedulerBucketKey
(
schedulerLockPrefix
,
bucket
)
return
c
.
rdb
.
SetNX
(
ctx
,
key
,
time
.
Now
()
.
UnixNano
(),
ttl
)
.
Result
()
}
func
(
c
*
schedulerCache
)
ListBuckets
(
ctx
context
.
Context
)
([]
service
.
SchedulerBucket
,
error
)
{
raw
,
err
:=
c
.
rdb
.
SMembers
(
ctx
,
schedulerBucketSetKey
)
.
Result
()
if
err
!=
nil
{
return
nil
,
err
}
out
:=
make
([]
service
.
SchedulerBucket
,
0
,
len
(
raw
))
for
_
,
entry
:=
range
raw
{
bucket
,
ok
:=
service
.
ParseSchedulerBucket
(
entry
)
if
!
ok
{
continue
}
out
=
append
(
out
,
bucket
)
}
return
out
,
nil
}
func
(
c
*
schedulerCache
)
GetOutboxWatermark
(
ctx
context
.
Context
)
(
int64
,
error
)
{
val
,
err
:=
c
.
rdb
.
Get
(
ctx
,
schedulerOutboxWatermarkKey
)
.
Result
()
if
err
==
redis
.
Nil
{
return
0
,
nil
}
if
err
!=
nil
{
return
0
,
err
}
id
,
err
:=
strconv
.
ParseInt
(
val
,
10
,
64
)
if
err
!=
nil
{
return
0
,
err
}
return
id
,
nil
}
func
(
c
*
schedulerCache
)
SetOutboxWatermark
(
ctx
context
.
Context
,
id
int64
)
error
{
return
c
.
rdb
.
Set
(
ctx
,
schedulerOutboxWatermarkKey
,
strconv
.
FormatInt
(
id
,
10
),
0
)
.
Err
()
}
func
schedulerBucketKey
(
prefix
string
,
bucket
service
.
SchedulerBucket
)
string
{
return
fmt
.
Sprintf
(
"%s%d:%s:%s"
,
prefix
,
bucket
.
GroupID
,
bucket
.
Platform
,
bucket
.
Mode
)
}
func
schedulerSnapshotKey
(
bucket
service
.
SchedulerBucket
,
version
string
)
string
{
return
fmt
.
Sprintf
(
"%s%d:%s:%s:v%s"
,
schedulerSnapshotPrefix
,
bucket
.
GroupID
,
bucket
.
Platform
,
bucket
.
Mode
,
version
)
}
func
schedulerAccountKey
(
id
string
)
string
{
return
schedulerAccountPrefix
+
id
}
func
ptrTime
(
t
time
.
Time
)
*
time
.
Time
{
return
&
t
}
func
decodeCachedAccount
(
val
any
)
(
*
service
.
Account
,
error
)
{
var
payload
[]
byte
switch
raw
:=
val
.
(
type
)
{
case
string
:
payload
=
[]
byte
(
raw
)
case
[]
byte
:
payload
=
raw
default
:
return
nil
,
fmt
.
Errorf
(
"unexpected account cache type: %T"
,
val
)
}
var
account
service
.
Account
if
err
:=
json
.
Unmarshal
(
payload
,
&
account
);
err
!=
nil
{
return
nil
,
err
}
return
&
account
,
nil
}
backend/internal/repository/scheduler_outbox_repo.go
0 → 100644
View file @
b9b4db3d
package
repository
import
(
"context"
"database/sql"
"encoding/json"
"github.com/Wei-Shaw/sub2api/internal/service"
)
type
schedulerOutboxRepository
struct
{
db
*
sql
.
DB
}
func
NewSchedulerOutboxRepository
(
db
*
sql
.
DB
)
service
.
SchedulerOutboxRepository
{
return
&
schedulerOutboxRepository
{
db
:
db
}
}
func
(
r
*
schedulerOutboxRepository
)
ListAfter
(
ctx
context
.
Context
,
afterID
int64
,
limit
int
)
([]
service
.
SchedulerOutboxEvent
,
error
)
{
if
limit
<=
0
{
limit
=
100
}
rows
,
err
:=
r
.
db
.
QueryContext
(
ctx
,
`
SELECT id, event_type, account_id, group_id, payload, created_at
FROM scheduler_outbox
WHERE id > $1
ORDER BY id ASC
LIMIT $2
`
,
afterID
,
limit
)
if
err
!=
nil
{
return
nil
,
err
}
defer
func
()
{
_
=
rows
.
Close
()
}()
events
:=
make
([]
service
.
SchedulerOutboxEvent
,
0
,
limit
)
for
rows
.
Next
()
{
var
(
payloadRaw
[]
byte
accountID
sql
.
NullInt64
groupID
sql
.
NullInt64
event
service
.
SchedulerOutboxEvent
)
if
err
:=
rows
.
Scan
(
&
event
.
ID
,
&
event
.
EventType
,
&
accountID
,
&
groupID
,
&
payloadRaw
,
&
event
.
CreatedAt
);
err
!=
nil
{
return
nil
,
err
}
if
accountID
.
Valid
{
v
:=
accountID
.
Int64
event
.
AccountID
=
&
v
}
if
groupID
.
Valid
{
v
:=
groupID
.
Int64
event
.
GroupID
=
&
v
}
if
len
(
payloadRaw
)
>
0
{
var
payload
map
[
string
]
any
if
err
:=
json
.
Unmarshal
(
payloadRaw
,
&
payload
);
err
!=
nil
{
return
nil
,
err
}
event
.
Payload
=
payload
}
events
=
append
(
events
,
event
)
}
if
err
:=
rows
.
Err
();
err
!=
nil
{
return
nil
,
err
}
return
events
,
nil
}
func
(
r
*
schedulerOutboxRepository
)
MaxID
(
ctx
context
.
Context
)
(
int64
,
error
)
{
var
maxID
int64
if
err
:=
r
.
db
.
QueryRowContext
(
ctx
,
"SELECT COALESCE(MAX(id), 0) FROM scheduler_outbox"
)
.
Scan
(
&
maxID
);
err
!=
nil
{
return
0
,
err
}
return
maxID
,
nil
}
func
enqueueSchedulerOutbox
(
ctx
context
.
Context
,
exec
sqlExecutor
,
eventType
string
,
accountID
*
int64
,
groupID
*
int64
,
payload
any
)
error
{
if
exec
==
nil
{
return
nil
}
var
payloadArg
any
if
payload
!=
nil
{
encoded
,
err
:=
json
.
Marshal
(
payload
)
if
err
!=
nil
{
return
err
}
payloadArg
=
encoded
}
_
,
err
:=
exec
.
ExecContext
(
ctx
,
`
INSERT INTO scheduler_outbox (event_type, account_id, group_id, payload)
VALUES ($1, $2, $3, $4)
`
,
eventType
,
accountID
,
groupID
,
payloadArg
)
return
err
}
backend/internal/repository/scheduler_snapshot_outbox_integration_test.go
0 → 100644
View file @
b9b4db3d
//go:build integration
package
repository
import
(
"context"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/stretchr/testify/require"
)
func
TestSchedulerSnapshotOutboxReplay
(
t
*
testing
.
T
)
{
ctx
:=
context
.
Background
()
rdb
:=
testRedis
(
t
)
client
:=
testEntClient
(
t
)
_
,
_
=
integrationDB
.
ExecContext
(
ctx
,
"TRUNCATE scheduler_outbox"
)
accountRepo
:=
newAccountRepositoryWithSQL
(
client
,
integrationDB
)
outboxRepo
:=
NewSchedulerOutboxRepository
(
integrationDB
)
cache
:=
NewSchedulerCache
(
rdb
)
cfg
:=
&
config
.
Config
{
RunMode
:
config
.
RunModeStandard
,
Gateway
:
config
.
GatewayConfig
{
Scheduling
:
config
.
GatewaySchedulingConfig
{
OutboxPollIntervalSeconds
:
1
,
FullRebuildIntervalSeconds
:
0
,
DbFallbackEnabled
:
true
,
},
},
}
account
:=
&
service
.
Account
{
Name
:
"outbox-replay-"
+
time
.
Now
()
.
Format
(
"150405.000000"
),
Platform
:
service
.
PlatformOpenAI
,
Type
:
service
.
AccountTypeAPIKey
,
Status
:
service
.
StatusActive
,
Schedulable
:
true
,
Concurrency
:
3
,
Priority
:
1
,
Credentials
:
map
[
string
]
any
{},
Extra
:
map
[
string
]
any
{},
}
require
.
NoError
(
t
,
accountRepo
.
Create
(
ctx
,
account
))
require
.
NoError
(
t
,
cache
.
SetAccount
(
ctx
,
account
))
svc
:=
service
.
NewSchedulerSnapshotService
(
cache
,
outboxRepo
,
accountRepo
,
nil
,
cfg
)
svc
.
Start
()
t
.
Cleanup
(
svc
.
Stop
)
require
.
NoError
(
t
,
accountRepo
.
UpdateLastUsed
(
ctx
,
account
.
ID
))
updated
,
err
:=
accountRepo
.
GetByID
(
ctx
,
account
.
ID
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
updated
.
LastUsedAt
)
expectedUnix
:=
updated
.
LastUsedAt
.
Unix
()
require
.
Eventually
(
t
,
func
()
bool
{
cached
,
err
:=
cache
.
GetAccount
(
ctx
,
account
.
ID
)
if
err
!=
nil
||
cached
==
nil
||
cached
.
LastUsedAt
==
nil
{
return
false
}
return
cached
.
LastUsedAt
.
Unix
()
==
expectedUnix
},
5
*
time
.
Second
,
100
*
time
.
Millisecond
)
}
backend/internal/repository/session_limit_cache.go
0 → 100644
View file @
b9b4db3d
package
repository
import
(
"context"
"fmt"
"strconv"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/redis/go-redis/v9"
)
// 会话限制缓存常量定义
//
// 设计说明:
// 使用 Redis 有序集合(Sorted Set)跟踪每个账号的活跃会话:
// - Key: session_limit:account:{accountID}
// - Member: sessionUUID(从 metadata.user_id 中提取)
// - Score: Unix 时间戳(会话最后活跃时间)
//
// 通过 ZREMRANGEBYSCORE 自动清理过期会话,无需手动管理 TTL
const
(
// 会话限制键前缀
// 格式: session_limit:account:{accountID}
sessionLimitKeyPrefix
=
"session_limit:account:"
// 窗口费用缓存键前缀
// 格式: window_cost:account:{accountID}
windowCostKeyPrefix
=
"window_cost:account:"
// 窗口费用缓存 TTL(30秒)
windowCostCacheTTL
=
30
*
time
.
Second
)
var
(
// registerSessionScript 注册会话活动
// 使用 Redis TIME 命令获取服务器时间,避免多实例时钟不同步
// KEYS[1] = session_limit:account:{accountID}
// ARGV[1] = maxSessions
// ARGV[2] = idleTimeout(秒)
// ARGV[3] = sessionUUID
// 返回: 1 = 允许, 0 = 拒绝
registerSessionScript
=
redis
.
NewScript
(
`
local key = KEYS[1]
local maxSessions = tonumber(ARGV[1])
local idleTimeout = tonumber(ARGV[2])
local sessionUUID = ARGV[3]
-- 使用 Redis 服务器时间,确保多实例时钟一致
local timeResult = redis.call('TIME')
local now = tonumber(timeResult[1])
local expireBefore = now - idleTimeout
-- 清理过期会话
redis.call('ZREMRANGEBYSCORE', key, '-inf', expireBefore)
-- 检查会话是否已存在(支持刷新时间戳)
local exists = redis.call('ZSCORE', key, sessionUUID)
if exists ~= false then
-- 会话已存在,刷新时间戳
redis.call('ZADD', key, now, sessionUUID)
redis.call('EXPIRE', key, idleTimeout + 60)
return 1
end
-- 检查是否达到会话数量上限
local count = redis.call('ZCARD', key)
if count < maxSessions then
-- 未达上限,添加新会话
redis.call('ZADD', key, now, sessionUUID)
redis.call('EXPIRE', key, idleTimeout + 60)
return 1
end
-- 达到上限,拒绝新会话
return 0
`
)
// refreshSessionScript 刷新会话时间戳
// KEYS[1] = session_limit:account:{accountID}
// ARGV[1] = idleTimeout(秒)
// ARGV[2] = sessionUUID
refreshSessionScript
=
redis
.
NewScript
(
`
local key = KEYS[1]
local idleTimeout = tonumber(ARGV[1])
local sessionUUID = ARGV[2]
local timeResult = redis.call('TIME')
local now = tonumber(timeResult[1])
-- 检查会话是否存在
local exists = redis.call('ZSCORE', key, sessionUUID)
if exists ~= false then
redis.call('ZADD', key, now, sessionUUID)
redis.call('EXPIRE', key, idleTimeout + 60)
end
return 1
`
)
// getActiveSessionCountScript 获取活跃会话数
// KEYS[1] = session_limit:account:{accountID}
// ARGV[1] = idleTimeout(秒)
getActiveSessionCountScript
=
redis
.
NewScript
(
`
local key = KEYS[1]
local idleTimeout = tonumber(ARGV[1])
local timeResult = redis.call('TIME')
local now = tonumber(timeResult[1])
local expireBefore = now - idleTimeout
-- 清理过期会话
redis.call('ZREMRANGEBYSCORE', key, '-inf', expireBefore)
return redis.call('ZCARD', key)
`
)
// isSessionActiveScript 检查会话是否活跃
// KEYS[1] = session_limit:account:{accountID}
// ARGV[1] = idleTimeout(秒)
// ARGV[2] = sessionUUID
isSessionActiveScript
=
redis
.
NewScript
(
`
local key = KEYS[1]
local idleTimeout = tonumber(ARGV[1])
local sessionUUID = ARGV[2]
local timeResult = redis.call('TIME')
local now = tonumber(timeResult[1])
local expireBefore = now - idleTimeout
-- 获取会话的时间戳
local score = redis.call('ZSCORE', key, sessionUUID)
if score == false then
return 0
end
-- 检查是否过期
if tonumber(score) <= expireBefore then
return 0
end
return 1
`
)
)
type
sessionLimitCache
struct
{
rdb
*
redis
.
Client
defaultIdleTimeout
time
.
Duration
// 默认空闲超时(用于 GetActiveSessionCount)
}
// NewSessionLimitCache 创建会话限制缓存
// defaultIdleTimeoutMinutes: 默认空闲超时时间(分钟),用于无参数查询
func
NewSessionLimitCache
(
rdb
*
redis
.
Client
,
defaultIdleTimeoutMinutes
int
)
service
.
SessionLimitCache
{
if
defaultIdleTimeoutMinutes
<=
0
{
defaultIdleTimeoutMinutes
=
5
// 默认 5 分钟
}
return
&
sessionLimitCache
{
rdb
:
rdb
,
defaultIdleTimeout
:
time
.
Duration
(
defaultIdleTimeoutMinutes
)
*
time
.
Minute
,
}
}
// sessionLimitKey 生成会话限制的 Redis 键
func
sessionLimitKey
(
accountID
int64
)
string
{
return
fmt
.
Sprintf
(
"%s%d"
,
sessionLimitKeyPrefix
,
accountID
)
}
// windowCostKey 生成窗口费用缓存的 Redis 键
func
windowCostKey
(
accountID
int64
)
string
{
return
fmt
.
Sprintf
(
"%s%d"
,
windowCostKeyPrefix
,
accountID
)
}
// RegisterSession 注册会话活动
func
(
c
*
sessionLimitCache
)
RegisterSession
(
ctx
context
.
Context
,
accountID
int64
,
sessionUUID
string
,
maxSessions
int
,
idleTimeout
time
.
Duration
)
(
bool
,
error
)
{
if
sessionUUID
==
""
||
maxSessions
<=
0
{
return
true
,
nil
// 无效参数,默认允许
}
key
:=
sessionLimitKey
(
accountID
)
idleTimeoutSeconds
:=
int
(
idleTimeout
.
Seconds
())
if
idleTimeoutSeconds
<=
0
{
idleTimeoutSeconds
=
int
(
c
.
defaultIdleTimeout
.
Seconds
())
}
result
,
err
:=
registerSessionScript
.
Run
(
ctx
,
c
.
rdb
,
[]
string
{
key
},
maxSessions
,
idleTimeoutSeconds
,
sessionUUID
)
.
Int
()
if
err
!=
nil
{
return
true
,
err
// 失败开放:缓存错误时允许请求通过
}
return
result
==
1
,
nil
}
// RefreshSession 刷新会话时间戳
func
(
c
*
sessionLimitCache
)
RefreshSession
(
ctx
context
.
Context
,
accountID
int64
,
sessionUUID
string
,
idleTimeout
time
.
Duration
)
error
{
if
sessionUUID
==
""
{
return
nil
}
key
:=
sessionLimitKey
(
accountID
)
idleTimeoutSeconds
:=
int
(
idleTimeout
.
Seconds
())
if
idleTimeoutSeconds
<=
0
{
idleTimeoutSeconds
=
int
(
c
.
defaultIdleTimeout
.
Seconds
())
}
_
,
err
:=
refreshSessionScript
.
Run
(
ctx
,
c
.
rdb
,
[]
string
{
key
},
idleTimeoutSeconds
,
sessionUUID
)
.
Result
()
return
err
}
// GetActiveSessionCount 获取活跃会话数
func
(
c
*
sessionLimitCache
)
GetActiveSessionCount
(
ctx
context
.
Context
,
accountID
int64
)
(
int
,
error
)
{
key
:=
sessionLimitKey
(
accountID
)
idleTimeoutSeconds
:=
int
(
c
.
defaultIdleTimeout
.
Seconds
())
result
,
err
:=
getActiveSessionCountScript
.
Run
(
ctx
,
c
.
rdb
,
[]
string
{
key
},
idleTimeoutSeconds
)
.
Int
()
if
err
!=
nil
{
return
0
,
err
}
return
result
,
nil
}
// GetActiveSessionCountBatch 批量获取多个账号的活跃会话数
func
(
c
*
sessionLimitCache
)
GetActiveSessionCountBatch
(
ctx
context
.
Context
,
accountIDs
[]
int64
)
(
map
[
int64
]
int
,
error
)
{
if
len
(
accountIDs
)
==
0
{
return
make
(
map
[
int64
]
int
),
nil
}
results
:=
make
(
map
[
int64
]
int
,
len
(
accountIDs
))
// 使用 pipeline 批量执行
pipe
:=
c
.
rdb
.
Pipeline
()
idleTimeoutSeconds
:=
int
(
c
.
defaultIdleTimeout
.
Seconds
())
cmds
:=
make
(
map
[
int64
]
*
redis
.
Cmd
,
len
(
accountIDs
))
for
_
,
accountID
:=
range
accountIDs
{
key
:=
sessionLimitKey
(
accountID
)
cmds
[
accountID
]
=
getActiveSessionCountScript
.
Run
(
ctx
,
pipe
,
[]
string
{
key
},
idleTimeoutSeconds
)
}
// 执行 pipeline,即使部分失败也尝试获取成功的结果
_
,
_
=
pipe
.
Exec
(
ctx
)
for
accountID
,
cmd
:=
range
cmds
{
if
result
,
err
:=
cmd
.
Int
();
err
==
nil
{
results
[
accountID
]
=
result
}
}
return
results
,
nil
}
// IsSessionActive 检查会话是否活跃
func
(
c
*
sessionLimitCache
)
IsSessionActive
(
ctx
context
.
Context
,
accountID
int64
,
sessionUUID
string
)
(
bool
,
error
)
{
if
sessionUUID
==
""
{
return
false
,
nil
}
key
:=
sessionLimitKey
(
accountID
)
idleTimeoutSeconds
:=
int
(
c
.
defaultIdleTimeout
.
Seconds
())
result
,
err
:=
isSessionActiveScript
.
Run
(
ctx
,
c
.
rdb
,
[]
string
{
key
},
idleTimeoutSeconds
,
sessionUUID
)
.
Int
()
if
err
!=
nil
{
return
false
,
err
}
return
result
==
1
,
nil
}
// ========== 5h窗口费用缓存实现 ==========
// GetWindowCost 获取缓存的窗口费用
func
(
c
*
sessionLimitCache
)
GetWindowCost
(
ctx
context
.
Context
,
accountID
int64
)
(
float64
,
bool
,
error
)
{
key
:=
windowCostKey
(
accountID
)
val
,
err
:=
c
.
rdb
.
Get
(
ctx
,
key
)
.
Float64
()
if
err
==
redis
.
Nil
{
return
0
,
false
,
nil
// 缓存未命中
}
if
err
!=
nil
{
return
0
,
false
,
err
}
return
val
,
true
,
nil
}
// SetWindowCost 设置窗口费用缓存
func
(
c
*
sessionLimitCache
)
SetWindowCost
(
ctx
context
.
Context
,
accountID
int64
,
cost
float64
)
error
{
key
:=
windowCostKey
(
accountID
)
return
c
.
rdb
.
Set
(
ctx
,
key
,
cost
,
windowCostCacheTTL
)
.
Err
()
}
// GetWindowCostBatch 批量获取窗口费用缓存
func
(
c
*
sessionLimitCache
)
GetWindowCostBatch
(
ctx
context
.
Context
,
accountIDs
[]
int64
)
(
map
[
int64
]
float64
,
error
)
{
if
len
(
accountIDs
)
==
0
{
return
make
(
map
[
int64
]
float64
),
nil
}
// 构建批量查询的 keys
keys
:=
make
([]
string
,
len
(
accountIDs
))
for
i
,
accountID
:=
range
accountIDs
{
keys
[
i
]
=
windowCostKey
(
accountID
)
}
// 使用 MGET 批量获取
vals
,
err
:=
c
.
rdb
.
MGet
(
ctx
,
keys
...
)
.
Result
()
if
err
!=
nil
{
return
nil
,
err
}
results
:=
make
(
map
[
int64
]
float64
,
len
(
accountIDs
))
for
i
,
val
:=
range
vals
{
if
val
==
nil
{
continue
// 缓存未命中
}
// 尝试解析为 float64
switch
v
:=
val
.
(
type
)
{
case
string
:
if
cost
,
err
:=
strconv
.
ParseFloat
(
v
,
64
);
err
==
nil
{
results
[
accountIDs
[
i
]]
=
cost
}
case
float64
:
results
[
accountIDs
[
i
]]
=
v
}
}
return
results
,
nil
}
backend/internal/repository/timeout_counter_cache.go
0 → 100644
View file @
b9b4db3d
package
repository
import
(
"context"
"fmt"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/redis/go-redis/v9"
)
const
timeoutCounterPrefix
=
"timeout_count:account:"
// timeoutCounterIncrScript 使用 Lua 脚本原子性地增加计数并返回当前值
// 如果 key 不存在,则创建并设置过期时间
var
timeoutCounterIncrScript
=
redis
.
NewScript
(
`
local key = KEYS[1]
local ttl = tonumber(ARGV[1])
local count = redis.call('INCR', key)
if count == 1 then
redis.call('EXPIRE', key, ttl)
end
return count
`
)
type
timeoutCounterCache
struct
{
rdb
*
redis
.
Client
}
// NewTimeoutCounterCache 创建超时计数器缓存实例
func
NewTimeoutCounterCache
(
rdb
*
redis
.
Client
)
service
.
TimeoutCounterCache
{
return
&
timeoutCounterCache
{
rdb
:
rdb
}
}
// IncrementTimeoutCount 增加账户的超时计数,返回当前计数值
// windowMinutes 是计数窗口时间(分钟),超过此时间计数器会自动重置
func
(
c
*
timeoutCounterCache
)
IncrementTimeoutCount
(
ctx
context
.
Context
,
accountID
int64
,
windowMinutes
int
)
(
int64
,
error
)
{
key
:=
fmt
.
Sprintf
(
"%s%d"
,
timeoutCounterPrefix
,
accountID
)
ttlSeconds
:=
windowMinutes
*
60
if
ttlSeconds
<
60
{
ttlSeconds
=
60
// 最小1分钟
}
result
,
err
:=
timeoutCounterIncrScript
.
Run
(
ctx
,
c
.
rdb
,
[]
string
{
key
},
ttlSeconds
)
.
Int64
()
if
err
!=
nil
{
return
0
,
fmt
.
Errorf
(
"increment timeout count: %w"
,
err
)
}
return
result
,
nil
}
// GetTimeoutCount 获取账户当前的超时计数
func
(
c
*
timeoutCounterCache
)
GetTimeoutCount
(
ctx
context
.
Context
,
accountID
int64
)
(
int64
,
error
)
{
key
:=
fmt
.
Sprintf
(
"%s%d"
,
timeoutCounterPrefix
,
accountID
)
val
,
err
:=
c
.
rdb
.
Get
(
ctx
,
key
)
.
Int64
()
if
err
==
redis
.
Nil
{
return
0
,
nil
}
if
err
!=
nil
{
return
0
,
fmt
.
Errorf
(
"get timeout count: %w"
,
err
)
}
return
val
,
nil
}
// ResetTimeoutCount 重置账户的超时计数
func
(
c
*
timeoutCounterCache
)
ResetTimeoutCount
(
ctx
context
.
Context
,
accountID
int64
)
error
{
key
:=
fmt
.
Sprintf
(
"%s%d"
,
timeoutCounterPrefix
,
accountID
)
return
c
.
rdb
.
Del
(
ctx
,
key
)
.
Err
()
}
// GetTimeoutCountTTL 获取计数器剩余过期时间
func
(
c
*
timeoutCounterCache
)
GetTimeoutCountTTL
(
ctx
context
.
Context
,
accountID
int64
)
(
time
.
Duration
,
error
)
{
key
:=
fmt
.
Sprintf
(
"%s%d"
,
timeoutCounterPrefix
,
accountID
)
return
c
.
rdb
.
TTL
(
ctx
,
key
)
.
Result
()
}
backend/internal/repository/usage_log_repo.go
View file @
b9b4db3d
...
...
@@ -22,7 +22,7 @@ import (
"github.com/lib/pq"
)
const
usageLogSelectColumns
=
"id, user_id, api_key_id, account_id, request_id, model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, billing_type, stream, duration_ms, first_token_ms, user_agent, image_count, image_size, created_at"
const
usageLogSelectColumns
=
"id, user_id, api_key_id, account_id, request_id, model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier,
account_rate_multiplier,
billing_type, stream, duration_ms, first_token_ms, user_agent,
ip_address,
image_count, image_size, created_at"
type
usageLogRepository
struct
{
client
*
dbent
.
Client
...
...
@@ -105,11 +105,13 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
total_cost,
actual_cost,
rate_multiplier,
account_rate_multiplier,
billing_type,
stream,
duration_ms,
first_token_ms,
user_agent,
ip_address,
image_count,
image_size,
created_at
...
...
@@ -119,7 +121,7 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
$8, $9, $10, $11,
$12, $13,
$14, $15, $16, $17, $18, $19,
$20, $21, $22, $23, $24, $25, $26, $27, $28
$20, $21, $22, $23, $24, $25, $26, $27, $28
, $29, $30
)
ON CONFLICT (request_id, api_key_id) DO NOTHING
RETURNING id, created_at
...
...
@@ -130,6 +132,7 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
duration
:=
nullInt
(
log
.
DurationMs
)
firstToken
:=
nullInt
(
log
.
FirstTokenMs
)
userAgent
:=
nullString
(
log
.
UserAgent
)
ipAddress
:=
nullString
(
log
.
IPAddress
)
imageSize
:=
nullString
(
log
.
ImageSize
)
var
requestIDArg
any
...
...
@@ -158,11 +161,13 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
log
.
TotalCost
,
log
.
ActualCost
,
rateMultiplier
,
log
.
AccountRateMultiplier
,
log
.
BillingType
,
log
.
Stream
,
duration
,
firstToken
,
userAgent
,
ipAddress
,
log
.
ImageCount
,
imageSize
,
createdAt
,
...
...
@@ -266,16 +271,60 @@ func (r *usageLogRepository) GetUserStats(ctx context.Context, userID int64, sta
type
DashboardStats
=
usagestats
.
DashboardStats
func
(
r
*
usageLogRepository
)
GetDashboardStats
(
ctx
context
.
Context
)
(
*
DashboardStats
,
error
)
{
var
stats
DashboardStats
today
:=
timezone
.
Today
()
now
:=
time
.
Now
()
stats
:=
&
DashboardStats
{}
now
:=
timezone
.
Now
()
todayStart
:=
timezone
.
Today
()
if
err
:=
r
.
fillDashboardEntityStats
(
ctx
,
stats
,
todayStart
,
now
);
err
!=
nil
{
return
nil
,
err
}
if
err
:=
r
.
fillDashboardUsageStatsAggregated
(
ctx
,
stats
,
todayStart
,
now
);
err
!=
nil
{
return
nil
,
err
}
rpm
,
tpm
,
err
:=
r
.
getPerformanceStats
(
ctx
,
0
)
if
err
!=
nil
{
return
nil
,
err
}
stats
.
Rpm
=
rpm
stats
.
Tpm
=
tpm
return
stats
,
nil
}
func
(
r
*
usageLogRepository
)
GetDashboardStatsWithRange
(
ctx
context
.
Context
,
start
,
end
time
.
Time
)
(
*
DashboardStats
,
error
)
{
startUTC
:=
start
.
UTC
()
endUTC
:=
end
.
UTC
()
if
!
endUTC
.
After
(
startUTC
)
{
return
nil
,
errors
.
New
(
"统计时间范围无效"
)
}
// 合并用户统计查询
stats
:=
&
DashboardStats
{}
now
:=
timezone
.
Now
()
todayStart
:=
timezone
.
Today
()
if
err
:=
r
.
fillDashboardEntityStats
(
ctx
,
stats
,
todayStart
,
now
);
err
!=
nil
{
return
nil
,
err
}
if
err
:=
r
.
fillDashboardUsageStatsFromUsageLogs
(
ctx
,
stats
,
startUTC
,
endUTC
,
todayStart
,
now
);
err
!=
nil
{
return
nil
,
err
}
rpm
,
tpm
,
err
:=
r
.
getPerformanceStats
(
ctx
,
0
)
if
err
!=
nil
{
return
nil
,
err
}
stats
.
Rpm
=
rpm
stats
.
Tpm
=
tpm
return
stats
,
nil
}
func
(
r
*
usageLogRepository
)
fillDashboardEntityStats
(
ctx
context
.
Context
,
stats
*
DashboardStats
,
todayUTC
,
now
time
.
Time
)
error
{
userStatsQuery
:=
`
SELECT
COUNT(*) as total_users,
COUNT(CASE WHEN created_at >= $1 THEN 1 END) as today_new_users,
(SELECT COUNT(DISTINCT user_id) FROM usage_logs WHERE created_at >= $2) as active_users
COUNT(CASE WHEN created_at >= $1 THEN 1 END) as today_new_users
FROM users
WHERE deleted_at IS NULL
`
...
...
@@ -283,15 +332,13 @@ func (r *usageLogRepository) GetDashboardStats(ctx context.Context) (*DashboardS
ctx
,
r
.
sql
,
userStatsQuery
,
[]
any
{
today
,
today
},
[]
any
{
today
UTC
},
&
stats
.
TotalUsers
,
&
stats
.
TodayNewUsers
,
&
stats
.
ActiveUsers
,
);
err
!=
nil
{
return
nil
,
err
return
err
}
// 合并API Key统计查询
apiKeyStatsQuery
:=
`
SELECT
COUNT(*) as total_api_keys,
...
...
@@ -307,10 +354,9 @@ func (r *usageLogRepository) GetDashboardStats(ctx context.Context) (*DashboardS
&
stats
.
TotalAPIKeys
,
&
stats
.
ActiveAPIKeys
,
);
err
!=
nil
{
return
nil
,
err
return
err
}
// 合并账户统计查询
accountStatsQuery
:=
`
SELECT
COUNT(*) as total_accounts,
...
...
@@ -332,10 +378,96 @@ func (r *usageLogRepository) GetDashboardStats(ctx context.Context) (*DashboardS
&
stats
.
RateLimitAccounts
,
&
stats
.
OverloadAccounts
,
);
err
!=
nil
{
return
nil
,
err
return
err
}
// 累计 Token 统计
return
nil
}
func
(
r
*
usageLogRepository
)
fillDashboardUsageStatsAggregated
(
ctx
context
.
Context
,
stats
*
DashboardStats
,
todayUTC
,
now
time
.
Time
)
error
{
totalStatsQuery
:=
`
SELECT
COALESCE(SUM(total_requests), 0) as total_requests,
COALESCE(SUM(input_tokens), 0) as total_input_tokens,
COALESCE(SUM(output_tokens), 0) as total_output_tokens,
COALESCE(SUM(cache_creation_tokens), 0) as total_cache_creation_tokens,
COALESCE(SUM(cache_read_tokens), 0) as total_cache_read_tokens,
COALESCE(SUM(total_cost), 0) as total_cost,
COALESCE(SUM(actual_cost), 0) as total_actual_cost,
COALESCE(SUM(total_duration_ms), 0) as total_duration_ms
FROM usage_dashboard_daily
`
var
totalDurationMs
int64
if
err
:=
scanSingleRow
(
ctx
,
r
.
sql
,
totalStatsQuery
,
nil
,
&
stats
.
TotalRequests
,
&
stats
.
TotalInputTokens
,
&
stats
.
TotalOutputTokens
,
&
stats
.
TotalCacheCreationTokens
,
&
stats
.
TotalCacheReadTokens
,
&
stats
.
TotalCost
,
&
stats
.
TotalActualCost
,
&
totalDurationMs
,
);
err
!=
nil
{
return
err
}
stats
.
TotalTokens
=
stats
.
TotalInputTokens
+
stats
.
TotalOutputTokens
+
stats
.
TotalCacheCreationTokens
+
stats
.
TotalCacheReadTokens
if
stats
.
TotalRequests
>
0
{
stats
.
AverageDurationMs
=
float64
(
totalDurationMs
)
/
float64
(
stats
.
TotalRequests
)
}
todayStatsQuery
:=
`
SELECT
total_requests as today_requests,
input_tokens as today_input_tokens,
output_tokens as today_output_tokens,
cache_creation_tokens as today_cache_creation_tokens,
cache_read_tokens as today_cache_read_tokens,
total_cost as today_cost,
actual_cost as today_actual_cost,
active_users as active_users
FROM usage_dashboard_daily
WHERE bucket_date = $1::date
`
if
err
:=
scanSingleRow
(
ctx
,
r
.
sql
,
todayStatsQuery
,
[]
any
{
todayUTC
},
&
stats
.
TodayRequests
,
&
stats
.
TodayInputTokens
,
&
stats
.
TodayOutputTokens
,
&
stats
.
TodayCacheCreationTokens
,
&
stats
.
TodayCacheReadTokens
,
&
stats
.
TodayCost
,
&
stats
.
TodayActualCost
,
&
stats
.
ActiveUsers
,
);
err
!=
nil
{
if
err
!=
sql
.
ErrNoRows
{
return
err
}
}
stats
.
TodayTokens
=
stats
.
TodayInputTokens
+
stats
.
TodayOutputTokens
+
stats
.
TodayCacheCreationTokens
+
stats
.
TodayCacheReadTokens
hourlyActiveQuery
:=
`
SELECT active_users
FROM usage_dashboard_hourly
WHERE bucket_start = $1
`
hourStart
:=
now
.
In
(
timezone
.
Location
())
.
Truncate
(
time
.
Hour
)
if
err
:=
scanSingleRow
(
ctx
,
r
.
sql
,
hourlyActiveQuery
,
[]
any
{
hourStart
},
&
stats
.
HourlyActiveUsers
);
err
!=
nil
{
if
err
!=
sql
.
ErrNoRows
{
return
err
}
}
return
nil
}
func
(
r
*
usageLogRepository
)
fillDashboardUsageStatsFromUsageLogs
(
ctx
context
.
Context
,
stats
*
DashboardStats
,
startUTC
,
endUTC
,
todayUTC
,
now
time
.
Time
)
error
{
totalStatsQuery
:=
`
SELECT
COUNT(*) as total_requests,
...
...
@@ -345,14 +477,16 @@ func (r *usageLogRepository) GetDashboardStats(ctx context.Context) (*DashboardS
COALESCE(SUM(cache_read_tokens), 0) as total_cache_read_tokens,
COALESCE(SUM(total_cost), 0) as total_cost,
COALESCE(SUM(actual_cost), 0) as total_actual_cost,
COALESCE(
AVG
(duration_ms), 0) as
avg
_duration_ms
COALESCE(
SUM(COALESCE
(duration_ms
, 0)
), 0) as
total
_duration_ms
FROM usage_logs
WHERE created_at >= $1 AND created_at < $2
`
var
totalDurationMs
int64
if
err
:=
scanSingleRow
(
ctx
,
r
.
sql
,
totalStatsQuery
,
nil
,
[]
any
{
startUTC
,
endUTC
}
,
&
stats
.
TotalRequests
,
&
stats
.
TotalInputTokens
,
&
stats
.
TotalOutputTokens
,
...
...
@@ -360,13 +494,16 @@ func (r *usageLogRepository) GetDashboardStats(ctx context.Context) (*DashboardS
&
stats
.
TotalCacheReadTokens
,
&
stats
.
TotalCost
,
&
stats
.
TotalActualCost
,
&
stats
.
Average
DurationMs
,
&
total
DurationMs
,
);
err
!=
nil
{
return
nil
,
err
return
err
}
stats
.
TotalTokens
=
stats
.
TotalInputTokens
+
stats
.
TotalOutputTokens
+
stats
.
TotalCacheCreationTokens
+
stats
.
TotalCacheReadTokens
if
stats
.
TotalRequests
>
0
{
stats
.
AverageDurationMs
=
float64
(
totalDurationMs
)
/
float64
(
stats
.
TotalRequests
)
}
// 今日 Token 统计
todayEnd
:=
todayUTC
.
Add
(
24
*
time
.
Hour
)
todayStatsQuery
:=
`
SELECT
COUNT(*) as today_requests,
...
...
@@ -377,13 +514,13 @@ func (r *usageLogRepository) GetDashboardStats(ctx context.Context) (*DashboardS
COALESCE(SUM(total_cost), 0) as today_cost,
COALESCE(SUM(actual_cost), 0) as today_actual_cost
FROM usage_logs
WHERE created_at >= $1
WHERE created_at >= $1
AND created_at < $2
`
if
err
:=
scanSingleRow
(
ctx
,
r
.
sql
,
todayStatsQuery
,
[]
any
{
today
},
[]
any
{
today
UTC
,
todayEnd
},
&
stats
.
TodayRequests
,
&
stats
.
TodayInputTokens
,
&
stats
.
TodayOutputTokens
,
...
...
@@ -392,19 +529,31 @@ func (r *usageLogRepository) GetDashboardStats(ctx context.Context) (*DashboardS
&
stats
.
TodayCost
,
&
stats
.
TodayActualCost
,
);
err
!=
nil
{
return
nil
,
err
return
err
}
stats
.
TodayTokens
=
stats
.
TodayInputTokens
+
stats
.
TodayOutputTokens
+
stats
.
TodayCacheCreationTokens
+
stats
.
TodayCacheReadTokens
// 性能指标:RPM 和 TPM(最近1分钟,全局)
rpm
,
tpm
,
err
:=
r
.
getPerformanceStats
(
ctx
,
0
)
if
err
!=
nil
{
return
nil
,
err
activeUsersQuery
:=
`
SELECT COUNT(DISTINCT user_id) as active_users
FROM usage_logs
WHERE created_at >= $1 AND created_at < $2
`
if
err
:=
scanSingleRow
(
ctx
,
r
.
sql
,
activeUsersQuery
,
[]
any
{
todayUTC
,
todayEnd
},
&
stats
.
ActiveUsers
);
err
!=
nil
{
return
err
}
stats
.
Rpm
=
rpm
stats
.
Tpm
=
tpm
return
&
stats
,
nil
hourStart
:=
now
.
UTC
()
.
Truncate
(
time
.
Hour
)
hourEnd
:=
hourStart
.
Add
(
time
.
Hour
)
hourlyActiveQuery
:=
`
SELECT COUNT(DISTINCT user_id) as active_users
FROM usage_logs
WHERE created_at >= $1 AND created_at < $2
`
if
err
:=
scanSingleRow
(
ctx
,
r
.
sql
,
hourlyActiveQuery
,
[]
any
{
hourStart
,
hourEnd
},
&
stats
.
HourlyActiveUsers
);
err
!=
nil
{
return
err
}
return
nil
}
func
(
r
*
usageLogRepository
)
ListByAccount
(
ctx
context
.
Context
,
accountID
int64
,
params
pagination
.
PaginationParams
)
([]
service
.
UsageLog
,
*
pagination
.
PaginationResult
,
error
)
{
...
...
@@ -688,7 +837,9 @@ func (r *usageLogRepository) GetAccountTodayStats(ctx context.Context, accountID
SELECT
COUNT(*) as requests,
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as tokens,
COALESCE(SUM(actual_cost), 0) as cost
COALESCE(SUM(total_cost * COALESCE(account_rate_multiplier, 1)), 0) as cost,
COALESCE(SUM(total_cost), 0) as standard_cost,
COALESCE(SUM(actual_cost), 0) as user_cost
FROM usage_logs
WHERE account_id = $1 AND created_at >= $2
`
...
...
@@ -702,6 +853,8 @@ func (r *usageLogRepository) GetAccountTodayStats(ctx context.Context, accountID
&
stats
.
Requests
,
&
stats
.
Tokens
,
&
stats
.
Cost
,
&
stats
.
StandardCost
,
&
stats
.
UserCost
,
);
err
!=
nil
{
return
nil
,
err
}
...
...
@@ -714,7 +867,9 @@ func (r *usageLogRepository) GetAccountWindowStats(ctx context.Context, accountI
SELECT
COUNT(*) as requests,
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as tokens,
COALESCE(SUM(actual_cost), 0) as cost
COALESCE(SUM(total_cost * COALESCE(account_rate_multiplier, 1)), 0) as cost,
COALESCE(SUM(total_cost), 0) as standard_cost,
COALESCE(SUM(actual_cost), 0) as user_cost
FROM usage_logs
WHERE account_id = $1 AND created_at >= $2
`
...
...
@@ -728,6 +883,8 @@ func (r *usageLogRepository) GetAccountWindowStats(ctx context.Context, accountI
&
stats
.
Requests
,
&
stats
.
Tokens
,
&
stats
.
Cost
,
&
stats
.
StandardCost
,
&
stats
.
UserCost
,
);
err
!=
nil
{
return
nil
,
err
}
...
...
@@ -1253,8 +1410,8 @@ func (r *usageLogRepository) GetBatchAPIKeyUsageStats(ctx context.Context, apiKe
return
result
,
nil
}
// GetUsageTrendWithFilters returns usage trend data with optional
user/api_key
filters
func
(
r
*
usageLogRepository
)
GetUsageTrendWithFilters
(
ctx
context
.
Context
,
startTime
,
endTime
time
.
Time
,
granularity
string
,
userID
,
apiKeyID
int64
)
(
results
[]
TrendDataPoint
,
err
error
)
{
// GetUsageTrendWithFilters returns usage trend data with optional filters
func
(
r
*
usageLogRepository
)
GetUsageTrendWithFilters
(
ctx
context
.
Context
,
startTime
,
endTime
time
.
Time
,
granularity
string
,
userID
,
apiKeyID
,
accountID
,
groupID
int64
,
model
string
,
stream
*
bool
)
(
results
[]
TrendDataPoint
,
err
error
)
{
dateFormat
:=
"YYYY-MM-DD"
if
granularity
==
"hour"
{
dateFormat
=
"YYYY-MM-DD HH24:00"
...
...
@@ -1283,6 +1440,22 @@ func (r *usageLogRepository) GetUsageTrendWithFilters(ctx context.Context, start
query
+=
fmt
.
Sprintf
(
" AND api_key_id = $%d"
,
len
(
args
)
+
1
)
args
=
append
(
args
,
apiKeyID
)
}
if
accountID
>
0
{
query
+=
fmt
.
Sprintf
(
" AND account_id = $%d"
,
len
(
args
)
+
1
)
args
=
append
(
args
,
accountID
)
}
if
groupID
>
0
{
query
+=
fmt
.
Sprintf
(
" AND group_id = $%d"
,
len
(
args
)
+
1
)
args
=
append
(
args
,
groupID
)
}
if
model
!=
""
{
query
+=
fmt
.
Sprintf
(
" AND model = $%d"
,
len
(
args
)
+
1
)
args
=
append
(
args
,
model
)
}
if
stream
!=
nil
{
query
+=
fmt
.
Sprintf
(
" AND stream = $%d"
,
len
(
args
)
+
1
)
args
=
append
(
args
,
*
stream
)
}
query
+=
" GROUP BY date ORDER BY date ASC"
rows
,
err
:=
r
.
sql
.
QueryContext
(
ctx
,
query
,
args
...
)
...
...
@@ -1305,9 +1478,15 @@ func (r *usageLogRepository) GetUsageTrendWithFilters(ctx context.Context, start
return
results
,
nil
}
// GetModelStatsWithFilters returns model statistics with optional user/api_key filters
func
(
r
*
usageLogRepository
)
GetModelStatsWithFilters
(
ctx
context
.
Context
,
startTime
,
endTime
time
.
Time
,
userID
,
apiKeyID
,
accountID
int64
)
(
results
[]
ModelStat
,
err
error
)
{
query
:=
`
// GetModelStatsWithFilters returns model statistics with optional filters
func
(
r
*
usageLogRepository
)
GetModelStatsWithFilters
(
ctx
context
.
Context
,
startTime
,
endTime
time
.
Time
,
userID
,
apiKeyID
,
accountID
,
groupID
int64
,
stream
*
bool
)
(
results
[]
ModelStat
,
err
error
)
{
actualCostExpr
:=
"COALESCE(SUM(actual_cost), 0) as actual_cost"
// 当仅按 account_id 聚合时,实际费用使用账号倍率(total_cost * account_rate_multiplier)。
if
accountID
>
0
&&
userID
==
0
&&
apiKeyID
==
0
{
actualCostExpr
=
"COALESCE(SUM(total_cost * COALESCE(account_rate_multiplier, 1)), 0) as actual_cost"
}
query
:=
fmt
.
Sprintf
(
`
SELECT
model,
COUNT(*) as requests,
...
...
@@ -1315,10 +1494,10 @@ func (r *usageLogRepository) GetModelStatsWithFilters(ctx context.Context, start
COALESCE(SUM(output_tokens), 0) as output_tokens,
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as total_tokens,
COALESCE(SUM(total_cost), 0) as cost,
COALESCE(SUM(actual_cost), 0) as actual_cost
%s
FROM usage_logs
WHERE created_at >= $1 AND created_at < $2
`
`
,
actualCostExpr
)
args
:=
[]
any
{
startTime
,
endTime
}
if
userID
>
0
{
...
...
@@ -1333,6 +1512,14 @@ func (r *usageLogRepository) GetModelStatsWithFilters(ctx context.Context, start
query
+=
fmt
.
Sprintf
(
" AND account_id = $%d"
,
len
(
args
)
+
1
)
args
=
append
(
args
,
accountID
)
}
if
groupID
>
0
{
query
+=
fmt
.
Sprintf
(
" AND group_id = $%d"
,
len
(
args
)
+
1
)
args
=
append
(
args
,
groupID
)
}
if
stream
!=
nil
{
query
+=
fmt
.
Sprintf
(
" AND stream = $%d"
,
len
(
args
)
+
1
)
args
=
append
(
args
,
*
stream
)
}
query
+=
" GROUP BY model ORDER BY total_tokens DESC"
rows
,
err
:=
r
.
sql
.
QueryContext
(
ctx
,
query
,
args
...
)
...
...
@@ -1440,12 +1627,14 @@ func (r *usageLogRepository) GetStatsWithFilters(ctx context.Context, filters Us
COALESCE(SUM(cache_creation_tokens + cache_read_tokens), 0) as total_cache_tokens,
COALESCE(SUM(total_cost), 0) as total_cost,
COALESCE(SUM(actual_cost), 0) as total_actual_cost,
COALESCE(SUM(total_cost * COALESCE(account_rate_multiplier, 1)), 0) as total_account_cost,
COALESCE(AVG(duration_ms), 0) as avg_duration_ms
FROM usage_logs
%s
`
,
buildWhere
(
conditions
))
stats
:=
&
UsageStats
{}
var
totalAccountCost
float64
if
err
:=
scanSingleRow
(
ctx
,
r
.
sql
,
...
...
@@ -1457,10 +1646,14 @@ func (r *usageLogRepository) GetStatsWithFilters(ctx context.Context, filters Us
&
stats
.
TotalCacheTokens
,
&
stats
.
TotalCost
,
&
stats
.
TotalActualCost
,
&
totalAccountCost
,
&
stats
.
AverageDurationMs
,
);
err
!=
nil
{
return
nil
,
err
}
if
filters
.
AccountID
>
0
{
stats
.
TotalAccountCost
=
&
totalAccountCost
}
stats
.
TotalTokens
=
stats
.
TotalInputTokens
+
stats
.
TotalOutputTokens
+
stats
.
TotalCacheTokens
return
stats
,
nil
}
...
...
@@ -1487,7 +1680,8 @@ func (r *usageLogRepository) GetAccountUsageStats(ctx context.Context, accountID
COUNT(*) as requests,
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as tokens,
COALESCE(SUM(total_cost), 0) as cost,
COALESCE(SUM(actual_cost), 0) as actual_cost
COALESCE(SUM(total_cost * COALESCE(account_rate_multiplier, 1)), 0) as actual_cost,
COALESCE(SUM(actual_cost), 0) as user_cost
FROM usage_logs
WHERE account_id = $1 AND created_at >= $2 AND created_at < $3
GROUP BY date
...
...
@@ -1514,7 +1708,8 @@ func (r *usageLogRepository) GetAccountUsageStats(ctx context.Context, accountID
var
tokens
int64
var
cost
float64
var
actualCost
float64
if
err
=
rows
.
Scan
(
&
date
,
&
requests
,
&
tokens
,
&
cost
,
&
actualCost
);
err
!=
nil
{
var
userCost
float64
if
err
=
rows
.
Scan
(
&
date
,
&
requests
,
&
tokens
,
&
cost
,
&
actualCost
,
&
userCost
);
err
!=
nil
{
return
nil
,
err
}
t
,
_
:=
time
.
Parse
(
"2006-01-02"
,
date
)
...
...
@@ -1525,19 +1720,21 @@ func (r *usageLogRepository) GetAccountUsageStats(ctx context.Context, accountID
Tokens
:
tokens
,
Cost
:
cost
,
ActualCost
:
actualCost
,
UserCost
:
userCost
,
})
}
if
err
=
rows
.
Err
();
err
!=
nil
{
return
nil
,
err
}
var
totalAc
tual
Cost
,
totalStandardCost
float64
var
totalAc
countCost
,
totalUser
Cost
,
totalStandardCost
float64
var
totalRequests
,
totalTokens
int64
var
highestCostDay
,
highestRequestDay
*
AccountUsageHistory
for
i
:=
range
history
{
h
:=
&
history
[
i
]
totalActualCost
+=
h
.
ActualCost
totalAccountCost
+=
h
.
ActualCost
totalUserCost
+=
h
.
UserCost
totalStandardCost
+=
h
.
Cost
totalRequests
+=
h
.
Requests
totalTokens
+=
h
.
Tokens
...
...
@@ -1564,11 +1761,13 @@ func (r *usageLogRepository) GetAccountUsageStats(ctx context.Context, accountID
summary
:=
AccountUsageSummary
{
Days
:
daysCount
,
ActualDaysUsed
:
actualDaysUsed
,
TotalCost
:
totalActualCost
,
TotalCost
:
totalAccountCost
,
TotalUserCost
:
totalUserCost
,
TotalStandardCost
:
totalStandardCost
,
TotalRequests
:
totalRequests
,
TotalTokens
:
totalTokens
,
AvgDailyCost
:
totalActualCost
/
float64
(
actualDaysUsed
),
AvgDailyCost
:
totalAccountCost
/
float64
(
actualDaysUsed
),
AvgDailyUserCost
:
totalUserCost
/
float64
(
actualDaysUsed
),
AvgDailyRequests
:
float64
(
totalRequests
)
/
float64
(
actualDaysUsed
),
AvgDailyTokens
:
float64
(
totalTokens
)
/
float64
(
actualDaysUsed
),
AvgDurationMs
:
avgDuration
,
...
...
@@ -1580,11 +1779,13 @@ func (r *usageLogRepository) GetAccountUsageStats(ctx context.Context, accountID
summary
.
Today
=
&
struct
{
Date
string
`json:"date"`
Cost
float64
`json:"cost"`
UserCost
float64
`json:"user_cost"`
Requests
int64
`json:"requests"`
Tokens
int64
`json:"tokens"`
}{
Date
:
history
[
i
]
.
Date
,
Cost
:
history
[
i
]
.
ActualCost
,
UserCost
:
history
[
i
]
.
UserCost
,
Requests
:
history
[
i
]
.
Requests
,
Tokens
:
history
[
i
]
.
Tokens
,
}
...
...
@@ -1597,11 +1798,13 @@ func (r *usageLogRepository) GetAccountUsageStats(ctx context.Context, accountID
Date
string
`json:"date"`
Label
string
`json:"label"`
Cost
float64
`json:"cost"`
UserCost
float64
`json:"user_cost"`
Requests
int64
`json:"requests"`
}{
Date
:
highestCostDay
.
Date
,
Label
:
highestCostDay
.
Label
,
Cost
:
highestCostDay
.
ActualCost
,
UserCost
:
highestCostDay
.
UserCost
,
Requests
:
highestCostDay
.
Requests
,
}
}
...
...
@@ -1612,15 +1815,17 @@ func (r *usageLogRepository) GetAccountUsageStats(ctx context.Context, accountID
Label
string
`json:"label"`
Requests
int64
`json:"requests"`
Cost
float64
`json:"cost"`
UserCost
float64
`json:"user_cost"`
}{
Date
:
highestRequestDay
.
Date
,
Label
:
highestRequestDay
.
Label
,
Requests
:
highestRequestDay
.
Requests
,
Cost
:
highestRequestDay
.
ActualCost
,
UserCost
:
highestRequestDay
.
UserCost
,
}
}
models
,
err
:=
r
.
GetModelStatsWithFilters
(
ctx
,
startTime
,
endTime
,
0
,
0
,
accountID
)
models
,
err
:=
r
.
GetModelStatsWithFilters
(
ctx
,
startTime
,
endTime
,
0
,
0
,
accountID
,
0
,
nil
)
if
err
!=
nil
{
models
=
[]
ModelStat
{}
}
...
...
@@ -1868,11 +2073,13 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
totalCost
float64
actualCost
float64
rateMultiplier
float64
accountRateMultiplier
sql
.
NullFloat64
billingType
int16
stream
bool
durationMs
sql
.
NullInt64
firstTokenMs
sql
.
NullInt64
userAgent
sql
.
NullString
ipAddress
sql
.
NullString
imageCount
int
imageSize
sql
.
NullString
createdAt
time
.
Time
...
...
@@ -1900,11 +2107,13 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
&
totalCost
,
&
actualCost
,
&
rateMultiplier
,
&
accountRateMultiplier
,
&
billingType
,
&
stream
,
&
durationMs
,
&
firstTokenMs
,
&
userAgent
,
&
ipAddress
,
&
imageCount
,
&
imageSize
,
&
createdAt
,
...
...
@@ -1931,6 +2140,7 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
TotalCost
:
totalCost
,
ActualCost
:
actualCost
,
RateMultiplier
:
rateMultiplier
,
AccountRateMultiplier
:
nullFloat64Ptr
(
accountRateMultiplier
),
BillingType
:
int8
(
billingType
),
Stream
:
stream
,
ImageCount
:
imageCount
,
...
...
@@ -1959,6 +2169,9 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
if
userAgent
.
Valid
{
log
.
UserAgent
=
&
userAgent
.
String
}
if
ipAddress
.
Valid
{
log
.
IPAddress
=
&
ipAddress
.
String
}
if
imageSize
.
Valid
{
log
.
ImageSize
=
&
imageSize
.
String
}
...
...
@@ -2034,6 +2247,14 @@ func nullInt(v *int) sql.NullInt64 {
return
sql
.
NullInt64
{
Int64
:
int64
(
*
v
),
Valid
:
true
}
}
func
nullFloat64Ptr
(
v
sql
.
NullFloat64
)
*
float64
{
if
!
v
.
Valid
{
return
nil
}
out
:=
v
.
Float64
return
&
out
}
func
nullString
(
v
*
string
)
sql
.
NullString
{
if
v
==
nil
||
*
v
==
""
{
return
sql
.
NullString
{}
...
...
backend/internal/repository/usage_log_repo_integration_test.go
View file @
b9b4db3d
...
...
@@ -37,6 +37,12 @@ func TestUsageLogRepoSuite(t *testing.T) {
suite
.
Run
(
t
,
new
(
UsageLogRepoSuite
))
}
// truncateToDayUTC 截断到 UTC 日期边界(测试辅助函数)
func
truncateToDayUTC
(
t
time
.
Time
)
time
.
Time
{
t
=
t
.
UTC
()
return
time
.
Date
(
t
.
Year
(),
t
.
Month
(),
t
.
Day
(),
0
,
0
,
0
,
0
,
time
.
UTC
)
}
func
(
s
*
UsageLogRepoSuite
)
createUsageLog
(
user
*
service
.
User
,
apiKey
*
service
.
APIKey
,
account
*
service
.
Account
,
inputTokens
,
outputTokens
int
,
cost
float64
,
createdAt
time
.
Time
)
*
service
.
UsageLog
{
log
:=
&
service
.
UsageLog
{
UserID
:
user
.
ID
,
...
...
@@ -96,6 +102,34 @@ func (s *UsageLogRepoSuite) TestGetByID_NotFound() {
s
.
Require
()
.
Error
(
err
,
"expected error for non-existent ID"
)
}
func
(
s
*
UsageLogRepoSuite
)
TestGetByID_ReturnsAccountRateMultiplier
()
{
user
:=
mustCreateUser
(
s
.
T
(),
s
.
client
,
&
service
.
User
{
Email
:
"getbyid-mult@test.com"
})
apiKey
:=
mustCreateApiKey
(
s
.
T
(),
s
.
client
,
&
service
.
APIKey
{
UserID
:
user
.
ID
,
Key
:
"sk-getbyid-mult"
,
Name
:
"k"
})
account
:=
mustCreateAccount
(
s
.
T
(),
s
.
client
,
&
service
.
Account
{
Name
:
"acc-getbyid-mult"
})
m
:=
0.5
log
:=
&
service
.
UsageLog
{
UserID
:
user
.
ID
,
APIKeyID
:
apiKey
.
ID
,
AccountID
:
account
.
ID
,
RequestID
:
uuid
.
New
()
.
String
(),
Model
:
"claude-3"
,
InputTokens
:
10
,
OutputTokens
:
20
,
TotalCost
:
1.0
,
ActualCost
:
2.0
,
AccountRateMultiplier
:
&
m
,
CreatedAt
:
timezone
.
Today
()
.
Add
(
2
*
time
.
Hour
),
}
_
,
err
:=
s
.
repo
.
Create
(
s
.
ctx
,
log
)
s
.
Require
()
.
NoError
(
err
)
got
,
err
:=
s
.
repo
.
GetByID
(
s
.
ctx
,
log
.
ID
)
s
.
Require
()
.
NoError
(
err
)
s
.
Require
()
.
NotNil
(
got
.
AccountRateMultiplier
)
s
.
Require
()
.
InEpsilon
(
0.5
,
*
got
.
AccountRateMultiplier
,
0.0001
)
}
// --- Delete ---
func
(
s
*
UsageLogRepoSuite
)
TestDelete
()
{
...
...
@@ -198,14 +232,14 @@ func (s *UsageLogRepoSuite) TestListWithFilters() {
// --- GetDashboardStats ---
func
(
s
*
UsageLogRepoSuite
)
TestDashboardStats_TodayTotalsAndPerformance
()
{
now
:=
time
.
Now
()
todayStart
:=
t
imezone
.
To
d
ay
(
)
now
:=
time
.
Now
()
.
UTC
()
todayStart
:=
t
runcate
To
D
ay
UTC
(
now
)
baseStats
,
err
:=
s
.
repo
.
GetDashboardStats
(
s
.
ctx
)
s
.
Require
()
.
NoError
(
err
,
"GetDashboardStats base"
)
userToday
:=
mustCreateUser
(
s
.
T
(),
s
.
client
,
&
service
.
User
{
Email
:
"today@example.com"
,
CreatedAt
:
m
axTime
(
todayStart
.
Add
(
10
*
time
.
Second
),
now
.
Add
(
-
10
*
time
.
Second
)),
CreatedAt
:
testM
axTime
(
todayStart
.
Add
(
10
*
time
.
Second
),
now
.
Add
(
-
10
*
time
.
Second
)),
UpdatedAt
:
now
,
})
userOld
:=
mustCreateUser
(
s
.
T
(),
s
.
client
,
&
service
.
User
{
...
...
@@ -238,7 +272,7 @@ func (s *UsageLogRepoSuite) TestDashboardStats_TodayTotalsAndPerformance() {
TotalCost
:
1.5
,
ActualCost
:
1.2
,
DurationMs
:
&
d1
,
CreatedAt
:
m
axTime
(
todayStart
.
Add
(
2
*
time
.
Minute
),
now
.
Add
(
-
2
*
time
.
Minute
)),
CreatedAt
:
testM
axTime
(
todayStart
.
Add
(
2
*
time
.
Minute
),
now
.
Add
(
-
2
*
time
.
Minute
)),
}
_
,
err
=
s
.
repo
.
Create
(
s
.
ctx
,
logToday
)
s
.
Require
()
.
NoError
(
err
,
"Create logToday"
)
...
...
@@ -273,6 +307,11 @@ func (s *UsageLogRepoSuite) TestDashboardStats_TodayTotalsAndPerformance() {
_
,
err
=
s
.
repo
.
Create
(
s
.
ctx
,
logPerf
)
s
.
Require
()
.
NoError
(
err
,
"Create logPerf"
)
aggRepo
:=
newDashboardAggregationRepositoryWithSQL
(
s
.
tx
)
aggStart
:=
todayStart
.
Add
(
-
2
*
time
.
Hour
)
aggEnd
:=
now
.
Add
(
2
*
time
.
Minute
)
s
.
Require
()
.
NoError
(
aggRepo
.
AggregateRange
(
s
.
ctx
,
aggStart
,
aggEnd
),
"AggregateRange"
)
stats
,
err
:=
s
.
repo
.
GetDashboardStats
(
s
.
ctx
)
s
.
Require
()
.
NoError
(
err
,
"GetDashboardStats"
)
...
...
@@ -303,6 +342,80 @@ func (s *UsageLogRepoSuite) TestDashboardStats_TodayTotalsAndPerformance() {
s
.
Require
()
.
Equal
(
wantTpm
,
stats
.
Tpm
,
"Tpm mismatch"
)
}
func
(
s
*
UsageLogRepoSuite
)
TestDashboardStatsWithRange_Fallback
()
{
now
:=
time
.
Now
()
.
UTC
()
todayStart
:=
truncateToDayUTC
(
now
)
rangeStart
:=
todayStart
.
Add
(
-
24
*
time
.
Hour
)
rangeEnd
:=
now
.
Add
(
1
*
time
.
Second
)
user1
:=
mustCreateUser
(
s
.
T
(),
s
.
client
,
&
service
.
User
{
Email
:
"range-u1@test.com"
})
user2
:=
mustCreateUser
(
s
.
T
(),
s
.
client
,
&
service
.
User
{
Email
:
"range-u2@test.com"
})
apiKey1
:=
mustCreateApiKey
(
s
.
T
(),
s
.
client
,
&
service
.
APIKey
{
UserID
:
user1
.
ID
,
Key
:
"sk-range-1"
,
Name
:
"k1"
})
apiKey2
:=
mustCreateApiKey
(
s
.
T
(),
s
.
client
,
&
service
.
APIKey
{
UserID
:
user2
.
ID
,
Key
:
"sk-range-2"
,
Name
:
"k2"
})
account
:=
mustCreateAccount
(
s
.
T
(),
s
.
client
,
&
service
.
Account
{
Name
:
"acc-range"
})
d1
,
d2
,
d3
:=
100
,
200
,
300
logOutside
:=
&
service
.
UsageLog
{
UserID
:
user1
.
ID
,
APIKeyID
:
apiKey1
.
ID
,
AccountID
:
account
.
ID
,
Model
:
"claude-3"
,
InputTokens
:
7
,
OutputTokens
:
8
,
TotalCost
:
0.8
,
ActualCost
:
0.7
,
DurationMs
:
&
d3
,
CreatedAt
:
rangeStart
.
Add
(
-
1
*
time
.
Hour
),
}
_
,
err
:=
s
.
repo
.
Create
(
s
.
ctx
,
logOutside
)
s
.
Require
()
.
NoError
(
err
)
logRange
:=
&
service
.
UsageLog
{
UserID
:
user1
.
ID
,
APIKeyID
:
apiKey1
.
ID
,
AccountID
:
account
.
ID
,
Model
:
"claude-3"
,
InputTokens
:
10
,
OutputTokens
:
20
,
CacheCreationTokens
:
1
,
CacheReadTokens
:
2
,
TotalCost
:
1.0
,
ActualCost
:
0.9
,
DurationMs
:
&
d1
,
CreatedAt
:
rangeStart
.
Add
(
2
*
time
.
Hour
),
}
_
,
err
=
s
.
repo
.
Create
(
s
.
ctx
,
logRange
)
s
.
Require
()
.
NoError
(
err
)
logToday
:=
&
service
.
UsageLog
{
UserID
:
user2
.
ID
,
APIKeyID
:
apiKey2
.
ID
,
AccountID
:
account
.
ID
,
Model
:
"claude-3"
,
InputTokens
:
5
,
OutputTokens
:
6
,
CacheReadTokens
:
1
,
TotalCost
:
0.5
,
ActualCost
:
0.5
,
DurationMs
:
&
d2
,
CreatedAt
:
now
,
}
_
,
err
=
s
.
repo
.
Create
(
s
.
ctx
,
logToday
)
s
.
Require
()
.
NoError
(
err
)
stats
,
err
:=
s
.
repo
.
GetDashboardStatsWithRange
(
s
.
ctx
,
rangeStart
,
rangeEnd
)
s
.
Require
()
.
NoError
(
err
)
s
.
Require
()
.
Equal
(
int64
(
2
),
stats
.
TotalRequests
)
s
.
Require
()
.
Equal
(
int64
(
15
),
stats
.
TotalInputTokens
)
s
.
Require
()
.
Equal
(
int64
(
26
),
stats
.
TotalOutputTokens
)
s
.
Require
()
.
Equal
(
int64
(
1
),
stats
.
TotalCacheCreationTokens
)
s
.
Require
()
.
Equal
(
int64
(
3
),
stats
.
TotalCacheReadTokens
)
s
.
Require
()
.
Equal
(
int64
(
45
),
stats
.
TotalTokens
)
s
.
Require
()
.
Equal
(
1.5
,
stats
.
TotalCost
)
s
.
Require
()
.
Equal
(
1.4
,
stats
.
TotalActualCost
)
s
.
Require
()
.
InEpsilon
(
150.0
,
stats
.
AverageDurationMs
,
0.0001
)
}
// --- GetUserDashboardStats ---
func
(
s
*
UsageLogRepoSuite
)
TestGetUserDashboardStats
()
{
...
...
@@ -325,12 +438,202 @@ func (s *UsageLogRepoSuite) TestGetAccountTodayStats() {
apiKey
:=
mustCreateApiKey
(
s
.
T
(),
s
.
client
,
&
service
.
APIKey
{
UserID
:
user
.
ID
,
Key
:
"sk-acctoday"
,
Name
:
"k"
})
account
:=
mustCreateAccount
(
s
.
T
(),
s
.
client
,
&
service
.
Account
{
Name
:
"acc-today"
})
s
.
createUsageLog
(
user
,
apiKey
,
account
,
10
,
20
,
0.5
,
time
.
Now
())
createdAt
:=
timezone
.
Today
()
.
Add
(
1
*
time
.
Hour
)
m1
:=
1.5
m2
:=
0.0
_
,
err
:=
s
.
repo
.
Create
(
s
.
ctx
,
&
service
.
UsageLog
{
UserID
:
user
.
ID
,
APIKeyID
:
apiKey
.
ID
,
AccountID
:
account
.
ID
,
RequestID
:
uuid
.
New
()
.
String
(),
Model
:
"claude-3"
,
InputTokens
:
10
,
OutputTokens
:
20
,
TotalCost
:
1.0
,
ActualCost
:
2.0
,
AccountRateMultiplier
:
&
m1
,
CreatedAt
:
createdAt
,
})
s
.
Require
()
.
NoError
(
err
)
_
,
err
=
s
.
repo
.
Create
(
s
.
ctx
,
&
service
.
UsageLog
{
UserID
:
user
.
ID
,
APIKeyID
:
apiKey
.
ID
,
AccountID
:
account
.
ID
,
RequestID
:
uuid
.
New
()
.
String
(),
Model
:
"claude-3"
,
InputTokens
:
5
,
OutputTokens
:
5
,
TotalCost
:
0.5
,
ActualCost
:
1.0
,
AccountRateMultiplier
:
&
m2
,
CreatedAt
:
createdAt
,
})
s
.
Require
()
.
NoError
(
err
)
stats
,
err
:=
s
.
repo
.
GetAccountTodayStats
(
s
.
ctx
,
account
.
ID
)
s
.
Require
()
.
NoError
(
err
,
"GetAccountTodayStats"
)
s
.
Require
()
.
Equal
(
int64
(
1
),
stats
.
Requests
)
s
.
Require
()
.
Equal
(
int64
(
30
),
stats
.
Tokens
)
s
.
Require
()
.
Equal
(
int64
(
2
),
stats
.
Requests
)
s
.
Require
()
.
Equal
(
int64
(
40
),
stats
.
Tokens
)
// account cost = SUM(total_cost * account_rate_multiplier)
s
.
Require
()
.
InEpsilon
(
1.5
,
stats
.
Cost
,
0.0001
)
// standard cost = SUM(total_cost)
s
.
Require
()
.
InEpsilon
(
1.5
,
stats
.
StandardCost
,
0.0001
)
// user cost = SUM(actual_cost)
s
.
Require
()
.
InEpsilon
(
3.0
,
stats
.
UserCost
,
0.0001
)
}
func
(
s
*
UsageLogRepoSuite
)
TestDashboardAggregationConsistency
()
{
now
:=
time
.
Now
()
.
UTC
()
.
Truncate
(
time
.
Second
)
// 使用固定的时间偏移确保 hour1 和 hour2 在同一天且都在过去
// 选择当天 02:00 和 03:00 作为测试时间点(基于 now 的日期)
dayStart
:=
truncateToDayUTC
(
now
)
hour1
:=
dayStart
.
Add
(
2
*
time
.
Hour
)
// 当天 02:00
hour2
:=
dayStart
.
Add
(
3
*
time
.
Hour
)
// 当天 03:00
// 如果当前时间早于 hour2,则使用昨天的时间
if
now
.
Before
(
hour2
.
Add
(
time
.
Hour
))
{
dayStart
=
dayStart
.
Add
(
-
24
*
time
.
Hour
)
hour1
=
dayStart
.
Add
(
2
*
time
.
Hour
)
hour2
=
dayStart
.
Add
(
3
*
time
.
Hour
)
}
user1
:=
mustCreateUser
(
s
.
T
(),
s
.
client
,
&
service
.
User
{
Email
:
"agg-u1@test.com"
})
user2
:=
mustCreateUser
(
s
.
T
(),
s
.
client
,
&
service
.
User
{
Email
:
"agg-u2@test.com"
})
apiKey1
:=
mustCreateApiKey
(
s
.
T
(),
s
.
client
,
&
service
.
APIKey
{
UserID
:
user1
.
ID
,
Key
:
"sk-agg-1"
,
Name
:
"k1"
})
apiKey2
:=
mustCreateApiKey
(
s
.
T
(),
s
.
client
,
&
service
.
APIKey
{
UserID
:
user2
.
ID
,
Key
:
"sk-agg-2"
,
Name
:
"k2"
})
account
:=
mustCreateAccount
(
s
.
T
(),
s
.
client
,
&
service
.
Account
{
Name
:
"acc-agg"
})
d1
,
d2
,
d3
:=
100
,
200
,
150
log1
:=
&
service
.
UsageLog
{
UserID
:
user1
.
ID
,
APIKeyID
:
apiKey1
.
ID
,
AccountID
:
account
.
ID
,
Model
:
"claude-3"
,
InputTokens
:
10
,
OutputTokens
:
20
,
CacheCreationTokens
:
2
,
CacheReadTokens
:
1
,
TotalCost
:
1.0
,
ActualCost
:
0.9
,
DurationMs
:
&
d1
,
CreatedAt
:
hour1
.
Add
(
5
*
time
.
Minute
),
}
_
,
err
:=
s
.
repo
.
Create
(
s
.
ctx
,
log1
)
s
.
Require
()
.
NoError
(
err
)
log2
:=
&
service
.
UsageLog
{
UserID
:
user1
.
ID
,
APIKeyID
:
apiKey1
.
ID
,
AccountID
:
account
.
ID
,
Model
:
"claude-3"
,
InputTokens
:
5
,
OutputTokens
:
5
,
TotalCost
:
0.5
,
ActualCost
:
0.5
,
DurationMs
:
&
d2
,
CreatedAt
:
hour1
.
Add
(
20
*
time
.
Minute
),
}
_
,
err
=
s
.
repo
.
Create
(
s
.
ctx
,
log2
)
s
.
Require
()
.
NoError
(
err
)
log3
:=
&
service
.
UsageLog
{
UserID
:
user2
.
ID
,
APIKeyID
:
apiKey2
.
ID
,
AccountID
:
account
.
ID
,
Model
:
"claude-3"
,
InputTokens
:
7
,
OutputTokens
:
8
,
TotalCost
:
0.7
,
ActualCost
:
0.7
,
DurationMs
:
&
d3
,
CreatedAt
:
hour2
.
Add
(
10
*
time
.
Minute
),
}
_
,
err
=
s
.
repo
.
Create
(
s
.
ctx
,
log3
)
s
.
Require
()
.
NoError
(
err
)
aggRepo
:=
newDashboardAggregationRepositoryWithSQL
(
s
.
tx
)
aggStart
:=
hour1
.
Add
(
-
5
*
time
.
Minute
)
aggEnd
:=
hour2
.
Add
(
time
.
Hour
)
// 确保覆盖 hour2 的所有数据
s
.
Require
()
.
NoError
(
aggRepo
.
AggregateRange
(
s
.
ctx
,
aggStart
,
aggEnd
))
type
hourlyRow
struct
{
totalRequests
int64
inputTokens
int64
outputTokens
int64
cacheCreationTokens
int64
cacheReadTokens
int64
totalCost
float64
actualCost
float64
totalDurationMs
int64
activeUsers
int64
}
fetchHourly
:=
func
(
bucketStart
time
.
Time
)
hourlyRow
{
var
row
hourlyRow
err
:=
scanSingleRow
(
s
.
ctx
,
s
.
tx
,
`
SELECT total_requests, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens,
total_cost, actual_cost, total_duration_ms, active_users
FROM usage_dashboard_hourly
WHERE bucket_start = $1
`
,
[]
any
{
bucketStart
},
&
row
.
totalRequests
,
&
row
.
inputTokens
,
&
row
.
outputTokens
,
&
row
.
cacheCreationTokens
,
&
row
.
cacheReadTokens
,
&
row
.
totalCost
,
&
row
.
actualCost
,
&
row
.
totalDurationMs
,
&
row
.
activeUsers
,
)
s
.
Require
()
.
NoError
(
err
)
return
row
}
hour1Row
:=
fetchHourly
(
hour1
)
s
.
Require
()
.
Equal
(
int64
(
2
),
hour1Row
.
totalRequests
)
s
.
Require
()
.
Equal
(
int64
(
15
),
hour1Row
.
inputTokens
)
s
.
Require
()
.
Equal
(
int64
(
25
),
hour1Row
.
outputTokens
)
s
.
Require
()
.
Equal
(
int64
(
2
),
hour1Row
.
cacheCreationTokens
)
s
.
Require
()
.
Equal
(
int64
(
1
),
hour1Row
.
cacheReadTokens
)
s
.
Require
()
.
Equal
(
1.5
,
hour1Row
.
totalCost
)
s
.
Require
()
.
Equal
(
1.4
,
hour1Row
.
actualCost
)
s
.
Require
()
.
Equal
(
int64
(
300
),
hour1Row
.
totalDurationMs
)
s
.
Require
()
.
Equal
(
int64
(
1
),
hour1Row
.
activeUsers
)
hour2Row
:=
fetchHourly
(
hour2
)
s
.
Require
()
.
Equal
(
int64
(
1
),
hour2Row
.
totalRequests
)
s
.
Require
()
.
Equal
(
int64
(
7
),
hour2Row
.
inputTokens
)
s
.
Require
()
.
Equal
(
int64
(
8
),
hour2Row
.
outputTokens
)
s
.
Require
()
.
Equal
(
int64
(
0
),
hour2Row
.
cacheCreationTokens
)
s
.
Require
()
.
Equal
(
int64
(
0
),
hour2Row
.
cacheReadTokens
)
s
.
Require
()
.
Equal
(
0.7
,
hour2Row
.
totalCost
)
s
.
Require
()
.
Equal
(
0.7
,
hour2Row
.
actualCost
)
s
.
Require
()
.
Equal
(
int64
(
150
),
hour2Row
.
totalDurationMs
)
s
.
Require
()
.
Equal
(
int64
(
1
),
hour2Row
.
activeUsers
)
var
daily
struct
{
totalRequests
int64
inputTokens
int64
outputTokens
int64
cacheCreationTokens
int64
cacheReadTokens
int64
totalCost
float64
actualCost
float64
totalDurationMs
int64
activeUsers
int64
}
err
=
scanSingleRow
(
s
.
ctx
,
s
.
tx
,
`
SELECT total_requests, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens,
total_cost, actual_cost, total_duration_ms, active_users
FROM usage_dashboard_daily
WHERE bucket_date = $1::date
`
,
[]
any
{
dayStart
},
&
daily
.
totalRequests
,
&
daily
.
inputTokens
,
&
daily
.
outputTokens
,
&
daily
.
cacheCreationTokens
,
&
daily
.
cacheReadTokens
,
&
daily
.
totalCost
,
&
daily
.
actualCost
,
&
daily
.
totalDurationMs
,
&
daily
.
activeUsers
,
)
s
.
Require
()
.
NoError
(
err
)
s
.
Require
()
.
Equal
(
int64
(
3
),
daily
.
totalRequests
)
s
.
Require
()
.
Equal
(
int64
(
22
),
daily
.
inputTokens
)
s
.
Require
()
.
Equal
(
int64
(
33
),
daily
.
outputTokens
)
s
.
Require
()
.
Equal
(
int64
(
2
),
daily
.
cacheCreationTokens
)
s
.
Require
()
.
Equal
(
int64
(
1
),
daily
.
cacheReadTokens
)
s
.
Require
()
.
Equal
(
2.2
,
daily
.
totalCost
)
s
.
Require
()
.
Equal
(
2.1
,
daily
.
actualCost
)
s
.
Require
()
.
Equal
(
int64
(
450
),
daily
.
totalDurationMs
)
s
.
Require
()
.
Equal
(
int64
(
2
),
daily
.
activeUsers
)
}
// --- GetBatchUserUsageStats ---
...
...
@@ -398,7 +701,7 @@ func (s *UsageLogRepoSuite) TestGetGlobalStats() {
s
.
Require
()
.
Equal
(
int64
(
45
),
stats
.
TotalOutputTokens
)
}
func
m
axTime
(
a
,
b
time
.
Time
)
time
.
Time
{
func
testM
axTime
(
a
,
b
time
.
Time
)
time
.
Time
{
if
a
.
After
(
b
)
{
return
a
}
...
...
@@ -641,17 +944,17 @@ func (s *UsageLogRepoSuite) TestGetUsageTrendWithFilters() {
endTime
:=
base
.
Add
(
48
*
time
.
Hour
)
// Test with user filter
trend
,
err
:=
s
.
repo
.
GetUsageTrendWithFilters
(
s
.
ctx
,
startTime
,
endTime
,
"day"
,
user
.
ID
,
0
)
trend
,
err
:=
s
.
repo
.
GetUsageTrendWithFilters
(
s
.
ctx
,
startTime
,
endTime
,
"day"
,
user
.
ID
,
0
,
0
,
0
,
""
,
nil
)
s
.
Require
()
.
NoError
(
err
,
"GetUsageTrendWithFilters user filter"
)
s
.
Require
()
.
Len
(
trend
,
2
)
// Test with apiKey filter
trend
,
err
=
s
.
repo
.
GetUsageTrendWithFilters
(
s
.
ctx
,
startTime
,
endTime
,
"day"
,
0
,
apiKey
.
ID
)
trend
,
err
=
s
.
repo
.
GetUsageTrendWithFilters
(
s
.
ctx
,
startTime
,
endTime
,
"day"
,
0
,
apiKey
.
ID
,
0
,
0
,
""
,
nil
)
s
.
Require
()
.
NoError
(
err
,
"GetUsageTrendWithFilters apiKey filter"
)
s
.
Require
()
.
Len
(
trend
,
2
)
// Test with both filters
trend
,
err
=
s
.
repo
.
GetUsageTrendWithFilters
(
s
.
ctx
,
startTime
,
endTime
,
"day"
,
user
.
ID
,
apiKey
.
ID
)
trend
,
err
=
s
.
repo
.
GetUsageTrendWithFilters
(
s
.
ctx
,
startTime
,
endTime
,
"day"
,
user
.
ID
,
apiKey
.
ID
,
0
,
0
,
""
,
nil
)
s
.
Require
()
.
NoError
(
err
,
"GetUsageTrendWithFilters both filters"
)
s
.
Require
()
.
Len
(
trend
,
2
)
}
...
...
@@ -668,7 +971,7 @@ func (s *UsageLogRepoSuite) TestGetUsageTrendWithFilters_HourlyGranularity() {
startTime
:=
base
.
Add
(
-
1
*
time
.
Hour
)
endTime
:=
base
.
Add
(
3
*
time
.
Hour
)
trend
,
err
:=
s
.
repo
.
GetUsageTrendWithFilters
(
s
.
ctx
,
startTime
,
endTime
,
"hour"
,
user
.
ID
,
0
)
trend
,
err
:=
s
.
repo
.
GetUsageTrendWithFilters
(
s
.
ctx
,
startTime
,
endTime
,
"hour"
,
user
.
ID
,
0
,
0
,
0
,
""
,
nil
)
s
.
Require
()
.
NoError
(
err
,
"GetUsageTrendWithFilters hourly"
)
s
.
Require
()
.
Len
(
trend
,
2
)
}
...
...
@@ -714,17 +1017,17 @@ func (s *UsageLogRepoSuite) TestGetModelStatsWithFilters() {
endTime
:=
base
.
Add
(
2
*
time
.
Hour
)
// Test with user filter
stats
,
err
:=
s
.
repo
.
GetModelStatsWithFilters
(
s
.
ctx
,
startTime
,
endTime
,
user
.
ID
,
0
,
0
)
stats
,
err
:=
s
.
repo
.
GetModelStatsWithFilters
(
s
.
ctx
,
startTime
,
endTime
,
user
.
ID
,
0
,
0
,
0
,
nil
)
s
.
Require
()
.
NoError
(
err
,
"GetModelStatsWithFilters user filter"
)
s
.
Require
()
.
Len
(
stats
,
2
)
// Test with apiKey filter
stats
,
err
=
s
.
repo
.
GetModelStatsWithFilters
(
s
.
ctx
,
startTime
,
endTime
,
0
,
apiKey
.
ID
,
0
)
stats
,
err
=
s
.
repo
.
GetModelStatsWithFilters
(
s
.
ctx
,
startTime
,
endTime
,
0
,
apiKey
.
ID
,
0
,
0
,
nil
)
s
.
Require
()
.
NoError
(
err
,
"GetModelStatsWithFilters apiKey filter"
)
s
.
Require
()
.
Len
(
stats
,
2
)
// Test with account filter
stats
,
err
=
s
.
repo
.
GetModelStatsWithFilters
(
s
.
ctx
,
startTime
,
endTime
,
0
,
0
,
account
.
ID
)
stats
,
err
=
s
.
repo
.
GetModelStatsWithFilters
(
s
.
ctx
,
startTime
,
endTime
,
0
,
0
,
account
.
ID
,
0
,
nil
)
s
.
Require
()
.
NoError
(
err
,
"GetModelStatsWithFilters account filter"
)
s
.
Require
()
.
Len
(
stats
,
2
)
}
...
...
backend/internal/repository/wire.go
View file @
b9b4db3d
...
...
@@ -37,6 +37,16 @@ func ProvidePricingRemoteClient(cfg *config.Config) service.PricingRemoteClient
return
NewPricingRemoteClient
(
cfg
.
Update
.
ProxyURL
)
}
// ProvideSessionLimitCache 创建会话限制缓存
// 用于 Anthropic OAuth/SetupToken 账号的并发会话数量控制
func
ProvideSessionLimitCache
(
rdb
*
redis
.
Client
,
cfg
*
config
.
Config
)
service
.
SessionLimitCache
{
defaultIdleTimeoutMinutes
:=
5
// 默认 5 分钟空闲超时
if
cfg
!=
nil
&&
cfg
.
Gateway
.
SessionIdleTimeoutMinutes
>
0
{
defaultIdleTimeoutMinutes
=
cfg
.
Gateway
.
SessionIdleTimeoutMinutes
}
return
NewSessionLimitCache
(
rdb
,
defaultIdleTimeoutMinutes
)
}
// ProviderSet is the Wire provider set for all repositories
var
ProviderSet
=
wire
.
NewSet
(
NewUserRepository
,
...
...
@@ -45,8 +55,11 @@ var ProviderSet = wire.NewSet(
NewAccountRepository
,
NewProxyRepository
,
NewRedeemCodeRepository
,
NewPromoCodeRepository
,
NewUsageLogRepository
,
NewDashboardAggregationRepository
,
NewSettingRepository
,
NewOpsRepository
,
NewUserSubscriptionRepository
,
NewUserAttributeDefinitionRepository
,
NewUserAttributeValueRepository
,
...
...
@@ -56,12 +69,18 @@ var ProviderSet = wire.NewSet(
NewBillingCache
,
NewAPIKeyCache
,
NewTempUnschedCache
,
NewTimeoutCounterCache
,
ProvideConcurrencyCache
,
ProvideSessionLimitCache
,
NewDashboardCache
,
NewEmailCache
,
NewIdentityCache
,
NewRedeemCache
,
NewUpdateCache
,
NewGeminiTokenCache
,
NewSchedulerCache
,
NewSchedulerOutboxRepository
,
NewProxyLatencyCache
,
// HTTP service ports (DI Strategy A: return interface directly)
NewTurnstileVerifier
,
...
...
backend/internal/server/api_contract_test.go
View file @
b9b4db3d
...
...
@@ -82,6 +82,8 @@ func TestAPIContracts(t *testing.T) {
"name": "Key One",
"group_id": null,
"status": "active",
"ip_whitelist": null,
"ip_blacklist": null,
"created_at": "2025-01-02T03:04:05Z",
"updated_at": "2025-01-02T03:04:05Z"
}
...
...
@@ -116,6 +118,8 @@ func TestAPIContracts(t *testing.T) {
"name": "Key One",
"group_id": null,
"status": "active",
"ip_whitelist": null,
"ip_blacklist": null,
"created_at": "2025-01-02T03:04:05Z",
"updated_at": "2025-01-02T03:04:05Z"
}
...
...
@@ -237,6 +241,7 @@ func TestAPIContracts(t *testing.T) {
"total_cost": 0.5,
"actual_cost": 0.5,
"rate_multiplier": 1,
"account_rate_multiplier": null,
"billing_type": 0,
"stream": true,
"duration_ms": 100,
...
...
@@ -283,6 +288,11 @@ func TestAPIContracts(t *testing.T) {
service
.
SettingKeyDefaultConcurrency
:
"5"
,
service
.
SettingKeyDefaultBalance
:
"1.25"
,
service
.
SettingKeyOpsMonitoringEnabled
:
"false"
,
service
.
SettingKeyOpsRealtimeMonitoringEnabled
:
"true"
,
service
.
SettingKeyOpsQueryModeDefault
:
"auto"
,
service
.
SettingKeyOpsMetricsIntervalSeconds
:
"60"
,
})
},
method
:
http
.
MethodGet
,
...
...
@@ -308,6 +318,10 @@ func TestAPIContracts(t *testing.T) {
"linuxdo_connect_client_id": "",
"linuxdo_connect_client_secret_configured": false,
"linuxdo_connect_redirect_url": "",
"ops_monitoring_enabled": false,
"ops_realtime_monitoring_enabled": true,
"ops_query_mode_default": "auto",
"ops_metrics_interval_seconds": 60,
"site_name": "Sub2API",
"site_logo": "",
"site_subtitle": "Subtitle",
...
...
@@ -322,7 +336,32 @@ func TestAPIContracts(t *testing.T) {
"fallback_model_gemini": "gemini-2.5-pro",
"fallback_model_openai": "gpt-4o",
"enable_identity_patch": true,
"identity_patch_prompt": ""
"identity_patch_prompt": "",
"home_content": ""
}
}`
,
},
{
name
:
"POST /api/v1/admin/accounts/bulk-update"
,
method
:
http
.
MethodPost
,
path
:
"/api/v1/admin/accounts/bulk-update"
,
body
:
`{"account_ids":[101,102],"schedulable":false}`
,
headers
:
map
[
string
]
string
{
"Content-Type"
:
"application/json"
,
},
wantStatus
:
http
.
StatusOK
,
wantJSON
:
`{
"code": 0,
"message": "success",
"data": {
"success": 2,
"failed": 0,
"success_ids": [101, 102],
"failed_ids": [],
"results": [
{"account_id": 101, "success": true},
{"account_id": 102, "success": true}
]
}
}`
,
},
...
...
@@ -377,6 +416,9 @@ func newContractDeps(t *testing.T) *contractDeps {
apiKeyCache
:=
stubApiKeyCache
{}
groupRepo
:=
stubGroupRepo
{}
userSubRepo
:=
stubUserSubscriptionRepo
{}
accountRepo
:=
stubAccountRepo
{}
proxyRepo
:=
stubProxyRepo
{}
redeemRepo
:=
stubRedeemCodeRepo
{}
cfg
:=
&
config
.
Config
{
Default
:
config
.
DefaultConfig
{
...
...
@@ -385,19 +427,21 @@ func newContractDeps(t *testing.T) *contractDeps {
RunMode
:
config
.
RunModeStandard
,
}
userService
:=
service
.
NewUserService
(
userRepo
)
userService
:=
service
.
NewUserService
(
userRepo
,
nil
)
apiKeyService
:=
service
.
NewAPIKeyService
(
apiKeyRepo
,
userRepo
,
groupRepo
,
userSubRepo
,
apiKeyCache
,
cfg
)
usageRepo
:=
newStubUsageLogRepo
()
usageService
:=
service
.
NewUsageService
(
usageRepo
,
userRepo
,
nil
)
usageService
:=
service
.
NewUsageService
(
usageRepo
,
userRepo
,
nil
,
nil
)
settingRepo
:=
newStubSettingRepo
()
settingService
:=
service
.
NewSettingService
(
settingRepo
,
cfg
)
authHandler
:=
handler
.
NewAuthHandler
(
cfg
,
nil
,
userService
,
settingService
)
adminService
:=
service
.
NewAdminService
(
userRepo
,
groupRepo
,
&
accountRepo
,
proxyRepo
,
apiKeyRepo
,
redeemRepo
,
nil
,
nil
,
nil
,
nil
)
authHandler
:=
handler
.
NewAuthHandler
(
cfg
,
nil
,
userService
,
settingService
,
nil
)
apiKeyHandler
:=
handler
.
NewAPIKeyHandler
(
apiKeyService
)
usageHandler
:=
handler
.
NewUsageHandler
(
usageService
,
apiKeyService
)
adminSettingHandler
:=
adminhandler
.
NewSettingHandler
(
settingService
,
nil
,
nil
)
adminSettingHandler
:=
adminhandler
.
NewSettingHandler
(
settingService
,
nil
,
nil
,
nil
)
adminAccountHandler
:=
adminhandler
.
NewAccountHandler
(
adminService
,
nil
,
nil
,
nil
,
nil
,
nil
,
nil
,
nil
,
nil
,
nil
,
nil
)
jwtAuth
:=
func
(
c
*
gin
.
Context
)
{
c
.
Set
(
string
(
middleware
.
ContextKeyUser
),
middleware
.
AuthSubject
{
...
...
@@ -437,6 +481,7 @@ func newContractDeps(t *testing.T) *contractDeps {
v1Admin
:=
v1
.
Group
(
"/admin"
)
v1Admin
.
Use
(
adminAuth
)
v1Admin
.
GET
(
"/settings"
,
adminSettingHandler
.
GetSettings
)
v1Admin
.
POST
(
"/accounts/bulk-update"
,
adminAccountHandler
.
BulkUpdate
)
return
&
contractDeps
{
now
:
now
,
...
...
@@ -561,6 +606,18 @@ func (stubApiKeyCache) SetDailyUsageExpiry(ctx context.Context, apiKey string, t
return
nil
}
func
(
stubApiKeyCache
)
GetAuthCache
(
ctx
context
.
Context
,
key
string
)
(
*
service
.
APIKeyAuthCacheEntry
,
error
)
{
return
nil
,
nil
}
func
(
stubApiKeyCache
)
SetAuthCache
(
ctx
context
.
Context
,
key
string
,
entry
*
service
.
APIKeyAuthCacheEntry
,
ttl
time
.
Duration
)
error
{
return
nil
}
func
(
stubApiKeyCache
)
DeleteAuthCache
(
ctx
context
.
Context
,
key
string
)
error
{
return
nil
}
type
stubGroupRepo
struct
{}
func
(
stubGroupRepo
)
Create
(
ctx
context
.
Context
,
group
*
service
.
Group
)
error
{
...
...
@@ -571,6 +628,10 @@ func (stubGroupRepo) GetByID(ctx context.Context, id int64) (*service.Group, err
return
nil
,
service
.
ErrGroupNotFound
}
func
(
stubGroupRepo
)
GetByIDLite
(
ctx
context
.
Context
,
id
int64
)
(
*
service
.
Group
,
error
)
{
return
nil
,
service
.
ErrGroupNotFound
}
func
(
stubGroupRepo
)
Update
(
ctx
context
.
Context
,
group
*
service
.
Group
)
error
{
return
errors
.
New
(
"not implemented"
)
}
...
...
@@ -611,6 +672,251 @@ func (stubGroupRepo) DeleteAccountGroupsByGroupID(ctx context.Context, groupID i
return
0
,
errors
.
New
(
"not implemented"
)
}
type
stubAccountRepo
struct
{
bulkUpdateIDs
[]
int64
}
func
(
s
*
stubAccountRepo
)
Create
(
ctx
context
.
Context
,
account
*
service
.
Account
)
error
{
return
errors
.
New
(
"not implemented"
)
}
func
(
s
*
stubAccountRepo
)
GetByID
(
ctx
context
.
Context
,
id
int64
)
(
*
service
.
Account
,
error
)
{
return
nil
,
service
.
ErrAccountNotFound
}
func
(
s
*
stubAccountRepo
)
GetByIDs
(
ctx
context
.
Context
,
ids
[]
int64
)
([]
*
service
.
Account
,
error
)
{
return
nil
,
errors
.
New
(
"not implemented"
)
}
func
(
s
*
stubAccountRepo
)
ExistsByID
(
ctx
context
.
Context
,
id
int64
)
(
bool
,
error
)
{
return
false
,
errors
.
New
(
"not implemented"
)
}
func
(
s
*
stubAccountRepo
)
GetByCRSAccountID
(
ctx
context
.
Context
,
crsAccountID
string
)
(
*
service
.
Account
,
error
)
{
return
nil
,
errors
.
New
(
"not implemented"
)
}
func
(
s
*
stubAccountRepo
)
Update
(
ctx
context
.
Context
,
account
*
service
.
Account
)
error
{
return
errors
.
New
(
"not implemented"
)
}
func
(
s
*
stubAccountRepo
)
Delete
(
ctx
context
.
Context
,
id
int64
)
error
{
return
errors
.
New
(
"not implemented"
)
}
func
(
s
*
stubAccountRepo
)
List
(
ctx
context
.
Context
,
params
pagination
.
PaginationParams
)
([]
service
.
Account
,
*
pagination
.
PaginationResult
,
error
)
{
return
nil
,
nil
,
errors
.
New
(
"not implemented"
)
}
func
(
s
*
stubAccountRepo
)
ListWithFilters
(
ctx
context
.
Context
,
params
pagination
.
PaginationParams
,
platform
,
accountType
,
status
,
search
string
)
([]
service
.
Account
,
*
pagination
.
PaginationResult
,
error
)
{
return
nil
,
nil
,
errors
.
New
(
"not implemented"
)
}
func
(
s
*
stubAccountRepo
)
ListByGroup
(
ctx
context
.
Context
,
groupID
int64
)
([]
service
.
Account
,
error
)
{
return
nil
,
errors
.
New
(
"not implemented"
)
}
func
(
s
*
stubAccountRepo
)
ListActive
(
ctx
context
.
Context
)
([]
service
.
Account
,
error
)
{
return
nil
,
errors
.
New
(
"not implemented"
)
}
func
(
s
*
stubAccountRepo
)
ListByPlatform
(
ctx
context
.
Context
,
platform
string
)
([]
service
.
Account
,
error
)
{
return
nil
,
errors
.
New
(
"not implemented"
)
}
func
(
s
*
stubAccountRepo
)
UpdateLastUsed
(
ctx
context
.
Context
,
id
int64
)
error
{
return
errors
.
New
(
"not implemented"
)
}
func
(
s
*
stubAccountRepo
)
BatchUpdateLastUsed
(
ctx
context
.
Context
,
updates
map
[
int64
]
time
.
Time
)
error
{
return
errors
.
New
(
"not implemented"
)
}
func
(
s
*
stubAccountRepo
)
SetError
(
ctx
context
.
Context
,
id
int64
,
errorMsg
string
)
error
{
return
errors
.
New
(
"not implemented"
)
}
func
(
s
*
stubAccountRepo
)
ClearError
(
ctx
context
.
Context
,
id
int64
)
error
{
return
errors
.
New
(
"not implemented"
)
}
func
(
s
*
stubAccountRepo
)
SetSchedulable
(
ctx
context
.
Context
,
id
int64
,
schedulable
bool
)
error
{
return
errors
.
New
(
"not implemented"
)
}
func
(
s
*
stubAccountRepo
)
AutoPauseExpiredAccounts
(
ctx
context
.
Context
,
now
time
.
Time
)
(
int64
,
error
)
{
return
0
,
errors
.
New
(
"not implemented"
)
}
func
(
s
*
stubAccountRepo
)
BindGroups
(
ctx
context
.
Context
,
accountID
int64
,
groupIDs
[]
int64
)
error
{
return
errors
.
New
(
"not implemented"
)
}
func
(
s
*
stubAccountRepo
)
ListSchedulable
(
ctx
context
.
Context
)
([]
service
.
Account
,
error
)
{
return
nil
,
errors
.
New
(
"not implemented"
)
}
func
(
s
*
stubAccountRepo
)
ListSchedulableByGroupID
(
ctx
context
.
Context
,
groupID
int64
)
([]
service
.
Account
,
error
)
{
return
nil
,
errors
.
New
(
"not implemented"
)
}
func
(
s
*
stubAccountRepo
)
ListSchedulableByPlatform
(
ctx
context
.
Context
,
platform
string
)
([]
service
.
Account
,
error
)
{
return
nil
,
errors
.
New
(
"not implemented"
)
}
func
(
s
*
stubAccountRepo
)
ListSchedulableByGroupIDAndPlatform
(
ctx
context
.
Context
,
groupID
int64
,
platform
string
)
([]
service
.
Account
,
error
)
{
return
nil
,
errors
.
New
(
"not implemented"
)
}
func
(
s
*
stubAccountRepo
)
ListSchedulableByPlatforms
(
ctx
context
.
Context
,
platforms
[]
string
)
([]
service
.
Account
,
error
)
{
return
nil
,
errors
.
New
(
"not implemented"
)
}
func
(
s
*
stubAccountRepo
)
ListSchedulableByGroupIDAndPlatforms
(
ctx
context
.
Context
,
groupID
int64
,
platforms
[]
string
)
([]
service
.
Account
,
error
)
{
return
nil
,
errors
.
New
(
"not implemented"
)
}
func
(
s
*
stubAccountRepo
)
SetRateLimited
(
ctx
context
.
Context
,
id
int64
,
resetAt
time
.
Time
)
error
{
return
errors
.
New
(
"not implemented"
)
}
func
(
s
*
stubAccountRepo
)
SetAntigravityQuotaScopeLimit
(
ctx
context
.
Context
,
id
int64
,
scope
service
.
AntigravityQuotaScope
,
resetAt
time
.
Time
)
error
{
return
errors
.
New
(
"not implemented"
)
}
func
(
s
*
stubAccountRepo
)
SetModelRateLimit
(
ctx
context
.
Context
,
id
int64
,
scope
string
,
resetAt
time
.
Time
)
error
{
return
errors
.
New
(
"not implemented"
)
}
func
(
s
*
stubAccountRepo
)
SetOverloaded
(
ctx
context
.
Context
,
id
int64
,
until
time
.
Time
)
error
{
return
errors
.
New
(
"not implemented"
)
}
func
(
s
*
stubAccountRepo
)
SetTempUnschedulable
(
ctx
context
.
Context
,
id
int64
,
until
time
.
Time
,
reason
string
)
error
{
return
errors
.
New
(
"not implemented"
)
}
func
(
s
*
stubAccountRepo
)
ClearTempUnschedulable
(
ctx
context
.
Context
,
id
int64
)
error
{
return
errors
.
New
(
"not implemented"
)
}
func
(
s
*
stubAccountRepo
)
ClearRateLimit
(
ctx
context
.
Context
,
id
int64
)
error
{
return
errors
.
New
(
"not implemented"
)
}
func
(
s
*
stubAccountRepo
)
ClearAntigravityQuotaScopes
(
ctx
context
.
Context
,
id
int64
)
error
{
return
errors
.
New
(
"not implemented"
)
}
func
(
s
*
stubAccountRepo
)
ClearModelRateLimits
(
ctx
context
.
Context
,
id
int64
)
error
{
return
errors
.
New
(
"not implemented"
)
}
func
(
s
*
stubAccountRepo
)
UpdateSessionWindow
(
ctx
context
.
Context
,
id
int64
,
start
,
end
*
time
.
Time
,
status
string
)
error
{
return
errors
.
New
(
"not implemented"
)
}
func
(
s
*
stubAccountRepo
)
UpdateExtra
(
ctx
context
.
Context
,
id
int64
,
updates
map
[
string
]
any
)
error
{
return
errors
.
New
(
"not implemented"
)
}
func
(
s
*
stubAccountRepo
)
BulkUpdate
(
ctx
context
.
Context
,
ids
[]
int64
,
updates
service
.
AccountBulkUpdate
)
(
int64
,
error
)
{
s
.
bulkUpdateIDs
=
append
([]
int64
{},
ids
...
)
return
int64
(
len
(
ids
)),
nil
}
type
stubProxyRepo
struct
{}
func
(
stubProxyRepo
)
Create
(
ctx
context
.
Context
,
proxy
*
service
.
Proxy
)
error
{
return
errors
.
New
(
"not implemented"
)
}
func
(
stubProxyRepo
)
GetByID
(
ctx
context
.
Context
,
id
int64
)
(
*
service
.
Proxy
,
error
)
{
return
nil
,
service
.
ErrProxyNotFound
}
func
(
stubProxyRepo
)
Update
(
ctx
context
.
Context
,
proxy
*
service
.
Proxy
)
error
{
return
errors
.
New
(
"not implemented"
)
}
func
(
stubProxyRepo
)
Delete
(
ctx
context
.
Context
,
id
int64
)
error
{
return
errors
.
New
(
"not implemented"
)
}
func
(
stubProxyRepo
)
List
(
ctx
context
.
Context
,
params
pagination
.
PaginationParams
)
([]
service
.
Proxy
,
*
pagination
.
PaginationResult
,
error
)
{
return
nil
,
nil
,
errors
.
New
(
"not implemented"
)
}
func
(
stubProxyRepo
)
ListWithFilters
(
ctx
context
.
Context
,
params
pagination
.
PaginationParams
,
protocol
,
status
,
search
string
)
([]
service
.
Proxy
,
*
pagination
.
PaginationResult
,
error
)
{
return
nil
,
nil
,
errors
.
New
(
"not implemented"
)
}
func
(
stubProxyRepo
)
ListWithFiltersAndAccountCount
(
ctx
context
.
Context
,
params
pagination
.
PaginationParams
,
protocol
,
status
,
search
string
)
([]
service
.
ProxyWithAccountCount
,
*
pagination
.
PaginationResult
,
error
)
{
return
nil
,
nil
,
errors
.
New
(
"not implemented"
)
}
func
(
stubProxyRepo
)
ListActive
(
ctx
context
.
Context
)
([]
service
.
Proxy
,
error
)
{
return
nil
,
errors
.
New
(
"not implemented"
)
}
func
(
stubProxyRepo
)
ListActiveWithAccountCount
(
ctx
context
.
Context
)
([]
service
.
ProxyWithAccountCount
,
error
)
{
return
nil
,
errors
.
New
(
"not implemented"
)
}
func
(
stubProxyRepo
)
ExistsByHostPortAuth
(
ctx
context
.
Context
,
host
string
,
port
int
,
username
,
password
string
)
(
bool
,
error
)
{
return
false
,
errors
.
New
(
"not implemented"
)
}
func
(
stubProxyRepo
)
CountAccountsByProxyID
(
ctx
context
.
Context
,
proxyID
int64
)
(
int64
,
error
)
{
return
0
,
errors
.
New
(
"not implemented"
)
}
func
(
stubProxyRepo
)
ListAccountSummariesByProxyID
(
ctx
context
.
Context
,
proxyID
int64
)
([]
service
.
ProxyAccountSummary
,
error
)
{
return
nil
,
errors
.
New
(
"not implemented"
)
}
type
stubRedeemCodeRepo
struct
{}
func
(
stubRedeemCodeRepo
)
Create
(
ctx
context
.
Context
,
code
*
service
.
RedeemCode
)
error
{
return
errors
.
New
(
"not implemented"
)
}
func
(
stubRedeemCodeRepo
)
CreateBatch
(
ctx
context
.
Context
,
codes
[]
service
.
RedeemCode
)
error
{
return
errors
.
New
(
"not implemented"
)
}
func
(
stubRedeemCodeRepo
)
GetByID
(
ctx
context
.
Context
,
id
int64
)
(
*
service
.
RedeemCode
,
error
)
{
return
nil
,
service
.
ErrRedeemCodeNotFound
}
func
(
stubRedeemCodeRepo
)
GetByCode
(
ctx
context
.
Context
,
code
string
)
(
*
service
.
RedeemCode
,
error
)
{
return
nil
,
service
.
ErrRedeemCodeNotFound
}
func
(
stubRedeemCodeRepo
)
Update
(
ctx
context
.
Context
,
code
*
service
.
RedeemCode
)
error
{
return
errors
.
New
(
"not implemented"
)
}
func
(
stubRedeemCodeRepo
)
Delete
(
ctx
context
.
Context
,
id
int64
)
error
{
return
errors
.
New
(
"not implemented"
)
}
func
(
stubRedeemCodeRepo
)
Use
(
ctx
context
.
Context
,
id
,
userID
int64
)
error
{
return
errors
.
New
(
"not implemented"
)
}
func
(
stubRedeemCodeRepo
)
List
(
ctx
context
.
Context
,
params
pagination
.
PaginationParams
)
([]
service
.
RedeemCode
,
*
pagination
.
PaginationResult
,
error
)
{
return
nil
,
nil
,
errors
.
New
(
"not implemented"
)
}
func
(
stubRedeemCodeRepo
)
ListWithFilters
(
ctx
context
.
Context
,
params
pagination
.
PaginationParams
,
codeType
,
status
,
search
string
)
([]
service
.
RedeemCode
,
*
pagination
.
PaginationResult
,
error
)
{
return
nil
,
nil
,
errors
.
New
(
"not implemented"
)
}
func
(
stubRedeemCodeRepo
)
ListByUser
(
ctx
context
.
Context
,
userID
int64
,
limit
int
)
([]
service
.
RedeemCode
,
error
)
{
return
nil
,
errors
.
New
(
"not implemented"
)
}
type
stubUserSubscriptionRepo
struct
{}
func
(
stubUserSubscriptionRepo
)
Create
(
ctx
context
.
Context
,
sub
*
service
.
UserSubscription
)
error
{
...
...
@@ -729,12 +1035,12 @@ func (r *stubApiKeyRepo) GetByID(ctx context.Context, id int64) (*service.APIKey
return
&
clone
,
nil
}
func
(
r
*
stubApiKeyRepo
)
GetOwnerID
(
ctx
context
.
Context
,
id
int64
)
(
int64
,
error
)
{
func
(
r
*
stubApiKeyRepo
)
Get
KeyAnd
OwnerID
(
ctx
context
.
Context
,
id
int64
)
(
string
,
int64
,
error
)
{
key
,
ok
:=
r
.
byID
[
id
]
if
!
ok
{
return
0
,
service
.
ErrAPIKeyNotFound
return
""
,
0
,
service
.
ErrAPIKeyNotFound
}
return
key
.
UserID
,
nil
return
key
.
Key
,
key
.
UserID
,
nil
}
func
(
r
*
stubApiKeyRepo
)
GetByKey
(
ctx
context
.
Context
,
key
string
)
(
*
service
.
APIKey
,
error
)
{
...
...
@@ -746,6 +1052,10 @@ func (r *stubApiKeyRepo) GetByKey(ctx context.Context, key string) (*service.API
return
&
clone
,
nil
}
func
(
r
*
stubApiKeyRepo
)
GetByKeyForAuth
(
ctx
context
.
Context
,
key
string
)
(
*
service
.
APIKey
,
error
)
{
return
r
.
GetByKey
(
ctx
,
key
)
}
func
(
r
*
stubApiKeyRepo
)
Update
(
ctx
context
.
Context
,
key
*
service
.
APIKey
)
error
{
if
key
==
nil
{
return
errors
.
New
(
"nil key"
)
...
...
@@ -860,6 +1170,14 @@ func (r *stubApiKeyRepo) CountByGroupID(ctx context.Context, groupID int64) (int
return
0
,
errors
.
New
(
"not implemented"
)
}
func
(
r
*
stubApiKeyRepo
)
ListKeysByUserID
(
ctx
context
.
Context
,
userID
int64
)
([]
string
,
error
)
{
return
nil
,
errors
.
New
(
"not implemented"
)
}
func
(
r
*
stubApiKeyRepo
)
ListKeysByGroupID
(
ctx
context
.
Context
,
groupID
int64
)
([]
string
,
error
)
{
return
nil
,
errors
.
New
(
"not implemented"
)
}
type
stubUsageLogRepo
struct
{
userLogs
map
[
int64
][]
service
.
UsageLog
}
...
...
@@ -928,11 +1246,11 @@ func (r *stubUsageLogRepo) GetDashboardStats(ctx context.Context) (*usagestats.D
return
nil
,
errors
.
New
(
"not implemented"
)
}
func
(
r
*
stubUsageLogRepo
)
GetUsageTrendWithFilters
(
ctx
context
.
Context
,
startTime
,
endTime
time
.
Time
,
granularity
string
,
userID
,
apiKeyID
int64
)
([]
usagestats
.
TrendDataPoint
,
error
)
{
func
(
r
*
stubUsageLogRepo
)
GetUsageTrendWithFilters
(
ctx
context
.
Context
,
startTime
,
endTime
time
.
Time
,
granularity
string
,
userID
,
apiKeyID
,
accountID
,
groupID
int64
,
model
string
,
stream
*
bool
)
([]
usagestats
.
TrendDataPoint
,
error
)
{
return
nil
,
errors
.
New
(
"not implemented"
)
}
func
(
r
*
stubUsageLogRepo
)
GetModelStatsWithFilters
(
ctx
context
.
Context
,
startTime
,
endTime
time
.
Time
,
userID
,
apiKeyID
,
accountID
int64
)
([]
usagestats
.
ModelStat
,
error
)
{
func
(
r
*
stubUsageLogRepo
)
GetModelStatsWithFilters
(
ctx
context
.
Context
,
startTime
,
endTime
time
.
Time
,
userID
,
apiKeyID
,
accountID
,
groupID
int64
,
stream
*
bool
)
([]
usagestats
.
ModelStat
,
error
)
{
return
nil
,
errors
.
New
(
"not implemented"
)
}
...
...
backend/internal/server/http.go
View file @
b9b4db3d
...
...
@@ -13,6 +13,7 @@ import (
"github.com/gin-gonic/gin"
"github.com/google/wire"
"github.com/redis/go-redis/v9"
)
// ProviderSet 提供服务器层的依赖
...
...
@@ -30,6 +31,9 @@ func ProvideRouter(
apiKeyAuth
middleware2
.
APIKeyAuthMiddleware
,
apiKeyService
*
service
.
APIKeyService
,
subscriptionService
*
service
.
SubscriptionService
,
opsService
*
service
.
OpsService
,
settingService
*
service
.
SettingService
,
redisClient
*
redis
.
Client
,
)
*
gin
.
Engine
{
if
cfg
.
Server
.
Mode
==
"release"
{
gin
.
SetMode
(
gin
.
ReleaseMode
)
...
...
@@ -47,7 +51,7 @@ func ProvideRouter(
}
}
return
SetupRouter
(
r
,
handlers
,
jwtAuth
,
adminAuth
,
apiKeyAuth
,
apiKeyService
,
subscriptionService
,
cfg
)
return
SetupRouter
(
r
,
handlers
,
jwtAuth
,
adminAuth
,
apiKeyAuth
,
apiKeyService
,
subscriptionService
,
opsService
,
settingService
,
cfg
,
redisClient
)
}
// ProvideHTTPServer 提供 HTTP 服务器
...
...
backend/internal/server/middleware/admin_auth.go
View file @
b9b4db3d
...
...
@@ -30,6 +30,20 @@ func adminAuth(
settingService
*
service
.
SettingService
,
)
gin
.
HandlerFunc
{
return
func
(
c
*
gin
.
Context
)
{
// WebSocket upgrade requests cannot set Authorization headers in browsers.
// For admin WebSocket endpoints (e.g. Ops realtime), allow passing the JWT via
// Sec-WebSocket-Protocol (subprotocol list) using a prefixed token item:
// Sec-WebSocket-Protocol: sub2api-admin, jwt.<token>
if
isWebSocketUpgradeRequest
(
c
)
{
if
token
:=
extractJWTFromWebSocketSubprotocol
(
c
);
token
!=
""
{
if
!
validateJWTForAdmin
(
c
,
token
,
authService
,
userService
)
{
return
}
c
.
Next
()
return
}
}
// 检查 x-api-key header(Admin API Key 认证)
apiKey
:=
c
.
GetHeader
(
"x-api-key"
)
if
apiKey
!=
""
{
...
...
@@ -58,6 +72,44 @@ func adminAuth(
}
}
func
isWebSocketUpgradeRequest
(
c
*
gin
.
Context
)
bool
{
if
c
==
nil
||
c
.
Request
==
nil
{
return
false
}
// RFC6455 handshake uses:
// Connection: Upgrade
// Upgrade: websocket
upgrade
:=
strings
.
ToLower
(
strings
.
TrimSpace
(
c
.
GetHeader
(
"Upgrade"
)))
if
upgrade
!=
"websocket"
{
return
false
}
connection
:=
strings
.
ToLower
(
c
.
GetHeader
(
"Connection"
))
return
strings
.
Contains
(
connection
,
"upgrade"
)
}
func
extractJWTFromWebSocketSubprotocol
(
c
*
gin
.
Context
)
string
{
if
c
==
nil
{
return
""
}
raw
:=
strings
.
TrimSpace
(
c
.
GetHeader
(
"Sec-WebSocket-Protocol"
))
if
raw
==
""
{
return
""
}
// The header is a comma-separated list of tokens. We reserve the prefix "jwt."
// for carrying the admin JWT.
for
_
,
part
:=
range
strings
.
Split
(
raw
,
","
)
{
p
:=
strings
.
TrimSpace
(
part
)
if
strings
.
HasPrefix
(
p
,
"jwt."
)
{
token
:=
strings
.
TrimSpace
(
strings
.
TrimPrefix
(
p
,
"jwt."
))
if
token
!=
""
{
return
token
}
}
}
return
""
}
// validateAdminAPIKey 验证管理员 API Key
func
validateAdminAPIKey
(
c
*
gin
.
Context
,
...
...
backend/internal/server/middleware/api_key_auth.go
View file @
b9b4db3d
package
middleware
import
(
"context"
"errors"
"log"
"strings"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
...
...
@@ -71,6 +74,17 @@ func apiKeyAuthWithSubscription(apiKeyService *service.APIKeyService, subscripti
return
}
// 检查 IP 限制(白名单/黑名单)
// 注意:错误信息故意模糊,避免暴露具体的 IP 限制机制
if
len
(
apiKey
.
IPWhitelist
)
>
0
||
len
(
apiKey
.
IPBlacklist
)
>
0
{
clientIP
:=
ip
.
GetClientIP
(
c
)
allowed
,
_
:=
ip
.
CheckIPRestriction
(
clientIP
,
apiKey
.
IPWhitelist
,
apiKey
.
IPBlacklist
)
if
!
allowed
{
AbortWithError
(
c
,
403
,
"ACCESS_DENIED"
,
"Access denied"
)
return
}
}
// 检查关联的用户
if
apiKey
.
User
==
nil
{
AbortWithError
(
c
,
401
,
"USER_NOT_FOUND"
,
"User associated with API key not found"
)
...
...
@@ -91,6 +105,7 @@ func apiKeyAuthWithSubscription(apiKeyService *service.APIKeyService, subscripti
Concurrency
:
apiKey
.
User
.
Concurrency
,
})
c
.
Set
(
string
(
ContextKeyUserRole
),
apiKey
.
User
.
Role
)
setGroupContext
(
c
,
apiKey
.
Group
)
c
.
Next
()
return
}
...
...
@@ -149,6 +164,7 @@ func apiKeyAuthWithSubscription(apiKeyService *service.APIKeyService, subscripti
Concurrency
:
apiKey
.
User
.
Concurrency
,
})
c
.
Set
(
string
(
ContextKeyUserRole
),
apiKey
.
User
.
Role
)
setGroupContext
(
c
,
apiKey
.
Group
)
c
.
Next
()
}
...
...
@@ -173,3 +189,14 @@ func GetSubscriptionFromContext(c *gin.Context) (*service.UserSubscription, bool
subscription
,
ok
:=
value
.
(
*
service
.
UserSubscription
)
return
subscription
,
ok
}
func
setGroupContext
(
c
*
gin
.
Context
,
group
*
service
.
Group
)
{
if
!
service
.
IsGroupContextValid
(
group
)
{
return
}
if
existing
,
ok
:=
c
.
Request
.
Context
()
.
Value
(
ctxkey
.
Group
)
.
(
*
service
.
Group
);
ok
&&
existing
!=
nil
&&
existing
.
ID
==
group
.
ID
&&
service
.
IsGroupContextValid
(
existing
)
{
return
}
ctx
:=
context
.
WithValue
(
c
.
Request
.
Context
(),
ctxkey
.
Group
,
group
)
c
.
Request
=
c
.
Request
.
WithContext
(
ctx
)
}
backend/internal/server/middleware/api_key_auth_google.go
View file @
b9b4db3d
...
...
@@ -63,6 +63,7 @@ func APIKeyAuthWithSubscriptionGoogle(apiKeyService *service.APIKeyService, subs
Concurrency
:
apiKey
.
User
.
Concurrency
,
})
c
.
Set
(
string
(
ContextKeyUserRole
),
apiKey
.
User
.
Role
)
setGroupContext
(
c
,
apiKey
.
Group
)
c
.
Next
()
return
}
...
...
@@ -102,6 +103,7 @@ func APIKeyAuthWithSubscriptionGoogle(apiKeyService *service.APIKeyService, subs
Concurrency
:
apiKey
.
User
.
Concurrency
,
})
c
.
Set
(
string
(
ContextKeyUserRole
),
apiKey
.
User
.
Role
)
setGroupContext
(
c
,
apiKey
.
Group
)
c
.
Next
()
}
}
...
...
backend/internal/server/middleware/api_key_auth_google_test.go
View file @
b9b4db3d
...
...
@@ -9,6 +9,7 @@ import (
"testing"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service"
...
...
@@ -26,8 +27,8 @@ func (f fakeAPIKeyRepo) Create(ctx context.Context, key *service.APIKey) error {
func
(
f
fakeAPIKeyRepo
)
GetByID
(
ctx
context
.
Context
,
id
int64
)
(
*
service
.
APIKey
,
error
)
{
return
nil
,
errors
.
New
(
"not implemented"
)
}
func
(
f
fakeAPIKeyRepo
)
GetOwnerID
(
ctx
context
.
Context
,
id
int64
)
(
int64
,
error
)
{
return
0
,
errors
.
New
(
"not implemented"
)
func
(
f
fakeAPIKeyRepo
)
Get
KeyAnd
OwnerID
(
ctx
context
.
Context
,
id
int64
)
(
string
,
int64
,
error
)
{
return
""
,
0
,
errors
.
New
(
"not implemented"
)
}
func
(
f
fakeAPIKeyRepo
)
GetByKey
(
ctx
context
.
Context
,
key
string
)
(
*
service
.
APIKey
,
error
)
{
if
f
.
getByKey
==
nil
{
...
...
@@ -35,6 +36,9 @@ func (f fakeAPIKeyRepo) GetByKey(ctx context.Context, key string) (*service.APIK
}
return
f
.
getByKey
(
ctx
,
key
)
}
func
(
f
fakeAPIKeyRepo
)
GetByKeyForAuth
(
ctx
context
.
Context
,
key
string
)
(
*
service
.
APIKey
,
error
)
{
return
f
.
GetByKey
(
ctx
,
key
)
}
func
(
f
fakeAPIKeyRepo
)
Update
(
ctx
context
.
Context
,
key
*
service
.
APIKey
)
error
{
return
errors
.
New
(
"not implemented"
)
}
...
...
@@ -65,6 +69,12 @@ func (f fakeAPIKeyRepo) ClearGroupIDByGroupID(ctx context.Context, groupID int64
func
(
f
fakeAPIKeyRepo
)
CountByGroupID
(
ctx
context
.
Context
,
groupID
int64
)
(
int64
,
error
)
{
return
0
,
errors
.
New
(
"not implemented"
)
}
func
(
f
fakeAPIKeyRepo
)
ListKeysByUserID
(
ctx
context
.
Context
,
userID
int64
)
([]
string
,
error
)
{
return
nil
,
errors
.
New
(
"not implemented"
)
}
func
(
f
fakeAPIKeyRepo
)
ListKeysByGroupID
(
ctx
context
.
Context
,
groupID
int64
)
([]
string
,
error
)
{
return
nil
,
errors
.
New
(
"not implemented"
)
}
type
googleErrorResponse
struct
{
Error
struct
{
...
...
@@ -133,6 +143,70 @@ func TestApiKeyAuthWithSubscriptionGoogle_QueryApiKeyRejected(t *testing.T) {
require
.
Equal
(
t
,
"INVALID_ARGUMENT"
,
resp
.
Error
.
Status
)
}
func
TestApiKeyAuthWithSubscriptionGoogleSetsGroupContext
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
group
:=
&
service
.
Group
{
ID
:
99
,
Name
:
"g1"
,
Status
:
service
.
StatusActive
,
Platform
:
service
.
PlatformGemini
,
Hydrated
:
true
,
}
user
:=
&
service
.
User
{
ID
:
7
,
Role
:
service
.
RoleUser
,
Status
:
service
.
StatusActive
,
Balance
:
10
,
Concurrency
:
3
,
}
apiKey
:=
&
service
.
APIKey
{
ID
:
100
,
UserID
:
user
.
ID
,
Key
:
"test-key"
,
Status
:
service
.
StatusActive
,
User
:
user
,
Group
:
group
,
}
apiKey
.
GroupID
=
&
group
.
ID
apiKeyService
:=
service
.
NewAPIKeyService
(
fakeAPIKeyRepo
{
getByKey
:
func
(
ctx
context
.
Context
,
key
string
)
(
*
service
.
APIKey
,
error
)
{
if
key
!=
apiKey
.
Key
{
return
nil
,
service
.
ErrAPIKeyNotFound
}
clone
:=
*
apiKey
return
&
clone
,
nil
},
},
nil
,
nil
,
nil
,
nil
,
&
config
.
Config
{
RunMode
:
config
.
RunModeSimple
},
)
cfg
:=
&
config
.
Config
{
RunMode
:
config
.
RunModeSimple
}
r
:=
gin
.
New
()
r
.
Use
(
APIKeyAuthWithSubscriptionGoogle
(
apiKeyService
,
nil
,
cfg
))
r
.
GET
(
"/v1beta/test"
,
func
(
c
*
gin
.
Context
)
{
groupFromCtx
,
ok
:=
c
.
Request
.
Context
()
.
Value
(
ctxkey
.
Group
)
.
(
*
service
.
Group
)
if
!
ok
||
groupFromCtx
==
nil
||
groupFromCtx
.
ID
!=
group
.
ID
{
c
.
JSON
(
http
.
StatusInternalServerError
,
gin
.
H
{
"ok"
:
false
})
return
}
c
.
JSON
(
http
.
StatusOK
,
gin
.
H
{
"ok"
:
true
})
})
req
:=
httptest
.
NewRequest
(
http
.
MethodGet
,
"/v1beta/test"
,
nil
)
req
.
Header
.
Set
(
"x-api-key"
,
apiKey
.
Key
)
rec
:=
httptest
.
NewRecorder
()
r
.
ServeHTTP
(
rec
,
req
)
require
.
Equal
(
t
,
http
.
StatusOK
,
rec
.
Code
)
}
func
TestApiKeyAuthWithSubscriptionGoogle_QueryKeyAllowedOnV1Beta
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
...
...
backend/internal/server/middleware/api_key_auth_test.go
View file @
b9b4db3d
...
...
@@ -11,6 +11,7 @@ import (
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
...
...
@@ -25,6 +26,7 @@ func TestSimpleModeBypassesQuotaCheck(t *testing.T) {
ID
:
42
,
Name
:
"sub"
,
Status
:
service
.
StatusActive
,
Hydrated
:
true
,
SubscriptionType
:
service
.
SubscriptionTypeSubscription
,
DailyLimitUSD
:
&
limit
,
}
...
...
@@ -110,6 +112,129 @@ func TestSimpleModeBypassesQuotaCheck(t *testing.T) {
})
}
func
TestAPIKeyAuthSetsGroupContext
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
group
:=
&
service
.
Group
{
ID
:
101
,
Name
:
"g1"
,
Status
:
service
.
StatusActive
,
Platform
:
service
.
PlatformAnthropic
,
Hydrated
:
true
,
}
user
:=
&
service
.
User
{
ID
:
7
,
Role
:
service
.
RoleUser
,
Status
:
service
.
StatusActive
,
Balance
:
10
,
Concurrency
:
3
,
}
apiKey
:=
&
service
.
APIKey
{
ID
:
100
,
UserID
:
user
.
ID
,
Key
:
"test-key"
,
Status
:
service
.
StatusActive
,
User
:
user
,
Group
:
group
,
}
apiKey
.
GroupID
=
&
group
.
ID
apiKeyRepo
:=
&
stubApiKeyRepo
{
getByKey
:
func
(
ctx
context
.
Context
,
key
string
)
(
*
service
.
APIKey
,
error
)
{
if
key
!=
apiKey
.
Key
{
return
nil
,
service
.
ErrAPIKeyNotFound
}
clone
:=
*
apiKey
return
&
clone
,
nil
},
}
cfg
:=
&
config
.
Config
{
RunMode
:
config
.
RunModeSimple
}
apiKeyService
:=
service
.
NewAPIKeyService
(
apiKeyRepo
,
nil
,
nil
,
nil
,
nil
,
cfg
)
router
:=
gin
.
New
()
router
.
Use
(
gin
.
HandlerFunc
(
NewAPIKeyAuthMiddleware
(
apiKeyService
,
nil
,
cfg
)))
router
.
GET
(
"/t"
,
func
(
c
*
gin
.
Context
)
{
groupFromCtx
,
ok
:=
c
.
Request
.
Context
()
.
Value
(
ctxkey
.
Group
)
.
(
*
service
.
Group
)
if
!
ok
||
groupFromCtx
==
nil
||
groupFromCtx
.
ID
!=
group
.
ID
{
c
.
JSON
(
http
.
StatusInternalServerError
,
gin
.
H
{
"ok"
:
false
})
return
}
c
.
JSON
(
http
.
StatusOK
,
gin
.
H
{
"ok"
:
true
})
})
w
:=
httptest
.
NewRecorder
()
req
:=
httptest
.
NewRequest
(
http
.
MethodGet
,
"/t"
,
nil
)
req
.
Header
.
Set
(
"x-api-key"
,
apiKey
.
Key
)
router
.
ServeHTTP
(
w
,
req
)
require
.
Equal
(
t
,
http
.
StatusOK
,
w
.
Code
)
}
func
TestAPIKeyAuthOverwritesInvalidContextGroup
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
group
:=
&
service
.
Group
{
ID
:
101
,
Name
:
"g1"
,
Status
:
service
.
StatusActive
,
Platform
:
service
.
PlatformAnthropic
,
Hydrated
:
true
,
}
user
:=
&
service
.
User
{
ID
:
7
,
Role
:
service
.
RoleUser
,
Status
:
service
.
StatusActive
,
Balance
:
10
,
Concurrency
:
3
,
}
apiKey
:=
&
service
.
APIKey
{
ID
:
100
,
UserID
:
user
.
ID
,
Key
:
"test-key"
,
Status
:
service
.
StatusActive
,
User
:
user
,
Group
:
group
,
}
apiKey
.
GroupID
=
&
group
.
ID
apiKeyRepo
:=
&
stubApiKeyRepo
{
getByKey
:
func
(
ctx
context
.
Context
,
key
string
)
(
*
service
.
APIKey
,
error
)
{
if
key
!=
apiKey
.
Key
{
return
nil
,
service
.
ErrAPIKeyNotFound
}
clone
:=
*
apiKey
return
&
clone
,
nil
},
}
cfg
:=
&
config
.
Config
{
RunMode
:
config
.
RunModeSimple
}
apiKeyService
:=
service
.
NewAPIKeyService
(
apiKeyRepo
,
nil
,
nil
,
nil
,
nil
,
cfg
)
router
:=
gin
.
New
()
router
.
Use
(
gin
.
HandlerFunc
(
NewAPIKeyAuthMiddleware
(
apiKeyService
,
nil
,
cfg
)))
invalidGroup
:=
&
service
.
Group
{
ID
:
group
.
ID
,
Platform
:
group
.
Platform
,
Status
:
group
.
Status
,
}
router
.
GET
(
"/t"
,
func
(
c
*
gin
.
Context
)
{
groupFromCtx
,
ok
:=
c
.
Request
.
Context
()
.
Value
(
ctxkey
.
Group
)
.
(
*
service
.
Group
)
if
!
ok
||
groupFromCtx
==
nil
||
groupFromCtx
.
ID
!=
group
.
ID
||
!
groupFromCtx
.
Hydrated
||
groupFromCtx
==
invalidGroup
{
c
.
JSON
(
http
.
StatusInternalServerError
,
gin
.
H
{
"ok"
:
false
})
return
}
c
.
JSON
(
http
.
StatusOK
,
gin
.
H
{
"ok"
:
true
})
})
w
:=
httptest
.
NewRecorder
()
req
:=
httptest
.
NewRequest
(
http
.
MethodGet
,
"/t"
,
nil
)
req
.
Header
.
Set
(
"x-api-key"
,
apiKey
.
Key
)
req
=
req
.
WithContext
(
context
.
WithValue
(
req
.
Context
(),
ctxkey
.
Group
,
invalidGroup
))
router
.
ServeHTTP
(
w
,
req
)
require
.
Equal
(
t
,
http
.
StatusOK
,
w
.
Code
)
}
func
newAuthTestRouter
(
apiKeyService
*
service
.
APIKeyService
,
subscriptionService
*
service
.
SubscriptionService
,
cfg
*
config
.
Config
)
*
gin
.
Engine
{
router
:=
gin
.
New
()
router
.
Use
(
gin
.
HandlerFunc
(
NewAPIKeyAuthMiddleware
(
apiKeyService
,
subscriptionService
,
cfg
)))
...
...
@@ -131,8 +256,8 @@ func (r *stubApiKeyRepo) GetByID(ctx context.Context, id int64) (*service.APIKey
return
nil
,
errors
.
New
(
"not implemented"
)
}
func
(
r
*
stubApiKeyRepo
)
GetOwnerID
(
ctx
context
.
Context
,
id
int64
)
(
int64
,
error
)
{
return
0
,
errors
.
New
(
"not implemented"
)
func
(
r
*
stubApiKeyRepo
)
Get
KeyAnd
OwnerID
(
ctx
context
.
Context
,
id
int64
)
(
string
,
int64
,
error
)
{
return
""
,
0
,
errors
.
New
(
"not implemented"
)
}
func
(
r
*
stubApiKeyRepo
)
GetByKey
(
ctx
context
.
Context
,
key
string
)
(
*
service
.
APIKey
,
error
)
{
...
...
@@ -142,6 +267,10 @@ func (r *stubApiKeyRepo) GetByKey(ctx context.Context, key string) (*service.API
return
nil
,
errors
.
New
(
"not implemented"
)
}
func
(
r
*
stubApiKeyRepo
)
GetByKeyForAuth
(
ctx
context
.
Context
,
key
string
)
(
*
service
.
APIKey
,
error
)
{
return
r
.
GetByKey
(
ctx
,
key
)
}
func
(
r
*
stubApiKeyRepo
)
Update
(
ctx
context
.
Context
,
key
*
service
.
APIKey
)
error
{
return
errors
.
New
(
"not implemented"
)
}
...
...
@@ -182,6 +311,14 @@ func (r *stubApiKeyRepo) CountByGroupID(ctx context.Context, groupID int64) (int
return
0
,
errors
.
New
(
"not implemented"
)
}
func
(
r
*
stubApiKeyRepo
)
ListKeysByUserID
(
ctx
context
.
Context
,
userID
int64
)
([]
string
,
error
)
{
return
nil
,
errors
.
New
(
"not implemented"
)
}
func
(
r
*
stubApiKeyRepo
)
ListKeysByGroupID
(
ctx
context
.
Context
,
groupID
int64
)
([]
string
,
error
)
{
return
nil
,
errors
.
New
(
"not implemented"
)
}
type
stubUserSubscriptionRepo
struct
{
getActive
func
(
ctx
context
.
Context
,
userID
,
groupID
int64
)
(
*
service
.
UserSubscription
,
error
)
updateStatus
func
(
ctx
context
.
Context
,
subscriptionID
int64
,
status
string
)
error
...
...
backend/internal/server/middleware/client_request_id.go
0 → 100644
View file @
b9b4db3d
package
middleware
import
(
"context"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
)
// ClientRequestID ensures every request has a unique client_request_id in request.Context().
//
// This is used by the Ops monitoring module for end-to-end request correlation.
func
ClientRequestID
()
gin
.
HandlerFunc
{
return
func
(
c
*
gin
.
Context
)
{
if
c
.
Request
==
nil
{
c
.
Next
()
return
}
if
v
:=
c
.
Request
.
Context
()
.
Value
(
ctxkey
.
ClientRequestID
);
v
!=
nil
{
c
.
Next
()
return
}
id
:=
uuid
.
New
()
.
String
()
c
.
Request
=
c
.
Request
.
WithContext
(
context
.
WithValue
(
c
.
Request
.
Context
(),
ctxkey
.
ClientRequestID
,
id
))
c
.
Next
()
}
}
backend/internal/server/middleware/security_headers.go
View file @
b9b4db3d
package
middleware
import
(
"crypto/rand"
"encoding/base64"
"strings"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/gin-gonic/gin"
)
const
(
// CSPNonceKey is the context key for storing the CSP nonce
CSPNonceKey
=
"csp_nonce"
// NonceTemplate is the placeholder in CSP policy for nonce
NonceTemplate
=
"__CSP_NONCE__"
// CloudflareInsightsDomain is the domain for Cloudflare Web Analytics
CloudflareInsightsDomain
=
"https://static.cloudflareinsights.com"
)
// GenerateNonce generates a cryptographically secure random nonce
func
GenerateNonce
()
string
{
b
:=
make
([]
byte
,
16
)
_
,
_
=
rand
.
Read
(
b
)
return
base64
.
StdEncoding
.
EncodeToString
(
b
)
}
// GetNonceFromContext retrieves the CSP nonce from gin context
func
GetNonceFromContext
(
c
*
gin
.
Context
)
string
{
if
nonce
,
exists
:=
c
.
Get
(
CSPNonceKey
);
exists
{
if
s
,
ok
:=
nonce
.
(
string
);
ok
{
return
s
}
}
return
""
}
// SecurityHeaders sets baseline security headers for all responses.
func
SecurityHeaders
(
cfg
config
.
CSPConfig
)
gin
.
HandlerFunc
{
policy
:=
strings
.
TrimSpace
(
cfg
.
Policy
)
...
...
@@ -14,13 +42,75 @@ func SecurityHeaders(cfg config.CSPConfig) gin.HandlerFunc {
policy
=
config
.
DefaultCSPPolicy
}
// Enhance policy with required directives (nonce placeholder and Cloudflare Insights)
policy
=
enhanceCSPPolicy
(
policy
)
return
func
(
c
*
gin
.
Context
)
{
c
.
Header
(
"X-Content-Type-Options"
,
"nosniff"
)
c
.
Header
(
"X-Frame-Options"
,
"DENY"
)
c
.
Header
(
"Referrer-Policy"
,
"strict-origin-when-cross-origin"
)
if
cfg
.
Enabled
{
c
.
Header
(
"Content-Security-Policy"
,
policy
)
// Generate nonce for this request
nonce
:=
GenerateNonce
()
c
.
Set
(
CSPNonceKey
,
nonce
)
// Replace nonce placeholder in policy
finalPolicy
:=
strings
.
ReplaceAll
(
policy
,
NonceTemplate
,
"'nonce-"
+
nonce
+
"'"
)
c
.
Header
(
"Content-Security-Policy"
,
finalPolicy
)
}
c
.
Next
()
}
}
// enhanceCSPPolicy ensures the CSP policy includes nonce support and Cloudflare Insights domain.
// This allows the application to work correctly even if the config file has an older CSP policy.
func
enhanceCSPPolicy
(
policy
string
)
string
{
// Add nonce placeholder to script-src if not present
if
!
strings
.
Contains
(
policy
,
NonceTemplate
)
&&
!
strings
.
Contains
(
policy
,
"'nonce-"
)
{
policy
=
addToDirective
(
policy
,
"script-src"
,
NonceTemplate
)
}
// Add Cloudflare Insights domain to script-src if not present
if
!
strings
.
Contains
(
policy
,
CloudflareInsightsDomain
)
{
policy
=
addToDirective
(
policy
,
"script-src"
,
CloudflareInsightsDomain
)
}
return
policy
}
// addToDirective adds a value to a specific CSP directive.
// If the directive doesn't exist, it will be added after default-src.
func
addToDirective
(
policy
,
directive
,
value
string
)
string
{
// Find the directive in the policy
directivePrefix
:=
directive
+
" "
idx
:=
strings
.
Index
(
policy
,
directivePrefix
)
if
idx
==
-
1
{
// Directive not found, add it after default-src or at the beginning
defaultSrcIdx
:=
strings
.
Index
(
policy
,
"default-src "
)
if
defaultSrcIdx
!=
-
1
{
// Find the end of default-src directive (next semicolon)
endIdx
:=
strings
.
Index
(
policy
[
defaultSrcIdx
:
],
";"
)
if
endIdx
!=
-
1
{
insertPos
:=
defaultSrcIdx
+
endIdx
+
1
// Insert new directive after default-src
return
policy
[
:
insertPos
]
+
" "
+
directive
+
" 'self' "
+
value
+
";"
+
policy
[
insertPos
:
]
}
}
// Fallback: prepend the directive
return
directive
+
" 'self' "
+
value
+
"; "
+
policy
}
// Find the end of this directive (next semicolon or end of string)
endIdx
:=
strings
.
Index
(
policy
[
idx
:
],
";"
)
if
endIdx
==
-
1
{
// No semicolon found, directive goes to end of string
return
policy
+
" "
+
value
}
// Insert value before the semicolon
insertPos
:=
idx
+
endIdx
return
policy
[
:
insertPos
]
+
" "
+
value
+
policy
[
insertPos
:
]
}
backend/internal/server/middleware/security_headers_test.go
0 → 100644
View file @
b9b4db3d
package
middleware
import
(
"encoding/base64"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func
init
()
{
gin
.
SetMode
(
gin
.
TestMode
)
}
func
TestGenerateNonce
(
t
*
testing
.
T
)
{
t
.
Run
(
"generates_valid_base64_string"
,
func
(
t
*
testing
.
T
)
{
nonce
:=
GenerateNonce
()
// Should be valid base64
decoded
,
err
:=
base64
.
StdEncoding
.
DecodeString
(
nonce
)
require
.
NoError
(
t
,
err
)
// Should decode to 16 bytes
assert
.
Len
(
t
,
decoded
,
16
)
})
t
.
Run
(
"generates_unique_nonces"
,
func
(
t
*
testing
.
T
)
{
nonces
:=
make
(
map
[
string
]
bool
)
for
i
:=
0
;
i
<
100
;
i
++
{
nonce
:=
GenerateNonce
()
assert
.
False
(
t
,
nonces
[
nonce
],
"nonce should be unique"
)
nonces
[
nonce
]
=
true
}
})
t
.
Run
(
"nonce_has_expected_length"
,
func
(
t
*
testing
.
T
)
{
nonce
:=
GenerateNonce
()
// 16 bytes -> 24 chars in base64 (with padding)
assert
.
Len
(
t
,
nonce
,
24
)
})
}
func
TestGetNonceFromContext
(
t
*
testing
.
T
)
{
t
.
Run
(
"returns_nonce_when_present"
,
func
(
t
*
testing
.
T
)
{
w
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
w
)
expectedNonce
:=
"test-nonce-123"
c
.
Set
(
CSPNonceKey
,
expectedNonce
)
nonce
:=
GetNonceFromContext
(
c
)
assert
.
Equal
(
t
,
expectedNonce
,
nonce
)
})
t
.
Run
(
"returns_empty_string_when_not_present"
,
func
(
t
*
testing
.
T
)
{
w
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
w
)
nonce
:=
GetNonceFromContext
(
c
)
assert
.
Empty
(
t
,
nonce
)
})
t
.
Run
(
"returns_empty_for_wrong_type"
,
func
(
t
*
testing
.
T
)
{
w
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
w
)
// Set a non-string value
c
.
Set
(
CSPNonceKey
,
12345
)
// Should return empty string for wrong type (safe type assertion)
nonce
:=
GetNonceFromContext
(
c
)
assert
.
Empty
(
t
,
nonce
)
})
}
func
TestSecurityHeaders
(
t
*
testing
.
T
)
{
t
.
Run
(
"sets_basic_security_headers"
,
func
(
t
*
testing
.
T
)
{
cfg
:=
config
.
CSPConfig
{
Enabled
:
false
}
middleware
:=
SecurityHeaders
(
cfg
)
w
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
w
)
c
.
Request
=
httptest
.
NewRequest
(
http
.
MethodGet
,
"/"
,
nil
)
middleware
(
c
)
assert
.
Equal
(
t
,
"nosniff"
,
w
.
Header
()
.
Get
(
"X-Content-Type-Options"
))
assert
.
Equal
(
t
,
"DENY"
,
w
.
Header
()
.
Get
(
"X-Frame-Options"
))
assert
.
Equal
(
t
,
"strict-origin-when-cross-origin"
,
w
.
Header
()
.
Get
(
"Referrer-Policy"
))
})
t
.
Run
(
"csp_disabled_no_csp_header"
,
func
(
t
*
testing
.
T
)
{
cfg
:=
config
.
CSPConfig
{
Enabled
:
false
}
middleware
:=
SecurityHeaders
(
cfg
)
w
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
w
)
c
.
Request
=
httptest
.
NewRequest
(
http
.
MethodGet
,
"/"
,
nil
)
middleware
(
c
)
assert
.
Empty
(
t
,
w
.
Header
()
.
Get
(
"Content-Security-Policy"
))
})
t
.
Run
(
"csp_enabled_sets_csp_header"
,
func
(
t
*
testing
.
T
)
{
cfg
:=
config
.
CSPConfig
{
Enabled
:
true
,
Policy
:
"default-src 'self'"
,
}
middleware
:=
SecurityHeaders
(
cfg
)
w
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
w
)
c
.
Request
=
httptest
.
NewRequest
(
http
.
MethodGet
,
"/"
,
nil
)
middleware
(
c
)
csp
:=
w
.
Header
()
.
Get
(
"Content-Security-Policy"
)
assert
.
NotEmpty
(
t
,
csp
)
// Policy is auto-enhanced with nonce and Cloudflare Insights domain
assert
.
Contains
(
t
,
csp
,
"default-src 'self'"
)
assert
.
Contains
(
t
,
csp
,
"'nonce-"
)
assert
.
Contains
(
t
,
csp
,
CloudflareInsightsDomain
)
})
t
.
Run
(
"csp_enabled_with_nonce_placeholder"
,
func
(
t
*
testing
.
T
)
{
cfg
:=
config
.
CSPConfig
{
Enabled
:
true
,
Policy
:
"script-src 'self' __CSP_NONCE__"
,
}
middleware
:=
SecurityHeaders
(
cfg
)
w
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
w
)
c
.
Request
=
httptest
.
NewRequest
(
http
.
MethodGet
,
"/"
,
nil
)
middleware
(
c
)
csp
:=
w
.
Header
()
.
Get
(
"Content-Security-Policy"
)
assert
.
NotEmpty
(
t
,
csp
)
assert
.
NotContains
(
t
,
csp
,
"__CSP_NONCE__"
,
"placeholder should be replaced"
)
assert
.
Contains
(
t
,
csp
,
"'nonce-"
,
"should contain nonce directive"
)
// Verify nonce is stored in context
nonce
:=
GetNonceFromContext
(
c
)
assert
.
NotEmpty
(
t
,
nonce
)
assert
.
Contains
(
t
,
csp
,
"'nonce-"
+
nonce
+
"'"
)
})
t
.
Run
(
"uses_default_policy_when_empty"
,
func
(
t
*
testing
.
T
)
{
cfg
:=
config
.
CSPConfig
{
Enabled
:
true
,
Policy
:
""
,
}
middleware
:=
SecurityHeaders
(
cfg
)
w
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
w
)
c
.
Request
=
httptest
.
NewRequest
(
http
.
MethodGet
,
"/"
,
nil
)
middleware
(
c
)
csp
:=
w
.
Header
()
.
Get
(
"Content-Security-Policy"
)
assert
.
NotEmpty
(
t
,
csp
)
// Default policy should contain these elements
assert
.
Contains
(
t
,
csp
,
"default-src 'self'"
)
})
t
.
Run
(
"uses_default_policy_when_whitespace_only"
,
func
(
t
*
testing
.
T
)
{
cfg
:=
config
.
CSPConfig
{
Enabled
:
true
,
Policy
:
"
\t\n
"
,
}
middleware
:=
SecurityHeaders
(
cfg
)
w
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
w
)
c
.
Request
=
httptest
.
NewRequest
(
http
.
MethodGet
,
"/"
,
nil
)
middleware
(
c
)
csp
:=
w
.
Header
()
.
Get
(
"Content-Security-Policy"
)
assert
.
NotEmpty
(
t
,
csp
)
assert
.
Contains
(
t
,
csp
,
"default-src 'self'"
)
})
t
.
Run
(
"multiple_nonce_placeholders_replaced"
,
func
(
t
*
testing
.
T
)
{
cfg
:=
config
.
CSPConfig
{
Enabled
:
true
,
Policy
:
"script-src __CSP_NONCE__; style-src __CSP_NONCE__"
,
}
middleware
:=
SecurityHeaders
(
cfg
)
w
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
w
)
c
.
Request
=
httptest
.
NewRequest
(
http
.
MethodGet
,
"/"
,
nil
)
middleware
(
c
)
csp
:=
w
.
Header
()
.
Get
(
"Content-Security-Policy"
)
nonce
:=
GetNonceFromContext
(
c
)
// Count occurrences of the nonce
count
:=
strings
.
Count
(
csp
,
"'nonce-"
+
nonce
+
"'"
)
assert
.
Equal
(
t
,
2
,
count
,
"both placeholders should be replaced with same nonce"
)
})
t
.
Run
(
"calls_next_handler"
,
func
(
t
*
testing
.
T
)
{
cfg
:=
config
.
CSPConfig
{
Enabled
:
true
,
Policy
:
"default-src 'self'"
}
middleware
:=
SecurityHeaders
(
cfg
)
nextCalled
:=
false
router
:=
gin
.
New
()
router
.
Use
(
middleware
)
router
.
GET
(
"/test"
,
func
(
c
*
gin
.
Context
)
{
nextCalled
=
true
c
.
Status
(
http
.
StatusOK
)
})
w
:=
httptest
.
NewRecorder
()
req
:=
httptest
.
NewRequest
(
http
.
MethodGet
,
"/test"
,
nil
)
router
.
ServeHTTP
(
w
,
req
)
assert
.
True
(
t
,
nextCalled
,
"next handler should be called"
)
assert
.
Equal
(
t
,
http
.
StatusOK
,
w
.
Code
)
})
t
.
Run
(
"nonce_unique_per_request"
,
func
(
t
*
testing
.
T
)
{
cfg
:=
config
.
CSPConfig
{
Enabled
:
true
,
Policy
:
"script-src __CSP_NONCE__"
,
}
middleware
:=
SecurityHeaders
(
cfg
)
nonces
:=
make
(
map
[
string
]
bool
)
for
i
:=
0
;
i
<
10
;
i
++
{
w
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
w
)
c
.
Request
=
httptest
.
NewRequest
(
http
.
MethodGet
,
"/"
,
nil
)
middleware
(
c
)
nonce
:=
GetNonceFromContext
(
c
)
assert
.
False
(
t
,
nonces
[
nonce
],
"nonce should be unique per request"
)
nonces
[
nonce
]
=
true
}
})
}
func
TestCSPNonceKey
(
t
*
testing
.
T
)
{
t
.
Run
(
"constant_value"
,
func
(
t
*
testing
.
T
)
{
assert
.
Equal
(
t
,
"csp_nonce"
,
CSPNonceKey
)
})
}
func
TestNonceTemplate
(
t
*
testing
.
T
)
{
t
.
Run
(
"constant_value"
,
func
(
t
*
testing
.
T
)
{
assert
.
Equal
(
t
,
"__CSP_NONCE__"
,
NonceTemplate
)
})
}
func
TestEnhanceCSPPolicy
(
t
*
testing
.
T
)
{
t
.
Run
(
"adds_nonce_placeholder_if_missing"
,
func
(
t
*
testing
.
T
)
{
policy
:=
"default-src 'self'; script-src 'self'"
enhanced
:=
enhanceCSPPolicy
(
policy
)
assert
.
Contains
(
t
,
enhanced
,
NonceTemplate
)
assert
.
Contains
(
t
,
enhanced
,
CloudflareInsightsDomain
)
})
t
.
Run
(
"does_not_duplicate_nonce_placeholder"
,
func
(
t
*
testing
.
T
)
{
policy
:=
"default-src 'self'; script-src 'self' __CSP_NONCE__"
enhanced
:=
enhanceCSPPolicy
(
policy
)
// Should not duplicate
count
:=
strings
.
Count
(
enhanced
,
NonceTemplate
)
assert
.
Equal
(
t
,
1
,
count
)
})
t
.
Run
(
"does_not_duplicate_cloudflare_domain"
,
func
(
t
*
testing
.
T
)
{
policy
:=
"default-src 'self'; script-src 'self' https://static.cloudflareinsights.com"
enhanced
:=
enhanceCSPPolicy
(
policy
)
count
:=
strings
.
Count
(
enhanced
,
CloudflareInsightsDomain
)
assert
.
Equal
(
t
,
1
,
count
)
})
t
.
Run
(
"handles_policy_without_script_src"
,
func
(
t
*
testing
.
T
)
{
policy
:=
"default-src 'self'"
enhanced
:=
enhanceCSPPolicy
(
policy
)
assert
.
Contains
(
t
,
enhanced
,
"script-src"
)
assert
.
Contains
(
t
,
enhanced
,
NonceTemplate
)
assert
.
Contains
(
t
,
enhanced
,
CloudflareInsightsDomain
)
})
t
.
Run
(
"preserves_existing_nonce"
,
func
(
t
*
testing
.
T
)
{
policy
:=
"script-src 'self' 'nonce-existing'"
enhanced
:=
enhanceCSPPolicy
(
policy
)
// Should not add placeholder if nonce already exists
assert
.
NotContains
(
t
,
enhanced
,
NonceTemplate
)
assert
.
Contains
(
t
,
enhanced
,
"'nonce-existing'"
)
})
}
func
TestAddToDirective
(
t
*
testing
.
T
)
{
t
.
Run
(
"adds_to_existing_directive"
,
func
(
t
*
testing
.
T
)
{
policy
:=
"script-src 'self'; style-src 'self'"
result
:=
addToDirective
(
policy
,
"script-src"
,
"https://example.com"
)
assert
.
Contains
(
t
,
result
,
"script-src 'self' https://example.com"
)
})
t
.
Run
(
"creates_directive_if_not_exists"
,
func
(
t
*
testing
.
T
)
{
policy
:=
"default-src 'self'"
result
:=
addToDirective
(
policy
,
"script-src"
,
"https://example.com"
)
assert
.
Contains
(
t
,
result
,
"script-src"
)
assert
.
Contains
(
t
,
result
,
"https://example.com"
)
})
t
.
Run
(
"handles_directive_at_end_without_semicolon"
,
func
(
t
*
testing
.
T
)
{
policy
:=
"default-src 'self'; script-src 'self'"
result
:=
addToDirective
(
policy
,
"script-src"
,
"https://example.com"
)
assert
.
Contains
(
t
,
result
,
"https://example.com"
)
})
t
.
Run
(
"handles_empty_policy"
,
func
(
t
*
testing
.
T
)
{
policy
:=
""
result
:=
addToDirective
(
policy
,
"script-src"
,
"https://example.com"
)
assert
.
Contains
(
t
,
result
,
"script-src"
)
assert
.
Contains
(
t
,
result
,
"https://example.com"
)
})
}
// Benchmark tests
func
BenchmarkGenerateNonce
(
b
*
testing
.
B
)
{
for
i
:=
0
;
i
<
b
.
N
;
i
++
{
GenerateNonce
()
}
}
func
BenchmarkSecurityHeadersMiddleware
(
b
*
testing
.
B
)
{
cfg
:=
config
.
CSPConfig
{
Enabled
:
true
,
Policy
:
"script-src 'self' __CSP_NONCE__"
,
}
middleware
:=
SecurityHeaders
(
cfg
)
b
.
ResetTimer
()
for
i
:=
0
;
i
<
b
.
N
;
i
++
{
w
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
w
)
c
.
Request
=
httptest
.
NewRequest
(
http
.
MethodGet
,
"/"
,
nil
)
middleware
(
c
)
}
}
Prev
1
…
4
5
6
7
8
9
10
11
12
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment