Unverified Commit 9398ea7a authored by Wesley Liddick's avatar Wesley Liddick Committed by GitHub
Browse files

Merge pull request #1340 from DaydreamCoding/fix/privacy-and-system-prompt

fix(openai): OpenAI 隐私模式全场景覆盖 & 修复转发路径 system prompt 丢失
parents 29dce1a5 c729ee42
...@@ -539,6 +539,8 @@ func (h *AccountHandler) Create(c *gin.Context) { ...@@ -539,6 +539,8 @@ func (h *AccountHandler) Create(c *gin.Context) {
} }
// Antigravity OAuth: 新账号直接设置隐私 // Antigravity OAuth: 新账号直接设置隐私
h.adminService.ForceAntigravityPrivacy(ctx, account) h.adminService.ForceAntigravityPrivacy(ctx, account)
// OpenAI OAuth: 新账号直接设置隐私
h.adminService.ForceOpenAIPrivacy(ctx, account)
return h.buildAccountResponseWithRuntime(ctx, account), nil return h.buildAccountResponseWithRuntime(ctx, account), nil
}) })
if err != nil { if err != nil {
...@@ -785,6 +787,8 @@ func (h *AccountHandler) refreshSingleAccount(ctx context.Context, account *serv ...@@ -785,6 +787,8 @@ func (h *AccountHandler) refreshSingleAccount(ctx context.Context, account *serv
if account.IsOpenAI() { if account.IsOpenAI() {
tokenInfo, err := h.openaiOAuthService.RefreshAccountToken(ctx, account) tokenInfo, err := h.openaiOAuthService.RefreshAccountToken(ctx, account)
if err != nil { if err != nil {
// 刷新失败但 access_token 可能仍有效,尝试设置隐私
h.adminService.EnsureOpenAIPrivacy(ctx, account)
return nil, "", err return nil, "", err
} }
...@@ -1159,8 +1163,9 @@ func (h *AccountHandler) BatchCreate(c *gin.Context) { ...@@ -1159,8 +1163,9 @@ func (h *AccountHandler) BatchCreate(c *gin.Context) {
success := 0 success := 0
failed := 0 failed := 0
results := make([]gin.H, 0, len(req.Accounts)) results := make([]gin.H, 0, len(req.Accounts))
// 收集需要异步设置隐私的 Antigravity OAuth 账号 // 收集需要异步设置隐私的 OAuth 账号
var privacyAccounts []*service.Account var antigravityPrivacyAccounts []*service.Account
var openaiPrivacyAccounts []*service.Account
for _, item := range req.Accounts { for _, item := range req.Accounts {
if item.RateMultiplier != nil && *item.RateMultiplier < 0 { if item.RateMultiplier != nil && *item.RateMultiplier < 0 {
...@@ -1203,9 +1208,14 @@ func (h *AccountHandler) BatchCreate(c *gin.Context) { ...@@ -1203,9 +1208,14 @@ func (h *AccountHandler) BatchCreate(c *gin.Context) {
}) })
continue continue
} }
// 收集 Antigravity OAuth 账号,稍后异步设置隐私 // 收集需要异步设置隐私的 OAuth 账号
if account.Platform == service.PlatformAntigravity && account.Type == service.AccountTypeOAuth { if account.Type == service.AccountTypeOAuth {
privacyAccounts = append(privacyAccounts, account) switch account.Platform {
case service.PlatformAntigravity:
antigravityPrivacyAccounts = append(antigravityPrivacyAccounts, account)
case service.PlatformOpenAI:
openaiPrivacyAccounts = append(openaiPrivacyAccounts, account)
}
} }
success++ success++
results = append(results, gin.H{ results = append(results, gin.H{
...@@ -1215,9 +1225,10 @@ func (h *AccountHandler) BatchCreate(c *gin.Context) { ...@@ -1215,9 +1225,10 @@ func (h *AccountHandler) BatchCreate(c *gin.Context) {
}) })
} }
// 异步设置 Antigravity 隐私,避免批量创建时阻塞请求 // 异步设置隐私,避免批量创建时阻塞请求
if len(privacyAccounts) > 0 {
adminSvc := h.adminService adminSvc := h.adminService
if len(antigravityPrivacyAccounts) > 0 {
accounts := antigravityPrivacyAccounts
go func() { go func() {
defer func() { defer func() {
if r := recover(); r != nil { if r := recover(); r != nil {
...@@ -1225,11 +1236,25 @@ func (h *AccountHandler) BatchCreate(c *gin.Context) { ...@@ -1225,11 +1236,25 @@ func (h *AccountHandler) BatchCreate(c *gin.Context) {
} }
}() }()
bgCtx := context.Background() bgCtx := context.Background()
for _, acc := range privacyAccounts { for _, acc := range accounts {
adminSvc.ForceAntigravityPrivacy(bgCtx, acc) adminSvc.ForceAntigravityPrivacy(bgCtx, acc)
} }
}() }()
} }
if len(openaiPrivacyAccounts) > 0 {
accounts := openaiPrivacyAccounts
go func() {
defer func() {
if r := recover(); r != nil {
slog.Error("batch_create_openai_privacy_panic", "recover", r)
}
}()
bgCtx := context.Background()
for _, acc := range accounts {
adminSvc.ForceOpenAIPrivacy(bgCtx, acc)
}
}()
}
return gin.H{ return gin.H{
"success": success, "success": success,
...@@ -1896,7 +1921,7 @@ func (h *AccountHandler) GetAvailableModels(c *gin.Context) { ...@@ -1896,7 +1921,7 @@ func (h *AccountHandler) GetAvailableModels(c *gin.Context) {
response.Success(c, models) response.Success(c, models)
} }
// SetPrivacy handles setting privacy for a single Antigravity OAuth account // SetPrivacy handles setting privacy for a single OpenAI/Antigravity OAuth account
// POST /api/v1/admin/accounts/:id/set-privacy // POST /api/v1/admin/accounts/:id/set-privacy
func (h *AccountHandler) SetPrivacy(c *gin.Context) { func (h *AccountHandler) SetPrivacy(c *gin.Context) {
accountID, err := strconv.ParseInt(c.Param("id"), 10, 64) accountID, err := strconv.ParseInt(c.Param("id"), 10, 64)
...@@ -1909,11 +1934,20 @@ func (h *AccountHandler) SetPrivacy(c *gin.Context) { ...@@ -1909,11 +1934,20 @@ func (h *AccountHandler) SetPrivacy(c *gin.Context) {
response.NotFound(c, "Account not found") response.NotFound(c, "Account not found")
return return
} }
if account.Platform != service.PlatformAntigravity || account.Type != service.AccountTypeOAuth { if account.Type != service.AccountTypeOAuth {
response.BadRequest(c, "Only Antigravity OAuth accounts support privacy setting") response.BadRequest(c, "Only OAuth accounts support privacy setting")
return
}
var mode string
switch account.Platform {
case service.PlatformOpenAI:
mode = h.adminService.ForceOpenAIPrivacy(c.Request.Context(), account)
case service.PlatformAntigravity:
mode = h.adminService.ForceAntigravityPrivacy(c.Request.Context(), account)
default:
response.BadRequest(c, "Only OpenAI and Antigravity OAuth accounts support privacy setting")
return return
} }
mode := h.adminService.ForceAntigravityPrivacy(c.Request.Context(), account)
if mode == "" { if mode == "" {
response.BadRequest(c, "Cannot set privacy: missing access_token") response.BadRequest(c, "Cannot set privacy: missing access_token")
return return
......
...@@ -449,6 +449,10 @@ func (s *stubAdminService) EnsureAntigravityPrivacy(ctx context.Context, account ...@@ -449,6 +449,10 @@ func (s *stubAdminService) EnsureAntigravityPrivacy(ctx context.Context, account
return "" return ""
} }
func (s *stubAdminService) ForceOpenAIPrivacy(ctx context.Context, account *service.Account) string {
return ""
}
func (s *stubAdminService) ForceAntigravityPrivacy(ctx context.Context, account *service.Account) string { func (s *stubAdminService) ForceAntigravityPrivacy(ctx context.Context, account *service.Account) string {
return "" return ""
} }
......
...@@ -67,6 +67,8 @@ type AdminService interface { ...@@ -67,6 +67,8 @@ type AdminService interface {
EnsureOpenAIPrivacy(ctx context.Context, account *Account) string EnsureOpenAIPrivacy(ctx context.Context, account *Account) string
// EnsureAntigravityPrivacy 检查 Antigravity OAuth 账号 privacy_mode,未设置则调用 setUserSettings 并持久化。 // EnsureAntigravityPrivacy 检查 Antigravity OAuth 账号 privacy_mode,未设置则调用 setUserSettings 并持久化。
EnsureAntigravityPrivacy(ctx context.Context, account *Account) string EnsureAntigravityPrivacy(ctx context.Context, account *Account) string
// ForceOpenAIPrivacy 强制重新设置 OpenAI OAuth 账号隐私,无论当前状态。
ForceOpenAIPrivacy(ctx context.Context, account *Account) string
// ForceAntigravityPrivacy 强制重新设置 Antigravity OAuth 账号隐私,无论当前状态。 // ForceAntigravityPrivacy 强制重新设置 Antigravity OAuth 账号隐私,无论当前状态。
ForceAntigravityPrivacy(ctx context.Context, account *Account) string ForceAntigravityPrivacy(ctx context.Context, account *Account) string
SetAccountSchedulable(ctx context.Context, id int64, schedulable bool) (*Account, error) SetAccountSchedulable(ctx context.Context, id int64, schedulable bool) (*Account, error)
...@@ -2664,6 +2666,43 @@ func (s *adminServiceImpl) EnsureOpenAIPrivacy(ctx context.Context, account *Acc ...@@ -2664,6 +2666,43 @@ func (s *adminServiceImpl) EnsureOpenAIPrivacy(ctx context.Context, account *Acc
return mode return mode
} }
// ForceOpenAIPrivacy 强制重新设置 OpenAI OAuth 账号隐私,无论当前状态。
func (s *adminServiceImpl) ForceOpenAIPrivacy(ctx context.Context, account *Account) string {
if account.Platform != PlatformOpenAI || account.Type != AccountTypeOAuth {
return ""
}
if s.privacyClientFactory == nil {
return ""
}
token, _ := account.Credentials["access_token"].(string)
if token == "" {
return ""
}
var proxyURL string
if account.ProxyID != nil {
if p, err := s.proxyRepo.GetByID(ctx, *account.ProxyID); err == nil && p != nil {
proxyURL = p.URL()
}
}
mode := disableOpenAITraining(ctx, s.privacyClientFactory, token, proxyURL)
if mode == "" {
return ""
}
if err := s.accountRepo.UpdateExtra(ctx, account.ID, map[string]any{"privacy_mode": mode}); err != nil {
logger.LegacyPrintf("service.admin", "force_update_openai_privacy_mode_failed: account_id=%d err=%v", account.ID, err)
return mode
}
if account.Extra == nil {
account.Extra = make(map[string]any)
}
account.Extra["privacy_mode"] = mode
return mode
}
// EnsureAntigravityPrivacy 检查 Antigravity OAuth 账号隐私状态。 // EnsureAntigravityPrivacy 检查 Antigravity OAuth 账号隐私状态。
// 如果 Extra["privacy_mode"] 已存在(无论成功或失败),直接跳过。 // 如果 Extra["privacy_mode"] 已存在(无论成功或失败),直接跳过。
// 仅对从未设置过隐私的账号执行 setUserSettings + fetchUserInfo 流程。 // 仅对从未设置过隐私的账号执行 setUserSettings + fetchUserInfo 流程。
......
...@@ -124,6 +124,27 @@ func TestSystemIncludesClaudeCodePrompt(t *testing.T) { ...@@ -124,6 +124,27 @@ func TestSystemIncludesClaudeCodePrompt(t *testing.T) {
}, },
want: false, want: false,
}, },
// json.RawMessage cases (conversion path: ForwardAsResponses / ForwardAsChatCompletions)
{
name: "json.RawMessage string with Claude Code prompt",
system: json.RawMessage(`"` + claudeCodeSystemPrompt + `"`),
want: true,
},
{
name: "json.RawMessage string without Claude Code prompt",
system: json.RawMessage(`"You are a helpful assistant"`),
want: false,
},
{
name: "json.RawMessage nil (empty)",
system: json.RawMessage(nil),
want: false,
},
{
name: "json.RawMessage empty string",
system: json.RawMessage(`""`),
want: false,
},
} }
for _, tt := range tests { for _, tt := range tests {
...@@ -202,6 +223,29 @@ func TestInjectClaudeCodePrompt(t *testing.T) { ...@@ -202,6 +223,29 @@ func TestInjectClaudeCodePrompt(t *testing.T) {
wantSystemLen: 1, wantSystemLen: 1,
wantFirstText: claudeCodeSystemPrompt, wantFirstText: claudeCodeSystemPrompt,
}, },
// json.RawMessage cases (conversion path: ForwardAsResponses / ForwardAsChatCompletions)
{
name: "json.RawMessage string system",
body: `{"model":"claude-3","system":"Custom prompt"}`,
system: json.RawMessage(`"Custom prompt"`),
wantSystemLen: 2,
wantFirstText: claudeCodeSystemPrompt,
wantSecondText: claudePrefix + "\n\nCustom prompt",
},
{
name: "json.RawMessage nil system",
body: `{"model":"claude-3"}`,
system: json.RawMessage(nil),
wantSystemLen: 1,
wantFirstText: claudeCodeSystemPrompt,
},
{
name: "json.RawMessage Claude Code prompt (should not duplicate)",
body: `{"model":"claude-3","system":"` + claudeCodeSystemPrompt + `"}`,
system: json.RawMessage(`"` + claudeCodeSystemPrompt + `"`),
wantSystemLen: 1,
wantFirstText: claudeCodeSystemPrompt,
},
} }
for _, tt := range tests { for _, tt := range tests {
......
...@@ -3749,9 +3749,28 @@ func isClaudeCodeRequest(ctx context.Context, c *gin.Context, parsed *ParsedRequ ...@@ -3749,9 +3749,28 @@ func isClaudeCodeRequest(ctx context.Context, c *gin.Context, parsed *ParsedRequ
return isClaudeCodeClient(c.GetHeader("User-Agent"), parsed.MetadataUserID) return isClaudeCodeClient(c.GetHeader("User-Agent"), parsed.MetadataUserID)
} }
// normalizeSystemParam 将 json.RawMessage 类型的 system 参数转为标准 Go 类型(string / []any / nil),
// 避免 type switch 中 json.RawMessage(底层 []byte)无法匹配 case string / case []any / case nil 的问题。
// 这是 Go 的 typed nil 陷阱:(json.RawMessage, nil) ≠ (nil, nil)。
func normalizeSystemParam(system any) any {
raw, ok := system.(json.RawMessage)
if !ok {
return system
}
if len(raw) == 0 {
return nil
}
var parsed any
if err := json.Unmarshal(raw, &parsed); err != nil {
return nil
}
return parsed
}
// systemIncludesClaudeCodePrompt 检查 system 中是否已包含 Claude Code 提示词 // systemIncludesClaudeCodePrompt 检查 system 中是否已包含 Claude Code 提示词
// 使用前缀匹配支持多种变体(标准版、Agent SDK 版等) // 使用前缀匹配支持多种变体(标准版、Agent SDK 版等)
func systemIncludesClaudeCodePrompt(system any) bool { func systemIncludesClaudeCodePrompt(system any) bool {
system = normalizeSystemParam(system)
switch v := system.(type) { switch v := system.(type) {
case string: case string:
return hasClaudeCodePrefix(v) return hasClaudeCodePrefix(v)
...@@ -3780,6 +3799,7 @@ func hasClaudeCodePrefix(text string) bool { ...@@ -3780,6 +3799,7 @@ func hasClaudeCodePrefix(text string) bool {
// injectClaudeCodePrompt 在 system 开头注入 Claude Code 提示词 // injectClaudeCodePrompt 在 system 开头注入 Claude Code 提示词
// 处理 null、字符串、数组三种格式 // 处理 null、字符串、数组三种格式
func injectClaudeCodePrompt(body []byte, system any) []byte { func injectClaudeCodePrompt(body []byte, system any) []byte {
system = normalizeSystemParam(system)
claudeCodeBlock, err := marshalAnthropicSystemTextBlock(claudeCodeSystemPrompt, true) claudeCodeBlock, err := marshalAnthropicSystemTextBlock(claudeCodeSystemPrompt, true)
if err != nil { if err != nil {
logger.LegacyPrintf("service.gateway", "Warning: failed to build Claude Code prompt block: %v", err) logger.LegacyPrintf("service.gateway", "Warning: failed to build Claude Code prompt block: %v", err)
......
...@@ -300,6 +300,8 @@ func (s *TokenRefreshService) refreshWithRetry(ctx context.Context, account *Acc ...@@ -300,6 +300,8 @@ func (s *TokenRefreshService) refreshWithRetry(ctx context.Context, account *Acc
"error", setErr, "error", setErr,
) )
} }
// 刷新失败但 access_token 可能仍有效,尝试设置隐私
s.ensureOpenAIPrivacy(ctx, account)
return err return err
} }
...@@ -327,6 +329,9 @@ func (s *TokenRefreshService) refreshWithRetry(ctx context.Context, account *Acc ...@@ -327,6 +329,9 @@ func (s *TokenRefreshService) refreshWithRetry(ctx context.Context, account *Acc
"error", lastErr, "error", lastErr,
) )
// 刷新失败但 access_token 可能仍有效,尝试设置隐私
s.ensureOpenAIPrivacy(ctx, account)
// 设置临时不可调度 10 分钟(不标记 error,保持 status=active 让下个刷新周期能继续尝试) // 设置临时不可调度 10 分钟(不标记 error,保持 status=active 让下个刷新周期能继续尝试)
until := time.Now().Add(tokenRefreshTempUnschedDuration) until := time.Now().Add(tokenRefreshTempUnschedDuration)
reason := fmt.Sprintf("token refresh retry exhausted: %v", lastErr) reason := fmt.Sprintf("token refresh retry exhausted: %v", lastErr)
......
...@@ -32,7 +32,7 @@ ...@@ -32,7 +32,7 @@
{{ t('admin.accounts.refreshToken') }} {{ t('admin.accounts.refreshToken') }}
</button> </button>
</template> </template>
<button v-if="isAntigravityOAuth" @click="$emit('set-privacy', account); $emit('close')" class="flex w-full items-center gap-2 px-4 py-2 text-sm text-emerald-600 hover:bg-gray-100 dark:hover:bg-dark-700"> <button v-if="supportsPrivacy" @click="$emit('set-privacy', account); $emit('close')" class="flex w-full items-center gap-2 px-4 py-2 text-sm text-emerald-600 hover:bg-gray-100 dark:hover:bg-dark-700">
<Icon name="shield" size="sm" /> <Icon name="shield" size="sm" />
{{ t('admin.accounts.setPrivacy') }} {{ t('admin.accounts.setPrivacy') }}
</button> </button>
...@@ -80,6 +80,8 @@ const hasRecoverableState = computed(() => { ...@@ -80,6 +80,8 @@ const hasRecoverableState = computed(() => {
return props.account?.status === 'error' || Boolean(isRateLimited.value) || Boolean(isOverloaded.value) || Boolean(isTempUnschedulable.value) return props.account?.status === 'error' || Boolean(isRateLimited.value) || Boolean(isOverloaded.value) || Boolean(isTempUnschedulable.value)
}) })
const isAntigravityOAuth = computed(() => props.account?.platform === 'antigravity' && props.account?.type === 'oauth') const isAntigravityOAuth = computed(() => props.account?.platform === 'antigravity' && props.account?.type === 'oauth')
const isOpenAIOAuth = computed(() => props.account?.platform === 'openai' && props.account?.type === 'oauth')
const supportsPrivacy = computed(() => isAntigravityOAuth.value || isOpenAIOAuth.value)
const hasQuotaLimit = computed(() => { const hasQuotaLimit = computed(() => {
return (props.account?.type === 'apikey' || props.account?.type === 'bedrock') && ( return (props.account?.type === 'apikey' || props.account?.type === 'bedrock') && (
(props.account?.quota_limit ?? 0) > 0 || (props.account?.quota_limit ?? 0) > 0 ||
......
...@@ -1262,7 +1262,7 @@ const handleSetPrivacy = async (a: Account) => { ...@@ -1262,7 +1262,7 @@ const handleSetPrivacy = async (a: Account) => {
appStore.showSuccess(t('common.success')) appStore.showSuccess(t('common.success'))
} catch (error: any) { } catch (error: any) {
console.error('Failed to set privacy:', error) console.error('Failed to set privacy:', error)
appStore.showError(error?.response?.data?.message || t('admin.accounts.privacyAntigravityFailed')) appStore.showError(error?.response?.data?.message || t('admin.accounts.privacyFailed'))
} }
} }
const handleDelete = (a: Account) => { deletingAcc.value = a; showDeleteDialog.value = true } const handleDelete = (a: Account) => { deletingAcc.value = a; showDeleteDialog.value = true }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment