Commit 6901b64f authored by cyhhao's avatar cyhhao
Browse files

merge: sync upstream changes

parents 32c47b15 dae0d532
......@@ -8,6 +8,7 @@ import (
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/lib/pq"
)
......@@ -41,21 +42,22 @@ func isPostgresDriver(db *sql.DB) bool {
}
func (r *dashboardAggregationRepository) AggregateRange(ctx context.Context, start, end time.Time) error {
startUTC := start.UTC()
endUTC := end.UTC()
if !endUTC.After(startUTC) {
loc := timezone.Location()
startLocal := start.In(loc)
endLocal := end.In(loc)
if !endLocal.After(startLocal) {
return nil
}
hourStart := startUTC.Truncate(time.Hour)
hourEnd := endUTC.Truncate(time.Hour)
if endUTC.After(hourEnd) {
hourStart := startLocal.Truncate(time.Hour)
hourEnd := endLocal.Truncate(time.Hour)
if endLocal.After(hourEnd) {
hourEnd = hourEnd.Add(time.Hour)
}
dayStart := truncateToDayUTC(startUTC)
dayEnd := truncateToDayUTC(endUTC)
if endUTC.After(dayEnd) {
dayStart := truncateToDay(startLocal)
dayEnd := truncateToDay(endLocal)
if endLocal.After(dayEnd) {
dayEnd = dayEnd.Add(24 * time.Hour)
}
......@@ -146,38 +148,41 @@ func (r *dashboardAggregationRepository) EnsureUsageLogsPartitions(ctx context.C
}
func (r *dashboardAggregationRepository) insertHourlyActiveUsers(ctx context.Context, start, end time.Time) error {
tzName := timezone.Name()
query := `
INSERT INTO usage_dashboard_hourly_users (bucket_start, user_id)
SELECT DISTINCT
date_trunc('hour', created_at AT TIME ZONE 'UTC') AT TIME ZONE 'UTC' AS bucket_start,
date_trunc('hour', created_at AT TIME ZONE $3) AT TIME ZONE $3 AS bucket_start,
user_id
FROM usage_logs
WHERE created_at >= $1 AND created_at < $2
ON CONFLICT DO NOTHING
`
_, err := r.sql.ExecContext(ctx, query, start.UTC(), end.UTC())
_, err := r.sql.ExecContext(ctx, query, start, end, tzName)
return err
}
func (r *dashboardAggregationRepository) insertDailyActiveUsers(ctx context.Context, start, end time.Time) error {
tzName := timezone.Name()
query := `
INSERT INTO usage_dashboard_daily_users (bucket_date, user_id)
SELECT DISTINCT
(bucket_start AT TIME ZONE 'UTC')::date AS bucket_date,
(bucket_start AT TIME ZONE $3)::date AS bucket_date,
user_id
FROM usage_dashboard_hourly_users
WHERE bucket_start >= $1 AND bucket_start < $2
ON CONFLICT DO NOTHING
`
_, err := r.sql.ExecContext(ctx, query, start.UTC(), end.UTC())
_, err := r.sql.ExecContext(ctx, query, start, end, tzName)
return err
}
func (r *dashboardAggregationRepository) upsertHourlyAggregates(ctx context.Context, start, end time.Time) error {
tzName := timezone.Name()
query := `
WITH hourly AS (
SELECT
date_trunc('hour', created_at AT TIME ZONE 'UTC') AT TIME ZONE 'UTC' AS bucket_start,
date_trunc('hour', created_at AT TIME ZONE $3) AT TIME ZONE $3 AS bucket_start,
COUNT(*) AS total_requests,
COALESCE(SUM(input_tokens), 0) AS input_tokens,
COALESCE(SUM(output_tokens), 0) AS output_tokens,
......@@ -236,15 +241,16 @@ func (r *dashboardAggregationRepository) upsertHourlyAggregates(ctx context.Cont
active_users = EXCLUDED.active_users,
computed_at = EXCLUDED.computed_at
`
_, err := r.sql.ExecContext(ctx, query, start.UTC(), end.UTC())
_, err := r.sql.ExecContext(ctx, query, start, end, tzName)
return err
}
func (r *dashboardAggregationRepository) upsertDailyAggregates(ctx context.Context, start, end time.Time) error {
tzName := timezone.Name()
query := `
WITH daily AS (
SELECT
(bucket_start AT TIME ZONE 'UTC')::date AS bucket_date,
(bucket_start AT TIME ZONE $5)::date AS bucket_date,
COALESCE(SUM(total_requests), 0) AS total_requests,
COALESCE(SUM(input_tokens), 0) AS input_tokens,
COALESCE(SUM(output_tokens), 0) AS output_tokens,
......@@ -255,7 +261,7 @@ func (r *dashboardAggregationRepository) upsertDailyAggregates(ctx context.Conte
COALESCE(SUM(total_duration_ms), 0) AS total_duration_ms
FROM usage_dashboard_hourly
WHERE bucket_start >= $1 AND bucket_start < $2
GROUP BY (bucket_start AT TIME ZONE 'UTC')::date
GROUP BY (bucket_start AT TIME ZONE $5)::date
),
user_counts AS (
SELECT bucket_date, COUNT(*) AS active_users
......@@ -303,7 +309,7 @@ func (r *dashboardAggregationRepository) upsertDailyAggregates(ctx context.Conte
active_users = EXCLUDED.active_users,
computed_at = EXCLUDED.computed_at
`
_, err := r.sql.ExecContext(ctx, query, start.UTC(), end.UTC(), start.UTC(), end.UTC())
_, err := r.sql.ExecContext(ctx, query, start, end, start, end, tzName)
return err
}
......@@ -376,9 +382,8 @@ func (r *dashboardAggregationRepository) createUsageLogsPartition(ctx context.Co
return err
}
func truncateToDayUTC(t time.Time) time.Time {
t = t.UTC()
return time.Date(t.Year(), t.Month(), t.Day(), 0, 0, 0, 0, time.UTC)
func truncateToDay(t time.Time) time.Time {
return timezone.StartOfDay(t)
}
func truncateToMonthUTC(t time.Time) time.Time {
......
......@@ -11,8 +11,8 @@ import (
)
const (
geminiTokenKeyPrefix = "gemini:token:"
geminiRefreshLockKeyPrefix = "gemini:refresh_lock:"
oauthTokenKeyPrefix = "oauth:token:"
oauthRefreshLockKeyPrefix = "oauth:refresh_lock:"
)
type geminiTokenCache struct {
......@@ -24,21 +24,26 @@ func NewGeminiTokenCache(rdb *redis.Client) service.GeminiTokenCache {
}
func (c *geminiTokenCache) GetAccessToken(ctx context.Context, cacheKey string) (string, error) {
key := fmt.Sprintf("%s%s", geminiTokenKeyPrefix, cacheKey)
key := fmt.Sprintf("%s%s", oauthTokenKeyPrefix, cacheKey)
return c.rdb.Get(ctx, key).Result()
}
func (c *geminiTokenCache) SetAccessToken(ctx context.Context, cacheKey string, token string, ttl time.Duration) error {
key := fmt.Sprintf("%s%s", geminiTokenKeyPrefix, cacheKey)
key := fmt.Sprintf("%s%s", oauthTokenKeyPrefix, cacheKey)
return c.rdb.Set(ctx, key, token, ttl).Err()
}
func (c *geminiTokenCache) DeleteAccessToken(ctx context.Context, cacheKey string) error {
key := fmt.Sprintf("%s%s", oauthTokenKeyPrefix, cacheKey)
return c.rdb.Del(ctx, key).Err()
}
func (c *geminiTokenCache) AcquireRefreshLock(ctx context.Context, cacheKey string, ttl time.Duration) (bool, error) {
key := fmt.Sprintf("%s%s", geminiRefreshLockKeyPrefix, cacheKey)
key := fmt.Sprintf("%s%s", oauthRefreshLockKeyPrefix, cacheKey)
return c.rdb.SetNX(ctx, key, 1, ttl).Result()
}
func (c *geminiTokenCache) ReleaseRefreshLock(ctx context.Context, cacheKey string) error {
key := fmt.Sprintf("%s%s", geminiRefreshLockKeyPrefix, cacheKey)
key := fmt.Sprintf("%s%s", oauthRefreshLockKeyPrefix, cacheKey)
return c.rdb.Del(ctx, key).Err()
}
//go:build integration
package repository
import (
"errors"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/redis/go-redis/v9"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
)
type GeminiTokenCacheSuite struct {
IntegrationRedisSuite
cache service.GeminiTokenCache
}
func (s *GeminiTokenCacheSuite) SetupTest() {
s.IntegrationRedisSuite.SetupTest()
s.cache = NewGeminiTokenCache(s.rdb)
}
func (s *GeminiTokenCacheSuite) TestDeleteAccessToken() {
cacheKey := "project-123"
token := "token-value"
require.NoError(s.T(), s.cache.SetAccessToken(s.ctx, cacheKey, token, time.Minute))
got, err := s.cache.GetAccessToken(s.ctx, cacheKey)
require.NoError(s.T(), err)
require.Equal(s.T(), token, got)
require.NoError(s.T(), s.cache.DeleteAccessToken(s.ctx, cacheKey))
_, err = s.cache.GetAccessToken(s.ctx, cacheKey)
require.True(s.T(), errors.Is(err, redis.Nil), "expected redis.Nil after delete")
}
func (s *GeminiTokenCacheSuite) TestDeleteAccessToken_MissingKey() {
require.NoError(s.T(), s.cache.DeleteAccessToken(s.ctx, "missing-key"))
}
func TestGeminiTokenCacheSuite(t *testing.T) {
suite.Run(t, new(GeminiTokenCacheSuite))
}
//go:build unit
package repository
import (
"context"
"testing"
"time"
"github.com/redis/go-redis/v9"
"github.com/stretchr/testify/require"
)
func TestGeminiTokenCache_DeleteAccessToken_RedisError(t *testing.T) {
rdb := redis.NewClient(&redis.Options{
Addr: "127.0.0.1:1",
DialTimeout: 50 * time.Millisecond,
ReadTimeout: 50 * time.Millisecond,
WriteTimeout: 50 * time.Millisecond,
})
t.Cleanup(func() {
_ = rdb.Close()
})
cache := NewGeminiTokenCache(rdb)
err := cache.DeleteAccessToken(context.Background(), "broken")
require.Error(t, err)
}
......@@ -49,7 +49,13 @@ func (r *groupRepository) Create(ctx context.Context, groupIn *service.Group) er
SetNillableImagePrice4k(groupIn.ImagePrice4K).
SetDefaultValidityDays(groupIn.DefaultValidityDays).
SetClaudeCodeOnly(groupIn.ClaudeCodeOnly).
SetNillableFallbackGroupID(groupIn.FallbackGroupID)
SetNillableFallbackGroupID(groupIn.FallbackGroupID).
SetModelRoutingEnabled(groupIn.ModelRoutingEnabled)
// 设置模型路由配置
if groupIn.ModelRouting != nil {
builder = builder.SetModelRouting(groupIn.ModelRouting)
}
created, err := builder.Save(ctx)
if err == nil {
......@@ -101,7 +107,8 @@ func (r *groupRepository) Update(ctx context.Context, groupIn *service.Group) er
SetNillableImagePrice2k(groupIn.ImagePrice2K).
SetNillableImagePrice4k(groupIn.ImagePrice4K).
SetDefaultValidityDays(groupIn.DefaultValidityDays).
SetClaudeCodeOnly(groupIn.ClaudeCodeOnly)
SetClaudeCodeOnly(groupIn.ClaudeCodeOnly).
SetModelRoutingEnabled(groupIn.ModelRoutingEnabled)
// 处理 FallbackGroupID:nil 时清除,否则设置
if groupIn.FallbackGroupID != nil {
......@@ -110,6 +117,13 @@ func (r *groupRepository) Update(ctx context.Context, groupIn *service.Group) er
builder = builder.ClearFallbackGroupID()
}
// 处理 ModelRouting:nil 时清除,否则设置
if groupIn.ModelRouting != nil {
builder = builder.SetModelRouting(groupIn.ModelRouting)
} else {
builder = builder.ClearModelRouting()
}
updated, err := builder.Save(ctx)
if err != nil {
return translatePersistenceError(err, service.ErrGroupNotFound, service.ErrGroupExists)
......
......@@ -55,7 +55,6 @@ INSERT INTO ops_error_logs (
upstream_error_message,
upstream_error_detail,
upstream_errors,
duration_ms,
time_to_first_token_ms,
request_body,
request_body_truncated,
......@@ -65,7 +64,7 @@ INSERT INTO ops_error_logs (
retry_count,
created_at
) VALUES (
$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31,$32,$33,$34,$35
$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31,$32,$33,$34
) RETURNING id`
var id int64
......@@ -98,7 +97,6 @@ INSERT INTO ops_error_logs (
opsNullString(input.UpstreamErrorMessage),
opsNullString(input.UpstreamErrorDetail),
opsNullString(input.UpstreamErrorsJSON),
opsNullInt(input.DurationMs),
opsNullInt64(input.TimeToFirstTokenMs),
opsNullString(input.RequestBodyJSON),
input.RequestBodyTruncated,
......@@ -135,7 +133,7 @@ func (r *opsRepository) ListErrorLogs(ctx context.Context, filter *service.OpsEr
}
where, args := buildOpsErrorLogsWhere(filter)
countSQL := "SELECT COUNT(*) FROM ops_error_logs " + where
countSQL := "SELECT COUNT(*) FROM ops_error_logs e " + where
var total int
if err := r.db.QueryRowContext(ctx, countSQL, args...).Scan(&total); err != nil {
......@@ -146,28 +144,43 @@ func (r *opsRepository) ListErrorLogs(ctx context.Context, filter *service.OpsEr
argsWithLimit := append(args, pageSize, offset)
selectSQL := `
SELECT
id,
created_at,
error_phase,
error_type,
severity,
COALESCE(upstream_status_code, status_code, 0),
COALESCE(platform, ''),
COALESCE(model, ''),
duration_ms,
COALESCE(client_request_id, ''),
COALESCE(request_id, ''),
COALESCE(error_message, ''),
user_id,
api_key_id,
account_id,
group_id,
CASE WHEN client_ip IS NULL THEN NULL ELSE client_ip::text END,
COALESCE(request_path, ''),
stream
FROM ops_error_logs
e.id,
e.created_at,
e.error_phase,
e.error_type,
COALESCE(e.error_owner, ''),
COALESCE(e.error_source, ''),
e.severity,
COALESCE(e.upstream_status_code, e.status_code, 0),
COALESCE(e.platform, ''),
COALESCE(e.model, ''),
COALESCE(e.is_retryable, false),
COALESCE(e.retry_count, 0),
COALESCE(e.resolved, false),
e.resolved_at,
e.resolved_by_user_id,
COALESCE(u2.email, ''),
e.resolved_retry_id,
COALESCE(e.client_request_id, ''),
COALESCE(e.request_id, ''),
COALESCE(e.error_message, ''),
e.user_id,
COALESCE(u.email, ''),
e.api_key_id,
e.account_id,
COALESCE(a.name, ''),
e.group_id,
COALESCE(g.name, ''),
CASE WHEN e.client_ip IS NULL THEN NULL ELSE e.client_ip::text END,
COALESCE(e.request_path, ''),
e.stream
FROM ops_error_logs e
LEFT JOIN accounts a ON e.account_id = a.id
LEFT JOIN groups g ON e.group_id = g.id
LEFT JOIN users u ON e.user_id = u.id
LEFT JOIN users u2 ON e.resolved_by_user_id = u2.id
` + where + `
ORDER BY created_at DESC
ORDER BY e.created_at DESC
LIMIT $` + itoa(len(args)+1) + ` OFFSET $` + itoa(len(args)+2)
rows, err := r.db.QueryContext(ctx, selectSQL, argsWithLimit...)
......@@ -179,39 +192,65 @@ LIMIT $` + itoa(len(args)+1) + ` OFFSET $` + itoa(len(args)+2)
out := make([]*service.OpsErrorLog, 0, pageSize)
for rows.Next() {
var item service.OpsErrorLog
var latency sql.NullInt64
var statusCode sql.NullInt64
var clientIP sql.NullString
var userID sql.NullInt64
var apiKeyID sql.NullInt64
var accountID sql.NullInt64
var accountName string
var groupID sql.NullInt64
var groupName string
var userEmail string
var resolvedAt sql.NullTime
var resolvedBy sql.NullInt64
var resolvedByName string
var resolvedRetryID sql.NullInt64
if err := rows.Scan(
&item.ID,
&item.CreatedAt,
&item.Phase,
&item.Type,
&item.Owner,
&item.Source,
&item.Severity,
&statusCode,
&item.Platform,
&item.Model,
&latency,
&item.IsRetryable,
&item.RetryCount,
&item.Resolved,
&resolvedAt,
&resolvedBy,
&resolvedByName,
&resolvedRetryID,
&item.ClientRequestID,
&item.RequestID,
&item.Message,
&userID,
&userEmail,
&apiKeyID,
&accountID,
&accountName,
&groupID,
&groupName,
&clientIP,
&item.RequestPath,
&item.Stream,
); err != nil {
return nil, err
}
if latency.Valid {
v := int(latency.Int64)
item.LatencyMs = &v
if resolvedAt.Valid {
t := resolvedAt.Time
item.ResolvedAt = &t
}
if resolvedBy.Valid {
v := resolvedBy.Int64
item.ResolvedByUserID = &v
}
item.ResolvedByUserName = resolvedByName
if resolvedRetryID.Valid {
v := resolvedRetryID.Int64
item.ResolvedRetryID = &v
}
item.StatusCode = int(statusCode.Int64)
if clientIP.Valid {
......@@ -222,6 +261,7 @@ LIMIT $` + itoa(len(args)+1) + ` OFFSET $` + itoa(len(args)+2)
v := userID.Int64
item.UserID = &v
}
item.UserEmail = userEmail
if apiKeyID.Valid {
v := apiKeyID.Int64
item.APIKeyID = &v
......@@ -230,10 +270,12 @@ LIMIT $` + itoa(len(args)+1) + ` OFFSET $` + itoa(len(args)+2)
v := accountID.Int64
item.AccountID = &v
}
item.AccountName = accountName
if groupID.Valid {
v := groupID.Int64
item.GroupID = &v
}
item.GroupName = groupName
out = append(out, &item)
}
if err := rows.Err(); err != nil {
......@@ -258,49 +300,64 @@ func (r *opsRepository) GetErrorLogByID(ctx context.Context, id int64) (*service
q := `
SELECT
id,
created_at,
error_phase,
error_type,
severity,
COALESCE(upstream_status_code, status_code, 0),
COALESCE(platform, ''),
COALESCE(model, ''),
duration_ms,
COALESCE(client_request_id, ''),
COALESCE(request_id, ''),
COALESCE(error_message, ''),
COALESCE(error_body, ''),
upstream_status_code,
COALESCE(upstream_error_message, ''),
COALESCE(upstream_error_detail, ''),
COALESCE(upstream_errors::text, ''),
is_business_limited,
user_id,
api_key_id,
account_id,
group_id,
CASE WHEN client_ip IS NULL THEN NULL ELSE client_ip::text END,
COALESCE(request_path, ''),
stream,
COALESCE(user_agent, ''),
auth_latency_ms,
routing_latency_ms,
upstream_latency_ms,
response_latency_ms,
time_to_first_token_ms,
COALESCE(request_body::text, ''),
request_body_truncated,
request_body_bytes,
COALESCE(request_headers::text, '')
FROM ops_error_logs
WHERE id = $1
e.id,
e.created_at,
e.error_phase,
e.error_type,
COALESCE(e.error_owner, ''),
COALESCE(e.error_source, ''),
e.severity,
COALESCE(e.upstream_status_code, e.status_code, 0),
COALESCE(e.platform, ''),
COALESCE(e.model, ''),
COALESCE(e.is_retryable, false),
COALESCE(e.retry_count, 0),
COALESCE(e.resolved, false),
e.resolved_at,
e.resolved_by_user_id,
e.resolved_retry_id,
COALESCE(e.client_request_id, ''),
COALESCE(e.request_id, ''),
COALESCE(e.error_message, ''),
COALESCE(e.error_body, ''),
e.upstream_status_code,
COALESCE(e.upstream_error_message, ''),
COALESCE(e.upstream_error_detail, ''),
COALESCE(e.upstream_errors::text, ''),
e.is_business_limited,
e.user_id,
COALESCE(u.email, ''),
e.api_key_id,
e.account_id,
COALESCE(a.name, ''),
e.group_id,
COALESCE(g.name, ''),
CASE WHEN e.client_ip IS NULL THEN NULL ELSE e.client_ip::text END,
COALESCE(e.request_path, ''),
e.stream,
COALESCE(e.user_agent, ''),
e.auth_latency_ms,
e.routing_latency_ms,
e.upstream_latency_ms,
e.response_latency_ms,
e.time_to_first_token_ms,
COALESCE(e.request_body::text, ''),
e.request_body_truncated,
e.request_body_bytes,
COALESCE(e.request_headers::text, '')
FROM ops_error_logs e
LEFT JOIN users u ON e.user_id = u.id
LEFT JOIN accounts a ON e.account_id = a.id
LEFT JOIN groups g ON e.group_id = g.id
WHERE e.id = $1
LIMIT 1`
var out service.OpsErrorLogDetail
var latency sql.NullInt64
var statusCode sql.NullInt64
var upstreamStatusCode sql.NullInt64
var resolvedAt sql.NullTime
var resolvedBy sql.NullInt64
var resolvedRetryID sql.NullInt64
var clientIP sql.NullString
var userID sql.NullInt64
var apiKeyID sql.NullInt64
......@@ -318,11 +375,18 @@ LIMIT 1`
&out.CreatedAt,
&out.Phase,
&out.Type,
&out.Owner,
&out.Source,
&out.Severity,
&statusCode,
&out.Platform,
&out.Model,
&latency,
&out.IsRetryable,
&out.RetryCount,
&out.Resolved,
&resolvedAt,
&resolvedBy,
&resolvedRetryID,
&out.ClientRequestID,
&out.RequestID,
&out.Message,
......@@ -333,9 +397,12 @@ LIMIT 1`
&out.UpstreamErrors,
&out.IsBusinessLimited,
&userID,
&out.UserEmail,
&apiKeyID,
&accountID,
&out.AccountName,
&groupID,
&out.GroupName,
&clientIP,
&out.RequestPath,
&out.Stream,
......@@ -355,9 +422,17 @@ LIMIT 1`
}
out.StatusCode = int(statusCode.Int64)
if latency.Valid {
v := int(latency.Int64)
out.LatencyMs = &v
if resolvedAt.Valid {
t := resolvedAt.Time
out.ResolvedAt = &t
}
if resolvedBy.Valid {
v := resolvedBy.Int64
out.ResolvedByUserID = &v
}
if resolvedRetryID.Valid {
v := resolvedRetryID.Int64
out.ResolvedRetryID = &v
}
if clientIP.Valid {
s := clientIP.String
......@@ -487,9 +562,15 @@ SET
status = $2,
finished_at = $3,
duration_ms = $4,
result_request_id = $5,
result_error_id = $6,
error_message = $7
success = $5,
http_status_code = $6,
upstream_request_id = $7,
used_account_id = $8,
response_preview = $9,
response_truncated = $10,
result_request_id = $11,
result_error_id = $12,
error_message = $13
WHERE id = $1`
_, err := r.db.ExecContext(
......@@ -499,8 +580,14 @@ WHERE id = $1`
strings.TrimSpace(input.Status),
nullTime(input.FinishedAt),
input.DurationMs,
nullBool(input.Success),
nullInt(input.HTTPStatusCode),
opsNullString(input.UpstreamRequestID),
nullInt64(input.UsedAccountID),
opsNullString(input.ResponsePreview),
nullBool(input.ResponseTruncated),
opsNullString(input.ResultRequestID),
opsNullInt64(input.ResultErrorID),
nullInt64(input.ResultErrorID),
opsNullString(input.ErrorMessage),
)
return err
......@@ -526,6 +613,12 @@ SELECT
started_at,
finished_at,
duration_ms,
success,
http_status_code,
upstream_request_id,
used_account_id,
response_preview,
response_truncated,
result_request_id,
result_error_id,
error_message
......@@ -540,6 +633,12 @@ LIMIT 1`
var startedAt sql.NullTime
var finishedAt sql.NullTime
var durationMs sql.NullInt64
var success sql.NullBool
var httpStatusCode sql.NullInt64
var upstreamRequestID sql.NullString
var usedAccountID sql.NullInt64
var responsePreview sql.NullString
var responseTruncated sql.NullBool
var resultRequestID sql.NullString
var resultErrorID sql.NullInt64
var errorMessage sql.NullString
......@@ -555,6 +654,12 @@ LIMIT 1`
&startedAt,
&finishedAt,
&durationMs,
&success,
&httpStatusCode,
&upstreamRequestID,
&usedAccountID,
&responsePreview,
&responseTruncated,
&resultRequestID,
&resultErrorID,
&errorMessage,
......@@ -579,6 +684,30 @@ LIMIT 1`
v := durationMs.Int64
out.DurationMs = &v
}
if success.Valid {
v := success.Bool
out.Success = &v
}
if httpStatusCode.Valid {
v := int(httpStatusCode.Int64)
out.HTTPStatusCode = &v
}
if upstreamRequestID.Valid {
s := upstreamRequestID.String
out.UpstreamRequestID = &s
}
if usedAccountID.Valid {
v := usedAccountID.Int64
out.UsedAccountID = &v
}
if responsePreview.Valid {
s := responsePreview.String
out.ResponsePreview = &s
}
if responseTruncated.Valid {
v := responseTruncated.Bool
out.ResponseTruncated = &v
}
if resultRequestID.Valid {
s := resultRequestID.String
out.ResultRequestID = &s
......@@ -602,30 +731,234 @@ func nullTime(t time.Time) sql.NullTime {
return sql.NullTime{Time: t, Valid: true}
}
func nullBool(v *bool) sql.NullBool {
if v == nil {
return sql.NullBool{}
}
return sql.NullBool{Bool: *v, Valid: true}
}
func (r *opsRepository) ListRetryAttemptsByErrorID(ctx context.Context, sourceErrorID int64, limit int) ([]*service.OpsRetryAttempt, error) {
if r == nil || r.db == nil {
return nil, fmt.Errorf("nil ops repository")
}
if sourceErrorID <= 0 {
return nil, fmt.Errorf("invalid source_error_id")
}
if limit <= 0 {
limit = 50
}
if limit > 200 {
limit = 200
}
q := `
SELECT
r.id,
r.created_at,
COALESCE(r.requested_by_user_id, 0),
r.source_error_id,
COALESCE(r.mode, ''),
r.pinned_account_id,
COALESCE(pa.name, ''),
COALESCE(r.status, ''),
r.started_at,
r.finished_at,
r.duration_ms,
r.success,
r.http_status_code,
r.upstream_request_id,
r.used_account_id,
COALESCE(ua.name, ''),
r.response_preview,
r.response_truncated,
r.result_request_id,
r.result_error_id,
r.error_message
FROM ops_retry_attempts r
LEFT JOIN accounts pa ON r.pinned_account_id = pa.id
LEFT JOIN accounts ua ON r.used_account_id = ua.id
WHERE r.source_error_id = $1
ORDER BY r.created_at DESC
LIMIT $2`
rows, err := r.db.QueryContext(ctx, q, sourceErrorID, limit)
if err != nil {
return nil, err
}
defer func() { _ = rows.Close() }()
out := make([]*service.OpsRetryAttempt, 0, 16)
for rows.Next() {
var item service.OpsRetryAttempt
var pinnedAccountID sql.NullInt64
var pinnedAccountName string
var requestedBy sql.NullInt64
var startedAt sql.NullTime
var finishedAt sql.NullTime
var durationMs sql.NullInt64
var success sql.NullBool
var httpStatusCode sql.NullInt64
var upstreamRequestID sql.NullString
var usedAccountID sql.NullInt64
var usedAccountName string
var responsePreview sql.NullString
var responseTruncated sql.NullBool
var resultRequestID sql.NullString
var resultErrorID sql.NullInt64
var errorMessage sql.NullString
if err := rows.Scan(
&item.ID,
&item.CreatedAt,
&requestedBy,
&item.SourceErrorID,
&item.Mode,
&pinnedAccountID,
&pinnedAccountName,
&item.Status,
&startedAt,
&finishedAt,
&durationMs,
&success,
&httpStatusCode,
&upstreamRequestID,
&usedAccountID,
&usedAccountName,
&responsePreview,
&responseTruncated,
&resultRequestID,
&resultErrorID,
&errorMessage,
); err != nil {
return nil, err
}
item.RequestedByUserID = requestedBy.Int64
if pinnedAccountID.Valid {
v := pinnedAccountID.Int64
item.PinnedAccountID = &v
}
item.PinnedAccountName = pinnedAccountName
if startedAt.Valid {
t := startedAt.Time
item.StartedAt = &t
}
if finishedAt.Valid {
t := finishedAt.Time
item.FinishedAt = &t
}
if durationMs.Valid {
v := durationMs.Int64
item.DurationMs = &v
}
if success.Valid {
v := success.Bool
item.Success = &v
}
if httpStatusCode.Valid {
v := int(httpStatusCode.Int64)
item.HTTPStatusCode = &v
}
if upstreamRequestID.Valid {
item.UpstreamRequestID = &upstreamRequestID.String
}
if usedAccountID.Valid {
v := usedAccountID.Int64
item.UsedAccountID = &v
}
item.UsedAccountName = usedAccountName
if responsePreview.Valid {
item.ResponsePreview = &responsePreview.String
}
if responseTruncated.Valid {
v := responseTruncated.Bool
item.ResponseTruncated = &v
}
if resultRequestID.Valid {
item.ResultRequestID = &resultRequestID.String
}
if resultErrorID.Valid {
v := resultErrorID.Int64
item.ResultErrorID = &v
}
if errorMessage.Valid {
item.ErrorMessage = &errorMessage.String
}
out = append(out, &item)
}
if err := rows.Err(); err != nil {
return nil, err
}
return out, nil
}
func (r *opsRepository) UpdateErrorResolution(ctx context.Context, errorID int64, resolved bool, resolvedByUserID *int64, resolvedRetryID *int64, resolvedAt *time.Time) error {
if r == nil || r.db == nil {
return fmt.Errorf("nil ops repository")
}
if errorID <= 0 {
return fmt.Errorf("invalid error id")
}
q := `
UPDATE ops_error_logs
SET
resolved = $2,
resolved_at = $3,
resolved_by_user_id = $4,
resolved_retry_id = $5
WHERE id = $1`
at := sql.NullTime{}
if resolvedAt != nil && !resolvedAt.IsZero() {
at = sql.NullTime{Time: resolvedAt.UTC(), Valid: true}
} else if resolved {
now := time.Now().UTC()
at = sql.NullTime{Time: now, Valid: true}
}
_, err := r.db.ExecContext(
ctx,
q,
errorID,
resolved,
at,
nullInt64(resolvedByUserID),
nullInt64(resolvedRetryID),
)
return err
}
func buildOpsErrorLogsWhere(filter *service.OpsErrorLogFilter) (string, []any) {
clauses := make([]string, 0, 8)
args := make([]any, 0, 8)
clauses := make([]string, 0, 12)
args := make([]any, 0, 12)
clauses = append(clauses, "1=1")
phaseFilter := ""
if filter != nil {
phaseFilter = strings.TrimSpace(strings.ToLower(filter.Phase))
}
// ops_error_logs primarily stores client-visible error requests (status>=400),
// ops_error_logs stores client-visible error requests (status>=400),
// but we also persist "recovered" upstream errors (status<400) for upstream health visibility.
// By default, keep list endpoints scoped to client errors unless explicitly filtering upstream phase.
// If Resolved is not specified, do not filter by resolved state (backward-compatible).
resolvedFilter := (*bool)(nil)
if filter != nil {
resolvedFilter = filter.Resolved
}
// Keep list endpoints scoped to client errors unless explicitly filtering upstream phase.
if phaseFilter != "upstream" {
clauses = append(clauses, "COALESCE(status_code, 0) >= 400")
}
if filter.StartTime != nil && !filter.StartTime.IsZero() {
args = append(args, filter.StartTime.UTC())
clauses = append(clauses, "created_at >= $"+itoa(len(args)))
clauses = append(clauses, "e.created_at >= $"+itoa(len(args)))
}
if filter.EndTime != nil && !filter.EndTime.IsZero() {
args = append(args, filter.EndTime.UTC())
// Keep time-window semantics consistent with other ops queries: [start, end)
clauses = append(clauses, "created_at < $"+itoa(len(args)))
clauses = append(clauses, "e.created_at < $"+itoa(len(args)))
}
if p := strings.TrimSpace(filter.Platform); p != "" {
args = append(args, p)
......@@ -643,10 +976,59 @@ func buildOpsErrorLogsWhere(filter *service.OpsErrorLogFilter) (string, []any) {
args = append(args, phase)
clauses = append(clauses, "error_phase = $"+itoa(len(args)))
}
if filter != nil {
if owner := strings.TrimSpace(strings.ToLower(filter.Owner)); owner != "" {
args = append(args, owner)
clauses = append(clauses, "LOWER(COALESCE(error_owner,'')) = $"+itoa(len(args)))
}
if source := strings.TrimSpace(strings.ToLower(filter.Source)); source != "" {
args = append(args, source)
clauses = append(clauses, "LOWER(COALESCE(error_source,'')) = $"+itoa(len(args)))
}
}
if resolvedFilter != nil {
args = append(args, *resolvedFilter)
clauses = append(clauses, "COALESCE(resolved,false) = $"+itoa(len(args)))
}
// View filter: errors vs excluded vs all.
// Excluded = upstream 429/529 and business-limited (quota/concurrency/billing) errors.
view := ""
if filter != nil {
view = strings.ToLower(strings.TrimSpace(filter.View))
}
switch view {
case "", "errors":
clauses = append(clauses, "COALESCE(is_business_limited,false) = false")
clauses = append(clauses, "COALESCE(upstream_status_code, status_code, 0) NOT IN (429, 529)")
case "excluded":
clauses = append(clauses, "(COALESCE(is_business_limited,false) = true OR COALESCE(upstream_status_code, status_code, 0) IN (429, 529))")
case "all":
// no-op
default:
// treat unknown as default 'errors'
clauses = append(clauses, "COALESCE(is_business_limited,false) = false")
clauses = append(clauses, "COALESCE(upstream_status_code, status_code, 0) NOT IN (429, 529)")
}
if len(filter.StatusCodes) > 0 {
args = append(args, pq.Array(filter.StatusCodes))
clauses = append(clauses, "COALESCE(upstream_status_code, status_code, 0) = ANY($"+itoa(len(args))+")")
} else if filter.StatusCodesOther {
// "Other" means: status codes not in the common list.
known := []int{400, 401, 403, 404, 409, 422, 429, 500, 502, 503, 504, 529}
args = append(args, pq.Array(known))
clauses = append(clauses, "NOT (COALESCE(upstream_status_code, status_code, 0) = ANY($"+itoa(len(args))+"))")
}
// Exact correlation keys (preferred for request↔upstream linkage).
if rid := strings.TrimSpace(filter.RequestID); rid != "" {
args = append(args, rid)
clauses = append(clauses, "COALESCE(request_id,'') = $"+itoa(len(args)))
}
if crid := strings.TrimSpace(filter.ClientRequestID); crid != "" {
args = append(args, crid)
clauses = append(clauses, "COALESCE(client_request_id,'') = $"+itoa(len(args)))
}
if q := strings.TrimSpace(filter.Query); q != "" {
like := "%" + q + "%"
args = append(args, like)
......@@ -654,6 +1036,13 @@ func buildOpsErrorLogsWhere(filter *service.OpsErrorLogFilter) (string, []any) {
clauses = append(clauses, "(request_id ILIKE $"+n+" OR client_request_id ILIKE $"+n+" OR error_message ILIKE $"+n+")")
}
if userQuery := strings.TrimSpace(filter.UserQuery); userQuery != "" {
like := "%" + userQuery + "%"
args = append(args, like)
n := itoa(len(args))
clauses = append(clauses, "u.email ILIKE $"+n)
}
return "WHERE " + strings.Join(clauses, " AND "), args
}
......
......@@ -354,7 +354,7 @@ SELECT
created_at
FROM ops_alert_events
` + where + `
ORDER BY fired_at DESC
ORDER BY fired_at DESC, id DESC
LIMIT ` + limitArg
rows, err := r.db.QueryContext(ctx, q, args...)
......@@ -413,6 +413,43 @@ LIMIT ` + limitArg
return out, nil
}
func (r *opsRepository) GetAlertEventByID(ctx context.Context, eventID int64) (*service.OpsAlertEvent, error) {
if r == nil || r.db == nil {
return nil, fmt.Errorf("nil ops repository")
}
if eventID <= 0 {
return nil, fmt.Errorf("invalid event id")
}
q := `
SELECT
id,
COALESCE(rule_id, 0),
COALESCE(severity, ''),
COALESCE(status, ''),
COALESCE(title, ''),
COALESCE(description, ''),
metric_value,
threshold_value,
dimensions,
fired_at,
resolved_at,
email_sent,
created_at
FROM ops_alert_events
WHERE id = $1`
row := r.db.QueryRowContext(ctx, q, eventID)
ev, err := scanOpsAlertEvent(row)
if err != nil {
if err == sql.ErrNoRows {
return nil, nil
}
return nil, err
}
return ev, nil
}
func (r *opsRepository) GetActiveAlertEvent(ctx context.Context, ruleID int64) (*service.OpsAlertEvent, error) {
if r == nil || r.db == nil {
return nil, fmt.Errorf("nil ops repository")
......@@ -591,6 +628,121 @@ type opsAlertEventRow interface {
Scan(dest ...any) error
}
func (r *opsRepository) CreateAlertSilence(ctx context.Context, input *service.OpsAlertSilence) (*service.OpsAlertSilence, error) {
if r == nil || r.db == nil {
return nil, fmt.Errorf("nil ops repository")
}
if input == nil {
return nil, fmt.Errorf("nil input")
}
if input.RuleID <= 0 {
return nil, fmt.Errorf("invalid rule_id")
}
platform := strings.TrimSpace(input.Platform)
if platform == "" {
return nil, fmt.Errorf("invalid platform")
}
if input.Until.IsZero() {
return nil, fmt.Errorf("invalid until")
}
q := `
INSERT INTO ops_alert_silences (
rule_id,
platform,
group_id,
region,
until,
reason,
created_by,
created_at
) VALUES (
$1,$2,$3,$4,$5,$6,$7,NOW()
)
RETURNING id, rule_id, platform, group_id, region, until, COALESCE(reason,''), created_by, created_at`
row := r.db.QueryRowContext(
ctx,
q,
input.RuleID,
platform,
opsNullInt64(input.GroupID),
opsNullString(input.Region),
input.Until,
opsNullString(input.Reason),
opsNullInt64(input.CreatedBy),
)
var out service.OpsAlertSilence
var groupID sql.NullInt64
var region sql.NullString
var createdBy sql.NullInt64
if err := row.Scan(
&out.ID,
&out.RuleID,
&out.Platform,
&groupID,
&region,
&out.Until,
&out.Reason,
&createdBy,
&out.CreatedAt,
); err != nil {
return nil, err
}
if groupID.Valid {
v := groupID.Int64
out.GroupID = &v
}
if region.Valid {
v := strings.TrimSpace(region.String)
if v != "" {
out.Region = &v
}
}
if createdBy.Valid {
v := createdBy.Int64
out.CreatedBy = &v
}
return &out, nil
}
func (r *opsRepository) IsAlertSilenced(ctx context.Context, ruleID int64, platform string, groupID *int64, region *string, now time.Time) (bool, error) {
if r == nil || r.db == nil {
return false, fmt.Errorf("nil ops repository")
}
if ruleID <= 0 {
return false, fmt.Errorf("invalid rule id")
}
platform = strings.TrimSpace(platform)
if platform == "" {
return false, nil
}
if now.IsZero() {
now = time.Now().UTC()
}
q := `
SELECT 1
FROM ops_alert_silences
WHERE rule_id = $1
AND platform = $2
AND (group_id IS NOT DISTINCT FROM $3)
AND (region IS NOT DISTINCT FROM $4)
AND until > $5
LIMIT 1`
var dummy int
err := r.db.QueryRowContext(ctx, q, ruleID, platform, opsNullInt64(groupID), opsNullString(region), now).Scan(&dummy)
if err != nil {
if err == sql.ErrNoRows {
return false, nil
}
return false, err
}
return true, nil
}
func scanOpsAlertEvent(row opsAlertEventRow) (*service.OpsAlertEvent, error) {
var ev service.OpsAlertEvent
var metricValue sql.NullFloat64
......@@ -652,6 +804,10 @@ func buildOpsAlertEventsWhere(filter *service.OpsAlertEventFilter) (string, []an
args = append(args, severity)
clauses = append(clauses, "severity = $"+itoa(len(args)))
}
if filter.EmailSent != nil {
args = append(args, *filter.EmailSent)
clauses = append(clauses, "email_sent = $"+itoa(len(args)))
}
if filter.StartTime != nil && !filter.StartTime.IsZero() {
args = append(args, *filter.StartTime)
clauses = append(clauses, "fired_at >= $"+itoa(len(args)))
......@@ -661,6 +817,14 @@ func buildOpsAlertEventsWhere(filter *service.OpsAlertEventFilter) (string, []an
clauses = append(clauses, "fired_at < $"+itoa(len(args)))
}
// Cursor pagination (descending by fired_at, then id)
if filter.BeforeFiredAt != nil && !filter.BeforeFiredAt.IsZero() && filter.BeforeID != nil && *filter.BeforeID > 0 {
args = append(args, *filter.BeforeFiredAt)
tsArg := "$" + itoa(len(args))
args = append(args, *filter.BeforeID)
idArg := "$" + itoa(len(args))
clauses = append(clauses, fmt.Sprintf("(fired_at < %s OR (fired_at = %s AND id < %s))", tsArg, tsArg, idArg))
}
// Dimensions are stored in JSONB. We filter best-effort without requiring GIN indexes.
if platform := strings.TrimSpace(filter.Platform); platform != "" {
args = append(args, platform)
......
......@@ -296,9 +296,10 @@ INSERT INTO ops_job_heartbeats (
last_error_at,
last_error,
last_duration_ms,
last_result,
updated_at
) VALUES (
$1,$2,$3,$4,$5,$6,NOW()
$1,$2,$3,$4,$5,$6,$7,NOW()
)
ON CONFLICT (job_name) DO UPDATE SET
last_run_at = COALESCE(EXCLUDED.last_run_at, ops_job_heartbeats.last_run_at),
......@@ -312,6 +313,10 @@ ON CONFLICT (job_name) DO UPDATE SET
ELSE COALESCE(EXCLUDED.last_error, ops_job_heartbeats.last_error)
END,
last_duration_ms = COALESCE(EXCLUDED.last_duration_ms, ops_job_heartbeats.last_duration_ms),
last_result = CASE
WHEN EXCLUDED.last_success_at IS NOT NULL THEN COALESCE(EXCLUDED.last_result, ops_job_heartbeats.last_result)
ELSE ops_job_heartbeats.last_result
END,
updated_at = NOW()`
_, err := r.db.ExecContext(
......@@ -323,6 +328,7 @@ ON CONFLICT (job_name) DO UPDATE SET
opsNullTime(input.LastErrorAt),
opsNullString(input.LastError),
opsNullInt(input.LastDurationMs),
opsNullString(input.LastResult),
)
return err
}
......@@ -340,6 +346,7 @@ SELECT
last_error_at,
last_error,
last_duration_ms,
last_result,
updated_at
FROM ops_job_heartbeats
ORDER BY job_name ASC`
......@@ -359,6 +366,8 @@ ORDER BY job_name ASC`
var lastError sql.NullString
var lastDuration sql.NullInt64
var lastResult sql.NullString
if err := rows.Scan(
&item.JobName,
&lastRun,
......@@ -366,6 +375,7 @@ ORDER BY job_name ASC`
&lastErrorAt,
&lastError,
&lastDuration,
&lastResult,
&item.UpdatedAt,
); err != nil {
return nil, err
......@@ -391,6 +401,10 @@ ORDER BY job_name ASC`
v := lastDuration.Int64
item.LastDurationMs = &v
}
if lastResult.Valid {
v := lastResult.String
item.LastResult = &v
}
out = append(out, &item)
}
......
package repository
import (
"context"
"encoding/json"
"fmt"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/redis/go-redis/v9"
)
const proxyLatencyKeyPrefix = "proxy:latency:"
func proxyLatencyKey(proxyID int64) string {
return fmt.Sprintf("%s%d", proxyLatencyKeyPrefix, proxyID)
}
type proxyLatencyCache struct {
rdb *redis.Client
}
func NewProxyLatencyCache(rdb *redis.Client) service.ProxyLatencyCache {
return &proxyLatencyCache{rdb: rdb}
}
func (c *proxyLatencyCache) GetProxyLatencies(ctx context.Context, proxyIDs []int64) (map[int64]*service.ProxyLatencyInfo, error) {
results := make(map[int64]*service.ProxyLatencyInfo)
if len(proxyIDs) == 0 {
return results, nil
}
keys := make([]string, 0, len(proxyIDs))
for _, id := range proxyIDs {
keys = append(keys, proxyLatencyKey(id))
}
values, err := c.rdb.MGet(ctx, keys...).Result()
if err != nil {
return results, err
}
for i, raw := range values {
if raw == nil {
continue
}
var payload []byte
switch v := raw.(type) {
case string:
payload = []byte(v)
case []byte:
payload = v
default:
continue
}
var info service.ProxyLatencyInfo
if err := json.Unmarshal(payload, &info); err != nil {
continue
}
results[proxyIDs[i]] = &info
}
return results, nil
}
func (c *proxyLatencyCache) SetProxyLatency(ctx context.Context, proxyID int64, info *service.ProxyLatencyInfo) error {
if info == nil {
return nil
}
payload, err := json.Marshal(info)
if err != nil {
return err
}
return c.rdb.Set(ctx, proxyLatencyKey(proxyID), payload, 0).Err()
}
......@@ -7,6 +7,7 @@ import (
"io"
"log"
"net/http"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
......@@ -34,7 +35,10 @@ func NewProxyExitInfoProber(cfg *config.Config) service.ProxyExitInfoProber {
}
}
const defaultIPInfoURL = "https://ipinfo.io/json"
const (
defaultIPInfoURL = "http://ip-api.com/json/?lang=zh-CN"
defaultProxyProbeTimeout = 30 * time.Second
)
type proxyProbeService struct {
ipInfoURL string
......@@ -46,7 +50,7 @@ type proxyProbeService struct {
func (s *proxyProbeService) ProbeProxy(ctx context.Context, proxyURL string) (*service.ProxyExitInfo, int64, error) {
client, err := httpclient.GetClient(httpclient.Options{
ProxyURL: proxyURL,
Timeout: 15 * time.Second,
Timeout: defaultProxyProbeTimeout,
InsecureSkipVerify: s.insecureSkipVerify,
ProxyStrict: true,
ValidateResolvedIP: s.validateResolvedIP,
......@@ -75,10 +79,14 @@ func (s *proxyProbeService) ProbeProxy(ctx context.Context, proxyURL string) (*s
}
var ipInfo struct {
IP string `json:"ip"`
City string `json:"city"`
Region string `json:"region"`
Country string `json:"country"`
Status string `json:"status"`
Message string `json:"message"`
Query string `json:"query"`
City string `json:"city"`
Region string `json:"region"`
RegionName string `json:"regionName"`
Country string `json:"country"`
CountryCode string `json:"countryCode"`
}
body, err := io.ReadAll(resp.Body)
......@@ -89,11 +97,22 @@ func (s *proxyProbeService) ProbeProxy(ctx context.Context, proxyURL string) (*s
if err := json.Unmarshal(body, &ipInfo); err != nil {
return nil, latencyMs, fmt.Errorf("failed to parse response: %w", err)
}
if strings.ToLower(ipInfo.Status) != "success" {
if ipInfo.Message == "" {
ipInfo.Message = "ip-api request failed"
}
return nil, latencyMs, fmt.Errorf("ip-api request failed: %s", ipInfo.Message)
}
region := ipInfo.RegionName
if region == "" {
region = ipInfo.Region
}
return &service.ProxyExitInfo{
IP: ipInfo.IP,
City: ipInfo.City,
Region: ipInfo.Region,
Country: ipInfo.Country,
IP: ipInfo.Query,
City: ipInfo.City,
Region: region,
Country: ipInfo.Country,
CountryCode: ipInfo.CountryCode,
}, latencyMs, nil
}
......@@ -21,7 +21,7 @@ type ProxyProbeServiceSuite struct {
func (s *ProxyProbeServiceSuite) SetupTest() {
s.ctx = context.Background()
s.prober = &proxyProbeService{
ipInfoURL: "http://ipinfo.test/json",
ipInfoURL: "http://ip-api.test/json/?lang=zh-CN",
allowPrivateHosts: true,
}
}
......@@ -54,7 +54,7 @@ func (s *ProxyProbeServiceSuite) TestProbeProxy_Success() {
s.setupProxyServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
seen <- r.RequestURI
w.Header().Set("Content-Type", "application/json")
_, _ = io.WriteString(w, `{"ip":"1.2.3.4","city":"c","region":"r","country":"cc"}`)
_, _ = io.WriteString(w, `{"status":"success","query":"1.2.3.4","city":"c","regionName":"r","country":"cc","countryCode":"CC"}`)
}))
info, latencyMs, err := s.prober.ProbeProxy(s.ctx, s.proxySrv.URL)
......@@ -64,11 +64,12 @@ func (s *ProxyProbeServiceSuite) TestProbeProxy_Success() {
require.Equal(s.T(), "c", info.City)
require.Equal(s.T(), "r", info.Region)
require.Equal(s.T(), "cc", info.Country)
require.Equal(s.T(), "CC", info.CountryCode)
// Verify proxy received the request
select {
case uri := <-seen:
require.Contains(s.T(), uri, "ipinfo.test", "expected request to go through proxy")
require.Contains(s.T(), uri, "ip-api.test", "expected request to go through proxy")
default:
require.Fail(s.T(), "expected proxy to receive request")
}
......
......@@ -219,12 +219,54 @@ func (r *proxyRepository) ExistsByHostPortAuth(ctx context.Context, host string,
// CountAccountsByProxyID returns the number of accounts using a specific proxy
func (r *proxyRepository) CountAccountsByProxyID(ctx context.Context, proxyID int64) (int64, error) {
var count int64
if err := scanSingleRow(ctx, r.sql, "SELECT COUNT(*) FROM accounts WHERE proxy_id = $1", []any{proxyID}, &count); err != nil {
if err := scanSingleRow(ctx, r.sql, "SELECT COUNT(*) FROM accounts WHERE proxy_id = $1 AND deleted_at IS NULL", []any{proxyID}, &count); err != nil {
return 0, err
}
return count, nil
}
func (r *proxyRepository) ListAccountSummariesByProxyID(ctx context.Context, proxyID int64) ([]service.ProxyAccountSummary, error) {
rows, err := r.sql.QueryContext(ctx, `
SELECT id, name, platform, type, notes
FROM accounts
WHERE proxy_id = $1 AND deleted_at IS NULL
ORDER BY id DESC
`, proxyID)
if err != nil {
return nil, err
}
defer func() { _ = rows.Close() }()
out := make([]service.ProxyAccountSummary, 0)
for rows.Next() {
var (
id int64
name string
platform string
accType string
notes sql.NullString
)
if err := rows.Scan(&id, &name, &platform, &accType, &notes); err != nil {
return nil, err
}
var notesPtr *string
if notes.Valid {
notesPtr = &notes.String
}
out = append(out, service.ProxyAccountSummary{
ID: id,
Name: name,
Platform: platform,
Type: accType,
Notes: notesPtr,
})
}
if err := rows.Err(); err != nil {
return nil, err
}
return out, nil
}
// GetAccountCountsForProxies returns a map of proxy ID to account count for all proxies
func (r *proxyRepository) GetAccountCountsForProxies(ctx context.Context) (counts map[int64]int64, err error) {
rows, err := r.sql.QueryContext(ctx, "SELECT proxy_id, COUNT(*) AS count FROM accounts WHERE proxy_id IS NOT NULL AND deleted_at IS NULL GROUP BY proxy_id")
......
......@@ -27,7 +27,7 @@ func TestSchedulerSnapshotOutboxReplay(t *testing.T) {
RunMode: config.RunModeStandard,
Gateway: config.GatewayConfig{
Scheduling: config.GatewaySchedulingConfig{
OutboxPollIntervalSeconds: 1,
OutboxPollIntervalSeconds: 1,
FullRebuildIntervalSeconds: 0,
DbFallbackEnabled: true,
},
......
package repository
import (
"context"
"fmt"
"strconv"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/redis/go-redis/v9"
)
// 会话限制缓存常量定义
//
// 设计说明:
// 使用 Redis 有序集合(Sorted Set)跟踪每个账号的活跃会话:
// - Key: session_limit:account:{accountID}
// - Member: sessionUUID(从 metadata.user_id 中提取)
// - Score: Unix 时间戳(会话最后活跃时间)
//
// 通过 ZREMRANGEBYSCORE 自动清理过期会话,无需手动管理 TTL
const (
// 会话限制键前缀
// 格式: session_limit:account:{accountID}
sessionLimitKeyPrefix = "session_limit:account:"
// 窗口费用缓存键前缀
// 格式: window_cost:account:{accountID}
windowCostKeyPrefix = "window_cost:account:"
// 窗口费用缓存 TTL(30秒)
windowCostCacheTTL = 30 * time.Second
)
var (
// registerSessionScript 注册会话活动
// 使用 Redis TIME 命令获取服务器时间,避免多实例时钟不同步
// KEYS[1] = session_limit:account:{accountID}
// ARGV[1] = maxSessions
// ARGV[2] = idleTimeout(秒)
// ARGV[3] = sessionUUID
// 返回: 1 = 允许, 0 = 拒绝
registerSessionScript = redis.NewScript(`
local key = KEYS[1]
local maxSessions = tonumber(ARGV[1])
local idleTimeout = tonumber(ARGV[2])
local sessionUUID = ARGV[3]
-- 使用 Redis 服务器时间,确保多实例时钟一致
local timeResult = redis.call('TIME')
local now = tonumber(timeResult[1])
local expireBefore = now - idleTimeout
-- 清理过期会话
redis.call('ZREMRANGEBYSCORE', key, '-inf', expireBefore)
-- 检查会话是否已存在(支持刷新时间戳)
local exists = redis.call('ZSCORE', key, sessionUUID)
if exists ~= false then
-- 会话已存在,刷新时间戳
redis.call('ZADD', key, now, sessionUUID)
redis.call('EXPIRE', key, idleTimeout + 60)
return 1
end
-- 检查是否达到会话数量上限
local count = redis.call('ZCARD', key)
if count < maxSessions then
-- 未达上限,添加新会话
redis.call('ZADD', key, now, sessionUUID)
redis.call('EXPIRE', key, idleTimeout + 60)
return 1
end
-- 达到上限,拒绝新会话
return 0
`)
// refreshSessionScript 刷新会话时间戳
// KEYS[1] = session_limit:account:{accountID}
// ARGV[1] = idleTimeout(秒)
// ARGV[2] = sessionUUID
refreshSessionScript = redis.NewScript(`
local key = KEYS[1]
local idleTimeout = tonumber(ARGV[1])
local sessionUUID = ARGV[2]
local timeResult = redis.call('TIME')
local now = tonumber(timeResult[1])
-- 检查会话是否存在
local exists = redis.call('ZSCORE', key, sessionUUID)
if exists ~= false then
redis.call('ZADD', key, now, sessionUUID)
redis.call('EXPIRE', key, idleTimeout + 60)
end
return 1
`)
// getActiveSessionCountScript 获取活跃会话数
// KEYS[1] = session_limit:account:{accountID}
// ARGV[1] = idleTimeout(秒)
getActiveSessionCountScript = redis.NewScript(`
local key = KEYS[1]
local idleTimeout = tonumber(ARGV[1])
local timeResult = redis.call('TIME')
local now = tonumber(timeResult[1])
local expireBefore = now - idleTimeout
-- 清理过期会话
redis.call('ZREMRANGEBYSCORE', key, '-inf', expireBefore)
return redis.call('ZCARD', key)
`)
// isSessionActiveScript 检查会话是否活跃
// KEYS[1] = session_limit:account:{accountID}
// ARGV[1] = idleTimeout(秒)
// ARGV[2] = sessionUUID
isSessionActiveScript = redis.NewScript(`
local key = KEYS[1]
local idleTimeout = tonumber(ARGV[1])
local sessionUUID = ARGV[2]
local timeResult = redis.call('TIME')
local now = tonumber(timeResult[1])
local expireBefore = now - idleTimeout
-- 获取会话的时间戳
local score = redis.call('ZSCORE', key, sessionUUID)
if score == false then
return 0
end
-- 检查是否过期
if tonumber(score) <= expireBefore then
return 0
end
return 1
`)
)
type sessionLimitCache struct {
rdb *redis.Client
defaultIdleTimeout time.Duration // 默认空闲超时(用于 GetActiveSessionCount)
}
// NewSessionLimitCache 创建会话限制缓存
// defaultIdleTimeoutMinutes: 默认空闲超时时间(分钟),用于无参数查询
func NewSessionLimitCache(rdb *redis.Client, defaultIdleTimeoutMinutes int) service.SessionLimitCache {
if defaultIdleTimeoutMinutes <= 0 {
defaultIdleTimeoutMinutes = 5 // 默认 5 分钟
}
return &sessionLimitCache{
rdb: rdb,
defaultIdleTimeout: time.Duration(defaultIdleTimeoutMinutes) * time.Minute,
}
}
// sessionLimitKey 生成会话限制的 Redis 键
func sessionLimitKey(accountID int64) string {
return fmt.Sprintf("%s%d", sessionLimitKeyPrefix, accountID)
}
// windowCostKey 生成窗口费用缓存的 Redis 键
func windowCostKey(accountID int64) string {
return fmt.Sprintf("%s%d", windowCostKeyPrefix, accountID)
}
// RegisterSession 注册会话活动
func (c *sessionLimitCache) RegisterSession(ctx context.Context, accountID int64, sessionUUID string, maxSessions int, idleTimeout time.Duration) (bool, error) {
if sessionUUID == "" || maxSessions <= 0 {
return true, nil // 无效参数,默认允许
}
key := sessionLimitKey(accountID)
idleTimeoutSeconds := int(idleTimeout.Seconds())
if idleTimeoutSeconds <= 0 {
idleTimeoutSeconds = int(c.defaultIdleTimeout.Seconds())
}
result, err := registerSessionScript.Run(ctx, c.rdb, []string{key}, maxSessions, idleTimeoutSeconds, sessionUUID).Int()
if err != nil {
return true, err // 失败开放:缓存错误时允许请求通过
}
return result == 1, nil
}
// RefreshSession 刷新会话时间戳
func (c *sessionLimitCache) RefreshSession(ctx context.Context, accountID int64, sessionUUID string, idleTimeout time.Duration) error {
if sessionUUID == "" {
return nil
}
key := sessionLimitKey(accountID)
idleTimeoutSeconds := int(idleTimeout.Seconds())
if idleTimeoutSeconds <= 0 {
idleTimeoutSeconds = int(c.defaultIdleTimeout.Seconds())
}
_, err := refreshSessionScript.Run(ctx, c.rdb, []string{key}, idleTimeoutSeconds, sessionUUID).Result()
return err
}
// GetActiveSessionCount 获取活跃会话数
func (c *sessionLimitCache) GetActiveSessionCount(ctx context.Context, accountID int64) (int, error) {
key := sessionLimitKey(accountID)
idleTimeoutSeconds := int(c.defaultIdleTimeout.Seconds())
result, err := getActiveSessionCountScript.Run(ctx, c.rdb, []string{key}, idleTimeoutSeconds).Int()
if err != nil {
return 0, err
}
return result, nil
}
// GetActiveSessionCountBatch 批量获取多个账号的活跃会话数
func (c *sessionLimitCache) GetActiveSessionCountBatch(ctx context.Context, accountIDs []int64) (map[int64]int, error) {
if len(accountIDs) == 0 {
return make(map[int64]int), nil
}
results := make(map[int64]int, len(accountIDs))
// 使用 pipeline 批量执行
pipe := c.rdb.Pipeline()
idleTimeoutSeconds := int(c.defaultIdleTimeout.Seconds())
cmds := make(map[int64]*redis.Cmd, len(accountIDs))
for _, accountID := range accountIDs {
key := sessionLimitKey(accountID)
cmds[accountID] = getActiveSessionCountScript.Run(ctx, pipe, []string{key}, idleTimeoutSeconds)
}
// 执行 pipeline,即使部分失败也尝试获取成功的结果
_, _ = pipe.Exec(ctx)
for accountID, cmd := range cmds {
if result, err := cmd.Int(); err == nil {
results[accountID] = result
}
}
return results, nil
}
// IsSessionActive 检查会话是否活跃
func (c *sessionLimitCache) IsSessionActive(ctx context.Context, accountID int64, sessionUUID string) (bool, error) {
if sessionUUID == "" {
return false, nil
}
key := sessionLimitKey(accountID)
idleTimeoutSeconds := int(c.defaultIdleTimeout.Seconds())
result, err := isSessionActiveScript.Run(ctx, c.rdb, []string{key}, idleTimeoutSeconds, sessionUUID).Int()
if err != nil {
return false, err
}
return result == 1, nil
}
// ========== 5h窗口费用缓存实现 ==========
// GetWindowCost 获取缓存的窗口费用
func (c *sessionLimitCache) GetWindowCost(ctx context.Context, accountID int64) (float64, bool, error) {
key := windowCostKey(accountID)
val, err := c.rdb.Get(ctx, key).Float64()
if err == redis.Nil {
return 0, false, nil // 缓存未命中
}
if err != nil {
return 0, false, err
}
return val, true, nil
}
// SetWindowCost 设置窗口费用缓存
func (c *sessionLimitCache) SetWindowCost(ctx context.Context, accountID int64, cost float64) error {
key := windowCostKey(accountID)
return c.rdb.Set(ctx, key, cost, windowCostCacheTTL).Err()
}
// GetWindowCostBatch 批量获取窗口费用缓存
func (c *sessionLimitCache) GetWindowCostBatch(ctx context.Context, accountIDs []int64) (map[int64]float64, error) {
if len(accountIDs) == 0 {
return make(map[int64]float64), nil
}
// 构建批量查询的 keys
keys := make([]string, len(accountIDs))
for i, accountID := range accountIDs {
keys[i] = windowCostKey(accountID)
}
// 使用 MGET 批量获取
vals, err := c.rdb.MGet(ctx, keys...).Result()
if err != nil {
return nil, err
}
results := make(map[int64]float64, len(accountIDs))
for i, val := range vals {
if val == nil {
continue // 缓存未命中
}
// 尝试解析为 float64
switch v := val.(type) {
case string:
if cost, err := strconv.ParseFloat(v, 64); err == nil {
results[accountIDs[i]] = cost
}
case float64:
results[accountIDs[i]] = v
}
}
return results, nil
}
......@@ -22,7 +22,7 @@ import (
"github.com/lib/pq"
)
const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, billing_type, stream, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, created_at"
const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, account_rate_multiplier, billing_type, stream, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, created_at"
type usageLogRepository struct {
client *dbent.Client
......@@ -105,6 +105,7 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
total_cost,
actual_cost,
rate_multiplier,
account_rate_multiplier,
billing_type,
stream,
duration_ms,
......@@ -120,7 +121,7 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
$8, $9, $10, $11,
$12, $13,
$14, $15, $16, $17, $18, $19,
$20, $21, $22, $23, $24, $25, $26, $27, $28, $29
$20, $21, $22, $23, $24, $25, $26, $27, $28, $29, $30
)
ON CONFLICT (request_id, api_key_id) DO NOTHING
RETURNING id, created_at
......@@ -160,6 +161,7 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
log.TotalCost,
log.ActualCost,
rateMultiplier,
log.AccountRateMultiplier,
log.BillingType,
log.Stream,
duration,
......@@ -270,13 +272,13 @@ type DashboardStats = usagestats.DashboardStats
func (r *usageLogRepository) GetDashboardStats(ctx context.Context) (*DashboardStats, error) {
stats := &DashboardStats{}
now := time.Now().UTC()
todayUTC := truncateToDayUTC(now)
now := timezone.Now()
todayStart := timezone.Today()
if err := r.fillDashboardEntityStats(ctx, stats, todayUTC, now); err != nil {
if err := r.fillDashboardEntityStats(ctx, stats, todayStart, now); err != nil {
return nil, err
}
if err := r.fillDashboardUsageStatsAggregated(ctx, stats, todayUTC, now); err != nil {
if err := r.fillDashboardUsageStatsAggregated(ctx, stats, todayStart, now); err != nil {
return nil, err
}
......@@ -298,13 +300,13 @@ func (r *usageLogRepository) GetDashboardStatsWithRange(ctx context.Context, sta
}
stats := &DashboardStats{}
now := time.Now().UTC()
todayUTC := truncateToDayUTC(now)
now := timezone.Now()
todayStart := timezone.Today()
if err := r.fillDashboardEntityStats(ctx, stats, todayUTC, now); err != nil {
if err := r.fillDashboardEntityStats(ctx, stats, todayStart, now); err != nil {
return nil, err
}
if err := r.fillDashboardUsageStatsFromUsageLogs(ctx, stats, startUTC, endUTC, todayUTC, now); err != nil {
if err := r.fillDashboardUsageStatsFromUsageLogs(ctx, stats, startUTC, endUTC, todayStart, now); err != nil {
return nil, err
}
......@@ -455,7 +457,7 @@ func (r *usageLogRepository) fillDashboardUsageStatsAggregated(ctx context.Conte
FROM usage_dashboard_hourly
WHERE bucket_start = $1
`
hourStart := now.UTC().Truncate(time.Hour)
hourStart := now.In(timezone.Location()).Truncate(time.Hour)
if err := scanSingleRow(ctx, r.sql, hourlyActiveQuery, []any{hourStart}, &stats.HourlyActiveUsers); err != nil {
if err != sql.ErrNoRows {
return err
......@@ -835,7 +837,9 @@ func (r *usageLogRepository) GetAccountTodayStats(ctx context.Context, accountID
SELECT
COUNT(*) as requests,
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as tokens,
COALESCE(SUM(actual_cost), 0) as cost
COALESCE(SUM(total_cost * COALESCE(account_rate_multiplier, 1)), 0) as cost,
COALESCE(SUM(total_cost), 0) as standard_cost,
COALESCE(SUM(actual_cost), 0) as user_cost
FROM usage_logs
WHERE account_id = $1 AND created_at >= $2
`
......@@ -849,6 +853,8 @@ func (r *usageLogRepository) GetAccountTodayStats(ctx context.Context, accountID
&stats.Requests,
&stats.Tokens,
&stats.Cost,
&stats.StandardCost,
&stats.UserCost,
); err != nil {
return nil, err
}
......@@ -861,7 +867,9 @@ func (r *usageLogRepository) GetAccountWindowStats(ctx context.Context, accountI
SELECT
COUNT(*) as requests,
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as tokens,
COALESCE(SUM(actual_cost), 0) as cost
COALESCE(SUM(total_cost * COALESCE(account_rate_multiplier, 1)), 0) as cost,
COALESCE(SUM(total_cost), 0) as standard_cost,
COALESCE(SUM(actual_cost), 0) as user_cost
FROM usage_logs
WHERE account_id = $1 AND created_at >= $2
`
......@@ -875,6 +883,8 @@ func (r *usageLogRepository) GetAccountWindowStats(ctx context.Context, accountI
&stats.Requests,
&stats.Tokens,
&stats.Cost,
&stats.StandardCost,
&stats.UserCost,
); err != nil {
return nil, err
}
......@@ -1400,8 +1410,8 @@ func (r *usageLogRepository) GetBatchAPIKeyUsageStats(ctx context.Context, apiKe
return result, nil
}
// GetUsageTrendWithFilters returns usage trend data with optional user/api_key filters
func (r *usageLogRepository) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID int64) (results []TrendDataPoint, err error) {
// GetUsageTrendWithFilters returns usage trend data with optional filters
func (r *usageLogRepository) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID, accountID, groupID int64, model string, stream *bool) (results []TrendDataPoint, err error) {
dateFormat := "YYYY-MM-DD"
if granularity == "hour" {
dateFormat = "YYYY-MM-DD HH24:00"
......@@ -1430,6 +1440,22 @@ func (r *usageLogRepository) GetUsageTrendWithFilters(ctx context.Context, start
query += fmt.Sprintf(" AND api_key_id = $%d", len(args)+1)
args = append(args, apiKeyID)
}
if accountID > 0 {
query += fmt.Sprintf(" AND account_id = $%d", len(args)+1)
args = append(args, accountID)
}
if groupID > 0 {
query += fmt.Sprintf(" AND group_id = $%d", len(args)+1)
args = append(args, groupID)
}
if model != "" {
query += fmt.Sprintf(" AND model = $%d", len(args)+1)
args = append(args, model)
}
if stream != nil {
query += fmt.Sprintf(" AND stream = $%d", len(args)+1)
args = append(args, *stream)
}
query += " GROUP BY date ORDER BY date ASC"
rows, err := r.sql.QueryContext(ctx, query, args...)
......@@ -1452,9 +1478,15 @@ func (r *usageLogRepository) GetUsageTrendWithFilters(ctx context.Context, start
return results, nil
}
// GetModelStatsWithFilters returns model statistics with optional user/api_key filters
func (r *usageLogRepository) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID int64) (results []ModelStat, err error) {
query := `
// GetModelStatsWithFilters returns model statistics with optional filters
func (r *usageLogRepository) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, stream *bool) (results []ModelStat, err error) {
actualCostExpr := "COALESCE(SUM(actual_cost), 0) as actual_cost"
// 当仅按 account_id 聚合时,实际费用使用账号倍率(total_cost * account_rate_multiplier)。
if accountID > 0 && userID == 0 && apiKeyID == 0 {
actualCostExpr = "COALESCE(SUM(total_cost * COALESCE(account_rate_multiplier, 1)), 0) as actual_cost"
}
query := fmt.Sprintf(`
SELECT
model,
COUNT(*) as requests,
......@@ -1462,10 +1494,10 @@ func (r *usageLogRepository) GetModelStatsWithFilters(ctx context.Context, start
COALESCE(SUM(output_tokens), 0) as output_tokens,
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as total_tokens,
COALESCE(SUM(total_cost), 0) as cost,
COALESCE(SUM(actual_cost), 0) as actual_cost
%s
FROM usage_logs
WHERE created_at >= $1 AND created_at < $2
`
`, actualCostExpr)
args := []any{startTime, endTime}
if userID > 0 {
......@@ -1480,6 +1512,14 @@ func (r *usageLogRepository) GetModelStatsWithFilters(ctx context.Context, start
query += fmt.Sprintf(" AND account_id = $%d", len(args)+1)
args = append(args, accountID)
}
if groupID > 0 {
query += fmt.Sprintf(" AND group_id = $%d", len(args)+1)
args = append(args, groupID)
}
if stream != nil {
query += fmt.Sprintf(" AND stream = $%d", len(args)+1)
args = append(args, *stream)
}
query += " GROUP BY model ORDER BY total_tokens DESC"
rows, err := r.sql.QueryContext(ctx, query, args...)
......@@ -1587,12 +1627,14 @@ func (r *usageLogRepository) GetStatsWithFilters(ctx context.Context, filters Us
COALESCE(SUM(cache_creation_tokens + cache_read_tokens), 0) as total_cache_tokens,
COALESCE(SUM(total_cost), 0) as total_cost,
COALESCE(SUM(actual_cost), 0) as total_actual_cost,
COALESCE(SUM(total_cost * COALESCE(account_rate_multiplier, 1)), 0) as total_account_cost,
COALESCE(AVG(duration_ms), 0) as avg_duration_ms
FROM usage_logs
%s
`, buildWhere(conditions))
stats := &UsageStats{}
var totalAccountCost float64
if err := scanSingleRow(
ctx,
r.sql,
......@@ -1604,10 +1646,14 @@ func (r *usageLogRepository) GetStatsWithFilters(ctx context.Context, filters Us
&stats.TotalCacheTokens,
&stats.TotalCost,
&stats.TotalActualCost,
&totalAccountCost,
&stats.AverageDurationMs,
); err != nil {
return nil, err
}
if filters.AccountID > 0 {
stats.TotalAccountCost = &totalAccountCost
}
stats.TotalTokens = stats.TotalInputTokens + stats.TotalOutputTokens + stats.TotalCacheTokens
return stats, nil
}
......@@ -1634,7 +1680,8 @@ func (r *usageLogRepository) GetAccountUsageStats(ctx context.Context, accountID
COUNT(*) as requests,
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as tokens,
COALESCE(SUM(total_cost), 0) as cost,
COALESCE(SUM(actual_cost), 0) as actual_cost
COALESCE(SUM(total_cost * COALESCE(account_rate_multiplier, 1)), 0) as actual_cost,
COALESCE(SUM(actual_cost), 0) as user_cost
FROM usage_logs
WHERE account_id = $1 AND created_at >= $2 AND created_at < $3
GROUP BY date
......@@ -1661,7 +1708,8 @@ func (r *usageLogRepository) GetAccountUsageStats(ctx context.Context, accountID
var tokens int64
var cost float64
var actualCost float64
if err = rows.Scan(&date, &requests, &tokens, &cost, &actualCost); err != nil {
var userCost float64
if err = rows.Scan(&date, &requests, &tokens, &cost, &actualCost, &userCost); err != nil {
return nil, err
}
t, _ := time.Parse("2006-01-02", date)
......@@ -1672,19 +1720,21 @@ func (r *usageLogRepository) GetAccountUsageStats(ctx context.Context, accountID
Tokens: tokens,
Cost: cost,
ActualCost: actualCost,
UserCost: userCost,
})
}
if err = rows.Err(); err != nil {
return nil, err
}
var totalActualCost, totalStandardCost float64
var totalAccountCost, totalUserCost, totalStandardCost float64
var totalRequests, totalTokens int64
var highestCostDay, highestRequestDay *AccountUsageHistory
for i := range history {
h := &history[i]
totalActualCost += h.ActualCost
totalAccountCost += h.ActualCost
totalUserCost += h.UserCost
totalStandardCost += h.Cost
totalRequests += h.Requests
totalTokens += h.Tokens
......@@ -1711,11 +1761,13 @@ func (r *usageLogRepository) GetAccountUsageStats(ctx context.Context, accountID
summary := AccountUsageSummary{
Days: daysCount,
ActualDaysUsed: actualDaysUsed,
TotalCost: totalActualCost,
TotalCost: totalAccountCost,
TotalUserCost: totalUserCost,
TotalStandardCost: totalStandardCost,
TotalRequests: totalRequests,
TotalTokens: totalTokens,
AvgDailyCost: totalActualCost / float64(actualDaysUsed),
AvgDailyCost: totalAccountCost / float64(actualDaysUsed),
AvgDailyUserCost: totalUserCost / float64(actualDaysUsed),
AvgDailyRequests: float64(totalRequests) / float64(actualDaysUsed),
AvgDailyTokens: float64(totalTokens) / float64(actualDaysUsed),
AvgDurationMs: avgDuration,
......@@ -1727,11 +1779,13 @@ func (r *usageLogRepository) GetAccountUsageStats(ctx context.Context, accountID
summary.Today = &struct {
Date string `json:"date"`
Cost float64 `json:"cost"`
UserCost float64 `json:"user_cost"`
Requests int64 `json:"requests"`
Tokens int64 `json:"tokens"`
}{
Date: history[i].Date,
Cost: history[i].ActualCost,
UserCost: history[i].UserCost,
Requests: history[i].Requests,
Tokens: history[i].Tokens,
}
......@@ -1744,11 +1798,13 @@ func (r *usageLogRepository) GetAccountUsageStats(ctx context.Context, accountID
Date string `json:"date"`
Label string `json:"label"`
Cost float64 `json:"cost"`
UserCost float64 `json:"user_cost"`
Requests int64 `json:"requests"`
}{
Date: highestCostDay.Date,
Label: highestCostDay.Label,
Cost: highestCostDay.ActualCost,
UserCost: highestCostDay.UserCost,
Requests: highestCostDay.Requests,
}
}
......@@ -1759,15 +1815,17 @@ func (r *usageLogRepository) GetAccountUsageStats(ctx context.Context, accountID
Label string `json:"label"`
Requests int64 `json:"requests"`
Cost float64 `json:"cost"`
UserCost float64 `json:"user_cost"`
}{
Date: highestRequestDay.Date,
Label: highestRequestDay.Label,
Requests: highestRequestDay.Requests,
Cost: highestRequestDay.ActualCost,
UserCost: highestRequestDay.UserCost,
}
}
models, err := r.GetModelStatsWithFilters(ctx, startTime, endTime, 0, 0, accountID)
models, err := r.GetModelStatsWithFilters(ctx, startTime, endTime, 0, 0, accountID, 0, nil)
if err != nil {
models = []ModelStat{}
}
......@@ -1994,36 +2052,37 @@ func (r *usageLogRepository) loadSubscriptions(ctx context.Context, ids []int64)
func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, error) {
var (
id int64
userID int64
apiKeyID int64
accountID int64
requestID sql.NullString
model string
groupID sql.NullInt64
subscriptionID sql.NullInt64
inputTokens int
outputTokens int
cacheCreationTokens int
cacheReadTokens int
cacheCreation5m int
cacheCreation1h int
inputCost float64
outputCost float64
cacheCreationCost float64
cacheReadCost float64
totalCost float64
actualCost float64
rateMultiplier float64
billingType int16
stream bool
durationMs sql.NullInt64
firstTokenMs sql.NullInt64
userAgent sql.NullString
ipAddress sql.NullString
imageCount int
imageSize sql.NullString
createdAt time.Time
id int64
userID int64
apiKeyID int64
accountID int64
requestID sql.NullString
model string
groupID sql.NullInt64
subscriptionID sql.NullInt64
inputTokens int
outputTokens int
cacheCreationTokens int
cacheReadTokens int
cacheCreation5m int
cacheCreation1h int
inputCost float64
outputCost float64
cacheCreationCost float64
cacheReadCost float64
totalCost float64
actualCost float64
rateMultiplier float64
accountRateMultiplier sql.NullFloat64
billingType int16
stream bool
durationMs sql.NullInt64
firstTokenMs sql.NullInt64
userAgent sql.NullString
ipAddress sql.NullString
imageCount int
imageSize sql.NullString
createdAt time.Time
)
if err := scanner.Scan(
......@@ -2048,6 +2107,7 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
&totalCost,
&actualCost,
&rateMultiplier,
&accountRateMultiplier,
&billingType,
&stream,
&durationMs,
......@@ -2080,6 +2140,7 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
TotalCost: totalCost,
ActualCost: actualCost,
RateMultiplier: rateMultiplier,
AccountRateMultiplier: nullFloat64Ptr(accountRateMultiplier),
BillingType: int8(billingType),
Stream: stream,
ImageCount: imageCount,
......@@ -2186,6 +2247,14 @@ func nullInt(v *int) sql.NullInt64 {
return sql.NullInt64{Int64: int64(*v), Valid: true}
}
func nullFloat64Ptr(v sql.NullFloat64) *float64 {
if !v.Valid {
return nil
}
out := v.Float64
return &out
}
func nullString(v *string) sql.NullString {
if v == nil || *v == "" {
return sql.NullString{}
......
......@@ -11,6 +11,7 @@ import (
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/stretchr/testify/suite"
......@@ -36,6 +37,12 @@ func TestUsageLogRepoSuite(t *testing.T) {
suite.Run(t, new(UsageLogRepoSuite))
}
// truncateToDayUTC 截断到 UTC 日期边界(测试辅助函数)
func truncateToDayUTC(t time.Time) time.Time {
t = t.UTC()
return time.Date(t.Year(), t.Month(), t.Day(), 0, 0, 0, 0, time.UTC)
}
func (s *UsageLogRepoSuite) createUsageLog(user *service.User, apiKey *service.APIKey, account *service.Account, inputTokens, outputTokens int, cost float64, createdAt time.Time) *service.UsageLog {
log := &service.UsageLog{
UserID: user.ID,
......@@ -95,6 +102,34 @@ func (s *UsageLogRepoSuite) TestGetByID_NotFound() {
s.Require().Error(err, "expected error for non-existent ID")
}
func (s *UsageLogRepoSuite) TestGetByID_ReturnsAccountRateMultiplier() {
user := mustCreateUser(s.T(), s.client, &service.User{Email: "getbyid-mult@test.com"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-getbyid-mult", Name: "k"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-getbyid-mult"})
m := 0.5
log := &service.UsageLog{
UserID: user.ID,
APIKeyID: apiKey.ID,
AccountID: account.ID,
RequestID: uuid.New().String(),
Model: "claude-3",
InputTokens: 10,
OutputTokens: 20,
TotalCost: 1.0,
ActualCost: 2.0,
AccountRateMultiplier: &m,
CreatedAt: timezone.Today().Add(2 * time.Hour),
}
_, err := s.repo.Create(s.ctx, log)
s.Require().NoError(err)
got, err := s.repo.GetByID(s.ctx, log.ID)
s.Require().NoError(err)
s.Require().NotNil(got.AccountRateMultiplier)
s.Require().InEpsilon(0.5, *got.AccountRateMultiplier, 0.0001)
}
// --- Delete ---
func (s *UsageLogRepoSuite) TestDelete() {
......@@ -403,12 +438,49 @@ func (s *UsageLogRepoSuite) TestGetAccountTodayStats() {
apiKey := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-acctoday", Name: "k"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-today"})
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now())
createdAt := timezone.Today().Add(1 * time.Hour)
m1 := 1.5
m2 := 0.0
_, err := s.repo.Create(s.ctx, &service.UsageLog{
UserID: user.ID,
APIKeyID: apiKey.ID,
AccountID: account.ID,
RequestID: uuid.New().String(),
Model: "claude-3",
InputTokens: 10,
OutputTokens: 20,
TotalCost: 1.0,
ActualCost: 2.0,
AccountRateMultiplier: &m1,
CreatedAt: createdAt,
})
s.Require().NoError(err)
_, err = s.repo.Create(s.ctx, &service.UsageLog{
UserID: user.ID,
APIKeyID: apiKey.ID,
AccountID: account.ID,
RequestID: uuid.New().String(),
Model: "claude-3",
InputTokens: 5,
OutputTokens: 5,
TotalCost: 0.5,
ActualCost: 1.0,
AccountRateMultiplier: &m2,
CreatedAt: createdAt,
})
s.Require().NoError(err)
stats, err := s.repo.GetAccountTodayStats(s.ctx, account.ID)
s.Require().NoError(err, "GetAccountTodayStats")
s.Require().Equal(int64(1), stats.Requests)
s.Require().Equal(int64(30), stats.Tokens)
s.Require().Equal(int64(2), stats.Requests)
s.Require().Equal(int64(40), stats.Tokens)
// account cost = SUM(total_cost * account_rate_multiplier)
s.Require().InEpsilon(1.5, stats.Cost, 0.0001)
// standard cost = SUM(total_cost)
s.Require().InEpsilon(1.5, stats.StandardCost, 0.0001)
// user cost = SUM(actual_cost)
s.Require().InEpsilon(3.0, stats.UserCost, 0.0001)
}
func (s *UsageLogRepoSuite) TestDashboardAggregationConsistency() {
......@@ -416,8 +488,8 @@ func (s *UsageLogRepoSuite) TestDashboardAggregationConsistency() {
// 使用固定的时间偏移确保 hour1 和 hour2 在同一天且都在过去
// 选择当天 02:00 和 03:00 作为测试时间点(基于 now 的日期)
dayStart := truncateToDayUTC(now)
hour1 := dayStart.Add(2 * time.Hour) // 当天 02:00
hour2 := dayStart.Add(3 * time.Hour) // 当天 03:00
hour1 := dayStart.Add(2 * time.Hour) // 当天 02:00
hour2 := dayStart.Add(3 * time.Hour) // 当天 03:00
// 如果当前时间早于 hour2,则使用昨天的时间
if now.Before(hour2.Add(time.Hour)) {
dayStart = dayStart.Add(-24 * time.Hour)
......@@ -872,17 +944,17 @@ func (s *UsageLogRepoSuite) TestGetUsageTrendWithFilters() {
endTime := base.Add(48 * time.Hour)
// Test with user filter
trend, err := s.repo.GetUsageTrendWithFilters(s.ctx, startTime, endTime, "day", user.ID, 0)
trend, err := s.repo.GetUsageTrendWithFilters(s.ctx, startTime, endTime, "day", user.ID, 0, 0, 0, "", nil)
s.Require().NoError(err, "GetUsageTrendWithFilters user filter")
s.Require().Len(trend, 2)
// Test with apiKey filter
trend, err = s.repo.GetUsageTrendWithFilters(s.ctx, startTime, endTime, "day", 0, apiKey.ID)
trend, err = s.repo.GetUsageTrendWithFilters(s.ctx, startTime, endTime, "day", 0, apiKey.ID, 0, 0, "", nil)
s.Require().NoError(err, "GetUsageTrendWithFilters apiKey filter")
s.Require().Len(trend, 2)
// Test with both filters
trend, err = s.repo.GetUsageTrendWithFilters(s.ctx, startTime, endTime, "day", user.ID, apiKey.ID)
trend, err = s.repo.GetUsageTrendWithFilters(s.ctx, startTime, endTime, "day", user.ID, apiKey.ID, 0, 0, "", nil)
s.Require().NoError(err, "GetUsageTrendWithFilters both filters")
s.Require().Len(trend, 2)
}
......@@ -899,7 +971,7 @@ func (s *UsageLogRepoSuite) TestGetUsageTrendWithFilters_HourlyGranularity() {
startTime := base.Add(-1 * time.Hour)
endTime := base.Add(3 * time.Hour)
trend, err := s.repo.GetUsageTrendWithFilters(s.ctx, startTime, endTime, "hour", user.ID, 0)
trend, err := s.repo.GetUsageTrendWithFilters(s.ctx, startTime, endTime, "hour", user.ID, 0, 0, 0, "", nil)
s.Require().NoError(err, "GetUsageTrendWithFilters hourly")
s.Require().Len(trend, 2)
}
......@@ -945,17 +1017,17 @@ func (s *UsageLogRepoSuite) TestGetModelStatsWithFilters() {
endTime := base.Add(2 * time.Hour)
// Test with user filter
stats, err := s.repo.GetModelStatsWithFilters(s.ctx, startTime, endTime, user.ID, 0, 0)
stats, err := s.repo.GetModelStatsWithFilters(s.ctx, startTime, endTime, user.ID, 0, 0, 0, nil)
s.Require().NoError(err, "GetModelStatsWithFilters user filter")
s.Require().Len(stats, 2)
// Test with apiKey filter
stats, err = s.repo.GetModelStatsWithFilters(s.ctx, startTime, endTime, 0, apiKey.ID, 0)
stats, err = s.repo.GetModelStatsWithFilters(s.ctx, startTime, endTime, 0, apiKey.ID, 0, 0, nil)
s.Require().NoError(err, "GetModelStatsWithFilters apiKey filter")
s.Require().Len(stats, 2)
// Test with account filter
stats, err = s.repo.GetModelStatsWithFilters(s.ctx, startTime, endTime, 0, 0, account.ID)
stats, err = s.repo.GetModelStatsWithFilters(s.ctx, startTime, endTime, 0, 0, account.ID, 0, nil)
s.Require().NoError(err, "GetModelStatsWithFilters account filter")
s.Require().Len(stats, 2)
}
......
......@@ -37,6 +37,16 @@ func ProvidePricingRemoteClient(cfg *config.Config) service.PricingRemoteClient
return NewPricingRemoteClient(cfg.Update.ProxyURL)
}
// ProvideSessionLimitCache 创建会话限制缓存
// 用于 Anthropic OAuth/SetupToken 账号的并发会话数量控制
func ProvideSessionLimitCache(rdb *redis.Client, cfg *config.Config) service.SessionLimitCache {
defaultIdleTimeoutMinutes := 5 // 默认 5 分钟空闲超时
if cfg != nil && cfg.Gateway.SessionIdleTimeoutMinutes > 0 {
defaultIdleTimeoutMinutes = cfg.Gateway.SessionIdleTimeoutMinutes
}
return NewSessionLimitCache(rdb, defaultIdleTimeoutMinutes)
}
// ProviderSet is the Wire provider set for all repositories
var ProviderSet = wire.NewSet(
NewUserRepository,
......@@ -61,6 +71,7 @@ var ProviderSet = wire.NewSet(
NewTempUnschedCache,
NewTimeoutCounterCache,
ProvideConcurrencyCache,
ProvideSessionLimitCache,
NewDashboardCache,
NewEmailCache,
NewIdentityCache,
......@@ -69,6 +80,7 @@ var ProviderSet = wire.NewSet(
NewGeminiTokenCache,
NewSchedulerCache,
NewSchedulerOutboxRepository,
NewProxyLatencyCache,
// HTTP service ports (DI Strategy A: return interface directly)
NewTurnstileVerifier,
......
......@@ -239,9 +239,10 @@ func TestAPIContracts(t *testing.T) {
"cache_creation_cost": 0,
"cache_read_cost": 0,
"total_cost": 0.5,
"actual_cost": 0.5,
"rate_multiplier": 1,
"billing_type": 0,
"actual_cost": 0.5,
"rate_multiplier": 1,
"account_rate_multiplier": null,
"billing_type": 0,
"stream": true,
"duration_ms": 100,
"first_token_ms": 50,
......@@ -262,11 +263,11 @@ func TestAPIContracts(t *testing.T) {
name: "GET /api/v1/admin/settings",
setup: func(t *testing.T, deps *contractDeps) {
t.Helper()
deps.settingRepo.SetAll(map[string]string{
service.SettingKeyRegistrationEnabled: "true",
service.SettingKeyEmailVerifyEnabled: "false",
deps.settingRepo.SetAll(map[string]string{
service.SettingKeyRegistrationEnabled: "true",
service.SettingKeyEmailVerifyEnabled: "false",
service.SettingKeySMTPHost: "smtp.example.com",
service.SettingKeySMTPHost: "smtp.example.com",
service.SettingKeySMTPPort: "587",
service.SettingKeySMTPUsername: "user",
service.SettingKeySMTPPassword: "secret",
......@@ -285,15 +286,15 @@ func TestAPIContracts(t *testing.T) {
service.SettingKeyContactInfo: "support",
service.SettingKeyDocURL: "https://docs.example.com",
service.SettingKeyDefaultConcurrency: "5",
service.SettingKeyDefaultBalance: "1.25",
service.SettingKeyDefaultConcurrency: "5",
service.SettingKeyDefaultBalance: "1.25",
service.SettingKeyOpsMonitoringEnabled: "false",
service.SettingKeyOpsRealtimeMonitoringEnabled: "true",
service.SettingKeyOpsQueryModeDefault: "auto",
service.SettingKeyOpsMetricsIntervalSeconds: "60",
})
},
service.SettingKeyOpsMonitoringEnabled: "false",
service.SettingKeyOpsRealtimeMonitoringEnabled: "true",
service.SettingKeyOpsQueryModeDefault: "auto",
service.SettingKeyOpsMetricsIntervalSeconds: "60",
})
},
method: http.MethodGet,
path: "/api/v1/admin/settings",
wantStatus: http.StatusOK,
......@@ -435,12 +436,12 @@ func newContractDeps(t *testing.T) *contractDeps {
settingRepo := newStubSettingRepo()
settingService := service.NewSettingService(settingRepo, cfg)
adminService := service.NewAdminService(userRepo, groupRepo, &accountRepo, proxyRepo, apiKeyRepo, redeemRepo, nil, nil, nil)
adminService := service.NewAdminService(userRepo, groupRepo, &accountRepo, proxyRepo, apiKeyRepo, redeemRepo, nil, nil, nil, nil)
authHandler := handler.NewAuthHandler(cfg, nil, userService, settingService, nil)
apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService)
usageHandler := handler.NewUsageHandler(usageService, apiKeyService)
adminSettingHandler := adminhandler.NewSettingHandler(settingService, nil, nil, nil)
adminAccountHandler := adminhandler.NewAccountHandler(adminService, nil, nil, nil, nil, nil, nil, nil, nil, nil)
adminAccountHandler := adminhandler.NewAccountHandler(adminService, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
jwtAuth := func(c *gin.Context) {
c.Set(string(middleware.ContextKeyUser), middleware.AuthSubject{
......@@ -779,6 +780,10 @@ func (s *stubAccountRepo) SetAntigravityQuotaScopeLimit(ctx context.Context, id
return errors.New("not implemented")
}
func (s *stubAccountRepo) SetModelRateLimit(ctx context.Context, id int64, scope string, resetAt time.Time) error {
return errors.New("not implemented")
}
func (s *stubAccountRepo) SetOverloaded(ctx context.Context, id int64, until time.Time) error {
return errors.New("not implemented")
}
......@@ -799,6 +804,10 @@ func (s *stubAccountRepo) ClearAntigravityQuotaScopes(ctx context.Context, id in
return errors.New("not implemented")
}
func (s *stubAccountRepo) ClearModelRateLimits(ctx context.Context, id int64) error {
return errors.New("not implemented")
}
func (s *stubAccountRepo) UpdateSessionWindow(ctx context.Context, id int64, start, end *time.Time, status string) error {
return errors.New("not implemented")
}
......@@ -858,6 +867,10 @@ func (stubProxyRepo) CountAccountsByProxyID(ctx context.Context, proxyID int64)
return 0, errors.New("not implemented")
}
func (stubProxyRepo) ListAccountSummariesByProxyID(ctx context.Context, proxyID int64) ([]service.ProxyAccountSummary, error) {
return nil, errors.New("not implemented")
}
type stubRedeemCodeRepo struct{}
func (stubRedeemCodeRepo) Create(ctx context.Context, code *service.RedeemCode) error {
......@@ -1229,11 +1242,11 @@ func (r *stubUsageLogRepo) GetDashboardStats(ctx context.Context) (*usagestats.D
return nil, errors.New("not implemented")
}
func (r *stubUsageLogRepo) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID int64) ([]usagestats.TrendDataPoint, error) {
func (r *stubUsageLogRepo) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID, accountID, groupID int64, model string, stream *bool) ([]usagestats.TrendDataPoint, error) {
return nil, errors.New("not implemented")
}
func (r *stubUsageLogRepo) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID int64) ([]usagestats.ModelStat, error) {
func (r *stubUsageLogRepo) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, stream *bool) ([]usagestats.ModelStat, error) {
return nil, errors.New("not implemented")
}
......
package middleware
import (
"crypto/rand"
"encoding/base64"
"strings"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/gin-gonic/gin"
)
const (
// CSPNonceKey is the context key for storing the CSP nonce
CSPNonceKey = "csp_nonce"
// NonceTemplate is the placeholder in CSP policy for nonce
NonceTemplate = "__CSP_NONCE__"
// CloudflareInsightsDomain is the domain for Cloudflare Web Analytics
CloudflareInsightsDomain = "https://static.cloudflareinsights.com"
)
// GenerateNonce generates a cryptographically secure random nonce
func GenerateNonce() string {
b := make([]byte, 16)
_, _ = rand.Read(b)
return base64.StdEncoding.EncodeToString(b)
}
// GetNonceFromContext retrieves the CSP nonce from gin context
func GetNonceFromContext(c *gin.Context) string {
if nonce, exists := c.Get(CSPNonceKey); exists {
if s, ok := nonce.(string); ok {
return s
}
}
return ""
}
// SecurityHeaders sets baseline security headers for all responses.
func SecurityHeaders(cfg config.CSPConfig) gin.HandlerFunc {
policy := strings.TrimSpace(cfg.Policy)
......@@ -14,13 +42,75 @@ func SecurityHeaders(cfg config.CSPConfig) gin.HandlerFunc {
policy = config.DefaultCSPPolicy
}
// Enhance policy with required directives (nonce placeholder and Cloudflare Insights)
policy = enhanceCSPPolicy(policy)
return func(c *gin.Context) {
c.Header("X-Content-Type-Options", "nosniff")
c.Header("X-Frame-Options", "DENY")
c.Header("Referrer-Policy", "strict-origin-when-cross-origin")
if cfg.Enabled {
c.Header("Content-Security-Policy", policy)
// Generate nonce for this request
nonce := GenerateNonce()
c.Set(CSPNonceKey, nonce)
// Replace nonce placeholder in policy
finalPolicy := strings.ReplaceAll(policy, NonceTemplate, "'nonce-"+nonce+"'")
c.Header("Content-Security-Policy", finalPolicy)
}
c.Next()
}
}
// enhanceCSPPolicy ensures the CSP policy includes nonce support and Cloudflare Insights domain.
// This allows the application to work correctly even if the config file has an older CSP policy.
func enhanceCSPPolicy(policy string) string {
// Add nonce placeholder to script-src if not present
if !strings.Contains(policy, NonceTemplate) && !strings.Contains(policy, "'nonce-") {
policy = addToDirective(policy, "script-src", NonceTemplate)
}
// Add Cloudflare Insights domain to script-src if not present
if !strings.Contains(policy, CloudflareInsightsDomain) {
policy = addToDirective(policy, "script-src", CloudflareInsightsDomain)
}
return policy
}
// addToDirective adds a value to a specific CSP directive.
// If the directive doesn't exist, it will be added after default-src.
func addToDirective(policy, directive, value string) string {
// Find the directive in the policy
directivePrefix := directive + " "
idx := strings.Index(policy, directivePrefix)
if idx == -1 {
// Directive not found, add it after default-src or at the beginning
defaultSrcIdx := strings.Index(policy, "default-src ")
if defaultSrcIdx != -1 {
// Find the end of default-src directive (next semicolon)
endIdx := strings.Index(policy[defaultSrcIdx:], ";")
if endIdx != -1 {
insertPos := defaultSrcIdx + endIdx + 1
// Insert new directive after default-src
return policy[:insertPos] + " " + directive + " 'self' " + value + ";" + policy[insertPos:]
}
}
// Fallback: prepend the directive
return directive + " 'self' " + value + "; " + policy
}
// Find the end of this directive (next semicolon or end of string)
endIdx := strings.Index(policy[idx:], ";")
if endIdx == -1 {
// No semicolon found, directive goes to end of string
return policy + " " + value
}
// Insert value before the semicolon
insertPos := idx + endIdx
return policy[:insertPos] + " " + value + policy[insertPos:]
}
package middleware
import (
"encoding/base64"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func init() {
gin.SetMode(gin.TestMode)
}
func TestGenerateNonce(t *testing.T) {
t.Run("generates_valid_base64_string", func(t *testing.T) {
nonce := GenerateNonce()
// Should be valid base64
decoded, err := base64.StdEncoding.DecodeString(nonce)
require.NoError(t, err)
// Should decode to 16 bytes
assert.Len(t, decoded, 16)
})
t.Run("generates_unique_nonces", func(t *testing.T) {
nonces := make(map[string]bool)
for i := 0; i < 100; i++ {
nonce := GenerateNonce()
assert.False(t, nonces[nonce], "nonce should be unique")
nonces[nonce] = true
}
})
t.Run("nonce_has_expected_length", func(t *testing.T) {
nonce := GenerateNonce()
// 16 bytes -> 24 chars in base64 (with padding)
assert.Len(t, nonce, 24)
})
}
func TestGetNonceFromContext(t *testing.T) {
t.Run("returns_nonce_when_present", func(t *testing.T) {
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
expectedNonce := "test-nonce-123"
c.Set(CSPNonceKey, expectedNonce)
nonce := GetNonceFromContext(c)
assert.Equal(t, expectedNonce, nonce)
})
t.Run("returns_empty_string_when_not_present", func(t *testing.T) {
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
nonce := GetNonceFromContext(c)
assert.Empty(t, nonce)
})
t.Run("returns_empty_for_wrong_type", func(t *testing.T) {
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
// Set a non-string value
c.Set(CSPNonceKey, 12345)
// Should return empty string for wrong type (safe type assertion)
nonce := GetNonceFromContext(c)
assert.Empty(t, nonce)
})
}
func TestSecurityHeaders(t *testing.T) {
t.Run("sets_basic_security_headers", func(t *testing.T) {
cfg := config.CSPConfig{Enabled: false}
middleware := SecurityHeaders(cfg)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
middleware(c)
assert.Equal(t, "nosniff", w.Header().Get("X-Content-Type-Options"))
assert.Equal(t, "DENY", w.Header().Get("X-Frame-Options"))
assert.Equal(t, "strict-origin-when-cross-origin", w.Header().Get("Referrer-Policy"))
})
t.Run("csp_disabled_no_csp_header", func(t *testing.T) {
cfg := config.CSPConfig{Enabled: false}
middleware := SecurityHeaders(cfg)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
middleware(c)
assert.Empty(t, w.Header().Get("Content-Security-Policy"))
})
t.Run("csp_enabled_sets_csp_header", func(t *testing.T) {
cfg := config.CSPConfig{
Enabled: true,
Policy: "default-src 'self'",
}
middleware := SecurityHeaders(cfg)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
middleware(c)
csp := w.Header().Get("Content-Security-Policy")
assert.NotEmpty(t, csp)
// Policy is auto-enhanced with nonce and Cloudflare Insights domain
assert.Contains(t, csp, "default-src 'self'")
assert.Contains(t, csp, "'nonce-")
assert.Contains(t, csp, CloudflareInsightsDomain)
})
t.Run("csp_enabled_with_nonce_placeholder", func(t *testing.T) {
cfg := config.CSPConfig{
Enabled: true,
Policy: "script-src 'self' __CSP_NONCE__",
}
middleware := SecurityHeaders(cfg)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
middleware(c)
csp := w.Header().Get("Content-Security-Policy")
assert.NotEmpty(t, csp)
assert.NotContains(t, csp, "__CSP_NONCE__", "placeholder should be replaced")
assert.Contains(t, csp, "'nonce-", "should contain nonce directive")
// Verify nonce is stored in context
nonce := GetNonceFromContext(c)
assert.NotEmpty(t, nonce)
assert.Contains(t, csp, "'nonce-"+nonce+"'")
})
t.Run("uses_default_policy_when_empty", func(t *testing.T) {
cfg := config.CSPConfig{
Enabled: true,
Policy: "",
}
middleware := SecurityHeaders(cfg)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
middleware(c)
csp := w.Header().Get("Content-Security-Policy")
assert.NotEmpty(t, csp)
// Default policy should contain these elements
assert.Contains(t, csp, "default-src 'self'")
})
t.Run("uses_default_policy_when_whitespace_only", func(t *testing.T) {
cfg := config.CSPConfig{
Enabled: true,
Policy: " \t\n ",
}
middleware := SecurityHeaders(cfg)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
middleware(c)
csp := w.Header().Get("Content-Security-Policy")
assert.NotEmpty(t, csp)
assert.Contains(t, csp, "default-src 'self'")
})
t.Run("multiple_nonce_placeholders_replaced", func(t *testing.T) {
cfg := config.CSPConfig{
Enabled: true,
Policy: "script-src __CSP_NONCE__; style-src __CSP_NONCE__",
}
middleware := SecurityHeaders(cfg)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
middleware(c)
csp := w.Header().Get("Content-Security-Policy")
nonce := GetNonceFromContext(c)
// Count occurrences of the nonce
count := strings.Count(csp, "'nonce-"+nonce+"'")
assert.Equal(t, 2, count, "both placeholders should be replaced with same nonce")
})
t.Run("calls_next_handler", func(t *testing.T) {
cfg := config.CSPConfig{Enabled: true, Policy: "default-src 'self'"}
middleware := SecurityHeaders(cfg)
nextCalled := false
router := gin.New()
router.Use(middleware)
router.GET("/test", func(c *gin.Context) {
nextCalled = true
c.Status(http.StatusOK)
})
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/test", nil)
router.ServeHTTP(w, req)
assert.True(t, nextCalled, "next handler should be called")
assert.Equal(t, http.StatusOK, w.Code)
})
t.Run("nonce_unique_per_request", func(t *testing.T) {
cfg := config.CSPConfig{
Enabled: true,
Policy: "script-src __CSP_NONCE__",
}
middleware := SecurityHeaders(cfg)
nonces := make(map[string]bool)
for i := 0; i < 10; i++ {
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
middleware(c)
nonce := GetNonceFromContext(c)
assert.False(t, nonces[nonce], "nonce should be unique per request")
nonces[nonce] = true
}
})
}
func TestCSPNonceKey(t *testing.T) {
t.Run("constant_value", func(t *testing.T) {
assert.Equal(t, "csp_nonce", CSPNonceKey)
})
}
func TestNonceTemplate(t *testing.T) {
t.Run("constant_value", func(t *testing.T) {
assert.Equal(t, "__CSP_NONCE__", NonceTemplate)
})
}
func TestEnhanceCSPPolicy(t *testing.T) {
t.Run("adds_nonce_placeholder_if_missing", func(t *testing.T) {
policy := "default-src 'self'; script-src 'self'"
enhanced := enhanceCSPPolicy(policy)
assert.Contains(t, enhanced, NonceTemplate)
assert.Contains(t, enhanced, CloudflareInsightsDomain)
})
t.Run("does_not_duplicate_nonce_placeholder", func(t *testing.T) {
policy := "default-src 'self'; script-src 'self' __CSP_NONCE__"
enhanced := enhanceCSPPolicy(policy)
// Should not duplicate
count := strings.Count(enhanced, NonceTemplate)
assert.Equal(t, 1, count)
})
t.Run("does_not_duplicate_cloudflare_domain", func(t *testing.T) {
policy := "default-src 'self'; script-src 'self' https://static.cloudflareinsights.com"
enhanced := enhanceCSPPolicy(policy)
count := strings.Count(enhanced, CloudflareInsightsDomain)
assert.Equal(t, 1, count)
})
t.Run("handles_policy_without_script_src", func(t *testing.T) {
policy := "default-src 'self'"
enhanced := enhanceCSPPolicy(policy)
assert.Contains(t, enhanced, "script-src")
assert.Contains(t, enhanced, NonceTemplate)
assert.Contains(t, enhanced, CloudflareInsightsDomain)
})
t.Run("preserves_existing_nonce", func(t *testing.T) {
policy := "script-src 'self' 'nonce-existing'"
enhanced := enhanceCSPPolicy(policy)
// Should not add placeholder if nonce already exists
assert.NotContains(t, enhanced, NonceTemplate)
assert.Contains(t, enhanced, "'nonce-existing'")
})
}
func TestAddToDirective(t *testing.T) {
t.Run("adds_to_existing_directive", func(t *testing.T) {
policy := "script-src 'self'; style-src 'self'"
result := addToDirective(policy, "script-src", "https://example.com")
assert.Contains(t, result, "script-src 'self' https://example.com")
})
t.Run("creates_directive_if_not_exists", func(t *testing.T) {
policy := "default-src 'self'"
result := addToDirective(policy, "script-src", "https://example.com")
assert.Contains(t, result, "script-src")
assert.Contains(t, result, "https://example.com")
})
t.Run("handles_directive_at_end_without_semicolon", func(t *testing.T) {
policy := "default-src 'self'; script-src 'self'"
result := addToDirective(policy, "script-src", "https://example.com")
assert.Contains(t, result, "https://example.com")
})
t.Run("handles_empty_policy", func(t *testing.T) {
policy := ""
result := addToDirective(policy, "script-src", "https://example.com")
assert.Contains(t, result, "script-src")
assert.Contains(t, result, "https://example.com")
})
}
// Benchmark tests
func BenchmarkGenerateNonce(b *testing.B) {
for i := 0; i < b.N; i++ {
GenerateNonce()
}
}
func BenchmarkSecurityHeadersMiddleware(b *testing.B) {
cfg := config.CSPConfig{
Enabled: true,
Policy: "script-src 'self' __CSP_NONCE__",
}
middleware := SecurityHeaders(cfg)
b.ResetTimer()
for i := 0; i < b.N; i++ {
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
middleware(c)
}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment