Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Menu
Open sidebar
陈曦
sub2api
Commits
0170d19f
Commit
0170d19f
authored
Feb 02, 2026
by
song
Browse files
merge upstream main
parent
7ade9baa
Changes
319
Hide whitespace changes
Inline
Side-by-side
backend/internal/service/antigravity_gateway_service.go
View file @
0170d19f
...
@@ -26,7 +26,7 @@ import (
...
@@ -26,7 +26,7 @@ import (
const
(
const
(
antigravityStickySessionTTL
=
time
.
Hour
antigravityStickySessionTTL
=
time
.
Hour
antigravityDefaultMaxRetries
=
5
antigravityDefaultMaxRetries
=
3
antigravityRetryBaseDelay
=
1
*
time
.
Second
antigravityRetryBaseDelay
=
1
*
time
.
Second
antigravityRetryMaxDelay
=
16
*
time
.
Second
antigravityRetryMaxDelay
=
16
*
time
.
Second
)
)
...
@@ -52,11 +52,11 @@ type antigravityRetryLoopParams struct {
...
@@ -52,11 +52,11 @@ type antigravityRetryLoopParams struct {
action
string
action
string
body
[]
byte
body
[]
byte
quotaScope
AntigravityQuotaScope
quotaScope
AntigravityQuotaScope
maxRetries
int
c
*
gin
.
Context
c
*
gin
.
Context
httpUpstream
HTTPUpstream
httpUpstream
HTTPUpstream
settingService
*
SettingService
settingService
*
SettingService
handleError
func
(
ctx
context
.
Context
,
prefix
string
,
account
*
Account
,
statusCode
int
,
headers
http
.
Header
,
body
[]
byte
,
quotaScope
AntigravityQuotaScope
)
handleError
func
(
ctx
context
.
Context
,
prefix
string
,
account
*
Account
,
statusCode
int
,
headers
http
.
Header
,
body
[]
byte
,
quotaScope
AntigravityQuotaScope
)
maxRetries
int
// 可选,0 表示使用平台级默认值
}
}
// antigravityRetryLoopResult 重试循环的结果
// antigravityRetryLoopResult 重试循环的结果
...
@@ -82,9 +82,10 @@ func antigravityRetryLoop(p antigravityRetryLoopParams) (*antigravityRetryLoopRe
...
@@ -82,9 +82,10 @@ func antigravityRetryLoop(p antigravityRetryLoopParams) (*antigravityRetryLoopRe
if
len
(
availableURLs
)
==
0
{
if
len
(
availableURLs
)
==
0
{
availableURLs
=
baseURLs
availableURLs
=
baseURLs
}
}
maxRetries
:=
p
.
maxRetries
maxRetries
:=
p
.
maxRetries
if
maxRetries
<=
0
{
if
maxRetries
<=
0
{
maxRetries
=
antigravityMaxRetries
()
maxRetries
=
antigravity
Default
MaxRetries
}
}
var
resp
*
http
.
Response
var
resp
*
http
.
Response
...
@@ -161,7 +162,7 @@ urlFallbackLoop:
...
@@ -161,7 +162,7 @@ urlFallbackLoop:
continue
urlFallbackLoop
continue
urlFallbackLoop
}
}
// 账户/模型配额限流,
按最大重试次数做
指数退避
// 账户/模型配额限流,
重试 3 次(
指数退避
)
if
attempt
<
maxRetries
{
if
attempt
<
maxRetries
{
upstreamMsg
:=
strings
.
TrimSpace
(
extractAntigravityErrorMessage
(
respBody
))
upstreamMsg
:=
strings
.
TrimSpace
(
extractAntigravityErrorMessage
(
respBody
))
upstreamMsg
=
sanitizeUpstreamErrorMessage
(
upstreamMsg
)
upstreamMsg
=
sanitizeUpstreamErrorMessage
(
upstreamMsg
)
...
@@ -1044,7 +1045,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
...
@@ -1044,7 +1045,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
return
&
ForwardResult
{
return
&
ForwardResult
{
RequestID
:
requestID
,
RequestID
:
requestID
,
Usage
:
*
usage
,
Usage
:
*
usage
,
Model
:
billingModel
,
Model
:
billingModel
,
// 计费模型(可按映射模型覆盖)
Stream
:
claudeReq
.
Stream
,
Stream
:
claudeReq
.
Stream
,
Duration
:
time
.
Since
(
startTime
),
Duration
:
time
.
Since
(
startTime
),
FirstTokenMs
:
firstTokenMs
,
FirstTokenMs
:
firstTokenMs
,
...
@@ -1729,7 +1730,6 @@ func antigravityFallbackCooldownSeconds() (time.Duration, bool) {
...
@@ -1729,7 +1730,6 @@ func antigravityFallbackCooldownSeconds() (time.Duration, bool) {
}
}
return
time
.
Duration
(
seconds
)
*
time
.
Second
,
true
return
time
.
Duration
(
seconds
)
*
time
.
Second
,
true
}
}
func
(
s
*
AntigravityGatewayService
)
handleUpstreamError
(
ctx
context
.
Context
,
prefix
string
,
account
*
Account
,
statusCode
int
,
headers
http
.
Header
,
body
[]
byte
,
quotaScope
AntigravityQuotaScope
)
{
func
(
s
*
AntigravityGatewayService
)
handleUpstreamError
(
ctx
context
.
Context
,
prefix
string
,
account
*
Account
,
statusCode
int
,
headers
http
.
Header
,
body
[]
byte
,
quotaScope
AntigravityQuotaScope
)
{
// 429 使用 Gemini 格式解析(从 body 解析重置时间)
// 429 使用 Gemini 格式解析(从 body 解析重置时间)
if
statusCode
==
429
{
if
statusCode
==
429
{
...
@@ -1742,9 +1742,6 @@ func (s *AntigravityGatewayService) handleUpstreamError(ctx context.Context, pre
...
@@ -1742,9 +1742,6 @@ func (s *AntigravityGatewayService) handleUpstreamError(ctx context.Context, pre
fallbackMinutes
=
s
.
settingService
.
cfg
.
Gateway
.
AntigravityFallbackCooldownMinutes
fallbackMinutes
=
s
.
settingService
.
cfg
.
Gateway
.
AntigravityFallbackCooldownMinutes
}
}
defaultDur
:=
time
.
Duration
(
fallbackMinutes
)
*
time
.
Minute
defaultDur
:=
time
.
Duration
(
fallbackMinutes
)
*
time
.
Minute
if
override
,
ok
:=
antigravityFallbackCooldownSeconds
();
ok
{
defaultDur
=
override
}
ra
:=
time
.
Now
()
.
Add
(
defaultDur
)
ra
:=
time
.
Now
()
.
Add
(
defaultDur
)
if
useScopeLimit
{
if
useScopeLimit
{
log
.
Printf
(
"%s status=429 rate_limited scope=%s reset_in=%v (fallback)"
,
prefix
,
quotaScope
,
defaultDur
)
log
.
Printf
(
"%s status=429 rate_limited scope=%s reset_in=%v (fallback)"
,
prefix
,
quotaScope
,
defaultDur
)
...
@@ -2185,6 +2182,58 @@ func getOrCreateGeminiParts(response map[string]any) (result map[string]any, exi
...
@@ -2185,6 +2182,58 @@ func getOrCreateGeminiParts(response map[string]any) (result map[string]any, exi
return
result
,
existingParts
,
setParts
return
result
,
existingParts
,
setParts
}
}
// mergeCollectedPartsToResponse 将收集的所有 parts 合并到 Gemini 响应中
// 这个函数会合并所有类型的 parts:text、thinking、functionCall、inlineData 等
// 保持原始顺序,只合并连续的普通 text parts
func
mergeCollectedPartsToResponse
(
response
map
[
string
]
any
,
collectedParts
[]
map
[
string
]
any
)
map
[
string
]
any
{
if
len
(
collectedParts
)
==
0
{
return
response
}
result
,
_
,
setParts
:=
getOrCreateGeminiParts
(
response
)
// 合并策略:
// 1. 保持原始顺序
// 2. 连续的普通 text parts 合并为一个
// 3. thinking、functionCall、inlineData 等保持原样
var
mergedParts
[]
any
var
textBuffer
strings
.
Builder
flushTextBuffer
:=
func
()
{
if
textBuffer
.
Len
()
>
0
{
mergedParts
=
append
(
mergedParts
,
map
[
string
]
any
{
"text"
:
textBuffer
.
String
(),
})
textBuffer
.
Reset
()
}
}
for
_
,
part
:=
range
collectedParts
{
// 检查是否是普通 text part
if
text
,
ok
:=
part
[
"text"
]
.
(
string
);
ok
{
// 检查是否有 thought 标记
if
thought
,
_
:=
part
[
"thought"
]
.
(
bool
);
thought
{
// thinking part,先刷新 text buffer,然后保留原样
flushTextBuffer
()
mergedParts
=
append
(
mergedParts
,
part
)
}
else
{
// 普通 text,累积到 buffer
_
,
_
=
textBuffer
.
WriteString
(
text
)
}
}
else
{
// 非 text part(functionCall、inlineData 等),先刷新 text buffer,然后保留原样
flushTextBuffer
()
mergedParts
=
append
(
mergedParts
,
part
)
}
}
// 刷新剩余的 text
flushTextBuffer
()
setParts
(
mergedParts
)
return
result
}
// mergeImagePartsToResponse 将收集到的图片 parts 合并到 Gemini 响应中
// mergeImagePartsToResponse 将收集到的图片 parts 合并到 Gemini 响应中
func
mergeImagePartsToResponse
(
response
map
[
string
]
any
,
imageParts
[]
map
[
string
]
any
)
map
[
string
]
any
{
func
mergeImagePartsToResponse
(
response
map
[
string
]
any
,
imageParts
[]
map
[
string
]
any
)
map
[
string
]
any
{
if
len
(
imageParts
)
==
0
{
if
len
(
imageParts
)
==
0
{
...
@@ -2372,8 +2421,7 @@ func (s *AntigravityGatewayService) handleClaudeStreamToNonStreaming(c *gin.Cont
...
@@ -2372,8 +2421,7 @@ func (s *AntigravityGatewayService) handleClaudeStreamToNonStreaming(c *gin.Cont
var
firstTokenMs
*
int
var
firstTokenMs
*
int
var
last
map
[
string
]
any
var
last
map
[
string
]
any
var
lastWithParts
map
[
string
]
any
var
lastWithParts
map
[
string
]
any
var
collectedImageParts
[]
map
[
string
]
any
// 收集所有包含图片的 parts
var
collectedParts
[]
map
[
string
]
any
// 收集所有 parts(包括 text、thinking、functionCall、inlineData 等)
var
collectedTextParts
[]
string
// 收集所有文本片段
type
scanEvent
struct
{
type
scanEvent
struct
{
line
string
line
string
...
@@ -2468,18 +2516,12 @@ func (s *AntigravityGatewayService) handleClaudeStreamToNonStreaming(c *gin.Cont
...
@@ -2468,18 +2516,12 @@ func (s *AntigravityGatewayService) handleClaudeStreamToNonStreaming(c *gin.Cont
last
=
parsed
last
=
parsed
// 保留最后一个有 parts 的响应
// 保留最后一个有 parts 的响应
,并收集所有 parts
if
parts
:=
extractGeminiParts
(
parsed
);
len
(
parts
)
>
0
{
if
parts
:=
extractGeminiParts
(
parsed
);
len
(
parts
)
>
0
{
lastWithParts
=
parsed
lastWithParts
=
parsed
// 收集包含图片和文本的 parts
for
_
,
part
:=
range
parts
{
// 收集所有 parts(text、thinking、functionCall、inlineData 等)
if
_
,
ok
:=
part
[
"inlineData"
]
.
(
map
[
string
]
any
);
ok
{
collectedParts
=
append
(
collectedParts
,
parts
...
)
collectedImageParts
=
append
(
collectedImageParts
,
part
)
}
if
text
,
ok
:=
part
[
"text"
]
.
(
string
);
ok
&&
text
!=
""
{
collectedTextParts
=
append
(
collectedTextParts
,
text
)
}
}
}
}
case
<-
intervalCh
:
case
<-
intervalCh
:
...
@@ -2502,32 +2544,13 @@ returnResponse:
...
@@ -2502,32 +2544,13 @@ returnResponse:
return
nil
,
s
.
writeClaudeError
(
c
,
http
.
StatusBadGateway
,
"upstream_error"
,
"Empty response from upstream"
)
return
nil
,
s
.
writeClaudeError
(
c
,
http
.
StatusBadGateway
,
"upstream_error"
,
"Empty response from upstream"
)
}
}
// 如果收集到了图片 parts,需要合并到最终响应中
// 将收集的所有 parts 合并到最终响应中
if
len
(
collectedImageParts
)
>
0
{
if
len
(
collectedParts
)
>
0
{
finalResponse
=
mergeImagePartsToResponse
(
finalResponse
,
collectedImageParts
)
finalResponse
=
mergeCollectedPartsToResponse
(
finalResponse
,
collectedParts
)
}
// 如果收集到了文本,需要合并到最终响应中
if
len
(
collectedTextParts
)
>
0
{
finalResponse
=
mergeTextPartsToResponse
(
finalResponse
,
collectedTextParts
)
}
geminiPayload
:=
finalResponse
if
_
,
ok
:=
finalResponse
[
"response"
];
!
ok
{
wrapped
:=
map
[
string
]
any
{
"response"
:
finalResponse
,
}
if
respID
,
ok
:=
finalResponse
[
"responseId"
];
ok
{
wrapped
[
"responseId"
]
=
respID
}
if
modelVersion
,
ok
:=
finalResponse
[
"modelVersion"
];
ok
{
wrapped
[
"modelVersion"
]
=
modelVersion
}
geminiPayload
=
wrapped
}
}
// 序列化为 JSON(Gemini 格式)
// 序列化为 JSON(Gemini 格式)
geminiBody
,
err
:=
json
.
Marshal
(
geminiPayload
)
geminiBody
,
err
:=
json
.
Marshal
(
finalResponse
)
if
err
!=
nil
{
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"failed to marshal gemini response: %w"
,
err
)
return
nil
,
fmt
.
Errorf
(
"failed to marshal gemini response: %w"
,
err
)
}
}
...
...
backend/internal/service/antigravity_model_mapping_test.go
View file @
0170d19f
...
@@ -30,7 +30,7 @@ func TestIsAntigravityModelSupported(t *testing.T) {
...
@@ -30,7 +30,7 @@ func TestIsAntigravityModelSupported(t *testing.T) {
{
"可映射 - claude-3-haiku-20240307"
,
"claude-3-haiku-20240307"
,
true
},
{
"可映射 - claude-3-haiku-20240307"
,
"claude-3-haiku-20240307"
,
true
},
// Gemini 前缀透传
// Gemini 前缀透传
{
"Gemini前缀 - gemini-
1
.5-pro"
,
"gemini-
1
.5-pro"
,
true
},
{
"Gemini前缀 - gemini-
2
.5-pro"
,
"gemini-
2
.5-pro"
,
true
},
{
"Gemini前缀 - gemini-unknown-model"
,
"gemini-unknown-model"
,
true
},
{
"Gemini前缀 - gemini-unknown-model"
,
"gemini-unknown-model"
,
true
},
{
"Gemini前缀 - gemini-future-version"
,
"gemini-future-version"
,
true
},
{
"Gemini前缀 - gemini-future-version"
,
"gemini-future-version"
,
true
},
...
@@ -142,10 +142,10 @@ func TestAntigravityGatewayService_GetMappedModel(t *testing.T) {
...
@@ -142,10 +142,10 @@ func TestAntigravityGatewayService_GetMappedModel(t *testing.T) {
expected
:
"gemini-2.5-flash"
,
expected
:
"gemini-2.5-flash"
,
},
},
{
{
name
:
"Gemini透传 - gemini-
1
.5-pro"
,
name
:
"Gemini透传 - gemini-
2
.5-pro"
,
requestedModel
:
"gemini-
1
.5-pro"
,
requestedModel
:
"gemini-
2
.5-pro"
,
accountMapping
:
nil
,
accountMapping
:
nil
,
expected
:
"gemini-
1
.5-pro"
,
expected
:
"gemini-
2
.5-pro"
,
},
},
{
{
name
:
"Gemini透传 - gemini-future-model"
,
name
:
"Gemini透传 - gemini-future-model"
,
...
...
backend/internal/service/antigravity_oauth_service.go
View file @
0170d19f
...
@@ -142,12 +142,13 @@ func (s *AntigravityOAuthService) ExchangeCode(ctx context.Context, input *Antig
...
@@ -142,12 +142,13 @@ func (s *AntigravityOAuthService) ExchangeCode(ctx context.Context, input *Antig
result
.
Email
=
userInfo
.
Email
result
.
Email
=
userInfo
.
Email
}
}
// 获取 project_id(部分账户类型可能没有)
// 获取 project_id(部分账户类型可能没有),失败时重试
loadResp
,
_
,
err
:=
client
.
LoadCodeAssist
(
ctx
,
tokenResp
.
AccessToken
)
projectID
,
loadErr
:=
s
.
loadProjectIDWithRetry
(
ctx
,
tokenResp
.
AccessToken
,
proxyURL
,
3
)
if
err
!=
nil
{
if
loadErr
!=
nil
{
fmt
.
Printf
(
"[AntigravityOAuth] 警告: 获取 project_id 失败: %v
\n
"
,
err
)
fmt
.
Printf
(
"[AntigravityOAuth] 警告: 获取 project_id 失败(重试后): %v
\n
"
,
loadErr
)
}
else
if
loadResp
!=
nil
&&
loadResp
.
CloudAICompanionProject
!=
""
{
result
.
ProjectIDMissing
=
true
result
.
ProjectID
=
loadResp
.
CloudAICompanionProject
}
else
{
result
.
ProjectID
=
projectID
}
}
return
result
,
nil
return
result
,
nil
...
@@ -237,21 +238,60 @@ func (s *AntigravityOAuthService) RefreshAccountToken(ctx context.Context, accou
...
@@ -237,21 +238,60 @@ func (s *AntigravityOAuthService) RefreshAccountToken(ctx context.Context, accou
tokenInfo
.
Email
=
existingEmail
tokenInfo
.
Email
=
existingEmail
}
}
// 每次刷新都调用 LoadCodeAssist 获取 project_id
// 每次刷新都调用 LoadCodeAssist 获取 project_id
,失败时重试
client
:=
antigravity
.
NewClient
(
proxyURL
)
existingProjectID
:=
strings
.
TrimSpace
(
account
.
GetCredential
(
"project_id"
)
)
loadResp
,
_
,
err
:=
client
.
LoadCodeAssist
(
ctx
,
tokenInfo
.
AccessToken
)
projectID
,
loadErr
:=
s
.
loadProjectIDWithRetry
(
ctx
,
tokenInfo
.
AccessToken
,
proxyURL
,
3
)
if
err
!=
nil
||
loadResp
==
nil
||
loadResp
.
CloudAICompanionProject
==
""
{
// LoadCodeAssist 失败或返回空,保留原有 project_id,标记缺失
if
loadErr
!=
nil
{
existingProjectID
:=
strings
.
TrimSpace
(
account
.
GetCredential
(
"
project_id
"
))
// LoadCodeAssist 失败,保留原有
project_id
tokenInfo
.
ProjectID
=
existingProjectID
tokenInfo
.
ProjectID
=
existingProjectID
tokenInfo
.
ProjectIDMissing
=
true
// 只有从未获取过 project_id 且本次也获取失败时,才标记为真正缺失
// 如果之前有 project_id,本次只是临时故障,不应标记为错误
if
existingProjectID
==
""
{
tokenInfo
.
ProjectIDMissing
=
true
}
}
else
{
}
else
{
tokenInfo
.
ProjectID
=
loadResp
.
CloudAICompanionP
roject
tokenInfo
.
ProjectID
=
p
roject
ID
}
}
return
tokenInfo
,
nil
return
tokenInfo
,
nil
}
}
// loadProjectIDWithRetry 带重试机制获取 project_id
// 返回 project_id 和错误,失败时会重试指定次数
func
(
s
*
AntigravityOAuthService
)
loadProjectIDWithRetry
(
ctx
context
.
Context
,
accessToken
,
proxyURL
string
,
maxRetries
int
)
(
string
,
error
)
{
var
lastErr
error
for
attempt
:=
0
;
attempt
<=
maxRetries
;
attempt
++
{
if
attempt
>
0
{
// 指数退避:1s, 2s, 4s
backoff
:=
time
.
Duration
(
1
<<
uint
(
attempt
-
1
))
*
time
.
Second
if
backoff
>
8
*
time
.
Second
{
backoff
=
8
*
time
.
Second
}
time
.
Sleep
(
backoff
)
}
client
:=
antigravity
.
NewClient
(
proxyURL
)
loadResp
,
_
,
err
:=
client
.
LoadCodeAssist
(
ctx
,
accessToken
)
if
err
==
nil
&&
loadResp
!=
nil
&&
loadResp
.
CloudAICompanionProject
!=
""
{
return
loadResp
.
CloudAICompanionProject
,
nil
}
// 记录错误
if
err
!=
nil
{
lastErr
=
err
}
else
if
loadResp
==
nil
{
lastErr
=
fmt
.
Errorf
(
"LoadCodeAssist 返回空响应"
)
}
else
{
lastErr
=
fmt
.
Errorf
(
"LoadCodeAssist 返回空 project_id"
)
}
}
return
""
,
fmt
.
Errorf
(
"获取 project_id 失败 (重试 %d 次后): %w"
,
maxRetries
,
lastErr
)
}
// BuildAccountCredentials 构建账户凭证
// BuildAccountCredentials 构建账户凭证
func
(
s
*
AntigravityOAuthService
)
BuildAccountCredentials
(
tokenInfo
*
AntigravityTokenInfo
)
map
[
string
]
any
{
func
(
s
*
AntigravityOAuthService
)
BuildAccountCredentials
(
tokenInfo
*
AntigravityTokenInfo
)
map
[
string
]
any
{
creds
:=
map
[
string
]
any
{
creds
:=
map
[
string
]
any
{
...
...
backend/internal/service/antigravity_rate_limit_test.go
View file @
0170d19f
...
@@ -38,6 +38,10 @@ func (s *stubAntigravityUpstream) Do(req *http.Request, proxyURL string, account
...
@@ -38,6 +38,10 @@ func (s *stubAntigravityUpstream) Do(req *http.Request, proxyURL string, account
},
nil
},
nil
}
}
func
(
s
*
stubAntigravityUpstream
)
DoWithTLS
(
req
*
http
.
Request
,
proxyURL
string
,
accountID
int64
,
accountConcurrency
int
,
enableTLSFingerprint
bool
)
(
*
http
.
Response
,
error
)
{
return
s
.
Do
(
req
,
proxyURL
,
accountID
,
accountConcurrency
)
}
type
scopeLimitCall
struct
{
type
scopeLimitCall
struct
{
accountID
int64
accountID
int64
scope
AntigravityQuotaScope
scope
AntigravityQuotaScope
...
@@ -90,14 +94,14 @@ func TestAntigravityRetryLoop_URLFallback_UsesLatestSuccess(t *testing.T) {
...
@@ -90,14 +94,14 @@ func TestAntigravityRetryLoop_URLFallback_UsesLatestSuccess(t *testing.T) {
var
handleErrorCalled
bool
var
handleErrorCalled
bool
result
,
err
:=
antigravityRetryLoop
(
antigravityRetryLoopParams
{
result
,
err
:=
antigravityRetryLoop
(
antigravityRetryLoopParams
{
prefix
:
"[test]"
,
prefix
:
"[test]"
,
ctx
:
context
.
Background
(),
ctx
:
context
.
Background
(),
account
:
account
,
account
:
account
,
proxyURL
:
""
,
proxyURL
:
""
,
accessToken
:
"token"
,
accessToken
:
"token"
,
action
:
"generateContent"
,
action
:
"generateContent"
,
body
:
[]
byte
(
`{"input":"test"}`
),
body
:
[]
byte
(
`{"input":"test"}`
),
quotaScope
:
AntigravityQuotaScopeClaude
,
quotaScope
:
AntigravityQuotaScopeClaude
,
httpUpstream
:
upstream
,
httpUpstream
:
upstream
,
handleError
:
func
(
ctx
context
.
Context
,
prefix
string
,
account
*
Account
,
statusCode
int
,
headers
http
.
Header
,
body
[]
byte
,
quotaScope
AntigravityQuotaScope
)
{
handleError
:
func
(
ctx
context
.
Context
,
prefix
string
,
account
*
Account
,
statusCode
int
,
headers
http
.
Header
,
body
[]
byte
,
quotaScope
AntigravityQuotaScope
)
{
handleErrorCalled
=
true
handleErrorCalled
=
true
...
...
backend/internal/service/antigravity_token_provider.go
View file @
0170d19f
...
@@ -4,6 +4,7 @@ import (
...
@@ -4,6 +4,7 @@ import (
"context"
"context"
"errors"
"errors"
"log"
"log"
"log/slog"
"strconv"
"strconv"
"strings"
"strings"
"time"
"time"
...
@@ -101,21 +102,32 @@ func (p *AntigravityTokenProvider) GetAccessToken(ctx context.Context, account *
...
@@ -101,21 +102,32 @@ func (p *AntigravityTokenProvider) GetAccessToken(ctx context.Context, account *
return
""
,
errors
.
New
(
"access_token not found in credentials"
)
return
""
,
errors
.
New
(
"access_token not found in credentials"
)
}
}
// 3. 存入缓存
// 3. 存入缓存
(验证版本后再写入,避免异步刷新任务与请求线程的竞态条件)
if
p
.
tokenCache
!=
nil
{
if
p
.
tokenCache
!=
nil
{
ttl
:=
30
*
time
.
Minute
latestAccount
,
isStale
:=
CheckTokenVersion
(
ctx
,
account
,
p
.
accountRepo
)
if
expiresAt
!=
nil
{
if
isStale
&&
latestAccount
!=
nil
{
until
:=
time
.
Until
(
*
expiresAt
)
// 版本过时,使用 DB 中的最新 token
switch
{
slog
.
Debug
(
"antigravity_token_version_stale_use_latest"
,
"account_id"
,
account
.
ID
)
case
until
>
antigravityTokenCacheSkew
:
accessToken
=
latestAccount
.
GetCredential
(
"access_token"
)
ttl
=
until
-
antigravityTokenCacheSkew
if
strings
.
TrimSpace
(
accessToken
)
==
""
{
case
until
>
0
:
return
""
,
errors
.
New
(
"access_token not found after version check"
)
ttl
=
until
default
:
ttl
=
time
.
Minute
}
}
// 不写入缓存,让下次请求重新处理
}
else
{
ttl
:=
30
*
time
.
Minute
if
expiresAt
!=
nil
{
until
:=
time
.
Until
(
*
expiresAt
)
switch
{
case
until
>
antigravityTokenCacheSkew
:
ttl
=
until
-
antigravityTokenCacheSkew
case
until
>
0
:
ttl
=
until
default
:
ttl
=
time
.
Minute
}
}
_
=
p
.
tokenCache
.
SetAccessToken
(
ctx
,
cacheKey
,
accessToken
,
ttl
)
}
}
_
=
p
.
tokenCache
.
SetAccessToken
(
ctx
,
cacheKey
,
accessToken
,
ttl
)
}
}
return
accessToken
,
nil
return
accessToken
,
nil
...
...
backend/internal/service/antigravity_token_refresher.go
View file @
0170d19f
...
@@ -3,6 +3,8 @@ package service
...
@@ -3,6 +3,8 @@ package service
import
(
import
(
"context"
"context"
"fmt"
"fmt"
"log"
"strings"
"time"
"time"
)
)
...
@@ -55,15 +57,32 @@ func (r *AntigravityTokenRefresher) Refresh(ctx context.Context, account *Accoun
...
@@ -55,15 +57,32 @@ func (r *AntigravityTokenRefresher) Refresh(ctx context.Context, account *Accoun
}
}
newCredentials
:=
r
.
antigravityOAuthService
.
BuildAccountCredentials
(
tokenInfo
)
newCredentials
:=
r
.
antigravityOAuthService
.
BuildAccountCredentials
(
tokenInfo
)
// 合并旧的 credentials,保留新 credentials 中不存在的字段
for
k
,
v
:=
range
account
.
Credentials
{
for
k
,
v
:=
range
account
.
Credentials
{
if
_
,
exists
:=
newCredentials
[
k
];
!
exists
{
if
_
,
exists
:=
newCredentials
[
k
];
!
exists
{
newCredentials
[
k
]
=
v
newCredentials
[
k
]
=
v
}
}
}
}
// 如果 project_id 获取失败,返回 credentials 但同时返回错误让账户被标记
// 特殊处理 project_id:如果新值为空但旧值非空,保留旧值
// 这确保了即使 LoadCodeAssist 失败,project_id 也不会丢失
if
newProjectID
,
_
:=
newCredentials
[
"project_id"
]
.
(
string
);
newProjectID
==
""
{
if
oldProjectID
:=
strings
.
TrimSpace
(
account
.
GetCredential
(
"project_id"
));
oldProjectID
!=
""
{
newCredentials
[
"project_id"
]
=
oldProjectID
}
}
// 如果 project_id 获取失败,只记录警告,不返回错误
// LoadCodeAssist 失败可能是临时网络问题,应该允许重试而不是立即标记为不可重试错误
// Token 刷新本身是成功的(access_token 和 refresh_token 已更新)
if
tokenInfo
.
ProjectIDMissing
{
if
tokenInfo
.
ProjectIDMissing
{
return
newCredentials
,
fmt
.
Errorf
(
"missing_project_id: 账户缺少project id,可能无法使用Antigravity"
)
if
tokenInfo
.
ProjectID
!=
""
{
// 有旧的 project_id,本次获取失败,保留旧值
log
.
Printf
(
"[AntigravityTokenRefresher] Account %d: LoadCodeAssist 临时失败,保留旧 project_id"
,
account
.
ID
)
}
else
{
// 从未获取过 project_id,本次也失败,但不返回错误以允许下次重试
log
.
Printf
(
"[AntigravityTokenRefresher] Account %d: LoadCodeAssist 失败,project_id 缺失,但 token 已更新,将在下次刷新时重试"
,
account
.
ID
)
}
}
}
return
newCredentials
,
nil
return
newCredentials
,
nil
...
...
backend/internal/service/api_key_auth_cache_impl.go
View file @
0170d19f
...
@@ -94,6 +94,20 @@ func (s *APIKeyService) initAuthCache(cfg *config.Config) {
...
@@ -94,6 +94,20 @@ func (s *APIKeyService) initAuthCache(cfg *config.Config) {
s
.
authCacheL1
=
cache
s
.
authCacheL1
=
cache
}
}
// StartAuthCacheInvalidationSubscriber starts the Pub/Sub subscriber for L1 cache invalidation.
// This should be called after the service is fully initialized.
func
(
s
*
APIKeyService
)
StartAuthCacheInvalidationSubscriber
(
ctx
context
.
Context
)
{
if
s
.
cache
==
nil
||
s
.
authCacheL1
==
nil
{
return
}
if
err
:=
s
.
cache
.
SubscribeAuthCacheInvalidation
(
ctx
,
func
(
cacheKey
string
)
{
s
.
authCacheL1
.
Del
(
cacheKey
)
});
err
!=
nil
{
// Log but don't fail - L1 cache will still work, just without cross-instance invalidation
println
(
"[Service] Warning: failed to start auth cache invalidation subscriber:"
,
err
.
Error
())
}
}
func
(
s
*
APIKeyService
)
authCacheKey
(
key
string
)
string
{
func
(
s
*
APIKeyService
)
authCacheKey
(
key
string
)
string
{
sum
:=
sha256
.
Sum256
([]
byte
(
key
))
sum
:=
sha256
.
Sum256
([]
byte
(
key
))
return
hex
.
EncodeToString
(
sum
[
:
])
return
hex
.
EncodeToString
(
sum
[
:
])
...
@@ -149,6 +163,8 @@ func (s *APIKeyService) deleteAuthCache(ctx context.Context, cacheKey string) {
...
@@ -149,6 +163,8 @@ func (s *APIKeyService) deleteAuthCache(ctx context.Context, cacheKey string) {
return
return
}
}
_
=
s
.
cache
.
DeleteAuthCache
(
ctx
,
cacheKey
)
_
=
s
.
cache
.
DeleteAuthCache
(
ctx
,
cacheKey
)
// Publish invalidation message to other instances
_
=
s
.
cache
.
PublishAuthCacheInvalidation
(
ctx
,
cacheKey
)
}
}
func
(
s
*
APIKeyService
)
loadAuthCacheEntry
(
ctx
context
.
Context
,
key
,
cacheKey
string
)
(
*
APIKeyAuthCacheEntry
,
error
)
{
func
(
s
*
APIKeyService
)
loadAuthCacheEntry
(
ctx
context
.
Context
,
key
,
cacheKey
string
)
(
*
APIKeyAuthCacheEntry
,
error
)
{
...
...
backend/internal/service/api_key_service.go
View file @
0170d19f
...
@@ -65,6 +65,10 @@ type APIKeyCache interface {
...
@@ -65,6 +65,10 @@ type APIKeyCache interface {
GetAuthCache
(
ctx
context
.
Context
,
key
string
)
(
*
APIKeyAuthCacheEntry
,
error
)
GetAuthCache
(
ctx
context
.
Context
,
key
string
)
(
*
APIKeyAuthCacheEntry
,
error
)
SetAuthCache
(
ctx
context
.
Context
,
key
string
,
entry
*
APIKeyAuthCacheEntry
,
ttl
time
.
Duration
)
error
SetAuthCache
(
ctx
context
.
Context
,
key
string
,
entry
*
APIKeyAuthCacheEntry
,
ttl
time
.
Duration
)
error
DeleteAuthCache
(
ctx
context
.
Context
,
key
string
)
error
DeleteAuthCache
(
ctx
context
.
Context
,
key
string
)
error
// Pub/Sub for L1 cache invalidation across instances
PublishAuthCacheInvalidation
(
ctx
context
.
Context
,
cacheKey
string
)
error
SubscribeAuthCacheInvalidation
(
ctx
context
.
Context
,
handler
func
(
cacheKey
string
))
error
}
}
// APIKeyAuthCacheInvalidator 提供认证缓存失效能力
// APIKeyAuthCacheInvalidator 提供认证缓存失效能力
...
...
backend/internal/service/api_key_service_cache_test.go
View file @
0170d19f
...
@@ -142,6 +142,14 @@ func (s *authCacheStub) DeleteAuthCache(ctx context.Context, key string) error {
...
@@ -142,6 +142,14 @@ func (s *authCacheStub) DeleteAuthCache(ctx context.Context, key string) error {
return
nil
return
nil
}
}
func
(
s
*
authCacheStub
)
PublishAuthCacheInvalidation
(
ctx
context
.
Context
,
cacheKey
string
)
error
{
return
nil
}
func
(
s
*
authCacheStub
)
SubscribeAuthCacheInvalidation
(
ctx
context
.
Context
,
handler
func
(
cacheKey
string
))
error
{
return
nil
}
func
TestAPIKeyService_GetByKey_UsesL2Cache
(
t
*
testing
.
T
)
{
func
TestAPIKeyService_GetByKey_UsesL2Cache
(
t
*
testing
.
T
)
{
cache
:=
&
authCacheStub
{}
cache
:=
&
authCacheStub
{}
repo
:=
&
authRepoStub
{
repo
:=
&
authRepoStub
{
...
...
backend/internal/service/api_key_service_delete_test.go
View file @
0170d19f
...
@@ -168,6 +168,14 @@ func (s *apiKeyCacheStub) DeleteAuthCache(ctx context.Context, key string) error
...
@@ -168,6 +168,14 @@ func (s *apiKeyCacheStub) DeleteAuthCache(ctx context.Context, key string) error
return
nil
return
nil
}
}
func
(
s
*
apiKeyCacheStub
)
PublishAuthCacheInvalidation
(
ctx
context
.
Context
,
cacheKey
string
)
error
{
return
nil
}
func
(
s
*
apiKeyCacheStub
)
SubscribeAuthCacheInvalidation
(
ctx
context
.
Context
,
handler
func
(
cacheKey
string
))
error
{
return
nil
}
// TestApiKeyService_Delete_OwnerMismatch 测试非所有者尝试删除时返回权限错误。
// TestApiKeyService_Delete_OwnerMismatch 测试非所有者尝试删除时返回权限错误。
// 预期行为:
// 预期行为:
// - GetKeyAndOwnerID 返回所有者 ID 为 1
// - GetKeyAndOwnerID 返回所有者 ID 为 1
...
...
backend/internal/service/auth_service.go
View file @
0170d19f
...
@@ -153,8 +153,8 @@ func (s *AuthService) RegisterWithVerification(ctx context.Context, email, passw
...
@@ -153,8 +153,8 @@ func (s *AuthService) RegisterWithVerification(ctx context.Context, email, passw
return
""
,
nil
,
ErrServiceUnavailable
return
""
,
nil
,
ErrServiceUnavailable
}
}
// 应用优惠码(如果提供)
// 应用优惠码(如果提供
且功能已启用
)
if
promoCode
!=
""
&&
s
.
promoService
!=
nil
{
if
promoCode
!=
""
&&
s
.
promoService
!=
nil
&&
s
.
settingService
!=
nil
&&
s
.
settingService
.
IsPromoCodeEnabled
(
ctx
)
{
if
err
:=
s
.
promoService
.
ApplyPromoCode
(
ctx
,
user
.
ID
,
promoCode
);
err
!=
nil
{
if
err
:=
s
.
promoService
.
ApplyPromoCode
(
ctx
,
user
.
ID
,
promoCode
);
err
!=
nil
{
// 优惠码应用失败不影响注册,只记录日志
// 优惠码应用失败不影响注册,只记录日志
log
.
Printf
(
"[Auth] Failed to apply promo code for user %d: %v"
,
user
.
ID
,
err
)
log
.
Printf
(
"[Auth] Failed to apply promo code for user %d: %v"
,
user
.
ID
,
err
)
...
@@ -580,3 +580,149 @@ func (s *AuthService) RefreshToken(ctx context.Context, oldTokenString string) (
...
@@ -580,3 +580,149 @@ func (s *AuthService) RefreshToken(ctx context.Context, oldTokenString string) (
// 生成新token
// 生成新token
return
s
.
GenerateToken
(
user
)
return
s
.
GenerateToken
(
user
)
}
}
// IsPasswordResetEnabled 检查是否启用密码重置功能
// 要求:必须同时开启邮件验证且 SMTP 配置正确
func
(
s
*
AuthService
)
IsPasswordResetEnabled
(
ctx
context
.
Context
)
bool
{
if
s
.
settingService
==
nil
{
return
false
}
// Must have email verification enabled and SMTP configured
if
!
s
.
settingService
.
IsEmailVerifyEnabled
(
ctx
)
{
return
false
}
return
s
.
settingService
.
IsPasswordResetEnabled
(
ctx
)
}
// preparePasswordReset validates the password reset request and returns necessary data
// Returns (siteName, resetURL, shouldProceed)
// shouldProceed is false when we should silently return success (to prevent enumeration)
func
(
s
*
AuthService
)
preparePasswordReset
(
ctx
context
.
Context
,
email
,
frontendBaseURL
string
)
(
string
,
string
,
bool
)
{
// Check if user exists (but don't reveal this to the caller)
user
,
err
:=
s
.
userRepo
.
GetByEmail
(
ctx
,
email
)
if
err
!=
nil
{
if
errors
.
Is
(
err
,
ErrUserNotFound
)
{
// Security: Log but don't reveal that user doesn't exist
log
.
Printf
(
"[Auth] Password reset requested for non-existent email: %s"
,
email
)
return
""
,
""
,
false
}
log
.
Printf
(
"[Auth] Database error checking email for password reset: %v"
,
err
)
return
""
,
""
,
false
}
// Check if user is active
if
!
user
.
IsActive
()
{
log
.
Printf
(
"[Auth] Password reset requested for inactive user: %s"
,
email
)
return
""
,
""
,
false
}
// Get site name
siteName
:=
"Sub2API"
if
s
.
settingService
!=
nil
{
siteName
=
s
.
settingService
.
GetSiteName
(
ctx
)
}
// Build reset URL base
resetURL
:=
fmt
.
Sprintf
(
"%s/reset-password"
,
strings
.
TrimSuffix
(
frontendBaseURL
,
"/"
))
return
siteName
,
resetURL
,
true
}
// RequestPasswordReset 请求密码重置(同步发送)
// Security: Returns the same response regardless of whether the email exists (prevent user enumeration)
func
(
s
*
AuthService
)
RequestPasswordReset
(
ctx
context
.
Context
,
email
,
frontendBaseURL
string
)
error
{
if
!
s
.
IsPasswordResetEnabled
(
ctx
)
{
return
infraerrors
.
Forbidden
(
"PASSWORD_RESET_DISABLED"
,
"password reset is not enabled"
)
}
if
s
.
emailService
==
nil
{
return
ErrServiceUnavailable
}
siteName
,
resetURL
,
shouldProceed
:=
s
.
preparePasswordReset
(
ctx
,
email
,
frontendBaseURL
)
if
!
shouldProceed
{
return
nil
// Silent success to prevent enumeration
}
if
err
:=
s
.
emailService
.
SendPasswordResetEmail
(
ctx
,
email
,
siteName
,
resetURL
);
err
!=
nil
{
log
.
Printf
(
"[Auth] Failed to send password reset email to %s: %v"
,
email
,
err
)
return
nil
// Silent success to prevent enumeration
}
log
.
Printf
(
"[Auth] Password reset email sent to: %s"
,
email
)
return
nil
}
// RequestPasswordResetAsync 异步请求密码重置(队列发送)
// Security: Returns the same response regardless of whether the email exists (prevent user enumeration)
func
(
s
*
AuthService
)
RequestPasswordResetAsync
(
ctx
context
.
Context
,
email
,
frontendBaseURL
string
)
error
{
if
!
s
.
IsPasswordResetEnabled
(
ctx
)
{
return
infraerrors
.
Forbidden
(
"PASSWORD_RESET_DISABLED"
,
"password reset is not enabled"
)
}
if
s
.
emailQueueService
==
nil
{
return
ErrServiceUnavailable
}
siteName
,
resetURL
,
shouldProceed
:=
s
.
preparePasswordReset
(
ctx
,
email
,
frontendBaseURL
)
if
!
shouldProceed
{
return
nil
// Silent success to prevent enumeration
}
if
err
:=
s
.
emailQueueService
.
EnqueuePasswordReset
(
email
,
siteName
,
resetURL
);
err
!=
nil
{
log
.
Printf
(
"[Auth] Failed to enqueue password reset email for %s: %v"
,
email
,
err
)
return
nil
// Silent success to prevent enumeration
}
log
.
Printf
(
"[Auth] Password reset email enqueued for: %s"
,
email
)
return
nil
}
// ResetPassword 重置密码
// Security: Increments TokenVersion to invalidate all existing JWT tokens
func
(
s
*
AuthService
)
ResetPassword
(
ctx
context
.
Context
,
email
,
token
,
newPassword
string
)
error
{
// Check if password reset is enabled
if
!
s
.
IsPasswordResetEnabled
(
ctx
)
{
return
infraerrors
.
Forbidden
(
"PASSWORD_RESET_DISABLED"
,
"password reset is not enabled"
)
}
if
s
.
emailService
==
nil
{
return
ErrServiceUnavailable
}
// Verify and consume the reset token (one-time use)
if
err
:=
s
.
emailService
.
ConsumePasswordResetToken
(
ctx
,
email
,
token
);
err
!=
nil
{
return
err
}
// Get user
user
,
err
:=
s
.
userRepo
.
GetByEmail
(
ctx
,
email
)
if
err
!=
nil
{
if
errors
.
Is
(
err
,
ErrUserNotFound
)
{
return
ErrInvalidResetToken
// Token was valid but user was deleted
}
log
.
Printf
(
"[Auth] Database error getting user for password reset: %v"
,
err
)
return
ErrServiceUnavailable
}
// Check if user is active
if
!
user
.
IsActive
()
{
return
ErrUserNotActive
}
// Hash new password
hashedPassword
,
err
:=
s
.
HashPassword
(
newPassword
)
if
err
!=
nil
{
return
fmt
.
Errorf
(
"hash password: %w"
,
err
)
}
// Update password and increment TokenVersion
user
.
PasswordHash
=
hashedPassword
user
.
TokenVersion
++
// Invalidate all existing tokens
if
err
:=
s
.
userRepo
.
Update
(
ctx
,
user
);
err
!=
nil
{
log
.
Printf
(
"[Auth] Database error updating password for user %d: %v"
,
user
.
ID
,
err
)
return
ErrServiceUnavailable
}
log
.
Printf
(
"[Auth] Password reset successful for user: %s"
,
email
)
return
nil
}
backend/internal/service/auth_service_register_test.go
View file @
0170d19f
...
@@ -71,6 +71,26 @@ func (s *emailCacheStub) DeleteVerificationCode(ctx context.Context, email strin
...
@@ -71,6 +71,26 @@ func (s *emailCacheStub) DeleteVerificationCode(ctx context.Context, email strin
return
nil
return
nil
}
}
func
(
s
*
emailCacheStub
)
GetPasswordResetToken
(
ctx
context
.
Context
,
email
string
)
(
*
PasswordResetTokenData
,
error
)
{
return
nil
,
nil
}
func
(
s
*
emailCacheStub
)
SetPasswordResetToken
(
ctx
context
.
Context
,
email
string
,
data
*
PasswordResetTokenData
,
ttl
time
.
Duration
)
error
{
return
nil
}
func
(
s
*
emailCacheStub
)
DeletePasswordResetToken
(
ctx
context
.
Context
,
email
string
)
error
{
return
nil
}
func
(
s
*
emailCacheStub
)
IsPasswordResetEmailInCooldown
(
ctx
context
.
Context
,
email
string
)
bool
{
return
false
}
func
(
s
*
emailCacheStub
)
SetPasswordResetEmailCooldown
(
ctx
context
.
Context
,
email
string
,
ttl
time
.
Duration
)
error
{
return
nil
}
func
newAuthService
(
repo
*
userRepoStub
,
settings
map
[
string
]
string
,
emailCache
EmailCache
)
*
AuthService
{
func
newAuthService
(
repo
*
userRepoStub
,
settings
map
[
string
]
string
,
emailCache
EmailCache
)
*
AuthService
{
cfg
:=
&
config
.
Config
{
cfg
:=
&
config
.
Config
{
JWT
:
config
.
JWTConfig
{
JWT
:
config
.
JWTConfig
{
...
...
backend/internal/service/claude_token_provider.go
View file @
0170d19f
...
@@ -181,26 +181,37 @@ func (p *ClaudeTokenProvider) GetAccessToken(ctx context.Context, account *Accou
...
@@ -181,26 +181,37 @@ func (p *ClaudeTokenProvider) GetAccessToken(ctx context.Context, account *Accou
return
""
,
errors
.
New
(
"access_token not found in credentials"
)
return
""
,
errors
.
New
(
"access_token not found in credentials"
)
}
}
// 3. 存入缓存
// 3. 存入缓存
(验证版本后再写入,避免异步刷新任务与请求线程的竞态条件)
if
p
.
tokenCache
!=
nil
{
if
p
.
tokenCache
!=
nil
{
ttl
:=
30
*
time
.
Minute
latestAccount
,
isStale
:=
CheckTokenVersion
(
ctx
,
account
,
p
.
accountRepo
)
if
refreshFailed
{
if
isStale
&&
latestAccount
!=
nil
{
//
刷新失败时使用短 TTL,避免失效 token 长时间缓存导致 401 抖动
//
版本过时,使用 DB 中的最新 token
ttl
=
time
.
Minute
slog
.
Debug
(
"claude_token_version_stale_use_latest"
,
"account_id"
,
account
.
ID
)
slog
.
Debug
(
"claude_token_cache_short_ttl"
,
"account_id"
,
account
.
ID
,
"reason"
,
"refresh_failed
"
)
accessToken
=
latestAccount
.
GetCredential
(
"access_token
"
)
}
else
if
expiresAt
!=
nil
{
if
strings
.
TrimSpace
(
accessToken
)
==
""
{
until
:=
time
.
Until
(
*
expiresAt
)
return
""
,
errors
.
New
(
"access_token not found after version check"
)
switch
{
}
case
until
>
claudeTokenCacheSkew
:
// 不写入缓存,让下次请求重新处理
ttl
=
until
-
claudeTokenCacheSkew
}
else
{
case
until
>
0
:
ttl
:=
30
*
time
.
Minute
ttl
=
until
if
refreshFailed
{
default
:
// 刷新失败时使用短 TTL,避免失效 token 长时间缓存导致 401 抖动
ttl
=
time
.
Minute
ttl
=
time
.
Minute
slog
.
Debug
(
"claude_token_cache_short_ttl"
,
"account_id"
,
account
.
ID
,
"reason"
,
"refresh_failed"
)
}
else
if
expiresAt
!=
nil
{
until
:=
time
.
Until
(
*
expiresAt
)
switch
{
case
until
>
claudeTokenCacheSkew
:
ttl
=
until
-
claudeTokenCacheSkew
case
until
>
0
:
ttl
=
until
default
:
ttl
=
time
.
Minute
}
}
if
err
:=
p
.
tokenCache
.
SetAccessToken
(
ctx
,
cacheKey
,
accessToken
,
ttl
);
err
!=
nil
{
slog
.
Warn
(
"claude_token_cache_set_failed"
,
"account_id"
,
account
.
ID
,
"error"
,
err
)
}
}
}
if
err
:=
p
.
tokenCache
.
SetAccessToken
(
ctx
,
cacheKey
,
accessToken
,
ttl
);
err
!=
nil
{
slog
.
Warn
(
"claude_token_cache_set_failed"
,
"account_id"
,
account
.
ID
,
"error"
,
err
)
}
}
}
}
...
...
backend/internal/service/dashboard_aggregation_service.go
View file @
0170d19f
...
@@ -20,12 +20,16 @@ var (
...
@@ -20,12 +20,16 @@ var (
// ErrDashboardBackfillDisabled 当配置禁用回填时返回。
// ErrDashboardBackfillDisabled 当配置禁用回填时返回。
ErrDashboardBackfillDisabled
=
errors
.
New
(
"仪表盘聚合回填已禁用"
)
ErrDashboardBackfillDisabled
=
errors
.
New
(
"仪表盘聚合回填已禁用"
)
// ErrDashboardBackfillTooLarge 当回填跨度超过限制时返回。
// ErrDashboardBackfillTooLarge 当回填跨度超过限制时返回。
ErrDashboardBackfillTooLarge
=
errors
.
New
(
"回填时间跨度过大"
)
ErrDashboardBackfillTooLarge
=
errors
.
New
(
"回填时间跨度过大"
)
errDashboardAggregationRunning
=
errors
.
New
(
"聚合作业正在运行"
)
)
)
// DashboardAggregationRepository 定义仪表盘预聚合仓储接口。
// DashboardAggregationRepository 定义仪表盘预聚合仓储接口。
type
DashboardAggregationRepository
interface
{
type
DashboardAggregationRepository
interface
{
AggregateRange
(
ctx
context
.
Context
,
start
,
end
time
.
Time
)
error
AggregateRange
(
ctx
context
.
Context
,
start
,
end
time
.
Time
)
error
// RecomputeRange 重新计算指定时间范围内的聚合数据(包含活跃用户等派生表)。
// 设计目的:当 usage_logs 被批量删除/回滚后,确保聚合表可恢复一致性。
RecomputeRange
(
ctx
context
.
Context
,
start
,
end
time
.
Time
)
error
GetAggregationWatermark
(
ctx
context
.
Context
)
(
time
.
Time
,
error
)
GetAggregationWatermark
(
ctx
context
.
Context
)
(
time
.
Time
,
error
)
UpdateAggregationWatermark
(
ctx
context
.
Context
,
aggregatedAt
time
.
Time
)
error
UpdateAggregationWatermark
(
ctx
context
.
Context
,
aggregatedAt
time
.
Time
)
error
CleanupAggregates
(
ctx
context
.
Context
,
hourlyCutoff
,
dailyCutoff
time
.
Time
)
error
CleanupAggregates
(
ctx
context
.
Context
,
hourlyCutoff
,
dailyCutoff
time
.
Time
)
error
...
@@ -112,6 +116,41 @@ func (s *DashboardAggregationService) TriggerBackfill(start, end time.Time) erro
...
@@ -112,6 +116,41 @@ func (s *DashboardAggregationService) TriggerBackfill(start, end time.Time) erro
return
nil
return
nil
}
}
// TriggerRecomputeRange 触发指定范围的重新计算(异步)。
// 与 TriggerBackfill 不同:
// - 不依赖 backfill_enabled(这是内部一致性修复)
// - 不更新 watermark(避免影响正常增量聚合游标)
func
(
s
*
DashboardAggregationService
)
TriggerRecomputeRange
(
start
,
end
time
.
Time
)
error
{
if
s
==
nil
||
s
.
repo
==
nil
{
return
errors
.
New
(
"聚合服务未初始化"
)
}
if
!
s
.
cfg
.
Enabled
{
return
errors
.
New
(
"聚合服务已禁用"
)
}
if
!
end
.
After
(
start
)
{
return
errors
.
New
(
"重新计算时间范围无效"
)
}
go
func
()
{
const
maxRetries
=
3
for
i
:=
0
;
i
<
maxRetries
;
i
++
{
ctx
,
cancel
:=
context
.
WithTimeout
(
context
.
Background
(),
defaultDashboardAggregationBackfillTimeout
)
err
:=
s
.
recomputeRange
(
ctx
,
start
,
end
)
cancel
()
if
err
==
nil
{
return
}
if
!
errors
.
Is
(
err
,
errDashboardAggregationRunning
)
{
log
.
Printf
(
"[DashboardAggregation] 重新计算失败: %v"
,
err
)
return
}
time
.
Sleep
(
5
*
time
.
Second
)
}
log
.
Printf
(
"[DashboardAggregation] 重新计算放弃: 聚合作业持续占用"
)
}()
return
nil
}
func
(
s
*
DashboardAggregationService
)
recomputeRecentDays
()
{
func
(
s
*
DashboardAggregationService
)
recomputeRecentDays
()
{
days
:=
s
.
cfg
.
RecomputeDays
days
:=
s
.
cfg
.
RecomputeDays
if
days
<=
0
{
if
days
<=
0
{
...
@@ -128,6 +167,24 @@ func (s *DashboardAggregationService) recomputeRecentDays() {
...
@@ -128,6 +167,24 @@ func (s *DashboardAggregationService) recomputeRecentDays() {
}
}
}
}
func
(
s
*
DashboardAggregationService
)
recomputeRange
(
ctx
context
.
Context
,
start
,
end
time
.
Time
)
error
{
if
!
atomic
.
CompareAndSwapInt32
(
&
s
.
running
,
0
,
1
)
{
return
errDashboardAggregationRunning
}
defer
atomic
.
StoreInt32
(
&
s
.
running
,
0
)
jobStart
:=
time
.
Now
()
.
UTC
()
if
err
:=
s
.
repo
.
RecomputeRange
(
ctx
,
start
,
end
);
err
!=
nil
{
return
err
}
log
.
Printf
(
"[DashboardAggregation] 重新计算完成 (start=%s end=%s duration=%s)"
,
start
.
UTC
()
.
Format
(
time
.
RFC3339
),
end
.
UTC
()
.
Format
(
time
.
RFC3339
),
time
.
Since
(
jobStart
)
.
String
(),
)
return
nil
}
func
(
s
*
DashboardAggregationService
)
runScheduledAggregation
()
{
func
(
s
*
DashboardAggregationService
)
runScheduledAggregation
()
{
if
!
atomic
.
CompareAndSwapInt32
(
&
s
.
running
,
0
,
1
)
{
if
!
atomic
.
CompareAndSwapInt32
(
&
s
.
running
,
0
,
1
)
{
return
return
...
@@ -179,7 +236,7 @@ func (s *DashboardAggregationService) runScheduledAggregation() {
...
@@ -179,7 +236,7 @@ func (s *DashboardAggregationService) runScheduledAggregation() {
func
(
s
*
DashboardAggregationService
)
backfillRange
(
ctx
context
.
Context
,
start
,
end
time
.
Time
)
error
{
func
(
s
*
DashboardAggregationService
)
backfillRange
(
ctx
context
.
Context
,
start
,
end
time
.
Time
)
error
{
if
!
atomic
.
CompareAndSwapInt32
(
&
s
.
running
,
0
,
1
)
{
if
!
atomic
.
CompareAndSwapInt32
(
&
s
.
running
,
0
,
1
)
{
return
err
ors
.
New
(
"聚合作业正在运行"
)
return
err
DashboardAggregationRunning
}
}
defer
atomic
.
StoreInt32
(
&
s
.
running
,
0
)
defer
atomic
.
StoreInt32
(
&
s
.
running
,
0
)
...
...
backend/internal/service/dashboard_aggregation_service_test.go
View file @
0170d19f
...
@@ -27,6 +27,10 @@ func (s *dashboardAggregationRepoTestStub) AggregateRange(ctx context.Context, s
...
@@ -27,6 +27,10 @@ func (s *dashboardAggregationRepoTestStub) AggregateRange(ctx context.Context, s
return
s
.
aggregateErr
return
s
.
aggregateErr
}
}
func
(
s
*
dashboardAggregationRepoTestStub
)
RecomputeRange
(
ctx
context
.
Context
,
start
,
end
time
.
Time
)
error
{
return
s
.
AggregateRange
(
ctx
,
start
,
end
)
}
func
(
s
*
dashboardAggregationRepoTestStub
)
GetAggregationWatermark
(
ctx
context
.
Context
)
(
time
.
Time
,
error
)
{
func
(
s
*
dashboardAggregationRepoTestStub
)
GetAggregationWatermark
(
ctx
context
.
Context
)
(
time
.
Time
,
error
)
{
return
s
.
watermark
,
nil
return
s
.
watermark
,
nil
}
}
...
...
backend/internal/service/dashboard_service.go
View file @
0170d19f
...
@@ -124,16 +124,16 @@ func (s *DashboardService) GetDashboardStats(ctx context.Context) (*usagestats.D
...
@@ -124,16 +124,16 @@ func (s *DashboardService) GetDashboardStats(ctx context.Context) (*usagestats.D
return
stats
,
nil
return
stats
,
nil
}
}
func
(
s
*
DashboardService
)
GetUsageTrendWithFilters
(
ctx
context
.
Context
,
startTime
,
endTime
time
.
Time
,
granularity
string
,
userID
,
apiKeyID
,
accountID
,
groupID
int64
,
model
string
,
stream
*
bool
)
([]
usagestats
.
TrendDataPoint
,
error
)
{
func
(
s
*
DashboardService
)
GetUsageTrendWithFilters
(
ctx
context
.
Context
,
startTime
,
endTime
time
.
Time
,
granularity
string
,
userID
,
apiKeyID
,
accountID
,
groupID
int64
,
model
string
,
stream
*
bool
,
billingType
*
int8
)
([]
usagestats
.
TrendDataPoint
,
error
)
{
trend
,
err
:=
s
.
usageRepo
.
GetUsageTrendWithFilters
(
ctx
,
startTime
,
endTime
,
granularity
,
userID
,
apiKeyID
,
accountID
,
groupID
,
model
,
stream
)
trend
,
err
:=
s
.
usageRepo
.
GetUsageTrendWithFilters
(
ctx
,
startTime
,
endTime
,
granularity
,
userID
,
apiKeyID
,
accountID
,
groupID
,
model
,
stream
,
billingType
)
if
err
!=
nil
{
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"get usage trend with filters: %w"
,
err
)
return
nil
,
fmt
.
Errorf
(
"get usage trend with filters: %w"
,
err
)
}
}
return
trend
,
nil
return
trend
,
nil
}
}
func
(
s
*
DashboardService
)
GetModelStatsWithFilters
(
ctx
context
.
Context
,
startTime
,
endTime
time
.
Time
,
userID
,
apiKeyID
,
accountID
,
groupID
int64
,
stream
*
bool
)
([]
usagestats
.
ModelStat
,
error
)
{
func
(
s
*
DashboardService
)
GetModelStatsWithFilters
(
ctx
context
.
Context
,
startTime
,
endTime
time
.
Time
,
userID
,
apiKeyID
,
accountID
,
groupID
int64
,
stream
*
bool
,
billingType
*
int8
)
([]
usagestats
.
ModelStat
,
error
)
{
stats
,
err
:=
s
.
usageRepo
.
GetModelStatsWithFilters
(
ctx
,
startTime
,
endTime
,
userID
,
apiKeyID
,
accountID
,
groupID
,
stream
)
stats
,
err
:=
s
.
usageRepo
.
GetModelStatsWithFilters
(
ctx
,
startTime
,
endTime
,
userID
,
apiKeyID
,
accountID
,
groupID
,
stream
,
billingType
)
if
err
!=
nil
{
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"get model stats with filters: %w"
,
err
)
return
nil
,
fmt
.
Errorf
(
"get model stats with filters: %w"
,
err
)
}
}
...
...
backend/internal/service/dashboard_service_test.go
View file @
0170d19f
...
@@ -101,6 +101,10 @@ func (s *dashboardAggregationRepoStub) AggregateRange(ctx context.Context, start
...
@@ -101,6 +101,10 @@ func (s *dashboardAggregationRepoStub) AggregateRange(ctx context.Context, start
return
nil
return
nil
}
}
func
(
s
*
dashboardAggregationRepoStub
)
RecomputeRange
(
ctx
context
.
Context
,
start
,
end
time
.
Time
)
error
{
return
nil
}
func
(
s
*
dashboardAggregationRepoStub
)
GetAggregationWatermark
(
ctx
context
.
Context
)
(
time
.
Time
,
error
)
{
func
(
s
*
dashboardAggregationRepoStub
)
GetAggregationWatermark
(
ctx
context
.
Context
)
(
time
.
Time
,
error
)
{
if
s
.
err
!=
nil
{
if
s
.
err
!=
nil
{
return
time
.
Time
{},
s
.
err
return
time
.
Time
{},
s
.
err
...
...
backend/internal/service/domain_constants.go
View file @
0170d19f
package
service
package
service
import
"github.com/Wei-Shaw/sub2api/internal/domain"
// Status constants
// Status constants
const
(
const
(
StatusActive
=
"a
ctive
"
StatusActive
=
domain
.
StatusA
ctive
StatusDisabled
=
"
disabled
"
StatusDisabled
=
d
omain
.
StatusD
isabled
StatusError
=
"e
rror
"
StatusError
=
domain
.
StatusE
rror
StatusUnused
=
"u
nused
"
StatusUnused
=
domain
.
StatusU
nused
StatusUsed
=
"u
sed
"
StatusUsed
=
domain
.
StatusU
sed
StatusExpired
=
"e
xpired
"
StatusExpired
=
domain
.
StatusE
xpired
)
)
// Role constants
// Role constants
const
(
const
(
RoleAdmin
=
"a
dmin
"
RoleAdmin
=
domain
.
RoleA
dmin
RoleUser
=
"u
ser
"
RoleUser
=
domain
.
RoleU
ser
)
)
// Platform constants
// Platform constants
const
(
const
(
PlatformAnthropic
=
"a
nthropic
"
PlatformAnthropic
=
domain
.
PlatformA
nthropic
PlatformOpenAI
=
"openai"
PlatformOpenAI
=
domain
.
PlatformOpenAI
PlatformGemini
=
"g
emini
"
PlatformGemini
=
domain
.
PlatformG
emini
PlatformAntigravity
=
"a
ntigravity
"
PlatformAntigravity
=
domain
.
PlatformA
ntigravity
)
)
// Account type constants
// Account type constants
const
(
const
(
AccountTypeOAuth
=
"oa
uth
"
// OAuth类型账号(full scope: profile + inference)
AccountTypeOAuth
=
domain
.
AccountTypeOA
uth
// OAuth类型账号(full scope: profile + inference)
AccountTypeSetupToken
=
"s
etup
-t
oken
"
// Setup Token类型账号(inference only scope)
AccountTypeSetupToken
=
domain
.
AccountTypeS
etup
T
oken
// Setup Token类型账号(inference only scope)
AccountTypeAPIKey
=
"apikey"
// API Key类型账号
AccountTypeAPIKey
=
domain
.
AccountTypeAPIKey
// API Key类型账号
)
)
// Redeem type constants
// Redeem type constants
const
(
const
(
RedeemTypeBalance
=
"b
alance
"
RedeemTypeBalance
=
domain
.
RedeemTypeB
alance
RedeemTypeConcurrency
=
"c
oncurrency
"
RedeemTypeConcurrency
=
domain
.
RedeemTypeC
oncurrency
RedeemTypeSubscription
=
"s
ubscription
"
RedeemTypeSubscription
=
domain
.
RedeemTypeS
ubscription
)
)
// PromoCode status constants
// PromoCode status constants
const
(
const
(
PromoCodeStatusActive
=
"a
ctive
"
PromoCodeStatusActive
=
domain
.
PromoCodeStatusA
ctive
PromoCodeStatusDisabled
=
"
disabled
"
PromoCodeStatusDisabled
=
d
omain
.
PromoCodeStatusD
isabled
)
)
// Admin adjustment type constants
// Admin adjustment type constants
const
(
const
(
AdjustmentTypeAdminBalance
=
"a
dmin
_b
alance
"
// 管理员调整余额
AdjustmentTypeAdminBalance
=
domain
.
AdjustmentTypeA
dmin
B
alance
// 管理员调整余额
AdjustmentTypeAdminConcurrency
=
"a
dmin
_c
oncurrency
"
// 管理员调整并发数
AdjustmentTypeAdminConcurrency
=
domain
.
AdjustmentTypeA
dmin
C
oncurrency
// 管理员调整并发数
)
)
// Group subscription type constants
// Group subscription type constants
const
(
const
(
SubscriptionTypeStandard
=
"s
tandard
"
// 标准计费模式(按余额扣费)
SubscriptionTypeStandard
=
domain
.
SubscriptionTypeS
tandard
// 标准计费模式(按余额扣费)
SubscriptionTypeSubscription
=
"s
ubscription
"
// 订阅模式(按限额控制)
SubscriptionTypeSubscription
=
domain
.
SubscriptionTypeS
ubscription
// 订阅模式(按限额控制)
)
)
// Subscription status constants
// Subscription status constants
const
(
const
(
SubscriptionStatusActive
=
"a
ctive
"
SubscriptionStatusActive
=
domain
.
SubscriptionStatusA
ctive
SubscriptionStatusExpired
=
"e
xpired
"
SubscriptionStatusExpired
=
domain
.
SubscriptionStatusE
xpired
SubscriptionStatusSuspended
=
"s
uspended
"
SubscriptionStatusSuspended
=
domain
.
SubscriptionStatusS
uspended
)
)
// LinuxDoConnectSyntheticEmailDomain 是 LinuxDo Connect 用户的合成邮箱后缀(RFC 保留域名)。
// LinuxDoConnectSyntheticEmailDomain 是 LinuxDo Connect 用户的合成邮箱后缀(RFC 保留域名)。
...
@@ -69,8 +71,10 @@ const LinuxDoConnectSyntheticEmailDomain = "@linuxdo-connect.invalid"
...
@@ -69,8 +71,10 @@ const LinuxDoConnectSyntheticEmailDomain = "@linuxdo-connect.invalid"
// Setting keys
// Setting keys
const
(
const
(
// 注册设置
// 注册设置
SettingKeyRegistrationEnabled
=
"registration_enabled"
// 是否开放注册
SettingKeyRegistrationEnabled
=
"registration_enabled"
// 是否开放注册
SettingKeyEmailVerifyEnabled
=
"email_verify_enabled"
// 是否开启邮件验证
SettingKeyEmailVerifyEnabled
=
"email_verify_enabled"
// 是否开启邮件验证
SettingKeyPromoCodeEnabled
=
"promo_code_enabled"
// 是否启用优惠码功能
SettingKeyPasswordResetEnabled
=
"password_reset_enabled"
// 是否启用忘记密码功能(需要先开启邮件验证)
// 邮件服务设置
// 邮件服务设置
SettingKeySMTPHost
=
"smtp_host"
// SMTP服务器地址
SettingKeySMTPHost
=
"smtp_host"
// SMTP服务器地址
...
@@ -86,6 +90,9 @@ const (
...
@@ -86,6 +90,9 @@ const (
SettingKeyTurnstileSiteKey
=
"turnstile_site_key"
// Turnstile Site Key
SettingKeyTurnstileSiteKey
=
"turnstile_site_key"
// Turnstile Site Key
SettingKeyTurnstileSecretKey
=
"turnstile_secret_key"
// Turnstile Secret Key
SettingKeyTurnstileSecretKey
=
"turnstile_secret_key"
// Turnstile Secret Key
// TOTP 双因素认证设置
SettingKeyTotpEnabled
=
"totp_enabled"
// 是否启用 TOTP 2FA 功能
// LinuxDo Connect OAuth 登录设置
// LinuxDo Connect OAuth 登录设置
SettingKeyLinuxDoConnectEnabled
=
"linuxdo_connect_enabled"
SettingKeyLinuxDoConnectEnabled
=
"linuxdo_connect_enabled"
SettingKeyLinuxDoConnectClientID
=
"linuxdo_connect_client_id"
SettingKeyLinuxDoConnectClientID
=
"linuxdo_connect_client_id"
...
@@ -93,13 +100,16 @@ const (
...
@@ -93,13 +100,16 @@ const (
SettingKeyLinuxDoConnectRedirectURL
=
"linuxdo_connect_redirect_url"
SettingKeyLinuxDoConnectRedirectURL
=
"linuxdo_connect_redirect_url"
// OEM设置
// OEM设置
SettingKeySiteName
=
"site_name"
// 网站名称
SettingKeySiteName
=
"site_name"
// 网站名称
SettingKeySiteLogo
=
"site_logo"
// 网站Logo (base64)
SettingKeySiteLogo
=
"site_logo"
// 网站Logo (base64)
SettingKeySiteSubtitle
=
"site_subtitle"
// 网站副标题
SettingKeySiteSubtitle
=
"site_subtitle"
// 网站副标题
SettingKeyAPIBaseURL
=
"api_base_url"
// API端点地址(用于客户端配置和导入)
SettingKeyAPIBaseURL
=
"api_base_url"
// API端点地址(用于客户端配置和导入)
SettingKeyContactInfo
=
"contact_info"
// 客服联系方式
SettingKeyContactInfo
=
"contact_info"
// 客服联系方式
SettingKeyDocURL
=
"doc_url"
// 文档链接
SettingKeyDocURL
=
"doc_url"
// 文档链接
SettingKeyHomeContent
=
"home_content"
// 首页内容(支持 Markdown/HTML,或 URL 作为 iframe src)
SettingKeyHomeContent
=
"home_content"
// 首页内容(支持 Markdown/HTML,或 URL 作为 iframe src)
SettingKeyHideCcsImportButton
=
"hide_ccs_import_button"
// 是否隐藏 API Keys 页面的导入 CCS 按钮
SettingKeyPurchaseSubscriptionEnabled
=
"purchase_subscription_enabled"
// 是否展示“购买订阅”页面入口
SettingKeyPurchaseSubscriptionURL
=
"purchase_subscription_url"
// “购买订阅”页面 URL(作为 iframe src)
// 默认配置
// 默认配置
SettingKeyDefaultConcurrency
=
"default_concurrency"
// 新用户默认并发量
SettingKeyDefaultConcurrency
=
"default_concurrency"
// 新用户默认并发量
...
...
backend/internal/service/email_queue_service.go
View file @
0170d19f
...
@@ -8,11 +8,18 @@ import (
...
@@ -8,11 +8,18 @@ import (
"time"
"time"
)
)
// Task type constants
const
(
TaskTypeVerifyCode
=
"verify_code"
TaskTypePasswordReset
=
"password_reset"
)
// EmailTask 邮件发送任务
// EmailTask 邮件发送任务
type
EmailTask
struct
{
type
EmailTask
struct
{
Email
string
Email
string
SiteName
string
SiteName
string
TaskType
string
// "verify_code"
TaskType
string
// "verify_code" or "password_reset"
ResetURL
string
// Only used for password_reset task type
}
}
// EmailQueueService 异步邮件队列服务
// EmailQueueService 异步邮件队列服务
...
@@ -73,12 +80,18 @@ func (s *EmailQueueService) processTask(workerID int, task EmailTask) {
...
@@ -73,12 +80,18 @@ func (s *EmailQueueService) processTask(workerID int, task EmailTask) {
defer
cancel
()
defer
cancel
()
switch
task
.
TaskType
{
switch
task
.
TaskType
{
case
"v
erify
_c
ode
"
:
case
TaskTypeV
erify
C
ode
:
if
err
:=
s
.
emailService
.
SendVerifyCode
(
ctx
,
task
.
Email
,
task
.
SiteName
);
err
!=
nil
{
if
err
:=
s
.
emailService
.
SendVerifyCode
(
ctx
,
task
.
Email
,
task
.
SiteName
);
err
!=
nil
{
log
.
Printf
(
"[EmailQueue] Worker %d failed to send verify code to %s: %v"
,
workerID
,
task
.
Email
,
err
)
log
.
Printf
(
"[EmailQueue] Worker %d failed to send verify code to %s: %v"
,
workerID
,
task
.
Email
,
err
)
}
else
{
}
else
{
log
.
Printf
(
"[EmailQueue] Worker %d sent verify code to %s"
,
workerID
,
task
.
Email
)
log
.
Printf
(
"[EmailQueue] Worker %d sent verify code to %s"
,
workerID
,
task
.
Email
)
}
}
case
TaskTypePasswordReset
:
if
err
:=
s
.
emailService
.
SendPasswordResetEmailWithCooldown
(
ctx
,
task
.
Email
,
task
.
SiteName
,
task
.
ResetURL
);
err
!=
nil
{
log
.
Printf
(
"[EmailQueue] Worker %d failed to send password reset to %s: %v"
,
workerID
,
task
.
Email
,
err
)
}
else
{
log
.
Printf
(
"[EmailQueue] Worker %d sent password reset to %s"
,
workerID
,
task
.
Email
)
}
default
:
default
:
log
.
Printf
(
"[EmailQueue] Worker %d unknown task type: %s"
,
workerID
,
task
.
TaskType
)
log
.
Printf
(
"[EmailQueue] Worker %d unknown task type: %s"
,
workerID
,
task
.
TaskType
)
}
}
...
@@ -89,7 +102,7 @@ func (s *EmailQueueService) EnqueueVerifyCode(email, siteName string) error {
...
@@ -89,7 +102,7 @@ func (s *EmailQueueService) EnqueueVerifyCode(email, siteName string) error {
task
:=
EmailTask
{
task
:=
EmailTask
{
Email
:
email
,
Email
:
email
,
SiteName
:
siteName
,
SiteName
:
siteName
,
TaskType
:
"v
erify
_c
ode
"
,
TaskType
:
TaskTypeV
erify
C
ode
,
}
}
select
{
select
{
...
@@ -101,6 +114,24 @@ func (s *EmailQueueService) EnqueueVerifyCode(email, siteName string) error {
...
@@ -101,6 +114,24 @@ func (s *EmailQueueService) EnqueueVerifyCode(email, siteName string) error {
}
}
}
}
// EnqueuePasswordReset 将密码重置邮件任务加入队列
func
(
s
*
EmailQueueService
)
EnqueuePasswordReset
(
email
,
siteName
,
resetURL
string
)
error
{
task
:=
EmailTask
{
Email
:
email
,
SiteName
:
siteName
,
TaskType
:
TaskTypePasswordReset
,
ResetURL
:
resetURL
,
}
select
{
case
s
.
taskChan
<-
task
:
log
.
Printf
(
"[EmailQueue] Enqueued password reset task for %s"
,
email
)
return
nil
default
:
return
fmt
.
Errorf
(
"email queue is full"
)
}
}
// Stop 停止队列服务
// Stop 停止队列服务
func
(
s
*
EmailQueueService
)
Stop
()
{
func
(
s
*
EmailQueueService
)
Stop
()
{
close
(
s
.
stopChan
)
close
(
s
.
stopChan
)
...
...
backend/internal/service/email_service.go
View file @
0170d19f
...
@@ -3,11 +3,14 @@ package service
...
@@ -3,11 +3,14 @@ package service
import
(
import
(
"context"
"context"
"crypto/rand"
"crypto/rand"
"crypto/subtle"
"crypto/tls"
"crypto/tls"
"encoding/hex"
"fmt"
"fmt"
"log"
"log"
"math/big"
"math/big"
"net/smtp"
"net/smtp"
"net/url"
"strconv"
"strconv"
"time"
"time"
...
@@ -19,6 +22,9 @@ var (
...
@@ -19,6 +22,9 @@ var (
ErrInvalidVerifyCode
=
infraerrors
.
BadRequest
(
"INVALID_VERIFY_CODE"
,
"invalid or expired verification code"
)
ErrInvalidVerifyCode
=
infraerrors
.
BadRequest
(
"INVALID_VERIFY_CODE"
,
"invalid or expired verification code"
)
ErrVerifyCodeTooFrequent
=
infraerrors
.
TooManyRequests
(
"VERIFY_CODE_TOO_FREQUENT"
,
"please wait before requesting a new code"
)
ErrVerifyCodeTooFrequent
=
infraerrors
.
TooManyRequests
(
"VERIFY_CODE_TOO_FREQUENT"
,
"please wait before requesting a new code"
)
ErrVerifyCodeMaxAttempts
=
infraerrors
.
TooManyRequests
(
"VERIFY_CODE_MAX_ATTEMPTS"
,
"too many failed attempts, please request a new code"
)
ErrVerifyCodeMaxAttempts
=
infraerrors
.
TooManyRequests
(
"VERIFY_CODE_MAX_ATTEMPTS"
,
"too many failed attempts, please request a new code"
)
// Password reset errors
ErrInvalidResetToken
=
infraerrors
.
BadRequest
(
"INVALID_RESET_TOKEN"
,
"invalid or expired password reset token"
)
)
)
// EmailCache defines cache operations for email service
// EmailCache defines cache operations for email service
...
@@ -26,6 +32,16 @@ type EmailCache interface {
...
@@ -26,6 +32,16 @@ type EmailCache interface {
GetVerificationCode
(
ctx
context
.
Context
,
email
string
)
(
*
VerificationCodeData
,
error
)
GetVerificationCode
(
ctx
context
.
Context
,
email
string
)
(
*
VerificationCodeData
,
error
)
SetVerificationCode
(
ctx
context
.
Context
,
email
string
,
data
*
VerificationCodeData
,
ttl
time
.
Duration
)
error
SetVerificationCode
(
ctx
context
.
Context
,
email
string
,
data
*
VerificationCodeData
,
ttl
time
.
Duration
)
error
DeleteVerificationCode
(
ctx
context
.
Context
,
email
string
)
error
DeleteVerificationCode
(
ctx
context
.
Context
,
email
string
)
error
// Password reset token methods
GetPasswordResetToken
(
ctx
context
.
Context
,
email
string
)
(
*
PasswordResetTokenData
,
error
)
SetPasswordResetToken
(
ctx
context
.
Context
,
email
string
,
data
*
PasswordResetTokenData
,
ttl
time
.
Duration
)
error
DeletePasswordResetToken
(
ctx
context
.
Context
,
email
string
)
error
// Password reset email cooldown methods
// Returns true if in cooldown period (email was sent recently)
IsPasswordResetEmailInCooldown
(
ctx
context
.
Context
,
email
string
)
bool
SetPasswordResetEmailCooldown
(
ctx
context
.
Context
,
email
string
,
ttl
time
.
Duration
)
error
}
}
// VerificationCodeData represents verification code data
// VerificationCodeData represents verification code data
...
@@ -35,10 +51,22 @@ type VerificationCodeData struct {
...
@@ -35,10 +51,22 @@ type VerificationCodeData struct {
CreatedAt
time
.
Time
CreatedAt
time
.
Time
}
}
// PasswordResetTokenData represents password reset token data
type
PasswordResetTokenData
struct
{
Token
string
CreatedAt
time
.
Time
}
const
(
const
(
verifyCodeTTL
=
15
*
time
.
Minute
verifyCodeTTL
=
15
*
time
.
Minute
verifyCodeCooldown
=
1
*
time
.
Minute
verifyCodeCooldown
=
1
*
time
.
Minute
maxVerifyCodeAttempts
=
5
maxVerifyCodeAttempts
=
5
// Password reset token settings
passwordResetTokenTTL
=
30
*
time
.
Minute
// Password reset email cooldown (prevent email bombing)
passwordResetEmailCooldown
=
30
*
time
.
Second
)
)
// SMTPConfig SMTP配置
// SMTPConfig SMTP配置
...
@@ -254,8 +282,8 @@ func (s *EmailService) VerifyCode(ctx context.Context, email, code string) error
...
@@ -254,8 +282,8 @@ func (s *EmailService) VerifyCode(ctx context.Context, email, code string) error
return
ErrVerifyCodeMaxAttempts
return
ErrVerifyCodeMaxAttempts
}
}
// 验证码不匹配
// 验证码不匹配
(constant-time comparison to prevent timing attacks)
if
data
.
Code
!=
code
{
if
subtle
.
ConstantTimeCompare
([]
byte
(
data
.
Code
),
[]
byte
(
code
))
!=
1
{
data
.
Attempts
++
data
.
Attempts
++
if
err
:=
s
.
cache
.
SetVerificationCode
(
ctx
,
email
,
data
,
verifyCodeTTL
);
err
!=
nil
{
if
err
:=
s
.
cache
.
SetVerificationCode
(
ctx
,
email
,
data
,
verifyCodeTTL
);
err
!=
nil
{
log
.
Printf
(
"[Email] Failed to update verification attempt count: %v"
,
err
)
log
.
Printf
(
"[Email] Failed to update verification attempt count: %v"
,
err
)
...
@@ -357,3 +385,157 @@ func (s *EmailService) TestSMTPConnectionWithConfig(config *SMTPConfig) error {
...
@@ -357,3 +385,157 @@ func (s *EmailService) TestSMTPConnectionWithConfig(config *SMTPConfig) error {
return
client
.
Quit
()
return
client
.
Quit
()
}
}
// GeneratePasswordResetToken generates a secure 32-byte random token (64 hex characters)
func
(
s
*
EmailService
)
GeneratePasswordResetToken
()
(
string
,
error
)
{
bytes
:=
make
([]
byte
,
32
)
if
_
,
err
:=
rand
.
Read
(
bytes
);
err
!=
nil
{
return
""
,
err
}
return
hex
.
EncodeToString
(
bytes
),
nil
}
// SendPasswordResetEmail sends a password reset email with a reset link
func
(
s
*
EmailService
)
SendPasswordResetEmail
(
ctx
context
.
Context
,
email
,
siteName
,
resetURL
string
)
error
{
var
token
string
var
needSaveToken
bool
// Check if token already exists
existing
,
err
:=
s
.
cache
.
GetPasswordResetToken
(
ctx
,
email
)
if
err
==
nil
&&
existing
!=
nil
{
// Token exists, reuse it (allows resending email without generating new token)
token
=
existing
.
Token
needSaveToken
=
false
}
else
{
// Generate new token
token
,
err
=
s
.
GeneratePasswordResetToken
()
if
err
!=
nil
{
return
fmt
.
Errorf
(
"generate token: %w"
,
err
)
}
needSaveToken
=
true
}
// Save token to Redis (only if new token generated)
if
needSaveToken
{
data
:=
&
PasswordResetTokenData
{
Token
:
token
,
CreatedAt
:
time
.
Now
(),
}
if
err
:=
s
.
cache
.
SetPasswordResetToken
(
ctx
,
email
,
data
,
passwordResetTokenTTL
);
err
!=
nil
{
return
fmt
.
Errorf
(
"save reset token: %w"
,
err
)
}
}
// Build full reset URL with URL-encoded token and email
fullResetURL
:=
fmt
.
Sprintf
(
"%s?email=%s&token=%s"
,
resetURL
,
url
.
QueryEscape
(
email
),
url
.
QueryEscape
(
token
))
// Build email content
subject
:=
fmt
.
Sprintf
(
"[%s] 密码重置请求"
,
siteName
)
body
:=
s
.
buildPasswordResetEmailBody
(
fullResetURL
,
siteName
)
// Send email
if
err
:=
s
.
SendEmail
(
ctx
,
email
,
subject
,
body
);
err
!=
nil
{
return
fmt
.
Errorf
(
"send email: %w"
,
err
)
}
return
nil
}
// SendPasswordResetEmailWithCooldown sends password reset email with cooldown check (called by queue worker)
// This method wraps SendPasswordResetEmail with email cooldown to prevent email bombing
func
(
s
*
EmailService
)
SendPasswordResetEmailWithCooldown
(
ctx
context
.
Context
,
email
,
siteName
,
resetURL
string
)
error
{
// Check email cooldown to prevent email bombing
if
s
.
cache
.
IsPasswordResetEmailInCooldown
(
ctx
,
email
)
{
log
.
Printf
(
"[Email] Password reset email skipped (cooldown): %s"
,
email
)
return
nil
// Silent success to prevent revealing cooldown to attackers
}
// Send email using core method
if
err
:=
s
.
SendPasswordResetEmail
(
ctx
,
email
,
siteName
,
resetURL
);
err
!=
nil
{
return
err
}
// Set cooldown marker (Redis TTL handles expiration)
if
err
:=
s
.
cache
.
SetPasswordResetEmailCooldown
(
ctx
,
email
,
passwordResetEmailCooldown
);
err
!=
nil
{
log
.
Printf
(
"[Email] Failed to set password reset cooldown for %s: %v"
,
email
,
err
)
}
return
nil
}
// VerifyPasswordResetToken verifies the password reset token without consuming it
func
(
s
*
EmailService
)
VerifyPasswordResetToken
(
ctx
context
.
Context
,
email
,
token
string
)
error
{
data
,
err
:=
s
.
cache
.
GetPasswordResetToken
(
ctx
,
email
)
if
err
!=
nil
||
data
==
nil
{
return
ErrInvalidResetToken
}
// Use constant-time comparison to prevent timing attacks
if
subtle
.
ConstantTimeCompare
([]
byte
(
data
.
Token
),
[]
byte
(
token
))
!=
1
{
return
ErrInvalidResetToken
}
return
nil
}
// ConsumePasswordResetToken verifies and deletes the token (one-time use)
func
(
s
*
EmailService
)
ConsumePasswordResetToken
(
ctx
context
.
Context
,
email
,
token
string
)
error
{
// Verify first
if
err
:=
s
.
VerifyPasswordResetToken
(
ctx
,
email
,
token
);
err
!=
nil
{
return
err
}
// Delete after verification (one-time use)
if
err
:=
s
.
cache
.
DeletePasswordResetToken
(
ctx
,
email
);
err
!=
nil
{
log
.
Printf
(
"[Email] Failed to delete password reset token after consumption: %v"
,
err
)
}
return
nil
}
// buildPasswordResetEmailBody builds the HTML content for password reset email
func
(
s
*
EmailService
)
buildPasswordResetEmailBody
(
resetURL
,
siteName
string
)
string
{
return
fmt
.
Sprintf
(
`
<!DOCTYPE html>
<html>
<head>
<meta charset="UTF-8">
<style>
body { font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, Oxygen, Ubuntu, sans-serif; background-color: #f5f5f5; margin: 0; padding: 20px; }
.container { max-width: 600px; margin: 0 auto; background-color: #ffffff; border-radius: 8px; overflow: hidden; box-shadow: 0 2px 8px rgba(0,0,0,0.1); }
.header { background: linear-gradient(135deg, #667eea 0%%, #764ba2 100%%); color: white; padding: 30px; text-align: center; }
.header h1 { margin: 0; font-size: 24px; }
.content { padding: 40px 30px; text-align: center; }
.button { display: inline-block; background: linear-gradient(135deg, #667eea 0%%, #764ba2 100%%); color: white; padding: 14px 32px; text-decoration: none; border-radius: 8px; font-size: 16px; font-weight: 600; margin: 20px 0; }
.button:hover { opacity: 0.9; }
.info { color: #666; font-size: 14px; line-height: 1.6; margin-top: 20px; }
.link-fallback { color: #666; font-size: 12px; word-break: break-all; margin-top: 20px; padding: 15px; background-color: #f8f9fa; border-radius: 4px; }
.footer { background-color: #f8f9fa; padding: 20px; text-align: center; color: #999; font-size: 12px; }
.warning { color: #e74c3c; font-weight: 500; }
</style>
</head>
<body>
<div class="container">
<div class="header">
<h1>%s</h1>
</div>
<div class="content">
<p style="font-size: 18px; color: #333;">密码重置请求</p>
<p style="color: #666;">您已请求重置密码。请点击下方按钮设置新密码:</p>
<a href="%s" class="button">重置密码</a>
<div class="info">
<p>此链接将在 <strong>30 分钟</strong>后失效。</p>
<p class="warning">如果您没有请求重置密码,请忽略此邮件。您的密码将保持不变。</p>
</div>
<div class="link-fallback">
<p>如果按钮无法点击,请复制以下链接到浏览器中打开:</p>
<p>%s</p>
</div>
</div>
<div class="footer">
<p>这是一封自动发送的邮件,请勿回复。</p>
</div>
</div>
</body>
</html>
`
,
siteName
,
resetURL
,
resetURL
)
}
Prev
1
…
5
6
7
8
9
10
11
12
13
…
16
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment