Commit c5781c69 authored by IanShaw027's avatar IanShaw027
Browse files

fix(merge): 解决与 main 分支的配置冲突

- 合并 main 分支的上游错误日志配置
- 保留调度配置
- 合并 beta header 和 failover 配置
parents e1a9c1ec 34c10204
...@@ -7,7 +7,6 @@ import ( ...@@ -7,7 +7,6 @@ import (
"database/sql" "database/sql"
"testing" "testing"
"github.com/Wei-Shaw/sub2api/internal/infrastructure"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
...@@ -15,7 +14,7 @@ func TestMigrationsRunner_IsIdempotent_AndSchemaIsUpToDate(t *testing.T) { ...@@ -15,7 +14,7 @@ func TestMigrationsRunner_IsIdempotent_AndSchemaIsUpToDate(t *testing.T) {
tx := testTx(t) tx := testTx(t)
// Re-apply migrations to verify idempotency (no errors, no duplicate rows). // Re-apply migrations to verify idempotency (no errors, no duplicate rows).
require.NoError(t, infrastructure.ApplyMigrations(context.Background(), integrationDB)) require.NoError(t, ApplyMigrations(context.Background(), integrationDB))
// schema_migrations should have at least the current migration set. // schema_migrations should have at least the current migration set.
var applied int var applied int
......
package infrastructure package repository
import ( import (
"time" "time"
......
package infrastructure package repository
import ( import (
"testing" "testing"
......
...@@ -291,13 +291,11 @@ func (r *userSubscriptionRepository) ResetMonthlyUsage(ctx context.Context, id i ...@@ -291,13 +291,11 @@ func (r *userSubscriptionRepository) ResetMonthlyUsage(ctx context.Context, id i
return translatePersistenceError(err, service.ErrSubscriptionNotFound, nil) return translatePersistenceError(err, service.ErrSubscriptionNotFound, nil)
} }
// IncrementUsage 原子性地累加用量并校验限额 // IncrementUsage 原子性地累加订阅用量。
// 使用单条 SQL 语句同时检查 Group 的限额,如果任一限额即将超出则拒绝更新。 // 限额检查已在请求前由 BillingCacheService.CheckBillingEligibility 完成,
// 当更新失败时,会执行额外查询确定具体超出的限额类型 // 此处仅负责记录实际消费,确保消费数据的完整性
func (r *userSubscriptionRepository) IncrementUsage(ctx context.Context, id int64, costUSD float64) error { func (r *userSubscriptionRepository) IncrementUsage(ctx context.Context, id int64, costUSD float64) error {
// 使用 JOIN 的原子更新:只有当所有限额条件满足时才执行累加 const updateSQL = `
// NULL 限额表示无限制
const atomicUpdateSQL = `
UPDATE user_subscriptions us UPDATE user_subscriptions us
SET SET
daily_usage_usd = us.daily_usage_usd + $1, daily_usage_usd = us.daily_usage_usd + $1,
...@@ -309,13 +307,10 @@ func (r *userSubscriptionRepository) IncrementUsage(ctx context.Context, id int6 ...@@ -309,13 +307,10 @@ func (r *userSubscriptionRepository) IncrementUsage(ctx context.Context, id int6
AND us.deleted_at IS NULL AND us.deleted_at IS NULL
AND us.group_id = g.id AND us.group_id = g.id
AND g.deleted_at IS NULL AND g.deleted_at IS NULL
AND (g.daily_limit_usd IS NULL OR us.daily_usage_usd + $1 <= g.daily_limit_usd)
AND (g.weekly_limit_usd IS NULL OR us.weekly_usage_usd + $1 <= g.weekly_limit_usd)
AND (g.monthly_limit_usd IS NULL OR us.monthly_usage_usd + $1 <= g.monthly_limit_usd)
` `
client := clientFromContext(ctx, r.client) client := clientFromContext(ctx, r.client)
result, err := client.ExecContext(ctx, atomicUpdateSQL, costUSD, id) result, err := client.ExecContext(ctx, updateSQL, costUSD, id)
if err != nil { if err != nil {
return err return err
} }
...@@ -326,64 +321,11 @@ func (r *userSubscriptionRepository) IncrementUsage(ctx context.Context, id int6 ...@@ -326,64 +321,11 @@ func (r *userSubscriptionRepository) IncrementUsage(ctx context.Context, id int6
} }
if affected > 0 { if affected > 0 {
return nil // 更新成功 return nil
}
// affected == 0:可能是订阅不存在、分组已删除、或限额超出
// 执行额外查询确定具体原因
return r.checkIncrementFailureReason(ctx, id, costUSD)
}
// checkIncrementFailureReason 查询更新失败的具体原因
func (r *userSubscriptionRepository) checkIncrementFailureReason(ctx context.Context, id int64, costUSD float64) error {
const checkSQL = `
SELECT
CASE WHEN us.deleted_at IS NOT NULL THEN 'subscription_deleted'
WHEN g.id IS NULL THEN 'subscription_not_found'
WHEN g.deleted_at IS NOT NULL THEN 'group_deleted'
WHEN g.daily_limit_usd IS NOT NULL AND us.daily_usage_usd + $1 > g.daily_limit_usd THEN 'daily_exceeded'
WHEN g.weekly_limit_usd IS NOT NULL AND us.weekly_usage_usd + $1 > g.weekly_limit_usd THEN 'weekly_exceeded'
WHEN g.monthly_limit_usd IS NOT NULL AND us.monthly_usage_usd + $1 > g.monthly_limit_usd THEN 'monthly_exceeded'
ELSE 'unknown'
END AS reason
FROM user_subscriptions us
LEFT JOIN groups g ON us.group_id = g.id
WHERE us.id = $2
`
client := clientFromContext(ctx, r.client)
rows, err := client.QueryContext(ctx, checkSQL, costUSD, id)
if err != nil {
return err
}
defer func() { _ = rows.Close() }()
if !rows.Next() {
return service.ErrSubscriptionNotFound
}
var reason string
if err := rows.Scan(&reason); err != nil {
return err
}
if err := rows.Err(); err != nil {
return err
} }
switch reason { // affected == 0:订阅不存在或已删除
case "subscription_not_found", "subscription_deleted", "group_deleted":
return service.ErrSubscriptionNotFound return service.ErrSubscriptionNotFound
case "daily_exceeded":
return service.ErrDailyLimitExceeded
case "weekly_exceeded":
return service.ErrWeeklyLimitExceeded
case "monthly_exceeded":
return service.ErrMonthlyLimitExceeded
default:
// unknown 情况理论上不应发生,但作为兜底返回
return service.ErrSubscriptionNotFound
}
} }
func (r *userSubscriptionRepository) BatchUpdateExpiredStatus(ctx context.Context) (int64, error) { func (r *userSubscriptionRepository) BatchUpdateExpiredStatus(ctx context.Context) (int64, error) {
......
...@@ -633,112 +633,7 @@ func (s *UserSubscriptionRepoSuite) TestActiveExpiredBoundaries_UsageAndReset_Ba ...@@ -633,112 +633,7 @@ func (s *UserSubscriptionRepoSuite) TestActiveExpiredBoundaries_UsageAndReset_Ba
s.Require().Equal(service.SubscriptionStatusExpired, updated.Status, "expected status expired") s.Require().Equal(service.SubscriptionStatusExpired, updated.Status, "expected status expired")
} }
// --- 限额检查与软删除过滤测试 --- // --- 软删除过滤测试 ---
func (s *UserSubscriptionRepoSuite) mustCreateGroupWithLimits(name string, daily, weekly, monthly *float64) *service.Group {
s.T().Helper()
create := s.client.Group.Create().
SetName(name).
SetStatus(service.StatusActive).
SetSubscriptionType(service.SubscriptionTypeSubscription)
if daily != nil {
create.SetDailyLimitUsd(*daily)
}
if weekly != nil {
create.SetWeeklyLimitUsd(*weekly)
}
if monthly != nil {
create.SetMonthlyLimitUsd(*monthly)
}
g, err := create.Save(s.ctx)
s.Require().NoError(err, "create group with limits")
return groupEntityToService(g)
}
func (s *UserSubscriptionRepoSuite) TestIncrementUsage_DailyLimitExceeded() {
user := s.mustCreateUser("dailylimit@test.com", service.RoleUser)
dailyLimit := 10.0
group := s.mustCreateGroupWithLimits("g-dailylimit", &dailyLimit, nil, nil)
sub := s.mustCreateSubscription(user.ID, group.ID, nil)
// 先增加 9.0,应该成功
err := s.repo.IncrementUsage(s.ctx, sub.ID, 9.0)
s.Require().NoError(err, "first increment should succeed")
// 再增加 2.0,会超过 10.0 限额,应该失败
err = s.repo.IncrementUsage(s.ctx, sub.ID, 2.0)
s.Require().Error(err, "should fail when daily limit exceeded")
s.Require().ErrorIs(err, service.ErrDailyLimitExceeded)
// 验证用量没有变化
got, err := s.repo.GetByID(s.ctx, sub.ID)
s.Require().NoError(err)
s.Require().InDelta(9.0, got.DailyUsageUSD, 1e-6, "usage should not change after failed increment")
}
func (s *UserSubscriptionRepoSuite) TestIncrementUsage_WeeklyLimitExceeded() {
user := s.mustCreateUser("weeklylimit@test.com", service.RoleUser)
weeklyLimit := 50.0
group := s.mustCreateGroupWithLimits("g-weeklylimit", nil, &weeklyLimit, nil)
sub := s.mustCreateSubscription(user.ID, group.ID, nil)
// 增加 45.0,应该成功
err := s.repo.IncrementUsage(s.ctx, sub.ID, 45.0)
s.Require().NoError(err, "first increment should succeed")
// 再增加 10.0,会超过 50.0 限额,应该失败
err = s.repo.IncrementUsage(s.ctx, sub.ID, 10.0)
s.Require().Error(err, "should fail when weekly limit exceeded")
s.Require().ErrorIs(err, service.ErrWeeklyLimitExceeded)
}
func (s *UserSubscriptionRepoSuite) TestIncrementUsage_MonthlyLimitExceeded() {
user := s.mustCreateUser("monthlylimit@test.com", service.RoleUser)
monthlyLimit := 100.0
group := s.mustCreateGroupWithLimits("g-monthlylimit", nil, nil, &monthlyLimit)
sub := s.mustCreateSubscription(user.ID, group.ID, nil)
// 增加 90.0,应该成功
err := s.repo.IncrementUsage(s.ctx, sub.ID, 90.0)
s.Require().NoError(err, "first increment should succeed")
// 再增加 20.0,会超过 100.0 限额,应该失败
err = s.repo.IncrementUsage(s.ctx, sub.ID, 20.0)
s.Require().Error(err, "should fail when monthly limit exceeded")
s.Require().ErrorIs(err, service.ErrMonthlyLimitExceeded)
}
func (s *UserSubscriptionRepoSuite) TestIncrementUsage_NoLimits() {
user := s.mustCreateUser("nolimits@test.com", service.RoleUser)
group := s.mustCreateGroupWithLimits("g-nolimits", nil, nil, nil) // 无限额
sub := s.mustCreateSubscription(user.ID, group.ID, nil)
// 应该可以增加任意金额
err := s.repo.IncrementUsage(s.ctx, sub.ID, 1000000.0)
s.Require().NoError(err, "should succeed without limits")
got, err := s.repo.GetByID(s.ctx, sub.ID)
s.Require().NoError(err)
s.Require().InDelta(1000000.0, got.DailyUsageUSD, 1e-6)
}
func (s *UserSubscriptionRepoSuite) TestIncrementUsage_AtExactLimit() {
user := s.mustCreateUser("exactlimit@test.com", service.RoleUser)
dailyLimit := 10.0
group := s.mustCreateGroupWithLimits("g-exactlimit", &dailyLimit, nil, nil)
sub := s.mustCreateSubscription(user.ID, group.ID, nil)
// 正好达到限额应该成功
err := s.repo.IncrementUsage(s.ctx, sub.ID, 10.0)
s.Require().NoError(err, "should succeed at exact limit")
got, err := s.repo.GetByID(s.ctx, sub.ID)
s.Require().NoError(err)
s.Require().InDelta(10.0, got.DailyUsageUSD, 1e-6)
}
func (s *UserSubscriptionRepoSuite) TestIncrementUsage_SoftDeletedGroup() { func (s *UserSubscriptionRepoSuite) TestIncrementUsage_SoftDeletedGroup() {
user := s.mustCreateUser("softdeleted@test.com", service.RoleUser) user := s.mustCreateUser("softdeleted@test.com", service.RoleUser)
...@@ -779,7 +674,7 @@ func (s *UserSubscriptionRepoSuite) TestUpdate_NilInput() { ...@@ -779,7 +674,7 @@ func (s *UserSubscriptionRepoSuite) TestUpdate_NilInput() {
func (s *UserSubscriptionRepoSuite) TestIncrementUsage_Concurrent() { func (s *UserSubscriptionRepoSuite) TestIncrementUsage_Concurrent() {
user := s.mustCreateUser("concurrent@test.com", service.RoleUser) user := s.mustCreateUser("concurrent@test.com", service.RoleUser)
group := s.mustCreateGroupWithLimits("g-concurrent", nil, nil, nil) // 无限额 group := s.mustCreateGroup("g-concurrent")
sub := s.mustCreateSubscription(user.ID, group.ID, nil) sub := s.mustCreateSubscription(user.ID, group.ID, nil)
const numGoroutines = 10 const numGoroutines = 10
...@@ -808,34 +703,6 @@ func (s *UserSubscriptionRepoSuite) TestIncrementUsage_Concurrent() { ...@@ -808,34 +703,6 @@ func (s *UserSubscriptionRepoSuite) TestIncrementUsage_Concurrent() {
s.Require().InDelta(expectedUsage, got.MonthlyUsageUSD, 1e-6, "monthly usage should be correctly accumulated") s.Require().InDelta(expectedUsage, got.MonthlyUsageUSD, 1e-6, "monthly usage should be correctly accumulated")
} }
func (s *UserSubscriptionRepoSuite) TestIncrementUsage_ConcurrentWithLimit() {
user := s.mustCreateUser("concurrentlimit@test.com", service.RoleUser)
dailyLimit := 5.0
group := s.mustCreateGroupWithLimits("g-concurrentlimit", &dailyLimit, nil, nil)
sub := s.mustCreateSubscription(user.ID, group.ID, nil)
// 注意:事务内的操作是串行的,所以这里改为顺序执行以验证限额逻辑
// 尝试增加 10 次,每次 1.0,但限额只有 5.0
const numAttempts = 10
const incrementPerAttempt = 1.0
successCount := 0
for i := 0; i < numAttempts; i++ {
err := s.repo.IncrementUsage(s.ctx, sub.ID, incrementPerAttempt)
if err == nil {
successCount++
}
}
// 验证:应该有 5 次成功(不超过限额),5 次失败(超出限额)
s.Require().Equal(5, successCount, "exactly 5 increments should succeed (limit=5, increment=1)")
// 验证最终用量等于限额
got, err := s.repo.GetByID(s.ctx, sub.ID)
s.Require().NoError(err)
s.Require().InDelta(dailyLimit, got.DailyUsageUSD, 1e-6, "daily usage should equal limit")
}
func (s *UserSubscriptionRepoSuite) TestTxContext_RollbackIsolation() { func (s *UserSubscriptionRepoSuite) TestTxContext_RollbackIsolation() {
baseClient := testEntClient(s.T()) baseClient := testEntClient(s.T())
tx, err := baseClient.Tx(context.Background()) tx, err := baseClient.Tx(context.Background())
......
package repository package repository
import ( import (
"database/sql"
"errors"
entsql "entgo.io/ent/dialect/sql"
"github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
"github.com/google/wire" "github.com/google/wire"
...@@ -54,4 +59,58 @@ var ProviderSet = wire.NewSet( ...@@ -54,4 +59,58 @@ var ProviderSet = wire.NewSet(
NewOpenAIOAuthClient, NewOpenAIOAuthClient,
NewGeminiOAuthClient, NewGeminiOAuthClient,
NewGeminiCliCodeAssistClient, NewGeminiCliCodeAssistClient,
ProvideEnt,
ProvideSQLDB,
ProvideRedis,
) )
// ProvideEnt 为依赖注入提供 Ent 客户端。
//
// 该函数是 InitEnt 的包装器,符合 Wire 的依赖提供函数签名要求。
// Wire 会在编译时分析依赖关系,自动生成初始化代码。
//
// 依赖:config.Config
// 提供:*ent.Client
func ProvideEnt(cfg *config.Config) (*ent.Client, error) {
client, _, err := InitEnt(cfg)
return client, err
}
// ProvideSQLDB 从 Ent 客户端提取底层的 *sql.DB 连接。
//
// 某些 Repository 需要直接执行原生 SQL(如复杂的批量更新、聚合查询),
// 此时需要访问底层的 sql.DB 而不是通过 Ent ORM。
//
// 设计说明:
// - Ent 底层使用 sql.DB,通过 Driver 接口可以访问
// - 这种设计允许在同一事务中混用 Ent 和原生 SQL
//
// 依赖:*ent.Client
// 提供:*sql.DB
func ProvideSQLDB(client *ent.Client) (*sql.DB, error) {
if client == nil {
return nil, errors.New("nil ent client")
}
// 从 Ent 客户端获取底层驱动
drv, ok := client.Driver().(*entsql.Driver)
if !ok {
return nil, errors.New("ent driver does not expose *sql.DB")
}
// 返回驱动持有的 sql.DB 实例
return drv.DB(), nil
}
// ProvideRedis 为依赖注入提供 Redis 客户端。
//
// Redis 用于:
// - 分布式锁(如并发控制)
// - 缓存(如用户会话、API 响应缓存)
// - 速率限制
// - 实时统计数据
//
// 依赖:config.Config
// 提供:*redis.Client
func ProvideRedis(cfg *config.Config) *redis.Client {
return InitRedis(cfg)
}
...@@ -7,7 +7,7 @@ import ( ...@@ -7,7 +7,7 @@ import (
"os" "os"
"strings" "strings"
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors" infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/response" "github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )
......
...@@ -8,7 +8,7 @@ import ( ...@@ -8,7 +8,7 @@ import (
"net/http/httptest" "net/http/httptest"
"testing" "testing"
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors" infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/response" "github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
......
...@@ -5,7 +5,7 @@ import ( ...@@ -5,7 +5,7 @@ import (
"fmt" "fmt"
"time" "time"
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors" infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
) )
......
...@@ -488,6 +488,11 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn ...@@ -488,6 +488,11 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn
subscriptionType = SubscriptionTypeStandard subscriptionType = SubscriptionTypeStandard
} }
// 限额字段:0 和 nil 都表示"无限制"
dailyLimit := normalizeLimit(input.DailyLimitUSD)
weeklyLimit := normalizeLimit(input.WeeklyLimitUSD)
monthlyLimit := normalizeLimit(input.MonthlyLimitUSD)
group := &Group{ group := &Group{
Name: input.Name, Name: input.Name,
Description: input.Description, Description: input.Description,
...@@ -496,9 +501,9 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn ...@@ -496,9 +501,9 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn
IsExclusive: input.IsExclusive, IsExclusive: input.IsExclusive,
Status: StatusActive, Status: StatusActive,
SubscriptionType: subscriptionType, SubscriptionType: subscriptionType,
DailyLimitUSD: input.DailyLimitUSD, DailyLimitUSD: dailyLimit,
WeeklyLimitUSD: input.WeeklyLimitUSD, WeeklyLimitUSD: weeklyLimit,
MonthlyLimitUSD: input.MonthlyLimitUSD, MonthlyLimitUSD: monthlyLimit,
} }
if err := s.groupRepo.Create(ctx, group); err != nil { if err := s.groupRepo.Create(ctx, group); err != nil {
return nil, err return nil, err
...@@ -506,6 +511,14 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn ...@@ -506,6 +511,14 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn
return group, nil return group, nil
} }
// normalizeLimit 将 0 或负数转换为 nil(表示无限制)
func normalizeLimit(limit *float64) *float64 {
if limit == nil || *limit <= 0 {
return nil
}
return limit
}
func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *UpdateGroupInput) (*Group, error) { func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *UpdateGroupInput) (*Group, error) {
group, err := s.groupRepo.GetByID(ctx, id) group, err := s.groupRepo.GetByID(ctx, id)
if err != nil { if err != nil {
...@@ -535,15 +548,15 @@ func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *Upd ...@@ -535,15 +548,15 @@ func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *Upd
if input.SubscriptionType != "" { if input.SubscriptionType != "" {
group.SubscriptionType = input.SubscriptionType group.SubscriptionType = input.SubscriptionType
} }
// 限额字段支持设置为nil(清除限额)或具体值 // 限额字段:0 和 nil 都表示"无限制",正数表示具体限额
if input.DailyLimitUSD != nil { if input.DailyLimitUSD != nil {
group.DailyLimitUSD = input.DailyLimitUSD group.DailyLimitUSD = normalizeLimit(input.DailyLimitUSD)
} }
if input.WeeklyLimitUSD != nil { if input.WeeklyLimitUSD != nil {
group.WeeklyLimitUSD = input.WeeklyLimitUSD group.WeeklyLimitUSD = normalizeLimit(input.WeeklyLimitUSD)
} }
if input.MonthlyLimitUSD != nil { if input.MonthlyLimitUSD != nil {
group.MonthlyLimitUSD = input.MonthlyLimitUSD group.MonthlyLimitUSD = normalizeLimit(input.MonthlyLimitUSD)
} }
if err := s.groupRepo.Update(ctx, group); err != nil { if err := s.groupRepo.Update(ctx, group); err != nil {
......
...@@ -358,6 +358,15 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context, ...@@ -358,6 +358,15 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
return nil, fmt.Errorf("transform request: %w", err) return nil, fmt.Errorf("transform request: %w", err)
} }
// 调试:记录转换后的请求体(仅记录前 2000 字符)
if bodyJSON, err := json.Marshal(geminiBody); err == nil {
truncated := string(bodyJSON)
if len(truncated) > 2000 {
truncated = truncated[:2000] + "..."
}
log.Printf("[Debug] Transformed Gemini request: %s", truncated)
}
// 构建上游 action // 构建上游 action
action := "generateContent" action := "generateContent"
if claudeReq.Stream { if claudeReq.Stream {
......
...@@ -2,6 +2,7 @@ package service ...@@ -2,6 +2,7 @@ package service
import ( import (
"context" "context"
"fmt"
"time" "time"
) )
...@@ -28,7 +29,7 @@ func (r *AntigravityTokenRefresher) CanRefresh(account *Account) bool { ...@@ -28,7 +29,7 @@ func (r *AntigravityTokenRefresher) CanRefresh(account *Account) bool {
} }
// NeedsRefresh 检查账户是否需要刷新 // NeedsRefresh 检查账户是否需要刷新
// Antigravity 使用固定的10分钟刷新窗口,忽略全局配置 // Antigravity 使用固定的15分钟刷新窗口,忽略全局配置
func (r *AntigravityTokenRefresher) NeedsRefresh(account *Account, _ time.Duration) bool { func (r *AntigravityTokenRefresher) NeedsRefresh(account *Account, _ time.Duration) bool {
if !r.CanRefresh(account) { if !r.CanRefresh(account) {
return false return false
...@@ -37,7 +38,13 @@ func (r *AntigravityTokenRefresher) NeedsRefresh(account *Account, _ time.Durati ...@@ -37,7 +38,13 @@ func (r *AntigravityTokenRefresher) NeedsRefresh(account *Account, _ time.Durati
if expiresAt == nil { if expiresAt == nil {
return false return false
} }
return time.Until(*expiresAt) < antigravityRefreshWindow timeUntilExpiry := time.Until(*expiresAt)
needsRefresh := timeUntilExpiry < antigravityRefreshWindow
if needsRefresh {
fmt.Printf("[AntigravityTokenRefresher] Account %d needs refresh: expires_at=%s, time_until_expiry=%v, window=%v\n",
account.ID, expiresAt.Format("2006-01-02 15:04:05"), timeUntilExpiry, antigravityRefreshWindow)
}
return needsRefresh
} }
// Refresh 执行 token 刷新 // Refresh 执行 token 刷新
......
...@@ -8,7 +8,7 @@ import ( ...@@ -8,7 +8,7 @@ import (
"time" "time"
"github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/config"
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors" infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone" "github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
) )
......
...@@ -8,7 +8,7 @@ import ( ...@@ -8,7 +8,7 @@ import (
"time" "time"
"github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/config"
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors" infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/golang-jwt/jwt/v5" "github.com/golang-jwt/jwt/v5"
"golang.org/x/crypto/bcrypt" "golang.org/x/crypto/bcrypt"
......
...@@ -9,7 +9,7 @@ import ( ...@@ -9,7 +9,7 @@ import (
"time" "time"
"github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/config"
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors" infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
) )
// 错误定义 // 错误定义
......
...@@ -10,7 +10,7 @@ import ( ...@@ -10,7 +10,7 @@ import (
"strconv" "strconv"
"time" "time"
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors" infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
) )
var ( var (
......
...@@ -20,6 +20,7 @@ import ( ...@@ -20,6 +20,7 @@ import (
"github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/claude" "github.com/Wei-Shaw/sub2api/internal/pkg/claude"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey" "github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson" "github.com/tidwall/sjson"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
...@@ -1061,6 +1062,30 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A ...@@ -1061,6 +1062,30 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
// 处理错误响应(不可重试的错误) // 处理错误响应(不可重试的错误)
if resp.StatusCode >= 400 { if resp.StatusCode >= 400 {
// 可选:对部分 400 触发 failover(默认关闭以保持语义)
if resp.StatusCode == 400 && s.cfg != nil && s.cfg.Gateway.FailoverOn400 {
respBody, readErr := io.ReadAll(resp.Body)
if readErr != nil {
// ReadAll failed, fall back to normal error handling without consuming the stream
return s.handleErrorResponse(ctx, resp, c, account)
}
_ = resp.Body.Close()
resp.Body = io.NopCloser(bytes.NewReader(respBody))
if s.shouldFailoverOn400(respBody) {
if s.cfg.Gateway.LogUpstreamErrorBody {
log.Printf(
"Account %d: 400 error, attempting failover: %s",
account.ID,
truncateForLog(respBody, s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes),
)
} else {
log.Printf("Account %d: 400 error, attempting failover", account.ID)
}
s.handleFailoverSideEffects(ctx, resp, account)
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
}
}
return s.handleErrorResponse(ctx, resp, c, account) return s.handleErrorResponse(ctx, resp, c, account)
} }
...@@ -1163,6 +1188,13 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex ...@@ -1163,6 +1188,13 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
// 处理anthropic-beta header(OAuth账号需要特殊处理) // 处理anthropic-beta header(OAuth账号需要特殊处理)
if tokenType == "oauth" { if tokenType == "oauth" {
req.Header.Set("anthropic-beta", s.getBetaHeader(modelID, c.GetHeader("anthropic-beta"))) req.Header.Set("anthropic-beta", s.getBetaHeader(modelID, c.GetHeader("anthropic-beta")))
} else if s.cfg != nil && s.cfg.Gateway.InjectBetaForApiKey && req.Header.Get("anthropic-beta") == "" {
// API-key:仅在请求显式使用 beta 特性且客户端未提供时,按需补齐(默认关闭)
if requestNeedsBetaFeatures(body) {
if beta := defaultApiKeyBetaHeader(body); beta != "" {
req.Header.Set("anthropic-beta", beta)
}
}
} }
return req, nil return req, nil
...@@ -1215,6 +1247,83 @@ func (s *GatewayService) getBetaHeader(modelID string, clientBetaHeader string) ...@@ -1215,6 +1247,83 @@ func (s *GatewayService) getBetaHeader(modelID string, clientBetaHeader string)
return claude.DefaultBetaHeader return claude.DefaultBetaHeader
} }
func requestNeedsBetaFeatures(body []byte) bool {
tools := gjson.GetBytes(body, "tools")
if tools.Exists() && tools.IsArray() && len(tools.Array()) > 0 {
return true
}
if strings.EqualFold(gjson.GetBytes(body, "thinking.type").String(), "enabled") {
return true
}
return false
}
func defaultApiKeyBetaHeader(body []byte) string {
modelID := gjson.GetBytes(body, "model").String()
if strings.Contains(strings.ToLower(modelID), "haiku") {
return claude.ApiKeyHaikuBetaHeader
}
return claude.ApiKeyBetaHeader
}
func truncateForLog(b []byte, maxBytes int) string {
if maxBytes <= 0 {
maxBytes = 2048
}
if len(b) > maxBytes {
b = b[:maxBytes]
}
s := string(b)
// 保持一行,避免污染日志格式
s = strings.ReplaceAll(s, "\n", "\\n")
s = strings.ReplaceAll(s, "\r", "\\r")
return s
}
func (s *GatewayService) shouldFailoverOn400(respBody []byte) bool {
// 只对“可能是兼容性差异导致”的 400 允许切换,避免无意义重试。
// 默认保守:无法识别则不切换。
msg := strings.ToLower(strings.TrimSpace(extractUpstreamErrorMessage(respBody)))
if msg == "" {
return false
}
// 缺少/错误的 beta header:换账号/链路可能成功(尤其是混合调度时)。
// 更精确匹配 beta 相关的兼容性问题,避免误触发切换。
if strings.Contains(msg, "anthropic-beta") ||
strings.Contains(msg, "beta feature") ||
strings.Contains(msg, "requires beta") {
return true
}
// thinking/tool streaming 等兼容性约束(常见于中间转换链路)
if strings.Contains(msg, "thinking") || strings.Contains(msg, "thought_signature") || strings.Contains(msg, "signature") {
return true
}
if strings.Contains(msg, "tool_use") || strings.Contains(msg, "tool_result") || strings.Contains(msg, "tools") {
return true
}
return false
}
func extractUpstreamErrorMessage(body []byte) string {
// Claude 风格:{"type":"error","error":{"type":"...","message":"..."}}
if m := gjson.GetBytes(body, "error.message").String(); strings.TrimSpace(m) != "" {
inner := strings.TrimSpace(m)
// 有些上游会把完整 JSON 作为字符串塞进 message
if strings.HasPrefix(inner, "{") {
if innerMsg := gjson.Get(inner, "error.message").String(); strings.TrimSpace(innerMsg) != "" {
return innerMsg
}
}
return m
}
// 兜底:尝试顶层 message
return gjson.GetBytes(body, "message").String()
}
func (s *GatewayService) handleErrorResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account) (*ForwardResult, error) { func (s *GatewayService) handleErrorResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account) (*ForwardResult, error) {
body, _ := io.ReadAll(resp.Body) body, _ := io.ReadAll(resp.Body)
...@@ -1227,6 +1336,16 @@ func (s *GatewayService) handleErrorResponse(ctx context.Context, resp *http.Res ...@@ -1227,6 +1336,16 @@ func (s *GatewayService) handleErrorResponse(ctx context.Context, resp *http.Res
switch resp.StatusCode { switch resp.StatusCode {
case 400: case 400:
// 仅记录上游错误摘要(避免输出请求内容);需要时可通过配置打开
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
log.Printf(
"Upstream 400 error (account=%d platform=%s type=%s): %s",
account.ID,
account.Platform,
account.Type,
truncateForLog(body, s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes),
)
}
c.Data(http.StatusBadRequest, "application/json", body) c.Data(http.StatusBadRequest, "application/json", body)
return nil, fmt.Errorf("upstream error: %d", resp.StatusCode) return nil, fmt.Errorf("upstream error: %d", resp.StatusCode)
case 401: case 401:
...@@ -1706,6 +1825,18 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context, ...@@ -1706,6 +1825,18 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
// 标记账号状态(429/529等) // 标记账号状态(429/529等)
s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody) s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
// 记录上游错误摘要便于排障(不回显请求内容)
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
log.Printf(
"count_tokens upstream error %d (account=%d platform=%s type=%s): %s",
resp.StatusCode,
account.ID,
account.Platform,
account.Type,
truncateForLog(respBody, s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes),
)
}
// 返回简化的错误响应 // 返回简化的错误响应
errMsg := "Upstream request failed" errMsg := "Upstream request failed"
switch resp.StatusCode { switch resp.StatusCode {
...@@ -1786,6 +1917,13 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con ...@@ -1786,6 +1917,13 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
// OAuth 账号:处理 anthropic-beta header // OAuth 账号:处理 anthropic-beta header
if tokenType == "oauth" { if tokenType == "oauth" {
req.Header.Set("anthropic-beta", s.getBetaHeader(modelID, c.GetHeader("anthropic-beta"))) req.Header.Set("anthropic-beta", s.getBetaHeader(modelID, c.GetHeader("anthropic-beta")))
} else if s.cfg != nil && s.cfg.Gateway.InjectBetaForApiKey && req.Header.Get("anthropic-beta") == "" {
// API-key:与 messages 同步的按需 beta 注入(默认关闭)
if requestNeedsBetaFeatures(body) {
if beta := defaultApiKeyBetaHeader(body); beta != "" {
req.Header.Set("anthropic-beta", beta)
}
}
} }
return req, nil return req, nil
......
...@@ -2278,11 +2278,13 @@ func convertClaudeToolsToGeminiTools(tools any) []any { ...@@ -2278,11 +2278,13 @@ func convertClaudeToolsToGeminiTools(tools any) []any {
"properties": map[string]any{}, "properties": map[string]any{},
} }
} }
// 清理 JSON Schema
cleanedParams := cleanToolSchema(params)
funcDecls = append(funcDecls, map[string]any{ funcDecls = append(funcDecls, map[string]any{
"name": name, "name": name,
"description": desc, "description": desc,
"parameters": params, "parameters": cleanedParams,
}) })
} }
...@@ -2296,6 +2298,41 @@ func convertClaudeToolsToGeminiTools(tools any) []any { ...@@ -2296,6 +2298,41 @@ func convertClaudeToolsToGeminiTools(tools any) []any {
} }
} }
// cleanToolSchema 清理工具的 JSON Schema,移除 Gemini 不支持的字段
func cleanToolSchema(schema any) any {
if schema == nil {
return nil
}
switch v := schema.(type) {
case map[string]any:
cleaned := make(map[string]any)
for key, value := range v {
// 跳过不支持的字段
if key == "$schema" || key == "$id" || key == "$ref" ||
key == "additionalProperties" || key == "minLength" ||
key == "maxLength" || key == "minItems" || key == "maxItems" {
continue
}
// 递归清理嵌套对象
cleaned[key] = cleanToolSchema(value)
}
// 规范化 type 字段为大写
if typeVal, ok := cleaned["type"].(string); ok {
cleaned["type"] = strings.ToUpper(typeVal)
}
return cleaned
case []any:
cleaned := make([]any, len(v))
for i, item := range v {
cleaned[i] = cleanToolSchema(item)
}
return cleaned
default:
return v
}
}
func convertClaudeGenerationConfig(req map[string]any) map[string]any { func convertClaudeGenerationConfig(req map[string]any) map[string]any {
out := make(map[string]any) out := make(map[string]any)
if mt, ok := asInt(req["max_tokens"]); ok && mt > 0 { if mt, ok := asInt(req["max_tokens"]); ok && mt > 0 {
......
package service
import (
"testing"
)
// TestConvertClaudeToolsToGeminiTools_CustomType 测试custom类型工具转换
func TestConvertClaudeToolsToGeminiTools_CustomType(t *testing.T) {
tests := []struct {
name string
tools any
expectedLen int
description string
}{
{
name: "Standard tools",
tools: []any{
map[string]any{
"name": "get_weather",
"description": "Get weather info",
"input_schema": map[string]any{"type": "object"},
},
},
expectedLen: 1,
description: "标准工具格式应该正常转换",
},
{
name: "Custom type tool (MCP format)",
tools: []any{
map[string]any{
"type": "custom",
"name": "mcp_tool",
"custom": map[string]any{
"description": "MCP tool description",
"input_schema": map[string]any{"type": "object"},
},
},
},
expectedLen: 1,
description: "Custom类型工具应该从custom字段读取",
},
{
name: "Mixed standard and custom tools",
tools: []any{
map[string]any{
"name": "standard_tool",
"description": "Standard",
"input_schema": map[string]any{"type": "object"},
},
map[string]any{
"type": "custom",
"name": "custom_tool",
"custom": map[string]any{
"description": "Custom",
"input_schema": map[string]any{"type": "object"},
},
},
},
expectedLen: 1,
description: "混合工具应该都能正确转换",
},
{
name: "Custom tool without custom field",
tools: []any{
map[string]any{
"type": "custom",
"name": "invalid_custom",
// 缺少 custom 字段
},
},
expectedLen: 0, // 应该被跳过
description: "缺少custom字段的custom工具应该被跳过",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := convertClaudeToolsToGeminiTools(tt.tools)
if tt.expectedLen == 0 {
if result != nil {
t.Errorf("%s: expected nil result, got %v", tt.description, result)
}
return
}
if result == nil {
t.Fatalf("%s: expected non-nil result", tt.description)
}
if len(result) != 1 {
t.Errorf("%s: expected 1 tool declaration, got %d", tt.description, len(result))
return
}
toolDecl, ok := result[0].(map[string]any)
if !ok {
t.Fatalf("%s: result[0] is not map[string]any", tt.description)
}
funcDecls, ok := toolDecl["functionDeclarations"].([]any)
if !ok {
t.Fatalf("%s: functionDeclarations is not []any", tt.description)
}
toolsArr, _ := tt.tools.([]any)
expectedFuncCount := 0
for _, tool := range toolsArr {
toolMap, _ := tool.(map[string]any)
if toolMap["name"] != "" {
// 检查是否为有效的custom工具
if toolMap["type"] == "custom" {
if toolMap["custom"] != nil {
expectedFuncCount++
}
} else {
expectedFuncCount++
}
}
}
if len(funcDecls) != expectedFuncCount {
t.Errorf("%s: expected %d function declarations, got %d",
tt.description, expectedFuncCount, len(funcDecls))
}
})
}
}
...@@ -7,6 +7,7 @@ import ( ...@@ -7,6 +7,7 @@ import (
"fmt" "fmt"
"io" "io"
"net/http" "net/http"
"regexp"
"strconv" "strconv"
"strings" "strings"
"time" "time"
...@@ -163,6 +164,45 @@ type GeminiTokenInfo struct { ...@@ -163,6 +164,45 @@ type GeminiTokenInfo struct {
Scope string `json:"scope,omitempty"` Scope string `json:"scope,omitempty"`
ProjectID string `json:"project_id,omitempty"` ProjectID string `json:"project_id,omitempty"`
OAuthType string `json:"oauth_type,omitempty"` // "code_assist" 或 "ai_studio" OAuthType string `json:"oauth_type,omitempty"` // "code_assist" 或 "ai_studio"
TierID string `json:"tier_id,omitempty"` // Gemini Code Assist tier: LEGACY/PRO/ULTRA
}
// validateTierID validates tier_id format and length
func validateTierID(tierID string) error {
if tierID == "" {
return nil // Empty is allowed
}
if len(tierID) > 64 {
return fmt.Errorf("tier_id exceeds maximum length of 64 characters")
}
// Allow alphanumeric, underscore, hyphen, and slash (for tier paths)
if !regexp.MustCompile(`^[a-zA-Z0-9_/-]+$`).MatchString(tierID) {
return fmt.Errorf("tier_id contains invalid characters")
}
return nil
}
// extractTierIDFromAllowedTiers extracts tierID from LoadCodeAssist response
// Prioritizes IsDefault tier, falls back to first non-empty tier
func extractTierIDFromAllowedTiers(allowedTiers []geminicli.AllowedTier) string {
tierID := "LEGACY"
// First pass: look for default tier
for _, tier := range allowedTiers {
if tier.IsDefault && strings.TrimSpace(tier.ID) != "" {
tierID = strings.TrimSpace(tier.ID)
break
}
}
// Second pass: if still LEGACY, take first non-empty tier
if tierID == "LEGACY" {
for _, tier := range allowedTiers {
if strings.TrimSpace(tier.ID) != "" {
tierID = strings.TrimSpace(tier.ID)
break
}
}
}
return tierID
} }
func (s *GeminiOAuthService) ExchangeCode(ctx context.Context, input *GeminiExchangeCodeInput) (*GeminiTokenInfo, error) { func (s *GeminiOAuthService) ExchangeCode(ctx context.Context, input *GeminiExchangeCodeInput) (*GeminiTokenInfo, error) {
...@@ -223,13 +263,14 @@ func (s *GeminiOAuthService) ExchangeCode(ctx context.Context, input *GeminiExch ...@@ -223,13 +263,14 @@ func (s *GeminiOAuthService) ExchangeCode(ctx context.Context, input *GeminiExch
expiresAt := time.Now().Unix() + tokenResp.ExpiresIn - 300 expiresAt := time.Now().Unix() + tokenResp.ExpiresIn - 300
projectID := sessionProjectID projectID := sessionProjectID
var tierID string
// 对于 code_assist 模式,project_id 是必需的 // 对于 code_assist 模式,project_id 是必需的
// 对于 ai_studio 模式,project_id 是可选的(不影响使用 AI Studio API) // 对于 ai_studio 模式,project_id 是可选的(不影响使用 AI Studio API)
if oauthType == "code_assist" { if oauthType == "code_assist" {
if projectID == "" { if projectID == "" {
var err error var err error
projectID, err = s.fetchProjectID(ctx, tokenResp.AccessToken, proxyURL) projectID, tierID, err = s.fetchProjectID(ctx, tokenResp.AccessToken, proxyURL)
if err != nil { if err != nil {
// 记录警告但不阻断流程,允许后续补充 project_id // 记录警告但不阻断流程,允许后续补充 project_id
fmt.Printf("[GeminiOAuth] Warning: Failed to fetch project_id during token exchange: %v\n", err) fmt.Printf("[GeminiOAuth] Warning: Failed to fetch project_id during token exchange: %v\n", err)
...@@ -248,6 +289,7 @@ func (s *GeminiOAuthService) ExchangeCode(ctx context.Context, input *GeminiExch ...@@ -248,6 +289,7 @@ func (s *GeminiOAuthService) ExchangeCode(ctx context.Context, input *GeminiExch
ExpiresAt: expiresAt, ExpiresAt: expiresAt,
Scope: tokenResp.Scope, Scope: tokenResp.Scope,
ProjectID: projectID, ProjectID: projectID,
TierID: tierID,
OAuthType: oauthType, OAuthType: oauthType,
}, nil }, nil
} }
...@@ -357,7 +399,7 @@ func (s *GeminiOAuthService) RefreshAccountToken(ctx context.Context, account *A ...@@ -357,7 +399,7 @@ func (s *GeminiOAuthService) RefreshAccountToken(ctx context.Context, account *A
// For Code Assist, project_id is required. Auto-detect if missing. // For Code Assist, project_id is required. Auto-detect if missing.
// For AI Studio OAuth, project_id is optional and should not block refresh. // For AI Studio OAuth, project_id is optional and should not block refresh.
if oauthType == "code_assist" && strings.TrimSpace(tokenInfo.ProjectID) == "" { if oauthType == "code_assist" && strings.TrimSpace(tokenInfo.ProjectID) == "" {
projectID, err := s.fetchProjectID(ctx, tokenInfo.AccessToken, proxyURL) projectID, tierID, err := s.fetchProjectID(ctx, tokenInfo.AccessToken, proxyURL)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to auto-detect project_id: %w", err) return nil, fmt.Errorf("failed to auto-detect project_id: %w", err)
} }
...@@ -366,6 +408,7 @@ func (s *GeminiOAuthService) RefreshAccountToken(ctx context.Context, account *A ...@@ -366,6 +408,7 @@ func (s *GeminiOAuthService) RefreshAccountToken(ctx context.Context, account *A
return nil, fmt.Errorf("failed to auto-detect project_id: empty result") return nil, fmt.Errorf("failed to auto-detect project_id: empty result")
} }
tokenInfo.ProjectID = projectID tokenInfo.ProjectID = projectID
tokenInfo.TierID = tierID
} }
return tokenInfo, nil return tokenInfo, nil
...@@ -388,6 +431,13 @@ func (s *GeminiOAuthService) BuildAccountCredentials(tokenInfo *GeminiTokenInfo) ...@@ -388,6 +431,13 @@ func (s *GeminiOAuthService) BuildAccountCredentials(tokenInfo *GeminiTokenInfo)
if tokenInfo.ProjectID != "" { if tokenInfo.ProjectID != "" {
creds["project_id"] = tokenInfo.ProjectID creds["project_id"] = tokenInfo.ProjectID
} }
if tokenInfo.TierID != "" {
// Validate tier_id before storing
if err := validateTierID(tokenInfo.TierID); err == nil {
creds["tier_id"] = tokenInfo.TierID
}
// Silently skip invalid tier_id (don't block account creation)
}
if tokenInfo.OAuthType != "" { if tokenInfo.OAuthType != "" {
creds["oauth_type"] = tokenInfo.OAuthType creds["oauth_type"] = tokenInfo.OAuthType
} }
...@@ -398,35 +448,27 @@ func (s *GeminiOAuthService) Stop() { ...@@ -398,35 +448,27 @@ func (s *GeminiOAuthService) Stop() {
s.sessionStore.Stop() s.sessionStore.Stop()
} }
func (s *GeminiOAuthService) fetchProjectID(ctx context.Context, accessToken, proxyURL string) (string, error) { func (s *GeminiOAuthService) fetchProjectID(ctx context.Context, accessToken, proxyURL string) (string, string, error) {
if s.codeAssist == nil { if s.codeAssist == nil {
return "", errors.New("code assist client not configured") return "", "", errors.New("code assist client not configured")
} }
loadResp, loadErr := s.codeAssist.LoadCodeAssist(ctx, accessToken, proxyURL, nil) loadResp, loadErr := s.codeAssist.LoadCodeAssist(ctx, accessToken, proxyURL, nil)
if loadErr == nil && loadResp != nil && strings.TrimSpace(loadResp.CloudAICompanionProject) != "" {
return strings.TrimSpace(loadResp.CloudAICompanionProject), nil
}
// Pick tier from allowedTiers; if no default tier is marked, pick the first non-empty tier ID. // Extract tierID from response (works whether CloudAICompanionProject is set or not)
tierID := "LEGACY" tierID := "LEGACY"
if loadResp != nil { if loadResp != nil {
for _, tier := range loadResp.AllowedTiers { tierID = extractTierIDFromAllowedTiers(loadResp.AllowedTiers)
if tier.IsDefault && strings.TrimSpace(tier.ID) != "" {
tierID = strings.TrimSpace(tier.ID)
break
}
}
if strings.TrimSpace(tierID) == "" || tierID == "LEGACY" {
for _, tier := range loadResp.AllowedTiers {
if strings.TrimSpace(tier.ID) != "" {
tierID = strings.TrimSpace(tier.ID)
break
}
}
} }
// If LoadCodeAssist returned a project, use it
if loadErr == nil && loadResp != nil && strings.TrimSpace(loadResp.CloudAICompanionProject) != "" {
return strings.TrimSpace(loadResp.CloudAICompanionProject), tierID, nil
} }
// Pick tier from allowedTiers; if no default tier is marked, pick the first non-empty tier ID.
// (tierID already extracted above, reuse it)
req := &geminicli.OnboardUserRequest{ req := &geminicli.OnboardUserRequest{
TierID: tierID, TierID: tierID,
Metadata: geminicli.LoadCodeAssistMetadata{ Metadata: geminicli.LoadCodeAssistMetadata{
...@@ -443,39 +485,39 @@ func (s *GeminiOAuthService) fetchProjectID(ctx context.Context, accessToken, pr ...@@ -443,39 +485,39 @@ func (s *GeminiOAuthService) fetchProjectID(ctx context.Context, accessToken, pr
// If Code Assist onboarding fails (e.g. INVALID_ARGUMENT), fallback to Cloud Resource Manager projects. // If Code Assist onboarding fails (e.g. INVALID_ARGUMENT), fallback to Cloud Resource Manager projects.
fallback, fbErr := fetchProjectIDFromResourceManager(ctx, accessToken, proxyURL) fallback, fbErr := fetchProjectIDFromResourceManager(ctx, accessToken, proxyURL)
if fbErr == nil && strings.TrimSpace(fallback) != "" { if fbErr == nil && strings.TrimSpace(fallback) != "" {
return strings.TrimSpace(fallback), nil return strings.TrimSpace(fallback), tierID, nil
} }
return "", err return "", "", err
} }
if resp.Done { if resp.Done {
if resp.Response != nil && resp.Response.CloudAICompanionProject != nil { if resp.Response != nil && resp.Response.CloudAICompanionProject != nil {
switch v := resp.Response.CloudAICompanionProject.(type) { switch v := resp.Response.CloudAICompanionProject.(type) {
case string: case string:
return strings.TrimSpace(v), nil return strings.TrimSpace(v), tierID, nil
case map[string]any: case map[string]any:
if id, ok := v["id"].(string); ok { if id, ok := v["id"].(string); ok {
return strings.TrimSpace(id), nil return strings.TrimSpace(id), tierID, nil
} }
} }
} }
fallback, fbErr := fetchProjectIDFromResourceManager(ctx, accessToken, proxyURL) fallback, fbErr := fetchProjectIDFromResourceManager(ctx, accessToken, proxyURL)
if fbErr == nil && strings.TrimSpace(fallback) != "" { if fbErr == nil && strings.TrimSpace(fallback) != "" {
return strings.TrimSpace(fallback), nil return strings.TrimSpace(fallback), tierID, nil
} }
return "", errors.New("onboardUser completed but no project_id returned") return "", "", errors.New("onboardUser completed but no project_id returned")
} }
time.Sleep(2 * time.Second) time.Sleep(2 * time.Second)
} }
fallback, fbErr := fetchProjectIDFromResourceManager(ctx, accessToken, proxyURL) fallback, fbErr := fetchProjectIDFromResourceManager(ctx, accessToken, proxyURL)
if fbErr == nil && strings.TrimSpace(fallback) != "" { if fbErr == nil && strings.TrimSpace(fallback) != "" {
return strings.TrimSpace(fallback), nil return strings.TrimSpace(fallback), tierID, nil
} }
if loadErr != nil { if loadErr != nil {
return "", fmt.Errorf("loadCodeAssist failed (%v) and onboardUser timeout after %d attempts", loadErr, maxAttempts) return "", "", fmt.Errorf("loadCodeAssist failed (%v) and onboardUser timeout after %d attempts", loadErr, maxAttempts)
} }
return "", fmt.Errorf("onboardUser timeout after %d attempts", maxAttempts) return "", "", fmt.Errorf("onboardUser timeout after %d attempts", maxAttempts)
} }
type googleCloudProject struct { type googleCloudProject struct {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment