Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Menu
Open sidebar
陈曦
sub2api
Commits
07be258d
Unverified
Commit
07be258d
authored
Feb 24, 2026
by
Wesley Liddick
Committed by
GitHub
Feb 24, 2026
Browse files
Merge pull request #603 from mt21625457/release
feat : 大幅度的性能优化 和 新增了很多功能
parents
dbdb2959
53d55bb9
Changes
268
Hide whitespace changes
Inline
Side-by-side
Too many changes to show.
To preserve performance only
268 of 268+
files are displayed.
Plain diff
Email patch
backend/internal/repository/ops_repo_openai_token_stats.go
0 → 100644
View file @
07be258d
package
repository
import
(
"context"
"database/sql"
"fmt"
"strings"
"github.com/Wei-Shaw/sub2api/internal/service"
)
func
(
r
*
opsRepository
)
GetOpenAITokenStats
(
ctx
context
.
Context
,
filter
*
service
.
OpsOpenAITokenStatsFilter
)
(
*
service
.
OpsOpenAITokenStatsResponse
,
error
)
{
if
r
==
nil
||
r
.
db
==
nil
{
return
nil
,
fmt
.
Errorf
(
"nil ops repository"
)
}
if
filter
==
nil
{
return
nil
,
fmt
.
Errorf
(
"nil filter"
)
}
if
filter
.
StartTime
.
IsZero
()
||
filter
.
EndTime
.
IsZero
()
{
return
nil
,
fmt
.
Errorf
(
"start_time/end_time required"
)
}
// 允许 start_time == end_time(结果为空),与 service 层校验口径保持一致。
if
filter
.
StartTime
.
After
(
filter
.
EndTime
)
{
return
nil
,
fmt
.
Errorf
(
"start_time must be <= end_time"
)
}
dashboardFilter
:=
&
service
.
OpsDashboardFilter
{
StartTime
:
filter
.
StartTime
.
UTC
(),
EndTime
:
filter
.
EndTime
.
UTC
(),
Platform
:
strings
.
TrimSpace
(
strings
.
ToLower
(
filter
.
Platform
)),
GroupID
:
filter
.
GroupID
,
}
join
,
where
,
baseArgs
,
next
:=
buildUsageWhere
(
dashboardFilter
,
dashboardFilter
.
StartTime
,
dashboardFilter
.
EndTime
,
1
)
where
+=
" AND ul.model LIKE 'gpt%'"
baseCTE
:=
`
WITH stats AS (
SELECT
ul.model AS model,
COUNT(*)::bigint AS request_count,
ROUND(
AVG(
CASE
WHEN ul.duration_ms > 0 AND ul.output_tokens > 0
THEN ul.output_tokens * 1000.0 / ul.duration_ms
END
)::numeric,
2
)::float8 AS avg_tokens_per_sec,
ROUND(AVG(ul.first_token_ms)::numeric, 2)::float8 AS avg_first_token_ms,
COALESCE(SUM(ul.output_tokens), 0)::bigint AS total_output_tokens,
COALESCE(ROUND(AVG(ul.duration_ms)::numeric, 0), 0)::bigint AS avg_duration_ms,
COUNT(CASE WHEN ul.first_token_ms IS NOT NULL THEN 1 END)::bigint AS requests_with_first_token
FROM usage_logs ul
`
+
join
+
`
`
+
where
+
`
GROUP BY ul.model
)
`
countSQL
:=
baseCTE
+
`SELECT COUNT(*) FROM stats`
var
total
int64
if
err
:=
r
.
db
.
QueryRowContext
(
ctx
,
countSQL
,
baseArgs
...
)
.
Scan
(
&
total
);
err
!=
nil
{
return
nil
,
err
}
querySQL
:=
baseCTE
+
`
SELECT
model,
request_count,
avg_tokens_per_sec,
avg_first_token_ms,
total_output_tokens,
avg_duration_ms,
requests_with_first_token
FROM stats
ORDER BY request_count DESC, model ASC`
args
:=
make
([]
any
,
0
,
len
(
baseArgs
)
+
2
)
args
=
append
(
args
,
baseArgs
...
)
if
filter
.
IsTopNMode
()
{
querySQL
+=
fmt
.
Sprintf
(
"
\n
LIMIT $%d"
,
next
)
args
=
append
(
args
,
filter
.
TopN
)
}
else
{
offset
:=
(
filter
.
Page
-
1
)
*
filter
.
PageSize
querySQL
+=
fmt
.
Sprintf
(
"
\n
LIMIT $%d OFFSET $%d"
,
next
,
next
+
1
)
args
=
append
(
args
,
filter
.
PageSize
,
offset
)
}
rows
,
err
:=
r
.
db
.
QueryContext
(
ctx
,
querySQL
,
args
...
)
if
err
!=
nil
{
return
nil
,
err
}
defer
func
()
{
_
=
rows
.
Close
()
}()
items
:=
make
([]
*
service
.
OpsOpenAITokenStatsItem
,
0
,
32
)
for
rows
.
Next
()
{
item
:=
&
service
.
OpsOpenAITokenStatsItem
{}
var
avgTPS
sql
.
NullFloat64
var
avgFirstToken
sql
.
NullFloat64
if
err
:=
rows
.
Scan
(
&
item
.
Model
,
&
item
.
RequestCount
,
&
avgTPS
,
&
avgFirstToken
,
&
item
.
TotalOutputTokens
,
&
item
.
AvgDurationMs
,
&
item
.
RequestsWithFirstToken
,
);
err
!=
nil
{
return
nil
,
err
}
if
avgTPS
.
Valid
{
v
:=
avgTPS
.
Float64
item
.
AvgTokensPerSec
=
&
v
}
if
avgFirstToken
.
Valid
{
v
:=
avgFirstToken
.
Float64
item
.
AvgFirstTokenMs
=
&
v
}
items
=
append
(
items
,
item
)
}
if
err
:=
rows
.
Err
();
err
!=
nil
{
return
nil
,
err
}
resp
:=
&
service
.
OpsOpenAITokenStatsResponse
{
TimeRange
:
strings
.
TrimSpace
(
filter
.
TimeRange
),
StartTime
:
dashboardFilter
.
StartTime
,
EndTime
:
dashboardFilter
.
EndTime
,
Platform
:
dashboardFilter
.
Platform
,
GroupID
:
dashboardFilter
.
GroupID
,
Items
:
items
,
Total
:
total
,
}
if
filter
.
IsTopNMode
()
{
topN
:=
filter
.
TopN
resp
.
TopN
=
&
topN
}
else
{
resp
.
Page
=
filter
.
Page
resp
.
PageSize
=
filter
.
PageSize
}
return
resp
,
nil
}
backend/internal/repository/ops_repo_openai_token_stats_test.go
0 → 100644
View file @
07be258d
package
repository
import
(
"context"
"testing"
"time"
"github.com/DATA-DOG/go-sqlmock"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/stretchr/testify/require"
)
func
TestOpsRepositoryGetOpenAITokenStats_PaginationMode
(
t
*
testing
.
T
)
{
db
,
mock
:=
newSQLMock
(
t
)
repo
:=
&
opsRepository
{
db
:
db
}
start
:=
time
.
Date
(
2026
,
1
,
1
,
0
,
0
,
0
,
0
,
time
.
UTC
)
end
:=
start
.
Add
(
24
*
time
.
Hour
)
groupID
:=
int64
(
9
)
filter
:=
&
service
.
OpsOpenAITokenStatsFilter
{
TimeRange
:
"1d"
,
StartTime
:
start
,
EndTime
:
end
,
Platform
:
" OpenAI "
,
GroupID
:
&
groupID
,
Page
:
2
,
PageSize
:
10
,
}
mock
.
ExpectQuery
(
`SELECT COUNT\(\*\) FROM stats`
)
.
WithArgs
(
start
,
end
,
groupID
,
"openai"
)
.
WillReturnRows
(
sqlmock
.
NewRows
([]
string
{
"count"
})
.
AddRow
(
int64
(
3
)))
rows
:=
sqlmock
.
NewRows
([]
string
{
"model"
,
"request_count"
,
"avg_tokens_per_sec"
,
"avg_first_token_ms"
,
"total_output_tokens"
,
"avg_duration_ms"
,
"requests_with_first_token"
,
})
.
AddRow
(
"gpt-4o-mini"
,
int64
(
20
),
21.56
,
120.34
,
int64
(
3000
),
int64
(
850
),
int64
(
18
))
.
AddRow
(
"gpt-4.1"
,
int64
(
20
),
10.2
,
240.0
,
int64
(
2500
),
int64
(
900
),
int64
(
20
))
mock
.
ExpectQuery
(
`ORDER BY request_count DESC, model ASC\s+LIMIT \$5 OFFSET \$6`
)
.
WithArgs
(
start
,
end
,
groupID
,
"openai"
,
10
,
10
)
.
WillReturnRows
(
rows
)
resp
,
err
:=
repo
.
GetOpenAITokenStats
(
context
.
Background
(),
filter
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
resp
)
require
.
Equal
(
t
,
int64
(
3
),
resp
.
Total
)
require
.
Equal
(
t
,
2
,
resp
.
Page
)
require
.
Equal
(
t
,
10
,
resp
.
PageSize
)
require
.
Nil
(
t
,
resp
.
TopN
)
require
.
Equal
(
t
,
"openai"
,
resp
.
Platform
)
require
.
NotNil
(
t
,
resp
.
GroupID
)
require
.
Equal
(
t
,
groupID
,
*
resp
.
GroupID
)
require
.
Len
(
t
,
resp
.
Items
,
2
)
require
.
Equal
(
t
,
"gpt-4o-mini"
,
resp
.
Items
[
0
]
.
Model
)
require
.
NotNil
(
t
,
resp
.
Items
[
0
]
.
AvgTokensPerSec
)
require
.
InDelta
(
t
,
21.56
,
*
resp
.
Items
[
0
]
.
AvgTokensPerSec
,
0.0001
)
require
.
NotNil
(
t
,
resp
.
Items
[
0
]
.
AvgFirstTokenMs
)
require
.
InDelta
(
t
,
120.34
,
*
resp
.
Items
[
0
]
.
AvgFirstTokenMs
,
0.0001
)
require
.
NoError
(
t
,
mock
.
ExpectationsWereMet
())
}
func
TestOpsRepositoryGetOpenAITokenStats_TopNMode
(
t
*
testing
.
T
)
{
db
,
mock
:=
newSQLMock
(
t
)
repo
:=
&
opsRepository
{
db
:
db
}
start
:=
time
.
Date
(
2026
,
1
,
1
,
10
,
0
,
0
,
0
,
time
.
UTC
)
end
:=
start
.
Add
(
time
.
Hour
)
filter
:=
&
service
.
OpsOpenAITokenStatsFilter
{
TimeRange
:
"1h"
,
StartTime
:
start
,
EndTime
:
end
,
TopN
:
5
,
}
mock
.
ExpectQuery
(
`SELECT COUNT\(\*\) FROM stats`
)
.
WithArgs
(
start
,
end
)
.
WillReturnRows
(
sqlmock
.
NewRows
([]
string
{
"count"
})
.
AddRow
(
int64
(
1
)))
rows
:=
sqlmock
.
NewRows
([]
string
{
"model"
,
"request_count"
,
"avg_tokens_per_sec"
,
"avg_first_token_ms"
,
"total_output_tokens"
,
"avg_duration_ms"
,
"requests_with_first_token"
,
})
.
AddRow
(
"gpt-4o"
,
int64
(
5
),
nil
,
nil
,
int64
(
0
),
int64
(
0
),
int64
(
0
))
mock
.
ExpectQuery
(
`ORDER BY request_count DESC, model ASC\s+LIMIT \$3`
)
.
WithArgs
(
start
,
end
,
5
)
.
WillReturnRows
(
rows
)
resp
,
err
:=
repo
.
GetOpenAITokenStats
(
context
.
Background
(),
filter
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
resp
)
require
.
NotNil
(
t
,
resp
.
TopN
)
require
.
Equal
(
t
,
5
,
*
resp
.
TopN
)
require
.
Equal
(
t
,
0
,
resp
.
Page
)
require
.
Equal
(
t
,
0
,
resp
.
PageSize
)
require
.
Len
(
t
,
resp
.
Items
,
1
)
require
.
Nil
(
t
,
resp
.
Items
[
0
]
.
AvgTokensPerSec
)
require
.
Nil
(
t
,
resp
.
Items
[
0
]
.
AvgFirstTokenMs
)
require
.
NoError
(
t
,
mock
.
ExpectationsWereMet
())
}
func
TestOpsRepositoryGetOpenAITokenStats_EmptyResult
(
t
*
testing
.
T
)
{
db
,
mock
:=
newSQLMock
(
t
)
repo
:=
&
opsRepository
{
db
:
db
}
start
:=
time
.
Date
(
2026
,
1
,
2
,
0
,
0
,
0
,
0
,
time
.
UTC
)
end
:=
start
.
Add
(
30
*
time
.
Minute
)
filter
:=
&
service
.
OpsOpenAITokenStatsFilter
{
TimeRange
:
"30m"
,
StartTime
:
start
,
EndTime
:
end
,
Page
:
1
,
PageSize
:
20
,
}
mock
.
ExpectQuery
(
`SELECT COUNT\(\*\) FROM stats`
)
.
WithArgs
(
start
,
end
)
.
WillReturnRows
(
sqlmock
.
NewRows
([]
string
{
"count"
})
.
AddRow
(
int64
(
0
)))
mock
.
ExpectQuery
(
`ORDER BY request_count DESC, model ASC\s+LIMIT \$3 OFFSET \$4`
)
.
WithArgs
(
start
,
end
,
20
,
0
)
.
WillReturnRows
(
sqlmock
.
NewRows
([]
string
{
"model"
,
"request_count"
,
"avg_tokens_per_sec"
,
"avg_first_token_ms"
,
"total_output_tokens"
,
"avg_duration_ms"
,
"requests_with_first_token"
,
}))
resp
,
err
:=
repo
.
GetOpenAITokenStats
(
context
.
Background
(),
filter
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
resp
)
require
.
Equal
(
t
,
int64
(
0
),
resp
.
Total
)
require
.
Len
(
t
,
resp
.
Items
,
0
)
require
.
Equal
(
t
,
1
,
resp
.
Page
)
require
.
Equal
(
t
,
20
,
resp
.
PageSize
)
require
.
NoError
(
t
,
mock
.
ExpectationsWereMet
())
}
backend/internal/repository/ops_repo_system_logs_test.go
0 → 100644
View file @
07be258d
package
repository
import
(
"strings"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
)
func
TestBuildOpsSystemLogsWhere_WithClientRequestIDAndUserID
(
t
*
testing
.
T
)
{
start
:=
time
.
Date
(
2026
,
2
,
1
,
0
,
0
,
0
,
0
,
time
.
UTC
)
end
:=
time
.
Date
(
2026
,
2
,
2
,
0
,
0
,
0
,
0
,
time
.
UTC
)
userID
:=
int64
(
12
)
accountID
:=
int64
(
34
)
filter
:=
&
service
.
OpsSystemLogFilter
{
StartTime
:
&
start
,
EndTime
:
&
end
,
Level
:
"warn"
,
Component
:
"http.access"
,
RequestID
:
"req-1"
,
ClientRequestID
:
"creq-1"
,
UserID
:
&
userID
,
AccountID
:
&
accountID
,
Platform
:
"openai"
,
Model
:
"gpt-5"
,
Query
:
"timeout"
,
}
where
,
args
,
hasConstraint
:=
buildOpsSystemLogsWhere
(
filter
)
if
!
hasConstraint
{
t
.
Fatalf
(
"expected hasConstraint=true"
)
}
if
where
==
""
{
t
.
Fatalf
(
"where should not be empty"
)
}
if
len
(
args
)
!=
11
{
t
.
Fatalf
(
"args len = %d, want 11"
,
len
(
args
))
}
if
!
contains
(
where
,
"COALESCE(l.client_request_id,'') = $"
)
{
t
.
Fatalf
(
"where should include client_request_id condition: %s"
,
where
)
}
if
!
contains
(
where
,
"l.user_id = $"
)
{
t
.
Fatalf
(
"where should include user_id condition: %s"
,
where
)
}
}
func
TestBuildOpsSystemLogsCleanupWhere_RequireConstraint
(
t
*
testing
.
T
)
{
where
,
args
,
hasConstraint
:=
buildOpsSystemLogsCleanupWhere
(
&
service
.
OpsSystemLogCleanupFilter
{})
if
hasConstraint
{
t
.
Fatalf
(
"expected hasConstraint=false"
)
}
if
where
==
""
{
t
.
Fatalf
(
"where should not be empty"
)
}
if
len
(
args
)
!=
0
{
t
.
Fatalf
(
"args len = %d, want 0"
,
len
(
args
))
}
}
func
TestBuildOpsSystemLogsCleanupWhere_WithClientRequestIDAndUserID
(
t
*
testing
.
T
)
{
userID
:=
int64
(
9
)
filter
:=
&
service
.
OpsSystemLogCleanupFilter
{
ClientRequestID
:
"creq-9"
,
UserID
:
&
userID
,
}
where
,
args
,
hasConstraint
:=
buildOpsSystemLogsCleanupWhere
(
filter
)
if
!
hasConstraint
{
t
.
Fatalf
(
"expected hasConstraint=true"
)
}
if
len
(
args
)
!=
2
{
t
.
Fatalf
(
"args len = %d, want 2"
,
len
(
args
))
}
if
!
contains
(
where
,
"COALESCE(l.client_request_id,'') = $"
)
{
t
.
Fatalf
(
"where should include client_request_id condition: %s"
,
where
)
}
if
!
contains
(
where
,
"l.user_id = $"
)
{
t
.
Fatalf
(
"where should include user_id condition: %s"
,
where
)
}
}
func
contains
(
s
string
,
sub
string
)
bool
{
return
strings
.
Contains
(
s
,
sub
)
}
backend/internal/repository/promo_code_repo.go
View file @
07be258d
...
...
@@ -132,7 +132,7 @@ func (r *promoCodeRepository) ListWithFilters(ctx context.Context, params pagina
q
=
q
.
Where
(
promocode
.
CodeContainsFold
(
search
))
}
total
,
err
:=
q
.
Count
(
ctx
)
total
,
err
:=
q
.
Clone
()
.
Count
(
ctx
)
if
err
!=
nil
{
return
nil
,
nil
,
err
}
...
...
@@ -187,7 +187,7 @@ func (r *promoCodeRepository) ListUsagesByPromoCode(ctx context.Context, promoCo
q
:=
r
.
client
.
PromoCodeUsage
.
Query
()
.
Where
(
promocodeusage
.
PromoCodeIDEQ
(
promoCodeID
))
total
,
err
:=
q
.
Count
(
ctx
)
total
,
err
:=
q
.
Clone
()
.
Count
(
ctx
)
if
err
!=
nil
{
return
nil
,
nil
,
err
}
...
...
backend/internal/repository/proxy_probe_service.go
View file @
07be258d
...
...
@@ -19,10 +19,14 @@ func NewProxyExitInfoProber(cfg *config.Config) service.ProxyExitInfoProber {
insecure
:=
false
allowPrivate
:=
false
validateResolvedIP
:=
true
maxResponseBytes
:=
defaultProxyProbeResponseMaxBytes
if
cfg
!=
nil
{
insecure
=
cfg
.
Security
.
ProxyProbe
.
InsecureSkipVerify
allowPrivate
=
cfg
.
Security
.
URLAllowlist
.
AllowPrivateHosts
validateResolvedIP
=
cfg
.
Security
.
URLAllowlist
.
Enabled
if
cfg
.
Gateway
.
ProxyProbeResponseReadMaxBytes
>
0
{
maxResponseBytes
=
cfg
.
Gateway
.
ProxyProbeResponseReadMaxBytes
}
}
if
insecure
{
log
.
Printf
(
"[ProxyProbe] Warning: insecure_skip_verify is not allowed and will cause probe failure."
)
...
...
@@ -31,11 +35,13 @@ func NewProxyExitInfoProber(cfg *config.Config) service.ProxyExitInfoProber {
insecureSkipVerify
:
insecure
,
allowPrivateHosts
:
allowPrivate
,
validateResolvedIP
:
validateResolvedIP
,
maxResponseBytes
:
maxResponseBytes
,
}
}
const
(
defaultProxyProbeTimeout
=
30
*
time
.
Second
defaultProxyProbeTimeout
=
30
*
time
.
Second
defaultProxyProbeResponseMaxBytes
=
int64
(
1024
*
1024
)
)
// probeURLs 按优先级排列的探测 URL 列表
...
...
@@ -52,6 +58,7 @@ type proxyProbeService struct {
insecureSkipVerify
bool
allowPrivateHosts
bool
validateResolvedIP
bool
maxResponseBytes
int64
}
func
(
s
*
proxyProbeService
)
ProbeProxy
(
ctx
context
.
Context
,
proxyURL
string
)
(
*
service
.
ProxyExitInfo
,
int64
,
error
)
{
...
...
@@ -98,10 +105,17 @@ func (s *proxyProbeService) probeWithURL(ctx context.Context, client *http.Clien
return
nil
,
latencyMs
,
fmt
.
Errorf
(
"request failed with status: %d"
,
resp
.
StatusCode
)
}
body
,
err
:=
io
.
ReadAll
(
resp
.
Body
)
maxResponseBytes
:=
s
.
maxResponseBytes
if
maxResponseBytes
<=
0
{
maxResponseBytes
=
defaultProxyProbeResponseMaxBytes
}
body
,
err
:=
io
.
ReadAll
(
io
.
LimitReader
(
resp
.
Body
,
maxResponseBytes
+
1
))
if
err
!=
nil
{
return
nil
,
latencyMs
,
fmt
.
Errorf
(
"failed to read response: %w"
,
err
)
}
if
int64
(
len
(
body
))
>
maxResponseBytes
{
return
nil
,
latencyMs
,
fmt
.
Errorf
(
"proxy probe response exceeds limit: %d"
,
maxResponseBytes
)
}
switch
parser
{
case
"ip-api"
:
...
...
backend/internal/repository/security_secret_bootstrap.go
0 → 100644
View file @
07be258d
package
repository
import
(
"context"
"crypto/rand"
"database/sql"
"encoding/hex"
"errors"
"fmt"
"log"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/securitysecret"
"github.com/Wei-Shaw/sub2api/internal/config"
)
const
(
securitySecretKeyJWT
=
"jwt_secret"
securitySecretReadRetryMax
=
5
securitySecretReadRetryWait
=
10
*
time
.
Millisecond
)
var
readRandomBytes
=
rand
.
Read
func
ensureBootstrapSecrets
(
ctx
context
.
Context
,
client
*
ent
.
Client
,
cfg
*
config
.
Config
)
error
{
if
client
==
nil
{
return
fmt
.
Errorf
(
"nil ent client"
)
}
if
cfg
==
nil
{
return
fmt
.
Errorf
(
"nil config"
)
}
cfg
.
JWT
.
Secret
=
strings
.
TrimSpace
(
cfg
.
JWT
.
Secret
)
if
cfg
.
JWT
.
Secret
!=
""
{
storedSecret
,
err
:=
createSecuritySecretIfAbsent
(
ctx
,
client
,
securitySecretKeyJWT
,
cfg
.
JWT
.
Secret
)
if
err
!=
nil
{
return
fmt
.
Errorf
(
"persist jwt secret: %w"
,
err
)
}
if
storedSecret
!=
cfg
.
JWT
.
Secret
{
log
.
Println
(
"Warning: configured JWT secret mismatches persisted value; using persisted secret for cross-instance consistency."
)
}
cfg
.
JWT
.
Secret
=
storedSecret
return
nil
}
secret
,
created
,
err
:=
getOrCreateGeneratedSecuritySecret
(
ctx
,
client
,
securitySecretKeyJWT
,
32
)
if
err
!=
nil
{
return
fmt
.
Errorf
(
"ensure jwt secret: %w"
,
err
)
}
cfg
.
JWT
.
Secret
=
secret
if
created
{
log
.
Println
(
"Warning: JWT secret auto-generated and persisted to database. Consider rotating to a managed secret for production."
)
}
return
nil
}
func
getOrCreateGeneratedSecuritySecret
(
ctx
context
.
Context
,
client
*
ent
.
Client
,
key
string
,
byteLength
int
)
(
string
,
bool
,
error
)
{
existing
,
err
:=
client
.
SecuritySecret
.
Query
()
.
Where
(
securitysecret
.
KeyEQ
(
key
))
.
Only
(
ctx
)
if
err
==
nil
{
value
:=
strings
.
TrimSpace
(
existing
.
Value
)
if
len
([]
byte
(
value
))
<
32
{
return
""
,
false
,
fmt
.
Errorf
(
"stored secret %q must be at least 32 bytes"
,
key
)
}
return
value
,
false
,
nil
}
if
!
ent
.
IsNotFound
(
err
)
{
return
""
,
false
,
err
}
generated
,
err
:=
generateHexSecret
(
byteLength
)
if
err
!=
nil
{
return
""
,
false
,
err
}
if
err
:=
client
.
SecuritySecret
.
Create
()
.
SetKey
(
key
)
.
SetValue
(
generated
)
.
OnConflictColumns
(
securitysecret
.
FieldKey
)
.
DoNothing
()
.
Exec
(
ctx
);
err
!=
nil
{
if
!
isSQLNoRowsError
(
err
)
{
return
""
,
false
,
err
}
}
stored
,
err
:=
querySecuritySecretWithRetry
(
ctx
,
client
,
key
)
if
err
!=
nil
{
return
""
,
false
,
err
}
value
:=
strings
.
TrimSpace
(
stored
.
Value
)
if
len
([]
byte
(
value
))
<
32
{
return
""
,
false
,
fmt
.
Errorf
(
"stored secret %q must be at least 32 bytes"
,
key
)
}
return
value
,
value
==
generated
,
nil
}
func
createSecuritySecretIfAbsent
(
ctx
context
.
Context
,
client
*
ent
.
Client
,
key
,
value
string
)
(
string
,
error
)
{
value
=
strings
.
TrimSpace
(
value
)
if
len
([]
byte
(
value
))
<
32
{
return
""
,
fmt
.
Errorf
(
"secret %q must be at least 32 bytes"
,
key
)
}
if
err
:=
client
.
SecuritySecret
.
Create
()
.
SetKey
(
key
)
.
SetValue
(
value
)
.
OnConflictColumns
(
securitysecret
.
FieldKey
)
.
DoNothing
()
.
Exec
(
ctx
);
err
!=
nil
{
if
!
isSQLNoRowsError
(
err
)
{
return
""
,
err
}
}
stored
,
err
:=
querySecuritySecretWithRetry
(
ctx
,
client
,
key
)
if
err
!=
nil
{
return
""
,
err
}
storedValue
:=
strings
.
TrimSpace
(
stored
.
Value
)
if
len
([]
byte
(
storedValue
))
<
32
{
return
""
,
fmt
.
Errorf
(
"stored secret %q must be at least 32 bytes"
,
key
)
}
return
storedValue
,
nil
}
func
querySecuritySecretWithRetry
(
ctx
context
.
Context
,
client
*
ent
.
Client
,
key
string
)
(
*
ent
.
SecuritySecret
,
error
)
{
var
lastErr
error
for
attempt
:=
0
;
attempt
<=
securitySecretReadRetryMax
;
attempt
++
{
stored
,
err
:=
client
.
SecuritySecret
.
Query
()
.
Where
(
securitysecret
.
KeyEQ
(
key
))
.
Only
(
ctx
)
if
err
==
nil
{
return
stored
,
nil
}
if
!
isSecretNotFoundError
(
err
)
{
return
nil
,
err
}
lastErr
=
err
if
attempt
==
securitySecretReadRetryMax
{
break
}
timer
:=
time
.
NewTimer
(
securitySecretReadRetryWait
)
select
{
case
<-
ctx
.
Done
()
:
timer
.
Stop
()
return
nil
,
ctx
.
Err
()
case
<-
timer
.
C
:
}
}
return
nil
,
lastErr
}
func
isSecretNotFoundError
(
err
error
)
bool
{
if
err
==
nil
{
return
false
}
return
ent
.
IsNotFound
(
err
)
||
isSQLNoRowsError
(
err
)
}
func
isSQLNoRowsError
(
err
error
)
bool
{
if
err
==
nil
{
return
false
}
return
errors
.
Is
(
err
,
sql
.
ErrNoRows
)
||
strings
.
Contains
(
err
.
Error
(),
"no rows in result set"
)
}
func
generateHexSecret
(
byteLength
int
)
(
string
,
error
)
{
if
byteLength
<=
0
{
byteLength
=
32
}
buf
:=
make
([]
byte
,
byteLength
)
if
_
,
err
:=
readRandomBytes
(
buf
);
err
!=
nil
{
return
""
,
fmt
.
Errorf
(
"generate random secret: %w"
,
err
)
}
return
hex
.
EncodeToString
(
buf
),
nil
}
backend/internal/repository/security_secret_bootstrap_test.go
0 → 100644
View file @
07be258d
package
repository
import
(
"context"
"database/sql"
"encoding/hex"
"errors"
"fmt"
"strings"
"sync"
"testing"
dbent
"github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/enttest"
"github.com/Wei-Shaw/sub2api/ent/securitysecret"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/stretchr/testify/require"
"entgo.io/ent/dialect"
entsql
"entgo.io/ent/dialect/sql"
_
"modernc.org/sqlite"
)
func
newSecuritySecretTestClient
(
t
*
testing
.
T
)
*
dbent
.
Client
{
t
.
Helper
()
name
:=
strings
.
ReplaceAll
(
t
.
Name
(),
"/"
,
"_"
)
dsn
:=
fmt
.
Sprintf
(
"file:%s?mode=memory&cache=shared&_fk=1"
,
name
)
db
,
err
:=
sql
.
Open
(
"sqlite"
,
dsn
)
require
.
NoError
(
t
,
err
)
t
.
Cleanup
(
func
()
{
_
=
db
.
Close
()
})
_
,
err
=
db
.
Exec
(
"PRAGMA foreign_keys = ON"
)
require
.
NoError
(
t
,
err
)
drv
:=
entsql
.
OpenDB
(
dialect
.
SQLite
,
db
)
client
:=
enttest
.
NewClient
(
t
,
enttest
.
WithOptions
(
dbent
.
Driver
(
drv
)))
t
.
Cleanup
(
func
()
{
_
=
client
.
Close
()
})
return
client
}
func
TestEnsureBootstrapSecretsNilInputs
(
t
*
testing
.
T
)
{
err
:=
ensureBootstrapSecrets
(
context
.
Background
(),
nil
,
&
config
.
Config
{})
require
.
Error
(
t
,
err
)
require
.
Contains
(
t
,
err
.
Error
(),
"nil ent client"
)
client
:=
newSecuritySecretTestClient
(
t
)
err
=
ensureBootstrapSecrets
(
context
.
Background
(),
client
,
nil
)
require
.
Error
(
t
,
err
)
require
.
Contains
(
t
,
err
.
Error
(),
"nil config"
)
}
func
TestEnsureBootstrapSecretsGenerateAndPersistJWTSecret
(
t
*
testing
.
T
)
{
client
:=
newSecuritySecretTestClient
(
t
)
cfg
:=
&
config
.
Config
{}
err
:=
ensureBootstrapSecrets
(
context
.
Background
(),
client
,
cfg
)
require
.
NoError
(
t
,
err
)
require
.
NotEmpty
(
t
,
cfg
.
JWT
.
Secret
)
require
.
GreaterOrEqual
(
t
,
len
([]
byte
(
cfg
.
JWT
.
Secret
)),
32
)
stored
,
err
:=
client
.
SecuritySecret
.
Query
()
.
Where
(
securitysecret
.
KeyEQ
(
securitySecretKeyJWT
))
.
Only
(
context
.
Background
())
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
cfg
.
JWT
.
Secret
,
stored
.
Value
)
}
func
TestEnsureBootstrapSecretsLoadExistingJWTSecret
(
t
*
testing
.
T
)
{
client
:=
newSecuritySecretTestClient
(
t
)
_
,
err
:=
client
.
SecuritySecret
.
Create
()
.
SetKey
(
securitySecretKeyJWT
)
.
SetValue
(
"existing-jwt-secret-32bytes-long!!!!"
)
.
Save
(
context
.
Background
())
require
.
NoError
(
t
,
err
)
cfg
:=
&
config
.
Config
{}
err
=
ensureBootstrapSecrets
(
context
.
Background
(),
client
,
cfg
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
"existing-jwt-secret-32bytes-long!!!!"
,
cfg
.
JWT
.
Secret
)
}
func
TestEnsureBootstrapSecretsRejectInvalidStoredSecret
(
t
*
testing
.
T
)
{
client
:=
newSecuritySecretTestClient
(
t
)
_
,
err
:=
client
.
SecuritySecret
.
Create
()
.
SetKey
(
securitySecretKeyJWT
)
.
SetValue
(
"too-short"
)
.
Save
(
context
.
Background
())
require
.
NoError
(
t
,
err
)
cfg
:=
&
config
.
Config
{}
err
=
ensureBootstrapSecrets
(
context
.
Background
(),
client
,
cfg
)
require
.
Error
(
t
,
err
)
require
.
Contains
(
t
,
err
.
Error
(),
"at least 32 bytes"
)
}
func
TestEnsureBootstrapSecretsPersistConfiguredJWTSecret
(
t
*
testing
.
T
)
{
client
:=
newSecuritySecretTestClient
(
t
)
cfg
:=
&
config
.
Config
{
JWT
:
config
.
JWTConfig
{
Secret
:
"configured-jwt-secret-32bytes-long!!"
},
}
err
:=
ensureBootstrapSecrets
(
context
.
Background
(),
client
,
cfg
)
require
.
NoError
(
t
,
err
)
stored
,
err
:=
client
.
SecuritySecret
.
Query
()
.
Where
(
securitysecret
.
KeyEQ
(
securitySecretKeyJWT
))
.
Only
(
context
.
Background
())
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
"configured-jwt-secret-32bytes-long!!"
,
stored
.
Value
)
}
func
TestEnsureBootstrapSecretsConfiguredSecretTooShort
(
t
*
testing
.
T
)
{
client
:=
newSecuritySecretTestClient
(
t
)
cfg
:=
&
config
.
Config
{
JWT
:
config
.
JWTConfig
{
Secret
:
"short"
}}
err
:=
ensureBootstrapSecrets
(
context
.
Background
(),
client
,
cfg
)
require
.
Error
(
t
,
err
)
require
.
Contains
(
t
,
err
.
Error
(),
"at least 32 bytes"
)
}
func
TestEnsureBootstrapSecretsConfiguredSecretDuplicateIgnored
(
t
*
testing
.
T
)
{
client
:=
newSecuritySecretTestClient
(
t
)
_
,
err
:=
client
.
SecuritySecret
.
Create
()
.
SetKey
(
securitySecretKeyJWT
)
.
SetValue
(
"existing-jwt-secret-32bytes-long!!!!"
)
.
Save
(
context
.
Background
())
require
.
NoError
(
t
,
err
)
cfg
:=
&
config
.
Config
{
JWT
:
config
.
JWTConfig
{
Secret
:
"another-configured-jwt-secret-32!!!!"
}}
err
=
ensureBootstrapSecrets
(
context
.
Background
(),
client
,
cfg
)
require
.
NoError
(
t
,
err
)
stored
,
err
:=
client
.
SecuritySecret
.
Query
()
.
Where
(
securitysecret
.
KeyEQ
(
securitySecretKeyJWT
))
.
Only
(
context
.
Background
())
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
"existing-jwt-secret-32bytes-long!!!!"
,
stored
.
Value
)
require
.
Equal
(
t
,
"existing-jwt-secret-32bytes-long!!!!"
,
cfg
.
JWT
.
Secret
)
}
func
TestGetOrCreateGeneratedSecuritySecretTrimmedExistingValue
(
t
*
testing
.
T
)
{
client
:=
newSecuritySecretTestClient
(
t
)
_
,
err
:=
client
.
SecuritySecret
.
Create
()
.
SetKey
(
"trimmed_key"
)
.
SetValue
(
" existing-trimmed-secret-32bytes-long!! "
)
.
Save
(
context
.
Background
())
require
.
NoError
(
t
,
err
)
value
,
created
,
err
:=
getOrCreateGeneratedSecuritySecret
(
context
.
Background
(),
client
,
"trimmed_key"
,
32
)
require
.
NoError
(
t
,
err
)
require
.
False
(
t
,
created
)
require
.
Equal
(
t
,
"existing-trimmed-secret-32bytes-long!!"
,
value
)
}
func
TestGetOrCreateGeneratedSecuritySecretQueryError
(
t
*
testing
.
T
)
{
client
:=
newSecuritySecretTestClient
(
t
)
require
.
NoError
(
t
,
client
.
Close
())
_
,
_
,
err
:=
getOrCreateGeneratedSecuritySecret
(
context
.
Background
(),
client
,
"closed_client_key"
,
32
)
require
.
Error
(
t
,
err
)
}
func
TestGetOrCreateGeneratedSecuritySecretCreateValidationError
(
t
*
testing
.
T
)
{
client
:=
newSecuritySecretTestClient
(
t
)
tooLongKey
:=
strings
.
Repeat
(
"k"
,
101
)
_
,
_
,
err
:=
getOrCreateGeneratedSecuritySecret
(
context
.
Background
(),
client
,
tooLongKey
,
32
)
require
.
Error
(
t
,
err
)
}
func
TestGetOrCreateGeneratedSecuritySecretConcurrentCreation
(
t
*
testing
.
T
)
{
client
:=
newSecuritySecretTestClient
(
t
)
const
goroutines
=
8
key
:=
"concurrent_bootstrap_key"
values
:=
make
([]
string
,
goroutines
)
createdFlags
:=
make
([]
bool
,
goroutines
)
errs
:=
make
([]
error
,
goroutines
)
var
wg
sync
.
WaitGroup
for
i
:=
0
;
i
<
goroutines
;
i
++
{
wg
.
Add
(
1
)
go
func
(
idx
int
)
{
defer
wg
.
Done
()
values
[
idx
],
createdFlags
[
idx
],
errs
[
idx
]
=
getOrCreateGeneratedSecuritySecret
(
context
.
Background
(),
client
,
key
,
32
)
}(
i
)
}
wg
.
Wait
()
for
i
:=
range
errs
{
require
.
NoError
(
t
,
errs
[
i
])
require
.
NotEmpty
(
t
,
values
[
i
])
}
for
i
:=
1
;
i
<
len
(
values
);
i
++
{
require
.
Equal
(
t
,
values
[
0
],
values
[
i
])
}
createdCount
:=
0
for
_
,
created
:=
range
createdFlags
{
if
created
{
createdCount
++
}
}
require
.
GreaterOrEqual
(
t
,
createdCount
,
1
)
require
.
LessOrEqual
(
t
,
createdCount
,
1
)
count
,
err
:=
client
.
SecuritySecret
.
Query
()
.
Where
(
securitysecret
.
KeyEQ
(
key
))
.
Count
(
context
.
Background
())
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
1
,
count
)
}
func
TestGetOrCreateGeneratedSecuritySecretGenerateError
(
t
*
testing
.
T
)
{
client
:=
newSecuritySecretTestClient
(
t
)
originalRead
:=
readRandomBytes
readRandomBytes
=
func
([]
byte
)
(
int
,
error
)
{
return
0
,
errors
.
New
(
"boom"
)
}
t
.
Cleanup
(
func
()
{
readRandomBytes
=
originalRead
})
_
,
_
,
err
:=
getOrCreateGeneratedSecuritySecret
(
context
.
Background
(),
client
,
"gen_error_key"
,
32
)
require
.
Error
(
t
,
err
)
require
.
Contains
(
t
,
err
.
Error
(),
"boom"
)
}
func
TestCreateSecuritySecretIfAbsent
(
t
*
testing
.
T
)
{
client
:=
newSecuritySecretTestClient
(
t
)
_
,
err
:=
createSecuritySecretIfAbsent
(
context
.
Background
(),
client
,
"abc"
,
"short"
)
require
.
Error
(
t
,
err
)
require
.
Contains
(
t
,
err
.
Error
(),
"at least 32 bytes"
)
stored
,
err
:=
createSecuritySecretIfAbsent
(
context
.
Background
(),
client
,
"abc"
,
"valid-jwt-secret-value-32bytes-long"
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
"valid-jwt-secret-value-32bytes-long"
,
stored
)
stored
,
err
=
createSecuritySecretIfAbsent
(
context
.
Background
(),
client
,
"abc"
,
"another-valid-secret-value-32bytes"
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
"valid-jwt-secret-value-32bytes-long"
,
stored
)
count
,
err
:=
client
.
SecuritySecret
.
Query
()
.
Where
(
securitysecret
.
KeyEQ
(
"abc"
))
.
Count
(
context
.
Background
())
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
1
,
count
)
}
func
TestCreateSecuritySecretIfAbsentValidationError
(
t
*
testing
.
T
)
{
client
:=
newSecuritySecretTestClient
(
t
)
_
,
err
:=
createSecuritySecretIfAbsent
(
context
.
Background
(),
client
,
strings
.
Repeat
(
"k"
,
101
),
"valid-jwt-secret-value-32bytes-long"
,
)
require
.
Error
(
t
,
err
)
}
func
TestCreateSecuritySecretIfAbsentExecError
(
t
*
testing
.
T
)
{
client
:=
newSecuritySecretTestClient
(
t
)
require
.
NoError
(
t
,
client
.
Close
())
_
,
err
:=
createSecuritySecretIfAbsent
(
context
.
Background
(),
client
,
"closed-client-key"
,
"valid-jwt-secret-value-32bytes-long"
)
require
.
Error
(
t
,
err
)
}
func
TestQuerySecuritySecretWithRetrySuccess
(
t
*
testing
.
T
)
{
client
:=
newSecuritySecretTestClient
(
t
)
created
,
err
:=
client
.
SecuritySecret
.
Create
()
.
SetKey
(
"retry_success_key"
)
.
SetValue
(
"retry-success-jwt-secret-value-32!!"
)
.
Save
(
context
.
Background
())
require
.
NoError
(
t
,
err
)
got
,
err
:=
querySecuritySecretWithRetry
(
context
.
Background
(),
client
,
"retry_success_key"
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
created
.
ID
,
got
.
ID
)
require
.
Equal
(
t
,
"retry-success-jwt-secret-value-32!!"
,
got
.
Value
)
}
func
TestQuerySecuritySecretWithRetryExhausted
(
t
*
testing
.
T
)
{
client
:=
newSecuritySecretTestClient
(
t
)
_
,
err
:=
querySecuritySecretWithRetry
(
context
.
Background
(),
client
,
"retry_missing_key"
)
require
.
Error
(
t
,
err
)
require
.
True
(
t
,
isSecretNotFoundError
(
err
))
}
func
TestQuerySecuritySecretWithRetryContextCanceled
(
t
*
testing
.
T
)
{
client
:=
newSecuritySecretTestClient
(
t
)
ctx
,
cancel
:=
context
.
WithTimeout
(
context
.
Background
(),
securitySecretReadRetryWait
/
2
)
defer
cancel
()
_
,
err
:=
querySecuritySecretWithRetry
(
ctx
,
client
,
"retry_ctx_cancel_key"
)
require
.
Error
(
t
,
err
)
require
.
ErrorIs
(
t
,
err
,
context
.
DeadlineExceeded
)
}
func
TestQuerySecuritySecretWithRetryNonNotFoundError
(
t
*
testing
.
T
)
{
client
:=
newSecuritySecretTestClient
(
t
)
require
.
NoError
(
t
,
client
.
Close
())
_
,
err
:=
querySecuritySecretWithRetry
(
context
.
Background
(),
client
,
"retry_closed_client_key"
)
require
.
Error
(
t
,
err
)
require
.
False
(
t
,
isSecretNotFoundError
(
err
))
}
func
TestSecretNotFoundHelpers
(
t
*
testing
.
T
)
{
require
.
False
(
t
,
isSecretNotFoundError
(
nil
))
require
.
False
(
t
,
isSQLNoRowsError
(
nil
))
require
.
True
(
t
,
isSQLNoRowsError
(
sql
.
ErrNoRows
))
require
.
True
(
t
,
isSQLNoRowsError
(
fmt
.
Errorf
(
"wrapped: %w"
,
sql
.
ErrNoRows
)))
require
.
True
(
t
,
isSQLNoRowsError
(
errors
.
New
(
"sql: no rows in result set"
)))
require
.
True
(
t
,
isSecretNotFoundError
(
sql
.
ErrNoRows
))
require
.
True
(
t
,
isSecretNotFoundError
(
errors
.
New
(
"sql: no rows in result set"
)))
require
.
False
(
t
,
isSecretNotFoundError
(
errors
.
New
(
"some other error"
)))
}
func
TestGenerateHexSecretReadError
(
t
*
testing
.
T
)
{
originalRead
:=
readRandomBytes
readRandomBytes
=
func
([]
byte
)
(
int
,
error
)
{
return
0
,
errors
.
New
(
"read random failed"
)
}
t
.
Cleanup
(
func
()
{
readRandomBytes
=
originalRead
})
_
,
err
:=
generateHexSecret
(
32
)
require
.
Error
(
t
,
err
)
require
.
Contains
(
t
,
err
.
Error
(),
"read random failed"
)
}
func
TestGenerateHexSecretLengths
(
t
*
testing
.
T
)
{
v1
,
err
:=
generateHexSecret
(
0
)
require
.
NoError
(
t
,
err
)
require
.
Len
(
t
,
v1
,
64
)
_
,
err
=
hex
.
DecodeString
(
v1
)
require
.
NoError
(
t
,
err
)
v2
,
err
:=
generateHexSecret
(
16
)
require
.
NoError
(
t
,
err
)
require
.
Len
(
t
,
v2
,
32
)
_
,
err
=
hex
.
DecodeString
(
v2
)
require
.
NoError
(
t
,
err
)
require
.
NotEqual
(
t
,
v1
,
v2
)
}
backend/internal/repository/sora_account_repo.go
0 → 100644
View file @
07be258d
package
repository
import
(
"context"
"database/sql"
"errors"
"github.com/Wei-Shaw/sub2api/internal/service"
)
// soraAccountRepository 实现 service.SoraAccountRepository 接口。
// 使用原生 SQL 操作 sora_accounts 表,因为该表不在 Ent ORM 管理范围内。
//
// 设计说明:
// - sora_accounts 表是独立迁移创建的,不通过 Ent Schema 管理
// - 使用 ON CONFLICT (account_id) DO UPDATE 实现 Upsert 语义
// - 与 accounts 主表通过外键关联,ON DELETE CASCADE 确保级联删除
type
soraAccountRepository
struct
{
sql
*
sql
.
DB
}
// NewSoraAccountRepository 创建 Sora 账号扩展表仓储实例
func
NewSoraAccountRepository
(
sqlDB
*
sql
.
DB
)
service
.
SoraAccountRepository
{
return
&
soraAccountRepository
{
sql
:
sqlDB
}
}
// Upsert 创建或更新 Sora 账号扩展信息
// 使用 PostgreSQL ON CONFLICT ... DO UPDATE 实现原子性 upsert
func
(
r
*
soraAccountRepository
)
Upsert
(
ctx
context
.
Context
,
accountID
int64
,
updates
map
[
string
]
any
)
error
{
accessToken
,
accessOK
:=
updates
[
"access_token"
]
.
(
string
)
refreshToken
,
refreshOK
:=
updates
[
"refresh_token"
]
.
(
string
)
sessionToken
,
sessionOK
:=
updates
[
"session_token"
]
.
(
string
)
if
!
accessOK
||
accessToken
==
""
||
!
refreshOK
||
refreshToken
==
""
{
if
!
sessionOK
{
return
errors
.
New
(
"缺少 access_token/refresh_token,且未提供可更新字段"
)
}
result
,
err
:=
r
.
sql
.
ExecContext
(
ctx
,
`
UPDATE sora_accounts
SET session_token = CASE WHEN $2 = '' THEN session_token ELSE $2 END,
updated_at = NOW()
WHERE account_id = $1
`
,
accountID
,
sessionToken
)
if
err
!=
nil
{
return
err
}
rows
,
err
:=
result
.
RowsAffected
()
if
err
!=
nil
{
return
err
}
if
rows
==
0
{
return
errors
.
New
(
"sora_accounts 记录不存在,无法仅更新 session_token"
)
}
return
nil
}
_
,
err
:=
r
.
sql
.
ExecContext
(
ctx
,
`
INSERT INTO sora_accounts (account_id, access_token, refresh_token, session_token, created_at, updated_at)
VALUES ($1, $2, $3, $4, NOW(), NOW())
ON CONFLICT (account_id) DO UPDATE SET
access_token = EXCLUDED.access_token,
refresh_token = EXCLUDED.refresh_token,
session_token = CASE WHEN EXCLUDED.session_token = '' THEN sora_accounts.session_token ELSE EXCLUDED.session_token END,
updated_at = NOW()
`
,
accountID
,
accessToken
,
refreshToken
,
sessionToken
)
return
err
}
// GetByAccountID 根据账号 ID 获取 Sora 扩展信息
func
(
r
*
soraAccountRepository
)
GetByAccountID
(
ctx
context
.
Context
,
accountID
int64
)
(
*
service
.
SoraAccount
,
error
)
{
rows
,
err
:=
r
.
sql
.
QueryContext
(
ctx
,
`
SELECT account_id, access_token, refresh_token, COALESCE(session_token, '')
FROM sora_accounts
WHERE account_id = $1
`
,
accountID
)
if
err
!=
nil
{
return
nil
,
err
}
defer
func
()
{
_
=
rows
.
Close
()
}()
if
!
rows
.
Next
()
{
return
nil
,
nil
// 记录不存在
}
var
sa
service
.
SoraAccount
if
err
:=
rows
.
Scan
(
&
sa
.
AccountID
,
&
sa
.
AccessToken
,
&
sa
.
RefreshToken
,
&
sa
.
SessionToken
);
err
!=
nil
{
return
nil
,
err
}
return
&
sa
,
nil
}
// Delete 删除 Sora 账号扩展信息
func
(
r
*
soraAccountRepository
)
Delete
(
ctx
context
.
Context
,
accountID
int64
)
error
{
_
,
err
:=
r
.
sql
.
ExecContext
(
ctx
,
`
DELETE FROM sora_accounts WHERE account_id = $1
`
,
accountID
)
return
err
}
backend/internal/repository/usage_log_repo.go
View file @
07be258d
...
...
@@ -22,7 +22,23 @@ import (
"github.com/lib/pq"
)
const
usageLogSelectColumns
=
"id, user_id, api_key_id, account_id, request_id, model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, account_rate_multiplier, billing_type, stream, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, reasoning_effort, cache_ttl_overridden, created_at"
const
usageLogSelectColumns
=
"id, user_id, api_key_id, account_id, request_id, model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, account_rate_multiplier, billing_type, stream, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, media_type, reasoning_effort, cache_ttl_overridden, created_at"
// dateFormatWhitelist 将 granularity 参数映射为 PostgreSQL TO_CHAR 格式字符串,防止外部输入直接拼入 SQL
var
dateFormatWhitelist
=
map
[
string
]
string
{
"hour"
:
"YYYY-MM-DD HH24:00"
,
"day"
:
"YYYY-MM-DD"
,
"week"
:
"IYYY-IW"
,
"month"
:
"YYYY-MM"
,
}
// safeDateFormat 根据白名单获取 dateFormat,未匹配时返回默认值
func
safeDateFormat
(
granularity
string
)
string
{
if
f
,
ok
:=
dateFormatWhitelist
[
granularity
];
ok
{
return
f
}
return
"YYYY-MM-DD"
}
type
usageLogRepository
struct
{
client
*
dbent
.
Client
...
...
@@ -111,23 +127,24 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
duration_ms,
first_token_ms,
user_agent,
ip_address,
image_count,
image_size,
reasoning_effort,
cache_ttl_overridden,
created_at
) VALUES (
$1, $2, $3, $4, $5,
$6, $7,
$8, $9, $10, $11,
$12, $13,
$14, $15, $16, $17, $18, $19,
$20, $21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31, $32
)
ON CONFLICT (request_id, api_key_id) DO NOTHING
RETURNING id, created_at
`
ip_address,
image_count,
image_size,
media_type,
reasoning_effort,
cache_ttl_overridden,
created_at
) VALUES (
$1, $2, $3, $4, $5,
$6, $7,
$8, $9, $10, $11,
$12, $13,
$14, $15, $16, $17, $18, $19,
$20, $21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31, $32, $33
)
ON CONFLICT (request_id, api_key_id) DO NOTHING
RETURNING id, created_at
`
groupID
:=
nullInt64
(
log
.
GroupID
)
subscriptionID
:=
nullInt64
(
log
.
SubscriptionID
)
...
...
@@ -136,6 +153,7 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
userAgent
:=
nullString
(
log
.
UserAgent
)
ipAddress
:=
nullString
(
log
.
IPAddress
)
imageSize
:=
nullString
(
log
.
ImageSize
)
mediaType
:=
nullString
(
log
.
MediaType
)
reasoningEffort
:=
nullString
(
log
.
ReasoningEffort
)
var
requestIDArg
any
...
...
@@ -173,6 +191,7 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
ipAddress
,
log
.
ImageCount
,
imageSize
,
mediaType
,
reasoningEffort
,
log
.
CacheTTLOverridden
,
createdAt
,
...
...
@@ -566,7 +585,7 @@ func (r *usageLogRepository) ListByAccount(ctx context.Context, accountID int64,
}
func
(
r
*
usageLogRepository
)
ListByUserAndTimeRange
(
ctx
context
.
Context
,
userID
int64
,
startTime
,
endTime
time
.
Time
)
([]
service
.
UsageLog
,
*
pagination
.
PaginationResult
,
error
)
{
query
:=
"SELECT "
+
usageLogSelectColumns
+
" FROM usage_logs WHERE user_id = $1 AND created_at >= $2 AND created_at < $3 ORDER BY id DESC"
query
:=
"SELECT "
+
usageLogSelectColumns
+
" FROM usage_logs WHERE user_id = $1 AND created_at >= $2 AND created_at < $3 ORDER BY id DESC
LIMIT 10000
"
logs
,
err
:=
r
.
queryUsageLogs
(
ctx
,
query
,
userID
,
startTime
,
endTime
)
return
logs
,
nil
,
err
}
...
...
@@ -812,19 +831,19 @@ func resolveUsageStatsTimezone() string {
}
func
(
r
*
usageLogRepository
)
ListByAPIKeyAndTimeRange
(
ctx
context
.
Context
,
apiKeyID
int64
,
startTime
,
endTime
time
.
Time
)
([]
service
.
UsageLog
,
*
pagination
.
PaginationResult
,
error
)
{
query
:=
"SELECT "
+
usageLogSelectColumns
+
" FROM usage_logs WHERE api_key_id = $1 AND created_at >= $2 AND created_at < $3 ORDER BY id DESC"
query
:=
"SELECT "
+
usageLogSelectColumns
+
" FROM usage_logs WHERE api_key_id = $1 AND created_at >= $2 AND created_at < $3 ORDER BY id DESC
LIMIT 10000
"
logs
,
err
:=
r
.
queryUsageLogs
(
ctx
,
query
,
apiKeyID
,
startTime
,
endTime
)
return
logs
,
nil
,
err
}
func
(
r
*
usageLogRepository
)
ListByAccountAndTimeRange
(
ctx
context
.
Context
,
accountID
int64
,
startTime
,
endTime
time
.
Time
)
([]
service
.
UsageLog
,
*
pagination
.
PaginationResult
,
error
)
{
query
:=
"SELECT "
+
usageLogSelectColumns
+
" FROM usage_logs WHERE account_id = $1 AND created_at >= $2 AND created_at < $3 ORDER BY id DESC"
query
:=
"SELECT "
+
usageLogSelectColumns
+
" FROM usage_logs WHERE account_id = $1 AND created_at >= $2 AND created_at < $3 ORDER BY id DESC
LIMIT 10000
"
logs
,
err
:=
r
.
queryUsageLogs
(
ctx
,
query
,
accountID
,
startTime
,
endTime
)
return
logs
,
nil
,
err
}
func
(
r
*
usageLogRepository
)
ListByModelAndTimeRange
(
ctx
context
.
Context
,
modelName
string
,
startTime
,
endTime
time
.
Time
)
([]
service
.
UsageLog
,
*
pagination
.
PaginationResult
,
error
)
{
query
:=
"SELECT "
+
usageLogSelectColumns
+
" FROM usage_logs WHERE model = $1 AND created_at >= $2 AND created_at < $3 ORDER BY id DESC"
query
:=
"SELECT "
+
usageLogSelectColumns
+
" FROM usage_logs WHERE model = $1 AND created_at >= $2 AND created_at < $3 ORDER BY id DESC
LIMIT 10000
"
logs
,
err
:=
r
.
queryUsageLogs
(
ctx
,
query
,
modelName
,
startTime
,
endTime
)
return
logs
,
nil
,
err
}
...
...
@@ -896,6 +915,59 @@ func (r *usageLogRepository) GetAccountWindowStats(ctx context.Context, accountI
return
stats
,
nil
}
// GetAccountWindowStatsBatch 批量获取同一窗口起点下多个账号的统计数据。
// 返回 map[accountID]*AccountStats,未命中的账号会返回零值统计,便于上层直接复用。
func
(
r
*
usageLogRepository
)
GetAccountWindowStatsBatch
(
ctx
context
.
Context
,
accountIDs
[]
int64
,
startTime
time
.
Time
)
(
map
[
int64
]
*
usagestats
.
AccountStats
,
error
)
{
result
:=
make
(
map
[
int64
]
*
usagestats
.
AccountStats
,
len
(
accountIDs
))
if
len
(
accountIDs
)
==
0
{
return
result
,
nil
}
query
:=
`
SELECT
account_id,
COUNT(*) as requests,
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as tokens,
COALESCE(SUM(total_cost * COALESCE(account_rate_multiplier, 1)), 0) as cost,
COALESCE(SUM(total_cost), 0) as standard_cost,
COALESCE(SUM(actual_cost), 0) as user_cost
FROM usage_logs
WHERE account_id = ANY($1) AND created_at >= $2
GROUP BY account_id
`
rows
,
err
:=
r
.
sql
.
QueryContext
(
ctx
,
query
,
pq
.
Array
(
accountIDs
),
startTime
)
if
err
!=
nil
{
return
nil
,
err
}
defer
func
()
{
_
=
rows
.
Close
()
}()
for
rows
.
Next
()
{
var
accountID
int64
stats
:=
&
usagestats
.
AccountStats
{}
if
err
:=
rows
.
Scan
(
&
accountID
,
&
stats
.
Requests
,
&
stats
.
Tokens
,
&
stats
.
Cost
,
&
stats
.
StandardCost
,
&
stats
.
UserCost
,
);
err
!=
nil
{
return
nil
,
err
}
result
[
accountID
]
=
stats
}
if
err
:=
rows
.
Err
();
err
!=
nil
{
return
nil
,
err
}
for
_
,
accountID
:=
range
accountIDs
{
if
_
,
ok
:=
result
[
accountID
];
!
ok
{
result
[
accountID
]
=
&
usagestats
.
AccountStats
{}
}
}
return
result
,
nil
}
// TrendDataPoint represents a single point in trend data
type
TrendDataPoint
=
usagestats
.
TrendDataPoint
...
...
@@ -910,10 +982,7 @@ type APIKeyUsageTrendPoint = usagestats.APIKeyUsageTrendPoint
// GetAPIKeyUsageTrend returns usage trend data grouped by API key and date
func
(
r
*
usageLogRepository
)
GetAPIKeyUsageTrend
(
ctx
context
.
Context
,
startTime
,
endTime
time
.
Time
,
granularity
string
,
limit
int
)
(
results
[]
APIKeyUsageTrendPoint
,
err
error
)
{
dateFormat
:=
"YYYY-MM-DD"
if
granularity
==
"hour"
{
dateFormat
=
"YYYY-MM-DD HH24:00"
}
dateFormat
:=
safeDateFormat
(
granularity
)
query
:=
fmt
.
Sprintf
(
`
WITH top_keys AS (
...
...
@@ -968,10 +1037,7 @@ func (r *usageLogRepository) GetAPIKeyUsageTrend(ctx context.Context, startTime,
// GetUserUsageTrend returns usage trend data grouped by user and date
func
(
r
*
usageLogRepository
)
GetUserUsageTrend
(
ctx
context
.
Context
,
startTime
,
endTime
time
.
Time
,
granularity
string
,
limit
int
)
(
results
[]
UserUsageTrendPoint
,
err
error
)
{
dateFormat
:=
"YYYY-MM-DD"
if
granularity
==
"hour"
{
dateFormat
=
"YYYY-MM-DD HH24:00"
}
dateFormat
:=
safeDateFormat
(
granularity
)
query
:=
fmt
.
Sprintf
(
`
WITH top_users AS (
...
...
@@ -1230,10 +1296,7 @@ func (r *usageLogRepository) GetAPIKeyDashboardStats(ctx context.Context, apiKey
// GetUserUsageTrendByUserID 获取指定用户的使用趋势
func
(
r
*
usageLogRepository
)
GetUserUsageTrendByUserID
(
ctx
context
.
Context
,
userID
int64
,
startTime
,
endTime
time
.
Time
,
granularity
string
)
(
results
[]
TrendDataPoint
,
err
error
)
{
dateFormat
:=
"YYYY-MM-DD"
if
granularity
==
"hour"
{
dateFormat
=
"YYYY-MM-DD HH24:00"
}
dateFormat
:=
safeDateFormat
(
granularity
)
query
:=
fmt
.
Sprintf
(
`
SELECT
...
...
@@ -1371,13 +1434,22 @@ type UsageStats = usagestats.UsageStats
// BatchUserUsageStats represents usage stats for a single user
type
BatchUserUsageStats
=
usagestats
.
BatchUserUsageStats
// GetBatchUserUsageStats gets today and total actual_cost for multiple users
func
(
r
*
usageLogRepository
)
GetBatchUserUsageStats
(
ctx
context
.
Context
,
userIDs
[]
int64
)
(
map
[
int64
]
*
BatchUserUsageStats
,
error
)
{
// GetBatchUserUsageStats gets today and total actual_cost for multiple users within a time range.
// If startTime is zero, defaults to 30 days ago.
func
(
r
*
usageLogRepository
)
GetBatchUserUsageStats
(
ctx
context
.
Context
,
userIDs
[]
int64
,
startTime
,
endTime
time
.
Time
)
(
map
[
int64
]
*
BatchUserUsageStats
,
error
)
{
result
:=
make
(
map
[
int64
]
*
BatchUserUsageStats
)
if
len
(
userIDs
)
==
0
{
return
result
,
nil
}
// 默认最近 30 天
if
startTime
.
IsZero
()
{
startTime
=
time
.
Now
()
.
AddDate
(
0
,
0
,
-
30
)
}
if
endTime
.
IsZero
()
{
endTime
=
time
.
Now
()
}
for
_
,
id
:=
range
userIDs
{
result
[
id
]
=
&
BatchUserUsageStats
{
UserID
:
id
}
}
...
...
@@ -1385,10 +1457,10 @@ func (r *usageLogRepository) GetBatchUserUsageStats(ctx context.Context, userIDs
query
:=
`
SELECT user_id, COALESCE(SUM(actual_cost), 0) as total_cost
FROM usage_logs
WHERE user_id = ANY($1)
WHERE user_id = ANY($1)
AND created_at >= $2 AND created_at < $3
GROUP BY user_id
`
rows
,
err
:=
r
.
sql
.
QueryContext
(
ctx
,
query
,
pq
.
Array
(
userIDs
))
rows
,
err
:=
r
.
sql
.
QueryContext
(
ctx
,
query
,
pq
.
Array
(
userIDs
)
,
startTime
,
endTime
)
if
err
!=
nil
{
return
nil
,
err
}
...
...
@@ -1445,13 +1517,22 @@ func (r *usageLogRepository) GetBatchUserUsageStats(ctx context.Context, userIDs
// BatchAPIKeyUsageStats represents usage stats for a single API key
type
BatchAPIKeyUsageStats
=
usagestats
.
BatchAPIKeyUsageStats
// GetBatchAPIKeyUsageStats gets today and total actual_cost for multiple API keys
func
(
r
*
usageLogRepository
)
GetBatchAPIKeyUsageStats
(
ctx
context
.
Context
,
apiKeyIDs
[]
int64
)
(
map
[
int64
]
*
BatchAPIKeyUsageStats
,
error
)
{
// GetBatchAPIKeyUsageStats gets today and total actual_cost for multiple API keys within a time range.
// If startTime is zero, defaults to 30 days ago.
func
(
r
*
usageLogRepository
)
GetBatchAPIKeyUsageStats
(
ctx
context
.
Context
,
apiKeyIDs
[]
int64
,
startTime
,
endTime
time
.
Time
)
(
map
[
int64
]
*
BatchAPIKeyUsageStats
,
error
)
{
result
:=
make
(
map
[
int64
]
*
BatchAPIKeyUsageStats
)
if
len
(
apiKeyIDs
)
==
0
{
return
result
,
nil
}
// 默认最近 30 天
if
startTime
.
IsZero
()
{
startTime
=
time
.
Now
()
.
AddDate
(
0
,
0
,
-
30
)
}
if
endTime
.
IsZero
()
{
endTime
=
time
.
Now
()
}
for
_
,
id
:=
range
apiKeyIDs
{
result
[
id
]
=
&
BatchAPIKeyUsageStats
{
APIKeyID
:
id
}
}
...
...
@@ -1459,10 +1540,10 @@ func (r *usageLogRepository) GetBatchAPIKeyUsageStats(ctx context.Context, apiKe
query
:=
`
SELECT api_key_id, COALESCE(SUM(actual_cost), 0) as total_cost
FROM usage_logs
WHERE api_key_id = ANY($1)
WHERE api_key_id = ANY($1)
AND created_at >= $2 AND created_at < $3
GROUP BY api_key_id
`
rows
,
err
:=
r
.
sql
.
QueryContext
(
ctx
,
query
,
pq
.
Array
(
apiKeyIDs
))
rows
,
err
:=
r
.
sql
.
QueryContext
(
ctx
,
query
,
pq
.
Array
(
apiKeyIDs
)
,
startTime
,
endTime
)
if
err
!=
nil
{
return
nil
,
err
}
...
...
@@ -1518,10 +1599,7 @@ func (r *usageLogRepository) GetBatchAPIKeyUsageStats(ctx context.Context, apiKe
// GetUsageTrendWithFilters returns usage trend data with optional filters
func
(
r
*
usageLogRepository
)
GetUsageTrendWithFilters
(
ctx
context
.
Context
,
startTime
,
endTime
time
.
Time
,
granularity
string
,
userID
,
apiKeyID
,
accountID
,
groupID
int64
,
model
string
,
stream
*
bool
,
billingType
*
int8
)
(
results
[]
TrendDataPoint
,
err
error
)
{
dateFormat
:=
"YYYY-MM-DD"
if
granularity
==
"hour"
{
dateFormat
=
"YYYY-MM-DD HH24:00"
}
dateFormat
:=
safeDateFormat
(
granularity
)
query
:=
fmt
.
Sprintf
(
`
SELECT
...
...
@@ -2196,6 +2274,7 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
ipAddress
sql
.
NullString
imageCount
int
imageSize
sql
.
NullString
mediaType
sql
.
NullString
reasoningEffort
sql
.
NullString
cacheTTLOverridden
bool
createdAt
time
.
Time
...
...
@@ -2232,6 +2311,7 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
&
ipAddress
,
&
imageCount
,
&
imageSize
,
&
mediaType
,
&
reasoningEffort
,
&
cacheTTLOverridden
,
&
createdAt
,
...
...
@@ -2294,6 +2374,9 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
if
imageSize
.
Valid
{
log
.
ImageSize
=
&
imageSize
.
String
}
if
mediaType
.
Valid
{
log
.
MediaType
=
&
mediaType
.
String
}
if
reasoningEffort
.
Valid
{
log
.
ReasoningEffort
=
&
reasoningEffort
.
String
}
...
...
backend/internal/repository/usage_log_repo_integration_test.go
View file @
07be258d
...
...
@@ -648,7 +648,7 @@ func (s *UsageLogRepoSuite) TestGetBatchUserUsageStats() {
s
.
createUsageLog
(
user1
,
apiKey1
,
account
,
10
,
20
,
0.5
,
time
.
Now
())
s
.
createUsageLog
(
user2
,
apiKey2
,
account
,
15
,
25
,
0.6
,
time
.
Now
())
stats
,
err
:=
s
.
repo
.
GetBatchUserUsageStats
(
s
.
ctx
,
[]
int64
{
user1
.
ID
,
user2
.
ID
})
stats
,
err
:=
s
.
repo
.
GetBatchUserUsageStats
(
s
.
ctx
,
[]
int64
{
user1
.
ID
,
user2
.
ID
}
,
time
.
Time
{},
time
.
Time
{}
)
s
.
Require
()
.
NoError
(
err
,
"GetBatchUserUsageStats"
)
s
.
Require
()
.
Len
(
stats
,
2
)
s
.
Require
()
.
NotNil
(
stats
[
user1
.
ID
])
...
...
@@ -656,7 +656,7 @@ func (s *UsageLogRepoSuite) TestGetBatchUserUsageStats() {
}
func
(
s
*
UsageLogRepoSuite
)
TestGetBatchUserUsageStats_Empty
()
{
stats
,
err
:=
s
.
repo
.
GetBatchUserUsageStats
(
s
.
ctx
,
[]
int64
{})
stats
,
err
:=
s
.
repo
.
GetBatchUserUsageStats
(
s
.
ctx
,
[]
int64
{}
,
time
.
Time
{},
time
.
Time
{}
)
s
.
Require
()
.
NoError
(
err
)
s
.
Require
()
.
Empty
(
stats
)
}
...
...
@@ -672,13 +672,13 @@ func (s *UsageLogRepoSuite) TestGetBatchApiKeyUsageStats() {
s
.
createUsageLog
(
user
,
apiKey1
,
account
,
10
,
20
,
0.5
,
time
.
Now
())
s
.
createUsageLog
(
user
,
apiKey2
,
account
,
15
,
25
,
0.6
,
time
.
Now
())
stats
,
err
:=
s
.
repo
.
GetBatchAPIKeyUsageStats
(
s
.
ctx
,
[]
int64
{
apiKey1
.
ID
,
apiKey2
.
ID
})
stats
,
err
:=
s
.
repo
.
GetBatchAPIKeyUsageStats
(
s
.
ctx
,
[]
int64
{
apiKey1
.
ID
,
apiKey2
.
ID
}
,
time
.
Time
{},
time
.
Time
{}
)
s
.
Require
()
.
NoError
(
err
,
"GetBatchAPIKeyUsageStats"
)
s
.
Require
()
.
Len
(
stats
,
2
)
}
func
(
s
*
UsageLogRepoSuite
)
TestGetBatchApiKeyUsageStats_Empty
()
{
stats
,
err
:=
s
.
repo
.
GetBatchAPIKeyUsageStats
(
s
.
ctx
,
[]
int64
{})
stats
,
err
:=
s
.
repo
.
GetBatchAPIKeyUsageStats
(
s
.
ctx
,
[]
int64
{}
,
time
.
Time
{},
time
.
Time
{}
)
s
.
Require
()
.
NoError
(
err
)
s
.
Require
()
.
Empty
(
stats
)
}
...
...
backend/internal/repository/usage_log_repo_unit_test.go
0 → 100644
View file @
07be258d
//go:build unit
package
repository
import
(
"testing"
"github.com/stretchr/testify/require"
)
func
TestSafeDateFormat
(
t
*
testing
.
T
)
{
tests
:=
[]
struct
{
name
string
granularity
string
expected
string
}{
// 合法值
{
"hour"
,
"hour"
,
"YYYY-MM-DD HH24:00"
},
{
"day"
,
"day"
,
"YYYY-MM-DD"
},
{
"week"
,
"week"
,
"IYYY-IW"
},
{
"month"
,
"month"
,
"YYYY-MM"
},
// 非法值回退到默认
{
"空字符串"
,
""
,
"YYYY-MM-DD"
},
{
"未知粒度 year"
,
"year"
,
"YYYY-MM-DD"
},
{
"未知粒度 minute"
,
"minute"
,
"YYYY-MM-DD"
},
// 恶意字符串
{
"SQL 注入尝试"
,
"'; DROP TABLE users; --"
,
"YYYY-MM-DD"
},
{
"带引号"
,
"day'"
,
"YYYY-MM-DD"
},
{
"带括号"
,
"day)"
,
"YYYY-MM-DD"
},
{
"Unicode"
,
"日"
,
"YYYY-MM-DD"
},
}
for
_
,
tc
:=
range
tests
{
t
.
Run
(
tc
.
name
,
func
(
t
*
testing
.
T
)
{
got
:=
safeDateFormat
(
tc
.
granularity
)
require
.
Equal
(
t
,
tc
.
expected
,
got
,
"safeDateFormat(%q)"
,
tc
.
granularity
)
})
}
}
backend/internal/repository/wire.go
View file @
07be258d
...
...
@@ -28,7 +28,7 @@ func ProvideConcurrencyCache(rdb *redis.Client, cfg *config.Config) service.Conc
// ProvideGitHubReleaseClient 创建 GitHub Release 客户端
// 从配置中读取代理设置,支持国内服务器通过代理访问 GitHub
func
ProvideGitHubReleaseClient
(
cfg
*
config
.
Config
)
service
.
GitHubReleaseClient
{
return
NewGitHubReleaseClient
(
cfg
.
Update
.
ProxyURL
)
return
NewGitHubReleaseClient
(
cfg
.
Update
.
ProxyURL
,
cfg
.
Security
.
ProxyFallback
.
AllowDirectOnError
)
}
// ProvidePricingRemoteClient 创建定价数据远程客户端
...
...
@@ -53,12 +53,14 @@ var ProviderSet = wire.NewSet(
NewAPIKeyRepository
,
NewGroupRepository
,
NewAccountRepository
,
NewSoraAccountRepository
,
// Sora 账号扩展表仓储
NewProxyRepository
,
NewRedeemCodeRepository
,
NewPromoCodeRepository
,
NewAnnouncementRepository
,
NewAnnouncementReadRepository
,
NewUsageLogRepository
,
NewIdempotencyRepository
,
NewUsageCleanupRepository
,
NewDashboardAggregationRepository
,
NewSettingRepository
,
...
...
backend/internal/server/api_contract_test.go
View file @
07be258d
...
...
@@ -83,6 +83,7 @@ func TestAPIContracts(t *testing.T) {
"status": "active",
"ip_whitelist": null,
"ip_blacklist": null,
"last_used_at": null,
"quota": 0,
"quota_used": 0,
"expires_at": null,
...
...
@@ -122,6 +123,7 @@ func TestAPIContracts(t *testing.T) {
"status": "active",
"ip_whitelist": null,
"ip_blacklist": null,
"last_used_at": null,
"quota": 0,
"quota_used": 0,
"expires_at": null,
...
...
@@ -184,6 +186,10 @@ func TestAPIContracts(t *testing.T) {
"image_price_1k": null,
"image_price_2k": null,
"image_price_4k": null,
"sora_image_price_360": null,
"sora_image_price_540": null,
"sora_video_price_per_request": null,
"sora_video_price_per_request_hd": null,
"claude_code_only": false,
"fallback_group_id": null,
"fallback_group_id_on_invalid_request": null,
...
...
@@ -401,6 +407,7 @@ func TestAPIContracts(t *testing.T) {
"first_token_ms": 50,
"image_count": 0,
"image_size": null,
"media_type": null,
"cache_ttl_overridden": false,
"created_at": "2025-01-02T03:04:05Z",
"user_agent": null
...
...
@@ -593,13 +600,13 @@ func newContractDeps(t *testing.T) *contractDeps {
RunMode
:
config
.
RunModeStandard
,
}
userService
:=
service
.
NewUserService
(
userRepo
,
nil
)
userService
:=
service
.
NewUserService
(
userRepo
,
nil
,
nil
)
apiKeyService
:=
service
.
NewAPIKeyService
(
apiKeyRepo
,
userRepo
,
groupRepo
,
userSubRepo
,
nil
,
apiKeyCache
,
cfg
)
usageRepo
:=
newStubUsageLogRepo
()
usageService
:=
service
.
NewUsageService
(
usageRepo
,
userRepo
,
nil
,
nil
)
subscriptionService
:=
service
.
NewSubscriptionService
(
groupRepo
,
userSubRepo
,
nil
)
subscriptionService
:=
service
.
NewSubscriptionService
(
groupRepo
,
userSubRepo
,
nil
,
nil
,
cfg
)
subscriptionHandler
:=
handler
.
NewSubscriptionHandler
(
subscriptionService
)
redeemService
:=
service
.
NewRedeemService
(
redeemRepo
,
userRepo
,
subscriptionService
,
nil
,
nil
,
nil
,
nil
)
...
...
@@ -608,7 +615,7 @@ func newContractDeps(t *testing.T) *contractDeps {
settingRepo
:=
newStubSettingRepo
()
settingService
:=
service
.
NewSettingService
(
settingRepo
,
cfg
)
adminService
:=
service
.
NewAdminService
(
userRepo
,
groupRepo
,
&
accountRepo
,
proxyRepo
,
apiKeyRepo
,
redeemRepo
,
nil
,
nil
,
nil
,
nil
,
nil
)
adminService
:=
service
.
NewAdminService
(
userRepo
,
groupRepo
,
&
accountRepo
,
nil
,
proxyRepo
,
apiKeyRepo
,
redeemRepo
,
nil
,
nil
,
nil
,
nil
,
nil
)
authHandler
:=
handler
.
NewAuthHandler
(
cfg
,
nil
,
userService
,
settingService
,
nil
,
redeemService
,
nil
)
apiKeyHandler
:=
handler
.
NewAPIKeyHandler
(
apiKeyService
)
usageHandler
:=
handler
.
NewUsageHandler
(
usageService
,
apiKeyService
)
...
...
@@ -925,6 +932,10 @@ func (s *stubAccountRepo) GetByCRSAccountID(ctx context.Context, crsAccountID st
return
nil
,
errors
.
New
(
"not implemented"
)
}
func
(
s
*
stubAccountRepo
)
FindByExtraField
(
ctx
context
.
Context
,
key
string
,
value
any
)
([]
service
.
Account
,
error
)
{
return
nil
,
errors
.
New
(
"not implemented"
)
}
func
(
s
*
stubAccountRepo
)
Update
(
ctx
context
.
Context
,
account
*
service
.
Account
)
error
{
return
errors
.
New
(
"not implemented"
)
}
...
...
@@ -1462,6 +1473,20 @@ func (r *stubApiKeyRepo) IncrementQuotaUsed(ctx context.Context, id int64, amoun
return
0
,
errors
.
New
(
"not implemented"
)
}
func
(
r
*
stubApiKeyRepo
)
UpdateLastUsed
(
ctx
context
.
Context
,
id
int64
,
usedAt
time
.
Time
)
error
{
key
,
ok
:=
r
.
byID
[
id
]
if
!
ok
{
return
service
.
ErrAPIKeyNotFound
}
ts
:=
usedAt
key
.
LastUsedAt
=
&
ts
key
.
UpdatedAt
=
usedAt
clone
:=
*
key
r
.
byID
[
id
]
=
&
clone
r
.
byKey
[
clone
.
Key
]
=
&
clone
return
nil
}
type
stubUsageLogRepo
struct
{
userLogs
map
[
int64
][]
service
.
UsageLog
}
...
...
@@ -1607,11 +1632,11 @@ func (r *stubUsageLogRepo) GetDailyStatsAggregated(ctx context.Context, userID i
return
nil
,
errors
.
New
(
"not implemented"
)
}
func
(
r
*
stubUsageLogRepo
)
GetBatchUserUsageStats
(
ctx
context
.
Context
,
userIDs
[]
int64
)
(
map
[
int64
]
*
usagestats
.
BatchUserUsageStats
,
error
)
{
func
(
r
*
stubUsageLogRepo
)
GetBatchUserUsageStats
(
ctx
context
.
Context
,
userIDs
[]
int64
,
startTime
,
endTime
time
.
Time
)
(
map
[
int64
]
*
usagestats
.
BatchUserUsageStats
,
error
)
{
return
nil
,
errors
.
New
(
"not implemented"
)
}
func
(
r
*
stubUsageLogRepo
)
GetBatchAPIKeyUsageStats
(
ctx
context
.
Context
,
apiKeyIDs
[]
int64
)
(
map
[
int64
]
*
usagestats
.
BatchAPIKeyUsageStats
,
error
)
{
func
(
r
*
stubUsageLogRepo
)
GetBatchAPIKeyUsageStats
(
ctx
context
.
Context
,
apiKeyIDs
[]
int64
,
startTime
,
endTime
time
.
Time
)
(
map
[
int64
]
*
usagestats
.
BatchAPIKeyUsageStats
,
error
)
{
return
nil
,
errors
.
New
(
"not implemented"
)
}
...
...
backend/internal/server/http.go
View file @
07be258d
...
...
@@ -51,6 +51,9 @@ func ProvideRouter(
if
err
:=
r
.
SetTrustedProxies
(
nil
);
err
!=
nil
{
log
.
Printf
(
"Failed to disable trusted proxies: %v"
,
err
)
}
if
cfg
.
Server
.
Mode
==
"release"
{
log
.
Printf
(
"Warning: server.trusted_proxies is empty in release mode; client IP trust chain is disabled"
)
}
}
return
SetupRouter
(
r
,
handlers
,
jwtAuth
,
adminAuth
,
apiKeyAuth
,
apiKeyService
,
subscriptionService
,
opsService
,
settingService
,
cfg
,
redisClient
)
...
...
backend/internal/server/middleware/admin_auth.go
View file @
07be258d
...
...
@@ -58,8 +58,13 @@ func adminAuth(
authHeader
:=
c
.
GetHeader
(
"Authorization"
)
if
authHeader
!=
""
{
parts
:=
strings
.
SplitN
(
authHeader
,
" "
,
2
)
if
len
(
parts
)
==
2
&&
parts
[
0
]
==
"Bearer"
{
if
!
validateJWTForAdmin
(
c
,
parts
[
1
],
authService
,
userService
)
{
if
len
(
parts
)
==
2
&&
strings
.
EqualFold
(
parts
[
0
],
"Bearer"
)
{
token
:=
strings
.
TrimSpace
(
parts
[
1
])
if
token
==
""
{
AbortWithError
(
c
,
401
,
"UNAUTHORIZED"
,
"Authorization required"
)
return
}
if
!
validateJWTForAdmin
(
c
,
token
,
authService
,
userService
)
{
return
}
c
.
Next
()
...
...
@@ -176,6 +181,12 @@ func validateJWTForAdmin(
return
false
}
// 校验 TokenVersion,确保管理员改密后旧 token 失效
if
claims
.
TokenVersion
!=
user
.
TokenVersion
{
AbortWithError
(
c
,
401
,
"TOKEN_REVOKED"
,
"Token has been revoked (password changed)"
)
return
false
}
// 检查管理员权限
if
!
user
.
IsAdmin
()
{
AbortWithError
(
c
,
403
,
"FORBIDDEN"
,
"Admin access required"
)
...
...
backend/internal/server/middleware/admin_auth_test.go
0 → 100644
View file @
07be258d
//go:build unit
package
middleware
import
(
"context"
"net/http"
"net/http/httptest"
"testing"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
func
TestAdminAuthJWTValidatesTokenVersion
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
cfg
:=
&
config
.
Config
{
JWT
:
config
.
JWTConfig
{
Secret
:
"test-secret"
,
ExpireHour
:
1
}}
authService
:=
service
.
NewAuthService
(
nil
,
nil
,
nil
,
cfg
,
nil
,
nil
,
nil
,
nil
,
nil
)
admin
:=
&
service
.
User
{
ID
:
1
,
Email
:
"admin@example.com"
,
Role
:
service
.
RoleAdmin
,
Status
:
service
.
StatusActive
,
TokenVersion
:
2
,
Concurrency
:
1
,
}
userRepo
:=
&
stubUserRepo
{
getByID
:
func
(
ctx
context
.
Context
,
id
int64
)
(
*
service
.
User
,
error
)
{
if
id
!=
admin
.
ID
{
return
nil
,
service
.
ErrUserNotFound
}
clone
:=
*
admin
return
&
clone
,
nil
},
}
userService
:=
service
.
NewUserService
(
userRepo
,
nil
,
nil
)
router
:=
gin
.
New
()
router
.
Use
(
gin
.
HandlerFunc
(
NewAdminAuthMiddleware
(
authService
,
userService
,
nil
)))
router
.
GET
(
"/t"
,
func
(
c
*
gin
.
Context
)
{
c
.
JSON
(
http
.
StatusOK
,
gin
.
H
{
"ok"
:
true
})
})
t
.
Run
(
"token_version_mismatch_rejected"
,
func
(
t
*
testing
.
T
)
{
token
,
err
:=
authService
.
GenerateToken
(
&
service
.
User
{
ID
:
admin
.
ID
,
Email
:
admin
.
Email
,
Role
:
admin
.
Role
,
TokenVersion
:
admin
.
TokenVersion
-
1
,
})
require
.
NoError
(
t
,
err
)
w
:=
httptest
.
NewRecorder
()
req
:=
httptest
.
NewRequest
(
http
.
MethodGet
,
"/t"
,
nil
)
req
.
Header
.
Set
(
"Authorization"
,
"Bearer "
+
token
)
router
.
ServeHTTP
(
w
,
req
)
require
.
Equal
(
t
,
http
.
StatusUnauthorized
,
w
.
Code
)
require
.
Contains
(
t
,
w
.
Body
.
String
(),
"TOKEN_REVOKED"
)
})
t
.
Run
(
"token_version_match_allows"
,
func
(
t
*
testing
.
T
)
{
token
,
err
:=
authService
.
GenerateToken
(
&
service
.
User
{
ID
:
admin
.
ID
,
Email
:
admin
.
Email
,
Role
:
admin
.
Role
,
TokenVersion
:
admin
.
TokenVersion
,
})
require
.
NoError
(
t
,
err
)
w
:=
httptest
.
NewRecorder
()
req
:=
httptest
.
NewRequest
(
http
.
MethodGet
,
"/t"
,
nil
)
req
.
Header
.
Set
(
"Authorization"
,
"Bearer "
+
token
)
router
.
ServeHTTP
(
w
,
req
)
require
.
Equal
(
t
,
http
.
StatusOK
,
w
.
Code
)
})
t
.
Run
(
"websocket_token_version_mismatch_rejected"
,
func
(
t
*
testing
.
T
)
{
token
,
err
:=
authService
.
GenerateToken
(
&
service
.
User
{
ID
:
admin
.
ID
,
Email
:
admin
.
Email
,
Role
:
admin
.
Role
,
TokenVersion
:
admin
.
TokenVersion
-
1
,
})
require
.
NoError
(
t
,
err
)
w
:=
httptest
.
NewRecorder
()
req
:=
httptest
.
NewRequest
(
http
.
MethodGet
,
"/t"
,
nil
)
req
.
Header
.
Set
(
"Upgrade"
,
"websocket"
)
req
.
Header
.
Set
(
"Connection"
,
"Upgrade"
)
req
.
Header
.
Set
(
"Sec-WebSocket-Protocol"
,
"sub2api-admin, jwt."
+
token
)
router
.
ServeHTTP
(
w
,
req
)
require
.
Equal
(
t
,
http
.
StatusUnauthorized
,
w
.
Code
)
require
.
Contains
(
t
,
w
.
Body
.
String
(),
"TOKEN_REVOKED"
)
})
t
.
Run
(
"websocket_token_version_match_allows"
,
func
(
t
*
testing
.
T
)
{
token
,
err
:=
authService
.
GenerateToken
(
&
service
.
User
{
ID
:
admin
.
ID
,
Email
:
admin
.
Email
,
Role
:
admin
.
Role
,
TokenVersion
:
admin
.
TokenVersion
,
})
require
.
NoError
(
t
,
err
)
w
:=
httptest
.
NewRecorder
()
req
:=
httptest
.
NewRequest
(
http
.
MethodGet
,
"/t"
,
nil
)
req
.
Header
.
Set
(
"Upgrade"
,
"websocket"
)
req
.
Header
.
Set
(
"Connection"
,
"Upgrade"
)
req
.
Header
.
Set
(
"Sec-WebSocket-Protocol"
,
"sub2api-admin, jwt."
+
token
)
router
.
ServeHTTP
(
w
,
req
)
require
.
Equal
(
t
,
http
.
StatusOK
,
w
.
Code
)
})
}
type
stubUserRepo
struct
{
getByID
func
(
ctx
context
.
Context
,
id
int64
)
(
*
service
.
User
,
error
)
}
func
(
s
*
stubUserRepo
)
Create
(
ctx
context
.
Context
,
user
*
service
.
User
)
error
{
panic
(
"unexpected Create call"
)
}
func
(
s
*
stubUserRepo
)
GetByID
(
ctx
context
.
Context
,
id
int64
)
(
*
service
.
User
,
error
)
{
if
s
.
getByID
==
nil
{
panic
(
"GetByID not stubbed"
)
}
return
s
.
getByID
(
ctx
,
id
)
}
func
(
s
*
stubUserRepo
)
GetByEmail
(
ctx
context
.
Context
,
email
string
)
(
*
service
.
User
,
error
)
{
panic
(
"unexpected GetByEmail call"
)
}
func
(
s
*
stubUserRepo
)
GetFirstAdmin
(
ctx
context
.
Context
)
(
*
service
.
User
,
error
)
{
panic
(
"unexpected GetFirstAdmin call"
)
}
func
(
s
*
stubUserRepo
)
Update
(
ctx
context
.
Context
,
user
*
service
.
User
)
error
{
panic
(
"unexpected Update call"
)
}
func
(
s
*
stubUserRepo
)
Delete
(
ctx
context
.
Context
,
id
int64
)
error
{
panic
(
"unexpected Delete call"
)
}
func
(
s
*
stubUserRepo
)
List
(
ctx
context
.
Context
,
params
pagination
.
PaginationParams
)
([]
service
.
User
,
*
pagination
.
PaginationResult
,
error
)
{
panic
(
"unexpected List call"
)
}
func
(
s
*
stubUserRepo
)
ListWithFilters
(
ctx
context
.
Context
,
params
pagination
.
PaginationParams
,
filters
service
.
UserListFilters
)
([]
service
.
User
,
*
pagination
.
PaginationResult
,
error
)
{
panic
(
"unexpected ListWithFilters call"
)
}
func
(
s
*
stubUserRepo
)
UpdateBalance
(
ctx
context
.
Context
,
id
int64
,
amount
float64
)
error
{
panic
(
"unexpected UpdateBalance call"
)
}
func
(
s
*
stubUserRepo
)
DeductBalance
(
ctx
context
.
Context
,
id
int64
,
amount
float64
)
error
{
panic
(
"unexpected DeductBalance call"
)
}
func
(
s
*
stubUserRepo
)
UpdateConcurrency
(
ctx
context
.
Context
,
id
int64
,
amount
int
)
error
{
panic
(
"unexpected UpdateConcurrency call"
)
}
func
(
s
*
stubUserRepo
)
ExistsByEmail
(
ctx
context
.
Context
,
email
string
)
(
bool
,
error
)
{
panic
(
"unexpected ExistsByEmail call"
)
}
func
(
s
*
stubUserRepo
)
RemoveGroupFromAllowedGroups
(
ctx
context
.
Context
,
groupID
int64
)
(
int64
,
error
)
{
panic
(
"unexpected RemoveGroupFromAllowedGroups call"
)
}
func
(
s
*
stubUserRepo
)
UpdateTotpSecret
(
ctx
context
.
Context
,
userID
int64
,
encryptedSecret
*
string
)
error
{
panic
(
"unexpected UpdateTotpSecret call"
)
}
func
(
s
*
stubUserRepo
)
EnableTotp
(
ctx
context
.
Context
,
userID
int64
)
error
{
panic
(
"unexpected EnableTotp call"
)
}
func
(
s
*
stubUserRepo
)
DisableTotp
(
ctx
context
.
Context
,
userID
int64
)
error
{
panic
(
"unexpected DisableTotp call"
)
}
backend/internal/server/middleware/api_key_auth.go
View file @
07be258d
...
...
@@ -3,7 +3,6 @@ package middleware
import
(
"context"
"errors"
"log"
"strings"
"github.com/Wei-Shaw/sub2api/internal/config"
...
...
@@ -36,8 +35,8 @@ func apiKeyAuthWithSubscription(apiKeyService *service.APIKeyService, subscripti
if
authHeader
!=
""
{
// 验证Bearer scheme
parts
:=
strings
.
SplitN
(
authHeader
,
" "
,
2
)
if
len
(
parts
)
==
2
&&
parts
[
0
]
==
"Bearer"
{
apiKeyString
=
parts
[
1
]
if
len
(
parts
)
==
2
&&
strings
.
EqualFold
(
parts
[
0
]
,
"Bearer"
)
{
apiKeyString
=
strings
.
TrimSpace
(
parts
[
1
]
)
}
}
...
...
@@ -97,7 +96,7 @@ func apiKeyAuthWithSubscription(apiKeyService *service.APIKeyService, subscripti
// 检查 IP 限制(白名单/黑名单)
// 注意:错误信息故意模糊,避免暴露具体的 IP 限制机制
if
len
(
apiKey
.
IPWhitelist
)
>
0
||
len
(
apiKey
.
IPBlacklist
)
>
0
{
clientIP
:=
ip
.
GetClientIP
(
c
)
clientIP
:=
ip
.
Get
Trusted
ClientIP
(
c
)
allowed
,
_
:=
ip
.
CheckIPRestriction
(
clientIP
,
apiKey
.
IPWhitelist
,
apiKey
.
IPBlacklist
)
if
!
allowed
{
AbortWithError
(
c
,
403
,
"ACCESS_DENIED"
,
"Access denied"
)
...
...
@@ -126,6 +125,7 @@ func apiKeyAuthWithSubscription(apiKeyService *service.APIKeyService, subscripti
})
c
.
Set
(
string
(
ContextKeyUserRole
),
apiKey
.
User
.
Role
)
setGroupContext
(
c
,
apiKey
.
Group
)
_
=
apiKeyService
.
TouchLastUsed
(
c
.
Request
.
Context
(),
apiKey
.
ID
)
c
.
Next
()
return
}
...
...
@@ -134,7 +134,7 @@ func apiKeyAuthWithSubscription(apiKeyService *service.APIKeyService, subscripti
isSubscriptionType
:=
apiKey
.
Group
!=
nil
&&
apiKey
.
Group
.
IsSubscriptionType
()
if
isSubscriptionType
&&
subscriptionService
!=
nil
{
// 订阅模式:
验证订阅
// 订阅模式:
获取订阅(L1 缓存 + singleflight)
subscription
,
err
:=
subscriptionService
.
GetActiveSubscription
(
c
.
Request
.
Context
(),
apiKey
.
User
.
ID
,
...
...
@@ -145,30 +145,30 @@ func apiKeyAuthWithSubscription(apiKeyService *service.APIKeyService, subscripti
return
}
// 验证订阅状态(是否过期、暂停等)
if
err
:=
subscriptionService
.
ValidateSubscription
(
c
.
Request
.
Context
(),
subscription
);
err
!=
nil
{
AbortWithError
(
c
,
403
,
"SUBSCRIPTION_INVALID"
,
err
.
Error
())
return
}
// 激活滑动窗口(首次使用时)
if
err
:=
subscriptionService
.
CheckAndActivateWindow
(
c
.
Request
.
Context
(),
subscription
);
err
!=
nil
{
log
.
Printf
(
"Failed to activate subscription windows: %v"
,
err
)
}
// 检查并重置过期窗口
if
err
:=
subscriptionService
.
CheckAndResetWindows
(
c
.
Request
.
Context
(),
subscription
);
err
!=
nil
{
log
.
Printf
(
"Failed to reset subscription windows: %v"
,
err
)
}
// 预检查用量限制(使用0作为额外费用进行预检查)
if
err
:=
subscriptionService
.
CheckUsageLimits
(
c
.
Request
.
Context
(),
subscription
,
apiKey
.
Group
,
0
);
err
!=
nil
{
AbortWithError
(
c
,
429
,
"USAGE_LIMIT_EXCEEDED"
,
err
.
Error
())
// 合并验证 + 限额检查(纯内存操作)
needsMaintenance
,
err
:=
subscriptionService
.
ValidateAndCheckLimits
(
subscription
,
apiKey
.
Group
)
if
err
!=
nil
{
code
:=
"SUBSCRIPTION_INVALID"
status
:=
403
if
errors
.
Is
(
err
,
service
.
ErrDailyLimitExceeded
)
||
errors
.
Is
(
err
,
service
.
ErrWeeklyLimitExceeded
)
||
errors
.
Is
(
err
,
service
.
ErrMonthlyLimitExceeded
)
{
code
=
"USAGE_LIMIT_EXCEEDED"
status
=
429
}
AbortWithError
(
c
,
status
,
code
,
err
.
Error
())
return
}
// 将订阅信息存入上下文
c
.
Set
(
string
(
ContextKeySubscription
),
subscription
)
// 窗口维护异步化(不阻塞请求)
// 传递独立拷贝,避免与 handler 读取 context 中的 subscription 产生 data race
if
needsMaintenance
{
maintenanceCopy
:=
*
subscription
subscriptionService
.
DoWindowMaintenance
(
&
maintenanceCopy
)
}
}
else
{
// 余额模式:检查用户余额
if
apiKey
.
User
.
Balance
<=
0
{
...
...
@@ -185,6 +185,7 @@ func apiKeyAuthWithSubscription(apiKeyService *service.APIKeyService, subscripti
})
c
.
Set
(
string
(
ContextKeyUserRole
),
apiKey
.
User
.
Role
)
setGroupContext
(
c
,
apiKey
.
Group
)
_
=
apiKeyService
.
TouchLastUsed
(
c
.
Request
.
Context
(),
apiKey
.
ID
)
c
.
Next
()
}
...
...
backend/internal/server/middleware/api_key_auth_google.go
View file @
07be258d
...
...
@@ -64,6 +64,7 @@ func APIKeyAuthWithSubscriptionGoogle(apiKeyService *service.APIKeyService, subs
})
c
.
Set
(
string
(
ContextKeyUserRole
),
apiKey
.
User
.
Role
)
setGroupContext
(
c
,
apiKey
.
Group
)
_
=
apiKeyService
.
TouchLastUsed
(
c
.
Request
.
Context
(),
apiKey
.
ID
)
c
.
Next
()
return
}
...
...
@@ -104,6 +105,7 @@ func APIKeyAuthWithSubscriptionGoogle(apiKeyService *service.APIKeyService, subs
})
c
.
Set
(
string
(
ContextKeyUserRole
),
apiKey
.
User
.
Role
)
setGroupContext
(
c
,
apiKey
.
Group
)
_
=
apiKeyService
.
TouchLastUsed
(
c
.
Request
.
Context
(),
apiKey
.
ID
)
c
.
Next
()
}
}
...
...
backend/internal/server/middleware/api_key_auth_google_test.go
View file @
07be258d
...
...
@@ -7,6 +7,7 @@ import (
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
...
...
@@ -18,7 +19,8 @@ import (
)
type
fakeAPIKeyRepo
struct
{
getByKey
func
(
ctx
context
.
Context
,
key
string
)
(
*
service
.
APIKey
,
error
)
getByKey
func
(
ctx
context
.
Context
,
key
string
)
(
*
service
.
APIKey
,
error
)
updateLastUsed
func
(
ctx
context
.
Context
,
id
int64
,
usedAt
time
.
Time
)
error
}
func
(
f
fakeAPIKeyRepo
)
Create
(
ctx
context
.
Context
,
key
*
service
.
APIKey
)
error
{
...
...
@@ -78,6 +80,12 @@ func (f fakeAPIKeyRepo) ListKeysByGroupID(ctx context.Context, groupID int64) ([
func
(
f
fakeAPIKeyRepo
)
IncrementQuotaUsed
(
ctx
context
.
Context
,
id
int64
,
amount
float64
)
(
float64
,
error
)
{
return
0
,
errors
.
New
(
"not implemented"
)
}
func
(
f
fakeAPIKeyRepo
)
UpdateLastUsed
(
ctx
context
.
Context
,
id
int64
,
usedAt
time
.
Time
)
error
{
if
f
.
updateLastUsed
!=
nil
{
return
f
.
updateLastUsed
(
ctx
,
id
,
usedAt
)
}
return
nil
}
type
googleErrorResponse
struct
{
Error
struct
{
...
...
@@ -356,3 +364,144 @@ func TestApiKeyAuthWithSubscriptionGoogle_InsufficientBalance(t *testing.T) {
require
.
Equal
(
t
,
"Insufficient account balance"
,
resp
.
Error
.
Message
)
require
.
Equal
(
t
,
"PERMISSION_DENIED"
,
resp
.
Error
.
Status
)
}
func
TestApiKeyAuthWithSubscriptionGoogle_TouchesLastUsedOnSuccess
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
user
:=
&
service
.
User
{
ID
:
11
,
Role
:
service
.
RoleUser
,
Status
:
service
.
StatusActive
,
Balance
:
10
,
Concurrency
:
3
,
}
apiKey
:=
&
service
.
APIKey
{
ID
:
201
,
UserID
:
user
.
ID
,
Key
:
"google-touch-ok"
,
Status
:
service
.
StatusActive
,
User
:
user
,
}
var
touchedID
int64
var
touchedAt
time
.
Time
r
:=
gin
.
New
()
apiKeyService
:=
newTestAPIKeyService
(
fakeAPIKeyRepo
{
getByKey
:
func
(
ctx
context
.
Context
,
key
string
)
(
*
service
.
APIKey
,
error
)
{
if
key
!=
apiKey
.
Key
{
return
nil
,
service
.
ErrAPIKeyNotFound
}
clone
:=
*
apiKey
return
&
clone
,
nil
},
updateLastUsed
:
func
(
ctx
context
.
Context
,
id
int64
,
usedAt
time
.
Time
)
error
{
touchedID
=
id
touchedAt
=
usedAt
return
nil
},
})
cfg
:=
&
config
.
Config
{
RunMode
:
config
.
RunModeSimple
}
r
.
Use
(
APIKeyAuthWithSubscriptionGoogle
(
apiKeyService
,
nil
,
cfg
))
r
.
GET
(
"/v1beta/test"
,
func
(
c
*
gin
.
Context
)
{
c
.
JSON
(
200
,
gin
.
H
{
"ok"
:
true
})
})
req
:=
httptest
.
NewRequest
(
http
.
MethodGet
,
"/v1beta/test"
,
nil
)
req
.
Header
.
Set
(
"x-goog-api-key"
,
apiKey
.
Key
)
rec
:=
httptest
.
NewRecorder
()
r
.
ServeHTTP
(
rec
,
req
)
require
.
Equal
(
t
,
http
.
StatusOK
,
rec
.
Code
)
require
.
Equal
(
t
,
apiKey
.
ID
,
touchedID
)
require
.
False
(
t
,
touchedAt
.
IsZero
())
}
func
TestApiKeyAuthWithSubscriptionGoogle_TouchFailureDoesNotBlock
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
user
:=
&
service
.
User
{
ID
:
12
,
Role
:
service
.
RoleUser
,
Status
:
service
.
StatusActive
,
Balance
:
10
,
Concurrency
:
3
,
}
apiKey
:=
&
service
.
APIKey
{
ID
:
202
,
UserID
:
user
.
ID
,
Key
:
"google-touch-fail"
,
Status
:
service
.
StatusActive
,
User
:
user
,
}
touchCalls
:=
0
r
:=
gin
.
New
()
apiKeyService
:=
newTestAPIKeyService
(
fakeAPIKeyRepo
{
getByKey
:
func
(
ctx
context
.
Context
,
key
string
)
(
*
service
.
APIKey
,
error
)
{
if
key
!=
apiKey
.
Key
{
return
nil
,
service
.
ErrAPIKeyNotFound
}
clone
:=
*
apiKey
return
&
clone
,
nil
},
updateLastUsed
:
func
(
ctx
context
.
Context
,
id
int64
,
usedAt
time
.
Time
)
error
{
touchCalls
++
return
errors
.
New
(
"write failed"
)
},
})
cfg
:=
&
config
.
Config
{
RunMode
:
config
.
RunModeSimple
}
r
.
Use
(
APIKeyAuthWithSubscriptionGoogle
(
apiKeyService
,
nil
,
cfg
))
r
.
GET
(
"/v1beta/test"
,
func
(
c
*
gin
.
Context
)
{
c
.
JSON
(
200
,
gin
.
H
{
"ok"
:
true
})
})
req
:=
httptest
.
NewRequest
(
http
.
MethodGet
,
"/v1beta/test"
,
nil
)
req
.
Header
.
Set
(
"x-goog-api-key"
,
apiKey
.
Key
)
rec
:=
httptest
.
NewRecorder
()
r
.
ServeHTTP
(
rec
,
req
)
require
.
Equal
(
t
,
http
.
StatusOK
,
rec
.
Code
)
require
.
Equal
(
t
,
1
,
touchCalls
)
}
func
TestApiKeyAuthWithSubscriptionGoogle_TouchesLastUsedInStandardMode
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
user
:=
&
service
.
User
{
ID
:
13
,
Role
:
service
.
RoleUser
,
Status
:
service
.
StatusActive
,
Balance
:
10
,
Concurrency
:
3
,
}
apiKey
:=
&
service
.
APIKey
{
ID
:
203
,
UserID
:
user
.
ID
,
Key
:
"google-touch-standard"
,
Status
:
service
.
StatusActive
,
User
:
user
,
}
touchCalls
:=
0
r
:=
gin
.
New
()
apiKeyService
:=
newTestAPIKeyService
(
fakeAPIKeyRepo
{
getByKey
:
func
(
ctx
context
.
Context
,
key
string
)
(
*
service
.
APIKey
,
error
)
{
if
key
!=
apiKey
.
Key
{
return
nil
,
service
.
ErrAPIKeyNotFound
}
clone
:=
*
apiKey
return
&
clone
,
nil
},
updateLastUsed
:
func
(
ctx
context
.
Context
,
id
int64
,
usedAt
time
.
Time
)
error
{
touchCalls
++
return
nil
},
})
cfg
:=
&
config
.
Config
{
RunMode
:
config
.
RunModeStandard
}
r
.
Use
(
APIKeyAuthWithSubscriptionGoogle
(
apiKeyService
,
nil
,
cfg
))
r
.
GET
(
"/v1beta/test"
,
func
(
c
*
gin
.
Context
)
{
c
.
JSON
(
200
,
gin
.
H
{
"ok"
:
true
})
})
req
:=
httptest
.
NewRequest
(
http
.
MethodGet
,
"/v1beta/test"
,
nil
)
req
.
Header
.
Set
(
"Authorization"
,
"Bearer "
+
apiKey
.
Key
)
rec
:=
httptest
.
NewRecorder
()
r
.
ServeHTTP
(
rec
,
req
)
require
.
Equal
(
t
,
http
.
StatusOK
,
rec
.
Code
)
require
.
Equal
(
t
,
1
,
touchCalls
)
}
backend/internal/server/middleware/api_key_auth_test.go
View file @
07be258d
...
...
@@ -57,10 +57,61 @@ func TestSimpleModeBypassesQuotaCheck(t *testing.T) {
},
}
t
.
Run
(
"standard_mode_needs_maintenance_does_not_block_request"
,
func
(
t
*
testing
.
T
)
{
cfg
:=
&
config
.
Config
{
RunMode
:
config
.
RunModeStandard
}
cfg
.
SubscriptionMaintenance
.
WorkerCount
=
1
cfg
.
SubscriptionMaintenance
.
QueueSize
=
1
apiKeyService
:=
service
.
NewAPIKeyService
(
apiKeyRepo
,
nil
,
nil
,
nil
,
nil
,
nil
,
cfg
)
past
:=
time
.
Now
()
.
Add
(
-
48
*
time
.
Hour
)
sub
:=
&
service
.
UserSubscription
{
ID
:
55
,
UserID
:
user
.
ID
,
GroupID
:
group
.
ID
,
Status
:
service
.
SubscriptionStatusActive
,
ExpiresAt
:
time
.
Now
()
.
Add
(
24
*
time
.
Hour
),
DailyWindowStart
:
&
past
,
DailyUsageUSD
:
0
,
}
maintenanceCalled
:=
make
(
chan
struct
{},
1
)
subscriptionRepo
:=
&
stubUserSubscriptionRepo
{
getActive
:
func
(
ctx
context
.
Context
,
userID
,
groupID
int64
)
(
*
service
.
UserSubscription
,
error
)
{
clone
:=
*
sub
return
&
clone
,
nil
},
updateStatus
:
func
(
ctx
context
.
Context
,
subscriptionID
int64
,
status
string
)
error
{
return
nil
},
activateWindow
:
func
(
ctx
context
.
Context
,
id
int64
,
start
time
.
Time
)
error
{
return
nil
},
resetDaily
:
func
(
ctx
context
.
Context
,
id
int64
,
start
time
.
Time
)
error
{
maintenanceCalled
<-
struct
{}{}
return
nil
},
resetWeekly
:
func
(
ctx
context
.
Context
,
id
int64
,
start
time
.
Time
)
error
{
return
nil
},
resetMonthly
:
func
(
ctx
context
.
Context
,
id
int64
,
start
time
.
Time
)
error
{
return
nil
},
}
subscriptionService
:=
service
.
NewSubscriptionService
(
nil
,
subscriptionRepo
,
nil
,
nil
,
cfg
)
t
.
Cleanup
(
subscriptionService
.
Stop
)
router
:=
newAuthTestRouter
(
apiKeyService
,
subscriptionService
,
cfg
)
w
:=
httptest
.
NewRecorder
()
req
:=
httptest
.
NewRequest
(
http
.
MethodGet
,
"/t"
,
nil
)
req
.
Header
.
Set
(
"x-api-key"
,
apiKey
.
Key
)
router
.
ServeHTTP
(
w
,
req
)
require
.
Equal
(
t
,
http
.
StatusOK
,
w
.
Code
)
select
{
case
<-
maintenanceCalled
:
// ok
case
<-
time
.
After
(
time
.
Second
)
:
t
.
Fatalf
(
"expected maintenance to be scheduled"
)
}
})
t
.
Run
(
"simple_mode_bypasses_quota_check"
,
func
(
t
*
testing
.
T
)
{
cfg
:=
&
config
.
Config
{
RunMode
:
config
.
RunModeSimple
}
apiKeyService
:=
service
.
NewAPIKeyService
(
apiKeyRepo
,
nil
,
nil
,
nil
,
nil
,
nil
,
cfg
)
subscriptionService
:=
service
.
NewSubscriptionService
(
nil
,
&
stubUserSubscriptionRepo
{},
nil
)
subscriptionService
:=
service
.
NewSubscriptionService
(
nil
,
&
stubUserSubscriptionRepo
{},
nil
,
nil
,
cfg
)
router
:=
newAuthTestRouter
(
apiKeyService
,
subscriptionService
,
cfg
)
w
:=
httptest
.
NewRecorder
()
...
...
@@ -71,6 +122,20 @@ func TestSimpleModeBypassesQuotaCheck(t *testing.T) {
require
.
Equal
(
t
,
http
.
StatusOK
,
w
.
Code
)
})
t
.
Run
(
"simple_mode_accepts_lowercase_bearer"
,
func
(
t
*
testing
.
T
)
{
cfg
:=
&
config
.
Config
{
RunMode
:
config
.
RunModeSimple
}
apiKeyService
:=
service
.
NewAPIKeyService
(
apiKeyRepo
,
nil
,
nil
,
nil
,
nil
,
nil
,
cfg
)
subscriptionService
:=
service
.
NewSubscriptionService
(
nil
,
&
stubUserSubscriptionRepo
{},
nil
,
nil
,
cfg
)
router
:=
newAuthTestRouter
(
apiKeyService
,
subscriptionService
,
cfg
)
w
:=
httptest
.
NewRecorder
()
req
:=
httptest
.
NewRequest
(
http
.
MethodGet
,
"/t"
,
nil
)
req
.
Header
.
Set
(
"Authorization"
,
"bearer "
+
apiKey
.
Key
)
router
.
ServeHTTP
(
w
,
req
)
require
.
Equal
(
t
,
http
.
StatusOK
,
w
.
Code
)
})
t
.
Run
(
"standard_mode_enforces_quota_check"
,
func
(
t
*
testing
.
T
)
{
cfg
:=
&
config
.
Config
{
RunMode
:
config
.
RunModeStandard
}
apiKeyService
:=
service
.
NewAPIKeyService
(
apiKeyRepo
,
nil
,
nil
,
nil
,
nil
,
nil
,
cfg
)
...
...
@@ -99,7 +164,7 @@ func TestSimpleModeBypassesQuotaCheck(t *testing.T) {
resetWeekly
:
func
(
ctx
context
.
Context
,
id
int64
,
start
time
.
Time
)
error
{
return
nil
},
resetMonthly
:
func
(
ctx
context
.
Context
,
id
int64
,
start
time
.
Time
)
error
{
return
nil
},
}
subscriptionService
:=
service
.
NewSubscriptionService
(
nil
,
subscriptionRepo
,
nil
)
subscriptionService
:=
service
.
NewSubscriptionService
(
nil
,
subscriptionRepo
,
nil
,
nil
,
cfg
)
router
:=
newAuthTestRouter
(
apiKeyService
,
subscriptionService
,
cfg
)
w
:=
httptest
.
NewRecorder
()
...
...
@@ -235,6 +300,198 @@ func TestAPIKeyAuthOverwritesInvalidContextGroup(t *testing.T) {
require
.
Equal
(
t
,
http
.
StatusOK
,
w
.
Code
)
}
func
TestAPIKeyAuthIPRestrictionDoesNotTrustSpoofedForwardHeaders
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
user
:=
&
service
.
User
{
ID
:
7
,
Role
:
service
.
RoleUser
,
Status
:
service
.
StatusActive
,
Balance
:
10
,
Concurrency
:
3
,
}
apiKey
:=
&
service
.
APIKey
{
ID
:
100
,
UserID
:
user
.
ID
,
Key
:
"test-key"
,
Status
:
service
.
StatusActive
,
User
:
user
,
IPWhitelist
:
[]
string
{
"1.2.3.4"
},
}
apiKeyRepo
:=
&
stubApiKeyRepo
{
getByKey
:
func
(
ctx
context
.
Context
,
key
string
)
(
*
service
.
APIKey
,
error
)
{
if
key
!=
apiKey
.
Key
{
return
nil
,
service
.
ErrAPIKeyNotFound
}
clone
:=
*
apiKey
return
&
clone
,
nil
},
}
cfg
:=
&
config
.
Config
{
RunMode
:
config
.
RunModeSimple
}
apiKeyService
:=
service
.
NewAPIKeyService
(
apiKeyRepo
,
nil
,
nil
,
nil
,
nil
,
nil
,
cfg
)
router
:=
gin
.
New
()
require
.
NoError
(
t
,
router
.
SetTrustedProxies
(
nil
))
router
.
Use
(
gin
.
HandlerFunc
(
NewAPIKeyAuthMiddleware
(
apiKeyService
,
nil
,
cfg
)))
router
.
GET
(
"/t"
,
func
(
c
*
gin
.
Context
)
{
c
.
JSON
(
http
.
StatusOK
,
gin
.
H
{
"ok"
:
true
})
})
w
:=
httptest
.
NewRecorder
()
req
:=
httptest
.
NewRequest
(
http
.
MethodGet
,
"/t"
,
nil
)
req
.
RemoteAddr
=
"9.9.9.9:12345"
req
.
Header
.
Set
(
"x-api-key"
,
apiKey
.
Key
)
req
.
Header
.
Set
(
"X-Forwarded-For"
,
"1.2.3.4"
)
req
.
Header
.
Set
(
"X-Real-IP"
,
"1.2.3.4"
)
req
.
Header
.
Set
(
"CF-Connecting-IP"
,
"1.2.3.4"
)
router
.
ServeHTTP
(
w
,
req
)
require
.
Equal
(
t
,
http
.
StatusForbidden
,
w
.
Code
)
require
.
Contains
(
t
,
w
.
Body
.
String
(),
"ACCESS_DENIED"
)
}
func
TestAPIKeyAuthTouchesLastUsedOnSuccess
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
user
:=
&
service
.
User
{
ID
:
7
,
Role
:
service
.
RoleUser
,
Status
:
service
.
StatusActive
,
Balance
:
10
,
Concurrency
:
3
,
}
apiKey
:=
&
service
.
APIKey
{
ID
:
100
,
UserID
:
user
.
ID
,
Key
:
"touch-ok"
,
Status
:
service
.
StatusActive
,
User
:
user
,
}
var
touchedID
int64
var
touchedAt
time
.
Time
apiKeyRepo
:=
&
stubApiKeyRepo
{
getByKey
:
func
(
ctx
context
.
Context
,
key
string
)
(
*
service
.
APIKey
,
error
)
{
if
key
!=
apiKey
.
Key
{
return
nil
,
service
.
ErrAPIKeyNotFound
}
clone
:=
*
apiKey
return
&
clone
,
nil
},
updateLastUsed
:
func
(
ctx
context
.
Context
,
id
int64
,
usedAt
time
.
Time
)
error
{
touchedID
=
id
touchedAt
=
usedAt
return
nil
},
}
cfg
:=
&
config
.
Config
{
RunMode
:
config
.
RunModeSimple
}
apiKeyService
:=
service
.
NewAPIKeyService
(
apiKeyRepo
,
nil
,
nil
,
nil
,
nil
,
nil
,
cfg
)
router
:=
newAuthTestRouter
(
apiKeyService
,
nil
,
cfg
)
w
:=
httptest
.
NewRecorder
()
req
:=
httptest
.
NewRequest
(
http
.
MethodGet
,
"/t"
,
nil
)
req
.
Header
.
Set
(
"x-api-key"
,
apiKey
.
Key
)
router
.
ServeHTTP
(
w
,
req
)
require
.
Equal
(
t
,
http
.
StatusOK
,
w
.
Code
)
require
.
Equal
(
t
,
apiKey
.
ID
,
touchedID
)
require
.
False
(
t
,
touchedAt
.
IsZero
(),
"expected touch timestamp"
)
}
func
TestAPIKeyAuthTouchLastUsedFailureDoesNotBlock
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
user
:=
&
service
.
User
{
ID
:
8
,
Role
:
service
.
RoleUser
,
Status
:
service
.
StatusActive
,
Balance
:
10
,
Concurrency
:
3
,
}
apiKey
:=
&
service
.
APIKey
{
ID
:
101
,
UserID
:
user
.
ID
,
Key
:
"touch-fail"
,
Status
:
service
.
StatusActive
,
User
:
user
,
}
touchCalls
:=
0
apiKeyRepo
:=
&
stubApiKeyRepo
{
getByKey
:
func
(
ctx
context
.
Context
,
key
string
)
(
*
service
.
APIKey
,
error
)
{
if
key
!=
apiKey
.
Key
{
return
nil
,
service
.
ErrAPIKeyNotFound
}
clone
:=
*
apiKey
return
&
clone
,
nil
},
updateLastUsed
:
func
(
ctx
context
.
Context
,
id
int64
,
usedAt
time
.
Time
)
error
{
touchCalls
++
return
errors
.
New
(
"db unavailable"
)
},
}
cfg
:=
&
config
.
Config
{
RunMode
:
config
.
RunModeSimple
}
apiKeyService
:=
service
.
NewAPIKeyService
(
apiKeyRepo
,
nil
,
nil
,
nil
,
nil
,
nil
,
cfg
)
router
:=
newAuthTestRouter
(
apiKeyService
,
nil
,
cfg
)
w
:=
httptest
.
NewRecorder
()
req
:=
httptest
.
NewRequest
(
http
.
MethodGet
,
"/t"
,
nil
)
req
.
Header
.
Set
(
"x-api-key"
,
apiKey
.
Key
)
router
.
ServeHTTP
(
w
,
req
)
require
.
Equal
(
t
,
http
.
StatusOK
,
w
.
Code
,
"touch failure should not block request"
)
require
.
Equal
(
t
,
1
,
touchCalls
)
}
func
TestAPIKeyAuthTouchesLastUsedInStandardMode
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
user
:=
&
service
.
User
{
ID
:
9
,
Role
:
service
.
RoleUser
,
Status
:
service
.
StatusActive
,
Balance
:
10
,
Concurrency
:
3
,
}
apiKey
:=
&
service
.
APIKey
{
ID
:
102
,
UserID
:
user
.
ID
,
Key
:
"touch-standard"
,
Status
:
service
.
StatusActive
,
User
:
user
,
}
touchCalls
:=
0
apiKeyRepo
:=
&
stubApiKeyRepo
{
getByKey
:
func
(
ctx
context
.
Context
,
key
string
)
(
*
service
.
APIKey
,
error
)
{
if
key
!=
apiKey
.
Key
{
return
nil
,
service
.
ErrAPIKeyNotFound
}
clone
:=
*
apiKey
return
&
clone
,
nil
},
updateLastUsed
:
func
(
ctx
context
.
Context
,
id
int64
,
usedAt
time
.
Time
)
error
{
touchCalls
++
return
nil
},
}
cfg
:=
&
config
.
Config
{
RunMode
:
config
.
RunModeStandard
}
apiKeyService
:=
service
.
NewAPIKeyService
(
apiKeyRepo
,
nil
,
nil
,
nil
,
nil
,
nil
,
cfg
)
router
:=
newAuthTestRouter
(
apiKeyService
,
nil
,
cfg
)
w
:=
httptest
.
NewRecorder
()
req
:=
httptest
.
NewRequest
(
http
.
MethodGet
,
"/t"
,
nil
)
req
.
Header
.
Set
(
"x-api-key"
,
apiKey
.
Key
)
router
.
ServeHTTP
(
w
,
req
)
require
.
Equal
(
t
,
http
.
StatusOK
,
w
.
Code
)
require
.
Equal
(
t
,
1
,
touchCalls
)
}
func
newAuthTestRouter
(
apiKeyService
*
service
.
APIKeyService
,
subscriptionService
*
service
.
SubscriptionService
,
cfg
*
config
.
Config
)
*
gin
.
Engine
{
router
:=
gin
.
New
()
router
.
Use
(
gin
.
HandlerFunc
(
NewAPIKeyAuthMiddleware
(
apiKeyService
,
subscriptionService
,
cfg
)))
...
...
@@ -245,7 +502,8 @@ func newAuthTestRouter(apiKeyService *service.APIKeyService, subscriptionService
}
type
stubApiKeyRepo
struct
{
getByKey
func
(
ctx
context
.
Context
,
key
string
)
(
*
service
.
APIKey
,
error
)
getByKey
func
(
ctx
context
.
Context
,
key
string
)
(
*
service
.
APIKey
,
error
)
updateLastUsed
func
(
ctx
context
.
Context
,
id
int64
,
usedAt
time
.
Time
)
error
}
func
(
r
*
stubApiKeyRepo
)
Create
(
ctx
context
.
Context
,
key
*
service
.
APIKey
)
error
{
...
...
@@ -323,6 +581,13 @@ func (r *stubApiKeyRepo) IncrementQuotaUsed(ctx context.Context, id int64, amoun
return
0
,
errors
.
New
(
"not implemented"
)
}
func
(
r
*
stubApiKeyRepo
)
UpdateLastUsed
(
ctx
context
.
Context
,
id
int64
,
usedAt
time
.
Time
)
error
{
if
r
.
updateLastUsed
!=
nil
{
return
r
.
updateLastUsed
(
ctx
,
id
,
usedAt
)
}
return
nil
}
type
stubUserSubscriptionRepo
struct
{
getActive
func
(
ctx
context
.
Context
,
userID
,
groupID
int64
)
(
*
service
.
UserSubscription
,
error
)
updateStatus
func
(
ctx
context
.
Context
,
subscriptionID
int64
,
status
string
)
error
...
...
Prev
1
…
5
6
7
8
9
10
11
12
13
14
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment