Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Menu
Open sidebar
陈曦
sub2api
Commits
5e060b22
Commit
5e060b22
authored
Apr 23, 2026
by
erio
Browse files
Merge remote-tracking branch 'upstream/main' into feat/channel-insights
# Conflicts: # backend/cmd/server/wire_gen.go
parents
6f04c25e
0a80ec80
Changes
106
Hide whitespace changes
Inline
Side-by-side
backend/internal/repository/openai_oauth_service_test.go
View file @
5e060b22
...
...
@@ -8,6 +8,7 @@ import (
"net/url"
"testing"
infraerrors
"github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
...
...
@@ -204,6 +205,17 @@ func (s *OpenAIOAuthServiceSuite) TestRequestError_ClosedServer() {
require
.
ErrorContains
(
s
.
T
(),
err
,
"request failed"
)
}
func
(
s
*
OpenAIOAuthServiceSuite
)
TestExchangeCode_RequestErrorWithoutProxyReturnsProxyHint
()
{
s
.
setupServer
(
http
.
HandlerFunc
(
func
(
w
http
.
ResponseWriter
,
r
*
http
.
Request
)
{}))
s
.
srv
.
Close
()
_
,
err
:=
s
.
svc
.
ExchangeCode
(
s
.
ctx
,
"code"
,
"ver"
,
openai
.
DefaultRedirectURI
,
""
,
""
)
require
.
Error
(
s
.
T
(),
err
)
require
.
Equal
(
s
.
T
(),
"OPENAI_OAUTH_PROXY_REQUIRED"
,
infraerrors
.
Reason
(
err
))
require
.
Contains
(
s
.
T
(),
infraerrors
.
Message
(
err
),
"no proxy is configured"
)
}
func
(
s
*
OpenAIOAuthServiceSuite
)
TestContextCancel
()
{
started
:=
make
(
chan
struct
{})
block
:=
make
(
chan
struct
{})
...
...
backend/internal/repository/usage_billing_repo.go
View file @
5e060b22
...
...
@@ -290,7 +290,6 @@ func incrementUsageBillingAccountQuota(ctx context.Context, tx *sql.Tx, accountI
if
err
!=
nil
{
return
nil
,
err
}
defer
func
()
{
_
=
rows
.
Close
()
}()
var
state
service
.
AccountQuotaState
if
rows
.
Next
()
{
...
...
@@ -299,18 +298,36 @@ func incrementUsageBillingAccountQuota(ctx context.Context, tx *sql.Tx, accountI
&
state
.
DailyUsed
,
&
state
.
DailyLimit
,
&
state
.
WeeklyUsed
,
&
state
.
WeeklyLimit
,
);
err
!=
nil
{
_
=
rows
.
Close
()
return
nil
,
err
}
}
else
{
if
err
:=
rows
.
Err
();
err
!=
nil
{
_
=
rows
.
Close
()
return
nil
,
err
}
_
=
rows
.
Close
()
return
nil
,
service
.
ErrAccountNotFound
}
if
err
:=
rows
.
Err
();
err
!=
nil
{
_
=
rows
.
Close
()
return
nil
,
err
}
if
state
.
TotalLimit
>
0
&&
state
.
TotalUsed
>=
state
.
TotalLimit
&&
(
state
.
TotalUsed
-
amount
)
<
state
.
TotalLimit
{
// 必须在执行下一条 SQL 前显式关闭 rows:pq 驱动在同一连接上
// 不允许前一条查询的结果集未耗尽时启动新查询,否则会返回
// "unexpected Parse response" 错误。
if
err
:=
rows
.
Close
();
err
!=
nil
{
return
nil
,
err
}
// 任意维度额度在本次递增中从"未超"跨越到"已超"时,必须刷新调度快照,
// 否则 Redis 中缓存的 Account 仍显示旧的 used 值,后续请求会继续选中本账号,
// 最终观察到 daily_used / weekly_used 大幅超过配置的 limit。
// 对于日/周额度,即使本次触发了周期重置(pre=0、post=amount),
// 判定式 (post-amount) < limit 同样成立,逻辑与总额度保持一致。
crossedTotal
:=
state
.
TotalLimit
>
0
&&
state
.
TotalUsed
>=
state
.
TotalLimit
&&
(
state
.
TotalUsed
-
amount
)
<
state
.
TotalLimit
crossedDaily
:=
state
.
DailyLimit
>
0
&&
state
.
DailyUsed
>=
state
.
DailyLimit
&&
(
state
.
DailyUsed
-
amount
)
<
state
.
DailyLimit
crossedWeekly
:=
state
.
WeeklyLimit
>
0
&&
state
.
WeeklyUsed
>=
state
.
WeeklyLimit
&&
(
state
.
WeeklyUsed
-
amount
)
<
state
.
WeeklyLimit
if
crossedTotal
||
crossedDaily
||
crossedWeekly
{
if
err
:=
enqueueSchedulerOutbox
(
ctx
,
tx
,
service
.
SchedulerOutboxEventAccountChanged
,
&
accountID
,
nil
,
nil
);
err
!=
nil
{
logger
.
LegacyPrintf
(
"repository.usage_billing"
,
"[SchedulerOutbox] enqueue quota exceeded failed: account=%d err=%v"
,
accountID
,
err
)
return
nil
,
err
...
...
backend/internal/repository/usage_billing_repo_integration_test.go
View file @
5e060b22
...
...
@@ -199,6 +199,94 @@ func TestUsageBillingRepositoryApply_UpdatesAccountQuota(t *testing.T) {
require
.
InDelta
(
t
,
3.5
,
quotaUsed
,
0.000001
)
}
func
TestUsageBillingRepositoryApply_EnqueuesSchedulerOutboxOnQuotaCrossing
(
t
*
testing
.
T
)
{
ctx
:=
context
.
Background
()
client
:=
testEntClient
(
t
)
repo
:=
NewUsageBillingRepository
(
client
,
integrationDB
)
newFixture
:=
func
(
t
*
testing
.
T
,
extra
map
[
string
]
any
)
(
int64
,
int64
)
{
t
.
Helper
()
user
:=
mustCreateUser
(
t
,
client
,
&
service
.
User
{
Email
:
fmt
.
Sprintf
(
"usage-billing-outbox-user-%d-%s@example.com"
,
time
.
Now
()
.
UnixNano
(),
uuid
.
NewString
()),
PasswordHash
:
"hash"
,
})
apiKey
:=
mustCreateApiKey
(
t
,
client
,
&
service
.
APIKey
{
UserID
:
user
.
ID
,
Key
:
"sk-usage-billing-outbox-"
+
uuid
.
NewString
(),
Name
:
"billing-outbox"
,
})
account
:=
mustCreateAccount
(
t
,
client
,
&
service
.
Account
{
Name
:
"usage-billing-outbox-"
+
uuid
.
NewString
(),
Type
:
service
.
AccountTypeAPIKey
,
Extra
:
extra
,
})
return
apiKey
.
ID
,
account
.
ID
}
outboxCountFor
:=
func
(
t
*
testing
.
T
,
accountID
int64
)
int
{
t
.
Helper
()
var
count
int
require
.
NoError
(
t
,
integrationDB
.
QueryRowContext
(
ctx
,
"SELECT COUNT(*) FROM scheduler_outbox WHERE event_type = $1 AND account_id = $2"
,
service
.
SchedulerOutboxEventAccountChanged
,
accountID
,
)
.
Scan
(
&
count
))
return
count
}
t
.
Run
(
"daily_first_crossing_enqueues"
,
func
(
t
*
testing
.
T
)
{
apiKeyID
,
accountID
:=
newFixture
(
t
,
map
[
string
]
any
{
"quota_daily_limit"
:
10.0
,
})
// 第一次低于日限额:不应入队 outbox
_
,
err
:=
repo
.
Apply
(
ctx
,
&
service
.
UsageBillingCommand
{
RequestID
:
uuid
.
NewString
(),
APIKeyID
:
apiKeyID
,
AccountID
:
accountID
,
AccountType
:
service
.
AccountTypeAPIKey
,
AccountQuotaCost
:
4
,
})
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
0
,
outboxCountFor
(
t
,
accountID
),
"below limit should not enqueue"
)
// 第二次跨越日限额:应入队一次 outbox
_
,
err
=
repo
.
Apply
(
ctx
,
&
service
.
UsageBillingCommand
{
RequestID
:
uuid
.
NewString
(),
APIKeyID
:
apiKeyID
,
AccountID
:
accountID
,
AccountType
:
service
.
AccountTypeAPIKey
,
AccountQuotaCost
:
8
,
})
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
1
,
outboxCountFor
(
t
,
accountID
),
"crossing daily limit should enqueue once"
)
// 再次递增(已超):不应重复入队
_
,
err
=
repo
.
Apply
(
ctx
,
&
service
.
UsageBillingCommand
{
RequestID
:
uuid
.
NewString
(),
APIKeyID
:
apiKeyID
,
AccountID
:
accountID
,
AccountType
:
service
.
AccountTypeAPIKey
,
AccountQuotaCost
:
2
,
})
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
1
,
outboxCountFor
(
t
,
accountID
),
"subsequent increments beyond limit should not re-enqueue"
)
})
t
.
Run
(
"weekly_first_crossing_enqueues"
,
func
(
t
*
testing
.
T
)
{
apiKeyID
,
accountID
:=
newFixture
(
t
,
map
[
string
]
any
{
"quota_weekly_limit"
:
10.0
,
})
_
,
err
:=
repo
.
Apply
(
ctx
,
&
service
.
UsageBillingCommand
{
RequestID
:
uuid
.
NewString
(),
APIKeyID
:
apiKeyID
,
AccountID
:
accountID
,
AccountType
:
service
.
AccountTypeAPIKey
,
AccountQuotaCost
:
15
,
// 单次即跨越
})
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
1
,
outboxCountFor
(
t
,
accountID
),
"single-shot crossing weekly limit should enqueue once"
)
})
}
func
TestDashboardAggregationRepositoryCleanupUsageBillingDedup_BatchDeletesOldRows
(
t
*
testing
.
T
)
{
ctx
:=
context
.
Background
()
repo
:=
newDashboardAggregationRepositoryWithSQL
(
integrationDB
)
...
...
backend/internal/repository/user_group_rate_repo.go
View file @
5e060b22
...
...
@@ -13,14 +13,14 @@ type userGroupRateRepository struct {
sql
sqlExecutor
}
// NewUserGroupRateRepository 创建用户专属分组倍率仓储
// NewUserGroupRateRepository 创建用户专属分组倍率
/RPM
仓储
func
NewUserGroupRateRepository
(
sqlDB
*
sql
.
DB
)
service
.
UserGroupRateRepository
{
return
&
userGroupRateRepository
{
sql
:
sqlDB
}
}
// GetByUserID 获取用户
的
所有专属分组
倍率
// GetByUserID 获取用户所有专属分组
rate_multiplier(仅返回非 NULL 的条目)
func
(
r
*
userGroupRateRepository
)
GetByUserID
(
ctx
context
.
Context
,
userID
int64
)
(
map
[
int64
]
float64
,
error
)
{
query
:=
`SELECT group_id, rate_multiplier FROM user_group_rate_multipliers WHERE user_id = $1`
query
:=
`SELECT group_id, rate_multiplier FROM user_group_rate_multipliers WHERE user_id = $1
AND rate_multiplier IS NOT NULL
`
rows
,
err
:=
r
.
sql
.
QueryContext
(
ctx
,
query
,
userID
)
if
err
!=
nil
{
return
nil
,
err
...
...
@@ -42,8 +42,7 @@ func (r *userGroupRateRepository) GetByUserID(ctx context.Context, userID int64)
return
result
,
nil
}
// GetByUserIDs 批量获取多个用户的专属分组倍率。
// 返回结构:map[userID]map[groupID]rate
// GetByUserIDs 批量获取多个用户的专属分组 rate_multiplier(仅返回非 NULL 的条目)
func
(
r
*
userGroupRateRepository
)
GetByUserIDs
(
ctx
context
.
Context
,
userIDs
[]
int64
)
(
map
[
int64
]
map
[
int64
]
float64
,
error
)
{
result
:=
make
(
map
[
int64
]
map
[
int64
]
float64
,
len
(
userIDs
))
if
len
(
userIDs
)
==
0
{
...
...
@@ -70,7 +69,7 @@ func (r *userGroupRateRepository) GetByUserIDs(ctx context.Context, userIDs []in
rows
,
err
:=
r
.
sql
.
QueryContext
(
ctx
,
`
SELECT user_id, group_id, rate_multiplier
FROM user_group_rate_multipliers
WHERE user_id = ANY($1)
WHERE user_id = ANY($1)
AND rate_multiplier IS NOT NULL
`
,
pq
.
Array
(
uniqueIDs
))
if
err
!=
nil
{
return
nil
,
err
...
...
@@ -95,10 +94,10 @@ func (r *userGroupRateRepository) GetByUserIDs(ctx context.Context, userIDs []in
return
result
,
nil
}
// GetByGroupID 获取指定分组下所有用户的专属
倍率
// GetByGroupID 获取指定分组下所有用户的专属
配置(rate 与 rpm_override 任一非 NULL 即返回)
func
(
r
*
userGroupRateRepository
)
GetByGroupID
(
ctx
context
.
Context
,
groupID
int64
)
([]
service
.
UserGroupRateEntry
,
error
)
{
query
:=
`
SELECT ugr.user_id, u.username, u.email, COALESCE(u.notes, ''), u.status, ugr.rate_multiplier
SELECT ugr.user_id, u.username, u.email, COALESCE(u.notes, ''), u.status, ugr.rate_multiplier
, ugr.rpm_override
FROM user_group_rate_multipliers ugr
JOIN users u ON u.id = ugr.user_id AND u.deleted_at IS NULL
WHERE ugr.group_id = $1
...
...
@@ -113,9 +112,19 @@ func (r *userGroupRateRepository) GetByGroupID(ctx context.Context, groupID int6
var
result
[]
service
.
UserGroupRateEntry
for
rows
.
Next
()
{
var
entry
service
.
UserGroupRateEntry
if
err
:=
rows
.
Scan
(
&
entry
.
UserID
,
&
entry
.
UserName
,
&
entry
.
UserEmail
,
&
entry
.
UserNotes
,
&
entry
.
UserStatus
,
&
entry
.
RateMultiplier
);
err
!=
nil
{
var
rate
sql
.
NullFloat64
var
rpm
sql
.
NullInt32
if
err
:=
rows
.
Scan
(
&
entry
.
UserID
,
&
entry
.
UserName
,
&
entry
.
UserEmail
,
&
entry
.
UserNotes
,
&
entry
.
UserStatus
,
&
rate
,
&
rpm
);
err
!=
nil
{
return
nil
,
err
}
if
rate
.
Valid
{
v
:=
rate
.
Float64
entry
.
RateMultiplier
=
&
v
}
if
rpm
.
Valid
{
v
:=
int
(
rpm
.
Int32
)
entry
.
RPMOverride
=
&
v
}
result
=
append
(
result
,
entry
)
}
if
err
:=
rows
.
Err
();
err
!=
nil
{
...
...
@@ -124,10 +133,10 @@ func (r *userGroupRateRepository) GetByGroupID(ctx context.Context, groupID int6
return
result
,
nil
}
// GetByUserAndGroup 获取用户在特定分组的专属
倍率
// GetByUserAndGroup 获取用户在特定分组的专属
rate_multiplier(NULL 返回 nil)
func
(
r
*
userGroupRateRepository
)
GetByUserAndGroup
(
ctx
context
.
Context
,
userID
,
groupID
int64
)
(
*
float64
,
error
)
{
query
:=
`SELECT rate_multiplier FROM user_group_rate_multipliers WHERE user_id = $1 AND group_id = $2`
var
rate
f
loat64
var
rate
sql
.
NullF
loat64
err
:=
scanSingleRow
(
ctx
,
r
.
sql
,
query
,
[]
any
{
userID
,
groupID
},
&
rate
)
if
err
==
sql
.
ErrNoRows
{
return
nil
,
nil
...
...
@@ -135,42 +144,79 @@ func (r *userGroupRateRepository) GetByUserAndGroup(ctx context.Context, userID,
if
err
!=
nil
{
return
nil
,
err
}
return
&
rate
,
nil
if
!
rate
.
Valid
{
return
nil
,
nil
}
v
:=
rate
.
Float64
return
&
v
,
nil
}
// GetRPMOverrideByUserAndGroup 获取用户在特定分组的 rpm_override(NULL 返回 nil)
func
(
r
*
userGroupRateRepository
)
GetRPMOverrideByUserAndGroup
(
ctx
context
.
Context
,
userID
,
groupID
int64
)
(
*
int
,
error
)
{
query
:=
`SELECT rpm_override FROM user_group_rate_multipliers WHERE user_id = $1 AND group_id = $2`
var
rpm
sql
.
NullInt32
err
:=
scanSingleRow
(
ctx
,
r
.
sql
,
query
,
[]
any
{
userID
,
groupID
},
&
rpm
)
if
err
==
sql
.
ErrNoRows
{
return
nil
,
nil
}
if
err
!=
nil
{
return
nil
,
err
}
if
!
rpm
.
Valid
{
return
nil
,
nil
}
v
:=
int
(
rpm
.
Int32
)
return
&
v
,
nil
}
// SyncUserGroupRates 同步用户的分组专属倍率
// SyncUserGroupRates 同步用户的分组专属 rate_multiplier。
// - 传入空 map:清空该用户所有行的 rate_multiplier;若 rpm_override 也为 NULL 则整行删除。
// - 值为 nil:清空对应行的 rate_multiplier(保留 rpm_override)。
// - 值非 nil:upsert rate_multiplier(保留已有 rpm_override)。
func
(
r
*
userGroupRateRepository
)
SyncUserGroupRates
(
ctx
context
.
Context
,
userID
int64
,
rates
map
[
int64
]
*
float64
)
error
{
if
len
(
rates
)
==
0
{
// 如果传入空 map,删除该用户的所有专属倍率
_
,
err
:=
r
.
sql
.
ExecContext
(
ctx
,
`DELETE FROM user_group_rate_multipliers WHERE user_id = $1`
,
userID
)
if
_
,
err
:=
r
.
sql
.
ExecContext
(
ctx
,
`
UPDATE user_group_rate_multipliers
SET rate_multiplier = NULL, updated_at = NOW()
WHERE user_id = $1
`
,
userID
);
err
!=
nil
{
return
err
}
_
,
err
:=
r
.
sql
.
ExecContext
(
ctx
,
`DELETE FROM user_group_rate_multipliers WHERE user_id = $1 AND rate_multiplier IS NULL AND rpm_override IS NULL`
,
userID
)
return
err
}
// 分离需要删除和需要 upsert 的记录
var
toDelete
[]
int64
var
clearGroupIDs
[]
int64
upsertGroupIDs
:=
make
([]
int64
,
0
,
len
(
rates
))
upsertRates
:=
make
([]
float64
,
0
,
len
(
rates
))
for
groupID
,
rate
:=
range
rates
{
if
rate
==
nil
{
toDelete
=
append
(
toDelete
,
groupID
)
clearGroupIDs
=
append
(
clearGroupIDs
,
groupID
)
}
else
{
upsertGroupIDs
=
append
(
upsertGroupIDs
,
groupID
)
upsertRates
=
append
(
upsertRates
,
*
rate
)
}
}
// 删除指定的记录
if
len
(
toDelete
)
>
0
{
if
len
(
clearGroupIDs
)
>
0
{
if
_
,
err
:=
r
.
sql
.
ExecContext
(
ctx
,
`
UPDATE user_group_rate_multipliers
SET rate_multiplier = NULL, updated_at = NOW()
WHERE user_id = $1 AND group_id = ANY($2)
`
,
userID
,
pq
.
Array
(
clearGroupIDs
));
err
!=
nil
{
return
err
}
if
_
,
err
:=
r
.
sql
.
ExecContext
(
ctx
,
`DELETE FROM user_group_rate_multipliers WHERE user_id = $1 AND group_id = ANY($2)`
,
userID
,
pq
.
Array
(
toDelete
));
err
!=
nil
{
`DELETE FROM user_group_rate_multipliers WHERE user_id = $1 AND group_id = ANY($2)
AND rate_multiplier IS NULL AND rpm_override IS NULL
`
,
userID
,
pq
.
Array
(
clearGroupIDs
));
err
!=
nil
{
return
err
}
}
// Upsert 记录
now
:=
time
.
Now
()
if
len
(
upsertGroupIDs
)
>
0
{
now
:=
time
.
Now
()
_
,
err
:=
r
.
sql
.
ExecContext
(
ctx
,
`
INSERT INTO user_group_rate_multipliers (user_id, group_id, rate_multiplier, created_at, updated_at)
SELECT
...
...
@@ -193,14 +239,47 @@ func (r *userGroupRateRepository) SyncUserGroupRates(ctx context.Context, userID
return
nil
}
// SyncGroupRateMultipliers 批量同步分组的用户专属倍率(先删后插)
// SyncGroupRateMultipliers 同步分组的 rate_multiplier 部分(不触动 rpm_override)。
// 语义:
// - 未出现在 entries 中的用户行:rate_multiplier 归 NULL;若 rpm_override 也为 NULL 则整行删除。
// - 出现的用户行:upsert rate_multiplier。
func
(
r
*
userGroupRateRepository
)
SyncGroupRateMultipliers
(
ctx
context
.
Context
,
groupID
int64
,
entries
[]
service
.
GroupRateMultiplierInput
)
error
{
if
_
,
err
:=
r
.
sql
.
ExecContext
(
ctx
,
`DELETE FROM user_group_rate_multipliers WHERE group_id = $1`
,
groupID
);
err
!=
nil
{
keepUserIDs
:=
make
([]
int64
,
0
,
len
(
entries
))
for
_
,
e
:=
range
entries
{
keepUserIDs
=
append
(
keepUserIDs
,
e
.
UserID
)
}
// 未在 entries 列表中的行:清空 rate_multiplier。
if
len
(
keepUserIDs
)
==
0
{
if
_
,
err
:=
r
.
sql
.
ExecContext
(
ctx
,
`
UPDATE user_group_rate_multipliers
SET rate_multiplier = NULL, updated_at = NOW()
WHERE group_id = $1
`
,
groupID
);
err
!=
nil
{
return
err
}
}
else
{
if
_
,
err
:=
r
.
sql
.
ExecContext
(
ctx
,
`
UPDATE user_group_rate_multipliers
SET rate_multiplier = NULL, updated_at = NOW()
WHERE group_id = $1 AND user_id <> ALL($2)
`
,
groupID
,
pq
.
Array
(
keepUserIDs
));
err
!=
nil
{
return
err
}
}
// 清空后若整行 NULL 则删除。
if
_
,
err
:=
r
.
sql
.
ExecContext
(
ctx
,
`
DELETE FROM user_group_rate_multipliers
WHERE group_id = $1 AND rate_multiplier IS NULL AND rpm_override IS NULL
`
,
groupID
);
err
!=
nil
{
return
err
}
if
len
(
entries
)
==
0
{
return
nil
}
userIDs
:=
make
([]
int64
,
len
(
entries
))
rates
:=
make
([]
float64
,
len
(
entries
))
for
i
,
e
:=
range
entries
{
...
...
@@ -218,13 +297,103 @@ func (r *userGroupRateRepository) SyncGroupRateMultipliers(ctx context.Context,
return
err
}
// DeleteByGroupID 删除指定分组的所有用户专属倍率
// SyncGroupRPMOverrides 同步分组的 rpm_override 部分(不触动 rate_multiplier)。
// 语义:
// - 未出现的用户行:rpm_override 归 NULL;若 rate_multiplier 也为 NULL 则整行删除。
// - 出现的用户行:若 RPMOverride 为 nil 则清空;非 nil 则 upsert。
func
(
r
*
userGroupRateRepository
)
SyncGroupRPMOverrides
(
ctx
context
.
Context
,
groupID
int64
,
entries
[]
service
.
GroupRPMOverrideInput
)
error
{
keepUserIDs
:=
make
([]
int64
,
0
,
len
(
entries
))
var
clearUserIDs
[]
int64
upsertUserIDs
:=
make
([]
int64
,
0
,
len
(
entries
))
upsertValues
:=
make
([]
int32
,
0
,
len
(
entries
))
for
_
,
e
:=
range
entries
{
keepUserIDs
=
append
(
keepUserIDs
,
e
.
UserID
)
if
e
.
RPMOverride
==
nil
{
clearUserIDs
=
append
(
clearUserIDs
,
e
.
UserID
)
}
else
{
upsertUserIDs
=
append
(
upsertUserIDs
,
e
.
UserID
)
upsertValues
=
append
(
upsertValues
,
int32
(
*
e
.
RPMOverride
))
}
}
// 未在 entries 列表中的行:清空 rpm_override。
if
len
(
keepUserIDs
)
==
0
{
if
_
,
err
:=
r
.
sql
.
ExecContext
(
ctx
,
`
UPDATE user_group_rate_multipliers
SET rpm_override = NULL, updated_at = NOW()
WHERE group_id = $1
`
,
groupID
);
err
!=
nil
{
return
err
}
}
else
{
if
_
,
err
:=
r
.
sql
.
ExecContext
(
ctx
,
`
UPDATE user_group_rate_multipliers
SET rpm_override = NULL, updated_at = NOW()
WHERE group_id = $1 AND user_id <> ALL($2)
`
,
groupID
,
pq
.
Array
(
keepUserIDs
));
err
!=
nil
{
return
err
}
}
// 显式 clear 的行。
if
len
(
clearUserIDs
)
>
0
{
if
_
,
err
:=
r
.
sql
.
ExecContext
(
ctx
,
`
UPDATE user_group_rate_multipliers
SET rpm_override = NULL, updated_at = NOW()
WHERE group_id = $1 AND user_id = ANY($2)
`
,
groupID
,
pq
.
Array
(
clearUserIDs
));
err
!=
nil
{
return
err
}
}
// 清空后若整行 NULL 则删除。
if
_
,
err
:=
r
.
sql
.
ExecContext
(
ctx
,
`
DELETE FROM user_group_rate_multipliers
WHERE group_id = $1 AND rate_multiplier IS NULL AND rpm_override IS NULL
`
,
groupID
);
err
!=
nil
{
return
err
}
if
len
(
upsertUserIDs
)
>
0
{
now
:=
time
.
Now
()
_
,
err
:=
r
.
sql
.
ExecContext
(
ctx
,
`
INSERT INTO user_group_rate_multipliers (user_id, group_id, rpm_override, created_at, updated_at)
SELECT data.user_id, $1::bigint, data.rpm_override, $2::timestamptz, $2::timestamptz
FROM unnest($3::bigint[], $4::integer[]) AS data(user_id, rpm_override)
ON CONFLICT (user_id, group_id)
DO UPDATE SET rpm_override = EXCLUDED.rpm_override, updated_at = EXCLUDED.updated_at
`
,
groupID
,
now
,
pq
.
Array
(
upsertUserIDs
),
pq
.
Array
(
upsertValues
))
if
err
!=
nil
{
return
err
}
}
return
nil
}
// ClearGroupRPMOverrides 清空指定分组所有行的 rpm_override。
func
(
r
*
userGroupRateRepository
)
ClearGroupRPMOverrides
(
ctx
context
.
Context
,
groupID
int64
)
error
{
if
_
,
err
:=
r
.
sql
.
ExecContext
(
ctx
,
`
UPDATE user_group_rate_multipliers
SET rpm_override = NULL, updated_at = NOW()
WHERE group_id = $1
`
,
groupID
);
err
!=
nil
{
return
err
}
_
,
err
:=
r
.
sql
.
ExecContext
(
ctx
,
`
DELETE FROM user_group_rate_multipliers
WHERE group_id = $1 AND rate_multiplier IS NULL AND rpm_override IS NULL
`
,
groupID
)
return
err
}
// DeleteByGroupID 删除指定分组的所有用户专属条目
func
(
r
*
userGroupRateRepository
)
DeleteByGroupID
(
ctx
context
.
Context
,
groupID
int64
)
error
{
_
,
err
:=
r
.
sql
.
ExecContext
(
ctx
,
`DELETE FROM user_group_rate_multipliers WHERE group_id = $1`
,
groupID
)
return
err
}
// DeleteByUserID 删除指定用户的所有专属
倍率
// DeleteByUserID 删除指定用户的所有专属
条目
func
(
r
*
userGroupRateRepository
)
DeleteByUserID
(
ctx
context
.
Context
,
userID
int64
)
error
{
_
,
err
:=
r
.
sql
.
ExecContext
(
ctx
,
`DELETE FROM user_group_rate_multipliers WHERE user_id = $1`
,
userID
)
return
err
...
...
backend/internal/repository/user_repo.go
View file @
5e060b22
...
...
@@ -93,6 +93,7 @@ func (r *userRepository) Create(ctx context.Context, userIn *service.User) error
SetSignupSource
(
userSignupSourceOrDefault
(
userIn
.
SignupSource
))
.
SetNillableLastLoginAt
(
userIn
.
LastLoginAt
)
.
SetNillableLastActiveAt
(
userIn
.
LastActiveAt
)
.
SetRpmLimit
(
userIn
.
RPMLimit
)
.
Save
(
txCtx
)
if
err
!=
nil
{
return
translatePersistenceError
(
err
,
nil
,
service
.
ErrEmailExists
)
...
...
@@ -219,7 +220,8 @@ func (r *userRepository) Update(ctx context.Context, userIn *service.User) error
SetBalanceNotifyThresholdType
(
userIn
.
BalanceNotifyThresholdType
)
.
SetNillableBalanceNotifyThreshold
(
userIn
.
BalanceNotifyThreshold
)
.
SetBalanceNotifyExtraEmails
(
marshalExtraEmails
(
userIn
.
BalanceNotifyExtraEmails
))
.
SetTotalRecharged
(
userIn
.
TotalRecharged
)
SetTotalRecharged
(
userIn
.
TotalRecharged
)
.
SetRpmLimit
(
userIn
.
RPMLimit
)
if
userIn
.
SignupSource
!=
""
{
updateOp
=
updateOp
.
SetSignupSource
(
userIn
.
SignupSource
)
}
...
...
backend/internal/repository/user_rpm_cache.go
0 → 100644
View file @
5e060b22
package
repository
import
(
"context"
"fmt"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/redis/go-redis/v9"
)
// 用户/分组级 RPM 计数器 Redis 实现。
//
// 设计说明:
// - key 形式:rpm:ug:{uid}:{gid}:{minute}、rpm:u:{uid}:{minute}
// - 时间来源:rdb.Time()(Redis 服务端时间),避免多实例时钟漂移。
// - 原子操作:TxPipeline (MULTI/EXEC) 执行 INCR+EXPIRE,兼容 Redis Cluster。
// - TTL:120s,覆盖当前分钟窗口 + 少量冗余。
// - 返回值语义:超限判断由调用方(billing_cache_service.checkRPM)与 RPMLimit 比较完成。
const
(
userGroupRPMKeyPrefix
=
"rpm:ug:"
userRPMKeyPrefix
=
"rpm:u:"
userRPMKeyTTL
=
120
*
time
.
Second
)
type
userRPMCacheImpl
struct
{
rdb
*
redis
.
Client
}
// NewUserRPMCache 创建用户/分组级 RPM 计数器。
func
NewUserRPMCache
(
rdb
*
redis
.
Client
)
service
.
UserRPMCache
{
return
&
userRPMCacheImpl
{
rdb
:
rdb
}
}
// minuteTS 获取当前 Redis 服务端分钟时间戳。
func
(
c
*
userRPMCacheImpl
)
minuteTS
(
ctx
context
.
Context
)
(
int64
,
error
)
{
t
,
err
:=
c
.
rdb
.
Time
(
ctx
)
.
Result
()
if
err
!=
nil
{
return
0
,
fmt
.
Errorf
(
"redis TIME: %w"
,
err
)
}
return
t
.
Unix
()
/
60
,
nil
}
// atomicIncr 原子 INCR+EXPIRE。
func
(
c
*
userRPMCacheImpl
)
atomicIncr
(
ctx
context
.
Context
,
key
string
)
(
int
,
error
)
{
pipe
:=
c
.
rdb
.
TxPipeline
()
incr
:=
pipe
.
Incr
(
ctx
,
key
)
pipe
.
Expire
(
ctx
,
key
,
userRPMKeyTTL
)
if
_
,
err
:=
pipe
.
Exec
(
ctx
);
err
!=
nil
{
return
0
,
fmt
.
Errorf
(
"user rpm increment: %w"
,
err
)
}
return
int
(
incr
.
Val
()),
nil
}
// IncrementUserGroupRPM 递增 (user, group) 分钟计数。
func
(
c
*
userRPMCacheImpl
)
IncrementUserGroupRPM
(
ctx
context
.
Context
,
userID
,
groupID
int64
)
(
int
,
error
)
{
minute
,
err
:=
c
.
minuteTS
(
ctx
)
if
err
!=
nil
{
return
0
,
err
}
key
:=
fmt
.
Sprintf
(
"%s%d:%d:%d"
,
userGroupRPMKeyPrefix
,
userID
,
groupID
,
minute
)
return
c
.
atomicIncr
(
ctx
,
key
)
}
// IncrementUserRPM 递增用户分钟计数。
func
(
c
*
userRPMCacheImpl
)
IncrementUserRPM
(
ctx
context
.
Context
,
userID
int64
)
(
int
,
error
)
{
minute
,
err
:=
c
.
minuteTS
(
ctx
)
if
err
!=
nil
{
return
0
,
err
}
key
:=
fmt
.
Sprintf
(
"%s%d:%d"
,
userRPMKeyPrefix
,
userID
,
minute
)
return
c
.
atomicIncr
(
ctx
,
key
)
}
// GetUserGroupRPM 获取 (user, group) 当前分钟已用 RPM(只读)。
func
(
c
*
userRPMCacheImpl
)
GetUserGroupRPM
(
ctx
context
.
Context
,
userID
,
groupID
int64
)
(
int
,
error
)
{
minute
,
err
:=
c
.
minuteTS
(
ctx
)
if
err
!=
nil
{
return
0
,
err
}
key
:=
fmt
.
Sprintf
(
"%s%d:%d:%d"
,
userGroupRPMKeyPrefix
,
userID
,
groupID
,
minute
)
val
,
err
:=
c
.
rdb
.
Get
(
ctx
,
key
)
.
Int
()
if
err
==
redis
.
Nil
{
return
0
,
nil
}
if
err
!=
nil
{
return
0
,
fmt
.
Errorf
(
"user group rpm get: %w"
,
err
)
}
return
val
,
nil
}
// GetUserRPM 获取用户当前分钟已用 RPM(只读)。
func
(
c
*
userRPMCacheImpl
)
GetUserRPM
(
ctx
context
.
Context
,
userID
int64
)
(
int
,
error
)
{
minute
,
err
:=
c
.
minuteTS
(
ctx
)
if
err
!=
nil
{
return
0
,
err
}
key
:=
fmt
.
Sprintf
(
"%s%d:%d"
,
userRPMKeyPrefix
,
userID
,
minute
)
val
,
err
:=
c
.
rdb
.
Get
(
ctx
,
key
)
.
Int
()
if
err
==
redis
.
Nil
{
return
0
,
nil
}
if
err
!=
nil
{
return
0
,
fmt
.
Errorf
(
"user rpm get: %w"
,
err
)
}
return
val
,
nil
}
backend/internal/repository/wire.go
View file @
5e060b22
...
...
@@ -98,10 +98,12 @@ var ProviderSet = wire.NewSet(
NewAPIKeyCache
,
NewTempUnschedCache
,
NewTimeoutCounterCache
,
NewOpenAI403CounterCache
,
NewInternal500CounterCache
,
ProvideConcurrencyCache
,
ProvideSessionLimitCache
,
NewRPMCache
,
NewUserRPMCache
,
NewUserMsgQueueCache
,
NewDashboardCache
,
NewEmailCache
,
...
...
backend/internal/server/api_contract_test.go
View file @
5e060b22
...
...
@@ -55,6 +55,7 @@ func TestAPIContracts(t *testing.T) {
"role": "user",
"balance": 12.5,
"concurrency": 5,
"rpm_limit": 0,
"status": "active",
"allowed_groups": null,
"created_at": "2025-01-02T03:04:05Z",
...
...
@@ -333,6 +334,7 @@ func TestAPIContracts(t *testing.T) {
"fallback_group_id_on_invalid_request": null,
"require_oauth_only": false,
"require_privacy_set": false,
"rpm_limit": 0,
"created_at": "2025-01-02T03:04:05Z",
"updated_at": "2025-01-02T03:04:05Z"
}
...
...
@@ -713,6 +715,7 @@ func TestAPIContracts(t *testing.T) {
"force_email_on_third_party_signup": false,
"default_concurrency": 5,
"default_balance": 1.25,
"default_user_rpm_limit": 0,
"default_subscriptions": [],
"enable_model_fallback": false,
"fallback_model_anthropic": "claude-3-5-sonnet-20241022",
...
...
@@ -892,6 +895,7 @@ func TestAPIContracts(t *testing.T) {
"custom_endpoints": [],
"default_concurrency": 0,
"default_balance": 0,
"default_user_rpm_limit": 0,
"default_subscriptions": [],
"enable_model_fallback": false,
"fallback_model_anthropic": "claude-3-5-sonnet-20241022",
...
...
@@ -1090,7 +1094,7 @@ func newContractDeps(t *testing.T) *contractDeps {
settingRepo
:=
newStubSettingRepo
()
settingService
:=
service
.
NewSettingService
(
settingRepo
,
cfg
)
adminService
:=
service
.
NewAdminService
(
userRepo
,
groupRepo
,
&
accountRepo
,
proxyRepo
,
apiKeyRepo
,
redeemRepo
,
nil
,
nil
,
nil
,
nil
,
nil
,
nil
,
nil
,
nil
,
nil
,
nil
)
adminService
:=
service
.
NewAdminService
(
userRepo
,
groupRepo
,
&
accountRepo
,
proxyRepo
,
apiKeyRepo
,
redeemRepo
,
nil
,
nil
,
nil
,
nil
,
nil
,
nil
,
nil
,
nil
,
nil
,
nil
,
nil
)
authHandler
:=
handler
.
NewAuthHandler
(
cfg
,
nil
,
userService
,
settingService
,
nil
,
redeemService
,
nil
)
apiKeyHandler
:=
handler
.
NewAPIKeyHandler
(
apiKeyService
)
usageHandler
:=
handler
.
NewUsageHandler
(
usageService
,
apiKeyService
)
...
...
backend/internal/server/routes/admin.go
View file @
5e060b22
...
...
@@ -224,6 +224,7 @@ func registerUserManagementRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
users
.
GET
(
"/:id/usage"
,
h
.
Admin
.
User
.
GetUserUsage
)
users
.
GET
(
"/:id/balance-history"
,
h
.
Admin
.
User
.
GetBalanceHistory
)
users
.
POST
(
"/:id/replace-group"
,
h
.
Admin
.
User
.
ReplaceGroup
)
users
.
GET
(
"/:id/rpm-status"
,
h
.
Admin
.
User
.
GetUserRPMStatus
)
// User attribute values
users
.
GET
(
"/:id/attributes"
,
h
.
Admin
.
UserAttribute
.
GetUserAttributes
)
...
...
@@ -247,6 +248,8 @@ func registerGroupRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
groups
.
GET
(
"/:id/rate-multipliers"
,
h
.
Admin
.
Group
.
GetGroupRateMultipliers
)
groups
.
PUT
(
"/:id/rate-multipliers"
,
h
.
Admin
.
Group
.
BatchSetGroupRateMultipliers
)
groups
.
DELETE
(
"/:id/rate-multipliers"
,
h
.
Admin
.
Group
.
ClearGroupRateMultipliers
)
groups
.
PUT
(
"/:id/rpm-overrides"
,
h
.
Admin
.
Group
.
BatchSetGroupRPMOverrides
)
groups
.
DELETE
(
"/:id/rpm-overrides"
,
h
.
Admin
.
Group
.
ClearGroupRPMOverrides
)
groups
.
GET
(
"/:id/api-keys"
,
h
.
Admin
.
Group
.
GetGroupAPIKeys
)
}
}
...
...
backend/internal/service/account.go
View file @
5e060b22
...
...
@@ -930,10 +930,8 @@ func (a *Account) SupportsOpenAIImageCapability(capability OpenAIImagesCapabilit
return
false
}
switch
capability
{
case
OpenAIImagesCapabilityBasic
:
case
OpenAIImagesCapabilityBasic
,
OpenAIImagesCapabilityNative
:
return
a
.
Type
==
AccountTypeOAuth
||
a
.
Type
==
AccountTypeAPIKey
case
OpenAIImagesCapabilityNative
:
return
a
.
Type
==
AccountTypeAPIKey
default
:
return
true
}
...
...
backend/internal/service/account_test_service.go
View file @
5e060b22
...
...
@@ -5,7 +5,6 @@ import (
"bytes"
"context"
"crypto/rand"
"encoding/base64"
"encoding/hex"
"encoding/json"
"errors"
...
...
@@ -1138,7 +1137,7 @@ func (s *AccountTestService) testOpenAIImageAPIKey(c *gin.Context, ctx context.C
return
nil
}
// testOpenAIImageOAuth tests OpenAI image generation using an OAuth account via C
hatGPT backend
API.
// testOpenAIImageOAuth tests OpenAI image generation using an OAuth account via C
odex /responses
API.
func
(
s
*
AccountTestService
)
testOpenAIImageOAuth
(
c
*
gin
.
Context
,
ctx
context
.
Context
,
account
*
Account
,
modelID
,
prompt
string
)
error
{
authToken
:=
account
.
GetOpenAIAccessToken
()
if
authToken
==
""
{
...
...
@@ -1153,69 +1152,46 @@ func (s *AccountTestService) testOpenAIImageOAuth(c *gin.Context, ctx context.Co
c
.
Writer
.
Flush
()
s
.
sendEvent
(
c
,
TestEvent
{
Type
:
"test_start"
,
Model
:
modelID
})
s
.
sendEvent
(
c
,
TestEvent
{
Type
:
"content"
,
Text
:
"
Initializing ChatGPT backend
...
\n
"
})
s
.
sendEvent
(
c
,
TestEvent
{
Type
:
"content"
,
Text
:
"
Calling Codex /responses image tool
...
\n
"
})
// Build headers (replicating buildOpenAIBackendAPIHeaders logic)
headers
:=
buildOpenAIBackendAPIHeadersForTest
(
ctx
,
account
,
authToken
,
s
.
accountRepo
)
proxyURL
:=
""
if
account
.
ProxyID
!=
nil
&&
account
.
Proxy
!=
nil
{
proxyURL
=
account
.
Proxy
.
URL
()
parsed
:=
&
OpenAIImagesRequest
{
Endpoint
:
openAIImagesGenerationsEndpoint
,
Model
:
strings
.
TrimSpace
(
modelID
),
Prompt
:
prompt
,
}
applyOpenAIImagesDefaults
(
parsed
)
client
,
err
:=
new
OpenAI
BackendAPIClient
(
proxyURL
)
responsesBody
,
err
:=
build
OpenAI
ImagesResponsesRequest
(
parsed
,
parsed
.
Model
)
if
err
!=
nil
{
return
s
.
sendErrorAndEnd
(
c
,
fmt
.
Sprintf
(
"Failed to
create clien
t: %s"
,
err
.
Error
()))
return
s
.
sendErrorAndEnd
(
c
,
fmt
.
Sprintf
(
"Failed to
build image reques
t: %s"
,
err
.
Error
()))
}
// Bootstrap
if
bootstrapErr
:=
bootstrapOpenAIBackendAPI
(
ctx
,
client
,
headers
);
bootstrapErr
!=
nil
{
log
.
Printf
(
"OpenAI image test bootstrap warning: %v"
,
bootstrapErr
)
}
// Fetch chat requirements
s
.
sendEvent
(
c
,
TestEvent
{
Type
:
"content"
,
Text
:
"Fetching chat requirements...
\n
"
})
chatReqs
,
err
:=
fetchOpenAIChatRequirements
(
ctx
,
client
,
headers
)
req
,
err
:=
http
.
NewRequestWithContext
(
ctx
,
http
.
MethodPost
,
chatgptCodexAPIURL
,
bytes
.
NewReader
(
responsesBody
))
if
err
!=
nil
{
return
s
.
sendErrorAndEnd
(
c
,
fmt
.
Sprintf
(
"Ch
at requ
irements failed: %s"
,
err
.
Error
())
)
return
s
.
sendErrorAndEnd
(
c
,
"Failed to cre
at
e
requ
est"
)
}
if
chatReqs
.
Arkose
.
Required
{
return
s
.
sendErrorAndEnd
(
c
,
"Unsupported challenge: arkose required"
)
req
.
Host
=
"chatgpt.com"
req
.
Header
.
Set
(
"Authorization"
,
"Bearer "
+
authToken
)
req
.
Header
.
Set
(
"Content-Type"
,
"application/json"
)
req
.
Header
.
Set
(
"Accept"
,
"text/event-stream"
)
req
.
Header
.
Set
(
"OpenAI-Beta"
,
"responses=experimental"
)
req
.
Header
.
Set
(
"originator"
,
"opencode"
)
if
customUA
:=
strings
.
TrimSpace
(
account
.
GetOpenAIUserAgent
());
customUA
!=
""
{
req
.
Header
.
Set
(
"User-Agent"
,
customUA
)
}
else
{
req
.
Header
.
Set
(
"User-Agent"
,
codexCLIUserAgent
)
}
// Initialize and prepare conversation
s
.
sendEvent
(
c
,
TestEvent
{
Type
:
"content"
,
Text
:
"Preparing image conversation...
\n
"
})
parentMessageID
:=
uuid
.
NewString
()
proofToken
:=
generateOpenAIProofToken
(
chatReqs
.
ProofOfWork
.
Required
,
chatReqs
.
ProofOfWork
.
Seed
,
chatReqs
.
ProofOfWork
.
Difficulty
,
headers
.
Get
(
"User-Agent"
))
_
=
initializeOpenAIImageConversation
(
ctx
,
client
,
headers
)
conduitToken
,
err
:=
prepareOpenAIImageConversation
(
ctx
,
client
,
headers
,
prompt
,
parentMessageID
,
chatReqs
.
Token
,
proofToken
)
if
err
!=
nil
{
return
s
.
sendErrorAndEnd
(
c
,
fmt
.
Sprintf
(
"Conversation prepare failed: %s"
,
err
.
Error
()))
if
chatgptAccountID
:=
strings
.
TrimSpace
(
account
.
GetChatGPTAccountID
());
chatgptAccountID
!=
""
{
req
.
Header
.
Set
(
"chatgpt-account-id"
,
chatgptAccountID
)
}
// Build simplified conversation request (no file uploads)
convReq
:=
buildOpenAIImageTestConversationRequest
(
prompt
,
parentMessageID
)
convHeaders
:=
cloneHTTPHeader
(
headers
)
convHeaders
.
Set
(
"Accept"
,
"text/event-stream"
)
convHeaders
.
Set
(
"Content-Type"
,
"application/json"
)
convHeaders
.
Set
(
"openai-sentinel-chat-requirements-token"
,
chatReqs
.
Token
)
if
conduitToken
!=
""
{
convHeaders
.
Set
(
"x-conduit-token"
,
conduitToken
)
}
if
proofToken
!=
""
{
convHeaders
.
Set
(
"openai-sentinel-proof-token"
,
proofToken
)
proxyURL
:=
""
if
account
.
ProxyID
!=
nil
&&
account
.
Proxy
!=
nil
{
proxyURL
=
account
.
Proxy
.
URL
()
}
s
.
sendEvent
(
c
,
TestEvent
{
Type
:
"content"
,
Text
:
"Generating image...
\n
"
})
resp
,
err
:=
client
.
R
()
.
SetContext
(
ctx
)
.
DisableAutoReadResponse
()
.
SetHeaders
(
headerToMap
(
convHeaders
))
.
SetBodyJsonMarshal
(
convReq
)
.
Post
(
openAIChatGPTConversationURL
)
resp
,
err
:=
s
.
httpUpstream
.
Do
(
req
,
proxyURL
,
account
.
ID
,
account
.
Concurrency
)
if
err
!=
nil
{
return
s
.
sendErrorAndEnd
(
c
,
fmt
.
Sprintf
(
"
Conversation
request failed: %s"
,
err
.
Error
()))
return
s
.
sendErrorAndEnd
(
c
,
fmt
.
Sprintf
(
"
Responses API
request failed: %s"
,
err
.
Error
()))
}
defer
func
()
{
if
resp
!=
nil
&&
resp
.
Body
!=
nil
{
...
...
@@ -1223,49 +1199,35 @@ func (s *AccountTestService) testOpenAIImageOAuth(c *gin.Context, ctx context.Co
}
}()
if
resp
.
StatusCode
>=
400
{
return
s
.
sendErrorAndEnd
(
c
,
fmt
.
Sprintf
(
"Conversation API returned %d"
,
resp
.
StatusCode
))
body
,
_
:=
io
.
ReadAll
(
io
.
LimitReader
(
resp
.
Body
,
2
<<
20
))
message
:=
strings
.
TrimSpace
(
extractUpstreamErrorMessage
(
body
))
if
message
==
""
{
message
=
fmt
.
Sprintf
(
"Responses API returned %d"
,
resp
.
StatusCode
)
}
return
s
.
sendErrorAndEnd
(
c
,
message
)
}
startTime
:=
time
.
Now
()
conversationID
,
pointerInfos
,
_
,
_
,
err
:=
readOpenAIImageConversationStream
(
resp
,
startTime
)
body
,
err
:=
io
.
ReadAll
(
resp
.
Body
)
if
err
!=
nil
{
return
s
.
sendErrorAndEnd
(
c
,
fmt
.
Sprintf
(
"
Stream read failed
: %s"
,
err
.
Error
()))
return
s
.
sendErrorAndEnd
(
c
,
fmt
.
Sprintf
(
"
Failed to read image response
: %s"
,
err
.
Error
()))
}
pointerInfos
=
mergeOpenAIImagePointerInfos
(
pointerInfos
,
nil
)
if
conversationID
!=
""
&&
!
hasOpenAIFileServicePointerInfos
(
pointerInfos
)
{
s
.
sendEvent
(
c
,
TestEvent
{
Type
:
"content"
,
Text
:
"Waiting for image generation to complete...
\n
"
})
polledPointers
,
pollErr
:=
pollOpenAIImageConversation
(
ctx
,
client
,
headers
,
conversationID
)
if
pollErr
!=
nil
{
return
s
.
sendErrorAndEnd
(
c
,
fmt
.
Sprintf
(
"Poll failed: %s"
,
pollErr
.
Error
()))
}
pointerInfos
=
mergeOpenAIImagePointerInfos
(
pointerInfos
,
polledPointers
)
results
,
_
,
_
,
_
,
_
,
err
:=
collectOpenAIImagesFromResponsesBody
(
body
)
if
err
!=
nil
{
return
s
.
sendErrorAndEnd
(
c
,
fmt
.
Sprintf
(
"Failed to parse image response: %s"
,
err
.
Error
()))
}
pointerInfos
=
preferOpenAIFileServicePointerInfos
(
pointerInfos
)
if
len
(
pointerInfos
)
==
0
{
return
s
.
sendErrorAndEnd
(
c
,
"No images returned from conversation"
)
if
len
(
results
)
==
0
{
return
s
.
sendErrorAndEnd
(
c
,
"No images returned from responses API"
)
}
s
.
sendEvent
(
c
,
TestEvent
{
Type
:
"content"
,
Text
:
"Downloading generated image...
\n
"
})
// Download and encode each image
for
_
,
pointer
:=
range
pointerInfos
{
downloadURL
,
err
:=
fetchOpenAIImageDownloadURL
(
ctx
,
client
,
headers
,
conversationID
,
pointer
.
Pointer
)
if
err
!=
nil
{
return
s
.
sendErrorAndEnd
(
c
,
fmt
.
Sprintf
(
"Download URL fetch failed: %s"
,
err
.
Error
()))
}
data
,
err
:=
downloadOpenAIImageBytes
(
ctx
,
client
,
headers
,
downloadURL
)
if
err
!=
nil
{
return
s
.
sendErrorAndEnd
(
c
,
fmt
.
Sprintf
(
"Image download failed: %s"
,
err
.
Error
()))
}
b64
:=
base64
.
StdEncoding
.
EncodeToString
(
data
)
mimeType
:=
http
.
DetectContentType
(
data
)
if
pointer
.
Prompt
!=
""
{
s
.
sendEvent
(
c
,
TestEvent
{
Type
:
"content"
,
Text
:
pointer
.
Prompt
})
for
_
,
item
:=
range
results
{
if
item
.
RevisedPrompt
!=
""
{
s
.
sendEvent
(
c
,
TestEvent
{
Type
:
"content"
,
Text
:
item
.
RevisedPrompt
})
}
mimeType
:=
openAIImageOutputMIMEType
(
item
.
OutputFormat
)
s
.
sendEvent
(
c
,
TestEvent
{
Type
:
"image"
,
ImageURL
:
"data:"
+
mimeType
+
";base64,"
+
b64
,
ImageURL
:
"data:"
+
mimeType
+
";base64,"
+
item
.
Result
,
MimeType
:
mimeType
,
})
}
...
...
@@ -1274,107 +1236,6 @@ func (s *AccountTestService) testOpenAIImageOAuth(c *gin.Context, ctx context.Co
return
nil
}
// buildOpenAIBackendAPIHeadersForTest builds ChatGPT backend API headers for test purposes.
// Replicates the logic from OpenAIGatewayService.buildOpenAIBackendAPIHeaders without
// requiring the full gateway service dependency.
func
buildOpenAIBackendAPIHeadersForTest
(
ctx
context
.
Context
,
account
*
Account
,
token
string
,
repo
AccountRepository
)
http
.
Header
{
// Ensure device and session IDs exist
deviceID
:=
account
.
GetOpenAIDeviceID
()
sessionID
:=
account
.
GetOpenAISessionID
()
if
deviceID
==
""
||
sessionID
==
""
{
updates
:=
map
[
string
]
any
{}
if
deviceID
==
""
{
deviceID
=
uuid
.
NewString
()
updates
[
"openai_device_id"
]
=
deviceID
}
if
sessionID
==
""
{
sessionID
=
uuid
.
NewString
()
updates
[
"openai_session_id"
]
=
sessionID
}
if
account
.
Extra
==
nil
{
account
.
Extra
=
map
[
string
]
any
{}
}
for
key
,
value
:=
range
updates
{
account
.
Extra
[
key
]
=
value
}
if
repo
!=
nil
{
updateCtx
,
cancel
:=
context
.
WithTimeout
(
ctx
,
5
*
time
.
Second
)
defer
cancel
()
_
=
repo
.
UpdateExtra
(
updateCtx
,
account
.
ID
,
updates
)
}
}
headers
:=
make
(
http
.
Header
)
headers
.
Set
(
"Authorization"
,
"Bearer "
+
token
)
headers
.
Set
(
"Accept"
,
"application/json"
)
headers
.
Set
(
"Origin"
,
"https://chatgpt.com"
)
headers
.
Set
(
"Referer"
,
"https://chatgpt.com/"
)
headers
.
Set
(
"Sec-Fetch-Dest"
,
"empty"
)
headers
.
Set
(
"Sec-Fetch-Mode"
,
"cors"
)
headers
.
Set
(
"Sec-Fetch-Site"
,
"same-origin"
)
headers
.
Set
(
"User-Agent"
,
openAIImageBackendUserAgent
)
if
customUA
:=
strings
.
TrimSpace
(
account
.
GetOpenAIUserAgent
());
customUA
!=
""
{
headers
.
Set
(
"User-Agent"
,
customUA
)
}
if
chatgptAccountID
:=
strings
.
TrimSpace
(
account
.
GetChatGPTAccountID
());
chatgptAccountID
!=
""
{
headers
.
Set
(
"chatgpt-account-id"
,
chatgptAccountID
)
}
if
deviceID
!=
""
{
headers
.
Set
(
"oai-device-id"
,
deviceID
)
headers
.
Set
(
"Cookie"
,
"oai-did="
+
deviceID
)
}
if
sessionID
!=
""
{
headers
.
Set
(
"oai-session-id"
,
sessionID
)
}
return
headers
}
// buildOpenAIImageTestConversationRequest creates a simplified image generation conversation request.
func
buildOpenAIImageTestConversationRequest
(
prompt
,
parentMessageID
string
)
map
[
string
]
any
{
promptText
:=
strings
.
TrimSpace
(
prompt
)
if
promptText
==
""
{
promptText
=
"Generate an image."
}
metadata
:=
map
[
string
]
any
{
"developer_mode_connector_ids"
:
[]
any
{},
"selected_github_repos"
:
[]
any
{},
"selected_all_github_repos"
:
false
,
"system_hints"
:
[]
string
{
"picture_v2"
},
"serialization_metadata"
:
map
[
string
]
any
{
"custom_symbol_offsets"
:
[]
any
{},
},
}
message
:=
map
[
string
]
any
{
"id"
:
uuid
.
NewString
(),
"author"
:
map
[
string
]
any
{
"role"
:
"user"
},
"content"
:
map
[
string
]
any
{
"content_type"
:
"text"
,
"parts"
:
[]
any
{
promptText
},
},
"metadata"
:
metadata
,
"create_time"
:
float64
(
time
.
Now
()
.
UnixMilli
())
/
1000
,
}
return
map
[
string
]
any
{
"action"
:
"next"
,
"client_prepare_state"
:
"sent"
,
"parent_message_id"
:
parentMessageID
,
"messages"
:
[]
any
{
message
},
"model"
:
"auto"
,
"timezone_offset_min"
:
openAITimezoneOffsetMinutes
(),
"timezone"
:
openAITimezoneName
(),
"conversation_mode"
:
map
[
string
]
any
{
"kind"
:
"primary_assistant"
},
"system_hints"
:
[]
string
{
"picture_v2"
},
"supports_buffering"
:
true
,
"supported_encodings"
:
[]
string
{
"v1"
},
"client_contextual_info"
:
map
[
string
]
any
{
"app_name"
:
"chatgpt.com"
},
"force_nulligen"
:
false
,
"force_paragen"
:
false
,
"force_paragen_model_slug"
:
""
,
"force_rate_limit"
:
false
,
"websocket_request_id"
:
uuid
.
NewString
(),
}
}
func
(
s
*
AccountTestService
)
sendEvent
(
c
*
gin
.
Context
,
event
TestEvent
)
{
eventJSON
,
_
:=
json
.
Marshal
(
event
)
if
_
,
err
:=
fmt
.
Fprintf
(
c
.
Writer
,
"data: %s
\n\n
"
,
eventJSON
);
err
!=
nil
{
...
...
backend/internal/service/account_test_service_openai_image_test.go
0 → 100644
View file @
5e060b22
package
service
import
(
"context"
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
func
TestAccountTestService_OpenAIImageOAuthHandlesOutputItemDoneFallback
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
rec
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
rec
)
c
.
Request
=
httptest
.
NewRequest
(
http
.
MethodPost
,
"/api/v1/admin/accounts/1/test"
,
nil
)
upstream
:=
&
httpUpstreamRecorder
{
resp
:
&
http
.
Response
{
StatusCode
:
http
.
StatusOK
,
Header
:
http
.
Header
{
"Content-Type"
:
[]
string
{
"text/event-stream"
},
},
Body
:
io
.
NopCloser
(
strings
.
NewReader
(
"data: {
\"
type
\"
:
\"
response.output_item.done
\"
,
\"
item
\"
:{
\"
id
\"
:
\"
ig_123
\"
,
\"
type
\"
:
\"
image_generation_call
\"
,
\"
result
\"
:
\"
aGVsbG8=
\"
,
\"
revised_prompt
\"
:
\"
draw a cat
\"
,
\"
output_format
\"
:
\"
png
\"
}}
\n\n
"
+
"data: {
\"
type
\"
:
\"
response.completed
\"
,
\"
response
\"
:{
\"
created_at
\"
:1710000006,
\"
tool_usage
\"
:{
\"
image_gen
\"
:{
\"
images
\"
:1}},
\"
output
\"
:[]}}
\n\n
"
+
"data: [DONE]
\n\n
"
,
)),
},
}
svc
:=
&
AccountTestService
{
httpUpstream
:
upstream
}
account
:=
&
Account
{
ID
:
53
,
Name
:
"openai-oauth"
,
Platform
:
PlatformOpenAI
,
Type
:
AccountTypeOAuth
,
Credentials
:
map
[
string
]
any
{
"access_token"
:
"token-123"
,
},
}
err
:=
svc
.
testOpenAIImageOAuth
(
c
,
context
.
Background
(),
account
,
"gpt-image-2"
,
"draw a cat"
)
require
.
NoError
(
t
,
err
)
require
.
Contains
(
t
,
rec
.
Body
.
String
(),
"Calling Codex /responses image tool"
)
require
.
Contains
(
t
,
rec
.
Body
.
String
(),
"data:image/png;base64,aGVsbG8="
)
require
.
Contains
(
t
,
rec
.
Body
.
String
(),
"
\"
success
\"
:true"
)
}
backend/internal/service/admin_service.go
View file @
5e060b22
...
...
@@ -8,6 +8,7 @@ import (
"io"
"log/slog"
"net/http"
"sort"
"strings"
"time"
...
...
@@ -32,6 +33,7 @@ type AdminService interface {
UpdateUserBalance
(
ctx
context
.
Context
,
userID
int64
,
balance
float64
,
operation
string
,
notes
string
)
(
*
User
,
error
)
GetUserAPIKeys
(
ctx
context
.
Context
,
userID
int64
,
page
,
pageSize
int
,
sortBy
,
sortOrder
string
)
([]
APIKey
,
int64
,
error
)
GetUserUsageStats
(
ctx
context
.
Context
,
userID
int64
,
period
string
)
(
any
,
error
)
GetUserRPMStatus
(
ctx
context
.
Context
,
userID
int64
)
(
*
UserRPMStatus
,
error
)
// GetUserBalanceHistory returns paginated balance/concurrency change records for a user.
// codeType is optional - pass empty string to return all types.
// Also returns totalRecharged (sum of all positive balance top-ups).
...
...
@@ -50,6 +52,8 @@ type AdminService interface {
GetGroupRateMultipliers
(
ctx
context
.
Context
,
groupID
int64
)
([]
UserGroupRateEntry
,
error
)
ClearGroupRateMultipliers
(
ctx
context
.
Context
,
groupID
int64
)
error
BatchSetGroupRateMultipliers
(
ctx
context
.
Context
,
groupID
int64
,
entries
[]
GroupRateMultiplierInput
)
error
ClearGroupRPMOverrides
(
ctx
context
.
Context
,
groupID
int64
)
error
BatchSetGroupRPMOverrides
(
ctx
context
.
Context
,
groupID
int64
,
entries
[]
GroupRPMOverrideInput
)
error
UpdateGroupSortOrders
(
ctx
context
.
Context
,
updates
[]
GroupSortOrderUpdate
)
error
// API Key management (admin)
...
...
@@ -114,6 +118,7 @@ type CreateUserInput struct {
Notes
string
Balance
float64
Concurrency
int
RPMLimit
int
AllowedGroups
[]
int64
}
...
...
@@ -124,6 +129,7 @@ type UpdateUserInput struct {
Notes
*
string
Balance
*
float64
// 使用指针区分"未提供"和"设置为0"
Concurrency
*
int
// 使用指针区分"未提供"和"设置为0"
RPMLimit
*
int
// 使用指针区分"未提供"和"设置为0"
Status
string
AllowedGroups
*
[]
int64
// 使用指针区分"未提供"和"设置为空数组"
// GroupRates 用户专属分组倍率配置
...
...
@@ -199,6 +205,8 @@ type CreateGroupInput struct {
RequireOAuthOnly
bool
RequirePrivacySet
bool
MessagesDispatchModelConfig
OpenAIMessagesDispatchModelConfig
// RPMLimit 分组 RPM 上限(0 = 不限制)
RPMLimit
int
// 从指定分组复制账号(创建分组后在同一事务内绑定)
CopyAccountsFromGroupIDs
[]
int64
}
...
...
@@ -234,6 +242,8 @@ type UpdateGroupInput struct {
RequireOAuthOnly
*
bool
RequirePrivacySet
*
bool
MessagesDispatchModelConfig
*
OpenAIMessagesDispatchModelConfig
// RPMLimit 分组 RPM 上限(0 = 不限制),nil 表示未提供不改动。
RPMLimit
*
int
// 从指定分组复制账号(同步操作:先清空当前分组的账号绑定,再绑定源分组的账号)
CopyAccountsFromGroupIDs
[]
int64
}
...
...
@@ -317,6 +327,22 @@ type ReplaceUserGroupResult struct {
MigratedKeys
int64
// 迁移的 Key 数量
}
// UserRPMStatus describes a user's current per-minute RPM usage.
type
UserRPMStatus
struct
{
UserRPMUsed
int
`json:"user_rpm_used"`
UserRPMLimit
int
`json:"user_rpm_limit"`
PerGroup
[]
UserGroupRPMStatus
`json:"per_group"`
}
// UserGroupRPMStatus describes current per-minute RPM usage for one user/group pair.
type
UserGroupRPMStatus
struct
{
GroupID
int64
`json:"group_id"`
GroupName
string
`json:"group_name"`
Used
int
`json:"used"`
Limit
int
`json:"limit"`
Source
string
`json:"source"`
// "group" | "override"
}
// BulkUpdateAccountsResult is the aggregated response for bulk updates.
type
BulkUpdateAccountsResult
struct
{
Success
int
`json:"success"`
...
...
@@ -463,6 +489,8 @@ const (
proxyQualityClientUserAgent
=
"Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/136.0.0.0 Safari/537.36"
)
var
ErrRPMStatusUnavailable
=
infraerrors
.
New
(
http
.
StatusNotImplemented
,
"RPM_STATUS_UNAVAILABLE"
,
"RPM cache not available"
)
// adminServiceImpl implements AdminService
type
adminServiceImpl
struct
{
userRepo
UserRepository
...
...
@@ -472,6 +500,7 @@ type adminServiceImpl struct {
apiKeyRepo
APIKeyRepository
redeemCodeRepo
RedeemCodeRepository
userGroupRateRepo
UserGroupRateRepository
userRPMCache
UserRPMCache
billingCacheService
*
BillingCacheService
proxyProber
ProxyExitInfoProber
proxyLatencyCache
ProxyLatencyCache
...
...
@@ -496,6 +525,7 @@ func NewAdminService(
apiKeyRepo
APIKeyRepository
,
redeemCodeRepo
RedeemCodeRepository
,
userGroupRateRepo
UserGroupRateRepository
,
userRPMCache
UserRPMCache
,
billingCacheService
*
BillingCacheService
,
proxyProber
ProxyExitInfoProber
,
proxyLatencyCache
ProxyLatencyCache
,
...
...
@@ -514,6 +544,7 @@ func NewAdminService(
apiKeyRepo
:
apiKeyRepo
,
redeemCodeRepo
:
redeemCodeRepo
,
userGroupRateRepo
:
userGroupRateRepo
,
userRPMCache
:
userRPMCache
,
billingCacheService
:
billingCacheService
,
proxyProber
:
proxyProber
,
proxyLatencyCache
:
proxyLatencyCache
,
...
...
@@ -617,6 +648,7 @@ func (s *adminServiceImpl) CreateUser(ctx context.Context, input *CreateUserInpu
Role
:
RoleUser
,
// Always create as regular user, never admin
Balance
:
input
.
Balance
,
Concurrency
:
input
.
Concurrency
,
RPMLimit
:
input
.
RPMLimit
,
Status
:
StatusActive
,
AllowedGroups
:
input
.
AllowedGroups
,
}
...
...
@@ -670,6 +702,7 @@ func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *Upda
oldConcurrency
:=
user
.
Concurrency
oldStatus
:=
user
.
Status
oldRole
:=
user
.
Role
oldRPMLimit
:=
user
.
RPMLimit
if
input
.
Email
!=
""
{
user
.
Email
=
input
.
Email
...
...
@@ -695,6 +728,10 @@ func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *Upda
user
.
Concurrency
=
*
input
.
Concurrency
}
if
input
.
RPMLimit
!=
nil
{
user
.
RPMLimit
=
*
input
.
RPMLimit
}
if
input
.
AllowedGroups
!=
nil
{
user
.
AllowedGroups
=
*
input
.
AllowedGroups
}
...
...
@@ -711,7 +748,9 @@ func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *Upda
}
if
s
.
authCacheInvalidator
!=
nil
{
if
user
.
Concurrency
!=
oldConcurrency
||
user
.
Status
!=
oldStatus
||
user
.
Role
!=
oldRole
{
// RPMLimit 直接参与 billing_cache_service.checkRPM 的三级级联,
// 不失效缓存会让修改在一个 L2 TTL 内失去效果。
if
user
.
Concurrency
!=
oldConcurrency
||
user
.
Status
!=
oldStatus
||
user
.
Role
!=
oldRole
||
user
.
RPMLimit
!=
oldRPMLimit
{
s
.
authCacheInvalidator
.
InvalidateAuthCacheByUserID
(
ctx
,
user
.
ID
)
}
}
...
...
@@ -833,6 +872,81 @@ func (s *adminServiceImpl) GetUserAPIKeys(ctx context.Context, userID int64, pag
return
keys
,
result
.
Total
,
nil
}
func
(
s
*
adminServiceImpl
)
GetUserRPMStatus
(
ctx
context
.
Context
,
userID
int64
)
(
*
UserRPMStatus
,
error
)
{
if
s
.
userRPMCache
==
nil
{
return
nil
,
ErrRPMStatusUnavailable
}
user
,
err
:=
s
.
userRepo
.
GetByID
(
ctx
,
userID
)
if
err
!=
nil
{
return
nil
,
err
}
userRPMUsed
,
err
:=
s
.
userRPMCache
.
GetUserRPM
(
ctx
,
userID
)
if
err
!=
nil
{
logger
.
LegacyPrintf
(
"service.admin"
,
"failed to get user rpm: user_id=%d err=%v"
,
userID
,
err
)
}
keys
,
_
,
err
:=
s
.
GetUserAPIKeys
(
ctx
,
userID
,
1
,
1000
,
""
,
""
)
if
err
!=
nil
{
return
nil
,
err
}
groupIDSet
:=
make
(
map
[
int64
]
struct
{})
for
_
,
key
:=
range
keys
{
if
key
.
GroupID
!=
nil
&&
*
key
.
GroupID
>
0
{
groupIDSet
[
*
key
.
GroupID
]
=
struct
{}{}
}
}
groupIDs
:=
make
([]
int64
,
0
,
len
(
groupIDSet
))
for
groupID
:=
range
groupIDSet
{
groupIDs
=
append
(
groupIDs
,
groupID
)
}
sort
.
Slice
(
groupIDs
,
func
(
i
,
j
int
)
bool
{
return
groupIDs
[
i
]
<
groupIDs
[
j
]
})
var
perGroup
[]
UserGroupRPMStatus
for
_
,
groupID
:=
range
groupIDs
{
used
,
getErr
:=
s
.
userRPMCache
.
GetUserGroupRPM
(
ctx
,
userID
,
groupID
)
if
getErr
!=
nil
{
logger
.
LegacyPrintf
(
"service.admin"
,
"failed to get user group rpm: user_id=%d group_id=%d err=%v"
,
userID
,
groupID
,
getErr
)
}
entry
:=
UserGroupRPMStatus
{
GroupID
:
groupID
,
Used
:
used
,
}
if
s
.
groupRepo
!=
nil
{
if
group
,
groupErr
:=
s
.
groupRepo
.
GetByIDLite
(
ctx
,
groupID
);
groupErr
==
nil
&&
group
!=
nil
{
entry
.
GroupName
=
group
.
Name
entry
.
Limit
=
group
.
RPMLimit
entry
.
Source
=
"group"
}
else
if
groupErr
!=
nil
{
logger
.
LegacyPrintf
(
"service.admin"
,
"failed to get group rpm status metadata: group_id=%d err=%v"
,
groupID
,
groupErr
)
}
}
if
s
.
userGroupRateRepo
!=
nil
{
override
,
overrideErr
:=
s
.
userGroupRateRepo
.
GetRPMOverrideByUserAndGroup
(
ctx
,
userID
,
groupID
)
if
overrideErr
!=
nil
{
logger
.
LegacyPrintf
(
"service.admin"
,
"failed to get rpm override: user_id=%d group_id=%d err=%v"
,
userID
,
groupID
,
overrideErr
)
}
else
if
override
!=
nil
{
entry
.
Limit
=
*
override
entry
.
Source
=
"override"
}
}
perGroup
=
append
(
perGroup
,
entry
)
}
return
&
UserRPMStatus
{
UserRPMUsed
:
userRPMUsed
,
UserRPMLimit
:
user
.
RPMLimit
,
PerGroup
:
perGroup
,
},
nil
}
func
(
s
*
adminServiceImpl
)
GetUserUsageStats
(
ctx
context
.
Context
,
userID
int64
,
period
string
)
(
any
,
error
)
{
// Return mock data for now
return
map
[
string
]
any
{
...
...
@@ -1314,6 +1428,7 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn
RequirePrivacySet
:
input
.
RequirePrivacySet
,
DefaultMappedModel
:
input
.
DefaultMappedModel
,
MessagesDispatchModelConfig
:
normalizeOpenAIMessagesDispatchModelConfig
(
input
.
MessagesDispatchModelConfig
),
RPMLimit
:
input
.
RPMLimit
,
}
sanitizeGroupMessagesDispatchFields
(
group
)
if
err
:=
s
.
groupRepo
.
Create
(
ctx
,
group
);
err
!=
nil
{
...
...
@@ -1548,12 +1663,19 @@ func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *Upd
if
input
.
MessagesDispatchModelConfig
!=
nil
{
group
.
MessagesDispatchModelConfig
=
normalizeOpenAIMessagesDispatchModelConfig
(
*
input
.
MessagesDispatchModelConfig
)
}
if
input
.
RPMLimit
!=
nil
{
group
.
RPMLimit
=
*
input
.
RPMLimit
}
sanitizeGroupMessagesDispatchFields
(
group
)
if
err
:=
s
.
groupRepo
.
Update
(
ctx
,
group
);
err
!=
nil
{
return
nil
,
err
}
if
s
.
authCacheInvalidator
!=
nil
{
s
.
authCacheInvalidator
.
InvalidateAuthCacheByGroupID
(
ctx
,
id
)
}
// 如果指定了复制账号的源分组,同步绑定(替换当前分组的账号)
if
len
(
input
.
CopyAccountsFromGroupIDs
)
>
0
{
// 去重源分组 IDs
...
...
@@ -1622,9 +1744,6 @@ func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *Upd
}
}
if
s
.
authCacheInvalidator
!=
nil
{
s
.
authCacheInvalidator
.
InvalidateAuthCacheByGroupID
(
ctx
,
id
)
}
return
group
,
nil
}
...
...
@@ -1700,6 +1819,39 @@ func (s *adminServiceImpl) BatchSetGroupRateMultipliers(ctx context.Context, gro
return
s
.
userGroupRateRepo
.
SyncGroupRateMultipliers
(
ctx
,
groupID
,
entries
)
}
func
(
s
*
adminServiceImpl
)
ClearGroupRPMOverrides
(
ctx
context
.
Context
,
groupID
int64
)
error
{
if
s
.
userGroupRateRepo
==
nil
{
return
nil
}
if
err
:=
s
.
userGroupRateRepo
.
ClearGroupRPMOverrides
(
ctx
,
groupID
);
err
!=
nil
{
return
err
}
// RPM override 已嵌入 auth cache snapshot (v7),变更后必须失效相关缓存。
if
s
.
authCacheInvalidator
!=
nil
{
s
.
authCacheInvalidator
.
InvalidateAuthCacheByGroupID
(
ctx
,
groupID
)
}
return
nil
}
func
(
s
*
adminServiceImpl
)
BatchSetGroupRPMOverrides
(
ctx
context
.
Context
,
groupID
int64
,
entries
[]
GroupRPMOverrideInput
)
error
{
if
s
.
userGroupRateRepo
==
nil
{
return
nil
}
for
_
,
e
:=
range
entries
{
if
e
.
RPMOverride
!=
nil
&&
*
e
.
RPMOverride
<
0
{
return
infraerrors
.
BadRequest
(
"INVALID_RPM_OVERRIDE"
,
fmt
.
Sprintf
(
"rpm_override must be >= 0 (user_id=%d)"
,
e
.
UserID
))
}
}
if
err
:=
s
.
userGroupRateRepo
.
SyncGroupRPMOverrides
(
ctx
,
groupID
,
entries
);
err
!=
nil
{
return
err
}
// RPM override 已嵌入 auth cache snapshot (v7),变更后必须失效相关缓存。
if
s
.
authCacheInvalidator
!=
nil
{
s
.
authCacheInvalidator
.
InvalidateAuthCacheByGroupID
(
ctx
,
groupID
)
}
return
nil
}
func
(
s
*
adminServiceImpl
)
UpdateGroupSortOrders
(
ctx
context
.
Context
,
updates
[]
GroupSortOrderUpdate
)
error
{
return
s
.
groupRepo
.
UpdateSortOrders
(
ctx
,
updates
)
}
...
...
backend/internal/service/admin_service_group_rate_test.go
View file @
5e060b22
...
...
@@ -5,8 +5,10 @@ package service
import
(
"context"
"errors"
"net/http"
"testing"
infraerrors
"github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/stretchr/testify/require"
)
...
...
@@ -21,6 +23,10 @@ type userGroupRateRepoStubForGroupRate struct {
syncedGroupID
int64
syncedEntries
[]
GroupRateMultiplierInput
syncGroupErr
error
rpmSyncedGroupID
int64
rpmSyncedEntries
[]
GroupRPMOverrideInput
rpmSyncErr
error
}
func
(
s
*
userGroupRateRepoStubForGroupRate
)
GetByUserID
(
_
context
.
Context
,
_
int64
)
(
map
[
int64
]
float64
,
error
)
{
...
...
@@ -31,6 +37,10 @@ func (s *userGroupRateRepoStubForGroupRate) GetByUserAndGroup(_ context.Context,
panic
(
"unexpected GetByUserAndGroup call"
)
}
func
(
s
*
userGroupRateRepoStubForGroupRate
)
GetRPMOverrideByUserAndGroup
(
_
context
.
Context
,
_
,
_
int64
)
(
*
int
,
error
)
{
panic
(
"unexpected GetRPMOverrideByUserAndGroup call"
)
}
func
(
s
*
userGroupRateRepoStubForGroupRate
)
GetByGroupID
(
_
context
.
Context
,
groupID
int64
)
([]
UserGroupRateEntry
,
error
)
{
if
s
.
getByGroupIDErr
!=
nil
{
return
nil
,
s
.
getByGroupIDErr
...
...
@@ -48,6 +58,16 @@ func (s *userGroupRateRepoStubForGroupRate) SyncGroupRateMultipliers(_ context.C
return
s
.
syncGroupErr
}
func
(
s
*
userGroupRateRepoStubForGroupRate
)
SyncGroupRPMOverrides
(
_
context
.
Context
,
groupID
int64
,
entries
[]
GroupRPMOverrideInput
)
error
{
s
.
rpmSyncedGroupID
=
groupID
s
.
rpmSyncedEntries
=
entries
return
s
.
rpmSyncErr
}
func
(
s
*
userGroupRateRepoStubForGroupRate
)
ClearGroupRPMOverrides
(
_
context
.
Context
,
_
int64
)
error
{
panic
(
"unexpected ClearGroupRPMOverrides call"
)
}
func
(
s
*
userGroupRateRepoStubForGroupRate
)
DeleteByGroupID
(
_
context
.
Context
,
groupID
int64
)
error
{
s
.
deletedGroupIDs
=
append
(
s
.
deletedGroupIDs
,
groupID
)
return
s
.
deleteByGroupErr
...
...
@@ -62,8 +82,8 @@ func TestAdminService_GetGroupRateMultipliers(t *testing.T) {
repo
:=
&
userGroupRateRepoStubForGroupRate
{
getByGroupIDData
:
map
[
int64
][]
UserGroupRateEntry
{
10
:
{
{
UserID
:
1
,
UserName
:
"alice"
,
UserEmail
:
"alice@test.com"
,
RateMultiplier
:
1.5
},
{
UserID
:
2
,
UserName
:
"bob"
,
UserEmail
:
"bob@test.com"
,
RateMultiplier
:
0.8
},
{
UserID
:
1
,
UserName
:
"alice"
,
UserEmail
:
"alice@test.com"
,
RateMultiplier
:
ptrFloat
(
1.5
)
},
{
UserID
:
2
,
UserName
:
"bob"
,
UserEmail
:
"bob@test.com"
,
RateMultiplier
:
ptrFloat
(
0.8
)
},
},
},
}
...
...
@@ -74,9 +94,11 @@ func TestAdminService_GetGroupRateMultipliers(t *testing.T) {
require
.
Len
(
t
,
entries
,
2
)
require
.
Equal
(
t
,
int64
(
1
),
entries
[
0
]
.
UserID
)
require
.
Equal
(
t
,
"alice"
,
entries
[
0
]
.
UserName
)
require
.
Equal
(
t
,
1.5
,
entries
[
0
]
.
RateMultiplier
)
require
.
NotNil
(
t
,
entries
[
0
]
.
RateMultiplier
)
require
.
Equal
(
t
,
1.5
,
*
entries
[
0
]
.
RateMultiplier
)
require
.
Equal
(
t
,
int64
(
2
),
entries
[
1
]
.
UserID
)
require
.
Equal
(
t
,
0.8
,
entries
[
1
]
.
RateMultiplier
)
require
.
NotNil
(
t
,
entries
[
1
]
.
RateMultiplier
)
require
.
Equal
(
t
,
0.8
,
*
entries
[
1
]
.
RateMultiplier
)
})
t
.
Run
(
"returns nil when repo is nil"
,
func
(
t
*
testing
.
T
)
{
...
...
@@ -174,3 +196,30 @@ func TestAdminService_BatchSetGroupRateMultipliers(t *testing.T) {
require
.
Contains
(
t
,
err
.
Error
(),
"sync failed"
)
})
}
func
TestAdminService_BatchSetGroupRPMOverrides
(
t
*
testing
.
T
)
{
t
.
Run
(
"syncs entries to repo"
,
func
(
t
*
testing
.
T
)
{
repo
:=
&
userGroupRateRepoStubForGroupRate
{}
svc
:=
&
adminServiceImpl
{
userGroupRateRepo
:
repo
}
override
:=
20
entries
:=
[]
GroupRPMOverrideInput
{{
UserID
:
2
,
RPMOverride
:
&
override
}}
err
:=
svc
.
BatchSetGroupRPMOverrides
(
context
.
Background
(),
10
,
entries
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
int64
(
10
),
repo
.
rpmSyncedGroupID
)
require
.
Equal
(
t
,
entries
,
repo
.
rpmSyncedEntries
)
})
t
.
Run
(
"rejects negative override as bad request"
,
func
(
t
*
testing
.
T
)
{
repo
:=
&
userGroupRateRepoStubForGroupRate
{}
svc
:=
&
adminServiceImpl
{
userGroupRateRepo
:
repo
}
negative
:=
-
1
err
:=
svc
.
BatchSetGroupRPMOverrides
(
context
.
Background
(),
10
,
[]
GroupRPMOverrideInput
{
{
UserID
:
2
,
RPMOverride
:
&
negative
},
})
require
.
Error
(
t
,
err
)
require
.
Equal
(
t
,
http
.
StatusBadRequest
,
infraerrors
.
Code
(
err
))
require
.
Zero
(
t
,
repo
.
rpmSyncedGroupID
)
})
}
backend/internal/service/admin_service_group_test.go
View file @
5e060b22
...
...
@@ -266,6 +266,31 @@ func TestAdminService_UpdateGroup_PartialImagePricing(t *testing.T) {
require
.
Nil
(
t
,
repo
.
updated
.
ImagePrice4K
)
}
func
TestAdminService_UpdateGroup_InvalidatesAuthCacheOnRPMLimitChange
(
t
*
testing
.
T
)
{
existingGroup
:=
&
Group
{
ID
:
1
,
Name
:
"existing-group"
,
Platform
:
PlatformAnthropic
,
Status
:
StatusActive
,
RPMLimit
:
10
,
}
repo
:=
&
groupRepoStubForAdmin
{
getByID
:
existingGroup
}
invalidator
:=
&
authCacheInvalidatorStub
{}
svc
:=
&
adminServiceImpl
{
groupRepo
:
repo
,
authCacheInvalidator
:
invalidator
,
}
rpmLimit
:=
60
group
,
err
:=
svc
.
UpdateGroup
(
context
.
Background
(),
1
,
&
UpdateGroupInput
{
RPMLimit
:
&
rpmLimit
,
})
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
group
)
require
.
Equal
(
t
,
60
,
repo
.
updated
.
RPMLimit
)
require
.
Equal
(
t
,
[]
int64
{
1
},
invalidator
.
groupIDs
,
"分组 RPMLimit 写入 auth snapshot,变更后必须失效 API Key 认证缓存"
)
}
func
TestAdminService_CreateGroup_NormalizesMessagesDispatchModelConfig
(
t
*
testing
.
T
)
{
repo
:=
&
groupRepoStubForAdmin
{}
svc
:=
&
adminServiceImpl
{
groupRepo
:
repo
}
...
...
backend/internal/service/admin_service_list_users_test.go
View file @
5e060b22
...
...
@@ -89,6 +89,10 @@ func (s *userGroupRateRepoStubForListUsers) GetByUserAndGroup(_ context.Context,
panic
(
"unexpected GetByUserAndGroup call"
)
}
func
(
s
*
userGroupRateRepoStubForListUsers
)
GetRPMOverrideByUserAndGroup
(
_
context
.
Context
,
_
,
_
int64
)
(
*
int
,
error
)
{
panic
(
"unexpected GetRPMOverrideByUserAndGroup call"
)
}
func
(
s
*
userGroupRateRepoStubForListUsers
)
SyncUserGroupRates
(
_
context
.
Context
,
userID
int64
,
rates
map
[
int64
]
*
float64
)
error
{
panic
(
"unexpected SyncUserGroupRates call"
)
}
...
...
@@ -101,6 +105,14 @@ func (s *userGroupRateRepoStubForListUsers) SyncGroupRateMultipliers(_ context.C
panic
(
"unexpected SyncGroupRateMultipliers call"
)
}
func
(
s
*
userGroupRateRepoStubForListUsers
)
SyncGroupRPMOverrides
(
_
context
.
Context
,
_
int64
,
_
[]
GroupRPMOverrideInput
)
error
{
panic
(
"unexpected SyncGroupRPMOverrides call"
)
}
func
(
s
*
userGroupRateRepoStubForListUsers
)
ClearGroupRPMOverrides
(
_
context
.
Context
,
_
int64
)
error
{
panic
(
"unexpected ClearGroupRPMOverrides call"
)
}
func
(
s
*
userGroupRateRepoStubForListUsers
)
DeleteByGroupID
(
_
context
.
Context
,
_
int64
)
error
{
panic
(
"unexpected DeleteByGroupID call"
)
}
...
...
backend/internal/service/admin_service_rpm_status_test.go
0 → 100644
View file @
5e060b22
//go:build unit
package
service
import
(
"context"
"testing"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/stretchr/testify/require"
)
type
rpmStatusUserRepoStub
struct
{
UserRepository
user
*
User
}
func
(
s
*
rpmStatusUserRepoStub
)
GetByID
(
_
context
.
Context
,
_
int64
)
(
*
User
,
error
)
{
return
s
.
user
,
nil
}
type
rpmStatusAPIKeyRepoStub
struct
{
APIKeyRepository
keys
[]
APIKey
}
func
(
s
*
rpmStatusAPIKeyRepoStub
)
ListByUserID
(
_
context
.
Context
,
_
int64
,
_
pagination
.
PaginationParams
,
_
APIKeyListFilters
)
([]
APIKey
,
*
pagination
.
PaginationResult
,
error
)
{
return
s
.
keys
,
&
pagination
.
PaginationResult
{
Total
:
int64
(
len
(
s
.
keys
))},
nil
}
type
rpmStatusGroupRepoStub
struct
{
GroupRepository
groups
map
[
int64
]
*
Group
}
func
(
s
*
rpmStatusGroupRepoStub
)
GetByIDLite
(
_
context
.
Context
,
id
int64
)
(
*
Group
,
error
)
{
return
s
.
groups
[
id
],
nil
}
type
rpmStatusRateRepoStub
struct
{
UserGroupRateRepository
overrides
map
[
int64
]
*
int
}
func
(
s
*
rpmStatusRateRepoStub
)
GetRPMOverrideByUserAndGroup
(
_
context
.
Context
,
_
,
groupID
int64
)
(
*
int
,
error
)
{
return
s
.
overrides
[
groupID
],
nil
}
type
rpmStatusCacheStub
struct
{
UserRPMCache
userUsed
int
groupUsed
map
[
int64
]
int
}
func
(
s
*
rpmStatusCacheStub
)
IncrementUserGroupRPM
(
context
.
Context
,
int64
,
int64
)
(
int
,
error
)
{
return
0
,
nil
}
func
(
s
*
rpmStatusCacheStub
)
IncrementUserRPM
(
context
.
Context
,
int64
)
(
int
,
error
)
{
return
0
,
nil
}
func
(
s
*
rpmStatusCacheStub
)
GetUserGroupRPM
(
_
context
.
Context
,
_
,
groupID
int64
)
(
int
,
error
)
{
return
s
.
groupUsed
[
groupID
],
nil
}
func
(
s
*
rpmStatusCacheStub
)
GetUserRPM
(
context
.
Context
,
int64
)
(
int
,
error
)
{
return
s
.
userUsed
,
nil
}
func
TestAdminService_GetUserRPMStatus_AggregatesUserAndGroupLimits
(
t
*
testing
.
T
)
{
groupOneID
:=
int64
(
1
)
groupTwoID
:=
int64
(
2
)
override
:=
7
svc
:=
&
adminServiceImpl
{
userRepo
:
&
rpmStatusUserRepoStub
{
user
:
&
User
{
ID
:
42
,
RPMLimit
:
20
,
}},
apiKeyRepo
:
&
rpmStatusAPIKeyRepoStub
{
keys
:
[]
APIKey
{
{
ID
:
100
,
UserID
:
42
,
GroupID
:
&
groupTwoID
},
{
ID
:
101
,
UserID
:
42
,
GroupID
:
&
groupOneID
},
{
ID
:
102
,
UserID
:
42
,
GroupID
:
&
groupTwoID
},
{
ID
:
103
,
UserID
:
42
},
}},
groupRepo
:
&
rpmStatusGroupRepoStub
{
groups
:
map
[
int64
]
*
Group
{
groupOneID
:
{
ID
:
groupOneID
,
Name
:
"group-one"
,
RPMLimit
:
10
},
groupTwoID
:
{
ID
:
groupTwoID
,
Name
:
"group-two"
,
RPMLimit
:
60
},
}},
userGroupRateRepo
:
&
rpmStatusRateRepoStub
{
overrides
:
map
[
int64
]
*
int
{
groupTwoID
:
&
override
,
}},
userRPMCache
:
&
rpmStatusCacheStub
{
userUsed
:
5
,
groupUsed
:
map
[
int64
]
int
{
groupOneID
:
3
,
groupTwoID
:
4
,
},
},
}
status
,
err
:=
svc
.
GetUserRPMStatus
(
context
.
Background
(),
42
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
&
UserRPMStatus
{
UserRPMUsed
:
5
,
UserRPMLimit
:
20
,
PerGroup
:
[]
UserGroupRPMStatus
{
{
GroupID
:
groupOneID
,
GroupName
:
"group-one"
,
Used
:
3
,
Limit
:
10
,
Source
:
"group"
},
{
GroupID
:
groupTwoID
,
GroupName
:
"group-two"
,
Used
:
4
,
Limit
:
7
,
Source
:
"override"
},
},
},
status
)
}
backend/internal/service/admin_service_update_user_rpm_test.go
0 → 100644
View file @
5e060b22
//go:build unit
package
service
import
(
"context"
"testing"
"github.com/stretchr/testify/require"
)
// rpmUserRepoStub 复用 admin_service_update_balance_test.go 的基础 stub 结构,
// 只在 Update 时把入参克隆一份,便于断言修改后的 RPMLimit。
type
rpmUserRepoStub
struct
{
*
userRepoStub
lastUpdated
*
User
}
func
(
s
*
rpmUserRepoStub
)
Update
(
_
context
.
Context
,
user
*
User
)
error
{
if
user
==
nil
{
return
nil
}
clone
:=
*
user
s
.
lastUpdated
=
&
clone
if
s
.
userRepoStub
!=
nil
{
s
.
userRepoStub
.
user
=
&
clone
}
return
nil
}
func
TestAdminService_UpdateUser_InvalidatesAuthCacheOnRPMLimitChange
(
t
*
testing
.
T
)
{
base
:=
&
userRepoStub
{
user
:
&
User
{
ID
:
42
,
Email
:
"u@example.com"
,
RPMLimit
:
10
}}
repo
:=
&
rpmUserRepoStub
{
userRepoStub
:
base
}
invalidator
:=
&
authCacheInvalidatorStub
{}
svc
:=
&
adminServiceImpl
{
userRepo
:
repo
,
redeemCodeRepo
:
&
redeemRepoStub
{},
authCacheInvalidator
:
invalidator
,
}
newRPM
:=
60
updated
,
err
:=
svc
.
UpdateUser
(
context
.
Background
(),
42
,
&
UpdateUserInput
{
RPMLimit
:
&
newRPM
,
})
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
updated
)
require
.
Equal
(
t
,
60
,
updated
.
RPMLimit
)
require
.
Equal
(
t
,
[]
int64
{
42
},
invalidator
.
userIDs
,
"仅修改 RPMLimit 也应失效 API Key 认证缓存"
)
}
func
TestAdminService_UpdateUser_NoInvalidateWhenRPMLimitUnchanged
(
t
*
testing
.
T
)
{
base
:=
&
userRepoStub
{
user
:
&
User
{
ID
:
42
,
Email
:
"u@example.com"
,
RPMLimit
:
10
,
Username
:
"old"
}}
repo
:=
&
rpmUserRepoStub
{
userRepoStub
:
base
}
invalidator
:=
&
authCacheInvalidatorStub
{}
svc
:=
&
adminServiceImpl
{
userRepo
:
repo
,
redeemCodeRepo
:
&
redeemRepoStub
{},
authCacheInvalidator
:
invalidator
,
}
newName
:=
"new"
sameRPM
:=
10
_
,
err
:=
svc
.
UpdateUser
(
context
.
Background
(),
42
,
&
UpdateUserInput
{
Username
:
&
newName
,
RPMLimit
:
&
sameRPM
,
})
require
.
NoError
(
t
,
err
)
require
.
Empty
(
t
,
invalidator
.
userIDs
,
"只改 username 不应触发认证缓存失效"
)
}
backend/internal/service/api_key_auth_cache.go
View file @
5e060b22
...
...
@@ -43,6 +43,13 @@ type APIKeyAuthUserSnapshot struct {
BalanceNotifyThreshold
*
float64
`json:"balance_notify_threshold,omitempty"`
BalanceNotifyExtraEmails
[]
NotifyEmailEntry
`json:"balance_notify_extra_emails,omitempty"`
TotalRecharged
float64
`json:"total_recharged"`
// RPMLimit 用户级每分钟请求数上限(0 = 不限制);用于 billing_cache_service.checkRPM 兜底判断。
RPMLimit
int
`json:"rpm_limit"`
// UserGroupRPMOverride 该 API Key 对应的 (user, group) 专属 RPM 覆盖值。
// nil = 无 override(回退到 group/user 级);0 = 不限流;>0 = 专属上限。
UserGroupRPMOverride
*
int
`json:"user_group_rpm_override,omitempty"`
}
// APIKeyAuthGroupSnapshot 分组快照
...
...
@@ -76,6 +83,9 @@ type APIKeyAuthGroupSnapshot struct {
AllowMessagesDispatch
bool
`json:"allow_messages_dispatch"`
DefaultMappedModel
string
`json:"default_mapped_model,omitempty"`
MessagesDispatchModelConfig
OpenAIMessagesDispatchModelConfig
`json:"messages_dispatch_model_config,omitempty"`
// RPMLimit 分组级每分钟请求数上限(0 = 不限制);用于 billing_cache_service.checkRPM 级联判断。
RPMLimit
int
`json:"rpm_limit"`
}
// APIKeyAuthCacheEntry 缓存条目,支持负缓存
...
...
backend/internal/service/api_key_auth_cache_impl.go
View file @
5e060b22
...
...
@@ -14,7 +14,7 @@ import (
"github.com/dgraph-io/ristretto"
)
const
apiKeyAuthSnapshotVersion
=
5
// v
5
: added
TotalRecharged for percentage thre
sho
ld
const
apiKeyAuthSnapshotVersion
=
7
// v
7
: added
UserGroupRPMOverride on user snap
sho
t
type
apiKeyAuthCacheConfig
struct
{
l1Size
int
...
...
@@ -176,7 +176,7 @@ func (s *APIKeyService) loadAuthCacheEntry(ctx context.Context, key, cacheKey st
return
nil
,
fmt
.
Errorf
(
"get api key: %w"
,
err
)
}
apiKey
.
Key
=
key
snapshot
:=
s
.
snapshotFromAPIKey
(
apiKey
)
snapshot
:=
s
.
snapshotFromAPIKey
(
ctx
,
apiKey
)
if
snapshot
==
nil
{
return
nil
,
fmt
.
Errorf
(
"get api key: %w"
,
ErrAPIKeyNotFound
)
}
...
...
@@ -201,7 +201,7 @@ func (s *APIKeyService) applyAuthCacheEntry(key string, entry *APIKeyAuthCacheEn
return
s
.
snapshotToAPIKey
(
key
,
entry
.
Snapshot
),
true
,
nil
}
func
(
s
*
APIKeyService
)
snapshotFromAPIKey
(
apiKey
*
APIKey
)
*
APIKeyAuthSnapshot
{
func
(
s
*
APIKeyService
)
snapshotFromAPIKey
(
ctx
context
.
Context
,
apiKey
*
APIKey
)
*
APIKeyAuthSnapshot
{
if
apiKey
==
nil
||
apiKey
.
User
==
nil
{
return
nil
}
...
...
@@ -232,8 +232,18 @@ func (s *APIKeyService) snapshotFromAPIKey(apiKey *APIKey) *APIKeyAuthSnapshot {
BalanceNotifyThreshold
:
apiKey
.
User
.
BalanceNotifyThreshold
,
BalanceNotifyExtraEmails
:
apiKey
.
User
.
BalanceNotifyExtraEmails
,
TotalRecharged
:
apiKey
.
User
.
TotalRecharged
,
RPMLimit
:
apiKey
.
User
.
RPMLimit
,
},
}
// 填充 (user, group) RPM override —— snapshot 构建时查一次 DB,后续请求零 DB 往返。
if
apiKey
.
GroupID
!=
nil
&&
*
apiKey
.
GroupID
>
0
&&
s
.
userGroupRateRepo
!=
nil
{
override
,
err
:=
s
.
userGroupRateRepo
.
GetRPMOverrideByUserAndGroup
(
ctx
,
apiKey
.
UserID
,
*
apiKey
.
GroupID
)
if
err
==
nil
&&
override
!=
nil
{
snapshot
.
User
.
UserGroupRPMOverride
=
override
}
// 查询失败或无 override 时留 nil,checkRPM 会回退到 DB 查询
}
if
apiKey
.
Group
!=
nil
{
snapshot
.
Group
=
&
APIKeyAuthGroupSnapshot
{
ID
:
apiKey
.
Group
.
ID
,
...
...
@@ -258,6 +268,7 @@ func (s *APIKeyService) snapshotFromAPIKey(apiKey *APIKey) *APIKeyAuthSnapshot {
AllowMessagesDispatch
:
apiKey
.
Group
.
AllowMessagesDispatch
,
DefaultMappedModel
:
apiKey
.
Group
.
DefaultMappedModel
,
MessagesDispatchModelConfig
:
apiKey
.
Group
.
MessagesDispatchModelConfig
,
RPMLimit
:
apiKey
.
Group
.
RPMLimit
,
}
}
return
snapshot
...
...
@@ -294,6 +305,8 @@ func (s *APIKeyService) snapshotToAPIKey(key string, snapshot *APIKeyAuthSnapsho
BalanceNotifyThreshold
:
snapshot
.
User
.
BalanceNotifyThreshold
,
BalanceNotifyExtraEmails
:
snapshot
.
User
.
BalanceNotifyExtraEmails
,
TotalRecharged
:
snapshot
.
User
.
TotalRecharged
,
RPMLimit
:
snapshot
.
User
.
RPMLimit
,
UserGroupRPMOverride
:
snapshot
.
User
.
UserGroupRPMOverride
,
},
}
if
snapshot
.
Group
!=
nil
{
...
...
@@ -321,6 +334,7 @@ func (s *APIKeyService) snapshotToAPIKey(key string, snapshot *APIKeyAuthSnapsho
AllowMessagesDispatch
:
snapshot
.
Group
.
AllowMessagesDispatch
,
DefaultMappedModel
:
snapshot
.
Group
.
DefaultMappedModel
,
MessagesDispatchModelConfig
:
snapshot
.
Group
.
MessagesDispatchModelConfig
,
RPMLimit
:
snapshot
.
Group
.
RPMLimit
,
}
}
s
.
compileAPIKeyIPRules
(
apiKey
)
...
...
Prev
1
2
3
4
5
6
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment