Commit 5e060b22 authored by erio's avatar erio
Browse files

Merge remote-tracking branch 'upstream/main' into feat/channel-insights

# Conflicts:
#	backend/cmd/server/wire_gen.go
parents 6f04c25e 0a80ec80
...@@ -8,6 +8,7 @@ import ( ...@@ -8,6 +8,7 @@ import (
"net/url" "net/url"
"testing" "testing"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai" "github.com/Wei-Shaw/sub2api/internal/pkg/openai"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite" "github.com/stretchr/testify/suite"
...@@ -204,6 +205,17 @@ func (s *OpenAIOAuthServiceSuite) TestRequestError_ClosedServer() { ...@@ -204,6 +205,17 @@ func (s *OpenAIOAuthServiceSuite) TestRequestError_ClosedServer() {
require.ErrorContains(s.T(), err, "request failed") require.ErrorContains(s.T(), err, "request failed")
} }
func (s *OpenAIOAuthServiceSuite) TestExchangeCode_RequestErrorWithoutProxyReturnsProxyHint() {
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
s.srv.Close()
_, err := s.svc.ExchangeCode(s.ctx, "code", "ver", openai.DefaultRedirectURI, "", "")
require.Error(s.T(), err)
require.Equal(s.T(), "OPENAI_OAUTH_PROXY_REQUIRED", infraerrors.Reason(err))
require.Contains(s.T(), infraerrors.Message(err), "no proxy is configured")
}
func (s *OpenAIOAuthServiceSuite) TestContextCancel() { func (s *OpenAIOAuthServiceSuite) TestContextCancel() {
started := make(chan struct{}) started := make(chan struct{})
block := make(chan struct{}) block := make(chan struct{})
......
...@@ -290,7 +290,6 @@ func incrementUsageBillingAccountQuota(ctx context.Context, tx *sql.Tx, accountI ...@@ -290,7 +290,6 @@ func incrementUsageBillingAccountQuota(ctx context.Context, tx *sql.Tx, accountI
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer func() { _ = rows.Close() }()
var state service.AccountQuotaState var state service.AccountQuotaState
if rows.Next() { if rows.Next() {
...@@ -299,18 +298,36 @@ func incrementUsageBillingAccountQuota(ctx context.Context, tx *sql.Tx, accountI ...@@ -299,18 +298,36 @@ func incrementUsageBillingAccountQuota(ctx context.Context, tx *sql.Tx, accountI
&state.DailyUsed, &state.DailyLimit, &state.DailyUsed, &state.DailyLimit,
&state.WeeklyUsed, &state.WeeklyLimit, &state.WeeklyUsed, &state.WeeklyLimit,
); err != nil { ); err != nil {
_ = rows.Close()
return nil, err return nil, err
} }
} else { } else {
if err := rows.Err(); err != nil { if err := rows.Err(); err != nil {
_ = rows.Close()
return nil, err return nil, err
} }
_ = rows.Close()
return nil, service.ErrAccountNotFound return nil, service.ErrAccountNotFound
} }
if err := rows.Err(); err != nil { if err := rows.Err(); err != nil {
_ = rows.Close()
return nil, err return nil, err
} }
if state.TotalLimit > 0 && state.TotalUsed >= state.TotalLimit && (state.TotalUsed-amount) < state.TotalLimit { // 必须在执行下一条 SQL 前显式关闭 rows:pq 驱动在同一连接上
// 不允许前一条查询的结果集未耗尽时启动新查询,否则会返回
// "unexpected Parse response" 错误。
if err := rows.Close(); err != nil {
return nil, err
}
// 任意维度额度在本次递增中从"未超"跨越到"已超"时,必须刷新调度快照,
// 否则 Redis 中缓存的 Account 仍显示旧的 used 值,后续请求会继续选中本账号,
// 最终观察到 daily_used / weekly_used 大幅超过配置的 limit。
// 对于日/周额度,即使本次触发了周期重置(pre=0、post=amount),
// 判定式 (post-amount) < limit 同样成立,逻辑与总额度保持一致。
crossedTotal := state.TotalLimit > 0 && state.TotalUsed >= state.TotalLimit && (state.TotalUsed-amount) < state.TotalLimit
crossedDaily := state.DailyLimit > 0 && state.DailyUsed >= state.DailyLimit && (state.DailyUsed-amount) < state.DailyLimit
crossedWeekly := state.WeeklyLimit > 0 && state.WeeklyUsed >= state.WeeklyLimit && (state.WeeklyUsed-amount) < state.WeeklyLimit
if crossedTotal || crossedDaily || crossedWeekly {
if err := enqueueSchedulerOutbox(ctx, tx, service.SchedulerOutboxEventAccountChanged, &accountID, nil, nil); err != nil { if err := enqueueSchedulerOutbox(ctx, tx, service.SchedulerOutboxEventAccountChanged, &accountID, nil, nil); err != nil {
logger.LegacyPrintf("repository.usage_billing", "[SchedulerOutbox] enqueue quota exceeded failed: account=%d err=%v", accountID, err) logger.LegacyPrintf("repository.usage_billing", "[SchedulerOutbox] enqueue quota exceeded failed: account=%d err=%v", accountID, err)
return nil, err return nil, err
......
...@@ -199,6 +199,94 @@ func TestUsageBillingRepositoryApply_UpdatesAccountQuota(t *testing.T) { ...@@ -199,6 +199,94 @@ func TestUsageBillingRepositoryApply_UpdatesAccountQuota(t *testing.T) {
require.InDelta(t, 3.5, quotaUsed, 0.000001) require.InDelta(t, 3.5, quotaUsed, 0.000001)
} }
func TestUsageBillingRepositoryApply_EnqueuesSchedulerOutboxOnQuotaCrossing(t *testing.T) {
ctx := context.Background()
client := testEntClient(t)
repo := NewUsageBillingRepository(client, integrationDB)
newFixture := func(t *testing.T, extra map[string]any) (int64, int64) {
t.Helper()
user := mustCreateUser(t, client, &service.User{
Email: fmt.Sprintf("usage-billing-outbox-user-%d-%s@example.com", time.Now().UnixNano(), uuid.NewString()),
PasswordHash: "hash",
})
apiKey := mustCreateApiKey(t, client, &service.APIKey{
UserID: user.ID,
Key: "sk-usage-billing-outbox-" + uuid.NewString(),
Name: "billing-outbox",
})
account := mustCreateAccount(t, client, &service.Account{
Name: "usage-billing-outbox-" + uuid.NewString(),
Type: service.AccountTypeAPIKey,
Extra: extra,
})
return apiKey.ID, account.ID
}
outboxCountFor := func(t *testing.T, accountID int64) int {
t.Helper()
var count int
require.NoError(t, integrationDB.QueryRowContext(ctx,
"SELECT COUNT(*) FROM scheduler_outbox WHERE event_type = $1 AND account_id = $2",
service.SchedulerOutboxEventAccountChanged, accountID,
).Scan(&count))
return count
}
t.Run("daily_first_crossing_enqueues", func(t *testing.T) {
apiKeyID, accountID := newFixture(t, map[string]any{
"quota_daily_limit": 10.0,
})
// 第一次低于日限额:不应入队 outbox
_, err := repo.Apply(ctx, &service.UsageBillingCommand{
RequestID: uuid.NewString(),
APIKeyID: apiKeyID,
AccountID: accountID,
AccountType: service.AccountTypeAPIKey,
AccountQuotaCost: 4,
})
require.NoError(t, err)
require.Equal(t, 0, outboxCountFor(t, accountID), "below limit should not enqueue")
// 第二次跨越日限额:应入队一次 outbox
_, err = repo.Apply(ctx, &service.UsageBillingCommand{
RequestID: uuid.NewString(),
APIKeyID: apiKeyID,
AccountID: accountID,
AccountType: service.AccountTypeAPIKey,
AccountQuotaCost: 8,
})
require.NoError(t, err)
require.Equal(t, 1, outboxCountFor(t, accountID), "crossing daily limit should enqueue once")
// 再次递增(已超):不应重复入队
_, err = repo.Apply(ctx, &service.UsageBillingCommand{
RequestID: uuid.NewString(),
APIKeyID: apiKeyID,
AccountID: accountID,
AccountType: service.AccountTypeAPIKey,
AccountQuotaCost: 2,
})
require.NoError(t, err)
require.Equal(t, 1, outboxCountFor(t, accountID), "subsequent increments beyond limit should not re-enqueue")
})
t.Run("weekly_first_crossing_enqueues", func(t *testing.T) {
apiKeyID, accountID := newFixture(t, map[string]any{
"quota_weekly_limit": 10.0,
})
_, err := repo.Apply(ctx, &service.UsageBillingCommand{
RequestID: uuid.NewString(),
APIKeyID: apiKeyID,
AccountID: accountID,
AccountType: service.AccountTypeAPIKey,
AccountQuotaCost: 15, // 单次即跨越
})
require.NoError(t, err)
require.Equal(t, 1, outboxCountFor(t, accountID), "single-shot crossing weekly limit should enqueue once")
})
}
func TestDashboardAggregationRepositoryCleanupUsageBillingDedup_BatchDeletesOldRows(t *testing.T) { func TestDashboardAggregationRepositoryCleanupUsageBillingDedup_BatchDeletesOldRows(t *testing.T) {
ctx := context.Background() ctx := context.Background()
repo := newDashboardAggregationRepositoryWithSQL(integrationDB) repo := newDashboardAggregationRepositoryWithSQL(integrationDB)
......
...@@ -13,14 +13,14 @@ type userGroupRateRepository struct { ...@@ -13,14 +13,14 @@ type userGroupRateRepository struct {
sql sqlExecutor sql sqlExecutor
} }
// NewUserGroupRateRepository 创建用户专属分组倍率仓储 // NewUserGroupRateRepository 创建用户专属分组倍率/RPM 仓储
func NewUserGroupRateRepository(sqlDB *sql.DB) service.UserGroupRateRepository { func NewUserGroupRateRepository(sqlDB *sql.DB) service.UserGroupRateRepository {
return &userGroupRateRepository{sql: sqlDB} return &userGroupRateRepository{sql: sqlDB}
} }
// GetByUserID 获取用户所有专属分组倍率 // GetByUserID 获取用户所有专属分组 rate_multiplier(仅返回非 NULL 的条目)
func (r *userGroupRateRepository) GetByUserID(ctx context.Context, userID int64) (map[int64]float64, error) { func (r *userGroupRateRepository) GetByUserID(ctx context.Context, userID int64) (map[int64]float64, error) {
query := `SELECT group_id, rate_multiplier FROM user_group_rate_multipliers WHERE user_id = $1` query := `SELECT group_id, rate_multiplier FROM user_group_rate_multipliers WHERE user_id = $1 AND rate_multiplier IS NOT NULL`
rows, err := r.sql.QueryContext(ctx, query, userID) rows, err := r.sql.QueryContext(ctx, query, userID)
if err != nil { if err != nil {
return nil, err return nil, err
...@@ -42,8 +42,7 @@ func (r *userGroupRateRepository) GetByUserID(ctx context.Context, userID int64) ...@@ -42,8 +42,7 @@ func (r *userGroupRateRepository) GetByUserID(ctx context.Context, userID int64)
return result, nil return result, nil
} }
// GetByUserIDs 批量获取多个用户的专属分组倍率。 // GetByUserIDs 批量获取多个用户的专属分组 rate_multiplier(仅返回非 NULL 的条目)
// 返回结构:map[userID]map[groupID]rate
func (r *userGroupRateRepository) GetByUserIDs(ctx context.Context, userIDs []int64) (map[int64]map[int64]float64, error) { func (r *userGroupRateRepository) GetByUserIDs(ctx context.Context, userIDs []int64) (map[int64]map[int64]float64, error) {
result := make(map[int64]map[int64]float64, len(userIDs)) result := make(map[int64]map[int64]float64, len(userIDs))
if len(userIDs) == 0 { if len(userIDs) == 0 {
...@@ -70,7 +69,7 @@ func (r *userGroupRateRepository) GetByUserIDs(ctx context.Context, userIDs []in ...@@ -70,7 +69,7 @@ func (r *userGroupRateRepository) GetByUserIDs(ctx context.Context, userIDs []in
rows, err := r.sql.QueryContext(ctx, ` rows, err := r.sql.QueryContext(ctx, `
SELECT user_id, group_id, rate_multiplier SELECT user_id, group_id, rate_multiplier
FROM user_group_rate_multipliers FROM user_group_rate_multipliers
WHERE user_id = ANY($1) WHERE user_id = ANY($1) AND rate_multiplier IS NOT NULL
`, pq.Array(uniqueIDs)) `, pq.Array(uniqueIDs))
if err != nil { if err != nil {
return nil, err return nil, err
...@@ -95,10 +94,10 @@ func (r *userGroupRateRepository) GetByUserIDs(ctx context.Context, userIDs []in ...@@ -95,10 +94,10 @@ func (r *userGroupRateRepository) GetByUserIDs(ctx context.Context, userIDs []in
return result, nil return result, nil
} }
// GetByGroupID 获取指定分组下所有用户的专属倍率 // GetByGroupID 获取指定分组下所有用户的专属配置(rate 与 rpm_override 任一非 NULL 即返回)
func (r *userGroupRateRepository) GetByGroupID(ctx context.Context, groupID int64) ([]service.UserGroupRateEntry, error) { func (r *userGroupRateRepository) GetByGroupID(ctx context.Context, groupID int64) ([]service.UserGroupRateEntry, error) {
query := ` query := `
SELECT ugr.user_id, u.username, u.email, COALESCE(u.notes, ''), u.status, ugr.rate_multiplier SELECT ugr.user_id, u.username, u.email, COALESCE(u.notes, ''), u.status, ugr.rate_multiplier, ugr.rpm_override
FROM user_group_rate_multipliers ugr FROM user_group_rate_multipliers ugr
JOIN users u ON u.id = ugr.user_id AND u.deleted_at IS NULL JOIN users u ON u.id = ugr.user_id AND u.deleted_at IS NULL
WHERE ugr.group_id = $1 WHERE ugr.group_id = $1
...@@ -113,9 +112,19 @@ func (r *userGroupRateRepository) GetByGroupID(ctx context.Context, groupID int6 ...@@ -113,9 +112,19 @@ func (r *userGroupRateRepository) GetByGroupID(ctx context.Context, groupID int6
var result []service.UserGroupRateEntry var result []service.UserGroupRateEntry
for rows.Next() { for rows.Next() {
var entry service.UserGroupRateEntry var entry service.UserGroupRateEntry
if err := rows.Scan(&entry.UserID, &entry.UserName, &entry.UserEmail, &entry.UserNotes, &entry.UserStatus, &entry.RateMultiplier); err != nil { var rate sql.NullFloat64
var rpm sql.NullInt32
if err := rows.Scan(&entry.UserID, &entry.UserName, &entry.UserEmail, &entry.UserNotes, &entry.UserStatus, &rate, &rpm); err != nil {
return nil, err return nil, err
} }
if rate.Valid {
v := rate.Float64
entry.RateMultiplier = &v
}
if rpm.Valid {
v := int(rpm.Int32)
entry.RPMOverride = &v
}
result = append(result, entry) result = append(result, entry)
} }
if err := rows.Err(); err != nil { if err := rows.Err(); err != nil {
...@@ -124,10 +133,10 @@ func (r *userGroupRateRepository) GetByGroupID(ctx context.Context, groupID int6 ...@@ -124,10 +133,10 @@ func (r *userGroupRateRepository) GetByGroupID(ctx context.Context, groupID int6
return result, nil return result, nil
} }
// GetByUserAndGroup 获取用户在特定分组的专属倍率 // GetByUserAndGroup 获取用户在特定分组的专属 rate_multiplier(NULL 返回 nil)
func (r *userGroupRateRepository) GetByUserAndGroup(ctx context.Context, userID, groupID int64) (*float64, error) { func (r *userGroupRateRepository) GetByUserAndGroup(ctx context.Context, userID, groupID int64) (*float64, error) {
query := `SELECT rate_multiplier FROM user_group_rate_multipliers WHERE user_id = $1 AND group_id = $2` query := `SELECT rate_multiplier FROM user_group_rate_multipliers WHERE user_id = $1 AND group_id = $2`
var rate float64 var rate sql.NullFloat64
err := scanSingleRow(ctx, r.sql, query, []any{userID, groupID}, &rate) err := scanSingleRow(ctx, r.sql, query, []any{userID, groupID}, &rate)
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
return nil, nil return nil, nil
...@@ -135,42 +144,79 @@ func (r *userGroupRateRepository) GetByUserAndGroup(ctx context.Context, userID, ...@@ -135,42 +144,79 @@ func (r *userGroupRateRepository) GetByUserAndGroup(ctx context.Context, userID,
if err != nil { if err != nil {
return nil, err return nil, err
} }
return &rate, nil if !rate.Valid {
return nil, nil
}
v := rate.Float64
return &v, nil
}
// GetRPMOverrideByUserAndGroup 获取用户在特定分组的 rpm_override(NULL 返回 nil)
func (r *userGroupRateRepository) GetRPMOverrideByUserAndGroup(ctx context.Context, userID, groupID int64) (*int, error) {
query := `SELECT rpm_override FROM user_group_rate_multipliers WHERE user_id = $1 AND group_id = $2`
var rpm sql.NullInt32
err := scanSingleRow(ctx, r.sql, query, []any{userID, groupID}, &rpm)
if err == sql.ErrNoRows {
return nil, nil
}
if err != nil {
return nil, err
}
if !rpm.Valid {
return nil, nil
}
v := int(rpm.Int32)
return &v, nil
} }
// SyncUserGroupRates 同步用户的分组专属倍率 // SyncUserGroupRates 同步用户的分组专属 rate_multiplier。
// - 传入空 map:清空该用户所有行的 rate_multiplier;若 rpm_override 也为 NULL 则整行删除。
// - 值为 nil:清空对应行的 rate_multiplier(保留 rpm_override)。
// - 值非 nil:upsert rate_multiplier(保留已有 rpm_override)。
func (r *userGroupRateRepository) SyncUserGroupRates(ctx context.Context, userID int64, rates map[int64]*float64) error { func (r *userGroupRateRepository) SyncUserGroupRates(ctx context.Context, userID int64, rates map[int64]*float64) error {
if len(rates) == 0 { if len(rates) == 0 {
// 如果传入空 map,删除该用户的所有专属倍率 if _, err := r.sql.ExecContext(ctx, `
_, err := r.sql.ExecContext(ctx, `DELETE FROM user_group_rate_multipliers WHERE user_id = $1`, userID) UPDATE user_group_rate_multipliers
SET rate_multiplier = NULL, updated_at = NOW()
WHERE user_id = $1
`, userID); err != nil {
return err
}
_, err := r.sql.ExecContext(ctx,
`DELETE FROM user_group_rate_multipliers WHERE user_id = $1 AND rate_multiplier IS NULL AND rpm_override IS NULL`,
userID)
return err return err
} }
// 分离需要删除和需要 upsert 的记录 var clearGroupIDs []int64
var toDelete []int64
upsertGroupIDs := make([]int64, 0, len(rates)) upsertGroupIDs := make([]int64, 0, len(rates))
upsertRates := make([]float64, 0, len(rates)) upsertRates := make([]float64, 0, len(rates))
for groupID, rate := range rates { for groupID, rate := range rates {
if rate == nil { if rate == nil {
toDelete = append(toDelete, groupID) clearGroupIDs = append(clearGroupIDs, groupID)
} else { } else {
upsertGroupIDs = append(upsertGroupIDs, groupID) upsertGroupIDs = append(upsertGroupIDs, groupID)
upsertRates = append(upsertRates, *rate) upsertRates = append(upsertRates, *rate)
} }
} }
// 删除指定的记录 if len(clearGroupIDs) > 0 {
if len(toDelete) > 0 { if _, err := r.sql.ExecContext(ctx, `
UPDATE user_group_rate_multipliers
SET rate_multiplier = NULL, updated_at = NOW()
WHERE user_id = $1 AND group_id = ANY($2)
`, userID, pq.Array(clearGroupIDs)); err != nil {
return err
}
if _, err := r.sql.ExecContext(ctx, if _, err := r.sql.ExecContext(ctx,
`DELETE FROM user_group_rate_multipliers WHERE user_id = $1 AND group_id = ANY($2)`, `DELETE FROM user_group_rate_multipliers WHERE user_id = $1 AND group_id = ANY($2) AND rate_multiplier IS NULL AND rpm_override IS NULL`,
userID, pq.Array(toDelete)); err != nil { userID, pq.Array(clearGroupIDs)); err != nil {
return err return err
} }
} }
// Upsert 记录
now := time.Now()
if len(upsertGroupIDs) > 0 { if len(upsertGroupIDs) > 0 {
now := time.Now()
_, err := r.sql.ExecContext(ctx, ` _, err := r.sql.ExecContext(ctx, `
INSERT INTO user_group_rate_multipliers (user_id, group_id, rate_multiplier, created_at, updated_at) INSERT INTO user_group_rate_multipliers (user_id, group_id, rate_multiplier, created_at, updated_at)
SELECT SELECT
...@@ -193,14 +239,47 @@ func (r *userGroupRateRepository) SyncUserGroupRates(ctx context.Context, userID ...@@ -193,14 +239,47 @@ func (r *userGroupRateRepository) SyncUserGroupRates(ctx context.Context, userID
return nil return nil
} }
// SyncGroupRateMultipliers 批量同步分组的用户专属倍率(先删后插) // SyncGroupRateMultipliers 同步分组的 rate_multiplier 部分(不触动 rpm_override)。
// 语义:
// - 未出现在 entries 中的用户行:rate_multiplier 归 NULL;若 rpm_override 也为 NULL 则整行删除。
// - 出现的用户行:upsert rate_multiplier。
func (r *userGroupRateRepository) SyncGroupRateMultipliers(ctx context.Context, groupID int64, entries []service.GroupRateMultiplierInput) error { func (r *userGroupRateRepository) SyncGroupRateMultipliers(ctx context.Context, groupID int64, entries []service.GroupRateMultiplierInput) error {
if _, err := r.sql.ExecContext(ctx, `DELETE FROM user_group_rate_multipliers WHERE group_id = $1`, groupID); err != nil { keepUserIDs := make([]int64, 0, len(entries))
for _, e := range entries {
keepUserIDs = append(keepUserIDs, e.UserID)
}
// 未在 entries 列表中的行:清空 rate_multiplier。
if len(keepUserIDs) == 0 {
if _, err := r.sql.ExecContext(ctx, `
UPDATE user_group_rate_multipliers
SET rate_multiplier = NULL, updated_at = NOW()
WHERE group_id = $1
`, groupID); err != nil {
return err
}
} else {
if _, err := r.sql.ExecContext(ctx, `
UPDATE user_group_rate_multipliers
SET rate_multiplier = NULL, updated_at = NOW()
WHERE group_id = $1 AND user_id <> ALL($2)
`, groupID, pq.Array(keepUserIDs)); err != nil {
return err
}
}
// 清空后若整行 NULL 则删除。
if _, err := r.sql.ExecContext(ctx, `
DELETE FROM user_group_rate_multipliers
WHERE group_id = $1 AND rate_multiplier IS NULL AND rpm_override IS NULL
`, groupID); err != nil {
return err return err
} }
if len(entries) == 0 { if len(entries) == 0 {
return nil return nil
} }
userIDs := make([]int64, len(entries)) userIDs := make([]int64, len(entries))
rates := make([]float64, len(entries)) rates := make([]float64, len(entries))
for i, e := range entries { for i, e := range entries {
...@@ -218,13 +297,103 @@ func (r *userGroupRateRepository) SyncGroupRateMultipliers(ctx context.Context, ...@@ -218,13 +297,103 @@ func (r *userGroupRateRepository) SyncGroupRateMultipliers(ctx context.Context,
return err return err
} }
// DeleteByGroupID 删除指定分组的所有用户专属倍率 // SyncGroupRPMOverrides 同步分组的 rpm_override 部分(不触动 rate_multiplier)。
// 语义:
// - 未出现的用户行:rpm_override 归 NULL;若 rate_multiplier 也为 NULL 则整行删除。
// - 出现的用户行:若 RPMOverride 为 nil 则清空;非 nil 则 upsert。
func (r *userGroupRateRepository) SyncGroupRPMOverrides(ctx context.Context, groupID int64, entries []service.GroupRPMOverrideInput) error {
keepUserIDs := make([]int64, 0, len(entries))
var clearUserIDs []int64
upsertUserIDs := make([]int64, 0, len(entries))
upsertValues := make([]int32, 0, len(entries))
for _, e := range entries {
keepUserIDs = append(keepUserIDs, e.UserID)
if e.RPMOverride == nil {
clearUserIDs = append(clearUserIDs, e.UserID)
} else {
upsertUserIDs = append(upsertUserIDs, e.UserID)
upsertValues = append(upsertValues, int32(*e.RPMOverride))
}
}
// 未在 entries 列表中的行:清空 rpm_override。
if len(keepUserIDs) == 0 {
if _, err := r.sql.ExecContext(ctx, `
UPDATE user_group_rate_multipliers
SET rpm_override = NULL, updated_at = NOW()
WHERE group_id = $1
`, groupID); err != nil {
return err
}
} else {
if _, err := r.sql.ExecContext(ctx, `
UPDATE user_group_rate_multipliers
SET rpm_override = NULL, updated_at = NOW()
WHERE group_id = $1 AND user_id <> ALL($2)
`, groupID, pq.Array(keepUserIDs)); err != nil {
return err
}
}
// 显式 clear 的行。
if len(clearUserIDs) > 0 {
if _, err := r.sql.ExecContext(ctx, `
UPDATE user_group_rate_multipliers
SET rpm_override = NULL, updated_at = NOW()
WHERE group_id = $1 AND user_id = ANY($2)
`, groupID, pq.Array(clearUserIDs)); err != nil {
return err
}
}
// 清空后若整行 NULL 则删除。
if _, err := r.sql.ExecContext(ctx, `
DELETE FROM user_group_rate_multipliers
WHERE group_id = $1 AND rate_multiplier IS NULL AND rpm_override IS NULL
`, groupID); err != nil {
return err
}
if len(upsertUserIDs) > 0 {
now := time.Now()
_, err := r.sql.ExecContext(ctx, `
INSERT INTO user_group_rate_multipliers (user_id, group_id, rpm_override, created_at, updated_at)
SELECT data.user_id, $1::bigint, data.rpm_override, $2::timestamptz, $2::timestamptz
FROM unnest($3::bigint[], $4::integer[]) AS data(user_id, rpm_override)
ON CONFLICT (user_id, group_id)
DO UPDATE SET rpm_override = EXCLUDED.rpm_override, updated_at = EXCLUDED.updated_at
`, groupID, now, pq.Array(upsertUserIDs), pq.Array(upsertValues))
if err != nil {
return err
}
}
return nil
}
// ClearGroupRPMOverrides 清空指定分组所有行的 rpm_override。
func (r *userGroupRateRepository) ClearGroupRPMOverrides(ctx context.Context, groupID int64) error {
if _, err := r.sql.ExecContext(ctx, `
UPDATE user_group_rate_multipliers
SET rpm_override = NULL, updated_at = NOW()
WHERE group_id = $1
`, groupID); err != nil {
return err
}
_, err := r.sql.ExecContext(ctx, `
DELETE FROM user_group_rate_multipliers
WHERE group_id = $1 AND rate_multiplier IS NULL AND rpm_override IS NULL
`, groupID)
return err
}
// DeleteByGroupID 删除指定分组的所有用户专属条目
func (r *userGroupRateRepository) DeleteByGroupID(ctx context.Context, groupID int64) error { func (r *userGroupRateRepository) DeleteByGroupID(ctx context.Context, groupID int64) error {
_, err := r.sql.ExecContext(ctx, `DELETE FROM user_group_rate_multipliers WHERE group_id = $1`, groupID) _, err := r.sql.ExecContext(ctx, `DELETE FROM user_group_rate_multipliers WHERE group_id = $1`, groupID)
return err return err
} }
// DeleteByUserID 删除指定用户的所有专属倍率 // DeleteByUserID 删除指定用户的所有专属条目
func (r *userGroupRateRepository) DeleteByUserID(ctx context.Context, userID int64) error { func (r *userGroupRateRepository) DeleteByUserID(ctx context.Context, userID int64) error {
_, err := r.sql.ExecContext(ctx, `DELETE FROM user_group_rate_multipliers WHERE user_id = $1`, userID) _, err := r.sql.ExecContext(ctx, `DELETE FROM user_group_rate_multipliers WHERE user_id = $1`, userID)
return err return err
......
...@@ -93,6 +93,7 @@ func (r *userRepository) Create(ctx context.Context, userIn *service.User) error ...@@ -93,6 +93,7 @@ func (r *userRepository) Create(ctx context.Context, userIn *service.User) error
SetSignupSource(userSignupSourceOrDefault(userIn.SignupSource)). SetSignupSource(userSignupSourceOrDefault(userIn.SignupSource)).
SetNillableLastLoginAt(userIn.LastLoginAt). SetNillableLastLoginAt(userIn.LastLoginAt).
SetNillableLastActiveAt(userIn.LastActiveAt). SetNillableLastActiveAt(userIn.LastActiveAt).
SetRpmLimit(userIn.RPMLimit).
Save(txCtx) Save(txCtx)
if err != nil { if err != nil {
return translatePersistenceError(err, nil, service.ErrEmailExists) return translatePersistenceError(err, nil, service.ErrEmailExists)
...@@ -219,7 +220,8 @@ func (r *userRepository) Update(ctx context.Context, userIn *service.User) error ...@@ -219,7 +220,8 @@ func (r *userRepository) Update(ctx context.Context, userIn *service.User) error
SetBalanceNotifyThresholdType(userIn.BalanceNotifyThresholdType). SetBalanceNotifyThresholdType(userIn.BalanceNotifyThresholdType).
SetNillableBalanceNotifyThreshold(userIn.BalanceNotifyThreshold). SetNillableBalanceNotifyThreshold(userIn.BalanceNotifyThreshold).
SetBalanceNotifyExtraEmails(marshalExtraEmails(userIn.BalanceNotifyExtraEmails)). SetBalanceNotifyExtraEmails(marshalExtraEmails(userIn.BalanceNotifyExtraEmails)).
SetTotalRecharged(userIn.TotalRecharged) SetTotalRecharged(userIn.TotalRecharged).
SetRpmLimit(userIn.RPMLimit)
if userIn.SignupSource != "" { if userIn.SignupSource != "" {
updateOp = updateOp.SetSignupSource(userIn.SignupSource) updateOp = updateOp.SetSignupSource(userIn.SignupSource)
} }
......
package repository
import (
"context"
"fmt"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/redis/go-redis/v9"
)
// 用户/分组级 RPM 计数器 Redis 实现。
//
// 设计说明:
// - key 形式:rpm:ug:{uid}:{gid}:{minute}、rpm:u:{uid}:{minute}
// - 时间来源:rdb.Time()(Redis 服务端时间),避免多实例时钟漂移。
// - 原子操作:TxPipeline (MULTI/EXEC) 执行 INCR+EXPIRE,兼容 Redis Cluster。
// - TTL:120s,覆盖当前分钟窗口 + 少量冗余。
// - 返回值语义:超限判断由调用方(billing_cache_service.checkRPM)与 RPMLimit 比较完成。
const (
userGroupRPMKeyPrefix = "rpm:ug:"
userRPMKeyPrefix = "rpm:u:"
userRPMKeyTTL = 120 * time.Second
)
type userRPMCacheImpl struct {
rdb *redis.Client
}
// NewUserRPMCache 创建用户/分组级 RPM 计数器。
func NewUserRPMCache(rdb *redis.Client) service.UserRPMCache {
return &userRPMCacheImpl{rdb: rdb}
}
// minuteTS 获取当前 Redis 服务端分钟时间戳。
func (c *userRPMCacheImpl) minuteTS(ctx context.Context) (int64, error) {
t, err := c.rdb.Time(ctx).Result()
if err != nil {
return 0, fmt.Errorf("redis TIME: %w", err)
}
return t.Unix() / 60, nil
}
// atomicIncr 原子 INCR+EXPIRE。
func (c *userRPMCacheImpl) atomicIncr(ctx context.Context, key string) (int, error) {
pipe := c.rdb.TxPipeline()
incr := pipe.Incr(ctx, key)
pipe.Expire(ctx, key, userRPMKeyTTL)
if _, err := pipe.Exec(ctx); err != nil {
return 0, fmt.Errorf("user rpm increment: %w", err)
}
return int(incr.Val()), nil
}
// IncrementUserGroupRPM 递增 (user, group) 分钟计数。
func (c *userRPMCacheImpl) IncrementUserGroupRPM(ctx context.Context, userID, groupID int64) (int, error) {
minute, err := c.minuteTS(ctx)
if err != nil {
return 0, err
}
key := fmt.Sprintf("%s%d:%d:%d", userGroupRPMKeyPrefix, userID, groupID, minute)
return c.atomicIncr(ctx, key)
}
// IncrementUserRPM 递增用户分钟计数。
func (c *userRPMCacheImpl) IncrementUserRPM(ctx context.Context, userID int64) (int, error) {
minute, err := c.minuteTS(ctx)
if err != nil {
return 0, err
}
key := fmt.Sprintf("%s%d:%d", userRPMKeyPrefix, userID, minute)
return c.atomicIncr(ctx, key)
}
// GetUserGroupRPM 获取 (user, group) 当前分钟已用 RPM(只读)。
func (c *userRPMCacheImpl) GetUserGroupRPM(ctx context.Context, userID, groupID int64) (int, error) {
minute, err := c.minuteTS(ctx)
if err != nil {
return 0, err
}
key := fmt.Sprintf("%s%d:%d:%d", userGroupRPMKeyPrefix, userID, groupID, minute)
val, err := c.rdb.Get(ctx, key).Int()
if err == redis.Nil {
return 0, nil
}
if err != nil {
return 0, fmt.Errorf("user group rpm get: %w", err)
}
return val, nil
}
// GetUserRPM 获取用户当前分钟已用 RPM(只读)。
func (c *userRPMCacheImpl) GetUserRPM(ctx context.Context, userID int64) (int, error) {
minute, err := c.minuteTS(ctx)
if err != nil {
return 0, err
}
key := fmt.Sprintf("%s%d:%d", userRPMKeyPrefix, userID, minute)
val, err := c.rdb.Get(ctx, key).Int()
if err == redis.Nil {
return 0, nil
}
if err != nil {
return 0, fmt.Errorf("user rpm get: %w", err)
}
return val, nil
}
...@@ -98,10 +98,12 @@ var ProviderSet = wire.NewSet( ...@@ -98,10 +98,12 @@ var ProviderSet = wire.NewSet(
NewAPIKeyCache, NewAPIKeyCache,
NewTempUnschedCache, NewTempUnschedCache,
NewTimeoutCounterCache, NewTimeoutCounterCache,
NewOpenAI403CounterCache,
NewInternal500CounterCache, NewInternal500CounterCache,
ProvideConcurrencyCache, ProvideConcurrencyCache,
ProvideSessionLimitCache, ProvideSessionLimitCache,
NewRPMCache, NewRPMCache,
NewUserRPMCache,
NewUserMsgQueueCache, NewUserMsgQueueCache,
NewDashboardCache, NewDashboardCache,
NewEmailCache, NewEmailCache,
......
...@@ -55,6 +55,7 @@ func TestAPIContracts(t *testing.T) { ...@@ -55,6 +55,7 @@ func TestAPIContracts(t *testing.T) {
"role": "user", "role": "user",
"balance": 12.5, "balance": 12.5,
"concurrency": 5, "concurrency": 5,
"rpm_limit": 0,
"status": "active", "status": "active",
"allowed_groups": null, "allowed_groups": null,
"created_at": "2025-01-02T03:04:05Z", "created_at": "2025-01-02T03:04:05Z",
...@@ -333,6 +334,7 @@ func TestAPIContracts(t *testing.T) { ...@@ -333,6 +334,7 @@ func TestAPIContracts(t *testing.T) {
"fallback_group_id_on_invalid_request": null, "fallback_group_id_on_invalid_request": null,
"require_oauth_only": false, "require_oauth_only": false,
"require_privacy_set": false, "require_privacy_set": false,
"rpm_limit": 0,
"created_at": "2025-01-02T03:04:05Z", "created_at": "2025-01-02T03:04:05Z",
"updated_at": "2025-01-02T03:04:05Z" "updated_at": "2025-01-02T03:04:05Z"
} }
...@@ -713,6 +715,7 @@ func TestAPIContracts(t *testing.T) { ...@@ -713,6 +715,7 @@ func TestAPIContracts(t *testing.T) {
"force_email_on_third_party_signup": false, "force_email_on_third_party_signup": false,
"default_concurrency": 5, "default_concurrency": 5,
"default_balance": 1.25, "default_balance": 1.25,
"default_user_rpm_limit": 0,
"default_subscriptions": [], "default_subscriptions": [],
"enable_model_fallback": false, "enable_model_fallback": false,
"fallback_model_anthropic": "claude-3-5-sonnet-20241022", "fallback_model_anthropic": "claude-3-5-sonnet-20241022",
...@@ -892,6 +895,7 @@ func TestAPIContracts(t *testing.T) { ...@@ -892,6 +895,7 @@ func TestAPIContracts(t *testing.T) {
"custom_endpoints": [], "custom_endpoints": [],
"default_concurrency": 0, "default_concurrency": 0,
"default_balance": 0, "default_balance": 0,
"default_user_rpm_limit": 0,
"default_subscriptions": [], "default_subscriptions": [],
"enable_model_fallback": false, "enable_model_fallback": false,
"fallback_model_anthropic": "claude-3-5-sonnet-20241022", "fallback_model_anthropic": "claude-3-5-sonnet-20241022",
...@@ -1090,7 +1094,7 @@ func newContractDeps(t *testing.T) *contractDeps { ...@@ -1090,7 +1094,7 @@ func newContractDeps(t *testing.T) *contractDeps {
settingRepo := newStubSettingRepo() settingRepo := newStubSettingRepo()
settingService := service.NewSettingService(settingRepo, cfg) settingService := service.NewSettingService(settingRepo, cfg)
adminService := service.NewAdminService(userRepo, groupRepo, &accountRepo, proxyRepo, apiKeyRepo, redeemRepo, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil) adminService := service.NewAdminService(userRepo, groupRepo, &accountRepo, proxyRepo, apiKeyRepo, redeemRepo, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
authHandler := handler.NewAuthHandler(cfg, nil, userService, settingService, nil, redeemService, nil) authHandler := handler.NewAuthHandler(cfg, nil, userService, settingService, nil, redeemService, nil)
apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService) apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService)
usageHandler := handler.NewUsageHandler(usageService, apiKeyService) usageHandler := handler.NewUsageHandler(usageService, apiKeyService)
......
...@@ -224,6 +224,7 @@ func registerUserManagementRoutes(admin *gin.RouterGroup, h *handler.Handlers) { ...@@ -224,6 +224,7 @@ func registerUserManagementRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
users.GET("/:id/usage", h.Admin.User.GetUserUsage) users.GET("/:id/usage", h.Admin.User.GetUserUsage)
users.GET("/:id/balance-history", h.Admin.User.GetBalanceHistory) users.GET("/:id/balance-history", h.Admin.User.GetBalanceHistory)
users.POST("/:id/replace-group", h.Admin.User.ReplaceGroup) users.POST("/:id/replace-group", h.Admin.User.ReplaceGroup)
users.GET("/:id/rpm-status", h.Admin.User.GetUserRPMStatus)
// User attribute values // User attribute values
users.GET("/:id/attributes", h.Admin.UserAttribute.GetUserAttributes) users.GET("/:id/attributes", h.Admin.UserAttribute.GetUserAttributes)
...@@ -247,6 +248,8 @@ func registerGroupRoutes(admin *gin.RouterGroup, h *handler.Handlers) { ...@@ -247,6 +248,8 @@ func registerGroupRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
groups.GET("/:id/rate-multipliers", h.Admin.Group.GetGroupRateMultipliers) groups.GET("/:id/rate-multipliers", h.Admin.Group.GetGroupRateMultipliers)
groups.PUT("/:id/rate-multipliers", h.Admin.Group.BatchSetGroupRateMultipliers) groups.PUT("/:id/rate-multipliers", h.Admin.Group.BatchSetGroupRateMultipliers)
groups.DELETE("/:id/rate-multipliers", h.Admin.Group.ClearGroupRateMultipliers) groups.DELETE("/:id/rate-multipliers", h.Admin.Group.ClearGroupRateMultipliers)
groups.PUT("/:id/rpm-overrides", h.Admin.Group.BatchSetGroupRPMOverrides)
groups.DELETE("/:id/rpm-overrides", h.Admin.Group.ClearGroupRPMOverrides)
groups.GET("/:id/api-keys", h.Admin.Group.GetGroupAPIKeys) groups.GET("/:id/api-keys", h.Admin.Group.GetGroupAPIKeys)
} }
} }
......
...@@ -930,10 +930,8 @@ func (a *Account) SupportsOpenAIImageCapability(capability OpenAIImagesCapabilit ...@@ -930,10 +930,8 @@ func (a *Account) SupportsOpenAIImageCapability(capability OpenAIImagesCapabilit
return false return false
} }
switch capability { switch capability {
case OpenAIImagesCapabilityBasic: case OpenAIImagesCapabilityBasic, OpenAIImagesCapabilityNative:
return a.Type == AccountTypeOAuth || a.Type == AccountTypeAPIKey return a.Type == AccountTypeOAuth || a.Type == AccountTypeAPIKey
case OpenAIImagesCapabilityNative:
return a.Type == AccountTypeAPIKey
default: default:
return true return true
} }
......
...@@ -5,7 +5,6 @@ import ( ...@@ -5,7 +5,6 @@ import (
"bytes" "bytes"
"context" "context"
"crypto/rand" "crypto/rand"
"encoding/base64"
"encoding/hex" "encoding/hex"
"encoding/json" "encoding/json"
"errors" "errors"
...@@ -1138,7 +1137,7 @@ func (s *AccountTestService) testOpenAIImageAPIKey(c *gin.Context, ctx context.C ...@@ -1138,7 +1137,7 @@ func (s *AccountTestService) testOpenAIImageAPIKey(c *gin.Context, ctx context.C
return nil return nil
} }
// testOpenAIImageOAuth tests OpenAI image generation using an OAuth account via ChatGPT backend API. // testOpenAIImageOAuth tests OpenAI image generation using an OAuth account via Codex /responses API.
func (s *AccountTestService) testOpenAIImageOAuth(c *gin.Context, ctx context.Context, account *Account, modelID, prompt string) error { func (s *AccountTestService) testOpenAIImageOAuth(c *gin.Context, ctx context.Context, account *Account, modelID, prompt string) error {
authToken := account.GetOpenAIAccessToken() authToken := account.GetOpenAIAccessToken()
if authToken == "" { if authToken == "" {
...@@ -1153,69 +1152,46 @@ func (s *AccountTestService) testOpenAIImageOAuth(c *gin.Context, ctx context.Co ...@@ -1153,69 +1152,46 @@ func (s *AccountTestService) testOpenAIImageOAuth(c *gin.Context, ctx context.Co
c.Writer.Flush() c.Writer.Flush()
s.sendEvent(c, TestEvent{Type: "test_start", Model: modelID}) s.sendEvent(c, TestEvent{Type: "test_start", Model: modelID})
s.sendEvent(c, TestEvent{Type: "content", Text: "Initializing ChatGPT backend...\n"}) s.sendEvent(c, TestEvent{Type: "content", Text: "Calling Codex /responses image tool...\n"})
// Build headers (replicating buildOpenAIBackendAPIHeaders logic) parsed := &OpenAIImagesRequest{
headers := buildOpenAIBackendAPIHeadersForTest(ctx, account, authToken, s.accountRepo) Endpoint: openAIImagesGenerationsEndpoint,
Model: strings.TrimSpace(modelID),
proxyURL := "" Prompt: prompt,
if account.ProxyID != nil && account.Proxy != nil {
proxyURL = account.Proxy.URL()
} }
applyOpenAIImagesDefaults(parsed)
client, err := newOpenAIBackendAPIClient(proxyURL) responsesBody, err := buildOpenAIImagesResponsesRequest(parsed, parsed.Model)
if err != nil { if err != nil {
return s.sendErrorAndEnd(c, fmt.Sprintf("Failed to create client: %s", err.Error())) return s.sendErrorAndEnd(c, fmt.Sprintf("Failed to build image request: %s", err.Error()))
} }
// Bootstrap req, err := http.NewRequestWithContext(ctx, http.MethodPost, chatgptCodexAPIURL, bytes.NewReader(responsesBody))
if bootstrapErr := bootstrapOpenAIBackendAPI(ctx, client, headers); bootstrapErr != nil {
log.Printf("OpenAI image test bootstrap warning: %v", bootstrapErr)
}
// Fetch chat requirements
s.sendEvent(c, TestEvent{Type: "content", Text: "Fetching chat requirements...\n"})
chatReqs, err := fetchOpenAIChatRequirements(ctx, client, headers)
if err != nil { if err != nil {
return s.sendErrorAndEnd(c, fmt.Sprintf("Chat requirements failed: %s", err.Error())) return s.sendErrorAndEnd(c, "Failed to create request")
} }
if chatReqs.Arkose.Required { req.Host = "chatgpt.com"
return s.sendErrorAndEnd(c, "Unsupported challenge: arkose required") req.Header.Set("Authorization", "Bearer "+authToken)
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept", "text/event-stream")
req.Header.Set("OpenAI-Beta", "responses=experimental")
req.Header.Set("originator", "opencode")
if customUA := strings.TrimSpace(account.GetOpenAIUserAgent()); customUA != "" {
req.Header.Set("User-Agent", customUA)
} else {
req.Header.Set("User-Agent", codexCLIUserAgent)
} }
if chatgptAccountID := strings.TrimSpace(account.GetChatGPTAccountID()); chatgptAccountID != "" {
// Initialize and prepare conversation req.Header.Set("chatgpt-account-id", chatgptAccountID)
s.sendEvent(c, TestEvent{Type: "content", Text: "Preparing image conversation...\n"})
parentMessageID := uuid.NewString()
proofToken := generateOpenAIProofToken(chatReqs.ProofOfWork.Required, chatReqs.ProofOfWork.Seed, chatReqs.ProofOfWork.Difficulty, headers.Get("User-Agent"))
_ = initializeOpenAIImageConversation(ctx, client, headers)
conduitToken, err := prepareOpenAIImageConversation(ctx, client, headers, prompt, parentMessageID, chatReqs.Token, proofToken)
if err != nil {
return s.sendErrorAndEnd(c, fmt.Sprintf("Conversation prepare failed: %s", err.Error()))
} }
// Build simplified conversation request (no file uploads) proxyURL := ""
convReq := buildOpenAIImageTestConversationRequest(prompt, parentMessageID) if account.ProxyID != nil && account.Proxy != nil {
convHeaders := cloneHTTPHeader(headers) proxyURL = account.Proxy.URL()
convHeaders.Set("Accept", "text/event-stream")
convHeaders.Set("Content-Type", "application/json")
convHeaders.Set("openai-sentinel-chat-requirements-token", chatReqs.Token)
if conduitToken != "" {
convHeaders.Set("x-conduit-token", conduitToken)
}
if proofToken != "" {
convHeaders.Set("openai-sentinel-proof-token", proofToken)
} }
resp, err := s.httpUpstream.Do(req, proxyURL, account.ID, account.Concurrency)
s.sendEvent(c, TestEvent{Type: "content", Text: "Generating image...\n"})
resp, err := client.R().
SetContext(ctx).
DisableAutoReadResponse().
SetHeaders(headerToMap(convHeaders)).
SetBodyJsonMarshal(convReq).
Post(openAIChatGPTConversationURL)
if err != nil { if err != nil {
return s.sendErrorAndEnd(c, fmt.Sprintf("Conversation request failed: %s", err.Error())) return s.sendErrorAndEnd(c, fmt.Sprintf("Responses API request failed: %s", err.Error()))
} }
defer func() { defer func() {
if resp != nil && resp.Body != nil { if resp != nil && resp.Body != nil {
...@@ -1223,49 +1199,35 @@ func (s *AccountTestService) testOpenAIImageOAuth(c *gin.Context, ctx context.Co ...@@ -1223,49 +1199,35 @@ func (s *AccountTestService) testOpenAIImageOAuth(c *gin.Context, ctx context.Co
} }
}() }()
if resp.StatusCode >= 400 { if resp.StatusCode >= 400 {
return s.sendErrorAndEnd(c, fmt.Sprintf("Conversation API returned %d", resp.StatusCode)) body, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
message := strings.TrimSpace(extractUpstreamErrorMessage(body))
if message == "" {
message = fmt.Sprintf("Responses API returned %d", resp.StatusCode)
}
return s.sendErrorAndEnd(c, message)
} }
startTime := time.Now() body, err := io.ReadAll(resp.Body)
conversationID, pointerInfos, _, _, err := readOpenAIImageConversationStream(resp, startTime)
if err != nil { if err != nil {
return s.sendErrorAndEnd(c, fmt.Sprintf("Stream read failed: %s", err.Error())) return s.sendErrorAndEnd(c, fmt.Sprintf("Failed to read image response: %s", err.Error()))
} }
pointerInfos = mergeOpenAIImagePointerInfos(pointerInfos, nil) results, _, _, _, _, err := collectOpenAIImagesFromResponsesBody(body)
if conversationID != "" && !hasOpenAIFileServicePointerInfos(pointerInfos) { if err != nil {
s.sendEvent(c, TestEvent{Type: "content", Text: "Waiting for image generation to complete...\n"}) return s.sendErrorAndEnd(c, fmt.Sprintf("Failed to parse image response: %s", err.Error()))
polledPointers, pollErr := pollOpenAIImageConversation(ctx, client, headers, conversationID)
if pollErr != nil {
return s.sendErrorAndEnd(c, fmt.Sprintf("Poll failed: %s", pollErr.Error()))
}
pointerInfos = mergeOpenAIImagePointerInfos(pointerInfos, polledPointers)
} }
pointerInfos = preferOpenAIFileServicePointerInfos(pointerInfos) if len(results) == 0 {
if len(pointerInfos) == 0 { return s.sendErrorAndEnd(c, "No images returned from responses API")
return s.sendErrorAndEnd(c, "No images returned from conversation")
} }
s.sendEvent(c, TestEvent{Type: "content", Text: "Downloading generated image...\n"}) for _, item := range results {
if item.RevisedPrompt != "" {
// Download and encode each image s.sendEvent(c, TestEvent{Type: "content", Text: item.RevisedPrompt})
for _, pointer := range pointerInfos {
downloadURL, err := fetchOpenAIImageDownloadURL(ctx, client, headers, conversationID, pointer.Pointer)
if err != nil {
return s.sendErrorAndEnd(c, fmt.Sprintf("Download URL fetch failed: %s", err.Error()))
}
data, err := downloadOpenAIImageBytes(ctx, client, headers, downloadURL)
if err != nil {
return s.sendErrorAndEnd(c, fmt.Sprintf("Image download failed: %s", err.Error()))
}
b64 := base64.StdEncoding.EncodeToString(data)
mimeType := http.DetectContentType(data)
if pointer.Prompt != "" {
s.sendEvent(c, TestEvent{Type: "content", Text: pointer.Prompt})
} }
mimeType := openAIImageOutputMIMEType(item.OutputFormat)
s.sendEvent(c, TestEvent{ s.sendEvent(c, TestEvent{
Type: "image", Type: "image",
ImageURL: "data:" + mimeType + ";base64," + b64, ImageURL: "data:" + mimeType + ";base64," + item.Result,
MimeType: mimeType, MimeType: mimeType,
}) })
} }
...@@ -1274,107 +1236,6 @@ func (s *AccountTestService) testOpenAIImageOAuth(c *gin.Context, ctx context.Co ...@@ -1274,107 +1236,6 @@ func (s *AccountTestService) testOpenAIImageOAuth(c *gin.Context, ctx context.Co
return nil return nil
} }
// buildOpenAIBackendAPIHeadersForTest builds ChatGPT backend API headers for test purposes.
// Replicates the logic from OpenAIGatewayService.buildOpenAIBackendAPIHeaders without
// requiring the full gateway service dependency.
func buildOpenAIBackendAPIHeadersForTest(ctx context.Context, account *Account, token string, repo AccountRepository) http.Header {
// Ensure device and session IDs exist
deviceID := account.GetOpenAIDeviceID()
sessionID := account.GetOpenAISessionID()
if deviceID == "" || sessionID == "" {
updates := map[string]any{}
if deviceID == "" {
deviceID = uuid.NewString()
updates["openai_device_id"] = deviceID
}
if sessionID == "" {
sessionID = uuid.NewString()
updates["openai_session_id"] = sessionID
}
if account.Extra == nil {
account.Extra = map[string]any{}
}
for key, value := range updates {
account.Extra[key] = value
}
if repo != nil {
updateCtx, cancel := context.WithTimeout(ctx, 5*time.Second)
defer cancel()
_ = repo.UpdateExtra(updateCtx, account.ID, updates)
}
}
headers := make(http.Header)
headers.Set("Authorization", "Bearer "+token)
headers.Set("Accept", "application/json")
headers.Set("Origin", "https://chatgpt.com")
headers.Set("Referer", "https://chatgpt.com/")
headers.Set("Sec-Fetch-Dest", "empty")
headers.Set("Sec-Fetch-Mode", "cors")
headers.Set("Sec-Fetch-Site", "same-origin")
headers.Set("User-Agent", openAIImageBackendUserAgent)
if customUA := strings.TrimSpace(account.GetOpenAIUserAgent()); customUA != "" {
headers.Set("User-Agent", customUA)
}
if chatgptAccountID := strings.TrimSpace(account.GetChatGPTAccountID()); chatgptAccountID != "" {
headers.Set("chatgpt-account-id", chatgptAccountID)
}
if deviceID != "" {
headers.Set("oai-device-id", deviceID)
headers.Set("Cookie", "oai-did="+deviceID)
}
if sessionID != "" {
headers.Set("oai-session-id", sessionID)
}
return headers
}
// buildOpenAIImageTestConversationRequest creates a simplified image generation conversation request.
func buildOpenAIImageTestConversationRequest(prompt, parentMessageID string) map[string]any {
promptText := strings.TrimSpace(prompt)
if promptText == "" {
promptText = "Generate an image."
}
metadata := map[string]any{
"developer_mode_connector_ids": []any{},
"selected_github_repos": []any{},
"selected_all_github_repos": false,
"system_hints": []string{"picture_v2"},
"serialization_metadata": map[string]any{
"custom_symbol_offsets": []any{},
},
}
message := map[string]any{
"id": uuid.NewString(),
"author": map[string]any{"role": "user"},
"content": map[string]any{
"content_type": "text",
"parts": []any{promptText},
},
"metadata": metadata,
"create_time": float64(time.Now().UnixMilli()) / 1000,
}
return map[string]any{
"action": "next",
"client_prepare_state": "sent",
"parent_message_id": parentMessageID,
"messages": []any{message},
"model": "auto",
"timezone_offset_min": openAITimezoneOffsetMinutes(),
"timezone": openAITimezoneName(),
"conversation_mode": map[string]any{"kind": "primary_assistant"},
"system_hints": []string{"picture_v2"},
"supports_buffering": true,
"supported_encodings": []string{"v1"},
"client_contextual_info": map[string]any{"app_name": "chatgpt.com"},
"force_nulligen": false,
"force_paragen": false,
"force_paragen_model_slug": "",
"force_rate_limit": false,
"websocket_request_id": uuid.NewString(),
}
}
func (s *AccountTestService) sendEvent(c *gin.Context, event TestEvent) { func (s *AccountTestService) sendEvent(c *gin.Context, event TestEvent) {
eventJSON, _ := json.Marshal(event) eventJSON, _ := json.Marshal(event)
if _, err := fmt.Fprintf(c.Writer, "data: %s\n\n", eventJSON); err != nil { if _, err := fmt.Fprintf(c.Writer, "data: %s\n\n", eventJSON); err != nil {
......
package service
import (
"context"
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
func TestAccountTestService_OpenAIImageOAuthHandlesOutputItemDoneFallback(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/api/v1/admin/accounts/1/test", nil)
upstream := &httpUpstreamRecorder{
resp: &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{
"Content-Type": []string{"text/event-stream"},
},
Body: io.NopCloser(strings.NewReader(
"data: {\"type\":\"response.output_item.done\",\"item\":{\"id\":\"ig_123\",\"type\":\"image_generation_call\",\"result\":\"aGVsbG8=\",\"revised_prompt\":\"draw a cat\",\"output_format\":\"png\"}}\n\n" +
"data: {\"type\":\"response.completed\",\"response\":{\"created_at\":1710000006,\"tool_usage\":{\"image_gen\":{\"images\":1}},\"output\":[]}}\n\n" +
"data: [DONE]\n\n",
)),
},
}
svc := &AccountTestService{httpUpstream: upstream}
account := &Account{
ID: 53,
Name: "openai-oauth",
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "token-123",
},
}
err := svc.testOpenAIImageOAuth(c, context.Background(), account, "gpt-image-2", "draw a cat")
require.NoError(t, err)
require.Contains(t, rec.Body.String(), "Calling Codex /responses image tool")
require.Contains(t, rec.Body.String(), "data:image/png;base64,aGVsbG8=")
require.Contains(t, rec.Body.String(), "\"success\":true")
}
...@@ -8,6 +8,7 @@ import ( ...@@ -8,6 +8,7 @@ import (
"io" "io"
"log/slog" "log/slog"
"net/http" "net/http"
"sort"
"strings" "strings"
"time" "time"
...@@ -32,6 +33,7 @@ type AdminService interface { ...@@ -32,6 +33,7 @@ type AdminService interface {
UpdateUserBalance(ctx context.Context, userID int64, balance float64, operation string, notes string) (*User, error) UpdateUserBalance(ctx context.Context, userID int64, balance float64, operation string, notes string) (*User, error)
GetUserAPIKeys(ctx context.Context, userID int64, page, pageSize int, sortBy, sortOrder string) ([]APIKey, int64, error) GetUserAPIKeys(ctx context.Context, userID int64, page, pageSize int, sortBy, sortOrder string) ([]APIKey, int64, error)
GetUserUsageStats(ctx context.Context, userID int64, period string) (any, error) GetUserUsageStats(ctx context.Context, userID int64, period string) (any, error)
GetUserRPMStatus(ctx context.Context, userID int64) (*UserRPMStatus, error)
// GetUserBalanceHistory returns paginated balance/concurrency change records for a user. // GetUserBalanceHistory returns paginated balance/concurrency change records for a user.
// codeType is optional - pass empty string to return all types. // codeType is optional - pass empty string to return all types.
// Also returns totalRecharged (sum of all positive balance top-ups). // Also returns totalRecharged (sum of all positive balance top-ups).
...@@ -50,6 +52,8 @@ type AdminService interface { ...@@ -50,6 +52,8 @@ type AdminService interface {
GetGroupRateMultipliers(ctx context.Context, groupID int64) ([]UserGroupRateEntry, error) GetGroupRateMultipliers(ctx context.Context, groupID int64) ([]UserGroupRateEntry, error)
ClearGroupRateMultipliers(ctx context.Context, groupID int64) error ClearGroupRateMultipliers(ctx context.Context, groupID int64) error
BatchSetGroupRateMultipliers(ctx context.Context, groupID int64, entries []GroupRateMultiplierInput) error BatchSetGroupRateMultipliers(ctx context.Context, groupID int64, entries []GroupRateMultiplierInput) error
ClearGroupRPMOverrides(ctx context.Context, groupID int64) error
BatchSetGroupRPMOverrides(ctx context.Context, groupID int64, entries []GroupRPMOverrideInput) error
UpdateGroupSortOrders(ctx context.Context, updates []GroupSortOrderUpdate) error UpdateGroupSortOrders(ctx context.Context, updates []GroupSortOrderUpdate) error
// API Key management (admin) // API Key management (admin)
...@@ -114,6 +118,7 @@ type CreateUserInput struct { ...@@ -114,6 +118,7 @@ type CreateUserInput struct {
Notes string Notes string
Balance float64 Balance float64
Concurrency int Concurrency int
RPMLimit int
AllowedGroups []int64 AllowedGroups []int64
} }
...@@ -124,6 +129,7 @@ type UpdateUserInput struct { ...@@ -124,6 +129,7 @@ type UpdateUserInput struct {
Notes *string Notes *string
Balance *float64 // 使用指针区分"未提供"和"设置为0" Balance *float64 // 使用指针区分"未提供"和"设置为0"
Concurrency *int // 使用指针区分"未提供"和"设置为0" Concurrency *int // 使用指针区分"未提供"和"设置为0"
RPMLimit *int // 使用指针区分"未提供"和"设置为0"
Status string Status string
AllowedGroups *[]int64 // 使用指针区分"未提供"和"设置为空数组" AllowedGroups *[]int64 // 使用指针区分"未提供"和"设置为空数组"
// GroupRates 用户专属分组倍率配置 // GroupRates 用户专属分组倍率配置
...@@ -199,6 +205,8 @@ type CreateGroupInput struct { ...@@ -199,6 +205,8 @@ type CreateGroupInput struct {
RequireOAuthOnly bool RequireOAuthOnly bool
RequirePrivacySet bool RequirePrivacySet bool
MessagesDispatchModelConfig OpenAIMessagesDispatchModelConfig MessagesDispatchModelConfig OpenAIMessagesDispatchModelConfig
// RPMLimit 分组 RPM 上限(0 = 不限制)
RPMLimit int
// 从指定分组复制账号(创建分组后在同一事务内绑定) // 从指定分组复制账号(创建分组后在同一事务内绑定)
CopyAccountsFromGroupIDs []int64 CopyAccountsFromGroupIDs []int64
} }
...@@ -234,6 +242,8 @@ type UpdateGroupInput struct { ...@@ -234,6 +242,8 @@ type UpdateGroupInput struct {
RequireOAuthOnly *bool RequireOAuthOnly *bool
RequirePrivacySet *bool RequirePrivacySet *bool
MessagesDispatchModelConfig *OpenAIMessagesDispatchModelConfig MessagesDispatchModelConfig *OpenAIMessagesDispatchModelConfig
// RPMLimit 分组 RPM 上限(0 = 不限制),nil 表示未提供不改动。
RPMLimit *int
// 从指定分组复制账号(同步操作:先清空当前分组的账号绑定,再绑定源分组的账号) // 从指定分组复制账号(同步操作:先清空当前分组的账号绑定,再绑定源分组的账号)
CopyAccountsFromGroupIDs []int64 CopyAccountsFromGroupIDs []int64
} }
...@@ -317,6 +327,22 @@ type ReplaceUserGroupResult struct { ...@@ -317,6 +327,22 @@ type ReplaceUserGroupResult struct {
MigratedKeys int64 // 迁移的 Key 数量 MigratedKeys int64 // 迁移的 Key 数量
} }
// UserRPMStatus describes a user's current per-minute RPM usage.
type UserRPMStatus struct {
UserRPMUsed int `json:"user_rpm_used"`
UserRPMLimit int `json:"user_rpm_limit"`
PerGroup []UserGroupRPMStatus `json:"per_group"`
}
// UserGroupRPMStatus describes current per-minute RPM usage for one user/group pair.
type UserGroupRPMStatus struct {
GroupID int64 `json:"group_id"`
GroupName string `json:"group_name"`
Used int `json:"used"`
Limit int `json:"limit"`
Source string `json:"source"` // "group" | "override"
}
// BulkUpdateAccountsResult is the aggregated response for bulk updates. // BulkUpdateAccountsResult is the aggregated response for bulk updates.
type BulkUpdateAccountsResult struct { type BulkUpdateAccountsResult struct {
Success int `json:"success"` Success int `json:"success"`
...@@ -463,6 +489,8 @@ const ( ...@@ -463,6 +489,8 @@ const (
proxyQualityClientUserAgent = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/136.0.0.0 Safari/537.36" proxyQualityClientUserAgent = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/136.0.0.0 Safari/537.36"
) )
var ErrRPMStatusUnavailable = infraerrors.New(http.StatusNotImplemented, "RPM_STATUS_UNAVAILABLE", "RPM cache not available")
// adminServiceImpl implements AdminService // adminServiceImpl implements AdminService
type adminServiceImpl struct { type adminServiceImpl struct {
userRepo UserRepository userRepo UserRepository
...@@ -472,6 +500,7 @@ type adminServiceImpl struct { ...@@ -472,6 +500,7 @@ type adminServiceImpl struct {
apiKeyRepo APIKeyRepository apiKeyRepo APIKeyRepository
redeemCodeRepo RedeemCodeRepository redeemCodeRepo RedeemCodeRepository
userGroupRateRepo UserGroupRateRepository userGroupRateRepo UserGroupRateRepository
userRPMCache UserRPMCache
billingCacheService *BillingCacheService billingCacheService *BillingCacheService
proxyProber ProxyExitInfoProber proxyProber ProxyExitInfoProber
proxyLatencyCache ProxyLatencyCache proxyLatencyCache ProxyLatencyCache
...@@ -496,6 +525,7 @@ func NewAdminService( ...@@ -496,6 +525,7 @@ func NewAdminService(
apiKeyRepo APIKeyRepository, apiKeyRepo APIKeyRepository,
redeemCodeRepo RedeemCodeRepository, redeemCodeRepo RedeemCodeRepository,
userGroupRateRepo UserGroupRateRepository, userGroupRateRepo UserGroupRateRepository,
userRPMCache UserRPMCache,
billingCacheService *BillingCacheService, billingCacheService *BillingCacheService,
proxyProber ProxyExitInfoProber, proxyProber ProxyExitInfoProber,
proxyLatencyCache ProxyLatencyCache, proxyLatencyCache ProxyLatencyCache,
...@@ -514,6 +544,7 @@ func NewAdminService( ...@@ -514,6 +544,7 @@ func NewAdminService(
apiKeyRepo: apiKeyRepo, apiKeyRepo: apiKeyRepo,
redeemCodeRepo: redeemCodeRepo, redeemCodeRepo: redeemCodeRepo,
userGroupRateRepo: userGroupRateRepo, userGroupRateRepo: userGroupRateRepo,
userRPMCache: userRPMCache,
billingCacheService: billingCacheService, billingCacheService: billingCacheService,
proxyProber: proxyProber, proxyProber: proxyProber,
proxyLatencyCache: proxyLatencyCache, proxyLatencyCache: proxyLatencyCache,
...@@ -617,6 +648,7 @@ func (s *adminServiceImpl) CreateUser(ctx context.Context, input *CreateUserInpu ...@@ -617,6 +648,7 @@ func (s *adminServiceImpl) CreateUser(ctx context.Context, input *CreateUserInpu
Role: RoleUser, // Always create as regular user, never admin Role: RoleUser, // Always create as regular user, never admin
Balance: input.Balance, Balance: input.Balance,
Concurrency: input.Concurrency, Concurrency: input.Concurrency,
RPMLimit: input.RPMLimit,
Status: StatusActive, Status: StatusActive,
AllowedGroups: input.AllowedGroups, AllowedGroups: input.AllowedGroups,
} }
...@@ -670,6 +702,7 @@ func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *Upda ...@@ -670,6 +702,7 @@ func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *Upda
oldConcurrency := user.Concurrency oldConcurrency := user.Concurrency
oldStatus := user.Status oldStatus := user.Status
oldRole := user.Role oldRole := user.Role
oldRPMLimit := user.RPMLimit
if input.Email != "" { if input.Email != "" {
user.Email = input.Email user.Email = input.Email
...@@ -695,6 +728,10 @@ func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *Upda ...@@ -695,6 +728,10 @@ func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *Upda
user.Concurrency = *input.Concurrency user.Concurrency = *input.Concurrency
} }
if input.RPMLimit != nil {
user.RPMLimit = *input.RPMLimit
}
if input.AllowedGroups != nil { if input.AllowedGroups != nil {
user.AllowedGroups = *input.AllowedGroups user.AllowedGroups = *input.AllowedGroups
} }
...@@ -711,7 +748,9 @@ func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *Upda ...@@ -711,7 +748,9 @@ func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *Upda
} }
if s.authCacheInvalidator != nil { if s.authCacheInvalidator != nil {
if user.Concurrency != oldConcurrency || user.Status != oldStatus || user.Role != oldRole { // RPMLimit 直接参与 billing_cache_service.checkRPM 的三级级联,
// 不失效缓存会让修改在一个 L2 TTL 内失去效果。
if user.Concurrency != oldConcurrency || user.Status != oldStatus || user.Role != oldRole || user.RPMLimit != oldRPMLimit {
s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, user.ID) s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, user.ID)
} }
} }
...@@ -833,6 +872,81 @@ func (s *adminServiceImpl) GetUserAPIKeys(ctx context.Context, userID int64, pag ...@@ -833,6 +872,81 @@ func (s *adminServiceImpl) GetUserAPIKeys(ctx context.Context, userID int64, pag
return keys, result.Total, nil return keys, result.Total, nil
} }
func (s *adminServiceImpl) GetUserRPMStatus(ctx context.Context, userID int64) (*UserRPMStatus, error) {
if s.userRPMCache == nil {
return nil, ErrRPMStatusUnavailable
}
user, err := s.userRepo.GetByID(ctx, userID)
if err != nil {
return nil, err
}
userRPMUsed, err := s.userRPMCache.GetUserRPM(ctx, userID)
if err != nil {
logger.LegacyPrintf("service.admin", "failed to get user rpm: user_id=%d err=%v", userID, err)
}
keys, _, err := s.GetUserAPIKeys(ctx, userID, 1, 1000, "", "")
if err != nil {
return nil, err
}
groupIDSet := make(map[int64]struct{})
for _, key := range keys {
if key.GroupID != nil && *key.GroupID > 0 {
groupIDSet[*key.GroupID] = struct{}{}
}
}
groupIDs := make([]int64, 0, len(groupIDSet))
for groupID := range groupIDSet {
groupIDs = append(groupIDs, groupID)
}
sort.Slice(groupIDs, func(i, j int) bool { return groupIDs[i] < groupIDs[j] })
var perGroup []UserGroupRPMStatus
for _, groupID := range groupIDs {
used, getErr := s.userRPMCache.GetUserGroupRPM(ctx, userID, groupID)
if getErr != nil {
logger.LegacyPrintf("service.admin", "failed to get user group rpm: user_id=%d group_id=%d err=%v", userID, groupID, getErr)
}
entry := UserGroupRPMStatus{
GroupID: groupID,
Used: used,
}
if s.groupRepo != nil {
if group, groupErr := s.groupRepo.GetByIDLite(ctx, groupID); groupErr == nil && group != nil {
entry.GroupName = group.Name
entry.Limit = group.RPMLimit
entry.Source = "group"
} else if groupErr != nil {
logger.LegacyPrintf("service.admin", "failed to get group rpm status metadata: group_id=%d err=%v", groupID, groupErr)
}
}
if s.userGroupRateRepo != nil {
override, overrideErr := s.userGroupRateRepo.GetRPMOverrideByUserAndGroup(ctx, userID, groupID)
if overrideErr != nil {
logger.LegacyPrintf("service.admin", "failed to get rpm override: user_id=%d group_id=%d err=%v", userID, groupID, overrideErr)
} else if override != nil {
entry.Limit = *override
entry.Source = "override"
}
}
perGroup = append(perGroup, entry)
}
return &UserRPMStatus{
UserRPMUsed: userRPMUsed,
UserRPMLimit: user.RPMLimit,
PerGroup: perGroup,
}, nil
}
func (s *adminServiceImpl) GetUserUsageStats(ctx context.Context, userID int64, period string) (any, error) { func (s *adminServiceImpl) GetUserUsageStats(ctx context.Context, userID int64, period string) (any, error) {
// Return mock data for now // Return mock data for now
return map[string]any{ return map[string]any{
...@@ -1314,6 +1428,7 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn ...@@ -1314,6 +1428,7 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn
RequirePrivacySet: input.RequirePrivacySet, RequirePrivacySet: input.RequirePrivacySet,
DefaultMappedModel: input.DefaultMappedModel, DefaultMappedModel: input.DefaultMappedModel,
MessagesDispatchModelConfig: normalizeOpenAIMessagesDispatchModelConfig(input.MessagesDispatchModelConfig), MessagesDispatchModelConfig: normalizeOpenAIMessagesDispatchModelConfig(input.MessagesDispatchModelConfig),
RPMLimit: input.RPMLimit,
} }
sanitizeGroupMessagesDispatchFields(group) sanitizeGroupMessagesDispatchFields(group)
if err := s.groupRepo.Create(ctx, group); err != nil { if err := s.groupRepo.Create(ctx, group); err != nil {
...@@ -1548,12 +1663,19 @@ func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *Upd ...@@ -1548,12 +1663,19 @@ func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *Upd
if input.MessagesDispatchModelConfig != nil { if input.MessagesDispatchModelConfig != nil {
group.MessagesDispatchModelConfig = normalizeOpenAIMessagesDispatchModelConfig(*input.MessagesDispatchModelConfig) group.MessagesDispatchModelConfig = normalizeOpenAIMessagesDispatchModelConfig(*input.MessagesDispatchModelConfig)
} }
if input.RPMLimit != nil {
group.RPMLimit = *input.RPMLimit
}
sanitizeGroupMessagesDispatchFields(group) sanitizeGroupMessagesDispatchFields(group)
if err := s.groupRepo.Update(ctx, group); err != nil { if err := s.groupRepo.Update(ctx, group); err != nil {
return nil, err return nil, err
} }
if s.authCacheInvalidator != nil {
s.authCacheInvalidator.InvalidateAuthCacheByGroupID(ctx, id)
}
// 如果指定了复制账号的源分组,同步绑定(替换当前分组的账号) // 如果指定了复制账号的源分组,同步绑定(替换当前分组的账号)
if len(input.CopyAccountsFromGroupIDs) > 0 { if len(input.CopyAccountsFromGroupIDs) > 0 {
// 去重源分组 IDs // 去重源分组 IDs
...@@ -1622,9 +1744,6 @@ func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *Upd ...@@ -1622,9 +1744,6 @@ func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *Upd
} }
} }
if s.authCacheInvalidator != nil {
s.authCacheInvalidator.InvalidateAuthCacheByGroupID(ctx, id)
}
return group, nil return group, nil
} }
...@@ -1700,6 +1819,39 @@ func (s *adminServiceImpl) BatchSetGroupRateMultipliers(ctx context.Context, gro ...@@ -1700,6 +1819,39 @@ func (s *adminServiceImpl) BatchSetGroupRateMultipliers(ctx context.Context, gro
return s.userGroupRateRepo.SyncGroupRateMultipliers(ctx, groupID, entries) return s.userGroupRateRepo.SyncGroupRateMultipliers(ctx, groupID, entries)
} }
func (s *adminServiceImpl) ClearGroupRPMOverrides(ctx context.Context, groupID int64) error {
if s.userGroupRateRepo == nil {
return nil
}
if err := s.userGroupRateRepo.ClearGroupRPMOverrides(ctx, groupID); err != nil {
return err
}
// RPM override 已嵌入 auth cache snapshot (v7),变更后必须失效相关缓存。
if s.authCacheInvalidator != nil {
s.authCacheInvalidator.InvalidateAuthCacheByGroupID(ctx, groupID)
}
return nil
}
func (s *adminServiceImpl) BatchSetGroupRPMOverrides(ctx context.Context, groupID int64, entries []GroupRPMOverrideInput) error {
if s.userGroupRateRepo == nil {
return nil
}
for _, e := range entries {
if e.RPMOverride != nil && *e.RPMOverride < 0 {
return infraerrors.BadRequest("INVALID_RPM_OVERRIDE", fmt.Sprintf("rpm_override must be >= 0 (user_id=%d)", e.UserID))
}
}
if err := s.userGroupRateRepo.SyncGroupRPMOverrides(ctx, groupID, entries); err != nil {
return err
}
// RPM override 已嵌入 auth cache snapshot (v7),变更后必须失效相关缓存。
if s.authCacheInvalidator != nil {
s.authCacheInvalidator.InvalidateAuthCacheByGroupID(ctx, groupID)
}
return nil
}
func (s *adminServiceImpl) UpdateGroupSortOrders(ctx context.Context, updates []GroupSortOrderUpdate) error { func (s *adminServiceImpl) UpdateGroupSortOrders(ctx context.Context, updates []GroupSortOrderUpdate) error {
return s.groupRepo.UpdateSortOrders(ctx, updates) return s.groupRepo.UpdateSortOrders(ctx, updates)
} }
......
...@@ -5,8 +5,10 @@ package service ...@@ -5,8 +5,10 @@ package service
import ( import (
"context" "context"
"errors" "errors"
"net/http"
"testing" "testing"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
...@@ -21,6 +23,10 @@ type userGroupRateRepoStubForGroupRate struct { ...@@ -21,6 +23,10 @@ type userGroupRateRepoStubForGroupRate struct {
syncedGroupID int64 syncedGroupID int64
syncedEntries []GroupRateMultiplierInput syncedEntries []GroupRateMultiplierInput
syncGroupErr error syncGroupErr error
rpmSyncedGroupID int64
rpmSyncedEntries []GroupRPMOverrideInput
rpmSyncErr error
} }
func (s *userGroupRateRepoStubForGroupRate) GetByUserID(_ context.Context, _ int64) (map[int64]float64, error) { func (s *userGroupRateRepoStubForGroupRate) GetByUserID(_ context.Context, _ int64) (map[int64]float64, error) {
...@@ -31,6 +37,10 @@ func (s *userGroupRateRepoStubForGroupRate) GetByUserAndGroup(_ context.Context, ...@@ -31,6 +37,10 @@ func (s *userGroupRateRepoStubForGroupRate) GetByUserAndGroup(_ context.Context,
panic("unexpected GetByUserAndGroup call") panic("unexpected GetByUserAndGroup call")
} }
func (s *userGroupRateRepoStubForGroupRate) GetRPMOverrideByUserAndGroup(_ context.Context, _, _ int64) (*int, error) {
panic("unexpected GetRPMOverrideByUserAndGroup call")
}
func (s *userGroupRateRepoStubForGroupRate) GetByGroupID(_ context.Context, groupID int64) ([]UserGroupRateEntry, error) { func (s *userGroupRateRepoStubForGroupRate) GetByGroupID(_ context.Context, groupID int64) ([]UserGroupRateEntry, error) {
if s.getByGroupIDErr != nil { if s.getByGroupIDErr != nil {
return nil, s.getByGroupIDErr return nil, s.getByGroupIDErr
...@@ -48,6 +58,16 @@ func (s *userGroupRateRepoStubForGroupRate) SyncGroupRateMultipliers(_ context.C ...@@ -48,6 +58,16 @@ func (s *userGroupRateRepoStubForGroupRate) SyncGroupRateMultipliers(_ context.C
return s.syncGroupErr return s.syncGroupErr
} }
func (s *userGroupRateRepoStubForGroupRate) SyncGroupRPMOverrides(_ context.Context, groupID int64, entries []GroupRPMOverrideInput) error {
s.rpmSyncedGroupID = groupID
s.rpmSyncedEntries = entries
return s.rpmSyncErr
}
func (s *userGroupRateRepoStubForGroupRate) ClearGroupRPMOverrides(_ context.Context, _ int64) error {
panic("unexpected ClearGroupRPMOverrides call")
}
func (s *userGroupRateRepoStubForGroupRate) DeleteByGroupID(_ context.Context, groupID int64) error { func (s *userGroupRateRepoStubForGroupRate) DeleteByGroupID(_ context.Context, groupID int64) error {
s.deletedGroupIDs = append(s.deletedGroupIDs, groupID) s.deletedGroupIDs = append(s.deletedGroupIDs, groupID)
return s.deleteByGroupErr return s.deleteByGroupErr
...@@ -62,8 +82,8 @@ func TestAdminService_GetGroupRateMultipliers(t *testing.T) { ...@@ -62,8 +82,8 @@ func TestAdminService_GetGroupRateMultipliers(t *testing.T) {
repo := &userGroupRateRepoStubForGroupRate{ repo := &userGroupRateRepoStubForGroupRate{
getByGroupIDData: map[int64][]UserGroupRateEntry{ getByGroupIDData: map[int64][]UserGroupRateEntry{
10: { 10: {
{UserID: 1, UserName: "alice", UserEmail: "alice@test.com", RateMultiplier: 1.5}, {UserID: 1, UserName: "alice", UserEmail: "alice@test.com", RateMultiplier: ptrFloat(1.5)},
{UserID: 2, UserName: "bob", UserEmail: "bob@test.com", RateMultiplier: 0.8}, {UserID: 2, UserName: "bob", UserEmail: "bob@test.com", RateMultiplier: ptrFloat(0.8)},
}, },
}, },
} }
...@@ -74,9 +94,11 @@ func TestAdminService_GetGroupRateMultipliers(t *testing.T) { ...@@ -74,9 +94,11 @@ func TestAdminService_GetGroupRateMultipliers(t *testing.T) {
require.Len(t, entries, 2) require.Len(t, entries, 2)
require.Equal(t, int64(1), entries[0].UserID) require.Equal(t, int64(1), entries[0].UserID)
require.Equal(t, "alice", entries[0].UserName) require.Equal(t, "alice", entries[0].UserName)
require.Equal(t, 1.5, entries[0].RateMultiplier) require.NotNil(t, entries[0].RateMultiplier)
require.Equal(t, 1.5, *entries[0].RateMultiplier)
require.Equal(t, int64(2), entries[1].UserID) require.Equal(t, int64(2), entries[1].UserID)
require.Equal(t, 0.8, entries[1].RateMultiplier) require.NotNil(t, entries[1].RateMultiplier)
require.Equal(t, 0.8, *entries[1].RateMultiplier)
}) })
t.Run("returns nil when repo is nil", func(t *testing.T) { t.Run("returns nil when repo is nil", func(t *testing.T) {
...@@ -174,3 +196,30 @@ func TestAdminService_BatchSetGroupRateMultipliers(t *testing.T) { ...@@ -174,3 +196,30 @@ func TestAdminService_BatchSetGroupRateMultipliers(t *testing.T) {
require.Contains(t, err.Error(), "sync failed") require.Contains(t, err.Error(), "sync failed")
}) })
} }
func TestAdminService_BatchSetGroupRPMOverrides(t *testing.T) {
t.Run("syncs entries to repo", func(t *testing.T) {
repo := &userGroupRateRepoStubForGroupRate{}
svc := &adminServiceImpl{userGroupRateRepo: repo}
override := 20
entries := []GroupRPMOverrideInput{{UserID: 2, RPMOverride: &override}}
err := svc.BatchSetGroupRPMOverrides(context.Background(), 10, entries)
require.NoError(t, err)
require.Equal(t, int64(10), repo.rpmSyncedGroupID)
require.Equal(t, entries, repo.rpmSyncedEntries)
})
t.Run("rejects negative override as bad request", func(t *testing.T) {
repo := &userGroupRateRepoStubForGroupRate{}
svc := &adminServiceImpl{userGroupRateRepo: repo}
negative := -1
err := svc.BatchSetGroupRPMOverrides(context.Background(), 10, []GroupRPMOverrideInput{
{UserID: 2, RPMOverride: &negative},
})
require.Error(t, err)
require.Equal(t, http.StatusBadRequest, infraerrors.Code(err))
require.Zero(t, repo.rpmSyncedGroupID)
})
}
...@@ -266,6 +266,31 @@ func TestAdminService_UpdateGroup_PartialImagePricing(t *testing.T) { ...@@ -266,6 +266,31 @@ func TestAdminService_UpdateGroup_PartialImagePricing(t *testing.T) {
require.Nil(t, repo.updated.ImagePrice4K) require.Nil(t, repo.updated.ImagePrice4K)
} }
func TestAdminService_UpdateGroup_InvalidatesAuthCacheOnRPMLimitChange(t *testing.T) {
existingGroup := &Group{
ID: 1,
Name: "existing-group",
Platform: PlatformAnthropic,
Status: StatusActive,
RPMLimit: 10,
}
repo := &groupRepoStubForAdmin{getByID: existingGroup}
invalidator := &authCacheInvalidatorStub{}
svc := &adminServiceImpl{
groupRepo: repo,
authCacheInvalidator: invalidator,
}
rpmLimit := 60
group, err := svc.UpdateGroup(context.Background(), 1, &UpdateGroupInput{
RPMLimit: &rpmLimit,
})
require.NoError(t, err)
require.NotNil(t, group)
require.Equal(t, 60, repo.updated.RPMLimit)
require.Equal(t, []int64{1}, invalidator.groupIDs, "分组 RPMLimit 写入 auth snapshot,变更后必须失效 API Key 认证缓存")
}
func TestAdminService_CreateGroup_NormalizesMessagesDispatchModelConfig(t *testing.T) { func TestAdminService_CreateGroup_NormalizesMessagesDispatchModelConfig(t *testing.T) {
repo := &groupRepoStubForAdmin{} repo := &groupRepoStubForAdmin{}
svc := &adminServiceImpl{groupRepo: repo} svc := &adminServiceImpl{groupRepo: repo}
......
...@@ -89,6 +89,10 @@ func (s *userGroupRateRepoStubForListUsers) GetByUserAndGroup(_ context.Context, ...@@ -89,6 +89,10 @@ func (s *userGroupRateRepoStubForListUsers) GetByUserAndGroup(_ context.Context,
panic("unexpected GetByUserAndGroup call") panic("unexpected GetByUserAndGroup call")
} }
func (s *userGroupRateRepoStubForListUsers) GetRPMOverrideByUserAndGroup(_ context.Context, _, _ int64) (*int, error) {
panic("unexpected GetRPMOverrideByUserAndGroup call")
}
func (s *userGroupRateRepoStubForListUsers) SyncUserGroupRates(_ context.Context, userID int64, rates map[int64]*float64) error { func (s *userGroupRateRepoStubForListUsers) SyncUserGroupRates(_ context.Context, userID int64, rates map[int64]*float64) error {
panic("unexpected SyncUserGroupRates call") panic("unexpected SyncUserGroupRates call")
} }
...@@ -101,6 +105,14 @@ func (s *userGroupRateRepoStubForListUsers) SyncGroupRateMultipliers(_ context.C ...@@ -101,6 +105,14 @@ func (s *userGroupRateRepoStubForListUsers) SyncGroupRateMultipliers(_ context.C
panic("unexpected SyncGroupRateMultipliers call") panic("unexpected SyncGroupRateMultipliers call")
} }
func (s *userGroupRateRepoStubForListUsers) SyncGroupRPMOverrides(_ context.Context, _ int64, _ []GroupRPMOverrideInput) error {
panic("unexpected SyncGroupRPMOverrides call")
}
func (s *userGroupRateRepoStubForListUsers) ClearGroupRPMOverrides(_ context.Context, _ int64) error {
panic("unexpected ClearGroupRPMOverrides call")
}
func (s *userGroupRateRepoStubForListUsers) DeleteByGroupID(_ context.Context, _ int64) error { func (s *userGroupRateRepoStubForListUsers) DeleteByGroupID(_ context.Context, _ int64) error {
panic("unexpected DeleteByGroupID call") panic("unexpected DeleteByGroupID call")
} }
......
//go:build unit
package service
import (
"context"
"testing"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/stretchr/testify/require"
)
type rpmStatusUserRepoStub struct {
UserRepository
user *User
}
func (s *rpmStatusUserRepoStub) GetByID(_ context.Context, _ int64) (*User, error) {
return s.user, nil
}
type rpmStatusAPIKeyRepoStub struct {
APIKeyRepository
keys []APIKey
}
func (s *rpmStatusAPIKeyRepoStub) ListByUserID(_ context.Context, _ int64, _ pagination.PaginationParams, _ APIKeyListFilters) ([]APIKey, *pagination.PaginationResult, error) {
return s.keys, &pagination.PaginationResult{Total: int64(len(s.keys))}, nil
}
type rpmStatusGroupRepoStub struct {
GroupRepository
groups map[int64]*Group
}
func (s *rpmStatusGroupRepoStub) GetByIDLite(_ context.Context, id int64) (*Group, error) {
return s.groups[id], nil
}
type rpmStatusRateRepoStub struct {
UserGroupRateRepository
overrides map[int64]*int
}
func (s *rpmStatusRateRepoStub) GetRPMOverrideByUserAndGroup(_ context.Context, _, groupID int64) (*int, error) {
return s.overrides[groupID], nil
}
type rpmStatusCacheStub struct {
UserRPMCache
userUsed int
groupUsed map[int64]int
}
func (s *rpmStatusCacheStub) IncrementUserGroupRPM(context.Context, int64, int64) (int, error) {
return 0, nil
}
func (s *rpmStatusCacheStub) IncrementUserRPM(context.Context, int64) (int, error) {
return 0, nil
}
func (s *rpmStatusCacheStub) GetUserGroupRPM(_ context.Context, _, groupID int64) (int, error) {
return s.groupUsed[groupID], nil
}
func (s *rpmStatusCacheStub) GetUserRPM(context.Context, int64) (int, error) {
return s.userUsed, nil
}
func TestAdminService_GetUserRPMStatus_AggregatesUserAndGroupLimits(t *testing.T) {
groupOneID := int64(1)
groupTwoID := int64(2)
override := 7
svc := &adminServiceImpl{
userRepo: &rpmStatusUserRepoStub{user: &User{
ID: 42,
RPMLimit: 20,
}},
apiKeyRepo: &rpmStatusAPIKeyRepoStub{keys: []APIKey{
{ID: 100, UserID: 42, GroupID: &groupTwoID},
{ID: 101, UserID: 42, GroupID: &groupOneID},
{ID: 102, UserID: 42, GroupID: &groupTwoID},
{ID: 103, UserID: 42},
}},
groupRepo: &rpmStatusGroupRepoStub{groups: map[int64]*Group{
groupOneID: {ID: groupOneID, Name: "group-one", RPMLimit: 10},
groupTwoID: {ID: groupTwoID, Name: "group-two", RPMLimit: 60},
}},
userGroupRateRepo: &rpmStatusRateRepoStub{overrides: map[int64]*int{
groupTwoID: &override,
}},
userRPMCache: &rpmStatusCacheStub{
userUsed: 5,
groupUsed: map[int64]int{
groupOneID: 3,
groupTwoID: 4,
},
},
}
status, err := svc.GetUserRPMStatus(context.Background(), 42)
require.NoError(t, err)
require.Equal(t, &UserRPMStatus{
UserRPMUsed: 5,
UserRPMLimit: 20,
PerGroup: []UserGroupRPMStatus{
{GroupID: groupOneID, GroupName: "group-one", Used: 3, Limit: 10, Source: "group"},
{GroupID: groupTwoID, GroupName: "group-two", Used: 4, Limit: 7, Source: "override"},
},
}, status)
}
//go:build unit
package service
import (
"context"
"testing"
"github.com/stretchr/testify/require"
)
// rpmUserRepoStub 复用 admin_service_update_balance_test.go 的基础 stub 结构,
// 只在 Update 时把入参克隆一份,便于断言修改后的 RPMLimit。
type rpmUserRepoStub struct {
*userRepoStub
lastUpdated *User
}
func (s *rpmUserRepoStub) Update(_ context.Context, user *User) error {
if user == nil {
return nil
}
clone := *user
s.lastUpdated = &clone
if s.userRepoStub != nil {
s.userRepoStub.user = &clone
}
return nil
}
func TestAdminService_UpdateUser_InvalidatesAuthCacheOnRPMLimitChange(t *testing.T) {
base := &userRepoStub{user: &User{ID: 42, Email: "u@example.com", RPMLimit: 10}}
repo := &rpmUserRepoStub{userRepoStub: base}
invalidator := &authCacheInvalidatorStub{}
svc := &adminServiceImpl{
userRepo: repo,
redeemCodeRepo: &redeemRepoStub{},
authCacheInvalidator: invalidator,
}
newRPM := 60
updated, err := svc.UpdateUser(context.Background(), 42, &UpdateUserInput{
RPMLimit: &newRPM,
})
require.NoError(t, err)
require.NotNil(t, updated)
require.Equal(t, 60, updated.RPMLimit)
require.Equal(t, []int64{42}, invalidator.userIDs, "仅修改 RPMLimit 也应失效 API Key 认证缓存")
}
func TestAdminService_UpdateUser_NoInvalidateWhenRPMLimitUnchanged(t *testing.T) {
base := &userRepoStub{user: &User{ID: 42, Email: "u@example.com", RPMLimit: 10, Username: "old"}}
repo := &rpmUserRepoStub{userRepoStub: base}
invalidator := &authCacheInvalidatorStub{}
svc := &adminServiceImpl{
userRepo: repo,
redeemCodeRepo: &redeemRepoStub{},
authCacheInvalidator: invalidator,
}
newName := "new"
sameRPM := 10
_, err := svc.UpdateUser(context.Background(), 42, &UpdateUserInput{
Username: &newName,
RPMLimit: &sameRPM,
})
require.NoError(t, err)
require.Empty(t, invalidator.userIDs, "只改 username 不应触发认证缓存失效")
}
...@@ -43,6 +43,13 @@ type APIKeyAuthUserSnapshot struct { ...@@ -43,6 +43,13 @@ type APIKeyAuthUserSnapshot struct {
BalanceNotifyThreshold *float64 `json:"balance_notify_threshold,omitempty"` BalanceNotifyThreshold *float64 `json:"balance_notify_threshold,omitempty"`
BalanceNotifyExtraEmails []NotifyEmailEntry `json:"balance_notify_extra_emails,omitempty"` BalanceNotifyExtraEmails []NotifyEmailEntry `json:"balance_notify_extra_emails,omitempty"`
TotalRecharged float64 `json:"total_recharged"` TotalRecharged float64 `json:"total_recharged"`
// RPMLimit 用户级每分钟请求数上限(0 = 不限制);用于 billing_cache_service.checkRPM 兜底判断。
RPMLimit int `json:"rpm_limit"`
// UserGroupRPMOverride 该 API Key 对应的 (user, group) 专属 RPM 覆盖值。
// nil = 无 override(回退到 group/user 级);0 = 不限流;>0 = 专属上限。
UserGroupRPMOverride *int `json:"user_group_rpm_override,omitempty"`
} }
// APIKeyAuthGroupSnapshot 分组快照 // APIKeyAuthGroupSnapshot 分组快照
...@@ -76,6 +83,9 @@ type APIKeyAuthGroupSnapshot struct { ...@@ -76,6 +83,9 @@ type APIKeyAuthGroupSnapshot struct {
AllowMessagesDispatch bool `json:"allow_messages_dispatch"` AllowMessagesDispatch bool `json:"allow_messages_dispatch"`
DefaultMappedModel string `json:"default_mapped_model,omitempty"` DefaultMappedModel string `json:"default_mapped_model,omitempty"`
MessagesDispatchModelConfig OpenAIMessagesDispatchModelConfig `json:"messages_dispatch_model_config,omitempty"` MessagesDispatchModelConfig OpenAIMessagesDispatchModelConfig `json:"messages_dispatch_model_config,omitempty"`
// RPMLimit 分组级每分钟请求数上限(0 = 不限制);用于 billing_cache_service.checkRPM 级联判断。
RPMLimit int `json:"rpm_limit"`
} }
// APIKeyAuthCacheEntry 缓存条目,支持负缓存 // APIKeyAuthCacheEntry 缓存条目,支持负缓存
......
...@@ -14,7 +14,7 @@ import ( ...@@ -14,7 +14,7 @@ import (
"github.com/dgraph-io/ristretto" "github.com/dgraph-io/ristretto"
) )
const apiKeyAuthSnapshotVersion = 5 // v5: added TotalRecharged for percentage threshold const apiKeyAuthSnapshotVersion = 7 // v7: added UserGroupRPMOverride on user snapshot
type apiKeyAuthCacheConfig struct { type apiKeyAuthCacheConfig struct {
l1Size int l1Size int
...@@ -176,7 +176,7 @@ func (s *APIKeyService) loadAuthCacheEntry(ctx context.Context, key, cacheKey st ...@@ -176,7 +176,7 @@ func (s *APIKeyService) loadAuthCacheEntry(ctx context.Context, key, cacheKey st
return nil, fmt.Errorf("get api key: %w", err) return nil, fmt.Errorf("get api key: %w", err)
} }
apiKey.Key = key apiKey.Key = key
snapshot := s.snapshotFromAPIKey(apiKey) snapshot := s.snapshotFromAPIKey(ctx, apiKey)
if snapshot == nil { if snapshot == nil {
return nil, fmt.Errorf("get api key: %w", ErrAPIKeyNotFound) return nil, fmt.Errorf("get api key: %w", ErrAPIKeyNotFound)
} }
...@@ -201,7 +201,7 @@ func (s *APIKeyService) applyAuthCacheEntry(key string, entry *APIKeyAuthCacheEn ...@@ -201,7 +201,7 @@ func (s *APIKeyService) applyAuthCacheEntry(key string, entry *APIKeyAuthCacheEn
return s.snapshotToAPIKey(key, entry.Snapshot), true, nil return s.snapshotToAPIKey(key, entry.Snapshot), true, nil
} }
func (s *APIKeyService) snapshotFromAPIKey(apiKey *APIKey) *APIKeyAuthSnapshot { func (s *APIKeyService) snapshotFromAPIKey(ctx context.Context, apiKey *APIKey) *APIKeyAuthSnapshot {
if apiKey == nil || apiKey.User == nil { if apiKey == nil || apiKey.User == nil {
return nil return nil
} }
...@@ -232,8 +232,18 @@ func (s *APIKeyService) snapshotFromAPIKey(apiKey *APIKey) *APIKeyAuthSnapshot { ...@@ -232,8 +232,18 @@ func (s *APIKeyService) snapshotFromAPIKey(apiKey *APIKey) *APIKeyAuthSnapshot {
BalanceNotifyThreshold: apiKey.User.BalanceNotifyThreshold, BalanceNotifyThreshold: apiKey.User.BalanceNotifyThreshold,
BalanceNotifyExtraEmails: apiKey.User.BalanceNotifyExtraEmails, BalanceNotifyExtraEmails: apiKey.User.BalanceNotifyExtraEmails,
TotalRecharged: apiKey.User.TotalRecharged, TotalRecharged: apiKey.User.TotalRecharged,
RPMLimit: apiKey.User.RPMLimit,
}, },
} }
// 填充 (user, group) RPM override —— snapshot 构建时查一次 DB,后续请求零 DB 往返。
if apiKey.GroupID != nil && *apiKey.GroupID > 0 && s.userGroupRateRepo != nil {
override, err := s.userGroupRateRepo.GetRPMOverrideByUserAndGroup(ctx, apiKey.UserID, *apiKey.GroupID)
if err == nil && override != nil {
snapshot.User.UserGroupRPMOverride = override
}
// 查询失败或无 override 时留 nil,checkRPM 会回退到 DB 查询
}
if apiKey.Group != nil { if apiKey.Group != nil {
snapshot.Group = &APIKeyAuthGroupSnapshot{ snapshot.Group = &APIKeyAuthGroupSnapshot{
ID: apiKey.Group.ID, ID: apiKey.Group.ID,
...@@ -258,6 +268,7 @@ func (s *APIKeyService) snapshotFromAPIKey(apiKey *APIKey) *APIKeyAuthSnapshot { ...@@ -258,6 +268,7 @@ func (s *APIKeyService) snapshotFromAPIKey(apiKey *APIKey) *APIKeyAuthSnapshot {
AllowMessagesDispatch: apiKey.Group.AllowMessagesDispatch, AllowMessagesDispatch: apiKey.Group.AllowMessagesDispatch,
DefaultMappedModel: apiKey.Group.DefaultMappedModel, DefaultMappedModel: apiKey.Group.DefaultMappedModel,
MessagesDispatchModelConfig: apiKey.Group.MessagesDispatchModelConfig, MessagesDispatchModelConfig: apiKey.Group.MessagesDispatchModelConfig,
RPMLimit: apiKey.Group.RPMLimit,
} }
} }
return snapshot return snapshot
...@@ -294,6 +305,8 @@ func (s *APIKeyService) snapshotToAPIKey(key string, snapshot *APIKeyAuthSnapsho ...@@ -294,6 +305,8 @@ func (s *APIKeyService) snapshotToAPIKey(key string, snapshot *APIKeyAuthSnapsho
BalanceNotifyThreshold: snapshot.User.BalanceNotifyThreshold, BalanceNotifyThreshold: snapshot.User.BalanceNotifyThreshold,
BalanceNotifyExtraEmails: snapshot.User.BalanceNotifyExtraEmails, BalanceNotifyExtraEmails: snapshot.User.BalanceNotifyExtraEmails,
TotalRecharged: snapshot.User.TotalRecharged, TotalRecharged: snapshot.User.TotalRecharged,
RPMLimit: snapshot.User.RPMLimit,
UserGroupRPMOverride: snapshot.User.UserGroupRPMOverride,
}, },
} }
if snapshot.Group != nil { if snapshot.Group != nil {
...@@ -321,6 +334,7 @@ func (s *APIKeyService) snapshotToAPIKey(key string, snapshot *APIKeyAuthSnapsho ...@@ -321,6 +334,7 @@ func (s *APIKeyService) snapshotToAPIKey(key string, snapshot *APIKeyAuthSnapsho
AllowMessagesDispatch: snapshot.Group.AllowMessagesDispatch, AllowMessagesDispatch: snapshot.Group.AllowMessagesDispatch,
DefaultMappedModel: snapshot.Group.DefaultMappedModel, DefaultMappedModel: snapshot.Group.DefaultMappedModel,
MessagesDispatchModelConfig: snapshot.Group.MessagesDispatchModelConfig, MessagesDispatchModelConfig: snapshot.Group.MessagesDispatchModelConfig,
RPMLimit: snapshot.Group.RPMLimit,
} }
} }
s.compileAPIKeyIPRules(apiKey) s.compileAPIKeyIPRules(apiKey)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment