Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Menu
Open sidebar
陈曦
sub2api
Commits
195e227c
Commit
195e227c
authored
Jan 06, 2026
by
song
Browse files
merge: 合并 upstream/main 并保留本地图片计费功能
parents
6fa704d6
752882a0
Changes
187
Hide whitespace changes
Inline
Side-by-side
backend/internal/repository/http_upstream_test.go
View file @
195e227c
...
...
@@ -22,7 +22,13 @@ type HTTPUpstreamSuite struct {
// SetupTest 每个测试用例执行前的初始化
// 创建空配置,各测试用例可按需覆盖
func
(
s
*
HTTPUpstreamSuite
)
SetupTest
()
{
s
.
cfg
=
&
config
.
Config
{}
s
.
cfg
=
&
config
.
Config
{
Security
:
config
.
SecurityConfig
{
URLAllowlist
:
config
.
URLAllowlistConfig
{
AllowPrivateHosts
:
true
,
},
},
}
}
// newService 创建测试用的 httpUpstreamService 实例
...
...
backend/internal/repository/migrations_schema_integration_test.go
View file @
195e227c
...
...
@@ -26,6 +26,7 @@ func TestMigrationsRunner_IsIdempotent_AndSchemaIsUpToDate(t *testing.T) {
requireColumn
(
t
,
tx
,
"users"
,
"notes"
,
"text"
,
0
,
false
)
// accounts: schedulable and rate-limit fields
requireColumn
(
t
,
tx
,
"accounts"
,
"notes"
,
"text"
,
0
,
true
)
requireColumn
(
t
,
tx
,
"accounts"
,
"schedulable"
,
"boolean"
,
0
,
false
)
requireColumn
(
t
,
tx
,
"accounts"
,
"rate_limited_at"
,
"timestamp with time zone"
,
0
,
true
)
requireColumn
(
t
,
tx
,
"accounts"
,
"rate_limit_reset_at"
,
"timestamp with time zone"
,
0
,
true
)
...
...
backend/internal/repository/pricing_service.go
View file @
195e227c
...
...
@@ -8,6 +8,7 @@ import (
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/httpclient"
"github.com/Wei-Shaw/sub2api/internal/service"
)
...
...
@@ -16,9 +17,17 @@ type pricingRemoteClient struct {
httpClient
*
http
.
Client
}
func
NewPricingRemoteClient
()
service
.
PricingRemoteClient
{
func
NewPricingRemoteClient
(
cfg
*
config
.
Config
)
service
.
PricingRemoteClient
{
allowPrivate
:=
false
validateResolvedIP
:=
true
if
cfg
!=
nil
{
allowPrivate
=
cfg
.
Security
.
URLAllowlist
.
AllowPrivateHosts
validateResolvedIP
=
cfg
.
Security
.
URLAllowlist
.
Enabled
}
sharedClient
,
err
:=
httpclient
.
GetClient
(
httpclient
.
Options
{
Timeout
:
30
*
time
.
Second
,
Timeout
:
30
*
time
.
Second
,
ValidateResolvedIP
:
validateResolvedIP
,
AllowPrivateHosts
:
allowPrivate
,
})
if
err
!=
nil
{
sharedClient
=
&
http
.
Client
{
Timeout
:
30
*
time
.
Second
}
...
...
backend/internal/repository/pricing_service_test.go
View file @
195e227c
...
...
@@ -6,6 +6,7 @@ import (
"net/http/httptest"
"testing"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
)
...
...
@@ -19,7 +20,13 @@ type PricingServiceSuite struct {
func
(
s
*
PricingServiceSuite
)
SetupTest
()
{
s
.
ctx
=
context
.
Background
()
client
,
ok
:=
NewPricingRemoteClient
()
.
(
*
pricingRemoteClient
)
client
,
ok
:=
NewPricingRemoteClient
(
&
config
.
Config
{
Security
:
config
.
SecurityConfig
{
URLAllowlist
:
config
.
URLAllowlistConfig
{
AllowPrivateHosts
:
true
,
},
},
})
.
(
*
pricingRemoteClient
)
require
.
True
(
s
.
T
(),
ok
,
"type assertion failed"
)
s
.
client
=
client
}
...
...
backend/internal/repository/proxy_probe_service.go
View file @
195e227c
...
...
@@ -5,28 +5,52 @@ import (
"encoding/json"
"fmt"
"io"
"log"
"net/http"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/httpclient"
"github.com/Wei-Shaw/sub2api/internal/service"
)
func
NewProxyExitInfoProber
()
service
.
ProxyExitInfoProber
{
return
&
proxyProbeService
{
ipInfoURL
:
defaultIPInfoURL
}
func
NewProxyExitInfoProber
(
cfg
*
config
.
Config
)
service
.
ProxyExitInfoProber
{
insecure
:=
false
allowPrivate
:=
false
validateResolvedIP
:=
true
if
cfg
!=
nil
{
insecure
=
cfg
.
Security
.
ProxyProbe
.
InsecureSkipVerify
allowPrivate
=
cfg
.
Security
.
URLAllowlist
.
AllowPrivateHosts
validateResolvedIP
=
cfg
.
Security
.
URLAllowlist
.
Enabled
}
if
insecure
{
log
.
Printf
(
"[ProxyProbe] Warning: TLS verification is disabled for proxy probing."
)
}
return
&
proxyProbeService
{
ipInfoURL
:
defaultIPInfoURL
,
insecureSkipVerify
:
insecure
,
allowPrivateHosts
:
allowPrivate
,
validateResolvedIP
:
validateResolvedIP
,
}
}
const
defaultIPInfoURL
=
"https://ipinfo.io/json"
type
proxyProbeService
struct
{
ipInfoURL
string
ipInfoURL
string
insecureSkipVerify
bool
allowPrivateHosts
bool
validateResolvedIP
bool
}
func
(
s
*
proxyProbeService
)
ProbeProxy
(
ctx
context
.
Context
,
proxyURL
string
)
(
*
service
.
ProxyExitInfo
,
int64
,
error
)
{
client
,
err
:=
httpclient
.
GetClient
(
httpclient
.
Options
{
ProxyURL
:
proxyURL
,
Timeout
:
15
*
time
.
Second
,
InsecureSkipVerify
:
true
,
InsecureSkipVerify
:
s
.
insecureSkipVerify
,
ProxyStrict
:
true
,
ValidateResolvedIP
:
s
.
validateResolvedIP
,
AllowPrivateHosts
:
s
.
allowPrivateHosts
,
})
if
err
!=
nil
{
return
nil
,
0
,
fmt
.
Errorf
(
"failed to create proxy client: %w"
,
err
)
...
...
backend/internal/repository/proxy_probe_service_test.go
View file @
195e227c
...
...
@@ -20,7 +20,10 @@ type ProxyProbeServiceSuite struct {
func
(
s
*
ProxyProbeServiceSuite
)
SetupTest
()
{
s
.
ctx
=
context
.
Background
()
s
.
prober
=
&
proxyProbeService
{
ipInfoURL
:
"http://ipinfo.test/json"
}
s
.
prober
=
&
proxyProbeService
{
ipInfoURL
:
"http://ipinfo.test/json"
,
allowPrivateHosts
:
true
,
}
}
func
(
s
*
ProxyProbeServiceSuite
)
TearDownTest
()
{
...
...
backend/internal/repository/turnstile_service.go
View file @
195e227c
...
...
@@ -22,7 +22,8 @@ type turnstileVerifier struct {
func
NewTurnstileVerifier
()
service
.
TurnstileVerifier
{
sharedClient
,
err
:=
httpclient
.
GetClient
(
httpclient
.
Options
{
Timeout
:
10
*
time
.
Second
,
Timeout
:
10
*
time
.
Second
,
ValidateResolvedIP
:
true
,
})
if
err
!=
nil
{
sharedClient
=
&
http
.
Client
{
Timeout
:
10
*
time
.
Second
}
...
...
backend/internal/repository/user_repo.go
View file @
195e227c
...
...
@@ -329,17 +329,20 @@ func (r *userRepository) UpdateBalance(ctx context.Context, id int64, amount flo
return
nil
}
// DeductBalance 扣除用户余额
// 透支策略:允许余额变为负数,确保当前请求能够完成
// 中间件会阻止余额 <= 0 的用户发起后续请求
func
(
r
*
userRepository
)
DeductBalance
(
ctx
context
.
Context
,
id
int64
,
amount
float64
)
error
{
client
:=
clientFromContext
(
ctx
,
r
.
client
)
n
,
err
:=
client
.
User
.
Update
()
.
Where
(
dbuser
.
IDEQ
(
id
)
,
dbuser
.
BalanceGTE
(
amount
)
)
.
Where
(
dbuser
.
IDEQ
(
id
))
.
AddBalance
(
-
amount
)
.
Save
(
ctx
)
if
err
!=
nil
{
return
err
}
if
n
==
0
{
return
service
.
Err
InsufficientBalance
return
service
.
Err
UserNotFound
}
return
nil
}
...
...
backend/internal/repository/user_repo_integration_test.go
View file @
195e227c
...
...
@@ -290,9 +290,14 @@ func (s *UserRepoSuite) TestDeductBalance() {
func
(
s
*
UserRepoSuite
)
TestDeductBalance_InsufficientFunds
()
{
user
:=
s
.
mustCreateUser
(
&
service
.
User
{
Email
:
"insuf@test.com"
,
Balance
:
5
})
// 透支策略:允许扣除超过余额的金额
err
:=
s
.
repo
.
DeductBalance
(
s
.
ctx
,
user
.
ID
,
999
)
s
.
Require
()
.
Error
(
err
,
"expected error for insufficient balance"
)
s
.
Require
()
.
ErrorIs
(
err
,
service
.
ErrInsufficientBalance
)
s
.
Require
()
.
NoError
(
err
,
"DeductBalance should allow overdraft"
)
// 验证余额变为负数
got
,
err
:=
s
.
repo
.
GetByID
(
s
.
ctx
,
user
.
ID
)
s
.
Require
()
.
NoError
(
err
)
s
.
Require
()
.
InDelta
(
-
994.0
,
got
.
Balance
,
1e-6
,
"Balance should be negative after overdraft"
)
}
func
(
s
*
UserRepoSuite
)
TestDeductBalance_ExactAmount
()
{
...
...
@@ -306,6 +311,19 @@ func (s *UserRepoSuite) TestDeductBalance_ExactAmount() {
s
.
Require
()
.
InDelta
(
0.0
,
got
.
Balance
,
1e-6
)
}
func
(
s
*
UserRepoSuite
)
TestDeductBalance_AllowsOverdraft
()
{
user
:=
s
.
mustCreateUser
(
&
service
.
User
{
Email
:
"overdraft@test.com"
,
Balance
:
5.0
})
// 扣除超过余额的金额 - 应该成功
err
:=
s
.
repo
.
DeductBalance
(
s
.
ctx
,
user
.
ID
,
10.0
)
s
.
Require
()
.
NoError
(
err
,
"DeductBalance should allow overdraft"
)
// 验证余额为负
got
,
err
:=
s
.
repo
.
GetByID
(
s
.
ctx
,
user
.
ID
)
s
.
Require
()
.
NoError
(
err
)
s
.
Require
()
.
InDelta
(
-
5.0
,
got
.
Balance
,
1e-6
,
"Balance should be -5.0 after overdraft"
)
}
// --- Concurrency ---
func
(
s
*
UserRepoSuite
)
TestUpdateConcurrency
()
{
...
...
@@ -477,9 +495,12 @@ func (s *UserRepoSuite) TestCRUD_And_Filters_And_AtomicUpdates() {
s
.
Require
()
.
NoError
(
err
,
"GetByID after DeductBalance"
)
s
.
Require
()
.
InDelta
(
7.5
,
got4
.
Balance
,
1e-6
)
// 透支策略:允许扣除超过余额的金额
err
=
s
.
repo
.
DeductBalance
(
s
.
ctx
,
user1
.
ID
,
999
)
s
.
Require
()
.
Error
(
err
,
"DeductBalance expected error for insufficient balance"
)
s
.
Require
()
.
ErrorIs
(
err
,
service
.
ErrInsufficientBalance
,
"DeductBalance unexpected error"
)
s
.
Require
()
.
NoError
(
err
,
"DeductBalance should allow overdraft"
)
gotOverdraft
,
err
:=
s
.
repo
.
GetByID
(
s
.
ctx
,
user1
.
ID
)
s
.
Require
()
.
NoError
(
err
,
"GetByID after overdraft"
)
s
.
Require
()
.
Less
(
gotOverdraft
.
Balance
,
0.0
,
"Balance should be negative after overdraft"
)
s
.
Require
()
.
NoError
(
s
.
repo
.
UpdateConcurrency
(
s
.
ctx
,
user1
.
ID
,
3
),
"UpdateConcurrency"
)
got5
,
err
:=
s
.
repo
.
GetByID
(
s
.
ctx
,
user1
.
ID
)
...
...
@@ -511,6 +532,6 @@ func (s *UserRepoSuite) TestUpdateConcurrency_NotFound() {
func
(
s
*
UserRepoSuite
)
TestDeductBalance_NotFound
()
{
err
:=
s
.
repo
.
DeductBalance
(
s
.
ctx
,
999999
,
5
)
s
.
Require
()
.
Error
(
err
,
"expected error for non-existent user"
)
// DeductBalance 在用户不存在时返回 Err
InsufficientBalance 因为 WHERE 条件不匹配
s
.
Require
()
.
ErrorIs
(
err
,
service
.
Err
InsufficientBalance
)
// DeductBalance 在用户不存在时返回 Err
UserNotFound
s
.
Require
()
.
ErrorIs
(
err
,
service
.
Err
UserNotFound
)
}
backend/internal/server/api_contract_test.go
View file @
195e227c
...
...
@@ -296,13 +296,13 @@ func TestAPIContracts(t *testing.T) {
"smtp_host": "smtp.example.com",
"smtp_port": 587,
"smtp_username": "user",
"smtp_password
": "secret"
,
"smtp_password
_configured": true
,
"smtp_from_email": "no-reply@example.com",
"smtp_from_name": "Sub2API",
"smtp_use_tls": true,
"turnstile_enabled": true,
"turnstile_site_key": "site-key",
"turnstile_secret_key
": "secret-key"
,
"turnstile_secret_key
_configured": true
,
"site_name": "Sub2API",
"site_logo": "",
"site_subtitle": "Subtitle",
...
...
@@ -315,7 +315,9 @@ func TestAPIContracts(t *testing.T) {
"fallback_model_anthropic": "claude-3-5-sonnet-20241022",
"fallback_model_antigravity": "gemini-2.5-pro",
"fallback_model_gemini": "gemini-2.5-pro",
"fallback_model_openai": "gpt-4o"
"fallback_model_openai": "gpt-4o",
"enable_identity_patch": true,
"identity_patch_prompt": ""
}
}`
,
},
...
...
backend/internal/server/http.go
View file @
195e227c
...
...
@@ -2,6 +2,7 @@
package
server
import
(
"log"
"net/http"
"time"
...
...
@@ -36,6 +37,15 @@ func ProvideRouter(
r
:=
gin
.
New
()
r
.
Use
(
middleware2
.
Recovery
())
if
len
(
cfg
.
Server
.
TrustedProxies
)
>
0
{
if
err
:=
r
.
SetTrustedProxies
(
cfg
.
Server
.
TrustedProxies
);
err
!=
nil
{
log
.
Printf
(
"Failed to set trusted proxies: %v"
,
err
)
}
}
else
{
if
err
:=
r
.
SetTrustedProxies
(
nil
);
err
!=
nil
{
log
.
Printf
(
"Failed to disable trusted proxies: %v"
,
err
)
}
}
return
SetupRouter
(
r
,
handlers
,
jwtAuth
,
adminAuth
,
apiKeyAuth
,
apiKeyService
,
subscriptionService
,
cfg
)
}
...
...
backend/internal/server/middleware/api_key_auth.go
View file @
195e227c
...
...
@@ -19,6 +19,13 @@ func NewAPIKeyAuthMiddleware(apiKeyService *service.APIKeyService, subscriptionS
// apiKeyAuthWithSubscription API Key认证中间件(支持订阅验证)
func
apiKeyAuthWithSubscription
(
apiKeyService
*
service
.
APIKeyService
,
subscriptionService
*
service
.
SubscriptionService
,
cfg
*
config
.
Config
)
gin
.
HandlerFunc
{
return
func
(
c
*
gin
.
Context
)
{
queryKey
:=
strings
.
TrimSpace
(
c
.
Query
(
"key"
))
queryApiKey
:=
strings
.
TrimSpace
(
c
.
Query
(
"api_key"
))
if
queryKey
!=
""
||
queryApiKey
!=
""
{
AbortWithError
(
c
,
400
,
"api_key_in_query_deprecated"
,
"API key in query parameter is deprecated. Please use Authorization header instead."
)
return
}
// 尝试从Authorization header中提取API key (Bearer scheme)
authHeader
:=
c
.
GetHeader
(
"Authorization"
)
var
apiKeyString
string
...
...
@@ -41,19 +48,9 @@ func apiKeyAuthWithSubscription(apiKeyService *service.APIKeyService, subscripti
apiKeyString
=
c
.
GetHeader
(
"x-goog-api-key"
)
}
// 如果header中没有,尝试从query参数中提取(Google API key风格)
if
apiKeyString
==
""
{
apiKeyString
=
c
.
Query
(
"key"
)
}
// 兼容常见别名
if
apiKeyString
==
""
{
apiKeyString
=
c
.
Query
(
"api_key"
)
}
// 如果所有header都没有API key
if
apiKeyString
==
""
{
AbortWithError
(
c
,
401
,
"API_KEY_REQUIRED"
,
"API key is required in Authorization header (Bearer scheme), x-api-key header, x-goog-api-key header
, or key/api_key query parameter
"
)
AbortWithError
(
c
,
401
,
"API_KEY_REQUIRED"
,
"API key is required in Authorization header (Bearer scheme), x-api-key header,
or
x-goog-api-key header"
)
return
}
...
...
backend/internal/server/middleware/api_key_auth_google.go
View file @
195e227c
...
...
@@ -22,6 +22,10 @@ func APIKeyAuthGoogle(apiKeyService *service.APIKeyService, cfg *config.Config)
// It is intended for Gemini native endpoints (/v1beta) to match Gemini SDK expectations.
func
APIKeyAuthWithSubscriptionGoogle
(
apiKeyService
*
service
.
APIKeyService
,
subscriptionService
*
service
.
SubscriptionService
,
cfg
*
config
.
Config
)
gin
.
HandlerFunc
{
return
func
(
c
*
gin
.
Context
)
{
if
v
:=
strings
.
TrimSpace
(
c
.
Query
(
"api_key"
));
v
!=
""
{
abortWithGoogleError
(
c
,
400
,
"Query parameter api_key is deprecated. Use Authorization header or key instead."
)
return
}
apiKeyString
:=
extractAPIKeyFromRequest
(
c
)
if
apiKeyString
==
""
{
abortWithGoogleError
(
c
,
401
,
"API key is required"
)
...
...
@@ -116,15 +120,18 @@ func extractAPIKeyFromRequest(c *gin.Context) string {
if
v
:=
strings
.
TrimSpace
(
c
.
GetHeader
(
"x-goog-api-key"
));
v
!=
""
{
return
v
}
if
v
:=
strings
.
TrimSpace
(
c
.
Query
(
"key"
));
v
!=
""
{
return
v
}
if
v
:=
strings
.
TrimSpace
(
c
.
Query
(
"api_key"
));
v
!=
""
{
return
v
if
allowGoogleQueryKey
(
c
.
Request
.
URL
.
Path
)
{
if
v
:=
strings
.
TrimSpace
(
c
.
Query
(
"key"
));
v
!=
""
{
return
v
}
}
return
""
}
func
allowGoogleQueryKey
(
path
string
)
bool
{
return
strings
.
HasPrefix
(
path
,
"/v1beta"
)
||
strings
.
HasPrefix
(
path
,
"/antigravity/v1beta"
)
}
func
abortWithGoogleError
(
c
*
gin
.
Context
,
status
int
,
message
string
)
{
c
.
JSON
(
status
,
gin
.
H
{
"error"
:
gin
.
H
{
...
...
backend/internal/server/middleware/api_key_auth_google_test.go
View file @
195e227c
...
...
@@ -109,6 +109,58 @@ func TestApiKeyAuthWithSubscriptionGoogle_MissingKey(t *testing.T) {
require
.
Equal
(
t
,
"UNAUTHENTICATED"
,
resp
.
Error
.
Status
)
}
func
TestApiKeyAuthWithSubscriptionGoogle_QueryApiKeyRejected
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
r
:=
gin
.
New
()
apiKeyService
:=
newTestAPIKeyService
(
fakeAPIKeyRepo
{
getByKey
:
func
(
ctx
context
.
Context
,
key
string
)
(
*
service
.
APIKey
,
error
)
{
return
nil
,
errors
.
New
(
"should not be called"
)
},
})
r
.
Use
(
APIKeyAuthWithSubscriptionGoogle
(
apiKeyService
,
nil
,
&
config
.
Config
{}))
r
.
GET
(
"/v1beta/test"
,
func
(
c
*
gin
.
Context
)
{
c
.
JSON
(
200
,
gin
.
H
{
"ok"
:
true
})
})
req
:=
httptest
.
NewRequest
(
http
.
MethodGet
,
"/v1beta/test?api_key=legacy"
,
nil
)
rec
:=
httptest
.
NewRecorder
()
r
.
ServeHTTP
(
rec
,
req
)
require
.
Equal
(
t
,
http
.
StatusBadRequest
,
rec
.
Code
)
var
resp
googleErrorResponse
require
.
NoError
(
t
,
json
.
Unmarshal
(
rec
.
Body
.
Bytes
(),
&
resp
))
require
.
Equal
(
t
,
http
.
StatusBadRequest
,
resp
.
Error
.
Code
)
require
.
Equal
(
t
,
"Query parameter api_key is deprecated. Use Authorization header or key instead."
,
resp
.
Error
.
Message
)
require
.
Equal
(
t
,
"INVALID_ARGUMENT"
,
resp
.
Error
.
Status
)
}
func
TestApiKeyAuthWithSubscriptionGoogle_QueryKeyAllowedOnV1Beta
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
r
:=
gin
.
New
()
apiKeyService
:=
newTestAPIKeyService
(
fakeAPIKeyRepo
{
getByKey
:
func
(
ctx
context
.
Context
,
key
string
)
(
*
service
.
APIKey
,
error
)
{
return
&
service
.
APIKey
{
ID
:
1
,
Key
:
key
,
Status
:
service
.
StatusActive
,
User
:
&
service
.
User
{
ID
:
123
,
Status
:
service
.
StatusActive
,
},
},
nil
},
})
cfg
:=
&
config
.
Config
{
RunMode
:
config
.
RunModeSimple
}
r
.
Use
(
APIKeyAuthWithSubscriptionGoogle
(
apiKeyService
,
nil
,
cfg
))
r
.
GET
(
"/v1beta/test"
,
func
(
c
*
gin
.
Context
)
{
c
.
JSON
(
200
,
gin
.
H
{
"ok"
:
true
})
})
req
:=
httptest
.
NewRequest
(
http
.
MethodGet
,
"/v1beta/test?key=valid"
,
nil
)
rec
:=
httptest
.
NewRecorder
()
r
.
ServeHTTP
(
rec
,
req
)
require
.
Equal
(
t
,
http
.
StatusOK
,
rec
.
Code
)
}
func
TestApiKeyAuthWithSubscriptionGoogle_InvalidKey
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
...
...
backend/internal/server/middleware/cors.go
View file @
195e227c
package
middleware
import
(
"log"
"net/http"
"strings"
"sync"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/gin-gonic/gin"
)
var
corsWarningOnce
sync
.
Once
// CORS 跨域中间件
func
CORS
()
gin
.
HandlerFunc
{
func
CORS
(
cfg
config
.
CORSConfig
)
gin
.
HandlerFunc
{
allowedOrigins
:=
normalizeOrigins
(
cfg
.
AllowedOrigins
)
allowAll
:=
false
for
_
,
origin
:=
range
allowedOrigins
{
if
origin
==
"*"
{
allowAll
=
true
break
}
}
wildcardWithSpecific
:=
allowAll
&&
len
(
allowedOrigins
)
>
1
if
wildcardWithSpecific
{
allowedOrigins
=
[]
string
{
"*"
}
}
allowCredentials
:=
cfg
.
AllowCredentials
corsWarningOnce
.
Do
(
func
()
{
if
len
(
allowedOrigins
)
==
0
{
log
.
Println
(
"Warning: CORS allowed_origins not configured; cross-origin requests will be rejected."
)
}
if
wildcardWithSpecific
{
log
.
Println
(
"Warning: CORS allowed_origins includes '*'; wildcard will take precedence over explicit origins."
)
}
if
allowAll
&&
allowCredentials
{
log
.
Println
(
"Warning: CORS allowed_origins set to '*', disabling allow_credentials."
)
}
})
if
allowAll
&&
allowCredentials
{
allowCredentials
=
false
}
allowedSet
:=
make
(
map
[
string
]
struct
{},
len
(
allowedOrigins
))
for
_
,
origin
:=
range
allowedOrigins
{
if
origin
==
""
||
origin
==
"*"
{
continue
}
allowedSet
[
origin
]
=
struct
{}{}
}
return
func
(
c
*
gin
.
Context
)
{
// 设置允许跨域的响应头
c
.
Writer
.
Header
()
.
Set
(
"Access-Control-Allow-Origin"
,
"*"
)
c
.
Writer
.
Header
()
.
Set
(
"Access-Control-Allow-Credentials"
,
"true"
)
origin
:=
strings
.
TrimSpace
(
c
.
GetHeader
(
"Origin"
))
originAllowed
:=
allowAll
if
origin
!=
""
&&
!
allowAll
{
_
,
originAllowed
=
allowedSet
[
origin
]
}
if
originAllowed
{
if
allowAll
{
c
.
Writer
.
Header
()
.
Set
(
"Access-Control-Allow-Origin"
,
"*"
)
}
else
if
origin
!=
""
{
c
.
Writer
.
Header
()
.
Set
(
"Access-Control-Allow-Origin"
,
origin
)
c
.
Writer
.
Header
()
.
Add
(
"Vary"
,
"Origin"
)
}
if
allowCredentials
{
c
.
Writer
.
Header
()
.
Set
(
"Access-Control-Allow-Credentials"
,
"true"
)
}
}
c
.
Writer
.
Header
()
.
Set
(
"Access-Control-Allow-Headers"
,
"Content-Type, Content-Length, Accept-Encoding, X-CSRF-Token, Authorization, accept, origin, Cache-Control, X-Requested-With, X-API-Key"
)
c
.
Writer
.
Header
()
.
Set
(
"Access-Control-Allow-Methods"
,
"POST, OPTIONS, GET, PUT, DELETE, PATCH"
)
// 处理预检请求
if
c
.
Request
.
Method
==
"OPTIONS"
{
c
.
AbortWithStatus
(
204
)
if
c
.
Request
.
Method
==
http
.
MethodOptions
{
if
originAllowed
{
c
.
AbortWithStatus
(
http
.
StatusNoContent
)
}
else
{
c
.
AbortWithStatus
(
http
.
StatusForbidden
)
}
return
}
c
.
Next
()
}
}
func
normalizeOrigins
(
values
[]
string
)
[]
string
{
if
len
(
values
)
==
0
{
return
nil
}
normalized
:=
make
([]
string
,
0
,
len
(
values
))
for
_
,
value
:=
range
values
{
trimmed
:=
strings
.
TrimSpace
(
value
)
if
trimmed
==
""
{
continue
}
normalized
=
append
(
normalized
,
trimmed
)
}
return
normalized
}
backend/internal/server/middleware/security_headers.go
0 → 100644
View file @
195e227c
package
middleware
import
(
"strings"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/gin-gonic/gin"
)
// SecurityHeaders sets baseline security headers for all responses.
func
SecurityHeaders
(
cfg
config
.
CSPConfig
)
gin
.
HandlerFunc
{
policy
:=
strings
.
TrimSpace
(
cfg
.
Policy
)
if
policy
==
""
{
policy
=
config
.
DefaultCSPPolicy
}
return
func
(
c
*
gin
.
Context
)
{
c
.
Header
(
"X-Content-Type-Options"
,
"nosniff"
)
c
.
Header
(
"X-Frame-Options"
,
"DENY"
)
c
.
Header
(
"Referrer-Policy"
,
"strict-origin-when-cross-origin"
)
if
cfg
.
Enabled
{
c
.
Header
(
"Content-Security-Policy"
,
policy
)
}
c
.
Next
()
}
}
backend/internal/server/router.go
View file @
195e227c
...
...
@@ -24,7 +24,8 @@ func SetupRouter(
)
*
gin
.
Engine
{
// 应用中间件
r
.
Use
(
middleware2
.
Logger
())
r
.
Use
(
middleware2
.
CORS
())
r
.
Use
(
middleware2
.
CORS
(
cfg
.
CORS
))
r
.
Use
(
middleware2
.
SecurityHeaders
(
cfg
.
Security
.
CSP
))
// Serve embedded frontend if available
if
web
.
HasEmbeddedFrontend
()
{
...
...
backend/internal/service/account.go
View file @
195e227c
...
...
@@ -11,6 +11,7 @@ import (
type
Account
struct
{
ID
int64
Name
string
Notes
*
string
Platform
string
Type
string
Credentials
map
[
string
]
any
...
...
@@ -262,6 +263,17 @@ func parseTempUnschedStrings(value any) []string {
return
out
}
func
normalizeAccountNotes
(
value
*
string
)
*
string
{
if
value
==
nil
{
return
nil
}
trimmed
:=
strings
.
TrimSpace
(
*
value
)
if
trimmed
==
""
{
return
nil
}
return
&
trimmed
}
func
parseTempUnschedInt
(
value
any
)
int
{
switch
v
:=
value
.
(
type
)
{
case
int
:
...
...
backend/internal/service/account_service.go
View file @
195e227c
...
...
@@ -72,6 +72,7 @@ type AccountBulkUpdate struct {
// CreateAccountRequest 创建账号请求
type
CreateAccountRequest
struct
{
Name
string
`json:"name"`
Notes
*
string
`json:"notes"`
Platform
string
`json:"platform"`
Type
string
`json:"type"`
Credentials
map
[
string
]
any
`json:"credentials"`
...
...
@@ -85,6 +86,7 @@ type CreateAccountRequest struct {
// UpdateAccountRequest 更新账号请求
type
UpdateAccountRequest
struct
{
Name
*
string
`json:"name"`
Notes
*
string
`json:"notes"`
Credentials
*
map
[
string
]
any
`json:"credentials"`
Extra
*
map
[
string
]
any
`json:"extra"`
ProxyID
*
int64
`json:"proxy_id"`
...
...
@@ -123,6 +125,7 @@ func (s *AccountService) Create(ctx context.Context, req CreateAccountRequest) (
// 创建账号
account
:=
&
Account
{
Name
:
req
.
Name
,
Notes
:
normalizeAccountNotes
(
req
.
Notes
),
Platform
:
req
.
Platform
,
Type
:
req
.
Type
,
Credentials
:
req
.
Credentials
,
...
...
@@ -194,6 +197,9 @@ func (s *AccountService) Update(ctx context.Context, id int64, req UpdateAccount
if
req
.
Name
!=
nil
{
account
.
Name
=
*
req
.
Name
}
if
req
.
Notes
!=
nil
{
account
.
Notes
=
normalizeAccountNotes
(
req
.
Notes
)
}
if
req
.
Credentials
!=
nil
{
account
.
Credentials
=
*
req
.
Credentials
...
...
backend/internal/service/account_test_service.go
View file @
195e227c
...
...
@@ -7,6 +7,7 @@ import (
"crypto/rand"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"io"
"log"
...
...
@@ -14,9 +15,11 @@ import (
"regexp"
"strings"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
"github.com/Wei-Shaw/sub2api/internal/pkg/geminicli"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
"github.com/Wei-Shaw/sub2api/internal/util/urlvalidator"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
)
...
...
@@ -45,6 +48,7 @@ type AccountTestService struct {
geminiTokenProvider
*
GeminiTokenProvider
antigravityGatewayService
*
AntigravityGatewayService
httpUpstream
HTTPUpstream
cfg
*
config
.
Config
}
// NewAccountTestService creates a new AccountTestService
...
...
@@ -53,15 +57,35 @@ func NewAccountTestService(
geminiTokenProvider
*
GeminiTokenProvider
,
antigravityGatewayService
*
AntigravityGatewayService
,
httpUpstream
HTTPUpstream
,
cfg
*
config
.
Config
,
)
*
AccountTestService
{
return
&
AccountTestService
{
accountRepo
:
accountRepo
,
geminiTokenProvider
:
geminiTokenProvider
,
antigravityGatewayService
:
antigravityGatewayService
,
httpUpstream
:
httpUpstream
,
cfg
:
cfg
,
}
}
func
(
s
*
AccountTestService
)
validateUpstreamBaseURL
(
raw
string
)
(
string
,
error
)
{
if
s
.
cfg
==
nil
{
return
""
,
errors
.
New
(
"config is not available"
)
}
if
!
s
.
cfg
.
Security
.
URLAllowlist
.
Enabled
{
return
urlvalidator
.
ValidateURLFormat
(
raw
,
s
.
cfg
.
Security
.
URLAllowlist
.
AllowInsecureHTTP
)
}
normalized
,
err
:=
urlvalidator
.
ValidateHTTPSURL
(
raw
,
urlvalidator
.
ValidationOptions
{
AllowedHosts
:
s
.
cfg
.
Security
.
URLAllowlist
.
UpstreamHosts
,
RequireAllowlist
:
true
,
AllowPrivate
:
s
.
cfg
.
Security
.
URLAllowlist
.
AllowPrivateHosts
,
})
if
err
!=
nil
{
return
""
,
err
}
return
normalized
,
nil
}
// generateSessionString generates a Claude Code style session string
func
generateSessionString
()
(
string
,
error
)
{
bytes
:=
make
([]
byte
,
32
)
...
...
@@ -183,11 +207,15 @@ func (s *AccountTestService) testClaudeAccountConnection(c *gin.Context, account
return
s
.
sendErrorAndEnd
(
c
,
"No API key available"
)
}
api
URL
=
account
.
GetBaseURL
()
if
api
URL
==
""
{
api
URL
=
"https://api.anthropic.com"
base
URL
:
=
account
.
GetBaseURL
()
if
base
URL
==
""
{
base
URL
=
"https://api.anthropic.com"
}
apiURL
=
strings
.
TrimSuffix
(
apiURL
,
"/"
)
+
"/v1/messages"
normalizedBaseURL
,
err
:=
s
.
validateUpstreamBaseURL
(
baseURL
)
if
err
!=
nil
{
return
s
.
sendErrorAndEnd
(
c
,
fmt
.
Sprintf
(
"Invalid base URL: %s"
,
err
.
Error
()))
}
apiURL
=
strings
.
TrimSuffix
(
normalizedBaseURL
,
"/"
)
+
"/v1/messages"
}
else
{
return
s
.
sendErrorAndEnd
(
c
,
fmt
.
Sprintf
(
"Unsupported account type: %s"
,
account
.
Type
))
}
...
...
@@ -300,7 +328,11 @@ func (s *AccountTestService) testOpenAIAccountConnection(c *gin.Context, account
if
baseURL
==
""
{
baseURL
=
"https://api.openai.com"
}
apiURL
=
strings
.
TrimSuffix
(
baseURL
,
"/"
)
+
"/responses"
normalizedBaseURL
,
err
:=
s
.
validateUpstreamBaseURL
(
baseURL
)
if
err
!=
nil
{
return
s
.
sendErrorAndEnd
(
c
,
fmt
.
Sprintf
(
"Invalid base URL: %s"
,
err
.
Error
()))
}
apiURL
=
strings
.
TrimSuffix
(
normalizedBaseURL
,
"/"
)
+
"/responses"
}
else
{
return
s
.
sendErrorAndEnd
(
c
,
fmt
.
Sprintf
(
"Unsupported account type: %s"
,
account
.
Type
))
}
...
...
@@ -480,10 +512,14 @@ func (s *AccountTestService) buildGeminiAPIKeyRequest(ctx context.Context, accou
if
baseURL
==
""
{
baseURL
=
geminicli
.
AIStudioBaseURL
}
normalizedBaseURL
,
err
:=
s
.
validateUpstreamBaseURL
(
baseURL
)
if
err
!=
nil
{
return
nil
,
err
}
// Use streamGenerateContent for real-time feedback
fullURL
:=
fmt
.
Sprintf
(
"%s/v1beta/models/%s:streamGenerateContent?alt=sse"
,
strings
.
TrimRight
(
b
aseURL
,
"/"
),
modelID
)
strings
.
TrimRight
(
normalizedB
aseURL
,
"/"
),
modelID
)
req
,
err
:=
http
.
NewRequestWithContext
(
ctx
,
"POST"
,
fullURL
,
bytes
.
NewReader
(
payload
))
if
err
!=
nil
{
...
...
@@ -515,7 +551,11 @@ func (s *AccountTestService) buildGeminiOAuthRequest(ctx context.Context, accoun
if
strings
.
TrimSpace
(
baseURL
)
==
""
{
baseURL
=
geminicli
.
AIStudioBaseURL
}
fullURL
:=
fmt
.
Sprintf
(
"%s/v1beta/models/%s:streamGenerateContent?alt=sse"
,
strings
.
TrimRight
(
baseURL
,
"/"
),
modelID
)
normalizedBaseURL
,
err
:=
s
.
validateUpstreamBaseURL
(
baseURL
)
if
err
!=
nil
{
return
nil
,
err
}
fullURL
:=
fmt
.
Sprintf
(
"%s/v1beta/models/%s:streamGenerateContent?alt=sse"
,
strings
.
TrimRight
(
normalizedBaseURL
,
"/"
),
modelID
)
req
,
err
:=
http
.
NewRequestWithContext
(
ctx
,
http
.
MethodPost
,
fullURL
,
bytes
.
NewReader
(
payload
))
if
err
!=
nil
{
...
...
@@ -544,7 +584,11 @@ func (s *AccountTestService) buildCodeAssistRequest(ctx context.Context, accessT
}
wrappedBytes
,
_
:=
json
.
Marshal
(
wrapped
)
fullURL
:=
fmt
.
Sprintf
(
"%s/v1internal:streamGenerateContent?alt=sse"
,
geminicli
.
GeminiCliBaseURL
)
normalizedBaseURL
,
err
:=
s
.
validateUpstreamBaseURL
(
geminicli
.
GeminiCliBaseURL
)
if
err
!=
nil
{
return
nil
,
err
}
fullURL
:=
fmt
.
Sprintf
(
"%s/v1internal:streamGenerateContent?alt=sse"
,
normalizedBaseURL
)
req
,
err
:=
http
.
NewRequestWithContext
(
ctx
,
"POST"
,
fullURL
,
bytes
.
NewReader
(
wrappedBytes
))
if
err
!=
nil
{
...
...
Prev
1
2
3
4
5
6
7
…
10
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment