Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Menu
Open sidebar
陈曦
sub2api
Commits
c5781c69
Commit
c5781c69
authored
Jan 01, 2026
by
IanShaw027
Browse files
fix(merge): 解决与 main 分支的配置冲突
- 合并 main 分支的上游错误日志配置 - 保留调度配置 - 合并 beta header 和 failover 配置
parents
e1a9c1ec
34c10204
Changes
53
Hide whitespace changes
Inline
Side-by-side
backend/internal/repository/migrations_schema_integration_test.go
View file @
c5781c69
...
@@ -7,7 +7,6 @@ import (
...
@@ -7,7 +7,6 @@ import (
"database/sql"
"database/sql"
"testing"
"testing"
"github.com/Wei-Shaw/sub2api/internal/infrastructure"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/require"
)
)
...
@@ -15,7 +14,7 @@ func TestMigrationsRunner_IsIdempotent_AndSchemaIsUpToDate(t *testing.T) {
...
@@ -15,7 +14,7 @@ func TestMigrationsRunner_IsIdempotent_AndSchemaIsUpToDate(t *testing.T) {
tx
:=
testTx
(
t
)
tx
:=
testTx
(
t
)
// Re-apply migrations to verify idempotency (no errors, no duplicate rows).
// Re-apply migrations to verify idempotency (no errors, no duplicate rows).
require
.
NoError
(
t
,
infrastructure
.
ApplyMigrations
(
context
.
Background
(),
integrationDB
))
require
.
NoError
(
t
,
ApplyMigrations
(
context
.
Background
(),
integrationDB
))
// schema_migrations should have at least the current migration set.
// schema_migrations should have at least the current migration set.
var
applied
int
var
applied
int
...
...
backend/internal/
infrastructure
/redis.go
→
backend/internal/
repository
/redis.go
View file @
c5781c69
package
infrastructure
package
repository
import
(
import
(
"time"
"time"
...
...
backend/internal/
infrastructure
/redis_test.go
→
backend/internal/
repository
/redis_test.go
View file @
c5781c69
package
infrastructure
package
repository
import
(
import
(
"testing"
"testing"
...
...
backend/internal/repository/user_subscription_repo.go
View file @
c5781c69
...
@@ -291,13 +291,11 @@ func (r *userSubscriptionRepository) ResetMonthlyUsage(ctx context.Context, id i
...
@@ -291,13 +291,11 @@ func (r *userSubscriptionRepository) ResetMonthlyUsage(ctx context.Context, id i
return
translatePersistenceError
(
err
,
service
.
ErrSubscriptionNotFound
,
nil
)
return
translatePersistenceError
(
err
,
service
.
ErrSubscriptionNotFound
,
nil
)
}
}
// IncrementUsage 原子性地累加用量
并校验限额
。
// IncrementUsage 原子性地累加
订阅
用量。
//
使用单条 SQL 语句同时检查 Group 的限额,如果任一限额即将超出则拒绝更新。
//
限额检查已在请求前由 BillingCacheService.CheckBillingEligibility 完成,
//
当更新失败时,会执行额外查询确定具体超出的限额类型
。
//
此处仅负责记录实际消费,确保消费数据的完整性
。
func
(
r
*
userSubscriptionRepository
)
IncrementUsage
(
ctx
context
.
Context
,
id
int64
,
costUSD
float64
)
error
{
func
(
r
*
userSubscriptionRepository
)
IncrementUsage
(
ctx
context
.
Context
,
id
int64
,
costUSD
float64
)
error
{
// 使用 JOIN 的原子更新:只有当所有限额条件满足时才执行累加
const
updateSQL
=
`
// NULL 限额表示无限制
const
atomicUpdateSQL
=
`
UPDATE user_subscriptions us
UPDATE user_subscriptions us
SET
SET
daily_usage_usd = us.daily_usage_usd + $1,
daily_usage_usd = us.daily_usage_usd + $1,
...
@@ -309,13 +307,10 @@ func (r *userSubscriptionRepository) IncrementUsage(ctx context.Context, id int6
...
@@ -309,13 +307,10 @@ func (r *userSubscriptionRepository) IncrementUsage(ctx context.Context, id int6
AND us.deleted_at IS NULL
AND us.deleted_at IS NULL
AND us.group_id = g.id
AND us.group_id = g.id
AND g.deleted_at IS NULL
AND g.deleted_at IS NULL
AND (g.daily_limit_usd IS NULL OR us.daily_usage_usd + $1 <= g.daily_limit_usd)
AND (g.weekly_limit_usd IS NULL OR us.weekly_usage_usd + $1 <= g.weekly_limit_usd)
AND (g.monthly_limit_usd IS NULL OR us.monthly_usage_usd + $1 <= g.monthly_limit_usd)
`
`
client
:=
clientFromContext
(
ctx
,
r
.
client
)
client
:=
clientFromContext
(
ctx
,
r
.
client
)
result
,
err
:=
client
.
ExecContext
(
ctx
,
atomicU
pdateSQL
,
costUSD
,
id
)
result
,
err
:=
client
.
ExecContext
(
ctx
,
u
pdateSQL
,
costUSD
,
id
)
if
err
!=
nil
{
if
err
!=
nil
{
return
err
return
err
}
}
...
@@ -326,64 +321,11 @@ func (r *userSubscriptionRepository) IncrementUsage(ctx context.Context, id int6
...
@@ -326,64 +321,11 @@ func (r *userSubscriptionRepository) IncrementUsage(ctx context.Context, id int6
}
}
if
affected
>
0
{
if
affected
>
0
{
return
nil
// 更新成功
return
nil
}
// affected == 0:可能是订阅不存在、分组已删除、或限额超出
// 执行额外查询确定具体原因
return
r
.
checkIncrementFailureReason
(
ctx
,
id
,
costUSD
)
}
// checkIncrementFailureReason 查询更新失败的具体原因
func
(
r
*
userSubscriptionRepository
)
checkIncrementFailureReason
(
ctx
context
.
Context
,
id
int64
,
costUSD
float64
)
error
{
const
checkSQL
=
`
SELECT
CASE WHEN us.deleted_at IS NOT NULL THEN 'subscription_deleted'
WHEN g.id IS NULL THEN 'subscription_not_found'
WHEN g.deleted_at IS NOT NULL THEN 'group_deleted'
WHEN g.daily_limit_usd IS NOT NULL AND us.daily_usage_usd + $1 > g.daily_limit_usd THEN 'daily_exceeded'
WHEN g.weekly_limit_usd IS NOT NULL AND us.weekly_usage_usd + $1 > g.weekly_limit_usd THEN 'weekly_exceeded'
WHEN g.monthly_limit_usd IS NOT NULL AND us.monthly_usage_usd + $1 > g.monthly_limit_usd THEN 'monthly_exceeded'
ELSE 'unknown'
END AS reason
FROM user_subscriptions us
LEFT JOIN groups g ON us.group_id = g.id
WHERE us.id = $2
`
client
:=
clientFromContext
(
ctx
,
r
.
client
)
rows
,
err
:=
client
.
QueryContext
(
ctx
,
checkSQL
,
costUSD
,
id
)
if
err
!=
nil
{
return
err
}
defer
func
()
{
_
=
rows
.
Close
()
}()
if
!
rows
.
Next
()
{
return
service
.
ErrSubscriptionNotFound
}
var
reason
string
if
err
:=
rows
.
Scan
(
&
reason
);
err
!=
nil
{
return
err
}
if
err
:=
rows
.
Err
();
err
!=
nil
{
return
err
}
}
switch
reason
{
// affected == 0:订阅不存在或已删除
case
"subscription_not_found"
,
"subscription_deleted"
,
"group_deleted"
:
return
service
.
ErrSubscriptionNotFound
return
service
.
ErrSubscriptionNotFound
case
"daily_exceeded"
:
return
service
.
ErrDailyLimitExceeded
case
"weekly_exceeded"
:
return
service
.
ErrWeeklyLimitExceeded
case
"monthly_exceeded"
:
return
service
.
ErrMonthlyLimitExceeded
default
:
// unknown 情况理论上不应发生,但作为兜底返回
return
service
.
ErrSubscriptionNotFound
}
}
}
func
(
r
*
userSubscriptionRepository
)
BatchUpdateExpiredStatus
(
ctx
context
.
Context
)
(
int64
,
error
)
{
func
(
r
*
userSubscriptionRepository
)
BatchUpdateExpiredStatus
(
ctx
context
.
Context
)
(
int64
,
error
)
{
...
...
backend/internal/repository/user_subscription_repo_integration_test.go
View file @
c5781c69
...
@@ -633,112 +633,7 @@ func (s *UserSubscriptionRepoSuite) TestActiveExpiredBoundaries_UsageAndReset_Ba
...
@@ -633,112 +633,7 @@ func (s *UserSubscriptionRepoSuite) TestActiveExpiredBoundaries_UsageAndReset_Ba
s
.
Require
()
.
Equal
(
service
.
SubscriptionStatusExpired
,
updated
.
Status
,
"expected status expired"
)
s
.
Require
()
.
Equal
(
service
.
SubscriptionStatusExpired
,
updated
.
Status
,
"expected status expired"
)
}
}
// --- 限额检查与软删除过滤测试 ---
// --- 软删除过滤测试 ---
func
(
s
*
UserSubscriptionRepoSuite
)
mustCreateGroupWithLimits
(
name
string
,
daily
,
weekly
,
monthly
*
float64
)
*
service
.
Group
{
s
.
T
()
.
Helper
()
create
:=
s
.
client
.
Group
.
Create
()
.
SetName
(
name
)
.
SetStatus
(
service
.
StatusActive
)
.
SetSubscriptionType
(
service
.
SubscriptionTypeSubscription
)
if
daily
!=
nil
{
create
.
SetDailyLimitUsd
(
*
daily
)
}
if
weekly
!=
nil
{
create
.
SetWeeklyLimitUsd
(
*
weekly
)
}
if
monthly
!=
nil
{
create
.
SetMonthlyLimitUsd
(
*
monthly
)
}
g
,
err
:=
create
.
Save
(
s
.
ctx
)
s
.
Require
()
.
NoError
(
err
,
"create group with limits"
)
return
groupEntityToService
(
g
)
}
func
(
s
*
UserSubscriptionRepoSuite
)
TestIncrementUsage_DailyLimitExceeded
()
{
user
:=
s
.
mustCreateUser
(
"dailylimit@test.com"
,
service
.
RoleUser
)
dailyLimit
:=
10.0
group
:=
s
.
mustCreateGroupWithLimits
(
"g-dailylimit"
,
&
dailyLimit
,
nil
,
nil
)
sub
:=
s
.
mustCreateSubscription
(
user
.
ID
,
group
.
ID
,
nil
)
// 先增加 9.0,应该成功
err
:=
s
.
repo
.
IncrementUsage
(
s
.
ctx
,
sub
.
ID
,
9.0
)
s
.
Require
()
.
NoError
(
err
,
"first increment should succeed"
)
// 再增加 2.0,会超过 10.0 限额,应该失败
err
=
s
.
repo
.
IncrementUsage
(
s
.
ctx
,
sub
.
ID
,
2.0
)
s
.
Require
()
.
Error
(
err
,
"should fail when daily limit exceeded"
)
s
.
Require
()
.
ErrorIs
(
err
,
service
.
ErrDailyLimitExceeded
)
// 验证用量没有变化
got
,
err
:=
s
.
repo
.
GetByID
(
s
.
ctx
,
sub
.
ID
)
s
.
Require
()
.
NoError
(
err
)
s
.
Require
()
.
InDelta
(
9.0
,
got
.
DailyUsageUSD
,
1e-6
,
"usage should not change after failed increment"
)
}
func
(
s
*
UserSubscriptionRepoSuite
)
TestIncrementUsage_WeeklyLimitExceeded
()
{
user
:=
s
.
mustCreateUser
(
"weeklylimit@test.com"
,
service
.
RoleUser
)
weeklyLimit
:=
50.0
group
:=
s
.
mustCreateGroupWithLimits
(
"g-weeklylimit"
,
nil
,
&
weeklyLimit
,
nil
)
sub
:=
s
.
mustCreateSubscription
(
user
.
ID
,
group
.
ID
,
nil
)
// 增加 45.0,应该成功
err
:=
s
.
repo
.
IncrementUsage
(
s
.
ctx
,
sub
.
ID
,
45.0
)
s
.
Require
()
.
NoError
(
err
,
"first increment should succeed"
)
// 再增加 10.0,会超过 50.0 限额,应该失败
err
=
s
.
repo
.
IncrementUsage
(
s
.
ctx
,
sub
.
ID
,
10.0
)
s
.
Require
()
.
Error
(
err
,
"should fail when weekly limit exceeded"
)
s
.
Require
()
.
ErrorIs
(
err
,
service
.
ErrWeeklyLimitExceeded
)
}
func
(
s
*
UserSubscriptionRepoSuite
)
TestIncrementUsage_MonthlyLimitExceeded
()
{
user
:=
s
.
mustCreateUser
(
"monthlylimit@test.com"
,
service
.
RoleUser
)
monthlyLimit
:=
100.0
group
:=
s
.
mustCreateGroupWithLimits
(
"g-monthlylimit"
,
nil
,
nil
,
&
monthlyLimit
)
sub
:=
s
.
mustCreateSubscription
(
user
.
ID
,
group
.
ID
,
nil
)
// 增加 90.0,应该成功
err
:=
s
.
repo
.
IncrementUsage
(
s
.
ctx
,
sub
.
ID
,
90.0
)
s
.
Require
()
.
NoError
(
err
,
"first increment should succeed"
)
// 再增加 20.0,会超过 100.0 限额,应该失败
err
=
s
.
repo
.
IncrementUsage
(
s
.
ctx
,
sub
.
ID
,
20.0
)
s
.
Require
()
.
Error
(
err
,
"should fail when monthly limit exceeded"
)
s
.
Require
()
.
ErrorIs
(
err
,
service
.
ErrMonthlyLimitExceeded
)
}
func
(
s
*
UserSubscriptionRepoSuite
)
TestIncrementUsage_NoLimits
()
{
user
:=
s
.
mustCreateUser
(
"nolimits@test.com"
,
service
.
RoleUser
)
group
:=
s
.
mustCreateGroupWithLimits
(
"g-nolimits"
,
nil
,
nil
,
nil
)
// 无限额
sub
:=
s
.
mustCreateSubscription
(
user
.
ID
,
group
.
ID
,
nil
)
// 应该可以增加任意金额
err
:=
s
.
repo
.
IncrementUsage
(
s
.
ctx
,
sub
.
ID
,
1000000.0
)
s
.
Require
()
.
NoError
(
err
,
"should succeed without limits"
)
got
,
err
:=
s
.
repo
.
GetByID
(
s
.
ctx
,
sub
.
ID
)
s
.
Require
()
.
NoError
(
err
)
s
.
Require
()
.
InDelta
(
1000000.0
,
got
.
DailyUsageUSD
,
1e-6
)
}
func
(
s
*
UserSubscriptionRepoSuite
)
TestIncrementUsage_AtExactLimit
()
{
user
:=
s
.
mustCreateUser
(
"exactlimit@test.com"
,
service
.
RoleUser
)
dailyLimit
:=
10.0
group
:=
s
.
mustCreateGroupWithLimits
(
"g-exactlimit"
,
&
dailyLimit
,
nil
,
nil
)
sub
:=
s
.
mustCreateSubscription
(
user
.
ID
,
group
.
ID
,
nil
)
// 正好达到限额应该成功
err
:=
s
.
repo
.
IncrementUsage
(
s
.
ctx
,
sub
.
ID
,
10.0
)
s
.
Require
()
.
NoError
(
err
,
"should succeed at exact limit"
)
got
,
err
:=
s
.
repo
.
GetByID
(
s
.
ctx
,
sub
.
ID
)
s
.
Require
()
.
NoError
(
err
)
s
.
Require
()
.
InDelta
(
10.0
,
got
.
DailyUsageUSD
,
1e-6
)
}
func
(
s
*
UserSubscriptionRepoSuite
)
TestIncrementUsage_SoftDeletedGroup
()
{
func
(
s
*
UserSubscriptionRepoSuite
)
TestIncrementUsage_SoftDeletedGroup
()
{
user
:=
s
.
mustCreateUser
(
"softdeleted@test.com"
,
service
.
RoleUser
)
user
:=
s
.
mustCreateUser
(
"softdeleted@test.com"
,
service
.
RoleUser
)
...
@@ -779,7 +674,7 @@ func (s *UserSubscriptionRepoSuite) TestUpdate_NilInput() {
...
@@ -779,7 +674,7 @@ func (s *UserSubscriptionRepoSuite) TestUpdate_NilInput() {
func
(
s
*
UserSubscriptionRepoSuite
)
TestIncrementUsage_Concurrent
()
{
func
(
s
*
UserSubscriptionRepoSuite
)
TestIncrementUsage_Concurrent
()
{
user
:=
s
.
mustCreateUser
(
"concurrent@test.com"
,
service
.
RoleUser
)
user
:=
s
.
mustCreateUser
(
"concurrent@test.com"
,
service
.
RoleUser
)
group
:=
s
.
mustCreateGroup
WithLimits
(
"g-concurrent"
,
nil
,
nil
,
nil
)
// 无限额
group
:=
s
.
mustCreateGroup
(
"g-concurrent"
)
sub
:=
s
.
mustCreateSubscription
(
user
.
ID
,
group
.
ID
,
nil
)
sub
:=
s
.
mustCreateSubscription
(
user
.
ID
,
group
.
ID
,
nil
)
const
numGoroutines
=
10
const
numGoroutines
=
10
...
@@ -808,34 +703,6 @@ func (s *UserSubscriptionRepoSuite) TestIncrementUsage_Concurrent() {
...
@@ -808,34 +703,6 @@ func (s *UserSubscriptionRepoSuite) TestIncrementUsage_Concurrent() {
s
.
Require
()
.
InDelta
(
expectedUsage
,
got
.
MonthlyUsageUSD
,
1e-6
,
"monthly usage should be correctly accumulated"
)
s
.
Require
()
.
InDelta
(
expectedUsage
,
got
.
MonthlyUsageUSD
,
1e-6
,
"monthly usage should be correctly accumulated"
)
}
}
func
(
s
*
UserSubscriptionRepoSuite
)
TestIncrementUsage_ConcurrentWithLimit
()
{
user
:=
s
.
mustCreateUser
(
"concurrentlimit@test.com"
,
service
.
RoleUser
)
dailyLimit
:=
5.0
group
:=
s
.
mustCreateGroupWithLimits
(
"g-concurrentlimit"
,
&
dailyLimit
,
nil
,
nil
)
sub
:=
s
.
mustCreateSubscription
(
user
.
ID
,
group
.
ID
,
nil
)
// 注意:事务内的操作是串行的,所以这里改为顺序执行以验证限额逻辑
// 尝试增加 10 次,每次 1.0,但限额只有 5.0
const
numAttempts
=
10
const
incrementPerAttempt
=
1.0
successCount
:=
0
for
i
:=
0
;
i
<
numAttempts
;
i
++
{
err
:=
s
.
repo
.
IncrementUsage
(
s
.
ctx
,
sub
.
ID
,
incrementPerAttempt
)
if
err
==
nil
{
successCount
++
}
}
// 验证:应该有 5 次成功(不超过限额),5 次失败(超出限额)
s
.
Require
()
.
Equal
(
5
,
successCount
,
"exactly 5 increments should succeed (limit=5, increment=1)"
)
// 验证最终用量等于限额
got
,
err
:=
s
.
repo
.
GetByID
(
s
.
ctx
,
sub
.
ID
)
s
.
Require
()
.
NoError
(
err
)
s
.
Require
()
.
InDelta
(
dailyLimit
,
got
.
DailyUsageUSD
,
1e-6
,
"daily usage should equal limit"
)
}
func
(
s
*
UserSubscriptionRepoSuite
)
TestTxContext_RollbackIsolation
()
{
func
(
s
*
UserSubscriptionRepoSuite
)
TestTxContext_RollbackIsolation
()
{
baseClient
:=
testEntClient
(
s
.
T
())
baseClient
:=
testEntClient
(
s
.
T
())
tx
,
err
:=
baseClient
.
Tx
(
context
.
Background
())
tx
,
err
:=
baseClient
.
Tx
(
context
.
Background
())
...
...
backend/internal/repository/wire.go
View file @
c5781c69
package
repository
package
repository
import
(
import
(
"database/sql"
"errors"
entsql
"entgo.io/ent/dialect/sql"
"github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/google/wire"
"github.com/google/wire"
...
@@ -54,4 +59,58 @@ var ProviderSet = wire.NewSet(
...
@@ -54,4 +59,58 @@ var ProviderSet = wire.NewSet(
NewOpenAIOAuthClient
,
NewOpenAIOAuthClient
,
NewGeminiOAuthClient
,
NewGeminiOAuthClient
,
NewGeminiCliCodeAssistClient
,
NewGeminiCliCodeAssistClient
,
ProvideEnt
,
ProvideSQLDB
,
ProvideRedis
,
)
)
// ProvideEnt 为依赖注入提供 Ent 客户端。
//
// 该函数是 InitEnt 的包装器,符合 Wire 的依赖提供函数签名要求。
// Wire 会在编译时分析依赖关系,自动生成初始化代码。
//
// 依赖:config.Config
// 提供:*ent.Client
func
ProvideEnt
(
cfg
*
config
.
Config
)
(
*
ent
.
Client
,
error
)
{
client
,
_
,
err
:=
InitEnt
(
cfg
)
return
client
,
err
}
// ProvideSQLDB 从 Ent 客户端提取底层的 *sql.DB 连接。
//
// 某些 Repository 需要直接执行原生 SQL(如复杂的批量更新、聚合查询),
// 此时需要访问底层的 sql.DB 而不是通过 Ent ORM。
//
// 设计说明:
// - Ent 底层使用 sql.DB,通过 Driver 接口可以访问
// - 这种设计允许在同一事务中混用 Ent 和原生 SQL
//
// 依赖:*ent.Client
// 提供:*sql.DB
func
ProvideSQLDB
(
client
*
ent
.
Client
)
(
*
sql
.
DB
,
error
)
{
if
client
==
nil
{
return
nil
,
errors
.
New
(
"nil ent client"
)
}
// 从 Ent 客户端获取底层驱动
drv
,
ok
:=
client
.
Driver
()
.
(
*
entsql
.
Driver
)
if
!
ok
{
return
nil
,
errors
.
New
(
"ent driver does not expose *sql.DB"
)
}
// 返回驱动持有的 sql.DB 实例
return
drv
.
DB
(),
nil
}
// ProvideRedis 为依赖注入提供 Redis 客户端。
//
// Redis 用于:
// - 分布式锁(如并发控制)
// - 缓存(如用户会话、API 响应缓存)
// - 速率限制
// - 实时统计数据
//
// 依赖:config.Config
// 提供:*redis.Client
func
ProvideRedis
(
cfg
*
config
.
Config
)
*
redis
.
Client
{
return
InitRedis
(
cfg
)
}
backend/internal/server/middleware/recovery.go
View file @
c5781c69
...
@@ -7,7 +7,7 @@ import (
...
@@ -7,7 +7,7 @@ import (
"os"
"os"
"strings"
"strings"
infraerrors
"github.com/Wei-Shaw/sub2api/internal/
infrastructure
/errors"
infraerrors
"github.com/Wei-Shaw/sub2api/internal/
pkg
/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/gin-gonic/gin"
"github.com/gin-gonic/gin"
)
)
...
...
backend/internal/server/middleware/recovery_test.go
View file @
c5781c69
...
@@ -8,7 +8,7 @@ import (
...
@@ -8,7 +8,7 @@ import (
"net/http/httptest"
"net/http/httptest"
"testing"
"testing"
infraerrors
"github.com/Wei-Shaw/sub2api/internal/
infrastructure
/errors"
infraerrors
"github.com/Wei-Shaw/sub2api/internal/
pkg
/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/gin-gonic/gin"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/require"
...
...
backend/internal/service/account_service.go
View file @
c5781c69
...
@@ -5,7 +5,7 @@ import (
...
@@ -5,7 +5,7 @@ import (
"fmt"
"fmt"
"time"
"time"
infraerrors
"github.com/Wei-Shaw/sub2api/internal/
infrastructure
/errors"
infraerrors
"github.com/Wei-Shaw/sub2api/internal/
pkg
/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
)
)
...
...
backend/internal/service/admin_service.go
View file @
c5781c69
...
@@ -488,6 +488,11 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn
...
@@ -488,6 +488,11 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn
subscriptionType
=
SubscriptionTypeStandard
subscriptionType
=
SubscriptionTypeStandard
}
}
// 限额字段:0 和 nil 都表示"无限制"
dailyLimit
:=
normalizeLimit
(
input
.
DailyLimitUSD
)
weeklyLimit
:=
normalizeLimit
(
input
.
WeeklyLimitUSD
)
monthlyLimit
:=
normalizeLimit
(
input
.
MonthlyLimitUSD
)
group
:=
&
Group
{
group
:=
&
Group
{
Name
:
input
.
Name
,
Name
:
input
.
Name
,
Description
:
input
.
Description
,
Description
:
input
.
Description
,
...
@@ -496,9 +501,9 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn
...
@@ -496,9 +501,9 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn
IsExclusive
:
input
.
IsExclusive
,
IsExclusive
:
input
.
IsExclusive
,
Status
:
StatusActive
,
Status
:
StatusActive
,
SubscriptionType
:
subscriptionType
,
SubscriptionType
:
subscriptionType
,
DailyLimitUSD
:
input
.
D
ailyLimit
USD
,
DailyLimitUSD
:
d
ailyLimit
,
WeeklyLimitUSD
:
input
.
W
eeklyLimit
USD
,
WeeklyLimitUSD
:
w
eeklyLimit
,
MonthlyLimitUSD
:
input
.
M
onthlyLimit
USD
,
MonthlyLimitUSD
:
m
onthlyLimit
,
}
}
if
err
:=
s
.
groupRepo
.
Create
(
ctx
,
group
);
err
!=
nil
{
if
err
:=
s
.
groupRepo
.
Create
(
ctx
,
group
);
err
!=
nil
{
return
nil
,
err
return
nil
,
err
...
@@ -506,6 +511,14 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn
...
@@ -506,6 +511,14 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn
return
group
,
nil
return
group
,
nil
}
}
// normalizeLimit 将 0 或负数转换为 nil(表示无限制)
func
normalizeLimit
(
limit
*
float64
)
*
float64
{
if
limit
==
nil
||
*
limit
<=
0
{
return
nil
}
return
limit
}
func
(
s
*
adminServiceImpl
)
UpdateGroup
(
ctx
context
.
Context
,
id
int64
,
input
*
UpdateGroupInput
)
(
*
Group
,
error
)
{
func
(
s
*
adminServiceImpl
)
UpdateGroup
(
ctx
context
.
Context
,
id
int64
,
input
*
UpdateGroupInput
)
(
*
Group
,
error
)
{
group
,
err
:=
s
.
groupRepo
.
GetByID
(
ctx
,
id
)
group
,
err
:=
s
.
groupRepo
.
GetByID
(
ctx
,
id
)
if
err
!=
nil
{
if
err
!=
nil
{
...
@@ -535,15 +548,15 @@ func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *Upd
...
@@ -535,15 +548,15 @@ func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *Upd
if
input
.
SubscriptionType
!=
""
{
if
input
.
SubscriptionType
!=
""
{
group
.
SubscriptionType
=
input
.
SubscriptionType
group
.
SubscriptionType
=
input
.
SubscriptionType
}
}
// 限额字段
支持设置为nil(清除限额)或具体值
// 限额字段
:0 和 nil 都表示"无限制",正数表示具体限额
if
input
.
DailyLimitUSD
!=
nil
{
if
input
.
DailyLimitUSD
!=
nil
{
group
.
DailyLimitUSD
=
input
.
DailyLimitUSD
group
.
DailyLimitUSD
=
normalizeLimit
(
input
.
DailyLimitUSD
)
}
}
if
input
.
WeeklyLimitUSD
!=
nil
{
if
input
.
WeeklyLimitUSD
!=
nil
{
group
.
WeeklyLimitUSD
=
input
.
WeeklyLimitUSD
group
.
WeeklyLimitUSD
=
normalizeLimit
(
input
.
WeeklyLimitUSD
)
}
}
if
input
.
MonthlyLimitUSD
!=
nil
{
if
input
.
MonthlyLimitUSD
!=
nil
{
group
.
MonthlyLimitUSD
=
input
.
MonthlyLimitUSD
group
.
MonthlyLimitUSD
=
normalizeLimit
(
input
.
MonthlyLimitUSD
)
}
}
if
err
:=
s
.
groupRepo
.
Update
(
ctx
,
group
);
err
!=
nil
{
if
err
:=
s
.
groupRepo
.
Update
(
ctx
,
group
);
err
!=
nil
{
...
...
backend/internal/service/antigravity_gateway_service.go
View file @
c5781c69
...
@@ -358,6 +358,15 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
...
@@ -358,6 +358,15 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
return
nil
,
fmt
.
Errorf
(
"transform request: %w"
,
err
)
return
nil
,
fmt
.
Errorf
(
"transform request: %w"
,
err
)
}
}
// 调试:记录转换后的请求体(仅记录前 2000 字符)
if
bodyJSON
,
err
:=
json
.
Marshal
(
geminiBody
);
err
==
nil
{
truncated
:=
string
(
bodyJSON
)
if
len
(
truncated
)
>
2000
{
truncated
=
truncated
[
:
2000
]
+
"..."
}
log
.
Printf
(
"[Debug] Transformed Gemini request: %s"
,
truncated
)
}
// 构建上游 action
// 构建上游 action
action
:=
"generateContent"
action
:=
"generateContent"
if
claudeReq
.
Stream
{
if
claudeReq
.
Stream
{
...
...
backend/internal/service/antigravity_token_refresher.go
View file @
c5781c69
...
@@ -2,6 +2,7 @@ package service
...
@@ -2,6 +2,7 @@ package service
import
(
import
(
"context"
"context"
"fmt"
"time"
"time"
)
)
...
@@ -28,7 +29,7 @@ func (r *AntigravityTokenRefresher) CanRefresh(account *Account) bool {
...
@@ -28,7 +29,7 @@ func (r *AntigravityTokenRefresher) CanRefresh(account *Account) bool {
}
}
// NeedsRefresh 检查账户是否需要刷新
// NeedsRefresh 检查账户是否需要刷新
// Antigravity 使用固定的1
0
分钟刷新窗口,忽略全局配置
// Antigravity 使用固定的1
5
分钟刷新窗口,忽略全局配置
func
(
r
*
AntigravityTokenRefresher
)
NeedsRefresh
(
account
*
Account
,
_
time
.
Duration
)
bool
{
func
(
r
*
AntigravityTokenRefresher
)
NeedsRefresh
(
account
*
Account
,
_
time
.
Duration
)
bool
{
if
!
r
.
CanRefresh
(
account
)
{
if
!
r
.
CanRefresh
(
account
)
{
return
false
return
false
...
@@ -37,7 +38,13 @@ func (r *AntigravityTokenRefresher) NeedsRefresh(account *Account, _ time.Durati
...
@@ -37,7 +38,13 @@ func (r *AntigravityTokenRefresher) NeedsRefresh(account *Account, _ time.Durati
if
expiresAt
==
nil
{
if
expiresAt
==
nil
{
return
false
return
false
}
}
return
time
.
Until
(
*
expiresAt
)
<
antigravityRefreshWindow
timeUntilExpiry
:=
time
.
Until
(
*
expiresAt
)
needsRefresh
:=
timeUntilExpiry
<
antigravityRefreshWindow
if
needsRefresh
{
fmt
.
Printf
(
"[AntigravityTokenRefresher] Account %d needs refresh: expires_at=%s, time_until_expiry=%v, window=%v
\n
"
,
account
.
ID
,
expiresAt
.
Format
(
"2006-01-02 15:04:05"
),
timeUntilExpiry
,
antigravityRefreshWindow
)
}
return
needsRefresh
}
}
// Refresh 执行 token 刷新
// Refresh 执行 token 刷新
...
...
backend/internal/service/api_key_service.go
View file @
c5781c69
...
@@ -8,7 +8,7 @@ import (
...
@@ -8,7 +8,7 @@ import (
"time"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/config"
infraerrors
"github.com/Wei-Shaw/sub2api/internal/
infrastructure
/errors"
infraerrors
"github.com/Wei-Shaw/sub2api/internal/
pkg
/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
)
)
...
...
backend/internal/service/auth_service.go
View file @
c5781c69
...
@@ -8,7 +8,7 @@ import (
...
@@ -8,7 +8,7 @@ import (
"time"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/config"
infraerrors
"github.com/Wei-Shaw/sub2api/internal/
infrastructure
/errors"
infraerrors
"github.com/Wei-Shaw/sub2api/internal/
pkg
/errors"
"github.com/golang-jwt/jwt/v5"
"github.com/golang-jwt/jwt/v5"
"golang.org/x/crypto/bcrypt"
"golang.org/x/crypto/bcrypt"
...
...
backend/internal/service/billing_cache_service.go
View file @
c5781c69
...
@@ -9,7 +9,7 @@ import (
...
@@ -9,7 +9,7 @@ import (
"time"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/config"
infraerrors
"github.com/Wei-Shaw/sub2api/internal/
infrastructure
/errors"
infraerrors
"github.com/Wei-Shaw/sub2api/internal/
pkg
/errors"
)
)
// 错误定义
// 错误定义
...
...
backend/internal/service/email_service.go
View file @
c5781c69
...
@@ -10,7 +10,7 @@ import (
...
@@ -10,7 +10,7 @@ import (
"strconv"
"strconv"
"time"
"time"
infraerrors
"github.com/Wei-Shaw/sub2api/internal/
infrastructure
/errors"
infraerrors
"github.com/Wei-Shaw/sub2api/internal/
pkg
/errors"
)
)
var
(
var
(
...
...
backend/internal/service/gateway_service.go
View file @
c5781c69
...
@@ -20,6 +20,7 @@ import (
...
@@ -20,6 +20,7 @@ import (
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
"github.com/tidwall/sjson"
"github.com/gin-gonic/gin"
"github.com/gin-gonic/gin"
...
@@ -1061,6 +1062,30 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
...
@@ -1061,6 +1062,30 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
// 处理错误响应(不可重试的错误)
// 处理错误响应(不可重试的错误)
if
resp
.
StatusCode
>=
400
{
if
resp
.
StatusCode
>=
400
{
// 可选:对部分 400 触发 failover(默认关闭以保持语义)
if
resp
.
StatusCode
==
400
&&
s
.
cfg
!=
nil
&&
s
.
cfg
.
Gateway
.
FailoverOn400
{
respBody
,
readErr
:=
io
.
ReadAll
(
resp
.
Body
)
if
readErr
!=
nil
{
// ReadAll failed, fall back to normal error handling without consuming the stream
return
s
.
handleErrorResponse
(
ctx
,
resp
,
c
,
account
)
}
_
=
resp
.
Body
.
Close
()
resp
.
Body
=
io
.
NopCloser
(
bytes
.
NewReader
(
respBody
))
if
s
.
shouldFailoverOn400
(
respBody
)
{
if
s
.
cfg
.
Gateway
.
LogUpstreamErrorBody
{
log
.
Printf
(
"Account %d: 400 error, attempting failover: %s"
,
account
.
ID
,
truncateForLog
(
respBody
,
s
.
cfg
.
Gateway
.
LogUpstreamErrorBodyMaxBytes
),
)
}
else
{
log
.
Printf
(
"Account %d: 400 error, attempting failover"
,
account
.
ID
)
}
s
.
handleFailoverSideEffects
(
ctx
,
resp
,
account
)
return
nil
,
&
UpstreamFailoverError
{
StatusCode
:
resp
.
StatusCode
}
}
}
return
s
.
handleErrorResponse
(
ctx
,
resp
,
c
,
account
)
return
s
.
handleErrorResponse
(
ctx
,
resp
,
c
,
account
)
}
}
...
@@ -1163,6 +1188,13 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
...
@@ -1163,6 +1188,13 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
// 处理anthropic-beta header(OAuth账号需要特殊处理)
// 处理anthropic-beta header(OAuth账号需要特殊处理)
if
tokenType
==
"oauth"
{
if
tokenType
==
"oauth"
{
req
.
Header
.
Set
(
"anthropic-beta"
,
s
.
getBetaHeader
(
modelID
,
c
.
GetHeader
(
"anthropic-beta"
)))
req
.
Header
.
Set
(
"anthropic-beta"
,
s
.
getBetaHeader
(
modelID
,
c
.
GetHeader
(
"anthropic-beta"
)))
}
else
if
s
.
cfg
!=
nil
&&
s
.
cfg
.
Gateway
.
InjectBetaForApiKey
&&
req
.
Header
.
Get
(
"anthropic-beta"
)
==
""
{
// API-key:仅在请求显式使用 beta 特性且客户端未提供时,按需补齐(默认关闭)
if
requestNeedsBetaFeatures
(
body
)
{
if
beta
:=
defaultApiKeyBetaHeader
(
body
);
beta
!=
""
{
req
.
Header
.
Set
(
"anthropic-beta"
,
beta
)
}
}
}
}
return
req
,
nil
return
req
,
nil
...
@@ -1215,6 +1247,83 @@ func (s *GatewayService) getBetaHeader(modelID string, clientBetaHeader string)
...
@@ -1215,6 +1247,83 @@ func (s *GatewayService) getBetaHeader(modelID string, clientBetaHeader string)
return
claude
.
DefaultBetaHeader
return
claude
.
DefaultBetaHeader
}
}
func
requestNeedsBetaFeatures
(
body
[]
byte
)
bool
{
tools
:=
gjson
.
GetBytes
(
body
,
"tools"
)
if
tools
.
Exists
()
&&
tools
.
IsArray
()
&&
len
(
tools
.
Array
())
>
0
{
return
true
}
if
strings
.
EqualFold
(
gjson
.
GetBytes
(
body
,
"thinking.type"
)
.
String
(),
"enabled"
)
{
return
true
}
return
false
}
func
defaultApiKeyBetaHeader
(
body
[]
byte
)
string
{
modelID
:=
gjson
.
GetBytes
(
body
,
"model"
)
.
String
()
if
strings
.
Contains
(
strings
.
ToLower
(
modelID
),
"haiku"
)
{
return
claude
.
ApiKeyHaikuBetaHeader
}
return
claude
.
ApiKeyBetaHeader
}
func
truncateForLog
(
b
[]
byte
,
maxBytes
int
)
string
{
if
maxBytes
<=
0
{
maxBytes
=
2048
}
if
len
(
b
)
>
maxBytes
{
b
=
b
[
:
maxBytes
]
}
s
:=
string
(
b
)
// 保持一行,避免污染日志格式
s
=
strings
.
ReplaceAll
(
s
,
"
\n
"
,
"
\\
n"
)
s
=
strings
.
ReplaceAll
(
s
,
"
\r
"
,
"
\\
r"
)
return
s
}
func
(
s
*
GatewayService
)
shouldFailoverOn400
(
respBody
[]
byte
)
bool
{
// 只对“可能是兼容性差异导致”的 400 允许切换,避免无意义重试。
// 默认保守:无法识别则不切换。
msg
:=
strings
.
ToLower
(
strings
.
TrimSpace
(
extractUpstreamErrorMessage
(
respBody
)))
if
msg
==
""
{
return
false
}
// 缺少/错误的 beta header:换账号/链路可能成功(尤其是混合调度时)。
// 更精确匹配 beta 相关的兼容性问题,避免误触发切换。
if
strings
.
Contains
(
msg
,
"anthropic-beta"
)
||
strings
.
Contains
(
msg
,
"beta feature"
)
||
strings
.
Contains
(
msg
,
"requires beta"
)
{
return
true
}
// thinking/tool streaming 等兼容性约束(常见于中间转换链路)
if
strings
.
Contains
(
msg
,
"thinking"
)
||
strings
.
Contains
(
msg
,
"thought_signature"
)
||
strings
.
Contains
(
msg
,
"signature"
)
{
return
true
}
if
strings
.
Contains
(
msg
,
"tool_use"
)
||
strings
.
Contains
(
msg
,
"tool_result"
)
||
strings
.
Contains
(
msg
,
"tools"
)
{
return
true
}
return
false
}
func
extractUpstreamErrorMessage
(
body
[]
byte
)
string
{
// Claude 风格:{"type":"error","error":{"type":"...","message":"..."}}
if
m
:=
gjson
.
GetBytes
(
body
,
"error.message"
)
.
String
();
strings
.
TrimSpace
(
m
)
!=
""
{
inner
:=
strings
.
TrimSpace
(
m
)
// 有些上游会把完整 JSON 作为字符串塞进 message
if
strings
.
HasPrefix
(
inner
,
"{"
)
{
if
innerMsg
:=
gjson
.
Get
(
inner
,
"error.message"
)
.
String
();
strings
.
TrimSpace
(
innerMsg
)
!=
""
{
return
innerMsg
}
}
return
m
}
// 兜底:尝试顶层 message
return
gjson
.
GetBytes
(
body
,
"message"
)
.
String
()
}
func
(
s
*
GatewayService
)
handleErrorResponse
(
ctx
context
.
Context
,
resp
*
http
.
Response
,
c
*
gin
.
Context
,
account
*
Account
)
(
*
ForwardResult
,
error
)
{
func
(
s
*
GatewayService
)
handleErrorResponse
(
ctx
context
.
Context
,
resp
*
http
.
Response
,
c
*
gin
.
Context
,
account
*
Account
)
(
*
ForwardResult
,
error
)
{
body
,
_
:=
io
.
ReadAll
(
resp
.
Body
)
body
,
_
:=
io
.
ReadAll
(
resp
.
Body
)
...
@@ -1227,6 +1336,16 @@ func (s *GatewayService) handleErrorResponse(ctx context.Context, resp *http.Res
...
@@ -1227,6 +1336,16 @@ func (s *GatewayService) handleErrorResponse(ctx context.Context, resp *http.Res
switch
resp
.
StatusCode
{
switch
resp
.
StatusCode
{
case
400
:
case
400
:
// 仅记录上游错误摘要(避免输出请求内容);需要时可通过配置打开
if
s
.
cfg
!=
nil
&&
s
.
cfg
.
Gateway
.
LogUpstreamErrorBody
{
log
.
Printf
(
"Upstream 400 error (account=%d platform=%s type=%s): %s"
,
account
.
ID
,
account
.
Platform
,
account
.
Type
,
truncateForLog
(
body
,
s
.
cfg
.
Gateway
.
LogUpstreamErrorBodyMaxBytes
),
)
}
c
.
Data
(
http
.
StatusBadRequest
,
"application/json"
,
body
)
c
.
Data
(
http
.
StatusBadRequest
,
"application/json"
,
body
)
return
nil
,
fmt
.
Errorf
(
"upstream error: %d"
,
resp
.
StatusCode
)
return
nil
,
fmt
.
Errorf
(
"upstream error: %d"
,
resp
.
StatusCode
)
case
401
:
case
401
:
...
@@ -1706,6 +1825,18 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
...
@@ -1706,6 +1825,18 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
// 标记账号状态(429/529等)
// 标记账号状态(429/529等)
s
.
rateLimitService
.
HandleUpstreamError
(
ctx
,
account
,
resp
.
StatusCode
,
resp
.
Header
,
respBody
)
s
.
rateLimitService
.
HandleUpstreamError
(
ctx
,
account
,
resp
.
StatusCode
,
resp
.
Header
,
respBody
)
// 记录上游错误摘要便于排障(不回显请求内容)
if
s
.
cfg
!=
nil
&&
s
.
cfg
.
Gateway
.
LogUpstreamErrorBody
{
log
.
Printf
(
"count_tokens upstream error %d (account=%d platform=%s type=%s): %s"
,
resp
.
StatusCode
,
account
.
ID
,
account
.
Platform
,
account
.
Type
,
truncateForLog
(
respBody
,
s
.
cfg
.
Gateway
.
LogUpstreamErrorBodyMaxBytes
),
)
}
// 返回简化的错误响应
// 返回简化的错误响应
errMsg
:=
"Upstream request failed"
errMsg
:=
"Upstream request failed"
switch
resp
.
StatusCode
{
switch
resp
.
StatusCode
{
...
@@ -1786,6 +1917,13 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
...
@@ -1786,6 +1917,13 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
// OAuth 账号:处理 anthropic-beta header
// OAuth 账号:处理 anthropic-beta header
if
tokenType
==
"oauth"
{
if
tokenType
==
"oauth"
{
req
.
Header
.
Set
(
"anthropic-beta"
,
s
.
getBetaHeader
(
modelID
,
c
.
GetHeader
(
"anthropic-beta"
)))
req
.
Header
.
Set
(
"anthropic-beta"
,
s
.
getBetaHeader
(
modelID
,
c
.
GetHeader
(
"anthropic-beta"
)))
}
else
if
s
.
cfg
!=
nil
&&
s
.
cfg
.
Gateway
.
InjectBetaForApiKey
&&
req
.
Header
.
Get
(
"anthropic-beta"
)
==
""
{
// API-key:与 messages 同步的按需 beta 注入(默认关闭)
if
requestNeedsBetaFeatures
(
body
)
{
if
beta
:=
defaultApiKeyBetaHeader
(
body
);
beta
!=
""
{
req
.
Header
.
Set
(
"anthropic-beta"
,
beta
)
}
}
}
}
return
req
,
nil
return
req
,
nil
...
...
backend/internal/service/gemini_messages_compat_service.go
View file @
c5781c69
...
@@ -2278,11 +2278,13 @@ func convertClaudeToolsToGeminiTools(tools any) []any {
...
@@ -2278,11 +2278,13 @@ func convertClaudeToolsToGeminiTools(tools any) []any {
"properties"
:
map
[
string
]
any
{},
"properties"
:
map
[
string
]
any
{},
}
}
}
}
// 清理 JSON Schema
cleanedParams
:=
cleanToolSchema
(
params
)
funcDecls
=
append
(
funcDecls
,
map
[
string
]
any
{
funcDecls
=
append
(
funcDecls
,
map
[
string
]
any
{
"name"
:
name
,
"name"
:
name
,
"description"
:
desc
,
"description"
:
desc
,
"parameters"
:
p
arams
,
"parameters"
:
cleanedP
arams
,
})
})
}
}
...
@@ -2296,6 +2298,41 @@ func convertClaudeToolsToGeminiTools(tools any) []any {
...
@@ -2296,6 +2298,41 @@ func convertClaudeToolsToGeminiTools(tools any) []any {
}
}
}
}
// cleanToolSchema 清理工具的 JSON Schema,移除 Gemini 不支持的字段
func
cleanToolSchema
(
schema
any
)
any
{
if
schema
==
nil
{
return
nil
}
switch
v
:=
schema
.
(
type
)
{
case
map
[
string
]
any
:
cleaned
:=
make
(
map
[
string
]
any
)
for
key
,
value
:=
range
v
{
// 跳过不支持的字段
if
key
==
"$schema"
||
key
==
"$id"
||
key
==
"$ref"
||
key
==
"additionalProperties"
||
key
==
"minLength"
||
key
==
"maxLength"
||
key
==
"minItems"
||
key
==
"maxItems"
{
continue
}
// 递归清理嵌套对象
cleaned
[
key
]
=
cleanToolSchema
(
value
)
}
// 规范化 type 字段为大写
if
typeVal
,
ok
:=
cleaned
[
"type"
]
.
(
string
);
ok
{
cleaned
[
"type"
]
=
strings
.
ToUpper
(
typeVal
)
}
return
cleaned
case
[]
any
:
cleaned
:=
make
([]
any
,
len
(
v
))
for
i
,
item
:=
range
v
{
cleaned
[
i
]
=
cleanToolSchema
(
item
)
}
return
cleaned
default
:
return
v
}
}
func
convertClaudeGenerationConfig
(
req
map
[
string
]
any
)
map
[
string
]
any
{
func
convertClaudeGenerationConfig
(
req
map
[
string
]
any
)
map
[
string
]
any
{
out
:=
make
(
map
[
string
]
any
)
out
:=
make
(
map
[
string
]
any
)
if
mt
,
ok
:=
asInt
(
req
[
"max_tokens"
]);
ok
&&
mt
>
0
{
if
mt
,
ok
:=
asInt
(
req
[
"max_tokens"
]);
ok
&&
mt
>
0
{
...
...
backend/internal/service/gemini_messages_compat_service_test.go
0 → 100644
View file @
c5781c69
package
service
import
(
"testing"
)
// TestConvertClaudeToolsToGeminiTools_CustomType 测试custom类型工具转换
func
TestConvertClaudeToolsToGeminiTools_CustomType
(
t
*
testing
.
T
)
{
tests
:=
[]
struct
{
name
string
tools
any
expectedLen
int
description
string
}{
{
name
:
"Standard tools"
,
tools
:
[]
any
{
map
[
string
]
any
{
"name"
:
"get_weather"
,
"description"
:
"Get weather info"
,
"input_schema"
:
map
[
string
]
any
{
"type"
:
"object"
},
},
},
expectedLen
:
1
,
description
:
"标准工具格式应该正常转换"
,
},
{
name
:
"Custom type tool (MCP format)"
,
tools
:
[]
any
{
map
[
string
]
any
{
"type"
:
"custom"
,
"name"
:
"mcp_tool"
,
"custom"
:
map
[
string
]
any
{
"description"
:
"MCP tool description"
,
"input_schema"
:
map
[
string
]
any
{
"type"
:
"object"
},
},
},
},
expectedLen
:
1
,
description
:
"Custom类型工具应该从custom字段读取"
,
},
{
name
:
"Mixed standard and custom tools"
,
tools
:
[]
any
{
map
[
string
]
any
{
"name"
:
"standard_tool"
,
"description"
:
"Standard"
,
"input_schema"
:
map
[
string
]
any
{
"type"
:
"object"
},
},
map
[
string
]
any
{
"type"
:
"custom"
,
"name"
:
"custom_tool"
,
"custom"
:
map
[
string
]
any
{
"description"
:
"Custom"
,
"input_schema"
:
map
[
string
]
any
{
"type"
:
"object"
},
},
},
},
expectedLen
:
1
,
description
:
"混合工具应该都能正确转换"
,
},
{
name
:
"Custom tool without custom field"
,
tools
:
[]
any
{
map
[
string
]
any
{
"type"
:
"custom"
,
"name"
:
"invalid_custom"
,
// 缺少 custom 字段
},
},
expectedLen
:
0
,
// 应该被跳过
description
:
"缺少custom字段的custom工具应该被跳过"
,
},
}
for
_
,
tt
:=
range
tests
{
t
.
Run
(
tt
.
name
,
func
(
t
*
testing
.
T
)
{
result
:=
convertClaudeToolsToGeminiTools
(
tt
.
tools
)
if
tt
.
expectedLen
==
0
{
if
result
!=
nil
{
t
.
Errorf
(
"%s: expected nil result, got %v"
,
tt
.
description
,
result
)
}
return
}
if
result
==
nil
{
t
.
Fatalf
(
"%s: expected non-nil result"
,
tt
.
description
)
}
if
len
(
result
)
!=
1
{
t
.
Errorf
(
"%s: expected 1 tool declaration, got %d"
,
tt
.
description
,
len
(
result
))
return
}
toolDecl
,
ok
:=
result
[
0
]
.
(
map
[
string
]
any
)
if
!
ok
{
t
.
Fatalf
(
"%s: result[0] is not map[string]any"
,
tt
.
description
)
}
funcDecls
,
ok
:=
toolDecl
[
"functionDeclarations"
]
.
([]
any
)
if
!
ok
{
t
.
Fatalf
(
"%s: functionDeclarations is not []any"
,
tt
.
description
)
}
toolsArr
,
_
:=
tt
.
tools
.
([]
any
)
expectedFuncCount
:=
0
for
_
,
tool
:=
range
toolsArr
{
toolMap
,
_
:=
tool
.
(
map
[
string
]
any
)
if
toolMap
[
"name"
]
!=
""
{
// 检查是否为有效的custom工具
if
toolMap
[
"type"
]
==
"custom"
{
if
toolMap
[
"custom"
]
!=
nil
{
expectedFuncCount
++
}
}
else
{
expectedFuncCount
++
}
}
}
if
len
(
funcDecls
)
!=
expectedFuncCount
{
t
.
Errorf
(
"%s: expected %d function declarations, got %d"
,
tt
.
description
,
expectedFuncCount
,
len
(
funcDecls
))
}
})
}
}
backend/internal/service/gemini_oauth_service.go
View file @
c5781c69
...
@@ -7,6 +7,7 @@ import (
...
@@ -7,6 +7,7 @@ import (
"fmt"
"fmt"
"io"
"io"
"net/http"
"net/http"
"regexp"
"strconv"
"strconv"
"strings"
"strings"
"time"
"time"
...
@@ -163,6 +164,45 @@ type GeminiTokenInfo struct {
...
@@ -163,6 +164,45 @@ type GeminiTokenInfo struct {
Scope
string
`json:"scope,omitempty"`
Scope
string
`json:"scope,omitempty"`
ProjectID
string
`json:"project_id,omitempty"`
ProjectID
string
`json:"project_id,omitempty"`
OAuthType
string
`json:"oauth_type,omitempty"`
// "code_assist" 或 "ai_studio"
OAuthType
string
`json:"oauth_type,omitempty"`
// "code_assist" 或 "ai_studio"
TierID
string
`json:"tier_id,omitempty"`
// Gemini Code Assist tier: LEGACY/PRO/ULTRA
}
// validateTierID validates tier_id format and length
func
validateTierID
(
tierID
string
)
error
{
if
tierID
==
""
{
return
nil
// Empty is allowed
}
if
len
(
tierID
)
>
64
{
return
fmt
.
Errorf
(
"tier_id exceeds maximum length of 64 characters"
)
}
// Allow alphanumeric, underscore, hyphen, and slash (for tier paths)
if
!
regexp
.
MustCompile
(
`^[a-zA-Z0-9_/-]+$`
)
.
MatchString
(
tierID
)
{
return
fmt
.
Errorf
(
"tier_id contains invalid characters"
)
}
return
nil
}
// extractTierIDFromAllowedTiers extracts tierID from LoadCodeAssist response
// Prioritizes IsDefault tier, falls back to first non-empty tier
func
extractTierIDFromAllowedTiers
(
allowedTiers
[]
geminicli
.
AllowedTier
)
string
{
tierID
:=
"LEGACY"
// First pass: look for default tier
for
_
,
tier
:=
range
allowedTiers
{
if
tier
.
IsDefault
&&
strings
.
TrimSpace
(
tier
.
ID
)
!=
""
{
tierID
=
strings
.
TrimSpace
(
tier
.
ID
)
break
}
}
// Second pass: if still LEGACY, take first non-empty tier
if
tierID
==
"LEGACY"
{
for
_
,
tier
:=
range
allowedTiers
{
if
strings
.
TrimSpace
(
tier
.
ID
)
!=
""
{
tierID
=
strings
.
TrimSpace
(
tier
.
ID
)
break
}
}
}
return
tierID
}
}
func
(
s
*
GeminiOAuthService
)
ExchangeCode
(
ctx
context
.
Context
,
input
*
GeminiExchangeCodeInput
)
(
*
GeminiTokenInfo
,
error
)
{
func
(
s
*
GeminiOAuthService
)
ExchangeCode
(
ctx
context
.
Context
,
input
*
GeminiExchangeCodeInput
)
(
*
GeminiTokenInfo
,
error
)
{
...
@@ -223,13 +263,14 @@ func (s *GeminiOAuthService) ExchangeCode(ctx context.Context, input *GeminiExch
...
@@ -223,13 +263,14 @@ func (s *GeminiOAuthService) ExchangeCode(ctx context.Context, input *GeminiExch
expiresAt
:=
time
.
Now
()
.
Unix
()
+
tokenResp
.
ExpiresIn
-
300
expiresAt
:=
time
.
Now
()
.
Unix
()
+
tokenResp
.
ExpiresIn
-
300
projectID
:=
sessionProjectID
projectID
:=
sessionProjectID
var
tierID
string
// 对于 code_assist 模式,project_id 是必需的
// 对于 code_assist 模式,project_id 是必需的
// 对于 ai_studio 模式,project_id 是可选的(不影响使用 AI Studio API)
// 对于 ai_studio 模式,project_id 是可选的(不影响使用 AI Studio API)
if
oauthType
==
"code_assist"
{
if
oauthType
==
"code_assist"
{
if
projectID
==
""
{
if
projectID
==
""
{
var
err
error
var
err
error
projectID
,
err
=
s
.
fetchProjectID
(
ctx
,
tokenResp
.
AccessToken
,
proxyURL
)
projectID
,
tierID
,
err
=
s
.
fetchProjectID
(
ctx
,
tokenResp
.
AccessToken
,
proxyURL
)
if
err
!=
nil
{
if
err
!=
nil
{
// 记录警告但不阻断流程,允许后续补充 project_id
// 记录警告但不阻断流程,允许后续补充 project_id
fmt
.
Printf
(
"[GeminiOAuth] Warning: Failed to fetch project_id during token exchange: %v
\n
"
,
err
)
fmt
.
Printf
(
"[GeminiOAuth] Warning: Failed to fetch project_id during token exchange: %v
\n
"
,
err
)
...
@@ -248,6 +289,7 @@ func (s *GeminiOAuthService) ExchangeCode(ctx context.Context, input *GeminiExch
...
@@ -248,6 +289,7 @@ func (s *GeminiOAuthService) ExchangeCode(ctx context.Context, input *GeminiExch
ExpiresAt
:
expiresAt
,
ExpiresAt
:
expiresAt
,
Scope
:
tokenResp
.
Scope
,
Scope
:
tokenResp
.
Scope
,
ProjectID
:
projectID
,
ProjectID
:
projectID
,
TierID
:
tierID
,
OAuthType
:
oauthType
,
OAuthType
:
oauthType
,
},
nil
},
nil
}
}
...
@@ -357,7 +399,7 @@ func (s *GeminiOAuthService) RefreshAccountToken(ctx context.Context, account *A
...
@@ -357,7 +399,7 @@ func (s *GeminiOAuthService) RefreshAccountToken(ctx context.Context, account *A
// For Code Assist, project_id is required. Auto-detect if missing.
// For Code Assist, project_id is required. Auto-detect if missing.
// For AI Studio OAuth, project_id is optional and should not block refresh.
// For AI Studio OAuth, project_id is optional and should not block refresh.
if
oauthType
==
"code_assist"
&&
strings
.
TrimSpace
(
tokenInfo
.
ProjectID
)
==
""
{
if
oauthType
==
"code_assist"
&&
strings
.
TrimSpace
(
tokenInfo
.
ProjectID
)
==
""
{
projectID
,
err
:=
s
.
fetchProjectID
(
ctx
,
tokenInfo
.
AccessToken
,
proxyURL
)
projectID
,
tierID
,
err
:=
s
.
fetchProjectID
(
ctx
,
tokenInfo
.
AccessToken
,
proxyURL
)
if
err
!=
nil
{
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"failed to auto-detect project_id: %w"
,
err
)
return
nil
,
fmt
.
Errorf
(
"failed to auto-detect project_id: %w"
,
err
)
}
}
...
@@ -366,6 +408,7 @@ func (s *GeminiOAuthService) RefreshAccountToken(ctx context.Context, account *A
...
@@ -366,6 +408,7 @@ func (s *GeminiOAuthService) RefreshAccountToken(ctx context.Context, account *A
return
nil
,
fmt
.
Errorf
(
"failed to auto-detect project_id: empty result"
)
return
nil
,
fmt
.
Errorf
(
"failed to auto-detect project_id: empty result"
)
}
}
tokenInfo
.
ProjectID
=
projectID
tokenInfo
.
ProjectID
=
projectID
tokenInfo
.
TierID
=
tierID
}
}
return
tokenInfo
,
nil
return
tokenInfo
,
nil
...
@@ -388,6 +431,13 @@ func (s *GeminiOAuthService) BuildAccountCredentials(tokenInfo *GeminiTokenInfo)
...
@@ -388,6 +431,13 @@ func (s *GeminiOAuthService) BuildAccountCredentials(tokenInfo *GeminiTokenInfo)
if
tokenInfo
.
ProjectID
!=
""
{
if
tokenInfo
.
ProjectID
!=
""
{
creds
[
"project_id"
]
=
tokenInfo
.
ProjectID
creds
[
"project_id"
]
=
tokenInfo
.
ProjectID
}
}
if
tokenInfo
.
TierID
!=
""
{
// Validate tier_id before storing
if
err
:=
validateTierID
(
tokenInfo
.
TierID
);
err
==
nil
{
creds
[
"tier_id"
]
=
tokenInfo
.
TierID
}
// Silently skip invalid tier_id (don't block account creation)
}
if
tokenInfo
.
OAuthType
!=
""
{
if
tokenInfo
.
OAuthType
!=
""
{
creds
[
"oauth_type"
]
=
tokenInfo
.
OAuthType
creds
[
"oauth_type"
]
=
tokenInfo
.
OAuthType
}
}
...
@@ -398,35 +448,27 @@ func (s *GeminiOAuthService) Stop() {
...
@@ -398,35 +448,27 @@ func (s *GeminiOAuthService) Stop() {
s
.
sessionStore
.
Stop
()
s
.
sessionStore
.
Stop
()
}
}
func
(
s
*
GeminiOAuthService
)
fetchProjectID
(
ctx
context
.
Context
,
accessToken
,
proxyURL
string
)
(
string
,
error
)
{
func
(
s
*
GeminiOAuthService
)
fetchProjectID
(
ctx
context
.
Context
,
accessToken
,
proxyURL
string
)
(
string
,
string
,
error
)
{
if
s
.
codeAssist
==
nil
{
if
s
.
codeAssist
==
nil
{
return
""
,
errors
.
New
(
"code assist client not configured"
)
return
""
,
""
,
errors
.
New
(
"code assist client not configured"
)
}
}
loadResp
,
loadErr
:=
s
.
codeAssist
.
LoadCodeAssist
(
ctx
,
accessToken
,
proxyURL
,
nil
)
loadResp
,
loadErr
:=
s
.
codeAssist
.
LoadCodeAssist
(
ctx
,
accessToken
,
proxyURL
,
nil
)
if
loadErr
==
nil
&&
loadResp
!=
nil
&&
strings
.
TrimSpace
(
loadResp
.
CloudAICompanionProject
)
!=
""
{
return
strings
.
TrimSpace
(
loadResp
.
CloudAICompanionProject
),
nil
}
//
Pick
tier from
allowedTiers; if no default tier is marked, pick the first non-empty tier ID.
//
Extract
tier
ID
from
response (works whether CloudAICompanionProject is set or not)
tierID
:=
"LEGACY"
tierID
:=
"LEGACY"
if
loadResp
!=
nil
{
if
loadResp
!=
nil
{
for
_
,
tier
:=
range
loadResp
.
AllowedTiers
{
tierID
=
extractTierIDFromAllowedTiers
(
loadResp
.
AllowedTiers
)
if
tier
.
IsDefault
&&
strings
.
TrimSpace
(
tier
.
ID
)
!=
""
{
}
tierID
=
strings
.
TrimSpace
(
tier
.
ID
)
break
// If LoadCodeAssist returned a project, use it
}
if
loadErr
==
nil
&&
loadResp
!=
nil
&&
strings
.
TrimSpace
(
loadResp
.
CloudAICompanionProject
)
!=
""
{
}
return
strings
.
TrimSpace
(
loadResp
.
CloudAICompanionProject
),
tierID
,
nil
if
strings
.
TrimSpace
(
tierID
)
==
""
||
tierID
==
"LEGACY"
{
for
_
,
tier
:=
range
loadResp
.
AllowedTiers
{
if
strings
.
TrimSpace
(
tier
.
ID
)
!=
""
{
tierID
=
strings
.
TrimSpace
(
tier
.
ID
)
break
}
}
}
}
}
// Pick tier from allowedTiers; if no default tier is marked, pick the first non-empty tier ID.
// (tierID already extracted above, reuse it)
req
:=
&
geminicli
.
OnboardUserRequest
{
req
:=
&
geminicli
.
OnboardUserRequest
{
TierID
:
tierID
,
TierID
:
tierID
,
Metadata
:
geminicli
.
LoadCodeAssistMetadata
{
Metadata
:
geminicli
.
LoadCodeAssistMetadata
{
...
@@ -443,39 +485,39 @@ func (s *GeminiOAuthService) fetchProjectID(ctx context.Context, accessToken, pr
...
@@ -443,39 +485,39 @@ func (s *GeminiOAuthService) fetchProjectID(ctx context.Context, accessToken, pr
// If Code Assist onboarding fails (e.g. INVALID_ARGUMENT), fallback to Cloud Resource Manager projects.
// If Code Assist onboarding fails (e.g. INVALID_ARGUMENT), fallback to Cloud Resource Manager projects.
fallback
,
fbErr
:=
fetchProjectIDFromResourceManager
(
ctx
,
accessToken
,
proxyURL
)
fallback
,
fbErr
:=
fetchProjectIDFromResourceManager
(
ctx
,
accessToken
,
proxyURL
)
if
fbErr
==
nil
&&
strings
.
TrimSpace
(
fallback
)
!=
""
{
if
fbErr
==
nil
&&
strings
.
TrimSpace
(
fallback
)
!=
""
{
return
strings
.
TrimSpace
(
fallback
),
nil
return
strings
.
TrimSpace
(
fallback
),
tierID
,
nil
}
}
return
""
,
err
return
""
,
""
,
err
}
}
if
resp
.
Done
{
if
resp
.
Done
{
if
resp
.
Response
!=
nil
&&
resp
.
Response
.
CloudAICompanionProject
!=
nil
{
if
resp
.
Response
!=
nil
&&
resp
.
Response
.
CloudAICompanionProject
!=
nil
{
switch
v
:=
resp
.
Response
.
CloudAICompanionProject
.
(
type
)
{
switch
v
:=
resp
.
Response
.
CloudAICompanionProject
.
(
type
)
{
case
string
:
case
string
:
return
strings
.
TrimSpace
(
v
),
nil
return
strings
.
TrimSpace
(
v
),
tierID
,
nil
case
map
[
string
]
any
:
case
map
[
string
]
any
:
if
id
,
ok
:=
v
[
"id"
]
.
(
string
);
ok
{
if
id
,
ok
:=
v
[
"id"
]
.
(
string
);
ok
{
return
strings
.
TrimSpace
(
id
),
nil
return
strings
.
TrimSpace
(
id
),
tierID
,
nil
}
}
}
}
}
}
fallback
,
fbErr
:=
fetchProjectIDFromResourceManager
(
ctx
,
accessToken
,
proxyURL
)
fallback
,
fbErr
:=
fetchProjectIDFromResourceManager
(
ctx
,
accessToken
,
proxyURL
)
if
fbErr
==
nil
&&
strings
.
TrimSpace
(
fallback
)
!=
""
{
if
fbErr
==
nil
&&
strings
.
TrimSpace
(
fallback
)
!=
""
{
return
strings
.
TrimSpace
(
fallback
),
nil
return
strings
.
TrimSpace
(
fallback
),
tierID
,
nil
}
}
return
""
,
errors
.
New
(
"onboardUser completed but no project_id returned"
)
return
""
,
""
,
errors
.
New
(
"onboardUser completed but no project_id returned"
)
}
}
time
.
Sleep
(
2
*
time
.
Second
)
time
.
Sleep
(
2
*
time
.
Second
)
}
}
fallback
,
fbErr
:=
fetchProjectIDFromResourceManager
(
ctx
,
accessToken
,
proxyURL
)
fallback
,
fbErr
:=
fetchProjectIDFromResourceManager
(
ctx
,
accessToken
,
proxyURL
)
if
fbErr
==
nil
&&
strings
.
TrimSpace
(
fallback
)
!=
""
{
if
fbErr
==
nil
&&
strings
.
TrimSpace
(
fallback
)
!=
""
{
return
strings
.
TrimSpace
(
fallback
),
nil
return
strings
.
TrimSpace
(
fallback
),
tierID
,
nil
}
}
if
loadErr
!=
nil
{
if
loadErr
!=
nil
{
return
""
,
fmt
.
Errorf
(
"loadCodeAssist failed (%v) and onboardUser timeout after %d attempts"
,
loadErr
,
maxAttempts
)
return
""
,
""
,
fmt
.
Errorf
(
"loadCodeAssist failed (%v) and onboardUser timeout after %d attempts"
,
loadErr
,
maxAttempts
)
}
}
return
""
,
fmt
.
Errorf
(
"onboardUser timeout after %d attempts"
,
maxAttempts
)
return
""
,
""
,
fmt
.
Errorf
(
"onboardUser timeout after %d attempts"
,
maxAttempts
)
}
}
type
googleCloudProject
struct
{
type
googleCloudProject
struct
{
...
...
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment