Commit b764d3b8 authored by ius's avatar ius
Browse files

Merge remote-tracking branch 'origin/main' into feat/billing-ledger-decouple-usage-log-20260312

parents 611fd884 826090e0
......@@ -6069,6 +6069,22 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
intervalCh = intervalTicker.C
}
// 下游 keepalive:防止代理/Cloudflare Tunnel 因连接空闲而断开
keepaliveInterval := time.Duration(0)
if s.cfg != nil && s.cfg.Gateway.StreamKeepaliveInterval > 0 {
keepaliveInterval = time.Duration(s.cfg.Gateway.StreamKeepaliveInterval) * time.Second
}
var keepaliveTicker *time.Ticker
if keepaliveInterval > 0 {
keepaliveTicker = time.NewTicker(keepaliveInterval)
defer keepaliveTicker.Stop()
}
var keepaliveCh <-chan time.Time
if keepaliveTicker != nil {
keepaliveCh = keepaliveTicker.C
}
lastDataAt := time.Now()
// 仅发送一次错误事件,避免多次写入导致协议混乱(写失败时尽力通知客户端)
errorEventSent := false
sendErrorEvent := func(reason string) {
......@@ -6267,6 +6283,7 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
break
}
flusher.Flush()
lastDataAt = time.Now()
}
if data != "" {
if firstTokenMs == nil && data != "[DONE]" {
......@@ -6298,6 +6315,22 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
}
sendErrorEvent("stream_timeout")
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream data interval timeout")
case <-keepaliveCh:
if clientDisconnected {
continue
}
if time.Since(lastDataAt) < keepaliveInterval {
continue
}
// SSE ping 事件:Anthropic 原生格式,客户端会正确处理,
// 同时保持连接活跃防止 Cloudflare Tunnel 等代理断开
if _, werr := fmt.Fprint(w, "event: ping\ndata: {\"type\": \"ping\"}\n\n"); werr != nil {
clientDisconnected = true
logger.LegacyPrintf("service.gateway", "Client disconnected during keepalive ping, continuing to drain upstream for billing")
continue
}
flusher.Flush()
}
}
......
package service
import (
"encoding/json"
"testing"
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
"github.com/stretchr/testify/require"
)
func TestCleanGeminiNativeThoughtSignatures_ReplacesNestedThoughtSignatures(t *testing.T) {
input := []byte(`{
"contents": [
{
"role": "user",
"parts": [{"text": "hello"}]
},
{
"role": "model",
"parts": [
{"text": "thinking", "thought": true, "thoughtSignature": "sig_1"},
{"functionCall": {"name": "toolA", "args": {"k": "v"}}, "thoughtSignature": "sig_2"}
]
}
],
"cachedContent": {
"parts": [{"text": "cached", "thoughtSignature": "sig_3"}]
},
"signature": "keep_me"
}`)
cleaned := CleanGeminiNativeThoughtSignatures(input)
var got map[string]any
require.NoError(t, json.Unmarshal(cleaned, &got))
require.NotContains(t, string(cleaned), `"thoughtSignature":"sig_1"`)
require.NotContains(t, string(cleaned), `"thoughtSignature":"sig_2"`)
require.NotContains(t, string(cleaned), `"thoughtSignature":"sig_3"`)
require.Contains(t, string(cleaned), `"thoughtSignature":"`+antigravity.DummyThoughtSignature+`"`)
require.Contains(t, string(cleaned), `"signature":"keep_me"`)
}
func TestCleanGeminiNativeThoughtSignatures_InvalidJSONReturnsOriginal(t *testing.T) {
input := []byte(`{"contents":[invalid-json]}`)
cleaned := CleanGeminiNativeThoughtSignatures(input)
require.Equal(t, input, cleaned)
}
func TestReplaceThoughtSignaturesRecursive_OnlyReplacesTargetField(t *testing.T) {
input := map[string]any{
"thoughtSignature": "sig_root",
"signature": "keep_signature",
"nested": []any{
map[string]any{
"thoughtSignature": "sig_nested",
"signature": "keep_nested_signature",
},
},
}
got, ok := replaceThoughtSignaturesRecursive(input).(map[string]any)
require.True(t, ok)
require.Equal(t, antigravity.DummyThoughtSignature, got["thoughtSignature"])
require.Equal(t, "keep_signature", got["signature"])
nested, ok := got["nested"].([]any)
require.True(t, ok)
nestedMap, ok := nested[0].(map[string]any)
require.True(t, ok)
require.Equal(t, antigravity.DummyThoughtSignature, nestedMap["thoughtSignature"])
require.Equal(t, "keep_nested_signature", nestedMap["signature"])
}
package service
import (
"fmt"
"strings"
)
......@@ -226,6 +227,29 @@ func normalizeCodexModel(model string) string {
return "gpt-5.1"
}
func SupportsVerbosity(model string) bool {
if !strings.HasPrefix(model, "gpt-") {
return true
}
var major, minor int
n, _ := fmt.Sscanf(model, "gpt-%d.%d", &major, &minor)
if major > 5 {
return true
}
if major < 5 {
return false
}
// gpt-5
if n == 1 {
return true
}
return minor >= 3
}
func getNormalizedCodexModel(modelID string) string {
if modelID == "" {
return ""
......
This diff is collapsed.
......@@ -172,7 +172,7 @@ func (s *OpenAIGatewayService) ForwardAsAnthropic(
return nil, &UpstreamFailoverError{
StatusCode: resp.StatusCode,
ResponseBody: respBody,
RetryableOnSameAccount: account.IsPoolMode() && isOpenAITransientProcessingError(resp.StatusCode, upstreamMsg, respBody),
RetryableOnSameAccount: account.IsPoolMode() && (isPoolModeRetryableStatus(resp.StatusCode) || isOpenAITransientProcessingError(resp.StatusCode, upstreamMsg, respBody)),
}
}
// Non-failover error: return Anthropic-formatted error to client
......@@ -219,54 +219,7 @@ func (s *OpenAIGatewayService) handleAnthropicErrorResponse(
c *gin.Context,
account *Account,
) (*OpenAIForwardResult, error) {
body, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(body))
if upstreamMsg == "" {
upstreamMsg = fmt.Sprintf("Upstream error: %d", resp.StatusCode)
}
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
// Record upstream error details for ops logging
upstreamDetail := ""
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
if maxBytes <= 0 {
maxBytes = 2048
}
upstreamDetail = truncateString(string(body), maxBytes)
}
setOpsUpstreamError(c, resp.StatusCode, upstreamMsg, upstreamDetail)
// Apply error passthrough rules (matches handleErrorResponse pattern in openai_gateway_service.go)
if status, errType, errMsg, matched := applyErrorPassthroughRule(
c, account.Platform, resp.StatusCode, body,
http.StatusBadGateway, "api_error", "Upstream request failed",
); matched {
writeAnthropicError(c, status, errType, errMsg)
if upstreamMsg == "" {
upstreamMsg = errMsg
}
if upstreamMsg == "" {
return nil, fmt.Errorf("upstream error: %d (passthrough rule matched)", resp.StatusCode)
}
return nil, fmt.Errorf("upstream error: %d (passthrough rule matched) message=%s", resp.StatusCode, upstreamMsg)
}
errType := "api_error"
switch {
case resp.StatusCode == 400:
errType = "invalid_request_error"
case resp.StatusCode == 404:
errType = "not_found_error"
case resp.StatusCode == 429:
errType = "rate_limit_error"
case resp.StatusCode >= 500:
errType = "api_error"
}
writeAnthropicError(c, resp.StatusCode, errType, upstreamMsg)
return nil, fmt.Errorf("upstream error: %d %s", resp.StatusCode, upstreamMsg)
return s.handleCompatErrorResponse(resp, c, account, writeAnthropicError)
}
// handleAnthropicBufferedStreamingResponse reads all Responses SSE events from
......
......@@ -52,6 +52,8 @@ const (
openAIWSRetryJitterRatioDefault = 0.2
openAICompactSessionSeedKey = "openai_compact_session_seed"
codexCLIVersion = "0.104.0"
// Codex 限额快照仅用于后台展示/诊断,不需要每个成功请求都立即落库。
openAICodexSnapshotPersistMinInterval = 30 * time.Second
)
// OpenAI allowed headers whitelist (for non-passthrough).
......@@ -255,6 +257,46 @@ type openAIWSRetryMetrics struct {
nonRetryableFastFallback atomic.Int64
}
type accountWriteThrottle struct {
minInterval time.Duration
mu sync.Mutex
lastByID map[int64]time.Time
}
func newAccountWriteThrottle(minInterval time.Duration) *accountWriteThrottle {
return &accountWriteThrottle{
minInterval: minInterval,
lastByID: make(map[int64]time.Time),
}
}
func (t *accountWriteThrottle) Allow(id int64, now time.Time) bool {
if t == nil || id <= 0 || t.minInterval <= 0 {
return true
}
t.mu.Lock()
defer t.mu.Unlock()
if last, ok := t.lastByID[id]; ok && now.Sub(last) < t.minInterval {
return false
}
t.lastByID[id] = now
if len(t.lastByID) > 4096 {
cutoff := now.Add(-4 * t.minInterval)
for accountID, writtenAt := range t.lastByID {
if writtenAt.Before(cutoff) {
delete(t.lastByID, accountID)
}
}
}
return true
}
var defaultOpenAICodexSnapshotPersistThrottle = newAccountWriteThrottle(openAICodexSnapshotPersistMinInterval)
// OpenAIGatewayService handles OpenAI API gateway operations
type OpenAIGatewayService struct {
accountRepo AccountRepository
......@@ -290,6 +332,7 @@ type OpenAIGatewayService struct {
openaiWSFallbackUntil sync.Map // key: int64(accountID), value: time.Time
openaiWSRetryMetrics openAIWSRetryMetrics
responseHeaderFilter *responseheaders.CompiledHeaderFilter
codexSnapshotThrottle *accountWriteThrottle
}
// NewOpenAIGatewayService creates a new OpenAIGatewayService
......@@ -332,17 +375,25 @@ func NewOpenAIGatewayService(
nil,
"service.openai_gateway",
),
httpUpstream: httpUpstream,
deferredService: deferredService,
openAITokenProvider: openAITokenProvider,
toolCorrector: NewCodexToolCorrector(),
openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
responseHeaderFilter: compileResponseHeaderFilter(cfg),
httpUpstream: httpUpstream,
deferredService: deferredService,
openAITokenProvider: openAITokenProvider,
toolCorrector: NewCodexToolCorrector(),
openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
responseHeaderFilter: compileResponseHeaderFilter(cfg),
codexSnapshotThrottle: newAccountWriteThrottle(openAICodexSnapshotPersistMinInterval),
}
svc.logOpenAIWSModeBootstrap()
return svc
}
func (s *OpenAIGatewayService) getCodexSnapshotThrottle() *accountWriteThrottle {
if s != nil && s.codexSnapshotThrottle != nil {
return s.codexSnapshotThrottle
}
return defaultOpenAICodexSnapshotPersistThrottle
}
func (s *OpenAIGatewayService) billingDeps() *billingDeps {
return &billingDeps{
accountRepo: s.accountRepo,
......@@ -1719,6 +1770,14 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
bodyModified = true
markPatchSet("model", normalizedModel)
}
// 移除 gpt-5.2-codex 以下的版本 verbosity 参数
// 确保高版本模型向低版本模型映射不报错
if !SupportsVerbosity(normalizedModel) {
if text, ok := reqBody["text"].(map[string]any); ok {
delete(text, "verbosity")
}
}
}
// 规范化 reasoning.effort 参数(minimal -> none),与上游允许值对齐。
......@@ -2954,6 +3013,120 @@ func (s *OpenAIGatewayService) handleErrorResponse(
return nil, fmt.Errorf("upstream error: %d message=%s", resp.StatusCode, upstreamMsg)
}
// compatErrorWriter is the signature for format-specific error writers used by
// the compat paths (Chat Completions and Anthropic Messages).
type compatErrorWriter func(c *gin.Context, statusCode int, errType, message string)
// handleCompatErrorResponse is the shared non-failover error handler for the
// Chat Completions and Anthropic Messages compat paths. It mirrors the logic of
// handleErrorResponse (passthrough rules, ShouldHandleErrorCode, rate-limit
// tracking, secondary failover) but delegates the final error write to the
// format-specific writer function.
func (s *OpenAIGatewayService) handleCompatErrorResponse(
resp *http.Response,
c *gin.Context,
account *Account,
writeError compatErrorWriter,
) (*OpenAIForwardResult, error) {
body, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(body))
if upstreamMsg == "" {
upstreamMsg = fmt.Sprintf("Upstream error: %d", resp.StatusCode)
}
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
upstreamDetail := ""
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
if maxBytes <= 0 {
maxBytes = 2048
}
upstreamDetail = truncateString(string(body), maxBytes)
}
setOpsUpstreamError(c, resp.StatusCode, upstreamMsg, upstreamDetail)
// Apply error passthrough rules
if status, errType, errMsg, matched := applyErrorPassthroughRule(
c, account.Platform, resp.StatusCode, body,
http.StatusBadGateway, "api_error", "Upstream request failed",
); matched {
writeError(c, status, errType, errMsg)
if upstreamMsg == "" {
upstreamMsg = errMsg
}
if upstreamMsg == "" {
return nil, fmt.Errorf("upstream error: %d (passthrough rule matched)", resp.StatusCode)
}
return nil, fmt.Errorf("upstream error: %d (passthrough rule matched) message=%s", resp.StatusCode, upstreamMsg)
}
// Check custom error codes — if the account does not handle this status,
// return a generic error without exposing upstream details.
if !account.ShouldHandleErrorCode(resp.StatusCode) {
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: resp.StatusCode,
UpstreamRequestID: resp.Header.Get("x-request-id"),
Kind: "http_error",
Message: upstreamMsg,
Detail: upstreamDetail,
})
writeError(c, http.StatusInternalServerError, "api_error", "Upstream gateway error")
if upstreamMsg == "" {
return nil, fmt.Errorf("upstream error: %d (not in custom error codes)", resp.StatusCode)
}
return nil, fmt.Errorf("upstream error: %d (not in custom error codes) message=%s", resp.StatusCode, upstreamMsg)
}
// Track rate limits and decide whether to trigger secondary failover.
shouldDisable := false
if s.rateLimitService != nil {
shouldDisable = s.rateLimitService.HandleUpstreamError(
c.Request.Context(), account, resp.StatusCode, resp.Header, body,
)
}
kind := "http_error"
if shouldDisable {
kind = "failover"
}
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: resp.StatusCode,
UpstreamRequestID: resp.Header.Get("x-request-id"),
Kind: kind,
Message: upstreamMsg,
Detail: upstreamDetail,
})
if shouldDisable {
return nil, &UpstreamFailoverError{
StatusCode: resp.StatusCode,
ResponseBody: body,
RetryableOnSameAccount: account.IsPoolMode() && isPoolModeRetryableStatus(resp.StatusCode),
}
}
// Map status code to error type and write response
errType := "api_error"
switch {
case resp.StatusCode == 400:
errType = "invalid_request_error"
case resp.StatusCode == 404:
errType = "not_found_error"
case resp.StatusCode == 429:
errType = "rate_limit_error"
case resp.StatusCode >= 500:
errType = "api_error"
}
writeError(c, resp.StatusCode, errType, upstreamMsg)
return nil, fmt.Errorf("upstream error: %d %s", resp.StatusCode, upstreamMsg)
}
// openaiStreamingResult streaming response result
type openaiStreamingResult struct {
usage *OpenAIUsage
......@@ -4071,11 +4244,15 @@ func (s *OpenAIGatewayService) updateCodexUsageSnapshot(ctx context.Context, acc
if len(updates) == 0 && resetAt == nil {
return
}
shouldPersistUpdates := len(updates) > 0 && s.getCodexSnapshotThrottle().Allow(accountID, now)
if !shouldPersistUpdates && resetAt == nil {
return
}
go func() {
updateCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if len(updates) > 0 {
if shouldPersistUpdates {
_ = s.accountRepo.UpdateExtra(updateCtx, accountID, updates)
}
if resetAt != nil {
......
......@@ -405,6 +405,40 @@ func TestOpenAIGatewayService_UpdateCodexUsageSnapshot_NonExhaustedSnapshotDoesN
}
}
func TestOpenAIGatewayService_UpdateCodexUsageSnapshot_ThrottlesExtraWrites(t *testing.T) {
repo := &openAICodexSnapshotAsyncRepo{
updateExtraCh: make(chan map[string]any, 2),
rateLimitCh: make(chan time.Time, 2),
}
svc := &OpenAIGatewayService{
accountRepo: repo,
codexSnapshotThrottle: newAccountWriteThrottle(time.Hour),
}
snapshot := &OpenAICodexUsageSnapshot{
PrimaryUsedPercent: ptrFloat64WS(94),
PrimaryResetAfterSeconds: ptrIntWS(3600),
PrimaryWindowMinutes: ptrIntWS(10080),
SecondaryUsedPercent: ptrFloat64WS(22),
SecondaryResetAfterSeconds: ptrIntWS(1200),
SecondaryWindowMinutes: ptrIntWS(300),
}
svc.updateCodexUsageSnapshot(context.Background(), 777, snapshot)
svc.updateCodexUsageSnapshot(context.Background(), 777, snapshot)
select {
case <-repo.updateExtraCh:
case <-time.After(2 * time.Second):
t.Fatal("等待第一次 codex 快照落库超时")
}
select {
case updates := <-repo.updateExtraCh:
t.Fatalf("unexpected second codex snapshot write: %v", updates)
case <-time.After(200 * time.Millisecond):
}
}
func ptrFloat64WS(v float64) *float64 { return &v }
func ptrIntWS(v int) *int { return &v }
......
......@@ -506,6 +506,48 @@ func (s *OpsAlertEvaluatorService) computeRuleMetric(
return float64(countAccountsByCondition(availability.Accounts, func(acc *AccountAvailability) bool {
return acc.HasError && acc.TempUnschedulableUntil == nil
})), true
case "group_rate_limit_ratio":
if groupID == nil || *groupID <= 0 {
return 0, false
}
if s == nil || s.opsService == nil {
return 0, false
}
availability, err := s.opsService.GetAccountAvailability(ctx, platform, groupID)
if err != nil || availability == nil {
return 0, false
}
if availability.Group == nil || availability.Group.TotalAccounts <= 0 {
return 0, true
}
return (float64(availability.Group.RateLimitCount) / float64(availability.Group.TotalAccounts)) * 100, true
case "account_error_ratio":
if s == nil || s.opsService == nil {
return 0, false
}
availability, err := s.opsService.GetAccountAvailability(ctx, platform, groupID)
if err != nil || availability == nil {
return 0, false
}
total := int64(len(availability.Accounts))
if total <= 0 {
return 0, true
}
errorCount := countAccountsByCondition(availability.Accounts, func(acc *AccountAvailability) bool {
return acc.HasError && acc.TempUnschedulableUntil == nil
})
return (float64(errorCount) / float64(total)) * 100, true
case "overload_account_count":
if s == nil || s.opsService == nil {
return 0, false
}
availability, err := s.opsService.GetAccountAvailability(ctx, platform, groupID)
if err != nil || availability == nil {
return 0, false
}
return float64(countAccountsByCondition(availability.Accounts, func(acc *AccountAvailability) bool {
return acc.IsOverloaded
})), true
}
overview, err := s.opsRepo.GetDashboardOverview(ctx, &OpsDashboardFilter{
......
......@@ -7,6 +7,7 @@ import (
type OpsRepository interface {
InsertErrorLog(ctx context.Context, input *OpsInsertErrorLogInput) (int64, error)
BatchInsertErrorLogs(ctx context.Context, inputs []*OpsInsertErrorLogInput) (int64, error)
ListErrorLogs(ctx context.Context, filter *OpsErrorLogFilter) (*OpsErrorLogList, error)
GetErrorLogByID(ctx context.Context, id int64) (*OpsErrorLogDetail, error)
ListRequestDetails(ctx context.Context, filter *OpsRequestDetailFilter) ([]*OpsRequestDetail, int64, error)
......
......@@ -7,6 +7,8 @@ import (
// opsRepoMock is a test-only OpsRepository implementation with optional function hooks.
type opsRepoMock struct {
InsertErrorLogFn func(ctx context.Context, input *OpsInsertErrorLogInput) (int64, error)
BatchInsertErrorLogsFn func(ctx context.Context, inputs []*OpsInsertErrorLogInput) (int64, error)
BatchInsertSystemLogsFn func(ctx context.Context, inputs []*OpsInsertSystemLogInput) (int64, error)
ListSystemLogsFn func(ctx context.Context, filter *OpsSystemLogFilter) (*OpsSystemLogList, error)
DeleteSystemLogsFn func(ctx context.Context, filter *OpsSystemLogCleanupFilter) (int64, error)
......@@ -14,9 +16,19 @@ type opsRepoMock struct {
}
func (m *opsRepoMock) InsertErrorLog(ctx context.Context, input *OpsInsertErrorLogInput) (int64, error) {
if m.InsertErrorLogFn != nil {
return m.InsertErrorLogFn(ctx, input)
}
return 0, nil
}
func (m *opsRepoMock) BatchInsertErrorLogs(ctx context.Context, inputs []*OpsInsertErrorLogInput) (int64, error) {
if m.BatchInsertErrorLogsFn != nil {
return m.BatchInsertErrorLogsFn(ctx, inputs)
}
return int64(len(inputs)), nil
}
func (m *opsRepoMock) ListErrorLogs(ctx context.Context, filter *OpsErrorLogFilter) (*OpsErrorLogList, error) {
return &OpsErrorLogList{Errors: []*OpsErrorLog{}, Page: 1, PageSize: 20}, nil
}
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
......@@ -31,6 +31,7 @@ var (
ErrSubscriptionAlreadyExists = infraerrors.Conflict("SUBSCRIPTION_ALREADY_EXISTS", "subscription already exists for this user and group")
ErrSubscriptionAssignConflict = infraerrors.Conflict("SUBSCRIPTION_ASSIGN_CONFLICT", "subscription exists but request conflicts with existing assignment semantics")
ErrGroupNotSubscriptionType = infraerrors.BadRequest("GROUP_NOT_SUBSCRIPTION_TYPE", "group is not a subscription type")
ErrInvalidInput = infraerrors.BadRequest("INVALID_INPUT", "at least one of resetDaily or resetWeekly must be true")
ErrDailyLimitExceeded = infraerrors.TooManyRequests("DAILY_LIMIT_EXCEEDED", "daily usage limit exceeded")
ErrWeeklyLimitExceeded = infraerrors.TooManyRequests("WEEKLY_LIMIT_EXCEEDED", "weekly usage limit exceeded")
ErrMonthlyLimitExceeded = infraerrors.TooManyRequests("MONTHLY_LIMIT_EXCEEDED", "monthly usage limit exceeded")
......@@ -695,6 +696,36 @@ func (s *SubscriptionService) CheckAndActivateWindow(ctx context.Context, sub *U
return s.userSubRepo.ActivateWindows(ctx, sub.ID, windowStart)
}
// AdminResetQuota manually resets the daily and/or weekly usage windows.
// Uses startOfDay(now) as the new window start, matching automatic resets.
func (s *SubscriptionService) AdminResetQuota(ctx context.Context, subscriptionID int64, resetDaily, resetWeekly bool) (*UserSubscription, error) {
if !resetDaily && !resetWeekly {
return nil, ErrInvalidInput
}
sub, err := s.userSubRepo.GetByID(ctx, subscriptionID)
if err != nil {
return nil, err
}
windowStart := startOfDay(time.Now())
if resetDaily {
if err := s.userSubRepo.ResetDailyUsage(ctx, sub.ID, windowStart); err != nil {
return nil, err
}
}
if resetWeekly {
if err := s.userSubRepo.ResetWeeklyUsage(ctx, sub.ID, windowStart); err != nil {
return nil, err
}
}
// Invalidate caches, same as CheckAndResetWindows
s.InvalidateSubCache(sub.UserID, sub.GroupID)
if s.billingCacheService != nil {
_ = s.billingCacheService.InvalidateSubscription(ctx, sub.UserID, sub.GroupID)
}
// Return the refreshed subscription from DB
return s.userSubRepo.GetByID(ctx, subscriptionID)
}
// CheckAndResetWindows 检查并重置过期的窗口
func (s *SubscriptionService) CheckAndResetWindows(ctx context.Context, sub *UserSubscription) error {
// 使用当天零点作为新窗口起始时间
......
This diff is collapsed.
......@@ -176,6 +176,7 @@ const formatScopeName = (scope: string): string => {
'gemini-2.5-flash-lite': 'G25FL',
'gemini-2.5-flash-thinking': 'G25FT',
'gemini-2.5-pro': 'G25P',
'gemini-2.5-flash-image': 'G25I',
// Gemini 3 系列
'gemini-3-flash': 'G3F',
'gemini-3.1-pro-high': 'G3PH',
......
......@@ -521,7 +521,7 @@ const antigravity3FlashUsageFromAPI = computed(() => getAntigravityUsageFromAPI(
// Gemini Image from API
const antigravity3ImageUsageFromAPI = computed(() =>
getAntigravityUsageFromAPI(['gemini-3.1-flash-image', 'gemini-3-pro-image'])
getAntigravityUsageFromAPI(['gemini-2.5-flash-image', 'gemini-3.1-flash-image', 'gemini-3-pro-image'])
)
// Claude from API (all Claude model variants)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment