Commit bb399e56 authored by Wang Lvyuan's avatar Wang Lvyuan
Browse files

merge: resolve upstream main conflicts for bulk OpenAI passthrough

parents 73d72651 0f033930
......@@ -73,12 +73,13 @@ func (r *openAICodexExtraListRepo) SetRateLimited(_ context.Context, _ int64, re
return nil
}
func (r *openAICodexExtraListRepo) ListWithFilters(_ context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64) ([]Account, *pagination.PaginationResult, error) {
func (r *openAICodexExtraListRepo) ListWithFilters(_ context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64, privacyMode string) ([]Account, *pagination.PaginationResult, error) {
_ = platform
_ = accountType
_ = status
_ = search
_ = groupID
_ = privacyMode
return r.accounts, &pagination.PaginationResult{Total: int64(len(r.accounts)), Page: params.Page, PageSize: params.PageSize}, nil
}
......@@ -491,7 +492,7 @@ func TestAdminService_ListAccounts_ExhaustedCodexExtraReturnsRateLimitedAccount(
}
svc := &adminServiceImpl{accountRepo: repo}
accounts, total, err := svc.ListAccounts(context.Background(), 1, 20, PlatformOpenAI, AccountTypeOAuth, "", "", 0)
accounts, total, err := svc.ListAccounts(context.Background(), 1, 20, PlatformOpenAI, AccountTypeOAuth, "", "", 0, "")
require.NoError(t, err)
require.Equal(t, int64(1), total)
require.Len(t, accounts, 1)
......
......@@ -24,7 +24,7 @@ func (s *OpsService) listAllAccountsForOps(ctx context.Context, platformFilter s
accounts, pageInfo, err := s.accountRepo.ListWithFilters(ctx, pagination.PaginationParams{
Page: page,
PageSize: opsAccountsPageSize,
}, platformFilter, "", "", "", 0)
}, platformFilter, "", "", "", 0, "")
if err != nil {
return nil, err
}
......
......@@ -62,6 +62,12 @@ type OpsErrorLog struct {
ClientIP *string `json:"client_ip"`
RequestPath string `json:"request_path"`
Stream bool `json:"stream"`
InboundEndpoint string `json:"inbound_endpoint"`
UpstreamEndpoint string `json:"upstream_endpoint"`
RequestedModel string `json:"requested_model"`
UpstreamModel string `json:"upstream_model"`
RequestType *int16 `json:"request_type"`
}
type OpsErrorLogDetail struct {
......
......@@ -79,6 +79,17 @@ type OpsInsertErrorLogInput struct {
Model string
RequestPath string
Stream bool
// InboundEndpoint is the normalized client-facing API endpoint path, e.g. /v1/chat/completions.
InboundEndpoint string
// UpstreamEndpoint is the normalized upstream endpoint path, e.g. /v1/responses.
UpstreamEndpoint string
// RequestedModel is the client-requested model name before mapping.
RequestedModel string
// UpstreamModel is the actual model sent to upstream after mapping. Empty means no mapping.
UpstreamModel string
// RequestType is the granular request type: 0=unknown, 1=sync, 2=stream, 3=ws_v2.
// Matches service.RequestType enum semantics from usage_log.go.
RequestType *int16
UserAgent string
ErrorPhase string
......
......@@ -93,6 +93,10 @@ type OpsUpstreamErrorEvent struct {
UpstreamStatusCode int `json:"upstream_status_code,omitempty"`
UpstreamRequestID string `json:"upstream_request_id,omitempty"`
// UpstreamURL is the actual upstream URL that was called (host + path, query/fragment stripped).
// Helps debug 404/routing errors by showing which endpoint was targeted.
UpstreamURL string `json:"upstream_url,omitempty"`
// Best-effort upstream request capture (sanitized+trimmed).
// Required for retrying a specific upstream attempt.
UpstreamRequestBody string `json:"upstream_request_body,omitempty"`
......@@ -119,6 +123,7 @@ func appendOpsUpstreamError(c *gin.Context, ev OpsUpstreamErrorEvent) {
ev.UpstreamRequestBody = strings.TrimSpace(ev.UpstreamRequestBody)
ev.UpstreamResponseBody = strings.TrimSpace(ev.UpstreamResponseBody)
ev.Kind = strings.TrimSpace(ev.Kind)
ev.UpstreamURL = strings.TrimSpace(ev.UpstreamURL)
ev.Message = strings.TrimSpace(ev.Message)
ev.Detail = strings.TrimSpace(ev.Detail)
if ev.Message != "" {
......@@ -205,3 +210,19 @@ func ParseOpsUpstreamErrors(raw string) ([]*OpsUpstreamErrorEvent, error) {
}
return out, nil
}
// safeUpstreamURL returns scheme + host + path from a URL, stripping query/fragment
// to avoid leaking sensitive query parameters (e.g. OAuth tokens).
func safeUpstreamURL(rawURL string) string {
rawURL = strings.TrimSpace(rawURL)
if rawURL == "" {
return ""
}
if idx := strings.IndexByte(rawURL, '?'); idx >= 0 {
rawURL = rawURL[:idx]
}
if idx := strings.IndexByte(rawURL, '#'); idx >= 0 {
rawURL = rawURL[:idx]
}
return rawURL
}
......@@ -8,6 +8,27 @@ import (
"github.com/stretchr/testify/require"
)
func TestSafeUpstreamURL(t *testing.T) {
tests := []struct {
name string
input string
want string
}{
{"strips query", "https://api.anthropic.com/v1/messages?beta=true", "https://api.anthropic.com/v1/messages"},
{"strips fragment", "https://api.openai.com/v1/responses#frag", "https://api.openai.com/v1/responses"},
{"strips both", "https://host/path?token=secret#x", "https://host/path"},
{"no query or fragment", "https://host/path", "https://host/path"},
{"empty string", "", ""},
{"whitespace only", " ", ""},
{"query before fragment", "https://h/p?a=1#f", "https://h/p"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
require.Equal(t, tt.want, safeUpstreamURL(tt.input))
})
}
}
func TestAppendOpsUpstreamError_UsesRequestBodyBytesFromContext(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
......
......@@ -163,7 +163,7 @@ func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Acc
account.Credentials = make(map[string]any)
}
account.Credentials["expires_at"] = time.Now().Format(time.RFC3339)
if err := s.accountRepo.Update(ctx, account); err != nil {
if err := persistAccountCredentials(ctx, s.accountRepo, account, account.Credentials); err != nil {
slog.Warn("oauth_401_force_refresh_update_failed", "account_id", account.ID, "error", err)
} else {
slog.Info("oauth_401_force_refresh_set", "account_id", account.ID, "platform", account.Platform)
......
......@@ -15,9 +15,11 @@ import (
type rateLimitAccountRepoStub struct {
mockAccountRepoForGemini
setErrorCalls int
tempCalls int
lastErrorMsg string
setErrorCalls int
tempCalls int
updateCredentialsCalls int
lastCredentials map[string]any
lastErrorMsg string
}
func (r *rateLimitAccountRepoStub) SetError(ctx context.Context, id int64, errorMsg string) error {
......@@ -31,6 +33,12 @@ func (r *rateLimitAccountRepoStub) SetTempUnschedulable(ctx context.Context, id
return nil
}
func (r *rateLimitAccountRepoStub) UpdateCredentials(ctx context.Context, id int64, credentials map[string]any) error {
r.updateCredentialsCalls++
r.lastCredentials = cloneCredentials(credentials)
return nil
}
type tokenCacheInvalidatorRecorder struct {
accounts []*Account
err error
......@@ -110,6 +118,7 @@ func TestRateLimitService_HandleUpstreamError_OAuth401InvalidatorError(t *testin
require.True(t, shouldDisable)
require.Equal(t, 0, repo.setErrorCalls)
require.Equal(t, 1, repo.tempCalls)
require.Equal(t, 1, repo.updateCredentialsCalls)
require.Len(t, invalidator.accounts, 1)
}
......@@ -130,3 +139,22 @@ func TestRateLimitService_HandleUpstreamError_NonOAuth401(t *testing.T) {
require.Equal(t, 1, repo.setErrorCalls)
require.Empty(t, invalidator.accounts)
}
func TestRateLimitService_HandleUpstreamError_OAuth401UsesCredentialsUpdater(t *testing.T) {
repo := &rateLimitAccountRepoStub{}
service := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
account := &Account{
ID: 103,
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "token",
},
}
shouldDisable := service.HandleUpstreamError(context.Background(), account, 401, http.Header{}, []byte("unauthorized"))
require.True(t, shouldDisable)
require.Equal(t, 1, repo.updateCredentialsCalls)
require.NotEmpty(t, repo.lastCredentials["expires_at"])
}
......@@ -81,7 +81,7 @@ func (m *sessionWindowMockRepo) Delete(context.Context, int64) error { panic(
func (m *sessionWindowMockRepo) List(context.Context, pagination.PaginationParams) ([]Account, *pagination.PaginationResult, error) {
panic("unexpected")
}
func (m *sessionWindowMockRepo) ListWithFilters(context.Context, pagination.PaginationParams, string, string, string, string, int64) ([]Account, *pagination.PaginationResult, error) {
func (m *sessionWindowMockRepo) ListWithFilters(context.Context, pagination.PaginationParams, string, string, string, string, int64, string) ([]Account, *pagination.PaginationResult, error) {
panic("unexpected")
}
func (m *sessionWindowMockRepo) ListByGroup(context.Context, int64) ([]Account, error) {
......
......@@ -150,6 +150,7 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
SettingKeyPurchaseSubscriptionURL,
SettingKeySoraClientEnabled,
SettingKeyCustomMenuItems,
SettingKeyCustomEndpoints,
SettingKeyLinuxDoConnectEnabled,
SettingKeyBackendModeEnabled,
}
......@@ -195,6 +196,7 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
PurchaseSubscriptionURL: strings.TrimSpace(settings[SettingKeyPurchaseSubscriptionURL]),
SoraClientEnabled: settings[SettingKeySoraClientEnabled] == "true",
CustomMenuItems: settings[SettingKeyCustomMenuItems],
CustomEndpoints: settings[SettingKeyCustomEndpoints],
LinuxDoOAuthEnabled: linuxDoEnabled,
BackendModeEnabled: settings[SettingKeyBackendModeEnabled] == "true",
}, nil
......@@ -247,6 +249,7 @@ func (s *SettingService) GetPublicSettingsForInjection(ctx context.Context) (any
PurchaseSubscriptionURL string `json:"purchase_subscription_url,omitempty"`
SoraClientEnabled bool `json:"sora_client_enabled"`
CustomMenuItems json.RawMessage `json:"custom_menu_items"`
CustomEndpoints json.RawMessage `json:"custom_endpoints"`
LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"`
BackendModeEnabled bool `json:"backend_mode_enabled"`
Version string `json:"version,omitempty"`
......@@ -272,6 +275,7 @@ func (s *SettingService) GetPublicSettingsForInjection(ctx context.Context) (any
PurchaseSubscriptionURL: settings.PurchaseSubscriptionURL,
SoraClientEnabled: settings.SoraClientEnabled,
CustomMenuItems: filterUserVisibleMenuItems(settings.CustomMenuItems),
CustomEndpoints: safeRawJSONArray(settings.CustomEndpoints),
LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled,
BackendModeEnabled: settings.BackendModeEnabled,
Version: s.version,
......@@ -314,6 +318,18 @@ func filterUserVisibleMenuItems(raw string) json.RawMessage {
return result
}
// safeRawJSONArray returns raw as json.RawMessage if it's valid JSON, otherwise "[]".
func safeRawJSONArray(raw string) json.RawMessage {
raw = strings.TrimSpace(raw)
if raw == "" {
return json.RawMessage("[]")
}
if json.Valid([]byte(raw)) {
return json.RawMessage(raw)
}
return json.RawMessage("[]")
}
// GetFrameSrcOrigins returns deduplicated http(s) origins from purchase_subscription_url
// and all custom_menu_items URLs. Used by the router layer for CSP frame-src injection.
func (s *SettingService) GetFrameSrcOrigins(ctx context.Context) ([]string, error) {
......@@ -454,6 +470,7 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet
updates[SettingKeyPurchaseSubscriptionURL] = strings.TrimSpace(settings.PurchaseSubscriptionURL)
updates[SettingKeySoraClientEnabled] = strconv.FormatBool(settings.SoraClientEnabled)
updates[SettingKeyCustomMenuItems] = settings.CustomMenuItems
updates[SettingKeyCustomEndpoints] = settings.CustomEndpoints
// 默认配置
updates[SettingKeyDefaultConcurrency] = strconv.Itoa(settings.DefaultConcurrency)
......@@ -740,6 +757,7 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error {
SettingKeyPurchaseSubscriptionURL: "",
SettingKeySoraClientEnabled: "false",
SettingKeyCustomMenuItems: "[]",
SettingKeyCustomEndpoints: "[]",
SettingKeyDefaultConcurrency: strconv.Itoa(s.cfg.Default.UserConcurrency),
SettingKeyDefaultBalance: strconv.FormatFloat(s.cfg.Default.UserBalance, 'f', 8, 64),
SettingKeyDefaultSubscriptions: "[]",
......@@ -805,6 +823,7 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin
PurchaseSubscriptionURL: strings.TrimSpace(settings[SettingKeyPurchaseSubscriptionURL]),
SoraClientEnabled: settings[SettingKeySoraClientEnabled] == "true",
CustomMenuItems: settings[SettingKeyCustomMenuItems],
CustomEndpoints: settings[SettingKeyCustomEndpoints],
BackendModeEnabled: settings[SettingKeyBackendModeEnabled] == "true",
}
......
......@@ -43,6 +43,7 @@ type SystemSettings struct {
PurchaseSubscriptionURL string
SoraClientEnabled bool
CustomMenuItems string // JSON array of custom menu items
CustomEndpoints string // JSON array of custom endpoints
DefaultConcurrency int
DefaultBalance float64
......@@ -104,6 +105,7 @@ type PublicSettings struct {
PurchaseSubscriptionURL string
SoraClientEnabled bool
CustomMenuItems string // JSON array of custom menu items
CustomEndpoints string // JSON array of custom endpoints
LinuxDoOAuthEnabled bool
BackendModeEnabled bool
......
......@@ -947,7 +947,7 @@ func (c *SoraSDKClient) applyRecoveredToken(ctx context.Context, account *Accoun
}
if c.accountRepo != nil {
if err := c.accountRepo.Update(ctx, account); err != nil && c.debugEnabled() {
if err := persistAccountCredentials(ctx, c.accountRepo, account, account.Credentials); err != nil && c.debugEnabled() {
c.debugLogf("persist_recovered_token_failed account_id=%d err=%s", account.ID, logredact.RedactText(err.Error()))
}
}
......
......@@ -280,8 +280,7 @@ func (s *TokenRefreshService) refreshWithRetry(ctx context.Context, account *Acc
newCredentials, err = refresher.Refresh(ctx, account)
if newCredentials != nil {
newCredentials["_token_version"] = time.Now().UnixMilli()
account.Credentials = newCredentials
if saveErr := s.accountRepo.Update(ctx, account); saveErr != nil {
if saveErr := persistAccountCredentials(ctx, s.accountRepo, account, newCredentials); saveErr != nil {
return fmt.Errorf("failed to save credentials: %w", saveErr)
}
}
......
......@@ -14,19 +14,40 @@ import (
type tokenRefreshAccountRepo struct {
mockAccountRepoForGemini
updateCalls int
setErrorCalls int
clearTempCalls int
lastAccount *Account
updateErr error
updateCalls int
fullUpdateCalls int
updateCredentialsCalls int
setErrorCalls int
clearTempCalls int
lastAccount *Account
updateErr error
}
func (r *tokenRefreshAccountRepo) Update(ctx context.Context, account *Account) error {
r.updateCalls++
r.fullUpdateCalls++
r.lastAccount = account
return r.updateErr
}
func (r *tokenRefreshAccountRepo) UpdateCredentials(ctx context.Context, id int64, credentials map[string]any) error {
r.updateCalls++
r.updateCredentialsCalls++
if r.updateErr != nil {
return r.updateErr
}
cloned := cloneCredentials(credentials)
if r.accountsByID != nil {
if acc, ok := r.accountsByID[id]; ok && acc != nil {
acc.Credentials = cloned
r.lastAccount = acc
return nil
}
}
r.lastAccount = &Account{ID: id, Credentials: cloned}
return nil
}
func (r *tokenRefreshAccountRepo) SetError(ctx context.Context, id int64, errorMsg string) error {
r.setErrorCalls++
return nil
......@@ -112,6 +133,8 @@ func TestTokenRefreshService_RefreshWithRetry_InvalidatesCache(t *testing.T) {
err := service.refreshWithRetry(context.Background(), account, refresher, refresher, time.Hour)
require.NoError(t, err)
require.Equal(t, 1, repo.updateCalls)
require.Equal(t, 1, repo.updateCredentialsCalls)
require.Equal(t, 0, repo.fullUpdateCalls)
require.Equal(t, 1, invalidator.calls)
require.Equal(t, "new-token", account.GetCredential("access_token"))
}
......@@ -249,9 +272,43 @@ func TestTokenRefreshService_RefreshWithRetry_OtherPlatformOAuth(t *testing.T) {
err := service.refreshWithRetry(context.Background(), account, refresher, refresher, time.Hour)
require.NoError(t, err)
require.Equal(t, 1, repo.updateCalls)
require.Equal(t, 1, repo.updateCredentialsCalls)
require.Equal(t, 1, invalidator.calls) // 所有 OAuth 账户刷新后触发缓存失效
}
func TestTokenRefreshService_RefreshWithRetry_UsesCredentialsUpdater(t *testing.T) {
repo := &tokenRefreshAccountRepo{}
cfg := &config.Config{
TokenRefresh: config.TokenRefreshConfig{
MaxRetries: 1,
RetryBackoffSeconds: 0,
},
}
service := NewTokenRefreshService(repo, nil, nil, nil, nil, nil, nil, cfg, nil)
resetAt := time.Now().Add(30 * time.Minute)
account := &Account{
ID: 17,
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
RateLimitResetAt: &resetAt,
Credentials: map[string]any{
"access_token": "old-token",
},
}
refresher := &tokenRefresherStub{
credentials: map[string]any{
"access_token": "new-token",
},
}
err := service.refreshWithRetry(context.Background(), account, refresher, refresher, time.Hour)
require.NoError(t, err)
require.Equal(t, 1, repo.updateCredentialsCalls)
require.Equal(t, 0, repo.fullUpdateCalls)
require.NotNil(t, account.RateLimitResetAt)
require.WithinDuration(t, resetAt, *account.RateLimitResetAt, time.Second)
}
// TestTokenRefreshService_RefreshWithRetry_UpdateFailed 测试更新失败的情况
func TestTokenRefreshService_RefreshWithRetry_UpdateFailed(t *testing.T) {
repo := &tokenRefreshAccountRepo{updateErr: errors.New("update failed")}
......@@ -390,7 +447,7 @@ func TestTokenRefreshService_RefreshWithRetry_ClearsTempUnschedulable(t *testing
err := service.refreshWithRetry(context.Background(), account, refresher, refresher, time.Hour)
require.NoError(t, err)
require.Equal(t, 1, repo.updateCalls)
require.Equal(t, 1, repo.clearTempCalls) // DB 清除
require.Equal(t, 1, repo.clearTempCalls) // DB 清除
require.Equal(t, 1, tempCache.deleteCalls) // Redis 缓存也应清除
}
......
-- Ops error logs: add endpoint, model mapping, and request_type fields
-- to match usage_logs observability coverage.
--
-- All columns are nullable with no default to preserve backward compatibility
-- with existing rows.
SET LOCAL lock_timeout = '5s';
SET LOCAL statement_timeout = '10min';
-- 1) Standardized endpoint paths (analogous to usage_logs.inbound_endpoint / upstream_endpoint)
ALTER TABLE ops_error_logs
ADD COLUMN IF NOT EXISTS inbound_endpoint VARCHAR(256),
ADD COLUMN IF NOT EXISTS upstream_endpoint VARCHAR(256);
-- 2) Model mapping fields (analogous to usage_logs.requested_model / upstream_model)
ALTER TABLE ops_error_logs
ADD COLUMN IF NOT EXISTS requested_model VARCHAR(100),
ADD COLUMN IF NOT EXISTS upstream_model VARCHAR(100);
-- 3) Granular request type enum (analogous to usage_logs.request_type: 0=unknown, 1=sync, 2=stream, 3=ws_v2)
ALTER TABLE ops_error_logs
ADD COLUMN IF NOT EXISTS request_type SMALLINT;
COMMENT ON COLUMN ops_error_logs.inbound_endpoint IS 'Normalized client-facing API endpoint path, e.g. /v1/chat/completions. Populated from InboundEndpointMiddleware.';
COMMENT ON COLUMN ops_error_logs.upstream_endpoint IS 'Normalized upstream endpoint path derived from platform, e.g. /v1/responses.';
COMMENT ON COLUMN ops_error_logs.requested_model IS 'Client-requested model name before mapping (raw from request body).';
COMMENT ON COLUMN ops_error_logs.upstream_model IS 'Actual model sent to upstream provider after mapping. NULL means no mapping applied.';
COMMENT ON COLUMN ops_error_logs.request_type IS 'Request type enum: 0=unknown, 1=sync, 2=stream, 3=ws_v2. Matches usage_logs.request_type semantics.';
......@@ -36,6 +36,7 @@ export async function list(
status?: string
group?: string
search?: string
privacy_mode?: string
lite?: string
},
options?: {
......@@ -68,6 +69,7 @@ export async function listWithEtag(
status?: string
group?: string
search?: string
privacy_mode?: string
lite?: string
},
options?: {
......@@ -550,14 +552,18 @@ export async function getAntigravityDefaultModelMapping(): Promise<Record<string
export async function refreshOpenAIToken(
refreshToken: string,
proxyId?: number | null,
endpoint: string = '/admin/openai/refresh-token'
endpoint: string = '/admin/openai/refresh-token',
clientId?: string
): Promise<Record<string, unknown>> {
const payload: { refresh_token: string; proxy_id?: number } = {
const payload: { refresh_token: string; proxy_id?: number; client_id?: string } = {
refresh_token: refreshToken
}
if (proxyId) {
payload.proxy_id = proxyId
}
if (clientId) {
payload.client_id = clientId
}
const { data } = await apiClient.post<Record<string, unknown>>(endpoint, payload)
return data
}
......
......@@ -969,6 +969,13 @@ export interface OpsErrorLog {
client_ip?: string | null
request_path?: string
stream?: boolean
// Error observability context (endpoint + model mapping)
inbound_endpoint?: string
upstream_endpoint?: string
requested_model?: string
upstream_model?: string
request_type?: number | null
}
export interface OpsErrorDetail extends OpsErrorLog {
......
......@@ -4,7 +4,7 @@
*/
import { apiClient } from '../client'
import type { CustomMenuItem } from '@/types'
import type { CustomMenuItem, CustomEndpoint } from '@/types'
export interface DefaultSubscriptionSetting {
group_id: number
......@@ -43,6 +43,7 @@ export interface SystemSettings {
sora_client_enabled: boolean
backend_mode_enabled: boolean
custom_menu_items: CustomMenuItem[]
custom_endpoints: CustomEndpoint[]
// SMTP settings
smtp_host: string
smtp_port: number
......@@ -112,6 +113,7 @@ export interface UpdateSettingsRequest {
sora_client_enabled?: boolean
backend_mode_enabled?: boolean
custom_menu_items?: CustomMenuItem[]
custom_endpoints?: CustomEndpoint[]
smtp_host?: string
smtp_port?: number
smtp_username?: string
......
......@@ -661,6 +661,43 @@
</div>
</div>
<!-- OpenAI OAuth WS mode -->
<div v-if="allOpenAIOAuth" class="border-t border-gray-200 pt-4 dark:border-dark-600">
<div class="mb-3 flex items-center justify-between">
<label
id="bulk-edit-openai-ws-mode-label"
class="input-label mb-0"
for="bulk-edit-openai-ws-mode-enabled"
>
{{ t('admin.accounts.openai.wsMode') }}
</label>
<input
v-model="enableOpenAIWSMode"
id="bulk-edit-openai-ws-mode-enabled"
type="checkbox"
aria-controls="bulk-edit-openai-ws-mode"
class="rounded border-gray-300 text-primary-600 focus:ring-primary-500"
/>
</div>
<div
id="bulk-edit-openai-ws-mode"
:class="!enableOpenAIWSMode && 'pointer-events-none opacity-50'"
>
<p class="mb-3 text-xs text-gray-500 dark:text-gray-400">
{{ t('admin.accounts.openai.wsModeDesc') }}
</p>
<p class="mb-3 text-xs text-gray-500 dark:text-gray-400">
{{ t(openAIWSModeConcurrencyHintKey) }}
</p>
<Select
v-model="openaiOAuthResponsesWebSocketV2Mode"
data-testid="bulk-edit-openai-ws-mode-select"
:options="openAIWSModeOptions"
aria-labelledby="bulk-edit-openai-ws-mode-label"
/>
</div>
</div>
<!-- RPM Limit (仅全部为 Anthropic OAuth/SetupToken 时显示) -->
<div v-if="allAnthropicOAuthOrSetupToken" class="border-t border-gray-200 pt-4 dark:border-dark-600">
<div class="mb-3 flex items-center justify-between">
......@@ -883,6 +920,13 @@ import {
buildModelMappingObject as buildModelMappingPayload,
getPresetMappingsByPlatform
} from '@/composables/useModelWhitelist'
import {
OPENAI_WS_MODE_OFF,
OPENAI_WS_MODE_PASSTHROUGH,
isOpenAIWSModeEnabled,
resolveOpenAIWSModeConcurrencyHintKey
} from '@/utils/openaiWsMode'
import type { OpenAIWSMode } from '@/utils/openaiWsMode'
interface Props {
show: boolean
accountIds: number[]
......@@ -913,6 +957,15 @@ const allOpenAIPassthroughCapable = computed(() => {
)
})
const allOpenAIOAuth = computed(() => {
return (
props.selectedPlatforms.length === 1 &&
props.selectedPlatforms[0] === 'openai' &&
props.selectedTypes.length > 0 &&
props.selectedTypes.every(t => t === 'oauth')
)
})
// 是否全部为 Anthropic OAuth/SetupToken(RPM 配置仅在此条件下显示)
const allAnthropicOAuthOrSetupToken = computed(() => {
return (
......@@ -957,6 +1010,7 @@ const enableRateMultiplier = ref(false)
const enableStatus = ref(false)
const enableGroups = ref(false)
const enableOpenAIPassthrough = ref(false)
const enableOpenAIWSMode = ref(false)
const enableRpmLimit = ref(false)
// State - field values
......@@ -979,6 +1033,7 @@ const rateMultiplier = ref(1)
const status = ref<'active' | 'inactive'>('active')
const groupIds = ref<number[]>([])
const openaiPassthroughEnabled = ref(false)
const openaiOAuthResponsesWebSocketV2Mode = ref<OpenAIWSMode>(OPENAI_WS_MODE_OFF)
const rpmLimitEnabled = ref(false)
const bulkBaseRpm = ref<number | null>(null)
const bulkRpmStrategy = ref<'tiered' | 'sticky_exempt'>('tiered')
......@@ -1005,10 +1060,19 @@ const statusOptions = computed(() => [
{ value: 'active', label: t('common.active') },
{ value: 'inactive', label: t('common.inactive') }
])
const isOpenAIModelRestrictionDisabled = computed(() =>
allOpenAIPassthroughCapable.value &&
enableOpenAIPassthrough.value &&
openaiPassthroughEnabled.value
const isOpenAIModelRestrictionDisabled = computed(
() =>
allOpenAIPassthroughCapable.value &&
enableOpenAIPassthrough.value &&
openaiPassthroughEnabled.value
)
const openAIWSModeOptions = computed(() => [
{ value: OPENAI_WS_MODE_OFF, label: t('admin.accounts.openai.wsModeOff') },
{ value: OPENAI_WS_MODE_PASSTHROUGH, label: t('admin.accounts.openai.wsModePassthrough') }
])
const openAIWSModeConcurrencyHintKey = computed(() =>
resolveOpenAIWSModeConcurrencyHintKey(openaiOAuthResponsesWebSocketV2Mode.value)
)
// Model mapping helpers
......@@ -1180,6 +1244,14 @@ const buildUpdatePayload = (): Record<string, unknown> | null => {
updates.credentials = credentials
}
if (enableOpenAIWSMode.value) {
const extra = ensureExtra()
extra.openai_oauth_responses_websockets_v2_mode = openaiOAuthResponsesWebSocketV2Mode.value
extra.openai_oauth_responses_websockets_v2_enabled = isOpenAIWSModeEnabled(
openaiOAuthResponsesWebSocketV2Mode.value
)
}
// RPM limit settings (写入 extra 字段)
if (enableRpmLimit.value) {
const extra = ensureExtra()
......@@ -1269,6 +1341,7 @@ const handleSubmit = async () => {
enableRateMultiplier.value ||
enableStatus.value ||
enableGroups.value ||
enableOpenAIWSMode.value ||
enableRpmLimit.value ||
userMsgQueueMode.value !== null
......@@ -1361,6 +1434,7 @@ watch(
enableStatus.value = false
enableGroups.value = false
enableOpenAIPassthrough.value = false
enableOpenAIWSMode.value = false
enableRpmLimit.value = false
// Reset all values
......@@ -1379,6 +1453,7 @@ watch(
rateMultiplier.value = 1
status.value = 'active'
groupIds.value = []
openaiOAuthResponsesWebSocketV2Mode.value = OPENAI_WS_MODE_OFF
rpmLimitEnabled.value = false
bulkBaseRpm.value = null
bulkRpmStrategy.value = 'tiered'
......
......@@ -2504,6 +2504,7 @@
:allow-multiple="form.platform === 'anthropic'"
:show-cookie-option="form.platform === 'anthropic'"
:show-refresh-token-option="form.platform === 'openai' || form.platform === 'sora' || form.platform === 'antigravity'"
:show-mobile-refresh-token-option="form.platform === 'openai'"
:show-session-token-option="form.platform === 'sora'"
:show-access-token-option="form.platform === 'sora'"
:platform="form.platform"
......@@ -2511,6 +2512,7 @@
@generate-url="handleGenerateUrl"
@cookie-auth="handleCookieAuth"
@validate-refresh-token="handleValidateRefreshToken"
@validate-mobile-refresh-token="handleOpenAIValidateMobileRT"
@validate-session-token="handleValidateSessionToken"
@import-access-token="handleImportAccessToken"
/>
......@@ -4360,11 +4362,14 @@ const handleOpenAIExchange = async (authCode: string) => {
}
// OpenAI 手动 RT 批量验证和创建
const handleOpenAIValidateRT = async (refreshTokenInput: string) => {
// OpenAI Mobile RT 使用的 client_id(与后端 openai.SoraClientID 一致)
const OPENAI_MOBILE_RT_CLIENT_ID = 'app_LlGpXReQgckcGGUo2JrYvtJK'
// OpenAI/Sora RT 批量验证和创建(共享逻辑)
const handleOpenAIBatchRT = async (refreshTokenInput: string, clientId?: string) => {
const oauthClient = activeOpenAIOAuth.value
if (!refreshTokenInput.trim()) return
// Parse multiple refresh tokens (one per line)
const refreshTokens = refreshTokenInput
.split('\n')
.map((rt) => rt.trim())
......@@ -4389,7 +4394,8 @@ const handleOpenAIValidateRT = async (refreshTokenInput: string) => {
try {
const tokenInfo = await oauthClient.validateRefreshToken(
refreshTokens[i],
form.proxy_id
form.proxy_id,
clientId
)
if (!tokenInfo) {
failedCount++
......@@ -4399,6 +4405,9 @@ const handleOpenAIValidateRT = async (refreshTokenInput: string) => {
}
const credentials = oauthClient.buildCredentials(tokenInfo)
if (clientId) {
credentials.client_id = clientId
}
const oauthExtra = oauthClient.buildExtraInfo(tokenInfo) as Record<string, unknown> | undefined
const extra = buildOpenAIExtra(oauthExtra)
......@@ -4410,8 +4419,9 @@ const handleOpenAIValidateRT = async (refreshTokenInput: string) => {
}
}
// Generate account name with index for batch
const accountName = refreshTokens.length > 1 ? `${form.name} #${i + 1}` : form.name
// Generate account name; fallback to email if name is empty (ent schema requires NotEmpty)
const baseName = form.name || tokenInfo.email || 'OpenAI OAuth Account'
const accountName = refreshTokens.length > 1 ? `${baseName} #${i + 1}` : baseName
let openaiAccountId: string | number | undefined
......@@ -4494,6 +4504,12 @@ const handleOpenAIValidateRT = async (refreshTokenInput: string) => {
}
}
// 手动输入 RT(Codex CLI client_id,默认)
const handleOpenAIValidateRT = (rt: string) => handleOpenAIBatchRT(rt)
// 手动输入 Mobile RT(SoraClientID)
const handleOpenAIValidateMobileRT = (rt: string) => handleOpenAIBatchRT(rt, OPENAI_MOBILE_RT_CLIENT_ID)
// Sora 手动 ST 批量验证和创建
const handleSoraValidateST = async (sessionTokenInput: string) => {
const oauthClient = activeOpenAIOAuth.value
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment