"...specs/git@web.lueluesay.top:chenxi/sub2api.git" did not exist on "18b8bd43ad8c5fe00d19a67da101e2fe83814620"
Commit 31fe0178 authored by yangjianbo's avatar yangjianbo
Browse files
parents d9e345f2 ba5a0d47
...@@ -60,6 +60,92 @@ type OpenAICodexUsageSnapshot struct { ...@@ -60,6 +60,92 @@ type OpenAICodexUsageSnapshot struct {
UpdatedAt string `json:"updated_at,omitempty"` UpdatedAt string `json:"updated_at,omitempty"`
} }
// NormalizedCodexLimits contains normalized 5h/7d rate limit data
type NormalizedCodexLimits struct {
Used5hPercent *float64
Reset5hSeconds *int
Window5hMinutes *int
Used7dPercent *float64
Reset7dSeconds *int
Window7dMinutes *int
}
// Normalize converts primary/secondary fields to canonical 5h/7d fields.
// Strategy: Compare window_minutes to determine which is 5h vs 7d.
// Returns nil if snapshot is nil or has no useful data.
func (s *OpenAICodexUsageSnapshot) Normalize() *NormalizedCodexLimits {
if s == nil {
return nil
}
result := &NormalizedCodexLimits{}
primaryMins := 0
secondaryMins := 0
hasPrimaryWindow := false
hasSecondaryWindow := false
if s.PrimaryWindowMinutes != nil {
primaryMins = *s.PrimaryWindowMinutes
hasPrimaryWindow = true
}
if s.SecondaryWindowMinutes != nil {
secondaryMins = *s.SecondaryWindowMinutes
hasSecondaryWindow = true
}
// Determine mapping based on window_minutes
use5hFromPrimary := false
use7dFromPrimary := false
if hasPrimaryWindow && hasSecondaryWindow {
// Both known: smaller window is 5h, larger is 7d
if primaryMins < secondaryMins {
use5hFromPrimary = true
} else {
use7dFromPrimary = true
}
} else if hasPrimaryWindow {
// Only primary known: classify by threshold (<=360 min = 6h -> 5h window)
if primaryMins <= 360 {
use5hFromPrimary = true
} else {
use7dFromPrimary = true
}
} else if hasSecondaryWindow {
// Only secondary known: classify by threshold
if secondaryMins <= 360 {
// 5h from secondary, so primary (if any data) is 7d
use7dFromPrimary = true
} else {
// 7d from secondary, so primary (if any data) is 5h
use5hFromPrimary = true
}
} else {
// No window_minutes: fall back to legacy assumption (primary=7d, secondary=5h)
use7dFromPrimary = true
}
// Assign values
if use5hFromPrimary {
result.Used5hPercent = s.PrimaryUsedPercent
result.Reset5hSeconds = s.PrimaryResetAfterSeconds
result.Window5hMinutes = s.PrimaryWindowMinutes
result.Used7dPercent = s.SecondaryUsedPercent
result.Reset7dSeconds = s.SecondaryResetAfterSeconds
result.Window7dMinutes = s.SecondaryWindowMinutes
} else if use7dFromPrimary {
result.Used7dPercent = s.PrimaryUsedPercent
result.Reset7dSeconds = s.PrimaryResetAfterSeconds
result.Window7dMinutes = s.PrimaryWindowMinutes
result.Used5hPercent = s.SecondaryUsedPercent
result.Reset5hSeconds = s.SecondaryResetAfterSeconds
result.Window5hMinutes = s.SecondaryWindowMinutes
}
return result
}
// OpenAIUsage represents OpenAI API response usage // OpenAIUsage represents OpenAI API response usage
type OpenAIUsage struct { type OpenAIUsage struct {
InputTokens int `json:"input_tokens"` InputTokens int `json:"input_tokens"`
...@@ -70,12 +156,15 @@ type OpenAIUsage struct { ...@@ -70,12 +156,15 @@ type OpenAIUsage struct {
// OpenAIForwardResult represents the result of forwarding // OpenAIForwardResult represents the result of forwarding
type OpenAIForwardResult struct { type OpenAIForwardResult struct {
RequestID string RequestID string
Usage OpenAIUsage Usage OpenAIUsage
Model string Model string
Stream bool // ReasoningEffort is extracted from request body (reasoning.effort) or derived from model suffix.
Duration time.Duration // Stored for usage records display; nil means not provided / not applicable.
FirstTokenMs *int ReasoningEffort *string
Stream bool
Duration time.Duration
FirstTokenMs *int
} }
// OpenAIGatewayService handles OpenAI API gateway operations // OpenAIGatewayService handles OpenAI API gateway operations
...@@ -756,6 +845,12 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco ...@@ -756,6 +845,12 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
bodyModified = true bodyModified = true
} }
} }
// Remove prompt_cache_retention (not supported by upstream OpenAI API)
if _, has := reqBody["prompt_cache_retention"]; has {
delete(reqBody, "prompt_cache_retention")
bodyModified = true
}
} }
// Re-serialize body only if modified // Re-serialize body only if modified
...@@ -867,18 +962,21 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco ...@@ -867,18 +962,21 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
// Extract and save Codex usage snapshot from response headers (for OAuth accounts) // Extract and save Codex usage snapshot from response headers (for OAuth accounts)
if account.Type == AccountTypeOAuth { if account.Type == AccountTypeOAuth {
if snapshot := extractCodexUsageHeaders(resp.Header); snapshot != nil { if snapshot := ParseCodexRateLimitHeaders(resp.Header); snapshot != nil {
s.updateCodexUsageSnapshot(ctx, account.ID, snapshot) s.updateCodexUsageSnapshot(ctx, account.ID, snapshot)
} }
} }
reasoningEffort := extractOpenAIReasoningEffort(reqBody, originalModel)
return &OpenAIForwardResult{ return &OpenAIForwardResult{
RequestID: resp.Header.Get("x-request-id"), RequestID: resp.Header.Get("x-request-id"),
Usage: *usage, Usage: *usage,
Model: originalModel, Model: originalModel,
Stream: reqStream, ReasoningEffort: reasoningEffort,
Duration: time.Since(startTime), Stream: reqStream,
FirstTokenMs: firstTokenMs, Duration: time.Since(startTime),
FirstTokenMs: firstTokenMs,
}, nil }, nil
} }
...@@ -1174,15 +1272,29 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp ...@@ -1174,15 +1272,29 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
// 记录上次收到上游数据的时间,用于控制 keepalive 发送频率 // 记录上次收到上游数据的时间,用于控制 keepalive 发送频率
lastDataAt := time.Now() lastDataAt := time.Now()
// 仅发送一次错误事件,避免多次写入导致协议混乱(写失败时尽力通知客户端) // 仅发送一次错误事件,避免多次写入导致协议混乱。
// 注意:OpenAI `/v1/responses` streaming 事件必须符合 OpenAI Responses schema;
// 否则下游 SDK(例如 OpenCode)会因为类型校验失败而报错。
errorEventSent := false errorEventSent := false
clientDisconnected := false // 客户端断开后继续 drain 上游以收集 usage
sendErrorEvent := func(reason string) { sendErrorEvent := func(reason string) {
if errorEventSent { if errorEventSent || clientDisconnected {
return return
} }
errorEventSent = true errorEventSent = true
_, _ = fmt.Fprintf(w, "event: error\ndata: {\"error\":\"%s\"}\n\n", reason) payload := map[string]any{
flusher.Flush() "type": "error",
"sequence_number": 0,
"error": map[string]any{
"type": "upstream_error",
"message": reason,
"code": reason,
},
}
if b, err := json.Marshal(payload); err == nil {
_, _ = fmt.Fprintf(w, "data: %s\n\n", b)
flusher.Flush()
}
} }
needModelReplace := originalModel != mappedModel needModelReplace := originalModel != mappedModel
...@@ -1194,6 +1306,17 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp ...@@ -1194,6 +1306,17 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, nil return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, nil
} }
if ev.err != nil { if ev.err != nil {
// 客户端断开/取消请求时,上游读取往往会返回 context canceled。
// /v1/responses 的 SSE 事件必须符合 OpenAI 协议;这里不注入自定义 error event,避免下游 SDK 解析失败。
if errors.Is(ev.err, context.Canceled) || errors.Is(ev.err, context.DeadlineExceeded) {
log.Printf("Context canceled during streaming, returning collected usage")
return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, nil
}
// 客户端已断开时,上游出错仅影响体验,不影响计费;返回已收集 usage
if clientDisconnected {
log.Printf("Upstream read error after client disconnect: %v, returning collected usage", ev.err)
return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, nil
}
if errors.Is(ev.err, bufio.ErrTooLong) { if errors.Is(ev.err, bufio.ErrTooLong) {
log.Printf("SSE line too long: account=%d max_size=%d error=%v", account.ID, maxLineSize, ev.err) log.Printf("SSE line too long: account=%d max_size=%d error=%v", account.ID, maxLineSize, ev.err)
sendErrorEvent("response_too_large") sendErrorEvent("response_too_large")
...@@ -1217,15 +1340,19 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp ...@@ -1217,15 +1340,19 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
// Correct Codex tool calls if needed (apply_patch -> edit, etc.) // Correct Codex tool calls if needed (apply_patch -> edit, etc.)
if correctedData, corrected := s.toolCorrector.CorrectToolCallsInSSEData(data); corrected { if correctedData, corrected := s.toolCorrector.CorrectToolCallsInSSEData(data); corrected {
data = correctedData
line = "data: " + correctedData line = "data: " + correctedData
} }
// Forward line // 写入客户端(客户端断开后继续 drain 上游)
if _, err := fmt.Fprintf(w, "%s\n", line); err != nil { if !clientDisconnected {
sendErrorEvent("write_failed") if _, err := fmt.Fprintf(w, "%s\n", line); err != nil {
return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, err clientDisconnected = true
log.Printf("Client disconnected during streaming, continuing to drain upstream for billing")
} else {
flusher.Flush()
}
} }
flusher.Flush()
// Record first token time // Record first token time
if firstTokenMs == nil && data != "" && data != "[DONE]" { if firstTokenMs == nil && data != "" && data != "[DONE]" {
...@@ -1235,11 +1362,14 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp ...@@ -1235,11 +1362,14 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
s.parseSSEUsage(data, usage) s.parseSSEUsage(data, usage)
} else { } else {
// Forward non-data lines as-is // Forward non-data lines as-is
if _, err := fmt.Fprintf(w, "%s\n", line); err != nil { if !clientDisconnected {
sendErrorEvent("write_failed") if _, err := fmt.Fprintf(w, "%s\n", line); err != nil {
return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, err clientDisconnected = true
log.Printf("Client disconnected during streaming, continuing to drain upstream for billing")
} else {
flusher.Flush()
}
} }
flusher.Flush()
} }
case <-intervalCh: case <-intervalCh:
...@@ -1247,6 +1377,10 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp ...@@ -1247,6 +1377,10 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
if time.Since(lastRead) < streamInterval { if time.Since(lastRead) < streamInterval {
continue continue
} }
if clientDisconnected {
log.Printf("Upstream timeout after client disconnect, returning collected usage")
return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, nil
}
log.Printf("Stream data interval timeout: account=%d model=%s interval=%s", account.ID, originalModel, streamInterval) log.Printf("Stream data interval timeout: account=%d model=%s interval=%s", account.ID, originalModel, streamInterval)
// 处理流超时,可能标记账户为临时不可调度或错误状态 // 处理流超时,可能标记账户为临时不可调度或错误状态
if s.rateLimitService != nil { if s.rateLimitService != nil {
...@@ -1256,11 +1390,16 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp ...@@ -1256,11 +1390,16 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream data interval timeout") return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream data interval timeout")
case <-keepaliveCh: case <-keepaliveCh:
if clientDisconnected {
continue
}
if time.Since(lastDataAt) < keepaliveInterval { if time.Since(lastDataAt) < keepaliveInterval {
continue continue
} }
if _, err := fmt.Fprint(w, ":\n\n"); err != nil { if _, err := fmt.Fprint(w, ":\n\n"); err != nil {
return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, err clientDisconnected = true
log.Printf("Client disconnected during streaming, continuing to drain upstream for billing")
continue
} }
flusher.Flush() flusher.Flush()
} }
...@@ -1601,6 +1740,7 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec ...@@ -1601,6 +1740,7 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
AccountID: account.ID, AccountID: account.ID,
RequestID: result.RequestID, RequestID: result.RequestID,
Model: result.Model, Model: result.Model,
ReasoningEffort: result.ReasoningEffort,
InputTokens: actualInputTokens, InputTokens: actualInputTokens,
OutputTokens: result.Usage.OutputTokens, OutputTokens: result.Usage.OutputTokens,
CacheCreationTokens: result.Usage.CacheCreationInputTokens, CacheCreationTokens: result.Usage.CacheCreationInputTokens,
...@@ -1665,8 +1805,9 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec ...@@ -1665,8 +1805,9 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
return nil return nil
} }
// extractCodexUsageHeaders extracts Codex usage limits from response headers // ParseCodexRateLimitHeaders extracts Codex usage limits from response headers.
func extractCodexUsageHeaders(headers http.Header) *OpenAICodexUsageSnapshot { // Exported for use in ratelimit_service when handling OpenAI 429 responses.
func ParseCodexRateLimitHeaders(headers http.Header) *OpenAICodexUsageSnapshot {
snapshot := &OpenAICodexUsageSnapshot{} snapshot := &OpenAICodexUsageSnapshot{}
hasData := false hasData := false
...@@ -1740,6 +1881,8 @@ func (s *OpenAIGatewayService) updateCodexUsageSnapshot(ctx context.Context, acc ...@@ -1740,6 +1881,8 @@ func (s *OpenAIGatewayService) updateCodexUsageSnapshot(ctx context.Context, acc
// Convert snapshot to map for merging into Extra // Convert snapshot to map for merging into Extra
updates := make(map[string]any) updates := make(map[string]any)
// Save raw primary/secondary fields for debugging/tracing
if snapshot.PrimaryUsedPercent != nil { if snapshot.PrimaryUsedPercent != nil {
updates["codex_primary_used_percent"] = *snapshot.PrimaryUsedPercent updates["codex_primary_used_percent"] = *snapshot.PrimaryUsedPercent
} }
...@@ -1763,116 +1906,115 @@ func (s *OpenAIGatewayService) updateCodexUsageSnapshot(ctx context.Context, acc ...@@ -1763,116 +1906,115 @@ func (s *OpenAIGatewayService) updateCodexUsageSnapshot(ctx context.Context, acc
} }
updates["codex_usage_updated_at"] = snapshot.UpdatedAt updates["codex_usage_updated_at"] = snapshot.UpdatedAt
// Normalize to canonical 5h/7d fields based on window_minutes // Normalize to canonical 5h/7d fields
// This fixes the issue where OpenAI's primary/secondary naming is reversed if normalized := snapshot.Normalize(); normalized != nil {
// Strategy: Compare the two windows and assign the smaller one to 5h, larger one to 7d if normalized.Used5hPercent != nil {
updates["codex_5h_used_percent"] = *normalized.Used5hPercent
}
if normalized.Reset5hSeconds != nil {
updates["codex_5h_reset_after_seconds"] = *normalized.Reset5hSeconds
}
if normalized.Window5hMinutes != nil {
updates["codex_5h_window_minutes"] = *normalized.Window5hMinutes
}
if normalized.Used7dPercent != nil {
updates["codex_7d_used_percent"] = *normalized.Used7dPercent
}
if normalized.Reset7dSeconds != nil {
updates["codex_7d_reset_after_seconds"] = *normalized.Reset7dSeconds
}
if normalized.Window7dMinutes != nil {
updates["codex_7d_window_minutes"] = *normalized.Window7dMinutes
}
}
// Update account's Extra field asynchronously
go func() {
updateCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
_ = s.accountRepo.UpdateExtra(updateCtx, accountID, updates)
}()
}
// IMPORTANT: We can only reliably determine window type from window_minutes field func getOpenAIReasoningEffortFromReqBody(reqBody map[string]any) (value string, present bool) {
// The reset_after_seconds is remaining time, not window size, so it cannot be used for comparison if reqBody == nil {
return "", false
}
var primaryWindowMins, secondaryWindowMins int // Primary: reasoning.effort
var hasPrimaryWindow, hasSecondaryWindow bool if reasoning, ok := reqBody["reasoning"].(map[string]any); ok {
if effort, ok := reasoning["effort"].(string); ok {
return normalizeOpenAIReasoningEffort(effort), true
}
}
// Only use window_minutes for reliable window size comparison // Fallback: some clients may use a flat field.
if snapshot.PrimaryWindowMinutes != nil { if effort, ok := reqBody["reasoning_effort"].(string); ok {
primaryWindowMins = *snapshot.PrimaryWindowMinutes return normalizeOpenAIReasoningEffort(effort), true
hasPrimaryWindow = true
} }
if snapshot.SecondaryWindowMinutes != nil { return "", false
secondaryWindowMins = *snapshot.SecondaryWindowMinutes }
hasSecondaryWindow = true
func deriveOpenAIReasoningEffortFromModel(model string) string {
if strings.TrimSpace(model) == "" {
return ""
} }
// Determine which is 5h and which is 7d modelID := strings.TrimSpace(model)
var use5hFromPrimary, use7dFromPrimary bool if strings.Contains(modelID, "/") {
var use5hFromSecondary, use7dFromSecondary bool parts := strings.Split(modelID, "/")
modelID = parts[len(parts)-1]
}
if hasPrimaryWindow && hasSecondaryWindow { parts := strings.FieldsFunc(strings.ToLower(modelID), func(r rune) bool {
// Both window sizes known: compare and assign smaller to 5h, larger to 7d switch r {
if primaryWindowMins < secondaryWindowMins { case '-', '_', ' ':
use5hFromPrimary = true return true
use7dFromSecondary = true default:
} else { return false
use5hFromSecondary = true
use7dFromPrimary = true
}
} else if hasPrimaryWindow {
// Only primary window size known: classify by absolute threshold
if primaryWindowMins <= 360 {
use5hFromPrimary = true
} else {
use7dFromPrimary = true
}
} else if hasSecondaryWindow {
// Only secondary window size known: classify by absolute threshold
if secondaryWindowMins <= 360 {
use5hFromSecondary = true
} else {
use7dFromSecondary = true
}
} else {
// No window_minutes available: cannot reliably determine window types
// Fall back to legacy assumption (may be incorrect)
// Assume primary=7d, secondary=5h based on historical observation
if snapshot.SecondaryUsedPercent != nil || snapshot.SecondaryResetAfterSeconds != nil || snapshot.SecondaryWindowMinutes != nil {
use5hFromSecondary = true
}
if snapshot.PrimaryUsedPercent != nil || snapshot.PrimaryResetAfterSeconds != nil || snapshot.PrimaryWindowMinutes != nil {
use7dFromPrimary = true
} }
})
if len(parts) == 0 {
return ""
} }
// Write canonical 5h fields return normalizeOpenAIReasoningEffort(parts[len(parts)-1])
if use5hFromPrimary { }
if snapshot.PrimaryUsedPercent != nil {
updates["codex_5h_used_percent"] = *snapshot.PrimaryUsedPercent func extractOpenAIReasoningEffort(reqBody map[string]any, requestedModel string) *string {
} if value, present := getOpenAIReasoningEffortFromReqBody(reqBody); present {
if snapshot.PrimaryResetAfterSeconds != nil { if value == "" {
updates["codex_5h_reset_after_seconds"] = *snapshot.PrimaryResetAfterSeconds return nil
}
if snapshot.PrimaryWindowMinutes != nil {
updates["codex_5h_window_minutes"] = *snapshot.PrimaryWindowMinutes
}
} else if use5hFromSecondary {
if snapshot.SecondaryUsedPercent != nil {
updates["codex_5h_used_percent"] = *snapshot.SecondaryUsedPercent
}
if snapshot.SecondaryResetAfterSeconds != nil {
updates["codex_5h_reset_after_seconds"] = *snapshot.SecondaryResetAfterSeconds
}
if snapshot.SecondaryWindowMinutes != nil {
updates["codex_5h_window_minutes"] = *snapshot.SecondaryWindowMinutes
} }
return &value
} }
// Write canonical 7d fields value := deriveOpenAIReasoningEffortFromModel(requestedModel)
if use7dFromPrimary { if value == "" {
if snapshot.PrimaryUsedPercent != nil { return nil
updates["codex_7d_used_percent"] = *snapshot.PrimaryUsedPercent
}
if snapshot.PrimaryResetAfterSeconds != nil {
updates["codex_7d_reset_after_seconds"] = *snapshot.PrimaryResetAfterSeconds
}
if snapshot.PrimaryWindowMinutes != nil {
updates["codex_7d_window_minutes"] = *snapshot.PrimaryWindowMinutes
}
} else if use7dFromSecondary {
if snapshot.SecondaryUsedPercent != nil {
updates["codex_7d_used_percent"] = *snapshot.SecondaryUsedPercent
}
if snapshot.SecondaryResetAfterSeconds != nil {
updates["codex_7d_reset_after_seconds"] = *snapshot.SecondaryResetAfterSeconds
}
if snapshot.SecondaryWindowMinutes != nil {
updates["codex_7d_window_minutes"] = *snapshot.SecondaryWindowMinutes
}
} }
return &value
}
// Update account's Extra field asynchronously func normalizeOpenAIReasoningEffort(raw string) string {
go func() { value := strings.ToLower(strings.TrimSpace(raw))
updateCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) if value == "" {
defer cancel() return ""
_ = s.accountRepo.UpdateExtra(updateCtx, accountID, updates) }
}()
// Normalize separators for "x-high"/"x_high" variants.
value = strings.NewReplacer("-", "", "_", "", " ", "").Replace(value)
switch value {
case "none", "minimal":
return ""
case "low", "medium", "high":
return value
case "xhigh", "extrahigh":
return "xhigh"
default:
// Only store known effort levels for now to keep UI consistent.
return ""
}
} }
...@@ -59,6 +59,25 @@ type stubConcurrencyCache struct { ...@@ -59,6 +59,25 @@ type stubConcurrencyCache struct {
skipDefaultLoad bool skipDefaultLoad bool
} }
type cancelReadCloser struct{}
func (c cancelReadCloser) Read(p []byte) (int, error) { return 0, context.Canceled }
func (c cancelReadCloser) Close() error { return nil }
type failingGinWriter struct {
gin.ResponseWriter
failAfter int
writes int
}
func (w *failingGinWriter) Write(p []byte) (int, error) {
if w.writes >= w.failAfter {
return 0, errors.New("write failed")
}
w.writes++
return w.ResponseWriter.Write(p)
}
func (c stubConcurrencyCache) AcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error) { func (c stubConcurrencyCache) AcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error) {
if c.acquireResults != nil { if c.acquireResults != nil {
if result, ok := c.acquireResults[accountID]; ok { if result, ok := c.acquireResults[accountID]; ok {
...@@ -814,8 +833,85 @@ func TestOpenAIStreamingTimeout(t *testing.T) { ...@@ -814,8 +833,85 @@ func TestOpenAIStreamingTimeout(t *testing.T) {
if err == nil || !strings.Contains(err.Error(), "stream data interval timeout") { if err == nil || !strings.Contains(err.Error(), "stream data interval timeout") {
t.Fatalf("expected stream timeout error, got %v", err) t.Fatalf("expected stream timeout error, got %v", err)
} }
if !strings.Contains(rec.Body.String(), "stream_timeout") { if !strings.Contains(rec.Body.String(), "\"type\":\"error\"") || !strings.Contains(rec.Body.String(), "stream_timeout") {
t.Fatalf("expected stream_timeout SSE error, got %q", rec.Body.String()) t.Fatalf("expected OpenAI-compatible error SSE event, got %q", rec.Body.String())
}
}
func TestOpenAIStreamingContextCanceledDoesNotInjectErrorEvent(t *testing.T) {
gin.SetMode(gin.TestMode)
cfg := &config.Config{
Gateway: config.GatewayConfig{
StreamDataIntervalTimeout: 0,
StreamKeepaliveInterval: 0,
MaxLineSize: defaultMaxLineSize,
},
}
svc := &OpenAIGatewayService{cfg: cfg}
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
ctx, cancel := context.WithCancel(context.Background())
cancel()
c.Request = httptest.NewRequest(http.MethodPost, "/", nil).WithContext(ctx)
resp := &http.Response{
StatusCode: http.StatusOK,
Body: cancelReadCloser{},
Header: http.Header{},
}
_, err := svc.handleStreamingResponse(c.Request.Context(), resp, c, &Account{ID: 1}, time.Now(), "model", "model")
if err != nil {
t.Fatalf("expected nil error, got %v", err)
}
if strings.Contains(rec.Body.String(), "event: error") || strings.Contains(rec.Body.String(), "stream_read_error") {
t.Fatalf("expected no injected SSE error event, got %q", rec.Body.String())
}
}
func TestOpenAIStreamingClientDisconnectDrainsUpstreamUsage(t *testing.T) {
gin.SetMode(gin.TestMode)
cfg := &config.Config{
Gateway: config.GatewayConfig{
StreamDataIntervalTimeout: 0,
StreamKeepaliveInterval: 0,
MaxLineSize: defaultMaxLineSize,
},
}
svc := &OpenAIGatewayService{cfg: cfg}
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
c.Writer = &failingGinWriter{ResponseWriter: c.Writer, failAfter: 0}
pr, pw := io.Pipe()
resp := &http.Response{
StatusCode: http.StatusOK,
Body: pr,
Header: http.Header{},
}
go func() {
defer func() { _ = pw.Close() }()
_, _ = pw.Write([]byte("data: {\"type\":\"response.in_progress\",\"response\":{}}\n\n"))
_, _ = pw.Write([]byte("data: {\"type\":\"response.completed\",\"response\":{\"usage\":{\"input_tokens\":3,\"output_tokens\":5,\"input_tokens_details\":{\"cached_tokens\":1}}}}\n\n"))
}()
result, err := svc.handleStreamingResponse(c.Request.Context(), resp, c, &Account{ID: 1}, time.Now(), "model", "model")
_ = pr.Close()
if err != nil {
t.Fatalf("expected nil error, got %v", err)
}
if result == nil || result.usage == nil {
t.Fatalf("expected usage result")
}
if result.usage.InputTokens != 3 || result.usage.OutputTokens != 5 || result.usage.CacheReadInputTokens != 1 {
t.Fatalf("unexpected usage: %+v", *result.usage)
}
if strings.Contains(rec.Body.String(), "event: error") || strings.Contains(rec.Body.String(), "write_failed") {
t.Fatalf("expected no injected SSE error event, got %q", rec.Body.String())
} }
} }
...@@ -854,8 +950,8 @@ func TestOpenAIStreamingTooLong(t *testing.T) { ...@@ -854,8 +950,8 @@ func TestOpenAIStreamingTooLong(t *testing.T) {
if !errors.Is(err, bufio.ErrTooLong) { if !errors.Is(err, bufio.ErrTooLong) {
t.Fatalf("expected ErrTooLong, got %v", err) t.Fatalf("expected ErrTooLong, got %v", err)
} }
if !strings.Contains(rec.Body.String(), "response_too_large") { if !strings.Contains(rec.Body.String(), "\"type\":\"error\"") || !strings.Contains(rec.Body.String(), "response_too_large") {
t.Fatalf("expected response_too_large SSE error, got %q", rec.Body.String()) t.Fatalf("expected OpenAI-compatible error SSE event, got %q", rec.Body.String())
} }
} }
......
...@@ -2,9 +2,10 @@ package service ...@@ -2,9 +2,10 @@ package service
import ( import (
"context" "context"
"fmt" "net/http"
"time" "time"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai" "github.com/Wei-Shaw/sub2api/internal/pkg/openai"
) )
...@@ -35,12 +36,12 @@ func (s *OpenAIOAuthService) GenerateAuthURL(ctx context.Context, proxyID *int64 ...@@ -35,12 +36,12 @@ func (s *OpenAIOAuthService) GenerateAuthURL(ctx context.Context, proxyID *int64
// Generate PKCE values // Generate PKCE values
state, err := openai.GenerateState() state, err := openai.GenerateState()
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to generate state: %w", err) return nil, infraerrors.Newf(http.StatusInternalServerError, "OPENAI_OAUTH_STATE_FAILED", "failed to generate state: %v", err)
} }
codeVerifier, err := openai.GenerateCodeVerifier() codeVerifier, err := openai.GenerateCodeVerifier()
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to generate code verifier: %w", err) return nil, infraerrors.Newf(http.StatusInternalServerError, "OPENAI_OAUTH_VERIFIER_FAILED", "failed to generate code verifier: %v", err)
} }
codeChallenge := openai.GenerateCodeChallenge(codeVerifier) codeChallenge := openai.GenerateCodeChallenge(codeVerifier)
...@@ -48,14 +49,17 @@ func (s *OpenAIOAuthService) GenerateAuthURL(ctx context.Context, proxyID *int64 ...@@ -48,14 +49,17 @@ func (s *OpenAIOAuthService) GenerateAuthURL(ctx context.Context, proxyID *int64
// Generate session ID // Generate session ID
sessionID, err := openai.GenerateSessionID() sessionID, err := openai.GenerateSessionID()
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to generate session ID: %w", err) return nil, infraerrors.Newf(http.StatusInternalServerError, "OPENAI_OAUTH_SESSION_FAILED", "failed to generate session ID: %v", err)
} }
// Get proxy URL if specified // Get proxy URL if specified
var proxyURL string var proxyURL string
if proxyID != nil { if proxyID != nil {
proxy, err := s.proxyRepo.GetByID(ctx, *proxyID) proxy, err := s.proxyRepo.GetByID(ctx, *proxyID)
if err == nil && proxy != nil { if err != nil {
return nil, infraerrors.Newf(http.StatusBadRequest, "OPENAI_OAUTH_PROXY_NOT_FOUND", "proxy not found: %v", err)
}
if proxy != nil {
proxyURL = proxy.URL() proxyURL = proxy.URL()
} }
} }
...@@ -110,14 +114,17 @@ func (s *OpenAIOAuthService) ExchangeCode(ctx context.Context, input *OpenAIExch ...@@ -110,14 +114,17 @@ func (s *OpenAIOAuthService) ExchangeCode(ctx context.Context, input *OpenAIExch
// Get session // Get session
session, ok := s.sessionStore.Get(input.SessionID) session, ok := s.sessionStore.Get(input.SessionID)
if !ok { if !ok {
return nil, fmt.Errorf("session not found or expired") return nil, infraerrors.New(http.StatusBadRequest, "OPENAI_OAUTH_SESSION_NOT_FOUND", "session not found or expired")
} }
// Get proxy URL // Get proxy URL: prefer input.ProxyID, fallback to session.ProxyURL
proxyURL := session.ProxyURL proxyURL := session.ProxyURL
if input.ProxyID != nil { if input.ProxyID != nil {
proxy, err := s.proxyRepo.GetByID(ctx, *input.ProxyID) proxy, err := s.proxyRepo.GetByID(ctx, *input.ProxyID)
if err == nil && proxy != nil { if err != nil {
return nil, infraerrors.Newf(http.StatusBadRequest, "OPENAI_OAUTH_PROXY_NOT_FOUND", "proxy not found: %v", err)
}
if proxy != nil {
proxyURL = proxy.URL() proxyURL = proxy.URL()
} }
} }
...@@ -131,7 +138,7 @@ func (s *OpenAIOAuthService) ExchangeCode(ctx context.Context, input *OpenAIExch ...@@ -131,7 +138,7 @@ func (s *OpenAIOAuthService) ExchangeCode(ctx context.Context, input *OpenAIExch
// Exchange code for token // Exchange code for token
tokenResp, err := s.oauthClient.ExchangeCode(ctx, input.Code, session.CodeVerifier, redirectURI, proxyURL) tokenResp, err := s.oauthClient.ExchangeCode(ctx, input.Code, session.CodeVerifier, redirectURI, proxyURL)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to exchange code: %w", err) return nil, err
} }
// Parse ID token to get user info // Parse ID token to get user info
...@@ -201,12 +208,12 @@ func (s *OpenAIOAuthService) RefreshToken(ctx context.Context, refreshToken stri ...@@ -201,12 +208,12 @@ func (s *OpenAIOAuthService) RefreshToken(ctx context.Context, refreshToken stri
// RefreshAccountToken refreshes token for an OpenAI account // RefreshAccountToken refreshes token for an OpenAI account
func (s *OpenAIOAuthService) RefreshAccountToken(ctx context.Context, account *Account) (*OpenAITokenInfo, error) { func (s *OpenAIOAuthService) RefreshAccountToken(ctx context.Context, account *Account) (*OpenAITokenInfo, error) {
if !account.IsOpenAI() { if !account.IsOpenAI() {
return nil, fmt.Errorf("account is not an OpenAI account") return nil, infraerrors.New(http.StatusBadRequest, "OPENAI_OAUTH_INVALID_ACCOUNT", "account is not an OpenAI account")
} }
refreshToken := account.GetOpenAIRefreshToken() refreshToken := account.GetOpenAIRefreshToken()
if refreshToken == "" { if refreshToken == "" {
return nil, fmt.Errorf("no refresh token available") return nil, infraerrors.New(http.StatusBadRequest, "OPENAI_OAUTH_NO_REFRESH_TOKEN", "no refresh token available")
} }
var proxyURL string var proxyURL string
......
...@@ -67,6 +67,8 @@ func (s *OpsService) GetAccountAvailabilityStats(ctx context.Context, platformFi ...@@ -67,6 +67,8 @@ func (s *OpsService) GetAccountAvailabilityStats(ctx context.Context, platformFi
isAvailable := acc.Status == StatusActive && acc.Schedulable && !isRateLimited && !isOverloaded && !isTempUnsched isAvailable := acc.Status == StatusActive && acc.Schedulable && !isRateLimited && !isOverloaded && !isTempUnsched
scopeRateLimits := acc.GetAntigravityScopeRateLimits()
if acc.Platform != "" { if acc.Platform != "" {
if _, ok := platform[acc.Platform]; !ok { if _, ok := platform[acc.Platform]; !ok {
platform[acc.Platform] = &PlatformAvailability{ platform[acc.Platform] = &PlatformAvailability{
...@@ -84,6 +86,14 @@ func (s *OpsService) GetAccountAvailabilityStats(ctx context.Context, platformFi ...@@ -84,6 +86,14 @@ func (s *OpsService) GetAccountAvailabilityStats(ctx context.Context, platformFi
if hasError { if hasError {
p.ErrorCount++ p.ErrorCount++
} }
if len(scopeRateLimits) > 0 {
if p.ScopeRateLimitCount == nil {
p.ScopeRateLimitCount = make(map[string]int64)
}
for scope := range scopeRateLimits {
p.ScopeRateLimitCount[scope]++
}
}
} }
for _, grp := range acc.Groups { for _, grp := range acc.Groups {
...@@ -108,6 +118,14 @@ func (s *OpsService) GetAccountAvailabilityStats(ctx context.Context, platformFi ...@@ -108,6 +118,14 @@ func (s *OpsService) GetAccountAvailabilityStats(ctx context.Context, platformFi
if hasError { if hasError {
g.ErrorCount++ g.ErrorCount++
} }
if len(scopeRateLimits) > 0 {
if g.ScopeRateLimitCount == nil {
g.ScopeRateLimitCount = make(map[string]int64)
}
for scope := range scopeRateLimits {
g.ScopeRateLimitCount[scope]++
}
}
} }
displayGroupID := int64(0) displayGroupID := int64(0)
...@@ -140,6 +158,9 @@ func (s *OpsService) GetAccountAvailabilityStats(ctx context.Context, platformFi ...@@ -140,6 +158,9 @@ func (s *OpsService) GetAccountAvailabilityStats(ctx context.Context, platformFi
item.RateLimitRemainingSec = &remainingSec item.RateLimitRemainingSec = &remainingSec
} }
} }
if len(scopeRateLimits) > 0 {
item.ScopeRateLimits = scopeRateLimits
}
if isOverloaded && acc.OverloadUntil != nil { if isOverloaded && acc.OverloadUntil != nil {
item.OverloadUntil = acc.OverloadUntil item.OverloadUntil = acc.OverloadUntil
remainingSec := int64(time.Until(*acc.OverloadUntil).Seconds()) remainingSec := int64(time.Until(*acc.OverloadUntil).Seconds())
......
...@@ -39,22 +39,24 @@ type AccountConcurrencyInfo struct { ...@@ -39,22 +39,24 @@ type AccountConcurrencyInfo struct {
// PlatformAvailability aggregates account availability by platform. // PlatformAvailability aggregates account availability by platform.
type PlatformAvailability struct { type PlatformAvailability struct {
Platform string `json:"platform"` Platform string `json:"platform"`
TotalAccounts int64 `json:"total_accounts"` TotalAccounts int64 `json:"total_accounts"`
AvailableCount int64 `json:"available_count"` AvailableCount int64 `json:"available_count"`
RateLimitCount int64 `json:"rate_limit_count"` RateLimitCount int64 `json:"rate_limit_count"`
ErrorCount int64 `json:"error_count"` ScopeRateLimitCount map[string]int64 `json:"scope_rate_limit_count,omitempty"`
ErrorCount int64 `json:"error_count"`
} }
// GroupAvailability aggregates account availability by group. // GroupAvailability aggregates account availability by group.
type GroupAvailability struct { type GroupAvailability struct {
GroupID int64 `json:"group_id"` GroupID int64 `json:"group_id"`
GroupName string `json:"group_name"` GroupName string `json:"group_name"`
Platform string `json:"platform"` Platform string `json:"platform"`
TotalAccounts int64 `json:"total_accounts"` TotalAccounts int64 `json:"total_accounts"`
AvailableCount int64 `json:"available_count"` AvailableCount int64 `json:"available_count"`
RateLimitCount int64 `json:"rate_limit_count"` RateLimitCount int64 `json:"rate_limit_count"`
ErrorCount int64 `json:"error_count"` ScopeRateLimitCount map[string]int64 `json:"scope_rate_limit_count,omitempty"`
ErrorCount int64 `json:"error_count"`
} }
// AccountAvailability represents current availability for a single account. // AccountAvailability represents current availability for a single account.
...@@ -72,10 +74,11 @@ type AccountAvailability struct { ...@@ -72,10 +74,11 @@ type AccountAvailability struct {
IsOverloaded bool `json:"is_overloaded"` IsOverloaded bool `json:"is_overloaded"`
HasError bool `json:"has_error"` HasError bool `json:"has_error"`
RateLimitResetAt *time.Time `json:"rate_limit_reset_at"` RateLimitResetAt *time.Time `json:"rate_limit_reset_at"`
RateLimitRemainingSec *int64 `json:"rate_limit_remaining_sec"` RateLimitRemainingSec *int64 `json:"rate_limit_remaining_sec"`
OverloadUntil *time.Time `json:"overload_until"` ScopeRateLimits map[string]int64 `json:"scope_rate_limits,omitempty"`
OverloadRemainingSec *int64 `json:"overload_remaining_sec"` OverloadUntil *time.Time `json:"overload_until"`
ErrorMessage string `json:"error_message"` OverloadRemainingSec *int64 `json:"overload_remaining_sec"`
TempUnschedulableUntil *time.Time `json:"temp_unschedulable_until,omitempty"` ErrorMessage string `json:"error_message"`
TempUnschedulableUntil *time.Time `json:"temp_unschedulable_until,omitempty"`
} }
...@@ -83,6 +83,7 @@ type OpsAdvancedSettings struct { ...@@ -83,6 +83,7 @@ type OpsAdvancedSettings struct {
IgnoreCountTokensErrors bool `json:"ignore_count_tokens_errors"` IgnoreCountTokensErrors bool `json:"ignore_count_tokens_errors"`
IgnoreContextCanceled bool `json:"ignore_context_canceled"` IgnoreContextCanceled bool `json:"ignore_context_canceled"`
IgnoreNoAvailableAccounts bool `json:"ignore_no_available_accounts"` IgnoreNoAvailableAccounts bool `json:"ignore_no_available_accounts"`
IgnoreInvalidApiKeyErrors bool `json:"ignore_invalid_api_key_errors"`
AutoRefreshEnabled bool `json:"auto_refresh_enabled"` AutoRefreshEnabled bool `json:"auto_refresh_enabled"`
AutoRefreshIntervalSec int `json:"auto_refresh_interval_seconds"` AutoRefreshIntervalSec int `json:"auto_refresh_interval_seconds"`
} }
......
...@@ -343,9 +343,48 @@ func (s *RateLimitService) handleCustomErrorCode(ctx context.Context, account *A ...@@ -343,9 +343,48 @@ func (s *RateLimitService) handleCustomErrorCode(ctx context.Context, account *A
// handle429 处理429限流错误 // handle429 处理429限流错误
// 解析响应头获取重置时间,标记账号为限流状态 // 解析响应头获取重置时间,标记账号为限流状态
func (s *RateLimitService) handle429(ctx context.Context, account *Account, headers http.Header, responseBody []byte) { func (s *RateLimitService) handle429(ctx context.Context, account *Account, headers http.Header, responseBody []byte) {
// 解析重置时间戳 // 1. OpenAI 平台:优先尝试解析 x-codex-* 响应头(用于 rate_limit_exceeded)
if account.Platform == PlatformOpenAI {
if resetAt := s.calculateOpenAI429ResetTime(headers); resetAt != nil {
if err := s.accountRepo.SetRateLimited(ctx, account.ID, *resetAt); err != nil {
slog.Warn("rate_limit_set_failed", "account_id", account.ID, "error", err)
return
}
slog.Info("openai_account_rate_limited", "account_id", account.ID, "reset_at", *resetAt)
return
}
}
// 2. 尝试从响应头解析重置时间(Anthropic)
resetTimestamp := headers.Get("anthropic-ratelimit-unified-reset") resetTimestamp := headers.Get("anthropic-ratelimit-unified-reset")
// 3. 如果响应头没有,尝试从响应体解析(OpenAI usage_limit_reached, Gemini)
if resetTimestamp == "" { if resetTimestamp == "" {
switch account.Platform {
case PlatformOpenAI:
// 尝试解析 OpenAI 的 usage_limit_reached 错误
if resetAt := parseOpenAIRateLimitResetTime(responseBody); resetAt != nil {
resetTime := time.Unix(*resetAt, 0)
if err := s.accountRepo.SetRateLimited(ctx, account.ID, resetTime); err != nil {
slog.Warn("rate_limit_set_failed", "account_id", account.ID, "error", err)
return
}
slog.Info("account_rate_limited", "account_id", account.ID, "platform", account.Platform, "reset_at", resetTime, "reset_in", time.Until(resetTime).Truncate(time.Second))
return
}
case PlatformGemini, PlatformAntigravity:
// 尝试解析 Gemini 格式(用于其他平台)
if resetAt := ParseGeminiRateLimitResetTime(responseBody); resetAt != nil {
resetTime := time.Unix(*resetAt, 0)
if err := s.accountRepo.SetRateLimited(ctx, account.ID, resetTime); err != nil {
slog.Warn("rate_limit_set_failed", "account_id", account.ID, "error", err)
return
}
slog.Info("account_rate_limited", "account_id", account.ID, "platform", account.Platform, "reset_at", resetTime, "reset_in", time.Until(resetTime).Truncate(time.Second))
return
}
}
// 没有重置时间,使用默认5分钟 // 没有重置时间,使用默认5分钟
resetAt := time.Now().Add(5 * time.Minute) resetAt := time.Now().Add(5 * time.Minute)
if s.shouldScopeClaudeSonnetRateLimit(account, responseBody) { if s.shouldScopeClaudeSonnetRateLimit(account, responseBody) {
...@@ -356,6 +395,7 @@ func (s *RateLimitService) handle429(ctx context.Context, account *Account, head ...@@ -356,6 +395,7 @@ func (s *RateLimitService) handle429(ctx context.Context, account *Account, head
} }
return return
} }
slog.Warn("rate_limit_no_reset_time", "account_id", account.ID, "platform", account.Platform, "using_default", "5m")
if err := s.accountRepo.SetRateLimited(ctx, account.ID, resetAt); err != nil { if err := s.accountRepo.SetRateLimited(ctx, account.ID, resetAt); err != nil {
slog.Warn("rate_limit_set_failed", "account_id", account.ID, "error", err) slog.Warn("rate_limit_set_failed", "account_id", account.ID, "error", err)
} }
...@@ -419,6 +459,108 @@ func (s *RateLimitService) shouldScopeClaudeSonnetRateLimit(account *Account, re ...@@ -419,6 +459,108 @@ func (s *RateLimitService) shouldScopeClaudeSonnetRateLimit(account *Account, re
return strings.Contains(msg, "sonnet") return strings.Contains(msg, "sonnet")
} }
// calculateOpenAI429ResetTime 从 OpenAI 429 响应头计算正确的重置时间
// 返回 nil 表示无法从响应头中确定重置时间
func (s *RateLimitService) calculateOpenAI429ResetTime(headers http.Header) *time.Time {
snapshot := ParseCodexRateLimitHeaders(headers)
if snapshot == nil {
return nil
}
normalized := snapshot.Normalize()
if normalized == nil {
return nil
}
now := time.Now()
// 判断哪个限制被触发(used_percent >= 100)
is7dExhausted := normalized.Used7dPercent != nil && *normalized.Used7dPercent >= 100
is5hExhausted := normalized.Used5hPercent != nil && *normalized.Used5hPercent >= 100
// 优先使用被触发限制的重置时间
if is7dExhausted && normalized.Reset7dSeconds != nil {
resetAt := now.Add(time.Duration(*normalized.Reset7dSeconds) * time.Second)
slog.Info("openai_429_7d_limit_exhausted", "reset_after_seconds", *normalized.Reset7dSeconds, "reset_at", resetAt)
return &resetAt
}
if is5hExhausted && normalized.Reset5hSeconds != nil {
resetAt := now.Add(time.Duration(*normalized.Reset5hSeconds) * time.Second)
slog.Info("openai_429_5h_limit_exhausted", "reset_after_seconds", *normalized.Reset5hSeconds, "reset_at", resetAt)
return &resetAt
}
// 都未达到100%但收到429,使用较长的重置时间
var maxResetSecs int
if normalized.Reset7dSeconds != nil && *normalized.Reset7dSeconds > maxResetSecs {
maxResetSecs = *normalized.Reset7dSeconds
}
if normalized.Reset5hSeconds != nil && *normalized.Reset5hSeconds > maxResetSecs {
maxResetSecs = *normalized.Reset5hSeconds
}
if maxResetSecs > 0 {
resetAt := now.Add(time.Duration(maxResetSecs) * time.Second)
slog.Info("openai_429_using_max_reset", "max_reset_seconds", maxResetSecs, "reset_at", resetAt)
return &resetAt
}
return nil
}
// parseOpenAIRateLimitResetTime 解析 OpenAI 格式的 429 响应,返回重置时间的 Unix 时间戳
// OpenAI 的 usage_limit_reached 错误格式:
//
// {
// "error": {
// "message": "The usage limit has been reached",
// "type": "usage_limit_reached",
// "resets_at": 1769404154,
// "resets_in_seconds": 133107
// }
// }
func parseOpenAIRateLimitResetTime(body []byte) *int64 {
var parsed map[string]any
if err := json.Unmarshal(body, &parsed); err != nil {
return nil
}
errObj, ok := parsed["error"].(map[string]any)
if !ok {
return nil
}
// 检查是否为 usage_limit_reached 或 rate_limit_exceeded 类型
errType, _ := errObj["type"].(string)
if errType != "usage_limit_reached" && errType != "rate_limit_exceeded" {
return nil
}
// 优先使用 resets_at(Unix 时间戳)
if resetsAt, ok := errObj["resets_at"].(float64); ok {
ts := int64(resetsAt)
return &ts
}
if resetsAt, ok := errObj["resets_at"].(string); ok {
if ts, err := strconv.ParseInt(resetsAt, 10, 64); err == nil {
return &ts
}
}
// 如果没有 resets_at,尝试使用 resets_in_seconds
if resetsInSeconds, ok := errObj["resets_in_seconds"].(float64); ok {
ts := time.Now().Unix() + int64(resetsInSeconds)
return &ts
}
if resetsInSeconds, ok := errObj["resets_in_seconds"].(string); ok {
if sec, err := strconv.ParseInt(resetsInSeconds, 10, 64); err == nil {
ts := time.Now().Unix() + sec
return &ts
}
}
return nil
}
// handle529 处理529过载错误 // handle529 处理529过载错误
// 根据配置设置过载冷却时间 // 根据配置设置过载冷却时间
func (s *RateLimitService) handle529(ctx context.Context, account *Account) { func (s *RateLimitService) handle529(ctx context.Context, account *Account) {
......
package service
import (
"net/http"
"testing"
"time"
)
func TestCalculateOpenAI429ResetTime_7dExhausted(t *testing.T) {
svc := &RateLimitService{}
// Simulate headers when 7d limit is exhausted (100% used)
// Primary = 7d (10080 minutes), Secondary = 5h (300 minutes)
headers := http.Header{}
headers.Set("x-codex-primary-used-percent", "100")
headers.Set("x-codex-primary-reset-after-seconds", "384607") // ~4.5 days
headers.Set("x-codex-primary-window-minutes", "10080") // 7 days
headers.Set("x-codex-secondary-used-percent", "3")
headers.Set("x-codex-secondary-reset-after-seconds", "17369") // ~4.8 hours
headers.Set("x-codex-secondary-window-minutes", "300") // 5 hours
before := time.Now()
resetAt := svc.calculateOpenAI429ResetTime(headers)
after := time.Now()
if resetAt == nil {
t.Fatal("expected non-nil resetAt")
}
// Should be approximately 384607 seconds from now
expectedDuration := 384607 * time.Second
minExpected := before.Add(expectedDuration)
maxExpected := after.Add(expectedDuration)
if resetAt.Before(minExpected) || resetAt.After(maxExpected) {
t.Errorf("resetAt %v not in expected range [%v, %v]", resetAt, minExpected, maxExpected)
}
}
func TestCalculateOpenAI429ResetTime_5hExhausted(t *testing.T) {
svc := &RateLimitService{}
// Simulate headers when 5h limit is exhausted (100% used)
headers := http.Header{}
headers.Set("x-codex-primary-used-percent", "50")
headers.Set("x-codex-primary-reset-after-seconds", "500000")
headers.Set("x-codex-primary-window-minutes", "10080") // 7 days
headers.Set("x-codex-secondary-used-percent", "100")
headers.Set("x-codex-secondary-reset-after-seconds", "3600") // 1 hour
headers.Set("x-codex-secondary-window-minutes", "300") // 5 hours
before := time.Now()
resetAt := svc.calculateOpenAI429ResetTime(headers)
after := time.Now()
if resetAt == nil {
t.Fatal("expected non-nil resetAt")
}
// Should be approximately 3600 seconds from now
expectedDuration := 3600 * time.Second
minExpected := before.Add(expectedDuration)
maxExpected := after.Add(expectedDuration)
if resetAt.Before(minExpected) || resetAt.After(maxExpected) {
t.Errorf("resetAt %v not in expected range [%v, %v]", resetAt, minExpected, maxExpected)
}
}
func TestCalculateOpenAI429ResetTime_NeitherExhausted_UsesMax(t *testing.T) {
svc := &RateLimitService{}
// Neither limit at 100%, should use the longer reset time
headers := http.Header{}
headers.Set("x-codex-primary-used-percent", "80")
headers.Set("x-codex-primary-reset-after-seconds", "100000")
headers.Set("x-codex-primary-window-minutes", "10080")
headers.Set("x-codex-secondary-used-percent", "90")
headers.Set("x-codex-secondary-reset-after-seconds", "5000")
headers.Set("x-codex-secondary-window-minutes", "300")
before := time.Now()
resetAt := svc.calculateOpenAI429ResetTime(headers)
after := time.Now()
if resetAt == nil {
t.Fatal("expected non-nil resetAt")
}
// Should use the max (100000 seconds from 7d window)
expectedDuration := 100000 * time.Second
minExpected := before.Add(expectedDuration)
maxExpected := after.Add(expectedDuration)
if resetAt.Before(minExpected) || resetAt.After(maxExpected) {
t.Errorf("resetAt %v not in expected range [%v, %v]", resetAt, minExpected, maxExpected)
}
}
func TestCalculateOpenAI429ResetTime_NoCodexHeaders(t *testing.T) {
svc := &RateLimitService{}
// No codex headers at all
headers := http.Header{}
headers.Set("content-type", "application/json")
resetAt := svc.calculateOpenAI429ResetTime(headers)
if resetAt != nil {
t.Errorf("expected nil resetAt when no codex headers, got %v", resetAt)
}
}
func TestCalculateOpenAI429ResetTime_ReversedWindowOrder(t *testing.T) {
svc := &RateLimitService{}
// Test when OpenAI sends primary as 5h and secondary as 7d (reversed)
headers := http.Header{}
headers.Set("x-codex-primary-used-percent", "100") // This is 5h
headers.Set("x-codex-primary-reset-after-seconds", "3600") // 1 hour
headers.Set("x-codex-primary-window-minutes", "300") // 5 hours - smaller!
headers.Set("x-codex-secondary-used-percent", "50")
headers.Set("x-codex-secondary-reset-after-seconds", "500000")
headers.Set("x-codex-secondary-window-minutes", "10080") // 7 days - larger!
before := time.Now()
resetAt := svc.calculateOpenAI429ResetTime(headers)
after := time.Now()
if resetAt == nil {
t.Fatal("expected non-nil resetAt")
}
// Should correctly identify that primary is 5h (smaller window) and use its reset time
expectedDuration := 3600 * time.Second
minExpected := before.Add(expectedDuration)
maxExpected := after.Add(expectedDuration)
if resetAt.Before(minExpected) || resetAt.After(maxExpected) {
t.Errorf("resetAt %v not in expected range [%v, %v]", resetAt, minExpected, maxExpected)
}
}
func TestNormalizedCodexLimits(t *testing.T) {
// Test the Normalize() method directly
pUsed := 100.0
pReset := 384607
pWindow := 10080
sUsed := 3.0
sReset := 17369
sWindow := 300
snapshot := &OpenAICodexUsageSnapshot{
PrimaryUsedPercent: &pUsed,
PrimaryResetAfterSeconds: &pReset,
PrimaryWindowMinutes: &pWindow,
SecondaryUsedPercent: &sUsed,
SecondaryResetAfterSeconds: &sReset,
SecondaryWindowMinutes: &sWindow,
}
normalized := snapshot.Normalize()
if normalized == nil {
t.Fatal("expected non-nil normalized")
}
// Primary has larger window (10080 > 300), so primary should be 7d
if normalized.Used7dPercent == nil || *normalized.Used7dPercent != 100.0 {
t.Errorf("expected Used7dPercent=100, got %v", normalized.Used7dPercent)
}
if normalized.Reset7dSeconds == nil || *normalized.Reset7dSeconds != 384607 {
t.Errorf("expected Reset7dSeconds=384607, got %v", normalized.Reset7dSeconds)
}
if normalized.Used5hPercent == nil || *normalized.Used5hPercent != 3.0 {
t.Errorf("expected Used5hPercent=3, got %v", normalized.Used5hPercent)
}
if normalized.Reset5hSeconds == nil || *normalized.Reset5hSeconds != 17369 {
t.Errorf("expected Reset5hSeconds=17369, got %v", normalized.Reset5hSeconds)
}
}
func TestNormalizedCodexLimits_OnlyPrimaryData(t *testing.T) {
// Test when only primary has data, no window_minutes
pUsed := 80.0
pReset := 50000
snapshot := &OpenAICodexUsageSnapshot{
PrimaryUsedPercent: &pUsed,
PrimaryResetAfterSeconds: &pReset,
// No window_minutes, no secondary data
}
normalized := snapshot.Normalize()
if normalized == nil {
t.Fatal("expected non-nil normalized")
}
// Legacy assumption: primary=7d, secondary=5h
if normalized.Used7dPercent == nil || *normalized.Used7dPercent != 80.0 {
t.Errorf("expected Used7dPercent=80, got %v", normalized.Used7dPercent)
}
if normalized.Reset7dSeconds == nil || *normalized.Reset7dSeconds != 50000 {
t.Errorf("expected Reset7dSeconds=50000, got %v", normalized.Reset7dSeconds)
}
// Secondary (5h) should be nil
if normalized.Used5hPercent != nil {
t.Errorf("expected Used5hPercent=nil, got %v", *normalized.Used5hPercent)
}
if normalized.Reset5hSeconds != nil {
t.Errorf("expected Reset5hSeconds=nil, got %v", *normalized.Reset5hSeconds)
}
}
func TestNormalizedCodexLimits_OnlySecondaryData(t *testing.T) {
// Test when only secondary has data, no window_minutes
sUsed := 60.0
sReset := 3000
snapshot := &OpenAICodexUsageSnapshot{
SecondaryUsedPercent: &sUsed,
SecondaryResetAfterSeconds: &sReset,
// No window_minutes, no primary data
}
normalized := snapshot.Normalize()
if normalized == nil {
t.Fatal("expected non-nil normalized")
}
// Legacy assumption: primary=7d, secondary=5h
// So secondary goes to 5h
if normalized.Used5hPercent == nil || *normalized.Used5hPercent != 60.0 {
t.Errorf("expected Used5hPercent=60, got %v", normalized.Used5hPercent)
}
if normalized.Reset5hSeconds == nil || *normalized.Reset5hSeconds != 3000 {
t.Errorf("expected Reset5hSeconds=3000, got %v", normalized.Reset5hSeconds)
}
// Primary (7d) should be nil
if normalized.Used7dPercent != nil {
t.Errorf("expected Used7dPercent=nil, got %v", *normalized.Used7dPercent)
}
}
func TestNormalizedCodexLimits_BothDataNoWindowMinutes(t *testing.T) {
// Test when both have data but no window_minutes
pUsed := 100.0
pReset := 400000
sUsed := 50.0
sReset := 10000
snapshot := &OpenAICodexUsageSnapshot{
PrimaryUsedPercent: &pUsed,
PrimaryResetAfterSeconds: &pReset,
SecondaryUsedPercent: &sUsed,
SecondaryResetAfterSeconds: &sReset,
// No window_minutes
}
normalized := snapshot.Normalize()
if normalized == nil {
t.Fatal("expected non-nil normalized")
}
// Legacy assumption: primary=7d, secondary=5h
if normalized.Used7dPercent == nil || *normalized.Used7dPercent != 100.0 {
t.Errorf("expected Used7dPercent=100, got %v", normalized.Used7dPercent)
}
if normalized.Reset7dSeconds == nil || *normalized.Reset7dSeconds != 400000 {
t.Errorf("expected Reset7dSeconds=400000, got %v", normalized.Reset7dSeconds)
}
if normalized.Used5hPercent == nil || *normalized.Used5hPercent != 50.0 {
t.Errorf("expected Used5hPercent=50, got %v", normalized.Used5hPercent)
}
if normalized.Reset5hSeconds == nil || *normalized.Reset5hSeconds != 10000 {
t.Errorf("expected Reset5hSeconds=10000, got %v", normalized.Reset5hSeconds)
}
}
func TestHandle429_AnthropicPlatformUnaffected(t *testing.T) {
// Verify that Anthropic platform accounts still use the original logic
// This test ensures we don't break existing Claude account rate limiting
svc := &RateLimitService{}
// Simulate Anthropic 429 headers
headers := http.Header{}
headers.Set("anthropic-ratelimit-unified-reset", "1737820800") // A future Unix timestamp
// For Anthropic platform, calculateOpenAI429ResetTime should return nil
// because it only handles OpenAI platform
resetAt := svc.calculateOpenAI429ResetTime(headers)
// Should return nil since there are no x-codex-* headers
if resetAt != nil {
t.Errorf("expected nil for Anthropic headers, got %v", resetAt)
}
}
func TestCalculateOpenAI429ResetTime_UserProvidedScenario(t *testing.T) {
// This is the exact scenario from the user:
// codex_7d_used_percent: 100
// codex_7d_reset_after_seconds: 384607 (约4.5天后重置)
// codex_5h_used_percent: 3
// codex_5h_reset_after_seconds: 17369 (约4.8小时后重置)
svc := &RateLimitService{}
// Simulate headers matching user's data
// Note: We need to map the canonical 5h/7d back to primary/secondary
// Based on typical OpenAI behavior: primary=7d (larger window), secondary=5h (smaller window)
headers := http.Header{}
headers.Set("x-codex-primary-used-percent", "100")
headers.Set("x-codex-primary-reset-after-seconds", "384607")
headers.Set("x-codex-primary-window-minutes", "10080") // 7 days = 10080 minutes
headers.Set("x-codex-secondary-used-percent", "3")
headers.Set("x-codex-secondary-reset-after-seconds", "17369")
headers.Set("x-codex-secondary-window-minutes", "300") // 5 hours = 300 minutes
before := time.Now()
resetAt := svc.calculateOpenAI429ResetTime(headers)
after := time.Now()
if resetAt == nil {
t.Fatal("expected non-nil resetAt for user scenario")
}
// Should use the 7d reset time (384607 seconds) since 7d limit is exhausted (100%)
expectedDuration := 384607 * time.Second
minExpected := before.Add(expectedDuration)
maxExpected := after.Add(expectedDuration)
if resetAt.Before(minExpected) || resetAt.After(maxExpected) {
t.Errorf("resetAt %v not in expected range [%v, %v]", resetAt, minExpected, maxExpected)
}
// Verify it's approximately 4.45 days (384607 seconds)
duration := resetAt.Sub(before)
actualDays := duration.Hours() / 24.0
// 384607 / 86400 = ~4.45 days
if actualDays < 4.4 || actualDays > 4.5 {
t.Errorf("expected ~4.45 days, got %.2f days", actualDays)
}
t.Logf("User scenario: reset_at=%v, duration=%.2f days", resetAt, actualDays)
}
func TestCalculateOpenAI429ResetTime_5MinFallbackWhenNoReset(t *testing.T) {
// Test that we return nil when there's used_percent but no reset_after_seconds
// This should cause the caller to use the default 5-minute fallback
svc := &RateLimitService{}
headers := http.Header{}
headers.Set("x-codex-primary-used-percent", "100")
// No reset_after_seconds!
resetAt := svc.calculateOpenAI429ResetTime(headers)
// Should return nil since there's no reset time available
if resetAt != nil {
t.Errorf("expected nil when no reset_after_seconds, got %v", resetAt)
}
}
...@@ -49,6 +49,11 @@ type RedeemCodeRepository interface { ...@@ -49,6 +49,11 @@ type RedeemCodeRepository interface {
List(ctx context.Context, params pagination.PaginationParams) ([]RedeemCode, *pagination.PaginationResult, error) List(ctx context.Context, params pagination.PaginationParams) ([]RedeemCode, *pagination.PaginationResult, error)
ListWithFilters(ctx context.Context, params pagination.PaginationParams, codeType, status, search string) ([]RedeemCode, *pagination.PaginationResult, error) ListWithFilters(ctx context.Context, params pagination.PaginationParams, codeType, status, search string) ([]RedeemCode, *pagination.PaginationResult, error)
ListByUser(ctx context.Context, userID int64, limit int) ([]RedeemCode, error) ListByUser(ctx context.Context, userID int64, limit int) ([]RedeemCode, error)
// ListByUserPaginated returns paginated balance/concurrency history for a specific user.
// codeType filter is optional - pass empty string to return all types.
ListByUserPaginated(ctx context.Context, userID int64, params pagination.PaginationParams, codeType string) ([]RedeemCode, *pagination.PaginationResult, error)
// SumPositiveBalanceByUser returns the total recharged amount (sum of positive balance values) for a user.
SumPositiveBalanceByUser(ctx context.Context, userID int64) (float64, error)
} }
// GenerateCodesRequest 生成兑换码请求 // GenerateCodesRequest 生成兑换码请求
...@@ -126,7 +131,8 @@ func (s *RedeemService) GenerateCodes(ctx context.Context, req GenerateCodesRequ ...@@ -126,7 +131,8 @@ func (s *RedeemService) GenerateCodes(ctx context.Context, req GenerateCodesRequ
return nil, errors.New("count must be greater than 0") return nil, errors.New("count must be greater than 0")
} }
if req.Value <= 0 { // 邀请码类型不需要数值,其他类型需要
if req.Type != RedeemTypeInvitation && req.Value <= 0 {
return nil, errors.New("value must be greater than 0") return nil, errors.New("value must be greater than 0")
} }
...@@ -139,6 +145,12 @@ func (s *RedeemService) GenerateCodes(ctx context.Context, req GenerateCodesRequ ...@@ -139,6 +145,12 @@ func (s *RedeemService) GenerateCodes(ctx context.Context, req GenerateCodesRequ
codeType = RedeemTypeBalance codeType = RedeemTypeBalance
} }
// 邀请码类型的 value 设为 0
value := req.Value
if codeType == RedeemTypeInvitation {
value = 0
}
codes := make([]RedeemCode, 0, req.Count) codes := make([]RedeemCode, 0, req.Count)
for i := 0; i < req.Count; i++ { for i := 0; i < req.Count; i++ {
code, err := s.GenerateRandomCode() code, err := s.GenerateRandomCode()
...@@ -149,7 +161,7 @@ func (s *RedeemService) GenerateCodes(ctx context.Context, req GenerateCodesRequ ...@@ -149,7 +161,7 @@ func (s *RedeemService) GenerateCodes(ctx context.Context, req GenerateCodesRequ
codes = append(codes, RedeemCode{ codes = append(codes, RedeemCode{
Code: code, Code: code,
Type: codeType, Type: codeType,
Value: req.Value, Value: value,
Status: StatusUnused, Status: StatusUnused,
}) })
} }
......
...@@ -61,6 +61,9 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings ...@@ -61,6 +61,9 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
SettingKeyRegistrationEnabled, SettingKeyRegistrationEnabled,
SettingKeyEmailVerifyEnabled, SettingKeyEmailVerifyEnabled,
SettingKeyPromoCodeEnabled, SettingKeyPromoCodeEnabled,
SettingKeyPasswordResetEnabled,
SettingKeyInvitationCodeEnabled,
SettingKeyTotpEnabled,
SettingKeyTurnstileEnabled, SettingKeyTurnstileEnabled,
SettingKeyTurnstileSiteKey, SettingKeyTurnstileSiteKey,
SettingKeySiteName, SettingKeySiteName,
...@@ -71,6 +74,8 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings ...@@ -71,6 +74,8 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
SettingKeyDocURL, SettingKeyDocURL,
SettingKeyHomeContent, SettingKeyHomeContent,
SettingKeyHideCcsImportButton, SettingKeyHideCcsImportButton,
SettingKeyPurchaseSubscriptionEnabled,
SettingKeyPurchaseSubscriptionURL,
SettingKeyLinuxDoConnectEnabled, SettingKeyLinuxDoConnectEnabled,
} }
...@@ -86,21 +91,30 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings ...@@ -86,21 +91,30 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
linuxDoEnabled = s.cfg != nil && s.cfg.LinuxDo.Enabled linuxDoEnabled = s.cfg != nil && s.cfg.LinuxDo.Enabled
} }
// Password reset requires email verification to be enabled
emailVerifyEnabled := settings[SettingKeyEmailVerifyEnabled] == "true"
passwordResetEnabled := emailVerifyEnabled && settings[SettingKeyPasswordResetEnabled] == "true"
return &PublicSettings{ return &PublicSettings{
RegistrationEnabled: settings[SettingKeyRegistrationEnabled] == "true", RegistrationEnabled: settings[SettingKeyRegistrationEnabled] == "true",
EmailVerifyEnabled: settings[SettingKeyEmailVerifyEnabled] == "true", EmailVerifyEnabled: emailVerifyEnabled,
PromoCodeEnabled: settings[SettingKeyPromoCodeEnabled] != "false", // 默认启用 PromoCodeEnabled: settings[SettingKeyPromoCodeEnabled] != "false", // 默认启用
TurnstileEnabled: settings[SettingKeyTurnstileEnabled] == "true", PasswordResetEnabled: passwordResetEnabled,
TurnstileSiteKey: settings[SettingKeyTurnstileSiteKey], InvitationCodeEnabled: settings[SettingKeyInvitationCodeEnabled] == "true",
SiteName: s.getStringOrDefault(settings, SettingKeySiteName, "Sub2API"), TotpEnabled: settings[SettingKeyTotpEnabled] == "true",
SiteLogo: settings[SettingKeySiteLogo], TurnstileEnabled: settings[SettingKeyTurnstileEnabled] == "true",
SiteSubtitle: s.getStringOrDefault(settings, SettingKeySiteSubtitle, "Subscription to API Conversion Platform"), TurnstileSiteKey: settings[SettingKeyTurnstileSiteKey],
APIBaseURL: settings[SettingKeyAPIBaseURL], SiteName: s.getStringOrDefault(settings, SettingKeySiteName, "Sub2API"),
ContactInfo: settings[SettingKeyContactInfo], SiteLogo: settings[SettingKeySiteLogo],
DocURL: settings[SettingKeyDocURL], SiteSubtitle: s.getStringOrDefault(settings, SettingKeySiteSubtitle, "Subscription to API Conversion Platform"),
HomeContent: settings[SettingKeyHomeContent], APIBaseURL: settings[SettingKeyAPIBaseURL],
HideCcsImportButton: settings[SettingKeyHideCcsImportButton] == "true", ContactInfo: settings[SettingKeyContactInfo],
LinuxDoOAuthEnabled: linuxDoEnabled, DocURL: settings[SettingKeyDocURL],
HomeContent: settings[SettingKeyHomeContent],
HideCcsImportButton: settings[SettingKeyHideCcsImportButton] == "true",
PurchaseSubscriptionEnabled: settings[SettingKeyPurchaseSubscriptionEnabled] == "true",
PurchaseSubscriptionURL: strings.TrimSpace(settings[SettingKeyPurchaseSubscriptionURL]),
LinuxDoOAuthEnabled: linuxDoEnabled,
}, nil }, nil
} }
...@@ -125,37 +139,47 @@ func (s *SettingService) GetPublicSettingsForInjection(ctx context.Context) (any ...@@ -125,37 +139,47 @@ func (s *SettingService) GetPublicSettingsForInjection(ctx context.Context) (any
// Return a struct that matches the frontend's expected format // Return a struct that matches the frontend's expected format
return &struct { return &struct {
RegistrationEnabled bool `json:"registration_enabled"` RegistrationEnabled bool `json:"registration_enabled"`
EmailVerifyEnabled bool `json:"email_verify_enabled"` EmailVerifyEnabled bool `json:"email_verify_enabled"`
PromoCodeEnabled bool `json:"promo_code_enabled"` PromoCodeEnabled bool `json:"promo_code_enabled"`
TurnstileEnabled bool `json:"turnstile_enabled"` PasswordResetEnabled bool `json:"password_reset_enabled"`
TurnstileSiteKey string `json:"turnstile_site_key,omitempty"` InvitationCodeEnabled bool `json:"invitation_code_enabled"`
SiteName string `json:"site_name"` TotpEnabled bool `json:"totp_enabled"`
SiteLogo string `json:"site_logo,omitempty"` TurnstileEnabled bool `json:"turnstile_enabled"`
SiteSubtitle string `json:"site_subtitle,omitempty"` TurnstileSiteKey string `json:"turnstile_site_key,omitempty"`
APIBaseURL string `json:"api_base_url,omitempty"` SiteName string `json:"site_name"`
ContactInfo string `json:"contact_info,omitempty"` SiteLogo string `json:"site_logo,omitempty"`
DocURL string `json:"doc_url,omitempty"` SiteSubtitle string `json:"site_subtitle,omitempty"`
HomeContent string `json:"home_content,omitempty"` APIBaseURL string `json:"api_base_url,omitempty"`
HideCcsImportButton bool `json:"hide_ccs_import_button"` ContactInfo string `json:"contact_info,omitempty"`
LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"` DocURL string `json:"doc_url,omitempty"`
Version string `json:"version,omitempty"` HomeContent string `json:"home_content,omitempty"`
HideCcsImportButton bool `json:"hide_ccs_import_button"`
PurchaseSubscriptionEnabled bool `json:"purchase_subscription_enabled"`
PurchaseSubscriptionURL string `json:"purchase_subscription_url,omitempty"`
LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"`
Version string `json:"version,omitempty"`
}{ }{
RegistrationEnabled: settings.RegistrationEnabled, RegistrationEnabled: settings.RegistrationEnabled,
EmailVerifyEnabled: settings.EmailVerifyEnabled, EmailVerifyEnabled: settings.EmailVerifyEnabled,
PromoCodeEnabled: settings.PromoCodeEnabled, PromoCodeEnabled: settings.PromoCodeEnabled,
TurnstileEnabled: settings.TurnstileEnabled, PasswordResetEnabled: settings.PasswordResetEnabled,
TurnstileSiteKey: settings.TurnstileSiteKey, InvitationCodeEnabled: settings.InvitationCodeEnabled,
SiteName: settings.SiteName, TotpEnabled: settings.TotpEnabled,
SiteLogo: settings.SiteLogo, TurnstileEnabled: settings.TurnstileEnabled,
SiteSubtitle: settings.SiteSubtitle, TurnstileSiteKey: settings.TurnstileSiteKey,
APIBaseURL: settings.APIBaseURL, SiteName: settings.SiteName,
ContactInfo: settings.ContactInfo, SiteLogo: settings.SiteLogo,
DocURL: settings.DocURL, SiteSubtitle: settings.SiteSubtitle,
HomeContent: settings.HomeContent, APIBaseURL: settings.APIBaseURL,
HideCcsImportButton: settings.HideCcsImportButton, ContactInfo: settings.ContactInfo,
LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled, DocURL: settings.DocURL,
Version: s.version, HomeContent: settings.HomeContent,
HideCcsImportButton: settings.HideCcsImportButton,
PurchaseSubscriptionEnabled: settings.PurchaseSubscriptionEnabled,
PurchaseSubscriptionURL: settings.PurchaseSubscriptionURL,
LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled,
Version: s.version,
}, nil }, nil
} }
...@@ -167,6 +191,9 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet ...@@ -167,6 +191,9 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet
updates[SettingKeyRegistrationEnabled] = strconv.FormatBool(settings.RegistrationEnabled) updates[SettingKeyRegistrationEnabled] = strconv.FormatBool(settings.RegistrationEnabled)
updates[SettingKeyEmailVerifyEnabled] = strconv.FormatBool(settings.EmailVerifyEnabled) updates[SettingKeyEmailVerifyEnabled] = strconv.FormatBool(settings.EmailVerifyEnabled)
updates[SettingKeyPromoCodeEnabled] = strconv.FormatBool(settings.PromoCodeEnabled) updates[SettingKeyPromoCodeEnabled] = strconv.FormatBool(settings.PromoCodeEnabled)
updates[SettingKeyPasswordResetEnabled] = strconv.FormatBool(settings.PasswordResetEnabled)
updates[SettingKeyInvitationCodeEnabled] = strconv.FormatBool(settings.InvitationCodeEnabled)
updates[SettingKeyTotpEnabled] = strconv.FormatBool(settings.TotpEnabled)
// 邮件服务设置(只有非空才更新密码) // 邮件服务设置(只有非空才更新密码)
updates[SettingKeySMTPHost] = settings.SMTPHost updates[SettingKeySMTPHost] = settings.SMTPHost
...@@ -203,6 +230,8 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet ...@@ -203,6 +230,8 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet
updates[SettingKeyDocURL] = settings.DocURL updates[SettingKeyDocURL] = settings.DocURL
updates[SettingKeyHomeContent] = settings.HomeContent updates[SettingKeyHomeContent] = settings.HomeContent
updates[SettingKeyHideCcsImportButton] = strconv.FormatBool(settings.HideCcsImportButton) updates[SettingKeyHideCcsImportButton] = strconv.FormatBool(settings.HideCcsImportButton)
updates[SettingKeyPurchaseSubscriptionEnabled] = strconv.FormatBool(settings.PurchaseSubscriptionEnabled)
updates[SettingKeyPurchaseSubscriptionURL] = strings.TrimSpace(settings.PurchaseSubscriptionURL)
// 默认配置 // 默认配置
updates[SettingKeyDefaultConcurrency] = strconv.Itoa(settings.DefaultConcurrency) updates[SettingKeyDefaultConcurrency] = strconv.Itoa(settings.DefaultConcurrency)
...@@ -262,6 +291,44 @@ func (s *SettingService) IsPromoCodeEnabled(ctx context.Context) bool { ...@@ -262,6 +291,44 @@ func (s *SettingService) IsPromoCodeEnabled(ctx context.Context) bool {
return value != "false" return value != "false"
} }
// IsInvitationCodeEnabled 检查是否启用邀请码注册功能
func (s *SettingService) IsInvitationCodeEnabled(ctx context.Context) bool {
value, err := s.settingRepo.GetValue(ctx, SettingKeyInvitationCodeEnabled)
if err != nil {
return false // 默认关闭
}
return value == "true"
}
// IsPasswordResetEnabled 检查是否启用密码重置功能
// 要求:必须同时开启邮件验证
func (s *SettingService) IsPasswordResetEnabled(ctx context.Context) bool {
// Password reset requires email verification to be enabled
if !s.IsEmailVerifyEnabled(ctx) {
return false
}
value, err := s.settingRepo.GetValue(ctx, SettingKeyPasswordResetEnabled)
if err != nil {
return false // 默认关闭
}
return value == "true"
}
// IsTotpEnabled 检查是否启用 TOTP 双因素认证功能
func (s *SettingService) IsTotpEnabled(ctx context.Context) bool {
value, err := s.settingRepo.GetValue(ctx, SettingKeyTotpEnabled)
if err != nil {
return false // 默认关闭
}
return value == "true"
}
// IsTotpEncryptionKeyConfigured 检查 TOTP 加密密钥是否已手动配置
// 只有手动配置了密钥才允许在管理后台启用 TOTP 功能
func (s *SettingService) IsTotpEncryptionKeyConfigured() bool {
return s.cfg.Totp.EncryptionKeyConfigured
}
// GetSiteName 获取网站名称 // GetSiteName 获取网站名称
func (s *SettingService) GetSiteName(ctx context.Context) string { func (s *SettingService) GetSiteName(ctx context.Context) string {
value, err := s.settingRepo.GetValue(ctx, SettingKeySiteName) value, err := s.settingRepo.GetValue(ctx, SettingKeySiteName)
...@@ -309,15 +376,17 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error { ...@@ -309,15 +376,17 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error {
// 初始化默认设置 // 初始化默认设置
defaults := map[string]string{ defaults := map[string]string{
SettingKeyRegistrationEnabled: "true", SettingKeyRegistrationEnabled: "true",
SettingKeyEmailVerifyEnabled: "false", SettingKeyEmailVerifyEnabled: "false",
SettingKeyPromoCodeEnabled: "true", // 默认启用优惠码功能 SettingKeyPromoCodeEnabled: "true", // 默认启用优惠码功能
SettingKeySiteName: "Sub2API", SettingKeySiteName: "Sub2API",
SettingKeySiteLogo: "", SettingKeySiteLogo: "",
SettingKeyDefaultConcurrency: strconv.Itoa(s.cfg.Default.UserConcurrency), SettingKeyPurchaseSubscriptionEnabled: "false",
SettingKeyDefaultBalance: strconv.FormatFloat(s.cfg.Default.UserBalance, 'f', 8, 64), SettingKeyPurchaseSubscriptionURL: "",
SettingKeySMTPPort: "587", SettingKeyDefaultConcurrency: strconv.Itoa(s.cfg.Default.UserConcurrency),
SettingKeySMTPUseTLS: "false", SettingKeyDefaultBalance: strconv.FormatFloat(s.cfg.Default.UserBalance, 'f', 8, 64),
SettingKeySMTPPort: "587",
SettingKeySMTPUseTLS: "false",
// Model fallback defaults // Model fallback defaults
SettingKeyEnableModelFallback: "false", SettingKeyEnableModelFallback: "false",
SettingKeyFallbackModelAnthropic: "claude-3-5-sonnet-20241022", SettingKeyFallbackModelAnthropic: "claude-3-5-sonnet-20241022",
...@@ -340,10 +409,14 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error { ...@@ -340,10 +409,14 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error {
// parseSettings 解析设置到结构体 // parseSettings 解析设置到结构体
func (s *SettingService) parseSettings(settings map[string]string) *SystemSettings { func (s *SettingService) parseSettings(settings map[string]string) *SystemSettings {
emailVerifyEnabled := settings[SettingKeyEmailVerifyEnabled] == "true"
result := &SystemSettings{ result := &SystemSettings{
RegistrationEnabled: settings[SettingKeyRegistrationEnabled] == "true", RegistrationEnabled: settings[SettingKeyRegistrationEnabled] == "true",
EmailVerifyEnabled: settings[SettingKeyEmailVerifyEnabled] == "true", EmailVerifyEnabled: emailVerifyEnabled,
PromoCodeEnabled: settings[SettingKeyPromoCodeEnabled] != "false", // 默认启用 PromoCodeEnabled: settings[SettingKeyPromoCodeEnabled] != "false", // 默认启用
PasswordResetEnabled: emailVerifyEnabled && settings[SettingKeyPasswordResetEnabled] == "true",
InvitationCodeEnabled: settings[SettingKeyInvitationCodeEnabled] == "true",
TotpEnabled: settings[SettingKeyTotpEnabled] == "true",
SMTPHost: settings[SettingKeySMTPHost], SMTPHost: settings[SettingKeySMTPHost],
SMTPUsername: settings[SettingKeySMTPUsername], SMTPUsername: settings[SettingKeySMTPUsername],
SMTPFrom: settings[SettingKeySMTPFrom], SMTPFrom: settings[SettingKeySMTPFrom],
...@@ -361,6 +434,8 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin ...@@ -361,6 +434,8 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin
DocURL: settings[SettingKeyDocURL], DocURL: settings[SettingKeyDocURL],
HomeContent: settings[SettingKeyHomeContent], HomeContent: settings[SettingKeyHomeContent],
HideCcsImportButton: settings[SettingKeyHideCcsImportButton] == "true", HideCcsImportButton: settings[SettingKeyHideCcsImportButton] == "true",
PurchaseSubscriptionEnabled: settings[SettingKeyPurchaseSubscriptionEnabled] == "true",
PurchaseSubscriptionURL: strings.TrimSpace(settings[SettingKeyPurchaseSubscriptionURL]),
} }
// 解析整数类型 // 解析整数类型
......
package service package service
type SystemSettings struct { type SystemSettings struct {
RegistrationEnabled bool RegistrationEnabled bool
EmailVerifyEnabled bool EmailVerifyEnabled bool
PromoCodeEnabled bool PromoCodeEnabled bool
PasswordResetEnabled bool
InvitationCodeEnabled bool
TotpEnabled bool // TOTP 双因素认证
SMTPHost string SMTPHost string
SMTPPort int SMTPPort int
...@@ -26,14 +29,16 @@ type SystemSettings struct { ...@@ -26,14 +29,16 @@ type SystemSettings struct {
LinuxDoConnectClientSecretConfigured bool LinuxDoConnectClientSecretConfigured bool
LinuxDoConnectRedirectURL string LinuxDoConnectRedirectURL string
SiteName string SiteName string
SiteLogo string SiteLogo string
SiteSubtitle string SiteSubtitle string
APIBaseURL string APIBaseURL string
ContactInfo string ContactInfo string
DocURL string DocURL string
HomeContent string HomeContent string
HideCcsImportButton bool HideCcsImportButton bool
PurchaseSubscriptionEnabled bool
PurchaseSubscriptionURL string
DefaultConcurrency int DefaultConcurrency int
DefaultBalance float64 DefaultBalance float64
...@@ -57,19 +62,26 @@ type SystemSettings struct { ...@@ -57,19 +62,26 @@ type SystemSettings struct {
} }
type PublicSettings struct { type PublicSettings struct {
RegistrationEnabled bool RegistrationEnabled bool
EmailVerifyEnabled bool EmailVerifyEnabled bool
PromoCodeEnabled bool PromoCodeEnabled bool
TurnstileEnabled bool PasswordResetEnabled bool
TurnstileSiteKey string InvitationCodeEnabled bool
SiteName string TotpEnabled bool // TOTP 双因素认证
SiteLogo string TurnstileEnabled bool
SiteSubtitle string TurnstileSiteKey string
APIBaseURL string SiteName string
ContactInfo string SiteLogo string
DocURL string SiteSubtitle string
HomeContent string APIBaseURL string
HideCcsImportButton bool ContactInfo string
DocURL string
HomeContent string
HideCcsImportButton bool
PurchaseSubscriptionEnabled bool
PurchaseSubscriptionURL string
LinuxDoOAuthEnabled bool LinuxDoOAuthEnabled bool
Version string Version string
} }
......
package service
import (
"context"
"log"
"sync"
"time"
)
// SubscriptionExpiryService periodically updates expired subscription status.
type SubscriptionExpiryService struct {
userSubRepo UserSubscriptionRepository
interval time.Duration
stopCh chan struct{}
stopOnce sync.Once
wg sync.WaitGroup
}
func NewSubscriptionExpiryService(userSubRepo UserSubscriptionRepository, interval time.Duration) *SubscriptionExpiryService {
return &SubscriptionExpiryService{
userSubRepo: userSubRepo,
interval: interval,
stopCh: make(chan struct{}),
}
}
func (s *SubscriptionExpiryService) Start() {
if s == nil || s.userSubRepo == nil || s.interval <= 0 {
return
}
s.wg.Add(1)
go func() {
defer s.wg.Done()
ticker := time.NewTicker(s.interval)
defer ticker.Stop()
s.runOnce()
for {
select {
case <-ticker.C:
s.runOnce()
case <-s.stopCh:
return
}
}
}()
}
func (s *SubscriptionExpiryService) Stop() {
if s == nil {
return
}
s.stopOnce.Do(func() {
close(s.stopCh)
})
s.wg.Wait()
}
func (s *SubscriptionExpiryService) runOnce() {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
updated, err := s.userSubRepo.BatchUpdateExpiredStatus(ctx)
if err != nil {
log.Printf("[SubscriptionExpiry] Update expired subscriptions failed: %v", err)
return
}
if updated > 0 {
log.Printf("[SubscriptionExpiry] Updated %d expired subscriptions", updated)
}
}
...@@ -324,18 +324,31 @@ func (s *SubscriptionService) ExtendSubscription(ctx context.Context, subscripti ...@@ -324,18 +324,31 @@ func (s *SubscriptionService) ExtendSubscription(ctx context.Context, subscripti
days = -MaxValidityDays days = -MaxValidityDays
} }
now := time.Now()
isExpired := !sub.ExpiresAt.After(now)
// 如果订阅已过期,不允许负向调整
if isExpired && days < 0 {
return nil, infraerrors.BadRequest("CANNOT_SHORTEN_EXPIRED", "cannot shorten an expired subscription")
}
// 计算新的过期时间 // 计算新的过期时间
newExpiresAt := sub.ExpiresAt.AddDate(0, 0, days) var newExpiresAt time.Time
if isExpired {
// 已过期:从当前时间开始增加天数
newExpiresAt = now.AddDate(0, 0, days)
} else {
// 未过期:从原过期时间增加/减少天数
newExpiresAt = sub.ExpiresAt.AddDate(0, 0, days)
}
if newExpiresAt.After(MaxExpiresAt) { if newExpiresAt.After(MaxExpiresAt) {
newExpiresAt = MaxExpiresAt newExpiresAt = MaxExpiresAt
} }
// 如果是缩短(负数),检查新的过期时间必须大于当前时间 // 检查新的过期时间必须大于当前时间
if days < 0 { if !newExpiresAt.After(now) {
now := time.Now() return nil, ErrAdjustWouldExpire
if !newExpiresAt.After(now) {
return nil, ErrAdjustWouldExpire
}
} }
if err := s.userSubRepo.ExtendExpiry(ctx, subscriptionID, newExpiresAt); err != nil { if err := s.userSubRepo.ExtendExpiry(ctx, subscriptionID, newExpiresAt); err != nil {
...@@ -383,6 +396,7 @@ func (s *SubscriptionService) ListUserSubscriptions(ctx context.Context, userID ...@@ -383,6 +396,7 @@ func (s *SubscriptionService) ListUserSubscriptions(ctx context.Context, userID
return nil, err return nil, err
} }
normalizeExpiredWindows(subs) normalizeExpiredWindows(subs)
normalizeSubscriptionStatus(subs)
return subs, nil return subs, nil
} }
...@@ -404,17 +418,19 @@ func (s *SubscriptionService) ListGroupSubscriptions(ctx context.Context, groupI ...@@ -404,17 +418,19 @@ func (s *SubscriptionService) ListGroupSubscriptions(ctx context.Context, groupI
return nil, nil, err return nil, nil, err
} }
normalizeExpiredWindows(subs) normalizeExpiredWindows(subs)
normalizeSubscriptionStatus(subs)
return subs, pag, nil return subs, pag, nil
} }
// List 获取所有订阅(分页,支持筛选) // List 获取所有订阅(分页,支持筛选和排序
func (s *SubscriptionService) List(ctx context.Context, page, pageSize int, userID, groupID *int64, status string) ([]UserSubscription, *pagination.PaginationResult, error) { func (s *SubscriptionService) List(ctx context.Context, page, pageSize int, userID, groupID *int64, status, sortBy, sortOrder string) ([]UserSubscription, *pagination.PaginationResult, error) {
params := pagination.PaginationParams{Page: page, PageSize: pageSize} params := pagination.PaginationParams{Page: page, PageSize: pageSize}
subs, pag, err := s.userSubRepo.List(ctx, params, userID, groupID, status) subs, pag, err := s.userSubRepo.List(ctx, params, userID, groupID, status, sortBy, sortOrder)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
normalizeExpiredWindows(subs) normalizeExpiredWindows(subs)
normalizeSubscriptionStatus(subs)
return subs, pag, nil return subs, pag, nil
} }
...@@ -441,6 +457,18 @@ func normalizeExpiredWindows(subs []UserSubscription) { ...@@ -441,6 +457,18 @@ func normalizeExpiredWindows(subs []UserSubscription) {
} }
} }
// normalizeSubscriptionStatus 根据实际过期时间修正状态(仅影响返回数据,不影响数据库)
// 这确保前端显示正确的状态,即使定时任务尚未更新数据库
func normalizeSubscriptionStatus(subs []UserSubscription) {
now := time.Now()
for i := range subs {
sub := &subs[i]
if sub.Status == SubscriptionStatusActive && !sub.ExpiresAt.After(now) {
sub.Status = SubscriptionStatusExpired
}
}
}
// startOfDay 返回给定时间所在日期的零点(保持原时区) // startOfDay 返回给定时间所在日期的零点(保持原时区)
func startOfDay(t time.Time) time.Time { func startOfDay(t time.Time) time.Time {
return time.Date(t.Year(), t.Month(), t.Day(), 0, 0, 0, 0, t.Location()) return time.Date(t.Year(), t.Month(), t.Day(), 0, 0, 0, 0, t.Location())
...@@ -659,11 +687,6 @@ func (s *SubscriptionService) GetUserSubscriptionsWithProgress(ctx context.Conte ...@@ -659,11 +687,6 @@ func (s *SubscriptionService) GetUserSubscriptionsWithProgress(ctx context.Conte
return progresses, nil return progresses, nil
} }
// UpdateExpiredSubscriptions 更新过期订阅状态(定时任务调用)
func (s *SubscriptionService) UpdateExpiredSubscriptions(ctx context.Context) (int64, error) {
return s.userSubRepo.BatchUpdateExpiredStatus(ctx)
}
// ValidateSubscription 验证订阅是否有效 // ValidateSubscription 验证订阅是否有效
func (s *SubscriptionService) ValidateSubscription(ctx context.Context, sub *UserSubscription) error { func (s *SubscriptionService) ValidateSubscription(ctx context.Context, sub *UserSubscription) error {
if sub.Status == SubscriptionStatusExpired { if sub.Status == SubscriptionStatusExpired {
......
...@@ -18,6 +18,7 @@ type TokenRefreshService struct { ...@@ -18,6 +18,7 @@ type TokenRefreshService struct {
refreshers []TokenRefresher refreshers []TokenRefresher
cfg *config.TokenRefreshConfig cfg *config.TokenRefreshConfig
cacheInvalidator TokenCacheInvalidator cacheInvalidator TokenCacheInvalidator
schedulerCache SchedulerCache // 用于同步更新调度器缓存,解决 token 刷新后缓存不一致问题
stopCh chan struct{} stopCh chan struct{}
wg sync.WaitGroup wg sync.WaitGroup
...@@ -31,12 +32,14 @@ func NewTokenRefreshService( ...@@ -31,12 +32,14 @@ func NewTokenRefreshService(
geminiOAuthService *GeminiOAuthService, geminiOAuthService *GeminiOAuthService,
antigravityOAuthService *AntigravityOAuthService, antigravityOAuthService *AntigravityOAuthService,
cacheInvalidator TokenCacheInvalidator, cacheInvalidator TokenCacheInvalidator,
schedulerCache SchedulerCache,
cfg *config.Config, cfg *config.Config,
) *TokenRefreshService { ) *TokenRefreshService {
s := &TokenRefreshService{ s := &TokenRefreshService{
accountRepo: accountRepo, accountRepo: accountRepo,
cfg: &cfg.TokenRefresh, cfg: &cfg.TokenRefresh,
cacheInvalidator: cacheInvalidator, cacheInvalidator: cacheInvalidator,
schedulerCache: schedulerCache,
stopCh: make(chan struct{}), stopCh: make(chan struct{}),
} }
...@@ -198,6 +201,15 @@ func (s *TokenRefreshService) refreshWithRetry(ctx context.Context, account *Acc ...@@ -198,6 +201,15 @@ func (s *TokenRefreshService) refreshWithRetry(ctx context.Context, account *Acc
log.Printf("[TokenRefresh] Token cache invalidated for account %d", account.ID) log.Printf("[TokenRefresh] Token cache invalidated for account %d", account.ID)
} }
} }
// 同步更新调度器缓存,确保调度获取的 Account 对象包含最新的 credentials
// 这解决了 token 刷新后调度器缓存数据不一致的问题(#445)
if s.schedulerCache != nil {
if err := s.schedulerCache.SetAccount(ctx, account); err != nil {
log.Printf("[TokenRefresh] Failed to sync scheduler cache for account %d: %v", account.ID, err)
} else {
log.Printf("[TokenRefresh] Scheduler cache synced for account %d", account.ID)
}
}
return nil return nil
} }
...@@ -237,7 +249,8 @@ func (s *TokenRefreshService) refreshWithRetry(ctx context.Context, account *Acc ...@@ -237,7 +249,8 @@ func (s *TokenRefreshService) refreshWithRetry(ctx context.Context, account *Acc
} }
// isNonRetryableRefreshError 判断是否为不可重试的刷新错误 // isNonRetryableRefreshError 判断是否为不可重试的刷新错误
// 这些错误通常表示凭证已失效,需要用户重新授权 // 这些错误通常表示凭证已失效或配置确实缺失,需要用户重新授权
// 注意:missing_project_id 错误只在真正缺失(从未获取过)时返回,临时获取失败不会返回此错误
func isNonRetryableRefreshError(err error) bool { func isNonRetryableRefreshError(err error) bool {
if err == nil { if err == nil {
return false return false
......
...@@ -70,7 +70,7 @@ func TestTokenRefreshService_RefreshWithRetry_InvalidatesCache(t *testing.T) { ...@@ -70,7 +70,7 @@ func TestTokenRefreshService_RefreshWithRetry_InvalidatesCache(t *testing.T) {
RetryBackoffSeconds: 0, RetryBackoffSeconds: 0,
}, },
} }
service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, cfg) service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, nil, cfg)
account := &Account{ account := &Account{
ID: 5, ID: 5,
Platform: PlatformGemini, Platform: PlatformGemini,
...@@ -98,7 +98,7 @@ func TestTokenRefreshService_RefreshWithRetry_InvalidatorErrorIgnored(t *testing ...@@ -98,7 +98,7 @@ func TestTokenRefreshService_RefreshWithRetry_InvalidatorErrorIgnored(t *testing
RetryBackoffSeconds: 0, RetryBackoffSeconds: 0,
}, },
} }
service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, cfg) service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, nil, cfg)
account := &Account{ account := &Account{
ID: 6, ID: 6,
Platform: PlatformGemini, Platform: PlatformGemini,
...@@ -124,7 +124,7 @@ func TestTokenRefreshService_RefreshWithRetry_NilInvalidator(t *testing.T) { ...@@ -124,7 +124,7 @@ func TestTokenRefreshService_RefreshWithRetry_NilInvalidator(t *testing.T) {
RetryBackoffSeconds: 0, RetryBackoffSeconds: 0,
}, },
} }
service := NewTokenRefreshService(repo, nil, nil, nil, nil, nil, cfg) service := NewTokenRefreshService(repo, nil, nil, nil, nil, nil, nil, cfg)
account := &Account{ account := &Account{
ID: 7, ID: 7,
Platform: PlatformGemini, Platform: PlatformGemini,
...@@ -151,7 +151,7 @@ func TestTokenRefreshService_RefreshWithRetry_Antigravity(t *testing.T) { ...@@ -151,7 +151,7 @@ func TestTokenRefreshService_RefreshWithRetry_Antigravity(t *testing.T) {
RetryBackoffSeconds: 0, RetryBackoffSeconds: 0,
}, },
} }
service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, cfg) service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, nil, cfg)
account := &Account{ account := &Account{
ID: 8, ID: 8,
Platform: PlatformAntigravity, Platform: PlatformAntigravity,
...@@ -179,7 +179,7 @@ func TestTokenRefreshService_RefreshWithRetry_NonOAuthAccount(t *testing.T) { ...@@ -179,7 +179,7 @@ func TestTokenRefreshService_RefreshWithRetry_NonOAuthAccount(t *testing.T) {
RetryBackoffSeconds: 0, RetryBackoffSeconds: 0,
}, },
} }
service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, cfg) service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, nil, cfg)
account := &Account{ account := &Account{
ID: 9, ID: 9,
Platform: PlatformGemini, Platform: PlatformGemini,
...@@ -207,7 +207,7 @@ func TestTokenRefreshService_RefreshWithRetry_OtherPlatformOAuth(t *testing.T) { ...@@ -207,7 +207,7 @@ func TestTokenRefreshService_RefreshWithRetry_OtherPlatformOAuth(t *testing.T) {
RetryBackoffSeconds: 0, RetryBackoffSeconds: 0,
}, },
} }
service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, cfg) service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, nil, cfg)
account := &Account{ account := &Account{
ID: 10, ID: 10,
Platform: PlatformOpenAI, // OpenAI OAuth 账户 Platform: PlatformOpenAI, // OpenAI OAuth 账户
...@@ -235,7 +235,7 @@ func TestTokenRefreshService_RefreshWithRetry_UpdateFailed(t *testing.T) { ...@@ -235,7 +235,7 @@ func TestTokenRefreshService_RefreshWithRetry_UpdateFailed(t *testing.T) {
RetryBackoffSeconds: 0, RetryBackoffSeconds: 0,
}, },
} }
service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, cfg) service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, nil, cfg)
account := &Account{ account := &Account{
ID: 11, ID: 11,
Platform: PlatformGemini, Platform: PlatformGemini,
...@@ -264,7 +264,7 @@ func TestTokenRefreshService_RefreshWithRetry_RefreshFailed(t *testing.T) { ...@@ -264,7 +264,7 @@ func TestTokenRefreshService_RefreshWithRetry_RefreshFailed(t *testing.T) {
RetryBackoffSeconds: 0, RetryBackoffSeconds: 0,
}, },
} }
service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, cfg) service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, nil, cfg)
account := &Account{ account := &Account{
ID: 12, ID: 12,
Platform: PlatformGemini, Platform: PlatformGemini,
...@@ -291,7 +291,7 @@ func TestTokenRefreshService_RefreshWithRetry_AntigravityRefreshFailed(t *testin ...@@ -291,7 +291,7 @@ func TestTokenRefreshService_RefreshWithRetry_AntigravityRefreshFailed(t *testin
RetryBackoffSeconds: 0, RetryBackoffSeconds: 0,
}, },
} }
service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, cfg) service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, nil, cfg)
account := &Account{ account := &Account{
ID: 13, ID: 13,
Platform: PlatformAntigravity, Platform: PlatformAntigravity,
...@@ -318,7 +318,7 @@ func TestTokenRefreshService_RefreshWithRetry_AntigravityNonRetryableError(t *te ...@@ -318,7 +318,7 @@ func TestTokenRefreshService_RefreshWithRetry_AntigravityNonRetryableError(t *te
RetryBackoffSeconds: 0, RetryBackoffSeconds: 0,
}, },
} }
service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, cfg) service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, nil, cfg)
account := &Account{ account := &Account{
ID: 14, ID: 14,
Platform: PlatformAntigravity, Platform: PlatformAntigravity,
......
package service
import (
"context"
"crypto/rand"
"crypto/subtle"
"encoding/hex"
"fmt"
"log/slog"
"time"
"github.com/pquerna/otp/totp"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
)
var (
ErrTotpNotEnabled = infraerrors.BadRequest("TOTP_NOT_ENABLED", "totp feature is not enabled")
ErrTotpAlreadyEnabled = infraerrors.BadRequest("TOTP_ALREADY_ENABLED", "totp is already enabled for this account")
ErrTotpNotSetup = infraerrors.BadRequest("TOTP_NOT_SETUP", "totp is not set up for this account")
ErrTotpInvalidCode = infraerrors.BadRequest("TOTP_INVALID_CODE", "invalid totp code")
ErrTotpSetupExpired = infraerrors.BadRequest("TOTP_SETUP_EXPIRED", "totp setup session expired")
ErrTotpTooManyAttempts = infraerrors.TooManyRequests("TOTP_TOO_MANY_ATTEMPTS", "too many verification attempts, please try again later")
ErrVerifyCodeRequired = infraerrors.BadRequest("VERIFY_CODE_REQUIRED", "email verification code is required")
ErrPasswordRequired = infraerrors.BadRequest("PASSWORD_REQUIRED", "password is required")
)
// TotpCache defines cache operations for TOTP service
type TotpCache interface {
// Setup session methods
GetSetupSession(ctx context.Context, userID int64) (*TotpSetupSession, error)
SetSetupSession(ctx context.Context, userID int64, session *TotpSetupSession, ttl time.Duration) error
DeleteSetupSession(ctx context.Context, userID int64) error
// Login session methods (for 2FA login flow)
GetLoginSession(ctx context.Context, tempToken string) (*TotpLoginSession, error)
SetLoginSession(ctx context.Context, tempToken string, session *TotpLoginSession, ttl time.Duration) error
DeleteLoginSession(ctx context.Context, tempToken string) error
// Rate limiting
IncrementVerifyAttempts(ctx context.Context, userID int64) (int, error)
GetVerifyAttempts(ctx context.Context, userID int64) (int, error)
ClearVerifyAttempts(ctx context.Context, userID int64) error
}
// SecretEncryptor defines encryption operations for TOTP secrets
type SecretEncryptor interface {
Encrypt(plaintext string) (string, error)
Decrypt(ciphertext string) (string, error)
}
// TotpSetupSession represents a TOTP setup session
type TotpSetupSession struct {
Secret string // Plain text TOTP secret (not encrypted yet)
SetupToken string // Random token to verify setup request
CreatedAt time.Time
}
// TotpLoginSession represents a pending 2FA login session
type TotpLoginSession struct {
UserID int64
Email string
TokenExpiry time.Time
}
// TotpStatus represents the TOTP status for a user
type TotpStatus struct {
Enabled bool `json:"enabled"`
EnabledAt *time.Time `json:"enabled_at,omitempty"`
FeatureEnabled bool `json:"feature_enabled"`
}
// TotpSetupResponse represents the response for initiating TOTP setup
type TotpSetupResponse struct {
Secret string `json:"secret"`
QRCodeURL string `json:"qr_code_url"`
SetupToken string `json:"setup_token"`
Countdown int `json:"countdown"` // seconds until setup expires
}
const (
totpSetupTTL = 5 * time.Minute
totpLoginTTL = 5 * time.Minute
totpAttemptsTTL = 15 * time.Minute
maxTotpAttempts = 5
totpIssuer = "Sub2API"
)
// TotpService handles TOTP operations
type TotpService struct {
userRepo UserRepository
encryptor SecretEncryptor
cache TotpCache
settingService *SettingService
emailService *EmailService
emailQueueService *EmailQueueService
}
// NewTotpService creates a new TOTP service
func NewTotpService(
userRepo UserRepository,
encryptor SecretEncryptor,
cache TotpCache,
settingService *SettingService,
emailService *EmailService,
emailQueueService *EmailQueueService,
) *TotpService {
return &TotpService{
userRepo: userRepo,
encryptor: encryptor,
cache: cache,
settingService: settingService,
emailService: emailService,
emailQueueService: emailQueueService,
}
}
// GetStatus returns the TOTP status for a user
func (s *TotpService) GetStatus(ctx context.Context, userID int64) (*TotpStatus, error) {
featureEnabled := s.settingService.IsTotpEnabled(ctx)
user, err := s.userRepo.GetByID(ctx, userID)
if err != nil {
return nil, fmt.Errorf("get user: %w", err)
}
return &TotpStatus{
Enabled: user.TotpEnabled,
EnabledAt: user.TotpEnabledAt,
FeatureEnabled: featureEnabled,
}, nil
}
// InitiateSetup starts the TOTP setup process
// If email verification is enabled, emailCode is required; otherwise password is required
func (s *TotpService) InitiateSetup(ctx context.Context, userID int64, emailCode, password string) (*TotpSetupResponse, error) {
// Check if TOTP feature is enabled globally
if !s.settingService.IsTotpEnabled(ctx) {
return nil, ErrTotpNotEnabled
}
// Get user and check if TOTP is already enabled
user, err := s.userRepo.GetByID(ctx, userID)
if err != nil {
return nil, fmt.Errorf("get user: %w", err)
}
if user.TotpEnabled {
return nil, ErrTotpAlreadyEnabled
}
// Verify identity based on email verification setting
if s.settingService.IsEmailVerifyEnabled(ctx) {
// Email verification enabled - verify email code
if emailCode == "" {
return nil, ErrVerifyCodeRequired
}
if err := s.emailService.VerifyCode(ctx, user.Email, emailCode); err != nil {
return nil, err
}
} else {
// Email verification disabled - verify password
if password == "" {
return nil, ErrPasswordRequired
}
if !user.CheckPassword(password) {
return nil, ErrPasswordIncorrect
}
}
// Generate a new TOTP key
key, err := totp.Generate(totp.GenerateOpts{
Issuer: totpIssuer,
AccountName: user.Email,
})
if err != nil {
return nil, fmt.Errorf("generate totp key: %w", err)
}
// Generate a random setup token
setupToken, err := generateRandomToken(32)
if err != nil {
return nil, fmt.Errorf("generate setup token: %w", err)
}
// Store the setup session in cache
session := &TotpSetupSession{
Secret: key.Secret(),
SetupToken: setupToken,
CreatedAt: time.Now(),
}
if err := s.cache.SetSetupSession(ctx, userID, session, totpSetupTTL); err != nil {
return nil, fmt.Errorf("store setup session: %w", err)
}
return &TotpSetupResponse{
Secret: key.Secret(),
QRCodeURL: key.URL(),
SetupToken: setupToken,
Countdown: int(totpSetupTTL.Seconds()),
}, nil
}
// CompleteSetup completes the TOTP setup by verifying the code
func (s *TotpService) CompleteSetup(ctx context.Context, userID int64, totpCode, setupToken string) error {
// Check if TOTP feature is enabled globally
if !s.settingService.IsTotpEnabled(ctx) {
return ErrTotpNotEnabled
}
// Get the setup session
session, err := s.cache.GetSetupSession(ctx, userID)
if err != nil {
return ErrTotpSetupExpired
}
if session == nil {
return ErrTotpSetupExpired
}
// Verify the setup token (constant-time comparison)
if subtle.ConstantTimeCompare([]byte(session.SetupToken), []byte(setupToken)) != 1 {
return ErrTotpSetupExpired
}
// Verify the TOTP code
if !totp.Validate(totpCode, session.Secret) {
return ErrTotpInvalidCode
}
setupSecretPrefix := "N/A"
if len(session.Secret) >= 4 {
setupSecretPrefix = session.Secret[:4]
}
slog.Debug("totp_complete_setup_before_encrypt",
"user_id", userID,
"secret_len", len(session.Secret),
"secret_prefix", setupSecretPrefix)
// Encrypt the secret
encryptedSecret, err := s.encryptor.Encrypt(session.Secret)
if err != nil {
return fmt.Errorf("encrypt totp secret: %w", err)
}
slog.Debug("totp_complete_setup_encrypted",
"user_id", userID,
"encrypted_len", len(encryptedSecret))
// Verify encryption by decrypting
decrypted, decErr := s.encryptor.Decrypt(encryptedSecret)
if decErr != nil {
slog.Debug("totp_complete_setup_verify_failed",
"user_id", userID,
"error", decErr)
} else {
decryptedPrefix := "N/A"
if len(decrypted) >= 4 {
decryptedPrefix = decrypted[:4]
}
slog.Debug("totp_complete_setup_verified",
"user_id", userID,
"original_len", len(session.Secret),
"decrypted_len", len(decrypted),
"match", session.Secret == decrypted,
"decrypted_prefix", decryptedPrefix)
}
// Update user with encrypted TOTP secret
if err := s.userRepo.UpdateTotpSecret(ctx, userID, &encryptedSecret); err != nil {
return fmt.Errorf("update totp secret: %w", err)
}
// Enable TOTP for the user
if err := s.userRepo.EnableTotp(ctx, userID); err != nil {
return fmt.Errorf("enable totp: %w", err)
}
// Clean up the setup session
_ = s.cache.DeleteSetupSession(ctx, userID)
return nil
}
// Disable disables TOTP for a user
// If email verification is enabled, emailCode is required; otherwise password is required
func (s *TotpService) Disable(ctx context.Context, userID int64, emailCode, password string) error {
// Get user
user, err := s.userRepo.GetByID(ctx, userID)
if err != nil {
return fmt.Errorf("get user: %w", err)
}
if !user.TotpEnabled {
return ErrTotpNotSetup
}
// Verify identity based on email verification setting
if s.settingService.IsEmailVerifyEnabled(ctx) {
// Email verification enabled - verify email code
if emailCode == "" {
return ErrVerifyCodeRequired
}
if err := s.emailService.VerifyCode(ctx, user.Email, emailCode); err != nil {
return err
}
} else {
// Email verification disabled - verify password
if password == "" {
return ErrPasswordRequired
}
if !user.CheckPassword(password) {
return ErrPasswordIncorrect
}
}
// Disable TOTP
if err := s.userRepo.DisableTotp(ctx, userID); err != nil {
return fmt.Errorf("disable totp: %w", err)
}
return nil
}
// VerifyCode verifies a TOTP code for a user
func (s *TotpService) VerifyCode(ctx context.Context, userID int64, code string) error {
slog.Debug("totp_verify_code_called",
"user_id", userID,
"code_len", len(code))
// Check rate limiting
attempts, err := s.cache.GetVerifyAttempts(ctx, userID)
if err == nil && attempts >= maxTotpAttempts {
return ErrTotpTooManyAttempts
}
// Get user
user, err := s.userRepo.GetByID(ctx, userID)
if err != nil {
slog.Debug("totp_verify_get_user_failed",
"user_id", userID,
"error", err)
return infraerrors.InternalServer("TOTP_VERIFY_ERROR", "failed to verify totp code")
}
if !user.TotpEnabled || user.TotpSecretEncrypted == nil {
slog.Debug("totp_verify_not_setup",
"user_id", userID,
"enabled", user.TotpEnabled,
"has_secret", user.TotpSecretEncrypted != nil)
return ErrTotpNotSetup
}
slog.Debug("totp_verify_encrypted_secret",
"user_id", userID,
"encrypted_len", len(*user.TotpSecretEncrypted))
// Decrypt the secret
secret, err := s.encryptor.Decrypt(*user.TotpSecretEncrypted)
if err != nil {
slog.Debug("totp_verify_decrypt_failed",
"user_id", userID,
"error", err)
return infraerrors.InternalServer("TOTP_VERIFY_ERROR", "failed to verify totp code")
}
secretPrefix := "N/A"
if len(secret) >= 4 {
secretPrefix = secret[:4]
}
slog.Debug("totp_verify_decrypted",
"user_id", userID,
"secret_len", len(secret),
"secret_prefix", secretPrefix)
// Verify the code
valid := totp.Validate(code, secret)
slog.Debug("totp_verify_result",
"user_id", userID,
"valid", valid,
"secret_len", len(secret),
"secret_prefix", secretPrefix,
"server_time", time.Now().UTC().Format(time.RFC3339))
if !valid {
// Increment failed attempts
_, _ = s.cache.IncrementVerifyAttempts(ctx, userID)
return ErrTotpInvalidCode
}
// Clear attempt counter on success
_ = s.cache.ClearVerifyAttempts(ctx, userID)
return nil
}
// CreateLoginSession creates a temporary login session for 2FA
func (s *TotpService) CreateLoginSession(ctx context.Context, userID int64, email string) (string, error) {
// Generate a random temp token
tempToken, err := generateRandomToken(32)
if err != nil {
return "", fmt.Errorf("generate temp token: %w", err)
}
session := &TotpLoginSession{
UserID: userID,
Email: email,
TokenExpiry: time.Now().Add(totpLoginTTL),
}
if err := s.cache.SetLoginSession(ctx, tempToken, session, totpLoginTTL); err != nil {
return "", fmt.Errorf("store login session: %w", err)
}
return tempToken, nil
}
// GetLoginSession retrieves a login session
func (s *TotpService) GetLoginSession(ctx context.Context, tempToken string) (*TotpLoginSession, error) {
return s.cache.GetLoginSession(ctx, tempToken)
}
// DeleteLoginSession deletes a login session
func (s *TotpService) DeleteLoginSession(ctx context.Context, tempToken string) error {
return s.cache.DeleteLoginSession(ctx, tempToken)
}
// IsTotpEnabledForUser checks if TOTP is enabled for a specific user
func (s *TotpService) IsTotpEnabledForUser(ctx context.Context, userID int64) (bool, error) {
user, err := s.userRepo.GetByID(ctx, userID)
if err != nil {
return false, fmt.Errorf("get user: %w", err)
}
return user.TotpEnabled, nil
}
// MaskEmail masks an email address for display
func MaskEmail(email string) string {
if len(email) < 3 {
return "***"
}
atIdx := -1
for i, c := range email {
if c == '@' {
atIdx = i
break
}
}
if atIdx == -1 || atIdx < 1 {
return email[:1] + "***"
}
localPart := email[:atIdx]
domain := email[atIdx:]
if len(localPart) <= 2 {
return localPart[:1] + "***" + domain
}
return localPart[:1] + "***" + localPart[len(localPart)-1:] + domain
}
// generateRandomToken generates a random hex-encoded token
func generateRandomToken(byteLength int) (string, error) {
b := make([]byte, byteLength)
if _, err := rand.Read(b); err != nil {
return "", err
}
return hex.EncodeToString(b), nil
}
// VerificationMethod represents the method required for TOTP operations
type VerificationMethod struct {
Method string `json:"method"` // "email" or "password"
}
// GetVerificationMethod returns the verification method for TOTP operations
func (s *TotpService) GetVerificationMethod(ctx context.Context) *VerificationMethod {
if s.settingService.IsEmailVerifyEnabled(ctx) {
return &VerificationMethod{Method: "email"}
}
return &VerificationMethod{Method: "password"}
}
// SendVerifyCode sends an email verification code for TOTP operations
func (s *TotpService) SendVerifyCode(ctx context.Context, userID int64) error {
// Check if email verification is enabled
if !s.settingService.IsEmailVerifyEnabled(ctx) {
return infraerrors.BadRequest("EMAIL_VERIFY_NOT_ENABLED", "email verification is not enabled")
}
// Get user email
user, err := s.userRepo.GetByID(ctx, userID)
if err != nil {
return fmt.Errorf("get user: %w", err)
}
// Get site name for email
siteName := s.settingService.GetSiteName(ctx)
// Send verification code via queue
return s.emailQueueService.EnqueueVerifyCode(user.Email, siteName)
}
...@@ -14,6 +14,9 @@ type UsageLog struct { ...@@ -14,6 +14,9 @@ type UsageLog struct {
AccountID int64 AccountID int64
RequestID string RequestID string
Model string Model string
// ReasoningEffort is the request's reasoning effort level (OpenAI Responses API),
// e.g. "low" / "medium" / "high" / "xhigh". Nil means not provided / not applicable.
ReasoningEffort *string
GroupID *int64 GroupID *int64
SubscriptionID *int64 SubscriptionID *int64
......
...@@ -21,6 +21,11 @@ type User struct { ...@@ -21,6 +21,11 @@ type User struct {
CreatedAt time.Time CreatedAt time.Time
UpdatedAt time.Time UpdatedAt time.Time
// TOTP 双因素认证字段
TotpSecretEncrypted *string // AES-256-GCM 加密的 TOTP 密钥
TotpEnabled bool // 是否启用 TOTP
TotpEnabledAt *time.Time // TOTP 启用时间
APIKeys []APIKey APIKeys []APIKey
Subscriptions []UserSubscription Subscriptions []UserSubscription
} }
......
...@@ -38,6 +38,11 @@ type UserRepository interface { ...@@ -38,6 +38,11 @@ type UserRepository interface {
UpdateConcurrency(ctx context.Context, id int64, amount int) error UpdateConcurrency(ctx context.Context, id int64, amount int) error
ExistsByEmail(ctx context.Context, email string) (bool, error) ExistsByEmail(ctx context.Context, email string) (bool, error)
RemoveGroupFromAllowedGroups(ctx context.Context, groupID int64) (int64, error) RemoveGroupFromAllowedGroups(ctx context.Context, groupID int64) (int64, error)
// TOTP 相关方法
UpdateTotpSecret(ctx context.Context, userID int64, encryptedSecret *string) error
EnableTotp(ctx context.Context, userID int64) error
DisableTotp(ctx context.Context, userID int64) error
} }
// UpdateProfileRequest 更新用户资料请求 // UpdateProfileRequest 更新用户资料请求
......
...@@ -18,7 +18,7 @@ type UserSubscriptionRepository interface { ...@@ -18,7 +18,7 @@ type UserSubscriptionRepository interface {
ListByUserID(ctx context.Context, userID int64) ([]UserSubscription, error) ListByUserID(ctx context.Context, userID int64) ([]UserSubscription, error)
ListActiveByUserID(ctx context.Context, userID int64) ([]UserSubscription, error) ListActiveByUserID(ctx context.Context, userID int64) ([]UserSubscription, error)
ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]UserSubscription, *pagination.PaginationResult, error) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]UserSubscription, *pagination.PaginationResult, error)
List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status string) ([]UserSubscription, *pagination.PaginationResult, error) List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status, sortBy, sortOrder string) ([]UserSubscription, *pagination.PaginationResult, error)
ExistsByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (bool, error) ExistsByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (bool, error)
ExtendExpiry(ctx context.Context, subscriptionID int64, newExpiresAt time.Time) error ExtendExpiry(ctx context.Context, subscriptionID int64, newExpiresAt time.Time) error
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment