Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Menu
Open sidebar
陈曦
sub2api
Commits
9c02ab78
Commit
9c02ab78
authored
Jan 12, 2026
by
yangjianbo
Browse files
fix(rate_limiter): 更新速率限制逻辑,支持返回修复状态
parent
202ec21b
Changes
2
Hide whitespace changes
Inline
Side-by-side
backend/internal/middleware/rate_limiter.go
View file @
9c02ab78
...
@@ -2,7 +2,10 @@ package middleware
...
@@ -2,7 +2,10 @@ package middleware
import
(
import
(
"context"
"context"
"fmt"
"log"
"net/http"
"net/http"
"strconv"
"time"
"time"
"github.com/gin-gonic/gin"
"github.com/gin-gonic/gin"
...
@@ -25,15 +28,34 @@ type RateLimitOptions struct {
...
@@ -25,15 +28,34 @@ type RateLimitOptions struct {
var
rateLimitScript
=
redis
.
NewScript
(
`
var
rateLimitScript
=
redis
.
NewScript
(
`
local current = redis.call('INCR', KEYS[1])
local current = redis.call('INCR', KEYS[1])
local ttl = redis.call('PTTL', KEYS[1])
local ttl = redis.call('PTTL', KEYS[1])
if current == 1 or ttl == -1 then
local repaired = 0
if current == 1 then
redis.call('PEXPIRE', KEYS[1], ARGV[1])
redis.call('PEXPIRE', KEYS[1], ARGV[1])
elseif ttl == -1 then
redis.call('PEXPIRE', KEYS[1], ARGV[1])
repaired = 1
end
end
return current
return
{
current
, repaired}
`
)
`
)
// rateLimitRun 允许测试覆写脚本执行逻辑
// rateLimitRun 允许测试覆写脚本执行逻辑
var
rateLimitRun
=
func
(
ctx
context
.
Context
,
client
*
redis
.
Client
,
key
string
,
windowMillis
int64
)
(
int64
,
error
)
{
var
rateLimitRun
=
func
(
ctx
context
.
Context
,
client
*
redis
.
Client
,
key
string
,
windowMillis
int64
)
(
int64
,
bool
,
error
)
{
return
rateLimitScript
.
Run
(
ctx
,
client
,
[]
string
{
key
},
windowMillis
)
.
Int64
()
values
,
err
:=
rateLimitScript
.
Run
(
ctx
,
client
,
[]
string
{
key
},
windowMillis
)
.
Slice
()
if
err
!=
nil
{
return
0
,
false
,
err
}
if
len
(
values
)
<
2
{
return
0
,
false
,
fmt
.
Errorf
(
"rate limit script returned %d values"
,
len
(
values
))
}
count
,
err
:=
parseInt64
(
values
[
0
])
if
err
!=
nil
{
return
0
,
false
,
err
}
repaired
,
err
:=
parseInt64
(
values
[
1
])
if
err
!=
nil
{
return
0
,
false
,
err
}
return
count
,
repaired
==
1
,
nil
}
}
// RateLimiter Redis 速率限制器
// RateLimiter Redis 速率限制器
...
@@ -74,8 +96,9 @@ func (r *RateLimiter) LimitWithOptions(key string, limit int, window time.Durati
...
@@ -74,8 +96,9 @@ func (r *RateLimiter) LimitWithOptions(key string, limit int, window time.Durati
windowMillis
:=
windowTTLMillis
(
window
)
windowMillis
:=
windowTTLMillis
(
window
)
// 使用 Lua 脚本原子操作增加计数并设置过期
// 使用 Lua 脚本原子操作增加计数并设置过期
count
,
err
:=
rateLimitRun
(
ctx
,
r
.
redis
,
redisKey
,
windowMillis
)
count
,
repaired
,
err
:=
rateLimitRun
(
ctx
,
r
.
redis
,
redisKey
,
windowMillis
)
if
err
!=
nil
{
if
err
!=
nil
{
log
.
Printf
(
"[RateLimit] redis error: key=%s mode=%s err=%v"
,
redisKey
,
failureModeLabel
(
failureMode
),
err
)
if
failureMode
==
RateLimitFailClose
{
if
failureMode
==
RateLimitFailClose
{
abortRateLimit
(
c
)
abortRateLimit
(
c
)
return
return
...
@@ -84,6 +107,9 @@ func (r *RateLimiter) LimitWithOptions(key string, limit int, window time.Durati
...
@@ -84,6 +107,9 @@ func (r *RateLimiter) LimitWithOptions(key string, limit int, window time.Durati
c
.
Next
()
c
.
Next
()
return
return
}
}
if
repaired
{
log
.
Printf
(
"[RateLimit] ttl repaired: key=%s window_ms=%d"
,
redisKey
,
windowMillis
)
}
// 超过限制
// 超过限制
if
count
>
int64
(
limit
)
{
if
count
>
int64
(
limit
)
{
...
@@ -109,3 +135,27 @@ func abortRateLimit(c *gin.Context) {
...
@@ -109,3 +135,27 @@ func abortRateLimit(c *gin.Context) {
"message"
:
"Too many requests, please try again later"
,
"message"
:
"Too many requests, please try again later"
,
})
})
}
}
func
failureModeLabel
(
mode
RateLimitFailureMode
)
string
{
if
mode
==
RateLimitFailClose
{
return
"fail-close"
}
return
"fail-open"
}
func
parseInt64
(
value
any
)
(
int64
,
error
)
{
switch
v
:=
value
.
(
type
)
{
case
int64
:
return
v
,
nil
case
int
:
return
int64
(
v
),
nil
case
string
:
parsed
,
err
:=
strconv
.
ParseInt
(
v
,
10
,
64
)
if
err
!=
nil
{
return
0
,
err
}
return
parsed
,
nil
default
:
return
0
,
fmt
.
Errorf
(
"unexpected value type %T"
,
value
)
}
}
backend/internal/middleware/rate_limiter_test.go
View file @
9c02ab78
...
@@ -66,13 +66,13 @@ func TestRateLimiterSuccessAndLimit(t *testing.T) {
...
@@ -66,13 +66,13 @@ func TestRateLimiterSuccessAndLimit(t *testing.T) {
originalRun
:=
rateLimitRun
originalRun
:=
rateLimitRun
counts
:=
[]
int64
{
1
,
2
}
counts
:=
[]
int64
{
1
,
2
}
callIndex
:=
0
callIndex
:=
0
rateLimitRun
=
func
(
ctx
context
.
Context
,
client
*
redis
.
Client
,
key
string
,
windowMillis
int64
)
(
int64
,
error
)
{
rateLimitRun
=
func
(
ctx
context
.
Context
,
client
*
redis
.
Client
,
key
string
,
windowMillis
int64
)
(
int64
,
bool
,
error
)
{
if
callIndex
>=
len
(
counts
)
{
if
callIndex
>=
len
(
counts
)
{
return
counts
[
len
(
counts
)
-
1
],
nil
return
counts
[
len
(
counts
)
-
1
],
false
,
nil
}
}
value
:=
counts
[
callIndex
]
value
:=
counts
[
callIndex
]
callIndex
++
callIndex
++
return
value
,
nil
return
value
,
false
,
nil
}
}
t
.
Cleanup
(
func
()
{
t
.
Cleanup
(
func
()
{
rateLimitRun
=
originalRun
rateLimitRun
=
originalRun
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment