"frontend/src/components/vscode:/vscode.git/clone" did not exist on "1dd3158c7e263ada7b2ea466c27aa1f584756991"
Commit 7331220e authored by Edric Li's avatar Edric Li
Browse files

Merge remote-tracking branch 'upstream/main'

# Conflicts:
#	frontend/src/components/account/CreateAccountModal.vue
parents fb86002e 4f13c8de
...@@ -34,22 +34,16 @@ func (s *ProxyProbeServiceSuite) setupProxyServer(handler http.HandlerFunc) { ...@@ -34,22 +34,16 @@ func (s *ProxyProbeServiceSuite) setupProxyServer(handler http.HandlerFunc) {
s.proxySrv = httptest.NewServer(handler) s.proxySrv = httptest.NewServer(handler)
} }
func (s *ProxyProbeServiceSuite) TestCreateProxyTransport_InvalidURL() { func (s *ProxyProbeServiceSuite) TestProbeProxy_InvalidProxyURL() {
_, err := createProxyTransport("://bad") _, _, err := s.prober.ProbeProxy(s.ctx, "://bad")
require.Error(s.T(), err) require.Error(s.T(), err)
require.ErrorContains(s.T(), err, "invalid proxy URL") require.ErrorContains(s.T(), err, "failed to create proxy client")
} }
func (s *ProxyProbeServiceSuite) TestCreateProxyTransport_UnsupportedScheme() { func (s *ProxyProbeServiceSuite) TestProbeProxy_UnsupportedProxyScheme() {
_, err := createProxyTransport("ftp://127.0.0.1:1") _, _, err := s.prober.ProbeProxy(s.ctx, "ftp://127.0.0.1:1")
require.Error(s.T(), err) require.Error(s.T(), err)
require.ErrorContains(s.T(), err, "unsupported proxy protocol") require.ErrorContains(s.T(), err, "failed to create proxy client")
}
func (s *ProxyProbeServiceSuite) TestCreateProxyTransport_Socks5SetsDialer() {
tr, err := createProxyTransport("socks5://127.0.0.1:1080")
require.NoError(s.T(), err, "createProxyTransport")
require.NotNil(s.T(), tr.DialContext, "expected DialContext to be set for socks5")
} }
func (s *ProxyProbeServiceSuite) TestProbeProxy_Success() { func (s *ProxyProbeServiceSuite) TestProbeProxy_Success() {
......
...@@ -178,7 +178,7 @@ func (r *proxyRepository) CountAccountsByProxyID(ctx context.Context, proxyID in ...@@ -178,7 +178,7 @@ func (r *proxyRepository) CountAccountsByProxyID(ctx context.Context, proxyID in
// GetAccountCountsForProxies returns a map of proxy ID to account count for all proxies // GetAccountCountsForProxies returns a map of proxy ID to account count for all proxies
func (r *proxyRepository) GetAccountCountsForProxies(ctx context.Context) (counts map[int64]int64, err error) { func (r *proxyRepository) GetAccountCountsForProxies(ctx context.Context) (counts map[int64]int64, err error) {
rows, err := r.sql.QueryContext(ctx, "SELECT proxy_id, COUNT(*) AS count FROM accounts WHERE proxy_id IS NOT NULL GROUP BY proxy_id") rows, err := r.sql.QueryContext(ctx, "SELECT proxy_id, COUNT(*) AS count FROM accounts WHERE proxy_id IS NOT NULL AND deleted_at IS NULL GROUP BY proxy_id")
if err != nil { if err != nil {
return nil, err return nil, err
} }
......
...@@ -168,7 +168,8 @@ func (r *redeemCodeRepository) Update(ctx context.Context, code *service.RedeemC ...@@ -168,7 +168,8 @@ func (r *redeemCodeRepository) Update(ctx context.Context, code *service.RedeemC
func (r *redeemCodeRepository) Use(ctx context.Context, id, userID int64) error { func (r *redeemCodeRepository) Use(ctx context.Context, id, userID int64) error {
now := time.Now() now := time.Now()
affected, err := r.client.RedeemCode.Update(). client := clientFromContext(ctx, r.client)
affected, err := client.RedeemCode.Update().
Where(redeemcode.IDEQ(id), redeemcode.StatusEQ(service.StatusUnused)). Where(redeemcode.IDEQ(id), redeemcode.StatusEQ(service.StatusUnused)).
SetStatus(service.StatusUsed). SetStatus(service.StatusUsed).
SetUsedBy(userID). SetUsedBy(userID).
......
package repository
import (
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/redis/go-redis/v9"
)
// InitRedis 初始化 Redis 客户端
//
// 性能优化说明:
// 原实现使用 go-redis 默认配置,未设置连接池和超时参数:
// 1. 默认连接池大小可能不足以支撑高并发
// 2. 无超时控制可能导致慢操作阻塞
//
// 新实现支持可配置的连接池和超时参数:
// 1. PoolSize: 控制最大并发连接数(默认 128)
// 2. MinIdleConns: 保持最小空闲连接,减少冷启动延迟(默认 10)
// 3. DialTimeout/ReadTimeout/WriteTimeout: 精确控制各阶段超时
func InitRedis(cfg *config.Config) *redis.Client {
return redis.NewClient(buildRedisOptions(cfg))
}
// buildRedisOptions 构建 Redis 连接选项
// 从配置文件读取连接池和超时参数,支持生产环境调优
func buildRedisOptions(cfg *config.Config) *redis.Options {
return &redis.Options{
Addr: cfg.Redis.Address(),
Password: cfg.Redis.Password,
DB: cfg.Redis.DB,
DialTimeout: time.Duration(cfg.Redis.DialTimeoutSeconds) * time.Second, // 建连超时
ReadTimeout: time.Duration(cfg.Redis.ReadTimeoutSeconds) * time.Second, // 读取超时
WriteTimeout: time.Duration(cfg.Redis.WriteTimeoutSeconds) * time.Second, // 写入超时
PoolSize: cfg.Redis.PoolSize, // 连接池大小
MinIdleConns: cfg.Redis.MinIdleConns, // 最小空闲连接
}
}
package repository
import (
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/stretchr/testify/require"
)
func TestBuildRedisOptions(t *testing.T) {
cfg := &config.Config{
Redis: config.RedisConfig{
Host: "localhost",
Port: 6379,
Password: "secret",
DB: 2,
DialTimeoutSeconds: 5,
ReadTimeoutSeconds: 3,
WriteTimeoutSeconds: 4,
PoolSize: 100,
MinIdleConns: 10,
},
}
opts := buildRedisOptions(cfg)
require.Equal(t, "localhost:6379", opts.Addr)
require.Equal(t, "secret", opts.Password)
require.Equal(t, 2, opts.DB)
require.Equal(t, 5*time.Second, opts.DialTimeout)
require.Equal(t, 3*time.Second, opts.ReadTimeout)
require.Equal(t, 4*time.Second, opts.WriteTimeout)
require.Equal(t, 100, opts.PoolSize)
require.Equal(t, 10, opts.MinIdleConns)
}
package repository
import (
"fmt"
"strings"
"sync"
"time"
"github.com/imroc/req/v3"
)
// reqClientOptions 定义 req 客户端的构建参数
type reqClientOptions struct {
ProxyURL string // 代理 URL(支持 http/https/socks5)
Timeout time.Duration // 请求超时时间
Impersonate bool // 是否模拟 Chrome 浏览器指纹
}
// sharedReqClients 存储按配置参数缓存的 req 客户端实例
//
// 性能优化说明:
// 原实现在每次 OAuth 刷新时都创建新的 req.Client:
// 1. claude_oauth_service.go: 每次刷新创建新客户端
// 2. openai_oauth_service.go: 每次刷新创建新客户端
// 3. gemini_oauth_client.go: 每次刷新创建新客户端
//
// 新实现使用 sync.Map 缓存客户端:
// 1. 相同配置(代理+超时+模拟设置)复用同一客户端
// 2. 复用底层连接池,减少 TLS 握手开销
// 3. LoadOrStore 保证并发安全,避免重复创建
var sharedReqClients sync.Map
// getSharedReqClient 获取共享的 req 客户端实例
// 性能优化:相同配置复用同一客户端,避免重复创建
func getSharedReqClient(opts reqClientOptions) *req.Client {
key := buildReqClientKey(opts)
if cached, ok := sharedReqClients.Load(key); ok {
if c, ok := cached.(*req.Client); ok {
return c
}
}
client := req.C().SetTimeout(opts.Timeout)
if opts.Impersonate {
client = client.ImpersonateChrome()
}
if strings.TrimSpace(opts.ProxyURL) != "" {
client.SetProxyURL(strings.TrimSpace(opts.ProxyURL))
}
actual, _ := sharedReqClients.LoadOrStore(key, client)
if c, ok := actual.(*req.Client); ok {
return c
}
return client
}
func buildReqClientKey(opts reqClientOptions) string {
return fmt.Sprintf("%s|%s|%t",
strings.TrimSpace(opts.ProxyURL),
opts.Timeout.String(),
opts.Impersonate,
)
}
...@@ -105,3 +105,59 @@ func (s *SettingRepoSuite) TestSetMultiple_Upsert() { ...@@ -105,3 +105,59 @@ func (s *SettingRepoSuite) TestSetMultiple_Upsert() {
s.Require().NoError(err) s.Require().NoError(err)
s.Require().Equal("new_val", got2) s.Require().Equal("new_val", got2)
} }
// TestSet_EmptyValue 测试保存空字符串值
// 这是一个回归测试,确保可选设置(如站点Logo、API端点地址等)可以保存为空字符串
func (s *SettingRepoSuite) TestSet_EmptyValue() {
// 测试 Set 方法保存空值
s.Require().NoError(s.repo.Set(s.ctx, "empty_key", ""), "Set with empty value should succeed")
got, err := s.repo.GetValue(s.ctx, "empty_key")
s.Require().NoError(err, "GetValue for empty value")
s.Require().Equal("", got, "empty value should be preserved")
}
// TestSetMultiple_WithEmptyValues 测试批量保存包含空字符串的设置
// 模拟用户保存站点设置时部分字段为空的场景
func (s *SettingRepoSuite) TestSetMultiple_WithEmptyValues() {
// 模拟保存站点设置,部分字段有值,部分字段为空
settings := map[string]string{
"site_name": "AICodex2API",
"site_subtitle": "Subscription to API",
"site_logo": "", // 用户未上传Logo
"api_base_url": "", // 用户未设置API地址
"contact_info": "", // 用户未设置联系方式
"doc_url": "", // 用户未设置文档链接
}
s.Require().NoError(s.repo.SetMultiple(s.ctx, settings), "SetMultiple with empty values should succeed")
// 验证所有值都正确保存
result, err := s.repo.GetMultiple(s.ctx, []string{"site_name", "site_subtitle", "site_logo", "api_base_url", "contact_info", "doc_url"})
s.Require().NoError(err, "GetMultiple after SetMultiple with empty values")
s.Require().Equal("AICodex2API", result["site_name"])
s.Require().Equal("Subscription to API", result["site_subtitle"])
s.Require().Equal("", result["site_logo"], "empty site_logo should be preserved")
s.Require().Equal("", result["api_base_url"], "empty api_base_url should be preserved")
s.Require().Equal("", result["contact_info"], "empty contact_info should be preserved")
s.Require().Equal("", result["doc_url"], "empty doc_url should be preserved")
}
// TestSetMultiple_UpdateToEmpty 测试将已有值更新为空字符串
// 确保用户可以清空之前设置的值
func (s *SettingRepoSuite) TestSetMultiple_UpdateToEmpty() {
// 先设置非空值
s.Require().NoError(s.repo.Set(s.ctx, "clearable_key", "initial_value"))
got, err := s.repo.GetValue(s.ctx, "clearable_key")
s.Require().NoError(err)
s.Require().Equal("initial_value", got)
// 更新为空值
s.Require().NoError(s.repo.SetMultiple(s.ctx, map[string]string{"clearable_key": ""}), "Update to empty should succeed")
got, err = s.repo.GetValue(s.ctx, "clearable_key")
s.Require().NoError(err)
s.Require().Equal("", got, "value should be updated to empty string")
}
...@@ -7,10 +7,12 @@ import ( ...@@ -7,10 +7,12 @@ import (
"fmt" "fmt"
"strings" "strings"
"testing" "testing"
"time"
dbent "github.com/Wei-Shaw/sub2api/ent" dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/apikey" "github.com/Wei-Shaw/sub2api/ent/apikey"
"github.com/Wei-Shaw/sub2api/ent/schema/mixins" "github.com/Wei-Shaw/sub2api/ent/schema/mixins"
"github.com/Wei-Shaw/sub2api/ent/usersubscription"
"github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
...@@ -111,3 +113,104 @@ func TestEntSoftDelete_ApiKey_HardDeleteViaSkipSoftDelete(t *testing.T) { ...@@ -111,3 +113,104 @@ func TestEntSoftDelete_ApiKey_HardDeleteViaSkipSoftDelete(t *testing.T) {
Only(mixins.SkipSoftDelete(ctx)) Only(mixins.SkipSoftDelete(ctx))
require.True(t, dbent.IsNotFound(err), "expected row to be hard deleted") require.True(t, dbent.IsNotFound(err), "expected row to be hard deleted")
} }
// --- UserSubscription 软删除测试 ---
func createEntGroup(t *testing.T, ctx context.Context, client *dbent.Client, name string) *dbent.Group {
t.Helper()
g, err := client.Group.Create().
SetName(name).
SetStatus(service.StatusActive).
Save(ctx)
require.NoError(t, err, "create ent group")
return g
}
func TestEntSoftDelete_UserSubscription_DefaultFilterAndSkip(t *testing.T) {
ctx := context.Background()
client := testEntClient(t)
u := createEntUser(t, ctx, client, uniqueSoftDeleteValue(t, "sd-sub-user")+"@example.com")
g := createEntGroup(t, ctx, client, uniqueSoftDeleteValue(t, "sd-sub-group"))
repo := NewUserSubscriptionRepository(client)
sub := &service.UserSubscription{
UserID: u.ID,
GroupID: g.ID,
Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour),
}
require.NoError(t, repo.Create(ctx, sub), "create user subscription")
require.NoError(t, repo.Delete(ctx, sub.ID), "soft delete user subscription")
_, err := repo.GetByID(ctx, sub.ID)
require.Error(t, err, "deleted rows should be hidden by default")
_, err = client.UserSubscription.Query().Where(usersubscription.IDEQ(sub.ID)).Only(ctx)
require.Error(t, err, "default ent query should not see soft-deleted rows")
require.True(t, dbent.IsNotFound(err), "expected ent not-found after default soft delete filter")
got, err := client.UserSubscription.Query().
Where(usersubscription.IDEQ(sub.ID)).
Only(mixins.SkipSoftDelete(ctx))
require.NoError(t, err, "SkipSoftDelete should include soft-deleted rows")
require.NotNil(t, got.DeletedAt, "deleted_at should be set after soft delete")
}
func TestEntSoftDelete_UserSubscription_DeleteIdempotent(t *testing.T) {
ctx := context.Background()
client := testEntClient(t)
u := createEntUser(t, ctx, client, uniqueSoftDeleteValue(t, "sd-sub-user2")+"@example.com")
g := createEntGroup(t, ctx, client, uniqueSoftDeleteValue(t, "sd-sub-group2"))
repo := NewUserSubscriptionRepository(client)
sub := &service.UserSubscription{
UserID: u.ID,
GroupID: g.ID,
Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour),
}
require.NoError(t, repo.Create(ctx, sub), "create user subscription")
require.NoError(t, repo.Delete(ctx, sub.ID), "first delete")
require.NoError(t, repo.Delete(ctx, sub.ID), "second delete should be idempotent")
}
func TestEntSoftDelete_UserSubscription_ListExcludesDeleted(t *testing.T) {
ctx := context.Background()
client := testEntClient(t)
u := createEntUser(t, ctx, client, uniqueSoftDeleteValue(t, "sd-sub-user3")+"@example.com")
g1 := createEntGroup(t, ctx, client, uniqueSoftDeleteValue(t, "sd-sub-group3a"))
g2 := createEntGroup(t, ctx, client, uniqueSoftDeleteValue(t, "sd-sub-group3b"))
repo := NewUserSubscriptionRepository(client)
sub1 := &service.UserSubscription{
UserID: u.ID,
GroupID: g1.ID,
Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour),
}
require.NoError(t, repo.Create(ctx, sub1), "create subscription 1")
sub2 := &service.UserSubscription{
UserID: u.ID,
GroupID: g2.ID,
Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour),
}
require.NoError(t, repo.Create(ctx, sub2), "create subscription 2")
// 软删除 sub1
require.NoError(t, repo.Delete(ctx, sub1.ID), "soft delete subscription 1")
// ListByUserID 应只返回未删除的订阅
subs, err := repo.ListByUserID(ctx, u.ID)
require.NoError(t, err, "ListByUserID")
require.Len(t, subs, 1, "should only return non-deleted subscriptions")
require.Equal(t, sub2.ID, subs[0].ID, "expected sub2 to be returned")
}
...@@ -9,6 +9,7 @@ import ( ...@@ -9,6 +9,7 @@ import (
"strings" "strings"
"time" "time"
"github.com/Wei-Shaw/sub2api/internal/pkg/httpclient"
"github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
) )
...@@ -20,11 +21,15 @@ type turnstileVerifier struct { ...@@ -20,11 +21,15 @@ type turnstileVerifier struct {
} }
func NewTurnstileVerifier() service.TurnstileVerifier { func NewTurnstileVerifier() service.TurnstileVerifier {
sharedClient, err := httpclient.GetClient(httpclient.Options{
Timeout: 10 * time.Second,
})
if err != nil {
sharedClient = &http.Client{Timeout: 10 * time.Second}
}
return &turnstileVerifier{ return &turnstileVerifier{
httpClient: &http.Client{ httpClient: sharedClient,
Timeout: 10 * time.Second, verifyURL: turnstileVerifyURL,
},
verifyURL: turnstileVerifyURL,
} }
} }
......
...@@ -4,6 +4,7 @@ import ( ...@@ -4,6 +4,7 @@ import (
"context" "context"
"database/sql" "database/sql"
"fmt" "fmt"
"os"
"strings" "strings"
"time" "time"
...@@ -452,6 +453,176 @@ func (r *usageLogRepository) GetApiKeyStatsAggregated(ctx context.Context, apiKe ...@@ -452,6 +453,176 @@ func (r *usageLogRepository) GetApiKeyStatsAggregated(ctx context.Context, apiKe
return &stats, nil return &stats, nil
} }
// GetAccountStatsAggregated 使用 SQL 聚合统计账号使用数据
//
// 性能优化说明:
// 原实现先查询所有日志记录,再在应用层循环计算统计值:
// 1. 需要传输大量数据到应用层
// 2. 应用层循环计算增加 CPU 和内存开销
//
// 新实现使用 SQL 聚合函数:
// 1. 在数据库层完成 COUNT/SUM/AVG 计算
// 2. 只返回单行聚合结果,大幅减少数据传输量
// 3. 利用数据库索引优化聚合查询性能
func (r *usageLogRepository) GetAccountStatsAggregated(ctx context.Context, accountID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error) {
query := `
SELECT
COUNT(*) as total_requests,
COALESCE(SUM(input_tokens), 0) as total_input_tokens,
COALESCE(SUM(output_tokens), 0) as total_output_tokens,
COALESCE(SUM(cache_creation_tokens + cache_read_tokens), 0) as total_cache_tokens,
COALESCE(SUM(total_cost), 0) as total_cost,
COALESCE(SUM(actual_cost), 0) as total_actual_cost,
COALESCE(AVG(COALESCE(duration_ms, 0)), 0) as avg_duration_ms
FROM usage_logs
WHERE account_id = $1 AND created_at >= $2 AND created_at < $3
`
var stats usagestats.UsageStats
if err := scanSingleRow(
ctx,
r.sql,
query,
[]any{accountID, startTime, endTime},
&stats.TotalRequests,
&stats.TotalInputTokens,
&stats.TotalOutputTokens,
&stats.TotalCacheTokens,
&stats.TotalCost,
&stats.TotalActualCost,
&stats.AverageDurationMs,
); err != nil {
return nil, err
}
stats.TotalTokens = stats.TotalInputTokens + stats.TotalOutputTokens + stats.TotalCacheTokens
return &stats, nil
}
// GetModelStatsAggregated 使用 SQL 聚合统计模型使用数据
// 性能优化:数据库层聚合计算,避免应用层循环统计
func (r *usageLogRepository) GetModelStatsAggregated(ctx context.Context, modelName string, startTime, endTime time.Time) (*usagestats.UsageStats, error) {
query := `
SELECT
COUNT(*) as total_requests,
COALESCE(SUM(input_tokens), 0) as total_input_tokens,
COALESCE(SUM(output_tokens), 0) as total_output_tokens,
COALESCE(SUM(cache_creation_tokens + cache_read_tokens), 0) as total_cache_tokens,
COALESCE(SUM(total_cost), 0) as total_cost,
COALESCE(SUM(actual_cost), 0) as total_actual_cost,
COALESCE(AVG(COALESCE(duration_ms, 0)), 0) as avg_duration_ms
FROM usage_logs
WHERE model = $1 AND created_at >= $2 AND created_at < $3
`
var stats usagestats.UsageStats
if err := scanSingleRow(
ctx,
r.sql,
query,
[]any{modelName, startTime, endTime},
&stats.TotalRequests,
&stats.TotalInputTokens,
&stats.TotalOutputTokens,
&stats.TotalCacheTokens,
&stats.TotalCost,
&stats.TotalActualCost,
&stats.AverageDurationMs,
); err != nil {
return nil, err
}
stats.TotalTokens = stats.TotalInputTokens + stats.TotalOutputTokens + stats.TotalCacheTokens
return &stats, nil
}
// GetDailyStatsAggregated 使用 SQL 聚合统计用户的每日使用数据
// 性能优化:使用 GROUP BY 在数据库层按日期分组聚合,避免应用层循环分组统计
func (r *usageLogRepository) GetDailyStatsAggregated(ctx context.Context, userID int64, startTime, endTime time.Time) (result []map[string]any, err error) {
tzName := resolveUsageStatsTimezone()
query := `
SELECT
-- 使用应用时区分组,避免数据库会话时区导致日边界偏移。
TO_CHAR(created_at AT TIME ZONE $4, 'YYYY-MM-DD') as date,
COUNT(*) as total_requests,
COALESCE(SUM(input_tokens), 0) as total_input_tokens,
COALESCE(SUM(output_tokens), 0) as total_output_tokens,
COALESCE(SUM(cache_creation_tokens + cache_read_tokens), 0) as total_cache_tokens,
COALESCE(SUM(total_cost), 0) as total_cost,
COALESCE(SUM(actual_cost), 0) as total_actual_cost,
COALESCE(AVG(COALESCE(duration_ms, 0)), 0) as avg_duration_ms
FROM usage_logs
WHERE user_id = $1 AND created_at >= $2 AND created_at < $3
GROUP BY 1
ORDER BY 1
`
rows, err := r.sql.QueryContext(ctx, query, userID, startTime, endTime, tzName)
if err != nil {
return nil, err
}
defer func() {
if closeErr := rows.Close(); closeErr != nil && err == nil {
err = closeErr
result = nil
}
}()
result = make([]map[string]any, 0)
for rows.Next() {
var (
date string
totalRequests int64
totalInputTokens int64
totalOutputTokens int64
totalCacheTokens int64
totalCost float64
totalActualCost float64
avgDurationMs float64
)
if err = rows.Scan(
&date,
&totalRequests,
&totalInputTokens,
&totalOutputTokens,
&totalCacheTokens,
&totalCost,
&totalActualCost,
&avgDurationMs,
); err != nil {
return nil, err
}
result = append(result, map[string]any{
"date": date,
"total_requests": totalRequests,
"total_input_tokens": totalInputTokens,
"total_output_tokens": totalOutputTokens,
"total_cache_tokens": totalCacheTokens,
"total_tokens": totalInputTokens + totalOutputTokens + totalCacheTokens,
"total_cost": totalCost,
"total_actual_cost": totalActualCost,
"average_duration_ms": avgDurationMs,
})
}
if err = rows.Err(); err != nil {
return nil, err
}
return result, nil
}
// resolveUsageStatsTimezone 获取用于 SQL 分组的时区名称。
// 优先使用应用初始化的时区,其次尝试读取 TZ 环境变量,最后回落为 UTC。
func resolveUsageStatsTimezone() string {
tzName := timezone.Name()
if tzName != "" && tzName != "Local" {
return tzName
}
if envTZ := strings.TrimSpace(os.Getenv("TZ")); envTZ != "" {
return envTZ
}
return "UTC"
}
func (r *usageLogRepository) ListByApiKeyAndTimeRange(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) ([]service.UsageLog, *pagination.PaginationResult, error) { func (r *usageLogRepository) ListByApiKeyAndTimeRange(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) ([]service.UsageLog, *pagination.PaginationResult, error) {
query := "SELECT " + usageLogSelectColumns + " FROM usage_logs WHERE api_key_id = $1 AND created_at >= $2 AND created_at < $3 ORDER BY id DESC" query := "SELECT " + usageLogSelectColumns + " FROM usage_logs WHERE api_key_id = $1 AND created_at >= $2 AND created_at < $3 ORDER BY id DESC"
logs, err := r.queryUsageLogs(ctx, query, apiKeyID, startTime, endTime) logs, err := r.queryUsageLogs(ctx, query, apiKeyID, startTime, endTime)
...@@ -938,6 +1109,9 @@ func (r *usageLogRepository) GetBatchUserUsageStats(ctx context.Context, userIDs ...@@ -938,6 +1109,9 @@ func (r *usageLogRepository) GetBatchUserUsageStats(ctx context.Context, userIDs
if err := rows.Close(); err != nil { if err := rows.Close(); err != nil {
return nil, err return nil, err
} }
if err := rows.Err(); err != nil {
return nil, err
}
today := timezone.Today() today := timezone.Today()
todayQuery := ` todayQuery := `
...@@ -964,6 +1138,9 @@ func (r *usageLogRepository) GetBatchUserUsageStats(ctx context.Context, userIDs ...@@ -964,6 +1138,9 @@ func (r *usageLogRepository) GetBatchUserUsageStats(ctx context.Context, userIDs
if err := rows.Close(); err != nil { if err := rows.Close(); err != nil {
return nil, err return nil, err
} }
if err := rows.Err(); err != nil {
return nil, err
}
return result, nil return result, nil
} }
...@@ -1006,6 +1183,9 @@ func (r *usageLogRepository) GetBatchApiKeyUsageStats(ctx context.Context, apiKe ...@@ -1006,6 +1183,9 @@ func (r *usageLogRepository) GetBatchApiKeyUsageStats(ctx context.Context, apiKe
if err := rows.Close(); err != nil { if err := rows.Close(); err != nil {
return nil, err return nil, err
} }
if err := rows.Err(); err != nil {
return nil, err
}
today := timezone.Today() today := timezone.Today()
todayQuery := ` todayQuery := `
...@@ -1032,6 +1212,9 @@ func (r *usageLogRepository) GetBatchApiKeyUsageStats(ctx context.Context, apiKe ...@@ -1032,6 +1212,9 @@ func (r *usageLogRepository) GetBatchApiKeyUsageStats(ctx context.Context, apiKe
if err := rows.Close(); err != nil { if err := rows.Close(); err != nil {
return nil, err return nil, err
} }
if err := rows.Err(); err != nil {
return nil, err
}
return result, nil return result, nil
} }
......
...@@ -12,20 +12,18 @@ import ( ...@@ -12,20 +12,18 @@ import (
"github.com/Wei-Shaw/sub2api/ent/usersubscription" "github.com/Wei-Shaw/sub2api/ent/usersubscription"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
"github.com/lib/pq"
) )
type userRepository struct { type userRepository struct {
client *dbent.Client client *dbent.Client
sql sqlExecutor
} }
func NewUserRepository(client *dbent.Client, sqlDB *sql.DB) service.UserRepository { func NewUserRepository(client *dbent.Client, sqlDB *sql.DB) service.UserRepository {
return newUserRepositoryWithSQL(client, sqlDB) return newUserRepositoryWithSQL(client, sqlDB)
} }
func newUserRepositoryWithSQL(client *dbent.Client, sqlq sqlExecutor) *userRepository { func newUserRepositoryWithSQL(client *dbent.Client, _ sqlExecutor) *userRepository {
return &userRepository{client: client, sql: sqlq} return &userRepository{client: client}
} }
func (r *userRepository) Create(ctx context.Context, userIn *service.User) error { func (r *userRepository) Create(ctx context.Context, userIn *service.User) error {
...@@ -86,10 +84,11 @@ func (r *userRepository) GetByID(ctx context.Context, id int64) (*service.User, ...@@ -86,10 +84,11 @@ func (r *userRepository) GetByID(ctx context.Context, id int64) (*service.User,
out := userEntityToService(m) out := userEntityToService(m)
groups, err := r.loadAllowedGroups(ctx, []int64{id}) groups, err := r.loadAllowedGroups(ctx, []int64{id})
if err == nil { if err != nil {
if v, ok := groups[id]; ok { return nil, err
out.AllowedGroups = v }
} if v, ok := groups[id]; ok {
out.AllowedGroups = v
} }
return out, nil return out, nil
} }
...@@ -102,10 +101,11 @@ func (r *userRepository) GetByEmail(ctx context.Context, email string) (*service ...@@ -102,10 +101,11 @@ func (r *userRepository) GetByEmail(ctx context.Context, email string) (*service
out := userEntityToService(m) out := userEntityToService(m)
groups, err := r.loadAllowedGroups(ctx, []int64{m.ID}) groups, err := r.loadAllowedGroups(ctx, []int64{m.ID})
if err == nil { if err != nil {
if v, ok := groups[m.ID]; ok { return nil, err
out.AllowedGroups = v }
} if v, ok := groups[m.ID]; ok {
out.AllowedGroups = v
} }
return out, nil return out, nil
} }
...@@ -240,11 +240,12 @@ func (r *userRepository) ListWithFilters(ctx context.Context, params pagination. ...@@ -240,11 +240,12 @@ func (r *userRepository) ListWithFilters(ctx context.Context, params pagination.
} }
allowedGroupsByUser, err := r.loadAllowedGroups(ctx, userIDs) allowedGroupsByUser, err := r.loadAllowedGroups(ctx, userIDs)
if err == nil { if err != nil {
for id, u := range userMap { return nil, nil, err
if groups, ok := allowedGroupsByUser[id]; ok { }
u.AllowedGroups = groups for id, u := range userMap {
} if groups, ok := allowedGroupsByUser[id]; ok {
u.AllowedGroups = groups
} }
} }
...@@ -252,12 +253,20 @@ func (r *userRepository) ListWithFilters(ctx context.Context, params pagination. ...@@ -252,12 +253,20 @@ func (r *userRepository) ListWithFilters(ctx context.Context, params pagination.
} }
func (r *userRepository) UpdateBalance(ctx context.Context, id int64, amount float64) error { func (r *userRepository) UpdateBalance(ctx context.Context, id int64, amount float64) error {
_, err := r.client.User.Update().Where(dbuser.IDEQ(id)).AddBalance(amount).Save(ctx) client := clientFromContext(ctx, r.client)
return err n, err := client.User.Update().Where(dbuser.IDEQ(id)).AddBalance(amount).Save(ctx)
if err != nil {
return translatePersistenceError(err, service.ErrUserNotFound, nil)
}
if n == 0 {
return service.ErrUserNotFound
}
return nil
} }
func (r *userRepository) DeductBalance(ctx context.Context, id int64, amount float64) error { func (r *userRepository) DeductBalance(ctx context.Context, id int64, amount float64) error {
n, err := r.client.User.Update(). client := clientFromContext(ctx, r.client)
n, err := client.User.Update().
Where(dbuser.IDEQ(id), dbuser.BalanceGTE(amount)). Where(dbuser.IDEQ(id), dbuser.BalanceGTE(amount)).
AddBalance(-amount). AddBalance(-amount).
Save(ctx) Save(ctx)
...@@ -271,8 +280,15 @@ func (r *userRepository) DeductBalance(ctx context.Context, id int64, amount flo ...@@ -271,8 +280,15 @@ func (r *userRepository) DeductBalance(ctx context.Context, id int64, amount flo
} }
func (r *userRepository) UpdateConcurrency(ctx context.Context, id int64, amount int) error { func (r *userRepository) UpdateConcurrency(ctx context.Context, id int64, amount int) error {
_, err := r.client.User.Update().Where(dbuser.IDEQ(id)).AddConcurrency(amount).Save(ctx) client := clientFromContext(ctx, r.client)
return err n, err := client.User.Update().Where(dbuser.IDEQ(id)).AddConcurrency(amount).Save(ctx)
if err != nil {
return translatePersistenceError(err, service.ErrUserNotFound, nil)
}
if n == 0 {
return service.ErrUserNotFound
}
return nil
} }
func (r *userRepository) ExistsByEmail(ctx context.Context, email string) (bool, error) { func (r *userRepository) ExistsByEmail(ctx context.Context, email string) (bool, error) {
...@@ -280,33 +296,14 @@ func (r *userRepository) ExistsByEmail(ctx context.Context, email string) (bool, ...@@ -280,33 +296,14 @@ func (r *userRepository) ExistsByEmail(ctx context.Context, email string) (bool,
} }
func (r *userRepository) RemoveGroupFromAllowedGroups(ctx context.Context, groupID int64) (int64, error) { func (r *userRepository) RemoveGroupFromAllowedGroups(ctx context.Context, groupID int64) (int64, error) {
exec := r.sql // 仅操作 user_allowed_groups 联接表,legacy users.allowed_groups 列已弃用。
if exec == nil { affected, err := r.client.UserAllowedGroup.Delete().
// 未注入 sqlExecutor 时,退回到 ent client 的 ExecContext(支持事务)。
exec = r.client
}
joinAffected, err := r.client.UserAllowedGroup.Delete().
Where(userallowedgroup.GroupIDEQ(groupID)). Where(userallowedgroup.GroupIDEQ(groupID)).
Exec(ctx) Exec(ctx)
if err != nil { if err != nil {
return 0, err return 0, err
} }
return int64(affected), nil
arrayRes, err := exec.ExecContext(
ctx,
"UPDATE users SET allowed_groups = array_remove(allowed_groups, $1), updated_at = NOW() WHERE $1 = ANY(allowed_groups)",
groupID,
)
if err != nil {
return 0, err
}
arrayAffected, _ := arrayRes.RowsAffected()
if int64(joinAffected) > arrayAffected {
return int64(joinAffected), nil
}
return arrayAffected, nil
} }
func (r *userRepository) GetFirstAdmin(ctx context.Context) (*service.User, error) { func (r *userRepository) GetFirstAdmin(ctx context.Context) (*service.User, error) {
...@@ -323,10 +320,11 @@ func (r *userRepository) GetFirstAdmin(ctx context.Context) (*service.User, erro ...@@ -323,10 +320,11 @@ func (r *userRepository) GetFirstAdmin(ctx context.Context) (*service.User, erro
out := userEntityToService(m) out := userEntityToService(m)
groups, err := r.loadAllowedGroups(ctx, []int64{m.ID}) groups, err := r.loadAllowedGroups(ctx, []int64{m.ID})
if err == nil { if err != nil {
if v, ok := groups[m.ID]; ok { return nil, err
out.AllowedGroups = v }
} if v, ok := groups[m.ID]; ok {
out.AllowedGroups = v
} }
return out, nil return out, nil
} }
...@@ -356,8 +354,7 @@ func (r *userRepository) loadAllowedGroups(ctx context.Context, userIDs []int64) ...@@ -356,8 +354,7 @@ func (r *userRepository) loadAllowedGroups(ctx context.Context, userIDs []int64)
} }
// syncUserAllowedGroupsWithClient 在 ent client/事务内同步用户允许分组: // syncUserAllowedGroupsWithClient 在 ent client/事务内同步用户允许分组:
// 1) 以 user_allowed_groups 为读写源,确保新旧逻辑一致; // 仅操作 user_allowed_groups 联接表,legacy users.allowed_groups 列已弃用。
// 2) 额外更新 users.allowed_groups(历史字段)以保持兼容。
func (r *userRepository) syncUserAllowedGroupsWithClient(ctx context.Context, client *dbent.Client, userID int64, groupIDs []int64) error { func (r *userRepository) syncUserAllowedGroupsWithClient(ctx context.Context, client *dbent.Client, userID int64, groupIDs []int64) error {
if client == nil { if client == nil {
return nil return nil
...@@ -376,12 +373,10 @@ func (r *userRepository) syncUserAllowedGroupsWithClient(ctx context.Context, cl ...@@ -376,12 +373,10 @@ func (r *userRepository) syncUserAllowedGroupsWithClient(ctx context.Context, cl
unique[id] = struct{}{} unique[id] = struct{}{}
} }
legacyGroups := make([]int64, 0, len(unique))
if len(unique) > 0 { if len(unique) > 0 {
creates := make([]*dbent.UserAllowedGroupCreate, 0, len(unique)) creates := make([]*dbent.UserAllowedGroupCreate, 0, len(unique))
for groupID := range unique { for groupID := range unique {
creates = append(creates, client.UserAllowedGroup.Create().SetUserID(userID).SetGroupID(groupID)) creates = append(creates, client.UserAllowedGroup.Create().SetUserID(userID).SetGroupID(groupID))
legacyGroups = append(legacyGroups, groupID)
} }
if err := client.UserAllowedGroup. if err := client.UserAllowedGroup.
CreateBulk(creates...). CreateBulk(creates...).
...@@ -392,16 +387,6 @@ func (r *userRepository) syncUserAllowedGroupsWithClient(ctx context.Context, cl ...@@ -392,16 +387,6 @@ func (r *userRepository) syncUserAllowedGroupsWithClient(ctx context.Context, cl
} }
} }
// Phase 1 兼容:保持 users.allowed_groups(数组字段)同步,避免旧查询路径读取到过期数据。
var legacy any
if len(legacyGroups) > 0 {
sort.Slice(legacyGroups, func(i, j int) bool { return legacyGroups[i] < legacyGroups[j] })
legacy = pq.Array(legacyGroups)
}
if _, err := client.ExecContext(ctx, "UPDATE users SET allowed_groups = $1::bigint[] WHERE id = $2", legacy, userID); err != nil {
return err
}
return nil return nil
} }
......
...@@ -507,3 +507,24 @@ func (s *UserRepoSuite) TestCRUD_And_Filters_And_AtomicUpdates() { ...@@ -507,3 +507,24 @@ func (s *UserRepoSuite) TestCRUD_And_Filters_And_AtomicUpdates() {
s.Require().Len(users, 1, "ListWithFilters len mismatch") s.Require().Len(users, 1, "ListWithFilters len mismatch")
s.Require().Equal(user2.ID, users[0].ID, "ListWithFilters result mismatch") s.Require().Equal(user2.ID, users[0].ID, "ListWithFilters result mismatch")
} }
// --- UpdateBalance/UpdateConcurrency 影响行数校验测试 ---
func (s *UserRepoSuite) TestUpdateBalance_NotFound() {
err := s.repo.UpdateBalance(s.ctx, 999999, 10.0)
s.Require().Error(err, "expected error for non-existent user")
s.Require().ErrorIs(err, service.ErrUserNotFound)
}
func (s *UserRepoSuite) TestUpdateConcurrency_NotFound() {
err := s.repo.UpdateConcurrency(s.ctx, 999999, 5)
s.Require().Error(err, "expected error for non-existent user")
s.Require().ErrorIs(err, service.ErrUserNotFound)
}
func (s *UserRepoSuite) TestDeductBalance_NotFound() {
err := s.repo.DeductBalance(s.ctx, 999999, 5)
s.Require().Error(err, "expected error for non-existent user")
// DeductBalance 在用户不存在时返回 ErrInsufficientBalance 因为 WHERE 条件不匹配
s.Require().ErrorIs(err, service.ErrInsufficientBalance)
}
...@@ -20,10 +20,11 @@ func NewUserSubscriptionRepository(client *dbent.Client) service.UserSubscriptio ...@@ -20,10 +20,11 @@ func NewUserSubscriptionRepository(client *dbent.Client) service.UserSubscriptio
func (r *userSubscriptionRepository) Create(ctx context.Context, sub *service.UserSubscription) error { func (r *userSubscriptionRepository) Create(ctx context.Context, sub *service.UserSubscription) error {
if sub == nil { if sub == nil {
return nil return service.ErrSubscriptionNilInput
} }
builder := r.client.UserSubscription.Create(). client := clientFromContext(ctx, r.client)
builder := client.UserSubscription.Create().
SetUserID(sub.UserID). SetUserID(sub.UserID).
SetGroupID(sub.GroupID). SetGroupID(sub.GroupID).
SetExpiresAt(sub.ExpiresAt). SetExpiresAt(sub.ExpiresAt).
...@@ -57,7 +58,8 @@ func (r *userSubscriptionRepository) Create(ctx context.Context, sub *service.Us ...@@ -57,7 +58,8 @@ func (r *userSubscriptionRepository) Create(ctx context.Context, sub *service.Us
} }
func (r *userSubscriptionRepository) GetByID(ctx context.Context, id int64) (*service.UserSubscription, error) { func (r *userSubscriptionRepository) GetByID(ctx context.Context, id int64) (*service.UserSubscription, error) {
m, err := r.client.UserSubscription.Query(). client := clientFromContext(ctx, r.client)
m, err := client.UserSubscription.Query().
Where(usersubscription.IDEQ(id)). Where(usersubscription.IDEQ(id)).
WithUser(). WithUser().
WithGroup(). WithGroup().
...@@ -70,7 +72,8 @@ func (r *userSubscriptionRepository) GetByID(ctx context.Context, id int64) (*se ...@@ -70,7 +72,8 @@ func (r *userSubscriptionRepository) GetByID(ctx context.Context, id int64) (*se
} }
func (r *userSubscriptionRepository) GetByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (*service.UserSubscription, error) { func (r *userSubscriptionRepository) GetByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (*service.UserSubscription, error) {
m, err := r.client.UserSubscription.Query(). client := clientFromContext(ctx, r.client)
m, err := client.UserSubscription.Query().
Where(usersubscription.UserIDEQ(userID), usersubscription.GroupIDEQ(groupID)). Where(usersubscription.UserIDEQ(userID), usersubscription.GroupIDEQ(groupID)).
WithGroup(). WithGroup().
Only(ctx) Only(ctx)
...@@ -81,7 +84,8 @@ func (r *userSubscriptionRepository) GetByUserIDAndGroupID(ctx context.Context, ...@@ -81,7 +84,8 @@ func (r *userSubscriptionRepository) GetByUserIDAndGroupID(ctx context.Context,
} }
func (r *userSubscriptionRepository) GetActiveByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (*service.UserSubscription, error) { func (r *userSubscriptionRepository) GetActiveByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (*service.UserSubscription, error) {
m, err := r.client.UserSubscription.Query(). client := clientFromContext(ctx, r.client)
m, err := client.UserSubscription.Query().
Where( Where(
usersubscription.UserIDEQ(userID), usersubscription.UserIDEQ(userID),
usersubscription.GroupIDEQ(groupID), usersubscription.GroupIDEQ(groupID),
...@@ -98,10 +102,11 @@ func (r *userSubscriptionRepository) GetActiveByUserIDAndGroupID(ctx context.Con ...@@ -98,10 +102,11 @@ func (r *userSubscriptionRepository) GetActiveByUserIDAndGroupID(ctx context.Con
func (r *userSubscriptionRepository) Update(ctx context.Context, sub *service.UserSubscription) error { func (r *userSubscriptionRepository) Update(ctx context.Context, sub *service.UserSubscription) error {
if sub == nil { if sub == nil {
return nil return service.ErrSubscriptionNilInput
} }
builder := r.client.UserSubscription.UpdateOneID(sub.ID). client := clientFromContext(ctx, r.client)
builder := client.UserSubscription.UpdateOneID(sub.ID).
SetUserID(sub.UserID). SetUserID(sub.UserID).
SetGroupID(sub.GroupID). SetGroupID(sub.GroupID).
SetStartsAt(sub.StartsAt). SetStartsAt(sub.StartsAt).
...@@ -127,12 +132,14 @@ func (r *userSubscriptionRepository) Update(ctx context.Context, sub *service.Us ...@@ -127,12 +132,14 @@ func (r *userSubscriptionRepository) Update(ctx context.Context, sub *service.Us
func (r *userSubscriptionRepository) Delete(ctx context.Context, id int64) error { func (r *userSubscriptionRepository) Delete(ctx context.Context, id int64) error {
// Match GORM semantics: deleting a missing row is not an error. // Match GORM semantics: deleting a missing row is not an error.
_, err := r.client.UserSubscription.Delete().Where(usersubscription.IDEQ(id)).Exec(ctx) client := clientFromContext(ctx, r.client)
_, err := client.UserSubscription.Delete().Where(usersubscription.IDEQ(id)).Exec(ctx)
return err return err
} }
func (r *userSubscriptionRepository) ListByUserID(ctx context.Context, userID int64) ([]service.UserSubscription, error) { func (r *userSubscriptionRepository) ListByUserID(ctx context.Context, userID int64) ([]service.UserSubscription, error) {
subs, err := r.client.UserSubscription.Query(). client := clientFromContext(ctx, r.client)
subs, err := client.UserSubscription.Query().
Where(usersubscription.UserIDEQ(userID)). Where(usersubscription.UserIDEQ(userID)).
WithGroup(). WithGroup().
Order(dbent.Desc(usersubscription.FieldCreatedAt)). Order(dbent.Desc(usersubscription.FieldCreatedAt)).
...@@ -144,7 +151,8 @@ func (r *userSubscriptionRepository) ListByUserID(ctx context.Context, userID in ...@@ -144,7 +151,8 @@ func (r *userSubscriptionRepository) ListByUserID(ctx context.Context, userID in
} }
func (r *userSubscriptionRepository) ListActiveByUserID(ctx context.Context, userID int64) ([]service.UserSubscription, error) { func (r *userSubscriptionRepository) ListActiveByUserID(ctx context.Context, userID int64) ([]service.UserSubscription, error) {
subs, err := r.client.UserSubscription.Query(). client := clientFromContext(ctx, r.client)
subs, err := client.UserSubscription.Query().
Where( Where(
usersubscription.UserIDEQ(userID), usersubscription.UserIDEQ(userID),
usersubscription.StatusEQ(service.SubscriptionStatusActive), usersubscription.StatusEQ(service.SubscriptionStatusActive),
...@@ -160,7 +168,8 @@ func (r *userSubscriptionRepository) ListActiveByUserID(ctx context.Context, use ...@@ -160,7 +168,8 @@ func (r *userSubscriptionRepository) ListActiveByUserID(ctx context.Context, use
} }
func (r *userSubscriptionRepository) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]service.UserSubscription, *pagination.PaginationResult, error) { func (r *userSubscriptionRepository) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]service.UserSubscription, *pagination.PaginationResult, error) {
q := r.client.UserSubscription.Query().Where(usersubscription.GroupIDEQ(groupID)) client := clientFromContext(ctx, r.client)
q := client.UserSubscription.Query().Where(usersubscription.GroupIDEQ(groupID))
total, err := q.Clone().Count(ctx) total, err := q.Clone().Count(ctx)
if err != nil { if err != nil {
...@@ -182,7 +191,8 @@ func (r *userSubscriptionRepository) ListByGroupID(ctx context.Context, groupID ...@@ -182,7 +191,8 @@ func (r *userSubscriptionRepository) ListByGroupID(ctx context.Context, groupID
} }
func (r *userSubscriptionRepository) List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status string) ([]service.UserSubscription, *pagination.PaginationResult, error) { func (r *userSubscriptionRepository) List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status string) ([]service.UserSubscription, *pagination.PaginationResult, error) {
q := r.client.UserSubscription.Query() client := clientFromContext(ctx, r.client)
q := client.UserSubscription.Query()
if userID != nil { if userID != nil {
q = q.Where(usersubscription.UserIDEQ(*userID)) q = q.Where(usersubscription.UserIDEQ(*userID))
} }
...@@ -214,34 +224,39 @@ func (r *userSubscriptionRepository) List(ctx context.Context, params pagination ...@@ -214,34 +224,39 @@ func (r *userSubscriptionRepository) List(ctx context.Context, params pagination
} }
func (r *userSubscriptionRepository) ExistsByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (bool, error) { func (r *userSubscriptionRepository) ExistsByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (bool, error) {
return r.client.UserSubscription.Query(). client := clientFromContext(ctx, r.client)
return client.UserSubscription.Query().
Where(usersubscription.UserIDEQ(userID), usersubscription.GroupIDEQ(groupID)). Where(usersubscription.UserIDEQ(userID), usersubscription.GroupIDEQ(groupID)).
Exist(ctx) Exist(ctx)
} }
func (r *userSubscriptionRepository) ExtendExpiry(ctx context.Context, subscriptionID int64, newExpiresAt time.Time) error { func (r *userSubscriptionRepository) ExtendExpiry(ctx context.Context, subscriptionID int64, newExpiresAt time.Time) error {
_, err := r.client.UserSubscription.UpdateOneID(subscriptionID). client := clientFromContext(ctx, r.client)
_, err := client.UserSubscription.UpdateOneID(subscriptionID).
SetExpiresAt(newExpiresAt). SetExpiresAt(newExpiresAt).
Save(ctx) Save(ctx)
return translatePersistenceError(err, service.ErrSubscriptionNotFound, nil) return translatePersistenceError(err, service.ErrSubscriptionNotFound, nil)
} }
func (r *userSubscriptionRepository) UpdateStatus(ctx context.Context, subscriptionID int64, status string) error { func (r *userSubscriptionRepository) UpdateStatus(ctx context.Context, subscriptionID int64, status string) error {
_, err := r.client.UserSubscription.UpdateOneID(subscriptionID). client := clientFromContext(ctx, r.client)
_, err := client.UserSubscription.UpdateOneID(subscriptionID).
SetStatus(status). SetStatus(status).
Save(ctx) Save(ctx)
return translatePersistenceError(err, service.ErrSubscriptionNotFound, nil) return translatePersistenceError(err, service.ErrSubscriptionNotFound, nil)
} }
func (r *userSubscriptionRepository) UpdateNotes(ctx context.Context, subscriptionID int64, notes string) error { func (r *userSubscriptionRepository) UpdateNotes(ctx context.Context, subscriptionID int64, notes string) error {
_, err := r.client.UserSubscription.UpdateOneID(subscriptionID). client := clientFromContext(ctx, r.client)
_, err := client.UserSubscription.UpdateOneID(subscriptionID).
SetNotes(notes). SetNotes(notes).
Save(ctx) Save(ctx)
return translatePersistenceError(err, service.ErrSubscriptionNotFound, nil) return translatePersistenceError(err, service.ErrSubscriptionNotFound, nil)
} }
func (r *userSubscriptionRepository) ActivateWindows(ctx context.Context, id int64, start time.Time) error { func (r *userSubscriptionRepository) ActivateWindows(ctx context.Context, id int64, start time.Time) error {
_, err := r.client.UserSubscription.UpdateOneID(id). client := clientFromContext(ctx, r.client)
_, err := client.UserSubscription.UpdateOneID(id).
SetDailyWindowStart(start). SetDailyWindowStart(start).
SetWeeklyWindowStart(start). SetWeeklyWindowStart(start).
SetMonthlyWindowStart(start). SetMonthlyWindowStart(start).
...@@ -250,7 +265,8 @@ func (r *userSubscriptionRepository) ActivateWindows(ctx context.Context, id int ...@@ -250,7 +265,8 @@ func (r *userSubscriptionRepository) ActivateWindows(ctx context.Context, id int
} }
func (r *userSubscriptionRepository) ResetDailyUsage(ctx context.Context, id int64, newWindowStart time.Time) error { func (r *userSubscriptionRepository) ResetDailyUsage(ctx context.Context, id int64, newWindowStart time.Time) error {
_, err := r.client.UserSubscription.UpdateOneID(id). client := clientFromContext(ctx, r.client)
_, err := client.UserSubscription.UpdateOneID(id).
SetDailyUsageUsd(0). SetDailyUsageUsd(0).
SetDailyWindowStart(newWindowStart). SetDailyWindowStart(newWindowStart).
Save(ctx) Save(ctx)
...@@ -258,7 +274,8 @@ func (r *userSubscriptionRepository) ResetDailyUsage(ctx context.Context, id int ...@@ -258,7 +274,8 @@ func (r *userSubscriptionRepository) ResetDailyUsage(ctx context.Context, id int
} }
func (r *userSubscriptionRepository) ResetWeeklyUsage(ctx context.Context, id int64, newWindowStart time.Time) error { func (r *userSubscriptionRepository) ResetWeeklyUsage(ctx context.Context, id int64, newWindowStart time.Time) error {
_, err := r.client.UserSubscription.UpdateOneID(id). client := clientFromContext(ctx, r.client)
_, err := client.UserSubscription.UpdateOneID(id).
SetWeeklyUsageUsd(0). SetWeeklyUsageUsd(0).
SetWeeklyWindowStart(newWindowStart). SetWeeklyWindowStart(newWindowStart).
Save(ctx) Save(ctx)
...@@ -266,24 +283,54 @@ func (r *userSubscriptionRepository) ResetWeeklyUsage(ctx context.Context, id in ...@@ -266,24 +283,54 @@ func (r *userSubscriptionRepository) ResetWeeklyUsage(ctx context.Context, id in
} }
func (r *userSubscriptionRepository) ResetMonthlyUsage(ctx context.Context, id int64, newWindowStart time.Time) error { func (r *userSubscriptionRepository) ResetMonthlyUsage(ctx context.Context, id int64, newWindowStart time.Time) error {
_, err := r.client.UserSubscription.UpdateOneID(id). client := clientFromContext(ctx, r.client)
_, err := client.UserSubscription.UpdateOneID(id).
SetMonthlyUsageUsd(0). SetMonthlyUsageUsd(0).
SetMonthlyWindowStart(newWindowStart). SetMonthlyWindowStart(newWindowStart).
Save(ctx) Save(ctx)
return translatePersistenceError(err, service.ErrSubscriptionNotFound, nil) return translatePersistenceError(err, service.ErrSubscriptionNotFound, nil)
} }
// IncrementUsage 原子性地累加订阅用量。
// 限额检查已在请求前由 BillingCacheService.CheckBillingEligibility 完成,
// 此处仅负责记录实际消费,确保消费数据的完整性。
func (r *userSubscriptionRepository) IncrementUsage(ctx context.Context, id int64, costUSD float64) error { func (r *userSubscriptionRepository) IncrementUsage(ctx context.Context, id int64, costUSD float64) error {
_, err := r.client.UserSubscription.UpdateOneID(id). const updateSQL = `
AddDailyUsageUsd(costUSD). UPDATE user_subscriptions us
AddWeeklyUsageUsd(costUSD). SET
AddMonthlyUsageUsd(costUSD). daily_usage_usd = us.daily_usage_usd + $1,
Save(ctx) weekly_usage_usd = us.weekly_usage_usd + $1,
return translatePersistenceError(err, service.ErrSubscriptionNotFound, nil) monthly_usage_usd = us.monthly_usage_usd + $1,
updated_at = NOW()
FROM groups g
WHERE us.id = $2
AND us.deleted_at IS NULL
AND us.group_id = g.id
AND g.deleted_at IS NULL
`
client := clientFromContext(ctx, r.client)
result, err := client.ExecContext(ctx, updateSQL, costUSD, id)
if err != nil {
return err
}
affected, err := result.RowsAffected()
if err != nil {
return err
}
if affected > 0 {
return nil
}
// affected == 0:订阅不存在或已删除
return service.ErrSubscriptionNotFound
} }
func (r *userSubscriptionRepository) BatchUpdateExpiredStatus(ctx context.Context) (int64, error) { func (r *userSubscriptionRepository) BatchUpdateExpiredStatus(ctx context.Context) (int64, error) {
n, err := r.client.UserSubscription.Update(). client := clientFromContext(ctx, r.client)
n, err := client.UserSubscription.Update().
Where( Where(
usersubscription.StatusEQ(service.SubscriptionStatusActive), usersubscription.StatusEQ(service.SubscriptionStatusActive),
usersubscription.ExpiresAtLTE(time.Now()), usersubscription.ExpiresAtLTE(time.Now()),
...@@ -296,7 +343,8 @@ func (r *userSubscriptionRepository) BatchUpdateExpiredStatus(ctx context.Contex ...@@ -296,7 +343,8 @@ func (r *userSubscriptionRepository) BatchUpdateExpiredStatus(ctx context.Contex
// Extra repository helpers (currently used only by integration tests). // Extra repository helpers (currently used only by integration tests).
func (r *userSubscriptionRepository) ListExpired(ctx context.Context) ([]service.UserSubscription, error) { func (r *userSubscriptionRepository) ListExpired(ctx context.Context) ([]service.UserSubscription, error) {
subs, err := r.client.UserSubscription.Query(). client := clientFromContext(ctx, r.client)
subs, err := client.UserSubscription.Query().
Where( Where(
usersubscription.StatusEQ(service.SubscriptionStatusActive), usersubscription.StatusEQ(service.SubscriptionStatusActive),
usersubscription.ExpiresAtLTE(time.Now()), usersubscription.ExpiresAtLTE(time.Now()),
...@@ -309,12 +357,14 @@ func (r *userSubscriptionRepository) ListExpired(ctx context.Context) ([]service ...@@ -309,12 +357,14 @@ func (r *userSubscriptionRepository) ListExpired(ctx context.Context) ([]service
} }
func (r *userSubscriptionRepository) CountByGroupID(ctx context.Context, groupID int64) (int64, error) { func (r *userSubscriptionRepository) CountByGroupID(ctx context.Context, groupID int64) (int64, error) {
count, err := r.client.UserSubscription.Query().Where(usersubscription.GroupIDEQ(groupID)).Count(ctx) client := clientFromContext(ctx, r.client)
count, err := client.UserSubscription.Query().Where(usersubscription.GroupIDEQ(groupID)).Count(ctx)
return int64(count), err return int64(count), err
} }
func (r *userSubscriptionRepository) CountActiveByGroupID(ctx context.Context, groupID int64) (int64, error) { func (r *userSubscriptionRepository) CountActiveByGroupID(ctx context.Context, groupID int64) (int64, error) {
count, err := r.client.UserSubscription.Query(). client := clientFromContext(ctx, r.client)
count, err := client.UserSubscription.Query().
Where( Where(
usersubscription.GroupIDEQ(groupID), usersubscription.GroupIDEQ(groupID),
usersubscription.StatusEQ(service.SubscriptionStatusActive), usersubscription.StatusEQ(service.SubscriptionStatusActive),
...@@ -325,7 +375,8 @@ func (r *userSubscriptionRepository) CountActiveByGroupID(ctx context.Context, g ...@@ -325,7 +375,8 @@ func (r *userSubscriptionRepository) CountActiveByGroupID(ctx context.Context, g
} }
func (r *userSubscriptionRepository) DeleteByGroupID(ctx context.Context, groupID int64) (int64, error) { func (r *userSubscriptionRepository) DeleteByGroupID(ctx context.Context, groupID int64) (int64, error) {
n, err := r.client.UserSubscription.Delete().Where(usersubscription.GroupIDEQ(groupID)).Exec(ctx) client := clientFromContext(ctx, r.client)
n, err := client.UserSubscription.Delete().Where(usersubscription.GroupIDEQ(groupID)).Exec(ctx)
return int64(n), err return int64(n), err
} }
......
...@@ -4,6 +4,7 @@ package repository ...@@ -4,6 +4,7 @@ package repository
import ( import (
"context" "context"
"fmt"
"testing" "testing"
"time" "time"
...@@ -631,3 +632,116 @@ func (s *UserSubscriptionRepoSuite) TestActiveExpiredBoundaries_UsageAndReset_Ba ...@@ -631,3 +632,116 @@ func (s *UserSubscriptionRepoSuite) TestActiveExpiredBoundaries_UsageAndReset_Ba
s.Require().NoError(err, "GetByID expired") s.Require().NoError(err, "GetByID expired")
s.Require().Equal(service.SubscriptionStatusExpired, updated.Status, "expected status expired") s.Require().Equal(service.SubscriptionStatusExpired, updated.Status, "expected status expired")
} }
// --- 软删除过滤测试 ---
func (s *UserSubscriptionRepoSuite) TestIncrementUsage_SoftDeletedGroup() {
user := s.mustCreateUser("softdeleted@test.com", service.RoleUser)
group := s.mustCreateGroup("g-softdeleted")
sub := s.mustCreateSubscription(user.ID, group.ID, nil)
// 软删除分组
_, err := s.client.Group.UpdateOneID(group.ID).SetDeletedAt(time.Now()).Save(s.ctx)
s.Require().NoError(err, "soft delete group")
// IncrementUsage 应该失败,因为分组已软删除
err = s.repo.IncrementUsage(s.ctx, sub.ID, 1.0)
s.Require().Error(err, "should fail for soft-deleted group")
s.Require().ErrorIs(err, service.ErrSubscriptionNotFound)
}
func (s *UserSubscriptionRepoSuite) TestIncrementUsage_NotFound() {
err := s.repo.IncrementUsage(s.ctx, 999999, 1.0)
s.Require().Error(err, "should fail for non-existent subscription")
s.Require().ErrorIs(err, service.ErrSubscriptionNotFound)
}
// --- nil 入参测试 ---
func (s *UserSubscriptionRepoSuite) TestCreate_NilInput() {
err := s.repo.Create(s.ctx, nil)
s.Require().Error(err, "Create should fail with nil input")
s.Require().ErrorIs(err, service.ErrSubscriptionNilInput)
}
func (s *UserSubscriptionRepoSuite) TestUpdate_NilInput() {
err := s.repo.Update(s.ctx, nil)
s.Require().Error(err, "Update should fail with nil input")
s.Require().ErrorIs(err, service.ErrSubscriptionNilInput)
}
// --- 并发用量更新测试 ---
func (s *UserSubscriptionRepoSuite) TestIncrementUsage_Concurrent() {
user := s.mustCreateUser("concurrent@test.com", service.RoleUser)
group := s.mustCreateGroup("g-concurrent")
sub := s.mustCreateSubscription(user.ID, group.ID, nil)
const numGoroutines = 10
const incrementPerGoroutine = 1.5
// 启动多个 goroutine 并发调用 IncrementUsage
errCh := make(chan error, numGoroutines)
for i := 0; i < numGoroutines; i++ {
go func() {
errCh <- s.repo.IncrementUsage(s.ctx, sub.ID, incrementPerGoroutine)
}()
}
// 等待所有 goroutine 完成
for i := 0; i < numGoroutines; i++ {
err := <-errCh
s.Require().NoError(err, "IncrementUsage should succeed")
}
// 验证累加结果正确
got, err := s.repo.GetByID(s.ctx, sub.ID)
s.Require().NoError(err)
expectedUsage := float64(numGoroutines) * incrementPerGoroutine
s.Require().InDelta(expectedUsage, got.DailyUsageUSD, 1e-6, "daily usage should be correctly accumulated")
s.Require().InDelta(expectedUsage, got.WeeklyUsageUSD, 1e-6, "weekly usage should be correctly accumulated")
s.Require().InDelta(expectedUsage, got.MonthlyUsageUSD, 1e-6, "monthly usage should be correctly accumulated")
}
func (s *UserSubscriptionRepoSuite) TestTxContext_RollbackIsolation() {
baseClient := testEntClient(s.T())
tx, err := baseClient.Tx(context.Background())
s.Require().NoError(err, "begin tx")
defer func() {
if tx != nil {
_ = tx.Rollback()
}
}()
txCtx := dbent.NewTxContext(context.Background(), tx)
suffix := fmt.Sprintf("%d", time.Now().UnixNano())
userEnt, err := tx.Client().User.Create().
SetEmail("tx-user-" + suffix + "@example.com").
SetPasswordHash("test").
Save(txCtx)
s.Require().NoError(err, "create user in tx")
groupEnt, err := tx.Client().Group.Create().
SetName("tx-group-" + suffix).
Save(txCtx)
s.Require().NoError(err, "create group in tx")
repo := NewUserSubscriptionRepository(baseClient)
sub := &service.UserSubscription{
UserID: userEnt.ID,
GroupID: groupEnt.ID,
ExpiresAt: time.Now().AddDate(0, 0, 30),
Status: service.SubscriptionStatusActive,
AssignedAt: time.Now(),
Notes: "tx",
}
s.Require().NoError(repo.Create(txCtx, sub), "create subscription in tx")
s.Require().NoError(repo.UpdateNotes(txCtx, sub.ID, "tx-note"), "update subscription in tx")
s.Require().NoError(tx.Rollback(), "rollback tx")
tx = nil
_, err = repo.GetByID(context.Background(), sub.ID)
s.Require().ErrorIs(err, service.ErrSubscriptionNotFound)
}
package repository package repository
import ( import (
"database/sql"
"errors"
entsql "entgo.io/ent/dialect/sql"
"github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/google/wire" "github.com/google/wire"
"github.com/redis/go-redis/v9"
) )
// ProvideConcurrencyCache 创建并发控制缓存,从配置读取 TTL 参数
// 性能优化:TTL 可配置,支持长时间运行的 LLM 请求场景
func ProvideConcurrencyCache(rdb *redis.Client, cfg *config.Config) service.ConcurrencyCache {
waitTTLSeconds := int(cfg.Gateway.Scheduling.StickySessionWaitTimeout.Seconds())
if cfg.Gateway.Scheduling.FallbackWaitTimeout > cfg.Gateway.Scheduling.StickySessionWaitTimeout {
waitTTLSeconds = int(cfg.Gateway.Scheduling.FallbackWaitTimeout.Seconds())
}
if waitTTLSeconds <= 0 {
waitTTLSeconds = cfg.Gateway.ConcurrencySlotTTLMinutes * 60
}
return NewConcurrencyCache(rdb, cfg.Gateway.ConcurrencySlotTTLMinutes, waitTTLSeconds)
}
// ProviderSet is the Wire provider set for all repositories // ProviderSet is the Wire provider set for all repositories
var ProviderSet = wire.NewSet( var ProviderSet = wire.NewSet(
NewUserRepository, NewUserRepository,
...@@ -20,7 +41,7 @@ var ProviderSet = wire.NewSet( ...@@ -20,7 +41,7 @@ var ProviderSet = wire.NewSet(
NewGatewayCache, NewGatewayCache,
NewBillingCache, NewBillingCache,
NewApiKeyCache, NewApiKeyCache,
NewConcurrencyCache, ProvideConcurrencyCache,
NewEmailCache, NewEmailCache,
NewIdentityCache, NewIdentityCache,
NewRedeemCache, NewRedeemCache,
...@@ -38,4 +59,58 @@ var ProviderSet = wire.NewSet( ...@@ -38,4 +59,58 @@ var ProviderSet = wire.NewSet(
NewOpenAIOAuthClient, NewOpenAIOAuthClient,
NewGeminiOAuthClient, NewGeminiOAuthClient,
NewGeminiCliCodeAssistClient, NewGeminiCliCodeAssistClient,
ProvideEnt,
ProvideSQLDB,
ProvideRedis,
) )
// ProvideEnt 为依赖注入提供 Ent 客户端。
//
// 该函数是 InitEnt 的包装器,符合 Wire 的依赖提供函数签名要求。
// Wire 会在编译时分析依赖关系,自动生成初始化代码。
//
// 依赖:config.Config
// 提供:*ent.Client
func ProvideEnt(cfg *config.Config) (*ent.Client, error) {
client, _, err := InitEnt(cfg)
return client, err
}
// ProvideSQLDB 从 Ent 客户端提取底层的 *sql.DB 连接。
//
// 某些 Repository 需要直接执行原生 SQL(如复杂的批量更新、聚合查询),
// 此时需要访问底层的 sql.DB 而不是通过 Ent ORM。
//
// 设计说明:
// - Ent 底层使用 sql.DB,通过 Driver 接口可以访问
// - 这种设计允许在同一事务中混用 Ent 和原生 SQL
//
// 依赖:*ent.Client
// 提供:*sql.DB
func ProvideSQLDB(client *ent.Client) (*sql.DB, error) {
if client == nil {
return nil, errors.New("nil ent client")
}
// 从 Ent 客户端获取底层驱动
drv, ok := client.Driver().(*entsql.Driver)
if !ok {
return nil, errors.New("ent driver does not expose *sql.DB")
}
// 返回驱动持有的 sql.DB 实例
return drv.DB(), nil
}
// ProvideRedis 为依赖注入提供 Redis 客户端。
//
// Redis 用于:
// - 分布式锁(如并发控制)
// - 缓存(如用户会话、API 响应缓存)
// - 速率限制
// - 实时统计数据
//
// 依赖:config.Config
// 提供:*redis.Client
func ProvideRedis(cfg *config.Config) *redis.Client {
return InitRedis(cfg)
}
...@@ -385,7 +385,7 @@ func newContractDeps(t *testing.T) *contractDeps { ...@@ -385,7 +385,7 @@ func newContractDeps(t *testing.T) *contractDeps {
authHandler := handler.NewAuthHandler(cfg, nil, userService) authHandler := handler.NewAuthHandler(cfg, nil, userService)
apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService) apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService)
usageHandler := handler.NewUsageHandler(usageService, apiKeyService) usageHandler := handler.NewUsageHandler(usageService, apiKeyService)
adminSettingHandler := adminhandler.NewSettingHandler(settingService, nil) adminSettingHandler := adminhandler.NewSettingHandler(settingService, nil, nil)
jwtAuth := func(c *gin.Context) { jwtAuth := func(c *gin.Context) {
c.Set(string(middleware.ContextKeyUser), middleware.AuthSubject{ c.Set(string(middleware.ContextKeyUser), middleware.AuthSubject{
...@@ -981,6 +981,18 @@ func (r *stubUsageLogRepo) GetApiKeyStatsAggregated(ctx context.Context, apiKeyI ...@@ -981,6 +981,18 @@ func (r *stubUsageLogRepo) GetApiKeyStatsAggregated(ctx context.Context, apiKeyI
return nil, errors.New("not implemented") return nil, errors.New("not implemented")
} }
func (r *stubUsageLogRepo) GetAccountStatsAggregated(ctx context.Context, accountID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error) {
return nil, errors.New("not implemented")
}
func (r *stubUsageLogRepo) GetModelStatsAggregated(ctx context.Context, modelName string, startTime, endTime time.Time) (*usagestats.UsageStats, error) {
return nil, errors.New("not implemented")
}
func (r *stubUsageLogRepo) GetDailyStatsAggregated(ctx context.Context, userID int64, startTime, endTime time.Time) ([]map[string]any, error) {
return nil, errors.New("not implemented")
}
func (r *stubUsageLogRepo) GetBatchUserUsageStats(ctx context.Context, userIDs []int64) (map[int64]*usagestats.BatchUserUsageStats, error) { func (r *stubUsageLogRepo) GetBatchUserUsageStats(ctx context.Context, userIDs []int64) (map[int64]*usagestats.BatchUserUsageStats, error) {
return nil, errors.New("not implemented") return nil, errors.New("not implemented")
} }
......
...@@ -7,7 +7,7 @@ import ( ...@@ -7,7 +7,7 @@ import (
"os" "os"
"strings" "strings"
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors" infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/response" "github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )
......
...@@ -8,7 +8,7 @@ import ( ...@@ -8,7 +8,7 @@ import (
"net/http/httptest" "net/http/httptest"
"testing" "testing"
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors" infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/response" "github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
......
package middleware
import (
"net/http"
"github.com/gin-gonic/gin"
)
// RequestBodyLimit 使用 MaxBytesReader 限制请求体大小。
func RequestBodyLimit(maxBytes int64) gin.HandlerFunc {
return func(c *gin.Context) {
c.Request.Body = http.MaxBytesReader(c.Writer, c.Request.Body, maxBytes)
c.Next()
}
}
...@@ -18,8 +18,11 @@ func RegisterGatewayRoutes( ...@@ -18,8 +18,11 @@ func RegisterGatewayRoutes(
subscriptionService *service.SubscriptionService, subscriptionService *service.SubscriptionService,
cfg *config.Config, cfg *config.Config,
) { ) {
bodyLimit := middleware.RequestBodyLimit(cfg.Gateway.MaxBodySize)
// API网关(Claude API兼容) // API网关(Claude API兼容)
gateway := r.Group("/v1") gateway := r.Group("/v1")
gateway.Use(bodyLimit)
gateway.Use(gin.HandlerFunc(apiKeyAuth)) gateway.Use(gin.HandlerFunc(apiKeyAuth))
{ {
gateway.POST("/messages", h.Gateway.Messages) gateway.POST("/messages", h.Gateway.Messages)
...@@ -32,6 +35,7 @@ func RegisterGatewayRoutes( ...@@ -32,6 +35,7 @@ func RegisterGatewayRoutes(
// Gemini 原生 API 兼容层(Gemini SDK/CLI 直连) // Gemini 原生 API 兼容层(Gemini SDK/CLI 直连)
gemini := r.Group("/v1beta") gemini := r.Group("/v1beta")
gemini.Use(bodyLimit)
gemini.Use(middleware.ApiKeyAuthWithSubscriptionGoogle(apiKeyService, subscriptionService, cfg)) gemini.Use(middleware.ApiKeyAuthWithSubscriptionGoogle(apiKeyService, subscriptionService, cfg))
{ {
gemini.GET("/models", h.Gateway.GeminiV1BetaListModels) gemini.GET("/models", h.Gateway.GeminiV1BetaListModels)
...@@ -41,10 +45,11 @@ func RegisterGatewayRoutes( ...@@ -41,10 +45,11 @@ func RegisterGatewayRoutes(
} }
// OpenAI Responses API(不带v1前缀的别名) // OpenAI Responses API(不带v1前缀的别名)
r.POST("/responses", gin.HandlerFunc(apiKeyAuth), h.OpenAIGateway.Responses) r.POST("/responses", bodyLimit, gin.HandlerFunc(apiKeyAuth), h.OpenAIGateway.Responses)
// Antigravity 专用路由(仅使用 antigravity 账户,不混合调度) // Antigravity 专用路由(仅使用 antigravity 账户,不混合调度)
antigravityV1 := r.Group("/antigravity/v1") antigravityV1 := r.Group("/antigravity/v1")
antigravityV1.Use(bodyLimit)
antigravityV1.Use(middleware.ForcePlatform(service.PlatformAntigravity)) antigravityV1.Use(middleware.ForcePlatform(service.PlatformAntigravity))
antigravityV1.Use(gin.HandlerFunc(apiKeyAuth)) antigravityV1.Use(gin.HandlerFunc(apiKeyAuth))
{ {
...@@ -55,6 +60,7 @@ func RegisterGatewayRoutes( ...@@ -55,6 +60,7 @@ func RegisterGatewayRoutes(
} }
antigravityV1Beta := r.Group("/antigravity/v1beta") antigravityV1Beta := r.Group("/antigravity/v1beta")
antigravityV1Beta.Use(bodyLimit)
antigravityV1Beta.Use(middleware.ForcePlatform(service.PlatformAntigravity)) antigravityV1Beta.Use(middleware.ForcePlatform(service.PlatformAntigravity))
antigravityV1Beta.Use(middleware.ApiKeyAuthWithSubscriptionGoogle(apiKeyService, subscriptionService, cfg)) antigravityV1Beta.Use(middleware.ApiKeyAuthWithSubscriptionGoogle(apiKeyService, subscriptionService, cfg))
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment