Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Menu
Open sidebar
陈曦
sub2api
Commits
a161fcc8
Commit
a161fcc8
authored
Jan 26, 2026
by
cyhhao
Browse files
Merge branch 'main' of github.com:Wei-Shaw/sub2api
parents
65e69738
e32c5f53
Changes
119
Show whitespace changes
Inline
Side-by-side
backend/internal/service/email_service.go
View file @
a161fcc8
...
...
@@ -3,11 +3,14 @@ package service
import
(
"context"
"crypto/rand"
"crypto/subtle"
"crypto/tls"
"encoding/hex"
"fmt"
"log"
"math/big"
"net/smtp"
"net/url"
"strconv"
"time"
...
...
@@ -19,6 +22,9 @@ var (
ErrInvalidVerifyCode
=
infraerrors
.
BadRequest
(
"INVALID_VERIFY_CODE"
,
"invalid or expired verification code"
)
ErrVerifyCodeTooFrequent
=
infraerrors
.
TooManyRequests
(
"VERIFY_CODE_TOO_FREQUENT"
,
"please wait before requesting a new code"
)
ErrVerifyCodeMaxAttempts
=
infraerrors
.
TooManyRequests
(
"VERIFY_CODE_MAX_ATTEMPTS"
,
"too many failed attempts, please request a new code"
)
// Password reset errors
ErrInvalidResetToken
=
infraerrors
.
BadRequest
(
"INVALID_RESET_TOKEN"
,
"invalid or expired password reset token"
)
)
// EmailCache defines cache operations for email service
...
...
@@ -26,6 +32,16 @@ type EmailCache interface {
GetVerificationCode
(
ctx
context
.
Context
,
email
string
)
(
*
VerificationCodeData
,
error
)
SetVerificationCode
(
ctx
context
.
Context
,
email
string
,
data
*
VerificationCodeData
,
ttl
time
.
Duration
)
error
DeleteVerificationCode
(
ctx
context
.
Context
,
email
string
)
error
// Password reset token methods
GetPasswordResetToken
(
ctx
context
.
Context
,
email
string
)
(
*
PasswordResetTokenData
,
error
)
SetPasswordResetToken
(
ctx
context
.
Context
,
email
string
,
data
*
PasswordResetTokenData
,
ttl
time
.
Duration
)
error
DeletePasswordResetToken
(
ctx
context
.
Context
,
email
string
)
error
// Password reset email cooldown methods
// Returns true if in cooldown period (email was sent recently)
IsPasswordResetEmailInCooldown
(
ctx
context
.
Context
,
email
string
)
bool
SetPasswordResetEmailCooldown
(
ctx
context
.
Context
,
email
string
,
ttl
time
.
Duration
)
error
}
// VerificationCodeData represents verification code data
...
...
@@ -35,10 +51,22 @@ type VerificationCodeData struct {
CreatedAt
time
.
Time
}
// PasswordResetTokenData represents password reset token data
type
PasswordResetTokenData
struct
{
Token
string
CreatedAt
time
.
Time
}
const
(
verifyCodeTTL
=
15
*
time
.
Minute
verifyCodeCooldown
=
1
*
time
.
Minute
maxVerifyCodeAttempts
=
5
// Password reset token settings
passwordResetTokenTTL
=
30
*
time
.
Minute
// Password reset email cooldown (prevent email bombing)
passwordResetEmailCooldown
=
30
*
time
.
Second
)
// SMTPConfig SMTP配置
...
...
@@ -254,8 +282,8 @@ func (s *EmailService) VerifyCode(ctx context.Context, email, code string) error
return
ErrVerifyCodeMaxAttempts
}
// 验证码不匹配
if
data
.
Code
!=
code
{
// 验证码不匹配
(constant-time comparison to prevent timing attacks)
if
subtle
.
ConstantTimeCompare
([]
byte
(
data
.
Code
),
[]
byte
(
code
))
!=
1
{
data
.
Attempts
++
if
err
:=
s
.
cache
.
SetVerificationCode
(
ctx
,
email
,
data
,
verifyCodeTTL
);
err
!=
nil
{
log
.
Printf
(
"[Email] Failed to update verification attempt count: %v"
,
err
)
...
...
@@ -357,3 +385,157 @@ func (s *EmailService) TestSMTPConnectionWithConfig(config *SMTPConfig) error {
return
client
.
Quit
()
}
// GeneratePasswordResetToken generates a secure 32-byte random token (64 hex characters)
func
(
s
*
EmailService
)
GeneratePasswordResetToken
()
(
string
,
error
)
{
bytes
:=
make
([]
byte
,
32
)
if
_
,
err
:=
rand
.
Read
(
bytes
);
err
!=
nil
{
return
""
,
err
}
return
hex
.
EncodeToString
(
bytes
),
nil
}
// SendPasswordResetEmail sends a password reset email with a reset link
func
(
s
*
EmailService
)
SendPasswordResetEmail
(
ctx
context
.
Context
,
email
,
siteName
,
resetURL
string
)
error
{
var
token
string
var
needSaveToken
bool
// Check if token already exists
existing
,
err
:=
s
.
cache
.
GetPasswordResetToken
(
ctx
,
email
)
if
err
==
nil
&&
existing
!=
nil
{
// Token exists, reuse it (allows resending email without generating new token)
token
=
existing
.
Token
needSaveToken
=
false
}
else
{
// Generate new token
token
,
err
=
s
.
GeneratePasswordResetToken
()
if
err
!=
nil
{
return
fmt
.
Errorf
(
"generate token: %w"
,
err
)
}
needSaveToken
=
true
}
// Save token to Redis (only if new token generated)
if
needSaveToken
{
data
:=
&
PasswordResetTokenData
{
Token
:
token
,
CreatedAt
:
time
.
Now
(),
}
if
err
:=
s
.
cache
.
SetPasswordResetToken
(
ctx
,
email
,
data
,
passwordResetTokenTTL
);
err
!=
nil
{
return
fmt
.
Errorf
(
"save reset token: %w"
,
err
)
}
}
// Build full reset URL with URL-encoded token and email
fullResetURL
:=
fmt
.
Sprintf
(
"%s?email=%s&token=%s"
,
resetURL
,
url
.
QueryEscape
(
email
),
url
.
QueryEscape
(
token
))
// Build email content
subject
:=
fmt
.
Sprintf
(
"[%s] 密码重置请求"
,
siteName
)
body
:=
s
.
buildPasswordResetEmailBody
(
fullResetURL
,
siteName
)
// Send email
if
err
:=
s
.
SendEmail
(
ctx
,
email
,
subject
,
body
);
err
!=
nil
{
return
fmt
.
Errorf
(
"send email: %w"
,
err
)
}
return
nil
}
// SendPasswordResetEmailWithCooldown sends password reset email with cooldown check (called by queue worker)
// This method wraps SendPasswordResetEmail with email cooldown to prevent email bombing
func
(
s
*
EmailService
)
SendPasswordResetEmailWithCooldown
(
ctx
context
.
Context
,
email
,
siteName
,
resetURL
string
)
error
{
// Check email cooldown to prevent email bombing
if
s
.
cache
.
IsPasswordResetEmailInCooldown
(
ctx
,
email
)
{
log
.
Printf
(
"[Email] Password reset email skipped (cooldown): %s"
,
email
)
return
nil
// Silent success to prevent revealing cooldown to attackers
}
// Send email using core method
if
err
:=
s
.
SendPasswordResetEmail
(
ctx
,
email
,
siteName
,
resetURL
);
err
!=
nil
{
return
err
}
// Set cooldown marker (Redis TTL handles expiration)
if
err
:=
s
.
cache
.
SetPasswordResetEmailCooldown
(
ctx
,
email
,
passwordResetEmailCooldown
);
err
!=
nil
{
log
.
Printf
(
"[Email] Failed to set password reset cooldown for %s: %v"
,
email
,
err
)
}
return
nil
}
// VerifyPasswordResetToken verifies the password reset token without consuming it
func
(
s
*
EmailService
)
VerifyPasswordResetToken
(
ctx
context
.
Context
,
email
,
token
string
)
error
{
data
,
err
:=
s
.
cache
.
GetPasswordResetToken
(
ctx
,
email
)
if
err
!=
nil
||
data
==
nil
{
return
ErrInvalidResetToken
}
// Use constant-time comparison to prevent timing attacks
if
subtle
.
ConstantTimeCompare
([]
byte
(
data
.
Token
),
[]
byte
(
token
))
!=
1
{
return
ErrInvalidResetToken
}
return
nil
}
// ConsumePasswordResetToken verifies and deletes the token (one-time use)
func
(
s
*
EmailService
)
ConsumePasswordResetToken
(
ctx
context
.
Context
,
email
,
token
string
)
error
{
// Verify first
if
err
:=
s
.
VerifyPasswordResetToken
(
ctx
,
email
,
token
);
err
!=
nil
{
return
err
}
// Delete after verification (one-time use)
if
err
:=
s
.
cache
.
DeletePasswordResetToken
(
ctx
,
email
);
err
!=
nil
{
log
.
Printf
(
"[Email] Failed to delete password reset token after consumption: %v"
,
err
)
}
return
nil
}
// buildPasswordResetEmailBody builds the HTML content for password reset email
func
(
s
*
EmailService
)
buildPasswordResetEmailBody
(
resetURL
,
siteName
string
)
string
{
return
fmt
.
Sprintf
(
`
<!DOCTYPE html>
<html>
<head>
<meta charset="UTF-8">
<style>
body { font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, Oxygen, Ubuntu, sans-serif; background-color: #f5f5f5; margin: 0; padding: 20px; }
.container { max-width: 600px; margin: 0 auto; background-color: #ffffff; border-radius: 8px; overflow: hidden; box-shadow: 0 2px 8px rgba(0,0,0,0.1); }
.header { background: linear-gradient(135deg, #667eea 0%%, #764ba2 100%%); color: white; padding: 30px; text-align: center; }
.header h1 { margin: 0; font-size: 24px; }
.content { padding: 40px 30px; text-align: center; }
.button { display: inline-block; background: linear-gradient(135deg, #667eea 0%%, #764ba2 100%%); color: white; padding: 14px 32px; text-decoration: none; border-radius: 8px; font-size: 16px; font-weight: 600; margin: 20px 0; }
.button:hover { opacity: 0.9; }
.info { color: #666; font-size: 14px; line-height: 1.6; margin-top: 20px; }
.link-fallback { color: #666; font-size: 12px; word-break: break-all; margin-top: 20px; padding: 15px; background-color: #f8f9fa; border-radius: 4px; }
.footer { background-color: #f8f9fa; padding: 20px; text-align: center; color: #999; font-size: 12px; }
.warning { color: #e74c3c; font-weight: 500; }
</style>
</head>
<body>
<div class="container">
<div class="header">
<h1>%s</h1>
</div>
<div class="content">
<p style="font-size: 18px; color: #333;">密码重置请求</p>
<p style="color: #666;">您已请求重置密码。请点击下方按钮设置新密码:</p>
<a href="%s" class="button">重置密码</a>
<div class="info">
<p>此链接将在 <strong>30 分钟</strong>后失效。</p>
<p class="warning">如果您没有请求重置密码,请忽略此邮件。您的密码将保持不变。</p>
</div>
<div class="link-fallback">
<p>如果按钮无法点击,请复制以下链接到浏览器中打开:</p>
<p>%s</p>
</div>
</div>
<div class="footer">
<p>这是一封自动发送的邮件,请勿回复。</p>
</div>
</div>
</body>
</html>
`
,
siteName
,
resetURL
,
resetURL
)
}
backend/internal/service/gateway_service.go
View file @
a161fcc8
...
...
@@ -342,6 +342,19 @@ func (s *GatewayService) BindStickySession(ctx context.Context, groupID *int64,
return
s
.
cache
.
SetSessionAccountID
(
ctx
,
derefGroupID
(
groupID
),
sessionHash
,
accountID
,
stickySessionTTL
)
}
// GetCachedSessionAccountID retrieves the account ID bound to a sticky session.
// Returns 0 if no binding exists or on error.
func
(
s
*
GatewayService
)
GetCachedSessionAccountID
(
ctx
context
.
Context
,
groupID
*
int64
,
sessionHash
string
)
(
int64
,
error
)
{
if
sessionHash
==
""
||
s
.
cache
==
nil
{
return
0
,
nil
}
accountID
,
err
:=
s
.
cache
.
GetSessionAccountID
(
ctx
,
derefGroupID
(
groupID
),
sessionHash
)
if
err
!=
nil
{
return
0
,
err
}
return
accountID
,
nil
}
func
(
s
*
GatewayService
)
extractCacheableContent
(
parsed
*
ParsedRequest
)
string
{
if
parsed
==
nil
{
return
""
...
...
backend/internal/service/gemini_messages_compat_service.go
View file @
a161fcc8
...
...
@@ -1972,6 +1972,7 @@ func collectGeminiSSE(body io.Reader, isOAuth bool) (map[string]any, *ClaudeUsag
var
last
map
[
string
]
any
var
lastWithParts
map
[
string
]
any
var
collectedTextParts
[]
string
// Collect all text parts for aggregation
usage
:=
&
ClaudeUsage
{}
for
{
...
...
@@ -1983,7 +1984,7 @@ func collectGeminiSSE(body io.Reader, isOAuth bool) (map[string]any, *ClaudeUsag
switch
payload
{
case
""
,
"[DONE]"
:
if
payload
==
"[DONE]"
{
return
pickGeminiCollectResult
(
last
,
lastWithParts
),
usage
,
nil
return
mergeCollectedTextParts
(
pickGeminiCollectResult
(
last
,
lastWithParts
),
collectedTextParts
),
usage
,
nil
}
default
:
var
parsed
map
[
string
]
any
...
...
@@ -2002,6 +2003,12 @@ func collectGeminiSSE(body io.Reader, isOAuth bool) (map[string]any, *ClaudeUsag
}
if
parts
:=
extractGeminiParts
(
parsed
);
len
(
parts
)
>
0
{
lastWithParts
=
parsed
// Collect text from each part for aggregation
for
_
,
part
:=
range
parts
{
if
text
,
ok
:=
part
[
"text"
]
.
(
string
);
ok
&&
text
!=
""
{
collectedTextParts
=
append
(
collectedTextParts
,
text
)
}
}
}
}
}
...
...
@@ -2016,7 +2023,7 @@ func collectGeminiSSE(body io.Reader, isOAuth bool) (map[string]any, *ClaudeUsag
}
}
return
pickGeminiCollectResult
(
last
,
lastWithParts
),
usage
,
nil
return
mergeCollectedTextParts
(
pickGeminiCollectResult
(
last
,
lastWithParts
),
collectedTextParts
),
usage
,
nil
}
func
pickGeminiCollectResult
(
last
map
[
string
]
any
,
lastWithParts
map
[
string
]
any
)
map
[
string
]
any
{
...
...
@@ -2029,6 +2036,83 @@ func pickGeminiCollectResult(last map[string]any, lastWithParts map[string]any)
return
map
[
string
]
any
{}
}
// mergeCollectedTextParts merges all collected text chunks into the final response.
// This fixes the issue where non-streaming responses only returned the last chunk
// instead of the complete aggregated text.
func
mergeCollectedTextParts
(
response
map
[
string
]
any
,
textParts
[]
string
)
map
[
string
]
any
{
if
len
(
textParts
)
==
0
{
return
response
}
// Join all text parts
mergedText
:=
strings
.
Join
(
textParts
,
""
)
// Deep copy response
result
:=
make
(
map
[
string
]
any
)
for
k
,
v
:=
range
response
{
result
[
k
]
=
v
}
// Get or create candidates
candidates
,
ok
:=
result
[
"candidates"
]
.
([]
any
)
if
!
ok
||
len
(
candidates
)
==
0
{
candidates
=
[]
any
{
map
[
string
]
any
{}}
}
// Get first candidate
candidate
,
ok
:=
candidates
[
0
]
.
(
map
[
string
]
any
)
if
!
ok
{
candidate
=
make
(
map
[
string
]
any
)
candidates
[
0
]
=
candidate
}
// Get or create content
content
,
ok
:=
candidate
[
"content"
]
.
(
map
[
string
]
any
)
if
!
ok
{
content
=
map
[
string
]
any
{
"role"
:
"model"
}
candidate
[
"content"
]
=
content
}
// Get existing parts
existingParts
,
ok
:=
content
[
"parts"
]
.
([]
any
)
if
!
ok
{
existingParts
=
[]
any
{}
}
// Find and update first text part, or create new one
newParts
:=
make
([]
any
,
0
,
len
(
existingParts
)
+
1
)
textUpdated
:=
false
for
_
,
p
:=
range
existingParts
{
pm
,
ok
:=
p
.
(
map
[
string
]
any
)
if
!
ok
{
newParts
=
append
(
newParts
,
p
)
continue
}
if
_
,
hasText
:=
pm
[
"text"
];
hasText
&&
!
textUpdated
{
// Replace with merged text
newPart
:=
make
(
map
[
string
]
any
)
for
k
,
v
:=
range
pm
{
newPart
[
k
]
=
v
}
newPart
[
"text"
]
=
mergedText
newParts
=
append
(
newParts
,
newPart
)
textUpdated
=
true
}
else
{
newParts
=
append
(
newParts
,
pm
)
}
}
if
!
textUpdated
{
newParts
=
append
([]
any
{
map
[
string
]
any
{
"text"
:
mergedText
}},
newParts
...
)
}
content
[
"parts"
]
=
newParts
result
[
"candidates"
]
=
candidates
return
result
}
type
geminiNativeStreamResult
struct
{
usage
*
ClaudeUsage
firstTokenMs
*
int
...
...
backend/internal/service/gemini_native_signature_cleaner.go
0 → 100644
View file @
a161fcc8
package
service
import
(
"encoding/json"
)
// CleanGeminiNativeThoughtSignatures 从 Gemini 原生 API 请求中移除 thoughtSignature 字段,
// 以避免跨账号签名验证错误。
//
// 当粘性会话切换账号时(例如原账号异常、不可调度等),旧账号返回的 thoughtSignature
// 会导致新账号的签名验证失败。通过移除这些签名,让新账号重新生成有效的签名。
//
// CleanGeminiNativeThoughtSignatures removes thoughtSignature fields from Gemini native API requests
// to avoid cross-account signature validation errors.
//
// When sticky session switches accounts (e.g., original account becomes unavailable),
// thoughtSignatures from the old account will cause validation failures on the new account.
// By removing these signatures, we allow the new account to generate valid signatures.
func
CleanGeminiNativeThoughtSignatures
(
body
[]
byte
)
[]
byte
{
if
len
(
body
)
==
0
{
return
body
}
// 解析 JSON
var
data
any
if
err
:=
json
.
Unmarshal
(
body
,
&
data
);
err
!=
nil
{
// 如果解析失败,返回原始 body(可能不是 JSON 或格式不正确)
return
body
}
// 递归清理 thoughtSignature
cleaned
:=
cleanThoughtSignaturesRecursive
(
data
)
// 重新序列化
result
,
err
:=
json
.
Marshal
(
cleaned
)
if
err
!=
nil
{
// 如果序列化失败,返回原始 body
return
body
}
return
result
}
// cleanThoughtSignaturesRecursive 递归遍历数据结构,移除所有 thoughtSignature 字段
func
cleanThoughtSignaturesRecursive
(
data
any
)
any
{
switch
v
:=
data
.
(
type
)
{
case
map
[
string
]
any
:
// 创建新的 map,移除 thoughtSignature
result
:=
make
(
map
[
string
]
any
,
len
(
v
))
for
key
,
value
:=
range
v
{
// 跳过 thoughtSignature 字段
if
key
==
"thoughtSignature"
{
continue
}
// 递归处理嵌套结构
result
[
key
]
=
cleanThoughtSignaturesRecursive
(
value
)
}
return
result
case
[]
any
:
// 递归处理数组中的每个元素
result
:=
make
([]
any
,
len
(
v
))
for
i
,
item
:=
range
v
{
result
[
i
]
=
cleanThoughtSignaturesRecursive
(
item
)
}
return
result
default
:
// 基本类型(string, number, bool, null)直接返回
return
v
}
}
backend/internal/service/gemini_token_provider.go
View file @
a161fcc8
...
...
@@ -4,6 +4,7 @@ import (
"context"
"errors"
"log"
"log/slog"
"strconv"
"strings"
"time"
...
...
@@ -131,8 +132,18 @@ func (p *GeminiTokenProvider) GetAccessToken(ctx context.Context, account *Accou
}
}
// 3) Populate cache with TTL
.
// 3) Populate cache with TTL
(验证版本后再写入,避免异步刷新任务与请求线程的竞态条件)
if
p
.
tokenCache
!=
nil
{
latestAccount
,
isStale
:=
CheckTokenVersion
(
ctx
,
account
,
p
.
accountRepo
)
if
isStale
&&
latestAccount
!=
nil
{
// 版本过时,使用 DB 中的最新 token
slog
.
Debug
(
"gemini_token_version_stale_use_latest"
,
"account_id"
,
account
.
ID
)
accessToken
=
latestAccount
.
GetCredential
(
"access_token"
)
if
strings
.
TrimSpace
(
accessToken
)
==
""
{
return
""
,
errors
.
New
(
"access_token not found after version check"
)
}
// 不写入缓存,让下次请求重新处理
}
else
{
ttl
:=
30
*
time
.
Minute
if
expiresAt
!=
nil
{
until
:=
time
.
Until
(
*
expiresAt
)
...
...
@@ -147,6 +158,7 @@ func (p *GeminiTokenProvider) GetAccessToken(ctx context.Context, account *Accou
}
_
=
p
.
tokenCache
.
SetAccessToken
(
ctx
,
cacheKey
,
accessToken
,
ttl
)
}
}
return
accessToken
,
nil
}
...
...
backend/internal/service/oauth_service.go
View file @
a161fcc8
...
...
@@ -122,6 +122,7 @@ type TokenInfo struct {
Scope
string
`json:"scope,omitempty"`
OrgUUID
string
`json:"org_uuid,omitempty"`
AccountUUID
string
`json:"account_uuid,omitempty"`
EmailAddress
string
`json:"email_address,omitempty"`
}
// ExchangeCode exchanges authorization code for tokens
...
...
@@ -252,10 +253,16 @@ func (s *OAuthService) exchangeCodeForToken(ctx context.Context, code, codeVerif
tokenInfo
.
OrgUUID
=
tokenResp
.
Organization
.
UUID
log
.
Printf
(
"[OAuth] Got org_uuid: %s"
,
tokenInfo
.
OrgUUID
)
}
if
tokenResp
.
Account
!=
nil
&&
tokenResp
.
Account
.
UUID
!=
""
{
if
tokenResp
.
Account
!=
nil
{
if
tokenResp
.
Account
.
UUID
!=
""
{
tokenInfo
.
AccountUUID
=
tokenResp
.
Account
.
UUID
log
.
Printf
(
"[OAuth] Got account_uuid: %s"
,
tokenInfo
.
AccountUUID
)
}
if
tokenResp
.
Account
.
EmailAddress
!=
""
{
tokenInfo
.
EmailAddress
=
tokenResp
.
Account
.
EmailAddress
log
.
Printf
(
"[OAuth] Got email_address: %s"
,
tokenInfo
.
EmailAddress
)
}
}
return
tokenInfo
,
nil
}
...
...
backend/internal/service/openai_gateway_service.go
View file @
a161fcc8
...
...
@@ -60,6 +60,92 @@ type OpenAICodexUsageSnapshot struct {
UpdatedAt
string
`json:"updated_at,omitempty"`
}
// NormalizedCodexLimits contains normalized 5h/7d rate limit data
type
NormalizedCodexLimits
struct
{
Used5hPercent
*
float64
Reset5hSeconds
*
int
Window5hMinutes
*
int
Used7dPercent
*
float64
Reset7dSeconds
*
int
Window7dMinutes
*
int
}
// Normalize converts primary/secondary fields to canonical 5h/7d fields.
// Strategy: Compare window_minutes to determine which is 5h vs 7d.
// Returns nil if snapshot is nil or has no useful data.
func
(
s
*
OpenAICodexUsageSnapshot
)
Normalize
()
*
NormalizedCodexLimits
{
if
s
==
nil
{
return
nil
}
result
:=
&
NormalizedCodexLimits
{}
primaryMins
:=
0
secondaryMins
:=
0
hasPrimaryWindow
:=
false
hasSecondaryWindow
:=
false
if
s
.
PrimaryWindowMinutes
!=
nil
{
primaryMins
=
*
s
.
PrimaryWindowMinutes
hasPrimaryWindow
=
true
}
if
s
.
SecondaryWindowMinutes
!=
nil
{
secondaryMins
=
*
s
.
SecondaryWindowMinutes
hasSecondaryWindow
=
true
}
// Determine mapping based on window_minutes
use5hFromPrimary
:=
false
use7dFromPrimary
:=
false
if
hasPrimaryWindow
&&
hasSecondaryWindow
{
// Both known: smaller window is 5h, larger is 7d
if
primaryMins
<
secondaryMins
{
use5hFromPrimary
=
true
}
else
{
use7dFromPrimary
=
true
}
}
else
if
hasPrimaryWindow
{
// Only primary known: classify by threshold (<=360 min = 6h -> 5h window)
if
primaryMins
<=
360
{
use5hFromPrimary
=
true
}
else
{
use7dFromPrimary
=
true
}
}
else
if
hasSecondaryWindow
{
// Only secondary known: classify by threshold
if
secondaryMins
<=
360
{
// 5h from secondary, so primary (if any data) is 7d
use7dFromPrimary
=
true
}
else
{
// 7d from secondary, so primary (if any data) is 5h
use5hFromPrimary
=
true
}
}
else
{
// No window_minutes: fall back to legacy assumption (primary=7d, secondary=5h)
use7dFromPrimary
=
true
}
// Assign values
if
use5hFromPrimary
{
result
.
Used5hPercent
=
s
.
PrimaryUsedPercent
result
.
Reset5hSeconds
=
s
.
PrimaryResetAfterSeconds
result
.
Window5hMinutes
=
s
.
PrimaryWindowMinutes
result
.
Used7dPercent
=
s
.
SecondaryUsedPercent
result
.
Reset7dSeconds
=
s
.
SecondaryResetAfterSeconds
result
.
Window7dMinutes
=
s
.
SecondaryWindowMinutes
}
else
if
use7dFromPrimary
{
result
.
Used7dPercent
=
s
.
PrimaryUsedPercent
result
.
Reset7dSeconds
=
s
.
PrimaryResetAfterSeconds
result
.
Window7dMinutes
=
s
.
PrimaryWindowMinutes
result
.
Used5hPercent
=
s
.
SecondaryUsedPercent
result
.
Reset5hSeconds
=
s
.
SecondaryResetAfterSeconds
result
.
Window5hMinutes
=
s
.
SecondaryWindowMinutes
}
return
result
}
// OpenAIUsage represents OpenAI API response usage
type
OpenAIUsage
struct
{
InputTokens
int
`json:"input_tokens"`
...
...
@@ -867,7 +953,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
// Extract and save Codex usage snapshot from response headers (for OAuth accounts)
if
account
.
Type
==
AccountTypeOAuth
{
if
snapshot
:=
extractCodexUsage
Headers
(
resp
.
Header
);
snapshot
!=
nil
{
if
snapshot
:=
ParseCodexRateLimit
Headers
(
resp
.
Header
);
snapshot
!=
nil
{
s
.
updateCodexUsageSnapshot
(
ctx
,
account
.
ID
,
snapshot
)
}
}
...
...
@@ -1706,8 +1792,9 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
return
nil
}
// extractCodexUsageHeaders extracts Codex usage limits from response headers
func
extractCodexUsageHeaders
(
headers
http
.
Header
)
*
OpenAICodexUsageSnapshot
{
// ParseCodexRateLimitHeaders extracts Codex usage limits from response headers.
// Exported for use in ratelimit_service when handling OpenAI 429 responses.
func
ParseCodexRateLimitHeaders
(
headers
http
.
Header
)
*
OpenAICodexUsageSnapshot
{
snapshot
:=
&
OpenAICodexUsageSnapshot
{}
hasData
:=
false
...
...
@@ -1781,6 +1868,8 @@ func (s *OpenAIGatewayService) updateCodexUsageSnapshot(ctx context.Context, acc
// Convert snapshot to map for merging into Extra
updates
:=
make
(
map
[
string
]
any
)
// Save raw primary/secondary fields for debugging/tracing
if
snapshot
.
PrimaryUsedPercent
!=
nil
{
updates
[
"codex_primary_used_percent"
]
=
*
snapshot
.
PrimaryUsedPercent
}
...
...
@@ -1804,109 +1893,25 @@ func (s *OpenAIGatewayService) updateCodexUsageSnapshot(ctx context.Context, acc
}
updates
[
"codex_usage_updated_at"
]
=
snapshot
.
UpdatedAt
// Normalize to canonical 5h/7d fields based on window_minutes
// This fixes the issue where OpenAI's primary/secondary naming is reversed
// Strategy: Compare the two windows and assign the smaller one to 5h, larger one to 7d
// IMPORTANT: We can only reliably determine window type from window_minutes field
// The reset_after_seconds is remaining time, not window size, so it cannot be used for comparison
var
primaryWindowMins
,
secondaryWindowMins
int
var
hasPrimaryWindow
,
hasSecondaryWindow
bool
// Only use window_minutes for reliable window size comparison
if
snapshot
.
PrimaryWindowMinutes
!=
nil
{
primaryWindowMins
=
*
snapshot
.
PrimaryWindowMinutes
hasPrimaryWindow
=
true
}
if
snapshot
.
SecondaryWindowMinutes
!=
nil
{
secondaryWindowMins
=
*
snapshot
.
SecondaryWindowMinutes
hasSecondaryWindow
=
true
}
// Determine which is 5h and which is 7d
var
use5hFromPrimary
,
use7dFromPrimary
bool
var
use5hFromSecondary
,
use7dFromSecondary
bool
if
hasPrimaryWindow
&&
hasSecondaryWindow
{
// Both window sizes known: compare and assign smaller to 5h, larger to 7d
if
primaryWindowMins
<
secondaryWindowMins
{
use5hFromPrimary
=
true
use7dFromSecondary
=
true
}
else
{
use5hFromSecondary
=
true
use7dFromPrimary
=
true
}
}
else
if
hasPrimaryWindow
{
// Only primary window size known: classify by absolute threshold
if
primaryWindowMins
<=
360
{
use5hFromPrimary
=
true
}
else
{
use7dFromPrimary
=
true
}
}
else
if
hasSecondaryWindow
{
// Only secondary window size known: classify by absolute threshold
if
secondaryWindowMins
<=
360
{
use5hFromSecondary
=
true
}
else
{
use7dFromSecondary
=
true
}
}
else
{
// No window_minutes available: cannot reliably determine window types
// Fall back to legacy assumption (may be incorrect)
// Assume primary=7d, secondary=5h based on historical observation
if
snapshot
.
SecondaryUsedPercent
!=
nil
||
snapshot
.
SecondaryResetAfterSeconds
!=
nil
||
snapshot
.
SecondaryWindowMinutes
!=
nil
{
use5hFromSecondary
=
true
}
if
snapshot
.
PrimaryUsedPercent
!=
nil
||
snapshot
.
PrimaryResetAfterSeconds
!=
nil
||
snapshot
.
PrimaryWindowMinutes
!=
nil
{
use7dFromPrimary
=
true
}
}
// Write canonical 5h fields
if
use5hFromPrimary
{
if
snapshot
.
PrimaryUsedPercent
!=
nil
{
updates
[
"codex_5h_used_percent"
]
=
*
snapshot
.
PrimaryUsedPercent
}
if
snapshot
.
PrimaryResetAfterSeconds
!=
nil
{
updates
[
"codex_5h_reset_after_seconds"
]
=
*
snapshot
.
PrimaryResetAfterSeconds
}
if
snapshot
.
PrimaryWindowMinutes
!=
nil
{
updates
[
"codex_5h_window_minutes"
]
=
*
snapshot
.
PrimaryWindowMinutes
}
}
else
if
use5hFromSecondary
{
if
snapshot
.
SecondaryUsedPercent
!=
nil
{
updates
[
"codex_5h_used_percent"
]
=
*
snapshot
.
SecondaryUsedPercent
}
if
snapshot
.
SecondaryResetAfterSeconds
!=
nil
{
updates
[
"codex_5h_reset_after_seconds"
]
=
*
snapshot
.
SecondaryResetAfterSeconds
}
if
snapshot
.
SecondaryWindowMinutes
!=
nil
{
updates
[
"codex_5h_window_minutes"
]
=
*
snapshot
.
SecondaryWindowMinutes
}
}
// Write canonical 7d fields
if
use7dFromPrimary
{
if
snapshot
.
PrimaryUsedPercent
!=
nil
{
updates
[
"codex_7d_used_percent"
]
=
*
snapshot
.
PrimaryUsedPercent
// Normalize to canonical 5h/7d fields
if
normalized
:=
snapshot
.
Normalize
();
normalized
!=
nil
{
if
normalized
.
Used5hPercent
!=
nil
{
updates
[
"codex_5h_used_percent"
]
=
*
normalized
.
Used5hPercent
}
if
snapshot
.
PrimaryResetAfter
Seconds
!=
nil
{
updates
[
"codex_
7d
_reset_after_seconds"
]
=
*
snapshot
.
PrimaryResetAfter
Seconds
if
normalized
.
Reset5h
Seconds
!=
nil
{
updates
[
"codex_
5h
_reset_after_seconds"
]
=
*
normalized
.
Reset5h
Seconds
}
if
snapshot
.
Primary
WindowMinutes
!=
nil
{
updates
[
"codex_
7d
_window_minutes"
]
=
*
snapshot
.
Primary
WindowMinutes
if
normalized
.
Window
5h
Minutes
!=
nil
{
updates
[
"codex_
5h
_window_minutes"
]
=
*
normalized
.
Window
5h
Minutes
}
}
else
if
use7dFromSecondary
{
if
snapshot
.
SecondaryUsedPercent
!=
nil
{
updates
[
"codex_7d_used_percent"
]
=
*
snapshot
.
SecondaryUsedPercent
if
normalized
.
Used7dPercent
!=
nil
{
updates
[
"codex_7d_used_percent"
]
=
*
normalized
.
Used7dPercent
}
if
snapshot
.
SecondaryResetAfter
Seconds
!=
nil
{
updates
[
"codex_7d_reset_after_seconds"
]
=
*
snapshot
.
SecondaryResetAfter
Seconds
if
normalized
.
Reset7d
Seconds
!=
nil
{
updates
[
"codex_7d_reset_after_seconds"
]
=
*
normalized
.
Reset7d
Seconds
}
if
snapshot
.
Secondary
WindowMinutes
!=
nil
{
updates
[
"codex_7d_window_minutes"
]
=
*
snapshot
.
Secondary
WindowMinutes
if
normalized
.
Window
7d
Minutes
!=
nil
{
updates
[
"codex_7d_window_minutes"
]
=
*
normalized
.
Window
7d
Minutes
}
}
...
...
backend/internal/service/openai_token_provider.go
View file @
a161fcc8
...
...
@@ -162,8 +162,18 @@ func (p *OpenAITokenProvider) GetAccessToken(ctx context.Context, account *Accou
return
""
,
errors
.
New
(
"access_token not found in credentials"
)
}
// 3. 存入缓存
// 3. 存入缓存
(验证版本后再写入,避免异步刷新任务与请求线程的竞态条件)
if
p
.
tokenCache
!=
nil
{
latestAccount
,
isStale
:=
CheckTokenVersion
(
ctx
,
account
,
p
.
accountRepo
)
if
isStale
&&
latestAccount
!=
nil
{
// 版本过时,使用 DB 中的最新 token
slog
.
Debug
(
"openai_token_version_stale_use_latest"
,
"account_id"
,
account
.
ID
)
accessToken
=
latestAccount
.
GetOpenAIAccessToken
()
if
strings
.
TrimSpace
(
accessToken
)
==
""
{
return
""
,
errors
.
New
(
"access_token not found after version check"
)
}
// 不写入缓存,让下次请求重新处理
}
else
{
ttl
:=
30
*
time
.
Minute
if
refreshFailed
{
// 刷新失败时使用短 TTL,避免失效 token 长时间缓存导致 401 抖动
...
...
@@ -184,6 +194,7 @@ func (p *OpenAITokenProvider) GetAccessToken(ctx context.Context, account *Accou
slog
.
Warn
(
"openai_token_cache_set_failed"
,
"account_id"
,
account
.
ID
,
"error"
,
err
)
}
}
}
return
accessToken
,
nil
}
backend/internal/service/ratelimit_service.go
View file @
a161fcc8
...
...
@@ -343,9 +343,48 @@ func (s *RateLimitService) handleCustomErrorCode(ctx context.Context, account *A
// handle429 处理429限流错误
// 解析响应头获取重置时间,标记账号为限流状态
func
(
s
*
RateLimitService
)
handle429
(
ctx
context
.
Context
,
account
*
Account
,
headers
http
.
Header
,
responseBody
[]
byte
)
{
// 解析重置时间戳
// 1. OpenAI 平台:优先尝试解析 x-codex-* 响应头(用于 rate_limit_exceeded)
if
account
.
Platform
==
PlatformOpenAI
{
if
resetAt
:=
s
.
calculateOpenAI429ResetTime
(
headers
);
resetAt
!=
nil
{
if
err
:=
s
.
accountRepo
.
SetRateLimited
(
ctx
,
account
.
ID
,
*
resetAt
);
err
!=
nil
{
slog
.
Warn
(
"rate_limit_set_failed"
,
"account_id"
,
account
.
ID
,
"error"
,
err
)
return
}
slog
.
Info
(
"openai_account_rate_limited"
,
"account_id"
,
account
.
ID
,
"reset_at"
,
*
resetAt
)
return
}
}
// 2. 尝试从响应头解析重置时间(Anthropic)
resetTimestamp
:=
headers
.
Get
(
"anthropic-ratelimit-unified-reset"
)
// 3. 如果响应头没有,尝试从响应体解析(OpenAI usage_limit_reached, Gemini)
if
resetTimestamp
==
""
{
switch
account
.
Platform
{
case
PlatformOpenAI
:
// 尝试解析 OpenAI 的 usage_limit_reached 错误
if
resetAt
:=
parseOpenAIRateLimitResetTime
(
responseBody
);
resetAt
!=
nil
{
resetTime
:=
time
.
Unix
(
*
resetAt
,
0
)
if
err
:=
s
.
accountRepo
.
SetRateLimited
(
ctx
,
account
.
ID
,
resetTime
);
err
!=
nil
{
slog
.
Warn
(
"rate_limit_set_failed"
,
"account_id"
,
account
.
ID
,
"error"
,
err
)
return
}
slog
.
Info
(
"account_rate_limited"
,
"account_id"
,
account
.
ID
,
"platform"
,
account
.
Platform
,
"reset_at"
,
resetTime
,
"reset_in"
,
time
.
Until
(
resetTime
)
.
Truncate
(
time
.
Second
))
return
}
case
PlatformGemini
,
PlatformAntigravity
:
// 尝试解析 Gemini 格式(用于其他平台)
if
resetAt
:=
ParseGeminiRateLimitResetTime
(
responseBody
);
resetAt
!=
nil
{
resetTime
:=
time
.
Unix
(
*
resetAt
,
0
)
if
err
:=
s
.
accountRepo
.
SetRateLimited
(
ctx
,
account
.
ID
,
resetTime
);
err
!=
nil
{
slog
.
Warn
(
"rate_limit_set_failed"
,
"account_id"
,
account
.
ID
,
"error"
,
err
)
return
}
slog
.
Info
(
"account_rate_limited"
,
"account_id"
,
account
.
ID
,
"platform"
,
account
.
Platform
,
"reset_at"
,
resetTime
,
"reset_in"
,
time
.
Until
(
resetTime
)
.
Truncate
(
time
.
Second
))
return
}
}
// 没有重置时间,使用默认5分钟
resetAt
:=
time
.
Now
()
.
Add
(
5
*
time
.
Minute
)
if
s
.
shouldScopeClaudeSonnetRateLimit
(
account
,
responseBody
)
{
...
...
@@ -356,6 +395,7 @@ func (s *RateLimitService) handle429(ctx context.Context, account *Account, head
}
return
}
slog
.
Warn
(
"rate_limit_no_reset_time"
,
"account_id"
,
account
.
ID
,
"platform"
,
account
.
Platform
,
"using_default"
,
"5m"
)
if
err
:=
s
.
accountRepo
.
SetRateLimited
(
ctx
,
account
.
ID
,
resetAt
);
err
!=
nil
{
slog
.
Warn
(
"rate_limit_set_failed"
,
"account_id"
,
account
.
ID
,
"error"
,
err
)
}
...
...
@@ -419,6 +459,108 @@ func (s *RateLimitService) shouldScopeClaudeSonnetRateLimit(account *Account, re
return
strings
.
Contains
(
msg
,
"sonnet"
)
}
// calculateOpenAI429ResetTime 从 OpenAI 429 响应头计算正确的重置时间
// 返回 nil 表示无法从响应头中确定重置时间
func
(
s
*
RateLimitService
)
calculateOpenAI429ResetTime
(
headers
http
.
Header
)
*
time
.
Time
{
snapshot
:=
ParseCodexRateLimitHeaders
(
headers
)
if
snapshot
==
nil
{
return
nil
}
normalized
:=
snapshot
.
Normalize
()
if
normalized
==
nil
{
return
nil
}
now
:=
time
.
Now
()
// 判断哪个限制被触发(used_percent >= 100)
is7dExhausted
:=
normalized
.
Used7dPercent
!=
nil
&&
*
normalized
.
Used7dPercent
>=
100
is5hExhausted
:=
normalized
.
Used5hPercent
!=
nil
&&
*
normalized
.
Used5hPercent
>=
100
// 优先使用被触发限制的重置时间
if
is7dExhausted
&&
normalized
.
Reset7dSeconds
!=
nil
{
resetAt
:=
now
.
Add
(
time
.
Duration
(
*
normalized
.
Reset7dSeconds
)
*
time
.
Second
)
slog
.
Info
(
"openai_429_7d_limit_exhausted"
,
"reset_after_seconds"
,
*
normalized
.
Reset7dSeconds
,
"reset_at"
,
resetAt
)
return
&
resetAt
}
if
is5hExhausted
&&
normalized
.
Reset5hSeconds
!=
nil
{
resetAt
:=
now
.
Add
(
time
.
Duration
(
*
normalized
.
Reset5hSeconds
)
*
time
.
Second
)
slog
.
Info
(
"openai_429_5h_limit_exhausted"
,
"reset_after_seconds"
,
*
normalized
.
Reset5hSeconds
,
"reset_at"
,
resetAt
)
return
&
resetAt
}
// 都未达到100%但收到429,使用较长的重置时间
var
maxResetSecs
int
if
normalized
.
Reset7dSeconds
!=
nil
&&
*
normalized
.
Reset7dSeconds
>
maxResetSecs
{
maxResetSecs
=
*
normalized
.
Reset7dSeconds
}
if
normalized
.
Reset5hSeconds
!=
nil
&&
*
normalized
.
Reset5hSeconds
>
maxResetSecs
{
maxResetSecs
=
*
normalized
.
Reset5hSeconds
}
if
maxResetSecs
>
0
{
resetAt
:=
now
.
Add
(
time
.
Duration
(
maxResetSecs
)
*
time
.
Second
)
slog
.
Info
(
"openai_429_using_max_reset"
,
"max_reset_seconds"
,
maxResetSecs
,
"reset_at"
,
resetAt
)
return
&
resetAt
}
return
nil
}
// parseOpenAIRateLimitResetTime 解析 OpenAI 格式的 429 响应,返回重置时间的 Unix 时间戳
// OpenAI 的 usage_limit_reached 错误格式:
//
// {
// "error": {
// "message": "The usage limit has been reached",
// "type": "usage_limit_reached",
// "resets_at": 1769404154,
// "resets_in_seconds": 133107
// }
// }
func
parseOpenAIRateLimitResetTime
(
body
[]
byte
)
*
int64
{
var
parsed
map
[
string
]
any
if
err
:=
json
.
Unmarshal
(
body
,
&
parsed
);
err
!=
nil
{
return
nil
}
errObj
,
ok
:=
parsed
[
"error"
]
.
(
map
[
string
]
any
)
if
!
ok
{
return
nil
}
// 检查是否为 usage_limit_reached 或 rate_limit_exceeded 类型
errType
,
_
:=
errObj
[
"type"
]
.
(
string
)
if
errType
!=
"usage_limit_reached"
&&
errType
!=
"rate_limit_exceeded"
{
return
nil
}
// 优先使用 resets_at(Unix 时间戳)
if
resetsAt
,
ok
:=
errObj
[
"resets_at"
]
.
(
float64
);
ok
{
ts
:=
int64
(
resetsAt
)
return
&
ts
}
if
resetsAt
,
ok
:=
errObj
[
"resets_at"
]
.
(
string
);
ok
{
if
ts
,
err
:=
strconv
.
ParseInt
(
resetsAt
,
10
,
64
);
err
==
nil
{
return
&
ts
}
}
// 如果没有 resets_at,尝试使用 resets_in_seconds
if
resetsInSeconds
,
ok
:=
errObj
[
"resets_in_seconds"
]
.
(
float64
);
ok
{
ts
:=
time
.
Now
()
.
Unix
()
+
int64
(
resetsInSeconds
)
return
&
ts
}
if
resetsInSeconds
,
ok
:=
errObj
[
"resets_in_seconds"
]
.
(
string
);
ok
{
if
sec
,
err
:=
strconv
.
ParseInt
(
resetsInSeconds
,
10
,
64
);
err
==
nil
{
ts
:=
time
.
Now
()
.
Unix
()
+
sec
return
&
ts
}
}
return
nil
}
// handle529 处理529过载错误
// 根据配置设置过载冷却时间
func
(
s
*
RateLimitService
)
handle529
(
ctx
context
.
Context
,
account
*
Account
)
{
...
...
backend/internal/service/ratelimit_service_openai_test.go
0 → 100644
View file @
a161fcc8
package
service
import
(
"net/http"
"testing"
"time"
)
func
TestCalculateOpenAI429ResetTime_7dExhausted
(
t
*
testing
.
T
)
{
svc
:=
&
RateLimitService
{}
// Simulate headers when 7d limit is exhausted (100% used)
// Primary = 7d (10080 minutes), Secondary = 5h (300 minutes)
headers
:=
http
.
Header
{}
headers
.
Set
(
"x-codex-primary-used-percent"
,
"100"
)
headers
.
Set
(
"x-codex-primary-reset-after-seconds"
,
"384607"
)
// ~4.5 days
headers
.
Set
(
"x-codex-primary-window-minutes"
,
"10080"
)
// 7 days
headers
.
Set
(
"x-codex-secondary-used-percent"
,
"3"
)
headers
.
Set
(
"x-codex-secondary-reset-after-seconds"
,
"17369"
)
// ~4.8 hours
headers
.
Set
(
"x-codex-secondary-window-minutes"
,
"300"
)
// 5 hours
before
:=
time
.
Now
()
resetAt
:=
svc
.
calculateOpenAI429ResetTime
(
headers
)
after
:=
time
.
Now
()
if
resetAt
==
nil
{
t
.
Fatal
(
"expected non-nil resetAt"
)
}
// Should be approximately 384607 seconds from now
expectedDuration
:=
384607
*
time
.
Second
minExpected
:=
before
.
Add
(
expectedDuration
)
maxExpected
:=
after
.
Add
(
expectedDuration
)
if
resetAt
.
Before
(
minExpected
)
||
resetAt
.
After
(
maxExpected
)
{
t
.
Errorf
(
"resetAt %v not in expected range [%v, %v]"
,
resetAt
,
minExpected
,
maxExpected
)
}
}
func
TestCalculateOpenAI429ResetTime_5hExhausted
(
t
*
testing
.
T
)
{
svc
:=
&
RateLimitService
{}
// Simulate headers when 5h limit is exhausted (100% used)
headers
:=
http
.
Header
{}
headers
.
Set
(
"x-codex-primary-used-percent"
,
"50"
)
headers
.
Set
(
"x-codex-primary-reset-after-seconds"
,
"500000"
)
headers
.
Set
(
"x-codex-primary-window-minutes"
,
"10080"
)
// 7 days
headers
.
Set
(
"x-codex-secondary-used-percent"
,
"100"
)
headers
.
Set
(
"x-codex-secondary-reset-after-seconds"
,
"3600"
)
// 1 hour
headers
.
Set
(
"x-codex-secondary-window-minutes"
,
"300"
)
// 5 hours
before
:=
time
.
Now
()
resetAt
:=
svc
.
calculateOpenAI429ResetTime
(
headers
)
after
:=
time
.
Now
()
if
resetAt
==
nil
{
t
.
Fatal
(
"expected non-nil resetAt"
)
}
// Should be approximately 3600 seconds from now
expectedDuration
:=
3600
*
time
.
Second
minExpected
:=
before
.
Add
(
expectedDuration
)
maxExpected
:=
after
.
Add
(
expectedDuration
)
if
resetAt
.
Before
(
minExpected
)
||
resetAt
.
After
(
maxExpected
)
{
t
.
Errorf
(
"resetAt %v not in expected range [%v, %v]"
,
resetAt
,
minExpected
,
maxExpected
)
}
}
func
TestCalculateOpenAI429ResetTime_NeitherExhausted_UsesMax
(
t
*
testing
.
T
)
{
svc
:=
&
RateLimitService
{}
// Neither limit at 100%, should use the longer reset time
headers
:=
http
.
Header
{}
headers
.
Set
(
"x-codex-primary-used-percent"
,
"80"
)
headers
.
Set
(
"x-codex-primary-reset-after-seconds"
,
"100000"
)
headers
.
Set
(
"x-codex-primary-window-minutes"
,
"10080"
)
headers
.
Set
(
"x-codex-secondary-used-percent"
,
"90"
)
headers
.
Set
(
"x-codex-secondary-reset-after-seconds"
,
"5000"
)
headers
.
Set
(
"x-codex-secondary-window-minutes"
,
"300"
)
before
:=
time
.
Now
()
resetAt
:=
svc
.
calculateOpenAI429ResetTime
(
headers
)
after
:=
time
.
Now
()
if
resetAt
==
nil
{
t
.
Fatal
(
"expected non-nil resetAt"
)
}
// Should use the max (100000 seconds from 7d window)
expectedDuration
:=
100000
*
time
.
Second
minExpected
:=
before
.
Add
(
expectedDuration
)
maxExpected
:=
after
.
Add
(
expectedDuration
)
if
resetAt
.
Before
(
minExpected
)
||
resetAt
.
After
(
maxExpected
)
{
t
.
Errorf
(
"resetAt %v not in expected range [%v, %v]"
,
resetAt
,
minExpected
,
maxExpected
)
}
}
func
TestCalculateOpenAI429ResetTime_NoCodexHeaders
(
t
*
testing
.
T
)
{
svc
:=
&
RateLimitService
{}
// No codex headers at all
headers
:=
http
.
Header
{}
headers
.
Set
(
"content-type"
,
"application/json"
)
resetAt
:=
svc
.
calculateOpenAI429ResetTime
(
headers
)
if
resetAt
!=
nil
{
t
.
Errorf
(
"expected nil resetAt when no codex headers, got %v"
,
resetAt
)
}
}
func
TestCalculateOpenAI429ResetTime_ReversedWindowOrder
(
t
*
testing
.
T
)
{
svc
:=
&
RateLimitService
{}
// Test when OpenAI sends primary as 5h and secondary as 7d (reversed)
headers
:=
http
.
Header
{}
headers
.
Set
(
"x-codex-primary-used-percent"
,
"100"
)
// This is 5h
headers
.
Set
(
"x-codex-primary-reset-after-seconds"
,
"3600"
)
// 1 hour
headers
.
Set
(
"x-codex-primary-window-minutes"
,
"300"
)
// 5 hours - smaller!
headers
.
Set
(
"x-codex-secondary-used-percent"
,
"50"
)
headers
.
Set
(
"x-codex-secondary-reset-after-seconds"
,
"500000"
)
headers
.
Set
(
"x-codex-secondary-window-minutes"
,
"10080"
)
// 7 days - larger!
before
:=
time
.
Now
()
resetAt
:=
svc
.
calculateOpenAI429ResetTime
(
headers
)
after
:=
time
.
Now
()
if
resetAt
==
nil
{
t
.
Fatal
(
"expected non-nil resetAt"
)
}
// Should correctly identify that primary is 5h (smaller window) and use its reset time
expectedDuration
:=
3600
*
time
.
Second
minExpected
:=
before
.
Add
(
expectedDuration
)
maxExpected
:=
after
.
Add
(
expectedDuration
)
if
resetAt
.
Before
(
minExpected
)
||
resetAt
.
After
(
maxExpected
)
{
t
.
Errorf
(
"resetAt %v not in expected range [%v, %v]"
,
resetAt
,
minExpected
,
maxExpected
)
}
}
func
TestNormalizedCodexLimits
(
t
*
testing
.
T
)
{
// Test the Normalize() method directly
pUsed
:=
100.0
pReset
:=
384607
pWindow
:=
10080
sUsed
:=
3.0
sReset
:=
17369
sWindow
:=
300
snapshot
:=
&
OpenAICodexUsageSnapshot
{
PrimaryUsedPercent
:
&
pUsed
,
PrimaryResetAfterSeconds
:
&
pReset
,
PrimaryWindowMinutes
:
&
pWindow
,
SecondaryUsedPercent
:
&
sUsed
,
SecondaryResetAfterSeconds
:
&
sReset
,
SecondaryWindowMinutes
:
&
sWindow
,
}
normalized
:=
snapshot
.
Normalize
()
if
normalized
==
nil
{
t
.
Fatal
(
"expected non-nil normalized"
)
}
// Primary has larger window (10080 > 300), so primary should be 7d
if
normalized
.
Used7dPercent
==
nil
||
*
normalized
.
Used7dPercent
!=
100.0
{
t
.
Errorf
(
"expected Used7dPercent=100, got %v"
,
normalized
.
Used7dPercent
)
}
if
normalized
.
Reset7dSeconds
==
nil
||
*
normalized
.
Reset7dSeconds
!=
384607
{
t
.
Errorf
(
"expected Reset7dSeconds=384607, got %v"
,
normalized
.
Reset7dSeconds
)
}
if
normalized
.
Used5hPercent
==
nil
||
*
normalized
.
Used5hPercent
!=
3.0
{
t
.
Errorf
(
"expected Used5hPercent=3, got %v"
,
normalized
.
Used5hPercent
)
}
if
normalized
.
Reset5hSeconds
==
nil
||
*
normalized
.
Reset5hSeconds
!=
17369
{
t
.
Errorf
(
"expected Reset5hSeconds=17369, got %v"
,
normalized
.
Reset5hSeconds
)
}
}
func
TestNormalizedCodexLimits_OnlyPrimaryData
(
t
*
testing
.
T
)
{
// Test when only primary has data, no window_minutes
pUsed
:=
80.0
pReset
:=
50000
snapshot
:=
&
OpenAICodexUsageSnapshot
{
PrimaryUsedPercent
:
&
pUsed
,
PrimaryResetAfterSeconds
:
&
pReset
,
// No window_minutes, no secondary data
}
normalized
:=
snapshot
.
Normalize
()
if
normalized
==
nil
{
t
.
Fatal
(
"expected non-nil normalized"
)
}
// Legacy assumption: primary=7d, secondary=5h
if
normalized
.
Used7dPercent
==
nil
||
*
normalized
.
Used7dPercent
!=
80.0
{
t
.
Errorf
(
"expected Used7dPercent=80, got %v"
,
normalized
.
Used7dPercent
)
}
if
normalized
.
Reset7dSeconds
==
nil
||
*
normalized
.
Reset7dSeconds
!=
50000
{
t
.
Errorf
(
"expected Reset7dSeconds=50000, got %v"
,
normalized
.
Reset7dSeconds
)
}
// Secondary (5h) should be nil
if
normalized
.
Used5hPercent
!=
nil
{
t
.
Errorf
(
"expected Used5hPercent=nil, got %v"
,
*
normalized
.
Used5hPercent
)
}
if
normalized
.
Reset5hSeconds
!=
nil
{
t
.
Errorf
(
"expected Reset5hSeconds=nil, got %v"
,
*
normalized
.
Reset5hSeconds
)
}
}
func
TestNormalizedCodexLimits_OnlySecondaryData
(
t
*
testing
.
T
)
{
// Test when only secondary has data, no window_minutes
sUsed
:=
60.0
sReset
:=
3000
snapshot
:=
&
OpenAICodexUsageSnapshot
{
SecondaryUsedPercent
:
&
sUsed
,
SecondaryResetAfterSeconds
:
&
sReset
,
// No window_minutes, no primary data
}
normalized
:=
snapshot
.
Normalize
()
if
normalized
==
nil
{
t
.
Fatal
(
"expected non-nil normalized"
)
}
// Legacy assumption: primary=7d, secondary=5h
// So secondary goes to 5h
if
normalized
.
Used5hPercent
==
nil
||
*
normalized
.
Used5hPercent
!=
60.0
{
t
.
Errorf
(
"expected Used5hPercent=60, got %v"
,
normalized
.
Used5hPercent
)
}
if
normalized
.
Reset5hSeconds
==
nil
||
*
normalized
.
Reset5hSeconds
!=
3000
{
t
.
Errorf
(
"expected Reset5hSeconds=3000, got %v"
,
normalized
.
Reset5hSeconds
)
}
// Primary (7d) should be nil
if
normalized
.
Used7dPercent
!=
nil
{
t
.
Errorf
(
"expected Used7dPercent=nil, got %v"
,
*
normalized
.
Used7dPercent
)
}
}
func
TestNormalizedCodexLimits_BothDataNoWindowMinutes
(
t
*
testing
.
T
)
{
// Test when both have data but no window_minutes
pUsed
:=
100.0
pReset
:=
400000
sUsed
:=
50.0
sReset
:=
10000
snapshot
:=
&
OpenAICodexUsageSnapshot
{
PrimaryUsedPercent
:
&
pUsed
,
PrimaryResetAfterSeconds
:
&
pReset
,
SecondaryUsedPercent
:
&
sUsed
,
SecondaryResetAfterSeconds
:
&
sReset
,
// No window_minutes
}
normalized
:=
snapshot
.
Normalize
()
if
normalized
==
nil
{
t
.
Fatal
(
"expected non-nil normalized"
)
}
// Legacy assumption: primary=7d, secondary=5h
if
normalized
.
Used7dPercent
==
nil
||
*
normalized
.
Used7dPercent
!=
100.0
{
t
.
Errorf
(
"expected Used7dPercent=100, got %v"
,
normalized
.
Used7dPercent
)
}
if
normalized
.
Reset7dSeconds
==
nil
||
*
normalized
.
Reset7dSeconds
!=
400000
{
t
.
Errorf
(
"expected Reset7dSeconds=400000, got %v"
,
normalized
.
Reset7dSeconds
)
}
if
normalized
.
Used5hPercent
==
nil
||
*
normalized
.
Used5hPercent
!=
50.0
{
t
.
Errorf
(
"expected Used5hPercent=50, got %v"
,
normalized
.
Used5hPercent
)
}
if
normalized
.
Reset5hSeconds
==
nil
||
*
normalized
.
Reset5hSeconds
!=
10000
{
t
.
Errorf
(
"expected Reset5hSeconds=10000, got %v"
,
normalized
.
Reset5hSeconds
)
}
}
func
TestHandle429_AnthropicPlatformUnaffected
(
t
*
testing
.
T
)
{
// Verify that Anthropic platform accounts still use the original logic
// This test ensures we don't break existing Claude account rate limiting
svc
:=
&
RateLimitService
{}
// Simulate Anthropic 429 headers
headers
:=
http
.
Header
{}
headers
.
Set
(
"anthropic-ratelimit-unified-reset"
,
"1737820800"
)
// A future Unix timestamp
// For Anthropic platform, calculateOpenAI429ResetTime should return nil
// because it only handles OpenAI platform
resetAt
:=
svc
.
calculateOpenAI429ResetTime
(
headers
)
// Should return nil since there are no x-codex-* headers
if
resetAt
!=
nil
{
t
.
Errorf
(
"expected nil for Anthropic headers, got %v"
,
resetAt
)
}
}
func
TestCalculateOpenAI429ResetTime_UserProvidedScenario
(
t
*
testing
.
T
)
{
// This is the exact scenario from the user:
// codex_7d_used_percent: 100
// codex_7d_reset_after_seconds: 384607 (约4.5天后重置)
// codex_5h_used_percent: 3
// codex_5h_reset_after_seconds: 17369 (约4.8小时后重置)
svc
:=
&
RateLimitService
{}
// Simulate headers matching user's data
// Note: We need to map the canonical 5h/7d back to primary/secondary
// Based on typical OpenAI behavior: primary=7d (larger window), secondary=5h (smaller window)
headers
:=
http
.
Header
{}
headers
.
Set
(
"x-codex-primary-used-percent"
,
"100"
)
headers
.
Set
(
"x-codex-primary-reset-after-seconds"
,
"384607"
)
headers
.
Set
(
"x-codex-primary-window-minutes"
,
"10080"
)
// 7 days = 10080 minutes
headers
.
Set
(
"x-codex-secondary-used-percent"
,
"3"
)
headers
.
Set
(
"x-codex-secondary-reset-after-seconds"
,
"17369"
)
headers
.
Set
(
"x-codex-secondary-window-minutes"
,
"300"
)
// 5 hours = 300 minutes
before
:=
time
.
Now
()
resetAt
:=
svc
.
calculateOpenAI429ResetTime
(
headers
)
after
:=
time
.
Now
()
if
resetAt
==
nil
{
t
.
Fatal
(
"expected non-nil resetAt for user scenario"
)
}
// Should use the 7d reset time (384607 seconds) since 7d limit is exhausted (100%)
expectedDuration
:=
384607
*
time
.
Second
minExpected
:=
before
.
Add
(
expectedDuration
)
maxExpected
:=
after
.
Add
(
expectedDuration
)
if
resetAt
.
Before
(
minExpected
)
||
resetAt
.
After
(
maxExpected
)
{
t
.
Errorf
(
"resetAt %v not in expected range [%v, %v]"
,
resetAt
,
minExpected
,
maxExpected
)
}
// Verify it's approximately 4.45 days (384607 seconds)
duration
:=
resetAt
.
Sub
(
before
)
actualDays
:=
duration
.
Hours
()
/
24.0
// 384607 / 86400 = ~4.45 days
if
actualDays
<
4.4
||
actualDays
>
4.5
{
t
.
Errorf
(
"expected ~4.45 days, got %.2f days"
,
actualDays
)
}
t
.
Logf
(
"User scenario: reset_at=%v, duration=%.2f days"
,
resetAt
,
actualDays
)
}
func
TestCalculateOpenAI429ResetTime_5MinFallbackWhenNoReset
(
t
*
testing
.
T
)
{
// Test that we return nil when there's used_percent but no reset_after_seconds
// This should cause the caller to use the default 5-minute fallback
svc
:=
&
RateLimitService
{}
headers
:=
http
.
Header
{}
headers
.
Set
(
"x-codex-primary-used-percent"
,
"100"
)
// No reset_after_seconds!
resetAt
:=
svc
.
calculateOpenAI429ResetTime
(
headers
)
// Should return nil since there's no reset time available
if
resetAt
!=
nil
{
t
.
Errorf
(
"expected nil when no reset_after_seconds, got %v"
,
resetAt
)
}
}
backend/internal/service/setting_service.go
View file @
a161fcc8
...
...
@@ -61,6 +61,8 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
SettingKeyRegistrationEnabled
,
SettingKeyEmailVerifyEnabled
,
SettingKeyPromoCodeEnabled
,
SettingKeyPasswordResetEnabled
,
SettingKeyTotpEnabled
,
SettingKeyTurnstileEnabled
,
SettingKeyTurnstileSiteKey
,
SettingKeySiteName
,
...
...
@@ -86,10 +88,16 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
linuxDoEnabled
=
s
.
cfg
!=
nil
&&
s
.
cfg
.
LinuxDo
.
Enabled
}
// Password reset requires email verification to be enabled
emailVerifyEnabled
:=
settings
[
SettingKeyEmailVerifyEnabled
]
==
"true"
passwordResetEnabled
:=
emailVerifyEnabled
&&
settings
[
SettingKeyPasswordResetEnabled
]
==
"true"
return
&
PublicSettings
{
RegistrationEnabled
:
settings
[
SettingKeyRegistrationEnabled
]
==
"true"
,
EmailVerifyEnabled
:
settings
[
SettingKeyE
mailVerifyEnabled
]
==
"true"
,
EmailVerifyEnabled
:
e
mailVerifyEnabled
,
PromoCodeEnabled
:
settings
[
SettingKeyPromoCodeEnabled
]
!=
"false"
,
// 默认启用
PasswordResetEnabled
:
passwordResetEnabled
,
TotpEnabled
:
settings
[
SettingKeyTotpEnabled
]
==
"true"
,
TurnstileEnabled
:
settings
[
SettingKeyTurnstileEnabled
]
==
"true"
,
TurnstileSiteKey
:
settings
[
SettingKeyTurnstileSiteKey
],
SiteName
:
s
.
getStringOrDefault
(
settings
,
SettingKeySiteName
,
"Sub2API"
),
...
...
@@ -128,6 +136,8 @@ func (s *SettingService) GetPublicSettingsForInjection(ctx context.Context) (any
RegistrationEnabled
bool
`json:"registration_enabled"`
EmailVerifyEnabled
bool
`json:"email_verify_enabled"`
PromoCodeEnabled
bool
`json:"promo_code_enabled"`
PasswordResetEnabled
bool
`json:"password_reset_enabled"`
TotpEnabled
bool
`json:"totp_enabled"`
TurnstileEnabled
bool
`json:"turnstile_enabled"`
TurnstileSiteKey
string
`json:"turnstile_site_key,omitempty"`
SiteName
string
`json:"site_name"`
...
...
@@ -144,6 +154,8 @@ func (s *SettingService) GetPublicSettingsForInjection(ctx context.Context) (any
RegistrationEnabled
:
settings
.
RegistrationEnabled
,
EmailVerifyEnabled
:
settings
.
EmailVerifyEnabled
,
PromoCodeEnabled
:
settings
.
PromoCodeEnabled
,
PasswordResetEnabled
:
settings
.
PasswordResetEnabled
,
TotpEnabled
:
settings
.
TotpEnabled
,
TurnstileEnabled
:
settings
.
TurnstileEnabled
,
TurnstileSiteKey
:
settings
.
TurnstileSiteKey
,
SiteName
:
settings
.
SiteName
,
...
...
@@ -167,6 +179,8 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet
updates
[
SettingKeyRegistrationEnabled
]
=
strconv
.
FormatBool
(
settings
.
RegistrationEnabled
)
updates
[
SettingKeyEmailVerifyEnabled
]
=
strconv
.
FormatBool
(
settings
.
EmailVerifyEnabled
)
updates
[
SettingKeyPromoCodeEnabled
]
=
strconv
.
FormatBool
(
settings
.
PromoCodeEnabled
)
updates
[
SettingKeyPasswordResetEnabled
]
=
strconv
.
FormatBool
(
settings
.
PasswordResetEnabled
)
updates
[
SettingKeyTotpEnabled
]
=
strconv
.
FormatBool
(
settings
.
TotpEnabled
)
// 邮件服务设置(只有非空才更新密码)
updates
[
SettingKeySMTPHost
]
=
settings
.
SMTPHost
...
...
@@ -262,6 +276,35 @@ func (s *SettingService) IsPromoCodeEnabled(ctx context.Context) bool {
return
value
!=
"false"
}
// IsPasswordResetEnabled 检查是否启用密码重置功能
// 要求:必须同时开启邮件验证
func
(
s
*
SettingService
)
IsPasswordResetEnabled
(
ctx
context
.
Context
)
bool
{
// Password reset requires email verification to be enabled
if
!
s
.
IsEmailVerifyEnabled
(
ctx
)
{
return
false
}
value
,
err
:=
s
.
settingRepo
.
GetValue
(
ctx
,
SettingKeyPasswordResetEnabled
)
if
err
!=
nil
{
return
false
// 默认关闭
}
return
value
==
"true"
}
// IsTotpEnabled 检查是否启用 TOTP 双因素认证功能
func
(
s
*
SettingService
)
IsTotpEnabled
(
ctx
context
.
Context
)
bool
{
value
,
err
:=
s
.
settingRepo
.
GetValue
(
ctx
,
SettingKeyTotpEnabled
)
if
err
!=
nil
{
return
false
// 默认关闭
}
return
value
==
"true"
}
// IsTotpEncryptionKeyConfigured 检查 TOTP 加密密钥是否已手动配置
// 只有手动配置了密钥才允许在管理后台启用 TOTP 功能
func
(
s
*
SettingService
)
IsTotpEncryptionKeyConfigured
()
bool
{
return
s
.
cfg
.
Totp
.
EncryptionKeyConfigured
}
// GetSiteName 获取网站名称
func
(
s
*
SettingService
)
GetSiteName
(
ctx
context
.
Context
)
string
{
value
,
err
:=
s
.
settingRepo
.
GetValue
(
ctx
,
SettingKeySiteName
)
...
...
@@ -340,10 +383,13 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error {
// parseSettings 解析设置到结构体
func
(
s
*
SettingService
)
parseSettings
(
settings
map
[
string
]
string
)
*
SystemSettings
{
emailVerifyEnabled
:=
settings
[
SettingKeyEmailVerifyEnabled
]
==
"true"
result
:=
&
SystemSettings
{
RegistrationEnabled
:
settings
[
SettingKeyRegistrationEnabled
]
==
"true"
,
EmailVerifyEnabled
:
settings
[
SettingKeyE
mailVerifyEnabled
]
==
"true"
,
EmailVerifyEnabled
:
e
mailVerifyEnabled
,
PromoCodeEnabled
:
settings
[
SettingKeyPromoCodeEnabled
]
!=
"false"
,
// 默认启用
PasswordResetEnabled
:
emailVerifyEnabled
&&
settings
[
SettingKeyPasswordResetEnabled
]
==
"true"
,
TotpEnabled
:
settings
[
SettingKeyTotpEnabled
]
==
"true"
,
SMTPHost
:
settings
[
SettingKeySMTPHost
],
SMTPUsername
:
settings
[
SettingKeySMTPUsername
],
SMTPFrom
:
settings
[
SettingKeySMTPFrom
],
...
...
backend/internal/service/settings_view.go
View file @
a161fcc8
...
...
@@ -4,6 +4,8 @@ type SystemSettings struct {
RegistrationEnabled
bool
EmailVerifyEnabled
bool
PromoCodeEnabled
bool
PasswordResetEnabled
bool
TotpEnabled
bool
// TOTP 双因素认证
SMTPHost
string
SMTPPort
int
...
...
@@ -60,6 +62,8 @@ type PublicSettings struct {
RegistrationEnabled
bool
EmailVerifyEnabled
bool
PromoCodeEnabled
bool
PasswordResetEnabled
bool
TotpEnabled
bool
// TOTP 双因素认证
TurnstileEnabled
bool
TurnstileSiteKey
string
SiteName
string
...
...
backend/internal/service/subscription_expiry_service.go
0 → 100644
View file @
a161fcc8
package
service
import
(
"context"
"log"
"sync"
"time"
)
// SubscriptionExpiryService periodically updates expired subscription status.
type
SubscriptionExpiryService
struct
{
userSubRepo
UserSubscriptionRepository
interval
time
.
Duration
stopCh
chan
struct
{}
stopOnce
sync
.
Once
wg
sync
.
WaitGroup
}
func
NewSubscriptionExpiryService
(
userSubRepo
UserSubscriptionRepository
,
interval
time
.
Duration
)
*
SubscriptionExpiryService
{
return
&
SubscriptionExpiryService
{
userSubRepo
:
userSubRepo
,
interval
:
interval
,
stopCh
:
make
(
chan
struct
{}),
}
}
func
(
s
*
SubscriptionExpiryService
)
Start
()
{
if
s
==
nil
||
s
.
userSubRepo
==
nil
||
s
.
interval
<=
0
{
return
}
s
.
wg
.
Add
(
1
)
go
func
()
{
defer
s
.
wg
.
Done
()
ticker
:=
time
.
NewTicker
(
s
.
interval
)
defer
ticker
.
Stop
()
s
.
runOnce
()
for
{
select
{
case
<-
ticker
.
C
:
s
.
runOnce
()
case
<-
s
.
stopCh
:
return
}
}
}()
}
func
(
s
*
SubscriptionExpiryService
)
Stop
()
{
if
s
==
nil
{
return
}
s
.
stopOnce
.
Do
(
func
()
{
close
(
s
.
stopCh
)
})
s
.
wg
.
Wait
()
}
func
(
s
*
SubscriptionExpiryService
)
runOnce
()
{
ctx
,
cancel
:=
context
.
WithTimeout
(
context
.
Background
(),
10
*
time
.
Second
)
defer
cancel
()
updated
,
err
:=
s
.
userSubRepo
.
BatchUpdateExpiredStatus
(
ctx
)
if
err
!=
nil
{
log
.
Printf
(
"[SubscriptionExpiry] Update expired subscriptions failed: %v"
,
err
)
return
}
if
updated
>
0
{
log
.
Printf
(
"[SubscriptionExpiry] Updated %d expired subscriptions"
,
updated
)
}
}
backend/internal/service/subscription_service.go
View file @
a161fcc8
...
...
@@ -324,19 +324,32 @@ func (s *SubscriptionService) ExtendSubscription(ctx context.Context, subscripti
days
=
-
MaxValidityDays
}
now
:=
time
.
Now
()
isExpired
:=
!
sub
.
ExpiresAt
.
After
(
now
)
// 如果订阅已过期,不允许负向调整
if
isExpired
&&
days
<
0
{
return
nil
,
infraerrors
.
BadRequest
(
"CANNOT_SHORTEN_EXPIRED"
,
"cannot shorten an expired subscription"
)
}
// 计算新的过期时间
newExpiresAt
:=
sub
.
ExpiresAt
.
AddDate
(
0
,
0
,
days
)
var
newExpiresAt
time
.
Time
if
isExpired
{
// 已过期:从当前时间开始增加天数
newExpiresAt
=
now
.
AddDate
(
0
,
0
,
days
)
}
else
{
// 未过期:从原过期时间增加/减少天数
newExpiresAt
=
sub
.
ExpiresAt
.
AddDate
(
0
,
0
,
days
)
}
if
newExpiresAt
.
After
(
MaxExpiresAt
)
{
newExpiresAt
=
MaxExpiresAt
}
// 如果是缩短(负数),检查新的过期时间必须大于当前时间
if
days
<
0
{
now
:=
time
.
Now
()
// 检查新的过期时间必须大于当前时间
if
!
newExpiresAt
.
After
(
now
)
{
return
nil
,
ErrAdjustWouldExpire
}
}
if
err
:=
s
.
userSubRepo
.
ExtendExpiry
(
ctx
,
subscriptionID
,
newExpiresAt
);
err
!=
nil
{
return
nil
,
err
...
...
@@ -383,6 +396,7 @@ func (s *SubscriptionService) ListUserSubscriptions(ctx context.Context, userID
return
nil
,
err
}
normalizeExpiredWindows
(
subs
)
normalizeSubscriptionStatus
(
subs
)
return
subs
,
nil
}
...
...
@@ -404,17 +418,19 @@ func (s *SubscriptionService) ListGroupSubscriptions(ctx context.Context, groupI
return
nil
,
nil
,
err
}
normalizeExpiredWindows
(
subs
)
normalizeSubscriptionStatus
(
subs
)
return
subs
,
pag
,
nil
}
// List 获取所有订阅(分页,支持筛选)
func
(
s
*
SubscriptionService
)
List
(
ctx
context
.
Context
,
page
,
pageSize
int
,
userID
,
groupID
*
int64
,
status
string
)
([]
UserSubscription
,
*
pagination
.
PaginationResult
,
error
)
{
// List 获取所有订阅(分页,支持筛选
和排序
)
func
(
s
*
SubscriptionService
)
List
(
ctx
context
.
Context
,
page
,
pageSize
int
,
userID
,
groupID
*
int64
,
status
,
sortBy
,
sortOrder
string
)
([]
UserSubscription
,
*
pagination
.
PaginationResult
,
error
)
{
params
:=
pagination
.
PaginationParams
{
Page
:
page
,
PageSize
:
pageSize
}
subs
,
pag
,
err
:=
s
.
userSubRepo
.
List
(
ctx
,
params
,
userID
,
groupID
,
status
)
subs
,
pag
,
err
:=
s
.
userSubRepo
.
List
(
ctx
,
params
,
userID
,
groupID
,
status
,
sortBy
,
sortOrder
)
if
err
!=
nil
{
return
nil
,
nil
,
err
}
normalizeExpiredWindows
(
subs
)
normalizeSubscriptionStatus
(
subs
)
return
subs
,
pag
,
nil
}
...
...
@@ -441,6 +457,18 @@ func normalizeExpiredWindows(subs []UserSubscription) {
}
}
// normalizeSubscriptionStatus 根据实际过期时间修正状态(仅影响返回数据,不影响数据库)
// 这确保前端显示正确的状态,即使定时任务尚未更新数据库
func
normalizeSubscriptionStatus
(
subs
[]
UserSubscription
)
{
now
:=
time
.
Now
()
for
i
:=
range
subs
{
sub
:=
&
subs
[
i
]
if
sub
.
Status
==
SubscriptionStatusActive
&&
!
sub
.
ExpiresAt
.
After
(
now
)
{
sub
.
Status
=
SubscriptionStatusExpired
}
}
}
// startOfDay 返回给定时间所在日期的零点(保持原时区)
func
startOfDay
(
t
time
.
Time
)
time
.
Time
{
return
time
.
Date
(
t
.
Year
(),
t
.
Month
(),
t
.
Day
(),
0
,
0
,
0
,
0
,
t
.
Location
())
...
...
@@ -659,11 +687,6 @@ func (s *SubscriptionService) GetUserSubscriptionsWithProgress(ctx context.Conte
return
progresses
,
nil
}
// UpdateExpiredSubscriptions 更新过期订阅状态(定时任务调用)
func
(
s
*
SubscriptionService
)
UpdateExpiredSubscriptions
(
ctx
context
.
Context
)
(
int64
,
error
)
{
return
s
.
userSubRepo
.
BatchUpdateExpiredStatus
(
ctx
)
}
// ValidateSubscription 验证订阅是否有效
func
(
s
*
SubscriptionService
)
ValidateSubscription
(
ctx
context
.
Context
,
sub
*
UserSubscription
)
error
{
if
sub
.
Status
==
SubscriptionStatusExpired
{
...
...
backend/internal/service/token_cache_invalidator.go
View file @
a161fcc8
package
service
import
"context"
import
(
"context"
"log/slog"
"strconv"
)
type
TokenCacheInvalidator
interface
{
InvalidateToken
(
ctx
context
.
Context
,
account
*
Account
)
error
...
...
@@ -24,18 +28,87 @@ func (c *CompositeTokenCacheInvalidator) InvalidateToken(ctx context.Context, ac
return
nil
}
var
cacheKey
string
var
keysToDelete
[]
string
accountIDKey
:=
"account:"
+
strconv
.
FormatInt
(
account
.
ID
,
10
)
switch
account
.
Platform
{
case
PlatformGemini
:
cacheKey
=
GeminiTokenCacheKey
(
account
)
// Gemini 可能有两种缓存键:project_id 或 account_id
// 首次获取 token 时可能没有 project_id,之后自动检测到 project_id 后会使用新 key
// 刷新时需要同时删除两种可能的 key,确保不会遗留旧缓存
keysToDelete
=
append
(
keysToDelete
,
GeminiTokenCacheKey
(
account
))
keysToDelete
=
append
(
keysToDelete
,
"gemini:"
+
accountIDKey
)
case
PlatformAntigravity
:
cacheKey
=
AntigravityTokenCacheKey
(
account
)
// Antigravity 同样可能有两种缓存键
keysToDelete
=
append
(
keysToDelete
,
AntigravityTokenCacheKey
(
account
))
keysToDelete
=
append
(
keysToDelete
,
"ag:"
+
accountIDKey
)
case
PlatformOpenAI
:
cacheKey
=
OpenAITokenCacheKey
(
account
)
keysToDelete
=
append
(
keysToDelete
,
OpenAITokenCacheKey
(
account
)
)
case
PlatformAnthropic
:
cacheKey
=
ClaudeTokenCacheKey
(
account
)
keysToDelete
=
append
(
keysToDelete
,
ClaudeTokenCacheKey
(
account
)
)
default
:
return
nil
}
return
c
.
cache
.
DeleteAccessToken
(
ctx
,
cacheKey
)
// 删除所有可能的缓存键(去重后)
seen
:=
make
(
map
[
string
]
bool
)
for
_
,
key
:=
range
keysToDelete
{
if
seen
[
key
]
{
continue
}
seen
[
key
]
=
true
if
err
:=
c
.
cache
.
DeleteAccessToken
(
ctx
,
key
);
err
!=
nil
{
slog
.
Warn
(
"token_cache_delete_failed"
,
"key"
,
key
,
"account_id"
,
account
.
ID
,
"error"
,
err
)
}
}
return
nil
}
// CheckTokenVersion 检查 account 的 token 版本是否已过时,并返回最新的 account
// 用于解决异步刷新任务与请求线程的竞态条件:
// 如果刷新任务已更新 token 并删除缓存,此时请求线程的旧 account 对象不应写入缓存
//
// 返回值:
// - latestAccount: 从 DB 获取的最新 account(如果查询失败则返回 nil)
// - isStale: true 表示 token 已过时(应使用 latestAccount),false 表示可以使用当前 account
func
CheckTokenVersion
(
ctx
context
.
Context
,
account
*
Account
,
repo
AccountRepository
)
(
latestAccount
*
Account
,
isStale
bool
)
{
if
account
==
nil
||
repo
==
nil
{
return
nil
,
false
}
currentVersion
:=
account
.
GetCredentialAsInt64
(
"_token_version"
)
latestAccount
,
err
:=
repo
.
GetByID
(
ctx
,
account
.
ID
)
if
err
!=
nil
||
latestAccount
==
nil
{
// 查询失败,默认允许缓存,不返回 latestAccount
return
nil
,
false
}
latestVersion
:=
latestAccount
.
GetCredentialAsInt64
(
"_token_version"
)
// 情况1: 当前 account 没有版本号,但 DB 中已有版本号
// 说明异步刷新任务已更新 token,当前 account 已过时
if
currentVersion
==
0
&&
latestVersion
>
0
{
slog
.
Debug
(
"token_version_stale_no_current_version"
,
"account_id"
,
account
.
ID
,
"latest_version"
,
latestVersion
)
return
latestAccount
,
true
}
// 情况2: 两边都没有版本号,说明从未被异步刷新过,允许缓存
if
currentVersion
==
0
&&
latestVersion
==
0
{
return
latestAccount
,
false
}
// 情况3: 比较版本号,如果 DB 中的版本更新,当前 account 已过时
if
latestVersion
>
currentVersion
{
slog
.
Debug
(
"token_version_stale"
,
"account_id"
,
account
.
ID
,
"current_version"
,
currentVersion
,
"latest_version"
,
latestVersion
)
return
latestAccount
,
true
}
return
latestAccount
,
false
}
backend/internal/service/token_cache_invalidator_test.go
View file @
a161fcc8
...
...
@@ -51,7 +51,27 @@ func TestCompositeTokenCacheInvalidator_Gemini(t *testing.T) {
err
:=
invalidator
.
InvalidateToken
(
context
.
Background
(),
account
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
[]
string
{
"gemini:project-x"
},
cache
.
deletedKeys
)
// 新行为:同时删除基于 project_id 和 account_id 的缓存键
// 这是为了处理:首次获取 token 时可能没有 project_id,之后自动检测到后会使用新 key
require
.
Equal
(
t
,
[]
string
{
"gemini:project-x"
,
"gemini:account:10"
},
cache
.
deletedKeys
)
}
func
TestCompositeTokenCacheInvalidator_GeminiWithoutProjectID
(
t
*
testing
.
T
)
{
cache
:=
&
geminiTokenCacheStub
{}
invalidator
:=
NewCompositeTokenCacheInvalidator
(
cache
)
account
:=
&
Account
{
ID
:
10
,
Platform
:
PlatformGemini
,
Type
:
AccountTypeOAuth
,
Credentials
:
map
[
string
]
any
{
"access_token"
:
"gemini-token"
,
},
}
err
:=
invalidator
.
InvalidateToken
(
context
.
Background
(),
account
)
require
.
NoError
(
t
,
err
)
// 没有 project_id 时,两个 key 相同,去重后只删除一个
require
.
Equal
(
t
,
[]
string
{
"gemini:account:10"
},
cache
.
deletedKeys
)
}
func
TestCompositeTokenCacheInvalidator_Antigravity
(
t
*
testing
.
T
)
{
...
...
@@ -68,7 +88,26 @@ func TestCompositeTokenCacheInvalidator_Antigravity(t *testing.T) {
err
:=
invalidator
.
InvalidateToken
(
context
.
Background
(),
account
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
[]
string
{
"ag:ag-project"
},
cache
.
deletedKeys
)
// 新行为:同时删除基于 project_id 和 account_id 的缓存键
require
.
Equal
(
t
,
[]
string
{
"ag:ag-project"
,
"ag:account:99"
},
cache
.
deletedKeys
)
}
func
TestCompositeTokenCacheInvalidator_AntigravityWithoutProjectID
(
t
*
testing
.
T
)
{
cache
:=
&
geminiTokenCacheStub
{}
invalidator
:=
NewCompositeTokenCacheInvalidator
(
cache
)
account
:=
&
Account
{
ID
:
99
,
Platform
:
PlatformAntigravity
,
Type
:
AccountTypeOAuth
,
Credentials
:
map
[
string
]
any
{
"access_token"
:
"ag-token"
,
},
}
err
:=
invalidator
.
InvalidateToken
(
context
.
Background
(),
account
)
require
.
NoError
(
t
,
err
)
// 没有 project_id 时,两个 key 相同,去重后只删除一个
require
.
Equal
(
t
,
[]
string
{
"ag:account:99"
},
cache
.
deletedKeys
)
}
func
TestCompositeTokenCacheInvalidator_OpenAI
(
t
*
testing
.
T
)
{
...
...
@@ -233,9 +272,10 @@ func TestCompositeTokenCacheInvalidator_DeleteError(t *testing.T) {
for
_
,
tt
:=
range
tests
{
t
.
Run
(
tt
.
name
,
func
(
t
*
testing
.
T
)
{
// 新行为:删除失败只记录日志,不返回错误
// 这是因为缓存失效失败不应影响主业务流程
err
:=
invalidator
.
InvalidateToken
(
context
.
Background
(),
tt
.
account
)
require
.
Error
(
t
,
err
)
require
.
Equal
(
t
,
expectedErr
,
err
)
require
.
NoError
(
t
,
err
)
})
}
}
...
...
@@ -252,9 +292,12 @@ func TestCompositeTokenCacheInvalidator_AllPlatformsIntegration(t *testing.T) {
{
ID
:
4
,
Platform
:
PlatformAnthropic
,
Type
:
AccountTypeOAuth
},
}
// 新行为:Gemini 和 Antigravity 会同时删除基于 project_id 和 account_id 的键
expectedKeys
:=
[]
string
{
"gemini:gemini-proj"
,
"gemini:account:1"
,
"ag:ag-proj"
,
"ag:account:2"
,
"openai:account:3"
,
"claude:account:4"
,
}
...
...
@@ -266,3 +309,239 @@ func TestCompositeTokenCacheInvalidator_AllPlatformsIntegration(t *testing.T) {
require
.
Equal
(
t
,
expectedKeys
,
cache
.
deletedKeys
)
}
// ========== GetCredentialAsInt64 测试 ==========
func
TestAccount_GetCredentialAsInt64
(
t
*
testing
.
T
)
{
tests
:=
[]
struct
{
name
string
credentials
map
[
string
]
any
key
string
expected
int64
}{
{
name
:
"int64_value"
,
credentials
:
map
[
string
]
any
{
"_token_version"
:
int64
(
1737654321000
)},
key
:
"_token_version"
,
expected
:
1737654321000
,
},
{
name
:
"float64_value"
,
credentials
:
map
[
string
]
any
{
"_token_version"
:
float64
(
1737654321000
)},
key
:
"_token_version"
,
expected
:
1737654321000
,
},
{
name
:
"int_value"
,
credentials
:
map
[
string
]
any
{
"_token_version"
:
12345
},
key
:
"_token_version"
,
expected
:
12345
,
},
{
name
:
"string_value"
,
credentials
:
map
[
string
]
any
{
"_token_version"
:
"1737654321000"
},
key
:
"_token_version"
,
expected
:
1737654321000
,
},
{
name
:
"string_with_spaces"
,
credentials
:
map
[
string
]
any
{
"_token_version"
:
" 1737654321000 "
},
key
:
"_token_version"
,
expected
:
1737654321000
,
},
{
name
:
"nil_credentials"
,
credentials
:
nil
,
key
:
"_token_version"
,
expected
:
0
,
},
{
name
:
"missing_key"
,
credentials
:
map
[
string
]
any
{
"other_key"
:
123
},
key
:
"_token_version"
,
expected
:
0
,
},
{
name
:
"nil_value"
,
credentials
:
map
[
string
]
any
{
"_token_version"
:
nil
},
key
:
"_token_version"
,
expected
:
0
,
},
{
name
:
"invalid_string"
,
credentials
:
map
[
string
]
any
{
"_token_version"
:
"not_a_number"
},
key
:
"_token_version"
,
expected
:
0
,
},
{
name
:
"empty_string"
,
credentials
:
map
[
string
]
any
{
"_token_version"
:
""
},
key
:
"_token_version"
,
expected
:
0
,
},
}
for
_
,
tt
:=
range
tests
{
t
.
Run
(
tt
.
name
,
func
(
t
*
testing
.
T
)
{
account
:=
&
Account
{
Credentials
:
tt
.
credentials
}
result
:=
account
.
GetCredentialAsInt64
(
tt
.
key
)
require
.
Equal
(
t
,
tt
.
expected
,
result
)
})
}
}
func
TestAccount_GetCredentialAsInt64_NilAccount
(
t
*
testing
.
T
)
{
var
account
*
Account
result
:=
account
.
GetCredentialAsInt64
(
"_token_version"
)
require
.
Equal
(
t
,
int64
(
0
),
result
)
}
// ========== CheckTokenVersion 测试 ==========
func
TestCheckTokenVersion
(
t
*
testing
.
T
)
{
tests
:=
[]
struct
{
name
string
account
*
Account
latestAccount
*
Account
repoErr
error
expectedStale
bool
}{
{
name
:
"nil_account"
,
account
:
nil
,
latestAccount
:
nil
,
expectedStale
:
false
,
},
{
name
:
"no_version_in_account_but_db_has_version"
,
account
:
&
Account
{
ID
:
1
,
Credentials
:
map
[
string
]
any
{},
},
latestAccount
:
&
Account
{
ID
:
1
,
Credentials
:
map
[
string
]
any
{
"_token_version"
:
int64
(
100
)},
},
expectedStale
:
true
,
// 当前 account 无版本但 DB 有,说明已被异步刷新,当前已过时
},
{
name
:
"both_no_version"
,
account
:
&
Account
{
ID
:
1
,
Credentials
:
map
[
string
]
any
{},
},
latestAccount
:
&
Account
{
ID
:
1
,
Credentials
:
map
[
string
]
any
{},
},
expectedStale
:
false
,
// 两边都没有版本号,说明从未被异步刷新过,允许缓存
},
{
name
:
"same_version"
,
account
:
&
Account
{
ID
:
1
,
Credentials
:
map
[
string
]
any
{
"_token_version"
:
int64
(
100
)},
},
latestAccount
:
&
Account
{
ID
:
1
,
Credentials
:
map
[
string
]
any
{
"_token_version"
:
int64
(
100
)},
},
expectedStale
:
false
,
},
{
name
:
"current_version_newer"
,
account
:
&
Account
{
ID
:
1
,
Credentials
:
map
[
string
]
any
{
"_token_version"
:
int64
(
200
)},
},
latestAccount
:
&
Account
{
ID
:
1
,
Credentials
:
map
[
string
]
any
{
"_token_version"
:
int64
(
100
)},
},
expectedStale
:
false
,
},
{
name
:
"current_version_older_stale"
,
account
:
&
Account
{
ID
:
1
,
Credentials
:
map
[
string
]
any
{
"_token_version"
:
int64
(
100
)},
},
latestAccount
:
&
Account
{
ID
:
1
,
Credentials
:
map
[
string
]
any
{
"_token_version"
:
int64
(
200
)},
},
expectedStale
:
true
,
// 当前版本过时
},
{
name
:
"repo_error"
,
account
:
&
Account
{
ID
:
1
,
Credentials
:
map
[
string
]
any
{
"_token_version"
:
int64
(
100
)},
},
latestAccount
:
nil
,
repoErr
:
errors
.
New
(
"db error"
),
expectedStale
:
false
,
// 查询失败,默认允许缓存
},
{
name
:
"repo_returns_nil"
,
account
:
&
Account
{
ID
:
1
,
Credentials
:
map
[
string
]
any
{
"_token_version"
:
int64
(
100
)},
},
latestAccount
:
nil
,
repoErr
:
nil
,
expectedStale
:
false
,
// 查询返回 nil,默认允许缓存
},
}
for
_
,
tt
:=
range
tests
{
t
.
Run
(
tt
.
name
,
func
(
t
*
testing
.
T
)
{
// 由于 CheckTokenVersion 接受 AccountRepository 接口,而创建完整的 mock 很繁琐
// 这里我们直接测试函数的核心逻辑来验证行为
if
tt
.
name
==
"nil_account"
{
_
,
isStale
:=
CheckTokenVersion
(
context
.
Background
(),
nil
,
nil
)
require
.
Equal
(
t
,
tt
.
expectedStale
,
isStale
)
return
}
// 模拟 CheckTokenVersion 的核心逻辑
account
:=
tt
.
account
currentVersion
:=
account
.
GetCredentialAsInt64
(
"_token_version"
)
// 模拟 repo 查询
latestAccount
:=
tt
.
latestAccount
if
tt
.
repoErr
!=
nil
||
latestAccount
==
nil
{
require
.
Equal
(
t
,
tt
.
expectedStale
,
false
)
return
}
latestVersion
:=
latestAccount
.
GetCredentialAsInt64
(
"_token_version"
)
// 情况1: 当前 account 没有版本号,但 DB 中已有版本号
if
currentVersion
==
0
&&
latestVersion
>
0
{
require
.
Equal
(
t
,
tt
.
expectedStale
,
true
)
return
}
// 情况2: 两边都没有版本号
if
currentVersion
==
0
&&
latestVersion
==
0
{
require
.
Equal
(
t
,
tt
.
expectedStale
,
false
)
return
}
// 情况3: 比较版本号
isStale
:=
latestVersion
>
currentVersion
require
.
Equal
(
t
,
tt
.
expectedStale
,
isStale
)
})
}
}
func
TestCheckTokenVersion_NilRepo
(
t
*
testing
.
T
)
{
account
:=
&
Account
{
ID
:
1
,
Credentials
:
map
[
string
]
any
{
"_token_version"
:
int64
(
100
)},
}
_
,
isStale
:=
CheckTokenVersion
(
context
.
Background
(),
account
,
nil
)
require
.
False
(
t
,
isStale
)
// nil repo,默认允许缓存
}
backend/internal/service/token_refresh_service.go
View file @
a161fcc8
...
...
@@ -169,6 +169,10 @@ func (s *TokenRefreshService) refreshWithRetry(ctx context.Context, account *Acc
// 如果有新凭证,先更新(即使有错误也要保存 token)
if
newCredentials
!=
nil
{
// 记录刷新版本时间戳,用于解决缓存一致性问题
// TokenProvider 写入缓存前会检查此版本,如果版本已更新则跳过写入
newCredentials
[
"_token_version"
]
=
time
.
Now
()
.
UnixMilli
()
account
.
Credentials
=
newCredentials
if
saveErr
:=
s
.
accountRepo
.
Update
(
ctx
,
account
);
saveErr
!=
nil
{
return
fmt
.
Errorf
(
"failed to save credentials: %w"
,
saveErr
)
...
...
@@ -233,7 +237,8 @@ func (s *TokenRefreshService) refreshWithRetry(ctx context.Context, account *Acc
}
// isNonRetryableRefreshError 判断是否为不可重试的刷新错误
// 这些错误通常表示凭证已失效,需要用户重新授权
// 这些错误通常表示凭证已失效或配置确实缺失,需要用户重新授权
// 注意:missing_project_id 错误只在真正缺失(从未获取过)时返回,临时获取失败不会返回此错误
func
isNonRetryableRefreshError
(
err
error
)
bool
{
if
err
==
nil
{
return
false
...
...
backend/internal/service/totp_service.go
0 → 100644
View file @
a161fcc8
package
service
import
(
"context"
"crypto/rand"
"crypto/subtle"
"encoding/hex"
"fmt"
"log/slog"
"time"
"github.com/pquerna/otp/totp"
infraerrors
"github.com/Wei-Shaw/sub2api/internal/pkg/errors"
)
var
(
ErrTotpNotEnabled
=
infraerrors
.
BadRequest
(
"TOTP_NOT_ENABLED"
,
"totp feature is not enabled"
)
ErrTotpAlreadyEnabled
=
infraerrors
.
BadRequest
(
"TOTP_ALREADY_ENABLED"
,
"totp is already enabled for this account"
)
ErrTotpNotSetup
=
infraerrors
.
BadRequest
(
"TOTP_NOT_SETUP"
,
"totp is not set up for this account"
)
ErrTotpInvalidCode
=
infraerrors
.
BadRequest
(
"TOTP_INVALID_CODE"
,
"invalid totp code"
)
ErrTotpSetupExpired
=
infraerrors
.
BadRequest
(
"TOTP_SETUP_EXPIRED"
,
"totp setup session expired"
)
ErrTotpTooManyAttempts
=
infraerrors
.
TooManyRequests
(
"TOTP_TOO_MANY_ATTEMPTS"
,
"too many verification attempts, please try again later"
)
ErrVerifyCodeRequired
=
infraerrors
.
BadRequest
(
"VERIFY_CODE_REQUIRED"
,
"email verification code is required"
)
ErrPasswordRequired
=
infraerrors
.
BadRequest
(
"PASSWORD_REQUIRED"
,
"password is required"
)
)
// TotpCache defines cache operations for TOTP service
type
TotpCache
interface
{
// Setup session methods
GetSetupSession
(
ctx
context
.
Context
,
userID
int64
)
(
*
TotpSetupSession
,
error
)
SetSetupSession
(
ctx
context
.
Context
,
userID
int64
,
session
*
TotpSetupSession
,
ttl
time
.
Duration
)
error
DeleteSetupSession
(
ctx
context
.
Context
,
userID
int64
)
error
// Login session methods (for 2FA login flow)
GetLoginSession
(
ctx
context
.
Context
,
tempToken
string
)
(
*
TotpLoginSession
,
error
)
SetLoginSession
(
ctx
context
.
Context
,
tempToken
string
,
session
*
TotpLoginSession
,
ttl
time
.
Duration
)
error
DeleteLoginSession
(
ctx
context
.
Context
,
tempToken
string
)
error
// Rate limiting
IncrementVerifyAttempts
(
ctx
context
.
Context
,
userID
int64
)
(
int
,
error
)
GetVerifyAttempts
(
ctx
context
.
Context
,
userID
int64
)
(
int
,
error
)
ClearVerifyAttempts
(
ctx
context
.
Context
,
userID
int64
)
error
}
// SecretEncryptor defines encryption operations for TOTP secrets
type
SecretEncryptor
interface
{
Encrypt
(
plaintext
string
)
(
string
,
error
)
Decrypt
(
ciphertext
string
)
(
string
,
error
)
}
// TotpSetupSession represents a TOTP setup session
type
TotpSetupSession
struct
{
Secret
string
// Plain text TOTP secret (not encrypted yet)
SetupToken
string
// Random token to verify setup request
CreatedAt
time
.
Time
}
// TotpLoginSession represents a pending 2FA login session
type
TotpLoginSession
struct
{
UserID
int64
Email
string
TokenExpiry
time
.
Time
}
// TotpStatus represents the TOTP status for a user
type
TotpStatus
struct
{
Enabled
bool
`json:"enabled"`
EnabledAt
*
time
.
Time
`json:"enabled_at,omitempty"`
FeatureEnabled
bool
`json:"feature_enabled"`
}
// TotpSetupResponse represents the response for initiating TOTP setup
type
TotpSetupResponse
struct
{
Secret
string
`json:"secret"`
QRCodeURL
string
`json:"qr_code_url"`
SetupToken
string
`json:"setup_token"`
Countdown
int
`json:"countdown"`
// seconds until setup expires
}
const
(
totpSetupTTL
=
5
*
time
.
Minute
totpLoginTTL
=
5
*
time
.
Minute
totpAttemptsTTL
=
15
*
time
.
Minute
maxTotpAttempts
=
5
totpIssuer
=
"Sub2API"
)
// TotpService handles TOTP operations
type
TotpService
struct
{
userRepo
UserRepository
encryptor
SecretEncryptor
cache
TotpCache
settingService
*
SettingService
emailService
*
EmailService
emailQueueService
*
EmailQueueService
}
// NewTotpService creates a new TOTP service
func
NewTotpService
(
userRepo
UserRepository
,
encryptor
SecretEncryptor
,
cache
TotpCache
,
settingService
*
SettingService
,
emailService
*
EmailService
,
emailQueueService
*
EmailQueueService
,
)
*
TotpService
{
return
&
TotpService
{
userRepo
:
userRepo
,
encryptor
:
encryptor
,
cache
:
cache
,
settingService
:
settingService
,
emailService
:
emailService
,
emailQueueService
:
emailQueueService
,
}
}
// GetStatus returns the TOTP status for a user
func
(
s
*
TotpService
)
GetStatus
(
ctx
context
.
Context
,
userID
int64
)
(
*
TotpStatus
,
error
)
{
featureEnabled
:=
s
.
settingService
.
IsTotpEnabled
(
ctx
)
user
,
err
:=
s
.
userRepo
.
GetByID
(
ctx
,
userID
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"get user: %w"
,
err
)
}
return
&
TotpStatus
{
Enabled
:
user
.
TotpEnabled
,
EnabledAt
:
user
.
TotpEnabledAt
,
FeatureEnabled
:
featureEnabled
,
},
nil
}
// InitiateSetup starts the TOTP setup process
// If email verification is enabled, emailCode is required; otherwise password is required
func
(
s
*
TotpService
)
InitiateSetup
(
ctx
context
.
Context
,
userID
int64
,
emailCode
,
password
string
)
(
*
TotpSetupResponse
,
error
)
{
// Check if TOTP feature is enabled globally
if
!
s
.
settingService
.
IsTotpEnabled
(
ctx
)
{
return
nil
,
ErrTotpNotEnabled
}
// Get user and check if TOTP is already enabled
user
,
err
:=
s
.
userRepo
.
GetByID
(
ctx
,
userID
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"get user: %w"
,
err
)
}
if
user
.
TotpEnabled
{
return
nil
,
ErrTotpAlreadyEnabled
}
// Verify identity based on email verification setting
if
s
.
settingService
.
IsEmailVerifyEnabled
(
ctx
)
{
// Email verification enabled - verify email code
if
emailCode
==
""
{
return
nil
,
ErrVerifyCodeRequired
}
if
err
:=
s
.
emailService
.
VerifyCode
(
ctx
,
user
.
Email
,
emailCode
);
err
!=
nil
{
return
nil
,
err
}
}
else
{
// Email verification disabled - verify password
if
password
==
""
{
return
nil
,
ErrPasswordRequired
}
if
!
user
.
CheckPassword
(
password
)
{
return
nil
,
ErrPasswordIncorrect
}
}
// Generate a new TOTP key
key
,
err
:=
totp
.
Generate
(
totp
.
GenerateOpts
{
Issuer
:
totpIssuer
,
AccountName
:
user
.
Email
,
})
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"generate totp key: %w"
,
err
)
}
// Generate a random setup token
setupToken
,
err
:=
generateRandomToken
(
32
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"generate setup token: %w"
,
err
)
}
// Store the setup session in cache
session
:=
&
TotpSetupSession
{
Secret
:
key
.
Secret
(),
SetupToken
:
setupToken
,
CreatedAt
:
time
.
Now
(),
}
if
err
:=
s
.
cache
.
SetSetupSession
(
ctx
,
userID
,
session
,
totpSetupTTL
);
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"store setup session: %w"
,
err
)
}
return
&
TotpSetupResponse
{
Secret
:
key
.
Secret
(),
QRCodeURL
:
key
.
URL
(),
SetupToken
:
setupToken
,
Countdown
:
int
(
totpSetupTTL
.
Seconds
()),
},
nil
}
// CompleteSetup completes the TOTP setup by verifying the code
func
(
s
*
TotpService
)
CompleteSetup
(
ctx
context
.
Context
,
userID
int64
,
totpCode
,
setupToken
string
)
error
{
// Check if TOTP feature is enabled globally
if
!
s
.
settingService
.
IsTotpEnabled
(
ctx
)
{
return
ErrTotpNotEnabled
}
// Get the setup session
session
,
err
:=
s
.
cache
.
GetSetupSession
(
ctx
,
userID
)
if
err
!=
nil
{
return
ErrTotpSetupExpired
}
if
session
==
nil
{
return
ErrTotpSetupExpired
}
// Verify the setup token (constant-time comparison)
if
subtle
.
ConstantTimeCompare
([]
byte
(
session
.
SetupToken
),
[]
byte
(
setupToken
))
!=
1
{
return
ErrTotpSetupExpired
}
// Verify the TOTP code
if
!
totp
.
Validate
(
totpCode
,
session
.
Secret
)
{
return
ErrTotpInvalidCode
}
setupSecretPrefix
:=
"N/A"
if
len
(
session
.
Secret
)
>=
4
{
setupSecretPrefix
=
session
.
Secret
[
:
4
]
}
slog
.
Debug
(
"totp_complete_setup_before_encrypt"
,
"user_id"
,
userID
,
"secret_len"
,
len
(
session
.
Secret
),
"secret_prefix"
,
setupSecretPrefix
)
// Encrypt the secret
encryptedSecret
,
err
:=
s
.
encryptor
.
Encrypt
(
session
.
Secret
)
if
err
!=
nil
{
return
fmt
.
Errorf
(
"encrypt totp secret: %w"
,
err
)
}
slog
.
Debug
(
"totp_complete_setup_encrypted"
,
"user_id"
,
userID
,
"encrypted_len"
,
len
(
encryptedSecret
))
// Verify encryption by decrypting
decrypted
,
decErr
:=
s
.
encryptor
.
Decrypt
(
encryptedSecret
)
if
decErr
!=
nil
{
slog
.
Debug
(
"totp_complete_setup_verify_failed"
,
"user_id"
,
userID
,
"error"
,
decErr
)
}
else
{
decryptedPrefix
:=
"N/A"
if
len
(
decrypted
)
>=
4
{
decryptedPrefix
=
decrypted
[
:
4
]
}
slog
.
Debug
(
"totp_complete_setup_verified"
,
"user_id"
,
userID
,
"original_len"
,
len
(
session
.
Secret
),
"decrypted_len"
,
len
(
decrypted
),
"match"
,
session
.
Secret
==
decrypted
,
"decrypted_prefix"
,
decryptedPrefix
)
}
// Update user with encrypted TOTP secret
if
err
:=
s
.
userRepo
.
UpdateTotpSecret
(
ctx
,
userID
,
&
encryptedSecret
);
err
!=
nil
{
return
fmt
.
Errorf
(
"update totp secret: %w"
,
err
)
}
// Enable TOTP for the user
if
err
:=
s
.
userRepo
.
EnableTotp
(
ctx
,
userID
);
err
!=
nil
{
return
fmt
.
Errorf
(
"enable totp: %w"
,
err
)
}
// Clean up the setup session
_
=
s
.
cache
.
DeleteSetupSession
(
ctx
,
userID
)
return
nil
}
// Disable disables TOTP for a user
// If email verification is enabled, emailCode is required; otherwise password is required
func
(
s
*
TotpService
)
Disable
(
ctx
context
.
Context
,
userID
int64
,
emailCode
,
password
string
)
error
{
// Get user
user
,
err
:=
s
.
userRepo
.
GetByID
(
ctx
,
userID
)
if
err
!=
nil
{
return
fmt
.
Errorf
(
"get user: %w"
,
err
)
}
if
!
user
.
TotpEnabled
{
return
ErrTotpNotSetup
}
// Verify identity based on email verification setting
if
s
.
settingService
.
IsEmailVerifyEnabled
(
ctx
)
{
// Email verification enabled - verify email code
if
emailCode
==
""
{
return
ErrVerifyCodeRequired
}
if
err
:=
s
.
emailService
.
VerifyCode
(
ctx
,
user
.
Email
,
emailCode
);
err
!=
nil
{
return
err
}
}
else
{
// Email verification disabled - verify password
if
password
==
""
{
return
ErrPasswordRequired
}
if
!
user
.
CheckPassword
(
password
)
{
return
ErrPasswordIncorrect
}
}
// Disable TOTP
if
err
:=
s
.
userRepo
.
DisableTotp
(
ctx
,
userID
);
err
!=
nil
{
return
fmt
.
Errorf
(
"disable totp: %w"
,
err
)
}
return
nil
}
// VerifyCode verifies a TOTP code for a user
func
(
s
*
TotpService
)
VerifyCode
(
ctx
context
.
Context
,
userID
int64
,
code
string
)
error
{
slog
.
Debug
(
"totp_verify_code_called"
,
"user_id"
,
userID
,
"code_len"
,
len
(
code
))
// Check rate limiting
attempts
,
err
:=
s
.
cache
.
GetVerifyAttempts
(
ctx
,
userID
)
if
err
==
nil
&&
attempts
>=
maxTotpAttempts
{
return
ErrTotpTooManyAttempts
}
// Get user
user
,
err
:=
s
.
userRepo
.
GetByID
(
ctx
,
userID
)
if
err
!=
nil
{
slog
.
Debug
(
"totp_verify_get_user_failed"
,
"user_id"
,
userID
,
"error"
,
err
)
return
infraerrors
.
InternalServer
(
"TOTP_VERIFY_ERROR"
,
"failed to verify totp code"
)
}
if
!
user
.
TotpEnabled
||
user
.
TotpSecretEncrypted
==
nil
{
slog
.
Debug
(
"totp_verify_not_setup"
,
"user_id"
,
userID
,
"enabled"
,
user
.
TotpEnabled
,
"has_secret"
,
user
.
TotpSecretEncrypted
!=
nil
)
return
ErrTotpNotSetup
}
slog
.
Debug
(
"totp_verify_encrypted_secret"
,
"user_id"
,
userID
,
"encrypted_len"
,
len
(
*
user
.
TotpSecretEncrypted
))
// Decrypt the secret
secret
,
err
:=
s
.
encryptor
.
Decrypt
(
*
user
.
TotpSecretEncrypted
)
if
err
!=
nil
{
slog
.
Debug
(
"totp_verify_decrypt_failed"
,
"user_id"
,
userID
,
"error"
,
err
)
return
infraerrors
.
InternalServer
(
"TOTP_VERIFY_ERROR"
,
"failed to verify totp code"
)
}
secretPrefix
:=
"N/A"
if
len
(
secret
)
>=
4
{
secretPrefix
=
secret
[
:
4
]
}
slog
.
Debug
(
"totp_verify_decrypted"
,
"user_id"
,
userID
,
"secret_len"
,
len
(
secret
),
"secret_prefix"
,
secretPrefix
)
// Verify the code
valid
:=
totp
.
Validate
(
code
,
secret
)
slog
.
Debug
(
"totp_verify_result"
,
"user_id"
,
userID
,
"valid"
,
valid
,
"secret_len"
,
len
(
secret
),
"secret_prefix"
,
secretPrefix
,
"server_time"
,
time
.
Now
()
.
UTC
()
.
Format
(
time
.
RFC3339
))
if
!
valid
{
// Increment failed attempts
_
,
_
=
s
.
cache
.
IncrementVerifyAttempts
(
ctx
,
userID
)
return
ErrTotpInvalidCode
}
// Clear attempt counter on success
_
=
s
.
cache
.
ClearVerifyAttempts
(
ctx
,
userID
)
return
nil
}
// CreateLoginSession creates a temporary login session for 2FA
func
(
s
*
TotpService
)
CreateLoginSession
(
ctx
context
.
Context
,
userID
int64
,
email
string
)
(
string
,
error
)
{
// Generate a random temp token
tempToken
,
err
:=
generateRandomToken
(
32
)
if
err
!=
nil
{
return
""
,
fmt
.
Errorf
(
"generate temp token: %w"
,
err
)
}
session
:=
&
TotpLoginSession
{
UserID
:
userID
,
Email
:
email
,
TokenExpiry
:
time
.
Now
()
.
Add
(
totpLoginTTL
),
}
if
err
:=
s
.
cache
.
SetLoginSession
(
ctx
,
tempToken
,
session
,
totpLoginTTL
);
err
!=
nil
{
return
""
,
fmt
.
Errorf
(
"store login session: %w"
,
err
)
}
return
tempToken
,
nil
}
// GetLoginSession retrieves a login session
func
(
s
*
TotpService
)
GetLoginSession
(
ctx
context
.
Context
,
tempToken
string
)
(
*
TotpLoginSession
,
error
)
{
return
s
.
cache
.
GetLoginSession
(
ctx
,
tempToken
)
}
// DeleteLoginSession deletes a login session
func
(
s
*
TotpService
)
DeleteLoginSession
(
ctx
context
.
Context
,
tempToken
string
)
error
{
return
s
.
cache
.
DeleteLoginSession
(
ctx
,
tempToken
)
}
// IsTotpEnabledForUser checks if TOTP is enabled for a specific user
func
(
s
*
TotpService
)
IsTotpEnabledForUser
(
ctx
context
.
Context
,
userID
int64
)
(
bool
,
error
)
{
user
,
err
:=
s
.
userRepo
.
GetByID
(
ctx
,
userID
)
if
err
!=
nil
{
return
false
,
fmt
.
Errorf
(
"get user: %w"
,
err
)
}
return
user
.
TotpEnabled
,
nil
}
// MaskEmail masks an email address for display
func
MaskEmail
(
email
string
)
string
{
if
len
(
email
)
<
3
{
return
"***"
}
atIdx
:=
-
1
for
i
,
c
:=
range
email
{
if
c
==
'@'
{
atIdx
=
i
break
}
}
if
atIdx
==
-
1
||
atIdx
<
1
{
return
email
[
:
1
]
+
"***"
}
localPart
:=
email
[
:
atIdx
]
domain
:=
email
[
atIdx
:
]
if
len
(
localPart
)
<=
2
{
return
localPart
[
:
1
]
+
"***"
+
domain
}
return
localPart
[
:
1
]
+
"***"
+
localPart
[
len
(
localPart
)
-
1
:
]
+
domain
}
// generateRandomToken generates a random hex-encoded token
func
generateRandomToken
(
byteLength
int
)
(
string
,
error
)
{
b
:=
make
([]
byte
,
byteLength
)
if
_
,
err
:=
rand
.
Read
(
b
);
err
!=
nil
{
return
""
,
err
}
return
hex
.
EncodeToString
(
b
),
nil
}
// VerificationMethod represents the method required for TOTP operations
type
VerificationMethod
struct
{
Method
string
`json:"method"`
// "email" or "password"
}
// GetVerificationMethod returns the verification method for TOTP operations
func
(
s
*
TotpService
)
GetVerificationMethod
(
ctx
context
.
Context
)
*
VerificationMethod
{
if
s
.
settingService
.
IsEmailVerifyEnabled
(
ctx
)
{
return
&
VerificationMethod
{
Method
:
"email"
}
}
return
&
VerificationMethod
{
Method
:
"password"
}
}
// SendVerifyCode sends an email verification code for TOTP operations
func
(
s
*
TotpService
)
SendVerifyCode
(
ctx
context
.
Context
,
userID
int64
)
error
{
// Check if email verification is enabled
if
!
s
.
settingService
.
IsEmailVerifyEnabled
(
ctx
)
{
return
infraerrors
.
BadRequest
(
"EMAIL_VERIFY_NOT_ENABLED"
,
"email verification is not enabled"
)
}
// Get user email
user
,
err
:=
s
.
userRepo
.
GetByID
(
ctx
,
userID
)
if
err
!=
nil
{
return
fmt
.
Errorf
(
"get user: %w"
,
err
)
}
// Get site name for email
siteName
:=
s
.
settingService
.
GetSiteName
(
ctx
)
// Send verification code via queue
return
s
.
emailQueueService
.
EnqueueVerifyCode
(
user
.
Email
,
siteName
)
}
backend/internal/service/user.go
View file @
a161fcc8
...
...
@@ -21,6 +21,11 @@ type User struct {
CreatedAt
time
.
Time
UpdatedAt
time
.
Time
// TOTP 双因素认证字段
TotpSecretEncrypted
*
string
// AES-256-GCM 加密的 TOTP 密钥
TotpEnabled
bool
// 是否启用 TOTP
TotpEnabledAt
*
time
.
Time
// TOTP 启用时间
APIKeys
[]
APIKey
Subscriptions
[]
UserSubscription
}
...
...
backend/internal/service/user_service.go
View file @
a161fcc8
...
...
@@ -38,6 +38,11 @@ type UserRepository interface {
UpdateConcurrency
(
ctx
context
.
Context
,
id
int64
,
amount
int
)
error
ExistsByEmail
(
ctx
context
.
Context
,
email
string
)
(
bool
,
error
)
RemoveGroupFromAllowedGroups
(
ctx
context
.
Context
,
groupID
int64
)
(
int64
,
error
)
// TOTP 相关方法
UpdateTotpSecret
(
ctx
context
.
Context
,
userID
int64
,
encryptedSecret
*
string
)
error
EnableTotp
(
ctx
context
.
Context
,
userID
int64
)
error
DisableTotp
(
ctx
context
.
Context
,
userID
int64
)
error
}
// UpdateProfileRequest 更新用户资料请求
...
...
Prev
1
2
3
4
5
6
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment