Commit 0170d19f authored by song's avatar song
Browse files

merge upstream main

parent 7ade9baa
...@@ -26,7 +26,7 @@ import ( ...@@ -26,7 +26,7 @@ import (
const ( const (
antigravityStickySessionTTL = time.Hour antigravityStickySessionTTL = time.Hour
antigravityDefaultMaxRetries = 5 antigravityDefaultMaxRetries = 3
antigravityRetryBaseDelay = 1 * time.Second antigravityRetryBaseDelay = 1 * time.Second
antigravityRetryMaxDelay = 16 * time.Second antigravityRetryMaxDelay = 16 * time.Second
) )
...@@ -52,11 +52,11 @@ type antigravityRetryLoopParams struct { ...@@ -52,11 +52,11 @@ type antigravityRetryLoopParams struct {
action string action string
body []byte body []byte
quotaScope AntigravityQuotaScope quotaScope AntigravityQuotaScope
maxRetries int
c *gin.Context c *gin.Context
httpUpstream HTTPUpstream httpUpstream HTTPUpstream
settingService *SettingService settingService *SettingService
handleError func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope) handleError func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope)
maxRetries int // 可选,0 表示使用平台级默认值
} }
// antigravityRetryLoopResult 重试循环的结果 // antigravityRetryLoopResult 重试循环的结果
...@@ -82,9 +82,10 @@ func antigravityRetryLoop(p antigravityRetryLoopParams) (*antigravityRetryLoopRe ...@@ -82,9 +82,10 @@ func antigravityRetryLoop(p antigravityRetryLoopParams) (*antigravityRetryLoopRe
if len(availableURLs) == 0 { if len(availableURLs) == 0 {
availableURLs = baseURLs availableURLs = baseURLs
} }
maxRetries := p.maxRetries maxRetries := p.maxRetries
if maxRetries <= 0 { if maxRetries <= 0 {
maxRetries = antigravityMaxRetries() maxRetries = antigravityDefaultMaxRetries
} }
var resp *http.Response var resp *http.Response
...@@ -161,7 +162,7 @@ urlFallbackLoop: ...@@ -161,7 +162,7 @@ urlFallbackLoop:
continue urlFallbackLoop continue urlFallbackLoop
} }
// 账户/模型配额限流,按最大重试次数做指数退避 // 账户/模型配额限流,重试 3 次(指数退避
if attempt < maxRetries { if attempt < maxRetries {
upstreamMsg := strings.TrimSpace(extractAntigravityErrorMessage(respBody)) upstreamMsg := strings.TrimSpace(extractAntigravityErrorMessage(respBody))
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg) upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
...@@ -1044,7 +1045,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context, ...@@ -1044,7 +1045,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
return &ForwardResult{ return &ForwardResult{
RequestID: requestID, RequestID: requestID,
Usage: *usage, Usage: *usage,
Model: billingModel, Model: billingModel, // 计费模型(可按映射模型覆盖)
Stream: claudeReq.Stream, Stream: claudeReq.Stream,
Duration: time.Since(startTime), Duration: time.Since(startTime),
FirstTokenMs: firstTokenMs, FirstTokenMs: firstTokenMs,
...@@ -1729,7 +1730,6 @@ func antigravityFallbackCooldownSeconds() (time.Duration, bool) { ...@@ -1729,7 +1730,6 @@ func antigravityFallbackCooldownSeconds() (time.Duration, bool) {
} }
return time.Duration(seconds) * time.Second, true return time.Duration(seconds) * time.Second, true
} }
func (s *AntigravityGatewayService) handleUpstreamError(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope) { func (s *AntigravityGatewayService) handleUpstreamError(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope) {
// 429 使用 Gemini 格式解析(从 body 解析重置时间) // 429 使用 Gemini 格式解析(从 body 解析重置时间)
if statusCode == 429 { if statusCode == 429 {
...@@ -1742,9 +1742,6 @@ func (s *AntigravityGatewayService) handleUpstreamError(ctx context.Context, pre ...@@ -1742,9 +1742,6 @@ func (s *AntigravityGatewayService) handleUpstreamError(ctx context.Context, pre
fallbackMinutes = s.settingService.cfg.Gateway.AntigravityFallbackCooldownMinutes fallbackMinutes = s.settingService.cfg.Gateway.AntigravityFallbackCooldownMinutes
} }
defaultDur := time.Duration(fallbackMinutes) * time.Minute defaultDur := time.Duration(fallbackMinutes) * time.Minute
if override, ok := antigravityFallbackCooldownSeconds(); ok {
defaultDur = override
}
ra := time.Now().Add(defaultDur) ra := time.Now().Add(defaultDur)
if useScopeLimit { if useScopeLimit {
log.Printf("%s status=429 rate_limited scope=%s reset_in=%v (fallback)", prefix, quotaScope, defaultDur) log.Printf("%s status=429 rate_limited scope=%s reset_in=%v (fallback)", prefix, quotaScope, defaultDur)
...@@ -2185,6 +2182,58 @@ func getOrCreateGeminiParts(response map[string]any) (result map[string]any, exi ...@@ -2185,6 +2182,58 @@ func getOrCreateGeminiParts(response map[string]any) (result map[string]any, exi
return result, existingParts, setParts return result, existingParts, setParts
} }
// mergeCollectedPartsToResponse 将收集的所有 parts 合并到 Gemini 响应中
// 这个函数会合并所有类型的 parts:text、thinking、functionCall、inlineData 等
// 保持原始顺序,只合并连续的普通 text parts
func mergeCollectedPartsToResponse(response map[string]any, collectedParts []map[string]any) map[string]any {
if len(collectedParts) == 0 {
return response
}
result, _, setParts := getOrCreateGeminiParts(response)
// 合并策略:
// 1. 保持原始顺序
// 2. 连续的普通 text parts 合并为一个
// 3. thinking、functionCall、inlineData 等保持原样
var mergedParts []any
var textBuffer strings.Builder
flushTextBuffer := func() {
if textBuffer.Len() > 0 {
mergedParts = append(mergedParts, map[string]any{
"text": textBuffer.String(),
})
textBuffer.Reset()
}
}
for _, part := range collectedParts {
// 检查是否是普通 text part
if text, ok := part["text"].(string); ok {
// 检查是否有 thought 标记
if thought, _ := part["thought"].(bool); thought {
// thinking part,先刷新 text buffer,然后保留原样
flushTextBuffer()
mergedParts = append(mergedParts, part)
} else {
// 普通 text,累积到 buffer
_, _ = textBuffer.WriteString(text)
}
} else {
// 非 text part(functionCall、inlineData 等),先刷新 text buffer,然后保留原样
flushTextBuffer()
mergedParts = append(mergedParts, part)
}
}
// 刷新剩余的 text
flushTextBuffer()
setParts(mergedParts)
return result
}
// mergeImagePartsToResponse 将收集到的图片 parts 合并到 Gemini 响应中 // mergeImagePartsToResponse 将收集到的图片 parts 合并到 Gemini 响应中
func mergeImagePartsToResponse(response map[string]any, imageParts []map[string]any) map[string]any { func mergeImagePartsToResponse(response map[string]any, imageParts []map[string]any) map[string]any {
if len(imageParts) == 0 { if len(imageParts) == 0 {
...@@ -2372,8 +2421,7 @@ func (s *AntigravityGatewayService) handleClaudeStreamToNonStreaming(c *gin.Cont ...@@ -2372,8 +2421,7 @@ func (s *AntigravityGatewayService) handleClaudeStreamToNonStreaming(c *gin.Cont
var firstTokenMs *int var firstTokenMs *int
var last map[string]any var last map[string]any
var lastWithParts map[string]any var lastWithParts map[string]any
var collectedImageParts []map[string]any // 收集所有包含图片的 parts var collectedParts []map[string]any // 收集所有 parts(包括 text、thinking、functionCall、inlineData 等)
var collectedTextParts []string // 收集所有文本片段
type scanEvent struct { type scanEvent struct {
line string line string
...@@ -2468,18 +2516,12 @@ func (s *AntigravityGatewayService) handleClaudeStreamToNonStreaming(c *gin.Cont ...@@ -2468,18 +2516,12 @@ func (s *AntigravityGatewayService) handleClaudeStreamToNonStreaming(c *gin.Cont
last = parsed last = parsed
// 保留最后一个有 parts 的响应 // 保留最后一个有 parts 的响应,并收集所有 parts
if parts := extractGeminiParts(parsed); len(parts) > 0 { if parts := extractGeminiParts(parsed); len(parts) > 0 {
lastWithParts = parsed lastWithParts = parsed
// 收集包含图片和文本的 parts
for _, part := range parts { // 收集所有 parts(text、thinking、functionCall、inlineData 等)
if _, ok := part["inlineData"].(map[string]any); ok { collectedParts = append(collectedParts, parts...)
collectedImageParts = append(collectedImageParts, part)
}
if text, ok := part["text"].(string); ok && text != "" {
collectedTextParts = append(collectedTextParts, text)
}
}
} }
case <-intervalCh: case <-intervalCh:
...@@ -2502,32 +2544,13 @@ returnResponse: ...@@ -2502,32 +2544,13 @@ returnResponse:
return nil, s.writeClaudeError(c, http.StatusBadGateway, "upstream_error", "Empty response from upstream") return nil, s.writeClaudeError(c, http.StatusBadGateway, "upstream_error", "Empty response from upstream")
} }
// 如果收集到了图片 parts,需要合并到最终响应中 // 将收集的所有 parts 合并到最终响应中
if len(collectedImageParts) > 0 { if len(collectedParts) > 0 {
finalResponse = mergeImagePartsToResponse(finalResponse, collectedImageParts) finalResponse = mergeCollectedPartsToResponse(finalResponse, collectedParts)
}
// 如果收集到了文本,需要合并到最终响应中
if len(collectedTextParts) > 0 {
finalResponse = mergeTextPartsToResponse(finalResponse, collectedTextParts)
}
geminiPayload := finalResponse
if _, ok := finalResponse["response"]; !ok {
wrapped := map[string]any{
"response": finalResponse,
}
if respID, ok := finalResponse["responseId"]; ok {
wrapped["responseId"] = respID
}
if modelVersion, ok := finalResponse["modelVersion"]; ok {
wrapped["modelVersion"] = modelVersion
}
geminiPayload = wrapped
} }
// 序列化为 JSON(Gemini 格式) // 序列化为 JSON(Gemini 格式)
geminiBody, err := json.Marshal(geminiPayload) geminiBody, err := json.Marshal(finalResponse)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to marshal gemini response: %w", err) return nil, fmt.Errorf("failed to marshal gemini response: %w", err)
} }
......
...@@ -30,7 +30,7 @@ func TestIsAntigravityModelSupported(t *testing.T) { ...@@ -30,7 +30,7 @@ func TestIsAntigravityModelSupported(t *testing.T) {
{"可映射 - claude-3-haiku-20240307", "claude-3-haiku-20240307", true}, {"可映射 - claude-3-haiku-20240307", "claude-3-haiku-20240307", true},
// Gemini 前缀透传 // Gemini 前缀透传
{"Gemini前缀 - gemini-1.5-pro", "gemini-1.5-pro", true}, {"Gemini前缀 - gemini-2.5-pro", "gemini-2.5-pro", true},
{"Gemini前缀 - gemini-unknown-model", "gemini-unknown-model", true}, {"Gemini前缀 - gemini-unknown-model", "gemini-unknown-model", true},
{"Gemini前缀 - gemini-future-version", "gemini-future-version", true}, {"Gemini前缀 - gemini-future-version", "gemini-future-version", true},
...@@ -142,10 +142,10 @@ func TestAntigravityGatewayService_GetMappedModel(t *testing.T) { ...@@ -142,10 +142,10 @@ func TestAntigravityGatewayService_GetMappedModel(t *testing.T) {
expected: "gemini-2.5-flash", expected: "gemini-2.5-flash",
}, },
{ {
name: "Gemini透传 - gemini-1.5-pro", name: "Gemini透传 - gemini-2.5-pro",
requestedModel: "gemini-1.5-pro", requestedModel: "gemini-2.5-pro",
accountMapping: nil, accountMapping: nil,
expected: "gemini-1.5-pro", expected: "gemini-2.5-pro",
}, },
{ {
name: "Gemini透传 - gemini-future-model", name: "Gemini透传 - gemini-future-model",
......
...@@ -142,12 +142,13 @@ func (s *AntigravityOAuthService) ExchangeCode(ctx context.Context, input *Antig ...@@ -142,12 +142,13 @@ func (s *AntigravityOAuthService) ExchangeCode(ctx context.Context, input *Antig
result.Email = userInfo.Email result.Email = userInfo.Email
} }
// 获取 project_id(部分账户类型可能没有) // 获取 project_id(部分账户类型可能没有),失败时重试
loadResp, _, err := client.LoadCodeAssist(ctx, tokenResp.AccessToken) projectID, loadErr := s.loadProjectIDWithRetry(ctx, tokenResp.AccessToken, proxyURL, 3)
if err != nil { if loadErr != nil {
fmt.Printf("[AntigravityOAuth] 警告: 获取 project_id 失败: %v\n", err) fmt.Printf("[AntigravityOAuth] 警告: 获取 project_id 失败(重试后): %v\n", loadErr)
} else if loadResp != nil && loadResp.CloudAICompanionProject != "" { result.ProjectIDMissing = true
result.ProjectID = loadResp.CloudAICompanionProject } else {
result.ProjectID = projectID
} }
return result, nil return result, nil
...@@ -237,21 +238,60 @@ func (s *AntigravityOAuthService) RefreshAccountToken(ctx context.Context, accou ...@@ -237,21 +238,60 @@ func (s *AntigravityOAuthService) RefreshAccountToken(ctx context.Context, accou
tokenInfo.Email = existingEmail tokenInfo.Email = existingEmail
} }
// 每次刷新都调用 LoadCodeAssist 获取 project_id // 每次刷新都调用 LoadCodeAssist 获取 project_id,失败时重试
client := antigravity.NewClient(proxyURL) existingProjectID := strings.TrimSpace(account.GetCredential("project_id"))
loadResp, _, err := client.LoadCodeAssist(ctx, tokenInfo.AccessToken) projectID, loadErr := s.loadProjectIDWithRetry(ctx, tokenInfo.AccessToken, proxyURL, 3)
if err != nil || loadResp == nil || loadResp.CloudAICompanionProject == "" {
// LoadCodeAssist 失败或返回空,保留原有 project_id,标记缺失 if loadErr != nil {
existingProjectID := strings.TrimSpace(account.GetCredential("project_id")) // LoadCodeAssist 失败,保留原有 project_id
tokenInfo.ProjectID = existingProjectID tokenInfo.ProjectID = existingProjectID
tokenInfo.ProjectIDMissing = true // 只有从未获取过 project_id 且本次也获取失败时,才标记为真正缺失
// 如果之前有 project_id,本次只是临时故障,不应标记为错误
if existingProjectID == "" {
tokenInfo.ProjectIDMissing = true
}
} else { } else {
tokenInfo.ProjectID = loadResp.CloudAICompanionProject tokenInfo.ProjectID = projectID
} }
return tokenInfo, nil return tokenInfo, nil
} }
// loadProjectIDWithRetry 带重试机制获取 project_id
// 返回 project_id 和错误,失败时会重试指定次数
func (s *AntigravityOAuthService) loadProjectIDWithRetry(ctx context.Context, accessToken, proxyURL string, maxRetries int) (string, error) {
var lastErr error
for attempt := 0; attempt <= maxRetries; attempt++ {
if attempt > 0 {
// 指数退避:1s, 2s, 4s
backoff := time.Duration(1<<uint(attempt-1)) * time.Second
if backoff > 8*time.Second {
backoff = 8 * time.Second
}
time.Sleep(backoff)
}
client := antigravity.NewClient(proxyURL)
loadResp, _, err := client.LoadCodeAssist(ctx, accessToken)
if err == nil && loadResp != nil && loadResp.CloudAICompanionProject != "" {
return loadResp.CloudAICompanionProject, nil
}
// 记录错误
if err != nil {
lastErr = err
} else if loadResp == nil {
lastErr = fmt.Errorf("LoadCodeAssist 返回空响应")
} else {
lastErr = fmt.Errorf("LoadCodeAssist 返回空 project_id")
}
}
return "", fmt.Errorf("获取 project_id 失败 (重试 %d 次后): %w", maxRetries, lastErr)
}
// BuildAccountCredentials 构建账户凭证 // BuildAccountCredentials 构建账户凭证
func (s *AntigravityOAuthService) BuildAccountCredentials(tokenInfo *AntigravityTokenInfo) map[string]any { func (s *AntigravityOAuthService) BuildAccountCredentials(tokenInfo *AntigravityTokenInfo) map[string]any {
creds := map[string]any{ creds := map[string]any{
......
...@@ -38,6 +38,10 @@ func (s *stubAntigravityUpstream) Do(req *http.Request, proxyURL string, account ...@@ -38,6 +38,10 @@ func (s *stubAntigravityUpstream) Do(req *http.Request, proxyURL string, account
}, nil }, nil
} }
func (s *stubAntigravityUpstream) DoWithTLS(req *http.Request, proxyURL string, accountID int64, accountConcurrency int, enableTLSFingerprint bool) (*http.Response, error) {
return s.Do(req, proxyURL, accountID, accountConcurrency)
}
type scopeLimitCall struct { type scopeLimitCall struct {
accountID int64 accountID int64
scope AntigravityQuotaScope scope AntigravityQuotaScope
...@@ -90,14 +94,14 @@ func TestAntigravityRetryLoop_URLFallback_UsesLatestSuccess(t *testing.T) { ...@@ -90,14 +94,14 @@ func TestAntigravityRetryLoop_URLFallback_UsesLatestSuccess(t *testing.T) {
var handleErrorCalled bool var handleErrorCalled bool
result, err := antigravityRetryLoop(antigravityRetryLoopParams{ result, err := antigravityRetryLoop(antigravityRetryLoopParams{
prefix: "[test]", prefix: "[test]",
ctx: context.Background(), ctx: context.Background(),
account: account, account: account,
proxyURL: "", proxyURL: "",
accessToken: "token", accessToken: "token",
action: "generateContent", action: "generateContent",
body: []byte(`{"input":"test"}`), body: []byte(`{"input":"test"}`),
quotaScope: AntigravityQuotaScopeClaude, quotaScope: AntigravityQuotaScopeClaude,
httpUpstream: upstream, httpUpstream: upstream,
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope) { handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope) {
handleErrorCalled = true handleErrorCalled = true
......
...@@ -4,6 +4,7 @@ import ( ...@@ -4,6 +4,7 @@ import (
"context" "context"
"errors" "errors"
"log" "log"
"log/slog"
"strconv" "strconv"
"strings" "strings"
"time" "time"
...@@ -101,21 +102,32 @@ func (p *AntigravityTokenProvider) GetAccessToken(ctx context.Context, account * ...@@ -101,21 +102,32 @@ func (p *AntigravityTokenProvider) GetAccessToken(ctx context.Context, account *
return "", errors.New("access_token not found in credentials") return "", errors.New("access_token not found in credentials")
} }
// 3. 存入缓存 // 3. 存入缓存(验证版本后再写入,避免异步刷新任务与请求线程的竞态条件)
if p.tokenCache != nil { if p.tokenCache != nil {
ttl := 30 * time.Minute latestAccount, isStale := CheckTokenVersion(ctx, account, p.accountRepo)
if expiresAt != nil { if isStale && latestAccount != nil {
until := time.Until(*expiresAt) // 版本过时,使用 DB 中的最新 token
switch { slog.Debug("antigravity_token_version_stale_use_latest", "account_id", account.ID)
case until > antigravityTokenCacheSkew: accessToken = latestAccount.GetCredential("access_token")
ttl = until - antigravityTokenCacheSkew if strings.TrimSpace(accessToken) == "" {
case until > 0: return "", errors.New("access_token not found after version check")
ttl = until
default:
ttl = time.Minute
} }
// 不写入缓存,让下次请求重新处理
} else {
ttl := 30 * time.Minute
if expiresAt != nil {
until := time.Until(*expiresAt)
switch {
case until > antigravityTokenCacheSkew:
ttl = until - antigravityTokenCacheSkew
case until > 0:
ttl = until
default:
ttl = time.Minute
}
}
_ = p.tokenCache.SetAccessToken(ctx, cacheKey, accessToken, ttl)
} }
_ = p.tokenCache.SetAccessToken(ctx, cacheKey, accessToken, ttl)
} }
return accessToken, nil return accessToken, nil
......
...@@ -3,6 +3,8 @@ package service ...@@ -3,6 +3,8 @@ package service
import ( import (
"context" "context"
"fmt" "fmt"
"log"
"strings"
"time" "time"
) )
...@@ -55,15 +57,32 @@ func (r *AntigravityTokenRefresher) Refresh(ctx context.Context, account *Accoun ...@@ -55,15 +57,32 @@ func (r *AntigravityTokenRefresher) Refresh(ctx context.Context, account *Accoun
} }
newCredentials := r.antigravityOAuthService.BuildAccountCredentials(tokenInfo) newCredentials := r.antigravityOAuthService.BuildAccountCredentials(tokenInfo)
// 合并旧的 credentials,保留新 credentials 中不存在的字段
for k, v := range account.Credentials { for k, v := range account.Credentials {
if _, exists := newCredentials[k]; !exists { if _, exists := newCredentials[k]; !exists {
newCredentials[k] = v newCredentials[k] = v
} }
} }
// 如果 project_id 获取失败,返回 credentials 但同时返回错误让账户被标记 // 特殊处理 project_id:如果新值为空但旧值非空,保留旧值
// 这确保了即使 LoadCodeAssist 失败,project_id 也不会丢失
if newProjectID, _ := newCredentials["project_id"].(string); newProjectID == "" {
if oldProjectID := strings.TrimSpace(account.GetCredential("project_id")); oldProjectID != "" {
newCredentials["project_id"] = oldProjectID
}
}
// 如果 project_id 获取失败,只记录警告,不返回错误
// LoadCodeAssist 失败可能是临时网络问题,应该允许重试而不是立即标记为不可重试错误
// Token 刷新本身是成功的(access_token 和 refresh_token 已更新)
if tokenInfo.ProjectIDMissing { if tokenInfo.ProjectIDMissing {
return newCredentials, fmt.Errorf("missing_project_id: 账户缺少project id,可能无法使用Antigravity") if tokenInfo.ProjectID != "" {
// 有旧的 project_id,本次获取失败,保留旧值
log.Printf("[AntigravityTokenRefresher] Account %d: LoadCodeAssist 临时失败,保留旧 project_id", account.ID)
} else {
// 从未获取过 project_id,本次也失败,但不返回错误以允许下次重试
log.Printf("[AntigravityTokenRefresher] Account %d: LoadCodeAssist 失败,project_id 缺失,但 token 已更新,将在下次刷新时重试", account.ID)
}
} }
return newCredentials, nil return newCredentials, nil
......
...@@ -94,6 +94,20 @@ func (s *APIKeyService) initAuthCache(cfg *config.Config) { ...@@ -94,6 +94,20 @@ func (s *APIKeyService) initAuthCache(cfg *config.Config) {
s.authCacheL1 = cache s.authCacheL1 = cache
} }
// StartAuthCacheInvalidationSubscriber starts the Pub/Sub subscriber for L1 cache invalidation.
// This should be called after the service is fully initialized.
func (s *APIKeyService) StartAuthCacheInvalidationSubscriber(ctx context.Context) {
if s.cache == nil || s.authCacheL1 == nil {
return
}
if err := s.cache.SubscribeAuthCacheInvalidation(ctx, func(cacheKey string) {
s.authCacheL1.Del(cacheKey)
}); err != nil {
// Log but don't fail - L1 cache will still work, just without cross-instance invalidation
println("[Service] Warning: failed to start auth cache invalidation subscriber:", err.Error())
}
}
func (s *APIKeyService) authCacheKey(key string) string { func (s *APIKeyService) authCacheKey(key string) string {
sum := sha256.Sum256([]byte(key)) sum := sha256.Sum256([]byte(key))
return hex.EncodeToString(sum[:]) return hex.EncodeToString(sum[:])
...@@ -149,6 +163,8 @@ func (s *APIKeyService) deleteAuthCache(ctx context.Context, cacheKey string) { ...@@ -149,6 +163,8 @@ func (s *APIKeyService) deleteAuthCache(ctx context.Context, cacheKey string) {
return return
} }
_ = s.cache.DeleteAuthCache(ctx, cacheKey) _ = s.cache.DeleteAuthCache(ctx, cacheKey)
// Publish invalidation message to other instances
_ = s.cache.PublishAuthCacheInvalidation(ctx, cacheKey)
} }
func (s *APIKeyService) loadAuthCacheEntry(ctx context.Context, key, cacheKey string) (*APIKeyAuthCacheEntry, error) { func (s *APIKeyService) loadAuthCacheEntry(ctx context.Context, key, cacheKey string) (*APIKeyAuthCacheEntry, error) {
......
...@@ -65,6 +65,10 @@ type APIKeyCache interface { ...@@ -65,6 +65,10 @@ type APIKeyCache interface {
GetAuthCache(ctx context.Context, key string) (*APIKeyAuthCacheEntry, error) GetAuthCache(ctx context.Context, key string) (*APIKeyAuthCacheEntry, error)
SetAuthCache(ctx context.Context, key string, entry *APIKeyAuthCacheEntry, ttl time.Duration) error SetAuthCache(ctx context.Context, key string, entry *APIKeyAuthCacheEntry, ttl time.Duration) error
DeleteAuthCache(ctx context.Context, key string) error DeleteAuthCache(ctx context.Context, key string) error
// Pub/Sub for L1 cache invalidation across instances
PublishAuthCacheInvalidation(ctx context.Context, cacheKey string) error
SubscribeAuthCacheInvalidation(ctx context.Context, handler func(cacheKey string)) error
} }
// APIKeyAuthCacheInvalidator 提供认证缓存失效能力 // APIKeyAuthCacheInvalidator 提供认证缓存失效能力
......
...@@ -142,6 +142,14 @@ func (s *authCacheStub) DeleteAuthCache(ctx context.Context, key string) error { ...@@ -142,6 +142,14 @@ func (s *authCacheStub) DeleteAuthCache(ctx context.Context, key string) error {
return nil return nil
} }
func (s *authCacheStub) PublishAuthCacheInvalidation(ctx context.Context, cacheKey string) error {
return nil
}
func (s *authCacheStub) SubscribeAuthCacheInvalidation(ctx context.Context, handler func(cacheKey string)) error {
return nil
}
func TestAPIKeyService_GetByKey_UsesL2Cache(t *testing.T) { func TestAPIKeyService_GetByKey_UsesL2Cache(t *testing.T) {
cache := &authCacheStub{} cache := &authCacheStub{}
repo := &authRepoStub{ repo := &authRepoStub{
......
...@@ -168,6 +168,14 @@ func (s *apiKeyCacheStub) DeleteAuthCache(ctx context.Context, key string) error ...@@ -168,6 +168,14 @@ func (s *apiKeyCacheStub) DeleteAuthCache(ctx context.Context, key string) error
return nil return nil
} }
func (s *apiKeyCacheStub) PublishAuthCacheInvalidation(ctx context.Context, cacheKey string) error {
return nil
}
func (s *apiKeyCacheStub) SubscribeAuthCacheInvalidation(ctx context.Context, handler func(cacheKey string)) error {
return nil
}
// TestApiKeyService_Delete_OwnerMismatch 测试非所有者尝试删除时返回权限错误。 // TestApiKeyService_Delete_OwnerMismatch 测试非所有者尝试删除时返回权限错误。
// 预期行为: // 预期行为:
// - GetKeyAndOwnerID 返回所有者 ID 为 1 // - GetKeyAndOwnerID 返回所有者 ID 为 1
......
...@@ -153,8 +153,8 @@ func (s *AuthService) RegisterWithVerification(ctx context.Context, email, passw ...@@ -153,8 +153,8 @@ func (s *AuthService) RegisterWithVerification(ctx context.Context, email, passw
return "", nil, ErrServiceUnavailable return "", nil, ErrServiceUnavailable
} }
// 应用优惠码(如果提供) // 应用优惠码(如果提供且功能已启用
if promoCode != "" && s.promoService != nil { if promoCode != "" && s.promoService != nil && s.settingService != nil && s.settingService.IsPromoCodeEnabled(ctx) {
if err := s.promoService.ApplyPromoCode(ctx, user.ID, promoCode); err != nil { if err := s.promoService.ApplyPromoCode(ctx, user.ID, promoCode); err != nil {
// 优惠码应用失败不影响注册,只记录日志 // 优惠码应用失败不影响注册,只记录日志
log.Printf("[Auth] Failed to apply promo code for user %d: %v", user.ID, err) log.Printf("[Auth] Failed to apply promo code for user %d: %v", user.ID, err)
...@@ -580,3 +580,149 @@ func (s *AuthService) RefreshToken(ctx context.Context, oldTokenString string) ( ...@@ -580,3 +580,149 @@ func (s *AuthService) RefreshToken(ctx context.Context, oldTokenString string) (
// 生成新token // 生成新token
return s.GenerateToken(user) return s.GenerateToken(user)
} }
// IsPasswordResetEnabled 检查是否启用密码重置功能
// 要求:必须同时开启邮件验证且 SMTP 配置正确
func (s *AuthService) IsPasswordResetEnabled(ctx context.Context) bool {
if s.settingService == nil {
return false
}
// Must have email verification enabled and SMTP configured
if !s.settingService.IsEmailVerifyEnabled(ctx) {
return false
}
return s.settingService.IsPasswordResetEnabled(ctx)
}
// preparePasswordReset validates the password reset request and returns necessary data
// Returns (siteName, resetURL, shouldProceed)
// shouldProceed is false when we should silently return success (to prevent enumeration)
func (s *AuthService) preparePasswordReset(ctx context.Context, email, frontendBaseURL string) (string, string, bool) {
// Check if user exists (but don't reveal this to the caller)
user, err := s.userRepo.GetByEmail(ctx, email)
if err != nil {
if errors.Is(err, ErrUserNotFound) {
// Security: Log but don't reveal that user doesn't exist
log.Printf("[Auth] Password reset requested for non-existent email: %s", email)
return "", "", false
}
log.Printf("[Auth] Database error checking email for password reset: %v", err)
return "", "", false
}
// Check if user is active
if !user.IsActive() {
log.Printf("[Auth] Password reset requested for inactive user: %s", email)
return "", "", false
}
// Get site name
siteName := "Sub2API"
if s.settingService != nil {
siteName = s.settingService.GetSiteName(ctx)
}
// Build reset URL base
resetURL := fmt.Sprintf("%s/reset-password", strings.TrimSuffix(frontendBaseURL, "/"))
return siteName, resetURL, true
}
// RequestPasswordReset 请求密码重置(同步发送)
// Security: Returns the same response regardless of whether the email exists (prevent user enumeration)
func (s *AuthService) RequestPasswordReset(ctx context.Context, email, frontendBaseURL string) error {
if !s.IsPasswordResetEnabled(ctx) {
return infraerrors.Forbidden("PASSWORD_RESET_DISABLED", "password reset is not enabled")
}
if s.emailService == nil {
return ErrServiceUnavailable
}
siteName, resetURL, shouldProceed := s.preparePasswordReset(ctx, email, frontendBaseURL)
if !shouldProceed {
return nil // Silent success to prevent enumeration
}
if err := s.emailService.SendPasswordResetEmail(ctx, email, siteName, resetURL); err != nil {
log.Printf("[Auth] Failed to send password reset email to %s: %v", email, err)
return nil // Silent success to prevent enumeration
}
log.Printf("[Auth] Password reset email sent to: %s", email)
return nil
}
// RequestPasswordResetAsync 异步请求密码重置(队列发送)
// Security: Returns the same response regardless of whether the email exists (prevent user enumeration)
func (s *AuthService) RequestPasswordResetAsync(ctx context.Context, email, frontendBaseURL string) error {
if !s.IsPasswordResetEnabled(ctx) {
return infraerrors.Forbidden("PASSWORD_RESET_DISABLED", "password reset is not enabled")
}
if s.emailQueueService == nil {
return ErrServiceUnavailable
}
siteName, resetURL, shouldProceed := s.preparePasswordReset(ctx, email, frontendBaseURL)
if !shouldProceed {
return nil // Silent success to prevent enumeration
}
if err := s.emailQueueService.EnqueuePasswordReset(email, siteName, resetURL); err != nil {
log.Printf("[Auth] Failed to enqueue password reset email for %s: %v", email, err)
return nil // Silent success to prevent enumeration
}
log.Printf("[Auth] Password reset email enqueued for: %s", email)
return nil
}
// ResetPassword 重置密码
// Security: Increments TokenVersion to invalidate all existing JWT tokens
func (s *AuthService) ResetPassword(ctx context.Context, email, token, newPassword string) error {
// Check if password reset is enabled
if !s.IsPasswordResetEnabled(ctx) {
return infraerrors.Forbidden("PASSWORD_RESET_DISABLED", "password reset is not enabled")
}
if s.emailService == nil {
return ErrServiceUnavailable
}
// Verify and consume the reset token (one-time use)
if err := s.emailService.ConsumePasswordResetToken(ctx, email, token); err != nil {
return err
}
// Get user
user, err := s.userRepo.GetByEmail(ctx, email)
if err != nil {
if errors.Is(err, ErrUserNotFound) {
return ErrInvalidResetToken // Token was valid but user was deleted
}
log.Printf("[Auth] Database error getting user for password reset: %v", err)
return ErrServiceUnavailable
}
// Check if user is active
if !user.IsActive() {
return ErrUserNotActive
}
// Hash new password
hashedPassword, err := s.HashPassword(newPassword)
if err != nil {
return fmt.Errorf("hash password: %w", err)
}
// Update password and increment TokenVersion
user.PasswordHash = hashedPassword
user.TokenVersion++ // Invalidate all existing tokens
if err := s.userRepo.Update(ctx, user); err != nil {
log.Printf("[Auth] Database error updating password for user %d: %v", user.ID, err)
return ErrServiceUnavailable
}
log.Printf("[Auth] Password reset successful for user: %s", email)
return nil
}
...@@ -71,6 +71,26 @@ func (s *emailCacheStub) DeleteVerificationCode(ctx context.Context, email strin ...@@ -71,6 +71,26 @@ func (s *emailCacheStub) DeleteVerificationCode(ctx context.Context, email strin
return nil return nil
} }
func (s *emailCacheStub) GetPasswordResetToken(ctx context.Context, email string) (*PasswordResetTokenData, error) {
return nil, nil
}
func (s *emailCacheStub) SetPasswordResetToken(ctx context.Context, email string, data *PasswordResetTokenData, ttl time.Duration) error {
return nil
}
func (s *emailCacheStub) DeletePasswordResetToken(ctx context.Context, email string) error {
return nil
}
func (s *emailCacheStub) IsPasswordResetEmailInCooldown(ctx context.Context, email string) bool {
return false
}
func (s *emailCacheStub) SetPasswordResetEmailCooldown(ctx context.Context, email string, ttl time.Duration) error {
return nil
}
func newAuthService(repo *userRepoStub, settings map[string]string, emailCache EmailCache) *AuthService { func newAuthService(repo *userRepoStub, settings map[string]string, emailCache EmailCache) *AuthService {
cfg := &config.Config{ cfg := &config.Config{
JWT: config.JWTConfig{ JWT: config.JWTConfig{
......
...@@ -181,26 +181,37 @@ func (p *ClaudeTokenProvider) GetAccessToken(ctx context.Context, account *Accou ...@@ -181,26 +181,37 @@ func (p *ClaudeTokenProvider) GetAccessToken(ctx context.Context, account *Accou
return "", errors.New("access_token not found in credentials") return "", errors.New("access_token not found in credentials")
} }
// 3. 存入缓存 // 3. 存入缓存(验证版本后再写入,避免异步刷新任务与请求线程的竞态条件)
if p.tokenCache != nil { if p.tokenCache != nil {
ttl := 30 * time.Minute latestAccount, isStale := CheckTokenVersion(ctx, account, p.accountRepo)
if refreshFailed { if isStale && latestAccount != nil {
// 刷新失败时使用短 TTL,避免失效 token 长时间缓存导致 401 抖动 // 版本过时,使用 DB 中的最新 token
ttl = time.Minute slog.Debug("claude_token_version_stale_use_latest", "account_id", account.ID)
slog.Debug("claude_token_cache_short_ttl", "account_id", account.ID, "reason", "refresh_failed") accessToken = latestAccount.GetCredential("access_token")
} else if expiresAt != nil { if strings.TrimSpace(accessToken) == "" {
until := time.Until(*expiresAt) return "", errors.New("access_token not found after version check")
switch { }
case until > claudeTokenCacheSkew: // 不写入缓存,让下次请求重新处理
ttl = until - claudeTokenCacheSkew } else {
case until > 0: ttl := 30 * time.Minute
ttl = until if refreshFailed {
default: // 刷新失败时使用短 TTL,避免失效 token 长时间缓存导致 401 抖动
ttl = time.Minute ttl = time.Minute
slog.Debug("claude_token_cache_short_ttl", "account_id", account.ID, "reason", "refresh_failed")
} else if expiresAt != nil {
until := time.Until(*expiresAt)
switch {
case until > claudeTokenCacheSkew:
ttl = until - claudeTokenCacheSkew
case until > 0:
ttl = until
default:
ttl = time.Minute
}
}
if err := p.tokenCache.SetAccessToken(ctx, cacheKey, accessToken, ttl); err != nil {
slog.Warn("claude_token_cache_set_failed", "account_id", account.ID, "error", err)
} }
}
if err := p.tokenCache.SetAccessToken(ctx, cacheKey, accessToken, ttl); err != nil {
slog.Warn("claude_token_cache_set_failed", "account_id", account.ID, "error", err)
} }
} }
......
...@@ -20,12 +20,16 @@ var ( ...@@ -20,12 +20,16 @@ var (
// ErrDashboardBackfillDisabled 当配置禁用回填时返回。 // ErrDashboardBackfillDisabled 当配置禁用回填时返回。
ErrDashboardBackfillDisabled = errors.New("仪表盘聚合回填已禁用") ErrDashboardBackfillDisabled = errors.New("仪表盘聚合回填已禁用")
// ErrDashboardBackfillTooLarge 当回填跨度超过限制时返回。 // ErrDashboardBackfillTooLarge 当回填跨度超过限制时返回。
ErrDashboardBackfillTooLarge = errors.New("回填时间跨度过大") ErrDashboardBackfillTooLarge = errors.New("回填时间跨度过大")
errDashboardAggregationRunning = errors.New("聚合作业正在运行")
) )
// DashboardAggregationRepository 定义仪表盘预聚合仓储接口。 // DashboardAggregationRepository 定义仪表盘预聚合仓储接口。
type DashboardAggregationRepository interface { type DashboardAggregationRepository interface {
AggregateRange(ctx context.Context, start, end time.Time) error AggregateRange(ctx context.Context, start, end time.Time) error
// RecomputeRange 重新计算指定时间范围内的聚合数据(包含活跃用户等派生表)。
// 设计目的:当 usage_logs 被批量删除/回滚后,确保聚合表可恢复一致性。
RecomputeRange(ctx context.Context, start, end time.Time) error
GetAggregationWatermark(ctx context.Context) (time.Time, error) GetAggregationWatermark(ctx context.Context) (time.Time, error)
UpdateAggregationWatermark(ctx context.Context, aggregatedAt time.Time) error UpdateAggregationWatermark(ctx context.Context, aggregatedAt time.Time) error
CleanupAggregates(ctx context.Context, hourlyCutoff, dailyCutoff time.Time) error CleanupAggregates(ctx context.Context, hourlyCutoff, dailyCutoff time.Time) error
...@@ -112,6 +116,41 @@ func (s *DashboardAggregationService) TriggerBackfill(start, end time.Time) erro ...@@ -112,6 +116,41 @@ func (s *DashboardAggregationService) TriggerBackfill(start, end time.Time) erro
return nil return nil
} }
// TriggerRecomputeRange 触发指定范围的重新计算(异步)。
// 与 TriggerBackfill 不同:
// - 不依赖 backfill_enabled(这是内部一致性修复)
// - 不更新 watermark(避免影响正常增量聚合游标)
func (s *DashboardAggregationService) TriggerRecomputeRange(start, end time.Time) error {
if s == nil || s.repo == nil {
return errors.New("聚合服务未初始化")
}
if !s.cfg.Enabled {
return errors.New("聚合服务已禁用")
}
if !end.After(start) {
return errors.New("重新计算时间范围无效")
}
go func() {
const maxRetries = 3
for i := 0; i < maxRetries; i++ {
ctx, cancel := context.WithTimeout(context.Background(), defaultDashboardAggregationBackfillTimeout)
err := s.recomputeRange(ctx, start, end)
cancel()
if err == nil {
return
}
if !errors.Is(err, errDashboardAggregationRunning) {
log.Printf("[DashboardAggregation] 重新计算失败: %v", err)
return
}
time.Sleep(5 * time.Second)
}
log.Printf("[DashboardAggregation] 重新计算放弃: 聚合作业持续占用")
}()
return nil
}
func (s *DashboardAggregationService) recomputeRecentDays() { func (s *DashboardAggregationService) recomputeRecentDays() {
days := s.cfg.RecomputeDays days := s.cfg.RecomputeDays
if days <= 0 { if days <= 0 {
...@@ -128,6 +167,24 @@ func (s *DashboardAggregationService) recomputeRecentDays() { ...@@ -128,6 +167,24 @@ func (s *DashboardAggregationService) recomputeRecentDays() {
} }
} }
func (s *DashboardAggregationService) recomputeRange(ctx context.Context, start, end time.Time) error {
if !atomic.CompareAndSwapInt32(&s.running, 0, 1) {
return errDashboardAggregationRunning
}
defer atomic.StoreInt32(&s.running, 0)
jobStart := time.Now().UTC()
if err := s.repo.RecomputeRange(ctx, start, end); err != nil {
return err
}
log.Printf("[DashboardAggregation] 重新计算完成 (start=%s end=%s duration=%s)",
start.UTC().Format(time.RFC3339),
end.UTC().Format(time.RFC3339),
time.Since(jobStart).String(),
)
return nil
}
func (s *DashboardAggregationService) runScheduledAggregation() { func (s *DashboardAggregationService) runScheduledAggregation() {
if !atomic.CompareAndSwapInt32(&s.running, 0, 1) { if !atomic.CompareAndSwapInt32(&s.running, 0, 1) {
return return
...@@ -179,7 +236,7 @@ func (s *DashboardAggregationService) runScheduledAggregation() { ...@@ -179,7 +236,7 @@ func (s *DashboardAggregationService) runScheduledAggregation() {
func (s *DashboardAggregationService) backfillRange(ctx context.Context, start, end time.Time) error { func (s *DashboardAggregationService) backfillRange(ctx context.Context, start, end time.Time) error {
if !atomic.CompareAndSwapInt32(&s.running, 0, 1) { if !atomic.CompareAndSwapInt32(&s.running, 0, 1) {
return errors.New("聚合作业正在运行") return errDashboardAggregationRunning
} }
defer atomic.StoreInt32(&s.running, 0) defer atomic.StoreInt32(&s.running, 0)
......
...@@ -27,6 +27,10 @@ func (s *dashboardAggregationRepoTestStub) AggregateRange(ctx context.Context, s ...@@ -27,6 +27,10 @@ func (s *dashboardAggregationRepoTestStub) AggregateRange(ctx context.Context, s
return s.aggregateErr return s.aggregateErr
} }
func (s *dashboardAggregationRepoTestStub) RecomputeRange(ctx context.Context, start, end time.Time) error {
return s.AggregateRange(ctx, start, end)
}
func (s *dashboardAggregationRepoTestStub) GetAggregationWatermark(ctx context.Context) (time.Time, error) { func (s *dashboardAggregationRepoTestStub) GetAggregationWatermark(ctx context.Context) (time.Time, error) {
return s.watermark, nil return s.watermark, nil
} }
......
...@@ -124,16 +124,16 @@ func (s *DashboardService) GetDashboardStats(ctx context.Context) (*usagestats.D ...@@ -124,16 +124,16 @@ func (s *DashboardService) GetDashboardStats(ctx context.Context) (*usagestats.D
return stats, nil return stats, nil
} }
func (s *DashboardService) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID, accountID, groupID int64, model string, stream *bool) ([]usagestats.TrendDataPoint, error) { func (s *DashboardService) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID, accountID, groupID int64, model string, stream *bool, billingType *int8) ([]usagestats.TrendDataPoint, error) {
trend, err := s.usageRepo.GetUsageTrendWithFilters(ctx, startTime, endTime, granularity, userID, apiKeyID, accountID, groupID, model, stream) trend, err := s.usageRepo.GetUsageTrendWithFilters(ctx, startTime, endTime, granularity, userID, apiKeyID, accountID, groupID, model, stream, billingType)
if err != nil { if err != nil {
return nil, fmt.Errorf("get usage trend with filters: %w", err) return nil, fmt.Errorf("get usage trend with filters: %w", err)
} }
return trend, nil return trend, nil
} }
func (s *DashboardService) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, stream *bool) ([]usagestats.ModelStat, error) { func (s *DashboardService) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, stream *bool, billingType *int8) ([]usagestats.ModelStat, error) {
stats, err := s.usageRepo.GetModelStatsWithFilters(ctx, startTime, endTime, userID, apiKeyID, accountID, groupID, stream) stats, err := s.usageRepo.GetModelStatsWithFilters(ctx, startTime, endTime, userID, apiKeyID, accountID, groupID, stream, billingType)
if err != nil { if err != nil {
return nil, fmt.Errorf("get model stats with filters: %w", err) return nil, fmt.Errorf("get model stats with filters: %w", err)
} }
......
...@@ -101,6 +101,10 @@ func (s *dashboardAggregationRepoStub) AggregateRange(ctx context.Context, start ...@@ -101,6 +101,10 @@ func (s *dashboardAggregationRepoStub) AggregateRange(ctx context.Context, start
return nil return nil
} }
func (s *dashboardAggregationRepoStub) RecomputeRange(ctx context.Context, start, end time.Time) error {
return nil
}
func (s *dashboardAggregationRepoStub) GetAggregationWatermark(ctx context.Context) (time.Time, error) { func (s *dashboardAggregationRepoStub) GetAggregationWatermark(ctx context.Context) (time.Time, error) {
if s.err != nil { if s.err != nil {
return time.Time{}, s.err return time.Time{}, s.err
......
package service package service
import "github.com/Wei-Shaw/sub2api/internal/domain"
// Status constants // Status constants
const ( const (
StatusActive = "active" StatusActive = domain.StatusActive
StatusDisabled = "disabled" StatusDisabled = domain.StatusDisabled
StatusError = "error" StatusError = domain.StatusError
StatusUnused = "unused" StatusUnused = domain.StatusUnused
StatusUsed = "used" StatusUsed = domain.StatusUsed
StatusExpired = "expired" StatusExpired = domain.StatusExpired
) )
// Role constants // Role constants
const ( const (
RoleAdmin = "admin" RoleAdmin = domain.RoleAdmin
RoleUser = "user" RoleUser = domain.RoleUser
) )
// Platform constants // Platform constants
const ( const (
PlatformAnthropic = "anthropic" PlatformAnthropic = domain.PlatformAnthropic
PlatformOpenAI = "openai" PlatformOpenAI = domain.PlatformOpenAI
PlatformGemini = "gemini" PlatformGemini = domain.PlatformGemini
PlatformAntigravity = "antigravity" PlatformAntigravity = domain.PlatformAntigravity
) )
// Account type constants // Account type constants
const ( const (
AccountTypeOAuth = "oauth" // OAuth类型账号(full scope: profile + inference) AccountTypeOAuth = domain.AccountTypeOAuth // OAuth类型账号(full scope: profile + inference)
AccountTypeSetupToken = "setup-token" // Setup Token类型账号(inference only scope) AccountTypeSetupToken = domain.AccountTypeSetupToken // Setup Token类型账号(inference only scope)
AccountTypeAPIKey = "apikey" // API Key类型账号 AccountTypeAPIKey = domain.AccountTypeAPIKey // API Key类型账号
) )
// Redeem type constants // Redeem type constants
const ( const (
RedeemTypeBalance = "balance" RedeemTypeBalance = domain.RedeemTypeBalance
RedeemTypeConcurrency = "concurrency" RedeemTypeConcurrency = domain.RedeemTypeConcurrency
RedeemTypeSubscription = "subscription" RedeemTypeSubscription = domain.RedeemTypeSubscription
) )
// PromoCode status constants // PromoCode status constants
const ( const (
PromoCodeStatusActive = "active" PromoCodeStatusActive = domain.PromoCodeStatusActive
PromoCodeStatusDisabled = "disabled" PromoCodeStatusDisabled = domain.PromoCodeStatusDisabled
) )
// Admin adjustment type constants // Admin adjustment type constants
const ( const (
AdjustmentTypeAdminBalance = "admin_balance" // 管理员调整余额 AdjustmentTypeAdminBalance = domain.AdjustmentTypeAdminBalance // 管理员调整余额
AdjustmentTypeAdminConcurrency = "admin_concurrency" // 管理员调整并发数 AdjustmentTypeAdminConcurrency = domain.AdjustmentTypeAdminConcurrency // 管理员调整并发数
) )
// Group subscription type constants // Group subscription type constants
const ( const (
SubscriptionTypeStandard = "standard" // 标准计费模式(按余额扣费) SubscriptionTypeStandard = domain.SubscriptionTypeStandard // 标准计费模式(按余额扣费)
SubscriptionTypeSubscription = "subscription" // 订阅模式(按限额控制) SubscriptionTypeSubscription = domain.SubscriptionTypeSubscription // 订阅模式(按限额控制)
) )
// Subscription status constants // Subscription status constants
const ( const (
SubscriptionStatusActive = "active" SubscriptionStatusActive = domain.SubscriptionStatusActive
SubscriptionStatusExpired = "expired" SubscriptionStatusExpired = domain.SubscriptionStatusExpired
SubscriptionStatusSuspended = "suspended" SubscriptionStatusSuspended = domain.SubscriptionStatusSuspended
) )
// LinuxDoConnectSyntheticEmailDomain 是 LinuxDo Connect 用户的合成邮箱后缀(RFC 保留域名)。 // LinuxDoConnectSyntheticEmailDomain 是 LinuxDo Connect 用户的合成邮箱后缀(RFC 保留域名)。
...@@ -69,8 +71,10 @@ const LinuxDoConnectSyntheticEmailDomain = "@linuxdo-connect.invalid" ...@@ -69,8 +71,10 @@ const LinuxDoConnectSyntheticEmailDomain = "@linuxdo-connect.invalid"
// Setting keys // Setting keys
const ( const (
// 注册设置 // 注册设置
SettingKeyRegistrationEnabled = "registration_enabled" // 是否开放注册 SettingKeyRegistrationEnabled = "registration_enabled" // 是否开放注册
SettingKeyEmailVerifyEnabled = "email_verify_enabled" // 是否开启邮件验证 SettingKeyEmailVerifyEnabled = "email_verify_enabled" // 是否开启邮件验证
SettingKeyPromoCodeEnabled = "promo_code_enabled" // 是否启用优惠码功能
SettingKeyPasswordResetEnabled = "password_reset_enabled" // 是否启用忘记密码功能(需要先开启邮件验证)
// 邮件服务设置 // 邮件服务设置
SettingKeySMTPHost = "smtp_host" // SMTP服务器地址 SettingKeySMTPHost = "smtp_host" // SMTP服务器地址
...@@ -86,6 +90,9 @@ const ( ...@@ -86,6 +90,9 @@ const (
SettingKeyTurnstileSiteKey = "turnstile_site_key" // Turnstile Site Key SettingKeyTurnstileSiteKey = "turnstile_site_key" // Turnstile Site Key
SettingKeyTurnstileSecretKey = "turnstile_secret_key" // Turnstile Secret Key SettingKeyTurnstileSecretKey = "turnstile_secret_key" // Turnstile Secret Key
// TOTP 双因素认证设置
SettingKeyTotpEnabled = "totp_enabled" // 是否启用 TOTP 2FA 功能
// LinuxDo Connect OAuth 登录设置 // LinuxDo Connect OAuth 登录设置
SettingKeyLinuxDoConnectEnabled = "linuxdo_connect_enabled" SettingKeyLinuxDoConnectEnabled = "linuxdo_connect_enabled"
SettingKeyLinuxDoConnectClientID = "linuxdo_connect_client_id" SettingKeyLinuxDoConnectClientID = "linuxdo_connect_client_id"
...@@ -93,13 +100,16 @@ const ( ...@@ -93,13 +100,16 @@ const (
SettingKeyLinuxDoConnectRedirectURL = "linuxdo_connect_redirect_url" SettingKeyLinuxDoConnectRedirectURL = "linuxdo_connect_redirect_url"
// OEM设置 // OEM设置
SettingKeySiteName = "site_name" // 网站名称 SettingKeySiteName = "site_name" // 网站名称
SettingKeySiteLogo = "site_logo" // 网站Logo (base64) SettingKeySiteLogo = "site_logo" // 网站Logo (base64)
SettingKeySiteSubtitle = "site_subtitle" // 网站副标题 SettingKeySiteSubtitle = "site_subtitle" // 网站副标题
SettingKeyAPIBaseURL = "api_base_url" // API端点地址(用于客户端配置和导入) SettingKeyAPIBaseURL = "api_base_url" // API端点地址(用于客户端配置和导入)
SettingKeyContactInfo = "contact_info" // 客服联系方式 SettingKeyContactInfo = "contact_info" // 客服联系方式
SettingKeyDocURL = "doc_url" // 文档链接 SettingKeyDocURL = "doc_url" // 文档链接
SettingKeyHomeContent = "home_content" // 首页内容(支持 Markdown/HTML,或 URL 作为 iframe src) SettingKeyHomeContent = "home_content" // 首页内容(支持 Markdown/HTML,或 URL 作为 iframe src)
SettingKeyHideCcsImportButton = "hide_ccs_import_button" // 是否隐藏 API Keys 页面的导入 CCS 按钮
SettingKeyPurchaseSubscriptionEnabled = "purchase_subscription_enabled" // 是否展示“购买订阅”页面入口
SettingKeyPurchaseSubscriptionURL = "purchase_subscription_url" // “购买订阅”页面 URL(作为 iframe src)
// 默认配置 // 默认配置
SettingKeyDefaultConcurrency = "default_concurrency" // 新用户默认并发量 SettingKeyDefaultConcurrency = "default_concurrency" // 新用户默认并发量
......
...@@ -8,11 +8,18 @@ import ( ...@@ -8,11 +8,18 @@ import (
"time" "time"
) )
// Task type constants
const (
TaskTypeVerifyCode = "verify_code"
TaskTypePasswordReset = "password_reset"
)
// EmailTask 邮件发送任务 // EmailTask 邮件发送任务
type EmailTask struct { type EmailTask struct {
Email string Email string
SiteName string SiteName string
TaskType string // "verify_code" TaskType string // "verify_code" or "password_reset"
ResetURL string // Only used for password_reset task type
} }
// EmailQueueService 异步邮件队列服务 // EmailQueueService 异步邮件队列服务
...@@ -73,12 +80,18 @@ func (s *EmailQueueService) processTask(workerID int, task EmailTask) { ...@@ -73,12 +80,18 @@ func (s *EmailQueueService) processTask(workerID int, task EmailTask) {
defer cancel() defer cancel()
switch task.TaskType { switch task.TaskType {
case "verify_code": case TaskTypeVerifyCode:
if err := s.emailService.SendVerifyCode(ctx, task.Email, task.SiteName); err != nil { if err := s.emailService.SendVerifyCode(ctx, task.Email, task.SiteName); err != nil {
log.Printf("[EmailQueue] Worker %d failed to send verify code to %s: %v", workerID, task.Email, err) log.Printf("[EmailQueue] Worker %d failed to send verify code to %s: %v", workerID, task.Email, err)
} else { } else {
log.Printf("[EmailQueue] Worker %d sent verify code to %s", workerID, task.Email) log.Printf("[EmailQueue] Worker %d sent verify code to %s", workerID, task.Email)
} }
case TaskTypePasswordReset:
if err := s.emailService.SendPasswordResetEmailWithCooldown(ctx, task.Email, task.SiteName, task.ResetURL); err != nil {
log.Printf("[EmailQueue] Worker %d failed to send password reset to %s: %v", workerID, task.Email, err)
} else {
log.Printf("[EmailQueue] Worker %d sent password reset to %s", workerID, task.Email)
}
default: default:
log.Printf("[EmailQueue] Worker %d unknown task type: %s", workerID, task.TaskType) log.Printf("[EmailQueue] Worker %d unknown task type: %s", workerID, task.TaskType)
} }
...@@ -89,7 +102,7 @@ func (s *EmailQueueService) EnqueueVerifyCode(email, siteName string) error { ...@@ -89,7 +102,7 @@ func (s *EmailQueueService) EnqueueVerifyCode(email, siteName string) error {
task := EmailTask{ task := EmailTask{
Email: email, Email: email,
SiteName: siteName, SiteName: siteName,
TaskType: "verify_code", TaskType: TaskTypeVerifyCode,
} }
select { select {
...@@ -101,6 +114,24 @@ func (s *EmailQueueService) EnqueueVerifyCode(email, siteName string) error { ...@@ -101,6 +114,24 @@ func (s *EmailQueueService) EnqueueVerifyCode(email, siteName string) error {
} }
} }
// EnqueuePasswordReset 将密码重置邮件任务加入队列
func (s *EmailQueueService) EnqueuePasswordReset(email, siteName, resetURL string) error {
task := EmailTask{
Email: email,
SiteName: siteName,
TaskType: TaskTypePasswordReset,
ResetURL: resetURL,
}
select {
case s.taskChan <- task:
log.Printf("[EmailQueue] Enqueued password reset task for %s", email)
return nil
default:
return fmt.Errorf("email queue is full")
}
}
// Stop 停止队列服务 // Stop 停止队列服务
func (s *EmailQueueService) Stop() { func (s *EmailQueueService) Stop() {
close(s.stopChan) close(s.stopChan)
......
...@@ -3,11 +3,14 @@ package service ...@@ -3,11 +3,14 @@ package service
import ( import (
"context" "context"
"crypto/rand" "crypto/rand"
"crypto/subtle"
"crypto/tls" "crypto/tls"
"encoding/hex"
"fmt" "fmt"
"log" "log"
"math/big" "math/big"
"net/smtp" "net/smtp"
"net/url"
"strconv" "strconv"
"time" "time"
...@@ -19,6 +22,9 @@ var ( ...@@ -19,6 +22,9 @@ var (
ErrInvalidVerifyCode = infraerrors.BadRequest("INVALID_VERIFY_CODE", "invalid or expired verification code") ErrInvalidVerifyCode = infraerrors.BadRequest("INVALID_VERIFY_CODE", "invalid or expired verification code")
ErrVerifyCodeTooFrequent = infraerrors.TooManyRequests("VERIFY_CODE_TOO_FREQUENT", "please wait before requesting a new code") ErrVerifyCodeTooFrequent = infraerrors.TooManyRequests("VERIFY_CODE_TOO_FREQUENT", "please wait before requesting a new code")
ErrVerifyCodeMaxAttempts = infraerrors.TooManyRequests("VERIFY_CODE_MAX_ATTEMPTS", "too many failed attempts, please request a new code") ErrVerifyCodeMaxAttempts = infraerrors.TooManyRequests("VERIFY_CODE_MAX_ATTEMPTS", "too many failed attempts, please request a new code")
// Password reset errors
ErrInvalidResetToken = infraerrors.BadRequest("INVALID_RESET_TOKEN", "invalid or expired password reset token")
) )
// EmailCache defines cache operations for email service // EmailCache defines cache operations for email service
...@@ -26,6 +32,16 @@ type EmailCache interface { ...@@ -26,6 +32,16 @@ type EmailCache interface {
GetVerificationCode(ctx context.Context, email string) (*VerificationCodeData, error) GetVerificationCode(ctx context.Context, email string) (*VerificationCodeData, error)
SetVerificationCode(ctx context.Context, email string, data *VerificationCodeData, ttl time.Duration) error SetVerificationCode(ctx context.Context, email string, data *VerificationCodeData, ttl time.Duration) error
DeleteVerificationCode(ctx context.Context, email string) error DeleteVerificationCode(ctx context.Context, email string) error
// Password reset token methods
GetPasswordResetToken(ctx context.Context, email string) (*PasswordResetTokenData, error)
SetPasswordResetToken(ctx context.Context, email string, data *PasswordResetTokenData, ttl time.Duration) error
DeletePasswordResetToken(ctx context.Context, email string) error
// Password reset email cooldown methods
// Returns true if in cooldown period (email was sent recently)
IsPasswordResetEmailInCooldown(ctx context.Context, email string) bool
SetPasswordResetEmailCooldown(ctx context.Context, email string, ttl time.Duration) error
} }
// VerificationCodeData represents verification code data // VerificationCodeData represents verification code data
...@@ -35,10 +51,22 @@ type VerificationCodeData struct { ...@@ -35,10 +51,22 @@ type VerificationCodeData struct {
CreatedAt time.Time CreatedAt time.Time
} }
// PasswordResetTokenData represents password reset token data
type PasswordResetTokenData struct {
Token string
CreatedAt time.Time
}
const ( const (
verifyCodeTTL = 15 * time.Minute verifyCodeTTL = 15 * time.Minute
verifyCodeCooldown = 1 * time.Minute verifyCodeCooldown = 1 * time.Minute
maxVerifyCodeAttempts = 5 maxVerifyCodeAttempts = 5
// Password reset token settings
passwordResetTokenTTL = 30 * time.Minute
// Password reset email cooldown (prevent email bombing)
passwordResetEmailCooldown = 30 * time.Second
) )
// SMTPConfig SMTP配置 // SMTPConfig SMTP配置
...@@ -254,8 +282,8 @@ func (s *EmailService) VerifyCode(ctx context.Context, email, code string) error ...@@ -254,8 +282,8 @@ func (s *EmailService) VerifyCode(ctx context.Context, email, code string) error
return ErrVerifyCodeMaxAttempts return ErrVerifyCodeMaxAttempts
} }
// 验证码不匹配 // 验证码不匹配 (constant-time comparison to prevent timing attacks)
if data.Code != code { if subtle.ConstantTimeCompare([]byte(data.Code), []byte(code)) != 1 {
data.Attempts++ data.Attempts++
if err := s.cache.SetVerificationCode(ctx, email, data, verifyCodeTTL); err != nil { if err := s.cache.SetVerificationCode(ctx, email, data, verifyCodeTTL); err != nil {
log.Printf("[Email] Failed to update verification attempt count: %v", err) log.Printf("[Email] Failed to update verification attempt count: %v", err)
...@@ -357,3 +385,157 @@ func (s *EmailService) TestSMTPConnectionWithConfig(config *SMTPConfig) error { ...@@ -357,3 +385,157 @@ func (s *EmailService) TestSMTPConnectionWithConfig(config *SMTPConfig) error {
return client.Quit() return client.Quit()
} }
// GeneratePasswordResetToken generates a secure 32-byte random token (64 hex characters)
func (s *EmailService) GeneratePasswordResetToken() (string, error) {
bytes := make([]byte, 32)
if _, err := rand.Read(bytes); err != nil {
return "", err
}
return hex.EncodeToString(bytes), nil
}
// SendPasswordResetEmail sends a password reset email with a reset link
func (s *EmailService) SendPasswordResetEmail(ctx context.Context, email, siteName, resetURL string) error {
var token string
var needSaveToken bool
// Check if token already exists
existing, err := s.cache.GetPasswordResetToken(ctx, email)
if err == nil && existing != nil {
// Token exists, reuse it (allows resending email without generating new token)
token = existing.Token
needSaveToken = false
} else {
// Generate new token
token, err = s.GeneratePasswordResetToken()
if err != nil {
return fmt.Errorf("generate token: %w", err)
}
needSaveToken = true
}
// Save token to Redis (only if new token generated)
if needSaveToken {
data := &PasswordResetTokenData{
Token: token,
CreatedAt: time.Now(),
}
if err := s.cache.SetPasswordResetToken(ctx, email, data, passwordResetTokenTTL); err != nil {
return fmt.Errorf("save reset token: %w", err)
}
}
// Build full reset URL with URL-encoded token and email
fullResetURL := fmt.Sprintf("%s?email=%s&token=%s", resetURL, url.QueryEscape(email), url.QueryEscape(token))
// Build email content
subject := fmt.Sprintf("[%s] 密码重置请求", siteName)
body := s.buildPasswordResetEmailBody(fullResetURL, siteName)
// Send email
if err := s.SendEmail(ctx, email, subject, body); err != nil {
return fmt.Errorf("send email: %w", err)
}
return nil
}
// SendPasswordResetEmailWithCooldown sends password reset email with cooldown check (called by queue worker)
// This method wraps SendPasswordResetEmail with email cooldown to prevent email bombing
func (s *EmailService) SendPasswordResetEmailWithCooldown(ctx context.Context, email, siteName, resetURL string) error {
// Check email cooldown to prevent email bombing
if s.cache.IsPasswordResetEmailInCooldown(ctx, email) {
log.Printf("[Email] Password reset email skipped (cooldown): %s", email)
return nil // Silent success to prevent revealing cooldown to attackers
}
// Send email using core method
if err := s.SendPasswordResetEmail(ctx, email, siteName, resetURL); err != nil {
return err
}
// Set cooldown marker (Redis TTL handles expiration)
if err := s.cache.SetPasswordResetEmailCooldown(ctx, email, passwordResetEmailCooldown); err != nil {
log.Printf("[Email] Failed to set password reset cooldown for %s: %v", email, err)
}
return nil
}
// VerifyPasswordResetToken verifies the password reset token without consuming it
func (s *EmailService) VerifyPasswordResetToken(ctx context.Context, email, token string) error {
data, err := s.cache.GetPasswordResetToken(ctx, email)
if err != nil || data == nil {
return ErrInvalidResetToken
}
// Use constant-time comparison to prevent timing attacks
if subtle.ConstantTimeCompare([]byte(data.Token), []byte(token)) != 1 {
return ErrInvalidResetToken
}
return nil
}
// ConsumePasswordResetToken verifies and deletes the token (one-time use)
func (s *EmailService) ConsumePasswordResetToken(ctx context.Context, email, token string) error {
// Verify first
if err := s.VerifyPasswordResetToken(ctx, email, token); err != nil {
return err
}
// Delete after verification (one-time use)
if err := s.cache.DeletePasswordResetToken(ctx, email); err != nil {
log.Printf("[Email] Failed to delete password reset token after consumption: %v", err)
}
return nil
}
// buildPasswordResetEmailBody builds the HTML content for password reset email
func (s *EmailService) buildPasswordResetEmailBody(resetURL, siteName string) string {
return fmt.Sprintf(`
<!DOCTYPE html>
<html>
<head>
<meta charset="UTF-8">
<style>
body { font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, Oxygen, Ubuntu, sans-serif; background-color: #f5f5f5; margin: 0; padding: 20px; }
.container { max-width: 600px; margin: 0 auto; background-color: #ffffff; border-radius: 8px; overflow: hidden; box-shadow: 0 2px 8px rgba(0,0,0,0.1); }
.header { background: linear-gradient(135deg, #667eea 0%%, #764ba2 100%%); color: white; padding: 30px; text-align: center; }
.header h1 { margin: 0; font-size: 24px; }
.content { padding: 40px 30px; text-align: center; }
.button { display: inline-block; background: linear-gradient(135deg, #667eea 0%%, #764ba2 100%%); color: white; padding: 14px 32px; text-decoration: none; border-radius: 8px; font-size: 16px; font-weight: 600; margin: 20px 0; }
.button:hover { opacity: 0.9; }
.info { color: #666; font-size: 14px; line-height: 1.6; margin-top: 20px; }
.link-fallback { color: #666; font-size: 12px; word-break: break-all; margin-top: 20px; padding: 15px; background-color: #f8f9fa; border-radius: 4px; }
.footer { background-color: #f8f9fa; padding: 20px; text-align: center; color: #999; font-size: 12px; }
.warning { color: #e74c3c; font-weight: 500; }
</style>
</head>
<body>
<div class="container">
<div class="header">
<h1>%s</h1>
</div>
<div class="content">
<p style="font-size: 18px; color: #333;">密码重置请求</p>
<p style="color: #666;">您已请求重置密码。请点击下方按钮设置新密码:</p>
<a href="%s" class="button">重置密码</a>
<div class="info">
<p>此链接将在 <strong>30 分钟</strong>后失效。</p>
<p class="warning">如果您没有请求重置密码,请忽略此邮件。您的密码将保持不变。</p>
</div>
<div class="link-fallback">
<p>如果按钮无法点击,请复制以下链接到浏览器中打开:</p>
<p>%s</p>
</div>
</div>
<div class="footer">
<p>这是一封自动发送的邮件,请勿回复。</p>
</div>
</div>
</body>
</html>
`, siteName, resetURL, resetURL)
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment