Unverified Commit fa68cbad authored by InCerryGit's avatar InCerryGit Committed by GitHub
Browse files

Merge branch 'Wei-Shaw:main' into main

parents 995ef134 0f033930
...@@ -79,6 +79,17 @@ type OpsInsertErrorLogInput struct { ...@@ -79,6 +79,17 @@ type OpsInsertErrorLogInput struct {
Model string Model string
RequestPath string RequestPath string
Stream bool Stream bool
// InboundEndpoint is the normalized client-facing API endpoint path, e.g. /v1/chat/completions.
InboundEndpoint string
// UpstreamEndpoint is the normalized upstream endpoint path, e.g. /v1/responses.
UpstreamEndpoint string
// RequestedModel is the client-requested model name before mapping.
RequestedModel string
// UpstreamModel is the actual model sent to upstream after mapping. Empty means no mapping.
UpstreamModel string
// RequestType is the granular request type: 0=unknown, 1=sync, 2=stream, 3=ws_v2.
// Matches service.RequestType enum semantics from usage_log.go.
RequestType *int16
UserAgent string UserAgent string
ErrorPhase string ErrorPhase string
......
...@@ -93,6 +93,10 @@ type OpsUpstreamErrorEvent struct { ...@@ -93,6 +93,10 @@ type OpsUpstreamErrorEvent struct {
UpstreamStatusCode int `json:"upstream_status_code,omitempty"` UpstreamStatusCode int `json:"upstream_status_code,omitempty"`
UpstreamRequestID string `json:"upstream_request_id,omitempty"` UpstreamRequestID string `json:"upstream_request_id,omitempty"`
// UpstreamURL is the actual upstream URL that was called (host + path, query/fragment stripped).
// Helps debug 404/routing errors by showing which endpoint was targeted.
UpstreamURL string `json:"upstream_url,omitempty"`
// Best-effort upstream request capture (sanitized+trimmed). // Best-effort upstream request capture (sanitized+trimmed).
// Required for retrying a specific upstream attempt. // Required for retrying a specific upstream attempt.
UpstreamRequestBody string `json:"upstream_request_body,omitempty"` UpstreamRequestBody string `json:"upstream_request_body,omitempty"`
...@@ -119,6 +123,7 @@ func appendOpsUpstreamError(c *gin.Context, ev OpsUpstreamErrorEvent) { ...@@ -119,6 +123,7 @@ func appendOpsUpstreamError(c *gin.Context, ev OpsUpstreamErrorEvent) {
ev.UpstreamRequestBody = strings.TrimSpace(ev.UpstreamRequestBody) ev.UpstreamRequestBody = strings.TrimSpace(ev.UpstreamRequestBody)
ev.UpstreamResponseBody = strings.TrimSpace(ev.UpstreamResponseBody) ev.UpstreamResponseBody = strings.TrimSpace(ev.UpstreamResponseBody)
ev.Kind = strings.TrimSpace(ev.Kind) ev.Kind = strings.TrimSpace(ev.Kind)
ev.UpstreamURL = strings.TrimSpace(ev.UpstreamURL)
ev.Message = strings.TrimSpace(ev.Message) ev.Message = strings.TrimSpace(ev.Message)
ev.Detail = strings.TrimSpace(ev.Detail) ev.Detail = strings.TrimSpace(ev.Detail)
if ev.Message != "" { if ev.Message != "" {
...@@ -205,3 +210,19 @@ func ParseOpsUpstreamErrors(raw string) ([]*OpsUpstreamErrorEvent, error) { ...@@ -205,3 +210,19 @@ func ParseOpsUpstreamErrors(raw string) ([]*OpsUpstreamErrorEvent, error) {
} }
return out, nil return out, nil
} }
// safeUpstreamURL returns scheme + host + path from a URL, stripping query/fragment
// to avoid leaking sensitive query parameters (e.g. OAuth tokens).
func safeUpstreamURL(rawURL string) string {
rawURL = strings.TrimSpace(rawURL)
if rawURL == "" {
return ""
}
if idx := strings.IndexByte(rawURL, '?'); idx >= 0 {
rawURL = rawURL[:idx]
}
if idx := strings.IndexByte(rawURL, '#'); idx >= 0 {
rawURL = rawURL[:idx]
}
return rawURL
}
...@@ -8,6 +8,27 @@ import ( ...@@ -8,6 +8,27 @@ import (
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
func TestSafeUpstreamURL(t *testing.T) {
tests := []struct {
name string
input string
want string
}{
{"strips query", "https://api.anthropic.com/v1/messages?beta=true", "https://api.anthropic.com/v1/messages"},
{"strips fragment", "https://api.openai.com/v1/responses#frag", "https://api.openai.com/v1/responses"},
{"strips both", "https://host/path?token=secret#x", "https://host/path"},
{"no query or fragment", "https://host/path", "https://host/path"},
{"empty string", "", ""},
{"whitespace only", " ", ""},
{"query before fragment", "https://h/p?a=1#f", "https://h/p"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
require.Equal(t, tt.want, safeUpstreamURL(tt.input))
})
}
}
func TestAppendOpsUpstreamError_UsesRequestBodyBytesFromContext(t *testing.T) { func TestAppendOpsUpstreamError_UsesRequestBodyBytesFromContext(t *testing.T) {
gin.SetMode(gin.TestMode) gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder() rec := httptest.NewRecorder()
......
...@@ -163,7 +163,7 @@ func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Acc ...@@ -163,7 +163,7 @@ func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Acc
account.Credentials = make(map[string]any) account.Credentials = make(map[string]any)
} }
account.Credentials["expires_at"] = time.Now().Format(time.RFC3339) account.Credentials["expires_at"] = time.Now().Format(time.RFC3339)
if err := s.accountRepo.Update(ctx, account); err != nil { if err := persistAccountCredentials(ctx, s.accountRepo, account, account.Credentials); err != nil {
slog.Warn("oauth_401_force_refresh_update_failed", "account_id", account.ID, "error", err) slog.Warn("oauth_401_force_refresh_update_failed", "account_id", account.ID, "error", err)
} else { } else {
slog.Info("oauth_401_force_refresh_set", "account_id", account.ID, "platform", account.Platform) slog.Info("oauth_401_force_refresh_set", "account_id", account.ID, "platform", account.Platform)
......
...@@ -15,9 +15,11 @@ import ( ...@@ -15,9 +15,11 @@ import (
type rateLimitAccountRepoStub struct { type rateLimitAccountRepoStub struct {
mockAccountRepoForGemini mockAccountRepoForGemini
setErrorCalls int setErrorCalls int
tempCalls int tempCalls int
lastErrorMsg string updateCredentialsCalls int
lastCredentials map[string]any
lastErrorMsg string
} }
func (r *rateLimitAccountRepoStub) SetError(ctx context.Context, id int64, errorMsg string) error { func (r *rateLimitAccountRepoStub) SetError(ctx context.Context, id int64, errorMsg string) error {
...@@ -31,6 +33,12 @@ func (r *rateLimitAccountRepoStub) SetTempUnschedulable(ctx context.Context, id ...@@ -31,6 +33,12 @@ func (r *rateLimitAccountRepoStub) SetTempUnschedulable(ctx context.Context, id
return nil return nil
} }
func (r *rateLimitAccountRepoStub) UpdateCredentials(ctx context.Context, id int64, credentials map[string]any) error {
r.updateCredentialsCalls++
r.lastCredentials = cloneCredentials(credentials)
return nil
}
type tokenCacheInvalidatorRecorder struct { type tokenCacheInvalidatorRecorder struct {
accounts []*Account accounts []*Account
err error err error
...@@ -110,6 +118,7 @@ func TestRateLimitService_HandleUpstreamError_OAuth401InvalidatorError(t *testin ...@@ -110,6 +118,7 @@ func TestRateLimitService_HandleUpstreamError_OAuth401InvalidatorError(t *testin
require.True(t, shouldDisable) require.True(t, shouldDisable)
require.Equal(t, 0, repo.setErrorCalls) require.Equal(t, 0, repo.setErrorCalls)
require.Equal(t, 1, repo.tempCalls) require.Equal(t, 1, repo.tempCalls)
require.Equal(t, 1, repo.updateCredentialsCalls)
require.Len(t, invalidator.accounts, 1) require.Len(t, invalidator.accounts, 1)
} }
...@@ -130,3 +139,22 @@ func TestRateLimitService_HandleUpstreamError_NonOAuth401(t *testing.T) { ...@@ -130,3 +139,22 @@ func TestRateLimitService_HandleUpstreamError_NonOAuth401(t *testing.T) {
require.Equal(t, 1, repo.setErrorCalls) require.Equal(t, 1, repo.setErrorCalls)
require.Empty(t, invalidator.accounts) require.Empty(t, invalidator.accounts)
} }
func TestRateLimitService_HandleUpstreamError_OAuth401UsesCredentialsUpdater(t *testing.T) {
repo := &rateLimitAccountRepoStub{}
service := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
account := &Account{
ID: 103,
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "token",
},
}
shouldDisable := service.HandleUpstreamError(context.Background(), account, 401, http.Header{}, []byte("unauthorized"))
require.True(t, shouldDisable)
require.Equal(t, 1, repo.updateCredentialsCalls)
require.NotEmpty(t, repo.lastCredentials["expires_at"])
}
...@@ -81,7 +81,7 @@ func (m *sessionWindowMockRepo) Delete(context.Context, int64) error { panic( ...@@ -81,7 +81,7 @@ func (m *sessionWindowMockRepo) Delete(context.Context, int64) error { panic(
func (m *sessionWindowMockRepo) List(context.Context, pagination.PaginationParams) ([]Account, *pagination.PaginationResult, error) { func (m *sessionWindowMockRepo) List(context.Context, pagination.PaginationParams) ([]Account, *pagination.PaginationResult, error) {
panic("unexpected") panic("unexpected")
} }
func (m *sessionWindowMockRepo) ListWithFilters(context.Context, pagination.PaginationParams, string, string, string, string, int64) ([]Account, *pagination.PaginationResult, error) { func (m *sessionWindowMockRepo) ListWithFilters(context.Context, pagination.PaginationParams, string, string, string, string, int64, string) ([]Account, *pagination.PaginationResult, error) {
panic("unexpected") panic("unexpected")
} }
func (m *sessionWindowMockRepo) ListByGroup(context.Context, int64) ([]Account, error) { func (m *sessionWindowMockRepo) ListByGroup(context.Context, int64) ([]Account, error) {
......
...@@ -947,7 +947,7 @@ func (c *SoraSDKClient) applyRecoveredToken(ctx context.Context, account *Accoun ...@@ -947,7 +947,7 @@ func (c *SoraSDKClient) applyRecoveredToken(ctx context.Context, account *Accoun
} }
if c.accountRepo != nil { if c.accountRepo != nil {
if err := c.accountRepo.Update(ctx, account); err != nil && c.debugEnabled() { if err := persistAccountCredentials(ctx, c.accountRepo, account, account.Credentials); err != nil && c.debugEnabled() {
c.debugLogf("persist_recovered_token_failed account_id=%d err=%s", account.ID, logredact.RedactText(err.Error())) c.debugLogf("persist_recovered_token_failed account_id=%d err=%s", account.ID, logredact.RedactText(err.Error()))
} }
} }
......
...@@ -280,8 +280,7 @@ func (s *TokenRefreshService) refreshWithRetry(ctx context.Context, account *Acc ...@@ -280,8 +280,7 @@ func (s *TokenRefreshService) refreshWithRetry(ctx context.Context, account *Acc
newCredentials, err = refresher.Refresh(ctx, account) newCredentials, err = refresher.Refresh(ctx, account)
if newCredentials != nil { if newCredentials != nil {
newCredentials["_token_version"] = time.Now().UnixMilli() newCredentials["_token_version"] = time.Now().UnixMilli()
account.Credentials = newCredentials if saveErr := persistAccountCredentials(ctx, s.accountRepo, account, newCredentials); saveErr != nil {
if saveErr := s.accountRepo.Update(ctx, account); saveErr != nil {
return fmt.Errorf("failed to save credentials: %w", saveErr) return fmt.Errorf("failed to save credentials: %w", saveErr)
} }
} }
......
...@@ -14,19 +14,40 @@ import ( ...@@ -14,19 +14,40 @@ import (
type tokenRefreshAccountRepo struct { type tokenRefreshAccountRepo struct {
mockAccountRepoForGemini mockAccountRepoForGemini
updateCalls int updateCalls int
setErrorCalls int fullUpdateCalls int
clearTempCalls int updateCredentialsCalls int
lastAccount *Account setErrorCalls int
updateErr error clearTempCalls int
lastAccount *Account
updateErr error
} }
func (r *tokenRefreshAccountRepo) Update(ctx context.Context, account *Account) error { func (r *tokenRefreshAccountRepo) Update(ctx context.Context, account *Account) error {
r.updateCalls++ r.updateCalls++
r.fullUpdateCalls++
r.lastAccount = account r.lastAccount = account
return r.updateErr return r.updateErr
} }
func (r *tokenRefreshAccountRepo) UpdateCredentials(ctx context.Context, id int64, credentials map[string]any) error {
r.updateCalls++
r.updateCredentialsCalls++
if r.updateErr != nil {
return r.updateErr
}
cloned := cloneCredentials(credentials)
if r.accountsByID != nil {
if acc, ok := r.accountsByID[id]; ok && acc != nil {
acc.Credentials = cloned
r.lastAccount = acc
return nil
}
}
r.lastAccount = &Account{ID: id, Credentials: cloned}
return nil
}
func (r *tokenRefreshAccountRepo) SetError(ctx context.Context, id int64, errorMsg string) error { func (r *tokenRefreshAccountRepo) SetError(ctx context.Context, id int64, errorMsg string) error {
r.setErrorCalls++ r.setErrorCalls++
return nil return nil
...@@ -112,6 +133,8 @@ func TestTokenRefreshService_RefreshWithRetry_InvalidatesCache(t *testing.T) { ...@@ -112,6 +133,8 @@ func TestTokenRefreshService_RefreshWithRetry_InvalidatesCache(t *testing.T) {
err := service.refreshWithRetry(context.Background(), account, refresher, refresher, time.Hour) err := service.refreshWithRetry(context.Background(), account, refresher, refresher, time.Hour)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, 1, repo.updateCalls) require.Equal(t, 1, repo.updateCalls)
require.Equal(t, 1, repo.updateCredentialsCalls)
require.Equal(t, 0, repo.fullUpdateCalls)
require.Equal(t, 1, invalidator.calls) require.Equal(t, 1, invalidator.calls)
require.Equal(t, "new-token", account.GetCredential("access_token")) require.Equal(t, "new-token", account.GetCredential("access_token"))
} }
...@@ -249,9 +272,43 @@ func TestTokenRefreshService_RefreshWithRetry_OtherPlatformOAuth(t *testing.T) { ...@@ -249,9 +272,43 @@ func TestTokenRefreshService_RefreshWithRetry_OtherPlatformOAuth(t *testing.T) {
err := service.refreshWithRetry(context.Background(), account, refresher, refresher, time.Hour) err := service.refreshWithRetry(context.Background(), account, refresher, refresher, time.Hour)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, 1, repo.updateCalls) require.Equal(t, 1, repo.updateCalls)
require.Equal(t, 1, repo.updateCredentialsCalls)
require.Equal(t, 1, invalidator.calls) // 所有 OAuth 账户刷新后触发缓存失效 require.Equal(t, 1, invalidator.calls) // 所有 OAuth 账户刷新后触发缓存失效
} }
func TestTokenRefreshService_RefreshWithRetry_UsesCredentialsUpdater(t *testing.T) {
repo := &tokenRefreshAccountRepo{}
cfg := &config.Config{
TokenRefresh: config.TokenRefreshConfig{
MaxRetries: 1,
RetryBackoffSeconds: 0,
},
}
service := NewTokenRefreshService(repo, nil, nil, nil, nil, nil, nil, cfg, nil)
resetAt := time.Now().Add(30 * time.Minute)
account := &Account{
ID: 17,
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
RateLimitResetAt: &resetAt,
Credentials: map[string]any{
"access_token": "old-token",
},
}
refresher := &tokenRefresherStub{
credentials: map[string]any{
"access_token": "new-token",
},
}
err := service.refreshWithRetry(context.Background(), account, refresher, refresher, time.Hour)
require.NoError(t, err)
require.Equal(t, 1, repo.updateCredentialsCalls)
require.Equal(t, 0, repo.fullUpdateCalls)
require.NotNil(t, account.RateLimitResetAt)
require.WithinDuration(t, resetAt, *account.RateLimitResetAt, time.Second)
}
// TestTokenRefreshService_RefreshWithRetry_UpdateFailed 测试更新失败的情况 // TestTokenRefreshService_RefreshWithRetry_UpdateFailed 测试更新失败的情况
func TestTokenRefreshService_RefreshWithRetry_UpdateFailed(t *testing.T) { func TestTokenRefreshService_RefreshWithRetry_UpdateFailed(t *testing.T) {
repo := &tokenRefreshAccountRepo{updateErr: errors.New("update failed")} repo := &tokenRefreshAccountRepo{updateErr: errors.New("update failed")}
...@@ -390,7 +447,7 @@ func TestTokenRefreshService_RefreshWithRetry_ClearsTempUnschedulable(t *testing ...@@ -390,7 +447,7 @@ func TestTokenRefreshService_RefreshWithRetry_ClearsTempUnschedulable(t *testing
err := service.refreshWithRetry(context.Background(), account, refresher, refresher, time.Hour) err := service.refreshWithRetry(context.Background(), account, refresher, refresher, time.Hour)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, 1, repo.updateCalls) require.Equal(t, 1, repo.updateCalls)
require.Equal(t, 1, repo.clearTempCalls) // DB 清除 require.Equal(t, 1, repo.clearTempCalls) // DB 清除
require.Equal(t, 1, tempCache.deleteCalls) // Redis 缓存也应清除 require.Equal(t, 1, tempCache.deleteCalls) // Redis 缓存也应清除
} }
......
-- Ops error logs: add endpoint, model mapping, and request_type fields
-- to match usage_logs observability coverage.
--
-- All columns are nullable with no default to preserve backward compatibility
-- with existing rows.
SET LOCAL lock_timeout = '5s';
SET LOCAL statement_timeout = '10min';
-- 1) Standardized endpoint paths (analogous to usage_logs.inbound_endpoint / upstream_endpoint)
ALTER TABLE ops_error_logs
ADD COLUMN IF NOT EXISTS inbound_endpoint VARCHAR(256),
ADD COLUMN IF NOT EXISTS upstream_endpoint VARCHAR(256);
-- 2) Model mapping fields (analogous to usage_logs.requested_model / upstream_model)
ALTER TABLE ops_error_logs
ADD COLUMN IF NOT EXISTS requested_model VARCHAR(100),
ADD COLUMN IF NOT EXISTS upstream_model VARCHAR(100);
-- 3) Granular request type enum (analogous to usage_logs.request_type: 0=unknown, 1=sync, 2=stream, 3=ws_v2)
ALTER TABLE ops_error_logs
ADD COLUMN IF NOT EXISTS request_type SMALLINT;
COMMENT ON COLUMN ops_error_logs.inbound_endpoint IS 'Normalized client-facing API endpoint path, e.g. /v1/chat/completions. Populated from InboundEndpointMiddleware.';
COMMENT ON COLUMN ops_error_logs.upstream_endpoint IS 'Normalized upstream endpoint path derived from platform, e.g. /v1/responses.';
COMMENT ON COLUMN ops_error_logs.requested_model IS 'Client-requested model name before mapping (raw from request body).';
COMMENT ON COLUMN ops_error_logs.upstream_model IS 'Actual model sent to upstream provider after mapping. NULL means no mapping applied.';
COMMENT ON COLUMN ops_error_logs.request_type IS 'Request type enum: 0=unknown, 1=sync, 2=stream, 3=ws_v2. Matches usage_logs.request_type semantics.';
...@@ -36,6 +36,7 @@ export async function list( ...@@ -36,6 +36,7 @@ export async function list(
status?: string status?: string
group?: string group?: string
search?: string search?: string
privacy_mode?: string
lite?: string lite?: string
}, },
options?: { options?: {
...@@ -68,6 +69,7 @@ export async function listWithEtag( ...@@ -68,6 +69,7 @@ export async function listWithEtag(
status?: string status?: string
group?: string group?: string
search?: string search?: string
privacy_mode?: string
lite?: string lite?: string
}, },
options?: { options?: {
...@@ -550,14 +552,18 @@ export async function getAntigravityDefaultModelMapping(): Promise<Record<string ...@@ -550,14 +552,18 @@ export async function getAntigravityDefaultModelMapping(): Promise<Record<string
export async function refreshOpenAIToken( export async function refreshOpenAIToken(
refreshToken: string, refreshToken: string,
proxyId?: number | null, proxyId?: number | null,
endpoint: string = '/admin/openai/refresh-token' endpoint: string = '/admin/openai/refresh-token',
clientId?: string
): Promise<Record<string, unknown>> { ): Promise<Record<string, unknown>> {
const payload: { refresh_token: string; proxy_id?: number } = { const payload: { refresh_token: string; proxy_id?: number; client_id?: string } = {
refresh_token: refreshToken refresh_token: refreshToken
} }
if (proxyId) { if (proxyId) {
payload.proxy_id = proxyId payload.proxy_id = proxyId
} }
if (clientId) {
payload.client_id = clientId
}
const { data } = await apiClient.post<Record<string, unknown>>(endpoint, payload) const { data } = await apiClient.post<Record<string, unknown>>(endpoint, payload)
return data return data
} }
......
...@@ -969,6 +969,13 @@ export interface OpsErrorLog { ...@@ -969,6 +969,13 @@ export interface OpsErrorLog {
client_ip?: string | null client_ip?: string | null
request_path?: string request_path?: string
stream?: boolean stream?: boolean
// Error observability context (endpoint + model mapping)
inbound_endpoint?: string
upstream_endpoint?: string
requested_model?: string
upstream_model?: string
request_type?: number | null
} }
export interface OpsErrorDetail extends OpsErrorLog { export interface OpsErrorDetail extends OpsErrorLog {
......
...@@ -599,6 +599,43 @@ ...@@ -599,6 +599,43 @@
</div> </div>
</div> </div>
<!-- OpenAI OAuth WS mode -->
<div v-if="allOpenAIOAuth" class="border-t border-gray-200 pt-4 dark:border-dark-600">
<div class="mb-3 flex items-center justify-between">
<label
id="bulk-edit-openai-ws-mode-label"
class="input-label mb-0"
for="bulk-edit-openai-ws-mode-enabled"
>
{{ t('admin.accounts.openai.wsMode') }}
</label>
<input
v-model="enableOpenAIWSMode"
id="bulk-edit-openai-ws-mode-enabled"
type="checkbox"
aria-controls="bulk-edit-openai-ws-mode"
class="rounded border-gray-300 text-primary-600 focus:ring-primary-500"
/>
</div>
<div
id="bulk-edit-openai-ws-mode"
:class="!enableOpenAIWSMode && 'pointer-events-none opacity-50'"
>
<p class="mb-3 text-xs text-gray-500 dark:text-gray-400">
{{ t('admin.accounts.openai.wsModeDesc') }}
</p>
<p class="mb-3 text-xs text-gray-500 dark:text-gray-400">
{{ t(openAIWSModeConcurrencyHintKey) }}
</p>
<Select
v-model="openaiOAuthResponsesWebSocketV2Mode"
data-testid="bulk-edit-openai-ws-mode-select"
:options="openAIWSModeOptions"
aria-labelledby="bulk-edit-openai-ws-mode-label"
/>
</div>
</div>
<!-- RPM Limit (仅全部为 Anthropic OAuth/SetupToken 时显示) --> <!-- RPM Limit (仅全部为 Anthropic OAuth/SetupToken 时显示) -->
<div v-if="allAnthropicOAuthOrSetupToken" class="border-t border-gray-200 pt-4 dark:border-dark-600"> <div v-if="allAnthropicOAuthOrSetupToken" class="border-t border-gray-200 pt-4 dark:border-dark-600">
<div class="mb-3 flex items-center justify-between"> <div class="mb-3 flex items-center justify-between">
...@@ -821,6 +858,13 @@ import { ...@@ -821,6 +858,13 @@ import {
buildModelMappingObject as buildModelMappingPayload, buildModelMappingObject as buildModelMappingPayload,
getPresetMappingsByPlatform getPresetMappingsByPlatform
} from '@/composables/useModelWhitelist' } from '@/composables/useModelWhitelist'
import {
OPENAI_WS_MODE_OFF,
OPENAI_WS_MODE_PASSTHROUGH,
isOpenAIWSModeEnabled,
resolveOpenAIWSModeConcurrencyHintKey
} from '@/utils/openaiWsMode'
import type { OpenAIWSMode } from '@/utils/openaiWsMode'
interface Props { interface Props {
show: boolean show: boolean
...@@ -843,6 +887,15 @@ const appStore = useAppStore() ...@@ -843,6 +887,15 @@ const appStore = useAppStore()
// Platform awareness // Platform awareness
const isMixedPlatform = computed(() => props.selectedPlatforms.length > 1) const isMixedPlatform = computed(() => props.selectedPlatforms.length > 1)
const allOpenAIOAuth = computed(() => {
return (
props.selectedPlatforms.length === 1 &&
props.selectedPlatforms[0] === 'openai' &&
props.selectedTypes.length > 0 &&
props.selectedTypes.every(t => t === 'oauth')
)
})
// 是否全部为 Anthropic OAuth/SetupToken(RPM 配置仅在此条件下显示) // 是否全部为 Anthropic OAuth/SetupToken(RPM 配置仅在此条件下显示)
const allAnthropicOAuthOrSetupToken = computed(() => { const allAnthropicOAuthOrSetupToken = computed(() => {
return ( return (
...@@ -886,6 +939,7 @@ const enablePriority = ref(false) ...@@ -886,6 +939,7 @@ const enablePriority = ref(false)
const enableRateMultiplier = ref(false) const enableRateMultiplier = ref(false)
const enableStatus = ref(false) const enableStatus = ref(false)
const enableGroups = ref(false) const enableGroups = ref(false)
const enableOpenAIWSMode = ref(false)
const enableRpmLimit = ref(false) const enableRpmLimit = ref(false)
// State - field values // State - field values
...@@ -907,6 +961,7 @@ const priority = ref(1) ...@@ -907,6 +961,7 @@ const priority = ref(1)
const rateMultiplier = ref(1) const rateMultiplier = ref(1)
const status = ref<'active' | 'inactive'>('active') const status = ref<'active' | 'inactive'>('active')
const groupIds = ref<number[]>([]) const groupIds = ref<number[]>([])
const openaiOAuthResponsesWebSocketV2Mode = ref<OpenAIWSMode>(OPENAI_WS_MODE_OFF)
const rpmLimitEnabled = ref(false) const rpmLimitEnabled = ref(false)
const bulkBaseRpm = ref<number | null>(null) const bulkBaseRpm = ref<number | null>(null)
const bulkRpmStrategy = ref<'tiered' | 'sticky_exempt'>('tiered') const bulkRpmStrategy = ref<'tiered' | 'sticky_exempt'>('tiered')
...@@ -933,6 +988,13 @@ const statusOptions = computed(() => [ ...@@ -933,6 +988,13 @@ const statusOptions = computed(() => [
{ value: 'active', label: t('common.active') }, { value: 'active', label: t('common.active') },
{ value: 'inactive', label: t('common.inactive') } { value: 'inactive', label: t('common.inactive') }
]) ])
const openAIWSModeOptions = computed(() => [
{ value: OPENAI_WS_MODE_OFF, label: t('admin.accounts.openai.wsModeOff') },
{ value: OPENAI_WS_MODE_PASSTHROUGH, label: t('admin.accounts.openai.wsModePassthrough') }
])
const openAIWSModeConcurrencyHintKey = computed(() =>
resolveOpenAIWSModeConcurrencyHintKey(openaiOAuthResponsesWebSocketV2Mode.value)
)
// Model mapping helpers // Model mapping helpers
const addModelMapping = () => { const addModelMapping = () => {
...@@ -1015,6 +1077,12 @@ const buildUpdatePayload = (): Record<string, unknown> | null => { ...@@ -1015,6 +1077,12 @@ const buildUpdatePayload = (): Record<string, unknown> | null => {
const updates: Record<string, unknown> = {} const updates: Record<string, unknown> = {}
const credentials: Record<string, unknown> = {} const credentials: Record<string, unknown> = {}
let credentialsChanged = false let credentialsChanged = false
const ensureExtra = (): Record<string, unknown> => {
if (!updates.extra) {
updates.extra = {}
}
return updates.extra as Record<string, unknown>
}
if (enableProxy.value) { if (enableProxy.value) {
// 后端期望 proxy_id: 0 表示清除代理,而不是 null // 后端期望 proxy_id: 0 表示清除代理,而不是 null
...@@ -1089,9 +1157,17 @@ const buildUpdatePayload = (): Record<string, unknown> | null => { ...@@ -1089,9 +1157,17 @@ const buildUpdatePayload = (): Record<string, unknown> | null => {
updates.credentials = credentials updates.credentials = credentials
} }
if (enableOpenAIWSMode.value) {
const extra = ensureExtra()
extra.openai_oauth_responses_websockets_v2_mode = openaiOAuthResponsesWebSocketV2Mode.value
extra.openai_oauth_responses_websockets_v2_enabled = isOpenAIWSModeEnabled(
openaiOAuthResponsesWebSocketV2Mode.value
)
}
// RPM limit settings (写入 extra 字段) // RPM limit settings (写入 extra 字段)
if (enableRpmLimit.value) { if (enableRpmLimit.value) {
const extra: Record<string, unknown> = {} const extra = ensureExtra()
if (rpmLimitEnabled.value && bulkBaseRpm.value != null && bulkBaseRpm.value > 0) { if (rpmLimitEnabled.value && bulkBaseRpm.value != null && bulkBaseRpm.value > 0) {
extra.base_rpm = bulkBaseRpm.value extra.base_rpm = bulkBaseRpm.value
extra.rpm_strategy = bulkRpmStrategy.value extra.rpm_strategy = bulkRpmStrategy.value
...@@ -1111,8 +1187,7 @@ const buildUpdatePayload = (): Record<string, unknown> | null => { ...@@ -1111,8 +1187,7 @@ const buildUpdatePayload = (): Record<string, unknown> | null => {
// UMQ mode(独立于 RPM 保存) // UMQ mode(独立于 RPM 保存)
if (userMsgQueueMode.value !== null) { if (userMsgQueueMode.value !== null) {
if (!updates.extra) updates.extra = {} const umqExtra = ensureExtra()
const umqExtra = updates.extra as Record<string, unknown>
umqExtra.user_msg_queue_mode = userMsgQueueMode.value // '' = 清除账号级覆盖 umqExtra.user_msg_queue_mode = userMsgQueueMode.value // '' = 清除账号级覆盖
umqExtra.user_msg_queue_enabled = false // 清理旧字段(JSONB merge) umqExtra.user_msg_queue_enabled = false // 清理旧字段(JSONB merge)
} }
...@@ -1178,6 +1253,7 @@ const handleSubmit = async () => { ...@@ -1178,6 +1253,7 @@ const handleSubmit = async () => {
enableRateMultiplier.value || enableRateMultiplier.value ||
enableStatus.value || enableStatus.value ||
enableGroups.value || enableGroups.value ||
enableOpenAIWSMode.value ||
enableRpmLimit.value || enableRpmLimit.value ||
userMsgQueueMode.value !== null userMsgQueueMode.value !== null
...@@ -1269,6 +1345,7 @@ watch( ...@@ -1269,6 +1345,7 @@ watch(
enableRateMultiplier.value = false enableRateMultiplier.value = false
enableStatus.value = false enableStatus.value = false
enableGroups.value = false enableGroups.value = false
enableOpenAIWSMode.value = false
enableRpmLimit.value = false enableRpmLimit.value = false
// Reset all values // Reset all values
...@@ -1286,6 +1363,7 @@ watch( ...@@ -1286,6 +1363,7 @@ watch(
rateMultiplier.value = 1 rateMultiplier.value = 1
status.value = 'active' status.value = 'active'
groupIds.value = [] groupIds.value = []
openaiOAuthResponsesWebSocketV2Mode.value = OPENAI_WS_MODE_OFF
rpmLimitEnabled.value = false rpmLimitEnabled.value = false
bulkBaseRpm.value = null bulkBaseRpm.value = null
bulkRpmStrategy.value = 'tiered' bulkRpmStrategy.value = 'tiered'
......
...@@ -2504,6 +2504,7 @@ ...@@ -2504,6 +2504,7 @@
:allow-multiple="form.platform === 'anthropic'" :allow-multiple="form.platform === 'anthropic'"
:show-cookie-option="form.platform === 'anthropic'" :show-cookie-option="form.platform === 'anthropic'"
:show-refresh-token-option="form.platform === 'openai' || form.platform === 'sora' || form.platform === 'antigravity'" :show-refresh-token-option="form.platform === 'openai' || form.platform === 'sora' || form.platform === 'antigravity'"
:show-mobile-refresh-token-option="form.platform === 'openai'"
:show-session-token-option="form.platform === 'sora'" :show-session-token-option="form.platform === 'sora'"
:show-access-token-option="form.platform === 'sora'" :show-access-token-option="form.platform === 'sora'"
:platform="form.platform" :platform="form.platform"
...@@ -2511,6 +2512,7 @@ ...@@ -2511,6 +2512,7 @@
@generate-url="handleGenerateUrl" @generate-url="handleGenerateUrl"
@cookie-auth="handleCookieAuth" @cookie-auth="handleCookieAuth"
@validate-refresh-token="handleValidateRefreshToken" @validate-refresh-token="handleValidateRefreshToken"
@validate-mobile-refresh-token="handleOpenAIValidateMobileRT"
@validate-session-token="handleValidateSessionToken" @validate-session-token="handleValidateSessionToken"
@import-access-token="handleImportAccessToken" @import-access-token="handleImportAccessToken"
/> />
...@@ -4360,11 +4362,14 @@ const handleOpenAIExchange = async (authCode: string) => { ...@@ -4360,11 +4362,14 @@ const handleOpenAIExchange = async (authCode: string) => {
} }
// OpenAI 手动 RT 批量验证和创建 // OpenAI 手动 RT 批量验证和创建
const handleOpenAIValidateRT = async (refreshTokenInput: string) => { // OpenAI Mobile RT 使用的 client_id(与后端 openai.SoraClientID 一致)
const OPENAI_MOBILE_RT_CLIENT_ID = 'app_LlGpXReQgckcGGUo2JrYvtJK'
// OpenAI/Sora RT 批量验证和创建(共享逻辑)
const handleOpenAIBatchRT = async (refreshTokenInput: string, clientId?: string) => {
const oauthClient = activeOpenAIOAuth.value const oauthClient = activeOpenAIOAuth.value
if (!refreshTokenInput.trim()) return if (!refreshTokenInput.trim()) return
// Parse multiple refresh tokens (one per line)
const refreshTokens = refreshTokenInput const refreshTokens = refreshTokenInput
.split('\n') .split('\n')
.map((rt) => rt.trim()) .map((rt) => rt.trim())
...@@ -4389,7 +4394,8 @@ const handleOpenAIValidateRT = async (refreshTokenInput: string) => { ...@@ -4389,7 +4394,8 @@ const handleOpenAIValidateRT = async (refreshTokenInput: string) => {
try { try {
const tokenInfo = await oauthClient.validateRefreshToken( const tokenInfo = await oauthClient.validateRefreshToken(
refreshTokens[i], refreshTokens[i],
form.proxy_id form.proxy_id,
clientId
) )
if (!tokenInfo) { if (!tokenInfo) {
failedCount++ failedCount++
...@@ -4399,6 +4405,9 @@ const handleOpenAIValidateRT = async (refreshTokenInput: string) => { ...@@ -4399,6 +4405,9 @@ const handleOpenAIValidateRT = async (refreshTokenInput: string) => {
} }
const credentials = oauthClient.buildCredentials(tokenInfo) const credentials = oauthClient.buildCredentials(tokenInfo)
if (clientId) {
credentials.client_id = clientId
}
const oauthExtra = oauthClient.buildExtraInfo(tokenInfo) as Record<string, unknown> | undefined const oauthExtra = oauthClient.buildExtraInfo(tokenInfo) as Record<string, unknown> | undefined
const extra = buildOpenAIExtra(oauthExtra) const extra = buildOpenAIExtra(oauthExtra)
...@@ -4410,8 +4419,9 @@ const handleOpenAIValidateRT = async (refreshTokenInput: string) => { ...@@ -4410,8 +4419,9 @@ const handleOpenAIValidateRT = async (refreshTokenInput: string) => {
} }
} }
// Generate account name with index for batch // Generate account name; fallback to email if name is empty (ent schema requires NotEmpty)
const accountName = refreshTokens.length > 1 ? `${form.name} #${i + 1}` : form.name const baseName = form.name || tokenInfo.email || 'OpenAI OAuth Account'
const accountName = refreshTokens.length > 1 ? `${baseName} #${i + 1}` : baseName
let openaiAccountId: string | number | undefined let openaiAccountId: string | number | undefined
...@@ -4494,6 +4504,12 @@ const handleOpenAIValidateRT = async (refreshTokenInput: string) => { ...@@ -4494,6 +4504,12 @@ const handleOpenAIValidateRT = async (refreshTokenInput: string) => {
} }
} }
// 手动输入 RT(Codex CLI client_id,默认)
const handleOpenAIValidateRT = (rt: string) => handleOpenAIBatchRT(rt)
// 手动输入 Mobile RT(SoraClientID)
const handleOpenAIValidateMobileRT = (rt: string) => handleOpenAIBatchRT(rt, OPENAI_MOBILE_RT_CLIENT_ID)
// Sora 手动 ST 批量验证和创建 // Sora 手动 ST 批量验证和创建
const handleSoraValidateST = async (sessionTokenInput: string) => { const handleSoraValidateST = async (sessionTokenInput: string) => {
const oauthClient = activeOpenAIOAuth.value const oauthClient = activeOpenAIOAuth.value
......
...@@ -48,6 +48,17 @@ ...@@ -48,6 +48,17 @@
t(getOAuthKey('refreshTokenAuth')) t(getOAuthKey('refreshTokenAuth'))
}}</span> }}</span>
</label> </label>
<label v-if="showMobileRefreshTokenOption" class="flex cursor-pointer items-center gap-2">
<input
v-model="inputMethod"
type="radio"
value="mobile_refresh_token"
class="text-blue-600 focus:ring-blue-500"
/>
<span class="text-sm text-blue-900 dark:text-blue-200">{{
t('admin.accounts.oauth.openai.mobileRefreshTokenAuth', '手动输入 Mobile RT')
}}</span>
</label>
<label v-if="showSessionTokenOption" class="flex cursor-pointer items-center gap-2"> <label v-if="showSessionTokenOption" class="flex cursor-pointer items-center gap-2">
<input <input
v-model="inputMethod" v-model="inputMethod"
...@@ -73,8 +84,8 @@ ...@@ -73,8 +84,8 @@
</div> </div>
</div> </div>
<!-- Refresh Token Input (OpenAI / Antigravity) --> <!-- Refresh Token Input (OpenAI / Antigravity / Mobile RT) -->
<div v-if="inputMethod === 'refresh_token'" class="space-y-4"> <div v-if="inputMethod === 'refresh_token' || inputMethod === 'mobile_refresh_token'" class="space-y-4">
<div <div
class="rounded-lg border border-blue-300 bg-white/80 p-4 dark:border-blue-600 dark:bg-gray-800/80" class="rounded-lg border border-blue-300 bg-white/80 p-4 dark:border-blue-600 dark:bg-gray-800/80"
> >
...@@ -759,6 +770,7 @@ interface Props { ...@@ -759,6 +770,7 @@ interface Props {
methodLabel?: string methodLabel?: string
showCookieOption?: boolean // Whether to show cookie auto-auth option showCookieOption?: boolean // Whether to show cookie auto-auth option
showRefreshTokenOption?: boolean // Whether to show refresh token input option (OpenAI only) showRefreshTokenOption?: boolean // Whether to show refresh token input option (OpenAI only)
showMobileRefreshTokenOption?: boolean // Whether to show mobile refresh token option (OpenAI only)
showSessionTokenOption?: boolean // Whether to show session token input option (Sora only) showSessionTokenOption?: boolean // Whether to show session token input option (Sora only)
showAccessTokenOption?: boolean // Whether to show access token input option (Sora only) showAccessTokenOption?: boolean // Whether to show access token input option (Sora only)
platform?: AccountPlatform // Platform type for different UI/text platform?: AccountPlatform // Platform type for different UI/text
...@@ -776,6 +788,7 @@ const props = withDefaults(defineProps<Props>(), { ...@@ -776,6 +788,7 @@ const props = withDefaults(defineProps<Props>(), {
methodLabel: 'Authorization Method', methodLabel: 'Authorization Method',
showCookieOption: true, showCookieOption: true,
showRefreshTokenOption: false, showRefreshTokenOption: false,
showMobileRefreshTokenOption: false,
showSessionTokenOption: false, showSessionTokenOption: false,
showAccessTokenOption: false, showAccessTokenOption: false,
platform: 'anthropic', platform: 'anthropic',
...@@ -787,6 +800,7 @@ const emit = defineEmits<{ ...@@ -787,6 +800,7 @@ const emit = defineEmits<{
'exchange-code': [code: string] 'exchange-code': [code: string]
'cookie-auth': [sessionKey: string] 'cookie-auth': [sessionKey: string]
'validate-refresh-token': [refreshToken: string] 'validate-refresh-token': [refreshToken: string]
'validate-mobile-refresh-token': [refreshToken: string]
'validate-session-token': [sessionToken: string] 'validate-session-token': [sessionToken: string]
'import-access-token': [accessToken: string] 'import-access-token': [accessToken: string]
'update:inputMethod': [method: AuthInputMethod] 'update:inputMethod': [method: AuthInputMethod]
...@@ -834,7 +848,7 @@ const oauthState = ref('') ...@@ -834,7 +848,7 @@ const oauthState = ref('')
const projectId = ref('') const projectId = ref('')
// Computed: show method selection when either cookie or refresh token option is enabled // Computed: show method selection when either cookie or refresh token option is enabled
const showMethodSelection = computed(() => props.showCookieOption || props.showRefreshTokenOption || props.showSessionTokenOption || props.showAccessTokenOption) const showMethodSelection = computed(() => props.showCookieOption || props.showRefreshTokenOption || props.showMobileRefreshTokenOption || props.showSessionTokenOption || props.showAccessTokenOption)
// Clipboard // Clipboard
const { copied, copyToClipboard } = useClipboard() const { copied, copyToClipboard } = useClipboard()
...@@ -945,7 +959,11 @@ const handleCookieAuth = () => { ...@@ -945,7 +959,11 @@ const handleCookieAuth = () => {
const handleValidateRefreshToken = () => { const handleValidateRefreshToken = () => {
if (refreshTokenInput.value.trim()) { if (refreshTokenInput.value.trim()) {
emit('validate-refresh-token', refreshTokenInput.value.trim()) if (inputMethod.value === 'mobile_refresh_token') {
emit('validate-mobile-refresh-token', refreshTokenInput.value.trim())
} else {
emit('validate-refresh-token', refreshTokenInput.value.trim())
}
} }
} }
......
...@@ -50,7 +50,21 @@ function mountModal(extraProps: Record<string, unknown> = {}) { ...@@ -50,7 +50,21 @@ function mountModal(extraProps: Record<string, unknown> = {}) {
stubs: { stubs: {
BaseDialog: { template: '<div><slot /><slot name="footer" /></div>' }, BaseDialog: { template: '<div><slot /><slot name="footer" /></div>' },
ConfirmDialog: true, ConfirmDialog: true,
Select: true, Select: {
props: ['modelValue', 'options'],
emits: ['update:modelValue'],
template: `
<select
v-bind="$attrs"
:value="modelValue"
@change="$emit('update:modelValue', $event.target.value)"
>
<option v-for="option in options" :key="option.value" :value="option.value">
{{ option.label }}
</option>
</select>
`
},
ProxySelector: true, ProxySelector: true,
GroupSelector: true, GroupSelector: true,
Icon: true Icon: true
...@@ -115,4 +129,33 @@ describe('BulkEditAccountModal', () => { ...@@ -115,4 +129,33 @@ describe('BulkEditAccountModal', () => {
} }
}) })
}) })
it('OpenAI OAuth 批量编辑应提交 OAuth 专属 WS mode 字段', async () => {
const wrapper = mountModal({
selectedPlatforms: ['openai'],
selectedTypes: ['oauth']
})
await wrapper.get('#bulk-edit-openai-ws-mode-enabled').setValue(true)
await wrapper.get('[data-testid="bulk-edit-openai-ws-mode-select"]').setValue('passthrough')
await wrapper.get('#bulk-edit-account-form').trigger('submit.prevent')
await flushPromises()
expect(adminAPI.accounts.bulkUpdate).toHaveBeenCalledTimes(1)
expect(adminAPI.accounts.bulkUpdate).toHaveBeenCalledWith([1, 2], {
extra: {
openai_oauth_responses_websockets_v2_mode: 'passthrough',
openai_oauth_responses_websockets_v2_enabled: true
}
})
})
it('OpenAI API Key 批量编辑不显示 WS mode 入口', () => {
const wrapper = mountModal({
selectedPlatforms: ['openai'],
selectedTypes: ['apikey']
})
expect(wrapper.find('#bulk-edit-openai-ws-mode-enabled').exists()).toBe(false)
})
}) })
...@@ -10,6 +10,7 @@ ...@@ -10,6 +10,7 @@
<Select :model-value="filters.platform" class="w-40" :options="pOpts" @update:model-value="updatePlatform" @change="$emit('change')" /> <Select :model-value="filters.platform" class="w-40" :options="pOpts" @update:model-value="updatePlatform" @change="$emit('change')" />
<Select :model-value="filters.type" class="w-40" :options="tOpts" @update:model-value="updateType" @change="$emit('change')" /> <Select :model-value="filters.type" class="w-40" :options="tOpts" @update:model-value="updateType" @change="$emit('change')" />
<Select :model-value="filters.status" class="w-40" :options="sOpts" @update:model-value="updateStatus" @change="$emit('change')" /> <Select :model-value="filters.status" class="w-40" :options="sOpts" @update:model-value="updateStatus" @change="$emit('change')" />
<Select :model-value="filters.privacy_mode" class="w-40" :options="privacyOpts" @update:model-value="updatePrivacyMode" @change="$emit('change')" />
<Select :model-value="filters.group" class="w-40" :options="gOpts" @update:model-value="updateGroup" @change="$emit('change')" /> <Select :model-value="filters.group" class="w-40" :options="gOpts" @update:model-value="updateGroup" @change="$emit('change')" />
</div> </div>
</template> </template>
...@@ -22,10 +23,18 @@ const emit = defineEmits(['update:searchQuery', 'update:filters', 'change']); co ...@@ -22,10 +23,18 @@ const emit = defineEmits(['update:searchQuery', 'update:filters', 'change']); co
const updatePlatform = (value: string | number | boolean | null) => { emit('update:filters', { ...props.filters, platform: value }) } const updatePlatform = (value: string | number | boolean | null) => { emit('update:filters', { ...props.filters, platform: value }) }
const updateType = (value: string | number | boolean | null) => { emit('update:filters', { ...props.filters, type: value }) } const updateType = (value: string | number | boolean | null) => { emit('update:filters', { ...props.filters, type: value }) }
const updateStatus = (value: string | number | boolean | null) => { emit('update:filters', { ...props.filters, status: value }) } const updateStatus = (value: string | number | boolean | null) => { emit('update:filters', { ...props.filters, status: value }) }
const updatePrivacyMode = (value: string | number | boolean | null) => { emit('update:filters', { ...props.filters, privacy_mode: value }) }
const updateGroup = (value: string | number | boolean | null) => { emit('update:filters', { ...props.filters, group: value }) } const updateGroup = (value: string | number | boolean | null) => { emit('update:filters', { ...props.filters, group: value }) }
const pOpts = computed(() => [{ value: '', label: t('admin.accounts.allPlatforms') }, { value: 'anthropic', label: 'Anthropic' }, { value: 'openai', label: 'OpenAI' }, { value: 'gemini', label: 'Gemini' }, { value: 'antigravity', label: 'Antigravity' }, { value: 'sora', label: 'Sora' }]) const pOpts = computed(() => [{ value: '', label: t('admin.accounts.allPlatforms') }, { value: 'anthropic', label: 'Anthropic' }, { value: 'openai', label: 'OpenAI' }, { value: 'gemini', label: 'Gemini' }, { value: 'antigravity', label: 'Antigravity' }, { value: 'sora', label: 'Sora' }])
const tOpts = computed(() => [{ value: '', label: t('admin.accounts.allTypes') }, { value: 'oauth', label: t('admin.accounts.oauthType') }, { value: 'setup-token', label: t('admin.accounts.setupToken') }, { value: 'apikey', label: t('admin.accounts.apiKey') }, { value: 'bedrock', label: 'AWS Bedrock' }]) const tOpts = computed(() => [{ value: '', label: t('admin.accounts.allTypes') }, { value: 'oauth', label: t('admin.accounts.oauthType') }, { value: 'setup-token', label: t('admin.accounts.setupToken') }, { value: 'apikey', label: t('admin.accounts.apiKey') }, { value: 'bedrock', label: 'AWS Bedrock' }])
const sOpts = computed(() => [{ value: '', label: t('admin.accounts.allStatus') }, { value: 'active', label: t('admin.accounts.status.active') }, { value: 'inactive', label: t('admin.accounts.status.inactive') }, { value: 'error', label: t('admin.accounts.status.error') }, { value: 'rate_limited', label: t('admin.accounts.status.rateLimited') }, { value: 'temp_unschedulable', label: t('admin.accounts.status.tempUnschedulable') }]) const sOpts = computed(() => [{ value: '', label: t('admin.accounts.allStatus') }, { value: 'active', label: t('admin.accounts.status.active') }, { value: 'inactive', label: t('admin.accounts.status.inactive') }, { value: 'error', label: t('admin.accounts.status.error') }, { value: 'rate_limited', label: t('admin.accounts.status.rateLimited') }, { value: 'temp_unschedulable', label: t('admin.accounts.status.tempUnschedulable') }])
const privacyOpts = computed(() => [
{ value: '', label: t('admin.accounts.allPrivacyModes') },
{ value: '__unset__', label: t('admin.accounts.privacyUnset') },
{ value: 'training_off', label: 'Privacy' },
{ value: 'training_set_cf_blocked', label: 'CF' },
{ value: 'training_set_failed', label: 'Fail' }
])
const gOpts = computed(() => [ const gOpts = computed(() => [
{ value: '', label: t('admin.accounts.allGroups') }, { value: '', label: t('admin.accounts.allGroups') },
{ value: 'ungrouped', label: t('admin.accounts.ungroupedGroup') }, { value: 'ungrouped', label: t('admin.accounts.ungroupedGroup') },
......
import { describe, expect, it, vi } from 'vitest'
import { mount } from '@vue/test-utils'
import AccountTableFilters from '../AccountTableFilters.vue'
vi.mock('vue-i18n', async () => {
const actual = await vi.importActual<typeof import('vue-i18n')>('vue-i18n')
return {
...actual,
useI18n: () => ({
t: (key: string) => key
})
}
})
describe('AccountTableFilters', () => {
it('renders privacy mode options and emits privacy_mode updates', async () => {
const wrapper = mount(AccountTableFilters, {
props: {
searchQuery: '',
filters: {
platform: '',
type: '',
status: '',
group: '',
privacy_mode: ''
},
groups: []
},
global: {
stubs: {
SearchInput: {
template: '<div />'
},
Select: {
props: ['modelValue', 'options'],
emits: ['update:modelValue', 'change'],
template: '<div class="select-stub" :data-options="JSON.stringify(options)" />'
}
}
}
})
const selects = wrapper.findAll('.select-stub')
expect(selects).toHaveLength(5)
const privacyOptions = JSON.parse(selects[3].attributes('data-options'))
expect(privacyOptions).toEqual([
{ value: '', label: 'admin.accounts.allPrivacyModes' },
{ value: '__unset__', label: 'admin.accounts.privacyUnset' },
{ value: 'training_off', label: 'Privacy' },
{ value: 'training_set_cf_blocked', label: 'CF' },
{ value: 'training_set_failed', label: 'Fail' }
])
})
})
...@@ -3,7 +3,7 @@ import { useAppStore } from '@/stores/app' ...@@ -3,7 +3,7 @@ import { useAppStore } from '@/stores/app'
import { adminAPI } from '@/api/admin' import { adminAPI } from '@/api/admin'
export type AddMethod = 'oauth' | 'setup-token' export type AddMethod = 'oauth' | 'setup-token'
export type AuthInputMethod = 'manual' | 'cookie' | 'refresh_token' | 'session_token' | 'access_token' export type AuthInputMethod = 'manual' | 'cookie' | 'refresh_token' | 'mobile_refresh_token' | 'session_token' | 'access_token'
export interface OAuthState { export interface OAuthState {
authUrl: string authUrl: string
......
...@@ -13,6 +13,8 @@ export interface OpenAITokenInfo { ...@@ -13,6 +13,8 @@ export interface OpenAITokenInfo {
scope?: string scope?: string
email?: string email?: string
name?: string name?: string
plan_type?: string
privacy_mode?: string
// OpenAI specific IDs (extracted from ID Token) // OpenAI specific IDs (extracted from ID Token)
chatgpt_account_id?: string chatgpt_account_id?: string
chatgpt_user_id?: string chatgpt_user_id?: string
...@@ -126,9 +128,11 @@ export function useOpenAIOAuth(options?: UseOpenAIOAuthOptions) { ...@@ -126,9 +128,11 @@ export function useOpenAIOAuth(options?: UseOpenAIOAuthOptions) {
} }
// Validate refresh token and get full token info // Validate refresh token and get full token info
// clientId: 指定 OAuth client_id(用于第三方渠道获取的 RT,如 app_LlGpXReQgckcGGUo2JrYvtJK)
const validateRefreshToken = async ( const validateRefreshToken = async (
refreshToken: string, refreshToken: string,
proxyId?: number | null proxyId?: number | null,
clientId?: string
): Promise<OpenAITokenInfo | null> => { ): Promise<OpenAITokenInfo | null> => {
if (!refreshToken.trim()) { if (!refreshToken.trim()) {
error.value = 'Missing refresh token' error.value = 'Missing refresh token'
...@@ -143,11 +147,12 @@ export function useOpenAIOAuth(options?: UseOpenAIOAuthOptions) { ...@@ -143,11 +147,12 @@ export function useOpenAIOAuth(options?: UseOpenAIOAuthOptions) {
const tokenInfo = await adminAPI.accounts.refreshOpenAIToken( const tokenInfo = await adminAPI.accounts.refreshOpenAIToken(
refreshToken.trim(), refreshToken.trim(),
proxyId, proxyId,
`${endpointPrefix}/refresh-token` `${endpointPrefix}/refresh-token`,
clientId
) )
return tokenInfo as OpenAITokenInfo return tokenInfo as OpenAITokenInfo
} catch (err: any) { } catch (err: any) {
error.value = err.response?.data?.detail || 'Failed to validate refresh token' error.value = err.response?.data?.detail || err.message || 'Failed to validate refresh token'
appStore.showError(error.value) appStore.showError(error.value)
return null return null
} finally { } finally {
...@@ -182,22 +187,23 @@ export function useOpenAIOAuth(options?: UseOpenAIOAuthOptions) { ...@@ -182,22 +187,23 @@ export function useOpenAIOAuth(options?: UseOpenAIOAuthOptions) {
} }
} }
// Build credentials for OpenAI OAuth account // Build credentials for OpenAI OAuth account (aligned with backend BuildAccountCredentials)
const buildCredentials = (tokenInfo: OpenAITokenInfo): Record<string, unknown> => { const buildCredentials = (tokenInfo: OpenAITokenInfo): Record<string, unknown> => {
const creds: Record<string, unknown> = { const creds: Record<string, unknown> = {
access_token: tokenInfo.access_token, access_token: tokenInfo.access_token,
refresh_token: tokenInfo.refresh_token, expires_at: tokenInfo.expires_at
token_type: tokenInfo.token_type,
expires_in: tokenInfo.expires_in,
expires_at: tokenInfo.expires_at,
scope: tokenInfo.scope
} }
if (tokenInfo.client_id) { // 仅在返回了新的 refresh_token 时才写入,防止用空值覆盖已有令牌
creds.client_id = tokenInfo.client_id if (tokenInfo.refresh_token) {
creds.refresh_token = tokenInfo.refresh_token
}
if (tokenInfo.id_token) {
creds.id_token = tokenInfo.id_token
}
if (tokenInfo.email) {
creds.email = tokenInfo.email
} }
// Include OpenAI specific IDs (required for forwarding)
if (tokenInfo.chatgpt_account_id) { if (tokenInfo.chatgpt_account_id) {
creds.chatgpt_account_id = tokenInfo.chatgpt_account_id creds.chatgpt_account_id = tokenInfo.chatgpt_account_id
} }
...@@ -207,6 +213,12 @@ export function useOpenAIOAuth(options?: UseOpenAIOAuthOptions) { ...@@ -207,6 +213,12 @@ export function useOpenAIOAuth(options?: UseOpenAIOAuthOptions) {
if (tokenInfo.organization_id) { if (tokenInfo.organization_id) {
creds.organization_id = tokenInfo.organization_id creds.organization_id = tokenInfo.organization_id
} }
if (tokenInfo.plan_type) {
creds.plan_type = tokenInfo.plan_type
}
if (tokenInfo.client_id) {
creds.client_id = tokenInfo.client_id
}
return creds return creds
} }
...@@ -220,6 +232,9 @@ export function useOpenAIOAuth(options?: UseOpenAIOAuthOptions) { ...@@ -220,6 +232,9 @@ export function useOpenAIOAuth(options?: UseOpenAIOAuthOptions) {
if (tokenInfo.name) { if (tokenInfo.name) {
extra.name = tokenInfo.name extra.name = tokenInfo.name
} }
if (tokenInfo.privacy_mode) {
extra.privacy_mode = tokenInfo.privacy_mode
}
return Object.keys(extra).length > 0 ? extra : undefined return Object.keys(extra).length > 0 ? extra : undefined
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment