"backend/git@web.lueluesay.top:chenxi/sub2api.git" did not exist on "8e834fd9f53b6cd8cf3735e3a3ae7a66f42b9dc1"
Commit 6901b64f authored by cyhhao's avatar cyhhao
Browse files

merge: sync upstream changes

parents 32c47b15 dae0d532
...@@ -8,6 +8,7 @@ import ( ...@@ -8,6 +8,7 @@ import (
"strings" "strings"
"time" "time"
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
"github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
"github.com/lib/pq" "github.com/lib/pq"
) )
...@@ -41,21 +42,22 @@ func isPostgresDriver(db *sql.DB) bool { ...@@ -41,21 +42,22 @@ func isPostgresDriver(db *sql.DB) bool {
} }
func (r *dashboardAggregationRepository) AggregateRange(ctx context.Context, start, end time.Time) error { func (r *dashboardAggregationRepository) AggregateRange(ctx context.Context, start, end time.Time) error {
startUTC := start.UTC() loc := timezone.Location()
endUTC := end.UTC() startLocal := start.In(loc)
if !endUTC.After(startUTC) { endLocal := end.In(loc)
if !endLocal.After(startLocal) {
return nil return nil
} }
hourStart := startUTC.Truncate(time.Hour) hourStart := startLocal.Truncate(time.Hour)
hourEnd := endUTC.Truncate(time.Hour) hourEnd := endLocal.Truncate(time.Hour)
if endUTC.After(hourEnd) { if endLocal.After(hourEnd) {
hourEnd = hourEnd.Add(time.Hour) hourEnd = hourEnd.Add(time.Hour)
} }
dayStart := truncateToDayUTC(startUTC) dayStart := truncateToDay(startLocal)
dayEnd := truncateToDayUTC(endUTC) dayEnd := truncateToDay(endLocal)
if endUTC.After(dayEnd) { if endLocal.After(dayEnd) {
dayEnd = dayEnd.Add(24 * time.Hour) dayEnd = dayEnd.Add(24 * time.Hour)
} }
...@@ -146,38 +148,41 @@ func (r *dashboardAggregationRepository) EnsureUsageLogsPartitions(ctx context.C ...@@ -146,38 +148,41 @@ func (r *dashboardAggregationRepository) EnsureUsageLogsPartitions(ctx context.C
} }
func (r *dashboardAggregationRepository) insertHourlyActiveUsers(ctx context.Context, start, end time.Time) error { func (r *dashboardAggregationRepository) insertHourlyActiveUsers(ctx context.Context, start, end time.Time) error {
tzName := timezone.Name()
query := ` query := `
INSERT INTO usage_dashboard_hourly_users (bucket_start, user_id) INSERT INTO usage_dashboard_hourly_users (bucket_start, user_id)
SELECT DISTINCT SELECT DISTINCT
date_trunc('hour', created_at AT TIME ZONE 'UTC') AT TIME ZONE 'UTC' AS bucket_start, date_trunc('hour', created_at AT TIME ZONE $3) AT TIME ZONE $3 AS bucket_start,
user_id user_id
FROM usage_logs FROM usage_logs
WHERE created_at >= $1 AND created_at < $2 WHERE created_at >= $1 AND created_at < $2
ON CONFLICT DO NOTHING ON CONFLICT DO NOTHING
` `
_, err := r.sql.ExecContext(ctx, query, start.UTC(), end.UTC()) _, err := r.sql.ExecContext(ctx, query, start, end, tzName)
return err return err
} }
func (r *dashboardAggregationRepository) insertDailyActiveUsers(ctx context.Context, start, end time.Time) error { func (r *dashboardAggregationRepository) insertDailyActiveUsers(ctx context.Context, start, end time.Time) error {
tzName := timezone.Name()
query := ` query := `
INSERT INTO usage_dashboard_daily_users (bucket_date, user_id) INSERT INTO usage_dashboard_daily_users (bucket_date, user_id)
SELECT DISTINCT SELECT DISTINCT
(bucket_start AT TIME ZONE 'UTC')::date AS bucket_date, (bucket_start AT TIME ZONE $3)::date AS bucket_date,
user_id user_id
FROM usage_dashboard_hourly_users FROM usage_dashboard_hourly_users
WHERE bucket_start >= $1 AND bucket_start < $2 WHERE bucket_start >= $1 AND bucket_start < $2
ON CONFLICT DO NOTHING ON CONFLICT DO NOTHING
` `
_, err := r.sql.ExecContext(ctx, query, start.UTC(), end.UTC()) _, err := r.sql.ExecContext(ctx, query, start, end, tzName)
return err return err
} }
func (r *dashboardAggregationRepository) upsertHourlyAggregates(ctx context.Context, start, end time.Time) error { func (r *dashboardAggregationRepository) upsertHourlyAggregates(ctx context.Context, start, end time.Time) error {
tzName := timezone.Name()
query := ` query := `
WITH hourly AS ( WITH hourly AS (
SELECT SELECT
date_trunc('hour', created_at AT TIME ZONE 'UTC') AT TIME ZONE 'UTC' AS bucket_start, date_trunc('hour', created_at AT TIME ZONE $3) AT TIME ZONE $3 AS bucket_start,
COUNT(*) AS total_requests, COUNT(*) AS total_requests,
COALESCE(SUM(input_tokens), 0) AS input_tokens, COALESCE(SUM(input_tokens), 0) AS input_tokens,
COALESCE(SUM(output_tokens), 0) AS output_tokens, COALESCE(SUM(output_tokens), 0) AS output_tokens,
...@@ -236,15 +241,16 @@ func (r *dashboardAggregationRepository) upsertHourlyAggregates(ctx context.Cont ...@@ -236,15 +241,16 @@ func (r *dashboardAggregationRepository) upsertHourlyAggregates(ctx context.Cont
active_users = EXCLUDED.active_users, active_users = EXCLUDED.active_users,
computed_at = EXCLUDED.computed_at computed_at = EXCLUDED.computed_at
` `
_, err := r.sql.ExecContext(ctx, query, start.UTC(), end.UTC()) _, err := r.sql.ExecContext(ctx, query, start, end, tzName)
return err return err
} }
func (r *dashboardAggregationRepository) upsertDailyAggregates(ctx context.Context, start, end time.Time) error { func (r *dashboardAggregationRepository) upsertDailyAggregates(ctx context.Context, start, end time.Time) error {
tzName := timezone.Name()
query := ` query := `
WITH daily AS ( WITH daily AS (
SELECT SELECT
(bucket_start AT TIME ZONE 'UTC')::date AS bucket_date, (bucket_start AT TIME ZONE $5)::date AS bucket_date,
COALESCE(SUM(total_requests), 0) AS total_requests, COALESCE(SUM(total_requests), 0) AS total_requests,
COALESCE(SUM(input_tokens), 0) AS input_tokens, COALESCE(SUM(input_tokens), 0) AS input_tokens,
COALESCE(SUM(output_tokens), 0) AS output_tokens, COALESCE(SUM(output_tokens), 0) AS output_tokens,
...@@ -255,7 +261,7 @@ func (r *dashboardAggregationRepository) upsertDailyAggregates(ctx context.Conte ...@@ -255,7 +261,7 @@ func (r *dashboardAggregationRepository) upsertDailyAggregates(ctx context.Conte
COALESCE(SUM(total_duration_ms), 0) AS total_duration_ms COALESCE(SUM(total_duration_ms), 0) AS total_duration_ms
FROM usage_dashboard_hourly FROM usage_dashboard_hourly
WHERE bucket_start >= $1 AND bucket_start < $2 WHERE bucket_start >= $1 AND bucket_start < $2
GROUP BY (bucket_start AT TIME ZONE 'UTC')::date GROUP BY (bucket_start AT TIME ZONE $5)::date
), ),
user_counts AS ( user_counts AS (
SELECT bucket_date, COUNT(*) AS active_users SELECT bucket_date, COUNT(*) AS active_users
...@@ -303,7 +309,7 @@ func (r *dashboardAggregationRepository) upsertDailyAggregates(ctx context.Conte ...@@ -303,7 +309,7 @@ func (r *dashboardAggregationRepository) upsertDailyAggregates(ctx context.Conte
active_users = EXCLUDED.active_users, active_users = EXCLUDED.active_users,
computed_at = EXCLUDED.computed_at computed_at = EXCLUDED.computed_at
` `
_, err := r.sql.ExecContext(ctx, query, start.UTC(), end.UTC(), start.UTC(), end.UTC()) _, err := r.sql.ExecContext(ctx, query, start, end, start, end, tzName)
return err return err
} }
...@@ -376,9 +382,8 @@ func (r *dashboardAggregationRepository) createUsageLogsPartition(ctx context.Co ...@@ -376,9 +382,8 @@ func (r *dashboardAggregationRepository) createUsageLogsPartition(ctx context.Co
return err return err
} }
func truncateToDayUTC(t time.Time) time.Time { func truncateToDay(t time.Time) time.Time {
t = t.UTC() return timezone.StartOfDay(t)
return time.Date(t.Year(), t.Month(), t.Day(), 0, 0, 0, 0, time.UTC)
} }
func truncateToMonthUTC(t time.Time) time.Time { func truncateToMonthUTC(t time.Time) time.Time {
......
...@@ -11,8 +11,8 @@ import ( ...@@ -11,8 +11,8 @@ import (
) )
const ( const (
geminiTokenKeyPrefix = "gemini:token:" oauthTokenKeyPrefix = "oauth:token:"
geminiRefreshLockKeyPrefix = "gemini:refresh_lock:" oauthRefreshLockKeyPrefix = "oauth:refresh_lock:"
) )
type geminiTokenCache struct { type geminiTokenCache struct {
...@@ -24,21 +24,26 @@ func NewGeminiTokenCache(rdb *redis.Client) service.GeminiTokenCache { ...@@ -24,21 +24,26 @@ func NewGeminiTokenCache(rdb *redis.Client) service.GeminiTokenCache {
} }
func (c *geminiTokenCache) GetAccessToken(ctx context.Context, cacheKey string) (string, error) { func (c *geminiTokenCache) GetAccessToken(ctx context.Context, cacheKey string) (string, error) {
key := fmt.Sprintf("%s%s", geminiTokenKeyPrefix, cacheKey) key := fmt.Sprintf("%s%s", oauthTokenKeyPrefix, cacheKey)
return c.rdb.Get(ctx, key).Result() return c.rdb.Get(ctx, key).Result()
} }
func (c *geminiTokenCache) SetAccessToken(ctx context.Context, cacheKey string, token string, ttl time.Duration) error { func (c *geminiTokenCache) SetAccessToken(ctx context.Context, cacheKey string, token string, ttl time.Duration) error {
key := fmt.Sprintf("%s%s", geminiTokenKeyPrefix, cacheKey) key := fmt.Sprintf("%s%s", oauthTokenKeyPrefix, cacheKey)
return c.rdb.Set(ctx, key, token, ttl).Err() return c.rdb.Set(ctx, key, token, ttl).Err()
} }
func (c *geminiTokenCache) DeleteAccessToken(ctx context.Context, cacheKey string) error {
key := fmt.Sprintf("%s%s", oauthTokenKeyPrefix, cacheKey)
return c.rdb.Del(ctx, key).Err()
}
func (c *geminiTokenCache) AcquireRefreshLock(ctx context.Context, cacheKey string, ttl time.Duration) (bool, error) { func (c *geminiTokenCache) AcquireRefreshLock(ctx context.Context, cacheKey string, ttl time.Duration) (bool, error) {
key := fmt.Sprintf("%s%s", geminiRefreshLockKeyPrefix, cacheKey) key := fmt.Sprintf("%s%s", oauthRefreshLockKeyPrefix, cacheKey)
return c.rdb.SetNX(ctx, key, 1, ttl).Result() return c.rdb.SetNX(ctx, key, 1, ttl).Result()
} }
func (c *geminiTokenCache) ReleaseRefreshLock(ctx context.Context, cacheKey string) error { func (c *geminiTokenCache) ReleaseRefreshLock(ctx context.Context, cacheKey string) error {
key := fmt.Sprintf("%s%s", geminiRefreshLockKeyPrefix, cacheKey) key := fmt.Sprintf("%s%s", oauthRefreshLockKeyPrefix, cacheKey)
return c.rdb.Del(ctx, key).Err() return c.rdb.Del(ctx, key).Err()
} }
//go:build integration
package repository
import (
"errors"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/redis/go-redis/v9"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
)
type GeminiTokenCacheSuite struct {
IntegrationRedisSuite
cache service.GeminiTokenCache
}
func (s *GeminiTokenCacheSuite) SetupTest() {
s.IntegrationRedisSuite.SetupTest()
s.cache = NewGeminiTokenCache(s.rdb)
}
func (s *GeminiTokenCacheSuite) TestDeleteAccessToken() {
cacheKey := "project-123"
token := "token-value"
require.NoError(s.T(), s.cache.SetAccessToken(s.ctx, cacheKey, token, time.Minute))
got, err := s.cache.GetAccessToken(s.ctx, cacheKey)
require.NoError(s.T(), err)
require.Equal(s.T(), token, got)
require.NoError(s.T(), s.cache.DeleteAccessToken(s.ctx, cacheKey))
_, err = s.cache.GetAccessToken(s.ctx, cacheKey)
require.True(s.T(), errors.Is(err, redis.Nil), "expected redis.Nil after delete")
}
func (s *GeminiTokenCacheSuite) TestDeleteAccessToken_MissingKey() {
require.NoError(s.T(), s.cache.DeleteAccessToken(s.ctx, "missing-key"))
}
func TestGeminiTokenCacheSuite(t *testing.T) {
suite.Run(t, new(GeminiTokenCacheSuite))
}
//go:build unit
package repository
import (
"context"
"testing"
"time"
"github.com/redis/go-redis/v9"
"github.com/stretchr/testify/require"
)
func TestGeminiTokenCache_DeleteAccessToken_RedisError(t *testing.T) {
rdb := redis.NewClient(&redis.Options{
Addr: "127.0.0.1:1",
DialTimeout: 50 * time.Millisecond,
ReadTimeout: 50 * time.Millisecond,
WriteTimeout: 50 * time.Millisecond,
})
t.Cleanup(func() {
_ = rdb.Close()
})
cache := NewGeminiTokenCache(rdb)
err := cache.DeleteAccessToken(context.Background(), "broken")
require.Error(t, err)
}
...@@ -49,7 +49,13 @@ func (r *groupRepository) Create(ctx context.Context, groupIn *service.Group) er ...@@ -49,7 +49,13 @@ func (r *groupRepository) Create(ctx context.Context, groupIn *service.Group) er
SetNillableImagePrice4k(groupIn.ImagePrice4K). SetNillableImagePrice4k(groupIn.ImagePrice4K).
SetDefaultValidityDays(groupIn.DefaultValidityDays). SetDefaultValidityDays(groupIn.DefaultValidityDays).
SetClaudeCodeOnly(groupIn.ClaudeCodeOnly). SetClaudeCodeOnly(groupIn.ClaudeCodeOnly).
SetNillableFallbackGroupID(groupIn.FallbackGroupID) SetNillableFallbackGroupID(groupIn.FallbackGroupID).
SetModelRoutingEnabled(groupIn.ModelRoutingEnabled)
// 设置模型路由配置
if groupIn.ModelRouting != nil {
builder = builder.SetModelRouting(groupIn.ModelRouting)
}
created, err := builder.Save(ctx) created, err := builder.Save(ctx)
if err == nil { if err == nil {
...@@ -101,7 +107,8 @@ func (r *groupRepository) Update(ctx context.Context, groupIn *service.Group) er ...@@ -101,7 +107,8 @@ func (r *groupRepository) Update(ctx context.Context, groupIn *service.Group) er
SetNillableImagePrice2k(groupIn.ImagePrice2K). SetNillableImagePrice2k(groupIn.ImagePrice2K).
SetNillableImagePrice4k(groupIn.ImagePrice4K). SetNillableImagePrice4k(groupIn.ImagePrice4K).
SetDefaultValidityDays(groupIn.DefaultValidityDays). SetDefaultValidityDays(groupIn.DefaultValidityDays).
SetClaudeCodeOnly(groupIn.ClaudeCodeOnly) SetClaudeCodeOnly(groupIn.ClaudeCodeOnly).
SetModelRoutingEnabled(groupIn.ModelRoutingEnabled)
// 处理 FallbackGroupID:nil 时清除,否则设置 // 处理 FallbackGroupID:nil 时清除,否则设置
if groupIn.FallbackGroupID != nil { if groupIn.FallbackGroupID != nil {
...@@ -110,6 +117,13 @@ func (r *groupRepository) Update(ctx context.Context, groupIn *service.Group) er ...@@ -110,6 +117,13 @@ func (r *groupRepository) Update(ctx context.Context, groupIn *service.Group) er
builder = builder.ClearFallbackGroupID() builder = builder.ClearFallbackGroupID()
} }
// 处理 ModelRouting:nil 时清除,否则设置
if groupIn.ModelRouting != nil {
builder = builder.SetModelRouting(groupIn.ModelRouting)
} else {
builder = builder.ClearModelRouting()
}
updated, err := builder.Save(ctx) updated, err := builder.Save(ctx)
if err != nil { if err != nil {
return translatePersistenceError(err, service.ErrGroupNotFound, service.ErrGroupExists) return translatePersistenceError(err, service.ErrGroupNotFound, service.ErrGroupExists)
......
...@@ -55,7 +55,6 @@ INSERT INTO ops_error_logs ( ...@@ -55,7 +55,6 @@ INSERT INTO ops_error_logs (
upstream_error_message, upstream_error_message,
upstream_error_detail, upstream_error_detail,
upstream_errors, upstream_errors,
duration_ms,
time_to_first_token_ms, time_to_first_token_ms,
request_body, request_body,
request_body_truncated, request_body_truncated,
...@@ -65,7 +64,7 @@ INSERT INTO ops_error_logs ( ...@@ -65,7 +64,7 @@ INSERT INTO ops_error_logs (
retry_count, retry_count,
created_at created_at
) VALUES ( ) VALUES (
$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31,$32,$33,$34,$35 $1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31,$32,$33,$34
) RETURNING id` ) RETURNING id`
var id int64 var id int64
...@@ -98,7 +97,6 @@ INSERT INTO ops_error_logs ( ...@@ -98,7 +97,6 @@ INSERT INTO ops_error_logs (
opsNullString(input.UpstreamErrorMessage), opsNullString(input.UpstreamErrorMessage),
opsNullString(input.UpstreamErrorDetail), opsNullString(input.UpstreamErrorDetail),
opsNullString(input.UpstreamErrorsJSON), opsNullString(input.UpstreamErrorsJSON),
opsNullInt(input.DurationMs),
opsNullInt64(input.TimeToFirstTokenMs), opsNullInt64(input.TimeToFirstTokenMs),
opsNullString(input.RequestBodyJSON), opsNullString(input.RequestBodyJSON),
input.RequestBodyTruncated, input.RequestBodyTruncated,
...@@ -135,7 +133,7 @@ func (r *opsRepository) ListErrorLogs(ctx context.Context, filter *service.OpsEr ...@@ -135,7 +133,7 @@ func (r *opsRepository) ListErrorLogs(ctx context.Context, filter *service.OpsEr
} }
where, args := buildOpsErrorLogsWhere(filter) where, args := buildOpsErrorLogsWhere(filter)
countSQL := "SELECT COUNT(*) FROM ops_error_logs " + where countSQL := "SELECT COUNT(*) FROM ops_error_logs e " + where
var total int var total int
if err := r.db.QueryRowContext(ctx, countSQL, args...).Scan(&total); err != nil { if err := r.db.QueryRowContext(ctx, countSQL, args...).Scan(&total); err != nil {
...@@ -146,28 +144,43 @@ func (r *opsRepository) ListErrorLogs(ctx context.Context, filter *service.OpsEr ...@@ -146,28 +144,43 @@ func (r *opsRepository) ListErrorLogs(ctx context.Context, filter *service.OpsEr
argsWithLimit := append(args, pageSize, offset) argsWithLimit := append(args, pageSize, offset)
selectSQL := ` selectSQL := `
SELECT SELECT
id, e.id,
created_at, e.created_at,
error_phase, e.error_phase,
error_type, e.error_type,
severity, COALESCE(e.error_owner, ''),
COALESCE(upstream_status_code, status_code, 0), COALESCE(e.error_source, ''),
COALESCE(platform, ''), e.severity,
COALESCE(model, ''), COALESCE(e.upstream_status_code, e.status_code, 0),
duration_ms, COALESCE(e.platform, ''),
COALESCE(client_request_id, ''), COALESCE(e.model, ''),
COALESCE(request_id, ''), COALESCE(e.is_retryable, false),
COALESCE(error_message, ''), COALESCE(e.retry_count, 0),
user_id, COALESCE(e.resolved, false),
api_key_id, e.resolved_at,
account_id, e.resolved_by_user_id,
group_id, COALESCE(u2.email, ''),
CASE WHEN client_ip IS NULL THEN NULL ELSE client_ip::text END, e.resolved_retry_id,
COALESCE(request_path, ''), COALESCE(e.client_request_id, ''),
stream COALESCE(e.request_id, ''),
FROM ops_error_logs COALESCE(e.error_message, ''),
e.user_id,
COALESCE(u.email, ''),
e.api_key_id,
e.account_id,
COALESCE(a.name, ''),
e.group_id,
COALESCE(g.name, ''),
CASE WHEN e.client_ip IS NULL THEN NULL ELSE e.client_ip::text END,
COALESCE(e.request_path, ''),
e.stream
FROM ops_error_logs e
LEFT JOIN accounts a ON e.account_id = a.id
LEFT JOIN groups g ON e.group_id = g.id
LEFT JOIN users u ON e.user_id = u.id
LEFT JOIN users u2 ON e.resolved_by_user_id = u2.id
` + where + ` ` + where + `
ORDER BY created_at DESC ORDER BY e.created_at DESC
LIMIT $` + itoa(len(args)+1) + ` OFFSET $` + itoa(len(args)+2) LIMIT $` + itoa(len(args)+1) + ` OFFSET $` + itoa(len(args)+2)
rows, err := r.db.QueryContext(ctx, selectSQL, argsWithLimit...) rows, err := r.db.QueryContext(ctx, selectSQL, argsWithLimit...)
...@@ -179,39 +192,65 @@ LIMIT $` + itoa(len(args)+1) + ` OFFSET $` + itoa(len(args)+2) ...@@ -179,39 +192,65 @@ LIMIT $` + itoa(len(args)+1) + ` OFFSET $` + itoa(len(args)+2)
out := make([]*service.OpsErrorLog, 0, pageSize) out := make([]*service.OpsErrorLog, 0, pageSize)
for rows.Next() { for rows.Next() {
var item service.OpsErrorLog var item service.OpsErrorLog
var latency sql.NullInt64
var statusCode sql.NullInt64 var statusCode sql.NullInt64
var clientIP sql.NullString var clientIP sql.NullString
var userID sql.NullInt64 var userID sql.NullInt64
var apiKeyID sql.NullInt64 var apiKeyID sql.NullInt64
var accountID sql.NullInt64 var accountID sql.NullInt64
var accountName string
var groupID sql.NullInt64 var groupID sql.NullInt64
var groupName string
var userEmail string
var resolvedAt sql.NullTime
var resolvedBy sql.NullInt64
var resolvedByName string
var resolvedRetryID sql.NullInt64
if err := rows.Scan( if err := rows.Scan(
&item.ID, &item.ID,
&item.CreatedAt, &item.CreatedAt,
&item.Phase, &item.Phase,
&item.Type, &item.Type,
&item.Owner,
&item.Source,
&item.Severity, &item.Severity,
&statusCode, &statusCode,
&item.Platform, &item.Platform,
&item.Model, &item.Model,
&latency, &item.IsRetryable,
&item.RetryCount,
&item.Resolved,
&resolvedAt,
&resolvedBy,
&resolvedByName,
&resolvedRetryID,
&item.ClientRequestID, &item.ClientRequestID,
&item.RequestID, &item.RequestID,
&item.Message, &item.Message,
&userID, &userID,
&userEmail,
&apiKeyID, &apiKeyID,
&accountID, &accountID,
&accountName,
&groupID, &groupID,
&groupName,
&clientIP, &clientIP,
&item.RequestPath, &item.RequestPath,
&item.Stream, &item.Stream,
); err != nil { ); err != nil {
return nil, err return nil, err
} }
if latency.Valid { if resolvedAt.Valid {
v := int(latency.Int64) t := resolvedAt.Time
item.LatencyMs = &v item.ResolvedAt = &t
}
if resolvedBy.Valid {
v := resolvedBy.Int64
item.ResolvedByUserID = &v
}
item.ResolvedByUserName = resolvedByName
if resolvedRetryID.Valid {
v := resolvedRetryID.Int64
item.ResolvedRetryID = &v
} }
item.StatusCode = int(statusCode.Int64) item.StatusCode = int(statusCode.Int64)
if clientIP.Valid { if clientIP.Valid {
...@@ -222,6 +261,7 @@ LIMIT $` + itoa(len(args)+1) + ` OFFSET $` + itoa(len(args)+2) ...@@ -222,6 +261,7 @@ LIMIT $` + itoa(len(args)+1) + ` OFFSET $` + itoa(len(args)+2)
v := userID.Int64 v := userID.Int64
item.UserID = &v item.UserID = &v
} }
item.UserEmail = userEmail
if apiKeyID.Valid { if apiKeyID.Valid {
v := apiKeyID.Int64 v := apiKeyID.Int64
item.APIKeyID = &v item.APIKeyID = &v
...@@ -230,10 +270,12 @@ LIMIT $` + itoa(len(args)+1) + ` OFFSET $` + itoa(len(args)+2) ...@@ -230,10 +270,12 @@ LIMIT $` + itoa(len(args)+1) + ` OFFSET $` + itoa(len(args)+2)
v := accountID.Int64 v := accountID.Int64
item.AccountID = &v item.AccountID = &v
} }
item.AccountName = accountName
if groupID.Valid { if groupID.Valid {
v := groupID.Int64 v := groupID.Int64
item.GroupID = &v item.GroupID = &v
} }
item.GroupName = groupName
out = append(out, &item) out = append(out, &item)
} }
if err := rows.Err(); err != nil { if err := rows.Err(); err != nil {
...@@ -258,49 +300,64 @@ func (r *opsRepository) GetErrorLogByID(ctx context.Context, id int64) (*service ...@@ -258,49 +300,64 @@ func (r *opsRepository) GetErrorLogByID(ctx context.Context, id int64) (*service
q := ` q := `
SELECT SELECT
id, e.id,
created_at, e.created_at,
error_phase, e.error_phase,
error_type, e.error_type,
severity, COALESCE(e.error_owner, ''),
COALESCE(upstream_status_code, status_code, 0), COALESCE(e.error_source, ''),
COALESCE(platform, ''), e.severity,
COALESCE(model, ''), COALESCE(e.upstream_status_code, e.status_code, 0),
duration_ms, COALESCE(e.platform, ''),
COALESCE(client_request_id, ''), COALESCE(e.model, ''),
COALESCE(request_id, ''), COALESCE(e.is_retryable, false),
COALESCE(error_message, ''), COALESCE(e.retry_count, 0),
COALESCE(error_body, ''), COALESCE(e.resolved, false),
upstream_status_code, e.resolved_at,
COALESCE(upstream_error_message, ''), e.resolved_by_user_id,
COALESCE(upstream_error_detail, ''), e.resolved_retry_id,
COALESCE(upstream_errors::text, ''), COALESCE(e.client_request_id, ''),
is_business_limited, COALESCE(e.request_id, ''),
user_id, COALESCE(e.error_message, ''),
api_key_id, COALESCE(e.error_body, ''),
account_id, e.upstream_status_code,
group_id, COALESCE(e.upstream_error_message, ''),
CASE WHEN client_ip IS NULL THEN NULL ELSE client_ip::text END, COALESCE(e.upstream_error_detail, ''),
COALESCE(request_path, ''), COALESCE(e.upstream_errors::text, ''),
stream, e.is_business_limited,
COALESCE(user_agent, ''), e.user_id,
auth_latency_ms, COALESCE(u.email, ''),
routing_latency_ms, e.api_key_id,
upstream_latency_ms, e.account_id,
response_latency_ms, COALESCE(a.name, ''),
time_to_first_token_ms, e.group_id,
COALESCE(request_body::text, ''), COALESCE(g.name, ''),
request_body_truncated, CASE WHEN e.client_ip IS NULL THEN NULL ELSE e.client_ip::text END,
request_body_bytes, COALESCE(e.request_path, ''),
COALESCE(request_headers::text, '') e.stream,
FROM ops_error_logs COALESCE(e.user_agent, ''),
WHERE id = $1 e.auth_latency_ms,
e.routing_latency_ms,
e.upstream_latency_ms,
e.response_latency_ms,
e.time_to_first_token_ms,
COALESCE(e.request_body::text, ''),
e.request_body_truncated,
e.request_body_bytes,
COALESCE(e.request_headers::text, '')
FROM ops_error_logs e
LEFT JOIN users u ON e.user_id = u.id
LEFT JOIN accounts a ON e.account_id = a.id
LEFT JOIN groups g ON e.group_id = g.id
WHERE e.id = $1
LIMIT 1` LIMIT 1`
var out service.OpsErrorLogDetail var out service.OpsErrorLogDetail
var latency sql.NullInt64
var statusCode sql.NullInt64 var statusCode sql.NullInt64
var upstreamStatusCode sql.NullInt64 var upstreamStatusCode sql.NullInt64
var resolvedAt sql.NullTime
var resolvedBy sql.NullInt64
var resolvedRetryID sql.NullInt64
var clientIP sql.NullString var clientIP sql.NullString
var userID sql.NullInt64 var userID sql.NullInt64
var apiKeyID sql.NullInt64 var apiKeyID sql.NullInt64
...@@ -318,11 +375,18 @@ LIMIT 1` ...@@ -318,11 +375,18 @@ LIMIT 1`
&out.CreatedAt, &out.CreatedAt,
&out.Phase, &out.Phase,
&out.Type, &out.Type,
&out.Owner,
&out.Source,
&out.Severity, &out.Severity,
&statusCode, &statusCode,
&out.Platform, &out.Platform,
&out.Model, &out.Model,
&latency, &out.IsRetryable,
&out.RetryCount,
&out.Resolved,
&resolvedAt,
&resolvedBy,
&resolvedRetryID,
&out.ClientRequestID, &out.ClientRequestID,
&out.RequestID, &out.RequestID,
&out.Message, &out.Message,
...@@ -333,9 +397,12 @@ LIMIT 1` ...@@ -333,9 +397,12 @@ LIMIT 1`
&out.UpstreamErrors, &out.UpstreamErrors,
&out.IsBusinessLimited, &out.IsBusinessLimited,
&userID, &userID,
&out.UserEmail,
&apiKeyID, &apiKeyID,
&accountID, &accountID,
&out.AccountName,
&groupID, &groupID,
&out.GroupName,
&clientIP, &clientIP,
&out.RequestPath, &out.RequestPath,
&out.Stream, &out.Stream,
...@@ -355,9 +422,17 @@ LIMIT 1` ...@@ -355,9 +422,17 @@ LIMIT 1`
} }
out.StatusCode = int(statusCode.Int64) out.StatusCode = int(statusCode.Int64)
if latency.Valid { if resolvedAt.Valid {
v := int(latency.Int64) t := resolvedAt.Time
out.LatencyMs = &v out.ResolvedAt = &t
}
if resolvedBy.Valid {
v := resolvedBy.Int64
out.ResolvedByUserID = &v
}
if resolvedRetryID.Valid {
v := resolvedRetryID.Int64
out.ResolvedRetryID = &v
} }
if clientIP.Valid { if clientIP.Valid {
s := clientIP.String s := clientIP.String
...@@ -487,9 +562,15 @@ SET ...@@ -487,9 +562,15 @@ SET
status = $2, status = $2,
finished_at = $3, finished_at = $3,
duration_ms = $4, duration_ms = $4,
result_request_id = $5, success = $5,
result_error_id = $6, http_status_code = $6,
error_message = $7 upstream_request_id = $7,
used_account_id = $8,
response_preview = $9,
response_truncated = $10,
result_request_id = $11,
result_error_id = $12,
error_message = $13
WHERE id = $1` WHERE id = $1`
_, err := r.db.ExecContext( _, err := r.db.ExecContext(
...@@ -499,8 +580,14 @@ WHERE id = $1` ...@@ -499,8 +580,14 @@ WHERE id = $1`
strings.TrimSpace(input.Status), strings.TrimSpace(input.Status),
nullTime(input.FinishedAt), nullTime(input.FinishedAt),
input.DurationMs, input.DurationMs,
nullBool(input.Success),
nullInt(input.HTTPStatusCode),
opsNullString(input.UpstreamRequestID),
nullInt64(input.UsedAccountID),
opsNullString(input.ResponsePreview),
nullBool(input.ResponseTruncated),
opsNullString(input.ResultRequestID), opsNullString(input.ResultRequestID),
opsNullInt64(input.ResultErrorID), nullInt64(input.ResultErrorID),
opsNullString(input.ErrorMessage), opsNullString(input.ErrorMessage),
) )
return err return err
...@@ -526,6 +613,12 @@ SELECT ...@@ -526,6 +613,12 @@ SELECT
started_at, started_at,
finished_at, finished_at,
duration_ms, duration_ms,
success,
http_status_code,
upstream_request_id,
used_account_id,
response_preview,
response_truncated,
result_request_id, result_request_id,
result_error_id, result_error_id,
error_message error_message
...@@ -540,6 +633,12 @@ LIMIT 1` ...@@ -540,6 +633,12 @@ LIMIT 1`
var startedAt sql.NullTime var startedAt sql.NullTime
var finishedAt sql.NullTime var finishedAt sql.NullTime
var durationMs sql.NullInt64 var durationMs sql.NullInt64
var success sql.NullBool
var httpStatusCode sql.NullInt64
var upstreamRequestID sql.NullString
var usedAccountID sql.NullInt64
var responsePreview sql.NullString
var responseTruncated sql.NullBool
var resultRequestID sql.NullString var resultRequestID sql.NullString
var resultErrorID sql.NullInt64 var resultErrorID sql.NullInt64
var errorMessage sql.NullString var errorMessage sql.NullString
...@@ -555,6 +654,12 @@ LIMIT 1` ...@@ -555,6 +654,12 @@ LIMIT 1`
&startedAt, &startedAt,
&finishedAt, &finishedAt,
&durationMs, &durationMs,
&success,
&httpStatusCode,
&upstreamRequestID,
&usedAccountID,
&responsePreview,
&responseTruncated,
&resultRequestID, &resultRequestID,
&resultErrorID, &resultErrorID,
&errorMessage, &errorMessage,
...@@ -579,6 +684,30 @@ LIMIT 1` ...@@ -579,6 +684,30 @@ LIMIT 1`
v := durationMs.Int64 v := durationMs.Int64
out.DurationMs = &v out.DurationMs = &v
} }
if success.Valid {
v := success.Bool
out.Success = &v
}
if httpStatusCode.Valid {
v := int(httpStatusCode.Int64)
out.HTTPStatusCode = &v
}
if upstreamRequestID.Valid {
s := upstreamRequestID.String
out.UpstreamRequestID = &s
}
if usedAccountID.Valid {
v := usedAccountID.Int64
out.UsedAccountID = &v
}
if responsePreview.Valid {
s := responsePreview.String
out.ResponsePreview = &s
}
if responseTruncated.Valid {
v := responseTruncated.Bool
out.ResponseTruncated = &v
}
if resultRequestID.Valid { if resultRequestID.Valid {
s := resultRequestID.String s := resultRequestID.String
out.ResultRequestID = &s out.ResultRequestID = &s
...@@ -602,30 +731,234 @@ func nullTime(t time.Time) sql.NullTime { ...@@ -602,30 +731,234 @@ func nullTime(t time.Time) sql.NullTime {
return sql.NullTime{Time: t, Valid: true} return sql.NullTime{Time: t, Valid: true}
} }
func nullBool(v *bool) sql.NullBool {
if v == nil {
return sql.NullBool{}
}
return sql.NullBool{Bool: *v, Valid: true}
}
func (r *opsRepository) ListRetryAttemptsByErrorID(ctx context.Context, sourceErrorID int64, limit int) ([]*service.OpsRetryAttempt, error) {
if r == nil || r.db == nil {
return nil, fmt.Errorf("nil ops repository")
}
if sourceErrorID <= 0 {
return nil, fmt.Errorf("invalid source_error_id")
}
if limit <= 0 {
limit = 50
}
if limit > 200 {
limit = 200
}
q := `
SELECT
r.id,
r.created_at,
COALESCE(r.requested_by_user_id, 0),
r.source_error_id,
COALESCE(r.mode, ''),
r.pinned_account_id,
COALESCE(pa.name, ''),
COALESCE(r.status, ''),
r.started_at,
r.finished_at,
r.duration_ms,
r.success,
r.http_status_code,
r.upstream_request_id,
r.used_account_id,
COALESCE(ua.name, ''),
r.response_preview,
r.response_truncated,
r.result_request_id,
r.result_error_id,
r.error_message
FROM ops_retry_attempts r
LEFT JOIN accounts pa ON r.pinned_account_id = pa.id
LEFT JOIN accounts ua ON r.used_account_id = ua.id
WHERE r.source_error_id = $1
ORDER BY r.created_at DESC
LIMIT $2`
rows, err := r.db.QueryContext(ctx, q, sourceErrorID, limit)
if err != nil {
return nil, err
}
defer func() { _ = rows.Close() }()
out := make([]*service.OpsRetryAttempt, 0, 16)
for rows.Next() {
var item service.OpsRetryAttempt
var pinnedAccountID sql.NullInt64
var pinnedAccountName string
var requestedBy sql.NullInt64
var startedAt sql.NullTime
var finishedAt sql.NullTime
var durationMs sql.NullInt64
var success sql.NullBool
var httpStatusCode sql.NullInt64
var upstreamRequestID sql.NullString
var usedAccountID sql.NullInt64
var usedAccountName string
var responsePreview sql.NullString
var responseTruncated sql.NullBool
var resultRequestID sql.NullString
var resultErrorID sql.NullInt64
var errorMessage sql.NullString
if err := rows.Scan(
&item.ID,
&item.CreatedAt,
&requestedBy,
&item.SourceErrorID,
&item.Mode,
&pinnedAccountID,
&pinnedAccountName,
&item.Status,
&startedAt,
&finishedAt,
&durationMs,
&success,
&httpStatusCode,
&upstreamRequestID,
&usedAccountID,
&usedAccountName,
&responsePreview,
&responseTruncated,
&resultRequestID,
&resultErrorID,
&errorMessage,
); err != nil {
return nil, err
}
item.RequestedByUserID = requestedBy.Int64
if pinnedAccountID.Valid {
v := pinnedAccountID.Int64
item.PinnedAccountID = &v
}
item.PinnedAccountName = pinnedAccountName
if startedAt.Valid {
t := startedAt.Time
item.StartedAt = &t
}
if finishedAt.Valid {
t := finishedAt.Time
item.FinishedAt = &t
}
if durationMs.Valid {
v := durationMs.Int64
item.DurationMs = &v
}
if success.Valid {
v := success.Bool
item.Success = &v
}
if httpStatusCode.Valid {
v := int(httpStatusCode.Int64)
item.HTTPStatusCode = &v
}
if upstreamRequestID.Valid {
item.UpstreamRequestID = &upstreamRequestID.String
}
if usedAccountID.Valid {
v := usedAccountID.Int64
item.UsedAccountID = &v
}
item.UsedAccountName = usedAccountName
if responsePreview.Valid {
item.ResponsePreview = &responsePreview.String
}
if responseTruncated.Valid {
v := responseTruncated.Bool
item.ResponseTruncated = &v
}
if resultRequestID.Valid {
item.ResultRequestID = &resultRequestID.String
}
if resultErrorID.Valid {
v := resultErrorID.Int64
item.ResultErrorID = &v
}
if errorMessage.Valid {
item.ErrorMessage = &errorMessage.String
}
out = append(out, &item)
}
if err := rows.Err(); err != nil {
return nil, err
}
return out, nil
}
func (r *opsRepository) UpdateErrorResolution(ctx context.Context, errorID int64, resolved bool, resolvedByUserID *int64, resolvedRetryID *int64, resolvedAt *time.Time) error {
if r == nil || r.db == nil {
return fmt.Errorf("nil ops repository")
}
if errorID <= 0 {
return fmt.Errorf("invalid error id")
}
q := `
UPDATE ops_error_logs
SET
resolved = $2,
resolved_at = $3,
resolved_by_user_id = $4,
resolved_retry_id = $5
WHERE id = $1`
at := sql.NullTime{}
if resolvedAt != nil && !resolvedAt.IsZero() {
at = sql.NullTime{Time: resolvedAt.UTC(), Valid: true}
} else if resolved {
now := time.Now().UTC()
at = sql.NullTime{Time: now, Valid: true}
}
_, err := r.db.ExecContext(
ctx,
q,
errorID,
resolved,
at,
nullInt64(resolvedByUserID),
nullInt64(resolvedRetryID),
)
return err
}
func buildOpsErrorLogsWhere(filter *service.OpsErrorLogFilter) (string, []any) { func buildOpsErrorLogsWhere(filter *service.OpsErrorLogFilter) (string, []any) {
clauses := make([]string, 0, 8) clauses := make([]string, 0, 12)
args := make([]any, 0, 8) args := make([]any, 0, 12)
clauses = append(clauses, "1=1") clauses = append(clauses, "1=1")
phaseFilter := "" phaseFilter := ""
if filter != nil { if filter != nil {
phaseFilter = strings.TrimSpace(strings.ToLower(filter.Phase)) phaseFilter = strings.TrimSpace(strings.ToLower(filter.Phase))
} }
// ops_error_logs primarily stores client-visible error requests (status>=400), // ops_error_logs stores client-visible error requests (status>=400),
// but we also persist "recovered" upstream errors (status<400) for upstream health visibility. // but we also persist "recovered" upstream errors (status<400) for upstream health visibility.
// By default, keep list endpoints scoped to client errors unless explicitly filtering upstream phase. // If Resolved is not specified, do not filter by resolved state (backward-compatible).
resolvedFilter := (*bool)(nil)
if filter != nil {
resolvedFilter = filter.Resolved
}
// Keep list endpoints scoped to client errors unless explicitly filtering upstream phase.
if phaseFilter != "upstream" { if phaseFilter != "upstream" {
clauses = append(clauses, "COALESCE(status_code, 0) >= 400") clauses = append(clauses, "COALESCE(status_code, 0) >= 400")
} }
if filter.StartTime != nil && !filter.StartTime.IsZero() { if filter.StartTime != nil && !filter.StartTime.IsZero() {
args = append(args, filter.StartTime.UTC()) args = append(args, filter.StartTime.UTC())
clauses = append(clauses, "created_at >= $"+itoa(len(args))) clauses = append(clauses, "e.created_at >= $"+itoa(len(args)))
} }
if filter.EndTime != nil && !filter.EndTime.IsZero() { if filter.EndTime != nil && !filter.EndTime.IsZero() {
args = append(args, filter.EndTime.UTC()) args = append(args, filter.EndTime.UTC())
// Keep time-window semantics consistent with other ops queries: [start, end) // Keep time-window semantics consistent with other ops queries: [start, end)
clauses = append(clauses, "created_at < $"+itoa(len(args))) clauses = append(clauses, "e.created_at < $"+itoa(len(args)))
} }
if p := strings.TrimSpace(filter.Platform); p != "" { if p := strings.TrimSpace(filter.Platform); p != "" {
args = append(args, p) args = append(args, p)
...@@ -643,10 +976,59 @@ func buildOpsErrorLogsWhere(filter *service.OpsErrorLogFilter) (string, []any) { ...@@ -643,10 +976,59 @@ func buildOpsErrorLogsWhere(filter *service.OpsErrorLogFilter) (string, []any) {
args = append(args, phase) args = append(args, phase)
clauses = append(clauses, "error_phase = $"+itoa(len(args))) clauses = append(clauses, "error_phase = $"+itoa(len(args)))
} }
if filter != nil {
if owner := strings.TrimSpace(strings.ToLower(filter.Owner)); owner != "" {
args = append(args, owner)
clauses = append(clauses, "LOWER(COALESCE(error_owner,'')) = $"+itoa(len(args)))
}
if source := strings.TrimSpace(strings.ToLower(filter.Source)); source != "" {
args = append(args, source)
clauses = append(clauses, "LOWER(COALESCE(error_source,'')) = $"+itoa(len(args)))
}
}
if resolvedFilter != nil {
args = append(args, *resolvedFilter)
clauses = append(clauses, "COALESCE(resolved,false) = $"+itoa(len(args)))
}
// View filter: errors vs excluded vs all.
// Excluded = upstream 429/529 and business-limited (quota/concurrency/billing) errors.
view := ""
if filter != nil {
view = strings.ToLower(strings.TrimSpace(filter.View))
}
switch view {
case "", "errors":
clauses = append(clauses, "COALESCE(is_business_limited,false) = false")
clauses = append(clauses, "COALESCE(upstream_status_code, status_code, 0) NOT IN (429, 529)")
case "excluded":
clauses = append(clauses, "(COALESCE(is_business_limited,false) = true OR COALESCE(upstream_status_code, status_code, 0) IN (429, 529))")
case "all":
// no-op
default:
// treat unknown as default 'errors'
clauses = append(clauses, "COALESCE(is_business_limited,false) = false")
clauses = append(clauses, "COALESCE(upstream_status_code, status_code, 0) NOT IN (429, 529)")
}
if len(filter.StatusCodes) > 0 { if len(filter.StatusCodes) > 0 {
args = append(args, pq.Array(filter.StatusCodes)) args = append(args, pq.Array(filter.StatusCodes))
clauses = append(clauses, "COALESCE(upstream_status_code, status_code, 0) = ANY($"+itoa(len(args))+")") clauses = append(clauses, "COALESCE(upstream_status_code, status_code, 0) = ANY($"+itoa(len(args))+")")
} else if filter.StatusCodesOther {
// "Other" means: status codes not in the common list.
known := []int{400, 401, 403, 404, 409, 422, 429, 500, 502, 503, 504, 529}
args = append(args, pq.Array(known))
clauses = append(clauses, "NOT (COALESCE(upstream_status_code, status_code, 0) = ANY($"+itoa(len(args))+"))")
} }
// Exact correlation keys (preferred for request↔upstream linkage).
if rid := strings.TrimSpace(filter.RequestID); rid != "" {
args = append(args, rid)
clauses = append(clauses, "COALESCE(request_id,'') = $"+itoa(len(args)))
}
if crid := strings.TrimSpace(filter.ClientRequestID); crid != "" {
args = append(args, crid)
clauses = append(clauses, "COALESCE(client_request_id,'') = $"+itoa(len(args)))
}
if q := strings.TrimSpace(filter.Query); q != "" { if q := strings.TrimSpace(filter.Query); q != "" {
like := "%" + q + "%" like := "%" + q + "%"
args = append(args, like) args = append(args, like)
...@@ -654,6 +1036,13 @@ func buildOpsErrorLogsWhere(filter *service.OpsErrorLogFilter) (string, []any) { ...@@ -654,6 +1036,13 @@ func buildOpsErrorLogsWhere(filter *service.OpsErrorLogFilter) (string, []any) {
clauses = append(clauses, "(request_id ILIKE $"+n+" OR client_request_id ILIKE $"+n+" OR error_message ILIKE $"+n+")") clauses = append(clauses, "(request_id ILIKE $"+n+" OR client_request_id ILIKE $"+n+" OR error_message ILIKE $"+n+")")
} }
if userQuery := strings.TrimSpace(filter.UserQuery); userQuery != "" {
like := "%" + userQuery + "%"
args = append(args, like)
n := itoa(len(args))
clauses = append(clauses, "u.email ILIKE $"+n)
}
return "WHERE " + strings.Join(clauses, " AND "), args return "WHERE " + strings.Join(clauses, " AND "), args
} }
......
...@@ -354,7 +354,7 @@ SELECT ...@@ -354,7 +354,7 @@ SELECT
created_at created_at
FROM ops_alert_events FROM ops_alert_events
` + where + ` ` + where + `
ORDER BY fired_at DESC ORDER BY fired_at DESC, id DESC
LIMIT ` + limitArg LIMIT ` + limitArg
rows, err := r.db.QueryContext(ctx, q, args...) rows, err := r.db.QueryContext(ctx, q, args...)
...@@ -413,6 +413,43 @@ LIMIT ` + limitArg ...@@ -413,6 +413,43 @@ LIMIT ` + limitArg
return out, nil return out, nil
} }
func (r *opsRepository) GetAlertEventByID(ctx context.Context, eventID int64) (*service.OpsAlertEvent, error) {
if r == nil || r.db == nil {
return nil, fmt.Errorf("nil ops repository")
}
if eventID <= 0 {
return nil, fmt.Errorf("invalid event id")
}
q := `
SELECT
id,
COALESCE(rule_id, 0),
COALESCE(severity, ''),
COALESCE(status, ''),
COALESCE(title, ''),
COALESCE(description, ''),
metric_value,
threshold_value,
dimensions,
fired_at,
resolved_at,
email_sent,
created_at
FROM ops_alert_events
WHERE id = $1`
row := r.db.QueryRowContext(ctx, q, eventID)
ev, err := scanOpsAlertEvent(row)
if err != nil {
if err == sql.ErrNoRows {
return nil, nil
}
return nil, err
}
return ev, nil
}
func (r *opsRepository) GetActiveAlertEvent(ctx context.Context, ruleID int64) (*service.OpsAlertEvent, error) { func (r *opsRepository) GetActiveAlertEvent(ctx context.Context, ruleID int64) (*service.OpsAlertEvent, error) {
if r == nil || r.db == nil { if r == nil || r.db == nil {
return nil, fmt.Errorf("nil ops repository") return nil, fmt.Errorf("nil ops repository")
...@@ -591,6 +628,121 @@ type opsAlertEventRow interface { ...@@ -591,6 +628,121 @@ type opsAlertEventRow interface {
Scan(dest ...any) error Scan(dest ...any) error
} }
func (r *opsRepository) CreateAlertSilence(ctx context.Context, input *service.OpsAlertSilence) (*service.OpsAlertSilence, error) {
if r == nil || r.db == nil {
return nil, fmt.Errorf("nil ops repository")
}
if input == nil {
return nil, fmt.Errorf("nil input")
}
if input.RuleID <= 0 {
return nil, fmt.Errorf("invalid rule_id")
}
platform := strings.TrimSpace(input.Platform)
if platform == "" {
return nil, fmt.Errorf("invalid platform")
}
if input.Until.IsZero() {
return nil, fmt.Errorf("invalid until")
}
q := `
INSERT INTO ops_alert_silences (
rule_id,
platform,
group_id,
region,
until,
reason,
created_by,
created_at
) VALUES (
$1,$2,$3,$4,$5,$6,$7,NOW()
)
RETURNING id, rule_id, platform, group_id, region, until, COALESCE(reason,''), created_by, created_at`
row := r.db.QueryRowContext(
ctx,
q,
input.RuleID,
platform,
opsNullInt64(input.GroupID),
opsNullString(input.Region),
input.Until,
opsNullString(input.Reason),
opsNullInt64(input.CreatedBy),
)
var out service.OpsAlertSilence
var groupID sql.NullInt64
var region sql.NullString
var createdBy sql.NullInt64
if err := row.Scan(
&out.ID,
&out.RuleID,
&out.Platform,
&groupID,
&region,
&out.Until,
&out.Reason,
&createdBy,
&out.CreatedAt,
); err != nil {
return nil, err
}
if groupID.Valid {
v := groupID.Int64
out.GroupID = &v
}
if region.Valid {
v := strings.TrimSpace(region.String)
if v != "" {
out.Region = &v
}
}
if createdBy.Valid {
v := createdBy.Int64
out.CreatedBy = &v
}
return &out, nil
}
func (r *opsRepository) IsAlertSilenced(ctx context.Context, ruleID int64, platform string, groupID *int64, region *string, now time.Time) (bool, error) {
if r == nil || r.db == nil {
return false, fmt.Errorf("nil ops repository")
}
if ruleID <= 0 {
return false, fmt.Errorf("invalid rule id")
}
platform = strings.TrimSpace(platform)
if platform == "" {
return false, nil
}
if now.IsZero() {
now = time.Now().UTC()
}
q := `
SELECT 1
FROM ops_alert_silences
WHERE rule_id = $1
AND platform = $2
AND (group_id IS NOT DISTINCT FROM $3)
AND (region IS NOT DISTINCT FROM $4)
AND until > $5
LIMIT 1`
var dummy int
err := r.db.QueryRowContext(ctx, q, ruleID, platform, opsNullInt64(groupID), opsNullString(region), now).Scan(&dummy)
if err != nil {
if err == sql.ErrNoRows {
return false, nil
}
return false, err
}
return true, nil
}
func scanOpsAlertEvent(row opsAlertEventRow) (*service.OpsAlertEvent, error) { func scanOpsAlertEvent(row opsAlertEventRow) (*service.OpsAlertEvent, error) {
var ev service.OpsAlertEvent var ev service.OpsAlertEvent
var metricValue sql.NullFloat64 var metricValue sql.NullFloat64
...@@ -652,6 +804,10 @@ func buildOpsAlertEventsWhere(filter *service.OpsAlertEventFilter) (string, []an ...@@ -652,6 +804,10 @@ func buildOpsAlertEventsWhere(filter *service.OpsAlertEventFilter) (string, []an
args = append(args, severity) args = append(args, severity)
clauses = append(clauses, "severity = $"+itoa(len(args))) clauses = append(clauses, "severity = $"+itoa(len(args)))
} }
if filter.EmailSent != nil {
args = append(args, *filter.EmailSent)
clauses = append(clauses, "email_sent = $"+itoa(len(args)))
}
if filter.StartTime != nil && !filter.StartTime.IsZero() { if filter.StartTime != nil && !filter.StartTime.IsZero() {
args = append(args, *filter.StartTime) args = append(args, *filter.StartTime)
clauses = append(clauses, "fired_at >= $"+itoa(len(args))) clauses = append(clauses, "fired_at >= $"+itoa(len(args)))
...@@ -661,6 +817,14 @@ func buildOpsAlertEventsWhere(filter *service.OpsAlertEventFilter) (string, []an ...@@ -661,6 +817,14 @@ func buildOpsAlertEventsWhere(filter *service.OpsAlertEventFilter) (string, []an
clauses = append(clauses, "fired_at < $"+itoa(len(args))) clauses = append(clauses, "fired_at < $"+itoa(len(args)))
} }
// Cursor pagination (descending by fired_at, then id)
if filter.BeforeFiredAt != nil && !filter.BeforeFiredAt.IsZero() && filter.BeforeID != nil && *filter.BeforeID > 0 {
args = append(args, *filter.BeforeFiredAt)
tsArg := "$" + itoa(len(args))
args = append(args, *filter.BeforeID)
idArg := "$" + itoa(len(args))
clauses = append(clauses, fmt.Sprintf("(fired_at < %s OR (fired_at = %s AND id < %s))", tsArg, tsArg, idArg))
}
// Dimensions are stored in JSONB. We filter best-effort without requiring GIN indexes. // Dimensions are stored in JSONB. We filter best-effort without requiring GIN indexes.
if platform := strings.TrimSpace(filter.Platform); platform != "" { if platform := strings.TrimSpace(filter.Platform); platform != "" {
args = append(args, platform) args = append(args, platform)
......
...@@ -296,9 +296,10 @@ INSERT INTO ops_job_heartbeats ( ...@@ -296,9 +296,10 @@ INSERT INTO ops_job_heartbeats (
last_error_at, last_error_at,
last_error, last_error,
last_duration_ms, last_duration_ms,
last_result,
updated_at updated_at
) VALUES ( ) VALUES (
$1,$2,$3,$4,$5,$6,NOW() $1,$2,$3,$4,$5,$6,$7,NOW()
) )
ON CONFLICT (job_name) DO UPDATE SET ON CONFLICT (job_name) DO UPDATE SET
last_run_at = COALESCE(EXCLUDED.last_run_at, ops_job_heartbeats.last_run_at), last_run_at = COALESCE(EXCLUDED.last_run_at, ops_job_heartbeats.last_run_at),
...@@ -312,6 +313,10 @@ ON CONFLICT (job_name) DO UPDATE SET ...@@ -312,6 +313,10 @@ ON CONFLICT (job_name) DO UPDATE SET
ELSE COALESCE(EXCLUDED.last_error, ops_job_heartbeats.last_error) ELSE COALESCE(EXCLUDED.last_error, ops_job_heartbeats.last_error)
END, END,
last_duration_ms = COALESCE(EXCLUDED.last_duration_ms, ops_job_heartbeats.last_duration_ms), last_duration_ms = COALESCE(EXCLUDED.last_duration_ms, ops_job_heartbeats.last_duration_ms),
last_result = CASE
WHEN EXCLUDED.last_success_at IS NOT NULL THEN COALESCE(EXCLUDED.last_result, ops_job_heartbeats.last_result)
ELSE ops_job_heartbeats.last_result
END,
updated_at = NOW()` updated_at = NOW()`
_, err := r.db.ExecContext( _, err := r.db.ExecContext(
...@@ -323,6 +328,7 @@ ON CONFLICT (job_name) DO UPDATE SET ...@@ -323,6 +328,7 @@ ON CONFLICT (job_name) DO UPDATE SET
opsNullTime(input.LastErrorAt), opsNullTime(input.LastErrorAt),
opsNullString(input.LastError), opsNullString(input.LastError),
opsNullInt(input.LastDurationMs), opsNullInt(input.LastDurationMs),
opsNullString(input.LastResult),
) )
return err return err
} }
...@@ -340,6 +346,7 @@ SELECT ...@@ -340,6 +346,7 @@ SELECT
last_error_at, last_error_at,
last_error, last_error,
last_duration_ms, last_duration_ms,
last_result,
updated_at updated_at
FROM ops_job_heartbeats FROM ops_job_heartbeats
ORDER BY job_name ASC` ORDER BY job_name ASC`
...@@ -359,6 +366,8 @@ ORDER BY job_name ASC` ...@@ -359,6 +366,8 @@ ORDER BY job_name ASC`
var lastError sql.NullString var lastError sql.NullString
var lastDuration sql.NullInt64 var lastDuration sql.NullInt64
var lastResult sql.NullString
if err := rows.Scan( if err := rows.Scan(
&item.JobName, &item.JobName,
&lastRun, &lastRun,
...@@ -366,6 +375,7 @@ ORDER BY job_name ASC` ...@@ -366,6 +375,7 @@ ORDER BY job_name ASC`
&lastErrorAt, &lastErrorAt,
&lastError, &lastError,
&lastDuration, &lastDuration,
&lastResult,
&item.UpdatedAt, &item.UpdatedAt,
); err != nil { ); err != nil {
return nil, err return nil, err
...@@ -391,6 +401,10 @@ ORDER BY job_name ASC` ...@@ -391,6 +401,10 @@ ORDER BY job_name ASC`
v := lastDuration.Int64 v := lastDuration.Int64
item.LastDurationMs = &v item.LastDurationMs = &v
} }
if lastResult.Valid {
v := lastResult.String
item.LastResult = &v
}
out = append(out, &item) out = append(out, &item)
} }
......
package repository
import (
"context"
"encoding/json"
"fmt"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/redis/go-redis/v9"
)
const proxyLatencyKeyPrefix = "proxy:latency:"
func proxyLatencyKey(proxyID int64) string {
return fmt.Sprintf("%s%d", proxyLatencyKeyPrefix, proxyID)
}
type proxyLatencyCache struct {
rdb *redis.Client
}
func NewProxyLatencyCache(rdb *redis.Client) service.ProxyLatencyCache {
return &proxyLatencyCache{rdb: rdb}
}
func (c *proxyLatencyCache) GetProxyLatencies(ctx context.Context, proxyIDs []int64) (map[int64]*service.ProxyLatencyInfo, error) {
results := make(map[int64]*service.ProxyLatencyInfo)
if len(proxyIDs) == 0 {
return results, nil
}
keys := make([]string, 0, len(proxyIDs))
for _, id := range proxyIDs {
keys = append(keys, proxyLatencyKey(id))
}
values, err := c.rdb.MGet(ctx, keys...).Result()
if err != nil {
return results, err
}
for i, raw := range values {
if raw == nil {
continue
}
var payload []byte
switch v := raw.(type) {
case string:
payload = []byte(v)
case []byte:
payload = v
default:
continue
}
var info service.ProxyLatencyInfo
if err := json.Unmarshal(payload, &info); err != nil {
continue
}
results[proxyIDs[i]] = &info
}
return results, nil
}
func (c *proxyLatencyCache) SetProxyLatency(ctx context.Context, proxyID int64, info *service.ProxyLatencyInfo) error {
if info == nil {
return nil
}
payload, err := json.Marshal(info)
if err != nil {
return err
}
return c.rdb.Set(ctx, proxyLatencyKey(proxyID), payload, 0).Err()
}
...@@ -7,6 +7,7 @@ import ( ...@@ -7,6 +7,7 @@ import (
"io" "io"
"log" "log"
"net/http" "net/http"
"strings"
"time" "time"
"github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/config"
...@@ -34,7 +35,10 @@ func NewProxyExitInfoProber(cfg *config.Config) service.ProxyExitInfoProber { ...@@ -34,7 +35,10 @@ func NewProxyExitInfoProber(cfg *config.Config) service.ProxyExitInfoProber {
} }
} }
const defaultIPInfoURL = "https://ipinfo.io/json" const (
defaultIPInfoURL = "http://ip-api.com/json/?lang=zh-CN"
defaultProxyProbeTimeout = 30 * time.Second
)
type proxyProbeService struct { type proxyProbeService struct {
ipInfoURL string ipInfoURL string
...@@ -46,7 +50,7 @@ type proxyProbeService struct { ...@@ -46,7 +50,7 @@ type proxyProbeService struct {
func (s *proxyProbeService) ProbeProxy(ctx context.Context, proxyURL string) (*service.ProxyExitInfo, int64, error) { func (s *proxyProbeService) ProbeProxy(ctx context.Context, proxyURL string) (*service.ProxyExitInfo, int64, error) {
client, err := httpclient.GetClient(httpclient.Options{ client, err := httpclient.GetClient(httpclient.Options{
ProxyURL: proxyURL, ProxyURL: proxyURL,
Timeout: 15 * time.Second, Timeout: defaultProxyProbeTimeout,
InsecureSkipVerify: s.insecureSkipVerify, InsecureSkipVerify: s.insecureSkipVerify,
ProxyStrict: true, ProxyStrict: true,
ValidateResolvedIP: s.validateResolvedIP, ValidateResolvedIP: s.validateResolvedIP,
...@@ -75,10 +79,14 @@ func (s *proxyProbeService) ProbeProxy(ctx context.Context, proxyURL string) (*s ...@@ -75,10 +79,14 @@ func (s *proxyProbeService) ProbeProxy(ctx context.Context, proxyURL string) (*s
} }
var ipInfo struct { var ipInfo struct {
IP string `json:"ip"` Status string `json:"status"`
City string `json:"city"` Message string `json:"message"`
Region string `json:"region"` Query string `json:"query"`
Country string `json:"country"` City string `json:"city"`
Region string `json:"region"`
RegionName string `json:"regionName"`
Country string `json:"country"`
CountryCode string `json:"countryCode"`
} }
body, err := io.ReadAll(resp.Body) body, err := io.ReadAll(resp.Body)
...@@ -89,11 +97,22 @@ func (s *proxyProbeService) ProbeProxy(ctx context.Context, proxyURL string) (*s ...@@ -89,11 +97,22 @@ func (s *proxyProbeService) ProbeProxy(ctx context.Context, proxyURL string) (*s
if err := json.Unmarshal(body, &ipInfo); err != nil { if err := json.Unmarshal(body, &ipInfo); err != nil {
return nil, latencyMs, fmt.Errorf("failed to parse response: %w", err) return nil, latencyMs, fmt.Errorf("failed to parse response: %w", err)
} }
if strings.ToLower(ipInfo.Status) != "success" {
if ipInfo.Message == "" {
ipInfo.Message = "ip-api request failed"
}
return nil, latencyMs, fmt.Errorf("ip-api request failed: %s", ipInfo.Message)
}
region := ipInfo.RegionName
if region == "" {
region = ipInfo.Region
}
return &service.ProxyExitInfo{ return &service.ProxyExitInfo{
IP: ipInfo.IP, IP: ipInfo.Query,
City: ipInfo.City, City: ipInfo.City,
Region: ipInfo.Region, Region: region,
Country: ipInfo.Country, Country: ipInfo.Country,
CountryCode: ipInfo.CountryCode,
}, latencyMs, nil }, latencyMs, nil
} }
...@@ -21,7 +21,7 @@ type ProxyProbeServiceSuite struct { ...@@ -21,7 +21,7 @@ type ProxyProbeServiceSuite struct {
func (s *ProxyProbeServiceSuite) SetupTest() { func (s *ProxyProbeServiceSuite) SetupTest() {
s.ctx = context.Background() s.ctx = context.Background()
s.prober = &proxyProbeService{ s.prober = &proxyProbeService{
ipInfoURL: "http://ipinfo.test/json", ipInfoURL: "http://ip-api.test/json/?lang=zh-CN",
allowPrivateHosts: true, allowPrivateHosts: true,
} }
} }
...@@ -54,7 +54,7 @@ func (s *ProxyProbeServiceSuite) TestProbeProxy_Success() { ...@@ -54,7 +54,7 @@ func (s *ProxyProbeServiceSuite) TestProbeProxy_Success() {
s.setupProxyServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { s.setupProxyServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
seen <- r.RequestURI seen <- r.RequestURI
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
_, _ = io.WriteString(w, `{"ip":"1.2.3.4","city":"c","region":"r","country":"cc"}`) _, _ = io.WriteString(w, `{"status":"success","query":"1.2.3.4","city":"c","regionName":"r","country":"cc","countryCode":"CC"}`)
})) }))
info, latencyMs, err := s.prober.ProbeProxy(s.ctx, s.proxySrv.URL) info, latencyMs, err := s.prober.ProbeProxy(s.ctx, s.proxySrv.URL)
...@@ -64,11 +64,12 @@ func (s *ProxyProbeServiceSuite) TestProbeProxy_Success() { ...@@ -64,11 +64,12 @@ func (s *ProxyProbeServiceSuite) TestProbeProxy_Success() {
require.Equal(s.T(), "c", info.City) require.Equal(s.T(), "c", info.City)
require.Equal(s.T(), "r", info.Region) require.Equal(s.T(), "r", info.Region)
require.Equal(s.T(), "cc", info.Country) require.Equal(s.T(), "cc", info.Country)
require.Equal(s.T(), "CC", info.CountryCode)
// Verify proxy received the request // Verify proxy received the request
select { select {
case uri := <-seen: case uri := <-seen:
require.Contains(s.T(), uri, "ipinfo.test", "expected request to go through proxy") require.Contains(s.T(), uri, "ip-api.test", "expected request to go through proxy")
default: default:
require.Fail(s.T(), "expected proxy to receive request") require.Fail(s.T(), "expected proxy to receive request")
} }
......
...@@ -219,12 +219,54 @@ func (r *proxyRepository) ExistsByHostPortAuth(ctx context.Context, host string, ...@@ -219,12 +219,54 @@ func (r *proxyRepository) ExistsByHostPortAuth(ctx context.Context, host string,
// CountAccountsByProxyID returns the number of accounts using a specific proxy // CountAccountsByProxyID returns the number of accounts using a specific proxy
func (r *proxyRepository) CountAccountsByProxyID(ctx context.Context, proxyID int64) (int64, error) { func (r *proxyRepository) CountAccountsByProxyID(ctx context.Context, proxyID int64) (int64, error) {
var count int64 var count int64
if err := scanSingleRow(ctx, r.sql, "SELECT COUNT(*) FROM accounts WHERE proxy_id = $1", []any{proxyID}, &count); err != nil { if err := scanSingleRow(ctx, r.sql, "SELECT COUNT(*) FROM accounts WHERE proxy_id = $1 AND deleted_at IS NULL", []any{proxyID}, &count); err != nil {
return 0, err return 0, err
} }
return count, nil return count, nil
} }
func (r *proxyRepository) ListAccountSummariesByProxyID(ctx context.Context, proxyID int64) ([]service.ProxyAccountSummary, error) {
rows, err := r.sql.QueryContext(ctx, `
SELECT id, name, platform, type, notes
FROM accounts
WHERE proxy_id = $1 AND deleted_at IS NULL
ORDER BY id DESC
`, proxyID)
if err != nil {
return nil, err
}
defer func() { _ = rows.Close() }()
out := make([]service.ProxyAccountSummary, 0)
for rows.Next() {
var (
id int64
name string
platform string
accType string
notes sql.NullString
)
if err := rows.Scan(&id, &name, &platform, &accType, &notes); err != nil {
return nil, err
}
var notesPtr *string
if notes.Valid {
notesPtr = &notes.String
}
out = append(out, service.ProxyAccountSummary{
ID: id,
Name: name,
Platform: platform,
Type: accType,
Notes: notesPtr,
})
}
if err := rows.Err(); err != nil {
return nil, err
}
return out, nil
}
// GetAccountCountsForProxies returns a map of proxy ID to account count for all proxies // GetAccountCountsForProxies returns a map of proxy ID to account count for all proxies
func (r *proxyRepository) GetAccountCountsForProxies(ctx context.Context) (counts map[int64]int64, err error) { func (r *proxyRepository) GetAccountCountsForProxies(ctx context.Context) (counts map[int64]int64, err error) {
rows, err := r.sql.QueryContext(ctx, "SELECT proxy_id, COUNT(*) AS count FROM accounts WHERE proxy_id IS NOT NULL AND deleted_at IS NULL GROUP BY proxy_id") rows, err := r.sql.QueryContext(ctx, "SELECT proxy_id, COUNT(*) AS count FROM accounts WHERE proxy_id IS NOT NULL AND deleted_at IS NULL GROUP BY proxy_id")
......
...@@ -27,7 +27,7 @@ func TestSchedulerSnapshotOutboxReplay(t *testing.T) { ...@@ -27,7 +27,7 @@ func TestSchedulerSnapshotOutboxReplay(t *testing.T) {
RunMode: config.RunModeStandard, RunMode: config.RunModeStandard,
Gateway: config.GatewayConfig{ Gateway: config.GatewayConfig{
Scheduling: config.GatewaySchedulingConfig{ Scheduling: config.GatewaySchedulingConfig{
OutboxPollIntervalSeconds: 1, OutboxPollIntervalSeconds: 1,
FullRebuildIntervalSeconds: 0, FullRebuildIntervalSeconds: 0,
DbFallbackEnabled: true, DbFallbackEnabled: true,
}, },
......
package repository
import (
"context"
"fmt"
"strconv"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/redis/go-redis/v9"
)
// 会话限制缓存常量定义
//
// 设计说明:
// 使用 Redis 有序集合(Sorted Set)跟踪每个账号的活跃会话:
// - Key: session_limit:account:{accountID}
// - Member: sessionUUID(从 metadata.user_id 中提取)
// - Score: Unix 时间戳(会话最后活跃时间)
//
// 通过 ZREMRANGEBYSCORE 自动清理过期会话,无需手动管理 TTL
const (
// 会话限制键前缀
// 格式: session_limit:account:{accountID}
sessionLimitKeyPrefix = "session_limit:account:"
// 窗口费用缓存键前缀
// 格式: window_cost:account:{accountID}
windowCostKeyPrefix = "window_cost:account:"
// 窗口费用缓存 TTL(30秒)
windowCostCacheTTL = 30 * time.Second
)
var (
// registerSessionScript 注册会话活动
// 使用 Redis TIME 命令获取服务器时间,避免多实例时钟不同步
// KEYS[1] = session_limit:account:{accountID}
// ARGV[1] = maxSessions
// ARGV[2] = idleTimeout(秒)
// ARGV[3] = sessionUUID
// 返回: 1 = 允许, 0 = 拒绝
registerSessionScript = redis.NewScript(`
local key = KEYS[1]
local maxSessions = tonumber(ARGV[1])
local idleTimeout = tonumber(ARGV[2])
local sessionUUID = ARGV[3]
-- 使用 Redis 服务器时间,确保多实例时钟一致
local timeResult = redis.call('TIME')
local now = tonumber(timeResult[1])
local expireBefore = now - idleTimeout
-- 清理过期会话
redis.call('ZREMRANGEBYSCORE', key, '-inf', expireBefore)
-- 检查会话是否已存在(支持刷新时间戳)
local exists = redis.call('ZSCORE', key, sessionUUID)
if exists ~= false then
-- 会话已存在,刷新时间戳
redis.call('ZADD', key, now, sessionUUID)
redis.call('EXPIRE', key, idleTimeout + 60)
return 1
end
-- 检查是否达到会话数量上限
local count = redis.call('ZCARD', key)
if count < maxSessions then
-- 未达上限,添加新会话
redis.call('ZADD', key, now, sessionUUID)
redis.call('EXPIRE', key, idleTimeout + 60)
return 1
end
-- 达到上限,拒绝新会话
return 0
`)
// refreshSessionScript 刷新会话时间戳
// KEYS[1] = session_limit:account:{accountID}
// ARGV[1] = idleTimeout(秒)
// ARGV[2] = sessionUUID
refreshSessionScript = redis.NewScript(`
local key = KEYS[1]
local idleTimeout = tonumber(ARGV[1])
local sessionUUID = ARGV[2]
local timeResult = redis.call('TIME')
local now = tonumber(timeResult[1])
-- 检查会话是否存在
local exists = redis.call('ZSCORE', key, sessionUUID)
if exists ~= false then
redis.call('ZADD', key, now, sessionUUID)
redis.call('EXPIRE', key, idleTimeout + 60)
end
return 1
`)
// getActiveSessionCountScript 获取活跃会话数
// KEYS[1] = session_limit:account:{accountID}
// ARGV[1] = idleTimeout(秒)
getActiveSessionCountScript = redis.NewScript(`
local key = KEYS[1]
local idleTimeout = tonumber(ARGV[1])
local timeResult = redis.call('TIME')
local now = tonumber(timeResult[1])
local expireBefore = now - idleTimeout
-- 清理过期会话
redis.call('ZREMRANGEBYSCORE', key, '-inf', expireBefore)
return redis.call('ZCARD', key)
`)
// isSessionActiveScript 检查会话是否活跃
// KEYS[1] = session_limit:account:{accountID}
// ARGV[1] = idleTimeout(秒)
// ARGV[2] = sessionUUID
isSessionActiveScript = redis.NewScript(`
local key = KEYS[1]
local idleTimeout = tonumber(ARGV[1])
local sessionUUID = ARGV[2]
local timeResult = redis.call('TIME')
local now = tonumber(timeResult[1])
local expireBefore = now - idleTimeout
-- 获取会话的时间戳
local score = redis.call('ZSCORE', key, sessionUUID)
if score == false then
return 0
end
-- 检查是否过期
if tonumber(score) <= expireBefore then
return 0
end
return 1
`)
)
type sessionLimitCache struct {
rdb *redis.Client
defaultIdleTimeout time.Duration // 默认空闲超时(用于 GetActiveSessionCount)
}
// NewSessionLimitCache 创建会话限制缓存
// defaultIdleTimeoutMinutes: 默认空闲超时时间(分钟),用于无参数查询
func NewSessionLimitCache(rdb *redis.Client, defaultIdleTimeoutMinutes int) service.SessionLimitCache {
if defaultIdleTimeoutMinutes <= 0 {
defaultIdleTimeoutMinutes = 5 // 默认 5 分钟
}
return &sessionLimitCache{
rdb: rdb,
defaultIdleTimeout: time.Duration(defaultIdleTimeoutMinutes) * time.Minute,
}
}
// sessionLimitKey 生成会话限制的 Redis 键
func sessionLimitKey(accountID int64) string {
return fmt.Sprintf("%s%d", sessionLimitKeyPrefix, accountID)
}
// windowCostKey 生成窗口费用缓存的 Redis 键
func windowCostKey(accountID int64) string {
return fmt.Sprintf("%s%d", windowCostKeyPrefix, accountID)
}
// RegisterSession 注册会话活动
func (c *sessionLimitCache) RegisterSession(ctx context.Context, accountID int64, sessionUUID string, maxSessions int, idleTimeout time.Duration) (bool, error) {
if sessionUUID == "" || maxSessions <= 0 {
return true, nil // 无效参数,默认允许
}
key := sessionLimitKey(accountID)
idleTimeoutSeconds := int(idleTimeout.Seconds())
if idleTimeoutSeconds <= 0 {
idleTimeoutSeconds = int(c.defaultIdleTimeout.Seconds())
}
result, err := registerSessionScript.Run(ctx, c.rdb, []string{key}, maxSessions, idleTimeoutSeconds, sessionUUID).Int()
if err != nil {
return true, err // 失败开放:缓存错误时允许请求通过
}
return result == 1, nil
}
// RefreshSession 刷新会话时间戳
func (c *sessionLimitCache) RefreshSession(ctx context.Context, accountID int64, sessionUUID string, idleTimeout time.Duration) error {
if sessionUUID == "" {
return nil
}
key := sessionLimitKey(accountID)
idleTimeoutSeconds := int(idleTimeout.Seconds())
if idleTimeoutSeconds <= 0 {
idleTimeoutSeconds = int(c.defaultIdleTimeout.Seconds())
}
_, err := refreshSessionScript.Run(ctx, c.rdb, []string{key}, idleTimeoutSeconds, sessionUUID).Result()
return err
}
// GetActiveSessionCount 获取活跃会话数
func (c *sessionLimitCache) GetActiveSessionCount(ctx context.Context, accountID int64) (int, error) {
key := sessionLimitKey(accountID)
idleTimeoutSeconds := int(c.defaultIdleTimeout.Seconds())
result, err := getActiveSessionCountScript.Run(ctx, c.rdb, []string{key}, idleTimeoutSeconds).Int()
if err != nil {
return 0, err
}
return result, nil
}
// GetActiveSessionCountBatch 批量获取多个账号的活跃会话数
func (c *sessionLimitCache) GetActiveSessionCountBatch(ctx context.Context, accountIDs []int64) (map[int64]int, error) {
if len(accountIDs) == 0 {
return make(map[int64]int), nil
}
results := make(map[int64]int, len(accountIDs))
// 使用 pipeline 批量执行
pipe := c.rdb.Pipeline()
idleTimeoutSeconds := int(c.defaultIdleTimeout.Seconds())
cmds := make(map[int64]*redis.Cmd, len(accountIDs))
for _, accountID := range accountIDs {
key := sessionLimitKey(accountID)
cmds[accountID] = getActiveSessionCountScript.Run(ctx, pipe, []string{key}, idleTimeoutSeconds)
}
// 执行 pipeline,即使部分失败也尝试获取成功的结果
_, _ = pipe.Exec(ctx)
for accountID, cmd := range cmds {
if result, err := cmd.Int(); err == nil {
results[accountID] = result
}
}
return results, nil
}
// IsSessionActive 检查会话是否活跃
func (c *sessionLimitCache) IsSessionActive(ctx context.Context, accountID int64, sessionUUID string) (bool, error) {
if sessionUUID == "" {
return false, nil
}
key := sessionLimitKey(accountID)
idleTimeoutSeconds := int(c.defaultIdleTimeout.Seconds())
result, err := isSessionActiveScript.Run(ctx, c.rdb, []string{key}, idleTimeoutSeconds, sessionUUID).Int()
if err != nil {
return false, err
}
return result == 1, nil
}
// ========== 5h窗口费用缓存实现 ==========
// GetWindowCost 获取缓存的窗口费用
func (c *sessionLimitCache) GetWindowCost(ctx context.Context, accountID int64) (float64, bool, error) {
key := windowCostKey(accountID)
val, err := c.rdb.Get(ctx, key).Float64()
if err == redis.Nil {
return 0, false, nil // 缓存未命中
}
if err != nil {
return 0, false, err
}
return val, true, nil
}
// SetWindowCost 设置窗口费用缓存
func (c *sessionLimitCache) SetWindowCost(ctx context.Context, accountID int64, cost float64) error {
key := windowCostKey(accountID)
return c.rdb.Set(ctx, key, cost, windowCostCacheTTL).Err()
}
// GetWindowCostBatch 批量获取窗口费用缓存
func (c *sessionLimitCache) GetWindowCostBatch(ctx context.Context, accountIDs []int64) (map[int64]float64, error) {
if len(accountIDs) == 0 {
return make(map[int64]float64), nil
}
// 构建批量查询的 keys
keys := make([]string, len(accountIDs))
for i, accountID := range accountIDs {
keys[i] = windowCostKey(accountID)
}
// 使用 MGET 批量获取
vals, err := c.rdb.MGet(ctx, keys...).Result()
if err != nil {
return nil, err
}
results := make(map[int64]float64, len(accountIDs))
for i, val := range vals {
if val == nil {
continue // 缓存未命中
}
// 尝试解析为 float64
switch v := val.(type) {
case string:
if cost, err := strconv.ParseFloat(v, 64); err == nil {
results[accountIDs[i]] = cost
}
case float64:
results[accountIDs[i]] = v
}
}
return results, nil
}
...@@ -22,7 +22,7 @@ import ( ...@@ -22,7 +22,7 @@ import (
"github.com/lib/pq" "github.com/lib/pq"
) )
const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, billing_type, stream, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, created_at" const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, account_rate_multiplier, billing_type, stream, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, created_at"
type usageLogRepository struct { type usageLogRepository struct {
client *dbent.Client client *dbent.Client
...@@ -105,6 +105,7 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog) ...@@ -105,6 +105,7 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
total_cost, total_cost,
actual_cost, actual_cost,
rate_multiplier, rate_multiplier,
account_rate_multiplier,
billing_type, billing_type,
stream, stream,
duration_ms, duration_ms,
...@@ -120,7 +121,7 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog) ...@@ -120,7 +121,7 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
$8, $9, $10, $11, $8, $9, $10, $11,
$12, $13, $12, $13,
$14, $15, $16, $17, $18, $19, $14, $15, $16, $17, $18, $19,
$20, $21, $22, $23, $24, $25, $26, $27, $28, $29 $20, $21, $22, $23, $24, $25, $26, $27, $28, $29, $30
) )
ON CONFLICT (request_id, api_key_id) DO NOTHING ON CONFLICT (request_id, api_key_id) DO NOTHING
RETURNING id, created_at RETURNING id, created_at
...@@ -160,6 +161,7 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog) ...@@ -160,6 +161,7 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
log.TotalCost, log.TotalCost,
log.ActualCost, log.ActualCost,
rateMultiplier, rateMultiplier,
log.AccountRateMultiplier,
log.BillingType, log.BillingType,
log.Stream, log.Stream,
duration, duration,
...@@ -270,13 +272,13 @@ type DashboardStats = usagestats.DashboardStats ...@@ -270,13 +272,13 @@ type DashboardStats = usagestats.DashboardStats
func (r *usageLogRepository) GetDashboardStats(ctx context.Context) (*DashboardStats, error) { func (r *usageLogRepository) GetDashboardStats(ctx context.Context) (*DashboardStats, error) {
stats := &DashboardStats{} stats := &DashboardStats{}
now := time.Now().UTC() now := timezone.Now()
todayUTC := truncateToDayUTC(now) todayStart := timezone.Today()
if err := r.fillDashboardEntityStats(ctx, stats, todayUTC, now); err != nil { if err := r.fillDashboardEntityStats(ctx, stats, todayStart, now); err != nil {
return nil, err return nil, err
} }
if err := r.fillDashboardUsageStatsAggregated(ctx, stats, todayUTC, now); err != nil { if err := r.fillDashboardUsageStatsAggregated(ctx, stats, todayStart, now); err != nil {
return nil, err return nil, err
} }
...@@ -298,13 +300,13 @@ func (r *usageLogRepository) GetDashboardStatsWithRange(ctx context.Context, sta ...@@ -298,13 +300,13 @@ func (r *usageLogRepository) GetDashboardStatsWithRange(ctx context.Context, sta
} }
stats := &DashboardStats{} stats := &DashboardStats{}
now := time.Now().UTC() now := timezone.Now()
todayUTC := truncateToDayUTC(now) todayStart := timezone.Today()
if err := r.fillDashboardEntityStats(ctx, stats, todayUTC, now); err != nil { if err := r.fillDashboardEntityStats(ctx, stats, todayStart, now); err != nil {
return nil, err return nil, err
} }
if err := r.fillDashboardUsageStatsFromUsageLogs(ctx, stats, startUTC, endUTC, todayUTC, now); err != nil { if err := r.fillDashboardUsageStatsFromUsageLogs(ctx, stats, startUTC, endUTC, todayStart, now); err != nil {
return nil, err return nil, err
} }
...@@ -455,7 +457,7 @@ func (r *usageLogRepository) fillDashboardUsageStatsAggregated(ctx context.Conte ...@@ -455,7 +457,7 @@ func (r *usageLogRepository) fillDashboardUsageStatsAggregated(ctx context.Conte
FROM usage_dashboard_hourly FROM usage_dashboard_hourly
WHERE bucket_start = $1 WHERE bucket_start = $1
` `
hourStart := now.UTC().Truncate(time.Hour) hourStart := now.In(timezone.Location()).Truncate(time.Hour)
if err := scanSingleRow(ctx, r.sql, hourlyActiveQuery, []any{hourStart}, &stats.HourlyActiveUsers); err != nil { if err := scanSingleRow(ctx, r.sql, hourlyActiveQuery, []any{hourStart}, &stats.HourlyActiveUsers); err != nil {
if err != sql.ErrNoRows { if err != sql.ErrNoRows {
return err return err
...@@ -835,7 +837,9 @@ func (r *usageLogRepository) GetAccountTodayStats(ctx context.Context, accountID ...@@ -835,7 +837,9 @@ func (r *usageLogRepository) GetAccountTodayStats(ctx context.Context, accountID
SELECT SELECT
COUNT(*) as requests, COUNT(*) as requests,
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as tokens, COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as tokens,
COALESCE(SUM(actual_cost), 0) as cost COALESCE(SUM(total_cost * COALESCE(account_rate_multiplier, 1)), 0) as cost,
COALESCE(SUM(total_cost), 0) as standard_cost,
COALESCE(SUM(actual_cost), 0) as user_cost
FROM usage_logs FROM usage_logs
WHERE account_id = $1 AND created_at >= $2 WHERE account_id = $1 AND created_at >= $2
` `
...@@ -849,6 +853,8 @@ func (r *usageLogRepository) GetAccountTodayStats(ctx context.Context, accountID ...@@ -849,6 +853,8 @@ func (r *usageLogRepository) GetAccountTodayStats(ctx context.Context, accountID
&stats.Requests, &stats.Requests,
&stats.Tokens, &stats.Tokens,
&stats.Cost, &stats.Cost,
&stats.StandardCost,
&stats.UserCost,
); err != nil { ); err != nil {
return nil, err return nil, err
} }
...@@ -861,7 +867,9 @@ func (r *usageLogRepository) GetAccountWindowStats(ctx context.Context, accountI ...@@ -861,7 +867,9 @@ func (r *usageLogRepository) GetAccountWindowStats(ctx context.Context, accountI
SELECT SELECT
COUNT(*) as requests, COUNT(*) as requests,
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as tokens, COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as tokens,
COALESCE(SUM(actual_cost), 0) as cost COALESCE(SUM(total_cost * COALESCE(account_rate_multiplier, 1)), 0) as cost,
COALESCE(SUM(total_cost), 0) as standard_cost,
COALESCE(SUM(actual_cost), 0) as user_cost
FROM usage_logs FROM usage_logs
WHERE account_id = $1 AND created_at >= $2 WHERE account_id = $1 AND created_at >= $2
` `
...@@ -875,6 +883,8 @@ func (r *usageLogRepository) GetAccountWindowStats(ctx context.Context, accountI ...@@ -875,6 +883,8 @@ func (r *usageLogRepository) GetAccountWindowStats(ctx context.Context, accountI
&stats.Requests, &stats.Requests,
&stats.Tokens, &stats.Tokens,
&stats.Cost, &stats.Cost,
&stats.StandardCost,
&stats.UserCost,
); err != nil { ); err != nil {
return nil, err return nil, err
} }
...@@ -1400,8 +1410,8 @@ func (r *usageLogRepository) GetBatchAPIKeyUsageStats(ctx context.Context, apiKe ...@@ -1400,8 +1410,8 @@ func (r *usageLogRepository) GetBatchAPIKeyUsageStats(ctx context.Context, apiKe
return result, nil return result, nil
} }
// GetUsageTrendWithFilters returns usage trend data with optional user/api_key filters // GetUsageTrendWithFilters returns usage trend data with optional filters
func (r *usageLogRepository) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID int64) (results []TrendDataPoint, err error) { func (r *usageLogRepository) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID, accountID, groupID int64, model string, stream *bool) (results []TrendDataPoint, err error) {
dateFormat := "YYYY-MM-DD" dateFormat := "YYYY-MM-DD"
if granularity == "hour" { if granularity == "hour" {
dateFormat = "YYYY-MM-DD HH24:00" dateFormat = "YYYY-MM-DD HH24:00"
...@@ -1430,6 +1440,22 @@ func (r *usageLogRepository) GetUsageTrendWithFilters(ctx context.Context, start ...@@ -1430,6 +1440,22 @@ func (r *usageLogRepository) GetUsageTrendWithFilters(ctx context.Context, start
query += fmt.Sprintf(" AND api_key_id = $%d", len(args)+1) query += fmt.Sprintf(" AND api_key_id = $%d", len(args)+1)
args = append(args, apiKeyID) args = append(args, apiKeyID)
} }
if accountID > 0 {
query += fmt.Sprintf(" AND account_id = $%d", len(args)+1)
args = append(args, accountID)
}
if groupID > 0 {
query += fmt.Sprintf(" AND group_id = $%d", len(args)+1)
args = append(args, groupID)
}
if model != "" {
query += fmt.Sprintf(" AND model = $%d", len(args)+1)
args = append(args, model)
}
if stream != nil {
query += fmt.Sprintf(" AND stream = $%d", len(args)+1)
args = append(args, *stream)
}
query += " GROUP BY date ORDER BY date ASC" query += " GROUP BY date ORDER BY date ASC"
rows, err := r.sql.QueryContext(ctx, query, args...) rows, err := r.sql.QueryContext(ctx, query, args...)
...@@ -1452,9 +1478,15 @@ func (r *usageLogRepository) GetUsageTrendWithFilters(ctx context.Context, start ...@@ -1452,9 +1478,15 @@ func (r *usageLogRepository) GetUsageTrendWithFilters(ctx context.Context, start
return results, nil return results, nil
} }
// GetModelStatsWithFilters returns model statistics with optional user/api_key filters // GetModelStatsWithFilters returns model statistics with optional filters
func (r *usageLogRepository) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID int64) (results []ModelStat, err error) { func (r *usageLogRepository) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, stream *bool) (results []ModelStat, err error) {
query := ` actualCostExpr := "COALESCE(SUM(actual_cost), 0) as actual_cost"
// 当仅按 account_id 聚合时,实际费用使用账号倍率(total_cost * account_rate_multiplier)。
if accountID > 0 && userID == 0 && apiKeyID == 0 {
actualCostExpr = "COALESCE(SUM(total_cost * COALESCE(account_rate_multiplier, 1)), 0) as actual_cost"
}
query := fmt.Sprintf(`
SELECT SELECT
model, model,
COUNT(*) as requests, COUNT(*) as requests,
...@@ -1462,10 +1494,10 @@ func (r *usageLogRepository) GetModelStatsWithFilters(ctx context.Context, start ...@@ -1462,10 +1494,10 @@ func (r *usageLogRepository) GetModelStatsWithFilters(ctx context.Context, start
COALESCE(SUM(output_tokens), 0) as output_tokens, COALESCE(SUM(output_tokens), 0) as output_tokens,
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as total_tokens, COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as total_tokens,
COALESCE(SUM(total_cost), 0) as cost, COALESCE(SUM(total_cost), 0) as cost,
COALESCE(SUM(actual_cost), 0) as actual_cost %s
FROM usage_logs FROM usage_logs
WHERE created_at >= $1 AND created_at < $2 WHERE created_at >= $1 AND created_at < $2
` `, actualCostExpr)
args := []any{startTime, endTime} args := []any{startTime, endTime}
if userID > 0 { if userID > 0 {
...@@ -1480,6 +1512,14 @@ func (r *usageLogRepository) GetModelStatsWithFilters(ctx context.Context, start ...@@ -1480,6 +1512,14 @@ func (r *usageLogRepository) GetModelStatsWithFilters(ctx context.Context, start
query += fmt.Sprintf(" AND account_id = $%d", len(args)+1) query += fmt.Sprintf(" AND account_id = $%d", len(args)+1)
args = append(args, accountID) args = append(args, accountID)
} }
if groupID > 0 {
query += fmt.Sprintf(" AND group_id = $%d", len(args)+1)
args = append(args, groupID)
}
if stream != nil {
query += fmt.Sprintf(" AND stream = $%d", len(args)+1)
args = append(args, *stream)
}
query += " GROUP BY model ORDER BY total_tokens DESC" query += " GROUP BY model ORDER BY total_tokens DESC"
rows, err := r.sql.QueryContext(ctx, query, args...) rows, err := r.sql.QueryContext(ctx, query, args...)
...@@ -1587,12 +1627,14 @@ func (r *usageLogRepository) GetStatsWithFilters(ctx context.Context, filters Us ...@@ -1587,12 +1627,14 @@ func (r *usageLogRepository) GetStatsWithFilters(ctx context.Context, filters Us
COALESCE(SUM(cache_creation_tokens + cache_read_tokens), 0) as total_cache_tokens, COALESCE(SUM(cache_creation_tokens + cache_read_tokens), 0) as total_cache_tokens,
COALESCE(SUM(total_cost), 0) as total_cost, COALESCE(SUM(total_cost), 0) as total_cost,
COALESCE(SUM(actual_cost), 0) as total_actual_cost, COALESCE(SUM(actual_cost), 0) as total_actual_cost,
COALESCE(SUM(total_cost * COALESCE(account_rate_multiplier, 1)), 0) as total_account_cost,
COALESCE(AVG(duration_ms), 0) as avg_duration_ms COALESCE(AVG(duration_ms), 0) as avg_duration_ms
FROM usage_logs FROM usage_logs
%s %s
`, buildWhere(conditions)) `, buildWhere(conditions))
stats := &UsageStats{} stats := &UsageStats{}
var totalAccountCost float64
if err := scanSingleRow( if err := scanSingleRow(
ctx, ctx,
r.sql, r.sql,
...@@ -1604,10 +1646,14 @@ func (r *usageLogRepository) GetStatsWithFilters(ctx context.Context, filters Us ...@@ -1604,10 +1646,14 @@ func (r *usageLogRepository) GetStatsWithFilters(ctx context.Context, filters Us
&stats.TotalCacheTokens, &stats.TotalCacheTokens,
&stats.TotalCost, &stats.TotalCost,
&stats.TotalActualCost, &stats.TotalActualCost,
&totalAccountCost,
&stats.AverageDurationMs, &stats.AverageDurationMs,
); err != nil { ); err != nil {
return nil, err return nil, err
} }
if filters.AccountID > 0 {
stats.TotalAccountCost = &totalAccountCost
}
stats.TotalTokens = stats.TotalInputTokens + stats.TotalOutputTokens + stats.TotalCacheTokens stats.TotalTokens = stats.TotalInputTokens + stats.TotalOutputTokens + stats.TotalCacheTokens
return stats, nil return stats, nil
} }
...@@ -1634,7 +1680,8 @@ func (r *usageLogRepository) GetAccountUsageStats(ctx context.Context, accountID ...@@ -1634,7 +1680,8 @@ func (r *usageLogRepository) GetAccountUsageStats(ctx context.Context, accountID
COUNT(*) as requests, COUNT(*) as requests,
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as tokens, COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as tokens,
COALESCE(SUM(total_cost), 0) as cost, COALESCE(SUM(total_cost), 0) as cost,
COALESCE(SUM(actual_cost), 0) as actual_cost COALESCE(SUM(total_cost * COALESCE(account_rate_multiplier, 1)), 0) as actual_cost,
COALESCE(SUM(actual_cost), 0) as user_cost
FROM usage_logs FROM usage_logs
WHERE account_id = $1 AND created_at >= $2 AND created_at < $3 WHERE account_id = $1 AND created_at >= $2 AND created_at < $3
GROUP BY date GROUP BY date
...@@ -1661,7 +1708,8 @@ func (r *usageLogRepository) GetAccountUsageStats(ctx context.Context, accountID ...@@ -1661,7 +1708,8 @@ func (r *usageLogRepository) GetAccountUsageStats(ctx context.Context, accountID
var tokens int64 var tokens int64
var cost float64 var cost float64
var actualCost float64 var actualCost float64
if err = rows.Scan(&date, &requests, &tokens, &cost, &actualCost); err != nil { var userCost float64
if err = rows.Scan(&date, &requests, &tokens, &cost, &actualCost, &userCost); err != nil {
return nil, err return nil, err
} }
t, _ := time.Parse("2006-01-02", date) t, _ := time.Parse("2006-01-02", date)
...@@ -1672,19 +1720,21 @@ func (r *usageLogRepository) GetAccountUsageStats(ctx context.Context, accountID ...@@ -1672,19 +1720,21 @@ func (r *usageLogRepository) GetAccountUsageStats(ctx context.Context, accountID
Tokens: tokens, Tokens: tokens,
Cost: cost, Cost: cost,
ActualCost: actualCost, ActualCost: actualCost,
UserCost: userCost,
}) })
} }
if err = rows.Err(); err != nil { if err = rows.Err(); err != nil {
return nil, err return nil, err
} }
var totalActualCost, totalStandardCost float64 var totalAccountCost, totalUserCost, totalStandardCost float64
var totalRequests, totalTokens int64 var totalRequests, totalTokens int64
var highestCostDay, highestRequestDay *AccountUsageHistory var highestCostDay, highestRequestDay *AccountUsageHistory
for i := range history { for i := range history {
h := &history[i] h := &history[i]
totalActualCost += h.ActualCost totalAccountCost += h.ActualCost
totalUserCost += h.UserCost
totalStandardCost += h.Cost totalStandardCost += h.Cost
totalRequests += h.Requests totalRequests += h.Requests
totalTokens += h.Tokens totalTokens += h.Tokens
...@@ -1711,11 +1761,13 @@ func (r *usageLogRepository) GetAccountUsageStats(ctx context.Context, accountID ...@@ -1711,11 +1761,13 @@ func (r *usageLogRepository) GetAccountUsageStats(ctx context.Context, accountID
summary := AccountUsageSummary{ summary := AccountUsageSummary{
Days: daysCount, Days: daysCount,
ActualDaysUsed: actualDaysUsed, ActualDaysUsed: actualDaysUsed,
TotalCost: totalActualCost, TotalCost: totalAccountCost,
TotalUserCost: totalUserCost,
TotalStandardCost: totalStandardCost, TotalStandardCost: totalStandardCost,
TotalRequests: totalRequests, TotalRequests: totalRequests,
TotalTokens: totalTokens, TotalTokens: totalTokens,
AvgDailyCost: totalActualCost / float64(actualDaysUsed), AvgDailyCost: totalAccountCost / float64(actualDaysUsed),
AvgDailyUserCost: totalUserCost / float64(actualDaysUsed),
AvgDailyRequests: float64(totalRequests) / float64(actualDaysUsed), AvgDailyRequests: float64(totalRequests) / float64(actualDaysUsed),
AvgDailyTokens: float64(totalTokens) / float64(actualDaysUsed), AvgDailyTokens: float64(totalTokens) / float64(actualDaysUsed),
AvgDurationMs: avgDuration, AvgDurationMs: avgDuration,
...@@ -1727,11 +1779,13 @@ func (r *usageLogRepository) GetAccountUsageStats(ctx context.Context, accountID ...@@ -1727,11 +1779,13 @@ func (r *usageLogRepository) GetAccountUsageStats(ctx context.Context, accountID
summary.Today = &struct { summary.Today = &struct {
Date string `json:"date"` Date string `json:"date"`
Cost float64 `json:"cost"` Cost float64 `json:"cost"`
UserCost float64 `json:"user_cost"`
Requests int64 `json:"requests"` Requests int64 `json:"requests"`
Tokens int64 `json:"tokens"` Tokens int64 `json:"tokens"`
}{ }{
Date: history[i].Date, Date: history[i].Date,
Cost: history[i].ActualCost, Cost: history[i].ActualCost,
UserCost: history[i].UserCost,
Requests: history[i].Requests, Requests: history[i].Requests,
Tokens: history[i].Tokens, Tokens: history[i].Tokens,
} }
...@@ -1744,11 +1798,13 @@ func (r *usageLogRepository) GetAccountUsageStats(ctx context.Context, accountID ...@@ -1744,11 +1798,13 @@ func (r *usageLogRepository) GetAccountUsageStats(ctx context.Context, accountID
Date string `json:"date"` Date string `json:"date"`
Label string `json:"label"` Label string `json:"label"`
Cost float64 `json:"cost"` Cost float64 `json:"cost"`
UserCost float64 `json:"user_cost"`
Requests int64 `json:"requests"` Requests int64 `json:"requests"`
}{ }{
Date: highestCostDay.Date, Date: highestCostDay.Date,
Label: highestCostDay.Label, Label: highestCostDay.Label,
Cost: highestCostDay.ActualCost, Cost: highestCostDay.ActualCost,
UserCost: highestCostDay.UserCost,
Requests: highestCostDay.Requests, Requests: highestCostDay.Requests,
} }
} }
...@@ -1759,15 +1815,17 @@ func (r *usageLogRepository) GetAccountUsageStats(ctx context.Context, accountID ...@@ -1759,15 +1815,17 @@ func (r *usageLogRepository) GetAccountUsageStats(ctx context.Context, accountID
Label string `json:"label"` Label string `json:"label"`
Requests int64 `json:"requests"` Requests int64 `json:"requests"`
Cost float64 `json:"cost"` Cost float64 `json:"cost"`
UserCost float64 `json:"user_cost"`
}{ }{
Date: highestRequestDay.Date, Date: highestRequestDay.Date,
Label: highestRequestDay.Label, Label: highestRequestDay.Label,
Requests: highestRequestDay.Requests, Requests: highestRequestDay.Requests,
Cost: highestRequestDay.ActualCost, Cost: highestRequestDay.ActualCost,
UserCost: highestRequestDay.UserCost,
} }
} }
models, err := r.GetModelStatsWithFilters(ctx, startTime, endTime, 0, 0, accountID) models, err := r.GetModelStatsWithFilters(ctx, startTime, endTime, 0, 0, accountID, 0, nil)
if err != nil { if err != nil {
models = []ModelStat{} models = []ModelStat{}
} }
...@@ -1994,36 +2052,37 @@ func (r *usageLogRepository) loadSubscriptions(ctx context.Context, ids []int64) ...@@ -1994,36 +2052,37 @@ func (r *usageLogRepository) loadSubscriptions(ctx context.Context, ids []int64)
func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, error) { func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, error) {
var ( var (
id int64 id int64
userID int64 userID int64
apiKeyID int64 apiKeyID int64
accountID int64 accountID int64
requestID sql.NullString requestID sql.NullString
model string model string
groupID sql.NullInt64 groupID sql.NullInt64
subscriptionID sql.NullInt64 subscriptionID sql.NullInt64
inputTokens int inputTokens int
outputTokens int outputTokens int
cacheCreationTokens int cacheCreationTokens int
cacheReadTokens int cacheReadTokens int
cacheCreation5m int cacheCreation5m int
cacheCreation1h int cacheCreation1h int
inputCost float64 inputCost float64
outputCost float64 outputCost float64
cacheCreationCost float64 cacheCreationCost float64
cacheReadCost float64 cacheReadCost float64
totalCost float64 totalCost float64
actualCost float64 actualCost float64
rateMultiplier float64 rateMultiplier float64
billingType int16 accountRateMultiplier sql.NullFloat64
stream bool billingType int16
durationMs sql.NullInt64 stream bool
firstTokenMs sql.NullInt64 durationMs sql.NullInt64
userAgent sql.NullString firstTokenMs sql.NullInt64
ipAddress sql.NullString userAgent sql.NullString
imageCount int ipAddress sql.NullString
imageSize sql.NullString imageCount int
createdAt time.Time imageSize sql.NullString
createdAt time.Time
) )
if err := scanner.Scan( if err := scanner.Scan(
...@@ -2048,6 +2107,7 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e ...@@ -2048,6 +2107,7 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
&totalCost, &totalCost,
&actualCost, &actualCost,
&rateMultiplier, &rateMultiplier,
&accountRateMultiplier,
&billingType, &billingType,
&stream, &stream,
&durationMs, &durationMs,
...@@ -2080,6 +2140,7 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e ...@@ -2080,6 +2140,7 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
TotalCost: totalCost, TotalCost: totalCost,
ActualCost: actualCost, ActualCost: actualCost,
RateMultiplier: rateMultiplier, RateMultiplier: rateMultiplier,
AccountRateMultiplier: nullFloat64Ptr(accountRateMultiplier),
BillingType: int8(billingType), BillingType: int8(billingType),
Stream: stream, Stream: stream,
ImageCount: imageCount, ImageCount: imageCount,
...@@ -2186,6 +2247,14 @@ func nullInt(v *int) sql.NullInt64 { ...@@ -2186,6 +2247,14 @@ func nullInt(v *int) sql.NullInt64 {
return sql.NullInt64{Int64: int64(*v), Valid: true} return sql.NullInt64{Int64: int64(*v), Valid: true}
} }
func nullFloat64Ptr(v sql.NullFloat64) *float64 {
if !v.Valid {
return nil
}
out := v.Float64
return &out
}
func nullString(v *string) sql.NullString { func nullString(v *string) sql.NullString {
if v == nil || *v == "" { if v == nil || *v == "" {
return sql.NullString{} return sql.NullString{}
......
...@@ -11,6 +11,7 @@ import ( ...@@ -11,6 +11,7 @@ import (
dbent "github.com/Wei-Shaw/sub2api/ent" dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats" "github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
"github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
"github.com/stretchr/testify/suite" "github.com/stretchr/testify/suite"
...@@ -36,6 +37,12 @@ func TestUsageLogRepoSuite(t *testing.T) { ...@@ -36,6 +37,12 @@ func TestUsageLogRepoSuite(t *testing.T) {
suite.Run(t, new(UsageLogRepoSuite)) suite.Run(t, new(UsageLogRepoSuite))
} }
// truncateToDayUTC 截断到 UTC 日期边界(测试辅助函数)
func truncateToDayUTC(t time.Time) time.Time {
t = t.UTC()
return time.Date(t.Year(), t.Month(), t.Day(), 0, 0, 0, 0, time.UTC)
}
func (s *UsageLogRepoSuite) createUsageLog(user *service.User, apiKey *service.APIKey, account *service.Account, inputTokens, outputTokens int, cost float64, createdAt time.Time) *service.UsageLog { func (s *UsageLogRepoSuite) createUsageLog(user *service.User, apiKey *service.APIKey, account *service.Account, inputTokens, outputTokens int, cost float64, createdAt time.Time) *service.UsageLog {
log := &service.UsageLog{ log := &service.UsageLog{
UserID: user.ID, UserID: user.ID,
...@@ -95,6 +102,34 @@ func (s *UsageLogRepoSuite) TestGetByID_NotFound() { ...@@ -95,6 +102,34 @@ func (s *UsageLogRepoSuite) TestGetByID_NotFound() {
s.Require().Error(err, "expected error for non-existent ID") s.Require().Error(err, "expected error for non-existent ID")
} }
func (s *UsageLogRepoSuite) TestGetByID_ReturnsAccountRateMultiplier() {
user := mustCreateUser(s.T(), s.client, &service.User{Email: "getbyid-mult@test.com"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-getbyid-mult", Name: "k"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-getbyid-mult"})
m := 0.5
log := &service.UsageLog{
UserID: user.ID,
APIKeyID: apiKey.ID,
AccountID: account.ID,
RequestID: uuid.New().String(),
Model: "claude-3",
InputTokens: 10,
OutputTokens: 20,
TotalCost: 1.0,
ActualCost: 2.0,
AccountRateMultiplier: &m,
CreatedAt: timezone.Today().Add(2 * time.Hour),
}
_, err := s.repo.Create(s.ctx, log)
s.Require().NoError(err)
got, err := s.repo.GetByID(s.ctx, log.ID)
s.Require().NoError(err)
s.Require().NotNil(got.AccountRateMultiplier)
s.Require().InEpsilon(0.5, *got.AccountRateMultiplier, 0.0001)
}
// --- Delete --- // --- Delete ---
func (s *UsageLogRepoSuite) TestDelete() { func (s *UsageLogRepoSuite) TestDelete() {
...@@ -403,12 +438,49 @@ func (s *UsageLogRepoSuite) TestGetAccountTodayStats() { ...@@ -403,12 +438,49 @@ func (s *UsageLogRepoSuite) TestGetAccountTodayStats() {
apiKey := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-acctoday", Name: "k"}) apiKey := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-acctoday", Name: "k"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-today"}) account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-today"})
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now()) createdAt := timezone.Today().Add(1 * time.Hour)
m1 := 1.5
m2 := 0.0
_, err := s.repo.Create(s.ctx, &service.UsageLog{
UserID: user.ID,
APIKeyID: apiKey.ID,
AccountID: account.ID,
RequestID: uuid.New().String(),
Model: "claude-3",
InputTokens: 10,
OutputTokens: 20,
TotalCost: 1.0,
ActualCost: 2.0,
AccountRateMultiplier: &m1,
CreatedAt: createdAt,
})
s.Require().NoError(err)
_, err = s.repo.Create(s.ctx, &service.UsageLog{
UserID: user.ID,
APIKeyID: apiKey.ID,
AccountID: account.ID,
RequestID: uuid.New().String(),
Model: "claude-3",
InputTokens: 5,
OutputTokens: 5,
TotalCost: 0.5,
ActualCost: 1.0,
AccountRateMultiplier: &m2,
CreatedAt: createdAt,
})
s.Require().NoError(err)
stats, err := s.repo.GetAccountTodayStats(s.ctx, account.ID) stats, err := s.repo.GetAccountTodayStats(s.ctx, account.ID)
s.Require().NoError(err, "GetAccountTodayStats") s.Require().NoError(err, "GetAccountTodayStats")
s.Require().Equal(int64(1), stats.Requests) s.Require().Equal(int64(2), stats.Requests)
s.Require().Equal(int64(30), stats.Tokens) s.Require().Equal(int64(40), stats.Tokens)
// account cost = SUM(total_cost * account_rate_multiplier)
s.Require().InEpsilon(1.5, stats.Cost, 0.0001)
// standard cost = SUM(total_cost)
s.Require().InEpsilon(1.5, stats.StandardCost, 0.0001)
// user cost = SUM(actual_cost)
s.Require().InEpsilon(3.0, stats.UserCost, 0.0001)
} }
func (s *UsageLogRepoSuite) TestDashboardAggregationConsistency() { func (s *UsageLogRepoSuite) TestDashboardAggregationConsistency() {
...@@ -416,8 +488,8 @@ func (s *UsageLogRepoSuite) TestDashboardAggregationConsistency() { ...@@ -416,8 +488,8 @@ func (s *UsageLogRepoSuite) TestDashboardAggregationConsistency() {
// 使用固定的时间偏移确保 hour1 和 hour2 在同一天且都在过去 // 使用固定的时间偏移确保 hour1 和 hour2 在同一天且都在过去
// 选择当天 02:00 和 03:00 作为测试时间点(基于 now 的日期) // 选择当天 02:00 和 03:00 作为测试时间点(基于 now 的日期)
dayStart := truncateToDayUTC(now) dayStart := truncateToDayUTC(now)
hour1 := dayStart.Add(2 * time.Hour) // 当天 02:00 hour1 := dayStart.Add(2 * time.Hour) // 当天 02:00
hour2 := dayStart.Add(3 * time.Hour) // 当天 03:00 hour2 := dayStart.Add(3 * time.Hour) // 当天 03:00
// 如果当前时间早于 hour2,则使用昨天的时间 // 如果当前时间早于 hour2,则使用昨天的时间
if now.Before(hour2.Add(time.Hour)) { if now.Before(hour2.Add(time.Hour)) {
dayStart = dayStart.Add(-24 * time.Hour) dayStart = dayStart.Add(-24 * time.Hour)
...@@ -872,17 +944,17 @@ func (s *UsageLogRepoSuite) TestGetUsageTrendWithFilters() { ...@@ -872,17 +944,17 @@ func (s *UsageLogRepoSuite) TestGetUsageTrendWithFilters() {
endTime := base.Add(48 * time.Hour) endTime := base.Add(48 * time.Hour)
// Test with user filter // Test with user filter
trend, err := s.repo.GetUsageTrendWithFilters(s.ctx, startTime, endTime, "day", user.ID, 0) trend, err := s.repo.GetUsageTrendWithFilters(s.ctx, startTime, endTime, "day", user.ID, 0, 0, 0, "", nil)
s.Require().NoError(err, "GetUsageTrendWithFilters user filter") s.Require().NoError(err, "GetUsageTrendWithFilters user filter")
s.Require().Len(trend, 2) s.Require().Len(trend, 2)
// Test with apiKey filter // Test with apiKey filter
trend, err = s.repo.GetUsageTrendWithFilters(s.ctx, startTime, endTime, "day", 0, apiKey.ID) trend, err = s.repo.GetUsageTrendWithFilters(s.ctx, startTime, endTime, "day", 0, apiKey.ID, 0, 0, "", nil)
s.Require().NoError(err, "GetUsageTrendWithFilters apiKey filter") s.Require().NoError(err, "GetUsageTrendWithFilters apiKey filter")
s.Require().Len(trend, 2) s.Require().Len(trend, 2)
// Test with both filters // Test with both filters
trend, err = s.repo.GetUsageTrendWithFilters(s.ctx, startTime, endTime, "day", user.ID, apiKey.ID) trend, err = s.repo.GetUsageTrendWithFilters(s.ctx, startTime, endTime, "day", user.ID, apiKey.ID, 0, 0, "", nil)
s.Require().NoError(err, "GetUsageTrendWithFilters both filters") s.Require().NoError(err, "GetUsageTrendWithFilters both filters")
s.Require().Len(trend, 2) s.Require().Len(trend, 2)
} }
...@@ -899,7 +971,7 @@ func (s *UsageLogRepoSuite) TestGetUsageTrendWithFilters_HourlyGranularity() { ...@@ -899,7 +971,7 @@ func (s *UsageLogRepoSuite) TestGetUsageTrendWithFilters_HourlyGranularity() {
startTime := base.Add(-1 * time.Hour) startTime := base.Add(-1 * time.Hour)
endTime := base.Add(3 * time.Hour) endTime := base.Add(3 * time.Hour)
trend, err := s.repo.GetUsageTrendWithFilters(s.ctx, startTime, endTime, "hour", user.ID, 0) trend, err := s.repo.GetUsageTrendWithFilters(s.ctx, startTime, endTime, "hour", user.ID, 0, 0, 0, "", nil)
s.Require().NoError(err, "GetUsageTrendWithFilters hourly") s.Require().NoError(err, "GetUsageTrendWithFilters hourly")
s.Require().Len(trend, 2) s.Require().Len(trend, 2)
} }
...@@ -945,17 +1017,17 @@ func (s *UsageLogRepoSuite) TestGetModelStatsWithFilters() { ...@@ -945,17 +1017,17 @@ func (s *UsageLogRepoSuite) TestGetModelStatsWithFilters() {
endTime := base.Add(2 * time.Hour) endTime := base.Add(2 * time.Hour)
// Test with user filter // Test with user filter
stats, err := s.repo.GetModelStatsWithFilters(s.ctx, startTime, endTime, user.ID, 0, 0) stats, err := s.repo.GetModelStatsWithFilters(s.ctx, startTime, endTime, user.ID, 0, 0, 0, nil)
s.Require().NoError(err, "GetModelStatsWithFilters user filter") s.Require().NoError(err, "GetModelStatsWithFilters user filter")
s.Require().Len(stats, 2) s.Require().Len(stats, 2)
// Test with apiKey filter // Test with apiKey filter
stats, err = s.repo.GetModelStatsWithFilters(s.ctx, startTime, endTime, 0, apiKey.ID, 0) stats, err = s.repo.GetModelStatsWithFilters(s.ctx, startTime, endTime, 0, apiKey.ID, 0, 0, nil)
s.Require().NoError(err, "GetModelStatsWithFilters apiKey filter") s.Require().NoError(err, "GetModelStatsWithFilters apiKey filter")
s.Require().Len(stats, 2) s.Require().Len(stats, 2)
// Test with account filter // Test with account filter
stats, err = s.repo.GetModelStatsWithFilters(s.ctx, startTime, endTime, 0, 0, account.ID) stats, err = s.repo.GetModelStatsWithFilters(s.ctx, startTime, endTime, 0, 0, account.ID, 0, nil)
s.Require().NoError(err, "GetModelStatsWithFilters account filter") s.Require().NoError(err, "GetModelStatsWithFilters account filter")
s.Require().Len(stats, 2) s.Require().Len(stats, 2)
} }
......
...@@ -37,6 +37,16 @@ func ProvidePricingRemoteClient(cfg *config.Config) service.PricingRemoteClient ...@@ -37,6 +37,16 @@ func ProvidePricingRemoteClient(cfg *config.Config) service.PricingRemoteClient
return NewPricingRemoteClient(cfg.Update.ProxyURL) return NewPricingRemoteClient(cfg.Update.ProxyURL)
} }
// ProvideSessionLimitCache 创建会话限制缓存
// 用于 Anthropic OAuth/SetupToken 账号的并发会话数量控制
func ProvideSessionLimitCache(rdb *redis.Client, cfg *config.Config) service.SessionLimitCache {
defaultIdleTimeoutMinutes := 5 // 默认 5 分钟空闲超时
if cfg != nil && cfg.Gateway.SessionIdleTimeoutMinutes > 0 {
defaultIdleTimeoutMinutes = cfg.Gateway.SessionIdleTimeoutMinutes
}
return NewSessionLimitCache(rdb, defaultIdleTimeoutMinutes)
}
// ProviderSet is the Wire provider set for all repositories // ProviderSet is the Wire provider set for all repositories
var ProviderSet = wire.NewSet( var ProviderSet = wire.NewSet(
NewUserRepository, NewUserRepository,
...@@ -61,6 +71,7 @@ var ProviderSet = wire.NewSet( ...@@ -61,6 +71,7 @@ var ProviderSet = wire.NewSet(
NewTempUnschedCache, NewTempUnschedCache,
NewTimeoutCounterCache, NewTimeoutCounterCache,
ProvideConcurrencyCache, ProvideConcurrencyCache,
ProvideSessionLimitCache,
NewDashboardCache, NewDashboardCache,
NewEmailCache, NewEmailCache,
NewIdentityCache, NewIdentityCache,
...@@ -69,6 +80,7 @@ var ProviderSet = wire.NewSet( ...@@ -69,6 +80,7 @@ var ProviderSet = wire.NewSet(
NewGeminiTokenCache, NewGeminiTokenCache,
NewSchedulerCache, NewSchedulerCache,
NewSchedulerOutboxRepository, NewSchedulerOutboxRepository,
NewProxyLatencyCache,
// HTTP service ports (DI Strategy A: return interface directly) // HTTP service ports (DI Strategy A: return interface directly)
NewTurnstileVerifier, NewTurnstileVerifier,
......
...@@ -239,9 +239,10 @@ func TestAPIContracts(t *testing.T) { ...@@ -239,9 +239,10 @@ func TestAPIContracts(t *testing.T) {
"cache_creation_cost": 0, "cache_creation_cost": 0,
"cache_read_cost": 0, "cache_read_cost": 0,
"total_cost": 0.5, "total_cost": 0.5,
"actual_cost": 0.5, "actual_cost": 0.5,
"rate_multiplier": 1, "rate_multiplier": 1,
"billing_type": 0, "account_rate_multiplier": null,
"billing_type": 0,
"stream": true, "stream": true,
"duration_ms": 100, "duration_ms": 100,
"first_token_ms": 50, "first_token_ms": 50,
...@@ -262,11 +263,11 @@ func TestAPIContracts(t *testing.T) { ...@@ -262,11 +263,11 @@ func TestAPIContracts(t *testing.T) {
name: "GET /api/v1/admin/settings", name: "GET /api/v1/admin/settings",
setup: func(t *testing.T, deps *contractDeps) { setup: func(t *testing.T, deps *contractDeps) {
t.Helper() t.Helper()
deps.settingRepo.SetAll(map[string]string{ deps.settingRepo.SetAll(map[string]string{
service.SettingKeyRegistrationEnabled: "true", service.SettingKeyRegistrationEnabled: "true",
service.SettingKeyEmailVerifyEnabled: "false", service.SettingKeyEmailVerifyEnabled: "false",
service.SettingKeySMTPHost: "smtp.example.com", service.SettingKeySMTPHost: "smtp.example.com",
service.SettingKeySMTPPort: "587", service.SettingKeySMTPPort: "587",
service.SettingKeySMTPUsername: "user", service.SettingKeySMTPUsername: "user",
service.SettingKeySMTPPassword: "secret", service.SettingKeySMTPPassword: "secret",
...@@ -285,15 +286,15 @@ func TestAPIContracts(t *testing.T) { ...@@ -285,15 +286,15 @@ func TestAPIContracts(t *testing.T) {
service.SettingKeyContactInfo: "support", service.SettingKeyContactInfo: "support",
service.SettingKeyDocURL: "https://docs.example.com", service.SettingKeyDocURL: "https://docs.example.com",
service.SettingKeyDefaultConcurrency: "5", service.SettingKeyDefaultConcurrency: "5",
service.SettingKeyDefaultBalance: "1.25", service.SettingKeyDefaultBalance: "1.25",
service.SettingKeyOpsMonitoringEnabled: "false", service.SettingKeyOpsMonitoringEnabled: "false",
service.SettingKeyOpsRealtimeMonitoringEnabled: "true", service.SettingKeyOpsRealtimeMonitoringEnabled: "true",
service.SettingKeyOpsQueryModeDefault: "auto", service.SettingKeyOpsQueryModeDefault: "auto",
service.SettingKeyOpsMetricsIntervalSeconds: "60", service.SettingKeyOpsMetricsIntervalSeconds: "60",
}) })
}, },
method: http.MethodGet, method: http.MethodGet,
path: "/api/v1/admin/settings", path: "/api/v1/admin/settings",
wantStatus: http.StatusOK, wantStatus: http.StatusOK,
...@@ -435,12 +436,12 @@ func newContractDeps(t *testing.T) *contractDeps { ...@@ -435,12 +436,12 @@ func newContractDeps(t *testing.T) *contractDeps {
settingRepo := newStubSettingRepo() settingRepo := newStubSettingRepo()
settingService := service.NewSettingService(settingRepo, cfg) settingService := service.NewSettingService(settingRepo, cfg)
adminService := service.NewAdminService(userRepo, groupRepo, &accountRepo, proxyRepo, apiKeyRepo, redeemRepo, nil, nil, nil) adminService := service.NewAdminService(userRepo, groupRepo, &accountRepo, proxyRepo, apiKeyRepo, redeemRepo, nil, nil, nil, nil)
authHandler := handler.NewAuthHandler(cfg, nil, userService, settingService, nil) authHandler := handler.NewAuthHandler(cfg, nil, userService, settingService, nil)
apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService) apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService)
usageHandler := handler.NewUsageHandler(usageService, apiKeyService) usageHandler := handler.NewUsageHandler(usageService, apiKeyService)
adminSettingHandler := adminhandler.NewSettingHandler(settingService, nil, nil, nil) adminSettingHandler := adminhandler.NewSettingHandler(settingService, nil, nil, nil)
adminAccountHandler := adminhandler.NewAccountHandler(adminService, nil, nil, nil, nil, nil, nil, nil, nil, nil) adminAccountHandler := adminhandler.NewAccountHandler(adminService, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
jwtAuth := func(c *gin.Context) { jwtAuth := func(c *gin.Context) {
c.Set(string(middleware.ContextKeyUser), middleware.AuthSubject{ c.Set(string(middleware.ContextKeyUser), middleware.AuthSubject{
...@@ -779,6 +780,10 @@ func (s *stubAccountRepo) SetAntigravityQuotaScopeLimit(ctx context.Context, id ...@@ -779,6 +780,10 @@ func (s *stubAccountRepo) SetAntigravityQuotaScopeLimit(ctx context.Context, id
return errors.New("not implemented") return errors.New("not implemented")
} }
func (s *stubAccountRepo) SetModelRateLimit(ctx context.Context, id int64, scope string, resetAt time.Time) error {
return errors.New("not implemented")
}
func (s *stubAccountRepo) SetOverloaded(ctx context.Context, id int64, until time.Time) error { func (s *stubAccountRepo) SetOverloaded(ctx context.Context, id int64, until time.Time) error {
return errors.New("not implemented") return errors.New("not implemented")
} }
...@@ -799,6 +804,10 @@ func (s *stubAccountRepo) ClearAntigravityQuotaScopes(ctx context.Context, id in ...@@ -799,6 +804,10 @@ func (s *stubAccountRepo) ClearAntigravityQuotaScopes(ctx context.Context, id in
return errors.New("not implemented") return errors.New("not implemented")
} }
func (s *stubAccountRepo) ClearModelRateLimits(ctx context.Context, id int64) error {
return errors.New("not implemented")
}
func (s *stubAccountRepo) UpdateSessionWindow(ctx context.Context, id int64, start, end *time.Time, status string) error { func (s *stubAccountRepo) UpdateSessionWindow(ctx context.Context, id int64, start, end *time.Time, status string) error {
return errors.New("not implemented") return errors.New("not implemented")
} }
...@@ -858,6 +867,10 @@ func (stubProxyRepo) CountAccountsByProxyID(ctx context.Context, proxyID int64) ...@@ -858,6 +867,10 @@ func (stubProxyRepo) CountAccountsByProxyID(ctx context.Context, proxyID int64)
return 0, errors.New("not implemented") return 0, errors.New("not implemented")
} }
func (stubProxyRepo) ListAccountSummariesByProxyID(ctx context.Context, proxyID int64) ([]service.ProxyAccountSummary, error) {
return nil, errors.New("not implemented")
}
type stubRedeemCodeRepo struct{} type stubRedeemCodeRepo struct{}
func (stubRedeemCodeRepo) Create(ctx context.Context, code *service.RedeemCode) error { func (stubRedeemCodeRepo) Create(ctx context.Context, code *service.RedeemCode) error {
...@@ -1229,11 +1242,11 @@ func (r *stubUsageLogRepo) GetDashboardStats(ctx context.Context) (*usagestats.D ...@@ -1229,11 +1242,11 @@ func (r *stubUsageLogRepo) GetDashboardStats(ctx context.Context) (*usagestats.D
return nil, errors.New("not implemented") return nil, errors.New("not implemented")
} }
func (r *stubUsageLogRepo) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID int64) ([]usagestats.TrendDataPoint, error) { func (r *stubUsageLogRepo) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID, accountID, groupID int64, model string, stream *bool) ([]usagestats.TrendDataPoint, error) {
return nil, errors.New("not implemented") return nil, errors.New("not implemented")
} }
func (r *stubUsageLogRepo) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID int64) ([]usagestats.ModelStat, error) { func (r *stubUsageLogRepo) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, stream *bool) ([]usagestats.ModelStat, error) {
return nil, errors.New("not implemented") return nil, errors.New("not implemented")
} }
......
package middleware package middleware
import ( import (
"crypto/rand"
"encoding/base64"
"strings" "strings"
"github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/config"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )
const (
// CSPNonceKey is the context key for storing the CSP nonce
CSPNonceKey = "csp_nonce"
// NonceTemplate is the placeholder in CSP policy for nonce
NonceTemplate = "__CSP_NONCE__"
// CloudflareInsightsDomain is the domain for Cloudflare Web Analytics
CloudflareInsightsDomain = "https://static.cloudflareinsights.com"
)
// GenerateNonce generates a cryptographically secure random nonce
func GenerateNonce() string {
b := make([]byte, 16)
_, _ = rand.Read(b)
return base64.StdEncoding.EncodeToString(b)
}
// GetNonceFromContext retrieves the CSP nonce from gin context
func GetNonceFromContext(c *gin.Context) string {
if nonce, exists := c.Get(CSPNonceKey); exists {
if s, ok := nonce.(string); ok {
return s
}
}
return ""
}
// SecurityHeaders sets baseline security headers for all responses. // SecurityHeaders sets baseline security headers for all responses.
func SecurityHeaders(cfg config.CSPConfig) gin.HandlerFunc { func SecurityHeaders(cfg config.CSPConfig) gin.HandlerFunc {
policy := strings.TrimSpace(cfg.Policy) policy := strings.TrimSpace(cfg.Policy)
...@@ -14,13 +42,75 @@ func SecurityHeaders(cfg config.CSPConfig) gin.HandlerFunc { ...@@ -14,13 +42,75 @@ func SecurityHeaders(cfg config.CSPConfig) gin.HandlerFunc {
policy = config.DefaultCSPPolicy policy = config.DefaultCSPPolicy
} }
// Enhance policy with required directives (nonce placeholder and Cloudflare Insights)
policy = enhanceCSPPolicy(policy)
return func(c *gin.Context) { return func(c *gin.Context) {
c.Header("X-Content-Type-Options", "nosniff") c.Header("X-Content-Type-Options", "nosniff")
c.Header("X-Frame-Options", "DENY") c.Header("X-Frame-Options", "DENY")
c.Header("Referrer-Policy", "strict-origin-when-cross-origin") c.Header("Referrer-Policy", "strict-origin-when-cross-origin")
if cfg.Enabled { if cfg.Enabled {
c.Header("Content-Security-Policy", policy) // Generate nonce for this request
nonce := GenerateNonce()
c.Set(CSPNonceKey, nonce)
// Replace nonce placeholder in policy
finalPolicy := strings.ReplaceAll(policy, NonceTemplate, "'nonce-"+nonce+"'")
c.Header("Content-Security-Policy", finalPolicy)
} }
c.Next() c.Next()
} }
} }
// enhanceCSPPolicy ensures the CSP policy includes nonce support and Cloudflare Insights domain.
// This allows the application to work correctly even if the config file has an older CSP policy.
func enhanceCSPPolicy(policy string) string {
// Add nonce placeholder to script-src if not present
if !strings.Contains(policy, NonceTemplate) && !strings.Contains(policy, "'nonce-") {
policy = addToDirective(policy, "script-src", NonceTemplate)
}
// Add Cloudflare Insights domain to script-src if not present
if !strings.Contains(policy, CloudflareInsightsDomain) {
policy = addToDirective(policy, "script-src", CloudflareInsightsDomain)
}
return policy
}
// addToDirective adds a value to a specific CSP directive.
// If the directive doesn't exist, it will be added after default-src.
func addToDirective(policy, directive, value string) string {
// Find the directive in the policy
directivePrefix := directive + " "
idx := strings.Index(policy, directivePrefix)
if idx == -1 {
// Directive not found, add it after default-src or at the beginning
defaultSrcIdx := strings.Index(policy, "default-src ")
if defaultSrcIdx != -1 {
// Find the end of default-src directive (next semicolon)
endIdx := strings.Index(policy[defaultSrcIdx:], ";")
if endIdx != -1 {
insertPos := defaultSrcIdx + endIdx + 1
// Insert new directive after default-src
return policy[:insertPos] + " " + directive + " 'self' " + value + ";" + policy[insertPos:]
}
}
// Fallback: prepend the directive
return directive + " 'self' " + value + "; " + policy
}
// Find the end of this directive (next semicolon or end of string)
endIdx := strings.Index(policy[idx:], ";")
if endIdx == -1 {
// No semicolon found, directive goes to end of string
return policy + " " + value
}
// Insert value before the semicolon
insertPos := idx + endIdx
return policy[:insertPos] + " " + value + policy[insertPos:]
}
package middleware
import (
"encoding/base64"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func init() {
gin.SetMode(gin.TestMode)
}
func TestGenerateNonce(t *testing.T) {
t.Run("generates_valid_base64_string", func(t *testing.T) {
nonce := GenerateNonce()
// Should be valid base64
decoded, err := base64.StdEncoding.DecodeString(nonce)
require.NoError(t, err)
// Should decode to 16 bytes
assert.Len(t, decoded, 16)
})
t.Run("generates_unique_nonces", func(t *testing.T) {
nonces := make(map[string]bool)
for i := 0; i < 100; i++ {
nonce := GenerateNonce()
assert.False(t, nonces[nonce], "nonce should be unique")
nonces[nonce] = true
}
})
t.Run("nonce_has_expected_length", func(t *testing.T) {
nonce := GenerateNonce()
// 16 bytes -> 24 chars in base64 (with padding)
assert.Len(t, nonce, 24)
})
}
func TestGetNonceFromContext(t *testing.T) {
t.Run("returns_nonce_when_present", func(t *testing.T) {
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
expectedNonce := "test-nonce-123"
c.Set(CSPNonceKey, expectedNonce)
nonce := GetNonceFromContext(c)
assert.Equal(t, expectedNonce, nonce)
})
t.Run("returns_empty_string_when_not_present", func(t *testing.T) {
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
nonce := GetNonceFromContext(c)
assert.Empty(t, nonce)
})
t.Run("returns_empty_for_wrong_type", func(t *testing.T) {
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
// Set a non-string value
c.Set(CSPNonceKey, 12345)
// Should return empty string for wrong type (safe type assertion)
nonce := GetNonceFromContext(c)
assert.Empty(t, nonce)
})
}
func TestSecurityHeaders(t *testing.T) {
t.Run("sets_basic_security_headers", func(t *testing.T) {
cfg := config.CSPConfig{Enabled: false}
middleware := SecurityHeaders(cfg)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
middleware(c)
assert.Equal(t, "nosniff", w.Header().Get("X-Content-Type-Options"))
assert.Equal(t, "DENY", w.Header().Get("X-Frame-Options"))
assert.Equal(t, "strict-origin-when-cross-origin", w.Header().Get("Referrer-Policy"))
})
t.Run("csp_disabled_no_csp_header", func(t *testing.T) {
cfg := config.CSPConfig{Enabled: false}
middleware := SecurityHeaders(cfg)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
middleware(c)
assert.Empty(t, w.Header().Get("Content-Security-Policy"))
})
t.Run("csp_enabled_sets_csp_header", func(t *testing.T) {
cfg := config.CSPConfig{
Enabled: true,
Policy: "default-src 'self'",
}
middleware := SecurityHeaders(cfg)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
middleware(c)
csp := w.Header().Get("Content-Security-Policy")
assert.NotEmpty(t, csp)
// Policy is auto-enhanced with nonce and Cloudflare Insights domain
assert.Contains(t, csp, "default-src 'self'")
assert.Contains(t, csp, "'nonce-")
assert.Contains(t, csp, CloudflareInsightsDomain)
})
t.Run("csp_enabled_with_nonce_placeholder", func(t *testing.T) {
cfg := config.CSPConfig{
Enabled: true,
Policy: "script-src 'self' __CSP_NONCE__",
}
middleware := SecurityHeaders(cfg)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
middleware(c)
csp := w.Header().Get("Content-Security-Policy")
assert.NotEmpty(t, csp)
assert.NotContains(t, csp, "__CSP_NONCE__", "placeholder should be replaced")
assert.Contains(t, csp, "'nonce-", "should contain nonce directive")
// Verify nonce is stored in context
nonce := GetNonceFromContext(c)
assert.NotEmpty(t, nonce)
assert.Contains(t, csp, "'nonce-"+nonce+"'")
})
t.Run("uses_default_policy_when_empty", func(t *testing.T) {
cfg := config.CSPConfig{
Enabled: true,
Policy: "",
}
middleware := SecurityHeaders(cfg)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
middleware(c)
csp := w.Header().Get("Content-Security-Policy")
assert.NotEmpty(t, csp)
// Default policy should contain these elements
assert.Contains(t, csp, "default-src 'self'")
})
t.Run("uses_default_policy_when_whitespace_only", func(t *testing.T) {
cfg := config.CSPConfig{
Enabled: true,
Policy: " \t\n ",
}
middleware := SecurityHeaders(cfg)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
middleware(c)
csp := w.Header().Get("Content-Security-Policy")
assert.NotEmpty(t, csp)
assert.Contains(t, csp, "default-src 'self'")
})
t.Run("multiple_nonce_placeholders_replaced", func(t *testing.T) {
cfg := config.CSPConfig{
Enabled: true,
Policy: "script-src __CSP_NONCE__; style-src __CSP_NONCE__",
}
middleware := SecurityHeaders(cfg)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
middleware(c)
csp := w.Header().Get("Content-Security-Policy")
nonce := GetNonceFromContext(c)
// Count occurrences of the nonce
count := strings.Count(csp, "'nonce-"+nonce+"'")
assert.Equal(t, 2, count, "both placeholders should be replaced with same nonce")
})
t.Run("calls_next_handler", func(t *testing.T) {
cfg := config.CSPConfig{Enabled: true, Policy: "default-src 'self'"}
middleware := SecurityHeaders(cfg)
nextCalled := false
router := gin.New()
router.Use(middleware)
router.GET("/test", func(c *gin.Context) {
nextCalled = true
c.Status(http.StatusOK)
})
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/test", nil)
router.ServeHTTP(w, req)
assert.True(t, nextCalled, "next handler should be called")
assert.Equal(t, http.StatusOK, w.Code)
})
t.Run("nonce_unique_per_request", func(t *testing.T) {
cfg := config.CSPConfig{
Enabled: true,
Policy: "script-src __CSP_NONCE__",
}
middleware := SecurityHeaders(cfg)
nonces := make(map[string]bool)
for i := 0; i < 10; i++ {
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
middleware(c)
nonce := GetNonceFromContext(c)
assert.False(t, nonces[nonce], "nonce should be unique per request")
nonces[nonce] = true
}
})
}
func TestCSPNonceKey(t *testing.T) {
t.Run("constant_value", func(t *testing.T) {
assert.Equal(t, "csp_nonce", CSPNonceKey)
})
}
func TestNonceTemplate(t *testing.T) {
t.Run("constant_value", func(t *testing.T) {
assert.Equal(t, "__CSP_NONCE__", NonceTemplate)
})
}
func TestEnhanceCSPPolicy(t *testing.T) {
t.Run("adds_nonce_placeholder_if_missing", func(t *testing.T) {
policy := "default-src 'self'; script-src 'self'"
enhanced := enhanceCSPPolicy(policy)
assert.Contains(t, enhanced, NonceTemplate)
assert.Contains(t, enhanced, CloudflareInsightsDomain)
})
t.Run("does_not_duplicate_nonce_placeholder", func(t *testing.T) {
policy := "default-src 'self'; script-src 'self' __CSP_NONCE__"
enhanced := enhanceCSPPolicy(policy)
// Should not duplicate
count := strings.Count(enhanced, NonceTemplate)
assert.Equal(t, 1, count)
})
t.Run("does_not_duplicate_cloudflare_domain", func(t *testing.T) {
policy := "default-src 'self'; script-src 'self' https://static.cloudflareinsights.com"
enhanced := enhanceCSPPolicy(policy)
count := strings.Count(enhanced, CloudflareInsightsDomain)
assert.Equal(t, 1, count)
})
t.Run("handles_policy_without_script_src", func(t *testing.T) {
policy := "default-src 'self'"
enhanced := enhanceCSPPolicy(policy)
assert.Contains(t, enhanced, "script-src")
assert.Contains(t, enhanced, NonceTemplate)
assert.Contains(t, enhanced, CloudflareInsightsDomain)
})
t.Run("preserves_existing_nonce", func(t *testing.T) {
policy := "script-src 'self' 'nonce-existing'"
enhanced := enhanceCSPPolicy(policy)
// Should not add placeholder if nonce already exists
assert.NotContains(t, enhanced, NonceTemplate)
assert.Contains(t, enhanced, "'nonce-existing'")
})
}
func TestAddToDirective(t *testing.T) {
t.Run("adds_to_existing_directive", func(t *testing.T) {
policy := "script-src 'self'; style-src 'self'"
result := addToDirective(policy, "script-src", "https://example.com")
assert.Contains(t, result, "script-src 'self' https://example.com")
})
t.Run("creates_directive_if_not_exists", func(t *testing.T) {
policy := "default-src 'self'"
result := addToDirective(policy, "script-src", "https://example.com")
assert.Contains(t, result, "script-src")
assert.Contains(t, result, "https://example.com")
})
t.Run("handles_directive_at_end_without_semicolon", func(t *testing.T) {
policy := "default-src 'self'; script-src 'self'"
result := addToDirective(policy, "script-src", "https://example.com")
assert.Contains(t, result, "https://example.com")
})
t.Run("handles_empty_policy", func(t *testing.T) {
policy := ""
result := addToDirective(policy, "script-src", "https://example.com")
assert.Contains(t, result, "script-src")
assert.Contains(t, result, "https://example.com")
})
}
// Benchmark tests
func BenchmarkGenerateNonce(b *testing.B) {
for i := 0; i < b.N; i++ {
GenerateNonce()
}
}
func BenchmarkSecurityHeadersMiddleware(b *testing.B) {
cfg := config.CSPConfig{
Enabled: true,
Policy: "script-src 'self' __CSP_NONCE__",
}
middleware := SecurityHeaders(cfg)
b.ResetTimer()
for i := 0; i < b.N; i++ {
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
middleware(c)
}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment