Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Menu
Open sidebar
陈曦
sub2api
Commits
6bccb8a8
Unverified
Commit
6bccb8a8
authored
Feb 24, 2026
by
Wesley Liddick
Committed by
GitHub
Feb 24, 2026
Browse files
Merge branch 'main' into feature/antigravity-user-agent-configurable
parents
1fc6ef3d
3de1e0e4
Changes
270
Hide whitespace changes
Inline
Side-by-side
Too many changes to show.
To preserve performance only
270 of 270+
files are displayed.
Plain diff
Email patch
backend/internal/server/middleware/api_key_auth_google_test.go
View file @
6bccb8a8
...
@@ -7,6 +7,7 @@ import (
...
@@ -7,6 +7,7 @@ import (
"net/http"
"net/http"
"net/http/httptest"
"net/http/httptest"
"testing"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
...
@@ -18,7 +19,8 @@ import (
...
@@ -18,7 +19,8 @@ import (
)
)
type
fakeAPIKeyRepo
struct
{
type
fakeAPIKeyRepo
struct
{
getByKey
func
(
ctx
context
.
Context
,
key
string
)
(
*
service
.
APIKey
,
error
)
getByKey
func
(
ctx
context
.
Context
,
key
string
)
(
*
service
.
APIKey
,
error
)
updateLastUsed
func
(
ctx
context
.
Context
,
id
int64
,
usedAt
time
.
Time
)
error
}
}
func
(
f
fakeAPIKeyRepo
)
Create
(
ctx
context
.
Context
,
key
*
service
.
APIKey
)
error
{
func
(
f
fakeAPIKeyRepo
)
Create
(
ctx
context
.
Context
,
key
*
service
.
APIKey
)
error
{
...
@@ -78,6 +80,12 @@ func (f fakeAPIKeyRepo) ListKeysByGroupID(ctx context.Context, groupID int64) ([
...
@@ -78,6 +80,12 @@ func (f fakeAPIKeyRepo) ListKeysByGroupID(ctx context.Context, groupID int64) ([
func
(
f
fakeAPIKeyRepo
)
IncrementQuotaUsed
(
ctx
context
.
Context
,
id
int64
,
amount
float64
)
(
float64
,
error
)
{
func
(
f
fakeAPIKeyRepo
)
IncrementQuotaUsed
(
ctx
context
.
Context
,
id
int64
,
amount
float64
)
(
float64
,
error
)
{
return
0
,
errors
.
New
(
"not implemented"
)
return
0
,
errors
.
New
(
"not implemented"
)
}
}
func
(
f
fakeAPIKeyRepo
)
UpdateLastUsed
(
ctx
context
.
Context
,
id
int64
,
usedAt
time
.
Time
)
error
{
if
f
.
updateLastUsed
!=
nil
{
return
f
.
updateLastUsed
(
ctx
,
id
,
usedAt
)
}
return
nil
}
type
googleErrorResponse
struct
{
type
googleErrorResponse
struct
{
Error
struct
{
Error
struct
{
...
@@ -356,3 +364,144 @@ func TestApiKeyAuthWithSubscriptionGoogle_InsufficientBalance(t *testing.T) {
...
@@ -356,3 +364,144 @@ func TestApiKeyAuthWithSubscriptionGoogle_InsufficientBalance(t *testing.T) {
require
.
Equal
(
t
,
"Insufficient account balance"
,
resp
.
Error
.
Message
)
require
.
Equal
(
t
,
"Insufficient account balance"
,
resp
.
Error
.
Message
)
require
.
Equal
(
t
,
"PERMISSION_DENIED"
,
resp
.
Error
.
Status
)
require
.
Equal
(
t
,
"PERMISSION_DENIED"
,
resp
.
Error
.
Status
)
}
}
func
TestApiKeyAuthWithSubscriptionGoogle_TouchesLastUsedOnSuccess
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
user
:=
&
service
.
User
{
ID
:
11
,
Role
:
service
.
RoleUser
,
Status
:
service
.
StatusActive
,
Balance
:
10
,
Concurrency
:
3
,
}
apiKey
:=
&
service
.
APIKey
{
ID
:
201
,
UserID
:
user
.
ID
,
Key
:
"google-touch-ok"
,
Status
:
service
.
StatusActive
,
User
:
user
,
}
var
touchedID
int64
var
touchedAt
time
.
Time
r
:=
gin
.
New
()
apiKeyService
:=
newTestAPIKeyService
(
fakeAPIKeyRepo
{
getByKey
:
func
(
ctx
context
.
Context
,
key
string
)
(
*
service
.
APIKey
,
error
)
{
if
key
!=
apiKey
.
Key
{
return
nil
,
service
.
ErrAPIKeyNotFound
}
clone
:=
*
apiKey
return
&
clone
,
nil
},
updateLastUsed
:
func
(
ctx
context
.
Context
,
id
int64
,
usedAt
time
.
Time
)
error
{
touchedID
=
id
touchedAt
=
usedAt
return
nil
},
})
cfg
:=
&
config
.
Config
{
RunMode
:
config
.
RunModeSimple
}
r
.
Use
(
APIKeyAuthWithSubscriptionGoogle
(
apiKeyService
,
nil
,
cfg
))
r
.
GET
(
"/v1beta/test"
,
func
(
c
*
gin
.
Context
)
{
c
.
JSON
(
200
,
gin
.
H
{
"ok"
:
true
})
})
req
:=
httptest
.
NewRequest
(
http
.
MethodGet
,
"/v1beta/test"
,
nil
)
req
.
Header
.
Set
(
"x-goog-api-key"
,
apiKey
.
Key
)
rec
:=
httptest
.
NewRecorder
()
r
.
ServeHTTP
(
rec
,
req
)
require
.
Equal
(
t
,
http
.
StatusOK
,
rec
.
Code
)
require
.
Equal
(
t
,
apiKey
.
ID
,
touchedID
)
require
.
False
(
t
,
touchedAt
.
IsZero
())
}
func
TestApiKeyAuthWithSubscriptionGoogle_TouchFailureDoesNotBlock
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
user
:=
&
service
.
User
{
ID
:
12
,
Role
:
service
.
RoleUser
,
Status
:
service
.
StatusActive
,
Balance
:
10
,
Concurrency
:
3
,
}
apiKey
:=
&
service
.
APIKey
{
ID
:
202
,
UserID
:
user
.
ID
,
Key
:
"google-touch-fail"
,
Status
:
service
.
StatusActive
,
User
:
user
,
}
touchCalls
:=
0
r
:=
gin
.
New
()
apiKeyService
:=
newTestAPIKeyService
(
fakeAPIKeyRepo
{
getByKey
:
func
(
ctx
context
.
Context
,
key
string
)
(
*
service
.
APIKey
,
error
)
{
if
key
!=
apiKey
.
Key
{
return
nil
,
service
.
ErrAPIKeyNotFound
}
clone
:=
*
apiKey
return
&
clone
,
nil
},
updateLastUsed
:
func
(
ctx
context
.
Context
,
id
int64
,
usedAt
time
.
Time
)
error
{
touchCalls
++
return
errors
.
New
(
"write failed"
)
},
})
cfg
:=
&
config
.
Config
{
RunMode
:
config
.
RunModeSimple
}
r
.
Use
(
APIKeyAuthWithSubscriptionGoogle
(
apiKeyService
,
nil
,
cfg
))
r
.
GET
(
"/v1beta/test"
,
func
(
c
*
gin
.
Context
)
{
c
.
JSON
(
200
,
gin
.
H
{
"ok"
:
true
})
})
req
:=
httptest
.
NewRequest
(
http
.
MethodGet
,
"/v1beta/test"
,
nil
)
req
.
Header
.
Set
(
"x-goog-api-key"
,
apiKey
.
Key
)
rec
:=
httptest
.
NewRecorder
()
r
.
ServeHTTP
(
rec
,
req
)
require
.
Equal
(
t
,
http
.
StatusOK
,
rec
.
Code
)
require
.
Equal
(
t
,
1
,
touchCalls
)
}
func
TestApiKeyAuthWithSubscriptionGoogle_TouchesLastUsedInStandardMode
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
user
:=
&
service
.
User
{
ID
:
13
,
Role
:
service
.
RoleUser
,
Status
:
service
.
StatusActive
,
Balance
:
10
,
Concurrency
:
3
,
}
apiKey
:=
&
service
.
APIKey
{
ID
:
203
,
UserID
:
user
.
ID
,
Key
:
"google-touch-standard"
,
Status
:
service
.
StatusActive
,
User
:
user
,
}
touchCalls
:=
0
r
:=
gin
.
New
()
apiKeyService
:=
newTestAPIKeyService
(
fakeAPIKeyRepo
{
getByKey
:
func
(
ctx
context
.
Context
,
key
string
)
(
*
service
.
APIKey
,
error
)
{
if
key
!=
apiKey
.
Key
{
return
nil
,
service
.
ErrAPIKeyNotFound
}
clone
:=
*
apiKey
return
&
clone
,
nil
},
updateLastUsed
:
func
(
ctx
context
.
Context
,
id
int64
,
usedAt
time
.
Time
)
error
{
touchCalls
++
return
nil
},
})
cfg
:=
&
config
.
Config
{
RunMode
:
config
.
RunModeStandard
}
r
.
Use
(
APIKeyAuthWithSubscriptionGoogle
(
apiKeyService
,
nil
,
cfg
))
r
.
GET
(
"/v1beta/test"
,
func
(
c
*
gin
.
Context
)
{
c
.
JSON
(
200
,
gin
.
H
{
"ok"
:
true
})
})
req
:=
httptest
.
NewRequest
(
http
.
MethodGet
,
"/v1beta/test"
,
nil
)
req
.
Header
.
Set
(
"Authorization"
,
"Bearer "
+
apiKey
.
Key
)
rec
:=
httptest
.
NewRecorder
()
r
.
ServeHTTP
(
rec
,
req
)
require
.
Equal
(
t
,
http
.
StatusOK
,
rec
.
Code
)
require
.
Equal
(
t
,
1
,
touchCalls
)
}
backend/internal/server/middleware/api_key_auth_test.go
View file @
6bccb8a8
...
@@ -57,10 +57,61 @@ func TestSimpleModeBypassesQuotaCheck(t *testing.T) {
...
@@ -57,10 +57,61 @@ func TestSimpleModeBypassesQuotaCheck(t *testing.T) {
},
},
}
}
t
.
Run
(
"standard_mode_needs_maintenance_does_not_block_request"
,
func
(
t
*
testing
.
T
)
{
cfg
:=
&
config
.
Config
{
RunMode
:
config
.
RunModeStandard
}
cfg
.
SubscriptionMaintenance
.
WorkerCount
=
1
cfg
.
SubscriptionMaintenance
.
QueueSize
=
1
apiKeyService
:=
service
.
NewAPIKeyService
(
apiKeyRepo
,
nil
,
nil
,
nil
,
nil
,
nil
,
cfg
)
past
:=
time
.
Now
()
.
Add
(
-
48
*
time
.
Hour
)
sub
:=
&
service
.
UserSubscription
{
ID
:
55
,
UserID
:
user
.
ID
,
GroupID
:
group
.
ID
,
Status
:
service
.
SubscriptionStatusActive
,
ExpiresAt
:
time
.
Now
()
.
Add
(
24
*
time
.
Hour
),
DailyWindowStart
:
&
past
,
DailyUsageUSD
:
0
,
}
maintenanceCalled
:=
make
(
chan
struct
{},
1
)
subscriptionRepo
:=
&
stubUserSubscriptionRepo
{
getActive
:
func
(
ctx
context
.
Context
,
userID
,
groupID
int64
)
(
*
service
.
UserSubscription
,
error
)
{
clone
:=
*
sub
return
&
clone
,
nil
},
updateStatus
:
func
(
ctx
context
.
Context
,
subscriptionID
int64
,
status
string
)
error
{
return
nil
},
activateWindow
:
func
(
ctx
context
.
Context
,
id
int64
,
start
time
.
Time
)
error
{
return
nil
},
resetDaily
:
func
(
ctx
context
.
Context
,
id
int64
,
start
time
.
Time
)
error
{
maintenanceCalled
<-
struct
{}{}
return
nil
},
resetWeekly
:
func
(
ctx
context
.
Context
,
id
int64
,
start
time
.
Time
)
error
{
return
nil
},
resetMonthly
:
func
(
ctx
context
.
Context
,
id
int64
,
start
time
.
Time
)
error
{
return
nil
},
}
subscriptionService
:=
service
.
NewSubscriptionService
(
nil
,
subscriptionRepo
,
nil
,
nil
,
cfg
)
t
.
Cleanup
(
subscriptionService
.
Stop
)
router
:=
newAuthTestRouter
(
apiKeyService
,
subscriptionService
,
cfg
)
w
:=
httptest
.
NewRecorder
()
req
:=
httptest
.
NewRequest
(
http
.
MethodGet
,
"/t"
,
nil
)
req
.
Header
.
Set
(
"x-api-key"
,
apiKey
.
Key
)
router
.
ServeHTTP
(
w
,
req
)
require
.
Equal
(
t
,
http
.
StatusOK
,
w
.
Code
)
select
{
case
<-
maintenanceCalled
:
// ok
case
<-
time
.
After
(
time
.
Second
)
:
t
.
Fatalf
(
"expected maintenance to be scheduled"
)
}
})
t
.
Run
(
"simple_mode_bypasses_quota_check"
,
func
(
t
*
testing
.
T
)
{
t
.
Run
(
"simple_mode_bypasses_quota_check"
,
func
(
t
*
testing
.
T
)
{
cfg
:=
&
config
.
Config
{
RunMode
:
config
.
RunModeSimple
}
cfg
:=
&
config
.
Config
{
RunMode
:
config
.
RunModeSimple
}
apiKeyService
:=
service
.
NewAPIKeyService
(
apiKeyRepo
,
nil
,
nil
,
nil
,
nil
,
nil
,
cfg
)
apiKeyService
:=
service
.
NewAPIKeyService
(
apiKeyRepo
,
nil
,
nil
,
nil
,
nil
,
nil
,
cfg
)
subscriptionService
:=
service
.
NewSubscriptionService
(
nil
,
&
stubUserSubscriptionRepo
{},
nil
)
subscriptionService
:=
service
.
NewSubscriptionService
(
nil
,
&
stubUserSubscriptionRepo
{},
nil
,
nil
,
cfg
)
router
:=
newAuthTestRouter
(
apiKeyService
,
subscriptionService
,
cfg
)
router
:=
newAuthTestRouter
(
apiKeyService
,
subscriptionService
,
cfg
)
w
:=
httptest
.
NewRecorder
()
w
:=
httptest
.
NewRecorder
()
...
@@ -71,6 +122,20 @@ func TestSimpleModeBypassesQuotaCheck(t *testing.T) {
...
@@ -71,6 +122,20 @@ func TestSimpleModeBypassesQuotaCheck(t *testing.T) {
require
.
Equal
(
t
,
http
.
StatusOK
,
w
.
Code
)
require
.
Equal
(
t
,
http
.
StatusOK
,
w
.
Code
)
})
})
t
.
Run
(
"simple_mode_accepts_lowercase_bearer"
,
func
(
t
*
testing
.
T
)
{
cfg
:=
&
config
.
Config
{
RunMode
:
config
.
RunModeSimple
}
apiKeyService
:=
service
.
NewAPIKeyService
(
apiKeyRepo
,
nil
,
nil
,
nil
,
nil
,
nil
,
cfg
)
subscriptionService
:=
service
.
NewSubscriptionService
(
nil
,
&
stubUserSubscriptionRepo
{},
nil
,
nil
,
cfg
)
router
:=
newAuthTestRouter
(
apiKeyService
,
subscriptionService
,
cfg
)
w
:=
httptest
.
NewRecorder
()
req
:=
httptest
.
NewRequest
(
http
.
MethodGet
,
"/t"
,
nil
)
req
.
Header
.
Set
(
"Authorization"
,
"bearer "
+
apiKey
.
Key
)
router
.
ServeHTTP
(
w
,
req
)
require
.
Equal
(
t
,
http
.
StatusOK
,
w
.
Code
)
})
t
.
Run
(
"standard_mode_enforces_quota_check"
,
func
(
t
*
testing
.
T
)
{
t
.
Run
(
"standard_mode_enforces_quota_check"
,
func
(
t
*
testing
.
T
)
{
cfg
:=
&
config
.
Config
{
RunMode
:
config
.
RunModeStandard
}
cfg
:=
&
config
.
Config
{
RunMode
:
config
.
RunModeStandard
}
apiKeyService
:=
service
.
NewAPIKeyService
(
apiKeyRepo
,
nil
,
nil
,
nil
,
nil
,
nil
,
cfg
)
apiKeyService
:=
service
.
NewAPIKeyService
(
apiKeyRepo
,
nil
,
nil
,
nil
,
nil
,
nil
,
cfg
)
...
@@ -99,7 +164,7 @@ func TestSimpleModeBypassesQuotaCheck(t *testing.T) {
...
@@ -99,7 +164,7 @@ func TestSimpleModeBypassesQuotaCheck(t *testing.T) {
resetWeekly
:
func
(
ctx
context
.
Context
,
id
int64
,
start
time
.
Time
)
error
{
return
nil
},
resetWeekly
:
func
(
ctx
context
.
Context
,
id
int64
,
start
time
.
Time
)
error
{
return
nil
},
resetMonthly
:
func
(
ctx
context
.
Context
,
id
int64
,
start
time
.
Time
)
error
{
return
nil
},
resetMonthly
:
func
(
ctx
context
.
Context
,
id
int64
,
start
time
.
Time
)
error
{
return
nil
},
}
}
subscriptionService
:=
service
.
NewSubscriptionService
(
nil
,
subscriptionRepo
,
nil
)
subscriptionService
:=
service
.
NewSubscriptionService
(
nil
,
subscriptionRepo
,
nil
,
nil
,
cfg
)
router
:=
newAuthTestRouter
(
apiKeyService
,
subscriptionService
,
cfg
)
router
:=
newAuthTestRouter
(
apiKeyService
,
subscriptionService
,
cfg
)
w
:=
httptest
.
NewRecorder
()
w
:=
httptest
.
NewRecorder
()
...
@@ -235,6 +300,198 @@ func TestAPIKeyAuthOverwritesInvalidContextGroup(t *testing.T) {
...
@@ -235,6 +300,198 @@ func TestAPIKeyAuthOverwritesInvalidContextGroup(t *testing.T) {
require
.
Equal
(
t
,
http
.
StatusOK
,
w
.
Code
)
require
.
Equal
(
t
,
http
.
StatusOK
,
w
.
Code
)
}
}
func
TestAPIKeyAuthIPRestrictionDoesNotTrustSpoofedForwardHeaders
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
user
:=
&
service
.
User
{
ID
:
7
,
Role
:
service
.
RoleUser
,
Status
:
service
.
StatusActive
,
Balance
:
10
,
Concurrency
:
3
,
}
apiKey
:=
&
service
.
APIKey
{
ID
:
100
,
UserID
:
user
.
ID
,
Key
:
"test-key"
,
Status
:
service
.
StatusActive
,
User
:
user
,
IPWhitelist
:
[]
string
{
"1.2.3.4"
},
}
apiKeyRepo
:=
&
stubApiKeyRepo
{
getByKey
:
func
(
ctx
context
.
Context
,
key
string
)
(
*
service
.
APIKey
,
error
)
{
if
key
!=
apiKey
.
Key
{
return
nil
,
service
.
ErrAPIKeyNotFound
}
clone
:=
*
apiKey
return
&
clone
,
nil
},
}
cfg
:=
&
config
.
Config
{
RunMode
:
config
.
RunModeSimple
}
apiKeyService
:=
service
.
NewAPIKeyService
(
apiKeyRepo
,
nil
,
nil
,
nil
,
nil
,
nil
,
cfg
)
router
:=
gin
.
New
()
require
.
NoError
(
t
,
router
.
SetTrustedProxies
(
nil
))
router
.
Use
(
gin
.
HandlerFunc
(
NewAPIKeyAuthMiddleware
(
apiKeyService
,
nil
,
cfg
)))
router
.
GET
(
"/t"
,
func
(
c
*
gin
.
Context
)
{
c
.
JSON
(
http
.
StatusOK
,
gin
.
H
{
"ok"
:
true
})
})
w
:=
httptest
.
NewRecorder
()
req
:=
httptest
.
NewRequest
(
http
.
MethodGet
,
"/t"
,
nil
)
req
.
RemoteAddr
=
"9.9.9.9:12345"
req
.
Header
.
Set
(
"x-api-key"
,
apiKey
.
Key
)
req
.
Header
.
Set
(
"X-Forwarded-For"
,
"1.2.3.4"
)
req
.
Header
.
Set
(
"X-Real-IP"
,
"1.2.3.4"
)
req
.
Header
.
Set
(
"CF-Connecting-IP"
,
"1.2.3.4"
)
router
.
ServeHTTP
(
w
,
req
)
require
.
Equal
(
t
,
http
.
StatusForbidden
,
w
.
Code
)
require
.
Contains
(
t
,
w
.
Body
.
String
(),
"ACCESS_DENIED"
)
}
func
TestAPIKeyAuthTouchesLastUsedOnSuccess
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
user
:=
&
service
.
User
{
ID
:
7
,
Role
:
service
.
RoleUser
,
Status
:
service
.
StatusActive
,
Balance
:
10
,
Concurrency
:
3
,
}
apiKey
:=
&
service
.
APIKey
{
ID
:
100
,
UserID
:
user
.
ID
,
Key
:
"touch-ok"
,
Status
:
service
.
StatusActive
,
User
:
user
,
}
var
touchedID
int64
var
touchedAt
time
.
Time
apiKeyRepo
:=
&
stubApiKeyRepo
{
getByKey
:
func
(
ctx
context
.
Context
,
key
string
)
(
*
service
.
APIKey
,
error
)
{
if
key
!=
apiKey
.
Key
{
return
nil
,
service
.
ErrAPIKeyNotFound
}
clone
:=
*
apiKey
return
&
clone
,
nil
},
updateLastUsed
:
func
(
ctx
context
.
Context
,
id
int64
,
usedAt
time
.
Time
)
error
{
touchedID
=
id
touchedAt
=
usedAt
return
nil
},
}
cfg
:=
&
config
.
Config
{
RunMode
:
config
.
RunModeSimple
}
apiKeyService
:=
service
.
NewAPIKeyService
(
apiKeyRepo
,
nil
,
nil
,
nil
,
nil
,
nil
,
cfg
)
router
:=
newAuthTestRouter
(
apiKeyService
,
nil
,
cfg
)
w
:=
httptest
.
NewRecorder
()
req
:=
httptest
.
NewRequest
(
http
.
MethodGet
,
"/t"
,
nil
)
req
.
Header
.
Set
(
"x-api-key"
,
apiKey
.
Key
)
router
.
ServeHTTP
(
w
,
req
)
require
.
Equal
(
t
,
http
.
StatusOK
,
w
.
Code
)
require
.
Equal
(
t
,
apiKey
.
ID
,
touchedID
)
require
.
False
(
t
,
touchedAt
.
IsZero
(),
"expected touch timestamp"
)
}
func
TestAPIKeyAuthTouchLastUsedFailureDoesNotBlock
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
user
:=
&
service
.
User
{
ID
:
8
,
Role
:
service
.
RoleUser
,
Status
:
service
.
StatusActive
,
Balance
:
10
,
Concurrency
:
3
,
}
apiKey
:=
&
service
.
APIKey
{
ID
:
101
,
UserID
:
user
.
ID
,
Key
:
"touch-fail"
,
Status
:
service
.
StatusActive
,
User
:
user
,
}
touchCalls
:=
0
apiKeyRepo
:=
&
stubApiKeyRepo
{
getByKey
:
func
(
ctx
context
.
Context
,
key
string
)
(
*
service
.
APIKey
,
error
)
{
if
key
!=
apiKey
.
Key
{
return
nil
,
service
.
ErrAPIKeyNotFound
}
clone
:=
*
apiKey
return
&
clone
,
nil
},
updateLastUsed
:
func
(
ctx
context
.
Context
,
id
int64
,
usedAt
time
.
Time
)
error
{
touchCalls
++
return
errors
.
New
(
"db unavailable"
)
},
}
cfg
:=
&
config
.
Config
{
RunMode
:
config
.
RunModeSimple
}
apiKeyService
:=
service
.
NewAPIKeyService
(
apiKeyRepo
,
nil
,
nil
,
nil
,
nil
,
nil
,
cfg
)
router
:=
newAuthTestRouter
(
apiKeyService
,
nil
,
cfg
)
w
:=
httptest
.
NewRecorder
()
req
:=
httptest
.
NewRequest
(
http
.
MethodGet
,
"/t"
,
nil
)
req
.
Header
.
Set
(
"x-api-key"
,
apiKey
.
Key
)
router
.
ServeHTTP
(
w
,
req
)
require
.
Equal
(
t
,
http
.
StatusOK
,
w
.
Code
,
"touch failure should not block request"
)
require
.
Equal
(
t
,
1
,
touchCalls
)
}
func
TestAPIKeyAuthTouchesLastUsedInStandardMode
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
user
:=
&
service
.
User
{
ID
:
9
,
Role
:
service
.
RoleUser
,
Status
:
service
.
StatusActive
,
Balance
:
10
,
Concurrency
:
3
,
}
apiKey
:=
&
service
.
APIKey
{
ID
:
102
,
UserID
:
user
.
ID
,
Key
:
"touch-standard"
,
Status
:
service
.
StatusActive
,
User
:
user
,
}
touchCalls
:=
0
apiKeyRepo
:=
&
stubApiKeyRepo
{
getByKey
:
func
(
ctx
context
.
Context
,
key
string
)
(
*
service
.
APIKey
,
error
)
{
if
key
!=
apiKey
.
Key
{
return
nil
,
service
.
ErrAPIKeyNotFound
}
clone
:=
*
apiKey
return
&
clone
,
nil
},
updateLastUsed
:
func
(
ctx
context
.
Context
,
id
int64
,
usedAt
time
.
Time
)
error
{
touchCalls
++
return
nil
},
}
cfg
:=
&
config
.
Config
{
RunMode
:
config
.
RunModeStandard
}
apiKeyService
:=
service
.
NewAPIKeyService
(
apiKeyRepo
,
nil
,
nil
,
nil
,
nil
,
nil
,
cfg
)
router
:=
newAuthTestRouter
(
apiKeyService
,
nil
,
cfg
)
w
:=
httptest
.
NewRecorder
()
req
:=
httptest
.
NewRequest
(
http
.
MethodGet
,
"/t"
,
nil
)
req
.
Header
.
Set
(
"x-api-key"
,
apiKey
.
Key
)
router
.
ServeHTTP
(
w
,
req
)
require
.
Equal
(
t
,
http
.
StatusOK
,
w
.
Code
)
require
.
Equal
(
t
,
1
,
touchCalls
)
}
func
newAuthTestRouter
(
apiKeyService
*
service
.
APIKeyService
,
subscriptionService
*
service
.
SubscriptionService
,
cfg
*
config
.
Config
)
*
gin
.
Engine
{
func
newAuthTestRouter
(
apiKeyService
*
service
.
APIKeyService
,
subscriptionService
*
service
.
SubscriptionService
,
cfg
*
config
.
Config
)
*
gin
.
Engine
{
router
:=
gin
.
New
()
router
:=
gin
.
New
()
router
.
Use
(
gin
.
HandlerFunc
(
NewAPIKeyAuthMiddleware
(
apiKeyService
,
subscriptionService
,
cfg
)))
router
.
Use
(
gin
.
HandlerFunc
(
NewAPIKeyAuthMiddleware
(
apiKeyService
,
subscriptionService
,
cfg
)))
...
@@ -245,7 +502,8 @@ func newAuthTestRouter(apiKeyService *service.APIKeyService, subscriptionService
...
@@ -245,7 +502,8 @@ func newAuthTestRouter(apiKeyService *service.APIKeyService, subscriptionService
}
}
type
stubApiKeyRepo
struct
{
type
stubApiKeyRepo
struct
{
getByKey
func
(
ctx
context
.
Context
,
key
string
)
(
*
service
.
APIKey
,
error
)
getByKey
func
(
ctx
context
.
Context
,
key
string
)
(
*
service
.
APIKey
,
error
)
updateLastUsed
func
(
ctx
context
.
Context
,
id
int64
,
usedAt
time
.
Time
)
error
}
}
func
(
r
*
stubApiKeyRepo
)
Create
(
ctx
context
.
Context
,
key
*
service
.
APIKey
)
error
{
func
(
r
*
stubApiKeyRepo
)
Create
(
ctx
context
.
Context
,
key
*
service
.
APIKey
)
error
{
...
@@ -323,6 +581,13 @@ func (r *stubApiKeyRepo) IncrementQuotaUsed(ctx context.Context, id int64, amoun
...
@@ -323,6 +581,13 @@ func (r *stubApiKeyRepo) IncrementQuotaUsed(ctx context.Context, id int64, amoun
return
0
,
errors
.
New
(
"not implemented"
)
return
0
,
errors
.
New
(
"not implemented"
)
}
}
func
(
r
*
stubApiKeyRepo
)
UpdateLastUsed
(
ctx
context
.
Context
,
id
int64
,
usedAt
time
.
Time
)
error
{
if
r
.
updateLastUsed
!=
nil
{
return
r
.
updateLastUsed
(
ctx
,
id
,
usedAt
)
}
return
nil
}
type
stubUserSubscriptionRepo
struct
{
type
stubUserSubscriptionRepo
struct
{
getActive
func
(
ctx
context
.
Context
,
userID
,
groupID
int64
)
(
*
service
.
UserSubscription
,
error
)
getActive
func
(
ctx
context
.
Context
,
userID
,
groupID
int64
)
(
*
service
.
UserSubscription
,
error
)
updateStatus
func
(
ctx
context
.
Context
,
subscriptionID
int64
,
status
string
)
error
updateStatus
func
(
ctx
context
.
Context
,
subscriptionID
int64
,
status
string
)
error
...
...
backend/internal/server/middleware/client_request_id.go
View file @
6bccb8a8
...
@@ -2,10 +2,13 @@ package middleware
...
@@ -2,10 +2,13 @@ package middleware
import
(
import
(
"context"
"context"
"strings"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"github.com/gin-gonic/gin"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
"github.com/google/uuid"
"go.uber.org/zap"
)
)
// ClientRequestID ensures every request has a unique client_request_id in request.Context().
// ClientRequestID ensures every request has a unique client_request_id in request.Context().
...
@@ -24,7 +27,10 @@ func ClientRequestID() gin.HandlerFunc {
...
@@ -24,7 +27,10 @@ func ClientRequestID() gin.HandlerFunc {
}
}
id
:=
uuid
.
New
()
.
String
()
id
:=
uuid
.
New
()
.
String
()
c
.
Request
=
c
.
Request
.
WithContext
(
context
.
WithValue
(
c
.
Request
.
Context
(),
ctxkey
.
ClientRequestID
,
id
))
ctx
:=
context
.
WithValue
(
c
.
Request
.
Context
(),
ctxkey
.
ClientRequestID
,
id
)
requestLogger
:=
logger
.
FromContext
(
ctx
)
.
With
(
zap
.
String
(
"client_request_id"
,
strings
.
TrimSpace
(
id
)))
ctx
=
logger
.
IntoContext
(
ctx
,
requestLogger
)
c
.
Request
=
c
.
Request
.
WithContext
(
ctx
)
c
.
Next
()
c
.
Next
()
}
}
}
}
backend/internal/server/middleware/cors.go
View file @
6bccb8a8
...
@@ -50,6 +50,19 @@ func CORS(cfg config.CORSConfig) gin.HandlerFunc {
...
@@ -50,6 +50,19 @@ func CORS(cfg config.CORSConfig) gin.HandlerFunc {
}
}
allowedSet
[
origin
]
=
struct
{}{}
allowedSet
[
origin
]
=
struct
{}{}
}
}
allowHeaders
:=
[]
string
{
"Content-Type"
,
"Content-Length"
,
"Accept-Encoding"
,
"X-CSRF-Token"
,
"Authorization"
,
"accept"
,
"origin"
,
"Cache-Control"
,
"X-Requested-With"
,
"X-API-Key"
,
}
// OpenAI Node SDK 会发送 x-stainless-* 请求头,需在 CORS 中显式放行。
openAIProperties
:=
[]
string
{
"lang"
,
"package-version"
,
"os"
,
"arch"
,
"retry-count"
,
"runtime"
,
"runtime-version"
,
"async"
,
"helper-method"
,
"poll-helper"
,
"custom-poll-interval"
,
"timeout"
,
}
for
_
,
prop
:=
range
openAIProperties
{
allowHeaders
=
append
(
allowHeaders
,
"x-stainless-"
+
prop
)
}
allowHeadersValue
:=
strings
.
Join
(
allowHeaders
,
", "
)
return
func
(
c
*
gin
.
Context
)
{
return
func
(
c
*
gin
.
Context
)
{
origin
:=
strings
.
TrimSpace
(
c
.
GetHeader
(
"Origin"
))
origin
:=
strings
.
TrimSpace
(
c
.
GetHeader
(
"Origin"
))
...
@@ -68,19 +81,11 @@ func CORS(cfg config.CORSConfig) gin.HandlerFunc {
...
@@ -68,19 +81,11 @@ func CORS(cfg config.CORSConfig) gin.HandlerFunc {
if
allowCredentials
{
if
allowCredentials
{
c
.
Writer
.
Header
()
.
Set
(
"Access-Control-Allow-Credentials"
,
"true"
)
c
.
Writer
.
Header
()
.
Set
(
"Access-Control-Allow-Credentials"
,
"true"
)
}
}
c
.
Writer
.
Header
()
.
Set
(
"Access-Control-Allow-Headers"
,
allowHeadersValue
)
c
.
Writer
.
Header
()
.
Set
(
"Access-Control-Allow-Methods"
,
"POST, OPTIONS, GET, PUT, DELETE, PATCH"
)
c
.
Writer
.
Header
()
.
Set
(
"Access-Control-Expose-Headers"
,
"ETag"
)
c
.
Writer
.
Header
()
.
Set
(
"Access-Control-Max-Age"
,
"86400"
)
}
}
allowHeaders
:=
[]
string
{
"Content-Type"
,
"Content-Length"
,
"Accept-Encoding"
,
"X-CSRF-Token"
,
"Authorization"
,
"accept"
,
"origin"
,
"Cache-Control"
,
"X-Requested-With"
,
"X-API-Key"
}
// openai node sdk
openAIProperties
:=
[]
string
{
"lang"
,
"package-version"
,
"os"
,
"arch"
,
"retry-count"
,
"runtime"
,
"runtime-version"
,
"async"
,
"helper-method"
,
"poll-helper"
,
"custom-poll-interval"
,
"timeout"
}
for
_
,
prop
:=
range
openAIProperties
{
allowHeaders
=
append
(
allowHeaders
,
"x-stainless-"
+
prop
)
}
c
.
Writer
.
Header
()
.
Set
(
"Access-Control-Allow-Headers"
,
strings
.
Join
(
allowHeaders
,
", "
))
c
.
Writer
.
Header
()
.
Set
(
"Access-Control-Allow-Methods"
,
"POST, OPTIONS, GET, PUT, DELETE, PATCH"
)
// 处理预检请求
// 处理预检请求
if
c
.
Request
.
Method
==
http
.
MethodOptions
{
if
c
.
Request
.
Method
==
http
.
MethodOptions
{
if
originAllowed
{
if
originAllowed
{
...
...
backend/internal/server/middleware/cors_test.go
0 → 100644
View file @
6bccb8a8
package
middleware
import
(
"net/http"
"net/http/httptest"
"testing"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
)
func
init
()
{
// cors_test 与 security_headers_test 在同一个包,但 init 是幂等的
gin
.
SetMode
(
gin
.
TestMode
)
}
// --- Task 8.2: 验证 CORS 条件化头部 ---
func
TestCORS_DisallowedOrigin_NoAllowHeaders
(
t
*
testing
.
T
)
{
cfg
:=
config
.
CORSConfig
{
AllowedOrigins
:
[]
string
{
"https://allowed.example.com"
},
AllowCredentials
:
false
,
}
middleware
:=
CORS
(
cfg
)
tests
:=
[]
struct
{
name
string
method
string
origin
string
}{
{
name
:
"preflight_disallowed_origin"
,
method
:
http
.
MethodOptions
,
origin
:
"https://evil.example.com"
,
},
{
name
:
"get_disallowed_origin"
,
method
:
http
.
MethodGet
,
origin
:
"https://evil.example.com"
,
},
{
name
:
"post_disallowed_origin"
,
method
:
http
.
MethodPost
,
origin
:
"https://attacker.example.com"
,
},
{
name
:
"preflight_no_origin"
,
method
:
http
.
MethodOptions
,
origin
:
""
,
},
}
for
_
,
tt
:=
range
tests
{
t
.
Run
(
tt
.
name
,
func
(
t
*
testing
.
T
)
{
w
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
w
)
c
.
Request
=
httptest
.
NewRequest
(
tt
.
method
,
"/"
,
nil
)
if
tt
.
origin
!=
""
{
c
.
Request
.
Header
.
Set
(
"Origin"
,
tt
.
origin
)
}
middleware
(
c
)
// 不应设置 Allow-Headers、Allow-Methods 和 Max-Age
assert
.
Empty
(
t
,
w
.
Header
()
.
Get
(
"Access-Control-Allow-Headers"
),
"不允许的 origin 不应收到 Allow-Headers"
)
assert
.
Empty
(
t
,
w
.
Header
()
.
Get
(
"Access-Control-Allow-Methods"
),
"不允许的 origin 不应收到 Allow-Methods"
)
assert
.
Empty
(
t
,
w
.
Header
()
.
Get
(
"Access-Control-Max-Age"
),
"不允许的 origin 不应收到 Max-Age"
)
assert
.
Empty
(
t
,
w
.
Header
()
.
Get
(
"Access-Control-Allow-Origin"
),
"不允许的 origin 不应收到 Allow-Origin"
)
})
}
}
func
TestCORS_AllowedOrigin_HasAllowHeaders
(
t
*
testing
.
T
)
{
cfg
:=
config
.
CORSConfig
{
AllowedOrigins
:
[]
string
{
"https://allowed.example.com"
},
AllowCredentials
:
false
,
}
middleware
:=
CORS
(
cfg
)
tests
:=
[]
struct
{
name
string
method
string
}{
{
name
:
"preflight_OPTIONS"
,
method
:
http
.
MethodOptions
},
{
name
:
"normal_GET"
,
method
:
http
.
MethodGet
},
{
name
:
"normal_POST"
,
method
:
http
.
MethodPost
},
}
for
_
,
tt
:=
range
tests
{
t
.
Run
(
tt
.
name
,
func
(
t
*
testing
.
T
)
{
w
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
w
)
c
.
Request
=
httptest
.
NewRequest
(
tt
.
method
,
"/"
,
nil
)
c
.
Request
.
Header
.
Set
(
"Origin"
,
"https://allowed.example.com"
)
middleware
(
c
)
// 应设置 Allow-Headers、Allow-Methods 和 Max-Age
assert
.
NotEmpty
(
t
,
w
.
Header
()
.
Get
(
"Access-Control-Allow-Headers"
),
"允许的 origin 应收到 Allow-Headers"
)
assert
.
NotEmpty
(
t
,
w
.
Header
()
.
Get
(
"Access-Control-Allow-Methods"
),
"允许的 origin 应收到 Allow-Methods"
)
assert
.
Equal
(
t
,
"86400"
,
w
.
Header
()
.
Get
(
"Access-Control-Max-Age"
),
"允许的 origin 应收到 Max-Age=86400"
)
assert
.
Equal
(
t
,
"https://allowed.example.com"
,
w
.
Header
()
.
Get
(
"Access-Control-Allow-Origin"
),
"允许的 origin 应收到 Allow-Origin"
)
})
}
}
func
TestCORS_PreflightDisallowedOrigin_ReturnsForbidden
(
t
*
testing
.
T
)
{
cfg
:=
config
.
CORSConfig
{
AllowedOrigins
:
[]
string
{
"https://allowed.example.com"
},
AllowCredentials
:
false
,
}
middleware
:=
CORS
(
cfg
)
w
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
w
)
c
.
Request
=
httptest
.
NewRequest
(
http
.
MethodOptions
,
"/"
,
nil
)
c
.
Request
.
Header
.
Set
(
"Origin"
,
"https://evil.example.com"
)
middleware
(
c
)
assert
.
Equal
(
t
,
http
.
StatusForbidden
,
w
.
Code
,
"不允许的 origin 的 preflight 请求应返回 403"
)
}
func
TestCORS_PreflightAllowedOrigin_ReturnsNoContent
(
t
*
testing
.
T
)
{
cfg
:=
config
.
CORSConfig
{
AllowedOrigins
:
[]
string
{
"https://allowed.example.com"
},
AllowCredentials
:
false
,
}
middleware
:=
CORS
(
cfg
)
w
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
w
)
c
.
Request
=
httptest
.
NewRequest
(
http
.
MethodOptions
,
"/"
,
nil
)
c
.
Request
.
Header
.
Set
(
"Origin"
,
"https://allowed.example.com"
)
middleware
(
c
)
assert
.
Equal
(
t
,
http
.
StatusNoContent
,
w
.
Code
,
"允许的 origin 的 preflight 请求应返回 204"
)
}
func
TestCORS_WildcardOrigin_AllowsAny
(
t
*
testing
.
T
)
{
cfg
:=
config
.
CORSConfig
{
AllowedOrigins
:
[]
string
{
"*"
},
AllowCredentials
:
false
,
}
middleware
:=
CORS
(
cfg
)
w
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
w
)
c
.
Request
=
httptest
.
NewRequest
(
http
.
MethodGet
,
"/"
,
nil
)
c
.
Request
.
Header
.
Set
(
"Origin"
,
"https://any-origin.example.com"
)
middleware
(
c
)
assert
.
Equal
(
t
,
"*"
,
w
.
Header
()
.
Get
(
"Access-Control-Allow-Origin"
),
"通配符配置应返回 *"
)
assert
.
NotEmpty
(
t
,
w
.
Header
()
.
Get
(
"Access-Control-Allow-Headers"
),
"通配符 origin 应设置 Allow-Headers"
)
assert
.
NotEmpty
(
t
,
w
.
Header
()
.
Get
(
"Access-Control-Allow-Methods"
),
"通配符 origin 应设置 Allow-Methods"
)
}
func
TestCORS_AllowCredentials_SetCorrectly
(
t
*
testing
.
T
)
{
cfg
:=
config
.
CORSConfig
{
AllowedOrigins
:
[]
string
{
"https://allowed.example.com"
},
AllowCredentials
:
true
,
}
middleware
:=
CORS
(
cfg
)
t
.
Run
(
"allowed_origin_gets_credentials"
,
func
(
t
*
testing
.
T
)
{
w
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
w
)
c
.
Request
=
httptest
.
NewRequest
(
http
.
MethodGet
,
"/"
,
nil
)
c
.
Request
.
Header
.
Set
(
"Origin"
,
"https://allowed.example.com"
)
middleware
(
c
)
assert
.
Equal
(
t
,
"true"
,
w
.
Header
()
.
Get
(
"Access-Control-Allow-Credentials"
),
"允许的 origin 且开启 credentials 应设置 Allow-Credentials"
)
})
t
.
Run
(
"disallowed_origin_no_credentials"
,
func
(
t
*
testing
.
T
)
{
w
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
w
)
c
.
Request
=
httptest
.
NewRequest
(
http
.
MethodGet
,
"/"
,
nil
)
c
.
Request
.
Header
.
Set
(
"Origin"
,
"https://evil.example.com"
)
middleware
(
c
)
assert
.
Empty
(
t
,
w
.
Header
()
.
Get
(
"Access-Control-Allow-Credentials"
),
"不允许的 origin 不应收到 Allow-Credentials"
)
})
}
func
TestCORS_WildcardWithCredentials_DisablesCredentials
(
t
*
testing
.
T
)
{
cfg
:=
config
.
CORSConfig
{
AllowedOrigins
:
[]
string
{
"*"
},
AllowCredentials
:
true
,
}
middleware
:=
CORS
(
cfg
)
w
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
w
)
c
.
Request
=
httptest
.
NewRequest
(
http
.
MethodGet
,
"/"
,
nil
)
c
.
Request
.
Header
.
Set
(
"Origin"
,
"https://any.example.com"
)
middleware
(
c
)
// 通配符 + credentials 不兼容,credentials 应被禁用
assert
.
Empty
(
t
,
w
.
Header
()
.
Get
(
"Access-Control-Allow-Credentials"
),
"通配符 origin 应禁用 Allow-Credentials"
)
}
func
TestCORS_MultipleAllowedOrigins
(
t
*
testing
.
T
)
{
cfg
:=
config
.
CORSConfig
{
AllowedOrigins
:
[]
string
{
"https://app1.example.com"
,
"https://app2.example.com"
,
},
AllowCredentials
:
false
,
}
middleware
:=
CORS
(
cfg
)
t
.
Run
(
"first_origin_allowed"
,
func
(
t
*
testing
.
T
)
{
w
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
w
)
c
.
Request
=
httptest
.
NewRequest
(
http
.
MethodGet
,
"/"
,
nil
)
c
.
Request
.
Header
.
Set
(
"Origin"
,
"https://app1.example.com"
)
middleware
(
c
)
assert
.
Equal
(
t
,
"https://app1.example.com"
,
w
.
Header
()
.
Get
(
"Access-Control-Allow-Origin"
))
assert
.
NotEmpty
(
t
,
w
.
Header
()
.
Get
(
"Access-Control-Allow-Headers"
))
})
t
.
Run
(
"second_origin_allowed"
,
func
(
t
*
testing
.
T
)
{
w
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
w
)
c
.
Request
=
httptest
.
NewRequest
(
http
.
MethodGet
,
"/"
,
nil
)
c
.
Request
.
Header
.
Set
(
"Origin"
,
"https://app2.example.com"
)
middleware
(
c
)
assert
.
Equal
(
t
,
"https://app2.example.com"
,
w
.
Header
()
.
Get
(
"Access-Control-Allow-Origin"
))
assert
.
NotEmpty
(
t
,
w
.
Header
()
.
Get
(
"Access-Control-Allow-Headers"
))
})
t
.
Run
(
"unlisted_origin_rejected"
,
func
(
t
*
testing
.
T
)
{
w
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
w
)
c
.
Request
=
httptest
.
NewRequest
(
http
.
MethodGet
,
"/"
,
nil
)
c
.
Request
.
Header
.
Set
(
"Origin"
,
"https://app3.example.com"
)
middleware
(
c
)
assert
.
Empty
(
t
,
w
.
Header
()
.
Get
(
"Access-Control-Allow-Origin"
))
assert
.
Empty
(
t
,
w
.
Header
()
.
Get
(
"Access-Control-Allow-Headers"
))
})
}
func
TestCORS_VaryHeader_SetForSpecificOrigin
(
t
*
testing
.
T
)
{
cfg
:=
config
.
CORSConfig
{
AllowedOrigins
:
[]
string
{
"https://allowed.example.com"
},
AllowCredentials
:
false
,
}
middleware
:=
CORS
(
cfg
)
w
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
w
)
c
.
Request
=
httptest
.
NewRequest
(
http
.
MethodGet
,
"/"
,
nil
)
c
.
Request
.
Header
.
Set
(
"Origin"
,
"https://allowed.example.com"
)
middleware
(
c
)
assert
.
Contains
(
t
,
w
.
Header
()
.
Values
(
"Vary"
),
"Origin"
,
"非通配符允许的 origin 应设置 Vary: Origin"
)
}
func
TestNormalizeOrigins
(
t
*
testing
.
T
)
{
tests
:=
[]
struct
{
name
string
input
[]
string
expect
[]
string
}{
{
name
:
"nil_input"
,
input
:
nil
,
expect
:
nil
},
{
name
:
"empty_input"
,
input
:
[]
string
{},
expect
:
nil
},
{
name
:
"trims_whitespace"
,
input
:
[]
string
{
" https://a.com "
,
" https://b.com"
},
expect
:
[]
string
{
"https://a.com"
,
"https://b.com"
}},
{
name
:
"removes_empty_strings"
,
input
:
[]
string
{
""
,
" "
,
"https://a.com"
},
expect
:
[]
string
{
"https://a.com"
}},
}
for
_
,
tt
:=
range
tests
{
t
.
Run
(
tt
.
name
,
func
(
t
*
testing
.
T
)
{
result
:=
normalizeOrigins
(
tt
.
input
)
assert
.
Equal
(
t
,
tt
.
expect
,
result
)
})
}
}
backend/internal/server/middleware/jwt_auth.go
View file @
6bccb8a8
...
@@ -26,12 +26,12 @@ func jwtAuth(authService *service.AuthService, userService *service.UserService)
...
@@ -26,12 +26,12 @@ func jwtAuth(authService *service.AuthService, userService *service.UserService)
// 验证Bearer scheme
// 验证Bearer scheme
parts
:=
strings
.
SplitN
(
authHeader
,
" "
,
2
)
parts
:=
strings
.
SplitN
(
authHeader
,
" "
,
2
)
if
len
(
parts
)
!=
2
||
parts
[
0
]
!=
"Bearer"
{
if
len
(
parts
)
!=
2
||
!
strings
.
EqualFold
(
parts
[
0
]
,
"Bearer"
)
{
AbortWithError
(
c
,
401
,
"INVALID_AUTH_HEADER"
,
"Authorization header format must be 'Bearer {token}'"
)
AbortWithError
(
c
,
401
,
"INVALID_AUTH_HEADER"
,
"Authorization header format must be 'Bearer {token}'"
)
return
return
}
}
tokenString
:=
parts
[
1
]
tokenString
:=
strings
.
TrimSpace
(
parts
[
1
]
)
if
tokenString
==
""
{
if
tokenString
==
""
{
AbortWithError
(
c
,
401
,
"EMPTY_TOKEN"
,
"Token cannot be empty"
)
AbortWithError
(
c
,
401
,
"EMPTY_TOKEN"
,
"Token cannot be empty"
)
return
return
...
...
backend/internal/server/middleware/jwt_auth_test.go
0 → 100644
View file @
6bccb8a8
//go:build unit
package
middleware
import
(
"context"
"encoding/json"
"errors"
"net/http"
"net/http/httptest"
"testing"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
// stubJWTUserRepo 实现 UserRepository 的最小子集,仅支持 GetByID。
type
stubJWTUserRepo
struct
{
service
.
UserRepository
users
map
[
int64
]
*
service
.
User
}
func
(
r
*
stubJWTUserRepo
)
GetByID
(
_
context
.
Context
,
id
int64
)
(
*
service
.
User
,
error
)
{
u
,
ok
:=
r
.
users
[
id
]
if
!
ok
{
return
nil
,
errors
.
New
(
"user not found"
)
}
return
u
,
nil
}
// newJWTTestEnv 创建 JWT 认证中间件测试环境。
// 返回 gin.Engine(已注册 JWT 中间件)和 AuthService(用于生成 Token)。
func
newJWTTestEnv
(
users
map
[
int64
]
*
service
.
User
)
(
*
gin
.
Engine
,
*
service
.
AuthService
)
{
gin
.
SetMode
(
gin
.
TestMode
)
cfg
:=
&
config
.
Config
{}
cfg
.
JWT
.
Secret
=
"test-jwt-secret-32bytes-long!!!"
cfg
.
JWT
.
AccessTokenExpireMinutes
=
60
userRepo
:=
&
stubJWTUserRepo
{
users
:
users
}
authSvc
:=
service
.
NewAuthService
(
userRepo
,
nil
,
nil
,
cfg
,
nil
,
nil
,
nil
,
nil
,
nil
)
userSvc
:=
service
.
NewUserService
(
userRepo
,
nil
,
nil
)
mw
:=
NewJWTAuthMiddleware
(
authSvc
,
userSvc
)
r
:=
gin
.
New
()
r
.
Use
(
gin
.
HandlerFunc
(
mw
))
r
.
GET
(
"/protected"
,
func
(
c
*
gin
.
Context
)
{
subject
,
_
:=
GetAuthSubjectFromContext
(
c
)
role
,
_
:=
GetUserRoleFromContext
(
c
)
c
.
JSON
(
http
.
StatusOK
,
gin
.
H
{
"user_id"
:
subject
.
UserID
,
"role"
:
role
,
})
})
return
r
,
authSvc
}
func
TestJWTAuth_ValidToken
(
t
*
testing
.
T
)
{
user
:=
&
service
.
User
{
ID
:
1
,
Email
:
"test@example.com"
,
Role
:
"user"
,
Status
:
service
.
StatusActive
,
Concurrency
:
5
,
TokenVersion
:
1
,
}
router
,
authSvc
:=
newJWTTestEnv
(
map
[
int64
]
*
service
.
User
{
1
:
user
})
token
,
err
:=
authSvc
.
GenerateToken
(
user
)
require
.
NoError
(
t
,
err
)
w
:=
httptest
.
NewRecorder
()
req
:=
httptest
.
NewRequest
(
http
.
MethodGet
,
"/protected"
,
nil
)
req
.
Header
.
Set
(
"Authorization"
,
"Bearer "
+
token
)
router
.
ServeHTTP
(
w
,
req
)
require
.
Equal
(
t
,
http
.
StatusOK
,
w
.
Code
)
var
body
map
[
string
]
any
require
.
NoError
(
t
,
json
.
Unmarshal
(
w
.
Body
.
Bytes
(),
&
body
))
require
.
Equal
(
t
,
float64
(
1
),
body
[
"user_id"
])
require
.
Equal
(
t
,
"user"
,
body
[
"role"
])
}
func
TestJWTAuth_ValidToken_LowercaseBearer
(
t
*
testing
.
T
)
{
user
:=
&
service
.
User
{
ID
:
1
,
Email
:
"test@example.com"
,
Role
:
"user"
,
Status
:
service
.
StatusActive
,
Concurrency
:
5
,
TokenVersion
:
1
,
}
router
,
authSvc
:=
newJWTTestEnv
(
map
[
int64
]
*
service
.
User
{
1
:
user
})
token
,
err
:=
authSvc
.
GenerateToken
(
user
)
require
.
NoError
(
t
,
err
)
w
:=
httptest
.
NewRecorder
()
req
:=
httptest
.
NewRequest
(
http
.
MethodGet
,
"/protected"
,
nil
)
req
.
Header
.
Set
(
"Authorization"
,
"bearer "
+
token
)
router
.
ServeHTTP
(
w
,
req
)
require
.
Equal
(
t
,
http
.
StatusOK
,
w
.
Code
)
}
func
TestJWTAuth_MissingAuthorizationHeader
(
t
*
testing
.
T
)
{
router
,
_
:=
newJWTTestEnv
(
nil
)
w
:=
httptest
.
NewRecorder
()
req
:=
httptest
.
NewRequest
(
http
.
MethodGet
,
"/protected"
,
nil
)
router
.
ServeHTTP
(
w
,
req
)
require
.
Equal
(
t
,
http
.
StatusUnauthorized
,
w
.
Code
)
var
body
ErrorResponse
require
.
NoError
(
t
,
json
.
Unmarshal
(
w
.
Body
.
Bytes
(),
&
body
))
require
.
Equal
(
t
,
"UNAUTHORIZED"
,
body
.
Code
)
}
func
TestJWTAuth_InvalidHeaderFormat
(
t
*
testing
.
T
)
{
tests
:=
[]
struct
{
name
string
header
string
}{
{
"无Bearer前缀"
,
"Token abc123"
},
{
"缺少空格分隔"
,
"Bearerabc123"
},
{
"仅有单词"
,
"abc123"
},
}
router
,
_
:=
newJWTTestEnv
(
nil
)
for
_
,
tt
:=
range
tests
{
t
.
Run
(
tt
.
name
,
func
(
t
*
testing
.
T
)
{
w
:=
httptest
.
NewRecorder
()
req
:=
httptest
.
NewRequest
(
http
.
MethodGet
,
"/protected"
,
nil
)
req
.
Header
.
Set
(
"Authorization"
,
tt
.
header
)
router
.
ServeHTTP
(
w
,
req
)
require
.
Equal
(
t
,
http
.
StatusUnauthorized
,
w
.
Code
)
var
body
ErrorResponse
require
.
NoError
(
t
,
json
.
Unmarshal
(
w
.
Body
.
Bytes
(),
&
body
))
require
.
Equal
(
t
,
"INVALID_AUTH_HEADER"
,
body
.
Code
)
})
}
}
func
TestJWTAuth_EmptyToken
(
t
*
testing
.
T
)
{
router
,
_
:=
newJWTTestEnv
(
nil
)
w
:=
httptest
.
NewRecorder
()
req
:=
httptest
.
NewRequest
(
http
.
MethodGet
,
"/protected"
,
nil
)
req
.
Header
.
Set
(
"Authorization"
,
"Bearer "
)
router
.
ServeHTTP
(
w
,
req
)
require
.
Equal
(
t
,
http
.
StatusUnauthorized
,
w
.
Code
)
var
body
ErrorResponse
require
.
NoError
(
t
,
json
.
Unmarshal
(
w
.
Body
.
Bytes
(),
&
body
))
require
.
Equal
(
t
,
"EMPTY_TOKEN"
,
body
.
Code
)
}
func
TestJWTAuth_TamperedToken
(
t
*
testing
.
T
)
{
router
,
_
:=
newJWTTestEnv
(
nil
)
w
:=
httptest
.
NewRecorder
()
req
:=
httptest
.
NewRequest
(
http
.
MethodGet
,
"/protected"
,
nil
)
req
.
Header
.
Set
(
"Authorization"
,
"Bearer eyJhbGciOiJIUzI1NiJ9.eyJ1c2VyX2lkIjoxfQ.invalid_signature"
)
router
.
ServeHTTP
(
w
,
req
)
require
.
Equal
(
t
,
http
.
StatusUnauthorized
,
w
.
Code
)
var
body
ErrorResponse
require
.
NoError
(
t
,
json
.
Unmarshal
(
w
.
Body
.
Bytes
(),
&
body
))
require
.
Equal
(
t
,
"INVALID_TOKEN"
,
body
.
Code
)
}
func
TestJWTAuth_UserNotFound
(
t
*
testing
.
T
)
{
// 使用 user ID=1 的 token,但 repo 中没有该用户
fakeUser
:=
&
service
.
User
{
ID
:
999
,
Email
:
"ghost@example.com"
,
Role
:
"user"
,
Status
:
service
.
StatusActive
,
TokenVersion
:
1
,
}
// 创建环境时不注入此用户,这样 GetByID 会失败
router
,
authSvc
:=
newJWTTestEnv
(
map
[
int64
]
*
service
.
User
{})
token
,
err
:=
authSvc
.
GenerateToken
(
fakeUser
)
require
.
NoError
(
t
,
err
)
w
:=
httptest
.
NewRecorder
()
req
:=
httptest
.
NewRequest
(
http
.
MethodGet
,
"/protected"
,
nil
)
req
.
Header
.
Set
(
"Authorization"
,
"Bearer "
+
token
)
router
.
ServeHTTP
(
w
,
req
)
require
.
Equal
(
t
,
http
.
StatusUnauthorized
,
w
.
Code
)
var
body
ErrorResponse
require
.
NoError
(
t
,
json
.
Unmarshal
(
w
.
Body
.
Bytes
(),
&
body
))
require
.
Equal
(
t
,
"USER_NOT_FOUND"
,
body
.
Code
)
}
func
TestJWTAuth_UserInactive
(
t
*
testing
.
T
)
{
user
:=
&
service
.
User
{
ID
:
1
,
Email
:
"disabled@example.com"
,
Role
:
"user"
,
Status
:
service
.
StatusDisabled
,
TokenVersion
:
1
,
}
router
,
authSvc
:=
newJWTTestEnv
(
map
[
int64
]
*
service
.
User
{
1
:
user
})
token
,
err
:=
authSvc
.
GenerateToken
(
user
)
require
.
NoError
(
t
,
err
)
w
:=
httptest
.
NewRecorder
()
req
:=
httptest
.
NewRequest
(
http
.
MethodGet
,
"/protected"
,
nil
)
req
.
Header
.
Set
(
"Authorization"
,
"Bearer "
+
token
)
router
.
ServeHTTP
(
w
,
req
)
require
.
Equal
(
t
,
http
.
StatusUnauthorized
,
w
.
Code
)
var
body
ErrorResponse
require
.
NoError
(
t
,
json
.
Unmarshal
(
w
.
Body
.
Bytes
(),
&
body
))
require
.
Equal
(
t
,
"USER_INACTIVE"
,
body
.
Code
)
}
func
TestJWTAuth_TokenVersionMismatch
(
t
*
testing
.
T
)
{
// Token 生成时 TokenVersion=1,但数据库中用户已更新为 TokenVersion=2(密码修改)
userForToken
:=
&
service
.
User
{
ID
:
1
,
Email
:
"test@example.com"
,
Role
:
"user"
,
Status
:
service
.
StatusActive
,
TokenVersion
:
1
,
}
userInDB
:=
&
service
.
User
{
ID
:
1
,
Email
:
"test@example.com"
,
Role
:
"user"
,
Status
:
service
.
StatusActive
,
TokenVersion
:
2
,
// 密码修改后版本递增
}
router
,
authSvc
:=
newJWTTestEnv
(
map
[
int64
]
*
service
.
User
{
1
:
userInDB
})
token
,
err
:=
authSvc
.
GenerateToken
(
userForToken
)
require
.
NoError
(
t
,
err
)
w
:=
httptest
.
NewRecorder
()
req
:=
httptest
.
NewRequest
(
http
.
MethodGet
,
"/protected"
,
nil
)
req
.
Header
.
Set
(
"Authorization"
,
"Bearer "
+
token
)
router
.
ServeHTTP
(
w
,
req
)
require
.
Equal
(
t
,
http
.
StatusUnauthorized
,
w
.
Code
)
var
body
ErrorResponse
require
.
NoError
(
t
,
json
.
Unmarshal
(
w
.
Body
.
Bytes
(),
&
body
))
require
.
Equal
(
t
,
"TOKEN_REVOKED"
,
body
.
Code
)
}
backend/internal/server/middleware/logger.go
View file @
6bccb8a8
package
middleware
package
middleware
import
(
import
(
"log"
"time"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"github.com/gin-gonic/gin"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
)
)
// Logger 请求日志中间件
// Logger 请求日志中间件
...
@@ -13,44 +15,52 @@ func Logger() gin.HandlerFunc {
...
@@ -13,44 +15,52 @@ func Logger() gin.HandlerFunc {
// 开始时间
// 开始时间
startTime
:=
time
.
Now
()
startTime
:=
time
.
Now
()
// 请求路径
path
:=
c
.
Request
.
URL
.
Path
// 处理请求
// 处理请求
c
.
Next
()
c
.
Next
()
// 结束时间
// 跳过健康检查等高频探针路径的日志
endTime
:=
time
.
Now
()
if
path
==
"/health"
||
path
==
"/setup/status"
{
return
}
// 执行时间
endTime
:=
time
.
Now
()
latency
:=
endTime
.
Sub
(
startTime
)
latency
:=
endTime
.
Sub
(
startTime
)
// 请求方法
method
:=
c
.
Request
.
Method
method
:=
c
.
Request
.
Method
// 请求路径
path
:=
c
.
Request
.
URL
.
Path
// 状态码
statusCode
:=
c
.
Writer
.
Status
()
statusCode
:=
c
.
Writer
.
Status
()
// 客户端IP
clientIP
:=
c
.
ClientIP
()
clientIP
:=
c
.
ClientIP
()
// 协议版本
protocol
:=
c
.
Request
.
Proto
protocol
:=
c
.
Request
.
Proto
accountID
,
hasAccountID
:=
c
.
Request
.
Context
()
.
Value
(
ctxkey
.
AccountID
)
.
(
int64
)
platform
,
_
:=
c
.
Request
.
Context
()
.
Value
(
ctxkey
.
Platform
)
.
(
string
)
model
,
_
:=
c
.
Request
.
Context
()
.
Value
(
ctxkey
.
Model
)
.
(
string
)
fields
:=
[]
zap
.
Field
{
zap
.
String
(
"component"
,
"http.access"
),
zap
.
Int
(
"status_code"
,
statusCode
),
zap
.
Int64
(
"latency_ms"
,
latency
.
Milliseconds
()),
zap
.
String
(
"client_ip"
,
clientIP
),
zap
.
String
(
"protocol"
,
protocol
),
zap
.
String
(
"method"
,
method
),
zap
.
String
(
"path"
,
path
),
}
if
hasAccountID
&&
accountID
>
0
{
fields
=
append
(
fields
,
zap
.
Int64
(
"account_id"
,
accountID
))
}
if
platform
!=
""
{
fields
=
append
(
fields
,
zap
.
String
(
"platform"
,
platform
))
}
if
model
!=
""
{
fields
=
append
(
fields
,
zap
.
String
(
"model"
,
model
))
}
l
:=
logger
.
FromContext
(
c
.
Request
.
Context
())
.
With
(
fields
...
)
l
.
Info
(
"http request completed"
,
zap
.
Time
(
"completed_at"
,
endTime
))
// 日志格式: [时间] 状态码 | 延迟 | IP | 协议 | 方法 路径
log
.
Printf
(
"[GIN] %v | %3d | %13v | %15s | %-6s | %-7s %s"
,
endTime
.
Format
(
"2006/01/02 - 15:04:05"
),
statusCode
,
latency
,
clientIP
,
protocol
,
method
,
path
,
)
// 如果有错误,额外记录错误信息
if
len
(
c
.
Errors
)
>
0
{
if
len
(
c
.
Errors
)
>
0
{
l
og
.
Printf
(
"[GIN] E
rrors
: %v
"
,
c
.
Errors
.
String
())
l
.
Warn
(
"http request contains gin errors"
,
zap
.
String
(
"e
rrors"
,
c
.
Errors
.
String
())
)
}
}
}
}
}
}
backend/internal/server/middleware/misc_coverage_test.go
0 → 100644
View file @
6bccb8a8
//go:build unit
package
middleware
import
(
"bytes"
"context"
"io"
"net/http"
"net/http/httptest"
"testing"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
func
TestClientRequestID_GeneratesWhenMissing
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
r
:=
gin
.
New
()
r
.
Use
(
ClientRequestID
())
r
.
GET
(
"/t"
,
func
(
c
*
gin
.
Context
)
{
v
:=
c
.
Request
.
Context
()
.
Value
(
ctxkey
.
ClientRequestID
)
require
.
NotNil
(
t
,
v
)
id
,
ok
:=
v
.
(
string
)
require
.
True
(
t
,
ok
)
require
.
NotEmpty
(
t
,
id
)
c
.
Status
(
http
.
StatusOK
)
})
w
:=
httptest
.
NewRecorder
()
req
:=
httptest
.
NewRequest
(
http
.
MethodGet
,
"/t"
,
nil
)
r
.
ServeHTTP
(
w
,
req
)
require
.
Equal
(
t
,
http
.
StatusOK
,
w
.
Code
)
}
func
TestClientRequestID_PreservesExisting
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
r
:=
gin
.
New
()
r
.
Use
(
ClientRequestID
())
r
.
GET
(
"/t"
,
func
(
c
*
gin
.
Context
)
{
id
,
ok
:=
c
.
Request
.
Context
()
.
Value
(
ctxkey
.
ClientRequestID
)
.
(
string
)
require
.
True
(
t
,
ok
)
require
.
Equal
(
t
,
"keep"
,
id
)
c
.
Status
(
http
.
StatusOK
)
})
w
:=
httptest
.
NewRecorder
()
req
:=
httptest
.
NewRequest
(
http
.
MethodGet
,
"/t"
,
nil
)
req
=
req
.
WithContext
(
context
.
WithValue
(
req
.
Context
(),
ctxkey
.
ClientRequestID
,
"keep"
))
r
.
ServeHTTP
(
w
,
req
)
require
.
Equal
(
t
,
http
.
StatusOK
,
w
.
Code
)
}
func
TestRequestBodyLimit_LimitsBody
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
r
:=
gin
.
New
()
r
.
Use
(
RequestBodyLimit
(
4
))
r
.
POST
(
"/t"
,
func
(
c
*
gin
.
Context
)
{
_
,
err
:=
io
.
ReadAll
(
c
.
Request
.
Body
)
require
.
Error
(
t
,
err
)
c
.
Status
(
http
.
StatusOK
)
})
w
:=
httptest
.
NewRecorder
()
req
:=
httptest
.
NewRequest
(
http
.
MethodPost
,
"/t"
,
bytes
.
NewBufferString
(
"12345"
))
r
.
ServeHTTP
(
w
,
req
)
require
.
Equal
(
t
,
http
.
StatusOK
,
w
.
Code
)
}
func
TestForcePlatform_SetsContextAndGinValue
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
r
:=
gin
.
New
()
r
.
Use
(
ForcePlatform
(
"anthropic"
))
r
.
GET
(
"/t"
,
func
(
c
*
gin
.
Context
)
{
require
.
True
(
t
,
HasForcePlatform
(
c
))
v
,
ok
:=
GetForcePlatformFromContext
(
c
)
require
.
True
(
t
,
ok
)
require
.
Equal
(
t
,
"anthropic"
,
v
)
ctxV
:=
c
.
Request
.
Context
()
.
Value
(
ctxkey
.
ForcePlatform
)
require
.
Equal
(
t
,
"anthropic"
,
ctxV
)
c
.
Status
(
http
.
StatusOK
)
})
w
:=
httptest
.
NewRecorder
()
req
:=
httptest
.
NewRequest
(
http
.
MethodGet
,
"/t"
,
nil
)
r
.
ServeHTTP
(
w
,
req
)
require
.
Equal
(
t
,
http
.
StatusOK
,
w
.
Code
)
}
func
TestAuthSubjectHelpers_RoundTrip
(
t
*
testing
.
T
)
{
c
:=
&
gin
.
Context
{}
c
.
Set
(
string
(
ContextKeyUser
),
AuthSubject
{
UserID
:
1
,
Concurrency
:
2
})
c
.
Set
(
string
(
ContextKeyUserRole
),
"admin"
)
sub
,
ok
:=
GetAuthSubjectFromContext
(
c
)
require
.
True
(
t
,
ok
)
require
.
Equal
(
t
,
int64
(
1
),
sub
.
UserID
)
require
.
Equal
(
t
,
2
,
sub
.
Concurrency
)
role
,
ok
:=
GetUserRoleFromContext
(
c
)
require
.
True
(
t
,
ok
)
require
.
Equal
(
t
,
"admin"
,
role
)
}
func
TestAPIKeyAndSubscriptionFromContext
(
t
*
testing
.
T
)
{
c
:=
&
gin
.
Context
{}
key
:=
&
service
.
APIKey
{
ID
:
1
}
c
.
Set
(
string
(
ContextKeyAPIKey
),
key
)
gotKey
,
ok
:=
GetAPIKeyFromContext
(
c
)
require
.
True
(
t
,
ok
)
require
.
Equal
(
t
,
int64
(
1
),
gotKey
.
ID
)
sub
:=
&
service
.
UserSubscription
{
ID
:
2
}
c
.
Set
(
string
(
ContextKeySubscription
),
sub
)
gotSub
,
ok
:=
GetSubscriptionFromContext
(
c
)
require
.
True
(
t
,
ok
)
require
.
Equal
(
t
,
int64
(
2
),
gotSub
.
ID
)
}
backend/internal/server/middleware/recovery_test.go
View file @
6bccb8a8
...
@@ -3,6 +3,7 @@
...
@@ -3,6 +3,7 @@
package
middleware
package
middleware
import
(
import
(
"bytes"
"encoding/json"
"encoding/json"
"net/http"
"net/http"
"net/http/httptest"
"net/http/httptest"
...
@@ -14,6 +15,34 @@ import (
...
@@ -14,6 +15,34 @@ import (
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/require"
)
)
func
TestRecovery_PanicLogContainsInfo
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
// 临时替换 DefaultErrorWriter 以捕获日志输出
var
buf
bytes
.
Buffer
originalWriter
:=
gin
.
DefaultErrorWriter
gin
.
DefaultErrorWriter
=
&
buf
t
.
Cleanup
(
func
()
{
gin
.
DefaultErrorWriter
=
originalWriter
})
r
:=
gin
.
New
()
r
.
Use
(
Recovery
())
r
.
GET
(
"/panic"
,
func
(
c
*
gin
.
Context
)
{
panic
(
"custom panic message for test"
)
})
w
:=
httptest
.
NewRecorder
()
req
:=
httptest
.
NewRequest
(
http
.
MethodGet
,
"/panic"
,
nil
)
r
.
ServeHTTP
(
w
,
req
)
require
.
Equal
(
t
,
http
.
StatusInternalServerError
,
w
.
Code
)
logOutput
:=
buf
.
String
()
require
.
Contains
(
t
,
logOutput
,
"custom panic message for test"
,
"日志应包含 panic 信息"
)
require
.
Contains
(
t
,
logOutput
,
"recovery_test.go"
,
"日志应包含堆栈跟踪文件名"
)
}
func
TestRecovery
(
t
*
testing
.
T
)
{
func
TestRecovery
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
gin
.
SetMode
(
gin
.
TestMode
)
...
...
backend/internal/server/middleware/request_access_logger_test.go
0 → 100644
View file @
6bccb8a8
package
middleware
import
(
"context"
"net/http"
"net/http/httptest"
"strings"
"sync"
"testing"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"github.com/gin-gonic/gin"
)
type
testLogSink
struct
{
mu
sync
.
Mutex
events
[]
*
logger
.
LogEvent
}
func
(
s
*
testLogSink
)
WriteLogEvent
(
event
*
logger
.
LogEvent
)
{
s
.
mu
.
Lock
()
defer
s
.
mu
.
Unlock
()
s
.
events
=
append
(
s
.
events
,
event
)
}
func
(
s
*
testLogSink
)
list
()
[]
*
logger
.
LogEvent
{
s
.
mu
.
Lock
()
defer
s
.
mu
.
Unlock
()
out
:=
make
([]
*
logger
.
LogEvent
,
len
(
s
.
events
))
copy
(
out
,
s
.
events
)
return
out
}
func
initMiddlewareTestLogger
(
t
*
testing
.
T
)
*
testLogSink
{
return
initMiddlewareTestLoggerWithLevel
(
t
,
"debug"
)
}
func
initMiddlewareTestLoggerWithLevel
(
t
*
testing
.
T
,
level
string
)
*
testLogSink
{
t
.
Helper
()
level
=
strings
.
TrimSpace
(
level
)
if
level
==
""
{
level
=
"debug"
}
if
err
:=
logger
.
Init
(
logger
.
InitOptions
{
Level
:
level
,
Format
:
"json"
,
ServiceName
:
"sub2api"
,
Environment
:
"test"
,
Output
:
logger
.
OutputOptions
{
ToStdout
:
false
,
ToFile
:
false
,
},
});
err
!=
nil
{
t
.
Fatalf
(
"init logger: %v"
,
err
)
}
sink
:=
&
testLogSink
{}
logger
.
SetSink
(
sink
)
t
.
Cleanup
(
func
()
{
logger
.
SetSink
(
nil
)
})
return
sink
}
func
TestRequestLogger_GenerateAndPropagateRequestID
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
r
:=
gin
.
New
()
r
.
Use
(
RequestLogger
())
r
.
GET
(
"/t"
,
func
(
c
*
gin
.
Context
)
{
reqID
,
ok
:=
c
.
Request
.
Context
()
.
Value
(
ctxkey
.
RequestID
)
.
(
string
)
if
!
ok
||
reqID
==
""
{
t
.
Fatalf
(
"request_id missing in context"
)
}
if
got
:=
c
.
Writer
.
Header
()
.
Get
(
requestIDHeader
);
got
!=
reqID
{
t
.
Fatalf
(
"response header request_id mismatch, header=%q ctx=%q"
,
got
,
reqID
)
}
c
.
Status
(
http
.
StatusOK
)
})
w
:=
httptest
.
NewRecorder
()
req
:=
httptest
.
NewRequest
(
http
.
MethodGet
,
"/t"
,
nil
)
r
.
ServeHTTP
(
w
,
req
)
if
w
.
Code
!=
http
.
StatusOK
{
t
.
Fatalf
(
"status=%d"
,
w
.
Code
)
}
if
w
.
Header
()
.
Get
(
requestIDHeader
)
==
""
{
t
.
Fatalf
(
"X-Request-ID should be set"
)
}
}
func
TestRequestLogger_KeepIncomingRequestID
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
r
:=
gin
.
New
()
r
.
Use
(
RequestLogger
())
r
.
GET
(
"/t"
,
func
(
c
*
gin
.
Context
)
{
reqID
,
_
:=
c
.
Request
.
Context
()
.
Value
(
ctxkey
.
RequestID
)
.
(
string
)
if
reqID
!=
"rid-fixed"
{
t
.
Fatalf
(
"request_id=%q, want rid-fixed"
,
reqID
)
}
c
.
Status
(
http
.
StatusOK
)
})
w
:=
httptest
.
NewRecorder
()
req
:=
httptest
.
NewRequest
(
http
.
MethodGet
,
"/t"
,
nil
)
req
.
Header
.
Set
(
requestIDHeader
,
"rid-fixed"
)
r
.
ServeHTTP
(
w
,
req
)
if
w
.
Code
!=
http
.
StatusOK
{
t
.
Fatalf
(
"status=%d"
,
w
.
Code
)
}
if
got
:=
w
.
Header
()
.
Get
(
requestIDHeader
);
got
!=
"rid-fixed"
{
t
.
Fatalf
(
"header=%q, want rid-fixed"
,
got
)
}
}
func
TestLogger_AccessLogIncludesCoreFields
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
sink
:=
initMiddlewareTestLogger
(
t
)
r
:=
gin
.
New
()
r
.
Use
(
Logger
())
r
.
Use
(
func
(
c
*
gin
.
Context
)
{
ctx
:=
c
.
Request
.
Context
()
ctx
=
context
.
WithValue
(
ctx
,
ctxkey
.
AccountID
,
int64
(
101
))
ctx
=
context
.
WithValue
(
ctx
,
ctxkey
.
Platform
,
"openai"
)
ctx
=
context
.
WithValue
(
ctx
,
ctxkey
.
Model
,
"gpt-5"
)
c
.
Request
=
c
.
Request
.
WithContext
(
ctx
)
c
.
Next
()
})
r
.
GET
(
"/api/test"
,
func
(
c
*
gin
.
Context
)
{
c
.
Status
(
http
.
StatusCreated
)
})
w
:=
httptest
.
NewRecorder
()
req
:=
httptest
.
NewRequest
(
http
.
MethodGet
,
"/api/test"
,
nil
)
r
.
ServeHTTP
(
w
,
req
)
if
w
.
Code
!=
http
.
StatusCreated
{
t
.
Fatalf
(
"status=%d"
,
w
.
Code
)
}
events
:=
sink
.
list
()
if
len
(
events
)
==
0
{
t
.
Fatalf
(
"expected at least one log event"
)
}
found
:=
false
for
_
,
event
:=
range
events
{
if
event
==
nil
||
event
.
Message
!=
"http request completed"
{
continue
}
found
=
true
switch
v
:=
event
.
Fields
[
"status_code"
]
.
(
type
)
{
case
int
:
if
v
!=
http
.
StatusCreated
{
t
.
Fatalf
(
"status_code field mismatch: %v"
,
v
)
}
case
int64
:
if
v
!=
int64
(
http
.
StatusCreated
)
{
t
.
Fatalf
(
"status_code field mismatch: %v"
,
v
)
}
default
:
t
.
Fatalf
(
"status_code type mismatch: %T"
,
v
)
}
switch
v
:=
event
.
Fields
[
"account_id"
]
.
(
type
)
{
case
int64
:
if
v
!=
101
{
t
.
Fatalf
(
"account_id field mismatch: %v"
,
v
)
}
case
int
:
if
v
!=
101
{
t
.
Fatalf
(
"account_id field mismatch: %v"
,
v
)
}
default
:
t
.
Fatalf
(
"account_id type mismatch: %T"
,
v
)
}
if
event
.
Fields
[
"platform"
]
!=
"openai"
||
event
.
Fields
[
"model"
]
!=
"gpt-5"
{
t
.
Fatalf
(
"platform/model mismatch: %+v"
,
event
.
Fields
)
}
}
if
!
found
{
t
.
Fatalf
(
"access log event not found"
)
}
}
func
TestLogger_HealthPathSkipped
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
sink
:=
initMiddlewareTestLogger
(
t
)
r
:=
gin
.
New
()
r
.
Use
(
Logger
())
r
.
GET
(
"/health"
,
func
(
c
*
gin
.
Context
)
{
c
.
Status
(
http
.
StatusOK
)
})
w
:=
httptest
.
NewRecorder
()
req
:=
httptest
.
NewRequest
(
http
.
MethodGet
,
"/health"
,
nil
)
r
.
ServeHTTP
(
w
,
req
)
if
w
.
Code
!=
http
.
StatusOK
{
t
.
Fatalf
(
"status=%d"
,
w
.
Code
)
}
if
len
(
sink
.
list
())
!=
0
{
t
.
Fatalf
(
"health endpoint should not write access log"
)
}
}
func
TestLogger_AccessLogDroppedWhenLevelWarn
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
sink
:=
initMiddlewareTestLoggerWithLevel
(
t
,
"warn"
)
r
:=
gin
.
New
()
r
.
Use
(
RequestLogger
())
r
.
Use
(
Logger
())
r
.
GET
(
"/api/test"
,
func
(
c
*
gin
.
Context
)
{
c
.
Status
(
http
.
StatusCreated
)
})
w
:=
httptest
.
NewRecorder
()
req
:=
httptest
.
NewRequest
(
http
.
MethodGet
,
"/api/test"
,
nil
)
r
.
ServeHTTP
(
w
,
req
)
if
w
.
Code
!=
http
.
StatusCreated
{
t
.
Fatalf
(
"status=%d"
,
w
.
Code
)
}
events
:=
sink
.
list
()
for
_
,
event
:=
range
events
{
if
event
!=
nil
&&
event
.
Message
==
"http request completed"
{
t
.
Fatalf
(
"access log should not be indexed when level=warn: %+v"
,
event
)
}
}
}
backend/internal/server/middleware/request_logger.go
0 → 100644
View file @
6bccb8a8
package
middleware
import
(
"context"
"strings"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
"go.uber.org/zap"
)
const
requestIDHeader
=
"X-Request-ID"
// RequestLogger 在请求入口注入 request-scoped logger。
func
RequestLogger
()
gin
.
HandlerFunc
{
return
func
(
c
*
gin
.
Context
)
{
if
c
.
Request
==
nil
{
c
.
Next
()
return
}
requestID
:=
strings
.
TrimSpace
(
c
.
GetHeader
(
requestIDHeader
))
if
requestID
==
""
{
requestID
=
uuid
.
NewString
()
}
c
.
Header
(
requestIDHeader
,
requestID
)
ctx
:=
context
.
WithValue
(
c
.
Request
.
Context
(),
ctxkey
.
RequestID
,
requestID
)
clientRequestID
,
_
:=
ctx
.
Value
(
ctxkey
.
ClientRequestID
)
.
(
string
)
requestLogger
:=
logger
.
With
(
zap
.
String
(
"component"
,
"http"
),
zap
.
String
(
"request_id"
,
requestID
),
zap
.
String
(
"client_request_id"
,
strings
.
TrimSpace
(
clientRequestID
)),
zap
.
String
(
"path"
,
c
.
Request
.
URL
.
Path
),
zap
.
String
(
"method"
,
c
.
Request
.
Method
),
)
ctx
=
logger
.
IntoContext
(
ctx
,
requestLogger
)
c
.
Request
=
c
.
Request
.
WithContext
(
ctx
)
c
.
Next
()
}
}
backend/internal/server/middleware/security_headers.go
View file @
6bccb8a8
...
@@ -3,6 +3,8 @@ package middleware
...
@@ -3,6 +3,8 @@ package middleware
import
(
import
(
"crypto/rand"
"crypto/rand"
"encoding/base64"
"encoding/base64"
"fmt"
"log"
"strings"
"strings"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/config"
...
@@ -18,11 +20,14 @@ const (
...
@@ -18,11 +20,14 @@ const (
CloudflareInsightsDomain
=
"https://static.cloudflareinsights.com"
CloudflareInsightsDomain
=
"https://static.cloudflareinsights.com"
)
)
// GenerateNonce generates a cryptographically secure random nonce
// GenerateNonce generates a cryptographically secure random nonce.
func
GenerateNonce
()
string
{
// 返回 error 以确保调用方在 crypto/rand 失败时能正确降级。
func
GenerateNonce
()
(
string
,
error
)
{
b
:=
make
([]
byte
,
16
)
b
:=
make
([]
byte
,
16
)
_
,
_
=
rand
.
Read
(
b
)
if
_
,
err
:=
rand
.
Read
(
b
);
err
!=
nil
{
return
base64
.
StdEncoding
.
EncodeToString
(
b
)
return
""
,
fmt
.
Errorf
(
"generate CSP nonce: %w"
,
err
)
}
return
base64
.
StdEncoding
.
EncodeToString
(
b
),
nil
}
}
// GetNonceFromContext retrieves the CSP nonce from gin context
// GetNonceFromContext retrieves the CSP nonce from gin context
...
@@ -52,12 +57,17 @@ func SecurityHeaders(cfg config.CSPConfig) gin.HandlerFunc {
...
@@ -52,12 +57,17 @@ func SecurityHeaders(cfg config.CSPConfig) gin.HandlerFunc {
if
cfg
.
Enabled
{
if
cfg
.
Enabled
{
// Generate nonce for this request
// Generate nonce for this request
nonce
:=
GenerateNonce
()
nonce
,
err
:=
GenerateNonce
()
c
.
Set
(
CSPNonceKey
,
nonce
)
if
err
!=
nil
{
// crypto/rand 失败时降级为无 nonce 的 CSP 策略
// Replace nonce placeholder in policy
log
.
Printf
(
"[SecurityHeaders] %v — 降级为无 nonce 的 CSP"
,
err
)
finalPolicy
:=
strings
.
ReplaceAll
(
policy
,
NonceTemplate
,
"'nonce-"
+
nonce
+
"'"
)
finalPolicy
:=
strings
.
ReplaceAll
(
policy
,
NonceTemplate
,
"'unsafe-inline'"
)
c
.
Header
(
"Content-Security-Policy"
,
finalPolicy
)
c
.
Header
(
"Content-Security-Policy"
,
finalPolicy
)
}
else
{
c
.
Set
(
CSPNonceKey
,
nonce
)
finalPolicy
:=
strings
.
ReplaceAll
(
policy
,
NonceTemplate
,
"'nonce-"
+
nonce
+
"'"
)
c
.
Header
(
"Content-Security-Policy"
,
finalPolicy
)
}
}
}
c
.
Next
()
c
.
Next
()
}
}
...
...
backend/internal/server/middleware/security_headers_test.go
View file @
6bccb8a8
...
@@ -19,7 +19,8 @@ func init() {
...
@@ -19,7 +19,8 @@ func init() {
func
TestGenerateNonce
(
t
*
testing
.
T
)
{
func
TestGenerateNonce
(
t
*
testing
.
T
)
{
t
.
Run
(
"generates_valid_base64_string"
,
func
(
t
*
testing
.
T
)
{
t
.
Run
(
"generates_valid_base64_string"
,
func
(
t
*
testing
.
T
)
{
nonce
:=
GenerateNonce
()
nonce
,
err
:=
GenerateNonce
()
require
.
NoError
(
t
,
err
)
// Should be valid base64
// Should be valid base64
decoded
,
err
:=
base64
.
StdEncoding
.
DecodeString
(
nonce
)
decoded
,
err
:=
base64
.
StdEncoding
.
DecodeString
(
nonce
)
...
@@ -32,14 +33,16 @@ func TestGenerateNonce(t *testing.T) {
...
@@ -32,14 +33,16 @@ func TestGenerateNonce(t *testing.T) {
t
.
Run
(
"generates_unique_nonces"
,
func
(
t
*
testing
.
T
)
{
t
.
Run
(
"generates_unique_nonces"
,
func
(
t
*
testing
.
T
)
{
nonces
:=
make
(
map
[
string
]
bool
)
nonces
:=
make
(
map
[
string
]
bool
)
for
i
:=
0
;
i
<
100
;
i
++
{
for
i
:=
0
;
i
<
100
;
i
++
{
nonce
:=
GenerateNonce
()
nonce
,
err
:=
GenerateNonce
()
require
.
NoError
(
t
,
err
)
assert
.
False
(
t
,
nonces
[
nonce
],
"nonce should be unique"
)
assert
.
False
(
t
,
nonces
[
nonce
],
"nonce should be unique"
)
nonces
[
nonce
]
=
true
nonces
[
nonce
]
=
true
}
}
})
})
t
.
Run
(
"nonce_has_expected_length"
,
func
(
t
*
testing
.
T
)
{
t
.
Run
(
"nonce_has_expected_length"
,
func
(
t
*
testing
.
T
)
{
nonce
:=
GenerateNonce
()
nonce
,
err
:=
GenerateNonce
()
require
.
NoError
(
t
,
err
)
// 16 bytes -> 24 chars in base64 (with padding)
// 16 bytes -> 24 chars in base64 (with padding)
assert
.
Len
(
t
,
nonce
,
24
)
assert
.
Len
(
t
,
nonce
,
24
)
})
})
...
@@ -344,7 +347,7 @@ func TestAddToDirective(t *testing.T) {
...
@@ -344,7 +347,7 @@ func TestAddToDirective(t *testing.T) {
// Benchmark tests
// Benchmark tests
func
BenchmarkGenerateNonce
(
b
*
testing
.
B
)
{
func
BenchmarkGenerateNonce
(
b
*
testing
.
B
)
{
for
i
:=
0
;
i
<
b
.
N
;
i
++
{
for
i
:=
0
;
i
<
b
.
N
;
i
++
{
GenerateNonce
()
_
,
_
=
GenerateNonce
()
}
}
}
}
...
...
backend/internal/server/router.go
View file @
6bccb8a8
...
@@ -29,6 +29,7 @@ func SetupRouter(
...
@@ -29,6 +29,7 @@ func SetupRouter(
redisClient
*
redis
.
Client
,
redisClient
*
redis
.
Client
,
)
*
gin
.
Engine
{
)
*
gin
.
Engine
{
// 应用中间件
// 应用中间件
r
.
Use
(
middleware2
.
RequestLogger
())
r
.
Use
(
middleware2
.
Logger
())
r
.
Use
(
middleware2
.
Logger
())
r
.
Use
(
middleware2
.
CORS
(
cfg
.
CORS
))
r
.
Use
(
middleware2
.
CORS
(
cfg
.
CORS
))
r
.
Use
(
middleware2
.
SecurityHeaders
(
cfg
.
Security
.
CSP
))
r
.
Use
(
middleware2
.
SecurityHeaders
(
cfg
.
Security
.
CSP
))
...
...
backend/internal/server/routes/admin.go
View file @
6bccb8a8
...
@@ -34,6 +34,8 @@ func RegisterAdminRoutes(
...
@@ -34,6 +34,8 @@ func RegisterAdminRoutes(
// OpenAI OAuth
// OpenAI OAuth
registerOpenAIOAuthRoutes
(
admin
,
h
)
registerOpenAIOAuthRoutes
(
admin
,
h
)
// Sora OAuth(实现复用 OpenAI OAuth 服务,入口独立)
registerSoraOAuthRoutes
(
admin
,
h
)
// Gemini OAuth
// Gemini OAuth
registerGeminiOAuthRoutes
(
admin
,
h
)
registerGeminiOAuthRoutes
(
admin
,
h
)
...
@@ -101,6 +103,9 @@ func registerOpsRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
...
@@ -101,6 +103,9 @@ func registerOpsRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
{
{
runtime
.
GET
(
"/alert"
,
h
.
Admin
.
Ops
.
GetAlertRuntimeSettings
)
runtime
.
GET
(
"/alert"
,
h
.
Admin
.
Ops
.
GetAlertRuntimeSettings
)
runtime
.
PUT
(
"/alert"
,
h
.
Admin
.
Ops
.
UpdateAlertRuntimeSettings
)
runtime
.
PUT
(
"/alert"
,
h
.
Admin
.
Ops
.
UpdateAlertRuntimeSettings
)
runtime
.
GET
(
"/logging"
,
h
.
Admin
.
Ops
.
GetRuntimeLogConfig
)
runtime
.
PUT
(
"/logging"
,
h
.
Admin
.
Ops
.
UpdateRuntimeLogConfig
)
runtime
.
POST
(
"/logging/reset"
,
h
.
Admin
.
Ops
.
ResetRuntimeLogConfig
)
}
}
// Advanced settings (DB-backed)
// Advanced settings (DB-backed)
...
@@ -144,12 +149,18 @@ func registerOpsRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
...
@@ -144,12 +149,18 @@ func registerOpsRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
// Request drilldown (success + error)
// Request drilldown (success + error)
ops
.
GET
(
"/requests"
,
h
.
Admin
.
Ops
.
ListRequestDetails
)
ops
.
GET
(
"/requests"
,
h
.
Admin
.
Ops
.
ListRequestDetails
)
// Indexed system logs
ops
.
GET
(
"/system-logs"
,
h
.
Admin
.
Ops
.
ListSystemLogs
)
ops
.
POST
(
"/system-logs/cleanup"
,
h
.
Admin
.
Ops
.
CleanupSystemLogs
)
ops
.
GET
(
"/system-logs/health"
,
h
.
Admin
.
Ops
.
GetSystemLogIngestionHealth
)
// Dashboard (vNext - raw path for MVP)
// Dashboard (vNext - raw path for MVP)
ops
.
GET
(
"/dashboard/overview"
,
h
.
Admin
.
Ops
.
GetDashboardOverview
)
ops
.
GET
(
"/dashboard/overview"
,
h
.
Admin
.
Ops
.
GetDashboardOverview
)
ops
.
GET
(
"/dashboard/throughput-trend"
,
h
.
Admin
.
Ops
.
GetDashboardThroughputTrend
)
ops
.
GET
(
"/dashboard/throughput-trend"
,
h
.
Admin
.
Ops
.
GetDashboardThroughputTrend
)
ops
.
GET
(
"/dashboard/latency-histogram"
,
h
.
Admin
.
Ops
.
GetDashboardLatencyHistogram
)
ops
.
GET
(
"/dashboard/latency-histogram"
,
h
.
Admin
.
Ops
.
GetDashboardLatencyHistogram
)
ops
.
GET
(
"/dashboard/error-trend"
,
h
.
Admin
.
Ops
.
GetDashboardErrorTrend
)
ops
.
GET
(
"/dashboard/error-trend"
,
h
.
Admin
.
Ops
.
GetDashboardErrorTrend
)
ops
.
GET
(
"/dashboard/error-distribution"
,
h
.
Admin
.
Ops
.
GetDashboardErrorDistribution
)
ops
.
GET
(
"/dashboard/error-distribution"
,
h
.
Admin
.
Ops
.
GetDashboardErrorDistribution
)
ops
.
GET
(
"/dashboard/openai-token-stats"
,
h
.
Admin
.
Ops
.
GetDashboardOpenAITokenStats
)
}
}
}
}
...
@@ -267,6 +278,19 @@ func registerOpenAIOAuthRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
...
@@ -267,6 +278,19 @@ func registerOpenAIOAuthRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
}
}
}
}
func
registerSoraOAuthRoutes
(
admin
*
gin
.
RouterGroup
,
h
*
handler
.
Handlers
)
{
sora
:=
admin
.
Group
(
"/sora"
)
{
sora
.
POST
(
"/generate-auth-url"
,
h
.
Admin
.
OpenAIOAuth
.
GenerateAuthURL
)
sora
.
POST
(
"/exchange-code"
,
h
.
Admin
.
OpenAIOAuth
.
ExchangeCode
)
sora
.
POST
(
"/refresh-token"
,
h
.
Admin
.
OpenAIOAuth
.
RefreshToken
)
sora
.
POST
(
"/st2at"
,
h
.
Admin
.
OpenAIOAuth
.
ExchangeSoraSessionToken
)
sora
.
POST
(
"/rt2at"
,
h
.
Admin
.
OpenAIOAuth
.
RefreshToken
)
sora
.
POST
(
"/accounts/:id/refresh"
,
h
.
Admin
.
OpenAIOAuth
.
RefreshAccountToken
)
sora
.
POST
(
"/create-from-oauth"
,
h
.
Admin
.
OpenAIOAuth
.
CreateAccountFromOAuth
)
}
}
func
registerGeminiOAuthRoutes
(
admin
*
gin
.
RouterGroup
,
h
*
handler
.
Handlers
)
{
func
registerGeminiOAuthRoutes
(
admin
*
gin
.
RouterGroup
,
h
*
handler
.
Handlers
)
{
gemini
:=
admin
.
Group
(
"/gemini"
)
gemini
:=
admin
.
Group
(
"/gemini"
)
{
{
...
@@ -297,6 +321,7 @@ func registerProxyRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
...
@@ -297,6 +321,7 @@ func registerProxyRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
proxies
.
PUT
(
"/:id"
,
h
.
Admin
.
Proxy
.
Update
)
proxies
.
PUT
(
"/:id"
,
h
.
Admin
.
Proxy
.
Update
)
proxies
.
DELETE
(
"/:id"
,
h
.
Admin
.
Proxy
.
Delete
)
proxies
.
DELETE
(
"/:id"
,
h
.
Admin
.
Proxy
.
Delete
)
proxies
.
POST
(
"/:id/test"
,
h
.
Admin
.
Proxy
.
Test
)
proxies
.
POST
(
"/:id/test"
,
h
.
Admin
.
Proxy
.
Test
)
proxies
.
POST
(
"/:id/quality-check"
,
h
.
Admin
.
Proxy
.
CheckQuality
)
proxies
.
GET
(
"/:id/stats"
,
h
.
Admin
.
Proxy
.
GetStats
)
proxies
.
GET
(
"/:id/stats"
,
h
.
Admin
.
Proxy
.
GetStats
)
proxies
.
GET
(
"/:id/accounts"
,
h
.
Admin
.
Proxy
.
GetProxyAccounts
)
proxies
.
GET
(
"/:id/accounts"
,
h
.
Admin
.
Proxy
.
GetProxyAccounts
)
proxies
.
POST
(
"/batch-delete"
,
h
.
Admin
.
Proxy
.
BatchDelete
)
proxies
.
POST
(
"/batch-delete"
,
h
.
Admin
.
Proxy
.
BatchDelete
)
...
...
backend/internal/server/routes/auth.go
View file @
6bccb8a8
...
@@ -24,10 +24,19 @@ func RegisterAuthRoutes(
...
@@ -24,10 +24,19 @@ func RegisterAuthRoutes(
// 公开接口
// 公开接口
auth
:=
v1
.
Group
(
"/auth"
)
auth
:=
v1
.
Group
(
"/auth"
)
{
{
auth
.
POST
(
"/register"
,
h
.
Auth
.
Register
)
// 注册/登录/2FA/验证码发送均属于高风险入口,增加服务端兜底限流(Redis 故障时 fail-close)
auth
.
POST
(
"/login"
,
h
.
Auth
.
Login
)
auth
.
POST
(
"/register"
,
rateLimiter
.
LimitWithOptions
(
"auth-register"
,
5
,
time
.
Minute
,
middleware
.
RateLimitOptions
{
auth
.
POST
(
"/login/2fa"
,
h
.
Auth
.
Login2FA
)
FailureMode
:
middleware
.
RateLimitFailClose
,
auth
.
POST
(
"/send-verify-code"
,
h
.
Auth
.
SendVerifyCode
)
}),
h
.
Auth
.
Register
)
auth
.
POST
(
"/login"
,
rateLimiter
.
LimitWithOptions
(
"auth-login"
,
20
,
time
.
Minute
,
middleware
.
RateLimitOptions
{
FailureMode
:
middleware
.
RateLimitFailClose
,
}),
h
.
Auth
.
Login
)
auth
.
POST
(
"/login/2fa"
,
rateLimiter
.
LimitWithOptions
(
"auth-login-2fa"
,
20
,
time
.
Minute
,
middleware
.
RateLimitOptions
{
FailureMode
:
middleware
.
RateLimitFailClose
,
}),
h
.
Auth
.
Login2FA
)
auth
.
POST
(
"/send-verify-code"
,
rateLimiter
.
LimitWithOptions
(
"auth-send-verify-code"
,
5
,
time
.
Minute
,
middleware
.
RateLimitOptions
{
FailureMode
:
middleware
.
RateLimitFailClose
,
}),
h
.
Auth
.
SendVerifyCode
)
// Token刷新接口添加速率限制:每分钟最多 30 次(Redis 故障时 fail-close)
// Token刷新接口添加速率限制:每分钟最多 30 次(Redis 故障时 fail-close)
auth
.
POST
(
"/refresh"
,
rateLimiter
.
LimitWithOptions
(
"refresh-token"
,
30
,
time
.
Minute
,
middleware
.
RateLimitOptions
{
auth
.
POST
(
"/refresh"
,
rateLimiter
.
LimitWithOptions
(
"refresh-token"
,
30
,
time
.
Minute
,
middleware
.
RateLimitOptions
{
FailureMode
:
middleware
.
RateLimitFailClose
,
FailureMode
:
middleware
.
RateLimitFailClose
,
...
...
backend/internal/server/routes/auth_rate_limit_integration_test.go
0 → 100644
View file @
6bccb8a8
//go:build integration
package
routes
import
(
"context"
"fmt"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"strconv"
"strings"
"testing"
"github.com/redis/go-redis/v9"
"github.com/stretchr/testify/require"
tcredis
"github.com/testcontainers/testcontainers-go/modules/redis"
)
const
authRouteRedisImageTag
=
"redis:8.4-alpine"
func
TestAuthRegisterRateLimitThresholdHitReturns429
(
t
*
testing
.
T
)
{
ctx
:=
context
.
Background
()
rdb
:=
startAuthRouteRedis
(
t
,
ctx
)
router
:=
newAuthRoutesTestRouter
(
rdb
)
const
path
=
"/api/v1/auth/register"
for
i
:=
1
;
i
<=
6
;
i
++
{
req
:=
httptest
.
NewRequest
(
http
.
MethodPost
,
path
,
strings
.
NewReader
(
`{}`
))
req
.
Header
.
Set
(
"Content-Type"
,
"application/json"
)
req
.
RemoteAddr
=
"198.51.100.10:23456"
w
:=
httptest
.
NewRecorder
()
router
.
ServeHTTP
(
w
,
req
)
if
i
<=
5
{
require
.
Equal
(
t
,
http
.
StatusBadRequest
,
w
.
Code
,
"第 %d 次请求应先进入业务校验"
,
i
)
continue
}
require
.
Equal
(
t
,
http
.
StatusTooManyRequests
,
w
.
Code
,
"第 6 次请求应命中限流"
)
require
.
Contains
(
t
,
w
.
Body
.
String
(),
"rate limit exceeded"
)
}
}
func
startAuthRouteRedis
(
t
*
testing
.
T
,
ctx
context
.
Context
)
*
redis
.
Client
{
t
.
Helper
()
ensureAuthRouteDockerAvailable
(
t
)
redisContainer
,
err
:=
tcredis
.
Run
(
ctx
,
authRouteRedisImageTag
)
require
.
NoError
(
t
,
err
)
t
.
Cleanup
(
func
()
{
_
=
redisContainer
.
Terminate
(
ctx
)
})
redisHost
,
err
:=
redisContainer
.
Host
(
ctx
)
require
.
NoError
(
t
,
err
)
redisPort
,
err
:=
redisContainer
.
MappedPort
(
ctx
,
"6379/tcp"
)
require
.
NoError
(
t
,
err
)
rdb
:=
redis
.
NewClient
(
&
redis
.
Options
{
Addr
:
fmt
.
Sprintf
(
"%s:%d"
,
redisHost
,
redisPort
.
Int
()),
DB
:
0
,
})
require
.
NoError
(
t
,
rdb
.
Ping
(
ctx
)
.
Err
())
t
.
Cleanup
(
func
()
{
_
=
rdb
.
Close
()
})
return
rdb
}
func
ensureAuthRouteDockerAvailable
(
t
*
testing
.
T
)
{
t
.
Helper
()
if
authRouteDockerAvailable
()
{
return
}
t
.
Skip
(
"Docker 未启用,跳过认证限流集成测试"
)
}
func
authRouteDockerAvailable
()
bool
{
if
os
.
Getenv
(
"DOCKER_HOST"
)
!=
""
{
return
true
}
socketCandidates
:=
[]
string
{
"/var/run/docker.sock"
,
filepath
.
Join
(
os
.
Getenv
(
"XDG_RUNTIME_DIR"
),
"docker.sock"
),
filepath
.
Join
(
authRouteUserHomeDir
(),
".docker"
,
"run"
,
"docker.sock"
),
filepath
.
Join
(
authRouteUserHomeDir
(),
".docker"
,
"desktop"
,
"docker.sock"
),
filepath
.
Join
(
"/run/user"
,
strconv
.
Itoa
(
os
.
Getuid
()),
"docker.sock"
),
}
for
_
,
socket
:=
range
socketCandidates
{
if
socket
==
""
{
continue
}
if
_
,
err
:=
os
.
Stat
(
socket
);
err
==
nil
{
return
true
}
}
return
false
}
func
authRouteUserHomeDir
()
string
{
home
,
err
:=
os
.
UserHomeDir
()
if
err
!=
nil
{
return
""
}
return
home
}
backend/internal/server/routes/auth_rate_limit_test.go
0 → 100644
View file @
6bccb8a8
package
routes
import
(
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/handler"
servermiddleware
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/gin-gonic/gin"
"github.com/redis/go-redis/v9"
"github.com/stretchr/testify/require"
)
func
newAuthRoutesTestRouter
(
redisClient
*
redis
.
Client
)
*
gin
.
Engine
{
gin
.
SetMode
(
gin
.
TestMode
)
router
:=
gin
.
New
()
v1
:=
router
.
Group
(
"/api/v1"
)
RegisterAuthRoutes
(
v1
,
&
handler
.
Handlers
{
Auth
:
&
handler
.
AuthHandler
{},
Setting
:
&
handler
.
SettingHandler
{},
},
servermiddleware
.
JWTAuthMiddleware
(
func
(
c
*
gin
.
Context
)
{
c
.
Next
()
}),
redisClient
,
)
return
router
}
func
TestAuthRoutesRateLimitFailCloseWhenRedisUnavailable
(
t
*
testing
.
T
)
{
rdb
:=
redis
.
NewClient
(
&
redis
.
Options
{
Addr
:
"127.0.0.1:1"
,
DialTimeout
:
50
*
time
.
Millisecond
,
ReadTimeout
:
50
*
time
.
Millisecond
,
WriteTimeout
:
50
*
time
.
Millisecond
,
})
t
.
Cleanup
(
func
()
{
_
=
rdb
.
Close
()
})
router
:=
newAuthRoutesTestRouter
(
rdb
)
paths
:=
[]
string
{
"/api/v1/auth/register"
,
"/api/v1/auth/login"
,
"/api/v1/auth/login/2fa"
,
"/api/v1/auth/send-verify-code"
,
}
for
_
,
path
:=
range
paths
{
req
:=
httptest
.
NewRequest
(
http
.
MethodPost
,
path
,
strings
.
NewReader
(
`{}`
))
req
.
Header
.
Set
(
"Content-Type"
,
"application/json"
)
req
.
RemoteAddr
=
"203.0.113.10:12345"
w
:=
httptest
.
NewRecorder
()
router
.
ServeHTTP
(
w
,
req
)
require
.
Equal
(
t
,
http
.
StatusTooManyRequests
,
w
.
Code
,
"path=%s"
,
path
)
require
.
Contains
(
t
,
w
.
Body
.
String
(),
"rate limit exceeded"
,
"path=%s"
,
path
)
}
}
backend/internal/server/routes/gateway.go
View file @
6bccb8a8
package
routes
package
routes
import
(
import
(
"net/http"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/handler"
"github.com/Wei-Shaw/sub2api/internal/handler"
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
...
@@ -20,6 +22,11 @@ func RegisterGatewayRoutes(
...
@@ -20,6 +22,11 @@ func RegisterGatewayRoutes(
cfg
*
config
.
Config
,
cfg
*
config
.
Config
,
)
{
)
{
bodyLimit
:=
middleware
.
RequestBodyLimit
(
cfg
.
Gateway
.
MaxBodySize
)
bodyLimit
:=
middleware
.
RequestBodyLimit
(
cfg
.
Gateway
.
MaxBodySize
)
soraMaxBodySize
:=
cfg
.
Gateway
.
SoraMaxBodySize
if
soraMaxBodySize
<=
0
{
soraMaxBodySize
=
cfg
.
Gateway
.
MaxBodySize
}
soraBodyLimit
:=
middleware
.
RequestBodyLimit
(
soraMaxBodySize
)
clientRequestID
:=
middleware
.
ClientRequestID
()
clientRequestID
:=
middleware
.
ClientRequestID
()
opsErrorLogger
:=
handler
.
OpsErrorLoggerMiddleware
(
opsService
)
opsErrorLogger
:=
handler
.
OpsErrorLoggerMiddleware
(
opsService
)
...
@@ -36,6 +43,15 @@ func RegisterGatewayRoutes(
...
@@ -36,6 +43,15 @@ func RegisterGatewayRoutes(
gateway
.
GET
(
"/usage"
,
h
.
Gateway
.
Usage
)
gateway
.
GET
(
"/usage"
,
h
.
Gateway
.
Usage
)
// OpenAI Responses API
// OpenAI Responses API
gateway
.
POST
(
"/responses"
,
h
.
OpenAIGateway
.
Responses
)
gateway
.
POST
(
"/responses"
,
h
.
OpenAIGateway
.
Responses
)
// 明确阻止旧协议入口:OpenAI 仅支持 Responses API,避免客户端误解为会自动路由到其它平台。
gateway
.
POST
(
"/chat/completions"
,
func
(
c
*
gin
.
Context
)
{
c
.
JSON
(
http
.
StatusBadRequest
,
gin
.
H
{
"error"
:
gin
.
H
{
"type"
:
"invalid_request_error"
,
"message"
:
"Unsupported legacy protocol: /v1/chat/completions is not supported. Please use /v1/responses."
,
},
})
})
}
}
// Gemini 原生 API 兼容层(Gemini SDK/CLI 直连)
// Gemini 原生 API 兼容层(Gemini SDK/CLI 直连)
...
@@ -82,4 +98,25 @@ func RegisterGatewayRoutes(
...
@@ -82,4 +98,25 @@ func RegisterGatewayRoutes(
antigravityV1Beta
.
GET
(
"/models/:model"
,
h
.
Gateway
.
GeminiV1BetaGetModel
)
antigravityV1Beta
.
GET
(
"/models/:model"
,
h
.
Gateway
.
GeminiV1BetaGetModel
)
antigravityV1Beta
.
POST
(
"/models/*modelAction"
,
h
.
Gateway
.
GeminiV1BetaModels
)
antigravityV1Beta
.
POST
(
"/models/*modelAction"
,
h
.
Gateway
.
GeminiV1BetaModels
)
}
}
// Sora 专用路由(强制使用 sora 平台)
soraV1
:=
r
.
Group
(
"/sora/v1"
)
soraV1
.
Use
(
soraBodyLimit
)
soraV1
.
Use
(
clientRequestID
)
soraV1
.
Use
(
opsErrorLogger
)
soraV1
.
Use
(
middleware
.
ForcePlatform
(
service
.
PlatformSora
))
soraV1
.
Use
(
gin
.
HandlerFunc
(
apiKeyAuth
))
{
soraV1
.
POST
(
"/chat/completions"
,
h
.
SoraGateway
.
ChatCompletions
)
soraV1
.
GET
(
"/models"
,
h
.
Gateway
.
Models
)
}
// Sora 媒体代理(可选 API Key 验证)
if
cfg
.
Gateway
.
SoraMediaRequireAPIKey
{
r
.
GET
(
"/sora/media/*filepath"
,
gin
.
HandlerFunc
(
apiKeyAuth
),
h
.
SoraGateway
.
MediaProxy
)
}
else
{
r
.
GET
(
"/sora/media/*filepath"
,
h
.
SoraGateway
.
MediaProxy
)
}
// Sora 媒体代理(签名 URL,无需 API Key)
r
.
GET
(
"/sora/media-signed/*filepath"
,
h
.
SoraGateway
.
MediaProxySigned
)
}
}
Prev
1
…
6
7
8
9
10
11
12
13
14
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment