Commit 8363663e authored by shaw's avatar shaw
Browse files

fix(gateway): 修复 usage_logs 记录 IP 不正确的问题

在 nginx 反向代理场景下,使用 ip.GetClientIP() 替代 c.ClientIP()
以正确获取客户端真实 IP 地址
parent b588ea19
...@@ -15,6 +15,7 @@ import ( ...@@ -15,6 +15,7 @@ import (
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity" "github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
"github.com/Wei-Shaw/sub2api/internal/pkg/claude" "github.com/Wei-Shaw/sub2api/internal/pkg/claude"
pkgerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" pkgerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai" "github.com/Wei-Shaw/sub2api/internal/pkg/openai"
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware" middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
...@@ -291,10 +292,10 @@ func (h *GatewayHandler) Messages(c *gin.Context) { ...@@ -291,10 +292,10 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
// 捕获请求信息(用于异步记录,避免在 goroutine 中访问 gin.Context) // 捕获请求信息(用于异步记录,避免在 goroutine 中访问 gin.Context)
userAgent := c.GetHeader("User-Agent") userAgent := c.GetHeader("User-Agent")
clientIP := c.ClientIP() clientIP := ip.GetClientIP(c)
// 异步记录使用量(subscription已在函数开头获取) // 异步记录使用量(subscription已在函数开头获取)
go func(result *service.ForwardResult, usedAccount *service.Account, ua, ip string) { go func(result *service.ForwardResult, usedAccount *service.Account, ua, clientIP string) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel() defer cancel()
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{ if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
...@@ -304,7 +305,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) { ...@@ -304,7 +305,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
Account: usedAccount, Account: usedAccount,
Subscription: subscription, Subscription: subscription,
UserAgent: ua, UserAgent: ua,
IPAddress: ip, IPAddress: clientIP,
}); err != nil { }); err != nil {
log.Printf("Record usage failed: %v", err) log.Printf("Record usage failed: %v", err)
} }
...@@ -425,10 +426,10 @@ func (h *GatewayHandler) Messages(c *gin.Context) { ...@@ -425,10 +426,10 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
// 捕获请求信息(用于异步记录,避免在 goroutine 中访问 gin.Context) // 捕获请求信息(用于异步记录,避免在 goroutine 中访问 gin.Context)
userAgent := c.GetHeader("User-Agent") userAgent := c.GetHeader("User-Agent")
clientIP := c.ClientIP() clientIP := ip.GetClientIP(c)
// 异步记录使用量(subscription已在函数开头获取) // 异步记录使用量(subscription已在函数开头获取)
go func(result *service.ForwardResult, usedAccount *service.Account, ua, ip string) { go func(result *service.ForwardResult, usedAccount *service.Account, ua, clientIP string) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel() defer cancel()
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{ if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
...@@ -438,7 +439,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) { ...@@ -438,7 +439,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
Account: usedAccount, Account: usedAccount,
Subscription: subscription, Subscription: subscription,
UserAgent: ua, UserAgent: ua,
IPAddress: ip, IPAddress: clientIP,
}); err != nil { }); err != nil {
log.Printf("Record usage failed: %v", err) log.Printf("Record usage failed: %v", err)
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment