Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Menu
Open sidebar
陈曦
sub2api
Commits
6901b64f
Commit
6901b64f
authored
Jan 17, 2026
by
cyhhao
Browse files
merge: sync upstream changes
parents
32c47b15
dae0d532
Changes
189
Show whitespace changes
Inline
Side-by-side
backend/internal/repository/dashboard_aggregation_repo.go
View file @
6901b64f
...
...
@@ -8,6 +8,7 @@ import (
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/lib/pq"
)
...
...
@@ -41,21 +42,22 @@ func isPostgresDriver(db *sql.DB) bool {
}
func
(
r
*
dashboardAggregationRepository
)
AggregateRange
(
ctx
context
.
Context
,
start
,
end
time
.
Time
)
error
{
startUTC
:=
start
.
UTC
()
endUTC
:=
end
.
UTC
()
if
!
endUTC
.
After
(
startUTC
)
{
loc
:=
timezone
.
Location
()
startLocal
:=
start
.
In
(
loc
)
endLocal
:=
end
.
In
(
loc
)
if
!
endLocal
.
After
(
startLocal
)
{
return
nil
}
hourStart
:=
start
UTC
.
Truncate
(
time
.
Hour
)
hourEnd
:=
end
UTC
.
Truncate
(
time
.
Hour
)
if
end
UTC
.
After
(
hourEnd
)
{
hourStart
:=
start
Local
.
Truncate
(
time
.
Hour
)
hourEnd
:=
end
Local
.
Truncate
(
time
.
Hour
)
if
end
Local
.
After
(
hourEnd
)
{
hourEnd
=
hourEnd
.
Add
(
time
.
Hour
)
}
dayStart
:=
truncateToDay
UTC
(
start
UTC
)
dayEnd
:=
truncateToDay
UTC
(
end
UTC
)
if
end
UTC
.
After
(
dayEnd
)
{
dayStart
:=
truncateToDay
(
start
Local
)
dayEnd
:=
truncateToDay
(
end
Local
)
if
end
Local
.
After
(
dayEnd
)
{
dayEnd
=
dayEnd
.
Add
(
24
*
time
.
Hour
)
}
...
...
@@ -146,38 +148,41 @@ func (r *dashboardAggregationRepository) EnsureUsageLogsPartitions(ctx context.C
}
func
(
r
*
dashboardAggregationRepository
)
insertHourlyActiveUsers
(
ctx
context
.
Context
,
start
,
end
time
.
Time
)
error
{
tzName
:=
timezone
.
Name
()
query
:=
`
INSERT INTO usage_dashboard_hourly_users (bucket_start, user_id)
SELECT DISTINCT
date_trunc('hour', created_at AT TIME ZONE
'UTC'
) AT TIME ZONE
'UTC'
AS bucket_start,
date_trunc('hour', created_at AT TIME ZONE
$3
) AT TIME ZONE
$3
AS bucket_start,
user_id
FROM usage_logs
WHERE created_at >= $1 AND created_at < $2
ON CONFLICT DO NOTHING
`
_
,
err
:=
r
.
sql
.
ExecContext
(
ctx
,
query
,
start
.
UTC
(),
end
.
UTC
()
)
_
,
err
:=
r
.
sql
.
ExecContext
(
ctx
,
query
,
start
,
end
,
tzName
)
return
err
}
func
(
r
*
dashboardAggregationRepository
)
insertDailyActiveUsers
(
ctx
context
.
Context
,
start
,
end
time
.
Time
)
error
{
tzName
:=
timezone
.
Name
()
query
:=
`
INSERT INTO usage_dashboard_daily_users (bucket_date, user_id)
SELECT DISTINCT
(bucket_start AT TIME ZONE
'UTC'
)::date AS bucket_date,
(bucket_start AT TIME ZONE
$3
)::date AS bucket_date,
user_id
FROM usage_dashboard_hourly_users
WHERE bucket_start >= $1 AND bucket_start < $2
ON CONFLICT DO NOTHING
`
_
,
err
:=
r
.
sql
.
ExecContext
(
ctx
,
query
,
start
.
UTC
(),
end
.
UTC
()
)
_
,
err
:=
r
.
sql
.
ExecContext
(
ctx
,
query
,
start
,
end
,
tzName
)
return
err
}
func
(
r
*
dashboardAggregationRepository
)
upsertHourlyAggregates
(
ctx
context
.
Context
,
start
,
end
time
.
Time
)
error
{
tzName
:=
timezone
.
Name
()
query
:=
`
WITH hourly AS (
SELECT
date_trunc('hour', created_at AT TIME ZONE
'UTC'
) AT TIME ZONE
'UTC'
AS bucket_start,
date_trunc('hour', created_at AT TIME ZONE
$3
) AT TIME ZONE
$3
AS bucket_start,
COUNT(*) AS total_requests,
COALESCE(SUM(input_tokens), 0) AS input_tokens,
COALESCE(SUM(output_tokens), 0) AS output_tokens,
...
...
@@ -236,15 +241,16 @@ func (r *dashboardAggregationRepository) upsertHourlyAggregates(ctx context.Cont
active_users = EXCLUDED.active_users,
computed_at = EXCLUDED.computed_at
`
_
,
err
:=
r
.
sql
.
ExecContext
(
ctx
,
query
,
start
.
UTC
(),
end
.
UTC
()
)
_
,
err
:=
r
.
sql
.
ExecContext
(
ctx
,
query
,
start
,
end
,
tzName
)
return
err
}
func
(
r
*
dashboardAggregationRepository
)
upsertDailyAggregates
(
ctx
context
.
Context
,
start
,
end
time
.
Time
)
error
{
tzName
:=
timezone
.
Name
()
query
:=
`
WITH daily AS (
SELECT
(bucket_start AT TIME ZONE
'UTC'
)::date AS bucket_date,
(bucket_start AT TIME ZONE
$5
)::date AS bucket_date,
COALESCE(SUM(total_requests), 0) AS total_requests,
COALESCE(SUM(input_tokens), 0) AS input_tokens,
COALESCE(SUM(output_tokens), 0) AS output_tokens,
...
...
@@ -255,7 +261,7 @@ func (r *dashboardAggregationRepository) upsertDailyAggregates(ctx context.Conte
COALESCE(SUM(total_duration_ms), 0) AS total_duration_ms
FROM usage_dashboard_hourly
WHERE bucket_start >= $1 AND bucket_start < $2
GROUP BY (bucket_start AT TIME ZONE
'UTC'
)::date
GROUP BY (bucket_start AT TIME ZONE
$5
)::date
),
user_counts AS (
SELECT bucket_date, COUNT(*) AS active_users
...
...
@@ -303,7 +309,7 @@ func (r *dashboardAggregationRepository) upsertDailyAggregates(ctx context.Conte
active_users = EXCLUDED.active_users,
computed_at = EXCLUDED.computed_at
`
_
,
err
:=
r
.
sql
.
ExecContext
(
ctx
,
query
,
start
.
UTC
(),
end
.
UTC
(),
start
.
UTC
(),
end
.
UTC
()
)
_
,
err
:=
r
.
sql
.
ExecContext
(
ctx
,
query
,
start
,
end
,
start
,
end
,
tzName
)
return
err
}
...
...
@@ -376,9 +382,8 @@ func (r *dashboardAggregationRepository) createUsageLogsPartition(ctx context.Co
return
err
}
func
truncateToDayUTC
(
t
time
.
Time
)
time
.
Time
{
t
=
t
.
UTC
()
return
time
.
Date
(
t
.
Year
(),
t
.
Month
(),
t
.
Day
(),
0
,
0
,
0
,
0
,
time
.
UTC
)
func
truncateToDay
(
t
time
.
Time
)
time
.
Time
{
return
timezone
.
StartOfDay
(
t
)
}
func
truncateToMonthUTC
(
t
time
.
Time
)
time
.
Time
{
...
...
backend/internal/repository/gemini_token_cache.go
View file @
6901b64f
...
...
@@ -11,8 +11,8 @@ import (
)
const
(
gemini
TokenKeyPrefix
=
"
gemini
:token:"
gemini
RefreshLockKeyPrefix
=
"
gemini
:refresh_lock:"
oauth
TokenKeyPrefix
=
"
oauth
:token:"
oauth
RefreshLockKeyPrefix
=
"
oauth
:refresh_lock:"
)
type
geminiTokenCache
struct
{
...
...
@@ -24,21 +24,26 @@ func NewGeminiTokenCache(rdb *redis.Client) service.GeminiTokenCache {
}
func
(
c
*
geminiTokenCache
)
GetAccessToken
(
ctx
context
.
Context
,
cacheKey
string
)
(
string
,
error
)
{
key
:=
fmt
.
Sprintf
(
"%s%s"
,
gemini
TokenKeyPrefix
,
cacheKey
)
key
:=
fmt
.
Sprintf
(
"%s%s"
,
oauth
TokenKeyPrefix
,
cacheKey
)
return
c
.
rdb
.
Get
(
ctx
,
key
)
.
Result
()
}
func
(
c
*
geminiTokenCache
)
SetAccessToken
(
ctx
context
.
Context
,
cacheKey
string
,
token
string
,
ttl
time
.
Duration
)
error
{
key
:=
fmt
.
Sprintf
(
"%s%s"
,
gemini
TokenKeyPrefix
,
cacheKey
)
key
:=
fmt
.
Sprintf
(
"%s%s"
,
oauth
TokenKeyPrefix
,
cacheKey
)
return
c
.
rdb
.
Set
(
ctx
,
key
,
token
,
ttl
)
.
Err
()
}
func
(
c
*
geminiTokenCache
)
DeleteAccessToken
(
ctx
context
.
Context
,
cacheKey
string
)
error
{
key
:=
fmt
.
Sprintf
(
"%s%s"
,
oauthTokenKeyPrefix
,
cacheKey
)
return
c
.
rdb
.
Del
(
ctx
,
key
)
.
Err
()
}
func
(
c
*
geminiTokenCache
)
AcquireRefreshLock
(
ctx
context
.
Context
,
cacheKey
string
,
ttl
time
.
Duration
)
(
bool
,
error
)
{
key
:=
fmt
.
Sprintf
(
"%s%s"
,
gemini
RefreshLockKeyPrefix
,
cacheKey
)
key
:=
fmt
.
Sprintf
(
"%s%s"
,
oauth
RefreshLockKeyPrefix
,
cacheKey
)
return
c
.
rdb
.
SetNX
(
ctx
,
key
,
1
,
ttl
)
.
Result
()
}
func
(
c
*
geminiTokenCache
)
ReleaseRefreshLock
(
ctx
context
.
Context
,
cacheKey
string
)
error
{
key
:=
fmt
.
Sprintf
(
"%s%s"
,
gemini
RefreshLockKeyPrefix
,
cacheKey
)
key
:=
fmt
.
Sprintf
(
"%s%s"
,
oauth
RefreshLockKeyPrefix
,
cacheKey
)
return
c
.
rdb
.
Del
(
ctx
,
key
)
.
Err
()
}
backend/internal/repository/gemini_token_cache_integration_test.go
0 → 100644
View file @
6901b64f
//go:build integration
package
repository
import
(
"errors"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/redis/go-redis/v9"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
)
type
GeminiTokenCacheSuite
struct
{
IntegrationRedisSuite
cache
service
.
GeminiTokenCache
}
func
(
s
*
GeminiTokenCacheSuite
)
SetupTest
()
{
s
.
IntegrationRedisSuite
.
SetupTest
()
s
.
cache
=
NewGeminiTokenCache
(
s
.
rdb
)
}
func
(
s
*
GeminiTokenCacheSuite
)
TestDeleteAccessToken
()
{
cacheKey
:=
"project-123"
token
:=
"token-value"
require
.
NoError
(
s
.
T
(),
s
.
cache
.
SetAccessToken
(
s
.
ctx
,
cacheKey
,
token
,
time
.
Minute
))
got
,
err
:=
s
.
cache
.
GetAccessToken
(
s
.
ctx
,
cacheKey
)
require
.
NoError
(
s
.
T
(),
err
)
require
.
Equal
(
s
.
T
(),
token
,
got
)
require
.
NoError
(
s
.
T
(),
s
.
cache
.
DeleteAccessToken
(
s
.
ctx
,
cacheKey
))
_
,
err
=
s
.
cache
.
GetAccessToken
(
s
.
ctx
,
cacheKey
)
require
.
True
(
s
.
T
(),
errors
.
Is
(
err
,
redis
.
Nil
),
"expected redis.Nil after delete"
)
}
func
(
s
*
GeminiTokenCacheSuite
)
TestDeleteAccessToken_MissingKey
()
{
require
.
NoError
(
s
.
T
(),
s
.
cache
.
DeleteAccessToken
(
s
.
ctx
,
"missing-key"
))
}
func
TestGeminiTokenCacheSuite
(
t
*
testing
.
T
)
{
suite
.
Run
(
t
,
new
(
GeminiTokenCacheSuite
))
}
backend/internal/repository/gemini_token_cache_test.go
0 → 100644
View file @
6901b64f
//go:build unit
package
repository
import
(
"context"
"testing"
"time"
"github.com/redis/go-redis/v9"
"github.com/stretchr/testify/require"
)
func
TestGeminiTokenCache_DeleteAccessToken_RedisError
(
t
*
testing
.
T
)
{
rdb
:=
redis
.
NewClient
(
&
redis
.
Options
{
Addr
:
"127.0.0.1:1"
,
DialTimeout
:
50
*
time
.
Millisecond
,
ReadTimeout
:
50
*
time
.
Millisecond
,
WriteTimeout
:
50
*
time
.
Millisecond
,
})
t
.
Cleanup
(
func
()
{
_
=
rdb
.
Close
()
})
cache
:=
NewGeminiTokenCache
(
rdb
)
err
:=
cache
.
DeleteAccessToken
(
context
.
Background
(),
"broken"
)
require
.
Error
(
t
,
err
)
}
backend/internal/repository/group_repo.go
View file @
6901b64f
...
...
@@ -49,7 +49,13 @@ func (r *groupRepository) Create(ctx context.Context, groupIn *service.Group) er
SetNillableImagePrice4k
(
groupIn
.
ImagePrice4K
)
.
SetDefaultValidityDays
(
groupIn
.
DefaultValidityDays
)
.
SetClaudeCodeOnly
(
groupIn
.
ClaudeCodeOnly
)
.
SetNillableFallbackGroupID
(
groupIn
.
FallbackGroupID
)
SetNillableFallbackGroupID
(
groupIn
.
FallbackGroupID
)
.
SetModelRoutingEnabled
(
groupIn
.
ModelRoutingEnabled
)
// 设置模型路由配置
if
groupIn
.
ModelRouting
!=
nil
{
builder
=
builder
.
SetModelRouting
(
groupIn
.
ModelRouting
)
}
created
,
err
:=
builder
.
Save
(
ctx
)
if
err
==
nil
{
...
...
@@ -101,7 +107,8 @@ func (r *groupRepository) Update(ctx context.Context, groupIn *service.Group) er
SetNillableImagePrice2k
(
groupIn
.
ImagePrice2K
)
.
SetNillableImagePrice4k
(
groupIn
.
ImagePrice4K
)
.
SetDefaultValidityDays
(
groupIn
.
DefaultValidityDays
)
.
SetClaudeCodeOnly
(
groupIn
.
ClaudeCodeOnly
)
SetClaudeCodeOnly
(
groupIn
.
ClaudeCodeOnly
)
.
SetModelRoutingEnabled
(
groupIn
.
ModelRoutingEnabled
)
// 处理 FallbackGroupID:nil 时清除,否则设置
if
groupIn
.
FallbackGroupID
!=
nil
{
...
...
@@ -110,6 +117,13 @@ func (r *groupRepository) Update(ctx context.Context, groupIn *service.Group) er
builder
=
builder
.
ClearFallbackGroupID
()
}
// 处理 ModelRouting:nil 时清除,否则设置
if
groupIn
.
ModelRouting
!=
nil
{
builder
=
builder
.
SetModelRouting
(
groupIn
.
ModelRouting
)
}
else
{
builder
=
builder
.
ClearModelRouting
()
}
updated
,
err
:=
builder
.
Save
(
ctx
)
if
err
!=
nil
{
return
translatePersistenceError
(
err
,
service
.
ErrGroupNotFound
,
service
.
ErrGroupExists
)
...
...
backend/internal/repository/ops_repo.go
View file @
6901b64f
...
...
@@ -55,7 +55,6 @@ INSERT INTO ops_error_logs (
upstream_error_message,
upstream_error_detail,
upstream_errors,
duration_ms,
time_to_first_token_ms,
request_body,
request_body_truncated,
...
...
@@ -65,7 +64,7 @@ INSERT INTO ops_error_logs (
retry_count,
created_at
) VALUES (
$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31,$32,$33,$34
,$35
$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31,$32,$33,$34
) RETURNING id`
var
id
int64
...
...
@@ -98,7 +97,6 @@ INSERT INTO ops_error_logs (
opsNullString
(
input
.
UpstreamErrorMessage
),
opsNullString
(
input
.
UpstreamErrorDetail
),
opsNullString
(
input
.
UpstreamErrorsJSON
),
opsNullInt
(
input
.
DurationMs
),
opsNullInt64
(
input
.
TimeToFirstTokenMs
),
opsNullString
(
input
.
RequestBodyJSON
),
input
.
RequestBodyTruncated
,
...
...
@@ -135,7 +133,7 @@ func (r *opsRepository) ListErrorLogs(ctx context.Context, filter *service.OpsEr
}
where
,
args
:=
buildOpsErrorLogsWhere
(
filter
)
countSQL
:=
"SELECT COUNT(*) FROM ops_error_logs "
+
where
countSQL
:=
"SELECT COUNT(*) FROM ops_error_logs
e
"
+
where
var
total
int
if
err
:=
r
.
db
.
QueryRowContext
(
ctx
,
countSQL
,
args
...
)
.
Scan
(
&
total
);
err
!=
nil
{
...
...
@@ -146,28 +144,43 @@ func (r *opsRepository) ListErrorLogs(ctx context.Context, filter *service.OpsEr
argsWithLimit
:=
append
(
args
,
pageSize
,
offset
)
selectSQL
:=
`
SELECT
id,
created_at,
error_phase,
error_type,
severity,
COALESCE(upstream_status_code, status_code, 0),
COALESCE(platform, ''),
COALESCE(model, ''),
duration_ms,
COALESCE(client_request_id, ''),
COALESCE(request_id, ''),
COALESCE(error_message, ''),
user_id,
api_key_id,
account_id,
group_id,
CASE WHEN client_ip IS NULL THEN NULL ELSE client_ip::text END,
COALESCE(request_path, ''),
stream
FROM ops_error_logs
e.id,
e.created_at,
e.error_phase,
e.error_type,
COALESCE(e.error_owner, ''),
COALESCE(e.error_source, ''),
e.severity,
COALESCE(e.upstream_status_code, e.status_code, 0),
COALESCE(e.platform, ''),
COALESCE(e.model, ''),
COALESCE(e.is_retryable, false),
COALESCE(e.retry_count, 0),
COALESCE(e.resolved, false),
e.resolved_at,
e.resolved_by_user_id,
COALESCE(u2.email, ''),
e.resolved_retry_id,
COALESCE(e.client_request_id, ''),
COALESCE(e.request_id, ''),
COALESCE(e.error_message, ''),
e.user_id,
COALESCE(u.email, ''),
e.api_key_id,
e.account_id,
COALESCE(a.name, ''),
e.group_id,
COALESCE(g.name, ''),
CASE WHEN e.client_ip IS NULL THEN NULL ELSE e.client_ip::text END,
COALESCE(e.request_path, ''),
e.stream
FROM ops_error_logs e
LEFT JOIN accounts a ON e.account_id = a.id
LEFT JOIN groups g ON e.group_id = g.id
LEFT JOIN users u ON e.user_id = u.id
LEFT JOIN users u2 ON e.resolved_by_user_id = u2.id
`
+
where
+
`
ORDER BY created_at DESC
ORDER BY
e.
created_at DESC
LIMIT $`
+
itoa
(
len
(
args
)
+
1
)
+
` OFFSET $`
+
itoa
(
len
(
args
)
+
2
)
rows
,
err
:=
r
.
db
.
QueryContext
(
ctx
,
selectSQL
,
argsWithLimit
...
)
...
...
@@ -179,39 +192,65 @@ LIMIT $` + itoa(len(args)+1) + ` OFFSET $` + itoa(len(args)+2)
out
:=
make
([]
*
service
.
OpsErrorLog
,
0
,
pageSize
)
for
rows
.
Next
()
{
var
item
service
.
OpsErrorLog
var
latency
sql
.
NullInt64
var
statusCode
sql
.
NullInt64
var
clientIP
sql
.
NullString
var
userID
sql
.
NullInt64
var
apiKeyID
sql
.
NullInt64
var
accountID
sql
.
NullInt64
var
accountName
string
var
groupID
sql
.
NullInt64
var
groupName
string
var
userEmail
string
var
resolvedAt
sql
.
NullTime
var
resolvedBy
sql
.
NullInt64
var
resolvedByName
string
var
resolvedRetryID
sql
.
NullInt64
if
err
:=
rows
.
Scan
(
&
item
.
ID
,
&
item
.
CreatedAt
,
&
item
.
Phase
,
&
item
.
Type
,
&
item
.
Owner
,
&
item
.
Source
,
&
item
.
Severity
,
&
statusCode
,
&
item
.
Platform
,
&
item
.
Model
,
&
latency
,
&
item
.
IsRetryable
,
&
item
.
RetryCount
,
&
item
.
Resolved
,
&
resolvedAt
,
&
resolvedBy
,
&
resolvedByName
,
&
resolvedRetryID
,
&
item
.
ClientRequestID
,
&
item
.
RequestID
,
&
item
.
Message
,
&
userID
,
&
userEmail
,
&
apiKeyID
,
&
accountID
,
&
accountName
,
&
groupID
,
&
groupName
,
&
clientIP
,
&
item
.
RequestPath
,
&
item
.
Stream
,
);
err
!=
nil
{
return
nil
,
err
}
if
latency
.
Valid
{
v
:=
int
(
latency
.
Int64
)
item
.
LatencyMs
=
&
v
if
resolvedAt
.
Valid
{
t
:=
resolvedAt
.
Time
item
.
ResolvedAt
=
&
t
}
if
resolvedBy
.
Valid
{
v
:=
resolvedBy
.
Int64
item
.
ResolvedByUserID
=
&
v
}
item
.
ResolvedByUserName
=
resolvedByName
if
resolvedRetryID
.
Valid
{
v
:=
resolvedRetryID
.
Int64
item
.
ResolvedRetryID
=
&
v
}
item
.
StatusCode
=
int
(
statusCode
.
Int64
)
if
clientIP
.
Valid
{
...
...
@@ -222,6 +261,7 @@ LIMIT $` + itoa(len(args)+1) + ` OFFSET $` + itoa(len(args)+2)
v
:=
userID
.
Int64
item
.
UserID
=
&
v
}
item
.
UserEmail
=
userEmail
if
apiKeyID
.
Valid
{
v
:=
apiKeyID
.
Int64
item
.
APIKeyID
=
&
v
...
...
@@ -230,10 +270,12 @@ LIMIT $` + itoa(len(args)+1) + ` OFFSET $` + itoa(len(args)+2)
v
:=
accountID
.
Int64
item
.
AccountID
=
&
v
}
item
.
AccountName
=
accountName
if
groupID
.
Valid
{
v
:=
groupID
.
Int64
item
.
GroupID
=
&
v
}
item
.
GroupName
=
groupName
out
=
append
(
out
,
&
item
)
}
if
err
:=
rows
.
Err
();
err
!=
nil
{
...
...
@@ -258,49 +300,64 @@ func (r *opsRepository) GetErrorLogByID(ctx context.Context, id int64) (*service
q
:=
`
SELECT
id,
created_at,
error_phase,
error_type,
severity,
COALESCE(upstream_status_code, status_code, 0),
COALESCE(platform, ''),
COALESCE(model, ''),
duration_ms,
COALESCE(client_request_id, ''),
COALESCE(request_id, ''),
COALESCE(error_message, ''),
COALESCE(error_body, ''),
upstream_status_code,
COALESCE(upstream_error_message, ''),
COALESCE(upstream_error_detail, ''),
COALESCE(upstream_errors::text, ''),
is_business_limited,
user_id,
api_key_id,
account_id,
group_id,
CASE WHEN client_ip IS NULL THEN NULL ELSE client_ip::text END,
COALESCE(request_path, ''),
stream,
COALESCE(user_agent, ''),
auth_latency_ms,
routing_latency_ms,
upstream_latency_ms,
response_latency_ms,
time_to_first_token_ms,
COALESCE(request_body::text, ''),
request_body_truncated,
request_body_bytes,
COALESCE(request_headers::text, '')
FROM ops_error_logs
WHERE id = $1
e.id,
e.created_at,
e.error_phase,
e.error_type,
COALESCE(e.error_owner, ''),
COALESCE(e.error_source, ''),
e.severity,
COALESCE(e.upstream_status_code, e.status_code, 0),
COALESCE(e.platform, ''),
COALESCE(e.model, ''),
COALESCE(e.is_retryable, false),
COALESCE(e.retry_count, 0),
COALESCE(e.resolved, false),
e.resolved_at,
e.resolved_by_user_id,
e.resolved_retry_id,
COALESCE(e.client_request_id, ''),
COALESCE(e.request_id, ''),
COALESCE(e.error_message, ''),
COALESCE(e.error_body, ''),
e.upstream_status_code,
COALESCE(e.upstream_error_message, ''),
COALESCE(e.upstream_error_detail, ''),
COALESCE(e.upstream_errors::text, ''),
e.is_business_limited,
e.user_id,
COALESCE(u.email, ''),
e.api_key_id,
e.account_id,
COALESCE(a.name, ''),
e.group_id,
COALESCE(g.name, ''),
CASE WHEN e.client_ip IS NULL THEN NULL ELSE e.client_ip::text END,
COALESCE(e.request_path, ''),
e.stream,
COALESCE(e.user_agent, ''),
e.auth_latency_ms,
e.routing_latency_ms,
e.upstream_latency_ms,
e.response_latency_ms,
e.time_to_first_token_ms,
COALESCE(e.request_body::text, ''),
e.request_body_truncated,
e.request_body_bytes,
COALESCE(e.request_headers::text, '')
FROM ops_error_logs e
LEFT JOIN users u ON e.user_id = u.id
LEFT JOIN accounts a ON e.account_id = a.id
LEFT JOIN groups g ON e.group_id = g.id
WHERE e.id = $1
LIMIT 1`
var
out
service
.
OpsErrorLogDetail
var
latency
sql
.
NullInt64
var
statusCode
sql
.
NullInt64
var
upstreamStatusCode
sql
.
NullInt64
var
resolvedAt
sql
.
NullTime
var
resolvedBy
sql
.
NullInt64
var
resolvedRetryID
sql
.
NullInt64
var
clientIP
sql
.
NullString
var
userID
sql
.
NullInt64
var
apiKeyID
sql
.
NullInt64
...
...
@@ -318,11 +375,18 @@ LIMIT 1`
&
out
.
CreatedAt
,
&
out
.
Phase
,
&
out
.
Type
,
&
out
.
Owner
,
&
out
.
Source
,
&
out
.
Severity
,
&
statusCode
,
&
out
.
Platform
,
&
out
.
Model
,
&
latency
,
&
out
.
IsRetryable
,
&
out
.
RetryCount
,
&
out
.
Resolved
,
&
resolvedAt
,
&
resolvedBy
,
&
resolvedRetryID
,
&
out
.
ClientRequestID
,
&
out
.
RequestID
,
&
out
.
Message
,
...
...
@@ -333,9 +397,12 @@ LIMIT 1`
&
out
.
UpstreamErrors
,
&
out
.
IsBusinessLimited
,
&
userID
,
&
out
.
UserEmail
,
&
apiKeyID
,
&
accountID
,
&
out
.
AccountName
,
&
groupID
,
&
out
.
GroupName
,
&
clientIP
,
&
out
.
RequestPath
,
&
out
.
Stream
,
...
...
@@ -355,9 +422,17 @@ LIMIT 1`
}
out
.
StatusCode
=
int
(
statusCode
.
Int64
)
if
latency
.
Valid
{
v
:=
int
(
latency
.
Int64
)
out
.
LatencyMs
=
&
v
if
resolvedAt
.
Valid
{
t
:=
resolvedAt
.
Time
out
.
ResolvedAt
=
&
t
}
if
resolvedBy
.
Valid
{
v
:=
resolvedBy
.
Int64
out
.
ResolvedByUserID
=
&
v
}
if
resolvedRetryID
.
Valid
{
v
:=
resolvedRetryID
.
Int64
out
.
ResolvedRetryID
=
&
v
}
if
clientIP
.
Valid
{
s
:=
clientIP
.
String
...
...
@@ -487,9 +562,15 @@ SET
status = $2,
finished_at = $3,
duration_ms = $4,
result_request_id = $5,
result_error_id = $6,
error_message = $7
success = $5,
http_status_code = $6,
upstream_request_id = $7,
used_account_id = $8,
response_preview = $9,
response_truncated = $10,
result_request_id = $11,
result_error_id = $12,
error_message = $13
WHERE id = $1`
_
,
err
:=
r
.
db
.
ExecContext
(
...
...
@@ -499,8 +580,14 @@ WHERE id = $1`
strings
.
TrimSpace
(
input
.
Status
),
nullTime
(
input
.
FinishedAt
),
input
.
DurationMs
,
nullBool
(
input
.
Success
),
nullInt
(
input
.
HTTPStatusCode
),
opsNullString
(
input
.
UpstreamRequestID
),
nullInt64
(
input
.
UsedAccountID
),
opsNullString
(
input
.
ResponsePreview
),
nullBool
(
input
.
ResponseTruncated
),
opsNullString
(
input
.
ResultRequestID
),
opsN
ullInt64
(
input
.
ResultErrorID
),
n
ullInt64
(
input
.
ResultErrorID
),
opsNullString
(
input
.
ErrorMessage
),
)
return
err
...
...
@@ -526,6 +613,12 @@ SELECT
started_at,
finished_at,
duration_ms,
success,
http_status_code,
upstream_request_id,
used_account_id,
response_preview,
response_truncated,
result_request_id,
result_error_id,
error_message
...
...
@@ -540,6 +633,12 @@ LIMIT 1`
var
startedAt
sql
.
NullTime
var
finishedAt
sql
.
NullTime
var
durationMs
sql
.
NullInt64
var
success
sql
.
NullBool
var
httpStatusCode
sql
.
NullInt64
var
upstreamRequestID
sql
.
NullString
var
usedAccountID
sql
.
NullInt64
var
responsePreview
sql
.
NullString
var
responseTruncated
sql
.
NullBool
var
resultRequestID
sql
.
NullString
var
resultErrorID
sql
.
NullInt64
var
errorMessage
sql
.
NullString
...
...
@@ -555,6 +654,12 @@ LIMIT 1`
&
startedAt
,
&
finishedAt
,
&
durationMs
,
&
success
,
&
httpStatusCode
,
&
upstreamRequestID
,
&
usedAccountID
,
&
responsePreview
,
&
responseTruncated
,
&
resultRequestID
,
&
resultErrorID
,
&
errorMessage
,
...
...
@@ -579,6 +684,30 @@ LIMIT 1`
v
:=
durationMs
.
Int64
out
.
DurationMs
=
&
v
}
if
success
.
Valid
{
v
:=
success
.
Bool
out
.
Success
=
&
v
}
if
httpStatusCode
.
Valid
{
v
:=
int
(
httpStatusCode
.
Int64
)
out
.
HTTPStatusCode
=
&
v
}
if
upstreamRequestID
.
Valid
{
s
:=
upstreamRequestID
.
String
out
.
UpstreamRequestID
=
&
s
}
if
usedAccountID
.
Valid
{
v
:=
usedAccountID
.
Int64
out
.
UsedAccountID
=
&
v
}
if
responsePreview
.
Valid
{
s
:=
responsePreview
.
String
out
.
ResponsePreview
=
&
s
}
if
responseTruncated
.
Valid
{
v
:=
responseTruncated
.
Bool
out
.
ResponseTruncated
=
&
v
}
if
resultRequestID
.
Valid
{
s
:=
resultRequestID
.
String
out
.
ResultRequestID
=
&
s
...
...
@@ -602,30 +731,234 @@ func nullTime(t time.Time) sql.NullTime {
return
sql
.
NullTime
{
Time
:
t
,
Valid
:
true
}
}
func
nullBool
(
v
*
bool
)
sql
.
NullBool
{
if
v
==
nil
{
return
sql
.
NullBool
{}
}
return
sql
.
NullBool
{
Bool
:
*
v
,
Valid
:
true
}
}
func
(
r
*
opsRepository
)
ListRetryAttemptsByErrorID
(
ctx
context
.
Context
,
sourceErrorID
int64
,
limit
int
)
([]
*
service
.
OpsRetryAttempt
,
error
)
{
if
r
==
nil
||
r
.
db
==
nil
{
return
nil
,
fmt
.
Errorf
(
"nil ops repository"
)
}
if
sourceErrorID
<=
0
{
return
nil
,
fmt
.
Errorf
(
"invalid source_error_id"
)
}
if
limit
<=
0
{
limit
=
50
}
if
limit
>
200
{
limit
=
200
}
q
:=
`
SELECT
r.id,
r.created_at,
COALESCE(r.requested_by_user_id, 0),
r.source_error_id,
COALESCE(r.mode, ''),
r.pinned_account_id,
COALESCE(pa.name, ''),
COALESCE(r.status, ''),
r.started_at,
r.finished_at,
r.duration_ms,
r.success,
r.http_status_code,
r.upstream_request_id,
r.used_account_id,
COALESCE(ua.name, ''),
r.response_preview,
r.response_truncated,
r.result_request_id,
r.result_error_id,
r.error_message
FROM ops_retry_attempts r
LEFT JOIN accounts pa ON r.pinned_account_id = pa.id
LEFT JOIN accounts ua ON r.used_account_id = ua.id
WHERE r.source_error_id = $1
ORDER BY r.created_at DESC
LIMIT $2`
rows
,
err
:=
r
.
db
.
QueryContext
(
ctx
,
q
,
sourceErrorID
,
limit
)
if
err
!=
nil
{
return
nil
,
err
}
defer
func
()
{
_
=
rows
.
Close
()
}()
out
:=
make
([]
*
service
.
OpsRetryAttempt
,
0
,
16
)
for
rows
.
Next
()
{
var
item
service
.
OpsRetryAttempt
var
pinnedAccountID
sql
.
NullInt64
var
pinnedAccountName
string
var
requestedBy
sql
.
NullInt64
var
startedAt
sql
.
NullTime
var
finishedAt
sql
.
NullTime
var
durationMs
sql
.
NullInt64
var
success
sql
.
NullBool
var
httpStatusCode
sql
.
NullInt64
var
upstreamRequestID
sql
.
NullString
var
usedAccountID
sql
.
NullInt64
var
usedAccountName
string
var
responsePreview
sql
.
NullString
var
responseTruncated
sql
.
NullBool
var
resultRequestID
sql
.
NullString
var
resultErrorID
sql
.
NullInt64
var
errorMessage
sql
.
NullString
if
err
:=
rows
.
Scan
(
&
item
.
ID
,
&
item
.
CreatedAt
,
&
requestedBy
,
&
item
.
SourceErrorID
,
&
item
.
Mode
,
&
pinnedAccountID
,
&
pinnedAccountName
,
&
item
.
Status
,
&
startedAt
,
&
finishedAt
,
&
durationMs
,
&
success
,
&
httpStatusCode
,
&
upstreamRequestID
,
&
usedAccountID
,
&
usedAccountName
,
&
responsePreview
,
&
responseTruncated
,
&
resultRequestID
,
&
resultErrorID
,
&
errorMessage
,
);
err
!=
nil
{
return
nil
,
err
}
item
.
RequestedByUserID
=
requestedBy
.
Int64
if
pinnedAccountID
.
Valid
{
v
:=
pinnedAccountID
.
Int64
item
.
PinnedAccountID
=
&
v
}
item
.
PinnedAccountName
=
pinnedAccountName
if
startedAt
.
Valid
{
t
:=
startedAt
.
Time
item
.
StartedAt
=
&
t
}
if
finishedAt
.
Valid
{
t
:=
finishedAt
.
Time
item
.
FinishedAt
=
&
t
}
if
durationMs
.
Valid
{
v
:=
durationMs
.
Int64
item
.
DurationMs
=
&
v
}
if
success
.
Valid
{
v
:=
success
.
Bool
item
.
Success
=
&
v
}
if
httpStatusCode
.
Valid
{
v
:=
int
(
httpStatusCode
.
Int64
)
item
.
HTTPStatusCode
=
&
v
}
if
upstreamRequestID
.
Valid
{
item
.
UpstreamRequestID
=
&
upstreamRequestID
.
String
}
if
usedAccountID
.
Valid
{
v
:=
usedAccountID
.
Int64
item
.
UsedAccountID
=
&
v
}
item
.
UsedAccountName
=
usedAccountName
if
responsePreview
.
Valid
{
item
.
ResponsePreview
=
&
responsePreview
.
String
}
if
responseTruncated
.
Valid
{
v
:=
responseTruncated
.
Bool
item
.
ResponseTruncated
=
&
v
}
if
resultRequestID
.
Valid
{
item
.
ResultRequestID
=
&
resultRequestID
.
String
}
if
resultErrorID
.
Valid
{
v
:=
resultErrorID
.
Int64
item
.
ResultErrorID
=
&
v
}
if
errorMessage
.
Valid
{
item
.
ErrorMessage
=
&
errorMessage
.
String
}
out
=
append
(
out
,
&
item
)
}
if
err
:=
rows
.
Err
();
err
!=
nil
{
return
nil
,
err
}
return
out
,
nil
}
func
(
r
*
opsRepository
)
UpdateErrorResolution
(
ctx
context
.
Context
,
errorID
int64
,
resolved
bool
,
resolvedByUserID
*
int64
,
resolvedRetryID
*
int64
,
resolvedAt
*
time
.
Time
)
error
{
if
r
==
nil
||
r
.
db
==
nil
{
return
fmt
.
Errorf
(
"nil ops repository"
)
}
if
errorID
<=
0
{
return
fmt
.
Errorf
(
"invalid error id"
)
}
q
:=
`
UPDATE ops_error_logs
SET
resolved = $2,
resolved_at = $3,
resolved_by_user_id = $4,
resolved_retry_id = $5
WHERE id = $1`
at
:=
sql
.
NullTime
{}
if
resolvedAt
!=
nil
&&
!
resolvedAt
.
IsZero
()
{
at
=
sql
.
NullTime
{
Time
:
resolvedAt
.
UTC
(),
Valid
:
true
}
}
else
if
resolved
{
now
:=
time
.
Now
()
.
UTC
()
at
=
sql
.
NullTime
{
Time
:
now
,
Valid
:
true
}
}
_
,
err
:=
r
.
db
.
ExecContext
(
ctx
,
q
,
errorID
,
resolved
,
at
,
nullInt64
(
resolvedByUserID
),
nullInt64
(
resolvedRetryID
),
)
return
err
}
func
buildOpsErrorLogsWhere
(
filter
*
service
.
OpsErrorLogFilter
)
(
string
,
[]
any
)
{
clauses
:=
make
([]
string
,
0
,
8
)
args
:=
make
([]
any
,
0
,
8
)
clauses
:=
make
([]
string
,
0
,
12
)
args
:=
make
([]
any
,
0
,
12
)
clauses
=
append
(
clauses
,
"1=1"
)
phaseFilter
:=
""
if
filter
!=
nil
{
phaseFilter
=
strings
.
TrimSpace
(
strings
.
ToLower
(
filter
.
Phase
))
}
// ops_error_logs
primarily
stores client-visible error requests (status>=400),
// ops_error_logs stores client-visible error requests (status>=400),
// but we also persist "recovered" upstream errors (status<400) for upstream health visibility.
// By default, keep list endpoints scoped to client errors unless explicitly filtering upstream phase.
// If Resolved is not specified, do not filter by resolved state (backward-compatible).
resolvedFilter
:=
(
*
bool
)(
nil
)
if
filter
!=
nil
{
resolvedFilter
=
filter
.
Resolved
}
// Keep list endpoints scoped to client errors unless explicitly filtering upstream phase.
if
phaseFilter
!=
"upstream"
{
clauses
=
append
(
clauses
,
"COALESCE(status_code, 0) >= 400"
)
}
if
filter
.
StartTime
!=
nil
&&
!
filter
.
StartTime
.
IsZero
()
{
args
=
append
(
args
,
filter
.
StartTime
.
UTC
())
clauses
=
append
(
clauses
,
"created_at >= $"
+
itoa
(
len
(
args
)))
clauses
=
append
(
clauses
,
"
e.
created_at >= $"
+
itoa
(
len
(
args
)))
}
if
filter
.
EndTime
!=
nil
&&
!
filter
.
EndTime
.
IsZero
()
{
args
=
append
(
args
,
filter
.
EndTime
.
UTC
())
// Keep time-window semantics consistent with other ops queries: [start, end)
clauses
=
append
(
clauses
,
"created_at < $"
+
itoa
(
len
(
args
)))
clauses
=
append
(
clauses
,
"
e.
created_at < $"
+
itoa
(
len
(
args
)))
}
if
p
:=
strings
.
TrimSpace
(
filter
.
Platform
);
p
!=
""
{
args
=
append
(
args
,
p
)
...
...
@@ -643,10 +976,59 @@ func buildOpsErrorLogsWhere(filter *service.OpsErrorLogFilter) (string, []any) {
args
=
append
(
args
,
phase
)
clauses
=
append
(
clauses
,
"error_phase = $"
+
itoa
(
len
(
args
)))
}
if
filter
!=
nil
{
if
owner
:=
strings
.
TrimSpace
(
strings
.
ToLower
(
filter
.
Owner
));
owner
!=
""
{
args
=
append
(
args
,
owner
)
clauses
=
append
(
clauses
,
"LOWER(COALESCE(error_owner,'')) = $"
+
itoa
(
len
(
args
)))
}
if
source
:=
strings
.
TrimSpace
(
strings
.
ToLower
(
filter
.
Source
));
source
!=
""
{
args
=
append
(
args
,
source
)
clauses
=
append
(
clauses
,
"LOWER(COALESCE(error_source,'')) = $"
+
itoa
(
len
(
args
)))
}
}
if
resolvedFilter
!=
nil
{
args
=
append
(
args
,
*
resolvedFilter
)
clauses
=
append
(
clauses
,
"COALESCE(resolved,false) = $"
+
itoa
(
len
(
args
)))
}
// View filter: errors vs excluded vs all.
// Excluded = upstream 429/529 and business-limited (quota/concurrency/billing) errors.
view
:=
""
if
filter
!=
nil
{
view
=
strings
.
ToLower
(
strings
.
TrimSpace
(
filter
.
View
))
}
switch
view
{
case
""
,
"errors"
:
clauses
=
append
(
clauses
,
"COALESCE(is_business_limited,false) = false"
)
clauses
=
append
(
clauses
,
"COALESCE(upstream_status_code, status_code, 0) NOT IN (429, 529)"
)
case
"excluded"
:
clauses
=
append
(
clauses
,
"(COALESCE(is_business_limited,false) = true OR COALESCE(upstream_status_code, status_code, 0) IN (429, 529))"
)
case
"all"
:
// no-op
default
:
// treat unknown as default 'errors'
clauses
=
append
(
clauses
,
"COALESCE(is_business_limited,false) = false"
)
clauses
=
append
(
clauses
,
"COALESCE(upstream_status_code, status_code, 0) NOT IN (429, 529)"
)
}
if
len
(
filter
.
StatusCodes
)
>
0
{
args
=
append
(
args
,
pq
.
Array
(
filter
.
StatusCodes
))
clauses
=
append
(
clauses
,
"COALESCE(upstream_status_code, status_code, 0) = ANY($"
+
itoa
(
len
(
args
))
+
")"
)
}
else
if
filter
.
StatusCodesOther
{
// "Other" means: status codes not in the common list.
known
:=
[]
int
{
400
,
401
,
403
,
404
,
409
,
422
,
429
,
500
,
502
,
503
,
504
,
529
}
args
=
append
(
args
,
pq
.
Array
(
known
))
clauses
=
append
(
clauses
,
"NOT (COALESCE(upstream_status_code, status_code, 0) = ANY($"
+
itoa
(
len
(
args
))
+
"))"
)
}
// Exact correlation keys (preferred for request↔upstream linkage).
if
rid
:=
strings
.
TrimSpace
(
filter
.
RequestID
);
rid
!=
""
{
args
=
append
(
args
,
rid
)
clauses
=
append
(
clauses
,
"COALESCE(request_id,'') = $"
+
itoa
(
len
(
args
)))
}
if
crid
:=
strings
.
TrimSpace
(
filter
.
ClientRequestID
);
crid
!=
""
{
args
=
append
(
args
,
crid
)
clauses
=
append
(
clauses
,
"COALESCE(client_request_id,'') = $"
+
itoa
(
len
(
args
)))
}
if
q
:=
strings
.
TrimSpace
(
filter
.
Query
);
q
!=
""
{
like
:=
"%"
+
q
+
"%"
args
=
append
(
args
,
like
)
...
...
@@ -654,6 +1036,13 @@ func buildOpsErrorLogsWhere(filter *service.OpsErrorLogFilter) (string, []any) {
clauses
=
append
(
clauses
,
"(request_id ILIKE $"
+
n
+
" OR client_request_id ILIKE $"
+
n
+
" OR error_message ILIKE $"
+
n
+
")"
)
}
if
userQuery
:=
strings
.
TrimSpace
(
filter
.
UserQuery
);
userQuery
!=
""
{
like
:=
"%"
+
userQuery
+
"%"
args
=
append
(
args
,
like
)
n
:=
itoa
(
len
(
args
))
clauses
=
append
(
clauses
,
"u.email ILIKE $"
+
n
)
}
return
"WHERE "
+
strings
.
Join
(
clauses
,
" AND "
),
args
}
...
...
backend/internal/repository/ops_repo_alerts.go
View file @
6901b64f
...
...
@@ -354,7 +354,7 @@ SELECT
created_at
FROM ops_alert_events
`
+
where
+
`
ORDER BY fired_at DESC
ORDER BY fired_at DESC
, id DESC
LIMIT `
+
limitArg
rows
,
err
:=
r
.
db
.
QueryContext
(
ctx
,
q
,
args
...
)
...
...
@@ -413,6 +413,43 @@ LIMIT ` + limitArg
return
out
,
nil
}
func
(
r
*
opsRepository
)
GetAlertEventByID
(
ctx
context
.
Context
,
eventID
int64
)
(
*
service
.
OpsAlertEvent
,
error
)
{
if
r
==
nil
||
r
.
db
==
nil
{
return
nil
,
fmt
.
Errorf
(
"nil ops repository"
)
}
if
eventID
<=
0
{
return
nil
,
fmt
.
Errorf
(
"invalid event id"
)
}
q
:=
`
SELECT
id,
COALESCE(rule_id, 0),
COALESCE(severity, ''),
COALESCE(status, ''),
COALESCE(title, ''),
COALESCE(description, ''),
metric_value,
threshold_value,
dimensions,
fired_at,
resolved_at,
email_sent,
created_at
FROM ops_alert_events
WHERE id = $1`
row
:=
r
.
db
.
QueryRowContext
(
ctx
,
q
,
eventID
)
ev
,
err
:=
scanOpsAlertEvent
(
row
)
if
err
!=
nil
{
if
err
==
sql
.
ErrNoRows
{
return
nil
,
nil
}
return
nil
,
err
}
return
ev
,
nil
}
func
(
r
*
opsRepository
)
GetActiveAlertEvent
(
ctx
context
.
Context
,
ruleID
int64
)
(
*
service
.
OpsAlertEvent
,
error
)
{
if
r
==
nil
||
r
.
db
==
nil
{
return
nil
,
fmt
.
Errorf
(
"nil ops repository"
)
...
...
@@ -591,6 +628,121 @@ type opsAlertEventRow interface {
Scan
(
dest
...
any
)
error
}
func
(
r
*
opsRepository
)
CreateAlertSilence
(
ctx
context
.
Context
,
input
*
service
.
OpsAlertSilence
)
(
*
service
.
OpsAlertSilence
,
error
)
{
if
r
==
nil
||
r
.
db
==
nil
{
return
nil
,
fmt
.
Errorf
(
"nil ops repository"
)
}
if
input
==
nil
{
return
nil
,
fmt
.
Errorf
(
"nil input"
)
}
if
input
.
RuleID
<=
0
{
return
nil
,
fmt
.
Errorf
(
"invalid rule_id"
)
}
platform
:=
strings
.
TrimSpace
(
input
.
Platform
)
if
platform
==
""
{
return
nil
,
fmt
.
Errorf
(
"invalid platform"
)
}
if
input
.
Until
.
IsZero
()
{
return
nil
,
fmt
.
Errorf
(
"invalid until"
)
}
q
:=
`
INSERT INTO ops_alert_silences (
rule_id,
platform,
group_id,
region,
until,
reason,
created_by,
created_at
) VALUES (
$1,$2,$3,$4,$5,$6,$7,NOW()
)
RETURNING id, rule_id, platform, group_id, region, until, COALESCE(reason,''), created_by, created_at`
row
:=
r
.
db
.
QueryRowContext
(
ctx
,
q
,
input
.
RuleID
,
platform
,
opsNullInt64
(
input
.
GroupID
),
opsNullString
(
input
.
Region
),
input
.
Until
,
opsNullString
(
input
.
Reason
),
opsNullInt64
(
input
.
CreatedBy
),
)
var
out
service
.
OpsAlertSilence
var
groupID
sql
.
NullInt64
var
region
sql
.
NullString
var
createdBy
sql
.
NullInt64
if
err
:=
row
.
Scan
(
&
out
.
ID
,
&
out
.
RuleID
,
&
out
.
Platform
,
&
groupID
,
&
region
,
&
out
.
Until
,
&
out
.
Reason
,
&
createdBy
,
&
out
.
CreatedAt
,
);
err
!=
nil
{
return
nil
,
err
}
if
groupID
.
Valid
{
v
:=
groupID
.
Int64
out
.
GroupID
=
&
v
}
if
region
.
Valid
{
v
:=
strings
.
TrimSpace
(
region
.
String
)
if
v
!=
""
{
out
.
Region
=
&
v
}
}
if
createdBy
.
Valid
{
v
:=
createdBy
.
Int64
out
.
CreatedBy
=
&
v
}
return
&
out
,
nil
}
func
(
r
*
opsRepository
)
IsAlertSilenced
(
ctx
context
.
Context
,
ruleID
int64
,
platform
string
,
groupID
*
int64
,
region
*
string
,
now
time
.
Time
)
(
bool
,
error
)
{
if
r
==
nil
||
r
.
db
==
nil
{
return
false
,
fmt
.
Errorf
(
"nil ops repository"
)
}
if
ruleID
<=
0
{
return
false
,
fmt
.
Errorf
(
"invalid rule id"
)
}
platform
=
strings
.
TrimSpace
(
platform
)
if
platform
==
""
{
return
false
,
nil
}
if
now
.
IsZero
()
{
now
=
time
.
Now
()
.
UTC
()
}
q
:=
`
SELECT 1
FROM ops_alert_silences
WHERE rule_id = $1
AND platform = $2
AND (group_id IS NOT DISTINCT FROM $3)
AND (region IS NOT DISTINCT FROM $4)
AND until > $5
LIMIT 1`
var
dummy
int
err
:=
r
.
db
.
QueryRowContext
(
ctx
,
q
,
ruleID
,
platform
,
opsNullInt64
(
groupID
),
opsNullString
(
region
),
now
)
.
Scan
(
&
dummy
)
if
err
!=
nil
{
if
err
==
sql
.
ErrNoRows
{
return
false
,
nil
}
return
false
,
err
}
return
true
,
nil
}
func
scanOpsAlertEvent
(
row
opsAlertEventRow
)
(
*
service
.
OpsAlertEvent
,
error
)
{
var
ev
service
.
OpsAlertEvent
var
metricValue
sql
.
NullFloat64
...
...
@@ -652,6 +804,10 @@ func buildOpsAlertEventsWhere(filter *service.OpsAlertEventFilter) (string, []an
args
=
append
(
args
,
severity
)
clauses
=
append
(
clauses
,
"severity = $"
+
itoa
(
len
(
args
)))
}
if
filter
.
EmailSent
!=
nil
{
args
=
append
(
args
,
*
filter
.
EmailSent
)
clauses
=
append
(
clauses
,
"email_sent = $"
+
itoa
(
len
(
args
)))
}
if
filter
.
StartTime
!=
nil
&&
!
filter
.
StartTime
.
IsZero
()
{
args
=
append
(
args
,
*
filter
.
StartTime
)
clauses
=
append
(
clauses
,
"fired_at >= $"
+
itoa
(
len
(
args
)))
...
...
@@ -661,6 +817,14 @@ func buildOpsAlertEventsWhere(filter *service.OpsAlertEventFilter) (string, []an
clauses
=
append
(
clauses
,
"fired_at < $"
+
itoa
(
len
(
args
)))
}
// Cursor pagination (descending by fired_at, then id)
if
filter
.
BeforeFiredAt
!=
nil
&&
!
filter
.
BeforeFiredAt
.
IsZero
()
&&
filter
.
BeforeID
!=
nil
&&
*
filter
.
BeforeID
>
0
{
args
=
append
(
args
,
*
filter
.
BeforeFiredAt
)
tsArg
:=
"$"
+
itoa
(
len
(
args
))
args
=
append
(
args
,
*
filter
.
BeforeID
)
idArg
:=
"$"
+
itoa
(
len
(
args
))
clauses
=
append
(
clauses
,
fmt
.
Sprintf
(
"(fired_at < %s OR (fired_at = %s AND id < %s))"
,
tsArg
,
tsArg
,
idArg
))
}
// Dimensions are stored in JSONB. We filter best-effort without requiring GIN indexes.
if
platform
:=
strings
.
TrimSpace
(
filter
.
Platform
);
platform
!=
""
{
args
=
append
(
args
,
platform
)
...
...
backend/internal/repository/ops_repo_metrics.go
View file @
6901b64f
...
...
@@ -296,9 +296,10 @@ INSERT INTO ops_job_heartbeats (
last_error_at,
last_error,
last_duration_ms,
last_result,
updated_at
) VALUES (
$1,$2,$3,$4,$5,$6,NOW()
$1,$2,$3,$4,$5,$6,
$7,
NOW()
)
ON CONFLICT (job_name) DO UPDATE SET
last_run_at = COALESCE(EXCLUDED.last_run_at, ops_job_heartbeats.last_run_at),
...
...
@@ -312,6 +313,10 @@ ON CONFLICT (job_name) DO UPDATE SET
ELSE COALESCE(EXCLUDED.last_error, ops_job_heartbeats.last_error)
END,
last_duration_ms = COALESCE(EXCLUDED.last_duration_ms, ops_job_heartbeats.last_duration_ms),
last_result = CASE
WHEN EXCLUDED.last_success_at IS NOT NULL THEN COALESCE(EXCLUDED.last_result, ops_job_heartbeats.last_result)
ELSE ops_job_heartbeats.last_result
END,
updated_at = NOW()`
_
,
err
:=
r
.
db
.
ExecContext
(
...
...
@@ -323,6 +328,7 @@ ON CONFLICT (job_name) DO UPDATE SET
opsNullTime
(
input
.
LastErrorAt
),
opsNullString
(
input
.
LastError
),
opsNullInt
(
input
.
LastDurationMs
),
opsNullString
(
input
.
LastResult
),
)
return
err
}
...
...
@@ -340,6 +346,7 @@ SELECT
last_error_at,
last_error,
last_duration_ms,
last_result,
updated_at
FROM ops_job_heartbeats
ORDER BY job_name ASC`
...
...
@@ -359,6 +366,8 @@ ORDER BY job_name ASC`
var
lastError
sql
.
NullString
var
lastDuration
sql
.
NullInt64
var
lastResult
sql
.
NullString
if
err
:=
rows
.
Scan
(
&
item
.
JobName
,
&
lastRun
,
...
...
@@ -366,6 +375,7 @@ ORDER BY job_name ASC`
&
lastErrorAt
,
&
lastError
,
&
lastDuration
,
&
lastResult
,
&
item
.
UpdatedAt
,
);
err
!=
nil
{
return
nil
,
err
...
...
@@ -391,6 +401,10 @@ ORDER BY job_name ASC`
v
:=
lastDuration
.
Int64
item
.
LastDurationMs
=
&
v
}
if
lastResult
.
Valid
{
v
:=
lastResult
.
String
item
.
LastResult
=
&
v
}
out
=
append
(
out
,
&
item
)
}
...
...
backend/internal/repository/proxy_latency_cache.go
0 → 100644
View file @
6901b64f
package
repository
import
(
"context"
"encoding/json"
"fmt"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/redis/go-redis/v9"
)
const
proxyLatencyKeyPrefix
=
"proxy:latency:"
func
proxyLatencyKey
(
proxyID
int64
)
string
{
return
fmt
.
Sprintf
(
"%s%d"
,
proxyLatencyKeyPrefix
,
proxyID
)
}
type
proxyLatencyCache
struct
{
rdb
*
redis
.
Client
}
func
NewProxyLatencyCache
(
rdb
*
redis
.
Client
)
service
.
ProxyLatencyCache
{
return
&
proxyLatencyCache
{
rdb
:
rdb
}
}
func
(
c
*
proxyLatencyCache
)
GetProxyLatencies
(
ctx
context
.
Context
,
proxyIDs
[]
int64
)
(
map
[
int64
]
*
service
.
ProxyLatencyInfo
,
error
)
{
results
:=
make
(
map
[
int64
]
*
service
.
ProxyLatencyInfo
)
if
len
(
proxyIDs
)
==
0
{
return
results
,
nil
}
keys
:=
make
([]
string
,
0
,
len
(
proxyIDs
))
for
_
,
id
:=
range
proxyIDs
{
keys
=
append
(
keys
,
proxyLatencyKey
(
id
))
}
values
,
err
:=
c
.
rdb
.
MGet
(
ctx
,
keys
...
)
.
Result
()
if
err
!=
nil
{
return
results
,
err
}
for
i
,
raw
:=
range
values
{
if
raw
==
nil
{
continue
}
var
payload
[]
byte
switch
v
:=
raw
.
(
type
)
{
case
string
:
payload
=
[]
byte
(
v
)
case
[]
byte
:
payload
=
v
default
:
continue
}
var
info
service
.
ProxyLatencyInfo
if
err
:=
json
.
Unmarshal
(
payload
,
&
info
);
err
!=
nil
{
continue
}
results
[
proxyIDs
[
i
]]
=
&
info
}
return
results
,
nil
}
func
(
c
*
proxyLatencyCache
)
SetProxyLatency
(
ctx
context
.
Context
,
proxyID
int64
,
info
*
service
.
ProxyLatencyInfo
)
error
{
if
info
==
nil
{
return
nil
}
payload
,
err
:=
json
.
Marshal
(
info
)
if
err
!=
nil
{
return
err
}
return
c
.
rdb
.
Set
(
ctx
,
proxyLatencyKey
(
proxyID
),
payload
,
0
)
.
Err
()
}
backend/internal/repository/proxy_probe_service.go
View file @
6901b64f
...
...
@@ -7,6 +7,7 @@ import (
"io"
"log"
"net/http"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
...
...
@@ -34,7 +35,10 @@ func NewProxyExitInfoProber(cfg *config.Config) service.ProxyExitInfoProber {
}
}
const
defaultIPInfoURL
=
"https://ipinfo.io/json"
const
(
defaultIPInfoURL
=
"http://ip-api.com/json/?lang=zh-CN"
defaultProxyProbeTimeout
=
30
*
time
.
Second
)
type
proxyProbeService
struct
{
ipInfoURL
string
...
...
@@ -46,7 +50,7 @@ type proxyProbeService struct {
func
(
s
*
proxyProbeService
)
ProbeProxy
(
ctx
context
.
Context
,
proxyURL
string
)
(
*
service
.
ProxyExitInfo
,
int64
,
error
)
{
client
,
err
:=
httpclient
.
GetClient
(
httpclient
.
Options
{
ProxyURL
:
proxyURL
,
Timeout
:
15
*
time
.
Second
,
Timeout
:
defaultProxyProbeTimeout
,
InsecureSkipVerify
:
s
.
insecureSkipVerify
,
ProxyStrict
:
true
,
ValidateResolvedIP
:
s
.
validateResolvedIP
,
...
...
@@ -75,10 +79,14 @@ func (s *proxyProbeService) ProbeProxy(ctx context.Context, proxyURL string) (*s
}
var
ipInfo
struct
{
IP
string
`json:"ip"`
Status
string
`json:"status"`
Message
string
`json:"message"`
Query
string
`json:"query"`
City
string
`json:"city"`
Region
string
`json:"region"`
RegionName
string
`json:"regionName"`
Country
string
`json:"country"`
CountryCode
string
`json:"countryCode"`
}
body
,
err
:=
io
.
ReadAll
(
resp
.
Body
)
...
...
@@ -89,11 +97,22 @@ func (s *proxyProbeService) ProbeProxy(ctx context.Context, proxyURL string) (*s
if
err
:=
json
.
Unmarshal
(
body
,
&
ipInfo
);
err
!=
nil
{
return
nil
,
latencyMs
,
fmt
.
Errorf
(
"failed to parse response: %w"
,
err
)
}
if
strings
.
ToLower
(
ipInfo
.
Status
)
!=
"success"
{
if
ipInfo
.
Message
==
""
{
ipInfo
.
Message
=
"ip-api request failed"
}
return
nil
,
latencyMs
,
fmt
.
Errorf
(
"ip-api request failed: %s"
,
ipInfo
.
Message
)
}
region
:=
ipInfo
.
RegionName
if
region
==
""
{
region
=
ipInfo
.
Region
}
return
&
service
.
ProxyExitInfo
{
IP
:
ipInfo
.
IP
,
IP
:
ipInfo
.
Query
,
City
:
ipInfo
.
City
,
Region
:
ipInfo
.
R
egion
,
Region
:
r
egion
,
Country
:
ipInfo
.
Country
,
CountryCode
:
ipInfo
.
CountryCode
,
},
latencyMs
,
nil
}
backend/internal/repository/proxy_probe_service_test.go
View file @
6901b64f
...
...
@@ -21,7 +21,7 @@ type ProxyProbeServiceSuite struct {
func
(
s
*
ProxyProbeServiceSuite
)
SetupTest
()
{
s
.
ctx
=
context
.
Background
()
s
.
prober
=
&
proxyProbeService
{
ipInfoURL
:
"http://ip
info
.test/json"
,
ipInfoURL
:
"http://ip
-api
.test/json
/?lang=zh-CN
"
,
allowPrivateHosts
:
true
,
}
}
...
...
@@ -54,7 +54,7 @@ func (s *ProxyProbeServiceSuite) TestProbeProxy_Success() {
s
.
setupProxyServer
(
http
.
HandlerFunc
(
func
(
w
http
.
ResponseWriter
,
r
*
http
.
Request
)
{
seen
<-
r
.
RequestURI
w
.
Header
()
.
Set
(
"Content-Type"
,
"application/json"
)
_
,
_
=
io
.
WriteString
(
w
,
`{"
ip
":"1.2.3.4","city":"c","region":"r","country":"cc"}`
)
_
,
_
=
io
.
WriteString
(
w
,
`{"
status":"success","query
":"1.2.3.4","city":"c","region
Name
":"r","country":"cc"
,"countryCode":"CC"
}`
)
}))
info
,
latencyMs
,
err
:=
s
.
prober
.
ProbeProxy
(
s
.
ctx
,
s
.
proxySrv
.
URL
)
...
...
@@ -64,11 +64,12 @@ func (s *ProxyProbeServiceSuite) TestProbeProxy_Success() {
require
.
Equal
(
s
.
T
(),
"c"
,
info
.
City
)
require
.
Equal
(
s
.
T
(),
"r"
,
info
.
Region
)
require
.
Equal
(
s
.
T
(),
"cc"
,
info
.
Country
)
require
.
Equal
(
s
.
T
(),
"CC"
,
info
.
CountryCode
)
// Verify proxy received the request
select
{
case
uri
:=
<-
seen
:
require
.
Contains
(
s
.
T
(),
uri
,
"ip
info
.test"
,
"expected request to go through proxy"
)
require
.
Contains
(
s
.
T
(),
uri
,
"ip
-api
.test"
,
"expected request to go through proxy"
)
default
:
require
.
Fail
(
s
.
T
(),
"expected proxy to receive request"
)
}
...
...
backend/internal/repository/proxy_repo.go
View file @
6901b64f
...
...
@@ -219,12 +219,54 @@ func (r *proxyRepository) ExistsByHostPortAuth(ctx context.Context, host string,
// CountAccountsByProxyID returns the number of accounts using a specific proxy
func
(
r
*
proxyRepository
)
CountAccountsByProxyID
(
ctx
context
.
Context
,
proxyID
int64
)
(
int64
,
error
)
{
var
count
int64
if
err
:=
scanSingleRow
(
ctx
,
r
.
sql
,
"SELECT COUNT(*) FROM accounts WHERE proxy_id = $1"
,
[]
any
{
proxyID
},
&
count
);
err
!=
nil
{
if
err
:=
scanSingleRow
(
ctx
,
r
.
sql
,
"SELECT COUNT(*) FROM accounts WHERE proxy_id = $1
AND deleted_at IS NULL
"
,
[]
any
{
proxyID
},
&
count
);
err
!=
nil
{
return
0
,
err
}
return
count
,
nil
}
func
(
r
*
proxyRepository
)
ListAccountSummariesByProxyID
(
ctx
context
.
Context
,
proxyID
int64
)
([]
service
.
ProxyAccountSummary
,
error
)
{
rows
,
err
:=
r
.
sql
.
QueryContext
(
ctx
,
`
SELECT id, name, platform, type, notes
FROM accounts
WHERE proxy_id = $1 AND deleted_at IS NULL
ORDER BY id DESC
`
,
proxyID
)
if
err
!=
nil
{
return
nil
,
err
}
defer
func
()
{
_
=
rows
.
Close
()
}()
out
:=
make
([]
service
.
ProxyAccountSummary
,
0
)
for
rows
.
Next
()
{
var
(
id
int64
name
string
platform
string
accType
string
notes
sql
.
NullString
)
if
err
:=
rows
.
Scan
(
&
id
,
&
name
,
&
platform
,
&
accType
,
&
notes
);
err
!=
nil
{
return
nil
,
err
}
var
notesPtr
*
string
if
notes
.
Valid
{
notesPtr
=
&
notes
.
String
}
out
=
append
(
out
,
service
.
ProxyAccountSummary
{
ID
:
id
,
Name
:
name
,
Platform
:
platform
,
Type
:
accType
,
Notes
:
notesPtr
,
})
}
if
err
:=
rows
.
Err
();
err
!=
nil
{
return
nil
,
err
}
return
out
,
nil
}
// GetAccountCountsForProxies returns a map of proxy ID to account count for all proxies
func
(
r
*
proxyRepository
)
GetAccountCountsForProxies
(
ctx
context
.
Context
)
(
counts
map
[
int64
]
int64
,
err
error
)
{
rows
,
err
:=
r
.
sql
.
QueryContext
(
ctx
,
"SELECT proxy_id, COUNT(*) AS count FROM accounts WHERE proxy_id IS NOT NULL AND deleted_at IS NULL GROUP BY proxy_id"
)
...
...
backend/internal/repository/scheduler_snapshot_outbox_integration_test.go
View file @
6901b64f
backend/internal/repository/session_limit_cache.go
0 → 100644
View file @
6901b64f
package
repository
import
(
"context"
"fmt"
"strconv"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/redis/go-redis/v9"
)
// 会话限制缓存常量定义
//
// 设计说明:
// 使用 Redis 有序集合(Sorted Set)跟踪每个账号的活跃会话:
// - Key: session_limit:account:{accountID}
// - Member: sessionUUID(从 metadata.user_id 中提取)
// - Score: Unix 时间戳(会话最后活跃时间)
//
// 通过 ZREMRANGEBYSCORE 自动清理过期会话,无需手动管理 TTL
const
(
// 会话限制键前缀
// 格式: session_limit:account:{accountID}
sessionLimitKeyPrefix
=
"session_limit:account:"
// 窗口费用缓存键前缀
// 格式: window_cost:account:{accountID}
windowCostKeyPrefix
=
"window_cost:account:"
// 窗口费用缓存 TTL(30秒)
windowCostCacheTTL
=
30
*
time
.
Second
)
var
(
// registerSessionScript 注册会话活动
// 使用 Redis TIME 命令获取服务器时间,避免多实例时钟不同步
// KEYS[1] = session_limit:account:{accountID}
// ARGV[1] = maxSessions
// ARGV[2] = idleTimeout(秒)
// ARGV[3] = sessionUUID
// 返回: 1 = 允许, 0 = 拒绝
registerSessionScript
=
redis
.
NewScript
(
`
local key = KEYS[1]
local maxSessions = tonumber(ARGV[1])
local idleTimeout = tonumber(ARGV[2])
local sessionUUID = ARGV[3]
-- 使用 Redis 服务器时间,确保多实例时钟一致
local timeResult = redis.call('TIME')
local now = tonumber(timeResult[1])
local expireBefore = now - idleTimeout
-- 清理过期会话
redis.call('ZREMRANGEBYSCORE', key, '-inf', expireBefore)
-- 检查会话是否已存在(支持刷新时间戳)
local exists = redis.call('ZSCORE', key, sessionUUID)
if exists ~= false then
-- 会话已存在,刷新时间戳
redis.call('ZADD', key, now, sessionUUID)
redis.call('EXPIRE', key, idleTimeout + 60)
return 1
end
-- 检查是否达到会话数量上限
local count = redis.call('ZCARD', key)
if count < maxSessions then
-- 未达上限,添加新会话
redis.call('ZADD', key, now, sessionUUID)
redis.call('EXPIRE', key, idleTimeout + 60)
return 1
end
-- 达到上限,拒绝新会话
return 0
`
)
// refreshSessionScript 刷新会话时间戳
// KEYS[1] = session_limit:account:{accountID}
// ARGV[1] = idleTimeout(秒)
// ARGV[2] = sessionUUID
refreshSessionScript
=
redis
.
NewScript
(
`
local key = KEYS[1]
local idleTimeout = tonumber(ARGV[1])
local sessionUUID = ARGV[2]
local timeResult = redis.call('TIME')
local now = tonumber(timeResult[1])
-- 检查会话是否存在
local exists = redis.call('ZSCORE', key, sessionUUID)
if exists ~= false then
redis.call('ZADD', key, now, sessionUUID)
redis.call('EXPIRE', key, idleTimeout + 60)
end
return 1
`
)
// getActiveSessionCountScript 获取活跃会话数
// KEYS[1] = session_limit:account:{accountID}
// ARGV[1] = idleTimeout(秒)
getActiveSessionCountScript
=
redis
.
NewScript
(
`
local key = KEYS[1]
local idleTimeout = tonumber(ARGV[1])
local timeResult = redis.call('TIME')
local now = tonumber(timeResult[1])
local expireBefore = now - idleTimeout
-- 清理过期会话
redis.call('ZREMRANGEBYSCORE', key, '-inf', expireBefore)
return redis.call('ZCARD', key)
`
)
// isSessionActiveScript 检查会话是否活跃
// KEYS[1] = session_limit:account:{accountID}
// ARGV[1] = idleTimeout(秒)
// ARGV[2] = sessionUUID
isSessionActiveScript
=
redis
.
NewScript
(
`
local key = KEYS[1]
local idleTimeout = tonumber(ARGV[1])
local sessionUUID = ARGV[2]
local timeResult = redis.call('TIME')
local now = tonumber(timeResult[1])
local expireBefore = now - idleTimeout
-- 获取会话的时间戳
local score = redis.call('ZSCORE', key, sessionUUID)
if score == false then
return 0
end
-- 检查是否过期
if tonumber(score) <= expireBefore then
return 0
end
return 1
`
)
)
type
sessionLimitCache
struct
{
rdb
*
redis
.
Client
defaultIdleTimeout
time
.
Duration
// 默认空闲超时(用于 GetActiveSessionCount)
}
// NewSessionLimitCache 创建会话限制缓存
// defaultIdleTimeoutMinutes: 默认空闲超时时间(分钟),用于无参数查询
func
NewSessionLimitCache
(
rdb
*
redis
.
Client
,
defaultIdleTimeoutMinutes
int
)
service
.
SessionLimitCache
{
if
defaultIdleTimeoutMinutes
<=
0
{
defaultIdleTimeoutMinutes
=
5
// 默认 5 分钟
}
return
&
sessionLimitCache
{
rdb
:
rdb
,
defaultIdleTimeout
:
time
.
Duration
(
defaultIdleTimeoutMinutes
)
*
time
.
Minute
,
}
}
// sessionLimitKey 生成会话限制的 Redis 键
func
sessionLimitKey
(
accountID
int64
)
string
{
return
fmt
.
Sprintf
(
"%s%d"
,
sessionLimitKeyPrefix
,
accountID
)
}
// windowCostKey 生成窗口费用缓存的 Redis 键
func
windowCostKey
(
accountID
int64
)
string
{
return
fmt
.
Sprintf
(
"%s%d"
,
windowCostKeyPrefix
,
accountID
)
}
// RegisterSession 注册会话活动
func
(
c
*
sessionLimitCache
)
RegisterSession
(
ctx
context
.
Context
,
accountID
int64
,
sessionUUID
string
,
maxSessions
int
,
idleTimeout
time
.
Duration
)
(
bool
,
error
)
{
if
sessionUUID
==
""
||
maxSessions
<=
0
{
return
true
,
nil
// 无效参数,默认允许
}
key
:=
sessionLimitKey
(
accountID
)
idleTimeoutSeconds
:=
int
(
idleTimeout
.
Seconds
())
if
idleTimeoutSeconds
<=
0
{
idleTimeoutSeconds
=
int
(
c
.
defaultIdleTimeout
.
Seconds
())
}
result
,
err
:=
registerSessionScript
.
Run
(
ctx
,
c
.
rdb
,
[]
string
{
key
},
maxSessions
,
idleTimeoutSeconds
,
sessionUUID
)
.
Int
()
if
err
!=
nil
{
return
true
,
err
// 失败开放:缓存错误时允许请求通过
}
return
result
==
1
,
nil
}
// RefreshSession 刷新会话时间戳
func
(
c
*
sessionLimitCache
)
RefreshSession
(
ctx
context
.
Context
,
accountID
int64
,
sessionUUID
string
,
idleTimeout
time
.
Duration
)
error
{
if
sessionUUID
==
""
{
return
nil
}
key
:=
sessionLimitKey
(
accountID
)
idleTimeoutSeconds
:=
int
(
idleTimeout
.
Seconds
())
if
idleTimeoutSeconds
<=
0
{
idleTimeoutSeconds
=
int
(
c
.
defaultIdleTimeout
.
Seconds
())
}
_
,
err
:=
refreshSessionScript
.
Run
(
ctx
,
c
.
rdb
,
[]
string
{
key
},
idleTimeoutSeconds
,
sessionUUID
)
.
Result
()
return
err
}
// GetActiveSessionCount 获取活跃会话数
func
(
c
*
sessionLimitCache
)
GetActiveSessionCount
(
ctx
context
.
Context
,
accountID
int64
)
(
int
,
error
)
{
key
:=
sessionLimitKey
(
accountID
)
idleTimeoutSeconds
:=
int
(
c
.
defaultIdleTimeout
.
Seconds
())
result
,
err
:=
getActiveSessionCountScript
.
Run
(
ctx
,
c
.
rdb
,
[]
string
{
key
},
idleTimeoutSeconds
)
.
Int
()
if
err
!=
nil
{
return
0
,
err
}
return
result
,
nil
}
// GetActiveSessionCountBatch 批量获取多个账号的活跃会话数
func
(
c
*
sessionLimitCache
)
GetActiveSessionCountBatch
(
ctx
context
.
Context
,
accountIDs
[]
int64
)
(
map
[
int64
]
int
,
error
)
{
if
len
(
accountIDs
)
==
0
{
return
make
(
map
[
int64
]
int
),
nil
}
results
:=
make
(
map
[
int64
]
int
,
len
(
accountIDs
))
// 使用 pipeline 批量执行
pipe
:=
c
.
rdb
.
Pipeline
()
idleTimeoutSeconds
:=
int
(
c
.
defaultIdleTimeout
.
Seconds
())
cmds
:=
make
(
map
[
int64
]
*
redis
.
Cmd
,
len
(
accountIDs
))
for
_
,
accountID
:=
range
accountIDs
{
key
:=
sessionLimitKey
(
accountID
)
cmds
[
accountID
]
=
getActiveSessionCountScript
.
Run
(
ctx
,
pipe
,
[]
string
{
key
},
idleTimeoutSeconds
)
}
// 执行 pipeline,即使部分失败也尝试获取成功的结果
_
,
_
=
pipe
.
Exec
(
ctx
)
for
accountID
,
cmd
:=
range
cmds
{
if
result
,
err
:=
cmd
.
Int
();
err
==
nil
{
results
[
accountID
]
=
result
}
}
return
results
,
nil
}
// IsSessionActive 检查会话是否活跃
func
(
c
*
sessionLimitCache
)
IsSessionActive
(
ctx
context
.
Context
,
accountID
int64
,
sessionUUID
string
)
(
bool
,
error
)
{
if
sessionUUID
==
""
{
return
false
,
nil
}
key
:=
sessionLimitKey
(
accountID
)
idleTimeoutSeconds
:=
int
(
c
.
defaultIdleTimeout
.
Seconds
())
result
,
err
:=
isSessionActiveScript
.
Run
(
ctx
,
c
.
rdb
,
[]
string
{
key
},
idleTimeoutSeconds
,
sessionUUID
)
.
Int
()
if
err
!=
nil
{
return
false
,
err
}
return
result
==
1
,
nil
}
// ========== 5h窗口费用缓存实现 ==========
// GetWindowCost 获取缓存的窗口费用
func
(
c
*
sessionLimitCache
)
GetWindowCost
(
ctx
context
.
Context
,
accountID
int64
)
(
float64
,
bool
,
error
)
{
key
:=
windowCostKey
(
accountID
)
val
,
err
:=
c
.
rdb
.
Get
(
ctx
,
key
)
.
Float64
()
if
err
==
redis
.
Nil
{
return
0
,
false
,
nil
// 缓存未命中
}
if
err
!=
nil
{
return
0
,
false
,
err
}
return
val
,
true
,
nil
}
// SetWindowCost 设置窗口费用缓存
func
(
c
*
sessionLimitCache
)
SetWindowCost
(
ctx
context
.
Context
,
accountID
int64
,
cost
float64
)
error
{
key
:=
windowCostKey
(
accountID
)
return
c
.
rdb
.
Set
(
ctx
,
key
,
cost
,
windowCostCacheTTL
)
.
Err
()
}
// GetWindowCostBatch 批量获取窗口费用缓存
func
(
c
*
sessionLimitCache
)
GetWindowCostBatch
(
ctx
context
.
Context
,
accountIDs
[]
int64
)
(
map
[
int64
]
float64
,
error
)
{
if
len
(
accountIDs
)
==
0
{
return
make
(
map
[
int64
]
float64
),
nil
}
// 构建批量查询的 keys
keys
:=
make
([]
string
,
len
(
accountIDs
))
for
i
,
accountID
:=
range
accountIDs
{
keys
[
i
]
=
windowCostKey
(
accountID
)
}
// 使用 MGET 批量获取
vals
,
err
:=
c
.
rdb
.
MGet
(
ctx
,
keys
...
)
.
Result
()
if
err
!=
nil
{
return
nil
,
err
}
results
:=
make
(
map
[
int64
]
float64
,
len
(
accountIDs
))
for
i
,
val
:=
range
vals
{
if
val
==
nil
{
continue
// 缓存未命中
}
// 尝试解析为 float64
switch
v
:=
val
.
(
type
)
{
case
string
:
if
cost
,
err
:=
strconv
.
ParseFloat
(
v
,
64
);
err
==
nil
{
results
[
accountIDs
[
i
]]
=
cost
}
case
float64
:
results
[
accountIDs
[
i
]]
=
v
}
}
return
results
,
nil
}
backend/internal/repository/usage_log_repo.go
View file @
6901b64f
...
...
@@ -22,7 +22,7 @@ import (
"github.com/lib/pq"
)
const
usageLogSelectColumns
=
"id, user_id, api_key_id, account_id, request_id, model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, billing_type, stream, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, created_at"
const
usageLogSelectColumns
=
"id, user_id, api_key_id, account_id, request_id, model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier,
account_rate_multiplier,
billing_type, stream, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, created_at"
type
usageLogRepository
struct
{
client
*
dbent
.
Client
...
...
@@ -105,6 +105,7 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
total_cost,
actual_cost,
rate_multiplier,
account_rate_multiplier,
billing_type,
stream,
duration_ms,
...
...
@@ -120,7 +121,7 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
$8, $9, $10, $11,
$12, $13,
$14, $15, $16, $17, $18, $19,
$20, $21, $22, $23, $24, $25, $26, $27, $28, $29
$20, $21, $22, $23, $24, $25, $26, $27, $28, $29
, $30
)
ON CONFLICT (request_id, api_key_id) DO NOTHING
RETURNING id, created_at
...
...
@@ -160,6 +161,7 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
log
.
TotalCost
,
log
.
ActualCost
,
rateMultiplier
,
log
.
AccountRateMultiplier
,
log
.
BillingType
,
log
.
Stream
,
duration
,
...
...
@@ -270,13 +272,13 @@ type DashboardStats = usagestats.DashboardStats
func
(
r
*
usageLogRepository
)
GetDashboardStats
(
ctx
context
.
Context
)
(
*
DashboardStats
,
error
)
{
stats
:=
&
DashboardStats
{}
now
:=
time
.
Now
()
.
UTC
()
today
UTC
:=
truncate
To
D
ay
UTC
(
now
)
now
:=
time
zone
.
Now
()
today
Start
:=
timezone
.
To
d
ay
(
)
if
err
:=
r
.
fillDashboardEntityStats
(
ctx
,
stats
,
today
UTC
,
now
);
err
!=
nil
{
if
err
:=
r
.
fillDashboardEntityStats
(
ctx
,
stats
,
today
Start
,
now
);
err
!=
nil
{
return
nil
,
err
}
if
err
:=
r
.
fillDashboardUsageStatsAggregated
(
ctx
,
stats
,
today
UTC
,
now
);
err
!=
nil
{
if
err
:=
r
.
fillDashboardUsageStatsAggregated
(
ctx
,
stats
,
today
Start
,
now
);
err
!=
nil
{
return
nil
,
err
}
...
...
@@ -298,13 +300,13 @@ func (r *usageLogRepository) GetDashboardStatsWithRange(ctx context.Context, sta
}
stats
:=
&
DashboardStats
{}
now
:=
time
.
Now
()
.
UTC
()
today
UTC
:=
truncate
To
D
ay
UTC
(
now
)
now
:=
time
zone
.
Now
()
today
Start
:=
timezone
.
To
d
ay
(
)
if
err
:=
r
.
fillDashboardEntityStats
(
ctx
,
stats
,
today
UTC
,
now
);
err
!=
nil
{
if
err
:=
r
.
fillDashboardEntityStats
(
ctx
,
stats
,
today
Start
,
now
);
err
!=
nil
{
return
nil
,
err
}
if
err
:=
r
.
fillDashboardUsageStatsFromUsageLogs
(
ctx
,
stats
,
startUTC
,
endUTC
,
today
UTC
,
now
);
err
!=
nil
{
if
err
:=
r
.
fillDashboardUsageStatsFromUsageLogs
(
ctx
,
stats
,
startUTC
,
endUTC
,
today
Start
,
now
);
err
!=
nil
{
return
nil
,
err
}
...
...
@@ -455,7 +457,7 @@ func (r *usageLogRepository) fillDashboardUsageStatsAggregated(ctx context.Conte
FROM usage_dashboard_hourly
WHERE bucket_start = $1
`
hourStart
:=
now
.
UTC
(
)
.
Truncate
(
time
.
Hour
)
hourStart
:=
now
.
In
(
timezone
.
Location
()
)
.
Truncate
(
time
.
Hour
)
if
err
:=
scanSingleRow
(
ctx
,
r
.
sql
,
hourlyActiveQuery
,
[]
any
{
hourStart
},
&
stats
.
HourlyActiveUsers
);
err
!=
nil
{
if
err
!=
sql
.
ErrNoRows
{
return
err
...
...
@@ -835,7 +837,9 @@ func (r *usageLogRepository) GetAccountTodayStats(ctx context.Context, accountID
SELECT
COUNT(*) as requests,
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as tokens,
COALESCE(SUM(actual_cost), 0) as cost
COALESCE(SUM(total_cost * COALESCE(account_rate_multiplier, 1)), 0) as cost,
COALESCE(SUM(total_cost), 0) as standard_cost,
COALESCE(SUM(actual_cost), 0) as user_cost
FROM usage_logs
WHERE account_id = $1 AND created_at >= $2
`
...
...
@@ -849,6 +853,8 @@ func (r *usageLogRepository) GetAccountTodayStats(ctx context.Context, accountID
&
stats
.
Requests
,
&
stats
.
Tokens
,
&
stats
.
Cost
,
&
stats
.
StandardCost
,
&
stats
.
UserCost
,
);
err
!=
nil
{
return
nil
,
err
}
...
...
@@ -861,7 +867,9 @@ func (r *usageLogRepository) GetAccountWindowStats(ctx context.Context, accountI
SELECT
COUNT(*) as requests,
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as tokens,
COALESCE(SUM(actual_cost), 0) as cost
COALESCE(SUM(total_cost * COALESCE(account_rate_multiplier, 1)), 0) as cost,
COALESCE(SUM(total_cost), 0) as standard_cost,
COALESCE(SUM(actual_cost), 0) as user_cost
FROM usage_logs
WHERE account_id = $1 AND created_at >= $2
`
...
...
@@ -875,6 +883,8 @@ func (r *usageLogRepository) GetAccountWindowStats(ctx context.Context, accountI
&
stats
.
Requests
,
&
stats
.
Tokens
,
&
stats
.
Cost
,
&
stats
.
StandardCost
,
&
stats
.
UserCost
,
);
err
!=
nil
{
return
nil
,
err
}
...
...
@@ -1400,8 +1410,8 @@ func (r *usageLogRepository) GetBatchAPIKeyUsageStats(ctx context.Context, apiKe
return
result
,
nil
}
// GetUsageTrendWithFilters returns usage trend data with optional
user/api_key
filters
func
(
r
*
usageLogRepository
)
GetUsageTrendWithFilters
(
ctx
context
.
Context
,
startTime
,
endTime
time
.
Time
,
granularity
string
,
userID
,
apiKeyID
int64
)
(
results
[]
TrendDataPoint
,
err
error
)
{
// GetUsageTrendWithFilters returns usage trend data with optional filters
func
(
r
*
usageLogRepository
)
GetUsageTrendWithFilters
(
ctx
context
.
Context
,
startTime
,
endTime
time
.
Time
,
granularity
string
,
userID
,
apiKeyID
,
accountID
,
groupID
int64
,
model
string
,
stream
*
bool
)
(
results
[]
TrendDataPoint
,
err
error
)
{
dateFormat
:=
"YYYY-MM-DD"
if
granularity
==
"hour"
{
dateFormat
=
"YYYY-MM-DD HH24:00"
...
...
@@ -1430,6 +1440,22 @@ func (r *usageLogRepository) GetUsageTrendWithFilters(ctx context.Context, start
query
+=
fmt
.
Sprintf
(
" AND api_key_id = $%d"
,
len
(
args
)
+
1
)
args
=
append
(
args
,
apiKeyID
)
}
if
accountID
>
0
{
query
+=
fmt
.
Sprintf
(
" AND account_id = $%d"
,
len
(
args
)
+
1
)
args
=
append
(
args
,
accountID
)
}
if
groupID
>
0
{
query
+=
fmt
.
Sprintf
(
" AND group_id = $%d"
,
len
(
args
)
+
1
)
args
=
append
(
args
,
groupID
)
}
if
model
!=
""
{
query
+=
fmt
.
Sprintf
(
" AND model = $%d"
,
len
(
args
)
+
1
)
args
=
append
(
args
,
model
)
}
if
stream
!=
nil
{
query
+=
fmt
.
Sprintf
(
" AND stream = $%d"
,
len
(
args
)
+
1
)
args
=
append
(
args
,
*
stream
)
}
query
+=
" GROUP BY date ORDER BY date ASC"
rows
,
err
:=
r
.
sql
.
QueryContext
(
ctx
,
query
,
args
...
)
...
...
@@ -1452,9 +1478,15 @@ func (r *usageLogRepository) GetUsageTrendWithFilters(ctx context.Context, start
return
results
,
nil
}
// GetModelStatsWithFilters returns model statistics with optional user/api_key filters
func
(
r
*
usageLogRepository
)
GetModelStatsWithFilters
(
ctx
context
.
Context
,
startTime
,
endTime
time
.
Time
,
userID
,
apiKeyID
,
accountID
int64
)
(
results
[]
ModelStat
,
err
error
)
{
query
:=
`
// GetModelStatsWithFilters returns model statistics with optional filters
func
(
r
*
usageLogRepository
)
GetModelStatsWithFilters
(
ctx
context
.
Context
,
startTime
,
endTime
time
.
Time
,
userID
,
apiKeyID
,
accountID
,
groupID
int64
,
stream
*
bool
)
(
results
[]
ModelStat
,
err
error
)
{
actualCostExpr
:=
"COALESCE(SUM(actual_cost), 0) as actual_cost"
// 当仅按 account_id 聚合时,实际费用使用账号倍率(total_cost * account_rate_multiplier)。
if
accountID
>
0
&&
userID
==
0
&&
apiKeyID
==
0
{
actualCostExpr
=
"COALESCE(SUM(total_cost * COALESCE(account_rate_multiplier, 1)), 0) as actual_cost"
}
query
:=
fmt
.
Sprintf
(
`
SELECT
model,
COUNT(*) as requests,
...
...
@@ -1462,10 +1494,10 @@ func (r *usageLogRepository) GetModelStatsWithFilters(ctx context.Context, start
COALESCE(SUM(output_tokens), 0) as output_tokens,
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as total_tokens,
COALESCE(SUM(total_cost), 0) as cost,
COALESCE(SUM(actual_cost), 0) as actual_cost
%s
FROM usage_logs
WHERE created_at >= $1 AND created_at < $2
`
`
,
actualCostExpr
)
args
:=
[]
any
{
startTime
,
endTime
}
if
userID
>
0
{
...
...
@@ -1480,6 +1512,14 @@ func (r *usageLogRepository) GetModelStatsWithFilters(ctx context.Context, start
query
+=
fmt
.
Sprintf
(
" AND account_id = $%d"
,
len
(
args
)
+
1
)
args
=
append
(
args
,
accountID
)
}
if
groupID
>
0
{
query
+=
fmt
.
Sprintf
(
" AND group_id = $%d"
,
len
(
args
)
+
1
)
args
=
append
(
args
,
groupID
)
}
if
stream
!=
nil
{
query
+=
fmt
.
Sprintf
(
" AND stream = $%d"
,
len
(
args
)
+
1
)
args
=
append
(
args
,
*
stream
)
}
query
+=
" GROUP BY model ORDER BY total_tokens DESC"
rows
,
err
:=
r
.
sql
.
QueryContext
(
ctx
,
query
,
args
...
)
...
...
@@ -1587,12 +1627,14 @@ func (r *usageLogRepository) GetStatsWithFilters(ctx context.Context, filters Us
COALESCE(SUM(cache_creation_tokens + cache_read_tokens), 0) as total_cache_tokens,
COALESCE(SUM(total_cost), 0) as total_cost,
COALESCE(SUM(actual_cost), 0) as total_actual_cost,
COALESCE(SUM(total_cost * COALESCE(account_rate_multiplier, 1)), 0) as total_account_cost,
COALESCE(AVG(duration_ms), 0) as avg_duration_ms
FROM usage_logs
%s
`
,
buildWhere
(
conditions
))
stats
:=
&
UsageStats
{}
var
totalAccountCost
float64
if
err
:=
scanSingleRow
(
ctx
,
r
.
sql
,
...
...
@@ -1604,10 +1646,14 @@ func (r *usageLogRepository) GetStatsWithFilters(ctx context.Context, filters Us
&
stats
.
TotalCacheTokens
,
&
stats
.
TotalCost
,
&
stats
.
TotalActualCost
,
&
totalAccountCost
,
&
stats
.
AverageDurationMs
,
);
err
!=
nil
{
return
nil
,
err
}
if
filters
.
AccountID
>
0
{
stats
.
TotalAccountCost
=
&
totalAccountCost
}
stats
.
TotalTokens
=
stats
.
TotalInputTokens
+
stats
.
TotalOutputTokens
+
stats
.
TotalCacheTokens
return
stats
,
nil
}
...
...
@@ -1634,7 +1680,8 @@ func (r *usageLogRepository) GetAccountUsageStats(ctx context.Context, accountID
COUNT(*) as requests,
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as tokens,
COALESCE(SUM(total_cost), 0) as cost,
COALESCE(SUM(actual_cost), 0) as actual_cost
COALESCE(SUM(total_cost * COALESCE(account_rate_multiplier, 1)), 0) as actual_cost,
COALESCE(SUM(actual_cost), 0) as user_cost
FROM usage_logs
WHERE account_id = $1 AND created_at >= $2 AND created_at < $3
GROUP BY date
...
...
@@ -1661,7 +1708,8 @@ func (r *usageLogRepository) GetAccountUsageStats(ctx context.Context, accountID
var
tokens
int64
var
cost
float64
var
actualCost
float64
if
err
=
rows
.
Scan
(
&
date
,
&
requests
,
&
tokens
,
&
cost
,
&
actualCost
);
err
!=
nil
{
var
userCost
float64
if
err
=
rows
.
Scan
(
&
date
,
&
requests
,
&
tokens
,
&
cost
,
&
actualCost
,
&
userCost
);
err
!=
nil
{
return
nil
,
err
}
t
,
_
:=
time
.
Parse
(
"2006-01-02"
,
date
)
...
...
@@ -1672,19 +1720,21 @@ func (r *usageLogRepository) GetAccountUsageStats(ctx context.Context, accountID
Tokens
:
tokens
,
Cost
:
cost
,
ActualCost
:
actualCost
,
UserCost
:
userCost
,
})
}
if
err
=
rows
.
Err
();
err
!=
nil
{
return
nil
,
err
}
var
totalAc
tual
Cost
,
totalStandardCost
float64
var
totalAc
countCost
,
totalUser
Cost
,
totalStandardCost
float64
var
totalRequests
,
totalTokens
int64
var
highestCostDay
,
highestRequestDay
*
AccountUsageHistory
for
i
:=
range
history
{
h
:=
&
history
[
i
]
totalActualCost
+=
h
.
ActualCost
totalAccountCost
+=
h
.
ActualCost
totalUserCost
+=
h
.
UserCost
totalStandardCost
+=
h
.
Cost
totalRequests
+=
h
.
Requests
totalTokens
+=
h
.
Tokens
...
...
@@ -1711,11 +1761,13 @@ func (r *usageLogRepository) GetAccountUsageStats(ctx context.Context, accountID
summary
:=
AccountUsageSummary
{
Days
:
daysCount
,
ActualDaysUsed
:
actualDaysUsed
,
TotalCost
:
totalActualCost
,
TotalCost
:
totalAccountCost
,
TotalUserCost
:
totalUserCost
,
TotalStandardCost
:
totalStandardCost
,
TotalRequests
:
totalRequests
,
TotalTokens
:
totalTokens
,
AvgDailyCost
:
totalActualCost
/
float64
(
actualDaysUsed
),
AvgDailyCost
:
totalAccountCost
/
float64
(
actualDaysUsed
),
AvgDailyUserCost
:
totalUserCost
/
float64
(
actualDaysUsed
),
AvgDailyRequests
:
float64
(
totalRequests
)
/
float64
(
actualDaysUsed
),
AvgDailyTokens
:
float64
(
totalTokens
)
/
float64
(
actualDaysUsed
),
AvgDurationMs
:
avgDuration
,
...
...
@@ -1727,11 +1779,13 @@ func (r *usageLogRepository) GetAccountUsageStats(ctx context.Context, accountID
summary
.
Today
=
&
struct
{
Date
string
`json:"date"`
Cost
float64
`json:"cost"`
UserCost
float64
`json:"user_cost"`
Requests
int64
`json:"requests"`
Tokens
int64
`json:"tokens"`
}{
Date
:
history
[
i
]
.
Date
,
Cost
:
history
[
i
]
.
ActualCost
,
UserCost
:
history
[
i
]
.
UserCost
,
Requests
:
history
[
i
]
.
Requests
,
Tokens
:
history
[
i
]
.
Tokens
,
}
...
...
@@ -1744,11 +1798,13 @@ func (r *usageLogRepository) GetAccountUsageStats(ctx context.Context, accountID
Date
string
`json:"date"`
Label
string
`json:"label"`
Cost
float64
`json:"cost"`
UserCost
float64
`json:"user_cost"`
Requests
int64
`json:"requests"`
}{
Date
:
highestCostDay
.
Date
,
Label
:
highestCostDay
.
Label
,
Cost
:
highestCostDay
.
ActualCost
,
UserCost
:
highestCostDay
.
UserCost
,
Requests
:
highestCostDay
.
Requests
,
}
}
...
...
@@ -1759,15 +1815,17 @@ func (r *usageLogRepository) GetAccountUsageStats(ctx context.Context, accountID
Label
string
`json:"label"`
Requests
int64
`json:"requests"`
Cost
float64
`json:"cost"`
UserCost
float64
`json:"user_cost"`
}{
Date
:
highestRequestDay
.
Date
,
Label
:
highestRequestDay
.
Label
,
Requests
:
highestRequestDay
.
Requests
,
Cost
:
highestRequestDay
.
ActualCost
,
UserCost
:
highestRequestDay
.
UserCost
,
}
}
models
,
err
:=
r
.
GetModelStatsWithFilters
(
ctx
,
startTime
,
endTime
,
0
,
0
,
accountID
)
models
,
err
:=
r
.
GetModelStatsWithFilters
(
ctx
,
startTime
,
endTime
,
0
,
0
,
accountID
,
0
,
nil
)
if
err
!=
nil
{
models
=
[]
ModelStat
{}
}
...
...
@@ -2015,6 +2073,7 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
totalCost
float64
actualCost
float64
rateMultiplier
float64
accountRateMultiplier
sql
.
NullFloat64
billingType
int16
stream
bool
durationMs
sql
.
NullInt64
...
...
@@ -2048,6 +2107,7 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
&
totalCost
,
&
actualCost
,
&
rateMultiplier
,
&
accountRateMultiplier
,
&
billingType
,
&
stream
,
&
durationMs
,
...
...
@@ -2080,6 +2140,7 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
TotalCost
:
totalCost
,
ActualCost
:
actualCost
,
RateMultiplier
:
rateMultiplier
,
AccountRateMultiplier
:
nullFloat64Ptr
(
accountRateMultiplier
),
BillingType
:
int8
(
billingType
),
Stream
:
stream
,
ImageCount
:
imageCount
,
...
...
@@ -2186,6 +2247,14 @@ func nullInt(v *int) sql.NullInt64 {
return
sql
.
NullInt64
{
Int64
:
int64
(
*
v
),
Valid
:
true
}
}
func
nullFloat64Ptr
(
v
sql
.
NullFloat64
)
*
float64
{
if
!
v
.
Valid
{
return
nil
}
out
:=
v
.
Float64
return
&
out
}
func
nullString
(
v
*
string
)
sql
.
NullString
{
if
v
==
nil
||
*
v
==
""
{
return
sql
.
NullString
{}
...
...
backend/internal/repository/usage_log_repo_integration_test.go
View file @
6901b64f
...
...
@@ -11,6 +11,7 @@ import (
dbent
"github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/stretchr/testify/suite"
...
...
@@ -36,6 +37,12 @@ func TestUsageLogRepoSuite(t *testing.T) {
suite
.
Run
(
t
,
new
(
UsageLogRepoSuite
))
}
// truncateToDayUTC 截断到 UTC 日期边界(测试辅助函数)
func
truncateToDayUTC
(
t
time
.
Time
)
time
.
Time
{
t
=
t
.
UTC
()
return
time
.
Date
(
t
.
Year
(),
t
.
Month
(),
t
.
Day
(),
0
,
0
,
0
,
0
,
time
.
UTC
)
}
func
(
s
*
UsageLogRepoSuite
)
createUsageLog
(
user
*
service
.
User
,
apiKey
*
service
.
APIKey
,
account
*
service
.
Account
,
inputTokens
,
outputTokens
int
,
cost
float64
,
createdAt
time
.
Time
)
*
service
.
UsageLog
{
log
:=
&
service
.
UsageLog
{
UserID
:
user
.
ID
,
...
...
@@ -95,6 +102,34 @@ func (s *UsageLogRepoSuite) TestGetByID_NotFound() {
s
.
Require
()
.
Error
(
err
,
"expected error for non-existent ID"
)
}
func
(
s
*
UsageLogRepoSuite
)
TestGetByID_ReturnsAccountRateMultiplier
()
{
user
:=
mustCreateUser
(
s
.
T
(),
s
.
client
,
&
service
.
User
{
Email
:
"getbyid-mult@test.com"
})
apiKey
:=
mustCreateApiKey
(
s
.
T
(),
s
.
client
,
&
service
.
APIKey
{
UserID
:
user
.
ID
,
Key
:
"sk-getbyid-mult"
,
Name
:
"k"
})
account
:=
mustCreateAccount
(
s
.
T
(),
s
.
client
,
&
service
.
Account
{
Name
:
"acc-getbyid-mult"
})
m
:=
0.5
log
:=
&
service
.
UsageLog
{
UserID
:
user
.
ID
,
APIKeyID
:
apiKey
.
ID
,
AccountID
:
account
.
ID
,
RequestID
:
uuid
.
New
()
.
String
(),
Model
:
"claude-3"
,
InputTokens
:
10
,
OutputTokens
:
20
,
TotalCost
:
1.0
,
ActualCost
:
2.0
,
AccountRateMultiplier
:
&
m
,
CreatedAt
:
timezone
.
Today
()
.
Add
(
2
*
time
.
Hour
),
}
_
,
err
:=
s
.
repo
.
Create
(
s
.
ctx
,
log
)
s
.
Require
()
.
NoError
(
err
)
got
,
err
:=
s
.
repo
.
GetByID
(
s
.
ctx
,
log
.
ID
)
s
.
Require
()
.
NoError
(
err
)
s
.
Require
()
.
NotNil
(
got
.
AccountRateMultiplier
)
s
.
Require
()
.
InEpsilon
(
0.5
,
*
got
.
AccountRateMultiplier
,
0.0001
)
}
// --- Delete ---
func
(
s
*
UsageLogRepoSuite
)
TestDelete
()
{
...
...
@@ -403,12 +438,49 @@ func (s *UsageLogRepoSuite) TestGetAccountTodayStats() {
apiKey
:=
mustCreateApiKey
(
s
.
T
(),
s
.
client
,
&
service
.
APIKey
{
UserID
:
user
.
ID
,
Key
:
"sk-acctoday"
,
Name
:
"k"
})
account
:=
mustCreateAccount
(
s
.
T
(),
s
.
client
,
&
service
.
Account
{
Name
:
"acc-today"
})
s
.
createUsageLog
(
user
,
apiKey
,
account
,
10
,
20
,
0.5
,
time
.
Now
())
createdAt
:=
timezone
.
Today
()
.
Add
(
1
*
time
.
Hour
)
m1
:=
1.5
m2
:=
0.0
_
,
err
:=
s
.
repo
.
Create
(
s
.
ctx
,
&
service
.
UsageLog
{
UserID
:
user
.
ID
,
APIKeyID
:
apiKey
.
ID
,
AccountID
:
account
.
ID
,
RequestID
:
uuid
.
New
()
.
String
(),
Model
:
"claude-3"
,
InputTokens
:
10
,
OutputTokens
:
20
,
TotalCost
:
1.0
,
ActualCost
:
2.0
,
AccountRateMultiplier
:
&
m1
,
CreatedAt
:
createdAt
,
})
s
.
Require
()
.
NoError
(
err
)
_
,
err
=
s
.
repo
.
Create
(
s
.
ctx
,
&
service
.
UsageLog
{
UserID
:
user
.
ID
,
APIKeyID
:
apiKey
.
ID
,
AccountID
:
account
.
ID
,
RequestID
:
uuid
.
New
()
.
String
(),
Model
:
"claude-3"
,
InputTokens
:
5
,
OutputTokens
:
5
,
TotalCost
:
0.5
,
ActualCost
:
1.0
,
AccountRateMultiplier
:
&
m2
,
CreatedAt
:
createdAt
,
})
s
.
Require
()
.
NoError
(
err
)
stats
,
err
:=
s
.
repo
.
GetAccountTodayStats
(
s
.
ctx
,
account
.
ID
)
s
.
Require
()
.
NoError
(
err
,
"GetAccountTodayStats"
)
s
.
Require
()
.
Equal
(
int64
(
1
),
stats
.
Requests
)
s
.
Require
()
.
Equal
(
int64
(
30
),
stats
.
Tokens
)
s
.
Require
()
.
Equal
(
int64
(
2
),
stats
.
Requests
)
s
.
Require
()
.
Equal
(
int64
(
40
),
stats
.
Tokens
)
// account cost = SUM(total_cost * account_rate_multiplier)
s
.
Require
()
.
InEpsilon
(
1.5
,
stats
.
Cost
,
0.0001
)
// standard cost = SUM(total_cost)
s
.
Require
()
.
InEpsilon
(
1.5
,
stats
.
StandardCost
,
0.0001
)
// user cost = SUM(actual_cost)
s
.
Require
()
.
InEpsilon
(
3.0
,
stats
.
UserCost
,
0.0001
)
}
func
(
s
*
UsageLogRepoSuite
)
TestDashboardAggregationConsistency
()
{
...
...
@@ -872,17 +944,17 @@ func (s *UsageLogRepoSuite) TestGetUsageTrendWithFilters() {
endTime
:=
base
.
Add
(
48
*
time
.
Hour
)
// Test with user filter
trend
,
err
:=
s
.
repo
.
GetUsageTrendWithFilters
(
s
.
ctx
,
startTime
,
endTime
,
"day"
,
user
.
ID
,
0
)
trend
,
err
:=
s
.
repo
.
GetUsageTrendWithFilters
(
s
.
ctx
,
startTime
,
endTime
,
"day"
,
user
.
ID
,
0
,
0
,
0
,
""
,
nil
)
s
.
Require
()
.
NoError
(
err
,
"GetUsageTrendWithFilters user filter"
)
s
.
Require
()
.
Len
(
trend
,
2
)
// Test with apiKey filter
trend
,
err
=
s
.
repo
.
GetUsageTrendWithFilters
(
s
.
ctx
,
startTime
,
endTime
,
"day"
,
0
,
apiKey
.
ID
)
trend
,
err
=
s
.
repo
.
GetUsageTrendWithFilters
(
s
.
ctx
,
startTime
,
endTime
,
"day"
,
0
,
apiKey
.
ID
,
0
,
0
,
""
,
nil
)
s
.
Require
()
.
NoError
(
err
,
"GetUsageTrendWithFilters apiKey filter"
)
s
.
Require
()
.
Len
(
trend
,
2
)
// Test with both filters
trend
,
err
=
s
.
repo
.
GetUsageTrendWithFilters
(
s
.
ctx
,
startTime
,
endTime
,
"day"
,
user
.
ID
,
apiKey
.
ID
)
trend
,
err
=
s
.
repo
.
GetUsageTrendWithFilters
(
s
.
ctx
,
startTime
,
endTime
,
"day"
,
user
.
ID
,
apiKey
.
ID
,
0
,
0
,
""
,
nil
)
s
.
Require
()
.
NoError
(
err
,
"GetUsageTrendWithFilters both filters"
)
s
.
Require
()
.
Len
(
trend
,
2
)
}
...
...
@@ -899,7 +971,7 @@ func (s *UsageLogRepoSuite) TestGetUsageTrendWithFilters_HourlyGranularity() {
startTime
:=
base
.
Add
(
-
1
*
time
.
Hour
)
endTime
:=
base
.
Add
(
3
*
time
.
Hour
)
trend
,
err
:=
s
.
repo
.
GetUsageTrendWithFilters
(
s
.
ctx
,
startTime
,
endTime
,
"hour"
,
user
.
ID
,
0
)
trend
,
err
:=
s
.
repo
.
GetUsageTrendWithFilters
(
s
.
ctx
,
startTime
,
endTime
,
"hour"
,
user
.
ID
,
0
,
0
,
0
,
""
,
nil
)
s
.
Require
()
.
NoError
(
err
,
"GetUsageTrendWithFilters hourly"
)
s
.
Require
()
.
Len
(
trend
,
2
)
}
...
...
@@ -945,17 +1017,17 @@ func (s *UsageLogRepoSuite) TestGetModelStatsWithFilters() {
endTime
:=
base
.
Add
(
2
*
time
.
Hour
)
// Test with user filter
stats
,
err
:=
s
.
repo
.
GetModelStatsWithFilters
(
s
.
ctx
,
startTime
,
endTime
,
user
.
ID
,
0
,
0
)
stats
,
err
:=
s
.
repo
.
GetModelStatsWithFilters
(
s
.
ctx
,
startTime
,
endTime
,
user
.
ID
,
0
,
0
,
0
,
nil
)
s
.
Require
()
.
NoError
(
err
,
"GetModelStatsWithFilters user filter"
)
s
.
Require
()
.
Len
(
stats
,
2
)
// Test with apiKey filter
stats
,
err
=
s
.
repo
.
GetModelStatsWithFilters
(
s
.
ctx
,
startTime
,
endTime
,
0
,
apiKey
.
ID
,
0
)
stats
,
err
=
s
.
repo
.
GetModelStatsWithFilters
(
s
.
ctx
,
startTime
,
endTime
,
0
,
apiKey
.
ID
,
0
,
0
,
nil
)
s
.
Require
()
.
NoError
(
err
,
"GetModelStatsWithFilters apiKey filter"
)
s
.
Require
()
.
Len
(
stats
,
2
)
// Test with account filter
stats
,
err
=
s
.
repo
.
GetModelStatsWithFilters
(
s
.
ctx
,
startTime
,
endTime
,
0
,
0
,
account
.
ID
)
stats
,
err
=
s
.
repo
.
GetModelStatsWithFilters
(
s
.
ctx
,
startTime
,
endTime
,
0
,
0
,
account
.
ID
,
0
,
nil
)
s
.
Require
()
.
NoError
(
err
,
"GetModelStatsWithFilters account filter"
)
s
.
Require
()
.
Len
(
stats
,
2
)
}
...
...
backend/internal/repository/wire.go
View file @
6901b64f
...
...
@@ -37,6 +37,16 @@ func ProvidePricingRemoteClient(cfg *config.Config) service.PricingRemoteClient
return
NewPricingRemoteClient
(
cfg
.
Update
.
ProxyURL
)
}
// ProvideSessionLimitCache 创建会话限制缓存
// 用于 Anthropic OAuth/SetupToken 账号的并发会话数量控制
func
ProvideSessionLimitCache
(
rdb
*
redis
.
Client
,
cfg
*
config
.
Config
)
service
.
SessionLimitCache
{
defaultIdleTimeoutMinutes
:=
5
// 默认 5 分钟空闲超时
if
cfg
!=
nil
&&
cfg
.
Gateway
.
SessionIdleTimeoutMinutes
>
0
{
defaultIdleTimeoutMinutes
=
cfg
.
Gateway
.
SessionIdleTimeoutMinutes
}
return
NewSessionLimitCache
(
rdb
,
defaultIdleTimeoutMinutes
)
}
// ProviderSet is the Wire provider set for all repositories
var
ProviderSet
=
wire
.
NewSet
(
NewUserRepository
,
...
...
@@ -61,6 +71,7 @@ var ProviderSet = wire.NewSet(
NewTempUnschedCache
,
NewTimeoutCounterCache
,
ProvideConcurrencyCache
,
ProvideSessionLimitCache
,
NewDashboardCache
,
NewEmailCache
,
NewIdentityCache
,
...
...
@@ -69,6 +80,7 @@ var ProviderSet = wire.NewSet(
NewGeminiTokenCache
,
NewSchedulerCache
,
NewSchedulerOutboxRepository
,
NewProxyLatencyCache
,
// HTTP service ports (DI Strategy A: return interface directly)
NewTurnstileVerifier
,
...
...
backend/internal/server/api_contract_test.go
View file @
6901b64f
...
...
@@ -241,6 +241,7 @@ func TestAPIContracts(t *testing.T) {
"total_cost": 0.5,
"actual_cost": 0.5,
"rate_multiplier": 1,
"account_rate_multiplier": null,
"billing_type": 0,
"stream": true,
"duration_ms": 100,
...
...
@@ -435,12 +436,12 @@ func newContractDeps(t *testing.T) *contractDeps {
settingRepo
:=
newStubSettingRepo
()
settingService
:=
service
.
NewSettingService
(
settingRepo
,
cfg
)
adminService
:=
service
.
NewAdminService
(
userRepo
,
groupRepo
,
&
accountRepo
,
proxyRepo
,
apiKeyRepo
,
redeemRepo
,
nil
,
nil
,
nil
)
adminService
:=
service
.
NewAdminService
(
userRepo
,
groupRepo
,
&
accountRepo
,
proxyRepo
,
apiKeyRepo
,
redeemRepo
,
nil
,
nil
,
nil
,
nil
)
authHandler
:=
handler
.
NewAuthHandler
(
cfg
,
nil
,
userService
,
settingService
,
nil
)
apiKeyHandler
:=
handler
.
NewAPIKeyHandler
(
apiKeyService
)
usageHandler
:=
handler
.
NewUsageHandler
(
usageService
,
apiKeyService
)
adminSettingHandler
:=
adminhandler
.
NewSettingHandler
(
settingService
,
nil
,
nil
,
nil
)
adminAccountHandler
:=
adminhandler
.
NewAccountHandler
(
adminService
,
nil
,
nil
,
nil
,
nil
,
nil
,
nil
,
nil
,
nil
,
nil
)
adminAccountHandler
:=
adminhandler
.
NewAccountHandler
(
adminService
,
nil
,
nil
,
nil
,
nil
,
nil
,
nil
,
nil
,
nil
,
nil
,
nil
)
jwtAuth
:=
func
(
c
*
gin
.
Context
)
{
c
.
Set
(
string
(
middleware
.
ContextKeyUser
),
middleware
.
AuthSubject
{
...
...
@@ -779,6 +780,10 @@ func (s *stubAccountRepo) SetAntigravityQuotaScopeLimit(ctx context.Context, id
return
errors
.
New
(
"not implemented"
)
}
func
(
s
*
stubAccountRepo
)
SetModelRateLimit
(
ctx
context
.
Context
,
id
int64
,
scope
string
,
resetAt
time
.
Time
)
error
{
return
errors
.
New
(
"not implemented"
)
}
func
(
s
*
stubAccountRepo
)
SetOverloaded
(
ctx
context
.
Context
,
id
int64
,
until
time
.
Time
)
error
{
return
errors
.
New
(
"not implemented"
)
}
...
...
@@ -799,6 +804,10 @@ func (s *stubAccountRepo) ClearAntigravityQuotaScopes(ctx context.Context, id in
return
errors
.
New
(
"not implemented"
)
}
func
(
s
*
stubAccountRepo
)
ClearModelRateLimits
(
ctx
context
.
Context
,
id
int64
)
error
{
return
errors
.
New
(
"not implemented"
)
}
func
(
s
*
stubAccountRepo
)
UpdateSessionWindow
(
ctx
context
.
Context
,
id
int64
,
start
,
end
*
time
.
Time
,
status
string
)
error
{
return
errors
.
New
(
"not implemented"
)
}
...
...
@@ -858,6 +867,10 @@ func (stubProxyRepo) CountAccountsByProxyID(ctx context.Context, proxyID int64)
return
0
,
errors
.
New
(
"not implemented"
)
}
func
(
stubProxyRepo
)
ListAccountSummariesByProxyID
(
ctx
context
.
Context
,
proxyID
int64
)
([]
service
.
ProxyAccountSummary
,
error
)
{
return
nil
,
errors
.
New
(
"not implemented"
)
}
type
stubRedeemCodeRepo
struct
{}
func
(
stubRedeemCodeRepo
)
Create
(
ctx
context
.
Context
,
code
*
service
.
RedeemCode
)
error
{
...
...
@@ -1229,11 +1242,11 @@ func (r *stubUsageLogRepo) GetDashboardStats(ctx context.Context) (*usagestats.D
return
nil
,
errors
.
New
(
"not implemented"
)
}
func
(
r
*
stubUsageLogRepo
)
GetUsageTrendWithFilters
(
ctx
context
.
Context
,
startTime
,
endTime
time
.
Time
,
granularity
string
,
userID
,
apiKeyID
int64
)
([]
usagestats
.
TrendDataPoint
,
error
)
{
func
(
r
*
stubUsageLogRepo
)
GetUsageTrendWithFilters
(
ctx
context
.
Context
,
startTime
,
endTime
time
.
Time
,
granularity
string
,
userID
,
apiKeyID
,
accountID
,
groupID
int64
,
model
string
,
stream
*
bool
)
([]
usagestats
.
TrendDataPoint
,
error
)
{
return
nil
,
errors
.
New
(
"not implemented"
)
}
func
(
r
*
stubUsageLogRepo
)
GetModelStatsWithFilters
(
ctx
context
.
Context
,
startTime
,
endTime
time
.
Time
,
userID
,
apiKeyID
,
accountID
int64
)
([]
usagestats
.
ModelStat
,
error
)
{
func
(
r
*
stubUsageLogRepo
)
GetModelStatsWithFilters
(
ctx
context
.
Context
,
startTime
,
endTime
time
.
Time
,
userID
,
apiKeyID
,
accountID
,
groupID
int64
,
stream
*
bool
)
([]
usagestats
.
ModelStat
,
error
)
{
return
nil
,
errors
.
New
(
"not implemented"
)
}
...
...
backend/internal/server/middleware/security_headers.go
View file @
6901b64f
package
middleware
import
(
"crypto/rand"
"encoding/base64"
"strings"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/gin-gonic/gin"
)
const
(
// CSPNonceKey is the context key for storing the CSP nonce
CSPNonceKey
=
"csp_nonce"
// NonceTemplate is the placeholder in CSP policy for nonce
NonceTemplate
=
"__CSP_NONCE__"
// CloudflareInsightsDomain is the domain for Cloudflare Web Analytics
CloudflareInsightsDomain
=
"https://static.cloudflareinsights.com"
)
// GenerateNonce generates a cryptographically secure random nonce
func
GenerateNonce
()
string
{
b
:=
make
([]
byte
,
16
)
_
,
_
=
rand
.
Read
(
b
)
return
base64
.
StdEncoding
.
EncodeToString
(
b
)
}
// GetNonceFromContext retrieves the CSP nonce from gin context
func
GetNonceFromContext
(
c
*
gin
.
Context
)
string
{
if
nonce
,
exists
:=
c
.
Get
(
CSPNonceKey
);
exists
{
if
s
,
ok
:=
nonce
.
(
string
);
ok
{
return
s
}
}
return
""
}
// SecurityHeaders sets baseline security headers for all responses.
func
SecurityHeaders
(
cfg
config
.
CSPConfig
)
gin
.
HandlerFunc
{
policy
:=
strings
.
TrimSpace
(
cfg
.
Policy
)
...
...
@@ -14,13 +42,75 @@ func SecurityHeaders(cfg config.CSPConfig) gin.HandlerFunc {
policy
=
config
.
DefaultCSPPolicy
}
// Enhance policy with required directives (nonce placeholder and Cloudflare Insights)
policy
=
enhanceCSPPolicy
(
policy
)
return
func
(
c
*
gin
.
Context
)
{
c
.
Header
(
"X-Content-Type-Options"
,
"nosniff"
)
c
.
Header
(
"X-Frame-Options"
,
"DENY"
)
c
.
Header
(
"Referrer-Policy"
,
"strict-origin-when-cross-origin"
)
if
cfg
.
Enabled
{
c
.
Header
(
"Content-Security-Policy"
,
policy
)
// Generate nonce for this request
nonce
:=
GenerateNonce
()
c
.
Set
(
CSPNonceKey
,
nonce
)
// Replace nonce placeholder in policy
finalPolicy
:=
strings
.
ReplaceAll
(
policy
,
NonceTemplate
,
"'nonce-"
+
nonce
+
"'"
)
c
.
Header
(
"Content-Security-Policy"
,
finalPolicy
)
}
c
.
Next
()
}
}
// enhanceCSPPolicy ensures the CSP policy includes nonce support and Cloudflare Insights domain.
// This allows the application to work correctly even if the config file has an older CSP policy.
func
enhanceCSPPolicy
(
policy
string
)
string
{
// Add nonce placeholder to script-src if not present
if
!
strings
.
Contains
(
policy
,
NonceTemplate
)
&&
!
strings
.
Contains
(
policy
,
"'nonce-"
)
{
policy
=
addToDirective
(
policy
,
"script-src"
,
NonceTemplate
)
}
// Add Cloudflare Insights domain to script-src if not present
if
!
strings
.
Contains
(
policy
,
CloudflareInsightsDomain
)
{
policy
=
addToDirective
(
policy
,
"script-src"
,
CloudflareInsightsDomain
)
}
return
policy
}
// addToDirective adds a value to a specific CSP directive.
// If the directive doesn't exist, it will be added after default-src.
func
addToDirective
(
policy
,
directive
,
value
string
)
string
{
// Find the directive in the policy
directivePrefix
:=
directive
+
" "
idx
:=
strings
.
Index
(
policy
,
directivePrefix
)
if
idx
==
-
1
{
// Directive not found, add it after default-src or at the beginning
defaultSrcIdx
:=
strings
.
Index
(
policy
,
"default-src "
)
if
defaultSrcIdx
!=
-
1
{
// Find the end of default-src directive (next semicolon)
endIdx
:=
strings
.
Index
(
policy
[
defaultSrcIdx
:
],
";"
)
if
endIdx
!=
-
1
{
insertPos
:=
defaultSrcIdx
+
endIdx
+
1
// Insert new directive after default-src
return
policy
[
:
insertPos
]
+
" "
+
directive
+
" 'self' "
+
value
+
";"
+
policy
[
insertPos
:
]
}
}
// Fallback: prepend the directive
return
directive
+
" 'self' "
+
value
+
"; "
+
policy
}
// Find the end of this directive (next semicolon or end of string)
endIdx
:=
strings
.
Index
(
policy
[
idx
:
],
";"
)
if
endIdx
==
-
1
{
// No semicolon found, directive goes to end of string
return
policy
+
" "
+
value
}
// Insert value before the semicolon
insertPos
:=
idx
+
endIdx
return
policy
[
:
insertPos
]
+
" "
+
value
+
policy
[
insertPos
:
]
}
backend/internal/server/middleware/security_headers_test.go
0 → 100644
View file @
6901b64f
package
middleware
import
(
"encoding/base64"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func
init
()
{
gin
.
SetMode
(
gin
.
TestMode
)
}
func
TestGenerateNonce
(
t
*
testing
.
T
)
{
t
.
Run
(
"generates_valid_base64_string"
,
func
(
t
*
testing
.
T
)
{
nonce
:=
GenerateNonce
()
// Should be valid base64
decoded
,
err
:=
base64
.
StdEncoding
.
DecodeString
(
nonce
)
require
.
NoError
(
t
,
err
)
// Should decode to 16 bytes
assert
.
Len
(
t
,
decoded
,
16
)
})
t
.
Run
(
"generates_unique_nonces"
,
func
(
t
*
testing
.
T
)
{
nonces
:=
make
(
map
[
string
]
bool
)
for
i
:=
0
;
i
<
100
;
i
++
{
nonce
:=
GenerateNonce
()
assert
.
False
(
t
,
nonces
[
nonce
],
"nonce should be unique"
)
nonces
[
nonce
]
=
true
}
})
t
.
Run
(
"nonce_has_expected_length"
,
func
(
t
*
testing
.
T
)
{
nonce
:=
GenerateNonce
()
// 16 bytes -> 24 chars in base64 (with padding)
assert
.
Len
(
t
,
nonce
,
24
)
})
}
func
TestGetNonceFromContext
(
t
*
testing
.
T
)
{
t
.
Run
(
"returns_nonce_when_present"
,
func
(
t
*
testing
.
T
)
{
w
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
w
)
expectedNonce
:=
"test-nonce-123"
c
.
Set
(
CSPNonceKey
,
expectedNonce
)
nonce
:=
GetNonceFromContext
(
c
)
assert
.
Equal
(
t
,
expectedNonce
,
nonce
)
})
t
.
Run
(
"returns_empty_string_when_not_present"
,
func
(
t
*
testing
.
T
)
{
w
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
w
)
nonce
:=
GetNonceFromContext
(
c
)
assert
.
Empty
(
t
,
nonce
)
})
t
.
Run
(
"returns_empty_for_wrong_type"
,
func
(
t
*
testing
.
T
)
{
w
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
w
)
// Set a non-string value
c
.
Set
(
CSPNonceKey
,
12345
)
// Should return empty string for wrong type (safe type assertion)
nonce
:=
GetNonceFromContext
(
c
)
assert
.
Empty
(
t
,
nonce
)
})
}
func
TestSecurityHeaders
(
t
*
testing
.
T
)
{
t
.
Run
(
"sets_basic_security_headers"
,
func
(
t
*
testing
.
T
)
{
cfg
:=
config
.
CSPConfig
{
Enabled
:
false
}
middleware
:=
SecurityHeaders
(
cfg
)
w
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
w
)
c
.
Request
=
httptest
.
NewRequest
(
http
.
MethodGet
,
"/"
,
nil
)
middleware
(
c
)
assert
.
Equal
(
t
,
"nosniff"
,
w
.
Header
()
.
Get
(
"X-Content-Type-Options"
))
assert
.
Equal
(
t
,
"DENY"
,
w
.
Header
()
.
Get
(
"X-Frame-Options"
))
assert
.
Equal
(
t
,
"strict-origin-when-cross-origin"
,
w
.
Header
()
.
Get
(
"Referrer-Policy"
))
})
t
.
Run
(
"csp_disabled_no_csp_header"
,
func
(
t
*
testing
.
T
)
{
cfg
:=
config
.
CSPConfig
{
Enabled
:
false
}
middleware
:=
SecurityHeaders
(
cfg
)
w
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
w
)
c
.
Request
=
httptest
.
NewRequest
(
http
.
MethodGet
,
"/"
,
nil
)
middleware
(
c
)
assert
.
Empty
(
t
,
w
.
Header
()
.
Get
(
"Content-Security-Policy"
))
})
t
.
Run
(
"csp_enabled_sets_csp_header"
,
func
(
t
*
testing
.
T
)
{
cfg
:=
config
.
CSPConfig
{
Enabled
:
true
,
Policy
:
"default-src 'self'"
,
}
middleware
:=
SecurityHeaders
(
cfg
)
w
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
w
)
c
.
Request
=
httptest
.
NewRequest
(
http
.
MethodGet
,
"/"
,
nil
)
middleware
(
c
)
csp
:=
w
.
Header
()
.
Get
(
"Content-Security-Policy"
)
assert
.
NotEmpty
(
t
,
csp
)
// Policy is auto-enhanced with nonce and Cloudflare Insights domain
assert
.
Contains
(
t
,
csp
,
"default-src 'self'"
)
assert
.
Contains
(
t
,
csp
,
"'nonce-"
)
assert
.
Contains
(
t
,
csp
,
CloudflareInsightsDomain
)
})
t
.
Run
(
"csp_enabled_with_nonce_placeholder"
,
func
(
t
*
testing
.
T
)
{
cfg
:=
config
.
CSPConfig
{
Enabled
:
true
,
Policy
:
"script-src 'self' __CSP_NONCE__"
,
}
middleware
:=
SecurityHeaders
(
cfg
)
w
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
w
)
c
.
Request
=
httptest
.
NewRequest
(
http
.
MethodGet
,
"/"
,
nil
)
middleware
(
c
)
csp
:=
w
.
Header
()
.
Get
(
"Content-Security-Policy"
)
assert
.
NotEmpty
(
t
,
csp
)
assert
.
NotContains
(
t
,
csp
,
"__CSP_NONCE__"
,
"placeholder should be replaced"
)
assert
.
Contains
(
t
,
csp
,
"'nonce-"
,
"should contain nonce directive"
)
// Verify nonce is stored in context
nonce
:=
GetNonceFromContext
(
c
)
assert
.
NotEmpty
(
t
,
nonce
)
assert
.
Contains
(
t
,
csp
,
"'nonce-"
+
nonce
+
"'"
)
})
t
.
Run
(
"uses_default_policy_when_empty"
,
func
(
t
*
testing
.
T
)
{
cfg
:=
config
.
CSPConfig
{
Enabled
:
true
,
Policy
:
""
,
}
middleware
:=
SecurityHeaders
(
cfg
)
w
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
w
)
c
.
Request
=
httptest
.
NewRequest
(
http
.
MethodGet
,
"/"
,
nil
)
middleware
(
c
)
csp
:=
w
.
Header
()
.
Get
(
"Content-Security-Policy"
)
assert
.
NotEmpty
(
t
,
csp
)
// Default policy should contain these elements
assert
.
Contains
(
t
,
csp
,
"default-src 'self'"
)
})
t
.
Run
(
"uses_default_policy_when_whitespace_only"
,
func
(
t
*
testing
.
T
)
{
cfg
:=
config
.
CSPConfig
{
Enabled
:
true
,
Policy
:
"
\t\n
"
,
}
middleware
:=
SecurityHeaders
(
cfg
)
w
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
w
)
c
.
Request
=
httptest
.
NewRequest
(
http
.
MethodGet
,
"/"
,
nil
)
middleware
(
c
)
csp
:=
w
.
Header
()
.
Get
(
"Content-Security-Policy"
)
assert
.
NotEmpty
(
t
,
csp
)
assert
.
Contains
(
t
,
csp
,
"default-src 'self'"
)
})
t
.
Run
(
"multiple_nonce_placeholders_replaced"
,
func
(
t
*
testing
.
T
)
{
cfg
:=
config
.
CSPConfig
{
Enabled
:
true
,
Policy
:
"script-src __CSP_NONCE__; style-src __CSP_NONCE__"
,
}
middleware
:=
SecurityHeaders
(
cfg
)
w
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
w
)
c
.
Request
=
httptest
.
NewRequest
(
http
.
MethodGet
,
"/"
,
nil
)
middleware
(
c
)
csp
:=
w
.
Header
()
.
Get
(
"Content-Security-Policy"
)
nonce
:=
GetNonceFromContext
(
c
)
// Count occurrences of the nonce
count
:=
strings
.
Count
(
csp
,
"'nonce-"
+
nonce
+
"'"
)
assert
.
Equal
(
t
,
2
,
count
,
"both placeholders should be replaced with same nonce"
)
})
t
.
Run
(
"calls_next_handler"
,
func
(
t
*
testing
.
T
)
{
cfg
:=
config
.
CSPConfig
{
Enabled
:
true
,
Policy
:
"default-src 'self'"
}
middleware
:=
SecurityHeaders
(
cfg
)
nextCalled
:=
false
router
:=
gin
.
New
()
router
.
Use
(
middleware
)
router
.
GET
(
"/test"
,
func
(
c
*
gin
.
Context
)
{
nextCalled
=
true
c
.
Status
(
http
.
StatusOK
)
})
w
:=
httptest
.
NewRecorder
()
req
:=
httptest
.
NewRequest
(
http
.
MethodGet
,
"/test"
,
nil
)
router
.
ServeHTTP
(
w
,
req
)
assert
.
True
(
t
,
nextCalled
,
"next handler should be called"
)
assert
.
Equal
(
t
,
http
.
StatusOK
,
w
.
Code
)
})
t
.
Run
(
"nonce_unique_per_request"
,
func
(
t
*
testing
.
T
)
{
cfg
:=
config
.
CSPConfig
{
Enabled
:
true
,
Policy
:
"script-src __CSP_NONCE__"
,
}
middleware
:=
SecurityHeaders
(
cfg
)
nonces
:=
make
(
map
[
string
]
bool
)
for
i
:=
0
;
i
<
10
;
i
++
{
w
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
w
)
c
.
Request
=
httptest
.
NewRequest
(
http
.
MethodGet
,
"/"
,
nil
)
middleware
(
c
)
nonce
:=
GetNonceFromContext
(
c
)
assert
.
False
(
t
,
nonces
[
nonce
],
"nonce should be unique per request"
)
nonces
[
nonce
]
=
true
}
})
}
func
TestCSPNonceKey
(
t
*
testing
.
T
)
{
t
.
Run
(
"constant_value"
,
func
(
t
*
testing
.
T
)
{
assert
.
Equal
(
t
,
"csp_nonce"
,
CSPNonceKey
)
})
}
func
TestNonceTemplate
(
t
*
testing
.
T
)
{
t
.
Run
(
"constant_value"
,
func
(
t
*
testing
.
T
)
{
assert
.
Equal
(
t
,
"__CSP_NONCE__"
,
NonceTemplate
)
})
}
func
TestEnhanceCSPPolicy
(
t
*
testing
.
T
)
{
t
.
Run
(
"adds_nonce_placeholder_if_missing"
,
func
(
t
*
testing
.
T
)
{
policy
:=
"default-src 'self'; script-src 'self'"
enhanced
:=
enhanceCSPPolicy
(
policy
)
assert
.
Contains
(
t
,
enhanced
,
NonceTemplate
)
assert
.
Contains
(
t
,
enhanced
,
CloudflareInsightsDomain
)
})
t
.
Run
(
"does_not_duplicate_nonce_placeholder"
,
func
(
t
*
testing
.
T
)
{
policy
:=
"default-src 'self'; script-src 'self' __CSP_NONCE__"
enhanced
:=
enhanceCSPPolicy
(
policy
)
// Should not duplicate
count
:=
strings
.
Count
(
enhanced
,
NonceTemplate
)
assert
.
Equal
(
t
,
1
,
count
)
})
t
.
Run
(
"does_not_duplicate_cloudflare_domain"
,
func
(
t
*
testing
.
T
)
{
policy
:=
"default-src 'self'; script-src 'self' https://static.cloudflareinsights.com"
enhanced
:=
enhanceCSPPolicy
(
policy
)
count
:=
strings
.
Count
(
enhanced
,
CloudflareInsightsDomain
)
assert
.
Equal
(
t
,
1
,
count
)
})
t
.
Run
(
"handles_policy_without_script_src"
,
func
(
t
*
testing
.
T
)
{
policy
:=
"default-src 'self'"
enhanced
:=
enhanceCSPPolicy
(
policy
)
assert
.
Contains
(
t
,
enhanced
,
"script-src"
)
assert
.
Contains
(
t
,
enhanced
,
NonceTemplate
)
assert
.
Contains
(
t
,
enhanced
,
CloudflareInsightsDomain
)
})
t
.
Run
(
"preserves_existing_nonce"
,
func
(
t
*
testing
.
T
)
{
policy
:=
"script-src 'self' 'nonce-existing'"
enhanced
:=
enhanceCSPPolicy
(
policy
)
// Should not add placeholder if nonce already exists
assert
.
NotContains
(
t
,
enhanced
,
NonceTemplate
)
assert
.
Contains
(
t
,
enhanced
,
"'nonce-existing'"
)
})
}
func
TestAddToDirective
(
t
*
testing
.
T
)
{
t
.
Run
(
"adds_to_existing_directive"
,
func
(
t
*
testing
.
T
)
{
policy
:=
"script-src 'self'; style-src 'self'"
result
:=
addToDirective
(
policy
,
"script-src"
,
"https://example.com"
)
assert
.
Contains
(
t
,
result
,
"script-src 'self' https://example.com"
)
})
t
.
Run
(
"creates_directive_if_not_exists"
,
func
(
t
*
testing
.
T
)
{
policy
:=
"default-src 'self'"
result
:=
addToDirective
(
policy
,
"script-src"
,
"https://example.com"
)
assert
.
Contains
(
t
,
result
,
"script-src"
)
assert
.
Contains
(
t
,
result
,
"https://example.com"
)
})
t
.
Run
(
"handles_directive_at_end_without_semicolon"
,
func
(
t
*
testing
.
T
)
{
policy
:=
"default-src 'self'; script-src 'self'"
result
:=
addToDirective
(
policy
,
"script-src"
,
"https://example.com"
)
assert
.
Contains
(
t
,
result
,
"https://example.com"
)
})
t
.
Run
(
"handles_empty_policy"
,
func
(
t
*
testing
.
T
)
{
policy
:=
""
result
:=
addToDirective
(
policy
,
"script-src"
,
"https://example.com"
)
assert
.
Contains
(
t
,
result
,
"script-src"
)
assert
.
Contains
(
t
,
result
,
"https://example.com"
)
})
}
// Benchmark tests
func
BenchmarkGenerateNonce
(
b
*
testing
.
B
)
{
for
i
:=
0
;
i
<
b
.
N
;
i
++
{
GenerateNonce
()
}
}
func
BenchmarkSecurityHeadersMiddleware
(
b
*
testing
.
B
)
{
cfg
:=
config
.
CSPConfig
{
Enabled
:
true
,
Policy
:
"script-src 'self' __CSP_NONCE__"
,
}
middleware
:=
SecurityHeaders
(
cfg
)
b
.
ResetTimer
()
for
i
:=
0
;
i
<
b
.
N
;
i
++
{
w
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
w
)
c
.
Request
=
httptest
.
NewRequest
(
http
.
MethodGet
,
"/"
,
nil
)
middleware
(
c
)
}
}
Prev
1
2
3
4
5
6
7
…
10
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment