Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Menu
Open sidebar
陈曦
sub2api
Commits
d04b47b3
Commit
d04b47b3
authored
Feb 14, 2026
by
yangjianbo
Browse files
feat(backend): 提交后端审计修复与配套测试改动
parent
86219914
Changes
22
Hide whitespace changes
Inline
Side-by-side
README_CN.md
View file @
d04b47b3
...
...
@@ -404,6 +404,14 @@ gateway:
-
`server.trusted_proxies`
启用可信代理解析 X-Forwarded-For
-
`turnstile.required`
在 release 模式强制启用 Turnstile
**网关防御纵深建议(重点)**
-
`gateway.upstream_response_read_max_bytes`
:限制非流式上游响应读取大小(默认
`8MB`
),用于防止异常响应导致内存放大。
-
`gateway.proxy_probe_response_read_max_bytes`
:限制代理探测响应读取大小(默认
`1MB`
)。
-
`gateway.gemini_debug_response_headers`
:默认
`false`
,仅在排障时短时开启,避免高频请求日志开销。
-
`/auth/register`
、
`/auth/login`
、
`/auth/login/2fa`
、
`/auth/send-verify-code`
已提供服务端兜底限流(Redis 故障时 fail-close)。
-
推荐将 WAF/CDN 作为第一层防护,服务端限流与响应读取上限作为第二层兜底;两层同时保留,避免旁路流量与误配置风险。
**⚠️ 安全警告:HTTP URL 配置**
当
`security.url_allowlist.enabled=false`
时,系统默认执行最小 URL 校验,
**拒绝 HTTP URL**
,仅允许 HTTPS。要允许 HTTP URL(例如用于开发或内网测试),必须显式设置:
...
...
backend/internal/config/config.go
View file @
d04b47b3
...
...
@@ -308,6 +308,12 @@ type GatewayConfig struct {
ResponseHeaderTimeout
int
`mapstructure:"response_header_timeout"`
// 请求体最大字节数,用于网关请求体大小限制
MaxBodySize
int64
`mapstructure:"max_body_size"`
// 非流式上游响应体读取上限(字节),用于防止无界读取导致内存放大
UpstreamResponseReadMaxBytes
int64
`mapstructure:"upstream_response_read_max_bytes"`
// 代理探测响应体读取上限(字节)
ProxyProbeResponseReadMaxBytes
int64
`mapstructure:"proxy_probe_response_read_max_bytes"`
// Gemini 上游响应头调试日志开关(默认关闭,避免高频日志开销)
GeminiDebugResponseHeaders
bool
`mapstructure:"gemini_debug_response_headers"`
// ConnectionPoolIsolation: 上游连接池隔离策略(proxy/account/account_proxy)
ConnectionPoolIsolation
string
`mapstructure:"connection_pool_isolation"`
// ForceCodexCLI: 强制将 OpenAI `/v1/responses` 请求按 Codex CLI 处理。
...
...
@@ -1059,6 +1065,9 @@ func setDefaults() {
viper
.
SetDefault
(
"gateway.openai_passthrough_allow_timeout_headers"
,
false
)
viper
.
SetDefault
(
"gateway.antigravity_fallback_cooldown_minutes"
,
1
)
viper
.
SetDefault
(
"gateway.max_body_size"
,
int64
(
100
*
1024
*
1024
))
viper
.
SetDefault
(
"gateway.upstream_response_read_max_bytes"
,
int64
(
8
*
1024
*
1024
))
viper
.
SetDefault
(
"gateway.proxy_probe_response_read_max_bytes"
,
int64
(
1024
*
1024
))
viper
.
SetDefault
(
"gateway.gemini_debug_response_headers"
,
false
)
viper
.
SetDefault
(
"gateway.sora_max_body_size"
,
int64
(
256
*
1024
*
1024
))
viper
.
SetDefault
(
"gateway.sora_stream_timeout_seconds"
,
900
)
viper
.
SetDefault
(
"gateway.sora_request_timeout_seconds"
,
180
)
...
...
@@ -1465,6 +1474,12 @@ func (c *Config) Validate() error {
if
c
.
Gateway
.
MaxBodySize
<=
0
{
return
fmt
.
Errorf
(
"gateway.max_body_size must be positive"
)
}
if
c
.
Gateway
.
UpstreamResponseReadMaxBytes
<=
0
{
return
fmt
.
Errorf
(
"gateway.upstream_response_read_max_bytes must be positive"
)
}
if
c
.
Gateway
.
ProxyProbeResponseReadMaxBytes
<=
0
{
return
fmt
.
Errorf
(
"gateway.proxy_probe_response_read_max_bytes must be positive"
)
}
if
c
.
Gateway
.
SoraMaxBodySize
<
0
{
return
fmt
.
Errorf
(
"gateway.sora_max_body_size must be non-negative"
)
}
...
...
backend/internal/handler/gateway_handler.go
View file @
d04b47b3
...
...
@@ -418,8 +418,12 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
}
continue
}
// 错误响应已在Forward中处理,这里只记录日志
reqLog
.
Error
(
"gateway.forward_failed"
,
zap
.
Int64
(
"account_id"
,
account
.
ID
),
zap
.
Error
(
err
))
wroteFallback
:=
h
.
ensureForwardErrorResponse
(
c
,
streamStarted
)
reqLog
.
Error
(
"gateway.forward_failed"
,
zap
.
Int64
(
"account_id"
,
account
.
ID
),
zap
.
Bool
(
"fallback_error_response_written"
,
wroteFallback
),
zap
.
Error
(
err
),
)
return
}
...
...
@@ -683,8 +687,12 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
}
continue
}
// 错误响应已在Forward中处理,这里只记录日志
reqLog
.
Error
(
"gateway.forward_failed"
,
zap
.
Int64
(
"account_id"
,
account
.
ID
),
zap
.
Error
(
err
))
wroteFallback
:=
h
.
ensureForwardErrorResponse
(
c
,
streamStarted
)
reqLog
.
Error
(
"gateway.forward_failed"
,
zap
.
Int64
(
"account_id"
,
account
.
ID
),
zap
.
Bool
(
"fallback_error_response_written"
,
wroteFallback
),
zap
.
Error
(
err
),
)
return
}
...
...
@@ -1117,6 +1125,15 @@ func (h *GatewayHandler) handleStreamingAwareError(c *gin.Context, status int, e
h
.
errorResponse
(
c
,
status
,
errType
,
message
)
}
// ensureForwardErrorResponse 在 Forward 返回错误但尚未写响应时补写统一错误响应。
func
(
h
*
GatewayHandler
)
ensureForwardErrorResponse
(
c
*
gin
.
Context
,
streamStarted
bool
)
bool
{
if
c
==
nil
||
c
.
Writer
==
nil
||
c
.
Writer
.
Written
()
{
return
false
}
h
.
handleStreamingAwareError
(
c
,
http
.
StatusBadGateway
,
"upstream_error"
,
"Upstream request failed"
,
streamStarted
)
return
true
}
// errorResponse 返回Claude API格式的错误响应
func
(
h
*
GatewayHandler
)
errorResponse
(
c
*
gin
.
Context
,
status
int
,
errType
,
message
string
)
{
c
.
JSON
(
status
,
gin
.
H
{
...
...
backend/internal/handler/gateway_handler_error_fallback_test.go
0 → 100644
View file @
d04b47b3
package
handler
import
(
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func
TestGatewayEnsureForwardErrorResponse_WritesFallbackWhenNotWritten
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
w
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
w
)
c
.
Request
=
httptest
.
NewRequest
(
http
.
MethodGet
,
"/"
,
nil
)
h
:=
&
GatewayHandler
{}
wrote
:=
h
.
ensureForwardErrorResponse
(
c
,
false
)
require
.
True
(
t
,
wrote
)
require
.
Equal
(
t
,
http
.
StatusBadGateway
,
w
.
Code
)
var
parsed
map
[
string
]
any
err
:=
json
.
Unmarshal
(
w
.
Body
.
Bytes
(),
&
parsed
)
require
.
NoError
(
t
,
err
)
assert
.
Equal
(
t
,
"error"
,
parsed
[
"type"
])
errorObj
,
ok
:=
parsed
[
"error"
]
.
(
map
[
string
]
any
)
require
.
True
(
t
,
ok
)
assert
.
Equal
(
t
,
"upstream_error"
,
errorObj
[
"type"
])
assert
.
Equal
(
t
,
"Upstream request failed"
,
errorObj
[
"message"
])
}
func
TestGatewayEnsureForwardErrorResponse_DoesNotOverrideWrittenResponse
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
w
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
w
)
c
.
Request
=
httptest
.
NewRequest
(
http
.
MethodGet
,
"/"
,
nil
)
c
.
String
(
http
.
StatusTeapot
,
"already written"
)
h
:=
&
GatewayHandler
{}
wrote
:=
h
.
ensureForwardErrorResponse
(
c
,
false
)
require
.
False
(
t
,
wrote
)
require
.
Equal
(
t
,
http
.
StatusTeapot
,
w
.
Code
)
assert
.
Equal
(
t
,
"already written"
,
w
.
Body
.
String
())
}
backend/internal/handler/openai_gateway_handler.go
View file @
d04b47b3
...
...
@@ -365,8 +365,12 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
)
continue
}
// Error response already handled in Forward, just log
reqLog
.
Error
(
"openai.forward_failed"
,
zap
.
Int64
(
"account_id"
,
account
.
ID
),
zap
.
Error
(
err
))
wroteFallback
:=
h
.
ensureForwardErrorResponse
(
c
,
streamStarted
)
reqLog
.
Error
(
"openai.forward_failed"
,
zap
.
Int64
(
"account_id"
,
account
.
ID
),
zap
.
Bool
(
"fallback_error_response_written"
,
wroteFallback
),
zap
.
Error
(
err
),
)
return
}
...
...
@@ -521,6 +525,15 @@ func (h *OpenAIGatewayHandler) handleStreamingAwareError(c *gin.Context, status
h
.
errorResponse
(
c
,
status
,
errType
,
message
)
}
// ensureForwardErrorResponse 在 Forward 返回错误但尚未写响应时补写统一错误响应。
func
(
h
*
OpenAIGatewayHandler
)
ensureForwardErrorResponse
(
c
*
gin
.
Context
,
streamStarted
bool
)
bool
{
if
c
==
nil
||
c
.
Writer
==
nil
||
c
.
Writer
.
Written
()
{
return
false
}
h
.
handleStreamingAwareError
(
c
,
http
.
StatusBadGateway
,
"upstream_error"
,
"Upstream request failed"
,
streamStarted
)
return
true
}
// errorResponse returns OpenAI API format error response
func
(
h
*
OpenAIGatewayHandler
)
errorResponse
(
c
*
gin
.
Context
,
status
int
,
errType
,
message
string
)
{
c
.
JSON
(
status
,
gin
.
H
{
...
...
backend/internal/handler/openai_gateway_handler_test.go
View file @
d04b47b3
...
...
@@ -105,6 +105,42 @@ func TestOpenAIHandleStreamingAwareError_NonStreaming(t *testing.T) {
assert
.
Equal
(
t
,
"test error"
,
errorObj
[
"message"
])
}
func
TestOpenAIEnsureForwardErrorResponse_WritesFallbackWhenNotWritten
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
w
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
w
)
c
.
Request
=
httptest
.
NewRequest
(
http
.
MethodGet
,
"/"
,
nil
)
h
:=
&
OpenAIGatewayHandler
{}
wrote
:=
h
.
ensureForwardErrorResponse
(
c
,
false
)
require
.
True
(
t
,
wrote
)
require
.
Equal
(
t
,
http
.
StatusBadGateway
,
w
.
Code
)
var
parsed
map
[
string
]
any
err
:=
json
.
Unmarshal
(
w
.
Body
.
Bytes
(),
&
parsed
)
require
.
NoError
(
t
,
err
)
errorObj
,
ok
:=
parsed
[
"error"
]
.
(
map
[
string
]
any
)
require
.
True
(
t
,
ok
)
assert
.
Equal
(
t
,
"upstream_error"
,
errorObj
[
"type"
])
assert
.
Equal
(
t
,
"Upstream request failed"
,
errorObj
[
"message"
])
}
func
TestOpenAIEnsureForwardErrorResponse_DoesNotOverrideWrittenResponse
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
w
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
w
)
c
.
Request
=
httptest
.
NewRequest
(
http
.
MethodGet
,
"/"
,
nil
)
c
.
String
(
http
.
StatusTeapot
,
"already written"
)
h
:=
&
OpenAIGatewayHandler
{}
wrote
:=
h
.
ensureForwardErrorResponse
(
c
,
false
)
require
.
False
(
t
,
wrote
)
require
.
Equal
(
t
,
http
.
StatusTeapot
,
w
.
Code
)
assert
.
Equal
(
t
,
"already written"
,
w
.
Body
.
String
())
}
// TestOpenAIHandler_GjsonExtraction 验证 gjson 从请求体中提取 model/stream 的正确性
func
TestOpenAIHandler_GjsonExtraction
(
t
*
testing
.
T
)
{
tests
:=
[]
struct
{
...
...
backend/internal/pkg/ip/ip.go
View file @
d04b47b3
...
...
@@ -44,6 +44,16 @@ func GetClientIP(c *gin.Context) string {
return
normalizeIP
(
c
.
ClientIP
())
}
// GetTrustedClientIP 从 Gin 的可信代理解析链提取客户端 IP。
// 该方法依赖 gin.Engine.SetTrustedProxies 配置,不会优先直接信任原始转发头值。
// 适用于 ACL / 风控等安全敏感场景。
func
GetTrustedClientIP
(
c
*
gin
.
Context
)
string
{
if
c
==
nil
{
return
""
}
return
normalizeIP
(
c
.
ClientIP
())
}
// normalizeIP 规范化 IP 地址,去除端口号和空格。
func
normalizeIP
(
ip
string
)
string
{
ip
=
strings
.
TrimSpace
(
ip
)
...
...
backend/internal/pkg/ip/ip_test.go
View file @
d04b47b3
...
...
@@ -3,8 +3,10 @@
package
ip
import
(
"net/http/httptest"
"testing"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
...
...
@@ -49,3 +51,25 @@ func TestIsPrivateIP(t *testing.T) {
})
}
}
func
TestGetTrustedClientIPUsesGinClientIP
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
r
:=
gin
.
New
()
require
.
NoError
(
t
,
r
.
SetTrustedProxies
(
nil
))
r
.
GET
(
"/t"
,
func
(
c
*
gin
.
Context
)
{
c
.
String
(
200
,
GetTrustedClientIP
(
c
))
})
w
:=
httptest
.
NewRecorder
()
req
:=
httptest
.
NewRequest
(
"GET"
,
"/t"
,
nil
)
req
.
RemoteAddr
=
"9.9.9.9:12345"
req
.
Header
.
Set
(
"X-Forwarded-For"
,
"1.2.3.4"
)
req
.
Header
.
Set
(
"X-Real-IP"
,
"1.2.3.4"
)
req
.
Header
.
Set
(
"CF-Connecting-IP"
,
"1.2.3.4"
)
r
.
ServeHTTP
(
w
,
req
)
require
.
Equal
(
t
,
200
,
w
.
Code
)
require
.
Equal
(
t
,
"9.9.9.9"
,
w
.
Body
.
String
())
}
backend/internal/repository/proxy_probe_service.go
View file @
d04b47b3
...
...
@@ -19,10 +19,14 @@ func NewProxyExitInfoProber(cfg *config.Config) service.ProxyExitInfoProber {
insecure
:=
false
allowPrivate
:=
false
validateResolvedIP
:=
true
maxResponseBytes
:=
defaultProxyProbeResponseMaxBytes
if
cfg
!=
nil
{
insecure
=
cfg
.
Security
.
ProxyProbe
.
InsecureSkipVerify
allowPrivate
=
cfg
.
Security
.
URLAllowlist
.
AllowPrivateHosts
validateResolvedIP
=
cfg
.
Security
.
URLAllowlist
.
Enabled
if
cfg
.
Gateway
.
ProxyProbeResponseReadMaxBytes
>
0
{
maxResponseBytes
=
cfg
.
Gateway
.
ProxyProbeResponseReadMaxBytes
}
}
if
insecure
{
log
.
Printf
(
"[ProxyProbe] Warning: insecure_skip_verify is not allowed and will cause probe failure."
)
...
...
@@ -31,11 +35,13 @@ func NewProxyExitInfoProber(cfg *config.Config) service.ProxyExitInfoProber {
insecureSkipVerify
:
insecure
,
allowPrivateHosts
:
allowPrivate
,
validateResolvedIP
:
validateResolvedIP
,
maxResponseBytes
:
maxResponseBytes
,
}
}
const
(
defaultProxyProbeTimeout
=
30
*
time
.
Second
defaultProxyProbeTimeout
=
30
*
time
.
Second
defaultProxyProbeResponseMaxBytes
=
int64
(
1024
*
1024
)
)
// probeURLs 按优先级排列的探测 URL 列表
...
...
@@ -52,6 +58,7 @@ type proxyProbeService struct {
insecureSkipVerify
bool
allowPrivateHosts
bool
validateResolvedIP
bool
maxResponseBytes
int64
}
func
(
s
*
proxyProbeService
)
ProbeProxy
(
ctx
context
.
Context
,
proxyURL
string
)
(
*
service
.
ProxyExitInfo
,
int64
,
error
)
{
...
...
@@ -98,10 +105,17 @@ func (s *proxyProbeService) probeWithURL(ctx context.Context, client *http.Clien
return
nil
,
latencyMs
,
fmt
.
Errorf
(
"request failed with status: %d"
,
resp
.
StatusCode
)
}
body
,
err
:=
io
.
ReadAll
(
resp
.
Body
)
maxResponseBytes
:=
s
.
maxResponseBytes
if
maxResponseBytes
<=
0
{
maxResponseBytes
=
defaultProxyProbeResponseMaxBytes
}
body
,
err
:=
io
.
ReadAll
(
io
.
LimitReader
(
resp
.
Body
,
maxResponseBytes
+
1
))
if
err
!=
nil
{
return
nil
,
latencyMs
,
fmt
.
Errorf
(
"failed to read response: %w"
,
err
)
}
if
int64
(
len
(
body
))
>
maxResponseBytes
{
return
nil
,
latencyMs
,
fmt
.
Errorf
(
"proxy probe response exceeds limit: %d"
,
maxResponseBytes
)
}
switch
parser
{
case
"ip-api"
:
...
...
backend/internal/server/http.go
View file @
d04b47b3
...
...
@@ -51,6 +51,9 @@ func ProvideRouter(
if
err
:=
r
.
SetTrustedProxies
(
nil
);
err
!=
nil
{
log
.
Printf
(
"Failed to disable trusted proxies: %v"
,
err
)
}
if
cfg
.
Server
.
Mode
==
"release"
{
log
.
Printf
(
"Warning: server.trusted_proxies is empty in release mode; client IP trust chain is disabled"
)
}
}
return
SetupRouter
(
r
,
handlers
,
jwtAuth
,
adminAuth
,
apiKeyAuth
,
apiKeyService
,
subscriptionService
,
opsService
,
settingService
,
cfg
,
redisClient
)
...
...
backend/internal/server/middleware/api_key_auth.go
View file @
d04b47b3
...
...
@@ -96,7 +96,7 @@ func apiKeyAuthWithSubscription(apiKeyService *service.APIKeyService, subscripti
// 检查 IP 限制(白名单/黑名单)
// 注意:错误信息故意模糊,避免暴露具体的 IP 限制机制
if
len
(
apiKey
.
IPWhitelist
)
>
0
||
len
(
apiKey
.
IPBlacklist
)
>
0
{
clientIP
:=
ip
.
GetClientIP
(
c
)
clientIP
:=
ip
.
Get
Trusted
ClientIP
(
c
)
allowed
,
_
:=
ip
.
CheckIPRestriction
(
clientIP
,
apiKey
.
IPWhitelist
,
apiKey
.
IPBlacklist
)
if
!
allowed
{
AbortWithError
(
c
,
403
,
"ACCESS_DENIED"
,
"Access denied"
)
...
...
backend/internal/server/middleware/api_key_auth_test.go
View file @
d04b47b3
...
...
@@ -300,6 +300,57 @@ func TestAPIKeyAuthOverwritesInvalidContextGroup(t *testing.T) {
require
.
Equal
(
t
,
http
.
StatusOK
,
w
.
Code
)
}
func
TestAPIKeyAuthIPRestrictionDoesNotTrustSpoofedForwardHeaders
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
user
:=
&
service
.
User
{
ID
:
7
,
Role
:
service
.
RoleUser
,
Status
:
service
.
StatusActive
,
Balance
:
10
,
Concurrency
:
3
,
}
apiKey
:=
&
service
.
APIKey
{
ID
:
100
,
UserID
:
user
.
ID
,
Key
:
"test-key"
,
Status
:
service
.
StatusActive
,
User
:
user
,
IPWhitelist
:
[]
string
{
"1.2.3.4"
},
}
apiKeyRepo
:=
&
stubApiKeyRepo
{
getByKey
:
func
(
ctx
context
.
Context
,
key
string
)
(
*
service
.
APIKey
,
error
)
{
if
key
!=
apiKey
.
Key
{
return
nil
,
service
.
ErrAPIKeyNotFound
}
clone
:=
*
apiKey
return
&
clone
,
nil
},
}
cfg
:=
&
config
.
Config
{
RunMode
:
config
.
RunModeSimple
}
apiKeyService
:=
service
.
NewAPIKeyService
(
apiKeyRepo
,
nil
,
nil
,
nil
,
nil
,
nil
,
cfg
)
router
:=
gin
.
New
()
require
.
NoError
(
t
,
router
.
SetTrustedProxies
(
nil
))
router
.
Use
(
gin
.
HandlerFunc
(
NewAPIKeyAuthMiddleware
(
apiKeyService
,
nil
,
cfg
)))
router
.
GET
(
"/t"
,
func
(
c
*
gin
.
Context
)
{
c
.
JSON
(
http
.
StatusOK
,
gin
.
H
{
"ok"
:
true
})
})
w
:=
httptest
.
NewRecorder
()
req
:=
httptest
.
NewRequest
(
http
.
MethodGet
,
"/t"
,
nil
)
req
.
RemoteAddr
=
"9.9.9.9:12345"
req
.
Header
.
Set
(
"x-api-key"
,
apiKey
.
Key
)
req
.
Header
.
Set
(
"X-Forwarded-For"
,
"1.2.3.4"
)
req
.
Header
.
Set
(
"X-Real-IP"
,
"1.2.3.4"
)
req
.
Header
.
Set
(
"CF-Connecting-IP"
,
"1.2.3.4"
)
router
.
ServeHTTP
(
w
,
req
)
require
.
Equal
(
t
,
http
.
StatusForbidden
,
w
.
Code
)
require
.
Contains
(
t
,
w
.
Body
.
String
(),
"ACCESS_DENIED"
)
}
func
newAuthTestRouter
(
apiKeyService
*
service
.
APIKeyService
,
subscriptionService
*
service
.
SubscriptionService
,
cfg
*
config
.
Config
)
*
gin
.
Engine
{
router
:=
gin
.
New
()
router
.
Use
(
gin
.
HandlerFunc
(
NewAPIKeyAuthMiddleware
(
apiKeyService
,
subscriptionService
,
cfg
)))
...
...
backend/internal/server/routes/auth.go
View file @
d04b47b3
...
...
@@ -24,10 +24,19 @@ func RegisterAuthRoutes(
// 公开接口
auth
:=
v1
.
Group
(
"/auth"
)
{
auth
.
POST
(
"/register"
,
h
.
Auth
.
Register
)
auth
.
POST
(
"/login"
,
h
.
Auth
.
Login
)
auth
.
POST
(
"/login/2fa"
,
h
.
Auth
.
Login2FA
)
auth
.
POST
(
"/send-verify-code"
,
h
.
Auth
.
SendVerifyCode
)
// 注册/登录/2FA/验证码发送均属于高风险入口,增加服务端兜底限流(Redis 故障时 fail-close)
auth
.
POST
(
"/register"
,
rateLimiter
.
LimitWithOptions
(
"auth-register"
,
5
,
time
.
Minute
,
middleware
.
RateLimitOptions
{
FailureMode
:
middleware
.
RateLimitFailClose
,
}),
h
.
Auth
.
Register
)
auth
.
POST
(
"/login"
,
rateLimiter
.
LimitWithOptions
(
"auth-login"
,
20
,
time
.
Minute
,
middleware
.
RateLimitOptions
{
FailureMode
:
middleware
.
RateLimitFailClose
,
}),
h
.
Auth
.
Login
)
auth
.
POST
(
"/login/2fa"
,
rateLimiter
.
LimitWithOptions
(
"auth-login-2fa"
,
20
,
time
.
Minute
,
middleware
.
RateLimitOptions
{
FailureMode
:
middleware
.
RateLimitFailClose
,
}),
h
.
Auth
.
Login2FA
)
auth
.
POST
(
"/send-verify-code"
,
rateLimiter
.
LimitWithOptions
(
"auth-send-verify-code"
,
5
,
time
.
Minute
,
middleware
.
RateLimitOptions
{
FailureMode
:
middleware
.
RateLimitFailClose
,
}),
h
.
Auth
.
SendVerifyCode
)
// Token刷新接口添加速率限制:每分钟最多 30 次(Redis 故障时 fail-close)
auth
.
POST
(
"/refresh"
,
rateLimiter
.
LimitWithOptions
(
"refresh-token"
,
30
,
time
.
Minute
,
middleware
.
RateLimitOptions
{
FailureMode
:
middleware
.
RateLimitFailClose
,
...
...
backend/internal/server/routes/auth_rate_limit_integration_test.go
0 → 100644
View file @
d04b47b3
//go:build integration
package
routes
import
(
"context"
"fmt"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"strconv"
"strings"
"testing"
"github.com/redis/go-redis/v9"
"github.com/stretchr/testify/require"
tcredis
"github.com/testcontainers/testcontainers-go/modules/redis"
)
const
authRouteRedisImageTag
=
"redis:8.4-alpine"
func
TestAuthRegisterRateLimitThresholdHitReturns429
(
t
*
testing
.
T
)
{
ctx
:=
context
.
Background
()
rdb
:=
startAuthRouteRedis
(
t
,
ctx
)
router
:=
newAuthRoutesTestRouter
(
rdb
)
const
path
=
"/api/v1/auth/register"
for
i
:=
1
;
i
<=
6
;
i
++
{
req
:=
httptest
.
NewRequest
(
http
.
MethodPost
,
path
,
strings
.
NewReader
(
`{}`
))
req
.
Header
.
Set
(
"Content-Type"
,
"application/json"
)
req
.
RemoteAddr
=
"198.51.100.10:23456"
w
:=
httptest
.
NewRecorder
()
router
.
ServeHTTP
(
w
,
req
)
if
i
<=
5
{
require
.
Equal
(
t
,
http
.
StatusBadRequest
,
w
.
Code
,
"第 %d 次请求应先进入业务校验"
,
i
)
continue
}
require
.
Equal
(
t
,
http
.
StatusTooManyRequests
,
w
.
Code
,
"第 6 次请求应命中限流"
)
require
.
Contains
(
t
,
w
.
Body
.
String
(),
"rate limit exceeded"
)
}
}
func
startAuthRouteRedis
(
t
*
testing
.
T
,
ctx
context
.
Context
)
*
redis
.
Client
{
t
.
Helper
()
ensureAuthRouteDockerAvailable
(
t
)
redisContainer
,
err
:=
tcredis
.
Run
(
ctx
,
authRouteRedisImageTag
)
require
.
NoError
(
t
,
err
)
t
.
Cleanup
(
func
()
{
_
=
redisContainer
.
Terminate
(
ctx
)
})
redisHost
,
err
:=
redisContainer
.
Host
(
ctx
)
require
.
NoError
(
t
,
err
)
redisPort
,
err
:=
redisContainer
.
MappedPort
(
ctx
,
"6379/tcp"
)
require
.
NoError
(
t
,
err
)
rdb
:=
redis
.
NewClient
(
&
redis
.
Options
{
Addr
:
fmt
.
Sprintf
(
"%s:%d"
,
redisHost
,
redisPort
.
Int
()),
DB
:
0
,
})
require
.
NoError
(
t
,
rdb
.
Ping
(
ctx
)
.
Err
())
t
.
Cleanup
(
func
()
{
_
=
rdb
.
Close
()
})
return
rdb
}
func
ensureAuthRouteDockerAvailable
(
t
*
testing
.
T
)
{
t
.
Helper
()
if
authRouteDockerAvailable
()
{
return
}
t
.
Skip
(
"Docker 未启用,跳过认证限流集成测试"
)
}
func
authRouteDockerAvailable
()
bool
{
if
os
.
Getenv
(
"DOCKER_HOST"
)
!=
""
{
return
true
}
socketCandidates
:=
[]
string
{
"/var/run/docker.sock"
,
filepath
.
Join
(
os
.
Getenv
(
"XDG_RUNTIME_DIR"
),
"docker.sock"
),
filepath
.
Join
(
authRouteUserHomeDir
(),
".docker"
,
"run"
,
"docker.sock"
),
filepath
.
Join
(
authRouteUserHomeDir
(),
".docker"
,
"desktop"
,
"docker.sock"
),
filepath
.
Join
(
"/run/user"
,
strconv
.
Itoa
(
os
.
Getuid
()),
"docker.sock"
),
}
for
_
,
socket
:=
range
socketCandidates
{
if
socket
==
""
{
continue
}
if
_
,
err
:=
os
.
Stat
(
socket
);
err
==
nil
{
return
true
}
}
return
false
}
func
authRouteUserHomeDir
()
string
{
home
,
err
:=
os
.
UserHomeDir
()
if
err
!=
nil
{
return
""
}
return
home
}
backend/internal/server/routes/auth_rate_limit_test.go
0 → 100644
View file @
d04b47b3
package
routes
import
(
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/handler"
servermiddleware
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/gin-gonic/gin"
"github.com/redis/go-redis/v9"
"github.com/stretchr/testify/require"
)
func
newAuthRoutesTestRouter
(
redisClient
*
redis
.
Client
)
*
gin
.
Engine
{
gin
.
SetMode
(
gin
.
TestMode
)
router
:=
gin
.
New
()
v1
:=
router
.
Group
(
"/api/v1"
)
RegisterAuthRoutes
(
v1
,
&
handler
.
Handlers
{
Auth
:
&
handler
.
AuthHandler
{},
Setting
:
&
handler
.
SettingHandler
{},
},
servermiddleware
.
JWTAuthMiddleware
(
func
(
c
*
gin
.
Context
)
{
c
.
Next
()
}),
redisClient
,
)
return
router
}
func
TestAuthRoutesRateLimitFailCloseWhenRedisUnavailable
(
t
*
testing
.
T
)
{
rdb
:=
redis
.
NewClient
(
&
redis
.
Options
{
Addr
:
"127.0.0.1:1"
,
DialTimeout
:
50
*
time
.
Millisecond
,
ReadTimeout
:
50
*
time
.
Millisecond
,
WriteTimeout
:
50
*
time
.
Millisecond
,
})
t
.
Cleanup
(
func
()
{
_
=
rdb
.
Close
()
})
router
:=
newAuthRoutesTestRouter
(
rdb
)
paths
:=
[]
string
{
"/api/v1/auth/register"
,
"/api/v1/auth/login"
,
"/api/v1/auth/login/2fa"
,
"/api/v1/auth/send-verify-code"
,
}
for
_
,
path
:=
range
paths
{
req
:=
httptest
.
NewRequest
(
http
.
MethodPost
,
path
,
strings
.
NewReader
(
`{}`
))
req
.
Header
.
Set
(
"Content-Type"
,
"application/json"
)
req
.
RemoteAddr
=
"203.0.113.10:12345"
w
:=
httptest
.
NewRecorder
()
router
.
ServeHTTP
(
w
,
req
)
require
.
Equal
(
t
,
http
.
StatusTooManyRequests
,
w
.
Code
,
"path=%s"
,
path
)
require
.
Contains
(
t
,
w
.
Body
.
String
(),
"rate limit exceeded"
,
"path=%s"
,
path
)
}
}
backend/internal/service/gateway_service.go
View file @
d04b47b3
...
...
@@ -3332,7 +3332,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
// 不需要重试(成功或不可重试的错误),跳出循环
// DEBUG: 输出响应 headers(用于检测 rate limit 信息)
if
account
.
Platform
==
PlatformGemini
&&
resp
.
StatusCode
<
400
{
if
account
.
Platform
==
PlatformGemini
&&
resp
.
StatusCode
<
400
&&
s
.
cfg
!=
nil
&&
s
.
cfg
.
Gateway
.
GeminiDebugResponseHeaders
{
logger
.
LegacyPrintf
(
"service.gateway"
,
"[DEBUG] Gemini API Response Headers for account %d:"
,
account
.
ID
)
for
k
,
v
:=
range
resp
.
Header
{
logger
.
LegacyPrintf
(
"service.gateway"
,
"[DEBUG] %s: %v"
,
k
,
v
)
...
...
@@ -4467,8 +4467,19 @@ func (s *GatewayService) handleNonStreamingResponse(ctx context.Context, resp *h
// 更新5h窗口状态
s
.
rateLimitService
.
UpdateSessionWindow
(
ctx
,
account
,
resp
.
Header
)
body
,
err
:=
io
.
ReadAll
(
resp
.
Body
)
maxBytes
:=
resolveUpstreamResponseReadLimit
(
s
.
cfg
)
body
,
err
:=
readUpstreamResponseBodyLimited
(
resp
.
Body
,
maxBytes
)
if
err
!=
nil
{
if
errors
.
Is
(
err
,
ErrUpstreamResponseBodyTooLarge
)
{
setOpsUpstreamError
(
c
,
http
.
StatusBadGateway
,
"upstream response too large"
,
""
)
c
.
JSON
(
http
.
StatusBadGateway
,
gin
.
H
{
"type"
:
"error"
,
"error"
:
gin
.
H
{
"type"
:
"upstream_error"
,
"message"
:
"Upstream response too large"
,
},
})
}
return
nil
,
err
}
...
...
@@ -4990,9 +5001,15 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
}
// 读取响应体
respBody
,
err
:=
io
.
ReadAll
(
resp
.
Body
)
maxReadBytes
:=
resolveUpstreamResponseReadLimit
(
s
.
cfg
)
respBody
,
err
:=
readUpstreamResponseBodyLimited
(
resp
.
Body
,
maxReadBytes
)
_
=
resp
.
Body
.
Close
()
if
err
!=
nil
{
if
errors
.
Is
(
err
,
ErrUpstreamResponseBodyTooLarge
)
{
setOpsUpstreamError
(
c
,
http
.
StatusBadGateway
,
"upstream response too large"
,
""
)
s
.
countTokensError
(
c
,
http
.
StatusBadGateway
,
"upstream_error"
,
"Upstream response too large"
)
return
err
}
s
.
countTokensError
(
c
,
http
.
StatusBadGateway
,
"upstream_error"
,
"Failed to read response"
)
return
err
}
...
...
@@ -5007,9 +5024,14 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
retryResp
,
retryErr
:=
s
.
httpUpstream
.
DoWithTLS
(
retryReq
,
proxyURL
,
account
.
ID
,
account
.
Concurrency
,
account
.
IsTLSFingerprintEnabled
())
if
retryErr
==
nil
{
resp
=
retryResp
respBody
,
err
=
io
.
ReadAll
(
resp
.
Body
)
respBody
,
err
=
readUpstreamResponseBodyLimited
(
resp
.
Body
,
maxReadBytes
)
_
=
resp
.
Body
.
Close
()
if
err
!=
nil
{
if
errors
.
Is
(
err
,
ErrUpstreamResponseBodyTooLarge
)
{
setOpsUpstreamError
(
c
,
http
.
StatusBadGateway
,
"upstream response too large"
,
""
)
s
.
countTokensError
(
c
,
http
.
StatusBadGateway
,
"upstream_error"
,
"Upstream response too large"
)
return
err
}
s
.
countTokensError
(
c
,
http
.
StatusBadGateway
,
"upstream_error"
,
"Failed to read response"
)
return
err
}
...
...
backend/internal/service/gemini_messages_compat_service.go
View file @
d04b47b3
...
...
@@ -2358,29 +2358,36 @@ type UpstreamHTTPResult struct {
}
func
(
s
*
GeminiMessagesCompatService
)
handleNativeNonStreamingResponse
(
c
*
gin
.
Context
,
resp
*
http
.
Response
,
isOAuth
bool
)
(
*
ClaudeUsage
,
error
)
{
// Log response headers for debugging
logger
.
LegacyPrintf
(
"service.gemini_messages_compat"
,
"[GeminiAPI] ========== Response Headers =========="
)
for
key
,
values
:=
range
resp
.
Header
{
if
strings
.
HasPrefix
(
strings
.
ToLower
(
key
),
"x-ratelimit"
)
{
logger
.
LegacyPrintf
(
"service.gemini_messages_compat"
,
"[GeminiAPI] %s: %v"
,
key
,
values
)
if
s
.
cfg
!=
nil
&&
s
.
cfg
.
Gateway
.
GeminiDebugResponseHeaders
{
logger
.
LegacyPrintf
(
"service.gemini_messages_compat"
,
"[GeminiAPI] ========== Response Headers =========="
)
for
key
,
values
:=
range
resp
.
Header
{
if
strings
.
HasPrefix
(
strings
.
ToLower
(
key
),
"x-ratelimit"
)
{
logger
.
LegacyPrintf
(
"service.gemini_messages_compat"
,
"[GeminiAPI] %s: %v"
,
key
,
values
)
}
}
logger
.
LegacyPrintf
(
"service.gemini_messages_compat"
,
"[GeminiAPI] ========================================"
)
}
logger
.
LegacyPrintf
(
"service.gemini_messages_compat"
,
"[GeminiAPI] ========================================"
)
respBody
,
err
:=
io
.
ReadAll
(
resp
.
Body
)
maxBytes
:=
resolveUpstreamResponseReadLimit
(
s
.
cfg
)
respBody
,
err
:=
readUpstreamResponseBodyLimited
(
resp
.
Body
,
maxBytes
)
if
err
!=
nil
{
if
errors
.
Is
(
err
,
ErrUpstreamResponseBodyTooLarge
)
{
setOpsUpstreamError
(
c
,
http
.
StatusBadGateway
,
"upstream response too large"
,
""
)
c
.
JSON
(
http
.
StatusBadGateway
,
gin
.
H
{
"error"
:
gin
.
H
{
"type"
:
"upstream_error"
,
"message"
:
"Upstream response too large"
,
},
})
}
return
nil
,
err
}
var
parsed
map
[
string
]
any
if
isOAuth
{
unwrappedBody
,
uwErr
:=
unwrapGeminiResponse
(
respBody
)
if
uwErr
==
nil
{
respBody
=
unwrappedBody
}
_
=
json
.
Unmarshal
(
respBody
,
&
parsed
)
}
else
{
_
=
json
.
Unmarshal
(
respBody
,
&
parsed
)
}
responseheaders
.
WriteFilteredHeaders
(
c
.
Writer
.
Header
(),
resp
.
Header
,
s
.
cfg
.
Security
.
ResponseHeaders
)
...
...
@@ -2398,14 +2405,15 @@ func (s *GeminiMessagesCompatService) handleNativeNonStreamingResponse(c *gin.Co
}
func
(
s
*
GeminiMessagesCompatService
)
handleNativeStreamingResponse
(
c
*
gin
.
Context
,
resp
*
http
.
Response
,
startTime
time
.
Time
,
isOAuth
bool
)
(
*
geminiNativeStreamResult
,
error
)
{
// Log response headers for debugging
logger
.
LegacyPrintf
(
"service.gemini_messages_compat"
,
"[GeminiAPI] ========== Streaming Response Headers =========="
)
for
key
,
values
:=
range
resp
.
Header
{
if
strings
.
HasPrefix
(
strings
.
ToLower
(
key
),
"x-ratelimit"
)
{
logger
.
LegacyPrintf
(
"service.gemini_messages_compat"
,
"[GeminiAPI] %s: %v"
,
key
,
values
)
if
s
.
cfg
!=
nil
&&
s
.
cfg
.
Gateway
.
GeminiDebugResponseHeaders
{
logger
.
LegacyPrintf
(
"service.gemini_messages_compat"
,
"[GeminiAPI] ========== Streaming Response Headers =========="
)
for
key
,
values
:=
range
resp
.
Header
{
if
strings
.
HasPrefix
(
strings
.
ToLower
(
key
),
"x-ratelimit"
)
{
logger
.
LegacyPrintf
(
"service.gemini_messages_compat"
,
"[GeminiAPI] %s: %v"
,
key
,
values
)
}
}
logger
.
LegacyPrintf
(
"service.gemini_messages_compat"
,
"[GeminiAPI] ===================================================="
)
}
logger
.
LegacyPrintf
(
"service.gemini_messages_compat"
,
"[GeminiAPI] ===================================================="
)
if
s
.
cfg
!=
nil
{
responseheaders
.
WriteFilteredHeaders
(
c
.
Writer
.
Header
(),
resp
.
Header
,
s
.
cfg
.
Security
.
ResponseHeaders
)
...
...
backend/internal/service/gemini_messages_compat_service_test.go
View file @
d04b47b3
...
...
@@ -3,10 +3,15 @@ package service
import
(
"encoding/json"
"fmt"
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
...
...
@@ -133,6 +138,38 @@ func TestConvertClaudeToolsToGeminiTools_CustomType(t *testing.T) {
}
}
func
TestGeminiHandleNativeNonStreamingResponse_DebugDisabledDoesNotEmitHeaderLogs
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
logSink
,
restore
:=
captureStructuredLog
(
t
)
defer
restore
()
svc
:=
&
GeminiMessagesCompatService
{
cfg
:
&
config
.
Config
{
Gateway
:
config
.
GatewayConfig
{
GeminiDebugResponseHeaders
:
false
,
},
},
}
w
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
w
)
c
.
Request
=
httptest
.
NewRequest
(
http
.
MethodPost
,
"/v1/messages"
,
nil
)
resp
:=
&
http
.
Response
{
StatusCode
:
http
.
StatusOK
,
Header
:
http
.
Header
{
"Content-Type"
:
[]
string
{
"application/json"
},
"X-RateLimit-Limit"
:
[]
string
{
"60"
},
},
Body
:
io
.
NopCloser
(
strings
.
NewReader
(
`{"usageMetadata":{"promptTokenCount":10,"candidatesTokenCount":2}}`
)),
}
usage
,
err
:=
svc
.
handleNativeNonStreamingResponse
(
c
,
resp
,
false
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
usage
)
require
.
False
(
t
,
logSink
.
ContainsMessage
(
"[GeminiAPI]"
),
"debug 关闭时不应输出 Gemini 响应头日志"
)
}
func
TestConvertClaudeMessagesToGeminiGenerateContent_AddsThoughtSignatureForToolUse
(
t
*
testing
.
T
)
{
claudeReq
:=
map
[
string
]
any
{
"model"
:
"claude-haiku-4-5-20251001"
,
...
...
backend/internal/service/openai_gateway_service.go
View file @
d04b47b3
...
...
@@ -1741,8 +1741,18 @@ func (s *OpenAIGatewayService) handleNonStreamingResponsePassthrough(
resp
*
http
.
Response
,
c
*
gin
.
Context
,
)
(
*
OpenAIUsage
,
error
)
{
body
,
err
:=
io
.
ReadAll
(
resp
.
Body
)
maxBytes
:=
resolveUpstreamResponseReadLimit
(
s
.
cfg
)
body
,
err
:=
readUpstreamResponseBodyLimited
(
resp
.
Body
,
maxBytes
)
if
err
!=
nil
{
if
errors
.
Is
(
err
,
ErrUpstreamResponseBodyTooLarge
)
{
setOpsUpstreamError
(
c
,
http
.
StatusBadGateway
,
"upstream response too large"
,
""
)
c
.
JSON
(
http
.
StatusBadGateway
,
gin
.
H
{
"error"
:
gin
.
H
{
"type"
:
"upstream_error"
,
"message"
:
"Upstream response too large"
,
},
})
}
return
nil
,
err
}
...
...
@@ -2371,8 +2381,18 @@ func (s *OpenAIGatewayService) parseSSEUsage(data string, usage *OpenAIUsage) {
}
func
(
s
*
OpenAIGatewayService
)
handleNonStreamingResponse
(
ctx
context
.
Context
,
resp
*
http
.
Response
,
c
*
gin
.
Context
,
account
*
Account
,
originalModel
,
mappedModel
string
)
(
*
OpenAIUsage
,
error
)
{
body
,
err
:=
io
.
ReadAll
(
resp
.
Body
)
maxBytes
:=
resolveUpstreamResponseReadLimit
(
s
.
cfg
)
body
,
err
:=
readUpstreamResponseBodyLimited
(
resp
.
Body
,
maxBytes
)
if
err
!=
nil
{
if
errors
.
Is
(
err
,
ErrUpstreamResponseBodyTooLarge
)
{
setOpsUpstreamError
(
c
,
http
.
StatusBadGateway
,
"upstream response too large"
,
""
)
c
.
JSON
(
http
.
StatusBadGateway
,
gin
.
H
{
"error"
:
gin
.
H
{
"type"
:
"upstream_error"
,
"message"
:
"Upstream response too large"
,
},
})
}
return
nil
,
err
}
...
...
@@ -2930,6 +2950,25 @@ func normalizeOpenAIPassthroughOAuthBody(body []byte) ([]byte, bool, error) {
return
normalized
,
changed
,
nil
}
func
detectOpenAIPassthroughInstructionsRejectReason
(
reqModel
string
,
body
[]
byte
)
string
{
model
:=
strings
.
ToLower
(
strings
.
TrimSpace
(
reqModel
))
if
!
strings
.
Contains
(
model
,
"codex"
)
{
return
""
}
instructions
:=
gjson
.
GetBytes
(
body
,
"instructions"
)
if
!
instructions
.
Exists
()
{
return
"instructions_missing"
}
if
instructions
.
Type
!=
gjson
.
String
{
return
"instructions_not_string"
}
if
strings
.
TrimSpace
(
instructions
.
String
())
==
""
{
return
"instructions_empty"
}
return
""
}
func
extractOpenAIReasoningEffortFromBody
(
body
[]
byte
,
requestedModel
string
)
*
string
{
reasoningEffort
:=
strings
.
TrimSpace
(
gjson
.
GetBytes
(
body
,
"reasoning.effort"
)
.
String
())
if
reasoningEffort
==
""
{
...
...
@@ -3002,22 +3041,3 @@ func normalizeOpenAIReasoningEffort(raw string) string {
return
""
}
}
func
detectOpenAIPassthroughInstructionsRejectReason
(
reqModel
string
,
body
[]
byte
)
string
{
model
:=
strings
.
ToLower
(
strings
.
TrimSpace
(
reqModel
))
if
!
strings
.
Contains
(
model
,
"codex"
)
{
return
""
}
instructions
:=
gjson
.
GetBytes
(
body
,
"instructions"
)
if
!
instructions
.
Exists
()
{
return
"instructions_missing"
}
if
instructions
.
Type
!=
gjson
.
String
{
return
"instructions_not_string"
}
if
strings
.
TrimSpace
(
instructions
.
String
())
==
""
{
return
"instructions_empty"
}
return
""
}
backend/internal/service/upstream_response_limit.go
0 → 100644
View file @
d04b47b3
package
service
import
(
"errors"
"fmt"
"io"
"github.com/Wei-Shaw/sub2api/internal/config"
)
var
ErrUpstreamResponseBodyTooLarge
=
errors
.
New
(
"upstream response body too large"
)
const
defaultUpstreamResponseReadMaxBytes
int64
=
8
*
1024
*
1024
func
resolveUpstreamResponseReadLimit
(
cfg
*
config
.
Config
)
int64
{
if
cfg
!=
nil
&&
cfg
.
Gateway
.
UpstreamResponseReadMaxBytes
>
0
{
return
cfg
.
Gateway
.
UpstreamResponseReadMaxBytes
}
return
defaultUpstreamResponseReadMaxBytes
}
func
readUpstreamResponseBodyLimited
(
reader
io
.
Reader
,
maxBytes
int64
)
([]
byte
,
error
)
{
if
reader
==
nil
{
return
nil
,
errors
.
New
(
"response body is nil"
)
}
if
maxBytes
<=
0
{
maxBytes
=
defaultUpstreamResponseReadMaxBytes
}
body
,
err
:=
io
.
ReadAll
(
io
.
LimitReader
(
reader
,
maxBytes
+
1
))
if
err
!=
nil
{
return
nil
,
err
}
if
int64
(
len
(
body
))
>
maxBytes
{
return
nil
,
fmt
.
Errorf
(
"%w: limit=%d"
,
ErrUpstreamResponseBodyTooLarge
,
maxBytes
)
}
return
body
,
nil
}
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment