Commit 584cfc3d authored by yangjianbo's avatar yangjianbo
Browse files

chore(logging): 完成后端日志审计与结构化迁移

- 将高密度服务与处理器日志迁移到新日志系统(LegacyPrintf/结构化日志)
- 增加 stdlog bridge 与兼容测试,保留旧日志捕获能力
- 将 OpenAI 断流告警改为结构化 Warn 并改造对应测试为 sink 捕获
- 补齐后端相关文件 logger 引用并通过全量 go test
parent eaa7d899
......@@ -3,9 +3,10 @@ package service
import (
"context"
"fmt"
"log"
"sync"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
)
// Task type constants
......@@ -56,7 +57,7 @@ func (s *EmailQueueService) start() {
s.wg.Add(1)
go s.worker(i)
}
log.Printf("[EmailQueue] Started %d workers", s.workers)
logger.LegacyPrintf("service.email_queue", "[EmailQueue] Started %d workers", s.workers)
}
// worker 工作协程
......@@ -68,7 +69,7 @@ func (s *EmailQueueService) worker(id int) {
case task := <-s.taskChan:
s.processTask(id, task)
case <-s.stopChan:
log.Printf("[EmailQueue] Worker %d stopping", id)
logger.LegacyPrintf("service.email_queue", "[EmailQueue] Worker %d stopping", id)
return
}
}
......@@ -82,18 +83,18 @@ func (s *EmailQueueService) processTask(workerID int, task EmailTask) {
switch task.TaskType {
case TaskTypeVerifyCode:
if err := s.emailService.SendVerifyCode(ctx, task.Email, task.SiteName); err != nil {
log.Printf("[EmailQueue] Worker %d failed to send verify code to %s: %v", workerID, task.Email, err)
logger.LegacyPrintf("service.email_queue", "[EmailQueue] Worker %d failed to send verify code to %s: %v", workerID, task.Email, err)
} else {
log.Printf("[EmailQueue] Worker %d sent verify code to %s", workerID, task.Email)
logger.LegacyPrintf("service.email_queue", "[EmailQueue] Worker %d sent verify code to %s", workerID, task.Email)
}
case TaskTypePasswordReset:
if err := s.emailService.SendPasswordResetEmailWithCooldown(ctx, task.Email, task.SiteName, task.ResetURL); err != nil {
log.Printf("[EmailQueue] Worker %d failed to send password reset to %s: %v", workerID, task.Email, err)
logger.LegacyPrintf("service.email_queue", "[EmailQueue] Worker %d failed to send password reset to %s: %v", workerID, task.Email, err)
} else {
log.Printf("[EmailQueue] Worker %d sent password reset to %s", workerID, task.Email)
logger.LegacyPrintf("service.email_queue", "[EmailQueue] Worker %d sent password reset to %s", workerID, task.Email)
}
default:
log.Printf("[EmailQueue] Worker %d unknown task type: %s", workerID, task.TaskType)
logger.LegacyPrintf("service.email_queue", "[EmailQueue] Worker %d unknown task type: %s", workerID, task.TaskType)
}
}
......@@ -107,7 +108,7 @@ func (s *EmailQueueService) EnqueueVerifyCode(email, siteName string) error {
select {
case s.taskChan <- task:
log.Printf("[EmailQueue] Enqueued verify code task for %s", email)
logger.LegacyPrintf("service.email_queue", "[EmailQueue] Enqueued verify code task for %s", email)
return nil
default:
return fmt.Errorf("email queue is full")
......@@ -125,7 +126,7 @@ func (s *EmailQueueService) EnqueuePasswordReset(email, siteName, resetURL strin
select {
case s.taskChan <- task:
log.Printf("[EmailQueue] Enqueued password reset task for %s", email)
logger.LegacyPrintf("service.email_queue", "[EmailQueue] Enqueued password reset task for %s", email)
return nil
default:
return fmt.Errorf("email queue is full")
......@@ -136,5 +137,5 @@ func (s *EmailQueueService) EnqueuePasswordReset(email, siteName, resetURL strin
func (s *EmailQueueService) Stop() {
close(s.stopChan)
s.wg.Wait()
log.Println("[EmailQueue] All workers stopped")
logger.LegacyPrintf("service.email_queue", "%s", "[EmailQueue] All workers stopped")
}
......@@ -2,13 +2,13 @@ package service
import (
"context"
"log"
"sort"
"strings"
"sync"
"time"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
)
// ErrorPassthroughRepository 定义错误透传规则的数据访问接口
......@@ -62,9 +62,9 @@ func NewErrorPassthroughService(
// 启动时加载规则到本地缓存
ctx := context.Background()
if err := svc.reloadRulesFromDB(ctx); err != nil {
log.Printf("[ErrorPassthroughService] Failed to load rules from DB on startup: %v", err)
logger.LegacyPrintf("service.error_passthrough", "[ErrorPassthroughService] Failed to load rules from DB on startup: %v", err)
if fallbackErr := svc.refreshLocalCache(ctx); fallbackErr != nil {
log.Printf("[ErrorPassthroughService] Failed to load rules from cache fallback on startup: %v", fallbackErr)
logger.LegacyPrintf("service.error_passthrough", "[ErrorPassthroughService] Failed to load rules from cache fallback on startup: %v", fallbackErr)
}
}
......@@ -72,7 +72,7 @@ func NewErrorPassthroughService(
if cache != nil {
cache.SubscribeUpdates(ctx, func() {
if err := svc.refreshLocalCache(context.Background()); err != nil {
log.Printf("[ErrorPassthroughService] Failed to refresh cache on notification: %v", err)
logger.LegacyPrintf("service.error_passthrough", "[ErrorPassthroughService] Failed to refresh cache on notification: %v", err)
}
})
}
......@@ -180,7 +180,7 @@ func (s *ErrorPassthroughService) getCachedRules() []*model.ErrorPassthroughRule
// 如果本地缓存为空,尝试刷新
ctx := context.Background()
if err := s.refreshLocalCache(ctx); err != nil {
log.Printf("[ErrorPassthroughService] Failed to refresh cache: %v", err)
logger.LegacyPrintf("service.error_passthrough", "[ErrorPassthroughService] Failed to refresh cache: %v", err)
return nil
}
......@@ -213,7 +213,7 @@ func (s *ErrorPassthroughService) reloadRulesFromDB(ctx context.Context) error {
// 更新 Redis 缓存
if s.cache != nil {
if err := s.cache.Set(ctx, rules); err != nil {
log.Printf("[ErrorPassthroughService] Failed to set cache: %v", err)
logger.LegacyPrintf("service.error_passthrough", "[ErrorPassthroughService] Failed to set cache: %v", err)
}
}
......@@ -254,13 +254,13 @@ func (s *ErrorPassthroughService) invalidateAndNotify(ctx context.Context) {
// 先失效缓存,避免后续刷新读到陈旧规则。
if s.cache != nil {
if err := s.cache.Invalidate(ctx); err != nil {
log.Printf("[ErrorPassthroughService] Failed to invalidate cache: %v", err)
logger.LegacyPrintf("service.error_passthrough", "[ErrorPassthroughService] Failed to invalidate cache: %v", err)
}
}
// 刷新本地缓存
if err := s.reloadRulesFromDB(ctx); err != nil {
log.Printf("[ErrorPassthroughService] Failed to refresh local cache: %v", err)
logger.LegacyPrintf("service.error_passthrough", "[ErrorPassthroughService] Failed to refresh local cache: %v", err)
// 刷新失败时清空本地缓存,避免继续使用陈旧规则。
s.clearLocalCache()
}
......@@ -268,7 +268,7 @@ func (s *ErrorPassthroughService) invalidateAndNotify(ctx context.Context) {
// 通知其他实例
if s.cache != nil {
if err := s.cache.NotifyUpdate(ctx); err != nil {
log.Printf("[ErrorPassthroughService] Failed to notify cache update: %v", err)
logger.LegacyPrintf("service.error_passthrough", "[ErrorPassthroughService] Failed to notify cache update: %v", err)
}
}
}
......
This diff is collapsed.
......@@ -10,7 +10,6 @@ import (
"errors"
"fmt"
"io"
"log"
"math"
mathrand "math/rand"
"net/http"
......@@ -22,6 +21,7 @@ import (
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
"github.com/Wei-Shaw/sub2api/internal/pkg/geminicli"
"github.com/Wei-Shaw/sub2api/internal/pkg/googleapi"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"github.com/Wei-Shaw/sub2api/internal/util/responseheaders"
"github.com/Wei-Shaw/sub2api/internal/util/urlvalidator"
......@@ -282,7 +282,7 @@ func (s *GeminiMessagesCompatService) passesRateLimitPreCheck(ctx context.Contex
}
ok, err := s.rateLimitService.PreCheckUsage(ctx, account, requestedModel)
if err != nil {
log.Printf("[Gemini PreCheck] Account %d precheck error: %v", account.ID, err)
logger.LegacyPrintf("service.gemini_messages_compat", "[Gemini PreCheck] Account %d precheck error: %v", account.ID, err)
}
return ok
}
......@@ -698,7 +698,7 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
Message: safeErr,
})
if attempt < geminiMaxRetries {
log.Printf("Gemini account %d: upstream request failed, retry %d/%d: %v", account.ID, attempt, geminiMaxRetries, err)
logger.LegacyPrintf("service.gemini_messages_compat", "Gemini account %d: upstream request failed, retry %d/%d: %v", account.ID, attempt, geminiMaxRetries, err)
sleepGeminiBackoff(attempt)
continue
}
......@@ -754,7 +754,7 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
}
retryGeminiReq, txErr := convertClaudeMessagesToGeminiGenerateContent(strippedClaudeBody)
if txErr == nil {
log.Printf("Gemini account %d: detected signature-related 400, retrying with downgraded Claude blocks (%s)", account.ID, stageName)
logger.LegacyPrintf("service.gemini_messages_compat", "Gemini account %d: detected signature-related 400, retrying with downgraded Claude blocks (%s)", account.ID, stageName)
geminiReq = retryGeminiReq
// Consume one retry budget attempt and continue with the updated request payload.
sleepGeminiBackoff(1)
......@@ -821,7 +821,7 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
Detail: upstreamDetail,
})
log.Printf("Gemini account %d: upstream status %d, retry %d/%d", account.ID, resp.StatusCode, attempt, geminiMaxRetries)
logger.LegacyPrintf("service.gemini_messages_compat", "Gemini account %d: upstream status %d, retry %d/%d", account.ID, resp.StatusCode, attempt, geminiMaxRetries)
sleepGeminiBackoff(attempt)
continue
}
......@@ -1166,7 +1166,7 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
Message: safeErr,
})
if attempt < geminiMaxRetries {
log.Printf("Gemini account %d: upstream request failed, retry %d/%d: %v", account.ID, attempt, geminiMaxRetries, err)
logger.LegacyPrintf("service.gemini_messages_compat", "Gemini account %d: upstream request failed, retry %d/%d: %v", account.ID, attempt, geminiMaxRetries, err)
sleepGeminiBackoff(attempt)
continue
}
......@@ -1235,7 +1235,7 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
Detail: upstreamDetail,
})
log.Printf("Gemini account %d: upstream status %d, retry %d/%d", account.ID, resp.StatusCode, attempt, geminiMaxRetries)
logger.LegacyPrintf("service.gemini_messages_compat", "Gemini account %d: upstream status %d, retry %d/%d", account.ID, resp.StatusCode, attempt, geminiMaxRetries)
sleepGeminiBackoff(attempt)
continue
}
......@@ -1367,7 +1367,7 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
maxBytes = 2048
}
upstreamDetail = truncateString(string(respBody), maxBytes)
log.Printf("[Gemini] native upstream error %d: %s", resp.StatusCode, truncateForLog(respBody, s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes))
logger.LegacyPrintf("service.gemini_messages_compat", "[Gemini] native upstream error %d: %s", resp.StatusCode, truncateForLog(respBody, s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes))
}
setOpsUpstreamError(c, resp.StatusCode, upstreamMsg, upstreamDetail)
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
......@@ -1544,7 +1544,7 @@ func (s *GeminiMessagesCompatService) writeGeminiMappedError(c *gin.Context, acc
})
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
log.Printf("[Gemini] upstream error %d: %s", upstreamStatus, truncateForLog(body, s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes))
logger.LegacyPrintf("service.gemini_messages_compat", "[Gemini] upstream error %d: %s", upstreamStatus, truncateForLog(body, s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes))
}
if status, errType, errMsg, matched := applyErrorPassthroughRule(
......@@ -2299,13 +2299,13 @@ type UpstreamHTTPResult struct {
func (s *GeminiMessagesCompatService) handleNativeNonStreamingResponse(c *gin.Context, resp *http.Response, isOAuth bool) (*ClaudeUsage, error) {
// Log response headers for debugging
log.Printf("[GeminiAPI] ========== Response Headers ==========")
logger.LegacyPrintf("service.gemini_messages_compat", "[GeminiAPI] ========== Response Headers ==========")
for key, values := range resp.Header {
if strings.HasPrefix(strings.ToLower(key), "x-ratelimit") {
log.Printf("[GeminiAPI] %s: %v", key, values)
logger.LegacyPrintf("service.gemini_messages_compat", "[GeminiAPI] %s: %v", key, values)
}
}
log.Printf("[GeminiAPI] ========================================")
logger.LegacyPrintf("service.gemini_messages_compat", "[GeminiAPI] ========================================")
respBody, err := io.ReadAll(resp.Body)
if err != nil {
......@@ -2339,13 +2339,13 @@ func (s *GeminiMessagesCompatService) handleNativeNonStreamingResponse(c *gin.Co
func (s *GeminiMessagesCompatService) handleNativeStreamingResponse(c *gin.Context, resp *http.Response, startTime time.Time, isOAuth bool) (*geminiNativeStreamResult, error) {
// Log response headers for debugging
log.Printf("[GeminiAPI] ========== Streaming Response Headers ==========")
logger.LegacyPrintf("service.gemini_messages_compat", "[GeminiAPI] ========== Streaming Response Headers ==========")
for key, values := range resp.Header {
if strings.HasPrefix(strings.ToLower(key), "x-ratelimit") {
log.Printf("[GeminiAPI] %s: %v", key, values)
logger.LegacyPrintf("service.gemini_messages_compat", "[GeminiAPI] %s: %v", key, values)
}
}
log.Printf("[GeminiAPI] ====================================================")
logger.LegacyPrintf("service.gemini_messages_compat", "[GeminiAPI] ====================================================")
if s.cfg != nil {
responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.cfg.Security.ResponseHeaders)
......@@ -2640,16 +2640,16 @@ func (s *GeminiMessagesCompatService) handleGeminiUpstreamError(ctx context.Cont
cooldown = s.rateLimitService.GeminiCooldown(ctx, account)
}
ra = time.Now().Add(cooldown)
log.Printf("[Gemini 429] Account %d (Code Assist, tier=%s, project=%s) rate limited, cooldown=%v", account.ID, tierID, projectID, time.Until(ra).Truncate(time.Second))
logger.LegacyPrintf("service.gemini_messages_compat", "[Gemini 429] Account %d (Code Assist, tier=%s, project=%s) rate limited, cooldown=%v", account.ID, tierID, projectID, time.Until(ra).Truncate(time.Second))
} else {
// API Key / AI Studio OAuth: PST 午夜
if ts := nextGeminiDailyResetUnix(); ts != nil {
ra = time.Unix(*ts, 0)
log.Printf("[Gemini 429] Account %d (API Key/AI Studio, type=%s) rate limited, reset at PST midnight (%v)", account.ID, account.Type, ra)
logger.LegacyPrintf("service.gemini_messages_compat", "[Gemini 429] Account %d (API Key/AI Studio, type=%s) rate limited, reset at PST midnight (%v)", account.ID, account.Type, ra)
} else {
// 兜底:5 分钟
ra = time.Now().Add(5 * time.Minute)
log.Printf("[Gemini 429] Account %d rate limited, fallback to 5min", account.ID)
logger.LegacyPrintf("service.gemini_messages_compat", "[Gemini 429] Account %d rate limited, fallback to 5min", account.ID)
}
}
_ = s.accountRepo.SetRateLimited(ctx, account.ID, ra)
......@@ -2659,7 +2659,7 @@ func (s *GeminiMessagesCompatService) handleGeminiUpstreamError(ctx context.Cont
// 使用解析到的重置时间
resetTime := time.Unix(*resetAt, 0)
_ = s.accountRepo.SetRateLimited(ctx, account.ID, resetTime)
log.Printf("[Gemini 429] Account %d rate limited until %v (oauth_type=%s, tier=%s)",
logger.LegacyPrintf("service.gemini_messages_compat", "[Gemini 429] Account %d rate limited until %v (oauth_type=%s, tier=%s)",
account.ID, resetTime, oauthType, tierID)
}
......
......@@ -6,7 +6,6 @@ import (
"errors"
"fmt"
"io"
"log"
"net/http"
"regexp"
"strconv"
......@@ -16,6 +15,7 @@ import (
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/geminicli"
"github.com/Wei-Shaw/sub2api/internal/pkg/httpclient"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
)
const (
......@@ -328,27 +328,27 @@ func extractTierIDFromAllowedTiers(allowedTiers []geminicli.AllowedTier) string
// inferGoogleOneTier infers Google One tier from Drive storage limit
func inferGoogleOneTier(storageBytes int64) string {
log.Printf("[GeminiOAuth] inferGoogleOneTier - input: %d bytes (%.2f TB)", storageBytes, float64(storageBytes)/float64(TB))
logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] inferGoogleOneTier - input: %d bytes (%.2f TB)", storageBytes, float64(storageBytes)/float64(TB))
if storageBytes <= 0 {
log.Printf("[GeminiOAuth] inferGoogleOneTier - storageBytes <= 0, returning UNKNOWN")
logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] inferGoogleOneTier - storageBytes <= 0, returning UNKNOWN")
return GeminiTierGoogleOneUnknown
}
if storageBytes > StorageTierUnlimited {
log.Printf("[GeminiOAuth] inferGoogleOneTier - > %d bytes (100TB), returning UNLIMITED", StorageTierUnlimited)
logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] inferGoogleOneTier - > %d bytes (100TB), returning UNLIMITED", StorageTierUnlimited)
return GeminiTierGoogleAIUltra
}
if storageBytes >= StorageTierAIPremium {
log.Printf("[GeminiOAuth] inferGoogleOneTier - >= %d bytes (2TB), returning google_ai_pro", StorageTierAIPremium)
logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] inferGoogleOneTier - >= %d bytes (2TB), returning google_ai_pro", StorageTierAIPremium)
return GeminiTierGoogleAIPro
}
if storageBytes >= StorageTierFree {
log.Printf("[GeminiOAuth] inferGoogleOneTier - >= %d bytes (15GB), returning FREE", StorageTierFree)
logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] inferGoogleOneTier - >= %d bytes (15GB), returning FREE", StorageTierFree)
return GeminiTierGoogleOneFree
}
log.Printf("[GeminiOAuth] inferGoogleOneTier - < %d bytes (15GB), returning UNKNOWN", StorageTierFree)
logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] inferGoogleOneTier - < %d bytes (15GB), returning UNKNOWN", StorageTierFree)
return GeminiTierGoogleOneUnknown
}
......@@ -358,30 +358,30 @@ func inferGoogleOneTier(storageBytes int64) string {
// 2. Personal accounts will get 403/404 from cloudaicompanion.googleapis.com
// 3. Google consumer (Google One) and enterprise (GCP) systems are physically isolated
func (s *GeminiOAuthService) FetchGoogleOneTier(ctx context.Context, accessToken, proxyURL string) (string, *geminicli.DriveStorageInfo, error) {
log.Printf("[GeminiOAuth] Starting FetchGoogleOneTier (Google One personal account)")
logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] Starting FetchGoogleOneTier (Google One personal account)")
// Use Drive API to infer tier from storage quota (requires drive.readonly scope)
log.Printf("[GeminiOAuth] Calling Drive API for storage quota...")
logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] Calling Drive API for storage quota...")
driveClient := geminicli.NewDriveClient()
storageInfo, err := driveClient.GetStorageQuota(ctx, accessToken, proxyURL)
if err != nil {
// Check if it's a 403 (scope not granted)
if strings.Contains(err.Error(), "status 403") {
log.Printf("[GeminiOAuth] Drive API scope not available (403): %v", err)
logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] Drive API scope not available (403): %v", err)
return GeminiTierGoogleOneUnknown, nil, err
}
// Other errors
log.Printf("[GeminiOAuth] Failed to fetch Drive storage: %v", err)
logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] Failed to fetch Drive storage: %v", err)
return GeminiTierGoogleOneUnknown, nil, err
}
log.Printf("[GeminiOAuth] Drive API response - Limit: %d bytes (%.2f TB), Usage: %d bytes (%.2f GB)",
logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] Drive API response - Limit: %d bytes (%.2f TB), Usage: %d bytes (%.2f GB)",
storageInfo.Limit, float64(storageInfo.Limit)/float64(TB),
storageInfo.Usage, float64(storageInfo.Usage)/float64(GB))
tierID := inferGoogleOneTier(storageInfo.Limit)
log.Printf("[GeminiOAuth] Inferred tier from storage: %s", tierID)
logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] Inferred tier from storage: %s", tierID)
return tierID, storageInfo, nil
}
......@@ -441,16 +441,16 @@ func (s *GeminiOAuthService) RefreshAccountGoogleOneTier(
}
func (s *GeminiOAuthService) ExchangeCode(ctx context.Context, input *GeminiExchangeCodeInput) (*GeminiTokenInfo, error) {
log.Printf("[GeminiOAuth] ========== ExchangeCode START ==========")
log.Printf("[GeminiOAuth] SessionID: %s", input.SessionID)
logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] ========== ExchangeCode START ==========")
logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] SessionID: %s", input.SessionID)
session, ok := s.sessionStore.Get(input.SessionID)
if !ok {
log.Printf("[GeminiOAuth] ERROR: Session not found or expired")
logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] ERROR: Session not found or expired")
return nil, fmt.Errorf("session not found or expired")
}
if strings.TrimSpace(input.State) == "" || input.State != session.State {
log.Printf("[GeminiOAuth] ERROR: Invalid state")
logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] ERROR: Invalid state")
return nil, fmt.Errorf("invalid state")
}
......@@ -461,7 +461,7 @@ func (s *GeminiOAuthService) ExchangeCode(ctx context.Context, input *GeminiExch
proxyURL = proxy.URL()
}
}
log.Printf("[GeminiOAuth] ProxyURL: %s", proxyURL)
logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] ProxyURL: %s", proxyURL)
redirectURI := session.RedirectURI
......@@ -470,8 +470,8 @@ func (s *GeminiOAuthService) ExchangeCode(ctx context.Context, input *GeminiExch
if oauthType == "" {
oauthType = "code_assist"
}
log.Printf("[GeminiOAuth] OAuth Type: %s", oauthType)
log.Printf("[GeminiOAuth] Project ID from session: %s", session.ProjectID)
logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] OAuth Type: %s", oauthType)
logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] Project ID from session: %s", session.ProjectID)
// If the session was created for AI Studio OAuth, ensure a custom OAuth client is configured.
if oauthType == "ai_studio" {
......@@ -496,12 +496,12 @@ func (s *GeminiOAuthService) ExchangeCode(ctx context.Context, input *GeminiExch
tokenResp, err := s.oauthClient.ExchangeCode(ctx, oauthType, input.Code, session.CodeVerifier, redirectURI, proxyURL)
if err != nil {
log.Printf("[GeminiOAuth] ERROR: Failed to exchange code: %v", err)
logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] ERROR: Failed to exchange code: %v", err)
return nil, fmt.Errorf("failed to exchange code: %w", err)
}
log.Printf("[GeminiOAuth] Token exchange successful")
log.Printf("[GeminiOAuth] Token scope: %s", tokenResp.Scope)
log.Printf("[GeminiOAuth] Token expires_in: %d seconds", tokenResp.ExpiresIn)
logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] Token exchange successful")
logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] Token scope: %s", tokenResp.Scope)
logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] Token expires_in: %d seconds", tokenResp.ExpiresIn)
sessionProjectID := strings.TrimSpace(session.ProjectID)
s.sessionStore.Delete(input.SessionID)
......@@ -523,40 +523,40 @@ func (s *GeminiOAuthService) ExchangeCode(ctx context.Context, input *GeminiExch
fallbackTierID = canonicalGeminiTierIDForOAuthType(oauthType, session.TierID)
}
log.Printf("[GeminiOAuth] ========== Account Type Detection START ==========")
log.Printf("[GeminiOAuth] OAuth Type: %s", oauthType)
logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] ========== Account Type Detection START ==========")
logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] OAuth Type: %s", oauthType)
// 对于 code_assist 模式,project_id 是必需的,需要调用 Code Assist API
// 对于 google_one 模式,使用个人 Google 账号,不需要 project_id,配额由 Google 网关自动识别
// 对于 ai_studio 模式,project_id 是可选的(不影响使用 AI Studio API)
switch oauthType {
case "code_assist":
log.Printf("[GeminiOAuth] Processing code_assist OAuth type")
logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] Processing code_assist OAuth type")
if projectID == "" {
log.Printf("[GeminiOAuth] No project_id provided, attempting to fetch from LoadCodeAssist API...")
logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] No project_id provided, attempting to fetch from LoadCodeAssist API...")
var err error
projectID, tierID, err = s.fetchProjectID(ctx, tokenResp.AccessToken, proxyURL)
if err != nil {
// 记录警告但不阻断流程,允许后续补充 project_id
fmt.Printf("[GeminiOAuth] Warning: Failed to fetch project_id during token exchange: %v\n", err)
log.Printf("[GeminiOAuth] WARNING: Failed to fetch project_id: %v", err)
logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] WARNING: Failed to fetch project_id: %v", err)
} else {
log.Printf("[GeminiOAuth] Successfully fetched project_id: %s, tier_id: %s", projectID, tierID)
logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] Successfully fetched project_id: %s, tier_id: %s", projectID, tierID)
}
} else {
log.Printf("[GeminiOAuth] User provided project_id: %s, fetching tier_id...", projectID)
logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] User provided project_id: %s, fetching tier_id...", projectID)
// 用户手动填了 project_id,仍需调用 LoadCodeAssist 获取 tierID
_, fetchedTierID, err := s.fetchProjectID(ctx, tokenResp.AccessToken, proxyURL)
if err != nil {
fmt.Printf("[GeminiOAuth] Warning: Failed to fetch tierID: %v\n", err)
log.Printf("[GeminiOAuth] WARNING: Failed to fetch tier_id: %v", err)
logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] WARNING: Failed to fetch tier_id: %v", err)
} else {
tierID = fetchedTierID
log.Printf("[GeminiOAuth] Successfully fetched tier_id: %s", tierID)
logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] Successfully fetched tier_id: %s", tierID)
}
}
if strings.TrimSpace(projectID) == "" {
log.Printf("[GeminiOAuth] ERROR: Missing project_id for Code Assist OAuth")
logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] ERROR: Missing project_id for Code Assist OAuth")
return nil, fmt.Errorf("missing project_id for Code Assist OAuth: please fill Project ID (optional field) and regenerate the auth URL, or ensure your Google account has an ACTIVE GCP project")
}
// Prefer auto-detected tier; fall back to user-selected tier.
......@@ -564,31 +564,31 @@ func (s *GeminiOAuthService) ExchangeCode(ctx context.Context, input *GeminiExch
if tierID == "" {
if fallbackTierID != "" {
tierID = fallbackTierID
log.Printf("[GeminiOAuth] Using fallback tier_id from user/session: %s", tierID)
logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] Using fallback tier_id from user/session: %s", tierID)
} else {
tierID = GeminiTierGCPStandard
log.Printf("[GeminiOAuth] Using default tier_id: %s", tierID)
logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] Using default tier_id: %s", tierID)
}
}
log.Printf("[GeminiOAuth] Final code_assist result - project_id: %s, tier_id: %s", projectID, tierID)
logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] Final code_assist result - project_id: %s, tier_id: %s", projectID, tierID)
case "google_one":
log.Printf("[GeminiOAuth] Processing google_one OAuth type")
logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] Processing google_one OAuth type")
// Google One accounts use cloudaicompanion API, which requires a project_id.
// For personal accounts, Google auto-assigns a project_id via the LoadCodeAssist API.
if projectID == "" {
log.Printf("[GeminiOAuth] No project_id provided, attempting to fetch from LoadCodeAssist API...")
logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] No project_id provided, attempting to fetch from LoadCodeAssist API...")
var err error
projectID, _, err = s.fetchProjectID(ctx, tokenResp.AccessToken, proxyURL)
if err != nil {
log.Printf("[GeminiOAuth] ERROR: Failed to fetch project_id: %v", err)
logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] ERROR: Failed to fetch project_id: %v", err)
return nil, fmt.Errorf("google One accounts require a project_id, failed to auto-detect: %w", err)
}
log.Printf("[GeminiOAuth] Successfully fetched project_id: %s", projectID)
logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] Successfully fetched project_id: %s", projectID)
}
log.Printf("[GeminiOAuth] Attempting to fetch Google One tier from Drive API...")
logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] Attempting to fetch Google One tier from Drive API...")
// Attempt to fetch Drive storage tier
var storageInfo *geminicli.DriveStorageInfo
var err error
......@@ -596,12 +596,12 @@ func (s *GeminiOAuthService) ExchangeCode(ctx context.Context, input *GeminiExch
if err != nil {
// Log warning but don't block - use fallback
fmt.Printf("[GeminiOAuth] Warning: Failed to fetch Drive tier: %v\n", err)
log.Printf("[GeminiOAuth] WARNING: Failed to fetch Drive tier: %v", err)
logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] WARNING: Failed to fetch Drive tier: %v", err)
tierID = ""
} else {
log.Printf("[GeminiOAuth] Successfully fetched Drive tier: %s", tierID)
logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] Successfully fetched Drive tier: %s", tierID)
if storageInfo != nil {
log.Printf("[GeminiOAuth] Drive storage - Limit: %d bytes (%.2f TB), Usage: %d bytes (%.2f GB)",
logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] Drive storage - Limit: %d bytes (%.2f TB), Usage: %d bytes (%.2f GB)",
storageInfo.Limit, float64(storageInfo.Limit)/float64(TB),
storageInfo.Usage, float64(storageInfo.Usage)/float64(GB))
}
......@@ -610,10 +610,10 @@ func (s *GeminiOAuthService) ExchangeCode(ctx context.Context, input *GeminiExch
if tierID == "" || tierID == GeminiTierGoogleOneUnknown {
if fallbackTierID != "" {
tierID = fallbackTierID
log.Printf("[GeminiOAuth] Using fallback tier_id from user/session: %s", tierID)
logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] Using fallback tier_id from user/session: %s", tierID)
} else {
tierID = GeminiTierGoogleOneFree
log.Printf("[GeminiOAuth] Using default tier_id: %s", tierID)
logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] Using default tier_id: %s", tierID)
}
}
fmt.Printf("[GeminiOAuth] Google One tierID after normalization: %s\n", tierID)
......@@ -636,7 +636,7 @@ func (s *GeminiOAuthService) ExchangeCode(ctx context.Context, input *GeminiExch
"drive_tier_updated_at": time.Now().Format(time.RFC3339),
},
}
log.Printf("[GeminiOAuth] ========== ExchangeCode END (google_one with storage info) ==========")
logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] ========== ExchangeCode END (google_one with storage info) ==========")
return tokenInfo, nil
}
......@@ -649,10 +649,10 @@ func (s *GeminiOAuthService) ExchangeCode(ctx context.Context, input *GeminiExch
}
default:
log.Printf("[GeminiOAuth] Processing %s OAuth type (no tier detection)", oauthType)
logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] Processing %s OAuth type (no tier detection)", oauthType)
}
log.Printf("[GeminiOAuth] ========== Account Type Detection END ==========")
logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] ========== Account Type Detection END ==========")
result := &GeminiTokenInfo{
AccessToken: tokenResp.AccessToken,
......@@ -665,8 +665,8 @@ func (s *GeminiOAuthService) ExchangeCode(ctx context.Context, input *GeminiExch
TierID: tierID,
OAuthType: oauthType,
}
log.Printf("[GeminiOAuth] Final result - OAuth Type: %s, Project ID: %s, Tier ID: %s", result.OAuthType, result.ProjectID, result.TierID)
log.Printf("[GeminiOAuth] ========== ExchangeCode END ==========")
logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] Final result - OAuth Type: %s, Project ID: %s, Tier ID: %s", result.OAuthType, result.ProjectID, result.TierID)
logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] ========== ExchangeCode END ==========")
return result, nil
}
......@@ -949,23 +949,23 @@ func (s *GeminiOAuthService) fetchProjectID(ctx context.Context, accessToken, pr
registeredTierID := strings.TrimSpace(loadResp.GetTier())
if registeredTierID != "" {
// 已注册但未返回 cloudaicompanionProject,这在 Google One 用户中较常见:需要用户自行提供 project_id。
log.Printf("[GeminiOAuth] User has tier (%s) but no cloudaicompanionProject, trying Cloud Resource Manager...", registeredTierID)
logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] User has tier (%s) but no cloudaicompanionProject, trying Cloud Resource Manager...", registeredTierID)
// Try to get project from Cloud Resource Manager
fallback, fbErr := fetchProjectIDFromResourceManager(ctx, accessToken, proxyURL)
if fbErr == nil && strings.TrimSpace(fallback) != "" {
log.Printf("[GeminiOAuth] Found project from Cloud Resource Manager: %s", fallback)
logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] Found project from Cloud Resource Manager: %s", fallback)
return strings.TrimSpace(fallback), tierID, nil
}
// No project found - user must provide project_id manually
log.Printf("[GeminiOAuth] No project found from Cloud Resource Manager, user must provide project_id manually")
logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] No project found from Cloud Resource Manager, user must provide project_id manually")
return "", tierID, fmt.Errorf("user is registered (tier: %s) but no project_id available. Please provide Project ID manually in the authorization form, or create a project at https://console.cloud.google.com", registeredTierID)
}
}
// 未检测到 currentTier/paidTier,视为新用户,继续调用 onboardUser
log.Printf("[GeminiOAuth] No currentTier/paidTier found, proceeding with onboardUser (tierID: %s)", tierID)
logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] No currentTier/paidTier found, proceeding with onboardUser (tierID: %s)", tierID)
req := &geminicli.OnboardUserRequest{
TierID: tierID,
......
......@@ -7,13 +7,14 @@ import (
"encoding/hex"
"encoding/json"
"fmt"
"log"
"log/slog"
"net/http"
"regexp"
"strconv"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
)
// 预编译正则表达式(避免每次调用重新编译)
......@@ -84,7 +85,7 @@ func (s *IdentityService) GetOrCreateFingerprint(ctx context.Context, accountID
cached.UserAgent = clientUA
// 保存更新后的指纹
_ = s.cache.SetFingerprint(ctx, accountID, cached)
log.Printf("Updated fingerprint user-agent for account %d: %s", accountID, clientUA)
logger.LegacyPrintf("service.identity", "Updated fingerprint user-agent for account %d: %s", accountID, clientUA)
}
return cached, nil
}
......@@ -97,10 +98,10 @@ func (s *IdentityService) GetOrCreateFingerprint(ctx context.Context, accountID
// 保存到缓存(永不过期)
if err := s.cache.SetFingerprint(ctx, accountID, fp); err != nil {
log.Printf("Warning: failed to cache fingerprint for account %d: %v", accountID, err)
logger.LegacyPrintf("service.identity", "Warning: failed to cache fingerprint for account %d: %v", accountID, err)
}
log.Printf("Created new fingerprint for account %d with client_id: %s", accountID, fp.ClientID)
logger.LegacyPrintf("service.identity", "Created new fingerprint for account %d with client_id: %s", accountID, fp.ClientID)
return fp, nil
}
......@@ -277,19 +278,19 @@ func (s *IdentityService) RewriteUserIDWithMasking(ctx context.Context, body []b
// 获取或生成固定的伪装 session ID
maskedSessionID, err := s.cache.GetMaskedSessionID(ctx, account.ID)
if err != nil {
log.Printf("Warning: failed to get masked session ID for account %d: %v", account.ID, err)
logger.LegacyPrintf("service.identity", "Warning: failed to get masked session ID for account %d: %v", account.ID, err)
return newBody, nil
}
if maskedSessionID == "" {
// 首次或已过期,生成新的伪装 session ID
maskedSessionID = generateRandomUUID()
log.Printf("Generated new masked session ID for account %d: %s", account.ID, maskedSessionID)
logger.LegacyPrintf("service.identity", "Generated new masked session ID for account %d: %s", account.ID, maskedSessionID)
}
// 刷新 TTL(每次请求都刷新,保持 15 分钟有效期)
if err := s.cache.SetMaskedSessionID(ctx, account.ID, maskedSessionID); err != nil {
log.Printf("Warning: failed to set masked session ID for account %d: %v", account.ID, err)
logger.LegacyPrintf("service.identity", "Warning: failed to set masked session ID for account %d: %v", account.ID, err)
}
// 替换 session 部分:保留 _session_ 之前的内容,替换之后的内容
......@@ -335,7 +336,7 @@ func generateClientID() string {
b := make([]byte, 32)
if _, err := rand.Read(b); err != nil {
// 极罕见的情况,使用时间戳+固定值作为fallback
log.Printf("Warning: crypto/rand.Read failed: %v, using fallback", err)
logger.LegacyPrintf("service.identity", "Warning: crypto/rand.Read failed: %v, using fallback", err)
// 使用SHA256(当前纳秒时间)作为fallback
h := sha256.Sum256([]byte(fmt.Sprintf("%d", time.Now().UnixNano())))
return hex.EncodeToString(h[:])
......
......@@ -10,7 +10,6 @@ import (
"errors"
"fmt"
"io"
"log"
"net/http"
"sort"
"strconv"
......@@ -19,12 +18,14 @@ import (
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
"github.com/Wei-Shaw/sub2api/internal/util/responseheaders"
"github.com/Wei-Shaw/sub2api/internal/util/urlvalidator"
"github.com/gin-gonic/gin"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
"go.uber.org/zap"
)
const (
......@@ -786,7 +787,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
// 对所有请求执行模型映射(包含 Codex CLI)。
mappedModel := account.GetMappedModel(reqModel)
if mappedModel != reqModel {
log.Printf("[OpenAI] Model mapping applied: %s -> %s (account: %s, isCodexCLI: %v)", reqModel, mappedModel, account.Name, isCodexCLI)
logger.LegacyPrintf("service.openai_gateway", "[OpenAI] Model mapping applied: %s -> %s (account: %s, isCodexCLI: %v)", reqModel, mappedModel, account.Name, isCodexCLI)
reqBody["model"] = mappedModel
bodyModified = true
}
......@@ -795,7 +796,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
if model, ok := reqBody["model"].(string); ok {
normalizedModel := normalizeCodexModel(model)
if normalizedModel != "" && normalizedModel != model {
log.Printf("[OpenAI] Codex model normalization: %s -> %s (account: %s, type: %s, isCodexCLI: %v)",
logger.LegacyPrintf("service.openai_gateway", "[OpenAI] Codex model normalization: %s -> %s (account: %s, type: %s, isCodexCLI: %v)",
model, normalizedModel, account.Name, account.Type, isCodexCLI)
reqBody["model"] = normalizedModel
mappedModel = normalizedModel
......@@ -808,7 +809,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
if effort, ok := reasoning["effort"].(string); ok && effort == "minimal" {
reasoning["effort"] = "none"
bodyModified = true
log.Printf("[OpenAI] Normalized reasoning.effort: minimal -> none (account: %s)", account.Name)
logger.LegacyPrintf("service.openai_gateway", "[OpenAI] Normalized reasoning.effort: minimal -> none (account: %s)", account.Name)
}
}
......@@ -1012,7 +1013,7 @@ func (s *OpenAIGatewayService) forwardOpenAIPassthrough(
reqStream bool,
startTime time.Time,
) (*OpenAIForwardResult, error) {
log.Printf(
logger.LegacyPrintf("service.openai_gateway",
"[OpenAI 自动透传] 命中自动透传分支: account=%d name=%s type=%s model=%s stream=%v",
account.ID,
account.Name,
......@@ -1022,18 +1023,15 @@ func (s *OpenAIGatewayService) forwardOpenAIPassthrough(
)
if reqStream && c != nil && c.Request != nil {
if timeoutHeaders := collectOpenAIPassthroughTimeoutHeaders(c.Request.Header); len(timeoutHeaders) > 0 {
streamWarnLogger := logger.FromContext(ctx).With(
zap.String("component", "service.openai_gateway"),
zap.Int64("account_id", account.ID),
zap.Strings("timeout_headers", timeoutHeaders),
)
if s.isOpenAIPassthroughTimeoutHeadersAllowed() {
log.Printf(
"[WARN] [OpenAI passthrough] 透传请求包含超时相关请求头,且当前配置为放行,可能导致上游提前断流: account=%d headers=%s",
account.ID,
strings.Join(timeoutHeaders, ", "),
)
streamWarnLogger.Warn("OpenAI passthrough 透传请求包含超时相关请求头,且当前配置为放行,可能导致上游提前断流")
} else {
log.Printf(
"[WARN] [OpenAI passthrough] 检测到超时相关请求头,将按配置过滤以降低断流风险: account=%d headers=%s",
account.ID,
strings.Join(timeoutHeaders, ", "),
)
streamWarnLogger.Warn("OpenAI passthrough 检测到超时相关请求头,将按配置过滤以降低断流风险")
}
}
}
......@@ -1347,7 +1345,7 @@ func (s *OpenAIGatewayService) handleStreamingResponsePassthrough(
if !clientDisconnected {
if _, err := fmt.Fprintln(w, line); err != nil {
clientDisconnected = true
log.Printf("[OpenAI passthrough] Client disconnected during streaming, continue draining upstream for usage: account=%d", account.ID)
logger.LegacyPrintf("service.openai_gateway", "[OpenAI passthrough] Client disconnected during streaming, continue draining upstream for usage: account=%d", account.ID)
} else {
flusher.Flush()
}
......@@ -1355,11 +1353,11 @@ func (s *OpenAIGatewayService) handleStreamingResponsePassthrough(
}
if err := scanner.Err(); err != nil {
if clientDisconnected {
log.Printf("[OpenAI passthrough] Upstream read error after client disconnect: account=%d err=%v", account.ID, err)
logger.LegacyPrintf("service.openai_gateway", "[OpenAI passthrough] Upstream read error after client disconnect: account=%d err=%v", account.ID, err)
return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, nil
}
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
log.Printf(
logger.LegacyPrintf("service.openai_gateway",
"[WARN] [OpenAI passthrough] 流读取被取消,可能发生断流: account=%d request_id=%s err=%v ctx_err=%v",
account.ID,
upstreamRequestID,
......@@ -1369,10 +1367,10 @@ func (s *OpenAIGatewayService) handleStreamingResponsePassthrough(
return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, nil
}
if errors.Is(err, bufio.ErrTooLong) {
log.Printf("[OpenAI passthrough] SSE line too long: account=%d max_size=%d error=%v", account.ID, maxLineSize, err)
logger.LegacyPrintf("service.openai_gateway", "[OpenAI passthrough] SSE line too long: account=%d max_size=%d error=%v", account.ID, maxLineSize, err)
return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, err
}
log.Printf(
logger.LegacyPrintf("service.openai_gateway",
"[WARN] [OpenAI passthrough] 流读取异常中断: account=%d request_id=%s err=%v",
account.ID,
upstreamRequestID,
......@@ -1381,11 +1379,11 @@ func (s *OpenAIGatewayService) handleStreamingResponsePassthrough(
return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream read error: %w", err)
}
if !clientDisconnected && !sawDone && ctx.Err() == nil {
log.Printf(
"[WARN] [OpenAI passthrough] 上游流在未收到 [DONE] 时结束,疑似断流: account=%d request_id=%s",
account.ID,
upstreamRequestID,
)
logger.FromContext(ctx).With(
zap.String("component", "service.openai_gateway"),
zap.Int64("account_id", account.ID),
zap.String("upstream_request_id", upstreamRequestID),
).Warn("OpenAI passthrough 上游流在未收到 [DONE] 时结束,疑似断流")
}
return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, nil
......@@ -1584,7 +1582,7 @@ func (s *OpenAIGatewayService) handleErrorResponse(ctx context.Context, resp *ht
setOpsUpstreamError(c, resp.StatusCode, upstreamMsg, upstreamDetail)
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
log.Printf(
logger.LegacyPrintf("service.openai_gateway",
"OpenAI upstream error %d (account=%d platform=%s type=%s): %s",
resp.StatusCode,
account.ID,
......@@ -1844,16 +1842,16 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
// 客户端断开/取消请求时,上游读取往往会返回 context canceled。
// /v1/responses 的 SSE 事件必须符合 OpenAI 协议;这里不注入自定义 error event,避免下游 SDK 解析失败。
if errors.Is(ev.err, context.Canceled) || errors.Is(ev.err, context.DeadlineExceeded) {
log.Printf("Context canceled during streaming, returning collected usage")
logger.LegacyPrintf("service.openai_gateway", "Context canceled during streaming, returning collected usage")
return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, nil
}
// 客户端已断开时,上游出错仅影响体验,不影响计费;返回已收集 usage
if clientDisconnected {
log.Printf("Upstream read error after client disconnect: %v, returning collected usage", ev.err)
logger.LegacyPrintf("service.openai_gateway", "Upstream read error after client disconnect: %v, returning collected usage", ev.err)
return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, nil
}
if errors.Is(ev.err, bufio.ErrTooLong) {
log.Printf("SSE line too long: account=%d max_size=%d error=%v", account.ID, maxLineSize, ev.err)
logger.LegacyPrintf("service.openai_gateway", "SSE line too long: account=%d max_size=%d error=%v", account.ID, maxLineSize, ev.err)
sendErrorEvent("response_too_large")
return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, ev.err
}
......@@ -1882,7 +1880,7 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
if !clientDisconnected {
if _, err := fmt.Fprintf(w, "%s\n", line); err != nil {
clientDisconnected = true
log.Printf("Client disconnected during streaming, continuing to drain upstream for billing")
logger.LegacyPrintf("service.openai_gateway", "Client disconnected during streaming, continuing to drain upstream for billing")
} else {
flusher.Flush()
}
......@@ -1899,7 +1897,7 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
if !clientDisconnected {
if _, err := fmt.Fprintf(w, "%s\n", line); err != nil {
clientDisconnected = true
log.Printf("Client disconnected during streaming, continuing to drain upstream for billing")
logger.LegacyPrintf("service.openai_gateway", "Client disconnected during streaming, continuing to drain upstream for billing")
} else {
flusher.Flush()
}
......@@ -1912,10 +1910,10 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
continue
}
if clientDisconnected {
log.Printf("Upstream timeout after client disconnect, returning collected usage")
logger.LegacyPrintf("service.openai_gateway", "Upstream timeout after client disconnect, returning collected usage")
return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, nil
}
log.Printf("Stream data interval timeout: account=%d model=%s interval=%s", account.ID, originalModel, streamInterval)
logger.LegacyPrintf("service.openai_gateway", "Stream data interval timeout: account=%d model=%s interval=%s", account.ID, originalModel, streamInterval)
// 处理流超时,可能标记账户为临时不可调度或错误状态
if s.rateLimitService != nil {
s.rateLimitService.HandleStreamTimeout(ctx, account, originalModel)
......@@ -1932,7 +1930,7 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
}
if _, err := fmt.Fprint(w, ":\n\n"); err != nil {
clientDisconnected = true
log.Printf("Client disconnected during streaming, continuing to drain upstream for billing")
logger.LegacyPrintf("service.openai_gateway", "Client disconnected during streaming, continuing to drain upstream for billing")
continue
}
flusher.Flush()
......@@ -2323,7 +2321,7 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
inserted, err := s.usageLogRepo.Create(ctx, usageLog)
if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple {
log.Printf("[SIMPLE MODE] Usage recorded (not billed): user=%d, tokens=%d", usageLog.UserID, usageLog.TotalTokens())
logger.LegacyPrintf("service.openai_gateway", "[SIMPLE MODE] Usage recorded (not billed): user=%d, tokens=%d", usageLog.UserID, usageLog.TotalTokens())
s.deferredService.ScheduleLastUsedUpdate(account.ID)
return nil
}
......@@ -2346,7 +2344,7 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
// Update API key quota if applicable (only for balance mode with quota set)
if shouldBill && cost.ActualCost > 0 && apiKey.Quota > 0 && input.APIKeyService != nil {
if err := input.APIKeyService.UpdateQuotaUsed(ctx, apiKey.ID, cost.ActualCost); err != nil {
log.Printf("Update API key quota failed: %v", err)
logger.LegacyPrintf("service.openai_gateway", "Update API key quota failed: %v", err)
}
}
......
......@@ -3,17 +3,17 @@ package service
import (
"bytes"
"context"
"fmt"
"io"
"log"
"net/http"
"net/http/httptest"
"os"
"strings"
"sync"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
......@@ -46,24 +46,76 @@ func (u *httpUpstreamRecorder) DoWithTLS(req *http.Request, proxyURL string, acc
return u.Do(req, proxyURL, accountID, accountConcurrency)
}
var stdLogCaptureMu sync.Mutex
var structuredLogCaptureMu sync.Mutex
func captureStdLog(t *testing.T) (*bytes.Buffer, func()) {
t.Helper()
stdLogCaptureMu.Lock()
buf := &bytes.Buffer{}
prevWriter := log.Writer()
prevFlags := log.Flags()
log.SetFlags(0)
log.SetOutput(buf)
return buf, func() {
log.SetOutput(prevWriter)
log.SetFlags(prevFlags)
// 防御性恢复,避免其他测试改动了底层 writer。
if prevWriter == nil {
log.SetOutput(os.Stderr)
type inMemoryLogSink struct {
mu sync.Mutex
events []*logger.LogEvent
}
func (s *inMemoryLogSink) WriteLogEvent(event *logger.LogEvent) {
if event == nil {
return
}
cloned := *event
if event.Fields != nil {
cloned.Fields = make(map[string]any, len(event.Fields))
for k, v := range event.Fields {
cloned.Fields[k] = v
}
}
s.mu.Lock()
s.events = append(s.events, &cloned)
s.mu.Unlock()
}
func (s *inMemoryLogSink) ContainsMessage(substr string) bool {
s.mu.Lock()
defer s.mu.Unlock()
for _, ev := range s.events {
if ev != nil && strings.Contains(ev.Message, substr) {
return true
}
}
return false
}
func (s *inMemoryLogSink) ContainsFieldValue(field, substr string) bool {
s.mu.Lock()
defer s.mu.Unlock()
for _, ev := range s.events {
if ev == nil || ev.Fields == nil {
continue
}
if v, ok := ev.Fields[field]; ok && strings.Contains(fmt.Sprint(v), substr) {
return true
}
stdLogCaptureMu.Unlock()
}
return false
}
func captureStructuredLog(t *testing.T) (*inMemoryLogSink, func()) {
t.Helper()
structuredLogCaptureMu.Lock()
err := logger.Init(logger.InitOptions{
Level: "debug",
Format: "json",
ServiceName: "sub2api",
Environment: "test",
Output: logger.OutputOptions{
ToStdout: true,
ToFile: false,
},
Sampling: logger.SamplingOptions{Enabled: false},
})
require.NoError(t, err)
sink := &inMemoryLogSink{}
logger.SetSink(sink)
return sink, func() {
logger.SetSink(nil)
structuredLogCaptureMu.Unlock()
}
}
......@@ -486,7 +538,7 @@ func TestOpenAIGatewayService_APIKeyPassthrough_PreservesBodyAndUsesResponsesEnd
func TestOpenAIGatewayService_OAuthPassthrough_WarnOnTimeoutHeadersForStream(t *testing.T) {
gin.SetMode(gin.TestMode)
logBuf, restore := captureStdLog(t)
logSink, restore := captureStructuredLog(t)
defer restore()
rec := httptest.NewRecorder()
......@@ -521,13 +573,13 @@ func TestOpenAIGatewayService_OAuthPassthrough_WarnOnTimeoutHeadersForStream(t *
_, err := svc.Forward(context.Background(), c, account, originalBody)
require.NoError(t, err)
require.Contains(t, logBuf.String(), "检测到超时相关请求头,将按配置过滤以降低断流风险")
require.Contains(t, logBuf.String(), "x-stainless-timeout=10000")
require.True(t, logSink.ContainsMessage("检测到超时相关请求头,将按配置过滤以降低断流风险"))
require.True(t, logSink.ContainsFieldValue("timeout_headers", "x-stainless-timeout=10000"))
}
func TestOpenAIGatewayService_OAuthPassthrough_WarnWhenStreamEndsWithoutDone(t *testing.T) {
gin.SetMode(gin.TestMode)
logBuf, restore := captureStdLog(t)
logSink, restore := captureStructuredLog(t)
defer restore()
rec := httptest.NewRecorder()
......@@ -562,8 +614,8 @@ func TestOpenAIGatewayService_OAuthPassthrough_WarnWhenStreamEndsWithoutDone(t *
_, err := svc.Forward(context.Background(), c, account, originalBody)
require.NoError(t, err)
require.Contains(t, logBuf.String(), "上游流在未收到 [DONE] 时结束,疑似断流")
require.Contains(t, logBuf.String(), "rid-truncate")
require.True(t, logSink.ContainsMessage("上游流在未收到 [DONE] 时结束,疑似断流"))
require.True(t, logSink.ContainsFieldValue("upstream_request_id", "rid-truncate"))
}
func TestOpenAIGatewayService_OAuthPassthrough_DefaultFiltersTimeoutHeaders(t *testing.T) {
......
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment