Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Menu
Open sidebar
陈曦
sub2api
Commits
07be258d
Unverified
Commit
07be258d
authored
Feb 24, 2026
by
Wesley Liddick
Committed by
GitHub
Feb 24, 2026
Browse files
Merge pull request #603 from mt21625457/release
feat : 大幅度的性能优化 和 新增了很多功能
parents
dbdb2959
53d55bb9
Changes
271
Show whitespace changes
Inline
Side-by-side
Too many changes to show.
To preserve performance only
271 of 271+
files are displayed.
Plain diff
Email patch
backend/internal/server/middleware/client_request_id.go
View file @
07be258d
...
...
@@ -2,10 +2,13 @@ package middleware
import
(
"context"
"strings"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
"go.uber.org/zap"
)
// ClientRequestID ensures every request has a unique client_request_id in request.Context().
...
...
@@ -24,7 +27,10 @@ func ClientRequestID() gin.HandlerFunc {
}
id
:=
uuid
.
New
()
.
String
()
c
.
Request
=
c
.
Request
.
WithContext
(
context
.
WithValue
(
c
.
Request
.
Context
(),
ctxkey
.
ClientRequestID
,
id
))
ctx
:=
context
.
WithValue
(
c
.
Request
.
Context
(),
ctxkey
.
ClientRequestID
,
id
)
requestLogger
:=
logger
.
FromContext
(
ctx
)
.
With
(
zap
.
String
(
"client_request_id"
,
strings
.
TrimSpace
(
id
)))
ctx
=
logger
.
IntoContext
(
ctx
,
requestLogger
)
c
.
Request
=
c
.
Request
.
WithContext
(
ctx
)
c
.
Next
()
}
}
backend/internal/server/middleware/cors.go
View file @
07be258d
...
...
@@ -50,6 +50,19 @@ func CORS(cfg config.CORSConfig) gin.HandlerFunc {
}
allowedSet
[
origin
]
=
struct
{}{}
}
allowHeaders
:=
[]
string
{
"Content-Type"
,
"Content-Length"
,
"Accept-Encoding"
,
"X-CSRF-Token"
,
"Authorization"
,
"accept"
,
"origin"
,
"Cache-Control"
,
"X-Requested-With"
,
"X-API-Key"
,
}
// OpenAI Node SDK 会发送 x-stainless-* 请求头,需在 CORS 中显式放行。
openAIProperties
:=
[]
string
{
"lang"
,
"package-version"
,
"os"
,
"arch"
,
"retry-count"
,
"runtime"
,
"runtime-version"
,
"async"
,
"helper-method"
,
"poll-helper"
,
"custom-poll-interval"
,
"timeout"
,
}
for
_
,
prop
:=
range
openAIProperties
{
allowHeaders
=
append
(
allowHeaders
,
"x-stainless-"
+
prop
)
}
allowHeadersValue
:=
strings
.
Join
(
allowHeaders
,
", "
)
return
func
(
c
*
gin
.
Context
)
{
origin
:=
strings
.
TrimSpace
(
c
.
GetHeader
(
"Origin"
))
...
...
@@ -68,19 +81,11 @@ func CORS(cfg config.CORSConfig) gin.HandlerFunc {
if
allowCredentials
{
c
.
Writer
.
Header
()
.
Set
(
"Access-Control-Allow-Credentials"
,
"true"
)
}
}
allowHeaders
:=
[]
string
{
"Content-Type"
,
"Content-Length"
,
"Accept-Encoding"
,
"X-CSRF-Token"
,
"Authorization"
,
"accept"
,
"origin"
,
"Cache-Control"
,
"X-Requested-With"
,
"X-API-Key"
}
// openai node sdk
openAIProperties
:=
[]
string
{
"lang"
,
"package-version"
,
"os"
,
"arch"
,
"retry-count"
,
"runtime"
,
"runtime-version"
,
"async"
,
"helper-method"
,
"poll-helper"
,
"custom-poll-interval"
,
"timeout"
}
for
_
,
prop
:=
range
openAIProperties
{
allowHeaders
=
append
(
allowHeaders
,
"x-stainless-"
+
prop
)
}
c
.
Writer
.
Header
()
.
Set
(
"Access-Control-Allow-Headers"
,
strings
.
Join
(
allowHeaders
,
", "
))
c
.
Writer
.
Header
()
.
Set
(
"Access-Control-Allow-Headers"
,
allowHeadersValue
)
c
.
Writer
.
Header
()
.
Set
(
"Access-Control-Allow-Methods"
,
"POST, OPTIONS, GET, PUT, DELETE, PATCH"
)
c
.
Writer
.
Header
()
.
Set
(
"Access-Control-Expose-Headers"
,
"ETag"
)
c
.
Writer
.
Header
()
.
Set
(
"Access-Control-Max-Age"
,
"86400"
)
}
// 处理预检请求
if
c
.
Request
.
Method
==
http
.
MethodOptions
{
if
originAllowed
{
...
...
backend/internal/server/middleware/cors_test.go
0 → 100644
View file @
07be258d
package
middleware
import
(
"net/http"
"net/http/httptest"
"testing"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
)
func
init
()
{
// cors_test 与 security_headers_test 在同一个包,但 init 是幂等的
gin
.
SetMode
(
gin
.
TestMode
)
}
// --- Task 8.2: 验证 CORS 条件化头部 ---
func
TestCORS_DisallowedOrigin_NoAllowHeaders
(
t
*
testing
.
T
)
{
cfg
:=
config
.
CORSConfig
{
AllowedOrigins
:
[]
string
{
"https://allowed.example.com"
},
AllowCredentials
:
false
,
}
middleware
:=
CORS
(
cfg
)
tests
:=
[]
struct
{
name
string
method
string
origin
string
}{
{
name
:
"preflight_disallowed_origin"
,
method
:
http
.
MethodOptions
,
origin
:
"https://evil.example.com"
,
},
{
name
:
"get_disallowed_origin"
,
method
:
http
.
MethodGet
,
origin
:
"https://evil.example.com"
,
},
{
name
:
"post_disallowed_origin"
,
method
:
http
.
MethodPost
,
origin
:
"https://attacker.example.com"
,
},
{
name
:
"preflight_no_origin"
,
method
:
http
.
MethodOptions
,
origin
:
""
,
},
}
for
_
,
tt
:=
range
tests
{
t
.
Run
(
tt
.
name
,
func
(
t
*
testing
.
T
)
{
w
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
w
)
c
.
Request
=
httptest
.
NewRequest
(
tt
.
method
,
"/"
,
nil
)
if
tt
.
origin
!=
""
{
c
.
Request
.
Header
.
Set
(
"Origin"
,
tt
.
origin
)
}
middleware
(
c
)
// 不应设置 Allow-Headers、Allow-Methods 和 Max-Age
assert
.
Empty
(
t
,
w
.
Header
()
.
Get
(
"Access-Control-Allow-Headers"
),
"不允许的 origin 不应收到 Allow-Headers"
)
assert
.
Empty
(
t
,
w
.
Header
()
.
Get
(
"Access-Control-Allow-Methods"
),
"不允许的 origin 不应收到 Allow-Methods"
)
assert
.
Empty
(
t
,
w
.
Header
()
.
Get
(
"Access-Control-Max-Age"
),
"不允许的 origin 不应收到 Max-Age"
)
assert
.
Empty
(
t
,
w
.
Header
()
.
Get
(
"Access-Control-Allow-Origin"
),
"不允许的 origin 不应收到 Allow-Origin"
)
})
}
}
func
TestCORS_AllowedOrigin_HasAllowHeaders
(
t
*
testing
.
T
)
{
cfg
:=
config
.
CORSConfig
{
AllowedOrigins
:
[]
string
{
"https://allowed.example.com"
},
AllowCredentials
:
false
,
}
middleware
:=
CORS
(
cfg
)
tests
:=
[]
struct
{
name
string
method
string
}{
{
name
:
"preflight_OPTIONS"
,
method
:
http
.
MethodOptions
},
{
name
:
"normal_GET"
,
method
:
http
.
MethodGet
},
{
name
:
"normal_POST"
,
method
:
http
.
MethodPost
},
}
for
_
,
tt
:=
range
tests
{
t
.
Run
(
tt
.
name
,
func
(
t
*
testing
.
T
)
{
w
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
w
)
c
.
Request
=
httptest
.
NewRequest
(
tt
.
method
,
"/"
,
nil
)
c
.
Request
.
Header
.
Set
(
"Origin"
,
"https://allowed.example.com"
)
middleware
(
c
)
// 应设置 Allow-Headers、Allow-Methods 和 Max-Age
assert
.
NotEmpty
(
t
,
w
.
Header
()
.
Get
(
"Access-Control-Allow-Headers"
),
"允许的 origin 应收到 Allow-Headers"
)
assert
.
NotEmpty
(
t
,
w
.
Header
()
.
Get
(
"Access-Control-Allow-Methods"
),
"允许的 origin 应收到 Allow-Methods"
)
assert
.
Equal
(
t
,
"86400"
,
w
.
Header
()
.
Get
(
"Access-Control-Max-Age"
),
"允许的 origin 应收到 Max-Age=86400"
)
assert
.
Equal
(
t
,
"https://allowed.example.com"
,
w
.
Header
()
.
Get
(
"Access-Control-Allow-Origin"
),
"允许的 origin 应收到 Allow-Origin"
)
})
}
}
func
TestCORS_PreflightDisallowedOrigin_ReturnsForbidden
(
t
*
testing
.
T
)
{
cfg
:=
config
.
CORSConfig
{
AllowedOrigins
:
[]
string
{
"https://allowed.example.com"
},
AllowCredentials
:
false
,
}
middleware
:=
CORS
(
cfg
)
w
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
w
)
c
.
Request
=
httptest
.
NewRequest
(
http
.
MethodOptions
,
"/"
,
nil
)
c
.
Request
.
Header
.
Set
(
"Origin"
,
"https://evil.example.com"
)
middleware
(
c
)
assert
.
Equal
(
t
,
http
.
StatusForbidden
,
w
.
Code
,
"不允许的 origin 的 preflight 请求应返回 403"
)
}
func
TestCORS_PreflightAllowedOrigin_ReturnsNoContent
(
t
*
testing
.
T
)
{
cfg
:=
config
.
CORSConfig
{
AllowedOrigins
:
[]
string
{
"https://allowed.example.com"
},
AllowCredentials
:
false
,
}
middleware
:=
CORS
(
cfg
)
w
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
w
)
c
.
Request
=
httptest
.
NewRequest
(
http
.
MethodOptions
,
"/"
,
nil
)
c
.
Request
.
Header
.
Set
(
"Origin"
,
"https://allowed.example.com"
)
middleware
(
c
)
assert
.
Equal
(
t
,
http
.
StatusNoContent
,
w
.
Code
,
"允许的 origin 的 preflight 请求应返回 204"
)
}
func
TestCORS_WildcardOrigin_AllowsAny
(
t
*
testing
.
T
)
{
cfg
:=
config
.
CORSConfig
{
AllowedOrigins
:
[]
string
{
"*"
},
AllowCredentials
:
false
,
}
middleware
:=
CORS
(
cfg
)
w
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
w
)
c
.
Request
=
httptest
.
NewRequest
(
http
.
MethodGet
,
"/"
,
nil
)
c
.
Request
.
Header
.
Set
(
"Origin"
,
"https://any-origin.example.com"
)
middleware
(
c
)
assert
.
Equal
(
t
,
"*"
,
w
.
Header
()
.
Get
(
"Access-Control-Allow-Origin"
),
"通配符配置应返回 *"
)
assert
.
NotEmpty
(
t
,
w
.
Header
()
.
Get
(
"Access-Control-Allow-Headers"
),
"通配符 origin 应设置 Allow-Headers"
)
assert
.
NotEmpty
(
t
,
w
.
Header
()
.
Get
(
"Access-Control-Allow-Methods"
),
"通配符 origin 应设置 Allow-Methods"
)
}
func
TestCORS_AllowCredentials_SetCorrectly
(
t
*
testing
.
T
)
{
cfg
:=
config
.
CORSConfig
{
AllowedOrigins
:
[]
string
{
"https://allowed.example.com"
},
AllowCredentials
:
true
,
}
middleware
:=
CORS
(
cfg
)
t
.
Run
(
"allowed_origin_gets_credentials"
,
func
(
t
*
testing
.
T
)
{
w
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
w
)
c
.
Request
=
httptest
.
NewRequest
(
http
.
MethodGet
,
"/"
,
nil
)
c
.
Request
.
Header
.
Set
(
"Origin"
,
"https://allowed.example.com"
)
middleware
(
c
)
assert
.
Equal
(
t
,
"true"
,
w
.
Header
()
.
Get
(
"Access-Control-Allow-Credentials"
),
"允许的 origin 且开启 credentials 应设置 Allow-Credentials"
)
})
t
.
Run
(
"disallowed_origin_no_credentials"
,
func
(
t
*
testing
.
T
)
{
w
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
w
)
c
.
Request
=
httptest
.
NewRequest
(
http
.
MethodGet
,
"/"
,
nil
)
c
.
Request
.
Header
.
Set
(
"Origin"
,
"https://evil.example.com"
)
middleware
(
c
)
assert
.
Empty
(
t
,
w
.
Header
()
.
Get
(
"Access-Control-Allow-Credentials"
),
"不允许的 origin 不应收到 Allow-Credentials"
)
})
}
func
TestCORS_WildcardWithCredentials_DisablesCredentials
(
t
*
testing
.
T
)
{
cfg
:=
config
.
CORSConfig
{
AllowedOrigins
:
[]
string
{
"*"
},
AllowCredentials
:
true
,
}
middleware
:=
CORS
(
cfg
)
w
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
w
)
c
.
Request
=
httptest
.
NewRequest
(
http
.
MethodGet
,
"/"
,
nil
)
c
.
Request
.
Header
.
Set
(
"Origin"
,
"https://any.example.com"
)
middleware
(
c
)
// 通配符 + credentials 不兼容,credentials 应被禁用
assert
.
Empty
(
t
,
w
.
Header
()
.
Get
(
"Access-Control-Allow-Credentials"
),
"通配符 origin 应禁用 Allow-Credentials"
)
}
func
TestCORS_MultipleAllowedOrigins
(
t
*
testing
.
T
)
{
cfg
:=
config
.
CORSConfig
{
AllowedOrigins
:
[]
string
{
"https://app1.example.com"
,
"https://app2.example.com"
,
},
AllowCredentials
:
false
,
}
middleware
:=
CORS
(
cfg
)
t
.
Run
(
"first_origin_allowed"
,
func
(
t
*
testing
.
T
)
{
w
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
w
)
c
.
Request
=
httptest
.
NewRequest
(
http
.
MethodGet
,
"/"
,
nil
)
c
.
Request
.
Header
.
Set
(
"Origin"
,
"https://app1.example.com"
)
middleware
(
c
)
assert
.
Equal
(
t
,
"https://app1.example.com"
,
w
.
Header
()
.
Get
(
"Access-Control-Allow-Origin"
))
assert
.
NotEmpty
(
t
,
w
.
Header
()
.
Get
(
"Access-Control-Allow-Headers"
))
})
t
.
Run
(
"second_origin_allowed"
,
func
(
t
*
testing
.
T
)
{
w
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
w
)
c
.
Request
=
httptest
.
NewRequest
(
http
.
MethodGet
,
"/"
,
nil
)
c
.
Request
.
Header
.
Set
(
"Origin"
,
"https://app2.example.com"
)
middleware
(
c
)
assert
.
Equal
(
t
,
"https://app2.example.com"
,
w
.
Header
()
.
Get
(
"Access-Control-Allow-Origin"
))
assert
.
NotEmpty
(
t
,
w
.
Header
()
.
Get
(
"Access-Control-Allow-Headers"
))
})
t
.
Run
(
"unlisted_origin_rejected"
,
func
(
t
*
testing
.
T
)
{
w
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
w
)
c
.
Request
=
httptest
.
NewRequest
(
http
.
MethodGet
,
"/"
,
nil
)
c
.
Request
.
Header
.
Set
(
"Origin"
,
"https://app3.example.com"
)
middleware
(
c
)
assert
.
Empty
(
t
,
w
.
Header
()
.
Get
(
"Access-Control-Allow-Origin"
))
assert
.
Empty
(
t
,
w
.
Header
()
.
Get
(
"Access-Control-Allow-Headers"
))
})
}
func
TestCORS_VaryHeader_SetForSpecificOrigin
(
t
*
testing
.
T
)
{
cfg
:=
config
.
CORSConfig
{
AllowedOrigins
:
[]
string
{
"https://allowed.example.com"
},
AllowCredentials
:
false
,
}
middleware
:=
CORS
(
cfg
)
w
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
w
)
c
.
Request
=
httptest
.
NewRequest
(
http
.
MethodGet
,
"/"
,
nil
)
c
.
Request
.
Header
.
Set
(
"Origin"
,
"https://allowed.example.com"
)
middleware
(
c
)
assert
.
Contains
(
t
,
w
.
Header
()
.
Values
(
"Vary"
),
"Origin"
,
"非通配符允许的 origin 应设置 Vary: Origin"
)
}
func
TestNormalizeOrigins
(
t
*
testing
.
T
)
{
tests
:=
[]
struct
{
name
string
input
[]
string
expect
[]
string
}{
{
name
:
"nil_input"
,
input
:
nil
,
expect
:
nil
},
{
name
:
"empty_input"
,
input
:
[]
string
{},
expect
:
nil
},
{
name
:
"trims_whitespace"
,
input
:
[]
string
{
" https://a.com "
,
" https://b.com"
},
expect
:
[]
string
{
"https://a.com"
,
"https://b.com"
}},
{
name
:
"removes_empty_strings"
,
input
:
[]
string
{
""
,
" "
,
"https://a.com"
},
expect
:
[]
string
{
"https://a.com"
}},
}
for
_
,
tt
:=
range
tests
{
t
.
Run
(
tt
.
name
,
func
(
t
*
testing
.
T
)
{
result
:=
normalizeOrigins
(
tt
.
input
)
assert
.
Equal
(
t
,
tt
.
expect
,
result
)
})
}
}
backend/internal/server/middleware/jwt_auth.go
View file @
07be258d
...
...
@@ -26,12 +26,12 @@ func jwtAuth(authService *service.AuthService, userService *service.UserService)
// 验证Bearer scheme
parts
:=
strings
.
SplitN
(
authHeader
,
" "
,
2
)
if
len
(
parts
)
!=
2
||
parts
[
0
]
!=
"Bearer"
{
if
len
(
parts
)
!=
2
||
!
strings
.
EqualFold
(
parts
[
0
]
,
"Bearer"
)
{
AbortWithError
(
c
,
401
,
"INVALID_AUTH_HEADER"
,
"Authorization header format must be 'Bearer {token}'"
)
return
}
tokenString
:=
parts
[
1
]
tokenString
:=
strings
.
TrimSpace
(
parts
[
1
]
)
if
tokenString
==
""
{
AbortWithError
(
c
,
401
,
"EMPTY_TOKEN"
,
"Token cannot be empty"
)
return
...
...
backend/internal/server/middleware/jwt_auth_test.go
0 → 100644
View file @
07be258d
//go:build unit
package
middleware
import
(
"context"
"encoding/json"
"errors"
"net/http"
"net/http/httptest"
"testing"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
// stubJWTUserRepo 实现 UserRepository 的最小子集,仅支持 GetByID。
type
stubJWTUserRepo
struct
{
service
.
UserRepository
users
map
[
int64
]
*
service
.
User
}
func
(
r
*
stubJWTUserRepo
)
GetByID
(
_
context
.
Context
,
id
int64
)
(
*
service
.
User
,
error
)
{
u
,
ok
:=
r
.
users
[
id
]
if
!
ok
{
return
nil
,
errors
.
New
(
"user not found"
)
}
return
u
,
nil
}
// newJWTTestEnv 创建 JWT 认证中间件测试环境。
// 返回 gin.Engine(已注册 JWT 中间件)和 AuthService(用于生成 Token)。
func
newJWTTestEnv
(
users
map
[
int64
]
*
service
.
User
)
(
*
gin
.
Engine
,
*
service
.
AuthService
)
{
gin
.
SetMode
(
gin
.
TestMode
)
cfg
:=
&
config
.
Config
{}
cfg
.
JWT
.
Secret
=
"test-jwt-secret-32bytes-long!!!"
cfg
.
JWT
.
AccessTokenExpireMinutes
=
60
userRepo
:=
&
stubJWTUserRepo
{
users
:
users
}
authSvc
:=
service
.
NewAuthService
(
userRepo
,
nil
,
nil
,
cfg
,
nil
,
nil
,
nil
,
nil
,
nil
)
userSvc
:=
service
.
NewUserService
(
userRepo
,
nil
,
nil
)
mw
:=
NewJWTAuthMiddleware
(
authSvc
,
userSvc
)
r
:=
gin
.
New
()
r
.
Use
(
gin
.
HandlerFunc
(
mw
))
r
.
GET
(
"/protected"
,
func
(
c
*
gin
.
Context
)
{
subject
,
_
:=
GetAuthSubjectFromContext
(
c
)
role
,
_
:=
GetUserRoleFromContext
(
c
)
c
.
JSON
(
http
.
StatusOK
,
gin
.
H
{
"user_id"
:
subject
.
UserID
,
"role"
:
role
,
})
})
return
r
,
authSvc
}
func
TestJWTAuth_ValidToken
(
t
*
testing
.
T
)
{
user
:=
&
service
.
User
{
ID
:
1
,
Email
:
"test@example.com"
,
Role
:
"user"
,
Status
:
service
.
StatusActive
,
Concurrency
:
5
,
TokenVersion
:
1
,
}
router
,
authSvc
:=
newJWTTestEnv
(
map
[
int64
]
*
service
.
User
{
1
:
user
})
token
,
err
:=
authSvc
.
GenerateToken
(
user
)
require
.
NoError
(
t
,
err
)
w
:=
httptest
.
NewRecorder
()
req
:=
httptest
.
NewRequest
(
http
.
MethodGet
,
"/protected"
,
nil
)
req
.
Header
.
Set
(
"Authorization"
,
"Bearer "
+
token
)
router
.
ServeHTTP
(
w
,
req
)
require
.
Equal
(
t
,
http
.
StatusOK
,
w
.
Code
)
var
body
map
[
string
]
any
require
.
NoError
(
t
,
json
.
Unmarshal
(
w
.
Body
.
Bytes
(),
&
body
))
require
.
Equal
(
t
,
float64
(
1
),
body
[
"user_id"
])
require
.
Equal
(
t
,
"user"
,
body
[
"role"
])
}
func
TestJWTAuth_ValidToken_LowercaseBearer
(
t
*
testing
.
T
)
{
user
:=
&
service
.
User
{
ID
:
1
,
Email
:
"test@example.com"
,
Role
:
"user"
,
Status
:
service
.
StatusActive
,
Concurrency
:
5
,
TokenVersion
:
1
,
}
router
,
authSvc
:=
newJWTTestEnv
(
map
[
int64
]
*
service
.
User
{
1
:
user
})
token
,
err
:=
authSvc
.
GenerateToken
(
user
)
require
.
NoError
(
t
,
err
)
w
:=
httptest
.
NewRecorder
()
req
:=
httptest
.
NewRequest
(
http
.
MethodGet
,
"/protected"
,
nil
)
req
.
Header
.
Set
(
"Authorization"
,
"bearer "
+
token
)
router
.
ServeHTTP
(
w
,
req
)
require
.
Equal
(
t
,
http
.
StatusOK
,
w
.
Code
)
}
func
TestJWTAuth_MissingAuthorizationHeader
(
t
*
testing
.
T
)
{
router
,
_
:=
newJWTTestEnv
(
nil
)
w
:=
httptest
.
NewRecorder
()
req
:=
httptest
.
NewRequest
(
http
.
MethodGet
,
"/protected"
,
nil
)
router
.
ServeHTTP
(
w
,
req
)
require
.
Equal
(
t
,
http
.
StatusUnauthorized
,
w
.
Code
)
var
body
ErrorResponse
require
.
NoError
(
t
,
json
.
Unmarshal
(
w
.
Body
.
Bytes
(),
&
body
))
require
.
Equal
(
t
,
"UNAUTHORIZED"
,
body
.
Code
)
}
func
TestJWTAuth_InvalidHeaderFormat
(
t
*
testing
.
T
)
{
tests
:=
[]
struct
{
name
string
header
string
}{
{
"无Bearer前缀"
,
"Token abc123"
},
{
"缺少空格分隔"
,
"Bearerabc123"
},
{
"仅有单词"
,
"abc123"
},
}
router
,
_
:=
newJWTTestEnv
(
nil
)
for
_
,
tt
:=
range
tests
{
t
.
Run
(
tt
.
name
,
func
(
t
*
testing
.
T
)
{
w
:=
httptest
.
NewRecorder
()
req
:=
httptest
.
NewRequest
(
http
.
MethodGet
,
"/protected"
,
nil
)
req
.
Header
.
Set
(
"Authorization"
,
tt
.
header
)
router
.
ServeHTTP
(
w
,
req
)
require
.
Equal
(
t
,
http
.
StatusUnauthorized
,
w
.
Code
)
var
body
ErrorResponse
require
.
NoError
(
t
,
json
.
Unmarshal
(
w
.
Body
.
Bytes
(),
&
body
))
require
.
Equal
(
t
,
"INVALID_AUTH_HEADER"
,
body
.
Code
)
})
}
}
func
TestJWTAuth_EmptyToken
(
t
*
testing
.
T
)
{
router
,
_
:=
newJWTTestEnv
(
nil
)
w
:=
httptest
.
NewRecorder
()
req
:=
httptest
.
NewRequest
(
http
.
MethodGet
,
"/protected"
,
nil
)
req
.
Header
.
Set
(
"Authorization"
,
"Bearer "
)
router
.
ServeHTTP
(
w
,
req
)
require
.
Equal
(
t
,
http
.
StatusUnauthorized
,
w
.
Code
)
var
body
ErrorResponse
require
.
NoError
(
t
,
json
.
Unmarshal
(
w
.
Body
.
Bytes
(),
&
body
))
require
.
Equal
(
t
,
"EMPTY_TOKEN"
,
body
.
Code
)
}
func
TestJWTAuth_TamperedToken
(
t
*
testing
.
T
)
{
router
,
_
:=
newJWTTestEnv
(
nil
)
w
:=
httptest
.
NewRecorder
()
req
:=
httptest
.
NewRequest
(
http
.
MethodGet
,
"/protected"
,
nil
)
req
.
Header
.
Set
(
"Authorization"
,
"Bearer eyJhbGciOiJIUzI1NiJ9.eyJ1c2VyX2lkIjoxfQ.invalid_signature"
)
router
.
ServeHTTP
(
w
,
req
)
require
.
Equal
(
t
,
http
.
StatusUnauthorized
,
w
.
Code
)
var
body
ErrorResponse
require
.
NoError
(
t
,
json
.
Unmarshal
(
w
.
Body
.
Bytes
(),
&
body
))
require
.
Equal
(
t
,
"INVALID_TOKEN"
,
body
.
Code
)
}
func
TestJWTAuth_UserNotFound
(
t
*
testing
.
T
)
{
// 使用 user ID=1 的 token,但 repo 中没有该用户
fakeUser
:=
&
service
.
User
{
ID
:
999
,
Email
:
"ghost@example.com"
,
Role
:
"user"
,
Status
:
service
.
StatusActive
,
TokenVersion
:
1
,
}
// 创建环境时不注入此用户,这样 GetByID 会失败
router
,
authSvc
:=
newJWTTestEnv
(
map
[
int64
]
*
service
.
User
{})
token
,
err
:=
authSvc
.
GenerateToken
(
fakeUser
)
require
.
NoError
(
t
,
err
)
w
:=
httptest
.
NewRecorder
()
req
:=
httptest
.
NewRequest
(
http
.
MethodGet
,
"/protected"
,
nil
)
req
.
Header
.
Set
(
"Authorization"
,
"Bearer "
+
token
)
router
.
ServeHTTP
(
w
,
req
)
require
.
Equal
(
t
,
http
.
StatusUnauthorized
,
w
.
Code
)
var
body
ErrorResponse
require
.
NoError
(
t
,
json
.
Unmarshal
(
w
.
Body
.
Bytes
(),
&
body
))
require
.
Equal
(
t
,
"USER_NOT_FOUND"
,
body
.
Code
)
}
func
TestJWTAuth_UserInactive
(
t
*
testing
.
T
)
{
user
:=
&
service
.
User
{
ID
:
1
,
Email
:
"disabled@example.com"
,
Role
:
"user"
,
Status
:
service
.
StatusDisabled
,
TokenVersion
:
1
,
}
router
,
authSvc
:=
newJWTTestEnv
(
map
[
int64
]
*
service
.
User
{
1
:
user
})
token
,
err
:=
authSvc
.
GenerateToken
(
user
)
require
.
NoError
(
t
,
err
)
w
:=
httptest
.
NewRecorder
()
req
:=
httptest
.
NewRequest
(
http
.
MethodGet
,
"/protected"
,
nil
)
req
.
Header
.
Set
(
"Authorization"
,
"Bearer "
+
token
)
router
.
ServeHTTP
(
w
,
req
)
require
.
Equal
(
t
,
http
.
StatusUnauthorized
,
w
.
Code
)
var
body
ErrorResponse
require
.
NoError
(
t
,
json
.
Unmarshal
(
w
.
Body
.
Bytes
(),
&
body
))
require
.
Equal
(
t
,
"USER_INACTIVE"
,
body
.
Code
)
}
func
TestJWTAuth_TokenVersionMismatch
(
t
*
testing
.
T
)
{
// Token 生成时 TokenVersion=1,但数据库中用户已更新为 TokenVersion=2(密码修改)
userForToken
:=
&
service
.
User
{
ID
:
1
,
Email
:
"test@example.com"
,
Role
:
"user"
,
Status
:
service
.
StatusActive
,
TokenVersion
:
1
,
}
userInDB
:=
&
service
.
User
{
ID
:
1
,
Email
:
"test@example.com"
,
Role
:
"user"
,
Status
:
service
.
StatusActive
,
TokenVersion
:
2
,
// 密码修改后版本递增
}
router
,
authSvc
:=
newJWTTestEnv
(
map
[
int64
]
*
service
.
User
{
1
:
userInDB
})
token
,
err
:=
authSvc
.
GenerateToken
(
userForToken
)
require
.
NoError
(
t
,
err
)
w
:=
httptest
.
NewRecorder
()
req
:=
httptest
.
NewRequest
(
http
.
MethodGet
,
"/protected"
,
nil
)
req
.
Header
.
Set
(
"Authorization"
,
"Bearer "
+
token
)
router
.
ServeHTTP
(
w
,
req
)
require
.
Equal
(
t
,
http
.
StatusUnauthorized
,
w
.
Code
)
var
body
ErrorResponse
require
.
NoError
(
t
,
json
.
Unmarshal
(
w
.
Body
.
Bytes
(),
&
body
))
require
.
Equal
(
t
,
"TOKEN_REVOKED"
,
body
.
Code
)
}
backend/internal/server/middleware/logger.go
View file @
07be258d
package
middleware
import
(
"log"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
)
// Logger 请求日志中间件
...
...
@@ -13,44 +15,52 @@ func Logger() gin.HandlerFunc {
// 开始时间
startTime
:=
time
.
Now
()
// 请求路径
path
:=
c
.
Request
.
URL
.
Path
// 处理请求
c
.
Next
()
// 结束时间
endTime
:=
time
.
Now
()
// 跳过健康检查等高频探针路径的日志
if
path
==
"/health"
||
path
==
"/setup/status"
{
return
}
// 执行时间
endTime
:=
time
.
Now
()
latency
:=
endTime
.
Sub
(
startTime
)
// 请求方法
method
:=
c
.
Request
.
Method
// 请求路径
path
:=
c
.
Request
.
URL
.
Path
// 状态码
statusCode
:=
c
.
Writer
.
Status
()
// 客户端IP
clientIP
:=
c
.
ClientIP
()
// 协议版本
protocol
:=
c
.
Request
.
Proto
accountID
,
hasAccountID
:=
c
.
Request
.
Context
()
.
Value
(
ctxkey
.
AccountID
)
.
(
int64
)
platform
,
_
:=
c
.
Request
.
Context
()
.
Value
(
ctxkey
.
Platform
)
.
(
string
)
model
,
_
:=
c
.
Request
.
Context
()
.
Value
(
ctxkey
.
Model
)
.
(
string
)
fields
:=
[]
zap
.
Field
{
zap
.
String
(
"component"
,
"http.access"
),
zap
.
Int
(
"status_code"
,
statusCode
),
zap
.
Int64
(
"latency_ms"
,
latency
.
Milliseconds
()),
zap
.
String
(
"client_ip"
,
clientIP
),
zap
.
String
(
"protocol"
,
protocol
),
zap
.
String
(
"method"
,
method
),
zap
.
String
(
"path"
,
path
),
}
if
hasAccountID
&&
accountID
>
0
{
fields
=
append
(
fields
,
zap
.
Int64
(
"account_id"
,
accountID
))
}
if
platform
!=
""
{
fields
=
append
(
fields
,
zap
.
String
(
"platform"
,
platform
))
}
if
model
!=
""
{
fields
=
append
(
fields
,
zap
.
String
(
"model"
,
model
))
}
l
:=
logger
.
FromContext
(
c
.
Request
.
Context
())
.
With
(
fields
...
)
l
.
Info
(
"http request completed"
,
zap
.
Time
(
"completed_at"
,
endTime
))
// 日志格式: [时间] 状态码 | 延迟 | IP | 协议 | 方法 路径
log
.
Printf
(
"[GIN] %v | %3d | %13v | %15s | %-6s | %-7s %s"
,
endTime
.
Format
(
"2006/01/02 - 15:04:05"
),
statusCode
,
latency
,
clientIP
,
protocol
,
method
,
path
,
)
// 如果有错误,额外记录错误信息
if
len
(
c
.
Errors
)
>
0
{
l
og
.
Printf
(
"[GIN] E
rrors
: %v
"
,
c
.
Errors
.
String
())
l
.
Warn
(
"http request contains gin errors"
,
zap
.
String
(
"e
rrors"
,
c
.
Errors
.
String
())
)
}
}
}
backend/internal/server/middleware/misc_coverage_test.go
0 → 100644
View file @
07be258d
//go:build unit
package
middleware
import
(
"bytes"
"context"
"io"
"net/http"
"net/http/httptest"
"testing"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
func
TestClientRequestID_GeneratesWhenMissing
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
r
:=
gin
.
New
()
r
.
Use
(
ClientRequestID
())
r
.
GET
(
"/t"
,
func
(
c
*
gin
.
Context
)
{
v
:=
c
.
Request
.
Context
()
.
Value
(
ctxkey
.
ClientRequestID
)
require
.
NotNil
(
t
,
v
)
id
,
ok
:=
v
.
(
string
)
require
.
True
(
t
,
ok
)
require
.
NotEmpty
(
t
,
id
)
c
.
Status
(
http
.
StatusOK
)
})
w
:=
httptest
.
NewRecorder
()
req
:=
httptest
.
NewRequest
(
http
.
MethodGet
,
"/t"
,
nil
)
r
.
ServeHTTP
(
w
,
req
)
require
.
Equal
(
t
,
http
.
StatusOK
,
w
.
Code
)
}
func
TestClientRequestID_PreservesExisting
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
r
:=
gin
.
New
()
r
.
Use
(
ClientRequestID
())
r
.
GET
(
"/t"
,
func
(
c
*
gin
.
Context
)
{
id
,
ok
:=
c
.
Request
.
Context
()
.
Value
(
ctxkey
.
ClientRequestID
)
.
(
string
)
require
.
True
(
t
,
ok
)
require
.
Equal
(
t
,
"keep"
,
id
)
c
.
Status
(
http
.
StatusOK
)
})
w
:=
httptest
.
NewRecorder
()
req
:=
httptest
.
NewRequest
(
http
.
MethodGet
,
"/t"
,
nil
)
req
=
req
.
WithContext
(
context
.
WithValue
(
req
.
Context
(),
ctxkey
.
ClientRequestID
,
"keep"
))
r
.
ServeHTTP
(
w
,
req
)
require
.
Equal
(
t
,
http
.
StatusOK
,
w
.
Code
)
}
func
TestRequestBodyLimit_LimitsBody
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
r
:=
gin
.
New
()
r
.
Use
(
RequestBodyLimit
(
4
))
r
.
POST
(
"/t"
,
func
(
c
*
gin
.
Context
)
{
_
,
err
:=
io
.
ReadAll
(
c
.
Request
.
Body
)
require
.
Error
(
t
,
err
)
c
.
Status
(
http
.
StatusOK
)
})
w
:=
httptest
.
NewRecorder
()
req
:=
httptest
.
NewRequest
(
http
.
MethodPost
,
"/t"
,
bytes
.
NewBufferString
(
"12345"
))
r
.
ServeHTTP
(
w
,
req
)
require
.
Equal
(
t
,
http
.
StatusOK
,
w
.
Code
)
}
func
TestForcePlatform_SetsContextAndGinValue
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
r
:=
gin
.
New
()
r
.
Use
(
ForcePlatform
(
"anthropic"
))
r
.
GET
(
"/t"
,
func
(
c
*
gin
.
Context
)
{
require
.
True
(
t
,
HasForcePlatform
(
c
))
v
,
ok
:=
GetForcePlatformFromContext
(
c
)
require
.
True
(
t
,
ok
)
require
.
Equal
(
t
,
"anthropic"
,
v
)
ctxV
:=
c
.
Request
.
Context
()
.
Value
(
ctxkey
.
ForcePlatform
)
require
.
Equal
(
t
,
"anthropic"
,
ctxV
)
c
.
Status
(
http
.
StatusOK
)
})
w
:=
httptest
.
NewRecorder
()
req
:=
httptest
.
NewRequest
(
http
.
MethodGet
,
"/t"
,
nil
)
r
.
ServeHTTP
(
w
,
req
)
require
.
Equal
(
t
,
http
.
StatusOK
,
w
.
Code
)
}
func
TestAuthSubjectHelpers_RoundTrip
(
t
*
testing
.
T
)
{
c
:=
&
gin
.
Context
{}
c
.
Set
(
string
(
ContextKeyUser
),
AuthSubject
{
UserID
:
1
,
Concurrency
:
2
})
c
.
Set
(
string
(
ContextKeyUserRole
),
"admin"
)
sub
,
ok
:=
GetAuthSubjectFromContext
(
c
)
require
.
True
(
t
,
ok
)
require
.
Equal
(
t
,
int64
(
1
),
sub
.
UserID
)
require
.
Equal
(
t
,
2
,
sub
.
Concurrency
)
role
,
ok
:=
GetUserRoleFromContext
(
c
)
require
.
True
(
t
,
ok
)
require
.
Equal
(
t
,
"admin"
,
role
)
}
func
TestAPIKeyAndSubscriptionFromContext
(
t
*
testing
.
T
)
{
c
:=
&
gin
.
Context
{}
key
:=
&
service
.
APIKey
{
ID
:
1
}
c
.
Set
(
string
(
ContextKeyAPIKey
),
key
)
gotKey
,
ok
:=
GetAPIKeyFromContext
(
c
)
require
.
True
(
t
,
ok
)
require
.
Equal
(
t
,
int64
(
1
),
gotKey
.
ID
)
sub
:=
&
service
.
UserSubscription
{
ID
:
2
}
c
.
Set
(
string
(
ContextKeySubscription
),
sub
)
gotSub
,
ok
:=
GetSubscriptionFromContext
(
c
)
require
.
True
(
t
,
ok
)
require
.
Equal
(
t
,
int64
(
2
),
gotSub
.
ID
)
}
backend/internal/server/middleware/recovery_test.go
View file @
07be258d
...
...
@@ -3,6 +3,7 @@
package
middleware
import
(
"bytes"
"encoding/json"
"net/http"
"net/http/httptest"
...
...
@@ -14,6 +15,34 @@ import (
"github.com/stretchr/testify/require"
)
func
TestRecovery_PanicLogContainsInfo
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
// 临时替换 DefaultErrorWriter 以捕获日志输出
var
buf
bytes
.
Buffer
originalWriter
:=
gin
.
DefaultErrorWriter
gin
.
DefaultErrorWriter
=
&
buf
t
.
Cleanup
(
func
()
{
gin
.
DefaultErrorWriter
=
originalWriter
})
r
:=
gin
.
New
()
r
.
Use
(
Recovery
())
r
.
GET
(
"/panic"
,
func
(
c
*
gin
.
Context
)
{
panic
(
"custom panic message for test"
)
})
w
:=
httptest
.
NewRecorder
()
req
:=
httptest
.
NewRequest
(
http
.
MethodGet
,
"/panic"
,
nil
)
r
.
ServeHTTP
(
w
,
req
)
require
.
Equal
(
t
,
http
.
StatusInternalServerError
,
w
.
Code
)
logOutput
:=
buf
.
String
()
require
.
Contains
(
t
,
logOutput
,
"custom panic message for test"
,
"日志应包含 panic 信息"
)
require
.
Contains
(
t
,
logOutput
,
"recovery_test.go"
,
"日志应包含堆栈跟踪文件名"
)
}
func
TestRecovery
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
...
...
backend/internal/server/middleware/request_access_logger_test.go
0 → 100644
View file @
07be258d
package
middleware
import
(
"context"
"net/http"
"net/http/httptest"
"strings"
"sync"
"testing"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"github.com/gin-gonic/gin"
)
type
testLogSink
struct
{
mu
sync
.
Mutex
events
[]
*
logger
.
LogEvent
}
func
(
s
*
testLogSink
)
WriteLogEvent
(
event
*
logger
.
LogEvent
)
{
s
.
mu
.
Lock
()
defer
s
.
mu
.
Unlock
()
s
.
events
=
append
(
s
.
events
,
event
)
}
func
(
s
*
testLogSink
)
list
()
[]
*
logger
.
LogEvent
{
s
.
mu
.
Lock
()
defer
s
.
mu
.
Unlock
()
out
:=
make
([]
*
logger
.
LogEvent
,
len
(
s
.
events
))
copy
(
out
,
s
.
events
)
return
out
}
func
initMiddlewareTestLogger
(
t
*
testing
.
T
)
*
testLogSink
{
return
initMiddlewareTestLoggerWithLevel
(
t
,
"debug"
)
}
func
initMiddlewareTestLoggerWithLevel
(
t
*
testing
.
T
,
level
string
)
*
testLogSink
{
t
.
Helper
()
level
=
strings
.
TrimSpace
(
level
)
if
level
==
""
{
level
=
"debug"
}
if
err
:=
logger
.
Init
(
logger
.
InitOptions
{
Level
:
level
,
Format
:
"json"
,
ServiceName
:
"sub2api"
,
Environment
:
"test"
,
Output
:
logger
.
OutputOptions
{
ToStdout
:
false
,
ToFile
:
false
,
},
});
err
!=
nil
{
t
.
Fatalf
(
"init logger: %v"
,
err
)
}
sink
:=
&
testLogSink
{}
logger
.
SetSink
(
sink
)
t
.
Cleanup
(
func
()
{
logger
.
SetSink
(
nil
)
})
return
sink
}
func
TestRequestLogger_GenerateAndPropagateRequestID
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
r
:=
gin
.
New
()
r
.
Use
(
RequestLogger
())
r
.
GET
(
"/t"
,
func
(
c
*
gin
.
Context
)
{
reqID
,
ok
:=
c
.
Request
.
Context
()
.
Value
(
ctxkey
.
RequestID
)
.
(
string
)
if
!
ok
||
reqID
==
""
{
t
.
Fatalf
(
"request_id missing in context"
)
}
if
got
:=
c
.
Writer
.
Header
()
.
Get
(
requestIDHeader
);
got
!=
reqID
{
t
.
Fatalf
(
"response header request_id mismatch, header=%q ctx=%q"
,
got
,
reqID
)
}
c
.
Status
(
http
.
StatusOK
)
})
w
:=
httptest
.
NewRecorder
()
req
:=
httptest
.
NewRequest
(
http
.
MethodGet
,
"/t"
,
nil
)
r
.
ServeHTTP
(
w
,
req
)
if
w
.
Code
!=
http
.
StatusOK
{
t
.
Fatalf
(
"status=%d"
,
w
.
Code
)
}
if
w
.
Header
()
.
Get
(
requestIDHeader
)
==
""
{
t
.
Fatalf
(
"X-Request-ID should be set"
)
}
}
func
TestRequestLogger_KeepIncomingRequestID
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
r
:=
gin
.
New
()
r
.
Use
(
RequestLogger
())
r
.
GET
(
"/t"
,
func
(
c
*
gin
.
Context
)
{
reqID
,
_
:=
c
.
Request
.
Context
()
.
Value
(
ctxkey
.
RequestID
)
.
(
string
)
if
reqID
!=
"rid-fixed"
{
t
.
Fatalf
(
"request_id=%q, want rid-fixed"
,
reqID
)
}
c
.
Status
(
http
.
StatusOK
)
})
w
:=
httptest
.
NewRecorder
()
req
:=
httptest
.
NewRequest
(
http
.
MethodGet
,
"/t"
,
nil
)
req
.
Header
.
Set
(
requestIDHeader
,
"rid-fixed"
)
r
.
ServeHTTP
(
w
,
req
)
if
w
.
Code
!=
http
.
StatusOK
{
t
.
Fatalf
(
"status=%d"
,
w
.
Code
)
}
if
got
:=
w
.
Header
()
.
Get
(
requestIDHeader
);
got
!=
"rid-fixed"
{
t
.
Fatalf
(
"header=%q, want rid-fixed"
,
got
)
}
}
func
TestLogger_AccessLogIncludesCoreFields
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
sink
:=
initMiddlewareTestLogger
(
t
)
r
:=
gin
.
New
()
r
.
Use
(
Logger
())
r
.
Use
(
func
(
c
*
gin
.
Context
)
{
ctx
:=
c
.
Request
.
Context
()
ctx
=
context
.
WithValue
(
ctx
,
ctxkey
.
AccountID
,
int64
(
101
))
ctx
=
context
.
WithValue
(
ctx
,
ctxkey
.
Platform
,
"openai"
)
ctx
=
context
.
WithValue
(
ctx
,
ctxkey
.
Model
,
"gpt-5"
)
c
.
Request
=
c
.
Request
.
WithContext
(
ctx
)
c
.
Next
()
})
r
.
GET
(
"/api/test"
,
func
(
c
*
gin
.
Context
)
{
c
.
Status
(
http
.
StatusCreated
)
})
w
:=
httptest
.
NewRecorder
()
req
:=
httptest
.
NewRequest
(
http
.
MethodGet
,
"/api/test"
,
nil
)
r
.
ServeHTTP
(
w
,
req
)
if
w
.
Code
!=
http
.
StatusCreated
{
t
.
Fatalf
(
"status=%d"
,
w
.
Code
)
}
events
:=
sink
.
list
()
if
len
(
events
)
==
0
{
t
.
Fatalf
(
"expected at least one log event"
)
}
found
:=
false
for
_
,
event
:=
range
events
{
if
event
==
nil
||
event
.
Message
!=
"http request completed"
{
continue
}
found
=
true
switch
v
:=
event
.
Fields
[
"status_code"
]
.
(
type
)
{
case
int
:
if
v
!=
http
.
StatusCreated
{
t
.
Fatalf
(
"status_code field mismatch: %v"
,
v
)
}
case
int64
:
if
v
!=
int64
(
http
.
StatusCreated
)
{
t
.
Fatalf
(
"status_code field mismatch: %v"
,
v
)
}
default
:
t
.
Fatalf
(
"status_code type mismatch: %T"
,
v
)
}
switch
v
:=
event
.
Fields
[
"account_id"
]
.
(
type
)
{
case
int64
:
if
v
!=
101
{
t
.
Fatalf
(
"account_id field mismatch: %v"
,
v
)
}
case
int
:
if
v
!=
101
{
t
.
Fatalf
(
"account_id field mismatch: %v"
,
v
)
}
default
:
t
.
Fatalf
(
"account_id type mismatch: %T"
,
v
)
}
if
event
.
Fields
[
"platform"
]
!=
"openai"
||
event
.
Fields
[
"model"
]
!=
"gpt-5"
{
t
.
Fatalf
(
"platform/model mismatch: %+v"
,
event
.
Fields
)
}
}
if
!
found
{
t
.
Fatalf
(
"access log event not found"
)
}
}
func
TestLogger_HealthPathSkipped
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
sink
:=
initMiddlewareTestLogger
(
t
)
r
:=
gin
.
New
()
r
.
Use
(
Logger
())
r
.
GET
(
"/health"
,
func
(
c
*
gin
.
Context
)
{
c
.
Status
(
http
.
StatusOK
)
})
w
:=
httptest
.
NewRecorder
()
req
:=
httptest
.
NewRequest
(
http
.
MethodGet
,
"/health"
,
nil
)
r
.
ServeHTTP
(
w
,
req
)
if
w
.
Code
!=
http
.
StatusOK
{
t
.
Fatalf
(
"status=%d"
,
w
.
Code
)
}
if
len
(
sink
.
list
())
!=
0
{
t
.
Fatalf
(
"health endpoint should not write access log"
)
}
}
func
TestLogger_AccessLogDroppedWhenLevelWarn
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
sink
:=
initMiddlewareTestLoggerWithLevel
(
t
,
"warn"
)
r
:=
gin
.
New
()
r
.
Use
(
RequestLogger
())
r
.
Use
(
Logger
())
r
.
GET
(
"/api/test"
,
func
(
c
*
gin
.
Context
)
{
c
.
Status
(
http
.
StatusCreated
)
})
w
:=
httptest
.
NewRecorder
()
req
:=
httptest
.
NewRequest
(
http
.
MethodGet
,
"/api/test"
,
nil
)
r
.
ServeHTTP
(
w
,
req
)
if
w
.
Code
!=
http
.
StatusCreated
{
t
.
Fatalf
(
"status=%d"
,
w
.
Code
)
}
events
:=
sink
.
list
()
for
_
,
event
:=
range
events
{
if
event
!=
nil
&&
event
.
Message
==
"http request completed"
{
t
.
Fatalf
(
"access log should not be indexed when level=warn: %+v"
,
event
)
}
}
}
backend/internal/server/middleware/request_logger.go
0 → 100644
View file @
07be258d
package
middleware
import
(
"context"
"strings"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
"go.uber.org/zap"
)
const
requestIDHeader
=
"X-Request-ID"
// RequestLogger 在请求入口注入 request-scoped logger。
func
RequestLogger
()
gin
.
HandlerFunc
{
return
func
(
c
*
gin
.
Context
)
{
if
c
.
Request
==
nil
{
c
.
Next
()
return
}
requestID
:=
strings
.
TrimSpace
(
c
.
GetHeader
(
requestIDHeader
))
if
requestID
==
""
{
requestID
=
uuid
.
NewString
()
}
c
.
Header
(
requestIDHeader
,
requestID
)
ctx
:=
context
.
WithValue
(
c
.
Request
.
Context
(),
ctxkey
.
RequestID
,
requestID
)
clientRequestID
,
_
:=
ctx
.
Value
(
ctxkey
.
ClientRequestID
)
.
(
string
)
requestLogger
:=
logger
.
With
(
zap
.
String
(
"component"
,
"http"
),
zap
.
String
(
"request_id"
,
requestID
),
zap
.
String
(
"client_request_id"
,
strings
.
TrimSpace
(
clientRequestID
)),
zap
.
String
(
"path"
,
c
.
Request
.
URL
.
Path
),
zap
.
String
(
"method"
,
c
.
Request
.
Method
),
)
ctx
=
logger
.
IntoContext
(
ctx
,
requestLogger
)
c
.
Request
=
c
.
Request
.
WithContext
(
ctx
)
c
.
Next
()
}
}
backend/internal/server/middleware/security_headers.go
View file @
07be258d
...
...
@@ -3,6 +3,8 @@ package middleware
import
(
"crypto/rand"
"encoding/base64"
"fmt"
"log"
"strings"
"github.com/Wei-Shaw/sub2api/internal/config"
...
...
@@ -18,11 +20,14 @@ const (
CloudflareInsightsDomain
=
"https://static.cloudflareinsights.com"
)
// GenerateNonce generates a cryptographically secure random nonce
func
GenerateNonce
()
string
{
// GenerateNonce generates a cryptographically secure random nonce.
// 返回 error 以确保调用方在 crypto/rand 失败时能正确降级。
func
GenerateNonce
()
(
string
,
error
)
{
b
:=
make
([]
byte
,
16
)
_
,
_
=
rand
.
Read
(
b
)
return
base64
.
StdEncoding
.
EncodeToString
(
b
)
if
_
,
err
:=
rand
.
Read
(
b
);
err
!=
nil
{
return
""
,
fmt
.
Errorf
(
"generate CSP nonce: %w"
,
err
)
}
return
base64
.
StdEncoding
.
EncodeToString
(
b
),
nil
}
// GetNonceFromContext retrieves the CSP nonce from gin context
...
...
@@ -52,13 +57,18 @@ func SecurityHeaders(cfg config.CSPConfig) gin.HandlerFunc {
if
cfg
.
Enabled
{
// Generate nonce for this request
nonce
:=
GenerateNonce
()
nonce
,
err
:=
GenerateNonce
()
if
err
!=
nil
{
// crypto/rand 失败时降级为无 nonce 的 CSP 策略
log
.
Printf
(
"[SecurityHeaders] %v — 降级为无 nonce 的 CSP"
,
err
)
finalPolicy
:=
strings
.
ReplaceAll
(
policy
,
NonceTemplate
,
"'unsafe-inline'"
)
c
.
Header
(
"Content-Security-Policy"
,
finalPolicy
)
}
else
{
c
.
Set
(
CSPNonceKey
,
nonce
)
// Replace nonce placeholder in policy
finalPolicy
:=
strings
.
ReplaceAll
(
policy
,
NonceTemplate
,
"'nonce-"
+
nonce
+
"'"
)
c
.
Header
(
"Content-Security-Policy"
,
finalPolicy
)
}
}
c
.
Next
()
}
}
...
...
backend/internal/server/middleware/security_headers_test.go
View file @
07be258d
...
...
@@ -19,7 +19,8 @@ func init() {
func
TestGenerateNonce
(
t
*
testing
.
T
)
{
t
.
Run
(
"generates_valid_base64_string"
,
func
(
t
*
testing
.
T
)
{
nonce
:=
GenerateNonce
()
nonce
,
err
:=
GenerateNonce
()
require
.
NoError
(
t
,
err
)
// Should be valid base64
decoded
,
err
:=
base64
.
StdEncoding
.
DecodeString
(
nonce
)
...
...
@@ -32,14 +33,16 @@ func TestGenerateNonce(t *testing.T) {
t
.
Run
(
"generates_unique_nonces"
,
func
(
t
*
testing
.
T
)
{
nonces
:=
make
(
map
[
string
]
bool
)
for
i
:=
0
;
i
<
100
;
i
++
{
nonce
:=
GenerateNonce
()
nonce
,
err
:=
GenerateNonce
()
require
.
NoError
(
t
,
err
)
assert
.
False
(
t
,
nonces
[
nonce
],
"nonce should be unique"
)
nonces
[
nonce
]
=
true
}
})
t
.
Run
(
"nonce_has_expected_length"
,
func
(
t
*
testing
.
T
)
{
nonce
:=
GenerateNonce
()
nonce
,
err
:=
GenerateNonce
()
require
.
NoError
(
t
,
err
)
// 16 bytes -> 24 chars in base64 (with padding)
assert
.
Len
(
t
,
nonce
,
24
)
})
...
...
@@ -344,7 +347,7 @@ func TestAddToDirective(t *testing.T) {
// Benchmark tests
func
BenchmarkGenerateNonce
(
b
*
testing
.
B
)
{
for
i
:=
0
;
i
<
b
.
N
;
i
++
{
GenerateNonce
()
_
,
_
=
GenerateNonce
()
}
}
...
...
backend/internal/server/router.go
View file @
07be258d
...
...
@@ -29,6 +29,7 @@ func SetupRouter(
redisClient
*
redis
.
Client
,
)
*
gin
.
Engine
{
// 应用中间件
r
.
Use
(
middleware2
.
RequestLogger
())
r
.
Use
(
middleware2
.
Logger
())
r
.
Use
(
middleware2
.
CORS
(
cfg
.
CORS
))
r
.
Use
(
middleware2
.
SecurityHeaders
(
cfg
.
Security
.
CSP
))
...
...
backend/internal/server/routes/admin.go
View file @
07be258d
...
...
@@ -34,6 +34,8 @@ func RegisterAdminRoutes(
// OpenAI OAuth
registerOpenAIOAuthRoutes
(
admin
,
h
)
// Sora OAuth(实现复用 OpenAI OAuth 服务,入口独立)
registerSoraOAuthRoutes
(
admin
,
h
)
// Gemini OAuth
registerGeminiOAuthRoutes
(
admin
,
h
)
...
...
@@ -101,6 +103,9 @@ func registerOpsRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
{
runtime
.
GET
(
"/alert"
,
h
.
Admin
.
Ops
.
GetAlertRuntimeSettings
)
runtime
.
PUT
(
"/alert"
,
h
.
Admin
.
Ops
.
UpdateAlertRuntimeSettings
)
runtime
.
GET
(
"/logging"
,
h
.
Admin
.
Ops
.
GetRuntimeLogConfig
)
runtime
.
PUT
(
"/logging"
,
h
.
Admin
.
Ops
.
UpdateRuntimeLogConfig
)
runtime
.
POST
(
"/logging/reset"
,
h
.
Admin
.
Ops
.
ResetRuntimeLogConfig
)
}
// Advanced settings (DB-backed)
...
...
@@ -144,12 +149,18 @@ func registerOpsRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
// Request drilldown (success + error)
ops
.
GET
(
"/requests"
,
h
.
Admin
.
Ops
.
ListRequestDetails
)
// Indexed system logs
ops
.
GET
(
"/system-logs"
,
h
.
Admin
.
Ops
.
ListSystemLogs
)
ops
.
POST
(
"/system-logs/cleanup"
,
h
.
Admin
.
Ops
.
CleanupSystemLogs
)
ops
.
GET
(
"/system-logs/health"
,
h
.
Admin
.
Ops
.
GetSystemLogIngestionHealth
)
// Dashboard (vNext - raw path for MVP)
ops
.
GET
(
"/dashboard/overview"
,
h
.
Admin
.
Ops
.
GetDashboardOverview
)
ops
.
GET
(
"/dashboard/throughput-trend"
,
h
.
Admin
.
Ops
.
GetDashboardThroughputTrend
)
ops
.
GET
(
"/dashboard/latency-histogram"
,
h
.
Admin
.
Ops
.
GetDashboardLatencyHistogram
)
ops
.
GET
(
"/dashboard/error-trend"
,
h
.
Admin
.
Ops
.
GetDashboardErrorTrend
)
ops
.
GET
(
"/dashboard/error-distribution"
,
h
.
Admin
.
Ops
.
GetDashboardErrorDistribution
)
ops
.
GET
(
"/dashboard/openai-token-stats"
,
h
.
Admin
.
Ops
.
GetDashboardOpenAITokenStats
)
}
}
...
...
@@ -267,6 +278,19 @@ func registerOpenAIOAuthRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
}
}
func
registerSoraOAuthRoutes
(
admin
*
gin
.
RouterGroup
,
h
*
handler
.
Handlers
)
{
sora
:=
admin
.
Group
(
"/sora"
)
{
sora
.
POST
(
"/generate-auth-url"
,
h
.
Admin
.
OpenAIOAuth
.
GenerateAuthURL
)
sora
.
POST
(
"/exchange-code"
,
h
.
Admin
.
OpenAIOAuth
.
ExchangeCode
)
sora
.
POST
(
"/refresh-token"
,
h
.
Admin
.
OpenAIOAuth
.
RefreshToken
)
sora
.
POST
(
"/st2at"
,
h
.
Admin
.
OpenAIOAuth
.
ExchangeSoraSessionToken
)
sora
.
POST
(
"/rt2at"
,
h
.
Admin
.
OpenAIOAuth
.
RefreshToken
)
sora
.
POST
(
"/accounts/:id/refresh"
,
h
.
Admin
.
OpenAIOAuth
.
RefreshAccountToken
)
sora
.
POST
(
"/create-from-oauth"
,
h
.
Admin
.
OpenAIOAuth
.
CreateAccountFromOAuth
)
}
}
func
registerGeminiOAuthRoutes
(
admin
*
gin
.
RouterGroup
,
h
*
handler
.
Handlers
)
{
gemini
:=
admin
.
Group
(
"/gemini"
)
{
...
...
@@ -297,6 +321,7 @@ func registerProxyRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
proxies
.
PUT
(
"/:id"
,
h
.
Admin
.
Proxy
.
Update
)
proxies
.
DELETE
(
"/:id"
,
h
.
Admin
.
Proxy
.
Delete
)
proxies
.
POST
(
"/:id/test"
,
h
.
Admin
.
Proxy
.
Test
)
proxies
.
POST
(
"/:id/quality-check"
,
h
.
Admin
.
Proxy
.
CheckQuality
)
proxies
.
GET
(
"/:id/stats"
,
h
.
Admin
.
Proxy
.
GetStats
)
proxies
.
GET
(
"/:id/accounts"
,
h
.
Admin
.
Proxy
.
GetProxyAccounts
)
proxies
.
POST
(
"/batch-delete"
,
h
.
Admin
.
Proxy
.
BatchDelete
)
...
...
backend/internal/server/routes/auth.go
View file @
07be258d
...
...
@@ -24,10 +24,19 @@ func RegisterAuthRoutes(
// 公开接口
auth
:=
v1
.
Group
(
"/auth"
)
{
auth
.
POST
(
"/register"
,
h
.
Auth
.
Register
)
auth
.
POST
(
"/login"
,
h
.
Auth
.
Login
)
auth
.
POST
(
"/login/2fa"
,
h
.
Auth
.
Login2FA
)
auth
.
POST
(
"/send-verify-code"
,
h
.
Auth
.
SendVerifyCode
)
// 注册/登录/2FA/验证码发送均属于高风险入口,增加服务端兜底限流(Redis 故障时 fail-close)
auth
.
POST
(
"/register"
,
rateLimiter
.
LimitWithOptions
(
"auth-register"
,
5
,
time
.
Minute
,
middleware
.
RateLimitOptions
{
FailureMode
:
middleware
.
RateLimitFailClose
,
}),
h
.
Auth
.
Register
)
auth
.
POST
(
"/login"
,
rateLimiter
.
LimitWithOptions
(
"auth-login"
,
20
,
time
.
Minute
,
middleware
.
RateLimitOptions
{
FailureMode
:
middleware
.
RateLimitFailClose
,
}),
h
.
Auth
.
Login
)
auth
.
POST
(
"/login/2fa"
,
rateLimiter
.
LimitWithOptions
(
"auth-login-2fa"
,
20
,
time
.
Minute
,
middleware
.
RateLimitOptions
{
FailureMode
:
middleware
.
RateLimitFailClose
,
}),
h
.
Auth
.
Login2FA
)
auth
.
POST
(
"/send-verify-code"
,
rateLimiter
.
LimitWithOptions
(
"auth-send-verify-code"
,
5
,
time
.
Minute
,
middleware
.
RateLimitOptions
{
FailureMode
:
middleware
.
RateLimitFailClose
,
}),
h
.
Auth
.
SendVerifyCode
)
// Token刷新接口添加速率限制:每分钟最多 30 次(Redis 故障时 fail-close)
auth
.
POST
(
"/refresh"
,
rateLimiter
.
LimitWithOptions
(
"refresh-token"
,
30
,
time
.
Minute
,
middleware
.
RateLimitOptions
{
FailureMode
:
middleware
.
RateLimitFailClose
,
...
...
backend/internal/server/routes/auth_rate_limit_integration_test.go
0 → 100644
View file @
07be258d
//go:build integration
package
routes
import
(
"context"
"fmt"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"strconv"
"strings"
"testing"
"github.com/redis/go-redis/v9"
"github.com/stretchr/testify/require"
tcredis
"github.com/testcontainers/testcontainers-go/modules/redis"
)
const
authRouteRedisImageTag
=
"redis:8.4-alpine"
func
TestAuthRegisterRateLimitThresholdHitReturns429
(
t
*
testing
.
T
)
{
ctx
:=
context
.
Background
()
rdb
:=
startAuthRouteRedis
(
t
,
ctx
)
router
:=
newAuthRoutesTestRouter
(
rdb
)
const
path
=
"/api/v1/auth/register"
for
i
:=
1
;
i
<=
6
;
i
++
{
req
:=
httptest
.
NewRequest
(
http
.
MethodPost
,
path
,
strings
.
NewReader
(
`{}`
))
req
.
Header
.
Set
(
"Content-Type"
,
"application/json"
)
req
.
RemoteAddr
=
"198.51.100.10:23456"
w
:=
httptest
.
NewRecorder
()
router
.
ServeHTTP
(
w
,
req
)
if
i
<=
5
{
require
.
Equal
(
t
,
http
.
StatusBadRequest
,
w
.
Code
,
"第 %d 次请求应先进入业务校验"
,
i
)
continue
}
require
.
Equal
(
t
,
http
.
StatusTooManyRequests
,
w
.
Code
,
"第 6 次请求应命中限流"
)
require
.
Contains
(
t
,
w
.
Body
.
String
(),
"rate limit exceeded"
)
}
}
func
startAuthRouteRedis
(
t
*
testing
.
T
,
ctx
context
.
Context
)
*
redis
.
Client
{
t
.
Helper
()
ensureAuthRouteDockerAvailable
(
t
)
redisContainer
,
err
:=
tcredis
.
Run
(
ctx
,
authRouteRedisImageTag
)
require
.
NoError
(
t
,
err
)
t
.
Cleanup
(
func
()
{
_
=
redisContainer
.
Terminate
(
ctx
)
})
redisHost
,
err
:=
redisContainer
.
Host
(
ctx
)
require
.
NoError
(
t
,
err
)
redisPort
,
err
:=
redisContainer
.
MappedPort
(
ctx
,
"6379/tcp"
)
require
.
NoError
(
t
,
err
)
rdb
:=
redis
.
NewClient
(
&
redis
.
Options
{
Addr
:
fmt
.
Sprintf
(
"%s:%d"
,
redisHost
,
redisPort
.
Int
()),
DB
:
0
,
})
require
.
NoError
(
t
,
rdb
.
Ping
(
ctx
)
.
Err
())
t
.
Cleanup
(
func
()
{
_
=
rdb
.
Close
()
})
return
rdb
}
func
ensureAuthRouteDockerAvailable
(
t
*
testing
.
T
)
{
t
.
Helper
()
if
authRouteDockerAvailable
()
{
return
}
t
.
Skip
(
"Docker 未启用,跳过认证限流集成测试"
)
}
func
authRouteDockerAvailable
()
bool
{
if
os
.
Getenv
(
"DOCKER_HOST"
)
!=
""
{
return
true
}
socketCandidates
:=
[]
string
{
"/var/run/docker.sock"
,
filepath
.
Join
(
os
.
Getenv
(
"XDG_RUNTIME_DIR"
),
"docker.sock"
),
filepath
.
Join
(
authRouteUserHomeDir
(),
".docker"
,
"run"
,
"docker.sock"
),
filepath
.
Join
(
authRouteUserHomeDir
(),
".docker"
,
"desktop"
,
"docker.sock"
),
filepath
.
Join
(
"/run/user"
,
strconv
.
Itoa
(
os
.
Getuid
()),
"docker.sock"
),
}
for
_
,
socket
:=
range
socketCandidates
{
if
socket
==
""
{
continue
}
if
_
,
err
:=
os
.
Stat
(
socket
);
err
==
nil
{
return
true
}
}
return
false
}
func
authRouteUserHomeDir
()
string
{
home
,
err
:=
os
.
UserHomeDir
()
if
err
!=
nil
{
return
""
}
return
home
}
backend/internal/server/routes/auth_rate_limit_test.go
0 → 100644
View file @
07be258d
package
routes
import
(
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/handler"
servermiddleware
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/gin-gonic/gin"
"github.com/redis/go-redis/v9"
"github.com/stretchr/testify/require"
)
func
newAuthRoutesTestRouter
(
redisClient
*
redis
.
Client
)
*
gin
.
Engine
{
gin
.
SetMode
(
gin
.
TestMode
)
router
:=
gin
.
New
()
v1
:=
router
.
Group
(
"/api/v1"
)
RegisterAuthRoutes
(
v1
,
&
handler
.
Handlers
{
Auth
:
&
handler
.
AuthHandler
{},
Setting
:
&
handler
.
SettingHandler
{},
},
servermiddleware
.
JWTAuthMiddleware
(
func
(
c
*
gin
.
Context
)
{
c
.
Next
()
}),
redisClient
,
)
return
router
}
func
TestAuthRoutesRateLimitFailCloseWhenRedisUnavailable
(
t
*
testing
.
T
)
{
rdb
:=
redis
.
NewClient
(
&
redis
.
Options
{
Addr
:
"127.0.0.1:1"
,
DialTimeout
:
50
*
time
.
Millisecond
,
ReadTimeout
:
50
*
time
.
Millisecond
,
WriteTimeout
:
50
*
time
.
Millisecond
,
})
t
.
Cleanup
(
func
()
{
_
=
rdb
.
Close
()
})
router
:=
newAuthRoutesTestRouter
(
rdb
)
paths
:=
[]
string
{
"/api/v1/auth/register"
,
"/api/v1/auth/login"
,
"/api/v1/auth/login/2fa"
,
"/api/v1/auth/send-verify-code"
,
}
for
_
,
path
:=
range
paths
{
req
:=
httptest
.
NewRequest
(
http
.
MethodPost
,
path
,
strings
.
NewReader
(
`{}`
))
req
.
Header
.
Set
(
"Content-Type"
,
"application/json"
)
req
.
RemoteAddr
=
"203.0.113.10:12345"
w
:=
httptest
.
NewRecorder
()
router
.
ServeHTTP
(
w
,
req
)
require
.
Equal
(
t
,
http
.
StatusTooManyRequests
,
w
.
Code
,
"path=%s"
,
path
)
require
.
Contains
(
t
,
w
.
Body
.
String
(),
"rate limit exceeded"
,
"path=%s"
,
path
)
}
}
backend/internal/server/routes/gateway.go
View file @
07be258d
package
routes
import
(
"net/http"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/handler"
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
...
...
@@ -20,6 +22,11 @@ func RegisterGatewayRoutes(
cfg
*
config
.
Config
,
)
{
bodyLimit
:=
middleware
.
RequestBodyLimit
(
cfg
.
Gateway
.
MaxBodySize
)
soraMaxBodySize
:=
cfg
.
Gateway
.
SoraMaxBodySize
if
soraMaxBodySize
<=
0
{
soraMaxBodySize
=
cfg
.
Gateway
.
MaxBodySize
}
soraBodyLimit
:=
middleware
.
RequestBodyLimit
(
soraMaxBodySize
)
clientRequestID
:=
middleware
.
ClientRequestID
()
opsErrorLogger
:=
handler
.
OpsErrorLoggerMiddleware
(
opsService
)
...
...
@@ -36,6 +43,15 @@ func RegisterGatewayRoutes(
gateway
.
GET
(
"/usage"
,
h
.
Gateway
.
Usage
)
// OpenAI Responses API
gateway
.
POST
(
"/responses"
,
h
.
OpenAIGateway
.
Responses
)
// 明确阻止旧协议入口:OpenAI 仅支持 Responses API,避免客户端误解为会自动路由到其它平台。
gateway
.
POST
(
"/chat/completions"
,
func
(
c
*
gin
.
Context
)
{
c
.
JSON
(
http
.
StatusBadRequest
,
gin
.
H
{
"error"
:
gin
.
H
{
"type"
:
"invalid_request_error"
,
"message"
:
"Unsupported legacy protocol: /v1/chat/completions is not supported. Please use /v1/responses."
,
},
})
})
}
// Gemini 原生 API 兼容层(Gemini SDK/CLI 直连)
...
...
@@ -82,4 +98,25 @@ func RegisterGatewayRoutes(
antigravityV1Beta
.
GET
(
"/models/:model"
,
h
.
Gateway
.
GeminiV1BetaGetModel
)
antigravityV1Beta
.
POST
(
"/models/*modelAction"
,
h
.
Gateway
.
GeminiV1BetaModels
)
}
// Sora 专用路由(强制使用 sora 平台)
soraV1
:=
r
.
Group
(
"/sora/v1"
)
soraV1
.
Use
(
soraBodyLimit
)
soraV1
.
Use
(
clientRequestID
)
soraV1
.
Use
(
opsErrorLogger
)
soraV1
.
Use
(
middleware
.
ForcePlatform
(
service
.
PlatformSora
))
soraV1
.
Use
(
gin
.
HandlerFunc
(
apiKeyAuth
))
{
soraV1
.
POST
(
"/chat/completions"
,
h
.
SoraGateway
.
ChatCompletions
)
soraV1
.
GET
(
"/models"
,
h
.
Gateway
.
Models
)
}
// Sora 媒体代理(可选 API Key 验证)
if
cfg
.
Gateway
.
SoraMediaRequireAPIKey
{
r
.
GET
(
"/sora/media/*filepath"
,
gin
.
HandlerFunc
(
apiKeyAuth
),
h
.
SoraGateway
.
MediaProxy
)
}
else
{
r
.
GET
(
"/sora/media/*filepath"
,
h
.
SoraGateway
.
MediaProxy
)
}
// Sora 媒体代理(签名 URL,无需 API Key)
r
.
GET
(
"/sora/media-signed/*filepath"
,
h
.
SoraGateway
.
MediaProxySigned
)
}
backend/internal/service/account.go
View file @
07be258d
...
...
@@ -696,6 +696,51 @@ func (a *Account) IsMixedSchedulingEnabled() bool {
return
false
}
// IsOpenAIPassthroughEnabled 返回 OpenAI 账号是否启用“自动透传(仅替换认证)”。
//
// 新字段:accounts.extra.openai_passthrough。
// 兼容字段:accounts.extra.openai_oauth_passthrough(历史 OAuth 开关)。
// 字段缺失或类型不正确时,按 false(关闭)处理。
func
(
a
*
Account
)
IsOpenAIPassthroughEnabled
()
bool
{
if
a
==
nil
||
!
a
.
IsOpenAI
()
||
a
.
Extra
==
nil
{
return
false
}
if
enabled
,
ok
:=
a
.
Extra
[
"openai_passthrough"
]
.
(
bool
);
ok
{
return
enabled
}
if
enabled
,
ok
:=
a
.
Extra
[
"openai_oauth_passthrough"
]
.
(
bool
);
ok
{
return
enabled
}
return
false
}
// IsOpenAIOAuthPassthroughEnabled 兼容旧接口,等价于 OAuth 账号的 IsOpenAIPassthroughEnabled。
func
(
a
*
Account
)
IsOpenAIOAuthPassthroughEnabled
()
bool
{
return
a
!=
nil
&&
a
.
IsOpenAIOAuth
()
&&
a
.
IsOpenAIPassthroughEnabled
()
}
// IsAnthropicAPIKeyPassthroughEnabled 返回 Anthropic API Key 账号是否启用“自动透传(仅替换认证)”。
// 字段:accounts.extra.anthropic_passthrough。
// 字段缺失或类型不正确时,按 false(关闭)处理。
func
(
a
*
Account
)
IsAnthropicAPIKeyPassthroughEnabled
()
bool
{
if
a
==
nil
||
a
.
Platform
!=
PlatformAnthropic
||
a
.
Type
!=
AccountTypeAPIKey
||
a
.
Extra
==
nil
{
return
false
}
enabled
,
ok
:=
a
.
Extra
[
"anthropic_passthrough"
]
.
(
bool
)
return
ok
&&
enabled
}
// IsCodexCLIOnlyEnabled 返回 OpenAI OAuth 账号是否启用“仅允许 Codex 官方客户端”。
// 字段:accounts.extra.codex_cli_only。
// 字段缺失或类型不正确时,按 false(关闭)处理。
func
(
a
*
Account
)
IsCodexCLIOnlyEnabled
()
bool
{
if
a
==
nil
||
!
a
.
IsOpenAIOAuth
()
||
a
.
Extra
==
nil
{
return
false
}
enabled
,
ok
:=
a
.
Extra
[
"codex_cli_only"
]
.
(
bool
)
return
ok
&&
enabled
}
// WindowCostSchedulability 窗口费用调度状态
type
WindowCostSchedulability
int
...
...
backend/internal/service/account_anthropic_passthrough_test.go
0 → 100644
View file @
07be258d
package
service
import
(
"testing"
"github.com/stretchr/testify/require"
)
func
TestAccount_IsAnthropicAPIKeyPassthroughEnabled
(
t
*
testing
.
T
)
{
t
.
Run
(
"Anthropic API Key 开启"
,
func
(
t
*
testing
.
T
)
{
account
:=
&
Account
{
Platform
:
PlatformAnthropic
,
Type
:
AccountTypeAPIKey
,
Extra
:
map
[
string
]
any
{
"anthropic_passthrough"
:
true
,
},
}
require
.
True
(
t
,
account
.
IsAnthropicAPIKeyPassthroughEnabled
())
})
t
.
Run
(
"Anthropic API Key 关闭"
,
func
(
t
*
testing
.
T
)
{
account
:=
&
Account
{
Platform
:
PlatformAnthropic
,
Type
:
AccountTypeAPIKey
,
Extra
:
map
[
string
]
any
{
"anthropic_passthrough"
:
false
,
},
}
require
.
False
(
t
,
account
.
IsAnthropicAPIKeyPassthroughEnabled
())
})
t
.
Run
(
"字段类型非法默认关闭"
,
func
(
t
*
testing
.
T
)
{
account
:=
&
Account
{
Platform
:
PlatformAnthropic
,
Type
:
AccountTypeAPIKey
,
Extra
:
map
[
string
]
any
{
"anthropic_passthrough"
:
"true"
,
},
}
require
.
False
(
t
,
account
.
IsAnthropicAPIKeyPassthroughEnabled
())
})
t
.
Run
(
"非 Anthropic API Key 账号始终关闭"
,
func
(
t
*
testing
.
T
)
{
oauth
:=
&
Account
{
Platform
:
PlatformAnthropic
,
Type
:
AccountTypeOAuth
,
Extra
:
map
[
string
]
any
{
"anthropic_passthrough"
:
true
,
},
}
require
.
False
(
t
,
oauth
.
IsAnthropicAPIKeyPassthroughEnabled
())
openai
:=
&
Account
{
Platform
:
PlatformOpenAI
,
Type
:
AccountTypeAPIKey
,
Extra
:
map
[
string
]
any
{
"anthropic_passthrough"
:
true
,
},
}
require
.
False
(
t
,
openai
.
IsAnthropicAPIKeyPassthroughEnabled
())
})
}
Prev
1
…
6
7
8
9
10
11
12
13
14
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment