Commit d8e40551 authored by yangjianbo's avatar yangjianbo
Browse files
parents 74d35f08 571d1479
package repository
import (
"context"
"github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/errorpassthroughrule"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/service"
)
type errorPassthroughRepository struct {
client *ent.Client
}
// NewErrorPassthroughRepository 创建错误透传规则仓库
func NewErrorPassthroughRepository(client *ent.Client) service.ErrorPassthroughRepository {
return &errorPassthroughRepository{client: client}
}
// List 获取所有规则
func (r *errorPassthroughRepository) List(ctx context.Context) ([]*model.ErrorPassthroughRule, error) {
rules, err := r.client.ErrorPassthroughRule.Query().
Order(ent.Asc(errorpassthroughrule.FieldPriority)).
All(ctx)
if err != nil {
return nil, err
}
result := make([]*model.ErrorPassthroughRule, len(rules))
for i, rule := range rules {
result[i] = r.toModel(rule)
}
return result, nil
}
// GetByID 根据 ID 获取规则
func (r *errorPassthroughRepository) GetByID(ctx context.Context, id int64) (*model.ErrorPassthroughRule, error) {
rule, err := r.client.ErrorPassthroughRule.Get(ctx, id)
if err != nil {
if ent.IsNotFound(err) {
return nil, nil
}
return nil, err
}
return r.toModel(rule), nil
}
// Create 创建规则
func (r *errorPassthroughRepository) Create(ctx context.Context, rule *model.ErrorPassthroughRule) (*model.ErrorPassthroughRule, error) {
builder := r.client.ErrorPassthroughRule.Create().
SetName(rule.Name).
SetEnabled(rule.Enabled).
SetPriority(rule.Priority).
SetMatchMode(rule.MatchMode).
SetPassthroughCode(rule.PassthroughCode).
SetPassthroughBody(rule.PassthroughBody)
if len(rule.ErrorCodes) > 0 {
builder.SetErrorCodes(rule.ErrorCodes)
}
if len(rule.Keywords) > 0 {
builder.SetKeywords(rule.Keywords)
}
if len(rule.Platforms) > 0 {
builder.SetPlatforms(rule.Platforms)
}
if rule.ResponseCode != nil {
builder.SetResponseCode(*rule.ResponseCode)
}
if rule.CustomMessage != nil {
builder.SetCustomMessage(*rule.CustomMessage)
}
if rule.Description != nil {
builder.SetDescription(*rule.Description)
}
created, err := builder.Save(ctx)
if err != nil {
return nil, err
}
return r.toModel(created), nil
}
// Update 更新规则
func (r *errorPassthroughRepository) Update(ctx context.Context, rule *model.ErrorPassthroughRule) (*model.ErrorPassthroughRule, error) {
builder := r.client.ErrorPassthroughRule.UpdateOneID(rule.ID).
SetName(rule.Name).
SetEnabled(rule.Enabled).
SetPriority(rule.Priority).
SetMatchMode(rule.MatchMode).
SetPassthroughCode(rule.PassthroughCode).
SetPassthroughBody(rule.PassthroughBody)
// 处理可选字段
if len(rule.ErrorCodes) > 0 {
builder.SetErrorCodes(rule.ErrorCodes)
} else {
builder.ClearErrorCodes()
}
if len(rule.Keywords) > 0 {
builder.SetKeywords(rule.Keywords)
} else {
builder.ClearKeywords()
}
if len(rule.Platforms) > 0 {
builder.SetPlatforms(rule.Platforms)
} else {
builder.ClearPlatforms()
}
if rule.ResponseCode != nil {
builder.SetResponseCode(*rule.ResponseCode)
} else {
builder.ClearResponseCode()
}
if rule.CustomMessage != nil {
builder.SetCustomMessage(*rule.CustomMessage)
} else {
builder.ClearCustomMessage()
}
if rule.Description != nil {
builder.SetDescription(*rule.Description)
} else {
builder.ClearDescription()
}
updated, err := builder.Save(ctx)
if err != nil {
return nil, err
}
return r.toModel(updated), nil
}
// Delete 删除规则
func (r *errorPassthroughRepository) Delete(ctx context.Context, id int64) error {
return r.client.ErrorPassthroughRule.DeleteOneID(id).Exec(ctx)
}
// toModel 将 Ent 实体转换为服务模型
func (r *errorPassthroughRepository) toModel(e *ent.ErrorPassthroughRule) *model.ErrorPassthroughRule {
rule := &model.ErrorPassthroughRule{
ID: int64(e.ID),
Name: e.Name,
Enabled: e.Enabled,
Priority: e.Priority,
ErrorCodes: e.ErrorCodes,
Keywords: e.Keywords,
MatchMode: e.MatchMode,
Platforms: e.Platforms,
PassthroughCode: e.PassthroughCode,
PassthroughBody: e.PassthroughBody,
CreatedAt: e.CreatedAt,
UpdatedAt: e.UpdatedAt,
}
if e.ResponseCode != nil {
rule.ResponseCode = e.ResponseCode
}
if e.CustomMessage != nil {
rule.CustomMessage = e.CustomMessage
}
if e.Description != nil {
rule.Description = e.Description
}
// 确保切片不为 nil
if rule.ErrorCodes == nil {
rule.ErrorCodes = []int{}
}
if rule.Keywords == nil {
rule.Keywords = []string{}
}
if rule.Platforms == nil {
rule.Platforms = []string{}
}
return rule
}
...@@ -6,6 +6,7 @@ import ( ...@@ -6,6 +6,7 @@ import (
"time" "time"
"github.com/Wei-Shaw/sub2api/internal/pkg/geminicli" "github.com/Wei-Shaw/sub2api/internal/pkg/geminicli"
"github.com/Wei-Shaw/sub2api/internal/pkg/googleapi"
"github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
"github.com/imroc/req/v3" "github.com/imroc/req/v3"
...@@ -38,9 +39,20 @@ func (c *geminiCliCodeAssistClient) LoadCodeAssist(ctx context.Context, accessTo ...@@ -38,9 +39,20 @@ func (c *geminiCliCodeAssistClient) LoadCodeAssist(ctx context.Context, accessTo
return nil, fmt.Errorf("request failed: %w", err) return nil, fmt.Errorf("request failed: %w", err)
} }
if !resp.IsSuccessState() { if !resp.IsSuccessState() {
body := geminicli.SanitizeBodyForLogs(resp.String()) body := resp.String()
fmt.Printf("[CodeAssist] LoadCodeAssist failed: status %d, body: %s\n", resp.StatusCode, body) sanitizedBody := geminicli.SanitizeBodyForLogs(body)
return nil, fmt.Errorf("loadCodeAssist failed: status %d, body: %s", resp.StatusCode, body) fmt.Printf("[CodeAssist] LoadCodeAssist failed: status %d, body: %s\n", resp.StatusCode, sanitizedBody)
// Check if this is a SERVICE_DISABLED error and extract activation URL
if googleapi.IsServiceDisabledError(body) {
activationURL := googleapi.ExtractActivationURL(body)
if activationURL != "" {
return nil, fmt.Errorf("gemini API not enabled for this project, please enable it by visiting: %s\n\nAfter enabling the API, wait a few minutes for the changes to propagate, then try again", activationURL)
}
return nil, fmt.Errorf("gemini API not enabled for this project, please enable it in the Google Cloud Console at: https://console.cloud.google.com/apis/library/cloudaicompanion.googleapis.com")
}
return nil, fmt.Errorf("loadCodeAssist failed: status %d, body: %s", resp.StatusCode, sanitizedBody)
} }
fmt.Printf("[CodeAssist] LoadCodeAssist success: status %d, response: %+v\n", resp.StatusCode, out) fmt.Printf("[CodeAssist] LoadCodeAssist success: status %d, response: %+v\n", resp.StatusCode, out)
return &out, nil return &out, nil
...@@ -67,9 +79,20 @@ func (c *geminiCliCodeAssistClient) OnboardUser(ctx context.Context, accessToken ...@@ -67,9 +79,20 @@ func (c *geminiCliCodeAssistClient) OnboardUser(ctx context.Context, accessToken
return nil, fmt.Errorf("request failed: %w", err) return nil, fmt.Errorf("request failed: %w", err)
} }
if !resp.IsSuccessState() { if !resp.IsSuccessState() {
body := geminicli.SanitizeBodyForLogs(resp.String()) body := resp.String()
fmt.Printf("[CodeAssist] OnboardUser failed: status %d, body: %s\n", resp.StatusCode, body) sanitizedBody := geminicli.SanitizeBodyForLogs(body)
return nil, fmt.Errorf("onboardUser failed: status %d, body: %s", resp.StatusCode, body) fmt.Printf("[CodeAssist] OnboardUser failed: status %d, body: %s\n", resp.StatusCode, sanitizedBody)
// Check if this is a SERVICE_DISABLED error and extract activation URL
if googleapi.IsServiceDisabledError(body) {
activationURL := googleapi.ExtractActivationURL(body)
if activationURL != "" {
return nil, fmt.Errorf("gemini API not enabled for this project, please enable it by visiting: %s\n\nAfter enabling the API, wait a few minutes for the changes to propagate, then try again", activationURL)
}
return nil, fmt.Errorf("gemini API not enabled for this project, please enable it in the Google Cloud Console at: https://console.cloud.google.com/apis/library/cloudaicompanion.googleapis.com")
}
return nil, fmt.Errorf("onboardUser failed: status %d, body: %s", resp.StatusCode, sanitizedBody)
} }
fmt.Printf("[CodeAssist] OnboardUser success: status %d, response: %+v\n", resp.StatusCode, out) fmt.Printf("[CodeAssist] OnboardUser success: status %d, response: %+v\n", resp.StatusCode, out)
return &out, nil return &out, nil
......
package repository
import (
"context"
"encoding/json"
"fmt"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/redis/go-redis/v9"
)
const (
refreshTokenKeyPrefix = "refresh_token:"
userRefreshTokensPrefix = "user_refresh_tokens:"
tokenFamilyPrefix = "token_family:"
)
// refreshTokenKey generates the Redis key for a refresh token.
func refreshTokenKey(tokenHash string) string {
return refreshTokenKeyPrefix + tokenHash
}
// userRefreshTokensKey generates the Redis key for user's token set.
func userRefreshTokensKey(userID int64) string {
return fmt.Sprintf("%s%d", userRefreshTokensPrefix, userID)
}
// tokenFamilyKey generates the Redis key for token family set.
func tokenFamilyKey(familyID string) string {
return tokenFamilyPrefix + familyID
}
type refreshTokenCache struct {
rdb *redis.Client
}
// NewRefreshTokenCache creates a new RefreshTokenCache implementation.
func NewRefreshTokenCache(rdb *redis.Client) service.RefreshTokenCache {
return &refreshTokenCache{rdb: rdb}
}
func (c *refreshTokenCache) StoreRefreshToken(ctx context.Context, tokenHash string, data *service.RefreshTokenData, ttl time.Duration) error {
key := refreshTokenKey(tokenHash)
val, err := json.Marshal(data)
if err != nil {
return fmt.Errorf("marshal refresh token data: %w", err)
}
return c.rdb.Set(ctx, key, val, ttl).Err()
}
func (c *refreshTokenCache) GetRefreshToken(ctx context.Context, tokenHash string) (*service.RefreshTokenData, error) {
key := refreshTokenKey(tokenHash)
val, err := c.rdb.Get(ctx, key).Result()
if err != nil {
if err == redis.Nil {
return nil, service.ErrRefreshTokenNotFound
}
return nil, err
}
var data service.RefreshTokenData
if err := json.Unmarshal([]byte(val), &data); err != nil {
return nil, fmt.Errorf("unmarshal refresh token data: %w", err)
}
return &data, nil
}
func (c *refreshTokenCache) DeleteRefreshToken(ctx context.Context, tokenHash string) error {
key := refreshTokenKey(tokenHash)
return c.rdb.Del(ctx, key).Err()
}
func (c *refreshTokenCache) DeleteUserRefreshTokens(ctx context.Context, userID int64) error {
// Get all token hashes for this user
tokenHashes, err := c.GetUserTokenHashes(ctx, userID)
if err != nil && err != redis.Nil {
return fmt.Errorf("get user token hashes: %w", err)
}
if len(tokenHashes) == 0 {
return nil
}
// Build keys to delete
keys := make([]string, 0, len(tokenHashes)+1)
for _, hash := range tokenHashes {
keys = append(keys, refreshTokenKey(hash))
}
keys = append(keys, userRefreshTokensKey(userID))
// Delete all keys in a pipeline
pipe := c.rdb.Pipeline()
for _, key := range keys {
pipe.Del(ctx, key)
}
_, err = pipe.Exec(ctx)
return err
}
func (c *refreshTokenCache) DeleteTokenFamily(ctx context.Context, familyID string) error {
// Get all token hashes in this family
tokenHashes, err := c.GetFamilyTokenHashes(ctx, familyID)
if err != nil && err != redis.Nil {
return fmt.Errorf("get family token hashes: %w", err)
}
if len(tokenHashes) == 0 {
return nil
}
// Build keys to delete
keys := make([]string, 0, len(tokenHashes)+1)
for _, hash := range tokenHashes {
keys = append(keys, refreshTokenKey(hash))
}
keys = append(keys, tokenFamilyKey(familyID))
// Delete all keys in a pipeline
pipe := c.rdb.Pipeline()
for _, key := range keys {
pipe.Del(ctx, key)
}
_, err = pipe.Exec(ctx)
return err
}
func (c *refreshTokenCache) AddToUserTokenSet(ctx context.Context, userID int64, tokenHash string, ttl time.Duration) error {
key := userRefreshTokensKey(userID)
pipe := c.rdb.Pipeline()
pipe.SAdd(ctx, key, tokenHash)
pipe.Expire(ctx, key, ttl)
_, err := pipe.Exec(ctx)
return err
}
func (c *refreshTokenCache) AddToFamilyTokenSet(ctx context.Context, familyID string, tokenHash string, ttl time.Duration) error {
key := tokenFamilyKey(familyID)
pipe := c.rdb.Pipeline()
pipe.SAdd(ctx, key, tokenHash)
pipe.Expire(ctx, key, ttl)
_, err := pipe.Exec(ctx)
return err
}
func (c *refreshTokenCache) GetUserTokenHashes(ctx context.Context, userID int64) ([]string, error) {
key := userRefreshTokensKey(userID)
return c.rdb.SMembers(ctx, key).Result()
}
func (c *refreshTokenCache) GetFamilyTokenHashes(ctx context.Context, familyID string) ([]string, error) {
key := tokenFamilyKey(familyID)
return c.rdb.SMembers(ctx, key).Result()
}
func (c *refreshTokenCache) IsTokenInFamily(ctx context.Context, familyID string, tokenHash string) (bool, error) {
key := tokenFamilyKey(familyID)
return c.rdb.SIsMember(ctx, key, tokenHash).Result()
}
...@@ -3,6 +3,7 @@ package repository ...@@ -3,6 +3,7 @@ package repository
import ( import (
"context" "context"
"fmt" "fmt"
"log"
"strconv" "strconv"
"time" "time"
...@@ -153,6 +154,21 @@ func NewSessionLimitCache(rdb *redis.Client, defaultIdleTimeoutMinutes int) serv ...@@ -153,6 +154,21 @@ func NewSessionLimitCache(rdb *redis.Client, defaultIdleTimeoutMinutes int) serv
if defaultIdleTimeoutMinutes <= 0 { if defaultIdleTimeoutMinutes <= 0 {
defaultIdleTimeoutMinutes = 5 // 默认 5 分钟 defaultIdleTimeoutMinutes = 5 // 默认 5 分钟
} }
// 预加载 Lua 脚本到 Redis,避免 Pipeline 中出现 NOSCRIPT 错误
ctx := context.Background()
scripts := []*redis.Script{
registerSessionScript,
refreshSessionScript,
getActiveSessionCountScript,
isSessionActiveScript,
}
for _, script := range scripts {
if err := script.Load(ctx, rdb).Err(); err != nil {
log.Printf("[SessionLimitCache] Failed to preload Lua script: %v", err)
}
}
return &sessionLimitCache{ return &sessionLimitCache{
rdb: rdb, rdb: rdb,
defaultIdleTimeout: time.Duration(defaultIdleTimeoutMinutes) * time.Minute, defaultIdleTimeout: time.Duration(defaultIdleTimeoutMinutes) * time.Minute,
......
...@@ -1128,6 +1128,107 @@ func (r *usageLogRepository) GetUserDashboardStats(ctx context.Context, userID i ...@@ -1128,6 +1128,107 @@ func (r *usageLogRepository) GetUserDashboardStats(ctx context.Context, userID i
return stats, nil return stats, nil
} }
// getPerformanceStatsByAPIKey 获取指定 API Key 的 RPM 和 TPM(近5分钟平均值)
func (r *usageLogRepository) getPerformanceStatsByAPIKey(ctx context.Context, apiKeyID int64) (rpm, tpm int64, err error) {
fiveMinutesAgo := time.Now().Add(-5 * time.Minute)
query := `
SELECT
COUNT(*) as request_count,
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as token_count
FROM usage_logs
WHERE created_at >= $1 AND api_key_id = $2`
args := []any{fiveMinutesAgo, apiKeyID}
var requestCount int64
var tokenCount int64
if err := scanSingleRow(ctx, r.sql, query, args, &requestCount, &tokenCount); err != nil {
return 0, 0, err
}
return requestCount / 5, tokenCount / 5, nil
}
// GetAPIKeyDashboardStats 获取指定 API Key 的仪表盘统计(按 api_key_id 过滤)
func (r *usageLogRepository) GetAPIKeyDashboardStats(ctx context.Context, apiKeyID int64) (*UserDashboardStats, error) {
stats := &UserDashboardStats{}
today := timezone.Today()
// API Key 维度不需要统计 key 数量,设为 1
stats.TotalAPIKeys = 1
stats.ActiveAPIKeys = 1
// 累计 Token 统计
totalStatsQuery := `
SELECT
COUNT(*) as total_requests,
COALESCE(SUM(input_tokens), 0) as total_input_tokens,
COALESCE(SUM(output_tokens), 0) as total_output_tokens,
COALESCE(SUM(cache_creation_tokens), 0) as total_cache_creation_tokens,
COALESCE(SUM(cache_read_tokens), 0) as total_cache_read_tokens,
COALESCE(SUM(total_cost), 0) as total_cost,
COALESCE(SUM(actual_cost), 0) as total_actual_cost,
COALESCE(AVG(duration_ms), 0) as avg_duration_ms
FROM usage_logs
WHERE api_key_id = $1
`
if err := scanSingleRow(
ctx,
r.sql,
totalStatsQuery,
[]any{apiKeyID},
&stats.TotalRequests,
&stats.TotalInputTokens,
&stats.TotalOutputTokens,
&stats.TotalCacheCreationTokens,
&stats.TotalCacheReadTokens,
&stats.TotalCost,
&stats.TotalActualCost,
&stats.AverageDurationMs,
); err != nil {
return nil, err
}
stats.TotalTokens = stats.TotalInputTokens + stats.TotalOutputTokens + stats.TotalCacheCreationTokens + stats.TotalCacheReadTokens
// 今日 Token 统计
todayStatsQuery := `
SELECT
COUNT(*) as today_requests,
COALESCE(SUM(input_tokens), 0) as today_input_tokens,
COALESCE(SUM(output_tokens), 0) as today_output_tokens,
COALESCE(SUM(cache_creation_tokens), 0) as today_cache_creation_tokens,
COALESCE(SUM(cache_read_tokens), 0) as today_cache_read_tokens,
COALESCE(SUM(total_cost), 0) as today_cost,
COALESCE(SUM(actual_cost), 0) as today_actual_cost
FROM usage_logs
WHERE api_key_id = $1 AND created_at >= $2
`
if err := scanSingleRow(
ctx,
r.sql,
todayStatsQuery,
[]any{apiKeyID, today},
&stats.TodayRequests,
&stats.TodayInputTokens,
&stats.TodayOutputTokens,
&stats.TodayCacheCreationTokens,
&stats.TodayCacheReadTokens,
&stats.TodayCost,
&stats.TodayActualCost,
); err != nil {
return nil, err
}
stats.TodayTokens = stats.TodayInputTokens + stats.TodayOutputTokens + stats.TodayCacheCreationTokens + stats.TodayCacheReadTokens
// 性能指标:RPM 和 TPM(最近5分钟,按 API Key 过滤)
rpm, tpm, err := r.getPerformanceStatsByAPIKey(ctx, apiKeyID)
if err != nil {
return nil, err
}
stats.Rpm = rpm
stats.Tpm = tpm
return stats, nil
}
// GetUserUsageTrendByUserID 获取指定用户的使用趋势 // GetUserUsageTrendByUserID 获取指定用户的使用趋势
func (r *usageLogRepository) GetUserUsageTrendByUserID(ctx context.Context, userID int64, startTime, endTime time.Time, granularity string) (results []TrendDataPoint, err error) { func (r *usageLogRepository) GetUserUsageTrendByUserID(ctx context.Context, userID int64, startTime, endTime time.Time, granularity string) (results []TrendDataPoint, err error) {
dateFormat := "YYYY-MM-DD" dateFormat := "YYYY-MM-DD"
......
package repository
import (
"context"
"database/sql"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
)
type userGroupRateRepository struct {
sql sqlExecutor
}
// NewUserGroupRateRepository 创建用户专属分组倍率仓储
func NewUserGroupRateRepository(sqlDB *sql.DB) service.UserGroupRateRepository {
return &userGroupRateRepository{sql: sqlDB}
}
// GetByUserID 获取用户的所有专属分组倍率
func (r *userGroupRateRepository) GetByUserID(ctx context.Context, userID int64) (map[int64]float64, error) {
query := `SELECT group_id, rate_multiplier FROM user_group_rate_multipliers WHERE user_id = $1`
rows, err := r.sql.QueryContext(ctx, query, userID)
if err != nil {
return nil, err
}
defer func() { _ = rows.Close() }()
result := make(map[int64]float64)
for rows.Next() {
var groupID int64
var rate float64
if err := rows.Scan(&groupID, &rate); err != nil {
return nil, err
}
result[groupID] = rate
}
if err := rows.Err(); err != nil {
return nil, err
}
return result, nil
}
// GetByUserAndGroup 获取用户在特定分组的专属倍率
func (r *userGroupRateRepository) GetByUserAndGroup(ctx context.Context, userID, groupID int64) (*float64, error) {
query := `SELECT rate_multiplier FROM user_group_rate_multipliers WHERE user_id = $1 AND group_id = $2`
var rate float64
err := scanSingleRow(ctx, r.sql, query, []any{userID, groupID}, &rate)
if err == sql.ErrNoRows {
return nil, nil
}
if err != nil {
return nil, err
}
return &rate, nil
}
// SyncUserGroupRates 同步用户的分组专属倍率
func (r *userGroupRateRepository) SyncUserGroupRates(ctx context.Context, userID int64, rates map[int64]*float64) error {
if len(rates) == 0 {
// 如果传入空 map,删除该用户的所有专属倍率
_, err := r.sql.ExecContext(ctx, `DELETE FROM user_group_rate_multipliers WHERE user_id = $1`, userID)
return err
}
// 分离需要删除和需要 upsert 的记录
var toDelete []int64
toUpsert := make(map[int64]float64)
for groupID, rate := range rates {
if rate == nil {
toDelete = append(toDelete, groupID)
} else {
toUpsert[groupID] = *rate
}
}
// 删除指定的记录
for _, groupID := range toDelete {
_, err := r.sql.ExecContext(ctx,
`DELETE FROM user_group_rate_multipliers WHERE user_id = $1 AND group_id = $2`,
userID, groupID)
if err != nil {
return err
}
}
// Upsert 记录
now := time.Now()
for groupID, rate := range toUpsert {
_, err := r.sql.ExecContext(ctx, `
INSERT INTO user_group_rate_multipliers (user_id, group_id, rate_multiplier, created_at, updated_at)
VALUES ($1, $2, $3, $4, $4)
ON CONFLICT (user_id, group_id) DO UPDATE SET rate_multiplier = $3, updated_at = $4
`, userID, groupID, rate, now)
if err != nil {
return err
}
}
return nil
}
// DeleteByGroupID 删除指定分组的所有用户专属倍率
func (r *userGroupRateRepository) DeleteByGroupID(ctx context.Context, groupID int64) error {
_, err := r.sql.ExecContext(ctx, `DELETE FROM user_group_rate_multipliers WHERE group_id = $1`, groupID)
return err
}
// DeleteByUserID 删除指定用户的所有专属倍率
func (r *userGroupRateRepository) DeleteByUserID(ctx context.Context, userID int64) error {
_, err := r.sql.ExecContext(ctx, `DELETE FROM user_group_rate_multipliers WHERE user_id = $1`, userID)
return err
}
...@@ -67,6 +67,8 @@ var ProviderSet = wire.NewSet( ...@@ -67,6 +67,8 @@ var ProviderSet = wire.NewSet(
NewUserSubscriptionRepository, NewUserSubscriptionRepository,
NewUserAttributeDefinitionRepository, NewUserAttributeDefinitionRepository,
NewUserAttributeValueRepository, NewUserAttributeValueRepository,
NewUserGroupRateRepository,
NewErrorPassthroughRepository,
// Cache implementations // Cache implementations
NewGatewayCache, NewGatewayCache,
...@@ -86,6 +88,8 @@ var ProviderSet = wire.NewSet( ...@@ -86,6 +88,8 @@ var ProviderSet = wire.NewSet(
NewSchedulerOutboxRepository, NewSchedulerOutboxRepository,
NewProxyLatencyCache, NewProxyLatencyCache,
NewTotpCache, NewTotpCache,
NewRefreshTokenCache,
NewErrorPassthroughCache,
// Encryptors // Encryptors
NewAESEncryptor, NewAESEncryptor,
......
...@@ -598,7 +598,7 @@ func newContractDeps(t *testing.T) *contractDeps { ...@@ -598,7 +598,7 @@ func newContractDeps(t *testing.T) *contractDeps {
} }
userService := service.NewUserService(userRepo, nil) userService := service.NewUserService(userRepo, nil)
apiKeyService := service.NewAPIKeyService(apiKeyRepo, userRepo, groupRepo, userSubRepo, apiKeyCache, cfg) apiKeyService := service.NewAPIKeyService(apiKeyRepo, userRepo, groupRepo, userSubRepo, nil, apiKeyCache, cfg)
usageRepo := newStubUsageLogRepo() usageRepo := newStubUsageLogRepo()
usageService := service.NewUsageService(usageRepo, userRepo, nil, nil) usageService := service.NewUsageService(usageRepo, userRepo, nil, nil)
...@@ -612,7 +612,7 @@ func newContractDeps(t *testing.T) *contractDeps { ...@@ -612,7 +612,7 @@ func newContractDeps(t *testing.T) *contractDeps {
settingRepo := newStubSettingRepo() settingRepo := newStubSettingRepo()
settingService := service.NewSettingService(settingRepo, cfg) settingService := service.NewSettingService(settingRepo, cfg)
adminService := service.NewAdminService(userRepo, groupRepo, &accountRepo, nil, proxyRepo, apiKeyRepo, redeemRepo, nil, nil, nil, nil) adminService := service.NewAdminService(userRepo, groupRepo, &accountRepo, nil, proxyRepo, apiKeyRepo, redeemRepo, nil, nil, nil, nil, nil)
authHandler := handler.NewAuthHandler(cfg, nil, userService, settingService, nil, redeemService, nil) authHandler := handler.NewAuthHandler(cfg, nil, userService, settingService, nil, redeemService, nil)
apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService) apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService)
usageHandler := handler.NewUsageHandler(usageService, apiKeyService) usageHandler := handler.NewUsageHandler(usageService, apiKeyService)
...@@ -1619,6 +1619,10 @@ func (r *stubUsageLogRepo) GetUserDashboardStats(ctx context.Context, userID int ...@@ -1619,6 +1619,10 @@ func (r *stubUsageLogRepo) GetUserDashboardStats(ctx context.Context, userID int
return nil, errors.New("not implemented") return nil, errors.New("not implemented")
} }
func (r *stubUsageLogRepo) GetAPIKeyDashboardStats(ctx context.Context, apiKeyID int64) (*usagestats.UserDashboardStats, error) {
return nil, errors.New("not implemented")
}
func (r *stubUsageLogRepo) GetUserUsageTrendByUserID(ctx context.Context, userID int64, startTime, endTime time.Time, granularity string) ([]usagestats.TrendDataPoint, error) { func (r *stubUsageLogRepo) GetUserUsageTrendByUserID(ctx context.Context, userID int64, startTime, endTime time.Time, granularity string) ([]usagestats.TrendDataPoint, error) {
return nil, errors.New("not implemented") return nil, errors.New("not implemented")
} }
......
...@@ -58,10 +58,39 @@ func ProvideRouter( ...@@ -58,10 +58,39 @@ func ProvideRouter(
// ProvideHTTPServer 提供 HTTP 服务器 // ProvideHTTPServer 提供 HTTP 服务器
func ProvideHTTPServer(cfg *config.Config, router *gin.Engine) *http.Server { func ProvideHTTPServer(cfg *config.Config, router *gin.Engine) *http.Server {
handler := h2c.NewHandler(router, &http2.Server{}) httpHandler := http.Handler(router)
globalMaxSize := cfg.Server.MaxRequestBodySize
if globalMaxSize <= 0 {
globalMaxSize = cfg.Gateway.MaxBodySize
}
if globalMaxSize > 0 {
httpHandler = http.MaxBytesHandler(httpHandler, globalMaxSize)
log.Printf("Global max request body size: %d bytes (%.2f MB)", globalMaxSize, float64(globalMaxSize)/(1<<20))
}
// 根据配置决定是否启用 H2C
if cfg.Server.H2C.Enabled {
h2cConfig := cfg.Server.H2C
httpHandler = h2c.NewHandler(router, &http2.Server{
MaxConcurrentStreams: h2cConfig.MaxConcurrentStreams,
IdleTimeout: time.Duration(h2cConfig.IdleTimeout) * time.Second,
MaxReadFrameSize: uint32(h2cConfig.MaxReadFrameSize),
MaxUploadBufferPerConnection: int32(h2cConfig.MaxUploadBufferPerConnection),
MaxUploadBufferPerStream: int32(h2cConfig.MaxUploadBufferPerStream),
})
log.Printf("HTTP/2 Cleartext (h2c) enabled: max_concurrent_streams=%d, idle_timeout=%ds, max_read_frame_size=%d, max_upload_buffer_per_connection=%d, max_upload_buffer_per_stream=%d",
h2cConfig.MaxConcurrentStreams,
h2cConfig.IdleTimeout,
h2cConfig.MaxReadFrameSize,
h2cConfig.MaxUploadBufferPerConnection,
h2cConfig.MaxUploadBufferPerStream,
)
}
return &http.Server{ return &http.Server{
Addr: cfg.Server.Address(), Addr: cfg.Server.Address(),
Handler: handler, Handler: httpHandler,
// ReadHeaderTimeout: 读取请求头的超时时间,防止慢速请求头攻击 // ReadHeaderTimeout: 读取请求头的超时时间,防止慢速请求头攻击
ReadHeaderTimeout: time.Duration(cfg.Server.ReadHeaderTimeout) * time.Second, ReadHeaderTimeout: time.Duration(cfg.Server.ReadHeaderTimeout) * time.Second,
// IdleTimeout: 空闲连接超时时间,释放不活跃的连接资源 // IdleTimeout: 空闲连接超时时间,释放不活跃的连接资源
......
...@@ -93,6 +93,7 @@ func newTestAPIKeyService(repo service.APIKeyRepository) *service.APIKeyService ...@@ -93,6 +93,7 @@ func newTestAPIKeyService(repo service.APIKeyRepository) *service.APIKeyService
nil, // userRepo (unused in GetByKey) nil, // userRepo (unused in GetByKey)
nil, // groupRepo nil, // groupRepo
nil, // userSubRepo nil, // userSubRepo
nil, // userGroupRateRepo
nil, // cache nil, // cache
&config.Config{}, &config.Config{},
) )
...@@ -187,6 +188,7 @@ func TestApiKeyAuthWithSubscriptionGoogleSetsGroupContext(t *testing.T) { ...@@ -187,6 +188,7 @@ func TestApiKeyAuthWithSubscriptionGoogleSetsGroupContext(t *testing.T) {
nil, nil,
nil, nil,
nil, nil,
nil,
&config.Config{RunMode: config.RunModeSimple}, &config.Config{RunMode: config.RunModeSimple},
) )
......
...@@ -59,7 +59,7 @@ func TestSimpleModeBypassesQuotaCheck(t *testing.T) { ...@@ -59,7 +59,7 @@ func TestSimpleModeBypassesQuotaCheck(t *testing.T) {
t.Run("simple_mode_bypasses_quota_check", func(t *testing.T) { t.Run("simple_mode_bypasses_quota_check", func(t *testing.T) {
cfg := &config.Config{RunMode: config.RunModeSimple} cfg := &config.Config{RunMode: config.RunModeSimple}
apiKeyService := service.NewAPIKeyService(apiKeyRepo, nil, nil, nil, nil, cfg) apiKeyService := service.NewAPIKeyService(apiKeyRepo, nil, nil, nil, nil, nil, cfg)
subscriptionService := service.NewSubscriptionService(nil, &stubUserSubscriptionRepo{}, nil) subscriptionService := service.NewSubscriptionService(nil, &stubUserSubscriptionRepo{}, nil)
router := newAuthTestRouter(apiKeyService, subscriptionService, cfg) router := newAuthTestRouter(apiKeyService, subscriptionService, cfg)
...@@ -73,7 +73,7 @@ func TestSimpleModeBypassesQuotaCheck(t *testing.T) { ...@@ -73,7 +73,7 @@ func TestSimpleModeBypassesQuotaCheck(t *testing.T) {
t.Run("standard_mode_enforces_quota_check", func(t *testing.T) { t.Run("standard_mode_enforces_quota_check", func(t *testing.T) {
cfg := &config.Config{RunMode: config.RunModeStandard} cfg := &config.Config{RunMode: config.RunModeStandard}
apiKeyService := service.NewAPIKeyService(apiKeyRepo, nil, nil, nil, nil, cfg) apiKeyService := service.NewAPIKeyService(apiKeyRepo, nil, nil, nil, nil, nil, cfg)
now := time.Now() now := time.Now()
sub := &service.UserSubscription{ sub := &service.UserSubscription{
...@@ -150,7 +150,7 @@ func TestAPIKeyAuthSetsGroupContext(t *testing.T) { ...@@ -150,7 +150,7 @@ func TestAPIKeyAuthSetsGroupContext(t *testing.T) {
} }
cfg := &config.Config{RunMode: config.RunModeSimple} cfg := &config.Config{RunMode: config.RunModeSimple}
apiKeyService := service.NewAPIKeyService(apiKeyRepo, nil, nil, nil, nil, cfg) apiKeyService := service.NewAPIKeyService(apiKeyRepo, nil, nil, nil, nil, nil, cfg)
router := gin.New() router := gin.New()
router.Use(gin.HandlerFunc(NewAPIKeyAuthMiddleware(apiKeyService, nil, cfg))) router.Use(gin.HandlerFunc(NewAPIKeyAuthMiddleware(apiKeyService, nil, cfg)))
router.GET("/t", func(c *gin.Context) { router.GET("/t", func(c *gin.Context) {
...@@ -208,7 +208,7 @@ func TestAPIKeyAuthOverwritesInvalidContextGroup(t *testing.T) { ...@@ -208,7 +208,7 @@ func TestAPIKeyAuthOverwritesInvalidContextGroup(t *testing.T) {
} }
cfg := &config.Config{RunMode: config.RunModeSimple} cfg := &config.Config{RunMode: config.RunModeSimple}
apiKeyService := service.NewAPIKeyService(apiKeyRepo, nil, nil, nil, nil, cfg) apiKeyService := service.NewAPIKeyService(apiKeyRepo, nil, nil, nil, nil, nil, cfg)
router := gin.New() router := gin.New()
router.Use(gin.HandlerFunc(NewAPIKeyAuthMiddleware(apiKeyService, nil, cfg))) router.Use(gin.HandlerFunc(NewAPIKeyAuthMiddleware(apiKeyService, nil, cfg)))
......
...@@ -34,12 +34,16 @@ func Logger() gin.HandlerFunc { ...@@ -34,12 +34,16 @@ func Logger() gin.HandlerFunc {
// 客户端IP // 客户端IP
clientIP := c.ClientIP() clientIP := c.ClientIP()
// 日志格式: [时间] 状态码 | 延迟 | IP | 方法 路径 // 协议版本
log.Printf("[GIN] %v | %3d | %13v | %15s | %-7s %s", protocol := c.Request.Proto
// 日志格式: [时间] 状态码 | 延迟 | IP | 协议 | 方法 路径
log.Printf("[GIN] %v | %3d | %13v | %15s | %-6s | %-7s %s",
endTime.Format("2006/01/02 - 15:04:05"), endTime.Format("2006/01/02 - 15:04:05"),
statusCode, statusCode,
latency, latency,
clientIP, clientIP,
protocol,
method, method,
path, path,
) )
......
...@@ -67,6 +67,9 @@ func RegisterAdminRoutes( ...@@ -67,6 +67,9 @@ func RegisterAdminRoutes(
// 用户属性管理 // 用户属性管理
registerUserAttributeRoutes(admin, h) registerUserAttributeRoutes(admin, h)
// 错误透传规则管理
registerErrorPassthroughRoutes(admin, h)
} }
} }
...@@ -387,3 +390,14 @@ func registerUserAttributeRoutes(admin *gin.RouterGroup, h *handler.Handlers) { ...@@ -387,3 +390,14 @@ func registerUserAttributeRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
attrs.DELETE("/:id", h.Admin.UserAttribute.DeleteDefinition) attrs.DELETE("/:id", h.Admin.UserAttribute.DeleteDefinition)
} }
} }
func registerErrorPassthroughRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
rules := admin.Group("/error-passthrough-rules")
{
rules.GET("", h.Admin.ErrorPassthrough.List)
rules.GET("/:id", h.Admin.ErrorPassthrough.GetByID)
rules.POST("", h.Admin.ErrorPassthrough.Create)
rules.PUT("/:id", h.Admin.ErrorPassthrough.Update)
rules.DELETE("/:id", h.Admin.ErrorPassthrough.Delete)
}
}
...@@ -28,6 +28,12 @@ func RegisterAuthRoutes( ...@@ -28,6 +28,12 @@ func RegisterAuthRoutes(
auth.POST("/login", h.Auth.Login) auth.POST("/login", h.Auth.Login)
auth.POST("/login/2fa", h.Auth.Login2FA) auth.POST("/login/2fa", h.Auth.Login2FA)
auth.POST("/send-verify-code", h.Auth.SendVerifyCode) auth.POST("/send-verify-code", h.Auth.SendVerifyCode)
// Token刷新接口添加速率限制:每分钟最多 30 次(Redis 故障时 fail-close)
auth.POST("/refresh", rateLimiter.LimitWithOptions("refresh-token", 30, time.Minute, middleware.RateLimitOptions{
FailureMode: middleware.RateLimitFailClose,
}), h.Auth.RefreshToken)
// 登出接口(公开,允许未认证用户调用以撤销Refresh Token)
auth.POST("/logout", h.Auth.Logout)
// 优惠码验证接口添加速率限制:每分钟最多 10 次(Redis 故障时 fail-close) // 优惠码验证接口添加速率限制:每分钟最多 10 次(Redis 故障时 fail-close)
auth.POST("/validate-promo-code", rateLimiter.LimitWithOptions("validate-promo", 10, time.Minute, middleware.RateLimitOptions{ auth.POST("/validate-promo-code", rateLimiter.LimitWithOptions("validate-promo", 10, time.Minute, middleware.RateLimitOptions{
FailureMode: middleware.RateLimitFailClose, FailureMode: middleware.RateLimitFailClose,
...@@ -59,5 +65,7 @@ func RegisterAuthRoutes( ...@@ -59,5 +65,7 @@ func RegisterAuthRoutes(
authenticated.Use(gin.HandlerFunc(jwtAuth)) authenticated.Use(gin.HandlerFunc(jwtAuth))
{ {
authenticated.GET("/auth/me", h.Auth.GetCurrentUser) authenticated.GET("/auth/me", h.Auth.GetCurrentUser)
// 撤销所有会话(需要认证)
authenticated.POST("/auth/revoke-all-sessions", h.Auth.RevokeAllSessions)
} }
} }
...@@ -49,6 +49,7 @@ func RegisterUserRoutes( ...@@ -49,6 +49,7 @@ func RegisterUserRoutes(
groups := authenticated.Group("/groups") groups := authenticated.Group("/groups")
{ {
groups.GET("/available", h.APIKey.GetAvailableGroups) groups.GET("/available", h.APIKey.GetAvailableGroups)
groups.GET("/rates", h.APIKey.GetUserGroupRates)
} }
// 使用记录 // 使用记录
......
...@@ -41,6 +41,7 @@ type UsageLogRepository interface { ...@@ -41,6 +41,7 @@ type UsageLogRepository interface {
// User dashboard stats // User dashboard stats
GetUserDashboardStats(ctx context.Context, userID int64) (*usagestats.UserDashboardStats, error) GetUserDashboardStats(ctx context.Context, userID int64) (*usagestats.UserDashboardStats, error)
GetAPIKeyDashboardStats(ctx context.Context, apiKeyID int64) (*usagestats.UserDashboardStats, error)
GetUserUsageTrendByUserID(ctx context.Context, userID int64, startTime, endTime time.Time, granularity string) ([]usagestats.TrendDataPoint, error) GetUserUsageTrendByUserID(ctx context.Context, userID int64, startTime, endTime time.Time, granularity string) ([]usagestats.TrendDataPoint, error)
GetUserModelStats(ctx context.Context, userID int64, startTime, endTime time.Time) ([]usagestats.ModelStat, error) GetUserModelStats(ctx context.Context, userID int64, startTime, endTime time.Time) ([]usagestats.ModelStat, error)
......
...@@ -93,6 +93,9 @@ type UpdateUserInput struct { ...@@ -93,6 +93,9 @@ type UpdateUserInput struct {
Concurrency *int // 使用指针区分"未提供"和"设置为0" Concurrency *int // 使用指针区分"未提供"和"设置为0"
Status string Status string
AllowedGroups *[]int64 // 使用指针区分"未提供"和"设置为空数组" AllowedGroups *[]int64 // 使用指针区分"未提供"和"设置为空数组"
// GroupRates 用户专属分组倍率配置
// map[groupID]*rate,nil 表示删除该分组的专属倍率
GroupRates map[int64]*float64
} }
type CreateGroupInput struct { type CreateGroupInput struct {
...@@ -304,6 +307,7 @@ type adminServiceImpl struct { ...@@ -304,6 +307,7 @@ type adminServiceImpl struct {
proxyRepo ProxyRepository proxyRepo ProxyRepository
apiKeyRepo APIKeyRepository apiKeyRepo APIKeyRepository
redeemCodeRepo RedeemCodeRepository redeemCodeRepo RedeemCodeRepository
userGroupRateRepo UserGroupRateRepository
billingCacheService *BillingCacheService billingCacheService *BillingCacheService
proxyProber ProxyExitInfoProber proxyProber ProxyExitInfoProber
proxyLatencyCache ProxyLatencyCache proxyLatencyCache ProxyLatencyCache
...@@ -319,6 +323,7 @@ func NewAdminService( ...@@ -319,6 +323,7 @@ func NewAdminService(
proxyRepo ProxyRepository, proxyRepo ProxyRepository,
apiKeyRepo APIKeyRepository, apiKeyRepo APIKeyRepository,
redeemCodeRepo RedeemCodeRepository, redeemCodeRepo RedeemCodeRepository,
userGroupRateRepo UserGroupRateRepository,
billingCacheService *BillingCacheService, billingCacheService *BillingCacheService,
proxyProber ProxyExitInfoProber, proxyProber ProxyExitInfoProber,
proxyLatencyCache ProxyLatencyCache, proxyLatencyCache ProxyLatencyCache,
...@@ -332,6 +337,7 @@ func NewAdminService( ...@@ -332,6 +337,7 @@ func NewAdminService(
proxyRepo: proxyRepo, proxyRepo: proxyRepo,
apiKeyRepo: apiKeyRepo, apiKeyRepo: apiKeyRepo,
redeemCodeRepo: redeemCodeRepo, redeemCodeRepo: redeemCodeRepo,
userGroupRateRepo: userGroupRateRepo,
billingCacheService: billingCacheService, billingCacheService: billingCacheService,
proxyProber: proxyProber, proxyProber: proxyProber,
proxyLatencyCache: proxyLatencyCache, proxyLatencyCache: proxyLatencyCache,
...@@ -346,11 +352,35 @@ func (s *adminServiceImpl) ListUsers(ctx context.Context, page, pageSize int, fi ...@@ -346,11 +352,35 @@ func (s *adminServiceImpl) ListUsers(ctx context.Context, page, pageSize int, fi
if err != nil { if err != nil {
return nil, 0, err return nil, 0, err
} }
// 批量加载用户专属分组倍率
if s.userGroupRateRepo != nil && len(users) > 0 {
for i := range users {
rates, err := s.userGroupRateRepo.GetByUserID(ctx, users[i].ID)
if err != nil {
log.Printf("failed to load user group rates: user_id=%d err=%v", users[i].ID, err)
continue
}
users[i].GroupRates = rates
}
}
return users, result.Total, nil return users, result.Total, nil
} }
func (s *adminServiceImpl) GetUser(ctx context.Context, id int64) (*User, error) { func (s *adminServiceImpl) GetUser(ctx context.Context, id int64) (*User, error) {
return s.userRepo.GetByID(ctx, id) user, err := s.userRepo.GetByID(ctx, id)
if err != nil {
return nil, err
}
// 加载用户专属分组倍率
if s.userGroupRateRepo != nil {
rates, err := s.userGroupRateRepo.GetByUserID(ctx, id)
if err != nil {
log.Printf("failed to load user group rates: user_id=%d err=%v", id, err)
} else {
user.GroupRates = rates
}
}
return user, nil
} }
func (s *adminServiceImpl) CreateUser(ctx context.Context, input *CreateUserInput) (*User, error) { func (s *adminServiceImpl) CreateUser(ctx context.Context, input *CreateUserInput) (*User, error) {
...@@ -419,6 +449,14 @@ func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *Upda ...@@ -419,6 +449,14 @@ func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *Upda
if err := s.userRepo.Update(ctx, user); err != nil { if err := s.userRepo.Update(ctx, user); err != nil {
return nil, err return nil, err
} }
// 同步用户专属分组倍率
if input.GroupRates != nil && s.userGroupRateRepo != nil {
if err := s.userGroupRateRepo.SyncUserGroupRates(ctx, user.ID, input.GroupRates); err != nil {
log.Printf("failed to sync user group rates: user_id=%d err=%v", user.ID, err)
}
}
if s.authCacheInvalidator != nil { if s.authCacheInvalidator != nil {
if user.Concurrency != oldConcurrency || user.Status != oldStatus || user.Role != oldRole { if user.Concurrency != oldConcurrency || user.Status != oldStatus || user.Role != oldRole {
s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, user.ID) s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, user.ID)
...@@ -974,6 +1012,7 @@ func (s *adminServiceImpl) DeleteGroup(ctx context.Context, id int64) error { ...@@ -974,6 +1012,7 @@ func (s *adminServiceImpl) DeleteGroup(ctx context.Context, id int64) error {
if err != nil { if err != nil {
return err return err
} }
// 注意:user_group_rate_multipliers 表通过外键 ON DELETE CASCADE 自动清理
// 事务成功后,异步失效受影响用户的订阅缓存 // 事务成功后,异步失效受影响用户的订阅缓存
if len(affectedUserIDs) > 0 && s.billingCacheService != nil { if len(affectedUserIDs) > 0 && s.billingCacheService != nil {
......
...@@ -1106,7 +1106,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context, ...@@ -1106,7 +1106,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
Message: upstreamMsg, Message: upstreamMsg,
Detail: upstreamDetail, Detail: upstreamDetail,
}) })
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode} return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody}
} }
return nil, s.writeMappedClaudeError(c, account, resp.StatusCode, resp.Header.Get("x-request-id"), respBody) return nil, s.writeMappedClaudeError(c, account, resp.StatusCode, resp.Header.Get("x-request-id"), respBody)
...@@ -1779,6 +1779,7 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co ...@@ -1779,6 +1779,7 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co
// 处理错误响应 // 处理错误响应
if resp.StatusCode >= 400 { if resp.StatusCode >= 400 {
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
contentType := resp.Header.Get("Content-Type")
// 尽早关闭原始响应体,释放连接;后续逻辑仍可能需要读取 body,因此用内存副本重新包装。 // 尽早关闭原始响应体,释放连接;后续逻辑仍可能需要读取 body,因此用内存副本重新包装。
_ = resp.Body.Close() _ = resp.Body.Close()
resp.Body = io.NopCloser(bytes.NewReader(respBody)) resp.Body = io.NopCloser(bytes.NewReader(respBody))
...@@ -1849,10 +1850,8 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co ...@@ -1849,10 +1850,8 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co
Message: upstreamMsg, Message: upstreamMsg,
Detail: upstreamDetail, Detail: upstreamDetail,
}) })
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode} return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: unwrappedForOps}
} }
contentType := resp.Header.Get("Content-Type")
if contentType == "" { if contentType == "" {
contentType = "application/json" contentType = "application/json"
} }
......
...@@ -115,15 +115,16 @@ type UpdateAPIKeyRequest struct { ...@@ -115,15 +115,16 @@ type UpdateAPIKeyRequest struct {
// APIKeyService API Key服务 // APIKeyService API Key服务
type APIKeyService struct { type APIKeyService struct {
apiKeyRepo APIKeyRepository apiKeyRepo APIKeyRepository
userRepo UserRepository userRepo UserRepository
groupRepo GroupRepository groupRepo GroupRepository
userSubRepo UserSubscriptionRepository userSubRepo UserSubscriptionRepository
cache APIKeyCache userGroupRateRepo UserGroupRateRepository
cfg *config.Config cache APIKeyCache
authCacheL1 *ristretto.Cache cfg *config.Config
authCfg apiKeyAuthCacheConfig authCacheL1 *ristretto.Cache
authGroup singleflight.Group authCfg apiKeyAuthCacheConfig
authGroup singleflight.Group
} }
// NewAPIKeyService 创建API Key服务实例 // NewAPIKeyService 创建API Key服务实例
...@@ -132,16 +133,18 @@ func NewAPIKeyService( ...@@ -132,16 +133,18 @@ func NewAPIKeyService(
userRepo UserRepository, userRepo UserRepository,
groupRepo GroupRepository, groupRepo GroupRepository,
userSubRepo UserSubscriptionRepository, userSubRepo UserSubscriptionRepository,
userGroupRateRepo UserGroupRateRepository,
cache APIKeyCache, cache APIKeyCache,
cfg *config.Config, cfg *config.Config,
) *APIKeyService { ) *APIKeyService {
svc := &APIKeyService{ svc := &APIKeyService{
apiKeyRepo: apiKeyRepo, apiKeyRepo: apiKeyRepo,
userRepo: userRepo, userRepo: userRepo,
groupRepo: groupRepo, groupRepo: groupRepo,
userSubRepo: userSubRepo, userSubRepo: userSubRepo,
cache: cache, userGroupRateRepo: userGroupRateRepo,
cfg: cfg, cache: cache,
cfg: cfg,
} }
svc.initAuthCache(cfg) svc.initAuthCache(cfg)
return svc return svc
...@@ -627,6 +630,19 @@ func (s *APIKeyService) SearchAPIKeys(ctx context.Context, userID int64, keyword ...@@ -627,6 +630,19 @@ func (s *APIKeyService) SearchAPIKeys(ctx context.Context, userID int64, keyword
return keys, nil return keys, nil
} }
// GetUserGroupRates 获取用户的专属分组倍率配置
// 返回 map[groupID]rateMultiplier
func (s *APIKeyService) GetUserGroupRates(ctx context.Context, userID int64) (map[int64]float64, error) {
if s.userGroupRateRepo == nil {
return nil, nil
}
rates, err := s.userGroupRateRepo.GetByUserID(ctx, userID)
if err != nil {
return nil, fmt.Errorf("get user group rates: %w", err)
}
return rates, nil
}
// CheckAPIKeyQuotaAndExpiry checks if the API key is valid for use (not expired, quota not exhausted) // CheckAPIKeyQuotaAndExpiry checks if the API key is valid for use (not expired, quota not exhausted)
// Returns nil if valid, error if invalid // Returns nil if valid, error if invalid
func (s *APIKeyService) CheckAPIKeyQuotaAndExpiry(apiKey *APIKey) error { func (s *APIKeyService) CheckAPIKeyQuotaAndExpiry(apiKey *APIKey) error {
......
...@@ -167,7 +167,7 @@ func TestAPIKeyService_GetByKey_UsesL2Cache(t *testing.T) { ...@@ -167,7 +167,7 @@ func TestAPIKeyService_GetByKey_UsesL2Cache(t *testing.T) {
NegativeTTLSeconds: 30, NegativeTTLSeconds: 30,
}, },
} }
svc := NewAPIKeyService(repo, nil, nil, nil, cache, cfg) svc := NewAPIKeyService(repo, nil, nil, nil, nil, cache, cfg)
groupID := int64(9) groupID := int64(9)
cacheEntry := &APIKeyAuthCacheEntry{ cacheEntry := &APIKeyAuthCacheEntry{
...@@ -223,7 +223,7 @@ func TestAPIKeyService_GetByKey_NegativeCache(t *testing.T) { ...@@ -223,7 +223,7 @@ func TestAPIKeyService_GetByKey_NegativeCache(t *testing.T) {
NegativeTTLSeconds: 30, NegativeTTLSeconds: 30,
}, },
} }
svc := NewAPIKeyService(repo, nil, nil, nil, cache, cfg) svc := NewAPIKeyService(repo, nil, nil, nil, nil, cache, cfg)
cache.getAuthCache = func(ctx context.Context, key string) (*APIKeyAuthCacheEntry, error) { cache.getAuthCache = func(ctx context.Context, key string) (*APIKeyAuthCacheEntry, error) {
return &APIKeyAuthCacheEntry{NotFound: true}, nil return &APIKeyAuthCacheEntry{NotFound: true}, nil
} }
...@@ -256,7 +256,7 @@ func TestAPIKeyService_GetByKey_CacheMissStoresL2(t *testing.T) { ...@@ -256,7 +256,7 @@ func TestAPIKeyService_GetByKey_CacheMissStoresL2(t *testing.T) {
NegativeTTLSeconds: 30, NegativeTTLSeconds: 30,
}, },
} }
svc := NewAPIKeyService(repo, nil, nil, nil, cache, cfg) svc := NewAPIKeyService(repo, nil, nil, nil, nil, cache, cfg)
cache.getAuthCache = func(ctx context.Context, key string) (*APIKeyAuthCacheEntry, error) { cache.getAuthCache = func(ctx context.Context, key string) (*APIKeyAuthCacheEntry, error) {
return nil, redis.Nil return nil, redis.Nil
} }
...@@ -293,7 +293,7 @@ func TestAPIKeyService_GetByKey_UsesL1Cache(t *testing.T) { ...@@ -293,7 +293,7 @@ func TestAPIKeyService_GetByKey_UsesL1Cache(t *testing.T) {
L1TTLSeconds: 60, L1TTLSeconds: 60,
}, },
} }
svc := NewAPIKeyService(repo, nil, nil, nil, cache, cfg) svc := NewAPIKeyService(repo, nil, nil, nil, nil, cache, cfg)
require.NotNil(t, svc.authCacheL1) require.NotNil(t, svc.authCacheL1)
_, err := svc.GetByKey(context.Background(), "k-l1") _, err := svc.GetByKey(context.Background(), "k-l1")
...@@ -320,7 +320,7 @@ func TestAPIKeyService_InvalidateAuthCacheByUserID(t *testing.T) { ...@@ -320,7 +320,7 @@ func TestAPIKeyService_InvalidateAuthCacheByUserID(t *testing.T) {
NegativeTTLSeconds: 30, NegativeTTLSeconds: 30,
}, },
} }
svc := NewAPIKeyService(repo, nil, nil, nil, cache, cfg) svc := NewAPIKeyService(repo, nil, nil, nil, nil, cache, cfg)
svc.InvalidateAuthCacheByUserID(context.Background(), 7) svc.InvalidateAuthCacheByUserID(context.Background(), 7)
require.Len(t, cache.deleteAuthKeys, 2) require.Len(t, cache.deleteAuthKeys, 2)
...@@ -338,7 +338,7 @@ func TestAPIKeyService_InvalidateAuthCacheByGroupID(t *testing.T) { ...@@ -338,7 +338,7 @@ func TestAPIKeyService_InvalidateAuthCacheByGroupID(t *testing.T) {
L2TTLSeconds: 60, L2TTLSeconds: 60,
}, },
} }
svc := NewAPIKeyService(repo, nil, nil, nil, cache, cfg) svc := NewAPIKeyService(repo, nil, nil, nil, nil, cache, cfg)
svc.InvalidateAuthCacheByGroupID(context.Background(), 9) svc.InvalidateAuthCacheByGroupID(context.Background(), 9)
require.Len(t, cache.deleteAuthKeys, 2) require.Len(t, cache.deleteAuthKeys, 2)
...@@ -356,7 +356,7 @@ func TestAPIKeyService_InvalidateAuthCacheByKey(t *testing.T) { ...@@ -356,7 +356,7 @@ func TestAPIKeyService_InvalidateAuthCacheByKey(t *testing.T) {
L2TTLSeconds: 60, L2TTLSeconds: 60,
}, },
} }
svc := NewAPIKeyService(repo, nil, nil, nil, cache, cfg) svc := NewAPIKeyService(repo, nil, nil, nil, nil, cache, cfg)
svc.InvalidateAuthCacheByKey(context.Background(), "k1") svc.InvalidateAuthCacheByKey(context.Background(), "k1")
require.Len(t, cache.deleteAuthKeys, 1) require.Len(t, cache.deleteAuthKeys, 1)
...@@ -375,7 +375,7 @@ func TestAPIKeyService_GetByKey_CachesNegativeOnRepoMiss(t *testing.T) { ...@@ -375,7 +375,7 @@ func TestAPIKeyService_GetByKey_CachesNegativeOnRepoMiss(t *testing.T) {
NegativeTTLSeconds: 30, NegativeTTLSeconds: 30,
}, },
} }
svc := NewAPIKeyService(repo, nil, nil, nil, cache, cfg) svc := NewAPIKeyService(repo, nil, nil, nil, nil, cache, cfg)
cache.getAuthCache = func(ctx context.Context, key string) (*APIKeyAuthCacheEntry, error) { cache.getAuthCache = func(ctx context.Context, key string) (*APIKeyAuthCacheEntry, error) {
return nil, redis.Nil return nil, redis.Nil
} }
...@@ -411,7 +411,7 @@ func TestAPIKeyService_GetByKey_SingleflightCollapses(t *testing.T) { ...@@ -411,7 +411,7 @@ func TestAPIKeyService_GetByKey_SingleflightCollapses(t *testing.T) {
Singleflight: true, Singleflight: true,
}, },
} }
svc := NewAPIKeyService(repo, nil, nil, nil, cache, cfg) svc := NewAPIKeyService(repo, nil, nil, nil, nil, cache, cfg)
start := make(chan struct{}) start := make(chan struct{})
wg := sync.WaitGroup{} wg := sync.WaitGroup{}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment