Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Menu
Open sidebar
陈曦
sub2api
Commits
a161fcc8
"...internal/pkg/git@web.lueluesay.top:chenxi/sub2api.git" did not exist on "9634494ba943a4f488a2093b5024631060c7dd5f"
Commit
a161fcc8
authored
Jan 26, 2026
by
cyhhao
Browse files
Merge branch 'main' of github.com:Wei-Shaw/sub2api
parents
65e69738
e32c5f53
Changes
119
Hide whitespace changes
Inline
Side-by-side
backend/internal/service/email_service.go
View file @
a161fcc8
...
@@ -3,11 +3,14 @@ package service
...
@@ -3,11 +3,14 @@ package service
import
(
import
(
"context"
"context"
"crypto/rand"
"crypto/rand"
"crypto/subtle"
"crypto/tls"
"crypto/tls"
"encoding/hex"
"fmt"
"fmt"
"log"
"log"
"math/big"
"math/big"
"net/smtp"
"net/smtp"
"net/url"
"strconv"
"strconv"
"time"
"time"
...
@@ -19,6 +22,9 @@ var (
...
@@ -19,6 +22,9 @@ var (
ErrInvalidVerifyCode
=
infraerrors
.
BadRequest
(
"INVALID_VERIFY_CODE"
,
"invalid or expired verification code"
)
ErrInvalidVerifyCode
=
infraerrors
.
BadRequest
(
"INVALID_VERIFY_CODE"
,
"invalid or expired verification code"
)
ErrVerifyCodeTooFrequent
=
infraerrors
.
TooManyRequests
(
"VERIFY_CODE_TOO_FREQUENT"
,
"please wait before requesting a new code"
)
ErrVerifyCodeTooFrequent
=
infraerrors
.
TooManyRequests
(
"VERIFY_CODE_TOO_FREQUENT"
,
"please wait before requesting a new code"
)
ErrVerifyCodeMaxAttempts
=
infraerrors
.
TooManyRequests
(
"VERIFY_CODE_MAX_ATTEMPTS"
,
"too many failed attempts, please request a new code"
)
ErrVerifyCodeMaxAttempts
=
infraerrors
.
TooManyRequests
(
"VERIFY_CODE_MAX_ATTEMPTS"
,
"too many failed attempts, please request a new code"
)
// Password reset errors
ErrInvalidResetToken
=
infraerrors
.
BadRequest
(
"INVALID_RESET_TOKEN"
,
"invalid or expired password reset token"
)
)
)
// EmailCache defines cache operations for email service
// EmailCache defines cache operations for email service
...
@@ -26,6 +32,16 @@ type EmailCache interface {
...
@@ -26,6 +32,16 @@ type EmailCache interface {
GetVerificationCode
(
ctx
context
.
Context
,
email
string
)
(
*
VerificationCodeData
,
error
)
GetVerificationCode
(
ctx
context
.
Context
,
email
string
)
(
*
VerificationCodeData
,
error
)
SetVerificationCode
(
ctx
context
.
Context
,
email
string
,
data
*
VerificationCodeData
,
ttl
time
.
Duration
)
error
SetVerificationCode
(
ctx
context
.
Context
,
email
string
,
data
*
VerificationCodeData
,
ttl
time
.
Duration
)
error
DeleteVerificationCode
(
ctx
context
.
Context
,
email
string
)
error
DeleteVerificationCode
(
ctx
context
.
Context
,
email
string
)
error
// Password reset token methods
GetPasswordResetToken
(
ctx
context
.
Context
,
email
string
)
(
*
PasswordResetTokenData
,
error
)
SetPasswordResetToken
(
ctx
context
.
Context
,
email
string
,
data
*
PasswordResetTokenData
,
ttl
time
.
Duration
)
error
DeletePasswordResetToken
(
ctx
context
.
Context
,
email
string
)
error
// Password reset email cooldown methods
// Returns true if in cooldown period (email was sent recently)
IsPasswordResetEmailInCooldown
(
ctx
context
.
Context
,
email
string
)
bool
SetPasswordResetEmailCooldown
(
ctx
context
.
Context
,
email
string
,
ttl
time
.
Duration
)
error
}
}
// VerificationCodeData represents verification code data
// VerificationCodeData represents verification code data
...
@@ -35,10 +51,22 @@ type VerificationCodeData struct {
...
@@ -35,10 +51,22 @@ type VerificationCodeData struct {
CreatedAt
time
.
Time
CreatedAt
time
.
Time
}
}
// PasswordResetTokenData represents password reset token data
type
PasswordResetTokenData
struct
{
Token
string
CreatedAt
time
.
Time
}
const
(
const
(
verifyCodeTTL
=
15
*
time
.
Minute
verifyCodeTTL
=
15
*
time
.
Minute
verifyCodeCooldown
=
1
*
time
.
Minute
verifyCodeCooldown
=
1
*
time
.
Minute
maxVerifyCodeAttempts
=
5
maxVerifyCodeAttempts
=
5
// Password reset token settings
passwordResetTokenTTL
=
30
*
time
.
Minute
// Password reset email cooldown (prevent email bombing)
passwordResetEmailCooldown
=
30
*
time
.
Second
)
)
// SMTPConfig SMTP配置
// SMTPConfig SMTP配置
...
@@ -254,8 +282,8 @@ func (s *EmailService) VerifyCode(ctx context.Context, email, code string) error
...
@@ -254,8 +282,8 @@ func (s *EmailService) VerifyCode(ctx context.Context, email, code string) error
return
ErrVerifyCodeMaxAttempts
return
ErrVerifyCodeMaxAttempts
}
}
// 验证码不匹配
// 验证码不匹配
(constant-time comparison to prevent timing attacks)
if
data
.
Code
!=
code
{
if
subtle
.
ConstantTimeCompare
([]
byte
(
data
.
Code
),
[]
byte
(
code
))
!=
1
{
data
.
Attempts
++
data
.
Attempts
++
if
err
:=
s
.
cache
.
SetVerificationCode
(
ctx
,
email
,
data
,
verifyCodeTTL
);
err
!=
nil
{
if
err
:=
s
.
cache
.
SetVerificationCode
(
ctx
,
email
,
data
,
verifyCodeTTL
);
err
!=
nil
{
log
.
Printf
(
"[Email] Failed to update verification attempt count: %v"
,
err
)
log
.
Printf
(
"[Email] Failed to update verification attempt count: %v"
,
err
)
...
@@ -357,3 +385,157 @@ func (s *EmailService) TestSMTPConnectionWithConfig(config *SMTPConfig) error {
...
@@ -357,3 +385,157 @@ func (s *EmailService) TestSMTPConnectionWithConfig(config *SMTPConfig) error {
return
client
.
Quit
()
return
client
.
Quit
()
}
}
// GeneratePasswordResetToken generates a secure 32-byte random token (64 hex characters)
func
(
s
*
EmailService
)
GeneratePasswordResetToken
()
(
string
,
error
)
{
bytes
:=
make
([]
byte
,
32
)
if
_
,
err
:=
rand
.
Read
(
bytes
);
err
!=
nil
{
return
""
,
err
}
return
hex
.
EncodeToString
(
bytes
),
nil
}
// SendPasswordResetEmail sends a password reset email with a reset link
func
(
s
*
EmailService
)
SendPasswordResetEmail
(
ctx
context
.
Context
,
email
,
siteName
,
resetURL
string
)
error
{
var
token
string
var
needSaveToken
bool
// Check if token already exists
existing
,
err
:=
s
.
cache
.
GetPasswordResetToken
(
ctx
,
email
)
if
err
==
nil
&&
existing
!=
nil
{
// Token exists, reuse it (allows resending email without generating new token)
token
=
existing
.
Token
needSaveToken
=
false
}
else
{
// Generate new token
token
,
err
=
s
.
GeneratePasswordResetToken
()
if
err
!=
nil
{
return
fmt
.
Errorf
(
"generate token: %w"
,
err
)
}
needSaveToken
=
true
}
// Save token to Redis (only if new token generated)
if
needSaveToken
{
data
:=
&
PasswordResetTokenData
{
Token
:
token
,
CreatedAt
:
time
.
Now
(),
}
if
err
:=
s
.
cache
.
SetPasswordResetToken
(
ctx
,
email
,
data
,
passwordResetTokenTTL
);
err
!=
nil
{
return
fmt
.
Errorf
(
"save reset token: %w"
,
err
)
}
}
// Build full reset URL with URL-encoded token and email
fullResetURL
:=
fmt
.
Sprintf
(
"%s?email=%s&token=%s"
,
resetURL
,
url
.
QueryEscape
(
email
),
url
.
QueryEscape
(
token
))
// Build email content
subject
:=
fmt
.
Sprintf
(
"[%s] 密码重置请求"
,
siteName
)
body
:=
s
.
buildPasswordResetEmailBody
(
fullResetURL
,
siteName
)
// Send email
if
err
:=
s
.
SendEmail
(
ctx
,
email
,
subject
,
body
);
err
!=
nil
{
return
fmt
.
Errorf
(
"send email: %w"
,
err
)
}
return
nil
}
// SendPasswordResetEmailWithCooldown sends password reset email with cooldown check (called by queue worker)
// This method wraps SendPasswordResetEmail with email cooldown to prevent email bombing
func
(
s
*
EmailService
)
SendPasswordResetEmailWithCooldown
(
ctx
context
.
Context
,
email
,
siteName
,
resetURL
string
)
error
{
// Check email cooldown to prevent email bombing
if
s
.
cache
.
IsPasswordResetEmailInCooldown
(
ctx
,
email
)
{
log
.
Printf
(
"[Email] Password reset email skipped (cooldown): %s"
,
email
)
return
nil
// Silent success to prevent revealing cooldown to attackers
}
// Send email using core method
if
err
:=
s
.
SendPasswordResetEmail
(
ctx
,
email
,
siteName
,
resetURL
);
err
!=
nil
{
return
err
}
// Set cooldown marker (Redis TTL handles expiration)
if
err
:=
s
.
cache
.
SetPasswordResetEmailCooldown
(
ctx
,
email
,
passwordResetEmailCooldown
);
err
!=
nil
{
log
.
Printf
(
"[Email] Failed to set password reset cooldown for %s: %v"
,
email
,
err
)
}
return
nil
}
// VerifyPasswordResetToken verifies the password reset token without consuming it
func
(
s
*
EmailService
)
VerifyPasswordResetToken
(
ctx
context
.
Context
,
email
,
token
string
)
error
{
data
,
err
:=
s
.
cache
.
GetPasswordResetToken
(
ctx
,
email
)
if
err
!=
nil
||
data
==
nil
{
return
ErrInvalidResetToken
}
// Use constant-time comparison to prevent timing attacks
if
subtle
.
ConstantTimeCompare
([]
byte
(
data
.
Token
),
[]
byte
(
token
))
!=
1
{
return
ErrInvalidResetToken
}
return
nil
}
// ConsumePasswordResetToken verifies and deletes the token (one-time use)
func
(
s
*
EmailService
)
ConsumePasswordResetToken
(
ctx
context
.
Context
,
email
,
token
string
)
error
{
// Verify first
if
err
:=
s
.
VerifyPasswordResetToken
(
ctx
,
email
,
token
);
err
!=
nil
{
return
err
}
// Delete after verification (one-time use)
if
err
:=
s
.
cache
.
DeletePasswordResetToken
(
ctx
,
email
);
err
!=
nil
{
log
.
Printf
(
"[Email] Failed to delete password reset token after consumption: %v"
,
err
)
}
return
nil
}
// buildPasswordResetEmailBody builds the HTML content for password reset email
func
(
s
*
EmailService
)
buildPasswordResetEmailBody
(
resetURL
,
siteName
string
)
string
{
return
fmt
.
Sprintf
(
`
<!DOCTYPE html>
<html>
<head>
<meta charset="UTF-8">
<style>
body { font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, Oxygen, Ubuntu, sans-serif; background-color: #f5f5f5; margin: 0; padding: 20px; }
.container { max-width: 600px; margin: 0 auto; background-color: #ffffff; border-radius: 8px; overflow: hidden; box-shadow: 0 2px 8px rgba(0,0,0,0.1); }
.header { background: linear-gradient(135deg, #667eea 0%%, #764ba2 100%%); color: white; padding: 30px; text-align: center; }
.header h1 { margin: 0; font-size: 24px; }
.content { padding: 40px 30px; text-align: center; }
.button { display: inline-block; background: linear-gradient(135deg, #667eea 0%%, #764ba2 100%%); color: white; padding: 14px 32px; text-decoration: none; border-radius: 8px; font-size: 16px; font-weight: 600; margin: 20px 0; }
.button:hover { opacity: 0.9; }
.info { color: #666; font-size: 14px; line-height: 1.6; margin-top: 20px; }
.link-fallback { color: #666; font-size: 12px; word-break: break-all; margin-top: 20px; padding: 15px; background-color: #f8f9fa; border-radius: 4px; }
.footer { background-color: #f8f9fa; padding: 20px; text-align: center; color: #999; font-size: 12px; }
.warning { color: #e74c3c; font-weight: 500; }
</style>
</head>
<body>
<div class="container">
<div class="header">
<h1>%s</h1>
</div>
<div class="content">
<p style="font-size: 18px; color: #333;">密码重置请求</p>
<p style="color: #666;">您已请求重置密码。请点击下方按钮设置新密码:</p>
<a href="%s" class="button">重置密码</a>
<div class="info">
<p>此链接将在 <strong>30 分钟</strong>后失效。</p>
<p class="warning">如果您没有请求重置密码,请忽略此邮件。您的密码将保持不变。</p>
</div>
<div class="link-fallback">
<p>如果按钮无法点击,请复制以下链接到浏览器中打开:</p>
<p>%s</p>
</div>
</div>
<div class="footer">
<p>这是一封自动发送的邮件,请勿回复。</p>
</div>
</div>
</body>
</html>
`
,
siteName
,
resetURL
,
resetURL
)
}
backend/internal/service/gateway_service.go
View file @
a161fcc8
...
@@ -342,6 +342,19 @@ func (s *GatewayService) BindStickySession(ctx context.Context, groupID *int64,
...
@@ -342,6 +342,19 @@ func (s *GatewayService) BindStickySession(ctx context.Context, groupID *int64,
return
s
.
cache
.
SetSessionAccountID
(
ctx
,
derefGroupID
(
groupID
),
sessionHash
,
accountID
,
stickySessionTTL
)
return
s
.
cache
.
SetSessionAccountID
(
ctx
,
derefGroupID
(
groupID
),
sessionHash
,
accountID
,
stickySessionTTL
)
}
}
// GetCachedSessionAccountID retrieves the account ID bound to a sticky session.
// Returns 0 if no binding exists or on error.
func
(
s
*
GatewayService
)
GetCachedSessionAccountID
(
ctx
context
.
Context
,
groupID
*
int64
,
sessionHash
string
)
(
int64
,
error
)
{
if
sessionHash
==
""
||
s
.
cache
==
nil
{
return
0
,
nil
}
accountID
,
err
:=
s
.
cache
.
GetSessionAccountID
(
ctx
,
derefGroupID
(
groupID
),
sessionHash
)
if
err
!=
nil
{
return
0
,
err
}
return
accountID
,
nil
}
func
(
s
*
GatewayService
)
extractCacheableContent
(
parsed
*
ParsedRequest
)
string
{
func
(
s
*
GatewayService
)
extractCacheableContent
(
parsed
*
ParsedRequest
)
string
{
if
parsed
==
nil
{
if
parsed
==
nil
{
return
""
return
""
...
...
backend/internal/service/gemini_messages_compat_service.go
View file @
a161fcc8
...
@@ -1972,6 +1972,7 @@ func collectGeminiSSE(body io.Reader, isOAuth bool) (map[string]any, *ClaudeUsag
...
@@ -1972,6 +1972,7 @@ func collectGeminiSSE(body io.Reader, isOAuth bool) (map[string]any, *ClaudeUsag
var
last
map
[
string
]
any
var
last
map
[
string
]
any
var
lastWithParts
map
[
string
]
any
var
lastWithParts
map
[
string
]
any
var
collectedTextParts
[]
string
// Collect all text parts for aggregation
usage
:=
&
ClaudeUsage
{}
usage
:=
&
ClaudeUsage
{}
for
{
for
{
...
@@ -1983,7 +1984,7 @@ func collectGeminiSSE(body io.Reader, isOAuth bool) (map[string]any, *ClaudeUsag
...
@@ -1983,7 +1984,7 @@ func collectGeminiSSE(body io.Reader, isOAuth bool) (map[string]any, *ClaudeUsag
switch
payload
{
switch
payload
{
case
""
,
"[DONE]"
:
case
""
,
"[DONE]"
:
if
payload
==
"[DONE]"
{
if
payload
==
"[DONE]"
{
return
pickGeminiCollectResult
(
last
,
lastWithParts
),
usage
,
nil
return
mergeCollectedTextParts
(
pickGeminiCollectResult
(
last
,
lastWithParts
),
collectedTextParts
),
usage
,
nil
}
}
default
:
default
:
var
parsed
map
[
string
]
any
var
parsed
map
[
string
]
any
...
@@ -2002,6 +2003,12 @@ func collectGeminiSSE(body io.Reader, isOAuth bool) (map[string]any, *ClaudeUsag
...
@@ -2002,6 +2003,12 @@ func collectGeminiSSE(body io.Reader, isOAuth bool) (map[string]any, *ClaudeUsag
}
}
if
parts
:=
extractGeminiParts
(
parsed
);
len
(
parts
)
>
0
{
if
parts
:=
extractGeminiParts
(
parsed
);
len
(
parts
)
>
0
{
lastWithParts
=
parsed
lastWithParts
=
parsed
// Collect text from each part for aggregation
for
_
,
part
:=
range
parts
{
if
text
,
ok
:=
part
[
"text"
]
.
(
string
);
ok
&&
text
!=
""
{
collectedTextParts
=
append
(
collectedTextParts
,
text
)
}
}
}
}
}
}
}
}
...
@@ -2016,7 +2023,7 @@ func collectGeminiSSE(body io.Reader, isOAuth bool) (map[string]any, *ClaudeUsag
...
@@ -2016,7 +2023,7 @@ func collectGeminiSSE(body io.Reader, isOAuth bool) (map[string]any, *ClaudeUsag
}
}
}
}
return
pickGeminiCollectResult
(
last
,
lastWithParts
),
usage
,
nil
return
mergeCollectedTextParts
(
pickGeminiCollectResult
(
last
,
lastWithParts
),
collectedTextParts
),
usage
,
nil
}
}
func
pickGeminiCollectResult
(
last
map
[
string
]
any
,
lastWithParts
map
[
string
]
any
)
map
[
string
]
any
{
func
pickGeminiCollectResult
(
last
map
[
string
]
any
,
lastWithParts
map
[
string
]
any
)
map
[
string
]
any
{
...
@@ -2029,6 +2036,83 @@ func pickGeminiCollectResult(last map[string]any, lastWithParts map[string]any)
...
@@ -2029,6 +2036,83 @@ func pickGeminiCollectResult(last map[string]any, lastWithParts map[string]any)
return
map
[
string
]
any
{}
return
map
[
string
]
any
{}
}
}
// mergeCollectedTextParts merges all collected text chunks into the final response.
// This fixes the issue where non-streaming responses only returned the last chunk
// instead of the complete aggregated text.
func
mergeCollectedTextParts
(
response
map
[
string
]
any
,
textParts
[]
string
)
map
[
string
]
any
{
if
len
(
textParts
)
==
0
{
return
response
}
// Join all text parts
mergedText
:=
strings
.
Join
(
textParts
,
""
)
// Deep copy response
result
:=
make
(
map
[
string
]
any
)
for
k
,
v
:=
range
response
{
result
[
k
]
=
v
}
// Get or create candidates
candidates
,
ok
:=
result
[
"candidates"
]
.
([]
any
)
if
!
ok
||
len
(
candidates
)
==
0
{
candidates
=
[]
any
{
map
[
string
]
any
{}}
}
// Get first candidate
candidate
,
ok
:=
candidates
[
0
]
.
(
map
[
string
]
any
)
if
!
ok
{
candidate
=
make
(
map
[
string
]
any
)
candidates
[
0
]
=
candidate
}
// Get or create content
content
,
ok
:=
candidate
[
"content"
]
.
(
map
[
string
]
any
)
if
!
ok
{
content
=
map
[
string
]
any
{
"role"
:
"model"
}
candidate
[
"content"
]
=
content
}
// Get existing parts
existingParts
,
ok
:=
content
[
"parts"
]
.
([]
any
)
if
!
ok
{
existingParts
=
[]
any
{}
}
// Find and update first text part, or create new one
newParts
:=
make
([]
any
,
0
,
len
(
existingParts
)
+
1
)
textUpdated
:=
false
for
_
,
p
:=
range
existingParts
{
pm
,
ok
:=
p
.
(
map
[
string
]
any
)
if
!
ok
{
newParts
=
append
(
newParts
,
p
)
continue
}
if
_
,
hasText
:=
pm
[
"text"
];
hasText
&&
!
textUpdated
{
// Replace with merged text
newPart
:=
make
(
map
[
string
]
any
)
for
k
,
v
:=
range
pm
{
newPart
[
k
]
=
v
}
newPart
[
"text"
]
=
mergedText
newParts
=
append
(
newParts
,
newPart
)
textUpdated
=
true
}
else
{
newParts
=
append
(
newParts
,
pm
)
}
}
if
!
textUpdated
{
newParts
=
append
([]
any
{
map
[
string
]
any
{
"text"
:
mergedText
}},
newParts
...
)
}
content
[
"parts"
]
=
newParts
result
[
"candidates"
]
=
candidates
return
result
}
type
geminiNativeStreamResult
struct
{
type
geminiNativeStreamResult
struct
{
usage
*
ClaudeUsage
usage
*
ClaudeUsage
firstTokenMs
*
int
firstTokenMs
*
int
...
...
backend/internal/service/gemini_native_signature_cleaner.go
0 → 100644
View file @
a161fcc8
package
service
import
(
"encoding/json"
)
// CleanGeminiNativeThoughtSignatures 从 Gemini 原生 API 请求中移除 thoughtSignature 字段,
// 以避免跨账号签名验证错误。
//
// 当粘性会话切换账号时(例如原账号异常、不可调度等),旧账号返回的 thoughtSignature
// 会导致新账号的签名验证失败。通过移除这些签名,让新账号重新生成有效的签名。
//
// CleanGeminiNativeThoughtSignatures removes thoughtSignature fields from Gemini native API requests
// to avoid cross-account signature validation errors.
//
// When sticky session switches accounts (e.g., original account becomes unavailable),
// thoughtSignatures from the old account will cause validation failures on the new account.
// By removing these signatures, we allow the new account to generate valid signatures.
func
CleanGeminiNativeThoughtSignatures
(
body
[]
byte
)
[]
byte
{
if
len
(
body
)
==
0
{
return
body
}
// 解析 JSON
var
data
any
if
err
:=
json
.
Unmarshal
(
body
,
&
data
);
err
!=
nil
{
// 如果解析失败,返回原始 body(可能不是 JSON 或格式不正确)
return
body
}
// 递归清理 thoughtSignature
cleaned
:=
cleanThoughtSignaturesRecursive
(
data
)
// 重新序列化
result
,
err
:=
json
.
Marshal
(
cleaned
)
if
err
!=
nil
{
// 如果序列化失败,返回原始 body
return
body
}
return
result
}
// cleanThoughtSignaturesRecursive 递归遍历数据结构,移除所有 thoughtSignature 字段
func
cleanThoughtSignaturesRecursive
(
data
any
)
any
{
switch
v
:=
data
.
(
type
)
{
case
map
[
string
]
any
:
// 创建新的 map,移除 thoughtSignature
result
:=
make
(
map
[
string
]
any
,
len
(
v
))
for
key
,
value
:=
range
v
{
// 跳过 thoughtSignature 字段
if
key
==
"thoughtSignature"
{
continue
}
// 递归处理嵌套结构
result
[
key
]
=
cleanThoughtSignaturesRecursive
(
value
)
}
return
result
case
[]
any
:
// 递归处理数组中的每个元素
result
:=
make
([]
any
,
len
(
v
))
for
i
,
item
:=
range
v
{
result
[
i
]
=
cleanThoughtSignaturesRecursive
(
item
)
}
return
result
default
:
// 基本类型(string, number, bool, null)直接返回
return
v
}
}
backend/internal/service/gemini_token_provider.go
View file @
a161fcc8
...
@@ -4,6 +4,7 @@ import (
...
@@ -4,6 +4,7 @@ import (
"context"
"context"
"errors"
"errors"
"log"
"log"
"log/slog"
"strconv"
"strconv"
"strings"
"strings"
"time"
"time"
...
@@ -131,21 +132,32 @@ func (p *GeminiTokenProvider) GetAccessToken(ctx context.Context, account *Accou
...
@@ -131,21 +132,32 @@ func (p *GeminiTokenProvider) GetAccessToken(ctx context.Context, account *Accou
}
}
}
}
// 3) Populate cache with TTL
.
// 3) Populate cache with TTL
(验证版本后再写入,避免异步刷新任务与请求线程的竞态条件)
if
p
.
tokenCache
!=
nil
{
if
p
.
tokenCache
!=
nil
{
ttl
:=
30
*
time
.
Minute
latestAccount
,
isStale
:=
CheckTokenVersion
(
ctx
,
account
,
p
.
accountRepo
)
if
expiresAt
!=
nil
{
if
isStale
&&
latestAccount
!=
nil
{
until
:=
time
.
Until
(
*
expiresAt
)
// 版本过时,使用 DB 中的最新 token
switch
{
slog
.
Debug
(
"gemini_token_version_stale_use_latest"
,
"account_id"
,
account
.
ID
)
case
until
>
geminiTokenCacheSkew
:
accessToken
=
latestAccount
.
GetCredential
(
"access_token"
)
ttl
=
until
-
geminiTokenCacheSkew
if
strings
.
TrimSpace
(
accessToken
)
==
""
{
case
until
>
0
:
return
""
,
errors
.
New
(
"access_token not found after version check"
)
ttl
=
until
default
:
ttl
=
time
.
Minute
}
}
// 不写入缓存,让下次请求重新处理
}
else
{
ttl
:=
30
*
time
.
Minute
if
expiresAt
!=
nil
{
until
:=
time
.
Until
(
*
expiresAt
)
switch
{
case
until
>
geminiTokenCacheSkew
:
ttl
=
until
-
geminiTokenCacheSkew
case
until
>
0
:
ttl
=
until
default
:
ttl
=
time
.
Minute
}
}
_
=
p
.
tokenCache
.
SetAccessToken
(
ctx
,
cacheKey
,
accessToken
,
ttl
)
}
}
_
=
p
.
tokenCache
.
SetAccessToken
(
ctx
,
cacheKey
,
accessToken
,
ttl
)
}
}
return
accessToken
,
nil
return
accessToken
,
nil
...
...
backend/internal/service/oauth_service.go
View file @
a161fcc8
...
@@ -122,6 +122,7 @@ type TokenInfo struct {
...
@@ -122,6 +122,7 @@ type TokenInfo struct {
Scope
string
`json:"scope,omitempty"`
Scope
string
`json:"scope,omitempty"`
OrgUUID
string
`json:"org_uuid,omitempty"`
OrgUUID
string
`json:"org_uuid,omitempty"`
AccountUUID
string
`json:"account_uuid,omitempty"`
AccountUUID
string
`json:"account_uuid,omitempty"`
EmailAddress
string
`json:"email_address,omitempty"`
}
}
// ExchangeCode exchanges authorization code for tokens
// ExchangeCode exchanges authorization code for tokens
...
@@ -252,9 +253,15 @@ func (s *OAuthService) exchangeCodeForToken(ctx context.Context, code, codeVerif
...
@@ -252,9 +253,15 @@ func (s *OAuthService) exchangeCodeForToken(ctx context.Context, code, codeVerif
tokenInfo
.
OrgUUID
=
tokenResp
.
Organization
.
UUID
tokenInfo
.
OrgUUID
=
tokenResp
.
Organization
.
UUID
log
.
Printf
(
"[OAuth] Got org_uuid: %s"
,
tokenInfo
.
OrgUUID
)
log
.
Printf
(
"[OAuth] Got org_uuid: %s"
,
tokenInfo
.
OrgUUID
)
}
}
if
tokenResp
.
Account
!=
nil
&&
tokenResp
.
Account
.
UUID
!=
""
{
if
tokenResp
.
Account
!=
nil
{
tokenInfo
.
AccountUUID
=
tokenResp
.
Account
.
UUID
if
tokenResp
.
Account
.
UUID
!=
""
{
log
.
Printf
(
"[OAuth] Got account_uuid: %s"
,
tokenInfo
.
AccountUUID
)
tokenInfo
.
AccountUUID
=
tokenResp
.
Account
.
UUID
log
.
Printf
(
"[OAuth] Got account_uuid: %s"
,
tokenInfo
.
AccountUUID
)
}
if
tokenResp
.
Account
.
EmailAddress
!=
""
{
tokenInfo
.
EmailAddress
=
tokenResp
.
Account
.
EmailAddress
log
.
Printf
(
"[OAuth] Got email_address: %s"
,
tokenInfo
.
EmailAddress
)
}
}
}
return
tokenInfo
,
nil
return
tokenInfo
,
nil
...
...
backend/internal/service/openai_gateway_service.go
View file @
a161fcc8
...
@@ -60,6 +60,92 @@ type OpenAICodexUsageSnapshot struct {
...
@@ -60,6 +60,92 @@ type OpenAICodexUsageSnapshot struct {
UpdatedAt
string
`json:"updated_at,omitempty"`
UpdatedAt
string
`json:"updated_at,omitempty"`
}
}
// NormalizedCodexLimits contains normalized 5h/7d rate limit data
type
NormalizedCodexLimits
struct
{
Used5hPercent
*
float64
Reset5hSeconds
*
int
Window5hMinutes
*
int
Used7dPercent
*
float64
Reset7dSeconds
*
int
Window7dMinutes
*
int
}
// Normalize converts primary/secondary fields to canonical 5h/7d fields.
// Strategy: Compare window_minutes to determine which is 5h vs 7d.
// Returns nil if snapshot is nil or has no useful data.
func
(
s
*
OpenAICodexUsageSnapshot
)
Normalize
()
*
NormalizedCodexLimits
{
if
s
==
nil
{
return
nil
}
result
:=
&
NormalizedCodexLimits
{}
primaryMins
:=
0
secondaryMins
:=
0
hasPrimaryWindow
:=
false
hasSecondaryWindow
:=
false
if
s
.
PrimaryWindowMinutes
!=
nil
{
primaryMins
=
*
s
.
PrimaryWindowMinutes
hasPrimaryWindow
=
true
}
if
s
.
SecondaryWindowMinutes
!=
nil
{
secondaryMins
=
*
s
.
SecondaryWindowMinutes
hasSecondaryWindow
=
true
}
// Determine mapping based on window_minutes
use5hFromPrimary
:=
false
use7dFromPrimary
:=
false
if
hasPrimaryWindow
&&
hasSecondaryWindow
{
// Both known: smaller window is 5h, larger is 7d
if
primaryMins
<
secondaryMins
{
use5hFromPrimary
=
true
}
else
{
use7dFromPrimary
=
true
}
}
else
if
hasPrimaryWindow
{
// Only primary known: classify by threshold (<=360 min = 6h -> 5h window)
if
primaryMins
<=
360
{
use5hFromPrimary
=
true
}
else
{
use7dFromPrimary
=
true
}
}
else
if
hasSecondaryWindow
{
// Only secondary known: classify by threshold
if
secondaryMins
<=
360
{
// 5h from secondary, so primary (if any data) is 7d
use7dFromPrimary
=
true
}
else
{
// 7d from secondary, so primary (if any data) is 5h
use5hFromPrimary
=
true
}
}
else
{
// No window_minutes: fall back to legacy assumption (primary=7d, secondary=5h)
use7dFromPrimary
=
true
}
// Assign values
if
use5hFromPrimary
{
result
.
Used5hPercent
=
s
.
PrimaryUsedPercent
result
.
Reset5hSeconds
=
s
.
PrimaryResetAfterSeconds
result
.
Window5hMinutes
=
s
.
PrimaryWindowMinutes
result
.
Used7dPercent
=
s
.
SecondaryUsedPercent
result
.
Reset7dSeconds
=
s
.
SecondaryResetAfterSeconds
result
.
Window7dMinutes
=
s
.
SecondaryWindowMinutes
}
else
if
use7dFromPrimary
{
result
.
Used7dPercent
=
s
.
PrimaryUsedPercent
result
.
Reset7dSeconds
=
s
.
PrimaryResetAfterSeconds
result
.
Window7dMinutes
=
s
.
PrimaryWindowMinutes
result
.
Used5hPercent
=
s
.
SecondaryUsedPercent
result
.
Reset5hSeconds
=
s
.
SecondaryResetAfterSeconds
result
.
Window5hMinutes
=
s
.
SecondaryWindowMinutes
}
return
result
}
// OpenAIUsage represents OpenAI API response usage
// OpenAIUsage represents OpenAI API response usage
type
OpenAIUsage
struct
{
type
OpenAIUsage
struct
{
InputTokens
int
`json:"input_tokens"`
InputTokens
int
`json:"input_tokens"`
...
@@ -867,7 +953,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
...
@@ -867,7 +953,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
// Extract and save Codex usage snapshot from response headers (for OAuth accounts)
// Extract and save Codex usage snapshot from response headers (for OAuth accounts)
if
account
.
Type
==
AccountTypeOAuth
{
if
account
.
Type
==
AccountTypeOAuth
{
if
snapshot
:=
extractCodexUsage
Headers
(
resp
.
Header
);
snapshot
!=
nil
{
if
snapshot
:=
ParseCodexRateLimit
Headers
(
resp
.
Header
);
snapshot
!=
nil
{
s
.
updateCodexUsageSnapshot
(
ctx
,
account
.
ID
,
snapshot
)
s
.
updateCodexUsageSnapshot
(
ctx
,
account
.
ID
,
snapshot
)
}
}
}
}
...
@@ -1706,8 +1792,9 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
...
@@ -1706,8 +1792,9 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
return
nil
return
nil
}
}
// extractCodexUsageHeaders extracts Codex usage limits from response headers
// ParseCodexRateLimitHeaders extracts Codex usage limits from response headers.
func
extractCodexUsageHeaders
(
headers
http
.
Header
)
*
OpenAICodexUsageSnapshot
{
// Exported for use in ratelimit_service when handling OpenAI 429 responses.
func
ParseCodexRateLimitHeaders
(
headers
http
.
Header
)
*
OpenAICodexUsageSnapshot
{
snapshot
:=
&
OpenAICodexUsageSnapshot
{}
snapshot
:=
&
OpenAICodexUsageSnapshot
{}
hasData
:=
false
hasData
:=
false
...
@@ -1781,6 +1868,8 @@ func (s *OpenAIGatewayService) updateCodexUsageSnapshot(ctx context.Context, acc
...
@@ -1781,6 +1868,8 @@ func (s *OpenAIGatewayService) updateCodexUsageSnapshot(ctx context.Context, acc
// Convert snapshot to map for merging into Extra
// Convert snapshot to map for merging into Extra
updates
:=
make
(
map
[
string
]
any
)
updates
:=
make
(
map
[
string
]
any
)
// Save raw primary/secondary fields for debugging/tracing
if
snapshot
.
PrimaryUsedPercent
!=
nil
{
if
snapshot
.
PrimaryUsedPercent
!=
nil
{
updates
[
"codex_primary_used_percent"
]
=
*
snapshot
.
PrimaryUsedPercent
updates
[
"codex_primary_used_percent"
]
=
*
snapshot
.
PrimaryUsedPercent
}
}
...
@@ -1804,109 +1893,25 @@ func (s *OpenAIGatewayService) updateCodexUsageSnapshot(ctx context.Context, acc
...
@@ -1804,109 +1893,25 @@ func (s *OpenAIGatewayService) updateCodexUsageSnapshot(ctx context.Context, acc
}
}
updates
[
"codex_usage_updated_at"
]
=
snapshot
.
UpdatedAt
updates
[
"codex_usage_updated_at"
]
=
snapshot
.
UpdatedAt
// Normalize to canonical 5h/7d fields based on window_minutes
// Normalize to canonical 5h/7d fields
// This fixes the issue where OpenAI's primary/secondary naming is reversed
if
normalized
:=
snapshot
.
Normalize
();
normalized
!=
nil
{
// Strategy: Compare the two windows and assign the smaller one to 5h, larger one to 7d
if
normalized
.
Used5hPercent
!=
nil
{
updates
[
"codex_5h_used_percent"
]
=
*
normalized
.
Used5hPercent
// IMPORTANT: We can only reliably determine window type from window_minutes field
// The reset_after_seconds is remaining time, not window size, so it cannot be used for comparison
var
primaryWindowMins
,
secondaryWindowMins
int
var
hasPrimaryWindow
,
hasSecondaryWindow
bool
// Only use window_minutes for reliable window size comparison
if
snapshot
.
PrimaryWindowMinutes
!=
nil
{
primaryWindowMins
=
*
snapshot
.
PrimaryWindowMinutes
hasPrimaryWindow
=
true
}
if
snapshot
.
SecondaryWindowMinutes
!=
nil
{
secondaryWindowMins
=
*
snapshot
.
SecondaryWindowMinutes
hasSecondaryWindow
=
true
}
// Determine which is 5h and which is 7d
var
use5hFromPrimary
,
use7dFromPrimary
bool
var
use5hFromSecondary
,
use7dFromSecondary
bool
if
hasPrimaryWindow
&&
hasSecondaryWindow
{
// Both window sizes known: compare and assign smaller to 5h, larger to 7d
if
primaryWindowMins
<
secondaryWindowMins
{
use5hFromPrimary
=
true
use7dFromSecondary
=
true
}
else
{
use5hFromSecondary
=
true
use7dFromPrimary
=
true
}
}
else
if
hasPrimaryWindow
{
// Only primary window size known: classify by absolute threshold
if
primaryWindowMins
<=
360
{
use5hFromPrimary
=
true
}
else
{
use7dFromPrimary
=
true
}
}
else
if
hasSecondaryWindow
{
// Only secondary window size known: classify by absolute threshold
if
secondaryWindowMins
<=
360
{
use5hFromSecondary
=
true
}
else
{
use7dFromSecondary
=
true
}
}
else
{
// No window_minutes available: cannot reliably determine window types
// Fall back to legacy assumption (may be incorrect)
// Assume primary=7d, secondary=5h based on historical observation
if
snapshot
.
SecondaryUsedPercent
!=
nil
||
snapshot
.
SecondaryResetAfterSeconds
!=
nil
||
snapshot
.
SecondaryWindowMinutes
!=
nil
{
use5hFromSecondary
=
true
}
if
snapshot
.
PrimaryUsedPercent
!=
nil
||
snapshot
.
PrimaryResetAfterSeconds
!=
nil
||
snapshot
.
PrimaryWindowMinutes
!=
nil
{
use7dFromPrimary
=
true
}
}
// Write canonical 5h fields
if
use5hFromPrimary
{
if
snapshot
.
PrimaryUsedPercent
!=
nil
{
updates
[
"codex_5h_used_percent"
]
=
*
snapshot
.
PrimaryUsedPercent
}
if
snapshot
.
PrimaryResetAfterSeconds
!=
nil
{
updates
[
"codex_5h_reset_after_seconds"
]
=
*
snapshot
.
PrimaryResetAfterSeconds
}
if
snapshot
.
PrimaryWindowMinutes
!=
nil
{
updates
[
"codex_5h_window_minutes"
]
=
*
snapshot
.
PrimaryWindowMinutes
}
}
else
if
use5hFromSecondary
{
if
snapshot
.
SecondaryUsedPercent
!=
nil
{
updates
[
"codex_5h_used_percent"
]
=
*
snapshot
.
SecondaryUsedPercent
}
if
snapshot
.
SecondaryResetAfterSeconds
!=
nil
{
updates
[
"codex_5h_reset_after_seconds"
]
=
*
snapshot
.
SecondaryResetAfterSeconds
}
if
snapshot
.
SecondaryWindowMinutes
!=
nil
{
updates
[
"codex_5h_window_minutes"
]
=
*
snapshot
.
SecondaryWindowMinutes
}
}
// Write canonical 7d fields
if
use7dFromPrimary
{
if
snapshot
.
PrimaryUsedPercent
!=
nil
{
updates
[
"codex_7d_used_percent"
]
=
*
snapshot
.
PrimaryUsedPercent
}
}
if
snapshot
.
PrimaryResetAfter
Seconds
!=
nil
{
if
normalized
.
Reset5h
Seconds
!=
nil
{
updates
[
"codex_
7d
_reset_after_seconds"
]
=
*
snapshot
.
PrimaryResetAfter
Seconds
updates
[
"codex_
5h
_reset_after_seconds"
]
=
*
normalized
.
Reset5h
Seconds
}
}
if
snapshot
.
Primary
WindowMinutes
!=
nil
{
if
normalized
.
Window
5h
Minutes
!=
nil
{
updates
[
"codex_
7d
_window_minutes"
]
=
*
snapshot
.
Primary
WindowMinutes
updates
[
"codex_
5h
_window_minutes"
]
=
*
normalized
.
Window
5h
Minutes
}
}
}
else
if
use7dFromSecondary
{
if
normalized
.
Used7dPercent
!=
nil
{
if
snapshot
.
SecondaryUsedPercent
!=
nil
{
updates
[
"codex_7d_used_percent"
]
=
*
normalized
.
Used7dPercent
updates
[
"codex_7d_used_percent"
]
=
*
snapshot
.
SecondaryUsedPercent
}
}
if
snapshot
.
SecondaryResetAfter
Seconds
!=
nil
{
if
normalized
.
Reset7d
Seconds
!=
nil
{
updates
[
"codex_7d_reset_after_seconds"
]
=
*
snapshot
.
SecondaryResetAfter
Seconds
updates
[
"codex_7d_reset_after_seconds"
]
=
*
normalized
.
Reset7d
Seconds
}
}
if
snapshot
.
Secondary
WindowMinutes
!=
nil
{
if
normalized
.
Window
7d
Minutes
!=
nil
{
updates
[
"codex_7d_window_minutes"
]
=
*
snapshot
.
Secondary
WindowMinutes
updates
[
"codex_7d_window_minutes"
]
=
*
normalized
.
Window
7d
Minutes
}
}
}
}
...
...
backend/internal/service/openai_token_provider.go
View file @
a161fcc8
...
@@ -162,26 +162,37 @@ func (p *OpenAITokenProvider) GetAccessToken(ctx context.Context, account *Accou
...
@@ -162,26 +162,37 @@ func (p *OpenAITokenProvider) GetAccessToken(ctx context.Context, account *Accou
return
""
,
errors
.
New
(
"access_token not found in credentials"
)
return
""
,
errors
.
New
(
"access_token not found in credentials"
)
}
}
// 3. 存入缓存
// 3. 存入缓存
(验证版本后再写入,避免异步刷新任务与请求线程的竞态条件)
if
p
.
tokenCache
!=
nil
{
if
p
.
tokenCache
!=
nil
{
ttl
:=
30
*
time
.
Minute
latestAccount
,
isStale
:=
CheckTokenVersion
(
ctx
,
account
,
p
.
accountRepo
)
if
refreshFailed
{
if
isStale
&&
latestAccount
!=
nil
{
//
刷新失败时使用短 TTL,避免失效 token 长时间缓存导致 401 抖动
//
版本过时,使用 DB 中的最新 token
ttl
=
time
.
Minute
slog
.
Debug
(
"openai_token_version_stale_use_latest"
,
"account_id"
,
account
.
ID
)
slog
.
Debug
(
"openai_token_cache_short_ttl"
,
"account_id"
,
account
.
ID
,
"reason"
,
"refresh_failed"
)
accessToken
=
latestAccount
.
GetOpenAIAccessToken
(
)
}
else
if
expiresAt
!=
nil
{
if
strings
.
TrimSpace
(
accessToken
)
==
""
{
until
:=
time
.
Until
(
*
expiresAt
)
return
""
,
errors
.
New
(
"access_token not found after version check"
)
switch
{
}
case
until
>
openAITokenCacheSkew
:
// 不写入缓存,让下次请求重新处理
ttl
=
until
-
openAITokenCacheSkew
}
else
{
case
until
>
0
:
ttl
:=
30
*
time
.
Minute
ttl
=
until
if
refreshFailed
{
default
:
// 刷新失败时使用短 TTL,避免失效 token 长时间缓存导致 401 抖动
ttl
=
time
.
Minute
ttl
=
time
.
Minute
slog
.
Debug
(
"openai_token_cache_short_ttl"
,
"account_id"
,
account
.
ID
,
"reason"
,
"refresh_failed"
)
}
else
if
expiresAt
!=
nil
{
until
:=
time
.
Until
(
*
expiresAt
)
switch
{
case
until
>
openAITokenCacheSkew
:
ttl
=
until
-
openAITokenCacheSkew
case
until
>
0
:
ttl
=
until
default
:
ttl
=
time
.
Minute
}
}
if
err
:=
p
.
tokenCache
.
SetAccessToken
(
ctx
,
cacheKey
,
accessToken
,
ttl
);
err
!=
nil
{
slog
.
Warn
(
"openai_token_cache_set_failed"
,
"account_id"
,
account
.
ID
,
"error"
,
err
)
}
}
}
if
err
:=
p
.
tokenCache
.
SetAccessToken
(
ctx
,
cacheKey
,
accessToken
,
ttl
);
err
!=
nil
{
slog
.
Warn
(
"openai_token_cache_set_failed"
,
"account_id"
,
account
.
ID
,
"error"
,
err
)
}
}
}
}
...
...
backend/internal/service/ratelimit_service.go
View file @
a161fcc8
...
@@ -343,9 +343,48 @@ func (s *RateLimitService) handleCustomErrorCode(ctx context.Context, account *A
...
@@ -343,9 +343,48 @@ func (s *RateLimitService) handleCustomErrorCode(ctx context.Context, account *A
// handle429 处理429限流错误
// handle429 处理429限流错误
// 解析响应头获取重置时间,标记账号为限流状态
// 解析响应头获取重置时间,标记账号为限流状态
func
(
s
*
RateLimitService
)
handle429
(
ctx
context
.
Context
,
account
*
Account
,
headers
http
.
Header
,
responseBody
[]
byte
)
{
func
(
s
*
RateLimitService
)
handle429
(
ctx
context
.
Context
,
account
*
Account
,
headers
http
.
Header
,
responseBody
[]
byte
)
{
// 解析重置时间戳
// 1. OpenAI 平台:优先尝试解析 x-codex-* 响应头(用于 rate_limit_exceeded)
if
account
.
Platform
==
PlatformOpenAI
{
if
resetAt
:=
s
.
calculateOpenAI429ResetTime
(
headers
);
resetAt
!=
nil
{
if
err
:=
s
.
accountRepo
.
SetRateLimited
(
ctx
,
account
.
ID
,
*
resetAt
);
err
!=
nil
{
slog
.
Warn
(
"rate_limit_set_failed"
,
"account_id"
,
account
.
ID
,
"error"
,
err
)
return
}
slog
.
Info
(
"openai_account_rate_limited"
,
"account_id"
,
account
.
ID
,
"reset_at"
,
*
resetAt
)
return
}
}
// 2. 尝试从响应头解析重置时间(Anthropic)
resetTimestamp
:=
headers
.
Get
(
"anthropic-ratelimit-unified-reset"
)
resetTimestamp
:=
headers
.
Get
(
"anthropic-ratelimit-unified-reset"
)
// 3. 如果响应头没有,尝试从响应体解析(OpenAI usage_limit_reached, Gemini)
if
resetTimestamp
==
""
{
if
resetTimestamp
==
""
{
switch
account
.
Platform
{
case
PlatformOpenAI
:
// 尝试解析 OpenAI 的 usage_limit_reached 错误
if
resetAt
:=
parseOpenAIRateLimitResetTime
(
responseBody
);
resetAt
!=
nil
{
resetTime
:=
time
.
Unix
(
*
resetAt
,
0
)
if
err
:=
s
.
accountRepo
.
SetRateLimited
(
ctx
,
account
.
ID
,
resetTime
);
err
!=
nil
{
slog
.
Warn
(
"rate_limit_set_failed"
,
"account_id"
,
account
.
ID
,
"error"
,
err
)
return
}
slog
.
Info
(
"account_rate_limited"
,
"account_id"
,
account
.
ID
,
"platform"
,
account
.
Platform
,
"reset_at"
,
resetTime
,
"reset_in"
,
time
.
Until
(
resetTime
)
.
Truncate
(
time
.
Second
))
return
}
case
PlatformGemini
,
PlatformAntigravity
:
// 尝试解析 Gemini 格式(用于其他平台)
if
resetAt
:=
ParseGeminiRateLimitResetTime
(
responseBody
);
resetAt
!=
nil
{
resetTime
:=
time
.
Unix
(
*
resetAt
,
0
)
if
err
:=
s
.
accountRepo
.
SetRateLimited
(
ctx
,
account
.
ID
,
resetTime
);
err
!=
nil
{
slog
.
Warn
(
"rate_limit_set_failed"
,
"account_id"
,
account
.
ID
,
"error"
,
err
)
return
}
slog
.
Info
(
"account_rate_limited"
,
"account_id"
,
account
.
ID
,
"platform"
,
account
.
Platform
,
"reset_at"
,
resetTime
,
"reset_in"
,
time
.
Until
(
resetTime
)
.
Truncate
(
time
.
Second
))
return
}
}
// 没有重置时间,使用默认5分钟
// 没有重置时间,使用默认5分钟
resetAt
:=
time
.
Now
()
.
Add
(
5
*
time
.
Minute
)
resetAt
:=
time
.
Now
()
.
Add
(
5
*
time
.
Minute
)
if
s
.
shouldScopeClaudeSonnetRateLimit
(
account
,
responseBody
)
{
if
s
.
shouldScopeClaudeSonnetRateLimit
(
account
,
responseBody
)
{
...
@@ -356,6 +395,7 @@ func (s *RateLimitService) handle429(ctx context.Context, account *Account, head
...
@@ -356,6 +395,7 @@ func (s *RateLimitService) handle429(ctx context.Context, account *Account, head
}
}
return
return
}
}
slog
.
Warn
(
"rate_limit_no_reset_time"
,
"account_id"
,
account
.
ID
,
"platform"
,
account
.
Platform
,
"using_default"
,
"5m"
)
if
err
:=
s
.
accountRepo
.
SetRateLimited
(
ctx
,
account
.
ID
,
resetAt
);
err
!=
nil
{
if
err
:=
s
.
accountRepo
.
SetRateLimited
(
ctx
,
account
.
ID
,
resetAt
);
err
!=
nil
{
slog
.
Warn
(
"rate_limit_set_failed"
,
"account_id"
,
account
.
ID
,
"error"
,
err
)
slog
.
Warn
(
"rate_limit_set_failed"
,
"account_id"
,
account
.
ID
,
"error"
,
err
)
}
}
...
@@ -419,6 +459,108 @@ func (s *RateLimitService) shouldScopeClaudeSonnetRateLimit(account *Account, re
...
@@ -419,6 +459,108 @@ func (s *RateLimitService) shouldScopeClaudeSonnetRateLimit(account *Account, re
return
strings
.
Contains
(
msg
,
"sonnet"
)
return
strings
.
Contains
(
msg
,
"sonnet"
)
}
}
// calculateOpenAI429ResetTime 从 OpenAI 429 响应头计算正确的重置时间
// 返回 nil 表示无法从响应头中确定重置时间
func
(
s
*
RateLimitService
)
calculateOpenAI429ResetTime
(
headers
http
.
Header
)
*
time
.
Time
{
snapshot
:=
ParseCodexRateLimitHeaders
(
headers
)
if
snapshot
==
nil
{
return
nil
}
normalized
:=
snapshot
.
Normalize
()
if
normalized
==
nil
{
return
nil
}
now
:=
time
.
Now
()
// 判断哪个限制被触发(used_percent >= 100)
is7dExhausted
:=
normalized
.
Used7dPercent
!=
nil
&&
*
normalized
.
Used7dPercent
>=
100
is5hExhausted
:=
normalized
.
Used5hPercent
!=
nil
&&
*
normalized
.
Used5hPercent
>=
100
// 优先使用被触发限制的重置时间
if
is7dExhausted
&&
normalized
.
Reset7dSeconds
!=
nil
{
resetAt
:=
now
.
Add
(
time
.
Duration
(
*
normalized
.
Reset7dSeconds
)
*
time
.
Second
)
slog
.
Info
(
"openai_429_7d_limit_exhausted"
,
"reset_after_seconds"
,
*
normalized
.
Reset7dSeconds
,
"reset_at"
,
resetAt
)
return
&
resetAt
}
if
is5hExhausted
&&
normalized
.
Reset5hSeconds
!=
nil
{
resetAt
:=
now
.
Add
(
time
.
Duration
(
*
normalized
.
Reset5hSeconds
)
*
time
.
Second
)
slog
.
Info
(
"openai_429_5h_limit_exhausted"
,
"reset_after_seconds"
,
*
normalized
.
Reset5hSeconds
,
"reset_at"
,
resetAt
)
return
&
resetAt
}
// 都未达到100%但收到429,使用较长的重置时间
var
maxResetSecs
int
if
normalized
.
Reset7dSeconds
!=
nil
&&
*
normalized
.
Reset7dSeconds
>
maxResetSecs
{
maxResetSecs
=
*
normalized
.
Reset7dSeconds
}
if
normalized
.
Reset5hSeconds
!=
nil
&&
*
normalized
.
Reset5hSeconds
>
maxResetSecs
{
maxResetSecs
=
*
normalized
.
Reset5hSeconds
}
if
maxResetSecs
>
0
{
resetAt
:=
now
.
Add
(
time
.
Duration
(
maxResetSecs
)
*
time
.
Second
)
slog
.
Info
(
"openai_429_using_max_reset"
,
"max_reset_seconds"
,
maxResetSecs
,
"reset_at"
,
resetAt
)
return
&
resetAt
}
return
nil
}
// parseOpenAIRateLimitResetTime 解析 OpenAI 格式的 429 响应,返回重置时间的 Unix 时间戳
// OpenAI 的 usage_limit_reached 错误格式:
//
// {
// "error": {
// "message": "The usage limit has been reached",
// "type": "usage_limit_reached",
// "resets_at": 1769404154,
// "resets_in_seconds": 133107
// }
// }
func
parseOpenAIRateLimitResetTime
(
body
[]
byte
)
*
int64
{
var
parsed
map
[
string
]
any
if
err
:=
json
.
Unmarshal
(
body
,
&
parsed
);
err
!=
nil
{
return
nil
}
errObj
,
ok
:=
parsed
[
"error"
]
.
(
map
[
string
]
any
)
if
!
ok
{
return
nil
}
// 检查是否为 usage_limit_reached 或 rate_limit_exceeded 类型
errType
,
_
:=
errObj
[
"type"
]
.
(
string
)
if
errType
!=
"usage_limit_reached"
&&
errType
!=
"rate_limit_exceeded"
{
return
nil
}
// 优先使用 resets_at(Unix 时间戳)
if
resetsAt
,
ok
:=
errObj
[
"resets_at"
]
.
(
float64
);
ok
{
ts
:=
int64
(
resetsAt
)
return
&
ts
}
if
resetsAt
,
ok
:=
errObj
[
"resets_at"
]
.
(
string
);
ok
{
if
ts
,
err
:=
strconv
.
ParseInt
(
resetsAt
,
10
,
64
);
err
==
nil
{
return
&
ts
}
}
// 如果没有 resets_at,尝试使用 resets_in_seconds
if
resetsInSeconds
,
ok
:=
errObj
[
"resets_in_seconds"
]
.
(
float64
);
ok
{
ts
:=
time
.
Now
()
.
Unix
()
+
int64
(
resetsInSeconds
)
return
&
ts
}
if
resetsInSeconds
,
ok
:=
errObj
[
"resets_in_seconds"
]
.
(
string
);
ok
{
if
sec
,
err
:=
strconv
.
ParseInt
(
resetsInSeconds
,
10
,
64
);
err
==
nil
{
ts
:=
time
.
Now
()
.
Unix
()
+
sec
return
&
ts
}
}
return
nil
}
// handle529 处理529过载错误
// handle529 处理529过载错误
// 根据配置设置过载冷却时间
// 根据配置设置过载冷却时间
func
(
s
*
RateLimitService
)
handle529
(
ctx
context
.
Context
,
account
*
Account
)
{
func
(
s
*
RateLimitService
)
handle529
(
ctx
context
.
Context
,
account
*
Account
)
{
...
...
backend/internal/service/ratelimit_service_openai_test.go
0 → 100644
View file @
a161fcc8
package
service
import
(
"net/http"
"testing"
"time"
)
func
TestCalculateOpenAI429ResetTime_7dExhausted
(
t
*
testing
.
T
)
{
svc
:=
&
RateLimitService
{}
// Simulate headers when 7d limit is exhausted (100% used)
// Primary = 7d (10080 minutes), Secondary = 5h (300 minutes)
headers
:=
http
.
Header
{}
headers
.
Set
(
"x-codex-primary-used-percent"
,
"100"
)
headers
.
Set
(
"x-codex-primary-reset-after-seconds"
,
"384607"
)
// ~4.5 days
headers
.
Set
(
"x-codex-primary-window-minutes"
,
"10080"
)
// 7 days
headers
.
Set
(
"x-codex-secondary-used-percent"
,
"3"
)
headers
.
Set
(
"x-codex-secondary-reset-after-seconds"
,
"17369"
)
// ~4.8 hours
headers
.
Set
(
"x-codex-secondary-window-minutes"
,
"300"
)
// 5 hours
before
:=
time
.
Now
()
resetAt
:=
svc
.
calculateOpenAI429ResetTime
(
headers
)
after
:=
time
.
Now
()
if
resetAt
==
nil
{
t
.
Fatal
(
"expected non-nil resetAt"
)
}
// Should be approximately 384607 seconds from now
expectedDuration
:=
384607
*
time
.
Second
minExpected
:=
before
.
Add
(
expectedDuration
)
maxExpected
:=
after
.
Add
(
expectedDuration
)
if
resetAt
.
Before
(
minExpected
)
||
resetAt
.
After
(
maxExpected
)
{
t
.
Errorf
(
"resetAt %v not in expected range [%v, %v]"
,
resetAt
,
minExpected
,
maxExpected
)
}
}
func
TestCalculateOpenAI429ResetTime_5hExhausted
(
t
*
testing
.
T
)
{
svc
:=
&
RateLimitService
{}
// Simulate headers when 5h limit is exhausted (100% used)
headers
:=
http
.
Header
{}
headers
.
Set
(
"x-codex-primary-used-percent"
,
"50"
)
headers
.
Set
(
"x-codex-primary-reset-after-seconds"
,
"500000"
)
headers
.
Set
(
"x-codex-primary-window-minutes"
,
"10080"
)
// 7 days
headers
.
Set
(
"x-codex-secondary-used-percent"
,
"100"
)
headers
.
Set
(
"x-codex-secondary-reset-after-seconds"
,
"3600"
)
// 1 hour
headers
.
Set
(
"x-codex-secondary-window-minutes"
,
"300"
)
// 5 hours
before
:=
time
.
Now
()
resetAt
:=
svc
.
calculateOpenAI429ResetTime
(
headers
)
after
:=
time
.
Now
()
if
resetAt
==
nil
{
t
.
Fatal
(
"expected non-nil resetAt"
)
}
// Should be approximately 3600 seconds from now
expectedDuration
:=
3600
*
time
.
Second
minExpected
:=
before
.
Add
(
expectedDuration
)
maxExpected
:=
after
.
Add
(
expectedDuration
)
if
resetAt
.
Before
(
minExpected
)
||
resetAt
.
After
(
maxExpected
)
{
t
.
Errorf
(
"resetAt %v not in expected range [%v, %v]"
,
resetAt
,
minExpected
,
maxExpected
)
}
}
func
TestCalculateOpenAI429ResetTime_NeitherExhausted_UsesMax
(
t
*
testing
.
T
)
{
svc
:=
&
RateLimitService
{}
// Neither limit at 100%, should use the longer reset time
headers
:=
http
.
Header
{}
headers
.
Set
(
"x-codex-primary-used-percent"
,
"80"
)
headers
.
Set
(
"x-codex-primary-reset-after-seconds"
,
"100000"
)
headers
.
Set
(
"x-codex-primary-window-minutes"
,
"10080"
)
headers
.
Set
(
"x-codex-secondary-used-percent"
,
"90"
)
headers
.
Set
(
"x-codex-secondary-reset-after-seconds"
,
"5000"
)
headers
.
Set
(
"x-codex-secondary-window-minutes"
,
"300"
)
before
:=
time
.
Now
()
resetAt
:=
svc
.
calculateOpenAI429ResetTime
(
headers
)
after
:=
time
.
Now
()
if
resetAt
==
nil
{
t
.
Fatal
(
"expected non-nil resetAt"
)
}
// Should use the max (100000 seconds from 7d window)
expectedDuration
:=
100000
*
time
.
Second
minExpected
:=
before
.
Add
(
expectedDuration
)
maxExpected
:=
after
.
Add
(
expectedDuration
)
if
resetAt
.
Before
(
minExpected
)
||
resetAt
.
After
(
maxExpected
)
{
t
.
Errorf
(
"resetAt %v not in expected range [%v, %v]"
,
resetAt
,
minExpected
,
maxExpected
)
}
}
func
TestCalculateOpenAI429ResetTime_NoCodexHeaders
(
t
*
testing
.
T
)
{
svc
:=
&
RateLimitService
{}
// No codex headers at all
headers
:=
http
.
Header
{}
headers
.
Set
(
"content-type"
,
"application/json"
)
resetAt
:=
svc
.
calculateOpenAI429ResetTime
(
headers
)
if
resetAt
!=
nil
{
t
.
Errorf
(
"expected nil resetAt when no codex headers, got %v"
,
resetAt
)
}
}
func
TestCalculateOpenAI429ResetTime_ReversedWindowOrder
(
t
*
testing
.
T
)
{
svc
:=
&
RateLimitService
{}
// Test when OpenAI sends primary as 5h and secondary as 7d (reversed)
headers
:=
http
.
Header
{}
headers
.
Set
(
"x-codex-primary-used-percent"
,
"100"
)
// This is 5h
headers
.
Set
(
"x-codex-primary-reset-after-seconds"
,
"3600"
)
// 1 hour
headers
.
Set
(
"x-codex-primary-window-minutes"
,
"300"
)
// 5 hours - smaller!
headers
.
Set
(
"x-codex-secondary-used-percent"
,
"50"
)
headers
.
Set
(
"x-codex-secondary-reset-after-seconds"
,
"500000"
)
headers
.
Set
(
"x-codex-secondary-window-minutes"
,
"10080"
)
// 7 days - larger!
before
:=
time
.
Now
()
resetAt
:=
svc
.
calculateOpenAI429ResetTime
(
headers
)
after
:=
time
.
Now
()
if
resetAt
==
nil
{
t
.
Fatal
(
"expected non-nil resetAt"
)
}
// Should correctly identify that primary is 5h (smaller window) and use its reset time
expectedDuration
:=
3600
*
time
.
Second
minExpected
:=
before
.
Add
(
expectedDuration
)
maxExpected
:=
after
.
Add
(
expectedDuration
)
if
resetAt
.
Before
(
minExpected
)
||
resetAt
.
After
(
maxExpected
)
{
t
.
Errorf
(
"resetAt %v not in expected range [%v, %v]"
,
resetAt
,
minExpected
,
maxExpected
)
}
}
func
TestNormalizedCodexLimits
(
t
*
testing
.
T
)
{
// Test the Normalize() method directly
pUsed
:=
100.0
pReset
:=
384607
pWindow
:=
10080
sUsed
:=
3.0
sReset
:=
17369
sWindow
:=
300
snapshot
:=
&
OpenAICodexUsageSnapshot
{
PrimaryUsedPercent
:
&
pUsed
,
PrimaryResetAfterSeconds
:
&
pReset
,
PrimaryWindowMinutes
:
&
pWindow
,
SecondaryUsedPercent
:
&
sUsed
,
SecondaryResetAfterSeconds
:
&
sReset
,
SecondaryWindowMinutes
:
&
sWindow
,
}
normalized
:=
snapshot
.
Normalize
()
if
normalized
==
nil
{
t
.
Fatal
(
"expected non-nil normalized"
)
}
// Primary has larger window (10080 > 300), so primary should be 7d
if
normalized
.
Used7dPercent
==
nil
||
*
normalized
.
Used7dPercent
!=
100.0
{
t
.
Errorf
(
"expected Used7dPercent=100, got %v"
,
normalized
.
Used7dPercent
)
}
if
normalized
.
Reset7dSeconds
==
nil
||
*
normalized
.
Reset7dSeconds
!=
384607
{
t
.
Errorf
(
"expected Reset7dSeconds=384607, got %v"
,
normalized
.
Reset7dSeconds
)
}
if
normalized
.
Used5hPercent
==
nil
||
*
normalized
.
Used5hPercent
!=
3.0
{
t
.
Errorf
(
"expected Used5hPercent=3, got %v"
,
normalized
.
Used5hPercent
)
}
if
normalized
.
Reset5hSeconds
==
nil
||
*
normalized
.
Reset5hSeconds
!=
17369
{
t
.
Errorf
(
"expected Reset5hSeconds=17369, got %v"
,
normalized
.
Reset5hSeconds
)
}
}
func
TestNormalizedCodexLimits_OnlyPrimaryData
(
t
*
testing
.
T
)
{
// Test when only primary has data, no window_minutes
pUsed
:=
80.0
pReset
:=
50000
snapshot
:=
&
OpenAICodexUsageSnapshot
{
PrimaryUsedPercent
:
&
pUsed
,
PrimaryResetAfterSeconds
:
&
pReset
,
// No window_minutes, no secondary data
}
normalized
:=
snapshot
.
Normalize
()
if
normalized
==
nil
{
t
.
Fatal
(
"expected non-nil normalized"
)
}
// Legacy assumption: primary=7d, secondary=5h
if
normalized
.
Used7dPercent
==
nil
||
*
normalized
.
Used7dPercent
!=
80.0
{
t
.
Errorf
(
"expected Used7dPercent=80, got %v"
,
normalized
.
Used7dPercent
)
}
if
normalized
.
Reset7dSeconds
==
nil
||
*
normalized
.
Reset7dSeconds
!=
50000
{
t
.
Errorf
(
"expected Reset7dSeconds=50000, got %v"
,
normalized
.
Reset7dSeconds
)
}
// Secondary (5h) should be nil
if
normalized
.
Used5hPercent
!=
nil
{
t
.
Errorf
(
"expected Used5hPercent=nil, got %v"
,
*
normalized
.
Used5hPercent
)
}
if
normalized
.
Reset5hSeconds
!=
nil
{
t
.
Errorf
(
"expected Reset5hSeconds=nil, got %v"
,
*
normalized
.
Reset5hSeconds
)
}
}
func
TestNormalizedCodexLimits_OnlySecondaryData
(
t
*
testing
.
T
)
{
// Test when only secondary has data, no window_minutes
sUsed
:=
60.0
sReset
:=
3000
snapshot
:=
&
OpenAICodexUsageSnapshot
{
SecondaryUsedPercent
:
&
sUsed
,
SecondaryResetAfterSeconds
:
&
sReset
,
// No window_minutes, no primary data
}
normalized
:=
snapshot
.
Normalize
()
if
normalized
==
nil
{
t
.
Fatal
(
"expected non-nil normalized"
)
}
// Legacy assumption: primary=7d, secondary=5h
// So secondary goes to 5h
if
normalized
.
Used5hPercent
==
nil
||
*
normalized
.
Used5hPercent
!=
60.0
{
t
.
Errorf
(
"expected Used5hPercent=60, got %v"
,
normalized
.
Used5hPercent
)
}
if
normalized
.
Reset5hSeconds
==
nil
||
*
normalized
.
Reset5hSeconds
!=
3000
{
t
.
Errorf
(
"expected Reset5hSeconds=3000, got %v"
,
normalized
.
Reset5hSeconds
)
}
// Primary (7d) should be nil
if
normalized
.
Used7dPercent
!=
nil
{
t
.
Errorf
(
"expected Used7dPercent=nil, got %v"
,
*
normalized
.
Used7dPercent
)
}
}
func
TestNormalizedCodexLimits_BothDataNoWindowMinutes
(
t
*
testing
.
T
)
{
// Test when both have data but no window_minutes
pUsed
:=
100.0
pReset
:=
400000
sUsed
:=
50.0
sReset
:=
10000
snapshot
:=
&
OpenAICodexUsageSnapshot
{
PrimaryUsedPercent
:
&
pUsed
,
PrimaryResetAfterSeconds
:
&
pReset
,
SecondaryUsedPercent
:
&
sUsed
,
SecondaryResetAfterSeconds
:
&
sReset
,
// No window_minutes
}
normalized
:=
snapshot
.
Normalize
()
if
normalized
==
nil
{
t
.
Fatal
(
"expected non-nil normalized"
)
}
// Legacy assumption: primary=7d, secondary=5h
if
normalized
.
Used7dPercent
==
nil
||
*
normalized
.
Used7dPercent
!=
100.0
{
t
.
Errorf
(
"expected Used7dPercent=100, got %v"
,
normalized
.
Used7dPercent
)
}
if
normalized
.
Reset7dSeconds
==
nil
||
*
normalized
.
Reset7dSeconds
!=
400000
{
t
.
Errorf
(
"expected Reset7dSeconds=400000, got %v"
,
normalized
.
Reset7dSeconds
)
}
if
normalized
.
Used5hPercent
==
nil
||
*
normalized
.
Used5hPercent
!=
50.0
{
t
.
Errorf
(
"expected Used5hPercent=50, got %v"
,
normalized
.
Used5hPercent
)
}
if
normalized
.
Reset5hSeconds
==
nil
||
*
normalized
.
Reset5hSeconds
!=
10000
{
t
.
Errorf
(
"expected Reset5hSeconds=10000, got %v"
,
normalized
.
Reset5hSeconds
)
}
}
func
TestHandle429_AnthropicPlatformUnaffected
(
t
*
testing
.
T
)
{
// Verify that Anthropic platform accounts still use the original logic
// This test ensures we don't break existing Claude account rate limiting
svc
:=
&
RateLimitService
{}
// Simulate Anthropic 429 headers
headers
:=
http
.
Header
{}
headers
.
Set
(
"anthropic-ratelimit-unified-reset"
,
"1737820800"
)
// A future Unix timestamp
// For Anthropic platform, calculateOpenAI429ResetTime should return nil
// because it only handles OpenAI platform
resetAt
:=
svc
.
calculateOpenAI429ResetTime
(
headers
)
// Should return nil since there are no x-codex-* headers
if
resetAt
!=
nil
{
t
.
Errorf
(
"expected nil for Anthropic headers, got %v"
,
resetAt
)
}
}
func
TestCalculateOpenAI429ResetTime_UserProvidedScenario
(
t
*
testing
.
T
)
{
// This is the exact scenario from the user:
// codex_7d_used_percent: 100
// codex_7d_reset_after_seconds: 384607 (约4.5天后重置)
// codex_5h_used_percent: 3
// codex_5h_reset_after_seconds: 17369 (约4.8小时后重置)
svc
:=
&
RateLimitService
{}
// Simulate headers matching user's data
// Note: We need to map the canonical 5h/7d back to primary/secondary
// Based on typical OpenAI behavior: primary=7d (larger window), secondary=5h (smaller window)
headers
:=
http
.
Header
{}
headers
.
Set
(
"x-codex-primary-used-percent"
,
"100"
)
headers
.
Set
(
"x-codex-primary-reset-after-seconds"
,
"384607"
)
headers
.
Set
(
"x-codex-primary-window-minutes"
,
"10080"
)
// 7 days = 10080 minutes
headers
.
Set
(
"x-codex-secondary-used-percent"
,
"3"
)
headers
.
Set
(
"x-codex-secondary-reset-after-seconds"
,
"17369"
)
headers
.
Set
(
"x-codex-secondary-window-minutes"
,
"300"
)
// 5 hours = 300 minutes
before
:=
time
.
Now
()
resetAt
:=
svc
.
calculateOpenAI429ResetTime
(
headers
)
after
:=
time
.
Now
()
if
resetAt
==
nil
{
t
.
Fatal
(
"expected non-nil resetAt for user scenario"
)
}
// Should use the 7d reset time (384607 seconds) since 7d limit is exhausted (100%)
expectedDuration
:=
384607
*
time
.
Second
minExpected
:=
before
.
Add
(
expectedDuration
)
maxExpected
:=
after
.
Add
(
expectedDuration
)
if
resetAt
.
Before
(
minExpected
)
||
resetAt
.
After
(
maxExpected
)
{
t
.
Errorf
(
"resetAt %v not in expected range [%v, %v]"
,
resetAt
,
minExpected
,
maxExpected
)
}
// Verify it's approximately 4.45 days (384607 seconds)
duration
:=
resetAt
.
Sub
(
before
)
actualDays
:=
duration
.
Hours
()
/
24.0
// 384607 / 86400 = ~4.45 days
if
actualDays
<
4.4
||
actualDays
>
4.5
{
t
.
Errorf
(
"expected ~4.45 days, got %.2f days"
,
actualDays
)
}
t
.
Logf
(
"User scenario: reset_at=%v, duration=%.2f days"
,
resetAt
,
actualDays
)
}
func
TestCalculateOpenAI429ResetTime_5MinFallbackWhenNoReset
(
t
*
testing
.
T
)
{
// Test that we return nil when there's used_percent but no reset_after_seconds
// This should cause the caller to use the default 5-minute fallback
svc
:=
&
RateLimitService
{}
headers
:=
http
.
Header
{}
headers
.
Set
(
"x-codex-primary-used-percent"
,
"100"
)
// No reset_after_seconds!
resetAt
:=
svc
.
calculateOpenAI429ResetTime
(
headers
)
// Should return nil since there's no reset time available
if
resetAt
!=
nil
{
t
.
Errorf
(
"expected nil when no reset_after_seconds, got %v"
,
resetAt
)
}
}
backend/internal/service/setting_service.go
View file @
a161fcc8
...
@@ -61,6 +61,8 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
...
@@ -61,6 +61,8 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
SettingKeyRegistrationEnabled
,
SettingKeyRegistrationEnabled
,
SettingKeyEmailVerifyEnabled
,
SettingKeyEmailVerifyEnabled
,
SettingKeyPromoCodeEnabled
,
SettingKeyPromoCodeEnabled
,
SettingKeyPasswordResetEnabled
,
SettingKeyTotpEnabled
,
SettingKeyTurnstileEnabled
,
SettingKeyTurnstileEnabled
,
SettingKeyTurnstileSiteKey
,
SettingKeyTurnstileSiteKey
,
SettingKeySiteName
,
SettingKeySiteName
,
...
@@ -86,21 +88,27 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
...
@@ -86,21 +88,27 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
linuxDoEnabled
=
s
.
cfg
!=
nil
&&
s
.
cfg
.
LinuxDo
.
Enabled
linuxDoEnabled
=
s
.
cfg
!=
nil
&&
s
.
cfg
.
LinuxDo
.
Enabled
}
}
// Password reset requires email verification to be enabled
emailVerifyEnabled
:=
settings
[
SettingKeyEmailVerifyEnabled
]
==
"true"
passwordResetEnabled
:=
emailVerifyEnabled
&&
settings
[
SettingKeyPasswordResetEnabled
]
==
"true"
return
&
PublicSettings
{
return
&
PublicSettings
{
RegistrationEnabled
:
settings
[
SettingKeyRegistrationEnabled
]
==
"true"
,
RegistrationEnabled
:
settings
[
SettingKeyRegistrationEnabled
]
==
"true"
,
EmailVerifyEnabled
:
settings
[
SettingKeyEmailVerifyEnabled
]
==
"true"
,
EmailVerifyEnabled
:
emailVerifyEnabled
,
PromoCodeEnabled
:
settings
[
SettingKeyPromoCodeEnabled
]
!=
"false"
,
// 默认启用
PromoCodeEnabled
:
settings
[
SettingKeyPromoCodeEnabled
]
!=
"false"
,
// 默认启用
TurnstileEnabled
:
settings
[
SettingKeyTurnstileEnabled
]
==
"true"
,
PasswordResetEnabled
:
passwordResetEnabled
,
TurnstileSiteKey
:
settings
[
SettingKeyTurnstileSiteKey
],
TotpEnabled
:
settings
[
SettingKeyTotpEnabled
]
==
"true"
,
SiteName
:
s
.
getStringOrDefault
(
settings
,
SettingKeySiteName
,
"Sub2API"
),
TurnstileEnabled
:
settings
[
SettingKeyTurnstileEnabled
]
==
"true"
,
SiteLogo
:
settings
[
SettingKeySiteLogo
],
TurnstileSiteKey
:
settings
[
SettingKeyTurnstileSiteKey
],
SiteSubtitle
:
s
.
getStringOrDefault
(
settings
,
SettingKeySiteSubtitle
,
"Subscription to API Conversion Platform"
),
SiteName
:
s
.
getStringOrDefault
(
settings
,
SettingKeySiteName
,
"Sub2API"
),
APIBaseURL
:
settings
[
SettingKeyAPIBaseURL
],
SiteLogo
:
settings
[
SettingKeySiteLogo
],
ContactInfo
:
settings
[
SettingKeyContactInfo
],
SiteSubtitle
:
s
.
getStringOrDefault
(
settings
,
SettingKeySiteSubtitle
,
"Subscription to API Conversion Platform"
),
DocURL
:
settings
[
SettingKeyDocURL
],
APIBaseURL
:
settings
[
SettingKeyAPIBaseURL
],
HomeContent
:
settings
[
SettingKeyHomeContent
],
ContactInfo
:
settings
[
SettingKeyContactInfo
],
HideCcsImportButton
:
settings
[
SettingKeyHideCcsImportButton
]
==
"true"
,
DocURL
:
settings
[
SettingKeyDocURL
],
LinuxDoOAuthEnabled
:
linuxDoEnabled
,
HomeContent
:
settings
[
SettingKeyHomeContent
],
HideCcsImportButton
:
settings
[
SettingKeyHideCcsImportButton
]
==
"true"
,
LinuxDoOAuthEnabled
:
linuxDoEnabled
,
},
nil
},
nil
}
}
...
@@ -125,37 +133,41 @@ func (s *SettingService) GetPublicSettingsForInjection(ctx context.Context) (any
...
@@ -125,37 +133,41 @@ func (s *SettingService) GetPublicSettingsForInjection(ctx context.Context) (any
// Return a struct that matches the frontend's expected format
// Return a struct that matches the frontend's expected format
return
&
struct
{
return
&
struct
{
RegistrationEnabled
bool
`json:"registration_enabled"`
RegistrationEnabled
bool
`json:"registration_enabled"`
EmailVerifyEnabled
bool
`json:"email_verify_enabled"`
EmailVerifyEnabled
bool
`json:"email_verify_enabled"`
PromoCodeEnabled
bool
`json:"promo_code_enabled"`
PromoCodeEnabled
bool
`json:"promo_code_enabled"`
TurnstileEnabled
bool
`json:"turnstile_enabled"`
PasswordResetEnabled
bool
`json:"password_reset_enabled"`
TurnstileSiteKey
string
`json:"turnstile_site_key,omitempty"`
TotpEnabled
bool
`json:"totp_enabled"`
SiteName
string
`json:"site_name"`
TurnstileEnabled
bool
`json:"turnstile_enabled"`
SiteLogo
string
`json:"site_logo,omitempty"`
TurnstileSiteKey
string
`json:"turnstile_site_key,omitempty"`
SiteSubtitle
string
`json:"site_subtitle,omitempty"`
SiteName
string
`json:"site_name"`
APIBaseURL
string
`json:"api_base_url,omitempty"`
SiteLogo
string
`json:"site_logo,omitempty"`
ContactInfo
string
`json:"contact_info,omitempty"`
SiteSubtitle
string
`json:"site_subtitle,omitempty"`
DocURL
string
`json:"doc_url,omitempty"`
APIBaseURL
string
`json:"api_base_url,omitempty"`
HomeContent
string
`json:"home_content,omitempty"`
ContactInfo
string
`json:"contact_info,omitempty"`
HideCcsImportButton
bool
`json:"hide_ccs_import_button"`
DocURL
string
`json:"doc_url,omitempty"`
LinuxDoOAuthEnabled
bool
`json:"linuxdo_oauth_enabled"`
HomeContent
string
`json:"home_content,omitempty"`
Version
string
`json:"version,omitempty"`
HideCcsImportButton
bool
`json:"hide_ccs_import_button"`
LinuxDoOAuthEnabled
bool
`json:"linuxdo_oauth_enabled"`
Version
string
`json:"version,omitempty"`
}{
}{
RegistrationEnabled
:
settings
.
RegistrationEnabled
,
RegistrationEnabled
:
settings
.
RegistrationEnabled
,
EmailVerifyEnabled
:
settings
.
EmailVerifyEnabled
,
EmailVerifyEnabled
:
settings
.
EmailVerifyEnabled
,
PromoCodeEnabled
:
settings
.
PromoCodeEnabled
,
PromoCodeEnabled
:
settings
.
PromoCodeEnabled
,
TurnstileEnabled
:
settings
.
TurnstileEnabled
,
PasswordResetEnabled
:
settings
.
PasswordResetEnabled
,
TurnstileSiteKey
:
settings
.
TurnstileSiteKey
,
TotpEnabled
:
settings
.
TotpEnabled
,
SiteName
:
settings
.
SiteName
,
TurnstileEnabled
:
settings
.
TurnstileEnabled
,
SiteLogo
:
settings
.
SiteLogo
,
TurnstileSiteKey
:
settings
.
TurnstileSiteKey
,
SiteSubtitle
:
settings
.
SiteSubtitle
,
SiteName
:
settings
.
SiteName
,
APIBaseURL
:
settings
.
APIBaseURL
,
SiteLogo
:
settings
.
SiteLogo
,
ContactInfo
:
settings
.
ContactInfo
,
SiteSubtitle
:
settings
.
SiteSubtitle
,
DocURL
:
settings
.
DocURL
,
APIBaseURL
:
settings
.
APIBaseURL
,
HomeContent
:
settings
.
HomeContent
,
ContactInfo
:
settings
.
ContactInfo
,
HideCcsImportButton
:
settings
.
HideCcsImportButton
,
DocURL
:
settings
.
DocURL
,
LinuxDoOAuthEnabled
:
settings
.
LinuxDoOAuthEnabled
,
HomeContent
:
settings
.
HomeContent
,
Version
:
s
.
version
,
HideCcsImportButton
:
settings
.
HideCcsImportButton
,
LinuxDoOAuthEnabled
:
settings
.
LinuxDoOAuthEnabled
,
Version
:
s
.
version
,
},
nil
},
nil
}
}
...
@@ -167,6 +179,8 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet
...
@@ -167,6 +179,8 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet
updates
[
SettingKeyRegistrationEnabled
]
=
strconv
.
FormatBool
(
settings
.
RegistrationEnabled
)
updates
[
SettingKeyRegistrationEnabled
]
=
strconv
.
FormatBool
(
settings
.
RegistrationEnabled
)
updates
[
SettingKeyEmailVerifyEnabled
]
=
strconv
.
FormatBool
(
settings
.
EmailVerifyEnabled
)
updates
[
SettingKeyEmailVerifyEnabled
]
=
strconv
.
FormatBool
(
settings
.
EmailVerifyEnabled
)
updates
[
SettingKeyPromoCodeEnabled
]
=
strconv
.
FormatBool
(
settings
.
PromoCodeEnabled
)
updates
[
SettingKeyPromoCodeEnabled
]
=
strconv
.
FormatBool
(
settings
.
PromoCodeEnabled
)
updates
[
SettingKeyPasswordResetEnabled
]
=
strconv
.
FormatBool
(
settings
.
PasswordResetEnabled
)
updates
[
SettingKeyTotpEnabled
]
=
strconv
.
FormatBool
(
settings
.
TotpEnabled
)
// 邮件服务设置(只有非空才更新密码)
// 邮件服务设置(只有非空才更新密码)
updates
[
SettingKeySMTPHost
]
=
settings
.
SMTPHost
updates
[
SettingKeySMTPHost
]
=
settings
.
SMTPHost
...
@@ -262,6 +276,35 @@ func (s *SettingService) IsPromoCodeEnabled(ctx context.Context) bool {
...
@@ -262,6 +276,35 @@ func (s *SettingService) IsPromoCodeEnabled(ctx context.Context) bool {
return
value
!=
"false"
return
value
!=
"false"
}
}
// IsPasswordResetEnabled 检查是否启用密码重置功能
// 要求:必须同时开启邮件验证
func
(
s
*
SettingService
)
IsPasswordResetEnabled
(
ctx
context
.
Context
)
bool
{
// Password reset requires email verification to be enabled
if
!
s
.
IsEmailVerifyEnabled
(
ctx
)
{
return
false
}
value
,
err
:=
s
.
settingRepo
.
GetValue
(
ctx
,
SettingKeyPasswordResetEnabled
)
if
err
!=
nil
{
return
false
// 默认关闭
}
return
value
==
"true"
}
// IsTotpEnabled 检查是否启用 TOTP 双因素认证功能
func
(
s
*
SettingService
)
IsTotpEnabled
(
ctx
context
.
Context
)
bool
{
value
,
err
:=
s
.
settingRepo
.
GetValue
(
ctx
,
SettingKeyTotpEnabled
)
if
err
!=
nil
{
return
false
// 默认关闭
}
return
value
==
"true"
}
// IsTotpEncryptionKeyConfigured 检查 TOTP 加密密钥是否已手动配置
// 只有手动配置了密钥才允许在管理后台启用 TOTP 功能
func
(
s
*
SettingService
)
IsTotpEncryptionKeyConfigured
()
bool
{
return
s
.
cfg
.
Totp
.
EncryptionKeyConfigured
}
// GetSiteName 获取网站名称
// GetSiteName 获取网站名称
func
(
s
*
SettingService
)
GetSiteName
(
ctx
context
.
Context
)
string
{
func
(
s
*
SettingService
)
GetSiteName
(
ctx
context
.
Context
)
string
{
value
,
err
:=
s
.
settingRepo
.
GetValue
(
ctx
,
SettingKeySiteName
)
value
,
err
:=
s
.
settingRepo
.
GetValue
(
ctx
,
SettingKeySiteName
)
...
@@ -340,10 +383,13 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error {
...
@@ -340,10 +383,13 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error {
// parseSettings 解析设置到结构体
// parseSettings 解析设置到结构体
func
(
s
*
SettingService
)
parseSettings
(
settings
map
[
string
]
string
)
*
SystemSettings
{
func
(
s
*
SettingService
)
parseSettings
(
settings
map
[
string
]
string
)
*
SystemSettings
{
emailVerifyEnabled
:=
settings
[
SettingKeyEmailVerifyEnabled
]
==
"true"
result
:=
&
SystemSettings
{
result
:=
&
SystemSettings
{
RegistrationEnabled
:
settings
[
SettingKeyRegistrationEnabled
]
==
"true"
,
RegistrationEnabled
:
settings
[
SettingKeyRegistrationEnabled
]
==
"true"
,
EmailVerifyEnabled
:
settings
[
SettingKeyE
mailVerifyEnabled
]
==
"true"
,
EmailVerifyEnabled
:
e
mailVerifyEnabled
,
PromoCodeEnabled
:
settings
[
SettingKeyPromoCodeEnabled
]
!=
"false"
,
// 默认启用
PromoCodeEnabled
:
settings
[
SettingKeyPromoCodeEnabled
]
!=
"false"
,
// 默认启用
PasswordResetEnabled
:
emailVerifyEnabled
&&
settings
[
SettingKeyPasswordResetEnabled
]
==
"true"
,
TotpEnabled
:
settings
[
SettingKeyTotpEnabled
]
==
"true"
,
SMTPHost
:
settings
[
SettingKeySMTPHost
],
SMTPHost
:
settings
[
SettingKeySMTPHost
],
SMTPUsername
:
settings
[
SettingKeySMTPUsername
],
SMTPUsername
:
settings
[
SettingKeySMTPUsername
],
SMTPFrom
:
settings
[
SettingKeySMTPFrom
],
SMTPFrom
:
settings
[
SettingKeySMTPFrom
],
...
...
backend/internal/service/settings_view.go
View file @
a161fcc8
package
service
package
service
type
SystemSettings
struct
{
type
SystemSettings
struct
{
RegistrationEnabled
bool
RegistrationEnabled
bool
EmailVerifyEnabled
bool
EmailVerifyEnabled
bool
PromoCodeEnabled
bool
PromoCodeEnabled
bool
PasswordResetEnabled
bool
TotpEnabled
bool
// TOTP 双因素认证
SMTPHost
string
SMTPHost
string
SMTPPort
int
SMTPPort
int
...
@@ -57,21 +59,23 @@ type SystemSettings struct {
...
@@ -57,21 +59,23 @@ type SystemSettings struct {
}
}
type
PublicSettings
struct
{
type
PublicSettings
struct
{
RegistrationEnabled
bool
RegistrationEnabled
bool
EmailVerifyEnabled
bool
EmailVerifyEnabled
bool
PromoCodeEnabled
bool
PromoCodeEnabled
bool
TurnstileEnabled
bool
PasswordResetEnabled
bool
TurnstileSiteKey
string
TotpEnabled
bool
// TOTP 双因素认证
SiteName
string
TurnstileEnabled
bool
SiteLogo
string
TurnstileSiteKey
string
SiteSubtitle
string
SiteName
string
APIBaseURL
string
SiteLogo
string
ContactInfo
string
SiteSubtitle
string
DocURL
string
APIBaseURL
string
HomeContent
string
ContactInfo
string
HideCcsImportButton
bool
DocURL
string
LinuxDoOAuthEnabled
bool
HomeContent
string
Version
string
HideCcsImportButton
bool
LinuxDoOAuthEnabled
bool
Version
string
}
}
// StreamTimeoutSettings 流超时处理配置(仅控制超时后的处理方式,超时判定由网关配置控制)
// StreamTimeoutSettings 流超时处理配置(仅控制超时后的处理方式,超时判定由网关配置控制)
...
...
backend/internal/service/subscription_expiry_service.go
0 → 100644
View file @
a161fcc8
package
service
import
(
"context"
"log"
"sync"
"time"
)
// SubscriptionExpiryService periodically updates expired subscription status.
type
SubscriptionExpiryService
struct
{
userSubRepo
UserSubscriptionRepository
interval
time
.
Duration
stopCh
chan
struct
{}
stopOnce
sync
.
Once
wg
sync
.
WaitGroup
}
func
NewSubscriptionExpiryService
(
userSubRepo
UserSubscriptionRepository
,
interval
time
.
Duration
)
*
SubscriptionExpiryService
{
return
&
SubscriptionExpiryService
{
userSubRepo
:
userSubRepo
,
interval
:
interval
,
stopCh
:
make
(
chan
struct
{}),
}
}
func
(
s
*
SubscriptionExpiryService
)
Start
()
{
if
s
==
nil
||
s
.
userSubRepo
==
nil
||
s
.
interval
<=
0
{
return
}
s
.
wg
.
Add
(
1
)
go
func
()
{
defer
s
.
wg
.
Done
()
ticker
:=
time
.
NewTicker
(
s
.
interval
)
defer
ticker
.
Stop
()
s
.
runOnce
()
for
{
select
{
case
<-
ticker
.
C
:
s
.
runOnce
()
case
<-
s
.
stopCh
:
return
}
}
}()
}
func
(
s
*
SubscriptionExpiryService
)
Stop
()
{
if
s
==
nil
{
return
}
s
.
stopOnce
.
Do
(
func
()
{
close
(
s
.
stopCh
)
})
s
.
wg
.
Wait
()
}
func
(
s
*
SubscriptionExpiryService
)
runOnce
()
{
ctx
,
cancel
:=
context
.
WithTimeout
(
context
.
Background
(),
10
*
time
.
Second
)
defer
cancel
()
updated
,
err
:=
s
.
userSubRepo
.
BatchUpdateExpiredStatus
(
ctx
)
if
err
!=
nil
{
log
.
Printf
(
"[SubscriptionExpiry] Update expired subscriptions failed: %v"
,
err
)
return
}
if
updated
>
0
{
log
.
Printf
(
"[SubscriptionExpiry] Updated %d expired subscriptions"
,
updated
)
}
}
backend/internal/service/subscription_service.go
View file @
a161fcc8
...
@@ -324,18 +324,31 @@ func (s *SubscriptionService) ExtendSubscription(ctx context.Context, subscripti
...
@@ -324,18 +324,31 @@ func (s *SubscriptionService) ExtendSubscription(ctx context.Context, subscripti
days
=
-
MaxValidityDays
days
=
-
MaxValidityDays
}
}
now
:=
time
.
Now
()
isExpired
:=
!
sub
.
ExpiresAt
.
After
(
now
)
// 如果订阅已过期,不允许负向调整
if
isExpired
&&
days
<
0
{
return
nil
,
infraerrors
.
BadRequest
(
"CANNOT_SHORTEN_EXPIRED"
,
"cannot shorten an expired subscription"
)
}
// 计算新的过期时间
// 计算新的过期时间
newExpiresAt
:=
sub
.
ExpiresAt
.
AddDate
(
0
,
0
,
days
)
var
newExpiresAt
time
.
Time
if
isExpired
{
// 已过期:从当前时间开始增加天数
newExpiresAt
=
now
.
AddDate
(
0
,
0
,
days
)
}
else
{
// 未过期:从原过期时间增加/减少天数
newExpiresAt
=
sub
.
ExpiresAt
.
AddDate
(
0
,
0
,
days
)
}
if
newExpiresAt
.
After
(
MaxExpiresAt
)
{
if
newExpiresAt
.
After
(
MaxExpiresAt
)
{
newExpiresAt
=
MaxExpiresAt
newExpiresAt
=
MaxExpiresAt
}
}
// 如果是缩短(负数),检查新的过期时间必须大于当前时间
// 检查新的过期时间必须大于当前时间
if
days
<
0
{
if
!
newExpiresAt
.
After
(
now
)
{
now
:=
time
.
Now
()
return
nil
,
ErrAdjustWouldExpire
if
!
newExpiresAt
.
After
(
now
)
{
return
nil
,
ErrAdjustWouldExpire
}
}
}
if
err
:=
s
.
userSubRepo
.
ExtendExpiry
(
ctx
,
subscriptionID
,
newExpiresAt
);
err
!=
nil
{
if
err
:=
s
.
userSubRepo
.
ExtendExpiry
(
ctx
,
subscriptionID
,
newExpiresAt
);
err
!=
nil
{
...
@@ -383,6 +396,7 @@ func (s *SubscriptionService) ListUserSubscriptions(ctx context.Context, userID
...
@@ -383,6 +396,7 @@ func (s *SubscriptionService) ListUserSubscriptions(ctx context.Context, userID
return
nil
,
err
return
nil
,
err
}
}
normalizeExpiredWindows
(
subs
)
normalizeExpiredWindows
(
subs
)
normalizeSubscriptionStatus
(
subs
)
return
subs
,
nil
return
subs
,
nil
}
}
...
@@ -404,17 +418,19 @@ func (s *SubscriptionService) ListGroupSubscriptions(ctx context.Context, groupI
...
@@ -404,17 +418,19 @@ func (s *SubscriptionService) ListGroupSubscriptions(ctx context.Context, groupI
return
nil
,
nil
,
err
return
nil
,
nil
,
err
}
}
normalizeExpiredWindows
(
subs
)
normalizeExpiredWindows
(
subs
)
normalizeSubscriptionStatus
(
subs
)
return
subs
,
pag
,
nil
return
subs
,
pag
,
nil
}
}
// List 获取所有订阅(分页,支持筛选)
// List 获取所有订阅(分页,支持筛选
和排序
)
func
(
s
*
SubscriptionService
)
List
(
ctx
context
.
Context
,
page
,
pageSize
int
,
userID
,
groupID
*
int64
,
status
string
)
([]
UserSubscription
,
*
pagination
.
PaginationResult
,
error
)
{
func
(
s
*
SubscriptionService
)
List
(
ctx
context
.
Context
,
page
,
pageSize
int
,
userID
,
groupID
*
int64
,
status
,
sortBy
,
sortOrder
string
)
([]
UserSubscription
,
*
pagination
.
PaginationResult
,
error
)
{
params
:=
pagination
.
PaginationParams
{
Page
:
page
,
PageSize
:
pageSize
}
params
:=
pagination
.
PaginationParams
{
Page
:
page
,
PageSize
:
pageSize
}
subs
,
pag
,
err
:=
s
.
userSubRepo
.
List
(
ctx
,
params
,
userID
,
groupID
,
status
)
subs
,
pag
,
err
:=
s
.
userSubRepo
.
List
(
ctx
,
params
,
userID
,
groupID
,
status
,
sortBy
,
sortOrder
)
if
err
!=
nil
{
if
err
!=
nil
{
return
nil
,
nil
,
err
return
nil
,
nil
,
err
}
}
normalizeExpiredWindows
(
subs
)
normalizeExpiredWindows
(
subs
)
normalizeSubscriptionStatus
(
subs
)
return
subs
,
pag
,
nil
return
subs
,
pag
,
nil
}
}
...
@@ -441,6 +457,18 @@ func normalizeExpiredWindows(subs []UserSubscription) {
...
@@ -441,6 +457,18 @@ func normalizeExpiredWindows(subs []UserSubscription) {
}
}
}
}
// normalizeSubscriptionStatus 根据实际过期时间修正状态(仅影响返回数据,不影响数据库)
// 这确保前端显示正确的状态,即使定时任务尚未更新数据库
func
normalizeSubscriptionStatus
(
subs
[]
UserSubscription
)
{
now
:=
time
.
Now
()
for
i
:=
range
subs
{
sub
:=
&
subs
[
i
]
if
sub
.
Status
==
SubscriptionStatusActive
&&
!
sub
.
ExpiresAt
.
After
(
now
)
{
sub
.
Status
=
SubscriptionStatusExpired
}
}
}
// startOfDay 返回给定时间所在日期的零点(保持原时区)
// startOfDay 返回给定时间所在日期的零点(保持原时区)
func
startOfDay
(
t
time
.
Time
)
time
.
Time
{
func
startOfDay
(
t
time
.
Time
)
time
.
Time
{
return
time
.
Date
(
t
.
Year
(),
t
.
Month
(),
t
.
Day
(),
0
,
0
,
0
,
0
,
t
.
Location
())
return
time
.
Date
(
t
.
Year
(),
t
.
Month
(),
t
.
Day
(),
0
,
0
,
0
,
0
,
t
.
Location
())
...
@@ -659,11 +687,6 @@ func (s *SubscriptionService) GetUserSubscriptionsWithProgress(ctx context.Conte
...
@@ -659,11 +687,6 @@ func (s *SubscriptionService) GetUserSubscriptionsWithProgress(ctx context.Conte
return
progresses
,
nil
return
progresses
,
nil
}
}
// UpdateExpiredSubscriptions 更新过期订阅状态(定时任务调用)
func
(
s
*
SubscriptionService
)
UpdateExpiredSubscriptions
(
ctx
context
.
Context
)
(
int64
,
error
)
{
return
s
.
userSubRepo
.
BatchUpdateExpiredStatus
(
ctx
)
}
// ValidateSubscription 验证订阅是否有效
// ValidateSubscription 验证订阅是否有效
func
(
s
*
SubscriptionService
)
ValidateSubscription
(
ctx
context
.
Context
,
sub
*
UserSubscription
)
error
{
func
(
s
*
SubscriptionService
)
ValidateSubscription
(
ctx
context
.
Context
,
sub
*
UserSubscription
)
error
{
if
sub
.
Status
==
SubscriptionStatusExpired
{
if
sub
.
Status
==
SubscriptionStatusExpired
{
...
...
backend/internal/service/token_cache_invalidator.go
View file @
a161fcc8
package
service
package
service
import
"context"
import
(
"context"
"log/slog"
"strconv"
)
type
TokenCacheInvalidator
interface
{
type
TokenCacheInvalidator
interface
{
InvalidateToken
(
ctx
context
.
Context
,
account
*
Account
)
error
InvalidateToken
(
ctx
context
.
Context
,
account
*
Account
)
error
...
@@ -24,18 +28,87 @@ func (c *CompositeTokenCacheInvalidator) InvalidateToken(ctx context.Context, ac
...
@@ -24,18 +28,87 @@ func (c *CompositeTokenCacheInvalidator) InvalidateToken(ctx context.Context, ac
return
nil
return
nil
}
}
var
cacheKey
string
var
keysToDelete
[]
string
accountIDKey
:=
"account:"
+
strconv
.
FormatInt
(
account
.
ID
,
10
)
switch
account
.
Platform
{
switch
account
.
Platform
{
case
PlatformGemini
:
case
PlatformGemini
:
cacheKey
=
GeminiTokenCacheKey
(
account
)
// Gemini 可能有两种缓存键:project_id 或 account_id
// 首次获取 token 时可能没有 project_id,之后自动检测到 project_id 后会使用新 key
// 刷新时需要同时删除两种可能的 key,确保不会遗留旧缓存
keysToDelete
=
append
(
keysToDelete
,
GeminiTokenCacheKey
(
account
))
keysToDelete
=
append
(
keysToDelete
,
"gemini:"
+
accountIDKey
)
case
PlatformAntigravity
:
case
PlatformAntigravity
:
cacheKey
=
AntigravityTokenCacheKey
(
account
)
// Antigravity 同样可能有两种缓存键
keysToDelete
=
append
(
keysToDelete
,
AntigravityTokenCacheKey
(
account
))
keysToDelete
=
append
(
keysToDelete
,
"ag:"
+
accountIDKey
)
case
PlatformOpenAI
:
case
PlatformOpenAI
:
cacheKey
=
OpenAITokenCacheKey
(
account
)
keysToDelete
=
append
(
keysToDelete
,
OpenAITokenCacheKey
(
account
)
)
case
PlatformAnthropic
:
case
PlatformAnthropic
:
cacheKey
=
ClaudeTokenCacheKey
(
account
)
keysToDelete
=
append
(
keysToDelete
,
ClaudeTokenCacheKey
(
account
)
)
default
:
default
:
return
nil
return
nil
}
}
return
c
.
cache
.
DeleteAccessToken
(
ctx
,
cacheKey
)
// 删除所有可能的缓存键(去重后)
seen
:=
make
(
map
[
string
]
bool
)
for
_
,
key
:=
range
keysToDelete
{
if
seen
[
key
]
{
continue
}
seen
[
key
]
=
true
if
err
:=
c
.
cache
.
DeleteAccessToken
(
ctx
,
key
);
err
!=
nil
{
slog
.
Warn
(
"token_cache_delete_failed"
,
"key"
,
key
,
"account_id"
,
account
.
ID
,
"error"
,
err
)
}
}
return
nil
}
// CheckTokenVersion 检查 account 的 token 版本是否已过时,并返回最新的 account
// 用于解决异步刷新任务与请求线程的竞态条件:
// 如果刷新任务已更新 token 并删除缓存,此时请求线程的旧 account 对象不应写入缓存
//
// 返回值:
// - latestAccount: 从 DB 获取的最新 account(如果查询失败则返回 nil)
// - isStale: true 表示 token 已过时(应使用 latestAccount),false 表示可以使用当前 account
func
CheckTokenVersion
(
ctx
context
.
Context
,
account
*
Account
,
repo
AccountRepository
)
(
latestAccount
*
Account
,
isStale
bool
)
{
if
account
==
nil
||
repo
==
nil
{
return
nil
,
false
}
currentVersion
:=
account
.
GetCredentialAsInt64
(
"_token_version"
)
latestAccount
,
err
:=
repo
.
GetByID
(
ctx
,
account
.
ID
)
if
err
!=
nil
||
latestAccount
==
nil
{
// 查询失败,默认允许缓存,不返回 latestAccount
return
nil
,
false
}
latestVersion
:=
latestAccount
.
GetCredentialAsInt64
(
"_token_version"
)
// 情况1: 当前 account 没有版本号,但 DB 中已有版本号
// 说明异步刷新任务已更新 token,当前 account 已过时
if
currentVersion
==
0
&&
latestVersion
>
0
{
slog
.
Debug
(
"token_version_stale_no_current_version"
,
"account_id"
,
account
.
ID
,
"latest_version"
,
latestVersion
)
return
latestAccount
,
true
}
// 情况2: 两边都没有版本号,说明从未被异步刷新过,允许缓存
if
currentVersion
==
0
&&
latestVersion
==
0
{
return
latestAccount
,
false
}
// 情况3: 比较版本号,如果 DB 中的版本更新,当前 account 已过时
if
latestVersion
>
currentVersion
{
slog
.
Debug
(
"token_version_stale"
,
"account_id"
,
account
.
ID
,
"current_version"
,
currentVersion
,
"latest_version"
,
latestVersion
)
return
latestAccount
,
true
}
return
latestAccount
,
false
}
}
backend/internal/service/token_cache_invalidator_test.go
View file @
a161fcc8
...
@@ -51,7 +51,27 @@ func TestCompositeTokenCacheInvalidator_Gemini(t *testing.T) {
...
@@ -51,7 +51,27 @@ func TestCompositeTokenCacheInvalidator_Gemini(t *testing.T) {
err
:=
invalidator
.
InvalidateToken
(
context
.
Background
(),
account
)
err
:=
invalidator
.
InvalidateToken
(
context
.
Background
(),
account
)
require
.
NoError
(
t
,
err
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
[]
string
{
"gemini:project-x"
},
cache
.
deletedKeys
)
// 新行为:同时删除基于 project_id 和 account_id 的缓存键
// 这是为了处理:首次获取 token 时可能没有 project_id,之后自动检测到后会使用新 key
require
.
Equal
(
t
,
[]
string
{
"gemini:project-x"
,
"gemini:account:10"
},
cache
.
deletedKeys
)
}
func
TestCompositeTokenCacheInvalidator_GeminiWithoutProjectID
(
t
*
testing
.
T
)
{
cache
:=
&
geminiTokenCacheStub
{}
invalidator
:=
NewCompositeTokenCacheInvalidator
(
cache
)
account
:=
&
Account
{
ID
:
10
,
Platform
:
PlatformGemini
,
Type
:
AccountTypeOAuth
,
Credentials
:
map
[
string
]
any
{
"access_token"
:
"gemini-token"
,
},
}
err
:=
invalidator
.
InvalidateToken
(
context
.
Background
(),
account
)
require
.
NoError
(
t
,
err
)
// 没有 project_id 时,两个 key 相同,去重后只删除一个
require
.
Equal
(
t
,
[]
string
{
"gemini:account:10"
},
cache
.
deletedKeys
)
}
}
func
TestCompositeTokenCacheInvalidator_Antigravity
(
t
*
testing
.
T
)
{
func
TestCompositeTokenCacheInvalidator_Antigravity
(
t
*
testing
.
T
)
{
...
@@ -68,7 +88,26 @@ func TestCompositeTokenCacheInvalidator_Antigravity(t *testing.T) {
...
@@ -68,7 +88,26 @@ func TestCompositeTokenCacheInvalidator_Antigravity(t *testing.T) {
err
:=
invalidator
.
InvalidateToken
(
context
.
Background
(),
account
)
err
:=
invalidator
.
InvalidateToken
(
context
.
Background
(),
account
)
require
.
NoError
(
t
,
err
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
[]
string
{
"ag:ag-project"
},
cache
.
deletedKeys
)
// 新行为:同时删除基于 project_id 和 account_id 的缓存键
require
.
Equal
(
t
,
[]
string
{
"ag:ag-project"
,
"ag:account:99"
},
cache
.
deletedKeys
)
}
func
TestCompositeTokenCacheInvalidator_AntigravityWithoutProjectID
(
t
*
testing
.
T
)
{
cache
:=
&
geminiTokenCacheStub
{}
invalidator
:=
NewCompositeTokenCacheInvalidator
(
cache
)
account
:=
&
Account
{
ID
:
99
,
Platform
:
PlatformAntigravity
,
Type
:
AccountTypeOAuth
,
Credentials
:
map
[
string
]
any
{
"access_token"
:
"ag-token"
,
},
}
err
:=
invalidator
.
InvalidateToken
(
context
.
Background
(),
account
)
require
.
NoError
(
t
,
err
)
// 没有 project_id 时,两个 key 相同,去重后只删除一个
require
.
Equal
(
t
,
[]
string
{
"ag:account:99"
},
cache
.
deletedKeys
)
}
}
func
TestCompositeTokenCacheInvalidator_OpenAI
(
t
*
testing
.
T
)
{
func
TestCompositeTokenCacheInvalidator_OpenAI
(
t
*
testing
.
T
)
{
...
@@ -233,9 +272,10 @@ func TestCompositeTokenCacheInvalidator_DeleteError(t *testing.T) {
...
@@ -233,9 +272,10 @@ func TestCompositeTokenCacheInvalidator_DeleteError(t *testing.T) {
for
_
,
tt
:=
range
tests
{
for
_
,
tt
:=
range
tests
{
t
.
Run
(
tt
.
name
,
func
(
t
*
testing
.
T
)
{
t
.
Run
(
tt
.
name
,
func
(
t
*
testing
.
T
)
{
// 新行为:删除失败只记录日志,不返回错误
// 这是因为缓存失效失败不应影响主业务流程
err
:=
invalidator
.
InvalidateToken
(
context
.
Background
(),
tt
.
account
)
err
:=
invalidator
.
InvalidateToken
(
context
.
Background
(),
tt
.
account
)
require
.
Error
(
t
,
err
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
expectedErr
,
err
)
})
})
}
}
}
}
...
@@ -252,9 +292,12 @@ func TestCompositeTokenCacheInvalidator_AllPlatformsIntegration(t *testing.T) {
...
@@ -252,9 +292,12 @@ func TestCompositeTokenCacheInvalidator_AllPlatformsIntegration(t *testing.T) {
{
ID
:
4
,
Platform
:
PlatformAnthropic
,
Type
:
AccountTypeOAuth
},
{
ID
:
4
,
Platform
:
PlatformAnthropic
,
Type
:
AccountTypeOAuth
},
}
}
// 新行为:Gemini 和 Antigravity 会同时删除基于 project_id 和 account_id 的键
expectedKeys
:=
[]
string
{
expectedKeys
:=
[]
string
{
"gemini:gemini-proj"
,
"gemini:gemini-proj"
,
"gemini:account:1"
,
"ag:ag-proj"
,
"ag:ag-proj"
,
"ag:account:2"
,
"openai:account:3"
,
"openai:account:3"
,
"claude:account:4"
,
"claude:account:4"
,
}
}
...
@@ -266,3 +309,239 @@ func TestCompositeTokenCacheInvalidator_AllPlatformsIntegration(t *testing.T) {
...
@@ -266,3 +309,239 @@ func TestCompositeTokenCacheInvalidator_AllPlatformsIntegration(t *testing.T) {
require
.
Equal
(
t
,
expectedKeys
,
cache
.
deletedKeys
)
require
.
Equal
(
t
,
expectedKeys
,
cache
.
deletedKeys
)
}
}
// ========== GetCredentialAsInt64 测试 ==========
func
TestAccount_GetCredentialAsInt64
(
t
*
testing
.
T
)
{
tests
:=
[]
struct
{
name
string
credentials
map
[
string
]
any
key
string
expected
int64
}{
{
name
:
"int64_value"
,
credentials
:
map
[
string
]
any
{
"_token_version"
:
int64
(
1737654321000
)},
key
:
"_token_version"
,
expected
:
1737654321000
,
},
{
name
:
"float64_value"
,
credentials
:
map
[
string
]
any
{
"_token_version"
:
float64
(
1737654321000
)},
key
:
"_token_version"
,
expected
:
1737654321000
,
},
{
name
:
"int_value"
,
credentials
:
map
[
string
]
any
{
"_token_version"
:
12345
},
key
:
"_token_version"
,
expected
:
12345
,
},
{
name
:
"string_value"
,
credentials
:
map
[
string
]
any
{
"_token_version"
:
"1737654321000"
},
key
:
"_token_version"
,
expected
:
1737654321000
,
},
{
name
:
"string_with_spaces"
,
credentials
:
map
[
string
]
any
{
"_token_version"
:
" 1737654321000 "
},
key
:
"_token_version"
,
expected
:
1737654321000
,
},
{
name
:
"nil_credentials"
,
credentials
:
nil
,
key
:
"_token_version"
,
expected
:
0
,
},
{
name
:
"missing_key"
,
credentials
:
map
[
string
]
any
{
"other_key"
:
123
},
key
:
"_token_version"
,
expected
:
0
,
},
{
name
:
"nil_value"
,
credentials
:
map
[
string
]
any
{
"_token_version"
:
nil
},
key
:
"_token_version"
,
expected
:
0
,
},
{
name
:
"invalid_string"
,
credentials
:
map
[
string
]
any
{
"_token_version"
:
"not_a_number"
},
key
:
"_token_version"
,
expected
:
0
,
},
{
name
:
"empty_string"
,
credentials
:
map
[
string
]
any
{
"_token_version"
:
""
},
key
:
"_token_version"
,
expected
:
0
,
},
}
for
_
,
tt
:=
range
tests
{
t
.
Run
(
tt
.
name
,
func
(
t
*
testing
.
T
)
{
account
:=
&
Account
{
Credentials
:
tt
.
credentials
}
result
:=
account
.
GetCredentialAsInt64
(
tt
.
key
)
require
.
Equal
(
t
,
tt
.
expected
,
result
)
})
}
}
func
TestAccount_GetCredentialAsInt64_NilAccount
(
t
*
testing
.
T
)
{
var
account
*
Account
result
:=
account
.
GetCredentialAsInt64
(
"_token_version"
)
require
.
Equal
(
t
,
int64
(
0
),
result
)
}
// ========== CheckTokenVersion 测试 ==========
func
TestCheckTokenVersion
(
t
*
testing
.
T
)
{
tests
:=
[]
struct
{
name
string
account
*
Account
latestAccount
*
Account
repoErr
error
expectedStale
bool
}{
{
name
:
"nil_account"
,
account
:
nil
,
latestAccount
:
nil
,
expectedStale
:
false
,
},
{
name
:
"no_version_in_account_but_db_has_version"
,
account
:
&
Account
{
ID
:
1
,
Credentials
:
map
[
string
]
any
{},
},
latestAccount
:
&
Account
{
ID
:
1
,
Credentials
:
map
[
string
]
any
{
"_token_version"
:
int64
(
100
)},
},
expectedStale
:
true
,
// 当前 account 无版本但 DB 有,说明已被异步刷新,当前已过时
},
{
name
:
"both_no_version"
,
account
:
&
Account
{
ID
:
1
,
Credentials
:
map
[
string
]
any
{},
},
latestAccount
:
&
Account
{
ID
:
1
,
Credentials
:
map
[
string
]
any
{},
},
expectedStale
:
false
,
// 两边都没有版本号,说明从未被异步刷新过,允许缓存
},
{
name
:
"same_version"
,
account
:
&
Account
{
ID
:
1
,
Credentials
:
map
[
string
]
any
{
"_token_version"
:
int64
(
100
)},
},
latestAccount
:
&
Account
{
ID
:
1
,
Credentials
:
map
[
string
]
any
{
"_token_version"
:
int64
(
100
)},
},
expectedStale
:
false
,
},
{
name
:
"current_version_newer"
,
account
:
&
Account
{
ID
:
1
,
Credentials
:
map
[
string
]
any
{
"_token_version"
:
int64
(
200
)},
},
latestAccount
:
&
Account
{
ID
:
1
,
Credentials
:
map
[
string
]
any
{
"_token_version"
:
int64
(
100
)},
},
expectedStale
:
false
,
},
{
name
:
"current_version_older_stale"
,
account
:
&
Account
{
ID
:
1
,
Credentials
:
map
[
string
]
any
{
"_token_version"
:
int64
(
100
)},
},
latestAccount
:
&
Account
{
ID
:
1
,
Credentials
:
map
[
string
]
any
{
"_token_version"
:
int64
(
200
)},
},
expectedStale
:
true
,
// 当前版本过时
},
{
name
:
"repo_error"
,
account
:
&
Account
{
ID
:
1
,
Credentials
:
map
[
string
]
any
{
"_token_version"
:
int64
(
100
)},
},
latestAccount
:
nil
,
repoErr
:
errors
.
New
(
"db error"
),
expectedStale
:
false
,
// 查询失败,默认允许缓存
},
{
name
:
"repo_returns_nil"
,
account
:
&
Account
{
ID
:
1
,
Credentials
:
map
[
string
]
any
{
"_token_version"
:
int64
(
100
)},
},
latestAccount
:
nil
,
repoErr
:
nil
,
expectedStale
:
false
,
// 查询返回 nil,默认允许缓存
},
}
for
_
,
tt
:=
range
tests
{
t
.
Run
(
tt
.
name
,
func
(
t
*
testing
.
T
)
{
// 由于 CheckTokenVersion 接受 AccountRepository 接口,而创建完整的 mock 很繁琐
// 这里我们直接测试函数的核心逻辑来验证行为
if
tt
.
name
==
"nil_account"
{
_
,
isStale
:=
CheckTokenVersion
(
context
.
Background
(),
nil
,
nil
)
require
.
Equal
(
t
,
tt
.
expectedStale
,
isStale
)
return
}
// 模拟 CheckTokenVersion 的核心逻辑
account
:=
tt
.
account
currentVersion
:=
account
.
GetCredentialAsInt64
(
"_token_version"
)
// 模拟 repo 查询
latestAccount
:=
tt
.
latestAccount
if
tt
.
repoErr
!=
nil
||
latestAccount
==
nil
{
require
.
Equal
(
t
,
tt
.
expectedStale
,
false
)
return
}
latestVersion
:=
latestAccount
.
GetCredentialAsInt64
(
"_token_version"
)
// 情况1: 当前 account 没有版本号,但 DB 中已有版本号
if
currentVersion
==
0
&&
latestVersion
>
0
{
require
.
Equal
(
t
,
tt
.
expectedStale
,
true
)
return
}
// 情况2: 两边都没有版本号
if
currentVersion
==
0
&&
latestVersion
==
0
{
require
.
Equal
(
t
,
tt
.
expectedStale
,
false
)
return
}
// 情况3: 比较版本号
isStale
:=
latestVersion
>
currentVersion
require
.
Equal
(
t
,
tt
.
expectedStale
,
isStale
)
})
}
}
func
TestCheckTokenVersion_NilRepo
(
t
*
testing
.
T
)
{
account
:=
&
Account
{
ID
:
1
,
Credentials
:
map
[
string
]
any
{
"_token_version"
:
int64
(
100
)},
}
_
,
isStale
:=
CheckTokenVersion
(
context
.
Background
(),
account
,
nil
)
require
.
False
(
t
,
isStale
)
// nil repo,默认允许缓存
}
backend/internal/service/token_refresh_service.go
View file @
a161fcc8
...
@@ -169,6 +169,10 @@ func (s *TokenRefreshService) refreshWithRetry(ctx context.Context, account *Acc
...
@@ -169,6 +169,10 @@ func (s *TokenRefreshService) refreshWithRetry(ctx context.Context, account *Acc
// 如果有新凭证,先更新(即使有错误也要保存 token)
// 如果有新凭证,先更新(即使有错误也要保存 token)
if
newCredentials
!=
nil
{
if
newCredentials
!=
nil
{
// 记录刷新版本时间戳,用于解决缓存一致性问题
// TokenProvider 写入缓存前会检查此版本,如果版本已更新则跳过写入
newCredentials
[
"_token_version"
]
=
time
.
Now
()
.
UnixMilli
()
account
.
Credentials
=
newCredentials
account
.
Credentials
=
newCredentials
if
saveErr
:=
s
.
accountRepo
.
Update
(
ctx
,
account
);
saveErr
!=
nil
{
if
saveErr
:=
s
.
accountRepo
.
Update
(
ctx
,
account
);
saveErr
!=
nil
{
return
fmt
.
Errorf
(
"failed to save credentials: %w"
,
saveErr
)
return
fmt
.
Errorf
(
"failed to save credentials: %w"
,
saveErr
)
...
@@ -233,7 +237,8 @@ func (s *TokenRefreshService) refreshWithRetry(ctx context.Context, account *Acc
...
@@ -233,7 +237,8 @@ func (s *TokenRefreshService) refreshWithRetry(ctx context.Context, account *Acc
}
}
// isNonRetryableRefreshError 判断是否为不可重试的刷新错误
// isNonRetryableRefreshError 判断是否为不可重试的刷新错误
// 这些错误通常表示凭证已失效,需要用户重新授权
// 这些错误通常表示凭证已失效或配置确实缺失,需要用户重新授权
// 注意:missing_project_id 错误只在真正缺失(从未获取过)时返回,临时获取失败不会返回此错误
func
isNonRetryableRefreshError
(
err
error
)
bool
{
func
isNonRetryableRefreshError
(
err
error
)
bool
{
if
err
==
nil
{
if
err
==
nil
{
return
false
return
false
...
...
backend/internal/service/totp_service.go
0 → 100644
View file @
a161fcc8
package
service
import
(
"context"
"crypto/rand"
"crypto/subtle"
"encoding/hex"
"fmt"
"log/slog"
"time"
"github.com/pquerna/otp/totp"
infraerrors
"github.com/Wei-Shaw/sub2api/internal/pkg/errors"
)
var
(
ErrTotpNotEnabled
=
infraerrors
.
BadRequest
(
"TOTP_NOT_ENABLED"
,
"totp feature is not enabled"
)
ErrTotpAlreadyEnabled
=
infraerrors
.
BadRequest
(
"TOTP_ALREADY_ENABLED"
,
"totp is already enabled for this account"
)
ErrTotpNotSetup
=
infraerrors
.
BadRequest
(
"TOTP_NOT_SETUP"
,
"totp is not set up for this account"
)
ErrTotpInvalidCode
=
infraerrors
.
BadRequest
(
"TOTP_INVALID_CODE"
,
"invalid totp code"
)
ErrTotpSetupExpired
=
infraerrors
.
BadRequest
(
"TOTP_SETUP_EXPIRED"
,
"totp setup session expired"
)
ErrTotpTooManyAttempts
=
infraerrors
.
TooManyRequests
(
"TOTP_TOO_MANY_ATTEMPTS"
,
"too many verification attempts, please try again later"
)
ErrVerifyCodeRequired
=
infraerrors
.
BadRequest
(
"VERIFY_CODE_REQUIRED"
,
"email verification code is required"
)
ErrPasswordRequired
=
infraerrors
.
BadRequest
(
"PASSWORD_REQUIRED"
,
"password is required"
)
)
// TotpCache defines cache operations for TOTP service
type
TotpCache
interface
{
// Setup session methods
GetSetupSession
(
ctx
context
.
Context
,
userID
int64
)
(
*
TotpSetupSession
,
error
)
SetSetupSession
(
ctx
context
.
Context
,
userID
int64
,
session
*
TotpSetupSession
,
ttl
time
.
Duration
)
error
DeleteSetupSession
(
ctx
context
.
Context
,
userID
int64
)
error
// Login session methods (for 2FA login flow)
GetLoginSession
(
ctx
context
.
Context
,
tempToken
string
)
(
*
TotpLoginSession
,
error
)
SetLoginSession
(
ctx
context
.
Context
,
tempToken
string
,
session
*
TotpLoginSession
,
ttl
time
.
Duration
)
error
DeleteLoginSession
(
ctx
context
.
Context
,
tempToken
string
)
error
// Rate limiting
IncrementVerifyAttempts
(
ctx
context
.
Context
,
userID
int64
)
(
int
,
error
)
GetVerifyAttempts
(
ctx
context
.
Context
,
userID
int64
)
(
int
,
error
)
ClearVerifyAttempts
(
ctx
context
.
Context
,
userID
int64
)
error
}
// SecretEncryptor defines encryption operations for TOTP secrets
type
SecretEncryptor
interface
{
Encrypt
(
plaintext
string
)
(
string
,
error
)
Decrypt
(
ciphertext
string
)
(
string
,
error
)
}
// TotpSetupSession represents a TOTP setup session
type
TotpSetupSession
struct
{
Secret
string
// Plain text TOTP secret (not encrypted yet)
SetupToken
string
// Random token to verify setup request
CreatedAt
time
.
Time
}
// TotpLoginSession represents a pending 2FA login session
type
TotpLoginSession
struct
{
UserID
int64
Email
string
TokenExpiry
time
.
Time
}
// TotpStatus represents the TOTP status for a user
type
TotpStatus
struct
{
Enabled
bool
`json:"enabled"`
EnabledAt
*
time
.
Time
`json:"enabled_at,omitempty"`
FeatureEnabled
bool
`json:"feature_enabled"`
}
// TotpSetupResponse represents the response for initiating TOTP setup
type
TotpSetupResponse
struct
{
Secret
string
`json:"secret"`
QRCodeURL
string
`json:"qr_code_url"`
SetupToken
string
`json:"setup_token"`
Countdown
int
`json:"countdown"`
// seconds until setup expires
}
const
(
totpSetupTTL
=
5
*
time
.
Minute
totpLoginTTL
=
5
*
time
.
Minute
totpAttemptsTTL
=
15
*
time
.
Minute
maxTotpAttempts
=
5
totpIssuer
=
"Sub2API"
)
// TotpService handles TOTP operations
type
TotpService
struct
{
userRepo
UserRepository
encryptor
SecretEncryptor
cache
TotpCache
settingService
*
SettingService
emailService
*
EmailService
emailQueueService
*
EmailQueueService
}
// NewTotpService creates a new TOTP service
func
NewTotpService
(
userRepo
UserRepository
,
encryptor
SecretEncryptor
,
cache
TotpCache
,
settingService
*
SettingService
,
emailService
*
EmailService
,
emailQueueService
*
EmailQueueService
,
)
*
TotpService
{
return
&
TotpService
{
userRepo
:
userRepo
,
encryptor
:
encryptor
,
cache
:
cache
,
settingService
:
settingService
,
emailService
:
emailService
,
emailQueueService
:
emailQueueService
,
}
}
// GetStatus returns the TOTP status for a user
func
(
s
*
TotpService
)
GetStatus
(
ctx
context
.
Context
,
userID
int64
)
(
*
TotpStatus
,
error
)
{
featureEnabled
:=
s
.
settingService
.
IsTotpEnabled
(
ctx
)
user
,
err
:=
s
.
userRepo
.
GetByID
(
ctx
,
userID
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"get user: %w"
,
err
)
}
return
&
TotpStatus
{
Enabled
:
user
.
TotpEnabled
,
EnabledAt
:
user
.
TotpEnabledAt
,
FeatureEnabled
:
featureEnabled
,
},
nil
}
// InitiateSetup starts the TOTP setup process
// If email verification is enabled, emailCode is required; otherwise password is required
func
(
s
*
TotpService
)
InitiateSetup
(
ctx
context
.
Context
,
userID
int64
,
emailCode
,
password
string
)
(
*
TotpSetupResponse
,
error
)
{
// Check if TOTP feature is enabled globally
if
!
s
.
settingService
.
IsTotpEnabled
(
ctx
)
{
return
nil
,
ErrTotpNotEnabled
}
// Get user and check if TOTP is already enabled
user
,
err
:=
s
.
userRepo
.
GetByID
(
ctx
,
userID
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"get user: %w"
,
err
)
}
if
user
.
TotpEnabled
{
return
nil
,
ErrTotpAlreadyEnabled
}
// Verify identity based on email verification setting
if
s
.
settingService
.
IsEmailVerifyEnabled
(
ctx
)
{
// Email verification enabled - verify email code
if
emailCode
==
""
{
return
nil
,
ErrVerifyCodeRequired
}
if
err
:=
s
.
emailService
.
VerifyCode
(
ctx
,
user
.
Email
,
emailCode
);
err
!=
nil
{
return
nil
,
err
}
}
else
{
// Email verification disabled - verify password
if
password
==
""
{
return
nil
,
ErrPasswordRequired
}
if
!
user
.
CheckPassword
(
password
)
{
return
nil
,
ErrPasswordIncorrect
}
}
// Generate a new TOTP key
key
,
err
:=
totp
.
Generate
(
totp
.
GenerateOpts
{
Issuer
:
totpIssuer
,
AccountName
:
user
.
Email
,
})
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"generate totp key: %w"
,
err
)
}
// Generate a random setup token
setupToken
,
err
:=
generateRandomToken
(
32
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"generate setup token: %w"
,
err
)
}
// Store the setup session in cache
session
:=
&
TotpSetupSession
{
Secret
:
key
.
Secret
(),
SetupToken
:
setupToken
,
CreatedAt
:
time
.
Now
(),
}
if
err
:=
s
.
cache
.
SetSetupSession
(
ctx
,
userID
,
session
,
totpSetupTTL
);
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"store setup session: %w"
,
err
)
}
return
&
TotpSetupResponse
{
Secret
:
key
.
Secret
(),
QRCodeURL
:
key
.
URL
(),
SetupToken
:
setupToken
,
Countdown
:
int
(
totpSetupTTL
.
Seconds
()),
},
nil
}
// CompleteSetup completes the TOTP setup by verifying the code
func
(
s
*
TotpService
)
CompleteSetup
(
ctx
context
.
Context
,
userID
int64
,
totpCode
,
setupToken
string
)
error
{
// Check if TOTP feature is enabled globally
if
!
s
.
settingService
.
IsTotpEnabled
(
ctx
)
{
return
ErrTotpNotEnabled
}
// Get the setup session
session
,
err
:=
s
.
cache
.
GetSetupSession
(
ctx
,
userID
)
if
err
!=
nil
{
return
ErrTotpSetupExpired
}
if
session
==
nil
{
return
ErrTotpSetupExpired
}
// Verify the setup token (constant-time comparison)
if
subtle
.
ConstantTimeCompare
([]
byte
(
session
.
SetupToken
),
[]
byte
(
setupToken
))
!=
1
{
return
ErrTotpSetupExpired
}
// Verify the TOTP code
if
!
totp
.
Validate
(
totpCode
,
session
.
Secret
)
{
return
ErrTotpInvalidCode
}
setupSecretPrefix
:=
"N/A"
if
len
(
session
.
Secret
)
>=
4
{
setupSecretPrefix
=
session
.
Secret
[
:
4
]
}
slog
.
Debug
(
"totp_complete_setup_before_encrypt"
,
"user_id"
,
userID
,
"secret_len"
,
len
(
session
.
Secret
),
"secret_prefix"
,
setupSecretPrefix
)
// Encrypt the secret
encryptedSecret
,
err
:=
s
.
encryptor
.
Encrypt
(
session
.
Secret
)
if
err
!=
nil
{
return
fmt
.
Errorf
(
"encrypt totp secret: %w"
,
err
)
}
slog
.
Debug
(
"totp_complete_setup_encrypted"
,
"user_id"
,
userID
,
"encrypted_len"
,
len
(
encryptedSecret
))
// Verify encryption by decrypting
decrypted
,
decErr
:=
s
.
encryptor
.
Decrypt
(
encryptedSecret
)
if
decErr
!=
nil
{
slog
.
Debug
(
"totp_complete_setup_verify_failed"
,
"user_id"
,
userID
,
"error"
,
decErr
)
}
else
{
decryptedPrefix
:=
"N/A"
if
len
(
decrypted
)
>=
4
{
decryptedPrefix
=
decrypted
[
:
4
]
}
slog
.
Debug
(
"totp_complete_setup_verified"
,
"user_id"
,
userID
,
"original_len"
,
len
(
session
.
Secret
),
"decrypted_len"
,
len
(
decrypted
),
"match"
,
session
.
Secret
==
decrypted
,
"decrypted_prefix"
,
decryptedPrefix
)
}
// Update user with encrypted TOTP secret
if
err
:=
s
.
userRepo
.
UpdateTotpSecret
(
ctx
,
userID
,
&
encryptedSecret
);
err
!=
nil
{
return
fmt
.
Errorf
(
"update totp secret: %w"
,
err
)
}
// Enable TOTP for the user
if
err
:=
s
.
userRepo
.
EnableTotp
(
ctx
,
userID
);
err
!=
nil
{
return
fmt
.
Errorf
(
"enable totp: %w"
,
err
)
}
// Clean up the setup session
_
=
s
.
cache
.
DeleteSetupSession
(
ctx
,
userID
)
return
nil
}
// Disable disables TOTP for a user
// If email verification is enabled, emailCode is required; otherwise password is required
func
(
s
*
TotpService
)
Disable
(
ctx
context
.
Context
,
userID
int64
,
emailCode
,
password
string
)
error
{
// Get user
user
,
err
:=
s
.
userRepo
.
GetByID
(
ctx
,
userID
)
if
err
!=
nil
{
return
fmt
.
Errorf
(
"get user: %w"
,
err
)
}
if
!
user
.
TotpEnabled
{
return
ErrTotpNotSetup
}
// Verify identity based on email verification setting
if
s
.
settingService
.
IsEmailVerifyEnabled
(
ctx
)
{
// Email verification enabled - verify email code
if
emailCode
==
""
{
return
ErrVerifyCodeRequired
}
if
err
:=
s
.
emailService
.
VerifyCode
(
ctx
,
user
.
Email
,
emailCode
);
err
!=
nil
{
return
err
}
}
else
{
// Email verification disabled - verify password
if
password
==
""
{
return
ErrPasswordRequired
}
if
!
user
.
CheckPassword
(
password
)
{
return
ErrPasswordIncorrect
}
}
// Disable TOTP
if
err
:=
s
.
userRepo
.
DisableTotp
(
ctx
,
userID
);
err
!=
nil
{
return
fmt
.
Errorf
(
"disable totp: %w"
,
err
)
}
return
nil
}
// VerifyCode verifies a TOTP code for a user
func
(
s
*
TotpService
)
VerifyCode
(
ctx
context
.
Context
,
userID
int64
,
code
string
)
error
{
slog
.
Debug
(
"totp_verify_code_called"
,
"user_id"
,
userID
,
"code_len"
,
len
(
code
))
// Check rate limiting
attempts
,
err
:=
s
.
cache
.
GetVerifyAttempts
(
ctx
,
userID
)
if
err
==
nil
&&
attempts
>=
maxTotpAttempts
{
return
ErrTotpTooManyAttempts
}
// Get user
user
,
err
:=
s
.
userRepo
.
GetByID
(
ctx
,
userID
)
if
err
!=
nil
{
slog
.
Debug
(
"totp_verify_get_user_failed"
,
"user_id"
,
userID
,
"error"
,
err
)
return
infraerrors
.
InternalServer
(
"TOTP_VERIFY_ERROR"
,
"failed to verify totp code"
)
}
if
!
user
.
TotpEnabled
||
user
.
TotpSecretEncrypted
==
nil
{
slog
.
Debug
(
"totp_verify_not_setup"
,
"user_id"
,
userID
,
"enabled"
,
user
.
TotpEnabled
,
"has_secret"
,
user
.
TotpSecretEncrypted
!=
nil
)
return
ErrTotpNotSetup
}
slog
.
Debug
(
"totp_verify_encrypted_secret"
,
"user_id"
,
userID
,
"encrypted_len"
,
len
(
*
user
.
TotpSecretEncrypted
))
// Decrypt the secret
secret
,
err
:=
s
.
encryptor
.
Decrypt
(
*
user
.
TotpSecretEncrypted
)
if
err
!=
nil
{
slog
.
Debug
(
"totp_verify_decrypt_failed"
,
"user_id"
,
userID
,
"error"
,
err
)
return
infraerrors
.
InternalServer
(
"TOTP_VERIFY_ERROR"
,
"failed to verify totp code"
)
}
secretPrefix
:=
"N/A"
if
len
(
secret
)
>=
4
{
secretPrefix
=
secret
[
:
4
]
}
slog
.
Debug
(
"totp_verify_decrypted"
,
"user_id"
,
userID
,
"secret_len"
,
len
(
secret
),
"secret_prefix"
,
secretPrefix
)
// Verify the code
valid
:=
totp
.
Validate
(
code
,
secret
)
slog
.
Debug
(
"totp_verify_result"
,
"user_id"
,
userID
,
"valid"
,
valid
,
"secret_len"
,
len
(
secret
),
"secret_prefix"
,
secretPrefix
,
"server_time"
,
time
.
Now
()
.
UTC
()
.
Format
(
time
.
RFC3339
))
if
!
valid
{
// Increment failed attempts
_
,
_
=
s
.
cache
.
IncrementVerifyAttempts
(
ctx
,
userID
)
return
ErrTotpInvalidCode
}
// Clear attempt counter on success
_
=
s
.
cache
.
ClearVerifyAttempts
(
ctx
,
userID
)
return
nil
}
// CreateLoginSession creates a temporary login session for 2FA
func
(
s
*
TotpService
)
CreateLoginSession
(
ctx
context
.
Context
,
userID
int64
,
email
string
)
(
string
,
error
)
{
// Generate a random temp token
tempToken
,
err
:=
generateRandomToken
(
32
)
if
err
!=
nil
{
return
""
,
fmt
.
Errorf
(
"generate temp token: %w"
,
err
)
}
session
:=
&
TotpLoginSession
{
UserID
:
userID
,
Email
:
email
,
TokenExpiry
:
time
.
Now
()
.
Add
(
totpLoginTTL
),
}
if
err
:=
s
.
cache
.
SetLoginSession
(
ctx
,
tempToken
,
session
,
totpLoginTTL
);
err
!=
nil
{
return
""
,
fmt
.
Errorf
(
"store login session: %w"
,
err
)
}
return
tempToken
,
nil
}
// GetLoginSession retrieves a login session
func
(
s
*
TotpService
)
GetLoginSession
(
ctx
context
.
Context
,
tempToken
string
)
(
*
TotpLoginSession
,
error
)
{
return
s
.
cache
.
GetLoginSession
(
ctx
,
tempToken
)
}
// DeleteLoginSession deletes a login session
func
(
s
*
TotpService
)
DeleteLoginSession
(
ctx
context
.
Context
,
tempToken
string
)
error
{
return
s
.
cache
.
DeleteLoginSession
(
ctx
,
tempToken
)
}
// IsTotpEnabledForUser checks if TOTP is enabled for a specific user
func
(
s
*
TotpService
)
IsTotpEnabledForUser
(
ctx
context
.
Context
,
userID
int64
)
(
bool
,
error
)
{
user
,
err
:=
s
.
userRepo
.
GetByID
(
ctx
,
userID
)
if
err
!=
nil
{
return
false
,
fmt
.
Errorf
(
"get user: %w"
,
err
)
}
return
user
.
TotpEnabled
,
nil
}
// MaskEmail masks an email address for display
func
MaskEmail
(
email
string
)
string
{
if
len
(
email
)
<
3
{
return
"***"
}
atIdx
:=
-
1
for
i
,
c
:=
range
email
{
if
c
==
'@'
{
atIdx
=
i
break
}
}
if
atIdx
==
-
1
||
atIdx
<
1
{
return
email
[
:
1
]
+
"***"
}
localPart
:=
email
[
:
atIdx
]
domain
:=
email
[
atIdx
:
]
if
len
(
localPart
)
<=
2
{
return
localPart
[
:
1
]
+
"***"
+
domain
}
return
localPart
[
:
1
]
+
"***"
+
localPart
[
len
(
localPart
)
-
1
:
]
+
domain
}
// generateRandomToken generates a random hex-encoded token
func
generateRandomToken
(
byteLength
int
)
(
string
,
error
)
{
b
:=
make
([]
byte
,
byteLength
)
if
_
,
err
:=
rand
.
Read
(
b
);
err
!=
nil
{
return
""
,
err
}
return
hex
.
EncodeToString
(
b
),
nil
}
// VerificationMethod represents the method required for TOTP operations
type
VerificationMethod
struct
{
Method
string
`json:"method"`
// "email" or "password"
}
// GetVerificationMethod returns the verification method for TOTP operations
func
(
s
*
TotpService
)
GetVerificationMethod
(
ctx
context
.
Context
)
*
VerificationMethod
{
if
s
.
settingService
.
IsEmailVerifyEnabled
(
ctx
)
{
return
&
VerificationMethod
{
Method
:
"email"
}
}
return
&
VerificationMethod
{
Method
:
"password"
}
}
// SendVerifyCode sends an email verification code for TOTP operations
func
(
s
*
TotpService
)
SendVerifyCode
(
ctx
context
.
Context
,
userID
int64
)
error
{
// Check if email verification is enabled
if
!
s
.
settingService
.
IsEmailVerifyEnabled
(
ctx
)
{
return
infraerrors
.
BadRequest
(
"EMAIL_VERIFY_NOT_ENABLED"
,
"email verification is not enabled"
)
}
// Get user email
user
,
err
:=
s
.
userRepo
.
GetByID
(
ctx
,
userID
)
if
err
!=
nil
{
return
fmt
.
Errorf
(
"get user: %w"
,
err
)
}
// Get site name for email
siteName
:=
s
.
settingService
.
GetSiteName
(
ctx
)
// Send verification code via queue
return
s
.
emailQueueService
.
EnqueueVerifyCode
(
user
.
Email
,
siteName
)
}
backend/internal/service/user.go
View file @
a161fcc8
...
@@ -21,6 +21,11 @@ type User struct {
...
@@ -21,6 +21,11 @@ type User struct {
CreatedAt
time
.
Time
CreatedAt
time
.
Time
UpdatedAt
time
.
Time
UpdatedAt
time
.
Time
// TOTP 双因素认证字段
TotpSecretEncrypted
*
string
// AES-256-GCM 加密的 TOTP 密钥
TotpEnabled
bool
// 是否启用 TOTP
TotpEnabledAt
*
time
.
Time
// TOTP 启用时间
APIKeys
[]
APIKey
APIKeys
[]
APIKey
Subscriptions
[]
UserSubscription
Subscriptions
[]
UserSubscription
}
}
...
...
backend/internal/service/user_service.go
View file @
a161fcc8
...
@@ -38,6 +38,11 @@ type UserRepository interface {
...
@@ -38,6 +38,11 @@ type UserRepository interface {
UpdateConcurrency
(
ctx
context
.
Context
,
id
int64
,
amount
int
)
error
UpdateConcurrency
(
ctx
context
.
Context
,
id
int64
,
amount
int
)
error
ExistsByEmail
(
ctx
context
.
Context
,
email
string
)
(
bool
,
error
)
ExistsByEmail
(
ctx
context
.
Context
,
email
string
)
(
bool
,
error
)
RemoveGroupFromAllowedGroups
(
ctx
context
.
Context
,
groupID
int64
)
(
int64
,
error
)
RemoveGroupFromAllowedGroups
(
ctx
context
.
Context
,
groupID
int64
)
(
int64
,
error
)
// TOTP 相关方法
UpdateTotpSecret
(
ctx
context
.
Context
,
userID
int64
,
encryptedSecret
*
string
)
error
EnableTotp
(
ctx
context
.
Context
,
userID
int64
)
error
DisableTotp
(
ctx
context
.
Context
,
userID
int64
)
error
}
}
// UpdateProfileRequest 更新用户资料请求
// UpdateProfileRequest 更新用户资料请求
...
...
Prev
1
2
3
4
5
6
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment