Unverified Commit 0507852a authored by Wesley Liddick's avatar Wesley Liddick Committed by GitHub
Browse files

Merge pull request #1437 from DaydreamCoding/fix/openai-oauth-improvements

feat(openai): OpenAI OAuth 账号管理增强:订阅状态、隐私设置、Token 刷新修复
parents 7b6ff135 e8efaa4c
...@@ -551,6 +551,11 @@ func (s *AccountTestService) testOpenAIAccountConnection(c *gin.Context, account ...@@ -551,6 +551,11 @@ func (s *AccountTestService) testOpenAIAccountConnection(c *gin.Context, account
account.RateLimitResetAt = resetAt account.RateLimitResetAt = resetAt
} }
} }
// 401 Unauthorized: 标记账号为永久错误
if resp.StatusCode == http.StatusUnauthorized && s.accountRepo != nil {
errMsg := fmt.Sprintf("Authentication failed (401): %s", string(body))
_ = s.accountRepo.SetError(ctx, account.ID, errMsg)
}
return s.sendErrorAndEnd(c, fmt.Sprintf("API returned %d: %s", resp.StatusCode, string(body))) return s.sendErrorAndEnd(c, fmt.Sprintf("API returned %d: %s", resp.StatusCode, string(body)))
} }
......
...@@ -1642,16 +1642,29 @@ func (s *adminServiceImpl) CreateAccount(ctx context.Context, input *CreateAccou ...@@ -1642,16 +1642,29 @@ func (s *adminServiceImpl) CreateAccount(ctx context.Context, input *CreateAccou
} }
} }
// Antigravity OAuth 账号:创建后异步设置隐私 // OAuth 账号:创建后异步设置隐私。
if account.Platform == PlatformAntigravity && account.Type == AccountTypeOAuth { // 使用 Ensure(幂等)而非 Force:新建账号 Extra 为空时效果相同,但更安全。
go func() { if account.Type == AccountTypeOAuth {
defer func() { switch account.Platform {
if r := recover(); r != nil { case PlatformOpenAI:
slog.Error("create_account_antigravity_privacy_panic", "account_id", account.ID, "recover", r) go func() {
} defer func() {
if r := recover(); r != nil {
slog.Error("create_account_openai_privacy_panic", "account_id", account.ID, "recover", r)
}
}()
s.EnsureOpenAIPrivacy(context.Background(), account)
}() }()
s.EnsureAntigravityPrivacy(context.Background(), account) case PlatformAntigravity:
}() go func() {
defer func() {
if r := recover(); r != nil {
slog.Error("create_account_antigravity_privacy_panic", "account_id", account.ID, "recover", r)
}
}()
s.EnsureAntigravityPrivacy(context.Background(), account)
}()
}
} }
return account, nil return account, nil
......
...@@ -127,18 +127,19 @@ type OpenAIExchangeCodeInput struct { ...@@ -127,18 +127,19 @@ type OpenAIExchangeCodeInput struct {
// OpenAITokenInfo represents the token information for OpenAI // OpenAITokenInfo represents the token information for OpenAI
type OpenAITokenInfo struct { type OpenAITokenInfo struct {
AccessToken string `json:"access_token"` AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"` RefreshToken string `json:"refresh_token"`
IDToken string `json:"id_token,omitempty"` IDToken string `json:"id_token,omitempty"`
ExpiresIn int64 `json:"expires_in"` ExpiresIn int64 `json:"expires_in"`
ExpiresAt int64 `json:"expires_at"` ExpiresAt int64 `json:"expires_at"`
ClientID string `json:"client_id,omitempty"` ClientID string `json:"client_id,omitempty"`
Email string `json:"email,omitempty"` Email string `json:"email,omitempty"`
ChatGPTAccountID string `json:"chatgpt_account_id,omitempty"` ChatGPTAccountID string `json:"chatgpt_account_id,omitempty"`
ChatGPTUserID string `json:"chatgpt_user_id,omitempty"` ChatGPTUserID string `json:"chatgpt_user_id,omitempty"`
OrganizationID string `json:"organization_id,omitempty"` OrganizationID string `json:"organization_id,omitempty"`
PlanType string `json:"plan_type,omitempty"` PlanType string `json:"plan_type,omitempty"`
PrivacyMode string `json:"privacy_mode,omitempty"` SubscriptionExpiresAt string `json:"subscription_expires_at,omitempty"`
PrivacyMode string `json:"privacy_mode,omitempty"`
} }
// ExchangeCode exchanges authorization code for tokens // ExchangeCode exchanges authorization code for tokens
...@@ -214,6 +215,8 @@ func (s *OpenAIOAuthService) ExchangeCode(ctx context.Context, input *OpenAIExch ...@@ -214,6 +215,8 @@ func (s *OpenAIOAuthService) ExchangeCode(ctx context.Context, input *OpenAIExch
tokenInfo.PlanType = userInfo.PlanType tokenInfo.PlanType = userInfo.PlanType
} }
s.enrichTokenInfo(ctx, tokenInfo, proxyURL)
return tokenInfo, nil return tokenInfo, nil
} }
...@@ -259,31 +262,40 @@ func (s *OpenAIOAuthService) RefreshTokenWithClientID(ctx context.Context, refre ...@@ -259,31 +262,40 @@ func (s *OpenAIOAuthService) RefreshTokenWithClientID(ctx context.Context, refre
tokenInfo.PlanType = userInfo.PlanType tokenInfo.PlanType = userInfo.PlanType
} }
// id_token 中缺少 plan_type 时(如 Mobile RT),尝试通过 ChatGPT backend-api 补全 s.enrichTokenInfo(ctx, tokenInfo, proxyURL)
if tokenInfo.PlanType == "" && tokenInfo.AccessToken != "" && s.privacyClientFactory != nil {
// 从 access_token JWT 中提取 orgID(poid),用于匹配正确的账号 return tokenInfo, nil
orgID := tokenInfo.OrganizationID }
if orgID == "" {
if atClaims, err := openai.DecodeIDToken(tokenInfo.AccessToken); err == nil && atClaims.OpenAIAuth != nil { // enrichTokenInfo 通过 ChatGPT backend-api 补全 tokenInfo 并设置隐私(best-effort)。
orgID = atClaims.OpenAIAuth.POID // 从 accounts/check 获取最新 plan_type、subscription_expires_at、email,
} // 然后尝试关闭训练数据共享。适用于所有获取/刷新 token 的路径。
func (s *OpenAIOAuthService) enrichTokenInfo(ctx context.Context, tokenInfo *OpenAITokenInfo, proxyURL string) {
if tokenInfo.AccessToken == "" || s.privacyClientFactory == nil {
return
}
// 从 access_token JWT 中提取 orgID(poid),用于匹配正确的账号
orgID := tokenInfo.OrganizationID
if orgID == "" {
if atClaims, err := openai.DecodeIDToken(tokenInfo.AccessToken); err == nil && atClaims.OpenAIAuth != nil {
orgID = atClaims.OpenAIAuth.POID
} }
if info := fetchChatGPTAccountInfo(ctx, s.privacyClientFactory, tokenInfo.AccessToken, proxyURL, orgID); info != nil { }
if tokenInfo.PlanType == "" && info.PlanType != "" { if info := fetchChatGPTAccountInfo(ctx, s.privacyClientFactory, tokenInfo.AccessToken, proxyURL, orgID); info != nil {
tokenInfo.PlanType = info.PlanType if info.PlanType != "" {
} tokenInfo.PlanType = info.PlanType
if tokenInfo.Email == "" && info.Email != "" { }
tokenInfo.Email = info.Email if info.SubscriptionExpiresAt != "" {
} tokenInfo.SubscriptionExpiresAt = info.SubscriptionExpiresAt
}
if tokenInfo.Email == "" && info.Email != "" {
tokenInfo.Email = info.Email
} }
} }
// 尝试设置隐私(关闭训练数据共享),best-effort // 尝试设置隐私(关闭训练数据共享),best-effort
if tokenInfo.AccessToken != "" && s.privacyClientFactory != nil { tokenInfo.PrivacyMode = disableOpenAITraining(ctx, s.privacyClientFactory, tokenInfo.AccessToken, proxyURL)
tokenInfo.PrivacyMode = disableOpenAITraining(ctx, s.privacyClientFactory, tokenInfo.AccessToken, proxyURL)
}
return tokenInfo, nil
} }
// ExchangeSoraSessionToken exchanges Sora session_token to access_token. // ExchangeSoraSessionToken exchanges Sora session_token to access_token.
...@@ -567,6 +579,9 @@ func (s *OpenAIOAuthService) BuildAccountCredentials(tokenInfo *OpenAITokenInfo) ...@@ -567,6 +579,9 @@ func (s *OpenAIOAuthService) BuildAccountCredentials(tokenInfo *OpenAITokenInfo)
if tokenInfo.PlanType != "" { if tokenInfo.PlanType != "" {
creds["plan_type"] = tokenInfo.PlanType creds["plan_type"] = tokenInfo.PlanType
} }
if tokenInfo.SubscriptionExpiresAt != "" {
creds["subscription_expires_at"] = tokenInfo.SubscriptionExpiresAt
}
if strings.TrimSpace(tokenInfo.ClientID) != "" { if strings.TrimSpace(tokenInfo.ClientID) != "" {
creds["client_id"] = strings.TrimSpace(tokenInfo.ClientID) creds["client_id"] = strings.TrimSpace(tokenInfo.ClientID)
} }
......
...@@ -56,6 +56,10 @@ func disableOpenAITraining(ctx context.Context, clientFactory PrivacyClientFacto ...@@ -56,6 +56,10 @@ func disableOpenAITraining(ctx context.Context, clientFactory PrivacyClientFacto
SetHeader("Authorization", "Bearer "+accessToken). SetHeader("Authorization", "Bearer "+accessToken).
SetHeader("Origin", "https://chatgpt.com"). SetHeader("Origin", "https://chatgpt.com").
SetHeader("Referer", "https://chatgpt.com/"). SetHeader("Referer", "https://chatgpt.com/").
SetHeader("Accept", "application/json").
SetHeader("sec-fetch-mode", "cors").
SetHeader("sec-fetch-site", "same-origin").
SetHeader("sec-fetch-dest", "empty").
SetQueryParam("feature", "training_allowed"). SetQueryParam("feature", "training_allowed").
SetQueryParam("value", "false"). SetQueryParam("value", "false").
Patch(openAISettingsURL) Patch(openAISettingsURL)
...@@ -84,8 +88,9 @@ func disableOpenAITraining(ctx context.Context, clientFactory PrivacyClientFacto ...@@ -84,8 +88,9 @@ func disableOpenAITraining(ctx context.Context, clientFactory PrivacyClientFacto
// ChatGPTAccountInfo 从 chatgpt.com/backend-api/accounts/check 获取的账号信息 // ChatGPTAccountInfo 从 chatgpt.com/backend-api/accounts/check 获取的账号信息
type ChatGPTAccountInfo struct { type ChatGPTAccountInfo struct {
PlanType string PlanType string
Email string Email string
SubscriptionExpiresAt string // entitlement.expires_at (RFC3339)
} }
const chatGPTAccountsCheckURL = "https://chatgpt.com/backend-api/accounts/check/v4-2023-04-27" const chatGPTAccountsCheckURL = "https://chatgpt.com/backend-api/accounts/check/v4-2023-04-27"
...@@ -138,14 +143,20 @@ func fetchChatGPTAccountInfo(ctx context.Context, clientFactory PrivacyClientFac ...@@ -138,14 +143,20 @@ func fetchChatGPTAccountInfo(ctx context.Context, clientFactory PrivacyClientFac
// 优先匹配 orgID 对应的账号(access_token JWT 中的 poid) // 优先匹配 orgID 对应的账号(access_token JWT 中的 poid)
if orgID != "" { if orgID != "" {
if matched := extractPlanFromAccount(accounts, orgID); matched != "" { if acctRaw, ok := accounts[orgID]; ok {
info.PlanType = matched if acct, ok := acctRaw.(map[string]any); ok {
fillAccountInfo(info, acct)
}
} }
} }
// 未匹配到时,遍历所有账号:优先 is_default,次选非 free // 未匹配到时,遍历所有账号:优先 is_default,次选非 free
if info.PlanType == "" { if info.PlanType == "" {
var defaultPlan, paidPlan, anyPlan string type candidate struct {
planType string
expiresAt string
}
var defaultC, paidC, anyC candidate
for _, acctRaw := range accounts { for _, acctRaw := range accounts {
acct, ok := acctRaw.(map[string]any) acct, ok := acctRaw.(map[string]any)
if !ok { if !ok {
...@@ -155,26 +166,27 @@ func fetchChatGPTAccountInfo(ctx context.Context, clientFactory PrivacyClientFac ...@@ -155,26 +166,27 @@ func fetchChatGPTAccountInfo(ctx context.Context, clientFactory PrivacyClientFac
if planType == "" { if planType == "" {
continue continue
} }
if anyPlan == "" { ea := extractEntitlementExpiresAt(acct)
anyPlan = planType if anyC.planType == "" {
anyC = candidate{planType, ea}
} }
if account, ok := acct["account"].(map[string]any); ok { if account, ok := acct["account"].(map[string]any); ok {
if isDefault, _ := account["is_default"].(bool); isDefault { if isDefault, _ := account["is_default"].(bool); isDefault {
defaultPlan = planType defaultC = candidate{planType, ea}
} }
} }
if !strings.EqualFold(planType, "free") && paidPlan == "" { if !strings.EqualFold(planType, "free") && paidC.planType == "" {
paidPlan = planType paidC = candidate{planType, ea}
} }
} }
// 优先级:default > 非 free > 任意 // 优先级:default > 非 free > 任意
switch { switch {
case defaultPlan != "": case defaultC.planType != "":
info.PlanType = defaultPlan info.PlanType, info.SubscriptionExpiresAt = defaultC.planType, defaultC.expiresAt
case paidPlan != "": case paidC.planType != "":
info.PlanType = paidPlan info.PlanType, info.SubscriptionExpiresAt = paidC.planType, paidC.expiresAt
default: default:
info.PlanType = anyPlan info.PlanType, info.SubscriptionExpiresAt = anyC.planType, anyC.expiresAt
} }
} }
...@@ -183,21 +195,14 @@ func fetchChatGPTAccountInfo(ctx context.Context, clientFactory PrivacyClientFac ...@@ -183,21 +195,14 @@ func fetchChatGPTAccountInfo(ctx context.Context, clientFactory PrivacyClientFac
return nil return nil
} }
slog.Info("chatgpt_account_check_success", "plan_type", info.PlanType, "org_id", orgID) slog.Info("chatgpt_account_check_success", "plan_type", info.PlanType, "subscription_expires_at", info.SubscriptionExpiresAt, "org_id", orgID)
return info return info
} }
// extractPlanFromAccount 从 accounts map 中按 key(account_id)精确匹配并提取 plan_type // fillAccountInfo 从单个 account 对象中提取 plan_type 和 subscription_expires_at
func extractPlanFromAccount(accounts map[string]any, accountKey string) string { func fillAccountInfo(info *ChatGPTAccountInfo, acct map[string]any) {
acctRaw, ok := accounts[accountKey] info.PlanType = extractPlanType(acct)
if !ok { info.SubscriptionExpiresAt = extractEntitlementExpiresAt(acct)
return ""
}
acct, ok := acctRaw.(map[string]any)
if !ok {
return ""
}
return extractPlanType(acct)
} }
// extractPlanType 从单个 account 对象中提取 plan_type // extractPlanType 从单个 account 对象中提取 plan_type
...@@ -215,6 +220,17 @@ func extractPlanType(acct map[string]any) string { ...@@ -215,6 +220,17 @@ func extractPlanType(acct map[string]any) string {
return "" return ""
} }
// extractEntitlementExpiresAt 从 entitlement 中提取 expires_at。
// 预期为 RFC3339 字符串格式,如 "2026-05-02T20:32:12+00:00"。
func extractEntitlementExpiresAt(acct map[string]any) string {
entitlement, ok := acct["entitlement"].(map[string]any)
if !ok {
return ""
}
ea, _ := entitlement["expires_at"].(string)
return ea
}
func truncate(s string, n int) string { func truncate(s string, n int) string {
if len(s) <= n { if len(s) <= n {
return s return s
......
...@@ -161,6 +161,16 @@ func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Acc ...@@ -161,6 +161,16 @@ func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Acc
shouldDisable = true shouldDisable = true
break break
} }
// OpenAI: {"detail":"Unauthorized"} 表示 token 完全无效(非标准 OpenAI 错误格式),直接标记 error
if account.Platform == PlatformOpenAI && gjson.GetBytes(responseBody, "detail").String() == "Unauthorized" {
msg := "Unauthorized (401): account authentication failed permanently"
if upstreamMsg != "" {
msg = "Unauthorized (401): " + upstreamMsg
}
s.handleAuthError(ctx, account, msg)
shouldDisable = true
break
}
// OAuth 账号在 401 错误时临时不可调度(给 token 刷新窗口);非 OAuth 账号保持原有 SetError 行为。 // OAuth 账号在 401 错误时临时不可调度(给 token 刷新窗口);非 OAuth 账号保持原有 SetError 行为。
// Antigravity 除外:其 401 由 applyErrorPolicy 的 temp_unschedulable_rules 自行控制。 // Antigravity 除外:其 401 由 applyErrorPolicy 的 temp_unschedulable_rules 自行控制。
if account.Type == AccountTypeOAuth && account.Platform != PlatformAntigravity { if account.Type == AccountTypeOAuth && account.Platform != PlatformAntigravity {
......
...@@ -109,11 +109,11 @@ func (r *OpenAITokenRefresher) CanRefresh(account *Account) bool { ...@@ -109,11 +109,11 @@ func (r *OpenAITokenRefresher) CanRefresh(account *Account) bool {
} }
// NeedsRefresh 检查token是否需要刷新 // NeedsRefresh 检查token是否需要刷新
// 基于 expires_at 字段判断是否在刷新窗口内 // expires_at 缺失且处于限流状态时需要刷新,防止限流期间 token 静默过期
func (r *OpenAITokenRefresher) NeedsRefresh(account *Account, refreshWindow time.Duration) bool { func (r *OpenAITokenRefresher) NeedsRefresh(account *Account, refreshWindow time.Duration) bool {
expiresAt := account.GetCredentialAsTime("expires_at") expiresAt := account.GetCredentialAsTime("expires_at")
if expiresAt == nil { if expiresAt == nil {
return false return account.IsRateLimited()
} }
return time.Until(*expiresAt) < refreshWindow return time.Until(*expiresAt) < refreshWindow
......
...@@ -45,6 +45,10 @@ ...@@ -45,6 +45,10 @@
<span>{{ privacyBadge.label }}</span> <span>{{ privacyBadge.label }}</span>
</span> </span>
</div> </div>
<!-- Row 3: Subscription expiration (non-free paid accounts only) -->
<div v-if="expiresLabel" class="text-[10px] leading-tight text-gray-400 dark:text-gray-500 pl-0.5" :title="subscriptionExpiresAt">
{{ expiresLabel }}
</div>
</div> </div>
</template> </template>
...@@ -62,6 +66,7 @@ interface Props { ...@@ -62,6 +66,7 @@ interface Props {
type: AccountType type: AccountType
planType?: string planType?: string
privacyMode?: string privacyMode?: string
subscriptionExpiresAt?: string
} }
const props = defineProps<Props>() const props = defineProps<Props>()
...@@ -148,6 +153,22 @@ const planBadgeClass = computed(() => { ...@@ -148,6 +153,22 @@ const planBadgeClass = computed(() => {
return typeClass.value return typeClass.value
}) })
// Subscription expiration label (non-free only)
const expiresLabel = computed(() => {
if (!props.subscriptionExpiresAt || !props.planType) return ''
if (props.planType.toLowerCase() === 'free') return ''
try {
const d = new Date(props.subscriptionExpiresAt)
if (isNaN(d.getTime())) return ''
const yyyy = d.getFullYear()
const mm = String(d.getMonth() + 1).padStart(2, '0')
const dd = String(d.getDate()).padStart(2, '0')
return `${t('admin.accounts.subscriptionExpires')} ${yyyy}-${mm}-${dd}`
} catch {
return ''
}
})
// Privacy badge — shows different states for OpenAI/Antigravity OAuth privacy setting // Privacy badge — shows different states for OpenAI/Antigravity OAuth privacy setting
const privacyBadge = computed(() => { const privacyBadge = computed(() => {
if (props.type !== 'oauth' || !props.privacyMode) return null if (props.type !== 'oauth' || !props.privacyMode) return null
......
...@@ -1988,6 +1988,7 @@ export default { ...@@ -1988,6 +1988,7 @@ export default {
privacyAntigravityFailed: 'Privacy setting failed', privacyAntigravityFailed: 'Privacy setting failed',
setPrivacy: 'Set Privacy', setPrivacy: 'Set Privacy',
subscriptionAbnormal: 'Abnormal', subscriptionAbnormal: 'Abnormal',
subscriptionExpires: 'Expires',
// Capacity status tooltips // Capacity status tooltips
capacity: { capacity: {
windowCost: { windowCost: {
......
...@@ -2026,6 +2026,7 @@ export default { ...@@ -2026,6 +2026,7 @@ export default {
privacyAntigravityFailed: '隐私设置失败', privacyAntigravityFailed: '隐私设置失败',
setPrivacy: '设置隐私', setPrivacy: '设置隐私',
subscriptionAbnormal: '异常', subscriptionAbnormal: '异常',
subscriptionExpires: '到期',
// 容量状态提示 // 容量状态提示
capacity: { capacity: {
windowCost: { windowCost: {
......
...@@ -182,7 +182,7 @@ ...@@ -182,7 +182,7 @@
</template> </template>
<template #cell-platform_type="{ row }"> <template #cell-platform_type="{ row }">
<div class="flex flex-wrap items-center gap-1"> <div class="flex flex-wrap items-center gap-1">
<PlatformTypeBadge :platform="row.platform" :type="row.type" :plan-type="row.credentials?.plan_type" :privacy-mode="row.extra?.privacy_mode" /> <PlatformTypeBadge :platform="row.platform" :type="row.type" :plan-type="row.credentials?.plan_type" :privacy-mode="row.extra?.privacy_mode" :subscription-expires-at="row.credentials?.subscription_expires_at" />
<span <span
v-if="getAntigravityTierLabel(row)" v-if="getAntigravityTierLabel(row)"
:class="['inline-block rounded px-1.5 py-0.5 text-[10px] font-medium', getAntigravityTierClass(row)]" :class="['inline-block rounded px-1.5 py-0.5 text-[10px] font-medium', getAntigravityTierClass(row)]"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment