"backend/vscode:/vscode.git/clone" did not exist on "58707f8a2a472b51b0301f4e50cfb52ed94e3f19"
Commit f6fd7c83 authored by QTom's avatar QTom
Browse files

feat(antigravity): 从 LoadCodeAssist 复用 TierInfo 提取 plan_type

复用已有 GetTier() 返回的 tier ID(free-tier / g1-pro-tier /
g1-ultra-tier),通过 TierIDToPlanType 映射为 Free / Pro / Ultra,
在 loadProjectIDWithRetry 中顺带提取并写入 credentials.plan_type;
前端增加 Abnormal 异常套餐红色标记。

Made-with: Cursor
parent c2965c0f
...@@ -78,7 +78,9 @@ type UserInfo struct { ...@@ -78,7 +78,9 @@ type UserInfo struct {
// LoadCodeAssistRequest loadCodeAssist 请求 // LoadCodeAssistRequest loadCodeAssist 请求
type LoadCodeAssistRequest struct { type LoadCodeAssistRequest struct {
Metadata struct { Metadata struct {
IDEType string `json:"ideType"` IDEType string `json:"ideType"`
IDEVersion string `json:"ideVersion"`
IDEName string `json:"ideName"`
} `json:"metadata"` } `json:"metadata"`
} }
...@@ -223,6 +225,23 @@ func (r *LoadCodeAssistResponse) GetAvailableCredits() []AvailableCredit { ...@@ -223,6 +225,23 @@ func (r *LoadCodeAssistResponse) GetAvailableCredits() []AvailableCredit {
return r.PaidTier.AvailableCredits return r.PaidTier.AvailableCredits
} }
// TierIDToPlanType 将 tier ID 映射为用户可见的套餐名。
func TierIDToPlanType(tierID string) string {
switch strings.ToLower(strings.TrimSpace(tierID)) {
case "free-tier":
return "Free"
case "g1-pro-tier":
return "Pro"
case "g1-ultra-tier":
return "Ultra"
default:
if tierID == "" {
return "Free"
}
return tierID
}
}
// Client Antigravity API 客户端 // Client Antigravity API 客户端
type Client struct { type Client struct {
httpClient *http.Client httpClient *http.Client
...@@ -421,6 +440,8 @@ func (c *Client) GetUserInfo(ctx context.Context, accessToken string) (*UserInfo ...@@ -421,6 +440,8 @@ func (c *Client) GetUserInfo(ctx context.Context, accessToken string) (*UserInfo
func (c *Client) LoadCodeAssist(ctx context.Context, accessToken string) (*LoadCodeAssistResponse, map[string]any, error) { func (c *Client) LoadCodeAssist(ctx context.Context, accessToken string) (*LoadCodeAssistResponse, map[string]any, error) {
reqBody := LoadCodeAssistRequest{} reqBody := LoadCodeAssistRequest{}
reqBody.Metadata.IDEType = "ANTIGRAVITY" reqBody.Metadata.IDEType = "ANTIGRAVITY"
reqBody.Metadata.IDEVersion = "1.20.6"
reqBody.Metadata.IDEName = "antigravity"
bodyBytes, err := json.Marshal(reqBody) bodyBytes, err := json.Marshal(reqBody)
if err != nil { if err != nil {
......
...@@ -250,6 +250,27 @@ func TestGetTier_两者都为nil(t *testing.T) { ...@@ -250,6 +250,27 @@ func TestGetTier_两者都为nil(t *testing.T) {
} }
} }
func TestTierIDToPlanType(t *testing.T) {
tests := []struct {
tierID string
want string
}{
{"free-tier", "Free"},
{"g1-pro-tier", "Pro"},
{"g1-ultra-tier", "Ultra"},
{"FREE-TIER", "Free"},
{"", "Free"},
{"unknown-tier", "unknown-tier"},
}
for _, tt := range tests {
t.Run(tt.tierID, func(t *testing.T) {
if got := TierIDToPlanType(tt.tierID); got != tt.want {
t.Errorf("TierIDToPlanType(%q) = %q, want %q", tt.tierID, got, tt.want)
}
})
}
}
// --------------------------------------------------------------------------- // ---------------------------------------------------------------------------
// NewClient // NewClient
// --------------------------------------------------------------------------- // ---------------------------------------------------------------------------
...@@ -800,6 +821,12 @@ type redirectRoundTripper struct { ...@@ -800,6 +821,12 @@ type redirectRoundTripper struct {
transport http.RoundTripper transport http.RoundTripper
} }
type roundTripperFunc func(*http.Request) (*http.Response, error)
func (f roundTripperFunc) RoundTrip(req *http.Request) (*http.Response, error) {
return f(req)
}
func (rt *redirectRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { func (rt *redirectRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
originalURL := req.URL.String() originalURL := req.URL.String()
for prefix, target := range rt.redirects { for prefix, target := range rt.redirects {
...@@ -1271,6 +1298,12 @@ func TestClient_LoadCodeAssist_Success_RealCall(t *testing.T) { ...@@ -1271,6 +1298,12 @@ func TestClient_LoadCodeAssist_Success_RealCall(t *testing.T) {
if reqBody.Metadata.IDEType != "ANTIGRAVITY" { if reqBody.Metadata.IDEType != "ANTIGRAVITY" {
t.Errorf("IDEType 不匹配: got %s, want ANTIGRAVITY", reqBody.Metadata.IDEType) t.Errorf("IDEType 不匹配: got %s, want ANTIGRAVITY", reqBody.Metadata.IDEType)
} }
if strings.TrimSpace(reqBody.Metadata.IDEVersion) == "" {
t.Errorf("IDEVersion 不应为空")
}
if reqBody.Metadata.IDEName != "antigravity" {
t.Errorf("IDEName 不匹配: got %s, want antigravity", reqBody.Metadata.IDEName)
}
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK) w.WriteHeader(http.StatusOK)
......
...@@ -89,7 +89,8 @@ type AntigravityTokenInfo struct { ...@@ -89,7 +89,8 @@ type AntigravityTokenInfo struct {
TokenType string `json:"token_type"` TokenType string `json:"token_type"`
Email string `json:"email,omitempty"` Email string `json:"email,omitempty"`
ProjectID string `json:"project_id,omitempty"` ProjectID string `json:"project_id,omitempty"`
ProjectIDMissing bool `json:"-"` // LoadCodeAssist 未返回 project_id ProjectIDMissing bool `json:"-"`
PlanType string `json:"-"`
} }
// ExchangeCode 用 authorization code 交换 token // ExchangeCode 用 authorization code 交换 token
...@@ -145,13 +146,17 @@ func (s *AntigravityOAuthService) ExchangeCode(ctx context.Context, input *Antig ...@@ -145,13 +146,17 @@ func (s *AntigravityOAuthService) ExchangeCode(ctx context.Context, input *Antig
result.Email = userInfo.Email result.Email = userInfo.Email
} }
// 获取 project_id(部分账户类型可能没有),失败时重试 // 获取 project_id + plan_type(部分账户类型可能没有),失败时重试
projectID, loadErr := s.loadProjectIDWithRetry(ctx, tokenResp.AccessToken, proxyURL, 3) loadResult, loadErr := s.loadProjectIDWithRetry(ctx, tokenResp.AccessToken, proxyURL, 3)
if loadErr != nil { if loadErr != nil {
fmt.Printf("[AntigravityOAuth] 警告: 获取 project_id 失败(重试后): %v\n", loadErr) fmt.Printf("[AntigravityOAuth] 警告: 获取 project_id 失败(重试后): %v\n", loadErr)
result.ProjectIDMissing = true result.ProjectIDMissing = true
} else { }
result.ProjectID = projectID if loadResult != nil {
result.ProjectID = loadResult.ProjectID
if loadResult.Subscription != nil {
result.PlanType = loadResult.Subscription.PlanType
}
} }
return result, nil return result, nil
...@@ -230,13 +235,17 @@ func (s *AntigravityOAuthService) ValidateRefreshToken(ctx context.Context, refr ...@@ -230,13 +235,17 @@ func (s *AntigravityOAuthService) ValidateRefreshToken(ctx context.Context, refr
tokenInfo.Email = userInfo.Email tokenInfo.Email = userInfo.Email
} }
// 获取 project_id(容错,失败不阻塞) // 获取 project_id + plan_type(容错,失败不阻塞)
projectID, loadErr := s.loadProjectIDWithRetry(ctx, tokenInfo.AccessToken, proxyURL, 3) loadResult, loadErr := s.loadProjectIDWithRetry(ctx, tokenInfo.AccessToken, proxyURL, 3)
if loadErr != nil { if loadErr != nil {
fmt.Printf("[AntigravityOAuth] 警告: 获取 project_id 失败(重试后): %v\n", loadErr) fmt.Printf("[AntigravityOAuth] 警告: 获取 project_id 失败(重试后): %v\n", loadErr)
tokenInfo.ProjectIDMissing = true tokenInfo.ProjectIDMissing = true
} else { }
tokenInfo.ProjectID = projectID if loadResult != nil {
tokenInfo.ProjectID = loadResult.ProjectID
if loadResult.Subscription != nil {
tokenInfo.PlanType = loadResult.Subscription.PlanType
}
} }
return tokenInfo, nil return tokenInfo, nil
...@@ -288,33 +297,42 @@ func (s *AntigravityOAuthService) RefreshAccountToken(ctx context.Context, accou ...@@ -288,33 +297,42 @@ func (s *AntigravityOAuthService) RefreshAccountToken(ctx context.Context, accou
tokenInfo.Email = existingEmail tokenInfo.Email = existingEmail
} }
// 每次刷新都调用 LoadCodeAssist 获取 project_id,失败时重试 // 每次刷新都调用 LoadCodeAssist 获取 project_id + plan_type,失败时重试
existingProjectID := strings.TrimSpace(account.GetCredential("project_id")) existingProjectID := strings.TrimSpace(account.GetCredential("project_id"))
projectID, loadErr := s.loadProjectIDWithRetry(ctx, tokenInfo.AccessToken, proxyURL, 3) loadResult, loadErr := s.loadProjectIDWithRetry(ctx, tokenInfo.AccessToken, proxyURL, 3)
if loadErr != nil { if loadErr != nil {
// LoadCodeAssist 失败,保留原有 project_id
tokenInfo.ProjectID = existingProjectID tokenInfo.ProjectID = existingProjectID
// 只有从未获取过 project_id 且本次也获取失败时,才标记为真正缺失
// 如果之前有 project_id,本次只是临时故障,不应标记为错误
if existingProjectID == "" { if existingProjectID == "" {
tokenInfo.ProjectIDMissing = true tokenInfo.ProjectIDMissing = true
} }
} else { }
tokenInfo.ProjectID = projectID if loadResult != nil {
if loadResult.ProjectID != "" {
tokenInfo.ProjectID = loadResult.ProjectID
}
if loadResult.Subscription != nil {
tokenInfo.PlanType = loadResult.Subscription.PlanType
}
} }
return tokenInfo, nil return tokenInfo, nil
} }
// loadProjectIDWithRetry 带重试机制获取 project_id // loadCodeAssistResult 封装 loadProjectIDWithRetry 的返回结果,
// 返回 project_id 和错误,失败时会重试指定次数 // 同时携带从 LoadCodeAssist 响应中提取的 plan_type 信息。
func (s *AntigravityOAuthService) loadProjectIDWithRetry(ctx context.Context, accessToken, proxyURL string, maxRetries int) (string, error) { type loadCodeAssistResult struct {
ProjectID string
Subscription *AntigravitySubscriptionResult
}
// loadProjectIDWithRetry 带重试机制获取 project_id,同时从响应中提取 plan_type。
func (s *AntigravityOAuthService) loadProjectIDWithRetry(ctx context.Context, accessToken, proxyURL string, maxRetries int) (*loadCodeAssistResult, error) {
var lastErr error var lastErr error
var lastSubscription *AntigravitySubscriptionResult
for attempt := 0; attempt <= maxRetries; attempt++ { for attempt := 0; attempt <= maxRetries; attempt++ {
if attempt > 0 { if attempt > 0 {
// 指数退避:1s, 2s, 4s
backoff := time.Duration(1<<uint(attempt-1)) * time.Second backoff := time.Duration(1<<uint(attempt-1)) * time.Second
if backoff > 8*time.Second { if backoff > 8*time.Second {
backoff = 8 * time.Second backoff = 8 * time.Second
...@@ -324,24 +342,34 @@ func (s *AntigravityOAuthService) loadProjectIDWithRetry(ctx context.Context, ac ...@@ -324,24 +342,34 @@ func (s *AntigravityOAuthService) loadProjectIDWithRetry(ctx context.Context, ac
client, err := antigravity.NewClient(proxyURL) client, err := antigravity.NewClient(proxyURL)
if err != nil { if err != nil {
return "", fmt.Errorf("create antigravity client failed: %w", err) return nil, fmt.Errorf("create antigravity client failed: %w", err)
} }
loadResp, loadRaw, err := client.LoadCodeAssist(ctx, accessToken) loadResp, loadRaw, err := client.LoadCodeAssist(ctx, accessToken)
if loadResp != nil {
sub := NormalizeAntigravitySubscription(loadResp)
lastSubscription = &sub
}
if err == nil && loadResp != nil && loadResp.CloudAICompanionProject != "" { if err == nil && loadResp != nil && loadResp.CloudAICompanionProject != "" {
return loadResp.CloudAICompanionProject, nil return &loadCodeAssistResult{
ProjectID: loadResp.CloudAICompanionProject,
Subscription: lastSubscription,
}, nil
} }
if err == nil { if err == nil {
if projectID, onboardErr := tryOnboardProjectID(ctx, client, accessToken, loadRaw); onboardErr == nil && projectID != "" { if projectID, onboardErr := tryOnboardProjectID(ctx, client, accessToken, loadRaw); onboardErr == nil && projectID != "" {
return projectID, nil return &loadCodeAssistResult{
ProjectID: projectID,
Subscription: lastSubscription,
}, nil
} else if onboardErr != nil { } else if onboardErr != nil {
lastErr = onboardErr lastErr = onboardErr
continue continue
} }
} }
// 记录错误
if err != nil { if err != nil {
lastErr = err lastErr = err
} else if loadResp == nil { } else if loadResp == nil {
...@@ -351,7 +379,10 @@ func (s *AntigravityOAuthService) loadProjectIDWithRetry(ctx context.Context, ac ...@@ -351,7 +379,10 @@ func (s *AntigravityOAuthService) loadProjectIDWithRetry(ctx context.Context, ac
} }
} }
return "", fmt.Errorf("获取 project_id 失败 (重试 %d 次后): %w", maxRetries, lastErr) if lastSubscription != nil {
return &loadCodeAssistResult{Subscription: lastSubscription}, fmt.Errorf("获取 project_id 失败 (重试 %d 次后): %w", maxRetries, lastErr)
}
return nil, fmt.Errorf("获取 project_id 失败 (重试 %d 次后): %w", maxRetries, lastErr)
} }
func tryOnboardProjectID(ctx context.Context, client *antigravity.Client, accessToken string, loadRaw map[string]any) (string, error) { func tryOnboardProjectID(ctx context.Context, client *antigravity.Client, accessToken string, loadRaw map[string]any) (string, error) {
...@@ -410,7 +441,11 @@ func (s *AntigravityOAuthService) FillProjectID(ctx context.Context, account *Ac ...@@ -410,7 +441,11 @@ func (s *AntigravityOAuthService) FillProjectID(ctx context.Context, account *Ac
proxyURL = proxy.URL() proxyURL = proxy.URL()
} }
} }
return s.loadProjectIDWithRetry(ctx, accessToken, proxyURL, 3) result, err := s.loadProjectIDWithRetry(ctx, accessToken, proxyURL, 3)
if result != nil {
return result.ProjectID, err
}
return "", err
} }
// BuildAccountCredentials 构建账户凭证 // BuildAccountCredentials 构建账户凭证
...@@ -431,6 +466,9 @@ func (s *AntigravityOAuthService) BuildAccountCredentials(tokenInfo *Antigravity ...@@ -431,6 +466,9 @@ func (s *AntigravityOAuthService) BuildAccountCredentials(tokenInfo *Antigravity
if tokenInfo.ProjectID != "" { if tokenInfo.ProjectID != "" {
creds["project_id"] = tokenInfo.ProjectID creds["project_id"] = tokenInfo.ProjectID
} }
if tokenInfo.PlanType != "" {
creds["plan_type"] = tokenInfo.PlanType
}
return creds return creds
} }
......
...@@ -16,3 +16,26 @@ func TestApplyAntigravityPrivacyMode_SetsInMemoryExtra(t *testing.T) { ...@@ -16,3 +16,26 @@ func TestApplyAntigravityPrivacyMode_SetsInMemoryExtra(t *testing.T) {
t.Fatalf("expected privacy_mode %q, got %v", AntigravityPrivacySet, got) t.Fatalf("expected privacy_mode %q, got %v", AntigravityPrivacySet, got)
} }
} }
func TestApplyAntigravityPrivacyMode_PreservedBySubscriptionResult(t *testing.T) {
account := &Account{
Credentials: map[string]any{
"access_token": "token",
},
Extra: map[string]any{
"existing": "value",
},
}
applyAntigravityPrivacyMode(account, AntigravityPrivacySet)
_, extra := applyAntigravitySubscriptionResult(account, AntigravitySubscriptionResult{
PlanType: "Pro",
})
if got := extra["privacy_mode"]; got != AntigravityPrivacySet {
t.Fatalf("expected subscription writeback to keep privacy_mode %q, got %v", AntigravityPrivacySet, got)
}
if got := extra["existing"]; got != "value" {
t.Fatalf("expected existing extra fields to be preserved, got %v", got)
}
}
package service
import (
"strings"
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
)
const antigravitySubscriptionAbnormal = "abnormal"
// AntigravitySubscriptionResult 表示订阅检测后的规范化结果。
type AntigravitySubscriptionResult struct {
PlanType string
SubscriptionStatus string
SubscriptionError string
}
// NormalizeAntigravitySubscription 从 LoadCodeAssistResponse 提取 plan_type + 异常状态。
// 使用 GetTier()(返回 tier ID)+ TierIDToPlanType 映射。
func NormalizeAntigravitySubscription(resp *antigravity.LoadCodeAssistResponse) AntigravitySubscriptionResult {
if resp == nil {
return AntigravitySubscriptionResult{PlanType: "Free"}
}
if len(resp.IneligibleTiers) > 0 {
result := AntigravitySubscriptionResult{
PlanType: "Abnormal",
SubscriptionStatus: antigravitySubscriptionAbnormal,
}
if resp.IneligibleTiers[0] != nil {
result.SubscriptionError = strings.TrimSpace(resp.IneligibleTiers[0].ReasonMessage)
}
return result
}
tierID := resp.GetTier()
return AntigravitySubscriptionResult{
PlanType: antigravity.TierIDToPlanType(tierID),
}
}
func applyAntigravitySubscriptionResult(account *Account, result AntigravitySubscriptionResult) (map[string]any, map[string]any) {
credentials := make(map[string]any)
for k, v := range account.Credentials {
credentials[k] = v
}
credentials["plan_type"] = result.PlanType
extra := make(map[string]any)
for k, v := range account.Extra {
extra[k] = v
}
if result.SubscriptionStatus != "" {
extra["subscription_status"] = result.SubscriptionStatus
} else {
delete(extra, "subscription_status")
}
if result.SubscriptionError != "" {
extra["subscription_error"] = result.SubscriptionError
} else {
delete(extra, "subscription_error")
}
return credentials, extra
}
...@@ -31,7 +31,7 @@ ...@@ -31,7 +31,7 @@
</div> </div>
<!-- Row 2: Plan type + Privacy mode (only if either exists) --> <!-- Row 2: Plan type + Privacy mode (only if either exists) -->
<div v-if="planLabel || privacyBadge" class="inline-flex items-center overflow-hidden rounded-md"> <div v-if="planLabel || privacyBadge" class="inline-flex items-center overflow-hidden rounded-md">
<span v-if="planLabel" :class="['inline-flex items-center gap-1 px-1.5 py-1', typeClass]"> <span v-if="planLabel" :class="['inline-flex items-center gap-1 px-1.5 py-1', planBadgeClass]">
<span>{{ planLabel }}</span> <span>{{ planLabel }}</span>
</span> </span>
<span <span
...@@ -102,6 +102,8 @@ const planLabel = computed(() => { ...@@ -102,6 +102,8 @@ const planLabel = computed(() => {
return 'Pro' return 'Pro'
case 'free': case 'free':
return 'Free' return 'Free'
case 'abnormal':
return t('admin.accounts.subscriptionAbnormal')
default: default:
return props.planType return props.planType
} }
...@@ -139,6 +141,13 @@ const typeClass = computed(() => { ...@@ -139,6 +141,13 @@ const typeClass = computed(() => {
return 'bg-blue-100 text-blue-600 dark:bg-blue-900/30 dark:text-blue-400' return 'bg-blue-100 text-blue-600 dark:bg-blue-900/30 dark:text-blue-400'
}) })
const planBadgeClass = computed(() => {
if (props.planType && props.planType.toLowerCase() === 'abnormal') {
return 'bg-red-100 text-red-600 dark:bg-red-900/30 dark:text-red-400'
}
return typeClass.value
})
// Privacy badge — shows different states for OpenAI/Antigravity OAuth privacy setting // Privacy badge — shows different states for OpenAI/Antigravity OAuth privacy setting
const privacyBadge = computed(() => { const privacyBadge = computed(() => {
if (props.type !== 'oauth' || !props.privacyMode) return null if (props.type !== 'oauth' || !props.privacyMode) return null
......
...@@ -1987,6 +1987,7 @@ export default { ...@@ -1987,6 +1987,7 @@ export default {
privacyAntigravitySet: 'Telemetry and marketing emails disabled', privacyAntigravitySet: 'Telemetry and marketing emails disabled',
privacyAntigravityFailed: 'Privacy setting failed', privacyAntigravityFailed: 'Privacy setting failed',
setPrivacy: 'Set Privacy', setPrivacy: 'Set Privacy',
subscriptionAbnormal: 'Abnormal',
// Capacity status tooltips // Capacity status tooltips
capacity: { capacity: {
windowCost: { windowCost: {
......
...@@ -2025,6 +2025,7 @@ export default { ...@@ -2025,6 +2025,7 @@ export default {
privacyAntigravitySet: '已关闭遥测和营销邮件', privacyAntigravitySet: '已关闭遥测和营销邮件',
privacyAntigravityFailed: '隐私设置失败', privacyAntigravityFailed: '隐私设置失败',
setPrivacy: '设置隐私', setPrivacy: '设置隐私',
subscriptionAbnormal: '异常',
// 容量状态提示 // 容量状态提示
capacity: { capacity: {
windowCost: { windowCost: {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment