Commit 8f0ea7a0 authored by InCerry's avatar InCerry
Browse files

Merge branch 'main' into fix/enc_coot

parents e4a4dfd0 a1dc0089
......@@ -2173,10 +2173,10 @@ func (s *GatewayService) withWindowCostPrefetch(ctx context.Context, accounts []
return context.WithValue(ctx, windowCostPrefetchContextKey, costs)
}
// isAccountSchedulableForQuota 检查 API Key 账号是否在配额限制内
// 适用于配置了 quota_limit 的 apikey 类型账号
// isAccountSchedulableForQuota 检查账号是否在配额限制内
// 适用于配置了 quota_limit 的 apikey 和 bedrock 类型账号
func (s *GatewayService) isAccountSchedulableForQuota(account *Account) bool {
if account.Type != AccountTypeAPIKey {
if !account.IsAPIKeyOrBedrock() {
return true
}
return !account.IsQuotaExceeded()
......@@ -3532,9 +3532,7 @@ func (s *GatewayService) GetAccessToken(ctx context.Context, account *Account) (
}
return apiKey, "apikey", nil
case AccountTypeBedrock:
return "", "bedrock", nil // Bedrock 使用 SigV4 签名,不需要 token
case AccountTypeBedrockAPIKey:
return "", "bedrock-apikey", nil // Bedrock API Key 使用 Bearer Token,由 forwardBedrock 处理
return "", "bedrock", nil // Bedrock 使用 SigV4 签名或 API Key,由 forwardBedrock 处理
default:
return "", "", fmt.Errorf("unsupported account type: %s", account.Type)
}
......@@ -5186,7 +5184,7 @@ func (s *GatewayService) forwardBedrock(
if account.IsBedrockAPIKey() {
bedrockAPIKey = account.GetCredential("api_key")
if bedrockAPIKey == "" {
return nil, fmt.Errorf("api_key not found in bedrock-apikey credentials")
return nil, fmt.Errorf("api_key not found in bedrock credentials")
}
} else {
signer, err = NewBedrockSignerFromAccount(account)
......@@ -5377,6 +5375,7 @@ func (s *GatewayService) handleBedrockUpstreamErrors(
return nil, &UpstreamFailoverError{
StatusCode: resp.StatusCode,
ResponseBody: respBody,
RetryableOnSameAccount: account.IsPoolMode() && isPoolModeRetryableStatus(resp.StatusCode),
}
}
return s.handleRetryExhaustedError(ctx, resp, c, account)
......@@ -5400,6 +5399,7 @@ func (s *GatewayService) handleBedrockUpstreamErrors(
return nil, &UpstreamFailoverError{
StatusCode: resp.StatusCode,
ResponseBody: respBody,
RetryableOnSameAccount: account.IsPoolMode() && isPoolModeRetryableStatus(resp.StatusCode),
}
}
......@@ -5808,9 +5808,10 @@ func (s *GatewayService) evaluateBetaPolicy(ctx context.Context, betaHeader stri
return betaPolicyResult{}
}
isOAuth := account.IsOAuth()
isBedrock := account.IsBedrock()
var result betaPolicyResult
for _, rule := range settings.Rules {
if !betaPolicyScopeMatches(rule.Scope, isOAuth) {
if !betaPolicyScopeMatches(rule.Scope, isOAuth, isBedrock) {
continue
}
switch rule.Action {
......@@ -5870,14 +5871,16 @@ func (s *GatewayService) getBetaPolicyFilterSet(ctx context.Context, c *gin.Cont
}
// betaPolicyScopeMatches checks whether a rule's scope matches the current account type.
func betaPolicyScopeMatches(scope string, isOAuth bool) bool {
func betaPolicyScopeMatches(scope string, isOAuth bool, isBedrock bool) bool {
switch scope {
case BetaPolicyScopeAll:
return true
case BetaPolicyScopeOAuth:
return isOAuth
case BetaPolicyScopeAPIKey:
return !isOAuth
return !isOAuth && !isBedrock
case BetaPolicyScopeBedrock:
return isBedrock
default:
return true // unknown scope → match all (fail-open)
}
......@@ -5959,12 +5962,13 @@ func (s *GatewayService) checkBetaPolicyBlockForTokens(ctx context.Context, toke
return nil
}
isOAuth := account.IsOAuth()
isBedrock := account.IsBedrock()
tokenSet := buildBetaTokenSet(tokens)
for _, rule := range settings.Rules {
if rule.Action != BetaPolicyActionBlock {
continue
}
if !betaPolicyScopeMatches(rule.Scope, isOAuth) {
if !betaPolicyScopeMatches(rule.Scope, isOAuth, isBedrock) {
continue
}
if _, present := tokenSet[rule.BetaToken]; present {
......@@ -7199,7 +7203,7 @@ func postUsageBilling(ctx context.Context, p *postUsageBillingParams, deps *bill
}
// 4. 账号配额用量(账号口径:TotalCost × 账号计费倍率)
if cost.TotalCost > 0 && p.Account.Type == AccountTypeAPIKey && p.Account.HasAnyQuotaLimit() {
if cost.TotalCost > 0 && p.Account.IsAPIKeyOrBedrock() && p.Account.HasAnyQuotaLimit() {
accountCost := cost.TotalCost * p.AccountRateMultiplier
if err := deps.accountRepo.IncrementQuotaUsed(billingCtx, p.Account.ID, accountCost); err != nil {
slog.Error("increment account quota used failed", "account_id", p.Account.ID, "cost", accountCost, "error", err)
......@@ -7287,7 +7291,7 @@ func buildUsageBillingCommand(requestID string, usageLog *UsageLog, p *postUsage
if p.Cost.ActualCost > 0 && p.APIKey.HasRateLimits() && p.APIKeyService != nil {
cmd.APIKeyRateLimitCost = p.Cost.ActualCost
}
if p.Cost.TotalCost > 0 && p.Account.Type == AccountTypeAPIKey && p.Account.HasAnyQuotaLimit() {
if p.Cost.TotalCost > 0 && p.Account.IsAPIKeyOrBedrock() && p.Account.HasAnyQuotaLimit() {
cmd.AccountQuotaCost = p.Cost.TotalCost * p.AccountRateMultiplier
}
......
......@@ -339,8 +339,9 @@ func filterCodexInput(input []any, preserveReferences bool) []any {
}
typ, _ := m["type"].(string)
// 修复 OpenAI 上游的最新校验:"Expected an ID that begins with 'fc'"
fixIDPrefix := func(id string) string {
// 仅修正真正的 tool/function call 标识,避免误改普通 message/reasoning id;
// 若 item_reference 指向 legacy call_* 标识,则仅修正该引用本身。
fixCallIDPrefix := func(id string) string {
if id == "" || strings.HasPrefix(id, "fc") {
return id
}
......@@ -358,8 +359,8 @@ func filterCodexInput(input []any, preserveReferences bool) []any {
for key, value := range m {
newItem[key] = value
}
if id, ok := newItem["id"].(string); ok && id != "" {
newItem["id"] = fixIDPrefix(id)
if id, ok := newItem["id"].(string); ok && strings.HasPrefix(id, "call_") {
newItem["id"] = fixCallIDPrefix(id)
}
filtered = append(filtered, newItem)
continue
......@@ -390,7 +391,7 @@ func filterCodexInput(input []any, preserveReferences bool) []any {
}
if callID != "" {
fixedCallID := fixIDPrefix(callID)
fixedCallID := fixCallIDPrefix(callID)
if fixedCallID != callID {
ensureCopy()
newItem["call_id"] = fixedCallID
......@@ -404,14 +405,6 @@ func filterCodexInput(input []any, preserveReferences bool) []any {
if !isCodexToolCallItemType(typ) {
delete(newItem, "call_id")
}
} else {
if id, ok := newItem["id"].(string); ok && id != "" {
fixedID := fixIDPrefix(id)
if fixedID != id {
ensureCopy()
newItem["id"] = fixedID
}
}
}
filtered = append(filtered, newItem)
......
......@@ -33,12 +33,63 @@ func TestApplyCodexOAuthTransform_ToolContinuationPreservesInput(t *testing.T) {
first, ok := input[0].(map[string]any)
require.True(t, ok)
require.Equal(t, "item_reference", first["type"])
require.Equal(t, "fc_ref1", first["id"])
require.Equal(t, "ref1", first["id"])
// 校验 input[1] 为 map,确保后续字段断言安全。
second, ok := input[1].(map[string]any)
require.True(t, ok)
require.Equal(t, "fc_o1", second["id"])
require.Equal(t, "o1", second["id"])
require.Equal(t, "fc1", second["call_id"])
}
func TestApplyCodexOAuthTransform_ToolContinuationPreservesNativeMessageAndReasoningIDs(t *testing.T) {
reqBody := map[string]any{
"model": "gpt-5.2",
"input": []any{
map[string]any{"type": "message", "id": "msg_0", "role": "user", "content": "hi"},
map[string]any{"type": "item_reference", "id": "rs_123"},
},
"tool_choice": "auto",
}
applyCodexOAuthTransform(reqBody, false, false)
input, ok := reqBody["input"].([]any)
require.True(t, ok)
require.Len(t, input, 2)
first, ok := input[0].(map[string]any)
require.True(t, ok)
require.Equal(t, "msg_0", first["id"])
second, ok := input[1].(map[string]any)
require.True(t, ok)
require.Equal(t, "rs_123", second["id"])
}
func TestApplyCodexOAuthTransform_ToolContinuationNormalizesToolReferenceIDsOnly(t *testing.T) {
reqBody := map[string]any{
"model": "gpt-5.2",
"input": []any{
map[string]any{"type": "item_reference", "id": "call_1"},
map[string]any{"type": "function_call_output", "call_id": "call_1", "output": "ok"},
},
"tool_choice": "auto",
}
applyCodexOAuthTransform(reqBody, false, false)
input, ok := reqBody["input"].([]any)
require.True(t, ok)
require.Len(t, input, 2)
first, ok := input[0].(map[string]any)
require.True(t, ok)
require.Equal(t, "fc1", first["id"])
second, ok := input[1].(map[string]any)
require.True(t, ok)
require.Equal(t, "fc1", second["call_id"])
}
func TestApplyCodexOAuthTransform_ExplicitStoreFalsePreserved(t *testing.T) {
......
......@@ -51,10 +51,7 @@ func (s *OpenAIGatewayService) ForwardAsChatCompletions(
}
// 3. Model mapping
mappedModel := account.GetMappedModel(originalModel)
if mappedModel == originalModel && defaultMappedModel != "" {
mappedModel = defaultMappedModel
}
mappedModel := resolveOpenAIForwardModel(account, originalModel, defaultMappedModel)
responsesReq.Model = mappedModel
logger.L().Debug("openai chat_completions: model mapping applied",
......
......@@ -59,11 +59,7 @@ func (s *OpenAIGatewayService) ForwardAsAnthropic(
}
// 3. Model mapping
mappedModel := account.GetMappedModel(originalModel)
// 分组级降级:账号未映射时使用分组默认映射模型
if mappedModel == originalModel && defaultMappedModel != "" {
mappedModel = defaultMappedModel
}
mappedModel := resolveOpenAIForwardModel(account, originalModel, defaultMappedModel)
responsesReq.Model = mappedModel
logger.L().Debug("openai messages: model mapping applied",
......
package service
// resolveOpenAIForwardModel determines the upstream model for OpenAI-compatible
// forwarding. Group-level default mapping only applies when the account itself
// did not match any explicit model_mapping rule.
func resolveOpenAIForwardModel(account *Account, requestedModel, defaultMappedModel string) string {
if account == nil {
if defaultMappedModel != "" {
return defaultMappedModel
}
return requestedModel
}
mappedModel, matched := account.ResolveMappedModel(requestedModel)
if !matched && defaultMappedModel != "" {
return defaultMappedModel
}
return mappedModel
}
package service
import "testing"
func TestResolveOpenAIForwardModel(t *testing.T) {
tests := []struct {
name string
account *Account
requestedModel string
defaultMappedModel string
expectedModel string
}{
{
name: "falls back to group default when account has no mapping",
account: &Account{
Credentials: map[string]any{},
},
requestedModel: "gpt-5.4",
defaultMappedModel: "gpt-4o-mini",
expectedModel: "gpt-4o-mini",
},
{
name: "preserves exact passthrough mapping instead of group default",
account: &Account{
Credentials: map[string]any{
"model_mapping": map[string]any{
"gpt-5.4": "gpt-5.4",
},
},
},
requestedModel: "gpt-5.4",
defaultMappedModel: "gpt-4o-mini",
expectedModel: "gpt-5.4",
},
{
name: "preserves wildcard passthrough mapping instead of group default",
account: &Account{
Credentials: map[string]any{
"model_mapping": map[string]any{
"gpt-*": "gpt-5.4",
},
},
},
requestedModel: "gpt-5.4",
defaultMappedModel: "gpt-4o-mini",
expectedModel: "gpt-5.4",
},
{
name: "uses account remap when explicit target differs",
account: &Account{
Credentials: map[string]any{
"model_mapping": map[string]any{
"gpt-5": "gpt-5.4",
},
},
},
requestedModel: "gpt-5",
defaultMappedModel: "gpt-4o-mini",
expectedModel: "gpt-5.4",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := resolveOpenAIForwardModel(tt.account, tt.requestedModel, tt.defaultMappedModel); got != tt.expectedModel {
t.Fatalf("resolveOpenAIForwardModel(...) = %q, want %q", got, tt.expectedModel)
}
})
}
}
......@@ -371,6 +371,8 @@ func defaultOpsAdvancedSettings() *OpsAdvancedSettings {
IgnoreCountTokensErrors: true, // count_tokens 404 是预期行为,默认忽略
IgnoreContextCanceled: true, // Default to true - client disconnects are not errors
IgnoreNoAvailableAccounts: false, // Default to false - this is a real routing issue
DisplayOpenAITokenStats: false,
DisplayAlertEvents: true,
AutoRefreshEnabled: false,
AutoRefreshIntervalSec: 30,
}
......@@ -438,7 +440,7 @@ func (s *OpsService) GetOpsAdvancedSettings(ctx context.Context) (*OpsAdvancedSe
return nil, err
}
cfg := &OpsAdvancedSettings{}
cfg := defaultOpsAdvancedSettings()
if err := json.Unmarshal([]byte(raw), cfg); err != nil {
return defaultCfg, nil
}
......
package service
import (
"context"
"encoding/json"
"testing"
)
func TestGetOpsAdvancedSettings_DefaultHidesOpenAITokenStats(t *testing.T) {
repo := newRuntimeSettingRepoStub()
svc := &OpsService{settingRepo: repo}
cfg, err := svc.GetOpsAdvancedSettings(context.Background())
if err != nil {
t.Fatalf("GetOpsAdvancedSettings() error = %v", err)
}
if cfg.DisplayOpenAITokenStats {
t.Fatalf("DisplayOpenAITokenStats = true, want false by default")
}
if !cfg.DisplayAlertEvents {
t.Fatalf("DisplayAlertEvents = false, want true by default")
}
if repo.setCalls != 1 {
t.Fatalf("expected defaults to be persisted once, got %d", repo.setCalls)
}
}
func TestUpdateOpsAdvancedSettings_PersistsOpenAITokenStatsVisibility(t *testing.T) {
repo := newRuntimeSettingRepoStub()
svc := &OpsService{settingRepo: repo}
cfg := defaultOpsAdvancedSettings()
cfg.DisplayOpenAITokenStats = true
cfg.DisplayAlertEvents = false
updated, err := svc.UpdateOpsAdvancedSettings(context.Background(), cfg)
if err != nil {
t.Fatalf("UpdateOpsAdvancedSettings() error = %v", err)
}
if !updated.DisplayOpenAITokenStats {
t.Fatalf("DisplayOpenAITokenStats = false, want true")
}
if updated.DisplayAlertEvents {
t.Fatalf("DisplayAlertEvents = true, want false")
}
reloaded, err := svc.GetOpsAdvancedSettings(context.Background())
if err != nil {
t.Fatalf("GetOpsAdvancedSettings() after update error = %v", err)
}
if !reloaded.DisplayOpenAITokenStats {
t.Fatalf("reloaded DisplayOpenAITokenStats = false, want true")
}
if reloaded.DisplayAlertEvents {
t.Fatalf("reloaded DisplayAlertEvents = true, want false")
}
}
func TestGetOpsAdvancedSettings_BackfillsNewDisplayFlagsFromDefaults(t *testing.T) {
repo := newRuntimeSettingRepoStub()
svc := &OpsService{settingRepo: repo}
legacyCfg := map[string]any{
"data_retention": map[string]any{
"cleanup_enabled": false,
"cleanup_schedule": "0 2 * * *",
"error_log_retention_days": 30,
"minute_metrics_retention_days": 30,
"hourly_metrics_retention_days": 30,
},
"aggregation": map[string]any{
"aggregation_enabled": false,
},
"ignore_count_tokens_errors": true,
"ignore_context_canceled": true,
"ignore_no_available_accounts": false,
"ignore_invalid_api_key_errors": false,
"auto_refresh_enabled": false,
"auto_refresh_interval_seconds": 30,
}
raw, err := json.Marshal(legacyCfg)
if err != nil {
t.Fatalf("marshal legacy config: %v", err)
}
repo.values[SettingKeyOpsAdvancedSettings] = string(raw)
cfg, err := svc.GetOpsAdvancedSettings(context.Background())
if err != nil {
t.Fatalf("GetOpsAdvancedSettings() error = %v", err)
}
if cfg.DisplayOpenAITokenStats {
t.Fatalf("DisplayOpenAITokenStats = true, want false default backfill")
}
if !cfg.DisplayAlertEvents {
t.Fatalf("DisplayAlertEvents = false, want true default backfill")
}
}
......@@ -98,6 +98,8 @@ type OpsAdvancedSettings struct {
IgnoreContextCanceled bool `json:"ignore_context_canceled"`
IgnoreNoAvailableAccounts bool `json:"ignore_no_available_accounts"`
IgnoreInvalidApiKeyErrors bool `json:"ignore_invalid_api_key_errors"`
DisplayOpenAITokenStats bool `json:"display_openai_token_stats"`
DisplayAlertEvents bool `json:"display_alert_events"`
AutoRefreshEnabled bool `json:"auto_refresh_enabled"`
AutoRefreshIntervalSec int `json:"auto_refresh_interval_seconds"`
}
......
......@@ -149,8 +149,9 @@ func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Acc
}
// 其他 400 错误(如参数问题)不处理,不禁用账号
case 401:
// 对所有 OAuth 账号在 401 错误时调用缓存失效并强制下次刷新
if account.Type == AccountTypeOAuth {
// OAuth 账号在 401 错误时临时不可调度(给 token 刷新窗口);非 OAuth 账号保持原有 SetError 行为。
// Antigravity 除外:其 401 由 applyErrorPolicy 的 temp_unschedulable_rules 自行控制。
if account.Type == AccountTypeOAuth && account.Platform != PlatformAntigravity {
// 1. 失效缓存
if s.tokenCacheInvalidator != nil {
if err := s.tokenCacheInvalidator.InvalidateToken(ctx, account); err != nil {
......@@ -182,7 +183,7 @@ func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Acc
}
shouldDisable = true
} else {
// 非 OAuth 账号(APIKey):保持原有 SetError 行为
// 非 OAuth / Antigravity OAuth:保持 SetError 行为
msg := "Authentication failed (401): invalid or expired credentials"
if upstreamMsg != "" {
msg = "Authentication failed (401): " + upstreamMsg
......@@ -199,11 +200,6 @@ func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Acc
s.handleAuthError(ctx, account, msg)
shouldDisable = true
case 403:
// 禁止访问:停止调度,记录错误
msg := "Access forbidden (403): account may be suspended or lack permissions"
if upstreamMsg != "" {
msg = "Access forbidden (403): " + upstreamMsg
}
logger.LegacyPrintf(
"service.ratelimit",
"[HandleUpstreamErrorRaw] account_id=%d platform=%s type=%s status=403 request_id=%s cf_ray=%s upstream_msg=%s raw_body=%s",
......@@ -215,8 +211,7 @@ func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Acc
upstreamMsg,
truncateForLog(responseBody, 1024),
)
s.handleAuthError(ctx, account, msg)
shouldDisable = true
shouldDisable = s.handle403(ctx, account, upstreamMsg, responseBody)
case 429:
s.handle429(ctx, account, headers, responseBody)
shouldDisable = false
......@@ -621,6 +616,62 @@ func (s *RateLimitService) handleAuthError(ctx context.Context, account *Account
slog.Warn("account_disabled_auth_error", "account_id", account.ID, "error", errorMsg)
}
// handle403 处理 403 Forbidden 错误
// Antigravity 平台区分 validation/violation/generic 三种类型,均 SetError 永久禁用;
// 其他平台保持原有 SetError 行为。
func (s *RateLimitService) handle403(ctx context.Context, account *Account, upstreamMsg string, responseBody []byte) (shouldDisable bool) {
if account.Platform == PlatformAntigravity {
return s.handleAntigravity403(ctx, account, upstreamMsg, responseBody)
}
// 非 Antigravity 平台:保持原有行为
msg := "Access forbidden (403): account may be suspended or lack permissions"
if upstreamMsg != "" {
msg = "Access forbidden (403): " + upstreamMsg
}
s.handleAuthError(ctx, account, msg)
return true
}
// handleAntigravity403 处理 Antigravity 平台的 403 错误
// validation(需要验证)→ 永久 SetError(需人工去 Google 验证后恢复)
// violation(违规封号)→ 永久 SetError(需人工处理)
// generic(通用禁止)→ 永久 SetError
func (s *RateLimitService) handleAntigravity403(ctx context.Context, account *Account, upstreamMsg string, responseBody []byte) (shouldDisable bool) {
fbType := classifyForbiddenType(string(responseBody))
switch fbType {
case forbiddenTypeValidation:
// VALIDATION_REQUIRED: 永久禁用,需人工去 Google 验证后手动恢复
msg := "Validation required (403): account needs Google verification"
if upstreamMsg != "" {
msg = "Validation required (403): " + upstreamMsg
}
if validationURL := extractValidationURL(string(responseBody)); validationURL != "" {
msg += " | validation_url: " + validationURL
}
s.handleAuthError(ctx, account, msg)
return true
case forbiddenTypeViolation:
// 违规封号: 永久禁用,需人工处理
msg := "Account violation (403): terms of service violation"
if upstreamMsg != "" {
msg = "Account violation (403): " + upstreamMsg
}
s.handleAuthError(ctx, account, msg)
return true
default:
// 通用 403: 保持原有行为
msg := "Access forbidden (403): account may be suspended or lack permissions"
if upstreamMsg != "" {
msg = "Access forbidden (403): " + upstreamMsg
}
s.handleAuthError(ctx, account, msg)
return true
}
}
// handleCustomErrorCode 处理自定义错误码,停止账号调度
func (s *RateLimitService) handleCustomErrorCode(ctx context.Context, account *Account, statusCode int, errorMsg string) {
msg := "Custom error code " + strconv.Itoa(statusCode) + ": " + errorMsg
......@@ -1213,7 +1264,8 @@ func (s *RateLimitService) tryTempUnschedulable(ctx context.Context, account *Ac
}
// 401 首次命中可临时不可调度(给 token 刷新窗口);
// 若历史上已因 401 进入过临时不可调度,则本次应升级为 error(返回 false 交由默认错误逻辑处理)。
if statusCode == http.StatusUnauthorized {
// Antigravity 跳过:其 401 由 applyErrorPolicy 的 temp_unschedulable_rules 自行控制,无需升级逻辑。
if statusCode == http.StatusUnauthorized && account.Platform != PlatformAntigravity {
reason := account.TempUnschedulableReason
// 缓存可能没有 reason,从 DB 回退读取
if reason == "" {
......
......@@ -27,7 +27,40 @@ func (r *dbFallbackRepoStub) GetByID(ctx context.Context, id int64) (*Account, e
func TestCheckErrorPolicy_401_DBFallback_Escalates(t *testing.T) {
// Scenario: cache account has empty TempUnschedulableReason (cache miss),
// but DB account has a previous 401 record → should escalate to ErrorPolicyNone.
// but DB account has a previous 401 record.
// Non-Antigravity: should escalate to ErrorPolicyNone (second 401 = permanent error).
// Antigravity: skips escalation logic (401 handled by applyErrorPolicy rules).
t.Run("gemini_escalates", func(t *testing.T) {
repo := &dbFallbackRepoStub{
dbAccount: &Account{
ID: 20,
TempUnschedulableReason: `{"status_code":401,"until_unix":1735689600}`,
},
}
svc := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
account := &Account{
ID: 20,
Type: AccountTypeOAuth,
Platform: PlatformGemini,
TempUnschedulableReason: "",
Credentials: map[string]any{
"temp_unschedulable_enabled": true,
"temp_unschedulable_rules": []any{
map[string]any{
"error_code": float64(401),
"keywords": []any{"unauthorized"},
"duration_minutes": float64(10),
},
},
},
}
result := svc.CheckErrorPolicy(context.Background(), account, http.StatusUnauthorized, []byte(`unauthorized`))
require.Equal(t, ErrorPolicyNone, result, "gemini 401 with DB fallback showing previous 401 should escalate")
})
t.Run("antigravity_stays_temp", func(t *testing.T) {
repo := &dbFallbackRepoStub{
dbAccount: &Account{
ID: 20,
......@@ -40,7 +73,7 @@ func TestCheckErrorPolicy_401_DBFallback_Escalates(t *testing.T) {
ID: 20,
Type: AccountTypeOAuth,
Platform: PlatformAntigravity,
TempUnschedulableReason: "", // cache miss — reason is empty
TempUnschedulableReason: "",
Credentials: map[string]any{
"temp_unschedulable_enabled": true,
"temp_unschedulable_rules": []any{
......@@ -54,7 +87,8 @@ func TestCheckErrorPolicy_401_DBFallback_Escalates(t *testing.T) {
}
result := svc.CheckErrorPolicy(context.Background(), account, http.StatusUnauthorized, []byte(`unauthorized`))
require.Equal(t, ErrorPolicyNone, result, "401 with DB fallback showing previous 401 should escalate to ErrorPolicyNone")
require.Equal(t, ErrorPolicyTempUnscheduled, result, "antigravity 401 skips escalation, stays temp-unscheduled")
})
}
func TestCheckErrorPolicy_401_DBFallback_NoDBRecord_FirstHit(t *testing.T) {
......
......@@ -42,23 +42,14 @@ func (r *tokenCacheInvalidatorRecorder) InvalidateToken(ctx context.Context, acc
}
func TestRateLimitService_HandleUpstreamError_OAuth401SetsTempUnschedulable(t *testing.T) {
tests := []struct {
name string
platform string
}{
{name: "gemini", platform: PlatformGemini},
{name: "antigravity", platform: PlatformAntigravity},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Run("gemini", func(t *testing.T) {
repo := &rateLimitAccountRepoStub{}
invalidator := &tokenCacheInvalidatorRecorder{}
service := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
service.SetTokenCacheInvalidator(invalidator)
account := &Account{
ID: 100,
Platform: tt.platform,
Platform: PlatformGemini,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"temp_unschedulable_enabled": true,
......@@ -80,7 +71,27 @@ func TestRateLimitService_HandleUpstreamError_OAuth401SetsTempUnschedulable(t *t
require.Equal(t, 1, repo.tempCalls)
require.Len(t, invalidator.accounts, 1)
})
t.Run("antigravity_401_uses_SetError", func(t *testing.T) {
// Antigravity 401 由 applyErrorPolicy 的 temp_unschedulable_rules 控制,
// HandleUpstreamError 中走 SetError 路径。
repo := &rateLimitAccountRepoStub{}
invalidator := &tokenCacheInvalidatorRecorder{}
service := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
service.SetTokenCacheInvalidator(invalidator)
account := &Account{
ID: 100,
Platform: PlatformAntigravity,
Type: AccountTypeOAuth,
}
shouldDisable := service.HandleUpstreamError(context.Background(), account, 401, http.Header{}, []byte("unauthorized"))
require.True(t, shouldDisable)
require.Equal(t, 1, repo.setErrorCalls)
require.Equal(t, 0, repo.tempCalls)
require.Empty(t, invalidator.accounts)
})
}
func TestRateLimitService_HandleUpstreamError_OAuth401InvalidatorError(t *testing.T) {
......
......@@ -65,6 +65,19 @@ const minVersionErrorTTL = 5 * time.Second
// minVersionDBTimeout singleflight 内 DB 查询超时,独立于请求 context
const minVersionDBTimeout = 5 * time.Second
// cachedBackendMode Backend Mode cache (in-process, 60s TTL)
type cachedBackendMode struct {
value bool
expiresAt int64 // unix nano
}
var backendModeCache atomic.Value // *cachedBackendMode
var backendModeSF singleflight.Group
const backendModeCacheTTL = 60 * time.Second
const backendModeErrorTTL = 5 * time.Second
const backendModeDBTimeout = 5 * time.Second
// DefaultSubscriptionGroupReader validates group references used by default subscriptions.
type DefaultSubscriptionGroupReader interface {
GetByID(ctx context.Context, id int64) (*Group, error)
......@@ -128,6 +141,7 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
SettingKeySoraClientEnabled,
SettingKeyCustomMenuItems,
SettingKeyLinuxDoConnectEnabled,
SettingKeyBackendModeEnabled,
}
settings, err := s.settingRepo.GetMultiple(ctx, keys)
......@@ -172,6 +186,7 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
SoraClientEnabled: settings[SettingKeySoraClientEnabled] == "true",
CustomMenuItems: settings[SettingKeyCustomMenuItems],
LinuxDoOAuthEnabled: linuxDoEnabled,
BackendModeEnabled: settings[SettingKeyBackendModeEnabled] == "true",
}, nil
}
......@@ -223,6 +238,7 @@ func (s *SettingService) GetPublicSettingsForInjection(ctx context.Context) (any
SoraClientEnabled bool `json:"sora_client_enabled"`
CustomMenuItems json.RawMessage `json:"custom_menu_items"`
LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"`
BackendModeEnabled bool `json:"backend_mode_enabled"`
Version string `json:"version,omitempty"`
}{
RegistrationEnabled: settings.RegistrationEnabled,
......@@ -247,6 +263,7 @@ func (s *SettingService) GetPublicSettingsForInjection(ctx context.Context) (any
SoraClientEnabled: settings.SoraClientEnabled,
CustomMenuItems: filterUserVisibleMenuItems(settings.CustomMenuItems),
LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled,
BackendModeEnabled: settings.BackendModeEnabled,
Version: s.version,
}, nil
}
......@@ -461,6 +478,9 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet
// 分组隔离
updates[SettingKeyAllowUngroupedKeyScheduling] = strconv.FormatBool(settings.AllowUngroupedKeyScheduling)
// Backend Mode
updates[SettingKeyBackendModeEnabled] = strconv.FormatBool(settings.BackendModeEnabled)
err = s.settingRepo.SetMultiple(ctx, updates)
if err == nil {
// 先使 inflight singleflight 失效,再刷新缓存,缩小旧值覆盖新值的竞态窗口
......@@ -469,6 +489,11 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet
value: settings.MinClaudeCodeVersion,
expiresAt: time.Now().Add(minVersionCacheTTL).UnixNano(),
})
backendModeSF.Forget("backend_mode")
backendModeCache.Store(&cachedBackendMode{
value: settings.BackendModeEnabled,
expiresAt: time.Now().Add(backendModeCacheTTL).UnixNano(),
})
if s.onUpdate != nil {
s.onUpdate() // Invalidate cache after settings update
}
......@@ -525,6 +550,52 @@ func (s *SettingService) IsRegistrationEnabled(ctx context.Context) bool {
return value == "true"
}
// IsBackendModeEnabled checks if backend mode is enabled
// Uses in-process atomic.Value cache with 60s TTL, zero-lock hot path
func (s *SettingService) IsBackendModeEnabled(ctx context.Context) bool {
if cached, ok := backendModeCache.Load().(*cachedBackendMode); ok && cached != nil {
if time.Now().UnixNano() < cached.expiresAt {
return cached.value
}
}
result, _, _ := backendModeSF.Do("backend_mode", func() (any, error) {
if cached, ok := backendModeCache.Load().(*cachedBackendMode); ok && cached != nil {
if time.Now().UnixNano() < cached.expiresAt {
return cached.value, nil
}
}
dbCtx, cancel := context.WithTimeout(context.WithoutCancel(ctx), backendModeDBTimeout)
defer cancel()
value, err := s.settingRepo.GetValue(dbCtx, SettingKeyBackendModeEnabled)
if err != nil {
if errors.Is(err, ErrSettingNotFound) {
// Setting not yet created (fresh install) - default to disabled with full TTL
backendModeCache.Store(&cachedBackendMode{
value: false,
expiresAt: time.Now().Add(backendModeCacheTTL).UnixNano(),
})
return false, nil
}
slog.Warn("failed to get backend_mode_enabled setting", "error", err)
backendModeCache.Store(&cachedBackendMode{
value: false,
expiresAt: time.Now().Add(backendModeErrorTTL).UnixNano(),
})
return false, nil
}
enabled := value == "true"
backendModeCache.Store(&cachedBackendMode{
value: enabled,
expiresAt: time.Now().Add(backendModeCacheTTL).UnixNano(),
})
return enabled, nil
})
if val, ok := result.(bool); ok {
return val
}
return false
}
// IsEmailVerifyEnabled 检查是否开启邮件验证
func (s *SettingService) IsEmailVerifyEnabled(ctx context.Context) bool {
value, err := s.settingRepo.GetValue(ctx, SettingKeyEmailVerifyEnabled)
......@@ -719,6 +790,7 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin
PurchaseSubscriptionURL: strings.TrimSpace(settings[SettingKeyPurchaseSubscriptionURL]),
SoraClientEnabled: settings[SettingKeySoraClientEnabled] == "true",
CustomMenuItems: settings[SettingKeyCustomMenuItems],
BackendModeEnabled: settings[SettingKeyBackendModeEnabled] == "true",
}
// 解析整数类型
......@@ -1278,7 +1350,7 @@ func (s *SettingService) SetBetaPolicySettings(ctx context.Context, settings *Be
BetaPolicyActionPass: true, BetaPolicyActionFilter: true, BetaPolicyActionBlock: true,
}
validScopes := map[string]bool{
BetaPolicyScopeAll: true, BetaPolicyScopeOAuth: true, BetaPolicyScopeAPIKey: true,
BetaPolicyScopeAll: true, BetaPolicyScopeOAuth: true, BetaPolicyScopeAPIKey: true, BetaPolicyScopeBedrock: true,
}
for i, rule := range settings.Rules {
......
//go:build unit
package service
import (
"context"
"errors"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/stretchr/testify/require"
)
type bmRepoStub struct {
getValueFn func(ctx context.Context, key string) (string, error)
calls int
}
func (s *bmRepoStub) Get(ctx context.Context, key string) (*Setting, error) {
panic("unexpected Get call")
}
func (s *bmRepoStub) GetValue(ctx context.Context, key string) (string, error) {
s.calls++
if s.getValueFn == nil {
panic("unexpected GetValue call")
}
return s.getValueFn(ctx, key)
}
func (s *bmRepoStub) Set(ctx context.Context, key, value string) error {
panic("unexpected Set call")
}
func (s *bmRepoStub) GetMultiple(ctx context.Context, keys []string) (map[string]string, error) {
panic("unexpected GetMultiple call")
}
func (s *bmRepoStub) SetMultiple(ctx context.Context, settings map[string]string) error {
panic("unexpected SetMultiple call")
}
func (s *bmRepoStub) GetAll(ctx context.Context) (map[string]string, error) {
panic("unexpected GetAll call")
}
func (s *bmRepoStub) Delete(ctx context.Context, key string) error {
panic("unexpected Delete call")
}
type bmUpdateRepoStub struct {
updates map[string]string
getValueFn func(ctx context.Context, key string) (string, error)
}
func (s *bmUpdateRepoStub) Get(ctx context.Context, key string) (*Setting, error) {
panic("unexpected Get call")
}
func (s *bmUpdateRepoStub) GetValue(ctx context.Context, key string) (string, error) {
if s.getValueFn == nil {
panic("unexpected GetValue call")
}
return s.getValueFn(ctx, key)
}
func (s *bmUpdateRepoStub) Set(ctx context.Context, key, value string) error {
panic("unexpected Set call")
}
func (s *bmUpdateRepoStub) GetMultiple(ctx context.Context, keys []string) (map[string]string, error) {
panic("unexpected GetMultiple call")
}
func (s *bmUpdateRepoStub) SetMultiple(ctx context.Context, settings map[string]string) error {
s.updates = make(map[string]string, len(settings))
for k, v := range settings {
s.updates[k] = v
}
return nil
}
func (s *bmUpdateRepoStub) GetAll(ctx context.Context) (map[string]string, error) {
panic("unexpected GetAll call")
}
func (s *bmUpdateRepoStub) Delete(ctx context.Context, key string) error {
panic("unexpected Delete call")
}
func resetBackendModeTestCache(t *testing.T) {
t.Helper()
backendModeCache.Store((*cachedBackendMode)(nil))
t.Cleanup(func() {
backendModeCache.Store((*cachedBackendMode)(nil))
})
}
func TestIsBackendModeEnabled_ReturnsTrue(t *testing.T) {
resetBackendModeTestCache(t)
repo := &bmRepoStub{
getValueFn: func(ctx context.Context, key string) (string, error) {
require.Equal(t, SettingKeyBackendModeEnabled, key)
return "true", nil
},
}
svc := NewSettingService(repo, &config.Config{})
require.True(t, svc.IsBackendModeEnabled(context.Background()))
require.Equal(t, 1, repo.calls)
}
func TestIsBackendModeEnabled_ReturnsFalse(t *testing.T) {
resetBackendModeTestCache(t)
repo := &bmRepoStub{
getValueFn: func(ctx context.Context, key string) (string, error) {
require.Equal(t, SettingKeyBackendModeEnabled, key)
return "false", nil
},
}
svc := NewSettingService(repo, &config.Config{})
require.False(t, svc.IsBackendModeEnabled(context.Background()))
require.Equal(t, 1, repo.calls)
}
func TestIsBackendModeEnabled_ReturnsFalseOnNotFound(t *testing.T) {
resetBackendModeTestCache(t)
repo := &bmRepoStub{
getValueFn: func(ctx context.Context, key string) (string, error) {
require.Equal(t, SettingKeyBackendModeEnabled, key)
return "", ErrSettingNotFound
},
}
svc := NewSettingService(repo, &config.Config{})
require.False(t, svc.IsBackendModeEnabled(context.Background()))
require.Equal(t, 1, repo.calls)
}
func TestIsBackendModeEnabled_ReturnsFalseOnDBError(t *testing.T) {
resetBackendModeTestCache(t)
repo := &bmRepoStub{
getValueFn: func(ctx context.Context, key string) (string, error) {
require.Equal(t, SettingKeyBackendModeEnabled, key)
return "", errors.New("db down")
},
}
svc := NewSettingService(repo, &config.Config{})
require.False(t, svc.IsBackendModeEnabled(context.Background()))
require.Equal(t, 1, repo.calls)
}
func TestIsBackendModeEnabled_CachesResult(t *testing.T) {
resetBackendModeTestCache(t)
repo := &bmRepoStub{
getValueFn: func(ctx context.Context, key string) (string, error) {
require.Equal(t, SettingKeyBackendModeEnabled, key)
return "true", nil
},
}
svc := NewSettingService(repo, &config.Config{})
require.True(t, svc.IsBackendModeEnabled(context.Background()))
require.True(t, svc.IsBackendModeEnabled(context.Background()))
require.Equal(t, 1, repo.calls)
}
func TestUpdateSettings_InvalidatesBackendModeCache(t *testing.T) {
resetBackendModeTestCache(t)
backendModeCache.Store(&cachedBackendMode{
value: true,
expiresAt: time.Now().Add(backendModeCacheTTL).UnixNano(),
})
repo := &bmUpdateRepoStub{
getValueFn: func(ctx context.Context, key string) (string, error) {
require.Equal(t, SettingKeyBackendModeEnabled, key)
return "true", nil
},
}
svc := NewSettingService(repo, &config.Config{})
err := svc.UpdateSettings(context.Background(), &SystemSettings{
BackendModeEnabled: false,
})
require.NoError(t, err)
require.Equal(t, "false", repo.updates[SettingKeyBackendModeEnabled])
require.False(t, svc.IsBackendModeEnabled(context.Background()))
}
......@@ -69,6 +69,9 @@ type SystemSettings struct {
// 分组隔离:允许未分组 Key 调度(默认 false → 403)
AllowUngroupedKeyScheduling bool
// Backend 模式:禁用用户注册和自助服务,仅管理员可登录
BackendModeEnabled bool
}
type DefaultSubscriptionSetting struct {
......@@ -101,6 +104,7 @@ type PublicSettings struct {
CustomMenuItems string // JSON array of custom menu items
LinuxDoOAuthEnabled bool
BackendModeEnabled bool
Version string
}
......@@ -201,13 +205,14 @@ const (
BetaPolicyScopeAll = "all" // 所有账号类型
BetaPolicyScopeOAuth = "oauth" // 仅 OAuth 账号
BetaPolicyScopeAPIKey = "apikey" // 仅 API Key 账号
BetaPolicyScopeBedrock = "bedrock" // 仅 AWS Bedrock 账号
)
// BetaPolicyRule 单条 Beta 策略规则
type BetaPolicyRule struct {
BetaToken string `json:"beta_token"` // beta token 值
Action string `json:"action"` // "pass" | "filter" | "block"
Scope string `json:"scope"` // "all" | "oauth" | "apikey"
Scope string `json:"scope"` // "all" | "oauth" | "apikey" | "bedrock"
ErrorMessage string `json:"error_message,omitempty"` // 自定义错误消息 (action=block 时生效)
}
......
......@@ -841,6 +841,8 @@ export interface OpsAdvancedSettings {
ignore_context_canceled: boolean
ignore_no_available_accounts: boolean
ignore_invalid_api_key_errors: boolean
display_openai_token_stats: boolean
display_alert_events: boolean
auto_refresh_enabled: boolean
auto_refresh_interval_seconds: number
}
......
......@@ -40,6 +40,7 @@ export interface SystemSettings {
purchase_subscription_enabled: boolean
purchase_subscription_url: string
sora_client_enabled: boolean
backend_mode_enabled: boolean
custom_menu_items: CustomMenuItem[]
// SMTP settings
smtp_host: string
......@@ -106,6 +107,7 @@ export interface UpdateSettingsRequest {
purchase_subscription_enabled?: boolean
purchase_subscription_url?: string
sora_client_enabled?: boolean
backend_mode_enabled?: boolean
custom_menu_items?: CustomMenuItem[]
smtp_host?: string
smtp_port?: number
......@@ -316,7 +318,7 @@ export async function updateRectifierSettings(
export interface BetaPolicyRule {
beta_token: string
action: 'pass' | 'filter' | 'block'
scope: 'all' | 'oauth' | 'apikey'
scope: 'all' | 'oauth' | 'apikey' | 'bedrock'
error_message?: string
}
......
......@@ -292,17 +292,19 @@ const rpmTooltip = computed(() => {
}
})
// 是否显示各维度配额(仅 apikey 类型)
// 是否显示各维度配额(apikey / bedrock 类型)
const isQuotaEligible = computed(() => props.account.type === 'apikey' || props.account.type === 'bedrock')
const showDailyQuota = computed(() => {
return props.account.type === 'apikey' && (props.account.quota_daily_limit ?? 0) > 0
return isQuotaEligible.value && (props.account.quota_daily_limit ?? 0) > 0
})
const showWeeklyQuota = computed(() => {
return props.account.type === 'apikey' && (props.account.quota_weekly_limit ?? 0) > 0
return isQuotaEligible.value && (props.account.quota_weekly_limit ?? 0) > 0
})
const showTotalQuota = computed(() => {
return props.account.type === 'apikey' && (props.account.quota_limit ?? 0) > 0
return isQuotaEligible.value && (props.account.quota_limit ?? 0) > 0
})
// 格式化费用显示
......
......@@ -36,6 +36,10 @@
<!-- Usage data -->
<div v-else-if="usageInfo" class="space-y-1">
<!-- API error (degraded response) -->
<div v-if="usageInfo.error" class="text-xs text-amber-600 dark:text-amber-400 truncate max-w-[200px]" :title="usageInfo.error">
{{ usageInfo.error }}
</div>
<!-- 5h Window -->
<UsageProgressBar
v-if="usageInfo.five_hour"
......@@ -189,8 +193,53 @@
</span>
</div>
<!-- Forbidden state (403) -->
<div v-if="isForbidden" class="space-y-1">
<span
:class="[
'inline-block rounded px-1.5 py-0.5 text-[10px] font-medium',
forbiddenBadgeClass
]"
>
{{ forbiddenLabel }}
</span>
<div v-if="validationURL" class="flex items-center gap-1">
<a
:href="validationURL"
target="_blank"
rel="noopener noreferrer"
class="text-[10px] text-blue-600 hover:text-blue-800 hover:underline dark:text-blue-400 dark:hover:text-blue-300"
:title="t('admin.accounts.openVerification')"
>
{{ t('admin.accounts.openVerification') }}
</a>
<button
type="button"
class="text-[10px] text-gray-500 hover:text-gray-700 dark:text-gray-400 dark:hover:text-gray-200"
:title="t('admin.accounts.copyLink')"
@click="copyValidationURL"
>
{{ linkCopied ? t('admin.accounts.linkCopied') : t('admin.accounts.copyLink') }}
</button>
</div>
</div>
<!-- Needs reauth (401) -->
<div v-else-if="needsReauth" class="space-y-1">
<span class="inline-block rounded px-1.5 py-0.5 text-[10px] font-medium bg-orange-100 text-orange-700 dark:bg-orange-900/40 dark:text-orange-300">
{{ t('admin.accounts.needsReauth') }}
</span>
</div>
<!-- Degraded error (non-403, non-401) -->
<div v-else-if="usageInfo?.error" class="space-y-1">
<span class="inline-block rounded px-1.5 py-0.5 text-[10px] font-medium bg-amber-100 text-amber-700 dark:bg-amber-900/40 dark:text-amber-300">
{{ usageErrorLabel }}
</span>
</div>
<!-- Loading state -->
<div v-if="loading" class="space-y-1.5">
<div v-else-if="loading" class="space-y-1.5">
<div class="flex items-center gap-1">
<div class="h-3 w-[32px] animate-pulse rounded bg-gray-200 dark:bg-gray-700"></div>
<div class="h-1.5 w-8 animate-pulse rounded-full bg-gray-200 dark:bg-gray-700"></div>
......@@ -816,6 +865,51 @@ const hasIneligibleTiers = computed(() => {
return Array.isArray(ineligibleTiers) && ineligibleTiers.length > 0
})
// Antigravity 403 forbidden 状态
const isForbidden = computed(() => !!usageInfo.value?.is_forbidden)
const forbiddenType = computed(() => usageInfo.value?.forbidden_type || 'forbidden')
const validationURL = computed(() => usageInfo.value?.validation_url || '')
// 需要重新授权(401)
const needsReauth = computed(() => !!usageInfo.value?.needs_reauth)
// 降级错误标签(rate_limited / network_error)
const usageErrorLabel = computed(() => {
const code = usageInfo.value?.error_code
if (code === 'rate_limited') return t('admin.accounts.rateLimited')
return t('admin.accounts.usageError')
})
const forbiddenLabel = computed(() => {
switch (forbiddenType.value) {
case 'validation':
return t('admin.accounts.forbiddenValidation')
case 'violation':
return t('admin.accounts.forbiddenViolation')
default:
return t('admin.accounts.forbidden')
}
})
const forbiddenBadgeClass = computed(() => {
if (forbiddenType.value === 'validation') {
return 'bg-yellow-100 text-yellow-700 dark:bg-yellow-900/40 dark:text-yellow-300'
}
return 'bg-red-100 text-red-700 dark:bg-red-900/40 dark:text-red-300'
})
const linkCopied = ref(false)
const copyValidationURL = async () => {
if (!validationURL.value) return
try {
await navigator.clipboard.writeText(validationURL.value)
linkCopied.value = true
setTimeout(() => { linkCopied.value = false }, 2000)
} catch {
// fallback: ignore
}
}
const loadUsage = async () => {
if (!shouldFetchUsage.value) return
......@@ -848,18 +942,30 @@ const makeQuotaBar = (
let resetsAt: string | null = null
if (startKey) {
const extra = props.account.extra as Record<string, unknown> | undefined
const isDaily = startKey.includes('daily')
const mode = isDaily
? (extra?.quota_daily_reset_mode as string) || 'rolling'
: (extra?.quota_weekly_reset_mode as string) || 'rolling'
if (mode === 'fixed') {
// Use pre-computed next reset time for fixed mode
const resetAtKey = isDaily ? 'quota_daily_reset_at' : 'quota_weekly_reset_at'
resetsAt = (extra?.[resetAtKey] as string) || null
} else {
// Rolling mode: compute from start + period
const startStr = extra?.[startKey] as string | undefined
if (startStr) {
const startDate = new Date(startStr)
const periodMs = startKey.includes('daily') ? 24 * 60 * 60 * 1000 : 7 * 24 * 60 * 60 * 1000
const periodMs = isDaily ? 24 * 60 * 60 * 1000 : 7 * 24 * 60 * 60 * 1000
resetsAt = new Date(startDate.getTime() + periodMs).toISOString()
}
}
}
return { utilization, resetsAt }
}
const hasApiKeyQuota = computed(() => {
if (props.account.type !== 'apikey') return false
if (props.account.type !== 'apikey' && props.account.type !== 'bedrock') return false
return (
(props.account.quota_daily_limit ?? 0) > 0 ||
(props.account.quota_weekly_limit ?? 0) > 0 ||
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment